mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-07 06:30:03 +00:00
Compare commits
1127 Commits
tool-proxi
...
dependabot
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5d0eed6084 | ||
|
|
2fed5c882b | ||
|
|
aa938d76d7 | ||
|
|
2940628aa6 | ||
|
|
7f23928134 | ||
|
|
20e17c84c7 | ||
|
|
389ddf6068 | ||
|
|
1e2443fb90 | ||
|
|
6387bd1892 | ||
|
|
7d22724d1c | ||
|
|
f6f12f6895 | ||
|
|
934127f323 | ||
|
|
1780e3cc91 | ||
|
|
5e7fab2f34 | ||
|
|
92ae76f95e | ||
|
|
18755bdd9b | ||
|
|
0f20adcbf4 | ||
|
|
18e2a829c9 | ||
|
|
cd44501a71 | ||
|
|
f8ebdf3fd4 | ||
|
|
7c6fca18ad | ||
|
|
5fab798707 | ||
|
|
cb30a24e05 | ||
|
|
530761d08c | ||
|
|
73fbc28744 | ||
|
|
b5b6538762 | ||
|
|
a9761061fc | ||
|
|
9388996a15 | ||
|
|
875868b7e5 | ||
|
|
502819ae52 | ||
|
|
cada1a44fc | ||
|
|
6192767451 | ||
|
|
5c3e6eca54 | ||
|
|
59d9d4ac50 | ||
|
|
3931ccccee | ||
|
|
55717043f6 | ||
|
|
ececcb8b17 | ||
|
|
420e9d3dd5 | ||
|
|
749eed3d0b | ||
|
|
bd03a513e3 | ||
|
|
fcdb4fb5e8 | ||
|
|
e787c896eb | ||
|
|
23aeaff5db | ||
|
|
689dd79597 | ||
|
|
0c15af90b1 | ||
|
|
cdd6ff6557 | ||
|
|
72b3d94453 | ||
|
|
7e88d09e5d | ||
|
|
74a4a237dc | ||
|
|
c3f01c6619 | ||
|
|
6b408823d4 | ||
|
|
3fc81ac5d8 | ||
|
|
2652f8a5b0 | ||
|
|
d711eefe96 | ||
|
|
79206f3919 | ||
|
|
de971d9452 | ||
|
|
1b4d5ca0dd | ||
|
|
81989e8258 | ||
|
|
dc262d1698 | ||
|
|
69f9c93869 | ||
|
|
74bf80b25c | ||
|
|
d9a92a7208 | ||
|
|
02e93d993d | ||
|
|
6b6495f48c | ||
|
|
249dd9ce37 | ||
|
|
9134ab0478 | ||
|
|
10ef68c9d0 | ||
|
|
7d65cf1c2b | ||
|
|
13c6cc59c1 | ||
|
|
6381f7dd4e | ||
|
|
e6ac4008fe | ||
|
|
1af09f114d | ||
|
|
be7da983e7 | ||
|
|
8b9e595d85 | ||
|
|
398f3acc8d | ||
|
|
e04baa7ed8 | ||
|
|
e5586b6f20 | ||
|
|
addf57cab7 | ||
|
|
648b3f1d20 | ||
|
|
a75a9e23f9 | ||
|
|
73256389cf | ||
|
|
d609efca49 | ||
|
|
772860b667 | ||
|
|
ea2fd8b04a | ||
|
|
2c73deac20 | ||
|
|
47f3907e5e | ||
|
|
727495c553 | ||
|
|
a3b08a5b44 | ||
|
|
81532ada2a | ||
|
|
43f71374e5 | ||
|
|
d5c0322e2a | ||
|
|
3b66a3176c | ||
|
|
dc6db847ca | ||
|
|
ed0063aada | ||
|
|
9a6a55b6da | ||
|
|
12a8368216 | ||
|
|
3f6d6f15ea | ||
|
|
126fa01b14 | ||
|
|
e06debad5f | ||
|
|
6492852f7d | ||
|
|
00a621f33a | ||
|
|
e92ffc6fdc | ||
|
|
fe185e5b8d | ||
|
|
9f3d9ab860 | ||
|
|
1c0adde380 | ||
|
|
3c56bd0d0b | ||
|
|
86664ebda2 | ||
|
|
db18b743d1 | ||
|
|
9e85cc9065 | ||
|
|
aaaa6f002d | ||
|
|
47dcbcb74b | ||
|
|
ddbfd94193 | ||
|
|
8dec60ab8b | ||
|
|
84b2e4bab4 | ||
|
|
193ca6fd63 | ||
|
|
2afdd7f026 | ||
|
|
f364475f64 | ||
|
|
b254de6ed6 | ||
|
|
08dedcaf95 | ||
|
|
c726eb8ebd | ||
|
|
5f0d39e5f1 | ||
|
|
8c82fc5495 | ||
|
|
6d81a15e97 | ||
|
|
5478e4234c | ||
|
|
4056278fef | ||
|
|
ee6530fe00 | ||
|
|
7c1decbcc3 | ||
|
|
8a3c724b31 | ||
|
|
15d4e9dbf5 | ||
|
|
174dee0fe6 | ||
|
|
f7bfd38b28 | ||
|
|
187e5da61e | ||
|
|
175ed58d2e | ||
|
|
820ee3a843 | ||
|
|
462f2e9494 | ||
|
|
c4968a641e | ||
|
|
c6ece177cd | ||
|
|
a3e6a5622d | ||
|
|
e8d11fdfa6 | ||
|
|
72393dc369 | ||
|
|
556b0a1da5 | ||
|
|
844167ba06 | ||
|
|
6fa3acb1ca | ||
|
|
32c268a21e | ||
|
|
ed34c2b929 | ||
|
|
06e827573c | ||
|
|
74e76d4cda | ||
|
|
db5c69ca76 | ||
|
|
9fd063266b | ||
|
|
05aa9d7cca | ||
|
|
dcececd118 | ||
|
|
eaf39bb15b | ||
|
|
6515481624 | ||
|
|
6a7e3b6d77 | ||
|
|
02804fecce | ||
|
|
324a8cd4cf | ||
|
|
ce5cd5561a | ||
|
|
adeefce9aa | ||
|
|
5ab43fd12c | ||
|
|
5894e47189 | ||
|
|
ca61d81f4a | ||
|
|
b12d0ca7b1 | ||
|
|
21996af626 | ||
|
|
cc3b174e5a | ||
|
|
faee58fb1e | ||
|
|
d439e48b39 | ||
|
|
3f0f155d64 | ||
|
|
d82d512319 | ||
|
|
76aea1716f | ||
|
|
586649b73f | ||
|
|
0349a79cb3 | ||
|
|
78a255bdd7 | ||
|
|
5b30e71aa1 | ||
|
|
99d84aece9 | ||
|
|
525d8eb66d | ||
|
|
4c810108e0 | ||
|
|
fc03cdc76a | ||
|
|
9779a563f3 | ||
|
|
6141c3c348 | ||
|
|
c3726ddfc9 | ||
|
|
10eaa8143e | ||
|
|
0c4f4e1f0c | ||
|
|
b225c3cd80 | ||
|
|
b558645d6b | ||
|
|
03b0889b15 | ||
|
|
943fe3651c | ||
|
|
65e57be4dd | ||
|
|
13ad3b5dce | ||
|
|
918bbf0369 | ||
|
|
5006271abb | ||
|
|
a6625ec5de | ||
|
|
1a2104f474 | ||
|
|
444abb8283 | ||
|
|
ee86537f21 | ||
|
|
17a736a927 | ||
|
|
6b5779054d | ||
|
|
14296632ef | ||
|
|
2a3f0e455a | ||
|
|
8aa44c415b | ||
|
|
0566c41a32 | ||
|
|
876b04c058 | ||
|
|
b49a5934e2 | ||
|
|
5fb063914e | ||
|
|
b9941e29a9 | ||
|
|
8ef321d784 | ||
|
|
8353f9c649 | ||
|
|
cb6b3aa406 | ||
|
|
36c7bd9206 | ||
|
|
fea94379d7 | ||
|
|
e602d941ca | ||
|
|
f41f69a268 | ||
|
|
ff72251878 | ||
|
|
7751fb52dd | ||
|
|
87a44d101d | ||
|
|
80148f25b6 | ||
|
|
8e3e4a8b09 | ||
|
|
9389b4a1e8 | ||
|
|
4245e5bd2e | ||
|
|
e7d2af2405 | ||
|
|
4c32a96370 | ||
|
|
f61d112cea | ||
|
|
2c55c6cd9a | ||
|
|
f1d714b5c1 | ||
|
|
69d9dc672a | ||
|
|
9192e010e8 | ||
|
|
f24cea0877 | ||
|
|
a29bfa7489 | ||
|
|
2246866a09 | ||
|
|
7b17fde34a | ||
|
|
df57053613 | ||
|
|
5662be12b5 | ||
|
|
d3e9d66b07 | ||
|
|
e0bdbcbe38 | ||
|
|
05c835ed02 | ||
|
|
9e7f1ad1c0 | ||
|
|
f910a82683 | ||
|
|
d8b7e86f8d | ||
|
|
aef3e0b4bb | ||
|
|
b0eee7be24 | ||
|
|
197e94302b | ||
|
|
98e949d2fd | ||
|
|
83e7a928f1 | ||
|
|
ccd29b7d4e | ||
|
|
5b6cfa6ecc | ||
|
|
f91846ce2d | ||
|
|
87e24ab96e | ||
|
|
40c3e5568c | ||
|
|
7958d29e13 | ||
|
|
a6fafa6a4d | ||
|
|
3ad38f53fd | ||
|
|
d90b1c57e5 | ||
|
|
a69a0e100f | ||
|
|
b0d4576a95 | ||
|
|
2a4ab3aca1 | ||
|
|
e0fd11a86e | ||
|
|
de369f8b5e | ||
|
|
af3e16c4fc | ||
|
|
aacf281222 | ||
|
|
6d8f083c6f | ||
|
|
909bc421c0 | ||
|
|
d14f04d79c | ||
|
|
e0a9f08632 | ||
|
|
09e7c1b97f | ||
|
|
4adffe762a | ||
|
|
9a937d2686 | ||
|
|
e68da34c13 | ||
|
|
9b9f95710a | ||
|
|
3352d42414 | ||
|
|
899b30da5e | ||
|
|
dc2faf7a7e | ||
|
|
67e0d222d1 | ||
|
|
17698ce774 | ||
|
|
7d1c8c008b | ||
|
|
9e58eb02b3 | ||
|
|
3f7de867cc | ||
|
|
2c2bdd37d5 | ||
|
|
6a00319c2d | ||
|
|
66870279d3 | ||
|
|
fbf7cf874b | ||
|
|
ba7278b80f | ||
|
|
9d649de6f9 | ||
|
|
7929afbf58 | ||
|
|
ceaf942e70 | ||
|
|
f355601a44 | ||
|
|
4ff99a1e86 | ||
|
|
129084ba92 | ||
|
|
2288df1293 | ||
|
|
d9dfac55e7 | ||
|
|
404cf4b7c7 | ||
|
|
f1c1fc123b | ||
|
|
9f19c7ee4c | ||
|
|
155e74eca1 | ||
|
|
ea2dc4dbcb | ||
|
|
616edc97de | ||
|
|
b017e99c79 | ||
|
|
f698e9d3e1 | ||
|
|
d366502850 | ||
|
|
3d6757c170 | ||
|
|
cb8302add8 | ||
|
|
9d266e9fad | ||
|
|
ae94c9d31e | ||
|
|
83ab232dcd | ||
|
|
eea85772a3 | ||
|
|
0fe7e223cc | ||
|
|
3789d2eb03 | ||
|
|
d54469532e | ||
|
|
9884e51836 | ||
|
|
6626723180 | ||
|
|
0c251e066b | ||
|
|
0957034bfa | ||
|
|
44521cd893 | ||
|
|
b17f846730 | ||
|
|
6dd32fd4ca | ||
|
|
b17b1c70b5 | ||
|
|
3f5b31fb5f | ||
|
|
06bda6bd55 | ||
|
|
7dd97821a8 | ||
|
|
695191d888 | ||
|
|
1dbcef24c7 | ||
|
|
e086c79da0 | ||
|
|
6ae8d34b27 | ||
|
|
2e23e547d3 | ||
|
|
fa11dc9828 | ||
|
|
673fa70bc5 | ||
|
|
a0660a54c1 | ||
|
|
1137bf4280 | ||
|
|
da41c898d8 | ||
|
|
21e5c261ef | ||
|
|
a7d61b9d59 | ||
|
|
c5fe25c149 | ||
|
|
6a4cb617f9 | ||
|
|
94f70e6de5 | ||
|
|
ab4ebf9a9d | ||
|
|
9f7945fcf5 | ||
|
|
d8ec3c008c | ||
|
|
2f00691246 | ||
|
|
9b2383b074 | ||
|
|
e4e9910575 | ||
|
|
f448e4a615 | ||
|
|
c4e8daf50e | ||
|
|
5aa4ec1b9f | ||
|
|
125ce0aad3 | ||
|
|
ababc9ae04 | ||
|
|
62ac90746e | ||
|
|
096f6d91a2 | ||
|
|
d28ef6b094 | ||
|
|
8fb945ab09 | ||
|
|
835d71727c | ||
|
|
ce32dd2907 | ||
|
|
72bc24a490 | ||
|
|
d6c49bdbf0 | ||
|
|
1805292528 | ||
|
|
d09ce7e1f7 | ||
|
|
a8d2024791 | ||
|
|
f0b954dbfb | ||
|
|
50bee7c2b0 | ||
|
|
e7b15b316e | ||
|
|
a4507008c1 | ||
|
|
c5ba85f929 | ||
|
|
2e636bd67e | ||
|
|
4a039f1abf | ||
|
|
434d8e2070 | ||
|
|
160ad2dc79 | ||
|
|
0ec86c2c71 | ||
|
|
03452ffd9f | ||
|
|
da6317a242 | ||
|
|
8b8e616557 | ||
|
|
d260f1a1a6 | ||
|
|
9d452e3b04 | ||
|
|
e012189672 | ||
|
|
4c31e9a8b1 | ||
|
|
7cfc230316 | ||
|
|
9605e85f1c | ||
|
|
498e2b772c | ||
|
|
dad897da51 | ||
|
|
02ad5f062e | ||
|
|
4eb9471b4f | ||
|
|
b505d207d7 | ||
|
|
3c954bd07f | ||
|
|
c00b6459dc | ||
|
|
eb4d776784 | ||
|
|
5d7a890533 | ||
|
|
9c6aefef1e | ||
|
|
e4554d6c09 | ||
|
|
c184b63df8 | ||
|
|
6bb4195393 | ||
|
|
7827a4d40d | ||
|
|
f09fa8231a | ||
|
|
96ff10000d | ||
|
|
9460636867 | ||
|
|
6c43245295 | ||
|
|
266b6cf638 | ||
|
|
70183e234a | ||
|
|
17b9c359ca | ||
|
|
045630b8a5 | ||
|
|
55ff7dd640 | ||
|
|
e6d64f71f2 | ||
|
|
e72313ebdd | ||
|
|
65d5bd72cd | ||
|
|
dc0cbb41f0 | ||
|
|
c4a54a85be | ||
|
|
5b2738aec9 | ||
|
|
892312fc08 | ||
|
|
c2ccf2c72c | ||
|
|
80aaecb5f0 | ||
|
|
946865a335 | ||
|
|
5de15c8413 | ||
|
|
67268fd35a | ||
|
|
42fc771833 | ||
|
|
444b1a0b65 | ||
|
|
814ea1c016 | ||
|
|
4d34dc4234 | ||
|
|
d567399f2b | ||
|
|
77f4f8d8b0 | ||
|
|
a2d04beaa1 | ||
|
|
ba49eea23d | ||
|
|
82beafc086 | ||
|
|
7d8ed2d102 | ||
|
|
aab8d3a4f1 | ||
|
|
76658d50a0 | ||
|
|
88ba22342c | ||
|
|
11a1460af9 | ||
|
|
2cd4c41316 | ||
|
|
b910f308f2 | ||
|
|
763aa73ea4 | ||
|
|
30c79e92d4 | ||
|
|
402d5e054b | ||
|
|
0e211df206 | ||
|
|
e24a0ac686 | ||
|
|
8c91b1c527 | ||
|
|
2b38f80d04 | ||
|
|
282bd35f52 | ||
|
|
cc9b4c2bcb | ||
|
|
068ce4970a | ||
|
|
cf19165ad8 | ||
|
|
68c479f3a5 | ||
|
|
ba496a772b | ||
|
|
3b27db36f2 | ||
|
|
f803def69b | ||
|
|
52065e69a4 | ||
|
|
50f5e8a955 | ||
|
|
2d0e97b66d | ||
|
|
5f3cc5a392 | ||
|
|
ac66d77512 | ||
|
|
50cf653d4a | ||
|
|
56256051d2 | ||
|
|
c0361ff03d | ||
|
|
f153435c08 | ||
|
|
9aa7f22fa6 | ||
|
|
52b7bda5f8 | ||
|
|
21aefa2778 | ||
|
|
a89ff71c9e | ||
|
|
4c275816be | ||
|
|
f8dfbcfc80 | ||
|
|
d317f6473d | ||
|
|
00b4e133d4 | ||
|
|
b6349e4efb | ||
|
|
6ca3d9585c | ||
|
|
5935a0283a | ||
|
|
5400a6ec06 | ||
|
|
6574d9cc84 | ||
|
|
42b83c5994 | ||
|
|
896612a5a3 | ||
|
|
0ee875bee4 | ||
|
|
8ce345cd94 | ||
|
|
da2f8477e6 | ||
|
|
82b47b5673 | ||
|
|
7c15a4c7ff | ||
|
|
3369b910b4 | ||
|
|
ec0c4c3b84 | ||
|
|
f74e2c9da1 | ||
|
|
e26ad3c475 | ||
|
|
145c3b8ad0 | ||
|
|
0ff6c6a154 | ||
|
|
641cf5a4c1 | ||
|
|
09b9576eef | ||
|
|
18b71ca2f2 | ||
|
|
e0eb7f456e | ||
|
|
188d118fc0 | ||
|
|
adcdce8d76 | ||
|
|
b865a7aec1 | ||
|
|
cec8c72b46 | ||
|
|
b052e32805 | ||
|
|
816f660be3 | ||
|
|
fc8be45d5a | ||
|
|
e749c936c9 | ||
|
|
b2b9670a23 | ||
|
|
2f88890c94 | ||
|
|
6366663f03 | ||
|
|
20fe7dc6d1 | ||
|
|
4b9153069e | ||
|
|
80406d0753 | ||
|
|
35f4c11784 | ||
|
|
7896526f19 | ||
|
|
f7db22edff | ||
|
|
0e4196f036 | ||
|
|
1bf6af6eeb | ||
|
|
5a9bc6d2bf | ||
|
|
f7f6042579 | ||
|
|
c4a598f3d3 | ||
|
|
7c23f43c63 | ||
|
|
7e2cbdd88c | ||
|
|
3b3a04a249 | ||
|
|
f9b2c95695 | ||
|
|
c2c18e8319 | ||
|
|
384ad3e0ac | ||
|
|
8c986aaa7f | ||
|
|
bb4ea76d30 | ||
|
|
2868e47cf8 | ||
|
|
e0adc3e5d5 | ||
|
|
e55d1a5865 | ||
|
|
018273c6b2 | ||
|
|
44b8a11c04 | ||
|
|
56e5aba559 | ||
|
|
46904ccd54 | ||
|
|
5b7c7a4471 | ||
|
|
9da4215d1f | ||
|
|
f39ac9945f | ||
|
|
a0cc2e4d46 | ||
|
|
4065041a9f | ||
|
|
f08067a161 | ||
|
|
545caacfa3 | ||
|
|
a06f646637 | ||
|
|
578c68205a | ||
|
|
f09f1433a9 | ||
|
|
15a9e97a1e | ||
|
|
b3af4ee50b | ||
|
|
07d59b6640 | ||
|
|
e25b988dc8 | ||
|
|
2410bd8654 | ||
|
|
44d21ab703 | ||
|
|
e283957c8f | ||
|
|
b1210c4902 | ||
|
|
e7430f0fbc | ||
|
|
92d6ae54c3 | ||
|
|
f82be23ca9 | ||
|
|
8c3f75e3e2 | ||
|
|
193d59f193 | ||
|
|
c2bebbaefa | ||
|
|
7ae5a9c5a5 | ||
|
|
3b69bea23d | ||
|
|
ab05726b99 | ||
|
|
b2b04268e9 | ||
|
|
bd73fa9ae7 | ||
|
|
927d10d66e | ||
|
|
b67329623c | ||
|
|
6f47aa802b | ||
|
|
3417c73011 | ||
|
|
6a02bcf15b | ||
|
|
cd0fbf79a3 | ||
|
|
15d2d0115b | ||
|
|
d1a0fe6e91 | ||
|
|
1db80d140f | ||
|
|
896dcf1f9e | ||
|
|
819a12fb49 | ||
|
|
c68273706c | ||
|
|
6bb0cd535a | ||
|
|
cb9ec69cf6 | ||
|
|
143854fa81 | ||
|
|
2f48a3d7d5 | ||
|
|
ec95dafe1e | ||
|
|
3d1fe724e5 | ||
|
|
5c615d6f2d | ||
|
|
d72558eb36 | ||
|
|
65c33ad915 | ||
|
|
9be128a963 | ||
|
|
eb05132008 | ||
|
|
f94a093e8c | ||
|
|
0d0c2daf64 | ||
|
|
823d948b25 | ||
|
|
56831fbcf2 | ||
|
|
bf49b9cb88 | ||
|
|
e01adffbad | ||
|
|
08a5d52d82 | ||
|
|
fdae235742 | ||
|
|
9903fad1e9 | ||
|
|
14bbd5338d | ||
|
|
4a236c2f6f | ||
|
|
0a8cdbd7f1 | ||
|
|
94c49843be | ||
|
|
9281fac898 | ||
|
|
0b2736f454 | ||
|
|
ae116b0d0d | ||
|
|
ba260e3382 | ||
|
|
1282e7687f | ||
|
|
b1d8266eef | ||
|
|
7acae6935b | ||
|
|
092c01cae7 | ||
|
|
56a1066c30 | ||
|
|
1356d71839 | ||
|
|
1eb011e8c3 | ||
|
|
e349eb28b0 | ||
|
|
b000b235a2 | ||
|
|
16fe92282e | ||
|
|
e218e88cf4 | ||
|
|
888ea81a32 | ||
|
|
735fab7640 | ||
|
|
45745c2a47 | ||
|
|
4caff0fcf6 | ||
|
|
762ea6ce7f | ||
|
|
8b4f6553f3 | ||
|
|
a61e44d175 | ||
|
|
e1b1558fc9 | ||
|
|
53225bda4e | ||
|
|
5212769848 | ||
|
|
d5ded3c9f4 | ||
|
|
c92d778894 | ||
|
|
829abd1ad6 | ||
|
|
266d256a07 | ||
|
|
8380cac3e7 | ||
|
|
a24652f901 | ||
|
|
2d203d3c70 | ||
|
|
48d21600da | ||
|
|
2508d0fbb3 | ||
|
|
e90e80c289 | ||
|
|
5e4748f9d9 | ||
|
|
212952f3e9 | ||
|
|
f99b6496c5 | ||
|
|
67423d51b9 | ||
|
|
58465ece65 | ||
|
|
8ede3a0173 | ||
|
|
ad2f0f8950 | ||
|
|
76973a4b4c | ||
|
|
b198e2e029 | ||
|
|
4d6ea401b5 | ||
|
|
b00c4cc3b6 | ||
|
|
4185e64c65 | ||
|
|
6eb2c884a2 | ||
|
|
6c0362a4cf | ||
|
|
50b1755a63 | ||
|
|
ff3c7eb5fb | ||
|
|
3755316d49 | ||
|
|
f952046847 | ||
|
|
969cdb4a63 | ||
|
|
f336d44595 | ||
|
|
a53f93c195 | ||
|
|
fcb334ce33 | ||
|
|
8ddf04a904 | ||
|
|
29698ca169 | ||
|
|
a9baf7436a | ||
|
|
99a8962183 | ||
|
|
afc5b15a6b | ||
|
|
b6ab508e27 | ||
|
|
789e65557a | ||
|
|
8a7806ab2d | ||
|
|
493303e103 | ||
|
|
1d9af05e9e | ||
|
|
5b07c5f2e8 | ||
|
|
2a4ec0cf5b | ||
|
|
a00c44386e | ||
|
|
a38d71bbfb | ||
|
|
a24a3f868c | ||
|
|
f60c516185 | ||
|
|
26f4646304 | ||
|
|
3a351f67e6 | ||
|
|
e7c09cb91e | ||
|
|
ae1a6ef303 | ||
|
|
2ff477a339 | ||
|
|
793f3fb683 | ||
|
|
a472ee7602 | ||
|
|
c62040e232 | ||
|
|
2e7cb510ae | ||
|
|
dbe45904d7 | ||
|
|
5623734276 | ||
|
|
d3b592bffc | ||
|
|
4fcbdae5bf | ||
|
|
ca95d7275a | ||
|
|
61baf3701c | ||
|
|
bbce872ac5 | ||
|
|
0f7ebcd8e4 | ||
|
|
82fc19e7b7 | ||
|
|
839a12bed4 | ||
|
|
2ef23fe1b3 | ||
|
|
fd905b1a06 | ||
|
|
1372210004 | ||
|
|
ade704d065 | ||
|
|
42f48649b9 | ||
|
|
0b08e8b617 | ||
|
|
926b2f1a1b | ||
|
|
1770a1a45f | ||
|
|
50ed2a64c6 | ||
|
|
2332344988 | ||
|
|
7ccc8cdc58 | ||
|
|
ecec9f913e | ||
|
|
777f40fc5e | ||
|
|
327ae35420 | ||
|
|
0d48159da8 | ||
|
|
d36f12a4ea | ||
|
|
709488beb1 | ||
|
|
a9e4583695 | ||
|
|
4702dec933 | ||
|
|
e6352dd691 | ||
|
|
240ea3b857 | ||
|
|
f0908af3c0 | ||
|
|
6834961dd1 | ||
|
|
b404162364 | ||
|
|
e879ef805f | ||
|
|
7077ca5e98 | ||
|
|
a1e6978c8f | ||
|
|
584391dd59 | ||
|
|
bab3ae809c | ||
|
|
c78518baf0 | ||
|
|
556d7e0497 | ||
|
|
2d27936dab | ||
|
|
0cc22de545 | ||
|
|
63f6127049 | ||
|
|
f34e00c986 | ||
|
|
55f60a9fe1 | ||
|
|
7da3618e0c | ||
|
|
56bfa98633 | ||
|
|
96f6188722 | ||
|
|
aa9d359039 | ||
|
|
cef5731028 | ||
|
|
5bc28bd4fd | ||
|
|
55a1d867c3 | ||
|
|
6c3a79802e | ||
|
|
c35c5e0793 | ||
|
|
7bc83caa99 | ||
|
|
3aceca63c6 | ||
|
|
9bc166ffd4 | ||
|
|
fc01b90007 | ||
|
|
e35f1d70e4 | ||
|
|
cab1f3787a | ||
|
|
bb42f4cbc1 | ||
|
|
98dc418a51 | ||
|
|
322b4eb18c | ||
|
|
7f1cc30ed8 | ||
|
|
7b45a6b956 | ||
|
|
e36769e70f | ||
|
|
bd4a4cc4af | ||
|
|
8343fe63cb | ||
|
|
7d89fb8461 | ||
|
|
098955d230 | ||
|
|
d254d14928 | ||
|
|
0a3e8ca535 | ||
|
|
b8a10e0962 | ||
|
|
0aceda96e4 | ||
|
|
44b6ec25a2 | ||
|
|
1b84d1fa9d | ||
|
|
78d5ed2ed2 | ||
|
|
142477ab9b | ||
|
|
b414f79bc5 | ||
|
|
6e08fe21d0 | ||
|
|
9b839655a7 | ||
|
|
3353c0ee1d | ||
|
|
aaecf52c99 | ||
|
|
8b3e960be0 | ||
|
|
3351f71813 | ||
|
|
7490256303 | ||
|
|
041d600e45 | ||
|
|
b4e2588a24 | ||
|
|
68dc14c5a1 | ||
|
|
ef35864e16 | ||
|
|
c0d385b983 | ||
|
|
b2df431fa4 | ||
|
|
69a4bd415a | ||
|
|
4862548e65 | ||
|
|
50248cc9ea | ||
|
|
430822bae3 | ||
|
|
dd9d18208d | ||
|
|
e5b1a71659 | ||
|
|
35f4b13237 | ||
|
|
5f5c31cd5b | ||
|
|
e9530d5ec5 | ||
|
|
143f4aa886 | ||
|
|
ece5c8bb31 | ||
|
|
31baf181a3 | ||
|
|
3bae30c70c | ||
|
|
12b18c6bd1 | ||
|
|
787d9e3bf5 | ||
|
|
f325b54895 | ||
|
|
c5616705b0 | ||
|
|
c0f693d35d | ||
|
|
52a5f132c1 | ||
|
|
f14eac6d10 | ||
|
|
e90fe117ec | ||
|
|
381d737d24 | ||
|
|
7cab5b3b09 | ||
|
|
9f911cb5cb | ||
|
|
3da7cba06c | ||
|
|
b47af9600f | ||
|
|
92c3c707e1 | ||
|
|
5acc54e609 | ||
|
|
9c6352dd5b | ||
|
|
8e29a07df5 | ||
|
|
bd88cd3a06 | ||
|
|
f371b9702f | ||
|
|
3ff4ae29af | ||
|
|
eae0f2e7a9 | ||
|
|
305a98bb79 | ||
|
|
8040a3ed60 | ||
|
|
bb9de7d9b0 | ||
|
|
d8e8bc0068 | ||
|
|
6577e9d852 | ||
|
|
3f8625c65a | ||
|
|
92d69636a7 | ||
|
|
9c28817fba | ||
|
|
773788fb32 | ||
|
|
a393ad8e04 | ||
|
|
71d3714347 | ||
|
|
b7e1329c13 | ||
|
|
59e6d9d10e | ||
|
|
46efb446fb | ||
|
|
d31e3a54fd | ||
|
|
c4e471ac47 | ||
|
|
3b8733e085 | ||
|
|
a7c67d83ca | ||
|
|
8abc1de26d | ||
|
|
2ca9f708a6 | ||
|
|
f8f369fbb2 | ||
|
|
3e9155767b | ||
|
|
8cd4195657 | ||
|
|
ad1a944276 | ||
|
|
02ff4c5657 | ||
|
|
b1b27f2dde | ||
|
|
5097f77469 | ||
|
|
7e826d5002 | ||
|
|
fe8143a56c | ||
|
|
e5442a713a | ||
|
|
1982a46f36 | ||
|
|
c8c3640baf | ||
|
|
fdf47b3f2c | ||
|
|
93fa4b6a37 | ||
|
|
90e9ab70b0 | ||
|
|
573c2386b7 | ||
|
|
d2176aeeb9 | ||
|
|
920aec5c3e | ||
|
|
b792c5459a | ||
|
|
87fbf05fa1 | ||
|
|
67c53250c5 | ||
|
|
d657eea910 | ||
|
|
b5fbb825ed | ||
|
|
d094e7a4c6 | ||
|
|
945c155b17 | ||
|
|
f798072a1e | ||
|
|
f967214b57 | ||
|
|
d0b92e2540 | ||
|
|
8ddfe272bf | ||
|
|
b7a6bad7cd | ||
|
|
e2f6c04406 | ||
|
|
c662725955 | ||
|
|
4b66ddfdef | ||
|
|
2d55b1f592 | ||
|
|
14adfabf7e | ||
|
|
e7a76ede76 | ||
|
|
de47df3bf9 | ||
|
|
5475e6f7c5 | ||
|
|
8e3f3d74d4 | ||
|
|
046f6c66ed | ||
|
|
79f9d6552e | ||
|
|
56b4b63749 | ||
|
|
b3246a48c7 | ||
|
|
71722ef6a3 | ||
|
|
ebf8f00302 | ||
|
|
7445928c7e | ||
|
|
5ab7602f2f | ||
|
|
a340aff63a | ||
|
|
f82042ff00 | ||
|
|
920422e28c | ||
|
|
50d6b7a6f8 | ||
|
|
41d624a36a | ||
|
|
f42c37c82e | ||
|
|
119fcdf6f6 | ||
|
|
a5b093d1a9 | ||
|
|
e07cb44a3e | ||
|
|
fec1bcfd5c | ||
|
|
dbcf658343 | ||
|
|
d89e78c9ca | ||
|
|
ec50650dfa | ||
|
|
7432e551f9 | ||
|
|
4ee6bd44d1 | ||
|
|
26f819098d | ||
|
|
a1c79f93d7 | ||
|
|
9c1b202d74 | ||
|
|
8ad0f59f19 | ||
|
|
50fbe3d5af | ||
|
|
af40a77d24 | ||
|
|
8af9a5e921 | ||
|
|
9807788ecb | ||
|
|
5e2f329f15 | ||
|
|
9572a7adaa | ||
|
|
1ba94f4f5f | ||
|
|
237afa0a3a | ||
|
|
d80b7017cf | ||
|
|
56793c8db7 | ||
|
|
8edb217943 | ||
|
|
23ebcf1065 | ||
|
|
68a5a3d62a | ||
|
|
8d7236b0db | ||
|
|
96c7daf818 | ||
|
|
9d8073d468 | ||
|
|
fc4942e189 | ||
|
|
ca69d025bd | ||
|
|
ffa428e32a | ||
|
|
c24e90eaae | ||
|
|
ab32eff588 | ||
|
|
7f592f2b35 | ||
|
|
3bf7f67adf | ||
|
|
594ce05292 | ||
|
|
fe02ca68d5 | ||
|
|
21ef27ee9b | ||
|
|
09d37f669f | ||
|
|
416b776062 | ||
|
|
5ed05d4020 | ||
|
|
4004bfb5ef | ||
|
|
45aace8966 | ||
|
|
d9fc623dcb | ||
|
|
dbb822f6b0 | ||
|
|
3d64dffc32 | ||
|
|
130ece7bc0 | ||
|
|
b2809b2e9a | ||
|
|
29e89d2965 | ||
|
|
e7d54a639e | ||
|
|
22df98e9bb | ||
|
|
0d45c44c6f | ||
|
|
63c6912841 | ||
|
|
73bce73034 | ||
|
|
b2582796a2 | ||
|
|
8babb6e68f | ||
|
|
d1d28df8a1 | ||
|
|
cd556d5d43 | ||
|
|
2855283a2c | ||
|
|
06c29500f2 | ||
|
|
81104153a6 | ||
|
|
23bfd4683c | ||
|
|
a52a3e3158 | ||
|
|
44e524e3c3 | ||
|
|
9a430f73e2 | ||
|
|
fdea40ec11 | ||
|
|
526d340849 | ||
|
|
fe95f6ad81 | ||
|
|
39e73c37ab | ||
|
|
39b36b6857 | ||
|
|
44e98748c5 | ||
|
|
8a7aeee955 | ||
|
|
1c7befb8d3 | ||
|
|
d5d59ac62c | ||
|
|
562f0762a0 | ||
|
|
e46aedce21 | ||
|
|
57cc09b1d7 | ||
|
|
e1e608b744 | ||
|
|
cbfa5a5118 | ||
|
|
ea9ab5b27c | ||
|
|
357ced6cba | ||
|
|
3ffda69651 | ||
|
|
e1bf4e0762 | ||
|
|
ec7f14b82d | ||
|
|
6520be5b85 | ||
|
|
17e4fad6fb | ||
|
|
d84c416421 | ||
|
|
32803c89a3 | ||
|
|
a86bcb5c29 | ||
|
|
7d76a33790 | ||
|
|
8552e81022 | ||
|
|
eacdde829f | ||
|
|
d873539856 | ||
|
|
24bb2e469d | ||
|
|
e1aa2cc0b8 | ||
|
|
d073947f3b | ||
|
|
3243740dd1 | ||
|
|
f9bd566a3b | ||
|
|
183251487c | ||
|
|
ff532210f7 | ||
|
|
d0a04d9801 | ||
|
|
ea6533db4e | ||
|
|
89d5e7bee5 | ||
|
|
7e6cdee592 | ||
|
|
990c2fb416 | ||
|
|
09e054c6aa | ||
|
|
23f648f53a | ||
|
|
07fa656e7c | ||
|
|
7858c48f11 | ||
|
|
e56d54c3f0 | ||
|
|
f37ca95c10 | ||
|
|
72e51bb072 | ||
|
|
dcfcbf54be | ||
|
|
204936b2d0 | ||
|
|
98856b39ac | ||
|
|
ad5f707486 | ||
|
|
5ecfb0ce6d | ||
|
|
2147b3f06f | ||
|
|
7daed3daaf | ||
|
|
481df4d604 | ||
|
|
cf333873fd | ||
|
|
ae700e8f3a | ||
|
|
16386a9524 | ||
|
|
7e7ce276b2 | ||
|
|
71c6b41b83 | ||
|
|
4b2faae29a | ||
|
|
7e28e562d0 | ||
|
|
93c2e2a597 | ||
|
|
c45d13d834 | ||
|
|
330276cdf7 | ||
|
|
22c7015c69 | ||
|
|
cc67d4a1e2 | ||
|
|
eeb9da696f | ||
|
|
4979e1ac9a | ||
|
|
545353dabf | ||
|
|
545376740c | ||
|
|
8289b02ab0 | ||
|
|
fc0060662b | ||
|
|
df9d432d29 | ||
|
|
76fd6e15cc | ||
|
|
06982efda5 | ||
|
|
3cd9a72495 | ||
|
|
0ce27f274a | ||
|
|
e60f78ac4a | ||
|
|
637d3a24a1 | ||
|
|
24c8b24b1f | ||
|
|
5ad34e2216 | ||
|
|
64c42f0ddf | ||
|
|
0a31ddaae6 | ||
|
|
38476cfeb8 | ||
|
|
decc31f1f0 | ||
|
|
ea0aa64330 | ||
|
|
e9a6044645 | ||
|
|
474d700df2 | ||
|
|
c50ff6faa3 | ||
|
|
c8efef8f04 | ||
|
|
1d22f77568 | ||
|
|
5aa51f5f36 | ||
|
|
335c21c48a | ||
|
|
c35d1cecfe | ||
|
|
0d3e6157cd | ||
|
|
68e4cf4d14 | ||
|
|
9454150f7d | ||
|
|
0a0e16547e | ||
|
|
0aec1b9969 | ||
|
|
3e1ec23409 | ||
|
|
2f9f428a2f | ||
|
|
da15cde49c | ||
|
|
e6ed37139a | ||
|
|
377e33c148 | ||
|
|
e567d88951 | ||
|
|
89b2937b11 | ||
|
|
142ed75468 | ||
|
|
d80eeb044c | ||
|
|
7c69e99914 | ||
|
|
5e1aaf5a44 | ||
|
|
ad610d2f90 | ||
|
|
02934452d6 | ||
|
|
8b054010e1 | ||
|
|
5b77f3839b | ||
|
|
231b792452 | ||
|
|
b468e0c164 | ||
|
|
fa1f9d7009 | ||
|
|
c5a8f3abcd | ||
|
|
dfe6a8d3e3 | ||
|
|
292257770c | ||
|
|
b4c6b2b08b | ||
|
|
6cb4577e1b | ||
|
|
456784db48 | ||
|
|
dd9ea46e58 | ||
|
|
ed3af2fac0 | ||
|
|
02f8132f3a | ||
|
|
55bd90fad9 | ||
|
|
cd7bbb45c3 | ||
|
|
6c7fc0ed22 | ||
|
|
5421bc1386 | ||
|
|
051841e566 | ||
|
|
0c68815cf2 | ||
|
|
0c1138179b | ||
|
|
1f3d1cc73e | ||
|
|
707d1332de | ||
|
|
f6c88da81b | ||
|
|
a651e6e518 | ||
|
|
bea89b93eb | ||
|
|
244c9b96a2 | ||
|
|
a37bd76950 | ||
|
|
9d70032de8 | ||
|
|
e4945b41e9 | ||
|
|
493dc8689c | ||
|
|
bdac2ffa27 | ||
|
|
b1235f3ce0 | ||
|
|
ba4bb63a1f | ||
|
|
3227b0e69c | ||
|
|
29c899627e | ||
|
|
5923781484 | ||
|
|
8bb263a2ec | ||
|
|
94c7bba168 | ||
|
|
f9ad4c068a | ||
|
|
19d68252cd | ||
|
|
72bbe3b1ce | ||
|
|
856824316b | ||
|
|
95e189d1d8 | ||
|
|
c629460acb | ||
|
|
f235a94986 | ||
|
|
632cba86e9 | ||
|
|
6b92c7eccc | ||
|
|
ab0da1abac | ||
|
|
7f31ac7bcb | ||
|
|
57a6fb31b2 | ||
|
|
fd2b6c111c | ||
|
|
302458b505 | ||
|
|
8978a4cf2d | ||
|
|
c70be12bfd | ||
|
|
4241307990 | ||
|
|
727a8ef13d | ||
|
|
7c92558ad1 | ||
|
|
45083d29a6 | ||
|
|
5089d86095 | ||
|
|
80e55ef385 | ||
|
|
b5ed98445f | ||
|
|
82d377abf5 | ||
|
|
2dbea5d1b2 | ||
|
|
4ba35d6189 | ||
|
|
cec3f987f2 | ||
|
|
55050a9f58 | ||
|
|
502dc9ec52 | ||
|
|
9c8999a3ae | ||
|
|
90db42ce3a | ||
|
|
551130f0e1 | ||
|
|
2940a60b3c | ||
|
|
76b9bc0d56 | ||
|
|
42422ccdcd | ||
|
|
e9702ae2de | ||
|
|
5c54852ebe | ||
|
|
718a86ecda | ||
|
|
1223fd2149 | ||
|
|
b09386d102 | ||
|
|
6464698b6d | ||
|
|
9230fd3bd6 | ||
|
|
7771609ea0 | ||
|
|
561a125c92 | ||
|
|
7149461d8e | ||
|
|
02c8bd06f5 | ||
|
|
9f17eb1d28 |
@@ -1,9 +1,42 @@
|
||||
API_KEY=<LLM api key (for example, open ai key)>
|
||||
LLM_NAME=docsgpt
|
||||
VITE_API_STREAMING=true
|
||||
INTERNAL_KEY=<internal key for worker-to-backend authentication>
|
||||
|
||||
# Provider-specific API keys (optional - use these to enable multiple providers)
|
||||
# OPENAI_API_KEY=<your-openai-api-key>
|
||||
# ANTHROPIC_API_KEY=<your-anthropic-api-key>
|
||||
# GOOGLE_API_KEY=<your-google-api-key>
|
||||
# GROQ_API_KEY=<your-groq-api-key>
|
||||
# NOVITA_API_KEY=<your-novita-api-key>
|
||||
# OPEN_ROUTER_API_KEY=<your-openrouter-api-key>
|
||||
|
||||
# Remote Embeddings (Optional - for using a remote embeddings API instead of local SentenceTransformer)
|
||||
# When set, the app will use the remote API and won't load SentenceTransformer (saves RAM)
|
||||
EMBEDDINGS_BASE_URL=
|
||||
EMBEDDINGS_KEY=
|
||||
|
||||
#For Azure (you can delete it if you don't use Azure)
|
||||
OPENAI_API_BASE=
|
||||
OPENAI_API_VERSION=
|
||||
AZURE_DEPLOYMENT_NAME=
|
||||
AZURE_EMBEDDINGS_DEPLOYMENT_NAME=
|
||||
AZURE_EMBEDDINGS_DEPLOYMENT_NAME=
|
||||
|
||||
#Azure AD Application (client) ID
|
||||
MICROSOFT_CLIENT_ID=your-azure-ad-client-id
|
||||
#Azure AD Application client secret
|
||||
MICROSOFT_CLIENT_SECRET=your-azure-ad-client-secret
|
||||
#Azure AD Tenant ID (or 'common' for multi-tenant)
|
||||
MICROSOFT_TENANT_ID=your-azure-ad-tenant-id
|
||||
#If you are using a Microsoft Entra ID tenant,
|
||||
#configure the AUTHORITY variable as
|
||||
#"https://login.microsoftonline.com/TENANT_GUID"
|
||||
#or "https://login.microsoftonline.com/contoso.onmicrosoft.com".
|
||||
#Alternatively, use "https://login.microsoftonline.com/common" for multi-tenant app.
|
||||
MICROSOFT_AUTHORITY=https://{tenantId}.ciamlogin.com/{tenantId}
|
||||
|
||||
# User-data Postgres DB (Phase 0 of the MongoDB→Postgres migration).
|
||||
# Standard Postgres URI — `postgres://` and `postgresql://` both work.
|
||||
# Leave unset while the migration is still being rolled out; the app will
|
||||
# fall back to MongoDB for user data until POSTGRES_URI is configured.
|
||||
# POSTGRES_URI=postgresql://docsgpt:docsgpt@localhost:5432/docsgpt
|
||||
|
||||
2
.gitattributes
vendored
Normal file
2
.gitattributes
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
# Auto detect text files and perform LF normalization
|
||||
* text=auto
|
||||
99
.github/INCIDENT_RESPONSE.md
vendored
Normal file
99
.github/INCIDENT_RESPONSE.md
vendored
Normal file
@@ -0,0 +1,99 @@
|
||||
# DocsGPT Incident Response Plan (IRP)
|
||||
|
||||
This playbook describes how maintainers respond to confirmed or suspected security incidents.
|
||||
|
||||
- Vulnerability reporting: [`SECURITY.md`](../SECURITY.md)
|
||||
- Non-security bugs/features: [`CONTRIBUTING.md`](../CONTRIBUTING.md)
|
||||
|
||||
## Severity
|
||||
|
||||
| Severity | Definition | Typical examples |
|
||||
|---|---|---|
|
||||
| **Critical** | Active exploitation, supply-chain compromise, or confirmed data breach requiring immediate user action. | Compromised release artifact/image; remote execution. |
|
||||
| **High** | Serious undisclosed vulnerability with no practical workaround, or CVSS >= 7.0. | key leakage; prompt injection enabling cross-tenant access. |
|
||||
| **Medium** | Material impact but constrained by preconditions/scope, or a practical workaround exists. | Auth-required exploit; dependency CVE with limited reachability. |
|
||||
| **Low** | Defense-in-depth or narrow availability impact with no confirmed data exposure. | Missing rate limiting; hardening gap without exploit evidence. |
|
||||
|
||||
|
||||
## Response workflow
|
||||
|
||||
### 1) Triage (target: initial response within 48 hours)
|
||||
|
||||
1. Acknowledge report.
|
||||
2. Validate on latest release and `main`.
|
||||
3. Confirm in-scope security issue vs. hardening item (per `SECURITY.md`).
|
||||
4. Assign severity and open a **draft GitHub Security Advisory (GHSA)** (no public issue).
|
||||
5. Determine whether root cause is DocsGPT code or upstream dependency/provider.
|
||||
|
||||
### 2) Investigation
|
||||
|
||||
1. Identify affected components, versions, and deployment scope (self-hosted, cloud, or both).
|
||||
2. For AI issues, explicitly evaluate prompt injection, document isolation, and output leakage.
|
||||
3. Request a CVE through GHSA for **Medium+** issues.
|
||||
|
||||
### 3) Containment, fix, and disclosure
|
||||
|
||||
1. Implement and test fix in private security workflow (GHSA private fork/branch).
|
||||
2. Merge fix to `main`, cut patched release, and verify published artifacts/images.
|
||||
3. Patch managed cloud deployment (`app.docsgpt.cloud`) and other deployments as soon as validated.
|
||||
4. Publish GHSA with CVE (if assigned), affected/fixed versions, CVSS, mitigations, and upgrade guidance.
|
||||
5. **Critical/High:** coordinate disclosure timing with reporter (goal: <= 90 days) and publish a notice.
|
||||
6. **Medium/Low:** include in next scheduled release unless risk requires immediate out-of-band patching.
|
||||
|
||||
### 4) Post-incident
|
||||
|
||||
1. Monitor support channels (GitHub/Discord) for regressions or exploitation reports.
|
||||
2. Run a short retrospective (root cause, detection, response gaps, prevention work).
|
||||
3. Track follow-up hardening actions with owners/dates.
|
||||
4. Update this IRP and related runbooks as needed.
|
||||
|
||||
## Scenario playbooks
|
||||
|
||||
### Supply-chain compromise
|
||||
|
||||
1. Freeze releases and investigate blast radius.
|
||||
2. Rotate credentials in order: Docker Hub -> GitHub tokens -> LLM provider keys -> DB credentials -> `JWT_SECRET_KEY` -> `ENCRYPTION_SECRET_KEY` -> `INTERNAL_KEY`.
|
||||
3. Replace compromised artifacts/tags with clean releases and revoke/remove bad tags where possible.
|
||||
4. Publish advisory with exact affected versions and required user actions.
|
||||
|
||||
### Data exposure
|
||||
|
||||
1. Determine scope (users, documents, keys, logs, time window).
|
||||
2. Disable affected path or hotfix immediately for managed cloud.
|
||||
3. Notify affected users with concrete remediation steps (for example, rotate keys).
|
||||
4. Continue through standard fix/disclosure workflow.
|
||||
|
||||
### Critical regression with security impact
|
||||
|
||||
1. Identify introducing change (`git bisect` if needed).
|
||||
2. Publish workaround within 24 hours (for example, pin to known-good version).
|
||||
3. Ship patch release with regression test and close incident with public summary.
|
||||
|
||||
## AI-specific guidance
|
||||
|
||||
Treat confirmed AI-specific abuse as security incidents:
|
||||
|
||||
- Prompt injection causing sensitive data exfiltration (from tools that don't belong to the agent) -> **High**
|
||||
- Cross-tenant retrieval/isolation failure -> **High**
|
||||
- API key disclosure in output -> **High**
|
||||
|
||||
## Secret rotation quick reference
|
||||
|
||||
| Secret | Standard rotation action |
|
||||
|---|---|
|
||||
| Docker Hub credentials | Revoke/replace in Docker Hub; update CI/CD secrets |
|
||||
| GitHub tokens/PATs | Revoke/replace in GitHub; update automation secrets |
|
||||
| LLM provider API keys | Rotate in provider console; update runtime/deploy secrets |
|
||||
| Database credentials | Rotate in DB platform; redeploy with new secrets |
|
||||
| `JWT_SECRET_KEY` | Rotate and redeploy (invalidates all active user sessions/tokens) |
|
||||
| `ENCRYPTION_SECRET_KEY` | Rotate and redeploy (re-encrypt stored data if possible; existing encrypted data may become inaccessible) |
|
||||
| `INTERNAL_KEY` | Rotate and redeploy (invalidates worker-to-backend authentication) |
|
||||
|
||||
## Maintenance
|
||||
|
||||
Review this document:
|
||||
|
||||
- after every **Critical/High** incident, and
|
||||
- at least annually.
|
||||
|
||||
Changes should be proposed via pull request to `main`.
|
||||
6
.github/dependabot.yml
vendored
6
.github/dependabot.yml
vendored
@@ -13,7 +13,11 @@ updates:
|
||||
directory: "/frontend" # Location of package manifests
|
||||
schedule:
|
||||
interval: "daily"
|
||||
- package-ecosystem: "npm"
|
||||
directory: "/extensions/react-widget"
|
||||
schedule:
|
||||
interval: "daily"
|
||||
- package-ecosystem: "github-actions"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "daily"
|
||||
interval: "daily"
|
||||
11
.github/styles/DocsGPT/Spelling.yml
vendored
Normal file
11
.github/styles/DocsGPT/Spelling.yml
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
extends: spelling
|
||||
level: warning
|
||||
message: "Did you really mean '%s'?"
|
||||
ignore:
|
||||
- "**/node_modules/**"
|
||||
- "**/dist/**"
|
||||
- "**/build/**"
|
||||
- "**/coverage/**"
|
||||
- "**/public/**"
|
||||
- "**/static/**"
|
||||
vocab: DocsGPT
|
||||
80
.github/styles/config/vocabularies/DocsGPT/accept.txt
vendored
Normal file
80
.github/styles/config/vocabularies/DocsGPT/accept.txt
vendored
Normal file
@@ -0,0 +1,80 @@
|
||||
Agentic
|
||||
Anthropic's
|
||||
api
|
||||
APIs
|
||||
Atlassian
|
||||
automations
|
||||
autoescaping
|
||||
Autoescaping
|
||||
backfill
|
||||
backfills
|
||||
bool
|
||||
boolean
|
||||
brave_web_search
|
||||
chatbot
|
||||
Chatwoot
|
||||
config
|
||||
configs
|
||||
CSVs
|
||||
dev
|
||||
diarization
|
||||
Docling
|
||||
docsgpt
|
||||
docstrings
|
||||
Entra
|
||||
env
|
||||
enqueues
|
||||
EOL
|
||||
ESLint
|
||||
feedbacks
|
||||
Figma
|
||||
GPUs
|
||||
Groq
|
||||
hardcode
|
||||
hardcoding
|
||||
Idempotency
|
||||
JSONPath
|
||||
kubectl
|
||||
Lightsail
|
||||
llama_cpp
|
||||
llm
|
||||
LLM
|
||||
LLMs
|
||||
LMDeploy
|
||||
Milvus
|
||||
Mixtral
|
||||
namespace
|
||||
namespaces
|
||||
needs_auth
|
||||
Nextra
|
||||
Novita
|
||||
npm
|
||||
OAuth
|
||||
Ollama
|
||||
opencode
|
||||
parsable
|
||||
passthrough
|
||||
PDFs
|
||||
pgvector
|
||||
Postgres
|
||||
Premade
|
||||
Pydantic
|
||||
pytest
|
||||
Qdrant
|
||||
qdrant
|
||||
Repo
|
||||
repo
|
||||
Sanitization
|
||||
SDKs
|
||||
SGLang
|
||||
Shareability
|
||||
Signup
|
||||
Supabase
|
||||
UIs
|
||||
uncomment
|
||||
URl
|
||||
vectorstore
|
||||
Vite
|
||||
VSCode
|
||||
VSCode's
|
||||
widget's
|
||||
3
.github/workflows/lint.yml
vendored
3
.github/workflows/lint.yml
vendored
@@ -7,6 +7,9 @@ on:
|
||||
pull_request:
|
||||
types: [ opened, synchronize ]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
ruff:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
114
.github/workflows/npm-publish.yml
vendored
Normal file
114
.github/workflows/npm-publish.yml
vendored
Normal file
@@ -0,0 +1,114 @@
|
||||
name: Publish npm libraries
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
version:
|
||||
description: >
|
||||
Version bump type (patch | minor | major) or explicit semver (e.g. 1.2.3).
|
||||
Applies to both docsgpt and docsgpt-react.
|
||||
required: true
|
||||
default: patch
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
publish:
|
||||
runs-on: ubuntu-latest
|
||||
environment: npm-release
|
||||
defaults:
|
||||
run:
|
||||
working-directory: extensions/react-widget
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 20
|
||||
registry-url: https://registry.npmjs.org
|
||||
|
||||
- name: Install dependencies
|
||||
run: npm ci
|
||||
|
||||
# ── docsgpt (HTML embedding bundle) ──────────────────────────────────
|
||||
# Uses the `build` script (parcel build src/browser.tsx) and keeps
|
||||
# the `targets` field so Parcel produces browser-optimised bundles.
|
||||
|
||||
- name: Set package name → docsgpt
|
||||
run: jq --arg n "docsgpt" '.name=$n' package.json > _tmp.json && mv _tmp.json package.json
|
||||
|
||||
- name: Bump version (docsgpt)
|
||||
id: version_docsgpt
|
||||
run: |
|
||||
VERSION="${{ github.event.inputs.version }}"
|
||||
NEW_VER=$(npm version "${VERSION:-patch}" --no-git-tag-version)
|
||||
echo "version=${NEW_VER#v}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Build docsgpt
|
||||
run: npm run build
|
||||
|
||||
- name: Publish docsgpt
|
||||
run: npm publish --verbose
|
||||
env:
|
||||
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
|
||||
|
||||
# ── docsgpt-react (React library bundle) ─────────────────────────────
|
||||
# Uses `build:react` script (parcel build src/index.ts) and strips
|
||||
# the `targets` field so Parcel treats the output as a plain library
|
||||
# without browser-specific target resolution, producing a smaller bundle.
|
||||
|
||||
- name: Reset package.json from source control
|
||||
run: git checkout -- package.json
|
||||
|
||||
- name: Set package name → docsgpt-react
|
||||
run: jq --arg n "docsgpt-react" '.name=$n' package.json > _tmp.json && mv _tmp.json package.json
|
||||
|
||||
- name: Remove targets field (react library build)
|
||||
run: jq 'del(.targets)' package.json > _tmp.json && mv _tmp.json package.json
|
||||
|
||||
- name: Bump version (docsgpt-react) to match docsgpt
|
||||
run: npm version "${{ steps.version_docsgpt.outputs.version }}" --no-git-tag-version
|
||||
|
||||
- name: Clean dist before react build
|
||||
run: rm -rf dist
|
||||
|
||||
- name: Build docsgpt-react
|
||||
run: npm run build:react
|
||||
|
||||
- name: Publish docsgpt-react
|
||||
run: npm publish --verbose
|
||||
env:
|
||||
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
|
||||
|
||||
# ── Commit the bumped version back to the repository ─────────────────
|
||||
|
||||
- name: Reset package.json and write final version
|
||||
run: |
|
||||
git checkout -- package.json
|
||||
jq --arg v "${{ steps.version_docsgpt.outputs.version }}" '.version=$v' \
|
||||
package.json > _tmp.json && mv _tmp.json package.json
|
||||
npm install --package-lock-only
|
||||
|
||||
- name: Commit version bump and create PR
|
||||
run: |
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
BRANCH="chore/bump-npm-v${{ steps.version_docsgpt.outputs.version }}"
|
||||
git checkout -b "$BRANCH"
|
||||
git add package.json package-lock.json
|
||||
git commit -m "chore: bump npm libraries to v${{ steps.version_docsgpt.outputs.version }}"
|
||||
git push origin "$BRANCH"
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Create PR
|
||||
run: |
|
||||
gh pr create \
|
||||
--title "chore: bump npm libraries to v${{ steps.version_docsgpt.outputs.version }}" \
|
||||
--body "Automated version bump after npm publish." \
|
||||
--base main
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
10
.github/workflows/pytest.yml
vendored
10
.github/workflows/pytest.yml
vendored
@@ -1,5 +1,9 @@
|
||||
name: Run python tests with pytest
|
||||
on: [push, pull_request]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
pytest_and_coverage:
|
||||
name: Run tests and count coverage
|
||||
@@ -16,15 +20,15 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install pytest pytest-cov
|
||||
cd application
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
|
||||
cd ../tests
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
|
||||
- name: Test with pytest and generate coverage report
|
||||
run: |
|
||||
python -m pytest --cov=application --cov-report=xml
|
||||
python -m pytest --cov=application --cov-report=xml --cov-report=term-missing
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: github.event_name == 'pull_request' && matrix.python-version == '3.12'
|
||||
uses: codecov/codecov-action@v5
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
|
||||
|
||||
34
.github/workflows/react-widget-build.yml
vendored
Normal file
34
.github/workflows/react-widget-build.yml
vendored
Normal file
@@ -0,0 +1,34 @@
|
||||
name: React Widget Build
|
||||
|
||||
on:
|
||||
push:
|
||||
paths:
|
||||
- 'extensions/react-widget/**'
|
||||
pull_request:
|
||||
paths:
|
||||
- 'extensions/react-widget/**'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
working-directory: extensions/react-widget
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 20
|
||||
cache: npm
|
||||
cache-dependency-path: extensions/react-widget/package-lock.json
|
||||
|
||||
- name: Install dependencies
|
||||
run: npm ci
|
||||
|
||||
- name: Build
|
||||
run: npm run build
|
||||
30
.github/workflows/vale.yml
vendored
Normal file
30
.github/workflows/vale.yml
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
name: Vale Documentation Linter
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- 'docs/**/*.md'
|
||||
- 'docs/**/*.mdx'
|
||||
- '**/*.md'
|
||||
- '.vale.ini'
|
||||
- '.github/styles/**'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
vale:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Vale linter
|
||||
uses: errata-ai/vale-action@v2
|
||||
with:
|
||||
files: docs
|
||||
fail_on_error: false
|
||||
version: 3.0.5
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
25
.github/workflows/zizmor.yml
vendored
Normal file
25
.github/workflows/zizmor.yml
vendored
Normal file
@@ -0,0 +1,25 @@
|
||||
name: GitHub Actions Security Analysis
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["master"]
|
||||
pull_request:
|
||||
branches: ["**"]
|
||||
|
||||
permissions: {}
|
||||
|
||||
jobs:
|
||||
zizmor:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
permissions:
|
||||
security-events: write # Required for upload-sarif (used by zizmor-action) to upload SARIF files.
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Run zizmor 🌈
|
||||
uses: zizmorcore/zizmor-action@71321a20a9ded102f6e9ce5718a2fcec2c4f70d8 # v0.5.2
|
||||
12
.gitignore
vendored
12
.gitignore
vendored
@@ -2,7 +2,10 @@
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
results.txt
|
||||
experiments/
|
||||
|
||||
experiments
|
||||
# C extensions
|
||||
*.so
|
||||
*.next
|
||||
@@ -69,6 +72,7 @@ instance/
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
docs/public/_pagefind/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
@@ -104,6 +108,8 @@ celerybeat.pid
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
# Machine-specific Claude Code guidance (see CLAUDE.md preamble)
|
||||
CLAUDE.md
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
@@ -113,6 +119,7 @@ venv.bak/
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
.jwt_secret_key
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
@@ -144,6 +151,10 @@ frontend/yarn-error.log*
|
||||
frontend/pnpm-debug.log*
|
||||
frontend/lerna-debug.log*
|
||||
|
||||
# Keep frontend utility helpers tracked (overrides global lib/ ignore)
|
||||
!frontend/src/lib/
|
||||
!frontend/src/lib/**
|
||||
|
||||
frontend/node_modules
|
||||
frontend/dist
|
||||
frontend/dist-ssr
|
||||
@@ -172,5 +183,6 @@ application/vectors/
|
||||
|
||||
node_modules/
|
||||
.vscode/settings.json
|
||||
.vscode/sftp.json
|
||||
/models/
|
||||
model/
|
||||
|
||||
@@ -1,2 +1,6 @@
|
||||
# Allow lines to be as long as 120 characters.
|
||||
line-length = 120
|
||||
line-length = 120
|
||||
|
||||
[lint.per-file-ignores]
|
||||
# Integration tests use sys.path.insert() before imports for standalone execution
|
||||
"tests/integration/*.py" = ["E402"]
|
||||
7
.vale.ini
Normal file
7
.vale.ini
Normal file
@@ -0,0 +1,7 @@
|
||||
MinAlertLevel = warning
|
||||
StylesPath = .github/styles
|
||||
Vocab = DocsGPT
|
||||
|
||||
[*.{md,mdx}]
|
||||
BasedOnStyles = DocsGPT
|
||||
|
||||
33
.vscode/launch.json
vendored
33
.vscode/launch.json
vendored
@@ -2,15 +2,11 @@
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Docker Debug Frontend",
|
||||
"name": "Frontend Debug (npm)",
|
||||
"type": "node-terminal",
|
||||
"request": "launch",
|
||||
"type": "chrome",
|
||||
"preLaunchTask": "docker-compose: debug:frontend",
|
||||
"url": "http://127.0.0.1:5173",
|
||||
"webRoot": "${workspaceFolder}/frontend",
|
||||
"skipFiles": [
|
||||
"<node_internals>/**"
|
||||
]
|
||||
"command": "npm run dev",
|
||||
"cwd": "${workspaceFolder}/frontend"
|
||||
},
|
||||
{
|
||||
"name": "Flask Debugger",
|
||||
@@ -49,6 +45,27 @@
|
||||
"--pool=solo"
|
||||
],
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
{
|
||||
"name": "Dev Containers (Mongo + Redis)",
|
||||
"type": "node-terminal",
|
||||
"request": "launch",
|
||||
"command": "docker compose -f deployment/docker-compose-dev.yaml up --build",
|
||||
"cwd": "${workspaceFolder}"
|
||||
}
|
||||
],
|
||||
"compounds": [
|
||||
{
|
||||
"name": "DocsGPT: Full Stack",
|
||||
"configurations": [
|
||||
"Frontend Debug (npm)",
|
||||
"Flask Debugger",
|
||||
"Celery Debugger"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "DocsGPT",
|
||||
"order": 1
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
21
.vscode/tasks.json
vendored
21
.vscode/tasks.json
vendored
@@ -1,21 +0,0 @@
|
||||
{
|
||||
"version": "2.0.0",
|
||||
"tasks": [
|
||||
{
|
||||
"type": "docker-compose",
|
||||
"label": "docker-compose: debug:frontend",
|
||||
"dockerCompose": {
|
||||
"up": {
|
||||
"detached": true,
|
||||
"services": [
|
||||
"frontend"
|
||||
],
|
||||
"build": true
|
||||
},
|
||||
"files": [
|
||||
"${workspaceFolder}/docker-compose.yaml"
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
134
AGENTS.md
Normal file
134
AGENTS.md
Normal file
@@ -0,0 +1,134 @@
|
||||
# AGENTS.md
|
||||
|
||||
- Read `CONTRIBUTING.md` before making non-trivial changes.
|
||||
- For day-to-day development and feature work, follow the development-environment workflow rather than defaulting to `setup.sh` / `setup.ps1`.
|
||||
- Avoid using the setup scripts during normal feature work unless the user explicitly asks for them. Users configure `.env` usually.
|
||||
- Try to follow red/green TDD
|
||||
|
||||
### Check existing dev prerequisites first
|
||||
|
||||
For feature work, do **not** assume the environment needs to be recreated.
|
||||
|
||||
- Check whether the user already has a Python virtual environment such as `venv/` or `.venv/`.
|
||||
- Check whether MongoDB is already running.
|
||||
- Check whether Redis is already running.
|
||||
- Reuse what is already working. Do not stop or recreate MongoDB, Redis, or the Python environment unless the task is environment setup or troubleshooting.
|
||||
|
||||
## Normal local development commands
|
||||
|
||||
Use these commands once the dev prerequisites above are satisfied.
|
||||
|
||||
### Backend
|
||||
|
||||
```bash
|
||||
source .venv/bin/activate # macOS/Linux
|
||||
uv pip install -r application/requirements.txt # or: pip install -r application/requirements.txt
|
||||
```
|
||||
|
||||
Run the Flask API (if needed):
|
||||
|
||||
```bash
|
||||
flask --app application/app.py run --host=0.0.0.0 --port=7091
|
||||
```
|
||||
|
||||
Run the Celery worker in a separate terminal (if needed):
|
||||
|
||||
```bash
|
||||
celery -A application.app.celery worker -l INFO
|
||||
```
|
||||
|
||||
On macOS, prefer the solo pool for Celery:
|
||||
|
||||
```bash
|
||||
python -m celery -A application.app.celery worker -l INFO --pool=solo
|
||||
```
|
||||
|
||||
### Frontend
|
||||
|
||||
Install dependencies only when needed, then run the dev server:
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
npm install --include=dev
|
||||
npm run dev
|
||||
```
|
||||
|
||||
### Docs site
|
||||
|
||||
```bash
|
||||
cd docs
|
||||
npm install
|
||||
```
|
||||
|
||||
### Python / backend changes validation
|
||||
|
||||
```bash
|
||||
ruff check .
|
||||
python -m pytest
|
||||
```
|
||||
|
||||
### Frontend changes
|
||||
|
||||
```bash
|
||||
cd frontend && npm run lint
|
||||
cd frontend && npm run build
|
||||
```
|
||||
|
||||
### Documentation changes
|
||||
|
||||
```bash
|
||||
cd docs && npm run build
|
||||
```
|
||||
|
||||
If Vale is installed locally and you edited prose, also run:
|
||||
|
||||
```bash
|
||||
vale .
|
||||
```
|
||||
|
||||
## Repository map
|
||||
|
||||
- `application/`: Flask backend, API routes, agent logic, retrieval, parsing, security, storage, Celery worker, and WSGI entrypoints.
|
||||
- `tests/`: backend unit/integration tests and test-only Python dependencies.
|
||||
- `frontend/`: Vite + React + TypeScript application.
|
||||
- `frontend/src/`: main UI code, including `components`, `conversation`, `hooks`, `locale`, `settings`, `upload`, and Redux store wiring in `store.ts`.
|
||||
- `docs/`: separate documentation site built with Next.js/Nextra.
|
||||
- `extensions/`: integrations and widgets such as Chatwoot, Chrome, Discord, React widget, Slack bot, and web widget.
|
||||
- `deployment/`: Docker Compose variants and Kubernetes manifests.
|
||||
|
||||
## Coding rules
|
||||
|
||||
### Backend
|
||||
|
||||
- Follow PEP 8 and keep Python line length at or under 120 characters.
|
||||
- Use type hints for function arguments and return values.
|
||||
- Add Google-style docstrings to new or substantially changed functions and classes.
|
||||
- Add or update tests under `tests/` for backend behavior changes.
|
||||
- Keep changes narrow in `api`, `auth`, `security`, `parser`, `retriever`, and `storage` areas.
|
||||
|
||||
### Backend Abstractions
|
||||
|
||||
- LLM providers implement a common interface in `application/llm/` (add new providers by extending the base class).
|
||||
- Vector stores are abstracted in `application/vectorstore/`.
|
||||
- Parsers live in `application/parser/` and handle different document formats in the ingestion stage.
|
||||
- Agents and tools are in `application/agents/` and `application/agents/tools/`.
|
||||
- Celery setup/config lives in `application/celery_init.py` and `application/celeryconfig.py`.
|
||||
- Settings and env vars are managed via Pydantic in `application/core/settings.py`.
|
||||
|
||||
### Frontend
|
||||
|
||||
- Follow the existing ESLint + Prettier setup.
|
||||
- Prefer small, reusable functional components and hooks.
|
||||
- If shared state must be added, use Redux rather than introducing a new global state library.
|
||||
- Avoid broad UI refactors unless the task explicitly asks for them.
|
||||
- Do not re-create components if we already have some in the app.
|
||||
|
||||
## PR readiness
|
||||
|
||||
Before opening a PR:
|
||||
|
||||
- run the relevant validation commands above
|
||||
- confirm backend changes still work end-to-end after ingesting sample data when applicable
|
||||
- clearly summarize user-visible behavior changes
|
||||
- mention any config, dependency, or deployment implications
|
||||
- Ask your user to attach a screenshot or a video to it
|
||||
@@ -22,6 +22,11 @@ Thank you for choosing to contribute to DocsGPT! We are all very grateful!
|
||||
|
||||
- We have a frontend built on React (Vite) and a backend in Python.
|
||||
|
||||
> **Required for every PR:** Please attach screenshots or a short screen
|
||||
> recording that shows the working version of your changes. This makes the
|
||||
> requirement visible to reviewers and helps them quickly verify what you are
|
||||
> submitting.
|
||||
|
||||
|
||||
Before creating issues, please check out how the latest version of our app looks and works by launching it via [Quickstart](https://github.com/arc53/DocsGPT#quickstart) the version on our live demo is slightly modified with login. Your issues should relate to the version you can launch via [Quickstart](https://github.com/arc53/DocsGPT#quickstart).
|
||||
|
||||
@@ -125,7 +130,7 @@ Here's a step-by-step guide on how to contribute to DocsGPT:
|
||||
```
|
||||
|
||||
9. **Submit a Pull Request (PR):**
|
||||
- Create a Pull Request from your branch to the main repository. Make sure to include a detailed description of your changes and reference any related issues.
|
||||
- Create a Pull Request from your branch to the main repository. Make sure to include a detailed description of your changes, reference any related issues, and attach screenshots or a screen recording showing the working version.
|
||||
|
||||
10. **Collaborate:**
|
||||
- Be responsive to comments and feedback on your PR.
|
||||
@@ -147,5 +152,5 @@ Here's a step-by-step guide on how to contribute to DocsGPT:
|
||||
Thank you for considering contributing to DocsGPT! 🙏
|
||||
|
||||
## Questions/collaboration
|
||||
Feel free to join our [Discord](https://discord.gg/n5BX8dh8rU). We're very friendly and welcoming to new contributors, so don't hesitate to reach out.
|
||||
Feel free to join our [Discord](https://discord.gg/vN7YFfdMpj). We're very friendly and welcoming to new contributors, so don't hesitate to reach out.
|
||||
# Thank you so much for considering to contributing DocsGPT!🙏
|
||||
|
||||
39
HACKTOBERFEST.md
Normal file
39
HACKTOBERFEST.md
Normal file
@@ -0,0 +1,39 @@
|
||||
# **🎉 Join the Hacktoberfest with DocsGPT and win a Free T-shirt for a meaningful PR! 🎉**
|
||||
|
||||
Welcome, contributors! We're excited to announce that DocsGPT is participating in Hacktoberfest. Get involved by submitting meaningful pull requests.
|
||||
|
||||
All Meaningful contributors with accepted PRs that were created for issues with the `hacktoberfest` label (set by our maintainer team: dartpain, siiddhantt, pabik, ManishMadan2882) will receive a cool T-shirt! 🤩.
|
||||
<img width="1331" height="678" alt="hacktoberfest-mocks-preview" src="https://github.com/user-attachments/assets/633f6377-38db-48f5-b519-a8b3855a9eb4" />
|
||||
|
||||
Fill in [this form](https://forms.gle/Npaba4n9Epfyx56S8
|
||||
) after your PR was merged please
|
||||
|
||||
If you are in doubt don't hesitate to ping us on discord, ping me - Alex (dartpain).
|
||||
|
||||
## 📜 Here's How to Contribute:
|
||||
```text
|
||||
🛠️ Code: This is the golden ticket! Make meaningful contributions through PRs.
|
||||
|
||||
🧩 API extension: Build an app utilising DocsGPT API. We prefer submissions that showcase original ideas and turn the API into an AI agent.
|
||||
They can be a completely separate repos.
|
||||
For example:
|
||||
https://github.com/arc53/tg-bot-docsgpt-extenstion or
|
||||
https://github.com/arc53/DocsGPT-cli
|
||||
|
||||
Non-Code Contributions:
|
||||
|
||||
📚 Wiki: Improve our documentation, create a guide.
|
||||
|
||||
🖥️ Design: Improve the UI/UX or design a new feature.
|
||||
```
|
||||
|
||||
### 📝 Guidelines for Pull Requests:
|
||||
- Familiarize yourself with the current contributions and our [Roadmap](https://github.com/orgs/arc53/projects/2).
|
||||
- Before contributing check existing [issues](https://github.com/arc53/DocsGPT/issues) or [create](https://github.com/arc53/DocsGPT/issues/new/choose) an issue and wait to get assigned.
|
||||
- Once you are finished with your contribution, please fill in this [form](https://forms.gle/Npaba4n9Epfyx56S8).
|
||||
- Refer to the [Documentation](https://docs.docsgpt.cloud/).
|
||||
- Feel free to join our [Discord](https://discord.gg/vN7YFfdMpj) server. We're here to help newcomers, so don't hesitate to jump in! Join us [here](https://discord.gg/vN7YFfdMpj).
|
||||
|
||||
Thank you very much for considering contributing to DocsGPT during Hacktoberfest! 🙏 Your contributions (not just simple typos) could earn you a stylish new t-shirt.
|
||||
|
||||
We will publish a t-shirt design later into the October.
|
||||
64
README.md
64
README.md
@@ -3,11 +3,11 @@
|
||||
</h1>
|
||||
|
||||
<p align="center">
|
||||
<strong>Open-Source RAG Assistant</strong>
|
||||
<strong>Private AI for agents, assistants and enterprise search</strong>
|
||||
</p>
|
||||
|
||||
<p align="left">
|
||||
<strong><a href="https://www.docsgpt.cloud/">DocsGPT</a></strong> is an open-source genAI tool that helps users get reliable answers from any knowledge source, while avoiding hallucinations. It enables quick and reliable information retrieval, with tooling and agentic system capability built in.
|
||||
<strong><a href="https://www.docsgpt.cloud/">DocsGPT</a></strong> is an open-source AI platform for building intelligent agents and assistants. Features Agent Builder, deep research tools, document analysis (PDF, Office, web content, and audio), Multi-model support (choose your provider or run locally), and rich API connectivity for agents with actionable tools and integrations. Deploy anywhere with complete privacy control.
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
@@ -16,23 +16,27 @@
|
||||
<a href="https://github.com/arc53/DocsGPT"></a>
|
||||
<a href="https://github.com/arc53/DocsGPT/blob/main/LICENSE"></a>
|
||||
<a href="https://www.bestpractices.dev/projects/9907"><img src="https://www.bestpractices.dev/projects/9907/badge"></a>
|
||||
<a href="https://discord.gg/n5BX8dh8rU"></a>
|
||||
<a href="https://twitter.com/docsgptai"></a>
|
||||
<a href="https://discord.gg/vN7YFfdMpj"></a>
|
||||
<a href="https://x.com/docsgptai"></a>
|
||||
|
||||
<a href="https://docs.docsgpt.cloud/quickstart">⚡️ Quickstart</a> • <a href="https://app.docsgpt.cloud/">☁️ Cloud Version</a> • <a href="https://discord.gg/n5BX8dh8rU">💬 Discord</a>
|
||||
<br>
|
||||
<a href="https://docs.docsgpt.cloud/">📖 Documentation</a> • <a href="https://github.com/arc53/DocsGPT/blob/main/CONTRIBUTING.md">👫 Contribute</a> • <a href="https://blog.docsgpt.cloud/">🗞 Blog</a>
|
||||
<br>
|
||||
<a href="https://docs.docsgpt.cloud/quickstart">⚡️ Quickstart</a> • <a href="https://app.docsgpt.cloud/">☁️ Cloud Version</a> • <a href="https://discord.gg/vN7YFfdMpj">💬 Discord</a>
|
||||
<br>
|
||||
<a href="https://docs.docsgpt.cloud/">📖 Documentation</a> • <a href="https://github.com/arc53/DocsGPT/blob/main/CONTRIBUTING.md">👫 Contribute</a> • <a href="https://blog.docsgpt.cloud/">🗞 Blog</a>
|
||||
<br>
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
<div align="center">
|
||||
<img src="https://d3dg1063dc54p9.cloudfront.net/videos/demov7.gif" alt="video-example-of-docs-gpt" width="800" height="450">
|
||||
<br>
|
||||
<img src="https://d3dg1063dc54p9.cloudfront.net/videos/demo-26.gif" alt="video-example-of-docs-gpt" width="800" height="480">
|
||||
</div>
|
||||
<h3 align="left">
|
||||
<strong>Key Features:</strong>
|
||||
</h3>
|
||||
<ul align="left">
|
||||
<li><strong>🗂️ Wide Format Support:</strong> Reads PDF, DOCX, CSV, XLSX, EPUB, MD, RST, HTML, MDX, JSON, PPTX, and images.</li>
|
||||
<li><strong>🗂️ Wide Format Support:</strong> Reads PDF, DOCX, CSV, XLSX, EPUB, MD, RST, HTML, MDX, JSON, PPTX, images, and audio files such as MP3, WAV, M4A, OGG, and WebM.</li>
|
||||
<li><strong>🎙️ Speech Workflows:</strong> Record voice input into chat, transcribe audio on the backend, and ingest meeting recordings or voice notes as searchable knowledge.</li>
|
||||
<li><strong>🌐 Web & Data Integration:</strong> Ingests from URLs, sitemaps, Reddit, GitHub and web crawlers.</li>
|
||||
<li><strong>✅ Reliable Answers:</strong> Get accurate, hallucination-free responses with source citations viewable in a clean UI.</li>
|
||||
<li><strong>🔑 Streamlined API Keys:</strong> Generate keys linked to your settings, documents, and models, simplifying chatbot and integration setup.</li>
|
||||
@@ -43,15 +47,11 @@
|
||||
</ul>
|
||||
|
||||
## Roadmap
|
||||
|
||||
- [x] Full GoogleAI compatibility (Jan 2025)
|
||||
- [x] Add tools (Jan 2025)
|
||||
- [x] Manually updating chunks in the app UI (Feb 2025)
|
||||
- [x] Devcontainer for easy development (Feb 2025)
|
||||
- [ ] Anthropic Tool compatibility
|
||||
- [ ] Add triggerable actions / tools (webhook)
|
||||
- [ ] Add OAuth 2.0 authentication for tools and sources
|
||||
- [ ] Chatbots menu re-design to handle tools, scheduling, and more
|
||||
- [x] Add OAuth 2.0 authentication for MCP ( September 2025 )
|
||||
- [x] Deep Agents ( October 2025 )
|
||||
- [x] Prompt Templating ( October 2025 )
|
||||
- [x] Full api tooling ( Dec 2025 )
|
||||
- [ ] Agent scheduling ( Jan 2026 )
|
||||
|
||||
You can find our full roadmap [here](https://github.com/orgs/arc53/projects/2). Please don't hesitate to contribute or create issues, it helps us improve DocsGPT!
|
||||
|
||||
@@ -65,11 +65,10 @@ We're eager to provide personalized assistance when deploying your DocsGPT to a
|
||||
|
||||
## Join the Lighthouse Program 🌟
|
||||
|
||||
Calling all developers and GenAI innovators! The **DocsGPT Lighthouse Program** connects technical leaders actively deploying or extending DocsGPT in real-world scenarios. Collaborate directly with our team to shape the roadmap, access priority support, and build enterprise-ready solutions with exclusive community insights.
|
||||
Calling all developers and GenAI innovators! The **DocsGPT Lighthouse Program** connects technical leaders actively deploying or extending DocsGPT in real-world scenarios. Collaborate directly with our team to shape the roadmap, access priority support, and build enterprise-ready solutions with exclusive community insights.
|
||||
|
||||
[Learn More & Apply →](https://docs.google.com/forms/d/1KAADiJinUJ8EMQyfTXUIGyFbqINNClNR3jBNWq7DgTE)
|
||||
|
||||
|
||||
## QuickStart
|
||||
|
||||
> [!Note]
|
||||
@@ -92,13 +91,15 @@ A more detailed [Quickstart](https://docs.docsgpt.cloud/quickstart) is available
|
||||
./setup.sh
|
||||
```
|
||||
|
||||
This interactive script will guide you through setting up DocsGPT. It offers four options: using the public API, running locally, connecting to a local inference engine, or using a cloud API provider. The script will automatically configure your `.env` file and handle necessary downloads and installations based on your chosen option.
|
||||
|
||||
**For Windows:**
|
||||
|
||||
2. **Follow the Docker Deployment Guide:**
|
||||
2. **Run the PowerShell setup script:**
|
||||
|
||||
Please refer to the [Docker Deployment documentation](https://docs.docsgpt.cloud/Deploying/Docker-Deploying) for detailed step-by-step instructions on setting up DocsGPT using Docker.
|
||||
```powershell
|
||||
PowerShell -ExecutionPolicy Bypass -File .\setup.ps1
|
||||
```
|
||||
|
||||
Either script will guide you through setting up DocsGPT. Five options available: using the public API, running locally, connecting to a local inference engine, using a cloud API provider, or build the docker image locally. Scripts will automatically configure your `.env` file and handle necessary downloads and installations based on your chosen option.
|
||||
|
||||
**Navigate to http://localhost:5173/**
|
||||
|
||||
@@ -107,7 +108,8 @@ To stop DocsGPT, open a terminal in the `DocsGPT` directory and run:
|
||||
```bash
|
||||
docker compose -f deployment/docker-compose.yaml down
|
||||
```
|
||||
(or use the specific `docker compose down` command shown after running `setup.sh`).
|
||||
|
||||
(or use the specific `docker compose down` command shown after running the setup script).
|
||||
|
||||
> [!Note]
|
||||
> For development environment setup instructions, please refer to the [Development Environment Guide](https://docs.docsgpt.cloud/Deploying/Development-Environment).
|
||||
@@ -134,7 +136,6 @@ Please refer to the [CONTRIBUTING.md](CONTRIBUTING.md) file for information abou
|
||||
|
||||
We as members, contributors, and leaders, pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. Please refer to the [CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md) file for more information about contributing.
|
||||
|
||||
|
||||
## Many Thanks To Our Contributors⚡
|
||||
|
||||
<a href="https://github.com/arc53/DocsGPT/graphs/contributors" alt="View Contributors">
|
||||
@@ -145,9 +146,16 @@ We as members, contributors, and leaders, pledge to make participation in our co
|
||||
|
||||
The source code license is [MIT](https://opensource.org/license/mit/), as described in the [LICENSE](LICENSE) file.
|
||||
|
||||
<p>This project is supported by:</p>
|
||||
## This project is supported by:
|
||||
|
||||
<p>
|
||||
<a href="https://www.digitalocean.com/?utm_medium=opensource&utm_source=DocsGPT">
|
||||
<img src="https://opensource.nyc3.cdn.digitaloceanspaces.com/attribution/assets/SVG/DO_Logo_horizontal_blue.svg" width="201px">
|
||||
</a>
|
||||
</p>
|
||||
<p>
|
||||
<a href="https://get.neon.com/docsgpt">
|
||||
<img width="201" alt="color" src="https://github.com/user-attachments/assets/7d9813b7-0e6d-403f-b5af-68af066b326f" />
|
||||
</a>
|
||||
|
||||
</p>
|
||||
|
||||
18
SECURITY.md
18
SECURITY.md
@@ -2,13 +2,21 @@
|
||||
|
||||
## Supported Versions
|
||||
|
||||
Supported Versions:
|
||||
|
||||
Currently, we support security patches by committing changes and bumping the version published on Github.
|
||||
Security patches target the latest release and the `main` branch. We recommend always running the most recent version.
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
Found a vulnerability? Please email us:
|
||||
Preferred method: use GitHub's private vulnerability reporting flow:
|
||||
https://github.com/arc53/DocsGPT/security
|
||||
|
||||
security@arc53.com
|
||||
Then click **Report a vulnerability**.
|
||||
|
||||
|
||||
Alternatively, email us at: security@arc53.com
|
||||
|
||||
We aim to acknowledge reports within 48 hours.
|
||||
|
||||
## Incident Handling
|
||||
|
||||
For the public incident response process, see [`INCIDENT_RESPONSE.md`](./.github/INCIDENT_RESPONSE.md). If you believe an active exploit is occurring, include **URGENT** in your report subject line.
|
||||
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
API_KEY=your_api_key
|
||||
EMBEDDINGS_KEY=your_api_key
|
||||
API_URL=http://localhost:7091
|
||||
FLASK_APP=application/app.py
|
||||
FLASK_DEBUG=true
|
||||
|
||||
#For OPENAI on Azure
|
||||
OPENAI_API_BASE=
|
||||
OPENAI_API_VERSION=
|
||||
AZURE_DEPLOYMENT_NAME=
|
||||
AZURE_EMBEDDINGS_DEPLOYMENT_NAME=
|
||||
@@ -7,7 +7,7 @@ RUN apt-get update && \
|
||||
apt-get install -y software-properties-common && \
|
||||
add-apt-repository ppa:deadsnakes/ppa && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends gcc wget unzip libc6-dev python3.12 python3.12-venv && \
|
||||
apt-get install -y --no-install-recommends gcc g++ wget unzip libc6-dev python3.12 python3.12-venv python3.12-dev && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Verify Python installation and setup symlink
|
||||
@@ -48,7 +48,12 @@ FROM ubuntu:24.04 as final
|
||||
RUN apt-get update && \
|
||||
apt-get install -y software-properties-common && \
|
||||
add-apt-repository ppa:deadsnakes/ppa && \
|
||||
apt-get update && apt-get install -y --no-install-recommends python3.12 && \
|
||||
apt-get update && apt-get install -y --no-install-recommends \
|
||||
python3.12 \
|
||||
libgl1 \
|
||||
libglib2.0-0 \
|
||||
poppler-utils \
|
||||
&& \
|
||||
ln -s /usr/bin/python3.12 /usr/bin/python && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
@@ -84,4 +89,4 @@ EXPOSE 7091
|
||||
USER appuser
|
||||
|
||||
# Start Gunicorn
|
||||
CMD ["gunicorn", "-w", "2", "--timeout", "120", "--bind", "0.0.0.0:7091", "application.wsgi:app"]
|
||||
CMD ["gunicorn", "-w", "1", "--timeout", "120", "--bind", "0.0.0.0:7091", "--preload", "application.wsgi:app"]
|
||||
|
||||
0
application/agents/__init__.py
Normal file
0
application/agents/__init__.py
Normal file
@@ -1,9 +1,20 @@
|
||||
import logging
|
||||
|
||||
from application.agents.agentic_agent import AgenticAgent
|
||||
from application.agents.classic_agent import ClassicAgent
|
||||
from application.agents.research_agent import ResearchAgent
|
||||
from application.agents.workflow_agent import WorkflowAgent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentCreator:
|
||||
agents = {
|
||||
"classic": ClassicAgent,
|
||||
"react": ClassicAgent, # backwards compat: react falls back to classic
|
||||
"agentic": AgenticAgent,
|
||||
"research": ResearchAgent,
|
||||
"workflow": WorkflowAgent,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
63
application/agents/agentic_agent.py
Normal file
63
application/agents/agentic_agent.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import logging
|
||||
from typing import Dict, Generator, Optional
|
||||
|
||||
from application.agents.base import BaseAgent
|
||||
from application.agents.tools.internal_search import (
|
||||
INTERNAL_TOOL_ID,
|
||||
add_internal_search_tool,
|
||||
)
|
||||
from application.logging import LogContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgenticAgent(BaseAgent):
|
||||
"""Agent where the LLM controls retrieval via tools.
|
||||
|
||||
Unlike ClassicAgent which pre-fetches docs into the prompt,
|
||||
AgenticAgent gives the LLM an internal_search tool so it can
|
||||
decide when, what, and whether to search.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
retriever_config: Optional[Dict] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.retriever_config = retriever_config or {}
|
||||
|
||||
def _gen_inner(
|
||||
self, query: str, log_context: LogContext
|
||||
) -> Generator[Dict, None, None]:
|
||||
tools_dict = self.tool_executor.get_tools()
|
||||
add_internal_search_tool(tools_dict, self.retriever_config)
|
||||
self._prepare_tools(tools_dict)
|
||||
|
||||
# 4. Build messages (prompt has NO pre-fetched docs)
|
||||
messages = self._build_messages(self.prompt, query)
|
||||
|
||||
# 5. Call LLM — the handler manages the tool loop
|
||||
llm_response = self._llm_gen(messages, log_context)
|
||||
|
||||
yield from self._handle_response(
|
||||
llm_response, tools_dict, messages, log_context
|
||||
)
|
||||
|
||||
# 6. Collect sources from internal search tool results
|
||||
self._collect_internal_sources()
|
||||
|
||||
yield {"sources": self.retrieved_docs}
|
||||
yield {"tool_calls": self._get_truncated_tool_calls()}
|
||||
|
||||
log_context.stacks.append(
|
||||
{"component": "agent", "data": {"tool_calls": self.tool_calls.copy()}}
|
||||
)
|
||||
|
||||
def _collect_internal_sources(self):
|
||||
"""Collect retrieved docs from the cached InternalSearchTool instance."""
|
||||
cache_key = f"internal_search:{INTERNAL_TOOL_ID}:{self.user or ''}"
|
||||
tool = self.tool_executor._loaded_tools.get(cache_key)
|
||||
if tool and hasattr(tool, "retrieved_docs") and tool.retrieved_docs:
|
||||
self.retrieved_docs = tool.retrieved_docs
|
||||
@@ -1,153 +1,585 @@
|
||||
from typing import Dict, Generator
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Generator, List, Optional
|
||||
|
||||
from application.agents.llm_handler import get_llm_handler
|
||||
from application.agents.tools.tool_action_parser import ToolActionParser
|
||||
from application.agents.tools.tool_manager import ToolManager
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.agents.tool_executor import ToolExecutor
|
||||
from application.core.json_schema_utils import (
|
||||
JsonSchemaValidationError,
|
||||
normalize_json_schema_payload,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
from application.llm.handlers.base import ToolCall
|
||||
from application.llm.handlers.handler_creator import LLMHandlerCreator
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.logging import build_stack_data, log_activity, LogContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseAgent:
|
||||
class BaseAgent(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
endpoint,
|
||||
llm_name,
|
||||
gpt_model,
|
||||
api_key,
|
||||
user_api_key=None,
|
||||
decoded_token=None,
|
||||
endpoint: str,
|
||||
llm_name: str,
|
||||
model_id: str,
|
||||
api_key: str,
|
||||
agent_id: Optional[str] = None,
|
||||
user_api_key: Optional[str] = None,
|
||||
prompt: str = "",
|
||||
chat_history: Optional[List[Dict]] = None,
|
||||
retrieved_docs: Optional[List[Dict]] = None,
|
||||
decoded_token: Optional[Dict] = None,
|
||||
attachments: Optional[List[Dict]] = None,
|
||||
json_schema: Optional[Dict] = None,
|
||||
limited_token_mode: Optional[bool] = False,
|
||||
token_limit: Optional[int] = settings.DEFAULT_AGENT_LIMITS["token_limit"],
|
||||
limited_request_mode: Optional[bool] = False,
|
||||
request_limit: Optional[int] = settings.DEFAULT_AGENT_LIMITS["request_limit"],
|
||||
compressed_summary: Optional[str] = None,
|
||||
llm=None,
|
||||
llm_handler=None,
|
||||
tool_executor: Optional[ToolExecutor] = None,
|
||||
backup_models: Optional[List[str]] = None,
|
||||
):
|
||||
self.endpoint = endpoint
|
||||
self.llm = LLMCreator.create_llm(
|
||||
llm_name,
|
||||
api_key=api_key,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
self.llm_handler = get_llm_handler(llm_name)
|
||||
self.gpt_model = gpt_model
|
||||
self.tools = []
|
||||
self.tool_config = {}
|
||||
self.tool_calls = []
|
||||
self.llm_name = llm_name
|
||||
self.model_id = model_id
|
||||
self.api_key = api_key
|
||||
self.agent_id = agent_id
|
||||
self.user_api_key = user_api_key
|
||||
self.prompt = prompt
|
||||
self.decoded_token = decoded_token or {}
|
||||
self.user: str = self.decoded_token.get("sub")
|
||||
self.tools: List[Dict] = []
|
||||
self.chat_history: List[Dict] = chat_history if chat_history is not None else []
|
||||
|
||||
def gen(self, *args, **kwargs) -> Generator[Dict, None, None]:
|
||||
raise NotImplementedError('Method "gen" must be implemented in the child class')
|
||||
# Dependency injection for LLM — fall back to creating if not provided
|
||||
if llm is not None:
|
||||
self.llm = llm
|
||||
else:
|
||||
self.llm = LLMCreator.create_llm(
|
||||
llm_name,
|
||||
api_key=api_key,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
model_id=model_id,
|
||||
agent_id=agent_id,
|
||||
backup_models=backup_models,
|
||||
)
|
||||
|
||||
def _get_user_tools(self, user="local"):
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo["docsgpt"]
|
||||
user_tools_collection = db["user_tools"]
|
||||
user_tools = user_tools_collection.find({"user": user, "status": True})
|
||||
user_tools = list(user_tools)
|
||||
tools_by_id = {str(tool["_id"]): tool for tool in user_tools}
|
||||
return tools_by_id
|
||||
self.retrieved_docs = retrieved_docs or []
|
||||
|
||||
def _build_tool_parameters(self, action):
|
||||
params = {"type": "object", "properties": {}, "required": []}
|
||||
for param_type in ["query_params", "headers", "body", "parameters"]:
|
||||
if param_type in action and action[param_type].get("properties"):
|
||||
for k, v in action[param_type]["properties"].items():
|
||||
if v.get("filled_by_llm", True):
|
||||
params["properties"][k] = {
|
||||
key: value
|
||||
for key, value in v.items()
|
||||
if key != "filled_by_llm" and key != "value"
|
||||
}
|
||||
if llm_handler is not None:
|
||||
self.llm_handler = llm_handler
|
||||
else:
|
||||
self.llm_handler = LLMHandlerCreator.create_handler(
|
||||
llm_name if llm_name else "default"
|
||||
)
|
||||
|
||||
params["required"].append(k)
|
||||
return params
|
||||
# Tool executor — injected or created
|
||||
if tool_executor is not None:
|
||||
self.tool_executor = tool_executor
|
||||
else:
|
||||
self.tool_executor = ToolExecutor(
|
||||
user_api_key=user_api_key,
|
||||
user=self.user,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
def _prepare_tools(self, tools_dict):
|
||||
self.tools = [
|
||||
{
|
||||
self.attachments = attachments or []
|
||||
self.json_schema = None
|
||||
if json_schema is not None:
|
||||
try:
|
||||
self.json_schema = normalize_json_schema_payload(json_schema)
|
||||
except JsonSchemaValidationError as exc:
|
||||
logger.warning("Ignoring invalid JSON schema payload: %s", exc)
|
||||
self.limited_token_mode = limited_token_mode
|
||||
self.token_limit = token_limit
|
||||
self.limited_request_mode = limited_request_mode
|
||||
self.request_limit = request_limit
|
||||
self.compressed_summary = compressed_summary
|
||||
self.current_token_count = 0
|
||||
self.context_limit_reached = False
|
||||
|
||||
@log_activity()
|
||||
def gen(
|
||||
self, query: str, log_context: LogContext = None
|
||||
) -> Generator[Dict, None, None]:
|
||||
yield from self._gen_inner(query, log_context)
|
||||
|
||||
@abstractmethod
|
||||
def _gen_inner(
|
||||
self, query: str, log_context: LogContext
|
||||
) -> Generator[Dict, None, None]:
|
||||
pass
|
||||
|
||||
def gen_continuation(
|
||||
self,
|
||||
messages: List[Dict],
|
||||
tools_dict: Dict,
|
||||
pending_tool_calls: List[Dict],
|
||||
tool_actions: List[Dict],
|
||||
) -> Generator[Dict, None, None]:
|
||||
"""Resume generation after tool actions are resolved.
|
||||
|
||||
Processes the client-provided *tool_actions* (approvals, denials,
|
||||
or client-side results), appends the resulting messages, then
|
||||
hands back to the LLM to continue the conversation.
|
||||
|
||||
Args:
|
||||
messages: The saved messages array from the pause point.
|
||||
tools_dict: The saved tools dictionary.
|
||||
pending_tool_calls: The pending tool call descriptors from the pause.
|
||||
tool_actions: Client-provided actions resolving the pending calls.
|
||||
"""
|
||||
self._prepare_tools(tools_dict)
|
||||
|
||||
actions_by_id = {a["call_id"]: a for a in tool_actions}
|
||||
|
||||
# Build a single assistant message containing all tool calls so
|
||||
# the message history matches the format LLM providers expect
|
||||
# (one assistant message with N tool_calls, followed by N tool results).
|
||||
tc_objects: List[Dict[str, Any]] = []
|
||||
for pending in pending_tool_calls:
|
||||
call_id = pending["call_id"]
|
||||
args = pending["arguments"]
|
||||
args_str = (
|
||||
json.dumps(args) if isinstance(args, dict) else (args or "{}")
|
||||
)
|
||||
tc_obj: Dict[str, Any] = {
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": f"{action['name']}_{tool_id}",
|
||||
"description": action["description"],
|
||||
"parameters": self._build_tool_parameters(action),
|
||||
"name": pending["name"],
|
||||
"arguments": args_str,
|
||||
},
|
||||
}
|
||||
for tool_id, tool in tools_dict.items()
|
||||
if (
|
||||
(tool["name"] == "api_tool" and "actions" in tool.get("config", {}))
|
||||
or (tool["name"] != "api_tool" and "actions" in tool)
|
||||
)
|
||||
for action in (
|
||||
tool["config"]["actions"].values()
|
||||
if tool["name"] == "api_tool"
|
||||
else tool["actions"]
|
||||
)
|
||||
if action.get("active", True)
|
||||
]
|
||||
if pending.get("thought_signature"):
|
||||
tc_obj["thought_signature"] = pending["thought_signature"]
|
||||
tc_objects.append(tc_obj)
|
||||
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": tc_objects,
|
||||
})
|
||||
|
||||
# Now process each pending call and append tool result messages
|
||||
for pending in pending_tool_calls:
|
||||
call_id = pending["call_id"]
|
||||
args = pending["arguments"]
|
||||
action = actions_by_id.get(call_id)
|
||||
if not action:
|
||||
action = {
|
||||
"call_id": call_id,
|
||||
"decision": "denied",
|
||||
"comment": "No response provided",
|
||||
}
|
||||
|
||||
if action.get("decision") == "approved":
|
||||
# Execute the tool server-side
|
||||
tc = ToolCall(
|
||||
id=call_id,
|
||||
name=pending["name"],
|
||||
arguments=(
|
||||
json.dumps(args) if isinstance(args, dict) else args
|
||||
),
|
||||
)
|
||||
tool_gen = self._execute_tool_action(tools_dict, tc)
|
||||
tool_response = None
|
||||
while True:
|
||||
try:
|
||||
event = next(tool_gen)
|
||||
yield event
|
||||
except StopIteration as e:
|
||||
tool_response, _ = e.value
|
||||
break
|
||||
messages.append(
|
||||
self.llm_handler.create_tool_message(tc, tool_response)
|
||||
)
|
||||
|
||||
elif action.get("decision") == "denied":
|
||||
comment = action.get("comment", "")
|
||||
denial = (
|
||||
f"Tool execution denied by user. Reason: {comment}"
|
||||
if comment
|
||||
else "Tool execution denied by user."
|
||||
)
|
||||
tc = ToolCall(
|
||||
id=call_id, name=pending["name"], arguments=args
|
||||
)
|
||||
messages.append(
|
||||
self.llm_handler.create_tool_message(tc, denial)
|
||||
)
|
||||
yield {
|
||||
"type": "tool_call",
|
||||
"data": {
|
||||
"tool_name": pending.get("tool_name", "unknown"),
|
||||
"call_id": call_id,
|
||||
"action_name": pending.get("llm_name", pending["name"]),
|
||||
"arguments": args,
|
||||
"status": "denied",
|
||||
},
|
||||
}
|
||||
|
||||
elif "result" in action:
|
||||
result = action["result"]
|
||||
result_str = (
|
||||
json.dumps(result)
|
||||
if not isinstance(result, str)
|
||||
else result
|
||||
)
|
||||
tc = ToolCall(
|
||||
id=call_id, name=pending["name"], arguments=args
|
||||
)
|
||||
messages.append(
|
||||
self.llm_handler.create_tool_message(tc, result_str)
|
||||
)
|
||||
yield {
|
||||
"type": "tool_call",
|
||||
"data": {
|
||||
"tool_name": pending.get("tool_name", "unknown"),
|
||||
"call_id": call_id,
|
||||
"action_name": pending.get("llm_name", pending["name"]),
|
||||
"arguments": args,
|
||||
"result": (
|
||||
result_str[:50] + "..."
|
||||
if len(result_str) > 50
|
||||
else result_str
|
||||
),
|
||||
"status": "completed",
|
||||
},
|
||||
}
|
||||
|
||||
# Resume the LLM loop with the updated messages
|
||||
llm_response = self._llm_gen(messages)
|
||||
yield from self._handle_response(
|
||||
llm_response, tools_dict, messages, None
|
||||
)
|
||||
|
||||
yield {"sources": self.retrieved_docs}
|
||||
yield {"tool_calls": self._get_truncated_tool_calls()}
|
||||
|
||||
# ---- Tool delegation (thin wrappers around ToolExecutor) ----
|
||||
|
||||
@property
|
||||
def tool_calls(self) -> List[Dict]:
|
||||
return self.tool_executor.tool_calls
|
||||
|
||||
@tool_calls.setter
|
||||
def tool_calls(self, value: List[Dict]):
|
||||
self.tool_executor.tool_calls = value
|
||||
|
||||
def _get_tools(self, api_key: str = None) -> Dict[str, Dict]:
|
||||
return self.tool_executor._get_tools_by_api_key(api_key or self.user_api_key)
|
||||
|
||||
def _get_user_tools(self, user="local"):
|
||||
return self.tool_executor._get_user_tools(user)
|
||||
|
||||
def _build_tool_parameters(self, action):
|
||||
return self.tool_executor._build_tool_parameters(action)
|
||||
|
||||
def _prepare_tools(self, tools_dict):
|
||||
self.tools = self.tool_executor.prepare_tools_for_llm(tools_dict)
|
||||
|
||||
def _execute_tool_action(self, tools_dict, call):
|
||||
parser = ToolActionParser(self.llm.__class__.__name__)
|
||||
tool_id, action_name, call_args = parser.parse_args(call)
|
||||
|
||||
tool_data = tools_dict[tool_id]
|
||||
action_data = (
|
||||
tool_data["config"]["actions"][action_name]
|
||||
if tool_data["name"] == "api_tool"
|
||||
else next(
|
||||
action
|
||||
for action in tool_data["actions"]
|
||||
if action["name"] == action_name
|
||||
)
|
||||
return self.tool_executor.execute(
|
||||
tools_dict, call, self.llm.__class__.__name__
|
||||
)
|
||||
|
||||
query_params, headers, body, parameters = {}, {}, {}, {}
|
||||
param_types = {
|
||||
"query_params": query_params,
|
||||
"headers": headers,
|
||||
"body": body,
|
||||
"parameters": parameters,
|
||||
}
|
||||
def _get_truncated_tool_calls(self):
|
||||
return self.tool_executor.get_truncated_tool_calls()
|
||||
|
||||
for param_type, target_dict in param_types.items():
|
||||
if param_type in action_data and action_data[param_type].get("properties"):
|
||||
for param, details in action_data[param_type]["properties"].items():
|
||||
if param not in call_args and "value" in details:
|
||||
target_dict[param] = details["value"]
|
||||
# ---- Context / token management ----
|
||||
|
||||
for param, value in call_args.items():
|
||||
for param_type, target_dict in param_types.items():
|
||||
if param_type in action_data and param in action_data[param_type].get(
|
||||
"properties", {}
|
||||
):
|
||||
target_dict[param] = value
|
||||
|
||||
tm = ToolManager(config={})
|
||||
tool = tm.load_tool(
|
||||
tool_data["name"],
|
||||
tool_config=(
|
||||
{
|
||||
"url": tool_data["config"]["actions"][action_name]["url"],
|
||||
"method": tool_data["config"]["actions"][action_name]["method"],
|
||||
"headers": headers,
|
||||
"query_params": query_params,
|
||||
}
|
||||
if tool_data["name"] == "api_tool"
|
||||
else tool_data["config"]
|
||||
),
|
||||
def _calculate_current_context_tokens(self, messages: List[Dict]) -> int:
|
||||
from application.api.answer.services.compression.token_counter import (
|
||||
TokenCounter,
|
||||
)
|
||||
if tool_data["name"] == "api_tool":
|
||||
print(
|
||||
f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}"
|
||||
return TokenCounter.count_message_tokens(messages)
|
||||
|
||||
def _check_context_limit(self, messages: List[Dict]) -> bool:
|
||||
from application.core.model_utils import get_token_limit
|
||||
|
||||
try:
|
||||
current_tokens = self._calculate_current_context_tokens(messages)
|
||||
self.current_token_count = current_tokens
|
||||
context_limit = get_token_limit(self.model_id)
|
||||
threshold = int(context_limit * settings.COMPRESSION_THRESHOLD_PERCENTAGE)
|
||||
|
||||
if current_tokens >= threshold:
|
||||
logger.warning(
|
||||
f"Context limit approaching: {current_tokens}/{context_limit} tokens "
|
||||
f"({(current_tokens/context_limit)*100:.1f}%)"
|
||||
)
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking context limit: {str(e)}", exc_info=True)
|
||||
return False
|
||||
|
||||
def _validate_context_size(self, messages: List[Dict]) -> None:
|
||||
from application.core.model_utils import get_token_limit
|
||||
|
||||
current_tokens = self._calculate_current_context_tokens(messages)
|
||||
self.current_token_count = current_tokens
|
||||
context_limit = get_token_limit(self.model_id)
|
||||
percentage = (current_tokens / context_limit) * 100
|
||||
|
||||
if current_tokens >= context_limit:
|
||||
logger.warning(
|
||||
f"Context at limit: {current_tokens:,}/{context_limit:,} tokens "
|
||||
f"({percentage:.1f}%). Model: {self.model_id}"
|
||||
)
|
||||
elif current_tokens >= int(
|
||||
context_limit * settings.COMPRESSION_THRESHOLD_PERCENTAGE
|
||||
):
|
||||
logger.info(
|
||||
f"Context approaching limit: {current_tokens:,}/{context_limit:,} tokens "
|
||||
f"({percentage:.1f}%)"
|
||||
)
|
||||
result = tool.execute_action(action_name, **body)
|
||||
else:
|
||||
print(f"Executing tool: {action_name} with args: {call_args}")
|
||||
result = tool.execute_action(action_name, **parameters)
|
||||
call_id = getattr(call, "id", None)
|
||||
|
||||
tool_call_data = {
|
||||
"tool_name": tool_data["name"],
|
||||
"call_id": call_id if call_id is not None else "None",
|
||||
"action_name": f"{action_name}_{tool_id}",
|
||||
"arguments": call_args,
|
||||
"result": result,
|
||||
}
|
||||
self.tool_calls.append(tool_call_data)
|
||||
def _truncate_text_middle(self, text: str, max_tokens: int) -> str:
|
||||
from application.utils import num_tokens_from_string
|
||||
|
||||
return result, call_id
|
||||
current_tokens = num_tokens_from_string(text)
|
||||
if current_tokens <= max_tokens:
|
||||
return text
|
||||
|
||||
chars_per_token = len(text) / current_tokens if current_tokens > 0 else 4
|
||||
target_chars = int(max_tokens * chars_per_token * 0.95)
|
||||
|
||||
if target_chars <= 0:
|
||||
return ""
|
||||
|
||||
start_chars = int(target_chars * 0.4)
|
||||
end_chars = int(target_chars * 0.4)
|
||||
|
||||
truncation_marker = "\n\n[... content truncated to fit context limit ...]\n\n"
|
||||
truncated = text[:start_chars] + truncation_marker + text[-end_chars:]
|
||||
|
||||
logger.info(
|
||||
f"Truncated text from {current_tokens:,} to ~{max_tokens:,} tokens "
|
||||
f"(removed middle section)"
|
||||
)
|
||||
return truncated
|
||||
|
||||
# ---- Message building ----
|
||||
|
||||
def _build_messages(
|
||||
self,
|
||||
system_prompt: str,
|
||||
query: str,
|
||||
) -> List[Dict]:
|
||||
"""Build messages using pre-rendered system prompt"""
|
||||
from application.core.model_utils import get_token_limit
|
||||
from application.utils import num_tokens_from_string
|
||||
|
||||
if self.compressed_summary:
|
||||
compression_context = (
|
||||
"\n\n---\n\n"
|
||||
"This session is being continued from a previous conversation that "
|
||||
"has been compressed to fit within context limits. "
|
||||
"The conversation is summarized below:\n\n"
|
||||
f"{self.compressed_summary}"
|
||||
)
|
||||
system_prompt = system_prompt + compression_context
|
||||
|
||||
context_limit = get_token_limit(self.model_id)
|
||||
system_tokens = num_tokens_from_string(system_prompt)
|
||||
|
||||
safety_buffer = int(context_limit * 0.1)
|
||||
available_after_system = context_limit - system_tokens - safety_buffer
|
||||
|
||||
max_query_tokens = int(available_after_system * 0.8)
|
||||
query_tokens = num_tokens_from_string(query)
|
||||
|
||||
if query_tokens > max_query_tokens:
|
||||
query = self._truncate_text_middle(query, max_query_tokens)
|
||||
query_tokens = num_tokens_from_string(query)
|
||||
|
||||
available_for_history = max(available_after_system - query_tokens, 0)
|
||||
|
||||
working_history = self._truncate_history_to_fit(
|
||||
self.chat_history,
|
||||
available_for_history,
|
||||
)
|
||||
|
||||
messages = [{"role": "system", "content": system_prompt}]
|
||||
|
||||
for i in working_history:
|
||||
if "prompt" in i and "response" in i:
|
||||
messages.append({"role": "user", "content": i["prompt"]})
|
||||
messages.append({"role": "assistant", "content": i["response"]})
|
||||
if "tool_calls" in i:
|
||||
for tool_call in i["tool_calls"]:
|
||||
call_id = tool_call.get("call_id") or str(uuid.uuid4())
|
||||
args = tool_call.get("arguments")
|
||||
args_str = (
|
||||
json.dumps(args)
|
||||
if isinstance(args, dict)
|
||||
else (args or "{}")
|
||||
)
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.get("action_name", ""),
|
||||
"arguments": args_str,
|
||||
},
|
||||
}],
|
||||
})
|
||||
result = tool_call.get("result")
|
||||
result_str = (
|
||||
json.dumps(result)
|
||||
if not isinstance(result, str)
|
||||
else (result or "")
|
||||
)
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"content": result_str,
|
||||
})
|
||||
messages.append({"role": "user", "content": query})
|
||||
return messages
|
||||
|
||||
def _truncate_history_to_fit(
|
||||
self,
|
||||
history: List[Dict],
|
||||
max_tokens: int,
|
||||
) -> List[Dict]:
|
||||
from application.utils import num_tokens_from_string
|
||||
|
||||
if not history or max_tokens <= 0:
|
||||
return []
|
||||
|
||||
truncated = []
|
||||
current_tokens = 0
|
||||
|
||||
for message in reversed(history):
|
||||
message_tokens = 0
|
||||
|
||||
if "prompt" in message and "response" in message:
|
||||
message_tokens += num_tokens_from_string(message["prompt"])
|
||||
message_tokens += num_tokens_from_string(message["response"])
|
||||
|
||||
if "tool_calls" in message:
|
||||
for tool_call in message["tool_calls"]:
|
||||
tool_str = (
|
||||
f"Tool: {tool_call.get('tool_name')} | "
|
||||
f"Action: {tool_call.get('action_name')} | "
|
||||
f"Args: {tool_call.get('arguments')} | "
|
||||
f"Response: {tool_call.get('result')}"
|
||||
)
|
||||
message_tokens += num_tokens_from_string(tool_str)
|
||||
|
||||
if current_tokens + message_tokens <= max_tokens:
|
||||
current_tokens += message_tokens
|
||||
truncated.insert(0, message)
|
||||
else:
|
||||
break
|
||||
|
||||
if len(truncated) < len(history):
|
||||
logger.info(
|
||||
f"Truncated chat history from {len(history)} to {len(truncated)} messages "
|
||||
f"to fit within {max_tokens:,} token budget"
|
||||
)
|
||||
|
||||
return truncated
|
||||
|
||||
# ---- LLM generation ----
|
||||
|
||||
def _llm_gen(self, messages: List[Dict], log_context: Optional[LogContext] = None):
|
||||
self._validate_context_size(messages)
|
||||
|
||||
gen_kwargs = {"model": self.model_id, "messages": messages}
|
||||
if self.attachments:
|
||||
gen_kwargs["_usage_attachments"] = self.attachments
|
||||
|
||||
if (
|
||||
hasattr(self.llm, "_supports_tools")
|
||||
and self.llm._supports_tools
|
||||
and self.tools
|
||||
):
|
||||
gen_kwargs["tools"] = self.tools
|
||||
if (
|
||||
self.json_schema
|
||||
and hasattr(self.llm, "_supports_structured_output")
|
||||
and self.llm._supports_structured_output()
|
||||
):
|
||||
structured_format = self.llm.prepare_structured_output_format(
|
||||
self.json_schema
|
||||
)
|
||||
if structured_format:
|
||||
if self.llm_name == "openai":
|
||||
gen_kwargs["response_format"] = structured_format
|
||||
elif self.llm_name == "google":
|
||||
gen_kwargs["response_schema"] = structured_format
|
||||
resp = self.llm.gen_stream(**gen_kwargs)
|
||||
|
||||
if log_context:
|
||||
data = build_stack_data(self.llm, exclude_attributes=["client"])
|
||||
log_context.stacks.append({"component": "llm", "data": data})
|
||||
return resp
|
||||
|
||||
def _llm_handler(
|
||||
self,
|
||||
resp,
|
||||
tools_dict: Dict,
|
||||
messages: List[Dict],
|
||||
log_context: Optional[LogContext] = None,
|
||||
attachments: Optional[List[Dict]] = None,
|
||||
):
|
||||
resp = self.llm_handler.process_message_flow(
|
||||
self, resp, tools_dict, messages, attachments, True
|
||||
)
|
||||
if log_context:
|
||||
data = build_stack_data(self.llm_handler, exclude_attributes=["tool_calls"])
|
||||
log_context.stacks.append({"component": "llm_handler", "data": data})
|
||||
return resp
|
||||
|
||||
def _handle_response(self, response, tools_dict, messages, log_context):
|
||||
is_structured_output = (
|
||||
self.json_schema is not None
|
||||
and hasattr(self.llm, "_supports_structured_output")
|
||||
and self.llm._supports_structured_output()
|
||||
)
|
||||
|
||||
if isinstance(response, str):
|
||||
answer_data = {"answer": response}
|
||||
if is_structured_output:
|
||||
answer_data["structured"] = True
|
||||
answer_data["schema"] = self.json_schema
|
||||
yield answer_data
|
||||
return
|
||||
if hasattr(response, "message") and getattr(response.message, "content", None):
|
||||
answer_data = {"answer": response.message.content}
|
||||
if is_structured_output:
|
||||
answer_data["structured"] = True
|
||||
answer_data["schema"] = self.json_schema
|
||||
yield answer_data
|
||||
return
|
||||
processed_response_gen = self._llm_handler(
|
||||
response, tools_dict, messages, log_context, self.attachments
|
||||
)
|
||||
|
||||
for event in processed_response_gen:
|
||||
if isinstance(event, str):
|
||||
answer_data = {"answer": event}
|
||||
if is_structured_output:
|
||||
answer_data["structured"] = True
|
||||
answer_data["schema"] = self.json_schema
|
||||
yield answer_data
|
||||
elif hasattr(event, "message") and getattr(event.message, "content", None):
|
||||
answer_data = {"answer": event.message.content}
|
||||
if is_structured_output:
|
||||
answer_data["structured"] = True
|
||||
answer_data["schema"] = self.json_schema
|
||||
yield answer_data
|
||||
elif isinstance(event, dict) and "type" in event:
|
||||
yield event
|
||||
|
||||
@@ -1,140 +1,33 @@
|
||||
import uuid
|
||||
import logging
|
||||
from typing import Dict, Generator
|
||||
|
||||
from application.agents.base import BaseAgent
|
||||
from application.logging import build_stack_data, log_activity, LogContext
|
||||
from application.logging import LogContext
|
||||
|
||||
from application.retriever.base import BaseRetriever
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ClassicAgent(BaseAgent):
|
||||
def __init__(
|
||||
self,
|
||||
endpoint,
|
||||
llm_name,
|
||||
gpt_model,
|
||||
api_key,
|
||||
user_api_key=None,
|
||||
prompt="",
|
||||
chat_history=None,
|
||||
decoded_token=None,
|
||||
):
|
||||
super().__init__(
|
||||
endpoint, llm_name, gpt_model, api_key, user_api_key, decoded_token
|
||||
)
|
||||
self.user = decoded_token.get("sub")
|
||||
self.prompt = prompt
|
||||
self.chat_history = chat_history if chat_history is not None else []
|
||||
|
||||
@log_activity()
|
||||
def gen(
|
||||
self, query: str, retriever: BaseRetriever, log_context: LogContext = None
|
||||
) -> Generator[Dict, None, None]:
|
||||
yield from self._gen_inner(query, retriever, log_context)
|
||||
"""A simplified agent with clear execution flow"""
|
||||
|
||||
def _gen_inner(
|
||||
self, query: str, retriever: BaseRetriever, log_context: LogContext
|
||||
self, query: str, log_context: LogContext
|
||||
) -> Generator[Dict, None, None]:
|
||||
retrieved_data = self._retriever_search(retriever, query, log_context)
|
||||
"""Core generator function for ClassicAgent execution flow"""
|
||||
|
||||
docs_together = "\n".join([doc["text"] for doc in retrieved_data])
|
||||
p_chat_combine = self.prompt.replace("{summaries}", docs_together)
|
||||
messages_combine = [{"role": "system", "content": p_chat_combine}]
|
||||
|
||||
if len(self.chat_history) > 0:
|
||||
for i in self.chat_history:
|
||||
if "prompt" in i and "response" in i:
|
||||
messages_combine.append({"role": "user", "content": i["prompt"]})
|
||||
messages_combine.append(
|
||||
{"role": "assistant", "content": i["response"]}
|
||||
)
|
||||
if "tool_calls" in i:
|
||||
for tool_call in i["tool_calls"]:
|
||||
call_id = tool_call.get("call_id")
|
||||
if call_id is None or call_id == "None":
|
||||
call_id = str(uuid.uuid4())
|
||||
|
||||
function_call_dict = {
|
||||
"function_call": {
|
||||
"name": tool_call.get("action_name"),
|
||||
"args": tool_call.get("arguments"),
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
function_response_dict = {
|
||||
"function_response": {
|
||||
"name": tool_call.get("action_name"),
|
||||
"response": {"result": tool_call.get("result")},
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
|
||||
messages_combine.append(
|
||||
{"role": "assistant", "content": [function_call_dict]}
|
||||
)
|
||||
messages_combine.append(
|
||||
{"role": "tool", "content": [function_response_dict]}
|
||||
)
|
||||
messages_combine.append({"role": "user", "content": query})
|
||||
|
||||
tools_dict = self._get_user_tools(self.user)
|
||||
tools_dict = self.tool_executor.get_tools()
|
||||
self._prepare_tools(tools_dict)
|
||||
|
||||
resp = self._llm_gen(messages_combine, log_context)
|
||||
messages = self._build_messages(self.prompt, query)
|
||||
llm_response = self._llm_gen(messages, log_context)
|
||||
|
||||
if isinstance(resp, str):
|
||||
yield {"answer": resp}
|
||||
return
|
||||
if (
|
||||
hasattr(resp, "message")
|
||||
and hasattr(resp.message, "content")
|
||||
and resp.message.content is not None
|
||||
):
|
||||
yield {"answer": resp.message.content}
|
||||
return
|
||||
|
||||
resp = self._llm_handler(resp, tools_dict, messages_combine, log_context)
|
||||
|
||||
if isinstance(resp, str):
|
||||
yield {"answer": resp}
|
||||
elif (
|
||||
hasattr(resp, "message")
|
||||
and hasattr(resp.message, "content")
|
||||
and resp.message.content is not None
|
||||
):
|
||||
yield {"answer": resp.message.content}
|
||||
else:
|
||||
completion = self.llm.gen_stream(
|
||||
model=self.gpt_model, messages=messages_combine, tools=self.tools
|
||||
)
|
||||
for line in completion:
|
||||
if isinstance(line, str):
|
||||
yield {"answer": line}
|
||||
|
||||
yield {"sources": retrieved_data}
|
||||
yield {"tool_calls": self.tool_calls.copy()}
|
||||
|
||||
def _retriever_search(self, retriever, query, log_context):
|
||||
retrieved_data = retriever.search(query)
|
||||
if log_context:
|
||||
data = build_stack_data(retriever, exclude_attributes=["llm"])
|
||||
log_context.stacks.append({"component": "retriever", "data": data})
|
||||
return retrieved_data
|
||||
|
||||
def _llm_gen(self, messages_combine, log_context):
|
||||
resp = self.llm.gen_stream(
|
||||
model=self.gpt_model, messages=messages_combine, tools=self.tools
|
||||
yield from self._handle_response(
|
||||
llm_response, tools_dict, messages, log_context
|
||||
)
|
||||
if log_context:
|
||||
data = build_stack_data(self.llm)
|
||||
log_context.stacks.append({"component": "llm", "data": data})
|
||||
return resp
|
||||
|
||||
def _llm_handler(self, resp, tools_dict, messages_combine, log_context):
|
||||
resp = self.llm_handler.handle_response(
|
||||
self, resp, tools_dict, messages_combine
|
||||
yield {"sources": self.retrieved_docs}
|
||||
yield {"tool_calls": self._get_truncated_tool_calls()}
|
||||
|
||||
log_context.stacks.append(
|
||||
{"component": "agent", "data": {"tool_calls": self.tool_calls.copy()}}
|
||||
)
|
||||
if log_context:
|
||||
data = build_stack_data(self.llm_handler)
|
||||
log_context.stacks.append({"component": "llm_handler", "data": data})
|
||||
return resp
|
||||
|
||||
@@ -1,254 +0,0 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from application.logging import build_stack_data
|
||||
|
||||
|
||||
class LLMHandler(ABC):
|
||||
def __init__(self):
|
||||
self.llm_calls = []
|
||||
self.tool_calls = []
|
||||
|
||||
@abstractmethod
|
||||
def handle_response(self, agent, resp, tools_dict, messages, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class OpenAILLMHandler(LLMHandler):
|
||||
def handle_response(self, agent, resp, tools_dict, messages, stream: bool = True):
|
||||
if not stream:
|
||||
while hasattr(resp, "finish_reason") and resp.finish_reason == "tool_calls":
|
||||
message = json.loads(resp.model_dump_json())["message"]
|
||||
keys_to_remove = {"audio", "function_call", "refusal"}
|
||||
filtered_data = {
|
||||
k: v for k, v in message.items() if k not in keys_to_remove
|
||||
}
|
||||
messages.append(filtered_data)
|
||||
|
||||
tool_calls = resp.message.tool_calls
|
||||
for call in tool_calls:
|
||||
try:
|
||||
self.tool_calls.append(call)
|
||||
tool_response, call_id = agent._execute_tool_action(
|
||||
tools_dict, call
|
||||
)
|
||||
function_call_dict = {
|
||||
"function_call": {
|
||||
"name": call.function.name,
|
||||
"args": call.function.arguments,
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
function_response_dict = {
|
||||
"function_response": {
|
||||
"name": call.function.name,
|
||||
"response": {"result": tool_response},
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
|
||||
messages.append(
|
||||
{"role": "assistant", "content": [function_call_dict]}
|
||||
)
|
||||
messages.append(
|
||||
{"role": "tool", "content": [function_response_dict]}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": f"Error executing tool: {str(e)}",
|
||||
"tool_call_id": call_id,
|
||||
}
|
||||
)
|
||||
resp = agent.llm.gen_stream(
|
||||
model=agent.gpt_model, messages=messages, tools=agent.tools
|
||||
)
|
||||
self.llm_calls.append(build_stack_data(agent.llm))
|
||||
return resp
|
||||
|
||||
else:
|
||||
while True:
|
||||
tool_calls = {}
|
||||
for chunk in resp:
|
||||
if isinstance(chunk, str) and len(chunk) > 0:
|
||||
return
|
||||
elif hasattr(chunk, "delta"):
|
||||
chunk_delta = chunk.delta
|
||||
|
||||
if (
|
||||
hasattr(chunk_delta, "tool_calls")
|
||||
and chunk_delta.tool_calls is not None
|
||||
):
|
||||
for tool_call in chunk_delta.tool_calls:
|
||||
index = tool_call.index
|
||||
if index not in tool_calls:
|
||||
tool_calls[index] = {
|
||||
"id": "",
|
||||
"function": {"name": "", "arguments": ""},
|
||||
}
|
||||
|
||||
current = tool_calls[index]
|
||||
if tool_call.id:
|
||||
current["id"] = tool_call.id
|
||||
if tool_call.function.name:
|
||||
current["function"][
|
||||
"name"
|
||||
] = tool_call.function.name
|
||||
if tool_call.function.arguments:
|
||||
current["function"][
|
||||
"arguments"
|
||||
] += tool_call.function.arguments
|
||||
tool_calls[index] = current
|
||||
|
||||
if (
|
||||
hasattr(chunk, "finish_reason")
|
||||
and chunk.finish_reason == "tool_calls"
|
||||
):
|
||||
for index in sorted(tool_calls.keys()):
|
||||
call = tool_calls[index]
|
||||
try:
|
||||
self.tool_calls.append(call)
|
||||
tool_response, call_id = agent._execute_tool_action(
|
||||
tools_dict, call
|
||||
)
|
||||
if isinstance(call["function"]["arguments"], str):
|
||||
call["function"]["arguments"] = json.loads(call["function"]["arguments"])
|
||||
|
||||
function_call_dict = {
|
||||
"function_call": {
|
||||
"name": call["function"]["name"],
|
||||
"args": call["function"]["arguments"],
|
||||
"call_id": call["id"],
|
||||
}
|
||||
}
|
||||
function_response_dict = {
|
||||
"function_response": {
|
||||
"name": call["function"]["name"],
|
||||
"response": {"result": tool_response},
|
||||
"call_id": call["id"],
|
||||
}
|
||||
}
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [function_call_dict],
|
||||
}
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": [function_response_dict],
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"Error executing tool: {str(e)}",
|
||||
}
|
||||
)
|
||||
tool_calls = {}
|
||||
|
||||
if (
|
||||
hasattr(chunk, "finish_reason")
|
||||
and chunk.finish_reason == "stop"
|
||||
):
|
||||
return
|
||||
elif isinstance(chunk, str) and len(chunk) == 0:
|
||||
continue
|
||||
|
||||
resp = agent.llm.gen_stream(
|
||||
model=agent.gpt_model, messages=messages, tools=agent.tools
|
||||
)
|
||||
self.llm_calls.append(build_stack_data(agent.llm))
|
||||
|
||||
|
||||
class GoogleLLMHandler(LLMHandler):
|
||||
def handle_response(self, agent, resp, tools_dict, messages, stream: bool = True):
|
||||
from google.genai import types
|
||||
|
||||
while True:
|
||||
if not stream:
|
||||
response = agent.llm.gen(
|
||||
model=agent.gpt_model, messages=messages, tools=agent.tools
|
||||
)
|
||||
self.llm_calls.append(build_stack_data(agent.llm))
|
||||
if response.candidates and response.candidates[0].content.parts:
|
||||
tool_call_found = False
|
||||
for part in response.candidates[0].content.parts:
|
||||
if part.function_call:
|
||||
tool_call_found = True
|
||||
self.tool_calls.append(part.function_call)
|
||||
tool_response, call_id = agent._execute_tool_action(
|
||||
tools_dict, part.function_call
|
||||
)
|
||||
function_response_part = types.Part.from_function_response(
|
||||
name=part.function_call.name,
|
||||
response={"result": tool_response},
|
||||
)
|
||||
|
||||
messages.append(
|
||||
{"role": "model", "content": [part.to_json_dict()]}
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": [function_response_part.to_json_dict()],
|
||||
}
|
||||
)
|
||||
|
||||
if (
|
||||
not tool_call_found
|
||||
and response.candidates[0].content.parts
|
||||
and response.candidates[0].content.parts[0].text
|
||||
):
|
||||
return response.candidates[0].content.parts[0].text
|
||||
elif not tool_call_found:
|
||||
return response.candidates[0].content.parts
|
||||
|
||||
else:
|
||||
return response
|
||||
|
||||
else:
|
||||
response = agent.llm.gen_stream(
|
||||
model=agent.gpt_model, messages=messages, tools=agent.tools
|
||||
)
|
||||
self.llm_calls.append(build_stack_data(agent.llm))
|
||||
|
||||
tool_call_found = False
|
||||
for result in response:
|
||||
if hasattr(result, "function_call"):
|
||||
tool_call_found = True
|
||||
self.tool_calls.append(result.function_call)
|
||||
tool_response, call_id = agent._execute_tool_action(
|
||||
tools_dict, result.function_call
|
||||
)
|
||||
function_response_part = types.Part.from_function_response(
|
||||
name=result.function_call.name,
|
||||
response={"result": tool_response},
|
||||
)
|
||||
|
||||
messages.append(
|
||||
{"role": "model", "content": [result.to_json_dict()]}
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": [function_response_part.to_json_dict()],
|
||||
}
|
||||
)
|
||||
|
||||
if not tool_call_found:
|
||||
return response
|
||||
|
||||
|
||||
def get_llm_handler(llm_type):
|
||||
handlers = {
|
||||
"openai": OpenAILLMHandler(),
|
||||
"google": GoogleLLMHandler(),
|
||||
}
|
||||
return handlers.get(llm_type, OpenAILLMHandler())
|
||||
698
application/agents/research_agent.py
Normal file
698
application/agents/research_agent.py
Normal file
@@ -0,0 +1,698 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, Generator, List, Optional
|
||||
|
||||
from application.agents.base import BaseAgent
|
||||
from application.agents.tool_executor import ToolExecutor
|
||||
from application.agents.tools.internal_search import (
|
||||
INTERNAL_TOOL_ID,
|
||||
add_internal_search_tool,
|
||||
)
|
||||
from application.agents.tools.think import THINK_TOOL_ENTRY, THINK_TOOL_ID
|
||||
from application.logging import LogContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Defaults (can be overridden via constructor)
|
||||
DEFAULT_MAX_STEPS = 6
|
||||
DEFAULT_MAX_SUB_ITERATIONS = 5
|
||||
DEFAULT_TIMEOUT_SECONDS = 300 # 5 minutes
|
||||
DEFAULT_TOKEN_BUDGET = 100_000
|
||||
DEFAULT_PARALLEL_WORKERS = 3
|
||||
|
||||
# Adaptive depth caps per complexity level
|
||||
COMPLEXITY_CAPS = {
|
||||
"simple": 2,
|
||||
"moderate": 4,
|
||||
"complex": 6,
|
||||
}
|
||||
|
||||
_PROMPTS_DIR = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
||||
"prompts",
|
||||
"research",
|
||||
)
|
||||
|
||||
|
||||
def _load_prompt(name: str) -> str:
|
||||
with open(os.path.join(_PROMPTS_DIR, name), "r") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
CLARIFICATION_PROMPT = _load_prompt("clarification.txt")
|
||||
PLANNING_PROMPT = _load_prompt("planning.txt")
|
||||
STEP_PROMPT = _load_prompt("step.txt")
|
||||
SYNTHESIS_PROMPT = _load_prompt("synthesis.txt")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CitationManager
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class CitationManager:
|
||||
"""Tracks and deduplicates citations across research steps."""
|
||||
|
||||
def __init__(self):
|
||||
self.citations: Dict[int, Dict] = {}
|
||||
self._counter = 0
|
||||
|
||||
def add(self, doc: Dict) -> int:
|
||||
"""Register a source, return its citation number. Deduplicates by source."""
|
||||
source = doc.get("source", "")
|
||||
title = doc.get("title", "")
|
||||
for num, existing in self.citations.items():
|
||||
if existing.get("source") == source and existing.get("title") == title:
|
||||
return num
|
||||
self._counter += 1
|
||||
self.citations[self._counter] = doc
|
||||
return self._counter
|
||||
|
||||
def add_docs(self, docs: List[Dict]) -> str:
|
||||
"""Register multiple docs, return formatted citation mapping text."""
|
||||
mapping_lines = []
|
||||
for doc in docs:
|
||||
num = self.add(doc)
|
||||
title = doc.get("title", "Untitled")
|
||||
mapping_lines.append(f"[{num}] {title}")
|
||||
return "\n".join(mapping_lines)
|
||||
|
||||
def format_references(self) -> str:
|
||||
"""Generate [N] -> source mapping for report footer."""
|
||||
if not self.citations:
|
||||
return "No sources found."
|
||||
lines = []
|
||||
for num, doc in sorted(self.citations.items()):
|
||||
title = doc.get("title", "Untitled")
|
||||
source = doc.get("source", "Unknown")
|
||||
filename = doc.get("filename", "")
|
||||
display = filename or title
|
||||
lines.append(f"[{num}] {display} — {source}")
|
||||
return "\n".join(lines)
|
||||
|
||||
def get_all_docs(self) -> List[Dict]:
|
||||
return list(self.citations.values())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ResearchAgent
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ResearchAgent(BaseAgent):
|
||||
"""Multi-step research agent with parallel execution and budget controls.
|
||||
|
||||
Orchestrates: Plan -> Research (per step, optionally parallel) -> Synthesize.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
retriever_config: Optional[Dict] = None,
|
||||
max_steps: int = DEFAULT_MAX_STEPS,
|
||||
max_sub_iterations: int = DEFAULT_MAX_SUB_ITERATIONS,
|
||||
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
|
||||
token_budget: int = DEFAULT_TOKEN_BUDGET,
|
||||
parallel_workers: int = DEFAULT_PARALLEL_WORKERS,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.retriever_config = retriever_config or {}
|
||||
self.max_steps = max_steps
|
||||
self.max_sub_iterations = max_sub_iterations
|
||||
self.timeout_seconds = timeout_seconds
|
||||
self.token_budget = token_budget
|
||||
self.parallel_workers = parallel_workers
|
||||
self.citations = CitationManager()
|
||||
self._start_time: float = 0
|
||||
self._tokens_used: int = 0
|
||||
self._last_token_snapshot: int = 0
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Budget & timeout helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _is_timed_out(self) -> bool:
|
||||
return (time.monotonic() - self._start_time) >= self.timeout_seconds
|
||||
|
||||
def _elapsed(self) -> float:
|
||||
return round(time.monotonic() - self._start_time, 1)
|
||||
|
||||
def _track_tokens(self, count: int):
|
||||
self._tokens_used += count
|
||||
|
||||
def _budget_remaining(self) -> int:
|
||||
return max(self.token_budget - self._tokens_used, 0)
|
||||
|
||||
def _is_over_budget(self) -> bool:
|
||||
return self._tokens_used >= self.token_budget
|
||||
|
||||
def _snapshot_llm_tokens(self) -> int:
|
||||
"""Read current token usage from LLM and return delta since last snapshot."""
|
||||
current = self.llm.token_usage.get("prompt_tokens", 0) + self.llm.token_usage.get("generated_tokens", 0)
|
||||
delta = current - self._last_token_snapshot
|
||||
self._last_token_snapshot = current
|
||||
return delta
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Main orchestration
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _gen_inner(
|
||||
self, query: str, log_context: LogContext
|
||||
) -> Generator[Dict, None, None]:
|
||||
self._start_time = time.monotonic()
|
||||
tools_dict = self._setup_tools()
|
||||
|
||||
# Phase 0: Clarification (skip if user is responding to a prior clarification)
|
||||
if not self._is_follow_up():
|
||||
clarification = self._clarification_phase(query)
|
||||
if clarification:
|
||||
yield {"metadata": {"is_clarification": True}}
|
||||
yield {"answer": clarification}
|
||||
yield {"sources": []}
|
||||
yield {"tool_calls": []}
|
||||
log_context.stacks.append(
|
||||
{"component": "agent", "data": {"clarification": True}}
|
||||
)
|
||||
return
|
||||
|
||||
# Phase 1: Planning (with adaptive depth)
|
||||
yield {"type": "research_progress", "data": {"status": "planning"}}
|
||||
plan, complexity = self._planning_phase(query)
|
||||
|
||||
if not plan:
|
||||
logger.warning("ResearchAgent: Planning produced no steps, falling back")
|
||||
plan = [{"query": query, "rationale": "Direct investigation"}]
|
||||
complexity = "simple"
|
||||
|
||||
yield {
|
||||
"type": "research_plan",
|
||||
"data": {"steps": plan, "complexity": complexity},
|
||||
}
|
||||
|
||||
# Phase 2: Research each step (yields progress events in real-time)
|
||||
intermediate_reports = []
|
||||
for i, step in enumerate(plan):
|
||||
step_num = i + 1
|
||||
step_query = step.get("query", query)
|
||||
|
||||
if self._is_timed_out():
|
||||
logger.warning(
|
||||
f"ResearchAgent: Timeout at step {step_num}/{len(plan)} "
|
||||
f"({self._elapsed()}s)"
|
||||
)
|
||||
break
|
||||
if self._is_over_budget():
|
||||
logger.warning(
|
||||
f"ResearchAgent: Token budget exhausted at step {step_num}/{len(plan)}"
|
||||
)
|
||||
break
|
||||
|
||||
yield {
|
||||
"type": "research_progress",
|
||||
"data": {
|
||||
"step": step_num,
|
||||
"total": len(plan),
|
||||
"query": step_query,
|
||||
"status": "researching",
|
||||
},
|
||||
}
|
||||
|
||||
report = self._research_step(step_query, tools_dict)
|
||||
intermediate_reports.append({"step": step, "content": report})
|
||||
|
||||
yield {
|
||||
"type": "research_progress",
|
||||
"data": {
|
||||
"step": step_num,
|
||||
"total": len(plan),
|
||||
"query": step_query,
|
||||
"status": "complete",
|
||||
},
|
||||
}
|
||||
|
||||
# Phase 3: Synthesis (streaming)
|
||||
if self._is_timed_out():
|
||||
logger.warning(
|
||||
f"ResearchAgent: Timeout ({self._elapsed()}s) before synthesis, "
|
||||
f"synthesizing with {len(intermediate_reports)} reports"
|
||||
)
|
||||
yield {
|
||||
"type": "research_progress",
|
||||
"data": {
|
||||
"status": "synthesizing",
|
||||
"elapsed_seconds": self._elapsed(),
|
||||
"tokens_used": self._tokens_used,
|
||||
},
|
||||
}
|
||||
yield from self._synthesis_phase(
|
||||
query, plan, intermediate_reports, tools_dict, log_context
|
||||
)
|
||||
|
||||
# Sources and tool calls
|
||||
self.retrieved_docs = self.citations.get_all_docs()
|
||||
yield {"sources": self.retrieved_docs}
|
||||
yield {"tool_calls": self._get_truncated_tool_calls()}
|
||||
|
||||
logger.info(
|
||||
f"ResearchAgent completed: {len(intermediate_reports)}/{len(plan)} steps, "
|
||||
f"{self._elapsed()}s, ~{self._tokens_used} tokens"
|
||||
)
|
||||
log_context.stacks.append(
|
||||
{"component": "agent", "data": {"tool_calls": self.tool_calls.copy()}}
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tool setup
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _setup_tools(self) -> Dict:
|
||||
"""Build tools_dict with user tools + internal search + think."""
|
||||
tools_dict = self.tool_executor.get_tools()
|
||||
|
||||
add_internal_search_tool(tools_dict, self.retriever_config)
|
||||
|
||||
think_entry = dict(THINK_TOOL_ENTRY)
|
||||
think_entry["config"] = {}
|
||||
tools_dict[THINK_TOOL_ID] = think_entry
|
||||
|
||||
self._prepare_tools(tools_dict)
|
||||
return tools_dict
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Phase 0: Clarification
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _is_follow_up(self) -> bool:
|
||||
"""Check if the user is responding to a prior clarification.
|
||||
|
||||
Uses the metadata flag stored in the conversation DB — no string matching.
|
||||
Only skip clarification when the last query was explicitly flagged
|
||||
as a clarification by this agent.
|
||||
"""
|
||||
if not self.chat_history:
|
||||
return False
|
||||
last = self.chat_history[-1]
|
||||
meta = last.get("metadata", {})
|
||||
return bool(meta.get("is_clarification"))
|
||||
|
||||
def _clarification_phase(self, question: str) -> Optional[str]:
|
||||
"""Ask the LLM whether the question needs clarification.
|
||||
|
||||
Returns formatted clarification text if needed, or None to proceed.
|
||||
Uses response_format to force valid JSON output.
|
||||
"""
|
||||
messages = [
|
||||
{"role": "system", "content": CLARIFICATION_PROMPT},
|
||||
{"role": "user", "content": question},
|
||||
]
|
||||
|
||||
try:
|
||||
response = self.llm.gen(
|
||||
model=self.model_id,
|
||||
messages=messages,
|
||||
tools=None,
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
text = self._extract_text(response)
|
||||
self._track_tokens(self._snapshot_llm_tokens())
|
||||
logger.info(f"ResearchAgent clarification response: {text[:300]}")
|
||||
|
||||
data = self._parse_clarification_json(text)
|
||||
if not data or not data.get("needs_clarification"):
|
||||
return None
|
||||
|
||||
questions = data.get("questions", [])
|
||||
if not questions:
|
||||
return None
|
||||
|
||||
# Format as a friendly response
|
||||
lines = [
|
||||
"Before I begin researching, I'd like to clarify a few things:\n"
|
||||
]
|
||||
for i, q in enumerate(questions[:3], 1):
|
||||
lines.append(f"{i}. {q}")
|
||||
lines.append(
|
||||
"\nPlease provide these details and I'll start the research."
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Clarification phase failed: {e}", exc_info=True)
|
||||
return None # proceed with research on failure
|
||||
|
||||
def _parse_clarification_json(self, text: str) -> Optional[Dict]:
|
||||
"""Parse clarification JSON from LLM response."""
|
||||
try:
|
||||
return json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Try extracting from code fences
|
||||
for marker in ["```json", "```"]:
|
||||
if marker in text:
|
||||
start = text.index(marker) + len(marker)
|
||||
end = text.index("```", start) if "```" in text[start:] else len(text)
|
||||
try:
|
||||
return json.loads(text[start:end].strip())
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
# Try finding JSON object
|
||||
for i, ch in enumerate(text):
|
||||
if ch == "{":
|
||||
for j in range(len(text) - 1, i, -1):
|
||||
if text[j] == "}":
|
||||
try:
|
||||
return json.loads(text[i : j + 1])
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
break
|
||||
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Phase 1: Planning (with adaptive depth)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _planning_phase(self, question: str) -> tuple[List[Dict], str]:
|
||||
"""Decompose the question into research steps via LLM.
|
||||
|
||||
Returns (steps, complexity) where complexity is simple/moderate/complex.
|
||||
"""
|
||||
messages = [
|
||||
{"role": "system", "content": PLANNING_PROMPT},
|
||||
{"role": "user", "content": question},
|
||||
]
|
||||
|
||||
try:
|
||||
response = self.llm.gen(
|
||||
model=self.model_id,
|
||||
messages=messages,
|
||||
tools=None,
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
text = self._extract_text(response)
|
||||
self._track_tokens(self._snapshot_llm_tokens())
|
||||
logger.info(f"ResearchAgent planning LLM response: {text[:500]}")
|
||||
|
||||
plan_data = self._parse_plan_json(text)
|
||||
if isinstance(plan_data, dict):
|
||||
complexity = plan_data.get("complexity", "moderate")
|
||||
steps = plan_data.get("steps", [])
|
||||
else:
|
||||
complexity = "moderate"
|
||||
steps = plan_data
|
||||
|
||||
# Adaptive depth: cap steps based on assessed complexity
|
||||
cap = COMPLEXITY_CAPS.get(complexity, self.max_steps)
|
||||
cap = min(cap, self.max_steps)
|
||||
steps = steps[:cap]
|
||||
|
||||
logger.info(
|
||||
f"ResearchAgent plan: complexity={complexity}, "
|
||||
f"steps={len(steps)} (cap={cap})"
|
||||
)
|
||||
return steps, complexity
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Planning phase failed: {e}", exc_info=True)
|
||||
return (
|
||||
[{"query": question, "rationale": "Direct investigation (planning failed)"}],
|
||||
"simple",
|
||||
)
|
||||
|
||||
def _parse_plan_json(self, text: str):
|
||||
"""Extract JSON plan from LLM response. Returns dict or list."""
|
||||
# Try direct parse
|
||||
try:
|
||||
data = json.loads(text)
|
||||
if isinstance(data, dict) and "steps" in data:
|
||||
return data
|
||||
if isinstance(data, list):
|
||||
return data
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Try extracting from markdown code fences
|
||||
for marker in ["```json", "```"]:
|
||||
if marker in text:
|
||||
start = text.index(marker) + len(marker)
|
||||
end = text.index("```", start) if "```" in text[start:] else len(text)
|
||||
try:
|
||||
data = json.loads(text[start:end].strip())
|
||||
if isinstance(data, dict) and "steps" in data:
|
||||
return data
|
||||
if isinstance(data, list):
|
||||
return data
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
# Try finding JSON object in text
|
||||
for i, ch in enumerate(text):
|
||||
if ch == "{":
|
||||
for j in range(len(text) - 1, i, -1):
|
||||
if text[j] == "}":
|
||||
try:
|
||||
data = json.loads(text[i : j + 1])
|
||||
if isinstance(data, dict) and "steps" in data:
|
||||
return data
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
break
|
||||
|
||||
logger.warning(f"Could not parse plan JSON from: {text[:200]}")
|
||||
return []
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Phase 2: Research step (core loop)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _research_step(self, step_query: str, tools_dict: Dict) -> str:
|
||||
"""Run a focused research loop for one sub-question (sequential path)."""
|
||||
report = self._research_step_with_executor(
|
||||
step_query, tools_dict, self.tool_executor
|
||||
)
|
||||
self._collect_step_sources()
|
||||
return report
|
||||
|
||||
def _research_step_with_executor(
|
||||
self, step_query: str, tools_dict: Dict, executor: ToolExecutor
|
||||
) -> str:
|
||||
"""Core research loop. Works with any ToolExecutor instance."""
|
||||
system_prompt = STEP_PROMPT.replace("{step_query}", step_query)
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": step_query},
|
||||
]
|
||||
|
||||
last_search_empty = False
|
||||
|
||||
for iteration in range(self.max_sub_iterations):
|
||||
# Check timeout and budget
|
||||
if self._is_timed_out():
|
||||
logger.info(
|
||||
f"Research step '{step_query[:50]}' timed out at iteration {iteration}"
|
||||
)
|
||||
break
|
||||
if self._is_over_budget():
|
||||
logger.info(
|
||||
f"Research step '{step_query[:50]}' hit token budget at iteration {iteration}"
|
||||
)
|
||||
break
|
||||
|
||||
try:
|
||||
response = self.llm.gen(
|
||||
model=self.model_id,
|
||||
messages=messages,
|
||||
tools=self.tools if self.tools else None,
|
||||
)
|
||||
self._track_tokens(self._snapshot_llm_tokens())
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Research step LLM call failed (iteration {iteration}): {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
break
|
||||
|
||||
parsed = self.llm_handler.parse_response(response)
|
||||
|
||||
if not parsed.requires_tool_call:
|
||||
return parsed.content or "No findings for this step."
|
||||
|
||||
# Execute tool calls
|
||||
messages, last_search_empty = self._execute_step_tools_with_refinement(
|
||||
parsed.tool_calls, tools_dict, messages, executor, last_search_empty
|
||||
)
|
||||
|
||||
# Max iterations / timeout / budget — ask for summary
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Please summarize your findings so far based on the information gathered.",
|
||||
}
|
||||
)
|
||||
try:
|
||||
response = self.llm.gen(
|
||||
model=self.model_id, messages=messages, tools=None
|
||||
)
|
||||
self._track_tokens(self._snapshot_llm_tokens())
|
||||
text = self._extract_text(response)
|
||||
return text or "Research step completed."
|
||||
except Exception:
|
||||
return "Research step completed."
|
||||
|
||||
def _execute_step_tools_with_refinement(
|
||||
self,
|
||||
tool_calls,
|
||||
tools_dict: Dict,
|
||||
messages: List[Dict],
|
||||
executor: ToolExecutor,
|
||||
last_search_empty: bool,
|
||||
) -> tuple[List[Dict], bool]:
|
||||
"""Execute tool calls with query refinement on empty results.
|
||||
|
||||
Returns (updated_messages, was_last_search_empty).
|
||||
"""
|
||||
search_returned_empty = False
|
||||
|
||||
for call in tool_calls:
|
||||
gen = executor.execute(
|
||||
tools_dict, call, self.llm.__class__.__name__
|
||||
)
|
||||
result = None
|
||||
call_id = None
|
||||
while True:
|
||||
try:
|
||||
event = next(gen)
|
||||
# Log tool_call status events instead of discarding them
|
||||
if isinstance(event, dict) and event.get("type") == "tool_call":
|
||||
logger.debug(
|
||||
"Tool %s status: %s",
|
||||
event.get("data", {}).get("action_name", ""),
|
||||
event.get("data", {}).get("status", ""),
|
||||
)
|
||||
except StopIteration as e:
|
||||
result, call_id = e.value
|
||||
break
|
||||
|
||||
# Detect empty search results for refinement
|
||||
is_search = "search" in (call.name or "").lower()
|
||||
result_str = str(result) if result else ""
|
||||
if is_search and "No documents found" in result_str:
|
||||
search_returned_empty = True
|
||||
if last_search_empty:
|
||||
# Two consecutive empty searches — inject refinement hint
|
||||
result_str += (
|
||||
"\n\nHint: Previous search also returned no results. "
|
||||
"Try a very different query with different keywords, "
|
||||
"or broaden your search terms."
|
||||
)
|
||||
result = result_str
|
||||
|
||||
import json as _json
|
||||
|
||||
args_str = (
|
||||
_json.dumps(call.arguments)
|
||||
if isinstance(call.arguments, dict)
|
||||
else call.arguments
|
||||
)
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {"name": call.name, "arguments": args_str},
|
||||
}],
|
||||
})
|
||||
tool_message = self.llm_handler.create_tool_message(call, result)
|
||||
messages.append(tool_message)
|
||||
|
||||
return messages, search_returned_empty
|
||||
|
||||
def _collect_step_sources(self):
|
||||
"""Collect sources from InternalSearchTool and register with CitationManager."""
|
||||
cache_key = f"internal_search:{INTERNAL_TOOL_ID}:{self.user or ''}"
|
||||
tool = self.tool_executor._loaded_tools.get(cache_key)
|
||||
if tool and hasattr(tool, "retrieved_docs"):
|
||||
for doc in tool.retrieved_docs:
|
||||
self.citations.add(doc)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Phase 3: Synthesis
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _synthesis_phase(
|
||||
self,
|
||||
question: str,
|
||||
plan: List[Dict],
|
||||
intermediate_reports: List[Dict],
|
||||
tools_dict: Dict,
|
||||
log_context: LogContext,
|
||||
) -> Generator[Dict, None, None]:
|
||||
"""Compile all findings into a final cited report (streaming)."""
|
||||
plan_lines = []
|
||||
for i, step in enumerate(plan, 1):
|
||||
plan_lines.append(
|
||||
f"{i}. {step.get('query', 'Unknown')} — {step.get('rationale', '')}"
|
||||
)
|
||||
plan_summary = "\n".join(plan_lines)
|
||||
|
||||
findings_parts = []
|
||||
for i, report in enumerate(intermediate_reports, 1):
|
||||
step_query = report["step"].get("query", "Unknown")
|
||||
content = report["content"]
|
||||
findings_parts.append(
|
||||
f"--- Step {i}: {step_query} ---\n{content}"
|
||||
)
|
||||
findings = "\n\n".join(findings_parts)
|
||||
|
||||
references = self.citations.format_references()
|
||||
|
||||
synthesis_prompt = SYNTHESIS_PROMPT.replace("{question}", question)
|
||||
synthesis_prompt = synthesis_prompt.replace("{plan_summary}", plan_summary)
|
||||
synthesis_prompt = synthesis_prompt.replace("{findings}", findings)
|
||||
synthesis_prompt = synthesis_prompt.replace("{references}", references)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": synthesis_prompt},
|
||||
{"role": "user", "content": f"Please write the research report for: {question}"},
|
||||
]
|
||||
|
||||
llm_response = self.llm.gen_stream(
|
||||
model=self.model_id, messages=messages, tools=None
|
||||
)
|
||||
|
||||
if log_context:
|
||||
from application.logging import build_stack_data
|
||||
|
||||
log_context.stacks.append(
|
||||
{"component": "synthesis_llm", "data": build_stack_data(self.llm)}
|
||||
)
|
||||
|
||||
yield from self._handle_response(
|
||||
llm_response, tools_dict, messages, log_context
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _extract_text(self, response) -> str:
|
||||
"""Extract text content from a non-streaming LLM response."""
|
||||
if isinstance(response, str):
|
||||
return response
|
||||
if hasattr(response, "message") and hasattr(response.message, "content"):
|
||||
return response.message.content or ""
|
||||
if hasattr(response, "choices") and response.choices:
|
||||
choice = response.choices[0]
|
||||
if hasattr(choice, "message") and hasattr(choice.message, "content"):
|
||||
return choice.message.content or ""
|
||||
if hasattr(response, "content") and isinstance(response.content, list):
|
||||
if response.content and hasattr(response.content[0], "text"):
|
||||
return response.content[0].text or ""
|
||||
return str(response) if response else ""
|
||||
477
application/agents/tool_executor.py
Normal file
477
application/agents/tool_executor.py
Normal file
@@ -0,0 +1,477 @@
|
||||
import logging
|
||||
import uuid
|
||||
from collections import Counter
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
|
||||
from application.agents.tools.tool_action_parser import ToolActionParser
|
||||
from application.agents.tools.tool_manager import ToolManager
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.security.encryption import decrypt_credentials
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolExecutor:
|
||||
"""Handles tool discovery, preparation, and execution.
|
||||
|
||||
Extracted from BaseAgent to separate concerns and enable tool caching.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_api_key: Optional[str] = None,
|
||||
user: Optional[str] = None,
|
||||
decoded_token: Optional[Dict] = None,
|
||||
):
|
||||
self.user_api_key = user_api_key
|
||||
self.user = user
|
||||
self.decoded_token = decoded_token
|
||||
self.tool_calls: List[Dict] = []
|
||||
self._loaded_tools: Dict[str, object] = {}
|
||||
self.conversation_id: Optional[str] = None
|
||||
self.client_tools: Optional[List[Dict]] = None
|
||||
self._name_to_tool: Dict[str, Tuple[str, str]] = {}
|
||||
self._tool_to_name: Dict[Tuple[str, str], str] = {}
|
||||
|
||||
def get_tools(self) -> Dict[str, Dict]:
|
||||
"""Load tool configs from DB based on user context.
|
||||
|
||||
If *client_tools* have been set on this executor, they are
|
||||
automatically merged into the returned dict.
|
||||
"""
|
||||
if self.user_api_key:
|
||||
tools = self._get_tools_by_api_key(self.user_api_key)
|
||||
else:
|
||||
tools = self._get_user_tools(self.user or "local")
|
||||
if self.client_tools:
|
||||
self.merge_client_tools(tools, self.client_tools)
|
||||
return tools
|
||||
|
||||
def _get_tools_by_api_key(self, api_key: str) -> Dict[str, Dict]:
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
agents_collection = db["agents"]
|
||||
tools_collection = db["user_tools"]
|
||||
|
||||
agent_data = agents_collection.find_one({"key": api_key})
|
||||
tool_ids = agent_data.get("tools", []) if agent_data else []
|
||||
|
||||
tools = (
|
||||
tools_collection.find(
|
||||
{"_id": {"$in": [ObjectId(tool_id) for tool_id in tool_ids]}}
|
||||
)
|
||||
if tool_ids
|
||||
else []
|
||||
)
|
||||
tools = list(tools)
|
||||
return {str(tool["_id"]): tool for tool in tools} if tools else {}
|
||||
|
||||
def _get_user_tools(self, user: str = "local") -> Dict[str, Dict]:
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
user_tools_collection = db["user_tools"]
|
||||
user_tools = user_tools_collection.find({"user": user, "status": True})
|
||||
user_tools = list(user_tools)
|
||||
return {str(i): tool for i, tool in enumerate(user_tools)}
|
||||
|
||||
def merge_client_tools(
|
||||
self, tools_dict: Dict, client_tools: List[Dict]
|
||||
) -> Dict:
|
||||
"""Merge client-provided tool definitions into tools_dict.
|
||||
|
||||
Client tools use the standard function-calling format::
|
||||
|
||||
[{"type": "function", "function": {"name": "get_weather",
|
||||
"description": "...", "parameters": {...}}}]
|
||||
|
||||
They are stored in *tools_dict* with ``client_side: True`` so that
|
||||
:meth:`check_pause` returns a pause signal instead of trying to
|
||||
execute them server-side.
|
||||
|
||||
Args:
|
||||
tools_dict: The mutable server tools dict (will be modified in place).
|
||||
client_tools: List of tool definitions in function-calling format.
|
||||
|
||||
Returns:
|
||||
The updated *tools_dict* (same reference, for convenience).
|
||||
"""
|
||||
for i, ct in enumerate(client_tools):
|
||||
func = ct.get("function", ct) # tolerate bare {"name":..} too
|
||||
name = func.get("name", f"clienttool{i}")
|
||||
tool_id = f"ct{i}"
|
||||
|
||||
tools_dict[tool_id] = {
|
||||
"name": name,
|
||||
"client_side": True,
|
||||
"actions": [
|
||||
{
|
||||
"name": name,
|
||||
"description": func.get("description", ""),
|
||||
"active": True,
|
||||
"parameters": func.get("parameters", {}),
|
||||
}
|
||||
],
|
||||
}
|
||||
return tools_dict
|
||||
|
||||
def prepare_tools_for_llm(self, tools_dict: Dict) -> List[Dict]:
|
||||
"""Convert tool configs to LLM function schemas.
|
||||
|
||||
Action names are kept clean for the LLM:
|
||||
- Unique action names appear as-is (e.g. ``get_weather``).
|
||||
- Duplicate action names get numbered suffixes (e.g. ``search_1``,
|
||||
``search_2``).
|
||||
|
||||
A reverse mapping is stored in ``_name_to_tool`` so that tool calls
|
||||
can be routed back to the correct ``(tool_id, action_name)`` without
|
||||
brittle string splitting.
|
||||
"""
|
||||
# Pass 1: collect entries and count action name occurrences
|
||||
entries: List[Tuple[str, str, Dict, bool]] = [] # (tool_id, action_name, action, is_client)
|
||||
name_counts: Counter = Counter()
|
||||
|
||||
for tool_id, tool in tools_dict.items():
|
||||
is_api = tool["name"] == "api_tool"
|
||||
is_client = tool.get("client_side", False)
|
||||
|
||||
if is_api and "actions" not in tool.get("config", {}):
|
||||
continue
|
||||
if not is_api and "actions" not in tool:
|
||||
continue
|
||||
|
||||
actions = (
|
||||
tool["config"]["actions"].values()
|
||||
if is_api
|
||||
else tool["actions"]
|
||||
)
|
||||
|
||||
for action in actions:
|
||||
if not action.get("active", True):
|
||||
continue
|
||||
entries.append((tool_id, action["name"], action, is_client))
|
||||
name_counts[action["name"]] += 1
|
||||
|
||||
# Pass 2: assign LLM-visible names and build mappings
|
||||
self._name_to_tool = {}
|
||||
self._tool_to_name = {}
|
||||
collision_counters: Dict[str, int] = {}
|
||||
all_llm_names: set = set()
|
||||
|
||||
result = []
|
||||
for tool_id, action_name, action, is_client in entries:
|
||||
if name_counts[action_name] == 1:
|
||||
llm_name = action_name
|
||||
else:
|
||||
counter = collision_counters.get(action_name, 1)
|
||||
candidate = f"{action_name}_{counter}"
|
||||
# Skip if candidate collides with a unique action name
|
||||
while candidate in all_llm_names or (
|
||||
candidate in name_counts and name_counts[candidate] == 1
|
||||
):
|
||||
counter += 1
|
||||
candidate = f"{action_name}_{counter}"
|
||||
collision_counters[action_name] = counter + 1
|
||||
llm_name = candidate
|
||||
|
||||
all_llm_names.add(llm_name)
|
||||
self._name_to_tool[llm_name] = (tool_id, action_name)
|
||||
self._tool_to_name[(tool_id, action_name)] = llm_name
|
||||
|
||||
if is_client:
|
||||
params = action.get("parameters", {})
|
||||
else:
|
||||
params = self._build_tool_parameters(action)
|
||||
|
||||
result.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": llm_name,
|
||||
"description": action.get("description", ""),
|
||||
"parameters": params,
|
||||
},
|
||||
})
|
||||
return result
|
||||
|
||||
def _build_tool_parameters(self, action: Dict) -> Dict:
|
||||
params = {"type": "object", "properties": {}, "required": []}
|
||||
for param_type in ["query_params", "headers", "body", "parameters"]:
|
||||
if param_type in action and action[param_type].get("properties"):
|
||||
for k, v in action[param_type]["properties"].items():
|
||||
if v.get("filled_by_llm", True):
|
||||
params["properties"][k] = {
|
||||
key: value
|
||||
for key, value in v.items()
|
||||
if key not in ("filled_by_llm", "value", "required")
|
||||
}
|
||||
if v.get("required", False):
|
||||
params["required"].append(k)
|
||||
return params
|
||||
|
||||
def check_pause(
|
||||
self, tools_dict: Dict, call, llm_class_name: str
|
||||
) -> Optional[Dict]:
|
||||
"""Check if a tool call requires pausing for approval or client execution.
|
||||
|
||||
Returns a dict describing the pending action if pause is needed, None otherwise.
|
||||
"""
|
||||
parser = ToolActionParser(llm_class_name, name_mapping=self._name_to_tool)
|
||||
tool_id, action_name, call_args = parser.parse_args(call)
|
||||
call_id = getattr(call, "id", None) or str(uuid.uuid4())
|
||||
llm_name = getattr(call, "name", "")
|
||||
|
||||
if tool_id is None or action_name is None or tool_id not in tools_dict:
|
||||
return None # Will be handled as error by execute()
|
||||
|
||||
tool_data = tools_dict[tool_id]
|
||||
|
||||
# Client-side tools
|
||||
if tool_data.get("client_side"):
|
||||
return {
|
||||
"call_id": call_id,
|
||||
"name": llm_name,
|
||||
"tool_name": tool_data.get("name", "unknown"),
|
||||
"tool_id": tool_id,
|
||||
"action_name": action_name,
|
||||
"llm_name": llm_name,
|
||||
"arguments": call_args if isinstance(call_args, dict) else {},
|
||||
"pause_type": "requires_client_execution",
|
||||
"thought_signature": getattr(call, "thought_signature", None),
|
||||
}
|
||||
|
||||
# Approval required
|
||||
if tool_data["name"] == "api_tool":
|
||||
action_data = tool_data.get("config", {}).get("actions", {}).get(
|
||||
action_name, {}
|
||||
)
|
||||
else:
|
||||
action_data = next(
|
||||
(a for a in tool_data.get("actions", []) if a["name"] == action_name),
|
||||
{},
|
||||
)
|
||||
|
||||
if action_data.get("require_approval"):
|
||||
return {
|
||||
"call_id": call_id,
|
||||
"name": llm_name,
|
||||
"tool_name": tool_data.get("name", "unknown"),
|
||||
"tool_id": tool_id,
|
||||
"action_name": action_name,
|
||||
"llm_name": llm_name,
|
||||
"arguments": call_args if isinstance(call_args, dict) else {},
|
||||
"pause_type": "awaiting_approval",
|
||||
"thought_signature": getattr(call, "thought_signature", None),
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
def execute(self, tools_dict: Dict, call, llm_class_name: str):
|
||||
"""Execute a tool call. Yields status events, returns (result, call_id)."""
|
||||
parser = ToolActionParser(llm_class_name, name_mapping=self._name_to_tool)
|
||||
tool_id, action_name, call_args = parser.parse_args(call)
|
||||
llm_name = getattr(call, "name", "unknown")
|
||||
|
||||
call_id = getattr(call, "id", None) or str(uuid.uuid4())
|
||||
|
||||
if tool_id is None or action_name is None:
|
||||
error_message = f"Error: Failed to parse LLM tool call. Tool name: {llm_name}"
|
||||
logger.error(error_message)
|
||||
|
||||
tool_call_data = {
|
||||
"tool_name": "unknown",
|
||||
"call_id": call_id,
|
||||
"action_name": llm_name,
|
||||
"arguments": call_args or {},
|
||||
"result": f"Failed to parse tool call. Invalid tool name format: {llm_name}",
|
||||
}
|
||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
|
||||
self.tool_calls.append(tool_call_data)
|
||||
return "Failed to parse tool call.", call_id
|
||||
|
||||
if tool_id not in tools_dict:
|
||||
error_message = f"Error: Tool ID '{tool_id}' extracted from LLM call not found in available tools_dict. Available IDs: {list(tools_dict.keys())}"
|
||||
logger.error(error_message)
|
||||
|
||||
tool_call_data = {
|
||||
"tool_name": "unknown",
|
||||
"call_id": call_id,
|
||||
"action_name": llm_name,
|
||||
"arguments": call_args,
|
||||
"result": f"Tool with ID {tool_id} not found. Available tools: {list(tools_dict.keys())}",
|
||||
}
|
||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
|
||||
self.tool_calls.append(tool_call_data)
|
||||
return f"Tool with ID {tool_id} not found.", call_id
|
||||
|
||||
tool_call_data = {
|
||||
"tool_name": tools_dict[tool_id]["name"],
|
||||
"call_id": call_id,
|
||||
"action_name": llm_name,
|
||||
"arguments": call_args,
|
||||
}
|
||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "pending"}}
|
||||
|
||||
tool_data = tools_dict[tool_id]
|
||||
action_data = (
|
||||
tool_data["config"]["actions"][action_name]
|
||||
if tool_data["name"] == "api_tool"
|
||||
else next(
|
||||
action
|
||||
for action in tool_data["actions"]
|
||||
if action["name"] == action_name
|
||||
)
|
||||
)
|
||||
|
||||
query_params, headers, body, parameters = {}, {}, {}, {}
|
||||
param_types = {
|
||||
"query_params": query_params,
|
||||
"headers": headers,
|
||||
"body": body,
|
||||
"parameters": parameters,
|
||||
}
|
||||
|
||||
for param_type, target_dict in param_types.items():
|
||||
if param_type in action_data and action_data[param_type].get("properties"):
|
||||
for param, details in action_data[param_type]["properties"].items():
|
||||
if (
|
||||
param not in call_args
|
||||
and "value" in details
|
||||
and details["value"]
|
||||
):
|
||||
target_dict[param] = details["value"]
|
||||
for param, value in call_args.items():
|
||||
for param_type, target_dict in param_types.items():
|
||||
if param_type in action_data and param in action_data[param_type].get(
|
||||
"properties", {}
|
||||
):
|
||||
target_dict[param] = value
|
||||
|
||||
# Load tool (with caching)
|
||||
tool = self._get_or_load_tool(
|
||||
tool_data, tool_id, action_name,
|
||||
headers=headers, query_params=query_params,
|
||||
)
|
||||
|
||||
resolved_arguments = (
|
||||
{"query_params": query_params, "headers": headers, "body": body}
|
||||
if tool_data["name"] == "api_tool"
|
||||
else parameters
|
||||
)
|
||||
if tool_data["name"] == "api_tool":
|
||||
logger.debug(
|
||||
f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}"
|
||||
)
|
||||
result = tool.execute_action(action_name, **body)
|
||||
else:
|
||||
logger.debug(f"Executing tool: {action_name} with args: {call_args}")
|
||||
result = tool.execute_action(action_name, **parameters)
|
||||
|
||||
get_artifact_id = (
|
||||
getattr(tool, "get_artifact_id", None)
|
||||
if tool_data["name"] != "api_tool"
|
||||
else None
|
||||
)
|
||||
|
||||
artifact_id = None
|
||||
if callable(get_artifact_id):
|
||||
try:
|
||||
artifact_id = get_artifact_id(action_name, **parameters)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to extract artifact_id from tool %s for action %s",
|
||||
tool_data["name"],
|
||||
action_name,
|
||||
)
|
||||
|
||||
artifact_id = str(artifact_id).strip() if artifact_id is not None else ""
|
||||
if artifact_id:
|
||||
tool_call_data["artifact_id"] = artifact_id
|
||||
result_full = str(result)
|
||||
tool_call_data["resolved_arguments"] = resolved_arguments
|
||||
tool_call_data["result_full"] = result_full
|
||||
tool_call_data["result"] = (
|
||||
f"{result_full[:50]}..." if len(result_full) > 50 else result_full
|
||||
)
|
||||
|
||||
stream_tool_call_data = {
|
||||
key: value
|
||||
for key, value in tool_call_data.items()
|
||||
if key not in {"result_full", "resolved_arguments"}
|
||||
}
|
||||
yield {"type": "tool_call", "data": {**stream_tool_call_data, "status": "completed"}}
|
||||
self.tool_calls.append(tool_call_data)
|
||||
|
||||
return result, call_id
|
||||
|
||||
def _get_or_load_tool(
|
||||
self, tool_data: Dict, tool_id: str, action_name: str,
|
||||
headers: Optional[Dict] = None, query_params: Optional[Dict] = None,
|
||||
):
|
||||
"""Load a tool, using cache when possible."""
|
||||
cache_key = f"{tool_data['name']}:{tool_id}:{self.user or ''}"
|
||||
if cache_key in self._loaded_tools:
|
||||
return self._loaded_tools[cache_key]
|
||||
|
||||
tm = ToolManager(config={})
|
||||
|
||||
if tool_data["name"] == "api_tool":
|
||||
action_config = tool_data["config"]["actions"][action_name]
|
||||
tool_config = {
|
||||
"url": action_config["url"],
|
||||
"method": action_config["method"],
|
||||
"headers": headers or {},
|
||||
"query_params": query_params or {},
|
||||
}
|
||||
if "body_content_type" in action_config:
|
||||
tool_config["body_content_type"] = action_config.get(
|
||||
"body_content_type", "application/json"
|
||||
)
|
||||
tool_config["body_encoding_rules"] = action_config.get(
|
||||
"body_encoding_rules", {}
|
||||
)
|
||||
else:
|
||||
tool_config = tool_data["config"].copy() if tool_data["config"] else {}
|
||||
if tool_config.get("encrypted_credentials") and self.user:
|
||||
decrypted = decrypt_credentials(
|
||||
tool_config["encrypted_credentials"], self.user
|
||||
)
|
||||
tool_config.update(decrypted)
|
||||
tool_config["auth_credentials"] = decrypted
|
||||
tool_config.pop("encrypted_credentials", None)
|
||||
tool_config["tool_id"] = str(tool_data.get("_id", tool_id))
|
||||
if self.conversation_id:
|
||||
tool_config["conversation_id"] = self.conversation_id
|
||||
if tool_data["name"] == "mcp_tool":
|
||||
tool_config["query_mode"] = True
|
||||
|
||||
tool = tm.load_tool(
|
||||
tool_data["name"],
|
||||
tool_config=tool_config,
|
||||
user_id=self.user,
|
||||
)
|
||||
|
||||
# Don't cache api_tool since config varies by action
|
||||
if tool_data["name"] != "api_tool":
|
||||
self._loaded_tools[cache_key] = tool
|
||||
|
||||
return tool
|
||||
|
||||
def get_truncated_tool_calls(self) -> List[Dict]:
|
||||
return [
|
||||
{
|
||||
"tool_name": tool_call.get("tool_name"),
|
||||
"call_id": tool_call.get("call_id"),
|
||||
"action_name": tool_call.get("action_name"),
|
||||
"arguments": tool_call.get("arguments"),
|
||||
"artifact_id": tool_call.get("artifact_id"),
|
||||
"result": (
|
||||
f"{str(tool_call['result'])[:50]}..."
|
||||
if len(str(tool_call["result"])) > 50
|
||||
else tool_call["result"]
|
||||
),
|
||||
"status": "completed",
|
||||
}
|
||||
for tool_call in self.tool_calls
|
||||
]
|
||||
323
application/agents/tools/api_body_serializer.py
Normal file
323
application/agents/tools/api_body_serializer.py
Normal file
@@ -0,0 +1,323 @@
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from urllib.parse import quote, urlencode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ContentType(str, Enum):
|
||||
"""Supported content types for request bodies."""
|
||||
|
||||
JSON = "application/json"
|
||||
FORM_URLENCODED = "application/x-www-form-urlencoded"
|
||||
MULTIPART_FORM_DATA = "multipart/form-data"
|
||||
TEXT_PLAIN = "text/plain"
|
||||
XML = "application/xml"
|
||||
OCTET_STREAM = "application/octet-stream"
|
||||
|
||||
|
||||
class RequestBodySerializer:
|
||||
"""Serializes request bodies according to content-type and OpenAPI 3.1 spec."""
|
||||
|
||||
@staticmethod
|
||||
def serialize(
|
||||
body_data: Dict[str, Any],
|
||||
content_type: str = ContentType.JSON,
|
||||
encoding_rules: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
) -> tuple[Union[str, bytes], Dict[str, str]]:
|
||||
"""
|
||||
Serialize body data to appropriate format.
|
||||
|
||||
Args:
|
||||
body_data: Dictionary of body parameters
|
||||
content_type: Content-Type header value
|
||||
encoding_rules: OpenAPI Encoding Object rules per field
|
||||
|
||||
Returns:
|
||||
Tuple of (serialized_body, updated_headers_dict)
|
||||
|
||||
Raises:
|
||||
ValueError: If serialization fails
|
||||
"""
|
||||
if not body_data:
|
||||
return None, {}
|
||||
|
||||
try:
|
||||
content_type_lower = content_type.lower().split(";")[0].strip()
|
||||
|
||||
if content_type_lower == ContentType.JSON:
|
||||
return RequestBodySerializer._serialize_json(body_data)
|
||||
|
||||
elif content_type_lower == ContentType.FORM_URLENCODED:
|
||||
return RequestBodySerializer._serialize_form_urlencoded(
|
||||
body_data, encoding_rules
|
||||
)
|
||||
|
||||
elif content_type_lower == ContentType.MULTIPART_FORM_DATA:
|
||||
return RequestBodySerializer._serialize_multipart_form_data(
|
||||
body_data, encoding_rules
|
||||
)
|
||||
|
||||
elif content_type_lower == ContentType.TEXT_PLAIN:
|
||||
return RequestBodySerializer._serialize_text_plain(body_data)
|
||||
|
||||
elif content_type_lower == ContentType.XML:
|
||||
return RequestBodySerializer._serialize_xml(body_data)
|
||||
|
||||
elif content_type_lower == ContentType.OCTET_STREAM:
|
||||
return RequestBodySerializer._serialize_octet_stream(body_data)
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
f"Unknown content type: {content_type}, treating as JSON"
|
||||
)
|
||||
return RequestBodySerializer._serialize_json(body_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error serializing body: {str(e)}", exc_info=True)
|
||||
raise ValueError(f"Failed to serialize request body: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def _serialize_json(body_data: Dict[str, Any]) -> tuple[str, Dict[str, str]]:
|
||||
"""Serialize body as JSON per OpenAPI spec."""
|
||||
try:
|
||||
serialized = json.dumps(
|
||||
body_data, separators=(",", ":"), ensure_ascii=False
|
||||
)
|
||||
headers = {"Content-Type": ContentType.JSON.value}
|
||||
return serialized, headers
|
||||
except (TypeError, ValueError) as e:
|
||||
raise ValueError(f"Failed to serialize JSON body: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def _serialize_form_urlencoded(
|
||||
body_data: Dict[str, Any],
|
||||
encoding_rules: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
) -> tuple[str, Dict[str, str]]:
|
||||
"""Serialize body as application/x-www-form-urlencoded per RFC1866/RFC3986."""
|
||||
encoding_rules = encoding_rules or {}
|
||||
params = []
|
||||
|
||||
for key, value in body_data.items():
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
rule = encoding_rules.get(key, {})
|
||||
style = rule.get("style", "form")
|
||||
explode = rule.get("explode", style == "form")
|
||||
content_type = rule.get("contentType", "text/plain")
|
||||
|
||||
serialized_value = RequestBodySerializer._serialize_form_value(
|
||||
value, style, explode, content_type, key
|
||||
)
|
||||
|
||||
if isinstance(serialized_value, list):
|
||||
for sv in serialized_value:
|
||||
params.append((key, sv))
|
||||
else:
|
||||
params.append((key, serialized_value))
|
||||
|
||||
# Use standard urlencode (replaces space with +)
|
||||
serialized = urlencode(params, safe="")
|
||||
headers = {"Content-Type": ContentType.FORM_URLENCODED.value}
|
||||
return serialized, headers
|
||||
|
||||
@staticmethod
|
||||
def _serialize_form_value(
|
||||
value: Any, style: str, explode: bool, content_type: str, key: str
|
||||
) -> Union[str, list]:
|
||||
"""Serialize individual form value with encoding rules."""
|
||||
if isinstance(value, dict):
|
||||
if content_type == "application/json":
|
||||
return json.dumps(value, separators=(",", ":"))
|
||||
elif content_type == "application/xml":
|
||||
return RequestBodySerializer._dict_to_xml(value)
|
||||
else:
|
||||
if style == "deepObject" and explode:
|
||||
return [
|
||||
f"{RequestBodySerializer._percent_encode(str(v))}"
|
||||
for v in value.values()
|
||||
]
|
||||
elif explode:
|
||||
return [
|
||||
f"{RequestBodySerializer._percent_encode(str(v))}"
|
||||
for v in value.values()
|
||||
]
|
||||
else:
|
||||
pairs = [f"{k},{v}" for k, v in value.items()]
|
||||
return RequestBodySerializer._percent_encode(",".join(pairs))
|
||||
|
||||
elif isinstance(value, (list, tuple)):
|
||||
if explode:
|
||||
return [
|
||||
RequestBodySerializer._percent_encode(str(item)) for item in value
|
||||
]
|
||||
else:
|
||||
return RequestBodySerializer._percent_encode(
|
||||
",".join(str(v) for v in value)
|
||||
)
|
||||
|
||||
else:
|
||||
return RequestBodySerializer._percent_encode(str(value))
|
||||
|
||||
@staticmethod
|
||||
def _serialize_multipart_form_data(
|
||||
body_data: Dict[str, Any],
|
||||
encoding_rules: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
) -> tuple[bytes, Dict[str, str]]:
|
||||
"""
|
||||
Serialize body as multipart/form-data per RFC7578.
|
||||
|
||||
Supports file uploads and encoding rules.
|
||||
"""
|
||||
import secrets
|
||||
|
||||
encoding_rules = encoding_rules or {}
|
||||
boundary = f"----DocsGPT{secrets.token_hex(16)}"
|
||||
parts = []
|
||||
|
||||
for key, value in body_data.items():
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
rule = encoding_rules.get(key, {})
|
||||
content_type = rule.get("contentType", "text/plain")
|
||||
headers_rule = rule.get("headers", {})
|
||||
|
||||
part = RequestBodySerializer._create_multipart_part(
|
||||
key, value, content_type, headers_rule
|
||||
)
|
||||
parts.append(part)
|
||||
|
||||
body_bytes = f"--{boundary}\r\n".encode("utf-8")
|
||||
body_bytes += f"--{boundary}\r\n".join(parts).encode("utf-8")
|
||||
body_bytes += f"\r\n--{boundary}--\r\n".encode("utf-8")
|
||||
|
||||
headers = {
|
||||
"Content-Type": f"multipart/form-data; boundary={boundary}",
|
||||
}
|
||||
return body_bytes, headers
|
||||
|
||||
@staticmethod
|
||||
def _create_multipart_part(
|
||||
name: str, value: Any, content_type: str, headers_rule: Dict[str, Any]
|
||||
) -> str:
|
||||
"""Create a single multipart/form-data part."""
|
||||
headers = [
|
||||
f'Content-Disposition: form-data; name="{RequestBodySerializer._percent_encode(name)}"'
|
||||
]
|
||||
|
||||
if isinstance(value, bytes):
|
||||
if content_type == "application/octet-stream":
|
||||
value_encoded = base64.b64encode(value).decode("utf-8")
|
||||
else:
|
||||
value_encoded = value.decode("utf-8", errors="replace")
|
||||
headers.append(f"Content-Type: {content_type}")
|
||||
headers.append("Content-Transfer-Encoding: base64")
|
||||
elif isinstance(value, dict):
|
||||
if content_type == "application/json":
|
||||
value_encoded = json.dumps(value, separators=(",", ":"))
|
||||
elif content_type == "application/xml":
|
||||
value_encoded = RequestBodySerializer._dict_to_xml(value)
|
||||
else:
|
||||
value_encoded = str(value)
|
||||
headers.append(f"Content-Type: {content_type}")
|
||||
elif isinstance(value, str) and content_type != "text/plain":
|
||||
try:
|
||||
if content_type == "application/json":
|
||||
json.loads(value)
|
||||
value_encoded = value
|
||||
elif content_type == "application/xml":
|
||||
value_encoded = value
|
||||
else:
|
||||
value_encoded = str(value)
|
||||
except json.JSONDecodeError:
|
||||
value_encoded = str(value)
|
||||
headers.append(f"Content-Type: {content_type}")
|
||||
else:
|
||||
value_encoded = str(value)
|
||||
if content_type != "text/plain":
|
||||
headers.append(f"Content-Type: {content_type}")
|
||||
|
||||
part = "\r\n".join(headers) + "\r\n\r\n" + value_encoded + "\r\n"
|
||||
return part
|
||||
|
||||
@staticmethod
|
||||
def _serialize_text_plain(body_data: Dict[str, Any]) -> tuple[str, Dict[str, str]]:
|
||||
"""Serialize body as plain text."""
|
||||
if len(body_data) == 1:
|
||||
value = list(body_data.values())[0]
|
||||
return str(value), {"Content-Type": ContentType.TEXT_PLAIN.value}
|
||||
else:
|
||||
text = "\n".join(f"{k}: {v}" for k, v in body_data.items())
|
||||
return text, {"Content-Type": ContentType.TEXT_PLAIN.value}
|
||||
|
||||
@staticmethod
|
||||
def _serialize_xml(body_data: Dict[str, Any]) -> tuple[str, Dict[str, str]]:
|
||||
"""Serialize body as XML."""
|
||||
xml_str = RequestBodySerializer._dict_to_xml(body_data)
|
||||
return xml_str, {"Content-Type": ContentType.XML.value}
|
||||
|
||||
@staticmethod
|
||||
def _serialize_octet_stream(
|
||||
body_data: Dict[str, Any],
|
||||
) -> tuple[bytes, Dict[str, str]]:
|
||||
"""Serialize body as binary octet stream."""
|
||||
if isinstance(body_data, bytes):
|
||||
return body_data, {"Content-Type": ContentType.OCTET_STREAM.value}
|
||||
elif isinstance(body_data, str):
|
||||
return body_data.encode("utf-8"), {
|
||||
"Content-Type": ContentType.OCTET_STREAM.value
|
||||
}
|
||||
else:
|
||||
serialized = json.dumps(body_data)
|
||||
return serialized.encode("utf-8"), {
|
||||
"Content-Type": ContentType.OCTET_STREAM.value
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _percent_encode(value: str, safe_chars: str = "") -> str:
|
||||
"""
|
||||
Percent-encode per RFC3986.
|
||||
|
||||
Args:
|
||||
value: String to encode
|
||||
safe_chars: Additional characters to not encode
|
||||
"""
|
||||
return quote(value, safe=safe_chars)
|
||||
|
||||
@staticmethod
|
||||
def _dict_to_xml(data: Dict[str, Any], root_name: str = "root") -> str:
|
||||
"""
|
||||
Convert dict to simple XML format.
|
||||
"""
|
||||
|
||||
def build_xml(obj: Any, name: str) -> str:
|
||||
if isinstance(obj, dict):
|
||||
inner = "".join(build_xml(v, k) for k, v in obj.items())
|
||||
return f"<{name}>{inner}</{name}>"
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
items = "".join(
|
||||
build_xml(item, f"{name[:-1] if name.endswith('s') else name}")
|
||||
for item in obj
|
||||
)
|
||||
return items
|
||||
else:
|
||||
return f"<{name}>{RequestBodySerializer._escape_xml(str(obj))}</{name}>"
|
||||
|
||||
root = build_xml(data, root_name)
|
||||
return f'<?xml version="1.0" encoding="UTF-8"?>{root}'
|
||||
|
||||
@staticmethod
|
||||
def _escape_xml(value: str) -> str:
|
||||
"""Escape XML special characters."""
|
||||
return (
|
||||
value.replace("&", "&")
|
||||
.replace("<", "<")
|
||||
.replace(">", ">")
|
||||
.replace('"', """)
|
||||
.replace("'", "'")
|
||||
)
|
||||
@@ -1,72 +1,280 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import requests
|
||||
|
||||
from application.agents.tools.api_body_serializer import (
|
||||
ContentType,
|
||||
RequestBodySerializer,
|
||||
)
|
||||
from application.agents.tools.base import Tool
|
||||
from application.core.url_validation import validate_url, SSRFError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_TIMEOUT = 90 # seconds
|
||||
|
||||
|
||||
class APITool(Tool):
|
||||
"""
|
||||
API Tool
|
||||
A flexible tool for performing various API actions (e.g., sending messages, retrieving data) via custom user-specified APIs
|
||||
A flexible tool for performing various API actions (e.g., sending messages, retrieving data) via custom user-specified APIs.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.url = config.get("url", "")
|
||||
self.method = config.get("method", "GET")
|
||||
self.headers = config.get("headers", {"Content-Type": "application/json"})
|
||||
self.headers = config.get("headers", {})
|
||||
self.query_params = config.get("query_params", {})
|
||||
self.body_content_type = config.get("body_content_type", ContentType.JSON)
|
||||
self.body_encoding_rules = config.get("body_encoding_rules", {})
|
||||
|
||||
def execute_action(self, action_name, **kwargs):
|
||||
"""Execute an API action with the given arguments."""
|
||||
return self._make_api_call(
|
||||
self.url, self.method, self.headers, self.query_params, kwargs
|
||||
self.url,
|
||||
self.method,
|
||||
self.headers,
|
||||
self.query_params,
|
||||
kwargs,
|
||||
self.body_content_type,
|
||||
self.body_encoding_rules,
|
||||
)
|
||||
|
||||
def _make_api_call(self, url, method, headers, query_params, body):
|
||||
if query_params:
|
||||
url = f"{url}?{requests.compat.urlencode(query_params)}"
|
||||
if isinstance(body, dict):
|
||||
body = json.dumps(body)
|
||||
def _make_api_call(
|
||||
self,
|
||||
url: str,
|
||||
method: str,
|
||||
headers: Dict[str, str],
|
||||
query_params: Dict[str, Any],
|
||||
body: Dict[str, Any],
|
||||
content_type: str = ContentType.JSON,
|
||||
encoding_rules: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Make an API call with proper body serialization and error handling.
|
||||
|
||||
Args:
|
||||
url: API endpoint URL
|
||||
method: HTTP method (GET, POST, PUT, DELETE, PATCH, HEAD, OPTIONS)
|
||||
headers: Request headers dict
|
||||
query_params: URL query parameters
|
||||
body: Request body as dict
|
||||
content_type: Content-Type for serialization
|
||||
encoding_rules: OpenAPI encoding rules
|
||||
|
||||
Returns:
|
||||
Dict with status_code, data, and message
|
||||
"""
|
||||
request_url = url
|
||||
request_headers = headers.copy() if headers else {}
|
||||
response = None
|
||||
|
||||
# Validate URL to prevent SSRF attacks
|
||||
try:
|
||||
print(f"Making API call: {method} {url} with body: {body}")
|
||||
if body == "{}":
|
||||
body = None
|
||||
response = requests.request(method, url, headers=headers, data=body)
|
||||
response.raise_for_status()
|
||||
content_type = response.headers.get(
|
||||
"Content-Type", "application/json"
|
||||
).lower()
|
||||
if "application/json" in content_type:
|
||||
validate_url(request_url)
|
||||
except SSRFError as e:
|
||||
logger.error(f"URL validation failed: {e}")
|
||||
return {
|
||||
"status_code": None,
|
||||
"message": f"URL validation error: {e}",
|
||||
"data": None,
|
||||
}
|
||||
|
||||
try:
|
||||
path_params_used = set()
|
||||
if query_params:
|
||||
for match in re.finditer(r"\{([^}]+)\}", request_url):
|
||||
param_name = match.group(1)
|
||||
if param_name in query_params:
|
||||
request_url = request_url.replace(
|
||||
f"{{{param_name}}}", str(query_params[param_name])
|
||||
)
|
||||
path_params_used.add(param_name)
|
||||
remaining_params = {
|
||||
k: v for k, v in query_params.items() if k not in path_params_used
|
||||
}
|
||||
if remaining_params:
|
||||
query_string = urlencode(remaining_params)
|
||||
separator = "&" if "?" in request_url else "?"
|
||||
request_url = f"{request_url}{separator}{query_string}"
|
||||
|
||||
# Re-validate URL after parameter substitution to prevent SSRF via path params
|
||||
try:
|
||||
validate_url(request_url)
|
||||
except SSRFError as e:
|
||||
logger.error(f"URL validation failed after parameter substitution: {e}")
|
||||
return {
|
||||
"status_code": None,
|
||||
"message": f"URL validation error: {e}",
|
||||
"data": None,
|
||||
}
|
||||
|
||||
# Serialize body based on content type
|
||||
|
||||
if body and body != {}:
|
||||
try:
|
||||
data = response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Error decoding JSON: {e}. Raw response: {response.text}")
|
||||
serialized_body, body_headers = RequestBodySerializer.serialize(
|
||||
body, content_type, encoding_rules
|
||||
)
|
||||
request_headers.update(body_headers)
|
||||
except ValueError as e:
|
||||
logger.error(f"Body serialization failed: {str(e)}")
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"message": f"API call returned invalid JSON. Error: {e}",
|
||||
"data": response.text,
|
||||
"status_code": None,
|
||||
"message": f"Body serialization error: {str(e)}",
|
||||
"data": None,
|
||||
}
|
||||
elif "text/" in content_type or "application/xml" in content_type:
|
||||
data = response.text
|
||||
elif not response.content:
|
||||
data = None
|
||||
else:
|
||||
print(f"Unsupported content type: {content_type}")
|
||||
data = response.content
|
||||
serialized_body = None
|
||||
if "Content-Type" not in request_headers and method not in [
|
||||
"GET",
|
||||
"HEAD",
|
||||
"DELETE",
|
||||
]:
|
||||
request_headers["Content-Type"] = ContentType.JSON
|
||||
logger.debug(
|
||||
f"API Call: {method} {request_url} | Content-Type: {request_headers.get('Content-Type', 'N/A')}"
|
||||
)
|
||||
|
||||
if method.upper() == "GET":
|
||||
response = requests.get(
|
||||
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
|
||||
)
|
||||
elif method.upper() == "POST":
|
||||
response = requests.post(
|
||||
request_url,
|
||||
data=serialized_body,
|
||||
headers=request_headers,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
)
|
||||
elif method.upper() == "PUT":
|
||||
response = requests.put(
|
||||
request_url,
|
||||
data=serialized_body,
|
||||
headers=request_headers,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
)
|
||||
elif method.upper() == "DELETE":
|
||||
response = requests.delete(
|
||||
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
|
||||
)
|
||||
elif method.upper() == "PATCH":
|
||||
response = requests.patch(
|
||||
request_url,
|
||||
data=serialized_body,
|
||||
headers=request_headers,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
)
|
||||
elif method.upper() == "HEAD":
|
||||
response = requests.head(
|
||||
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
|
||||
)
|
||||
elif method.upper() == "OPTIONS":
|
||||
response = requests.options(
|
||||
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
|
||||
)
|
||||
else:
|
||||
return {
|
||||
"status_code": None,
|
||||
"message": f"Unsupported HTTP method: {method}",
|
||||
"data": None,
|
||||
}
|
||||
response.raise_for_status()
|
||||
|
||||
data = self._parse_response(response)
|
||||
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"data": data,
|
||||
"message": "API call successful.",
|
||||
}
|
||||
except requests.exceptions.Timeout:
|
||||
logger.error(f"Request timeout for {request_url}")
|
||||
return {
|
||||
"status_code": None,
|
||||
"message": f"Request timeout ({DEFAULT_TIMEOUT}s exceeded)",
|
||||
"data": None,
|
||||
}
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
logger.error(f"Connection error: {str(e)}")
|
||||
return {
|
||||
"status_code": None,
|
||||
"message": f"Connection error: {str(e)}",
|
||||
"data": None,
|
||||
}
|
||||
except requests.exceptions.HTTPError as e:
|
||||
logger.error(f"HTTP error {response.status_code}: {str(e)}")
|
||||
try:
|
||||
error_data = response.json()
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
error_data = response.text
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"message": f"HTTP Error {response.status_code}",
|
||||
"data": error_data,
|
||||
}
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"Request failed: {str(e)}")
|
||||
return {
|
||||
"status_code": response.status_code if response else None,
|
||||
"message": f"API call failed: {str(e)}",
|
||||
"data": None,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in API call: {str(e)}", exc_info=True)
|
||||
return {
|
||||
"status_code": None,
|
||||
"message": f"Unexpected error: {str(e)}",
|
||||
"data": None,
|
||||
}
|
||||
|
||||
def _parse_response(self, response: requests.Response) -> Any:
|
||||
"""
|
||||
Parse response based on Content-Type header.
|
||||
|
||||
Supports: JSON, XML, plain text, binary data.
|
||||
"""
|
||||
content_type = response.headers.get("Content-Type", "").lower()
|
||||
|
||||
if not response.content:
|
||||
return None
|
||||
# JSON response
|
||||
|
||||
if "application/json" in content_type:
|
||||
try:
|
||||
return response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Failed to parse JSON response: {str(e)}")
|
||||
return response.text
|
||||
# XML response
|
||||
|
||||
elif "application/xml" in content_type or "text/xml" in content_type:
|
||||
return response.text
|
||||
# Plain text response
|
||||
|
||||
elif "text/plain" in content_type or "text/html" in content_type:
|
||||
return response.text
|
||||
# Binary/unknown response
|
||||
|
||||
else:
|
||||
# Try to decode as text first, fall back to base64
|
||||
|
||||
try:
|
||||
return response.text
|
||||
except (UnicodeDecodeError, AttributeError):
|
||||
import base64
|
||||
|
||||
return base64.b64encode(response.content).decode("utf-8")
|
||||
|
||||
def get_actions_metadata(self):
|
||||
"""Return metadata for available actions (none for API Tool - actions are user-defined)."""
|
||||
return []
|
||||
|
||||
def get_config_requirements(self):
|
||||
"""Return configuration requirements for the tool."""
|
||||
return {}
|
||||
|
||||
@@ -2,6 +2,8 @@ from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class Tool(ABC):
|
||||
internal: bool = False
|
||||
|
||||
@abstractmethod
|
||||
def execute_action(self, action_name: str, **kwargs):
|
||||
pass
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
import logging
|
||||
|
||||
import requests
|
||||
|
||||
from application.agents.tools.base import Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BraveSearchTool(Tool):
|
||||
"""
|
||||
@@ -25,27 +30,35 @@ class BraveSearchTool(Tool):
|
||||
else:
|
||||
raise ValueError(f"Unknown action: {action_name}")
|
||||
|
||||
def _web_search(self, query, country="ALL", search_lang="en", count=10,
|
||||
offset=0, safesearch="off", freshness=None,
|
||||
result_filter=None, extra_snippets=False, summary=False):
|
||||
def _web_search(
|
||||
self,
|
||||
query,
|
||||
country="ALL",
|
||||
search_lang="en",
|
||||
count=10,
|
||||
offset=0,
|
||||
safesearch="off",
|
||||
freshness=None,
|
||||
result_filter=None,
|
||||
extra_snippets=False,
|
||||
summary=False,
|
||||
):
|
||||
"""
|
||||
Performs a web search using the Brave Search API.
|
||||
"""
|
||||
print(f"Performing Brave web search for: {query}")
|
||||
|
||||
logger.debug("Performing Brave web search for: %s", query)
|
||||
|
||||
url = f"{self.base_url}/web/search"
|
||||
|
||||
# Build query parameters
|
||||
|
||||
params = {
|
||||
"q": query,
|
||||
"country": country,
|
||||
"search_lang": search_lang,
|
||||
"count": min(count, 20),
|
||||
"offset": min(offset, 9),
|
||||
"safesearch": safesearch
|
||||
"safesearch": safesearch,
|
||||
}
|
||||
|
||||
# Add optional parameters only if they have values
|
||||
|
||||
if freshness:
|
||||
params["freshness"] = freshness
|
||||
if result_filter:
|
||||
@@ -54,68 +67,69 @@ class BraveSearchTool(Tool):
|
||||
params["extra_snippets"] = 1
|
||||
if summary:
|
||||
params["summary"] = 1
|
||||
|
||||
# Set up headers
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"Accept-Encoding": "gzip",
|
||||
"X-Subscription-Token": self.token
|
||||
"X-Subscription-Token": self.token,
|
||||
}
|
||||
|
||||
# Make the request
|
||||
response = requests.get(url, params=params, headers=headers)
|
||||
|
||||
|
||||
response = requests.get(url, params=params, headers=headers, timeout=100)
|
||||
|
||||
if response.status_code == 200:
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"results": response.json(),
|
||||
"message": "Search completed successfully."
|
||||
"message": "Search completed successfully.",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"message": f"Search failed with status code: {response.status_code}."
|
||||
"message": f"Search failed with status code: {response.status_code}.",
|
||||
}
|
||||
|
||||
def _image_search(self, query, country="ALL", search_lang="en", count=5,
|
||||
safesearch="off", spellcheck=False):
|
||||
|
||||
def _image_search(
|
||||
self,
|
||||
query,
|
||||
country="ALL",
|
||||
search_lang="en",
|
||||
count=5,
|
||||
safesearch="off",
|
||||
spellcheck=False,
|
||||
):
|
||||
"""
|
||||
Performs an image search using the Brave Search API.
|
||||
"""
|
||||
print(f"Performing Brave image search for: {query}")
|
||||
|
||||
logger.debug("Performing Brave image search for: %s", query)
|
||||
|
||||
url = f"{self.base_url}/images/search"
|
||||
|
||||
# Build query parameters
|
||||
|
||||
params = {
|
||||
"q": query,
|
||||
"country": country,
|
||||
"search_lang": search_lang,
|
||||
"count": min(count, 100), # API max is 100
|
||||
"safesearch": safesearch,
|
||||
"spellcheck": 1 if spellcheck else 0
|
||||
"spellcheck": 1 if spellcheck else 0,
|
||||
}
|
||||
|
||||
# Set up headers
|
||||
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"Accept-Encoding": "gzip",
|
||||
"X-Subscription-Token": self.token
|
||||
"X-Subscription-Token": self.token,
|
||||
}
|
||||
|
||||
# Make the request
|
||||
response = requests.get(url, params=params, headers=headers)
|
||||
|
||||
|
||||
response = requests.get(url, params=params, headers=headers, timeout=100)
|
||||
|
||||
if response.status_code == 200:
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"results": response.json(),
|
||||
"message": "Image search completed successfully."
|
||||
"message": "Image search completed successfully.",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"message": f"Image search failed with status code: {response.status_code}."
|
||||
"message": f"Image search failed with status code: {response.status_code}.",
|
||||
}
|
||||
|
||||
def get_actions_metadata(self):
|
||||
@@ -130,42 +144,14 @@ class BraveSearchTool(Tool):
|
||||
"type": "string",
|
||||
"description": "The search query (max 400 characters, 50 words)",
|
||||
},
|
||||
# "country": {
|
||||
# "type": "string",
|
||||
# "description": "The 2-character country code (default: US)",
|
||||
# },
|
||||
"search_lang": {
|
||||
"type": "string",
|
||||
"description": "The search language preference (default: en)",
|
||||
},
|
||||
# "count": {
|
||||
# "type": "integer",
|
||||
# "description": "Number of results to return (max 20, default: 10)",
|
||||
# },
|
||||
# "offset": {
|
||||
# "type": "integer",
|
||||
# "description": "Pagination offset (max 9, default: 0)",
|
||||
# },
|
||||
# "safesearch": {
|
||||
# "type": "string",
|
||||
# "description": "Filter level for adult content (off, moderate, strict)",
|
||||
# },
|
||||
"freshness": {
|
||||
"type": "string",
|
||||
"description": "Time filter for results (pd: last 24h, pw: last week, pm: last month, py: last year)",
|
||||
},
|
||||
# "result_filter": {
|
||||
# "type": "string",
|
||||
# "description": "Comma-delimited list of result types to include",
|
||||
# },
|
||||
# "extra_snippets": {
|
||||
# "type": "boolean",
|
||||
# "description": "Get additional excerpts from result pages",
|
||||
# },
|
||||
# "summary": {
|
||||
# "type": "boolean",
|
||||
# "description": "Enable summary generation in search results",
|
||||
# }
|
||||
},
|
||||
"required": ["query"],
|
||||
"additionalProperties": False,
|
||||
@@ -181,37 +167,25 @@ class BraveSearchTool(Tool):
|
||||
"type": "string",
|
||||
"description": "The search query (max 400 characters, 50 words)",
|
||||
},
|
||||
# "country": {
|
||||
# "type": "string",
|
||||
# "description": "The 2-character country code (default: US)",
|
||||
# },
|
||||
# "search_lang": {
|
||||
# "type": "string",
|
||||
# "description": "The search language preference (default: en)",
|
||||
# },
|
||||
"count": {
|
||||
"type": "integer",
|
||||
"description": "Number of results to return (max 100, default: 5)",
|
||||
},
|
||||
# "safesearch": {
|
||||
# "type": "string",
|
||||
# "description": "Filter level for adult content (off, strict). Default: strict",
|
||||
# },
|
||||
# "spellcheck": {
|
||||
# "type": "boolean",
|
||||
# "description": "Whether to spellcheck provided query (default: true)",
|
||||
# }
|
||||
},
|
||||
"required": ["query"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
def get_config_requirements(self):
|
||||
return {
|
||||
"token": {
|
||||
"type": "string",
|
||||
"description": "Brave Search API key for authentication"
|
||||
"type": "string",
|
||||
"label": "API Key",
|
||||
"description": "Brave Search API key for authentication",
|
||||
"required": True,
|
||||
"secret": True,
|
||||
"order": 1,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ class CryptoPriceTool(Tool):
|
||||
returns price in USD.
|
||||
"""
|
||||
url = f"https://min-api.cryptocompare.com/data/price?fsym={symbol.upper()}&tsyms={currency.upper()}"
|
||||
response = requests.get(url)
|
||||
response = requests.get(url, timeout=100)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
if currency.upper() in data:
|
||||
|
||||
209
application/agents/tools/duckduckgo.py
Normal file
209
application/agents/tools/duckduckgo.py
Normal file
@@ -0,0 +1,209 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from application.agents.tools.base import Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_RETRIES = 3
|
||||
RETRY_DELAY = 2.0
|
||||
DEFAULT_TIMEOUT = 15
|
||||
|
||||
|
||||
class DuckDuckGoSearchTool(Tool):
|
||||
"""
|
||||
DuckDuckGo Search
|
||||
A tool for performing web and image searches using DuckDuckGo.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.timeout = config.get("timeout", DEFAULT_TIMEOUT)
|
||||
|
||||
def _get_ddgs_client(self):
|
||||
from ddgs import DDGS
|
||||
|
||||
return DDGS(timeout=self.timeout)
|
||||
|
||||
def _execute_with_retry(self, operation, operation_name: str) -> Dict[str, Any]:
|
||||
last_error = None
|
||||
for attempt in range(1, MAX_RETRIES + 1):
|
||||
try:
|
||||
results = operation()
|
||||
return {
|
||||
"status_code": 200,
|
||||
"results": list(results) if results else [],
|
||||
"message": f"{operation_name} completed successfully.",
|
||||
}
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
error_str = str(e).lower()
|
||||
if "ratelimit" in error_str or "429" in error_str:
|
||||
if attempt < MAX_RETRIES:
|
||||
delay = RETRY_DELAY * attempt
|
||||
logger.warning(
|
||||
f"{operation_name} rate limited, retrying in {delay}s (attempt {attempt}/{MAX_RETRIES})"
|
||||
)
|
||||
time.sleep(delay)
|
||||
continue
|
||||
logger.error(f"{operation_name} failed: {e}")
|
||||
break
|
||||
return {
|
||||
"status_code": 500,
|
||||
"results": [],
|
||||
"message": f"{operation_name} failed: {str(last_error)}",
|
||||
}
|
||||
|
||||
def execute_action(self, action_name, **kwargs):
|
||||
actions = {
|
||||
"ddg_web_search": self._web_search,
|
||||
"ddg_image_search": self._image_search,
|
||||
"ddg_news_search": self._news_search,
|
||||
}
|
||||
if action_name not in actions:
|
||||
raise ValueError(f"Unknown action: {action_name}")
|
||||
return actions[action_name](**kwargs)
|
||||
|
||||
def _web_search(
|
||||
self,
|
||||
query: str,
|
||||
max_results: int = 5,
|
||||
region: str = "wt-wt",
|
||||
safesearch: str = "moderate",
|
||||
timelimit: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
logger.info(f"DuckDuckGo web search: {query}")
|
||||
|
||||
def operation():
|
||||
client = self._get_ddgs_client()
|
||||
return client.text(
|
||||
query,
|
||||
region=region,
|
||||
safesearch=safesearch,
|
||||
timelimit=timelimit,
|
||||
max_results=min(max_results, 20),
|
||||
)
|
||||
|
||||
return self._execute_with_retry(operation, "Web search")
|
||||
|
||||
def _image_search(
|
||||
self,
|
||||
query: str,
|
||||
max_results: int = 5,
|
||||
region: str = "wt-wt",
|
||||
safesearch: str = "moderate",
|
||||
timelimit: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
logger.info(f"DuckDuckGo image search: {query}")
|
||||
|
||||
def operation():
|
||||
client = self._get_ddgs_client()
|
||||
return client.images(
|
||||
query,
|
||||
region=region,
|
||||
safesearch=safesearch,
|
||||
timelimit=timelimit,
|
||||
max_results=min(max_results, 50),
|
||||
)
|
||||
|
||||
return self._execute_with_retry(operation, "Image search")
|
||||
|
||||
def _news_search(
|
||||
self,
|
||||
query: str,
|
||||
max_results: int = 5,
|
||||
region: str = "wt-wt",
|
||||
safesearch: str = "moderate",
|
||||
timelimit: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
logger.info(f"DuckDuckGo news search: {query}")
|
||||
|
||||
def operation():
|
||||
client = self._get_ddgs_client()
|
||||
return client.news(
|
||||
query,
|
||||
region=region,
|
||||
safesearch=safesearch,
|
||||
timelimit=timelimit,
|
||||
max_results=min(max_results, 20),
|
||||
)
|
||||
|
||||
return self._execute_with_retry(operation, "News search")
|
||||
|
||||
def get_actions_metadata(self):
|
||||
return [
|
||||
{
|
||||
"name": "ddg_web_search",
|
||||
"description": "Search the web using DuckDuckGo. Returns titles, URLs, and snippets.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query",
|
||||
},
|
||||
"max_results": {
|
||||
"type": "integer",
|
||||
"description": "Number of results (default: 5, max: 20)",
|
||||
},
|
||||
"region": {
|
||||
"type": "string",
|
||||
"description": "Region code (default: wt-wt for worldwide, us-en for US)",
|
||||
},
|
||||
"timelimit": {
|
||||
"type": "string",
|
||||
"description": "Time filter: d (day), w (week), m (month), y (year)",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "ddg_image_search",
|
||||
"description": "Search for images using DuckDuckGo. Returns image URLs and metadata.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Image search query",
|
||||
},
|
||||
"max_results": {
|
||||
"type": "integer",
|
||||
"description": "Number of results (default: 5, max: 50)",
|
||||
},
|
||||
"region": {
|
||||
"type": "string",
|
||||
"description": "Region code (default: wt-wt for worldwide)",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "ddg_news_search",
|
||||
"description": "Search for news articles using DuckDuckGo. Returns recent news.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "News search query",
|
||||
},
|
||||
"max_results": {
|
||||
"type": "integer",
|
||||
"description": "Number of results (default: 5, max: 20)",
|
||||
},
|
||||
"timelimit": {
|
||||
"type": "string",
|
||||
"description": "Time filter: d (day), w (week), m (month)",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
def get_config_requirements(self):
|
||||
return {}
|
||||
438
application/agents/tools/internal_search.py
Normal file
438
application/agents/tools/internal_search.py
Normal file
@@ -0,0 +1,438 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from application.agents.tools.base import Tool
|
||||
from application.core.settings import settings
|
||||
from application.retriever.retriever_creator import RetrieverCreator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InternalSearchTool(Tool):
|
||||
"""Wraps the ClassicRAG retriever as an LLM-callable tool.
|
||||
|
||||
Instead of pre-fetching docs into the prompt, the LLM decides
|
||||
when and what to search. Supports multiple searches per session.
|
||||
|
||||
Optional capabilities (enabled when sources have directory_structure):
|
||||
- path_filter on search: restrict results to a specific file/folder
|
||||
- list_files action: browse the file/folder structure
|
||||
"""
|
||||
|
||||
internal = True
|
||||
|
||||
def __init__(self, config: Dict):
|
||||
self.config = config
|
||||
self.retrieved_docs: List[Dict] = []
|
||||
self._retriever = None
|
||||
self._directory_structure: Optional[Dict] = None
|
||||
self._dir_structure_loaded = False
|
||||
|
||||
def _get_retriever(self):
|
||||
if self._retriever is None:
|
||||
self._retriever = RetrieverCreator.create_retriever(
|
||||
self.config.get("retriever_name", "classic"),
|
||||
source=self.config.get("source", {}),
|
||||
chat_history=[],
|
||||
prompt="",
|
||||
chunks=int(self.config.get("chunks", 2)),
|
||||
doc_token_limit=int(self.config.get("doc_token_limit", 50000)),
|
||||
model_id=self.config.get("model_id", "docsgpt-local"),
|
||||
user_api_key=self.config.get("user_api_key"),
|
||||
agent_id=self.config.get("agent_id"),
|
||||
llm_name=self.config.get("llm_name", settings.LLM_PROVIDER),
|
||||
api_key=self.config.get("api_key", settings.API_KEY),
|
||||
decoded_token=self.config.get("decoded_token"),
|
||||
)
|
||||
return self._retriever
|
||||
|
||||
def _get_directory_structure(self) -> Optional[Dict]:
|
||||
"""Load directory structure from MongoDB for the configured sources."""
|
||||
if self._dir_structure_loaded:
|
||||
return self._directory_structure
|
||||
|
||||
self._dir_structure_loaded = True
|
||||
source = self.config.get("source", {})
|
||||
active_docs = source.get("active_docs", [])
|
||||
if not active_docs:
|
||||
return None
|
||||
|
||||
try:
|
||||
from bson.objectid import ObjectId
|
||||
from application.core.mongo_db import MongoDB
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
sources_collection = db["sources"]
|
||||
|
||||
if isinstance(active_docs, str):
|
||||
active_docs = [active_docs]
|
||||
|
||||
merged_structure = {}
|
||||
for doc_id in active_docs:
|
||||
try:
|
||||
source_doc = sources_collection.find_one(
|
||||
{"_id": ObjectId(doc_id)}
|
||||
)
|
||||
if not source_doc:
|
||||
continue
|
||||
dir_str = source_doc.get("directory_structure")
|
||||
if dir_str:
|
||||
if isinstance(dir_str, str):
|
||||
dir_str = json.loads(dir_str)
|
||||
source_name = source_doc.get("name", doc_id)
|
||||
if len(active_docs) > 1:
|
||||
merged_structure[source_name] = dir_str
|
||||
else:
|
||||
merged_structure = dir_str
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not load dir structure for {doc_id}: {e}")
|
||||
|
||||
self._directory_structure = merged_structure if merged_structure else None
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to load directory structures: {e}")
|
||||
|
||||
return self._directory_structure
|
||||
|
||||
def execute_action(self, action_name: str, **kwargs):
|
||||
if action_name == "search":
|
||||
return self._execute_search(**kwargs)
|
||||
elif action_name == "list_files":
|
||||
return self._execute_list_files(**kwargs)
|
||||
return f"Unknown action: {action_name}"
|
||||
|
||||
def _execute_search(self, **kwargs) -> str:
|
||||
query = kwargs.get("query", "")
|
||||
path_filter = kwargs.get("path_filter", "")
|
||||
|
||||
if not query:
|
||||
return "Error: 'query' parameter is required."
|
||||
|
||||
try:
|
||||
retriever = self._get_retriever()
|
||||
docs = retriever.search(query)
|
||||
except Exception as e:
|
||||
logger.error(f"Internal search failed: {e}", exc_info=True)
|
||||
return "Search failed: an internal error occurred."
|
||||
|
||||
if not docs:
|
||||
return "No documents found matching your query."
|
||||
|
||||
# Apply path filter if specified
|
||||
if path_filter:
|
||||
path_lower = path_filter.lower()
|
||||
docs = [
|
||||
d
|
||||
for d in docs
|
||||
if path_lower in d.get("source", "").lower()
|
||||
or path_lower in d.get("filename", "").lower()
|
||||
or path_lower in d.get("title", "").lower()
|
||||
]
|
||||
if not docs:
|
||||
return f"No documents found matching query '{query}' in path '{path_filter}'."
|
||||
|
||||
# Accumulate for source tracking
|
||||
for doc in docs:
|
||||
if doc not in self.retrieved_docs:
|
||||
self.retrieved_docs.append(doc)
|
||||
|
||||
# Format results for the LLM
|
||||
formatted = []
|
||||
for i, doc in enumerate(docs, 1):
|
||||
title = doc.get("title", "Untitled")
|
||||
text = doc.get("text", "")
|
||||
source = doc.get("source", "Unknown")
|
||||
filename = doc.get("filename", "")
|
||||
header = filename or title
|
||||
formatted.append(f"[{i}] {header} (source: {source})\n{text}")
|
||||
|
||||
return "\n\n---\n\n".join(formatted)
|
||||
|
||||
def _execute_list_files(self, **kwargs) -> str:
|
||||
path = kwargs.get("path", "")
|
||||
dir_structure = self._get_directory_structure()
|
||||
|
||||
if not dir_structure:
|
||||
return "No file structure available for the current sources."
|
||||
|
||||
# Navigate to the requested path
|
||||
current = dir_structure
|
||||
if path:
|
||||
for part in path.strip("/").split("/"):
|
||||
if not part:
|
||||
continue
|
||||
if isinstance(current, dict) and part in current:
|
||||
current = current[part]
|
||||
else:
|
||||
return f"Path '{path}' not found in the file structure."
|
||||
|
||||
# Format the structure for the LLM
|
||||
return self._format_structure(current, path or "/")
|
||||
|
||||
def _format_structure(self, node: Dict, current_path: str) -> str:
|
||||
if not isinstance(node, dict):
|
||||
return f"'{current_path}' is a file, not a directory."
|
||||
|
||||
lines = [f"File structure at '{current_path}':\n"]
|
||||
folders = []
|
||||
files = []
|
||||
|
||||
for name, value in sorted(node.items()):
|
||||
if isinstance(value, dict):
|
||||
# Check if it's a file metadata dict or a folder
|
||||
if "type" in value or "size_bytes" in value or "token_count" in value:
|
||||
# It's a file with metadata
|
||||
size = value.get("token_count", "")
|
||||
ftype = value.get("type", "")
|
||||
info_parts = []
|
||||
if ftype:
|
||||
info_parts.append(ftype)
|
||||
if size:
|
||||
info_parts.append(f"{size} tokens")
|
||||
info = f" ({', '.join(info_parts)})" if info_parts else ""
|
||||
files.append(f" {name}{info}")
|
||||
else:
|
||||
# It's a folder
|
||||
count = self._count_files(value)
|
||||
folders.append(f" {name}/ ({count} items)")
|
||||
else:
|
||||
files.append(f" {name}")
|
||||
|
||||
if folders:
|
||||
lines.append("Folders:")
|
||||
lines.extend(folders)
|
||||
if files:
|
||||
lines.append("Files:")
|
||||
lines.extend(files)
|
||||
if not folders and not files:
|
||||
lines.append(" (empty)")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _count_files(self, node: Dict) -> int:
|
||||
count = 0
|
||||
for value in node.values():
|
||||
if isinstance(value, dict):
|
||||
if "type" in value or "size_bytes" in value or "token_count" in value:
|
||||
count += 1
|
||||
else:
|
||||
count += self._count_files(value)
|
||||
else:
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def get_actions_metadata(self):
|
||||
actions = [
|
||||
{
|
||||
"name": "search",
|
||||
"description": (
|
||||
"Search the user's uploaded documents and knowledge base. "
|
||||
"Use this to find relevant information before answering questions. "
|
||||
"You can call this multiple times with different queries."
|
||||
),
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search query. Be specific and focused.",
|
||||
"filled_by_llm": True,
|
||||
"required": True,
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
# Add path_filter and list_files only if directory structure exists
|
||||
has_structure = self.config.get("has_directory_structure", False)
|
||||
if has_structure:
|
||||
actions[0]["parameters"]["properties"]["path_filter"] = {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Optional: filter results to a specific file or folder path. "
|
||||
"Use list_files first to see available paths."
|
||||
),
|
||||
"filled_by_llm": True,
|
||||
"required": False,
|
||||
}
|
||||
actions.append(
|
||||
{
|
||||
"name": "list_files",
|
||||
"description": (
|
||||
"Browse the file and folder structure of the knowledge base. "
|
||||
"Use this to see what files are available before searching. "
|
||||
"Optionally provide a path to browse a specific folder."
|
||||
),
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Optional: folder path to browse. Leave empty for root.",
|
||||
"filled_by_llm": True,
|
||||
"required": False,
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return actions
|
||||
|
||||
def get_config_requirements(self):
|
||||
return {}
|
||||
|
||||
|
||||
# Constants for building synthetic tools_dict entries
|
||||
INTERNAL_TOOL_ID = "internal"
|
||||
|
||||
|
||||
def build_internal_tool_entry(has_directory_structure: bool = False) -> Dict:
|
||||
"""Build the tools_dict entry for InternalSearchTool.
|
||||
|
||||
Dynamically includes list_files and path_filter based on
|
||||
whether the sources have directory structure.
|
||||
"""
|
||||
search_params = {
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search query. Be specific and focused.",
|
||||
"filled_by_llm": True,
|
||||
"required": True,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
actions = [
|
||||
{
|
||||
"name": "search",
|
||||
"description": (
|
||||
"Search the user's uploaded documents and knowledge base. "
|
||||
"Use this to find relevant information before answering questions. "
|
||||
"You can call this multiple times with different queries."
|
||||
),
|
||||
"active": True,
|
||||
"parameters": search_params,
|
||||
}
|
||||
]
|
||||
|
||||
if has_directory_structure:
|
||||
search_params["properties"]["path_filter"] = {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Optional: filter results to a specific file or folder path. "
|
||||
"Use list_files first to see available paths."
|
||||
),
|
||||
"filled_by_llm": True,
|
||||
"required": False,
|
||||
}
|
||||
actions.append(
|
||||
{
|
||||
"name": "list_files",
|
||||
"description": (
|
||||
"Browse the file and folder structure of the knowledge base. "
|
||||
"Use this to see what files are available before searching. "
|
||||
"Optionally provide a path to browse a specific folder."
|
||||
),
|
||||
"active": True,
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Optional: folder path to browse. Leave empty for root.",
|
||||
"filled_by_llm": True,
|
||||
"required": False,
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return {"name": "internal_search", "actions": actions}
|
||||
|
||||
|
||||
# Keep backward compat
|
||||
INTERNAL_TOOL_ENTRY = build_internal_tool_entry(has_directory_structure=False)
|
||||
|
||||
|
||||
def sources_have_directory_structure(source: Dict) -> bool:
|
||||
"""Check if any of the active sources have directory_structure in MongoDB."""
|
||||
active_docs = source.get("active_docs", [])
|
||||
if not active_docs:
|
||||
return False
|
||||
|
||||
try:
|
||||
from bson.objectid import ObjectId
|
||||
from application.core.mongo_db import MongoDB
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
sources_collection = db["sources"]
|
||||
|
||||
if isinstance(active_docs, str):
|
||||
active_docs = [active_docs]
|
||||
|
||||
for doc_id in active_docs:
|
||||
try:
|
||||
source_doc = sources_collection.find_one(
|
||||
{"_id": ObjectId(doc_id)},
|
||||
{"directory_structure": 1},
|
||||
)
|
||||
if source_doc and source_doc.get("directory_structure"):
|
||||
return True
|
||||
except Exception:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not check directory structure: {e}")
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def add_internal_search_tool(tools_dict: Dict, retriever_config: Dict) -> None:
|
||||
"""Add the internal search tool to tools_dict if sources are configured.
|
||||
|
||||
Shared by AgenticAgent and ResearchAgent to avoid duplicate setup logic.
|
||||
Mutates tools_dict in place.
|
||||
"""
|
||||
source = retriever_config.get("source", {})
|
||||
has_sources = bool(source.get("active_docs"))
|
||||
if not retriever_config or not has_sources:
|
||||
return
|
||||
|
||||
has_dir = sources_have_directory_structure(source)
|
||||
internal_entry = build_internal_tool_entry(has_directory_structure=has_dir)
|
||||
internal_entry["config"] = build_internal_tool_config(
|
||||
**retriever_config,
|
||||
has_directory_structure=has_dir,
|
||||
)
|
||||
tools_dict[INTERNAL_TOOL_ID] = internal_entry
|
||||
|
||||
|
||||
def build_internal_tool_config(
|
||||
source: Dict,
|
||||
retriever_name: str = "classic",
|
||||
chunks: int = 2,
|
||||
doc_token_limit: int = 50000,
|
||||
model_id: str = "docsgpt-local",
|
||||
user_api_key: Optional[str] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
llm_name: str = None,
|
||||
api_key: str = None,
|
||||
decoded_token: Optional[Dict] = None,
|
||||
has_directory_structure: bool = False,
|
||||
) -> Dict:
|
||||
"""Build the config dict for InternalSearchTool."""
|
||||
return {
|
||||
"source": source,
|
||||
"retriever_name": retriever_name,
|
||||
"chunks": chunks,
|
||||
"doc_token_limit": doc_token_limit,
|
||||
"model_id": model_id,
|
||||
"user_api_key": user_api_key,
|
||||
"agent_id": agent_id,
|
||||
"llm_name": llm_name or settings.LLM_PROVIDER,
|
||||
"api_key": api_key or settings.API_KEY,
|
||||
"decoded_token": decoded_token,
|
||||
"has_directory_structure": has_directory_structure,
|
||||
}
|
||||
996
application/agents/tools/mcp_tool.py
Normal file
996
application/agents/tools/mcp_tool.py
Normal file
@@ -0,0 +1,996 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import concurrent.futures
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from fastmcp import Client
|
||||
from fastmcp.client.auth import BearerAuth
|
||||
from fastmcp.client.transports import (
|
||||
SSETransport,
|
||||
StdioTransport,
|
||||
StreamableHttpTransport,
|
||||
)
|
||||
from mcp.client.auth import OAuthClientProvider, TokenStorage
|
||||
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
|
||||
from pydantic import AnyHttpUrl, ValidationError
|
||||
from redis import Redis
|
||||
|
||||
from application.agents.tools.base import Tool
|
||||
from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task
|
||||
from application.cache import get_redis_instance
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.core.url_validation import SSRFError, validate_url
|
||||
from application.security.encryption import decrypt_credentials
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
|
||||
_mcp_clients_cache = {}
|
||||
|
||||
|
||||
class MCPTool(Tool):
|
||||
"""
|
||||
MCP Tool
|
||||
Connect to remote Model Context Protocol (MCP) servers to access dynamic tools and resources.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any], user_id: Optional[str] = None):
|
||||
"""
|
||||
Initialize the MCP Tool with configuration.
|
||||
|
||||
Args:
|
||||
config: Dictionary containing MCP server configuration:
|
||||
- server_url: URL of the remote MCP server
|
||||
- transport_type: Transport type (auto, sse, http, stdio)
|
||||
- auth_type: Type of authentication (bearer, oauth, api_key, basic, none)
|
||||
- encrypted_credentials: Encrypted credentials (if available)
|
||||
- timeout: Request timeout in seconds (default: 30)
|
||||
- headers: Custom headers for requests
|
||||
- command: Command for STDIO transport
|
||||
- args: Arguments for STDIO transport
|
||||
- oauth_scopes: OAuth scopes for oauth auth type
|
||||
- oauth_client_name: OAuth client name for oauth auth type
|
||||
- query_mode: If True, use non-interactive OAuth (fail-fast on 401)
|
||||
user_id: User ID for decrypting credentials (required if encrypted_credentials exist)
|
||||
"""
|
||||
self.config = config
|
||||
self.user_id = user_id
|
||||
raw_url = config.get("server_url", "")
|
||||
self.server_url = self._validate_server_url(raw_url) if raw_url else ""
|
||||
self.transport_type = config.get("transport_type", "auto")
|
||||
self.auth_type = config.get("auth_type", "none")
|
||||
self.timeout = config.get("timeout", 30)
|
||||
self.custom_headers = config.get("headers", {})
|
||||
|
||||
self.auth_credentials = {}
|
||||
if config.get("encrypted_credentials") and user_id:
|
||||
self.auth_credentials = decrypt_credentials(
|
||||
config["encrypted_credentials"], user_id
|
||||
)
|
||||
else:
|
||||
self.auth_credentials = config.get("auth_credentials", {})
|
||||
self.oauth_scopes = config.get("oauth_scopes", [])
|
||||
self.oauth_task_id = config.get("oauth_task_id", None)
|
||||
self.oauth_client_name = config.get("oauth_client_name", "DocsGPT-MCP")
|
||||
self.redirect_uri = self._resolve_redirect_uri(config.get("redirect_uri"))
|
||||
|
||||
self.available_tools = []
|
||||
self._cache_key = self._generate_cache_key()
|
||||
self._client = None
|
||||
self.query_mode = config.get("query_mode", False)
|
||||
|
||||
if self.server_url and self.auth_type != "oauth":
|
||||
self._setup_client()
|
||||
|
||||
@staticmethod
|
||||
def _validate_server_url(server_url: str) -> str:
|
||||
"""Validate server_url to prevent SSRF to internal networks.
|
||||
|
||||
Raises:
|
||||
ValueError: If the URL points to a private/internal address.
|
||||
"""
|
||||
try:
|
||||
return validate_url(server_url)
|
||||
except SSRFError as exc:
|
||||
raise ValueError(f"Invalid MCP server URL: {exc}") from exc
|
||||
|
||||
def _resolve_redirect_uri(self, configured_redirect_uri: Optional[str]) -> str:
|
||||
if configured_redirect_uri:
|
||||
return configured_redirect_uri.rstrip("/")
|
||||
|
||||
explicit = getattr(settings, "MCP_OAUTH_REDIRECT_URI", None)
|
||||
if explicit:
|
||||
return explicit.rstrip("/")
|
||||
|
||||
connector_base = getattr(settings, "CONNECTOR_REDIRECT_BASE_URI", None)
|
||||
if connector_base:
|
||||
parsed = urlparse(connector_base)
|
||||
if parsed.scheme and parsed.netloc:
|
||||
return f"{parsed.scheme}://{parsed.netloc}/api/mcp_server/callback"
|
||||
|
||||
return f"{settings.API_URL.rstrip('/')}/api/mcp_server/callback"
|
||||
|
||||
def _generate_cache_key(self) -> str:
|
||||
"""Generate a unique cache key for this MCP server configuration."""
|
||||
auth_key = ""
|
||||
if self.auth_type == "oauth":
|
||||
scopes_str = ",".join(self.oauth_scopes) if self.oauth_scopes else "none"
|
||||
oauth_identity = self.user_id or self.oauth_task_id or "anonymous"
|
||||
auth_key = (
|
||||
f"oauth:{oauth_identity}:{self.oauth_client_name}:{scopes_str}:{self.redirect_uri}"
|
||||
)
|
||||
elif self.auth_type in ["bearer"]:
|
||||
token = self.auth_credentials.get(
|
||||
"bearer_token", ""
|
||||
) or self.auth_credentials.get("access_token", "")
|
||||
auth_key = f"bearer:{token[:10]}..." if token else "bearer:none"
|
||||
elif self.auth_type == "api_key":
|
||||
api_key = self.auth_credentials.get("api_key", "")
|
||||
auth_key = f"apikey:{api_key[:10]}..." if api_key else "apikey:none"
|
||||
elif self.auth_type == "basic":
|
||||
username = self.auth_credentials.get("username", "")
|
||||
auth_key = f"basic:{username}"
|
||||
else:
|
||||
auth_key = "none"
|
||||
return f"{self.server_url}#{self.transport_type}#{auth_key}"
|
||||
|
||||
def _setup_client(self):
|
||||
global _mcp_clients_cache
|
||||
if self._cache_key in _mcp_clients_cache:
|
||||
cached_data = _mcp_clients_cache[self._cache_key]
|
||||
if time.time() - cached_data["created_at"] < 300:
|
||||
self._client = cached_data["client"]
|
||||
return
|
||||
else:
|
||||
del _mcp_clients_cache[self._cache_key]
|
||||
transport = self._create_transport()
|
||||
auth = None
|
||||
|
||||
if self.auth_type == "oauth":
|
||||
redis_client = get_redis_instance()
|
||||
if self.query_mode:
|
||||
auth = NonInteractiveOAuth(
|
||||
mcp_url=self.server_url,
|
||||
scopes=self.oauth_scopes,
|
||||
redis_client=redis_client,
|
||||
redirect_uri=self.redirect_uri,
|
||||
db=db,
|
||||
user_id=self.user_id,
|
||||
)
|
||||
else:
|
||||
auth = DocsGPTOAuth(
|
||||
mcp_url=self.server_url,
|
||||
scopes=self.oauth_scopes,
|
||||
redis_client=redis_client,
|
||||
redirect_uri=self.redirect_uri,
|
||||
task_id=self.oauth_task_id,
|
||||
db=db,
|
||||
user_id=self.user_id,
|
||||
)
|
||||
elif self.auth_type == "bearer":
|
||||
token = self.auth_credentials.get(
|
||||
"bearer_token", ""
|
||||
) or self.auth_credentials.get("access_token", "")
|
||||
if token:
|
||||
auth = BearerAuth(token)
|
||||
self._client = Client(transport, auth=auth)
|
||||
_mcp_clients_cache[self._cache_key] = {
|
||||
"client": self._client,
|
||||
"created_at": time.time(),
|
||||
}
|
||||
|
||||
def _create_transport(self):
|
||||
"""Create appropriate transport based on configuration."""
|
||||
headers = {"Content-Type": "application/json", "User-Agent": "DocsGPT-MCP/1.0"}
|
||||
headers.update(self.custom_headers)
|
||||
|
||||
if self.auth_type == "api_key":
|
||||
api_key = self.auth_credentials.get("api_key", "")
|
||||
header_name = self.auth_credentials.get("api_key_header", "X-API-Key")
|
||||
if api_key:
|
||||
headers[header_name] = api_key
|
||||
elif self.auth_type == "basic":
|
||||
username = self.auth_credentials.get("username", "")
|
||||
password = self.auth_credentials.get("password", "")
|
||||
if username and password:
|
||||
credentials = base64.b64encode(
|
||||
f"{username}:{password}".encode()
|
||||
).decode()
|
||||
headers["Authorization"] = f"Basic {credentials}"
|
||||
if self.transport_type == "auto":
|
||||
if "sse" in self.server_url.lower() or self.server_url.endswith("/sse"):
|
||||
transport_type = "sse"
|
||||
else:
|
||||
transport_type = "http"
|
||||
else:
|
||||
transport_type = self.transport_type
|
||||
if transport_type == "stdio":
|
||||
raise ValueError("STDIO transport is disabled")
|
||||
if transport_type == "sse":
|
||||
headers.update({"Accept": "text/event-stream", "Cache-Control": "no-cache"})
|
||||
return SSETransport(url=self.server_url, headers=headers)
|
||||
elif transport_type == "http":
|
||||
return StreamableHttpTransport(url=self.server_url, headers=headers)
|
||||
elif transport_type == "stdio":
|
||||
command = self.config.get("command", "python")
|
||||
args = self.config.get("args", [])
|
||||
env = self.auth_credentials if self.auth_credentials else None
|
||||
return StdioTransport(command=command, args=args, env=env)
|
||||
else:
|
||||
return StreamableHttpTransport(url=self.server_url, headers=headers)
|
||||
|
||||
def _format_tools(self, tools_response) -> List[Dict]:
|
||||
"""Format tools response to match expected format."""
|
||||
if hasattr(tools_response, "tools"):
|
||||
tools = tools_response.tools
|
||||
elif isinstance(tools_response, list):
|
||||
tools = tools_response
|
||||
else:
|
||||
tools = []
|
||||
tools_dict = []
|
||||
for tool in tools:
|
||||
if hasattr(tool, "name"):
|
||||
tool_dict = {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
}
|
||||
if hasattr(tool, "inputSchema"):
|
||||
tool_dict["inputSchema"] = tool.inputSchema
|
||||
tools_dict.append(tool_dict)
|
||||
elif isinstance(tool, dict):
|
||||
tools_dict.append(tool)
|
||||
else:
|
||||
if hasattr(tool, "model_dump"):
|
||||
tools_dict.append(tool.model_dump())
|
||||
else:
|
||||
tools_dict.append({"name": str(tool), "description": ""})
|
||||
return tools_dict
|
||||
|
||||
async def _execute_with_client(self, operation: str, *args, **kwargs):
|
||||
"""Execute operation with FastMCP client."""
|
||||
if not self._client:
|
||||
raise Exception("FastMCP client not initialized")
|
||||
async with self._client:
|
||||
if operation == "ping":
|
||||
return await self._client.ping()
|
||||
elif operation == "list_tools":
|
||||
tools_response = await self._client.list_tools()
|
||||
self.available_tools = self._format_tools(tools_response)
|
||||
return self.available_tools
|
||||
elif operation == "call_tool":
|
||||
tool_name = args[0]
|
||||
tool_args = kwargs
|
||||
return await self._client.call_tool(tool_name, tool_args)
|
||||
elif operation == "list_resources":
|
||||
return await self._client.list_resources()
|
||||
elif operation == "list_prompts":
|
||||
return await self._client.list_prompts()
|
||||
else:
|
||||
raise Exception(f"Unknown operation: {operation}")
|
||||
|
||||
_ERROR_MAP = [
|
||||
(concurrent.futures.TimeoutError, lambda op, t, _: f"Timed out after {t}s"),
|
||||
(ConnectionRefusedError, lambda *_: "Connection refused"),
|
||||
]
|
||||
|
||||
_ERROR_PATTERNS = {
|
||||
("403", "Forbidden"): "Access denied (403 Forbidden)",
|
||||
("401", "Unauthorized"): "Authentication failed (401 Unauthorized)",
|
||||
("ECONNREFUSED",): "Connection refused",
|
||||
("SSL", "certificate"): "SSL/TLS error",
|
||||
}
|
||||
|
||||
def _run_async_operation(self, operation: str, *args, **kwargs):
|
||||
try:
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(
|
||||
self._run_in_new_loop, operation, *args, **kwargs
|
||||
)
|
||||
return future.result(timeout=self.timeout)
|
||||
except RuntimeError:
|
||||
return self._run_in_new_loop(operation, *args, **kwargs)
|
||||
except Exception as e:
|
||||
raise self._map_error(operation, e) from e
|
||||
raise self._map_error(operation, e) from e
|
||||
|
||||
def _run_in_new_loop(self, operation, *args, **kwargs):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
return loop.run_until_complete(
|
||||
self._execute_with_client(operation, *args, **kwargs)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
def _map_error(self, operation: str, exc: Exception) -> Exception:
|
||||
for exc_type, msg_fn in self._ERROR_MAP:
|
||||
if isinstance(exc, exc_type):
|
||||
return Exception(msg_fn(operation, self.timeout, exc))
|
||||
error_msg = str(exc)
|
||||
for patterns, friendly in self._ERROR_PATTERNS.items():
|
||||
if any(p.lower() in error_msg.lower() for p in patterns):
|
||||
return Exception(friendly)
|
||||
logger.error("MCP %s failed: %s", operation, exc)
|
||||
return exc
|
||||
|
||||
def discover_tools(self) -> List[Dict]:
|
||||
"""
|
||||
Discover available tools from the MCP server using FastMCP.
|
||||
|
||||
Returns:
|
||||
List of tool definitions from the server
|
||||
"""
|
||||
if not self.server_url:
|
||||
return []
|
||||
if not self._client:
|
||||
self._setup_client()
|
||||
try:
|
||||
tools = self._run_async_operation("list_tools")
|
||||
self.available_tools = tools
|
||||
return self.available_tools
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to discover tools from MCP server: {str(e)}")
|
||||
|
||||
def execute_action(self, action_name: str, **kwargs) -> Any:
|
||||
if not self.server_url:
|
||||
raise Exception("No MCP server configured")
|
||||
if not self._client:
|
||||
self._setup_client()
|
||||
cleaned_kwargs = {}
|
||||
for key, value in kwargs.items():
|
||||
if value == "" or value is None:
|
||||
continue
|
||||
cleaned_kwargs[key] = value
|
||||
try:
|
||||
result = self._run_async_operation(
|
||||
"call_tool", action_name, **cleaned_kwargs
|
||||
)
|
||||
return self._format_result(result)
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
lower_msg = error_msg.lower()
|
||||
is_auth_error = (
|
||||
"401" in error_msg
|
||||
or "unauthorized" in lower_msg
|
||||
or "session expired" in lower_msg
|
||||
or "re-authorize" in lower_msg
|
||||
)
|
||||
if is_auth_error:
|
||||
if self.auth_type == "oauth":
|
||||
raise Exception(
|
||||
f"Action '{action_name}' failed: OAuth session expired. "
|
||||
"Please re-authorize this MCP server in tool settings."
|
||||
) from e
|
||||
global _mcp_clients_cache
|
||||
_mcp_clients_cache.pop(self._cache_key, None)
|
||||
self._client = None
|
||||
self._setup_client()
|
||||
try:
|
||||
result = self._run_async_operation(
|
||||
"call_tool", action_name, **cleaned_kwargs
|
||||
)
|
||||
return self._format_result(result)
|
||||
except Exception as retry_e:
|
||||
raise Exception(
|
||||
f"Action '{action_name}' failed after re-auth attempt: {retry_e}. "
|
||||
"Your credentials may have expired — please re-authorize in tool settings."
|
||||
) from retry_e
|
||||
raise Exception(
|
||||
f"Failed to execute action '{action_name}': {error_msg}"
|
||||
) from e
|
||||
|
||||
def _format_result(self, result) -> Dict:
|
||||
"""Format FastMCP result to match expected format."""
|
||||
if hasattr(result, "content"):
|
||||
content_list = []
|
||||
for content_item in result.content:
|
||||
if hasattr(content_item, "text"):
|
||||
content_list.append({"type": "text", "text": content_item.text})
|
||||
elif hasattr(content_item, "data"):
|
||||
content_list.append({"type": "data", "data": content_item.data})
|
||||
else:
|
||||
content_list.append(
|
||||
{"type": "unknown", "content": str(content_item)}
|
||||
)
|
||||
return {
|
||||
"content": content_list,
|
||||
"isError": getattr(result, "isError", False),
|
||||
}
|
||||
else:
|
||||
return result
|
||||
|
||||
def test_connection(self) -> Dict:
|
||||
if not self.server_url:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "No server URL configured",
|
||||
"tools_count": 0,
|
||||
}
|
||||
try:
|
||||
parsed = urlparse(self.server_url)
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"Invalid URL scheme '{parsed.scheme}' — use http:// or https://",
|
||||
"tools_count": 0,
|
||||
}
|
||||
except Exception:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "Invalid URL format",
|
||||
"tools_count": 0,
|
||||
}
|
||||
if not self._client:
|
||||
try:
|
||||
self._setup_client()
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"Client init failed: {str(e)}",
|
||||
"tools_count": 0,
|
||||
}
|
||||
try:
|
||||
if self.auth_type == "oauth":
|
||||
return self._test_oauth_connection()
|
||||
else:
|
||||
return self._test_regular_connection()
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"Connection failed: {str(e)}",
|
||||
"tools_count": 0,
|
||||
}
|
||||
|
||||
def _test_regular_connection(self) -> Dict:
|
||||
ping_ok = False
|
||||
ping_error = None
|
||||
try:
|
||||
self._run_async_operation("ping")
|
||||
ping_ok = True
|
||||
except Exception as e:
|
||||
ping_error = str(e)
|
||||
|
||||
try:
|
||||
tools = self.discover_tools()
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"Connection failed: {ping_error or str(e)}",
|
||||
"tools_count": 0,
|
||||
}
|
||||
|
||||
if not tools and not ping_ok:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"Connection failed: {ping_error or 'No tools found'}",
|
||||
"tools_count": 0,
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Connected — found {len(tools)} tool{'s' if len(tools) != 1 else ''}.",
|
||||
"tools_count": len(tools),
|
||||
"tools": [
|
||||
{
|
||||
"name": tool.get("name", "unknown"),
|
||||
"description": tool.get("description", ""),
|
||||
}
|
||||
for tool in tools
|
||||
],
|
||||
}
|
||||
|
||||
def _test_oauth_connection(self) -> Dict:
|
||||
storage = DBTokenStorage(
|
||||
server_url=self.server_url, user_id=self.user_id, db_client=db
|
||||
)
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
tokens = loop.run_until_complete(storage.get_tokens())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
if tokens and tokens.access_token:
|
||||
self.query_mode = True
|
||||
_mcp_clients_cache.pop(self._cache_key, None)
|
||||
self._client = None
|
||||
self._setup_client()
|
||||
try:
|
||||
tools = self.discover_tools()
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Connected — found {len(tools)} tool{'s' if len(tools) != 1 else ''}.",
|
||||
"tools_count": len(tools),
|
||||
"tools": [
|
||||
{
|
||||
"name": t.get("name", "unknown"),
|
||||
"description": t.get("description", ""),
|
||||
}
|
||||
for t in tools
|
||||
],
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning("OAuth token validation failed: %s", e)
|
||||
_mcp_clients_cache.pop(self._cache_key, None)
|
||||
self._client = None
|
||||
|
||||
return self._start_oauth_task()
|
||||
|
||||
def _start_oauth_task(self) -> Dict:
|
||||
task_config = self.config.copy()
|
||||
task_config.pop("query_mode", None)
|
||||
result = mcp_oauth_task.delay(task_config, self.user_id)
|
||||
return {
|
||||
"success": False,
|
||||
"requires_oauth": True,
|
||||
"task_id": result.id,
|
||||
"message": "OAuth authorization required.",
|
||||
"tools_count": 0,
|
||||
}
|
||||
|
||||
def get_actions_metadata(self) -> List[Dict]:
|
||||
"""
|
||||
Get metadata for all available actions.
|
||||
|
||||
Returns:
|
||||
List of action metadata dictionaries
|
||||
"""
|
||||
actions = []
|
||||
for tool in self.available_tools:
|
||||
input_schema = (
|
||||
tool.get("inputSchema")
|
||||
or tool.get("input_schema")
|
||||
or tool.get("schema")
|
||||
or tool.get("parameters")
|
||||
)
|
||||
|
||||
parameters_schema = {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
if input_schema:
|
||||
if isinstance(input_schema, dict):
|
||||
if "properties" in input_schema:
|
||||
parameters_schema = {
|
||||
"type": input_schema.get("type", "object"),
|
||||
"properties": input_schema.get("properties", {}),
|
||||
"required": input_schema.get("required", []),
|
||||
}
|
||||
|
||||
for key in ["additionalProperties", "description"]:
|
||||
if key in input_schema:
|
||||
parameters_schema[key] = input_schema[key]
|
||||
else:
|
||||
parameters_schema["properties"] = input_schema
|
||||
action = {
|
||||
"name": tool.get("name", ""),
|
||||
"description": tool.get("description", ""),
|
||||
"parameters": parameters_schema,
|
||||
}
|
||||
actions.append(action)
|
||||
return actions
|
||||
|
||||
def get_config_requirements(self) -> Dict:
|
||||
return {
|
||||
"server_url": {
|
||||
"type": "string",
|
||||
"label": "Server URL",
|
||||
"description": "URL of the remote MCP server",
|
||||
"required": True,
|
||||
"secret": False,
|
||||
"order": 1,
|
||||
},
|
||||
"auth_type": {
|
||||
"type": "string",
|
||||
"label": "Authentication Type",
|
||||
"description": "Authentication method for the MCP server",
|
||||
"enum": ["none", "bearer", "oauth", "api_key", "basic"],
|
||||
"default": "none",
|
||||
"required": True,
|
||||
"secret": False,
|
||||
"order": 2,
|
||||
},
|
||||
"api_key": {
|
||||
"type": "string",
|
||||
"label": "API Key",
|
||||
"description": "API key for authentication",
|
||||
"required": False,
|
||||
"secret": True,
|
||||
"order": 3,
|
||||
"depends_on": {"auth_type": "api_key"},
|
||||
},
|
||||
"api_key_header": {
|
||||
"type": "string",
|
||||
"label": "API Key Header",
|
||||
"description": "Header name for API key (default: X-API-Key)",
|
||||
"default": "X-API-Key",
|
||||
"required": False,
|
||||
"secret": False,
|
||||
"order": 4,
|
||||
"depends_on": {"auth_type": "api_key"},
|
||||
},
|
||||
"bearer_token": {
|
||||
"type": "string",
|
||||
"label": "Bearer Token",
|
||||
"description": "Bearer token for authentication",
|
||||
"required": False,
|
||||
"secret": True,
|
||||
"order": 3,
|
||||
"depends_on": {"auth_type": "bearer"},
|
||||
},
|
||||
"username": {
|
||||
"type": "string",
|
||||
"label": "Username",
|
||||
"description": "Username for basic authentication",
|
||||
"required": False,
|
||||
"secret": False,
|
||||
"order": 3,
|
||||
"depends_on": {"auth_type": "basic"},
|
||||
},
|
||||
"password": {
|
||||
"type": "string",
|
||||
"label": "Password",
|
||||
"description": "Password for basic authentication",
|
||||
"required": False,
|
||||
"secret": True,
|
||||
"order": 4,
|
||||
"depends_on": {"auth_type": "basic"},
|
||||
},
|
||||
"oauth_scopes": {
|
||||
"type": "string",
|
||||
"label": "OAuth Scopes",
|
||||
"description": "Comma-separated OAuth scopes to request",
|
||||
"required": False,
|
||||
"secret": False,
|
||||
"order": 3,
|
||||
"depends_on": {"auth_type": "oauth"},
|
||||
},
|
||||
"timeout": {
|
||||
"type": "number",
|
||||
"label": "Timeout (seconds)",
|
||||
"description": "Request timeout in seconds (1-300)",
|
||||
"default": 30,
|
||||
"required": False,
|
||||
"secret": False,
|
||||
"order": 10,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class DocsGPTOAuth(OAuthClientProvider):
|
||||
"""
|
||||
Custom OAuth handler for DocsGPT that uses frontend redirect instead of browser.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mcp_url: str,
|
||||
redirect_uri: str,
|
||||
redis_client: Redis | None = None,
|
||||
redis_prefix: str = "mcp_oauth:",
|
||||
task_id: str = None,
|
||||
scopes: str | list[str] | None = None,
|
||||
client_name: str = "DocsGPT-MCP",
|
||||
user_id=None,
|
||||
db=None,
|
||||
additional_client_metadata: dict[str, Any] | None = None,
|
||||
skip_redirect_validation: bool = False,
|
||||
):
|
||||
self.redirect_uri = redirect_uri
|
||||
self.redis_client = redis_client
|
||||
self.redis_prefix = redis_prefix
|
||||
self.task_id = task_id
|
||||
self.user_id = user_id
|
||||
self.db = db
|
||||
|
||||
parsed_url = urlparse(mcp_url)
|
||||
self.server_base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||||
|
||||
if isinstance(scopes, list):
|
||||
scopes = " ".join(scopes)
|
||||
client_metadata = OAuthClientMetadata(
|
||||
client_name=client_name,
|
||||
redirect_uris=[AnyHttpUrl(redirect_uri)],
|
||||
grant_types=["authorization_code", "refresh_token"],
|
||||
response_types=["code"],
|
||||
scope=scopes,
|
||||
**(additional_client_metadata or {}),
|
||||
)
|
||||
|
||||
storage = DBTokenStorage(
|
||||
server_url=self.server_base_url,
|
||||
user_id=self.user_id,
|
||||
db_client=self.db,
|
||||
expected_redirect_uri=None if skip_redirect_validation else redirect_uri,
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
server_url=self.server_base_url,
|
||||
client_metadata=client_metadata,
|
||||
storage=storage,
|
||||
redirect_handler=self.redirect_handler,
|
||||
callback_handler=self.callback_handler,
|
||||
)
|
||||
|
||||
self.auth_url = None
|
||||
self.extracted_state = None
|
||||
|
||||
def _process_auth_url(self, authorization_url: str) -> tuple[str, str]:
|
||||
"""Process authorization URL to extract state"""
|
||||
try:
|
||||
parsed_url = urlparse(authorization_url)
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
|
||||
state_params = query_params.get("state", [])
|
||||
if state_params:
|
||||
state = state_params[0]
|
||||
else:
|
||||
raise ValueError("No state in auth URL")
|
||||
return authorization_url, state
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to process auth URL: {e}")
|
||||
|
||||
async def redirect_handler(self, authorization_url: str) -> None:
|
||||
"""Store auth URL and state in Redis for frontend to use."""
|
||||
auth_url, state = self._process_auth_url(authorization_url)
|
||||
logger.info("Processed auth_url: %s, state: %s", auth_url, state)
|
||||
self.auth_url = auth_url
|
||||
self.extracted_state = state
|
||||
|
||||
if self.redis_client and self.extracted_state:
|
||||
key = f"{self.redis_prefix}auth_url:{self.extracted_state}"
|
||||
self.redis_client.setex(key, 600, auth_url)
|
||||
logger.info("Stored auth_url in Redis: %s", key)
|
||||
|
||||
if self.task_id:
|
||||
status_key = f"mcp_oauth_status:{self.task_id}"
|
||||
status_data = {
|
||||
"status": "requires_redirect",
|
||||
"message": "Authorization required",
|
||||
"authorization_url": self.auth_url,
|
||||
"state": self.extracted_state,
|
||||
"requires_oauth": True,
|
||||
"task_id": self.task_id,
|
||||
}
|
||||
self.redis_client.setex(status_key, 600, json.dumps(status_data))
|
||||
|
||||
async def callback_handler(self) -> tuple[str, str | None]:
|
||||
"""Wait for auth code from Redis using the state value."""
|
||||
if not self.redis_client or not self.extracted_state:
|
||||
raise Exception("Redis client or state not configured for OAuth")
|
||||
poll_interval = 1
|
||||
max_wait_time = 300
|
||||
code_key = f"{self.redis_prefix}code:{self.extracted_state}"
|
||||
|
||||
if self.task_id:
|
||||
status_key = f"mcp_oauth_status:{self.task_id}"
|
||||
status_data = {
|
||||
"status": "awaiting_callback",
|
||||
"message": "Waiting for authorization...",
|
||||
"authorization_url": self.auth_url,
|
||||
"state": self.extracted_state,
|
||||
"requires_oauth": True,
|
||||
"task_id": self.task_id,
|
||||
}
|
||||
self.redis_client.setex(status_key, 600, json.dumps(status_data))
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < max_wait_time:
|
||||
code_data = self.redis_client.get(code_key)
|
||||
if code_data:
|
||||
code = code_data.decode()
|
||||
returned_state = self.extracted_state
|
||||
|
||||
self.redis_client.delete(code_key)
|
||||
self.redis_client.delete(
|
||||
f"{self.redis_prefix}auth_url:{self.extracted_state}"
|
||||
)
|
||||
self.redis_client.delete(
|
||||
f"{self.redis_prefix}state:{self.extracted_state}"
|
||||
)
|
||||
|
||||
if self.task_id:
|
||||
status_data = {
|
||||
"status": "callback_received",
|
||||
"message": "Completing authentication...",
|
||||
"task_id": self.task_id,
|
||||
}
|
||||
self.redis_client.setex(status_key, 600, json.dumps(status_data))
|
||||
return code, returned_state
|
||||
error_key = f"{self.redis_prefix}error:{self.extracted_state}"
|
||||
error_data = self.redis_client.get(error_key)
|
||||
if error_data:
|
||||
error_msg = error_data.decode()
|
||||
self.redis_client.delete(error_key)
|
||||
self.redis_client.delete(
|
||||
f"{self.redis_prefix}auth_url:{self.extracted_state}"
|
||||
)
|
||||
self.redis_client.delete(
|
||||
f"{self.redis_prefix}state:{self.extracted_state}"
|
||||
)
|
||||
raise Exception(f"OAuth error: {error_msg}")
|
||||
await asyncio.sleep(poll_interval)
|
||||
self.redis_client.delete(f"{self.redis_prefix}auth_url:{self.extracted_state}")
|
||||
self.redis_client.delete(f"{self.redis_prefix}state:{self.extracted_state}")
|
||||
raise Exception("OAuth timeout: no code received within 5 minutes")
|
||||
|
||||
|
||||
class NonInteractiveOAuth(DocsGPTOAuth):
|
||||
"""OAuth provider that fails fast on 401 instead of starting interactive auth.
|
||||
|
||||
Used during query execution to prevent the streaming response from blocking
|
||||
while waiting for user authorization that will never come.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs.setdefault("task_id", None)
|
||||
kwargs["skip_redirect_validation"] = True
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def redirect_handler(self, authorization_url: str) -> None:
|
||||
raise Exception(
|
||||
"OAuth session expired — please re-authorize this MCP server in tool settings."
|
||||
)
|
||||
|
||||
async def callback_handler(self) -> tuple[str, str | None]:
|
||||
raise Exception(
|
||||
"OAuth session expired — please re-authorize this MCP server in tool settings."
|
||||
)
|
||||
|
||||
|
||||
class DBTokenStorage(TokenStorage):
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
user_id: str,
|
||||
db_client,
|
||||
expected_redirect_uri: Optional[str] = None,
|
||||
):
|
||||
self.server_url = server_url
|
||||
self.user_id = user_id
|
||||
self.db_client = db_client
|
||||
self.expected_redirect_uri = expected_redirect_uri
|
||||
self.collection = db_client["connector_sessions"]
|
||||
|
||||
@staticmethod
|
||||
def get_base_url(url: str) -> str:
|
||||
parsed = urlparse(url)
|
||||
return f"{parsed.scheme}://{parsed.netloc}"
|
||||
|
||||
def get_db_key(self) -> dict:
|
||||
return {
|
||||
"server_url": self.get_base_url(self.server_url),
|
||||
"user_id": self.user_id,
|
||||
}
|
||||
|
||||
async def get_tokens(self) -> OAuthToken | None:
|
||||
doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key())
|
||||
if not doc or "tokens" not in doc:
|
||||
return None
|
||||
try:
|
||||
return OAuthToken.model_validate(doc["tokens"])
|
||||
except ValidationError as e:
|
||||
logger.error("Could not load tokens: %s", e)
|
||||
return None
|
||||
|
||||
async def set_tokens(self, tokens: OAuthToken) -> None:
|
||||
await asyncio.to_thread(
|
||||
self.collection.update_one,
|
||||
self.get_db_key(),
|
||||
{"$set": {"tokens": tokens.model_dump()}},
|
||||
True,
|
||||
)
|
||||
logger.info("Saved tokens for %s", self.get_base_url(self.server_url))
|
||||
|
||||
async def get_client_info(self) -> OAuthClientInformationFull | None:
|
||||
doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key())
|
||||
if not doc or "client_info" not in doc:
|
||||
logger.debug(
|
||||
"No client_info in DB for %s", self.get_base_url(self.server_url)
|
||||
)
|
||||
return None
|
||||
try:
|
||||
client_info = OAuthClientInformationFull.model_validate(doc["client_info"])
|
||||
if self.expected_redirect_uri:
|
||||
stored_uris = [
|
||||
str(uri).rstrip("/") for uri in client_info.redirect_uris
|
||||
]
|
||||
expected_uri = self.expected_redirect_uri.rstrip("/")
|
||||
if expected_uri not in stored_uris:
|
||||
logger.warning(
|
||||
"Redirect URI mismatch for %s: expected=%s stored=%s — clearing.",
|
||||
self.get_base_url(self.server_url),
|
||||
expected_uri,
|
||||
stored_uris,
|
||||
)
|
||||
await asyncio.to_thread(
|
||||
self.collection.update_one,
|
||||
self.get_db_key(),
|
||||
{"$unset": {"client_info": "", "tokens": ""}},
|
||||
)
|
||||
return None
|
||||
return client_info
|
||||
except ValidationError as e:
|
||||
logger.error("Could not load client info: %s", e)
|
||||
return None
|
||||
|
||||
def _serialize_client_info(self, info: dict) -> dict:
|
||||
if "redirect_uris" in info and isinstance(info["redirect_uris"], list):
|
||||
info["redirect_uris"] = [str(u) for u in info["redirect_uris"]]
|
||||
return info
|
||||
|
||||
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
|
||||
serialized_info = self._serialize_client_info(client_info.model_dump())
|
||||
await asyncio.to_thread(
|
||||
self.collection.update_one,
|
||||
self.get_db_key(),
|
||||
{"$set": {"client_info": serialized_info}},
|
||||
True,
|
||||
)
|
||||
logger.info("Saved client info for %s", self.get_base_url(self.server_url))
|
||||
|
||||
async def clear(self) -> None:
|
||||
await asyncio.to_thread(self.collection.delete_one, self.get_db_key())
|
||||
logger.info("Cleared OAuth cache for %s", self.get_base_url(self.server_url))
|
||||
|
||||
@classmethod
|
||||
async def clear_all(cls, db_client) -> None:
|
||||
collection = db_client["connector_sessions"]
|
||||
await asyncio.to_thread(collection.delete_many, {})
|
||||
logger.info("Cleared all OAuth client cache data.")
|
||||
|
||||
|
||||
class MCPOAuthManager:
|
||||
"""Manager for handling MCP OAuth callbacks."""
|
||||
|
||||
def __init__(self, redis_client: Redis | None, redis_prefix: str = "mcp_oauth:"):
|
||||
self.redis_client = redis_client
|
||||
self.redis_prefix = redis_prefix
|
||||
|
||||
def handle_oauth_callback(
|
||||
self, state: str, code: str, error: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Handle OAuth callback from provider.
|
||||
|
||||
Args:
|
||||
state: The state parameter from OAuth callback
|
||||
code: The authorization code from OAuth callback
|
||||
error: Error message if OAuth failed
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
if not self.redis_client or not state:
|
||||
raise Exception("Redis client or state not provided")
|
||||
if error:
|
||||
error_key = f"{self.redis_prefix}error:{state}"
|
||||
self.redis_client.setex(error_key, 300, error)
|
||||
raise Exception(f"OAuth error received: {error}")
|
||||
code_key = f"{self.redis_prefix}code:{state}"
|
||||
self.redis_client.setex(code_key, 300, code)
|
||||
|
||||
state_key = f"{self.redis_prefix}state:{state}"
|
||||
self.redis_client.setex(state_key, 300, "completed")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("Error handling OAuth callback: %s", e)
|
||||
return False
|
||||
|
||||
def get_oauth_status(self, task_id: str) -> Dict[str, Any]:
|
||||
"""Get current status of OAuth flow using provided task_id."""
|
||||
if not task_id:
|
||||
return {"status": "not_started", "message": "OAuth flow not started"}
|
||||
return mcp_oauth_status_task(task_id)
|
||||
546
application/agents/tools/memory.py
Normal file
546
application/agents/tools/memory.py
Normal file
@@ -0,0 +1,546 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
import re
|
||||
import uuid
|
||||
|
||||
from .base import Tool
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
class MemoryTool(Tool):
|
||||
"""Memory
|
||||
|
||||
Stores and retrieves information across conversations through a memory file directory.
|
||||
"""
|
||||
|
||||
def __init__(self, tool_config: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> None:
|
||||
"""Initialize the tool.
|
||||
|
||||
Args:
|
||||
tool_config: Optional tool configuration. Should include:
|
||||
- tool_id: Unique identifier for this memory tool instance (from user_tools._id)
|
||||
This ensures each user's tool configuration has isolated memories
|
||||
user_id: The authenticated user's id (should come from decoded_token["sub"]).
|
||||
"""
|
||||
self.user_id: Optional[str] = user_id
|
||||
|
||||
# Get tool_id from configuration (passed from user_tools._id in production)
|
||||
# In production, tool_id is the MongoDB ObjectId string from user_tools collection
|
||||
if tool_config and "tool_id" in tool_config:
|
||||
self.tool_id = tool_config["tool_id"]
|
||||
elif user_id:
|
||||
# Fallback for backward compatibility or testing
|
||||
self.tool_id = f"default_{user_id}"
|
||||
else:
|
||||
# Last resort fallback (shouldn't happen in normal use)
|
||||
self.tool_id = str(uuid.uuid4())
|
||||
|
||||
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
|
||||
self.collection = db["memories"]
|
||||
|
||||
# -----------------------------
|
||||
# Action implementations
|
||||
# -----------------------------
|
||||
def execute_action(self, action_name: str, **kwargs: Any) -> str:
|
||||
"""Execute an action by name.
|
||||
|
||||
Args:
|
||||
action_name: One of view, create, str_replace, insert, delete, rename.
|
||||
**kwargs: Parameters for the action.
|
||||
|
||||
Returns:
|
||||
A human-readable string result.
|
||||
"""
|
||||
if not self.user_id:
|
||||
return "Error: MemoryTool requires a valid user_id."
|
||||
|
||||
if action_name == "view":
|
||||
return self._view(
|
||||
kwargs.get("path", "/"),
|
||||
kwargs.get("view_range")
|
||||
)
|
||||
|
||||
if action_name == "create":
|
||||
return self._create(
|
||||
kwargs.get("path", ""),
|
||||
kwargs.get("file_text", "")
|
||||
)
|
||||
|
||||
if action_name == "str_replace":
|
||||
return self._str_replace(
|
||||
kwargs.get("path", ""),
|
||||
kwargs.get("old_str", ""),
|
||||
kwargs.get("new_str", "")
|
||||
)
|
||||
|
||||
if action_name == "insert":
|
||||
return self._insert(
|
||||
kwargs.get("path", ""),
|
||||
kwargs.get("insert_line", 1),
|
||||
kwargs.get("insert_text", "")
|
||||
)
|
||||
|
||||
if action_name == "delete":
|
||||
return self._delete(kwargs.get("path", ""))
|
||||
|
||||
if action_name == "rename":
|
||||
return self._rename(
|
||||
kwargs.get("old_path", ""),
|
||||
kwargs.get("new_path", "")
|
||||
)
|
||||
|
||||
return f"Unknown action: {action_name}"
|
||||
|
||||
def get_actions_metadata(self) -> List[Dict[str, Any]]:
|
||||
"""Return JSON metadata describing supported actions for tool schemas."""
|
||||
return [
|
||||
{
|
||||
"name": "view",
|
||||
"description": "Shows directory contents or file contents with optional line ranges.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to file or directory (e.g., /notes.txt or /project/ or /)."
|
||||
},
|
||||
"view_range": {
|
||||
"type": "array",
|
||||
"items": {"type": "integer"},
|
||||
"description": "Optional [start_line, end_line] to view specific lines (1-indexed)."
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "create",
|
||||
"description": "Create or overwrite a file.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "File path to create (e.g., /notes.txt or /project/task.txt)."
|
||||
},
|
||||
"file_text": {
|
||||
"type": "string",
|
||||
"description": "Content to write to the file."
|
||||
}
|
||||
},
|
||||
"required": ["path", "file_text"]
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "str_replace",
|
||||
"description": "Replace text in a file.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "File path (e.g., /notes.txt)."
|
||||
},
|
||||
"old_str": {
|
||||
"type": "string",
|
||||
"description": "String to find."
|
||||
},
|
||||
"new_str": {
|
||||
"type": "string",
|
||||
"description": "String to replace with."
|
||||
}
|
||||
},
|
||||
"required": ["path", "old_str", "new_str"]
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "insert",
|
||||
"description": "Insert text at a specific line in a file.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "File path (e.g., /notes.txt)."
|
||||
},
|
||||
"insert_line": {
|
||||
"type": "integer",
|
||||
"description": "Line number to insert at (1-indexed)."
|
||||
},
|
||||
"insert_text": {
|
||||
"type": "string",
|
||||
"description": "Text to insert."
|
||||
}
|
||||
},
|
||||
"required": ["path", "insert_line", "insert_text"]
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "delete",
|
||||
"description": "Delete a file or directory.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to delete (e.g., /notes.txt or /project/)."
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "rename",
|
||||
"description": "Rename or move a file/directory.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"old_path": {
|
||||
"type": "string",
|
||||
"description": "Current path (e.g., /old.txt)."
|
||||
},
|
||||
"new_path": {
|
||||
"type": "string",
|
||||
"description": "New path (e.g., /new.txt)."
|
||||
}
|
||||
},
|
||||
"required": ["old_path", "new_path"]
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
def get_config_requirements(self) -> Dict[str, Any]:
|
||||
"""Return configuration requirements."""
|
||||
return {}
|
||||
|
||||
# -----------------------------
|
||||
# Path validation
|
||||
# -----------------------------
|
||||
def _validate_path(self, path: str) -> Optional[str]:
|
||||
"""Validate and normalize path.
|
||||
|
||||
Args:
|
||||
path: User-provided path.
|
||||
|
||||
Returns:
|
||||
Normalized path or None if invalid.
|
||||
"""
|
||||
if not path:
|
||||
return None
|
||||
|
||||
# Remove any leading/trailing whitespace
|
||||
path = path.strip()
|
||||
|
||||
# Preserve whether path ends with / (indicates directory)
|
||||
is_directory = path.endswith("/")
|
||||
|
||||
# Ensure path starts with / for consistency
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
|
||||
# Check for directory traversal patterns
|
||||
if ".." in path or path.count("//") > 0:
|
||||
return None
|
||||
|
||||
# Normalize the path
|
||||
try:
|
||||
# Convert to Path object and resolve to canonical form
|
||||
normalized = str(Path(path).as_posix())
|
||||
|
||||
# Ensure it still starts with /
|
||||
if not normalized.startswith("/"):
|
||||
return None
|
||||
|
||||
# Preserve trailing slash for directories
|
||||
if is_directory and not normalized.endswith("/") and normalized != "/":
|
||||
normalized = normalized + "/"
|
||||
|
||||
return normalized
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
# -----------------------------
|
||||
# Internal helpers
|
||||
# -----------------------------
|
||||
def _view(self, path: str, view_range: Optional[List[int]] = None) -> str:
|
||||
"""View directory contents or file contents."""
|
||||
validated_path = self._validate_path(path)
|
||||
if not validated_path:
|
||||
return "Error: Invalid path."
|
||||
|
||||
# Check if viewing directory (ends with / or is root)
|
||||
if validated_path == "/" or validated_path.endswith("/"):
|
||||
return self._view_directory(validated_path)
|
||||
|
||||
# Otherwise view file
|
||||
return self._view_file(validated_path, view_range)
|
||||
|
||||
def _view_directory(self, path: str) -> str:
|
||||
"""List files in a directory."""
|
||||
# Ensure path ends with / for proper prefix matching
|
||||
search_path = path if path.endswith("/") else path + "/"
|
||||
|
||||
# Find all files that start with this directory path
|
||||
query = {
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": {"$regex": f"^{re.escape(search_path)}"}
|
||||
}
|
||||
|
||||
docs = list(self.collection.find(query, {"path": 1}))
|
||||
|
||||
if not docs:
|
||||
return f"Directory: {path}\n(empty)"
|
||||
|
||||
# Extract filenames relative to the directory
|
||||
files = []
|
||||
for doc in docs:
|
||||
file_path = doc["path"]
|
||||
# Remove the directory prefix
|
||||
if file_path.startswith(search_path):
|
||||
relative = file_path[len(search_path):]
|
||||
if relative:
|
||||
files.append(relative)
|
||||
|
||||
files.sort()
|
||||
file_list = "\n".join(f"- {f}" for f in files)
|
||||
return f"Directory: {path}\n{file_list}"
|
||||
|
||||
def _view_file(self, path: str, view_range: Optional[List[int]] = None) -> str:
|
||||
"""View file contents with optional line range."""
|
||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": path})
|
||||
|
||||
if not doc or not doc.get("content"):
|
||||
return f"Error: File not found: {path}"
|
||||
|
||||
content = str(doc["content"])
|
||||
|
||||
# Apply view_range if specified
|
||||
if view_range and len(view_range) == 2:
|
||||
lines = content.split("\n")
|
||||
start, end = view_range
|
||||
# Convert to 0-indexed
|
||||
start_idx = max(0, start - 1)
|
||||
end_idx = min(len(lines), end)
|
||||
|
||||
if start_idx >= len(lines):
|
||||
return f"Error: Line range out of bounds. File has {len(lines)} lines."
|
||||
|
||||
selected_lines = lines[start_idx:end_idx]
|
||||
# Add line numbers (enumerate with 1-based start)
|
||||
numbered_lines = [f"{i}: {line}" for i, line in enumerate(selected_lines, start=start)]
|
||||
return "\n".join(numbered_lines)
|
||||
|
||||
return content
|
||||
|
||||
def _create(self, path: str, file_text: str) -> str:
|
||||
"""Create or overwrite a file."""
|
||||
validated_path = self._validate_path(path)
|
||||
if not validated_path:
|
||||
return "Error: Invalid path."
|
||||
|
||||
if validated_path == "/" or validated_path.endswith("/"):
|
||||
return "Error: Cannot create a file at directory path."
|
||||
|
||||
self.collection.update_one(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
|
||||
{
|
||||
"$set": {
|
||||
"content": file_text,
|
||||
"updated_at": datetime.now()
|
||||
}
|
||||
},
|
||||
upsert=True
|
||||
)
|
||||
|
||||
return f"File created: {validated_path}"
|
||||
|
||||
def _str_replace(self, path: str, old_str: str, new_str: str) -> str:
|
||||
"""Replace text in a file."""
|
||||
validated_path = self._validate_path(path)
|
||||
if not validated_path:
|
||||
return "Error: Invalid path."
|
||||
|
||||
if not old_str:
|
||||
return "Error: old_str is required."
|
||||
|
||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path})
|
||||
|
||||
if not doc or not doc.get("content"):
|
||||
return f"Error: File not found: {validated_path}"
|
||||
|
||||
current_content = str(doc["content"])
|
||||
|
||||
# Check if old_str exists (case-insensitive)
|
||||
if old_str.lower() not in current_content.lower():
|
||||
return f"Error: String '{old_str}' not found in file."
|
||||
|
||||
# Replace the string (case-insensitive)
|
||||
import re as regex_module
|
||||
updated_content = regex_module.sub(regex_module.escape(old_str), new_str, current_content, flags=regex_module.IGNORECASE)
|
||||
|
||||
self.collection.update_one(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
|
||||
{
|
||||
"$set": {
|
||||
"content": updated_content,
|
||||
"updated_at": datetime.now()
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
return f"File updated: {validated_path}"
|
||||
|
||||
def _insert(self, path: str, insert_line: int, insert_text: str) -> str:
|
||||
"""Insert text at a specific line."""
|
||||
validated_path = self._validate_path(path)
|
||||
if not validated_path:
|
||||
return "Error: Invalid path."
|
||||
|
||||
if not insert_text:
|
||||
return "Error: insert_text is required."
|
||||
|
||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path})
|
||||
|
||||
if not doc or not doc.get("content"):
|
||||
return f"Error: File not found: {validated_path}"
|
||||
|
||||
current_content = str(doc["content"])
|
||||
lines = current_content.split("\n")
|
||||
|
||||
# Convert to 0-indexed
|
||||
index = insert_line - 1
|
||||
if index < 0 or index > len(lines):
|
||||
return f"Error: Invalid line number. File has {len(lines)} lines."
|
||||
|
||||
lines.insert(index, insert_text)
|
||||
updated_content = "\n".join(lines)
|
||||
|
||||
self.collection.update_one(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
|
||||
{
|
||||
"$set": {
|
||||
"content": updated_content,
|
||||
"updated_at": datetime.now()
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
return f"Text inserted at line {insert_line} in {validated_path}"
|
||||
|
||||
def _delete(self, path: str) -> str:
|
||||
"""Delete a file or directory."""
|
||||
validated_path = self._validate_path(path)
|
||||
if not validated_path:
|
||||
return "Error: Invalid path."
|
||||
|
||||
if validated_path == "/":
|
||||
# Delete all files for this user and tool
|
||||
result = self.collection.delete_many({"user_id": self.user_id, "tool_id": self.tool_id})
|
||||
return f"Deleted {result.deleted_count} file(s) from memory."
|
||||
|
||||
# Check if it's a directory (ends with /)
|
||||
if validated_path.endswith("/"):
|
||||
# Delete all files in directory
|
||||
result = self.collection.delete_many({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": {"$regex": f"^{re.escape(validated_path)}"}
|
||||
})
|
||||
return f"Deleted directory and {result.deleted_count} file(s)."
|
||||
|
||||
# Try to delete as directory first (without trailing slash)
|
||||
# Check if any files start with this path + /
|
||||
search_path = validated_path + "/"
|
||||
directory_result = self.collection.delete_many({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": {"$regex": f"^{re.escape(search_path)}"}
|
||||
})
|
||||
|
||||
if directory_result.deleted_count > 0:
|
||||
return f"Deleted directory and {directory_result.deleted_count} file(s)."
|
||||
|
||||
# Delete single file
|
||||
result = self.collection.delete_one({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": validated_path
|
||||
})
|
||||
|
||||
if result.deleted_count:
|
||||
return f"Deleted: {validated_path}"
|
||||
return f"Error: File not found: {validated_path}"
|
||||
|
||||
def _rename(self, old_path: str, new_path: str) -> str:
|
||||
"""Rename or move a file/directory."""
|
||||
validated_old = self._validate_path(old_path)
|
||||
validated_new = self._validate_path(new_path)
|
||||
|
||||
if not validated_old or not validated_new:
|
||||
return "Error: Invalid path."
|
||||
|
||||
if validated_old == "/" or validated_new == "/":
|
||||
return "Error: Cannot rename root directory."
|
||||
|
||||
# Check if renaming a directory
|
||||
if validated_old.endswith("/"):
|
||||
# Ensure validated_new also ends with / for proper path replacement
|
||||
if not validated_new.endswith("/"):
|
||||
validated_new = validated_new + "/"
|
||||
|
||||
# Find all files in the old directory
|
||||
docs = list(self.collection.find({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": {"$regex": f"^{re.escape(validated_old)}"}
|
||||
}))
|
||||
|
||||
if not docs:
|
||||
return f"Error: Directory not found: {validated_old}"
|
||||
|
||||
# Update paths for all files
|
||||
for doc in docs:
|
||||
old_file_path = doc["path"]
|
||||
new_file_path = old_file_path.replace(validated_old, validated_new, 1)
|
||||
|
||||
self.collection.update_one(
|
||||
{"_id": doc["_id"]},
|
||||
{"$set": {"path": new_file_path, "updated_at": datetime.now()}}
|
||||
)
|
||||
|
||||
return f"Renamed directory: {validated_old} -> {validated_new} ({len(docs)} files)"
|
||||
|
||||
# Rename single file
|
||||
doc = self.collection.find_one({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": validated_old
|
||||
})
|
||||
|
||||
if not doc:
|
||||
return f"Error: File not found: {validated_old}"
|
||||
|
||||
# Check if new path already exists
|
||||
existing = self.collection.find_one({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": validated_new
|
||||
})
|
||||
|
||||
if existing:
|
||||
return f"Error: File already exists at {validated_new}"
|
||||
|
||||
# Delete the old document and create a new one with the new path
|
||||
self.collection.delete_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_old})
|
||||
self.collection.insert_one({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": validated_new,
|
||||
"content": doc.get("content", ""),
|
||||
"updated_at": datetime.now()
|
||||
})
|
||||
|
||||
return f"Renamed: {validated_old} -> {validated_new}"
|
||||
223
application/agents/tools/notes.py
Normal file
223
application/agents/tools/notes.py
Normal file
@@ -0,0 +1,223 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
import uuid
|
||||
|
||||
from .base import Tool
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
class NotesTool(Tool):
|
||||
"""Notepad
|
||||
|
||||
Single note. Supports viewing, overwriting, string replacement.
|
||||
"""
|
||||
|
||||
def __init__(self, tool_config: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> None:
|
||||
"""Initialize the tool.
|
||||
|
||||
Args:
|
||||
tool_config: Optional tool configuration. Should include:
|
||||
- tool_id: Unique identifier for this notes tool instance (from user_tools._id)
|
||||
This ensures each user's tool configuration has isolated notes
|
||||
user_id: The authenticated user's id (should come from decoded_token["sub"]).
|
||||
"""
|
||||
self.user_id: Optional[str] = user_id
|
||||
|
||||
# Get tool_id from configuration (passed from user_tools._id in production)
|
||||
# In production, tool_id is the MongoDB ObjectId string from user_tools collection
|
||||
if tool_config and "tool_id" in tool_config:
|
||||
self.tool_id = tool_config["tool_id"]
|
||||
elif user_id:
|
||||
# Fallback for backward compatibility or testing
|
||||
self.tool_id = f"default_{user_id}"
|
||||
else:
|
||||
# Last resort fallback (shouldn't happen in normal use)
|
||||
self.tool_id = str(uuid.uuid4())
|
||||
|
||||
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
|
||||
self.collection = db["notes"]
|
||||
|
||||
self._last_artifact_id: Optional[str] = None
|
||||
|
||||
# -----------------------------
|
||||
# Action implementations
|
||||
# -----------------------------
|
||||
def execute_action(self, action_name: str, **kwargs: Any) -> str:
|
||||
"""Execute an action by name.
|
||||
|
||||
Args:
|
||||
action_name: One of view, overwrite, str_replace, insert, delete.
|
||||
**kwargs: Parameters for the action.
|
||||
|
||||
Returns:
|
||||
A human-readable string result.
|
||||
"""
|
||||
if not self.user_id:
|
||||
return "Error: NotesTool requires a valid user_id."
|
||||
|
||||
self._last_artifact_id = None
|
||||
|
||||
if action_name == "view":
|
||||
return self._get_note()
|
||||
|
||||
if action_name == "overwrite":
|
||||
return self._overwrite_note(kwargs.get("text", ""))
|
||||
|
||||
if action_name == "str_replace":
|
||||
return self._str_replace(kwargs.get("old_str", ""), kwargs.get("new_str", ""))
|
||||
|
||||
if action_name == "insert":
|
||||
return self._insert(kwargs.get("line_number", 1), kwargs.get("text", ""))
|
||||
|
||||
if action_name == "delete":
|
||||
return self._delete_note()
|
||||
|
||||
return f"Unknown action: {action_name}"
|
||||
|
||||
def get_actions_metadata(self) -> List[Dict[str, Any]]:
|
||||
"""Return JSON metadata describing supported actions for tool schemas."""
|
||||
return [
|
||||
{
|
||||
"name": "view",
|
||||
"description": "Retrieve the user's note.",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
{
|
||||
"name": "overwrite",
|
||||
"description": "Replace the entire note content (creates if doesn't exist).",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {"type": "string", "description": "New note content."}
|
||||
},
|
||||
"required": ["text"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "str_replace",
|
||||
"description": "Replace occurrences of old_str with new_str in the note.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"old_str": {"type": "string", "description": "String to find."},
|
||||
"new_str": {"type": "string", "description": "String to replace with."}
|
||||
},
|
||||
"required": ["old_str", "new_str"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "insert",
|
||||
"description": "Insert text at the specified line number (1-indexed).",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"line_number": {"type": "integer", "description": "Line number to insert at (1-indexed)."},
|
||||
"text": {"type": "string", "description": "Text to insert."}
|
||||
},
|
||||
"required": ["line_number", "text"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "delete",
|
||||
"description": "Delete the user's note.",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
]
|
||||
|
||||
def get_config_requirements(self) -> Dict[str, Any]:
|
||||
"""Return configuration requirements (none for now)."""
|
||||
return {}
|
||||
|
||||
def get_artifact_id(self, action_name: str, **kwargs: Any) -> Optional[str]:
|
||||
return self._last_artifact_id
|
||||
|
||||
# -----------------------------
|
||||
# Internal helpers (single-note)
|
||||
# -----------------------------
|
||||
def _get_note(self) -> str:
|
||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
|
||||
if not doc or not doc.get("note"):
|
||||
return "No note found."
|
||||
if doc.get("_id") is not None:
|
||||
self._last_artifact_id = str(doc.get("_id"))
|
||||
return str(doc["note"])
|
||||
|
||||
def _overwrite_note(self, content: str) -> str:
|
||||
content = (content or "").strip()
|
||||
if not content:
|
||||
return "Note content required."
|
||||
result = self.collection.find_one_and_update(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id},
|
||||
{"$set": {"note": content, "updated_at": datetime.utcnow()}},
|
||||
upsert=True,
|
||||
return_document=True,
|
||||
)
|
||||
if result and result.get("_id") is not None:
|
||||
self._last_artifact_id = str(result.get("_id"))
|
||||
return "Note saved."
|
||||
|
||||
def _str_replace(self, old_str: str, new_str: str) -> str:
|
||||
if not old_str:
|
||||
return "old_str is required."
|
||||
|
||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
|
||||
if not doc or not doc.get("note"):
|
||||
return "No note found."
|
||||
|
||||
current_note = str(doc["note"])
|
||||
|
||||
# Case-insensitive search
|
||||
if old_str.lower() not in current_note.lower():
|
||||
return f"String '{old_str}' not found in note."
|
||||
|
||||
# Case-insensitive replacement
|
||||
import re
|
||||
updated_note = re.sub(re.escape(old_str), new_str, current_note, flags=re.IGNORECASE)
|
||||
|
||||
result = self.collection.find_one_and_update(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id},
|
||||
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
|
||||
return_document=True,
|
||||
)
|
||||
if result and result.get("_id") is not None:
|
||||
self._last_artifact_id = str(result.get("_id"))
|
||||
return "Note updated."
|
||||
|
||||
def _insert(self, line_number: int, text: str) -> str:
|
||||
if not text:
|
||||
return "Text is required."
|
||||
|
||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
|
||||
if not doc or not doc.get("note"):
|
||||
return "No note found."
|
||||
|
||||
current_note = str(doc["note"])
|
||||
lines = current_note.split("\n")
|
||||
|
||||
# Convert to 0-indexed and validate
|
||||
index = line_number - 1
|
||||
if index < 0 or index > len(lines):
|
||||
return f"Invalid line number. Note has {len(lines)} lines."
|
||||
|
||||
lines.insert(index, text)
|
||||
updated_note = "\n".join(lines)
|
||||
|
||||
result = self.collection.find_one_and_update(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id},
|
||||
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
|
||||
return_document=True,
|
||||
)
|
||||
if result and result.get("_id") is not None:
|
||||
self._last_artifact_id = str(result.get("_id"))
|
||||
return "Text inserted."
|
||||
|
||||
def _delete_note(self) -> str:
|
||||
doc = self.collection.find_one_and_delete(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id}
|
||||
)
|
||||
if not doc:
|
||||
return "No note found to delete."
|
||||
if doc.get("_id") is not None:
|
||||
self._last_artifact_id = str(doc.get("_id"))
|
||||
return "Note deleted."
|
||||
@@ -71,7 +71,7 @@ class NtfyTool(Tool):
|
||||
if self.token:
|
||||
headers["Authorization"] = f"Basic {self.token}"
|
||||
data = message.encode("utf-8")
|
||||
response = requests.post(url, headers=headers, data=data)
|
||||
response = requests.post(url, headers=headers, data=data, timeout=100)
|
||||
return {"status_code": response.status_code, "message": "Message sent"}
|
||||
|
||||
def get_actions_metadata(self):
|
||||
@@ -116,12 +116,13 @@ class NtfyTool(Tool):
|
||||
]
|
||||
|
||||
def get_config_requirements(self):
|
||||
"""
|
||||
Specify the configuration requirements.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary describing required config parameters.
|
||||
"""
|
||||
return {
|
||||
"token": {"type": "string", "description": "Access token for authentication"},
|
||||
"token": {
|
||||
"type": "string",
|
||||
"label": "Access Token",
|
||||
"description": "Ntfy access token for authentication",
|
||||
"required": True,
|
||||
"secret": True,
|
||||
"order": 1,
|
||||
},
|
||||
}
|
||||
@@ -1,6 +1,12 @@
|
||||
import psycopg2
|
||||
import logging
|
||||
|
||||
import psycopg
|
||||
|
||||
from application.agents.tools.base import Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PostgresTool(Tool):
|
||||
"""
|
||||
PostgreSQL Database Tool
|
||||
@@ -17,25 +23,25 @@ class PostgresTool(Tool):
|
||||
"postgres_execute_sql": self._execute_sql,
|
||||
"postgres_get_schema": self._get_schema,
|
||||
}
|
||||
|
||||
if action_name in actions:
|
||||
return actions[action_name](**kwargs)
|
||||
else:
|
||||
if action_name not in actions:
|
||||
raise ValueError(f"Unknown action: {action_name}")
|
||||
return actions[action_name](**kwargs)
|
||||
|
||||
def _execute_sql(self, sql_query):
|
||||
"""
|
||||
Executes an SQL query against the PostgreSQL database using a connection string.
|
||||
"""
|
||||
conn = None # Initialize conn to None for error handling
|
||||
conn = None
|
||||
try:
|
||||
conn = psycopg2.connect(self.connection_string)
|
||||
conn = psycopg.connect(self.connection_string)
|
||||
cur = conn.cursor()
|
||||
cur.execute(sql_query)
|
||||
conn.commit()
|
||||
|
||||
if sql_query.strip().lower().startswith("select"):
|
||||
column_names = [desc[0] for desc in cur.description] if cur.description else []
|
||||
column_names = (
|
||||
[desc[0] for desc in cur.description] if cur.description else []
|
||||
)
|
||||
results = []
|
||||
rows = cur.fetchall()
|
||||
for row in rows:
|
||||
@@ -43,7 +49,9 @@ class PostgresTool(Tool):
|
||||
response_data = {"data": results, "column_names": column_names}
|
||||
else:
|
||||
row_count = cur.rowcount
|
||||
response_data = {"message": f"Query executed successfully, {row_count} rows affected."}
|
||||
response_data = {
|
||||
"message": f"Query executed successfully, {row_count} rows affected."
|
||||
}
|
||||
|
||||
cur.close()
|
||||
return {
|
||||
@@ -52,28 +60,29 @@ class PostgresTool(Tool):
|
||||
"response_data": response_data,
|
||||
}
|
||||
|
||||
except psycopg2.Error as e:
|
||||
except psycopg.Error as e:
|
||||
error_message = f"Database error: {e}"
|
||||
print(f"Database error: {e}")
|
||||
logger.error("PostgreSQL execute_sql error: %s", e)
|
||||
return {
|
||||
"status_code": 500,
|
||||
"message": "Failed to execute SQL query.",
|
||||
"error": error_message,
|
||||
}
|
||||
finally:
|
||||
if conn: # Ensure connection is closed even if errors occur
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
def _get_schema(self, db_name):
|
||||
"""
|
||||
Retrieves the schema of the PostgreSQL database using a connection string.
|
||||
"""
|
||||
conn = None # Initialize conn to None for error handling
|
||||
conn = None
|
||||
try:
|
||||
conn = psycopg2.connect(self.connection_string)
|
||||
conn = psycopg.connect(self.connection_string)
|
||||
cur = conn.cursor()
|
||||
|
||||
cur.execute("""
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT
|
||||
table_name,
|
||||
column_name,
|
||||
@@ -87,19 +96,22 @@ class PostgresTool(Tool):
|
||||
ORDER BY
|
||||
table_name,
|
||||
ordinal_position;
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
schema_data = {}
|
||||
for row in cur.fetchall():
|
||||
table_name, column_name, data_type, column_default, is_nullable = row
|
||||
if table_name not in schema_data:
|
||||
schema_data[table_name] = []
|
||||
schema_data[table_name].append({
|
||||
"column_name": column_name,
|
||||
"data_type": data_type,
|
||||
"column_default": column_default,
|
||||
"is_nullable": is_nullable
|
||||
})
|
||||
schema_data[table_name].append(
|
||||
{
|
||||
"column_name": column_name,
|
||||
"data_type": data_type,
|
||||
"column_default": column_default,
|
||||
"is_nullable": is_nullable,
|
||||
}
|
||||
)
|
||||
|
||||
cur.close()
|
||||
return {
|
||||
@@ -108,16 +120,16 @@ class PostgresTool(Tool):
|
||||
"schema": schema_data,
|
||||
}
|
||||
|
||||
except psycopg2.Error as e:
|
||||
except psycopg.Error as e:
|
||||
error_message = f"Database error: {e}"
|
||||
print(f"Database error: {e}")
|
||||
logger.error("PostgreSQL get_schema error: %s", e)
|
||||
return {
|
||||
"status_code": 500,
|
||||
"message": "Failed to retrieve database schema.",
|
||||
"error": error_message,
|
||||
}
|
||||
finally:
|
||||
if conn: # Ensure connection is closed even if errors occur
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
def get_actions_metadata(self):
|
||||
@@ -158,6 +170,10 @@ class PostgresTool(Tool):
|
||||
return {
|
||||
"token": {
|
||||
"type": "string",
|
||||
"description": "PostgreSQL database connection string (e.g., 'postgresql://user:password@host:port/dbname')",
|
||||
"label": "Connection String",
|
||||
"description": "PostgreSQL database connection string",
|
||||
"required": True,
|
||||
"secret": True,
|
||||
"order": 1,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
84
application/agents/tools/read_webpage.py
Normal file
84
application/agents/tools/read_webpage.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import requests
|
||||
from markdownify import markdownify
|
||||
from application.agents.tools.base import Tool
|
||||
from application.core.url_validation import validate_url, SSRFError
|
||||
|
||||
class ReadWebpageTool(Tool):
|
||||
"""
|
||||
Read Webpage (browser)
|
||||
A tool to fetch the HTML content of a URL and convert it to Markdown.
|
||||
"""
|
||||
|
||||
def __init__(self, config=None):
|
||||
"""
|
||||
Initializes the tool.
|
||||
:param config: Optional configuration dictionary. Not used by this tool.
|
||||
"""
|
||||
self.config = config
|
||||
|
||||
def execute_action(self, action_name: str, **kwargs) -> str:
|
||||
"""
|
||||
Executes the specified action. For this tool, the only action is 'read_webpage'.
|
||||
|
||||
:param action_name: The name of the action to execute. Should be 'read_webpage'.
|
||||
:param kwargs: Keyword arguments, must include 'url'.
|
||||
:return: The Markdown content of the webpage or an error message.
|
||||
"""
|
||||
if action_name != "read_webpage":
|
||||
return f"Error: Unknown action '{action_name}'. This tool only supports 'read_webpage'."
|
||||
|
||||
url = kwargs.get("url")
|
||||
if not url:
|
||||
return "Error: URL parameter is missing."
|
||||
|
||||
# Validate URL to prevent SSRF attacks
|
||||
try:
|
||||
url = validate_url(url)
|
||||
except SSRFError as e:
|
||||
return f"Error: URL validation failed - {e}"
|
||||
|
||||
try:
|
||||
response = requests.get(url, timeout=10, headers={'User-Agent': 'DocsGPT-Agent/1.0'})
|
||||
response.raise_for_status() # Raise an exception for HTTP errors (4xx or 5xx)
|
||||
|
||||
html_content = response.text
|
||||
#soup = BeautifulSoup(html_content, 'html.parser')
|
||||
|
||||
|
||||
markdown_content = markdownify(html_content, heading_style="ATX", newline_style="BACKSLASH")
|
||||
|
||||
return markdown_content
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
return f"Error fetching URL {url}: {e}"
|
||||
except Exception as e:
|
||||
return f"Error processing URL {url}: {e}"
|
||||
|
||||
def get_actions_metadata(self):
|
||||
"""
|
||||
Returns metadata for the actions supported by this tool.
|
||||
"""
|
||||
return [
|
||||
{
|
||||
"name": "read_webpage",
|
||||
"description": "Fetches the HTML content of a given URL and returns it as clean Markdown text. Input must be a valid URL.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "The fully qualified URL of the webpage to read (e.g., 'https://www.example.com').",
|
||||
}
|
||||
},
|
||||
"required": ["url"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
def get_config_requirements(self):
|
||||
"""
|
||||
Returns a dictionary describing the configuration requirements for the tool.
|
||||
This tool does not require any specific configuration.
|
||||
"""
|
||||
return {}
|
||||
342
application/agents/tools/spec_parser.py
Normal file
342
application/agents/tools/spec_parser.py
Normal file
@@ -0,0 +1,342 @@
|
||||
"""
|
||||
API Specification Parser
|
||||
|
||||
Parses OpenAPI 3.x and Swagger 2.0 specifications and converts them
|
||||
to API Tool action definitions for use in DocsGPT.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import yaml
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SUPPORTED_METHODS = frozenset(
|
||||
{"get", "post", "put", "delete", "patch", "head", "options"}
|
||||
)
|
||||
|
||||
|
||||
def parse_spec(spec_content: str) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
|
||||
"""
|
||||
Parse an API specification and convert operations to action definitions.
|
||||
|
||||
Supports OpenAPI 3.x and Swagger 2.0 formats in JSON or YAML.
|
||||
|
||||
Args:
|
||||
spec_content: Raw specification content as string
|
||||
|
||||
Returns:
|
||||
Tuple of (metadata dict, list of action dicts)
|
||||
|
||||
Raises:
|
||||
ValueError: If the spec is invalid or uses an unsupported format
|
||||
"""
|
||||
spec = _load_spec(spec_content)
|
||||
_validate_spec(spec)
|
||||
|
||||
is_swagger = "swagger" in spec
|
||||
metadata = _extract_metadata(spec, is_swagger)
|
||||
actions = _extract_actions(spec, is_swagger)
|
||||
|
||||
return metadata, actions
|
||||
|
||||
|
||||
def _load_spec(content: str) -> Dict[str, Any]:
|
||||
"""Parse spec content from JSON or YAML string."""
|
||||
content = content.strip()
|
||||
if not content:
|
||||
raise ValueError("Empty specification content")
|
||||
try:
|
||||
if content.startswith("{"):
|
||||
return json.loads(content)
|
||||
return yaml.safe_load(content)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON format: {e.msg}")
|
||||
except yaml.YAMLError as e:
|
||||
raise ValueError(f"Invalid YAML format: {e}")
|
||||
|
||||
|
||||
def _validate_spec(spec: Dict[str, Any]) -> None:
|
||||
"""Validate spec version and required fields."""
|
||||
if not isinstance(spec, dict):
|
||||
raise ValueError("Specification must be a valid object")
|
||||
openapi_version = spec.get("openapi", "")
|
||||
swagger_version = spec.get("swagger", "")
|
||||
|
||||
if not (openapi_version.startswith("3.") or swagger_version == "2.0"):
|
||||
raise ValueError(
|
||||
"Unsupported specification version. Expected OpenAPI 3.x or Swagger 2.0"
|
||||
)
|
||||
if "paths" not in spec or not spec["paths"]:
|
||||
raise ValueError("No API paths defined in the specification")
|
||||
|
||||
|
||||
def _extract_metadata(spec: Dict[str, Any], is_swagger: bool) -> Dict[str, Any]:
|
||||
"""Extract API metadata from specification."""
|
||||
info = spec.get("info", {})
|
||||
base_url = _get_base_url(spec, is_swagger)
|
||||
|
||||
return {
|
||||
"title": info.get("title", "Untitled API"),
|
||||
"description": (info.get("description", "") or "")[:500],
|
||||
"version": info.get("version", ""),
|
||||
"base_url": base_url,
|
||||
}
|
||||
|
||||
|
||||
def _get_base_url(spec: Dict[str, Any], is_swagger: bool) -> str:
|
||||
"""Extract base URL from spec (handles both OpenAPI 3.x and Swagger 2.0)."""
|
||||
if is_swagger:
|
||||
schemes = spec.get("schemes", ["https"])
|
||||
host = spec.get("host", "")
|
||||
base_path = spec.get("basePath", "")
|
||||
if host:
|
||||
scheme = schemes[0] if schemes else "https"
|
||||
return f"{scheme}://{host}{base_path}".rstrip("/")
|
||||
return ""
|
||||
servers = spec.get("servers", [])
|
||||
if servers and isinstance(servers, list) and servers[0].get("url"):
|
||||
return servers[0]["url"].rstrip("/")
|
||||
return ""
|
||||
|
||||
|
||||
def _extract_actions(spec: Dict[str, Any], is_swagger: bool) -> List[Dict[str, Any]]:
|
||||
"""Extract all API operations as action definitions."""
|
||||
actions = []
|
||||
paths = spec.get("paths", {})
|
||||
base_url = _get_base_url(spec, is_swagger)
|
||||
|
||||
components = spec.get("components", {})
|
||||
definitions = spec.get("definitions", {})
|
||||
|
||||
for path, path_item in paths.items():
|
||||
if not isinstance(path_item, dict):
|
||||
continue
|
||||
path_params = path_item.get("parameters", [])
|
||||
|
||||
for method in SUPPORTED_METHODS:
|
||||
operation = path_item.get(method)
|
||||
if not isinstance(operation, dict):
|
||||
continue
|
||||
try:
|
||||
action = _build_action(
|
||||
path=path,
|
||||
method=method,
|
||||
operation=operation,
|
||||
path_params=path_params,
|
||||
base_url=base_url,
|
||||
components=components,
|
||||
definitions=definitions,
|
||||
is_swagger=is_swagger,
|
||||
)
|
||||
actions.append(action)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to parse operation {method.upper()} {path}: {e}"
|
||||
)
|
||||
continue
|
||||
return actions
|
||||
|
||||
|
||||
def _build_action(
|
||||
path: str,
|
||||
method: str,
|
||||
operation: Dict[str, Any],
|
||||
path_params: List[Dict],
|
||||
base_url: str,
|
||||
components: Dict[str, Any],
|
||||
definitions: Dict[str, Any],
|
||||
is_swagger: bool,
|
||||
) -> Dict[str, Any]:
|
||||
"""Build a single action from an API operation."""
|
||||
action_name = _generate_action_name(operation, method, path)
|
||||
full_url = f"{base_url}{path}" if base_url else path
|
||||
|
||||
all_params = path_params + operation.get("parameters", [])
|
||||
query_params, headers = _categorize_parameters(all_params, components, definitions)
|
||||
|
||||
body, body_content_type = _extract_request_body(
|
||||
operation, components, definitions, is_swagger
|
||||
)
|
||||
|
||||
description = operation.get("summary", "") or operation.get("description", "")
|
||||
|
||||
return {
|
||||
"name": action_name,
|
||||
"url": full_url,
|
||||
"method": method.upper(),
|
||||
"description": (description or "")[:500],
|
||||
"query_params": {"type": "object", "properties": query_params},
|
||||
"headers": {"type": "object", "properties": headers},
|
||||
"body": {"type": "object", "properties": body},
|
||||
"body_content_type": body_content_type,
|
||||
"active": True,
|
||||
}
|
||||
|
||||
|
||||
def _generate_action_name(operation: Dict[str, Any], method: str, path: str) -> str:
|
||||
"""Generate a valid action name from operationId or method+path."""
|
||||
if operation.get("operationId"):
|
||||
name = operation["operationId"]
|
||||
else:
|
||||
path_slug = re.sub(r"[{}]", "", path)
|
||||
path_slug = re.sub(r"[^a-zA-Z0-9]", "_", path_slug)
|
||||
path_slug = re.sub(r"_+", "_", path_slug).strip("_")
|
||||
name = f"{method}_{path_slug}"
|
||||
name = re.sub(r"[^a-zA-Z0-9_-]", "_", name)
|
||||
return name[:64]
|
||||
|
||||
|
||||
def _categorize_parameters(
|
||||
parameters: List[Dict],
|
||||
components: Dict[str, Any],
|
||||
definitions: Dict[str, Any],
|
||||
) -> Tuple[Dict, Dict]:
|
||||
"""Categorize parameters into query params and headers."""
|
||||
query_params = {}
|
||||
headers = {}
|
||||
|
||||
for param in parameters:
|
||||
resolved = _resolve_ref(param, components, definitions)
|
||||
if not resolved or "name" not in resolved:
|
||||
continue
|
||||
location = resolved.get("in", "query")
|
||||
prop = _param_to_property(resolved)
|
||||
|
||||
if location in ("query", "path"):
|
||||
query_params[resolved["name"]] = prop
|
||||
elif location == "header":
|
||||
headers[resolved["name"]] = prop
|
||||
return query_params, headers
|
||||
|
||||
|
||||
def _param_to_property(param: Dict) -> Dict[str, Any]:
|
||||
"""Convert an API parameter to an action property definition."""
|
||||
schema = param.get("schema", {})
|
||||
param_type = schema.get("type", param.get("type", "string"))
|
||||
|
||||
mapped_type = "integer" if param_type in ("integer", "number") else "string"
|
||||
|
||||
return {
|
||||
"type": mapped_type,
|
||||
"description": (param.get("description", "") or "")[:200],
|
||||
"value": "",
|
||||
"filled_by_llm": param.get("required", False),
|
||||
"required": param.get("required", False),
|
||||
}
|
||||
|
||||
|
||||
def _extract_request_body(
|
||||
operation: Dict[str, Any],
|
||||
components: Dict[str, Any],
|
||||
definitions: Dict[str, Any],
|
||||
is_swagger: bool,
|
||||
) -> Tuple[Dict, str]:
|
||||
"""Extract request body schema and content type."""
|
||||
content_types = [
|
||||
"application/json",
|
||||
"application/x-www-form-urlencoded",
|
||||
"multipart/form-data",
|
||||
"text/plain",
|
||||
"application/xml",
|
||||
]
|
||||
|
||||
if is_swagger:
|
||||
consumes = operation.get("consumes", [])
|
||||
body_param = next(
|
||||
(p for p in operation.get("parameters", []) if p.get("in") == "body"), None
|
||||
)
|
||||
if not body_param:
|
||||
return {}, "application/json"
|
||||
selected_type = consumes[0] if consumes else "application/json"
|
||||
schema = body_param.get("schema", {})
|
||||
else:
|
||||
request_body = operation.get("requestBody", {})
|
||||
if not request_body:
|
||||
return {}, "application/json"
|
||||
request_body = _resolve_ref(request_body, components, definitions)
|
||||
content = request_body.get("content", {})
|
||||
|
||||
selected_type = "application/json"
|
||||
schema = {}
|
||||
|
||||
for ct in content_types:
|
||||
if ct in content:
|
||||
selected_type = ct
|
||||
schema = content[ct].get("schema", {})
|
||||
break
|
||||
if not schema and content:
|
||||
first_type = next(iter(content))
|
||||
selected_type = first_type
|
||||
schema = content[first_type].get("schema", {})
|
||||
properties = _schema_to_properties(schema, components, definitions)
|
||||
return properties, selected_type
|
||||
|
||||
|
||||
def _schema_to_properties(
|
||||
schema: Dict,
|
||||
components: Dict[str, Any],
|
||||
definitions: Dict[str, Any],
|
||||
depth: int = 0,
|
||||
) -> Dict[str, Any]:
|
||||
"""Convert schema to action body properties (limited depth to prevent recursion)."""
|
||||
if depth > 3:
|
||||
return {}
|
||||
schema = _resolve_ref(schema, components, definitions)
|
||||
if not schema or not isinstance(schema, dict):
|
||||
return {}
|
||||
properties = {}
|
||||
schema_type = schema.get("type", "object")
|
||||
|
||||
if schema_type == "object":
|
||||
required_fields = set(schema.get("required", []))
|
||||
for prop_name, prop_schema in schema.get("properties", {}).items():
|
||||
resolved = _resolve_ref(prop_schema, components, definitions)
|
||||
if not isinstance(resolved, dict):
|
||||
continue
|
||||
prop_type = resolved.get("type", "string")
|
||||
mapped_type = "integer" if prop_type in ("integer", "number") else "string"
|
||||
|
||||
properties[prop_name] = {
|
||||
"type": mapped_type,
|
||||
"description": (resolved.get("description", "") or "")[:200],
|
||||
"value": "",
|
||||
"filled_by_llm": prop_name in required_fields,
|
||||
"required": prop_name in required_fields,
|
||||
}
|
||||
return properties
|
||||
|
||||
|
||||
def _resolve_ref(
|
||||
obj: Any,
|
||||
components: Dict[str, Any],
|
||||
definitions: Dict[str, Any],
|
||||
) -> Optional[Dict]:
|
||||
"""Resolve $ref references in the specification."""
|
||||
if not isinstance(obj, dict):
|
||||
return obj if isinstance(obj, dict) else None
|
||||
if "$ref" not in obj:
|
||||
return obj
|
||||
ref_path = obj["$ref"]
|
||||
|
||||
if ref_path.startswith("#/components/"):
|
||||
parts = ref_path.replace("#/components/", "").split("/")
|
||||
return _traverse_path(components, parts)
|
||||
elif ref_path.startswith("#/definitions/"):
|
||||
parts = ref_path.replace("#/definitions/", "").split("/")
|
||||
return _traverse_path(definitions, parts)
|
||||
logger.debug(f"Unsupported ref path: {ref_path}")
|
||||
return None
|
||||
|
||||
|
||||
def _traverse_path(obj: Dict, parts: List[str]) -> Optional[Dict]:
|
||||
"""Traverse a nested dictionary using path parts."""
|
||||
try:
|
||||
for part in parts:
|
||||
obj = obj[part]
|
||||
return obj if isinstance(obj, dict) else None
|
||||
except (KeyError, TypeError):
|
||||
return None
|
||||
@@ -1,6 +1,11 @@
|
||||
import logging
|
||||
|
||||
import requests
|
||||
|
||||
from application.agents.tools.base import Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TelegramTool(Tool):
|
||||
"""
|
||||
@@ -18,24 +23,22 @@ class TelegramTool(Tool):
|
||||
"telegram_send_message": self._send_message,
|
||||
"telegram_send_image": self._send_image,
|
||||
}
|
||||
|
||||
if action_name in actions:
|
||||
return actions[action_name](**kwargs)
|
||||
else:
|
||||
if action_name not in actions:
|
||||
raise ValueError(f"Unknown action: {action_name}")
|
||||
return actions[action_name](**kwargs)
|
||||
|
||||
def _send_message(self, text, chat_id):
|
||||
print(f"Sending message: {text}")
|
||||
logger.debug("Sending Telegram message to chat_id=%s", chat_id)
|
||||
url = f"https://api.telegram.org/bot{self.token}/sendMessage"
|
||||
payload = {"chat_id": chat_id, "text": text}
|
||||
response = requests.post(url, data=payload)
|
||||
response = requests.post(url, data=payload, timeout=100)
|
||||
return {"status_code": response.status_code, "message": "Message sent"}
|
||||
|
||||
def _send_image(self, image_url, chat_id):
|
||||
print(f"Sending image: {image_url}")
|
||||
logger.debug("Sending Telegram image to chat_id=%s", chat_id)
|
||||
url = f"https://api.telegram.org/bot{self.token}/sendPhoto"
|
||||
payload = {"chat_id": chat_id, "photo": image_url}
|
||||
response = requests.post(url, data=payload)
|
||||
response = requests.post(url, data=payload, timeout=100)
|
||||
return {"status_code": response.status_code, "message": "Image sent"}
|
||||
|
||||
def get_actions_metadata(self):
|
||||
@@ -82,5 +85,12 @@ class TelegramTool(Tool):
|
||||
|
||||
def get_config_requirements(self):
|
||||
return {
|
||||
"token": {"type": "string", "description": "Bot token for authentication"},
|
||||
"token": {
|
||||
"type": "string",
|
||||
"label": "Bot Token",
|
||||
"description": "Telegram bot token for authentication",
|
||||
"required": True,
|
||||
"secret": True,
|
||||
"order": 1,
|
||||
},
|
||||
}
|
||||
|
||||
70
application/agents/tools/think.py
Normal file
70
application/agents/tools/think.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from application.agents.tools.base import Tool
|
||||
|
||||
|
||||
THINK_TOOL_ID = "think"
|
||||
|
||||
THINK_TOOL_ENTRY = {
|
||||
"name": "think",
|
||||
"actions": [
|
||||
{
|
||||
"name": "reason",
|
||||
"description": (
|
||||
"Use this tool to think through your reasoning step by step "
|
||||
"before deciding on your next action. Always reason before "
|
||||
"searching or answering."
|
||||
),
|
||||
"active": True,
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"reasoning": {
|
||||
"type": "string",
|
||||
"description": "Your step-by-step reasoning and analysis",
|
||||
"filled_by_llm": True,
|
||||
"required": True,
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class ThinkTool(Tool):
|
||||
"""Pseudo-tool that captures chain-of-thought reasoning.
|
||||
|
||||
Returns a short acknowledgment so the LLM can continue.
|
||||
The reasoning content is captured in tool_call data for transparency.
|
||||
"""
|
||||
|
||||
internal = True
|
||||
|
||||
def __init__(self, config=None):
|
||||
pass
|
||||
|
||||
def execute_action(self, action_name: str, **kwargs):
|
||||
return "Continue."
|
||||
|
||||
def get_actions_metadata(self):
|
||||
return [
|
||||
{
|
||||
"name": "reason",
|
||||
"description": (
|
||||
"Use this tool to think through your reasoning step by step "
|
||||
"before deciding on your next action. Always reason before "
|
||||
"searching or answering."
|
||||
),
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"reasoning": {
|
||||
"type": "string",
|
||||
"description": "Your step-by-step reasoning and analysis",
|
||||
"filled_by_llm": True,
|
||||
"required": True,
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
def get_config_requirements(self):
|
||||
return {}
|
||||
333
application/agents/tools/todo_list.py
Normal file
333
application/agents/tools/todo_list.py
Normal file
@@ -0,0 +1,333 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
import uuid
|
||||
|
||||
from .base import Tool
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
class TodoListTool(Tool):
|
||||
"""Todo List
|
||||
|
||||
Manages todo items for users. Supports creating, viewing, updating, and deleting todos.
|
||||
"""
|
||||
|
||||
def __init__(self, tool_config: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> None:
|
||||
"""Initialize the tool.
|
||||
|
||||
Args:
|
||||
tool_config: Optional tool configuration. Should include:
|
||||
- tool_id: Unique identifier for this todo list tool instance (from user_tools._id)
|
||||
This ensures each user's tool configuration has isolated todos
|
||||
user_id: The authenticated user's id (should come from decoded_token["sub"]).
|
||||
"""
|
||||
self.user_id: Optional[str] = user_id
|
||||
|
||||
# Get tool_id from configuration (passed from user_tools._id in production)
|
||||
# In production, tool_id is the MongoDB ObjectId string from user_tools collection
|
||||
if tool_config and "tool_id" in tool_config:
|
||||
self.tool_id = tool_config["tool_id"]
|
||||
elif user_id:
|
||||
# Fallback for backward compatibility or testing
|
||||
self.tool_id = f"default_{user_id}"
|
||||
else:
|
||||
# Last resort fallback (shouldn't happen in normal use)
|
||||
self.tool_id = str(uuid.uuid4())
|
||||
|
||||
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
|
||||
self.collection = db["todos"]
|
||||
|
||||
self._last_artifact_id: Optional[str] = None
|
||||
|
||||
# -----------------------------
|
||||
# Action implementations
|
||||
# -----------------------------
|
||||
def execute_action(self, action_name: str, **kwargs: Any) -> str:
|
||||
"""Execute an action by name.
|
||||
|
||||
Args:
|
||||
action_name: One of list, create, get, update, complete, delete.
|
||||
**kwargs: Parameters for the action.
|
||||
|
||||
Returns:
|
||||
A human-readable string result.
|
||||
"""
|
||||
if not self.user_id:
|
||||
return "Error: TodoListTool requires a valid user_id."
|
||||
|
||||
self._last_artifact_id = None
|
||||
|
||||
if action_name == "list":
|
||||
return self._list()
|
||||
|
||||
if action_name == "create":
|
||||
return self._create(kwargs.get("title", ""))
|
||||
|
||||
if action_name == "get":
|
||||
return self._get(kwargs.get("todo_id"))
|
||||
|
||||
if action_name == "update":
|
||||
return self._update(
|
||||
kwargs.get("todo_id"),
|
||||
kwargs.get("title", "")
|
||||
)
|
||||
|
||||
if action_name == "complete":
|
||||
return self._complete(kwargs.get("todo_id"))
|
||||
|
||||
if action_name == "delete":
|
||||
return self._delete(kwargs.get("todo_id"))
|
||||
|
||||
return f"Unknown action: {action_name}"
|
||||
|
||||
def get_actions_metadata(self) -> List[Dict[str, Any]]:
|
||||
"""Return JSON metadata describing supported actions for tool schemas."""
|
||||
return [
|
||||
{
|
||||
"name": "list",
|
||||
"description": "List all todos for the user.",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
{
|
||||
"name": "create",
|
||||
"description": "Create a new todo item.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {
|
||||
"type": "string",
|
||||
"description": "Title of the todo item."
|
||||
}
|
||||
},
|
||||
"required": ["title"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "get",
|
||||
"description": "Get a specific todo by ID.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"todo_id": {
|
||||
"type": "integer",
|
||||
"description": "The ID of the todo to retrieve."
|
||||
}
|
||||
},
|
||||
"required": ["todo_id"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "update",
|
||||
"description": "Update a todo's title by ID.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"todo_id": {
|
||||
"type": "integer",
|
||||
"description": "The ID of the todo to update."
|
||||
},
|
||||
"title": {
|
||||
"type": "string",
|
||||
"description": "The new title for the todo."
|
||||
}
|
||||
},
|
||||
"required": ["todo_id", "title"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "complete",
|
||||
"description": "Mark a todo as completed.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"todo_id": {
|
||||
"type": "integer",
|
||||
"description": "The ID of the todo to mark as completed."
|
||||
}
|
||||
},
|
||||
"required": ["todo_id"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "delete",
|
||||
"description": "Delete a specific todo by ID.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"todo_id": {
|
||||
"type": "integer",
|
||||
"description": "The ID of the todo to delete."
|
||||
}
|
||||
},
|
||||
"required": ["todo_id"],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
def get_config_requirements(self) -> Dict[str, Any]:
|
||||
"""Return configuration requirements."""
|
||||
return {}
|
||||
|
||||
def get_artifact_id(self, action_name: str, **kwargs: Any) -> Optional[str]:
|
||||
return self._last_artifact_id
|
||||
|
||||
# -----------------------------
|
||||
# Internal helpers
|
||||
# -----------------------------
|
||||
def _coerce_todo_id(self, value: Optional[Any]) -> Optional[int]:
|
||||
"""Convert todo identifiers to sequential integers."""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if isinstance(value, int):
|
||||
return value if value > 0 else None
|
||||
|
||||
if isinstance(value, str):
|
||||
stripped = value.strip()
|
||||
if stripped.isdigit():
|
||||
numeric_value = int(stripped)
|
||||
return numeric_value if numeric_value > 0 else None
|
||||
|
||||
return None
|
||||
|
||||
def _get_next_todo_id(self) -> int:
|
||||
"""Get the next sequential todo_id for this user and tool.
|
||||
|
||||
Returns a simple integer (1, 2, 3, ...) scoped to this user/tool.
|
||||
With 5-10 todos max, scanning is negligible.
|
||||
"""
|
||||
query = {"user_id": self.user_id, "tool_id": self.tool_id}
|
||||
todos = list(self.collection.find(query, {"todo_id": 1}))
|
||||
|
||||
# Find the maximum todo_id
|
||||
max_id = 0
|
||||
for todo in todos:
|
||||
todo_id = self._coerce_todo_id(todo.get("todo_id"))
|
||||
if todo_id is not None:
|
||||
max_id = max(max_id, todo_id)
|
||||
|
||||
return max_id + 1
|
||||
|
||||
def _list(self) -> str:
|
||||
"""List all todos for the user."""
|
||||
query = {"user_id": self.user_id, "tool_id": self.tool_id}
|
||||
todos = list(self.collection.find(query))
|
||||
|
||||
if not todos:
|
||||
return "No todos found."
|
||||
|
||||
result_lines = ["Todos:"]
|
||||
for doc in todos:
|
||||
todo_id = doc.get("todo_id")
|
||||
title = doc.get("title", "Untitled")
|
||||
status = doc.get("status", "open")
|
||||
|
||||
line = f"[{todo_id}] {title} ({status})"
|
||||
result_lines.append(line)
|
||||
|
||||
return "\n".join(result_lines)
|
||||
|
||||
def _create(self, title: str) -> str:
|
||||
"""Create a new todo item."""
|
||||
title = (title or "").strip()
|
||||
if not title:
|
||||
return "Error: Title is required."
|
||||
|
||||
now = datetime.now()
|
||||
todo_id = self._get_next_todo_id()
|
||||
|
||||
doc = {
|
||||
"todo_id": todo_id,
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"title": title,
|
||||
"status": "open",
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
}
|
||||
insert_result = self.collection.insert_one(doc)
|
||||
inserted_id = getattr(insert_result, "inserted_id", None) or doc.get("_id")
|
||||
if inserted_id is not None:
|
||||
self._last_artifact_id = str(inserted_id)
|
||||
return f"Todo created with ID {todo_id}: {title}"
|
||||
|
||||
def _get(self, todo_id: Optional[Any]) -> str:
|
||||
"""Get a specific todo by ID."""
|
||||
parsed_todo_id = self._coerce_todo_id(todo_id)
|
||||
if parsed_todo_id is None:
|
||||
return "Error: todo_id must be a positive integer."
|
||||
|
||||
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
|
||||
doc = self.collection.find_one(query)
|
||||
|
||||
if not doc:
|
||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||
|
||||
if doc.get("_id") is not None:
|
||||
self._last_artifact_id = str(doc.get("_id"))
|
||||
|
||||
title = doc.get("title", "Untitled")
|
||||
status = doc.get("status", "open")
|
||||
|
||||
result = f"Todo [{parsed_todo_id}]:\nTitle: {title}\nStatus: {status}"
|
||||
|
||||
return result
|
||||
|
||||
def _update(self, todo_id: Optional[Any], title: str) -> str:
|
||||
"""Update a todo's title by ID."""
|
||||
parsed_todo_id = self._coerce_todo_id(todo_id)
|
||||
if parsed_todo_id is None:
|
||||
return "Error: todo_id must be a positive integer."
|
||||
|
||||
title = (title or "").strip()
|
||||
if not title:
|
||||
return "Error: Title is required."
|
||||
|
||||
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
|
||||
doc = self.collection.find_one_and_update(
|
||||
query,
|
||||
{"$set": {"title": title, "updated_at": datetime.now()}},
|
||||
)
|
||||
if not doc:
|
||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||
|
||||
if doc.get("_id") is not None:
|
||||
self._last_artifact_id = str(doc.get("_id"))
|
||||
|
||||
return f"Todo {parsed_todo_id} updated to: {title}"
|
||||
|
||||
def _complete(self, todo_id: Optional[Any]) -> str:
|
||||
"""Mark a todo as completed."""
|
||||
parsed_todo_id = self._coerce_todo_id(todo_id)
|
||||
if parsed_todo_id is None:
|
||||
return "Error: todo_id must be a positive integer."
|
||||
|
||||
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
|
||||
doc = self.collection.find_one_and_update(
|
||||
query,
|
||||
{"$set": {"status": "completed", "updated_at": datetime.now()}},
|
||||
)
|
||||
if not doc:
|
||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||
|
||||
if doc.get("_id") is not None:
|
||||
self._last_artifact_id = str(doc.get("_id"))
|
||||
|
||||
return f"Todo {parsed_todo_id} marked as completed."
|
||||
|
||||
def _delete(self, todo_id: Optional[Any]) -> str:
|
||||
"""Delete a specific todo by ID."""
|
||||
parsed_todo_id = self._coerce_todo_id(todo_id)
|
||||
if parsed_todo_id is None:
|
||||
return "Error: todo_id must be a positive integer."
|
||||
|
||||
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
|
||||
doc = self.collection.find_one_and_delete(query)
|
||||
if not doc:
|
||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||
|
||||
if doc.get("_id") is not None:
|
||||
self._last_artifact_id = str(doc.get("_id"))
|
||||
|
||||
return f"Todo {parsed_todo_id} deleted."
|
||||
@@ -5,8 +5,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolActionParser:
|
||||
def __init__(self, llm_type):
|
||||
def __init__(self, llm_type, name_mapping=None):
|
||||
self.llm_type = llm_type
|
||||
self.name_mapping = name_mapping
|
||||
self.parsers = {
|
||||
"OpenAILLM": self._parse_openai_llm,
|
||||
"GoogleLLM": self._parse_google_llm,
|
||||
@@ -16,27 +17,70 @@ class ToolActionParser:
|
||||
parser = self.parsers.get(self.llm_type, self._parse_openai_llm)
|
||||
return parser(call)
|
||||
|
||||
def _resolve_via_mapping(self, call_name):
|
||||
"""Look up (tool_id, action_name) from the name mapping if available."""
|
||||
if self.name_mapping and call_name in self.name_mapping:
|
||||
return self.name_mapping[call_name]
|
||||
return None
|
||||
|
||||
def _parse_openai_llm(self, call):
|
||||
if isinstance(call, dict):
|
||||
try:
|
||||
call_args = json.loads(call["function"]["arguments"])
|
||||
tool_id = call["function"]["name"].split("_")[-1]
|
||||
action_name = call["function"]["name"].rsplit("_", 1)[0]
|
||||
except (KeyError, TypeError) as e:
|
||||
logger.error(f"Error parsing OpenAI LLM call: {e}")
|
||||
return None, None, None
|
||||
else:
|
||||
try:
|
||||
call_args = json.loads(call.function.arguments)
|
||||
tool_id = call.function.name.split("_")[-1]
|
||||
action_name = call.function.name.rsplit("_", 1)[0]
|
||||
except (AttributeError, TypeError) as e:
|
||||
logger.error(f"Error parsing OpenAI LLM call: {e}")
|
||||
try:
|
||||
call_args = json.loads(call.arguments)
|
||||
|
||||
resolved = self._resolve_via_mapping(call.name)
|
||||
if resolved:
|
||||
return resolved[0], resolved[1], call_args
|
||||
|
||||
# Fallback: legacy split on "_" for backward compatibility
|
||||
tool_parts = call.name.split("_")
|
||||
|
||||
if len(tool_parts) < 2:
|
||||
logger.warning(
|
||||
f"Invalid tool name format: {call.name}. "
|
||||
"Could not resolve via mapping or legacy parsing."
|
||||
)
|
||||
return None, None, None
|
||||
|
||||
tool_id = tool_parts[-1]
|
||||
action_name = "_".join(tool_parts[:-1])
|
||||
|
||||
if not tool_id.isdigit():
|
||||
logger.warning(
|
||||
f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call."
|
||||
)
|
||||
|
||||
except (AttributeError, TypeError, json.JSONDecodeError) as e:
|
||||
logger.error(f"Error parsing OpenAI LLM call: {e}")
|
||||
return None, None, None
|
||||
return tool_id, action_name, call_args
|
||||
|
||||
def _parse_google_llm(self, call):
|
||||
call_args = call.args
|
||||
tool_id = call.name.split("_")[-1]
|
||||
action_name = call.name.rsplit("_", 1)[0]
|
||||
try:
|
||||
call_args = call.arguments
|
||||
|
||||
resolved = self._resolve_via_mapping(call.name)
|
||||
if resolved:
|
||||
return resolved[0], resolved[1], call_args
|
||||
|
||||
# Fallback: legacy split on "_" for backward compatibility
|
||||
tool_parts = call.name.split("_")
|
||||
|
||||
if len(tool_parts) < 2:
|
||||
logger.warning(
|
||||
f"Invalid tool name format: {call.name}. "
|
||||
"Could not resolve via mapping or legacy parsing."
|
||||
)
|
||||
return None, None, None
|
||||
|
||||
tool_id = tool_parts[-1]
|
||||
action_name = "_".join(tool_parts[:-1])
|
||||
|
||||
if not tool_id.isdigit():
|
||||
logger.warning(
|
||||
f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call."
|
||||
)
|
||||
|
||||
except (AttributeError, TypeError) as e:
|
||||
logger.error(f"Error parsing Google LLM call: {e}")
|
||||
return None, None, None
|
||||
return tool_id, action_name, call_args
|
||||
|
||||
@@ -19,20 +19,27 @@ class ToolManager:
|
||||
continue
|
||||
module = importlib.import_module(f"application.agents.tools.{name}")
|
||||
for member_name, obj in inspect.getmembers(module, inspect.isclass):
|
||||
if issubclass(obj, Tool) and obj is not Tool:
|
||||
if issubclass(obj, Tool) and obj is not Tool and not obj.internal:
|
||||
tool_config = self.config.get(name, {})
|
||||
self.tools[name] = obj(tool_config)
|
||||
|
||||
def load_tool(self, tool_name, tool_config):
|
||||
def load_tool(self, tool_name, tool_config, user_id=None):
|
||||
self.config[tool_name] = tool_config
|
||||
module = importlib.import_module(f"application.agents.tools.{tool_name}")
|
||||
for member_name, obj in inspect.getmembers(module, inspect.isclass):
|
||||
if issubclass(obj, Tool) and obj is not Tool:
|
||||
return obj(tool_config)
|
||||
if tool_name in {"mcp_tool", "notes", "memory", "todo_list"} and user_id:
|
||||
return obj(tool_config, user_id)
|
||||
else:
|
||||
return obj(tool_config)
|
||||
|
||||
def execute_action(self, tool_name, action_name, **kwargs):
|
||||
def execute_action(self, tool_name, action_name, user_id=None, **kwargs):
|
||||
if tool_name not in self.tools:
|
||||
raise ValueError(f"Tool '{tool_name}' not loaded")
|
||||
if tool_name in {"mcp_tool", "memory", "todo_list", "notes"} and user_id:
|
||||
tool_config = self.config.get(tool_name, {})
|
||||
tool = self.load_tool(tool_name, tool_config, user_id)
|
||||
return tool.execute_action(action_name, **kwargs)
|
||||
return self.tools[tool_name].execute_action(action_name, **kwargs)
|
||||
|
||||
def get_all_actions_metadata(self):
|
||||
|
||||
265
application/agents/workflow_agent.py
Normal file
265
application/agents/workflow_agent.py
Normal file
@@ -0,0 +1,265 @@
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Generator, Optional
|
||||
|
||||
from application.agents.base import BaseAgent
|
||||
from application.agents.workflows.schemas import (
|
||||
ExecutionStatus,
|
||||
Workflow,
|
||||
WorkflowEdge,
|
||||
WorkflowGraph,
|
||||
WorkflowNode,
|
||||
WorkflowRun,
|
||||
)
|
||||
from application.agents.workflows.workflow_engine import WorkflowEngine
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.logging import log_activity, LogContext
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.repositories.workflow_runs import WorkflowRunsRepository
|
||||
from application.storage.db.repositories.workflows import WorkflowsRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowAgent(BaseAgent):
|
||||
"""A specialized agent that executes predefined workflows."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
workflow_id: Optional[str] = None,
|
||||
workflow: Optional[Dict[str, Any]] = None,
|
||||
workflow_owner: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.workflow_id = workflow_id
|
||||
self.workflow_owner = workflow_owner
|
||||
self._workflow_data = workflow
|
||||
self._engine: Optional[WorkflowEngine] = None
|
||||
|
||||
@log_activity()
|
||||
def gen(
|
||||
self, query: str, log_context: LogContext = None
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
yield from self._gen_inner(query, log_context)
|
||||
|
||||
def _gen_inner(
|
||||
self, query: str, log_context: LogContext
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
graph = self._load_workflow_graph()
|
||||
if not graph:
|
||||
yield {"type": "error", "error": "Failed to load workflow configuration."}
|
||||
return
|
||||
self._engine = WorkflowEngine(graph, self)
|
||||
yield from self._engine.execute({}, query)
|
||||
self._save_workflow_run(query)
|
||||
|
||||
def _load_workflow_graph(self) -> Optional[WorkflowGraph]:
|
||||
if self._workflow_data:
|
||||
return self._parse_embedded_workflow()
|
||||
if self.workflow_id:
|
||||
return self._load_from_database()
|
||||
return None
|
||||
|
||||
def _parse_embedded_workflow(self) -> Optional[WorkflowGraph]:
|
||||
try:
|
||||
nodes_data = self._workflow_data.get("nodes", [])
|
||||
edges_data = self._workflow_data.get("edges", [])
|
||||
|
||||
workflow = Workflow(
|
||||
name=self._workflow_data.get("name", "Embedded Workflow"),
|
||||
description=self._workflow_data.get("description"),
|
||||
)
|
||||
|
||||
nodes = []
|
||||
for n in nodes_data:
|
||||
node_config = n.get("data", {})
|
||||
nodes.append(
|
||||
WorkflowNode(
|
||||
id=n["id"],
|
||||
workflow_id=self.workflow_id or "embedded",
|
||||
type=n["type"],
|
||||
title=n.get("title", "Node"),
|
||||
description=n.get("description"),
|
||||
position=n.get("position", {"x": 0, "y": 0}),
|
||||
config=node_config,
|
||||
)
|
||||
)
|
||||
edges = []
|
||||
for e in edges_data:
|
||||
edges.append(
|
||||
WorkflowEdge(
|
||||
id=e["id"],
|
||||
workflow_id=self.workflow_id or "embedded",
|
||||
source=e.get("source") or e.get("source_id"),
|
||||
target=e.get("target") or e.get("target_id"),
|
||||
sourceHandle=e.get("sourceHandle") or e.get("source_handle"),
|
||||
targetHandle=e.get("targetHandle") or e.get("target_handle"),
|
||||
)
|
||||
)
|
||||
return WorkflowGraph(workflow=workflow, nodes=nodes, edges=edges)
|
||||
except Exception as e:
|
||||
logger.error(f"Invalid embedded workflow: {e}")
|
||||
return None
|
||||
|
||||
def _load_from_database(self) -> Optional[WorkflowGraph]:
|
||||
try:
|
||||
from bson.objectid import ObjectId
|
||||
|
||||
if not self.workflow_id or not ObjectId.is_valid(self.workflow_id):
|
||||
logger.error(f"Invalid workflow ID: {self.workflow_id}")
|
||||
return None
|
||||
owner_id = self.workflow_owner
|
||||
if not owner_id and isinstance(self.decoded_token, dict):
|
||||
owner_id = self.decoded_token.get("sub")
|
||||
if not owner_id:
|
||||
logger.error(
|
||||
f"Workflow owner not available for workflow load: {self.workflow_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
|
||||
workflows_coll = db["workflows"]
|
||||
workflow_nodes_coll = db["workflow_nodes"]
|
||||
workflow_edges_coll = db["workflow_edges"]
|
||||
|
||||
workflow_doc = workflows_coll.find_one(
|
||||
{"_id": ObjectId(self.workflow_id), "user": owner_id}
|
||||
)
|
||||
if not workflow_doc:
|
||||
logger.error(
|
||||
f"Workflow {self.workflow_id} not found or inaccessible for user {owner_id}"
|
||||
)
|
||||
return None
|
||||
workflow = Workflow(**workflow_doc)
|
||||
graph_version = workflow_doc.get("current_graph_version", 1)
|
||||
try:
|
||||
graph_version = int(graph_version)
|
||||
if graph_version <= 0:
|
||||
graph_version = 1
|
||||
except (ValueError, TypeError):
|
||||
graph_version = 1
|
||||
|
||||
nodes_docs = list(
|
||||
workflow_nodes_coll.find(
|
||||
{"workflow_id": self.workflow_id, "graph_version": graph_version}
|
||||
)
|
||||
)
|
||||
if not nodes_docs and graph_version == 1:
|
||||
nodes_docs = list(
|
||||
workflow_nodes_coll.find(
|
||||
{
|
||||
"workflow_id": self.workflow_id,
|
||||
"graph_version": {"$exists": False},
|
||||
}
|
||||
)
|
||||
)
|
||||
nodes = [WorkflowNode(**doc) for doc in nodes_docs]
|
||||
|
||||
edges_docs = list(
|
||||
workflow_edges_coll.find(
|
||||
{"workflow_id": self.workflow_id, "graph_version": graph_version}
|
||||
)
|
||||
)
|
||||
if not edges_docs and graph_version == 1:
|
||||
edges_docs = list(
|
||||
workflow_edges_coll.find(
|
||||
{
|
||||
"workflow_id": self.workflow_id,
|
||||
"graph_version": {"$exists": False},
|
||||
}
|
||||
)
|
||||
)
|
||||
edges = [WorkflowEdge(**doc) for doc in edges_docs]
|
||||
|
||||
return WorkflowGraph(workflow=workflow, nodes=nodes, edges=edges)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load workflow from database: {e}")
|
||||
return None
|
||||
|
||||
def _save_workflow_run(self, query: str) -> None:
|
||||
if not self._engine:
|
||||
return
|
||||
owner_id = self.workflow_owner
|
||||
if not owner_id and isinstance(self.decoded_token, dict):
|
||||
owner_id = self.decoded_token.get("sub")
|
||||
try:
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
workflow_runs_coll = db["workflow_runs"]
|
||||
|
||||
run = WorkflowRun(
|
||||
workflow_id=self.workflow_id or "unknown",
|
||||
user=owner_id,
|
||||
status=self._determine_run_status(),
|
||||
inputs={"query": query},
|
||||
outputs=self._serialize_state(self._engine.state),
|
||||
steps=self._engine.get_execution_summary(),
|
||||
created_at=datetime.now(timezone.utc),
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
result = workflow_runs_coll.insert_one(run.to_mongo_doc())
|
||||
legacy_mongo_id = (
|
||||
str(result.inserted_id)
|
||||
if getattr(result, "inserted_id", None) is not None
|
||||
else None
|
||||
)
|
||||
|
||||
def _pg_write(repo: WorkflowRunsRepository) -> None:
|
||||
if not self.workflow_id or not owner_id or not legacy_mongo_id:
|
||||
return
|
||||
workflow = WorkflowsRepository(repo._conn).get_by_legacy_id(
|
||||
self.workflow_id, owner_id,
|
||||
)
|
||||
if workflow is None:
|
||||
return
|
||||
repo.create(
|
||||
workflow["id"],
|
||||
owner_id,
|
||||
run.status.value,
|
||||
inputs=run.inputs,
|
||||
result=run.outputs,
|
||||
steps=[step.model_dump(mode="json") for step in run.steps],
|
||||
started_at=run.created_at,
|
||||
ended_at=run.completed_at,
|
||||
legacy_mongo_id=legacy_mongo_id,
|
||||
)
|
||||
|
||||
dual_write(WorkflowRunsRepository, _pg_write)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save workflow run: {e}")
|
||||
|
||||
def _determine_run_status(self) -> ExecutionStatus:
|
||||
if not self._engine or not self._engine.execution_log:
|
||||
return ExecutionStatus.COMPLETED
|
||||
for log in self._engine.execution_log:
|
||||
if log.get("status") == ExecutionStatus.FAILED.value:
|
||||
return ExecutionStatus.FAILED
|
||||
return ExecutionStatus.COMPLETED
|
||||
|
||||
def _serialize_state(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
||||
serialized: Dict[str, Any] = {}
|
||||
for key, value in state.items():
|
||||
serialized[key] = self._serialize_state_value(value)
|
||||
return serialized
|
||||
|
||||
def _serialize_state_value(self, value: Any) -> Any:
|
||||
if isinstance(value, dict):
|
||||
return {
|
||||
str(dict_key): self._serialize_state_value(dict_value)
|
||||
for dict_key, dict_value in value.items()
|
||||
}
|
||||
if isinstance(value, list):
|
||||
return [self._serialize_state_value(item) for item in value]
|
||||
if isinstance(value, tuple):
|
||||
return [self._serialize_state_value(item) for item in value]
|
||||
if isinstance(value, datetime):
|
||||
return value.isoformat()
|
||||
if isinstance(value, (str, int, float, bool, type(None))):
|
||||
return value
|
||||
return str(value)
|
||||
64
application/agents/workflows/cel_evaluator.py
Normal file
64
application/agents/workflows/cel_evaluator.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
import celpy
|
||||
import celpy.celtypes
|
||||
|
||||
|
||||
class CelEvaluationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _convert_value(value: Any) -> Any:
|
||||
if isinstance(value, bool):
|
||||
return celpy.celtypes.BoolType(value)
|
||||
if isinstance(value, int):
|
||||
return celpy.celtypes.IntType(value)
|
||||
if isinstance(value, float):
|
||||
return celpy.celtypes.DoubleType(value)
|
||||
if isinstance(value, str):
|
||||
return celpy.celtypes.StringType(value)
|
||||
if isinstance(value, list):
|
||||
return celpy.celtypes.ListType([_convert_value(item) for item in value])
|
||||
if isinstance(value, dict):
|
||||
return celpy.celtypes.MapType(
|
||||
{celpy.celtypes.StringType(k): _convert_value(v) for k, v in value.items()}
|
||||
)
|
||||
if value is None:
|
||||
return celpy.celtypes.BoolType(False)
|
||||
return celpy.celtypes.StringType(str(value))
|
||||
|
||||
|
||||
def build_activation(state: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return {k: _convert_value(v) for k, v in state.items()}
|
||||
|
||||
|
||||
def evaluate_cel(expression: str, state: Dict[str, Any]) -> Any:
|
||||
if not expression or not expression.strip():
|
||||
raise CelEvaluationError("Empty expression")
|
||||
try:
|
||||
env = celpy.Environment()
|
||||
ast = env.compile(expression)
|
||||
program = env.program(ast)
|
||||
activation = build_activation(state)
|
||||
result = program.evaluate(activation)
|
||||
except celpy.CELEvalError as exc:
|
||||
raise CelEvaluationError(f"CEL evaluation error: {exc}") from exc
|
||||
except Exception as exc:
|
||||
raise CelEvaluationError(f"CEL error: {exc}") from exc
|
||||
return cel_to_python(result)
|
||||
|
||||
|
||||
def cel_to_python(value: Any) -> Any:
|
||||
if isinstance(value, celpy.celtypes.BoolType):
|
||||
return bool(value)
|
||||
if isinstance(value, celpy.celtypes.IntType):
|
||||
return int(value)
|
||||
if isinstance(value, celpy.celtypes.DoubleType):
|
||||
return float(value)
|
||||
if isinstance(value, celpy.celtypes.StringType):
|
||||
return str(value)
|
||||
if isinstance(value, celpy.celtypes.ListType):
|
||||
return [cel_to_python(item) for item in value]
|
||||
if isinstance(value, celpy.celtypes.MapType):
|
||||
return {str(k): cel_to_python(v) for k, v in value.items()}
|
||||
return value
|
||||
104
application/agents/workflows/node_agent.py
Normal file
104
application/agents/workflows/node_agent.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""Workflow Node Agents - defines specialized agents for workflow nodes."""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
from application.agents.agentic_agent import AgenticAgent
|
||||
from application.agents.base import BaseAgent
|
||||
from application.agents.classic_agent import ClassicAgent
|
||||
from application.agents.research_agent import ResearchAgent
|
||||
from application.agents.workflows.schemas import AgentType
|
||||
|
||||
|
||||
class ToolFilterMixin:
|
||||
"""Mixin that filters fetched tools to only those specified in tool_ids."""
|
||||
|
||||
_allowed_tool_ids: List[str]
|
||||
|
||||
def _get_user_tools(self, user: str = "local") -> Dict[str, Dict[str, Any]]:
|
||||
all_tools = super()._get_user_tools(user)
|
||||
if not self._allowed_tool_ids:
|
||||
return {}
|
||||
filtered_tools = {
|
||||
tool_id: tool
|
||||
for tool_id, tool in all_tools.items()
|
||||
if str(tool.get("_id", "")) in self._allowed_tool_ids
|
||||
}
|
||||
return filtered_tools
|
||||
|
||||
def _get_tools(self, api_key: str = None) -> Dict[str, Dict[str, Any]]:
|
||||
all_tools = super()._get_tools(api_key)
|
||||
if not self._allowed_tool_ids:
|
||||
return {}
|
||||
filtered_tools = {
|
||||
tool_id: tool
|
||||
for tool_id, tool in all_tools.items()
|
||||
if str(tool.get("_id", "")) in self._allowed_tool_ids
|
||||
}
|
||||
return filtered_tools
|
||||
|
||||
|
||||
class _WorkflowNodeMixin:
|
||||
"""Common __init__ for all workflow node agents."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str,
|
||||
llm_name: str,
|
||||
model_id: str,
|
||||
api_key: str,
|
||||
tool_ids: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
endpoint=endpoint,
|
||||
llm_name=llm_name,
|
||||
model_id=model_id,
|
||||
api_key=api_key,
|
||||
**kwargs,
|
||||
)
|
||||
self._allowed_tool_ids = tool_ids or []
|
||||
|
||||
|
||||
class WorkflowNodeClassicAgent(ToolFilterMixin, _WorkflowNodeMixin, ClassicAgent):
|
||||
pass
|
||||
|
||||
|
||||
class WorkflowNodeAgenticAgent(ToolFilterMixin, _WorkflowNodeMixin, AgenticAgent):
|
||||
pass
|
||||
|
||||
|
||||
class WorkflowNodeResearchAgent(ToolFilterMixin, _WorkflowNodeMixin, ResearchAgent):
|
||||
pass
|
||||
|
||||
|
||||
class WorkflowNodeAgentFactory:
|
||||
|
||||
_agents: Dict[AgentType, Type[BaseAgent]] = {
|
||||
AgentType.CLASSIC: WorkflowNodeClassicAgent,
|
||||
AgentType.REACT: WorkflowNodeClassicAgent, # backwards compat
|
||||
AgentType.AGENTIC: WorkflowNodeAgenticAgent,
|
||||
AgentType.RESEARCH: WorkflowNodeResearchAgent,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
agent_type: AgentType,
|
||||
endpoint: str,
|
||||
llm_name: str,
|
||||
model_id: str,
|
||||
api_key: str,
|
||||
tool_ids: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
) -> BaseAgent:
|
||||
agent_class = cls._agents.get(agent_type)
|
||||
if not agent_class:
|
||||
raise ValueError(f"Unsupported agent type: {agent_type}")
|
||||
return agent_class(
|
||||
endpoint=endpoint,
|
||||
llm_name=llm_name,
|
||||
model_id=model_id,
|
||||
api_key=api_key,
|
||||
tool_ids=tool_ids,
|
||||
**kwargs,
|
||||
)
|
||||
242
application/agents/workflows/schemas.py
Normal file
242
application/agents/workflows/schemas.py
Normal file
@@ -0,0 +1,242 @@
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from bson import ObjectId
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
|
||||
class NodeType(str, Enum):
|
||||
START = "start"
|
||||
END = "end"
|
||||
AGENT = "agent"
|
||||
NOTE = "note"
|
||||
STATE = "state"
|
||||
CONDITION = "condition"
|
||||
|
||||
|
||||
class AgentType(str, Enum):
|
||||
CLASSIC = "classic"
|
||||
REACT = "react"
|
||||
AGENTIC = "agentic"
|
||||
RESEARCH = "research"
|
||||
|
||||
|
||||
class ExecutionStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class Position(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
x: float = 0.0
|
||||
y: float = 0.0
|
||||
|
||||
|
||||
class AgentNodeConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
agent_type: AgentType = AgentType.CLASSIC
|
||||
llm_name: Optional[str] = None
|
||||
system_prompt: str = "You are a helpful assistant."
|
||||
prompt_template: str = ""
|
||||
output_variable: Optional[str] = None
|
||||
stream_to_user: bool = True
|
||||
tools: List[str] = Field(default_factory=list)
|
||||
sources: List[str] = Field(default_factory=list)
|
||||
chunks: str = "2"
|
||||
retriever: str = ""
|
||||
model_id: Optional[str] = None
|
||||
json_schema: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class ConditionCase(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid", populate_by_name=True)
|
||||
name: Optional[str] = None
|
||||
expression: str = ""
|
||||
source_handle: str = Field(..., alias="sourceHandle")
|
||||
|
||||
|
||||
class ConditionNodeConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
mode: Literal["simple", "advanced"] = "simple"
|
||||
cases: List[ConditionCase] = Field(default_factory=list)
|
||||
|
||||
|
||||
class StateOperation(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
expression: str = ""
|
||||
target_variable: str = ""
|
||||
|
||||
|
||||
class WorkflowEdgeCreate(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
id: str
|
||||
workflow_id: str
|
||||
source_id: str = Field(..., alias="source")
|
||||
target_id: str = Field(..., alias="target")
|
||||
source_handle: Optional[str] = Field(None, alias="sourceHandle")
|
||||
target_handle: Optional[str] = Field(None, alias="targetHandle")
|
||||
|
||||
|
||||
class WorkflowEdge(WorkflowEdgeCreate):
|
||||
mongo_id: Optional[str] = Field(None, alias="_id")
|
||||
|
||||
@field_validator("mongo_id", mode="before")
|
||||
@classmethod
|
||||
def convert_objectid(cls, v: Any) -> Optional[str]:
|
||||
if isinstance(v, ObjectId):
|
||||
return str(v)
|
||||
return v
|
||||
|
||||
def to_mongo_doc(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"workflow_id": self.workflow_id,
|
||||
"source_id": self.source_id,
|
||||
"target_id": self.target_id,
|
||||
"source_handle": self.source_handle,
|
||||
"target_handle": self.target_handle,
|
||||
}
|
||||
|
||||
|
||||
class WorkflowNodeCreate(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
id: str
|
||||
workflow_id: str
|
||||
type: NodeType
|
||||
title: str = "Node"
|
||||
description: Optional[str] = None
|
||||
position: Position = Field(default_factory=Position)
|
||||
config: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@field_validator("position", mode="before")
|
||||
@classmethod
|
||||
def parse_position(cls, v: Union[Dict[str, float], Position]) -> Position:
|
||||
if isinstance(v, dict):
|
||||
return Position(**v)
|
||||
return v
|
||||
|
||||
|
||||
class WorkflowNode(WorkflowNodeCreate):
|
||||
mongo_id: Optional[str] = Field(None, alias="_id")
|
||||
|
||||
@field_validator("mongo_id", mode="before")
|
||||
@classmethod
|
||||
def convert_objectid(cls, v: Any) -> Optional[str]:
|
||||
if isinstance(v, ObjectId):
|
||||
return str(v)
|
||||
return v
|
||||
|
||||
def to_mongo_doc(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"workflow_id": self.workflow_id,
|
||||
"type": self.type.value,
|
||||
"title": self.title,
|
||||
"description": self.description,
|
||||
"position": self.position.model_dump(),
|
||||
"config": self.config,
|
||||
}
|
||||
|
||||
|
||||
class WorkflowCreate(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
name: str = "New Workflow"
|
||||
description: Optional[str] = None
|
||||
user: Optional[str] = None
|
||||
|
||||
|
||||
class Workflow(WorkflowCreate):
|
||||
id: Optional[str] = Field(None, alias="_id")
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
def convert_objectid(cls, v: Any) -> Optional[str]:
|
||||
if isinstance(v, ObjectId):
|
||||
return str(v)
|
||||
return v
|
||||
|
||||
def to_mongo_doc(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"user": self.user,
|
||||
"created_at": self.created_at,
|
||||
"updated_at": self.updated_at,
|
||||
}
|
||||
|
||||
|
||||
class WorkflowGraph(BaseModel):
|
||||
workflow: Workflow
|
||||
nodes: List[WorkflowNode] = Field(default_factory=list)
|
||||
edges: List[WorkflowEdge] = Field(default_factory=list)
|
||||
|
||||
def get_node_by_id(self, node_id: str) -> Optional[WorkflowNode]:
|
||||
for node in self.nodes:
|
||||
if node.id == node_id:
|
||||
return node
|
||||
return None
|
||||
|
||||
def get_start_node(self) -> Optional[WorkflowNode]:
|
||||
for node in self.nodes:
|
||||
if node.type == NodeType.START:
|
||||
return node
|
||||
return None
|
||||
|
||||
def get_outgoing_edges(self, node_id: str) -> List[WorkflowEdge]:
|
||||
return [edge for edge in self.edges if edge.source_id == node_id]
|
||||
|
||||
|
||||
class NodeExecutionLog(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
node_id: str
|
||||
node_type: str
|
||||
status: ExecutionStatus
|
||||
started_at: datetime
|
||||
completed_at: Optional[datetime] = None
|
||||
error: Optional[str] = None
|
||||
state_snapshot: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class WorkflowRunCreate(BaseModel):
|
||||
workflow_id: str
|
||||
inputs: Dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class WorkflowRun(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
id: Optional[str] = Field(None, alias="_id")
|
||||
workflow_id: str
|
||||
user: Optional[str] = None
|
||||
status: ExecutionStatus = ExecutionStatus.PENDING
|
||||
inputs: Dict[str, str] = Field(default_factory=dict)
|
||||
outputs: Dict[str, Any] = Field(default_factory=dict)
|
||||
steps: List[NodeExecutionLog] = Field(default_factory=list)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
completed_at: Optional[datetime] = None
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
def convert_objectid(cls, v: Any) -> Optional[str]:
|
||||
if isinstance(v, ObjectId):
|
||||
return str(v)
|
||||
return v
|
||||
|
||||
def to_mongo_doc(self) -> Dict[str, Any]:
|
||||
doc = {
|
||||
"workflow_id": self.workflow_id,
|
||||
"status": self.status.value,
|
||||
"inputs": self.inputs,
|
||||
"outputs": self.outputs,
|
||||
"steps": [step.model_dump() for step in self.steps],
|
||||
"created_at": self.created_at,
|
||||
"completed_at": self.completed_at,
|
||||
}
|
||||
if self.user:
|
||||
doc["user"] = self.user
|
||||
doc["user_id"] = self.user
|
||||
return doc
|
||||
470
application/agents/workflows/workflow_engine.py
Normal file
470
application/agents/workflows/workflow_engine.py
Normal file
@@ -0,0 +1,470 @@
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Generator, List, Optional, TYPE_CHECKING
|
||||
|
||||
from application.agents.workflows.cel_evaluator import CelEvaluationError, evaluate_cel
|
||||
from application.agents.workflows.node_agent import WorkflowNodeAgentFactory
|
||||
from application.agents.workflows.schemas import (
|
||||
AgentNodeConfig,
|
||||
AgentType,
|
||||
ConditionNodeConfig,
|
||||
ExecutionStatus,
|
||||
NodeExecutionLog,
|
||||
NodeType,
|
||||
WorkflowGraph,
|
||||
WorkflowNode,
|
||||
)
|
||||
from application.core.json_schema_utils import (
|
||||
JsonSchemaValidationError,
|
||||
normalize_json_schema_payload,
|
||||
)
|
||||
from application.error import sanitize_api_error
|
||||
from application.templates.namespaces import NamespaceManager
|
||||
from application.templates.template_engine import TemplateEngine, TemplateRenderError
|
||||
|
||||
try:
|
||||
import jsonschema
|
||||
except ImportError: # pragma: no cover - optional dependency in some deployments.
|
||||
jsonschema = None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from application.agents.base import BaseAgent
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
StateValue = Any
|
||||
WorkflowState = Dict[str, StateValue]
|
||||
TEMPLATE_RESERVED_NAMESPACES = {"agent", "system", "source", "tools", "passthrough"}
|
||||
|
||||
|
||||
class WorkflowEngine:
|
||||
MAX_EXECUTION_STEPS = 50
|
||||
|
||||
def __init__(self, graph: WorkflowGraph, agent: "BaseAgent"):
|
||||
self.graph = graph
|
||||
self.agent = agent
|
||||
self.state: WorkflowState = {}
|
||||
self.execution_log: List[Dict[str, Any]] = []
|
||||
self._condition_result: Optional[str] = None
|
||||
self._template_engine = TemplateEngine()
|
||||
self._namespace_manager = NamespaceManager()
|
||||
|
||||
def execute(
|
||||
self, initial_inputs: WorkflowState, query: str
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
self._initialize_state(initial_inputs, query)
|
||||
|
||||
start_node = self.graph.get_start_node()
|
||||
if not start_node:
|
||||
yield {"type": "error", "error": "No start node found in workflow."}
|
||||
return
|
||||
current_node_id: Optional[str] = start_node.id
|
||||
steps = 0
|
||||
|
||||
while current_node_id and steps < self.MAX_EXECUTION_STEPS:
|
||||
node = self.graph.get_node_by_id(current_node_id)
|
||||
if not node:
|
||||
yield {"type": "error", "error": f"Node {current_node_id} not found."}
|
||||
break
|
||||
log_entry = self._create_log_entry(node)
|
||||
|
||||
yield {
|
||||
"type": "workflow_step",
|
||||
"node_id": node.id,
|
||||
"node_type": node.type.value,
|
||||
"node_title": node.title,
|
||||
"status": "running",
|
||||
}
|
||||
|
||||
try:
|
||||
yield from self._execute_node(node)
|
||||
log_entry["status"] = ExecutionStatus.COMPLETED.value
|
||||
log_entry["completed_at"] = datetime.now(timezone.utc)
|
||||
|
||||
output_key = f"node_{node.id}_output"
|
||||
node_output = self.state.get(output_key)
|
||||
|
||||
yield {
|
||||
"type": "workflow_step",
|
||||
"node_id": node.id,
|
||||
"node_type": node.type.value,
|
||||
"node_title": node.title,
|
||||
"status": "completed",
|
||||
"state_snapshot": dict(self.state),
|
||||
"output": node_output,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing node {node.id}: {e}", exc_info=True)
|
||||
log_entry["status"] = ExecutionStatus.FAILED.value
|
||||
log_entry["error"] = str(e)
|
||||
log_entry["completed_at"] = datetime.now(timezone.utc)
|
||||
log_entry["state_snapshot"] = dict(self.state)
|
||||
self.execution_log.append(log_entry)
|
||||
|
||||
user_friendly_error = sanitize_api_error(e)
|
||||
yield {
|
||||
"type": "workflow_step",
|
||||
"node_id": node.id,
|
||||
"node_type": node.type.value,
|
||||
"node_title": node.title,
|
||||
"status": "failed",
|
||||
"state_snapshot": dict(self.state),
|
||||
"error": user_friendly_error,
|
||||
}
|
||||
yield {"type": "error", "error": user_friendly_error}
|
||||
break
|
||||
log_entry["state_snapshot"] = dict(self.state)
|
||||
self.execution_log.append(log_entry)
|
||||
|
||||
if node.type == NodeType.END:
|
||||
break
|
||||
current_node_id = self._get_next_node_id(current_node_id)
|
||||
if current_node_id is None and node.type != NodeType.END:
|
||||
logger.warning(
|
||||
f"Branch ended at node '{node.title}' ({node.id}) without reaching an end node"
|
||||
)
|
||||
steps += 1
|
||||
if steps >= self.MAX_EXECUTION_STEPS:
|
||||
logger.warning(
|
||||
f"Workflow reached max steps limit ({self.MAX_EXECUTION_STEPS})"
|
||||
)
|
||||
|
||||
def _initialize_state(self, initial_inputs: WorkflowState, query: str) -> None:
|
||||
self.state.update(initial_inputs)
|
||||
self.state["query"] = query
|
||||
self.state["chat_history"] = str(self.agent.chat_history)
|
||||
|
||||
def _create_log_entry(self, node: WorkflowNode) -> Dict[str, Any]:
|
||||
return {
|
||||
"node_id": node.id,
|
||||
"node_type": node.type.value,
|
||||
"started_at": datetime.now(timezone.utc),
|
||||
"completed_at": None,
|
||||
"status": ExecutionStatus.RUNNING.value,
|
||||
"error": None,
|
||||
"state_snapshot": {},
|
||||
}
|
||||
|
||||
def _get_next_node_id(self, current_node_id: str) -> Optional[str]:
|
||||
node = self.graph.get_node_by_id(current_node_id)
|
||||
edges = self.graph.get_outgoing_edges(current_node_id)
|
||||
if not edges:
|
||||
return None
|
||||
|
||||
if node and node.type == NodeType.CONDITION and self._condition_result:
|
||||
target_handle = self._condition_result
|
||||
self._condition_result = None
|
||||
for edge in edges:
|
||||
if edge.source_handle == target_handle:
|
||||
return edge.target_id
|
||||
return None
|
||||
|
||||
return edges[0].target_id
|
||||
|
||||
def _execute_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
logger.info(f"Executing node {node.id} ({node.type.value})")
|
||||
|
||||
node_handlers = {
|
||||
NodeType.START: self._execute_start_node,
|
||||
NodeType.NOTE: self._execute_note_node,
|
||||
NodeType.AGENT: self._execute_agent_node,
|
||||
NodeType.STATE: self._execute_state_node,
|
||||
NodeType.CONDITION: self._execute_condition_node,
|
||||
NodeType.END: self._execute_end_node,
|
||||
}
|
||||
|
||||
handler = node_handlers.get(node.type)
|
||||
if handler:
|
||||
yield from handler(node)
|
||||
|
||||
def _execute_start_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
yield from ()
|
||||
|
||||
def _execute_note_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
yield from ()
|
||||
|
||||
def _execute_agent_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
from application.core.model_utils import (
|
||||
get_api_key_for_provider,
|
||||
get_model_capabilities,
|
||||
get_provider_from_model_id,
|
||||
)
|
||||
|
||||
node_config = AgentNodeConfig(**node.config.get("config", node.config))
|
||||
|
||||
if node_config.prompt_template:
|
||||
formatted_prompt = self._format_template(node_config.prompt_template)
|
||||
else:
|
||||
formatted_prompt = self.state.get("query", "")
|
||||
node_json_schema = self._normalize_node_json_schema(
|
||||
node_config.json_schema, node.title
|
||||
)
|
||||
node_model_id = node_config.model_id or self.agent.model_id
|
||||
node_llm_name = (
|
||||
node_config.llm_name
|
||||
or get_provider_from_model_id(node_model_id or "")
|
||||
or self.agent.llm_name
|
||||
)
|
||||
node_api_key = get_api_key_for_provider(node_llm_name) or self.agent.api_key
|
||||
|
||||
if node_json_schema and node_model_id:
|
||||
model_capabilities = get_model_capabilities(node_model_id)
|
||||
if model_capabilities and not model_capabilities.get(
|
||||
"supports_structured_output", False
|
||||
):
|
||||
raise ValueError(
|
||||
f'Model "{node_model_id}" does not support structured output for node "{node.title}"'
|
||||
)
|
||||
|
||||
factory_kwargs = {
|
||||
"agent_type": node_config.agent_type,
|
||||
"endpoint": self.agent.endpoint,
|
||||
"llm_name": node_llm_name,
|
||||
"model_id": node_model_id,
|
||||
"api_key": node_api_key,
|
||||
"tool_ids": node_config.tools,
|
||||
"prompt": node_config.system_prompt,
|
||||
"chat_history": self.agent.chat_history,
|
||||
"decoded_token": self.agent.decoded_token,
|
||||
"json_schema": node_json_schema,
|
||||
}
|
||||
|
||||
# Agentic/research agents need retriever_config for on-demand search
|
||||
if node_config.agent_type in (AgentType.AGENTIC, AgentType.RESEARCH):
|
||||
factory_kwargs["retriever_config"] = {
|
||||
"source": {"active_docs": node_config.sources} if node_config.sources else {},
|
||||
"retriever_name": node_config.retriever or "classic",
|
||||
"chunks": int(node_config.chunks) if node_config.chunks else 2,
|
||||
"model_id": node_model_id,
|
||||
"llm_name": node_llm_name,
|
||||
"api_key": node_api_key,
|
||||
"decoded_token": self.agent.decoded_token,
|
||||
}
|
||||
|
||||
node_agent = WorkflowNodeAgentFactory.create(**factory_kwargs)
|
||||
|
||||
full_response_parts: List[str] = []
|
||||
structured_response_parts: List[str] = []
|
||||
has_structured_response = False
|
||||
first_chunk = True
|
||||
for event in node_agent.gen(formatted_prompt):
|
||||
if "answer" in event:
|
||||
chunk = str(event["answer"])
|
||||
full_response_parts.append(chunk)
|
||||
if event.get("structured"):
|
||||
has_structured_response = True
|
||||
structured_response_parts.append(chunk)
|
||||
if node_config.stream_to_user:
|
||||
if first_chunk and hasattr(self, "_has_streamed"):
|
||||
yield {"answer": "\n\n"}
|
||||
first_chunk = False
|
||||
yield event
|
||||
|
||||
if node_config.stream_to_user:
|
||||
self._has_streamed = True
|
||||
|
||||
full_response = "".join(full_response_parts).strip()
|
||||
output_value: Any = full_response
|
||||
if has_structured_response:
|
||||
structured_response = "".join(structured_response_parts).strip()
|
||||
response_to_parse = structured_response or full_response
|
||||
parsed_success, parsed_structured = self._parse_structured_output(
|
||||
response_to_parse
|
||||
)
|
||||
output_value = parsed_structured if parsed_success else response_to_parse
|
||||
if node_json_schema:
|
||||
self._validate_structured_output(node_json_schema, output_value)
|
||||
elif node_json_schema:
|
||||
parsed_success, parsed_structured = self._parse_structured_output(
|
||||
full_response
|
||||
)
|
||||
if not parsed_success:
|
||||
raise ValueError(
|
||||
"Structured output was expected but response was not valid JSON"
|
||||
)
|
||||
output_value = parsed_structured
|
||||
self._validate_structured_output(node_json_schema, output_value)
|
||||
|
||||
default_output_key = f"node_{node.id}_output"
|
||||
self.state[default_output_key] = output_value
|
||||
|
||||
if node_config.output_variable:
|
||||
self.state[node_config.output_variable] = output_value
|
||||
|
||||
def _execute_state_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
config = node.config.get("config", node.config)
|
||||
for op in config.get("operations", []):
|
||||
expression = op.get("expression", "")
|
||||
target_variable = op.get("target_variable", "")
|
||||
if expression and target_variable:
|
||||
self.state[target_variable] = evaluate_cel(expression, self.state)
|
||||
yield from ()
|
||||
|
||||
def _execute_condition_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
config = ConditionNodeConfig(**node.config.get("config", node.config))
|
||||
matched_handle = None
|
||||
|
||||
for case in config.cases:
|
||||
if not case.expression.strip():
|
||||
continue
|
||||
try:
|
||||
if evaluate_cel(case.expression, self.state):
|
||||
matched_handle = case.source_handle
|
||||
break
|
||||
except CelEvaluationError:
|
||||
continue
|
||||
|
||||
self._condition_result = matched_handle or "else"
|
||||
yield from ()
|
||||
|
||||
def _execute_end_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
config = node.config.get("config", node.config)
|
||||
output_template = str(config.get("output_template", ""))
|
||||
if output_template:
|
||||
formatted_output = self._format_template(output_template)
|
||||
yield {"answer": formatted_output}
|
||||
|
||||
def _parse_structured_output(self, raw_response: str) -> tuple[bool, Optional[Any]]:
|
||||
normalized_response = raw_response.strip()
|
||||
if not normalized_response:
|
||||
return False, None
|
||||
|
||||
try:
|
||||
return True, json.loads(normalized_response)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
"Workflow agent returned structured output that was not valid JSON"
|
||||
)
|
||||
return False, None
|
||||
|
||||
def _normalize_node_json_schema(
|
||||
self, schema: Optional[Dict[str, Any]], node_title: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
if schema is None:
|
||||
return None
|
||||
try:
|
||||
return normalize_json_schema_payload(schema)
|
||||
except JsonSchemaValidationError as exc:
|
||||
raise ValueError(
|
||||
f'Invalid JSON schema for node "{node_title}": {exc}'
|
||||
) from exc
|
||||
|
||||
def _validate_structured_output(self, schema: Dict[str, Any], output_value: Any) -> None:
|
||||
if jsonschema is None:
|
||||
logger.warning(
|
||||
"jsonschema package is not available, skipping structured output validation"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
normalized_schema = normalize_json_schema_payload(schema)
|
||||
except JsonSchemaValidationError as exc:
|
||||
raise ValueError(f"Invalid JSON schema: {exc}") from exc
|
||||
|
||||
try:
|
||||
jsonschema.validate(instance=output_value, schema=normalized_schema)
|
||||
except jsonschema.exceptions.ValidationError as exc:
|
||||
raise ValueError(f"Structured output did not match schema: {exc.message}") from exc
|
||||
except jsonschema.exceptions.SchemaError as exc:
|
||||
raise ValueError(f"Invalid JSON schema: {exc.message}") from exc
|
||||
|
||||
def _format_template(self, template: str) -> str:
|
||||
context = self._build_template_context()
|
||||
try:
|
||||
return self._template_engine.render(template, context)
|
||||
except TemplateRenderError as e:
|
||||
logger.warning(
|
||||
"Workflow template rendering failed, using raw template: %s", str(e)
|
||||
)
|
||||
return template
|
||||
|
||||
def _build_template_context(self) -> Dict[str, Any]:
|
||||
docs, docs_together = self._get_source_template_data()
|
||||
passthrough_data = (
|
||||
self.state.get("passthrough")
|
||||
if isinstance(self.state.get("passthrough"), dict)
|
||||
else None
|
||||
)
|
||||
tools_data = (
|
||||
self.state.get("tools") if isinstance(self.state.get("tools"), dict) else None
|
||||
)
|
||||
|
||||
context = self._namespace_manager.build_context(
|
||||
user_id=getattr(self.agent, "user", None),
|
||||
request_id=getattr(self.agent, "request_id", None),
|
||||
passthrough_data=passthrough_data,
|
||||
docs=docs,
|
||||
docs_together=docs_together,
|
||||
tools_data=tools_data,
|
||||
)
|
||||
|
||||
agent_context: Dict[str, Any] = {}
|
||||
for key, value in self.state.items():
|
||||
if not isinstance(key, str):
|
||||
continue
|
||||
normalized_key = key.strip()
|
||||
if not normalized_key:
|
||||
continue
|
||||
agent_context[normalized_key] = value
|
||||
|
||||
context["agent"] = agent_context
|
||||
|
||||
# Keep legacy top-level variables working while namespaced variables are adopted.
|
||||
for key, value in agent_context.items():
|
||||
if key in TEMPLATE_RESERVED_NAMESPACES:
|
||||
context[f"agent_{key}"] = value
|
||||
continue
|
||||
if key not in context:
|
||||
context[key] = value
|
||||
|
||||
return context
|
||||
|
||||
def _get_source_template_data(self) -> tuple[Optional[List[Dict[str, Any]]], Optional[str]]:
|
||||
docs = getattr(self.agent, "retrieved_docs", None)
|
||||
if not isinstance(docs, list) or len(docs) == 0:
|
||||
return None, None
|
||||
|
||||
docs_together_parts: List[str] = []
|
||||
for doc in docs:
|
||||
if not isinstance(doc, dict):
|
||||
continue
|
||||
text = doc.get("text")
|
||||
if not isinstance(text, str):
|
||||
continue
|
||||
|
||||
filename = doc.get("filename") or doc.get("title") or doc.get("source")
|
||||
if isinstance(filename, str) and filename.strip():
|
||||
docs_together_parts.append(f"{filename}\n{text}")
|
||||
else:
|
||||
docs_together_parts.append(text)
|
||||
|
||||
docs_together = "\n\n".join(docs_together_parts) if docs_together_parts else None
|
||||
return docs, docs_together
|
||||
|
||||
def get_execution_summary(self) -> List[NodeExecutionLog]:
|
||||
return [
|
||||
NodeExecutionLog(
|
||||
node_id=log["node_id"],
|
||||
node_type=log["node_type"],
|
||||
status=ExecutionStatus(log["status"]),
|
||||
started_at=log["started_at"],
|
||||
completed_at=log.get("completed_at"),
|
||||
error=log.get("error"),
|
||||
state_snapshot=log.get("state_snapshot", {}),
|
||||
)
|
||||
for log in self.execution_log
|
||||
]
|
||||
52
application/alembic.ini
Normal file
52
application/alembic.ini
Normal file
@@ -0,0 +1,52 @@
|
||||
# Alembic configuration for the DocsGPT user-data Postgres database.
|
||||
#
|
||||
# The SQLAlchemy URL is deliberately NOT set here — env.py reads it from
|
||||
# ``application.core.settings.settings.POSTGRES_URI`` so the same config
|
||||
# source serves the running app and migrations. To run from the project
|
||||
# root::
|
||||
#
|
||||
# alembic -c application/alembic.ini upgrade head
|
||||
|
||||
[alembic]
|
||||
script_location = %(here)s/alembic
|
||||
prepend_sys_path = ..
|
||||
version_path_separator = os
|
||||
|
||||
# sqlalchemy.url is intentionally left blank — env.py supplies it.
|
||||
sqlalchemy.url =
|
||||
|
||||
[post_write_hooks]
|
||||
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARNING
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARNING
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
82
application/alembic/env.py
Normal file
82
application/alembic/env.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""Alembic environment for the DocsGPT user-data Postgres database.
|
||||
|
||||
The URL is pulled from ``application.core.settings`` rather than
|
||||
``alembic.ini`` so that a single ``POSTGRES_URI`` env var drives both the
|
||||
running app and ``alembic`` CLI invocations.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from logging.config import fileConfig
|
||||
from pathlib import Path
|
||||
|
||||
# Make the project root importable regardless of cwd. env.py lives at
|
||||
# <repo>/application/alembic/env.py, so parents[2] is the repo root.
|
||||
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
if str(_PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(_PROJECT_ROOT))
|
||||
|
||||
from alembic import context # noqa: E402
|
||||
from sqlalchemy import engine_from_config, pool # noqa: E402
|
||||
|
||||
from application.core.settings import settings # noqa: E402
|
||||
from application.storage.db.models import metadata as target_metadata # noqa: E402
|
||||
|
||||
config = context.config
|
||||
|
||||
# Populate the runtime URL from settings.
|
||||
if settings.POSTGRES_URI:
|
||||
config.set_main_option("sqlalchemy.url", settings.POSTGRES_URI)
|
||||
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode (emits SQL without a live DB)."""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
if not url:
|
||||
raise RuntimeError(
|
||||
"POSTGRES_URI is not configured. Set it in your .env to a "
|
||||
"psycopg3 URI such as "
|
||||
"'postgresql+psycopg://user:pass@host:5432/docsgpt'."
|
||||
)
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
compare_type=True,
|
||||
)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode against a live connection."""
|
||||
if not config.get_main_option("sqlalchemy.url"):
|
||||
raise RuntimeError(
|
||||
"POSTGRES_URI is not configured. Set it in your .env to a "
|
||||
"psycopg3 URI such as "
|
||||
"'postgresql+psycopg://user:pass@host:5432/docsgpt'."
|
||||
)
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
future=True,
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata,
|
||||
compare_type=True,
|
||||
)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
26
application/alembic/script.py.mako
Normal file
26
application/alembic/script.py.mako
Normal file
@@ -0,0 +1,26 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
||||
825
application/alembic/versions/0001_initial.py
Normal file
825
application/alembic/versions/0001_initial.py
Normal file
@@ -0,0 +1,825 @@
|
||||
"""0001 initial schema — consolidated Phase-1..3 baseline.
|
||||
|
||||
Revision ID: 0001_initial
|
||||
Revises:
|
||||
Create Date: 2026-04-13
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
revision: str = "0001_initial"
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ------------------------------------------------------------------
|
||||
# Extensions
|
||||
# ------------------------------------------------------------------
|
||||
op.execute('CREATE EXTENSION IF NOT EXISTS "pgcrypto";')
|
||||
op.execute('CREATE EXTENSION IF NOT EXISTS "citext";')
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Trigger functions
|
||||
# ------------------------------------------------------------------
|
||||
op.execute(
|
||||
"""
|
||||
CREATE FUNCTION set_updated_at() RETURNS trigger
|
||||
LANGUAGE plpgsql AS $$
|
||||
BEGIN
|
||||
NEW.updated_at = now();
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE FUNCTION ensure_user_exists() RETURNS trigger
|
||||
LANGUAGE plpgsql AS $$
|
||||
BEGIN
|
||||
IF NEW.user_id IS NOT NULL THEN
|
||||
INSERT INTO users (user_id) VALUES (NEW.user_id)
|
||||
ON CONFLICT (user_id) DO NOTHING;
|
||||
END IF;
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE FUNCTION cleanup_message_attachment_refs() RETURNS trigger
|
||||
LANGUAGE plpgsql AS $$
|
||||
BEGIN
|
||||
UPDATE conversation_messages
|
||||
SET attachments = array_remove(attachments, OLD.id)
|
||||
WHERE OLD.id = ANY(attachments);
|
||||
RETURN OLD;
|
||||
END;
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE FUNCTION cleanup_agent_extra_source_refs() RETURNS trigger
|
||||
LANGUAGE plpgsql AS $$
|
||||
BEGIN
|
||||
UPDATE agents
|
||||
SET extra_source_ids = array_remove(extra_source_ids, OLD.id)
|
||||
WHERE OLD.id = ANY(extra_source_ids);
|
||||
RETURN OLD;
|
||||
END;
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE FUNCTION cleanup_user_agent_prefs() RETURNS trigger
|
||||
LANGUAGE plpgsql AS $$
|
||||
DECLARE
|
||||
agent_id_text text := OLD.id::text;
|
||||
BEGIN
|
||||
UPDATE users
|
||||
SET agent_preferences = jsonb_set(
|
||||
jsonb_set(
|
||||
agent_preferences,
|
||||
'{pinned}',
|
||||
COALESCE((
|
||||
SELECT jsonb_agg(e)
|
||||
FROM jsonb_array_elements(
|
||||
COALESCE(agent_preferences->'pinned', '[]'::jsonb)
|
||||
) e
|
||||
WHERE (e #>> '{}') <> agent_id_text
|
||||
), '[]'::jsonb)
|
||||
),
|
||||
'{shared_with_me}',
|
||||
COALESCE((
|
||||
SELECT jsonb_agg(e)
|
||||
FROM jsonb_array_elements(
|
||||
COALESCE(agent_preferences->'shared_with_me', '[]'::jsonb)
|
||||
) e
|
||||
WHERE (e #>> '{}') <> agent_id_text
|
||||
), '[]'::jsonb)
|
||||
)
|
||||
WHERE agent_preferences->'pinned' @> to_jsonb(agent_id_text)
|
||||
OR agent_preferences->'shared_with_me' @> to_jsonb(agent_id_text);
|
||||
RETURN OLD;
|
||||
END;
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE FUNCTION conversation_messages_fill_user_id() RETURNS trigger
|
||||
LANGUAGE plpgsql AS $$
|
||||
BEGIN
|
||||
IF NEW.user_id IS NULL THEN
|
||||
SELECT user_id INTO NEW.user_id
|
||||
FROM conversations
|
||||
WHERE id = NEW.conversation_id;
|
||||
END IF;
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tables
|
||||
# ------------------------------------------------------------------
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE users (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL UNIQUE,
|
||||
agent_preferences JSONB NOT NULL
|
||||
DEFAULT '{"pinned": [], "shared_with_me": []}'::jsonb,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE prompts (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
legacy_mongo_id TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE user_tools (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
custom_name TEXT,
|
||||
display_name TEXT,
|
||||
config JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE token_usage (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
user_id TEXT,
|
||||
api_key TEXT,
|
||||
agent_id UUID,
|
||||
prompt_tokens INTEGER NOT NULL DEFAULT 0,
|
||||
generated_tokens INTEGER NOT NULL DEFAULT 0,
|
||||
timestamp TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
"ALTER TABLE token_usage ADD CONSTRAINT token_usage_attribution_chk "
|
||||
"CHECK (user_id IS NOT NULL OR api_key IS NOT NULL) NOT VALID;"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE user_logs (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
user_id TEXT,
|
||||
endpoint TEXT,
|
||||
timestamp TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
data JSONB
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE stack_logs (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
activity_id TEXT NOT NULL,
|
||||
endpoint TEXT,
|
||||
level TEXT,
|
||||
user_id TEXT,
|
||||
api_key TEXT,
|
||||
query TEXT,
|
||||
stacks JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||
timestamp TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE agent_folders (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
description TEXT,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE sources (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
type TEXT,
|
||||
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE agents (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
description TEXT,
|
||||
agent_type TEXT,
|
||||
status TEXT NOT NULL,
|
||||
key CITEXT UNIQUE,
|
||||
source_id UUID REFERENCES sources(id) ON DELETE SET NULL,
|
||||
extra_source_ids UUID[] NOT NULL DEFAULT '{}',
|
||||
chunks INTEGER,
|
||||
retriever TEXT,
|
||||
prompt_id UUID REFERENCES prompts(id) ON DELETE SET NULL,
|
||||
tools JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||
json_schema JSONB,
|
||||
models JSONB,
|
||||
default_model_id TEXT,
|
||||
folder_id UUID REFERENCES agent_folders(id) ON DELETE SET NULL,
|
||||
limited_token_mode BOOLEAN NOT NULL DEFAULT false,
|
||||
token_limit INTEGER,
|
||||
limited_request_mode BOOLEAN NOT NULL DEFAULT false,
|
||||
request_limit INTEGER,
|
||||
shared BOOLEAN NOT NULL DEFAULT false,
|
||||
incoming_webhook_token CITEXT UNIQUE,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
last_used_at TIMESTAMPTZ,
|
||||
legacy_mongo_id TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
"ALTER TABLE token_usage ADD CONSTRAINT token_usage_agent_fk "
|
||||
"FOREIGN KEY (agent_id) REFERENCES agents(id) ON DELETE SET NULL;"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE attachments (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
filename TEXT NOT NULL,
|
||||
upload_path TEXT NOT NULL,
|
||||
mime_type TEXT,
|
||||
size BIGINT,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
legacy_mongo_id TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE memories (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE,
|
||||
path TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE todos (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE,
|
||||
title TEXT NOT NULL,
|
||||
completed BOOLEAN NOT NULL DEFAULT false,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE notes (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE,
|
||||
title TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE connector_sessions (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
provider TEXT NOT NULL,
|
||||
session_data JSONB NOT NULL,
|
||||
expires_at TIMESTAMPTZ,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE conversations (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
agent_id UUID REFERENCES agents(id) ON DELETE SET NULL,
|
||||
name TEXT,
|
||||
api_key TEXT,
|
||||
is_shared_usage BOOLEAN NOT NULL DEFAULT false,
|
||||
shared_token TEXT,
|
||||
date TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
shared_with TEXT[] NOT NULL DEFAULT '{}'::text[],
|
||||
compression_metadata JSONB,
|
||||
legacy_mongo_id TEXT,
|
||||
CONSTRAINT conversations_api_key_nonempty_chk
|
||||
CHECK (api_key IS NULL OR api_key <> '')
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE conversation_messages (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
conversation_id UUID NOT NULL REFERENCES conversations(id) ON DELETE CASCADE,
|
||||
position INTEGER NOT NULL,
|
||||
prompt TEXT,
|
||||
response TEXT,
|
||||
thought TEXT,
|
||||
sources JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||
tool_calls JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||
attachments UUID[] NOT NULL DEFAULT '{}'::uuid[],
|
||||
model_id TEXT,
|
||||
message_metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
feedback JSONB,
|
||||
timestamp TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
user_id TEXT NOT NULL,
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE shared_conversations (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
conversation_id UUID NOT NULL REFERENCES conversations(id) ON DELETE CASCADE,
|
||||
user_id TEXT NOT NULL,
|
||||
is_promptable BOOLEAN NOT NULL DEFAULT false,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
uuid UUID NOT NULL,
|
||||
first_n_queries INTEGER NOT NULL DEFAULT 0,
|
||||
api_key TEXT,
|
||||
prompt_id UUID REFERENCES prompts(id) ON DELETE SET NULL,
|
||||
chunks INTEGER,
|
||||
CONSTRAINT shared_conversations_api_key_nonempty_chk
|
||||
CHECK (api_key IS NULL OR api_key <> '')
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE pending_tool_state (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
conversation_id UUID NOT NULL REFERENCES conversations(id) ON DELETE CASCADE,
|
||||
user_id TEXT NOT NULL,
|
||||
messages JSONB NOT NULL,
|
||||
pending_tool_calls JSONB NOT NULL,
|
||||
tools_dict JSONB NOT NULL,
|
||||
tool_schemas JSONB NOT NULL,
|
||||
agent_config JSONB NOT NULL,
|
||||
client_tools JSONB,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
expires_at TIMESTAMPTZ NOT NULL
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE workflows (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
description TEXT,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
current_graph_version INTEGER NOT NULL DEFAULT 1,
|
||||
legacy_mongo_id TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE workflow_nodes (
|
||||
id UUID DEFAULT gen_random_uuid() NOT NULL,
|
||||
workflow_id UUID NOT NULL REFERENCES workflows(id) ON DELETE CASCADE,
|
||||
graph_version INTEGER NOT NULL,
|
||||
node_type TEXT NOT NULL,
|
||||
config JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
node_id TEXT NOT NULL,
|
||||
title TEXT,
|
||||
description TEXT,
|
||||
position JSONB NOT NULL DEFAULT '{"x": 0, "y": 0}'::jsonb,
|
||||
legacy_mongo_id TEXT,
|
||||
PRIMARY KEY (id),
|
||||
CONSTRAINT workflow_nodes_id_wf_ver_key
|
||||
UNIQUE (id, workflow_id, graph_version)
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE workflow_edges (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
workflow_id UUID NOT NULL REFERENCES workflows(id) ON DELETE CASCADE,
|
||||
graph_version INTEGER NOT NULL,
|
||||
from_node_id UUID NOT NULL,
|
||||
to_node_id UUID NOT NULL,
|
||||
config JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
edge_id TEXT NOT NULL,
|
||||
source_handle TEXT,
|
||||
target_handle TEXT,
|
||||
CONSTRAINT workflow_edges_from_node_fk
|
||||
FOREIGN KEY (from_node_id, workflow_id, graph_version)
|
||||
REFERENCES workflow_nodes(id, workflow_id, graph_version) ON DELETE CASCADE,
|
||||
CONSTRAINT workflow_edges_to_node_fk
|
||||
FOREIGN KEY (to_node_id, workflow_id, graph_version)
|
||||
REFERENCES workflow_nodes(id, workflow_id, graph_version) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE workflow_runs (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
workflow_id UUID NOT NULL REFERENCES workflows(id) ON DELETE CASCADE,
|
||||
user_id TEXT NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
started_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
ended_at TIMESTAMPTZ,
|
||||
result JSONB,
|
||||
inputs JSONB,
|
||||
steps JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||
legacy_mongo_id TEXT,
|
||||
CONSTRAINT workflow_runs_status_chk
|
||||
CHECK (status IN ('pending', 'running', 'completed', 'failed'))
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Indexes
|
||||
# ------------------------------------------------------------------
|
||||
op.execute("CREATE INDEX agent_folders_user_idx ON agent_folders (user_id);")
|
||||
|
||||
op.execute("CREATE INDEX agents_user_idx ON agents (user_id);")
|
||||
op.execute("CREATE INDEX agents_shared_idx ON agents (shared) WHERE shared = true;")
|
||||
op.execute("CREATE INDEX agents_status_idx ON agents (status);")
|
||||
op.execute("CREATE INDEX agents_source_id_idx ON agents (source_id);")
|
||||
op.execute("CREATE INDEX agents_prompt_id_idx ON agents (prompt_id);")
|
||||
op.execute("CREATE INDEX agents_folder_id_idx ON agents (folder_id);")
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX agents_legacy_mongo_id_uidx "
|
||||
"ON agents (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||
)
|
||||
|
||||
op.execute("CREATE INDEX attachments_user_idx ON attachments (user_id);")
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX attachments_legacy_mongo_id_uidx "
|
||||
"ON attachments (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX connector_sessions_user_provider_uidx "
|
||||
"ON connector_sessions (user_id, provider);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX connector_sessions_expiry_idx "
|
||||
"ON connector_sessions (expires_at) WHERE expires_at IS NOT NULL;"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX conversation_messages_conv_pos_uidx "
|
||||
"ON conversation_messages (conversation_id, position);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX conversation_messages_user_ts_idx "
|
||||
"ON conversation_messages (user_id, timestamp DESC);"
|
||||
)
|
||||
|
||||
op.execute("CREATE INDEX conversations_user_date_idx ON conversations (user_id, date DESC);")
|
||||
op.execute("CREATE INDEX conversations_agent_idx ON conversations (agent_id);")
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX conversations_shared_token_uidx "
|
||||
"ON conversations (shared_token) WHERE shared_token IS NOT NULL;"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX conversations_api_key_date_idx "
|
||||
"ON conversations (api_key, date DESC) WHERE api_key IS NOT NULL;"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX conversations_legacy_mongo_id_uidx "
|
||||
"ON conversations (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX memories_user_tool_path_uidx "
|
||||
"ON memories (user_id, tool_id, path);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX memories_user_path_null_tool_uidx "
|
||||
"ON memories (user_id, path) WHERE tool_id IS NULL;"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX memories_path_prefix_idx "
|
||||
"ON memories (user_id, tool_id, path text_pattern_ops);"
|
||||
)
|
||||
op.execute("CREATE INDEX memories_tool_id_idx ON memories (tool_id);")
|
||||
|
||||
op.execute("CREATE UNIQUE INDEX notes_user_tool_uidx ON notes (user_id, tool_id);")
|
||||
op.execute("CREATE INDEX notes_tool_id_idx ON notes (tool_id);")
|
||||
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX pending_tool_state_conv_user_uidx "
|
||||
"ON pending_tool_state (conversation_id, user_id);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX pending_tool_state_expires_idx ON pending_tool_state (expires_at);"
|
||||
)
|
||||
|
||||
op.execute("CREATE INDEX prompts_user_id_idx ON prompts (user_id);")
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX prompts_legacy_mongo_id_uidx "
|
||||
"ON prompts (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||
)
|
||||
|
||||
op.execute("CREATE INDEX shared_conversations_user_idx ON shared_conversations (user_id);")
|
||||
op.execute("CREATE INDEX shared_conversations_conv_idx ON shared_conversations (conversation_id);")
|
||||
op.execute(
|
||||
"CREATE INDEX shared_conversations_prompt_id_idx ON shared_conversations (prompt_id);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX shared_conversations_uuid_uidx ON shared_conversations (uuid);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX shared_conversations_dedup_uidx "
|
||||
"ON shared_conversations (conversation_id, user_id, is_promptable, first_n_queries, COALESCE(api_key, ''));"
|
||||
)
|
||||
|
||||
op.execute("CREATE INDEX sources_user_idx ON sources (user_id);")
|
||||
|
||||
op.execute('CREATE INDEX stack_logs_timestamp_idx ON stack_logs ("timestamp" DESC);')
|
||||
op.execute('CREATE INDEX stack_logs_user_ts_idx ON stack_logs (user_id, "timestamp" DESC);')
|
||||
op.execute('CREATE INDEX stack_logs_level_ts_idx ON stack_logs (level, "timestamp" DESC);')
|
||||
op.execute("CREATE INDEX stack_logs_activity_idx ON stack_logs (activity_id);")
|
||||
|
||||
op.execute("CREATE INDEX todos_user_tool_idx ON todos (user_id, tool_id);")
|
||||
op.execute("CREATE INDEX todos_tool_id_idx ON todos (tool_id);")
|
||||
|
||||
op.execute('CREATE INDEX token_usage_user_ts_idx ON token_usage (user_id, "timestamp" DESC);')
|
||||
op.execute('CREATE INDEX token_usage_key_ts_idx ON token_usage (api_key, "timestamp" DESC);')
|
||||
op.execute('CREATE INDEX token_usage_agent_ts_idx ON token_usage (agent_id, "timestamp" DESC);')
|
||||
|
||||
op.execute('CREATE INDEX user_logs_user_ts_idx ON user_logs (user_id, "timestamp" DESC);')
|
||||
|
||||
op.execute("CREATE INDEX user_tools_user_id_idx ON user_tools (user_id);")
|
||||
|
||||
op.execute("CREATE INDEX workflow_edges_from_node_idx ON workflow_edges (from_node_id);")
|
||||
op.execute("CREATE INDEX workflow_edges_to_node_idx ON workflow_edges (to_node_id);")
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX workflow_edges_wf_ver_eid_uidx "
|
||||
"ON workflow_edges (workflow_id, graph_version, edge_id);"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX workflow_nodes_wf_ver_nid_uidx "
|
||||
"ON workflow_nodes (workflow_id, graph_version, node_id);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX workflow_nodes_legacy_mongo_id_uidx "
|
||||
"ON workflow_nodes (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||
)
|
||||
|
||||
op.execute("CREATE INDEX workflow_runs_workflow_idx ON workflow_runs (workflow_id);")
|
||||
op.execute("CREATE INDEX workflow_runs_user_idx ON workflow_runs (user_id);")
|
||||
op.execute(
|
||||
"CREATE INDEX workflow_runs_status_started_idx "
|
||||
"ON workflow_runs (status, started_at DESC);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX workflow_runs_legacy_mongo_id_uidx "
|
||||
"ON workflow_runs (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||
)
|
||||
|
||||
op.execute("CREATE INDEX workflows_user_idx ON workflows (user_id);")
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX workflows_legacy_mongo_id_uidx "
|
||||
"ON workflows (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# user_id foreign keys (deferrable so backfills can stage rows)
|
||||
# ------------------------------------------------------------------
|
||||
user_fk_tables = (
|
||||
"agent_folders",
|
||||
"agents",
|
||||
"attachments",
|
||||
"connector_sessions",
|
||||
"conversation_messages",
|
||||
"conversations",
|
||||
"memories",
|
||||
"notes",
|
||||
"pending_tool_state",
|
||||
"prompts",
|
||||
"shared_conversations",
|
||||
"sources",
|
||||
"stack_logs",
|
||||
"todos",
|
||||
"token_usage",
|
||||
"user_logs",
|
||||
"user_tools",
|
||||
"workflow_runs",
|
||||
"workflows",
|
||||
)
|
||||
for table in user_fk_tables:
|
||||
op.execute(
|
||||
f"ALTER TABLE {table} "
|
||||
f"ADD CONSTRAINT {table}_user_id_fk "
|
||||
f"FOREIGN KEY (user_id) REFERENCES users(user_id) "
|
||||
f"ON DELETE RESTRICT DEFERRABLE INITIALLY IMMEDIATE;"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Triggers
|
||||
# ------------------------------------------------------------------
|
||||
updated_at_tables = (
|
||||
"agent_folders",
|
||||
"agents",
|
||||
"conversation_messages",
|
||||
"conversations",
|
||||
"memories",
|
||||
"notes",
|
||||
"prompts",
|
||||
"sources",
|
||||
"todos",
|
||||
"user_tools",
|
||||
"users",
|
||||
"workflows",
|
||||
)
|
||||
for table in updated_at_tables:
|
||||
op.execute(
|
||||
f"CREATE TRIGGER {table}_set_updated_at "
|
||||
f"BEFORE UPDATE ON {table} "
|
||||
f"FOR EACH ROW WHEN (OLD.* IS DISTINCT FROM NEW.*) "
|
||||
f"EXECUTE FUNCTION set_updated_at();"
|
||||
)
|
||||
|
||||
ensure_user_tables = (
|
||||
"agent_folders",
|
||||
"agents",
|
||||
"attachments",
|
||||
"connector_sessions",
|
||||
"conversation_messages",
|
||||
"conversations",
|
||||
"memories",
|
||||
"notes",
|
||||
"pending_tool_state",
|
||||
"prompts",
|
||||
"shared_conversations",
|
||||
"sources",
|
||||
"stack_logs",
|
||||
"todos",
|
||||
"token_usage",
|
||||
"user_logs",
|
||||
"user_tools",
|
||||
"workflow_runs",
|
||||
"workflows",
|
||||
)
|
||||
for table in ensure_user_tables:
|
||||
op.execute(
|
||||
f"CREATE TRIGGER {table}_ensure_user "
|
||||
f"BEFORE INSERT OR UPDATE OF user_id ON {table} "
|
||||
f"FOR EACH ROW EXECUTE FUNCTION ensure_user_exists();"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"CREATE TRIGGER conversation_messages_fill_user "
|
||||
"BEFORE INSERT ON conversation_messages "
|
||||
"FOR EACH ROW EXECUTE FUNCTION conversation_messages_fill_user_id();"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"CREATE TRIGGER attachments_cleanup_message_refs "
|
||||
"AFTER DELETE ON attachments "
|
||||
"FOR EACH ROW EXECUTE FUNCTION cleanup_message_attachment_refs();"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE TRIGGER agents_cleanup_user_prefs "
|
||||
"AFTER DELETE ON agents "
|
||||
"FOR EACH ROW EXECUTE FUNCTION cleanup_user_agent_prefs();"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE TRIGGER sources_cleanup_agent_extra_refs "
|
||||
"AFTER DELETE ON sources "
|
||||
"FOR EACH ROW EXECUTE FUNCTION cleanup_agent_extra_source_refs();"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Seed sentinel __system__ user (system/template sources attribute here)
|
||||
# ------------------------------------------------------------------
|
||||
op.execute(
|
||||
"INSERT INTO users (user_id) VALUES ('__system__') "
|
||||
"ON CONFLICT (user_id) DO NOTHING;"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Nuclear downgrade: drop everything this migration created. The
|
||||
# ordering drops FK-bearing children before parents; CASCADE would
|
||||
# also work but explicit ordering is easier to reason about in code
|
||||
# review.
|
||||
tables_in_drop_order = (
|
||||
"workflow_edges",
|
||||
"workflow_runs",
|
||||
"workflow_nodes",
|
||||
"workflows",
|
||||
"pending_tool_state",
|
||||
"shared_conversations",
|
||||
"conversation_messages",
|
||||
"conversations",
|
||||
"connector_sessions",
|
||||
"notes",
|
||||
"todos",
|
||||
"memories",
|
||||
"attachments",
|
||||
"agents",
|
||||
"sources",
|
||||
"agent_folders",
|
||||
"stack_logs",
|
||||
"user_logs",
|
||||
"token_usage",
|
||||
"user_tools",
|
||||
"prompts",
|
||||
"users",
|
||||
)
|
||||
for table in tables_in_drop_order:
|
||||
op.execute(f"DROP TABLE IF EXISTS {table} CASCADE;")
|
||||
|
||||
for fn in (
|
||||
"conversation_messages_fill_user_id",
|
||||
"cleanup_user_agent_prefs",
|
||||
"cleanup_agent_extra_source_refs",
|
||||
"cleanup_message_attachment_refs",
|
||||
"ensure_user_exists",
|
||||
"set_updated_at",
|
||||
):
|
||||
op.execute(f"DROP FUNCTION IF EXISTS {fn}();")
|
||||
@@ -0,0 +1,7 @@
|
||||
from flask_restx import Api
|
||||
|
||||
api = Api(
|
||||
version="1.0",
|
||||
title="DocsGPT API",
|
||||
description="API for DocsGPT",
|
||||
)
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
from flask import Blueprint
|
||||
|
||||
from application.api import api
|
||||
from application.api.answer.routes.answer import AnswerResource
|
||||
from application.api.answer.routes.base import answer_ns
|
||||
from application.api.answer.routes.search import SearchResource
|
||||
from application.api.answer.routes.stream import StreamResource
|
||||
|
||||
|
||||
answer = Blueprint("answer", __name__)
|
||||
|
||||
api.add_namespace(answer_ns)
|
||||
|
||||
|
||||
def init_answer_routes():
|
||||
api.add_resource(StreamResource, "/stream")
|
||||
api.add_resource(AnswerResource, "/api/answer")
|
||||
api.add_resource(SearchResource, "/api/search")
|
||||
|
||||
|
||||
init_answer_routes()
|
||||
|
||||
@@ -1,786 +0,0 @@
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import traceback
|
||||
|
||||
from bson.dbref import DBRef
|
||||
from bson.objectid import ObjectId
|
||||
from flask import Blueprint, make_response, request, Response
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.agents.agent_creator import AgentCreator
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.error import bad_request
|
||||
from application.extensions import api
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.retriever.retriever_creator import RetrieverCreator
|
||||
from application.utils import check_required_fields, limit_chat_history
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo["docsgpt"]
|
||||
conversations_collection = db["conversations"]
|
||||
sources_collection = db["sources"]
|
||||
prompts_collection = db["prompts"]
|
||||
api_key_collection = db["api_keys"]
|
||||
user_logs_collection = db["user_logs"]
|
||||
|
||||
answer = Blueprint("answer", __name__)
|
||||
answer_ns = Namespace("answer", description="Answer related operations", path="/")
|
||||
api.add_namespace(answer_ns)
|
||||
|
||||
gpt_model = ""
|
||||
# to have some kind of default behaviour
|
||||
if settings.LLM_NAME == "openai":
|
||||
gpt_model = "gpt-4o-mini"
|
||||
elif settings.LLM_NAME == "anthropic":
|
||||
gpt_model = "claude-2"
|
||||
elif settings.LLM_NAME == "groq":
|
||||
gpt_model = "llama3-8b-8192"
|
||||
elif settings.LLM_NAME == "novita":
|
||||
gpt_model = "deepseek/deepseek-r1"
|
||||
|
||||
if settings.MODEL_NAME: # in case there is particular model name configured
|
||||
gpt_model = settings.MODEL_NAME
|
||||
|
||||
# load the prompts
|
||||
current_dir = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
)
|
||||
with open(os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r") as f:
|
||||
chat_combine_template = f.read()
|
||||
|
||||
with open(os.path.join(current_dir, "prompts", "chat_reduce_prompt.txt"), "r") as f:
|
||||
chat_reduce_template = f.read()
|
||||
|
||||
with open(os.path.join(current_dir, "prompts", "chat_combine_creative.txt"), "r") as f:
|
||||
chat_combine_creative = f.read()
|
||||
|
||||
with open(os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r") as f:
|
||||
chat_combine_strict = f.read()
|
||||
|
||||
api_key_set = settings.API_KEY is not None
|
||||
embeddings_key_set = settings.EMBEDDINGS_KEY is not None
|
||||
|
||||
|
||||
async def async_generate(chain, question, chat_history):
|
||||
result = await chain.arun({"question": question, "chat_history": chat_history})
|
||||
return result
|
||||
|
||||
|
||||
def run_async_chain(chain, question, chat_history):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
result = {}
|
||||
try:
|
||||
answer = loop.run_until_complete(async_generate(chain, question, chat_history))
|
||||
finally:
|
||||
loop.close()
|
||||
result["answer"] = answer
|
||||
return result
|
||||
|
||||
|
||||
def get_data_from_api_key(api_key):
|
||||
data = api_key_collection.find_one({"key": api_key})
|
||||
# # Raise custom exception if the API key is not found
|
||||
if data is None:
|
||||
raise Exception("Invalid API Key, please generate new key", 401)
|
||||
|
||||
if "source" in data and isinstance(data["source"], DBRef):
|
||||
source_doc = db.dereference(data["source"])
|
||||
data["source"] = str(source_doc["_id"])
|
||||
if "retriever" in source_doc:
|
||||
data["retriever"] = source_doc["retriever"]
|
||||
else:
|
||||
data["source"] = {}
|
||||
return data
|
||||
|
||||
|
||||
def get_retriever(source_id: str):
|
||||
doc = sources_collection.find_one({"_id": ObjectId(source_id)})
|
||||
if doc is None:
|
||||
raise Exception("Source document does not exist", 404)
|
||||
retriever_name = None if "retriever" not in doc else doc["retriever"]
|
||||
return retriever_name
|
||||
|
||||
|
||||
def is_azure_configured():
|
||||
return (
|
||||
settings.OPENAI_API_BASE
|
||||
and settings.OPENAI_API_VERSION
|
||||
and settings.AZURE_DEPLOYMENT_NAME
|
||||
)
|
||||
|
||||
|
||||
def save_conversation(
|
||||
conversation_id,
|
||||
question,
|
||||
response,
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
decoded_token,
|
||||
index=None,
|
||||
api_key=None,
|
||||
):
|
||||
current_time = datetime.datetime.now(datetime.timezone.utc)
|
||||
if conversation_id is not None and index is not None:
|
||||
conversations_collection.update_one(
|
||||
{"_id": ObjectId(conversation_id), f"queries.{index}": {"$exists": True}},
|
||||
{
|
||||
"$set": {
|
||||
f"queries.{index}.prompt": question,
|
||||
f"queries.{index}.response": response,
|
||||
f"queries.{index}.sources": source_log_docs,
|
||||
f"queries.{index}.tool_calls": tool_calls,
|
||||
f"queries.{index}.timestamp": current_time,
|
||||
}
|
||||
},
|
||||
)
|
||||
##remove following queries from the array
|
||||
conversations_collection.update_one(
|
||||
{"_id": ObjectId(conversation_id), f"queries.{index}": {"$exists": True}},
|
||||
{"$push": {"queries": {"$each": [], "$slice": index + 1}}},
|
||||
)
|
||||
elif conversation_id is not None and conversation_id != "None":
|
||||
conversations_collection.update_one(
|
||||
{"_id": ObjectId(conversation_id)},
|
||||
{
|
||||
"$push": {
|
||||
"queries": {
|
||||
"prompt": question,
|
||||
"response": response,
|
||||
"sources": source_log_docs,
|
||||
"tool_calls": tool_calls,
|
||||
"timestamp": current_time,
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
else:
|
||||
# create new conversation
|
||||
# generate summary
|
||||
messages_summary = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Summarise following conversation in no more than 3 "
|
||||
"words, respond ONLY with the summary, use the same "
|
||||
"language as the system",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Summarise following conversation in no more than 3 words, "
|
||||
"respond ONLY with the summary, use the same language as the "
|
||||
"system \n\nUser: " + question + "\n\n" + "AI: " + response,
|
||||
},
|
||||
]
|
||||
|
||||
completion = llm.gen(model=gpt_model, messages=messages_summary, max_tokens=30)
|
||||
conversation_data = {
|
||||
"user": decoded_token.get("sub"),
|
||||
"date": datetime.datetime.utcnow(),
|
||||
"name": completion,
|
||||
"queries": [
|
||||
{
|
||||
"prompt": question,
|
||||
"response": response,
|
||||
"sources": source_log_docs,
|
||||
"tool_calls": tool_calls,
|
||||
"timestamp": current_time,
|
||||
}
|
||||
],
|
||||
}
|
||||
if api_key:
|
||||
api_key_doc = api_key_collection.find_one({"key": api_key})
|
||||
if api_key_doc:
|
||||
conversation_data["api_key"] = api_key_doc["key"]
|
||||
conversation_id = conversations_collection.insert_one(
|
||||
conversation_data
|
||||
).inserted_id
|
||||
return conversation_id
|
||||
|
||||
|
||||
def get_prompt(prompt_id):
|
||||
if prompt_id == "default":
|
||||
prompt = chat_combine_template
|
||||
elif prompt_id == "creative":
|
||||
prompt = chat_combine_creative
|
||||
elif prompt_id == "strict":
|
||||
prompt = chat_combine_strict
|
||||
else:
|
||||
prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"]
|
||||
return prompt
|
||||
|
||||
|
||||
def complete_stream(
|
||||
question,
|
||||
agent,
|
||||
retriever,
|
||||
conversation_id,
|
||||
user_api_key,
|
||||
decoded_token,
|
||||
isNoneDoc=False,
|
||||
index=None,
|
||||
should_save_conversation=True,
|
||||
):
|
||||
try:
|
||||
response_full = ""
|
||||
source_log_docs = []
|
||||
tool_calls = []
|
||||
|
||||
answer = agent.gen(query=question, retriever=retriever)
|
||||
|
||||
for line in answer:
|
||||
if "answer" in line:
|
||||
response_full += str(line["answer"])
|
||||
data = json.dumps({"type": "answer", "answer": line["answer"]})
|
||||
yield f"data: {data}\n\n"
|
||||
elif "sources" in line:
|
||||
truncated_sources = []
|
||||
source_log_docs = line["sources"]
|
||||
for source in line["sources"]:
|
||||
truncated_source = source.copy()
|
||||
if "text" in truncated_source:
|
||||
truncated_source["text"] = (
|
||||
truncated_source["text"][:100].strip() + "..."
|
||||
)
|
||||
truncated_sources.append(truncated_source)
|
||||
if len(truncated_sources) > 0:
|
||||
data = json.dumps({"type": "source", "source": truncated_sources})
|
||||
yield f"data: {data}\n\n"
|
||||
elif "tool_calls" in line:
|
||||
tool_calls = line["tool_calls"]
|
||||
data = json.dumps({"type": "tool_calls", "tool_calls": tool_calls})
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
if isNoneDoc:
|
||||
for doc in source_log_docs:
|
||||
doc["source"] = "None"
|
||||
|
||||
llm = LLMCreator.create_llm(
|
||||
settings.LLM_NAME,
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
if should_save_conversation:
|
||||
conversation_id = save_conversation(
|
||||
conversation_id,
|
||||
question,
|
||||
response_full,
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
decoded_token,
|
||||
index,
|
||||
api_key=user_api_key,
|
||||
)
|
||||
else:
|
||||
conversation_id = None
|
||||
|
||||
# send data.type = "end" to indicate that the stream has ended as json
|
||||
data = json.dumps({"type": "id", "id": str(conversation_id)})
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
retriever_params = retriever.get_params()
|
||||
user_logs_collection.insert_one(
|
||||
{
|
||||
"action": "stream_answer",
|
||||
"level": "info",
|
||||
"user": decoded_token.get("sub"),
|
||||
"api_key": user_api_key,
|
||||
"question": question,
|
||||
"response": response_full,
|
||||
"sources": source_log_docs,
|
||||
"retriever_params": retriever_params,
|
||||
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
||||
}
|
||||
)
|
||||
data = json.dumps({"type": "end"})
|
||||
yield f"data: {data}\n\n"
|
||||
except Exception as e:
|
||||
logger.error(f"Error in stream: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
data = json.dumps(
|
||||
{
|
||||
"type": "error",
|
||||
"error": "Please try again later. We apologize for any inconvenience.",
|
||||
}
|
||||
)
|
||||
yield f"data: {data}\n\n"
|
||||
return
|
||||
|
||||
|
||||
@answer_ns.route("/stream")
|
||||
class Stream(Resource):
|
||||
stream_model = api.model(
|
||||
"StreamModel",
|
||||
{
|
||||
"question": fields.String(
|
||||
required=True, description="Question to be asked"
|
||||
),
|
||||
"history": fields.List(
|
||||
fields.String, required=False, description="Chat history"
|
||||
),
|
||||
"conversation_id": fields.String(
|
||||
required=False, description="Conversation ID"
|
||||
),
|
||||
"prompt_id": fields.String(
|
||||
required=False, default="default", description="Prompt ID"
|
||||
),
|
||||
"chunks": fields.Integer(
|
||||
required=False, default=2, description="Number of chunks"
|
||||
),
|
||||
"token_limit": fields.Integer(required=False, description="Token limit"),
|
||||
"retriever": fields.String(required=False, description="Retriever type"),
|
||||
"api_key": fields.String(required=False, description="API key"),
|
||||
"active_docs": fields.String(
|
||||
required=False, description="Active documents"
|
||||
),
|
||||
"isNoneDoc": fields.Boolean(
|
||||
required=False, description="Flag indicating if no document is used"
|
||||
),
|
||||
"index": fields.Integer(
|
||||
required=False, description="The position where query is to be updated"
|
||||
),
|
||||
"save_conversation": fields.Boolean(
|
||||
required=False, default=True, description="Flag to save conversation"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(stream_model)
|
||||
@api.doc(description="Stream a response based on the question and retriever")
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
required_fields = ["question"]
|
||||
if "index" in data:
|
||||
required_fields = ["question", "conversation_id"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
|
||||
save_conv = data.get("save_conversation", True)
|
||||
|
||||
try:
|
||||
question = data["question"]
|
||||
history = limit_chat_history(
|
||||
json.loads(data.get("history", [])), gpt_model=gpt_model
|
||||
)
|
||||
conversation_id = data.get("conversation_id")
|
||||
prompt_id = data.get("prompt_id", "default")
|
||||
|
||||
index = data.get("index", None)
|
||||
chunks = int(data.get("chunks", 2))
|
||||
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
|
||||
retriever_name = data.get("retriever", "classic")
|
||||
|
||||
if "api_key" in data:
|
||||
data_key = get_data_from_api_key(data["api_key"])
|
||||
chunks = int(data_key.get("chunks", 2))
|
||||
prompt_id = data_key.get("prompt_id", "default")
|
||||
source = {"active_docs": data_key.get("source")}
|
||||
retriever_name = data_key.get("retriever", retriever_name)
|
||||
user_api_key = data["api_key"]
|
||||
decoded_token = {"sub": data_key.get("user")}
|
||||
|
||||
elif "active_docs" in data:
|
||||
source = {"active_docs": data["active_docs"]}
|
||||
retriever_name = get_retriever(data["active_docs"]) or retriever_name
|
||||
user_api_key = None
|
||||
decoded_token = request.decoded_token
|
||||
|
||||
else:
|
||||
source = {}
|
||||
user_api_key = None
|
||||
decoded_token = request.decoded_token
|
||||
|
||||
if not decoded_token:
|
||||
return make_response({"error": "Unauthorized"}, 401)
|
||||
|
||||
logger.info(
|
||||
f"/stream - request_data: {data}, source: {source}",
|
||||
extra={"data": json.dumps({"request_data": data, "source": source})},
|
||||
)
|
||||
|
||||
prompt = get_prompt(prompt_id)
|
||||
if "isNoneDoc" in data and data["isNoneDoc"] is True:
|
||||
chunks = 0
|
||||
|
||||
agent = AgentCreator.create_agent(
|
||||
settings.AGENT_NAME,
|
||||
endpoint="stream",
|
||||
llm_name=settings.LLM_NAME,
|
||||
gpt_model=gpt_model,
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=user_api_key,
|
||||
prompt=prompt,
|
||||
chat_history=history,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
retriever = RetrieverCreator.create_retriever(
|
||||
retriever_name,
|
||||
source=source,
|
||||
chat_history=history,
|
||||
prompt=prompt,
|
||||
chunks=chunks,
|
||||
token_limit=token_limit,
|
||||
gpt_model=gpt_model,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
return Response(
|
||||
complete_stream(
|
||||
question=question,
|
||||
agent=agent,
|
||||
retriever=retriever,
|
||||
conversation_id=conversation_id,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
isNoneDoc=data.get("isNoneDoc"),
|
||||
index=index,
|
||||
should_save_conversation=save_conv,
|
||||
),
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
|
||||
except ValueError:
|
||||
message = "Malformed request body"
|
||||
logger.error(f"/stream - error: {message}")
|
||||
return Response(
|
||||
error_stream_generate(message),
|
||||
status=400,
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"/stream - error: {str(e)} - traceback: {traceback.format_exc()}",
|
||||
extra={"error": str(e), "traceback": traceback.format_exc()},
|
||||
)
|
||||
status_code = 400
|
||||
return Response(
|
||||
error_stream_generate("Unknown error occurred"),
|
||||
status=status_code,
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
|
||||
|
||||
def error_stream_generate(err_response):
|
||||
data = json.dumps({"type": "error", "error": err_response})
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
|
||||
@answer_ns.route("/api/answer")
|
||||
class Answer(Resource):
|
||||
answer_model = api.model(
|
||||
"AnswerModel",
|
||||
{
|
||||
"question": fields.String(
|
||||
required=True, description="The question to answer"
|
||||
),
|
||||
"history": fields.List(
|
||||
fields.String, required=False, description="Conversation history"
|
||||
),
|
||||
"conversation_id": fields.String(
|
||||
required=False, description="Conversation ID"
|
||||
),
|
||||
"prompt_id": fields.String(
|
||||
required=False, default="default", description="Prompt ID"
|
||||
),
|
||||
"chunks": fields.Integer(
|
||||
required=False, default=2, description="Number of chunks"
|
||||
),
|
||||
"token_limit": fields.Integer(required=False, description="Token limit"),
|
||||
"retriever": fields.String(required=False, description="Retriever type"),
|
||||
"api_key": fields.String(required=False, description="API key"),
|
||||
"active_docs": fields.String(
|
||||
required=False, description="Active documents"
|
||||
),
|
||||
"isNoneDoc": fields.Boolean(
|
||||
required=False, description="Flag indicating if no document is used"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(answer_model)
|
||||
@api.doc(description="Provide an answer based on the question and retriever")
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
required_fields = ["question"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
|
||||
try:
|
||||
question = data["question"]
|
||||
history = limit_chat_history(
|
||||
json.loads(data.get("history", [])), gpt_model=gpt_model
|
||||
)
|
||||
conversation_id = data.get("conversation_id")
|
||||
prompt_id = data.get("prompt_id", "default")
|
||||
chunks = int(data.get("chunks", 2))
|
||||
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
|
||||
retriever_name = data.get("retriever", "classic")
|
||||
|
||||
if "api_key" in data:
|
||||
data_key = get_data_from_api_key(data["api_key"])
|
||||
chunks = int(data_key.get("chunks", 2))
|
||||
prompt_id = data_key.get("prompt_id", "default")
|
||||
source = {"active_docs": data_key.get("source")}
|
||||
retriever_name = data_key.get("retriever", retriever_name)
|
||||
user_api_key = data["api_key"]
|
||||
decoded_token = {"sub": data_key.get("user")}
|
||||
|
||||
elif "active_docs" in data:
|
||||
source = {"active_docs": data["active_docs"]}
|
||||
retriever_name = get_retriever(data["active_docs"]) or retriever_name
|
||||
user_api_key = None
|
||||
decoded_token = request.decoded_token
|
||||
|
||||
else:
|
||||
source = {}
|
||||
user_api_key = None
|
||||
decoded_token = request.decoded_token
|
||||
|
||||
if not decoded_token:
|
||||
return make_response({"error": "Unauthorized"}, 401)
|
||||
|
||||
prompt = get_prompt(prompt_id)
|
||||
|
||||
logger.info(
|
||||
f"/api/answer - request_data: {data}, source: {source}",
|
||||
extra={"data": json.dumps({"request_data": data, "source": source})},
|
||||
)
|
||||
|
||||
agent = AgentCreator.create_agent(
|
||||
settings.AGENT_NAME,
|
||||
endpoint="api/answer",
|
||||
llm_name=settings.LLM_NAME,
|
||||
gpt_model=gpt_model,
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=user_api_key,
|
||||
prompt=prompt,
|
||||
chat_history=history,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
retriever = RetrieverCreator.create_retriever(
|
||||
retriever_name,
|
||||
source=source,
|
||||
chat_history=history,
|
||||
prompt=prompt,
|
||||
chunks=chunks,
|
||||
token_limit=token_limit,
|
||||
gpt_model=gpt_model,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
response_full = ""
|
||||
source_log_docs = []
|
||||
tool_calls = []
|
||||
stream_ended = False
|
||||
|
||||
for line in complete_stream(
|
||||
question=question,
|
||||
agent=agent,
|
||||
retriever=retriever,
|
||||
conversation_id=conversation_id,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
isNoneDoc=data.get("isNoneDoc"),
|
||||
index=None,
|
||||
should_save_conversation=False,
|
||||
):
|
||||
try:
|
||||
event_data = line.replace("data: ", "").strip()
|
||||
event = json.loads(event_data)
|
||||
|
||||
if event["type"] == "answer":
|
||||
response_full += event["answer"]
|
||||
elif event["type"] == "source":
|
||||
source_log_docs = event["source"]
|
||||
elif event["type"] == "tool_calls":
|
||||
tool_calls = event["tool_calls"]
|
||||
elif event["type"] == "error":
|
||||
logger.error(f"Error from stream: {event['error']}")
|
||||
return bad_request(500, event["error"])
|
||||
elif event["type"] == "end":
|
||||
stream_ended = True
|
||||
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
logger.warning(f"Error parsing stream event: {e}, line: {line}")
|
||||
continue
|
||||
|
||||
if not stream_ended:
|
||||
logger.error("Stream ended unexpectedly without an 'end' event.")
|
||||
return bad_request(500, "Stream ended unexpectedly.")
|
||||
|
||||
if data.get("isNoneDoc"):
|
||||
for doc in source_log_docs:
|
||||
doc["source"] = "None"
|
||||
|
||||
llm = LLMCreator.create_llm(
|
||||
settings.LLM_NAME,
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
result = {"answer": response_full, "sources": source_log_docs}
|
||||
result["conversation_id"] = str(
|
||||
save_conversation(
|
||||
conversation_id,
|
||||
question,
|
||||
response_full,
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
decoded_token,
|
||||
api_key=user_api_key,
|
||||
)
|
||||
)
|
||||
|
||||
retriever_params = retriever.get_params()
|
||||
user_logs_collection.insert_one(
|
||||
{
|
||||
"action": "api_answer",
|
||||
"level": "info",
|
||||
"user": decoded_token.get("sub"),
|
||||
"api_key": user_api_key,
|
||||
"question": question,
|
||||
"response": response_full,
|
||||
"sources": source_log_docs,
|
||||
"retriever_params": retriever_params,
|
||||
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",
|
||||
extra={"error": str(e), "traceback": traceback.format_exc()},
|
||||
)
|
||||
return bad_request(500, str(e))
|
||||
|
||||
return make_response(result, 200)
|
||||
|
||||
|
||||
@answer_ns.route("/api/search")
|
||||
class Search(Resource):
|
||||
search_model = api.model(
|
||||
"SearchModel",
|
||||
{
|
||||
"question": fields.String(
|
||||
required=True, description="The question to search"
|
||||
),
|
||||
"chunks": fields.Integer(
|
||||
required=False, default=2, description="Number of chunks"
|
||||
),
|
||||
"api_key": fields.String(
|
||||
required=False, description="API key for authentication"
|
||||
),
|
||||
"active_docs": fields.String(
|
||||
required=False, description="Active documents for retrieval"
|
||||
),
|
||||
"retriever": fields.String(required=False, description="Retriever type"),
|
||||
"token_limit": fields.Integer(
|
||||
required=False, description="Limit for tokens"
|
||||
),
|
||||
"isNoneDoc": fields.Boolean(
|
||||
required=False, description="Flag indicating if no document is used"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(search_model)
|
||||
@api.doc(
|
||||
description="Search for relevant documents based on the question and retriever"
|
||||
)
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
required_fields = ["question"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
|
||||
try:
|
||||
question = data["question"]
|
||||
chunks = int(data.get("chunks", 2))
|
||||
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
|
||||
retriever_name = data.get("retriever", "classic")
|
||||
|
||||
if "api_key" in data:
|
||||
data_key = get_data_from_api_key(data["api_key"])
|
||||
chunks = int(data_key.get("chunks", 2))
|
||||
source = {"active_docs": data_key.get("source")}
|
||||
user_api_key = data["api_key"]
|
||||
decoded_token = {"sub": data_key.get("user")}
|
||||
|
||||
elif "active_docs" in data:
|
||||
source = {"active_docs": data["active_docs"]}
|
||||
user_api_key = None
|
||||
decoded_token = request.decoded_token
|
||||
|
||||
else:
|
||||
source = {}
|
||||
user_api_key = None
|
||||
decoded_token = request.decoded_token
|
||||
|
||||
if not decoded_token:
|
||||
return make_response({"error": "Unauthorized"}, 401)
|
||||
|
||||
logger.info(
|
||||
f"/api/answer - request_data: {data}, source: {source}",
|
||||
extra={"data": json.dumps({"request_data": data, "source": source})},
|
||||
)
|
||||
|
||||
retriever = RetrieverCreator.create_retriever(
|
||||
retriever_name,
|
||||
source=source,
|
||||
chat_history=[],
|
||||
prompt="default",
|
||||
chunks=chunks,
|
||||
token_limit=token_limit,
|
||||
gpt_model=gpt_model,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
docs = retriever.search(question)
|
||||
retriever_params = retriever.get_params()
|
||||
|
||||
user_logs_collection.insert_one(
|
||||
{
|
||||
"action": "api_search",
|
||||
"level": "info",
|
||||
"user": decoded_token.get("sub"),
|
||||
"api_key": user_api_key,
|
||||
"question": question,
|
||||
"sources": docs,
|
||||
"retriever_params": retriever_params,
|
||||
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
||||
}
|
||||
)
|
||||
|
||||
if data.get("isNoneDoc"):
|
||||
for doc in docs:
|
||||
doc["source"] = "None"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"/api/search - error: {str(e)} - traceback: {traceback.format_exc()}",
|
||||
extra={"error": str(e), "traceback": traceback.format_exc()},
|
||||
)
|
||||
return bad_request(500, str(e))
|
||||
|
||||
return make_response(docs, 200)
|
||||
0
application/api/answer/routes/__init__.py
Normal file
0
application/api/answer/routes/__init__.py
Normal file
153
application/api/answer/routes/answer.py
Normal file
153
application/api/answer/routes/answer.py
Normal file
@@ -0,0 +1,153 @@
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
from flask import make_response, request
|
||||
from flask_restx import fields, Resource
|
||||
|
||||
from application.api import api
|
||||
|
||||
from application.api.answer.routes.base import answer_ns, BaseAnswerResource
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@answer_ns.route("/api/answer")
|
||||
class AnswerResource(Resource, BaseAnswerResource):
|
||||
def __init__(self, *args, **kwargs):
|
||||
Resource.__init__(self, *args, **kwargs)
|
||||
BaseAnswerResource.__init__(self)
|
||||
|
||||
answer_model = answer_ns.model(
|
||||
"AnswerModel",
|
||||
{
|
||||
"question": fields.String(
|
||||
required=True, description="Question to be asked"
|
||||
),
|
||||
"history": fields.List(
|
||||
fields.String,
|
||||
required=False,
|
||||
description="Conversation history (only for new conversations)",
|
||||
),
|
||||
"conversation_id": fields.String(
|
||||
required=False,
|
||||
description="Existing conversation ID (loads history)",
|
||||
),
|
||||
"prompt_id": fields.String(
|
||||
required=False, default="default", description="Prompt ID"
|
||||
),
|
||||
"chunks": fields.Integer(
|
||||
required=False, default=2, description="Number of chunks"
|
||||
),
|
||||
"retriever": fields.String(required=False, description="Retriever type"),
|
||||
"api_key": fields.String(required=False, description="API key"),
|
||||
"agent_id": fields.String(required=False, description="Agent ID"),
|
||||
"active_docs": fields.String(
|
||||
required=False, description="Active documents"
|
||||
),
|
||||
"isNoneDoc": fields.Boolean(
|
||||
required=False, description="Flag indicating if no document is used"
|
||||
),
|
||||
"save_conversation": fields.Boolean(
|
||||
required=False,
|
||||
default=True,
|
||||
description="Whether to save the conversation",
|
||||
),
|
||||
"model_id": fields.String(
|
||||
required=False,
|
||||
description="Model ID to use for this request",
|
||||
),
|
||||
"passthrough": fields.Raw(
|
||||
required=False,
|
||||
description="Dynamic parameters to inject into prompt template",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(answer_model)
|
||||
@api.doc(description="Provide a response based on the question and retriever")
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
if error := self.validate_request(data):
|
||||
return error
|
||||
decoded_token = getattr(request, "decoded_token", None)
|
||||
processor = StreamProcessor(data, decoded_token)
|
||||
try:
|
||||
# ---- Continuation mode ----
|
||||
if data.get("tool_actions"):
|
||||
(
|
||||
agent,
|
||||
messages,
|
||||
tools_dict,
|
||||
pending_tool_calls,
|
||||
tool_actions,
|
||||
) = processor.resume_from_tool_actions(
|
||||
data["tool_actions"], data["conversation_id"]
|
||||
)
|
||||
if not processor.decoded_token:
|
||||
return make_response({"error": "Unauthorized"}, 401)
|
||||
if error := self.check_usage(processor.agent_config):
|
||||
return error
|
||||
stream = self.complete_stream(
|
||||
question="",
|
||||
agent=agent,
|
||||
conversation_id=processor.conversation_id,
|
||||
user_api_key=processor.agent_config.get("user_api_key"),
|
||||
decoded_token=processor.decoded_token,
|
||||
agent_id=processor.agent_id,
|
||||
model_id=processor.model_id,
|
||||
_continuation={
|
||||
"messages": messages,
|
||||
"tools_dict": tools_dict,
|
||||
"pending_tool_calls": pending_tool_calls,
|
||||
"tool_actions": tool_actions,
|
||||
},
|
||||
)
|
||||
else:
|
||||
# ---- Normal mode ----
|
||||
agent = processor.build_agent(data.get("question", ""))
|
||||
if not processor.decoded_token:
|
||||
return make_response({"error": "Unauthorized"}, 401)
|
||||
|
||||
if error := self.check_usage(processor.agent_config):
|
||||
return error
|
||||
|
||||
stream = self.complete_stream(
|
||||
question=data["question"],
|
||||
agent=agent,
|
||||
conversation_id=processor.conversation_id,
|
||||
user_api_key=processor.agent_config.get("user_api_key"),
|
||||
decoded_token=processor.decoded_token,
|
||||
isNoneDoc=data.get("isNoneDoc"),
|
||||
index=None,
|
||||
should_save_conversation=data.get("save_conversation", True),
|
||||
agent_id=processor.agent_id,
|
||||
is_shared_usage=processor.is_shared_usage,
|
||||
shared_token=processor.shared_token,
|
||||
model_id=processor.model_id,
|
||||
)
|
||||
|
||||
stream_result = self.process_response_stream(stream)
|
||||
|
||||
if stream_result["error"]:
|
||||
return make_response({"error": stream_result["error"]}, 400)
|
||||
|
||||
result = {
|
||||
"conversation_id": stream_result["conversation_id"],
|
||||
"answer": stream_result["answer"],
|
||||
"sources": stream_result["sources"],
|
||||
"tool_calls": stream_result["tool_calls"],
|
||||
"thought": stream_result["thought"],
|
||||
}
|
||||
|
||||
extra_info = stream_result.get("extra")
|
||||
if extra_info:
|
||||
result.update(extra_info)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",
|
||||
extra={"error": str(e), "traceback": traceback.format_exc()},
|
||||
)
|
||||
return make_response({"error": "An error occurred processing your request"}, 500)
|
||||
return make_response(result, 200)
|
||||
640
application/api/answer/routes/base.py
Normal file
640
application/api/answer/routes/base.py
Normal file
@@ -0,0 +1,640 @@
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, Generator, List, Optional
|
||||
|
||||
from flask import jsonify, make_response, Response
|
||||
from flask_restx import Namespace
|
||||
|
||||
from application.api.answer.services.continuation_service import ContinuationService
|
||||
from application.api.answer.services.conversation_service import ConversationService
|
||||
from application.core.model_utils import (
|
||||
get_api_key_for_provider,
|
||||
get_default_model_id,
|
||||
get_provider_from_model_id,
|
||||
)
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.error import sanitize_api_error
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.utils import check_required_fields
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
answer_ns = Namespace("answer", description="Answer related operations", path="/")
|
||||
|
||||
|
||||
class BaseAnswerResource:
|
||||
"""Shared base class for answer endpoints"""
|
||||
|
||||
def __init__(self):
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
self.db = db
|
||||
self.user_logs_collection = db["user_logs"]
|
||||
self.default_model_id = get_default_model_id()
|
||||
self.conversation_service = ConversationService()
|
||||
|
||||
def validate_request(
|
||||
self, data: Dict[str, Any], require_conversation_id: bool = False
|
||||
) -> Optional[Response]:
|
||||
"""Common request validation.
|
||||
|
||||
Continuation requests (``tool_actions`` present) require
|
||||
``conversation_id`` but not ``question``.
|
||||
"""
|
||||
if data.get("tool_actions"):
|
||||
# Continuation mode — question is not required
|
||||
if missing := check_required_fields(data, ["conversation_id"]):
|
||||
return missing
|
||||
return None
|
||||
required_fields = ["question"]
|
||||
if require_conversation_id:
|
||||
required_fields.append("conversation_id")
|
||||
if missing_fields := check_required_fields(data, required_fields):
|
||||
return missing_fields
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _prepare_tool_calls_for_logging(
|
||||
tool_calls: Optional[List[Dict[str, Any]]], max_chars: int = 10000
|
||||
) -> List[Dict[str, Any]]:
|
||||
if not tool_calls:
|
||||
return []
|
||||
|
||||
prepared = []
|
||||
for tool_call in tool_calls:
|
||||
if not isinstance(tool_call, dict):
|
||||
prepared.append({"result": str(tool_call)[:max_chars]})
|
||||
continue
|
||||
|
||||
item = dict(tool_call)
|
||||
for key in ("result", "result_full"):
|
||||
value = item.get(key)
|
||||
if isinstance(value, str) and len(value) > max_chars:
|
||||
item[key] = value[:max_chars]
|
||||
prepared.append(item)
|
||||
return prepared
|
||||
|
||||
def check_usage(self, agent_config: Dict) -> Optional[Response]:
|
||||
"""Check if there is a usage limit and if it is exceeded
|
||||
|
||||
Args:
|
||||
agent_config: The config dict of agent instance
|
||||
|
||||
Returns:
|
||||
None or Response if either of limits exceeded.
|
||||
|
||||
"""
|
||||
api_key = agent_config.get("user_api_key")
|
||||
if not api_key:
|
||||
return None
|
||||
agents_collection = self.db["agents"]
|
||||
agent = agents_collection.find_one({"key": api_key})
|
||||
|
||||
if not agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid API key."}), 401
|
||||
)
|
||||
limited_token_mode_raw = agent.get("limited_token_mode", False)
|
||||
limited_request_mode_raw = agent.get("limited_request_mode", False)
|
||||
|
||||
limited_token_mode = (
|
||||
limited_token_mode_raw
|
||||
if isinstance(limited_token_mode_raw, bool)
|
||||
else limited_token_mode_raw == "True"
|
||||
)
|
||||
limited_request_mode = (
|
||||
limited_request_mode_raw
|
||||
if isinstance(limited_request_mode_raw, bool)
|
||||
else limited_request_mode_raw == "True"
|
||||
)
|
||||
|
||||
token_limit = int(
|
||||
agent.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"])
|
||||
)
|
||||
request_limit = int(
|
||||
agent.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"])
|
||||
)
|
||||
|
||||
token_usage_collection = self.db["token_usage"]
|
||||
|
||||
end_date = datetime.datetime.now()
|
||||
start_date = end_date - datetime.timedelta(hours=24)
|
||||
|
||||
match_query = {
|
||||
"timestamp": {"$gte": start_date, "$lte": end_date},
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
||||
if limited_token_mode:
|
||||
token_pipeline = [
|
||||
{"$match": match_query},
|
||||
{
|
||||
"$group": {
|
||||
"_id": None,
|
||||
"total_tokens": {
|
||||
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
|
||||
},
|
||||
}
|
||||
},
|
||||
]
|
||||
token_result = list(token_usage_collection.aggregate(token_pipeline))
|
||||
daily_token_usage = token_result[0]["total_tokens"] if token_result else 0
|
||||
else:
|
||||
daily_token_usage = 0
|
||||
if limited_request_mode:
|
||||
daily_request_usage = token_usage_collection.count_documents(match_query)
|
||||
else:
|
||||
daily_request_usage = 0
|
||||
if not limited_token_mode and not limited_request_mode:
|
||||
return None
|
||||
token_exceeded = (
|
||||
limited_token_mode and token_limit > 0 and daily_token_usage >= token_limit
|
||||
)
|
||||
request_exceeded = (
|
||||
limited_request_mode
|
||||
and request_limit > 0
|
||||
and daily_request_usage >= request_limit
|
||||
)
|
||||
|
||||
if token_exceeded or request_exceeded:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Exceeding usage limit, please try again later.",
|
||||
}
|
||||
),
|
||||
429,
|
||||
)
|
||||
return None
|
||||
|
||||
def complete_stream(
|
||||
self,
|
||||
question: str,
|
||||
agent: Any,
|
||||
conversation_id: Optional[str],
|
||||
user_api_key: Optional[str],
|
||||
decoded_token: Dict[str, Any],
|
||||
isNoneDoc: bool = False,
|
||||
index: Optional[int] = None,
|
||||
should_save_conversation: bool = True,
|
||||
attachment_ids: Optional[List[str]] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
is_shared_usage: bool = False,
|
||||
shared_token: Optional[str] = None,
|
||||
model_id: Optional[str] = None,
|
||||
_continuation: Optional[Dict] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Generator function that streams the complete conversation response.
|
||||
|
||||
Args:
|
||||
question: The user's question
|
||||
agent: The agent instance
|
||||
retriever: The retriever instance
|
||||
conversation_id: Existing conversation ID
|
||||
user_api_key: User's API key if any
|
||||
decoded_token: Decoded JWT token
|
||||
isNoneDoc: Flag for document-less responses
|
||||
index: Index of message to update
|
||||
should_save_conversation: Whether to persist the conversation
|
||||
attachment_ids: List of attachment IDs
|
||||
agent_id: ID of agent used
|
||||
is_shared_usage: Flag for shared agent usage
|
||||
shared_token: Token for shared agent
|
||||
model_id: Model ID used for the request
|
||||
retrieved_docs: Pre-fetched documents for sources (optional)
|
||||
|
||||
Yields:
|
||||
Server-sent event strings
|
||||
"""
|
||||
try:
|
||||
response_full, thought, source_log_docs, tool_calls = "", "", [], []
|
||||
is_structured = False
|
||||
schema_info = None
|
||||
structured_chunks = []
|
||||
query_metadata = {}
|
||||
paused = False
|
||||
|
||||
if _continuation:
|
||||
gen_iter = agent.gen_continuation(
|
||||
messages=_continuation["messages"],
|
||||
tools_dict=_continuation["tools_dict"],
|
||||
pending_tool_calls=_continuation["pending_tool_calls"],
|
||||
tool_actions=_continuation["tool_actions"],
|
||||
)
|
||||
else:
|
||||
gen_iter = agent.gen(query=question)
|
||||
|
||||
for line in gen_iter:
|
||||
if "metadata" in line:
|
||||
query_metadata.update(line["metadata"])
|
||||
elif "answer" in line:
|
||||
response_full += str(line["answer"])
|
||||
if line.get("structured"):
|
||||
is_structured = True
|
||||
schema_info = line.get("schema")
|
||||
structured_chunks.append(line["answer"])
|
||||
else:
|
||||
data = json.dumps({"type": "answer", "answer": line["answer"]})
|
||||
yield f"data: {data}\n\n"
|
||||
elif "sources" in line:
|
||||
truncated_sources = []
|
||||
source_log_docs = line["sources"]
|
||||
for source in line["sources"]:
|
||||
truncated_source = source.copy()
|
||||
if "text" in truncated_source:
|
||||
truncated_source["text"] = (
|
||||
truncated_source["text"][:100].strip() + "..."
|
||||
)
|
||||
truncated_sources.append(truncated_source)
|
||||
if truncated_sources:
|
||||
data = json.dumps(
|
||||
{"type": "source", "source": truncated_sources}
|
||||
)
|
||||
yield f"data: {data}\n\n"
|
||||
elif "tool_calls" in line:
|
||||
tool_calls = line["tool_calls"]
|
||||
data = json.dumps({"type": "tool_calls", "tool_calls": tool_calls})
|
||||
yield f"data: {data}\n\n"
|
||||
elif "thought" in line:
|
||||
thought += line["thought"]
|
||||
data = json.dumps({"type": "thought", "thought": line["thought"]})
|
||||
yield f"data: {data}\n\n"
|
||||
elif "type" in line:
|
||||
if line.get("type") == "tool_calls_pending":
|
||||
# Save continuation state and end the stream
|
||||
paused = True
|
||||
data = json.dumps(line)
|
||||
yield f"data: {data}\n\n"
|
||||
elif line.get("type") == "error":
|
||||
sanitized_error = {
|
||||
"type": "error",
|
||||
"error": sanitize_api_error(line.get("error", "An error occurred"))
|
||||
}
|
||||
data = json.dumps(sanitized_error)
|
||||
yield f"data: {data}\n\n"
|
||||
else:
|
||||
data = json.dumps(line)
|
||||
yield f"data: {data}\n\n"
|
||||
if is_structured and structured_chunks:
|
||||
structured_data = {
|
||||
"type": "structured_answer",
|
||||
"answer": response_full,
|
||||
"structured": True,
|
||||
"schema": schema_info,
|
||||
}
|
||||
data = json.dumps(structured_data)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
# ---- Paused: save continuation state and end stream early ----
|
||||
if paused:
|
||||
continuation = getattr(agent, "_pending_continuation", None)
|
||||
if continuation:
|
||||
# Ensure we have a conversation_id — create a partial
|
||||
# conversation if this is the first turn.
|
||||
if not conversation_id and should_save_conversation:
|
||||
try:
|
||||
provider = (
|
||||
get_provider_from_model_id(model_id)
|
||||
if model_id
|
||||
else settings.LLM_PROVIDER
|
||||
)
|
||||
sys_api_key = get_api_key_for_provider(
|
||||
provider or settings.LLM_PROVIDER
|
||||
)
|
||||
llm = LLMCreator.create_llm(
|
||||
provider or settings.LLM_PROVIDER,
|
||||
api_key=sys_api_key,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
model_id=model_id,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
conversation_id = (
|
||||
self.conversation_service.save_conversation(
|
||||
None,
|
||||
question,
|
||||
response_full,
|
||||
thought,
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
model_id or self.default_model_id,
|
||||
decoded_token,
|
||||
api_key=user_api_key,
|
||||
agent_id=agent_id,
|
||||
is_shared_usage=is_shared_usage,
|
||||
shared_token=shared_token,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to create conversation for continuation: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
if conversation_id:
|
||||
try:
|
||||
cont_service = ContinuationService()
|
||||
cont_service.save_state(
|
||||
conversation_id=str(conversation_id),
|
||||
user=decoded_token.get("sub", "local"),
|
||||
messages=continuation["messages"],
|
||||
pending_tool_calls=continuation["pending_tool_calls"],
|
||||
tools_dict=continuation["tools_dict"],
|
||||
tool_schemas=getattr(agent, "tools", []),
|
||||
agent_config={
|
||||
"model_id": model_id or self.default_model_id,
|
||||
"llm_name": getattr(agent, "llm_name", settings.LLM_PROVIDER),
|
||||
"api_key": getattr(agent, "api_key", None),
|
||||
"user_api_key": user_api_key,
|
||||
"agent_id": agent_id,
|
||||
"agent_type": agent.__class__.__name__,
|
||||
"prompt": getattr(agent, "prompt", ""),
|
||||
"json_schema": getattr(agent, "json_schema", None),
|
||||
"retriever_config": getattr(agent, "retriever_config", None),
|
||||
},
|
||||
client_tools=getattr(
|
||||
agent.tool_executor, "client_tools", None
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to save continuation state: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
id_data = {"type": "id", "id": str(conversation_id)}
|
||||
data = json.dumps(id_data)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
data = json.dumps({"type": "end"})
|
||||
yield f"data: {data}\n\n"
|
||||
return
|
||||
|
||||
if isNoneDoc:
|
||||
for doc in source_log_docs:
|
||||
doc["source"] = "None"
|
||||
provider = (
|
||||
get_provider_from_model_id(model_id)
|
||||
if model_id
|
||||
else settings.LLM_PROVIDER
|
||||
)
|
||||
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
|
||||
|
||||
llm = LLMCreator.create_llm(
|
||||
provider or settings.LLM_PROVIDER,
|
||||
api_key=system_api_key,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
model_id=model_id,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
if should_save_conversation:
|
||||
conversation_id = self.conversation_service.save_conversation(
|
||||
conversation_id,
|
||||
question,
|
||||
response_full,
|
||||
thought,
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
model_id or self.default_model_id,
|
||||
decoded_token,
|
||||
index=index,
|
||||
api_key=user_api_key,
|
||||
agent_id=agent_id,
|
||||
is_shared_usage=is_shared_usage,
|
||||
shared_token=shared_token,
|
||||
attachment_ids=attachment_ids,
|
||||
metadata=query_metadata if query_metadata else None,
|
||||
)
|
||||
# Persist compression metadata/summary if it exists and wasn't saved mid-execution
|
||||
compression_meta = getattr(agent, "compression_metadata", None)
|
||||
compression_saved = getattr(agent, "compression_saved", False)
|
||||
if conversation_id and compression_meta and not compression_saved:
|
||||
try:
|
||||
self.conversation_service.update_compression_metadata(
|
||||
conversation_id, compression_meta
|
||||
)
|
||||
self.conversation_service.append_compression_message(
|
||||
conversation_id, compression_meta
|
||||
)
|
||||
agent.compression_saved = True
|
||||
logger.info(
|
||||
f"Persisted compression metadata for conversation {conversation_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to persist compression metadata: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
else:
|
||||
conversation_id = None
|
||||
id_data = {"type": "id", "id": str(conversation_id)}
|
||||
data = json.dumps(id_data)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
tool_calls_for_logging = self._prepare_tool_calls_for_logging(
|
||||
getattr(agent, "tool_calls", tool_calls) or tool_calls
|
||||
)
|
||||
|
||||
log_data = {
|
||||
"action": "stream_answer",
|
||||
"level": "info",
|
||||
"user": decoded_token.get("sub"),
|
||||
"api_key": user_api_key,
|
||||
"agent_id": agent_id,
|
||||
"question": question,
|
||||
"response": response_full,
|
||||
"sources": source_log_docs,
|
||||
"tool_calls": tool_calls_for_logging,
|
||||
"attachments": attachment_ids,
|
||||
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
||||
}
|
||||
if is_structured:
|
||||
log_data["structured_output"] = True
|
||||
if schema_info:
|
||||
log_data["schema"] = schema_info
|
||||
# Clean up text fields to be no longer than 10000 characters
|
||||
|
||||
for key, value in log_data.items():
|
||||
if isinstance(value, str) and len(value) > 10000:
|
||||
log_data[key] = value[:10000]
|
||||
self.user_logs_collection.insert_one(log_data)
|
||||
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.repositories.user_logs import UserLogsRepository
|
||||
|
||||
dual_write(
|
||||
UserLogsRepository,
|
||||
lambda repo, d=log_data: repo.insert(
|
||||
user_id=d.get("user"),
|
||||
endpoint="stream_answer",
|
||||
data=d,
|
||||
),
|
||||
)
|
||||
|
||||
data = json.dumps({"type": "end"})
|
||||
yield f"data: {data}\n\n"
|
||||
except GeneratorExit:
|
||||
logger.info(f"Stream aborted by client for question: {question[:50]}... ")
|
||||
# Save partial response
|
||||
|
||||
if should_save_conversation and response_full:
|
||||
try:
|
||||
if isNoneDoc:
|
||||
for doc in source_log_docs:
|
||||
doc["source"] = "None"
|
||||
llm = LLMCreator.create_llm(
|
||||
settings.LLM_PROVIDER,
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
self.conversation_service.save_conversation(
|
||||
conversation_id,
|
||||
question,
|
||||
response_full,
|
||||
thought,
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
model_id or self.default_model_id,
|
||||
decoded_token,
|
||||
index=index,
|
||||
api_key=user_api_key,
|
||||
agent_id=agent_id,
|
||||
is_shared_usage=is_shared_usage,
|
||||
shared_token=shared_token,
|
||||
attachment_ids=attachment_ids,
|
||||
metadata=query_metadata if query_metadata else None,
|
||||
)
|
||||
compression_meta = getattr(agent, "compression_metadata", None)
|
||||
compression_saved = getattr(agent, "compression_saved", False)
|
||||
if conversation_id and compression_meta and not compression_saved:
|
||||
try:
|
||||
self.conversation_service.update_compression_metadata(
|
||||
conversation_id, compression_meta
|
||||
)
|
||||
self.conversation_service.append_compression_message(
|
||||
conversation_id, compression_meta
|
||||
)
|
||||
agent.compression_saved = True
|
||||
logger.info(
|
||||
f"Persisted compression metadata for conversation {conversation_id} (partial stream)"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to persist compression metadata (partial stream): {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error saving partial response: {str(e)}", exc_info=True
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in stream: {str(e)}", exc_info=True)
|
||||
data = json.dumps(
|
||||
{
|
||||
"type": "error",
|
||||
"error": "Please try again later. We apologize for any inconvenience.",
|
||||
}
|
||||
)
|
||||
yield f"data: {data}\n\n"
|
||||
return
|
||||
|
||||
def process_response_stream(self, stream) -> Dict[str, Any]:
|
||||
"""Process the stream response for non-streaming endpoint.
|
||||
|
||||
Returns:
|
||||
Dict with keys: conversation_id, answer, sources, tool_calls,
|
||||
thought, error, and optional extra.
|
||||
"""
|
||||
conversation_id = ""
|
||||
response_full = ""
|
||||
source_log_docs = []
|
||||
tool_calls = []
|
||||
thought = ""
|
||||
stream_ended = False
|
||||
is_structured = False
|
||||
schema_info = None
|
||||
pending_tool_calls = None
|
||||
|
||||
for line in stream:
|
||||
try:
|
||||
event_data = line.replace("data: ", "").strip()
|
||||
event = json.loads(event_data)
|
||||
|
||||
if event["type"] == "id":
|
||||
conversation_id = event["id"]
|
||||
elif event["type"] == "answer":
|
||||
response_full += event["answer"]
|
||||
elif event["type"] == "structured_answer":
|
||||
response_full = event["answer"]
|
||||
is_structured = True
|
||||
schema_info = event.get("schema")
|
||||
elif event["type"] == "source":
|
||||
source_log_docs = event["source"]
|
||||
elif event["type"] == "tool_calls":
|
||||
tool_calls = event["tool_calls"]
|
||||
elif event["type"] == "tool_calls_pending":
|
||||
pending_tool_calls = event.get("data", {}).get(
|
||||
"pending_tool_calls", []
|
||||
)
|
||||
elif event["type"] == "thought":
|
||||
thought = event["thought"]
|
||||
elif event["type"] == "error":
|
||||
logger.error(f"Error from stream: {event['error']}")
|
||||
return {
|
||||
"conversation_id": None,
|
||||
"answer": None,
|
||||
"sources": None,
|
||||
"tool_calls": None,
|
||||
"thought": None,
|
||||
"error": event["error"],
|
||||
}
|
||||
elif event["type"] == "end":
|
||||
stream_ended = True
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
logger.warning(f"Error parsing stream event: {e}, line: {line}")
|
||||
continue
|
||||
if not stream_ended:
|
||||
logger.error("Stream ended unexpectedly without an 'end' event.")
|
||||
return {
|
||||
"conversation_id": None,
|
||||
"answer": None,
|
||||
"sources": None,
|
||||
"tool_calls": None,
|
||||
"thought": None,
|
||||
"error": "Stream ended unexpectedly",
|
||||
}
|
||||
|
||||
result: Dict[str, Any] = {
|
||||
"conversation_id": conversation_id,
|
||||
"answer": response_full,
|
||||
"sources": source_log_docs,
|
||||
"tool_calls": tool_calls,
|
||||
"thought": thought,
|
||||
"error": None,
|
||||
}
|
||||
|
||||
if pending_tool_calls is not None:
|
||||
result["extra"] = {"pending_tool_calls": pending_tool_calls}
|
||||
|
||||
if is_structured:
|
||||
result["extra"] = {"structured": True, "schema": schema_info}
|
||||
|
||||
return result
|
||||
|
||||
def error_stream_generate(self, err_response):
|
||||
data = json.dumps({"type": "error", "error": err_response})
|
||||
yield f"data: {data}\n\n"
|
||||
186
application/api/answer/routes/search.py
Normal file
186
application/api/answer/routes/search.py
Normal file
@@ -0,0 +1,186 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from flask import make_response, request
|
||||
from flask_restx import fields, Resource
|
||||
|
||||
from bson.dbref import DBRef
|
||||
|
||||
from application.api.answer.routes.base import answer_ns
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.vectorstore.vector_creator import VectorCreator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@answer_ns.route("/api/search")
|
||||
class SearchResource(Resource):
|
||||
"""Fast search endpoint for retrieving relevant documents"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
mongo = MongoDB.get_client()
|
||||
self.db = mongo[settings.MONGO_DB_NAME]
|
||||
self.agents_collection = self.db["agents"]
|
||||
|
||||
search_model = answer_ns.model(
|
||||
"SearchModel",
|
||||
{
|
||||
"question": fields.String(
|
||||
required=True, description="Search query"
|
||||
),
|
||||
"api_key": fields.String(
|
||||
required=True, description="API key for authentication"
|
||||
),
|
||||
"chunks": fields.Integer(
|
||||
required=False, default=5, description="Number of results to return"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
def _get_sources_from_api_key(self, api_key: str) -> List[str]:
|
||||
"""Get source IDs connected to the API key/agent.
|
||||
|
||||
"""
|
||||
agent_data = self.agents_collection.find_one({"key": api_key})
|
||||
if not agent_data:
|
||||
return []
|
||||
|
||||
source_ids = []
|
||||
|
||||
# Handle multiple sources (only if non-empty)
|
||||
sources = agent_data.get("sources", [])
|
||||
if sources and isinstance(sources, list) and len(sources) > 0:
|
||||
for source_ref in sources:
|
||||
# Skip "default" - it's a placeholder, not an actual vectorstore
|
||||
if source_ref == "default":
|
||||
continue
|
||||
elif isinstance(source_ref, DBRef):
|
||||
source_doc = self.db.dereference(source_ref)
|
||||
if source_doc:
|
||||
source_ids.append(str(source_doc["_id"]))
|
||||
|
||||
# Handle single source (legacy) - check if sources was empty or didn't yield results
|
||||
if not source_ids:
|
||||
source = agent_data.get("source")
|
||||
if isinstance(source, DBRef):
|
||||
source_doc = self.db.dereference(source)
|
||||
if source_doc:
|
||||
source_ids.append(str(source_doc["_id"]))
|
||||
# Skip "default" - it's a placeholder, not an actual vectorstore
|
||||
elif source and source != "default":
|
||||
source_ids.append(source)
|
||||
|
||||
return source_ids
|
||||
|
||||
def _search_vectorstores(
|
||||
self, query: str, source_ids: List[str], chunks: int
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search across vectorstores and return results"""
|
||||
if not source_ids:
|
||||
return []
|
||||
|
||||
results = []
|
||||
chunks_per_source = max(1, chunks // len(source_ids))
|
||||
seen_texts = set()
|
||||
|
||||
for source_id in source_ids:
|
||||
if not source_id or not source_id.strip():
|
||||
continue
|
||||
|
||||
try:
|
||||
docsearch = VectorCreator.create_vectorstore(
|
||||
settings.VECTOR_STORE, source_id, settings.EMBEDDINGS_KEY
|
||||
)
|
||||
docs = docsearch.search(query, k=chunks_per_source * 2)
|
||||
|
||||
for doc in docs:
|
||||
if len(results) >= chunks:
|
||||
break
|
||||
|
||||
if hasattr(doc, "page_content") and hasattr(doc, "metadata"):
|
||||
page_content = doc.page_content
|
||||
metadata = doc.metadata
|
||||
else:
|
||||
page_content = doc.get("text", doc.get("page_content", ""))
|
||||
metadata = doc.get("metadata", {})
|
||||
|
||||
# Skip duplicates
|
||||
text_hash = hash(page_content[:200])
|
||||
if text_hash in seen_texts:
|
||||
continue
|
||||
seen_texts.add(text_hash)
|
||||
|
||||
title = metadata.get(
|
||||
"title", metadata.get("post_title", "")
|
||||
)
|
||||
if not isinstance(title, str):
|
||||
title = str(title) if title else ""
|
||||
|
||||
# Clean up title
|
||||
if title:
|
||||
title = title.split("/")[-1]
|
||||
else:
|
||||
# Use filename or first part of content as title
|
||||
title = metadata.get("filename", page_content[:50] + "...")
|
||||
|
||||
source = metadata.get("source", source_id)
|
||||
|
||||
results.append({
|
||||
"text": page_content,
|
||||
"title": title,
|
||||
"source": source,
|
||||
})
|
||||
|
||||
if len(results) >= chunks:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error searching vectorstore {source_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
continue
|
||||
|
||||
return results[:chunks]
|
||||
|
||||
@answer_ns.expect(search_model)
|
||||
@answer_ns.doc(description="Search for relevant documents based on query")
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
|
||||
question = data.get("question")
|
||||
api_key = data.get("api_key")
|
||||
chunks = data.get("chunks", 5)
|
||||
|
||||
if not question:
|
||||
return make_response({"error": "question is required"}, 400)
|
||||
|
||||
if not api_key:
|
||||
return make_response({"error": "api_key is required"}, 400)
|
||||
|
||||
# Validate API key
|
||||
agent = self.agents_collection.find_one({"key": api_key})
|
||||
if not agent:
|
||||
return make_response({"error": "Invalid API key"}, 401)
|
||||
|
||||
try:
|
||||
# Get sources connected to this API key
|
||||
source_ids = self._get_sources_from_api_key(api_key)
|
||||
|
||||
if not source_ids:
|
||||
return make_response([], 200)
|
||||
|
||||
# Perform search
|
||||
results = self._search_vectorstores(question, source_ids, chunks)
|
||||
|
||||
return make_response(results, 200)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"/api/search - error: {str(e)}",
|
||||
extra={"error": str(e)},
|
||||
exc_info=True,
|
||||
)
|
||||
return make_response({"error": "Search failed"}, 500)
|
||||
171
application/api/answer/routes/stream.py
Normal file
171
application/api/answer/routes/stream.py
Normal file
@@ -0,0 +1,171 @@
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
from flask import request, Response
|
||||
from flask_restx import fields, Resource
|
||||
|
||||
from application.api import api
|
||||
|
||||
from application.api.answer.routes.base import answer_ns, BaseAnswerResource
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@answer_ns.route("/stream")
|
||||
class StreamResource(Resource, BaseAnswerResource):
|
||||
def __init__(self, *args, **kwargs):
|
||||
Resource.__init__(self, *args, **kwargs)
|
||||
BaseAnswerResource.__init__(self)
|
||||
|
||||
stream_model = answer_ns.model(
|
||||
"StreamModel",
|
||||
{
|
||||
"question": fields.String(
|
||||
required=True, description="Question to be asked"
|
||||
),
|
||||
"history": fields.List(
|
||||
fields.String,
|
||||
required=False,
|
||||
description="Conversation history (only for new conversations)",
|
||||
),
|
||||
"conversation_id": fields.String(
|
||||
required=False,
|
||||
description="Existing conversation ID (loads history)",
|
||||
),
|
||||
"prompt_id": fields.String(
|
||||
required=False, default="default", description="Prompt ID"
|
||||
),
|
||||
"chunks": fields.Integer(
|
||||
required=False, default=2, description="Number of chunks"
|
||||
),
|
||||
"retriever": fields.String(required=False, description="Retriever type"),
|
||||
"api_key": fields.String(required=False, description="API key"),
|
||||
"agent_id": fields.String(required=False, description="Agent ID"),
|
||||
"active_docs": fields.String(
|
||||
required=False, description="Active documents"
|
||||
),
|
||||
"isNoneDoc": fields.Boolean(
|
||||
required=False, description="Flag indicating if no document is used"
|
||||
),
|
||||
"index": fields.Integer(
|
||||
required=False, description="Index of the query to update"
|
||||
),
|
||||
"save_conversation": fields.Boolean(
|
||||
required=False,
|
||||
default=True,
|
||||
description="Whether to save the conversation",
|
||||
),
|
||||
"model_id": fields.String(
|
||||
required=False,
|
||||
description="Model ID to use for this request",
|
||||
),
|
||||
"attachments": fields.List(
|
||||
fields.String, required=False, description="List of attachment IDs"
|
||||
),
|
||||
"passthrough": fields.Raw(
|
||||
required=False,
|
||||
description="Dynamic parameters to inject into prompt template",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(stream_model)
|
||||
@api.doc(description="Stream a response based on the question and retriever")
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
if error := self.validate_request(data, "index" in data):
|
||||
return error
|
||||
decoded_token = getattr(request, "decoded_token", None)
|
||||
processor = StreamProcessor(data, decoded_token)
|
||||
|
||||
try:
|
||||
# ---- Continuation mode ----
|
||||
if data.get("tool_actions"):
|
||||
(
|
||||
agent,
|
||||
messages,
|
||||
tools_dict,
|
||||
pending_tool_calls,
|
||||
tool_actions,
|
||||
) = processor.resume_from_tool_actions(
|
||||
data["tool_actions"], data["conversation_id"]
|
||||
)
|
||||
if not processor.decoded_token:
|
||||
return Response(
|
||||
self.error_stream_generate("Unauthorized"),
|
||||
status=401,
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
if error := self.check_usage(processor.agent_config):
|
||||
return error
|
||||
return Response(
|
||||
self.complete_stream(
|
||||
question="",
|
||||
agent=agent,
|
||||
conversation_id=processor.conversation_id,
|
||||
user_api_key=processor.agent_config.get("user_api_key"),
|
||||
decoded_token=processor.decoded_token,
|
||||
agent_id=processor.agent_id,
|
||||
model_id=processor.model_id,
|
||||
_continuation={
|
||||
"messages": messages,
|
||||
"tools_dict": tools_dict,
|
||||
"pending_tool_calls": pending_tool_calls,
|
||||
"tool_actions": tool_actions,
|
||||
},
|
||||
),
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
|
||||
# ---- Normal mode ----
|
||||
agent = processor.build_agent(data["question"])
|
||||
if not processor.decoded_token:
|
||||
return Response(
|
||||
self.error_stream_generate("Unauthorized"),
|
||||
status=401,
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
|
||||
if error := self.check_usage(processor.agent_config):
|
||||
return error
|
||||
return Response(
|
||||
self.complete_stream(
|
||||
question=data["question"],
|
||||
agent=agent,
|
||||
conversation_id=processor.conversation_id,
|
||||
user_api_key=processor.agent_config.get("user_api_key"),
|
||||
decoded_token=processor.decoded_token,
|
||||
isNoneDoc=data.get("isNoneDoc"),
|
||||
index=data.get("index"),
|
||||
should_save_conversation=data.get("save_conversation", True),
|
||||
attachment_ids=data.get("attachments", []),
|
||||
agent_id=processor.agent_id,
|
||||
is_shared_usage=processor.is_shared_usage,
|
||||
shared_token=processor.shared_token,
|
||||
model_id=processor.model_id,
|
||||
),
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
except ValueError as e:
|
||||
message = "Malformed request body"
|
||||
logger.error(
|
||||
f"/stream - error: {message} - specific error: {str(e)} - traceback: {traceback.format_exc()}",
|
||||
extra={"error": str(e), "traceback": traceback.format_exc()},
|
||||
)
|
||||
return Response(
|
||||
self.error_stream_generate(message),
|
||||
status=400,
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"/stream - error: {str(e)} - traceback: {traceback.format_exc()}",
|
||||
extra={"error": str(e), "traceback": traceback.format_exc()},
|
||||
)
|
||||
return Response(
|
||||
self.error_stream_generate("Unknown error occurred"),
|
||||
status=400,
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
0
application/api/answer/services/__init__.py
Normal file
0
application/api/answer/services/__init__.py
Normal file
20
application/api/answer/services/compression/__init__.py
Normal file
20
application/api/answer/services/compression/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""
|
||||
Compression module for managing conversation context compression.
|
||||
|
||||
"""
|
||||
|
||||
from application.api.answer.services.compression.orchestrator import (
|
||||
CompressionOrchestrator,
|
||||
)
|
||||
from application.api.answer.services.compression.service import CompressionService
|
||||
from application.api.answer.services.compression.types import (
|
||||
CompressionResult,
|
||||
CompressionMetadata,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CompressionOrchestrator",
|
||||
"CompressionService",
|
||||
"CompressionResult",
|
||||
"CompressionMetadata",
|
||||
]
|
||||
249
application/api/answer/services/compression/message_builder.py
Normal file
249
application/api/answer/services/compression/message_builder.py
Normal file
@@ -0,0 +1,249 @@
|
||||
"""Message reconstruction utilities for compression."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageBuilder:
|
||||
"""Builds message arrays from compressed context."""
|
||||
|
||||
@staticmethod
|
||||
def build_from_compressed_context(
|
||||
system_prompt: str,
|
||||
compressed_summary: Optional[str],
|
||||
recent_queries: List[Dict],
|
||||
include_tool_calls: bool = False,
|
||||
context_type: str = "pre_request",
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Build messages from compressed context.
|
||||
|
||||
Args:
|
||||
system_prompt: Original system prompt
|
||||
compressed_summary: Compressed summary (if any)
|
||||
recent_queries: Recent uncompressed queries
|
||||
include_tool_calls: Whether to include tool calls from history
|
||||
context_type: Type of context ('pre_request' or 'mid_execution')
|
||||
|
||||
Returns:
|
||||
List of message dicts ready for LLM
|
||||
"""
|
||||
# Append compression summary to system prompt if present
|
||||
if compressed_summary:
|
||||
system_prompt = MessageBuilder._append_compression_context(
|
||||
system_prompt, compressed_summary, context_type
|
||||
)
|
||||
|
||||
messages = [{"role": "system", "content": system_prompt}]
|
||||
|
||||
# Add recent history
|
||||
for query in recent_queries:
|
||||
if "prompt" in query and "response" in query:
|
||||
messages.append({"role": "user", "content": query["prompt"]})
|
||||
messages.append({"role": "assistant", "content": query["response"]})
|
||||
|
||||
# Add tool calls from history if present
|
||||
if include_tool_calls and "tool_calls" in query:
|
||||
for tool_call in query["tool_calls"]:
|
||||
call_id = tool_call.get("call_id") or str(uuid.uuid4())
|
||||
args = tool_call.get("arguments")
|
||||
args_str = (
|
||||
json.dumps(args)
|
||||
if isinstance(args, dict)
|
||||
else (args or "{}")
|
||||
)
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.get("action_name", ""),
|
||||
"arguments": args_str,
|
||||
},
|
||||
}],
|
||||
})
|
||||
result = tool_call.get("result")
|
||||
result_str = (
|
||||
json.dumps(result)
|
||||
if not isinstance(result, str)
|
||||
else (result or "")
|
||||
)
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"content": result_str,
|
||||
})
|
||||
|
||||
# If no recent queries (everything was compressed), add a continuation user message
|
||||
if len(recent_queries) == 0 and compressed_summary:
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": "Please continue with the remaining tasks based on the context above."
|
||||
})
|
||||
logger.info("Added continuation user message to maintain proper turn-taking after full compression")
|
||||
|
||||
return messages
|
||||
|
||||
@staticmethod
|
||||
def _append_compression_context(
|
||||
system_prompt: str, compressed_summary: str, context_type: str = "pre_request"
|
||||
) -> str:
|
||||
"""
|
||||
Append compression context to system prompt.
|
||||
|
||||
Args:
|
||||
system_prompt: Original system prompt
|
||||
compressed_summary: Summary to append
|
||||
context_type: Type of compression context
|
||||
|
||||
Returns:
|
||||
Updated system prompt
|
||||
"""
|
||||
# Remove existing compression context if present
|
||||
if "This session is being continued" in system_prompt or "Context window limit reached" in system_prompt:
|
||||
parts = system_prompt.split("\n\n---\n\n")
|
||||
system_prompt = parts[0]
|
||||
|
||||
# Build appropriate context message based on type
|
||||
if context_type == "mid_execution":
|
||||
context_message = (
|
||||
"\n\n---\n\n"
|
||||
"Context window limit reached during execution. "
|
||||
"Previous conversation has been compressed to fit within limits. "
|
||||
"The conversation is summarized below:\n\n"
|
||||
f"{compressed_summary}"
|
||||
)
|
||||
else: # pre_request
|
||||
context_message = (
|
||||
"\n\n---\n\n"
|
||||
"This session is being continued from a previous conversation that "
|
||||
"has been compressed to fit within context limits. "
|
||||
"The conversation is summarized below:\n\n"
|
||||
f"{compressed_summary}"
|
||||
)
|
||||
|
||||
return system_prompt + context_message
|
||||
|
||||
@staticmethod
|
||||
def rebuild_messages_after_compression(
|
||||
messages: List[Dict],
|
||||
compressed_summary: Optional[str],
|
||||
recent_queries: List[Dict],
|
||||
include_current_execution: bool = False,
|
||||
include_tool_calls: bool = False,
|
||||
) -> Optional[List[Dict]]:
|
||||
"""
|
||||
Rebuild the message list after compression so tool execution can continue.
|
||||
|
||||
Args:
|
||||
messages: Original message list
|
||||
compressed_summary: Compressed summary
|
||||
recent_queries: Recent uncompressed queries
|
||||
include_current_execution: Whether to preserve current execution messages
|
||||
include_tool_calls: Whether to include tool calls from history
|
||||
|
||||
Returns:
|
||||
Rebuilt message list or None if failed
|
||||
"""
|
||||
# Find the system message
|
||||
system_message = next(
|
||||
(msg for msg in messages if msg.get("role") == "system"), None
|
||||
)
|
||||
if not system_message:
|
||||
logger.warning("No system message found in messages list")
|
||||
return None
|
||||
|
||||
# Update system message with compressed summary
|
||||
if compressed_summary:
|
||||
content = system_message.get("content", "")
|
||||
system_message["content"] = MessageBuilder._append_compression_context(
|
||||
content, compressed_summary, "mid_execution"
|
||||
)
|
||||
logger.info(
|
||||
"Appended compression summary to system prompt (truncated): %s",
|
||||
(
|
||||
compressed_summary[:500] + "..."
|
||||
if len(compressed_summary) > 500
|
||||
else compressed_summary
|
||||
),
|
||||
)
|
||||
|
||||
rebuilt_messages = [system_message]
|
||||
|
||||
# Add recent history from compressed context
|
||||
for query in recent_queries:
|
||||
if "prompt" in query and "response" in query:
|
||||
rebuilt_messages.append({"role": "user", "content": query["prompt"]})
|
||||
rebuilt_messages.append(
|
||||
{"role": "assistant", "content": query["response"]}
|
||||
)
|
||||
|
||||
# Add tool calls from history if present
|
||||
if include_tool_calls and "tool_calls" in query:
|
||||
for tool_call in query["tool_calls"]:
|
||||
call_id = tool_call.get("call_id") or str(uuid.uuid4())
|
||||
args = tool_call.get("arguments")
|
||||
args_str = (
|
||||
json.dumps(args)
|
||||
if isinstance(args, dict)
|
||||
else (args or "{}")
|
||||
)
|
||||
rebuilt_messages.append({
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.get("action_name", ""),
|
||||
"arguments": args_str,
|
||||
},
|
||||
}],
|
||||
})
|
||||
result = tool_call.get("result")
|
||||
result_str = (
|
||||
json.dumps(result)
|
||||
if not isinstance(result, str)
|
||||
else (result or "")
|
||||
)
|
||||
rebuilt_messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"content": result_str,
|
||||
})
|
||||
|
||||
# If no recent queries (everything was compressed), add a continuation user message
|
||||
if len(recent_queries) == 0 and compressed_summary:
|
||||
rebuilt_messages.append({
|
||||
"role": "user",
|
||||
"content": "Please continue with the remaining tasks based on the context above."
|
||||
})
|
||||
logger.info("Added continuation user message to maintain proper turn-taking after full compression")
|
||||
|
||||
if include_current_execution:
|
||||
# Preserve any messages that were added during the current execution cycle
|
||||
recent_msg_count = 1 # system message
|
||||
for query in recent_queries:
|
||||
if "prompt" in query and "response" in query:
|
||||
recent_msg_count += 2
|
||||
if "tool_calls" in query:
|
||||
recent_msg_count += len(query["tool_calls"]) * 2
|
||||
|
||||
if len(messages) > recent_msg_count:
|
||||
current_execution_messages = messages[recent_msg_count:]
|
||||
rebuilt_messages.extend(current_execution_messages)
|
||||
logger.info(
|
||||
f"Preserved {len(current_execution_messages)} messages from current execution cycle"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Messages rebuilt: {len(messages)} → {len(rebuilt_messages)} messages. "
|
||||
f"Ready to continue tool execution."
|
||||
)
|
||||
return rebuilt_messages
|
||||
233
application/api/answer/services/compression/orchestrator.py
Normal file
233
application/api/answer/services/compression/orchestrator.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""High-level compression orchestration."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from application.api.answer.services.compression.service import CompressionService
|
||||
from application.api.answer.services.compression.threshold_checker import (
|
||||
CompressionThresholdChecker,
|
||||
)
|
||||
from application.api.answer.services.compression.types import CompressionResult
|
||||
from application.api.answer.services.conversation_service import ConversationService
|
||||
from application.core.model_utils import (
|
||||
get_api_key_for_provider,
|
||||
get_provider_from_model_id,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompressionOrchestrator:
|
||||
"""
|
||||
Facade for compression operations.
|
||||
|
||||
Coordinates between all compression components and provides
|
||||
a simple interface for callers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conversation_service: ConversationService,
|
||||
threshold_checker: Optional[CompressionThresholdChecker] = None,
|
||||
):
|
||||
"""
|
||||
Initialize orchestrator.
|
||||
|
||||
Args:
|
||||
conversation_service: Service for DB operations
|
||||
threshold_checker: Custom threshold checker (optional)
|
||||
"""
|
||||
self.conversation_service = conversation_service
|
||||
self.threshold_checker = threshold_checker or CompressionThresholdChecker()
|
||||
|
||||
def compress_if_needed(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_id: str,
|
||||
model_id: str,
|
||||
decoded_token: Dict[str, Any],
|
||||
current_query_tokens: int = 500,
|
||||
) -> CompressionResult:
|
||||
"""
|
||||
Check if compression is needed and perform it if so.
|
||||
|
||||
This is the main entry point for compression operations.
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation ID
|
||||
user_id: User ID
|
||||
model_id: Model being used for conversation
|
||||
decoded_token: User's decoded JWT token
|
||||
current_query_tokens: Estimated tokens for current query
|
||||
|
||||
Returns:
|
||||
CompressionResult with summary and recent queries
|
||||
"""
|
||||
try:
|
||||
# Load conversation
|
||||
conversation = self.conversation_service.get_conversation(
|
||||
conversation_id, user_id
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
logger.warning(
|
||||
f"Conversation {conversation_id} not found for user {user_id}"
|
||||
)
|
||||
return CompressionResult.failure("Conversation not found")
|
||||
|
||||
# Check if compression is needed
|
||||
if not self.threshold_checker.should_compress(
|
||||
conversation, model_id, current_query_tokens
|
||||
):
|
||||
# No compression needed, return full history
|
||||
queries = conversation.get("queries", [])
|
||||
return CompressionResult.success_no_compression(queries)
|
||||
|
||||
# Perform compression
|
||||
return self._perform_compression(
|
||||
conversation_id, conversation, model_id, decoded_token
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in compress_if_needed: {str(e)}", exc_info=True
|
||||
)
|
||||
return CompressionResult.failure(str(e))
|
||||
|
||||
def _perform_compression(
|
||||
self,
|
||||
conversation_id: str,
|
||||
conversation: Dict[str, Any],
|
||||
model_id: str,
|
||||
decoded_token: Dict[str, Any],
|
||||
) -> CompressionResult:
|
||||
"""
|
||||
Perform the actual compression operation.
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation ID
|
||||
conversation: Conversation document
|
||||
model_id: Model ID for conversation
|
||||
decoded_token: User token
|
||||
|
||||
Returns:
|
||||
CompressionResult
|
||||
"""
|
||||
try:
|
||||
# Determine which model to use for compression
|
||||
compression_model = (
|
||||
settings.COMPRESSION_MODEL_OVERRIDE
|
||||
if settings.COMPRESSION_MODEL_OVERRIDE
|
||||
else model_id
|
||||
)
|
||||
|
||||
# Get provider and API key for compression model
|
||||
provider = get_provider_from_model_id(compression_model)
|
||||
api_key = get_api_key_for_provider(provider)
|
||||
|
||||
# Create compression LLM
|
||||
compression_llm = LLMCreator.create_llm(
|
||||
provider,
|
||||
api_key=api_key,
|
||||
user_api_key=None,
|
||||
decoded_token=decoded_token,
|
||||
model_id=compression_model,
|
||||
agent_id=conversation.get("agent_id"),
|
||||
)
|
||||
|
||||
# Create compression service with DB update capability
|
||||
compression_service = CompressionService(
|
||||
llm=compression_llm,
|
||||
model_id=compression_model,
|
||||
conversation_service=self.conversation_service,
|
||||
)
|
||||
|
||||
# Compress all queries up to the latest
|
||||
queries_count = len(conversation.get("queries", []))
|
||||
compress_up_to = queries_count - 1
|
||||
|
||||
if compress_up_to < 0:
|
||||
logger.warning("No queries to compress")
|
||||
return CompressionResult.success_no_compression([])
|
||||
|
||||
logger.info(
|
||||
f"Initiating compression for conversation {conversation_id}: "
|
||||
f"compressing all {queries_count} queries (0-{compress_up_to})"
|
||||
)
|
||||
|
||||
# Perform compression and save to DB
|
||||
metadata = compression_service.compress_and_save(
|
||||
conversation_id, conversation, compress_up_to
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Compression successful - ratio: {metadata.compression_ratio:.1f}x, "
|
||||
f"saved {metadata.original_token_count - metadata.compressed_token_count} tokens"
|
||||
)
|
||||
|
||||
# Reload conversation with updated metadata
|
||||
conversation = self.conversation_service.get_conversation(
|
||||
conversation_id, user_id=decoded_token.get("sub")
|
||||
)
|
||||
|
||||
# Get compressed context
|
||||
compressed_summary, recent_queries = (
|
||||
compression_service.get_compressed_context(conversation)
|
||||
)
|
||||
|
||||
return CompressionResult.success_with_compression(
|
||||
compressed_summary, recent_queries, metadata
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error performing compression: {str(e)}", exc_info=True)
|
||||
return CompressionResult.failure(str(e))
|
||||
|
||||
def compress_mid_execution(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_id: str,
|
||||
model_id: str,
|
||||
decoded_token: Dict[str, Any],
|
||||
current_conversation: Optional[Dict[str, Any]] = None,
|
||||
) -> CompressionResult:
|
||||
"""
|
||||
Perform compression during tool execution.
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation ID
|
||||
user_id: User ID
|
||||
model_id: Model ID
|
||||
decoded_token: User token
|
||||
current_conversation: Pre-loaded conversation (optional)
|
||||
|
||||
Returns:
|
||||
CompressionResult
|
||||
"""
|
||||
try:
|
||||
# Load conversation if not provided
|
||||
if current_conversation:
|
||||
conversation = current_conversation
|
||||
else:
|
||||
conversation = self.conversation_service.get_conversation(
|
||||
conversation_id, user_id
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
logger.warning(
|
||||
f"Could not load conversation {conversation_id} for mid-execution compression"
|
||||
)
|
||||
return CompressionResult.failure("Conversation not found")
|
||||
|
||||
# Perform compression
|
||||
return self._perform_compression(
|
||||
conversation_id, conversation, model_id, decoded_token
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in mid-execution compression: {str(e)}", exc_info=True
|
||||
)
|
||||
return CompressionResult.failure(str(e))
|
||||
149
application/api/answer/services/compression/prompt_builder.py
Normal file
149
application/api/answer/services/compression/prompt_builder.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""Compression prompt building logic."""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompressionPromptBuilder:
|
||||
"""Builds prompts for LLM compression calls."""
|
||||
|
||||
def __init__(self, version: str = "v1.0"):
|
||||
"""
|
||||
Initialize prompt builder.
|
||||
|
||||
Args:
|
||||
version: Prompt template version to use
|
||||
"""
|
||||
self.version = version
|
||||
self.system_prompt = self._load_prompt(version)
|
||||
|
||||
def _load_prompt(self, version: str) -> str:
|
||||
"""
|
||||
Load prompt template from file.
|
||||
|
||||
Args:
|
||||
version: Version string (e.g., 'v1.0')
|
||||
|
||||
Returns:
|
||||
Prompt template content
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If prompt template file doesn't exist
|
||||
"""
|
||||
current_dir = Path(__file__).resolve().parents[4]
|
||||
prompt_path = current_dir / "prompts" / "compression" / f"{version}.txt"
|
||||
|
||||
try:
|
||||
with open(prompt_path, "r") as f:
|
||||
return f.read()
|
||||
except FileNotFoundError:
|
||||
logger.error(f"Compression prompt template not found: {prompt_path}")
|
||||
raise FileNotFoundError(
|
||||
f"Compression prompt template '{version}' not found at {prompt_path}. "
|
||||
f"Please ensure the template file exists."
|
||||
)
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
queries: List[Dict[str, Any]],
|
||||
existing_compressions: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Build messages for compression LLM call.
|
||||
|
||||
Args:
|
||||
queries: List of query objects to compress
|
||||
existing_compressions: List of previous compression points
|
||||
|
||||
Returns:
|
||||
List of message dicts for LLM
|
||||
"""
|
||||
# Build conversation text
|
||||
conversation_text = self._format_conversation(queries)
|
||||
|
||||
# Add existing compression context if present
|
||||
existing_compression_context = ""
|
||||
if existing_compressions and len(existing_compressions) > 0:
|
||||
existing_compression_context = (
|
||||
"\n\nIMPORTANT: This conversation has been compressed before. "
|
||||
"Previous compression summaries:\n\n"
|
||||
)
|
||||
for i, comp in enumerate(existing_compressions):
|
||||
existing_compression_context += (
|
||||
f"--- Compression {i + 1} (up to message {comp.get('query_index', 'unknown')}) ---\n"
|
||||
f"{comp.get('compressed_summary', '')}\n\n"
|
||||
)
|
||||
existing_compression_context += (
|
||||
"Your task is to create a NEW summary that incorporates the context from "
|
||||
"previous compressions AND the new messages below. The final summary should "
|
||||
"be comprehensive and include all important information from both previous "
|
||||
"compressions and new messages.\n\n"
|
||||
)
|
||||
|
||||
user_prompt = (
|
||||
f"{existing_compression_context}"
|
||||
f"Here is the conversation to summarize:\n\n"
|
||||
f"{conversation_text}"
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": self.system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
|
||||
return messages
|
||||
|
||||
def _format_conversation(self, queries: List[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
Format conversation queries into readable text for compression.
|
||||
|
||||
Args:
|
||||
queries: List of query objects
|
||||
|
||||
Returns:
|
||||
Formatted conversation text
|
||||
"""
|
||||
conversation_lines = []
|
||||
|
||||
for i, query in enumerate(queries):
|
||||
conversation_lines.append(f"--- Message {i + 1} ---")
|
||||
conversation_lines.append(f"User: {query.get('prompt', '')}")
|
||||
|
||||
# Add tool calls if present
|
||||
tool_calls = query.get("tool_calls", [])
|
||||
if tool_calls:
|
||||
conversation_lines.append("\nTool Calls:")
|
||||
for tc in tool_calls:
|
||||
tool_name = tc.get("tool_name", "unknown")
|
||||
action_name = tc.get("action_name", "unknown")
|
||||
arguments = tc.get("arguments", {})
|
||||
result = tc.get("result", "")
|
||||
if result is None:
|
||||
result = ""
|
||||
status = tc.get("status", "unknown")
|
||||
|
||||
# Include full tool result for complete compression context
|
||||
conversation_lines.append(
|
||||
f" - {tool_name}.{action_name}({arguments}) "
|
||||
f"[{status}] → {result}"
|
||||
)
|
||||
|
||||
# Add agent thought if present
|
||||
thought = query.get("thought", "")
|
||||
if thought:
|
||||
conversation_lines.append(f"\nAgent Thought: {thought}")
|
||||
|
||||
# Add assistant response
|
||||
conversation_lines.append(f"\nAssistant: {query.get('response', '')}")
|
||||
|
||||
# Add sources if present
|
||||
sources = query.get("sources", [])
|
||||
if sources:
|
||||
conversation_lines.append(f"\nSources Used: {len(sources)} documents")
|
||||
|
||||
conversation_lines.append("") # Empty line between messages
|
||||
|
||||
return "\n".join(conversation_lines)
|
||||
306
application/api/answer/services/compression/service.py
Normal file
306
application/api/answer/services/compression/service.py
Normal file
@@ -0,0 +1,306 @@
|
||||
"""Core compression service with simplified responsibilities."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from application.api.answer.services.compression.prompt_builder import (
|
||||
CompressionPromptBuilder,
|
||||
)
|
||||
from application.api.answer.services.compression.token_counter import TokenCounter
|
||||
from application.api.answer.services.compression.types import (
|
||||
CompressionMetadata,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompressionService:
|
||||
"""
|
||||
Service for compressing conversation history.
|
||||
|
||||
Handles DB updates.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm,
|
||||
model_id: str,
|
||||
conversation_service=None,
|
||||
prompt_builder: Optional[CompressionPromptBuilder] = None,
|
||||
):
|
||||
"""
|
||||
Initialize compression service.
|
||||
|
||||
Args:
|
||||
llm: LLM instance to use for compression
|
||||
model_id: Model ID for compression
|
||||
conversation_service: Service for DB operations (optional, for DB updates)
|
||||
prompt_builder: Custom prompt builder (optional)
|
||||
"""
|
||||
self.llm = llm
|
||||
self.model_id = model_id
|
||||
self.conversation_service = conversation_service
|
||||
self.prompt_builder = prompt_builder or CompressionPromptBuilder(
|
||||
version=settings.COMPRESSION_PROMPT_VERSION
|
||||
)
|
||||
|
||||
def compress_conversation(
|
||||
self,
|
||||
conversation: Dict[str, Any],
|
||||
compress_up_to_index: int,
|
||||
) -> CompressionMetadata:
|
||||
"""
|
||||
Compress conversation history up to specified index.
|
||||
|
||||
Args:
|
||||
conversation: Full conversation document
|
||||
compress_up_to_index: Last query index to include in compression
|
||||
|
||||
Returns:
|
||||
CompressionMetadata with compression details
|
||||
|
||||
Raises:
|
||||
ValueError: If compress_up_to_index is invalid
|
||||
"""
|
||||
try:
|
||||
queries = conversation.get("queries", [])
|
||||
|
||||
if compress_up_to_index < 0 or compress_up_to_index >= len(queries):
|
||||
raise ValueError(
|
||||
f"Invalid compress_up_to_index: {compress_up_to_index} "
|
||||
f"(conversation has {len(queries)} queries)"
|
||||
)
|
||||
|
||||
# Get queries to compress
|
||||
queries_to_compress = queries[: compress_up_to_index + 1]
|
||||
|
||||
# Check if there are existing compressions
|
||||
existing_compressions = conversation.get("compression_metadata", {}).get(
|
||||
"compression_points", []
|
||||
)
|
||||
|
||||
if existing_compressions:
|
||||
logger.info(
|
||||
f"Found {len(existing_compressions)} previous compression(s) - "
|
||||
f"will incorporate into new summary"
|
||||
)
|
||||
|
||||
# Calculate original token count
|
||||
original_tokens = TokenCounter.count_query_tokens(queries_to_compress)
|
||||
|
||||
# Log tool call stats
|
||||
self._log_tool_call_stats(queries_to_compress)
|
||||
|
||||
# Build compression prompt
|
||||
messages = self.prompt_builder.build_prompt(
|
||||
queries_to_compress, existing_compressions
|
||||
)
|
||||
|
||||
# Call LLM to generate compression
|
||||
logger.info(
|
||||
f"Starting compression: {len(queries_to_compress)} queries "
|
||||
f"(messages 0-{compress_up_to_index}, {original_tokens} tokens) "
|
||||
f"using model {self.model_id}"
|
||||
)
|
||||
|
||||
response = self.llm.gen(
|
||||
model=self.model_id, messages=messages, max_tokens=4000
|
||||
)
|
||||
|
||||
# Extract summary from response
|
||||
compressed_summary = self._extract_summary(response)
|
||||
|
||||
# Calculate compressed token count
|
||||
compressed_tokens = TokenCounter.count_message_tokens(
|
||||
[{"content": compressed_summary}]
|
||||
)
|
||||
|
||||
# Calculate compression ratio
|
||||
compression_ratio = (
|
||||
original_tokens / compressed_tokens if compressed_tokens > 0 else 0
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Compression complete: {original_tokens} → {compressed_tokens} tokens "
|
||||
f"({compression_ratio:.1f}x compression)"
|
||||
)
|
||||
|
||||
# Build compression metadata
|
||||
compression_metadata = CompressionMetadata(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
query_index=compress_up_to_index,
|
||||
compressed_summary=compressed_summary,
|
||||
original_token_count=original_tokens,
|
||||
compressed_token_count=compressed_tokens,
|
||||
compression_ratio=compression_ratio,
|
||||
model_used=self.model_id,
|
||||
compression_prompt_version=self.prompt_builder.version,
|
||||
)
|
||||
|
||||
return compression_metadata
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error compressing conversation: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def compress_and_save(
|
||||
self,
|
||||
conversation_id: str,
|
||||
conversation: Dict[str, Any],
|
||||
compress_up_to_index: int,
|
||||
) -> CompressionMetadata:
|
||||
"""
|
||||
Compress conversation and save to database.
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation ID
|
||||
conversation: Full conversation document
|
||||
compress_up_to_index: Last query index to include
|
||||
|
||||
Returns:
|
||||
CompressionMetadata
|
||||
|
||||
Raises:
|
||||
ValueError: If conversation_service not provided or invalid index
|
||||
"""
|
||||
if not self.conversation_service:
|
||||
raise ValueError(
|
||||
"conversation_service required for compress_and_save operation"
|
||||
)
|
||||
|
||||
# Perform compression
|
||||
metadata = self.compress_conversation(conversation, compress_up_to_index)
|
||||
|
||||
# Save to database
|
||||
self.conversation_service.update_compression_metadata(
|
||||
conversation_id, metadata.to_dict()
|
||||
)
|
||||
|
||||
logger.info(f"Compression metadata saved to database for {conversation_id}")
|
||||
|
||||
return metadata
|
||||
|
||||
def get_compressed_context(
|
||||
self, conversation: Dict[str, Any]
|
||||
) -> tuple[Optional[str], List[Dict[str, Any]]]:
|
||||
"""
|
||||
Get compressed summary + recent uncompressed messages.
|
||||
|
||||
Args:
|
||||
conversation: Full conversation document
|
||||
|
||||
Returns:
|
||||
(compressed_summary, recent_messages)
|
||||
"""
|
||||
try:
|
||||
compression_metadata = conversation.get("compression_metadata", {})
|
||||
|
||||
if not compression_metadata.get("is_compressed"):
|
||||
logger.debug("No compression metadata found - using full history")
|
||||
queries = conversation.get("queries", [])
|
||||
if queries is None:
|
||||
logger.error("Conversation queries is None - returning empty list")
|
||||
return None, []
|
||||
return None, queries
|
||||
|
||||
compression_points = compression_metadata.get("compression_points", [])
|
||||
|
||||
if not compression_points:
|
||||
logger.debug("No compression points found - using full history")
|
||||
queries = conversation.get("queries", [])
|
||||
if queries is None:
|
||||
logger.error("Conversation queries is None - returning empty list")
|
||||
return None, []
|
||||
return None, queries
|
||||
|
||||
# Get the most recent compression point
|
||||
latest_compression = compression_points[-1]
|
||||
compressed_summary = latest_compression.get("compressed_summary")
|
||||
last_compressed_index = latest_compression.get("query_index")
|
||||
compressed_tokens = latest_compression.get("compressed_token_count", 0)
|
||||
original_tokens = latest_compression.get("original_token_count", 0)
|
||||
|
||||
# Get only messages after compression point
|
||||
queries = conversation.get("queries", [])
|
||||
total_queries = len(queries)
|
||||
recent_queries = queries[last_compressed_index + 1 :]
|
||||
|
||||
logger.info(
|
||||
f"Using compressed context: summary ({compressed_tokens} tokens, "
|
||||
f"compressed from {original_tokens}) + {len(recent_queries)} recent messages "
|
||||
f"(messages {last_compressed_index + 1}-{total_queries - 1})"
|
||||
)
|
||||
|
||||
return compressed_summary, recent_queries
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting compressed context: {str(e)}", exc_info=True
|
||||
)
|
||||
queries = conversation.get("queries", [])
|
||||
if queries is None:
|
||||
return None, []
|
||||
return None, queries
|
||||
|
||||
def _extract_summary(self, llm_response: str) -> str:
|
||||
"""
|
||||
Extract clean summary from LLM response.
|
||||
|
||||
Args:
|
||||
llm_response: Raw LLM response
|
||||
|
||||
Returns:
|
||||
Cleaned summary text
|
||||
"""
|
||||
try:
|
||||
# Try to extract content within <summary> tags
|
||||
summary_match = re.search(
|
||||
r"<summary>(.*?)</summary>", llm_response, re.DOTALL
|
||||
)
|
||||
|
||||
if summary_match:
|
||||
summary = summary_match.group(1).strip()
|
||||
else:
|
||||
# If no summary tags, remove analysis tags and use the rest
|
||||
summary = re.sub(
|
||||
r"<analysis>.*?</analysis>", "", llm_response, flags=re.DOTALL
|
||||
).strip()
|
||||
|
||||
return summary
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error extracting summary: {str(e)}, using full response")
|
||||
return llm_response
|
||||
|
||||
def _log_tool_call_stats(self, queries: List[Dict[str, Any]]) -> None:
|
||||
"""Log statistics about tool calls in queries."""
|
||||
total_tool_calls = 0
|
||||
total_tool_result_chars = 0
|
||||
tool_call_breakdown = {}
|
||||
|
||||
for q in queries:
|
||||
for tc in q.get("tool_calls", []):
|
||||
total_tool_calls += 1
|
||||
tool_name = tc.get("tool_name", "unknown")
|
||||
action_name = tc.get("action_name", "unknown")
|
||||
key = f"{tool_name}.{action_name}"
|
||||
tool_call_breakdown[key] = tool_call_breakdown.get(key, 0) + 1
|
||||
|
||||
# Track total tool result size
|
||||
result = tc.get("result", "")
|
||||
if result:
|
||||
total_tool_result_chars += len(str(result))
|
||||
|
||||
if total_tool_calls > 0:
|
||||
tool_breakdown_str = ", ".join(
|
||||
f"{tool}({count})"
|
||||
for tool, count in sorted(tool_call_breakdown.items())
|
||||
)
|
||||
tool_result_kb = total_tool_result_chars / 1024
|
||||
logger.info(
|
||||
f"Tool call breakdown: {tool_breakdown_str} "
|
||||
f"(total result size: {tool_result_kb:.1f} KB, {total_tool_result_chars:,} chars)"
|
||||
)
|
||||
103
application/api/answer/services/compression/threshold_checker.py
Normal file
103
application/api/answer/services/compression/threshold_checker.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""Compression threshold checking logic."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
|
||||
from application.core.model_utils import get_token_limit
|
||||
from application.core.settings import settings
|
||||
from application.api.answer.services.compression.token_counter import TokenCounter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompressionThresholdChecker:
|
||||
"""Determines if compression is needed based on token thresholds."""
|
||||
|
||||
def __init__(self, threshold_percentage: float = None):
|
||||
"""
|
||||
Initialize threshold checker.
|
||||
|
||||
Args:
|
||||
threshold_percentage: Percentage of context to use as threshold
|
||||
(defaults to settings.COMPRESSION_THRESHOLD_PERCENTAGE)
|
||||
"""
|
||||
self.threshold_percentage = (
|
||||
threshold_percentage or settings.COMPRESSION_THRESHOLD_PERCENTAGE
|
||||
)
|
||||
|
||||
def should_compress(
|
||||
self,
|
||||
conversation: Dict[str, Any],
|
||||
model_id: str,
|
||||
current_query_tokens: int = 500,
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if compression is needed.
|
||||
|
||||
Args:
|
||||
conversation: Full conversation document
|
||||
model_id: Target model for this request
|
||||
current_query_tokens: Estimated tokens for current query
|
||||
|
||||
Returns:
|
||||
True if tokens >= threshold% of context window
|
||||
"""
|
||||
try:
|
||||
# Calculate total tokens in conversation
|
||||
total_tokens = TokenCounter.count_conversation_tokens(conversation)
|
||||
total_tokens += current_query_tokens
|
||||
|
||||
# Get context window limit for model
|
||||
context_limit = get_token_limit(model_id)
|
||||
|
||||
# Calculate threshold
|
||||
threshold = int(context_limit * self.threshold_percentage)
|
||||
|
||||
compression_needed = total_tokens >= threshold
|
||||
percentage_used = (total_tokens / context_limit) * 100
|
||||
|
||||
if compression_needed:
|
||||
logger.warning(
|
||||
f"COMPRESSION TRIGGERED: {total_tokens} tokens / {context_limit} limit "
|
||||
f"({percentage_used:.1f}% used, threshold: {self.threshold_percentage * 100:.0f}%)"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Compression check: {total_tokens}/{context_limit} tokens "
|
||||
f"({percentage_used:.1f}% used, threshold: {self.threshold_percentage * 100:.0f}%) - No compression needed"
|
||||
)
|
||||
|
||||
return compression_needed
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking compression need: {str(e)}", exc_info=True)
|
||||
return False
|
||||
|
||||
def check_message_tokens(self, messages: list, model_id: str) -> bool:
|
||||
"""
|
||||
Check if message list exceeds threshold.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts
|
||||
model_id: Target model
|
||||
|
||||
Returns:
|
||||
True if at or above threshold
|
||||
"""
|
||||
try:
|
||||
current_tokens = TokenCounter.count_message_tokens(messages)
|
||||
context_limit = get_token_limit(model_id)
|
||||
threshold = int(context_limit * self.threshold_percentage)
|
||||
|
||||
if current_tokens >= threshold:
|
||||
logger.warning(
|
||||
f"Message context limit approaching: {current_tokens}/{context_limit} tokens "
|
||||
f"({(current_tokens/context_limit)*100:.1f}%)"
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking message tokens: {str(e)}", exc_info=True)
|
||||
return False
|
||||
103
application/api/answer/services/compression/token_counter.py
Normal file
103
application/api/answer/services/compression/token_counter.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""Token counting utilities for compression."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from application.utils import num_tokens_from_string
|
||||
from application.core.settings import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TokenCounter:
|
||||
"""Centralized token counting for conversations and messages."""
|
||||
|
||||
@staticmethod
|
||||
def count_message_tokens(messages: List[Dict]) -> int:
|
||||
"""
|
||||
Calculate total tokens in a list of messages.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'content' field
|
||||
|
||||
Returns:
|
||||
Total token count
|
||||
"""
|
||||
total_tokens = 0
|
||||
for message in messages:
|
||||
content = message.get("content", "")
|
||||
if isinstance(content, str):
|
||||
total_tokens += num_tokens_from_string(content)
|
||||
elif isinstance(content, list):
|
||||
# Handle structured content (tool calls, etc.)
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
total_tokens += num_tokens_from_string(str(item))
|
||||
return total_tokens
|
||||
|
||||
@staticmethod
|
||||
def count_query_tokens(
|
||||
queries: List[Dict[str, Any]], include_tool_calls: bool = True
|
||||
) -> int:
|
||||
"""
|
||||
Count tokens across multiple query objects.
|
||||
|
||||
Args:
|
||||
queries: List of query objects from conversation
|
||||
include_tool_calls: Whether to count tool call tokens
|
||||
|
||||
Returns:
|
||||
Total token count
|
||||
"""
|
||||
total_tokens = 0
|
||||
|
||||
for query in queries:
|
||||
# Count prompt and response tokens
|
||||
if "prompt" in query:
|
||||
total_tokens += num_tokens_from_string(query["prompt"])
|
||||
if "response" in query:
|
||||
total_tokens += num_tokens_from_string(query["response"])
|
||||
if "thought" in query:
|
||||
total_tokens += num_tokens_from_string(query.get("thought", ""))
|
||||
|
||||
# Count tool call tokens
|
||||
if include_tool_calls and "tool_calls" in query:
|
||||
for tool_call in query["tool_calls"]:
|
||||
tool_call_string = (
|
||||
f"Tool: {tool_call.get('tool_name')} | "
|
||||
f"Action: {tool_call.get('action_name')} | "
|
||||
f"Args: {tool_call.get('arguments')} | "
|
||||
f"Response: {tool_call.get('result')}"
|
||||
)
|
||||
total_tokens += num_tokens_from_string(tool_call_string)
|
||||
|
||||
return total_tokens
|
||||
|
||||
@staticmethod
|
||||
def count_conversation_tokens(
|
||||
conversation: Dict[str, Any], include_system_prompt: bool = False
|
||||
) -> int:
|
||||
"""
|
||||
Calculate total tokens in a conversation.
|
||||
|
||||
Args:
|
||||
conversation: Conversation document
|
||||
include_system_prompt: Whether to include system prompt in count
|
||||
|
||||
Returns:
|
||||
Total token count
|
||||
"""
|
||||
try:
|
||||
queries = conversation.get("queries", [])
|
||||
total_tokens = TokenCounter.count_query_tokens(queries)
|
||||
|
||||
# Add system prompt tokens if requested
|
||||
if include_system_prompt:
|
||||
# Rough estimate for system prompt
|
||||
total_tokens += settings.RESERVED_TOKENS.get("system_prompt", 500)
|
||||
|
||||
return total_tokens
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating conversation tokens: {str(e)}")
|
||||
return 0
|
||||
83
application/api/answer/services/compression/types.py
Normal file
83
application/api/answer/services/compression/types.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""Type definitions for compression module."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompressionMetadata:
|
||||
"""Metadata about a compression operation."""
|
||||
|
||||
timestamp: datetime
|
||||
query_index: int
|
||||
compressed_summary: str
|
||||
original_token_count: int
|
||||
compressed_token_count: int
|
||||
compression_ratio: float
|
||||
model_used: str
|
||||
compression_prompt_version: str
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for DB storage."""
|
||||
return {
|
||||
"timestamp": self.timestamp,
|
||||
"query_index": self.query_index,
|
||||
"compressed_summary": self.compressed_summary,
|
||||
"original_token_count": self.original_token_count,
|
||||
"compressed_token_count": self.compressed_token_count,
|
||||
"compression_ratio": self.compression_ratio,
|
||||
"model_used": self.model_used,
|
||||
"compression_prompt_version": self.compression_prompt_version,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompressionResult:
|
||||
"""Result of a compression operation."""
|
||||
|
||||
success: bool
|
||||
compressed_summary: Optional[str] = None
|
||||
recent_queries: List[Dict[str, Any]] = field(default_factory=list)
|
||||
metadata: Optional[CompressionMetadata] = None
|
||||
error: Optional[str] = None
|
||||
compression_performed: bool = False
|
||||
|
||||
@classmethod
|
||||
def success_with_compression(
|
||||
cls, summary: str, queries: List[Dict], metadata: CompressionMetadata
|
||||
) -> "CompressionResult":
|
||||
"""Create a successful result with compression."""
|
||||
return cls(
|
||||
success=True,
|
||||
compressed_summary=summary,
|
||||
recent_queries=queries,
|
||||
metadata=metadata,
|
||||
compression_performed=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def success_no_compression(cls, queries: List[Dict]) -> "CompressionResult":
|
||||
"""Create a successful result without compression needed."""
|
||||
return cls(
|
||||
success=True,
|
||||
recent_queries=queries,
|
||||
compression_performed=False,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def failure(cls, error: str) -> "CompressionResult":
|
||||
"""Create a failure result."""
|
||||
return cls(success=False, error=error, compression_performed=False)
|
||||
|
||||
def as_history(self) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Convert recent queries to history format.
|
||||
|
||||
Returns:
|
||||
List of prompt/response dicts
|
||||
"""
|
||||
return [
|
||||
{"prompt": q["prompt"], "response": q["response"]}
|
||||
for q in self.recent_queries
|
||||
]
|
||||
175
application/api/answer/services/continuation_service.py
Normal file
175
application/api/answer/services/continuation_service.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""Service for saving and restoring tool-call continuation state.
|
||||
|
||||
When a stream pauses (tool needs approval or client-side execution),
|
||||
the full execution state is persisted to MongoDB so the client can
|
||||
resume later by sending tool_actions.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from bson import ObjectId
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||
from application.storage.db.repositories.pending_tool_state import (
|
||||
PendingToolStateRepository,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# TTL for pending states — auto-cleaned after this period
|
||||
PENDING_STATE_TTL_SECONDS = 30 * 60 # 30 minutes
|
||||
|
||||
|
||||
def _make_serializable(obj: Any) -> Any:
|
||||
"""Recursively convert MongoDB ObjectIds and other non-JSON types."""
|
||||
if isinstance(obj, ObjectId):
|
||||
return str(obj)
|
||||
if isinstance(obj, dict):
|
||||
return {str(k): _make_serializable(v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [_make_serializable(v) for v in obj]
|
||||
if isinstance(obj, bytes):
|
||||
return obj.decode("utf-8", errors="replace")
|
||||
return obj
|
||||
|
||||
|
||||
class ContinuationService:
|
||||
"""Manages pending tool-call state in MongoDB."""
|
||||
|
||||
def __init__(self):
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
self.collection = db["pending_tool_state"]
|
||||
self._ensure_indexes()
|
||||
|
||||
def _ensure_indexes(self):
|
||||
try:
|
||||
self.collection.create_index(
|
||||
"expires_at", expireAfterSeconds=0
|
||||
)
|
||||
self.collection.create_index(
|
||||
[("conversation_id", 1), ("user", 1)], unique=True
|
||||
)
|
||||
except Exception:
|
||||
# Indexes may already exist or mongomock doesn't support TTL
|
||||
pass
|
||||
|
||||
def save_state(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user: str,
|
||||
messages: List[Dict],
|
||||
pending_tool_calls: List[Dict],
|
||||
tools_dict: Dict,
|
||||
tool_schemas: List[Dict],
|
||||
agent_config: Dict,
|
||||
client_tools: Optional[List[Dict]] = None,
|
||||
) -> str:
|
||||
"""Save execution state for later continuation.
|
||||
|
||||
Args:
|
||||
conversation_id: The conversation this state belongs to.
|
||||
user: Owner user ID.
|
||||
messages: Full messages array at the pause point.
|
||||
pending_tool_calls: Tool calls awaiting client action.
|
||||
tools_dict: Serializable tools configuration dict.
|
||||
tool_schemas: LLM-formatted tool schemas (agent.tools).
|
||||
agent_config: Config needed to recreate the agent on resume.
|
||||
client_tools: Client-provided tool schemas for client-side execution.
|
||||
|
||||
Returns:
|
||||
The string ID of the saved state document.
|
||||
"""
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
expires_at = now + datetime.timedelta(seconds=PENDING_STATE_TTL_SECONDS)
|
||||
|
||||
doc = {
|
||||
"conversation_id": conversation_id,
|
||||
"user": user,
|
||||
"messages": _make_serializable(messages),
|
||||
"pending_tool_calls": _make_serializable(pending_tool_calls),
|
||||
"tools_dict": _make_serializable(tools_dict),
|
||||
"tool_schemas": _make_serializable(tool_schemas),
|
||||
"agent_config": _make_serializable(agent_config),
|
||||
"client_tools": _make_serializable(client_tools) if client_tools else None,
|
||||
"created_at": now,
|
||||
"expires_at": expires_at,
|
||||
}
|
||||
|
||||
# Upsert — only one pending state per conversation per user
|
||||
result = self.collection.replace_one(
|
||||
{"conversation_id": conversation_id, "user": user},
|
||||
doc,
|
||||
upsert=True,
|
||||
)
|
||||
state_id = str(result.upserted_id) if result.upserted_id else conversation_id
|
||||
logger.info(
|
||||
f"Saved continuation state for conversation {conversation_id} "
|
||||
f"with {len(pending_tool_calls)} pending tool call(s)"
|
||||
)
|
||||
|
||||
# Dual-write to Postgres — upsert against the same Mongo conversation
|
||||
# by resolving its UUID via conversations.legacy_mongo_id.
|
||||
def _pg_save(_: PendingToolStateRepository) -> None:
|
||||
conn = _._conn # reuse the existing transaction
|
||||
conv = ConversationsRepository(conn).get_by_legacy_id(conversation_id)
|
||||
if conv is None:
|
||||
return
|
||||
_.save_state(
|
||||
conv["id"],
|
||||
user,
|
||||
messages=_make_serializable(messages),
|
||||
pending_tool_calls=_make_serializable(pending_tool_calls),
|
||||
tools_dict=_make_serializable(tools_dict),
|
||||
tool_schemas=_make_serializable(tool_schemas),
|
||||
agent_config=_make_serializable(agent_config),
|
||||
client_tools=_make_serializable(client_tools) if client_tools else None,
|
||||
)
|
||||
|
||||
dual_write(PendingToolStateRepository, _pg_save)
|
||||
return state_id
|
||||
|
||||
def load_state(
|
||||
self, conversation_id: str, user: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Load pending continuation state.
|
||||
|
||||
Returns:
|
||||
The state dict, or None if no pending state exists.
|
||||
"""
|
||||
doc = self.collection.find_one(
|
||||
{"conversation_id": conversation_id, "user": user}
|
||||
)
|
||||
if not doc:
|
||||
return None
|
||||
doc["_id"] = str(doc["_id"])
|
||||
return doc
|
||||
|
||||
def delete_state(self, conversation_id: str, user: str) -> bool:
|
||||
"""Delete pending state after successful resumption.
|
||||
|
||||
Returns:
|
||||
True if a document was deleted.
|
||||
"""
|
||||
result = self.collection.delete_one(
|
||||
{"conversation_id": conversation_id, "user": user}
|
||||
)
|
||||
if result.deleted_count:
|
||||
logger.info(
|
||||
f"Deleted continuation state for conversation {conversation_id}"
|
||||
)
|
||||
|
||||
# Dual-write to Postgres — delete the same row.
|
||||
def _pg_delete(repo: PendingToolStateRepository) -> None:
|
||||
conv = ConversationsRepository(repo._conn).get_by_legacy_id(conversation_id)
|
||||
if conv is None:
|
||||
return
|
||||
repo.delete_state(conv["id"], user)
|
||||
|
||||
dual_write(PendingToolStateRepository, _pg_delete)
|
||||
return result.deleted_count > 0
|
||||
399
application/api/answer/services/conversation_service.py
Normal file
399
application/api/answer/services/conversation_service.py
Normal file
@@ -0,0 +1,399 @@
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||
from bson import ObjectId
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConversationService:
|
||||
def __init__(self):
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
self.conversations_collection = db["conversations"]
|
||||
self.agents_collection = db["agents"]
|
||||
|
||||
def get_conversation(
|
||||
self, conversation_id: str, user_id: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Retrieve a conversation with proper access control"""
|
||||
if not conversation_id or not user_id:
|
||||
return None
|
||||
try:
|
||||
conversation = self.conversations_collection.find_one(
|
||||
{
|
||||
"_id": ObjectId(conversation_id),
|
||||
"$or": [{"user": user_id}, {"shared_with": user_id}],
|
||||
}
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
logger.warning(
|
||||
f"Conversation not found or unauthorized - ID: {conversation_id}, User: {user_id}"
|
||||
)
|
||||
return None
|
||||
conversation["_id"] = str(conversation["_id"])
|
||||
return conversation
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching conversation: {str(e)}", exc_info=True)
|
||||
return None
|
||||
|
||||
def save_conversation(
|
||||
self,
|
||||
conversation_id: Optional[str],
|
||||
question: str,
|
||||
response: str,
|
||||
thought: str,
|
||||
sources: List[Dict[str, Any]],
|
||||
tool_calls: List[Dict[str, Any]],
|
||||
llm: Any,
|
||||
model_id: str,
|
||||
decoded_token: Dict[str, Any],
|
||||
index: Optional[int] = None,
|
||||
api_key: Optional[str] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
is_shared_usage: bool = False,
|
||||
shared_token: Optional[str] = None,
|
||||
attachment_ids: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""Save or update a conversation in the database"""
|
||||
if decoded_token is None:
|
||||
raise ValueError("Invalid or missing authentication token")
|
||||
user_id = decoded_token.get("sub")
|
||||
if not user_id:
|
||||
raise ValueError("User ID not found in token")
|
||||
current_time = datetime.now(timezone.utc)
|
||||
|
||||
# clean up in sources array such that we save max 1k characters for text part
|
||||
for source in sources:
|
||||
if "text" in source and isinstance(source["text"], str):
|
||||
source["text"] = source["text"][:1000]
|
||||
|
||||
if conversation_id is not None and index is not None:
|
||||
# Update existing conversation with new query
|
||||
|
||||
result = self.conversations_collection.update_one(
|
||||
{
|
||||
"_id": ObjectId(conversation_id),
|
||||
"user": user_id,
|
||||
f"queries.{index}": {"$exists": True},
|
||||
},
|
||||
{
|
||||
"$set": {
|
||||
f"queries.{index}.prompt": question,
|
||||
f"queries.{index}.response": response,
|
||||
f"queries.{index}.thought": thought,
|
||||
f"queries.{index}.sources": sources,
|
||||
f"queries.{index}.tool_calls": tool_calls,
|
||||
f"queries.{index}.timestamp": current_time,
|
||||
f"queries.{index}.attachments": attachment_ids,
|
||||
f"queries.{index}.model_id": model_id,
|
||||
**(
|
||||
{f"queries.{index}.metadata": metadata}
|
||||
if metadata
|
||||
else {}
|
||||
),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
if result.matched_count == 0:
|
||||
raise ValueError("Conversation not found or unauthorized")
|
||||
self.conversations_collection.update_one(
|
||||
{
|
||||
"_id": ObjectId(conversation_id),
|
||||
"user": user_id,
|
||||
f"queries.{index}": {"$exists": True},
|
||||
},
|
||||
{"$push": {"queries": {"$each": [], "$slice": index + 1}}},
|
||||
)
|
||||
# Dual-write to Postgres: update the message at :index and
|
||||
# truncate anything after it, mirroring Mongo's $set+$slice.
|
||||
def _pg_update_at_index(repo: ConversationsRepository) -> None:
|
||||
conv = repo.get_by_legacy_id(conversation_id)
|
||||
if conv is None:
|
||||
return
|
||||
repo.update_message_at(conv["id"], index, {
|
||||
"prompt": question,
|
||||
"response": response,
|
||||
"thought": thought,
|
||||
"sources": sources,
|
||||
"tool_calls": tool_calls,
|
||||
"attachments": attachment_ids,
|
||||
"model_id": model_id,
|
||||
"timestamp": current_time,
|
||||
**({"metadata": metadata} if metadata else {}),
|
||||
})
|
||||
repo.truncate_after(conv["id"], index)
|
||||
|
||||
dual_write(ConversationsRepository, _pg_update_at_index)
|
||||
return conversation_id
|
||||
elif conversation_id:
|
||||
# Append new message to existing conversation
|
||||
|
||||
result = self.conversations_collection.update_one(
|
||||
{"_id": ObjectId(conversation_id), "user": user_id},
|
||||
{
|
||||
"$push": {
|
||||
"queries": {
|
||||
"prompt": question,
|
||||
"response": response,
|
||||
"thought": thought,
|
||||
"sources": sources,
|
||||
"tool_calls": tool_calls,
|
||||
"timestamp": current_time,
|
||||
"attachments": attachment_ids,
|
||||
"model_id": model_id,
|
||||
**({"metadata": metadata} if metadata else {}),
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
if result.matched_count == 0:
|
||||
raise ValueError("Conversation not found or unauthorized")
|
||||
|
||||
# Dual-write to Postgres: append the same message.
|
||||
def _pg_append(repo: ConversationsRepository) -> None:
|
||||
conv = repo.get_by_legacy_id(conversation_id)
|
||||
if conv is None:
|
||||
return
|
||||
repo.append_message(conv["id"], {
|
||||
"prompt": question,
|
||||
"response": response,
|
||||
"thought": thought,
|
||||
"sources": sources,
|
||||
"tool_calls": tool_calls,
|
||||
"attachments": attachment_ids,
|
||||
"model_id": model_id,
|
||||
"timestamp": current_time,
|
||||
"metadata": metadata or {},
|
||||
})
|
||||
|
||||
dual_write(ConversationsRepository, _pg_append)
|
||||
return conversation_id
|
||||
else:
|
||||
# Create new conversation
|
||||
|
||||
messages_summary = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant that creates concise conversation titles. "
|
||||
"Summarize conversations in 3 words or less using the same language as the user.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Summarise following conversation in no more than 3 words, "
|
||||
"respond ONLY with the summary, use the same language as the "
|
||||
"user query \n\nUser: " + question + "\n\n" + "AI: " + response,
|
||||
},
|
||||
]
|
||||
|
||||
completion = llm.gen(
|
||||
model=model_id, messages=messages_summary, max_tokens=500
|
||||
)
|
||||
|
||||
if not completion or not completion.strip():
|
||||
completion = question[:50] if question else "New Conversation"
|
||||
|
||||
query_doc = {
|
||||
"prompt": question,
|
||||
"response": response,
|
||||
"thought": thought,
|
||||
"sources": sources,
|
||||
"tool_calls": tool_calls,
|
||||
"timestamp": current_time,
|
||||
"attachments": attachment_ids,
|
||||
"model_id": model_id,
|
||||
}
|
||||
if metadata:
|
||||
query_doc["metadata"] = metadata
|
||||
|
||||
conversation_data = {
|
||||
"user": user_id,
|
||||
"date": current_time,
|
||||
"name": completion,
|
||||
"queries": [query_doc],
|
||||
}
|
||||
|
||||
if api_key:
|
||||
if agent_id:
|
||||
conversation_data["agent_id"] = agent_id
|
||||
if is_shared_usage:
|
||||
conversation_data["is_shared_usage"] = is_shared_usage
|
||||
conversation_data["shared_token"] = shared_token
|
||||
agent = self.agents_collection.find_one({"key": api_key})
|
||||
if agent:
|
||||
conversation_data["api_key"] = agent["key"]
|
||||
result = self.conversations_collection.insert_one(conversation_data)
|
||||
inserted_id = str(result.inserted_id)
|
||||
|
||||
# Dual-write to Postgres: create the conversation row with
|
||||
# legacy_mongo_id and append the first message.
|
||||
def _pg_create(repo: ConversationsRepository) -> None:
|
||||
conv = repo.create(
|
||||
user_id,
|
||||
completion,
|
||||
agent_id=conversation_data.get("agent_id"),
|
||||
api_key=conversation_data.get("api_key"),
|
||||
is_shared_usage=conversation_data.get("is_shared_usage", False),
|
||||
shared_token=conversation_data.get("shared_token"),
|
||||
legacy_mongo_id=inserted_id,
|
||||
)
|
||||
repo.append_message(conv["id"], {
|
||||
"prompt": question,
|
||||
"response": response,
|
||||
"thought": thought,
|
||||
"sources": sources,
|
||||
"tool_calls": tool_calls,
|
||||
"attachments": attachment_ids,
|
||||
"model_id": model_id,
|
||||
"timestamp": current_time,
|
||||
"metadata": metadata or {},
|
||||
})
|
||||
|
||||
dual_write(ConversationsRepository, _pg_create)
|
||||
return inserted_id
|
||||
|
||||
def update_compression_metadata(
|
||||
self, conversation_id: str, compression_metadata: Dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
Update conversation with compression metadata.
|
||||
|
||||
Uses $push with $slice to keep only the most recent compression points,
|
||||
preventing unbounded array growth. Since each compression incorporates
|
||||
previous compressions, older points become redundant.
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation ID
|
||||
compression_metadata: Compression point data
|
||||
"""
|
||||
try:
|
||||
self.conversations_collection.update_one(
|
||||
{"_id": ObjectId(conversation_id)},
|
||||
{
|
||||
"$set": {
|
||||
"compression_metadata.is_compressed": True,
|
||||
"compression_metadata.last_compression_at": compression_metadata.get(
|
||||
"timestamp"
|
||||
),
|
||||
},
|
||||
"$push": {
|
||||
"compression_metadata.compression_points": {
|
||||
"$each": [compression_metadata],
|
||||
"$slice": -settings.COMPRESSION_MAX_HISTORY_POINTS,
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
logger.info(
|
||||
f"Updated compression metadata for conversation {conversation_id}"
|
||||
)
|
||||
|
||||
# Dual-write to Postgres: mirror $set + $push $slice.
|
||||
def _pg_compression(repo: ConversationsRepository) -> None:
|
||||
conv = repo.get_by_legacy_id(conversation_id)
|
||||
if conv is None:
|
||||
return
|
||||
repo.set_compression_flags(
|
||||
conv["id"],
|
||||
is_compressed=True,
|
||||
last_compression_at=compression_metadata.get("timestamp"),
|
||||
)
|
||||
repo.append_compression_point(
|
||||
conv["id"],
|
||||
compression_metadata,
|
||||
max_points=settings.COMPRESSION_MAX_HISTORY_POINTS,
|
||||
)
|
||||
|
||||
dual_write(ConversationsRepository, _pg_compression)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error updating compression metadata: {str(e)}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def append_compression_message(
|
||||
self, conversation_id: str, compression_metadata: Dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
Append a synthetic compression summary entry into the conversation history.
|
||||
This makes the summary visible in the DB alongside normal queries.
|
||||
"""
|
||||
try:
|
||||
summary = compression_metadata.get("compressed_summary", "")
|
||||
if not summary:
|
||||
return
|
||||
timestamp = compression_metadata.get("timestamp", datetime.now(timezone.utc))
|
||||
|
||||
self.conversations_collection.update_one(
|
||||
{"_id": ObjectId(conversation_id)},
|
||||
{
|
||||
"$push": {
|
||||
"queries": {
|
||||
"prompt": "[Context Compression Summary]",
|
||||
"response": summary,
|
||||
"thought": "",
|
||||
"sources": [],
|
||||
"tool_calls": [],
|
||||
"timestamp": timestamp,
|
||||
"attachments": [],
|
||||
"model_id": compression_metadata.get("model_used"),
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
def _pg_append_summary(repo: ConversationsRepository) -> None:
|
||||
conv = repo.get_by_legacy_id(conversation_id)
|
||||
if conv is None:
|
||||
return
|
||||
repo.append_message(conv["id"], {
|
||||
"prompt": "[Context Compression Summary]",
|
||||
"response": summary,
|
||||
"thought": "",
|
||||
"sources": [],
|
||||
"tool_calls": [],
|
||||
"attachments": [],
|
||||
"model_id": compression_metadata.get("model_used"),
|
||||
"timestamp": timestamp,
|
||||
})
|
||||
|
||||
dual_write(ConversationsRepository, _pg_append_summary)
|
||||
logger.info(f"Appended compression summary to conversation {conversation_id}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error appending compression summary: {str(e)}", exc_info=True
|
||||
)
|
||||
|
||||
def get_compression_metadata(
|
||||
self, conversation_id: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get compression metadata for a conversation.
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation ID
|
||||
|
||||
Returns:
|
||||
Compression metadata dict or None
|
||||
"""
|
||||
try:
|
||||
conversation = self.conversations_collection.find_one(
|
||||
{"_id": ObjectId(conversation_id)}, {"compression_metadata": 1}
|
||||
)
|
||||
return conversation.get("compression_metadata") if conversation else None
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting compression metadata: {str(e)}", exc_info=True
|
||||
)
|
||||
return None
|
||||
97
application/api/answer/services/prompt_renderer.py
Normal file
97
application/api/answer/services/prompt_renderer.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from application.templates.namespaces import NamespaceManager
|
||||
|
||||
from application.templates.template_engine import TemplateEngine, TemplateRenderError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PromptRenderer:
|
||||
"""Service for rendering prompts with dynamic context using namespaces"""
|
||||
|
||||
def __init__(self):
|
||||
self.template_engine = TemplateEngine()
|
||||
self.namespace_manager = NamespaceManager()
|
||||
|
||||
def render_prompt(
|
||||
self,
|
||||
prompt_content: str,
|
||||
user_id: Optional[str] = None,
|
||||
request_id: Optional[str] = None,
|
||||
passthrough_data: Optional[Dict[str, Any]] = None,
|
||||
docs: Optional[list] = None,
|
||||
docs_together: Optional[str] = None,
|
||||
tools_data: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
Render prompt with full context from all namespaces.
|
||||
|
||||
Args:
|
||||
prompt_content: Raw prompt template string
|
||||
user_id: Current user identifier
|
||||
request_id: Unique request identifier
|
||||
passthrough_data: Parameters from web request
|
||||
docs: RAG retrieved documents
|
||||
docs_together: Concatenated document content
|
||||
tools_data: Pre-fetched tool results organized by tool name
|
||||
**kwargs: Additional parameters for namespace builders
|
||||
|
||||
Returns:
|
||||
Rendered prompt string with all variables substituted
|
||||
|
||||
Raises:
|
||||
TemplateRenderError: If template rendering fails
|
||||
"""
|
||||
if not prompt_content:
|
||||
return ""
|
||||
|
||||
uses_template = self._uses_template_syntax(prompt_content)
|
||||
|
||||
if not uses_template:
|
||||
return self._apply_legacy_substitutions(prompt_content, docs_together)
|
||||
|
||||
try:
|
||||
context = self.namespace_manager.build_context(
|
||||
user_id=user_id,
|
||||
request_id=request_id,
|
||||
passthrough_data=passthrough_data,
|
||||
docs=docs,
|
||||
docs_together=docs_together,
|
||||
tools_data=tools_data,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return self.template_engine.render(prompt_content, context)
|
||||
except TemplateRenderError:
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"Prompt rendering failed: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
raise TemplateRenderError(error_msg) from e
|
||||
|
||||
def _uses_template_syntax(self, prompt_content: str) -> bool:
|
||||
"""Check if prompt uses Jinja2 template syntax"""
|
||||
return "{{" in prompt_content and "}}" in prompt_content
|
||||
|
||||
def _apply_legacy_substitutions(
|
||||
self, prompt_content: str, docs_together: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Apply backward-compatible substitutions for old prompt format.
|
||||
|
||||
Handles legacy {summaries} and {query} placeholders during transition period.
|
||||
"""
|
||||
if docs_together:
|
||||
prompt_content = prompt_content.replace("{summaries}", docs_together)
|
||||
return prompt_content
|
||||
|
||||
def validate_template(self, prompt_content: str) -> bool:
|
||||
"""Validate prompt template syntax"""
|
||||
return self.template_engine.validate_template(prompt_content)
|
||||
|
||||
def extract_variables(self, prompt_content: str) -> set[str]:
|
||||
"""Extract all variable names from prompt template"""
|
||||
return self.template_engine.extract_variables(prompt_content)
|
||||
1072
application/api/answer/services/stream_processor.py
Normal file
1072
application/api/answer/services/stream_processor.py
Normal file
File diff suppressed because it is too large
Load Diff
545
application/api/connector/routes.py
Normal file
545
application/api/connector/routes.py
Normal file
@@ -0,0 +1,545 @@
|
||||
import base64
|
||||
import datetime
|
||||
import html
|
||||
import json
|
||||
import uuid
|
||||
from urllib.parse import urlencode
|
||||
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import (
|
||||
Blueprint,
|
||||
current_app,
|
||||
jsonify,
|
||||
make_response,
|
||||
request
|
||||
)
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
|
||||
from application.api.user.tasks import (
|
||||
ingest_connector_task,
|
||||
)
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.api import api
|
||||
|
||||
|
||||
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
sources_collection = db["sources"]
|
||||
sessions_collection = db["connector_sessions"]
|
||||
|
||||
connector = Blueprint("connector", __name__)
|
||||
connectors_ns = Namespace("connectors", description="Connector operations", path="/")
|
||||
api.add_namespace(connectors_ns)
|
||||
|
||||
# Fixed callback status path to prevent open redirect
|
||||
CALLBACK_STATUS_PATH = "/api/connectors/callback-status"
|
||||
|
||||
|
||||
def build_callback_redirect(params: dict) -> str:
|
||||
"""Build a safe redirect URL to the callback status page.
|
||||
|
||||
Uses a fixed path and properly URL-encodes all parameters
|
||||
to prevent URL injection and open redirect vulnerabilities.
|
||||
"""
|
||||
return f"{CALLBACK_STATUS_PATH}?{urlencode(params)}"
|
||||
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/auth")
|
||||
class ConnectorAuth(Resource):
|
||||
@api.doc(description="Get connector OAuth authorization URL", params={"provider": "Connector provider (e.g., google_drive)"})
|
||||
def get(self):
|
||||
try:
|
||||
provider = request.args.get('provider') or request.args.get('source')
|
||||
if not provider:
|
||||
return make_response(jsonify({"success": False, "error": "Missing provider"}), 400)
|
||||
|
||||
if not ConnectorCreator.is_supported(provider):
|
||||
return make_response(jsonify({"success": False, "error": f"Unsupported provider: {provider}"}), 400)
|
||||
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
||||
user_id = decoded_token.get('sub')
|
||||
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
result = sessions_collection.insert_one({
|
||||
"provider": provider,
|
||||
"user": user_id,
|
||||
"status": "pending",
|
||||
"created_at": now
|
||||
})
|
||||
state_dict = {
|
||||
"provider": provider,
|
||||
"object_id": str(result.inserted_id)
|
||||
}
|
||||
state = base64.urlsafe_b64encode(json.dumps(state_dict).encode()).decode()
|
||||
|
||||
auth = ConnectorCreator.create_auth(provider)
|
||||
authorization_url = auth.get_authorization_url(state=state)
|
||||
return make_response(jsonify({
|
||||
"success": True,
|
||||
"authorization_url": authorization_url,
|
||||
"state": state
|
||||
}), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error generating connector auth URL: {e}", exc_info=True)
|
||||
return make_response(jsonify({"success": False, "error": "Failed to generate authorization URL"}), 500)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/callback")
|
||||
class ConnectorsCallback(Resource):
|
||||
@api.doc(description="Handle OAuth callback for external connectors")
|
||||
def get(self):
|
||||
"""Handle OAuth callback for external connectors"""
|
||||
try:
|
||||
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||
from flask import request, redirect
|
||||
|
||||
authorization_code = request.args.get('code')
|
||||
state = request.args.get('state')
|
||||
error = request.args.get('error')
|
||||
|
||||
state_dict = json.loads(base64.urlsafe_b64decode(state.encode()).decode())
|
||||
provider = state_dict.get("provider")
|
||||
state_object_id = state_dict.get("object_id")
|
||||
|
||||
# Validate provider
|
||||
if not provider or not isinstance(provider, str) or not ConnectorCreator.is_supported(provider):
|
||||
return redirect(build_callback_redirect({
|
||||
"status": "error",
|
||||
"message": "Invalid provider"
|
||||
}))
|
||||
|
||||
if error:
|
||||
if error == "access_denied":
|
||||
return redirect(build_callback_redirect({
|
||||
"status": "cancelled",
|
||||
"message": "Authentication was cancelled. You can try again if you'd like to connect your account.",
|
||||
"provider": provider
|
||||
}))
|
||||
else:
|
||||
current_app.logger.warning(f"OAuth error in callback: {error}")
|
||||
return redirect(build_callback_redirect({
|
||||
"status": "error",
|
||||
"message": "Authentication failed. Please try again and make sure to grant all requested permissions.",
|
||||
"provider": provider
|
||||
}))
|
||||
|
||||
if not authorization_code:
|
||||
return redirect(build_callback_redirect({
|
||||
"status": "error",
|
||||
"message": "Authentication failed. Please try again and make sure to grant all requested permissions.",
|
||||
"provider": provider
|
||||
}))
|
||||
|
||||
try:
|
||||
auth = ConnectorCreator.create_auth(provider)
|
||||
token_info = auth.exchange_code_for_tokens(authorization_code)
|
||||
|
||||
session_token = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
if provider == "google_drive":
|
||||
credentials = auth.create_credentials_from_token_info(token_info)
|
||||
service = auth.build_drive_service(credentials)
|
||||
user_info = service.about().get(fields="user").execute()
|
||||
user_email = user_info.get('user', {}).get('emailAddress', 'Connected User')
|
||||
else:
|
||||
user_email = token_info.get('user_info', {}).get('email', 'Connected User')
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.warning(f"Could not get user info: {e}")
|
||||
user_email = 'Connected User'
|
||||
|
||||
sanitized_token_info = auth.sanitize_token_info(token_info)
|
||||
|
||||
sessions_collection.find_one_and_update(
|
||||
{"_id": ObjectId(state_object_id), "provider": provider},
|
||||
{
|
||||
"$set": {
|
||||
"session_token": session_token,
|
||||
"token_info": sanitized_token_info,
|
||||
"user_email": user_email,
|
||||
"status": "authorized"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Redirect to success page with session token and user email
|
||||
return redirect(build_callback_redirect({
|
||||
"status": "success",
|
||||
"message": "Authentication successful",
|
||||
"provider": provider,
|
||||
"session_token": session_token,
|
||||
"user_email": user_email
|
||||
}))
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error exchanging code for tokens: {str(e)}", exc_info=True)
|
||||
return redirect(build_callback_redirect({
|
||||
"status": "error",
|
||||
"message": "Authentication failed. Please try again and make sure to grant all requested permissions.",
|
||||
"provider": provider
|
||||
}))
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error handling connector callback: {e}")
|
||||
return redirect(build_callback_redirect({
|
||||
"status": "error",
|
||||
"message": "Authentication failed. Please try again and make sure to grant all requested permissions."
|
||||
}))
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/files")
|
||||
class ConnectorFiles(Resource):
|
||||
@api.expect(api.model("ConnectorFilesModel", {
|
||||
"provider": fields.String(required=True),
|
||||
"session_token": fields.String(required=True),
|
||||
"folder_id": fields.String(required=False),
|
||||
"limit": fields.Integer(required=False),
|
||||
"page_token": fields.String(required=False),
|
||||
"search_query": fields.String(required=False),
|
||||
}))
|
||||
@api.doc(description="List files from a connector provider (supports pagination and search)")
|
||||
def post(self):
|
||||
try:
|
||||
data = request.get_json()
|
||||
provider = data.get('provider')
|
||||
session_token = data.get('session_token')
|
||||
limit = data.get('limit', 10)
|
||||
|
||||
if not provider or not session_token:
|
||||
return make_response(jsonify({"success": False, "error": "provider and session_token are required"}), 400)
|
||||
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
||||
user = decoded_token.get('sub')
|
||||
session = sessions_collection.find_one({"session_token": session_token, "user": user})
|
||||
if not session:
|
||||
return make_response(jsonify({"success": False, "error": "Invalid or unauthorized session"}), 401)
|
||||
|
||||
loader = ConnectorCreator.create_connector(provider, session_token)
|
||||
|
||||
generic_keys = {'provider', 'session_token'}
|
||||
input_config = {
|
||||
k: v for k, v in data.items() if k not in generic_keys
|
||||
}
|
||||
input_config['list_only'] = True
|
||||
|
||||
documents = loader.load_data(input_config)
|
||||
|
||||
files = []
|
||||
for doc in documents[:limit]:
|
||||
metadata = doc.extra_info
|
||||
modified_time = metadata.get('modified_time')
|
||||
if modified_time:
|
||||
date_part = modified_time.split('T')[0]
|
||||
time_part = modified_time.split('T')[1].split('.')[0].split('Z')[0]
|
||||
formatted_time = f"{date_part} {time_part}"
|
||||
else:
|
||||
formatted_time = None
|
||||
|
||||
files.append({
|
||||
'id': doc.doc_id,
|
||||
'name': metadata.get('file_name', 'Unknown File'),
|
||||
'type': metadata.get('mime_type', 'unknown'),
|
||||
'size': metadata.get('size', None),
|
||||
'modifiedTime': formatted_time,
|
||||
'isFolder': metadata.get('is_folder', False)
|
||||
})
|
||||
|
||||
next_token = getattr(loader, 'next_page_token', None)
|
||||
has_more = bool(next_token)
|
||||
|
||||
return make_response(jsonify({
|
||||
"success": True,
|
||||
"files": files,
|
||||
"total": len(files),
|
||||
"next_page_token": next_token,
|
||||
"has_more": has_more
|
||||
}), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error loading connector files: {e}", exc_info=True)
|
||||
return make_response(jsonify({"success": False, "error": "Failed to load files"}), 500)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/validate-session")
|
||||
class ConnectorValidateSession(Resource):
|
||||
@api.expect(api.model("ConnectorValidateSessionModel", {"provider": fields.String(required=True), "session_token": fields.String(required=True)}))
|
||||
@api.doc(description="Validate connector session token and return user info and access token")
|
||||
def post(self):
|
||||
try:
|
||||
data = request.get_json()
|
||||
provider = data.get('provider')
|
||||
session_token = data.get('session_token')
|
||||
if not provider or not session_token:
|
||||
return make_response(jsonify({"success": False, "error": "provider and session_token are required"}), 400)
|
||||
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
||||
user = decoded_token.get('sub')
|
||||
|
||||
session = sessions_collection.find_one({"session_token": session_token, "user": user})
|
||||
if not session or "token_info" not in session:
|
||||
return make_response(jsonify({"success": False, "error": "Invalid or expired session"}), 401)
|
||||
|
||||
token_info = session["token_info"]
|
||||
auth = ConnectorCreator.create_auth(provider)
|
||||
is_expired = auth.is_token_expired(token_info)
|
||||
|
||||
if is_expired and token_info.get('refresh_token'):
|
||||
try:
|
||||
refreshed_token_info = auth.refresh_access_token(token_info.get('refresh_token'))
|
||||
sanitized_token_info = auth.sanitize_token_info(refreshed_token_info)
|
||||
sessions_collection.update_one(
|
||||
{"session_token": session_token},
|
||||
{"$set": {"token_info": sanitized_token_info}}
|
||||
)
|
||||
token_info = sanitized_token_info
|
||||
is_expired = False
|
||||
except Exception as refresh_error:
|
||||
current_app.logger.error(f"Failed to refresh token: {refresh_error}")
|
||||
|
||||
if is_expired:
|
||||
return make_response(jsonify({
|
||||
"success": False,
|
||||
"expired": True,
|
||||
"error": "Session token has expired. Please reconnect."
|
||||
}), 401)
|
||||
|
||||
_base_fields = {"access_token", "refresh_token", "token_uri", "expiry"}
|
||||
provider_extras = {k: v for k, v in token_info.items() if k not in _base_fields}
|
||||
|
||||
response_data = {
|
||||
"success": True,
|
||||
"expired": False,
|
||||
"user_email": session.get('user_email', 'Connected User'),
|
||||
"access_token": token_info.get('access_token'),
|
||||
**provider_extras,
|
||||
}
|
||||
|
||||
return make_response(jsonify(response_data), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error validating connector session: {e}", exc_info=True)
|
||||
return make_response(jsonify({"success": False, "error": "Failed to validate session"}), 500)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/disconnect")
|
||||
class ConnectorDisconnect(Resource):
|
||||
@api.expect(api.model("ConnectorDisconnectModel", {"provider": fields.String(required=True), "session_token": fields.String(required=False)}))
|
||||
@api.doc(description="Disconnect a connector session")
|
||||
def post(self):
|
||||
try:
|
||||
data = request.get_json()
|
||||
provider = data.get('provider')
|
||||
session_token = data.get('session_token')
|
||||
if not provider:
|
||||
return make_response(jsonify({"success": False, "error": "provider is required"}), 400)
|
||||
|
||||
|
||||
if session_token:
|
||||
sessions_collection.delete_one({"session_token": session_token})
|
||||
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error disconnecting connector session: {e}", exc_info=True)
|
||||
return make_response(jsonify({"success": False, "error": "Failed to disconnect session"}), 500)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/sync")
|
||||
class ConnectorSync(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"ConnectorSyncModel",
|
||||
{
|
||||
"source_id": fields.String(required=True, description="Source ID to sync"),
|
||||
"session_token": fields.String(required=True, description="Authentication token")
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Sync connector source to check for modifications")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
|
||||
try:
|
||||
data = request.get_json()
|
||||
source_id = data.get('source_id')
|
||||
session_token = data.get('session_token')
|
||||
|
||||
if not all([source_id, session_token]):
|
||||
return make_response(
|
||||
jsonify({
|
||||
"success": False,
|
||||
"error": "source_id and session_token are required"
|
||||
}),
|
||||
400
|
||||
)
|
||||
source = sources_collection.find_one({"_id": ObjectId(source_id)})
|
||||
if not source:
|
||||
return make_response(
|
||||
jsonify({
|
||||
"success": False,
|
||||
"error": "Source not found"
|
||||
}),
|
||||
404
|
||||
)
|
||||
|
||||
if source.get('user') != decoded_token.get('sub'):
|
||||
return make_response(
|
||||
jsonify({
|
||||
"success": False,
|
||||
"error": "Unauthorized access to source"
|
||||
}),
|
||||
403
|
||||
)
|
||||
|
||||
remote_data = {}
|
||||
try:
|
||||
if source.get('remote_data'):
|
||||
remote_data = json.loads(source.get('remote_data'))
|
||||
except json.JSONDecodeError:
|
||||
current_app.logger.error(f"Invalid remote_data format for source {source_id}")
|
||||
remote_data = {}
|
||||
|
||||
source_type = remote_data.get('provider')
|
||||
if not source_type:
|
||||
return make_response(
|
||||
jsonify({
|
||||
"success": False,
|
||||
"error": "Source provider not found in remote_data"
|
||||
}),
|
||||
400
|
||||
)
|
||||
|
||||
# Extract configuration from remote_data
|
||||
file_ids = remote_data.get('file_ids', [])
|
||||
folder_ids = remote_data.get('folder_ids', [])
|
||||
recursive = remote_data.get('recursive', True)
|
||||
|
||||
# Start the sync task
|
||||
task = ingest_connector_task.delay(
|
||||
job_name=source.get('name'),
|
||||
user=decoded_token.get('sub'),
|
||||
source_type=source_type,
|
||||
session_token=session_token,
|
||||
file_ids=file_ids,
|
||||
folder_ids=folder_ids,
|
||||
recursive=recursive,
|
||||
retriever=source.get('retriever', 'classic'),
|
||||
operation_mode="sync",
|
||||
doc_id=source_id,
|
||||
sync_frequency=source.get('sync_frequency', 'never')
|
||||
)
|
||||
|
||||
return make_response(
|
||||
jsonify({
|
||||
"success": True,
|
||||
"task_id": task.id
|
||||
}),
|
||||
200
|
||||
)
|
||||
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error syncing connector source: {err}",
|
||||
exc_info=True
|
||||
)
|
||||
return make_response(
|
||||
jsonify({
|
||||
"success": False,
|
||||
"error": "Failed to sync connector source"
|
||||
}),
|
||||
400
|
||||
)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/callback-status")
|
||||
class ConnectorCallbackStatus(Resource):
|
||||
@api.doc(description="Return HTML page with connector authentication status")
|
||||
def get(self):
|
||||
"""Return HTML page with connector authentication status"""
|
||||
try:
|
||||
# Validate and sanitize status to a known value
|
||||
status_raw = request.args.get('status', 'error')
|
||||
status = status_raw if status_raw in ('success', 'error', 'cancelled') else 'error'
|
||||
|
||||
# Escape all user-controlled values for HTML context
|
||||
message = html.escape(request.args.get('message', ''))
|
||||
provider_raw = request.args.get('provider', 'connector')
|
||||
provider = html.escape(provider_raw.replace('_', ' ').title())
|
||||
session_token = request.args.get('session_token', '')
|
||||
user_email = html.escape(request.args.get('user_email', ''))
|
||||
|
||||
def safe_js_string(value: str) -> str:
|
||||
"""Safely encode a string for embedding in inline JavaScript."""
|
||||
js_encoded = json.dumps(value)
|
||||
return js_encoded.replace('</', '<\\/').replace('<!--', '<\\!--')
|
||||
|
||||
js_status = safe_js_string(status)
|
||||
js_session_token = safe_js_string(session_token)
|
||||
js_user_email = safe_js_string(user_email)
|
||||
js_provider_type = safe_js_string(provider_raw)
|
||||
|
||||
html_content = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>{provider} Authentication</title>
|
||||
<style>
|
||||
body {{ font-family: Arial, sans-serif; text-align: center; padding: 40px; }}
|
||||
.container {{ max-width: 600px; margin: 0 auto; }}
|
||||
.success {{ color: #4CAF50; }}
|
||||
.error {{ color: #F44336; }}
|
||||
.cancelled {{ color: #FF9800; }}
|
||||
</style>
|
||||
<script>
|
||||
window.onload = function() {{
|
||||
const status = {js_status};
|
||||
const sessionToken = {js_session_token};
|
||||
const userEmail = {js_user_email};
|
||||
const providerType = {js_provider_type};
|
||||
|
||||
if (status === "success" && window.opener) {{
|
||||
window.opener.postMessage({{
|
||||
type: providerType + '_auth_success',
|
||||
session_token: sessionToken,
|
||||
user_email: userEmail
|
||||
}}, '*');
|
||||
|
||||
setTimeout(() => window.close(), 3000);
|
||||
}} else if (status === "cancelled" || status === "error") {{
|
||||
setTimeout(() => window.close(), 3000);
|
||||
}}
|
||||
}};
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h2>{provider} Authentication</h2>
|
||||
<div class="{status}">
|
||||
<p>{message}</p>
|
||||
{f'<p>Connected as: {user_email}</p>' if status == 'success' else ''}
|
||||
</div>
|
||||
<p><small>You can close this window. {f"Your {provider} is now connected and ready to use." if status == 'success' else "Feel free to close this window."}</small></p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
return make_response(html_content, 200, {'Content-Type': 'text/html'})
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error rendering callback status page: {e}")
|
||||
return make_response("Authentication error occurred", 500, {'Content-Type': 'text/html'})
|
||||
|
||||
|
||||
@@ -1,14 +1,18 @@
|
||||
import os
|
||||
import datetime
|
||||
from flask import Blueprint, request, send_from_directory
|
||||
import json
|
||||
from flask import Blueprint, request, send_from_directory, jsonify
|
||||
from werkzeug.utils import secure_filename
|
||||
from bson.objectid import ObjectId
|
||||
|
||||
import logging
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo["docsgpt"]
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
conversations_collection = db["conversations"]
|
||||
sources_collection = db["sources"]
|
||||
|
||||
@@ -20,6 +24,24 @@ current_dir = os.path.dirname(
|
||||
internal = Blueprint("internal", __name__)
|
||||
|
||||
|
||||
@internal.before_request
|
||||
def verify_internal_key():
|
||||
"""Verify INTERNAL_KEY for all internal endpoint requests.
|
||||
|
||||
Deny by default: if INTERNAL_KEY is not configured, reject all requests.
|
||||
"""
|
||||
if not settings.INTERNAL_KEY:
|
||||
logger.warning(
|
||||
f"Internal API request rejected from {request.remote_addr}: "
|
||||
"INTERNAL_KEY is not configured"
|
||||
)
|
||||
return jsonify({"error": "Unauthorized", "message": "Internal API is not configured"}), 401
|
||||
internal_key = request.headers.get("X-Internal-Key")
|
||||
if not internal_key or internal_key != settings.INTERNAL_KEY:
|
||||
logger.warning(f"Unauthorized internal API access attempt from {request.remote_addr}")
|
||||
return jsonify({"error": "Unauthorized", "message": "Invalid or missing internal key"}), 401
|
||||
|
||||
|
||||
@internal.route("/api/download", methods=["get"])
|
||||
def download_file():
|
||||
user = secure_filename(request.args.get("user"))
|
||||
@@ -34,71 +56,101 @@ def upload_index_files():
|
||||
"""Upload two files(index.faiss, index.pkl) to the user's folder."""
|
||||
if "user" not in request.form:
|
||||
return {"status": "no user"}
|
||||
user = secure_filename(request.form["user"])
|
||||
user = request.form["user"]
|
||||
if "name" not in request.form:
|
||||
return {"status": "no name"}
|
||||
job_name = secure_filename(request.form["name"])
|
||||
tokens = secure_filename(request.form["tokens"])
|
||||
retriever = secure_filename(request.form["retriever"])
|
||||
id = secure_filename(request.form["id"])
|
||||
type = secure_filename(request.form["type"])
|
||||
job_name = request.form["name"]
|
||||
tokens = request.form["tokens"]
|
||||
retriever = request.form["retriever"]
|
||||
id = request.form["id"]
|
||||
type = request.form["type"]
|
||||
remote_data = request.form["remote_data"] if "remote_data" in request.form else None
|
||||
sync_frequency = secure_filename(request.form["sync_frequency"]) if "sync_frequency" in request.form else None
|
||||
sync_frequency = request.form["sync_frequency"] if "sync_frequency" in request.form else None
|
||||
|
||||
file_path = request.form.get("file_path")
|
||||
directory_structure = request.form.get("directory_structure")
|
||||
file_name_map = request.form.get("file_name_map")
|
||||
|
||||
if directory_structure:
|
||||
try:
|
||||
directory_structure = json.loads(directory_structure)
|
||||
except Exception:
|
||||
logger.error("Error parsing directory_structure")
|
||||
directory_structure = {}
|
||||
else:
|
||||
directory_structure = {}
|
||||
if file_name_map:
|
||||
try:
|
||||
file_name_map = json.loads(file_name_map)
|
||||
except Exception:
|
||||
logger.error("Error parsing file_name_map")
|
||||
file_name_map = None
|
||||
else:
|
||||
file_name_map = None
|
||||
|
||||
save_dir = os.path.join(current_dir, "indexes", str(id))
|
||||
storage = StorageCreator.get_storage()
|
||||
index_base_path = f"indexes/{id}"
|
||||
|
||||
if settings.VECTOR_STORE == "faiss":
|
||||
if "file_faiss" not in request.files:
|
||||
print("No file part")
|
||||
logger.error("No file_faiss part")
|
||||
return {"status": "no file"}
|
||||
file_faiss = request.files["file_faiss"]
|
||||
if file_faiss.filename == "":
|
||||
return {"status": "no file name"}
|
||||
if "file_pkl" not in request.files:
|
||||
print("No file part")
|
||||
logger.error("No file_pkl part")
|
||||
return {"status": "no file"}
|
||||
file_pkl = request.files["file_pkl"]
|
||||
if file_pkl.filename == "":
|
||||
return {"status": "no file name"}
|
||||
# saves index files
|
||||
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
file_faiss.save(os.path.join(save_dir, "index.faiss"))
|
||||
file_pkl.save(os.path.join(save_dir, "index.pkl"))
|
||||
# Save index files to storage
|
||||
faiss_storage_path = f"{index_base_path}/index.faiss"
|
||||
pkl_storage_path = f"{index_base_path}/index.pkl"
|
||||
storage.save_file(file_faiss, faiss_storage_path)
|
||||
storage.save_file(file_pkl, pkl_storage_path)
|
||||
|
||||
|
||||
existing_entry = sources_collection.find_one({"_id": ObjectId(id)})
|
||||
if existing_entry:
|
||||
update_fields = {
|
||||
"user": user,
|
||||
"name": job_name,
|
||||
"language": job_name,
|
||||
"date": datetime.datetime.now(),
|
||||
"model": settings.EMBEDDINGS_NAME,
|
||||
"type": type,
|
||||
"tokens": tokens,
|
||||
"retriever": retriever,
|
||||
"remote_data": remote_data,
|
||||
"sync_frequency": sync_frequency,
|
||||
"file_path": file_path,
|
||||
"directory_structure": directory_structure,
|
||||
}
|
||||
if file_name_map is not None:
|
||||
update_fields["file_name_map"] = file_name_map
|
||||
sources_collection.update_one(
|
||||
{"_id": ObjectId(id)},
|
||||
{
|
||||
"$set": {
|
||||
"user": user,
|
||||
"name": job_name,
|
||||
"language": job_name,
|
||||
"date": datetime.datetime.now(),
|
||||
"model": settings.EMBEDDINGS_NAME,
|
||||
"type": type,
|
||||
"tokens": tokens,
|
||||
"retriever": retriever,
|
||||
"remote_data": remote_data,
|
||||
"sync_frequency": sync_frequency,
|
||||
}
|
||||
},
|
||||
{"$set": update_fields},
|
||||
)
|
||||
else:
|
||||
sources_collection.insert_one(
|
||||
{
|
||||
"_id": ObjectId(id),
|
||||
"user": user,
|
||||
"name": job_name,
|
||||
"language": job_name,
|
||||
"date": datetime.datetime.now(),
|
||||
"model": settings.EMBEDDINGS_NAME,
|
||||
"type": type,
|
||||
"tokens": tokens,
|
||||
"retriever": retriever,
|
||||
"remote_data": remote_data,
|
||||
"sync_frequency": sync_frequency,
|
||||
}
|
||||
)
|
||||
insert_doc = {
|
||||
"_id": ObjectId(id),
|
||||
"user": user,
|
||||
"name": job_name,
|
||||
"language": job_name,
|
||||
"date": datetime.datetime.now(),
|
||||
"model": settings.EMBEDDINGS_NAME,
|
||||
"type": type,
|
||||
"tokens": tokens,
|
||||
"retriever": retriever,
|
||||
"remote_data": remote_data,
|
||||
"sync_frequency": sync_frequency,
|
||||
"file_path": file_path,
|
||||
"directory_structure": directory_structure,
|
||||
}
|
||||
if file_name_map is not None:
|
||||
insert_doc["file_name_map"] = file_name_map
|
||||
sources_collection.insert_one(insert_doc)
|
||||
return {"status": "ok"}
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
"""User API module - provides all user-related API endpoints"""
|
||||
|
||||
from .routes import user
|
||||
|
||||
__all__ = ["user"]
|
||||
|
||||
8
application/api/user/agents/__init__.py
Normal file
8
application/api/user/agents/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""Agents module."""
|
||||
|
||||
from .routes import agents_ns
|
||||
from .sharing import agents_sharing_ns
|
||||
from .webhooks import agents_webhooks_ns
|
||||
from .folders import agents_folders_ns
|
||||
|
||||
__all__ = ["agents_ns", "agents_sharing_ns", "agents_webhooks_ns", "agents_folders_ns"]
|
||||
276
application/api/user/agents/folders.py
Normal file
276
application/api/user/agents/folders.py
Normal file
@@ -0,0 +1,276 @@
|
||||
"""
|
||||
Agent folders management routes.
|
||||
Provides virtual folder organization for agents (Google Drive-like structure).
|
||||
"""
|
||||
|
||||
import datetime
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import Namespace, Resource, fields
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import (
|
||||
agent_folders_collection,
|
||||
agents_collection,
|
||||
)
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.repositories.agent_folders import AgentFoldersRepository
|
||||
|
||||
agents_folders_ns = Namespace(
|
||||
"agents_folders", description="Agent folder management", path="/api/agents/folders"
|
||||
)
|
||||
|
||||
|
||||
def _folder_error_response(message: str, err: Exception):
|
||||
current_app.logger.error(f"{message}: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False, "message": message}), 400)
|
||||
|
||||
|
||||
@agents_folders_ns.route("/")
|
||||
class AgentFolders(Resource):
|
||||
@api.doc(description="Get all folders for the user")
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
folders = list(agent_folders_collection.find({"user": user}))
|
||||
result = [
|
||||
{
|
||||
"id": str(f["_id"]),
|
||||
"name": f["name"],
|
||||
"parent_id": f.get("parent_id"),
|
||||
"created_at": f.get("created_at", "").isoformat() if f.get("created_at") else None,
|
||||
"updated_at": f.get("updated_at", "").isoformat() if f.get("updated_at") else None,
|
||||
}
|
||||
for f in folders
|
||||
]
|
||||
return make_response(jsonify({"folders": result}), 200)
|
||||
except Exception as err:
|
||||
return _folder_error_response("Failed to fetch folders", err)
|
||||
|
||||
@api.doc(description="Create a new folder")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"CreateFolder",
|
||||
{
|
||||
"name": fields.String(required=True, description="Folder name"),
|
||||
"parent_id": fields.String(required=False, description="Parent folder ID"),
|
||||
},
|
||||
)
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
if not data or not data.get("name"):
|
||||
return make_response(jsonify({"success": False, "message": "Folder name is required"}), 400)
|
||||
|
||||
parent_id = data.get("parent_id")
|
||||
if parent_id:
|
||||
parent = agent_folders_collection.find_one({"_id": ObjectId(parent_id), "user": user})
|
||||
if not parent:
|
||||
return make_response(jsonify({"success": False, "message": "Parent folder not found"}), 404)
|
||||
|
||||
try:
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
folder = {
|
||||
"user": user,
|
||||
"name": data["name"],
|
||||
"parent_id": parent_id,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
}
|
||||
result = agent_folders_collection.insert_one(folder)
|
||||
dual_write(
|
||||
AgentFoldersRepository,
|
||||
lambda repo, u=user, n=data["name"]: repo.create(u, n),
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"id": str(result.inserted_id), "name": data["name"], "parent_id": parent_id}),
|
||||
201,
|
||||
)
|
||||
except Exception as err:
|
||||
return _folder_error_response("Failed to create folder", err)
|
||||
|
||||
|
||||
@agents_folders_ns.route("/<string:folder_id>")
|
||||
class AgentFolder(Resource):
|
||||
@api.doc(description="Get a specific folder with its agents")
|
||||
def get(self, folder_id):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
folder = agent_folders_collection.find_one({"_id": ObjectId(folder_id), "user": user})
|
||||
if not folder:
|
||||
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
|
||||
|
||||
agents = list(agents_collection.find({"user": user, "folder_id": folder_id}))
|
||||
agents_list = [
|
||||
{"id": str(a["_id"]), "name": a["name"], "description": a.get("description", "")}
|
||||
for a in agents
|
||||
]
|
||||
subfolders = list(agent_folders_collection.find({"user": user, "parent_id": folder_id}))
|
||||
subfolders_list = [{"id": str(sf["_id"]), "name": sf["name"]} for sf in subfolders]
|
||||
|
||||
return make_response(
|
||||
jsonify({
|
||||
"id": str(folder["_id"]),
|
||||
"name": folder["name"],
|
||||
"parent_id": folder.get("parent_id"),
|
||||
"agents": agents_list,
|
||||
"subfolders": subfolders_list,
|
||||
}),
|
||||
200,
|
||||
)
|
||||
except Exception as err:
|
||||
return _folder_error_response("Failed to fetch folder", err)
|
||||
|
||||
@api.doc(description="Update a folder")
|
||||
def put(self, folder_id):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
if not data:
|
||||
return make_response(jsonify({"success": False, "message": "No data provided"}), 400)
|
||||
|
||||
try:
|
||||
update_fields = {"updated_at": datetime.datetime.now(datetime.timezone.utc)}
|
||||
if "name" in data:
|
||||
update_fields["name"] = data["name"]
|
||||
if "parent_id" in data:
|
||||
if data["parent_id"] == folder_id:
|
||||
return make_response(jsonify({"success": False, "message": "Cannot set folder as its own parent"}), 400)
|
||||
update_fields["parent_id"] = data["parent_id"]
|
||||
|
||||
result = agent_folders_collection.update_one(
|
||||
{"_id": ObjectId(folder_id), "user": user}, {"$set": update_fields}
|
||||
)
|
||||
if result.matched_count == 0:
|
||||
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
except Exception as err:
|
||||
return _folder_error_response("Failed to update folder", err)
|
||||
|
||||
@api.doc(description="Delete a folder")
|
||||
def delete(self, folder_id):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
agents_collection.update_many(
|
||||
{"user": user, "folder_id": folder_id}, {"$unset": {"folder_id": ""}}
|
||||
)
|
||||
agent_folders_collection.update_many(
|
||||
{"user": user, "parent_id": folder_id}, {"$unset": {"parent_id": ""}}
|
||||
)
|
||||
result = agent_folders_collection.delete_one({"_id": ObjectId(folder_id), "user": user})
|
||||
dual_write(
|
||||
AgentFoldersRepository,
|
||||
lambda repo, fid=folder_id, u=user: repo.delete(fid, u),
|
||||
)
|
||||
if result.deleted_count == 0:
|
||||
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
except Exception as err:
|
||||
return _folder_error_response("Failed to delete folder", err)
|
||||
|
||||
|
||||
@agents_folders_ns.route("/move_agent")
|
||||
class MoveAgentToFolder(Resource):
|
||||
@api.doc(description="Move an agent to a folder or remove from folder")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"MoveAgent",
|
||||
{
|
||||
"agent_id": fields.String(required=True, description="Agent ID to move"),
|
||||
"folder_id": fields.String(required=False, description="Target folder ID (null to remove from folder)"),
|
||||
},
|
||||
)
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
if not data or not data.get("agent_id"):
|
||||
return make_response(jsonify({"success": False, "message": "Agent ID is required"}), 400)
|
||||
|
||||
agent_id = data["agent_id"]
|
||||
folder_id = data.get("folder_id")
|
||||
|
||||
try:
|
||||
agent = agents_collection.find_one({"_id": ObjectId(agent_id), "user": user})
|
||||
if not agent:
|
||||
return make_response(jsonify({"success": False, "message": "Agent not found"}), 404)
|
||||
|
||||
if folder_id:
|
||||
folder = agent_folders_collection.find_one({"_id": ObjectId(folder_id), "user": user})
|
||||
if not folder:
|
||||
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
|
||||
agents_collection.update_one(
|
||||
{"_id": ObjectId(agent_id)}, {"$set": {"folder_id": folder_id}}
|
||||
)
|
||||
else:
|
||||
agents_collection.update_one(
|
||||
{"_id": ObjectId(agent_id)}, {"$unset": {"folder_id": ""}}
|
||||
)
|
||||
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
except Exception as err:
|
||||
return _folder_error_response("Failed to move agent", err)
|
||||
|
||||
|
||||
@agents_folders_ns.route("/bulk_move")
|
||||
class BulkMoveAgents(Resource):
|
||||
@api.doc(description="Move multiple agents to a folder")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"BulkMoveAgents",
|
||||
{
|
||||
"agent_ids": fields.List(fields.String, required=True, description="List of agent IDs"),
|
||||
"folder_id": fields.String(required=False, description="Target folder ID"),
|
||||
},
|
||||
)
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
if not data or not data.get("agent_ids"):
|
||||
return make_response(jsonify({"success": False, "message": "Agent IDs are required"}), 400)
|
||||
|
||||
agent_ids = data["agent_ids"]
|
||||
folder_id = data.get("folder_id")
|
||||
|
||||
try:
|
||||
if folder_id:
|
||||
folder = agent_folders_collection.find_one({"_id": ObjectId(folder_id), "user": user})
|
||||
if not folder:
|
||||
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
|
||||
|
||||
object_ids = [ObjectId(aid) for aid in agent_ids]
|
||||
if folder_id:
|
||||
agents_collection.update_many(
|
||||
{"_id": {"$in": object_ids}, "user": user},
|
||||
{"$set": {"folder_id": folder_id}},
|
||||
)
|
||||
else:
|
||||
agents_collection.update_many(
|
||||
{"_id": {"$in": object_ids}, "user": user},
|
||||
{"$unset": {"folder_id": ""}},
|
||||
)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
except Exception as err:
|
||||
return _folder_error_response("Failed to move agents", err)
|
||||
1509
application/api/user/agents/routes.py
Normal file
1509
application/api/user/agents/routes.py
Normal file
File diff suppressed because it is too large
Load Diff
271
application/api/user/agents/sharing.py
Normal file
271
application/api/user/agents/sharing.py
Normal file
@@ -0,0 +1,271 @@
|
||||
"""Agent management sharing functionality."""
|
||||
|
||||
import datetime
|
||||
import secrets
|
||||
|
||||
from bson import DBRef
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.core.settings import settings
|
||||
from application.api.user.base import (
|
||||
agents_collection,
|
||||
db,
|
||||
ensure_user_doc,
|
||||
resolve_tool_details,
|
||||
user_tools_collection,
|
||||
users_collection,
|
||||
)
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.repositories.users import UsersRepository
|
||||
from application.utils import generate_image_url
|
||||
|
||||
agents_sharing_ns = Namespace(
|
||||
"agents", description="Agent management operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
@agents_sharing_ns.route("/shared_agent")
|
||||
class SharedAgent(Resource):
|
||||
@api.doc(
|
||||
params={
|
||||
"token": "Shared token of the agent",
|
||||
},
|
||||
description="Get a shared agent by token or ID",
|
||||
)
|
||||
def get(self):
|
||||
shared_token = request.args.get("token")
|
||||
|
||||
if not shared_token:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Token or ID is required"}), 400
|
||||
)
|
||||
try:
|
||||
query = {
|
||||
"shared_publicly": True,
|
||||
"shared_token": shared_token,
|
||||
}
|
||||
shared_agent = agents_collection.find_one(query)
|
||||
if not shared_agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Shared agent not found"}),
|
||||
404,
|
||||
)
|
||||
agent_id = str(shared_agent["_id"])
|
||||
data = {
|
||||
"id": agent_id,
|
||||
"user": shared_agent.get("user", ""),
|
||||
"name": shared_agent.get("name", ""),
|
||||
"image": (
|
||||
generate_image_url(shared_agent["image"])
|
||||
if shared_agent.get("image")
|
||||
else ""
|
||||
),
|
||||
"description": shared_agent.get("description", ""),
|
||||
"source": (
|
||||
str(source_doc["_id"])
|
||||
if isinstance(shared_agent.get("source"), DBRef)
|
||||
and (source_doc := db.dereference(shared_agent.get("source")))
|
||||
else ""
|
||||
),
|
||||
"chunks": shared_agent.get("chunks", "0"),
|
||||
"retriever": shared_agent.get("retriever", "classic"),
|
||||
"prompt_id": shared_agent.get("prompt_id", "default"),
|
||||
"tools": shared_agent.get("tools", []),
|
||||
"tool_details": resolve_tool_details(shared_agent.get("tools", [])),
|
||||
"agent_type": shared_agent.get("agent_type", ""),
|
||||
"status": shared_agent.get("status", ""),
|
||||
"json_schema": shared_agent.get("json_schema"),
|
||||
"limited_token_mode": shared_agent.get("limited_token_mode", False),
|
||||
"token_limit": shared_agent.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"]),
|
||||
"limited_request_mode": shared_agent.get("limited_request_mode", False),
|
||||
"request_limit": shared_agent.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"]),
|
||||
"created_at": shared_agent.get("createdAt", ""),
|
||||
"updated_at": shared_agent.get("updatedAt", ""),
|
||||
"shared": shared_agent.get("shared_publicly", False),
|
||||
"shared_token": shared_agent.get("shared_token", ""),
|
||||
"shared_metadata": shared_agent.get("shared_metadata", {}),
|
||||
}
|
||||
|
||||
if data["tools"]:
|
||||
enriched_tools = []
|
||||
for tool in data["tools"]:
|
||||
tool_data = user_tools_collection.find_one({"_id": ObjectId(tool)})
|
||||
if tool_data:
|
||||
enriched_tools.append(tool_data.get("name", ""))
|
||||
data["tools"] = enriched_tools
|
||||
decoded_token = getattr(request, "decoded_token", None)
|
||||
if decoded_token:
|
||||
user_id = decoded_token.get("sub")
|
||||
owner_id = shared_agent.get("user")
|
||||
|
||||
if user_id != owner_id:
|
||||
ensure_user_doc(user_id)
|
||||
users_collection.update_one(
|
||||
{"user_id": user_id},
|
||||
{"$addToSet": {"agent_preferences.shared_with_me": agent_id}},
|
||||
)
|
||||
dual_write(UsersRepository,
|
||||
lambda repo, uid=user_id, aid=agent_id: repo.add_shared(uid, aid)
|
||||
)
|
||||
return make_response(jsonify(data), 200)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving shared agent: {err}")
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
|
||||
@agents_sharing_ns.route("/shared_agents")
|
||||
class SharedAgents(Resource):
|
||||
@api.doc(description="Get shared agents explicitly shared with the user")
|
||||
def get(self):
|
||||
try:
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user_id = decoded_token.get("sub")
|
||||
|
||||
user_doc = ensure_user_doc(user_id)
|
||||
shared_with_ids = user_doc.get("agent_preferences", {}).get(
|
||||
"shared_with_me", []
|
||||
)
|
||||
shared_object_ids = [ObjectId(id) for id in shared_with_ids]
|
||||
|
||||
shared_agents_cursor = agents_collection.find(
|
||||
{"_id": {"$in": shared_object_ids}, "shared_publicly": True}
|
||||
)
|
||||
shared_agents = list(shared_agents_cursor)
|
||||
|
||||
found_ids_set = {str(agent["_id"]) for agent in shared_agents}
|
||||
stale_ids = [id for id in shared_with_ids if id not in found_ids_set]
|
||||
if stale_ids:
|
||||
users_collection.update_one(
|
||||
{"user_id": user_id},
|
||||
{"$pullAll": {"agent_preferences.shared_with_me": stale_ids}},
|
||||
)
|
||||
dual_write(UsersRepository,
|
||||
lambda repo, uid=user_id, ids=stale_ids: repo.remove_shared_bulk(uid, ids)
|
||||
)
|
||||
pinned_ids = set(user_doc.get("agent_preferences", {}).get("pinned", []))
|
||||
|
||||
list_shared_agents = [
|
||||
{
|
||||
"id": str(agent["_id"]),
|
||||
"name": agent.get("name", ""),
|
||||
"description": agent.get("description", ""),
|
||||
"image": (
|
||||
generate_image_url(agent["image"]) if agent.get("image") else ""
|
||||
),
|
||||
"tools": agent.get("tools", []),
|
||||
"tool_details": resolve_tool_details(agent.get("tools", [])),
|
||||
"agent_type": agent.get("agent_type", ""),
|
||||
"status": agent.get("status", ""),
|
||||
"json_schema": agent.get("json_schema"),
|
||||
"limited_token_mode": agent.get("limited_token_mode", False),
|
||||
"token_limit": agent.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"]),
|
||||
"limited_request_mode": agent.get("limited_request_mode", False),
|
||||
"request_limit": agent.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"]),
|
||||
"created_at": agent.get("createdAt", ""),
|
||||
"updated_at": agent.get("updatedAt", ""),
|
||||
"pinned": str(agent["_id"]) in pinned_ids,
|
||||
"shared": agent.get("shared_publicly", False),
|
||||
"shared_token": agent.get("shared_token", ""),
|
||||
"shared_metadata": agent.get("shared_metadata", {}),
|
||||
}
|
||||
for agent in shared_agents
|
||||
]
|
||||
|
||||
return make_response(jsonify(list_shared_agents), 200)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving shared agents: {err}")
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
|
||||
@agents_sharing_ns.route("/share_agent")
|
||||
class ShareAgent(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"ShareAgentModel",
|
||||
{
|
||||
"id": fields.String(required=True, description="ID of the agent"),
|
||||
"shared": fields.Boolean(
|
||||
required=True, description="Share or unshare the agent"
|
||||
),
|
||||
"username": fields.String(
|
||||
required=False, description="Name of the user"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Share or unshare an agent")
|
||||
def put(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
|
||||
data = request.get_json()
|
||||
if not data:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Missing JSON body"}), 400
|
||||
)
|
||||
agent_id = data.get("id")
|
||||
shared = data.get("shared")
|
||||
username = data.get("username", "")
|
||||
|
||||
if not agent_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "ID is required"}), 400
|
||||
)
|
||||
if shared is None:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Shared parameter is required and must be true or false",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
try:
|
||||
try:
|
||||
agent_oid = ObjectId(agent_id)
|
||||
except Exception:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid agent ID"}), 400
|
||||
)
|
||||
agent = agents_collection.find_one({"_id": agent_oid, "user": user})
|
||||
if not agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Agent not found"}), 404
|
||||
)
|
||||
if shared:
|
||||
shared_metadata = {
|
||||
"shared_by": username,
|
||||
"shared_at": datetime.datetime.now(datetime.timezone.utc),
|
||||
}
|
||||
shared_token = secrets.token_urlsafe(32)
|
||||
agents_collection.update_one(
|
||||
{"_id": agent_oid, "user": user},
|
||||
{
|
||||
"$set": {
|
||||
"shared_publicly": shared,
|
||||
"shared_metadata": shared_metadata,
|
||||
"shared_token": shared_token,
|
||||
}
|
||||
},
|
||||
)
|
||||
else:
|
||||
agents_collection.update_one(
|
||||
{"_id": agent_oid, "user": user},
|
||||
{"$set": {"shared_publicly": shared, "shared_token": None}},
|
||||
{"$unset": {"shared_metadata": ""}},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error sharing/unsharing agent: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False, "error": "Failed to update agent sharing status"}), 400)
|
||||
shared_token = shared_token if shared else None
|
||||
return make_response(
|
||||
jsonify({"success": True, "shared_token": shared_token}), 200
|
||||
)
|
||||
119
application/api/user/agents/webhooks.py
Normal file
119
application/api/user/agents/webhooks.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""Agent management webhook handlers."""
|
||||
|
||||
import secrets
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import agents_collection, require_agent
|
||||
from application.api.user.tasks import process_agent_webhook
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
agents_webhooks_ns = Namespace(
|
||||
"agents", description="Agent management operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
@agents_webhooks_ns.route("/agent_webhook")
|
||||
class AgentWebhook(Resource):
|
||||
@api.doc(
|
||||
params={"id": "ID of the agent"},
|
||||
description="Generate webhook URL for the agent",
|
||||
)
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
agent_id = request.args.get("id")
|
||||
if not agent_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "ID is required"}), 400
|
||||
)
|
||||
try:
|
||||
agent = agents_collection.find_one(
|
||||
{"_id": ObjectId(agent_id), "user": user}
|
||||
)
|
||||
if not agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Agent not found"}), 404
|
||||
)
|
||||
webhook_token = agent.get("incoming_webhook_token")
|
||||
if not webhook_token:
|
||||
webhook_token = secrets.token_urlsafe(32)
|
||||
agents_collection.update_one(
|
||||
{"_id": ObjectId(agent_id), "user": user},
|
||||
{"$set": {"incoming_webhook_token": webhook_token}},
|
||||
)
|
||||
base_url = settings.API_URL.rstrip("/")
|
||||
full_webhook_url = f"{base_url}/api/webhooks/agents/{webhook_token}"
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error generating webhook URL: {err}", exc_info=True
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Error generating webhook URL"}),
|
||||
400,
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": True, "webhook_url": full_webhook_url}), 200
|
||||
)
|
||||
|
||||
|
||||
@agents_webhooks_ns.route("/webhooks/agents/<string:webhook_token>")
|
||||
class AgentWebhookListener(Resource):
|
||||
method_decorators = [require_agent]
|
||||
|
||||
def _enqueue_webhook_task(self, agent_id_str, payload, source_method):
|
||||
if not payload:
|
||||
current_app.logger.warning(
|
||||
f"Webhook ({source_method}) received for agent {agent_id_str} with empty payload."
|
||||
)
|
||||
current_app.logger.info(
|
||||
f"Incoming {source_method} webhook for agent {agent_id_str}. Enqueuing task with payload: {payload}"
|
||||
)
|
||||
|
||||
try:
|
||||
task = process_agent_webhook.delay(
|
||||
agent_id=agent_id_str,
|
||||
payload=payload,
|
||||
)
|
||||
current_app.logger.info(
|
||||
f"Task {task.id} enqueued for agent {agent_id_str} ({source_method})."
|
||||
)
|
||||
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error enqueuing webhook task ({source_method}) for agent {agent_id_str}: {err}",
|
||||
exc_info=True,
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Error processing webhook"}), 500
|
||||
)
|
||||
|
||||
@api.doc(
|
||||
description="Webhook listener for agent events (POST). Expects JSON payload, which is used to trigger processing.",
|
||||
)
|
||||
def post(self, webhook_token, agent, agent_id_str):
|
||||
payload = request.get_json()
|
||||
if payload is None:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Invalid or missing JSON data in request body",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
return self._enqueue_webhook_task(agent_id_str, payload, source_method="POST")
|
||||
|
||||
@api.doc(
|
||||
description="Webhook listener for agent events (GET). Uses URL query parameters as payload to trigger processing.",
|
||||
)
|
||||
def get(self, webhook_token, agent, agent_id_str):
|
||||
payload = request.args.to_dict(flat=True)
|
||||
return self._enqueue_webhook_task(agent_id_str, payload, source_method="GET")
|
||||
5
application/api/user/analytics/__init__.py
Normal file
5
application/api/user/analytics/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Analytics module."""
|
||||
|
||||
from .routes import analytics_ns
|
||||
|
||||
__all__ = ["analytics_ns"]
|
||||
540
application/api/user/analytics/routes.py
Normal file
540
application/api/user/analytics/routes.py
Normal file
@@ -0,0 +1,540 @@
|
||||
"""Analytics and reporting routes."""
|
||||
|
||||
import datetime
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import (
|
||||
agents_collection,
|
||||
conversations_collection,
|
||||
generate_date_range,
|
||||
generate_hourly_range,
|
||||
generate_minute_range,
|
||||
token_usage_collection,
|
||||
user_logs_collection,
|
||||
)
|
||||
|
||||
analytics_ns = Namespace(
|
||||
"analytics", description="Analytics and reporting operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
@analytics_ns.route("/get_message_analytics")
|
||||
class GetMessageAnalytics(Resource):
|
||||
get_message_analytics_model = api.model(
|
||||
"GetMessageAnalyticsModel",
|
||||
{
|
||||
"api_key_id": fields.String(required=False, description="API Key ID"),
|
||||
"filter_option": fields.String(
|
||||
required=False,
|
||||
description="Filter option for analytics",
|
||||
default="last_30_days",
|
||||
enum=[
|
||||
"last_hour",
|
||||
"last_24_hour",
|
||||
"last_7_days",
|
||||
"last_15_days",
|
||||
"last_30_days",
|
||||
],
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(get_message_analytics_model)
|
||||
@api.doc(description="Get message analytics based on filter option")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
api_key_id = data.get("api_key_id")
|
||||
filter_option = data.get("filter_option", "last_30_days")
|
||||
|
||||
try:
|
||||
api_key = (
|
||||
agents_collection.find_one({"_id": ObjectId(api_key_id), "user": user})[
|
||||
"key"
|
||||
]
|
||||
if api_key_id
|
||||
else None
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
end_date = datetime.datetime.now(datetime.timezone.utc)
|
||||
|
||||
if filter_option == "last_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=1)
|
||||
group_format = "%Y-%m-%d %H:%M:00"
|
||||
elif filter_option == "last_24_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=24)
|
||||
group_format = "%Y-%m-%d %H:00"
|
||||
else:
|
||||
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
|
||||
filter_days = (
|
||||
6
|
||||
if filter_option == "last_7_days"
|
||||
else 14 if filter_option == "last_15_days" else 29
|
||||
)
|
||||
else:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid option"}), 400
|
||||
)
|
||||
start_date = end_date - datetime.timedelta(days=filter_days)
|
||||
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end_date = end_date.replace(
|
||||
hour=23, minute=59, second=59, microsecond=999999
|
||||
)
|
||||
group_format = "%Y-%m-%d"
|
||||
try:
|
||||
match_stage = {
|
||||
"$match": {
|
||||
"user": user,
|
||||
}
|
||||
}
|
||||
if api_key:
|
||||
match_stage["$match"]["api_key"] = api_key
|
||||
pipeline = [
|
||||
match_stage,
|
||||
{"$unwind": "$queries"},
|
||||
{
|
||||
"$match": {
|
||||
"queries.timestamp": {"$gte": start_date, "$lte": end_date}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$group": {
|
||||
"_id": {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$queries.timestamp",
|
||||
}
|
||||
},
|
||||
"count": {"$sum": 1},
|
||||
}
|
||||
},
|
||||
{"$sort": {"_id": 1}},
|
||||
]
|
||||
|
||||
message_data = conversations_collection.aggregate(pipeline)
|
||||
|
||||
if filter_option == "last_hour":
|
||||
intervals = generate_minute_range(start_date, end_date)
|
||||
elif filter_option == "last_24_hour":
|
||||
intervals = generate_hourly_range(start_date, end_date)
|
||||
else:
|
||||
intervals = generate_date_range(start_date, end_date)
|
||||
daily_messages = {interval: 0 for interval in intervals}
|
||||
|
||||
for entry in message_data:
|
||||
daily_messages[entry["_id"]] = entry["count"]
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error getting message analytics: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(
|
||||
jsonify({"success": True, "messages": daily_messages}), 200
|
||||
)
|
||||
|
||||
|
||||
@analytics_ns.route("/get_token_analytics")
|
||||
class GetTokenAnalytics(Resource):
|
||||
get_token_analytics_model = api.model(
|
||||
"GetTokenAnalyticsModel",
|
||||
{
|
||||
"api_key_id": fields.String(required=False, description="API Key ID"),
|
||||
"filter_option": fields.String(
|
||||
required=False,
|
||||
description="Filter option for analytics",
|
||||
default="last_30_days",
|
||||
enum=[
|
||||
"last_hour",
|
||||
"last_24_hour",
|
||||
"last_7_days",
|
||||
"last_15_days",
|
||||
"last_30_days",
|
||||
],
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(get_token_analytics_model)
|
||||
@api.doc(description="Get token analytics data")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
api_key_id = data.get("api_key_id")
|
||||
filter_option = data.get("filter_option", "last_30_days")
|
||||
|
||||
try:
|
||||
api_key = (
|
||||
agents_collection.find_one({"_id": ObjectId(api_key_id), "user": user})[
|
||||
"key"
|
||||
]
|
||||
if api_key_id
|
||||
else None
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
end_date = datetime.datetime.now(datetime.timezone.utc)
|
||||
|
||||
if filter_option == "last_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=1)
|
||||
group_format = "%Y-%m-%d %H:%M:00"
|
||||
group_stage = {
|
||||
"$group": {
|
||||
"_id": {
|
||||
"minute": {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$timestamp",
|
||||
}
|
||||
}
|
||||
},
|
||||
"total_tokens": {
|
||||
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
|
||||
},
|
||||
}
|
||||
}
|
||||
elif filter_option == "last_24_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=24)
|
||||
group_format = "%Y-%m-%d %H:00"
|
||||
group_stage = {
|
||||
"$group": {
|
||||
"_id": {
|
||||
"hour": {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$timestamp",
|
||||
}
|
||||
}
|
||||
},
|
||||
"total_tokens": {
|
||||
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
|
||||
},
|
||||
}
|
||||
}
|
||||
else:
|
||||
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
|
||||
filter_days = (
|
||||
6
|
||||
if filter_option == "last_7_days"
|
||||
else (14 if filter_option == "last_15_days" else 29)
|
||||
)
|
||||
else:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid option"}), 400
|
||||
)
|
||||
start_date = end_date - datetime.timedelta(days=filter_days)
|
||||
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end_date = end_date.replace(
|
||||
hour=23, minute=59, second=59, microsecond=999999
|
||||
)
|
||||
group_format = "%Y-%m-%d"
|
||||
group_stage = {
|
||||
"$group": {
|
||||
"_id": {
|
||||
"day": {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$timestamp",
|
||||
}
|
||||
}
|
||||
},
|
||||
"total_tokens": {
|
||||
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
|
||||
},
|
||||
}
|
||||
}
|
||||
try:
|
||||
match_stage = {
|
||||
"$match": {
|
||||
"user_id": user,
|
||||
"timestamp": {"$gte": start_date, "$lte": end_date},
|
||||
}
|
||||
}
|
||||
if api_key:
|
||||
match_stage["$match"]["api_key"] = api_key
|
||||
token_usage_data = token_usage_collection.aggregate(
|
||||
[
|
||||
match_stage,
|
||||
group_stage,
|
||||
{"$sort": {"_id": 1}},
|
||||
]
|
||||
)
|
||||
|
||||
if filter_option == "last_hour":
|
||||
intervals = generate_minute_range(start_date, end_date)
|
||||
elif filter_option == "last_24_hour":
|
||||
intervals = generate_hourly_range(start_date, end_date)
|
||||
else:
|
||||
intervals = generate_date_range(start_date, end_date)
|
||||
daily_token_usage = {interval: 0 for interval in intervals}
|
||||
|
||||
for entry in token_usage_data:
|
||||
if filter_option == "last_hour":
|
||||
daily_token_usage[entry["_id"]["minute"]] = entry["total_tokens"]
|
||||
elif filter_option == "last_24_hour":
|
||||
daily_token_usage[entry["_id"]["hour"]] = entry["total_tokens"]
|
||||
else:
|
||||
daily_token_usage[entry["_id"]["day"]] = entry["total_tokens"]
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error getting token analytics: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(
|
||||
jsonify({"success": True, "token_usage": daily_token_usage}), 200
|
||||
)
|
||||
|
||||
|
||||
@analytics_ns.route("/get_feedback_analytics")
|
||||
class GetFeedbackAnalytics(Resource):
|
||||
get_feedback_analytics_model = api.model(
|
||||
"GetFeedbackAnalyticsModel",
|
||||
{
|
||||
"api_key_id": fields.String(required=False, description="API Key ID"),
|
||||
"filter_option": fields.String(
|
||||
required=False,
|
||||
description="Filter option for analytics",
|
||||
default="last_30_days",
|
||||
enum=[
|
||||
"last_hour",
|
||||
"last_24_hour",
|
||||
"last_7_days",
|
||||
"last_15_days",
|
||||
"last_30_days",
|
||||
],
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(get_feedback_analytics_model)
|
||||
@api.doc(description="Get feedback analytics data")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
api_key_id = data.get("api_key_id")
|
||||
filter_option = data.get("filter_option", "last_30_days")
|
||||
|
||||
try:
|
||||
api_key = (
|
||||
agents_collection.find_one({"_id": ObjectId(api_key_id), "user": user})[
|
||||
"key"
|
||||
]
|
||||
if api_key_id
|
||||
else None
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
end_date = datetime.datetime.now(datetime.timezone.utc)
|
||||
|
||||
if filter_option == "last_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=1)
|
||||
group_format = "%Y-%m-%d %H:%M:00"
|
||||
date_field = {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$queries.feedback_timestamp",
|
||||
}
|
||||
}
|
||||
elif filter_option == "last_24_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=24)
|
||||
group_format = "%Y-%m-%d %H:00"
|
||||
date_field = {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$queries.feedback_timestamp",
|
||||
}
|
||||
}
|
||||
else:
|
||||
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
|
||||
filter_days = (
|
||||
6
|
||||
if filter_option == "last_7_days"
|
||||
else (14 if filter_option == "last_15_days" else 29)
|
||||
)
|
||||
else:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid option"}), 400
|
||||
)
|
||||
start_date = end_date - datetime.timedelta(days=filter_days)
|
||||
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end_date = end_date.replace(
|
||||
hour=23, minute=59, second=59, microsecond=999999
|
||||
)
|
||||
group_format = "%Y-%m-%d"
|
||||
date_field = {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$queries.feedback_timestamp",
|
||||
}
|
||||
}
|
||||
try:
|
||||
match_stage = {
|
||||
"$match": {
|
||||
"queries.feedback_timestamp": {
|
||||
"$gte": start_date,
|
||||
"$lte": end_date,
|
||||
},
|
||||
"queries.feedback": {"$exists": True},
|
||||
}
|
||||
}
|
||||
if api_key:
|
||||
match_stage["$match"]["api_key"] = api_key
|
||||
pipeline = [
|
||||
match_stage,
|
||||
{"$unwind": "$queries"},
|
||||
{"$match": {"queries.feedback": {"$exists": True}}},
|
||||
{
|
||||
"$group": {
|
||||
"_id": {"time": date_field, "feedback": "$queries.feedback"},
|
||||
"count": {"$sum": 1},
|
||||
}
|
||||
},
|
||||
{
|
||||
"$group": {
|
||||
"_id": "$_id.time",
|
||||
"positive": {
|
||||
"$sum": {
|
||||
"$cond": [
|
||||
{"$eq": ["$_id.feedback", "LIKE"]},
|
||||
"$count",
|
||||
0,
|
||||
]
|
||||
}
|
||||
},
|
||||
"negative": {
|
||||
"$sum": {
|
||||
"$cond": [
|
||||
{"$eq": ["$_id.feedback", "DISLIKE"]},
|
||||
"$count",
|
||||
0,
|
||||
]
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
{"$sort": {"_id": 1}},
|
||||
]
|
||||
|
||||
feedback_data = conversations_collection.aggregate(pipeline)
|
||||
|
||||
if filter_option == "last_hour":
|
||||
intervals = generate_minute_range(start_date, end_date)
|
||||
elif filter_option == "last_24_hour":
|
||||
intervals = generate_hourly_range(start_date, end_date)
|
||||
else:
|
||||
intervals = generate_date_range(start_date, end_date)
|
||||
daily_feedback = {
|
||||
interval: {"positive": 0, "negative": 0} for interval in intervals
|
||||
}
|
||||
|
||||
for entry in feedback_data:
|
||||
daily_feedback[entry["_id"]] = {
|
||||
"positive": entry["positive"],
|
||||
"negative": entry["negative"],
|
||||
}
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error getting feedback analytics: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(
|
||||
jsonify({"success": True, "feedback": daily_feedback}), 200
|
||||
)
|
||||
|
||||
|
||||
@analytics_ns.route("/get_user_logs")
|
||||
class GetUserLogs(Resource):
|
||||
get_user_logs_model = api.model(
|
||||
"GetUserLogsModel",
|
||||
{
|
||||
"page": fields.Integer(
|
||||
required=False,
|
||||
description="Page number for pagination",
|
||||
default=1,
|
||||
),
|
||||
"api_key_id": fields.String(required=False, description="API Key ID"),
|
||||
"page_size": fields.Integer(
|
||||
required=False,
|
||||
description="Number of logs per page",
|
||||
default=10,
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(get_user_logs_model)
|
||||
@api.doc(description="Get user logs with pagination")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
page = int(data.get("page", 1))
|
||||
api_key_id = data.get("api_key_id")
|
||||
page_size = int(data.get("page_size", 10))
|
||||
skip = (page - 1) * page_size
|
||||
|
||||
try:
|
||||
api_key = (
|
||||
agents_collection.find_one({"_id": ObjectId(api_key_id)})["key"]
|
||||
if api_key_id
|
||||
else None
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
query = {"user": user}
|
||||
if api_key:
|
||||
query = {"api_key": api_key}
|
||||
items_cursor = (
|
||||
user_logs_collection.find(query)
|
||||
.sort("timestamp", -1)
|
||||
.skip(skip)
|
||||
.limit(page_size + 1)
|
||||
)
|
||||
items = list(items_cursor)
|
||||
|
||||
results = [
|
||||
{
|
||||
"id": str(item.get("_id")),
|
||||
"action": item.get("action"),
|
||||
"level": item.get("level"),
|
||||
"user": item.get("user"),
|
||||
"question": item.get("question"),
|
||||
"sources": item.get("sources"),
|
||||
"retriever_params": item.get("retriever_params"),
|
||||
"timestamp": item.get("timestamp"),
|
||||
}
|
||||
for item in items[:page_size]
|
||||
]
|
||||
|
||||
has_more = len(items) > page_size
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"logs": results,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"has_more": has_more,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
5
application/api/user/attachments/__init__.py
Normal file
5
application/api/user/attachments/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Attachments module."""
|
||||
|
||||
from .routes import attachments_ns
|
||||
|
||||
__all__ = ["attachments_ns"]
|
||||
678
application/api/user/attachments/routes.py
Normal file
678
application/api/user/attachments/routes.py
Normal file
@@ -0,0 +1,678 @@
|
||||
"""File attachments and media routes."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.cache import get_redis_instance
|
||||
from application.core.settings import settings
|
||||
from application.stt.constants import (
|
||||
SUPPORTED_AUDIO_EXTENSIONS,
|
||||
SUPPORTED_AUDIO_MIME_TYPES,
|
||||
)
|
||||
from application.stt.upload_limits import (
|
||||
AudioFileTooLargeError,
|
||||
build_stt_file_size_limit_message,
|
||||
enforce_audio_file_size_limit,
|
||||
is_audio_filename,
|
||||
)
|
||||
from application.stt.live_session import (
|
||||
apply_live_stt_hypothesis,
|
||||
create_live_stt_session,
|
||||
delete_live_stt_session,
|
||||
finalize_live_stt_session,
|
||||
get_live_stt_transcript_text,
|
||||
load_live_stt_session,
|
||||
save_live_stt_session,
|
||||
)
|
||||
from application.stt.stt_creator import STTCreator
|
||||
from application.tts.tts_creator import TTSCreator
|
||||
from application.utils import safe_filename
|
||||
|
||||
|
||||
attachments_ns = Namespace(
|
||||
"attachments", description="File attachments and media operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
def _resolve_authenticated_user():
|
||||
decoded_token = getattr(request, "decoded_token", None)
|
||||
api_key = request.form.get("api_key") or request.args.get("api_key")
|
||||
|
||||
if decoded_token:
|
||||
return safe_filename(decoded_token.get("sub"))
|
||||
|
||||
if api_key:
|
||||
from application.api.user.base import agents_collection
|
||||
|
||||
agent = agents_collection.find_one({"key": api_key})
|
||||
if not agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid API key"}), 401
|
||||
)
|
||||
return safe_filename(agent.get("user"))
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _get_uploaded_file_size(file) -> int:
|
||||
try:
|
||||
current_position = file.stream.tell()
|
||||
file.stream.seek(0, os.SEEK_END)
|
||||
size_bytes = file.stream.tell()
|
||||
file.stream.seek(current_position)
|
||||
return size_bytes
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
|
||||
def _is_supported_audio_mimetype(mimetype: str) -> bool:
|
||||
if not mimetype:
|
||||
return True
|
||||
normalized = mimetype.split(";")[0].strip().lower()
|
||||
return normalized.startswith("audio/") or normalized in SUPPORTED_AUDIO_MIME_TYPES
|
||||
|
||||
|
||||
def _enforce_uploaded_audio_size_limit(file, filename: str) -> None:
|
||||
if not is_audio_filename(filename):
|
||||
return
|
||||
size_bytes = _get_uploaded_file_size(file)
|
||||
if size_bytes:
|
||||
enforce_audio_file_size_limit(size_bytes)
|
||||
|
||||
|
||||
def _get_store_attachment_user_error(exc: Exception) -> str:
|
||||
if isinstance(exc, AudioFileTooLargeError):
|
||||
return build_stt_file_size_limit_message()
|
||||
return "Failed to process file"
|
||||
|
||||
|
||||
def _require_live_stt_redis():
|
||||
redis_client = get_redis_instance()
|
||||
if redis_client:
|
||||
return redis_client
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Live transcription is unavailable"}),
|
||||
503,
|
||||
)
|
||||
|
||||
|
||||
def _parse_bool_form_value(value: str | None) -> bool:
|
||||
if value is None:
|
||||
return False
|
||||
return value.strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
@attachments_ns.route("/store_attachment")
|
||||
class StoreAttachment(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"AttachmentModel",
|
||||
{
|
||||
"file": fields.Raw(required=True, description="File(s) to upload"),
|
||||
"api_key": fields.String(
|
||||
required=False, description="API key (optional)"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Stores one or multiple attachments without vectorization or training. Supports user or API key authentication."
|
||||
)
|
||||
def post(self):
|
||||
auth_user = _resolve_authenticated_user()
|
||||
if hasattr(auth_user, "status_code"):
|
||||
return auth_user
|
||||
|
||||
files = request.files.getlist("file")
|
||||
if not files:
|
||||
single_file = request.files.get("file")
|
||||
if single_file:
|
||||
files = [single_file]
|
||||
|
||||
if not files or all(f.filename == "" for f in files):
|
||||
return make_response(
|
||||
jsonify({"status": "error", "message": "Missing file(s)"}),
|
||||
400,
|
||||
)
|
||||
|
||||
user = auth_user
|
||||
if not user:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Authentication required"}), 401
|
||||
)
|
||||
|
||||
try:
|
||||
from application.api.user.tasks import store_attachment
|
||||
from application.api.user.base import storage
|
||||
|
||||
tasks = []
|
||||
errors = []
|
||||
original_file_count = len(files)
|
||||
|
||||
for idx, file in enumerate(files):
|
||||
try:
|
||||
attachment_id = ObjectId()
|
||||
original_filename = safe_filename(os.path.basename(file.filename))
|
||||
_enforce_uploaded_audio_size_limit(file, original_filename)
|
||||
relative_path = f"{settings.UPLOAD_FOLDER}/{user}/attachments/{str(attachment_id)}/{original_filename}"
|
||||
|
||||
metadata = storage.save_file(file, relative_path)
|
||||
file_info = {
|
||||
"filename": original_filename,
|
||||
"attachment_id": str(attachment_id),
|
||||
"path": relative_path,
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
task = store_attachment.delay(file_info, user)
|
||||
tasks.append({
|
||||
"task_id": task.id,
|
||||
"filename": original_filename,
|
||||
"attachment_id": str(attachment_id),
|
||||
"upload_index": idx,
|
||||
})
|
||||
except Exception as file_err:
|
||||
current_app.logger.error(f"Error processing file {idx} ({file.filename}): {file_err}", exc_info=True)
|
||||
errors.append({
|
||||
"upload_index": idx,
|
||||
"filename": file.filename,
|
||||
"error": _get_store_attachment_user_error(file_err),
|
||||
})
|
||||
|
||||
if not tasks:
|
||||
if errors and all(
|
||||
error.get("error") == build_stt_file_size_limit_message()
|
||||
for error in errors
|
||||
):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": build_stt_file_size_limit_message(),
|
||||
"errors": errors,
|
||||
}
|
||||
),
|
||||
413,
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"status": "error", "message": "No valid files to upload"}),
|
||||
400,
|
||||
)
|
||||
|
||||
if original_file_count == 1 and len(tasks) == 1:
|
||||
current_app.logger.info("Returning single task_id response")
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"task_id": tasks[0]["task_id"],
|
||||
"message": "File uploaded successfully. Processing started.",
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
else:
|
||||
response_data = {
|
||||
"success": True,
|
||||
"tasks": tasks,
|
||||
"message": f"{len(tasks)} file(s) uploaded successfully. Processing started.",
|
||||
}
|
||||
if errors:
|
||||
response_data["errors"] = errors
|
||||
response_data["message"] += f" {len(errors)} file(s) failed."
|
||||
|
||||
return make_response(
|
||||
jsonify(response_data),
|
||||
200,
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error storing attachment: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False, "error": "Failed to store attachment"}), 400)
|
||||
|
||||
|
||||
@attachments_ns.route("/stt")
|
||||
class SpeechToText(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"SpeechToTextModel",
|
||||
{
|
||||
"file": fields.Raw(required=True, description="Audio file"),
|
||||
"language": fields.String(
|
||||
required=False, description="Optional transcription language hint"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Transcribe an uploaded audio file")
|
||||
def post(self):
|
||||
auth_user = _resolve_authenticated_user()
|
||||
if hasattr(auth_user, "status_code"):
|
||||
return auth_user
|
||||
if not auth_user:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Authentication required"}),
|
||||
401,
|
||||
)
|
||||
|
||||
file = request.files.get("file")
|
||||
if not file or file.filename == "":
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Missing file"}),
|
||||
400,
|
||||
)
|
||||
|
||||
filename = safe_filename(os.path.basename(file.filename))
|
||||
suffix = Path(filename).suffix.lower()
|
||||
if suffix not in SUPPORTED_AUDIO_EXTENSIONS:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Unsupported audio format"}),
|
||||
400,
|
||||
)
|
||||
|
||||
if not _is_supported_audio_mimetype(file.mimetype or ""):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Unsupported audio MIME type"}),
|
||||
400,
|
||||
)
|
||||
|
||||
try:
|
||||
_enforce_uploaded_audio_size_limit(file, filename)
|
||||
except AudioFileTooLargeError:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": build_stt_file_size_limit_message(),
|
||||
}
|
||||
),
|
||||
413,
|
||||
)
|
||||
|
||||
temp_path = None
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
|
||||
file.save(temp_file.name)
|
||||
temp_path = Path(temp_file.name)
|
||||
|
||||
stt_instance = STTCreator.create_stt(settings.STT_PROVIDER)
|
||||
transcript = stt_instance.transcribe(
|
||||
temp_path,
|
||||
language=request.form.get("language") or settings.STT_LANGUAGE,
|
||||
timestamps=settings.STT_ENABLE_TIMESTAMPS,
|
||||
diarize=settings.STT_ENABLE_DIARIZATION,
|
||||
)
|
||||
return make_response(jsonify({"success": True, **transcript}), 200)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error transcribing audio: {err}", exc_info=True)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Failed to transcribe audio"}),
|
||||
400,
|
||||
)
|
||||
finally:
|
||||
if temp_path and temp_path.exists():
|
||||
temp_path.unlink()
|
||||
|
||||
|
||||
@attachments_ns.route("/stt/live/start")
|
||||
class LiveSpeechToTextStart(Resource):
|
||||
@api.doc(description="Start a live speech-to-text session")
|
||||
def post(self):
|
||||
auth_user = _resolve_authenticated_user()
|
||||
if hasattr(auth_user, "status_code"):
|
||||
return auth_user
|
||||
if not auth_user:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Authentication required"}),
|
||||
401,
|
||||
)
|
||||
|
||||
redis_client = _require_live_stt_redis()
|
||||
if hasattr(redis_client, "status_code"):
|
||||
return redis_client
|
||||
|
||||
payload = request.get_json(silent=True) or {}
|
||||
session_state = create_live_stt_session(
|
||||
user=auth_user,
|
||||
language=payload.get("language") or settings.STT_LANGUAGE,
|
||||
)
|
||||
save_live_stt_session(redis_client, session_state)
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"session_id": session_state["session_id"],
|
||||
"language": session_state.get("language"),
|
||||
"committed_text": "",
|
||||
"mutable_text": "",
|
||||
"previous_hypothesis": "",
|
||||
"latest_hypothesis": "",
|
||||
"finalized_text": "",
|
||||
"pending_text": "",
|
||||
"transcript_text": "",
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
|
||||
|
||||
@attachments_ns.route("/stt/live/chunk")
|
||||
class LiveSpeechToTextChunk(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"LiveSpeechToTextChunkModel",
|
||||
{
|
||||
"session_id": fields.String(
|
||||
required=True, description="Live transcription session ID"
|
||||
),
|
||||
"chunk_index": fields.Integer(
|
||||
required=True, description="Sequential chunk index"
|
||||
),
|
||||
"is_silence": fields.Boolean(
|
||||
required=False,
|
||||
description="Whether the latest capture window was mostly silence",
|
||||
),
|
||||
"file": fields.Raw(required=True, description="Audio chunk"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Transcribe a chunk for a live speech-to-text session")
|
||||
def post(self):
|
||||
auth_user = _resolve_authenticated_user()
|
||||
if hasattr(auth_user, "status_code"):
|
||||
return auth_user
|
||||
if not auth_user:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Authentication required"}),
|
||||
401,
|
||||
)
|
||||
|
||||
redis_client = _require_live_stt_redis()
|
||||
if hasattr(redis_client, "status_code"):
|
||||
return redis_client
|
||||
|
||||
session_id = request.form.get("session_id", "").strip()
|
||||
if not session_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Missing session_id"}),
|
||||
400,
|
||||
)
|
||||
|
||||
session_state = load_live_stt_session(redis_client, session_id)
|
||||
if not session_state:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Live transcription session not found",
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
|
||||
if safe_filename(str(session_state.get("user", ""))) != auth_user:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Forbidden"}),
|
||||
403,
|
||||
)
|
||||
|
||||
chunk_index_raw = request.form.get("chunk_index", "").strip()
|
||||
if chunk_index_raw == "":
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Missing chunk_index"}),
|
||||
400,
|
||||
)
|
||||
|
||||
try:
|
||||
chunk_index = int(chunk_index_raw)
|
||||
except ValueError:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid chunk_index"}),
|
||||
400,
|
||||
)
|
||||
is_silence = _parse_bool_form_value(request.form.get("is_silence"))
|
||||
|
||||
file = request.files.get("file")
|
||||
if not file or file.filename == "":
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Missing file"}),
|
||||
400,
|
||||
)
|
||||
|
||||
filename = safe_filename(os.path.basename(file.filename))
|
||||
suffix = Path(filename).suffix.lower()
|
||||
if suffix not in SUPPORTED_AUDIO_EXTENSIONS:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Unsupported audio format"}),
|
||||
400,
|
||||
)
|
||||
|
||||
if not _is_supported_audio_mimetype(file.mimetype or ""):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Unsupported audio MIME type"}),
|
||||
400,
|
||||
)
|
||||
|
||||
try:
|
||||
_enforce_uploaded_audio_size_limit(file, filename)
|
||||
except AudioFileTooLargeError:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": build_stt_file_size_limit_message(),
|
||||
}
|
||||
),
|
||||
413,
|
||||
)
|
||||
|
||||
temp_path = None
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
|
||||
file.save(temp_file.name)
|
||||
temp_path = Path(temp_file.name)
|
||||
|
||||
session_language = session_state.get("language") or settings.STT_LANGUAGE
|
||||
stt_instance = STTCreator.create_stt(settings.STT_PROVIDER)
|
||||
transcript = stt_instance.transcribe(
|
||||
temp_path,
|
||||
language=session_language,
|
||||
timestamps=False,
|
||||
diarize=False,
|
||||
)
|
||||
if not session_state.get("language") and transcript.get("language"):
|
||||
session_state["language"] = transcript["language"]
|
||||
|
||||
try:
|
||||
apply_live_stt_hypothesis(
|
||||
session_state,
|
||||
str(transcript.get("text", "")),
|
||||
chunk_index,
|
||||
is_silence=is_silence,
|
||||
)
|
||||
except ValueError:
|
||||
current_app.logger.warning(
|
||||
"Invalid live transcription chunk",
|
||||
exc_info=True,
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Invalid live transcription chunk",
|
||||
}
|
||||
),
|
||||
409,
|
||||
)
|
||||
save_live_stt_session(redis_client, session_state)
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"session_id": session_id,
|
||||
"chunk_index": chunk_index,
|
||||
"chunk_text": transcript.get("text", ""),
|
||||
"is_silence": is_silence,
|
||||
"language": session_state.get("language"),
|
||||
"committed_text": session_state.get("committed_text", ""),
|
||||
"mutable_text": session_state.get("mutable_text", ""),
|
||||
"previous_hypothesis": session_state.get(
|
||||
"previous_hypothesis", ""
|
||||
),
|
||||
"latest_hypothesis": session_state.get(
|
||||
"latest_hypothesis", ""
|
||||
),
|
||||
"finalized_text": session_state.get("committed_text", ""),
|
||||
"pending_text": session_state.get("mutable_text", ""),
|
||||
"transcript_text": get_live_stt_transcript_text(session_state),
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error transcribing live audio chunk: {err}", exc_info=True
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Failed to transcribe audio"}),
|
||||
400,
|
||||
)
|
||||
finally:
|
||||
if temp_path and temp_path.exists():
|
||||
temp_path.unlink()
|
||||
|
||||
|
||||
@attachments_ns.route("/stt/live/finish")
|
||||
class LiveSpeechToTextFinish(Resource):
|
||||
@api.doc(description="Finish a live speech-to-text session")
|
||||
def post(self):
|
||||
auth_user = _resolve_authenticated_user()
|
||||
if hasattr(auth_user, "status_code"):
|
||||
return auth_user
|
||||
if not auth_user:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Authentication required"}),
|
||||
401,
|
||||
)
|
||||
|
||||
redis_client = _require_live_stt_redis()
|
||||
if hasattr(redis_client, "status_code"):
|
||||
return redis_client
|
||||
|
||||
payload = request.get_json(silent=True) or {}
|
||||
session_id = str(payload.get("session_id", "")).strip()
|
||||
if not session_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Missing session_id"}),
|
||||
400,
|
||||
)
|
||||
|
||||
session_state = load_live_stt_session(redis_client, session_id)
|
||||
if not session_state:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Live transcription session not found",
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
|
||||
if safe_filename(str(session_state.get("user", ""))) != auth_user:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Forbidden"}),
|
||||
403,
|
||||
)
|
||||
|
||||
final_text = finalize_live_stt_session(session_state)
|
||||
delete_live_stt_session(redis_client, session_id)
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"session_id": session_id,
|
||||
"language": session_state.get("language"),
|
||||
"text": final_text,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
|
||||
|
||||
@attachments_ns.route("/images/<path:image_path>")
|
||||
class ServeImage(Resource):
|
||||
@api.doc(description="Serve an image from storage")
|
||||
def get(self, image_path):
|
||||
if ".." in image_path or image_path.startswith("/") or "\x00" in image_path:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid image path"}), 400
|
||||
)
|
||||
try:
|
||||
from application.api.user.base import storage
|
||||
|
||||
file_obj = storage.get_file(image_path)
|
||||
extension = image_path.split(".")[-1].lower()
|
||||
content_type = f"image/{extension}"
|
||||
if extension == "jpg":
|
||||
content_type = "image/jpeg"
|
||||
response = make_response(file_obj.read())
|
||||
response.headers.set("Content-Type", content_type)
|
||||
response.headers.set("Cache-Control", "max-age=86400")
|
||||
|
||||
return response
|
||||
except FileNotFoundError:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Image not found"}), 404
|
||||
)
|
||||
except ValueError:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid image path"}), 400
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error serving image: {e}")
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Error retrieving image"}), 500
|
||||
)
|
||||
|
||||
|
||||
@attachments_ns.route("/tts")
|
||||
class TextToSpeech(Resource):
|
||||
tts_model = api.model(
|
||||
"TextToSpeechModel",
|
||||
{
|
||||
"text": fields.String(
|
||||
required=True, description="Text to be synthesized as audio"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(tts_model)
|
||||
@api.doc(description="Synthesize audio speech from text")
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
text = data["text"]
|
||||
try:
|
||||
tts_instance = TTSCreator.create_tts(settings.TTS_PROVIDER)
|
||||
audio_base64, detected_language = tts_instance.text_to_speech(text)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"audio_base64": audio_base64,
|
||||
"lang": detected_language,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error synthesizing audio: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
259
application/api/user/base.py
Normal file
259
application/api/user/base.py
Normal file
@@ -0,0 +1,259 @@
|
||||
"""
|
||||
Shared utilities, database connections, and helper functions for user API routes.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import os
|
||||
import uuid
|
||||
from functools import wraps
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, Response
|
||||
from pymongo import ReturnDocument
|
||||
from werkzeug.utils import secure_filename
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.repositories.users import UsersRepository
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
from application.vectorstore.vector_creator import VectorCreator
|
||||
|
||||
|
||||
storage = StorageCreator.get_storage()
|
||||
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
|
||||
|
||||
conversations_collection = db["conversations"]
|
||||
sources_collection = db["sources"]
|
||||
prompts_collection = db["prompts"]
|
||||
feedback_collection = db["feedback"]
|
||||
agents_collection = db["agents"]
|
||||
agent_folders_collection = db["agent_folders"]
|
||||
token_usage_collection = db["token_usage"]
|
||||
shared_conversations_collections = db["shared_conversations"]
|
||||
users_collection = db["users"]
|
||||
user_logs_collection = db["user_logs"]
|
||||
user_tools_collection = db["user_tools"]
|
||||
attachments_collection = db["attachments"]
|
||||
workflow_runs_collection = db["workflow_runs"]
|
||||
workflows_collection = db["workflows"]
|
||||
workflow_nodes_collection = db["workflow_nodes"]
|
||||
workflow_edges_collection = db["workflow_edges"]
|
||||
|
||||
|
||||
try:
|
||||
agents_collection.create_index(
|
||||
[("shared", 1)],
|
||||
name="shared_index",
|
||||
background=True,
|
||||
)
|
||||
users_collection.create_index("user_id", unique=True)
|
||||
workflows_collection.create_index(
|
||||
[("user", 1)], name="workflow_user_index", background=True
|
||||
)
|
||||
workflow_nodes_collection.create_index(
|
||||
[("workflow_id", 1)], name="node_workflow_index", background=True
|
||||
)
|
||||
workflow_nodes_collection.create_index(
|
||||
[("workflow_id", 1), ("graph_version", 1)],
|
||||
name="node_workflow_graph_version_index",
|
||||
background=True,
|
||||
)
|
||||
workflow_edges_collection.create_index(
|
||||
[("workflow_id", 1)], name="edge_workflow_index", background=True
|
||||
)
|
||||
workflow_edges_collection.create_index(
|
||||
[("workflow_id", 1), ("graph_version", 1)],
|
||||
name="edge_workflow_graph_version_index",
|
||||
background=True,
|
||||
)
|
||||
except Exception as e:
|
||||
print("Error creating indexes:", e)
|
||||
current_dir = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
)
|
||||
|
||||
|
||||
def generate_minute_range(start_date, end_date):
|
||||
"""Generate a dictionary with minute-level time ranges."""
|
||||
return {
|
||||
(start_date + datetime.timedelta(minutes=i)).strftime("%Y-%m-%d %H:%M:00"): 0
|
||||
for i in range(int((end_date - start_date).total_seconds() // 60) + 1)
|
||||
}
|
||||
|
||||
|
||||
def generate_hourly_range(start_date, end_date):
|
||||
"""Generate a dictionary with hourly time ranges."""
|
||||
return {
|
||||
(start_date + datetime.timedelta(hours=i)).strftime("%Y-%m-%d %H:00"): 0
|
||||
for i in range(int((end_date - start_date).total_seconds() // 3600) + 1)
|
||||
}
|
||||
|
||||
|
||||
def generate_date_range(start_date, end_date):
|
||||
"""Generate a dictionary with daily date ranges."""
|
||||
return {
|
||||
(start_date + datetime.timedelta(days=i)).strftime("%Y-%m-%d"): 0
|
||||
for i in range((end_date - start_date).days + 1)
|
||||
}
|
||||
|
||||
|
||||
def ensure_user_doc(user_id):
|
||||
"""
|
||||
Ensure user document exists with proper agent preferences structure.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to ensure
|
||||
|
||||
Returns:
|
||||
The user document
|
||||
"""
|
||||
default_prefs = {
|
||||
"pinned": [],
|
||||
"shared_with_me": [],
|
||||
}
|
||||
|
||||
user_doc = users_collection.find_one_and_update(
|
||||
{"user_id": user_id},
|
||||
{"$setOnInsert": {"agent_preferences": default_prefs}},
|
||||
upsert=True,
|
||||
return_document=ReturnDocument.AFTER,
|
||||
)
|
||||
|
||||
prefs = user_doc.get("agent_preferences", {})
|
||||
updates = {}
|
||||
if "pinned" not in prefs:
|
||||
updates["agent_preferences.pinned"] = []
|
||||
if "shared_with_me" not in prefs:
|
||||
updates["agent_preferences.shared_with_me"] = []
|
||||
if updates:
|
||||
users_collection.update_one({"user_id": user_id}, {"$set": updates})
|
||||
user_doc = users_collection.find_one({"user_id": user_id})
|
||||
|
||||
dual_write(UsersRepository, lambda repo: repo.upsert(user_id))
|
||||
|
||||
return user_doc
|
||||
|
||||
|
||||
def resolve_tool_details(tool_ids):
|
||||
"""
|
||||
Resolve tool IDs to their details.
|
||||
|
||||
Args:
|
||||
tool_ids: List of tool IDs
|
||||
|
||||
Returns:
|
||||
List of tool details with id, name, and display_name
|
||||
"""
|
||||
valid_ids = []
|
||||
for tid in tool_ids:
|
||||
try:
|
||||
valid_ids.append(ObjectId(tid))
|
||||
except Exception:
|
||||
continue
|
||||
tools = user_tools_collection.find(
|
||||
{"_id": {"$in": valid_ids}}
|
||||
) if valid_ids else []
|
||||
return [
|
||||
{
|
||||
"id": str(tool["_id"]),
|
||||
"name": tool.get("name", ""),
|
||||
"display_name": tool.get("customName")
|
||||
or tool.get("displayName")
|
||||
or tool.get("name", ""),
|
||||
}
|
||||
for tool in tools
|
||||
]
|
||||
|
||||
|
||||
def get_vector_store(source_id):
|
||||
"""
|
||||
Get the Vector Store for a given source ID.
|
||||
|
||||
Args:
|
||||
source_id (str): source id of the document
|
||||
|
||||
Returns:
|
||||
Vector store instance
|
||||
"""
|
||||
store = VectorCreator.create_vectorstore(
|
||||
settings.VECTOR_STORE,
|
||||
source_id=source_id,
|
||||
embeddings_key=os.getenv("EMBEDDINGS_KEY"),
|
||||
)
|
||||
return store
|
||||
|
||||
|
||||
def handle_image_upload(
|
||||
request, existing_url: str, user: str, storage, base_path: str = "attachments/"
|
||||
) -> Tuple[str, Optional[Response]]:
|
||||
"""
|
||||
Handle image file upload from request.
|
||||
|
||||
Args:
|
||||
request: Flask request object
|
||||
existing_url: Existing image URL (fallback)
|
||||
user: User ID
|
||||
storage: Storage instance
|
||||
base_path: Base path for upload
|
||||
|
||||
Returns:
|
||||
Tuple of (image_url, error_response)
|
||||
"""
|
||||
image_url = existing_url
|
||||
|
||||
if "image" in request.files:
|
||||
file = request.files["image"]
|
||||
if file.filename != "":
|
||||
filename = secure_filename(file.filename)
|
||||
upload_path = f"{settings.UPLOAD_FOLDER.rstrip('/')}/{user}/{base_path.rstrip('/')}/{uuid.uuid4()}_{filename}"
|
||||
try:
|
||||
storage.save_file(file, upload_path, storage_class="STANDARD")
|
||||
image_url = upload_path
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error uploading image: {e}")
|
||||
return None, make_response(
|
||||
jsonify({"success": False, "message": "Image upload failed"}),
|
||||
400,
|
||||
)
|
||||
return image_url, None
|
||||
|
||||
|
||||
def require_agent(func):
|
||||
"""
|
||||
Decorator to require valid agent webhook token.
|
||||
|
||||
Args:
|
||||
func: Function to decorate
|
||||
|
||||
Returns:
|
||||
Wrapped function
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
webhook_token = kwargs.get("webhook_token")
|
||||
if not webhook_token:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Webhook token missing"}), 400
|
||||
)
|
||||
agent = agents_collection.find_one(
|
||||
{"incoming_webhook_token": webhook_token}, {"_id": 1}
|
||||
)
|
||||
if not agent:
|
||||
current_app.logger.warning(
|
||||
f"Webhook attempt with invalid token: {webhook_token}"
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Agent not found"}), 404
|
||||
)
|
||||
kwargs["agent"] = agent
|
||||
kwargs["agent_id_str"] = str(agent["_id"])
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
5
application/api/user/conversations/__init__.py
Normal file
5
application/api/user/conversations/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Conversation management module."""
|
||||
|
||||
from .routes import conversations_ns
|
||||
|
||||
__all__ = ["conversations_ns"]
|
||||
320
application/api/user/conversations/routes.py
Normal file
320
application/api/user/conversations/routes.py
Normal file
@@ -0,0 +1,320 @@
|
||||
"""Conversation management routes."""
|
||||
|
||||
import datetime
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import attachments_collection, conversations_collection
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||
from application.utils import check_required_fields
|
||||
|
||||
conversations_ns = Namespace(
|
||||
"conversations", description="Conversation management operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
@conversations_ns.route("/delete_conversation")
|
||||
class DeleteConversation(Resource):
|
||||
@api.doc(
|
||||
description="Deletes a conversation by ID",
|
||||
params={"id": "The ID of the conversation to delete"},
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
conversation_id = request.args.get("id")
|
||||
if not conversation_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "ID is required"}), 400
|
||||
)
|
||||
user_id = decoded_token["sub"]
|
||||
try:
|
||||
conversations_collection.delete_one(
|
||||
{"_id": ObjectId(conversation_id), "user": user_id}
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error deleting conversation: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
def _pg_delete(repo: ConversationsRepository) -> None:
|
||||
conv = repo.get_by_legacy_id(conversation_id)
|
||||
if conv is not None:
|
||||
repo.delete(conv["id"], user_id)
|
||||
|
||||
dual_write(ConversationsRepository, _pg_delete)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@conversations_ns.route("/delete_all_conversations")
|
||||
class DeleteAllConversations(Resource):
|
||||
@api.doc(
|
||||
description="Deletes all conversations for a specific user",
|
||||
)
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user_id = decoded_token.get("sub")
|
||||
try:
|
||||
conversations_collection.delete_many({"user": user_id})
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error deleting all conversations: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
dual_write(
|
||||
ConversationsRepository,
|
||||
lambda r, uid=user_id: r.delete_all_for_user(uid),
|
||||
)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@conversations_ns.route("/get_conversations")
|
||||
class GetConversations(Resource):
|
||||
@api.doc(
|
||||
description="Retrieve a list of the latest 30 conversations (excluding API key conversations)",
|
||||
)
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
try:
|
||||
conversations = (
|
||||
conversations_collection.find(
|
||||
{
|
||||
"$or": [
|
||||
{"api_key": {"$exists": False}},
|
||||
{"agent_id": {"$exists": True}},
|
||||
],
|
||||
"user": decoded_token.get("sub"),
|
||||
}
|
||||
)
|
||||
.sort("date", -1)
|
||||
.limit(30)
|
||||
)
|
||||
|
||||
list_conversations = [
|
||||
{
|
||||
"id": str(conversation["_id"]),
|
||||
"name": conversation["name"],
|
||||
"agent_id": conversation.get("agent_id", None),
|
||||
"is_shared_usage": conversation.get("is_shared_usage", False),
|
||||
"shared_token": conversation.get("shared_token", None),
|
||||
}
|
||||
for conversation in conversations
|
||||
]
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error retrieving conversations: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify(list_conversations), 200)
|
||||
|
||||
|
||||
@conversations_ns.route("/get_single_conversation")
|
||||
class GetSingleConversation(Resource):
|
||||
@api.doc(
|
||||
description="Retrieve a single conversation by ID",
|
||||
params={"id": "The conversation ID"},
|
||||
)
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
conversation_id = request.args.get("id")
|
||||
if not conversation_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "ID is required"}), 400
|
||||
)
|
||||
try:
|
||||
conversation = conversations_collection.find_one(
|
||||
{"_id": ObjectId(conversation_id), "user": decoded_token.get("sub")}
|
||||
)
|
||||
if not conversation:
|
||||
return make_response(jsonify({"status": "not found"}), 404)
|
||||
# Process queries to include attachment names
|
||||
|
||||
queries = conversation["queries"]
|
||||
for query in queries:
|
||||
if "attachments" in query and query["attachments"]:
|
||||
attachment_details = []
|
||||
for attachment_id in query["attachments"]:
|
||||
try:
|
||||
attachment = attachments_collection.find_one(
|
||||
{"_id": ObjectId(attachment_id)}
|
||||
)
|
||||
if attachment:
|
||||
attachment_details.append(
|
||||
{
|
||||
"id": str(attachment["_id"]),
|
||||
"fileName": attachment.get(
|
||||
"filename", "Unknown file"
|
||||
),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error retrieving attachment {attachment_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
query["attachments"] = attachment_details
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error retrieving conversation: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
data = {
|
||||
"queries": queries,
|
||||
"agent_id": conversation.get("agent_id"),
|
||||
"is_shared_usage": conversation.get("is_shared_usage", False),
|
||||
"shared_token": conversation.get("shared_token", None),
|
||||
}
|
||||
return make_response(jsonify(data), 200)
|
||||
|
||||
|
||||
@conversations_ns.route("/update_conversation_name")
|
||||
class UpdateConversationName(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"UpdateConversationModel",
|
||||
{
|
||||
"id": fields.String(required=True, description="Conversation ID"),
|
||||
"name": fields.String(
|
||||
required=True, description="New name of the conversation"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Updates the name of a conversation",
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
data = request.get_json()
|
||||
required_fields = ["id", "name"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
user_id = decoded_token.get("sub")
|
||||
try:
|
||||
conversations_collection.update_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user_id},
|
||||
{"$set": {"name": data["name"]}},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error updating conversation name: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
def _pg_rename(repo: ConversationsRepository) -> None:
|
||||
conv = repo.get_by_legacy_id(data["id"])
|
||||
if conv is not None:
|
||||
repo.rename(conv["id"], user_id, data["name"])
|
||||
|
||||
dual_write(ConversationsRepository, _pg_rename)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@conversations_ns.route("/feedback")
|
||||
class SubmitFeedback(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"FeedbackModel",
|
||||
{
|
||||
"question": fields.String(
|
||||
required=False, description="The user question"
|
||||
),
|
||||
"answer": fields.String(required=False, description="The AI answer"),
|
||||
"feedback": fields.String(required=True, description="User feedback"),
|
||||
"question_index": fields.Integer(
|
||||
required=True,
|
||||
description="The question number in that particular conversation",
|
||||
),
|
||||
"conversation_id": fields.String(
|
||||
required=True, description="id of the particular conversation"
|
||||
),
|
||||
"api_key": fields.String(description="Optional API key"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Submit feedback for a conversation",
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
data = request.get_json()
|
||||
required_fields = ["feedback", "conversation_id", "question_index"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
if data["feedback"] is None:
|
||||
# Remove feedback and feedback_timestamp if feedback is null
|
||||
|
||||
conversations_collection.update_one(
|
||||
{
|
||||
"_id": ObjectId(data["conversation_id"]),
|
||||
"user": decoded_token.get("sub"),
|
||||
f"queries.{data['question_index']}": {"$exists": True},
|
||||
},
|
||||
{
|
||||
"$unset": {
|
||||
f"queries.{data['question_index']}.feedback": "",
|
||||
f"queries.{data['question_index']}.feedback_timestamp": "",
|
||||
}
|
||||
},
|
||||
)
|
||||
else:
|
||||
# Set feedback and feedback_timestamp if feedback has a value
|
||||
|
||||
conversations_collection.update_one(
|
||||
{
|
||||
"_id": ObjectId(data["conversation_id"]),
|
||||
"user": decoded_token.get("sub"),
|
||||
f"queries.{data['question_index']}": {"$exists": True},
|
||||
},
|
||||
{
|
||||
"$set": {
|
||||
f"queries.{data['question_index']}.feedback": data[
|
||||
"feedback"
|
||||
],
|
||||
f"queries.{data['question_index']}.feedback_timestamp": datetime.datetime.now(
|
||||
datetime.timezone.utc
|
||||
),
|
||||
}
|
||||
},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error submitting feedback: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
# Dual-write to Postgres: mirror the per-message feedback set/unset.
|
||||
feedback_value = data["feedback"]
|
||||
question_index = int(data["question_index"])
|
||||
feedback_payload = (
|
||||
None if feedback_value is None
|
||||
else {"text": feedback_value, "timestamp": datetime.datetime.now(
|
||||
datetime.timezone.utc
|
||||
).isoformat()}
|
||||
)
|
||||
|
||||
def _pg_feedback(repo: ConversationsRepository) -> None:
|
||||
conv = repo.get_by_legacy_id(data["conversation_id"])
|
||||
if conv is not None:
|
||||
repo.set_feedback(conv["id"], question_index, feedback_payload)
|
||||
|
||||
dual_write(ConversationsRepository, _pg_feedback)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
3
application/api/user/models/__init__.py
Normal file
3
application/api/user/models/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .routes import models_ns
|
||||
|
||||
__all__ = ["models_ns"]
|
||||
25
application/api/user/models/routes.py
Normal file
25
application/api/user/models/routes.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from flask import current_app, jsonify, make_response
|
||||
from flask_restx import Namespace, Resource
|
||||
|
||||
from application.core.model_settings import ModelRegistry
|
||||
|
||||
models_ns = Namespace("models", description="Available models", path="/api")
|
||||
|
||||
|
||||
@models_ns.route("/models")
|
||||
class ModelsListResource(Resource):
|
||||
def get(self):
|
||||
"""Get list of available models with their capabilities."""
|
||||
try:
|
||||
registry = ModelRegistry.get_instance()
|
||||
models = registry.get_enabled_models()
|
||||
|
||||
response = {
|
||||
"models": [model.to_dict() for model in models],
|
||||
"default_model_id": registry.default_model_id,
|
||||
"count": len(models),
|
||||
}
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error fetching models: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 500)
|
||||
return make_response(jsonify(response), 200)
|
||||
5
application/api/user/prompts/__init__.py
Normal file
5
application/api/user/prompts/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Prompts module."""
|
||||
|
||||
from .routes import prompts_ns
|
||||
|
||||
__all__ = ["prompts_ns"]
|
||||
209
application/api/user/prompts/routes.py
Normal file
209
application/api/user/prompts/routes.py
Normal file
@@ -0,0 +1,209 @@
|
||||
"""Prompt management routes."""
|
||||
|
||||
import os
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import current_dir, prompts_collection
|
||||
from application.storage.db.dual_write import dual_write
|
||||
from application.storage.db.repositories.prompts import PromptsRepository
|
||||
from application.utils import check_required_fields
|
||||
|
||||
prompts_ns = Namespace(
|
||||
"prompts", description="Prompt management operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
@prompts_ns.route("/create_prompt")
|
||||
class CreatePrompt(Resource):
|
||||
create_prompt_model = api.model(
|
||||
"CreatePromptModel",
|
||||
{
|
||||
"content": fields.String(
|
||||
required=True, description="Content of the prompt"
|
||||
),
|
||||
"name": fields.String(required=True, description="Name of the prompt"),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(create_prompt_model)
|
||||
@api.doc(description="Create a new prompt")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
data = request.get_json()
|
||||
required_fields = ["content", "name"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
|
||||
resp = prompts_collection.insert_one(
|
||||
{
|
||||
"name": data["name"],
|
||||
"content": data["content"],
|
||||
"user": user,
|
||||
}
|
||||
)
|
||||
new_id = str(resp.inserted_id)
|
||||
dual_write(
|
||||
PromptsRepository,
|
||||
lambda repo, u=user, n=data["name"], c=data["content"], mid=new_id: repo.create(
|
||||
u, n, c, legacy_mongo_id=mid,
|
||||
),
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error creating prompt: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"id": new_id}), 200)
|
||||
|
||||
|
||||
@prompts_ns.route("/get_prompts")
|
||||
class GetPrompts(Resource):
|
||||
@api.doc(description="Get all prompts for the user")
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
prompts = prompts_collection.find({"user": user})
|
||||
list_prompts = [
|
||||
{"id": "default", "name": "default", "type": "public"},
|
||||
{"id": "creative", "name": "creative", "type": "public"},
|
||||
{"id": "strict", "name": "strict", "type": "public"},
|
||||
]
|
||||
|
||||
for prompt in prompts:
|
||||
list_prompts.append(
|
||||
{
|
||||
"id": str(prompt["_id"]),
|
||||
"name": prompt["name"],
|
||||
"type": "private",
|
||||
}
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving prompts: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify(list_prompts), 200)
|
||||
|
||||
|
||||
@prompts_ns.route("/get_single_prompt")
|
||||
class GetSinglePrompt(Resource):
|
||||
@api.doc(params={"id": "ID of the prompt"}, description="Get a single prompt by ID")
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
prompt_id = request.args.get("id")
|
||||
if not prompt_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "ID is required"}), 400
|
||||
)
|
||||
try:
|
||||
if prompt_id == "default":
|
||||
with open(
|
||||
os.path.join(current_dir, "prompts", "chat_combine_default.txt"),
|
||||
"r",
|
||||
) as f:
|
||||
chat_combine_template = f.read()
|
||||
return make_response(jsonify({"content": chat_combine_template}), 200)
|
||||
elif prompt_id == "creative":
|
||||
with open(
|
||||
os.path.join(current_dir, "prompts", "chat_combine_creative.txt"),
|
||||
"r",
|
||||
) as f:
|
||||
chat_reduce_creative = f.read()
|
||||
return make_response(jsonify({"content": chat_reduce_creative}), 200)
|
||||
elif prompt_id == "strict":
|
||||
with open(
|
||||
os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r"
|
||||
) as f:
|
||||
chat_reduce_strict = f.read()
|
||||
return make_response(jsonify({"content": chat_reduce_strict}), 200)
|
||||
prompt = prompts_collection.find_one(
|
||||
{"_id": ObjectId(prompt_id), "user": user}
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving prompt: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"content": prompt["content"]}), 200)
|
||||
|
||||
|
||||
@prompts_ns.route("/delete_prompt")
|
||||
class DeletePrompt(Resource):
|
||||
delete_prompt_model = api.model(
|
||||
"DeletePromptModel",
|
||||
{"id": fields.String(required=True, description="Prompt ID to delete")},
|
||||
)
|
||||
|
||||
@api.expect(delete_prompt_model)
|
||||
@api.doc(description="Delete a prompt by ID")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
prompts_collection.delete_one({"_id": ObjectId(data["id"]), "user": user})
|
||||
dual_write(
|
||||
PromptsRepository,
|
||||
lambda repo, pid=data["id"], u=user: repo.delete_by_legacy_id(pid, u),
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error deleting prompt: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@prompts_ns.route("/update_prompt")
|
||||
class UpdatePrompt(Resource):
|
||||
update_prompt_model = api.model(
|
||||
"UpdatePromptModel",
|
||||
{
|
||||
"id": fields.String(required=True, description="Prompt ID to update"),
|
||||
"name": fields.String(required=True, description="New name of the prompt"),
|
||||
"content": fields.String(
|
||||
required=True, description="New content of the prompt"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(update_prompt_model)
|
||||
@api.doc(description="Update an existing prompt")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id", "name", "content"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
prompts_collection.update_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user},
|
||||
{"$set": {"name": data["name"], "content": data["content"]}},
|
||||
)
|
||||
dual_write(
|
||||
PromptsRepository,
|
||||
lambda repo, pid=data["id"], u=user, n=data["name"], c=data["content"]: repo.update_by_legacy_id(
|
||||
pid, u, n, c,
|
||||
),
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error updating prompt: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user