mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-07 06:30:03 +00:00
Compare commits
863 Commits
auto-chunk
...
feat-bring
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e0a8cc178b | ||
|
|
af618de13d | ||
|
|
ef976eeb06 | ||
|
|
9c8ae9d540 | ||
|
|
7ca33b2b72 | ||
|
|
d1b9798f62 | ||
|
|
ddc3adf3ab | ||
|
|
a4991d01ac | ||
|
|
87fd1bd359 | ||
|
|
c71e986d34 | ||
|
|
a2a06c569e | ||
|
|
c5f00a1d1b | ||
|
|
2a15bb0102 | ||
|
|
c06888bc86 | ||
|
|
d4b1c1fd81 | ||
|
|
2de84acf81 | ||
|
|
2702750861 | ||
|
|
2b5f20d0ec | ||
|
|
619b41dc5b | ||
|
|
76d8f49ccb | ||
|
|
65460b0c03 | ||
|
|
9fe96fb50f | ||
|
|
08822c3379 | ||
|
|
68ca8ff9ea | ||
|
|
81be3cdccc | ||
|
|
3ceabed8ad | ||
|
|
422a4b139e | ||
|
|
e85935eed0 | ||
|
|
6a69b8aca0 | ||
|
|
33c2cc9660 | ||
|
|
175d4d5a68 | ||
|
|
6c3ead1071 | ||
|
|
d23f88f825 | ||
|
|
da1df515f7 | ||
|
|
671a9d75ad | ||
|
|
1c829667ff | ||
|
|
3ab0ebb16d | ||
|
|
988c4a5a15 | ||
|
|
01db8b2c41 | ||
|
|
ef19da9516 | ||
|
|
cc1275c3f9 | ||
|
|
14c2f4890f | ||
|
|
b3aec36aa2 | ||
|
|
50f62beaeb | ||
|
|
423b4c6494 | ||
|
|
54f615c59d | ||
|
|
223b3de66e | ||
|
|
4db9622ef5 | ||
|
|
e8d1bbfb68 | ||
|
|
aff1345ae4 | ||
|
|
ee430aff1e | ||
|
|
81b6ee5daa | ||
|
|
ebb7938d1b | ||
|
|
7f6360b4ff | ||
|
|
c68f18a0ae | ||
|
|
684b29e73c | ||
|
|
a1efea81d0 | ||
|
|
9eb34262e0 | ||
|
|
951bdb8365 | ||
|
|
c18f85a050 | ||
|
|
5ecb174567 | ||
|
|
ed7212d016 | ||
|
|
f82acdab5d | ||
|
|
361aebc34c | ||
|
|
bf194c1a0f | ||
|
|
54c396750b | ||
|
|
9adebfec69 | ||
|
|
92c321f163 | ||
|
|
e3d36b9e52 | ||
|
|
8950e11208 | ||
|
|
5de0132a65 | ||
|
|
b92ca91512 | ||
|
|
8e0b2844a2 | ||
|
|
0969db5e30 | ||
|
|
af335a27e8 | ||
|
|
0ae3139284 | ||
|
|
7529ca3dd6 | ||
|
|
1b813320f1 | ||
|
|
02012e9a0b | ||
|
|
c2f027265a | ||
|
|
0ae615c10e | ||
|
|
881d0da344 | ||
|
|
1376de6bae | ||
|
|
362ebfcc0a | ||
|
|
bc77eed3d8 | ||
|
|
1f346588e7 | ||
|
|
2fed5c882b | ||
|
|
aa938d76d7 | ||
|
|
2940628aa6 | ||
|
|
7f23928134 | ||
|
|
20e17c84c7 | ||
|
|
389ddf6068 | ||
|
|
1e2443fb90 | ||
|
|
6387bd1892 | ||
|
|
7d22724d1c | ||
|
|
f6f12f6895 | ||
|
|
934127f323 | ||
|
|
1780e3cc91 | ||
|
|
5e7fab2f34 | ||
|
|
92ae76f95e | ||
|
|
18755bdd9b | ||
|
|
0f20adcbf4 | ||
|
|
18e2a829c9 | ||
|
|
cd44501a71 | ||
|
|
f8ebdf3fd4 | ||
|
|
7c6fca18ad | ||
|
|
5fab798707 | ||
|
|
cb30a24e05 | ||
|
|
530761d08c | ||
|
|
73fbc28744 | ||
|
|
b5b6538762 | ||
|
|
a9761061fc | ||
|
|
9388996a15 | ||
|
|
875868b7e5 | ||
|
|
502819ae52 | ||
|
|
cada1a44fc | ||
|
|
6192767451 | ||
|
|
5c3e6eca54 | ||
|
|
59d9d4ac50 | ||
|
|
3931ccccee | ||
|
|
55717043f6 | ||
|
|
ececcb8b17 | ||
|
|
420e9d3dd5 | ||
|
|
749eed3d0b | ||
|
|
bd03a513e3 | ||
|
|
fcdb4fb5e8 | ||
|
|
e787c896eb | ||
|
|
23aeaff5db | ||
|
|
689dd79597 | ||
|
|
0c15af90b1 | ||
|
|
cdd6ff6557 | ||
|
|
72b3d94453 | ||
|
|
7e88d09e5d | ||
|
|
74a4a237dc | ||
|
|
c3f01c6619 | ||
|
|
6b408823d4 | ||
|
|
3fc81ac5d8 | ||
|
|
2652f8a5b0 | ||
|
|
d711eefe96 | ||
|
|
79206f3919 | ||
|
|
de971d9452 | ||
|
|
1b4d5ca0dd | ||
|
|
81989e8258 | ||
|
|
dc262d1698 | ||
|
|
69f9c93869 | ||
|
|
74bf80b25c | ||
|
|
d9a92a7208 | ||
|
|
02e93d993d | ||
|
|
6b6495f48c | ||
|
|
249dd9ce37 | ||
|
|
9134ab0478 | ||
|
|
10ef68c9d0 | ||
|
|
7d65cf1c2b | ||
|
|
13c6cc59c1 | ||
|
|
6381f7dd4e | ||
|
|
e6ac4008fe | ||
|
|
1af09f114d | ||
|
|
be7da983e7 | ||
|
|
8b9e595d85 | ||
|
|
398f3acc8d | ||
|
|
e04baa7ed8 | ||
|
|
e5586b6f20 | ||
|
|
addf57cab7 | ||
|
|
648b3f1d20 | ||
|
|
a75a9e23f9 | ||
|
|
73256389cf | ||
|
|
d609efca49 | ||
|
|
772860b667 | ||
|
|
ea2fd8b04a | ||
|
|
2c73deac20 | ||
|
|
47f3907e5e | ||
|
|
727495c553 | ||
|
|
a3b08a5b44 | ||
|
|
81532ada2a | ||
|
|
43f71374e5 | ||
|
|
d5c0322e2a | ||
|
|
3b66a3176c | ||
|
|
dc6db847ca | ||
|
|
ed0063aada | ||
|
|
9a6a55b6da | ||
|
|
12a8368216 | ||
|
|
3f6d6f15ea | ||
|
|
126fa01b14 | ||
|
|
e06debad5f | ||
|
|
6492852f7d | ||
|
|
00a621f33a | ||
|
|
e92ffc6fdc | ||
|
|
fe185e5b8d | ||
|
|
9f3d9ab860 | ||
|
|
1c0adde380 | ||
|
|
3c56bd0d0b | ||
|
|
86664ebda2 | ||
|
|
db18b743d1 | ||
|
|
9e85cc9065 | ||
|
|
aaaa6f002d | ||
|
|
47dcbcb74b | ||
|
|
ddbfd94193 | ||
|
|
8dec60ab8b | ||
|
|
84b2e4bab4 | ||
|
|
193ca6fd63 | ||
|
|
2afdd7f026 | ||
|
|
f364475f64 | ||
|
|
b254de6ed6 | ||
|
|
08dedcaf95 | ||
|
|
c726eb8ebd | ||
|
|
5f0d39e5f1 | ||
|
|
8c82fc5495 | ||
|
|
6d81a15e97 | ||
|
|
5478e4234c | ||
|
|
4056278fef | ||
|
|
ee6530fe00 | ||
|
|
7c1decbcc3 | ||
|
|
8a3c724b31 | ||
|
|
15d4e9dbf5 | ||
|
|
174dee0fe6 | ||
|
|
f7bfd38b28 | ||
|
|
187e5da61e | ||
|
|
175ed58d2e | ||
|
|
820ee3a843 | ||
|
|
462f2e9494 | ||
|
|
c4968a641e | ||
|
|
c6ece177cd | ||
|
|
a3e6a5622d | ||
|
|
e8d11fdfa6 | ||
|
|
72393dc369 | ||
|
|
556b0a1da5 | ||
|
|
844167ba06 | ||
|
|
6fa3acb1ca | ||
|
|
32c268a21e | ||
|
|
ed34c2b929 | ||
|
|
06e827573c | ||
|
|
cdb71a54f0 | ||
|
|
74e76d4cda | ||
|
|
db5c69ca76 | ||
|
|
9fd063266b | ||
|
|
05aa9d7cca | ||
|
|
dcececd118 | ||
|
|
eaf39bb15b | ||
|
|
6515481624 | ||
|
|
6a7e3b6d77 | ||
|
|
02804fecce | ||
|
|
324a8cd4cf | ||
|
|
ce5cd5561a | ||
|
|
adeefce9aa | ||
|
|
5ab43fd12c | ||
|
|
5894e47189 | ||
|
|
ca61d81f4a | ||
|
|
b12d0ca7b1 | ||
|
|
21996af626 | ||
|
|
cc3b174e5a | ||
|
|
faee58fb1e | ||
|
|
d439e48b39 | ||
|
|
3f0f155d64 | ||
|
|
d82d512319 | ||
|
|
76aea1716f | ||
|
|
586649b73f | ||
|
|
0349a79cb3 | ||
|
|
78a255bdd7 | ||
|
|
5b30e71aa1 | ||
|
|
99d84aece9 | ||
|
|
525d8eb66d | ||
|
|
4c810108e0 | ||
|
|
fc03cdc76a | ||
|
|
9779a563f3 | ||
|
|
6141c3c348 | ||
|
|
c3726ddfc9 | ||
|
|
10eaa8143e | ||
|
|
0c4f4e1f0c | ||
|
|
b225c3cd80 | ||
|
|
b558645d6b | ||
|
|
03b0889b15 | ||
|
|
943fe3651c | ||
|
|
65e57be4dd | ||
|
|
13ad3b5dce | ||
|
|
918bbf0369 | ||
|
|
5006271abb | ||
|
|
a6625ec5de | ||
|
|
1a2104f474 | ||
|
|
444abb8283 | ||
|
|
ee86537f21 | ||
|
|
17a736a927 | ||
|
|
6b5779054d | ||
|
|
14296632ef | ||
|
|
2a3f0e455a | ||
|
|
8aa44c415b | ||
|
|
0566c41a32 | ||
|
|
876b04c058 | ||
|
|
b49a5934e2 | ||
|
|
5fb063914e | ||
|
|
b9941e29a9 | ||
|
|
8ef321d784 | ||
|
|
8353f9c649 | ||
|
|
cb6b3aa406 | ||
|
|
36c7bd9206 | ||
|
|
fea94379d7 | ||
|
|
e602d941ca | ||
|
|
f41f69a268 | ||
|
|
ff72251878 | ||
|
|
7751fb52dd | ||
|
|
87a44d101d | ||
|
|
80148f25b6 | ||
|
|
8e3e4a8b09 | ||
|
|
9389b4a1e8 | ||
|
|
4245e5bd2e | ||
|
|
e7d2af2405 | ||
|
|
4c32a96370 | ||
|
|
f61d112cea | ||
|
|
2c55c6cd9a | ||
|
|
f1d714b5c1 | ||
|
|
69d9dc672a | ||
|
|
9192e010e8 | ||
|
|
f24cea0877 | ||
|
|
a29bfa7489 | ||
|
|
2246866a09 | ||
|
|
7b17fde34a | ||
|
|
df57053613 | ||
|
|
5662be12b5 | ||
|
|
d3e9d66b07 | ||
|
|
e0bdbcbe38 | ||
|
|
05c835ed02 | ||
|
|
9e7f1ad1c0 | ||
|
|
f910a82683 | ||
|
|
d8b7e86f8d | ||
|
|
aef3e0b4bb | ||
|
|
b0eee7be24 | ||
|
|
197e94302b | ||
|
|
98e949d2fd | ||
|
|
83e7a928f1 | ||
|
|
ccd29b7d4e | ||
|
|
5b6cfa6ecc | ||
|
|
f91846ce2d | ||
|
|
87e24ab96e | ||
|
|
40c3e5568c | ||
|
|
7958d29e13 | ||
|
|
a6fafa6a4d | ||
|
|
3ad38f53fd | ||
|
|
d90b1c57e5 | ||
|
|
a69a0e100f | ||
|
|
b0d4576a95 | ||
|
|
2a4ab3aca1 | ||
|
|
e0fd11a86e | ||
|
|
de369f8b5e | ||
|
|
af3e16c4fc | ||
|
|
aacf281222 | ||
|
|
6d8f083c6f | ||
|
|
909bc421c0 | ||
|
|
d14f04d79c | ||
|
|
e0a9f08632 | ||
|
|
09e7c1b97f | ||
|
|
4adffe762a | ||
|
|
9a937d2686 | ||
|
|
e68da34c13 | ||
|
|
9b9f95710a | ||
|
|
3352d42414 | ||
|
|
899b30da5e | ||
|
|
dc2faf7a7e | ||
|
|
67e0d222d1 | ||
|
|
17698ce774 | ||
|
|
7d1c8c008b | ||
|
|
9e58eb02b3 | ||
|
|
3f7de867cc | ||
|
|
2c2bdd37d5 | ||
|
|
6a00319c2d | ||
|
|
66870279d3 | ||
|
|
fbf7cf874b | ||
|
|
ba7278b80f | ||
|
|
9d649de6f9 | ||
|
|
7929afbf58 | ||
|
|
ceaf942e70 | ||
|
|
f355601a44 | ||
|
|
4ff99a1e86 | ||
|
|
129084ba92 | ||
|
|
2288df1293 | ||
|
|
d9dfac55e7 | ||
|
|
404cf4b7c7 | ||
|
|
f1c1fc123b | ||
|
|
9f19c7ee4c | ||
|
|
155e74eca1 | ||
|
|
ea2dc4dbcb | ||
|
|
616edc97de | ||
|
|
b017e99c79 | ||
|
|
f698e9d3e1 | ||
|
|
d366502850 | ||
|
|
3d6757c170 | ||
|
|
cb8302add8 | ||
|
|
9d266e9fad | ||
|
|
ae94c9d31e | ||
|
|
83ab232dcd | ||
|
|
eea85772a3 | ||
|
|
0fe7e223cc | ||
|
|
3789d2eb03 | ||
|
|
d54469532e | ||
|
|
9884e51836 | ||
|
|
6626723180 | ||
|
|
0c251e066b | ||
|
|
0957034bfa | ||
|
|
44521cd893 | ||
|
|
b17f846730 | ||
|
|
6dd32fd4ca | ||
|
|
b17b1c70b5 | ||
|
|
3f5b31fb5f | ||
|
|
06bda6bd55 | ||
|
|
7dd97821a8 | ||
|
|
695191d888 | ||
|
|
1dbcef24c7 | ||
|
|
e086c79da0 | ||
|
|
6ae8d34b27 | ||
|
|
2e23e547d3 | ||
|
|
fa11dc9828 | ||
|
|
673fa70bc5 | ||
|
|
a0660a54c1 | ||
|
|
1137bf4280 | ||
|
|
da41c898d8 | ||
|
|
21e5c261ef | ||
|
|
a7d61b9d59 | ||
|
|
c5fe25c149 | ||
|
|
6a4cb617f9 | ||
|
|
94f70e6de5 | ||
|
|
ab4ebf9a9d | ||
|
|
9f7945fcf5 | ||
|
|
d8ec3c008c | ||
|
|
2f00691246 | ||
|
|
9b2383b074 | ||
|
|
e4e9910575 | ||
|
|
f448e4a615 | ||
|
|
c4e8daf50e | ||
|
|
5aa4ec1b9f | ||
|
|
125ce0aad3 | ||
|
|
ababc9ae04 | ||
|
|
62ac90746e | ||
|
|
096f6d91a2 | ||
|
|
d28ef6b094 | ||
|
|
8fb945ab09 | ||
|
|
835d71727c | ||
|
|
ce32dd2907 | ||
|
|
72bc24a490 | ||
|
|
d6c49bdbf0 | ||
|
|
1805292528 | ||
|
|
d09ce7e1f7 | ||
|
|
a8d2024791 | ||
|
|
f0b954dbfb | ||
|
|
50bee7c2b0 | ||
|
|
e7b15b316e | ||
|
|
a4507008c1 | ||
|
|
c5ba85f929 | ||
|
|
2e636bd67e | ||
|
|
4a039f1abf | ||
|
|
434d8e2070 | ||
|
|
160ad2dc79 | ||
|
|
0ec86c2c71 | ||
|
|
03452ffd9f | ||
|
|
da6317a242 | ||
|
|
8b8e616557 | ||
|
|
d260f1a1a6 | ||
|
|
9d452e3b04 | ||
|
|
e012189672 | ||
|
|
4c31e9a8b1 | ||
|
|
7cfc230316 | ||
|
|
9605e85f1c | ||
|
|
498e2b772c | ||
|
|
dad897da51 | ||
|
|
02ad5f062e | ||
|
|
4eb9471b4f | ||
|
|
b505d207d7 | ||
|
|
3c954bd07f | ||
|
|
c00b6459dc | ||
|
|
eb4d776784 | ||
|
|
5d7a890533 | ||
|
|
9c6aefef1e | ||
|
|
e4554d6c09 | ||
|
|
c184b63df8 | ||
|
|
6bb4195393 | ||
|
|
7827a4d40d | ||
|
|
f09fa8231a | ||
|
|
96ff10000d | ||
|
|
9460636867 | ||
|
|
6c43245295 | ||
|
|
266b6cf638 | ||
|
|
70183e234a | ||
|
|
17b9c359ca | ||
|
|
045630b8a5 | ||
|
|
55ff7dd640 | ||
|
|
e6d64f71f2 | ||
|
|
e72313ebdd | ||
|
|
65d5bd72cd | ||
|
|
dc0cbb41f0 | ||
|
|
c4a54a85be | ||
|
|
5b2738aec9 | ||
|
|
892312fc08 | ||
|
|
c2ccf2c72c | ||
|
|
80aaecb5f0 | ||
|
|
946865a335 | ||
|
|
5de15c8413 | ||
|
|
67268fd35a | ||
|
|
42fc771833 | ||
|
|
444b1a0b65 | ||
|
|
814ea1c016 | ||
|
|
4d34dc4234 | ||
|
|
d567399f2b | ||
|
|
77f4f8d8b0 | ||
|
|
a2d04beaa1 | ||
|
|
ba49eea23d | ||
|
|
82beafc086 | ||
|
|
7d8ed2d102 | ||
|
|
aab8d3a4f1 | ||
|
|
76658d50a0 | ||
|
|
88ba22342c | ||
|
|
11a1460af9 | ||
|
|
2cd4c41316 | ||
|
|
b910f308f2 | ||
|
|
763aa73ea4 | ||
|
|
30c79e92d4 | ||
|
|
402d5e054b | ||
|
|
0e211df206 | ||
|
|
e24a0ac686 | ||
|
|
8c91b1c527 | ||
|
|
2b38f80d04 | ||
|
|
282bd35f52 | ||
|
|
cc9b4c2bcb | ||
|
|
068ce4970a | ||
|
|
cf19165ad8 | ||
|
|
68c479f3a5 | ||
|
|
ba496a772b | ||
|
|
3b27db36f2 | ||
|
|
f803def69b | ||
|
|
52065e69a4 | ||
|
|
50f5e8a955 | ||
|
|
2d0e97b66d | ||
|
|
5f3cc5a392 | ||
|
|
ac66d77512 | ||
|
|
50cf653d4a | ||
|
|
56256051d2 | ||
|
|
c0361ff03d | ||
|
|
f153435c08 | ||
|
|
9aa7f22fa6 | ||
|
|
52b7bda5f8 | ||
|
|
21aefa2778 | ||
|
|
a89ff71c9e | ||
|
|
4c275816be | ||
|
|
f8dfbcfc80 | ||
|
|
d317f6473d | ||
|
|
00b4e133d4 | ||
|
|
b6349e4efb | ||
|
|
6ca3d9585c | ||
|
|
5935a0283a | ||
|
|
5400a6ec06 | ||
|
|
6574d9cc84 | ||
|
|
42b83c5994 | ||
|
|
896612a5a3 | ||
|
|
0ee875bee4 | ||
|
|
8ce345cd94 | ||
|
|
da2f8477e6 | ||
|
|
82b47b5673 | ||
|
|
7c15a4c7ff | ||
|
|
3369b910b4 | ||
|
|
ec0c4c3b84 | ||
|
|
f74e2c9da1 | ||
|
|
e26ad3c475 | ||
|
|
145c3b8ad0 | ||
|
|
0ff6c6a154 | ||
|
|
641cf5a4c1 | ||
|
|
09b9576eef | ||
|
|
18b71ca2f2 | ||
|
|
e0eb7f456e | ||
|
|
188d118fc0 | ||
|
|
adcdce8d76 | ||
|
|
b865a7aec1 | ||
|
|
cec8c72b46 | ||
|
|
b052e32805 | ||
|
|
816f660be3 | ||
|
|
fc8be45d5a | ||
|
|
e749c936c9 | ||
|
|
b2b9670a23 | ||
|
|
2f88890c94 | ||
|
|
6366663f03 | ||
|
|
20fe7dc6d1 | ||
|
|
4b9153069e | ||
|
|
80406d0753 | ||
|
|
35f4c11784 | ||
|
|
7896526f19 | ||
|
|
f7db22edff | ||
|
|
0e4196f036 | ||
|
|
1bf6af6eeb | ||
|
|
5a9bc6d2bf | ||
|
|
f7f6042579 | ||
|
|
c4a598f3d3 | ||
|
|
7c23f43c63 | ||
|
|
7e2cbdd88c | ||
|
|
3b3a04a249 | ||
|
|
f9b2c95695 | ||
|
|
c2c18e8319 | ||
|
|
384ad3e0ac | ||
|
|
8c986aaa7f | ||
|
|
bb4ea76d30 | ||
|
|
2868e47cf8 | ||
|
|
e0adc3e5d5 | ||
|
|
e55d1a5865 | ||
|
|
018273c6b2 | ||
|
|
44b8a11c04 | ||
|
|
56e5aba559 | ||
|
|
46904ccd54 | ||
|
|
5b7c7a4471 | ||
|
|
9da4215d1f | ||
|
|
f39ac9945f | ||
|
|
a0cc2e4d46 | ||
|
|
4065041a9f | ||
|
|
f08067a161 | ||
|
|
545caacfa3 | ||
|
|
a06f646637 | ||
|
|
578c68205a | ||
|
|
f09f1433a9 | ||
|
|
15a9e97a1e | ||
|
|
b3af4ee50b | ||
|
|
07d59b6640 | ||
|
|
e25b988dc8 | ||
|
|
2410bd8654 | ||
|
|
44d21ab703 | ||
|
|
e283957c8f | ||
|
|
b1210c4902 | ||
|
|
e7430f0fbc | ||
|
|
92d6ae54c3 | ||
|
|
f82be23ca9 | ||
|
|
8c3f75e3e2 | ||
|
|
193d59f193 | ||
|
|
c2bebbaefa | ||
|
|
7ae5a9c5a5 | ||
|
|
3b69bea23d | ||
|
|
ab05726b99 | ||
|
|
b2b04268e9 | ||
|
|
bd73fa9ae7 | ||
|
|
927d10d66e | ||
|
|
b67329623c | ||
|
|
6f47aa802b | ||
|
|
3417c73011 | ||
|
|
6a02bcf15b | ||
|
|
cd0fbf79a3 | ||
|
|
15d2d0115b | ||
|
|
d1a0fe6e91 | ||
|
|
1db80d140f | ||
|
|
896dcf1f9e | ||
|
|
819a12fb49 | ||
|
|
c68273706c | ||
|
|
6bb0cd535a | ||
|
|
cb9ec69cf6 | ||
|
|
143854fa81 | ||
|
|
2f48a3d7d5 | ||
|
|
ec95dafe1e | ||
|
|
3d1fe724e5 | ||
|
|
5c615d6f2d | ||
|
|
d72558eb36 | ||
|
|
65c33ad915 | ||
|
|
9be128a963 | ||
|
|
eb05132008 | ||
|
|
f94a093e8c | ||
|
|
0d0c2daf64 | ||
|
|
823d948b25 | ||
|
|
56831fbcf2 | ||
|
|
bf49b9cb88 | ||
|
|
e01adffbad | ||
|
|
08a5d52d82 | ||
|
|
fdae235742 | ||
|
|
9903fad1e9 | ||
|
|
14bbd5338d | ||
|
|
4a236c2f6f | ||
|
|
0a8cdbd7f1 | ||
|
|
94c49843be | ||
|
|
9281fac898 | ||
|
|
0b2736f454 | ||
|
|
ae116b0d0d | ||
|
|
ba260e3382 | ||
|
|
1282e7687f | ||
|
|
b1d8266eef | ||
|
|
7acae6935b | ||
|
|
092c01cae7 | ||
|
|
56a1066c30 | ||
|
|
1356d71839 | ||
|
|
1eb011e8c3 | ||
|
|
e349eb28b0 | ||
|
|
b000b235a2 | ||
|
|
16fe92282e | ||
|
|
e218e88cf4 | ||
|
|
888ea81a32 | ||
|
|
735fab7640 | ||
|
|
45745c2a47 | ||
|
|
4caff0fcf6 | ||
|
|
762ea6ce7f | ||
|
|
8b4f6553f3 | ||
|
|
a61e44d175 | ||
|
|
e1b1558fc9 | ||
|
|
53225bda4e | ||
|
|
5212769848 | ||
|
|
d5ded3c9f4 | ||
|
|
c92d778894 | ||
|
|
829abd1ad6 | ||
|
|
266d256a07 | ||
|
|
8380cac3e7 | ||
|
|
a24652f901 | ||
|
|
2d203d3c70 | ||
|
|
48d21600da | ||
|
|
2508d0fbb3 | ||
|
|
e90e80c289 | ||
|
|
5e4748f9d9 | ||
|
|
212952f3e9 | ||
|
|
f99b6496c5 | ||
|
|
67423d51b9 | ||
|
|
58465ece65 | ||
|
|
8ede3a0173 | ||
|
|
ad2f0f8950 | ||
|
|
76973a4b4c | ||
|
|
b198e2e029 | ||
|
|
4d6ea401b5 | ||
|
|
b00c4cc3b6 | ||
|
|
4185e64c65 | ||
|
|
6eb2c884a2 | ||
|
|
6c0362a4cf | ||
|
|
50b1755a63 | ||
|
|
ff3c7eb5fb | ||
|
|
3755316d49 | ||
|
|
f952046847 | ||
|
|
969cdb4a63 | ||
|
|
f336d44595 | ||
|
|
a53f93c195 | ||
|
|
fcb334ce33 | ||
|
|
8ddf04a904 | ||
|
|
29698ca169 | ||
|
|
a9baf7436a | ||
|
|
99a8962183 | ||
|
|
afc5b15a6b | ||
|
|
b6ab508e27 | ||
|
|
789e65557a | ||
|
|
8a7806ab2d | ||
|
|
493303e103 | ||
|
|
1d9af05e9e | ||
|
|
5b07c5f2e8 | ||
|
|
2a4ec0cf5b | ||
|
|
a00c44386e | ||
|
|
a38d71bbfb | ||
|
|
a24a3f868c | ||
|
|
f60c516185 | ||
|
|
26f4646304 | ||
|
|
3a351f67e6 | ||
|
|
e7c09cb91e | ||
|
|
ae1a6ef303 | ||
|
|
2ff477a339 | ||
|
|
793f3fb683 | ||
|
|
a472ee7602 | ||
|
|
c62040e232 | ||
|
|
2e7cb510ae | ||
|
|
dbe45904d7 | ||
|
|
5623734276 | ||
|
|
d3b592bffc | ||
|
|
4fcbdae5bf | ||
|
|
ca95d7275a | ||
|
|
61baf3701c | ||
|
|
bbce872ac5 | ||
|
|
0f7ebcd8e4 | ||
|
|
82fc19e7b7 | ||
|
|
839a12bed4 | ||
|
|
2ef23fe1b3 | ||
|
|
fd905b1a06 | ||
|
|
1372210004 | ||
|
|
ade704d065 | ||
|
|
42f48649b9 | ||
|
|
0b08e8b617 | ||
|
|
926b2f1a1b | ||
|
|
1770a1a45f | ||
|
|
50ed2a64c6 | ||
|
|
2332344988 | ||
|
|
7ccc8cdc58 | ||
|
|
ecec9f913e | ||
|
|
777f40fc5e | ||
|
|
327ae35420 | ||
|
|
0d48159da8 | ||
|
|
d36f12a4ea | ||
|
|
709488beb1 | ||
|
|
a9e4583695 | ||
|
|
4702dec933 | ||
|
|
e6352dd691 | ||
|
|
240ea3b857 | ||
|
|
f0908af3c0 | ||
|
|
6834961dd1 | ||
|
|
b404162364 | ||
|
|
e879ef805f | ||
|
|
7077ca5e98 | ||
|
|
a1e6978c8f | ||
|
|
584391dd59 | ||
|
|
bab3ae809c | ||
|
|
c78518baf0 | ||
|
|
556d7e0497 | ||
|
|
2d27936dab | ||
|
|
0cc22de545 | ||
|
|
63f6127049 | ||
|
|
f34e00c986 | ||
|
|
55f60a9fe1 | ||
|
|
7da3618e0c | ||
|
|
56bfa98633 | ||
|
|
96f6188722 | ||
|
|
aa9d359039 | ||
|
|
cef5731028 | ||
|
|
5bc28bd4fd | ||
|
|
55a1d867c3 | ||
|
|
6c3a79802e | ||
|
|
c35c5e0793 | ||
|
|
7bc83caa99 | ||
|
|
3aceca63c6 | ||
|
|
9bc166ffd4 | ||
|
|
fc01b90007 | ||
|
|
e35f1d70e4 | ||
|
|
cab1f3787a | ||
|
|
bb42f4cbc1 | ||
|
|
98dc418a51 | ||
|
|
322b4eb18c | ||
|
|
7f1cc30ed8 | ||
|
|
7b45a6b956 | ||
|
|
e36769e70f | ||
|
|
bd4a4cc4af | ||
|
|
8343fe63cb | ||
|
|
7d89fb8461 | ||
|
|
098955d230 | ||
|
|
d254d14928 | ||
|
|
0a3e8ca535 | ||
|
|
b8a10e0962 | ||
|
|
0aceda96e4 | ||
|
|
44b6ec25a2 | ||
|
|
1b84d1fa9d | ||
|
|
78d5ed2ed2 | ||
|
|
142477ab9b | ||
|
|
b414f79bc5 | ||
|
|
6e08fe21d0 | ||
|
|
9b839655a7 | ||
|
|
3353c0ee1d | ||
|
|
aaecf52c99 | ||
|
|
8b3e960be0 | ||
|
|
3351f71813 | ||
|
|
7490256303 | ||
|
|
041d600e45 | ||
|
|
b4e2588a24 | ||
|
|
68dc14c5a1 | ||
|
|
ef35864e16 | ||
|
|
c0d385b983 | ||
|
|
b2df431fa4 | ||
|
|
69a4bd415a | ||
|
|
4862548e65 | ||
|
|
50248cc9ea | ||
|
|
430822bae3 | ||
|
|
dd9d18208d | ||
|
|
e5b1a71659 | ||
|
|
35f4b13237 | ||
|
|
5f5c31cd5b | ||
|
|
e9530d5ec5 | ||
|
|
143f4aa886 | ||
|
|
ece5c8bb31 | ||
|
|
31baf181a3 | ||
|
|
3bae30c70c | ||
|
|
12b18c6bd1 | ||
|
|
787d9e3bf5 | ||
|
|
f325b54895 | ||
|
|
c5616705b0 | ||
|
|
e90fe117ec | ||
|
|
7cab5b3b09 | ||
|
|
9f911cb5cb | ||
|
|
3da7cba06c | ||
|
|
92c3c707e1 |
@@ -1,9 +1,39 @@
|
||||
API_KEY=<LLM api key (for example, open ai key)>
|
||||
LLM_NAME=docsgpt
|
||||
VITE_API_STREAMING=true
|
||||
INTERNAL_KEY=<internal key for worker-to-backend authentication>
|
||||
|
||||
# Provider-specific API keys (optional - use these to enable multiple providers)
|
||||
# OPENAI_API_KEY=<your-openai-api-key>
|
||||
# ANTHROPIC_API_KEY=<your-anthropic-api-key>
|
||||
# GOOGLE_API_KEY=<your-google-api-key>
|
||||
# GROQ_API_KEY=<your-groq-api-key>
|
||||
# NOVITA_API_KEY=<your-novita-api-key>
|
||||
# OPEN_ROUTER_API_KEY=<your-openrouter-api-key>
|
||||
|
||||
# Remote Embeddings (Optional - for using a remote embeddings API instead of local SentenceTransformer)
|
||||
# When set, the app will use the remote API and won't load SentenceTransformer (saves RAM)
|
||||
EMBEDDINGS_BASE_URL=
|
||||
EMBEDDINGS_KEY=
|
||||
|
||||
#For Azure (you can delete it if you don't use Azure)
|
||||
OPENAI_API_BASE=
|
||||
OPENAI_API_VERSION=
|
||||
AZURE_DEPLOYMENT_NAME=
|
||||
AZURE_EMBEDDINGS_DEPLOYMENT_NAME=
|
||||
AZURE_EMBEDDINGS_DEPLOYMENT_NAME=
|
||||
|
||||
#Azure AD Application (client) ID
|
||||
MICROSOFT_CLIENT_ID=your-azure-ad-client-id
|
||||
#Azure AD Application client secret
|
||||
MICROSOFT_CLIENT_SECRET=your-azure-ad-client-secret
|
||||
#Azure AD Tenant ID (or 'common' for multi-tenant)
|
||||
MICROSOFT_TENANT_ID=your-azure-ad-tenant-id
|
||||
#If you are using a Microsoft Entra ID tenant,
|
||||
#configure the AUTHORITY variable as
|
||||
#"https://login.microsoftonline.com/TENANT_GUID"
|
||||
#or "https://login.microsoftonline.com/contoso.onmicrosoft.com".
|
||||
#Alternatively, use "https://login.microsoftonline.com/common" for multi-tenant app.
|
||||
MICROSOFT_AUTHORITY=https://{tenantId}.ciamlogin.com/{tenantId}
|
||||
|
||||
|
||||
# POSTGRES_URI=postgresql://docsgpt:docsgpt@localhost:5432/docsgpt
|
||||
|
||||
2
.gitattributes
vendored
Normal file
2
.gitattributes
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
# Auto detect text files and perform LF normalization
|
||||
* text=auto
|
||||
99
.github/INCIDENT_RESPONSE.md
vendored
Normal file
99
.github/INCIDENT_RESPONSE.md
vendored
Normal file
@@ -0,0 +1,99 @@
|
||||
# DocsGPT Incident Response Plan (IRP)
|
||||
|
||||
This playbook describes how maintainers respond to confirmed or suspected security incidents.
|
||||
|
||||
- Vulnerability reporting: [`SECURITY.md`](../SECURITY.md)
|
||||
- Non-security bugs/features: [`CONTRIBUTING.md`](../CONTRIBUTING.md)
|
||||
|
||||
## Severity
|
||||
|
||||
| Severity | Definition | Typical examples |
|
||||
|---|---|---|
|
||||
| **Critical** | Active exploitation, supply-chain compromise, or confirmed data breach requiring immediate user action. | Compromised release artifact/image; remote execution. |
|
||||
| **High** | Serious undisclosed vulnerability with no practical workaround, or CVSS >= 7.0. | key leakage; prompt injection enabling cross-tenant access. |
|
||||
| **Medium** | Material impact but constrained by preconditions/scope, or a practical workaround exists. | Auth-required exploit; dependency CVE with limited reachability. |
|
||||
| **Low** | Defense-in-depth or narrow availability impact with no confirmed data exposure. | Missing rate limiting; hardening gap without exploit evidence. |
|
||||
|
||||
|
||||
## Response workflow
|
||||
|
||||
### 1) Triage (target: initial response within 48 hours)
|
||||
|
||||
1. Acknowledge report.
|
||||
2. Validate on latest release and `main`.
|
||||
3. Confirm in-scope security issue vs. hardening item (per `SECURITY.md`).
|
||||
4. Assign severity and open a **draft GitHub Security Advisory (GHSA)** (no public issue).
|
||||
5. Determine whether root cause is DocsGPT code or upstream dependency/provider.
|
||||
|
||||
### 2) Investigation
|
||||
|
||||
1. Identify affected components, versions, and deployment scope (self-hosted, cloud, or both).
|
||||
2. For AI issues, explicitly evaluate prompt injection, document isolation, and output leakage.
|
||||
3. Request a CVE through GHSA for **Medium+** issues.
|
||||
|
||||
### 3) Containment, fix, and disclosure
|
||||
|
||||
1. Implement and test fix in private security workflow (GHSA private fork/branch).
|
||||
2. Merge fix to `main`, cut patched release, and verify published artifacts/images.
|
||||
3. Patch managed cloud deployment (`app.docsgpt.cloud`) and other deployments as soon as validated.
|
||||
4. Publish GHSA with CVE (if assigned), affected/fixed versions, CVSS, mitigations, and upgrade guidance.
|
||||
5. **Critical/High:** coordinate disclosure timing with reporter (goal: <= 90 days) and publish a notice.
|
||||
6. **Medium/Low:** include in next scheduled release unless risk requires immediate out-of-band patching.
|
||||
|
||||
### 4) Post-incident
|
||||
|
||||
1. Monitor support channels (GitHub/Discord) for regressions or exploitation reports.
|
||||
2. Run a short retrospective (root cause, detection, response gaps, prevention work).
|
||||
3. Track follow-up hardening actions with owners/dates.
|
||||
4. Update this IRP and related runbooks as needed.
|
||||
|
||||
## Scenario playbooks
|
||||
|
||||
### Supply-chain compromise
|
||||
|
||||
1. Freeze releases and investigate blast radius.
|
||||
2. Rotate credentials in order: Docker Hub -> GitHub tokens -> LLM provider keys -> DB credentials -> `JWT_SECRET_KEY` -> `ENCRYPTION_SECRET_KEY` -> `INTERNAL_KEY`.
|
||||
3. Replace compromised artifacts/tags with clean releases and revoke/remove bad tags where possible.
|
||||
4. Publish advisory with exact affected versions and required user actions.
|
||||
|
||||
### Data exposure
|
||||
|
||||
1. Determine scope (users, documents, keys, logs, time window).
|
||||
2. Disable affected path or hotfix immediately for managed cloud.
|
||||
3. Notify affected users with concrete remediation steps (for example, rotate keys).
|
||||
4. Continue through standard fix/disclosure workflow.
|
||||
|
||||
### Critical regression with security impact
|
||||
|
||||
1. Identify introducing change (`git bisect` if needed).
|
||||
2. Publish workaround within 24 hours (for example, pin to known-good version).
|
||||
3. Ship patch release with regression test and close incident with public summary.
|
||||
|
||||
## AI-specific guidance
|
||||
|
||||
Treat confirmed AI-specific abuse as security incidents:
|
||||
|
||||
- Prompt injection causing sensitive data exfiltration (from tools that don't belong to the agent) -> **High**
|
||||
- Cross-tenant retrieval/isolation failure -> **High**
|
||||
- API key disclosure in output -> **High**
|
||||
|
||||
## Secret rotation quick reference
|
||||
|
||||
| Secret | Standard rotation action |
|
||||
|---|---|
|
||||
| Docker Hub credentials | Revoke/replace in Docker Hub; update CI/CD secrets |
|
||||
| GitHub tokens/PATs | Revoke/replace in GitHub; update automation secrets |
|
||||
| LLM provider API keys | Rotate in provider console; update runtime/deploy secrets |
|
||||
| Database credentials | Rotate in DB platform; redeploy with new secrets |
|
||||
| `JWT_SECRET_KEY` | Rotate and redeploy (invalidates all active user sessions/tokens) |
|
||||
| `ENCRYPTION_SECRET_KEY` | Rotate and redeploy (re-encrypt stored data if possible; existing encrypted data may become inaccessible) |
|
||||
| `INTERNAL_KEY` | Rotate and redeploy (invalidates worker-to-backend authentication) |
|
||||
|
||||
## Maintenance
|
||||
|
||||
Review this document:
|
||||
|
||||
- after every **Critical/High** incident, and
|
||||
- at least annually.
|
||||
|
||||
Changes should be proposed via pull request to `main`.
|
||||
144
.github/THREAT_MODEL.md
vendored
Normal file
144
.github/THREAT_MODEL.md
vendored
Normal file
@@ -0,0 +1,144 @@
|
||||
# DocsGPT Public Threat Model
|
||||
|
||||
**Classification:** Public
|
||||
**Last updated:** 2026-04-15
|
||||
**Applies to:** Open-source and self-hosted DocsGPT deployments
|
||||
|
||||
## 1) Overview
|
||||
|
||||
DocsGPT ingests content (files/URLs/connectors), indexes it, and answers queries via LLM-backed APIs and optional tools.
|
||||
|
||||
Core components:
|
||||
- Backend API (`application/`)
|
||||
- Workers/ingestion (`application/worker.py` and related modules)
|
||||
- Datastores (MongoDB/Redis/vector stores)
|
||||
- Frontend (`frontend/`)
|
||||
- Optional extensions/integrations (`extensions/`)
|
||||
|
||||
## 2) Scope and assumptions
|
||||
|
||||
In scope:
|
||||
- Application-level threats in this repository.
|
||||
- Local and internet-exposed self-hosted deployments.
|
||||
|
||||
Assumptions:
|
||||
- Internet-facing instances enable auth and use strong secrets.
|
||||
- Datastores/internal services are not publicly exposed.
|
||||
|
||||
Out of scope:
|
||||
- Cloud hardware/provider compromise.
|
||||
- Security guarantees of external LLM vendors.
|
||||
- Full security audits of third-party systems targeted by tools (external DBs/MCP servers/code-exec APIs).
|
||||
|
||||
## 3) Security objectives
|
||||
|
||||
- Protect document/conversation confidentiality.
|
||||
- Preserve integrity of prompts, agents, tools, and indexed data.
|
||||
- Maintain API/worker availability.
|
||||
- Enforce tenant isolation in authenticated deployments.
|
||||
|
||||
## 4) Assets
|
||||
|
||||
- Documents, attachments, chunks/embeddings, summaries.
|
||||
- Conversations, agents, workflows, prompt templates.
|
||||
- Secrets (JWT secret, `INTERNAL_KEY`, provider/API/OAuth credentials).
|
||||
- Operational capacity (worker throughput, queue depth, model quota/cost).
|
||||
|
||||
## 5) Trust boundaries and untrusted input
|
||||
|
||||
Trust boundaries:
|
||||
- Internet ↔ Frontend
|
||||
- Frontend ↔ Backend API
|
||||
- Backend ↔ Workers/internal APIs
|
||||
- Backend/workers ↔ Datastores
|
||||
- Backend ↔ External LLM/connectors/remote URLs
|
||||
|
||||
Untrusted input includes API payloads, file uploads, remote URLs, OAuth/webhook data, retrieved content, and LLM/tool arguments.
|
||||
|
||||
## 6) Main attack surfaces
|
||||
|
||||
1. Auth/authz paths and sharing tokens.
|
||||
2. File upload + parsing pipeline.
|
||||
3. Remote URL fetching and connectors (SSRF risk).
|
||||
4. Agent/tool execution from LLM output.
|
||||
5. Template/workflow rendering.
|
||||
6. Frontend rendering + token storage.
|
||||
7. Internal service endpoints (`INTERNAL_KEY`).
|
||||
8. High-impact integrations (SQL tool, generic API tool, remote MCP tools).
|
||||
|
||||
## 7) Key threats and expected mitigations
|
||||
|
||||
### A. Auth/authz misconfiguration
|
||||
- Threat: weak/no auth or leaked tokens leads to broad data access.
|
||||
- Mitigations: require auth for public deployments, short-lived tokens, rotation/revocation, least-privilege sharing.
|
||||
|
||||
### B. Untrusted file ingestion
|
||||
- Threat: malicious files/archives trigger traversal, parser exploits, or resource exhaustion.
|
||||
- Mitigations: strict path checks, archive safeguards, file limits, patched parser dependencies.
|
||||
|
||||
### C. SSRF/outbound abuse
|
||||
- Threat: URL loaders/tools access private/internal/metadata endpoints.
|
||||
- Mitigations: validate URLs + redirects, block private/link-local ranges, apply egress controls/allowlists.
|
||||
|
||||
### D. Prompt injection + tool abuse
|
||||
- Threat: retrieved text manipulates model behavior and causes unsafe tool calls.
|
||||
- Threat: never rely on the model to "choose correctly" under adversarial input.
|
||||
- Mitigations: treat retrieved/model output as untrusted, enforce tool policies, only expose tools explicitly assigned by the user/admin to that agent, separate system instructions from retrieved content, audit tool calls.
|
||||
|
||||
### E. Dangerous tool capability chaining (SQL/API/MCP)
|
||||
- Threat: write-capable SQL credentials allow destructive queries.
|
||||
- Threat: API tool can trigger side effects (infra/payment/webhook/code-exec endpoints).
|
||||
- Threat: remote MCP tools may expose privileged operations.
|
||||
- Mitigations: read-only-by-default credentials, destination allowlists, explicit approval for write/exec actions, per-tool policy enforcement + logging.
|
||||
|
||||
### F. Frontend/XSS + token theft
|
||||
- Threat: XSS can steal local tokens and call APIs.
|
||||
- Mitigations: reduce unsafe rendering paths, strong CSP, scoped short-lived credentials.
|
||||
|
||||
### G. Internal endpoint exposure
|
||||
- Threat: weak/unset `INTERNAL_KEY` enables internal API abuse.
|
||||
- Mitigations: fail closed, require strong random keys, keep internal APIs private.
|
||||
|
||||
### H. DoS and cost abuse
|
||||
- Threat: request floods, large ingestion jobs, expensive prompts/crawls.
|
||||
- Mitigations: rate limits, quotas, timeouts, queue backpressure, usage budgets.
|
||||
|
||||
## 8) Example attacker stories
|
||||
|
||||
- Internet-exposed deployment runs with weak/no auth and receives unauthorized data access/abuse.
|
||||
- Intranet deployment intentionally using weak/no auth is vulnerable to insider misuse and lateral-movement abuse.
|
||||
- Crafted archive attempts path traversal during extraction.
|
||||
- Malicious URL/redirect chain targets internal services.
|
||||
- Poisoned document causes data exfiltration through tool calls.
|
||||
- Over-privileged SQL/API/MCP tool performs destructive side effects.
|
||||
|
||||
## 9) Severity calibration
|
||||
|
||||
- **Critical:** unauthenticated public data access; prompt-injection-driven exfiltration; SSRF to sensitive internal endpoints.
|
||||
- **High:** cross-tenant leakage, persistent token compromise, over-privileged destructive tools.
|
||||
- **Medium:** DoS/cost amplification and non-critical information disclosure.
|
||||
- **Low:** minor hardening gaps with limited impact.
|
||||
|
||||
## 10) Baseline controls for public deployments
|
||||
|
||||
1. Enforce authentication and secure defaults.
|
||||
2. Set/rotate strong secrets (`JWT`, `INTERNAL_KEY`, encryption keys).
|
||||
3. Restrict CORS and front API with a hardened proxy.
|
||||
4. Add rate limiting/quotas for answer/upload/crawl/token endpoints.
|
||||
5. Enforce URL+redirect SSRF protections and egress restrictions.
|
||||
6. Apply upload/archive/parsing hardening.
|
||||
7. Require least-privilege tool credentials and auditable tool execution.
|
||||
8. Monitor auth failures, tool anomalies, ingestion spikes, and cost anomalies.
|
||||
9. Keep dependencies/images patched and scanned.
|
||||
10. Validate multi-tenant isolation with explicit tests.
|
||||
|
||||
## 11) Maintenance
|
||||
|
||||
Review this model after major auth, ingestion, connector, tool, or workflow changes.
|
||||
|
||||
## References
|
||||
|
||||
- [OWASP Top 10 for LLM Applications](https://owasp.org/www-project-top-10-for-large-language-model-applications/)
|
||||
- [OWASP ASVS](https://owasp.org/www-project-application-security-verification-standard/)
|
||||
- [STRIDE overview](https://learn.microsoft.com/azure/security/develop/threat-modeling-tool-threats)
|
||||
- [DocsGPT SECURITY.md](../SECURITY.md)
|
||||
6
.github/dependabot.yml
vendored
6
.github/dependabot.yml
vendored
@@ -13,7 +13,11 @@ updates:
|
||||
directory: "/frontend" # Location of package manifests
|
||||
schedule:
|
||||
interval: "daily"
|
||||
- package-ecosystem: "npm"
|
||||
directory: "/extensions/react-widget"
|
||||
schedule:
|
||||
interval: "daily"
|
||||
- package-ecosystem: "github-actions"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "daily"
|
||||
interval: "daily"
|
||||
11
.github/styles/DocsGPT/Spelling.yml
vendored
Normal file
11
.github/styles/DocsGPT/Spelling.yml
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
extends: spelling
|
||||
level: warning
|
||||
message: "Did you really mean '%s'?"
|
||||
ignore:
|
||||
- "**/node_modules/**"
|
||||
- "**/dist/**"
|
||||
- "**/build/**"
|
||||
- "**/coverage/**"
|
||||
- "**/public/**"
|
||||
- "**/static/**"
|
||||
vocab: DocsGPT
|
||||
80
.github/styles/config/vocabularies/DocsGPT/accept.txt
vendored
Normal file
80
.github/styles/config/vocabularies/DocsGPT/accept.txt
vendored
Normal file
@@ -0,0 +1,80 @@
|
||||
Agentic
|
||||
Anthropic's
|
||||
api
|
||||
APIs
|
||||
Atlassian
|
||||
automations
|
||||
autoescaping
|
||||
Autoescaping
|
||||
backfill
|
||||
backfills
|
||||
bool
|
||||
boolean
|
||||
brave_web_search
|
||||
chatbot
|
||||
Chatwoot
|
||||
config
|
||||
configs
|
||||
CSVs
|
||||
dev
|
||||
diarization
|
||||
Docling
|
||||
docsgpt
|
||||
docstrings
|
||||
Entra
|
||||
env
|
||||
enqueues
|
||||
EOL
|
||||
ESLint
|
||||
feedbacks
|
||||
Figma
|
||||
GPUs
|
||||
Groq
|
||||
hardcode
|
||||
hardcoding
|
||||
Idempotency
|
||||
JSONPath
|
||||
kubectl
|
||||
Lightsail
|
||||
llama_cpp
|
||||
llm
|
||||
LLM
|
||||
LLMs
|
||||
LMDeploy
|
||||
Milvus
|
||||
Mixtral
|
||||
namespace
|
||||
namespaces
|
||||
needs_auth
|
||||
Nextra
|
||||
Novita
|
||||
npm
|
||||
OAuth
|
||||
Ollama
|
||||
opencode
|
||||
parsable
|
||||
passthrough
|
||||
PDFs
|
||||
pgvector
|
||||
Postgres
|
||||
Premade
|
||||
Pydantic
|
||||
pytest
|
||||
Qdrant
|
||||
qdrant
|
||||
Repo
|
||||
repo
|
||||
Sanitization
|
||||
SDKs
|
||||
SGLang
|
||||
Shareability
|
||||
Signup
|
||||
Supabase
|
||||
UIs
|
||||
uncomment
|
||||
URl
|
||||
vectorstore
|
||||
Vite
|
||||
VSCode
|
||||
VSCode's
|
||||
widget's
|
||||
3
.github/workflows/lint.yml
vendored
3
.github/workflows/lint.yml
vendored
@@ -7,6 +7,9 @@ on:
|
||||
pull_request:
|
||||
types: [ opened, synchronize ]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
ruff:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
114
.github/workflows/npm-publish.yml
vendored
Normal file
114
.github/workflows/npm-publish.yml
vendored
Normal file
@@ -0,0 +1,114 @@
|
||||
name: Publish npm libraries
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
version:
|
||||
description: >
|
||||
Version bump type (patch | minor | major) or explicit semver (e.g. 1.2.3).
|
||||
Applies to both docsgpt and docsgpt-react.
|
||||
required: true
|
||||
default: patch
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
publish:
|
||||
runs-on: ubuntu-latest
|
||||
environment: npm-release
|
||||
defaults:
|
||||
run:
|
||||
working-directory: extensions/react-widget
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 20
|
||||
registry-url: https://registry.npmjs.org
|
||||
|
||||
- name: Install dependencies
|
||||
run: npm ci
|
||||
|
||||
# ── docsgpt (HTML embedding bundle) ──────────────────────────────────
|
||||
# Uses the `build` script (parcel build src/browser.tsx) and keeps
|
||||
# the `targets` field so Parcel produces browser-optimised bundles.
|
||||
|
||||
- name: Set package name → docsgpt
|
||||
run: jq --arg n "docsgpt" '.name=$n' package.json > _tmp.json && mv _tmp.json package.json
|
||||
|
||||
- name: Bump version (docsgpt)
|
||||
id: version_docsgpt
|
||||
run: |
|
||||
VERSION="${{ github.event.inputs.version }}"
|
||||
NEW_VER=$(npm version "${VERSION:-patch}" --no-git-tag-version)
|
||||
echo "version=${NEW_VER#v}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Build docsgpt
|
||||
run: npm run build
|
||||
|
||||
- name: Publish docsgpt
|
||||
run: npm publish --verbose
|
||||
env:
|
||||
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
|
||||
|
||||
# ── docsgpt-react (React library bundle) ─────────────────────────────
|
||||
# Uses `build:react` script (parcel build src/index.ts) and strips
|
||||
# the `targets` field so Parcel treats the output as a plain library
|
||||
# without browser-specific target resolution, producing a smaller bundle.
|
||||
|
||||
- name: Reset package.json from source control
|
||||
run: git checkout -- package.json
|
||||
|
||||
- name: Set package name → docsgpt-react
|
||||
run: jq --arg n "docsgpt-react" '.name=$n' package.json > _tmp.json && mv _tmp.json package.json
|
||||
|
||||
- name: Remove targets field (react library build)
|
||||
run: jq 'del(.targets)' package.json > _tmp.json && mv _tmp.json package.json
|
||||
|
||||
- name: Bump version (docsgpt-react) to match docsgpt
|
||||
run: npm version "${{ steps.version_docsgpt.outputs.version }}" --no-git-tag-version
|
||||
|
||||
- name: Clean dist before react build
|
||||
run: rm -rf dist
|
||||
|
||||
- name: Build docsgpt-react
|
||||
run: npm run build:react
|
||||
|
||||
- name: Publish docsgpt-react
|
||||
run: npm publish --verbose
|
||||
env:
|
||||
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
|
||||
|
||||
# ── Commit the bumped version back to the repository ─────────────────
|
||||
|
||||
- name: Reset package.json and write final version
|
||||
run: |
|
||||
git checkout -- package.json
|
||||
jq --arg v "${{ steps.version_docsgpt.outputs.version }}" '.version=$v' \
|
||||
package.json > _tmp.json && mv _tmp.json package.json
|
||||
npm install --package-lock-only
|
||||
|
||||
- name: Commit version bump and create PR
|
||||
run: |
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
BRANCH="chore/bump-npm-v${{ steps.version_docsgpt.outputs.version }}"
|
||||
git checkout -b "$BRANCH"
|
||||
git add package.json package-lock.json
|
||||
git commit -m "chore: bump npm libraries to v${{ steps.version_docsgpt.outputs.version }}"
|
||||
git push origin "$BRANCH"
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Create PR
|
||||
run: |
|
||||
gh pr create \
|
||||
--title "chore: bump npm libraries to v${{ steps.version_docsgpt.outputs.version }}" \
|
||||
--body "Automated version bump after npm publish." \
|
||||
--base main
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
10
.github/workflows/pytest.yml
vendored
10
.github/workflows/pytest.yml
vendored
@@ -1,5 +1,9 @@
|
||||
name: Run python tests with pytest
|
||||
on: [push, pull_request]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
pytest_and_coverage:
|
||||
name: Run tests and count coverage
|
||||
@@ -16,15 +20,15 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install pytest pytest-cov
|
||||
cd application
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
|
||||
cd ../tests
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
|
||||
- name: Test with pytest and generate coverage report
|
||||
run: |
|
||||
python -m pytest --cov=application --cov-report=xml
|
||||
python -m pytest --cov=application --cov-report=xml --cov-report=term-missing
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: github.event_name == 'pull_request' && matrix.python-version == '3.12'
|
||||
uses: codecov/codecov-action@v5
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
|
||||
|
||||
34
.github/workflows/react-widget-build.yml
vendored
Normal file
34
.github/workflows/react-widget-build.yml
vendored
Normal file
@@ -0,0 +1,34 @@
|
||||
name: React Widget Build
|
||||
|
||||
on:
|
||||
push:
|
||||
paths:
|
||||
- 'extensions/react-widget/**'
|
||||
pull_request:
|
||||
paths:
|
||||
- 'extensions/react-widget/**'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
working-directory: extensions/react-widget
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 20
|
||||
cache: npm
|
||||
cache-dependency-path: extensions/react-widget/package-lock.json
|
||||
|
||||
- name: Install dependencies
|
||||
run: npm ci
|
||||
|
||||
- name: Build
|
||||
run: npm run build
|
||||
34
.github/workflows/vale.yml
vendored
Normal file
34
.github/workflows/vale.yml
vendored
Normal file
@@ -0,0 +1,34 @@
|
||||
name: Vale Documentation Linter
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- 'docs/**/*.md'
|
||||
- 'docs/**/*.mdx'
|
||||
- '**/*.md'
|
||||
- '.vale.ini'
|
||||
- '.github/styles/**'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
vale:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Vale
|
||||
run: |
|
||||
curl -fsSL -o vale.tar.gz \
|
||||
https://github.com/errata-ai/vale/releases/download/v3.0.5/vale_3.0.5_Linux_64-bit.tar.gz
|
||||
tar -xzf vale.tar.gz
|
||||
sudo mv vale /usr/local/bin/vale
|
||||
vale --version
|
||||
|
||||
- name: Sync Vale packages
|
||||
run: vale sync
|
||||
|
||||
- name: Run Vale
|
||||
run: vale --minAlertLevel=error docs
|
||||
25
.github/workflows/zizmor.yml
vendored
Normal file
25
.github/workflows/zizmor.yml
vendored
Normal file
@@ -0,0 +1,25 @@
|
||||
name: GitHub Actions Security Analysis
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["master"]
|
||||
pull_request:
|
||||
branches: ["**"]
|
||||
|
||||
permissions: {}
|
||||
|
||||
jobs:
|
||||
zizmor:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
permissions:
|
||||
security-events: write # Required for upload-sarif (used by zizmor-action) to upload SARIF files.
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Run zizmor 🌈
|
||||
uses: zizmorcore/zizmor-action@71321a20a9ded102f6e9ce5718a2fcec2c4f70d8 # v0.5.2
|
||||
19
.gitignore
vendored
19
.gitignore
vendored
@@ -2,7 +2,10 @@
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
results.txt
|
||||
experiments/
|
||||
|
||||
experiments
|
||||
# C extensions
|
||||
*.so
|
||||
*.next
|
||||
@@ -69,6 +72,7 @@ instance/
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
docs/public/_pagefind/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
@@ -104,6 +108,8 @@ celerybeat.pid
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
# Machine-specific Claude Code guidance (see CLAUDE.md preamble)
|
||||
CLAUDE.md
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
@@ -145,6 +151,10 @@ frontend/yarn-error.log*
|
||||
frontend/pnpm-debug.log*
|
||||
frontend/lerna-debug.log*
|
||||
|
||||
# Keep frontend utility helpers tracked (overrides global lib/ ignore)
|
||||
!frontend/src/lib/
|
||||
!frontend/src/lib/**
|
||||
|
||||
frontend/node_modules
|
||||
frontend/dist
|
||||
frontend/dist-ssr
|
||||
@@ -173,5 +183,14 @@ application/vectors/
|
||||
|
||||
node_modules/
|
||||
.vscode/settings.json
|
||||
.vscode/sftp.json
|
||||
/models/
|
||||
model/
|
||||
|
||||
# E2E test artifacts
|
||||
.e2e-tmp/
|
||||
/tmp/docsgpt-e2e/
|
||||
tests/e2e/node_modules/
|
||||
tests/e2e/playwright-report/
|
||||
tests/e2e/test-results/
|
||||
tests/e2e/.e2e-last-run.json
|
||||
|
||||
@@ -1,2 +1,6 @@
|
||||
# Allow lines to be as long as 120 characters.
|
||||
line-length = 120
|
||||
line-length = 120
|
||||
|
||||
[lint.per-file-ignores]
|
||||
# Integration tests use sys.path.insert() before imports for standalone execution
|
||||
"tests/integration/*.py" = ["E402"]
|
||||
7
.vale.ini
Normal file
7
.vale.ini
Normal file
@@ -0,0 +1,7 @@
|
||||
MinAlertLevel = warning
|
||||
StylesPath = .github/styles
|
||||
Vocab = DocsGPT
|
||||
|
||||
[*.{md,mdx}]
|
||||
BasedOnStyles = DocsGPT
|
||||
|
||||
33
.vscode/launch.json
vendored
33
.vscode/launch.json
vendored
@@ -2,15 +2,11 @@
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Docker Debug Frontend",
|
||||
"name": "Frontend Debug (npm)",
|
||||
"type": "node-terminal",
|
||||
"request": "launch",
|
||||
"type": "chrome",
|
||||
"preLaunchTask": "docker-compose: debug:frontend",
|
||||
"url": "http://127.0.0.1:5173",
|
||||
"webRoot": "${workspaceFolder}/frontend",
|
||||
"skipFiles": [
|
||||
"<node_internals>/**"
|
||||
]
|
||||
"command": "npm run dev",
|
||||
"cwd": "${workspaceFolder}/frontend"
|
||||
},
|
||||
{
|
||||
"name": "Flask Debugger",
|
||||
@@ -49,6 +45,27 @@
|
||||
"--pool=solo"
|
||||
],
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
{
|
||||
"name": "Dev Containers (Mongo + Redis)",
|
||||
"type": "node-terminal",
|
||||
"request": "launch",
|
||||
"command": "docker compose -f deployment/docker-compose-dev.yaml up --build",
|
||||
"cwd": "${workspaceFolder}"
|
||||
}
|
||||
],
|
||||
"compounds": [
|
||||
{
|
||||
"name": "DocsGPT: Full Stack",
|
||||
"configurations": [
|
||||
"Frontend Debug (npm)",
|
||||
"Flask Debugger",
|
||||
"Celery Debugger"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "DocsGPT",
|
||||
"order": 1
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
21
.vscode/tasks.json
vendored
21
.vscode/tasks.json
vendored
@@ -1,21 +0,0 @@
|
||||
{
|
||||
"version": "2.0.0",
|
||||
"tasks": [
|
||||
{
|
||||
"type": "docker-compose",
|
||||
"label": "docker-compose: debug:frontend",
|
||||
"dockerCompose": {
|
||||
"up": {
|
||||
"detached": true,
|
||||
"services": [
|
||||
"frontend"
|
||||
],
|
||||
"build": true
|
||||
},
|
||||
"files": [
|
||||
"${workspaceFolder}/docker-compose.yaml"
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
156
AGENTS.md
Normal file
156
AGENTS.md
Normal file
@@ -0,0 +1,156 @@
|
||||
# AGENTS.md
|
||||
|
||||
- Read `CONTRIBUTING.md` before making non-trivial changes.
|
||||
- For day-to-day development and feature work, follow the development-environment workflow rather than defaulting to `setup.sh` / `setup.ps1`.
|
||||
- Avoid using the setup scripts during normal feature work unless the user explicitly asks for them. Users configure `.env` usually.
|
||||
- Try to follow red/green TDD
|
||||
|
||||
### Check existing dev prerequisites first
|
||||
|
||||
For feature work, do **not** assume the environment needs to be recreated.
|
||||
|
||||
- Check whether the user already has a Python virtual environment such as `venv/` or `.venv/`.
|
||||
- Check whether Postgres is already running and reachable via `POSTGRES_URI` (the canonical user-data store).
|
||||
- Check whether Redis is already running.
|
||||
- Reuse what is already working. Do not stop or recreate Postgres, Redis, or the Python environment unless the task is environment setup or troubleshooting.
|
||||
|
||||
> MongoDB is **not** required for the default install. It is only needed if
|
||||
> the user opts into the Mongo vector-store backend (`VECTOR_STORE=mongodb`)
|
||||
> or is running the one-shot `scripts/db/backfill.py` to migrate existing
|
||||
> user data from the legacy Mongo-based install. In those cases, `pymongo`
|
||||
> is available as an optional extra, not a core dependency.
|
||||
|
||||
## Normal local development commands
|
||||
|
||||
Use these commands once the dev prerequisites above are satisfied.
|
||||
|
||||
### Backend
|
||||
|
||||
```bash
|
||||
source .venv/bin/activate # macOS/Linux
|
||||
uv pip install -r application/requirements.txt # or: pip install -r application/requirements.txt
|
||||
```
|
||||
|
||||
Run the Flask API (if needed):
|
||||
|
||||
```bash
|
||||
flask --app application/app.py run --host=0.0.0.0 --port=7091
|
||||
```
|
||||
|
||||
That's the fast inner-loop option — quick startup, the Werkzeug interactive
|
||||
debugger still works, and it hot-reloads on source changes. It serves the
|
||||
Flask routes only (`/api/*`, `/stream`, etc.).
|
||||
|
||||
If you need to exercise the full ASGI stack — the `/mcp` FastMCP endpoint,
|
||||
or to match the production runtime exactly — run the ASGI composition under
|
||||
uvicorn instead:
|
||||
|
||||
```bash
|
||||
uvicorn application.asgi:asgi_app --host 0.0.0.0 --port 7091 --reload
|
||||
```
|
||||
|
||||
Production uses `gunicorn -k uvicorn_worker.UvicornWorker` against the same
|
||||
`application.asgi:asgi_app` target; see `application/Dockerfile` for the
|
||||
full flag set.
|
||||
|
||||
Run the Celery worker in a separate terminal (if needed):
|
||||
|
||||
```bash
|
||||
celery -A application.app.celery worker -l INFO
|
||||
```
|
||||
|
||||
On macOS, prefer the solo pool for Celery:
|
||||
|
||||
```bash
|
||||
python -m celery -A application.app.celery worker -l INFO --pool=solo
|
||||
```
|
||||
|
||||
### Frontend
|
||||
|
||||
Install dependencies only when needed, then run the dev server:
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
npm install --include=dev
|
||||
npm run dev
|
||||
```
|
||||
|
||||
### Docs site
|
||||
|
||||
```bash
|
||||
cd docs
|
||||
npm install
|
||||
```
|
||||
|
||||
### Python / backend changes validation
|
||||
|
||||
```bash
|
||||
ruff check .
|
||||
python -m pytest
|
||||
```
|
||||
|
||||
### Frontend changes
|
||||
|
||||
```bash
|
||||
cd frontend && npm run lint
|
||||
cd frontend && npm run build
|
||||
```
|
||||
|
||||
### Documentation changes
|
||||
|
||||
```bash
|
||||
cd docs && npm run build
|
||||
```
|
||||
|
||||
If Vale is installed locally and you edited prose, also run:
|
||||
|
||||
```bash
|
||||
vale .
|
||||
```
|
||||
|
||||
## Repository map
|
||||
|
||||
- `application/`: Flask backend, API routes, agent logic, retrieval, parsing, security, storage, Celery worker, and WSGI entrypoints.
|
||||
- `tests/`: backend unit/integration tests and test-only Python dependencies.
|
||||
- `frontend/`: Vite + React + TypeScript application.
|
||||
- `frontend/src/`: main UI code, including `components`, `conversation`, `hooks`, `locale`, `settings`, `upload`, and Redux store wiring in `store.ts`.
|
||||
- `docs/`: separate documentation site built with Next.js/Nextra.
|
||||
- `extensions/`: integrations and widgets — currently the Chatwoot webhook bridge and the React widget (published to npm as `docsgpt`). The Discord bot, Slack bot, and Chrome extension have been moved to their own repos under `arc53/`.
|
||||
- `deployment/`: Docker Compose variants and Kubernetes manifests.
|
||||
|
||||
## Coding rules
|
||||
|
||||
### Backend
|
||||
|
||||
- Follow PEP 8 and keep Python line length at or under 120 characters.
|
||||
- Use type hints for function arguments and return values.
|
||||
- Add Google-style docstrings to new or substantially changed functions and classes.
|
||||
- Add or update tests under `tests/` for backend behavior changes.
|
||||
- Keep changes narrow in `api`, `auth`, `security`, `parser`, `retriever`, and `storage` areas.
|
||||
|
||||
### Backend Abstractions
|
||||
|
||||
- LLM providers implement a common interface in `application/llm/` (add new providers by extending the base class).
|
||||
- Vector stores are abstracted in `application/vectorstore/`.
|
||||
- Parsers live in `application/parser/` and handle different document formats in the ingestion stage.
|
||||
- Agents and tools are in `application/agents/` and `application/agents/tools/`.
|
||||
- Celery setup/config lives in `application/celery_init.py` and `application/celeryconfig.py`.
|
||||
- Settings and env vars are managed via Pydantic in `application/core/settings.py`.
|
||||
|
||||
### Frontend
|
||||
|
||||
- Follow the existing ESLint + Prettier setup.
|
||||
- Prefer small, reusable functional components and hooks.
|
||||
- If shared state must be added, use Redux rather than introducing a new global state library.
|
||||
- Avoid broad UI refactors unless the task explicitly asks for them.
|
||||
- Do not re-create components if we already have some in the app.
|
||||
|
||||
## PR readiness
|
||||
|
||||
Before opening a PR:
|
||||
|
||||
- run the relevant validation commands above
|
||||
- confirm backend changes still work end-to-end after ingesting sample data when applicable
|
||||
- clearly summarize user-visible behavior changes
|
||||
- mention any config, dependency, or deployment implications
|
||||
- Ask your user to attach a screenshot or a video to it
|
||||
@@ -22,6 +22,11 @@ Thank you for choosing to contribute to DocsGPT! We are all very grateful!
|
||||
|
||||
- We have a frontend built on React (Vite) and a backend in Python.
|
||||
|
||||
> **Required for every PR:** Please attach screenshots or a short screen
|
||||
> recording that shows the working version of your changes. This makes the
|
||||
> requirement visible to reviewers and helps them quickly verify what you are
|
||||
> submitting.
|
||||
|
||||
|
||||
Before creating issues, please check out how the latest version of our app looks and works by launching it via [Quickstart](https://github.com/arc53/DocsGPT#quickstart) the version on our live demo is slightly modified with login. Your issues should relate to the version you can launch via [Quickstart](https://github.com/arc53/DocsGPT#quickstart).
|
||||
|
||||
@@ -125,7 +130,7 @@ Here's a step-by-step guide on how to contribute to DocsGPT:
|
||||
```
|
||||
|
||||
9. **Submit a Pull Request (PR):**
|
||||
- Create a Pull Request from your branch to the main repository. Make sure to include a detailed description of your changes and reference any related issues.
|
||||
- Create a Pull Request from your branch to the main repository. Make sure to include a detailed description of your changes, reference any related issues, and attach screenshots or a screen recording showing the working version.
|
||||
|
||||
10. **Collaborate:**
|
||||
- Be responsive to comments and feedback on your PR.
|
||||
@@ -147,5 +152,5 @@ Here's a step-by-step guide on how to contribute to DocsGPT:
|
||||
Thank you for considering contributing to DocsGPT! 🙏
|
||||
|
||||
## Questions/collaboration
|
||||
Feel free to join our [Discord](https://discord.gg/n5BX8dh8rU). We're very friendly and welcoming to new contributors, so don't hesitate to reach out.
|
||||
Feel free to join our [Discord](https://discord.gg/vN7YFfdMpj). We're very friendly and welcoming to new contributors, so don't hesitate to reach out.
|
||||
# Thank you so much for considering to contributing DocsGPT!🙏
|
||||
|
||||
39
HACKTOBERFEST.md
Normal file
39
HACKTOBERFEST.md
Normal file
@@ -0,0 +1,39 @@
|
||||
# **🎉 Join the Hacktoberfest with DocsGPT and win a Free T-shirt for a meaningful PR! 🎉**
|
||||
|
||||
Welcome, contributors! We're excited to announce that DocsGPT is participating in Hacktoberfest. Get involved by submitting meaningful pull requests.
|
||||
|
||||
All Meaningful contributors with accepted PRs that were created for issues with the `hacktoberfest` label (set by our maintainer team: dartpain, siiddhantt, pabik, ManishMadan2882) will receive a cool T-shirt! 🤩.
|
||||
<img width="1331" height="678" alt="hacktoberfest-mocks-preview" src="https://github.com/user-attachments/assets/633f6377-38db-48f5-b519-a8b3855a9eb4" />
|
||||
|
||||
Fill in [this form](https://forms.gle/Npaba4n9Epfyx56S8
|
||||
) after your PR was merged please
|
||||
|
||||
If you are in doubt don't hesitate to ping us on discord, ping me - Alex (dartpain).
|
||||
|
||||
## 📜 Here's How to Contribute:
|
||||
```text
|
||||
🛠️ Code: This is the golden ticket! Make meaningful contributions through PRs.
|
||||
|
||||
🧩 API extension: Build an app utilising DocsGPT API. We prefer submissions that showcase original ideas and turn the API into an AI agent.
|
||||
They can be a completely separate repos.
|
||||
For example:
|
||||
https://github.com/arc53/tg-bot-docsgpt-extenstion or
|
||||
https://github.com/arc53/DocsGPT-cli
|
||||
|
||||
Non-Code Contributions:
|
||||
|
||||
📚 Wiki: Improve our documentation, create a guide.
|
||||
|
||||
🖥️ Design: Improve the UI/UX or design a new feature.
|
||||
```
|
||||
|
||||
### 📝 Guidelines for Pull Requests:
|
||||
- Familiarize yourself with the current contributions and our [Roadmap](https://github.com/orgs/arc53/projects/2).
|
||||
- Before contributing check existing [issues](https://github.com/arc53/DocsGPT/issues) or [create](https://github.com/arc53/DocsGPT/issues/new/choose) an issue and wait to get assigned.
|
||||
- Once you are finished with your contribution, please fill in this [form](https://forms.gle/Npaba4n9Epfyx56S8).
|
||||
- Refer to the [Documentation](https://docs.docsgpt.cloud/).
|
||||
- Feel free to join our [Discord](https://discord.gg/vN7YFfdMpj) server. We're here to help newcomers, so don't hesitate to jump in! Join us [here](https://discord.gg/vN7YFfdMpj).
|
||||
|
||||
Thank you very much for considering contributing to DocsGPT during Hacktoberfest! 🙏 Your contributions (not just simple typos) could earn you a stylish new t-shirt.
|
||||
|
||||
We will publish a t-shirt design later into the October.
|
||||
57
README.md
57
README.md
@@ -3,11 +3,11 @@
|
||||
</h1>
|
||||
|
||||
<p align="center">
|
||||
<strong>Open-Source RAG Assistant</strong>
|
||||
<strong>Private AI for agents, assistants and enterprise search</strong>
|
||||
</p>
|
||||
|
||||
<p align="left">
|
||||
<strong><a href="https://www.docsgpt.cloud/">DocsGPT</a></strong> is an open-source genAI tool that helps users get reliable answers from any knowledge source, while avoiding hallucinations. It enables quick and reliable information retrieval, with tooling and agentic system capability built in.
|
||||
<strong><a href="https://www.docsgpt.cloud/">DocsGPT</a></strong> is an open-source AI platform for building intelligent agents and assistants. Features Agent Builder, deep research tools, document analysis (PDF, Office, web content, and audio), Multi-model support (choose your provider or run locally), and rich API connectivity for agents with actionable tools and integrations. Deploy anywhere with complete privacy control.
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
@@ -16,23 +16,27 @@
|
||||
<a href="https://github.com/arc53/DocsGPT"></a>
|
||||
<a href="https://github.com/arc53/DocsGPT/blob/main/LICENSE"></a>
|
||||
<a href="https://www.bestpractices.dev/projects/9907"><img src="https://www.bestpractices.dev/projects/9907/badge"></a>
|
||||
<a href="https://discord.gg/n5BX8dh8rU"></a>
|
||||
<a href="https://twitter.com/docsgptai"></a>
|
||||
<a href="https://discord.gg/vN7YFfdMpj"></a>
|
||||
<a href="https://x.com/docsgptai"></a>
|
||||
|
||||
<a href="https://docs.docsgpt.cloud/quickstart">⚡️ Quickstart</a> • <a href="https://app.docsgpt.cloud/">☁️ Cloud Version</a> • <a href="https://discord.gg/n5BX8dh8rU">💬 Discord</a>
|
||||
<br>
|
||||
<a href="https://docs.docsgpt.cloud/">📖 Documentation</a> • <a href="https://github.com/arc53/DocsGPT/blob/main/CONTRIBUTING.md">👫 Contribute</a> • <a href="https://blog.docsgpt.cloud/">🗞 Blog</a>
|
||||
<br>
|
||||
<a href="https://docs.docsgpt.cloud/quickstart">⚡️ Quickstart</a> • <a href="https://app.docsgpt.cloud/">☁️ Cloud Version</a> • <a href="https://discord.gg/vN7YFfdMpj">💬 Discord</a>
|
||||
<br>
|
||||
<a href="https://docs.docsgpt.cloud/">📖 Documentation</a> • <a href="https://github.com/arc53/DocsGPT/blob/main/CONTRIBUTING.md">👫 Contribute</a> • <a href="https://blog.docsgpt.cloud/">🗞 Blog</a>
|
||||
<br>
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
<div align="center">
|
||||
<img src="https://d3dg1063dc54p9.cloudfront.net/videos/demov7.gif" alt="video-example-of-docs-gpt" width="800" height="450">
|
||||
<br>
|
||||
<img src="https://d3dg1063dc54p9.cloudfront.net/videos/demo-26.gif" alt="video-example-of-docs-gpt" width="800" height="480">
|
||||
</div>
|
||||
<h3 align="left">
|
||||
<strong>Key Features:</strong>
|
||||
</h3>
|
||||
<ul align="left">
|
||||
<li><strong>🗂️ Wide Format Support:</strong> Reads PDF, DOCX, CSV, XLSX, EPUB, MD, RST, HTML, MDX, JSON, PPTX, and images.</li>
|
||||
<li><strong>🗂️ Wide Format Support:</strong> Reads PDF, DOCX, CSV, XLSX, EPUB, MD, RST, HTML, MDX, JSON, PPTX, images, and audio files such as MP3, WAV, M4A, OGG, and WebM.</li>
|
||||
<li><strong>🎙️ Speech Workflows:</strong> Record voice input into chat, transcribe audio on the backend, and ingest meeting recordings or voice notes as searchable knowledge.</li>
|
||||
<li><strong>🌐 Web & Data Integration:</strong> Ingests from URLs, sitemaps, Reddit, GitHub and web crawlers.</li>
|
||||
<li><strong>✅ Reliable Answers:</strong> Get accurate, hallucination-free responses with source citations viewable in a clean UI.</li>
|
||||
<li><strong>🔑 Streamlined API Keys:</strong> Generate keys linked to your settings, documents, and models, simplifying chatbot and integration setup.</li>
|
||||
@@ -43,18 +47,11 @@
|
||||
</ul>
|
||||
|
||||
## Roadmap
|
||||
|
||||
- [x] Full GoogleAI compatibility (Jan 2025)
|
||||
- [x] Add tools (Jan 2025)
|
||||
- [x] Manually updating chunks in the app UI (Feb 2025)
|
||||
- [x] Devcontainer for easy development (Feb 2025)
|
||||
- [x] ReACT agent (March 2025)
|
||||
- [x] Chatbots menu re-design to handle tools, agent types, and more (April 2025)
|
||||
- [x] New input box in the conversation menu (April 2025)
|
||||
- [x] Add triggerable actions / tools (webhook) (April 2025)
|
||||
- [ ] Anthropic Tool compatibility (May 2025)
|
||||
- [ ] Add OAuth 2.0 authentication for tools and sources
|
||||
- [ ] Agent scheduling
|
||||
- [x] Add OAuth 2.0 authentication for MCP ( September 2025 )
|
||||
- [x] Deep Agents ( October 2025 )
|
||||
- [x] Prompt Templating ( October 2025 )
|
||||
- [x] Full api tooling ( Dec 2025 )
|
||||
- [ ] Agent scheduling ( Jan 2026 )
|
||||
|
||||
You can find our full roadmap [here](https://github.com/orgs/arc53/projects/2). Please don't hesitate to contribute or create issues, it helps us improve DocsGPT!
|
||||
|
||||
@@ -68,11 +65,10 @@ We're eager to provide personalized assistance when deploying your DocsGPT to a
|
||||
|
||||
## Join the Lighthouse Program 🌟
|
||||
|
||||
Calling all developers and GenAI innovators! The **DocsGPT Lighthouse Program** connects technical leaders actively deploying or extending DocsGPT in real-world scenarios. Collaborate directly with our team to shape the roadmap, access priority support, and build enterprise-ready solutions with exclusive community insights.
|
||||
Calling all developers and GenAI innovators! The **DocsGPT Lighthouse Program** connects technical leaders actively deploying or extending DocsGPT in real-world scenarios. Collaborate directly with our team to shape the roadmap, access priority support, and build enterprise-ready solutions with exclusive community insights.
|
||||
|
||||
[Learn More & Apply →](https://docs.google.com/forms/d/1KAADiJinUJ8EMQyfTXUIGyFbqINNClNR3jBNWq7DgTE)
|
||||
|
||||
|
||||
## QuickStart
|
||||
|
||||
> [!Note]
|
||||
@@ -103,7 +99,7 @@ A more detailed [Quickstart](https://docs.docsgpt.cloud/quickstart) is available
|
||||
PowerShell -ExecutionPolicy Bypass -File .\setup.ps1
|
||||
```
|
||||
|
||||
Either script will guide you through setting up DocsGPT. Four options available: using the public API, running locally, connecting to a local inference engine, or using a cloud API provider. Scripts will automatically configure your `.env` file and handle necessary downloads and installations based on your chosen option.
|
||||
Either script will guide you through setting up DocsGPT. Five options available: using the public API, running locally, connecting to a local inference engine, using a cloud API provider, or build the docker image locally. Scripts will automatically configure your `.env` file and handle necessary downloads and installations based on your chosen option.
|
||||
|
||||
**Navigate to http://localhost:5173/**
|
||||
|
||||
@@ -112,6 +108,7 @@ To stop DocsGPT, open a terminal in the `DocsGPT` directory and run:
|
||||
```bash
|
||||
docker compose -f deployment/docker-compose.yaml down
|
||||
```
|
||||
|
||||
(or use the specific `docker compose down` command shown after running the setup script).
|
||||
|
||||
> [!Note]
|
||||
@@ -139,7 +136,6 @@ Please refer to the [CONTRIBUTING.md](CONTRIBUTING.md) file for information abou
|
||||
|
||||
We as members, contributors, and leaders, pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. Please refer to the [CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md) file for more information about contributing.
|
||||
|
||||
|
||||
## Many Thanks To Our Contributors⚡
|
||||
|
||||
<a href="https://github.com/arc53/DocsGPT/graphs/contributors" alt="View Contributors">
|
||||
@@ -150,9 +146,16 @@ We as members, contributors, and leaders, pledge to make participation in our co
|
||||
|
||||
The source code license is [MIT](https://opensource.org/license/mit/), as described in the [LICENSE](LICENSE) file.
|
||||
|
||||
<p>This project is supported by:</p>
|
||||
## This project is supported by:
|
||||
|
||||
<p>
|
||||
<a href="https://www.digitalocean.com/?utm_medium=opensource&utm_source=DocsGPT">
|
||||
<img src="https://opensource.nyc3.cdn.digitaloceanspaces.com/attribution/assets/SVG/DO_Logo_horizontal_blue.svg" width="201px">
|
||||
</a>
|
||||
</p>
|
||||
<p>
|
||||
<a href="https://get.neon.com/docsgpt">
|
||||
<img width="201" alt="color" src="https://github.com/user-attachments/assets/7d9813b7-0e6d-403f-b5af-68af066b326f" />
|
||||
</a>
|
||||
|
||||
</p>
|
||||
|
||||
18
SECURITY.md
18
SECURITY.md
@@ -2,13 +2,21 @@
|
||||
|
||||
## Supported Versions
|
||||
|
||||
Supported Versions:
|
||||
|
||||
Currently, we support security patches by committing changes and bumping the version published on Github.
|
||||
Security patches target the latest release and the `main` branch. We recommend always running the most recent version.
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
Found a vulnerability? Please email us:
|
||||
Preferred method: use GitHub's private vulnerability reporting flow:
|
||||
https://github.com/arc53/DocsGPT/security
|
||||
|
||||
security@arc53.com
|
||||
Then click **Report a vulnerability**.
|
||||
|
||||
|
||||
Alternatively, email us at: security@arc53.com
|
||||
|
||||
We aim to acknowledge reports within 48 hours.
|
||||
|
||||
## Incident Handling
|
||||
|
||||
For the public incident response process, see [`INCIDENT_RESPONSE.md`](./.github/INCIDENT_RESPONSE.md). If you believe an active exploit is occurring, include **URGENT** in your report subject line.
|
||||
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
API_KEY=your_api_key
|
||||
EMBEDDINGS_KEY=your_api_key
|
||||
API_URL=http://localhost:7091
|
||||
FLASK_APP=application/app.py
|
||||
FLASK_DEBUG=true
|
||||
|
||||
#For OPENAI on Azure
|
||||
OPENAI_API_BASE=
|
||||
OPENAI_API_VERSION=
|
||||
AZURE_DEPLOYMENT_NAME=
|
||||
AZURE_EMBEDDINGS_DEPLOYMENT_NAME=
|
||||
@@ -7,7 +7,7 @@ RUN apt-get update && \
|
||||
apt-get install -y software-properties-common && \
|
||||
add-apt-repository ppa:deadsnakes/ppa && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends gcc wget unzip libc6-dev python3.12 python3.12-venv && \
|
||||
apt-get install -y --no-install-recommends gcc g++ wget unzip libc6-dev python3.12 python3.12-venv python3.12-dev && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Verify Python installation and setup symlink
|
||||
@@ -48,7 +48,12 @@ FROM ubuntu:24.04 as final
|
||||
RUN apt-get update && \
|
||||
apt-get install -y software-properties-common && \
|
||||
add-apt-repository ppa:deadsnakes/ppa && \
|
||||
apt-get update && apt-get install -y --no-install-recommends python3.12 && \
|
||||
apt-get update && apt-get install -y --no-install-recommends \
|
||||
python3.12 \
|
||||
libgl1 \
|
||||
libglib2.0-0 \
|
||||
poppler-utils \
|
||||
&& \
|
||||
ln -s /usr/bin/python3.12 /usr/bin/python && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
@@ -83,5 +88,15 @@ EXPOSE 7091
|
||||
# Switch to non-root user
|
||||
USER appuser
|
||||
|
||||
# Start Gunicorn
|
||||
CMD ["gunicorn", "-w", "1", "--timeout", "120", "--bind", "0.0.0.0:7091", "--preload", "application.wsgi:app"]
|
||||
CMD ["gunicorn", \
|
||||
"-w", "1", \
|
||||
"-k", "uvicorn_worker.UvicornWorker", \
|
||||
"--bind", "0.0.0.0:7091", \
|
||||
"--timeout", "180", \
|
||||
"--graceful-timeout", "120", \
|
||||
"--keep-alive", "5", \
|
||||
"--worker-tmp-dir", "/dev/shm", \
|
||||
"--max-requests", "1000", \
|
||||
"--max-requests-jitter", "100", \
|
||||
"--config", "application/gunicorn_conf.py", \
|
||||
"application.asgi:asgi_app"]
|
||||
|
||||
@@ -1,11 +1,20 @@
|
||||
import logging
|
||||
|
||||
from application.agents.agentic_agent import AgenticAgent
|
||||
from application.agents.classic_agent import ClassicAgent
|
||||
from application.agents.react_agent import ReActAgent
|
||||
from application.agents.research_agent import ResearchAgent
|
||||
from application.agents.workflow_agent import WorkflowAgent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentCreator:
|
||||
agents = {
|
||||
"classic": ClassicAgent,
|
||||
"react": ReActAgent,
|
||||
"react": ClassicAgent, # backwards compat: react falls back to classic
|
||||
"agentic": AgenticAgent,
|
||||
"research": ResearchAgent,
|
||||
"workflow": WorkflowAgent,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
63
application/agents/agentic_agent.py
Normal file
63
application/agents/agentic_agent.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import logging
|
||||
from typing import Dict, Generator, Optional
|
||||
|
||||
from application.agents.base import BaseAgent
|
||||
from application.agents.tools.internal_search import (
|
||||
INTERNAL_TOOL_ID,
|
||||
add_internal_search_tool,
|
||||
)
|
||||
from application.logging import LogContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgenticAgent(BaseAgent):
|
||||
"""Agent where the LLM controls retrieval via tools.
|
||||
|
||||
Unlike ClassicAgent which pre-fetches docs into the prompt,
|
||||
AgenticAgent gives the LLM an internal_search tool so it can
|
||||
decide when, what, and whether to search.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
retriever_config: Optional[Dict] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.retriever_config = retriever_config or {}
|
||||
|
||||
def _gen_inner(
|
||||
self, query: str, log_context: LogContext
|
||||
) -> Generator[Dict, None, None]:
|
||||
tools_dict = self.tool_executor.get_tools()
|
||||
add_internal_search_tool(tools_dict, self.retriever_config)
|
||||
self._prepare_tools(tools_dict)
|
||||
|
||||
# 4. Build messages (prompt has NO pre-fetched docs)
|
||||
messages = self._build_messages(self.prompt, query)
|
||||
|
||||
# 5. Call LLM — the handler manages the tool loop
|
||||
llm_response = self._llm_gen(messages, log_context)
|
||||
|
||||
yield from self._handle_response(
|
||||
llm_response, tools_dict, messages, log_context
|
||||
)
|
||||
|
||||
# 6. Collect sources from internal search tool results
|
||||
self._collect_internal_sources()
|
||||
|
||||
yield {"sources": self.retrieved_docs}
|
||||
yield {"tool_calls": self._get_truncated_tool_calls()}
|
||||
|
||||
log_context.stacks.append(
|
||||
{"component": "agent", "data": {"tool_calls": self.tool_calls.copy()}}
|
||||
)
|
||||
|
||||
def _collect_internal_sources(self):
|
||||
"""Collect retrieved docs from the cached InternalSearchTool instance."""
|
||||
cache_key = f"internal_search:{INTERNAL_TOOL_ID}:{self.user or ''}"
|
||||
tool = self.tool_executor._loaded_tools.get(cache_key)
|
||||
if tool and hasattr(tool, "retrieved_docs") and tool.retrieved_docs:
|
||||
self.retrieved_docs = tool.retrieved_docs
|
||||
@@ -1,17 +1,21 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Generator, List, Optional
|
||||
from typing import Any, Dict, Generator, List, Optional
|
||||
|
||||
from application.agents.llm_handler import get_llm_handler
|
||||
from application.agents.tools.tool_action_parser import ToolActionParser
|
||||
from application.agents.tools.tool_manager import ToolManager
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.agents.tool_executor import ToolExecutor
|
||||
from application.core.json_schema_utils import (
|
||||
JsonSchemaValidationError,
|
||||
normalize_json_schema_payload,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
from application.llm.handlers.base import ToolCall
|
||||
from application.llm.handlers.handler_creator import LLMHandlerCreator
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.logging import build_stack_data, log_activity, LogContext
|
||||
from application.retriever.base import BaseRetriever
|
||||
from application.core.settings import settings
|
||||
from bson.objectid import ObjectId
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseAgent(ABC):
|
||||
@@ -19,242 +23,526 @@ class BaseAgent(ABC):
|
||||
self,
|
||||
endpoint: str,
|
||||
llm_name: str,
|
||||
gpt_model: str,
|
||||
model_id: str,
|
||||
api_key: str,
|
||||
agent_id: Optional[str] = None,
|
||||
user_api_key: Optional[str] = None,
|
||||
prompt: str = "",
|
||||
chat_history: Optional[List[Dict]] = None,
|
||||
retrieved_docs: Optional[List[Dict]] = None,
|
||||
decoded_token: Optional[Dict] = None,
|
||||
attachments: Optional[List[Dict]] = None,
|
||||
json_schema: Optional[Dict] = None,
|
||||
limited_token_mode: Optional[bool] = False,
|
||||
token_limit: Optional[int] = settings.DEFAULT_AGENT_LIMITS["token_limit"],
|
||||
limited_request_mode: Optional[bool] = False,
|
||||
request_limit: Optional[int] = settings.DEFAULT_AGENT_LIMITS["request_limit"],
|
||||
compressed_summary: Optional[str] = None,
|
||||
llm=None,
|
||||
llm_handler=None,
|
||||
tool_executor: Optional[ToolExecutor] = None,
|
||||
backup_models: Optional[List[str]] = None,
|
||||
model_user_id: Optional[str] = None,
|
||||
):
|
||||
self.endpoint = endpoint
|
||||
self.llm_name = llm_name
|
||||
self.gpt_model = gpt_model
|
||||
self.model_id = model_id
|
||||
self.api_key = api_key
|
||||
self.agent_id = agent_id
|
||||
self.user_api_key = user_api_key
|
||||
self.prompt = prompt
|
||||
self.decoded_token = decoded_token or {}
|
||||
self.user: str = decoded_token.get("sub")
|
||||
self.tool_config: Dict = {}
|
||||
self.user: str = self.decoded_token.get("sub")
|
||||
# BYOM-resolution scope: owner for shared agents, caller for
|
||||
# caller-owned BYOM, None for built-ins. Falls back to self.user
|
||||
# for worker/legacy callers that don't thread model_user_id.
|
||||
self.model_user_id = model_user_id
|
||||
self.tools: List[Dict] = []
|
||||
self.tool_calls: List[Dict] = []
|
||||
self.chat_history: List[Dict] = chat_history if chat_history is not None else []
|
||||
self.llm = LLMCreator.create_llm(
|
||||
llm_name,
|
||||
api_key=api_key,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
|
||||
if llm is not None:
|
||||
self.llm = llm
|
||||
else:
|
||||
self.llm = LLMCreator.create_llm(
|
||||
llm_name,
|
||||
api_key=api_key,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
model_id=model_id,
|
||||
agent_id=agent_id,
|
||||
backup_models=backup_models,
|
||||
model_user_id=model_user_id,
|
||||
)
|
||||
|
||||
# For BYOM, registry id (UUID) differs from upstream model id
|
||||
# (e.g. ``mistral-large-latest``). LLMCreator resolved this onto
|
||||
# the LLM instance; cache it for subsequent gen calls.
|
||||
self.upstream_model_id = (
|
||||
getattr(self.llm, "model_id", None) or model_id
|
||||
)
|
||||
self.llm_handler = get_llm_handler(llm_name)
|
||||
|
||||
self.retrieved_docs = retrieved_docs or []
|
||||
|
||||
if llm_handler is not None:
|
||||
self.llm_handler = llm_handler
|
||||
else:
|
||||
self.llm_handler = LLMHandlerCreator.create_handler(
|
||||
llm_name if llm_name else "default"
|
||||
)
|
||||
|
||||
# Tool executor — injected or created
|
||||
if tool_executor is not None:
|
||||
self.tool_executor = tool_executor
|
||||
else:
|
||||
self.tool_executor = ToolExecutor(
|
||||
user_api_key=user_api_key,
|
||||
user=self.user,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
self.attachments = attachments or []
|
||||
self.json_schema = None
|
||||
if json_schema is not None:
|
||||
try:
|
||||
self.json_schema = normalize_json_schema_payload(json_schema)
|
||||
except JsonSchemaValidationError as exc:
|
||||
logger.warning("Ignoring invalid JSON schema payload: %s", exc)
|
||||
self.limited_token_mode = limited_token_mode
|
||||
self.token_limit = token_limit
|
||||
self.limited_request_mode = limited_request_mode
|
||||
self.request_limit = request_limit
|
||||
self.compressed_summary = compressed_summary
|
||||
self.current_token_count = 0
|
||||
self.context_limit_reached = False
|
||||
|
||||
@log_activity()
|
||||
def gen(
|
||||
self, query: str, retriever: BaseRetriever, log_context: LogContext = None
|
||||
self, query: str, log_context: LogContext = None
|
||||
) -> Generator[Dict, None, None]:
|
||||
yield from self._gen_inner(query, retriever, log_context)
|
||||
yield from self._gen_inner(query, log_context)
|
||||
|
||||
@abstractmethod
|
||||
def _gen_inner(
|
||||
self, query: str, retriever: BaseRetriever, log_context: LogContext
|
||||
self, query: str, log_context: LogContext
|
||||
) -> Generator[Dict, None, None]:
|
||||
pass
|
||||
|
||||
def _get_tools(self, api_key: str = None) -> Dict[str, Dict]:
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
agents_collection = db["agents"]
|
||||
tools_collection = db["user_tools"]
|
||||
def gen_continuation(
|
||||
self,
|
||||
messages: List[Dict],
|
||||
tools_dict: Dict,
|
||||
pending_tool_calls: List[Dict],
|
||||
tool_actions: List[Dict],
|
||||
) -> Generator[Dict, None, None]:
|
||||
"""Resume generation after tool actions are resolved.
|
||||
|
||||
agent_data = agents_collection.find_one({"key": api_key or self.user_api_key})
|
||||
tool_ids = agent_data.get("tools", []) if agent_data else []
|
||||
Processes the client-provided *tool_actions* (approvals, denials,
|
||||
or client-side results), appends the resulting messages, then
|
||||
hands back to the LLM to continue the conversation.
|
||||
|
||||
tools = (
|
||||
tools_collection.find(
|
||||
{"_id": {"$in": [ObjectId(tool_id) for tool_id in tool_ids]}}
|
||||
Args:
|
||||
messages: The saved messages array from the pause point.
|
||||
tools_dict: The saved tools dictionary.
|
||||
pending_tool_calls: The pending tool call descriptors from the pause.
|
||||
tool_actions: Client-provided actions resolving the pending calls.
|
||||
"""
|
||||
self._prepare_tools(tools_dict)
|
||||
|
||||
actions_by_id = {a["call_id"]: a for a in tool_actions}
|
||||
|
||||
# Build a single assistant message containing all tool calls so
|
||||
# the message history matches the format LLM providers expect
|
||||
# (one assistant message with N tool_calls, followed by N tool results).
|
||||
tc_objects: List[Dict[str, Any]] = []
|
||||
for pending in pending_tool_calls:
|
||||
call_id = pending["call_id"]
|
||||
args = pending["arguments"]
|
||||
args_str = (
|
||||
json.dumps(args) if isinstance(args, dict) else (args or "{}")
|
||||
)
|
||||
if tool_ids
|
||||
else []
|
||||
)
|
||||
tools = list(tools)
|
||||
tools_by_id = {str(tool["_id"]): tool for tool in tools} if tools else {}
|
||||
|
||||
return tools_by_id
|
||||
|
||||
def _get_user_tools(self, user="local"):
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
user_tools_collection = db["user_tools"]
|
||||
user_tools = user_tools_collection.find({"user": user, "status": True})
|
||||
user_tools = list(user_tools)
|
||||
tools_by_id = {str(tool["_id"]): tool for tool in user_tools}
|
||||
return tools_by_id
|
||||
|
||||
def _build_tool_parameters(self, action):
|
||||
params = {"type": "object", "properties": {}, "required": []}
|
||||
for param_type in ["query_params", "headers", "body", "parameters"]:
|
||||
if param_type in action and action[param_type].get("properties"):
|
||||
for k, v in action[param_type]["properties"].items():
|
||||
if v.get("filled_by_llm", True):
|
||||
params["properties"][k] = {
|
||||
key: value
|
||||
for key, value in v.items()
|
||||
if key != "filled_by_llm" and key != "value"
|
||||
}
|
||||
|
||||
params["required"].append(k)
|
||||
return params
|
||||
|
||||
def _prepare_tools(self, tools_dict):
|
||||
self.tools = [
|
||||
{
|
||||
tc_obj: Dict[str, Any] = {
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": f"{action['name']}_{tool_id}",
|
||||
"description": action["description"],
|
||||
"parameters": self._build_tool_parameters(action),
|
||||
"name": pending["name"],
|
||||
"arguments": args_str,
|
||||
},
|
||||
}
|
||||
for tool_id, tool in tools_dict.items()
|
||||
if (
|
||||
(tool["name"] == "api_tool" and "actions" in tool.get("config", {}))
|
||||
or (tool["name"] != "api_tool" and "actions" in tool)
|
||||
)
|
||||
for action in (
|
||||
tool["config"]["actions"].values()
|
||||
if tool["name"] == "api_tool"
|
||||
else tool["actions"]
|
||||
)
|
||||
if action.get("active", True)
|
||||
]
|
||||
if pending.get("thought_signature"):
|
||||
tc_obj["thought_signature"] = pending["thought_signature"]
|
||||
tc_objects.append(tc_obj)
|
||||
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": tc_objects,
|
||||
})
|
||||
|
||||
# Now process each pending call and append tool result messages
|
||||
for pending in pending_tool_calls:
|
||||
call_id = pending["call_id"]
|
||||
args = pending["arguments"]
|
||||
action = actions_by_id.get(call_id)
|
||||
if not action:
|
||||
action = {
|
||||
"call_id": call_id,
|
||||
"decision": "denied",
|
||||
"comment": "No response provided",
|
||||
}
|
||||
|
||||
if action.get("decision") == "approved":
|
||||
# Execute the tool server-side
|
||||
tc = ToolCall(
|
||||
id=call_id,
|
||||
name=pending["name"],
|
||||
arguments=(
|
||||
json.dumps(args) if isinstance(args, dict) else args
|
||||
),
|
||||
)
|
||||
tool_gen = self._execute_tool_action(tools_dict, tc)
|
||||
tool_response = None
|
||||
while True:
|
||||
try:
|
||||
event = next(tool_gen)
|
||||
yield event
|
||||
except StopIteration as e:
|
||||
tool_response, _ = e.value
|
||||
break
|
||||
messages.append(
|
||||
self.llm_handler.create_tool_message(tc, tool_response)
|
||||
)
|
||||
|
||||
elif action.get("decision") == "denied":
|
||||
comment = action.get("comment", "")
|
||||
denial = (
|
||||
f"Tool execution denied by user. Reason: {comment}"
|
||||
if comment
|
||||
else "Tool execution denied by user."
|
||||
)
|
||||
tc = ToolCall(
|
||||
id=call_id, name=pending["name"], arguments=args
|
||||
)
|
||||
messages.append(
|
||||
self.llm_handler.create_tool_message(tc, denial)
|
||||
)
|
||||
yield {
|
||||
"type": "tool_call",
|
||||
"data": {
|
||||
"tool_name": pending.get("tool_name", "unknown"),
|
||||
"call_id": call_id,
|
||||
"action_name": pending.get("llm_name", pending["name"]),
|
||||
"arguments": args,
|
||||
"status": "denied",
|
||||
},
|
||||
}
|
||||
|
||||
elif "result" in action:
|
||||
result = action["result"]
|
||||
result_str = (
|
||||
json.dumps(result)
|
||||
if not isinstance(result, str)
|
||||
else result
|
||||
)
|
||||
tc = ToolCall(
|
||||
id=call_id, name=pending["name"], arguments=args
|
||||
)
|
||||
messages.append(
|
||||
self.llm_handler.create_tool_message(tc, result_str)
|
||||
)
|
||||
yield {
|
||||
"type": "tool_call",
|
||||
"data": {
|
||||
"tool_name": pending.get("tool_name", "unknown"),
|
||||
"call_id": call_id,
|
||||
"action_name": pending.get("llm_name", pending["name"]),
|
||||
"arguments": args,
|
||||
"result": (
|
||||
result_str[:50] + "..."
|
||||
if len(result_str) > 50
|
||||
else result_str
|
||||
),
|
||||
"status": "completed",
|
||||
},
|
||||
}
|
||||
|
||||
# Resume the LLM loop with the updated messages
|
||||
llm_response = self._llm_gen(messages)
|
||||
yield from self._handle_response(
|
||||
llm_response, tools_dict, messages, None
|
||||
)
|
||||
|
||||
yield {"sources": self.retrieved_docs}
|
||||
yield {"tool_calls": self._get_truncated_tool_calls()}
|
||||
|
||||
# ---- Tool delegation (thin wrappers around ToolExecutor) ----
|
||||
|
||||
@property
|
||||
def tool_calls(self) -> List[Dict]:
|
||||
return self.tool_executor.tool_calls
|
||||
|
||||
@tool_calls.setter
|
||||
def tool_calls(self, value: List[Dict]):
|
||||
self.tool_executor.tool_calls = value
|
||||
|
||||
def _get_tools(self, api_key: str = None) -> Dict[str, Dict]:
|
||||
return self.tool_executor._get_tools_by_api_key(api_key or self.user_api_key)
|
||||
|
||||
def _get_user_tools(self, user="local"):
|
||||
return self.tool_executor._get_user_tools(user)
|
||||
|
||||
def _build_tool_parameters(self, action):
|
||||
return self.tool_executor._build_tool_parameters(action)
|
||||
|
||||
def _prepare_tools(self, tools_dict):
|
||||
self.tools = self.tool_executor.prepare_tools_for_llm(tools_dict)
|
||||
|
||||
def _execute_tool_action(self, tools_dict, call):
|
||||
parser = ToolActionParser(self.llm.__class__.__name__)
|
||||
tool_id, action_name, call_args = parser.parse_args(call)
|
||||
|
||||
tool_data = tools_dict[tool_id]
|
||||
action_data = (
|
||||
tool_data["config"]["actions"][action_name]
|
||||
if tool_data["name"] == "api_tool"
|
||||
else next(
|
||||
action
|
||||
for action in tool_data["actions"]
|
||||
if action["name"] == action_name
|
||||
)
|
||||
return self.tool_executor.execute(
|
||||
tools_dict, call, self.llm.__class__.__name__
|
||||
)
|
||||
|
||||
query_params, headers, body, parameters = {}, {}, {}, {}
|
||||
param_types = {
|
||||
"query_params": query_params,
|
||||
"headers": headers,
|
||||
"body": body,
|
||||
"parameters": parameters,
|
||||
}
|
||||
def _get_truncated_tool_calls(self):
|
||||
return self.tool_executor.get_truncated_tool_calls()
|
||||
|
||||
for param_type, target_dict in param_types.items():
|
||||
if param_type in action_data and action_data[param_type].get("properties"):
|
||||
for param, details in action_data[param_type]["properties"].items():
|
||||
if param not in call_args and "value" in details:
|
||||
target_dict[param] = details["value"]
|
||||
for param, value in call_args.items():
|
||||
for param_type, target_dict in param_types.items():
|
||||
if param_type in action_data and param in action_data[param_type].get(
|
||||
"properties", {}
|
||||
):
|
||||
target_dict[param] = value
|
||||
tm = ToolManager(config={})
|
||||
tool = tm.load_tool(
|
||||
tool_data["name"],
|
||||
tool_config=(
|
||||
{
|
||||
"url": tool_data["config"]["actions"][action_name]["url"],
|
||||
"method": tool_data["config"]["actions"][action_name]["method"],
|
||||
"headers": headers,
|
||||
"query_params": query_params,
|
||||
}
|
||||
if tool_data["name"] == "api_tool"
|
||||
else tool_data["config"]
|
||||
),
|
||||
# ---- Context / token management ----
|
||||
|
||||
def _calculate_current_context_tokens(self, messages: List[Dict]) -> int:
|
||||
from application.api.answer.services.compression.token_counter import (
|
||||
TokenCounter,
|
||||
)
|
||||
if tool_data["name"] == "api_tool":
|
||||
print(
|
||||
f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}"
|
||||
return TokenCounter.count_message_tokens(messages)
|
||||
|
||||
def _check_context_limit(self, messages: List[Dict]) -> bool:
|
||||
from application.core.model_utils import get_token_limit
|
||||
|
||||
try:
|
||||
current_tokens = self._calculate_current_context_tokens(messages)
|
||||
self.current_token_count = current_tokens
|
||||
context_limit = get_token_limit(
|
||||
self.model_id, user_id=self.model_user_id or self.user
|
||||
)
|
||||
result = tool.execute_action(action_name, **body)
|
||||
else:
|
||||
print(f"Executing tool: {action_name} with args: {call_args}")
|
||||
result = tool.execute_action(action_name, **parameters)
|
||||
call_id = getattr(call, "id", None)
|
||||
threshold = int(context_limit * settings.COMPRESSION_THRESHOLD_PERCENTAGE)
|
||||
|
||||
tool_call_data = {
|
||||
"tool_name": tool_data["name"],
|
||||
"call_id": call_id if call_id is not None else "None",
|
||||
"action_name": f"{action_name}_{tool_id}",
|
||||
"arguments": call_args,
|
||||
"result": result,
|
||||
}
|
||||
self.tool_calls.append(tool_call_data)
|
||||
if current_tokens >= threshold:
|
||||
logger.warning(
|
||||
f"Context limit approaching: {current_tokens}/{context_limit} tokens "
|
||||
f"({(current_tokens/context_limit)*100:.1f}%)"
|
||||
)
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking context limit: {str(e)}", exc_info=True)
|
||||
return False
|
||||
|
||||
return result, call_id
|
||||
def _validate_context_size(self, messages: List[Dict]) -> None:
|
||||
from application.core.model_utils import get_token_limit
|
||||
|
||||
current_tokens = self._calculate_current_context_tokens(messages)
|
||||
self.current_token_count = current_tokens
|
||||
context_limit = get_token_limit(
|
||||
self.model_id, user_id=self.model_user_id or self.user
|
||||
)
|
||||
percentage = (current_tokens / context_limit) * 100
|
||||
|
||||
if current_tokens >= context_limit:
|
||||
logger.warning(
|
||||
f"Context at limit: {current_tokens:,}/{context_limit:,} tokens "
|
||||
f"({percentage:.1f}%). Model: {self.model_id}"
|
||||
)
|
||||
elif current_tokens >= int(
|
||||
context_limit * settings.COMPRESSION_THRESHOLD_PERCENTAGE
|
||||
):
|
||||
logger.info(
|
||||
f"Context approaching limit: {current_tokens:,}/{context_limit:,} tokens "
|
||||
f"({percentage:.1f}%)"
|
||||
)
|
||||
|
||||
def _truncate_text_middle(self, text: str, max_tokens: int) -> str:
|
||||
from application.utils import num_tokens_from_string
|
||||
|
||||
current_tokens = num_tokens_from_string(text)
|
||||
if current_tokens <= max_tokens:
|
||||
return text
|
||||
|
||||
chars_per_token = len(text) / current_tokens if current_tokens > 0 else 4
|
||||
target_chars = int(max_tokens * chars_per_token * 0.95)
|
||||
|
||||
if target_chars <= 0:
|
||||
return ""
|
||||
|
||||
start_chars = int(target_chars * 0.4)
|
||||
end_chars = int(target_chars * 0.4)
|
||||
|
||||
truncation_marker = "\n\n[... content truncated to fit context limit ...]\n\n"
|
||||
truncated = text[:start_chars] + truncation_marker + text[-end_chars:]
|
||||
|
||||
logger.info(
|
||||
f"Truncated text from {current_tokens:,} to ~{max_tokens:,} tokens "
|
||||
f"(removed middle section)"
|
||||
)
|
||||
return truncated
|
||||
|
||||
# ---- Message building ----
|
||||
|
||||
def _build_messages(
|
||||
self,
|
||||
system_prompt: str,
|
||||
query: str,
|
||||
retrieved_data: List[Dict],
|
||||
) -> List[Dict]:
|
||||
docs_together = "\n".join([doc["text"] for doc in retrieved_data])
|
||||
p_chat_combine = system_prompt.replace("{summaries}", docs_together)
|
||||
messages_combine = [{"role": "system", "content": p_chat_combine}]
|
||||
"""Build messages using pre-rendered system prompt"""
|
||||
from application.core.model_utils import get_token_limit
|
||||
from application.utils import num_tokens_from_string
|
||||
|
||||
for i in self.chat_history:
|
||||
if self.compressed_summary:
|
||||
compression_context = (
|
||||
"\n\n---\n\n"
|
||||
"This session is being continued from a previous conversation that "
|
||||
"has been compressed to fit within context limits. "
|
||||
"The conversation is summarized below:\n\n"
|
||||
f"{self.compressed_summary}"
|
||||
)
|
||||
system_prompt = system_prompt + compression_context
|
||||
|
||||
context_limit = get_token_limit(
|
||||
self.model_id, user_id=self.model_user_id or self.user
|
||||
)
|
||||
system_tokens = num_tokens_from_string(system_prompt)
|
||||
|
||||
safety_buffer = int(context_limit * 0.1)
|
||||
available_after_system = context_limit - system_tokens - safety_buffer
|
||||
|
||||
max_query_tokens = int(available_after_system * 0.8)
|
||||
query_tokens = num_tokens_from_string(query)
|
||||
|
||||
if query_tokens > max_query_tokens:
|
||||
query = self._truncate_text_middle(query, max_query_tokens)
|
||||
query_tokens = num_tokens_from_string(query)
|
||||
|
||||
available_for_history = max(available_after_system - query_tokens, 0)
|
||||
|
||||
working_history = self._truncate_history_to_fit(
|
||||
self.chat_history,
|
||||
available_for_history,
|
||||
)
|
||||
|
||||
messages = [{"role": "system", "content": system_prompt}]
|
||||
|
||||
for i in working_history:
|
||||
if "prompt" in i and "response" in i:
|
||||
messages_combine.append({"role": "user", "content": i["prompt"]})
|
||||
messages_combine.append({"role": "assistant", "content": i["response"]})
|
||||
messages.append({"role": "user", "content": i["prompt"]})
|
||||
messages.append({"role": "assistant", "content": i["response"]})
|
||||
if "tool_calls" in i:
|
||||
for tool_call in i["tool_calls"]:
|
||||
call_id = tool_call.get("call_id") or str(uuid.uuid4())
|
||||
|
||||
function_call_dict = {
|
||||
"function_call": {
|
||||
"name": tool_call.get("action_name"),
|
||||
"args": tool_call.get("arguments"),
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
function_response_dict = {
|
||||
"function_response": {
|
||||
"name": tool_call.get("action_name"),
|
||||
"response": {"result": tool_call.get("result")},
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
|
||||
messages_combine.append(
|
||||
{"role": "assistant", "content": [function_call_dict]}
|
||||
args = tool_call.get("arguments")
|
||||
args_str = (
|
||||
json.dumps(args)
|
||||
if isinstance(args, dict)
|
||||
else (args or "{}")
|
||||
)
|
||||
messages_combine.append(
|
||||
{"role": "tool", "content": [function_response_dict]}
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.get("action_name", ""),
|
||||
"arguments": args_str,
|
||||
},
|
||||
}],
|
||||
})
|
||||
result = tool_call.get("result")
|
||||
result_str = (
|
||||
json.dumps(result)
|
||||
if not isinstance(result, str)
|
||||
else (result or "")
|
||||
)
|
||||
messages_combine.append({"role": "user", "content": query})
|
||||
return messages_combine
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"content": result_str,
|
||||
})
|
||||
messages.append({"role": "user", "content": query})
|
||||
return messages
|
||||
|
||||
def _retriever_search(
|
||||
def _truncate_history_to_fit(
|
||||
self,
|
||||
retriever: BaseRetriever,
|
||||
query: str,
|
||||
log_context: Optional[LogContext] = None,
|
||||
history: List[Dict],
|
||||
max_tokens: int,
|
||||
) -> List[Dict]:
|
||||
retrieved_data = retriever.search(query)
|
||||
if log_context:
|
||||
data = build_stack_data(retriever, exclude_attributes=["llm"])
|
||||
log_context.stacks.append({"component": "retriever", "data": data})
|
||||
return retrieved_data
|
||||
from application.utils import num_tokens_from_string
|
||||
|
||||
if not history or max_tokens <= 0:
|
||||
return []
|
||||
|
||||
truncated = []
|
||||
current_tokens = 0
|
||||
|
||||
for message in reversed(history):
|
||||
message_tokens = 0
|
||||
|
||||
if "prompt" in message and "response" in message:
|
||||
message_tokens += num_tokens_from_string(message["prompt"])
|
||||
message_tokens += num_tokens_from_string(message["response"])
|
||||
|
||||
if "tool_calls" in message:
|
||||
for tool_call in message["tool_calls"]:
|
||||
tool_str = (
|
||||
f"Tool: {tool_call.get('tool_name')} | "
|
||||
f"Action: {tool_call.get('action_name')} | "
|
||||
f"Args: {tool_call.get('arguments')} | "
|
||||
f"Response: {tool_call.get('result')}"
|
||||
)
|
||||
message_tokens += num_tokens_from_string(tool_str)
|
||||
|
||||
if current_tokens + message_tokens <= max_tokens:
|
||||
current_tokens += message_tokens
|
||||
truncated.insert(0, message)
|
||||
else:
|
||||
break
|
||||
|
||||
if len(truncated) < len(history):
|
||||
logger.info(
|
||||
f"Truncated chat history from {len(history)} to {len(truncated)} messages "
|
||||
f"to fit within {max_tokens:,} token budget"
|
||||
)
|
||||
|
||||
return truncated
|
||||
|
||||
# ---- LLM generation ----
|
||||
|
||||
def _llm_gen(self, messages: List[Dict], log_context: Optional[LogContext] = None):
|
||||
resp = self.llm.gen_stream(
|
||||
model=self.gpt_model, messages=messages, tools=self.tools
|
||||
)
|
||||
self._validate_context_size(messages)
|
||||
|
||||
# Use the upstream id resolved by LLMCreator (see __init__).
|
||||
# Built-in models: same as self.model_id. BYOM: the user's
|
||||
# typed model name, not the internal UUID.
|
||||
gen_kwargs = {"model": self.upstream_model_id, "messages": messages}
|
||||
if self.attachments:
|
||||
gen_kwargs["_usage_attachments"] = self.attachments
|
||||
|
||||
if (
|
||||
hasattr(self.llm, "_supports_tools")
|
||||
and self.llm._supports_tools
|
||||
and self.tools
|
||||
):
|
||||
gen_kwargs["tools"] = self.tools
|
||||
if (
|
||||
self.json_schema
|
||||
and hasattr(self.llm, "_supports_structured_output")
|
||||
and self.llm._supports_structured_output()
|
||||
):
|
||||
structured_format = self.llm.prepare_structured_output_format(
|
||||
self.json_schema
|
||||
)
|
||||
if structured_format:
|
||||
if self.llm_name == "openai":
|
||||
gen_kwargs["response_format"] = structured_format
|
||||
elif self.llm_name == "google":
|
||||
gen_kwargs["response_schema"] = structured_format
|
||||
resp = self.llm.gen_stream(**gen_kwargs)
|
||||
|
||||
if log_context:
|
||||
data = build_stack_data(self.llm, exclude_attributes=["client"])
|
||||
log_context.stacks.append({"component": "llm", "data": data})
|
||||
@@ -268,10 +556,51 @@ class BaseAgent(ABC):
|
||||
log_context: Optional[LogContext] = None,
|
||||
attachments: Optional[List[Dict]] = None,
|
||||
):
|
||||
resp = self.llm_handler.handle_response(
|
||||
self, resp, tools_dict, messages, attachments
|
||||
resp = self.llm_handler.process_message_flow(
|
||||
self, resp, tools_dict, messages, attachments, True
|
||||
)
|
||||
if log_context:
|
||||
data = build_stack_data(self.llm_handler, exclude_attributes=["tool_calls"])
|
||||
log_context.stacks.append({"component": "llm_handler", "data": data})
|
||||
return resp
|
||||
|
||||
def _handle_response(self, response, tools_dict, messages, log_context):
|
||||
is_structured_output = (
|
||||
self.json_schema is not None
|
||||
and hasattr(self.llm, "_supports_structured_output")
|
||||
and self.llm._supports_structured_output()
|
||||
)
|
||||
|
||||
if isinstance(response, str):
|
||||
answer_data = {"answer": response}
|
||||
if is_structured_output:
|
||||
answer_data["structured"] = True
|
||||
answer_data["schema"] = self.json_schema
|
||||
yield answer_data
|
||||
return
|
||||
if hasattr(response, "message") and getattr(response.message, "content", None):
|
||||
answer_data = {"answer": response.message.content}
|
||||
if is_structured_output:
|
||||
answer_data["structured"] = True
|
||||
answer_data["schema"] = self.json_schema
|
||||
yield answer_data
|
||||
return
|
||||
processed_response_gen = self._llm_handler(
|
||||
response, tools_dict, messages, log_context, self.attachments
|
||||
)
|
||||
|
||||
for event in processed_response_gen:
|
||||
if isinstance(event, str):
|
||||
answer_data = {"answer": event}
|
||||
if is_structured_output:
|
||||
answer_data["structured"] = True
|
||||
answer_data["schema"] = self.json_schema
|
||||
yield answer_data
|
||||
elif hasattr(event, "message") and getattr(event.message, "content", None):
|
||||
answer_data = {"answer": event.message.content}
|
||||
if is_structured_output:
|
||||
answer_data["structured"] = True
|
||||
answer_data["schema"] = self.json_schema
|
||||
yield answer_data
|
||||
elif isinstance(event, dict) and "type" in event:
|
||||
yield event
|
||||
|
||||
@@ -1,64 +1,33 @@
|
||||
import logging
|
||||
from typing import Dict, Generator
|
||||
|
||||
from application.agents.base import BaseAgent
|
||||
from application.logging import LogContext
|
||||
|
||||
from application.retriever.base import BaseRetriever
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ClassicAgent(BaseAgent):
|
||||
"""A simplified agent with clear execution flow"""
|
||||
|
||||
def _gen_inner(
|
||||
self, query: str, retriever: BaseRetriever, log_context: LogContext
|
||||
self, query: str, log_context: LogContext
|
||||
) -> Generator[Dict, None, None]:
|
||||
retrieved_data = self._retriever_search(retriever, query, log_context)
|
||||
if self.user_api_key:
|
||||
tools_dict = self._get_tools(self.user_api_key)
|
||||
else:
|
||||
tools_dict = self._get_user_tools(self.user)
|
||||
"""Core generator function for ClassicAgent execution flow"""
|
||||
|
||||
tools_dict = self.tool_executor.get_tools()
|
||||
self._prepare_tools(tools_dict)
|
||||
|
||||
messages = self._build_messages(self.prompt, query, retrieved_data)
|
||||
messages = self._build_messages(self.prompt, query)
|
||||
llm_response = self._llm_gen(messages, log_context)
|
||||
|
||||
resp = self._llm_gen(messages, log_context)
|
||||
yield from self._handle_response(
|
||||
llm_response, tools_dict, messages, log_context
|
||||
)
|
||||
|
||||
attachments = self.attachments
|
||||
|
||||
if isinstance(resp, str):
|
||||
yield {"answer": resp}
|
||||
return
|
||||
if (
|
||||
hasattr(resp, "message")
|
||||
and hasattr(resp.message, "content")
|
||||
and resp.message.content is not None
|
||||
):
|
||||
yield {"answer": resp.message.content}
|
||||
return
|
||||
|
||||
resp = self._llm_handler(resp, tools_dict, messages, log_context, attachments)
|
||||
|
||||
if isinstance(resp, str):
|
||||
yield {"answer": resp}
|
||||
elif (
|
||||
hasattr(resp, "message")
|
||||
and hasattr(resp.message, "content")
|
||||
and resp.message.content is not None
|
||||
):
|
||||
yield {"answer": resp.message.content}
|
||||
else:
|
||||
for line in resp:
|
||||
if isinstance(line, str):
|
||||
yield {"answer": line}
|
||||
yield {"sources": self.retrieved_docs}
|
||||
yield {"tool_calls": self._get_truncated_tool_calls()}
|
||||
|
||||
log_context.stacks.append(
|
||||
{"component": "agent", "data": {"tool_calls": self.tool_calls.copy()}}
|
||||
)
|
||||
|
||||
yield {"sources": retrieved_data}
|
||||
# clean tool_call_data only send first 50 characters of tool_call['result']
|
||||
for tool_call in self.tool_calls:
|
||||
if len(str(tool_call["result"])) > 50:
|
||||
tool_call["result"] = str(tool_call["result"])[:50] + "..."
|
||||
yield {"tool_calls": self.tool_calls.copy()}
|
||||
|
||||
@@ -1,351 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from application.logging import build_stack_data
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMHandler(ABC):
|
||||
def __init__(self):
|
||||
self.llm_calls = []
|
||||
self.tool_calls = []
|
||||
|
||||
@abstractmethod
|
||||
def handle_response(self, agent, resp, tools_dict, messages, attachments=None, **kwargs):
|
||||
pass
|
||||
|
||||
def prepare_messages_with_attachments(self, agent, messages, attachments=None):
|
||||
"""
|
||||
Prepare messages with attachment content if available.
|
||||
|
||||
Args:
|
||||
agent: The current agent instance.
|
||||
messages (list): List of message dictionaries.
|
||||
attachments (list): List of attachment dictionaries with content.
|
||||
|
||||
Returns:
|
||||
list: Messages with attachment context added to the system prompt.
|
||||
"""
|
||||
if not attachments:
|
||||
return messages
|
||||
|
||||
logger.info(f"Preparing messages with {len(attachments)} attachments")
|
||||
|
||||
supported_types = agent.llm.get_supported_attachment_types()
|
||||
|
||||
supported_attachments = []
|
||||
unsupported_attachments = []
|
||||
|
||||
for attachment in attachments:
|
||||
mime_type = attachment.get('mime_type')
|
||||
if mime_type in supported_types:
|
||||
supported_attachments.append(attachment)
|
||||
else:
|
||||
unsupported_attachments.append(attachment)
|
||||
|
||||
# Process supported attachments with the LLM's custom method
|
||||
prepared_messages = messages
|
||||
if supported_attachments:
|
||||
logger.info(f"Processing {len(supported_attachments)} supported attachments with {agent.llm.__class__.__name__}'s method")
|
||||
prepared_messages = agent.llm.prepare_messages_with_attachments(messages, supported_attachments)
|
||||
|
||||
# Process unsupported attachments with the default method
|
||||
if unsupported_attachments:
|
||||
logger.info(f"Processing {len(unsupported_attachments)} unsupported attachments with default method")
|
||||
prepared_messages = self._append_attachment_content_to_system(prepared_messages, unsupported_attachments)
|
||||
|
||||
return prepared_messages
|
||||
|
||||
def _append_attachment_content_to_system(self, messages, attachments):
|
||||
"""
|
||||
Default method to append attachment content to the system prompt.
|
||||
|
||||
Args:
|
||||
messages (list): List of message dictionaries.
|
||||
attachments (list): List of attachment dictionaries with content.
|
||||
|
||||
Returns:
|
||||
list: Messages with attachment context added to the system prompt.
|
||||
"""
|
||||
prepared_messages = messages.copy()
|
||||
|
||||
attachment_texts = []
|
||||
for attachment in attachments:
|
||||
logger.info(f"Adding attachment {attachment.get('id')} to context")
|
||||
if 'content' in attachment:
|
||||
attachment_texts.append(f"Attached file content:\n\n{attachment['content']}")
|
||||
|
||||
if attachment_texts:
|
||||
combined_attachment_text = "\n\n".join(attachment_texts)
|
||||
|
||||
system_found = False
|
||||
for i in range(len(prepared_messages)):
|
||||
if prepared_messages[i].get("role") == "system":
|
||||
prepared_messages[i]["content"] += f"\n\n{combined_attachment_text}"
|
||||
system_found = True
|
||||
break
|
||||
|
||||
if not system_found:
|
||||
prepared_messages.insert(0, {"role": "system", "content": combined_attachment_text})
|
||||
|
||||
return prepared_messages
|
||||
|
||||
class OpenAILLMHandler(LLMHandler):
|
||||
def handle_response(self, agent, resp, tools_dict, messages, attachments=None, stream: bool = True):
|
||||
|
||||
messages = self.prepare_messages_with_attachments(agent, messages, attachments)
|
||||
logger.info(f"Messages with attachments: {messages}")
|
||||
if not stream:
|
||||
while hasattr(resp, "finish_reason") and resp.finish_reason == "tool_calls":
|
||||
message = json.loads(resp.model_dump_json())["message"]
|
||||
keys_to_remove = {"audio", "function_call", "refusal"}
|
||||
filtered_data = {
|
||||
k: v for k, v in message.items() if k not in keys_to_remove
|
||||
}
|
||||
messages.append(filtered_data)
|
||||
|
||||
tool_calls = resp.message.tool_calls
|
||||
for call in tool_calls:
|
||||
try:
|
||||
self.tool_calls.append(call)
|
||||
tool_response, call_id = agent._execute_tool_action(
|
||||
tools_dict, call
|
||||
)
|
||||
function_call_dict = {
|
||||
"function_call": {
|
||||
"name": call.function.name,
|
||||
"args": call.function.arguments,
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
function_response_dict = {
|
||||
"function_response": {
|
||||
"name": call.function.name,
|
||||
"response": {"result": tool_response},
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
|
||||
messages.append(
|
||||
{"role": "assistant", "content": [function_call_dict]}
|
||||
)
|
||||
messages.append(
|
||||
{"role": "tool", "content": [function_response_dict]}
|
||||
)
|
||||
|
||||
messages = self.prepare_messages_with_attachments(agent, messages, attachments)
|
||||
except Exception as e:
|
||||
logging.error(f"Error executing tool: {str(e)}", exc_info=True)
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": f"Error executing tool: {str(e)}",
|
||||
"tool_call_id": call_id,
|
||||
}
|
||||
)
|
||||
resp = agent.llm.gen_stream(
|
||||
model=agent.gpt_model, messages=messages, tools=agent.tools
|
||||
)
|
||||
self.llm_calls.append(build_stack_data(agent.llm))
|
||||
return resp
|
||||
|
||||
else:
|
||||
text_buffer = ""
|
||||
while True:
|
||||
tool_calls = {}
|
||||
for chunk in resp:
|
||||
if isinstance(chunk, str) and len(chunk) > 0:
|
||||
yield chunk
|
||||
continue
|
||||
elif hasattr(chunk, "delta"):
|
||||
chunk_delta = chunk.delta
|
||||
|
||||
if (
|
||||
hasattr(chunk_delta, "tool_calls")
|
||||
and chunk_delta.tool_calls is not None
|
||||
):
|
||||
for tool_call in chunk_delta.tool_calls:
|
||||
index = tool_call.index
|
||||
if index not in tool_calls:
|
||||
tool_calls[index] = {
|
||||
"id": "",
|
||||
"function": {"name": "", "arguments": ""},
|
||||
}
|
||||
|
||||
current = tool_calls[index]
|
||||
if tool_call.id:
|
||||
current["id"] = tool_call.id
|
||||
if tool_call.function.name:
|
||||
current["function"][
|
||||
"name"
|
||||
] = tool_call.function.name
|
||||
if tool_call.function.arguments:
|
||||
current["function"][
|
||||
"arguments"
|
||||
] += tool_call.function.arguments
|
||||
tool_calls[index] = current
|
||||
|
||||
if (
|
||||
hasattr(chunk, "finish_reason")
|
||||
and chunk.finish_reason == "tool_calls"
|
||||
):
|
||||
for index in sorted(tool_calls.keys()):
|
||||
call = tool_calls[index]
|
||||
try:
|
||||
self.tool_calls.append(call)
|
||||
tool_response, call_id = agent._execute_tool_action(
|
||||
tools_dict, call
|
||||
)
|
||||
if isinstance(call["function"]["arguments"], str):
|
||||
call["function"]["arguments"] = json.loads(call["function"]["arguments"])
|
||||
|
||||
function_call_dict = {
|
||||
"function_call": {
|
||||
"name": call["function"]["name"],
|
||||
"args": call["function"]["arguments"],
|
||||
"call_id": call["id"],
|
||||
}
|
||||
}
|
||||
function_response_dict = {
|
||||
"function_response": {
|
||||
"name": call["function"]["name"],
|
||||
"response": {"result": tool_response},
|
||||
"call_id": call["id"],
|
||||
}
|
||||
}
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [function_call_dict],
|
||||
}
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": [function_response_dict],
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error executing tool: {str(e)}", exc_info=True)
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"Error executing tool: {str(e)}",
|
||||
}
|
||||
)
|
||||
tool_calls = {}
|
||||
if hasattr(chunk_delta, "content") and chunk_delta.content:
|
||||
# Add to buffer or yield immediately based on your preference
|
||||
text_buffer += chunk_delta.content
|
||||
yield text_buffer
|
||||
text_buffer = ""
|
||||
|
||||
if (
|
||||
hasattr(chunk, "finish_reason")
|
||||
and chunk.finish_reason == "stop"
|
||||
):
|
||||
return resp
|
||||
elif isinstance(chunk, str) and len(chunk) == 0:
|
||||
continue
|
||||
|
||||
logger.info(f"Regenerating with messages: {messages}")
|
||||
resp = agent.llm.gen_stream(
|
||||
model=agent.gpt_model, messages=messages, tools=agent.tools
|
||||
)
|
||||
self.llm_calls.append(build_stack_data(agent.llm))
|
||||
|
||||
|
||||
class GoogleLLMHandler(LLMHandler):
|
||||
def handle_response(self, agent, resp, tools_dict, messages, attachments=None, stream: bool = True):
|
||||
from google.genai import types
|
||||
|
||||
messages = self.prepare_messages_with_attachments(agent, messages, attachments)
|
||||
|
||||
while True:
|
||||
if not stream:
|
||||
response = agent.llm.gen(
|
||||
model=agent.gpt_model, messages=messages, tools=agent.tools
|
||||
)
|
||||
self.llm_calls.append(build_stack_data(agent.llm))
|
||||
if response.candidates and response.candidates[0].content.parts:
|
||||
tool_call_found = False
|
||||
for part in response.candidates[0].content.parts:
|
||||
if part.function_call:
|
||||
tool_call_found = True
|
||||
self.tool_calls.append(part.function_call)
|
||||
tool_response, call_id = agent._execute_tool_action(
|
||||
tools_dict, part.function_call
|
||||
)
|
||||
function_response_part = types.Part.from_function_response(
|
||||
name=part.function_call.name,
|
||||
response={"result": tool_response},
|
||||
)
|
||||
|
||||
messages.append(
|
||||
{"role": "model", "content": [part.to_json_dict()]}
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": [function_response_part.to_json_dict()],
|
||||
}
|
||||
)
|
||||
|
||||
if (
|
||||
not tool_call_found
|
||||
and response.candidates[0].content.parts
|
||||
and response.candidates[0].content.parts[0].text
|
||||
):
|
||||
return response.candidates[0].content.parts[0].text
|
||||
elif not tool_call_found:
|
||||
return response.candidates[0].content.parts
|
||||
|
||||
else:
|
||||
return response
|
||||
|
||||
else:
|
||||
response = agent.llm.gen_stream(
|
||||
model=agent.gpt_model, messages=messages, tools=agent.tools
|
||||
)
|
||||
self.llm_calls.append(build_stack_data(agent.llm))
|
||||
|
||||
tool_call_found = False
|
||||
for result in response:
|
||||
if hasattr(result, "function_call"):
|
||||
tool_call_found = True
|
||||
self.tool_calls.append(result.function_call)
|
||||
tool_response, call_id = agent._execute_tool_action(
|
||||
tools_dict, result.function_call
|
||||
)
|
||||
function_response_part = types.Part.from_function_response(
|
||||
name=result.function_call.name,
|
||||
response={"result": tool_response},
|
||||
)
|
||||
|
||||
messages.append(
|
||||
{"role": "model", "content": [result.to_json_dict()]}
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": [function_response_part.to_json_dict()],
|
||||
}
|
||||
)
|
||||
else:
|
||||
tool_call_found = False
|
||||
yield result
|
||||
|
||||
if not tool_call_found:
|
||||
return response
|
||||
|
||||
|
||||
def get_llm_handler(llm_type):
|
||||
handlers = {
|
||||
"openai": OpenAILLMHandler(),
|
||||
"google": GoogleLLMHandler(),
|
||||
}
|
||||
return handlers.get(llm_type, OpenAILLMHandler())
|
||||
@@ -1,229 +0,0 @@
|
||||
import os
|
||||
from typing import Dict, Generator, List, Any
|
||||
import logging
|
||||
|
||||
from application.agents.base import BaseAgent
|
||||
from application.logging import build_stack_data, LogContext
|
||||
from application.retriever.base import BaseRetriever
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
current_dir = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
)
|
||||
with open(
|
||||
os.path.join(current_dir, "application/prompts", "react_planning_prompt.txt"), "r"
|
||||
) as f:
|
||||
planning_prompt_template = f.read()
|
||||
with open(
|
||||
os.path.join(current_dir, "application/prompts", "react_final_prompt.txt"),
|
||||
"r",
|
||||
) as f:
|
||||
final_prompt_template = f.read()
|
||||
|
||||
MAX_ITERATIONS_REASONING = 10
|
||||
|
||||
class ReActAgent(BaseAgent):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.plan: str = ""
|
||||
self.observations: List[str] = []
|
||||
|
||||
def _extract_content_from_llm_response(self, resp: Any) -> str:
|
||||
"""
|
||||
Helper to extract string content from various LLM response types.
|
||||
Handles strings, message objects (OpenAI-like), and streams.
|
||||
Adapt stream handling for your specific LLM client if not OpenAI.
|
||||
"""
|
||||
collected_content = []
|
||||
if isinstance(resp, str):
|
||||
collected_content.append(resp)
|
||||
elif ( # OpenAI non-streaming or Anthropic non-streaming (older SDK style)
|
||||
hasattr(resp, "message")
|
||||
and hasattr(resp.message, "content")
|
||||
and resp.message.content is not None
|
||||
):
|
||||
collected_content.append(resp.message.content)
|
||||
elif ( # OpenAI non-streaming (Pydantic model), Anthropic new SDK non-streaming
|
||||
hasattr(resp, "choices") and resp.choices and
|
||||
hasattr(resp.choices[0], "message") and
|
||||
hasattr(resp.choices[0].message, "content") and
|
||||
resp.choices[0].message.content is not None
|
||||
):
|
||||
collected_content.append(resp.choices[0].message.content) # OpenAI
|
||||
elif ( # Anthropic new SDK non-streaming content block
|
||||
hasattr(resp, "content") and isinstance(resp.content, list) and resp.content and
|
||||
hasattr(resp.content[0], "text")
|
||||
):
|
||||
collected_content.append(resp.content[0].text) # Anthropic
|
||||
else:
|
||||
# Assume resp is a stream if not a recognized object
|
||||
try:
|
||||
for chunk in resp: # This will fail if resp is not iterable (e.g. a non-streaming response object)
|
||||
content_piece = ""
|
||||
# OpenAI-like stream
|
||||
if hasattr(chunk, 'choices') and len(chunk.choices) > 0 and \
|
||||
hasattr(chunk.choices[0], 'delta') and \
|
||||
hasattr(chunk.choices[0].delta, 'content') and \
|
||||
chunk.choices[0].delta.content is not None:
|
||||
content_piece = chunk.choices[0].delta.content
|
||||
# Anthropic-like stream (ContentBlockDelta)
|
||||
elif hasattr(chunk, 'type') and chunk.type == 'content_block_delta' and \
|
||||
hasattr(chunk, 'delta') and hasattr(chunk.delta, 'text'):
|
||||
content_piece = chunk.delta.text
|
||||
elif isinstance(chunk, str): # Simplest case: stream of strings
|
||||
content_piece = chunk
|
||||
|
||||
if content_piece:
|
||||
collected_content.append(content_piece)
|
||||
except TypeError: # If resp is not iterable (e.g. a final response object that wasn't caught above)
|
||||
logger.debug(f"Response type {type(resp)} could not be iterated as a stream. It might be a non-streaming object not handled by specific checks.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing potential stream chunk: {e}, chunk was: {getattr(chunk, '__dict__', chunk)}")
|
||||
|
||||
|
||||
return "".join(collected_content)
|
||||
|
||||
def _gen_inner(
|
||||
self, query: str, retriever: BaseRetriever, log_context: LogContext
|
||||
) -> Generator[Dict, None, None]:
|
||||
# Reset state for this generation call
|
||||
self.plan = ""
|
||||
self.observations = []
|
||||
retrieved_data = self._retriever_search(retriever, query, log_context)
|
||||
|
||||
if self.user_api_key:
|
||||
tools_dict = self._get_tools(self.user_api_key)
|
||||
else:
|
||||
tools_dict = self._get_user_tools(self.user)
|
||||
self._prepare_tools(tools_dict)
|
||||
|
||||
docs_together = "\n".join([doc["text"] for doc in retrieved_data])
|
||||
iterating_reasoning = 0
|
||||
while iterating_reasoning < MAX_ITERATIONS_REASONING:
|
||||
iterating_reasoning += 1
|
||||
# 1. Create Plan
|
||||
logger.info("ReActAgent: Creating plan...")
|
||||
plan_stream = self._create_plan(query, docs_together, log_context)
|
||||
current_plan_parts = []
|
||||
yield {"thought": f"Reasoning... (iteration {iterating_reasoning})\n\n"}
|
||||
for line_chunk in plan_stream:
|
||||
current_plan_parts.append(line_chunk)
|
||||
yield {"thought": line_chunk}
|
||||
self.plan = "".join(current_plan_parts)
|
||||
if self.plan:
|
||||
self.observations.append(f"Plan: {self.plan} Iteration: {iterating_reasoning}")
|
||||
|
||||
|
||||
max_obs_len = 20000
|
||||
obs_str = "\n".join(self.observations)
|
||||
if len(obs_str) > max_obs_len:
|
||||
obs_str = obs_str[:max_obs_len] + "\n...[observations truncated]"
|
||||
execution_prompt_str = (
|
||||
(self.prompt or "")
|
||||
+ f"\n\nFollow this plan:\n{self.plan}"
|
||||
+ f"\n\nObservations:\n{obs_str}"
|
||||
+ f"\n\nIf there is enough data to complete user query '{query}', Respond with 'SATISFIED' only. Otherwise, continue. Dont Menstion 'SATISFIED' in your response if you are not ready. "
|
||||
)
|
||||
|
||||
messages = self._build_messages(execution_prompt_str, query, retrieved_data)
|
||||
|
||||
resp_from_llm_gen = self._llm_gen(messages, log_context)
|
||||
|
||||
initial_llm_thought_content = self._extract_content_from_llm_response(resp_from_llm_gen)
|
||||
if initial_llm_thought_content:
|
||||
self.observations.append(f"Initial thought/response: {initial_llm_thought_content}")
|
||||
else:
|
||||
logger.info("ReActAgent: Initial LLM response (before handler) had no textual content (might be only tool calls).")
|
||||
resp_after_handler = self._llm_handler(resp_from_llm_gen, tools_dict, messages, log_context)
|
||||
|
||||
for tool_call_info in self.tool_calls: # Iterate over self.tool_calls populated by _llm_handler
|
||||
observation_string = (
|
||||
f"Executed Action: Tool '{tool_call_info.get('tool_name', 'N/A')}' "
|
||||
f"with arguments '{tool_call_info.get('arguments', '{}')}'. Result: '{str(tool_call_info.get('result', ''))[:200]}...'"
|
||||
)
|
||||
self.observations.append(observation_string)
|
||||
|
||||
content_after_handler = self._extract_content_from_llm_response(resp_after_handler)
|
||||
if content_after_handler:
|
||||
self.observations.append(f"Response after tool execution: {content_after_handler}")
|
||||
else:
|
||||
logger.info("ReActAgent: LLM response after handler had no textual content.")
|
||||
|
||||
if log_context:
|
||||
log_context.stacks.append(
|
||||
{"component": "agent_tool_calls", "data": {"tool_calls": self.tool_calls.copy()}}
|
||||
)
|
||||
|
||||
yield {"sources": retrieved_data}
|
||||
|
||||
display_tool_calls = []
|
||||
for tc in self.tool_calls:
|
||||
cleaned_tc = tc.copy()
|
||||
if len(str(cleaned_tc.get("result", ""))) > 50:
|
||||
cleaned_tc["result"] = str(cleaned_tc["result"])[:50] + "..."
|
||||
display_tool_calls.append(cleaned_tc)
|
||||
if display_tool_calls:
|
||||
yield {"tool_calls": display_tool_calls}
|
||||
|
||||
if "SATISFIED" in content_after_handler:
|
||||
logger.info("ReActAgent: LLM satisfied with the plan and data. Stopping reasoning.")
|
||||
break
|
||||
|
||||
# 3. Create Final Answer based on all observations
|
||||
final_answer_stream = self._create_final_answer(query, self.observations, log_context)
|
||||
for answer_chunk in final_answer_stream:
|
||||
yield {"answer": answer_chunk}
|
||||
logger.info("ReActAgent: Finished generating final answer.")
|
||||
|
||||
def _create_plan(
|
||||
self, query: str, docs_data: str, log_context: LogContext = None
|
||||
) -> Generator[str, None, None]:
|
||||
plan_prompt_filled = planning_prompt_template.replace("{query}", query)
|
||||
if "{summaries}" in plan_prompt_filled:
|
||||
summaries = docs_data if docs_data else "No documents retrieved."
|
||||
plan_prompt_filled = plan_prompt_filled.replace("{summaries}", summaries)
|
||||
plan_prompt_filled = plan_prompt_filled.replace("{prompt}", self.prompt or "")
|
||||
plan_prompt_filled = plan_prompt_filled.replace("{observations}", "\n".join(self.observations))
|
||||
|
||||
messages = [{"role": "user", "content": plan_prompt_filled}]
|
||||
|
||||
plan_stream_from_llm = self.llm.gen_stream(
|
||||
model=self.gpt_model, messages=messages, tools=getattr(self, 'tools', None) # Use self.tools
|
||||
)
|
||||
if log_context:
|
||||
data = build_stack_data(self.llm)
|
||||
log_context.stacks.append({"component": "planning_llm", "data": data})
|
||||
|
||||
for chunk in plan_stream_from_llm:
|
||||
content_piece = self._extract_content_from_llm_response(chunk)
|
||||
if content_piece:
|
||||
yield content_piece
|
||||
|
||||
def _create_final_answer(
|
||||
self, query: str, observations: List[str], log_context: LogContext = None
|
||||
) -> Generator[str, None, None]:
|
||||
observation_string = "\n".join(observations)
|
||||
max_obs_len = 10000
|
||||
if len(observation_string) > max_obs_len:
|
||||
observation_string = observation_string[:max_obs_len] + "\n...[observations truncated]"
|
||||
logger.warning("ReActAgent: Truncated observations for final answer prompt due to length.")
|
||||
|
||||
final_answer_prompt_filled = final_prompt_template.format(
|
||||
query=query, observations=observation_string
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": final_answer_prompt_filled}]
|
||||
|
||||
# Final answer should synthesize, not call tools.
|
||||
final_answer_stream_from_llm = self.llm.gen_stream(
|
||||
model=self.gpt_model, messages=messages, tools=None
|
||||
)
|
||||
if log_context:
|
||||
data = build_stack_data(self.llm)
|
||||
log_context.stacks.append({"component": "final_answer_llm", "data": data})
|
||||
|
||||
for chunk in final_answer_stream_from_llm:
|
||||
content_piece = self._extract_content_from_llm_response(chunk)
|
||||
if content_piece:
|
||||
yield content_piece
|
||||
698
application/agents/research_agent.py
Normal file
698
application/agents/research_agent.py
Normal file
@@ -0,0 +1,698 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, Generator, List, Optional
|
||||
|
||||
from application.agents.base import BaseAgent
|
||||
from application.agents.tool_executor import ToolExecutor
|
||||
from application.agents.tools.internal_search import (
|
||||
INTERNAL_TOOL_ID,
|
||||
add_internal_search_tool,
|
||||
)
|
||||
from application.agents.tools.think import THINK_TOOL_ENTRY, THINK_TOOL_ID
|
||||
from application.logging import LogContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Defaults (can be overridden via constructor)
|
||||
DEFAULT_MAX_STEPS = 6
|
||||
DEFAULT_MAX_SUB_ITERATIONS = 5
|
||||
DEFAULT_TIMEOUT_SECONDS = 300 # 5 minutes
|
||||
DEFAULT_TOKEN_BUDGET = 100_000
|
||||
DEFAULT_PARALLEL_WORKERS = 3
|
||||
|
||||
# Adaptive depth caps per complexity level
|
||||
COMPLEXITY_CAPS = {
|
||||
"simple": 2,
|
||||
"moderate": 4,
|
||||
"complex": 6,
|
||||
}
|
||||
|
||||
_PROMPTS_DIR = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
||||
"prompts",
|
||||
"research",
|
||||
)
|
||||
|
||||
|
||||
def _load_prompt(name: str) -> str:
|
||||
with open(os.path.join(_PROMPTS_DIR, name), "r") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
CLARIFICATION_PROMPT = _load_prompt("clarification.txt")
|
||||
PLANNING_PROMPT = _load_prompt("planning.txt")
|
||||
STEP_PROMPT = _load_prompt("step.txt")
|
||||
SYNTHESIS_PROMPT = _load_prompt("synthesis.txt")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CitationManager
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class CitationManager:
|
||||
"""Tracks and deduplicates citations across research steps."""
|
||||
|
||||
def __init__(self):
|
||||
self.citations: Dict[int, Dict] = {}
|
||||
self._counter = 0
|
||||
|
||||
def add(self, doc: Dict) -> int:
|
||||
"""Register a source, return its citation number. Deduplicates by source."""
|
||||
source = doc.get("source", "")
|
||||
title = doc.get("title", "")
|
||||
for num, existing in self.citations.items():
|
||||
if existing.get("source") == source and existing.get("title") == title:
|
||||
return num
|
||||
self._counter += 1
|
||||
self.citations[self._counter] = doc
|
||||
return self._counter
|
||||
|
||||
def add_docs(self, docs: List[Dict]) -> str:
|
||||
"""Register multiple docs, return formatted citation mapping text."""
|
||||
mapping_lines = []
|
||||
for doc in docs:
|
||||
num = self.add(doc)
|
||||
title = doc.get("title", "Untitled")
|
||||
mapping_lines.append(f"[{num}] {title}")
|
||||
return "\n".join(mapping_lines)
|
||||
|
||||
def format_references(self) -> str:
|
||||
"""Generate [N] -> source mapping for report footer."""
|
||||
if not self.citations:
|
||||
return "No sources found."
|
||||
lines = []
|
||||
for num, doc in sorted(self.citations.items()):
|
||||
title = doc.get("title", "Untitled")
|
||||
source = doc.get("source", "Unknown")
|
||||
filename = doc.get("filename", "")
|
||||
display = filename or title
|
||||
lines.append(f"[{num}] {display} — {source}")
|
||||
return "\n".join(lines)
|
||||
|
||||
def get_all_docs(self) -> List[Dict]:
|
||||
return list(self.citations.values())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ResearchAgent
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ResearchAgent(BaseAgent):
|
||||
"""Multi-step research agent with parallel execution and budget controls.
|
||||
|
||||
Orchestrates: Plan -> Research (per step, optionally parallel) -> Synthesize.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
retriever_config: Optional[Dict] = None,
|
||||
max_steps: int = DEFAULT_MAX_STEPS,
|
||||
max_sub_iterations: int = DEFAULT_MAX_SUB_ITERATIONS,
|
||||
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
|
||||
token_budget: int = DEFAULT_TOKEN_BUDGET,
|
||||
parallel_workers: int = DEFAULT_PARALLEL_WORKERS,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.retriever_config = retriever_config or {}
|
||||
self.max_steps = max_steps
|
||||
self.max_sub_iterations = max_sub_iterations
|
||||
self.timeout_seconds = timeout_seconds
|
||||
self.token_budget = token_budget
|
||||
self.parallel_workers = parallel_workers
|
||||
self.citations = CitationManager()
|
||||
self._start_time: float = 0
|
||||
self._tokens_used: int = 0
|
||||
self._last_token_snapshot: int = 0
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Budget & timeout helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _is_timed_out(self) -> bool:
|
||||
return (time.monotonic() - self._start_time) >= self.timeout_seconds
|
||||
|
||||
def _elapsed(self) -> float:
|
||||
return round(time.monotonic() - self._start_time, 1)
|
||||
|
||||
def _track_tokens(self, count: int):
|
||||
self._tokens_used += count
|
||||
|
||||
def _budget_remaining(self) -> int:
|
||||
return max(self.token_budget - self._tokens_used, 0)
|
||||
|
||||
def _is_over_budget(self) -> bool:
|
||||
return self._tokens_used >= self.token_budget
|
||||
|
||||
def _snapshot_llm_tokens(self) -> int:
|
||||
"""Read current token usage from LLM and return delta since last snapshot."""
|
||||
current = self.llm.token_usage.get("prompt_tokens", 0) + self.llm.token_usage.get("generated_tokens", 0)
|
||||
delta = current - self._last_token_snapshot
|
||||
self._last_token_snapshot = current
|
||||
return delta
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Main orchestration
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _gen_inner(
|
||||
self, query: str, log_context: LogContext
|
||||
) -> Generator[Dict, None, None]:
|
||||
self._start_time = time.monotonic()
|
||||
tools_dict = self._setup_tools()
|
||||
|
||||
# Phase 0: Clarification (skip if user is responding to a prior clarification)
|
||||
if not self._is_follow_up():
|
||||
clarification = self._clarification_phase(query)
|
||||
if clarification:
|
||||
yield {"metadata": {"is_clarification": True}}
|
||||
yield {"answer": clarification}
|
||||
yield {"sources": []}
|
||||
yield {"tool_calls": []}
|
||||
log_context.stacks.append(
|
||||
{"component": "agent", "data": {"clarification": True}}
|
||||
)
|
||||
return
|
||||
|
||||
# Phase 1: Planning (with adaptive depth)
|
||||
yield {"type": "research_progress", "data": {"status": "planning"}}
|
||||
plan, complexity = self._planning_phase(query)
|
||||
|
||||
if not plan:
|
||||
logger.warning("ResearchAgent: Planning produced no steps, falling back")
|
||||
plan = [{"query": query, "rationale": "Direct investigation"}]
|
||||
complexity = "simple"
|
||||
|
||||
yield {
|
||||
"type": "research_plan",
|
||||
"data": {"steps": plan, "complexity": complexity},
|
||||
}
|
||||
|
||||
# Phase 2: Research each step (yields progress events in real-time)
|
||||
intermediate_reports = []
|
||||
for i, step in enumerate(plan):
|
||||
step_num = i + 1
|
||||
step_query = step.get("query", query)
|
||||
|
||||
if self._is_timed_out():
|
||||
logger.warning(
|
||||
f"ResearchAgent: Timeout at step {step_num}/{len(plan)} "
|
||||
f"({self._elapsed()}s)"
|
||||
)
|
||||
break
|
||||
if self._is_over_budget():
|
||||
logger.warning(
|
||||
f"ResearchAgent: Token budget exhausted at step {step_num}/{len(plan)}"
|
||||
)
|
||||
break
|
||||
|
||||
yield {
|
||||
"type": "research_progress",
|
||||
"data": {
|
||||
"step": step_num,
|
||||
"total": len(plan),
|
||||
"query": step_query,
|
||||
"status": "researching",
|
||||
},
|
||||
}
|
||||
|
||||
report = self._research_step(step_query, tools_dict)
|
||||
intermediate_reports.append({"step": step, "content": report})
|
||||
|
||||
yield {
|
||||
"type": "research_progress",
|
||||
"data": {
|
||||
"step": step_num,
|
||||
"total": len(plan),
|
||||
"query": step_query,
|
||||
"status": "complete",
|
||||
},
|
||||
}
|
||||
|
||||
# Phase 3: Synthesis (streaming)
|
||||
if self._is_timed_out():
|
||||
logger.warning(
|
||||
f"ResearchAgent: Timeout ({self._elapsed()}s) before synthesis, "
|
||||
f"synthesizing with {len(intermediate_reports)} reports"
|
||||
)
|
||||
yield {
|
||||
"type": "research_progress",
|
||||
"data": {
|
||||
"status": "synthesizing",
|
||||
"elapsed_seconds": self._elapsed(),
|
||||
"tokens_used": self._tokens_used,
|
||||
},
|
||||
}
|
||||
yield from self._synthesis_phase(
|
||||
query, plan, intermediate_reports, tools_dict, log_context
|
||||
)
|
||||
|
||||
# Sources and tool calls
|
||||
self.retrieved_docs = self.citations.get_all_docs()
|
||||
yield {"sources": self.retrieved_docs}
|
||||
yield {"tool_calls": self._get_truncated_tool_calls()}
|
||||
|
||||
logger.info(
|
||||
f"ResearchAgent completed: {len(intermediate_reports)}/{len(plan)} steps, "
|
||||
f"{self._elapsed()}s, ~{self._tokens_used} tokens"
|
||||
)
|
||||
log_context.stacks.append(
|
||||
{"component": "agent", "data": {"tool_calls": self.tool_calls.copy()}}
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tool setup
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _setup_tools(self) -> Dict:
|
||||
"""Build tools_dict with user tools + internal search + think."""
|
||||
tools_dict = self.tool_executor.get_tools()
|
||||
|
||||
add_internal_search_tool(tools_dict, self.retriever_config)
|
||||
|
||||
think_entry = dict(THINK_TOOL_ENTRY)
|
||||
think_entry["config"] = {}
|
||||
tools_dict[THINK_TOOL_ID] = think_entry
|
||||
|
||||
self._prepare_tools(tools_dict)
|
||||
return tools_dict
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Phase 0: Clarification
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _is_follow_up(self) -> bool:
|
||||
"""Check if the user is responding to a prior clarification.
|
||||
|
||||
Uses the metadata flag stored in the conversation DB — no string matching.
|
||||
Only skip clarification when the last query was explicitly flagged
|
||||
as a clarification by this agent.
|
||||
"""
|
||||
if not self.chat_history:
|
||||
return False
|
||||
last = self.chat_history[-1]
|
||||
meta = last.get("metadata", {})
|
||||
return bool(meta.get("is_clarification"))
|
||||
|
||||
def _clarification_phase(self, question: str) -> Optional[str]:
|
||||
"""Ask the LLM whether the question needs clarification.
|
||||
|
||||
Returns formatted clarification text if needed, or None to proceed.
|
||||
Uses response_format to force valid JSON output.
|
||||
"""
|
||||
messages = [
|
||||
{"role": "system", "content": CLARIFICATION_PROMPT},
|
||||
{"role": "user", "content": question},
|
||||
]
|
||||
|
||||
try:
|
||||
response = self.llm.gen(
|
||||
model=self.upstream_model_id,
|
||||
messages=messages,
|
||||
tools=None,
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
text = self._extract_text(response)
|
||||
self._track_tokens(self._snapshot_llm_tokens())
|
||||
logger.info(f"ResearchAgent clarification response: {text[:300]}")
|
||||
|
||||
data = self._parse_clarification_json(text)
|
||||
if not data or not data.get("needs_clarification"):
|
||||
return None
|
||||
|
||||
questions = data.get("questions", [])
|
||||
if not questions:
|
||||
return None
|
||||
|
||||
# Format as a friendly response
|
||||
lines = [
|
||||
"Before I begin researching, I'd like to clarify a few things:\n"
|
||||
]
|
||||
for i, q in enumerate(questions[:3], 1):
|
||||
lines.append(f"{i}. {q}")
|
||||
lines.append(
|
||||
"\nPlease provide these details and I'll start the research."
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Clarification phase failed: {e}", exc_info=True)
|
||||
return None # proceed with research on failure
|
||||
|
||||
def _parse_clarification_json(self, text: str) -> Optional[Dict]:
|
||||
"""Parse clarification JSON from LLM response."""
|
||||
try:
|
||||
return json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Try extracting from code fences
|
||||
for marker in ["```json", "```"]:
|
||||
if marker in text:
|
||||
start = text.index(marker) + len(marker)
|
||||
end = text.index("```", start) if "```" in text[start:] else len(text)
|
||||
try:
|
||||
return json.loads(text[start:end].strip())
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
# Try finding JSON object
|
||||
for i, ch in enumerate(text):
|
||||
if ch == "{":
|
||||
for j in range(len(text) - 1, i, -1):
|
||||
if text[j] == "}":
|
||||
try:
|
||||
return json.loads(text[i : j + 1])
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
break
|
||||
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Phase 1: Planning (with adaptive depth)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _planning_phase(self, question: str) -> tuple[List[Dict], str]:
|
||||
"""Decompose the question into research steps via LLM.
|
||||
|
||||
Returns (steps, complexity) where complexity is simple/moderate/complex.
|
||||
"""
|
||||
messages = [
|
||||
{"role": "system", "content": PLANNING_PROMPT},
|
||||
{"role": "user", "content": question},
|
||||
]
|
||||
|
||||
try:
|
||||
response = self.llm.gen(
|
||||
model=self.upstream_model_id,
|
||||
messages=messages,
|
||||
tools=None,
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
text = self._extract_text(response)
|
||||
self._track_tokens(self._snapshot_llm_tokens())
|
||||
logger.info(f"ResearchAgent planning LLM response: {text[:500]}")
|
||||
|
||||
plan_data = self._parse_plan_json(text)
|
||||
if isinstance(plan_data, dict):
|
||||
complexity = plan_data.get("complexity", "moderate")
|
||||
steps = plan_data.get("steps", [])
|
||||
else:
|
||||
complexity = "moderate"
|
||||
steps = plan_data
|
||||
|
||||
# Adaptive depth: cap steps based on assessed complexity
|
||||
cap = COMPLEXITY_CAPS.get(complexity, self.max_steps)
|
||||
cap = min(cap, self.max_steps)
|
||||
steps = steps[:cap]
|
||||
|
||||
logger.info(
|
||||
f"ResearchAgent plan: complexity={complexity}, "
|
||||
f"steps={len(steps)} (cap={cap})"
|
||||
)
|
||||
return steps, complexity
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Planning phase failed: {e}", exc_info=True)
|
||||
return (
|
||||
[{"query": question, "rationale": "Direct investigation (planning failed)"}],
|
||||
"simple",
|
||||
)
|
||||
|
||||
def _parse_plan_json(self, text: str):
|
||||
"""Extract JSON plan from LLM response. Returns dict or list."""
|
||||
# Try direct parse
|
||||
try:
|
||||
data = json.loads(text)
|
||||
if isinstance(data, dict) and "steps" in data:
|
||||
return data
|
||||
if isinstance(data, list):
|
||||
return data
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Try extracting from markdown code fences
|
||||
for marker in ["```json", "```"]:
|
||||
if marker in text:
|
||||
start = text.index(marker) + len(marker)
|
||||
end = text.index("```", start) if "```" in text[start:] else len(text)
|
||||
try:
|
||||
data = json.loads(text[start:end].strip())
|
||||
if isinstance(data, dict) and "steps" in data:
|
||||
return data
|
||||
if isinstance(data, list):
|
||||
return data
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
# Try finding JSON object in text
|
||||
for i, ch in enumerate(text):
|
||||
if ch == "{":
|
||||
for j in range(len(text) - 1, i, -1):
|
||||
if text[j] == "}":
|
||||
try:
|
||||
data = json.loads(text[i : j + 1])
|
||||
if isinstance(data, dict) and "steps" in data:
|
||||
return data
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
break
|
||||
|
||||
logger.warning(f"Could not parse plan JSON from: {text[:200]}")
|
||||
return []
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Phase 2: Research step (core loop)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _research_step(self, step_query: str, tools_dict: Dict) -> str:
|
||||
"""Run a focused research loop for one sub-question (sequential path)."""
|
||||
report = self._research_step_with_executor(
|
||||
step_query, tools_dict, self.tool_executor
|
||||
)
|
||||
self._collect_step_sources()
|
||||
return report
|
||||
|
||||
def _research_step_with_executor(
|
||||
self, step_query: str, tools_dict: Dict, executor: ToolExecutor
|
||||
) -> str:
|
||||
"""Core research loop. Works with any ToolExecutor instance."""
|
||||
system_prompt = STEP_PROMPT.replace("{step_query}", step_query)
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": step_query},
|
||||
]
|
||||
|
||||
last_search_empty = False
|
||||
|
||||
for iteration in range(self.max_sub_iterations):
|
||||
# Check timeout and budget
|
||||
if self._is_timed_out():
|
||||
logger.info(
|
||||
f"Research step '{step_query[:50]}' timed out at iteration {iteration}"
|
||||
)
|
||||
break
|
||||
if self._is_over_budget():
|
||||
logger.info(
|
||||
f"Research step '{step_query[:50]}' hit token budget at iteration {iteration}"
|
||||
)
|
||||
break
|
||||
|
||||
try:
|
||||
response = self.llm.gen(
|
||||
model=self.upstream_model_id,
|
||||
messages=messages,
|
||||
tools=self.tools if self.tools else None,
|
||||
)
|
||||
self._track_tokens(self._snapshot_llm_tokens())
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Research step LLM call failed (iteration {iteration}): {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
break
|
||||
|
||||
parsed = self.llm_handler.parse_response(response)
|
||||
|
||||
if not parsed.requires_tool_call:
|
||||
return parsed.content or "No findings for this step."
|
||||
|
||||
# Execute tool calls
|
||||
messages, last_search_empty = self._execute_step_tools_with_refinement(
|
||||
parsed.tool_calls, tools_dict, messages, executor, last_search_empty
|
||||
)
|
||||
|
||||
# Max iterations / timeout / budget — ask for summary
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Please summarize your findings so far based on the information gathered.",
|
||||
}
|
||||
)
|
||||
try:
|
||||
response = self.llm.gen(
|
||||
model=self.upstream_model_id, messages=messages, tools=None
|
||||
)
|
||||
self._track_tokens(self._snapshot_llm_tokens())
|
||||
text = self._extract_text(response)
|
||||
return text or "Research step completed."
|
||||
except Exception:
|
||||
return "Research step completed."
|
||||
|
||||
def _execute_step_tools_with_refinement(
|
||||
self,
|
||||
tool_calls,
|
||||
tools_dict: Dict,
|
||||
messages: List[Dict],
|
||||
executor: ToolExecutor,
|
||||
last_search_empty: bool,
|
||||
) -> tuple[List[Dict], bool]:
|
||||
"""Execute tool calls with query refinement on empty results.
|
||||
|
||||
Returns (updated_messages, was_last_search_empty).
|
||||
"""
|
||||
search_returned_empty = False
|
||||
|
||||
for call in tool_calls:
|
||||
gen = executor.execute(
|
||||
tools_dict, call, self.llm.__class__.__name__
|
||||
)
|
||||
result = None
|
||||
call_id = None
|
||||
while True:
|
||||
try:
|
||||
event = next(gen)
|
||||
# Log tool_call status events instead of discarding them
|
||||
if isinstance(event, dict) and event.get("type") == "tool_call":
|
||||
logger.debug(
|
||||
"Tool %s status: %s",
|
||||
event.get("data", {}).get("action_name", ""),
|
||||
event.get("data", {}).get("status", ""),
|
||||
)
|
||||
except StopIteration as e:
|
||||
result, call_id = e.value
|
||||
break
|
||||
|
||||
# Detect empty search results for refinement
|
||||
is_search = "search" in (call.name or "").lower()
|
||||
result_str = str(result) if result else ""
|
||||
if is_search and "No documents found" in result_str:
|
||||
search_returned_empty = True
|
||||
if last_search_empty:
|
||||
# Two consecutive empty searches — inject refinement hint
|
||||
result_str += (
|
||||
"\n\nHint: Previous search also returned no results. "
|
||||
"Try a very different query with different keywords, "
|
||||
"or broaden your search terms."
|
||||
)
|
||||
result = result_str
|
||||
|
||||
import json as _json
|
||||
|
||||
args_str = (
|
||||
_json.dumps(call.arguments)
|
||||
if isinstance(call.arguments, dict)
|
||||
else call.arguments
|
||||
)
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {"name": call.name, "arguments": args_str},
|
||||
}],
|
||||
})
|
||||
tool_message = self.llm_handler.create_tool_message(call, result)
|
||||
messages.append(tool_message)
|
||||
|
||||
return messages, search_returned_empty
|
||||
|
||||
def _collect_step_sources(self):
|
||||
"""Collect sources from InternalSearchTool and register with CitationManager."""
|
||||
cache_key = f"internal_search:{INTERNAL_TOOL_ID}:{self.user or ''}"
|
||||
tool = self.tool_executor._loaded_tools.get(cache_key)
|
||||
if tool and hasattr(tool, "retrieved_docs"):
|
||||
for doc in tool.retrieved_docs:
|
||||
self.citations.add(doc)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Phase 3: Synthesis
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _synthesis_phase(
|
||||
self,
|
||||
question: str,
|
||||
plan: List[Dict],
|
||||
intermediate_reports: List[Dict],
|
||||
tools_dict: Dict,
|
||||
log_context: LogContext,
|
||||
) -> Generator[Dict, None, None]:
|
||||
"""Compile all findings into a final cited report (streaming)."""
|
||||
plan_lines = []
|
||||
for i, step in enumerate(plan, 1):
|
||||
plan_lines.append(
|
||||
f"{i}. {step.get('query', 'Unknown')} — {step.get('rationale', '')}"
|
||||
)
|
||||
plan_summary = "\n".join(plan_lines)
|
||||
|
||||
findings_parts = []
|
||||
for i, report in enumerate(intermediate_reports, 1):
|
||||
step_query = report["step"].get("query", "Unknown")
|
||||
content = report["content"]
|
||||
findings_parts.append(
|
||||
f"--- Step {i}: {step_query} ---\n{content}"
|
||||
)
|
||||
findings = "\n\n".join(findings_parts)
|
||||
|
||||
references = self.citations.format_references()
|
||||
|
||||
synthesis_prompt = SYNTHESIS_PROMPT.replace("{question}", question)
|
||||
synthesis_prompt = synthesis_prompt.replace("{plan_summary}", plan_summary)
|
||||
synthesis_prompt = synthesis_prompt.replace("{findings}", findings)
|
||||
synthesis_prompt = synthesis_prompt.replace("{references}", references)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": synthesis_prompt},
|
||||
{"role": "user", "content": f"Please write the research report for: {question}"},
|
||||
]
|
||||
|
||||
llm_response = self.llm.gen_stream(
|
||||
model=self.upstream_model_id, messages=messages, tools=None
|
||||
)
|
||||
|
||||
if log_context:
|
||||
from application.logging import build_stack_data
|
||||
|
||||
log_context.stacks.append(
|
||||
{"component": "synthesis_llm", "data": build_stack_data(self.llm)}
|
||||
)
|
||||
|
||||
yield from self._handle_response(
|
||||
llm_response, tools_dict, messages, log_context
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _extract_text(self, response) -> str:
|
||||
"""Extract text content from a non-streaming LLM response."""
|
||||
if isinstance(response, str):
|
||||
return response
|
||||
if hasattr(response, "message") and hasattr(response.message, "content"):
|
||||
return response.message.content or ""
|
||||
if hasattr(response, "choices") and response.choices:
|
||||
choice = response.choices[0]
|
||||
if hasattr(choice, "message") and hasattr(choice.message, "content"):
|
||||
return choice.message.content or ""
|
||||
if hasattr(response, "content") and isinstance(response.content, list):
|
||||
if response.content and hasattr(response.content[0], "text"):
|
||||
return response.content[0].text or ""
|
||||
return str(response) if response else ""
|
||||
494
application/agents/tool_executor.py
Normal file
494
application/agents/tool_executor.py
Normal file
@@ -0,0 +1,494 @@
|
||||
import logging
|
||||
import uuid
|
||||
from collections import Counter
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from application.agents.tools.tool_action_parser import ToolActionParser
|
||||
from application.agents.tools.tool_manager import ToolManager
|
||||
from application.security.encryption import decrypt_credentials
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.user_tools import UserToolsRepository
|
||||
from application.storage.db.session import db_readonly
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolExecutor:
|
||||
"""Handles tool discovery, preparation, and execution.
|
||||
|
||||
Extracted from BaseAgent to separate concerns and enable tool caching.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_api_key: Optional[str] = None,
|
||||
user: Optional[str] = None,
|
||||
decoded_token: Optional[Dict] = None,
|
||||
):
|
||||
self.user_api_key = user_api_key
|
||||
self.user = user
|
||||
self.decoded_token = decoded_token
|
||||
self.tool_calls: List[Dict] = []
|
||||
self._loaded_tools: Dict[str, object] = {}
|
||||
self.conversation_id: Optional[str] = None
|
||||
self.client_tools: Optional[List[Dict]] = None
|
||||
self._name_to_tool: Dict[str, Tuple[str, str]] = {}
|
||||
self._tool_to_name: Dict[Tuple[str, str], str] = {}
|
||||
|
||||
def get_tools(self) -> Dict[str, Dict]:
|
||||
"""Load tool configs from DB based on user context.
|
||||
|
||||
If *client_tools* have been set on this executor, they are
|
||||
automatically merged into the returned dict.
|
||||
"""
|
||||
if self.user_api_key:
|
||||
tools = self._get_tools_by_api_key(self.user_api_key)
|
||||
else:
|
||||
tools = self._get_user_tools(self.user or "local")
|
||||
if self.client_tools:
|
||||
self.merge_client_tools(tools, self.client_tools)
|
||||
return tools
|
||||
|
||||
def _get_tools_by_api_key(self, api_key: str) -> Dict[str, Dict]:
|
||||
# Per-operation session: the answer pipeline spans a long-lived
|
||||
# generator; wrapping it in a single connection would pin a PG
|
||||
# conn for the whole stream. Open, fetch, close.
|
||||
with db_readonly() as conn:
|
||||
agent_data = AgentsRepository(conn).find_by_key(api_key)
|
||||
tool_ids = agent_data.get("tools", []) if agent_data else []
|
||||
if not tool_ids:
|
||||
return {}
|
||||
tools_repo = UserToolsRepository(conn)
|
||||
tools: List[Dict] = []
|
||||
owner = (agent_data.get("user_id") or agent_data.get("user")) if agent_data else None
|
||||
for tid in tool_ids:
|
||||
row = None
|
||||
if owner:
|
||||
row = tools_repo.get_any(str(tid), owner)
|
||||
if row is not None:
|
||||
tools.append(row)
|
||||
return {str(tool["id"]): tool for tool in tools} if tools else {}
|
||||
|
||||
def _get_user_tools(self, user: str = "local") -> Dict[str, Dict]:
|
||||
with db_readonly() as conn:
|
||||
user_tools = UserToolsRepository(conn).list_active_for_user(user)
|
||||
return {str(i): tool for i, tool in enumerate(user_tools)}
|
||||
|
||||
def merge_client_tools(
|
||||
self, tools_dict: Dict, client_tools: List[Dict]
|
||||
) -> Dict:
|
||||
"""Merge client-provided tool definitions into tools_dict.
|
||||
|
||||
Client tools use the standard function-calling format::
|
||||
|
||||
[{"type": "function", "function": {"name": "get_weather",
|
||||
"description": "...", "parameters": {...}}}]
|
||||
|
||||
They are stored in *tools_dict* with ``client_side: True`` so that
|
||||
:meth:`check_pause` returns a pause signal instead of trying to
|
||||
execute them server-side.
|
||||
|
||||
Args:
|
||||
tools_dict: The mutable server tools dict (will be modified in place).
|
||||
client_tools: List of tool definitions in function-calling format.
|
||||
|
||||
Returns:
|
||||
The updated *tools_dict* (same reference, for convenience).
|
||||
"""
|
||||
for i, ct in enumerate(client_tools):
|
||||
func = ct.get("function", ct) # tolerate bare {"name":..} too
|
||||
name = func.get("name", f"clienttool{i}")
|
||||
tool_id = f"ct{i}"
|
||||
|
||||
tools_dict[tool_id] = {
|
||||
"name": name,
|
||||
"client_side": True,
|
||||
"actions": [
|
||||
{
|
||||
"name": name,
|
||||
"description": func.get("description", ""),
|
||||
"active": True,
|
||||
"parameters": func.get("parameters", {}),
|
||||
}
|
||||
],
|
||||
}
|
||||
return tools_dict
|
||||
|
||||
def prepare_tools_for_llm(self, tools_dict: Dict) -> List[Dict]:
|
||||
"""Convert tool configs to LLM function schemas.
|
||||
|
||||
Action names are kept clean for the LLM:
|
||||
- Unique action names appear as-is (e.g. ``get_weather``).
|
||||
- Duplicate action names get numbered suffixes (e.g. ``search_1``,
|
||||
``search_2``).
|
||||
|
||||
A reverse mapping is stored in ``_name_to_tool`` so that tool calls
|
||||
can be routed back to the correct ``(tool_id, action_name)`` without
|
||||
brittle string splitting.
|
||||
"""
|
||||
# Pass 1: collect entries and count action name occurrences
|
||||
entries: List[Tuple[str, str, Dict, bool]] = [] # (tool_id, action_name, action, is_client)
|
||||
name_counts: Counter = Counter()
|
||||
|
||||
for tool_id, tool in tools_dict.items():
|
||||
is_api = tool["name"] == "api_tool"
|
||||
is_client = tool.get("client_side", False)
|
||||
|
||||
if is_api and "actions" not in tool.get("config", {}):
|
||||
continue
|
||||
if not is_api and "actions" not in tool:
|
||||
continue
|
||||
|
||||
actions = (
|
||||
tool["config"]["actions"].values()
|
||||
if is_api
|
||||
else tool["actions"]
|
||||
)
|
||||
|
||||
for action in actions:
|
||||
if not action.get("active", True):
|
||||
continue
|
||||
entries.append((tool_id, action["name"], action, is_client))
|
||||
name_counts[action["name"]] += 1
|
||||
|
||||
# Pass 2: assign LLM-visible names and build mappings
|
||||
self._name_to_tool = {}
|
||||
self._tool_to_name = {}
|
||||
collision_counters: Dict[str, int] = {}
|
||||
all_llm_names: set = set()
|
||||
|
||||
result = []
|
||||
for tool_id, action_name, action, is_client in entries:
|
||||
if name_counts[action_name] == 1:
|
||||
llm_name = action_name
|
||||
else:
|
||||
counter = collision_counters.get(action_name, 1)
|
||||
candidate = f"{action_name}_{counter}"
|
||||
# Skip if candidate collides with a unique action name
|
||||
while candidate in all_llm_names or (
|
||||
candidate in name_counts and name_counts[candidate] == 1
|
||||
):
|
||||
counter += 1
|
||||
candidate = f"{action_name}_{counter}"
|
||||
collision_counters[action_name] = counter + 1
|
||||
llm_name = candidate
|
||||
|
||||
all_llm_names.add(llm_name)
|
||||
self._name_to_tool[llm_name] = (tool_id, action_name)
|
||||
self._tool_to_name[(tool_id, action_name)] = llm_name
|
||||
|
||||
if is_client:
|
||||
params = action.get("parameters", {})
|
||||
else:
|
||||
params = self._build_tool_parameters(action)
|
||||
|
||||
result.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": llm_name,
|
||||
"description": action.get("description", ""),
|
||||
"parameters": params,
|
||||
},
|
||||
})
|
||||
return result
|
||||
|
||||
def _build_tool_parameters(self, action: Dict) -> Dict:
|
||||
params = {"type": "object", "properties": {}, "required": []}
|
||||
for param_type in ["query_params", "headers", "body", "parameters"]:
|
||||
if param_type in action and action[param_type].get("properties"):
|
||||
for k, v in action[param_type]["properties"].items():
|
||||
if v.get("filled_by_llm", True):
|
||||
params["properties"][k] = {
|
||||
key: value
|
||||
for key, value in v.items()
|
||||
if key not in ("filled_by_llm", "value", "required")
|
||||
}
|
||||
if v.get("required", False):
|
||||
params["required"].append(k)
|
||||
return params
|
||||
|
||||
def check_pause(
|
||||
self, tools_dict: Dict, call, llm_class_name: str
|
||||
) -> Optional[Dict]:
|
||||
"""Check if a tool call requires pausing for approval or client execution.
|
||||
|
||||
Returns a dict describing the pending action if pause is needed, None otherwise.
|
||||
"""
|
||||
parser = ToolActionParser(llm_class_name, name_mapping=self._name_to_tool)
|
||||
tool_id, action_name, call_args = parser.parse_args(call)
|
||||
call_id = getattr(call, "id", None) or str(uuid.uuid4())
|
||||
llm_name = getattr(call, "name", "")
|
||||
|
||||
if tool_id is None or action_name is None or tool_id not in tools_dict:
|
||||
return None # Will be handled as error by execute()
|
||||
|
||||
tool_data = tools_dict[tool_id]
|
||||
|
||||
# Client-side tools
|
||||
if tool_data.get("client_side"):
|
||||
return {
|
||||
"call_id": call_id,
|
||||
"name": llm_name,
|
||||
"tool_name": tool_data.get("name", "unknown"),
|
||||
"tool_id": tool_id,
|
||||
"action_name": action_name,
|
||||
"llm_name": llm_name,
|
||||
"arguments": call_args if isinstance(call_args, dict) else {},
|
||||
"pause_type": "requires_client_execution",
|
||||
"thought_signature": getattr(call, "thought_signature", None),
|
||||
}
|
||||
|
||||
# Approval required
|
||||
if tool_data["name"] == "api_tool":
|
||||
action_data = tool_data.get("config", {}).get("actions", {}).get(
|
||||
action_name, {}
|
||||
)
|
||||
else:
|
||||
action_data = next(
|
||||
(a for a in tool_data.get("actions", []) if a["name"] == action_name),
|
||||
{},
|
||||
)
|
||||
|
||||
if action_data.get("require_approval"):
|
||||
return {
|
||||
"call_id": call_id,
|
||||
"name": llm_name,
|
||||
"tool_name": tool_data.get("name", "unknown"),
|
||||
"tool_id": tool_id,
|
||||
"action_name": action_name,
|
||||
"llm_name": llm_name,
|
||||
"arguments": call_args if isinstance(call_args, dict) else {},
|
||||
"pause_type": "awaiting_approval",
|
||||
"thought_signature": getattr(call, "thought_signature", None),
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
def execute(self, tools_dict: Dict, call, llm_class_name: str):
|
||||
"""Execute a tool call. Yields status events, returns (result, call_id)."""
|
||||
parser = ToolActionParser(llm_class_name, name_mapping=self._name_to_tool)
|
||||
tool_id, action_name, call_args = parser.parse_args(call)
|
||||
llm_name = getattr(call, "name", "unknown")
|
||||
|
||||
call_id = getattr(call, "id", None) or str(uuid.uuid4())
|
||||
|
||||
if tool_id is None or action_name is None:
|
||||
error_message = f"Error: Failed to parse LLM tool call. Tool name: {llm_name}"
|
||||
logger.error(error_message)
|
||||
|
||||
tool_call_data = {
|
||||
"tool_name": "unknown",
|
||||
"call_id": call_id,
|
||||
"action_name": llm_name,
|
||||
"arguments": call_args or {},
|
||||
"result": f"Failed to parse tool call. Invalid tool name format: {llm_name}",
|
||||
}
|
||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
|
||||
self.tool_calls.append(tool_call_data)
|
||||
return "Failed to parse tool call.", call_id
|
||||
|
||||
if tool_id not in tools_dict:
|
||||
error_message = f"Error: Tool ID '{tool_id}' extracted from LLM call not found in available tools_dict. Available IDs: {list(tools_dict.keys())}"
|
||||
logger.error(error_message)
|
||||
|
||||
tool_call_data = {
|
||||
"tool_name": "unknown",
|
||||
"call_id": call_id,
|
||||
"action_name": llm_name,
|
||||
"arguments": call_args,
|
||||
"result": f"Tool with ID {tool_id} not found. Available tools: {list(tools_dict.keys())}",
|
||||
}
|
||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
|
||||
self.tool_calls.append(tool_call_data)
|
||||
return f"Tool with ID {tool_id} not found.", call_id
|
||||
|
||||
tool_call_data = {
|
||||
"tool_name": tools_dict[tool_id]["name"],
|
||||
"call_id": call_id,
|
||||
"action_name": llm_name,
|
||||
"arguments": call_args,
|
||||
}
|
||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "pending"}}
|
||||
|
||||
tool_data = tools_dict[tool_id]
|
||||
action_data = (
|
||||
tool_data["config"]["actions"][action_name]
|
||||
if tool_data["name"] == "api_tool"
|
||||
else next(
|
||||
action
|
||||
for action in tool_data["actions"]
|
||||
if action["name"] == action_name
|
||||
)
|
||||
)
|
||||
|
||||
query_params, headers, body, parameters = {}, {}, {}, {}
|
||||
param_types = {
|
||||
"query_params": query_params,
|
||||
"headers": headers,
|
||||
"body": body,
|
||||
"parameters": parameters,
|
||||
}
|
||||
|
||||
for param_type, target_dict in param_types.items():
|
||||
if param_type in action_data and action_data[param_type].get("properties"):
|
||||
for param, details in action_data[param_type]["properties"].items():
|
||||
if (
|
||||
param not in call_args
|
||||
and "value" in details
|
||||
and details["value"]
|
||||
):
|
||||
target_dict[param] = details["value"]
|
||||
for param, value in call_args.items():
|
||||
for param_type, target_dict in param_types.items():
|
||||
if param_type in action_data and param in action_data[param_type].get(
|
||||
"properties", {}
|
||||
):
|
||||
target_dict[param] = value
|
||||
|
||||
# Load tool (with caching)
|
||||
tool = self._get_or_load_tool(
|
||||
tool_data, tool_id, action_name,
|
||||
headers=headers, query_params=query_params,
|
||||
)
|
||||
|
||||
if tool is None:
|
||||
error_message = (
|
||||
f"Failed to load tool '{tool_data.get('name')}' (tool_id key={tool_id}): "
|
||||
"missing 'id' on tool row."
|
||||
)
|
||||
logger.error(error_message)
|
||||
tool_call_data["result"] = error_message
|
||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
|
||||
self.tool_calls.append(tool_call_data)
|
||||
return error_message, call_id
|
||||
|
||||
resolved_arguments = (
|
||||
{"query_params": query_params, "headers": headers, "body": body}
|
||||
if tool_data["name"] == "api_tool"
|
||||
else parameters
|
||||
)
|
||||
if tool_data["name"] == "api_tool":
|
||||
logger.debug(
|
||||
f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}"
|
||||
)
|
||||
result = tool.execute_action(action_name, **body)
|
||||
else:
|
||||
logger.debug(f"Executing tool: {action_name} with args: {call_args}")
|
||||
result = tool.execute_action(action_name, **parameters)
|
||||
|
||||
get_artifact_id = (
|
||||
getattr(tool, "get_artifact_id", None)
|
||||
if tool_data["name"] != "api_tool"
|
||||
else None
|
||||
)
|
||||
|
||||
artifact_id = None
|
||||
if callable(get_artifact_id):
|
||||
try:
|
||||
artifact_id = get_artifact_id(action_name, **parameters)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to extract artifact_id from tool %s for action %s",
|
||||
tool_data["name"],
|
||||
action_name,
|
||||
)
|
||||
|
||||
artifact_id = str(artifact_id).strip() if artifact_id is not None else ""
|
||||
if artifact_id:
|
||||
tool_call_data["artifact_id"] = artifact_id
|
||||
result_full = str(result)
|
||||
tool_call_data["resolved_arguments"] = resolved_arguments
|
||||
tool_call_data["result_full"] = result_full
|
||||
tool_call_data["result"] = (
|
||||
f"{result_full[:50]}..." if len(result_full) > 50 else result_full
|
||||
)
|
||||
|
||||
stream_tool_call_data = {
|
||||
key: value
|
||||
for key, value in tool_call_data.items()
|
||||
if key not in {"result_full", "resolved_arguments"}
|
||||
}
|
||||
yield {"type": "tool_call", "data": {**stream_tool_call_data, "status": "completed"}}
|
||||
self.tool_calls.append(tool_call_data)
|
||||
|
||||
return result, call_id
|
||||
|
||||
def _get_or_load_tool(
|
||||
self, tool_data: Dict, tool_id: str, action_name: str,
|
||||
headers: Optional[Dict] = None, query_params: Optional[Dict] = None,
|
||||
):
|
||||
"""Load a tool, using cache when possible."""
|
||||
cache_key = f"{tool_data['name']}:{tool_id}:{self.user or ''}"
|
||||
if cache_key in self._loaded_tools:
|
||||
return self._loaded_tools[cache_key]
|
||||
|
||||
tm = ToolManager(config={})
|
||||
|
||||
if tool_data["name"] == "api_tool":
|
||||
action_config = tool_data["config"]["actions"][action_name]
|
||||
tool_config = {
|
||||
"url": action_config["url"],
|
||||
"method": action_config["method"],
|
||||
"headers": headers or {},
|
||||
"query_params": query_params or {},
|
||||
}
|
||||
if "body_content_type" in action_config:
|
||||
tool_config["body_content_type"] = action_config.get(
|
||||
"body_content_type", "application/json"
|
||||
)
|
||||
tool_config["body_encoding_rules"] = action_config.get(
|
||||
"body_encoding_rules", {}
|
||||
)
|
||||
else:
|
||||
tool_config = tool_data["config"].copy() if tool_data["config"] else {}
|
||||
if tool_config.get("encrypted_credentials") and self.user:
|
||||
decrypted = decrypt_credentials(
|
||||
tool_config["encrypted_credentials"], self.user
|
||||
)
|
||||
tool_config.update(decrypted)
|
||||
tool_config["auth_credentials"] = decrypted
|
||||
tool_config.pop("encrypted_credentials", None)
|
||||
row_id = tool_data.get("id")
|
||||
if not row_id:
|
||||
logger.error(
|
||||
"Tool data missing 'id' for tool name=%s (enumerate-key tool_id=%s); "
|
||||
"skipping load to avoid binding a non-UUID downstream.",
|
||||
tool_data.get("name"),
|
||||
tool_id,
|
||||
)
|
||||
return None
|
||||
tool_config["tool_id"] = str(row_id)
|
||||
if self.conversation_id:
|
||||
tool_config["conversation_id"] = self.conversation_id
|
||||
if tool_data["name"] == "mcp_tool":
|
||||
tool_config["query_mode"] = True
|
||||
|
||||
tool = tm.load_tool(
|
||||
tool_data["name"],
|
||||
tool_config=tool_config,
|
||||
user_id=self.user,
|
||||
)
|
||||
|
||||
# Don't cache api_tool since config varies by action
|
||||
if tool_data["name"] != "api_tool":
|
||||
self._loaded_tools[cache_key] = tool
|
||||
|
||||
return tool
|
||||
|
||||
def get_truncated_tool_calls(self) -> List[Dict]:
|
||||
return [
|
||||
{
|
||||
"tool_name": tool_call.get("tool_name"),
|
||||
"call_id": tool_call.get("call_id"),
|
||||
"action_name": tool_call.get("action_name"),
|
||||
"arguments": tool_call.get("arguments"),
|
||||
"artifact_id": tool_call.get("artifact_id"),
|
||||
"result": (
|
||||
f"{str(tool_call['result'])[:50]}..."
|
||||
if len(str(tool_call["result"])) > 50
|
||||
else tool_call["result"]
|
||||
),
|
||||
"status": "completed",
|
||||
}
|
||||
for tool_call in self.tool_calls
|
||||
]
|
||||
323
application/agents/tools/api_body_serializer.py
Normal file
323
application/agents/tools/api_body_serializer.py
Normal file
@@ -0,0 +1,323 @@
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from urllib.parse import quote, urlencode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ContentType(str, Enum):
|
||||
"""Supported content types for request bodies."""
|
||||
|
||||
JSON = "application/json"
|
||||
FORM_URLENCODED = "application/x-www-form-urlencoded"
|
||||
MULTIPART_FORM_DATA = "multipart/form-data"
|
||||
TEXT_PLAIN = "text/plain"
|
||||
XML = "application/xml"
|
||||
OCTET_STREAM = "application/octet-stream"
|
||||
|
||||
|
||||
class RequestBodySerializer:
|
||||
"""Serializes request bodies according to content-type and OpenAPI 3.1 spec."""
|
||||
|
||||
@staticmethod
|
||||
def serialize(
|
||||
body_data: Dict[str, Any],
|
||||
content_type: str = ContentType.JSON,
|
||||
encoding_rules: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
) -> tuple[Union[str, bytes], Dict[str, str]]:
|
||||
"""
|
||||
Serialize body data to appropriate format.
|
||||
|
||||
Args:
|
||||
body_data: Dictionary of body parameters
|
||||
content_type: Content-Type header value
|
||||
encoding_rules: OpenAPI Encoding Object rules per field
|
||||
|
||||
Returns:
|
||||
Tuple of (serialized_body, updated_headers_dict)
|
||||
|
||||
Raises:
|
||||
ValueError: If serialization fails
|
||||
"""
|
||||
if not body_data:
|
||||
return None, {}
|
||||
|
||||
try:
|
||||
content_type_lower = content_type.lower().split(";")[0].strip()
|
||||
|
||||
if content_type_lower == ContentType.JSON:
|
||||
return RequestBodySerializer._serialize_json(body_data)
|
||||
|
||||
elif content_type_lower == ContentType.FORM_URLENCODED:
|
||||
return RequestBodySerializer._serialize_form_urlencoded(
|
||||
body_data, encoding_rules
|
||||
)
|
||||
|
||||
elif content_type_lower == ContentType.MULTIPART_FORM_DATA:
|
||||
return RequestBodySerializer._serialize_multipart_form_data(
|
||||
body_data, encoding_rules
|
||||
)
|
||||
|
||||
elif content_type_lower == ContentType.TEXT_PLAIN:
|
||||
return RequestBodySerializer._serialize_text_plain(body_data)
|
||||
|
||||
elif content_type_lower == ContentType.XML:
|
||||
return RequestBodySerializer._serialize_xml(body_data)
|
||||
|
||||
elif content_type_lower == ContentType.OCTET_STREAM:
|
||||
return RequestBodySerializer._serialize_octet_stream(body_data)
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
f"Unknown content type: {content_type}, treating as JSON"
|
||||
)
|
||||
return RequestBodySerializer._serialize_json(body_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error serializing body: {str(e)}", exc_info=True)
|
||||
raise ValueError(f"Failed to serialize request body: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def _serialize_json(body_data: Dict[str, Any]) -> tuple[str, Dict[str, str]]:
|
||||
"""Serialize body as JSON per OpenAPI spec."""
|
||||
try:
|
||||
serialized = json.dumps(
|
||||
body_data, separators=(",", ":"), ensure_ascii=False
|
||||
)
|
||||
headers = {"Content-Type": ContentType.JSON.value}
|
||||
return serialized, headers
|
||||
except (TypeError, ValueError) as e:
|
||||
raise ValueError(f"Failed to serialize JSON body: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def _serialize_form_urlencoded(
|
||||
body_data: Dict[str, Any],
|
||||
encoding_rules: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
) -> tuple[str, Dict[str, str]]:
|
||||
"""Serialize body as application/x-www-form-urlencoded per RFC1866/RFC3986."""
|
||||
encoding_rules = encoding_rules or {}
|
||||
params = []
|
||||
|
||||
for key, value in body_data.items():
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
rule = encoding_rules.get(key, {})
|
||||
style = rule.get("style", "form")
|
||||
explode = rule.get("explode", style == "form")
|
||||
content_type = rule.get("contentType", "text/plain")
|
||||
|
||||
serialized_value = RequestBodySerializer._serialize_form_value(
|
||||
value, style, explode, content_type, key
|
||||
)
|
||||
|
||||
if isinstance(serialized_value, list):
|
||||
for sv in serialized_value:
|
||||
params.append((key, sv))
|
||||
else:
|
||||
params.append((key, serialized_value))
|
||||
|
||||
# Use standard urlencode (replaces space with +)
|
||||
serialized = urlencode(params, safe="")
|
||||
headers = {"Content-Type": ContentType.FORM_URLENCODED.value}
|
||||
return serialized, headers
|
||||
|
||||
@staticmethod
|
||||
def _serialize_form_value(
|
||||
value: Any, style: str, explode: bool, content_type: str, key: str
|
||||
) -> Union[str, list]:
|
||||
"""Serialize individual form value with encoding rules."""
|
||||
if isinstance(value, dict):
|
||||
if content_type == "application/json":
|
||||
return json.dumps(value, separators=(",", ":"))
|
||||
elif content_type == "application/xml":
|
||||
return RequestBodySerializer._dict_to_xml(value)
|
||||
else:
|
||||
if style == "deepObject" and explode:
|
||||
return [
|
||||
f"{RequestBodySerializer._percent_encode(str(v))}"
|
||||
for v in value.values()
|
||||
]
|
||||
elif explode:
|
||||
return [
|
||||
f"{RequestBodySerializer._percent_encode(str(v))}"
|
||||
for v in value.values()
|
||||
]
|
||||
else:
|
||||
pairs = [f"{k},{v}" for k, v in value.items()]
|
||||
return RequestBodySerializer._percent_encode(",".join(pairs))
|
||||
|
||||
elif isinstance(value, (list, tuple)):
|
||||
if explode:
|
||||
return [
|
||||
RequestBodySerializer._percent_encode(str(item)) for item in value
|
||||
]
|
||||
else:
|
||||
return RequestBodySerializer._percent_encode(
|
||||
",".join(str(v) for v in value)
|
||||
)
|
||||
|
||||
else:
|
||||
return RequestBodySerializer._percent_encode(str(value))
|
||||
|
||||
@staticmethod
|
||||
def _serialize_multipart_form_data(
|
||||
body_data: Dict[str, Any],
|
||||
encoding_rules: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
) -> tuple[bytes, Dict[str, str]]:
|
||||
"""
|
||||
Serialize body as multipart/form-data per RFC7578.
|
||||
|
||||
Supports file uploads and encoding rules.
|
||||
"""
|
||||
import secrets
|
||||
|
||||
encoding_rules = encoding_rules or {}
|
||||
boundary = f"----DocsGPT{secrets.token_hex(16)}"
|
||||
parts = []
|
||||
|
||||
for key, value in body_data.items():
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
rule = encoding_rules.get(key, {})
|
||||
content_type = rule.get("contentType", "text/plain")
|
||||
headers_rule = rule.get("headers", {})
|
||||
|
||||
part = RequestBodySerializer._create_multipart_part(
|
||||
key, value, content_type, headers_rule
|
||||
)
|
||||
parts.append(part)
|
||||
|
||||
body_bytes = f"--{boundary}\r\n".encode("utf-8")
|
||||
body_bytes += f"--{boundary}\r\n".join(parts).encode("utf-8")
|
||||
body_bytes += f"\r\n--{boundary}--\r\n".encode("utf-8")
|
||||
|
||||
headers = {
|
||||
"Content-Type": f"multipart/form-data; boundary={boundary}",
|
||||
}
|
||||
return body_bytes, headers
|
||||
|
||||
@staticmethod
|
||||
def _create_multipart_part(
|
||||
name: str, value: Any, content_type: str, headers_rule: Dict[str, Any]
|
||||
) -> str:
|
||||
"""Create a single multipart/form-data part."""
|
||||
headers = [
|
||||
f'Content-Disposition: form-data; name="{RequestBodySerializer._percent_encode(name)}"'
|
||||
]
|
||||
|
||||
if isinstance(value, bytes):
|
||||
if content_type == "application/octet-stream":
|
||||
value_encoded = base64.b64encode(value).decode("utf-8")
|
||||
else:
|
||||
value_encoded = value.decode("utf-8", errors="replace")
|
||||
headers.append(f"Content-Type: {content_type}")
|
||||
headers.append("Content-Transfer-Encoding: base64")
|
||||
elif isinstance(value, dict):
|
||||
if content_type == "application/json":
|
||||
value_encoded = json.dumps(value, separators=(",", ":"))
|
||||
elif content_type == "application/xml":
|
||||
value_encoded = RequestBodySerializer._dict_to_xml(value)
|
||||
else:
|
||||
value_encoded = str(value)
|
||||
headers.append(f"Content-Type: {content_type}")
|
||||
elif isinstance(value, str) and content_type != "text/plain":
|
||||
try:
|
||||
if content_type == "application/json":
|
||||
json.loads(value)
|
||||
value_encoded = value
|
||||
elif content_type == "application/xml":
|
||||
value_encoded = value
|
||||
else:
|
||||
value_encoded = str(value)
|
||||
except json.JSONDecodeError:
|
||||
value_encoded = str(value)
|
||||
headers.append(f"Content-Type: {content_type}")
|
||||
else:
|
||||
value_encoded = str(value)
|
||||
if content_type != "text/plain":
|
||||
headers.append(f"Content-Type: {content_type}")
|
||||
|
||||
part = "\r\n".join(headers) + "\r\n\r\n" + value_encoded + "\r\n"
|
||||
return part
|
||||
|
||||
@staticmethod
|
||||
def _serialize_text_plain(body_data: Dict[str, Any]) -> tuple[str, Dict[str, str]]:
|
||||
"""Serialize body as plain text."""
|
||||
if len(body_data) == 1:
|
||||
value = list(body_data.values())[0]
|
||||
return str(value), {"Content-Type": ContentType.TEXT_PLAIN.value}
|
||||
else:
|
||||
text = "\n".join(f"{k}: {v}" for k, v in body_data.items())
|
||||
return text, {"Content-Type": ContentType.TEXT_PLAIN.value}
|
||||
|
||||
@staticmethod
|
||||
def _serialize_xml(body_data: Dict[str, Any]) -> tuple[str, Dict[str, str]]:
|
||||
"""Serialize body as XML."""
|
||||
xml_str = RequestBodySerializer._dict_to_xml(body_data)
|
||||
return xml_str, {"Content-Type": ContentType.XML.value}
|
||||
|
||||
@staticmethod
|
||||
def _serialize_octet_stream(
|
||||
body_data: Dict[str, Any],
|
||||
) -> tuple[bytes, Dict[str, str]]:
|
||||
"""Serialize body as binary octet stream."""
|
||||
if isinstance(body_data, bytes):
|
||||
return body_data, {"Content-Type": ContentType.OCTET_STREAM.value}
|
||||
elif isinstance(body_data, str):
|
||||
return body_data.encode("utf-8"), {
|
||||
"Content-Type": ContentType.OCTET_STREAM.value
|
||||
}
|
||||
else:
|
||||
serialized = json.dumps(body_data)
|
||||
return serialized.encode("utf-8"), {
|
||||
"Content-Type": ContentType.OCTET_STREAM.value
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _percent_encode(value: str, safe_chars: str = "") -> str:
|
||||
"""
|
||||
Percent-encode per RFC3986.
|
||||
|
||||
Args:
|
||||
value: String to encode
|
||||
safe_chars: Additional characters to not encode
|
||||
"""
|
||||
return quote(value, safe=safe_chars)
|
||||
|
||||
@staticmethod
|
||||
def _dict_to_xml(data: Dict[str, Any], root_name: str = "root") -> str:
|
||||
"""
|
||||
Convert dict to simple XML format.
|
||||
"""
|
||||
|
||||
def build_xml(obj: Any, name: str) -> str:
|
||||
if isinstance(obj, dict):
|
||||
inner = "".join(build_xml(v, k) for k, v in obj.items())
|
||||
return f"<{name}>{inner}</{name}>"
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
items = "".join(
|
||||
build_xml(item, f"{name[:-1] if name.endswith('s') else name}")
|
||||
for item in obj
|
||||
)
|
||||
return items
|
||||
else:
|
||||
return f"<{name}>{RequestBodySerializer._escape_xml(str(obj))}</{name}>"
|
||||
|
||||
root = build_xml(data, root_name)
|
||||
return f'<?xml version="1.0" encoding="UTF-8"?>{root}'
|
||||
|
||||
@staticmethod
|
||||
def _escape_xml(value: str) -> str:
|
||||
"""Escape XML special characters."""
|
||||
return (
|
||||
value.replace("&", "&")
|
||||
.replace("<", "<")
|
||||
.replace(">", ">")
|
||||
.replace('"', """)
|
||||
.replace("'", "'")
|
||||
)
|
||||
@@ -1,72 +1,280 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import requests
|
||||
|
||||
from application.agents.tools.api_body_serializer import (
|
||||
ContentType,
|
||||
RequestBodySerializer,
|
||||
)
|
||||
from application.agents.tools.base import Tool
|
||||
from application.core.url_validation import validate_url, SSRFError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_TIMEOUT = 90 # seconds
|
||||
|
||||
|
||||
class APITool(Tool):
|
||||
"""
|
||||
API Tool
|
||||
A flexible tool for performing various API actions (e.g., sending messages, retrieving data) via custom user-specified APIs
|
||||
A flexible tool for performing various API actions (e.g., sending messages, retrieving data) via custom user-specified APIs.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.url = config.get("url", "")
|
||||
self.method = config.get("method", "GET")
|
||||
self.headers = config.get("headers", {"Content-Type": "application/json"})
|
||||
self.headers = config.get("headers", {})
|
||||
self.query_params = config.get("query_params", {})
|
||||
self.body_content_type = config.get("body_content_type", ContentType.JSON)
|
||||
self.body_encoding_rules = config.get("body_encoding_rules", {})
|
||||
|
||||
def execute_action(self, action_name, **kwargs):
|
||||
"""Execute an API action with the given arguments."""
|
||||
return self._make_api_call(
|
||||
self.url, self.method, self.headers, self.query_params, kwargs
|
||||
self.url,
|
||||
self.method,
|
||||
self.headers,
|
||||
self.query_params,
|
||||
kwargs,
|
||||
self.body_content_type,
|
||||
self.body_encoding_rules,
|
||||
)
|
||||
|
||||
def _make_api_call(self, url, method, headers, query_params, body):
|
||||
if query_params:
|
||||
url = f"{url}?{requests.compat.urlencode(query_params)}"
|
||||
# if isinstance(body, dict):
|
||||
# body = json.dumps(body)
|
||||
def _make_api_call(
|
||||
self,
|
||||
url: str,
|
||||
method: str,
|
||||
headers: Dict[str, str],
|
||||
query_params: Dict[str, Any],
|
||||
body: Dict[str, Any],
|
||||
content_type: str = ContentType.JSON,
|
||||
encoding_rules: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Make an API call with proper body serialization and error handling.
|
||||
|
||||
Args:
|
||||
url: API endpoint URL
|
||||
method: HTTP method (GET, POST, PUT, DELETE, PATCH, HEAD, OPTIONS)
|
||||
headers: Request headers dict
|
||||
query_params: URL query parameters
|
||||
body: Request body as dict
|
||||
content_type: Content-Type for serialization
|
||||
encoding_rules: OpenAPI encoding rules
|
||||
|
||||
Returns:
|
||||
Dict with status_code, data, and message
|
||||
"""
|
||||
request_url = url
|
||||
request_headers = headers.copy() if headers else {}
|
||||
response = None
|
||||
|
||||
# Validate URL to prevent SSRF attacks
|
||||
try:
|
||||
print(f"Making API call: {method} {url} with body: {body}")
|
||||
if body == "{}":
|
||||
body = None
|
||||
response = requests.request(method, url, headers=headers, data=body)
|
||||
response.raise_for_status()
|
||||
content_type = response.headers.get(
|
||||
"Content-Type", "application/json"
|
||||
).lower()
|
||||
if "application/json" in content_type:
|
||||
validate_url(request_url)
|
||||
except SSRFError as e:
|
||||
logger.error(f"URL validation failed: {e}")
|
||||
return {
|
||||
"status_code": None,
|
||||
"message": f"URL validation error: {e}",
|
||||
"data": None,
|
||||
}
|
||||
|
||||
try:
|
||||
path_params_used = set()
|
||||
if query_params:
|
||||
for match in re.finditer(r"\{([^}]+)\}", request_url):
|
||||
param_name = match.group(1)
|
||||
if param_name in query_params:
|
||||
request_url = request_url.replace(
|
||||
f"{{{param_name}}}", str(query_params[param_name])
|
||||
)
|
||||
path_params_used.add(param_name)
|
||||
remaining_params = {
|
||||
k: v for k, v in query_params.items() if k not in path_params_used
|
||||
}
|
||||
if remaining_params:
|
||||
query_string = urlencode(remaining_params)
|
||||
separator = "&" if "?" in request_url else "?"
|
||||
request_url = f"{request_url}{separator}{query_string}"
|
||||
|
||||
# Re-validate URL after parameter substitution to prevent SSRF via path params
|
||||
try:
|
||||
validate_url(request_url)
|
||||
except SSRFError as e:
|
||||
logger.error(f"URL validation failed after parameter substitution: {e}")
|
||||
return {
|
||||
"status_code": None,
|
||||
"message": f"URL validation error: {e}",
|
||||
"data": None,
|
||||
}
|
||||
|
||||
# Serialize body based on content type
|
||||
|
||||
if body and body != {}:
|
||||
try:
|
||||
data = response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Error decoding JSON: {e}. Raw response: {response.text}")
|
||||
serialized_body, body_headers = RequestBodySerializer.serialize(
|
||||
body, content_type, encoding_rules
|
||||
)
|
||||
request_headers.update(body_headers)
|
||||
except ValueError as e:
|
||||
logger.error(f"Body serialization failed: {str(e)}")
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"message": f"API call returned invalid JSON. Error: {e}",
|
||||
"data": response.text,
|
||||
"status_code": None,
|
||||
"message": f"Body serialization error: {str(e)}",
|
||||
"data": None,
|
||||
}
|
||||
elif "text/" in content_type or "application/xml" in content_type:
|
||||
data = response.text
|
||||
elif not response.content:
|
||||
data = None
|
||||
else:
|
||||
print(f"Unsupported content type: {content_type}")
|
||||
data = response.content
|
||||
serialized_body = None
|
||||
if "Content-Type" not in request_headers and method not in [
|
||||
"GET",
|
||||
"HEAD",
|
||||
"DELETE",
|
||||
]:
|
||||
request_headers["Content-Type"] = ContentType.JSON
|
||||
logger.debug(
|
||||
f"API Call: {method} {request_url} | Content-Type: {request_headers.get('Content-Type', 'N/A')}"
|
||||
)
|
||||
|
||||
if method.upper() == "GET":
|
||||
response = requests.get(
|
||||
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
|
||||
)
|
||||
elif method.upper() == "POST":
|
||||
response = requests.post(
|
||||
request_url,
|
||||
data=serialized_body,
|
||||
headers=request_headers,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
)
|
||||
elif method.upper() == "PUT":
|
||||
response = requests.put(
|
||||
request_url,
|
||||
data=serialized_body,
|
||||
headers=request_headers,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
)
|
||||
elif method.upper() == "DELETE":
|
||||
response = requests.delete(
|
||||
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
|
||||
)
|
||||
elif method.upper() == "PATCH":
|
||||
response = requests.patch(
|
||||
request_url,
|
||||
data=serialized_body,
|
||||
headers=request_headers,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
)
|
||||
elif method.upper() == "HEAD":
|
||||
response = requests.head(
|
||||
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
|
||||
)
|
||||
elif method.upper() == "OPTIONS":
|
||||
response = requests.options(
|
||||
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
|
||||
)
|
||||
else:
|
||||
return {
|
||||
"status_code": None,
|
||||
"message": f"Unsupported HTTP method: {method}",
|
||||
"data": None,
|
||||
}
|
||||
response.raise_for_status()
|
||||
|
||||
data = self._parse_response(response)
|
||||
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"data": data,
|
||||
"message": "API call successful.",
|
||||
}
|
||||
except requests.exceptions.Timeout:
|
||||
logger.error(f"Request timeout for {request_url}")
|
||||
return {
|
||||
"status_code": None,
|
||||
"message": f"Request timeout ({DEFAULT_TIMEOUT}s exceeded)",
|
||||
"data": None,
|
||||
}
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
logger.error(f"Connection error: {str(e)}")
|
||||
return {
|
||||
"status_code": None,
|
||||
"message": f"Connection error: {str(e)}",
|
||||
"data": None,
|
||||
}
|
||||
except requests.exceptions.HTTPError as e:
|
||||
logger.error(f"HTTP error {response.status_code}: {str(e)}")
|
||||
try:
|
||||
error_data = response.json()
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
error_data = response.text
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"message": f"HTTP Error {response.status_code}",
|
||||
"data": error_data,
|
||||
}
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"Request failed: {str(e)}")
|
||||
return {
|
||||
"status_code": response.status_code if response else None,
|
||||
"message": f"API call failed: {str(e)}",
|
||||
"data": None,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in API call: {str(e)}", exc_info=True)
|
||||
return {
|
||||
"status_code": None,
|
||||
"message": f"Unexpected error: {str(e)}",
|
||||
"data": None,
|
||||
}
|
||||
|
||||
def _parse_response(self, response: requests.Response) -> Any:
|
||||
"""
|
||||
Parse response based on Content-Type header.
|
||||
|
||||
Supports: JSON, XML, plain text, binary data.
|
||||
"""
|
||||
content_type = response.headers.get("Content-Type", "").lower()
|
||||
|
||||
if not response.content:
|
||||
return None
|
||||
# JSON response
|
||||
|
||||
if "application/json" in content_type:
|
||||
try:
|
||||
return response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Failed to parse JSON response: {str(e)}")
|
||||
return response.text
|
||||
# XML response
|
||||
|
||||
elif "application/xml" in content_type or "text/xml" in content_type:
|
||||
return response.text
|
||||
# Plain text response
|
||||
|
||||
elif "text/plain" in content_type or "text/html" in content_type:
|
||||
return response.text
|
||||
# Binary/unknown response
|
||||
|
||||
else:
|
||||
# Try to decode as text first, fall back to base64
|
||||
|
||||
try:
|
||||
return response.text
|
||||
except (UnicodeDecodeError, AttributeError):
|
||||
import base64
|
||||
|
||||
return base64.b64encode(response.content).decode("utf-8")
|
||||
|
||||
def get_actions_metadata(self):
|
||||
"""Return metadata for available actions (none for API Tool - actions are user-defined)."""
|
||||
return []
|
||||
|
||||
def get_config_requirements(self):
|
||||
"""Return configuration requirements for the tool."""
|
||||
return {}
|
||||
|
||||
@@ -2,6 +2,8 @@ from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class Tool(ABC):
|
||||
internal: bool = False
|
||||
|
||||
@abstractmethod
|
||||
def execute_action(self, action_name: str, **kwargs):
|
||||
pass
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
import logging
|
||||
|
||||
import requests
|
||||
|
||||
from application.agents.tools.base import Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BraveSearchTool(Tool):
|
||||
"""
|
||||
@@ -25,27 +30,35 @@ class BraveSearchTool(Tool):
|
||||
else:
|
||||
raise ValueError(f"Unknown action: {action_name}")
|
||||
|
||||
def _web_search(self, query, country="ALL", search_lang="en", count=10,
|
||||
offset=0, safesearch="off", freshness=None,
|
||||
result_filter=None, extra_snippets=False, summary=False):
|
||||
def _web_search(
|
||||
self,
|
||||
query,
|
||||
country="ALL",
|
||||
search_lang="en",
|
||||
count=10,
|
||||
offset=0,
|
||||
safesearch="off",
|
||||
freshness=None,
|
||||
result_filter=None,
|
||||
extra_snippets=False,
|
||||
summary=False,
|
||||
):
|
||||
"""
|
||||
Performs a web search using the Brave Search API.
|
||||
"""
|
||||
print(f"Performing Brave web search for: {query}")
|
||||
|
||||
logger.debug("Performing Brave web search for: %s", query)
|
||||
|
||||
url = f"{self.base_url}/web/search"
|
||||
|
||||
# Build query parameters
|
||||
|
||||
params = {
|
||||
"q": query,
|
||||
"country": country,
|
||||
"search_lang": search_lang,
|
||||
"count": min(count, 20),
|
||||
"offset": min(offset, 9),
|
||||
"safesearch": safesearch
|
||||
"safesearch": safesearch,
|
||||
}
|
||||
|
||||
# Add optional parameters only if they have values
|
||||
|
||||
if freshness:
|
||||
params["freshness"] = freshness
|
||||
if result_filter:
|
||||
@@ -54,68 +67,69 @@ class BraveSearchTool(Tool):
|
||||
params["extra_snippets"] = 1
|
||||
if summary:
|
||||
params["summary"] = 1
|
||||
|
||||
# Set up headers
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"Accept-Encoding": "gzip",
|
||||
"X-Subscription-Token": self.token
|
||||
"X-Subscription-Token": self.token,
|
||||
}
|
||||
|
||||
# Make the request
|
||||
response = requests.get(url, params=params, headers=headers)
|
||||
|
||||
|
||||
response = requests.get(url, params=params, headers=headers, timeout=100)
|
||||
|
||||
if response.status_code == 200:
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"results": response.json(),
|
||||
"message": "Search completed successfully."
|
||||
"message": "Search completed successfully.",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"message": f"Search failed with status code: {response.status_code}."
|
||||
"message": f"Search failed with status code: {response.status_code}.",
|
||||
}
|
||||
|
||||
def _image_search(self, query, country="ALL", search_lang="en", count=5,
|
||||
safesearch="off", spellcheck=False):
|
||||
|
||||
def _image_search(
|
||||
self,
|
||||
query,
|
||||
country="ALL",
|
||||
search_lang="en",
|
||||
count=5,
|
||||
safesearch="off",
|
||||
spellcheck=False,
|
||||
):
|
||||
"""
|
||||
Performs an image search using the Brave Search API.
|
||||
"""
|
||||
print(f"Performing Brave image search for: {query}")
|
||||
|
||||
logger.debug("Performing Brave image search for: %s", query)
|
||||
|
||||
url = f"{self.base_url}/images/search"
|
||||
|
||||
# Build query parameters
|
||||
|
||||
params = {
|
||||
"q": query,
|
||||
"country": country,
|
||||
"search_lang": search_lang,
|
||||
"count": min(count, 100), # API max is 100
|
||||
"safesearch": safesearch,
|
||||
"spellcheck": 1 if spellcheck else 0
|
||||
"spellcheck": 1 if spellcheck else 0,
|
||||
}
|
||||
|
||||
# Set up headers
|
||||
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"Accept-Encoding": "gzip",
|
||||
"X-Subscription-Token": self.token
|
||||
"X-Subscription-Token": self.token,
|
||||
}
|
||||
|
||||
# Make the request
|
||||
response = requests.get(url, params=params, headers=headers)
|
||||
|
||||
|
||||
response = requests.get(url, params=params, headers=headers, timeout=100)
|
||||
|
||||
if response.status_code == 200:
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"results": response.json(),
|
||||
"message": "Image search completed successfully."
|
||||
"message": "Image search completed successfully.",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"message": f"Image search failed with status code: {response.status_code}."
|
||||
"message": f"Image search failed with status code: {response.status_code}.",
|
||||
}
|
||||
|
||||
def get_actions_metadata(self):
|
||||
@@ -130,42 +144,14 @@ class BraveSearchTool(Tool):
|
||||
"type": "string",
|
||||
"description": "The search query (max 400 characters, 50 words)",
|
||||
},
|
||||
# "country": {
|
||||
# "type": "string",
|
||||
# "description": "The 2-character country code (default: US)",
|
||||
# },
|
||||
"search_lang": {
|
||||
"type": "string",
|
||||
"description": "The search language preference (default: en)",
|
||||
},
|
||||
# "count": {
|
||||
# "type": "integer",
|
||||
# "description": "Number of results to return (max 20, default: 10)",
|
||||
# },
|
||||
# "offset": {
|
||||
# "type": "integer",
|
||||
# "description": "Pagination offset (max 9, default: 0)",
|
||||
# },
|
||||
# "safesearch": {
|
||||
# "type": "string",
|
||||
# "description": "Filter level for adult content (off, moderate, strict)",
|
||||
# },
|
||||
"freshness": {
|
||||
"type": "string",
|
||||
"description": "Time filter for results (pd: last 24h, pw: last week, pm: last month, py: last year)",
|
||||
},
|
||||
# "result_filter": {
|
||||
# "type": "string",
|
||||
# "description": "Comma-delimited list of result types to include",
|
||||
# },
|
||||
# "extra_snippets": {
|
||||
# "type": "boolean",
|
||||
# "description": "Get additional excerpts from result pages",
|
||||
# },
|
||||
# "summary": {
|
||||
# "type": "boolean",
|
||||
# "description": "Enable summary generation in search results",
|
||||
# }
|
||||
},
|
||||
"required": ["query"],
|
||||
"additionalProperties": False,
|
||||
@@ -181,37 +167,25 @@ class BraveSearchTool(Tool):
|
||||
"type": "string",
|
||||
"description": "The search query (max 400 characters, 50 words)",
|
||||
},
|
||||
# "country": {
|
||||
# "type": "string",
|
||||
# "description": "The 2-character country code (default: US)",
|
||||
# },
|
||||
# "search_lang": {
|
||||
# "type": "string",
|
||||
# "description": "The search language preference (default: en)",
|
||||
# },
|
||||
"count": {
|
||||
"type": "integer",
|
||||
"description": "Number of results to return (max 100, default: 5)",
|
||||
},
|
||||
# "safesearch": {
|
||||
# "type": "string",
|
||||
# "description": "Filter level for adult content (off, strict). Default: strict",
|
||||
# },
|
||||
# "spellcheck": {
|
||||
# "type": "boolean",
|
||||
# "description": "Whether to spellcheck provided query (default: true)",
|
||||
# }
|
||||
},
|
||||
"required": ["query"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
def get_config_requirements(self):
|
||||
return {
|
||||
"token": {
|
||||
"type": "string",
|
||||
"description": "Brave Search API key for authentication"
|
||||
"type": "string",
|
||||
"label": "API Key",
|
||||
"description": "Brave Search API key for authentication",
|
||||
"required": True,
|
||||
"secret": True,
|
||||
"order": 1,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ class CryptoPriceTool(Tool):
|
||||
returns price in USD.
|
||||
"""
|
||||
url = f"https://min-api.cryptocompare.com/data/price?fsym={symbol.upper()}&tsyms={currency.upper()}"
|
||||
response = requests.get(url)
|
||||
response = requests.get(url, timeout=100)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
if currency.upper() in data:
|
||||
|
||||
209
application/agents/tools/duckduckgo.py
Normal file
209
application/agents/tools/duckduckgo.py
Normal file
@@ -0,0 +1,209 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from application.agents.tools.base import Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_RETRIES = 3
|
||||
RETRY_DELAY = 2.0
|
||||
DEFAULT_TIMEOUT = 15
|
||||
|
||||
|
||||
class DuckDuckGoSearchTool(Tool):
|
||||
"""
|
||||
DuckDuckGo Search
|
||||
A tool for performing web and image searches using DuckDuckGo.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.timeout = config.get("timeout", DEFAULT_TIMEOUT)
|
||||
|
||||
def _get_ddgs_client(self):
|
||||
from ddgs import DDGS
|
||||
|
||||
return DDGS(timeout=self.timeout)
|
||||
|
||||
def _execute_with_retry(self, operation, operation_name: str) -> Dict[str, Any]:
|
||||
last_error = None
|
||||
for attempt in range(1, MAX_RETRIES + 1):
|
||||
try:
|
||||
results = operation()
|
||||
return {
|
||||
"status_code": 200,
|
||||
"results": list(results) if results else [],
|
||||
"message": f"{operation_name} completed successfully.",
|
||||
}
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
error_str = str(e).lower()
|
||||
if "ratelimit" in error_str or "429" in error_str:
|
||||
if attempt < MAX_RETRIES:
|
||||
delay = RETRY_DELAY * attempt
|
||||
logger.warning(
|
||||
f"{operation_name} rate limited, retrying in {delay}s (attempt {attempt}/{MAX_RETRIES})"
|
||||
)
|
||||
time.sleep(delay)
|
||||
continue
|
||||
logger.error(f"{operation_name} failed: {e}")
|
||||
break
|
||||
return {
|
||||
"status_code": 500,
|
||||
"results": [],
|
||||
"message": f"{operation_name} failed: {str(last_error)}",
|
||||
}
|
||||
|
||||
def execute_action(self, action_name, **kwargs):
|
||||
actions = {
|
||||
"ddg_web_search": self._web_search,
|
||||
"ddg_image_search": self._image_search,
|
||||
"ddg_news_search": self._news_search,
|
||||
}
|
||||
if action_name not in actions:
|
||||
raise ValueError(f"Unknown action: {action_name}")
|
||||
return actions[action_name](**kwargs)
|
||||
|
||||
def _web_search(
|
||||
self,
|
||||
query: str,
|
||||
max_results: int = 5,
|
||||
region: str = "wt-wt",
|
||||
safesearch: str = "moderate",
|
||||
timelimit: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
logger.info(f"DuckDuckGo web search: {query}")
|
||||
|
||||
def operation():
|
||||
client = self._get_ddgs_client()
|
||||
return client.text(
|
||||
query,
|
||||
region=region,
|
||||
safesearch=safesearch,
|
||||
timelimit=timelimit,
|
||||
max_results=min(max_results, 20),
|
||||
)
|
||||
|
||||
return self._execute_with_retry(operation, "Web search")
|
||||
|
||||
def _image_search(
|
||||
self,
|
||||
query: str,
|
||||
max_results: int = 5,
|
||||
region: str = "wt-wt",
|
||||
safesearch: str = "moderate",
|
||||
timelimit: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
logger.info(f"DuckDuckGo image search: {query}")
|
||||
|
||||
def operation():
|
||||
client = self._get_ddgs_client()
|
||||
return client.images(
|
||||
query,
|
||||
region=region,
|
||||
safesearch=safesearch,
|
||||
timelimit=timelimit,
|
||||
max_results=min(max_results, 50),
|
||||
)
|
||||
|
||||
return self._execute_with_retry(operation, "Image search")
|
||||
|
||||
def _news_search(
|
||||
self,
|
||||
query: str,
|
||||
max_results: int = 5,
|
||||
region: str = "wt-wt",
|
||||
safesearch: str = "moderate",
|
||||
timelimit: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
logger.info(f"DuckDuckGo news search: {query}")
|
||||
|
||||
def operation():
|
||||
client = self._get_ddgs_client()
|
||||
return client.news(
|
||||
query,
|
||||
region=region,
|
||||
safesearch=safesearch,
|
||||
timelimit=timelimit,
|
||||
max_results=min(max_results, 20),
|
||||
)
|
||||
|
||||
return self._execute_with_retry(operation, "News search")
|
||||
|
||||
def get_actions_metadata(self):
|
||||
return [
|
||||
{
|
||||
"name": "ddg_web_search",
|
||||
"description": "Search the web using DuckDuckGo. Returns titles, URLs, and snippets.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query",
|
||||
},
|
||||
"max_results": {
|
||||
"type": "integer",
|
||||
"description": "Number of results (default: 5, max: 20)",
|
||||
},
|
||||
"region": {
|
||||
"type": "string",
|
||||
"description": "Region code (default: wt-wt for worldwide, us-en for US)",
|
||||
},
|
||||
"timelimit": {
|
||||
"type": "string",
|
||||
"description": "Time filter: d (day), w (week), m (month), y (year)",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "ddg_image_search",
|
||||
"description": "Search for images using DuckDuckGo. Returns image URLs and metadata.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Image search query",
|
||||
},
|
||||
"max_results": {
|
||||
"type": "integer",
|
||||
"description": "Number of results (default: 5, max: 50)",
|
||||
},
|
||||
"region": {
|
||||
"type": "string",
|
||||
"description": "Region code (default: wt-wt for worldwide)",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "ddg_news_search",
|
||||
"description": "Search for news articles using DuckDuckGo. Returns recent news.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "News search query",
|
||||
},
|
||||
"max_results": {
|
||||
"type": "integer",
|
||||
"description": "Number of results (default: 5, max: 20)",
|
||||
},
|
||||
"timelimit": {
|
||||
"type": "string",
|
||||
"description": "Time filter: d (day), w (week), m (month)",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
def get_config_requirements(self):
|
||||
return {}
|
||||
461
application/agents/tools/internal_search.py
Normal file
461
application/agents/tools/internal_search.py
Normal file
@@ -0,0 +1,461 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from application.agents.tools.base import Tool
|
||||
from application.core.settings import settings
|
||||
from application.retriever.retriever_creator import RetrieverCreator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InternalSearchTool(Tool):
|
||||
"""Wraps the ClassicRAG retriever as an LLM-callable tool.
|
||||
|
||||
Instead of pre-fetching docs into the prompt, the LLM decides
|
||||
when and what to search. Supports multiple searches per session.
|
||||
|
||||
Optional capabilities (enabled when sources have directory_structure):
|
||||
- path_filter on search: restrict results to a specific file/folder
|
||||
- list_files action: browse the file/folder structure
|
||||
"""
|
||||
|
||||
internal = True
|
||||
|
||||
def __init__(self, config: Dict):
|
||||
self.config = config
|
||||
self.retrieved_docs: List[Dict] = []
|
||||
self._retriever = None
|
||||
self._directory_structure: Optional[Dict] = None
|
||||
self._dir_structure_loaded = False
|
||||
|
||||
def _get_retriever(self):
|
||||
if self._retriever is None:
|
||||
self._retriever = RetrieverCreator.create_retriever(
|
||||
self.config.get("retriever_name", "classic"),
|
||||
source=self.config.get("source", {}),
|
||||
chat_history=[],
|
||||
prompt="",
|
||||
chunks=int(self.config.get("chunks", 2)),
|
||||
doc_token_limit=int(self.config.get("doc_token_limit", 50000)),
|
||||
model_id=self.config.get("model_id", "docsgpt-local"),
|
||||
model_user_id=self.config.get("model_user_id"),
|
||||
user_api_key=self.config.get("user_api_key"),
|
||||
agent_id=self.config.get("agent_id"),
|
||||
llm_name=self.config.get("llm_name", settings.LLM_PROVIDER),
|
||||
api_key=self.config.get("api_key", settings.API_KEY),
|
||||
decoded_token=self.config.get("decoded_token"),
|
||||
)
|
||||
return self._retriever
|
||||
|
||||
def _get_directory_structure(self) -> Optional[Dict]:
|
||||
"""Load directory structure from Postgres for the configured sources."""
|
||||
if self._dir_structure_loaded:
|
||||
return self._directory_structure
|
||||
|
||||
self._dir_structure_loaded = True
|
||||
source = self.config.get("source", {})
|
||||
active_docs = source.get("active_docs", [])
|
||||
if not active_docs:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Per-operation session: this tool runs inside the answer
|
||||
# generator hot path, so we open a short-lived read
|
||||
# connection for the batch lookup and release immediately.
|
||||
from application.storage.db.repositories.sources import (
|
||||
SourcesRepository,
|
||||
)
|
||||
from application.storage.db.session import db_readonly
|
||||
|
||||
if isinstance(active_docs, str):
|
||||
active_docs = [active_docs]
|
||||
|
||||
decoded_token = self.config.get("decoded_token") or {}
|
||||
user_id = decoded_token.get("sub") if decoded_token else None
|
||||
|
||||
merged_structure = {}
|
||||
with db_readonly() as conn:
|
||||
repo = SourcesRepository(conn)
|
||||
for doc_id in active_docs:
|
||||
try:
|
||||
source_doc = repo.get_any(str(doc_id), user_id) if user_id else None
|
||||
if not source_doc:
|
||||
continue
|
||||
dir_str = source_doc.get("directory_structure")
|
||||
if dir_str:
|
||||
if isinstance(dir_str, str):
|
||||
dir_str = json.loads(dir_str)
|
||||
source_name = source_doc.get("name", doc_id)
|
||||
if len(active_docs) > 1:
|
||||
merged_structure[source_name] = dir_str
|
||||
else:
|
||||
merged_structure = dir_str
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not load dir structure for {doc_id}: {e}")
|
||||
|
||||
self._directory_structure = merged_structure if merged_structure else None
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to load directory structures: {e}")
|
||||
|
||||
return self._directory_structure
|
||||
|
||||
def execute_action(self, action_name: str, **kwargs):
|
||||
if action_name == "search":
|
||||
return self._execute_search(**kwargs)
|
||||
elif action_name == "list_files":
|
||||
return self._execute_list_files(**kwargs)
|
||||
return f"Unknown action: {action_name}"
|
||||
|
||||
def _execute_search(self, **kwargs) -> str:
|
||||
query = kwargs.get("query", "")
|
||||
path_filter = kwargs.get("path_filter", "")
|
||||
|
||||
if not query:
|
||||
return "Error: 'query' parameter is required."
|
||||
|
||||
try:
|
||||
retriever = self._get_retriever()
|
||||
docs = retriever.search(query)
|
||||
except Exception as e:
|
||||
logger.error(f"Internal search failed: {e}", exc_info=True)
|
||||
return "Search failed: an internal error occurred."
|
||||
|
||||
if not docs:
|
||||
return "No documents found matching your query."
|
||||
|
||||
# Apply path filter if specified
|
||||
if path_filter:
|
||||
path_lower = path_filter.lower()
|
||||
docs = [
|
||||
d
|
||||
for d in docs
|
||||
if path_lower in d.get("source", "").lower()
|
||||
or path_lower in d.get("filename", "").lower()
|
||||
or path_lower in d.get("title", "").lower()
|
||||
]
|
||||
if not docs:
|
||||
return f"No documents found matching query '{query}' in path '{path_filter}'."
|
||||
|
||||
# Accumulate for source tracking
|
||||
for doc in docs:
|
||||
if doc not in self.retrieved_docs:
|
||||
self.retrieved_docs.append(doc)
|
||||
|
||||
# Format results for the LLM
|
||||
formatted = []
|
||||
for i, doc in enumerate(docs, 1):
|
||||
title = doc.get("title", "Untitled")
|
||||
text = doc.get("text", "")
|
||||
source = doc.get("source", "Unknown")
|
||||
filename = doc.get("filename", "")
|
||||
header = filename or title
|
||||
formatted.append(f"[{i}] {header} (source: {source})\n{text}")
|
||||
|
||||
return "\n\n---\n\n".join(formatted)
|
||||
|
||||
def _execute_list_files(self, **kwargs) -> str:
|
||||
path = kwargs.get("path", "")
|
||||
dir_structure = self._get_directory_structure()
|
||||
|
||||
if not dir_structure:
|
||||
return "No file structure available for the current sources."
|
||||
|
||||
# Navigate to the requested path
|
||||
current = dir_structure
|
||||
if path:
|
||||
for part in path.strip("/").split("/"):
|
||||
if not part:
|
||||
continue
|
||||
if isinstance(current, dict) and part in current:
|
||||
current = current[part]
|
||||
else:
|
||||
return f"Path '{path}' not found in the file structure."
|
||||
|
||||
# Format the structure for the LLM
|
||||
return self._format_structure(current, path or "/")
|
||||
|
||||
def _format_structure(self, node: Dict, current_path: str) -> str:
|
||||
if not isinstance(node, dict):
|
||||
return f"'{current_path}' is a file, not a directory."
|
||||
|
||||
lines = [f"File structure at '{current_path}':\n"]
|
||||
folders = []
|
||||
files = []
|
||||
|
||||
for name, value in sorted(node.items()):
|
||||
if isinstance(value, dict):
|
||||
# Check if it's a file metadata dict or a folder
|
||||
if "type" in value or "size_bytes" in value or "token_count" in value:
|
||||
# It's a file with metadata
|
||||
size = value.get("token_count", "")
|
||||
ftype = value.get("type", "")
|
||||
info_parts = []
|
||||
if ftype:
|
||||
info_parts.append(ftype)
|
||||
if size:
|
||||
info_parts.append(f"{size} tokens")
|
||||
info = f" ({', '.join(info_parts)})" if info_parts else ""
|
||||
files.append(f" {name}{info}")
|
||||
else:
|
||||
# It's a folder
|
||||
count = self._count_files(value)
|
||||
folders.append(f" {name}/ ({count} items)")
|
||||
else:
|
||||
files.append(f" {name}")
|
||||
|
||||
if folders:
|
||||
lines.append("Folders:")
|
||||
lines.extend(folders)
|
||||
if files:
|
||||
lines.append("Files:")
|
||||
lines.extend(files)
|
||||
if not folders and not files:
|
||||
lines.append(" (empty)")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _count_files(self, node: Dict) -> int:
|
||||
count = 0
|
||||
for value in node.values():
|
||||
if isinstance(value, dict):
|
||||
if "type" in value or "size_bytes" in value or "token_count" in value:
|
||||
count += 1
|
||||
else:
|
||||
count += self._count_files(value)
|
||||
else:
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def get_actions_metadata(self):
|
||||
actions = [
|
||||
{
|
||||
"name": "search",
|
||||
"description": (
|
||||
"Search the user's uploaded documents and knowledge base. "
|
||||
"Use this to find relevant information before answering questions. "
|
||||
"You can call this multiple times with different queries."
|
||||
),
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search query. Be specific and focused.",
|
||||
"filled_by_llm": True,
|
||||
"required": True,
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
# Add path_filter and list_files only if directory structure exists
|
||||
has_structure = self.config.get("has_directory_structure", False)
|
||||
if has_structure:
|
||||
actions[0]["parameters"]["properties"]["path_filter"] = {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Optional: filter results to a specific file or folder path. "
|
||||
"Use list_files first to see available paths."
|
||||
),
|
||||
"filled_by_llm": True,
|
||||
"required": False,
|
||||
}
|
||||
actions.append(
|
||||
{
|
||||
"name": "list_files",
|
||||
"description": (
|
||||
"Browse the file and folder structure of the knowledge base. "
|
||||
"Use this to see what files are available before searching. "
|
||||
"Optionally provide a path to browse a specific folder."
|
||||
),
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Optional: folder path to browse. Leave empty for root.",
|
||||
"filled_by_llm": True,
|
||||
"required": False,
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return actions
|
||||
|
||||
def get_config_requirements(self):
|
||||
return {}
|
||||
|
||||
|
||||
# Constants for building synthetic tools_dict entries
|
||||
INTERNAL_TOOL_ID = "internal"
|
||||
|
||||
|
||||
def build_internal_tool_entry(has_directory_structure: bool = False) -> Dict:
|
||||
"""Build the tools_dict entry for InternalSearchTool.
|
||||
|
||||
Dynamically includes list_files and path_filter based on
|
||||
whether the sources have directory structure.
|
||||
"""
|
||||
search_params = {
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search query. Be specific and focused.",
|
||||
"filled_by_llm": True,
|
||||
"required": True,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
actions = [
|
||||
{
|
||||
"name": "search",
|
||||
"description": (
|
||||
"Search the user's uploaded documents and knowledge base. "
|
||||
"Use this to find relevant information before answering questions. "
|
||||
"You can call this multiple times with different queries."
|
||||
),
|
||||
"active": True,
|
||||
"parameters": search_params,
|
||||
}
|
||||
]
|
||||
|
||||
if has_directory_structure:
|
||||
search_params["properties"]["path_filter"] = {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Optional: filter results to a specific file or folder path. "
|
||||
"Use list_files first to see available paths."
|
||||
),
|
||||
"filled_by_llm": True,
|
||||
"required": False,
|
||||
}
|
||||
actions.append(
|
||||
{
|
||||
"name": "list_files",
|
||||
"description": (
|
||||
"Browse the file and folder structure of the knowledge base. "
|
||||
"Use this to see what files are available before searching. "
|
||||
"Optionally provide a path to browse a specific folder."
|
||||
),
|
||||
"active": True,
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Optional: folder path to browse. Leave empty for root.",
|
||||
"filled_by_llm": True,
|
||||
"required": False,
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return {"name": "internal_search", "actions": actions}
|
||||
|
||||
|
||||
# Keep backward compat
|
||||
INTERNAL_TOOL_ENTRY = build_internal_tool_entry(has_directory_structure=False)
|
||||
|
||||
|
||||
def sources_have_directory_structure(source: Dict) -> bool:
|
||||
"""Check if any of the active sources have a ``directory_structure`` row."""
|
||||
active_docs = source.get("active_docs", [])
|
||||
if not active_docs:
|
||||
return False
|
||||
|
||||
try:
|
||||
# TODO(pg-cutover): SourcesRepository.get_any requires ``user_id``
|
||||
# scoping, but callers in the agent build path don't always
|
||||
# thread the decoded token through here. Use a direct
|
||||
# short-lived SQL lookup instead of the repo until the call
|
||||
# sites are updated to propagate user context.
|
||||
from sqlalchemy import text as _text
|
||||
|
||||
from application.storage.db.session import db_readonly
|
||||
|
||||
if isinstance(active_docs, str):
|
||||
active_docs = [active_docs]
|
||||
|
||||
with db_readonly() as conn:
|
||||
for doc_id in active_docs:
|
||||
try:
|
||||
value = str(doc_id)
|
||||
if len(value) == 36 and "-" in value:
|
||||
row = conn.execute(
|
||||
_text(
|
||||
"SELECT directory_structure FROM sources "
|
||||
"WHERE id = CAST(:id AS uuid)"
|
||||
),
|
||||
{"id": value},
|
||||
).fetchone()
|
||||
else:
|
||||
row = conn.execute(
|
||||
_text(
|
||||
"SELECT directory_structure FROM sources "
|
||||
"WHERE legacy_mongo_id = :lid"
|
||||
),
|
||||
{"lid": value},
|
||||
).fetchone()
|
||||
if row is not None and row[0]:
|
||||
return True
|
||||
except Exception:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not check directory structure: {e}")
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def add_internal_search_tool(tools_dict: Dict, retriever_config: Dict) -> None:
|
||||
"""Add the internal search tool to tools_dict if sources are configured.
|
||||
|
||||
Shared by AgenticAgent and ResearchAgent to avoid duplicate setup logic.
|
||||
Mutates tools_dict in place.
|
||||
"""
|
||||
source = retriever_config.get("source", {})
|
||||
has_sources = bool(source.get("active_docs"))
|
||||
if not retriever_config or not has_sources:
|
||||
return
|
||||
|
||||
has_dir = sources_have_directory_structure(source)
|
||||
internal_entry = build_internal_tool_entry(has_directory_structure=has_dir)
|
||||
internal_entry["config"] = build_internal_tool_config(
|
||||
**retriever_config,
|
||||
has_directory_structure=has_dir,
|
||||
)
|
||||
tools_dict[INTERNAL_TOOL_ID] = internal_entry
|
||||
|
||||
|
||||
def build_internal_tool_config(
|
||||
source: Dict,
|
||||
retriever_name: str = "classic",
|
||||
chunks: int = 2,
|
||||
doc_token_limit: int = 50000,
|
||||
model_id: str = "docsgpt-local",
|
||||
model_user_id: Optional[str] = None,
|
||||
user_api_key: Optional[str] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
llm_name: str = None,
|
||||
api_key: str = None,
|
||||
decoded_token: Optional[Dict] = None,
|
||||
has_directory_structure: bool = False,
|
||||
) -> Dict:
|
||||
"""Build the config dict for InternalSearchTool."""
|
||||
return {
|
||||
"source": source,
|
||||
"retriever_name": retriever_name,
|
||||
"chunks": chunks,
|
||||
"doc_token_limit": doc_token_limit,
|
||||
"model_id": model_id,
|
||||
"model_user_id": model_user_id,
|
||||
"user_api_key": user_api_key,
|
||||
"agent_id": agent_id,
|
||||
"llm_name": llm_name or settings.LLM_PROVIDER,
|
||||
"api_key": api_key or settings.API_KEY,
|
||||
"decoded_token": decoded_token,
|
||||
"has_directory_structure": has_directory_structure,
|
||||
}
|
||||
1045
application/agents/tools/mcp_tool.py
Normal file
1045
application/agents/tools/mcp_tool.py
Normal file
File diff suppressed because it is too large
Load Diff
548
application/agents/tools/memory.py
Normal file
548
application/agents/tools/memory.py
Normal file
@@ -0,0 +1,548 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from .base import Tool
|
||||
from application.storage.db.repositories.memories import MemoriesRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemoryTool(Tool):
|
||||
"""Memory
|
||||
|
||||
Stores and retrieves information across conversations through a memory file directory.
|
||||
"""
|
||||
|
||||
def __init__(self, tool_config: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> None:
|
||||
"""Initialize the tool.
|
||||
|
||||
Args:
|
||||
tool_config: Optional tool configuration. Should include:
|
||||
- tool_id: Unique identifier for this memory tool instance (from user_tools._id)
|
||||
This ensures each user's tool configuration has isolated memories
|
||||
user_id: The authenticated user's id (should come from decoded_token["sub"]).
|
||||
"""
|
||||
self.user_id: Optional[str] = user_id
|
||||
|
||||
# Get tool_id from configuration (passed from user_tools._id in production)
|
||||
# In production, tool_id is the UUID string from user_tools.id.
|
||||
if tool_config and "tool_id" in tool_config:
|
||||
self.tool_id = tool_config["tool_id"]
|
||||
elif user_id:
|
||||
# Fallback for backward compatibility or testing
|
||||
self.tool_id = f"default_{user_id}"
|
||||
else:
|
||||
# Last resort fallback (shouldn't happen in normal use)
|
||||
self.tool_id = str(uuid.uuid4())
|
||||
|
||||
def _pg_enabled(self) -> bool:
|
||||
"""Return True if this MemoryTool's tool_id is a real ``user_tools.id``.
|
||||
|
||||
The ``memories`` PG table has a UUID foreign key to ``user_tools``.
|
||||
The sentinel ``default_{uid}`` fallback tool_id is not a UUID and
|
||||
has no row in ``user_tools``, so any storage operation would fail
|
||||
the foreign-key check. After the Postgres cutover Postgres is the
|
||||
only store, so for the sentinel case there is nowhere to read or
|
||||
write — operations become no-ops and the tool returns an
|
||||
explanatory error to the caller.
|
||||
"""
|
||||
tool_id = getattr(self, "tool_id", None)
|
||||
if not tool_id or not isinstance(tool_id, str):
|
||||
return False
|
||||
if tool_id.startswith("default_"):
|
||||
logger.debug(
|
||||
"Skipping Postgres operation for MemoryTool with sentinel tool_id=%s",
|
||||
tool_id,
|
||||
)
|
||||
return False
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
|
||||
if not looks_like_uuid(tool_id):
|
||||
logger.debug(
|
||||
"Skipping Postgres operation for MemoryTool with non-UUID tool_id=%s",
|
||||
tool_id,
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
# -----------------------------
|
||||
# Action implementations
|
||||
# -----------------------------
|
||||
def execute_action(self, action_name: str, **kwargs: Any) -> str:
|
||||
"""Execute an action by name.
|
||||
|
||||
Args:
|
||||
action_name: One of view, create, str_replace, insert, delete, rename.
|
||||
**kwargs: Parameters for the action.
|
||||
|
||||
Returns:
|
||||
A human-readable string result.
|
||||
"""
|
||||
if not self.user_id:
|
||||
return "Error: MemoryTool requires a valid user_id."
|
||||
|
||||
if not self._pg_enabled():
|
||||
return (
|
||||
"Error: MemoryTool is not configured with a persistent tool_id; "
|
||||
"memory storage is unavailable for this session."
|
||||
)
|
||||
|
||||
if action_name == "view":
|
||||
return self._view(
|
||||
kwargs.get("path", "/"),
|
||||
kwargs.get("view_range")
|
||||
)
|
||||
|
||||
if action_name == "create":
|
||||
return self._create(
|
||||
kwargs.get("path", ""),
|
||||
kwargs.get("file_text", "")
|
||||
)
|
||||
|
||||
if action_name == "str_replace":
|
||||
return self._str_replace(
|
||||
kwargs.get("path", ""),
|
||||
kwargs.get("old_str", ""),
|
||||
kwargs.get("new_str", "")
|
||||
)
|
||||
|
||||
if action_name == "insert":
|
||||
return self._insert(
|
||||
kwargs.get("path", ""),
|
||||
kwargs.get("insert_line", 1),
|
||||
kwargs.get("insert_text", "")
|
||||
)
|
||||
|
||||
if action_name == "delete":
|
||||
return self._delete(kwargs.get("path", ""))
|
||||
|
||||
if action_name == "rename":
|
||||
return self._rename(
|
||||
kwargs.get("old_path", ""),
|
||||
kwargs.get("new_path", "")
|
||||
)
|
||||
|
||||
return f"Unknown action: {action_name}"
|
||||
|
||||
def get_actions_metadata(self) -> List[Dict[str, Any]]:
|
||||
"""Return JSON metadata describing supported actions for tool schemas."""
|
||||
return [
|
||||
{
|
||||
"name": "view",
|
||||
"description": "Shows directory contents or file contents with optional line ranges.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to file or directory (e.g., /notes.txt or /project/ or /)."
|
||||
},
|
||||
"view_range": {
|
||||
"type": "array",
|
||||
"items": {"type": "integer"},
|
||||
"description": "Optional [start_line, end_line] to view specific lines (1-indexed)."
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "create",
|
||||
"description": "Create or overwrite a file.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "File path to create (e.g., /notes.txt or /project/task.txt)."
|
||||
},
|
||||
"file_text": {
|
||||
"type": "string",
|
||||
"description": "Content to write to the file."
|
||||
}
|
||||
},
|
||||
"required": ["path", "file_text"]
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "str_replace",
|
||||
"description": "Replace text in a file.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "File path (e.g., /notes.txt)."
|
||||
},
|
||||
"old_str": {
|
||||
"type": "string",
|
||||
"description": "String to find."
|
||||
},
|
||||
"new_str": {
|
||||
"type": "string",
|
||||
"description": "String to replace with."
|
||||
}
|
||||
},
|
||||
"required": ["path", "old_str", "new_str"]
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "insert",
|
||||
"description": "Insert text at a specific line in a file.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "File path (e.g., /notes.txt)."
|
||||
},
|
||||
"insert_line": {
|
||||
"type": "integer",
|
||||
"description": "Line number to insert at (1-indexed)."
|
||||
},
|
||||
"insert_text": {
|
||||
"type": "string",
|
||||
"description": "Text to insert."
|
||||
}
|
||||
},
|
||||
"required": ["path", "insert_line", "insert_text"]
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "delete",
|
||||
"description": "Delete a file or directory.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to delete (e.g., /notes.txt or /project/)."
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "rename",
|
||||
"description": "Rename or move a file/directory.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"old_path": {
|
||||
"type": "string",
|
||||
"description": "Current path (e.g., /old.txt)."
|
||||
},
|
||||
"new_path": {
|
||||
"type": "string",
|
||||
"description": "New path (e.g., /new.txt)."
|
||||
}
|
||||
},
|
||||
"required": ["old_path", "new_path"]
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
def get_config_requirements(self) -> Dict[str, Any]:
|
||||
"""Return configuration requirements."""
|
||||
return {}
|
||||
|
||||
# -----------------------------
|
||||
# Path validation
|
||||
# -----------------------------
|
||||
def _validate_path(self, path: str) -> Optional[str]:
|
||||
"""Validate and normalize path.
|
||||
|
||||
Args:
|
||||
path: User-provided path.
|
||||
|
||||
Returns:
|
||||
Normalized path or None if invalid.
|
||||
"""
|
||||
if not path:
|
||||
return None
|
||||
|
||||
# Remove any leading/trailing whitespace
|
||||
path = path.strip()
|
||||
|
||||
# Preserve whether path ends with / (indicates directory)
|
||||
is_directory = path.endswith("/")
|
||||
|
||||
# Ensure path starts with / for consistency
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
|
||||
# Check for directory traversal patterns
|
||||
if ".." in path or path.count("//") > 0:
|
||||
return None
|
||||
|
||||
# Normalize the path
|
||||
try:
|
||||
# Convert to Path object and resolve to canonical form
|
||||
normalized = str(Path(path).as_posix())
|
||||
|
||||
# Ensure it still starts with /
|
||||
if not normalized.startswith("/"):
|
||||
return None
|
||||
|
||||
# Preserve trailing slash for directories
|
||||
if is_directory and not normalized.endswith("/") and normalized != "/":
|
||||
normalized = normalized + "/"
|
||||
|
||||
return normalized
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
# -----------------------------
|
||||
# Internal helpers
|
||||
# -----------------------------
|
||||
def _view(self, path: str, view_range: Optional[List[int]] = None) -> str:
|
||||
"""View directory contents or file contents."""
|
||||
validated_path = self._validate_path(path)
|
||||
if not validated_path:
|
||||
return "Error: Invalid path."
|
||||
|
||||
# Check if viewing directory (ends with / or is root)
|
||||
if validated_path == "/" or validated_path.endswith("/"):
|
||||
return self._view_directory(validated_path)
|
||||
|
||||
# Otherwise view file
|
||||
return self._view_file(validated_path, view_range)
|
||||
|
||||
def _view_directory(self, path: str) -> str:
|
||||
"""List files in a directory."""
|
||||
# Ensure path ends with / for proper prefix matching
|
||||
search_path = path if path.endswith("/") else path + "/"
|
||||
|
||||
with db_readonly() as conn:
|
||||
docs = MemoriesRepository(conn).list_by_prefix(
|
||||
self.user_id, self.tool_id, search_path
|
||||
)
|
||||
|
||||
if not docs:
|
||||
return f"Directory: {path}\n(empty)"
|
||||
|
||||
# Extract filenames relative to the directory
|
||||
files = []
|
||||
for doc in docs:
|
||||
file_path = doc["path"]
|
||||
# Remove the directory prefix
|
||||
if file_path.startswith(search_path):
|
||||
relative = file_path[len(search_path):]
|
||||
if relative:
|
||||
files.append(relative)
|
||||
|
||||
files.sort()
|
||||
file_list = "\n".join(f"- {f}" for f in files)
|
||||
return f"Directory: {path}\n{file_list}"
|
||||
|
||||
def _view_file(self, path: str, view_range: Optional[List[int]] = None) -> str:
|
||||
"""View file contents with optional line range."""
|
||||
with db_readonly() as conn:
|
||||
doc = MemoriesRepository(conn).get_by_path(
|
||||
self.user_id, self.tool_id, path
|
||||
)
|
||||
|
||||
if not doc or not doc.get("content"):
|
||||
return f"Error: File not found: {path}"
|
||||
|
||||
content = str(doc["content"])
|
||||
|
||||
# Apply view_range if specified
|
||||
if view_range and len(view_range) == 2:
|
||||
lines = content.split("\n")
|
||||
start, end = view_range
|
||||
# Convert to 0-indexed
|
||||
start_idx = max(0, start - 1)
|
||||
end_idx = min(len(lines), end)
|
||||
|
||||
if start_idx >= len(lines):
|
||||
return f"Error: Line range out of bounds. File has {len(lines)} lines."
|
||||
|
||||
selected_lines = lines[start_idx:end_idx]
|
||||
# Add line numbers (enumerate with 1-based start)
|
||||
numbered_lines = [f"{i}: {line}" for i, line in enumerate(selected_lines, start=start)]
|
||||
return "\n".join(numbered_lines)
|
||||
|
||||
return content
|
||||
|
||||
def _create(self, path: str, file_text: str) -> str:
|
||||
"""Create or overwrite a file."""
|
||||
validated_path = self._validate_path(path)
|
||||
if not validated_path:
|
||||
return "Error: Invalid path."
|
||||
|
||||
if validated_path == "/" or validated_path.endswith("/"):
|
||||
return "Error: Cannot create a file at directory path."
|
||||
|
||||
with db_session() as conn:
|
||||
MemoriesRepository(conn).upsert(
|
||||
self.user_id, self.tool_id, validated_path, file_text
|
||||
)
|
||||
|
||||
return f"File created: {validated_path}"
|
||||
|
||||
def _str_replace(self, path: str, old_str: str, new_str: str) -> str:
|
||||
"""Replace text in a file."""
|
||||
validated_path = self._validate_path(path)
|
||||
if not validated_path:
|
||||
return "Error: Invalid path."
|
||||
|
||||
if not old_str:
|
||||
return "Error: old_str is required."
|
||||
|
||||
with db_session() as conn:
|
||||
repo = MemoriesRepository(conn)
|
||||
doc = repo.get_by_path(self.user_id, self.tool_id, validated_path)
|
||||
|
||||
if not doc or not doc.get("content"):
|
||||
return f"Error: File not found: {validated_path}"
|
||||
|
||||
current_content = str(doc["content"])
|
||||
|
||||
# Check if old_str exists (case-insensitive)
|
||||
if old_str.lower() not in current_content.lower():
|
||||
return f"Error: String '{old_str}' not found in file."
|
||||
|
||||
# Case-insensitive replace
|
||||
import re as regex_module
|
||||
updated_content = regex_module.sub(
|
||||
regex_module.escape(old_str),
|
||||
new_str,
|
||||
current_content,
|
||||
flags=regex_module.IGNORECASE,
|
||||
)
|
||||
|
||||
repo.upsert(self.user_id, self.tool_id, validated_path, updated_content)
|
||||
|
||||
return f"File updated: {validated_path}"
|
||||
|
||||
def _insert(self, path: str, insert_line: int, insert_text: str) -> str:
|
||||
"""Insert text at a specific line."""
|
||||
validated_path = self._validate_path(path)
|
||||
if not validated_path:
|
||||
return "Error: Invalid path."
|
||||
|
||||
if not insert_text:
|
||||
return "Error: insert_text is required."
|
||||
|
||||
with db_session() as conn:
|
||||
repo = MemoriesRepository(conn)
|
||||
doc = repo.get_by_path(self.user_id, self.tool_id, validated_path)
|
||||
|
||||
if not doc or not doc.get("content"):
|
||||
return f"Error: File not found: {validated_path}"
|
||||
|
||||
current_content = str(doc["content"])
|
||||
lines = current_content.split("\n")
|
||||
|
||||
# Convert to 0-indexed
|
||||
index = insert_line - 1
|
||||
if index < 0 or index > len(lines):
|
||||
return f"Error: Invalid line number. File has {len(lines)} lines."
|
||||
|
||||
lines.insert(index, insert_text)
|
||||
updated_content = "\n".join(lines)
|
||||
|
||||
repo.upsert(self.user_id, self.tool_id, validated_path, updated_content)
|
||||
|
||||
return f"Text inserted at line {insert_line} in {validated_path}"
|
||||
|
||||
def _delete(self, path: str) -> str:
|
||||
"""Delete a file or directory."""
|
||||
validated_path = self._validate_path(path)
|
||||
if not validated_path:
|
||||
return "Error: Invalid path."
|
||||
|
||||
if validated_path == "/":
|
||||
# Delete all files for this user and tool
|
||||
with db_session() as conn:
|
||||
deleted = MemoriesRepository(conn).delete_all(
|
||||
self.user_id, self.tool_id
|
||||
)
|
||||
return f"Deleted {deleted} file(s) from memory."
|
||||
|
||||
# Check if it's a directory (ends with /)
|
||||
if validated_path.endswith("/"):
|
||||
with db_session() as conn:
|
||||
deleted = MemoriesRepository(conn).delete_by_prefix(
|
||||
self.user_id, self.tool_id, validated_path
|
||||
)
|
||||
return f"Deleted directory and {deleted} file(s)."
|
||||
|
||||
# Try as directory first (without trailing slash)
|
||||
search_path = validated_path + "/"
|
||||
with db_session() as conn:
|
||||
repo = MemoriesRepository(conn)
|
||||
directory_deleted = repo.delete_by_prefix(
|
||||
self.user_id, self.tool_id, search_path
|
||||
)
|
||||
if directory_deleted > 0:
|
||||
return f"Deleted directory and {directory_deleted} file(s)."
|
||||
|
||||
# Otherwise delete a single file
|
||||
file_deleted = repo.delete_by_path(
|
||||
self.user_id, self.tool_id, validated_path
|
||||
)
|
||||
|
||||
if file_deleted:
|
||||
return f"Deleted: {validated_path}"
|
||||
return f"Error: File not found: {validated_path}"
|
||||
|
||||
def _rename(self, old_path: str, new_path: str) -> str:
|
||||
"""Rename or move a file/directory."""
|
||||
validated_old = self._validate_path(old_path)
|
||||
validated_new = self._validate_path(new_path)
|
||||
|
||||
if not validated_old or not validated_new:
|
||||
return "Error: Invalid path."
|
||||
|
||||
if validated_old == "/" or validated_new == "/":
|
||||
return "Error: Cannot rename root directory."
|
||||
|
||||
# Directory rename: do all path updates inside one transaction so
|
||||
# the rename is atomic from the caller's perspective.
|
||||
if validated_old.endswith("/"):
|
||||
# Ensure validated_new also ends with / for proper path replacement
|
||||
if not validated_new.endswith("/"):
|
||||
validated_new = validated_new + "/"
|
||||
|
||||
with db_session() as conn:
|
||||
repo = MemoriesRepository(conn)
|
||||
docs = repo.list_by_prefix(
|
||||
self.user_id, self.tool_id, validated_old
|
||||
)
|
||||
|
||||
if not docs:
|
||||
return f"Error: Directory not found: {validated_old}"
|
||||
|
||||
for doc in docs:
|
||||
old_file_path = doc["path"]
|
||||
new_file_path = old_file_path.replace(
|
||||
validated_old, validated_new, 1
|
||||
)
|
||||
repo.update_path(
|
||||
self.user_id, self.tool_id, old_file_path, new_file_path
|
||||
)
|
||||
|
||||
return f"Renamed directory: {validated_old} -> {validated_new} ({len(docs)} files)"
|
||||
|
||||
# Single-file rename: lookup, collision check, and update in one txn.
|
||||
with db_session() as conn:
|
||||
repo = MemoriesRepository(conn)
|
||||
doc = repo.get_by_path(self.user_id, self.tool_id, validated_old)
|
||||
if not doc:
|
||||
return f"Error: File not found: {validated_old}"
|
||||
|
||||
existing = repo.get_by_path(self.user_id, self.tool_id, validated_new)
|
||||
if existing:
|
||||
return f"Error: File already exists at {validated_new}"
|
||||
|
||||
repo.update_path(
|
||||
self.user_id, self.tool_id, validated_old, validated_new
|
||||
)
|
||||
|
||||
return f"Renamed: {validated_old} -> {validated_new}"
|
||||
258
application/agents/tools/notes.py
Normal file
258
application/agents/tools/notes.py
Normal file
@@ -0,0 +1,258 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
import uuid
|
||||
|
||||
from .base import Tool
|
||||
from application.storage.db.repositories.notes import NotesRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
|
||||
# Stable synthetic title used in the Postgres ``notes.title`` column.
|
||||
# The notes tool stores one note per (user_id, tool_id); there is no
|
||||
# user-facing title. PG requires ``title`` NOT NULL, so we write a stable
|
||||
# constant alongside the actual note body in ``content``.
|
||||
_NOTE_TITLE = "note"
|
||||
|
||||
|
||||
class NotesTool(Tool):
|
||||
"""Notepad
|
||||
|
||||
Single note. Supports viewing, overwriting, string replacement.
|
||||
"""
|
||||
|
||||
def __init__(self, tool_config: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> None:
|
||||
"""Initialize the tool.
|
||||
|
||||
Args:
|
||||
tool_config: Optional tool configuration. Should include:
|
||||
- tool_id: Unique identifier for this notes tool instance (from user_tools._id)
|
||||
This ensures each user's tool configuration has isolated notes
|
||||
user_id: The authenticated user's id (should come from decoded_token["sub"]).
|
||||
"""
|
||||
self.user_id: Optional[str] = user_id
|
||||
|
||||
# Get tool_id from configuration (passed from user_tools._id in production)
|
||||
if tool_config and "tool_id" in tool_config:
|
||||
self.tool_id = tool_config["tool_id"]
|
||||
elif user_id:
|
||||
# Fallback for backward compatibility or testing
|
||||
self.tool_id = f"default_{user_id}"
|
||||
else:
|
||||
# Last resort fallback (shouldn't happen in normal use)
|
||||
self.tool_id = str(uuid.uuid4())
|
||||
|
||||
self._last_artifact_id: Optional[str] = None
|
||||
|
||||
def _pg_enabled(self) -> bool:
|
||||
"""Return True only when ``tool_id`` is a real ``user_tools.id`` UUID.
|
||||
|
||||
``notes.tool_id`` is a UUID FK to ``user_tools``; repo queries
|
||||
``CAST(:tool_id AS uuid)``. The sentinel ``default_{uid}``
|
||||
fallback is neither a UUID nor a ``user_tools`` row, so any DB
|
||||
operation would crash. Mirror MemoryTool's guard and no-op.
|
||||
"""
|
||||
tool_id = getattr(self, "tool_id", None)
|
||||
if not tool_id or not isinstance(tool_id, str):
|
||||
return False
|
||||
if tool_id.startswith("default_"):
|
||||
return False
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
|
||||
return looks_like_uuid(tool_id)
|
||||
|
||||
# -----------------------------
|
||||
# Action implementations
|
||||
# -----------------------------
|
||||
def execute_action(self, action_name: str, **kwargs: Any) -> str:
|
||||
"""Execute an action by name.
|
||||
|
||||
Args:
|
||||
action_name: One of view, overwrite, str_replace, insert, delete.
|
||||
**kwargs: Parameters for the action.
|
||||
|
||||
Returns:
|
||||
A human-readable string result.
|
||||
"""
|
||||
if not self.user_id:
|
||||
return "Error: NotesTool requires a valid user_id."
|
||||
|
||||
if not self._pg_enabled():
|
||||
return (
|
||||
"Error: NotesTool is not configured with a persistent "
|
||||
"tool_id; note storage is unavailable for this session."
|
||||
)
|
||||
|
||||
self._last_artifact_id = None
|
||||
|
||||
if action_name == "view":
|
||||
return self._get_note()
|
||||
|
||||
if action_name == "overwrite":
|
||||
return self._overwrite_note(kwargs.get("text", ""))
|
||||
|
||||
if action_name == "str_replace":
|
||||
return self._str_replace(kwargs.get("old_str", ""), kwargs.get("new_str", ""))
|
||||
|
||||
if action_name == "insert":
|
||||
return self._insert(kwargs.get("line_number", 1), kwargs.get("text", ""))
|
||||
|
||||
if action_name == "delete":
|
||||
return self._delete_note()
|
||||
|
||||
return f"Unknown action: {action_name}"
|
||||
|
||||
def get_actions_metadata(self) -> List[Dict[str, Any]]:
|
||||
"""Return JSON metadata describing supported actions for tool schemas."""
|
||||
return [
|
||||
{
|
||||
"name": "view",
|
||||
"description": "Retrieve the user's note.",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
{
|
||||
"name": "overwrite",
|
||||
"description": "Replace the entire note content (creates if doesn't exist).",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {"type": "string", "description": "New note content."}
|
||||
},
|
||||
"required": ["text"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "str_replace",
|
||||
"description": "Replace occurrences of old_str with new_str in the note.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"old_str": {"type": "string", "description": "String to find."},
|
||||
"new_str": {"type": "string", "description": "String to replace with."}
|
||||
},
|
||||
"required": ["old_str", "new_str"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "insert",
|
||||
"description": "Insert text at the specified line number (1-indexed).",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"line_number": {"type": "integer", "description": "Line number to insert at (1-indexed)."},
|
||||
"text": {"type": "string", "description": "Text to insert."}
|
||||
},
|
||||
"required": ["line_number", "text"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "delete",
|
||||
"description": "Delete the user's note.",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
]
|
||||
|
||||
def get_config_requirements(self) -> Dict[str, Any]:
|
||||
"""Return configuration requirements (none for now)."""
|
||||
return {}
|
||||
|
||||
def get_artifact_id(self, action_name: str, **kwargs: Any) -> Optional[str]:
|
||||
return self._last_artifact_id
|
||||
|
||||
# -----------------------------
|
||||
# Internal helpers (single-note)
|
||||
# -----------------------------
|
||||
def _fetch_note(self) -> Optional[dict]:
|
||||
"""Read the note row for this (user, tool) from Postgres."""
|
||||
with db_readonly() as conn:
|
||||
return NotesRepository(conn).get_for_user_tool(self.user_id, self.tool_id)
|
||||
|
||||
def _get_note(self) -> str:
|
||||
doc = self._fetch_note()
|
||||
# ``content`` is the PG column; expose as ``note`` to callers via the
|
||||
# textual return value. Frontends that read the artifact via the
|
||||
# repo dict get ``content`` (PG-native) plus the artifact id below.
|
||||
body = (doc or {}).get("content")
|
||||
if not doc or not body:
|
||||
return "No note found."
|
||||
if doc.get("id") is not None:
|
||||
self._last_artifact_id = str(doc.get("id"))
|
||||
return str(body)
|
||||
|
||||
def _overwrite_note(self, content: str) -> str:
|
||||
content = (content or "").strip()
|
||||
if not content:
|
||||
return "Note content required."
|
||||
with db_session() as conn:
|
||||
row = NotesRepository(conn).upsert(
|
||||
self.user_id, self.tool_id, _NOTE_TITLE, content
|
||||
)
|
||||
if row and row.get("id") is not None:
|
||||
self._last_artifact_id = str(row.get("id"))
|
||||
return "Note saved."
|
||||
|
||||
def _str_replace(self, old_str: str, new_str: str) -> str:
|
||||
if not old_str:
|
||||
return "old_str is required."
|
||||
|
||||
doc = self._fetch_note()
|
||||
existing = (doc or {}).get("content")
|
||||
if not doc or not existing:
|
||||
return "No note found."
|
||||
|
||||
current_note = str(existing)
|
||||
|
||||
# Case-insensitive search
|
||||
if old_str.lower() not in current_note.lower():
|
||||
return f"String '{old_str}' not found in note."
|
||||
|
||||
# Case-insensitive replacement
|
||||
import re
|
||||
updated_note = re.sub(re.escape(old_str), new_str, current_note, flags=re.IGNORECASE)
|
||||
|
||||
with db_session() as conn:
|
||||
row = NotesRepository(conn).upsert(
|
||||
self.user_id, self.tool_id, _NOTE_TITLE, updated_note
|
||||
)
|
||||
if row and row.get("id") is not None:
|
||||
self._last_artifact_id = str(row.get("id"))
|
||||
return "Note updated."
|
||||
|
||||
def _insert(self, line_number: int, text: str) -> str:
|
||||
if not text:
|
||||
return "Text is required."
|
||||
|
||||
doc = self._fetch_note()
|
||||
existing = (doc or {}).get("content")
|
||||
if not doc or not existing:
|
||||
return "No note found."
|
||||
|
||||
current_note = str(existing)
|
||||
lines = current_note.split("\n")
|
||||
|
||||
# Convert to 0-indexed and validate
|
||||
index = line_number - 1
|
||||
if index < 0 or index > len(lines):
|
||||
return f"Invalid line number. Note has {len(lines)} lines."
|
||||
|
||||
lines.insert(index, text)
|
||||
updated_note = "\n".join(lines)
|
||||
|
||||
with db_session() as conn:
|
||||
row = NotesRepository(conn).upsert(
|
||||
self.user_id, self.tool_id, _NOTE_TITLE, updated_note
|
||||
)
|
||||
if row and row.get("id") is not None:
|
||||
self._last_artifact_id = str(row.get("id"))
|
||||
return "Text inserted."
|
||||
|
||||
def _delete_note(self) -> str:
|
||||
# Capture the id (for artifact tracking) before deleting.
|
||||
existing = self._fetch_note()
|
||||
if not existing:
|
||||
return "No note found to delete."
|
||||
with db_session() as conn:
|
||||
deleted = NotesRepository(conn).delete(self.user_id, self.tool_id)
|
||||
if not deleted:
|
||||
return "No note found to delete."
|
||||
if existing.get("id") is not None:
|
||||
self._last_artifact_id = str(existing.get("id"))
|
||||
return "Note deleted."
|
||||
@@ -71,7 +71,7 @@ class NtfyTool(Tool):
|
||||
if self.token:
|
||||
headers["Authorization"] = f"Basic {self.token}"
|
||||
data = message.encode("utf-8")
|
||||
response = requests.post(url, headers=headers, data=data)
|
||||
response = requests.post(url, headers=headers, data=data, timeout=100)
|
||||
return {"status_code": response.status_code, "message": "Message sent"}
|
||||
|
||||
def get_actions_metadata(self):
|
||||
@@ -116,12 +116,13 @@ class NtfyTool(Tool):
|
||||
]
|
||||
|
||||
def get_config_requirements(self):
|
||||
"""
|
||||
Specify the configuration requirements.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary describing required config parameters.
|
||||
"""
|
||||
return {
|
||||
"token": {"type": "string", "description": "Access token for authentication"},
|
||||
"token": {
|
||||
"type": "string",
|
||||
"label": "Access Token",
|
||||
"description": "Ntfy access token for authentication",
|
||||
"required": True,
|
||||
"secret": True,
|
||||
"order": 1,
|
||||
},
|
||||
}
|
||||
@@ -1,6 +1,12 @@
|
||||
import psycopg2
|
||||
import logging
|
||||
|
||||
import psycopg
|
||||
|
||||
from application.agents.tools.base import Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PostgresTool(Tool):
|
||||
"""
|
||||
PostgreSQL Database Tool
|
||||
@@ -17,25 +23,25 @@ class PostgresTool(Tool):
|
||||
"postgres_execute_sql": self._execute_sql,
|
||||
"postgres_get_schema": self._get_schema,
|
||||
}
|
||||
|
||||
if action_name in actions:
|
||||
return actions[action_name](**kwargs)
|
||||
else:
|
||||
if action_name not in actions:
|
||||
raise ValueError(f"Unknown action: {action_name}")
|
||||
return actions[action_name](**kwargs)
|
||||
|
||||
def _execute_sql(self, sql_query):
|
||||
"""
|
||||
Executes an SQL query against the PostgreSQL database using a connection string.
|
||||
"""
|
||||
conn = None # Initialize conn to None for error handling
|
||||
conn = None
|
||||
try:
|
||||
conn = psycopg2.connect(self.connection_string)
|
||||
conn = psycopg.connect(self.connection_string)
|
||||
cur = conn.cursor()
|
||||
cur.execute(sql_query)
|
||||
conn.commit()
|
||||
|
||||
if sql_query.strip().lower().startswith("select"):
|
||||
column_names = [desc[0] for desc in cur.description] if cur.description else []
|
||||
column_names = (
|
||||
[desc[0] for desc in cur.description] if cur.description else []
|
||||
)
|
||||
results = []
|
||||
rows = cur.fetchall()
|
||||
for row in rows:
|
||||
@@ -43,7 +49,9 @@ class PostgresTool(Tool):
|
||||
response_data = {"data": results, "column_names": column_names}
|
||||
else:
|
||||
row_count = cur.rowcount
|
||||
response_data = {"message": f"Query executed successfully, {row_count} rows affected."}
|
||||
response_data = {
|
||||
"message": f"Query executed successfully, {row_count} rows affected."
|
||||
}
|
||||
|
||||
cur.close()
|
||||
return {
|
||||
@@ -52,28 +60,29 @@ class PostgresTool(Tool):
|
||||
"response_data": response_data,
|
||||
}
|
||||
|
||||
except psycopg2.Error as e:
|
||||
except psycopg.Error as e:
|
||||
error_message = f"Database error: {e}"
|
||||
print(f"Database error: {e}")
|
||||
logger.error("PostgreSQL execute_sql error: %s", e)
|
||||
return {
|
||||
"status_code": 500,
|
||||
"message": "Failed to execute SQL query.",
|
||||
"error": error_message,
|
||||
}
|
||||
finally:
|
||||
if conn: # Ensure connection is closed even if errors occur
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
def _get_schema(self, db_name):
|
||||
"""
|
||||
Retrieves the schema of the PostgreSQL database using a connection string.
|
||||
"""
|
||||
conn = None # Initialize conn to None for error handling
|
||||
conn = None
|
||||
try:
|
||||
conn = psycopg2.connect(self.connection_string)
|
||||
conn = psycopg.connect(self.connection_string)
|
||||
cur = conn.cursor()
|
||||
|
||||
cur.execute("""
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT
|
||||
table_name,
|
||||
column_name,
|
||||
@@ -87,19 +96,22 @@ class PostgresTool(Tool):
|
||||
ORDER BY
|
||||
table_name,
|
||||
ordinal_position;
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
schema_data = {}
|
||||
for row in cur.fetchall():
|
||||
table_name, column_name, data_type, column_default, is_nullable = row
|
||||
if table_name not in schema_data:
|
||||
schema_data[table_name] = []
|
||||
schema_data[table_name].append({
|
||||
"column_name": column_name,
|
||||
"data_type": data_type,
|
||||
"column_default": column_default,
|
||||
"is_nullable": is_nullable
|
||||
})
|
||||
schema_data[table_name].append(
|
||||
{
|
||||
"column_name": column_name,
|
||||
"data_type": data_type,
|
||||
"column_default": column_default,
|
||||
"is_nullable": is_nullable,
|
||||
}
|
||||
)
|
||||
|
||||
cur.close()
|
||||
return {
|
||||
@@ -108,16 +120,16 @@ class PostgresTool(Tool):
|
||||
"schema": schema_data,
|
||||
}
|
||||
|
||||
except psycopg2.Error as e:
|
||||
except psycopg.Error as e:
|
||||
error_message = f"Database error: {e}"
|
||||
print(f"Database error: {e}")
|
||||
logger.error("PostgreSQL get_schema error: %s", e)
|
||||
return {
|
||||
"status_code": 500,
|
||||
"message": "Failed to retrieve database schema.",
|
||||
"error": error_message,
|
||||
}
|
||||
finally:
|
||||
if conn: # Ensure connection is closed even if errors occur
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
def get_actions_metadata(self):
|
||||
@@ -158,6 +170,10 @@ class PostgresTool(Tool):
|
||||
return {
|
||||
"token": {
|
||||
"type": "string",
|
||||
"description": "PostgreSQL database connection string (e.g., 'postgresql://user:password@host:port/dbname')",
|
||||
"label": "Connection String",
|
||||
"description": "PostgreSQL database connection string",
|
||||
"required": True,
|
||||
"secret": True,
|
||||
"order": 1,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import requests
|
||||
from markdownify import markdownify
|
||||
from application.agents.tools.base import Tool
|
||||
from urllib.parse import urlparse
|
||||
from application.core.url_validation import validate_url, SSRFError
|
||||
|
||||
class ReadWebpageTool(Tool):
|
||||
"""
|
||||
@@ -31,11 +31,12 @@ class ReadWebpageTool(Tool):
|
||||
if not url:
|
||||
return "Error: URL parameter is missing."
|
||||
|
||||
# Ensure the URL has a scheme (if not, default to http)
|
||||
parsed_url = urlparse(url)
|
||||
if not parsed_url.scheme:
|
||||
url = "http://" + url
|
||||
|
||||
# Validate URL to prevent SSRF attacks
|
||||
try:
|
||||
url = validate_url(url)
|
||||
except SSRFError as e:
|
||||
return f"Error: URL validation failed - {e}"
|
||||
|
||||
try:
|
||||
response = requests.get(url, timeout=10, headers={'User-Agent': 'DocsGPT-Agent/1.0'})
|
||||
response.raise_for_status() # Raise an exception for HTTP errors (4xx or 5xx)
|
||||
|
||||
342
application/agents/tools/spec_parser.py
Normal file
342
application/agents/tools/spec_parser.py
Normal file
@@ -0,0 +1,342 @@
|
||||
"""
|
||||
API Specification Parser
|
||||
|
||||
Parses OpenAPI 3.x and Swagger 2.0 specifications and converts them
|
||||
to API Tool action definitions for use in DocsGPT.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import yaml
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SUPPORTED_METHODS = frozenset(
|
||||
{"get", "post", "put", "delete", "patch", "head", "options"}
|
||||
)
|
||||
|
||||
|
||||
def parse_spec(spec_content: str) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
|
||||
"""
|
||||
Parse an API specification and convert operations to action definitions.
|
||||
|
||||
Supports OpenAPI 3.x and Swagger 2.0 formats in JSON or YAML.
|
||||
|
||||
Args:
|
||||
spec_content: Raw specification content as string
|
||||
|
||||
Returns:
|
||||
Tuple of (metadata dict, list of action dicts)
|
||||
|
||||
Raises:
|
||||
ValueError: If the spec is invalid or uses an unsupported format
|
||||
"""
|
||||
spec = _load_spec(spec_content)
|
||||
_validate_spec(spec)
|
||||
|
||||
is_swagger = "swagger" in spec
|
||||
metadata = _extract_metadata(spec, is_swagger)
|
||||
actions = _extract_actions(spec, is_swagger)
|
||||
|
||||
return metadata, actions
|
||||
|
||||
|
||||
def _load_spec(content: str) -> Dict[str, Any]:
|
||||
"""Parse spec content from JSON or YAML string."""
|
||||
content = content.strip()
|
||||
if not content:
|
||||
raise ValueError("Empty specification content")
|
||||
try:
|
||||
if content.startswith("{"):
|
||||
return json.loads(content)
|
||||
return yaml.safe_load(content)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON format: {e.msg}")
|
||||
except yaml.YAMLError as e:
|
||||
raise ValueError(f"Invalid YAML format: {e}")
|
||||
|
||||
|
||||
def _validate_spec(spec: Dict[str, Any]) -> None:
|
||||
"""Validate spec version and required fields."""
|
||||
if not isinstance(spec, dict):
|
||||
raise ValueError("Specification must be a valid object")
|
||||
openapi_version = spec.get("openapi", "")
|
||||
swagger_version = spec.get("swagger", "")
|
||||
|
||||
if not (openapi_version.startswith("3.") or swagger_version == "2.0"):
|
||||
raise ValueError(
|
||||
"Unsupported specification version. Expected OpenAPI 3.x or Swagger 2.0"
|
||||
)
|
||||
if "paths" not in spec or not spec["paths"]:
|
||||
raise ValueError("No API paths defined in the specification")
|
||||
|
||||
|
||||
def _extract_metadata(spec: Dict[str, Any], is_swagger: bool) -> Dict[str, Any]:
|
||||
"""Extract API metadata from specification."""
|
||||
info = spec.get("info", {})
|
||||
base_url = _get_base_url(spec, is_swagger)
|
||||
|
||||
return {
|
||||
"title": info.get("title", "Untitled API"),
|
||||
"description": (info.get("description", "") or "")[:500],
|
||||
"version": info.get("version", ""),
|
||||
"base_url": base_url,
|
||||
}
|
||||
|
||||
|
||||
def _get_base_url(spec: Dict[str, Any], is_swagger: bool) -> str:
|
||||
"""Extract base URL from spec (handles both OpenAPI 3.x and Swagger 2.0)."""
|
||||
if is_swagger:
|
||||
schemes = spec.get("schemes", ["https"])
|
||||
host = spec.get("host", "")
|
||||
base_path = spec.get("basePath", "")
|
||||
if host:
|
||||
scheme = schemes[0] if schemes else "https"
|
||||
return f"{scheme}://{host}{base_path}".rstrip("/")
|
||||
return ""
|
||||
servers = spec.get("servers", [])
|
||||
if servers and isinstance(servers, list) and servers[0].get("url"):
|
||||
return servers[0]["url"].rstrip("/")
|
||||
return ""
|
||||
|
||||
|
||||
def _extract_actions(spec: Dict[str, Any], is_swagger: bool) -> List[Dict[str, Any]]:
|
||||
"""Extract all API operations as action definitions."""
|
||||
actions = []
|
||||
paths = spec.get("paths", {})
|
||||
base_url = _get_base_url(spec, is_swagger)
|
||||
|
||||
components = spec.get("components", {})
|
||||
definitions = spec.get("definitions", {})
|
||||
|
||||
for path, path_item in paths.items():
|
||||
if not isinstance(path_item, dict):
|
||||
continue
|
||||
path_params = path_item.get("parameters", [])
|
||||
|
||||
for method in SUPPORTED_METHODS:
|
||||
operation = path_item.get(method)
|
||||
if not isinstance(operation, dict):
|
||||
continue
|
||||
try:
|
||||
action = _build_action(
|
||||
path=path,
|
||||
method=method,
|
||||
operation=operation,
|
||||
path_params=path_params,
|
||||
base_url=base_url,
|
||||
components=components,
|
||||
definitions=definitions,
|
||||
is_swagger=is_swagger,
|
||||
)
|
||||
actions.append(action)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to parse operation {method.upper()} {path}: {e}"
|
||||
)
|
||||
continue
|
||||
return actions
|
||||
|
||||
|
||||
def _build_action(
|
||||
path: str,
|
||||
method: str,
|
||||
operation: Dict[str, Any],
|
||||
path_params: List[Dict],
|
||||
base_url: str,
|
||||
components: Dict[str, Any],
|
||||
definitions: Dict[str, Any],
|
||||
is_swagger: bool,
|
||||
) -> Dict[str, Any]:
|
||||
"""Build a single action from an API operation."""
|
||||
action_name = _generate_action_name(operation, method, path)
|
||||
full_url = f"{base_url}{path}" if base_url else path
|
||||
|
||||
all_params = path_params + operation.get("parameters", [])
|
||||
query_params, headers = _categorize_parameters(all_params, components, definitions)
|
||||
|
||||
body, body_content_type = _extract_request_body(
|
||||
operation, components, definitions, is_swagger
|
||||
)
|
||||
|
||||
description = operation.get("summary", "") or operation.get("description", "")
|
||||
|
||||
return {
|
||||
"name": action_name,
|
||||
"url": full_url,
|
||||
"method": method.upper(),
|
||||
"description": (description or "")[:500],
|
||||
"query_params": {"type": "object", "properties": query_params},
|
||||
"headers": {"type": "object", "properties": headers},
|
||||
"body": {"type": "object", "properties": body},
|
||||
"body_content_type": body_content_type,
|
||||
"active": True,
|
||||
}
|
||||
|
||||
|
||||
def _generate_action_name(operation: Dict[str, Any], method: str, path: str) -> str:
|
||||
"""Generate a valid action name from operationId or method+path."""
|
||||
if operation.get("operationId"):
|
||||
name = operation["operationId"]
|
||||
else:
|
||||
path_slug = re.sub(r"[{}]", "", path)
|
||||
path_slug = re.sub(r"[^a-zA-Z0-9]", "_", path_slug)
|
||||
path_slug = re.sub(r"_+", "_", path_slug).strip("_")
|
||||
name = f"{method}_{path_slug}"
|
||||
name = re.sub(r"[^a-zA-Z0-9_-]", "_", name)
|
||||
return name[:64]
|
||||
|
||||
|
||||
def _categorize_parameters(
|
||||
parameters: List[Dict],
|
||||
components: Dict[str, Any],
|
||||
definitions: Dict[str, Any],
|
||||
) -> Tuple[Dict, Dict]:
|
||||
"""Categorize parameters into query params and headers."""
|
||||
query_params = {}
|
||||
headers = {}
|
||||
|
||||
for param in parameters:
|
||||
resolved = _resolve_ref(param, components, definitions)
|
||||
if not resolved or "name" not in resolved:
|
||||
continue
|
||||
location = resolved.get("in", "query")
|
||||
prop = _param_to_property(resolved)
|
||||
|
||||
if location in ("query", "path"):
|
||||
query_params[resolved["name"]] = prop
|
||||
elif location == "header":
|
||||
headers[resolved["name"]] = prop
|
||||
return query_params, headers
|
||||
|
||||
|
||||
def _param_to_property(param: Dict) -> Dict[str, Any]:
|
||||
"""Convert an API parameter to an action property definition."""
|
||||
schema = param.get("schema", {})
|
||||
param_type = schema.get("type", param.get("type", "string"))
|
||||
|
||||
mapped_type = "integer" if param_type in ("integer", "number") else "string"
|
||||
|
||||
return {
|
||||
"type": mapped_type,
|
||||
"description": (param.get("description", "") or "")[:200],
|
||||
"value": "",
|
||||
"filled_by_llm": param.get("required", False),
|
||||
"required": param.get("required", False),
|
||||
}
|
||||
|
||||
|
||||
def _extract_request_body(
|
||||
operation: Dict[str, Any],
|
||||
components: Dict[str, Any],
|
||||
definitions: Dict[str, Any],
|
||||
is_swagger: bool,
|
||||
) -> Tuple[Dict, str]:
|
||||
"""Extract request body schema and content type."""
|
||||
content_types = [
|
||||
"application/json",
|
||||
"application/x-www-form-urlencoded",
|
||||
"multipart/form-data",
|
||||
"text/plain",
|
||||
"application/xml",
|
||||
]
|
||||
|
||||
if is_swagger:
|
||||
consumes = operation.get("consumes", [])
|
||||
body_param = next(
|
||||
(p for p in operation.get("parameters", []) if p.get("in") == "body"), None
|
||||
)
|
||||
if not body_param:
|
||||
return {}, "application/json"
|
||||
selected_type = consumes[0] if consumes else "application/json"
|
||||
schema = body_param.get("schema", {})
|
||||
else:
|
||||
request_body = operation.get("requestBody", {})
|
||||
if not request_body:
|
||||
return {}, "application/json"
|
||||
request_body = _resolve_ref(request_body, components, definitions)
|
||||
content = request_body.get("content", {})
|
||||
|
||||
selected_type = "application/json"
|
||||
schema = {}
|
||||
|
||||
for ct in content_types:
|
||||
if ct in content:
|
||||
selected_type = ct
|
||||
schema = content[ct].get("schema", {})
|
||||
break
|
||||
if not schema and content:
|
||||
first_type = next(iter(content))
|
||||
selected_type = first_type
|
||||
schema = content[first_type].get("schema", {})
|
||||
properties = _schema_to_properties(schema, components, definitions)
|
||||
return properties, selected_type
|
||||
|
||||
|
||||
def _schema_to_properties(
|
||||
schema: Dict,
|
||||
components: Dict[str, Any],
|
||||
definitions: Dict[str, Any],
|
||||
depth: int = 0,
|
||||
) -> Dict[str, Any]:
|
||||
"""Convert schema to action body properties (limited depth to prevent recursion)."""
|
||||
if depth > 3:
|
||||
return {}
|
||||
schema = _resolve_ref(schema, components, definitions)
|
||||
if not schema or not isinstance(schema, dict):
|
||||
return {}
|
||||
properties = {}
|
||||
schema_type = schema.get("type", "object")
|
||||
|
||||
if schema_type == "object":
|
||||
required_fields = set(schema.get("required", []))
|
||||
for prop_name, prop_schema in schema.get("properties", {}).items():
|
||||
resolved = _resolve_ref(prop_schema, components, definitions)
|
||||
if not isinstance(resolved, dict):
|
||||
continue
|
||||
prop_type = resolved.get("type", "string")
|
||||
mapped_type = "integer" if prop_type in ("integer", "number") else "string"
|
||||
|
||||
properties[prop_name] = {
|
||||
"type": mapped_type,
|
||||
"description": (resolved.get("description", "") or "")[:200],
|
||||
"value": "",
|
||||
"filled_by_llm": prop_name in required_fields,
|
||||
"required": prop_name in required_fields,
|
||||
}
|
||||
return properties
|
||||
|
||||
|
||||
def _resolve_ref(
|
||||
obj: Any,
|
||||
components: Dict[str, Any],
|
||||
definitions: Dict[str, Any],
|
||||
) -> Optional[Dict]:
|
||||
"""Resolve $ref references in the specification."""
|
||||
if not isinstance(obj, dict):
|
||||
return obj if isinstance(obj, dict) else None
|
||||
if "$ref" not in obj:
|
||||
return obj
|
||||
ref_path = obj["$ref"]
|
||||
|
||||
if ref_path.startswith("#/components/"):
|
||||
parts = ref_path.replace("#/components/", "").split("/")
|
||||
return _traverse_path(components, parts)
|
||||
elif ref_path.startswith("#/definitions/"):
|
||||
parts = ref_path.replace("#/definitions/", "").split("/")
|
||||
return _traverse_path(definitions, parts)
|
||||
logger.debug(f"Unsupported ref path: {ref_path}")
|
||||
return None
|
||||
|
||||
|
||||
def _traverse_path(obj: Dict, parts: List[str]) -> Optional[Dict]:
|
||||
"""Traverse a nested dictionary using path parts."""
|
||||
try:
|
||||
for part in parts:
|
||||
obj = obj[part]
|
||||
return obj if isinstance(obj, dict) else None
|
||||
except (KeyError, TypeError):
|
||||
return None
|
||||
@@ -1,6 +1,11 @@
|
||||
import logging
|
||||
|
||||
import requests
|
||||
|
||||
from application.agents.tools.base import Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TelegramTool(Tool):
|
||||
"""
|
||||
@@ -18,24 +23,22 @@ class TelegramTool(Tool):
|
||||
"telegram_send_message": self._send_message,
|
||||
"telegram_send_image": self._send_image,
|
||||
}
|
||||
|
||||
if action_name in actions:
|
||||
return actions[action_name](**kwargs)
|
||||
else:
|
||||
if action_name not in actions:
|
||||
raise ValueError(f"Unknown action: {action_name}")
|
||||
return actions[action_name](**kwargs)
|
||||
|
||||
def _send_message(self, text, chat_id):
|
||||
print(f"Sending message: {text}")
|
||||
logger.debug("Sending Telegram message to chat_id=%s", chat_id)
|
||||
url = f"https://api.telegram.org/bot{self.token}/sendMessage"
|
||||
payload = {"chat_id": chat_id, "text": text}
|
||||
response = requests.post(url, data=payload)
|
||||
response = requests.post(url, data=payload, timeout=100)
|
||||
return {"status_code": response.status_code, "message": "Message sent"}
|
||||
|
||||
def _send_image(self, image_url, chat_id):
|
||||
print(f"Sending image: {image_url}")
|
||||
logger.debug("Sending Telegram image to chat_id=%s", chat_id)
|
||||
url = f"https://api.telegram.org/bot{self.token}/sendPhoto"
|
||||
payload = {"chat_id": chat_id, "photo": image_url}
|
||||
response = requests.post(url, data=payload)
|
||||
response = requests.post(url, data=payload, timeout=100)
|
||||
return {"status_code": response.status_code, "message": "Image sent"}
|
||||
|
||||
def get_actions_metadata(self):
|
||||
@@ -82,5 +85,12 @@ class TelegramTool(Tool):
|
||||
|
||||
def get_config_requirements(self):
|
||||
return {
|
||||
"token": {"type": "string", "description": "Bot token for authentication"},
|
||||
"token": {
|
||||
"type": "string",
|
||||
"label": "Bot Token",
|
||||
"description": "Telegram bot token for authentication",
|
||||
"required": True,
|
||||
"secret": True,
|
||||
"order": 1,
|
||||
},
|
||||
}
|
||||
|
||||
70
application/agents/tools/think.py
Normal file
70
application/agents/tools/think.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from application.agents.tools.base import Tool
|
||||
|
||||
|
||||
THINK_TOOL_ID = "think"
|
||||
|
||||
THINK_TOOL_ENTRY = {
|
||||
"name": "think",
|
||||
"actions": [
|
||||
{
|
||||
"name": "reason",
|
||||
"description": (
|
||||
"Use this tool to think through your reasoning step by step "
|
||||
"before deciding on your next action. Always reason before "
|
||||
"searching or answering."
|
||||
),
|
||||
"active": True,
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"reasoning": {
|
||||
"type": "string",
|
||||
"description": "Your step-by-step reasoning and analysis",
|
||||
"filled_by_llm": True,
|
||||
"required": True,
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class ThinkTool(Tool):
|
||||
"""Pseudo-tool that captures chain-of-thought reasoning.
|
||||
|
||||
Returns a short acknowledgment so the LLM can continue.
|
||||
The reasoning content is captured in tool_call data for transparency.
|
||||
"""
|
||||
|
||||
internal = True
|
||||
|
||||
def __init__(self, config=None):
|
||||
pass
|
||||
|
||||
def execute_action(self, action_name: str, **kwargs):
|
||||
return "Continue."
|
||||
|
||||
def get_actions_metadata(self):
|
||||
return [
|
||||
{
|
||||
"name": "reason",
|
||||
"description": (
|
||||
"Use this tool to think through your reasoning step by step "
|
||||
"before deciding on your next action. Always reason before "
|
||||
"searching or answering."
|
||||
),
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"reasoning": {
|
||||
"type": "string",
|
||||
"description": "Your step-by-step reasoning and analysis",
|
||||
"filled_by_llm": True,
|
||||
"required": True,
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
def get_config_requirements(self):
|
||||
return {}
|
||||
351
application/agents/tools/todo_list.py
Normal file
351
application/agents/tools/todo_list.py
Normal file
@@ -0,0 +1,351 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
import uuid
|
||||
|
||||
from .base import Tool
|
||||
from application.storage.db.repositories.todos import TodosRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
|
||||
def _status_from_completed(completed: Any) -> str:
|
||||
"""Translate the PG ``completed`` boolean to the legacy status string.
|
||||
|
||||
The frontend (and prior LLM-facing tool output) expects
|
||||
``"open"`` / ``"completed"``. Keeping that contract at the tool
|
||||
boundary insulates callers from the schema change.
|
||||
"""
|
||||
return "completed" if bool(completed) else "open"
|
||||
|
||||
|
||||
class TodoListTool(Tool):
|
||||
"""Todo List
|
||||
|
||||
Manages todo items for users. Supports creating, viewing, updating, and deleting todos.
|
||||
"""
|
||||
|
||||
def __init__(self, tool_config: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> None:
|
||||
"""Initialize the tool.
|
||||
|
||||
Args:
|
||||
tool_config: Optional tool configuration. Should include:
|
||||
- tool_id: Unique identifier for this todo list tool instance (from user_tools._id)
|
||||
This ensures each user's tool configuration has isolated todos
|
||||
user_id: The authenticated user's id (should come from decoded_token["sub"]).
|
||||
"""
|
||||
self.user_id: Optional[str] = user_id
|
||||
|
||||
# Get tool_id from configuration (passed from user_tools._id in production)
|
||||
if tool_config and "tool_id" in tool_config:
|
||||
self.tool_id = tool_config["tool_id"]
|
||||
elif user_id:
|
||||
# Fallback for backward compatibility or testing
|
||||
self.tool_id = f"default_{user_id}"
|
||||
else:
|
||||
# Last resort fallback (shouldn't happen in normal use)
|
||||
self.tool_id = str(uuid.uuid4())
|
||||
|
||||
self._last_artifact_id: Optional[str] = None
|
||||
|
||||
def _pg_enabled(self) -> bool:
|
||||
"""Return True only when ``tool_id`` is a real ``user_tools.id`` UUID.
|
||||
|
||||
The ``todos`` PG table has a UUID foreign key to ``user_tools`` and
|
||||
the repo queries ``CAST(:tool_id AS uuid)``. The sentinel
|
||||
``default_{uid}`` fallback is neither a UUID nor a row in
|
||||
``user_tools`` — binding it would crash ``invalid input syntax for
|
||||
type uuid`` and even if it didn't the FK would reject it. Mirror
|
||||
the MemoryTool guard and no-op in that case.
|
||||
"""
|
||||
tool_id = getattr(self, "tool_id", None)
|
||||
if not tool_id or not isinstance(tool_id, str):
|
||||
return False
|
||||
if tool_id.startswith("default_"):
|
||||
return False
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
|
||||
return looks_like_uuid(tool_id)
|
||||
|
||||
# -----------------------------
|
||||
# Action implementations
|
||||
# -----------------------------
|
||||
def execute_action(self, action_name: str, **kwargs: Any) -> str:
|
||||
"""Execute an action by name.
|
||||
|
||||
Args:
|
||||
action_name: One of list, create, get, update, complete, delete.
|
||||
**kwargs: Parameters for the action.
|
||||
|
||||
Returns:
|
||||
A human-readable string result.
|
||||
"""
|
||||
if not self.user_id:
|
||||
return "Error: TodoListTool requires a valid user_id."
|
||||
|
||||
if not self._pg_enabled():
|
||||
return (
|
||||
"Error: TodoListTool is not configured with a persistent "
|
||||
"tool_id; todo storage is unavailable for this session."
|
||||
)
|
||||
|
||||
self._last_artifact_id = None
|
||||
|
||||
if action_name == "list":
|
||||
return self._list()
|
||||
|
||||
if action_name == "create":
|
||||
return self._create(kwargs.get("title", ""))
|
||||
|
||||
if action_name == "get":
|
||||
return self._get(kwargs.get("todo_id"))
|
||||
|
||||
if action_name == "update":
|
||||
return self._update(
|
||||
kwargs.get("todo_id"),
|
||||
kwargs.get("title", "")
|
||||
)
|
||||
|
||||
if action_name == "complete":
|
||||
return self._complete(kwargs.get("todo_id"))
|
||||
|
||||
if action_name == "delete":
|
||||
return self._delete(kwargs.get("todo_id"))
|
||||
|
||||
return f"Unknown action: {action_name}"
|
||||
|
||||
def get_actions_metadata(self) -> List[Dict[str, Any]]:
|
||||
"""Return JSON metadata describing supported actions for tool schemas."""
|
||||
return [
|
||||
{
|
||||
"name": "list",
|
||||
"description": "List all todos for the user.",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
{
|
||||
"name": "create",
|
||||
"description": "Create a new todo item.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {
|
||||
"type": "string",
|
||||
"description": "Title of the todo item."
|
||||
}
|
||||
},
|
||||
"required": ["title"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "get",
|
||||
"description": "Get a specific todo by ID.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"todo_id": {
|
||||
"type": "integer",
|
||||
"description": "The ID of the todo to retrieve."
|
||||
}
|
||||
},
|
||||
"required": ["todo_id"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "update",
|
||||
"description": "Update a todo's title by ID.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"todo_id": {
|
||||
"type": "integer",
|
||||
"description": "The ID of the todo to update."
|
||||
},
|
||||
"title": {
|
||||
"type": "string",
|
||||
"description": "The new title for the todo."
|
||||
}
|
||||
},
|
||||
"required": ["todo_id", "title"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "complete",
|
||||
"description": "Mark a todo as completed.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"todo_id": {
|
||||
"type": "integer",
|
||||
"description": "The ID of the todo to mark as completed."
|
||||
}
|
||||
},
|
||||
"required": ["todo_id"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "delete",
|
||||
"description": "Delete a specific todo by ID.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"todo_id": {
|
||||
"type": "integer",
|
||||
"description": "The ID of the todo to delete."
|
||||
}
|
||||
},
|
||||
"required": ["todo_id"],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
def get_config_requirements(self) -> Dict[str, Any]:
|
||||
"""Return configuration requirements."""
|
||||
return {}
|
||||
|
||||
def get_artifact_id(self, action_name: str, **kwargs: Any) -> Optional[str]:
|
||||
return self._last_artifact_id
|
||||
|
||||
# -----------------------------
|
||||
# Internal helpers
|
||||
# -----------------------------
|
||||
def _coerce_todo_id(self, value: Optional[Any]) -> Optional[int]:
|
||||
"""Convert todo identifiers to sequential integers."""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if isinstance(value, int):
|
||||
return value if value > 0 else None
|
||||
|
||||
if isinstance(value, str):
|
||||
stripped = value.strip()
|
||||
if stripped.isdigit():
|
||||
numeric_value = int(stripped)
|
||||
return numeric_value if numeric_value > 0 else None
|
||||
|
||||
return None
|
||||
|
||||
def _list(self) -> str:
|
||||
"""List all todos for the user."""
|
||||
with db_readonly() as conn:
|
||||
todos = TodosRepository(conn).list_for_tool(self.user_id, self.tool_id)
|
||||
|
||||
if not todos:
|
||||
return "No todos found."
|
||||
|
||||
result_lines = ["Todos:"]
|
||||
for doc in todos:
|
||||
todo_id = doc.get("todo_id")
|
||||
title = doc.get("title", "Untitled")
|
||||
status = _status_from_completed(doc.get("completed"))
|
||||
|
||||
line = f"[{todo_id}] {title} ({status})"
|
||||
result_lines.append(line)
|
||||
|
||||
return "\n".join(result_lines)
|
||||
|
||||
def _create(self, title: str) -> str:
|
||||
"""Create a new todo item.
|
||||
|
||||
``TodosRepository.create`` allocates the per-tool monotonic
|
||||
``todo_id`` inside the same transaction (``COALESCE(MAX(todo_id),0)+1``
|
||||
scoped to ``tool_id``), so we no longer need a separate read-then-
|
||||
write step here.
|
||||
"""
|
||||
title = (title or "").strip()
|
||||
if not title:
|
||||
return "Error: Title is required."
|
||||
|
||||
with db_session() as conn:
|
||||
row = TodosRepository(conn).create(self.user_id, self.tool_id, title)
|
||||
|
||||
todo_id = row.get("todo_id")
|
||||
if row.get("id") is not None:
|
||||
self._last_artifact_id = str(row.get("id"))
|
||||
return f"Todo created with ID {todo_id}: {title}"
|
||||
|
||||
def _get(self, todo_id: Optional[Any]) -> str:
|
||||
"""Get a specific todo by ID."""
|
||||
parsed_todo_id = self._coerce_todo_id(todo_id)
|
||||
if parsed_todo_id is None:
|
||||
return "Error: todo_id must be a positive integer."
|
||||
|
||||
with db_readonly() as conn:
|
||||
doc = TodosRepository(conn).get_by_tool_and_todo_id(
|
||||
self.user_id, self.tool_id, parsed_todo_id
|
||||
)
|
||||
|
||||
if not doc:
|
||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||
|
||||
if doc.get("id") is not None:
|
||||
self._last_artifact_id = str(doc.get("id"))
|
||||
|
||||
title = doc.get("title", "Untitled")
|
||||
status = _status_from_completed(doc.get("completed"))
|
||||
|
||||
return f"Todo [{parsed_todo_id}]:\nTitle: {title}\nStatus: {status}"
|
||||
|
||||
def _update(self, todo_id: Optional[Any], title: str) -> str:
|
||||
"""Update a todo's title by ID."""
|
||||
parsed_todo_id = self._coerce_todo_id(todo_id)
|
||||
if parsed_todo_id is None:
|
||||
return "Error: todo_id must be a positive integer."
|
||||
|
||||
title = (title or "").strip()
|
||||
if not title:
|
||||
return "Error: Title is required."
|
||||
|
||||
with db_session() as conn:
|
||||
repo = TodosRepository(conn)
|
||||
existing = repo.get_by_tool_and_todo_id(
|
||||
self.user_id, self.tool_id, parsed_todo_id
|
||||
)
|
||||
if not existing:
|
||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||
repo.update_title_by_tool_and_todo_id(
|
||||
self.user_id, self.tool_id, parsed_todo_id, title
|
||||
)
|
||||
|
||||
if existing.get("id") is not None:
|
||||
self._last_artifact_id = str(existing.get("id"))
|
||||
|
||||
return f"Todo {parsed_todo_id} updated to: {title}"
|
||||
|
||||
def _complete(self, todo_id: Optional[Any]) -> str:
|
||||
"""Mark a todo as completed."""
|
||||
parsed_todo_id = self._coerce_todo_id(todo_id)
|
||||
if parsed_todo_id is None:
|
||||
return "Error: todo_id must be a positive integer."
|
||||
|
||||
with db_session() as conn:
|
||||
repo = TodosRepository(conn)
|
||||
existing = repo.get_by_tool_and_todo_id(
|
||||
self.user_id, self.tool_id, parsed_todo_id
|
||||
)
|
||||
if not existing:
|
||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||
repo.set_completed(self.user_id, self.tool_id, parsed_todo_id, True)
|
||||
|
||||
if existing.get("id") is not None:
|
||||
self._last_artifact_id = str(existing.get("id"))
|
||||
|
||||
return f"Todo {parsed_todo_id} marked as completed."
|
||||
|
||||
def _delete(self, todo_id: Optional[Any]) -> str:
|
||||
"""Delete a specific todo by ID."""
|
||||
parsed_todo_id = self._coerce_todo_id(todo_id)
|
||||
if parsed_todo_id is None:
|
||||
return "Error: todo_id must be a positive integer."
|
||||
|
||||
with db_session() as conn:
|
||||
repo = TodosRepository(conn)
|
||||
existing = repo.get_by_tool_and_todo_id(
|
||||
self.user_id, self.tool_id, parsed_todo_id
|
||||
)
|
||||
if not existing:
|
||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||
repo.delete_by_tool_and_todo_id(
|
||||
self.user_id, self.tool_id, parsed_todo_id
|
||||
)
|
||||
|
||||
if existing.get("id") is not None:
|
||||
self._last_artifact_id = str(existing.get("id"))
|
||||
|
||||
return f"Todo {parsed_todo_id} deleted."
|
||||
@@ -5,8 +5,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolActionParser:
|
||||
def __init__(self, llm_type):
|
||||
def __init__(self, llm_type, name_mapping=None):
|
||||
self.llm_type = llm_type
|
||||
self.name_mapping = name_mapping
|
||||
self.parsers = {
|
||||
"OpenAILLM": self._parse_openai_llm,
|
||||
"GoogleLLM": self._parse_google_llm,
|
||||
@@ -16,27 +17,70 @@ class ToolActionParser:
|
||||
parser = self.parsers.get(self.llm_type, self._parse_openai_llm)
|
||||
return parser(call)
|
||||
|
||||
def _resolve_via_mapping(self, call_name):
|
||||
"""Look up (tool_id, action_name) from the name mapping if available."""
|
||||
if self.name_mapping and call_name in self.name_mapping:
|
||||
return self.name_mapping[call_name]
|
||||
return None
|
||||
|
||||
def _parse_openai_llm(self, call):
|
||||
if isinstance(call, dict):
|
||||
try:
|
||||
call_args = json.loads(call["function"]["arguments"])
|
||||
tool_id = call["function"]["name"].split("_")[-1]
|
||||
action_name = call["function"]["name"].rsplit("_", 1)[0]
|
||||
except (KeyError, TypeError) as e:
|
||||
logger.error(f"Error parsing OpenAI LLM call: {e}")
|
||||
return None, None, None
|
||||
else:
|
||||
try:
|
||||
call_args = json.loads(call.function.arguments)
|
||||
tool_id = call.function.name.split("_")[-1]
|
||||
action_name = call.function.name.rsplit("_", 1)[0]
|
||||
except (AttributeError, TypeError) as e:
|
||||
logger.error(f"Error parsing OpenAI LLM call: {e}")
|
||||
try:
|
||||
call_args = json.loads(call.arguments)
|
||||
|
||||
resolved = self._resolve_via_mapping(call.name)
|
||||
if resolved:
|
||||
return resolved[0], resolved[1], call_args
|
||||
|
||||
# Fallback: legacy split on "_" for backward compatibility
|
||||
tool_parts = call.name.split("_")
|
||||
|
||||
if len(tool_parts) < 2:
|
||||
logger.warning(
|
||||
f"Invalid tool name format: {call.name}. "
|
||||
"Could not resolve via mapping or legacy parsing."
|
||||
)
|
||||
return None, None, None
|
||||
|
||||
tool_id = tool_parts[-1]
|
||||
action_name = "_".join(tool_parts[:-1])
|
||||
|
||||
if not tool_id.isdigit():
|
||||
logger.warning(
|
||||
f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call."
|
||||
)
|
||||
|
||||
except (AttributeError, TypeError, json.JSONDecodeError) as e:
|
||||
logger.error(f"Error parsing OpenAI LLM call: {e}")
|
||||
return None, None, None
|
||||
return tool_id, action_name, call_args
|
||||
|
||||
def _parse_google_llm(self, call):
|
||||
call_args = call.args
|
||||
tool_id = call.name.split("_")[-1]
|
||||
action_name = call.name.rsplit("_", 1)[0]
|
||||
try:
|
||||
call_args = call.arguments
|
||||
|
||||
resolved = self._resolve_via_mapping(call.name)
|
||||
if resolved:
|
||||
return resolved[0], resolved[1], call_args
|
||||
|
||||
# Fallback: legacy split on "_" for backward compatibility
|
||||
tool_parts = call.name.split("_")
|
||||
|
||||
if len(tool_parts) < 2:
|
||||
logger.warning(
|
||||
f"Invalid tool name format: {call.name}. "
|
||||
"Could not resolve via mapping or legacy parsing."
|
||||
)
|
||||
return None, None, None
|
||||
|
||||
tool_id = tool_parts[-1]
|
||||
action_name = "_".join(tool_parts[:-1])
|
||||
|
||||
if not tool_id.isdigit():
|
||||
logger.warning(
|
||||
f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call."
|
||||
)
|
||||
|
||||
except (AttributeError, TypeError) as e:
|
||||
logger.error(f"Error parsing Google LLM call: {e}")
|
||||
return None, None, None
|
||||
return tool_id, action_name, call_args
|
||||
|
||||
@@ -19,20 +19,27 @@ class ToolManager:
|
||||
continue
|
||||
module = importlib.import_module(f"application.agents.tools.{name}")
|
||||
for member_name, obj in inspect.getmembers(module, inspect.isclass):
|
||||
if issubclass(obj, Tool) and obj is not Tool:
|
||||
if issubclass(obj, Tool) and obj is not Tool and not obj.internal:
|
||||
tool_config = self.config.get(name, {})
|
||||
self.tools[name] = obj(tool_config)
|
||||
|
||||
def load_tool(self, tool_name, tool_config):
|
||||
def load_tool(self, tool_name, tool_config, user_id=None):
|
||||
self.config[tool_name] = tool_config
|
||||
module = importlib.import_module(f"application.agents.tools.{tool_name}")
|
||||
for member_name, obj in inspect.getmembers(module, inspect.isclass):
|
||||
if issubclass(obj, Tool) and obj is not Tool:
|
||||
return obj(tool_config)
|
||||
if tool_name in {"mcp_tool", "notes", "memory", "todo_list"} and user_id:
|
||||
return obj(tool_config, user_id)
|
||||
else:
|
||||
return obj(tool_config)
|
||||
|
||||
def execute_action(self, tool_name, action_name, **kwargs):
|
||||
def execute_action(self, tool_name, action_name, user_id=None, **kwargs):
|
||||
if tool_name not in self.tools:
|
||||
raise ValueError(f"Tool '{tool_name}' not loaded")
|
||||
if tool_name in {"mcp_tool", "memory", "todo_list", "notes"} and user_id:
|
||||
tool_config = self.config.get(tool_name, {})
|
||||
tool = self.load_tool(tool_name, tool_config, user_id)
|
||||
return tool.execute_action(action_name, **kwargs)
|
||||
return self.tools[tool_name].execute_action(action_name, **kwargs)
|
||||
|
||||
def get_all_actions_metadata(self):
|
||||
|
||||
254
application/agents/workflow_agent.py
Normal file
254
application/agents/workflow_agent.py
Normal file
@@ -0,0 +1,254 @@
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Generator, Optional
|
||||
|
||||
from application.agents.base import BaseAgent
|
||||
from application.agents.workflows.schemas import (
|
||||
ExecutionStatus,
|
||||
Workflow,
|
||||
WorkflowEdge,
|
||||
WorkflowGraph,
|
||||
WorkflowNode,
|
||||
WorkflowRun,
|
||||
)
|
||||
from application.agents.workflows.workflow_engine import WorkflowEngine
|
||||
from application.logging import log_activity, LogContext
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
from application.storage.db.repositories.workflow_edges import WorkflowEdgesRepository
|
||||
from application.storage.db.repositories.workflow_nodes import WorkflowNodesRepository
|
||||
from application.storage.db.repositories.workflow_runs import WorkflowRunsRepository
|
||||
from application.storage.db.repositories.workflows import WorkflowsRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowAgent(BaseAgent):
|
||||
"""A specialized agent that executes predefined workflows."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
workflow_id: Optional[str] = None,
|
||||
workflow: Optional[Dict[str, Any]] = None,
|
||||
workflow_owner: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.workflow_id = workflow_id
|
||||
self.workflow_owner = workflow_owner
|
||||
self._workflow_data = workflow
|
||||
self._engine: Optional[WorkflowEngine] = None
|
||||
|
||||
@log_activity()
|
||||
def gen(
|
||||
self, query: str, log_context: LogContext = None
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
yield from self._gen_inner(query, log_context)
|
||||
|
||||
def _gen_inner(
|
||||
self, query: str, log_context: LogContext
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
graph = self._load_workflow_graph()
|
||||
if not graph:
|
||||
yield {"type": "error", "error": "Failed to load workflow configuration."}
|
||||
return
|
||||
self._engine = WorkflowEngine(graph, self)
|
||||
yield from self._engine.execute({}, query)
|
||||
self._save_workflow_run(query)
|
||||
|
||||
def _load_workflow_graph(self) -> Optional[WorkflowGraph]:
|
||||
if self._workflow_data:
|
||||
return self._parse_embedded_workflow()
|
||||
if self.workflow_id:
|
||||
return self._load_from_database()
|
||||
return None
|
||||
|
||||
def _parse_embedded_workflow(self) -> Optional[WorkflowGraph]:
|
||||
try:
|
||||
nodes_data = self._workflow_data.get("nodes", [])
|
||||
edges_data = self._workflow_data.get("edges", [])
|
||||
|
||||
workflow = Workflow(
|
||||
name=self._workflow_data.get("name", "Embedded Workflow"),
|
||||
description=self._workflow_data.get("description"),
|
||||
)
|
||||
|
||||
nodes = []
|
||||
for n in nodes_data:
|
||||
node_config = n.get("data", {})
|
||||
nodes.append(
|
||||
WorkflowNode(
|
||||
id=n["id"],
|
||||
workflow_id=self.workflow_id or "embedded",
|
||||
type=n["type"],
|
||||
title=n.get("title", "Node"),
|
||||
description=n.get("description"),
|
||||
position=n.get("position", {"x": 0, "y": 0}),
|
||||
config=node_config,
|
||||
)
|
||||
)
|
||||
edges = []
|
||||
for e in edges_data:
|
||||
edges.append(
|
||||
WorkflowEdge(
|
||||
id=e["id"],
|
||||
workflow_id=self.workflow_id or "embedded",
|
||||
source=e.get("source") or e.get("source_id"),
|
||||
target=e.get("target") or e.get("target_id"),
|
||||
sourceHandle=e.get("sourceHandle") or e.get("source_handle"),
|
||||
targetHandle=e.get("targetHandle") or e.get("target_handle"),
|
||||
)
|
||||
)
|
||||
return WorkflowGraph(workflow=workflow, nodes=nodes, edges=edges)
|
||||
except Exception as e:
|
||||
logger.error(f"Invalid embedded workflow: {e}")
|
||||
return None
|
||||
|
||||
def _load_from_database(self) -> Optional[WorkflowGraph]:
|
||||
try:
|
||||
if not self.workflow_id:
|
||||
logger.error("Missing workflow ID for load")
|
||||
return None
|
||||
owner_id = self.workflow_owner
|
||||
if not owner_id and isinstance(self.decoded_token, dict):
|
||||
owner_id = self.decoded_token.get("sub")
|
||||
if not owner_id:
|
||||
logger.error(
|
||||
f"Workflow owner not available for workflow load: {self.workflow_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
with db_readonly() as conn:
|
||||
wf_repo = WorkflowsRepository(conn)
|
||||
if looks_like_uuid(self.workflow_id):
|
||||
workflow_row = wf_repo.get(self.workflow_id, owner_id)
|
||||
else:
|
||||
workflow_row = wf_repo.get_by_legacy_id(self.workflow_id, owner_id)
|
||||
if workflow_row is None:
|
||||
logger.error(
|
||||
f"Workflow {self.workflow_id} not found or inaccessible "
|
||||
f"for user {owner_id}"
|
||||
)
|
||||
return None
|
||||
pg_workflow_id = str(workflow_row["id"])
|
||||
graph_version = workflow_row.get("current_graph_version", 1)
|
||||
try:
|
||||
graph_version = int(graph_version)
|
||||
if graph_version <= 0:
|
||||
graph_version = 1
|
||||
except (ValueError, TypeError):
|
||||
graph_version = 1
|
||||
|
||||
node_rows = WorkflowNodesRepository(conn).find_by_version(
|
||||
pg_workflow_id, graph_version,
|
||||
)
|
||||
edge_rows = WorkflowEdgesRepository(conn).find_by_version(
|
||||
pg_workflow_id, graph_version,
|
||||
)
|
||||
|
||||
workflow = Workflow(
|
||||
name=workflow_row.get("name"),
|
||||
description=workflow_row.get("description"),
|
||||
)
|
||||
nodes = [
|
||||
WorkflowNode(
|
||||
id=n["node_id"],
|
||||
workflow_id=pg_workflow_id,
|
||||
type=n["node_type"],
|
||||
title=n.get("title") or "Node",
|
||||
description=n.get("description"),
|
||||
position=n.get("position") or {"x": 0, "y": 0},
|
||||
config=n.get("config") or {},
|
||||
)
|
||||
for n in node_rows
|
||||
]
|
||||
edges = [
|
||||
WorkflowEdge(
|
||||
id=e["edge_id"],
|
||||
workflow_id=pg_workflow_id,
|
||||
source=e.get("source_id"),
|
||||
target=e.get("target_id"),
|
||||
sourceHandle=e.get("source_handle"),
|
||||
targetHandle=e.get("target_handle"),
|
||||
)
|
||||
for e in edge_rows
|
||||
]
|
||||
|
||||
return WorkflowGraph(workflow=workflow, nodes=nodes, edges=edges)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load workflow from database: {e}")
|
||||
return None
|
||||
|
||||
def _save_workflow_run(self, query: str) -> None:
|
||||
if not self._engine:
|
||||
return
|
||||
owner_id = self.workflow_owner
|
||||
if not owner_id and isinstance(self.decoded_token, dict):
|
||||
owner_id = self.decoded_token.get("sub")
|
||||
try:
|
||||
run = WorkflowRun(
|
||||
workflow_id=self.workflow_id or "unknown",
|
||||
user=owner_id,
|
||||
status=self._determine_run_status(),
|
||||
inputs={"query": query},
|
||||
outputs=self._serialize_state(self._engine.state),
|
||||
steps=self._engine.get_execution_summary(),
|
||||
created_at=datetime.now(timezone.utc),
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
if not self.workflow_id or not owner_id:
|
||||
return
|
||||
with db_session() as conn:
|
||||
wf_repo = WorkflowsRepository(conn)
|
||||
if looks_like_uuid(self.workflow_id):
|
||||
workflow_row = wf_repo.get(self.workflow_id, owner_id)
|
||||
else:
|
||||
workflow_row = wf_repo.get_by_legacy_id(
|
||||
self.workflow_id, owner_id,
|
||||
)
|
||||
if workflow_row is None:
|
||||
return
|
||||
WorkflowRunsRepository(conn).create(
|
||||
str(workflow_row["id"]),
|
||||
owner_id,
|
||||
run.status.value,
|
||||
inputs=run.inputs,
|
||||
result=run.outputs,
|
||||
steps=[step.model_dump(mode="json") for step in run.steps],
|
||||
started_at=run.created_at,
|
||||
ended_at=run.completed_at,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save workflow run: {e}")
|
||||
|
||||
def _determine_run_status(self) -> ExecutionStatus:
|
||||
if not self._engine or not self._engine.execution_log:
|
||||
return ExecutionStatus.COMPLETED
|
||||
for log in self._engine.execution_log:
|
||||
if log.get("status") == ExecutionStatus.FAILED.value:
|
||||
return ExecutionStatus.FAILED
|
||||
return ExecutionStatus.COMPLETED
|
||||
|
||||
def _serialize_state(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
||||
serialized: Dict[str, Any] = {}
|
||||
for key, value in state.items():
|
||||
serialized[key] = self._serialize_state_value(value)
|
||||
return serialized
|
||||
|
||||
def _serialize_state_value(self, value: Any) -> Any:
|
||||
if isinstance(value, dict):
|
||||
return {
|
||||
str(dict_key): self._serialize_state_value(dict_value)
|
||||
for dict_key, dict_value in value.items()
|
||||
}
|
||||
if isinstance(value, list):
|
||||
return [self._serialize_state_value(item) for item in value]
|
||||
if isinstance(value, tuple):
|
||||
return [self._serialize_state_value(item) for item in value]
|
||||
if isinstance(value, datetime):
|
||||
return value.isoformat()
|
||||
if isinstance(value, (str, int, float, bool, type(None))):
|
||||
return value
|
||||
return str(value)
|
||||
64
application/agents/workflows/cel_evaluator.py
Normal file
64
application/agents/workflows/cel_evaluator.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
import celpy
|
||||
import celpy.celtypes
|
||||
|
||||
|
||||
class CelEvaluationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _convert_value(value: Any) -> Any:
|
||||
if isinstance(value, bool):
|
||||
return celpy.celtypes.BoolType(value)
|
||||
if isinstance(value, int):
|
||||
return celpy.celtypes.IntType(value)
|
||||
if isinstance(value, float):
|
||||
return celpy.celtypes.DoubleType(value)
|
||||
if isinstance(value, str):
|
||||
return celpy.celtypes.StringType(value)
|
||||
if isinstance(value, list):
|
||||
return celpy.celtypes.ListType([_convert_value(item) for item in value])
|
||||
if isinstance(value, dict):
|
||||
return celpy.celtypes.MapType(
|
||||
{celpy.celtypes.StringType(k): _convert_value(v) for k, v in value.items()}
|
||||
)
|
||||
if value is None:
|
||||
return celpy.celtypes.BoolType(False)
|
||||
return celpy.celtypes.StringType(str(value))
|
||||
|
||||
|
||||
def build_activation(state: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return {k: _convert_value(v) for k, v in state.items()}
|
||||
|
||||
|
||||
def evaluate_cel(expression: str, state: Dict[str, Any]) -> Any:
|
||||
if not expression or not expression.strip():
|
||||
raise CelEvaluationError("Empty expression")
|
||||
try:
|
||||
env = celpy.Environment()
|
||||
ast = env.compile(expression)
|
||||
program = env.program(ast)
|
||||
activation = build_activation(state)
|
||||
result = program.evaluate(activation)
|
||||
except celpy.CELEvalError as exc:
|
||||
raise CelEvaluationError(f"CEL evaluation error: {exc}") from exc
|
||||
except Exception as exc:
|
||||
raise CelEvaluationError(f"CEL error: {exc}") from exc
|
||||
return cel_to_python(result)
|
||||
|
||||
|
||||
def cel_to_python(value: Any) -> Any:
|
||||
if isinstance(value, celpy.celtypes.BoolType):
|
||||
return bool(value)
|
||||
if isinstance(value, celpy.celtypes.IntType):
|
||||
return int(value)
|
||||
if isinstance(value, celpy.celtypes.DoubleType):
|
||||
return float(value)
|
||||
if isinstance(value, celpy.celtypes.StringType):
|
||||
return str(value)
|
||||
if isinstance(value, celpy.celtypes.ListType):
|
||||
return [cel_to_python(item) for item in value]
|
||||
if isinstance(value, celpy.celtypes.MapType):
|
||||
return {str(k): cel_to_python(v) for k, v in value.items()}
|
||||
return value
|
||||
104
application/agents/workflows/node_agent.py
Normal file
104
application/agents/workflows/node_agent.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""Workflow Node Agents - defines specialized agents for workflow nodes."""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
from application.agents.agentic_agent import AgenticAgent
|
||||
from application.agents.base import BaseAgent
|
||||
from application.agents.classic_agent import ClassicAgent
|
||||
from application.agents.research_agent import ResearchAgent
|
||||
from application.agents.workflows.schemas import AgentType
|
||||
|
||||
|
||||
class ToolFilterMixin:
|
||||
"""Mixin that filters fetched tools to only those specified in tool_ids."""
|
||||
|
||||
_allowed_tool_ids: List[str]
|
||||
|
||||
def _get_user_tools(self, user: str = "local") -> Dict[str, Dict[str, Any]]:
|
||||
all_tools = super()._get_user_tools(user)
|
||||
if not self._allowed_tool_ids:
|
||||
return {}
|
||||
filtered_tools = {
|
||||
tool_id: tool
|
||||
for tool_id, tool in all_tools.items()
|
||||
if str(tool.get("_id", "")) in self._allowed_tool_ids
|
||||
}
|
||||
return filtered_tools
|
||||
|
||||
def _get_tools(self, api_key: str = None) -> Dict[str, Dict[str, Any]]:
|
||||
all_tools = super()._get_tools(api_key)
|
||||
if not self._allowed_tool_ids:
|
||||
return {}
|
||||
filtered_tools = {
|
||||
tool_id: tool
|
||||
for tool_id, tool in all_tools.items()
|
||||
if str(tool.get("_id", "")) in self._allowed_tool_ids
|
||||
}
|
||||
return filtered_tools
|
||||
|
||||
|
||||
class _WorkflowNodeMixin:
|
||||
"""Common __init__ for all workflow node agents."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str,
|
||||
llm_name: str,
|
||||
model_id: str,
|
||||
api_key: str,
|
||||
tool_ids: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
endpoint=endpoint,
|
||||
llm_name=llm_name,
|
||||
model_id=model_id,
|
||||
api_key=api_key,
|
||||
**kwargs,
|
||||
)
|
||||
self._allowed_tool_ids = tool_ids or []
|
||||
|
||||
|
||||
class WorkflowNodeClassicAgent(ToolFilterMixin, _WorkflowNodeMixin, ClassicAgent):
|
||||
pass
|
||||
|
||||
|
||||
class WorkflowNodeAgenticAgent(ToolFilterMixin, _WorkflowNodeMixin, AgenticAgent):
|
||||
pass
|
||||
|
||||
|
||||
class WorkflowNodeResearchAgent(ToolFilterMixin, _WorkflowNodeMixin, ResearchAgent):
|
||||
pass
|
||||
|
||||
|
||||
class WorkflowNodeAgentFactory:
|
||||
|
||||
_agents: Dict[AgentType, Type[BaseAgent]] = {
|
||||
AgentType.CLASSIC: WorkflowNodeClassicAgent,
|
||||
AgentType.REACT: WorkflowNodeClassicAgent, # backwards compat
|
||||
AgentType.AGENTIC: WorkflowNodeAgenticAgent,
|
||||
AgentType.RESEARCH: WorkflowNodeResearchAgent,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
agent_type: AgentType,
|
||||
endpoint: str,
|
||||
llm_name: str,
|
||||
model_id: str,
|
||||
api_key: str,
|
||||
tool_ids: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
) -> BaseAgent:
|
||||
agent_class = cls._agents.get(agent_type)
|
||||
if not agent_class:
|
||||
raise ValueError(f"Unsupported agent type: {agent_type}")
|
||||
return agent_class(
|
||||
endpoint=endpoint,
|
||||
llm_name=llm_name,
|
||||
model_id=model_id,
|
||||
api_key=api_key,
|
||||
tool_ids=tool_ids,
|
||||
**kwargs,
|
||||
)
|
||||
168
application/agents/workflows/schemas.py
Normal file
168
application/agents/workflows/schemas.py
Normal file
@@ -0,0 +1,168 @@
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
|
||||
class NodeType(str, Enum):
|
||||
START = "start"
|
||||
END = "end"
|
||||
AGENT = "agent"
|
||||
NOTE = "note"
|
||||
STATE = "state"
|
||||
CONDITION = "condition"
|
||||
|
||||
|
||||
class AgentType(str, Enum):
|
||||
CLASSIC = "classic"
|
||||
REACT = "react"
|
||||
AGENTIC = "agentic"
|
||||
RESEARCH = "research"
|
||||
|
||||
|
||||
class ExecutionStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class Position(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
x: float = 0.0
|
||||
y: float = 0.0
|
||||
|
||||
|
||||
class AgentNodeConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
agent_type: AgentType = AgentType.CLASSIC
|
||||
llm_name: Optional[str] = None
|
||||
system_prompt: str = "You are a helpful assistant."
|
||||
prompt_template: str = ""
|
||||
output_variable: Optional[str] = None
|
||||
stream_to_user: bool = True
|
||||
tools: List[str] = Field(default_factory=list)
|
||||
sources: List[str] = Field(default_factory=list)
|
||||
chunks: str = "2"
|
||||
retriever: str = ""
|
||||
model_id: Optional[str] = None
|
||||
json_schema: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class ConditionCase(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid", populate_by_name=True)
|
||||
name: Optional[str] = None
|
||||
expression: str = ""
|
||||
source_handle: str = Field(..., alias="sourceHandle")
|
||||
|
||||
|
||||
class ConditionNodeConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
mode: Literal["simple", "advanced"] = "simple"
|
||||
cases: List[ConditionCase] = Field(default_factory=list)
|
||||
|
||||
|
||||
class StateOperation(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
expression: str = ""
|
||||
target_variable: str = ""
|
||||
|
||||
|
||||
class WorkflowEdgeCreate(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
id: str
|
||||
workflow_id: str
|
||||
source_id: str = Field(..., alias="source")
|
||||
target_id: str = Field(..., alias="target")
|
||||
source_handle: Optional[str] = Field(None, alias="sourceHandle")
|
||||
target_handle: Optional[str] = Field(None, alias="targetHandle")
|
||||
|
||||
|
||||
class WorkflowEdge(WorkflowEdgeCreate):
|
||||
pass
|
||||
|
||||
|
||||
class WorkflowNodeCreate(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
id: str
|
||||
workflow_id: str
|
||||
type: NodeType
|
||||
title: str = "Node"
|
||||
description: Optional[str] = None
|
||||
position: Position = Field(default_factory=Position)
|
||||
config: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@field_validator("position", mode="before")
|
||||
@classmethod
|
||||
def parse_position(cls, v: Union[Dict[str, float], Position]) -> Position:
|
||||
if isinstance(v, dict):
|
||||
return Position(**v)
|
||||
return v
|
||||
|
||||
|
||||
class WorkflowNode(WorkflowNodeCreate):
|
||||
pass
|
||||
|
||||
|
||||
class WorkflowCreate(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
name: str = "New Workflow"
|
||||
description: Optional[str] = None
|
||||
user: Optional[str] = None
|
||||
|
||||
|
||||
class Workflow(WorkflowCreate):
|
||||
id: Optional[str] = None
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
|
||||
class WorkflowGraph(BaseModel):
|
||||
workflow: Workflow
|
||||
nodes: List[WorkflowNode] = Field(default_factory=list)
|
||||
edges: List[WorkflowEdge] = Field(default_factory=list)
|
||||
|
||||
def get_node_by_id(self, node_id: str) -> Optional[WorkflowNode]:
|
||||
for node in self.nodes:
|
||||
if node.id == node_id:
|
||||
return node
|
||||
return None
|
||||
|
||||
def get_start_node(self) -> Optional[WorkflowNode]:
|
||||
for node in self.nodes:
|
||||
if node.type == NodeType.START:
|
||||
return node
|
||||
return None
|
||||
|
||||
def get_outgoing_edges(self, node_id: str) -> List[WorkflowEdge]:
|
||||
return [edge for edge in self.edges if edge.source_id == node_id]
|
||||
|
||||
|
||||
class NodeExecutionLog(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
node_id: str
|
||||
node_type: str
|
||||
status: ExecutionStatus
|
||||
started_at: datetime
|
||||
completed_at: Optional[datetime] = None
|
||||
error: Optional[str] = None
|
||||
state_snapshot: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class WorkflowRunCreate(BaseModel):
|
||||
workflow_id: str
|
||||
inputs: Dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class WorkflowRun(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
id: Optional[str] = None
|
||||
workflow_id: str
|
||||
user: Optional[str] = None
|
||||
status: ExecutionStatus = ExecutionStatus.PENDING
|
||||
inputs: Dict[str, str] = Field(default_factory=dict)
|
||||
outputs: Dict[str, Any] = Field(default_factory=dict)
|
||||
steps: List[NodeExecutionLog] = Field(default_factory=list)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
completed_at: Optional[datetime] = None
|
||||
508
application/agents/workflows/workflow_engine.py
Normal file
508
application/agents/workflows/workflow_engine.py
Normal file
@@ -0,0 +1,508 @@
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Generator, List, Optional, TYPE_CHECKING
|
||||
|
||||
from application.agents.workflows.cel_evaluator import CelEvaluationError, evaluate_cel
|
||||
from application.agents.workflows.node_agent import WorkflowNodeAgentFactory
|
||||
from application.agents.workflows.schemas import (
|
||||
AgentNodeConfig,
|
||||
AgentType,
|
||||
ConditionNodeConfig,
|
||||
ExecutionStatus,
|
||||
NodeExecutionLog,
|
||||
NodeType,
|
||||
WorkflowGraph,
|
||||
WorkflowNode,
|
||||
)
|
||||
from application.core.json_schema_utils import (
|
||||
JsonSchemaValidationError,
|
||||
normalize_json_schema_payload,
|
||||
)
|
||||
from application.error import sanitize_api_error
|
||||
from application.templates.namespaces import NamespaceManager
|
||||
from application.templates.template_engine import TemplateEngine, TemplateRenderError
|
||||
|
||||
try:
|
||||
import jsonschema
|
||||
except ImportError: # pragma: no cover - optional dependency in some deployments.
|
||||
jsonschema = None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from application.agents.base import BaseAgent
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
StateValue = Any
|
||||
WorkflowState = Dict[str, StateValue]
|
||||
TEMPLATE_RESERVED_NAMESPACES = {"agent", "system", "source", "tools", "passthrough"}
|
||||
|
||||
|
||||
class WorkflowEngine:
|
||||
MAX_EXECUTION_STEPS = 50
|
||||
|
||||
def __init__(self, graph: WorkflowGraph, agent: "BaseAgent"):
|
||||
self.graph = graph
|
||||
self.agent = agent
|
||||
self.state: WorkflowState = {}
|
||||
self.execution_log: List[Dict[str, Any]] = []
|
||||
self._condition_result: Optional[str] = None
|
||||
self._template_engine = TemplateEngine()
|
||||
self._namespace_manager = NamespaceManager()
|
||||
|
||||
def execute(
|
||||
self, initial_inputs: WorkflowState, query: str
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
self._initialize_state(initial_inputs, query)
|
||||
|
||||
start_node = self.graph.get_start_node()
|
||||
if not start_node:
|
||||
yield {"type": "error", "error": "No start node found in workflow."}
|
||||
return
|
||||
current_node_id: Optional[str] = start_node.id
|
||||
steps = 0
|
||||
|
||||
while current_node_id and steps < self.MAX_EXECUTION_STEPS:
|
||||
node = self.graph.get_node_by_id(current_node_id)
|
||||
if not node:
|
||||
yield {"type": "error", "error": f"Node {current_node_id} not found."}
|
||||
break
|
||||
log_entry = self._create_log_entry(node)
|
||||
|
||||
yield {
|
||||
"type": "workflow_step",
|
||||
"node_id": node.id,
|
||||
"node_type": node.type.value,
|
||||
"node_title": node.title,
|
||||
"status": "running",
|
||||
}
|
||||
|
||||
try:
|
||||
yield from self._execute_node(node)
|
||||
log_entry["status"] = ExecutionStatus.COMPLETED.value
|
||||
log_entry["completed_at"] = datetime.now(timezone.utc)
|
||||
|
||||
output_key = f"node_{node.id}_output"
|
||||
node_output = self.state.get(output_key)
|
||||
|
||||
yield {
|
||||
"type": "workflow_step",
|
||||
"node_id": node.id,
|
||||
"node_type": node.type.value,
|
||||
"node_title": node.title,
|
||||
"status": "completed",
|
||||
"state_snapshot": dict(self.state),
|
||||
"output": node_output,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing node {node.id}: {e}", exc_info=True)
|
||||
log_entry["status"] = ExecutionStatus.FAILED.value
|
||||
log_entry["error"] = str(e)
|
||||
log_entry["completed_at"] = datetime.now(timezone.utc)
|
||||
log_entry["state_snapshot"] = dict(self.state)
|
||||
self.execution_log.append(log_entry)
|
||||
|
||||
user_friendly_error = sanitize_api_error(e)
|
||||
yield {
|
||||
"type": "workflow_step",
|
||||
"node_id": node.id,
|
||||
"node_type": node.type.value,
|
||||
"node_title": node.title,
|
||||
"status": "failed",
|
||||
"state_snapshot": dict(self.state),
|
||||
"error": user_friendly_error,
|
||||
}
|
||||
yield {"type": "error", "error": user_friendly_error}
|
||||
break
|
||||
log_entry["state_snapshot"] = dict(self.state)
|
||||
self.execution_log.append(log_entry)
|
||||
|
||||
if node.type == NodeType.END:
|
||||
break
|
||||
current_node_id = self._get_next_node_id(current_node_id)
|
||||
if current_node_id is None and node.type != NodeType.END:
|
||||
logger.warning(
|
||||
f"Branch ended at node '{node.title}' ({node.id}) without reaching an end node"
|
||||
)
|
||||
steps += 1
|
||||
if steps >= self.MAX_EXECUTION_STEPS:
|
||||
logger.warning(
|
||||
f"Workflow reached max steps limit ({self.MAX_EXECUTION_STEPS})"
|
||||
)
|
||||
|
||||
def _initialize_state(self, initial_inputs: WorkflowState, query: str) -> None:
|
||||
self.state.update(initial_inputs)
|
||||
self.state["query"] = query
|
||||
self.state["chat_history"] = str(self.agent.chat_history)
|
||||
|
||||
def _create_log_entry(self, node: WorkflowNode) -> Dict[str, Any]:
|
||||
return {
|
||||
"node_id": node.id,
|
||||
"node_type": node.type.value,
|
||||
"started_at": datetime.now(timezone.utc),
|
||||
"completed_at": None,
|
||||
"status": ExecutionStatus.RUNNING.value,
|
||||
"error": None,
|
||||
"state_snapshot": {},
|
||||
}
|
||||
|
||||
def _get_next_node_id(self, current_node_id: str) -> Optional[str]:
|
||||
node = self.graph.get_node_by_id(current_node_id)
|
||||
edges = self.graph.get_outgoing_edges(current_node_id)
|
||||
if not edges:
|
||||
return None
|
||||
|
||||
if node and node.type == NodeType.CONDITION and self._condition_result:
|
||||
target_handle = self._condition_result
|
||||
self._condition_result = None
|
||||
for edge in edges:
|
||||
if edge.source_handle == target_handle:
|
||||
return edge.target_id
|
||||
return None
|
||||
|
||||
return edges[0].target_id
|
||||
|
||||
def _execute_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
logger.info(f"Executing node {node.id} ({node.type.value})")
|
||||
|
||||
node_handlers = {
|
||||
NodeType.START: self._execute_start_node,
|
||||
NodeType.NOTE: self._execute_note_node,
|
||||
NodeType.AGENT: self._execute_agent_node,
|
||||
NodeType.STATE: self._execute_state_node,
|
||||
NodeType.CONDITION: self._execute_condition_node,
|
||||
NodeType.END: self._execute_end_node,
|
||||
}
|
||||
|
||||
handler = node_handlers.get(node.type)
|
||||
if handler:
|
||||
yield from handler(node)
|
||||
|
||||
def _execute_start_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
yield from ()
|
||||
|
||||
def _execute_note_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
yield from ()
|
||||
|
||||
def _execute_agent_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
from application.core.model_utils import (
|
||||
get_api_key_for_provider,
|
||||
get_model_capabilities,
|
||||
get_provider_from_model_id,
|
||||
)
|
||||
|
||||
node_config = AgentNodeConfig(**node.config.get("config", node.config))
|
||||
|
||||
if node_config.sources:
|
||||
self._retrieve_node_sources(node_config)
|
||||
|
||||
if node_config.prompt_template:
|
||||
formatted_prompt = self._format_template(node_config.prompt_template)
|
||||
else:
|
||||
formatted_prompt = self.state.get("query", "")
|
||||
node_json_schema = self._normalize_node_json_schema(
|
||||
node_config.json_schema, node.title
|
||||
)
|
||||
node_model_id = node_config.model_id or self.agent.model_id
|
||||
# Inherit BYOM scope from parent agent so owner-stored BYOM
|
||||
# resolves on shared workflows.
|
||||
node_user_id = getattr(self.agent, "model_user_id", None) or (
|
||||
self.agent.decoded_token.get("sub")
|
||||
if isinstance(self.agent.decoded_token, dict)
|
||||
else None
|
||||
)
|
||||
node_llm_name = (
|
||||
node_config.llm_name
|
||||
or get_provider_from_model_id(
|
||||
node_model_id or "", user_id=node_user_id
|
||||
)
|
||||
or self.agent.llm_name
|
||||
)
|
||||
node_api_key = get_api_key_for_provider(node_llm_name) or self.agent.api_key
|
||||
|
||||
if node_json_schema and node_model_id:
|
||||
model_capabilities = get_model_capabilities(
|
||||
node_model_id, user_id=node_user_id
|
||||
)
|
||||
if model_capabilities and not model_capabilities.get(
|
||||
"supports_structured_output", False
|
||||
):
|
||||
raise ValueError(
|
||||
f'Model "{node_model_id}" does not support structured output for node "{node.title}"'
|
||||
)
|
||||
|
||||
factory_kwargs = {
|
||||
"agent_type": node_config.agent_type,
|
||||
"endpoint": self.agent.endpoint,
|
||||
"llm_name": node_llm_name,
|
||||
"model_id": node_model_id,
|
||||
"model_user_id": getattr(self.agent, "model_user_id", None),
|
||||
"api_key": node_api_key,
|
||||
"tool_ids": node_config.tools,
|
||||
"prompt": node_config.system_prompt,
|
||||
"chat_history": self.agent.chat_history,
|
||||
"decoded_token": self.agent.decoded_token,
|
||||
"json_schema": node_json_schema,
|
||||
}
|
||||
|
||||
# Agentic/research agents need retriever_config for on-demand search
|
||||
if node_config.agent_type in (AgentType.AGENTIC, AgentType.RESEARCH):
|
||||
factory_kwargs["retriever_config"] = {
|
||||
"source": {"active_docs": node_config.sources} if node_config.sources else {},
|
||||
"retriever_name": node_config.retriever or "classic",
|
||||
"chunks": int(node_config.chunks) if node_config.chunks else 2,
|
||||
"model_id": node_model_id,
|
||||
"llm_name": node_llm_name,
|
||||
"api_key": node_api_key,
|
||||
"decoded_token": self.agent.decoded_token,
|
||||
}
|
||||
|
||||
node_agent = WorkflowNodeAgentFactory.create(**factory_kwargs)
|
||||
|
||||
full_response_parts: List[str] = []
|
||||
structured_response_parts: List[str] = []
|
||||
has_structured_response = False
|
||||
first_chunk = True
|
||||
for event in node_agent.gen(formatted_prompt):
|
||||
if "answer" in event:
|
||||
chunk = str(event["answer"])
|
||||
full_response_parts.append(chunk)
|
||||
if event.get("structured"):
|
||||
has_structured_response = True
|
||||
structured_response_parts.append(chunk)
|
||||
if node_config.stream_to_user:
|
||||
if first_chunk and hasattr(self, "_has_streamed"):
|
||||
yield {"answer": "\n\n"}
|
||||
first_chunk = False
|
||||
yield event
|
||||
|
||||
if node_config.stream_to_user:
|
||||
self._has_streamed = True
|
||||
|
||||
full_response = "".join(full_response_parts).strip()
|
||||
output_value: Any = full_response
|
||||
if has_structured_response:
|
||||
structured_response = "".join(structured_response_parts).strip()
|
||||
response_to_parse = structured_response or full_response
|
||||
parsed_success, parsed_structured = self._parse_structured_output(
|
||||
response_to_parse
|
||||
)
|
||||
output_value = parsed_structured if parsed_success else response_to_parse
|
||||
if node_json_schema:
|
||||
self._validate_structured_output(node_json_schema, output_value)
|
||||
elif node_json_schema:
|
||||
parsed_success, parsed_structured = self._parse_structured_output(
|
||||
full_response
|
||||
)
|
||||
if not parsed_success:
|
||||
raise ValueError(
|
||||
"Structured output was expected but response was not valid JSON"
|
||||
)
|
||||
output_value = parsed_structured
|
||||
self._validate_structured_output(node_json_schema, output_value)
|
||||
|
||||
default_output_key = f"node_{node.id}_output"
|
||||
self.state[default_output_key] = output_value
|
||||
|
||||
if node_config.output_variable:
|
||||
self.state[node_config.output_variable] = output_value
|
||||
|
||||
def _execute_state_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
config = node.config.get("config", node.config)
|
||||
for op in config.get("operations", []):
|
||||
expression = op.get("expression", "")
|
||||
target_variable = op.get("target_variable", "")
|
||||
if expression and target_variable:
|
||||
self.state[target_variable] = evaluate_cel(expression, self.state)
|
||||
yield from ()
|
||||
|
||||
def _execute_condition_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
config = ConditionNodeConfig(**node.config.get("config", node.config))
|
||||
matched_handle = None
|
||||
|
||||
for case in config.cases:
|
||||
if not case.expression.strip():
|
||||
continue
|
||||
try:
|
||||
if evaluate_cel(case.expression, self.state):
|
||||
matched_handle = case.source_handle
|
||||
break
|
||||
except CelEvaluationError:
|
||||
continue
|
||||
|
||||
self._condition_result = matched_handle or "else"
|
||||
yield from ()
|
||||
|
||||
def _execute_end_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
config = node.config.get("config", node.config)
|
||||
output_template = str(config.get("output_template", ""))
|
||||
if output_template:
|
||||
formatted_output = self._format_template(output_template)
|
||||
yield {"answer": formatted_output}
|
||||
|
||||
def _parse_structured_output(self, raw_response: str) -> tuple[bool, Optional[Any]]:
|
||||
normalized_response = raw_response.strip()
|
||||
if not normalized_response:
|
||||
return False, None
|
||||
|
||||
try:
|
||||
return True, json.loads(normalized_response)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
"Workflow agent returned structured output that was not valid JSON"
|
||||
)
|
||||
return False, None
|
||||
|
||||
def _normalize_node_json_schema(
|
||||
self, schema: Optional[Dict[str, Any]], node_title: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
if schema is None:
|
||||
return None
|
||||
try:
|
||||
return normalize_json_schema_payload(schema)
|
||||
except JsonSchemaValidationError as exc:
|
||||
raise ValueError(
|
||||
f'Invalid JSON schema for node "{node_title}": {exc}'
|
||||
) from exc
|
||||
|
||||
def _validate_structured_output(self, schema: Dict[str, Any], output_value: Any) -> None:
|
||||
if jsonschema is None:
|
||||
logger.warning(
|
||||
"jsonschema package is not available, skipping structured output validation"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
normalized_schema = normalize_json_schema_payload(schema)
|
||||
except JsonSchemaValidationError as exc:
|
||||
raise ValueError(f"Invalid JSON schema: {exc}") from exc
|
||||
|
||||
try:
|
||||
jsonschema.validate(instance=output_value, schema=normalized_schema)
|
||||
except jsonschema.exceptions.ValidationError as exc:
|
||||
raise ValueError(f"Structured output did not match schema: {exc.message}") from exc
|
||||
except jsonschema.exceptions.SchemaError as exc:
|
||||
raise ValueError(f"Invalid JSON schema: {exc.message}") from exc
|
||||
|
||||
def _format_template(self, template: str) -> str:
|
||||
context = self._build_template_context()
|
||||
try:
|
||||
return self._template_engine.render(template, context)
|
||||
except TemplateRenderError as e:
|
||||
logger.warning(
|
||||
"Workflow template rendering failed, using raw template: %s", str(e)
|
||||
)
|
||||
return template
|
||||
|
||||
def _build_template_context(self) -> Dict[str, Any]:
|
||||
docs, docs_together = self._get_source_template_data()
|
||||
passthrough_data = (
|
||||
self.state.get("passthrough")
|
||||
if isinstance(self.state.get("passthrough"), dict)
|
||||
else None
|
||||
)
|
||||
tools_data = (
|
||||
self.state.get("tools") if isinstance(self.state.get("tools"), dict) else None
|
||||
)
|
||||
|
||||
context = self._namespace_manager.build_context(
|
||||
user_id=getattr(self.agent, "user", None),
|
||||
request_id=getattr(self.agent, "request_id", None),
|
||||
passthrough_data=passthrough_data,
|
||||
docs=docs,
|
||||
docs_together=docs_together,
|
||||
tools_data=tools_data,
|
||||
)
|
||||
|
||||
agent_context: Dict[str, Any] = {}
|
||||
for key, value in self.state.items():
|
||||
if not isinstance(key, str):
|
||||
continue
|
||||
normalized_key = key.strip()
|
||||
if not normalized_key:
|
||||
continue
|
||||
agent_context[normalized_key] = value
|
||||
|
||||
context["agent"] = agent_context
|
||||
|
||||
# Keep legacy top-level variables working while namespaced variables are adopted.
|
||||
for key, value in agent_context.items():
|
||||
if key in TEMPLATE_RESERVED_NAMESPACES:
|
||||
context[f"agent_{key}"] = value
|
||||
continue
|
||||
if key not in context:
|
||||
context[key] = value
|
||||
|
||||
return context
|
||||
|
||||
def _get_source_template_data(self) -> tuple[Optional[List[Dict[str, Any]]], Optional[str]]:
|
||||
docs = getattr(self.agent, "retrieved_docs", None)
|
||||
if not isinstance(docs, list) or len(docs) == 0:
|
||||
return None, None
|
||||
|
||||
docs_together_parts: List[str] = []
|
||||
for doc in docs:
|
||||
if not isinstance(doc, dict):
|
||||
continue
|
||||
text = doc.get("text")
|
||||
if not isinstance(text, str):
|
||||
continue
|
||||
|
||||
filename = doc.get("filename") or doc.get("title") or doc.get("source")
|
||||
if isinstance(filename, str) and filename.strip():
|
||||
docs_together_parts.append(f"{filename}\n{text}")
|
||||
else:
|
||||
docs_together_parts.append(text)
|
||||
|
||||
docs_together = "\n\n".join(docs_together_parts) if docs_together_parts else None
|
||||
return docs, docs_together
|
||||
|
||||
def _retrieve_node_sources(self, node_config: AgentNodeConfig) -> None:
|
||||
"""Retrieve documents from the node's sources for template resolution."""
|
||||
from application.retriever.retriever_creator import RetrieverCreator
|
||||
|
||||
query = self.state.get("query", "")
|
||||
if not query:
|
||||
return
|
||||
|
||||
try:
|
||||
retriever = RetrieverCreator.create_retriever(
|
||||
node_config.retriever or "classic",
|
||||
source={"active_docs": node_config.sources},
|
||||
chat_history=[],
|
||||
prompt="",
|
||||
chunks=int(node_config.chunks) if node_config.chunks else 2,
|
||||
decoded_token=self.agent.decoded_token,
|
||||
)
|
||||
docs = retriever.search(query)
|
||||
if docs:
|
||||
self.agent.retrieved_docs = docs
|
||||
except Exception:
|
||||
logger.exception("Failed to retrieve docs for workflow node")
|
||||
|
||||
def get_execution_summary(self) -> List[NodeExecutionLog]:
|
||||
return [
|
||||
NodeExecutionLog(
|
||||
node_id=log["node_id"],
|
||||
node_type=log["node_type"],
|
||||
status=ExecutionStatus(log["status"]),
|
||||
started_at=log["started_at"],
|
||||
completed_at=log.get("completed_at"),
|
||||
error=log.get("error"),
|
||||
state_snapshot=log.get("state_snapshot", {}),
|
||||
)
|
||||
for log in self.execution_log
|
||||
]
|
||||
52
application/alembic.ini
Normal file
52
application/alembic.ini
Normal file
@@ -0,0 +1,52 @@
|
||||
# Alembic configuration for the DocsGPT user-data Postgres database.
|
||||
#
|
||||
# The SQLAlchemy URL is deliberately NOT set here — env.py reads it from
|
||||
# ``application.core.settings.settings.POSTGRES_URI`` so the same config
|
||||
# source serves the running app and migrations. To run from the project
|
||||
# root::
|
||||
#
|
||||
# alembic -c application/alembic.ini upgrade head
|
||||
|
||||
[alembic]
|
||||
script_location = %(here)s/alembic
|
||||
prepend_sys_path = ..
|
||||
version_path_separator = os
|
||||
|
||||
# sqlalchemy.url is intentionally left blank — env.py supplies it.
|
||||
sqlalchemy.url =
|
||||
|
||||
[post_write_hooks]
|
||||
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARNING
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARNING
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
82
application/alembic/env.py
Normal file
82
application/alembic/env.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""Alembic environment for the DocsGPT user-data Postgres database.
|
||||
|
||||
The URL is pulled from ``application.core.settings`` rather than
|
||||
``alembic.ini`` so that a single ``POSTGRES_URI`` env var drives both the
|
||||
running app and ``alembic`` CLI invocations.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from logging.config import fileConfig
|
||||
from pathlib import Path
|
||||
|
||||
# Make the project root importable regardless of cwd. env.py lives at
|
||||
# <repo>/application/alembic/env.py, so parents[2] is the repo root.
|
||||
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
if str(_PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(_PROJECT_ROOT))
|
||||
|
||||
from alembic import context # noqa: E402
|
||||
from sqlalchemy import engine_from_config, pool # noqa: E402
|
||||
|
||||
from application.core.settings import settings # noqa: E402
|
||||
from application.storage.db.models import metadata as target_metadata # noqa: E402
|
||||
|
||||
config = context.config
|
||||
|
||||
# Populate the runtime URL from settings.
|
||||
if settings.POSTGRES_URI:
|
||||
config.set_main_option("sqlalchemy.url", settings.POSTGRES_URI)
|
||||
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode (emits SQL without a live DB)."""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
if not url:
|
||||
raise RuntimeError(
|
||||
"POSTGRES_URI is not configured. Set it in your .env to a "
|
||||
"psycopg3 URI such as "
|
||||
"'postgresql+psycopg://user:pass@host:5432/docsgpt'."
|
||||
)
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
compare_type=True,
|
||||
)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode against a live connection."""
|
||||
if not config.get_main_option("sqlalchemy.url"):
|
||||
raise RuntimeError(
|
||||
"POSTGRES_URI is not configured. Set it in your .env to a "
|
||||
"psycopg3 URI such as "
|
||||
"'postgresql+psycopg://user:pass@host:5432/docsgpt'."
|
||||
)
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
future=True,
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata,
|
||||
compare_type=True,
|
||||
)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
26
application/alembic/script.py.mako
Normal file
26
application/alembic/script.py.mako
Normal file
@@ -0,0 +1,26 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
||||
927
application/alembic/versions/0001_initial.py
Normal file
927
application/alembic/versions/0001_initial.py
Normal file
@@ -0,0 +1,927 @@
|
||||
"""0001 initial schema — consolidated Phase-1..3 baseline.
|
||||
|
||||
Revision ID: 0001_initial
|
||||
Revises:
|
||||
Create Date: 2026-04-13
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
revision: str = "0001_initial"
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ------------------------------------------------------------------
|
||||
# Extensions
|
||||
# ------------------------------------------------------------------
|
||||
op.execute('CREATE EXTENSION IF NOT EXISTS "pgcrypto";')
|
||||
op.execute('CREATE EXTENSION IF NOT EXISTS "citext";')
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Trigger functions
|
||||
# ------------------------------------------------------------------
|
||||
op.execute(
|
||||
"""
|
||||
CREATE FUNCTION set_updated_at() RETURNS trigger
|
||||
LANGUAGE plpgsql AS $$
|
||||
BEGIN
|
||||
NEW.updated_at = now();
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE FUNCTION ensure_user_exists() RETURNS trigger
|
||||
LANGUAGE plpgsql AS $$
|
||||
BEGIN
|
||||
IF NEW.user_id IS NOT NULL THEN
|
||||
INSERT INTO users (user_id) VALUES (NEW.user_id)
|
||||
ON CONFLICT (user_id) DO NOTHING;
|
||||
END IF;
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE FUNCTION cleanup_message_attachment_refs() RETURNS trigger
|
||||
LANGUAGE plpgsql AS $$
|
||||
BEGIN
|
||||
UPDATE conversation_messages
|
||||
SET attachments = array_remove(attachments, OLD.id)
|
||||
WHERE OLD.id = ANY(attachments);
|
||||
RETURN OLD;
|
||||
END;
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE FUNCTION cleanup_agent_extra_source_refs() RETURNS trigger
|
||||
LANGUAGE plpgsql AS $$
|
||||
BEGIN
|
||||
UPDATE agents
|
||||
SET extra_source_ids = array_remove(extra_source_ids, OLD.id)
|
||||
WHERE OLD.id = ANY(extra_source_ids);
|
||||
RETURN OLD;
|
||||
END;
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE FUNCTION cleanup_user_agent_prefs() RETURNS trigger
|
||||
LANGUAGE plpgsql AS $$
|
||||
DECLARE
|
||||
agent_id_text text := OLD.id::text;
|
||||
BEGIN
|
||||
UPDATE users
|
||||
SET agent_preferences = jsonb_set(
|
||||
jsonb_set(
|
||||
agent_preferences,
|
||||
'{pinned}',
|
||||
COALESCE((
|
||||
SELECT jsonb_agg(e)
|
||||
FROM jsonb_array_elements(
|
||||
COALESCE(agent_preferences->'pinned', '[]'::jsonb)
|
||||
) e
|
||||
WHERE (e #>> '{}') <> agent_id_text
|
||||
), '[]'::jsonb)
|
||||
),
|
||||
'{shared_with_me}',
|
||||
COALESCE((
|
||||
SELECT jsonb_agg(e)
|
||||
FROM jsonb_array_elements(
|
||||
COALESCE(agent_preferences->'shared_with_me', '[]'::jsonb)
|
||||
) e
|
||||
WHERE (e #>> '{}') <> agent_id_text
|
||||
), '[]'::jsonb)
|
||||
)
|
||||
WHERE agent_preferences->'pinned' @> to_jsonb(agent_id_text)
|
||||
OR agent_preferences->'shared_with_me' @> to_jsonb(agent_id_text);
|
||||
RETURN OLD;
|
||||
END;
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE FUNCTION conversation_messages_fill_user_id() RETURNS trigger
|
||||
LANGUAGE plpgsql AS $$
|
||||
BEGIN
|
||||
IF NEW.user_id IS NULL THEN
|
||||
SELECT user_id INTO NEW.user_id
|
||||
FROM conversations
|
||||
WHERE id = NEW.conversation_id;
|
||||
END IF;
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tables
|
||||
# ------------------------------------------------------------------
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE users (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL UNIQUE,
|
||||
agent_preferences JSONB NOT NULL
|
||||
DEFAULT '{"pinned": [], "shared_with_me": []}'::jsonb,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE prompts (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
legacy_mongo_id TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE user_tools (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
custom_name TEXT,
|
||||
display_name TEXT,
|
||||
description TEXT,
|
||||
config JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
config_requirements JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
actions JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||
status BOOLEAN NOT NULL DEFAULT true,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
legacy_mongo_id TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE token_usage (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
user_id TEXT,
|
||||
api_key TEXT,
|
||||
agent_id UUID,
|
||||
prompt_tokens INTEGER NOT NULL DEFAULT 0,
|
||||
generated_tokens INTEGER NOT NULL DEFAULT 0,
|
||||
timestamp TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
mongo_id TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
"ALTER TABLE token_usage ADD CONSTRAINT token_usage_attribution_chk "
|
||||
"CHECK (user_id IS NOT NULL OR api_key IS NOT NULL) NOT VALID;"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE user_logs (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
user_id TEXT,
|
||||
endpoint TEXT,
|
||||
timestamp TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
data JSONB,
|
||||
mongo_id TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE stack_logs (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
activity_id TEXT NOT NULL,
|
||||
endpoint TEXT,
|
||||
level TEXT,
|
||||
user_id TEXT,
|
||||
api_key TEXT,
|
||||
query TEXT,
|
||||
stacks JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||
timestamp TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
mongo_id TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE agent_folders (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
description TEXT,
|
||||
parent_id UUID REFERENCES agent_folders(id) ON DELETE SET NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
legacy_mongo_id TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE sources (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
language TEXT,
|
||||
date TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
model TEXT,
|
||||
type TEXT,
|
||||
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
retriever TEXT,
|
||||
sync_frequency TEXT,
|
||||
tokens TEXT,
|
||||
file_path TEXT,
|
||||
remote_data JSONB,
|
||||
directory_structure JSONB,
|
||||
file_name_map JSONB,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
legacy_mongo_id TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE agents (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
description TEXT,
|
||||
agent_type TEXT,
|
||||
status TEXT NOT NULL,
|
||||
key CITEXT UNIQUE,
|
||||
image TEXT,
|
||||
source_id UUID REFERENCES sources(id) ON DELETE SET NULL,
|
||||
extra_source_ids UUID[] NOT NULL DEFAULT '{}',
|
||||
chunks INTEGER,
|
||||
retriever TEXT,
|
||||
prompt_id UUID REFERENCES prompts(id) ON DELETE SET NULL,
|
||||
tools JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||
json_schema JSONB,
|
||||
models JSONB,
|
||||
default_model_id TEXT,
|
||||
folder_id UUID REFERENCES agent_folders(id) ON DELETE SET NULL,
|
||||
workflow_id UUID,
|
||||
limited_token_mode BOOLEAN NOT NULL DEFAULT false,
|
||||
token_limit INTEGER,
|
||||
limited_request_mode BOOLEAN NOT NULL DEFAULT false,
|
||||
request_limit INTEGER,
|
||||
allow_system_prompt_override BOOLEAN NOT NULL DEFAULT false,
|
||||
shared BOOLEAN NOT NULL DEFAULT false,
|
||||
shared_token CITEXT UNIQUE,
|
||||
shared_metadata JSONB,
|
||||
incoming_webhook_token CITEXT UNIQUE,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
last_used_at TIMESTAMPTZ,
|
||||
legacy_mongo_id TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
"ALTER TABLE token_usage ADD CONSTRAINT token_usage_agent_fk "
|
||||
"FOREIGN KEY (agent_id) REFERENCES agents(id) ON DELETE SET NULL;"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE attachments (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
filename TEXT NOT NULL,
|
||||
upload_path TEXT NOT NULL,
|
||||
mime_type TEXT,
|
||||
size BIGINT,
|
||||
content TEXT,
|
||||
token_count INTEGER,
|
||||
openai_file_id TEXT,
|
||||
google_file_uri TEXT,
|
||||
metadata JSONB,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
legacy_mongo_id TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE memories (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE,
|
||||
path TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE todos (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE,
|
||||
todo_id INTEGER,
|
||||
title TEXT NOT NULL,
|
||||
completed BOOLEAN NOT NULL DEFAULT false,
|
||||
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
legacy_mongo_id TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE notes (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE,
|
||||
title TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
legacy_mongo_id TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE connector_sessions (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
provider TEXT NOT NULL,
|
||||
server_url TEXT,
|
||||
session_token TEXT UNIQUE,
|
||||
user_email TEXT,
|
||||
status TEXT,
|
||||
token_info JSONB,
|
||||
session_data JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
expires_at TIMESTAMPTZ,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
legacy_mongo_id TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE conversations (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
agent_id UUID REFERENCES agents(id) ON DELETE SET NULL,
|
||||
name TEXT,
|
||||
api_key TEXT,
|
||||
is_shared_usage BOOLEAN NOT NULL DEFAULT false,
|
||||
shared_token TEXT,
|
||||
date TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
shared_with TEXT[] NOT NULL DEFAULT '{}'::text[],
|
||||
compression_metadata JSONB,
|
||||
legacy_mongo_id TEXT,
|
||||
CONSTRAINT conversations_api_key_nonempty_chk
|
||||
CHECK (api_key IS NULL OR api_key <> '')
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE conversation_messages (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
conversation_id UUID NOT NULL REFERENCES conversations(id) ON DELETE CASCADE,
|
||||
position INTEGER NOT NULL,
|
||||
prompt TEXT,
|
||||
response TEXT,
|
||||
thought TEXT,
|
||||
sources JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||
tool_calls JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||
attachments UUID[] NOT NULL DEFAULT '{}'::uuid[],
|
||||
model_id TEXT,
|
||||
message_metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
feedback JSONB,
|
||||
timestamp TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
user_id TEXT NOT NULL,
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE shared_conversations (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
conversation_id UUID NOT NULL REFERENCES conversations(id) ON DELETE CASCADE,
|
||||
user_id TEXT NOT NULL,
|
||||
is_promptable BOOLEAN NOT NULL DEFAULT false,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
uuid UUID NOT NULL,
|
||||
first_n_queries INTEGER NOT NULL DEFAULT 0,
|
||||
api_key TEXT,
|
||||
prompt_id UUID REFERENCES prompts(id) ON DELETE SET NULL,
|
||||
chunks INTEGER,
|
||||
CONSTRAINT shared_conversations_api_key_nonempty_chk
|
||||
CHECK (api_key IS NULL OR api_key <> '')
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE pending_tool_state (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
conversation_id UUID NOT NULL REFERENCES conversations(id) ON DELETE CASCADE,
|
||||
user_id TEXT NOT NULL,
|
||||
messages JSONB NOT NULL,
|
||||
pending_tool_calls JSONB NOT NULL,
|
||||
tools_dict JSONB NOT NULL,
|
||||
tool_schemas JSONB NOT NULL,
|
||||
agent_config JSONB NOT NULL,
|
||||
client_tools JSONB,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
expires_at TIMESTAMPTZ NOT NULL
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE workflows (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
description TEXT,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
current_graph_version INTEGER NOT NULL DEFAULT 1,
|
||||
legacy_mongo_id TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
# Backfill the agents.workflow_id FK now that workflows exists.
|
||||
# The column was created without a FK (forward reference to a table
|
||||
# that hadn't been declared yet); add the constraint here so workflow
|
||||
# deletion still cascades through to agent unset.
|
||||
op.execute(
|
||||
"ALTER TABLE agents ADD CONSTRAINT agents_workflow_fk "
|
||||
"FOREIGN KEY (workflow_id) REFERENCES workflows(id) ON DELETE SET NULL;"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE workflow_nodes (
|
||||
id UUID DEFAULT gen_random_uuid() NOT NULL,
|
||||
workflow_id UUID NOT NULL REFERENCES workflows(id) ON DELETE CASCADE,
|
||||
graph_version INTEGER NOT NULL,
|
||||
node_type TEXT NOT NULL,
|
||||
config JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
node_id TEXT NOT NULL,
|
||||
title TEXT,
|
||||
description TEXT,
|
||||
position JSONB NOT NULL DEFAULT '{"x": 0, "y": 0}'::jsonb,
|
||||
legacy_mongo_id TEXT,
|
||||
PRIMARY KEY (id),
|
||||
CONSTRAINT workflow_nodes_id_wf_ver_key
|
||||
UNIQUE (id, workflow_id, graph_version)
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE workflow_edges (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
workflow_id UUID NOT NULL REFERENCES workflows(id) ON DELETE CASCADE,
|
||||
graph_version INTEGER NOT NULL,
|
||||
from_node_id UUID NOT NULL,
|
||||
to_node_id UUID NOT NULL,
|
||||
config JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
edge_id TEXT NOT NULL,
|
||||
source_handle TEXT,
|
||||
target_handle TEXT,
|
||||
CONSTRAINT workflow_edges_from_node_fk
|
||||
FOREIGN KEY (from_node_id, workflow_id, graph_version)
|
||||
REFERENCES workflow_nodes(id, workflow_id, graph_version) ON DELETE CASCADE,
|
||||
CONSTRAINT workflow_edges_to_node_fk
|
||||
FOREIGN KEY (to_node_id, workflow_id, graph_version)
|
||||
REFERENCES workflow_nodes(id, workflow_id, graph_version) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE workflow_runs (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
workflow_id UUID NOT NULL REFERENCES workflows(id) ON DELETE CASCADE,
|
||||
user_id TEXT NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
started_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
ended_at TIMESTAMPTZ,
|
||||
result JSONB,
|
||||
inputs JSONB,
|
||||
steps JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||
legacy_mongo_id TEXT,
|
||||
CONSTRAINT workflow_runs_status_chk
|
||||
CHECK (status IN ('pending', 'running', 'completed', 'failed'))
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Indexes
|
||||
# ------------------------------------------------------------------
|
||||
op.execute("CREATE INDEX agent_folders_user_idx ON agent_folders (user_id);")
|
||||
|
||||
op.execute("CREATE INDEX agents_user_idx ON agents (user_id);")
|
||||
op.execute("CREATE INDEX agents_shared_idx ON agents (shared) WHERE shared = true;")
|
||||
op.execute("CREATE INDEX agents_status_idx ON agents (status);")
|
||||
op.execute("CREATE INDEX agents_source_id_idx ON agents (source_id);")
|
||||
op.execute("CREATE INDEX agents_prompt_id_idx ON agents (prompt_id);")
|
||||
op.execute("CREATE INDEX agents_folder_id_idx ON agents (folder_id);")
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX agents_legacy_mongo_id_uidx "
|
||||
"ON agents (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||
)
|
||||
|
||||
op.execute("CREATE INDEX attachments_user_idx ON attachments (user_id);")
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX attachments_legacy_mongo_id_uidx "
|
||||
"ON attachments (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
# MCP and OAuth connectors share the ``provider`` slot, so the
|
||||
# dedup key is ``(user_id, server_url, provider)``: MCP rows
|
||||
# differentiate by server_url (one per MCP server), OAuth rows
|
||||
# have server_url = NULL and differentiate by provider alone.
|
||||
# COALESCE lets NULL server_url participate in the constraint.
|
||||
"CREATE UNIQUE INDEX connector_sessions_user_endpoint_uidx "
|
||||
"ON connector_sessions (user_id, COALESCE(server_url, ''), provider);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX connector_sessions_expiry_idx "
|
||||
"ON connector_sessions (expires_at) WHERE expires_at IS NOT NULL;"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX connector_sessions_server_url_idx "
|
||||
"ON connector_sessions (server_url) WHERE server_url IS NOT NULL;"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX connector_sessions_legacy_mongo_id_uidx "
|
||||
"ON connector_sessions (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX conversation_messages_conv_pos_uidx "
|
||||
"ON conversation_messages (conversation_id, position);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX conversation_messages_user_ts_idx "
|
||||
"ON conversation_messages (user_id, timestamp DESC);"
|
||||
)
|
||||
|
||||
op.execute("CREATE INDEX conversations_user_date_idx ON conversations (user_id, date DESC);")
|
||||
op.execute("CREATE INDEX conversations_agent_idx ON conversations (agent_id);")
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX conversations_shared_token_uidx "
|
||||
"ON conversations (shared_token) WHERE shared_token IS NOT NULL;"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX conversations_api_key_date_idx "
|
||||
"ON conversations (api_key, date DESC) WHERE api_key IS NOT NULL;"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX conversations_legacy_mongo_id_uidx "
|
||||
"ON conversations (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX memories_user_tool_path_uidx "
|
||||
"ON memories (user_id, tool_id, path);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX memories_user_path_null_tool_uidx "
|
||||
"ON memories (user_id, path) WHERE tool_id IS NULL;"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX memories_path_prefix_idx "
|
||||
"ON memories (user_id, tool_id, path text_pattern_ops);"
|
||||
)
|
||||
op.execute("CREATE INDEX memories_tool_id_idx ON memories (tool_id);")
|
||||
|
||||
op.execute("CREATE UNIQUE INDEX notes_user_tool_uidx ON notes (user_id, tool_id);")
|
||||
op.execute("CREATE INDEX notes_tool_id_idx ON notes (tool_id);")
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX notes_legacy_mongo_id_uidx "
|
||||
"ON notes (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX pending_tool_state_conv_user_uidx "
|
||||
"ON pending_tool_state (conversation_id, user_id);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX pending_tool_state_expires_idx ON pending_tool_state (expires_at);"
|
||||
)
|
||||
|
||||
op.execute("CREATE INDEX prompts_user_id_idx ON prompts (user_id);")
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX prompts_legacy_mongo_id_uidx "
|
||||
"ON prompts (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||
)
|
||||
|
||||
op.execute("CREATE INDEX shared_conversations_user_idx ON shared_conversations (user_id);")
|
||||
op.execute("CREATE INDEX shared_conversations_conv_idx ON shared_conversations (conversation_id);")
|
||||
op.execute(
|
||||
"CREATE INDEX shared_conversations_prompt_id_idx ON shared_conversations (prompt_id);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX shared_conversations_uuid_uidx ON shared_conversations (uuid);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX shared_conversations_dedup_uidx "
|
||||
"ON shared_conversations (conversation_id, user_id, is_promptable, first_n_queries, COALESCE(api_key, ''));"
|
||||
)
|
||||
|
||||
op.execute("CREATE INDEX sources_user_idx ON sources (user_id);")
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX sources_legacy_mongo_id_uidx "
|
||||
"ON sources (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX user_tools_legacy_mongo_id_uidx "
|
||||
"ON user_tools (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX agent_folders_legacy_mongo_id_uidx "
|
||||
"ON agent_folders (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||
)
|
||||
op.execute("CREATE INDEX agent_folders_parent_idx ON agent_folders (parent_id);")
|
||||
op.execute("CREATE INDEX agents_workflow_idx ON agents (workflow_id);")
|
||||
|
||||
op.execute('CREATE INDEX stack_logs_timestamp_idx ON stack_logs ("timestamp" DESC);')
|
||||
op.execute('CREATE INDEX stack_logs_user_ts_idx ON stack_logs (user_id, "timestamp" DESC);')
|
||||
op.execute('CREATE INDEX stack_logs_level_ts_idx ON stack_logs (level, "timestamp" DESC);')
|
||||
op.execute("CREATE INDEX stack_logs_activity_idx ON stack_logs (activity_id);")
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX stack_logs_mongo_id_uidx "
|
||||
"ON stack_logs (mongo_id) WHERE mongo_id IS NOT NULL;"
|
||||
)
|
||||
|
||||
op.execute("CREATE INDEX todos_user_tool_idx ON todos (user_id, tool_id);")
|
||||
op.execute("CREATE INDEX todos_tool_id_idx ON todos (tool_id);")
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX todos_legacy_mongo_id_uidx "
|
||||
"ON todos (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX todos_tool_todo_id_uidx "
|
||||
"ON todos (tool_id, todo_id) WHERE todo_id IS NOT NULL;"
|
||||
)
|
||||
|
||||
op.execute('CREATE INDEX token_usage_user_ts_idx ON token_usage (user_id, "timestamp" DESC);')
|
||||
op.execute('CREATE INDEX token_usage_key_ts_idx ON token_usage (api_key, "timestamp" DESC);')
|
||||
op.execute('CREATE INDEX token_usage_agent_ts_idx ON token_usage (agent_id, "timestamp" DESC);')
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX token_usage_mongo_id_uidx "
|
||||
"ON token_usage (mongo_id) WHERE mongo_id IS NOT NULL;"
|
||||
)
|
||||
|
||||
op.execute('CREATE INDEX user_logs_user_ts_idx ON user_logs (user_id, "timestamp" DESC);')
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX user_logs_mongo_id_uidx "
|
||||
"ON user_logs (mongo_id) WHERE mongo_id IS NOT NULL;"
|
||||
)
|
||||
|
||||
op.execute("CREATE INDEX user_tools_user_id_idx ON user_tools (user_id);")
|
||||
|
||||
op.execute("CREATE INDEX workflow_edges_from_node_idx ON workflow_edges (from_node_id);")
|
||||
op.execute("CREATE INDEX workflow_edges_to_node_idx ON workflow_edges (to_node_id);")
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX workflow_edges_wf_ver_eid_uidx "
|
||||
"ON workflow_edges (workflow_id, graph_version, edge_id);"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX workflow_nodes_wf_ver_nid_uidx "
|
||||
"ON workflow_nodes (workflow_id, graph_version, node_id);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX workflow_nodes_legacy_mongo_id_uidx "
|
||||
"ON workflow_nodes (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||
)
|
||||
|
||||
op.execute("CREATE INDEX workflow_runs_workflow_idx ON workflow_runs (workflow_id);")
|
||||
op.execute("CREATE INDEX workflow_runs_user_idx ON workflow_runs (user_id);")
|
||||
op.execute(
|
||||
"CREATE INDEX workflow_runs_status_started_idx "
|
||||
"ON workflow_runs (status, started_at DESC);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX workflow_runs_legacy_mongo_id_uidx "
|
||||
"ON workflow_runs (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||
)
|
||||
|
||||
op.execute("CREATE INDEX workflows_user_idx ON workflows (user_id);")
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX workflows_legacy_mongo_id_uidx "
|
||||
"ON workflows (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# user_id foreign keys (deferrable so backfills can stage rows)
|
||||
# ------------------------------------------------------------------
|
||||
user_fk_tables = (
|
||||
"agent_folders",
|
||||
"agents",
|
||||
"attachments",
|
||||
"connector_sessions",
|
||||
"conversation_messages",
|
||||
"conversations",
|
||||
"memories",
|
||||
"notes",
|
||||
"pending_tool_state",
|
||||
"prompts",
|
||||
"shared_conversations",
|
||||
"sources",
|
||||
"stack_logs",
|
||||
"todos",
|
||||
"token_usage",
|
||||
"user_logs",
|
||||
"user_tools",
|
||||
"workflow_runs",
|
||||
"workflows",
|
||||
)
|
||||
for table in user_fk_tables:
|
||||
op.execute(
|
||||
f"ALTER TABLE {table} "
|
||||
f"ADD CONSTRAINT {table}_user_id_fk "
|
||||
f"FOREIGN KEY (user_id) REFERENCES users(user_id) "
|
||||
f"ON DELETE RESTRICT DEFERRABLE INITIALLY IMMEDIATE;"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Triggers
|
||||
# ------------------------------------------------------------------
|
||||
updated_at_tables = (
|
||||
"agent_folders",
|
||||
"agents",
|
||||
"conversation_messages",
|
||||
"conversations",
|
||||
"memories",
|
||||
"notes",
|
||||
"prompts",
|
||||
"sources",
|
||||
"todos",
|
||||
"user_tools",
|
||||
"users",
|
||||
"workflows",
|
||||
)
|
||||
for table in updated_at_tables:
|
||||
op.execute(
|
||||
f"CREATE TRIGGER {table}_set_updated_at "
|
||||
f"BEFORE UPDATE ON {table} "
|
||||
f"FOR EACH ROW WHEN (OLD.* IS DISTINCT FROM NEW.*) "
|
||||
f"EXECUTE FUNCTION set_updated_at();"
|
||||
)
|
||||
|
||||
ensure_user_tables = (
|
||||
"agent_folders",
|
||||
"agents",
|
||||
"attachments",
|
||||
"connector_sessions",
|
||||
"conversation_messages",
|
||||
"conversations",
|
||||
"memories",
|
||||
"notes",
|
||||
"pending_tool_state",
|
||||
"prompts",
|
||||
"shared_conversations",
|
||||
"sources",
|
||||
"stack_logs",
|
||||
"todos",
|
||||
"token_usage",
|
||||
"user_logs",
|
||||
"user_tools",
|
||||
"workflow_runs",
|
||||
"workflows",
|
||||
)
|
||||
for table in ensure_user_tables:
|
||||
op.execute(
|
||||
f"CREATE TRIGGER {table}_ensure_user "
|
||||
f"BEFORE INSERT OR UPDATE OF user_id ON {table} "
|
||||
f"FOR EACH ROW EXECUTE FUNCTION ensure_user_exists();"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"CREATE TRIGGER conversation_messages_fill_user "
|
||||
"BEFORE INSERT ON conversation_messages "
|
||||
"FOR EACH ROW EXECUTE FUNCTION conversation_messages_fill_user_id();"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"CREATE TRIGGER attachments_cleanup_message_refs "
|
||||
"AFTER DELETE ON attachments "
|
||||
"FOR EACH ROW EXECUTE FUNCTION cleanup_message_attachment_refs();"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE TRIGGER agents_cleanup_user_prefs "
|
||||
"AFTER DELETE ON agents "
|
||||
"FOR EACH ROW EXECUTE FUNCTION cleanup_user_agent_prefs();"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE TRIGGER sources_cleanup_agent_extra_refs "
|
||||
"AFTER DELETE ON sources "
|
||||
"FOR EACH ROW EXECUTE FUNCTION cleanup_agent_extra_source_refs();"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Seed sentinel __system__ user (system/template sources attribute here)
|
||||
# ------------------------------------------------------------------
|
||||
op.execute(
|
||||
"INSERT INTO users (user_id) VALUES ('__system__') "
|
||||
"ON CONFLICT (user_id) DO NOTHING;"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Nuclear downgrade: drop everything this migration created. The
|
||||
# ordering drops FK-bearing children before parents; CASCADE would
|
||||
# also work but explicit ordering is easier to reason about in code
|
||||
# review.
|
||||
tables_in_drop_order = (
|
||||
"workflow_edges",
|
||||
"workflow_runs",
|
||||
"workflow_nodes",
|
||||
"workflows",
|
||||
"pending_tool_state",
|
||||
"shared_conversations",
|
||||
"conversation_messages",
|
||||
"conversations",
|
||||
"connector_sessions",
|
||||
"notes",
|
||||
"todos",
|
||||
"memories",
|
||||
"attachments",
|
||||
"agents",
|
||||
"sources",
|
||||
"agent_folders",
|
||||
"stack_logs",
|
||||
"user_logs",
|
||||
"token_usage",
|
||||
"user_tools",
|
||||
"prompts",
|
||||
"users",
|
||||
)
|
||||
for table in tables_in_drop_order:
|
||||
op.execute(f"DROP TABLE IF EXISTS {table} CASCADE;")
|
||||
|
||||
for fn in (
|
||||
"conversation_messages_fill_user_id",
|
||||
"cleanup_user_agent_prefs",
|
||||
"cleanup_agent_extra_source_refs",
|
||||
"cleanup_message_attachment_refs",
|
||||
"ensure_user_exists",
|
||||
"set_updated_at",
|
||||
):
|
||||
op.execute(f"DROP FUNCTION IF EXISTS {fn}();")
|
||||
37
application/alembic/versions/0002_app_metadata.py
Normal file
37
application/alembic/versions/0002_app_metadata.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""0002 app_metadata — singleton key/value table for instance-wide state.
|
||||
|
||||
Used by the startup version-check client to persist the anonymous
|
||||
instance UUID and a one-shot "notice shown" flag. Both values are tiny
|
||||
plain-text strings; this is a deliberate generic-config table rather
|
||||
than dedicated columns so future one-off settings (telemetry opt-in
|
||||
timestamps, feature-flag overrides, etc.) don't each need their own
|
||||
migration.
|
||||
|
||||
Revision ID: 0002_app_metadata
|
||||
Revises: 0001_initial
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
revision: str = "0002_app_metadata"
|
||||
down_revision: Union[str, None] = "0001_initial"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE app_metadata (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT NOT NULL
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP TABLE IF EXISTS app_metadata;")
|
||||
65
application/alembic/versions/0003_user_custom_models.py
Normal file
65
application/alembic/versions/0003_user_custom_models.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""0003 user_custom_models — per-user OpenAI-compatible model registrations.
|
||||
|
||||
Revision ID: 0003_user_custom_models
|
||||
Revises: 0002_app_metadata
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
revision: str = "0003_user_custom_models"
|
||||
down_revision: Union[str, None] = "0002_app_metadata"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE user_custom_models (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
upstream_model_id TEXT NOT NULL,
|
||||
display_name TEXT NOT NULL,
|
||||
description TEXT NOT NULL DEFAULT '',
|
||||
base_url TEXT NOT NULL,
|
||||
api_key_encrypted TEXT NOT NULL,
|
||||
capabilities JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
enabled BOOLEAN NOT NULL DEFAULT true,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX user_custom_models_user_id_idx "
|
||||
"ON user_custom_models (user_id);"
|
||||
)
|
||||
|
||||
# Mirror the project-wide invariants set up in 0001_initial:
|
||||
# * user_id FK with ON DELETE RESTRICT (deferrable),
|
||||
# * ensure_user_exists() trigger so the parent users row autocreates,
|
||||
# * set_updated_at() trigger.
|
||||
op.execute(
|
||||
"ALTER TABLE user_custom_models "
|
||||
"ADD CONSTRAINT user_custom_models_user_id_fk "
|
||||
"FOREIGN KEY (user_id) REFERENCES users(user_id) "
|
||||
"ON DELETE RESTRICT DEFERRABLE INITIALLY IMMEDIATE;"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE TRIGGER user_custom_models_ensure_user "
|
||||
"BEFORE INSERT OR UPDATE OF user_id ON user_custom_models "
|
||||
"FOR EACH ROW EXECUTE FUNCTION ensure_user_exists();"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE TRIGGER user_custom_models_set_updated_at "
|
||||
"BEFORE UPDATE ON user_custom_models "
|
||||
"FOR EACH ROW WHEN (OLD.* IS DISTINCT FROM NEW.*) "
|
||||
"EXECUTE FUNCTION set_updated_at();"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP TABLE IF EXISTS user_custom_models;")
|
||||
@@ -0,0 +1,7 @@
|
||||
from flask_restx import Api
|
||||
|
||||
api = Api(
|
||||
version="1.0",
|
||||
title="DocsGPT API",
|
||||
description="API for DocsGPT",
|
||||
)
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
from flask import Blueprint
|
||||
|
||||
from application.api import api
|
||||
from application.api.answer.routes.answer import AnswerResource
|
||||
from application.api.answer.routes.base import answer_ns
|
||||
from application.api.answer.routes.search import SearchResource
|
||||
from application.api.answer.routes.stream import StreamResource
|
||||
|
||||
|
||||
answer = Blueprint("answer", __name__)
|
||||
|
||||
api.add_namespace(answer_ns)
|
||||
|
||||
|
||||
def init_answer_routes():
|
||||
api.add_resource(StreamResource, "/stream")
|
||||
api.add_resource(AnswerResource, "/api/answer")
|
||||
api.add_resource(SearchResource, "/api/search")
|
||||
|
||||
|
||||
init_answer_routes()
|
||||
|
||||
@@ -1,916 +0,0 @@
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import traceback
|
||||
|
||||
from bson.dbref import DBRef
|
||||
from bson.objectid import ObjectId
|
||||
from flask import Blueprint, make_response, request, Response
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.agents.agent_creator import AgentCreator
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.error import bad_request
|
||||
from application.extensions import api
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.retriever.retriever_creator import RetrieverCreator
|
||||
from application.utils import check_required_fields, limit_chat_history
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
conversations_collection = db["conversations"]
|
||||
sources_collection = db["sources"]
|
||||
prompts_collection = db["prompts"]
|
||||
agents_collection = db["agents"]
|
||||
user_logs_collection = db["user_logs"]
|
||||
attachments_collection = db["attachments"]
|
||||
|
||||
answer = Blueprint("answer", __name__)
|
||||
answer_ns = Namespace("answer", description="Answer related operations", path="/")
|
||||
api.add_namespace(answer_ns)
|
||||
|
||||
gpt_model = ""
|
||||
# to have some kind of default behaviour
|
||||
if settings.LLM_NAME == "openai":
|
||||
gpt_model = "gpt-4o-mini"
|
||||
elif settings.LLM_NAME == "anthropic":
|
||||
gpt_model = "claude-2"
|
||||
elif settings.LLM_NAME == "groq":
|
||||
gpt_model = "llama3-8b-8192"
|
||||
elif settings.LLM_NAME == "novita":
|
||||
gpt_model = "deepseek/deepseek-r1"
|
||||
|
||||
if settings.MODEL_NAME: # in case there is particular model name configured
|
||||
gpt_model = settings.MODEL_NAME
|
||||
|
||||
# load the prompts
|
||||
current_dir = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
)
|
||||
with open(os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r") as f:
|
||||
chat_combine_template = f.read()
|
||||
|
||||
with open(os.path.join(current_dir, "prompts", "chat_reduce_prompt.txt"), "r") as f:
|
||||
chat_reduce_template = f.read()
|
||||
|
||||
with open(os.path.join(current_dir, "prompts", "chat_combine_creative.txt"), "r") as f:
|
||||
chat_combine_creative = f.read()
|
||||
|
||||
with open(os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r") as f:
|
||||
chat_combine_strict = f.read()
|
||||
|
||||
api_key_set = settings.API_KEY is not None
|
||||
embeddings_key_set = settings.EMBEDDINGS_KEY is not None
|
||||
|
||||
|
||||
async def async_generate(chain, question, chat_history):
|
||||
result = await chain.arun({"question": question, "chat_history": chat_history})
|
||||
return result
|
||||
|
||||
|
||||
def run_async_chain(chain, question, chat_history):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
result = {}
|
||||
try:
|
||||
answer = loop.run_until_complete(async_generate(chain, question, chat_history))
|
||||
finally:
|
||||
loop.close()
|
||||
result["answer"] = answer
|
||||
return result
|
||||
|
||||
|
||||
def get_agent_key(agent_id, user_id):
|
||||
if not agent_id:
|
||||
return None, False, None
|
||||
|
||||
try:
|
||||
agent = agents_collection.find_one({"_id": ObjectId(agent_id)})
|
||||
if agent is None:
|
||||
raise Exception("Agent not found", 404)
|
||||
|
||||
is_owner = agent.get("user") == user_id
|
||||
|
||||
if is_owner:
|
||||
agents_collection.update_one(
|
||||
{"_id": ObjectId(agent_id)},
|
||||
{"$set": {"lastUsedAt": datetime.datetime.now(datetime.timezone.utc)}},
|
||||
)
|
||||
return str(agent["key"]), False, None
|
||||
|
||||
is_shared_with_user = agent.get(
|
||||
"shared_publicly", False
|
||||
) or user_id in agent.get("shared_with", [])
|
||||
|
||||
if is_shared_with_user:
|
||||
return str(agent["key"]), True, agent.get("shared_token")
|
||||
|
||||
raise Exception("Unauthorized access to the agent", 403)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_agent_key: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
def get_data_from_api_key(api_key):
|
||||
data = agents_collection.find_one({"key": api_key})
|
||||
if not data:
|
||||
raise Exception("Invalid API Key, please generate a new key", 401)
|
||||
|
||||
source = data.get("source")
|
||||
if isinstance(source, DBRef):
|
||||
source_doc = db.dereference(source)
|
||||
data["source"] = str(source_doc["_id"])
|
||||
data["retriever"] = source_doc.get("retriever", data.get("retriever"))
|
||||
else:
|
||||
data["source"] = {}
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def get_retriever(source_id: str):
|
||||
doc = sources_collection.find_one({"_id": ObjectId(source_id)})
|
||||
if doc is None:
|
||||
raise Exception("Source document does not exist", 404)
|
||||
retriever_name = None if "retriever" not in doc else doc["retriever"]
|
||||
return retriever_name
|
||||
|
||||
|
||||
def is_azure_configured():
|
||||
return (
|
||||
settings.OPENAI_API_BASE
|
||||
and settings.OPENAI_API_VERSION
|
||||
and settings.AZURE_DEPLOYMENT_NAME
|
||||
)
|
||||
|
||||
|
||||
def save_conversation(
|
||||
conversation_id,
|
||||
question,
|
||||
response,
|
||||
thought,
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
decoded_token,
|
||||
index=None,
|
||||
api_key=None,
|
||||
agent_id=None,
|
||||
is_shared_usage=False,
|
||||
shared_token=None,
|
||||
):
|
||||
current_time = datetime.datetime.now(datetime.timezone.utc)
|
||||
if conversation_id is not None and index is not None:
|
||||
conversations_collection.update_one(
|
||||
{"_id": ObjectId(conversation_id), f"queries.{index}": {"$exists": True}},
|
||||
{
|
||||
"$set": {
|
||||
f"queries.{index}.prompt": question,
|
||||
f"queries.{index}.response": response,
|
||||
f"queries.{index}.thought": thought,
|
||||
f"queries.{index}.sources": source_log_docs,
|
||||
f"queries.{index}.tool_calls": tool_calls,
|
||||
f"queries.{index}.timestamp": current_time,
|
||||
}
|
||||
},
|
||||
)
|
||||
##remove following queries from the array
|
||||
conversations_collection.update_one(
|
||||
{"_id": ObjectId(conversation_id), f"queries.{index}": {"$exists": True}},
|
||||
{"$push": {"queries": {"$each": [], "$slice": index + 1}}},
|
||||
)
|
||||
elif conversation_id is not None and conversation_id != "None":
|
||||
conversations_collection.update_one(
|
||||
{"_id": ObjectId(conversation_id)},
|
||||
{
|
||||
"$push": {
|
||||
"queries": {
|
||||
"prompt": question,
|
||||
"response": response,
|
||||
"thought": thought,
|
||||
"sources": source_log_docs,
|
||||
"tool_calls": tool_calls,
|
||||
"timestamp": current_time,
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
else:
|
||||
# create new conversation
|
||||
# generate summary
|
||||
messages_summary = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Summarise following conversation in no more than 3 "
|
||||
"words, respond ONLY with the summary, use the same "
|
||||
"language as the system",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Summarise following conversation in no more than 3 words, "
|
||||
"respond ONLY with the summary, use the same language as the "
|
||||
"system \n\nUser: " + question + "\n\n" + "AI: " + response,
|
||||
},
|
||||
]
|
||||
|
||||
completion = llm.gen(model=gpt_model, messages=messages_summary, max_tokens=30)
|
||||
conversation_data = {
|
||||
"user": decoded_token.get("sub"),
|
||||
"date": datetime.datetime.utcnow(),
|
||||
"name": completion,
|
||||
"queries": [
|
||||
{
|
||||
"prompt": question,
|
||||
"response": response,
|
||||
"thought": thought,
|
||||
"sources": source_log_docs,
|
||||
"tool_calls": tool_calls,
|
||||
"timestamp": current_time,
|
||||
}
|
||||
],
|
||||
}
|
||||
if api_key:
|
||||
if agent_id:
|
||||
conversation_data["agent_id"] = agent_id
|
||||
if is_shared_usage:
|
||||
conversation_data["is_shared_usage"] = is_shared_usage
|
||||
conversation_data["shared_token"] = shared_token
|
||||
api_key_doc = agents_collection.find_one({"key": api_key})
|
||||
if api_key_doc:
|
||||
conversation_data["api_key"] = api_key_doc["key"]
|
||||
conversation_id = conversations_collection.insert_one(
|
||||
conversation_data
|
||||
).inserted_id
|
||||
return conversation_id
|
||||
|
||||
|
||||
def get_prompt(prompt_id):
|
||||
if prompt_id == "default":
|
||||
prompt = chat_combine_template
|
||||
elif prompt_id == "creative":
|
||||
prompt = chat_combine_creative
|
||||
elif prompt_id == "strict":
|
||||
prompt = chat_combine_strict
|
||||
else:
|
||||
prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"]
|
||||
return prompt
|
||||
|
||||
|
||||
def complete_stream(
|
||||
question,
|
||||
agent,
|
||||
retriever,
|
||||
conversation_id,
|
||||
user_api_key,
|
||||
decoded_token,
|
||||
isNoneDoc=False,
|
||||
index=None,
|
||||
should_save_conversation=True,
|
||||
attachments=None,
|
||||
agent_id=None,
|
||||
is_shared_usage=False,
|
||||
shared_token=None,
|
||||
):
|
||||
try:
|
||||
response_full, thought, source_log_docs, tool_calls = "", "", [], []
|
||||
attachment_ids = []
|
||||
|
||||
if attachments:
|
||||
attachment_ids = [attachment["id"] for attachment in attachments]
|
||||
logger.info(
|
||||
f"Processing request with {len(attachments)} attachments: {attachment_ids}"
|
||||
)
|
||||
|
||||
answer = agent.gen(query=question, retriever=retriever)
|
||||
|
||||
for line in answer:
|
||||
if "answer" in line:
|
||||
response_full += str(line["answer"])
|
||||
data = json.dumps({"type": "answer", "answer": line["answer"]})
|
||||
yield f"data: {data}\n\n"
|
||||
elif "sources" in line:
|
||||
truncated_sources = []
|
||||
source_log_docs = line["sources"]
|
||||
for source in line["sources"]:
|
||||
truncated_source = source.copy()
|
||||
if "text" in truncated_source:
|
||||
truncated_source["text"] = (
|
||||
truncated_source["text"][:100].strip() + "..."
|
||||
)
|
||||
truncated_sources.append(truncated_source)
|
||||
if len(truncated_sources) > 0:
|
||||
data = json.dumps({"type": "source", "source": truncated_sources})
|
||||
yield f"data: {data}\n\n"
|
||||
elif "tool_calls" in line:
|
||||
tool_calls = line["tool_calls"]
|
||||
data = json.dumps({"type": "tool_calls", "tool_calls": tool_calls})
|
||||
yield f"data: {data}\n\n"
|
||||
elif "thought" in line:
|
||||
thought += line["thought"]
|
||||
data = json.dumps({"type": "thought", "thought": line["thought"]})
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
if isNoneDoc:
|
||||
for doc in source_log_docs:
|
||||
doc["source"] = "None"
|
||||
|
||||
llm = LLMCreator.create_llm(
|
||||
settings.LLM_NAME,
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
if should_save_conversation:
|
||||
conversation_id = save_conversation(
|
||||
conversation_id,
|
||||
question,
|
||||
response_full,
|
||||
thought,
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
decoded_token,
|
||||
index,
|
||||
api_key=user_api_key,
|
||||
agent_id=agent_id,
|
||||
is_shared_usage=is_shared_usage,
|
||||
shared_token=shared_token,
|
||||
)
|
||||
else:
|
||||
conversation_id = None
|
||||
|
||||
# send data.type = "end" to indicate that the stream has ended as json
|
||||
data = json.dumps({"type": "id", "id": str(conversation_id)})
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
retriever_params = retriever.get_params()
|
||||
user_logs_collection.insert_one(
|
||||
{
|
||||
"action": "stream_answer",
|
||||
"level": "info",
|
||||
"user": decoded_token.get("sub"),
|
||||
"api_key": user_api_key,
|
||||
"question": question,
|
||||
"response": response_full,
|
||||
"sources": source_log_docs,
|
||||
"retriever_params": retriever_params,
|
||||
"attachments": attachment_ids,
|
||||
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
||||
}
|
||||
)
|
||||
data = json.dumps({"type": "end"})
|
||||
yield f"data: {data}\n\n"
|
||||
except Exception as e:
|
||||
logger.error(f"Error in stream: {str(e)}", exc_info=True)
|
||||
data = json.dumps(
|
||||
{
|
||||
"type": "error",
|
||||
"error": "Please try again later. We apologize for any inconvenience.",
|
||||
}
|
||||
)
|
||||
yield f"data: {data}\n\n"
|
||||
return
|
||||
|
||||
|
||||
@answer_ns.route("/stream")
|
||||
class Stream(Resource):
|
||||
stream_model = api.model(
|
||||
"StreamModel",
|
||||
{
|
||||
"question": fields.String(
|
||||
required=True, description="Question to be asked"
|
||||
),
|
||||
"history": fields.List(
|
||||
fields.String, required=False, description="Chat history"
|
||||
),
|
||||
"conversation_id": fields.String(
|
||||
required=False, description="Conversation ID"
|
||||
),
|
||||
"prompt_id": fields.String(
|
||||
required=False, default="default", description="Prompt ID"
|
||||
),
|
||||
"chunks": fields.Integer(
|
||||
required=False, default=2, description="Number of chunks"
|
||||
),
|
||||
"token_limit": fields.Integer(required=False, description="Token limit"),
|
||||
"retriever": fields.String(required=False, description="Retriever type"),
|
||||
"api_key": fields.String(required=False, description="API key"),
|
||||
"active_docs": fields.String(
|
||||
required=False, description="Active documents"
|
||||
),
|
||||
"isNoneDoc": fields.Boolean(
|
||||
required=False, description="Flag indicating if no document is used"
|
||||
),
|
||||
"index": fields.Integer(
|
||||
required=False, description="Index of the query to update"
|
||||
),
|
||||
"save_conversation": fields.Boolean(
|
||||
required=False,
|
||||
default=True,
|
||||
description="Whether to save the conversation",
|
||||
),
|
||||
"attachments": fields.List(
|
||||
fields.String, required=False, description="List of attachment IDs"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(stream_model)
|
||||
@api.doc(description="Stream a response based on the question and retriever")
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
required_fields = ["question"]
|
||||
if "index" in data:
|
||||
required_fields = ["question", "conversation_id"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
|
||||
save_conv = data.get("save_conversation", True)
|
||||
|
||||
try:
|
||||
question = data["question"]
|
||||
history = limit_chat_history(
|
||||
json.loads(data.get("history", "[]")), gpt_model=gpt_model
|
||||
)
|
||||
conversation_id = data.get("conversation_id")
|
||||
prompt_id = data.get("prompt_id", "default")
|
||||
attachment_ids = data.get("attachments", [])
|
||||
|
||||
index = data.get("index", None)
|
||||
chunks = int(data.get("chunks", 2))
|
||||
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
|
||||
retriever_name = data.get("retriever", "classic")
|
||||
agent_id = data.get("agent_id", None)
|
||||
agent_type = settings.AGENT_NAME
|
||||
decoded_token = getattr(request, "decoded_token", None)
|
||||
user_sub = decoded_token.get("sub") if decoded_token else None
|
||||
agent_key, is_shared_usage, shared_token = get_agent_key(
|
||||
agent_id, user_sub
|
||||
)
|
||||
|
||||
if agent_key:
|
||||
data.update({"api_key": agent_key})
|
||||
else:
|
||||
agent_id = None
|
||||
|
||||
if "api_key" in data:
|
||||
data_key = get_data_from_api_key(data["api_key"])
|
||||
chunks = int(data_key.get("chunks", 2))
|
||||
prompt_id = data_key.get("prompt_id", "default")
|
||||
source = {"active_docs": data_key.get("source")}
|
||||
retriever_name = data_key.get("retriever", retriever_name)
|
||||
user_api_key = data["api_key"]
|
||||
agent_type = data_key.get("agent_type", agent_type)
|
||||
if is_shared_usage:
|
||||
decoded_token = request.decoded_token
|
||||
else:
|
||||
decoded_token = {"sub": data_key.get("user")}
|
||||
is_shared_usage = False
|
||||
|
||||
elif "active_docs" in data:
|
||||
source = {"active_docs": data["active_docs"]}
|
||||
retriever_name = get_retriever(data["active_docs"]) or retriever_name
|
||||
user_api_key = None
|
||||
decoded_token = request.decoded_token
|
||||
|
||||
else:
|
||||
source = {}
|
||||
user_api_key = None
|
||||
decoded_token = request.decoded_token
|
||||
|
||||
if not decoded_token:
|
||||
return make_response({"error": "Unauthorized"}, 401)
|
||||
|
||||
attachments = get_attachments_content(
|
||||
attachment_ids, decoded_token.get("sub")
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"/stream - request_data: {data}, source: {source}, attachments: {len(attachments)}",
|
||||
extra={"data": json.dumps({"request_data": data, "source": source})},
|
||||
)
|
||||
|
||||
prompt = get_prompt(prompt_id)
|
||||
if "isNoneDoc" in data and data["isNoneDoc"] is True:
|
||||
chunks = 0
|
||||
|
||||
agent = AgentCreator.create_agent(
|
||||
agent_type,
|
||||
endpoint="stream",
|
||||
llm_name=settings.LLM_NAME,
|
||||
gpt_model=gpt_model,
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=user_api_key,
|
||||
prompt=prompt,
|
||||
chat_history=history,
|
||||
decoded_token=decoded_token,
|
||||
attachments=attachments,
|
||||
)
|
||||
|
||||
retriever = RetrieverCreator.create_retriever(
|
||||
retriever_name,
|
||||
source=source,
|
||||
chat_history=history,
|
||||
prompt=prompt,
|
||||
chunks=chunks,
|
||||
token_limit=token_limit,
|
||||
gpt_model=gpt_model,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
return Response(
|
||||
complete_stream(
|
||||
question=question,
|
||||
agent=agent,
|
||||
retriever=retriever,
|
||||
conversation_id=conversation_id,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
isNoneDoc=data.get("isNoneDoc"),
|
||||
index=index,
|
||||
should_save_conversation=save_conv,
|
||||
agent_id=agent_id,
|
||||
is_shared_usage=is_shared_usage,
|
||||
shared_token=shared_token,
|
||||
),
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
|
||||
except ValueError:
|
||||
message = "Malformed request body"
|
||||
logger.error(f"/stream - error: {message}")
|
||||
return Response(
|
||||
error_stream_generate(message),
|
||||
status=400,
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"/stream - error: {str(e)} - traceback: {traceback.format_exc()}",
|
||||
extra={"error": str(e), "traceback": traceback.format_exc()},
|
||||
)
|
||||
status_code = 400
|
||||
return Response(
|
||||
error_stream_generate("Unknown error occurred"),
|
||||
status=status_code,
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
|
||||
|
||||
def error_stream_generate(err_response):
|
||||
data = json.dumps({"type": "error", "error": err_response})
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
|
||||
@answer_ns.route("/api/answer")
|
||||
class Answer(Resource):
|
||||
answer_model = api.model(
|
||||
"AnswerModel",
|
||||
{
|
||||
"question": fields.String(
|
||||
required=True, description="The question to answer"
|
||||
),
|
||||
"history": fields.List(
|
||||
fields.String, required=False, description="Conversation history"
|
||||
),
|
||||
"conversation_id": fields.String(
|
||||
required=False, description="Conversation ID"
|
||||
),
|
||||
"prompt_id": fields.String(
|
||||
required=False, default="default", description="Prompt ID"
|
||||
),
|
||||
"chunks": fields.Integer(
|
||||
required=False, default=2, description="Number of chunks"
|
||||
),
|
||||
"token_limit": fields.Integer(required=False, description="Token limit"),
|
||||
"retriever": fields.String(required=False, description="Retriever type"),
|
||||
"api_key": fields.String(required=False, description="API key"),
|
||||
"active_docs": fields.String(
|
||||
required=False, description="Active documents"
|
||||
),
|
||||
"isNoneDoc": fields.Boolean(
|
||||
required=False, description="Flag indicating if no document is used"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(answer_model)
|
||||
@api.doc(description="Provide an answer based on the question and retriever")
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
required_fields = ["question"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
|
||||
try:
|
||||
question = data["question"]
|
||||
history = limit_chat_history(
|
||||
json.loads(data.get("history", [])), gpt_model=gpt_model
|
||||
)
|
||||
conversation_id = data.get("conversation_id")
|
||||
prompt_id = data.get("prompt_id", "default")
|
||||
chunks = int(data.get("chunks", 2))
|
||||
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
|
||||
retriever_name = data.get("retriever", "classic")
|
||||
agent_type = settings.AGENT_NAME
|
||||
|
||||
if "api_key" in data:
|
||||
data_key = get_data_from_api_key(data["api_key"])
|
||||
chunks = int(data_key.get("chunks", 2))
|
||||
prompt_id = data_key.get("prompt_id", "default")
|
||||
source = {"active_docs": data_key.get("source")}
|
||||
retriever_name = data_key.get("retriever", retriever_name)
|
||||
user_api_key = data["api_key"]
|
||||
agent_type = data_key.get("agent_type", agent_type)
|
||||
decoded_token = {"sub": data_key.get("user")}
|
||||
|
||||
elif "active_docs" in data:
|
||||
source = {"active_docs": data["active_docs"]}
|
||||
retriever_name = get_retriever(data["active_docs"]) or retriever_name
|
||||
user_api_key = None
|
||||
decoded_token = request.decoded_token
|
||||
|
||||
else:
|
||||
source = {}
|
||||
user_api_key = None
|
||||
decoded_token = request.decoded_token
|
||||
|
||||
if not decoded_token:
|
||||
return make_response({"error": "Unauthorized"}, 401)
|
||||
|
||||
prompt = get_prompt(prompt_id)
|
||||
|
||||
logger.info(
|
||||
f"/api/answer - request_data: {data}, source: {source}",
|
||||
extra={"data": json.dumps({"request_data": data, "source": source})},
|
||||
)
|
||||
|
||||
agent = AgentCreator.create_agent(
|
||||
agent_type,
|
||||
endpoint="api/answer",
|
||||
llm_name=settings.LLM_NAME,
|
||||
gpt_model=gpt_model,
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=user_api_key,
|
||||
prompt=prompt,
|
||||
chat_history=history,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
retriever = RetrieverCreator.create_retriever(
|
||||
retriever_name,
|
||||
source=source,
|
||||
chat_history=history,
|
||||
prompt=prompt,
|
||||
chunks=chunks,
|
||||
token_limit=token_limit,
|
||||
gpt_model=gpt_model,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
response_full = ""
|
||||
source_log_docs = []
|
||||
tool_calls = []
|
||||
stream_ended = False
|
||||
thought = ""
|
||||
|
||||
for line in complete_stream(
|
||||
question=question,
|
||||
agent=agent,
|
||||
retriever=retriever,
|
||||
conversation_id=conversation_id,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
isNoneDoc=data.get("isNoneDoc"),
|
||||
index=None,
|
||||
should_save_conversation=False,
|
||||
):
|
||||
try:
|
||||
event_data = line.replace("data: ", "").strip()
|
||||
event = json.loads(event_data)
|
||||
|
||||
if event["type"] == "answer":
|
||||
response_full += event["answer"]
|
||||
elif event["type"] == "source":
|
||||
source_log_docs = event["source"]
|
||||
elif event["type"] == "tool_calls":
|
||||
tool_calls = event["tool_calls"]
|
||||
elif event["type"] == "thought":
|
||||
thought = event["thought"]
|
||||
elif event["type"] == "error":
|
||||
logger.error(f"Error from stream: {event['error']}")
|
||||
return bad_request(500, event["error"])
|
||||
elif event["type"] == "end":
|
||||
stream_ended = True
|
||||
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
logger.warning(f"Error parsing stream event: {e}, line: {line}")
|
||||
continue
|
||||
|
||||
if not stream_ended:
|
||||
logger.error("Stream ended unexpectedly without an 'end' event.")
|
||||
return bad_request(500, "Stream ended unexpectedly.")
|
||||
|
||||
if data.get("isNoneDoc"):
|
||||
for doc in source_log_docs:
|
||||
doc["source"] = "None"
|
||||
|
||||
llm = LLMCreator.create_llm(
|
||||
settings.LLM_NAME,
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
result = {"answer": response_full, "sources": source_log_docs}
|
||||
result["conversation_id"] = str(
|
||||
save_conversation(
|
||||
conversation_id,
|
||||
question,
|
||||
response_full,
|
||||
thought,
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
decoded_token,
|
||||
api_key=user_api_key,
|
||||
)
|
||||
)
|
||||
|
||||
retriever_params = retriever.get_params()
|
||||
user_logs_collection.insert_one(
|
||||
{
|
||||
"action": "api_answer",
|
||||
"level": "info",
|
||||
"user": decoded_token.get("sub"),
|
||||
"api_key": user_api_key,
|
||||
"question": question,
|
||||
"response": response_full,
|
||||
"sources": source_log_docs,
|
||||
"retriever_params": retriever_params,
|
||||
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",
|
||||
extra={"error": str(e), "traceback": traceback.format_exc()},
|
||||
)
|
||||
return bad_request(500, str(e))
|
||||
|
||||
return make_response(result, 200)
|
||||
|
||||
|
||||
@answer_ns.route("/api/search")
|
||||
class Search(Resource):
|
||||
search_model = api.model(
|
||||
"SearchModel",
|
||||
{
|
||||
"question": fields.String(
|
||||
required=True, description="The question to search"
|
||||
),
|
||||
"chunks": fields.Integer(
|
||||
required=False, default=2, description="Number of chunks"
|
||||
),
|
||||
"api_key": fields.String(
|
||||
required=False, description="API key for authentication"
|
||||
),
|
||||
"active_docs": fields.String(
|
||||
required=False, description="Active documents for retrieval"
|
||||
),
|
||||
"retriever": fields.String(required=False, description="Retriever type"),
|
||||
"token_limit": fields.Integer(
|
||||
required=False, description="Limit for tokens"
|
||||
),
|
||||
"isNoneDoc": fields.Boolean(
|
||||
required=False, description="Flag indicating if no document is used"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(search_model)
|
||||
@api.doc(
|
||||
description="Search for relevant documents based on the question and retriever"
|
||||
)
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
required_fields = ["question"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
|
||||
try:
|
||||
question = data["question"]
|
||||
chunks = int(data.get("chunks", 2))
|
||||
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
|
||||
retriever_name = data.get("retriever", "classic")
|
||||
|
||||
if "api_key" in data:
|
||||
data_key = get_data_from_api_key(data["api_key"])
|
||||
chunks = int(data_key.get("chunks", 2))
|
||||
source = {"active_docs": data_key.get("source")}
|
||||
user_api_key = data["api_key"]
|
||||
decoded_token = {"sub": data_key.get("user")}
|
||||
|
||||
elif "active_docs" in data:
|
||||
source = {"active_docs": data["active_docs"]}
|
||||
user_api_key = None
|
||||
decoded_token = request.decoded_token
|
||||
|
||||
else:
|
||||
source = {}
|
||||
user_api_key = None
|
||||
decoded_token = request.decoded_token
|
||||
|
||||
if not decoded_token:
|
||||
return make_response({"error": "Unauthorized"}, 401)
|
||||
|
||||
logger.info(
|
||||
f"/api/answer - request_data: {data}, source: {source}",
|
||||
extra={"data": json.dumps({"request_data": data, "source": source})},
|
||||
)
|
||||
|
||||
retriever = RetrieverCreator.create_retriever(
|
||||
retriever_name,
|
||||
source=source,
|
||||
chat_history=[],
|
||||
prompt="default",
|
||||
chunks=chunks,
|
||||
token_limit=token_limit,
|
||||
gpt_model=gpt_model,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
docs = retriever.search(question)
|
||||
retriever_params = retriever.get_params()
|
||||
|
||||
user_logs_collection.insert_one(
|
||||
{
|
||||
"action": "api_search",
|
||||
"level": "info",
|
||||
"user": decoded_token.get("sub"),
|
||||
"api_key": user_api_key,
|
||||
"question": question,
|
||||
"sources": docs,
|
||||
"retriever_params": retriever_params,
|
||||
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
||||
}
|
||||
)
|
||||
|
||||
if data.get("isNoneDoc"):
|
||||
for doc in docs:
|
||||
doc["source"] = "None"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"/api/search - error: {str(e)} - traceback: {traceback.format_exc()}",
|
||||
extra={"error": str(e), "traceback": traceback.format_exc()},
|
||||
)
|
||||
return bad_request(500, str(e))
|
||||
|
||||
return make_response(docs, 200)
|
||||
|
||||
|
||||
def get_attachments_content(attachment_ids, user):
|
||||
"""
|
||||
Retrieve content from attachment documents based on their IDs.
|
||||
|
||||
Args:
|
||||
attachment_ids (list): List of attachment document IDs
|
||||
user (str): User identifier to verify ownership
|
||||
|
||||
Returns:
|
||||
list: List of dictionaries containing attachment content and metadata
|
||||
"""
|
||||
if not attachment_ids:
|
||||
return []
|
||||
|
||||
attachments = []
|
||||
for attachment_id in attachment_ids:
|
||||
try:
|
||||
attachment_doc = attachments_collection.find_one(
|
||||
{"_id": ObjectId(attachment_id), "user": user}
|
||||
)
|
||||
|
||||
if attachment_doc:
|
||||
attachments.append(attachment_doc)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error retrieving attachment {attachment_id}: {e}", exc_info=True
|
||||
)
|
||||
|
||||
return attachments
|
||||
153
application/api/answer/routes/answer.py
Normal file
153
application/api/answer/routes/answer.py
Normal file
@@ -0,0 +1,153 @@
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
from flask import make_response, request
|
||||
from flask_restx import fields, Resource
|
||||
|
||||
from application.api import api
|
||||
|
||||
from application.api.answer.routes.base import answer_ns, BaseAnswerResource
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@answer_ns.route("/api/answer")
|
||||
class AnswerResource(Resource, BaseAnswerResource):
|
||||
def __init__(self, *args, **kwargs):
|
||||
Resource.__init__(self, *args, **kwargs)
|
||||
BaseAnswerResource.__init__(self)
|
||||
|
||||
answer_model = answer_ns.model(
|
||||
"AnswerModel",
|
||||
{
|
||||
"question": fields.String(
|
||||
required=True, description="Question to be asked"
|
||||
),
|
||||
"history": fields.List(
|
||||
fields.String,
|
||||
required=False,
|
||||
description="Conversation history (only for new conversations)",
|
||||
),
|
||||
"conversation_id": fields.String(
|
||||
required=False,
|
||||
description="Existing conversation ID (loads history)",
|
||||
),
|
||||
"prompt_id": fields.String(
|
||||
required=False, default="default", description="Prompt ID"
|
||||
),
|
||||
"chunks": fields.Integer(
|
||||
required=False, default=2, description="Number of chunks"
|
||||
),
|
||||
"retriever": fields.String(required=False, description="Retriever type"),
|
||||
"api_key": fields.String(required=False, description="API key"),
|
||||
"agent_id": fields.String(required=False, description="Agent ID"),
|
||||
"active_docs": fields.String(
|
||||
required=False, description="Active documents"
|
||||
),
|
||||
"isNoneDoc": fields.Boolean(
|
||||
required=False, description="Flag indicating if no document is used"
|
||||
),
|
||||
"save_conversation": fields.Boolean(
|
||||
required=False,
|
||||
default=True,
|
||||
description="Whether to save the conversation",
|
||||
),
|
||||
"model_id": fields.String(
|
||||
required=False,
|
||||
description="Model ID to use for this request",
|
||||
),
|
||||
"passthrough": fields.Raw(
|
||||
required=False,
|
||||
description="Dynamic parameters to inject into prompt template",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(answer_model)
|
||||
@api.doc(description="Provide a response based on the question and retriever")
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
if error := self.validate_request(data):
|
||||
return error
|
||||
decoded_token = getattr(request, "decoded_token", None)
|
||||
processor = StreamProcessor(data, decoded_token)
|
||||
try:
|
||||
# ---- Continuation mode ----
|
||||
if data.get("tool_actions"):
|
||||
(
|
||||
agent,
|
||||
messages,
|
||||
tools_dict,
|
||||
pending_tool_calls,
|
||||
tool_actions,
|
||||
) = processor.resume_from_tool_actions(
|
||||
data["tool_actions"], data["conversation_id"]
|
||||
)
|
||||
if not processor.decoded_token:
|
||||
return make_response({"error": "Unauthorized"}, 401)
|
||||
if error := self.check_usage(processor.agent_config):
|
||||
return error
|
||||
stream = self.complete_stream(
|
||||
question="",
|
||||
agent=agent,
|
||||
conversation_id=processor.conversation_id,
|
||||
user_api_key=processor.agent_config.get("user_api_key"),
|
||||
decoded_token=processor.decoded_token,
|
||||
agent_id=processor.agent_id,
|
||||
model_id=processor.model_id,
|
||||
_continuation={
|
||||
"messages": messages,
|
||||
"tools_dict": tools_dict,
|
||||
"pending_tool_calls": pending_tool_calls,
|
||||
"tool_actions": tool_actions,
|
||||
},
|
||||
)
|
||||
else:
|
||||
# ---- Normal mode ----
|
||||
agent = processor.build_agent(data.get("question", ""))
|
||||
if not processor.decoded_token:
|
||||
return make_response({"error": "Unauthorized"}, 401)
|
||||
|
||||
if error := self.check_usage(processor.agent_config):
|
||||
return error
|
||||
|
||||
stream = self.complete_stream(
|
||||
question=data["question"],
|
||||
agent=agent,
|
||||
conversation_id=processor.conversation_id,
|
||||
user_api_key=processor.agent_config.get("user_api_key"),
|
||||
decoded_token=processor.decoded_token,
|
||||
isNoneDoc=data.get("isNoneDoc"),
|
||||
index=None,
|
||||
should_save_conversation=data.get("save_conversation", True),
|
||||
agent_id=processor.agent_id,
|
||||
is_shared_usage=processor.is_shared_usage,
|
||||
shared_token=processor.shared_token,
|
||||
model_id=processor.model_id,
|
||||
)
|
||||
|
||||
stream_result = self.process_response_stream(stream)
|
||||
|
||||
if stream_result["error"]:
|
||||
return make_response({"error": stream_result["error"]}, 400)
|
||||
|
||||
result = {
|
||||
"conversation_id": stream_result["conversation_id"],
|
||||
"answer": stream_result["answer"],
|
||||
"sources": stream_result["sources"],
|
||||
"tool_calls": stream_result["tool_calls"],
|
||||
"thought": stream_result["thought"],
|
||||
}
|
||||
|
||||
extra_info = stream_result.get("extra")
|
||||
if extra_info:
|
||||
result.update(extra_info)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",
|
||||
extra={"error": str(e), "traceback": traceback.format_exc()},
|
||||
)
|
||||
return make_response({"error": "An error occurred processing your request"}, 500)
|
||||
return make_response(result, 200)
|
||||
673
application/api/answer/routes/base.py
Normal file
673
application/api/answer/routes/base.py
Normal file
@@ -0,0 +1,673 @@
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, Generator, List, Optional
|
||||
|
||||
from flask import jsonify, make_response, Response
|
||||
from flask_restx import Namespace
|
||||
|
||||
from application.api.answer.services.continuation_service import ContinuationService
|
||||
from application.api.answer.services.conversation_service import ConversationService
|
||||
from application.core.model_utils import (
|
||||
get_api_key_for_provider,
|
||||
get_default_model_id,
|
||||
get_provider_from_model_id,
|
||||
)
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.error import sanitize_api_error
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.token_usage import TokenUsageRepository
|
||||
from application.storage.db.repositories.user_logs import UserLogsRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.utils import check_required_fields
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
answer_ns = Namespace("answer", description="Answer related operations", path="/")
|
||||
|
||||
|
||||
class BaseAnswerResource:
|
||||
"""Shared base class for answer endpoints"""
|
||||
|
||||
def __init__(self):
|
||||
self.default_model_id = get_default_model_id()
|
||||
self.conversation_service = ConversationService()
|
||||
|
||||
def validate_request(
|
||||
self, data: Dict[str, Any], require_conversation_id: bool = False
|
||||
) -> Optional[Response]:
|
||||
"""Common request validation.
|
||||
|
||||
Continuation requests (``tool_actions`` present) require
|
||||
``conversation_id`` but not ``question``.
|
||||
"""
|
||||
if data.get("tool_actions"):
|
||||
# Continuation mode — question is not required
|
||||
if missing := check_required_fields(data, ["conversation_id"]):
|
||||
return missing
|
||||
return None
|
||||
required_fields = ["question"]
|
||||
if require_conversation_id:
|
||||
required_fields.append("conversation_id")
|
||||
if missing_fields := check_required_fields(data, required_fields):
|
||||
return missing_fields
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _prepare_tool_calls_for_logging(
|
||||
tool_calls: Optional[List[Dict[str, Any]]], max_chars: int = 10000
|
||||
) -> List[Dict[str, Any]]:
|
||||
if not tool_calls:
|
||||
return []
|
||||
|
||||
prepared = []
|
||||
for tool_call in tool_calls:
|
||||
if not isinstance(tool_call, dict):
|
||||
prepared.append({"result": str(tool_call)[:max_chars]})
|
||||
continue
|
||||
|
||||
item = dict(tool_call)
|
||||
for key in ("result", "result_full"):
|
||||
value = item.get(key)
|
||||
if isinstance(value, str) and len(value) > max_chars:
|
||||
item[key] = value[:max_chars]
|
||||
prepared.append(item)
|
||||
return prepared
|
||||
|
||||
def check_usage(self, agent_config: Dict) -> Optional[Response]:
|
||||
"""Check if there is a usage limit and if it is exceeded
|
||||
|
||||
Args:
|
||||
agent_config: The config dict of agent instance
|
||||
|
||||
Returns:
|
||||
None or Response if either of limits exceeded.
|
||||
|
||||
"""
|
||||
api_key = agent_config.get("user_api_key")
|
||||
if not api_key:
|
||||
return None
|
||||
with db_readonly() as conn:
|
||||
agent = AgentsRepository(conn).find_by_key(api_key)
|
||||
|
||||
if not agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid API key."}), 401
|
||||
)
|
||||
limited_token_mode_raw = agent.get("limited_token_mode", False)
|
||||
limited_request_mode_raw = agent.get("limited_request_mode", False)
|
||||
|
||||
limited_token_mode = (
|
||||
limited_token_mode_raw
|
||||
if isinstance(limited_token_mode_raw, bool)
|
||||
else limited_token_mode_raw == "True"
|
||||
)
|
||||
limited_request_mode = (
|
||||
limited_request_mode_raw
|
||||
if isinstance(limited_request_mode_raw, bool)
|
||||
else limited_request_mode_raw == "True"
|
||||
)
|
||||
|
||||
token_limit = int(
|
||||
agent.get("token_limit") or settings.DEFAULT_AGENT_LIMITS["token_limit"]
|
||||
)
|
||||
request_limit = int(
|
||||
agent.get("request_limit") or settings.DEFAULT_AGENT_LIMITS["request_limit"]
|
||||
)
|
||||
|
||||
end_date = datetime.datetime.now(datetime.timezone.utc)
|
||||
start_date = end_date - datetime.timedelta(hours=24)
|
||||
|
||||
if limited_token_mode or limited_request_mode:
|
||||
with db_readonly() as conn:
|
||||
token_repo = TokenUsageRepository(conn)
|
||||
if limited_token_mode:
|
||||
daily_token_usage = token_repo.sum_tokens_in_range(
|
||||
start=start_date, end=end_date, api_key=api_key,
|
||||
)
|
||||
else:
|
||||
daily_token_usage = 0
|
||||
if limited_request_mode:
|
||||
daily_request_usage = token_repo.count_in_range(
|
||||
start=start_date, end=end_date, api_key=api_key,
|
||||
)
|
||||
else:
|
||||
daily_request_usage = 0
|
||||
else:
|
||||
daily_token_usage = 0
|
||||
daily_request_usage = 0
|
||||
if not limited_token_mode and not limited_request_mode:
|
||||
return None
|
||||
token_exceeded = (
|
||||
limited_token_mode and token_limit > 0 and daily_token_usage >= token_limit
|
||||
)
|
||||
request_exceeded = (
|
||||
limited_request_mode
|
||||
and request_limit > 0
|
||||
and daily_request_usage >= request_limit
|
||||
)
|
||||
|
||||
if token_exceeded or request_exceeded:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Exceeding usage limit, please try again later.",
|
||||
}
|
||||
),
|
||||
429,
|
||||
)
|
||||
return None
|
||||
|
||||
def complete_stream(
|
||||
self,
|
||||
question: str,
|
||||
agent: Any,
|
||||
conversation_id: Optional[str],
|
||||
user_api_key: Optional[str],
|
||||
decoded_token: Dict[str, Any],
|
||||
isNoneDoc: bool = False,
|
||||
index: Optional[int] = None,
|
||||
should_save_conversation: bool = True,
|
||||
attachment_ids: Optional[List[str]] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
is_shared_usage: bool = False,
|
||||
shared_token: Optional[str] = None,
|
||||
model_id: Optional[str] = None,
|
||||
model_user_id: Optional[str] = None,
|
||||
_continuation: Optional[Dict] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Generator function that streams the complete conversation response.
|
||||
|
||||
Args:
|
||||
question: The user's question
|
||||
agent: The agent instance
|
||||
retriever: The retriever instance
|
||||
conversation_id: Existing conversation ID
|
||||
user_api_key: User's API key if any
|
||||
decoded_token: Decoded JWT token
|
||||
isNoneDoc: Flag for document-less responses
|
||||
index: Index of message to update
|
||||
should_save_conversation: Whether to persist the conversation
|
||||
attachment_ids: List of attachment IDs
|
||||
agent_id: ID of agent used
|
||||
is_shared_usage: Flag for shared agent usage
|
||||
shared_token: Token for shared agent
|
||||
model_id: Model ID used for the request
|
||||
retrieved_docs: Pre-fetched documents for sources (optional)
|
||||
|
||||
Yields:
|
||||
Server-sent event strings
|
||||
"""
|
||||
try:
|
||||
response_full, thought, source_log_docs, tool_calls = "", "", [], []
|
||||
is_structured = False
|
||||
schema_info = None
|
||||
structured_chunks = []
|
||||
query_metadata = {}
|
||||
paused = False
|
||||
|
||||
if _continuation:
|
||||
gen_iter = agent.gen_continuation(
|
||||
messages=_continuation["messages"],
|
||||
tools_dict=_continuation["tools_dict"],
|
||||
pending_tool_calls=_continuation["pending_tool_calls"],
|
||||
tool_actions=_continuation["tool_actions"],
|
||||
)
|
||||
else:
|
||||
gen_iter = agent.gen(query=question)
|
||||
|
||||
for line in gen_iter:
|
||||
if "metadata" in line:
|
||||
query_metadata.update(line["metadata"])
|
||||
elif "answer" in line:
|
||||
response_full += str(line["answer"])
|
||||
if line.get("structured"):
|
||||
is_structured = True
|
||||
schema_info = line.get("schema")
|
||||
structured_chunks.append(line["answer"])
|
||||
else:
|
||||
data = json.dumps({"type": "answer", "answer": line["answer"]})
|
||||
yield f"data: {data}\n\n"
|
||||
elif "sources" in line:
|
||||
truncated_sources = []
|
||||
source_log_docs = line["sources"]
|
||||
for source in line["sources"]:
|
||||
truncated_source = source.copy()
|
||||
if "text" in truncated_source:
|
||||
truncated_source["text"] = (
|
||||
truncated_source["text"][:100].strip() + "..."
|
||||
)
|
||||
truncated_sources.append(truncated_source)
|
||||
if truncated_sources:
|
||||
data = json.dumps(
|
||||
{"type": "source", "source": truncated_sources}
|
||||
)
|
||||
yield f"data: {data}\n\n"
|
||||
elif "tool_calls" in line:
|
||||
tool_calls = line["tool_calls"]
|
||||
data = json.dumps({"type": "tool_calls", "tool_calls": tool_calls})
|
||||
yield f"data: {data}\n\n"
|
||||
elif "thought" in line:
|
||||
thought += line["thought"]
|
||||
data = json.dumps({"type": "thought", "thought": line["thought"]})
|
||||
yield f"data: {data}\n\n"
|
||||
elif "type" in line:
|
||||
if line.get("type") == "tool_calls_pending":
|
||||
# Save continuation state and end the stream
|
||||
paused = True
|
||||
data = json.dumps(line)
|
||||
yield f"data: {data}\n\n"
|
||||
elif line.get("type") == "error":
|
||||
sanitized_error = {
|
||||
"type": "error",
|
||||
"error": sanitize_api_error(line.get("error", "An error occurred"))
|
||||
}
|
||||
data = json.dumps(sanitized_error)
|
||||
yield f"data: {data}\n\n"
|
||||
else:
|
||||
data = json.dumps(line)
|
||||
yield f"data: {data}\n\n"
|
||||
if is_structured and structured_chunks:
|
||||
structured_data = {
|
||||
"type": "structured_answer",
|
||||
"answer": response_full,
|
||||
"structured": True,
|
||||
"schema": schema_info,
|
||||
}
|
||||
data = json.dumps(structured_data)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
# ---- Paused: save continuation state and end stream early ----
|
||||
if paused:
|
||||
continuation = getattr(agent, "_pending_continuation", None)
|
||||
if continuation:
|
||||
# Ensure we have a conversation_id — create a partial
|
||||
# conversation if this is the first turn.
|
||||
if not conversation_id and should_save_conversation:
|
||||
try:
|
||||
# Use model-owner scope so shared-agent
|
||||
# owner-BYOM resolves to its registered plugin.
|
||||
provider = (
|
||||
get_provider_from_model_id(
|
||||
model_id,
|
||||
user_id=model_user_id
|
||||
or (
|
||||
decoded_token.get("sub")
|
||||
if decoded_token
|
||||
else None
|
||||
),
|
||||
)
|
||||
if model_id
|
||||
else settings.LLM_PROVIDER
|
||||
)
|
||||
sys_api_key = get_api_key_for_provider(
|
||||
provider or settings.LLM_PROVIDER
|
||||
)
|
||||
llm = LLMCreator.create_llm(
|
||||
provider or settings.LLM_PROVIDER,
|
||||
api_key=sys_api_key,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
model_id=model_id,
|
||||
agent_id=agent_id,
|
||||
model_user_id=model_user_id,
|
||||
)
|
||||
conversation_id = (
|
||||
self.conversation_service.save_conversation(
|
||||
None,
|
||||
question,
|
||||
response_full,
|
||||
thought,
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
model_id or self.default_model_id,
|
||||
decoded_token,
|
||||
api_key=user_api_key,
|
||||
agent_id=agent_id,
|
||||
is_shared_usage=is_shared_usage,
|
||||
shared_token=shared_token,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to create conversation for continuation: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
if conversation_id:
|
||||
try:
|
||||
cont_service = ContinuationService()
|
||||
cont_service.save_state(
|
||||
conversation_id=str(conversation_id),
|
||||
user=decoded_token.get("sub", "local"),
|
||||
messages=continuation["messages"],
|
||||
pending_tool_calls=continuation["pending_tool_calls"],
|
||||
tools_dict=continuation["tools_dict"],
|
||||
tool_schemas=getattr(agent, "tools", []),
|
||||
agent_config={
|
||||
"model_id": model_id or self.default_model_id,
|
||||
# Persist BYOM scope so resume doesn't
|
||||
# fall back to caller's layer.
|
||||
"model_user_id": model_user_id,
|
||||
"llm_name": getattr(agent, "llm_name", settings.LLM_PROVIDER),
|
||||
"api_key": getattr(agent, "api_key", None),
|
||||
"user_api_key": user_api_key,
|
||||
"agent_id": agent_id,
|
||||
"agent_type": agent.__class__.__name__,
|
||||
"prompt": getattr(agent, "prompt", ""),
|
||||
"json_schema": getattr(agent, "json_schema", None),
|
||||
"retriever_config": getattr(agent, "retriever_config", None),
|
||||
},
|
||||
client_tools=getattr(
|
||||
agent.tool_executor, "client_tools", None
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to save continuation state: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
id_data = {"type": "id", "id": str(conversation_id)}
|
||||
data = json.dumps(id_data)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
data = json.dumps({"type": "end"})
|
||||
yield f"data: {data}\n\n"
|
||||
return
|
||||
|
||||
if isNoneDoc:
|
||||
for doc in source_log_docs:
|
||||
doc["source"] = "None"
|
||||
# Run under model-owner scope so title-gen LLM inside
|
||||
# save_conversation uses the owner's BYOM provider/key.
|
||||
provider = (
|
||||
get_provider_from_model_id(
|
||||
model_id,
|
||||
user_id=model_user_id
|
||||
or (decoded_token.get("sub") if decoded_token else None),
|
||||
)
|
||||
if model_id
|
||||
else settings.LLM_PROVIDER
|
||||
)
|
||||
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
|
||||
|
||||
llm = LLMCreator.create_llm(
|
||||
provider or settings.LLM_PROVIDER,
|
||||
api_key=system_api_key,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
model_id=model_id,
|
||||
agent_id=agent_id,
|
||||
model_user_id=model_user_id,
|
||||
)
|
||||
|
||||
if should_save_conversation:
|
||||
conversation_id = self.conversation_service.save_conversation(
|
||||
conversation_id,
|
||||
question,
|
||||
response_full,
|
||||
thought,
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
model_id or self.default_model_id,
|
||||
decoded_token,
|
||||
index=index,
|
||||
api_key=user_api_key,
|
||||
agent_id=agent_id,
|
||||
is_shared_usage=is_shared_usage,
|
||||
shared_token=shared_token,
|
||||
attachment_ids=attachment_ids,
|
||||
metadata=query_metadata if query_metadata else None,
|
||||
)
|
||||
# Persist compression metadata/summary if it exists and wasn't saved mid-execution
|
||||
compression_meta = getattr(agent, "compression_metadata", None)
|
||||
compression_saved = getattr(agent, "compression_saved", False)
|
||||
if conversation_id and compression_meta and not compression_saved:
|
||||
try:
|
||||
self.conversation_service.update_compression_metadata(
|
||||
conversation_id, compression_meta
|
||||
)
|
||||
self.conversation_service.append_compression_message(
|
||||
conversation_id, compression_meta
|
||||
)
|
||||
agent.compression_saved = True
|
||||
logger.info(
|
||||
f"Persisted compression metadata for conversation {conversation_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to persist compression metadata: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
else:
|
||||
conversation_id = None
|
||||
id_data = {"type": "id", "id": str(conversation_id)}
|
||||
data = json.dumps(id_data)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
tool_calls_for_logging = self._prepare_tool_calls_for_logging(
|
||||
getattr(agent, "tool_calls", tool_calls) or tool_calls
|
||||
)
|
||||
|
||||
log_data = {
|
||||
"action": "stream_answer",
|
||||
"level": "info",
|
||||
"user": decoded_token.get("sub"),
|
||||
"api_key": user_api_key,
|
||||
"agent_id": agent_id,
|
||||
"question": question,
|
||||
"response": response_full,
|
||||
"sources": source_log_docs,
|
||||
"tool_calls": tool_calls_for_logging,
|
||||
"attachments": attachment_ids,
|
||||
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
||||
}
|
||||
if is_structured:
|
||||
log_data["structured_output"] = True
|
||||
if schema_info:
|
||||
log_data["schema"] = schema_info
|
||||
# Clean up text fields to be no longer than 10000 characters
|
||||
|
||||
for key, value in log_data.items():
|
||||
if isinstance(value, str) and len(value) > 10000:
|
||||
log_data[key] = value[:10000]
|
||||
try:
|
||||
with db_session() as conn:
|
||||
UserLogsRepository(conn).insert(
|
||||
user_id=log_data.get("user"),
|
||||
endpoint="stream_answer",
|
||||
data=log_data,
|
||||
)
|
||||
except Exception as log_err:
|
||||
logger.error(
|
||||
f"Failed to persist stream_answer user log: {log_err}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
data = json.dumps({"type": "end"})
|
||||
yield f"data: {data}\n\n"
|
||||
except GeneratorExit:
|
||||
logger.info(f"Stream aborted by client for question: {question[:50]}... ")
|
||||
# Save partial response
|
||||
|
||||
if should_save_conversation and response_full:
|
||||
try:
|
||||
if isNoneDoc:
|
||||
for doc in source_log_docs:
|
||||
doc["source"] = "None"
|
||||
# Mirror the normal-path provider resolution so the
|
||||
# partial-save title LLM uses the model-owner's BYOM
|
||||
# registration (shared-agent dispatch) rather than
|
||||
# the deployment default with the instance api key.
|
||||
provider = (
|
||||
get_provider_from_model_id(
|
||||
model_id,
|
||||
user_id=model_user_id
|
||||
or (
|
||||
decoded_token.get("sub")
|
||||
if decoded_token
|
||||
else None
|
||||
),
|
||||
)
|
||||
if model_id
|
||||
else settings.LLM_PROVIDER
|
||||
)
|
||||
sys_api_key = get_api_key_for_provider(
|
||||
provider or settings.LLM_PROVIDER
|
||||
)
|
||||
llm = LLMCreator.create_llm(
|
||||
provider or settings.LLM_PROVIDER,
|
||||
api_key=sys_api_key,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
model_id=model_id,
|
||||
agent_id=agent_id,
|
||||
model_user_id=model_user_id,
|
||||
)
|
||||
self.conversation_service.save_conversation(
|
||||
conversation_id,
|
||||
question,
|
||||
response_full,
|
||||
thought,
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
model_id or self.default_model_id,
|
||||
decoded_token,
|
||||
index=index,
|
||||
api_key=user_api_key,
|
||||
agent_id=agent_id,
|
||||
is_shared_usage=is_shared_usage,
|
||||
shared_token=shared_token,
|
||||
attachment_ids=attachment_ids,
|
||||
metadata=query_metadata if query_metadata else None,
|
||||
)
|
||||
compression_meta = getattr(agent, "compression_metadata", None)
|
||||
compression_saved = getattr(agent, "compression_saved", False)
|
||||
if conversation_id and compression_meta and not compression_saved:
|
||||
try:
|
||||
self.conversation_service.update_compression_metadata(
|
||||
conversation_id, compression_meta
|
||||
)
|
||||
self.conversation_service.append_compression_message(
|
||||
conversation_id, compression_meta
|
||||
)
|
||||
agent.compression_saved = True
|
||||
logger.info(
|
||||
f"Persisted compression metadata for conversation {conversation_id} (partial stream)"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to persist compression metadata (partial stream): {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error saving partial response: {str(e)}", exc_info=True
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in stream: {str(e)}", exc_info=True)
|
||||
data = json.dumps(
|
||||
{
|
||||
"type": "error",
|
||||
"error": "Please try again later. We apologize for any inconvenience.",
|
||||
}
|
||||
)
|
||||
yield f"data: {data}\n\n"
|
||||
return
|
||||
|
||||
def process_response_stream(self, stream) -> Dict[str, Any]:
|
||||
"""Process the stream response for non-streaming endpoint.
|
||||
|
||||
Returns:
|
||||
Dict with keys: conversation_id, answer, sources, tool_calls,
|
||||
thought, error, and optional extra.
|
||||
"""
|
||||
conversation_id = ""
|
||||
response_full = ""
|
||||
source_log_docs = []
|
||||
tool_calls = []
|
||||
thought = ""
|
||||
stream_ended = False
|
||||
is_structured = False
|
||||
schema_info = None
|
||||
pending_tool_calls = None
|
||||
|
||||
for line in stream:
|
||||
try:
|
||||
event_data = line.replace("data: ", "").strip()
|
||||
event = json.loads(event_data)
|
||||
|
||||
if event["type"] == "id":
|
||||
conversation_id = event["id"]
|
||||
elif event["type"] == "answer":
|
||||
response_full += event["answer"]
|
||||
elif event["type"] == "structured_answer":
|
||||
response_full = event["answer"]
|
||||
is_structured = True
|
||||
schema_info = event.get("schema")
|
||||
elif event["type"] == "source":
|
||||
source_log_docs = event["source"]
|
||||
elif event["type"] == "tool_calls":
|
||||
tool_calls = event["tool_calls"]
|
||||
elif event["type"] == "tool_calls_pending":
|
||||
pending_tool_calls = event.get("data", {}).get(
|
||||
"pending_tool_calls", []
|
||||
)
|
||||
elif event["type"] == "thought":
|
||||
thought = event["thought"]
|
||||
elif event["type"] == "error":
|
||||
logger.error(f"Error from stream: {event['error']}")
|
||||
return {
|
||||
"conversation_id": None,
|
||||
"answer": None,
|
||||
"sources": None,
|
||||
"tool_calls": None,
|
||||
"thought": None,
|
||||
"error": event["error"],
|
||||
}
|
||||
elif event["type"] == "end":
|
||||
stream_ended = True
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
logger.warning(f"Error parsing stream event: {e}, line: {line}")
|
||||
continue
|
||||
if not stream_ended:
|
||||
logger.error("Stream ended unexpectedly without an 'end' event.")
|
||||
return {
|
||||
"conversation_id": None,
|
||||
"answer": None,
|
||||
"sources": None,
|
||||
"tool_calls": None,
|
||||
"thought": None,
|
||||
"error": "Stream ended unexpectedly",
|
||||
}
|
||||
|
||||
result: Dict[str, Any] = {
|
||||
"conversation_id": conversation_id,
|
||||
"answer": response_full,
|
||||
"sources": source_log_docs,
|
||||
"tool_calls": tool_calls,
|
||||
"thought": thought,
|
||||
"error": None,
|
||||
}
|
||||
|
||||
if pending_tool_calls is not None:
|
||||
result["extra"] = {"pending_tool_calls": pending_tool_calls}
|
||||
|
||||
if is_structured:
|
||||
result["extra"] = {"structured": True, "schema": schema_info}
|
||||
|
||||
return result
|
||||
|
||||
def error_stream_generate(self, err_response):
|
||||
data = json.dumps({"type": "error", "error": err_response})
|
||||
yield f"data: {data}\n\n"
|
||||
55
application/api/answer/routes/search.py
Normal file
55
application/api/answer/routes/search.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import logging
|
||||
|
||||
from flask import make_response, request
|
||||
from flask_restx import fields, Resource
|
||||
|
||||
from application.api.answer.routes.base import answer_ns
|
||||
from application.services.search_service import (
|
||||
InvalidAPIKey,
|
||||
SearchFailed,
|
||||
search,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@answer_ns.route("/api/search")
|
||||
class SearchResource(Resource):
|
||||
"""Fast search endpoint for retrieving relevant documents."""
|
||||
|
||||
search_model = answer_ns.model(
|
||||
"SearchModel",
|
||||
{
|
||||
"question": fields.String(
|
||||
required=True, description="Search query"
|
||||
),
|
||||
"api_key": fields.String(
|
||||
required=True, description="API key for authentication"
|
||||
),
|
||||
"chunks": fields.Integer(
|
||||
required=False, default=5, description="Number of results to return"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@answer_ns.expect(search_model)
|
||||
@answer_ns.doc(description="Search for relevant documents based on query")
|
||||
def post(self):
|
||||
data = request.get_json() or {}
|
||||
|
||||
question = data.get("question")
|
||||
api_key = data.get("api_key")
|
||||
chunks = data.get("chunks", 5)
|
||||
|
||||
if not question:
|
||||
return make_response({"error": "question is required"}, 400)
|
||||
if not api_key:
|
||||
return make_response({"error": "api_key is required"}, 400)
|
||||
|
||||
try:
|
||||
return make_response(search(api_key, question, chunks), 200)
|
||||
except InvalidAPIKey:
|
||||
return make_response({"error": "Invalid API key"}, 401)
|
||||
except SearchFailed:
|
||||
logger.exception("/api/search failed")
|
||||
return make_response({"error": "Search failed"}, 500)
|
||||
173
application/api/answer/routes/stream.py
Normal file
173
application/api/answer/routes/stream.py
Normal file
@@ -0,0 +1,173 @@
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
from flask import request, Response
|
||||
from flask_restx import fields, Resource
|
||||
|
||||
from application.api import api
|
||||
|
||||
from application.api.answer.routes.base import answer_ns, BaseAnswerResource
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@answer_ns.route("/stream")
|
||||
class StreamResource(Resource, BaseAnswerResource):
|
||||
def __init__(self, *args, **kwargs):
|
||||
Resource.__init__(self, *args, **kwargs)
|
||||
BaseAnswerResource.__init__(self)
|
||||
|
||||
stream_model = answer_ns.model(
|
||||
"StreamModel",
|
||||
{
|
||||
"question": fields.String(
|
||||
required=True, description="Question to be asked"
|
||||
),
|
||||
"history": fields.List(
|
||||
fields.String,
|
||||
required=False,
|
||||
description="Conversation history (only for new conversations)",
|
||||
),
|
||||
"conversation_id": fields.String(
|
||||
required=False,
|
||||
description="Existing conversation ID (loads history)",
|
||||
),
|
||||
"prompt_id": fields.String(
|
||||
required=False, default="default", description="Prompt ID"
|
||||
),
|
||||
"chunks": fields.Integer(
|
||||
required=False, default=2, description="Number of chunks"
|
||||
),
|
||||
"retriever": fields.String(required=False, description="Retriever type"),
|
||||
"api_key": fields.String(required=False, description="API key"),
|
||||
"agent_id": fields.String(required=False, description="Agent ID"),
|
||||
"active_docs": fields.String(
|
||||
required=False, description="Active documents"
|
||||
),
|
||||
"isNoneDoc": fields.Boolean(
|
||||
required=False, description="Flag indicating if no document is used"
|
||||
),
|
||||
"index": fields.Integer(
|
||||
required=False, description="Index of the query to update"
|
||||
),
|
||||
"save_conversation": fields.Boolean(
|
||||
required=False,
|
||||
default=True,
|
||||
description="Whether to save the conversation",
|
||||
),
|
||||
"model_id": fields.String(
|
||||
required=False,
|
||||
description="Model ID to use for this request",
|
||||
),
|
||||
"attachments": fields.List(
|
||||
fields.String, required=False, description="List of attachment IDs"
|
||||
),
|
||||
"passthrough": fields.Raw(
|
||||
required=False,
|
||||
description="Dynamic parameters to inject into prompt template",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(stream_model)
|
||||
@api.doc(description="Stream a response based on the question and retriever")
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
if error := self.validate_request(data, "index" in data):
|
||||
return error
|
||||
decoded_token = getattr(request, "decoded_token", None)
|
||||
processor = StreamProcessor(data, decoded_token)
|
||||
|
||||
try:
|
||||
# ---- Continuation mode ----
|
||||
if data.get("tool_actions"):
|
||||
(
|
||||
agent,
|
||||
messages,
|
||||
tools_dict,
|
||||
pending_tool_calls,
|
||||
tool_actions,
|
||||
) = processor.resume_from_tool_actions(
|
||||
data["tool_actions"], data["conversation_id"]
|
||||
)
|
||||
if not processor.decoded_token:
|
||||
return Response(
|
||||
self.error_stream_generate("Unauthorized"),
|
||||
status=401,
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
if error := self.check_usage(processor.agent_config):
|
||||
return error
|
||||
return Response(
|
||||
self.complete_stream(
|
||||
question="",
|
||||
agent=agent,
|
||||
conversation_id=processor.conversation_id,
|
||||
user_api_key=processor.agent_config.get("user_api_key"),
|
||||
decoded_token=processor.decoded_token,
|
||||
agent_id=processor.agent_id,
|
||||
model_id=processor.model_id,
|
||||
model_user_id=processor.model_user_id,
|
||||
_continuation={
|
||||
"messages": messages,
|
||||
"tools_dict": tools_dict,
|
||||
"pending_tool_calls": pending_tool_calls,
|
||||
"tool_actions": tool_actions,
|
||||
},
|
||||
),
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
|
||||
# ---- Normal mode ----
|
||||
agent = processor.build_agent(data["question"])
|
||||
if not processor.decoded_token:
|
||||
return Response(
|
||||
self.error_stream_generate("Unauthorized"),
|
||||
status=401,
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
|
||||
if error := self.check_usage(processor.agent_config):
|
||||
return error
|
||||
return Response(
|
||||
self.complete_stream(
|
||||
question=data["question"],
|
||||
agent=agent,
|
||||
conversation_id=processor.conversation_id,
|
||||
user_api_key=processor.agent_config.get("user_api_key"),
|
||||
decoded_token=processor.decoded_token,
|
||||
isNoneDoc=data.get("isNoneDoc"),
|
||||
index=data.get("index"),
|
||||
should_save_conversation=data.get("save_conversation", True),
|
||||
attachment_ids=data.get("attachments", []),
|
||||
agent_id=processor.agent_id,
|
||||
is_shared_usage=processor.is_shared_usage,
|
||||
shared_token=processor.shared_token,
|
||||
model_id=processor.model_id,
|
||||
model_user_id=processor.model_user_id,
|
||||
),
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
except ValueError as e:
|
||||
message = "Malformed request body"
|
||||
logger.error(
|
||||
f"/stream - error: {message} - specific error: {str(e)} - traceback: {traceback.format_exc()}",
|
||||
extra={"error": str(e), "traceback": traceback.format_exc()},
|
||||
)
|
||||
return Response(
|
||||
self.error_stream_generate(message),
|
||||
status=400,
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"/stream - error: {str(e)} - traceback: {traceback.format_exc()}",
|
||||
extra={"error": str(e), "traceback": traceback.format_exc()},
|
||||
)
|
||||
return Response(
|
||||
self.error_stream_generate("Unknown error occurred"),
|
||||
status=400,
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
0
application/api/answer/services/__init__.py
Normal file
0
application/api/answer/services/__init__.py
Normal file
20
application/api/answer/services/compression/__init__.py
Normal file
20
application/api/answer/services/compression/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""
|
||||
Compression module for managing conversation context compression.
|
||||
|
||||
"""
|
||||
|
||||
from application.api.answer.services.compression.orchestrator import (
|
||||
CompressionOrchestrator,
|
||||
)
|
||||
from application.api.answer.services.compression.service import CompressionService
|
||||
from application.api.answer.services.compression.types import (
|
||||
CompressionResult,
|
||||
CompressionMetadata,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CompressionOrchestrator",
|
||||
"CompressionService",
|
||||
"CompressionResult",
|
||||
"CompressionMetadata",
|
||||
]
|
||||
249
application/api/answer/services/compression/message_builder.py
Normal file
249
application/api/answer/services/compression/message_builder.py
Normal file
@@ -0,0 +1,249 @@
|
||||
"""Message reconstruction utilities for compression."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageBuilder:
|
||||
"""Builds message arrays from compressed context."""
|
||||
|
||||
@staticmethod
|
||||
def build_from_compressed_context(
|
||||
system_prompt: str,
|
||||
compressed_summary: Optional[str],
|
||||
recent_queries: List[Dict],
|
||||
include_tool_calls: bool = False,
|
||||
context_type: str = "pre_request",
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Build messages from compressed context.
|
||||
|
||||
Args:
|
||||
system_prompt: Original system prompt
|
||||
compressed_summary: Compressed summary (if any)
|
||||
recent_queries: Recent uncompressed queries
|
||||
include_tool_calls: Whether to include tool calls from history
|
||||
context_type: Type of context ('pre_request' or 'mid_execution')
|
||||
|
||||
Returns:
|
||||
List of message dicts ready for LLM
|
||||
"""
|
||||
# Append compression summary to system prompt if present
|
||||
if compressed_summary:
|
||||
system_prompt = MessageBuilder._append_compression_context(
|
||||
system_prompt, compressed_summary, context_type
|
||||
)
|
||||
|
||||
messages = [{"role": "system", "content": system_prompt}]
|
||||
|
||||
# Add recent history
|
||||
for query in recent_queries:
|
||||
if "prompt" in query and "response" in query:
|
||||
messages.append({"role": "user", "content": query["prompt"]})
|
||||
messages.append({"role": "assistant", "content": query["response"]})
|
||||
|
||||
# Add tool calls from history if present
|
||||
if include_tool_calls and "tool_calls" in query:
|
||||
for tool_call in query["tool_calls"]:
|
||||
call_id = tool_call.get("call_id") or str(uuid.uuid4())
|
||||
args = tool_call.get("arguments")
|
||||
args_str = (
|
||||
json.dumps(args)
|
||||
if isinstance(args, dict)
|
||||
else (args or "{}")
|
||||
)
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.get("action_name", ""),
|
||||
"arguments": args_str,
|
||||
},
|
||||
}],
|
||||
})
|
||||
result = tool_call.get("result")
|
||||
result_str = (
|
||||
json.dumps(result)
|
||||
if not isinstance(result, str)
|
||||
else (result or "")
|
||||
)
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"content": result_str,
|
||||
})
|
||||
|
||||
# If no recent queries (everything was compressed), add a continuation user message
|
||||
if len(recent_queries) == 0 and compressed_summary:
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": "Please continue with the remaining tasks based on the context above."
|
||||
})
|
||||
logger.info("Added continuation user message to maintain proper turn-taking after full compression")
|
||||
|
||||
return messages
|
||||
|
||||
@staticmethod
|
||||
def _append_compression_context(
|
||||
system_prompt: str, compressed_summary: str, context_type: str = "pre_request"
|
||||
) -> str:
|
||||
"""
|
||||
Append compression context to system prompt.
|
||||
|
||||
Args:
|
||||
system_prompt: Original system prompt
|
||||
compressed_summary: Summary to append
|
||||
context_type: Type of compression context
|
||||
|
||||
Returns:
|
||||
Updated system prompt
|
||||
"""
|
||||
# Remove existing compression context if present
|
||||
if "This session is being continued" in system_prompt or "Context window limit reached" in system_prompt:
|
||||
parts = system_prompt.split("\n\n---\n\n")
|
||||
system_prompt = parts[0]
|
||||
|
||||
# Build appropriate context message based on type
|
||||
if context_type == "mid_execution":
|
||||
context_message = (
|
||||
"\n\n---\n\n"
|
||||
"Context window limit reached during execution. "
|
||||
"Previous conversation has been compressed to fit within limits. "
|
||||
"The conversation is summarized below:\n\n"
|
||||
f"{compressed_summary}"
|
||||
)
|
||||
else: # pre_request
|
||||
context_message = (
|
||||
"\n\n---\n\n"
|
||||
"This session is being continued from a previous conversation that "
|
||||
"has been compressed to fit within context limits. "
|
||||
"The conversation is summarized below:\n\n"
|
||||
f"{compressed_summary}"
|
||||
)
|
||||
|
||||
return system_prompt + context_message
|
||||
|
||||
@staticmethod
|
||||
def rebuild_messages_after_compression(
|
||||
messages: List[Dict],
|
||||
compressed_summary: Optional[str],
|
||||
recent_queries: List[Dict],
|
||||
include_current_execution: bool = False,
|
||||
include_tool_calls: bool = False,
|
||||
) -> Optional[List[Dict]]:
|
||||
"""
|
||||
Rebuild the message list after compression so tool execution can continue.
|
||||
|
||||
Args:
|
||||
messages: Original message list
|
||||
compressed_summary: Compressed summary
|
||||
recent_queries: Recent uncompressed queries
|
||||
include_current_execution: Whether to preserve current execution messages
|
||||
include_tool_calls: Whether to include tool calls from history
|
||||
|
||||
Returns:
|
||||
Rebuilt message list or None if failed
|
||||
"""
|
||||
# Find the system message
|
||||
system_message = next(
|
||||
(msg for msg in messages if msg.get("role") == "system"), None
|
||||
)
|
||||
if not system_message:
|
||||
logger.warning("No system message found in messages list")
|
||||
return None
|
||||
|
||||
# Update system message with compressed summary
|
||||
if compressed_summary:
|
||||
content = system_message.get("content", "")
|
||||
system_message["content"] = MessageBuilder._append_compression_context(
|
||||
content, compressed_summary, "mid_execution"
|
||||
)
|
||||
logger.info(
|
||||
"Appended compression summary to system prompt (truncated): %s",
|
||||
(
|
||||
compressed_summary[:500] + "..."
|
||||
if len(compressed_summary) > 500
|
||||
else compressed_summary
|
||||
),
|
||||
)
|
||||
|
||||
rebuilt_messages = [system_message]
|
||||
|
||||
# Add recent history from compressed context
|
||||
for query in recent_queries:
|
||||
if "prompt" in query and "response" in query:
|
||||
rebuilt_messages.append({"role": "user", "content": query["prompt"]})
|
||||
rebuilt_messages.append(
|
||||
{"role": "assistant", "content": query["response"]}
|
||||
)
|
||||
|
||||
# Add tool calls from history if present
|
||||
if include_tool_calls and "tool_calls" in query:
|
||||
for tool_call in query["tool_calls"]:
|
||||
call_id = tool_call.get("call_id") or str(uuid.uuid4())
|
||||
args = tool_call.get("arguments")
|
||||
args_str = (
|
||||
json.dumps(args)
|
||||
if isinstance(args, dict)
|
||||
else (args or "{}")
|
||||
)
|
||||
rebuilt_messages.append({
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.get("action_name", ""),
|
||||
"arguments": args_str,
|
||||
},
|
||||
}],
|
||||
})
|
||||
result = tool_call.get("result")
|
||||
result_str = (
|
||||
json.dumps(result)
|
||||
if not isinstance(result, str)
|
||||
else (result or "")
|
||||
)
|
||||
rebuilt_messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"content": result_str,
|
||||
})
|
||||
|
||||
# If no recent queries (everything was compressed), add a continuation user message
|
||||
if len(recent_queries) == 0 and compressed_summary:
|
||||
rebuilt_messages.append({
|
||||
"role": "user",
|
||||
"content": "Please continue with the remaining tasks based on the context above."
|
||||
})
|
||||
logger.info("Added continuation user message to maintain proper turn-taking after full compression")
|
||||
|
||||
if include_current_execution:
|
||||
# Preserve any messages that were added during the current execution cycle
|
||||
recent_msg_count = 1 # system message
|
||||
for query in recent_queries:
|
||||
if "prompt" in query and "response" in query:
|
||||
recent_msg_count += 2
|
||||
if "tool_calls" in query:
|
||||
recent_msg_count += len(query["tool_calls"]) * 2
|
||||
|
||||
if len(messages) > recent_msg_count:
|
||||
current_execution_messages = messages[recent_msg_count:]
|
||||
rebuilt_messages.extend(current_execution_messages)
|
||||
logger.info(
|
||||
f"Preserved {len(current_execution_messages)} messages from current execution cycle"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Messages rebuilt: {len(messages)} → {len(rebuilt_messages)} messages. "
|
||||
f"Ready to continue tool execution."
|
||||
)
|
||||
return rebuilt_messages
|
||||
270
application/api/answer/services/compression/orchestrator.py
Normal file
270
application/api/answer/services/compression/orchestrator.py
Normal file
@@ -0,0 +1,270 @@
|
||||
"""High-level compression orchestration."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from application.api.answer.services.compression.service import CompressionService
|
||||
from application.api.answer.services.compression.threshold_checker import (
|
||||
CompressionThresholdChecker,
|
||||
)
|
||||
from application.api.answer.services.compression.types import CompressionResult
|
||||
from application.api.answer.services.conversation_service import ConversationService
|
||||
from application.core.model_utils import (
|
||||
get_api_key_for_provider,
|
||||
get_provider_from_model_id,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompressionOrchestrator:
|
||||
"""
|
||||
Facade for compression operations.
|
||||
|
||||
Coordinates between all compression components and provides
|
||||
a simple interface for callers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conversation_service: ConversationService,
|
||||
threshold_checker: Optional[CompressionThresholdChecker] = None,
|
||||
):
|
||||
"""
|
||||
Initialize orchestrator.
|
||||
|
||||
Args:
|
||||
conversation_service: Service for DB operations
|
||||
threshold_checker: Custom threshold checker (optional)
|
||||
"""
|
||||
self.conversation_service = conversation_service
|
||||
self.threshold_checker = threshold_checker or CompressionThresholdChecker()
|
||||
|
||||
def compress_if_needed(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_id: str,
|
||||
model_id: str,
|
||||
decoded_token: Dict[str, Any],
|
||||
current_query_tokens: int = 500,
|
||||
model_user_id: Optional[str] = None,
|
||||
) -> CompressionResult:
|
||||
"""
|
||||
Check if compression is needed and perform it if so.
|
||||
|
||||
This is the main entry point for compression operations.
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation ID
|
||||
user_id: Caller's user id — used for conversation access checks
|
||||
model_id: Model being used for conversation
|
||||
decoded_token: User's decoded JWT token
|
||||
current_query_tokens: Estimated tokens for current query
|
||||
model_user_id: BYOM-resolution scope (model owner); defaults
|
||||
to ``user_id`` for built-in / caller-owned models.
|
||||
|
||||
Returns:
|
||||
CompressionResult with summary and recent queries
|
||||
"""
|
||||
try:
|
||||
# Conversation row is owned by the caller, not the model owner.
|
||||
conversation = self.conversation_service.get_conversation(
|
||||
conversation_id, user_id
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
logger.warning(
|
||||
f"Conversation {conversation_id} not found for user {user_id}"
|
||||
)
|
||||
return CompressionResult.failure("Conversation not found")
|
||||
|
||||
# Use model-owner scope so per-user BYOM context windows
|
||||
# (e.g. 8k) compute the threshold against the right limit.
|
||||
registry_user_id = model_user_id or user_id
|
||||
if not self.threshold_checker.should_compress(
|
||||
conversation,
|
||||
model_id,
|
||||
current_query_tokens,
|
||||
user_id=registry_user_id,
|
||||
):
|
||||
# No compression needed, return full history
|
||||
queries = conversation.get("queries", [])
|
||||
return CompressionResult.success_no_compression(queries)
|
||||
|
||||
# Perform compression
|
||||
return self._perform_compression(
|
||||
conversation_id,
|
||||
conversation,
|
||||
model_id,
|
||||
decoded_token,
|
||||
user_id=user_id,
|
||||
model_user_id=model_user_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in compress_if_needed: {str(e)}", exc_info=True
|
||||
)
|
||||
return CompressionResult.failure(str(e))
|
||||
|
||||
def _perform_compression(
|
||||
self,
|
||||
conversation_id: str,
|
||||
conversation: Dict[str, Any],
|
||||
model_id: str,
|
||||
decoded_token: Dict[str, Any],
|
||||
user_id: Optional[str] = None,
|
||||
model_user_id: Optional[str] = None,
|
||||
) -> CompressionResult:
|
||||
"""
|
||||
Perform the actual compression operation.
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation ID
|
||||
conversation: Conversation document
|
||||
model_id: Model ID for conversation
|
||||
decoded_token: User token
|
||||
user_id: Caller's id (for conversation reload after compression)
|
||||
model_user_id: BYOM-resolution scope (model owner)
|
||||
|
||||
Returns:
|
||||
CompressionResult
|
||||
"""
|
||||
try:
|
||||
# Determine which model to use for compression
|
||||
compression_model = (
|
||||
settings.COMPRESSION_MODEL_OVERRIDE
|
||||
if settings.COMPRESSION_MODEL_OVERRIDE
|
||||
else model_id
|
||||
)
|
||||
|
||||
# Use model-owner scope so provider/api_key resolves to the
|
||||
# owner's BYOM record (shared-agent dispatch).
|
||||
caller_user_id = user_id
|
||||
if caller_user_id is None and isinstance(decoded_token, dict):
|
||||
caller_user_id = decoded_token.get("sub")
|
||||
registry_user_id = model_user_id or caller_user_id
|
||||
provider = get_provider_from_model_id(
|
||||
compression_model, user_id=registry_user_id
|
||||
)
|
||||
api_key = get_api_key_for_provider(provider)
|
||||
|
||||
compression_llm = LLMCreator.create_llm(
|
||||
provider,
|
||||
api_key=api_key,
|
||||
user_api_key=None,
|
||||
decoded_token=decoded_token,
|
||||
model_id=compression_model,
|
||||
agent_id=conversation.get("agent_id"),
|
||||
model_user_id=registry_user_id,
|
||||
)
|
||||
|
||||
# Create compression service with DB update capability
|
||||
compression_service = CompressionService(
|
||||
llm=compression_llm,
|
||||
model_id=compression_model,
|
||||
conversation_service=self.conversation_service,
|
||||
)
|
||||
|
||||
# Compress all queries up to the latest
|
||||
queries_count = len(conversation.get("queries", []))
|
||||
compress_up_to = queries_count - 1
|
||||
|
||||
if compress_up_to < 0:
|
||||
logger.warning("No queries to compress")
|
||||
return CompressionResult.success_no_compression([])
|
||||
|
||||
logger.info(
|
||||
f"Initiating compression for conversation {conversation_id}: "
|
||||
f"compressing all {queries_count} queries (0-{compress_up_to})"
|
||||
)
|
||||
|
||||
# Perform compression and save to DB
|
||||
metadata = compression_service.compress_and_save(
|
||||
conversation_id, conversation, compress_up_to
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Compression successful - ratio: {metadata.compression_ratio:.1f}x, "
|
||||
f"saved {metadata.original_token_count - metadata.compressed_token_count} tokens"
|
||||
)
|
||||
|
||||
# Reload under caller (conversation is owned by caller).
|
||||
reload_user_id = caller_user_id
|
||||
if reload_user_id is None and isinstance(decoded_token, dict):
|
||||
reload_user_id = decoded_token.get("sub")
|
||||
conversation = self.conversation_service.get_conversation(
|
||||
conversation_id, user_id=reload_user_id
|
||||
)
|
||||
|
||||
# Get compressed context
|
||||
compressed_summary, recent_queries = (
|
||||
compression_service.get_compressed_context(conversation)
|
||||
)
|
||||
|
||||
return CompressionResult.success_with_compression(
|
||||
compressed_summary, recent_queries, metadata
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error performing compression: {str(e)}", exc_info=True)
|
||||
return CompressionResult.failure(str(e))
|
||||
|
||||
def compress_mid_execution(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_id: str,
|
||||
model_id: str,
|
||||
decoded_token: Dict[str, Any],
|
||||
current_conversation: Optional[Dict[str, Any]] = None,
|
||||
model_user_id: Optional[str] = None,
|
||||
) -> CompressionResult:
|
||||
"""
|
||||
Perform compression during tool execution.
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation ID
|
||||
user_id: Caller's user id — used for conversation access checks
|
||||
model_id: Model ID
|
||||
decoded_token: User token
|
||||
current_conversation: Pre-loaded conversation (optional)
|
||||
model_user_id: BYOM-resolution scope (model owner). For
|
||||
shared-agent dispatch this is the agent owner; defaults
|
||||
to ``user_id`` so built-in / caller-owned models are
|
||||
unaffected.
|
||||
|
||||
Returns:
|
||||
CompressionResult
|
||||
"""
|
||||
try:
|
||||
# Load conversation if not provided
|
||||
if current_conversation:
|
||||
conversation = current_conversation
|
||||
else:
|
||||
conversation = self.conversation_service.get_conversation(
|
||||
conversation_id, user_id
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
logger.warning(
|
||||
f"Could not load conversation {conversation_id} for mid-execution compression"
|
||||
)
|
||||
return CompressionResult.failure("Conversation not found")
|
||||
|
||||
# Perform compression
|
||||
return self._perform_compression(
|
||||
conversation_id,
|
||||
conversation,
|
||||
model_id,
|
||||
decoded_token,
|
||||
user_id=user_id,
|
||||
model_user_id=model_user_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in mid-execution compression: {str(e)}", exc_info=True
|
||||
)
|
||||
return CompressionResult.failure(str(e))
|
||||
149
application/api/answer/services/compression/prompt_builder.py
Normal file
149
application/api/answer/services/compression/prompt_builder.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""Compression prompt building logic."""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompressionPromptBuilder:
|
||||
"""Builds prompts for LLM compression calls."""
|
||||
|
||||
def __init__(self, version: str = "v1.0"):
|
||||
"""
|
||||
Initialize prompt builder.
|
||||
|
||||
Args:
|
||||
version: Prompt template version to use
|
||||
"""
|
||||
self.version = version
|
||||
self.system_prompt = self._load_prompt(version)
|
||||
|
||||
def _load_prompt(self, version: str) -> str:
|
||||
"""
|
||||
Load prompt template from file.
|
||||
|
||||
Args:
|
||||
version: Version string (e.g., 'v1.0')
|
||||
|
||||
Returns:
|
||||
Prompt template content
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If prompt template file doesn't exist
|
||||
"""
|
||||
current_dir = Path(__file__).resolve().parents[4]
|
||||
prompt_path = current_dir / "prompts" / "compression" / f"{version}.txt"
|
||||
|
||||
try:
|
||||
with open(prompt_path, "r") as f:
|
||||
return f.read()
|
||||
except FileNotFoundError:
|
||||
logger.error(f"Compression prompt template not found: {prompt_path}")
|
||||
raise FileNotFoundError(
|
||||
f"Compression prompt template '{version}' not found at {prompt_path}. "
|
||||
f"Please ensure the template file exists."
|
||||
)
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
queries: List[Dict[str, Any]],
|
||||
existing_compressions: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Build messages for compression LLM call.
|
||||
|
||||
Args:
|
||||
queries: List of query objects to compress
|
||||
existing_compressions: List of previous compression points
|
||||
|
||||
Returns:
|
||||
List of message dicts for LLM
|
||||
"""
|
||||
# Build conversation text
|
||||
conversation_text = self._format_conversation(queries)
|
||||
|
||||
# Add existing compression context if present
|
||||
existing_compression_context = ""
|
||||
if existing_compressions and len(existing_compressions) > 0:
|
||||
existing_compression_context = (
|
||||
"\n\nIMPORTANT: This conversation has been compressed before. "
|
||||
"Previous compression summaries:\n\n"
|
||||
)
|
||||
for i, comp in enumerate(existing_compressions):
|
||||
existing_compression_context += (
|
||||
f"--- Compression {i + 1} (up to message {comp.get('query_index', 'unknown')}) ---\n"
|
||||
f"{comp.get('compressed_summary', '')}\n\n"
|
||||
)
|
||||
existing_compression_context += (
|
||||
"Your task is to create a NEW summary that incorporates the context from "
|
||||
"previous compressions AND the new messages below. The final summary should "
|
||||
"be comprehensive and include all important information from both previous "
|
||||
"compressions and new messages.\n\n"
|
||||
)
|
||||
|
||||
user_prompt = (
|
||||
f"{existing_compression_context}"
|
||||
f"Here is the conversation to summarize:\n\n"
|
||||
f"{conversation_text}"
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": self.system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
|
||||
return messages
|
||||
|
||||
def _format_conversation(self, queries: List[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
Format conversation queries into readable text for compression.
|
||||
|
||||
Args:
|
||||
queries: List of query objects
|
||||
|
||||
Returns:
|
||||
Formatted conversation text
|
||||
"""
|
||||
conversation_lines = []
|
||||
|
||||
for i, query in enumerate(queries):
|
||||
conversation_lines.append(f"--- Message {i + 1} ---")
|
||||
conversation_lines.append(f"User: {query.get('prompt', '')}")
|
||||
|
||||
# Add tool calls if present
|
||||
tool_calls = query.get("tool_calls", [])
|
||||
if tool_calls:
|
||||
conversation_lines.append("\nTool Calls:")
|
||||
for tc in tool_calls:
|
||||
tool_name = tc.get("tool_name", "unknown")
|
||||
action_name = tc.get("action_name", "unknown")
|
||||
arguments = tc.get("arguments", {})
|
||||
result = tc.get("result", "")
|
||||
if result is None:
|
||||
result = ""
|
||||
status = tc.get("status", "unknown")
|
||||
|
||||
# Include full tool result for complete compression context
|
||||
conversation_lines.append(
|
||||
f" - {tool_name}.{action_name}({arguments}) "
|
||||
f"[{status}] → {result}"
|
||||
)
|
||||
|
||||
# Add agent thought if present
|
||||
thought = query.get("thought", "")
|
||||
if thought:
|
||||
conversation_lines.append(f"\nAgent Thought: {thought}")
|
||||
|
||||
# Add assistant response
|
||||
conversation_lines.append(f"\nAssistant: {query.get('response', '')}")
|
||||
|
||||
# Add sources if present
|
||||
sources = query.get("sources", [])
|
||||
if sources:
|
||||
conversation_lines.append(f"\nSources Used: {len(sources)} documents")
|
||||
|
||||
conversation_lines.append("") # Empty line between messages
|
||||
|
||||
return "\n".join(conversation_lines)
|
||||
311
application/api/answer/services/compression/service.py
Normal file
311
application/api/answer/services/compression/service.py
Normal file
@@ -0,0 +1,311 @@
|
||||
"""Core compression service with simplified responsibilities."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from application.api.answer.services.compression.prompt_builder import (
|
||||
CompressionPromptBuilder,
|
||||
)
|
||||
from application.api.answer.services.compression.token_counter import TokenCounter
|
||||
from application.api.answer.services.compression.types import (
|
||||
CompressionMetadata,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompressionService:
|
||||
"""
|
||||
Service for compressing conversation history.
|
||||
|
||||
Handles DB updates.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm,
|
||||
model_id: str,
|
||||
conversation_service=None,
|
||||
prompt_builder: Optional[CompressionPromptBuilder] = None,
|
||||
):
|
||||
"""
|
||||
Initialize compression service.
|
||||
|
||||
Args:
|
||||
llm: LLM instance to use for compression
|
||||
model_id: Model ID for compression
|
||||
conversation_service: Service for DB operations (optional, for DB updates)
|
||||
prompt_builder: Custom prompt builder (optional)
|
||||
"""
|
||||
self.llm = llm
|
||||
self.model_id = model_id
|
||||
self.conversation_service = conversation_service
|
||||
self.prompt_builder = prompt_builder or CompressionPromptBuilder(
|
||||
version=settings.COMPRESSION_PROMPT_VERSION
|
||||
)
|
||||
|
||||
def compress_conversation(
|
||||
self,
|
||||
conversation: Dict[str, Any],
|
||||
compress_up_to_index: int,
|
||||
) -> CompressionMetadata:
|
||||
"""
|
||||
Compress conversation history up to specified index.
|
||||
|
||||
Args:
|
||||
conversation: Full conversation document
|
||||
compress_up_to_index: Last query index to include in compression
|
||||
|
||||
Returns:
|
||||
CompressionMetadata with compression details
|
||||
|
||||
Raises:
|
||||
ValueError: If compress_up_to_index is invalid
|
||||
"""
|
||||
try:
|
||||
queries = conversation.get("queries", [])
|
||||
|
||||
if compress_up_to_index < 0 or compress_up_to_index >= len(queries):
|
||||
raise ValueError(
|
||||
f"Invalid compress_up_to_index: {compress_up_to_index} "
|
||||
f"(conversation has {len(queries)} queries)"
|
||||
)
|
||||
|
||||
# Get queries to compress
|
||||
queries_to_compress = queries[: compress_up_to_index + 1]
|
||||
|
||||
# Check if there are existing compressions
|
||||
existing_compressions = conversation.get("compression_metadata", {}).get(
|
||||
"compression_points", []
|
||||
)
|
||||
|
||||
if existing_compressions:
|
||||
logger.info(
|
||||
f"Found {len(existing_compressions)} previous compression(s) - "
|
||||
f"will incorporate into new summary"
|
||||
)
|
||||
|
||||
# Calculate original token count
|
||||
original_tokens = TokenCounter.count_query_tokens(queries_to_compress)
|
||||
|
||||
# Log tool call stats
|
||||
self._log_tool_call_stats(queries_to_compress)
|
||||
|
||||
# Build compression prompt
|
||||
messages = self.prompt_builder.build_prompt(
|
||||
queries_to_compress, existing_compressions
|
||||
)
|
||||
|
||||
# Call LLM to generate compression
|
||||
logger.info(
|
||||
f"Starting compression: {len(queries_to_compress)} queries "
|
||||
f"(messages 0-{compress_up_to_index}, {original_tokens} tokens) "
|
||||
f"using model {self.model_id}"
|
||||
)
|
||||
|
||||
# See note in conversation_service.py: ``self.model_id`` is
|
||||
# the registry id (UUID for BYOM); the LLM's own model_id is
|
||||
# what the provider's API actually expects.
|
||||
response = self.llm.gen(
|
||||
model=getattr(self.llm, "model_id", None) or self.model_id,
|
||||
messages=messages,
|
||||
max_tokens=4000,
|
||||
)
|
||||
|
||||
# Extract summary from response
|
||||
compressed_summary = self._extract_summary(response)
|
||||
|
||||
# Calculate compressed token count
|
||||
compressed_tokens = TokenCounter.count_message_tokens(
|
||||
[{"content": compressed_summary}]
|
||||
)
|
||||
|
||||
# Calculate compression ratio
|
||||
compression_ratio = (
|
||||
original_tokens / compressed_tokens if compressed_tokens > 0 else 0
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Compression complete: {original_tokens} → {compressed_tokens} tokens "
|
||||
f"({compression_ratio:.1f}x compression)"
|
||||
)
|
||||
|
||||
# Build compression metadata
|
||||
compression_metadata = CompressionMetadata(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
query_index=compress_up_to_index,
|
||||
compressed_summary=compressed_summary,
|
||||
original_token_count=original_tokens,
|
||||
compressed_token_count=compressed_tokens,
|
||||
compression_ratio=compression_ratio,
|
||||
model_used=self.model_id,
|
||||
compression_prompt_version=self.prompt_builder.version,
|
||||
)
|
||||
|
||||
return compression_metadata
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error compressing conversation: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def compress_and_save(
|
||||
self,
|
||||
conversation_id: str,
|
||||
conversation: Dict[str, Any],
|
||||
compress_up_to_index: int,
|
||||
) -> CompressionMetadata:
|
||||
"""
|
||||
Compress conversation and save to database.
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation ID
|
||||
conversation: Full conversation document
|
||||
compress_up_to_index: Last query index to include
|
||||
|
||||
Returns:
|
||||
CompressionMetadata
|
||||
|
||||
Raises:
|
||||
ValueError: If conversation_service not provided or invalid index
|
||||
"""
|
||||
if not self.conversation_service:
|
||||
raise ValueError(
|
||||
"conversation_service required for compress_and_save operation"
|
||||
)
|
||||
|
||||
# Perform compression
|
||||
metadata = self.compress_conversation(conversation, compress_up_to_index)
|
||||
|
||||
# Save to database
|
||||
self.conversation_service.update_compression_metadata(
|
||||
conversation_id, metadata.to_dict()
|
||||
)
|
||||
|
||||
logger.info(f"Compression metadata saved to database for {conversation_id}")
|
||||
|
||||
return metadata
|
||||
|
||||
def get_compressed_context(
|
||||
self, conversation: Dict[str, Any]
|
||||
) -> tuple[Optional[str], List[Dict[str, Any]]]:
|
||||
"""
|
||||
Get compressed summary + recent uncompressed messages.
|
||||
|
||||
Args:
|
||||
conversation: Full conversation document
|
||||
|
||||
Returns:
|
||||
(compressed_summary, recent_messages)
|
||||
"""
|
||||
try:
|
||||
compression_metadata = conversation.get("compression_metadata", {})
|
||||
|
||||
if not compression_metadata.get("is_compressed"):
|
||||
logger.debug("No compression metadata found - using full history")
|
||||
queries = conversation.get("queries", [])
|
||||
if queries is None:
|
||||
logger.error("Conversation queries is None - returning empty list")
|
||||
return None, []
|
||||
return None, queries
|
||||
|
||||
compression_points = compression_metadata.get("compression_points", [])
|
||||
|
||||
if not compression_points:
|
||||
logger.debug("No compression points found - using full history")
|
||||
queries = conversation.get("queries", [])
|
||||
if queries is None:
|
||||
logger.error("Conversation queries is None - returning empty list")
|
||||
return None, []
|
||||
return None, queries
|
||||
|
||||
# Get the most recent compression point
|
||||
latest_compression = compression_points[-1]
|
||||
compressed_summary = latest_compression.get("compressed_summary")
|
||||
last_compressed_index = latest_compression.get("query_index")
|
||||
compressed_tokens = latest_compression.get("compressed_token_count", 0)
|
||||
original_tokens = latest_compression.get("original_token_count", 0)
|
||||
|
||||
# Get only messages after compression point
|
||||
queries = conversation.get("queries", [])
|
||||
total_queries = len(queries)
|
||||
recent_queries = queries[last_compressed_index + 1 :]
|
||||
|
||||
logger.info(
|
||||
f"Using compressed context: summary ({compressed_tokens} tokens, "
|
||||
f"compressed from {original_tokens}) + {len(recent_queries)} recent messages "
|
||||
f"(messages {last_compressed_index + 1}-{total_queries - 1})"
|
||||
)
|
||||
|
||||
return compressed_summary, recent_queries
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting compressed context: {str(e)}", exc_info=True
|
||||
)
|
||||
queries = conversation.get("queries", [])
|
||||
if queries is None:
|
||||
return None, []
|
||||
return None, queries
|
||||
|
||||
def _extract_summary(self, llm_response: str) -> str:
|
||||
"""
|
||||
Extract clean summary from LLM response.
|
||||
|
||||
Args:
|
||||
llm_response: Raw LLM response
|
||||
|
||||
Returns:
|
||||
Cleaned summary text
|
||||
"""
|
||||
try:
|
||||
# Try to extract content within <summary> tags
|
||||
summary_match = re.search(
|
||||
r"<summary>(.*?)</summary>", llm_response, re.DOTALL
|
||||
)
|
||||
|
||||
if summary_match:
|
||||
summary = summary_match.group(1).strip()
|
||||
else:
|
||||
# If no summary tags, remove analysis tags and use the rest
|
||||
summary = re.sub(
|
||||
r"<analysis>.*?</analysis>", "", llm_response, flags=re.DOTALL
|
||||
).strip()
|
||||
|
||||
return summary
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error extracting summary: {str(e)}, using full response")
|
||||
return llm_response
|
||||
|
||||
def _log_tool_call_stats(self, queries: List[Dict[str, Any]]) -> None:
|
||||
"""Log statistics about tool calls in queries."""
|
||||
total_tool_calls = 0
|
||||
total_tool_result_chars = 0
|
||||
tool_call_breakdown = {}
|
||||
|
||||
for q in queries:
|
||||
for tc in q.get("tool_calls", []):
|
||||
total_tool_calls += 1
|
||||
tool_name = tc.get("tool_name", "unknown")
|
||||
action_name = tc.get("action_name", "unknown")
|
||||
key = f"{tool_name}.{action_name}"
|
||||
tool_call_breakdown[key] = tool_call_breakdown.get(key, 0) + 1
|
||||
|
||||
# Track total tool result size
|
||||
result = tc.get("result", "")
|
||||
if result:
|
||||
total_tool_result_chars += len(str(result))
|
||||
|
||||
if total_tool_calls > 0:
|
||||
tool_breakdown_str = ", ".join(
|
||||
f"{tool}({count})"
|
||||
for tool, count in sorted(tool_call_breakdown.items())
|
||||
)
|
||||
tool_result_kb = total_tool_result_chars / 1024
|
||||
logger.info(
|
||||
f"Tool call breakdown: {tool_breakdown_str} "
|
||||
f"(total result size: {tool_result_kb:.1f} KB, {total_tool_result_chars:,} chars)"
|
||||
)
|
||||
110
application/api/answer/services/compression/threshold_checker.py
Normal file
110
application/api/answer/services/compression/threshold_checker.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""Compression threshold checking logic."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
|
||||
from application.core.model_utils import get_token_limit
|
||||
from application.core.settings import settings
|
||||
from application.api.answer.services.compression.token_counter import TokenCounter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompressionThresholdChecker:
|
||||
"""Determines if compression is needed based on token thresholds."""
|
||||
|
||||
def __init__(self, threshold_percentage: float = None):
|
||||
"""
|
||||
Initialize threshold checker.
|
||||
|
||||
Args:
|
||||
threshold_percentage: Percentage of context to use as threshold
|
||||
(defaults to settings.COMPRESSION_THRESHOLD_PERCENTAGE)
|
||||
"""
|
||||
self.threshold_percentage = (
|
||||
threshold_percentage or settings.COMPRESSION_THRESHOLD_PERCENTAGE
|
||||
)
|
||||
|
||||
def should_compress(
|
||||
self,
|
||||
conversation: Dict[str, Any],
|
||||
model_id: str,
|
||||
current_query_tokens: int = 500,
|
||||
user_id: str | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if compression is needed.
|
||||
|
||||
Args:
|
||||
conversation: Full conversation document
|
||||
model_id: Target model for this request
|
||||
current_query_tokens: Estimated tokens for current query
|
||||
user_id: Owner — needed so per-user BYOM custom-model UUIDs
|
||||
resolve when looking up the context window.
|
||||
|
||||
Returns:
|
||||
True if tokens >= threshold% of context window
|
||||
"""
|
||||
try:
|
||||
# Calculate total tokens in conversation
|
||||
total_tokens = TokenCounter.count_conversation_tokens(conversation)
|
||||
total_tokens += current_query_tokens
|
||||
|
||||
# Get context window limit for model
|
||||
context_limit = get_token_limit(model_id, user_id=user_id)
|
||||
|
||||
# Calculate threshold
|
||||
threshold = int(context_limit * self.threshold_percentage)
|
||||
|
||||
compression_needed = total_tokens >= threshold
|
||||
percentage_used = (total_tokens / context_limit) * 100
|
||||
|
||||
if compression_needed:
|
||||
logger.warning(
|
||||
f"COMPRESSION TRIGGERED: {total_tokens} tokens / {context_limit} limit "
|
||||
f"({percentage_used:.1f}% used, threshold: {self.threshold_percentage * 100:.0f}%)"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Compression check: {total_tokens}/{context_limit} tokens "
|
||||
f"({percentage_used:.1f}% used, threshold: {self.threshold_percentage * 100:.0f}%) - No compression needed"
|
||||
)
|
||||
|
||||
return compression_needed
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking compression need: {str(e)}", exc_info=True)
|
||||
return False
|
||||
|
||||
def check_message_tokens(
|
||||
self, messages: list, model_id: str, user_id: str | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
Check if message list exceeds threshold.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts
|
||||
model_id: Target model
|
||||
user_id: Owner — needed so per-user BYOM custom-model UUIDs
|
||||
resolve when looking up the context window.
|
||||
|
||||
Returns:
|
||||
True if at or above threshold
|
||||
"""
|
||||
try:
|
||||
current_tokens = TokenCounter.count_message_tokens(messages)
|
||||
context_limit = get_token_limit(model_id, user_id=user_id)
|
||||
threshold = int(context_limit * self.threshold_percentage)
|
||||
|
||||
if current_tokens >= threshold:
|
||||
logger.warning(
|
||||
f"Message context limit approaching: {current_tokens}/{context_limit} tokens "
|
||||
f"({(current_tokens/context_limit)*100:.1f}%)"
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking message tokens: {str(e)}", exc_info=True)
|
||||
return False
|
||||
103
application/api/answer/services/compression/token_counter.py
Normal file
103
application/api/answer/services/compression/token_counter.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""Token counting utilities for compression."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from application.utils import num_tokens_from_string
|
||||
from application.core.settings import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TokenCounter:
|
||||
"""Centralized token counting for conversations and messages."""
|
||||
|
||||
@staticmethod
|
||||
def count_message_tokens(messages: List[Dict]) -> int:
|
||||
"""
|
||||
Calculate total tokens in a list of messages.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'content' field
|
||||
|
||||
Returns:
|
||||
Total token count
|
||||
"""
|
||||
total_tokens = 0
|
||||
for message in messages:
|
||||
content = message.get("content", "")
|
||||
if isinstance(content, str):
|
||||
total_tokens += num_tokens_from_string(content)
|
||||
elif isinstance(content, list):
|
||||
# Handle structured content (tool calls, etc.)
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
total_tokens += num_tokens_from_string(str(item))
|
||||
return total_tokens
|
||||
|
||||
@staticmethod
|
||||
def count_query_tokens(
|
||||
queries: List[Dict[str, Any]], include_tool_calls: bool = True
|
||||
) -> int:
|
||||
"""
|
||||
Count tokens across multiple query objects.
|
||||
|
||||
Args:
|
||||
queries: List of query objects from conversation
|
||||
include_tool_calls: Whether to count tool call tokens
|
||||
|
||||
Returns:
|
||||
Total token count
|
||||
"""
|
||||
total_tokens = 0
|
||||
|
||||
for query in queries:
|
||||
# Count prompt and response tokens
|
||||
if "prompt" in query:
|
||||
total_tokens += num_tokens_from_string(query["prompt"])
|
||||
if "response" in query:
|
||||
total_tokens += num_tokens_from_string(query["response"])
|
||||
if "thought" in query:
|
||||
total_tokens += num_tokens_from_string(query.get("thought", ""))
|
||||
|
||||
# Count tool call tokens
|
||||
if include_tool_calls and "tool_calls" in query:
|
||||
for tool_call in query["tool_calls"]:
|
||||
tool_call_string = (
|
||||
f"Tool: {tool_call.get('tool_name')} | "
|
||||
f"Action: {tool_call.get('action_name')} | "
|
||||
f"Args: {tool_call.get('arguments')} | "
|
||||
f"Response: {tool_call.get('result')}"
|
||||
)
|
||||
total_tokens += num_tokens_from_string(tool_call_string)
|
||||
|
||||
return total_tokens
|
||||
|
||||
@staticmethod
|
||||
def count_conversation_tokens(
|
||||
conversation: Dict[str, Any], include_system_prompt: bool = False
|
||||
) -> int:
|
||||
"""
|
||||
Calculate total tokens in a conversation.
|
||||
|
||||
Args:
|
||||
conversation: Conversation document
|
||||
include_system_prompt: Whether to include system prompt in count
|
||||
|
||||
Returns:
|
||||
Total token count
|
||||
"""
|
||||
try:
|
||||
queries = conversation.get("queries", [])
|
||||
total_tokens = TokenCounter.count_query_tokens(queries)
|
||||
|
||||
# Add system prompt tokens if requested
|
||||
if include_system_prompt:
|
||||
# Rough estimate for system prompt
|
||||
total_tokens += settings.RESERVED_TOKENS.get("system_prompt", 500)
|
||||
|
||||
return total_tokens
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating conversation tokens: {str(e)}")
|
||||
return 0
|
||||
83
application/api/answer/services/compression/types.py
Normal file
83
application/api/answer/services/compression/types.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""Type definitions for compression module."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompressionMetadata:
|
||||
"""Metadata about a compression operation."""
|
||||
|
||||
timestamp: datetime
|
||||
query_index: int
|
||||
compressed_summary: str
|
||||
original_token_count: int
|
||||
compressed_token_count: int
|
||||
compression_ratio: float
|
||||
model_used: str
|
||||
compression_prompt_version: str
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for DB storage."""
|
||||
return {
|
||||
"timestamp": self.timestamp,
|
||||
"query_index": self.query_index,
|
||||
"compressed_summary": self.compressed_summary,
|
||||
"original_token_count": self.original_token_count,
|
||||
"compressed_token_count": self.compressed_token_count,
|
||||
"compression_ratio": self.compression_ratio,
|
||||
"model_used": self.model_used,
|
||||
"compression_prompt_version": self.compression_prompt_version,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompressionResult:
|
||||
"""Result of a compression operation."""
|
||||
|
||||
success: bool
|
||||
compressed_summary: Optional[str] = None
|
||||
recent_queries: List[Dict[str, Any]] = field(default_factory=list)
|
||||
metadata: Optional[CompressionMetadata] = None
|
||||
error: Optional[str] = None
|
||||
compression_performed: bool = False
|
||||
|
||||
@classmethod
|
||||
def success_with_compression(
|
||||
cls, summary: str, queries: List[Dict], metadata: CompressionMetadata
|
||||
) -> "CompressionResult":
|
||||
"""Create a successful result with compression."""
|
||||
return cls(
|
||||
success=True,
|
||||
compressed_summary=summary,
|
||||
recent_queries=queries,
|
||||
metadata=metadata,
|
||||
compression_performed=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def success_no_compression(cls, queries: List[Dict]) -> "CompressionResult":
|
||||
"""Create a successful result without compression needed."""
|
||||
return cls(
|
||||
success=True,
|
||||
recent_queries=queries,
|
||||
compression_performed=False,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def failure(cls, error: str) -> "CompressionResult":
|
||||
"""Create a failure result."""
|
||||
return cls(success=False, error=error, compression_performed=False)
|
||||
|
||||
def as_history(self) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Convert recent queries to history format.
|
||||
|
||||
Returns:
|
||||
List of prompt/response dicts
|
||||
"""
|
||||
return [
|
||||
{"prompt": q["prompt"], "response": q["response"]}
|
||||
for q in self.recent_queries
|
||||
]
|
||||
157
application/api/answer/services/continuation_service.py
Normal file
157
application/api/answer/services/continuation_service.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""Service for saving and restoring tool-call continuation state.
|
||||
|
||||
When a stream pauses (tool needs approval or client-side execution),
|
||||
the full execution state is persisted to Postgres so the client can
|
||||
resume later by sending tool_actions.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||
from application.storage.db.repositories.pending_tool_state import (
|
||||
PendingToolStateRepository,
|
||||
)
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# TTL for pending states — auto-cleaned after this period
|
||||
PENDING_STATE_TTL_SECONDS = 30 * 60 # 30 minutes
|
||||
|
||||
|
||||
def _make_serializable(obj: Any) -> Any:
|
||||
"""Recursively coerce non-JSON values into JSON-safe forms.
|
||||
|
||||
Handles ``uuid.UUID`` (from PG columns), ``bytes``, and recurses into
|
||||
dicts/lists. Post-Mongo-cutover the ObjectId branch is gone — none of
|
||||
our writers produce them anymore.
|
||||
"""
|
||||
if isinstance(obj, UUID):
|
||||
return str(obj)
|
||||
if isinstance(obj, dict):
|
||||
return {str(k): _make_serializable(v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [_make_serializable(v) for v in obj]
|
||||
if isinstance(obj, bytes):
|
||||
return obj.decode("utf-8", errors="replace")
|
||||
return obj
|
||||
|
||||
|
||||
class ContinuationService:
|
||||
"""Manages pending tool-call state in Postgres."""
|
||||
|
||||
def __init__(self):
|
||||
# No-op constructor retained for call-site compatibility. State
|
||||
# lives in Postgres now; each operation opens its own short-lived
|
||||
# session rather than holding a connection on the service.
|
||||
pass
|
||||
|
||||
def save_state(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user: str,
|
||||
messages: List[Dict],
|
||||
pending_tool_calls: List[Dict],
|
||||
tools_dict: Dict,
|
||||
tool_schemas: List[Dict],
|
||||
agent_config: Dict,
|
||||
client_tools: Optional[List[Dict]] = None,
|
||||
) -> str:
|
||||
"""Save execution state for later continuation.
|
||||
|
||||
``conversation_id`` may be a Postgres UUID or the legacy Mongo
|
||||
``ObjectId`` string — the latter is resolved via
|
||||
``conversations.legacy_mongo_id`` to find the matching row.
|
||||
|
||||
Args:
|
||||
conversation_id: The conversation this state belongs to.
|
||||
user: Owner user ID.
|
||||
messages: Full messages array at the pause point.
|
||||
pending_tool_calls: Tool calls awaiting client action.
|
||||
tools_dict: Serializable tools configuration dict.
|
||||
tool_schemas: LLM-formatted tool schemas (agent.tools).
|
||||
agent_config: Config needed to recreate the agent on resume.
|
||||
client_tools: Client-provided tool schemas for client-side execution.
|
||||
|
||||
Returns:
|
||||
The string ID (conversation_id as provided) of the saved state.
|
||||
"""
|
||||
with db_session() as conn:
|
||||
conv = ConversationsRepository(conn).get_by_legacy_id(conversation_id)
|
||||
if conv is not None:
|
||||
pg_conv_id = conv["id"]
|
||||
elif looks_like_uuid(conversation_id):
|
||||
pg_conv_id = conversation_id
|
||||
else:
|
||||
# Unresolvable legacy ObjectId — downstream ``CAST AS uuid``
|
||||
# would raise and poison the save. Surface the mismatch so
|
||||
# the caller can decide (the stream loop in routes/base.py
|
||||
# already wraps this in try/except).
|
||||
raise ValueError(
|
||||
f"Cannot save continuation state: conversation_id "
|
||||
f"{conversation_id!r} is neither a PG UUID nor a "
|
||||
f"backfilled legacy Mongo id."
|
||||
)
|
||||
PendingToolStateRepository(conn).save_state(
|
||||
pg_conv_id,
|
||||
user,
|
||||
messages=_make_serializable(messages),
|
||||
pending_tool_calls=_make_serializable(pending_tool_calls),
|
||||
tools_dict=_make_serializable(tools_dict),
|
||||
tool_schemas=_make_serializable(tool_schemas),
|
||||
agent_config=_make_serializable(agent_config),
|
||||
client_tools=_make_serializable(client_tools) if client_tools else None,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Saved continuation state for conversation {conversation_id} "
|
||||
f"with {len(pending_tool_calls)} pending tool call(s)"
|
||||
)
|
||||
return conversation_id
|
||||
|
||||
def load_state(
|
||||
self, conversation_id: str, user: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Load pending continuation state.
|
||||
|
||||
Returns:
|
||||
The state dict, or None if no pending state exists.
|
||||
"""
|
||||
with db_readonly() as conn:
|
||||
conv = ConversationsRepository(conn).get_by_legacy_id(conversation_id)
|
||||
if conv is not None:
|
||||
pg_conv_id = conv["id"]
|
||||
elif looks_like_uuid(conversation_id):
|
||||
pg_conv_id = conversation_id
|
||||
else:
|
||||
# Unresolvable legacy ObjectId → no state can exist for it.
|
||||
return None
|
||||
doc = PendingToolStateRepository(conn).load_state(pg_conv_id, user)
|
||||
if not doc:
|
||||
return None
|
||||
return doc
|
||||
|
||||
def delete_state(self, conversation_id: str, user: str) -> bool:
|
||||
"""Delete pending state after successful resumption.
|
||||
|
||||
Returns:
|
||||
True if a row was deleted.
|
||||
"""
|
||||
with db_session() as conn:
|
||||
conv = ConversationsRepository(conn).get_by_legacy_id(conversation_id)
|
||||
if conv is not None:
|
||||
pg_conv_id = conv["id"]
|
||||
elif looks_like_uuid(conversation_id):
|
||||
pg_conv_id = conversation_id
|
||||
else:
|
||||
# Unresolvable legacy ObjectId → nothing to delete.
|
||||
return False
|
||||
deleted = PendingToolStateRepository(conn).delete_state(pg_conv_id, user)
|
||||
if deleted:
|
||||
logger.info(
|
||||
f"Deleted continuation state for conversation {conversation_id}"
|
||||
)
|
||||
return deleted
|
||||
286
application/api/answer/services/conversation_service.py
Normal file
286
application/api/answer/services/conversation_service.py
Normal file
@@ -0,0 +1,286 @@
|
||||
"""Conversation persistence service backed by Postgres.
|
||||
|
||||
Handles create / append / update / compression for conversations during
|
||||
the answer-streaming path. Connections are opened per-operation rather
|
||||
than held for the duration of a stream.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from sqlalchemy import text as sql_text
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConversationService:
|
||||
def get_conversation(
|
||||
self, conversation_id: str, user_id: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Retrieve a conversation with owner-or-shared access control.
|
||||
|
||||
Returns a dict in the legacy Mongo shape — ``queries`` is a list
|
||||
of message dicts (prompt/response/...) — for compatibility with
|
||||
the streaming pipeline that consumes this shape.
|
||||
"""
|
||||
if not conversation_id or not user_id:
|
||||
return None
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
repo = ConversationsRepository(conn)
|
||||
conv = repo.get_any(conversation_id, user_id)
|
||||
if conv is None:
|
||||
logger.warning(
|
||||
f"Conversation not found or unauthorized - ID: {conversation_id}, User: {user_id}"
|
||||
)
|
||||
return None
|
||||
messages = repo.get_messages(str(conv["id"]))
|
||||
conv["queries"] = messages
|
||||
conv["_id"] = str(conv["id"])
|
||||
return conv
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching conversation: {str(e)}", exc_info=True)
|
||||
return None
|
||||
|
||||
def save_conversation(
|
||||
self,
|
||||
conversation_id: Optional[str],
|
||||
question: str,
|
||||
response: str,
|
||||
thought: str,
|
||||
sources: List[Dict[str, Any]],
|
||||
tool_calls: List[Dict[str, Any]],
|
||||
llm: Any,
|
||||
model_id: str,
|
||||
decoded_token: Dict[str, Any],
|
||||
index: Optional[int] = None,
|
||||
api_key: Optional[str] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
is_shared_usage: bool = False,
|
||||
shared_token: Optional[str] = None,
|
||||
attachment_ids: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""Save or update a conversation in Postgres.
|
||||
|
||||
Returns the string conversation id (PG UUID as string, or the
|
||||
caller-provided id if it was already a UUID).
|
||||
"""
|
||||
if decoded_token is None:
|
||||
raise ValueError("Invalid or missing authentication token")
|
||||
user_id = decoded_token.get("sub")
|
||||
if not user_id:
|
||||
raise ValueError("User ID not found in token")
|
||||
current_time = datetime.now(timezone.utc)
|
||||
|
||||
# Trim huge inline source text to a reasonable max before persist.
|
||||
for source in sources:
|
||||
if "text" in source and isinstance(source["text"], str):
|
||||
source["text"] = source["text"][:1000]
|
||||
|
||||
message_payload = {
|
||||
"prompt": question,
|
||||
"response": response,
|
||||
"thought": thought,
|
||||
"sources": sources,
|
||||
"tool_calls": tool_calls,
|
||||
"attachments": attachment_ids,
|
||||
"model_id": model_id,
|
||||
"timestamp": current_time,
|
||||
}
|
||||
if metadata:
|
||||
message_payload["metadata"] = metadata
|
||||
|
||||
if conversation_id is not None and index is not None:
|
||||
with db_session() as conn:
|
||||
repo = ConversationsRepository(conn)
|
||||
conv = repo.get_any(conversation_id, user_id)
|
||||
if conv is None:
|
||||
raise ValueError("Conversation not found or unauthorized")
|
||||
conv_pg_id = str(conv["id"])
|
||||
repo.update_message_at(conv_pg_id, index, message_payload)
|
||||
repo.truncate_after(conv_pg_id, index)
|
||||
return conversation_id
|
||||
elif conversation_id:
|
||||
with db_session() as conn:
|
||||
repo = ConversationsRepository(conn)
|
||||
conv = repo.get_any(conversation_id, user_id)
|
||||
if conv is None:
|
||||
raise ValueError("Conversation not found or unauthorized")
|
||||
conv_pg_id = str(conv["id"])
|
||||
# append_message expects 'metadata' key either way; normalise.
|
||||
append_payload = dict(message_payload)
|
||||
append_payload.setdefault("metadata", metadata or {})
|
||||
repo.append_message(conv_pg_id, append_payload)
|
||||
return conversation_id
|
||||
else:
|
||||
messages_summary = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant that creates concise conversation titles. "
|
||||
"Summarize conversations in 3 words or less using the same language as the user.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Summarise following conversation in no more than 3 words, "
|
||||
"respond ONLY with the summary, use the same language as the "
|
||||
"user query \n\nUser: " + question + "\n\n" + "AI: " + response,
|
||||
},
|
||||
]
|
||||
|
||||
# ``model_id`` here is the registry id (a UUID for BYOM
|
||||
# records). The LLM's own ``model_id`` is the upstream name
|
||||
# LLMCreator resolved at construction time — that's what
|
||||
# the provider's API expects. Built-ins are unaffected.
|
||||
completion = llm.gen(
|
||||
model=getattr(llm, "model_id", None) or model_id,
|
||||
messages=messages_summary,
|
||||
max_tokens=500,
|
||||
)
|
||||
|
||||
if not completion or not completion.strip():
|
||||
completion = question[:50] if question else "New Conversation"
|
||||
|
||||
resolved_api_key: Optional[str] = None
|
||||
resolved_agent_id: Optional[str] = None
|
||||
if api_key:
|
||||
with db_readonly() as conn:
|
||||
agent = AgentsRepository(conn).find_by_key(api_key)
|
||||
if agent:
|
||||
resolved_api_key = agent.get("key")
|
||||
if agent_id:
|
||||
resolved_agent_id = agent_id
|
||||
|
||||
with db_session() as conn:
|
||||
repo = ConversationsRepository(conn)
|
||||
conv = repo.create(
|
||||
user_id,
|
||||
completion,
|
||||
agent_id=resolved_agent_id,
|
||||
api_key=resolved_api_key,
|
||||
is_shared_usage=bool(resolved_agent_id and is_shared_usage),
|
||||
shared_token=(
|
||||
shared_token
|
||||
if (resolved_agent_id and is_shared_usage)
|
||||
else None
|
||||
),
|
||||
)
|
||||
conv_pg_id = str(conv["id"])
|
||||
append_payload = dict(message_payload)
|
||||
append_payload.setdefault("metadata", metadata or {})
|
||||
repo.append_message(conv_pg_id, append_payload)
|
||||
return conv_pg_id
|
||||
|
||||
def update_compression_metadata(
|
||||
self, conversation_id: str, compression_metadata: Dict[str, Any]
|
||||
) -> None:
|
||||
"""Persist compression flags and append a compression point.
|
||||
|
||||
Mirrors the Mongo-era ``$set`` + ``$push $slice`` on
|
||||
``compression_metadata`` but goes through the PG repo API.
|
||||
"""
|
||||
try:
|
||||
with db_session() as conn:
|
||||
repo = ConversationsRepository(conn)
|
||||
# conversation_id here comes from the streaming pipeline
|
||||
# which has already resolved it; accept either UUID or
|
||||
# legacy id for safety.
|
||||
conv = repo.get_by_legacy_id(conversation_id)
|
||||
conv_pg_id = (
|
||||
str(conv["id"]) if conv is not None else conversation_id
|
||||
)
|
||||
repo.set_compression_flags(
|
||||
conv_pg_id,
|
||||
is_compressed=True,
|
||||
last_compression_at=compression_metadata.get("timestamp"),
|
||||
)
|
||||
repo.append_compression_point(
|
||||
conv_pg_id,
|
||||
compression_metadata,
|
||||
max_points=settings.COMPRESSION_MAX_HISTORY_POINTS,
|
||||
)
|
||||
logger.info(
|
||||
f"Updated compression metadata for conversation {conversation_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error updating compression metadata: {str(e)}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def append_compression_message(
|
||||
self, conversation_id: str, compression_metadata: Dict[str, Any]
|
||||
) -> None:
|
||||
"""Append a synthetic compression summary message to the conversation."""
|
||||
try:
|
||||
summary = compression_metadata.get("compressed_summary", "")
|
||||
if not summary:
|
||||
return
|
||||
timestamp = compression_metadata.get(
|
||||
"timestamp", datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
with db_session() as conn:
|
||||
repo = ConversationsRepository(conn)
|
||||
conv = repo.get_by_legacy_id(conversation_id)
|
||||
conv_pg_id = (
|
||||
str(conv["id"]) if conv is not None else conversation_id
|
||||
)
|
||||
repo.append_message(conv_pg_id, {
|
||||
"prompt": "[Context Compression Summary]",
|
||||
"response": summary,
|
||||
"thought": "",
|
||||
"sources": [],
|
||||
"tool_calls": [],
|
||||
"attachments": [],
|
||||
"model_id": compression_metadata.get("model_used"),
|
||||
"timestamp": timestamp,
|
||||
})
|
||||
logger.info(
|
||||
f"Appended compression summary to conversation {conversation_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error appending compression summary: {str(e)}", exc_info=True
|
||||
)
|
||||
|
||||
def get_compression_metadata(
|
||||
self, conversation_id: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Fetch the stored compression metadata JSONB blob for a conversation."""
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
repo = ConversationsRepository(conn)
|
||||
conv = repo.get_by_legacy_id(conversation_id)
|
||||
if conv is None:
|
||||
# Fallback to UUID lookup without user scoping — the
|
||||
# caller already holds an authenticated conversation
|
||||
# id from the streaming path. Gate on id shape so a
|
||||
# non-UUID (legacy ObjectId that wasn't backfilled)
|
||||
# doesn't reach CAST — the cast raises and spams the
|
||||
# logs with a stack trace on every call.
|
||||
if not looks_like_uuid(conversation_id):
|
||||
return None
|
||||
result = conn.execute(
|
||||
sql_text(
|
||||
"SELECT compression_metadata FROM conversations "
|
||||
"WHERE id = CAST(:id AS uuid)"
|
||||
),
|
||||
{"id": conversation_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row[0] if row is not None else None
|
||||
return conv.get("compression_metadata") if conv else None
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting compression metadata: {str(e)}", exc_info=True
|
||||
)
|
||||
return None
|
||||
97
application/api/answer/services/prompt_renderer.py
Normal file
97
application/api/answer/services/prompt_renderer.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from application.templates.namespaces import NamespaceManager
|
||||
|
||||
from application.templates.template_engine import TemplateEngine, TemplateRenderError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PromptRenderer:
|
||||
"""Service for rendering prompts with dynamic context using namespaces"""
|
||||
|
||||
def __init__(self):
|
||||
self.template_engine = TemplateEngine()
|
||||
self.namespace_manager = NamespaceManager()
|
||||
|
||||
def render_prompt(
|
||||
self,
|
||||
prompt_content: str,
|
||||
user_id: Optional[str] = None,
|
||||
request_id: Optional[str] = None,
|
||||
passthrough_data: Optional[Dict[str, Any]] = None,
|
||||
docs: Optional[list] = None,
|
||||
docs_together: Optional[str] = None,
|
||||
tools_data: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
Render prompt with full context from all namespaces.
|
||||
|
||||
Args:
|
||||
prompt_content: Raw prompt template string
|
||||
user_id: Current user identifier
|
||||
request_id: Unique request identifier
|
||||
passthrough_data: Parameters from web request
|
||||
docs: RAG retrieved documents
|
||||
docs_together: Concatenated document content
|
||||
tools_data: Pre-fetched tool results organized by tool name
|
||||
**kwargs: Additional parameters for namespace builders
|
||||
|
||||
Returns:
|
||||
Rendered prompt string with all variables substituted
|
||||
|
||||
Raises:
|
||||
TemplateRenderError: If template rendering fails
|
||||
"""
|
||||
if not prompt_content:
|
||||
return ""
|
||||
|
||||
uses_template = self._uses_template_syntax(prompt_content)
|
||||
|
||||
if not uses_template:
|
||||
return self._apply_legacy_substitutions(prompt_content, docs_together)
|
||||
|
||||
try:
|
||||
context = self.namespace_manager.build_context(
|
||||
user_id=user_id,
|
||||
request_id=request_id,
|
||||
passthrough_data=passthrough_data,
|
||||
docs=docs,
|
||||
docs_together=docs_together,
|
||||
tools_data=tools_data,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return self.template_engine.render(prompt_content, context)
|
||||
except TemplateRenderError:
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"Prompt rendering failed: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
raise TemplateRenderError(error_msg) from e
|
||||
|
||||
def _uses_template_syntax(self, prompt_content: str) -> bool:
|
||||
"""Check if prompt uses Jinja2 template syntax"""
|
||||
return "{{" in prompt_content and "}}" in prompt_content
|
||||
|
||||
def _apply_legacy_substitutions(
|
||||
self, prompt_content: str, docs_together: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Apply backward-compatible substitutions for old prompt format.
|
||||
|
||||
Handles legacy {summaries} and {query} placeholders during transition period.
|
||||
"""
|
||||
if docs_together:
|
||||
prompt_content = prompt_content.replace("{summaries}", docs_together)
|
||||
return prompt_content
|
||||
|
||||
def validate_template(self, prompt_content: str) -> bool:
|
||||
"""Validate prompt template syntax"""
|
||||
return self.template_engine.validate_template(prompt_content)
|
||||
|
||||
def extract_variables(self, prompt_content: str) -> set[str]:
|
||||
"""Extract all variable names from prompt template"""
|
||||
return self.template_engine.extract_variables(prompt_content)
|
||||
1177
application/api/answer/services/stream_processor.py
Normal file
1177
application/api/answer/services/stream_processor.py
Normal file
File diff suppressed because it is too large
Load Diff
551
application/api/connector/routes.py
Normal file
551
application/api/connector/routes.py
Normal file
@@ -0,0 +1,551 @@
|
||||
import base64
|
||||
import html
|
||||
import json
|
||||
import uuid
|
||||
from urllib.parse import urlencode
|
||||
|
||||
|
||||
from flask import (
|
||||
Blueprint,
|
||||
current_app,
|
||||
jsonify,
|
||||
make_response,
|
||||
request
|
||||
)
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.tasks import (
|
||||
ingest_connector_task,
|
||||
)
|
||||
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||
from application.storage.db.repositories.connector_sessions import (
|
||||
ConnectorSessionsRepository,
|
||||
)
|
||||
from application.storage.db.repositories.sources import SourcesRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
|
||||
connector = Blueprint("connector", __name__)
|
||||
connectors_ns = Namespace("connectors", description="Connector operations", path="/")
|
||||
api.add_namespace(connectors_ns)
|
||||
|
||||
# Fixed callback status path to prevent open redirect
|
||||
CALLBACK_STATUS_PATH = "/api/connectors/callback-status"
|
||||
|
||||
|
||||
def build_callback_redirect(params: dict) -> str:
|
||||
"""Build a safe redirect URL to the callback status page.
|
||||
|
||||
Uses a fixed path and properly URL-encodes all parameters
|
||||
to prevent URL injection and open redirect vulnerabilities.
|
||||
"""
|
||||
return f"{CALLBACK_STATUS_PATH}?{urlencode(params)}"
|
||||
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/auth")
|
||||
class ConnectorAuth(Resource):
|
||||
@api.doc(description="Get connector OAuth authorization URL", params={"provider": "Connector provider (e.g., google_drive)"})
|
||||
def get(self):
|
||||
try:
|
||||
provider = request.args.get('provider') or request.args.get('source')
|
||||
if not provider:
|
||||
return make_response(jsonify({"success": False, "error": "Missing provider"}), 400)
|
||||
|
||||
if not ConnectorCreator.is_supported(provider):
|
||||
return make_response(jsonify({"success": False, "error": f"Unsupported provider: {provider}"}), 400)
|
||||
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
||||
user_id = decoded_token.get('sub')
|
||||
|
||||
with db_session() as conn:
|
||||
session_row = ConnectorSessionsRepository(conn).upsert(
|
||||
user_id, provider, status="pending",
|
||||
)
|
||||
session_pg_id = str(session_row["id"])
|
||||
state_dict = {
|
||||
"provider": provider,
|
||||
"object_id": session_pg_id,
|
||||
}
|
||||
state = base64.urlsafe_b64encode(json.dumps(state_dict).encode()).decode()
|
||||
|
||||
auth = ConnectorCreator.create_auth(provider)
|
||||
authorization_url = auth.get_authorization_url(state=state)
|
||||
return make_response(jsonify({
|
||||
"success": True,
|
||||
"authorization_url": authorization_url,
|
||||
"state": state
|
||||
}), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error generating connector auth URL: {e}", exc_info=True)
|
||||
return make_response(jsonify({"success": False, "error": "Failed to generate authorization URL"}), 500)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/callback")
|
||||
class ConnectorsCallback(Resource):
|
||||
@api.doc(description="Handle OAuth callback for external connectors")
|
||||
def get(self):
|
||||
"""Handle OAuth callback for external connectors"""
|
||||
try:
|
||||
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||
from flask import request, redirect
|
||||
|
||||
authorization_code = request.args.get('code')
|
||||
state = request.args.get('state')
|
||||
error = request.args.get('error')
|
||||
|
||||
state_dict = json.loads(base64.urlsafe_b64decode(state.encode()).decode())
|
||||
provider = state_dict.get("provider")
|
||||
state_object_id = state_dict.get("object_id")
|
||||
|
||||
# Validate provider
|
||||
if not provider or not isinstance(provider, str) or not ConnectorCreator.is_supported(provider):
|
||||
return redirect(build_callback_redirect({
|
||||
"status": "error",
|
||||
"message": "Invalid provider"
|
||||
}))
|
||||
|
||||
if error:
|
||||
if error == "access_denied":
|
||||
return redirect(build_callback_redirect({
|
||||
"status": "cancelled",
|
||||
"message": "Authentication was cancelled. You can try again if you'd like to connect your account.",
|
||||
"provider": provider
|
||||
}))
|
||||
else:
|
||||
current_app.logger.warning(f"OAuth error in callback: {error}")
|
||||
return redirect(build_callback_redirect({
|
||||
"status": "error",
|
||||
"message": "Authentication failed. Please try again and make sure to grant all requested permissions.",
|
||||
"provider": provider
|
||||
}))
|
||||
|
||||
if not authorization_code:
|
||||
return redirect(build_callback_redirect({
|
||||
"status": "error",
|
||||
"message": "Authentication failed. Please try again and make sure to grant all requested permissions.",
|
||||
"provider": provider
|
||||
}))
|
||||
|
||||
try:
|
||||
auth = ConnectorCreator.create_auth(provider)
|
||||
token_info = auth.exchange_code_for_tokens(authorization_code)
|
||||
|
||||
session_token = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
if provider == "google_drive":
|
||||
credentials = auth.create_credentials_from_token_info(token_info)
|
||||
service = auth.build_drive_service(credentials)
|
||||
user_info = service.about().get(fields="user").execute()
|
||||
user_email = user_info.get('user', {}).get('emailAddress', 'Connected User')
|
||||
else:
|
||||
user_email = token_info.get('user_info', {}).get('email', 'Connected User')
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.warning(f"Could not get user info: {e}")
|
||||
user_email = 'Connected User'
|
||||
|
||||
sanitized_token_info = auth.sanitize_token_info(token_info)
|
||||
|
||||
# ``object_id`` in the OAuth state is the PG session row
|
||||
# UUID (new flow) or a legacy Mongo ObjectId (pre-cutover
|
||||
# issued state). Try UUID update first; fall back to
|
||||
# legacy id path.
|
||||
patch = {
|
||||
"session_token": session_token,
|
||||
"token_info": sanitized_token_info,
|
||||
"user_email": user_email,
|
||||
"status": "authorized",
|
||||
}
|
||||
with db_session() as conn:
|
||||
repo = ConnectorSessionsRepository(conn)
|
||||
if state_object_id:
|
||||
value = str(state_object_id)
|
||||
updated = False
|
||||
if len(value) == 36 and "-" in value:
|
||||
updated = repo.update(value, patch)
|
||||
if not updated:
|
||||
repo.update_by_legacy_id(value, patch)
|
||||
|
||||
# Redirect to success page with session token and user email
|
||||
return redirect(build_callback_redirect({
|
||||
"status": "success",
|
||||
"message": "Authentication successful",
|
||||
"provider": provider,
|
||||
"session_token": session_token,
|
||||
"user_email": user_email
|
||||
}))
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error exchanging code for tokens: {str(e)}", exc_info=True)
|
||||
return redirect(build_callback_redirect({
|
||||
"status": "error",
|
||||
"message": "Authentication failed. Please try again and make sure to grant all requested permissions.",
|
||||
"provider": provider
|
||||
}))
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error handling connector callback: {e}")
|
||||
return redirect(build_callback_redirect({
|
||||
"status": "error",
|
||||
"message": "Authentication failed. Please try again and make sure to grant all requested permissions."
|
||||
}))
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/files")
|
||||
class ConnectorFiles(Resource):
|
||||
@api.expect(api.model("ConnectorFilesModel", {
|
||||
"provider": fields.String(required=True),
|
||||
"session_token": fields.String(required=True),
|
||||
"folder_id": fields.String(required=False),
|
||||
"limit": fields.Integer(required=False),
|
||||
"page_token": fields.String(required=False),
|
||||
"search_query": fields.String(required=False),
|
||||
}))
|
||||
@api.doc(description="List files from a connector provider (supports pagination and search)")
|
||||
def post(self):
|
||||
try:
|
||||
data = request.get_json()
|
||||
provider = data.get('provider')
|
||||
session_token = data.get('session_token')
|
||||
limit = data.get('limit', 10)
|
||||
|
||||
if not provider or not session_token:
|
||||
return make_response(jsonify({"success": False, "error": "provider and session_token are required"}), 400)
|
||||
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
||||
user = decoded_token.get('sub')
|
||||
with db_readonly() as conn:
|
||||
session = ConnectorSessionsRepository(conn).get_by_session_token(
|
||||
session_token,
|
||||
)
|
||||
if not session or session.get("user_id") != user:
|
||||
return make_response(jsonify({"success": False, "error": "Invalid or unauthorized session"}), 401)
|
||||
|
||||
loader = ConnectorCreator.create_connector(provider, session_token)
|
||||
|
||||
generic_keys = {'provider', 'session_token'}
|
||||
input_config = {
|
||||
k: v for k, v in data.items() if k not in generic_keys
|
||||
}
|
||||
input_config['list_only'] = True
|
||||
|
||||
documents = loader.load_data(input_config)
|
||||
|
||||
files = []
|
||||
for doc in documents[:limit]:
|
||||
metadata = doc.extra_info
|
||||
modified_time = metadata.get('modified_time')
|
||||
if modified_time:
|
||||
date_part = modified_time.split('T')[0]
|
||||
time_part = modified_time.split('T')[1].split('.')[0].split('Z')[0]
|
||||
formatted_time = f"{date_part} {time_part}"
|
||||
else:
|
||||
formatted_time = None
|
||||
|
||||
files.append({
|
||||
'id': doc.doc_id,
|
||||
'name': metadata.get('file_name', 'Unknown File'),
|
||||
'type': metadata.get('mime_type', 'unknown'),
|
||||
'size': metadata.get('size', None),
|
||||
'modifiedTime': formatted_time,
|
||||
'isFolder': metadata.get('is_folder', False)
|
||||
})
|
||||
|
||||
next_token = getattr(loader, 'next_page_token', None)
|
||||
has_more = bool(next_token)
|
||||
|
||||
return make_response(jsonify({
|
||||
"success": True,
|
||||
"files": files,
|
||||
"total": len(files),
|
||||
"next_page_token": next_token,
|
||||
"has_more": has_more
|
||||
}), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error loading connector files: {e}", exc_info=True)
|
||||
return make_response(jsonify({"success": False, "error": "Failed to load files"}), 500)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/validate-session")
|
||||
class ConnectorValidateSession(Resource):
|
||||
@api.expect(api.model("ConnectorValidateSessionModel", {"provider": fields.String(required=True), "session_token": fields.String(required=True)}))
|
||||
@api.doc(description="Validate connector session token and return user info and access token")
|
||||
def post(self):
|
||||
try:
|
||||
data = request.get_json()
|
||||
provider = data.get('provider')
|
||||
session_token = data.get('session_token')
|
||||
if not provider or not session_token:
|
||||
return make_response(jsonify({"success": False, "error": "provider and session_token are required"}), 400)
|
||||
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
||||
user = decoded_token.get('sub')
|
||||
|
||||
with db_readonly() as conn:
|
||||
session = ConnectorSessionsRepository(conn).get_by_session_token(
|
||||
session_token,
|
||||
)
|
||||
if not session or session.get("user_id") != user or not session.get("token_info"):
|
||||
return make_response(jsonify({"success": False, "error": "Invalid or expired session"}), 401)
|
||||
|
||||
token_info = session["token_info"]
|
||||
auth = ConnectorCreator.create_auth(provider)
|
||||
is_expired = auth.is_token_expired(token_info)
|
||||
|
||||
if is_expired and token_info.get('refresh_token'):
|
||||
try:
|
||||
refreshed_token_info = auth.refresh_access_token(token_info.get('refresh_token'))
|
||||
sanitized_token_info = auth.sanitize_token_info(refreshed_token_info)
|
||||
with db_session() as conn:
|
||||
repo = ConnectorSessionsRepository(conn)
|
||||
row = repo.get_by_session_token(session_token)
|
||||
if row:
|
||||
repo.update(str(row["id"]), {"token_info": sanitized_token_info})
|
||||
token_info = sanitized_token_info
|
||||
is_expired = False
|
||||
except Exception as refresh_error:
|
||||
current_app.logger.error(f"Failed to refresh token: {refresh_error}")
|
||||
|
||||
if is_expired:
|
||||
return make_response(jsonify({
|
||||
"success": False,
|
||||
"expired": True,
|
||||
"error": "Session token has expired. Please reconnect."
|
||||
}), 401)
|
||||
|
||||
_base_fields = {"access_token", "refresh_token", "token_uri", "expiry"}
|
||||
provider_extras = {k: v for k, v in token_info.items() if k not in _base_fields}
|
||||
|
||||
response_data = {
|
||||
"success": True,
|
||||
"expired": False,
|
||||
"user_email": session.get('user_email', 'Connected User'),
|
||||
"access_token": token_info.get('access_token'),
|
||||
**provider_extras,
|
||||
}
|
||||
|
||||
return make_response(jsonify(response_data), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error validating connector session: {e}", exc_info=True)
|
||||
return make_response(jsonify({"success": False, "error": "Failed to validate session"}), 500)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/disconnect")
|
||||
class ConnectorDisconnect(Resource):
|
||||
@api.expect(api.model("ConnectorDisconnectModel", {"provider": fields.String(required=True), "session_token": fields.String(required=False)}))
|
||||
@api.doc(description="Disconnect a connector session")
|
||||
def post(self):
|
||||
try:
|
||||
data = request.get_json()
|
||||
provider = data.get('provider')
|
||||
session_token = data.get('session_token')
|
||||
if not provider:
|
||||
return make_response(jsonify({"success": False, "error": "provider is required"}), 400)
|
||||
|
||||
|
||||
if session_token:
|
||||
with db_session() as conn:
|
||||
ConnectorSessionsRepository(conn).delete_by_session_token(
|
||||
session_token,
|
||||
)
|
||||
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error disconnecting connector session: {e}", exc_info=True)
|
||||
return make_response(jsonify({"success": False, "error": "Failed to disconnect session"}), 500)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/sync")
|
||||
class ConnectorSync(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"ConnectorSyncModel",
|
||||
{
|
||||
"source_id": fields.String(required=True, description="Source ID to sync"),
|
||||
"session_token": fields.String(required=True, description="Authentication token")
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Sync connector source to check for modifications")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
|
||||
try:
|
||||
data = request.get_json()
|
||||
source_id = data.get('source_id')
|
||||
session_token = data.get('session_token')
|
||||
|
||||
if not all([source_id, session_token]):
|
||||
return make_response(
|
||||
jsonify({
|
||||
"success": False,
|
||||
"error": "source_id and session_token are required"
|
||||
}),
|
||||
400
|
||||
)
|
||||
user_id = decoded_token.get('sub')
|
||||
with db_readonly() as conn:
|
||||
source = SourcesRepository(conn).get_any(source_id, user_id)
|
||||
if not source:
|
||||
return make_response(
|
||||
jsonify({
|
||||
"success": False,
|
||||
"error": "Source not found"
|
||||
}),
|
||||
404
|
||||
)
|
||||
|
||||
# ``get_any`` already scopes by ``user_id``; an extra guard
|
||||
# here would be dead code.
|
||||
|
||||
remote_data = source.get('remote_data') or {}
|
||||
if isinstance(remote_data, str):
|
||||
try:
|
||||
remote_data = json.loads(remote_data)
|
||||
except json.JSONDecodeError:
|
||||
current_app.logger.error(f"Invalid remote_data format for source {source_id}")
|
||||
remote_data = {}
|
||||
|
||||
source_type = remote_data.get('provider')
|
||||
if not source_type:
|
||||
return make_response(
|
||||
jsonify({
|
||||
"success": False,
|
||||
"error": "Source provider not found in remote_data"
|
||||
}),
|
||||
400
|
||||
)
|
||||
|
||||
# Extract configuration from remote_data
|
||||
file_ids = remote_data.get('file_ids', [])
|
||||
folder_ids = remote_data.get('folder_ids', [])
|
||||
recursive = remote_data.get('recursive', True)
|
||||
|
||||
# Start the sync task
|
||||
task = ingest_connector_task.delay(
|
||||
job_name=source.get('name'),
|
||||
user=decoded_token.get('sub'),
|
||||
source_type=source_type,
|
||||
session_token=session_token,
|
||||
file_ids=file_ids,
|
||||
folder_ids=folder_ids,
|
||||
recursive=recursive,
|
||||
retriever=source.get('retriever', 'classic'),
|
||||
operation_mode="sync",
|
||||
doc_id=str(source.get('id') or source_id),
|
||||
sync_frequency=source.get('sync_frequency', 'never')
|
||||
)
|
||||
|
||||
return make_response(
|
||||
jsonify({
|
||||
"success": True,
|
||||
"task_id": task.id
|
||||
}),
|
||||
200
|
||||
)
|
||||
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error syncing connector source: {err}",
|
||||
exc_info=True
|
||||
)
|
||||
return make_response(
|
||||
jsonify({
|
||||
"success": False,
|
||||
"error": "Failed to sync connector source"
|
||||
}),
|
||||
400
|
||||
)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/callback-status")
|
||||
class ConnectorCallbackStatus(Resource):
|
||||
@api.doc(description="Return HTML page with connector authentication status")
|
||||
def get(self):
|
||||
"""Return HTML page with connector authentication status"""
|
||||
try:
|
||||
# Validate and sanitize status to a known value
|
||||
status_raw = request.args.get('status', 'error')
|
||||
status = status_raw if status_raw in ('success', 'error', 'cancelled') else 'error'
|
||||
|
||||
# Escape all user-controlled values for HTML context
|
||||
message = html.escape(request.args.get('message', ''))
|
||||
provider_raw = request.args.get('provider', 'connector')
|
||||
provider = html.escape(provider_raw.replace('_', ' ').title())
|
||||
session_token = request.args.get('session_token', '')
|
||||
user_email = html.escape(request.args.get('user_email', ''))
|
||||
|
||||
def safe_js_string(value: str) -> str:
|
||||
"""Safely encode a string for embedding in inline JavaScript."""
|
||||
js_encoded = json.dumps(value)
|
||||
return js_encoded.replace('</', '<\\/').replace('<!--', '<\\!--')
|
||||
|
||||
js_status = safe_js_string(status)
|
||||
js_session_token = safe_js_string(session_token)
|
||||
js_user_email = safe_js_string(user_email)
|
||||
js_provider_type = safe_js_string(provider_raw)
|
||||
|
||||
html_content = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>{provider} Authentication</title>
|
||||
<style>
|
||||
body {{ font-family: Arial, sans-serif; text-align: center; padding: 40px; }}
|
||||
.container {{ max-width: 600px; margin: 0 auto; }}
|
||||
.success {{ color: #4CAF50; }}
|
||||
.error {{ color: #F44336; }}
|
||||
.cancelled {{ color: #FF9800; }}
|
||||
</style>
|
||||
<script>
|
||||
window.onload = function() {{
|
||||
const status = {js_status};
|
||||
const sessionToken = {js_session_token};
|
||||
const userEmail = {js_user_email};
|
||||
const providerType = {js_provider_type};
|
||||
|
||||
if (status === "success" && window.opener) {{
|
||||
window.opener.postMessage({{
|
||||
type: providerType + '_auth_success',
|
||||
session_token: sessionToken,
|
||||
user_email: userEmail
|
||||
}}, '*');
|
||||
|
||||
setTimeout(() => window.close(), 3000);
|
||||
}} else if (status === "cancelled" || status === "error") {{
|
||||
setTimeout(() => window.close(), 3000);
|
||||
}}
|
||||
}};
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h2>{provider} Authentication</h2>
|
||||
<div class="{status}">
|
||||
<p>{message}</p>
|
||||
{f'<p>Connected as: {user_email}</p>' if status == 'success' else ''}
|
||||
</div>
|
||||
<p><small>You can close this window. {f"Your {provider} is now connected and ready to use." if status == 'success' else "Feel free to close this window."}</small></p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
return make_response(html_content, 200, {'Content-Type': 'text/html'})
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error rendering callback status page: {e}")
|
||||
return make_response("Authentication error occurred", 500, {'Content-Type': 'text/html'})
|
||||
|
||||
|
||||
@@ -1,19 +1,18 @@
|
||||
import os
|
||||
import datetime
|
||||
from flask import Blueprint, request, send_from_directory
|
||||
import json
|
||||
from flask import Blueprint, request, send_from_directory, jsonify
|
||||
from werkzeug.utils import secure_filename
|
||||
from bson.objectid import ObjectId
|
||||
import logging
|
||||
from application.core.mongo_db import MongoDB
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
from application.storage.db.repositories.sources import SourcesRepository
|
||||
from application.storage.db.session import db_session
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
conversations_collection = db["conversations"]
|
||||
sources_collection = db["sources"]
|
||||
|
||||
current_dir = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
@@ -23,6 +22,24 @@ current_dir = os.path.dirname(
|
||||
internal = Blueprint("internal", __name__)
|
||||
|
||||
|
||||
@internal.before_request
|
||||
def verify_internal_key():
|
||||
"""Verify INTERNAL_KEY for all internal endpoint requests.
|
||||
|
||||
Deny by default: if INTERNAL_KEY is not configured, reject all requests.
|
||||
"""
|
||||
if not settings.INTERNAL_KEY:
|
||||
logger.warning(
|
||||
f"Internal API request rejected from {request.remote_addr}: "
|
||||
"INTERNAL_KEY is not configured"
|
||||
)
|
||||
return jsonify({"error": "Unauthorized", "message": "Internal API is not configured"}), 401
|
||||
internal_key = request.headers.get("X-Internal-Key")
|
||||
if not internal_key or internal_key != settings.INTERNAL_KEY:
|
||||
logger.warning(f"Unauthorized internal API access attempt from {request.remote_addr}")
|
||||
return jsonify({"error": "Unauthorized", "message": "Invalid or missing internal key"}), 401
|
||||
|
||||
|
||||
@internal.route("/api/download", methods=["get"])
|
||||
def download_file():
|
||||
user = secure_filename(request.args.get("user"))
|
||||
@@ -37,20 +54,41 @@ def upload_index_files():
|
||||
"""Upload two files(index.faiss, index.pkl) to the user's folder."""
|
||||
if "user" not in request.form:
|
||||
return {"status": "no user"}
|
||||
user = secure_filename(request.form["user"])
|
||||
user = request.form["user"]
|
||||
if "name" not in request.form:
|
||||
return {"status": "no name"}
|
||||
job_name = secure_filename(request.form["name"])
|
||||
tokens = secure_filename(request.form["tokens"])
|
||||
retriever = secure_filename(request.form["retriever"])
|
||||
id = secure_filename(request.form["id"])
|
||||
type = secure_filename(request.form["type"])
|
||||
job_name = request.form["name"]
|
||||
tokens = request.form["tokens"]
|
||||
retriever = request.form["retriever"]
|
||||
source_id = request.form["id"]
|
||||
type = request.form["type"]
|
||||
remote_data = request.form["remote_data"] if "remote_data" in request.form else None
|
||||
sync_frequency = secure_filename(request.form["sync_frequency"]) if "sync_frequency" in request.form else None
|
||||
sync_frequency = request.form["sync_frequency"] if "sync_frequency" in request.form else None
|
||||
|
||||
file_path = request.form.get("file_path")
|
||||
directory_structure = request.form.get("directory_structure")
|
||||
file_name_map = request.form.get("file_name_map")
|
||||
|
||||
if directory_structure:
|
||||
try:
|
||||
directory_structure = json.loads(directory_structure)
|
||||
except Exception:
|
||||
logger.error("Error parsing directory_structure")
|
||||
directory_structure = {}
|
||||
else:
|
||||
directory_structure = {}
|
||||
if file_name_map:
|
||||
try:
|
||||
file_name_map = json.loads(file_name_map)
|
||||
except Exception:
|
||||
logger.error("Error parsing file_name_map")
|
||||
file_name_map = None
|
||||
else:
|
||||
file_name_map = None
|
||||
|
||||
storage = StorageCreator.get_storage()
|
||||
index_base_path = f"indexes/{id}"
|
||||
|
||||
index_base_path = f"indexes/{source_id}"
|
||||
|
||||
if settings.VECTOR_STORE == "faiss":
|
||||
if "file_faiss" not in request.files:
|
||||
logger.error("No file_faiss part")
|
||||
@@ -64,44 +102,55 @@ def upload_index_files():
|
||||
file_pkl = request.files["file_pkl"]
|
||||
if file_pkl.filename == "":
|
||||
return {"status": "no file name"}
|
||||
|
||||
# Save index files to storage
|
||||
storage.save_file(file_faiss, f"{index_base_path}/index.faiss")
|
||||
storage.save_file(file_pkl, f"{index_base_path}/index.pkl")
|
||||
|
||||
existing_entry = sources_collection.find_one({"_id": ObjectId(id)})
|
||||
if existing_entry:
|
||||
sources_collection.update_one(
|
||||
{"_id": ObjectId(id)},
|
||||
{
|
||||
"$set": {
|
||||
"user": user,
|
||||
"name": job_name,
|
||||
"language": job_name,
|
||||
"date": datetime.datetime.now(),
|
||||
"model": settings.EMBEDDINGS_NAME,
|
||||
"type": type,
|
||||
"tokens": tokens,
|
||||
"retriever": retriever,
|
||||
"remote_data": remote_data,
|
||||
"sync_frequency": sync_frequency,
|
||||
}
|
||||
},
|
||||
)
|
||||
else:
|
||||
sources_collection.insert_one(
|
||||
{
|
||||
"_id": ObjectId(id),
|
||||
"user": user,
|
||||
"name": job_name,
|
||||
"language": job_name,
|
||||
"date": datetime.datetime.now(),
|
||||
"model": settings.EMBEDDINGS_NAME,
|
||||
"type": type,
|
||||
"tokens": tokens,
|
||||
"retriever": retriever,
|
||||
"remote_data": remote_data,
|
||||
"sync_frequency": sync_frequency,
|
||||
}
|
||||
)
|
||||
# Save index files to storage
|
||||
faiss_storage_path = f"{index_base_path}/index.faiss"
|
||||
pkl_storage_path = f"{index_base_path}/index.pkl"
|
||||
storage.save_file(file_faiss, faiss_storage_path)
|
||||
storage.save_file(file_pkl, pkl_storage_path)
|
||||
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
update_fields = {
|
||||
"name": job_name,
|
||||
"type": type,
|
||||
"language": job_name,
|
||||
"date": now,
|
||||
"model": settings.EMBEDDINGS_NAME,
|
||||
"tokens": tokens,
|
||||
"retriever": retriever,
|
||||
"remote_data": remote_data,
|
||||
"sync_frequency": sync_frequency,
|
||||
"file_path": file_path,
|
||||
"directory_structure": directory_structure,
|
||||
}
|
||||
if file_name_map is not None:
|
||||
update_fields["file_name_map"] = file_name_map
|
||||
|
||||
with db_session() as conn:
|
||||
repo = SourcesRepository(conn)
|
||||
existing = None
|
||||
if looks_like_uuid(source_id):
|
||||
existing = repo.get(source_id, user)
|
||||
if existing is None:
|
||||
existing = repo.get_by_legacy_id(source_id, user)
|
||||
if existing is not None:
|
||||
repo.update(str(existing["id"]), user, update_fields)
|
||||
else:
|
||||
repo.create(
|
||||
job_name,
|
||||
source_id=source_id if looks_like_uuid(source_id) else None,
|
||||
user_id=user,
|
||||
type=type,
|
||||
tokens=tokens,
|
||||
retriever=retriever,
|
||||
remote_data=remote_data,
|
||||
sync_frequency=sync_frequency,
|
||||
file_path=file_path,
|
||||
directory_structure=directory_structure,
|
||||
file_name_map=file_name_map,
|
||||
language=job_name,
|
||||
model=settings.EMBEDDINGS_NAME,
|
||||
date=now,
|
||||
legacy_mongo_id=None if looks_like_uuid(source_id) else str(source_id),
|
||||
)
|
||||
return {"status": "ok"}
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
"""User API module - provides all user-related API endpoints"""
|
||||
|
||||
from .routes import user
|
||||
|
||||
__all__ = ["user"]
|
||||
|
||||
8
application/api/user/agents/__init__.py
Normal file
8
application/api/user/agents/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""Agents module."""
|
||||
|
||||
from .routes import agents_ns
|
||||
from .sharing import agents_sharing_ns
|
||||
from .webhooks import agents_webhooks_ns
|
||||
from .folders import agents_folders_ns
|
||||
|
||||
__all__ = ["agents_ns", "agents_sharing_ns", "agents_webhooks_ns", "agents_folders_ns"]
|
||||
366
application/api/user/agents/folders.py
Normal file
366
application/api/user/agents/folders.py
Normal file
@@ -0,0 +1,366 @@
|
||||
"""
|
||||
Agent folders management routes.
|
||||
Provides virtual folder organization for agents (Google Drive-like structure).
|
||||
"""
|
||||
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import Namespace, Resource, fields
|
||||
from sqlalchemy import text as _sql_text
|
||||
|
||||
from application.api import api
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
from application.storage.db.repositories.agent_folders import AgentFoldersRepository
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
|
||||
agents_folders_ns = Namespace(
|
||||
"agents_folders", description="Agent folder management", path="/api/agents/folders"
|
||||
)
|
||||
|
||||
|
||||
def _resolve_folder_id(repo: AgentFoldersRepository, folder_id: str, user: str):
|
||||
"""Resolve a folder id that may be either a UUID or legacy Mongo ObjectId."""
|
||||
if not folder_id:
|
||||
return None
|
||||
if looks_like_uuid(folder_id):
|
||||
row = repo.get(folder_id, user)
|
||||
if row is not None:
|
||||
return row
|
||||
return repo.get_by_legacy_id(folder_id, user)
|
||||
|
||||
|
||||
def _folder_error_response(message: str, err: Exception):
|
||||
current_app.logger.error(f"{message}: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False, "message": message}), 400)
|
||||
|
||||
|
||||
def _serialize_folder(f: dict) -> dict:
|
||||
created_at = f.get("created_at")
|
||||
updated_at = f.get("updated_at")
|
||||
return {
|
||||
"id": str(f["id"]),
|
||||
"name": f.get("name"),
|
||||
"parent_id": str(f["parent_id"]) if f.get("parent_id") else None,
|
||||
"created_at": created_at.isoformat() if hasattr(created_at, "isoformat") else created_at,
|
||||
"updated_at": updated_at.isoformat() if hasattr(updated_at, "isoformat") else updated_at,
|
||||
}
|
||||
|
||||
|
||||
@agents_folders_ns.route("/")
|
||||
class AgentFolders(Resource):
|
||||
@api.doc(description="Get all folders for the user")
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
folders = AgentFoldersRepository(conn).list_for_user(user)
|
||||
result = [_serialize_folder(f) for f in folders]
|
||||
return make_response(jsonify({"folders": result}), 200)
|
||||
except Exception as err:
|
||||
return _folder_error_response("Failed to fetch folders", err)
|
||||
|
||||
@api.doc(description="Create a new folder")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"CreateFolder",
|
||||
{
|
||||
"name": fields.String(required=True, description="Folder name"),
|
||||
"parent_id": fields.String(required=False, description="Parent folder ID"),
|
||||
},
|
||||
)
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
if not data or not data.get("name"):
|
||||
return make_response(jsonify({"success": False, "message": "Folder name is required"}), 400)
|
||||
|
||||
parent_id_input = data.get("parent_id")
|
||||
description = data.get("description")
|
||||
|
||||
try:
|
||||
with db_session() as conn:
|
||||
repo = AgentFoldersRepository(conn)
|
||||
pg_parent_id = None
|
||||
if parent_id_input:
|
||||
parent = _resolve_folder_id(repo, parent_id_input, user)
|
||||
if not parent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Parent folder not found"}),
|
||||
404,
|
||||
)
|
||||
pg_parent_id = str(parent["id"])
|
||||
folder = repo.create(
|
||||
user, data["name"],
|
||||
description=description,
|
||||
parent_id=pg_parent_id,
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"id": str(folder["id"]),
|
||||
"name": folder["name"],
|
||||
"parent_id": pg_parent_id,
|
||||
}
|
||||
),
|
||||
201,
|
||||
)
|
||||
except Exception as err:
|
||||
return _folder_error_response("Failed to create folder", err)
|
||||
|
||||
|
||||
@agents_folders_ns.route("/<string:folder_id>")
|
||||
class AgentFolder(Resource):
|
||||
@api.doc(description="Get a specific folder with its agents")
|
||||
def get(self, folder_id):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
folders_repo = AgentFoldersRepository(conn)
|
||||
folder = _resolve_folder_id(folders_repo, folder_id, user)
|
||||
if not folder:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Folder not found"}),
|
||||
404,
|
||||
)
|
||||
pg_folder_id = str(folder["id"])
|
||||
|
||||
agents_rows = conn.execute(
|
||||
_sql_text(
|
||||
"SELECT id, name, description FROM agents "
|
||||
"WHERE user_id = :user_id AND folder_id = CAST(:fid AS uuid) "
|
||||
"ORDER BY created_at DESC"
|
||||
),
|
||||
{"user_id": user, "fid": pg_folder_id},
|
||||
).fetchall()
|
||||
agents_list = [
|
||||
{
|
||||
"id": str(row._mapping["id"]),
|
||||
"name": row._mapping["name"],
|
||||
"description": row._mapping.get("description", "") or "",
|
||||
}
|
||||
for row in agents_rows
|
||||
]
|
||||
|
||||
subfolders = folders_repo.list_children(pg_folder_id, user)
|
||||
subfolders_list = [
|
||||
{"id": str(sf["id"]), "name": sf["name"]}
|
||||
for sf in subfolders
|
||||
]
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"id": pg_folder_id,
|
||||
"name": folder["name"],
|
||||
"parent_id": (
|
||||
str(folder["parent_id"]) if folder.get("parent_id") else None
|
||||
),
|
||||
"agents": agents_list,
|
||||
"subfolders": subfolders_list,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
except Exception as err:
|
||||
return _folder_error_response("Failed to fetch folder", err)
|
||||
|
||||
@api.doc(description="Update a folder")
|
||||
def put(self, folder_id):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
if not data:
|
||||
return make_response(jsonify({"success": False, "message": "No data provided"}), 400)
|
||||
|
||||
try:
|
||||
with db_session() as conn:
|
||||
repo = AgentFoldersRepository(conn)
|
||||
folder = _resolve_folder_id(repo, folder_id, user)
|
||||
if not folder:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Folder not found"}),
|
||||
404,
|
||||
)
|
||||
pg_folder_id = str(folder["id"])
|
||||
|
||||
update_fields: dict = {}
|
||||
if "name" in data:
|
||||
update_fields["name"] = data["name"]
|
||||
if "description" in data:
|
||||
update_fields["description"] = data["description"]
|
||||
if "parent_id" in data:
|
||||
parent_input = data.get("parent_id")
|
||||
if parent_input:
|
||||
if parent_input == folder_id or parent_input == pg_folder_id:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Cannot set folder as its own parent",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
parent = _resolve_folder_id(repo, parent_input, user)
|
||||
if not parent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Parent folder not found"}),
|
||||
404,
|
||||
)
|
||||
if str(parent["id"]) == pg_folder_id:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Cannot set folder as its own parent",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
update_fields["parent_id"] = str(parent["id"])
|
||||
else:
|
||||
update_fields["parent_id"] = None
|
||||
|
||||
if update_fields:
|
||||
repo.update(pg_folder_id, user, update_fields)
|
||||
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
except Exception as err:
|
||||
return _folder_error_response("Failed to update folder", err)
|
||||
|
||||
@api.doc(description="Delete a folder")
|
||||
def delete(self, folder_id):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
with db_session() as conn:
|
||||
repo = AgentFoldersRepository(conn)
|
||||
folder = _resolve_folder_id(repo, folder_id, user)
|
||||
if not folder:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Folder not found"}),
|
||||
404,
|
||||
)
|
||||
pg_folder_id = str(folder["id"])
|
||||
# Clear folder assignments from agents; self-FK
|
||||
# ``ON DELETE SET NULL`` handles child folders.
|
||||
AgentsRepository(conn).clear_folder_for_all(pg_folder_id, user)
|
||||
deleted = repo.delete(pg_folder_id, user)
|
||||
if not deleted:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Folder not found"}),
|
||||
404,
|
||||
)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
except Exception as err:
|
||||
return _folder_error_response("Failed to delete folder", err)
|
||||
|
||||
|
||||
@agents_folders_ns.route("/move_agent")
|
||||
class MoveAgentToFolder(Resource):
|
||||
@api.doc(description="Move an agent to a folder or remove from folder")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"MoveAgent",
|
||||
{
|
||||
"agent_id": fields.String(required=True, description="Agent ID to move"),
|
||||
"folder_id": fields.String(required=False, description="Target folder ID (null to remove from folder)"),
|
||||
},
|
||||
)
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
if not data or not data.get("agent_id"):
|
||||
return make_response(jsonify({"success": False, "message": "Agent ID is required"}), 400)
|
||||
|
||||
agent_id_input = data["agent_id"]
|
||||
folder_id_input = data.get("folder_id")
|
||||
|
||||
try:
|
||||
with db_session() as conn:
|
||||
agents_repo = AgentsRepository(conn)
|
||||
agent = agents_repo.get_any(agent_id_input, user)
|
||||
if not agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Agent not found"}),
|
||||
404,
|
||||
)
|
||||
pg_folder_id = None
|
||||
if folder_id_input:
|
||||
folders_repo = AgentFoldersRepository(conn)
|
||||
folder = _resolve_folder_id(folders_repo, folder_id_input, user)
|
||||
if not folder:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Folder not found"}),
|
||||
404,
|
||||
)
|
||||
pg_folder_id = str(folder["id"])
|
||||
agents_repo.set_folder(str(agent["id"]), user, pg_folder_id)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
except Exception as err:
|
||||
return _folder_error_response("Failed to move agent", err)
|
||||
|
||||
|
||||
@agents_folders_ns.route("/bulk_move")
|
||||
class BulkMoveAgents(Resource):
|
||||
@api.doc(description="Move multiple agents to a folder")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"BulkMoveAgents",
|
||||
{
|
||||
"agent_ids": fields.List(fields.String, required=True, description="List of agent IDs"),
|
||||
"folder_id": fields.String(required=False, description="Target folder ID"),
|
||||
},
|
||||
)
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
if not data or not data.get("agent_ids"):
|
||||
return make_response(jsonify({"success": False, "message": "Agent IDs are required"}), 400)
|
||||
|
||||
agent_ids = data["agent_ids"]
|
||||
folder_id_input = data.get("folder_id")
|
||||
|
||||
try:
|
||||
with db_session() as conn:
|
||||
agents_repo = AgentsRepository(conn)
|
||||
pg_folder_id = None
|
||||
if folder_id_input:
|
||||
folders_repo = AgentFoldersRepository(conn)
|
||||
folder = _resolve_folder_id(folders_repo, folder_id_input, user)
|
||||
if not folder:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Folder not found"}),
|
||||
404,
|
||||
)
|
||||
pg_folder_id = str(folder["id"])
|
||||
for agent_id_input in agent_ids:
|
||||
agent = agents_repo.get_any(agent_id_input, user)
|
||||
if agent is not None:
|
||||
agents_repo.set_folder(str(agent["id"]), user, pg_folder_id)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
except Exception as err:
|
||||
return _folder_error_response("Failed to move agents", err)
|
||||
1436
application/api/user/agents/routes.py
Normal file
1436
application/api/user/agents/routes.py
Normal file
File diff suppressed because it is too large
Load Diff
274
application/api/user/agents/sharing.py
Normal file
274
application/api/user/agents/sharing.py
Normal file
@@ -0,0 +1,274 @@
|
||||
"""Agent management sharing functionality."""
|
||||
|
||||
import datetime
|
||||
import secrets
|
||||
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
from sqlalchemy import text as _sql_text
|
||||
|
||||
from application.api import api
|
||||
from application.core.settings import settings
|
||||
from application.api.user.base import resolve_tool_details
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.users import UsersRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.utils import generate_image_url
|
||||
|
||||
agents_sharing_ns = Namespace(
|
||||
"agents", description="Agent management operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
def _serialize_agent_basic(agent: dict) -> dict:
|
||||
"""Shape a PG agent row into the API response dict."""
|
||||
source_id = agent.get("source_id")
|
||||
return {
|
||||
"id": str(agent["id"]),
|
||||
"user": agent.get("user_id", ""),
|
||||
"name": agent.get("name", ""),
|
||||
"image": (
|
||||
generate_image_url(agent["image"]) if agent.get("image") else ""
|
||||
),
|
||||
"description": agent.get("description", ""),
|
||||
"source": str(source_id) if source_id else "",
|
||||
"chunks": str(agent["chunks"]) if agent.get("chunks") is not None else "0",
|
||||
"retriever": agent.get("retriever", "classic") or "classic",
|
||||
"prompt_id": str(agent["prompt_id"]) if agent.get("prompt_id") else "default",
|
||||
"tools": agent.get("tools", []) or [],
|
||||
"tool_details": resolve_tool_details(agent.get("tools", []) or []),
|
||||
"agent_type": agent.get("agent_type", "") or "",
|
||||
"status": agent.get("status", "") or "",
|
||||
"json_schema": agent.get("json_schema"),
|
||||
"limited_token_mode": agent.get("limited_token_mode", False),
|
||||
"token_limit": agent.get("token_limit") or settings.DEFAULT_AGENT_LIMITS["token_limit"],
|
||||
"limited_request_mode": agent.get("limited_request_mode", False),
|
||||
"request_limit": agent.get("request_limit") or settings.DEFAULT_AGENT_LIMITS["request_limit"],
|
||||
"created_at": agent.get("created_at", ""),
|
||||
"updated_at": agent.get("updated_at", ""),
|
||||
"shared": bool(agent.get("shared", False)),
|
||||
"shared_token": agent.get("shared_token", "") or "",
|
||||
"shared_metadata": agent.get("shared_metadata", {}) or {},
|
||||
}
|
||||
|
||||
|
||||
@agents_sharing_ns.route("/shared_agent")
|
||||
class SharedAgent(Resource):
|
||||
@api.doc(
|
||||
params={
|
||||
"token": "Shared token of the agent",
|
||||
},
|
||||
description="Get a shared agent by token or ID",
|
||||
)
|
||||
def get(self):
|
||||
shared_token = request.args.get("token")
|
||||
|
||||
if not shared_token:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Token or ID is required"}), 400
|
||||
)
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
shared_agent = AgentsRepository(conn).find_by_shared_token(
|
||||
shared_token,
|
||||
)
|
||||
if not shared_agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Shared agent not found"}),
|
||||
404,
|
||||
)
|
||||
agent_id = str(shared_agent["id"])
|
||||
data = _serialize_agent_basic(shared_agent)
|
||||
|
||||
if data["tools"]:
|
||||
enriched_tools = []
|
||||
for detail in data["tool_details"]:
|
||||
enriched_tools.append(detail.get("name", ""))
|
||||
data["tools"] = enriched_tools
|
||||
decoded_token = getattr(request, "decoded_token", None)
|
||||
if decoded_token:
|
||||
user_id = decoded_token.get("sub")
|
||||
owner_id = shared_agent.get("user_id")
|
||||
|
||||
if user_id != owner_id:
|
||||
with db_session() as conn:
|
||||
users_repo = UsersRepository(conn)
|
||||
users_repo.upsert(user_id)
|
||||
users_repo.add_shared(user_id, agent_id)
|
||||
return make_response(jsonify(data), 200)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving shared agent: {err}")
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
|
||||
@agents_sharing_ns.route("/shared_agents")
|
||||
class SharedAgents(Resource):
|
||||
@api.doc(description="Get shared agents explicitly shared with the user")
|
||||
def get(self):
|
||||
try:
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user_id = decoded_token.get("sub")
|
||||
|
||||
with db_session() as conn:
|
||||
users_repo = UsersRepository(conn)
|
||||
user_doc = users_repo.upsert(user_id)
|
||||
shared_with_ids = (
|
||||
user_doc.get("agent_preferences", {}).get("shared_with_me", [])
|
||||
if isinstance(user_doc.get("agent_preferences"), dict)
|
||||
else []
|
||||
)
|
||||
# Keep only UUID-shaped ids; ObjectId leftovers are stripped below.
|
||||
uuid_ids = [sid for sid in shared_with_ids if looks_like_uuid(sid)]
|
||||
non_uuid_ids = [sid for sid in shared_with_ids if not looks_like_uuid(sid)]
|
||||
|
||||
if uuid_ids:
|
||||
result = conn.execute(
|
||||
_sql_text(
|
||||
"SELECT * FROM agents "
|
||||
"WHERE id = ANY(CAST(:ids AS uuid[])) "
|
||||
"AND shared = true"
|
||||
),
|
||||
{"ids": uuid_ids},
|
||||
)
|
||||
shared_agents = [dict(row._mapping) for row in result.fetchall()]
|
||||
else:
|
||||
shared_agents = []
|
||||
|
||||
found_ids_set = {str(agent["id"]) for agent in shared_agents}
|
||||
stale_ids = [sid for sid in uuid_ids if sid not in found_ids_set]
|
||||
stale_ids.extend(non_uuid_ids)
|
||||
if stale_ids:
|
||||
users_repo.remove_shared_bulk(user_id, stale_ids)
|
||||
|
||||
pinned_ids = set(
|
||||
user_doc.get("agent_preferences", {}).get("pinned", [])
|
||||
if isinstance(user_doc.get("agent_preferences"), dict)
|
||||
else []
|
||||
)
|
||||
|
||||
list_shared_agents = []
|
||||
for agent in shared_agents:
|
||||
agent_id_str = str(agent["id"])
|
||||
list_shared_agents.append(
|
||||
{
|
||||
"id": agent_id_str,
|
||||
"name": agent.get("name", ""),
|
||||
"description": agent.get("description", ""),
|
||||
"image": (
|
||||
generate_image_url(agent["image"]) if agent.get("image") else ""
|
||||
),
|
||||
"tools": agent.get("tools", []) or [],
|
||||
"tool_details": resolve_tool_details(
|
||||
agent.get("tools", []) or []
|
||||
),
|
||||
"agent_type": agent.get("agent_type", "") or "",
|
||||
"status": agent.get("status", "") or "",
|
||||
"json_schema": agent.get("json_schema"),
|
||||
"limited_token_mode": agent.get("limited_token_mode", False),
|
||||
"token_limit": agent.get("token_limit") or settings.DEFAULT_AGENT_LIMITS["token_limit"],
|
||||
"limited_request_mode": agent.get("limited_request_mode", False),
|
||||
"request_limit": agent.get("request_limit") or settings.DEFAULT_AGENT_LIMITS["request_limit"],
|
||||
"created_at": agent.get("created_at", ""),
|
||||
"updated_at": agent.get("updated_at", ""),
|
||||
"pinned": agent_id_str in pinned_ids,
|
||||
"shared": bool(agent.get("shared", False)),
|
||||
"shared_token": agent.get("shared_token", "") or "",
|
||||
"shared_metadata": agent.get("shared_metadata", {}) or {},
|
||||
}
|
||||
)
|
||||
|
||||
return make_response(jsonify(list_shared_agents), 200)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving shared agents: {err}")
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
|
||||
@agents_sharing_ns.route("/share_agent")
|
||||
class ShareAgent(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"ShareAgentModel",
|
||||
{
|
||||
"id": fields.String(required=True, description="ID of the agent"),
|
||||
"shared": fields.Boolean(
|
||||
required=True, description="Share or unshare the agent"
|
||||
),
|
||||
"username": fields.String(
|
||||
required=False, description="Name of the user"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Share or unshare an agent")
|
||||
def put(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
|
||||
data = request.get_json()
|
||||
if not data:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Missing JSON body"}), 400
|
||||
)
|
||||
agent_id = data.get("id")
|
||||
shared = data.get("shared")
|
||||
username = data.get("username", "")
|
||||
|
||||
if not agent_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "ID is required"}), 400
|
||||
)
|
||||
if shared is None:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Shared parameter is required and must be true or false",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
shared_token = None
|
||||
try:
|
||||
with db_session() as conn:
|
||||
repo = AgentsRepository(conn)
|
||||
agent = repo.get_any(agent_id, user)
|
||||
if not agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Agent not found"}), 404
|
||||
)
|
||||
if shared:
|
||||
shared_metadata = {
|
||||
"shared_by": username,
|
||||
"shared_at": datetime.datetime.now(
|
||||
datetime.timezone.utc
|
||||
).isoformat(),
|
||||
}
|
||||
shared_token = secrets.token_urlsafe(32)
|
||||
repo.update(
|
||||
str(agent["id"]), user,
|
||||
{
|
||||
"shared": True,
|
||||
"shared_token": shared_token,
|
||||
"shared_metadata": shared_metadata,
|
||||
},
|
||||
)
|
||||
else:
|
||||
repo.update(
|
||||
str(agent["id"]), user,
|
||||
{
|
||||
"shared": False,
|
||||
"shared_token": None,
|
||||
"shared_metadata": None,
|
||||
},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error sharing/unsharing agent: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False, "error": "Failed to update agent sharing status"}), 400)
|
||||
return make_response(
|
||||
jsonify({"success": True, "shared_token": shared_token}), 200
|
||||
)
|
||||
120
application/api/user/agents/webhooks.py
Normal file
120
application/api/user/agents/webhooks.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Agent management webhook handlers."""
|
||||
|
||||
import secrets
|
||||
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import require_agent
|
||||
from application.api.user.tasks import process_agent_webhook
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
|
||||
agents_webhooks_ns = Namespace(
|
||||
"agents", description="Agent management operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
@agents_webhooks_ns.route("/agent_webhook")
|
||||
class AgentWebhook(Resource):
|
||||
@api.doc(
|
||||
params={"id": "ID of the agent"},
|
||||
description="Generate webhook URL for the agent",
|
||||
)
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
agent_id = request.args.get("id")
|
||||
if not agent_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "ID is required"}), 400
|
||||
)
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
agent = AgentsRepository(conn).get_any(agent_id, user)
|
||||
if not agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Agent not found"}), 404
|
||||
)
|
||||
webhook_token = agent.get("incoming_webhook_token")
|
||||
if not webhook_token:
|
||||
webhook_token = secrets.token_urlsafe(32)
|
||||
with db_session() as conn:
|
||||
AgentsRepository(conn).update(
|
||||
str(agent["id"]), user,
|
||||
{"incoming_webhook_token": webhook_token},
|
||||
)
|
||||
base_url = settings.API_URL.rstrip("/")
|
||||
full_webhook_url = f"{base_url}/api/webhooks/agents/{webhook_token}"
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error generating webhook URL: {err}", exc_info=True
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Error generating webhook URL"}),
|
||||
400,
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": True, "webhook_url": full_webhook_url}), 200
|
||||
)
|
||||
|
||||
|
||||
@agents_webhooks_ns.route("/webhooks/agents/<string:webhook_token>")
|
||||
class AgentWebhookListener(Resource):
|
||||
method_decorators = [require_agent]
|
||||
|
||||
def _enqueue_webhook_task(self, agent_id_str, payload, source_method):
|
||||
if not payload:
|
||||
current_app.logger.warning(
|
||||
f"Webhook ({source_method}) received for agent {agent_id_str} with empty payload."
|
||||
)
|
||||
current_app.logger.info(
|
||||
f"Incoming {source_method} webhook for agent {agent_id_str}. Enqueuing task with payload: {payload}"
|
||||
)
|
||||
|
||||
try:
|
||||
task = process_agent_webhook.delay(
|
||||
agent_id=agent_id_str,
|
||||
payload=payload,
|
||||
)
|
||||
current_app.logger.info(
|
||||
f"Task {task.id} enqueued for agent {agent_id_str} ({source_method})."
|
||||
)
|
||||
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error enqueuing webhook task ({source_method}) for agent {agent_id_str}: {err}",
|
||||
exc_info=True,
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Error processing webhook"}), 500
|
||||
)
|
||||
|
||||
@api.doc(
|
||||
description="Webhook listener for agent events (POST). Expects JSON payload, which is used to trigger processing.",
|
||||
)
|
||||
def post(self, webhook_token, agent, agent_id_str):
|
||||
payload = request.get_json()
|
||||
if payload is None:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Invalid or missing JSON data in request body",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
return self._enqueue_webhook_task(agent_id_str, payload, source_method="POST")
|
||||
|
||||
@api.doc(
|
||||
description="Webhook listener for agent events (GET). Uses URL query parameters as payload to trigger processing.",
|
||||
)
|
||||
def get(self, webhook_token, agent, agent_id_str):
|
||||
payload = request.args.to_dict(flat=True)
|
||||
return self._enqueue_webhook_task(agent_id_str, payload, source_method="GET")
|
||||
5
application/api/user/analytics/__init__.py
Normal file
5
application/api/user/analytics/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Analytics module."""
|
||||
|
||||
from .routes import analytics_ns
|
||||
|
||||
__all__ = ["analytics_ns"]
|
||||
437
application/api/user/analytics/routes.py
Normal file
437
application/api/user/analytics/routes.py
Normal file
@@ -0,0 +1,437 @@
|
||||
"""Analytics and reporting routes."""
|
||||
|
||||
import datetime
|
||||
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
from sqlalchemy import text as _sql_text
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import (
|
||||
generate_date_range,
|
||||
generate_hourly_range,
|
||||
generate_minute_range,
|
||||
)
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.token_usage import TokenUsageRepository
|
||||
from application.storage.db.repositories.user_logs import UserLogsRepository
|
||||
from application.storage.db.session import db_readonly
|
||||
|
||||
|
||||
analytics_ns = Namespace(
|
||||
"analytics", description="Analytics and reporting operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
_FILTER_BUCKETS = {
|
||||
"last_hour": ("minute", "%Y-%m-%d %H:%M:00", "YYYY-MM-DD HH24:MI:00"),
|
||||
"last_24_hour": ("hour", "%Y-%m-%d %H:00", "YYYY-MM-DD HH24:00"),
|
||||
"last_7_days": ("day", "%Y-%m-%d", "YYYY-MM-DD"),
|
||||
"last_15_days": ("day", "%Y-%m-%d", "YYYY-MM-DD"),
|
||||
"last_30_days": ("day", "%Y-%m-%d", "YYYY-MM-DD"),
|
||||
}
|
||||
|
||||
|
||||
def _range_for_filter(filter_option: str):
|
||||
"""Return ``(start_date, end_date, bucket_unit, pg_fmt)`` for the filter.
|
||||
|
||||
Returns ``None`` on invalid filter.
|
||||
"""
|
||||
if filter_option not in _FILTER_BUCKETS:
|
||||
return None
|
||||
end_date = datetime.datetime.now(datetime.timezone.utc)
|
||||
bucket_unit, _py_fmt, pg_fmt = _FILTER_BUCKETS[filter_option]
|
||||
|
||||
if filter_option == "last_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=1)
|
||||
elif filter_option == "last_24_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=24)
|
||||
else:
|
||||
days = {
|
||||
"last_7_days": 6,
|
||||
"last_15_days": 14,
|
||||
"last_30_days": 29,
|
||||
}[filter_option]
|
||||
start_date = end_date - datetime.timedelta(days=days)
|
||||
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end_date = end_date.replace(
|
||||
hour=23, minute=59, second=59, microsecond=999999
|
||||
)
|
||||
return start_date, end_date, bucket_unit, pg_fmt
|
||||
|
||||
|
||||
def _intervals_for_filter(filter_option, start_date, end_date):
|
||||
if filter_option == "last_hour":
|
||||
return generate_minute_range(start_date, end_date)
|
||||
if filter_option == "last_24_hour":
|
||||
return generate_hourly_range(start_date, end_date)
|
||||
return generate_date_range(start_date, end_date)
|
||||
|
||||
|
||||
def _resolve_api_key(conn, api_key_id, user_id):
|
||||
"""Look up the ``agents.key`` value for a given agent id.
|
||||
|
||||
Scoped by ``user_id`` so an authenticated caller can't probe another
|
||||
user's agents. Accepts either UUID or legacy Mongo ObjectId shape.
|
||||
"""
|
||||
if not api_key_id:
|
||||
return None
|
||||
agent = AgentsRepository(conn).get_any(api_key_id, user_id)
|
||||
return (agent or {}).get("key") if agent else None
|
||||
|
||||
|
||||
@analytics_ns.route("/get_message_analytics")
|
||||
class GetMessageAnalytics(Resource):
|
||||
get_message_analytics_model = api.model(
|
||||
"GetMessageAnalyticsModel",
|
||||
{
|
||||
"api_key_id": fields.String(required=False, description="API Key ID"),
|
||||
"filter_option": fields.String(
|
||||
required=False,
|
||||
description="Filter option for analytics",
|
||||
default="last_30_days",
|
||||
enum=list(_FILTER_BUCKETS.keys()),
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(get_message_analytics_model)
|
||||
@api.doc(description="Get message analytics based on filter option")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json() or {}
|
||||
api_key_id = data.get("api_key_id")
|
||||
filter_option = data.get("filter_option", "last_30_days")
|
||||
|
||||
window = _range_for_filter(filter_option)
|
||||
if window is None:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid option"}), 400
|
||||
)
|
||||
start_date, end_date, _bucket_unit, pg_fmt = window
|
||||
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
api_key = _resolve_api_key(conn, api_key_id, user)
|
||||
|
||||
# Count messages per bucket, filtered by the conversation's
|
||||
# owner (user_id) and optionally the agent api_key. The
|
||||
# ``user_id`` filter is always applied post-cutover to
|
||||
# prevent cross-tenant leakage on admin dashboards.
|
||||
clauses = [
|
||||
"c.user_id = :user_id",
|
||||
"m.timestamp >= :start",
|
||||
"m.timestamp <= :end",
|
||||
]
|
||||
params: dict = {
|
||||
"user_id": user,
|
||||
"start": start_date,
|
||||
"end": end_date,
|
||||
"fmt": pg_fmt,
|
||||
}
|
||||
if api_key:
|
||||
clauses.append("c.api_key = :api_key")
|
||||
params["api_key"] = api_key
|
||||
where = " AND ".join(clauses)
|
||||
sql = (
|
||||
"SELECT to_char(m.timestamp AT TIME ZONE 'UTC', :fmt) AS bucket, "
|
||||
"COUNT(*) AS count "
|
||||
"FROM conversation_messages m "
|
||||
"JOIN conversations c ON c.id = m.conversation_id "
|
||||
f"WHERE {where} "
|
||||
"GROUP BY bucket ORDER BY bucket ASC"
|
||||
)
|
||||
rows = conn.execute(_sql_text(sql), params).fetchall()
|
||||
|
||||
intervals = _intervals_for_filter(filter_option, start_date, end_date)
|
||||
daily_messages = {interval: 0 for interval in intervals}
|
||||
for row in rows:
|
||||
daily_messages[row._mapping["bucket"]] = int(row._mapping["count"])
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error getting message analytics: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(
|
||||
jsonify({"success": True, "messages": daily_messages}), 200
|
||||
)
|
||||
|
||||
|
||||
@analytics_ns.route("/get_token_analytics")
|
||||
class GetTokenAnalytics(Resource):
|
||||
get_token_analytics_model = api.model(
|
||||
"GetTokenAnalyticsModel",
|
||||
{
|
||||
"api_key_id": fields.String(required=False, description="API Key ID"),
|
||||
"filter_option": fields.String(
|
||||
required=False,
|
||||
description="Filter option for analytics",
|
||||
default="last_30_days",
|
||||
enum=list(_FILTER_BUCKETS.keys()),
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(get_token_analytics_model)
|
||||
@api.doc(description="Get token analytics data")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json() or {}
|
||||
api_key_id = data.get("api_key_id")
|
||||
filter_option = data.get("filter_option", "last_30_days")
|
||||
|
||||
window = _range_for_filter(filter_option)
|
||||
if window is None:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid option"}), 400
|
||||
)
|
||||
start_date, end_date, bucket_unit, _pg_fmt = window
|
||||
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
api_key = _resolve_api_key(conn, api_key_id, user)
|
||||
# ``bucketed_totals`` applies user_id / api_key filters
|
||||
# directly — no need to reshape a Mongo pipeline.
|
||||
rows = TokenUsageRepository(conn).bucketed_totals(
|
||||
bucket_unit=bucket_unit,
|
||||
user_id=user,
|
||||
api_key=api_key,
|
||||
timestamp_gte=start_date,
|
||||
timestamp_lt=end_date,
|
||||
)
|
||||
|
||||
intervals = _intervals_for_filter(filter_option, start_date, end_date)
|
||||
daily_token_usage = {interval: 0 for interval in intervals}
|
||||
for entry in rows:
|
||||
daily_token_usage[entry["bucket"]] = int(
|
||||
entry["prompt_tokens"] + entry["generated_tokens"]
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error getting token analytics: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(
|
||||
jsonify({"success": True, "token_usage": daily_token_usage}), 200
|
||||
)
|
||||
|
||||
|
||||
@analytics_ns.route("/get_feedback_analytics")
|
||||
class GetFeedbackAnalytics(Resource):
|
||||
get_feedback_analytics_model = api.model(
|
||||
"GetFeedbackAnalyticsModel",
|
||||
{
|
||||
"api_key_id": fields.String(required=False, description="API Key ID"),
|
||||
"filter_option": fields.String(
|
||||
required=False,
|
||||
description="Filter option for analytics",
|
||||
default="last_30_days",
|
||||
enum=list(_FILTER_BUCKETS.keys()),
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(get_feedback_analytics_model)
|
||||
@api.doc(description="Get feedback analytics data")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json() or {}
|
||||
api_key_id = data.get("api_key_id")
|
||||
filter_option = data.get("filter_option", "last_30_days")
|
||||
|
||||
window = _range_for_filter(filter_option)
|
||||
if window is None:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid option"}), 400
|
||||
)
|
||||
start_date, end_date, _bucket_unit, pg_fmt = window
|
||||
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
api_key = _resolve_api_key(conn, api_key_id, user)
|
||||
|
||||
# Feedback lives inside the ``conversation_messages.feedback``
|
||||
# JSONB as ``{"text": "like"|"dislike", "timestamp": "..."}``.
|
||||
# There is no scalar ``feedback_timestamp`` column — extract
|
||||
# the timestamp from the JSONB and cast it to timestamptz for
|
||||
# the range filter + bucket grouping.
|
||||
clauses = [
|
||||
"c.user_id = :user_id",
|
||||
"m.feedback IS NOT NULL",
|
||||
"(m.feedback->>'timestamp')::timestamptz >= :start",
|
||||
"(m.feedback->>'timestamp')::timestamptz <= :end",
|
||||
]
|
||||
params: dict = {
|
||||
"user_id": user,
|
||||
"start": start_date,
|
||||
"end": end_date,
|
||||
"fmt": pg_fmt,
|
||||
}
|
||||
if api_key:
|
||||
clauses.append("c.api_key = :api_key")
|
||||
params["api_key"] = api_key
|
||||
where = " AND ".join(clauses)
|
||||
sql = (
|
||||
"SELECT to_char("
|
||||
"(m.feedback->>'timestamp')::timestamptz AT TIME ZONE 'UTC', :fmt"
|
||||
") AS bucket, "
|
||||
"SUM(CASE WHEN m.feedback->>'text' = 'like' THEN 1 ELSE 0 END) AS positive, "
|
||||
"SUM(CASE WHEN m.feedback->>'text' = 'dislike' THEN 1 ELSE 0 END) AS negative "
|
||||
"FROM conversation_messages m "
|
||||
"JOIN conversations c ON c.id = m.conversation_id "
|
||||
f"WHERE {where} "
|
||||
"GROUP BY bucket ORDER BY bucket ASC"
|
||||
)
|
||||
rows = conn.execute(_sql_text(sql), params).fetchall()
|
||||
|
||||
intervals = _intervals_for_filter(filter_option, start_date, end_date)
|
||||
daily_feedback = {
|
||||
interval: {"positive": 0, "negative": 0} for interval in intervals
|
||||
}
|
||||
for row in rows:
|
||||
bucket = row._mapping["bucket"]
|
||||
daily_feedback[bucket] = {
|
||||
"positive": int(row._mapping["positive"] or 0),
|
||||
"negative": int(row._mapping["negative"] or 0),
|
||||
}
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error getting feedback analytics: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(
|
||||
jsonify({"success": True, "feedback": daily_feedback}), 200
|
||||
)
|
||||
|
||||
|
||||
@analytics_ns.route("/get_user_logs")
|
||||
class GetUserLogs(Resource):
|
||||
get_user_logs_model = api.model(
|
||||
"GetUserLogsModel",
|
||||
{
|
||||
"page": fields.Integer(
|
||||
required=False,
|
||||
description="Page number for pagination",
|
||||
default=1,
|
||||
),
|
||||
"api_key_id": fields.String(required=False, description="API Key ID"),
|
||||
"page_size": fields.Integer(
|
||||
required=False,
|
||||
description="Number of logs per page",
|
||||
default=10,
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(get_user_logs_model)
|
||||
@api.doc(description="Get user logs with pagination")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json() or {}
|
||||
page = int(data.get("page", 1))
|
||||
api_key_id = data.get("api_key_id")
|
||||
page_size = int(data.get("page_size", 10))
|
||||
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
api_key = _resolve_api_key(conn, api_key_id, user)
|
||||
logs_repo = UserLogsRepository(conn)
|
||||
if api_key:
|
||||
# ``find_by_api_key`` filters on ``data->>'api_key'``
|
||||
# — the PG shape of the legacy top-level ``api_key``
|
||||
# filter. Paginate client-side using offset/limit.
|
||||
all_rows = logs_repo.find_by_api_key(api_key)
|
||||
offset = (page - 1) * page_size
|
||||
window = all_rows[offset: offset + page_size + 1]
|
||||
items = window
|
||||
else:
|
||||
items, has_more_flag = logs_repo.list_paginated(
|
||||
user_id=user,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
# list_paginated already trims to page_size and
|
||||
# returns has_more separately.
|
||||
results = [
|
||||
{
|
||||
"id": str(item.get("id") or item.get("_id")),
|
||||
"action": (item.get("data") or {}).get("action"),
|
||||
"level": (item.get("data") or {}).get("level"),
|
||||
"user": item.get("user_id"),
|
||||
"question": (item.get("data") or {}).get("question"),
|
||||
"sources": (item.get("data") or {}).get("sources"),
|
||||
"retriever_params": (item.get("data") or {}).get(
|
||||
"retriever_params"
|
||||
),
|
||||
"timestamp": (
|
||||
item["timestamp"].isoformat()
|
||||
if hasattr(item.get("timestamp"), "isoformat")
|
||||
else item.get("timestamp")
|
||||
),
|
||||
}
|
||||
for item in items
|
||||
]
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"logs": results,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"has_more": has_more_flag,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
|
||||
has_more = len(items) > page_size
|
||||
items = items[:page_size]
|
||||
results = [
|
||||
{
|
||||
"id": str(item.get("id") or item.get("_id")),
|
||||
"action": (item.get("data") or {}).get("action"),
|
||||
"level": (item.get("data") or {}).get("level"),
|
||||
"user": item.get("user_id"),
|
||||
"question": (item.get("data") or {}).get("question"),
|
||||
"sources": (item.get("data") or {}).get("sources"),
|
||||
"retriever_params": (item.get("data") or {}).get(
|
||||
"retriever_params"
|
||||
),
|
||||
"timestamp": (
|
||||
item["timestamp"].isoformat()
|
||||
if hasattr(item.get("timestamp"), "isoformat")
|
||||
else item.get("timestamp")
|
||||
),
|
||||
}
|
||||
for item in items
|
||||
]
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error getting user logs: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"logs": results,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"has_more": has_more,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
5
application/api/user/attachments/__init__.py
Normal file
5
application/api/user/attachments/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Attachments module."""
|
||||
|
||||
from .routes import attachments_ns
|
||||
|
||||
__all__ = ["attachments_ns"]
|
||||
680
application/api/user/attachments/routes.py
Normal file
680
application/api/user/attachments/routes.py
Normal file
@@ -0,0 +1,680 @@
|
||||
"""File attachments and media routes."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import uuid
|
||||
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.cache import get_redis_instance
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.session import db_readonly
|
||||
from application.stt.constants import (
|
||||
SUPPORTED_AUDIO_EXTENSIONS,
|
||||
SUPPORTED_AUDIO_MIME_TYPES,
|
||||
)
|
||||
from application.stt.upload_limits import (
|
||||
AudioFileTooLargeError,
|
||||
build_stt_file_size_limit_message,
|
||||
enforce_audio_file_size_limit,
|
||||
is_audio_filename,
|
||||
)
|
||||
from application.stt.live_session import (
|
||||
apply_live_stt_hypothesis,
|
||||
create_live_stt_session,
|
||||
delete_live_stt_session,
|
||||
finalize_live_stt_session,
|
||||
get_live_stt_transcript_text,
|
||||
load_live_stt_session,
|
||||
save_live_stt_session,
|
||||
)
|
||||
from application.stt.stt_creator import STTCreator
|
||||
from application.tts.tts_creator import TTSCreator
|
||||
from application.utils import safe_filename
|
||||
|
||||
|
||||
attachments_ns = Namespace(
|
||||
"attachments", description="File attachments and media operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
def _resolve_authenticated_user():
|
||||
decoded_token = getattr(request, "decoded_token", None)
|
||||
api_key = request.form.get("api_key") or request.args.get("api_key")
|
||||
|
||||
if decoded_token:
|
||||
return safe_filename(decoded_token.get("sub"))
|
||||
|
||||
if api_key:
|
||||
with db_readonly() as conn:
|
||||
agent = AgentsRepository(conn).find_by_key(api_key)
|
||||
if not agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid API key"}), 401
|
||||
)
|
||||
return safe_filename(agent.get("user_id"))
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _get_uploaded_file_size(file) -> int:
|
||||
try:
|
||||
current_position = file.stream.tell()
|
||||
file.stream.seek(0, os.SEEK_END)
|
||||
size_bytes = file.stream.tell()
|
||||
file.stream.seek(current_position)
|
||||
return size_bytes
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
|
||||
def _is_supported_audio_mimetype(mimetype: str) -> bool:
|
||||
if not mimetype:
|
||||
return True
|
||||
normalized = mimetype.split(";")[0].strip().lower()
|
||||
return normalized.startswith("audio/") or normalized in SUPPORTED_AUDIO_MIME_TYPES
|
||||
|
||||
|
||||
def _enforce_uploaded_audio_size_limit(file, filename: str) -> None:
|
||||
if not is_audio_filename(filename):
|
||||
return
|
||||
size_bytes = _get_uploaded_file_size(file)
|
||||
if size_bytes:
|
||||
enforce_audio_file_size_limit(size_bytes)
|
||||
|
||||
|
||||
def _get_store_attachment_user_error(exc: Exception) -> str:
|
||||
if isinstance(exc, AudioFileTooLargeError):
|
||||
return build_stt_file_size_limit_message()
|
||||
return "Failed to process file"
|
||||
|
||||
|
||||
def _require_live_stt_redis():
|
||||
redis_client = get_redis_instance()
|
||||
if redis_client:
|
||||
return redis_client
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Live transcription is unavailable"}),
|
||||
503,
|
||||
)
|
||||
|
||||
|
||||
def _parse_bool_form_value(value: str | None) -> bool:
|
||||
if value is None:
|
||||
return False
|
||||
return value.strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
@attachments_ns.route("/store_attachment")
|
||||
class StoreAttachment(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"AttachmentModel",
|
||||
{
|
||||
"file": fields.Raw(required=True, description="File(s) to upload"),
|
||||
"api_key": fields.String(
|
||||
required=False, description="API key (optional)"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Stores one or multiple attachments without vectorization or training. Supports user or API key authentication."
|
||||
)
|
||||
def post(self):
|
||||
auth_user = _resolve_authenticated_user()
|
||||
if hasattr(auth_user, "status_code"):
|
||||
return auth_user
|
||||
|
||||
files = request.files.getlist("file")
|
||||
if not files:
|
||||
single_file = request.files.get("file")
|
||||
if single_file:
|
||||
files = [single_file]
|
||||
|
||||
if not files or all(f.filename == "" for f in files):
|
||||
return make_response(
|
||||
jsonify({"status": "error", "message": "Missing file(s)"}),
|
||||
400,
|
||||
)
|
||||
|
||||
user = auth_user
|
||||
if not user:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Authentication required"}), 401
|
||||
)
|
||||
|
||||
try:
|
||||
from application.api.user.tasks import store_attachment
|
||||
from application.api.user.base import storage
|
||||
|
||||
tasks = []
|
||||
errors = []
|
||||
original_file_count = len(files)
|
||||
|
||||
for idx, file in enumerate(files):
|
||||
try:
|
||||
attachment_id = uuid.uuid4()
|
||||
original_filename = safe_filename(os.path.basename(file.filename))
|
||||
_enforce_uploaded_audio_size_limit(file, original_filename)
|
||||
relative_path = f"{settings.UPLOAD_FOLDER}/{user}/attachments/{str(attachment_id)}/{original_filename}"
|
||||
|
||||
metadata = storage.save_file(file, relative_path)
|
||||
file_info = {
|
||||
"filename": original_filename,
|
||||
"attachment_id": str(attachment_id),
|
||||
"path": relative_path,
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
task = store_attachment.delay(file_info, user)
|
||||
tasks.append({
|
||||
"task_id": task.id,
|
||||
"filename": original_filename,
|
||||
"attachment_id": str(attachment_id),
|
||||
"upload_index": idx,
|
||||
})
|
||||
except Exception as file_err:
|
||||
current_app.logger.error(f"Error processing file {idx} ({file.filename}): {file_err}", exc_info=True)
|
||||
errors.append({
|
||||
"upload_index": idx,
|
||||
"filename": file.filename,
|
||||
"error": _get_store_attachment_user_error(file_err),
|
||||
})
|
||||
|
||||
if not tasks:
|
||||
if errors and all(
|
||||
error.get("error") == build_stt_file_size_limit_message()
|
||||
for error in errors
|
||||
):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": build_stt_file_size_limit_message(),
|
||||
"errors": errors,
|
||||
}
|
||||
),
|
||||
413,
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"status": "error", "message": "No valid files to upload"}),
|
||||
400,
|
||||
)
|
||||
|
||||
if original_file_count == 1 and len(tasks) == 1:
|
||||
current_app.logger.info("Returning single task_id response")
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"task_id": tasks[0]["task_id"],
|
||||
"message": "File uploaded successfully. Processing started.",
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
else:
|
||||
response_data = {
|
||||
"success": True,
|
||||
"tasks": tasks,
|
||||
"message": f"{len(tasks)} file(s) uploaded successfully. Processing started.",
|
||||
}
|
||||
if errors:
|
||||
response_data["errors"] = errors
|
||||
response_data["message"] += f" {len(errors)} file(s) failed."
|
||||
|
||||
return make_response(
|
||||
jsonify(response_data),
|
||||
200,
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error storing attachment: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False, "error": "Failed to store attachment"}), 400)
|
||||
|
||||
|
||||
@attachments_ns.route("/stt")
|
||||
class SpeechToText(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"SpeechToTextModel",
|
||||
{
|
||||
"file": fields.Raw(required=True, description="Audio file"),
|
||||
"language": fields.String(
|
||||
required=False, description="Optional transcription language hint"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Transcribe an uploaded audio file")
|
||||
def post(self):
|
||||
auth_user = _resolve_authenticated_user()
|
||||
if hasattr(auth_user, "status_code"):
|
||||
return auth_user
|
||||
if not auth_user:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Authentication required"}),
|
||||
401,
|
||||
)
|
||||
|
||||
file = request.files.get("file")
|
||||
if not file or file.filename == "":
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Missing file"}),
|
||||
400,
|
||||
)
|
||||
|
||||
filename = safe_filename(os.path.basename(file.filename))
|
||||
suffix = Path(filename).suffix.lower()
|
||||
if suffix not in SUPPORTED_AUDIO_EXTENSIONS:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Unsupported audio format"}),
|
||||
400,
|
||||
)
|
||||
|
||||
if not _is_supported_audio_mimetype(file.mimetype or ""):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Unsupported audio MIME type"}),
|
||||
400,
|
||||
)
|
||||
|
||||
try:
|
||||
_enforce_uploaded_audio_size_limit(file, filename)
|
||||
except AudioFileTooLargeError:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": build_stt_file_size_limit_message(),
|
||||
}
|
||||
),
|
||||
413,
|
||||
)
|
||||
|
||||
temp_path = None
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
|
||||
file.save(temp_file.name)
|
||||
temp_path = Path(temp_file.name)
|
||||
|
||||
stt_instance = STTCreator.create_stt(settings.STT_PROVIDER)
|
||||
transcript = stt_instance.transcribe(
|
||||
temp_path,
|
||||
language=request.form.get("language") or settings.STT_LANGUAGE,
|
||||
timestamps=settings.STT_ENABLE_TIMESTAMPS,
|
||||
diarize=settings.STT_ENABLE_DIARIZATION,
|
||||
)
|
||||
return make_response(jsonify({"success": True, **transcript}), 200)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error transcribing audio: {err}", exc_info=True)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Failed to transcribe audio"}),
|
||||
400,
|
||||
)
|
||||
finally:
|
||||
if temp_path and temp_path.exists():
|
||||
temp_path.unlink()
|
||||
|
||||
|
||||
@attachments_ns.route("/stt/live/start")
|
||||
class LiveSpeechToTextStart(Resource):
|
||||
@api.doc(description="Start a live speech-to-text session")
|
||||
def post(self):
|
||||
auth_user = _resolve_authenticated_user()
|
||||
if hasattr(auth_user, "status_code"):
|
||||
return auth_user
|
||||
if not auth_user:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Authentication required"}),
|
||||
401,
|
||||
)
|
||||
|
||||
redis_client = _require_live_stt_redis()
|
||||
if hasattr(redis_client, "status_code"):
|
||||
return redis_client
|
||||
|
||||
payload = request.get_json(silent=True) or {}
|
||||
session_state = create_live_stt_session(
|
||||
user=auth_user,
|
||||
language=payload.get("language") or settings.STT_LANGUAGE,
|
||||
)
|
||||
save_live_stt_session(redis_client, session_state)
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"session_id": session_state["session_id"],
|
||||
"language": session_state.get("language"),
|
||||
"committed_text": "",
|
||||
"mutable_text": "",
|
||||
"previous_hypothesis": "",
|
||||
"latest_hypothesis": "",
|
||||
"finalized_text": "",
|
||||
"pending_text": "",
|
||||
"transcript_text": "",
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
|
||||
|
||||
@attachments_ns.route("/stt/live/chunk")
|
||||
class LiveSpeechToTextChunk(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"LiveSpeechToTextChunkModel",
|
||||
{
|
||||
"session_id": fields.String(
|
||||
required=True, description="Live transcription session ID"
|
||||
),
|
||||
"chunk_index": fields.Integer(
|
||||
required=True, description="Sequential chunk index"
|
||||
),
|
||||
"is_silence": fields.Boolean(
|
||||
required=False,
|
||||
description="Whether the latest capture window was mostly silence",
|
||||
),
|
||||
"file": fields.Raw(required=True, description="Audio chunk"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Transcribe a chunk for a live speech-to-text session")
|
||||
def post(self):
|
||||
auth_user = _resolve_authenticated_user()
|
||||
if hasattr(auth_user, "status_code"):
|
||||
return auth_user
|
||||
if not auth_user:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Authentication required"}),
|
||||
401,
|
||||
)
|
||||
|
||||
redis_client = _require_live_stt_redis()
|
||||
if hasattr(redis_client, "status_code"):
|
||||
return redis_client
|
||||
|
||||
session_id = request.form.get("session_id", "").strip()
|
||||
if not session_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Missing session_id"}),
|
||||
400,
|
||||
)
|
||||
|
||||
session_state = load_live_stt_session(redis_client, session_id)
|
||||
if not session_state:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Live transcription session not found",
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
|
||||
if safe_filename(str(session_state.get("user", ""))) != auth_user:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Forbidden"}),
|
||||
403,
|
||||
)
|
||||
|
||||
chunk_index_raw = request.form.get("chunk_index", "").strip()
|
||||
if chunk_index_raw == "":
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Missing chunk_index"}),
|
||||
400,
|
||||
)
|
||||
|
||||
try:
|
||||
chunk_index = int(chunk_index_raw)
|
||||
except ValueError:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid chunk_index"}),
|
||||
400,
|
||||
)
|
||||
is_silence = _parse_bool_form_value(request.form.get("is_silence"))
|
||||
|
||||
file = request.files.get("file")
|
||||
if not file or file.filename == "":
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Missing file"}),
|
||||
400,
|
||||
)
|
||||
|
||||
filename = safe_filename(os.path.basename(file.filename))
|
||||
suffix = Path(filename).suffix.lower()
|
||||
if suffix not in SUPPORTED_AUDIO_EXTENSIONS:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Unsupported audio format"}),
|
||||
400,
|
||||
)
|
||||
|
||||
if not _is_supported_audio_mimetype(file.mimetype or ""):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Unsupported audio MIME type"}),
|
||||
400,
|
||||
)
|
||||
|
||||
try:
|
||||
_enforce_uploaded_audio_size_limit(file, filename)
|
||||
except AudioFileTooLargeError:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": build_stt_file_size_limit_message(),
|
||||
}
|
||||
),
|
||||
413,
|
||||
)
|
||||
|
||||
temp_path = None
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
|
||||
file.save(temp_file.name)
|
||||
temp_path = Path(temp_file.name)
|
||||
|
||||
session_language = session_state.get("language") or settings.STT_LANGUAGE
|
||||
stt_instance = STTCreator.create_stt(settings.STT_PROVIDER)
|
||||
transcript = stt_instance.transcribe(
|
||||
temp_path,
|
||||
language=session_language,
|
||||
timestamps=False,
|
||||
diarize=False,
|
||||
)
|
||||
if not session_state.get("language") and transcript.get("language"):
|
||||
session_state["language"] = transcript["language"]
|
||||
|
||||
try:
|
||||
apply_live_stt_hypothesis(
|
||||
session_state,
|
||||
str(transcript.get("text", "")),
|
||||
chunk_index,
|
||||
is_silence=is_silence,
|
||||
)
|
||||
except ValueError:
|
||||
current_app.logger.warning(
|
||||
"Invalid live transcription chunk",
|
||||
exc_info=True,
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Invalid live transcription chunk",
|
||||
}
|
||||
),
|
||||
409,
|
||||
)
|
||||
save_live_stt_session(redis_client, session_state)
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"session_id": session_id,
|
||||
"chunk_index": chunk_index,
|
||||
"chunk_text": transcript.get("text", ""),
|
||||
"is_silence": is_silence,
|
||||
"language": session_state.get("language"),
|
||||
"committed_text": session_state.get("committed_text", ""),
|
||||
"mutable_text": session_state.get("mutable_text", ""),
|
||||
"previous_hypothesis": session_state.get(
|
||||
"previous_hypothesis", ""
|
||||
),
|
||||
"latest_hypothesis": session_state.get(
|
||||
"latest_hypothesis", ""
|
||||
),
|
||||
"finalized_text": session_state.get("committed_text", ""),
|
||||
"pending_text": session_state.get("mutable_text", ""),
|
||||
"transcript_text": get_live_stt_transcript_text(session_state),
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error transcribing live audio chunk: {err}", exc_info=True
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Failed to transcribe audio"}),
|
||||
400,
|
||||
)
|
||||
finally:
|
||||
if temp_path and temp_path.exists():
|
||||
temp_path.unlink()
|
||||
|
||||
|
||||
@attachments_ns.route("/stt/live/finish")
|
||||
class LiveSpeechToTextFinish(Resource):
|
||||
@api.doc(description="Finish a live speech-to-text session")
|
||||
def post(self):
|
||||
auth_user = _resolve_authenticated_user()
|
||||
if hasattr(auth_user, "status_code"):
|
||||
return auth_user
|
||||
if not auth_user:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Authentication required"}),
|
||||
401,
|
||||
)
|
||||
|
||||
redis_client = _require_live_stt_redis()
|
||||
if hasattr(redis_client, "status_code"):
|
||||
return redis_client
|
||||
|
||||
payload = request.get_json(silent=True) or {}
|
||||
session_id = str(payload.get("session_id", "")).strip()
|
||||
if not session_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Missing session_id"}),
|
||||
400,
|
||||
)
|
||||
|
||||
session_state = load_live_stt_session(redis_client, session_id)
|
||||
if not session_state:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Live transcription session not found",
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
|
||||
if safe_filename(str(session_state.get("user", ""))) != auth_user:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Forbidden"}),
|
||||
403,
|
||||
)
|
||||
|
||||
final_text = finalize_live_stt_session(session_state)
|
||||
delete_live_stt_session(redis_client, session_id)
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"session_id": session_id,
|
||||
"language": session_state.get("language"),
|
||||
"text": final_text,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
|
||||
|
||||
@attachments_ns.route("/images/<path:image_path>")
|
||||
class ServeImage(Resource):
|
||||
@api.doc(description="Serve an image from storage")
|
||||
def get(self, image_path):
|
||||
if ".." in image_path or image_path.startswith("/") or "\x00" in image_path:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid image path"}), 400
|
||||
)
|
||||
try:
|
||||
from application.api.user.base import storage
|
||||
|
||||
file_obj = storage.get_file(image_path)
|
||||
extension = image_path.split(".")[-1].lower()
|
||||
content_type = f"image/{extension}"
|
||||
if extension == "jpg":
|
||||
content_type = "image/jpeg"
|
||||
response = make_response(file_obj.read())
|
||||
response.headers.set("Content-Type", content_type)
|
||||
response.headers.set("Cache-Control", "max-age=86400")
|
||||
|
||||
return response
|
||||
except FileNotFoundError:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Image not found"}), 404
|
||||
)
|
||||
except ValueError:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid image path"}), 400
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error serving image: {e}")
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Error retrieving image"}), 500
|
||||
)
|
||||
|
||||
|
||||
@attachments_ns.route("/tts")
|
||||
class TextToSpeech(Resource):
|
||||
tts_model = api.model(
|
||||
"TextToSpeechModel",
|
||||
{
|
||||
"text": fields.String(
|
||||
required=True, description="Text to be synthesized as audio"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(tts_model)
|
||||
@api.doc(description="Synthesize audio speech from text")
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
text = data["text"]
|
||||
try:
|
||||
tts_instance = TTSCreator.create_tts(settings.TTS_PROVIDER)
|
||||
audio_base64, detected_language = tts_instance.text_to_speech(text)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"audio_base64": audio_base64,
|
||||
"lang": detected_language,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error synthesizing audio: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
236
application/api/user/base.py
Normal file
236
application/api/user/base.py
Normal file
@@ -0,0 +1,236 @@
|
||||
"""
|
||||
Shared utilities, database connections, and helper functions for user API routes.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import os
|
||||
import uuid
|
||||
from functools import wraps
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from flask import current_app, jsonify, make_response, Response
|
||||
from werkzeug.utils import secure_filename
|
||||
|
||||
from sqlalchemy import text as _sql_text
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
|
||||
from application.storage.db.repositories.users import UsersRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
from application.vectorstore.vector_creator import VectorCreator
|
||||
|
||||
|
||||
storage = StorageCreator.get_storage()
|
||||
|
||||
|
||||
current_dir = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
)
|
||||
|
||||
|
||||
def generate_minute_range(start_date, end_date):
|
||||
"""Generate a dictionary with minute-level time ranges."""
|
||||
return {
|
||||
(start_date + datetime.timedelta(minutes=i)).strftime("%Y-%m-%d %H:%M:00"): 0
|
||||
for i in range(int((end_date - start_date).total_seconds() // 60) + 1)
|
||||
}
|
||||
|
||||
|
||||
def generate_hourly_range(start_date, end_date):
|
||||
"""Generate a dictionary with hourly time ranges."""
|
||||
return {
|
||||
(start_date + datetime.timedelta(hours=i)).strftime("%Y-%m-%d %H:00"): 0
|
||||
for i in range(int((end_date - start_date).total_seconds() // 3600) + 1)
|
||||
}
|
||||
|
||||
|
||||
def generate_date_range(start_date, end_date):
|
||||
"""Generate a dictionary with daily date ranges."""
|
||||
return {
|
||||
(start_date + datetime.timedelta(days=i)).strftime("%Y-%m-%d"): 0
|
||||
for i in range((end_date - start_date).days + 1)
|
||||
}
|
||||
|
||||
|
||||
def ensure_user_doc(user_id):
|
||||
"""
|
||||
Ensure a Postgres ``users`` row exists for ``user_id``.
|
||||
|
||||
Returns the row as a dict with the shape legacy callers expect — in
|
||||
particular ``user_id`` and ``agent_preferences`` (with ``pinned`` and
|
||||
``shared_with_me`` list keys always present).
|
||||
|
||||
Args:
|
||||
user_id: The user ID to ensure
|
||||
|
||||
Returns:
|
||||
The user document as a dict.
|
||||
"""
|
||||
with db_session() as conn:
|
||||
user_doc = UsersRepository(conn).upsert(user_id)
|
||||
|
||||
prefs = user_doc.get("agent_preferences") or {}
|
||||
if not isinstance(prefs, dict):
|
||||
prefs = {}
|
||||
prefs.setdefault("pinned", [])
|
||||
prefs.setdefault("shared_with_me", [])
|
||||
user_doc["agent_preferences"] = prefs
|
||||
return user_doc
|
||||
|
||||
|
||||
def resolve_tool_details(tool_ids):
|
||||
"""
|
||||
Resolve tool IDs to their display details.
|
||||
|
||||
Accepts either Postgres UUIDs or legacy Mongo ObjectId strings (mixed
|
||||
lists are supported — each id is looked up via ``get_any``, which
|
||||
resolves to whichever column matches). Unknown ids are silently
|
||||
skipped.
|
||||
|
||||
Args:
|
||||
tool_ids: List of tool IDs (UUIDs or legacy Mongo ObjectId strings).
|
||||
|
||||
Returns:
|
||||
List of tool details with ``id``, ``name``, and ``display_name``.
|
||||
"""
|
||||
if not tool_ids:
|
||||
return []
|
||||
|
||||
uuid_ids: list[str] = []
|
||||
legacy_ids: list[str] = []
|
||||
for tid in tool_ids:
|
||||
if not tid:
|
||||
continue
|
||||
tid_str = str(tid)
|
||||
if looks_like_uuid(tid_str):
|
||||
uuid_ids.append(tid_str)
|
||||
else:
|
||||
legacy_ids.append(tid_str)
|
||||
|
||||
if not uuid_ids and not legacy_ids:
|
||||
return []
|
||||
|
||||
rows: list[dict] = []
|
||||
with db_readonly() as conn:
|
||||
if uuid_ids:
|
||||
result = conn.execute(
|
||||
_sql_text(
|
||||
"SELECT * FROM user_tools "
|
||||
"WHERE id = ANY(CAST(:ids AS uuid[]))"
|
||||
),
|
||||
{"ids": uuid_ids},
|
||||
)
|
||||
rows.extend(row_to_dict(r) for r in result.fetchall())
|
||||
if legacy_ids:
|
||||
result = conn.execute(
|
||||
_sql_text(
|
||||
"SELECT * FROM user_tools "
|
||||
"WHERE legacy_mongo_id = ANY(:ids)"
|
||||
),
|
||||
{"ids": legacy_ids},
|
||||
)
|
||||
rows.extend(row_to_dict(r) for r in result.fetchall())
|
||||
|
||||
return [
|
||||
{
|
||||
"id": str(tool.get("id") or tool.get("legacy_mongo_id") or ""),
|
||||
"name": tool.get("name", "") or "",
|
||||
"display_name": (
|
||||
tool.get("custom_name")
|
||||
or tool.get("display_name")
|
||||
or tool.get("name", "")
|
||||
or ""
|
||||
),
|
||||
}
|
||||
for tool in rows
|
||||
]
|
||||
|
||||
|
||||
def get_vector_store(source_id):
|
||||
"""
|
||||
Get the Vector Store for a given source ID.
|
||||
|
||||
Args:
|
||||
source_id (str): source id of the document
|
||||
|
||||
Returns:
|
||||
Vector store instance
|
||||
"""
|
||||
store = VectorCreator.create_vectorstore(
|
||||
settings.VECTOR_STORE,
|
||||
source_id=source_id,
|
||||
embeddings_key=os.getenv("EMBEDDINGS_KEY"),
|
||||
)
|
||||
return store
|
||||
|
||||
|
||||
def handle_image_upload(
|
||||
request, existing_url: str, user: str, storage, base_path: str = "attachments/"
|
||||
) -> Tuple[str, Optional[Response]]:
|
||||
"""
|
||||
Handle image file upload from request.
|
||||
|
||||
Args:
|
||||
request: Flask request object
|
||||
existing_url: Existing image URL (fallback)
|
||||
user: User ID
|
||||
storage: Storage instance
|
||||
base_path: Base path for upload
|
||||
|
||||
Returns:
|
||||
Tuple of (image_url, error_response)
|
||||
"""
|
||||
image_url = existing_url
|
||||
|
||||
if "image" in request.files:
|
||||
file = request.files["image"]
|
||||
if file.filename != "":
|
||||
filename = secure_filename(file.filename)
|
||||
upload_path = f"{settings.UPLOAD_FOLDER.rstrip('/')}/{user}/{base_path.rstrip('/')}/{uuid.uuid4()}_{filename}"
|
||||
try:
|
||||
storage.save_file(file, upload_path, storage_class="STANDARD")
|
||||
image_url = upload_path
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error uploading image: {e}")
|
||||
return None, make_response(
|
||||
jsonify({"success": False, "message": "Image upload failed"}),
|
||||
400,
|
||||
)
|
||||
return image_url, None
|
||||
|
||||
|
||||
def require_agent(func):
|
||||
"""
|
||||
Decorator to require valid agent webhook token.
|
||||
|
||||
Args:
|
||||
func: Function to decorate
|
||||
|
||||
Returns:
|
||||
Wrapped function
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
|
||||
webhook_token = kwargs.get("webhook_token")
|
||||
if not webhook_token:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Webhook token missing"}), 400
|
||||
)
|
||||
with db_readonly() as conn:
|
||||
agent = AgentsRepository(conn).find_by_webhook_token(webhook_token)
|
||||
if not agent:
|
||||
current_app.logger.warning(
|
||||
f"Webhook attempt with invalid token: {webhook_token}"
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Agent not found"}), 404
|
||||
)
|
||||
kwargs["agent"] = agent
|
||||
kwargs["agent_id_str"] = str(agent["id"])
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
5
application/api/user/conversations/__init__.py
Normal file
5
application/api/user/conversations/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Conversation management module."""
|
||||
|
||||
from .routes import conversations_ns
|
||||
|
||||
__all__ = ["conversations_ns"]
|
||||
303
application/api/user/conversations/routes.py
Normal file
303
application/api/user/conversations/routes.py
Normal file
@@ -0,0 +1,303 @@
|
||||
"""Conversation management routes."""
|
||||
|
||||
import datetime
|
||||
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.storage.db.repositories.attachments import AttachmentsRepository
|
||||
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.utils import check_required_fields
|
||||
|
||||
conversations_ns = Namespace(
|
||||
"conversations", description="Conversation management operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
@conversations_ns.route("/delete_conversation")
|
||||
class DeleteConversation(Resource):
|
||||
@api.doc(
|
||||
description="Deletes a conversation by ID",
|
||||
params={"id": "The ID of the conversation to delete"},
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
conversation_id = request.args.get("id")
|
||||
if not conversation_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "ID is required"}), 400
|
||||
)
|
||||
user_id = decoded_token["sub"]
|
||||
try:
|
||||
with db_session() as conn:
|
||||
repo = ConversationsRepository(conn)
|
||||
conv = repo.get_any(conversation_id, user_id)
|
||||
if conv is not None:
|
||||
repo.delete(str(conv["id"]), user_id)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error deleting conversation: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@conversations_ns.route("/delete_all_conversations")
|
||||
class DeleteAllConversations(Resource):
|
||||
@api.doc(
|
||||
description="Deletes all conversations for a specific user",
|
||||
)
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user_id = decoded_token.get("sub")
|
||||
try:
|
||||
with db_session() as conn:
|
||||
ConversationsRepository(conn).delete_all_for_user(user_id)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error deleting all conversations: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@conversations_ns.route("/get_conversations")
|
||||
class GetConversations(Resource):
|
||||
@api.doc(
|
||||
description="Retrieve a list of the latest 30 conversations (excluding API key conversations)",
|
||||
)
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user_id = decoded_token.get("sub")
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
conversations = ConversationsRepository(conn).list_for_user(
|
||||
user_id, limit=30
|
||||
)
|
||||
list_conversations = [
|
||||
{
|
||||
"id": str(conversation["id"]),
|
||||
"name": conversation["name"],
|
||||
"agent_id": (
|
||||
str(conversation["agent_id"])
|
||||
if conversation.get("agent_id")
|
||||
else None
|
||||
),
|
||||
"is_shared_usage": conversation.get("is_shared_usage", False),
|
||||
"shared_token": conversation.get("shared_token", None),
|
||||
}
|
||||
for conversation in conversations
|
||||
]
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error retrieving conversations: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify(list_conversations), 200)
|
||||
|
||||
|
||||
@conversations_ns.route("/get_single_conversation")
|
||||
class GetSingleConversation(Resource):
|
||||
@api.doc(
|
||||
description="Retrieve a single conversation by ID",
|
||||
params={"id": "The conversation ID"},
|
||||
)
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
conversation_id = request.args.get("id")
|
||||
if not conversation_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "ID is required"}), 400
|
||||
)
|
||||
user_id = decoded_token.get("sub")
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
repo = ConversationsRepository(conn)
|
||||
conversation = repo.get_any(conversation_id, user_id)
|
||||
if not conversation:
|
||||
return make_response(jsonify({"status": "not found"}), 404)
|
||||
conv_pg_id = str(conversation["id"])
|
||||
messages = repo.get_messages(conv_pg_id)
|
||||
|
||||
# Resolve attachment details (id, fileName) for each message.
|
||||
attachments_repo = AttachmentsRepository(conn)
|
||||
queries = []
|
||||
for msg in messages:
|
||||
query = {
|
||||
"prompt": msg.get("prompt"),
|
||||
"response": msg.get("response"),
|
||||
"thought": msg.get("thought"),
|
||||
"sources": msg.get("sources") or [],
|
||||
"tool_calls": msg.get("tool_calls") or [],
|
||||
"timestamp": msg.get("timestamp"),
|
||||
"model_id": msg.get("model_id"),
|
||||
}
|
||||
if msg.get("metadata"):
|
||||
query["metadata"] = msg["metadata"]
|
||||
# Feedback on conversation_messages is a JSONB blob with
|
||||
# shape {"text": <str>, "timestamp": <iso>}. The legacy
|
||||
# frontend consumed a flat scalar feedback string, so
|
||||
# unwrap the ``text`` field for compat.
|
||||
feedback = msg.get("feedback")
|
||||
if feedback is not None:
|
||||
if isinstance(feedback, dict):
|
||||
query["feedback"] = feedback.get("text")
|
||||
if feedback.get("timestamp"):
|
||||
query["feedback_timestamp"] = feedback["timestamp"]
|
||||
else:
|
||||
query["feedback"] = feedback
|
||||
attachments = msg.get("attachments") or []
|
||||
if attachments:
|
||||
attachment_details = []
|
||||
for attachment_id in attachments:
|
||||
try:
|
||||
att = attachments_repo.get_any(
|
||||
str(attachment_id), user_id
|
||||
)
|
||||
if att:
|
||||
attachment_details.append(
|
||||
{
|
||||
"id": str(att["id"]),
|
||||
"fileName": att.get(
|
||||
"filename", "Unknown file"
|
||||
),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error retrieving attachment {attachment_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
query["attachments"] = attachment_details
|
||||
queries.append(query)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error retrieving conversation: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
data = {
|
||||
"queries": queries,
|
||||
"agent_id": (
|
||||
str(conversation["agent_id"]) if conversation.get("agent_id") else None
|
||||
),
|
||||
"is_shared_usage": conversation.get("is_shared_usage", False),
|
||||
"shared_token": conversation.get("shared_token", None),
|
||||
}
|
||||
return make_response(jsonify(data), 200)
|
||||
|
||||
|
||||
@conversations_ns.route("/update_conversation_name")
|
||||
class UpdateConversationName(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"UpdateConversationModel",
|
||||
{
|
||||
"id": fields.String(required=True, description="Conversation ID"),
|
||||
"name": fields.String(
|
||||
required=True, description="New name of the conversation"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Updates the name of a conversation",
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
data = request.get_json()
|
||||
required_fields = ["id", "name"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
user_id = decoded_token.get("sub")
|
||||
try:
|
||||
with db_session() as conn:
|
||||
repo = ConversationsRepository(conn)
|
||||
conv = repo.get_any(data["id"], user_id)
|
||||
if conv is not None:
|
||||
repo.rename(str(conv["id"]), user_id, data["name"])
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error updating conversation name: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@conversations_ns.route("/feedback")
|
||||
class SubmitFeedback(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"FeedbackModel",
|
||||
{
|
||||
"question": fields.String(
|
||||
required=False, description="The user question"
|
||||
),
|
||||
"answer": fields.String(required=False, description="The AI answer"),
|
||||
"feedback": fields.String(required=True, description="User feedback"),
|
||||
"question_index": fields.Integer(
|
||||
required=True,
|
||||
description="The question number in that particular conversation",
|
||||
),
|
||||
"conversation_id": fields.String(
|
||||
required=True, description="id of the particular conversation"
|
||||
),
|
||||
"api_key": fields.String(description="Optional API key"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Submit feedback for a conversation",
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
data = request.get_json()
|
||||
required_fields = ["feedback", "conversation_id", "question_index"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
user_id = decoded_token.get("sub")
|
||||
feedback_value = data["feedback"]
|
||||
question_index = int(data["question_index"])
|
||||
# Normalize string feedback to lowercase so analytics queries
|
||||
# (which match 'like'/'dislike') count rows correctly. Tolerate
|
||||
# legacy uppercase clients on ingest. Non-string values pass through.
|
||||
if isinstance(feedback_value, str):
|
||||
feedback_value = feedback_value.lower()
|
||||
feedback_payload = (
|
||||
None
|
||||
if feedback_value is None
|
||||
else {
|
||||
"text": feedback_value,
|
||||
"timestamp": datetime.datetime.now(
|
||||
datetime.timezone.utc
|
||||
).isoformat(),
|
||||
}
|
||||
)
|
||||
try:
|
||||
with db_session() as conn:
|
||||
repo = ConversationsRepository(conn)
|
||||
conv = repo.get_any(data["conversation_id"], user_id)
|
||||
if conv is None:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Not found"}), 404
|
||||
)
|
||||
repo.set_feedback(str(conv["id"]), question_index, feedback_payload)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error submitting feedback: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
3
application/api/user/models/__init__.py
Normal file
3
application/api/user/models/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .routes import models_ns
|
||||
|
||||
__all__ = ["models_ns"]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user