mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
Compare commits
593 Commits
dependabot
...
improve-va
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5a9d512679 | ||
|
|
6dd32fd4ca | ||
|
|
b17b1c70b5 | ||
|
|
3f5b31fb5f | ||
|
|
06bda6bd55 | ||
|
|
7dd97821a8 | ||
|
|
695191d888 | ||
|
|
1dbcef24c7 | ||
|
|
e086c79da0 | ||
|
|
6ae8d34b27 | ||
|
|
2e23e547d3 | ||
|
|
fa11dc9828 | ||
|
|
673fa70bc5 | ||
|
|
a0660a54c1 | ||
|
|
1137bf4280 | ||
|
|
da41c898d8 | ||
|
|
21e5c261ef | ||
|
|
a7d61b9d59 | ||
|
|
c5fe25c149 | ||
|
|
6a4cb617f9 | ||
|
|
94f70e6de5 | ||
|
|
ab4ebf9a9d | ||
|
|
9f7945fcf5 | ||
|
|
d8ec3c008c | ||
|
|
2f00691246 | ||
|
|
9b2383b074 | ||
|
|
e4e9910575 | ||
|
|
f448e4a615 | ||
|
|
c4e8daf50e | ||
|
|
5aa4ec1b9f | ||
|
|
125ce0aad3 | ||
|
|
ababc9ae04 | ||
|
|
62ac90746e | ||
|
|
096f6d91a2 | ||
|
|
d28ef6b094 | ||
|
|
8fb945ab09 | ||
|
|
835d71727c | ||
|
|
ce32dd2907 | ||
|
|
72bc24a490 | ||
|
|
d6c49bdbf0 | ||
|
|
1805292528 | ||
|
|
d09ce7e1f7 | ||
|
|
a8d2024791 | ||
|
|
f0b954dbfb | ||
|
|
50bee7c2b0 | ||
|
|
e7b15b316e | ||
|
|
a4507008c1 | ||
|
|
c5ba85f929 | ||
|
|
2e636bd67e | ||
|
|
4a039f1abf | ||
|
|
434d8e2070 | ||
|
|
160ad2dc79 | ||
|
|
0ec86c2c71 | ||
|
|
03452ffd9f | ||
|
|
da6317a242 | ||
|
|
8b8e616557 | ||
|
|
d260f1a1a6 | ||
|
|
9d452e3b04 | ||
|
|
e012189672 | ||
|
|
4c31e9a8b1 | ||
|
|
7cfc230316 | ||
|
|
9605e85f1c | ||
|
|
498e2b772c | ||
|
|
dad897da51 | ||
|
|
02ad5f062e | ||
|
|
4eb9471b4f | ||
|
|
b505d207d7 | ||
|
|
3c954bd07f | ||
|
|
c00b6459dc | ||
|
|
eb4d776784 | ||
|
|
5d7a890533 | ||
|
|
9c6aefef1e | ||
|
|
e4554d6c09 | ||
|
|
c184b63df8 | ||
|
|
6bb4195393 | ||
|
|
7827a4d40d | ||
|
|
f09fa8231a | ||
|
|
96ff10000d | ||
|
|
9460636867 | ||
|
|
6c43245295 | ||
|
|
266b6cf638 | ||
|
|
70183e234a | ||
|
|
17b9c359ca | ||
|
|
045630b8a5 | ||
|
|
55ff7dd640 | ||
|
|
e6d64f71f2 | ||
|
|
e72313ebdd | ||
|
|
65d5bd72cd | ||
|
|
dc0cbb41f0 | ||
|
|
c4a54a85be | ||
|
|
5b2738aec9 | ||
|
|
892312fc08 | ||
|
|
c2ccf2c72c | ||
|
|
80aaecb5f0 | ||
|
|
946865a335 | ||
|
|
5de15c8413 | ||
|
|
67268fd35a | ||
|
|
42fc771833 | ||
|
|
444b1a0b65 | ||
|
|
814ea1c016 | ||
|
|
4d34dc4234 | ||
|
|
d567399f2b | ||
|
|
77f4f8d8b0 | ||
|
|
a2d04beaa1 | ||
|
|
ba49eea23d | ||
|
|
82beafc086 | ||
|
|
7d8ed2d102 | ||
|
|
aab8d3a4f1 | ||
|
|
76658d50a0 | ||
|
|
88ba22342c | ||
|
|
11a1460af9 | ||
|
|
2cd4c41316 | ||
|
|
b910f308f2 | ||
|
|
763aa73ea4 | ||
|
|
30c79e92d4 | ||
|
|
402d5e054b | ||
|
|
0e211df206 | ||
|
|
e24a0ac686 | ||
|
|
8c91b1c527 | ||
|
|
2b38f80d04 | ||
|
|
282bd35f52 | ||
|
|
cc9b4c2bcb | ||
|
|
068ce4970a | ||
|
|
cf19165ad8 | ||
|
|
68c479f3a5 | ||
|
|
ba496a772b | ||
|
|
3b27db36f2 | ||
|
|
f803def69b | ||
|
|
52065e69a4 | ||
|
|
50f5e8a955 | ||
|
|
2d0e97b66d | ||
|
|
5f3cc5a392 | ||
|
|
ac66d77512 | ||
|
|
50cf653d4a | ||
|
|
56256051d2 | ||
|
|
c0361ff03d | ||
|
|
f153435c08 | ||
|
|
9aa7f22fa6 | ||
|
|
52b7bda5f8 | ||
|
|
21aefa2778 | ||
|
|
a89ff71c9e | ||
|
|
4c275816be | ||
|
|
f8dfbcfc80 | ||
|
|
d317f6473d | ||
|
|
00b4e133d4 | ||
|
|
b6349e4efb | ||
|
|
6ca3d9585c | ||
|
|
5935a0283a | ||
|
|
5400a6ec06 | ||
|
|
6574d9cc84 | ||
|
|
42b83c5994 | ||
|
|
896612a5a3 | ||
|
|
0ee875bee4 | ||
|
|
8ce345cd94 | ||
|
|
da2f8477e6 | ||
|
|
82b47b5673 | ||
|
|
7c15a4c7ff | ||
|
|
3369b910b4 | ||
|
|
ec0c4c3b84 | ||
|
|
f74e2c9da1 | ||
|
|
e26ad3c475 | ||
|
|
145c3b8ad0 | ||
|
|
0ff6c6a154 | ||
|
|
641cf5a4c1 | ||
|
|
09b9576eef | ||
|
|
18b71ca2f2 | ||
|
|
e0eb7f456e | ||
|
|
188d118fc0 | ||
|
|
adcdce8d76 | ||
|
|
b865a7aec1 | ||
|
|
cec8c72b46 | ||
|
|
b052e32805 | ||
|
|
816f660be3 | ||
|
|
fc8be45d5a | ||
|
|
e749c936c9 | ||
|
|
b2b9670a23 | ||
|
|
2f88890c94 | ||
|
|
6366663f03 | ||
|
|
20fe7dc6d1 | ||
|
|
4b9153069e | ||
|
|
80406d0753 | ||
|
|
35f4c11784 | ||
|
|
7896526f19 | ||
|
|
f7db22edff | ||
|
|
0e4196f036 | ||
|
|
1bf6af6eeb | ||
|
|
5a9bc6d2bf | ||
|
|
f7f6042579 | ||
|
|
c4a598f3d3 | ||
|
|
7c23f43c63 | ||
|
|
7e2cbdd88c | ||
|
|
3b3a04a249 | ||
|
|
f9b2c95695 | ||
|
|
c2c18e8319 | ||
|
|
384ad3e0ac | ||
|
|
8c986aaa7f | ||
|
|
bb4ea76d30 | ||
|
|
2868e47cf8 | ||
|
|
e0adc3e5d5 | ||
|
|
e55d1a5865 | ||
|
|
018273c6b2 | ||
|
|
44b8a11c04 | ||
|
|
56e5aba559 | ||
|
|
46904ccd54 | ||
|
|
5b7c7a4471 | ||
|
|
9da4215d1f | ||
|
|
f39ac9945f | ||
|
|
a0cc2e4d46 | ||
|
|
4065041a9f | ||
|
|
f08067a161 | ||
|
|
545caacfa3 | ||
|
|
a06f646637 | ||
|
|
578c68205a | ||
|
|
f09f1433a9 | ||
|
|
15a9e97a1e | ||
|
|
b3af4ee50b | ||
|
|
07d59b6640 | ||
|
|
e25b988dc8 | ||
|
|
2410bd8654 | ||
|
|
44d21ab703 | ||
|
|
e283957c8f | ||
|
|
b1210c4902 | ||
|
|
e7430f0fbc | ||
|
|
92d6ae54c3 | ||
|
|
f82be23ca9 | ||
|
|
8c3f75e3e2 | ||
|
|
193d59f193 | ||
|
|
c2bebbaefa | ||
|
|
7ae5a9c5a5 | ||
|
|
3b69bea23d | ||
|
|
ab05726b99 | ||
|
|
b2b04268e9 | ||
|
|
bd73fa9ae7 | ||
|
|
927d10d66e | ||
|
|
b67329623c | ||
|
|
6f47aa802b | ||
|
|
3417c73011 | ||
|
|
6a02bcf15b | ||
|
|
cd0fbf79a3 | ||
|
|
15d2d0115b | ||
|
|
d1a0fe6e91 | ||
|
|
1db80d140f | ||
|
|
896dcf1f9e | ||
|
|
819a12fb49 | ||
|
|
c68273706c | ||
|
|
6bb0cd535a | ||
|
|
cb9ec69cf6 | ||
|
|
143854fa81 | ||
|
|
2f48a3d7d5 | ||
|
|
ec95dafe1e | ||
|
|
3d1fe724e5 | ||
|
|
5c615d6f2d | ||
|
|
d72558eb36 | ||
|
|
65c33ad915 | ||
|
|
9be128a963 | ||
|
|
eb05132008 | ||
|
|
f94a093e8c | ||
|
|
0d0c2daf64 | ||
|
|
823d948b25 | ||
|
|
56831fbcf2 | ||
|
|
bf49b9cb88 | ||
|
|
e01adffbad | ||
|
|
08a5d52d82 | ||
|
|
fdae235742 | ||
|
|
9903fad1e9 | ||
|
|
14bbd5338d | ||
|
|
4a236c2f6f | ||
|
|
0a8cdbd7f1 | ||
|
|
94c49843be | ||
|
|
9281fac898 | ||
|
|
0b2736f454 | ||
|
|
ae116b0d0d | ||
|
|
ba260e3382 | ||
|
|
1282e7687f | ||
|
|
b1d8266eef | ||
|
|
7acae6935b | ||
|
|
092c01cae7 | ||
|
|
56a1066c30 | ||
|
|
1356d71839 | ||
|
|
1eb011e8c3 | ||
|
|
e349eb28b0 | ||
|
|
b000b235a2 | ||
|
|
16fe92282e | ||
|
|
e218e88cf4 | ||
|
|
888ea81a32 | ||
|
|
735fab7640 | ||
|
|
45745c2a47 | ||
|
|
4caff0fcf6 | ||
|
|
762ea6ce7f | ||
|
|
8b4f6553f3 | ||
|
|
a61e44d175 | ||
|
|
e1b1558fc9 | ||
|
|
53225bda4e | ||
|
|
5212769848 | ||
|
|
d5ded3c9f4 | ||
|
|
c92d778894 | ||
|
|
829abd1ad6 | ||
|
|
266d256a07 | ||
|
|
8380cac3e7 | ||
|
|
a24652f901 | ||
|
|
2d203d3c70 | ||
|
|
48d21600da | ||
|
|
2508d0fbb3 | ||
|
|
e90e80c289 | ||
|
|
5e4748f9d9 | ||
|
|
212952f3e9 | ||
|
|
f99b6496c5 | ||
|
|
67423d51b9 | ||
|
|
58465ece65 | ||
|
|
8ede3a0173 | ||
|
|
ad2f0f8950 | ||
|
|
76973a4b4c | ||
|
|
b198e2e029 | ||
|
|
4d6ea401b5 | ||
|
|
b00c4cc3b6 | ||
|
|
4185e64c65 | ||
|
|
6eb2c884a2 | ||
|
|
6c0362a4cf | ||
|
|
50b1755a63 | ||
|
|
ff3c7eb5fb | ||
|
|
3755316d49 | ||
|
|
f952046847 | ||
|
|
969cdb4a63 | ||
|
|
f336d44595 | ||
|
|
a53f93c195 | ||
|
|
fcb334ce33 | ||
|
|
8ddf04a904 | ||
|
|
29698ca169 | ||
|
|
a9baf7436a | ||
|
|
99a8962183 | ||
|
|
afc5b15a6b | ||
|
|
b6ab508e27 | ||
|
|
789e65557a | ||
|
|
8a7806ab2d | ||
|
|
493303e103 | ||
|
|
1d9af05e9e | ||
|
|
5b07c5f2e8 | ||
|
|
2a4ec0cf5b | ||
|
|
a00c44386e | ||
|
|
a38d71bbfb | ||
|
|
a24a3f868c | ||
|
|
f60c516185 | ||
|
|
26f4646304 | ||
|
|
3a351f67e6 | ||
|
|
e7c09cb91e | ||
|
|
ae1a6ef303 | ||
|
|
2ff477a339 | ||
|
|
793f3fb683 | ||
|
|
a472ee7602 | ||
|
|
c62040e232 | ||
|
|
2e7cb510ae | ||
|
|
dbe45904d7 | ||
|
|
5623734276 | ||
|
|
d3b592bffc | ||
|
|
4fcbdae5bf | ||
|
|
ca95d7275a | ||
|
|
61baf3701c | ||
|
|
bbce872ac5 | ||
|
|
0f7ebcd8e4 | ||
|
|
82fc19e7b7 | ||
|
|
839a12bed4 | ||
|
|
2ef23fe1b3 | ||
|
|
fd905b1a06 | ||
|
|
1372210004 | ||
|
|
ade704d065 | ||
|
|
42f48649b9 | ||
|
|
0b08e8b617 | ||
|
|
926b2f1a1b | ||
|
|
1770a1a45f | ||
|
|
50ed2a64c6 | ||
|
|
2332344988 | ||
|
|
7ccc8cdc58 | ||
|
|
ecec9f913e | ||
|
|
777f40fc5e | ||
|
|
327ae35420 | ||
|
|
0d48159da8 | ||
|
|
d36f12a4ea | ||
|
|
709488beb1 | ||
|
|
a9e4583695 | ||
|
|
4702dec933 | ||
|
|
e6352dd691 | ||
|
|
240ea3b857 | ||
|
|
f0908af3c0 | ||
|
|
6834961dd1 | ||
|
|
b404162364 | ||
|
|
e879ef805f | ||
|
|
7077ca5e98 | ||
|
|
a1e6978c8f | ||
|
|
584391dd59 | ||
|
|
bab3ae809c | ||
|
|
c78518baf0 | ||
|
|
556d7e0497 | ||
|
|
2d27936dab | ||
|
|
0cc22de545 | ||
|
|
63f6127049 | ||
|
|
f34e00c986 | ||
|
|
55f60a9fe1 | ||
|
|
7da3618e0c | ||
|
|
56bfa98633 | ||
|
|
96f6188722 | ||
|
|
aa9d359039 | ||
|
|
cef5731028 | ||
|
|
5bc28bd4fd | ||
|
|
55a1d867c3 | ||
|
|
6c3a79802e | ||
|
|
c35c5e0793 | ||
|
|
7bc83caa99 | ||
|
|
3aceca63c6 | ||
|
|
9bc166ffd4 | ||
|
|
fc01b90007 | ||
|
|
e35f1d70e4 | ||
|
|
cab1f3787a | ||
|
|
bb42f4cbc1 | ||
|
|
98dc418a51 | ||
|
|
322b4eb18c | ||
|
|
7f1cc30ed8 | ||
|
|
7b45a6b956 | ||
|
|
e36769e70f | ||
|
|
bd4a4cc4af | ||
|
|
8343fe63cb | ||
|
|
7d89fb8461 | ||
|
|
098955d230 | ||
|
|
d254d14928 | ||
|
|
0a3e8ca535 | ||
|
|
b8a10e0962 | ||
|
|
0aceda96e4 | ||
|
|
44b6ec25a2 | ||
|
|
1b84d1fa9d | ||
|
|
78d5ed2ed2 | ||
|
|
142477ab9b | ||
|
|
b414f79bc5 | ||
|
|
6e08fe21d0 | ||
|
|
9b839655a7 | ||
|
|
3353c0ee1d | ||
|
|
aaecf52c99 | ||
|
|
8b3e960be0 | ||
|
|
3351f71813 | ||
|
|
7490256303 | ||
|
|
041d600e45 | ||
|
|
b4e2588a24 | ||
|
|
68dc14c5a1 | ||
|
|
ef35864e16 | ||
|
|
c0d385b983 | ||
|
|
b2df431fa4 | ||
|
|
69a4bd415a | ||
|
|
4862548e65 | ||
|
|
50248cc9ea | ||
|
|
430822bae3 | ||
|
|
dd9d18208d | ||
|
|
e5b1a71659 | ||
|
|
35f4b13237 | ||
|
|
5f5c31cd5b | ||
|
|
e9530d5ec5 | ||
|
|
143f4aa886 | ||
|
|
ece5c8bb31 | ||
|
|
31baf181a3 | ||
|
|
3bae30c70c | ||
|
|
12b18c6bd1 | ||
|
|
787d9e3bf5 | ||
|
|
f325b54895 | ||
|
|
c5616705b0 | ||
|
|
c0f693d35d | ||
|
|
52a5f132c1 | ||
|
|
f14eac6d10 | ||
|
|
e90fe117ec | ||
|
|
381d737d24 | ||
|
|
7cab5b3b09 | ||
|
|
9f911cb5cb | ||
|
|
3da7cba06c | ||
|
|
b47af9600f | ||
|
|
92c3c707e1 | ||
|
|
5acc54e609 | ||
|
|
9c6352dd5b | ||
|
|
8e29a07df5 | ||
|
|
bd88cd3a06 | ||
|
|
f371b9702f | ||
|
|
3ff4ae29af | ||
|
|
eae0f2e7a9 | ||
|
|
305a98bb79 | ||
|
|
8040a3ed60 | ||
|
|
bb9de7d9b0 | ||
|
|
d8e8bc0068 | ||
|
|
6577e9d852 | ||
|
|
3f8625c65a | ||
|
|
92d69636a7 | ||
|
|
9c28817fba | ||
|
|
773788fb32 | ||
|
|
a393ad8e04 | ||
|
|
71d3714347 | ||
|
|
b7e1329c13 | ||
|
|
59e6d9d10e | ||
|
|
46efb446fb | ||
|
|
d31e3a54fd | ||
|
|
c4e471ac47 | ||
|
|
3b8733e085 | ||
|
|
a7c67d83ca | ||
|
|
8abc1de26d | ||
|
|
2ca9f708a6 | ||
|
|
f8f369fbb2 | ||
|
|
3e9155767b | ||
|
|
8cd4195657 | ||
|
|
ad1a944276 | ||
|
|
02ff4c5657 | ||
|
|
b1b27f2dde | ||
|
|
5097f77469 | ||
|
|
7e826d5002 | ||
|
|
fe8143a56c | ||
|
|
e5442a713a | ||
|
|
1982a46f36 | ||
|
|
c8c3640baf | ||
|
|
fdf47b3f2c | ||
|
|
93fa4b6a37 | ||
|
|
90e9ab70b0 | ||
|
|
573c2386b7 | ||
|
|
d2176aeeb9 | ||
|
|
920aec5c3e | ||
|
|
b792c5459a | ||
|
|
87fbf05fa1 | ||
|
|
67c53250c5 | ||
|
|
d657eea910 | ||
|
|
b5fbb825ed | ||
|
|
d094e7a4c6 | ||
|
|
945c155b17 | ||
|
|
f798072a1e | ||
|
|
f967214b57 | ||
|
|
d0b92e2540 | ||
|
|
8ddfe272bf | ||
|
|
b7a6bad7cd | ||
|
|
e2f6c04406 | ||
|
|
c662725955 | ||
|
|
4b66ddfdef | ||
|
|
2d55b1f592 | ||
|
|
14adfabf7e | ||
|
|
e7a76ede76 | ||
|
|
de47df3bf9 | ||
|
|
5475e6f7c5 | ||
|
|
8e3f3d74d4 | ||
|
|
046f6c66ed | ||
|
|
79f9d6552e | ||
|
|
56b4b63749 | ||
|
|
b3246a48c7 | ||
|
|
71722ef6a3 | ||
|
|
ebf8f00302 | ||
|
|
7445928c7e | ||
|
|
5ab7602f2f | ||
|
|
a340aff63a | ||
|
|
f82042ff00 | ||
|
|
920422e28c | ||
|
|
50d6b7a6f8 | ||
|
|
41d624a36a | ||
|
|
f42c37c82e | ||
|
|
119fcdf6f6 | ||
|
|
a5b093d1a9 | ||
|
|
e07cb44a3e | ||
|
|
fec1bcfd5c | ||
|
|
dbcf658343 | ||
|
|
d89e78c9ca | ||
|
|
ec50650dfa | ||
|
|
7432e551f9 | ||
|
|
4ee6bd44d1 | ||
|
|
26f819098d | ||
|
|
a1c79f93d7 | ||
|
|
9c1b202d74 | ||
|
|
8ad0f59f19 | ||
|
|
af40a77d24 | ||
|
|
d80b7017cf | ||
|
|
56793c8db7 | ||
|
|
8edb217943 | ||
|
|
23ebcf1065 | ||
|
|
68a5a3d62a | ||
|
|
8d7236b0db | ||
|
|
96c7daf818 | ||
|
|
fc4942e189 | ||
|
|
ca69d025bd | ||
|
|
ffa428e32a | ||
|
|
c24e90eaae | ||
|
|
ab32eff588 | ||
|
|
7f592f2b35 | ||
|
|
130ece7bc0 | ||
|
|
b2582796a2 | ||
|
|
cd556d5d43 | ||
|
|
2855283a2c | ||
|
|
06c29500f2 | ||
|
|
81104153a6 | ||
|
|
e1e608b744 | ||
|
|
ea9ab5b27c | ||
|
|
357ced6cba | ||
|
|
3ffda69651 | ||
|
|
e1bf4e0762 | ||
|
|
17e4fad6fb | ||
|
|
8552e81022 | ||
|
|
eacdde829f | ||
|
|
d873539856 |
2
.gitattributes
vendored
Normal file
2
.gitattributes
vendored
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
# Auto detect text files and perform LF normalization
|
||||||
|
* text=auto
|
||||||
11
.github/styles/DocsGPT/Spelling.yml
vendored
Normal file
11
.github/styles/DocsGPT/Spelling.yml
vendored
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
extends: spelling
|
||||||
|
level: warning
|
||||||
|
message: "Did you really mean '%s'?"
|
||||||
|
ignore:
|
||||||
|
- "**/node_modules/**"
|
||||||
|
- "**/dist/**"
|
||||||
|
- "**/build/**"
|
||||||
|
- "**/coverage/**"
|
||||||
|
- "**/public/**"
|
||||||
|
- "**/static/**"
|
||||||
|
vocab: DocsGPT
|
||||||
46
.github/styles/config/vocabularies/DocsGPT/accept.txt
vendored
Normal file
46
.github/styles/config/vocabularies/DocsGPT/accept.txt
vendored
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
Ollama
|
||||||
|
Qdrant
|
||||||
|
Milvus
|
||||||
|
Chatwoot
|
||||||
|
Nextra
|
||||||
|
VSCode
|
||||||
|
npm
|
||||||
|
LLMs
|
||||||
|
APIs
|
||||||
|
Groq
|
||||||
|
SGLang
|
||||||
|
LMDeploy
|
||||||
|
OAuth
|
||||||
|
Vite
|
||||||
|
LLM
|
||||||
|
JSONPath
|
||||||
|
UIs
|
||||||
|
configs
|
||||||
|
uncomment
|
||||||
|
qdrant
|
||||||
|
vectorstore
|
||||||
|
docsgpt
|
||||||
|
llm
|
||||||
|
GPUs
|
||||||
|
kubectl
|
||||||
|
Lightsail
|
||||||
|
enqueues
|
||||||
|
chatbot
|
||||||
|
VSCode's
|
||||||
|
Shareability
|
||||||
|
feedbacks
|
||||||
|
automations
|
||||||
|
Premade
|
||||||
|
Signup
|
||||||
|
Repo
|
||||||
|
repo
|
||||||
|
env
|
||||||
|
URl
|
||||||
|
agentic
|
||||||
|
llama_cpp
|
||||||
|
parsable
|
||||||
|
SDKs
|
||||||
|
boolean
|
||||||
|
bool
|
||||||
|
hardcode
|
||||||
|
EOL
|
||||||
6
.github/workflows/pytest.yml
vendored
6
.github/workflows/pytest.yml
vendored
@@ -16,15 +16,15 @@ jobs:
|
|||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install pytest pytest-cov
|
|
||||||
cd application
|
cd application
|
||||||
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
|
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
|
||||||
|
cd ../tests
|
||||||
|
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
|
||||||
- name: Test with pytest and generate coverage report
|
- name: Test with pytest and generate coverage report
|
||||||
run: |
|
run: |
|
||||||
python -m pytest --cov=application --cov-report=xml
|
python -m pytest --cov=application --cov-report=xml --cov-report=term-missing
|
||||||
- name: Upload coverage reports to Codecov
|
- name: Upload coverage reports to Codecov
|
||||||
if: github.event_name == 'pull_request' && matrix.python-version == '3.12'
|
if: github.event_name == 'pull_request' && matrix.python-version == '3.12'
|
||||||
uses: codecov/codecov-action@v5
|
uses: codecov/codecov-action@v5
|
||||||
env:
|
env:
|
||||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||||
|
|
||||||
|
|||||||
26
.github/workflows/vale.yml
vendored
Normal file
26
.github/workflows/vale.yml
vendored
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
name: Vale Documentation Linter
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
paths:
|
||||||
|
- 'docs/**/*.md'
|
||||||
|
- 'docs/**/*.mdx'
|
||||||
|
- '**/*.md'
|
||||||
|
- '.vale.ini'
|
||||||
|
- '.github/styles/**'
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
vale:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Vale linter
|
||||||
|
uses: errata-ai/vale-action@v2
|
||||||
|
with:
|
||||||
|
files: docs
|
||||||
|
fail_on_error: false
|
||||||
|
version: 3.0.5
|
||||||
|
env:
|
||||||
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -3,6 +3,7 @@ __pycache__/
|
|||||||
*.py[cod]
|
*.py[cod]
|
||||||
*$py.class
|
*$py.class
|
||||||
|
|
||||||
|
experiments
|
||||||
# C extensions
|
# C extensions
|
||||||
*.so
|
*.so
|
||||||
*.next
|
*.next
|
||||||
|
|||||||
5
.vale.ini
Normal file
5
.vale.ini
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
MinAlertLevel = warning
|
||||||
|
StylesPath = .github/styles
|
||||||
|
|
||||||
|
[*.{md,mdx}]
|
||||||
|
BasedOnStyles = DocsGPT
|
||||||
33
.vscode/launch.json
vendored
33
.vscode/launch.json
vendored
@@ -2,15 +2,11 @@
|
|||||||
"version": "0.2.0",
|
"version": "0.2.0",
|
||||||
"configurations": [
|
"configurations": [
|
||||||
{
|
{
|
||||||
"name": "Docker Debug Frontend",
|
"name": "Frontend Debug (npm)",
|
||||||
|
"type": "node-terminal",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"type": "chrome",
|
"command": "npm run dev",
|
||||||
"preLaunchTask": "docker-compose: debug:frontend",
|
"cwd": "${workspaceFolder}/frontend"
|
||||||
"url": "http://127.0.0.1:5173",
|
|
||||||
"webRoot": "${workspaceFolder}/frontend",
|
|
||||||
"skipFiles": [
|
|
||||||
"<node_internals>/**"
|
|
||||||
]
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Flask Debugger",
|
"name": "Flask Debugger",
|
||||||
@@ -49,6 +45,27 @@
|
|||||||
"--pool=solo"
|
"--pool=solo"
|
||||||
],
|
],
|
||||||
"cwd": "${workspaceFolder}"
|
"cwd": "${workspaceFolder}"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Dev Containers (Mongo + Redis)",
|
||||||
|
"type": "node-terminal",
|
||||||
|
"request": "launch",
|
||||||
|
"command": "docker compose -f deployment/docker-compose-dev.yaml up --build",
|
||||||
|
"cwd": "${workspaceFolder}"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"compounds": [
|
||||||
|
{
|
||||||
|
"name": "DocsGPT: Full Stack",
|
||||||
|
"configurations": [
|
||||||
|
"Frontend Debug (npm)",
|
||||||
|
"Flask Debugger",
|
||||||
|
"Celery Debugger"
|
||||||
|
],
|
||||||
|
"presentation": {
|
||||||
|
"group": "DocsGPT",
|
||||||
|
"order": 1
|
||||||
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
21
.vscode/tasks.json
vendored
21
.vscode/tasks.json
vendored
@@ -1,21 +0,0 @@
|
|||||||
{
|
|
||||||
"version": "2.0.0",
|
|
||||||
"tasks": [
|
|
||||||
{
|
|
||||||
"type": "docker-compose",
|
|
||||||
"label": "docker-compose: debug:frontend",
|
|
||||||
"dockerCompose": {
|
|
||||||
"up": {
|
|
||||||
"detached": true,
|
|
||||||
"services": [
|
|
||||||
"frontend"
|
|
||||||
],
|
|
||||||
"build": true
|
|
||||||
},
|
|
||||||
"files": [
|
|
||||||
"${workspaceFolder}/docker-compose.yaml"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
39
HACKTOBERFEST.md
Normal file
39
HACKTOBERFEST.md
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
# **🎉 Join the Hacktoberfest with DocsGPT and win a Free T-shirt for a meaningful PR! 🎉**
|
||||||
|
|
||||||
|
Welcome, contributors! We're excited to announce that DocsGPT is participating in Hacktoberfest. Get involved by submitting meaningful pull requests.
|
||||||
|
|
||||||
|
All Meaningful contributors with accepted PRs that were created for issues with the `hacktoberfest` label (set by our maintainer team: dartpain, siiddhantt, pabik, ManishMadan2882) will receive a cool T-shirt! 🤩.
|
||||||
|
<img width="1331" height="678" alt="hacktoberfest-mocks-preview" src="https://github.com/user-attachments/assets/633f6377-38db-48f5-b519-a8b3855a9eb4" />
|
||||||
|
|
||||||
|
Fill in [this form](https://forms.gle/Npaba4n9Epfyx56S8
|
||||||
|
) after your PR was merged please
|
||||||
|
|
||||||
|
If you are in doubt don't hesitate to ping us on discord, ping me - Alex (dartpain).
|
||||||
|
|
||||||
|
## 📜 Here's How to Contribute:
|
||||||
|
```text
|
||||||
|
🛠️ Code: This is the golden ticket! Make meaningful contributions through PRs.
|
||||||
|
|
||||||
|
🧩 API extension: Build an app utilising DocsGPT API. We prefer submissions that showcase original ideas and turn the API into an AI agent.
|
||||||
|
They can be a completely separate repos.
|
||||||
|
For example:
|
||||||
|
https://github.com/arc53/tg-bot-docsgpt-extenstion or
|
||||||
|
https://github.com/arc53/DocsGPT-cli
|
||||||
|
|
||||||
|
Non-Code Contributions:
|
||||||
|
|
||||||
|
📚 Wiki: Improve our documentation, create a guide.
|
||||||
|
|
||||||
|
🖥️ Design: Improve the UI/UX or design a new feature.
|
||||||
|
```
|
||||||
|
|
||||||
|
### 📝 Guidelines for Pull Requests:
|
||||||
|
- Familiarize yourself with the current contributions and our [Roadmap](https://github.com/orgs/arc53/projects/2).
|
||||||
|
- Before contributing check existing [issues](https://github.com/arc53/DocsGPT/issues) or [create](https://github.com/arc53/DocsGPT/issues/new/choose) an issue and wait to get assigned.
|
||||||
|
- Once you are finished with your contribution, please fill in this [form](https://forms.gle/Npaba4n9Epfyx56S8).
|
||||||
|
- Refer to the [Documentation](https://docs.docsgpt.cloud/).
|
||||||
|
- Feel free to join our [Discord](https://discord.gg/n5BX8dh8rU) server. We're here to help newcomers, so don't hesitate to jump in! Join us [here](https://discord.gg/n5BX8dh8rU).
|
||||||
|
|
||||||
|
Thank you very much for considering contributing to DocsGPT during Hacktoberfest! 🙏 Your contributions (not just simple typos) could earn you a stylish new t-shirt.
|
||||||
|
|
||||||
|
We will publish a t-shirt design later into the October.
|
||||||
41
README.md
41
README.md
@@ -3,11 +3,11 @@
|
|||||||
</h1>
|
</h1>
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<strong>Open-Source RAG Assistant</strong>
|
<strong>Private AI for agents, assistants and enterprise search</strong>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
<p align="left">
|
<p align="left">
|
||||||
<strong><a href="https://www.docsgpt.cloud/">DocsGPT</a></strong> is an open-source genAI tool that helps users get reliable answers from any knowledge source, while avoiding hallucinations. It enables quick and reliable information retrieval, with tooling and agentic system capability built in.
|
<strong><a href="https://www.docsgpt.cloud/">DocsGPT</a></strong> is an open-source AI platform for building intelligent agents and assistants. Features Agent Builder, deep research tools, document analysis (PDF, Office, web content), Multi-model support (choose your provider or run locally), and rich API connectivity for agents with actionable tools and integrations. Deploy anywhere with complete privacy control.
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
@@ -17,15 +17,25 @@
|
|||||||
<a href="https://github.com/arc53/DocsGPT/blob/main/LICENSE"></a>
|
<a href="https://github.com/arc53/DocsGPT/blob/main/LICENSE"></a>
|
||||||
<a href="https://www.bestpractices.dev/projects/9907"><img src="https://www.bestpractices.dev/projects/9907/badge"></a>
|
<a href="https://www.bestpractices.dev/projects/9907"><img src="https://www.bestpractices.dev/projects/9907/badge"></a>
|
||||||
<a href="https://discord.gg/n5BX8dh8rU"></a>
|
<a href="https://discord.gg/n5BX8dh8rU"></a>
|
||||||
<a href="https://twitter.com/docsgptai"></a>
|
<a href="https://x.com/docsgptai"></a>
|
||||||
|
|
||||||
<a href="https://docs.docsgpt.cloud/quickstart">⚡️ Quickstart</a> • <a href="https://app.docsgpt.cloud/">☁️ Cloud Version</a> • <a href="https://discord.gg/n5BX8dh8rU">💬 Discord</a>
|
<a href="https://docs.docsgpt.cloud/quickstart">⚡️ Quickstart</a> • <a href="https://app.docsgpt.cloud/">☁️ Cloud Version</a> • <a href="https://discord.gg/n5BX8dh8rU">💬 Discord</a>
|
||||||
<br>
|
<br>
|
||||||
<a href="https://docs.docsgpt.cloud/">📖 Documentation</a> • <a href="https://github.com/arc53/DocsGPT/blob/main/CONTRIBUTING.md">👫 Contribute</a> • <a href="https://blog.docsgpt.cloud/">🗞 Blog</a>
|
<a href="https://docs.docsgpt.cloud/">📖 Documentation</a> • <a href="https://github.com/arc53/DocsGPT/blob/main/CONTRIBUTING.md">👫 Contribute</a> • <a href="https://blog.docsgpt.cloud/">🗞 Blog</a>
|
||||||
<br>
|
<br>
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
|
<br>
|
||||||
|
🎃 <a href="https://github.com/arc53/DocsGPT/blob/main/HACKTOBERFEST.md"> Hacktoberfest Prizes, Rules & Q&A </a> 🎃
|
||||||
|
<br>
|
||||||
|
<br>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
<br>
|
||||||
<img src="https://d3dg1063dc54p9.cloudfront.net/videos/demov7.gif" alt="video-example-of-docs-gpt" width="800" height="450">
|
<img src="https://d3dg1063dc54p9.cloudfront.net/videos/demov7.gif" alt="video-example-of-docs-gpt" width="800" height="450">
|
||||||
</div>
|
</div>
|
||||||
<h3 align="left">
|
<h3 align="left">
|
||||||
@@ -52,8 +62,14 @@
|
|||||||
- [x] Chatbots menu re-design to handle tools, agent types, and more (April 2025)
|
- [x] Chatbots menu re-design to handle tools, agent types, and more (April 2025)
|
||||||
- [x] New input box in the conversation menu (April 2025)
|
- [x] New input box in the conversation menu (April 2025)
|
||||||
- [x] Add triggerable actions / tools (webhook) (April 2025)
|
- [x] Add triggerable actions / tools (webhook) (April 2025)
|
||||||
- [ ] Anthropic Tool compatibility (May 2025)
|
- [x] Agent optimisations (May 2025)
|
||||||
- [ ] Add OAuth 2.0 authentication for tools and sources
|
- [x] Filesystem sources update (July 2025)
|
||||||
|
- [x] Json Responses (August 2025)
|
||||||
|
- [x] MCP support (August 2025)
|
||||||
|
- [x] Google Drive integration (September 2025)
|
||||||
|
- [x] Add OAuth 2.0 authentication for MCP (September 2025)
|
||||||
|
- [ ] SharePoint integration (October 2025)
|
||||||
|
- [ ] Deep Agents (October 2025)
|
||||||
- [ ] Agent scheduling
|
- [ ] Agent scheduling
|
||||||
|
|
||||||
You can find our full roadmap [here](https://github.com/orgs/arc53/projects/2). Please don't hesitate to contribute or create issues, it helps us improve DocsGPT!
|
You can find our full roadmap [here](https://github.com/orgs/arc53/projects/2). Please don't hesitate to contribute or create issues, it helps us improve DocsGPT!
|
||||||
@@ -68,11 +84,10 @@ We're eager to provide personalized assistance when deploying your DocsGPT to a
|
|||||||
|
|
||||||
## Join the Lighthouse Program 🌟
|
## Join the Lighthouse Program 🌟
|
||||||
|
|
||||||
Calling all developers and GenAI innovators! The **DocsGPT Lighthouse Program** connects technical leaders actively deploying or extending DocsGPT in real-world scenarios. Collaborate directly with our team to shape the roadmap, access priority support, and build enterprise-ready solutions with exclusive community insights.
|
Calling all developers and GenAI innovators! The **DocsGPT Lighthouse Program** connects technical leaders actively deploying or extending DocsGPT in real-world scenarios. Collaborate directly with our team to shape the roadmap, access priority support, and build enterprise-ready solutions with exclusive community insights.
|
||||||
|
|
||||||
[Learn More & Apply →](https://docs.google.com/forms/d/1KAADiJinUJ8EMQyfTXUIGyFbqINNClNR3jBNWq7DgTE)
|
[Learn More & Apply →](https://docs.google.com/forms/d/1KAADiJinUJ8EMQyfTXUIGyFbqINNClNR3jBNWq7DgTE)
|
||||||
|
|
||||||
|
|
||||||
## QuickStart
|
## QuickStart
|
||||||
|
|
||||||
> [!Note]
|
> [!Note]
|
||||||
@@ -103,7 +118,7 @@ A more detailed [Quickstart](https://docs.docsgpt.cloud/quickstart) is available
|
|||||||
PowerShell -ExecutionPolicy Bypass -File .\setup.ps1
|
PowerShell -ExecutionPolicy Bypass -File .\setup.ps1
|
||||||
```
|
```
|
||||||
|
|
||||||
Either script will guide you through setting up DocsGPT. Four options available: using the public API, running locally, connecting to a local inference engine, or using a cloud API provider. Scripts will automatically configure your `.env` file and handle necessary downloads and installations based on your chosen option.
|
Either script will guide you through setting up DocsGPT. Five options available: using the public API, running locally, connecting to a local inference engine, using a cloud API provider, or build the docker image locally. Scripts will automatically configure your `.env` file and handle necessary downloads and installations based on your chosen option.
|
||||||
|
|
||||||
**Navigate to http://localhost:5173/**
|
**Navigate to http://localhost:5173/**
|
||||||
|
|
||||||
@@ -112,6 +127,7 @@ To stop DocsGPT, open a terminal in the `DocsGPT` directory and run:
|
|||||||
```bash
|
```bash
|
||||||
docker compose -f deployment/docker-compose.yaml down
|
docker compose -f deployment/docker-compose.yaml down
|
||||||
```
|
```
|
||||||
|
|
||||||
(or use the specific `docker compose down` command shown after running the setup script).
|
(or use the specific `docker compose down` command shown after running the setup script).
|
||||||
|
|
||||||
> [!Note]
|
> [!Note]
|
||||||
@@ -139,7 +155,6 @@ Please refer to the [CONTRIBUTING.md](CONTRIBUTING.md) file for information abou
|
|||||||
|
|
||||||
We as members, contributors, and leaders, pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. Please refer to the [CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md) file for more information about contributing.
|
We as members, contributors, and leaders, pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. Please refer to the [CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md) file for more information about contributing.
|
||||||
|
|
||||||
|
|
||||||
## Many Thanks To Our Contributors⚡
|
## Many Thanks To Our Contributors⚡
|
||||||
|
|
||||||
<a href="https://github.com/arc53/DocsGPT/graphs/contributors" alt="View Contributors">
|
<a href="https://github.com/arc53/DocsGPT/graphs/contributors" alt="View Contributors">
|
||||||
|
|||||||
@@ -1,17 +1,19 @@
|
|||||||
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, Generator, List, Optional
|
from typing import Dict, Generator, List, Optional
|
||||||
|
|
||||||
from application.agents.llm_handler import get_llm_handler
|
from bson.objectid import ObjectId
|
||||||
|
|
||||||
from application.agents.tools.tool_action_parser import ToolActionParser
|
from application.agents.tools.tool_action_parser import ToolActionParser
|
||||||
from application.agents.tools.tool_manager import ToolManager
|
from application.agents.tools.tool_manager import ToolManager
|
||||||
|
|
||||||
from application.core.mongo_db import MongoDB
|
from application.core.mongo_db import MongoDB
|
||||||
|
from application.core.settings import settings
|
||||||
|
from application.llm.handlers.handler_creator import LLMHandlerCreator
|
||||||
from application.llm.llm_creator import LLMCreator
|
from application.llm.llm_creator import LLMCreator
|
||||||
from application.logging import build_stack_data, log_activity, LogContext
|
from application.logging import build_stack_data, log_activity, LogContext
|
||||||
from application.retriever.base import BaseRetriever
|
|
||||||
from application.core.settings import settings
|
logger = logging.getLogger(__name__)
|
||||||
from bson.objectid import ObjectId
|
|
||||||
|
|
||||||
|
|
||||||
class BaseAgent(ABC):
|
class BaseAgent(ABC):
|
||||||
@@ -24,8 +26,14 @@ class BaseAgent(ABC):
|
|||||||
user_api_key: Optional[str] = None,
|
user_api_key: Optional[str] = None,
|
||||||
prompt: str = "",
|
prompt: str = "",
|
||||||
chat_history: Optional[List[Dict]] = None,
|
chat_history: Optional[List[Dict]] = None,
|
||||||
|
retrieved_docs: Optional[List[Dict]] = None,
|
||||||
decoded_token: Optional[Dict] = None,
|
decoded_token: Optional[Dict] = None,
|
||||||
attachments: Optional[List[Dict]] = None,
|
attachments: Optional[List[Dict]] = None,
|
||||||
|
json_schema: Optional[Dict] = None,
|
||||||
|
limited_token_mode: Optional[bool] = False,
|
||||||
|
token_limit: Optional[int] = settings.DEFAULT_AGENT_LIMITS["token_limit"],
|
||||||
|
limited_request_mode: Optional[bool] = False,
|
||||||
|
request_limit: Optional[int] = settings.DEFAULT_AGENT_LIMITS["request_limit"],
|
||||||
):
|
):
|
||||||
self.endpoint = endpoint
|
self.endpoint = endpoint
|
||||||
self.llm_name = llm_name
|
self.llm_name = llm_name
|
||||||
@@ -34,7 +42,7 @@ class BaseAgent(ABC):
|
|||||||
self.user_api_key = user_api_key
|
self.user_api_key = user_api_key
|
||||||
self.prompt = prompt
|
self.prompt = prompt
|
||||||
self.decoded_token = decoded_token or {}
|
self.decoded_token = decoded_token or {}
|
||||||
self.user: str = decoded_token.get("sub")
|
self.user: str = self.decoded_token.get("sub")
|
||||||
self.tool_config: Dict = {}
|
self.tool_config: Dict = {}
|
||||||
self.tools: List[Dict] = []
|
self.tools: List[Dict] = []
|
||||||
self.tool_calls: List[Dict] = []
|
self.tool_calls: List[Dict] = []
|
||||||
@@ -45,18 +53,26 @@ class BaseAgent(ABC):
|
|||||||
user_api_key=user_api_key,
|
user_api_key=user_api_key,
|
||||||
decoded_token=decoded_token,
|
decoded_token=decoded_token,
|
||||||
)
|
)
|
||||||
self.llm_handler = get_llm_handler(llm_name)
|
self.retrieved_docs = retrieved_docs or []
|
||||||
|
self.llm_handler = LLMHandlerCreator.create_handler(
|
||||||
|
llm_name if llm_name else "default"
|
||||||
|
)
|
||||||
self.attachments = attachments or []
|
self.attachments = attachments or []
|
||||||
|
self.json_schema = json_schema
|
||||||
|
self.limited_token_mode = limited_token_mode
|
||||||
|
self.token_limit = token_limit
|
||||||
|
self.limited_request_mode = limited_request_mode
|
||||||
|
self.request_limit = request_limit
|
||||||
|
|
||||||
@log_activity()
|
@log_activity()
|
||||||
def gen(
|
def gen(
|
||||||
self, query: str, retriever: BaseRetriever, log_context: LogContext = None
|
self, query: str, log_context: LogContext = None
|
||||||
) -> Generator[Dict, None, None]:
|
) -> Generator[Dict, None, None]:
|
||||||
yield from self._gen_inner(query, retriever, log_context)
|
yield from self._gen_inner(query, log_context)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _gen_inner(
|
def _gen_inner(
|
||||||
self, query: str, retriever: BaseRetriever, log_context: LogContext
|
self, query: str, log_context: LogContext
|
||||||
) -> Generator[Dict, None, None]:
|
) -> Generator[Dict, None, None]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -87,8 +103,8 @@ class BaseAgent(ABC):
|
|||||||
user_tools_collection = db["user_tools"]
|
user_tools_collection = db["user_tools"]
|
||||||
user_tools = user_tools_collection.find({"user": user, "status": True})
|
user_tools = user_tools_collection.find({"user": user, "status": True})
|
||||||
user_tools = list(user_tools)
|
user_tools = list(user_tools)
|
||||||
tools_by_id = {str(tool["_id"]): tool for tool in user_tools}
|
|
||||||
return tools_by_id
|
return {str(i): tool for i, tool in enumerate(user_tools)}
|
||||||
|
|
||||||
def _build_tool_parameters(self, action):
|
def _build_tool_parameters(self, action):
|
||||||
params = {"type": "object", "properties": {}, "required": []}
|
params = {"type": "object", "properties": {}, "required": []}
|
||||||
@@ -132,6 +148,50 @@ class BaseAgent(ABC):
|
|||||||
parser = ToolActionParser(self.llm.__class__.__name__)
|
parser = ToolActionParser(self.llm.__class__.__name__)
|
||||||
tool_id, action_name, call_args = parser.parse_args(call)
|
tool_id, action_name, call_args = parser.parse_args(call)
|
||||||
|
|
||||||
|
call_id = getattr(call, "id", None) or str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Check if parsing failed
|
||||||
|
|
||||||
|
if tool_id is None or action_name is None:
|
||||||
|
error_message = f"Error: Failed to parse LLM tool call. Tool name: {getattr(call, 'name', 'unknown')}"
|
||||||
|
logger.error(error_message)
|
||||||
|
|
||||||
|
tool_call_data = {
|
||||||
|
"tool_name": "unknown",
|
||||||
|
"call_id": call_id,
|
||||||
|
"action_name": getattr(call, "name", "unknown"),
|
||||||
|
"arguments": call_args or {},
|
||||||
|
"result": f"Failed to parse tool call. Invalid tool name format: {getattr(call, 'name', 'unknown')}",
|
||||||
|
}
|
||||||
|
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
|
||||||
|
self.tool_calls.append(tool_call_data)
|
||||||
|
return "Failed to parse tool call.", call_id
|
||||||
|
# Check if tool_id exists in available tools
|
||||||
|
|
||||||
|
if tool_id not in tools_dict:
|
||||||
|
error_message = f"Error: Tool ID '{tool_id}' extracted from LLM call not found in available tools_dict. Available IDs: {list(tools_dict.keys())}"
|
||||||
|
logger.error(error_message)
|
||||||
|
|
||||||
|
# Return error result
|
||||||
|
|
||||||
|
tool_call_data = {
|
||||||
|
"tool_name": "unknown",
|
||||||
|
"call_id": call_id,
|
||||||
|
"action_name": f"{action_name}_{tool_id}",
|
||||||
|
"arguments": call_args,
|
||||||
|
"result": f"Tool with ID {tool_id} not found. Available tools: {list(tools_dict.keys())}",
|
||||||
|
}
|
||||||
|
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
|
||||||
|
self.tool_calls.append(tool_call_data)
|
||||||
|
return f"Tool with ID {tool_id} not found.", call_id
|
||||||
|
tool_call_data = {
|
||||||
|
"tool_name": tools_dict[tool_id]["name"],
|
||||||
|
"call_id": call_id,
|
||||||
|
"action_name": f"{action_name}_{tool_id}",
|
||||||
|
"arguments": call_args,
|
||||||
|
}
|
||||||
|
yield {"type": "tool_call", "data": {**tool_call_data, "status": "pending"}}
|
||||||
|
|
||||||
tool_data = tools_dict[tool_id]
|
tool_data = tools_dict[tool_id]
|
||||||
action_data = (
|
action_data = (
|
||||||
tool_data["config"]["actions"][action_name]
|
tool_data["config"]["actions"][action_name]
|
||||||
@@ -163,18 +223,26 @@ class BaseAgent(ABC):
|
|||||||
):
|
):
|
||||||
target_dict[param] = value
|
target_dict[param] = value
|
||||||
tm = ToolManager(config={})
|
tm = ToolManager(config={})
|
||||||
|
|
||||||
|
# Prepare tool_config and add tool_id for memory tools
|
||||||
|
|
||||||
|
if tool_data["name"] == "api_tool":
|
||||||
|
tool_config = {
|
||||||
|
"url": tool_data["config"]["actions"][action_name]["url"],
|
||||||
|
"method": tool_data["config"]["actions"][action_name]["method"],
|
||||||
|
"headers": headers,
|
||||||
|
"query_params": query_params,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
tool_config = tool_data["config"].copy() if tool_data["config"] else {}
|
||||||
|
# Add tool_id from MongoDB _id for tools that need instance isolation (like memory tool)
|
||||||
|
# Use MongoDB _id if available, otherwise fall back to enumerated tool_id
|
||||||
|
|
||||||
|
tool_config["tool_id"] = str(tool_data.get("_id", tool_id))
|
||||||
tool = tm.load_tool(
|
tool = tm.load_tool(
|
||||||
tool_data["name"],
|
tool_data["name"],
|
||||||
tool_config=(
|
tool_config=tool_config,
|
||||||
{
|
user_id=self.user, # Pass user ID for MCP tools credential decryption
|
||||||
"url": tool_data["config"]["actions"][action_name]["url"],
|
|
||||||
"method": tool_data["config"]["actions"][action_name]["method"],
|
|
||||||
"headers": headers,
|
|
||||||
"query_params": query_params,
|
|
||||||
}
|
|
||||||
if tool_data["name"] == "api_tool"
|
|
||||||
else tool_data["config"]
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
if tool_data["name"] == "api_tool":
|
if tool_data["name"] == "api_tool":
|
||||||
print(
|
print(
|
||||||
@@ -184,33 +252,41 @@ class BaseAgent(ABC):
|
|||||||
else:
|
else:
|
||||||
print(f"Executing tool: {action_name} with args: {call_args}")
|
print(f"Executing tool: {action_name} with args: {call_args}")
|
||||||
result = tool.execute_action(action_name, **parameters)
|
result = tool.execute_action(action_name, **parameters)
|
||||||
call_id = getattr(call, "id", None)
|
tool_call_data["result"] = (
|
||||||
|
f"{str(result)[:50]}..." if len(str(result)) > 50 else result
|
||||||
|
)
|
||||||
|
|
||||||
tool_call_data = {
|
yield {"type": "tool_call", "data": {**tool_call_data, "status": "completed"}}
|
||||||
"tool_name": tool_data["name"],
|
|
||||||
"call_id": call_id if call_id is not None else "None",
|
|
||||||
"action_name": f"{action_name}_{tool_id}",
|
|
||||||
"arguments": call_args,
|
|
||||||
"result": result,
|
|
||||||
}
|
|
||||||
self.tool_calls.append(tool_call_data)
|
self.tool_calls.append(tool_call_data)
|
||||||
|
|
||||||
return result, call_id
|
return result, call_id
|
||||||
|
|
||||||
|
def _get_truncated_tool_calls(self):
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
**tool_call,
|
||||||
|
"result": (
|
||||||
|
f"{str(tool_call['result'])[:50]}..."
|
||||||
|
if len(str(tool_call["result"])) > 50
|
||||||
|
else tool_call["result"]
|
||||||
|
),
|
||||||
|
"status": "completed",
|
||||||
|
}
|
||||||
|
for tool_call in self.tool_calls
|
||||||
|
]
|
||||||
|
|
||||||
def _build_messages(
|
def _build_messages(
|
||||||
self,
|
self,
|
||||||
system_prompt: str,
|
system_prompt: str,
|
||||||
query: str,
|
query: str,
|
||||||
retrieved_data: List[Dict],
|
|
||||||
) -> List[Dict]:
|
) -> List[Dict]:
|
||||||
docs_together = "\n".join([doc["text"] for doc in retrieved_data])
|
"""Build messages using pre-rendered system prompt"""
|
||||||
p_chat_combine = system_prompt.replace("{summaries}", docs_together)
|
messages = [{"role": "system", "content": system_prompt}]
|
||||||
messages_combine = [{"role": "system", "content": p_chat_combine}]
|
|
||||||
|
|
||||||
for i in self.chat_history:
|
for i in self.chat_history:
|
||||||
if "prompt" in i and "response" in i:
|
if "prompt" in i and "response" in i:
|
||||||
messages_combine.append({"role": "user", "content": i["prompt"]})
|
messages.append({"role": "user", "content": i["prompt"]})
|
||||||
messages_combine.append({"role": "assistant", "content": i["response"]})
|
messages.append({"role": "assistant", "content": i["response"]})
|
||||||
if "tool_calls" in i:
|
if "tool_calls" in i:
|
||||||
for tool_call in i["tool_calls"]:
|
for tool_call in i["tool_calls"]:
|
||||||
call_id = tool_call.get("call_id") or str(uuid.uuid4())
|
call_id = tool_call.get("call_id") or str(uuid.uuid4())
|
||||||
@@ -230,31 +306,39 @@ class BaseAgent(ABC):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
messages_combine.append(
|
messages.append(
|
||||||
{"role": "assistant", "content": [function_call_dict]}
|
{"role": "assistant", "content": [function_call_dict]}
|
||||||
)
|
)
|
||||||
messages_combine.append(
|
messages.append(
|
||||||
{"role": "tool", "content": [function_response_dict]}
|
{"role": "tool", "content": [function_response_dict]}
|
||||||
)
|
)
|
||||||
messages_combine.append({"role": "user", "content": query})
|
messages.append({"role": "user", "content": query})
|
||||||
return messages_combine
|
return messages
|
||||||
|
|
||||||
def _retriever_search(
|
|
||||||
self,
|
|
||||||
retriever: BaseRetriever,
|
|
||||||
query: str,
|
|
||||||
log_context: Optional[LogContext] = None,
|
|
||||||
) -> List[Dict]:
|
|
||||||
retrieved_data = retriever.search(query)
|
|
||||||
if log_context:
|
|
||||||
data = build_stack_data(retriever, exclude_attributes=["llm"])
|
|
||||||
log_context.stacks.append({"component": "retriever", "data": data})
|
|
||||||
return retrieved_data
|
|
||||||
|
|
||||||
def _llm_gen(self, messages: List[Dict], log_context: Optional[LogContext] = None):
|
def _llm_gen(self, messages: List[Dict], log_context: Optional[LogContext] = None):
|
||||||
resp = self.llm.gen_stream(
|
gen_kwargs = {"model": self.gpt_model, "messages": messages}
|
||||||
model=self.gpt_model, messages=messages, tools=self.tools
|
|
||||||
)
|
if (
|
||||||
|
hasattr(self.llm, "_supports_tools")
|
||||||
|
and self.llm._supports_tools
|
||||||
|
and self.tools
|
||||||
|
):
|
||||||
|
gen_kwargs["tools"] = self.tools
|
||||||
|
if (
|
||||||
|
self.json_schema
|
||||||
|
and hasattr(self.llm, "_supports_structured_output")
|
||||||
|
and self.llm._supports_structured_output()
|
||||||
|
):
|
||||||
|
structured_format = self.llm.prepare_structured_output_format(
|
||||||
|
self.json_schema
|
||||||
|
)
|
||||||
|
if structured_format:
|
||||||
|
if self.llm_name == "openai":
|
||||||
|
gen_kwargs["response_format"] = structured_format
|
||||||
|
elif self.llm_name == "google":
|
||||||
|
gen_kwargs["response_schema"] = structured_format
|
||||||
|
resp = self.llm.gen_stream(**gen_kwargs)
|
||||||
|
|
||||||
if log_context:
|
if log_context:
|
||||||
data = build_stack_data(self.llm, exclude_attributes=["client"])
|
data = build_stack_data(self.llm, exclude_attributes=["client"])
|
||||||
log_context.stacks.append({"component": "llm", "data": data})
|
log_context.stacks.append({"component": "llm", "data": data})
|
||||||
@@ -268,10 +352,51 @@ class BaseAgent(ABC):
|
|||||||
log_context: Optional[LogContext] = None,
|
log_context: Optional[LogContext] = None,
|
||||||
attachments: Optional[List[Dict]] = None,
|
attachments: Optional[List[Dict]] = None,
|
||||||
):
|
):
|
||||||
resp = self.llm_handler.handle_response(
|
resp = self.llm_handler.process_message_flow(
|
||||||
self, resp, tools_dict, messages, attachments
|
self, resp, tools_dict, messages, attachments, True
|
||||||
)
|
)
|
||||||
if log_context:
|
if log_context:
|
||||||
data = build_stack_data(self.llm_handler, exclude_attributes=["tool_calls"])
|
data = build_stack_data(self.llm_handler, exclude_attributes=["tool_calls"])
|
||||||
log_context.stacks.append({"component": "llm_handler", "data": data})
|
log_context.stacks.append({"component": "llm_handler", "data": data})
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
def _handle_response(self, response, tools_dict, messages, log_context):
|
||||||
|
is_structured_output = (
|
||||||
|
self.json_schema is not None
|
||||||
|
and hasattr(self.llm, "_supports_structured_output")
|
||||||
|
and self.llm._supports_structured_output()
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(response, str):
|
||||||
|
answer_data = {"answer": response}
|
||||||
|
if is_structured_output:
|
||||||
|
answer_data["structured"] = True
|
||||||
|
answer_data["schema"] = self.json_schema
|
||||||
|
yield answer_data
|
||||||
|
return
|
||||||
|
if hasattr(response, "message") and getattr(response.message, "content", None):
|
||||||
|
answer_data = {"answer": response.message.content}
|
||||||
|
if is_structured_output:
|
||||||
|
answer_data["structured"] = True
|
||||||
|
answer_data["schema"] = self.json_schema
|
||||||
|
yield answer_data
|
||||||
|
return
|
||||||
|
processed_response_gen = self._llm_handler(
|
||||||
|
response, tools_dict, messages, log_context, self.attachments
|
||||||
|
)
|
||||||
|
|
||||||
|
for event in processed_response_gen:
|
||||||
|
if isinstance(event, str):
|
||||||
|
answer_data = {"answer": event}
|
||||||
|
if is_structured_output:
|
||||||
|
answer_data["structured"] = True
|
||||||
|
answer_data["schema"] = self.json_schema
|
||||||
|
yield answer_data
|
||||||
|
elif hasattr(event, "message") and getattr(event.message, "content", None):
|
||||||
|
answer_data = {"answer": event.message.content}
|
||||||
|
if is_structured_output:
|
||||||
|
answer_data["structured"] = True
|
||||||
|
answer_data["schema"] = self.json_schema
|
||||||
|
yield answer_data
|
||||||
|
elif isinstance(event, dict) and "type" in event:
|
||||||
|
yield event
|
||||||
|
|||||||
@@ -1,64 +1,37 @@
|
|||||||
|
import logging
|
||||||
from typing import Dict, Generator
|
from typing import Dict, Generator
|
||||||
|
|
||||||
from application.agents.base import BaseAgent
|
from application.agents.base import BaseAgent
|
||||||
from application.logging import LogContext
|
from application.logging import LogContext
|
||||||
|
|
||||||
from application.retriever.base import BaseRetriever
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ClassicAgent(BaseAgent):
|
class ClassicAgent(BaseAgent):
|
||||||
|
"""A simplified agent with clear execution flow"""
|
||||||
|
|
||||||
def _gen_inner(
|
def _gen_inner(
|
||||||
self, query: str, retriever: BaseRetriever, log_context: LogContext
|
self, query: str, log_context: LogContext
|
||||||
) -> Generator[Dict, None, None]:
|
) -> Generator[Dict, None, None]:
|
||||||
retrieved_data = self._retriever_search(retriever, query, log_context)
|
"""Core generator function for ClassicAgent execution flow"""
|
||||||
if self.user_api_key:
|
|
||||||
tools_dict = self._get_tools(self.user_api_key)
|
tools_dict = (
|
||||||
else:
|
self._get_user_tools(self.user)
|
||||||
tools_dict = self._get_user_tools(self.user)
|
if not self.user_api_key
|
||||||
|
else self._get_tools(self.user_api_key)
|
||||||
|
)
|
||||||
self._prepare_tools(tools_dict)
|
self._prepare_tools(tools_dict)
|
||||||
|
|
||||||
messages = self._build_messages(self.prompt, query, retrieved_data)
|
messages = self._build_messages(self.prompt, query)
|
||||||
|
llm_response = self._llm_gen(messages, log_context)
|
||||||
|
|
||||||
resp = self._llm_gen(messages, log_context)
|
yield from self._handle_response(
|
||||||
|
llm_response, tools_dict, messages, log_context
|
||||||
|
)
|
||||||
|
|
||||||
attachments = self.attachments
|
yield {"sources": self.retrieved_docs}
|
||||||
|
yield {"tool_calls": self._get_truncated_tool_calls()}
|
||||||
if isinstance(resp, str):
|
|
||||||
yield {"answer": resp}
|
|
||||||
return
|
|
||||||
if (
|
|
||||||
hasattr(resp, "message")
|
|
||||||
and hasattr(resp.message, "content")
|
|
||||||
and resp.message.content is not None
|
|
||||||
):
|
|
||||||
yield {"answer": resp.message.content}
|
|
||||||
return
|
|
||||||
|
|
||||||
resp = self._llm_handler(resp, tools_dict, messages, log_context, attachments)
|
|
||||||
|
|
||||||
if isinstance(resp, str):
|
|
||||||
yield {"answer": resp}
|
|
||||||
elif (
|
|
||||||
hasattr(resp, "message")
|
|
||||||
and hasattr(resp.message, "content")
|
|
||||||
and resp.message.content is not None
|
|
||||||
):
|
|
||||||
yield {"answer": resp.message.content}
|
|
||||||
else:
|
|
||||||
for line in resp:
|
|
||||||
if isinstance(line, str):
|
|
||||||
yield {"answer": line}
|
|
||||||
|
|
||||||
log_context.stacks.append(
|
log_context.stacks.append(
|
||||||
{"component": "agent", "data": {"tool_calls": self.tool_calls.copy()}}
|
{"component": "agent", "data": {"tool_calls": self.tool_calls.copy()}}
|
||||||
)
|
)
|
||||||
|
|
||||||
yield {"sources": retrieved_data}
|
|
||||||
# clean tool_call_data only send first 50 characters of tool_call['result']
|
|
||||||
for tool_call in self.tool_calls:
|
|
||||||
if len(str(tool_call["result"])) > 50:
|
|
||||||
tool_call["result"] = str(tool_call["result"])[:50] + "..."
|
|
||||||
yield {"tool_calls": self.tool_calls.copy()}
|
|
||||||
|
|||||||
@@ -1,351 +0,0 @@
|
|||||||
import json
|
|
||||||
import logging
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
|
|
||||||
from application.logging import build_stack_data
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class LLMHandler(ABC):
|
|
||||||
def __init__(self):
|
|
||||||
self.llm_calls = []
|
|
||||||
self.tool_calls = []
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def handle_response(self, agent, resp, tools_dict, messages, attachments=None, **kwargs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def prepare_messages_with_attachments(self, agent, messages, attachments=None):
|
|
||||||
"""
|
|
||||||
Prepare messages with attachment content if available.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent: The current agent instance.
|
|
||||||
messages (list): List of message dictionaries.
|
|
||||||
attachments (list): List of attachment dictionaries with content.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list: Messages with attachment context added to the system prompt.
|
|
||||||
"""
|
|
||||||
if not attachments:
|
|
||||||
return messages
|
|
||||||
|
|
||||||
logger.info(f"Preparing messages with {len(attachments)} attachments")
|
|
||||||
|
|
||||||
supported_types = agent.llm.get_supported_attachment_types()
|
|
||||||
|
|
||||||
supported_attachments = []
|
|
||||||
unsupported_attachments = []
|
|
||||||
|
|
||||||
for attachment in attachments:
|
|
||||||
mime_type = attachment.get('mime_type')
|
|
||||||
if mime_type in supported_types:
|
|
||||||
supported_attachments.append(attachment)
|
|
||||||
else:
|
|
||||||
unsupported_attachments.append(attachment)
|
|
||||||
|
|
||||||
# Process supported attachments with the LLM's custom method
|
|
||||||
prepared_messages = messages
|
|
||||||
if supported_attachments:
|
|
||||||
logger.info(f"Processing {len(supported_attachments)} supported attachments with {agent.llm.__class__.__name__}'s method")
|
|
||||||
prepared_messages = agent.llm.prepare_messages_with_attachments(messages, supported_attachments)
|
|
||||||
|
|
||||||
# Process unsupported attachments with the default method
|
|
||||||
if unsupported_attachments:
|
|
||||||
logger.info(f"Processing {len(unsupported_attachments)} unsupported attachments with default method")
|
|
||||||
prepared_messages = self._append_attachment_content_to_system(prepared_messages, unsupported_attachments)
|
|
||||||
|
|
||||||
return prepared_messages
|
|
||||||
|
|
||||||
def _append_attachment_content_to_system(self, messages, attachments):
|
|
||||||
"""
|
|
||||||
Default method to append attachment content to the system prompt.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages (list): List of message dictionaries.
|
|
||||||
attachments (list): List of attachment dictionaries with content.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list: Messages with attachment context added to the system prompt.
|
|
||||||
"""
|
|
||||||
prepared_messages = messages.copy()
|
|
||||||
|
|
||||||
attachment_texts = []
|
|
||||||
for attachment in attachments:
|
|
||||||
logger.info(f"Adding attachment {attachment.get('id')} to context")
|
|
||||||
if 'content' in attachment:
|
|
||||||
attachment_texts.append(f"Attached file content:\n\n{attachment['content']}")
|
|
||||||
|
|
||||||
if attachment_texts:
|
|
||||||
combined_attachment_text = "\n\n".join(attachment_texts)
|
|
||||||
|
|
||||||
system_found = False
|
|
||||||
for i in range(len(prepared_messages)):
|
|
||||||
if prepared_messages[i].get("role") == "system":
|
|
||||||
prepared_messages[i]["content"] += f"\n\n{combined_attachment_text}"
|
|
||||||
system_found = True
|
|
||||||
break
|
|
||||||
|
|
||||||
if not system_found:
|
|
||||||
prepared_messages.insert(0, {"role": "system", "content": combined_attachment_text})
|
|
||||||
|
|
||||||
return prepared_messages
|
|
||||||
|
|
||||||
class OpenAILLMHandler(LLMHandler):
|
|
||||||
def handle_response(self, agent, resp, tools_dict, messages, attachments=None, stream: bool = True):
|
|
||||||
|
|
||||||
messages = self.prepare_messages_with_attachments(agent, messages, attachments)
|
|
||||||
logger.info(f"Messages with attachments: {messages}")
|
|
||||||
if not stream:
|
|
||||||
while hasattr(resp, "finish_reason") and resp.finish_reason == "tool_calls":
|
|
||||||
message = json.loads(resp.model_dump_json())["message"]
|
|
||||||
keys_to_remove = {"audio", "function_call", "refusal"}
|
|
||||||
filtered_data = {
|
|
||||||
k: v for k, v in message.items() if k not in keys_to_remove
|
|
||||||
}
|
|
||||||
messages.append(filtered_data)
|
|
||||||
|
|
||||||
tool_calls = resp.message.tool_calls
|
|
||||||
for call in tool_calls:
|
|
||||||
try:
|
|
||||||
self.tool_calls.append(call)
|
|
||||||
tool_response, call_id = agent._execute_tool_action(
|
|
||||||
tools_dict, call
|
|
||||||
)
|
|
||||||
function_call_dict = {
|
|
||||||
"function_call": {
|
|
||||||
"name": call.function.name,
|
|
||||||
"args": call.function.arguments,
|
|
||||||
"call_id": call_id,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
function_response_dict = {
|
|
||||||
"function_response": {
|
|
||||||
"name": call.function.name,
|
|
||||||
"response": {"result": tool_response},
|
|
||||||
"call_id": call_id,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
messages.append(
|
|
||||||
{"role": "assistant", "content": [function_call_dict]}
|
|
||||||
)
|
|
||||||
messages.append(
|
|
||||||
{"role": "tool", "content": [function_response_dict]}
|
|
||||||
)
|
|
||||||
|
|
||||||
messages = self.prepare_messages_with_attachments(agent, messages, attachments)
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Error executing tool: {str(e)}", exc_info=True)
|
|
||||||
messages.append(
|
|
||||||
{
|
|
||||||
"role": "tool",
|
|
||||||
"content": f"Error executing tool: {str(e)}",
|
|
||||||
"tool_call_id": call_id,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
resp = agent.llm.gen_stream(
|
|
||||||
model=agent.gpt_model, messages=messages, tools=agent.tools
|
|
||||||
)
|
|
||||||
self.llm_calls.append(build_stack_data(agent.llm))
|
|
||||||
return resp
|
|
||||||
|
|
||||||
else:
|
|
||||||
text_buffer = ""
|
|
||||||
while True:
|
|
||||||
tool_calls = {}
|
|
||||||
for chunk in resp:
|
|
||||||
if isinstance(chunk, str) and len(chunk) > 0:
|
|
||||||
yield chunk
|
|
||||||
continue
|
|
||||||
elif hasattr(chunk, "delta"):
|
|
||||||
chunk_delta = chunk.delta
|
|
||||||
|
|
||||||
if (
|
|
||||||
hasattr(chunk_delta, "tool_calls")
|
|
||||||
and chunk_delta.tool_calls is not None
|
|
||||||
):
|
|
||||||
for tool_call in chunk_delta.tool_calls:
|
|
||||||
index = tool_call.index
|
|
||||||
if index not in tool_calls:
|
|
||||||
tool_calls[index] = {
|
|
||||||
"id": "",
|
|
||||||
"function": {"name": "", "arguments": ""},
|
|
||||||
}
|
|
||||||
|
|
||||||
current = tool_calls[index]
|
|
||||||
if tool_call.id:
|
|
||||||
current["id"] = tool_call.id
|
|
||||||
if tool_call.function.name:
|
|
||||||
current["function"][
|
|
||||||
"name"
|
|
||||||
] = tool_call.function.name
|
|
||||||
if tool_call.function.arguments:
|
|
||||||
current["function"][
|
|
||||||
"arguments"
|
|
||||||
] += tool_call.function.arguments
|
|
||||||
tool_calls[index] = current
|
|
||||||
|
|
||||||
if (
|
|
||||||
hasattr(chunk, "finish_reason")
|
|
||||||
and chunk.finish_reason == "tool_calls"
|
|
||||||
):
|
|
||||||
for index in sorted(tool_calls.keys()):
|
|
||||||
call = tool_calls[index]
|
|
||||||
try:
|
|
||||||
self.tool_calls.append(call)
|
|
||||||
tool_response, call_id = agent._execute_tool_action(
|
|
||||||
tools_dict, call
|
|
||||||
)
|
|
||||||
if isinstance(call["function"]["arguments"], str):
|
|
||||||
call["function"]["arguments"] = json.loads(call["function"]["arguments"])
|
|
||||||
|
|
||||||
function_call_dict = {
|
|
||||||
"function_call": {
|
|
||||||
"name": call["function"]["name"],
|
|
||||||
"args": call["function"]["arguments"],
|
|
||||||
"call_id": call["id"],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
function_response_dict = {
|
|
||||||
"function_response": {
|
|
||||||
"name": call["function"]["name"],
|
|
||||||
"response": {"result": tool_response},
|
|
||||||
"call_id": call["id"],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
messages.append(
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": [function_call_dict],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
messages.append(
|
|
||||||
{
|
|
||||||
"role": "tool",
|
|
||||||
"content": [function_response_dict],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Error executing tool: {str(e)}", exc_info=True)
|
|
||||||
messages.append(
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": f"Error executing tool: {str(e)}",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
tool_calls = {}
|
|
||||||
if hasattr(chunk_delta, "content") and chunk_delta.content:
|
|
||||||
# Add to buffer or yield immediately based on your preference
|
|
||||||
text_buffer += chunk_delta.content
|
|
||||||
yield text_buffer
|
|
||||||
text_buffer = ""
|
|
||||||
|
|
||||||
if (
|
|
||||||
hasattr(chunk, "finish_reason")
|
|
||||||
and chunk.finish_reason == "stop"
|
|
||||||
):
|
|
||||||
return resp
|
|
||||||
elif isinstance(chunk, str) and len(chunk) == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
logger.info(f"Regenerating with messages: {messages}")
|
|
||||||
resp = agent.llm.gen_stream(
|
|
||||||
model=agent.gpt_model, messages=messages, tools=agent.tools
|
|
||||||
)
|
|
||||||
self.llm_calls.append(build_stack_data(agent.llm))
|
|
||||||
|
|
||||||
|
|
||||||
class GoogleLLMHandler(LLMHandler):
|
|
||||||
def handle_response(self, agent, resp, tools_dict, messages, attachments=None, stream: bool = True):
|
|
||||||
from google.genai import types
|
|
||||||
|
|
||||||
messages = self.prepare_messages_with_attachments(agent, messages, attachments)
|
|
||||||
|
|
||||||
while True:
|
|
||||||
if not stream:
|
|
||||||
response = agent.llm.gen(
|
|
||||||
model=agent.gpt_model, messages=messages, tools=agent.tools
|
|
||||||
)
|
|
||||||
self.llm_calls.append(build_stack_data(agent.llm))
|
|
||||||
if response.candidates and response.candidates[0].content.parts:
|
|
||||||
tool_call_found = False
|
|
||||||
for part in response.candidates[0].content.parts:
|
|
||||||
if part.function_call:
|
|
||||||
tool_call_found = True
|
|
||||||
self.tool_calls.append(part.function_call)
|
|
||||||
tool_response, call_id = agent._execute_tool_action(
|
|
||||||
tools_dict, part.function_call
|
|
||||||
)
|
|
||||||
function_response_part = types.Part.from_function_response(
|
|
||||||
name=part.function_call.name,
|
|
||||||
response={"result": tool_response},
|
|
||||||
)
|
|
||||||
|
|
||||||
messages.append(
|
|
||||||
{"role": "model", "content": [part.to_json_dict()]}
|
|
||||||
)
|
|
||||||
messages.append(
|
|
||||||
{
|
|
||||||
"role": "tool",
|
|
||||||
"content": [function_response_part.to_json_dict()],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
not tool_call_found
|
|
||||||
and response.candidates[0].content.parts
|
|
||||||
and response.candidates[0].content.parts[0].text
|
|
||||||
):
|
|
||||||
return response.candidates[0].content.parts[0].text
|
|
||||||
elif not tool_call_found:
|
|
||||||
return response.candidates[0].content.parts
|
|
||||||
|
|
||||||
else:
|
|
||||||
return response
|
|
||||||
|
|
||||||
else:
|
|
||||||
response = agent.llm.gen_stream(
|
|
||||||
model=agent.gpt_model, messages=messages, tools=agent.tools
|
|
||||||
)
|
|
||||||
self.llm_calls.append(build_stack_data(agent.llm))
|
|
||||||
|
|
||||||
tool_call_found = False
|
|
||||||
for result in response:
|
|
||||||
if hasattr(result, "function_call"):
|
|
||||||
tool_call_found = True
|
|
||||||
self.tool_calls.append(result.function_call)
|
|
||||||
tool_response, call_id = agent._execute_tool_action(
|
|
||||||
tools_dict, result.function_call
|
|
||||||
)
|
|
||||||
function_response_part = types.Part.from_function_response(
|
|
||||||
name=result.function_call.name,
|
|
||||||
response={"result": tool_response},
|
|
||||||
)
|
|
||||||
|
|
||||||
messages.append(
|
|
||||||
{"role": "model", "content": [result.to_json_dict()]}
|
|
||||||
)
|
|
||||||
messages.append(
|
|
||||||
{
|
|
||||||
"role": "tool",
|
|
||||||
"content": [function_response_part.to_json_dict()],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
tool_call_found = False
|
|
||||||
yield result
|
|
||||||
|
|
||||||
if not tool_call_found:
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
def get_llm_handler(llm_type):
|
|
||||||
handlers = {
|
|
||||||
"openai": OpenAILLMHandler(),
|
|
||||||
"google": GoogleLLMHandler(),
|
|
||||||
}
|
|
||||||
return handlers.get(llm_type, OpenAILLMHandler())
|
|
||||||
@@ -1,9 +1,13 @@
|
|||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Dict, Generator, List
|
from typing import Any, Dict, Generator, List
|
||||||
|
|
||||||
from application.agents.base import BaseAgent
|
from application.agents.base import BaseAgent
|
||||||
from application.logging import build_stack_data, LogContext
|
from application.logging import build_stack_data, LogContext
|
||||||
from application.retriever.base import BaseRetriever
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
MAX_ITERATIONS_REASONING = 10
|
||||||
|
|
||||||
current_dir = os.path.dirname(
|
current_dir = os.path.dirname(
|
||||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
@@ -11,122 +15,224 @@ current_dir = os.path.dirname(
|
|||||||
with open(
|
with open(
|
||||||
os.path.join(current_dir, "application/prompts", "react_planning_prompt.txt"), "r"
|
os.path.join(current_dir, "application/prompts", "react_planning_prompt.txt"), "r"
|
||||||
) as f:
|
) as f:
|
||||||
planning_prompt = f.read()
|
PLANNING_PROMPT_TEMPLATE = f.read()
|
||||||
with open(
|
with open(
|
||||||
os.path.join(current_dir, "application/prompts", "react_final_prompt.txt"),
|
os.path.join(current_dir, "application/prompts", "react_final_prompt.txt"), "r"
|
||||||
"r",
|
|
||||||
) as f:
|
) as f:
|
||||||
final_prompt = f.read()
|
FINAL_PROMPT_TEMPLATE = f.read()
|
||||||
|
|
||||||
|
|
||||||
class ReActAgent(BaseAgent):
|
class ReActAgent(BaseAgent):
|
||||||
|
"""
|
||||||
|
Research and Action (ReAct) Agent - Advanced reasoning agent with iterative planning.
|
||||||
|
|
||||||
|
Implements a think-act-observe loop for complex problem-solving:
|
||||||
|
1. Creates a strategic plan based on the query
|
||||||
|
2. Executes tools and gathers observations
|
||||||
|
3. Iteratively refines approach until satisfied
|
||||||
|
4. Synthesizes final answer from all observations
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.plan = ""
|
self.plan: str = ""
|
||||||
self.observations: List[str] = []
|
self.observations: List[str] = []
|
||||||
|
|
||||||
def _gen_inner(
|
def _gen_inner(
|
||||||
self, query: str, retriever: BaseRetriever, log_context: LogContext
|
self, query: str, log_context: LogContext
|
||||||
) -> Generator[Dict, None, None]:
|
) -> Generator[Dict, None, None]:
|
||||||
retrieved_data = self._retriever_search(retriever, query, log_context)
|
"""Execute ReAct reasoning loop with planning, action, and observation cycles"""
|
||||||
|
|
||||||
if self.user_api_key:
|
self._reset_state()
|
||||||
tools_dict = self._get_tools(self.user_api_key)
|
|
||||||
else:
|
tools_dict = (
|
||||||
tools_dict = self._get_user_tools(self.user)
|
self._get_tools(self.user_api_key)
|
||||||
|
if self.user_api_key
|
||||||
|
else self._get_user_tools(self.user)
|
||||||
|
)
|
||||||
self._prepare_tools(tools_dict)
|
self._prepare_tools(tools_dict)
|
||||||
|
|
||||||
docs_together = "\n".join([doc["text"] for doc in retrieved_data])
|
for iteration in range(1, MAX_ITERATIONS_REASONING + 1):
|
||||||
plan = self._create_plan(query, docs_together, log_context)
|
yield {"thought": f"Reasoning... (iteration {iteration})\n\n"}
|
||||||
for line in plan:
|
|
||||||
if isinstance(line, str):
|
|
||||||
self.plan += line
|
|
||||||
yield {"thought": line}
|
|
||||||
|
|
||||||
prompt = self.prompt + f"\nFollow this plan: {self.plan}"
|
yield from self._planning_phase(query, log_context)
|
||||||
messages = self._build_messages(prompt, query, retrieved_data)
|
|
||||||
|
|
||||||
resp = self._llm_gen(messages, log_context)
|
if not self.plan:
|
||||||
|
logger.warning(
|
||||||
|
f"ReActAgent: No plan generated in iteration {iteration}"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
self.observations.append(f"Plan (iteration {iteration}): {self.plan}")
|
||||||
|
|
||||||
if isinstance(resp, str):
|
satisfied = yield from self._execution_phase(query, tools_dict, log_context)
|
||||||
self.observations.append(resp)
|
|
||||||
if (
|
|
||||||
hasattr(resp, "message")
|
|
||||||
and hasattr(resp.message, "content")
|
|
||||||
and resp.message.content is not None
|
|
||||||
):
|
|
||||||
self.observations.append(resp.message.content)
|
|
||||||
|
|
||||||
resp = self._llm_handler(resp, tools_dict, messages, log_context)
|
if satisfied:
|
||||||
|
logger.info("ReActAgent: Goal satisfied, stopping reasoning loop")
|
||||||
|
break
|
||||||
|
yield from self._synthesis_phase(query, log_context)
|
||||||
|
|
||||||
|
def _reset_state(self):
|
||||||
|
"""Reset agent state for new query"""
|
||||||
|
self.plan = ""
|
||||||
|
self.observations = []
|
||||||
|
|
||||||
|
def _planning_phase(
|
||||||
|
self, query: str, log_context: LogContext
|
||||||
|
) -> Generator[Dict, None, None]:
|
||||||
|
"""Generate strategic plan for query"""
|
||||||
|
logger.info("ReActAgent: Creating plan...")
|
||||||
|
|
||||||
|
plan_prompt = self._build_planning_prompt(query)
|
||||||
|
messages = [{"role": "user", "content": plan_prompt}]
|
||||||
|
|
||||||
|
plan_stream = self.llm.gen_stream(
|
||||||
|
model=self.gpt_model,
|
||||||
|
messages=messages,
|
||||||
|
tools=self.tools if self.tools else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if log_context:
|
||||||
|
log_context.stacks.append(
|
||||||
|
{"component": "planning_llm", "data": build_stack_data(self.llm)}
|
||||||
|
)
|
||||||
|
plan_parts = []
|
||||||
|
for chunk in plan_stream:
|
||||||
|
content = self._extract_content(chunk)
|
||||||
|
if content:
|
||||||
|
plan_parts.append(content)
|
||||||
|
yield {"thought": content}
|
||||||
|
self.plan = "".join(plan_parts)
|
||||||
|
|
||||||
|
def _execution_phase(
|
||||||
|
self, query: str, tools_dict: Dict, log_context: LogContext
|
||||||
|
) -> Generator[bool, None, None]:
|
||||||
|
"""Execute plan with tool calls and observations"""
|
||||||
|
execution_prompt = self._build_execution_prompt(query)
|
||||||
|
messages = self._build_messages(execution_prompt, query)
|
||||||
|
|
||||||
|
llm_response = self._llm_gen(messages, log_context)
|
||||||
|
initial_content = self._extract_content(llm_response)
|
||||||
|
|
||||||
|
if initial_content:
|
||||||
|
self.observations.append(f"Initial response: {initial_content}")
|
||||||
|
processed_response = self._llm_handler(
|
||||||
|
llm_response, tools_dict, messages, log_context
|
||||||
|
)
|
||||||
|
|
||||||
for tool_call in self.tool_calls:
|
for tool_call in self.tool_calls:
|
||||||
observation = (
|
observation = (
|
||||||
f"Action '{tool_call['action_name']}' of tool '{tool_call['tool_name']}' "
|
f"Executed: {tool_call.get('tool_name', 'Unknown')} "
|
||||||
f"with arguments '{tool_call['arguments']}' returned: '{tool_call['result']}'"
|
f"with args {tool_call.get('arguments', {})}. "
|
||||||
|
f"Result: {str(tool_call.get('result', ''))[:200]}"
|
||||||
)
|
)
|
||||||
self.observations.append(observation)
|
self.observations.append(observation)
|
||||||
|
final_content = self._extract_content(processed_response)
|
||||||
if isinstance(resp, str):
|
if final_content:
|
||||||
self.observations.append(resp)
|
self.observations.append(f"Response after tools: {final_content}")
|
||||||
elif (
|
if log_context:
|
||||||
hasattr(resp, "message")
|
log_context.stacks.append(
|
||||||
and hasattr(resp.message, "content")
|
{
|
||||||
and resp.message.content is not None
|
"component": "agent_tool_calls",
|
||||||
):
|
"data": {"tool_calls": self.tool_calls.copy()},
|
||||||
self.observations.append(resp.message.content)
|
}
|
||||||
else:
|
|
||||||
completion = self.llm.gen_stream(
|
|
||||||
model=self.gpt_model, messages=messages, tools=self.tools
|
|
||||||
)
|
)
|
||||||
for line in completion:
|
yield {"sources": self.retrieved_docs}
|
||||||
if isinstance(line, str):
|
yield {"tool_calls": self._get_truncated_tool_calls()}
|
||||||
self.observations.append(line)
|
|
||||||
|
|
||||||
log_context.stacks.append(
|
return "SATISFIED" in (final_content or "")
|
||||||
{"component": "agent", "data": {"tool_calls": self.tool_calls.copy()}}
|
|
||||||
|
def _synthesis_phase(
|
||||||
|
self, query: str, log_context: LogContext
|
||||||
|
) -> Generator[Dict, None, None]:
|
||||||
|
"""Synthesize final answer from all observations"""
|
||||||
|
logger.info("ReActAgent: Generating final answer...")
|
||||||
|
|
||||||
|
final_prompt = self._build_final_answer_prompt(query)
|
||||||
|
messages = [{"role": "user", "content": final_prompt}]
|
||||||
|
|
||||||
|
final_stream = self.llm.gen_stream(
|
||||||
|
model=self.gpt_model, messages=messages, tools=None
|
||||||
)
|
)
|
||||||
|
|
||||||
yield {"sources": retrieved_data}
|
|
||||||
# clean tool_call_data only send first 50 characters of tool_call['result']
|
|
||||||
for tool_call in self.tool_calls:
|
|
||||||
if len(str(tool_call["result"])) > 50:
|
|
||||||
tool_call["result"] = str(tool_call["result"])[:50] + "..."
|
|
||||||
yield {"tool_calls": self.tool_calls.copy()}
|
|
||||||
|
|
||||||
final_answer = self._create_final_answer(query, self.observations, log_context)
|
|
||||||
for line in final_answer:
|
|
||||||
if isinstance(line, str):
|
|
||||||
yield {"answer": line}
|
|
||||||
|
|
||||||
def _create_plan(
|
|
||||||
self, query: str, docs_data: str, log_context: LogContext = None
|
|
||||||
) -> Generator[str, None, None]:
|
|
||||||
plan_prompt = planning_prompt.replace("{query}", query)
|
|
||||||
if "{summaries}" in planning_prompt:
|
|
||||||
summaries = docs_data
|
|
||||||
plan_prompt = plan_prompt.replace("{summaries}", summaries)
|
|
||||||
|
|
||||||
messages = [{"role": "user", "content": plan_prompt}]
|
|
||||||
print(self.tools)
|
|
||||||
plan = self.llm.gen_stream(
|
|
||||||
model=self.gpt_model, messages=messages, tools=self.tools
|
|
||||||
)
|
|
||||||
if log_context:
|
if log_context:
|
||||||
data = build_stack_data(self.llm)
|
log_context.stacks.append(
|
||||||
log_context.stacks.append({"component": "planning_llm", "data": data})
|
{"component": "final_answer_llm", "data": build_stack_data(self.llm)}
|
||||||
return plan
|
)
|
||||||
|
for chunk in final_stream:
|
||||||
|
content = self._extract_content(chunk)
|
||||||
|
if content:
|
||||||
|
yield {"answer": content}
|
||||||
|
|
||||||
def _create_final_answer(
|
def _build_planning_prompt(self, query: str) -> str:
|
||||||
self, query: str, observations: List[str], log_context: LogContext = None
|
"""Build planning phase prompt"""
|
||||||
) -> str:
|
prompt = PLANNING_PROMPT_TEMPLATE.replace("{query}", query)
|
||||||
observation_string = "\n".join(observations)
|
prompt = prompt.replace("{prompt}", self.prompt or "")
|
||||||
final_answer_prompt = final_prompt.format(
|
prompt = prompt.replace("{summaries}", "")
|
||||||
query=query, observations=observation_string
|
prompt = prompt.replace("{observations}", "\n".join(self.observations))
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def _build_execution_prompt(self, query: str) -> str:
|
||||||
|
"""Build execution phase prompt with plan and observations"""
|
||||||
|
observations_str = "\n".join(self.observations)
|
||||||
|
|
||||||
|
if len(observations_str) > 20000:
|
||||||
|
observations_str = observations_str[:20000] + "\n...[truncated]"
|
||||||
|
return (
|
||||||
|
f"{self.prompt or ''}\n\n"
|
||||||
|
f"Follow this plan:\n{self.plan}\n\n"
|
||||||
|
f"Observations:\n{observations_str}\n\n"
|
||||||
|
f"If sufficient data exists to answer '{query}', respond with 'SATISFIED'. "
|
||||||
|
f"Otherwise, continue executing the plan."
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = [{"role": "user", "content": final_answer_prompt}]
|
def _build_final_answer_prompt(self, query: str) -> str:
|
||||||
final_answer = self.llm.gen_stream(model=self.gpt_model, messages=messages)
|
"""Build final synthesis prompt"""
|
||||||
if log_context:
|
observations_str = "\n".join(self.observations)
|
||||||
data = build_stack_data(self.llm)
|
|
||||||
log_context.stacks.append({"component": "final_answer_llm", "data": data})
|
if len(observations_str) > 10000:
|
||||||
return final_answer
|
observations_str = observations_str[:10000] + "\n...[truncated]"
|
||||||
|
logger.warning("ReActAgent: Observations truncated for final answer")
|
||||||
|
return FINAL_PROMPT_TEMPLATE.format(query=query, observations=observations_str)
|
||||||
|
|
||||||
|
def _extract_content(self, response: Any) -> str:
|
||||||
|
"""Extract text content from various LLM response formats"""
|
||||||
|
if not response:
|
||||||
|
return ""
|
||||||
|
collected = []
|
||||||
|
|
||||||
|
if isinstance(response, str):
|
||||||
|
return response
|
||||||
|
if hasattr(response, "message") and hasattr(response.message, "content"):
|
||||||
|
if response.message.content:
|
||||||
|
return response.message.content
|
||||||
|
if hasattr(response, "choices") and response.choices:
|
||||||
|
if hasattr(response.choices[0], "message"):
|
||||||
|
content = response.choices[0].message.content
|
||||||
|
if content:
|
||||||
|
return content
|
||||||
|
if hasattr(response, "content") and isinstance(response.content, list):
|
||||||
|
if response.content and hasattr(response.content[0], "text"):
|
||||||
|
return response.content[0].text
|
||||||
|
try:
|
||||||
|
for chunk in response:
|
||||||
|
content_piece = ""
|
||||||
|
|
||||||
|
if hasattr(chunk, "choices") and chunk.choices:
|
||||||
|
if hasattr(chunk.choices[0], "delta"):
|
||||||
|
delta_content = chunk.choices[0].delta.content
|
||||||
|
if delta_content:
|
||||||
|
content_piece = delta_content
|
||||||
|
elif hasattr(chunk, "type") and chunk.type == "content_block_delta":
|
||||||
|
if hasattr(chunk, "delta") and hasattr(chunk.delta, "text"):
|
||||||
|
content_piece = chunk.delta.text
|
||||||
|
elif isinstance(chunk, str):
|
||||||
|
content_piece = chunk
|
||||||
|
if content_piece:
|
||||||
|
collected.append(content_piece)
|
||||||
|
except (TypeError, AttributeError):
|
||||||
|
logger.debug(
|
||||||
|
f"Response not iterable or unexpected format: {type(response)}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error extracting content: {e}")
|
||||||
|
return "".join(collected)
|
||||||
|
|||||||
@@ -25,27 +25,35 @@ class BraveSearchTool(Tool):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown action: {action_name}")
|
raise ValueError(f"Unknown action: {action_name}")
|
||||||
|
|
||||||
def _web_search(self, query, country="ALL", search_lang="en", count=10,
|
def _web_search(
|
||||||
offset=0, safesearch="off", freshness=None,
|
self,
|
||||||
result_filter=None, extra_snippets=False, summary=False):
|
query,
|
||||||
|
country="ALL",
|
||||||
|
search_lang="en",
|
||||||
|
count=10,
|
||||||
|
offset=0,
|
||||||
|
safesearch="off",
|
||||||
|
freshness=None,
|
||||||
|
result_filter=None,
|
||||||
|
extra_snippets=False,
|
||||||
|
summary=False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Performs a web search using the Brave Search API.
|
Performs a web search using the Brave Search API.
|
||||||
"""
|
"""
|
||||||
print(f"Performing Brave web search for: {query}")
|
print(f"Performing Brave web search for: {query}")
|
||||||
|
|
||||||
url = f"{self.base_url}/web/search"
|
url = f"{self.base_url}/web/search"
|
||||||
|
|
||||||
# Build query parameters
|
|
||||||
params = {
|
params = {
|
||||||
"q": query,
|
"q": query,
|
||||||
"country": country,
|
"country": country,
|
||||||
"search_lang": search_lang,
|
"search_lang": search_lang,
|
||||||
"count": min(count, 20),
|
"count": min(count, 20),
|
||||||
"offset": min(offset, 9),
|
"offset": min(offset, 9),
|
||||||
"safesearch": safesearch
|
"safesearch": safesearch,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add optional parameters only if they have values
|
|
||||||
if freshness:
|
if freshness:
|
||||||
params["freshness"] = freshness
|
params["freshness"] = freshness
|
||||||
if result_filter:
|
if result_filter:
|
||||||
@@ -54,68 +62,69 @@ class BraveSearchTool(Tool):
|
|||||||
params["extra_snippets"] = 1
|
params["extra_snippets"] = 1
|
||||||
if summary:
|
if summary:
|
||||||
params["summary"] = 1
|
params["summary"] = 1
|
||||||
|
|
||||||
# Set up headers
|
|
||||||
headers = {
|
headers = {
|
||||||
"Accept": "application/json",
|
"Accept": "application/json",
|
||||||
"Accept-Encoding": "gzip",
|
"Accept-Encoding": "gzip",
|
||||||
"X-Subscription-Token": self.token
|
"X-Subscription-Token": self.token,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Make the request
|
|
||||||
response = requests.get(url, params=params, headers=headers)
|
response = requests.get(url, params=params, headers=headers)
|
||||||
|
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
return {
|
return {
|
||||||
"status_code": response.status_code,
|
"status_code": response.status_code,
|
||||||
"results": response.json(),
|
"results": response.json(),
|
||||||
"message": "Search completed successfully."
|
"message": "Search completed successfully.",
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
return {
|
return {
|
||||||
"status_code": response.status_code,
|
"status_code": response.status_code,
|
||||||
"message": f"Search failed with status code: {response.status_code}."
|
"message": f"Search failed with status code: {response.status_code}.",
|
||||||
}
|
}
|
||||||
|
|
||||||
def _image_search(self, query, country="ALL", search_lang="en", count=5,
|
def _image_search(
|
||||||
safesearch="off", spellcheck=False):
|
self,
|
||||||
|
query,
|
||||||
|
country="ALL",
|
||||||
|
search_lang="en",
|
||||||
|
count=5,
|
||||||
|
safesearch="off",
|
||||||
|
spellcheck=False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Performs an image search using the Brave Search API.
|
Performs an image search using the Brave Search API.
|
||||||
"""
|
"""
|
||||||
print(f"Performing Brave image search for: {query}")
|
print(f"Performing Brave image search for: {query}")
|
||||||
|
|
||||||
url = f"{self.base_url}/images/search"
|
url = f"{self.base_url}/images/search"
|
||||||
|
|
||||||
# Build query parameters
|
|
||||||
params = {
|
params = {
|
||||||
"q": query,
|
"q": query,
|
||||||
"country": country,
|
"country": country,
|
||||||
"search_lang": search_lang,
|
"search_lang": search_lang,
|
||||||
"count": min(count, 100), # API max is 100
|
"count": min(count, 100), # API max is 100
|
||||||
"safesearch": safesearch,
|
"safesearch": safesearch,
|
||||||
"spellcheck": 1 if spellcheck else 0
|
"spellcheck": 1 if spellcheck else 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Set up headers
|
|
||||||
headers = {
|
headers = {
|
||||||
"Accept": "application/json",
|
"Accept": "application/json",
|
||||||
"Accept-Encoding": "gzip",
|
"Accept-Encoding": "gzip",
|
||||||
"X-Subscription-Token": self.token
|
"X-Subscription-Token": self.token,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Make the request
|
|
||||||
response = requests.get(url, params=params, headers=headers)
|
response = requests.get(url, params=params, headers=headers)
|
||||||
|
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
return {
|
return {
|
||||||
"status_code": response.status_code,
|
"status_code": response.status_code,
|
||||||
"results": response.json(),
|
"results": response.json(),
|
||||||
"message": "Image search completed successfully."
|
"message": "Image search completed successfully.",
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
return {
|
return {
|
||||||
"status_code": response.status_code,
|
"status_code": response.status_code,
|
||||||
"message": f"Image search failed with status code: {response.status_code}."
|
"message": f"Image search failed with status code: {response.status_code}.",
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_actions_metadata(self):
|
def get_actions_metadata(self):
|
||||||
@@ -130,42 +139,14 @@ class BraveSearchTool(Tool):
|
|||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "The search query (max 400 characters, 50 words)",
|
"description": "The search query (max 400 characters, 50 words)",
|
||||||
},
|
},
|
||||||
# "country": {
|
|
||||||
# "type": "string",
|
|
||||||
# "description": "The 2-character country code (default: US)",
|
|
||||||
# },
|
|
||||||
"search_lang": {
|
"search_lang": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "The search language preference (default: en)",
|
"description": "The search language preference (default: en)",
|
||||||
},
|
},
|
||||||
# "count": {
|
|
||||||
# "type": "integer",
|
|
||||||
# "description": "Number of results to return (max 20, default: 10)",
|
|
||||||
# },
|
|
||||||
# "offset": {
|
|
||||||
# "type": "integer",
|
|
||||||
# "description": "Pagination offset (max 9, default: 0)",
|
|
||||||
# },
|
|
||||||
# "safesearch": {
|
|
||||||
# "type": "string",
|
|
||||||
# "description": "Filter level for adult content (off, moderate, strict)",
|
|
||||||
# },
|
|
||||||
"freshness": {
|
"freshness": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Time filter for results (pd: last 24h, pw: last week, pm: last month, py: last year)",
|
"description": "Time filter for results (pd: last 24h, pw: last week, pm: last month, py: last year)",
|
||||||
},
|
},
|
||||||
# "result_filter": {
|
|
||||||
# "type": "string",
|
|
||||||
# "description": "Comma-delimited list of result types to include",
|
|
||||||
# },
|
|
||||||
# "extra_snippets": {
|
|
||||||
# "type": "boolean",
|
|
||||||
# "description": "Get additional excerpts from result pages",
|
|
||||||
# },
|
|
||||||
# "summary": {
|
|
||||||
# "type": "boolean",
|
|
||||||
# "description": "Enable summary generation in search results",
|
|
||||||
# }
|
|
||||||
},
|
},
|
||||||
"required": ["query"],
|
"required": ["query"],
|
||||||
"additionalProperties": False,
|
"additionalProperties": False,
|
||||||
@@ -181,37 +162,21 @@ class BraveSearchTool(Tool):
|
|||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "The search query (max 400 characters, 50 words)",
|
"description": "The search query (max 400 characters, 50 words)",
|
||||||
},
|
},
|
||||||
# "country": {
|
|
||||||
# "type": "string",
|
|
||||||
# "description": "The 2-character country code (default: US)",
|
|
||||||
# },
|
|
||||||
# "search_lang": {
|
|
||||||
# "type": "string",
|
|
||||||
# "description": "The search language preference (default: en)",
|
|
||||||
# },
|
|
||||||
"count": {
|
"count": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"description": "Number of results to return (max 100, default: 5)",
|
"description": "Number of results to return (max 100, default: 5)",
|
||||||
},
|
},
|
||||||
# "safesearch": {
|
|
||||||
# "type": "string",
|
|
||||||
# "description": "Filter level for adult content (off, strict). Default: strict",
|
|
||||||
# },
|
|
||||||
# "spellcheck": {
|
|
||||||
# "type": "boolean",
|
|
||||||
# "description": "Whether to spellcheck provided query (default: true)",
|
|
||||||
# }
|
|
||||||
},
|
},
|
||||||
"required": ["query"],
|
"required": ["query"],
|
||||||
"additionalProperties": False,
|
"additionalProperties": False,
|
||||||
},
|
},
|
||||||
}
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_config_requirements(self):
|
def get_config_requirements(self):
|
||||||
return {
|
return {
|
||||||
"token": {
|
"token": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Brave Search API key for authentication"
|
"description": "Brave Search API key for authentication",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
114
application/agents/tools/duckduckgo.py
Normal file
114
application/agents/tools/duckduckgo.py
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
from application.agents.tools.base import Tool
|
||||||
|
from duckduckgo_search import DDGS
|
||||||
|
|
||||||
|
|
||||||
|
class DuckDuckGoSearchTool(Tool):
|
||||||
|
"""
|
||||||
|
DuckDuckGo Search
|
||||||
|
A tool for performing web and image searches using DuckDuckGo.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def execute_action(self, action_name, **kwargs):
|
||||||
|
actions = {
|
||||||
|
"ddg_web_search": self._web_search,
|
||||||
|
"ddg_image_search": self._image_search,
|
||||||
|
}
|
||||||
|
|
||||||
|
if action_name in actions:
|
||||||
|
return actions[action_name](**kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown action: {action_name}")
|
||||||
|
|
||||||
|
def _web_search(
|
||||||
|
self,
|
||||||
|
query,
|
||||||
|
max_results=5,
|
||||||
|
):
|
||||||
|
print(f"Performing DuckDuckGo web search for: {query}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
results = DDGS().text(
|
||||||
|
query,
|
||||||
|
max_results=max_results,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status_code": 200,
|
||||||
|
"results": results,
|
||||||
|
"message": "Web search completed successfully.",
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"status_code": 500,
|
||||||
|
"message": f"Web search failed: {str(e)}",
|
||||||
|
}
|
||||||
|
|
||||||
|
def _image_search(
|
||||||
|
self,
|
||||||
|
query,
|
||||||
|
max_results=5,
|
||||||
|
):
|
||||||
|
print(f"Performing DuckDuckGo image search for: {query}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
results = DDGS().images(
|
||||||
|
keywords=query,
|
||||||
|
max_results=max_results,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status_code": 200,
|
||||||
|
"results": results,
|
||||||
|
"message": "Image search completed successfully.",
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"status_code": 500,
|
||||||
|
"message": f"Image search failed: {str(e)}",
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_actions_metadata(self):
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"name": "ddg_web_search",
|
||||||
|
"description": "Perform a web search using DuckDuckGo.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Search query",
|
||||||
|
},
|
||||||
|
"max_results": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Number of results to return (default: 5)",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["query"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "ddg_image_search",
|
||||||
|
"description": "Perform an image search using DuckDuckGo.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Search query",
|
||||||
|
},
|
||||||
|
"max_results": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Number of results to return (default: 5, max: 50)",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["query"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_config_requirements(self):
|
||||||
|
return {}
|
||||||
861
application/agents/tools/mcp_tool.py
Normal file
861
application/agents/tools/mcp_tool.py
Normal file
@@ -0,0 +1,861 @@
|
|||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
from urllib.parse import parse_qs, urlparse
|
||||||
|
|
||||||
|
from application.agents.tools.base import Tool
|
||||||
|
from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task
|
||||||
|
from application.cache import get_redis_instance
|
||||||
|
|
||||||
|
from application.core.mongo_db import MongoDB
|
||||||
|
|
||||||
|
from application.core.settings import settings
|
||||||
|
|
||||||
|
from application.security.encryption import decrypt_credentials
|
||||||
|
from fastmcp import Client
|
||||||
|
from fastmcp.client.auth import BearerAuth
|
||||||
|
from fastmcp.client.transports import (
|
||||||
|
SSETransport,
|
||||||
|
StdioTransport,
|
||||||
|
StreamableHttpTransport,
|
||||||
|
)
|
||||||
|
from mcp.client.auth import OAuthClientProvider, TokenStorage
|
||||||
|
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
|
||||||
|
|
||||||
|
from pydantic import AnyHttpUrl, ValidationError
|
||||||
|
from redis import Redis
|
||||||
|
|
||||||
|
mongo = MongoDB.get_client()
|
||||||
|
db = mongo[settings.MONGO_DB_NAME]
|
||||||
|
|
||||||
|
_mcp_clients_cache = {}
|
||||||
|
|
||||||
|
|
||||||
|
class MCPTool(Tool):
|
||||||
|
"""
|
||||||
|
MCP Tool
|
||||||
|
Connect to remote Model Context Protocol (MCP) servers to access dynamic tools and resources.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: Dict[str, Any], user_id: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Initialize the MCP Tool with configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Dictionary containing MCP server configuration:
|
||||||
|
- server_url: URL of the remote MCP server
|
||||||
|
- transport_type: Transport type (auto, sse, http, stdio)
|
||||||
|
- auth_type: Type of authentication (bearer, oauth, api_key, basic, none)
|
||||||
|
- encrypted_credentials: Encrypted credentials (if available)
|
||||||
|
- timeout: Request timeout in seconds (default: 30)
|
||||||
|
- headers: Custom headers for requests
|
||||||
|
- command: Command for STDIO transport
|
||||||
|
- args: Arguments for STDIO transport
|
||||||
|
- oauth_scopes: OAuth scopes for oauth auth type
|
||||||
|
- oauth_client_name: OAuth client name for oauth auth type
|
||||||
|
user_id: User ID for decrypting credentials (required if encrypted_credentials exist)
|
||||||
|
"""
|
||||||
|
self.config = config
|
||||||
|
self.user_id = user_id
|
||||||
|
self.server_url = config.get("server_url", "")
|
||||||
|
self.transport_type = config.get("transport_type", "auto")
|
||||||
|
self.auth_type = config.get("auth_type", "none")
|
||||||
|
self.timeout = config.get("timeout", 30)
|
||||||
|
self.custom_headers = config.get("headers", {})
|
||||||
|
|
||||||
|
self.auth_credentials = {}
|
||||||
|
if config.get("encrypted_credentials") and user_id:
|
||||||
|
self.auth_credentials = decrypt_credentials(
|
||||||
|
config["encrypted_credentials"], user_id
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.auth_credentials = config.get("auth_credentials", {})
|
||||||
|
self.oauth_scopes = config.get("oauth_scopes", [])
|
||||||
|
self.oauth_task_id = config.get("oauth_task_id", None)
|
||||||
|
self.oauth_client_name = config.get("oauth_client_name", "DocsGPT-MCP")
|
||||||
|
self.redirect_uri = f"{settings.API_URL}/api/mcp_server/callback"
|
||||||
|
|
||||||
|
self.available_tools = []
|
||||||
|
self._cache_key = self._generate_cache_key()
|
||||||
|
self._client = None
|
||||||
|
|
||||||
|
# Only validate and setup if server_url is provided and not OAuth
|
||||||
|
|
||||||
|
if self.server_url and self.auth_type != "oauth":
|
||||||
|
self._setup_client()
|
||||||
|
|
||||||
|
def _generate_cache_key(self) -> str:
|
||||||
|
"""Generate a unique cache key for this MCP server configuration."""
|
||||||
|
auth_key = ""
|
||||||
|
if self.auth_type == "oauth":
|
||||||
|
scopes_str = ",".join(self.oauth_scopes) if self.oauth_scopes else "none"
|
||||||
|
auth_key = f"oauth:{self.oauth_client_name}:{scopes_str}"
|
||||||
|
elif self.auth_type in ["bearer"]:
|
||||||
|
token = self.auth_credentials.get(
|
||||||
|
"bearer_token", ""
|
||||||
|
) or self.auth_credentials.get("access_token", "")
|
||||||
|
auth_key = f"bearer:{token[:10]}..." if token else "bearer:none"
|
||||||
|
elif self.auth_type == "api_key":
|
||||||
|
api_key = self.auth_credentials.get("api_key", "")
|
||||||
|
auth_key = f"apikey:{api_key[:10]}..." if api_key else "apikey:none"
|
||||||
|
elif self.auth_type == "basic":
|
||||||
|
username = self.auth_credentials.get("username", "")
|
||||||
|
auth_key = f"basic:{username}"
|
||||||
|
else:
|
||||||
|
auth_key = "none"
|
||||||
|
return f"{self.server_url}#{self.transport_type}#{auth_key}"
|
||||||
|
|
||||||
|
def _setup_client(self):
|
||||||
|
"""Setup FastMCP client with proper transport and authentication."""
|
||||||
|
global _mcp_clients_cache
|
||||||
|
if self._cache_key in _mcp_clients_cache:
|
||||||
|
cached_data = _mcp_clients_cache[self._cache_key]
|
||||||
|
if time.time() - cached_data["created_at"] < 1800:
|
||||||
|
self._client = cached_data["client"]
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
del _mcp_clients_cache[self._cache_key]
|
||||||
|
transport = self._create_transport()
|
||||||
|
auth = None
|
||||||
|
|
||||||
|
if self.auth_type == "oauth":
|
||||||
|
redis_client = get_redis_instance()
|
||||||
|
auth = DocsGPTOAuth(
|
||||||
|
mcp_url=self.server_url,
|
||||||
|
scopes=self.oauth_scopes,
|
||||||
|
redis_client=redis_client,
|
||||||
|
redirect_uri=self.redirect_uri,
|
||||||
|
task_id=self.oauth_task_id,
|
||||||
|
db=db,
|
||||||
|
user_id=self.user_id,
|
||||||
|
)
|
||||||
|
elif self.auth_type == "bearer":
|
||||||
|
token = self.auth_credentials.get(
|
||||||
|
"bearer_token", ""
|
||||||
|
) or self.auth_credentials.get("access_token", "")
|
||||||
|
if token:
|
||||||
|
auth = BearerAuth(token)
|
||||||
|
self._client = Client(transport, auth=auth)
|
||||||
|
_mcp_clients_cache[self._cache_key] = {
|
||||||
|
"client": self._client,
|
||||||
|
"created_at": time.time(),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _create_transport(self):
|
||||||
|
"""Create appropriate transport based on configuration."""
|
||||||
|
headers = {"Content-Type": "application/json", "User-Agent": "DocsGPT-MCP/1.0"}
|
||||||
|
headers.update(self.custom_headers)
|
||||||
|
|
||||||
|
if self.auth_type == "api_key":
|
||||||
|
api_key = self.auth_credentials.get("api_key", "")
|
||||||
|
header_name = self.auth_credentials.get("api_key_header", "X-API-Key")
|
||||||
|
if api_key:
|
||||||
|
headers[header_name] = api_key
|
||||||
|
elif self.auth_type == "basic":
|
||||||
|
username = self.auth_credentials.get("username", "")
|
||||||
|
password = self.auth_credentials.get("password", "")
|
||||||
|
if username and password:
|
||||||
|
credentials = base64.b64encode(
|
||||||
|
f"{username}:{password}".encode()
|
||||||
|
).decode()
|
||||||
|
headers["Authorization"] = f"Basic {credentials}"
|
||||||
|
if self.transport_type == "auto":
|
||||||
|
if "sse" in self.server_url.lower() or self.server_url.endswith("/sse"):
|
||||||
|
transport_type = "sse"
|
||||||
|
else:
|
||||||
|
transport_type = "http"
|
||||||
|
else:
|
||||||
|
transport_type = self.transport_type
|
||||||
|
if transport_type == "sse":
|
||||||
|
headers.update({"Accept": "text/event-stream", "Cache-Control": "no-cache"})
|
||||||
|
return SSETransport(url=self.server_url, headers=headers)
|
||||||
|
elif transport_type == "http":
|
||||||
|
return StreamableHttpTransport(url=self.server_url, headers=headers)
|
||||||
|
elif transport_type == "stdio":
|
||||||
|
command = self.config.get("command", "python")
|
||||||
|
args = self.config.get("args", [])
|
||||||
|
env = self.auth_credentials if self.auth_credentials else None
|
||||||
|
return StdioTransport(command=command, args=args, env=env)
|
||||||
|
else:
|
||||||
|
return StreamableHttpTransport(url=self.server_url, headers=headers)
|
||||||
|
|
||||||
|
def _format_tools(self, tools_response) -> List[Dict]:
|
||||||
|
"""Format tools response to match expected format."""
|
||||||
|
if hasattr(tools_response, "tools"):
|
||||||
|
tools = tools_response.tools
|
||||||
|
elif isinstance(tools_response, list):
|
||||||
|
tools = tools_response
|
||||||
|
else:
|
||||||
|
tools = []
|
||||||
|
tools_dict = []
|
||||||
|
for tool in tools:
|
||||||
|
if hasattr(tool, "name"):
|
||||||
|
tool_dict = {
|
||||||
|
"name": tool.name,
|
||||||
|
"description": tool.description,
|
||||||
|
}
|
||||||
|
if hasattr(tool, "inputSchema"):
|
||||||
|
tool_dict["inputSchema"] = tool.inputSchema
|
||||||
|
tools_dict.append(tool_dict)
|
||||||
|
elif isinstance(tool, dict):
|
||||||
|
tools_dict.append(tool)
|
||||||
|
else:
|
||||||
|
if hasattr(tool, "model_dump"):
|
||||||
|
tools_dict.append(tool.model_dump())
|
||||||
|
else:
|
||||||
|
tools_dict.append({"name": str(tool), "description": ""})
|
||||||
|
return tools_dict
|
||||||
|
|
||||||
|
async def _execute_with_client(self, operation: str, *args, **kwargs):
|
||||||
|
"""Execute operation with FastMCP client."""
|
||||||
|
if not self._client:
|
||||||
|
raise Exception("FastMCP client not initialized")
|
||||||
|
async with self._client:
|
||||||
|
if operation == "ping":
|
||||||
|
return await self._client.ping()
|
||||||
|
elif operation == "list_tools":
|
||||||
|
tools_response = await self._client.list_tools()
|
||||||
|
self.available_tools = self._format_tools(tools_response)
|
||||||
|
return self.available_tools
|
||||||
|
elif operation == "call_tool":
|
||||||
|
tool_name = args[0]
|
||||||
|
tool_args = kwargs
|
||||||
|
return await self._client.call_tool(tool_name, tool_args)
|
||||||
|
elif operation == "list_resources":
|
||||||
|
return await self._client.list_resources()
|
||||||
|
elif operation == "list_prompts":
|
||||||
|
return await self._client.list_prompts()
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unknown operation: {operation}")
|
||||||
|
|
||||||
|
def _run_async_operation(self, operation: str, *args, **kwargs):
|
||||||
|
"""Run async operation in sync context."""
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
import concurrent.futures
|
||||||
|
|
||||||
|
def run_in_thread():
|
||||||
|
new_loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(new_loop)
|
||||||
|
try:
|
||||||
|
return new_loop.run_until_complete(
|
||||||
|
self._execute_with_client(operation, *args, **kwargs)
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
new_loop.close()
|
||||||
|
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
|
future = executor.submit(run_in_thread)
|
||||||
|
return future.result(timeout=self.timeout)
|
||||||
|
except RuntimeError:
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
try:
|
||||||
|
return loop.run_until_complete(
|
||||||
|
self._execute_with_client(operation, *args, **kwargs)
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error occurred while running async operation: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def discover_tools(self) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
Discover available tools from the MCP server using FastMCP.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of tool definitions from the server
|
||||||
|
"""
|
||||||
|
if not self.server_url:
|
||||||
|
return []
|
||||||
|
if not self._client:
|
||||||
|
self._setup_client()
|
||||||
|
try:
|
||||||
|
tools = self._run_async_operation("list_tools")
|
||||||
|
self.available_tools = tools
|
||||||
|
return self.available_tools
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Failed to discover tools from MCP server: {str(e)}")
|
||||||
|
|
||||||
|
def execute_action(self, action_name: str, **kwargs) -> Any:
|
||||||
|
"""
|
||||||
|
Execute an action on the remote MCP server using FastMCP.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action_name: Name of the action to execute
|
||||||
|
**kwargs: Parameters for the action
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Result from the MCP server
|
||||||
|
"""
|
||||||
|
if not self.server_url:
|
||||||
|
raise Exception("No MCP server configured")
|
||||||
|
if not self._client:
|
||||||
|
self._setup_client()
|
||||||
|
cleaned_kwargs = {}
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if value == "" or value is None:
|
||||||
|
continue
|
||||||
|
cleaned_kwargs[key] = value
|
||||||
|
try:
|
||||||
|
result = self._run_async_operation(
|
||||||
|
"call_tool", action_name, **cleaned_kwargs
|
||||||
|
)
|
||||||
|
return self._format_result(result)
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Failed to execute action '{action_name}': {str(e)}")
|
||||||
|
|
||||||
|
def _format_result(self, result) -> Dict:
|
||||||
|
"""Format FastMCP result to match expected format."""
|
||||||
|
if hasattr(result, "content"):
|
||||||
|
content_list = []
|
||||||
|
for content_item in result.content:
|
||||||
|
if hasattr(content_item, "text"):
|
||||||
|
content_list.append({"type": "text", "text": content_item.text})
|
||||||
|
elif hasattr(content_item, "data"):
|
||||||
|
content_list.append({"type": "data", "data": content_item.data})
|
||||||
|
else:
|
||||||
|
content_list.append(
|
||||||
|
{"type": "unknown", "content": str(content_item)}
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"content": content_list,
|
||||||
|
"isError": getattr(result, "isError", False),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return result
|
||||||
|
|
||||||
|
def test_connection(self) -> Dict:
|
||||||
|
"""
|
||||||
|
Test the connection to the MCP server and validate functionality.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with connection test results including tool count
|
||||||
|
"""
|
||||||
|
if not self.server_url:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": "No MCP server URL configured",
|
||||||
|
"tools_count": 0,
|
||||||
|
"transport_type": self.transport_type,
|
||||||
|
"auth_type": self.auth_type,
|
||||||
|
"error_type": "ConfigurationError",
|
||||||
|
}
|
||||||
|
if not self._client:
|
||||||
|
self._setup_client()
|
||||||
|
try:
|
||||||
|
if self.auth_type == "oauth":
|
||||||
|
return self._test_oauth_connection()
|
||||||
|
else:
|
||||||
|
return self._test_regular_connection()
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": f"Connection failed: {str(e)}",
|
||||||
|
"tools_count": 0,
|
||||||
|
"transport_type": self.transport_type,
|
||||||
|
"auth_type": self.auth_type,
|
||||||
|
"error_type": type(e).__name__,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _test_regular_connection(self) -> Dict:
|
||||||
|
"""Test connection for non-OAuth auth types."""
|
||||||
|
try:
|
||||||
|
self._run_async_operation("ping")
|
||||||
|
ping_success = True
|
||||||
|
except Exception:
|
||||||
|
ping_success = False
|
||||||
|
tools = self.discover_tools()
|
||||||
|
|
||||||
|
message = f"Successfully connected to MCP server. Found {len(tools)} tools."
|
||||||
|
if not ping_success:
|
||||||
|
message += " (Ping not supported, but tool discovery worked)"
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": message,
|
||||||
|
"tools_count": len(tools),
|
||||||
|
"transport_type": self.transport_type,
|
||||||
|
"auth_type": self.auth_type,
|
||||||
|
"ping_supported": ping_success,
|
||||||
|
"tools": [tool.get("name", "unknown") for tool in tools],
|
||||||
|
}
|
||||||
|
|
||||||
|
def _test_oauth_connection(self) -> Dict:
|
||||||
|
"""Test connection for OAuth auth type with proper async handling."""
|
||||||
|
try:
|
||||||
|
task = mcp_oauth_task.delay(config=self.config, user=self.user_id)
|
||||||
|
if not task:
|
||||||
|
raise Exception("Failed to start OAuth authentication")
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"requires_oauth": True,
|
||||||
|
"task_id": task.id,
|
||||||
|
"status": "pending",
|
||||||
|
"message": "OAuth flow started",
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": f"OAuth connection failed: {str(e)}",
|
||||||
|
"tools_count": 0,
|
||||||
|
"transport_type": self.transport_type,
|
||||||
|
"auth_type": self.auth_type,
|
||||||
|
"error_type": type(e).__name__,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_actions_metadata(self) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
Get metadata for all available actions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of action metadata dictionaries
|
||||||
|
"""
|
||||||
|
actions = []
|
||||||
|
for tool in self.available_tools:
|
||||||
|
input_schema = (
|
||||||
|
tool.get("inputSchema")
|
||||||
|
or tool.get("input_schema")
|
||||||
|
or tool.get("schema")
|
||||||
|
or tool.get("parameters")
|
||||||
|
)
|
||||||
|
|
||||||
|
parameters_schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {},
|
||||||
|
"required": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
if input_schema:
|
||||||
|
if isinstance(input_schema, dict):
|
||||||
|
if "properties" in input_schema:
|
||||||
|
parameters_schema = {
|
||||||
|
"type": input_schema.get("type", "object"),
|
||||||
|
"properties": input_schema.get("properties", {}),
|
||||||
|
"required": input_schema.get("required", []),
|
||||||
|
}
|
||||||
|
|
||||||
|
for key in ["additionalProperties", "description"]:
|
||||||
|
if key in input_schema:
|
||||||
|
parameters_schema[key] = input_schema[key]
|
||||||
|
else:
|
||||||
|
parameters_schema["properties"] = input_schema
|
||||||
|
action = {
|
||||||
|
"name": tool.get("name", ""),
|
||||||
|
"description": tool.get("description", ""),
|
||||||
|
"parameters": parameters_schema,
|
||||||
|
}
|
||||||
|
actions.append(action)
|
||||||
|
return actions
|
||||||
|
|
||||||
|
def get_config_requirements(self) -> Dict:
|
||||||
|
"""Get configuration requirements for the MCP tool."""
|
||||||
|
return {
|
||||||
|
"server_url": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "URL of the remote MCP server (e.g., https://api.example.com/mcp or https://docs.mcp.cloudflare.com/sse)",
|
||||||
|
"required": True,
|
||||||
|
},
|
||||||
|
"transport_type": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Transport type for connection",
|
||||||
|
"enum": ["auto", "sse", "http", "stdio"],
|
||||||
|
"default": "auto",
|
||||||
|
"required": False,
|
||||||
|
"help": {
|
||||||
|
"auto": "Automatically detect best transport",
|
||||||
|
"sse": "Server-Sent Events (for real-time streaming)",
|
||||||
|
"http": "HTTP streaming (recommended for production)",
|
||||||
|
"stdio": "Standard I/O (for local servers)",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"auth_type": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Authentication type",
|
||||||
|
"enum": ["none", "bearer", "oauth", "api_key", "basic"],
|
||||||
|
"default": "none",
|
||||||
|
"required": True,
|
||||||
|
"help": {
|
||||||
|
"none": "No authentication",
|
||||||
|
"bearer": "Bearer token authentication",
|
||||||
|
"oauth": "OAuth 2.1 authentication (with frontend integration)",
|
||||||
|
"api_key": "API key authentication",
|
||||||
|
"basic": "Basic authentication",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"auth_credentials": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "Authentication credentials (varies by auth_type)",
|
||||||
|
"required": False,
|
||||||
|
"properties": {
|
||||||
|
"bearer_token": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Bearer token for bearer auth",
|
||||||
|
},
|
||||||
|
"access_token": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Access token for OAuth (if pre-obtained)",
|
||||||
|
},
|
||||||
|
"api_key": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "API key for api_key auth",
|
||||||
|
},
|
||||||
|
"api_key_header": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Header name for API key (default: X-API-Key)",
|
||||||
|
},
|
||||||
|
"username": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Username for basic auth",
|
||||||
|
},
|
||||||
|
"password": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Password for basic auth",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"oauth_scopes": {
|
||||||
|
"type": "array",
|
||||||
|
"description": "OAuth scopes to request (for oauth auth_type)",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"required": False,
|
||||||
|
"default": [],
|
||||||
|
},
|
||||||
|
"oauth_client_name": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Client name for OAuth registration (for oauth auth_type)",
|
||||||
|
"default": "DocsGPT-MCP",
|
||||||
|
"required": False,
|
||||||
|
},
|
||||||
|
"headers": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "Custom headers to send with requests",
|
||||||
|
"required": False,
|
||||||
|
},
|
||||||
|
"timeout": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Request timeout in seconds",
|
||||||
|
"default": 30,
|
||||||
|
"minimum": 1,
|
||||||
|
"maximum": 300,
|
||||||
|
"required": False,
|
||||||
|
},
|
||||||
|
"command": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Command to run for STDIO transport (e.g., 'python')",
|
||||||
|
"required": False,
|
||||||
|
},
|
||||||
|
"args": {
|
||||||
|
"type": "array",
|
||||||
|
"description": "Arguments for STDIO command",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"required": False,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class DocsGPTOAuth(OAuthClientProvider):
|
||||||
|
"""
|
||||||
|
Custom OAuth handler for DocsGPT that uses frontend redirect instead of browser.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mcp_url: str,
|
||||||
|
redirect_uri: str,
|
||||||
|
redis_client: Redis | None = None,
|
||||||
|
redis_prefix: str = "mcp_oauth:",
|
||||||
|
task_id: str = None,
|
||||||
|
scopes: str | list[str] | None = None,
|
||||||
|
client_name: str = "DocsGPT-MCP",
|
||||||
|
user_id=None,
|
||||||
|
db=None,
|
||||||
|
additional_client_metadata: dict[str, Any] | None = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize custom OAuth client provider for DocsGPT.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mcp_url: Full URL to the MCP endpoint
|
||||||
|
redirect_uri: Custom redirect URI for DocsGPT frontend
|
||||||
|
redis_client: Redis client for storing auth state
|
||||||
|
redis_prefix: Prefix for Redis keys
|
||||||
|
task_id: Task ID for tracking auth status
|
||||||
|
scopes: OAuth scopes to request
|
||||||
|
client_name: Name for this client during registration
|
||||||
|
user_id: User ID for token storage
|
||||||
|
db: Database instance for token storage
|
||||||
|
additional_client_metadata: Extra fields for OAuthClientMetadata
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.redirect_uri = redirect_uri
|
||||||
|
self.redis_client = redis_client
|
||||||
|
self.redis_prefix = redis_prefix
|
||||||
|
self.task_id = task_id
|
||||||
|
self.user_id = user_id
|
||||||
|
self.db = db
|
||||||
|
|
||||||
|
parsed_url = urlparse(mcp_url)
|
||||||
|
self.server_base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||||||
|
|
||||||
|
if isinstance(scopes, list):
|
||||||
|
scopes = " ".join(scopes)
|
||||||
|
client_metadata = OAuthClientMetadata(
|
||||||
|
client_name=client_name,
|
||||||
|
redirect_uris=[AnyHttpUrl(redirect_uri)],
|
||||||
|
grant_types=["authorization_code", "refresh_token"],
|
||||||
|
response_types=["code"],
|
||||||
|
scope=scopes,
|
||||||
|
**(additional_client_metadata or {}),
|
||||||
|
)
|
||||||
|
|
||||||
|
storage = DBTokenStorage(
|
||||||
|
server_url=self.server_base_url, user_id=self.user_id, db_client=self.db
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
server_url=self.server_base_url,
|
||||||
|
client_metadata=client_metadata,
|
||||||
|
storage=storage,
|
||||||
|
redirect_handler=self.redirect_handler,
|
||||||
|
callback_handler=self.callback_handler,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.auth_url = None
|
||||||
|
self.extracted_state = None
|
||||||
|
|
||||||
|
def _process_auth_url(self, authorization_url: str) -> tuple[str, str]:
|
||||||
|
"""Process authorization URL to extract state"""
|
||||||
|
try:
|
||||||
|
parsed_url = urlparse(authorization_url)
|
||||||
|
query_params = parse_qs(parsed_url.query)
|
||||||
|
|
||||||
|
state_params = query_params.get("state", [])
|
||||||
|
if state_params:
|
||||||
|
state = state_params[0]
|
||||||
|
else:
|
||||||
|
raise ValueError("No state in auth URL")
|
||||||
|
return authorization_url, state
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Failed to process auth URL: {e}")
|
||||||
|
|
||||||
|
async def redirect_handler(self, authorization_url: str) -> None:
|
||||||
|
"""Store auth URL and state in Redis for frontend to use."""
|
||||||
|
auth_url, state = self._process_auth_url(authorization_url)
|
||||||
|
logging.info(
|
||||||
|
"[DocsGPTOAuth] Processed auth_url: %s, state: %s", auth_url, state
|
||||||
|
)
|
||||||
|
self.auth_url = auth_url
|
||||||
|
self.extracted_state = state
|
||||||
|
|
||||||
|
if self.redis_client and self.extracted_state:
|
||||||
|
key = f"{self.redis_prefix}auth_url:{self.extracted_state}"
|
||||||
|
self.redis_client.setex(key, 600, auth_url)
|
||||||
|
logging.info("[DocsGPTOAuth] Stored auth_url in Redis: %s", key)
|
||||||
|
|
||||||
|
if self.task_id:
|
||||||
|
status_key = f"mcp_oauth_status:{self.task_id}"
|
||||||
|
status_data = {
|
||||||
|
"status": "requires_redirect",
|
||||||
|
"message": "OAuth authorization required",
|
||||||
|
"authorization_url": self.auth_url,
|
||||||
|
"state": self.extracted_state,
|
||||||
|
"requires_oauth": True,
|
||||||
|
"task_id": self.task_id,
|
||||||
|
}
|
||||||
|
self.redis_client.setex(status_key, 600, json.dumps(status_data))
|
||||||
|
|
||||||
|
async def callback_handler(self) -> tuple[str, str | None]:
|
||||||
|
"""Wait for auth code from Redis using the state value."""
|
||||||
|
if not self.redis_client or not self.extracted_state:
|
||||||
|
raise Exception("Redis client or state not configured for OAuth")
|
||||||
|
poll_interval = 1
|
||||||
|
max_wait_time = 300
|
||||||
|
code_key = f"{self.redis_prefix}code:{self.extracted_state}"
|
||||||
|
|
||||||
|
if self.task_id:
|
||||||
|
status_key = f"mcp_oauth_status:{self.task_id}"
|
||||||
|
status_data = {
|
||||||
|
"status": "awaiting_callback",
|
||||||
|
"message": "Waiting for OAuth callback...",
|
||||||
|
"authorization_url": self.auth_url,
|
||||||
|
"state": self.extracted_state,
|
||||||
|
"requires_oauth": True,
|
||||||
|
"task_id": self.task_id,
|
||||||
|
}
|
||||||
|
self.redis_client.setex(status_key, 600, json.dumps(status_data))
|
||||||
|
start_time = time.time()
|
||||||
|
while time.time() - start_time < max_wait_time:
|
||||||
|
code_data = self.redis_client.get(code_key)
|
||||||
|
if code_data:
|
||||||
|
code = code_data.decode()
|
||||||
|
returned_state = self.extracted_state
|
||||||
|
|
||||||
|
self.redis_client.delete(code_key)
|
||||||
|
self.redis_client.delete(
|
||||||
|
f"{self.redis_prefix}auth_url:{self.extracted_state}"
|
||||||
|
)
|
||||||
|
self.redis_client.delete(
|
||||||
|
f"{self.redis_prefix}state:{self.extracted_state}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.task_id:
|
||||||
|
status_data = {
|
||||||
|
"status": "callback_received",
|
||||||
|
"message": "OAuth callback received, completing authentication...",
|
||||||
|
"task_id": self.task_id,
|
||||||
|
}
|
||||||
|
self.redis_client.setex(status_key, 600, json.dumps(status_data))
|
||||||
|
return code, returned_state
|
||||||
|
error_key = f"{self.redis_prefix}error:{self.extracted_state}"
|
||||||
|
error_data = self.redis_client.get(error_key)
|
||||||
|
if error_data:
|
||||||
|
error_msg = error_data.decode()
|
||||||
|
self.redis_client.delete(error_key)
|
||||||
|
self.redis_client.delete(
|
||||||
|
f"{self.redis_prefix}auth_url:{self.extracted_state}"
|
||||||
|
)
|
||||||
|
self.redis_client.delete(
|
||||||
|
f"{self.redis_prefix}state:{self.extracted_state}"
|
||||||
|
)
|
||||||
|
raise Exception(f"OAuth error: {error_msg}")
|
||||||
|
await asyncio.sleep(poll_interval)
|
||||||
|
self.redis_client.delete(f"{self.redis_prefix}auth_url:{self.extracted_state}")
|
||||||
|
self.redis_client.delete(f"{self.redis_prefix}state:{self.extracted_state}")
|
||||||
|
raise Exception("OAuth callback timeout: no code received within 5 minutes")
|
||||||
|
|
||||||
|
|
||||||
|
class DBTokenStorage(TokenStorage):
|
||||||
|
def __init__(self, server_url: str, user_id: str, db_client):
|
||||||
|
self.server_url = server_url
|
||||||
|
self.user_id = user_id
|
||||||
|
self.db_client = db_client
|
||||||
|
self.collection = db_client["connector_sessions"]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_base_url(url: str) -> str:
|
||||||
|
parsed = urlparse(url)
|
||||||
|
return f"{parsed.scheme}://{parsed.netloc}"
|
||||||
|
|
||||||
|
def get_db_key(self) -> dict:
|
||||||
|
return {
|
||||||
|
"server_url": self.get_base_url(self.server_url),
|
||||||
|
"user_id": self.user_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def get_tokens(self) -> OAuthToken | None:
|
||||||
|
doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key())
|
||||||
|
if not doc or "tokens" not in doc:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
tokens = OAuthToken.model_validate(doc["tokens"])
|
||||||
|
return tokens
|
||||||
|
except ValidationError as e:
|
||||||
|
logging.error(f"Could not load tokens: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def set_tokens(self, tokens: OAuthToken) -> None:
|
||||||
|
await asyncio.to_thread(
|
||||||
|
self.collection.update_one,
|
||||||
|
self.get_db_key(),
|
||||||
|
{"$set": {"tokens": tokens.model_dump()}},
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
logging.info(f"Saved tokens for {self.get_base_url(self.server_url)}")
|
||||||
|
|
||||||
|
async def get_client_info(self) -> OAuthClientInformationFull | None:
|
||||||
|
doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key())
|
||||||
|
if not doc or "client_info" not in doc:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
client_info = OAuthClientInformationFull.model_validate(doc["client_info"])
|
||||||
|
tokens = await self.get_tokens()
|
||||||
|
if tokens is None:
|
||||||
|
logging.debug(
|
||||||
|
"No tokens found, clearing client info to force fresh registration."
|
||||||
|
)
|
||||||
|
await asyncio.to_thread(
|
||||||
|
self.collection.update_one,
|
||||||
|
self.get_db_key(),
|
||||||
|
{"$unset": {"client_info": ""}},
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
return client_info
|
||||||
|
except ValidationError as e:
|
||||||
|
logging.error(f"Could not load client info: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _serialize_client_info(self, info: dict) -> dict:
|
||||||
|
if "redirect_uris" in info and isinstance(info["redirect_uris"], list):
|
||||||
|
info["redirect_uris"] = [str(u) for u in info["redirect_uris"]]
|
||||||
|
return info
|
||||||
|
|
||||||
|
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
|
||||||
|
serialized_info = self._serialize_client_info(client_info.model_dump())
|
||||||
|
await asyncio.to_thread(
|
||||||
|
self.collection.update_one,
|
||||||
|
self.get_db_key(),
|
||||||
|
{"$set": {"client_info": serialized_info}},
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
logging.info(f"Saved client info for {self.get_base_url(self.server_url)}")
|
||||||
|
|
||||||
|
async def clear(self) -> None:
|
||||||
|
await asyncio.to_thread(self.collection.delete_one, self.get_db_key())
|
||||||
|
logging.info(f"Cleared OAuth cache for {self.get_base_url(self.server_url)}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def clear_all(cls, db_client) -> None:
|
||||||
|
collection = db_client["connector_sessions"]
|
||||||
|
await asyncio.to_thread(collection.delete_many, {})
|
||||||
|
logging.info("Cleared all OAuth client cache data.")
|
||||||
|
|
||||||
|
|
||||||
|
class MCPOAuthManager:
|
||||||
|
"""Manager for handling MCP OAuth callbacks."""
|
||||||
|
|
||||||
|
def __init__(self, redis_client: Redis | None, redis_prefix: str = "mcp_oauth:"):
|
||||||
|
self.redis_client = redis_client
|
||||||
|
self.redis_prefix = redis_prefix
|
||||||
|
|
||||||
|
def handle_oauth_callback(
|
||||||
|
self, state: str, code: str, error: Optional[str] = None
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Handle OAuth callback from provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: The state parameter from OAuth callback
|
||||||
|
code: The authorization code from OAuth callback
|
||||||
|
error: Error message if OAuth failed
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successful, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not self.redis_client or not state:
|
||||||
|
raise Exception("Redis client or state not provided")
|
||||||
|
if error:
|
||||||
|
error_key = f"{self.redis_prefix}error:{state}"
|
||||||
|
self.redis_client.setex(error_key, 300, error)
|
||||||
|
raise Exception(f"OAuth error received: {error}")
|
||||||
|
code_key = f"{self.redis_prefix}code:{state}"
|
||||||
|
self.redis_client.setex(code_key, 300, code)
|
||||||
|
|
||||||
|
state_key = f"{self.redis_prefix}state:{state}"
|
||||||
|
self.redis_client.setex(state_key, 300, "completed")
|
||||||
|
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error handling OAuth callback: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_oauth_status(self, task_id: str) -> Dict[str, Any]:
|
||||||
|
"""Get current status of OAuth flow using provided task_id."""
|
||||||
|
if not task_id:
|
||||||
|
return {"status": "not_started", "message": "OAuth flow not started"}
|
||||||
|
return mcp_oauth_status_task(task_id)
|
||||||
546
application/agents/tools/memory.py
Normal file
546
application/agents/tools/memory.py
Normal file
@@ -0,0 +1,546 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
import re
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from .base import Tool
|
||||||
|
from application.core.mongo_db import MongoDB
|
||||||
|
from application.core.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryTool(Tool):
|
||||||
|
"""Memory
|
||||||
|
|
||||||
|
Stores and retrieves information across conversations through a memory file directory.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, tool_config: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> None:
|
||||||
|
"""Initialize the tool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_config: Optional tool configuration. Should include:
|
||||||
|
- tool_id: Unique identifier for this memory tool instance (from user_tools._id)
|
||||||
|
This ensures each user's tool configuration has isolated memories
|
||||||
|
user_id: The authenticated user's id (should come from decoded_token["sub"]).
|
||||||
|
"""
|
||||||
|
self.user_id: Optional[str] = user_id
|
||||||
|
|
||||||
|
# Get tool_id from configuration (passed from user_tools._id in production)
|
||||||
|
# In production, tool_id is the MongoDB ObjectId string from user_tools collection
|
||||||
|
if tool_config and "tool_id" in tool_config:
|
||||||
|
self.tool_id = tool_config["tool_id"]
|
||||||
|
elif user_id:
|
||||||
|
# Fallback for backward compatibility or testing
|
||||||
|
self.tool_id = f"default_{user_id}"
|
||||||
|
else:
|
||||||
|
# Last resort fallback (shouldn't happen in normal use)
|
||||||
|
self.tool_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
|
||||||
|
self.collection = db["memories"]
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# Action implementations
|
||||||
|
# -----------------------------
|
||||||
|
def execute_action(self, action_name: str, **kwargs: Any) -> str:
|
||||||
|
"""Execute an action by name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action_name: One of view, create, str_replace, insert, delete, rename.
|
||||||
|
**kwargs: Parameters for the action.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A human-readable string result.
|
||||||
|
"""
|
||||||
|
if not self.user_id:
|
||||||
|
return "Error: MemoryTool requires a valid user_id."
|
||||||
|
|
||||||
|
if action_name == "view":
|
||||||
|
return self._view(
|
||||||
|
kwargs.get("path", "/"),
|
||||||
|
kwargs.get("view_range")
|
||||||
|
)
|
||||||
|
|
||||||
|
if action_name == "create":
|
||||||
|
return self._create(
|
||||||
|
kwargs.get("path", ""),
|
||||||
|
kwargs.get("file_text", "")
|
||||||
|
)
|
||||||
|
|
||||||
|
if action_name == "str_replace":
|
||||||
|
return self._str_replace(
|
||||||
|
kwargs.get("path", ""),
|
||||||
|
kwargs.get("old_str", ""),
|
||||||
|
kwargs.get("new_str", "")
|
||||||
|
)
|
||||||
|
|
||||||
|
if action_name == "insert":
|
||||||
|
return self._insert(
|
||||||
|
kwargs.get("path", ""),
|
||||||
|
kwargs.get("insert_line", 1),
|
||||||
|
kwargs.get("insert_text", "")
|
||||||
|
)
|
||||||
|
|
||||||
|
if action_name == "delete":
|
||||||
|
return self._delete(kwargs.get("path", ""))
|
||||||
|
|
||||||
|
if action_name == "rename":
|
||||||
|
return self._rename(
|
||||||
|
kwargs.get("old_path", ""),
|
||||||
|
kwargs.get("new_path", "")
|
||||||
|
)
|
||||||
|
|
||||||
|
return f"Unknown action: {action_name}"
|
||||||
|
|
||||||
|
def get_actions_metadata(self) -> List[Dict[str, Any]]:
|
||||||
|
"""Return JSON metadata describing supported actions for tool schemas."""
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"name": "view",
|
||||||
|
"description": "Shows directory contents or file contents with optional line ranges.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Path to file or directory (e.g., /notes.txt or /project/ or /)."
|
||||||
|
},
|
||||||
|
"view_range": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "integer"},
|
||||||
|
"description": "Optional [start_line, end_line] to view specific lines (1-indexed)."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["path"]
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "create",
|
||||||
|
"description": "Create or overwrite a file.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "File path to create (e.g., /notes.txt or /project/task.txt)."
|
||||||
|
},
|
||||||
|
"file_text": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Content to write to the file."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["path", "file_text"]
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "str_replace",
|
||||||
|
"description": "Replace text in a file.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "File path (e.g., /notes.txt)."
|
||||||
|
},
|
||||||
|
"old_str": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "String to find."
|
||||||
|
},
|
||||||
|
"new_str": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "String to replace with."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["path", "old_str", "new_str"]
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "insert",
|
||||||
|
"description": "Insert text at a specific line in a file.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "File path (e.g., /notes.txt)."
|
||||||
|
},
|
||||||
|
"insert_line": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Line number to insert at (1-indexed)."
|
||||||
|
},
|
||||||
|
"insert_text": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Text to insert."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["path", "insert_line", "insert_text"]
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "delete",
|
||||||
|
"description": "Delete a file or directory.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Path to delete (e.g., /notes.txt or /project/)."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["path"]
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "rename",
|
||||||
|
"description": "Rename or move a file/directory.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"old_path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Current path (e.g., /old.txt)."
|
||||||
|
},
|
||||||
|
"new_path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "New path (e.g., /new.txt)."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["old_path", "new_path"]
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_config_requirements(self) -> Dict[str, Any]:
|
||||||
|
"""Return configuration requirements."""
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# Path validation
|
||||||
|
# -----------------------------
|
||||||
|
def _validate_path(self, path: str) -> Optional[str]:
|
||||||
|
"""Validate and normalize path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: User-provided path.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized path or None if invalid.
|
||||||
|
"""
|
||||||
|
if not path:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Remove any leading/trailing whitespace
|
||||||
|
path = path.strip()
|
||||||
|
|
||||||
|
# Preserve whether path ends with / (indicates directory)
|
||||||
|
is_directory = path.endswith("/")
|
||||||
|
|
||||||
|
# Ensure path starts with / for consistency
|
||||||
|
if not path.startswith("/"):
|
||||||
|
path = "/" + path
|
||||||
|
|
||||||
|
# Check for directory traversal patterns
|
||||||
|
if ".." in path or path.count("//") > 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Normalize the path
|
||||||
|
try:
|
||||||
|
# Convert to Path object and resolve to canonical form
|
||||||
|
normalized = str(Path(path).as_posix())
|
||||||
|
|
||||||
|
# Ensure it still starts with /
|
||||||
|
if not normalized.startswith("/"):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Preserve trailing slash for directories
|
||||||
|
if is_directory and not normalized.endswith("/") and normalized != "/":
|
||||||
|
normalized = normalized + "/"
|
||||||
|
|
||||||
|
return normalized
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# Internal helpers
|
||||||
|
# -----------------------------
|
||||||
|
def _view(self, path: str, view_range: Optional[List[int]] = None) -> str:
|
||||||
|
"""View directory contents or file contents."""
|
||||||
|
validated_path = self._validate_path(path)
|
||||||
|
if not validated_path:
|
||||||
|
return "Error: Invalid path."
|
||||||
|
|
||||||
|
# Check if viewing directory (ends with / or is root)
|
||||||
|
if validated_path == "/" or validated_path.endswith("/"):
|
||||||
|
return self._view_directory(validated_path)
|
||||||
|
|
||||||
|
# Otherwise view file
|
||||||
|
return self._view_file(validated_path, view_range)
|
||||||
|
|
||||||
|
def _view_directory(self, path: str) -> str:
|
||||||
|
"""List files in a directory."""
|
||||||
|
# Ensure path ends with / for proper prefix matching
|
||||||
|
search_path = path if path.endswith("/") else path + "/"
|
||||||
|
|
||||||
|
# Find all files that start with this directory path
|
||||||
|
query = {
|
||||||
|
"user_id": self.user_id,
|
||||||
|
"tool_id": self.tool_id,
|
||||||
|
"path": {"$regex": f"^{re.escape(search_path)}"}
|
||||||
|
}
|
||||||
|
|
||||||
|
docs = list(self.collection.find(query, {"path": 1}))
|
||||||
|
|
||||||
|
if not docs:
|
||||||
|
return f"Directory: {path}\n(empty)"
|
||||||
|
|
||||||
|
# Extract filenames relative to the directory
|
||||||
|
files = []
|
||||||
|
for doc in docs:
|
||||||
|
file_path = doc["path"]
|
||||||
|
# Remove the directory prefix
|
||||||
|
if file_path.startswith(search_path):
|
||||||
|
relative = file_path[len(search_path):]
|
||||||
|
if relative:
|
||||||
|
files.append(relative)
|
||||||
|
|
||||||
|
files.sort()
|
||||||
|
file_list = "\n".join(f"- {f}" for f in files)
|
||||||
|
return f"Directory: {path}\n{file_list}"
|
||||||
|
|
||||||
|
def _view_file(self, path: str, view_range: Optional[List[int]] = None) -> str:
|
||||||
|
"""View file contents with optional line range."""
|
||||||
|
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": path})
|
||||||
|
|
||||||
|
if not doc or not doc.get("content"):
|
||||||
|
return f"Error: File not found: {path}"
|
||||||
|
|
||||||
|
content = str(doc["content"])
|
||||||
|
|
||||||
|
# Apply view_range if specified
|
||||||
|
if view_range and len(view_range) == 2:
|
||||||
|
lines = content.split("\n")
|
||||||
|
start, end = view_range
|
||||||
|
# Convert to 0-indexed
|
||||||
|
start_idx = max(0, start - 1)
|
||||||
|
end_idx = min(len(lines), end)
|
||||||
|
|
||||||
|
if start_idx >= len(lines):
|
||||||
|
return f"Error: Line range out of bounds. File has {len(lines)} lines."
|
||||||
|
|
||||||
|
selected_lines = lines[start_idx:end_idx]
|
||||||
|
# Add line numbers (enumerate with 1-based start)
|
||||||
|
numbered_lines = [f"{i}: {line}" for i, line in enumerate(selected_lines, start=start)]
|
||||||
|
return "\n".join(numbered_lines)
|
||||||
|
|
||||||
|
return content
|
||||||
|
|
||||||
|
def _create(self, path: str, file_text: str) -> str:
|
||||||
|
"""Create or overwrite a file."""
|
||||||
|
validated_path = self._validate_path(path)
|
||||||
|
if not validated_path:
|
||||||
|
return "Error: Invalid path."
|
||||||
|
|
||||||
|
if validated_path == "/" or validated_path.endswith("/"):
|
||||||
|
return "Error: Cannot create a file at directory path."
|
||||||
|
|
||||||
|
self.collection.update_one(
|
||||||
|
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
|
||||||
|
{
|
||||||
|
"$set": {
|
||||||
|
"content": file_text,
|
||||||
|
"updated_at": datetime.now()
|
||||||
|
}
|
||||||
|
},
|
||||||
|
upsert=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return f"File created: {validated_path}"
|
||||||
|
|
||||||
|
def _str_replace(self, path: str, old_str: str, new_str: str) -> str:
|
||||||
|
"""Replace text in a file."""
|
||||||
|
validated_path = self._validate_path(path)
|
||||||
|
if not validated_path:
|
||||||
|
return "Error: Invalid path."
|
||||||
|
|
||||||
|
if not old_str:
|
||||||
|
return "Error: old_str is required."
|
||||||
|
|
||||||
|
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path})
|
||||||
|
|
||||||
|
if not doc or not doc.get("content"):
|
||||||
|
return f"Error: File not found: {validated_path}"
|
||||||
|
|
||||||
|
current_content = str(doc["content"])
|
||||||
|
|
||||||
|
# Check if old_str exists (case-insensitive)
|
||||||
|
if old_str.lower() not in current_content.lower():
|
||||||
|
return f"Error: String '{old_str}' not found in file."
|
||||||
|
|
||||||
|
# Replace the string (case-insensitive)
|
||||||
|
import re as regex_module
|
||||||
|
updated_content = regex_module.sub(regex_module.escape(old_str), new_str, current_content, flags=regex_module.IGNORECASE)
|
||||||
|
|
||||||
|
self.collection.update_one(
|
||||||
|
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
|
||||||
|
{
|
||||||
|
"$set": {
|
||||||
|
"content": updated_content,
|
||||||
|
"updated_at": datetime.now()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return f"File updated: {validated_path}"
|
||||||
|
|
||||||
|
def _insert(self, path: str, insert_line: int, insert_text: str) -> str:
|
||||||
|
"""Insert text at a specific line."""
|
||||||
|
validated_path = self._validate_path(path)
|
||||||
|
if not validated_path:
|
||||||
|
return "Error: Invalid path."
|
||||||
|
|
||||||
|
if not insert_text:
|
||||||
|
return "Error: insert_text is required."
|
||||||
|
|
||||||
|
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path})
|
||||||
|
|
||||||
|
if not doc or not doc.get("content"):
|
||||||
|
return f"Error: File not found: {validated_path}"
|
||||||
|
|
||||||
|
current_content = str(doc["content"])
|
||||||
|
lines = current_content.split("\n")
|
||||||
|
|
||||||
|
# Convert to 0-indexed
|
||||||
|
index = insert_line - 1
|
||||||
|
if index < 0 or index > len(lines):
|
||||||
|
return f"Error: Invalid line number. File has {len(lines)} lines."
|
||||||
|
|
||||||
|
lines.insert(index, insert_text)
|
||||||
|
updated_content = "\n".join(lines)
|
||||||
|
|
||||||
|
self.collection.update_one(
|
||||||
|
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
|
||||||
|
{
|
||||||
|
"$set": {
|
||||||
|
"content": updated_content,
|
||||||
|
"updated_at": datetime.now()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return f"Text inserted at line {insert_line} in {validated_path}"
|
||||||
|
|
||||||
|
def _delete(self, path: str) -> str:
|
||||||
|
"""Delete a file or directory."""
|
||||||
|
validated_path = self._validate_path(path)
|
||||||
|
if not validated_path:
|
||||||
|
return "Error: Invalid path."
|
||||||
|
|
||||||
|
if validated_path == "/":
|
||||||
|
# Delete all files for this user and tool
|
||||||
|
result = self.collection.delete_many({"user_id": self.user_id, "tool_id": self.tool_id})
|
||||||
|
return f"Deleted {result.deleted_count} file(s) from memory."
|
||||||
|
|
||||||
|
# Check if it's a directory (ends with /)
|
||||||
|
if validated_path.endswith("/"):
|
||||||
|
# Delete all files in directory
|
||||||
|
result = self.collection.delete_many({
|
||||||
|
"user_id": self.user_id,
|
||||||
|
"tool_id": self.tool_id,
|
||||||
|
"path": {"$regex": f"^{re.escape(validated_path)}"}
|
||||||
|
})
|
||||||
|
return f"Deleted directory and {result.deleted_count} file(s)."
|
||||||
|
|
||||||
|
# Try to delete as directory first (without trailing slash)
|
||||||
|
# Check if any files start with this path + /
|
||||||
|
search_path = validated_path + "/"
|
||||||
|
directory_result = self.collection.delete_many({
|
||||||
|
"user_id": self.user_id,
|
||||||
|
"tool_id": self.tool_id,
|
||||||
|
"path": {"$regex": f"^{re.escape(search_path)}"}
|
||||||
|
})
|
||||||
|
|
||||||
|
if directory_result.deleted_count > 0:
|
||||||
|
return f"Deleted directory and {directory_result.deleted_count} file(s)."
|
||||||
|
|
||||||
|
# Delete single file
|
||||||
|
result = self.collection.delete_one({
|
||||||
|
"user_id": self.user_id,
|
||||||
|
"tool_id": self.tool_id,
|
||||||
|
"path": validated_path
|
||||||
|
})
|
||||||
|
|
||||||
|
if result.deleted_count:
|
||||||
|
return f"Deleted: {validated_path}"
|
||||||
|
return f"Error: File not found: {validated_path}"
|
||||||
|
|
||||||
|
def _rename(self, old_path: str, new_path: str) -> str:
|
||||||
|
"""Rename or move a file/directory."""
|
||||||
|
validated_old = self._validate_path(old_path)
|
||||||
|
validated_new = self._validate_path(new_path)
|
||||||
|
|
||||||
|
if not validated_old or not validated_new:
|
||||||
|
return "Error: Invalid path."
|
||||||
|
|
||||||
|
if validated_old == "/" or validated_new == "/":
|
||||||
|
return "Error: Cannot rename root directory."
|
||||||
|
|
||||||
|
# Check if renaming a directory
|
||||||
|
if validated_old.endswith("/"):
|
||||||
|
# Ensure validated_new also ends with / for proper path replacement
|
||||||
|
if not validated_new.endswith("/"):
|
||||||
|
validated_new = validated_new + "/"
|
||||||
|
|
||||||
|
# Find all files in the old directory
|
||||||
|
docs = list(self.collection.find({
|
||||||
|
"user_id": self.user_id,
|
||||||
|
"tool_id": self.tool_id,
|
||||||
|
"path": {"$regex": f"^{re.escape(validated_old)}"}
|
||||||
|
}))
|
||||||
|
|
||||||
|
if not docs:
|
||||||
|
return f"Error: Directory not found: {validated_old}"
|
||||||
|
|
||||||
|
# Update paths for all files
|
||||||
|
for doc in docs:
|
||||||
|
old_file_path = doc["path"]
|
||||||
|
new_file_path = old_file_path.replace(validated_old, validated_new, 1)
|
||||||
|
|
||||||
|
self.collection.update_one(
|
||||||
|
{"_id": doc["_id"]},
|
||||||
|
{"$set": {"path": new_file_path, "updated_at": datetime.now()}}
|
||||||
|
)
|
||||||
|
|
||||||
|
return f"Renamed directory: {validated_old} -> {validated_new} ({len(docs)} files)"
|
||||||
|
|
||||||
|
# Rename single file
|
||||||
|
doc = self.collection.find_one({
|
||||||
|
"user_id": self.user_id,
|
||||||
|
"tool_id": self.tool_id,
|
||||||
|
"path": validated_old
|
||||||
|
})
|
||||||
|
|
||||||
|
if not doc:
|
||||||
|
return f"Error: File not found: {validated_old}"
|
||||||
|
|
||||||
|
# Check if new path already exists
|
||||||
|
existing = self.collection.find_one({
|
||||||
|
"user_id": self.user_id,
|
||||||
|
"tool_id": self.tool_id,
|
||||||
|
"path": validated_new
|
||||||
|
})
|
||||||
|
|
||||||
|
if existing:
|
||||||
|
return f"Error: File already exists at {validated_new}"
|
||||||
|
|
||||||
|
# Delete the old document and create a new one with the new path
|
||||||
|
self.collection.delete_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_old})
|
||||||
|
self.collection.insert_one({
|
||||||
|
"user_id": self.user_id,
|
||||||
|
"tool_id": self.tool_id,
|
||||||
|
"path": validated_new,
|
||||||
|
"content": doc.get("content", ""),
|
||||||
|
"updated_at": datetime.now()
|
||||||
|
})
|
||||||
|
|
||||||
|
return f"Renamed: {validated_old} -> {validated_new}"
|
||||||
199
application/agents/tools/notes.py
Normal file
199
application/agents/tools/notes.py
Normal file
@@ -0,0 +1,199 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from .base import Tool
|
||||||
|
from application.core.mongo_db import MongoDB
|
||||||
|
from application.core.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
class NotesTool(Tool):
|
||||||
|
"""Notepad
|
||||||
|
|
||||||
|
Single note. Supports viewing, overwriting, string replacement.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, tool_config: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> None:
|
||||||
|
"""Initialize the tool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_config: Optional tool configuration. Should include:
|
||||||
|
- tool_id: Unique identifier for this notes tool instance (from user_tools._id)
|
||||||
|
This ensures each user's tool configuration has isolated notes
|
||||||
|
user_id: The authenticated user's id (should come from decoded_token["sub"]).
|
||||||
|
"""
|
||||||
|
self.user_id: Optional[str] = user_id
|
||||||
|
|
||||||
|
# Get tool_id from configuration (passed from user_tools._id in production)
|
||||||
|
# In production, tool_id is the MongoDB ObjectId string from user_tools collection
|
||||||
|
if tool_config and "tool_id" in tool_config:
|
||||||
|
self.tool_id = tool_config["tool_id"]
|
||||||
|
elif user_id:
|
||||||
|
# Fallback for backward compatibility or testing
|
||||||
|
self.tool_id = f"default_{user_id}"
|
||||||
|
else:
|
||||||
|
# Last resort fallback (shouldn't happen in normal use)
|
||||||
|
self.tool_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
|
||||||
|
self.collection = db["notes"]
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# Action implementations
|
||||||
|
# -----------------------------
|
||||||
|
def execute_action(self, action_name: str, **kwargs: Any) -> str:
|
||||||
|
"""Execute an action by name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action_name: One of view, overwrite, str_replace, insert, delete.
|
||||||
|
**kwargs: Parameters for the action.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A human-readable string result.
|
||||||
|
"""
|
||||||
|
if not self.user_id:
|
||||||
|
return "Error: NotesTool requires a valid user_id."
|
||||||
|
|
||||||
|
if action_name == "view":
|
||||||
|
return self._get_note()
|
||||||
|
|
||||||
|
if action_name == "overwrite":
|
||||||
|
return self._overwrite_note(kwargs.get("text", ""))
|
||||||
|
|
||||||
|
if action_name == "str_replace":
|
||||||
|
return self._str_replace(kwargs.get("old_str", ""), kwargs.get("new_str", ""))
|
||||||
|
|
||||||
|
if action_name == "insert":
|
||||||
|
return self._insert(kwargs.get("line_number", 1), kwargs.get("text", ""))
|
||||||
|
|
||||||
|
if action_name == "delete":
|
||||||
|
return self._delete_note()
|
||||||
|
|
||||||
|
return f"Unknown action: {action_name}"
|
||||||
|
|
||||||
|
def get_actions_metadata(self) -> List[Dict[str, Any]]:
|
||||||
|
"""Return JSON metadata describing supported actions for tool schemas."""
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"name": "view",
|
||||||
|
"description": "Retrieve the user's note.",
|
||||||
|
"parameters": {"type": "object", "properties": {}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "overwrite",
|
||||||
|
"description": "Replace the entire note content (creates if doesn't exist).",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"text": {"type": "string", "description": "New note content."}
|
||||||
|
},
|
||||||
|
"required": ["text"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "str_replace",
|
||||||
|
"description": "Replace occurrences of old_str with new_str in the note.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"old_str": {"type": "string", "description": "String to find."},
|
||||||
|
"new_str": {"type": "string", "description": "String to replace with."}
|
||||||
|
},
|
||||||
|
"required": ["old_str", "new_str"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "insert",
|
||||||
|
"description": "Insert text at the specified line number (1-indexed).",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"line_number": {"type": "integer", "description": "Line number to insert at (1-indexed)."},
|
||||||
|
"text": {"type": "string", "description": "Text to insert."}
|
||||||
|
},
|
||||||
|
"required": ["line_number", "text"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "delete",
|
||||||
|
"description": "Delete the user's note.",
|
||||||
|
"parameters": {"type": "object", "properties": {}},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_config_requirements(self) -> Dict[str, Any]:
|
||||||
|
"""Return configuration requirements (none for now)."""
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# Internal helpers (single-note)
|
||||||
|
# -----------------------------
|
||||||
|
def _get_note(self) -> str:
|
||||||
|
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
|
||||||
|
if not doc or not doc.get("note"):
|
||||||
|
return "No note found."
|
||||||
|
return str(doc["note"])
|
||||||
|
|
||||||
|
def _overwrite_note(self, content: str) -> str:
|
||||||
|
content = (content or "").strip()
|
||||||
|
if not content:
|
||||||
|
return "Note content required."
|
||||||
|
self.collection.update_one(
|
||||||
|
{"user_id": self.user_id, "tool_id": self.tool_id},
|
||||||
|
{"$set": {"note": content, "updated_at": datetime.utcnow()}},
|
||||||
|
upsert=True, # ✅ create if missing
|
||||||
|
)
|
||||||
|
return "Note saved."
|
||||||
|
|
||||||
|
def _str_replace(self, old_str: str, new_str: str) -> str:
|
||||||
|
if not old_str:
|
||||||
|
return "old_str is required."
|
||||||
|
|
||||||
|
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
|
||||||
|
if not doc or not doc.get("note"):
|
||||||
|
return "No note found."
|
||||||
|
|
||||||
|
current_note = str(doc["note"])
|
||||||
|
|
||||||
|
# Case-insensitive search
|
||||||
|
if old_str.lower() not in current_note.lower():
|
||||||
|
return f"String '{old_str}' not found in note."
|
||||||
|
|
||||||
|
# Case-insensitive replacement
|
||||||
|
import re
|
||||||
|
updated_note = re.sub(re.escape(old_str), new_str, current_note, flags=re.IGNORECASE)
|
||||||
|
|
||||||
|
self.collection.update_one(
|
||||||
|
{"user_id": self.user_id, "tool_id": self.tool_id},
|
||||||
|
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
|
||||||
|
)
|
||||||
|
return "Note updated."
|
||||||
|
|
||||||
|
def _insert(self, line_number: int, text: str) -> str:
|
||||||
|
if not text:
|
||||||
|
return "Text is required."
|
||||||
|
|
||||||
|
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
|
||||||
|
if not doc or not doc.get("note"):
|
||||||
|
return "No note found."
|
||||||
|
|
||||||
|
current_note = str(doc["note"])
|
||||||
|
lines = current_note.split("\n")
|
||||||
|
|
||||||
|
# Convert to 0-indexed and validate
|
||||||
|
index = line_number - 1
|
||||||
|
if index < 0 or index > len(lines):
|
||||||
|
return f"Invalid line number. Note has {len(lines)} lines."
|
||||||
|
|
||||||
|
lines.insert(index, text)
|
||||||
|
updated_note = "\n".join(lines)
|
||||||
|
|
||||||
|
self.collection.update_one(
|
||||||
|
{"user_id": self.user_id, "tool_id": self.tool_id},
|
||||||
|
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
|
||||||
|
)
|
||||||
|
return "Text inserted."
|
||||||
|
|
||||||
|
def _delete_note(self) -> str:
|
||||||
|
res = self.collection.delete_one({"user_id": self.user_id, "tool_id": self.tool_id})
|
||||||
|
return "Note deleted." if res.deleted_count else "No note found to delete."
|
||||||
321
application/agents/tools/todo_list.py
Normal file
321
application/agents/tools/todo_list.py
Normal file
@@ -0,0 +1,321 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from .base import Tool
|
||||||
|
from application.core.mongo_db import MongoDB
|
||||||
|
from application.core.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
class TodoListTool(Tool):
|
||||||
|
"""Todo List
|
||||||
|
|
||||||
|
Manages todo items for users. Supports creating, viewing, updating, and deleting todos.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, tool_config: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> None:
|
||||||
|
"""Initialize the tool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_config: Optional tool configuration. Should include:
|
||||||
|
- tool_id: Unique identifier for this todo list tool instance (from user_tools._id)
|
||||||
|
This ensures each user's tool configuration has isolated todos
|
||||||
|
user_id: The authenticated user's id (should come from decoded_token["sub"]).
|
||||||
|
"""
|
||||||
|
self.user_id: Optional[str] = user_id
|
||||||
|
|
||||||
|
# Get tool_id from configuration (passed from user_tools._id in production)
|
||||||
|
# In production, tool_id is the MongoDB ObjectId string from user_tools collection
|
||||||
|
if tool_config and "tool_id" in tool_config:
|
||||||
|
self.tool_id = tool_config["tool_id"]
|
||||||
|
elif user_id:
|
||||||
|
# Fallback for backward compatibility or testing
|
||||||
|
self.tool_id = f"default_{user_id}"
|
||||||
|
else:
|
||||||
|
# Last resort fallback (shouldn't happen in normal use)
|
||||||
|
self.tool_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
|
||||||
|
self.collection = db["todos"]
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# Action implementations
|
||||||
|
# -----------------------------
|
||||||
|
def execute_action(self, action_name: str, **kwargs: Any) -> str:
|
||||||
|
"""Execute an action by name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action_name: One of list, create, get, update, complete, delete.
|
||||||
|
**kwargs: Parameters for the action.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A human-readable string result.
|
||||||
|
"""
|
||||||
|
if not self.user_id:
|
||||||
|
return "Error: TodoListTool requires a valid user_id."
|
||||||
|
|
||||||
|
if action_name == "list":
|
||||||
|
return self._list()
|
||||||
|
|
||||||
|
if action_name == "create":
|
||||||
|
return self._create(kwargs.get("title", ""))
|
||||||
|
|
||||||
|
if action_name == "get":
|
||||||
|
return self._get(kwargs.get("todo_id"))
|
||||||
|
|
||||||
|
if action_name == "update":
|
||||||
|
return self._update(
|
||||||
|
kwargs.get("todo_id"),
|
||||||
|
kwargs.get("title", "")
|
||||||
|
)
|
||||||
|
|
||||||
|
if action_name == "complete":
|
||||||
|
return self._complete(kwargs.get("todo_id"))
|
||||||
|
|
||||||
|
if action_name == "delete":
|
||||||
|
return self._delete(kwargs.get("todo_id"))
|
||||||
|
|
||||||
|
return f"Unknown action: {action_name}"
|
||||||
|
|
||||||
|
def get_actions_metadata(self) -> List[Dict[str, Any]]:
|
||||||
|
"""Return JSON metadata describing supported actions for tool schemas."""
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"name": "list",
|
||||||
|
"description": "List all todos for the user.",
|
||||||
|
"parameters": {"type": "object", "properties": {}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "create",
|
||||||
|
"description": "Create a new todo item.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"title": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Title of the todo item."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["title"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "get",
|
||||||
|
"description": "Get a specific todo by ID.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"todo_id": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "The ID of the todo to retrieve."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["todo_id"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "update",
|
||||||
|
"description": "Update a todo's title by ID.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"todo_id": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "The ID of the todo to update."
|
||||||
|
},
|
||||||
|
"title": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The new title for the todo."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["todo_id", "title"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "complete",
|
||||||
|
"description": "Mark a todo as completed.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"todo_id": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "The ID of the todo to mark as completed."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["todo_id"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "delete",
|
||||||
|
"description": "Delete a specific todo by ID.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"todo_id": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "The ID of the todo to delete."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["todo_id"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_config_requirements(self) -> Dict[str, Any]:
|
||||||
|
"""Return configuration requirements."""
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# Internal helpers
|
||||||
|
# -----------------------------
|
||||||
|
def _coerce_todo_id(self, value: Optional[Any]) -> Optional[int]:
|
||||||
|
"""Convert todo identifiers to sequential integers."""
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if isinstance(value, int):
|
||||||
|
return value if value > 0 else None
|
||||||
|
|
||||||
|
if isinstance(value, str):
|
||||||
|
stripped = value.strip()
|
||||||
|
if stripped.isdigit():
|
||||||
|
numeric_value = int(stripped)
|
||||||
|
return numeric_value if numeric_value > 0 else None
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_next_todo_id(self) -> int:
|
||||||
|
"""Get the next sequential todo_id for this user and tool.
|
||||||
|
|
||||||
|
Returns a simple integer (1, 2, 3, ...) scoped to this user/tool.
|
||||||
|
With 5-10 todos max, scanning is negligible.
|
||||||
|
"""
|
||||||
|
# Find all todos for this user/tool and get their IDs
|
||||||
|
todos = list(self.collection.find(
|
||||||
|
{"user_id": self.user_id, "tool_id": self.tool_id},
|
||||||
|
{"todo_id": 1}
|
||||||
|
))
|
||||||
|
|
||||||
|
# Find the maximum todo_id
|
||||||
|
max_id = 0
|
||||||
|
for todo in todos:
|
||||||
|
todo_id = self._coerce_todo_id(todo.get("todo_id"))
|
||||||
|
if todo_id is not None:
|
||||||
|
max_id = max(max_id, todo_id)
|
||||||
|
|
||||||
|
return max_id + 1
|
||||||
|
|
||||||
|
def _list(self) -> str:
|
||||||
|
"""List all todos for the user."""
|
||||||
|
cursor = self.collection.find({"user_id": self.user_id, "tool_id": self.tool_id})
|
||||||
|
todos = list(cursor)
|
||||||
|
|
||||||
|
if not todos:
|
||||||
|
return "No todos found."
|
||||||
|
|
||||||
|
result_lines = ["Todos:"]
|
||||||
|
for doc in todos:
|
||||||
|
todo_id = doc.get("todo_id")
|
||||||
|
title = doc.get("title", "Untitled")
|
||||||
|
status = doc.get("status", "open")
|
||||||
|
|
||||||
|
line = f"[{todo_id}] {title} ({status})"
|
||||||
|
result_lines.append(line)
|
||||||
|
|
||||||
|
return "\n".join(result_lines)
|
||||||
|
|
||||||
|
def _create(self, title: str) -> str:
|
||||||
|
"""Create a new todo item."""
|
||||||
|
title = (title or "").strip()
|
||||||
|
if not title:
|
||||||
|
return "Error: Title is required."
|
||||||
|
|
||||||
|
now = datetime.now()
|
||||||
|
todo_id = self._get_next_todo_id()
|
||||||
|
|
||||||
|
doc = {
|
||||||
|
"todo_id": todo_id,
|
||||||
|
"user_id": self.user_id,
|
||||||
|
"tool_id": self.tool_id,
|
||||||
|
"title": title,
|
||||||
|
"status": "open",
|
||||||
|
"created_at": now,
|
||||||
|
"updated_at": now,
|
||||||
|
}
|
||||||
|
self.collection.insert_one(doc)
|
||||||
|
return f"Todo created with ID {todo_id}: {title}"
|
||||||
|
|
||||||
|
def _get(self, todo_id: Optional[Any]) -> str:
|
||||||
|
"""Get a specific todo by ID."""
|
||||||
|
parsed_todo_id = self._coerce_todo_id(todo_id)
|
||||||
|
if parsed_todo_id is None:
|
||||||
|
return "Error: todo_id must be a positive integer."
|
||||||
|
|
||||||
|
doc = self.collection.find_one({
|
||||||
|
"user_id": self.user_id,
|
||||||
|
"tool_id": self.tool_id,
|
||||||
|
"todo_id": parsed_todo_id
|
||||||
|
})
|
||||||
|
|
||||||
|
if not doc:
|
||||||
|
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||||
|
|
||||||
|
title = doc.get("title", "Untitled")
|
||||||
|
status = doc.get("status", "open")
|
||||||
|
|
||||||
|
result = f"Todo [{parsed_todo_id}]:\nTitle: {title}\nStatus: {status}"
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _update(self, todo_id: Optional[Any], title: str) -> str:
|
||||||
|
"""Update a todo's title by ID."""
|
||||||
|
parsed_todo_id = self._coerce_todo_id(todo_id)
|
||||||
|
if parsed_todo_id is None:
|
||||||
|
return "Error: todo_id must be a positive integer."
|
||||||
|
|
||||||
|
title = (title or "").strip()
|
||||||
|
if not title:
|
||||||
|
return "Error: Title is required."
|
||||||
|
|
||||||
|
result = self.collection.update_one(
|
||||||
|
{"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id},
|
||||||
|
{"$set": {"title": title, "updated_at": datetime.now()}}
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.matched_count == 0:
|
||||||
|
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||||
|
|
||||||
|
return f"Todo {parsed_todo_id} updated to: {title}"
|
||||||
|
|
||||||
|
def _complete(self, todo_id: Optional[Any]) -> str:
|
||||||
|
"""Mark a todo as completed."""
|
||||||
|
parsed_todo_id = self._coerce_todo_id(todo_id)
|
||||||
|
if parsed_todo_id is None:
|
||||||
|
return "Error: todo_id must be a positive integer."
|
||||||
|
|
||||||
|
result = self.collection.update_one(
|
||||||
|
{"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id},
|
||||||
|
{"$set": {"status": "completed", "updated_at": datetime.now()}}
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.matched_count == 0:
|
||||||
|
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||||
|
|
||||||
|
return f"Todo {parsed_todo_id} marked as completed."
|
||||||
|
|
||||||
|
def _delete(self, todo_id: Optional[Any]) -> str:
|
||||||
|
"""Delete a specific todo by ID."""
|
||||||
|
parsed_todo_id = self._coerce_todo_id(todo_id)
|
||||||
|
if parsed_todo_id is None:
|
||||||
|
return "Error: todo_id must be a positive integer."
|
||||||
|
|
||||||
|
result = self.collection.delete_one({
|
||||||
|
"user_id": self.user_id,
|
||||||
|
"tool_id": self.tool_id,
|
||||||
|
"todo_id": parsed_todo_id
|
||||||
|
})
|
||||||
|
|
||||||
|
if result.deleted_count == 0:
|
||||||
|
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||||
|
|
||||||
|
return f"Todo {parsed_todo_id} deleted."
|
||||||
@@ -17,26 +17,53 @@ class ToolActionParser:
|
|||||||
return parser(call)
|
return parser(call)
|
||||||
|
|
||||||
def _parse_openai_llm(self, call):
|
def _parse_openai_llm(self, call):
|
||||||
if isinstance(call, dict):
|
try:
|
||||||
try:
|
call_args = json.loads(call.arguments)
|
||||||
call_args = json.loads(call["function"]["arguments"])
|
tool_parts = call.name.split("_")
|
||||||
tool_id = call["function"]["name"].split("_")[-1]
|
|
||||||
action_name = call["function"]["name"].rsplit("_", 1)[0]
|
# If the tool name doesn't contain an underscore, it's likely a hallucinated tool
|
||||||
except (KeyError, TypeError) as e:
|
if len(tool_parts) < 2:
|
||||||
logger.error(f"Error parsing OpenAI LLM call: {e}")
|
logger.warning(
|
||||||
return None, None, None
|
f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id"
|
||||||
else:
|
)
|
||||||
try:
|
|
||||||
call_args = json.loads(call.function.arguments)
|
|
||||||
tool_id = call.function.name.split("_")[-1]
|
|
||||||
action_name = call.function.name.rsplit("_", 1)[0]
|
|
||||||
except (AttributeError, TypeError) as e:
|
|
||||||
logger.error(f"Error parsing OpenAI LLM call: {e}")
|
|
||||||
return None, None, None
|
return None, None, None
|
||||||
|
|
||||||
|
tool_id = tool_parts[-1]
|
||||||
|
action_name = "_".join(tool_parts[:-1])
|
||||||
|
|
||||||
|
# Validate that tool_id looks like a numerical ID
|
||||||
|
if not tool_id.isdigit():
|
||||||
|
logger.warning(
|
||||||
|
f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call."
|
||||||
|
)
|
||||||
|
|
||||||
|
except (AttributeError, TypeError, json.JSONDecodeError) as e:
|
||||||
|
logger.error(f"Error parsing OpenAI LLM call: {e}")
|
||||||
|
return None, None, None
|
||||||
return tool_id, action_name, call_args
|
return tool_id, action_name, call_args
|
||||||
|
|
||||||
def _parse_google_llm(self, call):
|
def _parse_google_llm(self, call):
|
||||||
call_args = call.args
|
try:
|
||||||
tool_id = call.name.split("_")[-1]
|
call_args = call.arguments
|
||||||
action_name = call.name.rsplit("_", 1)[0]
|
tool_parts = call.name.split("_")
|
||||||
|
|
||||||
|
# If the tool name doesn't contain an underscore, it's likely a hallucinated tool
|
||||||
|
if len(tool_parts) < 2:
|
||||||
|
logger.warning(
|
||||||
|
f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id"
|
||||||
|
)
|
||||||
|
return None, None, None
|
||||||
|
|
||||||
|
tool_id = tool_parts[-1]
|
||||||
|
action_name = "_".join(tool_parts[:-1])
|
||||||
|
|
||||||
|
# Validate that tool_id looks like a numerical ID
|
||||||
|
if not tool_id.isdigit():
|
||||||
|
logger.warning(
|
||||||
|
f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call."
|
||||||
|
)
|
||||||
|
|
||||||
|
except (AttributeError, TypeError) as e:
|
||||||
|
logger.error(f"Error parsing Google LLM call: {e}")
|
||||||
|
return None, None, None
|
||||||
return tool_id, action_name, call_args
|
return tool_id, action_name, call_args
|
||||||
|
|||||||
@@ -23,16 +23,23 @@ class ToolManager:
|
|||||||
tool_config = self.config.get(name, {})
|
tool_config = self.config.get(name, {})
|
||||||
self.tools[name] = obj(tool_config)
|
self.tools[name] = obj(tool_config)
|
||||||
|
|
||||||
def load_tool(self, tool_name, tool_config):
|
def load_tool(self, tool_name, tool_config, user_id=None):
|
||||||
self.config[tool_name] = tool_config
|
self.config[tool_name] = tool_config
|
||||||
module = importlib.import_module(f"application.agents.tools.{tool_name}")
|
module = importlib.import_module(f"application.agents.tools.{tool_name}")
|
||||||
for member_name, obj in inspect.getmembers(module, inspect.isclass):
|
for member_name, obj in inspect.getmembers(module, inspect.isclass):
|
||||||
if issubclass(obj, Tool) and obj is not Tool:
|
if issubclass(obj, Tool) and obj is not Tool:
|
||||||
return obj(tool_config)
|
if tool_name in {"mcp_tool", "notes", "memory", "todo_list"} and user_id:
|
||||||
|
return obj(tool_config, user_id)
|
||||||
|
else:
|
||||||
|
return obj(tool_config)
|
||||||
|
|
||||||
def execute_action(self, tool_name, action_name, **kwargs):
|
def execute_action(self, tool_name, action_name, user_id=None, **kwargs):
|
||||||
if tool_name not in self.tools:
|
if tool_name not in self.tools:
|
||||||
raise ValueError(f"Tool '{tool_name}' not loaded")
|
raise ValueError(f"Tool '{tool_name}' not loaded")
|
||||||
|
if tool_name in {"mcp_tool", "memory", "todo_list"} and user_id:
|
||||||
|
tool_config = self.config.get(tool_name, {})
|
||||||
|
tool = self.load_tool(tool_name, tool_config, user_id)
|
||||||
|
return tool.execute_action(action_name, **kwargs)
|
||||||
return self.tools[tool_name].execute_action(action_name, **kwargs)
|
return self.tools[tool_name].execute_action(action_name, **kwargs)
|
||||||
|
|
||||||
def get_all_actions_metadata(self):
|
def get_all_actions_metadata(self):
|
||||||
|
|||||||
@@ -0,0 +1,7 @@
|
|||||||
|
from flask_restx import Api
|
||||||
|
|
||||||
|
api = Api(
|
||||||
|
version="1.0",
|
||||||
|
title="DocsGPT API",
|
||||||
|
description="API for DocsGPT",
|
||||||
|
)
|
||||||
|
|||||||
@@ -0,0 +1,19 @@
|
|||||||
|
from flask import Blueprint
|
||||||
|
|
||||||
|
from application.api import api
|
||||||
|
from application.api.answer.routes.answer import AnswerResource
|
||||||
|
from application.api.answer.routes.base import answer_ns
|
||||||
|
from application.api.answer.routes.stream import StreamResource
|
||||||
|
|
||||||
|
|
||||||
|
answer = Blueprint("answer", __name__)
|
||||||
|
|
||||||
|
api.add_namespace(answer_ns)
|
||||||
|
|
||||||
|
|
||||||
|
def init_answer_routes():
|
||||||
|
api.add_resource(StreamResource, "/stream")
|
||||||
|
api.add_resource(AnswerResource, "/api/answer")
|
||||||
|
|
||||||
|
|
||||||
|
init_answer_routes()
|
||||||
|
|||||||
@@ -1,915 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import datetime
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
from bson.dbref import DBRef
|
|
||||||
from bson.objectid import ObjectId
|
|
||||||
from flask import Blueprint, make_response, request, Response
|
|
||||||
from flask_restx import fields, Namespace, Resource
|
|
||||||
|
|
||||||
from application.agents.agent_creator import AgentCreator
|
|
||||||
|
|
||||||
from application.core.mongo_db import MongoDB
|
|
||||||
from application.core.settings import settings
|
|
||||||
from application.error import bad_request
|
|
||||||
from application.extensions import api
|
|
||||||
from application.llm.llm_creator import LLMCreator
|
|
||||||
from application.retriever.retriever_creator import RetrieverCreator
|
|
||||||
from application.utils import check_required_fields, limit_chat_history
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
mongo = MongoDB.get_client()
|
|
||||||
db = mongo[settings.MONGO_DB_NAME]
|
|
||||||
conversations_collection = db["conversations"]
|
|
||||||
sources_collection = db["sources"]
|
|
||||||
prompts_collection = db["prompts"]
|
|
||||||
agents_collection = db["agents"]
|
|
||||||
user_logs_collection = db["user_logs"]
|
|
||||||
attachments_collection = db["attachments"]
|
|
||||||
|
|
||||||
answer = Blueprint("answer", __name__)
|
|
||||||
answer_ns = Namespace("answer", description="Answer related operations", path="/")
|
|
||||||
api.add_namespace(answer_ns)
|
|
||||||
|
|
||||||
gpt_model = ""
|
|
||||||
# to have some kind of default behaviour
|
|
||||||
if settings.LLM_NAME == "openai":
|
|
||||||
gpt_model = "gpt-4o-mini"
|
|
||||||
elif settings.LLM_NAME == "anthropic":
|
|
||||||
gpt_model = "claude-2"
|
|
||||||
elif settings.LLM_NAME == "groq":
|
|
||||||
gpt_model = "llama3-8b-8192"
|
|
||||||
elif settings.LLM_NAME == "novita":
|
|
||||||
gpt_model = "deepseek/deepseek-r1"
|
|
||||||
|
|
||||||
if settings.MODEL_NAME: # in case there is particular model name configured
|
|
||||||
gpt_model = settings.MODEL_NAME
|
|
||||||
|
|
||||||
# load the prompts
|
|
||||||
current_dir = os.path.dirname(
|
|
||||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
||||||
)
|
|
||||||
with open(os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r") as f:
|
|
||||||
chat_combine_template = f.read()
|
|
||||||
|
|
||||||
with open(os.path.join(current_dir, "prompts", "chat_reduce_prompt.txt"), "r") as f:
|
|
||||||
chat_reduce_template = f.read()
|
|
||||||
|
|
||||||
with open(os.path.join(current_dir, "prompts", "chat_combine_creative.txt"), "r") as f:
|
|
||||||
chat_combine_creative = f.read()
|
|
||||||
|
|
||||||
with open(os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r") as f:
|
|
||||||
chat_combine_strict = f.read()
|
|
||||||
|
|
||||||
api_key_set = settings.API_KEY is not None
|
|
||||||
embeddings_key_set = settings.EMBEDDINGS_KEY is not None
|
|
||||||
|
|
||||||
|
|
||||||
async def async_generate(chain, question, chat_history):
|
|
||||||
result = await chain.arun({"question": question, "chat_history": chat_history})
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def run_async_chain(chain, question, chat_history):
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
result = {}
|
|
||||||
try:
|
|
||||||
answer = loop.run_until_complete(async_generate(chain, question, chat_history))
|
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
result["answer"] = answer
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def get_agent_key(agent_id, user_id):
|
|
||||||
if not agent_id:
|
|
||||||
return None, False, None
|
|
||||||
|
|
||||||
try:
|
|
||||||
agent = agents_collection.find_one({"_id": ObjectId(agent_id)})
|
|
||||||
if agent is None:
|
|
||||||
raise Exception("Agent not found", 404)
|
|
||||||
|
|
||||||
is_owner = agent.get("user") == user_id
|
|
||||||
|
|
||||||
if is_owner:
|
|
||||||
agents_collection.update_one(
|
|
||||||
{"_id": ObjectId(agent_id)},
|
|
||||||
{"$set": {"lastUsedAt": datetime.datetime.now(datetime.timezone.utc)}},
|
|
||||||
)
|
|
||||||
return str(agent["key"]), False, None
|
|
||||||
|
|
||||||
is_shared_with_user = agent.get(
|
|
||||||
"shared_publicly", False
|
|
||||||
) or user_id in agent.get("shared_with", [])
|
|
||||||
|
|
||||||
if is_shared_with_user:
|
|
||||||
return str(agent["key"]), True, agent.get("shared_token")
|
|
||||||
|
|
||||||
raise Exception("Unauthorized access to the agent", 403)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in get_agent_key: {str(e)}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
def get_data_from_api_key(api_key):
|
|
||||||
data = agents_collection.find_one({"key": api_key})
|
|
||||||
if not data:
|
|
||||||
raise Exception("Invalid API Key, please generate a new key", 401)
|
|
||||||
|
|
||||||
source = data.get("source")
|
|
||||||
if isinstance(source, DBRef):
|
|
||||||
source_doc = db.dereference(source)
|
|
||||||
data["source"] = str(source_doc["_id"])
|
|
||||||
data["retriever"] = source_doc.get("retriever", data.get("retriever"))
|
|
||||||
else:
|
|
||||||
data["source"] = {}
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
def get_retriever(source_id: str):
|
|
||||||
doc = sources_collection.find_one({"_id": ObjectId(source_id)})
|
|
||||||
if doc is None:
|
|
||||||
raise Exception("Source document does not exist", 404)
|
|
||||||
retriever_name = None if "retriever" not in doc else doc["retriever"]
|
|
||||||
return retriever_name
|
|
||||||
|
|
||||||
|
|
||||||
def is_azure_configured():
|
|
||||||
return (
|
|
||||||
settings.OPENAI_API_BASE
|
|
||||||
and settings.OPENAI_API_VERSION
|
|
||||||
and settings.AZURE_DEPLOYMENT_NAME
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def save_conversation(
|
|
||||||
conversation_id,
|
|
||||||
question,
|
|
||||||
response,
|
|
||||||
thought,
|
|
||||||
source_log_docs,
|
|
||||||
tool_calls,
|
|
||||||
llm,
|
|
||||||
decoded_token,
|
|
||||||
index=None,
|
|
||||||
api_key=None,
|
|
||||||
agent_id=None,
|
|
||||||
is_shared_usage=False,
|
|
||||||
shared_token=None,
|
|
||||||
):
|
|
||||||
current_time = datetime.datetime.now(datetime.timezone.utc)
|
|
||||||
if conversation_id is not None and index is not None:
|
|
||||||
conversations_collection.update_one(
|
|
||||||
{"_id": ObjectId(conversation_id), f"queries.{index}": {"$exists": True}},
|
|
||||||
{
|
|
||||||
"$set": {
|
|
||||||
f"queries.{index}.prompt": question,
|
|
||||||
f"queries.{index}.response": response,
|
|
||||||
f"queries.{index}.thought": thought,
|
|
||||||
f"queries.{index}.sources": source_log_docs,
|
|
||||||
f"queries.{index}.tool_calls": tool_calls,
|
|
||||||
f"queries.{index}.timestamp": current_time,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
##remove following queries from the array
|
|
||||||
conversations_collection.update_one(
|
|
||||||
{"_id": ObjectId(conversation_id), f"queries.{index}": {"$exists": True}},
|
|
||||||
{"$push": {"queries": {"$each": [], "$slice": index + 1}}},
|
|
||||||
)
|
|
||||||
elif conversation_id is not None and conversation_id != "None":
|
|
||||||
conversations_collection.update_one(
|
|
||||||
{"_id": ObjectId(conversation_id)},
|
|
||||||
{
|
|
||||||
"$push": {
|
|
||||||
"queries": {
|
|
||||||
"prompt": question,
|
|
||||||
"response": response,
|
|
||||||
"thought": thought,
|
|
||||||
"sources": source_log_docs,
|
|
||||||
"tool_calls": tool_calls,
|
|
||||||
"timestamp": current_time,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# create new conversation
|
|
||||||
# generate summary
|
|
||||||
messages_summary = [
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": "Summarise following conversation in no more than 3 "
|
|
||||||
"words, respond ONLY with the summary, use the same "
|
|
||||||
"language as the system",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "Summarise following conversation in no more than 3 words, "
|
|
||||||
"respond ONLY with the summary, use the same language as the "
|
|
||||||
"system \n\nUser: " + question + "\n\n" + "AI: " + response,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
completion = llm.gen(model=gpt_model, messages=messages_summary, max_tokens=30)
|
|
||||||
conversation_data = {
|
|
||||||
"user": decoded_token.get("sub"),
|
|
||||||
"date": datetime.datetime.utcnow(),
|
|
||||||
"name": completion,
|
|
||||||
"queries": [
|
|
||||||
{
|
|
||||||
"prompt": question,
|
|
||||||
"response": response,
|
|
||||||
"thought": thought,
|
|
||||||
"sources": source_log_docs,
|
|
||||||
"tool_calls": tool_calls,
|
|
||||||
"timestamp": current_time,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
if api_key:
|
|
||||||
if agent_id:
|
|
||||||
conversation_data["agent_id"] = agent_id
|
|
||||||
if is_shared_usage:
|
|
||||||
conversation_data["is_shared_usage"] = is_shared_usage
|
|
||||||
conversation_data["shared_token"] = shared_token
|
|
||||||
api_key_doc = agents_collection.find_one({"key": api_key})
|
|
||||||
if api_key_doc:
|
|
||||||
conversation_data["api_key"] = api_key_doc["key"]
|
|
||||||
conversation_id = conversations_collection.insert_one(
|
|
||||||
conversation_data
|
|
||||||
).inserted_id
|
|
||||||
return conversation_id
|
|
||||||
|
|
||||||
|
|
||||||
def get_prompt(prompt_id):
|
|
||||||
if prompt_id == "default":
|
|
||||||
prompt = chat_combine_template
|
|
||||||
elif prompt_id == "creative":
|
|
||||||
prompt = chat_combine_creative
|
|
||||||
elif prompt_id == "strict":
|
|
||||||
prompt = chat_combine_strict
|
|
||||||
else:
|
|
||||||
prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"]
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
|
|
||||||
def complete_stream(
|
|
||||||
question,
|
|
||||||
agent,
|
|
||||||
retriever,
|
|
||||||
conversation_id,
|
|
||||||
user_api_key,
|
|
||||||
decoded_token,
|
|
||||||
isNoneDoc=False,
|
|
||||||
index=None,
|
|
||||||
should_save_conversation=True,
|
|
||||||
attachments=None,
|
|
||||||
agent_id=None,
|
|
||||||
is_shared_usage=False,
|
|
||||||
shared_token=None,
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
response_full, thought, source_log_docs, tool_calls = "", "", [], []
|
|
||||||
attachment_ids = []
|
|
||||||
|
|
||||||
if attachments:
|
|
||||||
attachment_ids = [attachment["id"] for attachment in attachments]
|
|
||||||
logger.info(
|
|
||||||
f"Processing request with {len(attachments)} attachments: {attachment_ids}"
|
|
||||||
)
|
|
||||||
|
|
||||||
answer = agent.gen(query=question, retriever=retriever)
|
|
||||||
|
|
||||||
for line in answer:
|
|
||||||
if "answer" in line:
|
|
||||||
response_full += str(line["answer"])
|
|
||||||
data = json.dumps({"type": "answer", "answer": line["answer"]})
|
|
||||||
yield f"data: {data}\n\n"
|
|
||||||
elif "sources" in line:
|
|
||||||
truncated_sources = []
|
|
||||||
source_log_docs = line["sources"]
|
|
||||||
for source in line["sources"]:
|
|
||||||
truncated_source = source.copy()
|
|
||||||
if "text" in truncated_source:
|
|
||||||
truncated_source["text"] = (
|
|
||||||
truncated_source["text"][:100].strip() + "..."
|
|
||||||
)
|
|
||||||
truncated_sources.append(truncated_source)
|
|
||||||
if len(truncated_sources) > 0:
|
|
||||||
data = json.dumps({"type": "source", "source": truncated_sources})
|
|
||||||
yield f"data: {data}\n\n"
|
|
||||||
elif "tool_calls" in line:
|
|
||||||
tool_calls = line["tool_calls"]
|
|
||||||
data = json.dumps({"type": "tool_calls", "tool_calls": tool_calls})
|
|
||||||
yield f"data: {data}\n\n"
|
|
||||||
elif "thought" in line:
|
|
||||||
thought += line["thought"]
|
|
||||||
data = json.dumps({"type": "thought", "thought": line["thought"]})
|
|
||||||
yield f"data: {data}\n\n"
|
|
||||||
|
|
||||||
if isNoneDoc:
|
|
||||||
for doc in source_log_docs:
|
|
||||||
doc["source"] = "None"
|
|
||||||
|
|
||||||
llm = LLMCreator.create_llm(
|
|
||||||
settings.LLM_NAME,
|
|
||||||
api_key=settings.API_KEY,
|
|
||||||
user_api_key=user_api_key,
|
|
||||||
decoded_token=decoded_token,
|
|
||||||
)
|
|
||||||
|
|
||||||
if should_save_conversation:
|
|
||||||
conversation_id = save_conversation(
|
|
||||||
conversation_id,
|
|
||||||
question,
|
|
||||||
response_full,
|
|
||||||
thought,
|
|
||||||
source_log_docs,
|
|
||||||
tool_calls,
|
|
||||||
llm,
|
|
||||||
decoded_token,
|
|
||||||
index,
|
|
||||||
api_key=user_api_key,
|
|
||||||
agent_id=agent_id,
|
|
||||||
is_shared_usage=is_shared_usage,
|
|
||||||
shared_token=shared_token,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
conversation_id = None
|
|
||||||
|
|
||||||
# send data.type = "end" to indicate that the stream has ended as json
|
|
||||||
data = json.dumps({"type": "id", "id": str(conversation_id)})
|
|
||||||
yield f"data: {data}\n\n"
|
|
||||||
|
|
||||||
retriever_params = retriever.get_params()
|
|
||||||
user_logs_collection.insert_one(
|
|
||||||
{
|
|
||||||
"action": "stream_answer",
|
|
||||||
"level": "info",
|
|
||||||
"user": decoded_token.get("sub"),
|
|
||||||
"api_key": user_api_key,
|
|
||||||
"question": question,
|
|
||||||
"response": response_full,
|
|
||||||
"sources": source_log_docs,
|
|
||||||
"retriever_params": retriever_params,
|
|
||||||
"attachments": attachment_ids,
|
|
||||||
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
data = json.dumps({"type": "end"})
|
|
||||||
yield f"data: {data}\n\n"
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in stream: {str(e)}", exc_info=True)
|
|
||||||
data = json.dumps(
|
|
||||||
{
|
|
||||||
"type": "error",
|
|
||||||
"error": "Please try again later. We apologize for any inconvenience.",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
yield f"data: {data}\n\n"
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
@answer_ns.route("/stream")
|
|
||||||
class Stream(Resource):
|
|
||||||
stream_model = api.model(
|
|
||||||
"StreamModel",
|
|
||||||
{
|
|
||||||
"question": fields.String(
|
|
||||||
required=True, description="Question to be asked"
|
|
||||||
),
|
|
||||||
"history": fields.List(
|
|
||||||
fields.String, required=False, description="Chat history"
|
|
||||||
),
|
|
||||||
"conversation_id": fields.String(
|
|
||||||
required=False, description="Conversation ID"
|
|
||||||
),
|
|
||||||
"prompt_id": fields.String(
|
|
||||||
required=False, default="default", description="Prompt ID"
|
|
||||||
),
|
|
||||||
"chunks": fields.Integer(
|
|
||||||
required=False, default=2, description="Number of chunks"
|
|
||||||
),
|
|
||||||
"token_limit": fields.Integer(required=False, description="Token limit"),
|
|
||||||
"retriever": fields.String(required=False, description="Retriever type"),
|
|
||||||
"api_key": fields.String(required=False, description="API key"),
|
|
||||||
"active_docs": fields.String(
|
|
||||||
required=False, description="Active documents"
|
|
||||||
),
|
|
||||||
"isNoneDoc": fields.Boolean(
|
|
||||||
required=False, description="Flag indicating if no document is used"
|
|
||||||
),
|
|
||||||
"index": fields.Integer(
|
|
||||||
required=False, description="Index of the query to update"
|
|
||||||
),
|
|
||||||
"save_conversation": fields.Boolean(
|
|
||||||
required=False,
|
|
||||||
default=True,
|
|
||||||
description="Whether to save the conversation",
|
|
||||||
),
|
|
||||||
"attachments": fields.List(
|
|
||||||
fields.String, required=False, description="List of attachment IDs"
|
|
||||||
),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
@api.expect(stream_model)
|
|
||||||
@api.doc(description="Stream a response based on the question and retriever")
|
|
||||||
def post(self):
|
|
||||||
data = request.get_json()
|
|
||||||
required_fields = ["question"]
|
|
||||||
if "index" in data:
|
|
||||||
required_fields = ["question", "conversation_id"]
|
|
||||||
missing_fields = check_required_fields(data, required_fields)
|
|
||||||
if missing_fields:
|
|
||||||
return missing_fields
|
|
||||||
|
|
||||||
save_conv = data.get("save_conversation", True)
|
|
||||||
|
|
||||||
try:
|
|
||||||
question = data["question"]
|
|
||||||
history = limit_chat_history(
|
|
||||||
json.loads(data.get("history", [])), gpt_model=gpt_model
|
|
||||||
)
|
|
||||||
conversation_id = data.get("conversation_id")
|
|
||||||
prompt_id = data.get("prompt_id", "default")
|
|
||||||
attachment_ids = data.get("attachments", [])
|
|
||||||
|
|
||||||
index = data.get("index", None)
|
|
||||||
chunks = int(data.get("chunks", 2))
|
|
||||||
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
|
|
||||||
retriever_name = data.get("retriever", "classic")
|
|
||||||
agent_id = data.get("agent_id", None)
|
|
||||||
agent_type = settings.AGENT_NAME
|
|
||||||
agent_key, is_shared_usage, shared_token = get_agent_key(
|
|
||||||
agent_id, request.decoded_token.get("sub")
|
|
||||||
)
|
|
||||||
|
|
||||||
if agent_key:
|
|
||||||
data.update({"api_key": agent_key})
|
|
||||||
else:
|
|
||||||
agent_id = None
|
|
||||||
|
|
||||||
if "api_key" in data:
|
|
||||||
data_key = get_data_from_api_key(data["api_key"])
|
|
||||||
chunks = int(data_key.get("chunks", 2))
|
|
||||||
prompt_id = data_key.get("prompt_id", "default")
|
|
||||||
source = {"active_docs": data_key.get("source")}
|
|
||||||
retriever_name = data_key.get("retriever", retriever_name)
|
|
||||||
user_api_key = data["api_key"]
|
|
||||||
agent_type = data_key.get("agent_type", agent_type)
|
|
||||||
if is_shared_usage:
|
|
||||||
decoded_token = request.decoded_token
|
|
||||||
else:
|
|
||||||
decoded_token = {"sub": data_key.get("user")}
|
|
||||||
is_shared_usage = False
|
|
||||||
|
|
||||||
elif "active_docs" in data:
|
|
||||||
source = {"active_docs": data["active_docs"]}
|
|
||||||
retriever_name = get_retriever(data["active_docs"]) or retriever_name
|
|
||||||
user_api_key = None
|
|
||||||
decoded_token = request.decoded_token
|
|
||||||
|
|
||||||
else:
|
|
||||||
source = {}
|
|
||||||
user_api_key = None
|
|
||||||
decoded_token = request.decoded_token
|
|
||||||
|
|
||||||
if not decoded_token:
|
|
||||||
return make_response({"error": "Unauthorized"}, 401)
|
|
||||||
|
|
||||||
attachments = get_attachments_content(
|
|
||||||
attachment_ids, decoded_token.get("sub")
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"/stream - request_data: {data}, source: {source}, attachments: {len(attachments)}",
|
|
||||||
extra={"data": json.dumps({"request_data": data, "source": source})},
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt = get_prompt(prompt_id)
|
|
||||||
if "isNoneDoc" in data and data["isNoneDoc"] is True:
|
|
||||||
chunks = 0
|
|
||||||
|
|
||||||
agent = AgentCreator.create_agent(
|
|
||||||
agent_type,
|
|
||||||
endpoint="stream",
|
|
||||||
llm_name=settings.LLM_NAME,
|
|
||||||
gpt_model=gpt_model,
|
|
||||||
api_key=settings.API_KEY,
|
|
||||||
user_api_key=user_api_key,
|
|
||||||
prompt=prompt,
|
|
||||||
chat_history=history,
|
|
||||||
decoded_token=decoded_token,
|
|
||||||
attachments=attachments,
|
|
||||||
)
|
|
||||||
|
|
||||||
retriever = RetrieverCreator.create_retriever(
|
|
||||||
retriever_name,
|
|
||||||
source=source,
|
|
||||||
chat_history=history,
|
|
||||||
prompt=prompt,
|
|
||||||
chunks=chunks,
|
|
||||||
token_limit=token_limit,
|
|
||||||
gpt_model=gpt_model,
|
|
||||||
user_api_key=user_api_key,
|
|
||||||
decoded_token=decoded_token,
|
|
||||||
)
|
|
||||||
is_shared_usage_val = data.get("is_shared_usage", False)
|
|
||||||
is_shared_token_val = data.get("shared_token", None)
|
|
||||||
return Response(
|
|
||||||
complete_stream(
|
|
||||||
question=question,
|
|
||||||
agent=agent,
|
|
||||||
retriever=retriever,
|
|
||||||
conversation_id=conversation_id,
|
|
||||||
user_api_key=user_api_key,
|
|
||||||
decoded_token=decoded_token,
|
|
||||||
isNoneDoc=data.get("isNoneDoc"),
|
|
||||||
index=index,
|
|
||||||
should_save_conversation=save_conv,
|
|
||||||
agent_id=agent_id,
|
|
||||||
is_shared_usage=is_shared_usage_val,
|
|
||||||
shared_token=is_shared_token_val,
|
|
||||||
),
|
|
||||||
mimetype="text/event-stream",
|
|
||||||
)
|
|
||||||
|
|
||||||
except ValueError:
|
|
||||||
message = "Malformed request body"
|
|
||||||
logger.error(f"/stream - error: {message}")
|
|
||||||
return Response(
|
|
||||||
error_stream_generate(message),
|
|
||||||
status=400,
|
|
||||||
mimetype="text/event-stream",
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"/stream - error: {str(e)} - traceback: {traceback.format_exc()}",
|
|
||||||
extra={"error": str(e), "traceback": traceback.format_exc()},
|
|
||||||
)
|
|
||||||
status_code = 400
|
|
||||||
return Response(
|
|
||||||
error_stream_generate("Unknown error occurred"),
|
|
||||||
status=status_code,
|
|
||||||
mimetype="text/event-stream",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def error_stream_generate(err_response):
|
|
||||||
data = json.dumps({"type": "error", "error": err_response})
|
|
||||||
yield f"data: {data}\n\n"
|
|
||||||
|
|
||||||
|
|
||||||
@answer_ns.route("/api/answer")
|
|
||||||
class Answer(Resource):
|
|
||||||
answer_model = api.model(
|
|
||||||
"AnswerModel",
|
|
||||||
{
|
|
||||||
"question": fields.String(
|
|
||||||
required=True, description="The question to answer"
|
|
||||||
),
|
|
||||||
"history": fields.List(
|
|
||||||
fields.String, required=False, description="Conversation history"
|
|
||||||
),
|
|
||||||
"conversation_id": fields.String(
|
|
||||||
required=False, description="Conversation ID"
|
|
||||||
),
|
|
||||||
"prompt_id": fields.String(
|
|
||||||
required=False, default="default", description="Prompt ID"
|
|
||||||
),
|
|
||||||
"chunks": fields.Integer(
|
|
||||||
required=False, default=2, description="Number of chunks"
|
|
||||||
),
|
|
||||||
"token_limit": fields.Integer(required=False, description="Token limit"),
|
|
||||||
"retriever": fields.String(required=False, description="Retriever type"),
|
|
||||||
"api_key": fields.String(required=False, description="API key"),
|
|
||||||
"active_docs": fields.String(
|
|
||||||
required=False, description="Active documents"
|
|
||||||
),
|
|
||||||
"isNoneDoc": fields.Boolean(
|
|
||||||
required=False, description="Flag indicating if no document is used"
|
|
||||||
),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
@api.expect(answer_model)
|
|
||||||
@api.doc(description="Provide an answer based on the question and retriever")
|
|
||||||
def post(self):
|
|
||||||
data = request.get_json()
|
|
||||||
required_fields = ["question"]
|
|
||||||
missing_fields = check_required_fields(data, required_fields)
|
|
||||||
if missing_fields:
|
|
||||||
return missing_fields
|
|
||||||
|
|
||||||
try:
|
|
||||||
question = data["question"]
|
|
||||||
history = limit_chat_history(
|
|
||||||
json.loads(data.get("history", [])), gpt_model=gpt_model
|
|
||||||
)
|
|
||||||
conversation_id = data.get("conversation_id")
|
|
||||||
prompt_id = data.get("prompt_id", "default")
|
|
||||||
chunks = int(data.get("chunks", 2))
|
|
||||||
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
|
|
||||||
retriever_name = data.get("retriever", "classic")
|
|
||||||
agent_type = settings.AGENT_NAME
|
|
||||||
|
|
||||||
if "api_key" in data:
|
|
||||||
data_key = get_data_from_api_key(data["api_key"])
|
|
||||||
chunks = int(data_key.get("chunks", 2))
|
|
||||||
prompt_id = data_key.get("prompt_id", "default")
|
|
||||||
source = {"active_docs": data_key.get("source")}
|
|
||||||
retriever_name = data_key.get("retriever", retriever_name)
|
|
||||||
user_api_key = data["api_key"]
|
|
||||||
agent_type = data_key.get("agent_type", agent_type)
|
|
||||||
decoded_token = {"sub": data_key.get("user")}
|
|
||||||
|
|
||||||
elif "active_docs" in data:
|
|
||||||
source = {"active_docs": data["active_docs"]}
|
|
||||||
retriever_name = get_retriever(data["active_docs"]) or retriever_name
|
|
||||||
user_api_key = None
|
|
||||||
decoded_token = request.decoded_token
|
|
||||||
|
|
||||||
else:
|
|
||||||
source = {}
|
|
||||||
user_api_key = None
|
|
||||||
decoded_token = request.decoded_token
|
|
||||||
|
|
||||||
if not decoded_token:
|
|
||||||
return make_response({"error": "Unauthorized"}, 401)
|
|
||||||
|
|
||||||
prompt = get_prompt(prompt_id)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"/api/answer - request_data: {data}, source: {source}",
|
|
||||||
extra={"data": json.dumps({"request_data": data, "source": source})},
|
|
||||||
)
|
|
||||||
|
|
||||||
agent = AgentCreator.create_agent(
|
|
||||||
agent_type,
|
|
||||||
endpoint="api/answer",
|
|
||||||
llm_name=settings.LLM_NAME,
|
|
||||||
gpt_model=gpt_model,
|
|
||||||
api_key=settings.API_KEY,
|
|
||||||
user_api_key=user_api_key,
|
|
||||||
prompt=prompt,
|
|
||||||
chat_history=history,
|
|
||||||
decoded_token=decoded_token,
|
|
||||||
)
|
|
||||||
|
|
||||||
retriever = RetrieverCreator.create_retriever(
|
|
||||||
retriever_name,
|
|
||||||
source=source,
|
|
||||||
chat_history=history,
|
|
||||||
prompt=prompt,
|
|
||||||
chunks=chunks,
|
|
||||||
token_limit=token_limit,
|
|
||||||
gpt_model=gpt_model,
|
|
||||||
user_api_key=user_api_key,
|
|
||||||
decoded_token=decoded_token,
|
|
||||||
)
|
|
||||||
|
|
||||||
response_full = ""
|
|
||||||
source_log_docs = []
|
|
||||||
tool_calls = []
|
|
||||||
stream_ended = False
|
|
||||||
thought = ""
|
|
||||||
|
|
||||||
for line in complete_stream(
|
|
||||||
question=question,
|
|
||||||
agent=agent,
|
|
||||||
retriever=retriever,
|
|
||||||
conversation_id=conversation_id,
|
|
||||||
user_api_key=user_api_key,
|
|
||||||
decoded_token=decoded_token,
|
|
||||||
isNoneDoc=data.get("isNoneDoc"),
|
|
||||||
index=None,
|
|
||||||
should_save_conversation=False,
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
event_data = line.replace("data: ", "").strip()
|
|
||||||
event = json.loads(event_data)
|
|
||||||
|
|
||||||
if event["type"] == "answer":
|
|
||||||
response_full += event["answer"]
|
|
||||||
elif event["type"] == "source":
|
|
||||||
source_log_docs = event["source"]
|
|
||||||
elif event["type"] == "tool_calls":
|
|
||||||
tool_calls = event["tool_calls"]
|
|
||||||
elif event["type"] == "thought":
|
|
||||||
thought = event["thought"]
|
|
||||||
elif event["type"] == "error":
|
|
||||||
logger.error(f"Error from stream: {event['error']}")
|
|
||||||
return bad_request(500, event["error"])
|
|
||||||
elif event["type"] == "end":
|
|
||||||
stream_ended = True
|
|
||||||
|
|
||||||
except (json.JSONDecodeError, KeyError) as e:
|
|
||||||
logger.warning(f"Error parsing stream event: {e}, line: {line}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not stream_ended:
|
|
||||||
logger.error("Stream ended unexpectedly without an 'end' event.")
|
|
||||||
return bad_request(500, "Stream ended unexpectedly.")
|
|
||||||
|
|
||||||
if data.get("isNoneDoc"):
|
|
||||||
for doc in source_log_docs:
|
|
||||||
doc["source"] = "None"
|
|
||||||
|
|
||||||
llm = LLMCreator.create_llm(
|
|
||||||
settings.LLM_NAME,
|
|
||||||
api_key=settings.API_KEY,
|
|
||||||
user_api_key=user_api_key,
|
|
||||||
decoded_token=decoded_token,
|
|
||||||
)
|
|
||||||
|
|
||||||
result = {"answer": response_full, "sources": source_log_docs}
|
|
||||||
result["conversation_id"] = str(
|
|
||||||
save_conversation(
|
|
||||||
conversation_id,
|
|
||||||
question,
|
|
||||||
response_full,
|
|
||||||
thought,
|
|
||||||
source_log_docs,
|
|
||||||
tool_calls,
|
|
||||||
llm,
|
|
||||||
decoded_token,
|
|
||||||
api_key=user_api_key,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
retriever_params = retriever.get_params()
|
|
||||||
user_logs_collection.insert_one(
|
|
||||||
{
|
|
||||||
"action": "api_answer",
|
|
||||||
"level": "info",
|
|
||||||
"user": decoded_token.get("sub"),
|
|
||||||
"api_key": user_api_key,
|
|
||||||
"question": question,
|
|
||||||
"response": response_full,
|
|
||||||
"sources": source_log_docs,
|
|
||||||
"retriever_params": retriever_params,
|
|
||||||
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",
|
|
||||||
extra={"error": str(e), "traceback": traceback.format_exc()},
|
|
||||||
)
|
|
||||||
return bad_request(500, str(e))
|
|
||||||
|
|
||||||
return make_response(result, 200)
|
|
||||||
|
|
||||||
|
|
||||||
@answer_ns.route("/api/search")
|
|
||||||
class Search(Resource):
|
|
||||||
search_model = api.model(
|
|
||||||
"SearchModel",
|
|
||||||
{
|
|
||||||
"question": fields.String(
|
|
||||||
required=True, description="The question to search"
|
|
||||||
),
|
|
||||||
"chunks": fields.Integer(
|
|
||||||
required=False, default=2, description="Number of chunks"
|
|
||||||
),
|
|
||||||
"api_key": fields.String(
|
|
||||||
required=False, description="API key for authentication"
|
|
||||||
),
|
|
||||||
"active_docs": fields.String(
|
|
||||||
required=False, description="Active documents for retrieval"
|
|
||||||
),
|
|
||||||
"retriever": fields.String(required=False, description="Retriever type"),
|
|
||||||
"token_limit": fields.Integer(
|
|
||||||
required=False, description="Limit for tokens"
|
|
||||||
),
|
|
||||||
"isNoneDoc": fields.Boolean(
|
|
||||||
required=False, description="Flag indicating if no document is used"
|
|
||||||
),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
@api.expect(search_model)
|
|
||||||
@api.doc(
|
|
||||||
description="Search for relevant documents based on the question and retriever"
|
|
||||||
)
|
|
||||||
def post(self):
|
|
||||||
data = request.get_json()
|
|
||||||
required_fields = ["question"]
|
|
||||||
missing_fields = check_required_fields(data, required_fields)
|
|
||||||
if missing_fields:
|
|
||||||
return missing_fields
|
|
||||||
|
|
||||||
try:
|
|
||||||
question = data["question"]
|
|
||||||
chunks = int(data.get("chunks", 2))
|
|
||||||
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
|
|
||||||
retriever_name = data.get("retriever", "classic")
|
|
||||||
|
|
||||||
if "api_key" in data:
|
|
||||||
data_key = get_data_from_api_key(data["api_key"])
|
|
||||||
chunks = int(data_key.get("chunks", 2))
|
|
||||||
source = {"active_docs": data_key.get("source")}
|
|
||||||
user_api_key = data["api_key"]
|
|
||||||
decoded_token = {"sub": data_key.get("user")}
|
|
||||||
|
|
||||||
elif "active_docs" in data:
|
|
||||||
source = {"active_docs": data["active_docs"]}
|
|
||||||
user_api_key = None
|
|
||||||
decoded_token = request.decoded_token
|
|
||||||
|
|
||||||
else:
|
|
||||||
source = {}
|
|
||||||
user_api_key = None
|
|
||||||
decoded_token = request.decoded_token
|
|
||||||
|
|
||||||
if not decoded_token:
|
|
||||||
return make_response({"error": "Unauthorized"}, 401)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"/api/answer - request_data: {data}, source: {source}",
|
|
||||||
extra={"data": json.dumps({"request_data": data, "source": source})},
|
|
||||||
)
|
|
||||||
|
|
||||||
retriever = RetrieverCreator.create_retriever(
|
|
||||||
retriever_name,
|
|
||||||
source=source,
|
|
||||||
chat_history=[],
|
|
||||||
prompt="default",
|
|
||||||
chunks=chunks,
|
|
||||||
token_limit=token_limit,
|
|
||||||
gpt_model=gpt_model,
|
|
||||||
user_api_key=user_api_key,
|
|
||||||
decoded_token=decoded_token,
|
|
||||||
)
|
|
||||||
|
|
||||||
docs = retriever.search(question)
|
|
||||||
retriever_params = retriever.get_params()
|
|
||||||
|
|
||||||
user_logs_collection.insert_one(
|
|
||||||
{
|
|
||||||
"action": "api_search",
|
|
||||||
"level": "info",
|
|
||||||
"user": decoded_token.get("sub"),
|
|
||||||
"api_key": user_api_key,
|
|
||||||
"question": question,
|
|
||||||
"sources": docs,
|
|
||||||
"retriever_params": retriever_params,
|
|
||||||
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
if data.get("isNoneDoc"):
|
|
||||||
for doc in docs:
|
|
||||||
doc["source"] = "None"
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"/api/search - error: {str(e)} - traceback: {traceback.format_exc()}",
|
|
||||||
extra={"error": str(e), "traceback": traceback.format_exc()},
|
|
||||||
)
|
|
||||||
return bad_request(500, str(e))
|
|
||||||
|
|
||||||
return make_response(docs, 200)
|
|
||||||
|
|
||||||
|
|
||||||
def get_attachments_content(attachment_ids, user):
|
|
||||||
"""
|
|
||||||
Retrieve content from attachment documents based on their IDs.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
attachment_ids (list): List of attachment document IDs
|
|
||||||
user (str): User identifier to verify ownership
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list: List of dictionaries containing attachment content and metadata
|
|
||||||
"""
|
|
||||||
if not attachment_ids:
|
|
||||||
return []
|
|
||||||
|
|
||||||
attachments = []
|
|
||||||
for attachment_id in attachment_ids:
|
|
||||||
try:
|
|
||||||
attachment_doc = attachments_collection.find_one(
|
|
||||||
{"_id": ObjectId(attachment_id), "user": user}
|
|
||||||
)
|
|
||||||
|
|
||||||
if attachment_doc:
|
|
||||||
attachments.append(attachment_doc)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Error retrieving attachment {attachment_id}: {e}", exc_info=True
|
|
||||||
)
|
|
||||||
|
|
||||||
return attachments
|
|
||||||
0
application/api/answer/routes/__init__.py
Normal file
0
application/api/answer/routes/__init__.py
Normal file
137
application/api/answer/routes/answer.py
Normal file
137
application/api/answer/routes/answer.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
import logging
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
from flask import make_response, request
|
||||||
|
from flask_restx import fields, Resource
|
||||||
|
|
||||||
|
from application.api import api
|
||||||
|
|
||||||
|
from application.api.answer.routes.base import answer_ns, BaseAnswerResource
|
||||||
|
|
||||||
|
from application.api.answer.services.stream_processor import StreamProcessor
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@answer_ns.route("/api/answer")
|
||||||
|
class AnswerResource(Resource, BaseAnswerResource):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
Resource.__init__(self, *args, **kwargs)
|
||||||
|
BaseAnswerResource.__init__(self)
|
||||||
|
|
||||||
|
answer_model = answer_ns.model(
|
||||||
|
"AnswerModel",
|
||||||
|
{
|
||||||
|
"question": fields.String(
|
||||||
|
required=True, description="Question to be asked"
|
||||||
|
),
|
||||||
|
"history": fields.List(
|
||||||
|
fields.String,
|
||||||
|
required=False,
|
||||||
|
description="Conversation history (only for new conversations)",
|
||||||
|
),
|
||||||
|
"conversation_id": fields.String(
|
||||||
|
required=False,
|
||||||
|
description="Existing conversation ID (loads history)",
|
||||||
|
),
|
||||||
|
"prompt_id": fields.String(
|
||||||
|
required=False, default="default", description="Prompt ID"
|
||||||
|
),
|
||||||
|
"chunks": fields.Integer(
|
||||||
|
required=False, default=2, description="Number of chunks"
|
||||||
|
),
|
||||||
|
"token_limit": fields.Integer(required=False, description="Token limit"),
|
||||||
|
"retriever": fields.String(required=False, description="Retriever type"),
|
||||||
|
"api_key": fields.String(required=False, description="API key"),
|
||||||
|
"active_docs": fields.String(
|
||||||
|
required=False, description="Active documents"
|
||||||
|
),
|
||||||
|
"isNoneDoc": fields.Boolean(
|
||||||
|
required=False, description="Flag indicating if no document is used"
|
||||||
|
),
|
||||||
|
"save_conversation": fields.Boolean(
|
||||||
|
required=False,
|
||||||
|
default=True,
|
||||||
|
description="Whether to save the conversation",
|
||||||
|
),
|
||||||
|
"passthrough": fields.Raw(
|
||||||
|
required=False,
|
||||||
|
description="Dynamic parameters to inject into prompt template",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@api.expect(answer_model)
|
||||||
|
@api.doc(description="Provide a response based on the question and retriever")
|
||||||
|
def post(self):
|
||||||
|
data = request.get_json()
|
||||||
|
if error := self.validate_request(data):
|
||||||
|
return error
|
||||||
|
decoded_token = getattr(request, "decoded_token", None)
|
||||||
|
processor = StreamProcessor(data, decoded_token)
|
||||||
|
try:
|
||||||
|
processor.initialize()
|
||||||
|
if not processor.decoded_token:
|
||||||
|
return make_response({"error": "Unauthorized"}, 401)
|
||||||
|
|
||||||
|
docs_together, docs_list = processor.pre_fetch_docs(
|
||||||
|
data.get("question", "")
|
||||||
|
)
|
||||||
|
tools_data = processor.pre_fetch_tools()
|
||||||
|
|
||||||
|
agent = processor.create_agent(
|
||||||
|
docs_together=docs_together,
|
||||||
|
docs=docs_list,
|
||||||
|
tools_data=tools_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
if error := self.check_usage(processor.agent_config):
|
||||||
|
return error
|
||||||
|
|
||||||
|
stream = self.complete_stream(
|
||||||
|
question=data["question"],
|
||||||
|
agent=agent,
|
||||||
|
conversation_id=processor.conversation_id,
|
||||||
|
user_api_key=processor.agent_config.get("user_api_key"),
|
||||||
|
decoded_token=processor.decoded_token,
|
||||||
|
isNoneDoc=data.get("isNoneDoc"),
|
||||||
|
index=None,
|
||||||
|
should_save_conversation=data.get("save_conversation", True),
|
||||||
|
)
|
||||||
|
stream_result = self.process_response_stream(stream)
|
||||||
|
|
||||||
|
if len(stream_result) == 7:
|
||||||
|
(
|
||||||
|
conversation_id,
|
||||||
|
response,
|
||||||
|
sources,
|
||||||
|
tool_calls,
|
||||||
|
thought,
|
||||||
|
error,
|
||||||
|
structured_info,
|
||||||
|
) = stream_result
|
||||||
|
else:
|
||||||
|
conversation_id, response, sources, tool_calls, thought, error = (
|
||||||
|
stream_result
|
||||||
|
)
|
||||||
|
structured_info = None
|
||||||
|
|
||||||
|
if error:
|
||||||
|
return make_response({"error": error}, 400)
|
||||||
|
result = {
|
||||||
|
"conversation_id": conversation_id,
|
||||||
|
"answer": response,
|
||||||
|
"sources": sources,
|
||||||
|
"tool_calls": tool_calls,
|
||||||
|
"thought": thought,
|
||||||
|
}
|
||||||
|
|
||||||
|
if structured_info:
|
||||||
|
result.update(structured_info)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",
|
||||||
|
extra={"error": str(e), "traceback": traceback.format_exc()},
|
||||||
|
)
|
||||||
|
return make_response({"error": str(e)}, 500)
|
||||||
|
return make_response(result, 200)
|
||||||
398
application/api/answer/routes/base.py
Normal file
398
application/api/answer/routes/base.py
Normal file
@@ -0,0 +1,398 @@
|
|||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, Generator, List, Optional
|
||||||
|
|
||||||
|
from flask import jsonify, make_response, Response
|
||||||
|
from flask_restx import Namespace
|
||||||
|
|
||||||
|
from application.api.answer.services.conversation_service import ConversationService
|
||||||
|
|
||||||
|
from application.core.mongo_db import MongoDB
|
||||||
|
from application.core.settings import settings
|
||||||
|
from application.llm.llm_creator import LLMCreator
|
||||||
|
from application.utils import check_required_fields, get_gpt_model
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
answer_ns = Namespace("answer", description="Answer related operations", path="/")
|
||||||
|
|
||||||
|
|
||||||
|
class BaseAnswerResource:
|
||||||
|
"""Shared base class for answer endpoints"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
mongo = MongoDB.get_client()
|
||||||
|
db = mongo[settings.MONGO_DB_NAME]
|
||||||
|
self.db = db
|
||||||
|
self.user_logs_collection = db["user_logs"]
|
||||||
|
self.gpt_model = get_gpt_model()
|
||||||
|
self.conversation_service = ConversationService()
|
||||||
|
|
||||||
|
def validate_request(
|
||||||
|
self, data: Dict[str, Any], require_conversation_id: bool = False
|
||||||
|
) -> Optional[Response]:
|
||||||
|
"""Common request validation"""
|
||||||
|
required_fields = ["question"]
|
||||||
|
if require_conversation_id:
|
||||||
|
required_fields.append("conversation_id")
|
||||||
|
if missing_fields := check_required_fields(data, required_fields):
|
||||||
|
return missing_fields
|
||||||
|
return None
|
||||||
|
|
||||||
|
def check_usage(self, agent_config: Dict) -> Optional[Response]:
|
||||||
|
"""Check if there is a usage limit and if it is exceeded
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_config: The config dict of agent instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None or Response if either of limits exceeded.
|
||||||
|
|
||||||
|
"""
|
||||||
|
api_key = agent_config.get("user_api_key")
|
||||||
|
if not api_key:
|
||||||
|
return None
|
||||||
|
|
||||||
|
agents_collection = self.db["agents"]
|
||||||
|
agent = agents_collection.find_one({"key": api_key})
|
||||||
|
|
||||||
|
if not agent:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Invalid API key."}), 401
|
||||||
|
)
|
||||||
|
|
||||||
|
limited_token_mode_raw = agent.get("limited_token_mode", False)
|
||||||
|
limited_request_mode_raw = agent.get("limited_request_mode", False)
|
||||||
|
|
||||||
|
limited_token_mode = (
|
||||||
|
limited_token_mode_raw
|
||||||
|
if isinstance(limited_token_mode_raw, bool)
|
||||||
|
else limited_token_mode_raw == "True"
|
||||||
|
)
|
||||||
|
limited_request_mode = (
|
||||||
|
limited_request_mode_raw
|
||||||
|
if isinstance(limited_request_mode_raw, bool)
|
||||||
|
else limited_request_mode_raw == "True"
|
||||||
|
)
|
||||||
|
|
||||||
|
token_limit = int(
|
||||||
|
agent.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"])
|
||||||
|
)
|
||||||
|
request_limit = int(
|
||||||
|
agent.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"])
|
||||||
|
)
|
||||||
|
|
||||||
|
token_usage_collection = self.db["token_usage"]
|
||||||
|
|
||||||
|
end_date = datetime.datetime.now()
|
||||||
|
start_date = end_date - datetime.timedelta(hours=24)
|
||||||
|
|
||||||
|
match_query = {
|
||||||
|
"timestamp": {"$gte": start_date, "$lte": end_date},
|
||||||
|
"api_key": api_key,
|
||||||
|
}
|
||||||
|
|
||||||
|
if limited_token_mode:
|
||||||
|
token_pipeline = [
|
||||||
|
{"$match": match_query},
|
||||||
|
{
|
||||||
|
"$group": {
|
||||||
|
"_id": None,
|
||||||
|
"total_tokens": {
|
||||||
|
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
]
|
||||||
|
token_result = list(token_usage_collection.aggregate(token_pipeline))
|
||||||
|
daily_token_usage = token_result[0]["total_tokens"] if token_result else 0
|
||||||
|
else:
|
||||||
|
daily_token_usage = 0
|
||||||
|
|
||||||
|
if limited_request_mode:
|
||||||
|
daily_request_usage = token_usage_collection.count_documents(match_query)
|
||||||
|
else:
|
||||||
|
daily_request_usage = 0
|
||||||
|
|
||||||
|
if not limited_token_mode and not limited_request_mode:
|
||||||
|
return None
|
||||||
|
|
||||||
|
token_exceeded = (
|
||||||
|
limited_token_mode and token_limit > 0 and daily_token_usage >= token_limit
|
||||||
|
)
|
||||||
|
request_exceeded = (
|
||||||
|
limited_request_mode
|
||||||
|
and request_limit > 0
|
||||||
|
and daily_request_usage >= request_limit
|
||||||
|
)
|
||||||
|
|
||||||
|
if token_exceeded or request_exceeded:
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"message": "Exceeding usage limit, please try again later.",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
429,
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def complete_stream(
|
||||||
|
self,
|
||||||
|
question: str,
|
||||||
|
agent: Any,
|
||||||
|
conversation_id: Optional[str],
|
||||||
|
user_api_key: Optional[str],
|
||||||
|
decoded_token: Dict[str, Any],
|
||||||
|
isNoneDoc: bool = False,
|
||||||
|
index: Optional[int] = None,
|
||||||
|
should_save_conversation: bool = True,
|
||||||
|
attachment_ids: Optional[List[str]] = None,
|
||||||
|
agent_id: Optional[str] = None,
|
||||||
|
is_shared_usage: bool = False,
|
||||||
|
shared_token: Optional[str] = None,
|
||||||
|
) -> Generator[str, None, None]:
|
||||||
|
"""
|
||||||
|
Generator function that streams the complete conversation response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
question: The user's question
|
||||||
|
agent: The agent instance
|
||||||
|
retriever: The retriever instance
|
||||||
|
conversation_id: Existing conversation ID
|
||||||
|
user_api_key: User's API key if any
|
||||||
|
decoded_token: Decoded JWT token
|
||||||
|
isNoneDoc: Flag for document-less responses
|
||||||
|
index: Index of message to update
|
||||||
|
should_save_conversation: Whether to persist the conversation
|
||||||
|
attachment_ids: List of attachment IDs
|
||||||
|
agent_id: ID of agent used
|
||||||
|
is_shared_usage: Flag for shared agent usage
|
||||||
|
shared_token: Token for shared agent
|
||||||
|
retrieved_docs: Pre-fetched documents for sources (optional)
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Server-sent event strings
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response_full, thought, source_log_docs, tool_calls = "", "", [], []
|
||||||
|
is_structured = False
|
||||||
|
schema_info = None
|
||||||
|
structured_chunks = []
|
||||||
|
|
||||||
|
for line in agent.gen(query=question):
|
||||||
|
if "answer" in line:
|
||||||
|
response_full += str(line["answer"])
|
||||||
|
if line.get("structured"):
|
||||||
|
is_structured = True
|
||||||
|
schema_info = line.get("schema")
|
||||||
|
structured_chunks.append(line["answer"])
|
||||||
|
else:
|
||||||
|
data = json.dumps({"type": "answer", "answer": line["answer"]})
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
elif "sources" in line:
|
||||||
|
truncated_sources = []
|
||||||
|
source_log_docs = line["sources"]
|
||||||
|
for source in line["sources"]:
|
||||||
|
truncated_source = source.copy()
|
||||||
|
if "text" in truncated_source:
|
||||||
|
truncated_source["text"] = (
|
||||||
|
truncated_source["text"][:100].strip() + "..."
|
||||||
|
)
|
||||||
|
truncated_sources.append(truncated_source)
|
||||||
|
if truncated_sources:
|
||||||
|
data = json.dumps(
|
||||||
|
{"type": "source", "source": truncated_sources}
|
||||||
|
)
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
elif "tool_calls" in line:
|
||||||
|
tool_calls = line["tool_calls"]
|
||||||
|
data = json.dumps({"type": "tool_calls", "tool_calls": tool_calls})
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
elif "thought" in line:
|
||||||
|
thought += line["thought"]
|
||||||
|
data = json.dumps({"type": "thought", "thought": line["thought"]})
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
elif "type" in line:
|
||||||
|
data = json.dumps(line)
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
|
||||||
|
if is_structured and structured_chunks:
|
||||||
|
structured_data = {
|
||||||
|
"type": "structured_answer",
|
||||||
|
"answer": response_full,
|
||||||
|
"structured": True,
|
||||||
|
"schema": schema_info,
|
||||||
|
}
|
||||||
|
data = json.dumps(structured_data)
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
|
||||||
|
if isNoneDoc:
|
||||||
|
for doc in source_log_docs:
|
||||||
|
doc["source"] = "None"
|
||||||
|
llm = LLMCreator.create_llm(
|
||||||
|
settings.LLM_PROVIDER,
|
||||||
|
api_key=settings.API_KEY,
|
||||||
|
user_api_key=user_api_key,
|
||||||
|
decoded_token=decoded_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
if should_save_conversation:
|
||||||
|
conversation_id = self.conversation_service.save_conversation(
|
||||||
|
conversation_id,
|
||||||
|
question,
|
||||||
|
response_full,
|
||||||
|
thought,
|
||||||
|
source_log_docs,
|
||||||
|
tool_calls,
|
||||||
|
llm,
|
||||||
|
self.gpt_model,
|
||||||
|
decoded_token,
|
||||||
|
index=index,
|
||||||
|
api_key=user_api_key,
|
||||||
|
agent_id=agent_id,
|
||||||
|
is_shared_usage=is_shared_usage,
|
||||||
|
shared_token=shared_token,
|
||||||
|
attachment_ids=attachment_ids,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
conversation_id = None
|
||||||
|
id_data = {"type": "id", "id": str(conversation_id)}
|
||||||
|
data = json.dumps(id_data)
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
|
||||||
|
log_data = {
|
||||||
|
"action": "stream_answer",
|
||||||
|
"level": "info",
|
||||||
|
"user": decoded_token.get("sub"),
|
||||||
|
"api_key": user_api_key,
|
||||||
|
"question": question,
|
||||||
|
"response": response_full,
|
||||||
|
"sources": source_log_docs,
|
||||||
|
"attachments": attachment_ids,
|
||||||
|
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
||||||
|
}
|
||||||
|
if is_structured:
|
||||||
|
log_data["structured_output"] = True
|
||||||
|
if schema_info:
|
||||||
|
log_data["schema"] = schema_info
|
||||||
|
|
||||||
|
# Clean up text fields to be no longer than 10000 characters
|
||||||
|
for key, value in log_data.items():
|
||||||
|
if isinstance(value, str) and len(value) > 10000:
|
||||||
|
log_data[key] = value[:10000]
|
||||||
|
|
||||||
|
self.user_logs_collection.insert_one(log_data)
|
||||||
|
|
||||||
|
data = json.dumps({"type": "end"})
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
except GeneratorExit:
|
||||||
|
logger.info(f"Stream aborted by client for question: {question[:50]}... ")
|
||||||
|
# Save partial response
|
||||||
|
if should_save_conversation and response_full:
|
||||||
|
try:
|
||||||
|
if isNoneDoc:
|
||||||
|
for doc in source_log_docs:
|
||||||
|
doc["source"] = "None"
|
||||||
|
llm = LLMCreator.create_llm(
|
||||||
|
settings.LLM_PROVIDER,
|
||||||
|
api_key=settings.API_KEY,
|
||||||
|
user_api_key=user_api_key,
|
||||||
|
decoded_token=decoded_token,
|
||||||
|
)
|
||||||
|
self.conversation_service.save_conversation(
|
||||||
|
conversation_id,
|
||||||
|
question,
|
||||||
|
response_full,
|
||||||
|
thought,
|
||||||
|
source_log_docs,
|
||||||
|
tool_calls,
|
||||||
|
llm,
|
||||||
|
self.gpt_model,
|
||||||
|
decoded_token,
|
||||||
|
index=index,
|
||||||
|
api_key=user_api_key,
|
||||||
|
agent_id=agent_id,
|
||||||
|
is_shared_usage=is_shared_usage,
|
||||||
|
shared_token=shared_token,
|
||||||
|
attachment_ids=attachment_ids,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error saving partial response: {str(e)}", exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in stream: {str(e)}", exc_info=True)
|
||||||
|
data = json.dumps(
|
||||||
|
{
|
||||||
|
"type": "error",
|
||||||
|
"error": "Please try again later. We apologize for any inconvenience.",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
return
|
||||||
|
|
||||||
|
def process_response_stream(self, stream):
|
||||||
|
"""Process the stream response for non-streaming endpoint"""
|
||||||
|
conversation_id = ""
|
||||||
|
response_full = ""
|
||||||
|
source_log_docs = []
|
||||||
|
tool_calls = []
|
||||||
|
thought = ""
|
||||||
|
stream_ended = False
|
||||||
|
is_structured = False
|
||||||
|
schema_info = None
|
||||||
|
|
||||||
|
for line in stream:
|
||||||
|
try:
|
||||||
|
event_data = line.replace("data: ", "").strip()
|
||||||
|
event = json.loads(event_data)
|
||||||
|
|
||||||
|
if event["type"] == "id":
|
||||||
|
conversation_id = event["id"]
|
||||||
|
elif event["type"] == "answer":
|
||||||
|
response_full += event["answer"]
|
||||||
|
elif event["type"] == "structured_answer":
|
||||||
|
response_full = event["answer"]
|
||||||
|
is_structured = True
|
||||||
|
schema_info = event.get("schema")
|
||||||
|
elif event["type"] == "source":
|
||||||
|
source_log_docs = event["source"]
|
||||||
|
elif event["type"] == "tool_calls":
|
||||||
|
tool_calls = event["tool_calls"]
|
||||||
|
elif event["type"] == "thought":
|
||||||
|
thought = event["thought"]
|
||||||
|
elif event["type"] == "error":
|
||||||
|
logger.error(f"Error from stream: {event['error']}")
|
||||||
|
return None, None, None, None, event["error"]
|
||||||
|
elif event["type"] == "end":
|
||||||
|
stream_ended = True
|
||||||
|
except (json.JSONDecodeError, KeyError) as e:
|
||||||
|
logger.warning(f"Error parsing stream event: {e}, line: {line}")
|
||||||
|
continue
|
||||||
|
if not stream_ended:
|
||||||
|
logger.error("Stream ended unexpectedly without an 'end' event.")
|
||||||
|
return None, None, None, None, "Stream ended unexpectedly"
|
||||||
|
|
||||||
|
result = (
|
||||||
|
conversation_id,
|
||||||
|
response_full,
|
||||||
|
source_log_docs,
|
||||||
|
tool_calls,
|
||||||
|
thought,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_structured:
|
||||||
|
result = result + ({"structured": True, "schema": schema_info},)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def error_stream_generate(self, err_response):
|
||||||
|
data = json.dumps({"type": "error", "error": err_response})
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
127
application/api/answer/routes/stream.py
Normal file
127
application/api/answer/routes/stream.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
import logging
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
from flask import request, Response
|
||||||
|
from flask_restx import fields, Resource
|
||||||
|
|
||||||
|
from application.api import api
|
||||||
|
|
||||||
|
from application.api.answer.routes.base import answer_ns, BaseAnswerResource
|
||||||
|
|
||||||
|
from application.api.answer.services.stream_processor import StreamProcessor
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@answer_ns.route("/stream")
|
||||||
|
class StreamResource(Resource, BaseAnswerResource):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
Resource.__init__(self, *args, **kwargs)
|
||||||
|
BaseAnswerResource.__init__(self)
|
||||||
|
|
||||||
|
stream_model = answer_ns.model(
|
||||||
|
"StreamModel",
|
||||||
|
{
|
||||||
|
"question": fields.String(
|
||||||
|
required=True, description="Question to be asked"
|
||||||
|
),
|
||||||
|
"history": fields.List(
|
||||||
|
fields.String,
|
||||||
|
required=False,
|
||||||
|
description="Conversation history (only for new conversations)",
|
||||||
|
),
|
||||||
|
"conversation_id": fields.String(
|
||||||
|
required=False,
|
||||||
|
description="Existing conversation ID (loads history)",
|
||||||
|
),
|
||||||
|
"prompt_id": fields.String(
|
||||||
|
required=False, default="default", description="Prompt ID"
|
||||||
|
),
|
||||||
|
"chunks": fields.Integer(
|
||||||
|
required=False, default=2, description="Number of chunks"
|
||||||
|
),
|
||||||
|
"token_limit": fields.Integer(required=False, description="Token limit"),
|
||||||
|
"retriever": fields.String(required=False, description="Retriever type"),
|
||||||
|
"api_key": fields.String(required=False, description="API key"),
|
||||||
|
"active_docs": fields.String(
|
||||||
|
required=False, description="Active documents"
|
||||||
|
),
|
||||||
|
"isNoneDoc": fields.Boolean(
|
||||||
|
required=False, description="Flag indicating if no document is used"
|
||||||
|
),
|
||||||
|
"index": fields.Integer(
|
||||||
|
required=False, description="Index of the query to update"
|
||||||
|
),
|
||||||
|
"save_conversation": fields.Boolean(
|
||||||
|
required=False,
|
||||||
|
default=True,
|
||||||
|
description="Whether to save the conversation",
|
||||||
|
),
|
||||||
|
"attachments": fields.List(
|
||||||
|
fields.String, required=False, description="List of attachment IDs"
|
||||||
|
),
|
||||||
|
"passthrough": fields.Raw(
|
||||||
|
required=False,
|
||||||
|
description="Dynamic parameters to inject into prompt template",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@api.expect(stream_model)
|
||||||
|
@api.doc(description="Stream a response based on the question and retriever")
|
||||||
|
def post(self):
|
||||||
|
data = request.get_json()
|
||||||
|
if error := self.validate_request(data, "index" in data):
|
||||||
|
return error
|
||||||
|
decoded_token = getattr(request, "decoded_token", None)
|
||||||
|
processor = StreamProcessor(data, decoded_token)
|
||||||
|
try:
|
||||||
|
processor.initialize()
|
||||||
|
|
||||||
|
docs_together, docs_list = processor.pre_fetch_docs(data["question"])
|
||||||
|
tools_data = processor.pre_fetch_tools()
|
||||||
|
|
||||||
|
agent = processor.create_agent(
|
||||||
|
docs_together=docs_together, docs=docs_list, tools_data=tools_data
|
||||||
|
)
|
||||||
|
|
||||||
|
if error := self.check_usage(processor.agent_config):
|
||||||
|
return error
|
||||||
|
return Response(
|
||||||
|
self.complete_stream(
|
||||||
|
question=data["question"],
|
||||||
|
agent=agent,
|
||||||
|
conversation_id=processor.conversation_id,
|
||||||
|
user_api_key=processor.agent_config.get("user_api_key"),
|
||||||
|
decoded_token=processor.decoded_token,
|
||||||
|
isNoneDoc=data.get("isNoneDoc"),
|
||||||
|
index=data.get("index"),
|
||||||
|
should_save_conversation=data.get("save_conversation", True),
|
||||||
|
attachment_ids=data.get("attachments", []),
|
||||||
|
agent_id=data.get("agent_id"),
|
||||||
|
is_shared_usage=processor.is_shared_usage,
|
||||||
|
shared_token=processor.shared_token,
|
||||||
|
),
|
||||||
|
mimetype="text/event-stream",
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
message = "Malformed request body"
|
||||||
|
logger.error(
|
||||||
|
f"/stream - error: {message} - specific error: {str(e)} - traceback: {traceback.format_exc()}",
|
||||||
|
extra={"error": str(e), "traceback": traceback.format_exc()},
|
||||||
|
)
|
||||||
|
return Response(
|
||||||
|
self.error_stream_generate(message),
|
||||||
|
status=400,
|
||||||
|
mimetype="text/event-stream",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"/stream - error: {str(e)} - traceback: {traceback.format_exc()}",
|
||||||
|
extra={"error": str(e), "traceback": traceback.format_exc()},
|
||||||
|
)
|
||||||
|
return Response(
|
||||||
|
self.error_stream_generate("Unknown error occurred"),
|
||||||
|
status=400,
|
||||||
|
mimetype="text/event-stream",
|
||||||
|
)
|
||||||
0
application/api/answer/services/__init__.py
Normal file
0
application/api/answer/services/__init__.py
Normal file
179
application/api/answer/services/conversation_service.py
Normal file
179
application/api/answer/services/conversation_service.py
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from application.core.mongo_db import MongoDB
|
||||||
|
|
||||||
|
from application.core.settings import settings
|
||||||
|
from bson import ObjectId
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationService:
|
||||||
|
def __init__(self):
|
||||||
|
mongo = MongoDB.get_client()
|
||||||
|
db = mongo[settings.MONGO_DB_NAME]
|
||||||
|
self.conversations_collection = db["conversations"]
|
||||||
|
self.agents_collection = db["agents"]
|
||||||
|
|
||||||
|
def get_conversation(
|
||||||
|
self, conversation_id: str, user_id: str
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Retrieve a conversation with proper access control"""
|
||||||
|
if not conversation_id or not user_id:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
conversation = self.conversations_collection.find_one(
|
||||||
|
{
|
||||||
|
"_id": ObjectId(conversation_id),
|
||||||
|
"$or": [{"user": user_id}, {"shared_with": user_id}],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if not conversation:
|
||||||
|
logger.warning(
|
||||||
|
f"Conversation not found or unauthorized - ID: {conversation_id}, User: {user_id}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
conversation["_id"] = str(conversation["_id"])
|
||||||
|
return conversation
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching conversation: {str(e)}", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def save_conversation(
|
||||||
|
self,
|
||||||
|
conversation_id: Optional[str],
|
||||||
|
question: str,
|
||||||
|
response: str,
|
||||||
|
thought: str,
|
||||||
|
sources: List[Dict[str, Any]],
|
||||||
|
tool_calls: List[Dict[str, Any]],
|
||||||
|
llm: Any,
|
||||||
|
gpt_model: str,
|
||||||
|
decoded_token: Dict[str, Any],
|
||||||
|
index: Optional[int] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
agent_id: Optional[str] = None,
|
||||||
|
is_shared_usage: bool = False,
|
||||||
|
shared_token: Optional[str] = None,
|
||||||
|
attachment_ids: Optional[List[str]] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Save or update a conversation in the database"""
|
||||||
|
user_id = decoded_token.get("sub")
|
||||||
|
if not user_id:
|
||||||
|
raise ValueError("User ID not found in token")
|
||||||
|
current_time = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
# clean up in sources array such that we save max 1k characters for text part
|
||||||
|
for source in sources:
|
||||||
|
if "text" in source and isinstance(source["text"], str):
|
||||||
|
source["text"] = source["text"][:1000]
|
||||||
|
|
||||||
|
if conversation_id is not None and index is not None:
|
||||||
|
# Update existing conversation with new query
|
||||||
|
|
||||||
|
result = self.conversations_collection.update_one(
|
||||||
|
{
|
||||||
|
"_id": ObjectId(conversation_id),
|
||||||
|
"user": user_id,
|
||||||
|
f"queries.{index}": {"$exists": True},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$set": {
|
||||||
|
f"queries.{index}.prompt": question,
|
||||||
|
f"queries.{index}.response": response,
|
||||||
|
f"queries.{index}.thought": thought,
|
||||||
|
f"queries.{index}.sources": sources,
|
||||||
|
f"queries.{index}.tool_calls": tool_calls,
|
||||||
|
f"queries.{index}.timestamp": current_time,
|
||||||
|
f"queries.{index}.attachments": attachment_ids,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.matched_count == 0:
|
||||||
|
raise ValueError("Conversation not found or unauthorized")
|
||||||
|
self.conversations_collection.update_one(
|
||||||
|
{
|
||||||
|
"_id": ObjectId(conversation_id),
|
||||||
|
"user": user_id,
|
||||||
|
f"queries.{index}": {"$exists": True},
|
||||||
|
},
|
||||||
|
{"$push": {"queries": {"$each": [], "$slice": index + 1}}},
|
||||||
|
)
|
||||||
|
return conversation_id
|
||||||
|
elif conversation_id:
|
||||||
|
# Append new message to existing conversation
|
||||||
|
|
||||||
|
result = self.conversations_collection.update_one(
|
||||||
|
{"_id": ObjectId(conversation_id), "user": user_id},
|
||||||
|
{
|
||||||
|
"$push": {
|
||||||
|
"queries": {
|
||||||
|
"prompt": question,
|
||||||
|
"response": response,
|
||||||
|
"thought": thought,
|
||||||
|
"sources": sources,
|
||||||
|
"tool_calls": tool_calls,
|
||||||
|
"timestamp": current_time,
|
||||||
|
"attachments": attachment_ids,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.matched_count == 0:
|
||||||
|
raise ValueError("Conversation not found or unauthorized")
|
||||||
|
return conversation_id
|
||||||
|
else:
|
||||||
|
# Create new conversation
|
||||||
|
|
||||||
|
messages_summary = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant that creates concise conversation titles. "
|
||||||
|
"Summarize conversations in 3 words or less using the same language as the user.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Summarise following conversation in no more than 3 words, "
|
||||||
|
"respond ONLY with the summary, use the same language as the "
|
||||||
|
"user query \n\nUser: " + question + "\n\n" + "AI: " + response,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
completion = llm.gen(
|
||||||
|
model=gpt_model, messages=messages_summary, max_tokens=30
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation_data = {
|
||||||
|
"user": user_id,
|
||||||
|
"date": current_time,
|
||||||
|
"name": completion,
|
||||||
|
"queries": [
|
||||||
|
{
|
||||||
|
"prompt": question,
|
||||||
|
"response": response,
|
||||||
|
"thought": thought,
|
||||||
|
"sources": sources,
|
||||||
|
"tool_calls": tool_calls,
|
||||||
|
"timestamp": current_time,
|
||||||
|
"attachments": attachment_ids,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
if api_key:
|
||||||
|
if agent_id:
|
||||||
|
conversation_data["agent_id"] = agent_id
|
||||||
|
if is_shared_usage:
|
||||||
|
conversation_data["is_shared_usage"] = is_shared_usage
|
||||||
|
conversation_data["shared_token"] = shared_token
|
||||||
|
agent = self.agents_collection.find_one({"key": api_key})
|
||||||
|
if agent:
|
||||||
|
conversation_data["api_key"] = agent["key"]
|
||||||
|
result = self.conversations_collection.insert_one(conversation_data)
|
||||||
|
return str(result.inserted_id)
|
||||||
97
application/api/answer/services/prompt_renderer.py
Normal file
97
application/api/answer/services/prompt_renderer.py
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from application.templates.namespaces import NamespaceManager
|
||||||
|
|
||||||
|
from application.templates.template_engine import TemplateEngine, TemplateRenderError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PromptRenderer:
|
||||||
|
"""Service for rendering prompts with dynamic context using namespaces"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.template_engine = TemplateEngine()
|
||||||
|
self.namespace_manager = NamespaceManager()
|
||||||
|
|
||||||
|
def render_prompt(
|
||||||
|
self,
|
||||||
|
prompt_content: str,
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
request_id: Optional[str] = None,
|
||||||
|
passthrough_data: Optional[Dict[str, Any]] = None,
|
||||||
|
docs: Optional[list] = None,
|
||||||
|
docs_together: Optional[str] = None,
|
||||||
|
tools_data: Optional[Dict[str, Any]] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Render prompt with full context from all namespaces.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt_content: Raw prompt template string
|
||||||
|
user_id: Current user identifier
|
||||||
|
request_id: Unique request identifier
|
||||||
|
passthrough_data: Parameters from web request
|
||||||
|
docs: RAG retrieved documents
|
||||||
|
docs_together: Concatenated document content
|
||||||
|
tools_data: Pre-fetched tool results organized by tool name
|
||||||
|
**kwargs: Additional parameters for namespace builders
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Rendered prompt string with all variables substituted
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TemplateRenderError: If template rendering fails
|
||||||
|
"""
|
||||||
|
if not prompt_content:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
uses_template = self._uses_template_syntax(prompt_content)
|
||||||
|
|
||||||
|
if not uses_template:
|
||||||
|
return self._apply_legacy_substitutions(prompt_content, docs_together)
|
||||||
|
|
||||||
|
try:
|
||||||
|
context = self.namespace_manager.build_context(
|
||||||
|
user_id=user_id,
|
||||||
|
request_id=request_id,
|
||||||
|
passthrough_data=passthrough_data,
|
||||||
|
docs=docs,
|
||||||
|
docs_together=docs_together,
|
||||||
|
tools_data=tools_data,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.template_engine.render(prompt_content, context)
|
||||||
|
except TemplateRenderError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Prompt rendering failed: {str(e)}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
raise TemplateRenderError(error_msg) from e
|
||||||
|
|
||||||
|
def _uses_template_syntax(self, prompt_content: str) -> bool:
|
||||||
|
"""Check if prompt uses Jinja2 template syntax"""
|
||||||
|
return "{{" in prompt_content and "}}" in prompt_content
|
||||||
|
|
||||||
|
def _apply_legacy_substitutions(
|
||||||
|
self, prompt_content: str, docs_together: Optional[str] = None
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Apply backward-compatible substitutions for old prompt format.
|
||||||
|
|
||||||
|
Handles legacy {summaries} and {query} placeholders during transition period.
|
||||||
|
"""
|
||||||
|
if docs_together:
|
||||||
|
prompt_content = prompt_content.replace("{summaries}", docs_together)
|
||||||
|
return prompt_content
|
||||||
|
|
||||||
|
def validate_template(self, prompt_content: str) -> bool:
|
||||||
|
"""Validate prompt template syntax"""
|
||||||
|
return self.template_engine.validate_template(prompt_content)
|
||||||
|
|
||||||
|
def extract_variables(self, prompt_content: str) -> set[str]:
|
||||||
|
"""Extract all variable names from prompt template"""
|
||||||
|
return self.template_engine.extract_variables(prompt_content)
|
||||||
642
application/api/answer/services/stream_processor.py
Normal file
642
application/api/answer/services/stream_processor.py
Normal file
@@ -0,0 +1,642 @@
|
|||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Optional, Set
|
||||||
|
|
||||||
|
from bson.dbref import DBRef
|
||||||
|
|
||||||
|
from bson.objectid import ObjectId
|
||||||
|
|
||||||
|
from application.agents.agent_creator import AgentCreator
|
||||||
|
from application.api.answer.services.conversation_service import ConversationService
|
||||||
|
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||||
|
from application.core.mongo_db import MongoDB
|
||||||
|
from application.core.settings import settings
|
||||||
|
from application.retriever.retriever_creator import RetrieverCreator
|
||||||
|
from application.utils import (
|
||||||
|
calculate_doc_token_budget,
|
||||||
|
get_gpt_model,
|
||||||
|
limit_chat_history,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_prompt(prompt_id: str, prompts_collection=None) -> str:
|
||||||
|
"""
|
||||||
|
Get a prompt by preset name or MongoDB ID
|
||||||
|
"""
|
||||||
|
current_dir = Path(__file__).resolve().parents[3]
|
||||||
|
prompts_dir = current_dir / "prompts"
|
||||||
|
|
||||||
|
preset_mapping = {
|
||||||
|
"default": "chat_combine_default.txt",
|
||||||
|
"creative": "chat_combine_creative.txt",
|
||||||
|
"strict": "chat_combine_strict.txt",
|
||||||
|
"reduce": "chat_reduce_prompt.txt",
|
||||||
|
}
|
||||||
|
|
||||||
|
if prompt_id in preset_mapping:
|
||||||
|
file_path = os.path.join(prompts_dir, preset_mapping[prompt_id])
|
||||||
|
try:
|
||||||
|
with open(file_path, "r") as f:
|
||||||
|
return f.read()
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise FileNotFoundError(f"Prompt file not found: {file_path}")
|
||||||
|
try:
|
||||||
|
if prompts_collection is None:
|
||||||
|
mongo = MongoDB.get_client()
|
||||||
|
db = mongo[settings.MONGO_DB_NAME]
|
||||||
|
prompts_collection = db["prompts"]
|
||||||
|
prompt_doc = prompts_collection.find_one({"_id": ObjectId(prompt_id)})
|
||||||
|
if not prompt_doc:
|
||||||
|
raise ValueError(f"Prompt with ID {prompt_id} not found")
|
||||||
|
return prompt_doc["content"]
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Invalid prompt ID: {prompt_id}") from e
|
||||||
|
|
||||||
|
|
||||||
|
class StreamProcessor:
|
||||||
|
def __init__(
|
||||||
|
self, request_data: Dict[str, Any], decoded_token: Optional[Dict[str, Any]]
|
||||||
|
):
|
||||||
|
mongo = MongoDB.get_client()
|
||||||
|
self.db = mongo[settings.MONGO_DB_NAME]
|
||||||
|
self.agents_collection = self.db["agents"]
|
||||||
|
self.attachments_collection = self.db["attachments"]
|
||||||
|
self.prompts_collection = self.db["prompts"]
|
||||||
|
|
||||||
|
self.data = request_data
|
||||||
|
self.decoded_token = decoded_token
|
||||||
|
self.initial_user_id = (
|
||||||
|
self.decoded_token.get("sub") if self.decoded_token is not None else None
|
||||||
|
)
|
||||||
|
self.conversation_id = self.data.get("conversation_id")
|
||||||
|
self.source = {}
|
||||||
|
self.all_sources = []
|
||||||
|
self.attachments = []
|
||||||
|
self.history = []
|
||||||
|
self.retrieved_docs = []
|
||||||
|
self.agent_config = {}
|
||||||
|
self.retriever_config = {}
|
||||||
|
self.is_shared_usage = False
|
||||||
|
self.shared_token = None
|
||||||
|
self.gpt_model = get_gpt_model()
|
||||||
|
self.conversation_service = ConversationService()
|
||||||
|
self.prompt_renderer = PromptRenderer()
|
||||||
|
self._prompt_content: Optional[str] = None
|
||||||
|
self._required_tool_actions: Optional[Dict[str, Set[Optional[str]]]] = None
|
||||||
|
|
||||||
|
def initialize(self):
|
||||||
|
"""Initialize all required components for processing"""
|
||||||
|
self._configure_agent()
|
||||||
|
self._configure_source()
|
||||||
|
self._configure_retriever()
|
||||||
|
self._configure_agent()
|
||||||
|
self._load_conversation_history()
|
||||||
|
self._process_attachments()
|
||||||
|
|
||||||
|
def _load_conversation_history(self):
|
||||||
|
"""Load conversation history either from DB or request"""
|
||||||
|
if self.conversation_id and self.initial_user_id:
|
||||||
|
conversation = self.conversation_service.get_conversation(
|
||||||
|
self.conversation_id, self.initial_user_id
|
||||||
|
)
|
||||||
|
if not conversation:
|
||||||
|
raise ValueError("Conversation not found or unauthorized")
|
||||||
|
self.history = [
|
||||||
|
{"prompt": query["prompt"], "response": query["response"]}
|
||||||
|
for query in conversation.get("queries", [])
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
self.history = limit_chat_history(
|
||||||
|
json.loads(self.data.get("history", "[]")), gpt_model=self.gpt_model
|
||||||
|
)
|
||||||
|
|
||||||
|
def _process_attachments(self):
|
||||||
|
"""Process any attachments in the request"""
|
||||||
|
attachment_ids = self.data.get("attachments", [])
|
||||||
|
self.attachments = self._get_attachments_content(
|
||||||
|
attachment_ids, self.initial_user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_attachments_content(self, attachment_ids, user_id):
|
||||||
|
"""
|
||||||
|
Retrieve content from attachment documents based on their IDs.
|
||||||
|
"""
|
||||||
|
if not attachment_ids:
|
||||||
|
return []
|
||||||
|
attachments = []
|
||||||
|
for attachment_id in attachment_ids:
|
||||||
|
try:
|
||||||
|
attachment_doc = self.attachments_collection.find_one(
|
||||||
|
{"_id": ObjectId(attachment_id), "user": user_id}
|
||||||
|
)
|
||||||
|
|
||||||
|
if attachment_doc:
|
||||||
|
attachments.append(attachment_doc)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error retrieving attachment {attachment_id}: {e}", exc_info=True
|
||||||
|
)
|
||||||
|
return attachments
|
||||||
|
|
||||||
|
def _get_agent_key(self, agent_id: Optional[str], user_id: Optional[str]) -> tuple:
|
||||||
|
"""Get API key for agent with access control"""
|
||||||
|
if not agent_id:
|
||||||
|
return None, False, None
|
||||||
|
try:
|
||||||
|
agent = self.agents_collection.find_one({"_id": ObjectId(agent_id)})
|
||||||
|
if agent is None:
|
||||||
|
raise Exception("Agent not found")
|
||||||
|
is_owner = agent.get("user") == user_id
|
||||||
|
is_shared_with_user = agent.get(
|
||||||
|
"shared_publicly", False
|
||||||
|
) or user_id in agent.get("shared_with", [])
|
||||||
|
|
||||||
|
if not (is_owner or is_shared_with_user):
|
||||||
|
raise Exception("Unauthorized access to the agent")
|
||||||
|
if is_owner:
|
||||||
|
self.agents_collection.update_one(
|
||||||
|
{"_id": ObjectId(agent_id)},
|
||||||
|
{
|
||||||
|
"$set": {
|
||||||
|
"lastUsedAt": datetime.datetime.now(datetime.timezone.utc)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return str(agent["key"]), not is_owner, agent.get("shared_token")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in get_agent_key: {str(e)}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _get_data_from_api_key(self, api_key: str) -> Dict[str, Any]:
|
||||||
|
data = self.agents_collection.find_one({"key": api_key})
|
||||||
|
if not data:
|
||||||
|
raise Exception("Invalid API Key, please generate a new key", 401)
|
||||||
|
source = data.get("source")
|
||||||
|
if isinstance(source, DBRef):
|
||||||
|
source_doc = self.db.dereference(source)
|
||||||
|
if source_doc:
|
||||||
|
data["source"] = str(source_doc["_id"])
|
||||||
|
data["retriever"] = source_doc.get("retriever", data.get("retriever"))
|
||||||
|
data["chunks"] = source_doc.get("chunks", data.get("chunks"))
|
||||||
|
else:
|
||||||
|
data["source"] = None
|
||||||
|
elif source == "default":
|
||||||
|
data["source"] = "default"
|
||||||
|
else:
|
||||||
|
data["source"] = None
|
||||||
|
# Handle multiple sources
|
||||||
|
|
||||||
|
sources = data.get("sources", [])
|
||||||
|
if sources and isinstance(sources, list):
|
||||||
|
sources_list = []
|
||||||
|
for i, source_ref in enumerate(sources):
|
||||||
|
if source_ref == "default":
|
||||||
|
processed_source = {
|
||||||
|
"id": "default",
|
||||||
|
"retriever": "classic",
|
||||||
|
"chunks": data.get("chunks", "2"),
|
||||||
|
}
|
||||||
|
sources_list.append(processed_source)
|
||||||
|
elif isinstance(source_ref, DBRef):
|
||||||
|
source_doc = self.db.dereference(source_ref)
|
||||||
|
if source_doc:
|
||||||
|
processed_source = {
|
||||||
|
"id": str(source_doc["_id"]),
|
||||||
|
"retriever": source_doc.get("retriever", "classic"),
|
||||||
|
"chunks": source_doc.get("chunks", data.get("chunks", "2")),
|
||||||
|
}
|
||||||
|
sources_list.append(processed_source)
|
||||||
|
data["sources"] = sources_list
|
||||||
|
else:
|
||||||
|
data["sources"] = []
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _configure_source(self):
|
||||||
|
"""Configure the source based on agent data"""
|
||||||
|
api_key = self.data.get("api_key") or self.agent_key
|
||||||
|
|
||||||
|
if api_key:
|
||||||
|
agent_data = self._get_data_from_api_key(api_key)
|
||||||
|
|
||||||
|
if agent_data.get("sources") and len(agent_data["sources"]) > 0:
|
||||||
|
source_ids = [
|
||||||
|
source["id"] for source in agent_data["sources"] if source.get("id")
|
||||||
|
]
|
||||||
|
if source_ids:
|
||||||
|
self.source = {"active_docs": source_ids}
|
||||||
|
else:
|
||||||
|
self.source = {}
|
||||||
|
self.all_sources = agent_data["sources"]
|
||||||
|
elif agent_data.get("source"):
|
||||||
|
self.source = {"active_docs": agent_data["source"]}
|
||||||
|
self.all_sources = [
|
||||||
|
{
|
||||||
|
"id": agent_data["source"],
|
||||||
|
"retriever": agent_data.get("retriever", "classic"),
|
||||||
|
}
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
self.source = {}
|
||||||
|
self.all_sources = []
|
||||||
|
return
|
||||||
|
if "active_docs" in self.data:
|
||||||
|
self.source = {"active_docs": self.data["active_docs"]}
|
||||||
|
return
|
||||||
|
self.source = {}
|
||||||
|
self.all_sources = []
|
||||||
|
|
||||||
|
def _configure_agent(self):
|
||||||
|
"""Configure the agent based on request data"""
|
||||||
|
agent_id = self.data.get("agent_id")
|
||||||
|
self.agent_key, self.is_shared_usage, self.shared_token = self._get_agent_key(
|
||||||
|
agent_id, self.initial_user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
api_key = self.data.get("api_key")
|
||||||
|
if api_key:
|
||||||
|
data_key = self._get_data_from_api_key(api_key)
|
||||||
|
self.agent_config.update(
|
||||||
|
{
|
||||||
|
"prompt_id": data_key.get("prompt_id", "default"),
|
||||||
|
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
|
||||||
|
"user_api_key": api_key,
|
||||||
|
"json_schema": data_key.get("json_schema"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self.initial_user_id = data_key.get("user")
|
||||||
|
self.decoded_token = {"sub": data_key.get("user")}
|
||||||
|
if data_key.get("source"):
|
||||||
|
self.source = {"active_docs": data_key["source"]}
|
||||||
|
if data_key.get("retriever"):
|
||||||
|
self.retriever_config["retriever_name"] = data_key["retriever"]
|
||||||
|
if data_key.get("chunks") is not None:
|
||||||
|
try:
|
||||||
|
self.retriever_config["chunks"] = int(data_key["chunks"])
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
logger.warning(
|
||||||
|
f"Invalid chunks value: {data_key['chunks']}, using default value 2"
|
||||||
|
)
|
||||||
|
self.retriever_config["chunks"] = 2
|
||||||
|
elif self.agent_key:
|
||||||
|
data_key = self._get_data_from_api_key(self.agent_key)
|
||||||
|
self.agent_config.update(
|
||||||
|
{
|
||||||
|
"prompt_id": data_key.get("prompt_id", "default"),
|
||||||
|
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
|
||||||
|
"user_api_key": self.agent_key,
|
||||||
|
"json_schema": data_key.get("json_schema"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self.decoded_token = (
|
||||||
|
self.decoded_token
|
||||||
|
if self.is_shared_usage
|
||||||
|
else {"sub": data_key.get("user")}
|
||||||
|
)
|
||||||
|
if data_key.get("source"):
|
||||||
|
self.source = {"active_docs": data_key["source"]}
|
||||||
|
if data_key.get("retriever"):
|
||||||
|
self.retriever_config["retriever_name"] = data_key["retriever"]
|
||||||
|
if data_key.get("chunks") is not None:
|
||||||
|
try:
|
||||||
|
self.retriever_config["chunks"] = int(data_key["chunks"])
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
logger.warning(
|
||||||
|
f"Invalid chunks value: {data_key['chunks']}, using default value 2"
|
||||||
|
)
|
||||||
|
self.retriever_config["chunks"] = 2
|
||||||
|
else:
|
||||||
|
self.agent_config.update(
|
||||||
|
{
|
||||||
|
"prompt_id": self.data.get("prompt_id", "default"),
|
||||||
|
"agent_type": settings.AGENT_NAME,
|
||||||
|
"user_api_key": None,
|
||||||
|
"json_schema": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def _configure_retriever(self):
|
||||||
|
history_token_limit = int(self.data.get("token_limit", 2000))
|
||||||
|
doc_token_limit = calculate_doc_token_budget(
|
||||||
|
gpt_model=self.gpt_model, history_token_limit=history_token_limit
|
||||||
|
)
|
||||||
|
|
||||||
|
self.retriever_config = {
|
||||||
|
"retriever_name": self.data.get("retriever", "classic"),
|
||||||
|
"chunks": int(self.data.get("chunks", 2)),
|
||||||
|
"doc_token_limit": doc_token_limit,
|
||||||
|
"history_token_limit": history_token_limit,
|
||||||
|
}
|
||||||
|
|
||||||
|
api_key = self.data.get("api_key") or self.agent_key
|
||||||
|
if not api_key and "isNoneDoc" in self.data and self.data["isNoneDoc"]:
|
||||||
|
self.retriever_config["chunks"] = 0
|
||||||
|
|
||||||
|
def create_retriever(self):
|
||||||
|
return RetrieverCreator.create_retriever(
|
||||||
|
self.retriever_config["retriever_name"],
|
||||||
|
source=self.source,
|
||||||
|
chat_history=self.history,
|
||||||
|
prompt=get_prompt(self.agent_config["prompt_id"], self.prompts_collection),
|
||||||
|
chunks=self.retriever_config["chunks"],
|
||||||
|
doc_token_limit=self.retriever_config.get("doc_token_limit", 50000),
|
||||||
|
gpt_model=self.gpt_model,
|
||||||
|
user_api_key=self.agent_config["user_api_key"],
|
||||||
|
decoded_token=self.decoded_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
def pre_fetch_docs(self, question: str) -> tuple[Optional[str], Optional[list]]:
|
||||||
|
"""Pre-fetch documents for template rendering before agent creation"""
|
||||||
|
if self.data.get("isNoneDoc", False):
|
||||||
|
logger.info("Pre-fetch skipped: isNoneDoc=True")
|
||||||
|
return None, None
|
||||||
|
try:
|
||||||
|
retriever = self.create_retriever()
|
||||||
|
logger.info(
|
||||||
|
f"Pre-fetching docs with chunks={retriever.chunks}, doc_token_limit={retriever.doc_token_limit}"
|
||||||
|
)
|
||||||
|
docs = retriever.search(question)
|
||||||
|
logger.info(f"Pre-fetch retrieved {len(docs) if docs else 0} documents")
|
||||||
|
|
||||||
|
if not docs:
|
||||||
|
logger.info("Pre-fetch: No documents returned from search")
|
||||||
|
return None, None
|
||||||
|
self.retrieved_docs = docs
|
||||||
|
|
||||||
|
docs_with_filenames = []
|
||||||
|
for doc in docs:
|
||||||
|
filename = doc.get("filename") or doc.get("title") or doc.get("source")
|
||||||
|
if filename:
|
||||||
|
chunk_header = str(filename)
|
||||||
|
docs_with_filenames.append(f"{chunk_header}\n{doc['text']}")
|
||||||
|
else:
|
||||||
|
docs_with_filenames.append(doc["text"])
|
||||||
|
docs_together = "\n\n".join(docs_with_filenames)
|
||||||
|
|
||||||
|
logger.info(f"Pre-fetch docs_together size: {len(docs_together)} chars")
|
||||||
|
|
||||||
|
return docs_together, docs
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to pre-fetch docs: {str(e)}", exc_info=True)
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
def pre_fetch_tools(self) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Pre-fetch tool data for template rendering before agent creation
|
||||||
|
|
||||||
|
Can be controlled via:
|
||||||
|
1. Global setting: ENABLE_TOOL_PREFETCH in .env
|
||||||
|
2. Per-request: disable_tool_prefetch in request data
|
||||||
|
"""
|
||||||
|
if not settings.ENABLE_TOOL_PREFETCH:
|
||||||
|
logger.info(
|
||||||
|
"Tool pre-fetching disabled globally via ENABLE_TOOL_PREFETCH setting"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
if self.data.get("disable_tool_prefetch", False):
|
||||||
|
logger.info("Tool pre-fetching disabled for this request")
|
||||||
|
return None
|
||||||
|
|
||||||
|
required_tool_actions = self._get_required_tool_actions()
|
||||||
|
filtering_enabled = required_tool_actions is not None
|
||||||
|
|
||||||
|
try:
|
||||||
|
user_tools_collection = self.db["user_tools"]
|
||||||
|
user_id = self.initial_user_id or "local"
|
||||||
|
|
||||||
|
user_tools = list(
|
||||||
|
user_tools_collection.find({"user": user_id, "status": True})
|
||||||
|
)
|
||||||
|
|
||||||
|
if not user_tools:
|
||||||
|
return None
|
||||||
|
|
||||||
|
tools_data = {}
|
||||||
|
|
||||||
|
for tool_doc in user_tools:
|
||||||
|
tool_name = tool_doc.get("name")
|
||||||
|
tool_id = str(tool_doc.get("_id"))
|
||||||
|
|
||||||
|
if filtering_enabled:
|
||||||
|
required_actions_by_name = required_tool_actions.get(
|
||||||
|
tool_name, set()
|
||||||
|
)
|
||||||
|
required_actions_by_id = required_tool_actions.get(tool_id, set())
|
||||||
|
|
||||||
|
required_actions = required_actions_by_name | required_actions_by_id
|
||||||
|
|
||||||
|
if not required_actions:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
required_actions = None
|
||||||
|
|
||||||
|
tool_data = self._fetch_tool_data(tool_doc, required_actions)
|
||||||
|
if tool_data:
|
||||||
|
tools_data[tool_name] = tool_data
|
||||||
|
tools_data[tool_id] = tool_data
|
||||||
|
|
||||||
|
return tools_data if tools_data else None
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to pre-fetch tools: {type(e).__name__}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _fetch_tool_data(
|
||||||
|
self,
|
||||||
|
tool_doc: Dict[str, Any],
|
||||||
|
required_actions: Optional[Set[Optional[str]]],
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Fetch and execute tool actions with saved parameters"""
|
||||||
|
try:
|
||||||
|
from application.agents.tools.tool_manager import ToolManager
|
||||||
|
|
||||||
|
tool_name = tool_doc.get("name")
|
||||||
|
tool_config = tool_doc.get("config", {}).copy()
|
||||||
|
tool_config["tool_id"] = str(tool_doc["_id"])
|
||||||
|
|
||||||
|
tool_manager = ToolManager(config={tool_name: tool_config})
|
||||||
|
user_id = self.initial_user_id or "local"
|
||||||
|
tool = tool_manager.load_tool(tool_name, tool_config, user_id=user_id)
|
||||||
|
|
||||||
|
if not tool:
|
||||||
|
logger.debug(f"Tool '{tool_name}' failed to load")
|
||||||
|
return None
|
||||||
|
|
||||||
|
tool_actions = tool.get_actions_metadata()
|
||||||
|
if not tool_actions:
|
||||||
|
logger.debug(f"Tool '{tool_name}' has no actions")
|
||||||
|
return None
|
||||||
|
|
||||||
|
saved_actions = tool_doc.get("actions", [])
|
||||||
|
|
||||||
|
include_all_actions = required_actions is None or (
|
||||||
|
required_actions and None in required_actions
|
||||||
|
)
|
||||||
|
allowed_actions: Set[str] = (
|
||||||
|
{action for action in required_actions if isinstance(action, str)}
|
||||||
|
if required_actions
|
||||||
|
else set()
|
||||||
|
)
|
||||||
|
|
||||||
|
action_results = {}
|
||||||
|
for action_meta in tool_actions:
|
||||||
|
action_name = action_meta.get("name")
|
||||||
|
if action_name is None:
|
||||||
|
continue
|
||||||
|
if (
|
||||||
|
not include_all_actions
|
||||||
|
and allowed_actions
|
||||||
|
and action_name not in allowed_actions
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
saved_action = None
|
||||||
|
for sa in saved_actions:
|
||||||
|
if sa.get("name") == action_name:
|
||||||
|
saved_action = sa
|
||||||
|
break
|
||||||
|
|
||||||
|
action_params = action_meta.get("parameters", {})
|
||||||
|
properties = action_params.get("properties", {})
|
||||||
|
|
||||||
|
kwargs = {}
|
||||||
|
for param_name, param_spec in properties.items():
|
||||||
|
if saved_action:
|
||||||
|
saved_props = saved_action.get("parameters", {}).get(
|
||||||
|
"properties", {}
|
||||||
|
)
|
||||||
|
if param_name in saved_props:
|
||||||
|
param_value = saved_props[param_name].get("value")
|
||||||
|
if param_value is not None:
|
||||||
|
kwargs[param_name] = param_value
|
||||||
|
continue
|
||||||
|
|
||||||
|
if param_name in tool_config:
|
||||||
|
kwargs[param_name] = tool_config[param_name]
|
||||||
|
elif "default" in param_spec:
|
||||||
|
kwargs[param_name] = param_spec["default"]
|
||||||
|
|
||||||
|
result = tool.execute_action(action_name, **kwargs)
|
||||||
|
action_results[action_name] = result
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(
|
||||||
|
f"Action '{action_name}' execution failed: {type(e).__name__}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
return action_results if action_results else None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Tool pre-fetch failed for '{tool_name}': {type(e).__name__}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_prompt_content(self) -> Optional[str]:
|
||||||
|
"""Retrieve and cache the raw prompt content for the current agent configuration."""
|
||||||
|
if self._prompt_content is not None:
|
||||||
|
return self._prompt_content
|
||||||
|
prompt_id = (
|
||||||
|
self.agent_config.get("prompt_id")
|
||||||
|
if isinstance(self.agent_config, dict)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if not prompt_id:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
self._prompt_content = get_prompt(prompt_id, self.prompts_collection)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.debug(f"Invalid prompt ID '{prompt_id}': {str(e)}")
|
||||||
|
self._prompt_content = None
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Failed to fetch prompt '{prompt_id}': {type(e).__name__}")
|
||||||
|
self._prompt_content = None
|
||||||
|
return self._prompt_content
|
||||||
|
|
||||||
|
def _get_required_tool_actions(self) -> Optional[Dict[str, Set[Optional[str]]]]:
|
||||||
|
"""Determine which tool actions are referenced in the prompt template"""
|
||||||
|
if self._required_tool_actions is not None:
|
||||||
|
return self._required_tool_actions
|
||||||
|
|
||||||
|
prompt_content = self._get_prompt_content()
|
||||||
|
if prompt_content is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if "{{" not in prompt_content or "}}" not in prompt_content:
|
||||||
|
self._required_tool_actions = {}
|
||||||
|
return self._required_tool_actions
|
||||||
|
|
||||||
|
try:
|
||||||
|
from application.templates.template_engine import TemplateEngine
|
||||||
|
|
||||||
|
template_engine = TemplateEngine()
|
||||||
|
usages = template_engine.extract_tool_usages(prompt_content)
|
||||||
|
self._required_tool_actions = usages
|
||||||
|
return self._required_tool_actions
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Failed to extract tool usages: {type(e).__name__}")
|
||||||
|
self._required_tool_actions = {}
|
||||||
|
return self._required_tool_actions
|
||||||
|
|
||||||
|
def _fetch_memory_tool_data(
|
||||||
|
self, tool_doc: Dict[str, Any]
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Fetch memory tool data for pre-injection into prompt"""
|
||||||
|
try:
|
||||||
|
tool_config = tool_doc.get("config", {}).copy()
|
||||||
|
tool_config["tool_id"] = str(tool_doc["_id"])
|
||||||
|
|
||||||
|
from application.agents.tools.memory import MemoryTool
|
||||||
|
|
||||||
|
memory_tool = MemoryTool(tool_config, self.initial_user_id)
|
||||||
|
|
||||||
|
root_view = memory_tool.execute_action("view", path="/")
|
||||||
|
|
||||||
|
if "Error:" in root_view or not root_view.strip():
|
||||||
|
return None
|
||||||
|
|
||||||
|
return {"root": root_view, "available": True}
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to fetch memory tool data: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def create_agent(
|
||||||
|
self,
|
||||||
|
docs_together: Optional[str] = None,
|
||||||
|
docs: Optional[list] = None,
|
||||||
|
tools_data: Optional[Dict[str, Any]] = None,
|
||||||
|
):
|
||||||
|
"""Create and return the configured agent with rendered prompt"""
|
||||||
|
raw_prompt = self._get_prompt_content()
|
||||||
|
if raw_prompt is None:
|
||||||
|
raw_prompt = get_prompt(
|
||||||
|
self.agent_config["prompt_id"], self.prompts_collection
|
||||||
|
)
|
||||||
|
self._prompt_content = raw_prompt
|
||||||
|
|
||||||
|
rendered_prompt = self.prompt_renderer.render_prompt(
|
||||||
|
prompt_content=raw_prompt,
|
||||||
|
user_id=self.initial_user_id,
|
||||||
|
request_id=self.data.get("request_id"),
|
||||||
|
passthrough_data=self.data.get("passthrough"),
|
||||||
|
docs=docs,
|
||||||
|
docs_together=docs_together,
|
||||||
|
tools_data=tools_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
return AgentCreator.create_agent(
|
||||||
|
self.agent_config["agent_type"],
|
||||||
|
endpoint="stream",
|
||||||
|
llm_name=settings.LLM_PROVIDER,
|
||||||
|
gpt_model=self.gpt_model,
|
||||||
|
api_key=settings.API_KEY,
|
||||||
|
user_api_key=self.agent_config["user_api_key"],
|
||||||
|
prompt=rendered_prompt,
|
||||||
|
chat_history=self.history,
|
||||||
|
retrieved_docs=self.retrieved_docs,
|
||||||
|
decoded_token=self.decoded_token,
|
||||||
|
attachments=self.attachments,
|
||||||
|
json_schema=self.agent_config.get("json_schema"),
|
||||||
|
)
|
||||||
489
application/api/connector/routes.py
Normal file
489
application/api/connector/routes.py
Normal file
@@ -0,0 +1,489 @@
|
|||||||
|
import base64
|
||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
|
||||||
|
from bson.objectid import ObjectId
|
||||||
|
from flask import (
|
||||||
|
Blueprint,
|
||||||
|
current_app,
|
||||||
|
jsonify,
|
||||||
|
make_response,
|
||||||
|
request
|
||||||
|
)
|
||||||
|
from flask_restx import fields, Namespace, Resource
|
||||||
|
|
||||||
|
|
||||||
|
from application.api.user.tasks import (
|
||||||
|
ingest_connector_task,
|
||||||
|
)
|
||||||
|
from application.core.mongo_db import MongoDB
|
||||||
|
from application.core.settings import settings
|
||||||
|
from application.api import api
|
||||||
|
|
||||||
|
|
||||||
|
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||||
|
|
||||||
|
|
||||||
|
mongo = MongoDB.get_client()
|
||||||
|
db = mongo[settings.MONGO_DB_NAME]
|
||||||
|
sources_collection = db["sources"]
|
||||||
|
sessions_collection = db["connector_sessions"]
|
||||||
|
|
||||||
|
connector = Blueprint("connector", __name__)
|
||||||
|
connectors_ns = Namespace("connectors", description="Connector operations", path="/")
|
||||||
|
api.add_namespace(connectors_ns)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@connectors_ns.route("/api/connectors/auth")
|
||||||
|
class ConnectorAuth(Resource):
|
||||||
|
@api.doc(description="Get connector OAuth authorization URL", params={"provider": "Connector provider (e.g., google_drive)"})
|
||||||
|
def get(self):
|
||||||
|
try:
|
||||||
|
provider = request.args.get('provider') or request.args.get('source')
|
||||||
|
if not provider:
|
||||||
|
return make_response(jsonify({"success": False, "error": "Missing provider"}), 400)
|
||||||
|
|
||||||
|
if not ConnectorCreator.is_supported(provider):
|
||||||
|
return make_response(jsonify({"success": False, "error": f"Unsupported provider: {provider}"}), 400)
|
||||||
|
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
||||||
|
user_id = decoded_token.get('sub')
|
||||||
|
|
||||||
|
now = datetime.datetime.now(datetime.timezone.utc)
|
||||||
|
result = sessions_collection.insert_one({
|
||||||
|
"provider": provider,
|
||||||
|
"user": user_id,
|
||||||
|
"status": "pending",
|
||||||
|
"created_at": now
|
||||||
|
})
|
||||||
|
state_dict = {
|
||||||
|
"provider": provider,
|
||||||
|
"object_id": str(result.inserted_id)
|
||||||
|
}
|
||||||
|
state = base64.urlsafe_b64encode(json.dumps(state_dict).encode()).decode()
|
||||||
|
|
||||||
|
auth = ConnectorCreator.create_auth(provider)
|
||||||
|
authorization_url = auth.get_authorization_url(state=state)
|
||||||
|
return make_response(jsonify({
|
||||||
|
"success": True,
|
||||||
|
"authorization_url": authorization_url,
|
||||||
|
"state": state
|
||||||
|
}), 200)
|
||||||
|
except Exception as e:
|
||||||
|
current_app.logger.error(f"Error generating connector auth URL: {e}")
|
||||||
|
return make_response(jsonify({"success": False, "error": str(e)}), 500)
|
||||||
|
|
||||||
|
|
||||||
|
@connectors_ns.route("/api/connectors/callback")
|
||||||
|
class ConnectorsCallback(Resource):
|
||||||
|
@api.doc(description="Handle OAuth callback for external connectors")
|
||||||
|
def get(self):
|
||||||
|
"""Handle OAuth callback for external connectors"""
|
||||||
|
try:
|
||||||
|
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||||
|
from flask import request, redirect
|
||||||
|
|
||||||
|
authorization_code = request.args.get('code')
|
||||||
|
state = request.args.get('state')
|
||||||
|
error = request.args.get('error')
|
||||||
|
|
||||||
|
state_dict = json.loads(base64.urlsafe_b64decode(state.encode()).decode())
|
||||||
|
provider = state_dict["provider"]
|
||||||
|
state_object_id = state_dict["object_id"]
|
||||||
|
|
||||||
|
if error:
|
||||||
|
if error == "access_denied":
|
||||||
|
return redirect(f"/api/connectors/callback-status?status=cancelled&message=Authentication+was+cancelled.+You+can+try+again+if+you'd+like+to+connect+your+account.&provider={provider}")
|
||||||
|
else:
|
||||||
|
current_app.logger.warning(f"OAuth error in callback: {error}")
|
||||||
|
return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}")
|
||||||
|
|
||||||
|
if not authorization_code:
|
||||||
|
return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
auth = ConnectorCreator.create_auth(provider)
|
||||||
|
token_info = auth.exchange_code_for_tokens(authorization_code)
|
||||||
|
|
||||||
|
session_token = str(uuid.uuid4())
|
||||||
|
|
||||||
|
try:
|
||||||
|
credentials = auth.create_credentials_from_token_info(token_info)
|
||||||
|
service = auth.build_drive_service(credentials)
|
||||||
|
user_info = service.about().get(fields="user").execute()
|
||||||
|
user_email = user_info.get('user', {}).get('emailAddress', 'Connected User')
|
||||||
|
except Exception as e:
|
||||||
|
current_app.logger.warning(f"Could not get user info: {e}")
|
||||||
|
user_email = 'Connected User'
|
||||||
|
|
||||||
|
sanitized_token_info = {
|
||||||
|
"access_token": token_info.get("access_token"),
|
||||||
|
"refresh_token": token_info.get("refresh_token"),
|
||||||
|
"token_uri": token_info.get("token_uri"),
|
||||||
|
"expiry": token_info.get("expiry")
|
||||||
|
}
|
||||||
|
|
||||||
|
sessions_collection.find_one_and_update(
|
||||||
|
{"_id": ObjectId(state_object_id), "provider": provider},
|
||||||
|
{
|
||||||
|
"$set": {
|
||||||
|
"session_token": session_token,
|
||||||
|
"token_info": sanitized_token_info,
|
||||||
|
"user_email": user_email,
|
||||||
|
"status": "authorized"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Redirect to success page with session token and user email
|
||||||
|
return redirect(f"/api/connectors/callback-status?status=success&message=Authentication+successful&provider={provider}&session_token={session_token}&user_email={user_email}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
current_app.logger.error(f"Error exchanging code for tokens: {str(e)}", exc_info=True)
|
||||||
|
return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
current_app.logger.error(f"Error handling connector callback: {e}")
|
||||||
|
return redirect("/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.")
|
||||||
|
|
||||||
|
|
||||||
|
@connectors_ns.route("/api/connectors/files")
|
||||||
|
class ConnectorFiles(Resource):
|
||||||
|
@api.expect(api.model("ConnectorFilesModel", {
|
||||||
|
"provider": fields.String(required=True),
|
||||||
|
"session_token": fields.String(required=True),
|
||||||
|
"folder_id": fields.String(required=False),
|
||||||
|
"limit": fields.Integer(required=False),
|
||||||
|
"page_token": fields.String(required=False),
|
||||||
|
"search_query": fields.String(required=False)
|
||||||
|
}))
|
||||||
|
@api.doc(description="List files from a connector provider (supports pagination and search)")
|
||||||
|
def post(self):
|
||||||
|
try:
|
||||||
|
data = request.get_json()
|
||||||
|
provider = data.get('provider')
|
||||||
|
session_token = data.get('session_token')
|
||||||
|
folder_id = data.get('folder_id')
|
||||||
|
limit = data.get('limit', 10)
|
||||||
|
page_token = data.get('page_token')
|
||||||
|
search_query = data.get('search_query')
|
||||||
|
|
||||||
|
if not provider or not session_token:
|
||||||
|
return make_response(jsonify({"success": False, "error": "provider and session_token are required"}), 400)
|
||||||
|
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
||||||
|
user = decoded_token.get('sub')
|
||||||
|
session = sessions_collection.find_one({"session_token": session_token, "user": user})
|
||||||
|
if not session:
|
||||||
|
return make_response(jsonify({"success": False, "error": "Invalid or unauthorized session"}), 401)
|
||||||
|
|
||||||
|
loader = ConnectorCreator.create_connector(provider, session_token)
|
||||||
|
input_config = {
|
||||||
|
'limit': limit,
|
||||||
|
'list_only': True,
|
||||||
|
'session_token': session_token,
|
||||||
|
'folder_id': folder_id,
|
||||||
|
'page_token': page_token
|
||||||
|
}
|
||||||
|
if search_query:
|
||||||
|
input_config['search_query'] = search_query
|
||||||
|
|
||||||
|
documents = loader.load_data(input_config)
|
||||||
|
|
||||||
|
files = []
|
||||||
|
for doc in documents[:limit]:
|
||||||
|
metadata = doc.extra_info
|
||||||
|
modified_time = metadata.get('modified_time')
|
||||||
|
if modified_time:
|
||||||
|
date_part = modified_time.split('T')[0]
|
||||||
|
time_part = modified_time.split('T')[1].split('.')[0].split('Z')[0]
|
||||||
|
formatted_time = f"{date_part} {time_part}"
|
||||||
|
else:
|
||||||
|
formatted_time = None
|
||||||
|
|
||||||
|
files.append({
|
||||||
|
'id': doc.doc_id,
|
||||||
|
'name': metadata.get('file_name', 'Unknown File'),
|
||||||
|
'type': metadata.get('mime_type', 'unknown'),
|
||||||
|
'size': metadata.get('size', None),
|
||||||
|
'modifiedTime': formatted_time,
|
||||||
|
'isFolder': metadata.get('is_folder', False)
|
||||||
|
})
|
||||||
|
|
||||||
|
next_token = getattr(loader, 'next_page_token', None)
|
||||||
|
has_more = bool(next_token)
|
||||||
|
|
||||||
|
return make_response(jsonify({
|
||||||
|
"success": True,
|
||||||
|
"files": files,
|
||||||
|
"total": len(files),
|
||||||
|
"next_page_token": next_token,
|
||||||
|
"has_more": has_more
|
||||||
|
}), 200)
|
||||||
|
except Exception as e:
|
||||||
|
current_app.logger.error(f"Error loading connector files: {e}")
|
||||||
|
return make_response(jsonify({"success": False, "error": f"Failed to load files: {str(e)}"}), 500)
|
||||||
|
|
||||||
|
|
||||||
|
@connectors_ns.route("/api/connectors/validate-session")
|
||||||
|
class ConnectorValidateSession(Resource):
|
||||||
|
@api.expect(api.model("ConnectorValidateSessionModel", {"provider": fields.String(required=True), "session_token": fields.String(required=True)}))
|
||||||
|
@api.doc(description="Validate connector session token and return user info and access token")
|
||||||
|
def post(self):
|
||||||
|
try:
|
||||||
|
data = request.get_json()
|
||||||
|
provider = data.get('provider')
|
||||||
|
session_token = data.get('session_token')
|
||||||
|
if not provider or not session_token:
|
||||||
|
return make_response(jsonify({"success": False, "error": "provider and session_token are required"}), 400)
|
||||||
|
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
||||||
|
user = decoded_token.get('sub')
|
||||||
|
|
||||||
|
session = sessions_collection.find_one({"session_token": session_token, "user": user})
|
||||||
|
if not session or "token_info" not in session:
|
||||||
|
return make_response(jsonify({"success": False, "error": "Invalid or expired session"}), 401)
|
||||||
|
|
||||||
|
token_info = session["token_info"]
|
||||||
|
auth = ConnectorCreator.create_auth(provider)
|
||||||
|
is_expired = auth.is_token_expired(token_info)
|
||||||
|
|
||||||
|
if is_expired and token_info.get('refresh_token'):
|
||||||
|
try:
|
||||||
|
refreshed_token_info = auth.refresh_access_token(token_info.get('refresh_token'))
|
||||||
|
sanitized_token_info = {
|
||||||
|
"access_token": refreshed_token_info.get("access_token"),
|
||||||
|
"refresh_token": refreshed_token_info.get("refresh_token"),
|
||||||
|
"token_uri": refreshed_token_info.get("token_uri"),
|
||||||
|
"expiry": refreshed_token_info.get("expiry")
|
||||||
|
}
|
||||||
|
sessions_collection.update_one(
|
||||||
|
{"session_token": session_token},
|
||||||
|
{"$set": {"token_info": sanitized_token_info}}
|
||||||
|
)
|
||||||
|
token_info = sanitized_token_info
|
||||||
|
is_expired = False
|
||||||
|
except Exception as refresh_error:
|
||||||
|
current_app.logger.error(f"Failed to refresh token: {refresh_error}")
|
||||||
|
|
||||||
|
if is_expired:
|
||||||
|
return make_response(jsonify({
|
||||||
|
"success": False,
|
||||||
|
"expired": True,
|
||||||
|
"error": "Session token has expired. Please reconnect."
|
||||||
|
}), 401)
|
||||||
|
|
||||||
|
return make_response(jsonify({
|
||||||
|
"success": True,
|
||||||
|
"expired": False,
|
||||||
|
"user_email": session.get('user_email', 'Connected User'),
|
||||||
|
"access_token": token_info.get('access_token')
|
||||||
|
}), 200)
|
||||||
|
except Exception as e:
|
||||||
|
current_app.logger.error(f"Error validating connector session: {e}")
|
||||||
|
return make_response(jsonify({"success": False, "error": str(e)}), 500)
|
||||||
|
|
||||||
|
|
||||||
|
@connectors_ns.route("/api/connectors/disconnect")
|
||||||
|
class ConnectorDisconnect(Resource):
|
||||||
|
@api.expect(api.model("ConnectorDisconnectModel", {"provider": fields.String(required=True), "session_token": fields.String(required=False)}))
|
||||||
|
@api.doc(description="Disconnect a connector session")
|
||||||
|
def post(self):
|
||||||
|
try:
|
||||||
|
data = request.get_json()
|
||||||
|
provider = data.get('provider')
|
||||||
|
session_token = data.get('session_token')
|
||||||
|
if not provider:
|
||||||
|
return make_response(jsonify({"success": False, "error": "provider is required"}), 400)
|
||||||
|
|
||||||
|
|
||||||
|
if session_token:
|
||||||
|
sessions_collection.delete_one({"session_token": session_token})
|
||||||
|
|
||||||
|
return make_response(jsonify({"success": True}), 200)
|
||||||
|
except Exception as e:
|
||||||
|
current_app.logger.error(f"Error disconnecting connector session: {e}")
|
||||||
|
return make_response(jsonify({"success": False, "error": str(e)}), 500)
|
||||||
|
|
||||||
|
|
||||||
|
@connectors_ns.route("/api/connectors/sync")
|
||||||
|
class ConnectorSync(Resource):
|
||||||
|
@api.expect(
|
||||||
|
api.model(
|
||||||
|
"ConnectorSyncModel",
|
||||||
|
{
|
||||||
|
"source_id": fields.String(required=True, description="Source ID to sync"),
|
||||||
|
"session_token": fields.String(required=True, description="Authentication token")
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
@api.doc(description="Sync connector source to check for modifications")
|
||||||
|
def post(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = request.get_json()
|
||||||
|
source_id = data.get('source_id')
|
||||||
|
session_token = data.get('session_token')
|
||||||
|
|
||||||
|
if not all([source_id, session_token]):
|
||||||
|
return make_response(
|
||||||
|
jsonify({
|
||||||
|
"success": False,
|
||||||
|
"error": "source_id and session_token are required"
|
||||||
|
}),
|
||||||
|
400
|
||||||
|
)
|
||||||
|
source = sources_collection.find_one({"_id": ObjectId(source_id)})
|
||||||
|
if not source:
|
||||||
|
return make_response(
|
||||||
|
jsonify({
|
||||||
|
"success": False,
|
||||||
|
"error": "Source not found"
|
||||||
|
}),
|
||||||
|
404
|
||||||
|
)
|
||||||
|
|
||||||
|
if source.get('user') != decoded_token.get('sub'):
|
||||||
|
return make_response(
|
||||||
|
jsonify({
|
||||||
|
"success": False,
|
||||||
|
"error": "Unauthorized access to source"
|
||||||
|
}),
|
||||||
|
403
|
||||||
|
)
|
||||||
|
|
||||||
|
remote_data = {}
|
||||||
|
try:
|
||||||
|
if source.get('remote_data'):
|
||||||
|
remote_data = json.loads(source.get('remote_data'))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
current_app.logger.error(f"Invalid remote_data format for source {source_id}")
|
||||||
|
remote_data = {}
|
||||||
|
|
||||||
|
source_type = remote_data.get('provider')
|
||||||
|
if not source_type:
|
||||||
|
return make_response(
|
||||||
|
jsonify({
|
||||||
|
"success": False,
|
||||||
|
"error": "Source provider not found in remote_data"
|
||||||
|
}),
|
||||||
|
400
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract configuration from remote_data
|
||||||
|
file_ids = remote_data.get('file_ids', [])
|
||||||
|
folder_ids = remote_data.get('folder_ids', [])
|
||||||
|
recursive = remote_data.get('recursive', True)
|
||||||
|
|
||||||
|
# Start the sync task
|
||||||
|
task = ingest_connector_task.delay(
|
||||||
|
job_name=source.get('name'),
|
||||||
|
user=decoded_token.get('sub'),
|
||||||
|
source_type=source_type,
|
||||||
|
session_token=session_token,
|
||||||
|
file_ids=file_ids,
|
||||||
|
folder_ids=folder_ids,
|
||||||
|
recursive=recursive,
|
||||||
|
retriever=source.get('retriever', 'classic'),
|
||||||
|
operation_mode="sync",
|
||||||
|
doc_id=source_id,
|
||||||
|
sync_frequency=source.get('sync_frequency', 'never')
|
||||||
|
)
|
||||||
|
|
||||||
|
return make_response(
|
||||||
|
jsonify({
|
||||||
|
"success": True,
|
||||||
|
"task_id": task.id
|
||||||
|
}),
|
||||||
|
200
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(
|
||||||
|
f"Error syncing connector source: {err}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
return make_response(
|
||||||
|
jsonify({
|
||||||
|
"success": False,
|
||||||
|
"error": str(err)
|
||||||
|
}),
|
||||||
|
400
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@connectors_ns.route("/api/connectors/callback-status")
|
||||||
|
class ConnectorCallbackStatus(Resource):
|
||||||
|
@api.doc(description="Return HTML page with connector authentication status")
|
||||||
|
def get(self):
|
||||||
|
"""Return HTML page with connector authentication status"""
|
||||||
|
try:
|
||||||
|
status = request.args.get('status', 'error')
|
||||||
|
message = request.args.get('message', '')
|
||||||
|
provider = request.args.get('provider', 'connector')
|
||||||
|
session_token = request.args.get('session_token', '')
|
||||||
|
user_email = request.args.get('user_email', '')
|
||||||
|
|
||||||
|
html_content = f"""
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<title>{provider.replace('_', ' ').title()} Authentication</title>
|
||||||
|
<style>
|
||||||
|
body {{ font-family: Arial, sans-serif; text-align: center; padding: 40px; }}
|
||||||
|
.container {{ max-width: 600px; margin: 0 auto; }}
|
||||||
|
.success {{ color: #4CAF50; }}
|
||||||
|
.error {{ color: #F44336; }}
|
||||||
|
.cancelled {{ color: #FF9800; }}
|
||||||
|
</style>
|
||||||
|
<script>
|
||||||
|
window.onload = function() {{
|
||||||
|
const status = "{status}";
|
||||||
|
const sessionToken = "{session_token}";
|
||||||
|
const userEmail = "{user_email}";
|
||||||
|
|
||||||
|
if (status === "success" && window.opener) {{
|
||||||
|
window.opener.postMessage({{
|
||||||
|
type: '{provider}_auth_success',
|
||||||
|
session_token: sessionToken,
|
||||||
|
user_email: userEmail
|
||||||
|
}}, '*');
|
||||||
|
|
||||||
|
setTimeout(() => window.close(), 3000);
|
||||||
|
}} else if (status === "cancelled" || status === "error") {{
|
||||||
|
setTimeout(() => window.close(), 3000);
|
||||||
|
}}
|
||||||
|
}};
|
||||||
|
</script>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="container">
|
||||||
|
<h2>{provider.replace('_', ' ').title()} Authentication</h2>
|
||||||
|
<div class="{status}">
|
||||||
|
<p>{message}</p>
|
||||||
|
{f'<p>Connected as: {user_email}</p>' if status == 'success' else ''}
|
||||||
|
</div>
|
||||||
|
<p><small>You can close this window. {f"Your {provider.replace('_', ' ').title()} is now connected and ready to use." if status == 'success' else "Feel free to close this window."}</small></p>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
"""
|
||||||
|
|
||||||
|
return make_response(html_content, 200, {'Content-Type': 'text/html'})
|
||||||
|
except Exception as e:
|
||||||
|
current_app.logger.error(f"Error rendering callback status page: {e}")
|
||||||
|
return make_response("Authentication error occurred", 500, {'Content-Type': 'text/html'})
|
||||||
|
|
||||||
|
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import datetime
|
import datetime
|
||||||
|
import json
|
||||||
from flask import Blueprint, request, send_from_directory
|
from flask import Blueprint, request, send_from_directory
|
||||||
from werkzeug.utils import secure_filename
|
from werkzeug.utils import secure_filename
|
||||||
from bson.objectid import ObjectId
|
from bson.objectid import ObjectId
|
||||||
@@ -37,16 +38,28 @@ def upload_index_files():
|
|||||||
"""Upload two files(index.faiss, index.pkl) to the user's folder."""
|
"""Upload two files(index.faiss, index.pkl) to the user's folder."""
|
||||||
if "user" not in request.form:
|
if "user" not in request.form:
|
||||||
return {"status": "no user"}
|
return {"status": "no user"}
|
||||||
user = secure_filename(request.form["user"])
|
user = request.form["user"]
|
||||||
if "name" not in request.form:
|
if "name" not in request.form:
|
||||||
return {"status": "no name"}
|
return {"status": "no name"}
|
||||||
job_name = secure_filename(request.form["name"])
|
job_name = request.form["name"]
|
||||||
tokens = secure_filename(request.form["tokens"])
|
tokens = request.form["tokens"]
|
||||||
retriever = secure_filename(request.form["retriever"])
|
retriever = request.form["retriever"]
|
||||||
id = secure_filename(request.form["id"])
|
id = request.form["id"]
|
||||||
type = secure_filename(request.form["type"])
|
type = request.form["type"]
|
||||||
remote_data = request.form["remote_data"] if "remote_data" in request.form else None
|
remote_data = request.form["remote_data"] if "remote_data" in request.form else None
|
||||||
sync_frequency = secure_filename(request.form["sync_frequency"]) if "sync_frequency" in request.form else None
|
sync_frequency = request.form["sync_frequency"] if "sync_frequency" in request.form else None
|
||||||
|
|
||||||
|
file_path = request.form.get("file_path")
|
||||||
|
directory_structure = request.form.get("directory_structure")
|
||||||
|
|
||||||
|
if directory_structure:
|
||||||
|
try:
|
||||||
|
directory_structure = json.loads(directory_structure)
|
||||||
|
except Exception:
|
||||||
|
logger.error("Error parsing directory_structure")
|
||||||
|
directory_structure = {}
|
||||||
|
else:
|
||||||
|
directory_structure = {}
|
||||||
|
|
||||||
storage = StorageCreator.get_storage()
|
storage = StorageCreator.get_storage()
|
||||||
index_base_path = f"indexes/{id}"
|
index_base_path = f"indexes/{id}"
|
||||||
@@ -64,10 +77,13 @@ def upload_index_files():
|
|||||||
file_pkl = request.files["file_pkl"]
|
file_pkl = request.files["file_pkl"]
|
||||||
if file_pkl.filename == "":
|
if file_pkl.filename == "":
|
||||||
return {"status": "no file name"}
|
return {"status": "no file name"}
|
||||||
|
|
||||||
# Save index files to storage
|
# Save index files to storage
|
||||||
storage.save_file(file_faiss, f"{index_base_path}/index.faiss")
|
faiss_storage_path = f"{index_base_path}/index.faiss"
|
||||||
storage.save_file(file_pkl, f"{index_base_path}/index.pkl")
|
pkl_storage_path = f"{index_base_path}/index.pkl"
|
||||||
|
storage.save_file(file_faiss, faiss_storage_path)
|
||||||
|
storage.save_file(file_pkl, pkl_storage_path)
|
||||||
|
|
||||||
|
|
||||||
existing_entry = sources_collection.find_one({"_id": ObjectId(id)})
|
existing_entry = sources_collection.find_one({"_id": ObjectId(id)})
|
||||||
if existing_entry:
|
if existing_entry:
|
||||||
@@ -85,6 +101,8 @@ def upload_index_files():
|
|||||||
"retriever": retriever,
|
"retriever": retriever,
|
||||||
"remote_data": remote_data,
|
"remote_data": remote_data,
|
||||||
"sync_frequency": sync_frequency,
|
"sync_frequency": sync_frequency,
|
||||||
|
"file_path": file_path,
|
||||||
|
"directory_structure": directory_structure,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -102,6 +120,8 @@ def upload_index_files():
|
|||||||
"retriever": retriever,
|
"retriever": retriever,
|
||||||
"remote_data": remote_data,
|
"remote_data": remote_data,
|
||||||
"sync_frequency": sync_frequency,
|
"sync_frequency": sync_frequency,
|
||||||
|
"file_path": file_path,
|
||||||
|
"directory_structure": directory_structure,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|||||||
@@ -0,0 +1,5 @@
|
|||||||
|
"""User API module - provides all user-related API endpoints"""
|
||||||
|
|
||||||
|
from .routes import user
|
||||||
|
|
||||||
|
__all__ = ["user"]
|
||||||
|
|||||||
7
application/api/user/agents/__init__.py
Normal file
7
application/api/user/agents/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
"""Agents module."""
|
||||||
|
|
||||||
|
from .routes import agents_ns
|
||||||
|
from .sharing import agents_sharing_ns
|
||||||
|
from .webhooks import agents_webhooks_ns
|
||||||
|
|
||||||
|
__all__ = ["agents_ns", "agents_sharing_ns", "agents_webhooks_ns"]
|
||||||
1111
application/api/user/agents/routes.py
Normal file
1111
application/api/user/agents/routes.py
Normal file
File diff suppressed because it is too large
Load Diff
263
application/api/user/agents/sharing.py
Normal file
263
application/api/user/agents/sharing.py
Normal file
@@ -0,0 +1,263 @@
|
|||||||
|
"""Agent management sharing functionality."""
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
import secrets
|
||||||
|
|
||||||
|
from bson import DBRef
|
||||||
|
from bson.objectid import ObjectId
|
||||||
|
from flask import current_app, jsonify, make_response, request
|
||||||
|
from flask_restx import fields, Namespace, Resource
|
||||||
|
|
||||||
|
from application.api import api
|
||||||
|
from application.core.settings import settings
|
||||||
|
from application.api.user.base import (
|
||||||
|
agents_collection,
|
||||||
|
db,
|
||||||
|
ensure_user_doc,
|
||||||
|
resolve_tool_details,
|
||||||
|
user_tools_collection,
|
||||||
|
users_collection,
|
||||||
|
)
|
||||||
|
from application.utils import generate_image_url
|
||||||
|
|
||||||
|
agents_sharing_ns = Namespace(
|
||||||
|
"agents", description="Agent management operations", path="/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@agents_sharing_ns.route("/shared_agent")
|
||||||
|
class SharedAgent(Resource):
|
||||||
|
@api.doc(
|
||||||
|
params={
|
||||||
|
"token": "Shared token of the agent",
|
||||||
|
},
|
||||||
|
description="Get a shared agent by token or ID",
|
||||||
|
)
|
||||||
|
def get(self):
|
||||||
|
shared_token = request.args.get("token")
|
||||||
|
|
||||||
|
if not shared_token:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Token or ID is required"}), 400
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
query = {
|
||||||
|
"shared_publicly": True,
|
||||||
|
"shared_token": shared_token,
|
||||||
|
}
|
||||||
|
shared_agent = agents_collection.find_one(query)
|
||||||
|
if not shared_agent:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Shared agent not found"}),
|
||||||
|
404,
|
||||||
|
)
|
||||||
|
agent_id = str(shared_agent["_id"])
|
||||||
|
data = {
|
||||||
|
"id": agent_id,
|
||||||
|
"user": shared_agent.get("user", ""),
|
||||||
|
"name": shared_agent.get("name", ""),
|
||||||
|
"image": (
|
||||||
|
generate_image_url(shared_agent["image"])
|
||||||
|
if shared_agent.get("image")
|
||||||
|
else ""
|
||||||
|
),
|
||||||
|
"description": shared_agent.get("description", ""),
|
||||||
|
"source": (
|
||||||
|
str(source_doc["_id"])
|
||||||
|
if isinstance(shared_agent.get("source"), DBRef)
|
||||||
|
and (source_doc := db.dereference(shared_agent.get("source")))
|
||||||
|
else ""
|
||||||
|
),
|
||||||
|
"chunks": shared_agent.get("chunks", "0"),
|
||||||
|
"retriever": shared_agent.get("retriever", "classic"),
|
||||||
|
"prompt_id": shared_agent.get("prompt_id", "default"),
|
||||||
|
"tools": shared_agent.get("tools", []),
|
||||||
|
"tool_details": resolve_tool_details(shared_agent.get("tools", [])),
|
||||||
|
"agent_type": shared_agent.get("agent_type", ""),
|
||||||
|
"status": shared_agent.get("status", ""),
|
||||||
|
"json_schema": shared_agent.get("json_schema"),
|
||||||
|
"limited_token_mode": shared_agent.get("limited_token_mode", False),
|
||||||
|
"token_limit": shared_agent.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"]),
|
||||||
|
"limited_request_mode": shared_agent.get("limited_request_mode", False),
|
||||||
|
"request_limit": shared_agent.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"]),
|
||||||
|
"created_at": shared_agent.get("createdAt", ""),
|
||||||
|
"updated_at": shared_agent.get("updatedAt", ""),
|
||||||
|
"shared": shared_agent.get("shared_publicly", False),
|
||||||
|
"shared_token": shared_agent.get("shared_token", ""),
|
||||||
|
"shared_metadata": shared_agent.get("shared_metadata", {}),
|
||||||
|
}
|
||||||
|
|
||||||
|
if data["tools"]:
|
||||||
|
enriched_tools = []
|
||||||
|
for tool in data["tools"]:
|
||||||
|
tool_data = user_tools_collection.find_one({"_id": ObjectId(tool)})
|
||||||
|
if tool_data:
|
||||||
|
enriched_tools.append(tool_data.get("name", ""))
|
||||||
|
data["tools"] = enriched_tools
|
||||||
|
decoded_token = getattr(request, "decoded_token", None)
|
||||||
|
if decoded_token:
|
||||||
|
user_id = decoded_token.get("sub")
|
||||||
|
owner_id = shared_agent.get("user")
|
||||||
|
|
||||||
|
if user_id != owner_id:
|
||||||
|
ensure_user_doc(user_id)
|
||||||
|
users_collection.update_one(
|
||||||
|
{"user_id": user_id},
|
||||||
|
{"$addToSet": {"agent_preferences.shared_with_me": agent_id}},
|
||||||
|
)
|
||||||
|
return make_response(jsonify(data), 200)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(f"Error retrieving shared agent: {err}")
|
||||||
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
|
||||||
|
|
||||||
|
@agents_sharing_ns.route("/shared_agents")
|
||||||
|
class SharedAgents(Resource):
|
||||||
|
@api.doc(description="Get shared agents explicitly shared with the user")
|
||||||
|
def get(self):
|
||||||
|
try:
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
user_id = decoded_token.get("sub")
|
||||||
|
|
||||||
|
user_doc = ensure_user_doc(user_id)
|
||||||
|
shared_with_ids = user_doc.get("agent_preferences", {}).get(
|
||||||
|
"shared_with_me", []
|
||||||
|
)
|
||||||
|
shared_object_ids = [ObjectId(id) for id in shared_with_ids]
|
||||||
|
|
||||||
|
shared_agents_cursor = agents_collection.find(
|
||||||
|
{"_id": {"$in": shared_object_ids}, "shared_publicly": True}
|
||||||
|
)
|
||||||
|
shared_agents = list(shared_agents_cursor)
|
||||||
|
|
||||||
|
found_ids_set = {str(agent["_id"]) for agent in shared_agents}
|
||||||
|
stale_ids = [id for id in shared_with_ids if id not in found_ids_set]
|
||||||
|
if stale_ids:
|
||||||
|
users_collection.update_one(
|
||||||
|
{"user_id": user_id},
|
||||||
|
{"$pullAll": {"agent_preferences.shared_with_me": stale_ids}},
|
||||||
|
)
|
||||||
|
pinned_ids = set(user_doc.get("agent_preferences", {}).get("pinned", []))
|
||||||
|
|
||||||
|
list_shared_agents = [
|
||||||
|
{
|
||||||
|
"id": str(agent["_id"]),
|
||||||
|
"name": agent.get("name", ""),
|
||||||
|
"description": agent.get("description", ""),
|
||||||
|
"image": (
|
||||||
|
generate_image_url(agent["image"]) if agent.get("image") else ""
|
||||||
|
),
|
||||||
|
"tools": agent.get("tools", []),
|
||||||
|
"tool_details": resolve_tool_details(agent.get("tools", [])),
|
||||||
|
"agent_type": agent.get("agent_type", ""),
|
||||||
|
"status": agent.get("status", ""),
|
||||||
|
"json_schema": agent.get("json_schema"),
|
||||||
|
"limited_token_mode": agent.get("limited_token_mode", False),
|
||||||
|
"token_limit": agent.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"]),
|
||||||
|
"limited_request_mode": agent.get("limited_request_mode", False),
|
||||||
|
"request_limit": agent.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"]),
|
||||||
|
"created_at": agent.get("createdAt", ""),
|
||||||
|
"updated_at": agent.get("updatedAt", ""),
|
||||||
|
"pinned": str(agent["_id"]) in pinned_ids,
|
||||||
|
"shared": agent.get("shared_publicly", False),
|
||||||
|
"shared_token": agent.get("shared_token", ""),
|
||||||
|
"shared_metadata": agent.get("shared_metadata", {}),
|
||||||
|
}
|
||||||
|
for agent in shared_agents
|
||||||
|
]
|
||||||
|
|
||||||
|
return make_response(jsonify(list_shared_agents), 200)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(f"Error retrieving shared agents: {err}")
|
||||||
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
|
||||||
|
|
||||||
|
@agents_sharing_ns.route("/share_agent")
|
||||||
|
class ShareAgent(Resource):
|
||||||
|
@api.expect(
|
||||||
|
api.model(
|
||||||
|
"ShareAgentModel",
|
||||||
|
{
|
||||||
|
"id": fields.String(required=True, description="ID of the agent"),
|
||||||
|
"shared": fields.Boolean(
|
||||||
|
required=True, description="Share or unshare the agent"
|
||||||
|
),
|
||||||
|
"username": fields.String(
|
||||||
|
required=False, description="Name of the user"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
@api.doc(description="Share or unshare an agent")
|
||||||
|
def put(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
user = decoded_token.get("sub")
|
||||||
|
|
||||||
|
data = request.get_json()
|
||||||
|
if not data:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Missing JSON body"}), 400
|
||||||
|
)
|
||||||
|
agent_id = data.get("id")
|
||||||
|
shared = data.get("shared")
|
||||||
|
username = data.get("username", "")
|
||||||
|
|
||||||
|
if not agent_id:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "ID is required"}), 400
|
||||||
|
)
|
||||||
|
if shared is None:
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"message": "Shared parameter is required and must be true or false",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
agent_oid = ObjectId(agent_id)
|
||||||
|
except Exception:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Invalid agent ID"}), 400
|
||||||
|
)
|
||||||
|
agent = agents_collection.find_one({"_id": agent_oid, "user": user})
|
||||||
|
if not agent:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Agent not found"}), 404
|
||||||
|
)
|
||||||
|
if shared:
|
||||||
|
shared_metadata = {
|
||||||
|
"shared_by": username,
|
||||||
|
"shared_at": datetime.datetime.now(datetime.timezone.utc),
|
||||||
|
}
|
||||||
|
shared_token = secrets.token_urlsafe(32)
|
||||||
|
agents_collection.update_one(
|
||||||
|
{"_id": agent_oid, "user": user},
|
||||||
|
{
|
||||||
|
"$set": {
|
||||||
|
"shared_publicly": shared,
|
||||||
|
"shared_metadata": shared_metadata,
|
||||||
|
"shared_token": shared_token,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
agents_collection.update_one(
|
||||||
|
{"_id": agent_oid, "user": user},
|
||||||
|
{"$set": {"shared_publicly": shared, "shared_token": None}},
|
||||||
|
{"$unset": {"shared_metadata": ""}},
|
||||||
|
)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(f"Error sharing/unsharing agent: {err}")
|
||||||
|
return make_response(jsonify({"success": False, "error": str(err)}), 400)
|
||||||
|
shared_token = shared_token if shared else None
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": True, "shared_token": shared_token}), 200
|
||||||
|
)
|
||||||
119
application/api/user/agents/webhooks.py
Normal file
119
application/api/user/agents/webhooks.py
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
"""Agent management webhook handlers."""
|
||||||
|
|
||||||
|
import secrets
|
||||||
|
|
||||||
|
from bson.objectid import ObjectId
|
||||||
|
from flask import current_app, jsonify, make_response, request
|
||||||
|
from flask_restx import Namespace, Resource
|
||||||
|
|
||||||
|
from application.api import api
|
||||||
|
from application.api.user.base import agents_collection, require_agent
|
||||||
|
from application.api.user.tasks import process_agent_webhook
|
||||||
|
from application.core.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
agents_webhooks_ns = Namespace(
|
||||||
|
"agents", description="Agent management operations", path="/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@agents_webhooks_ns.route("/agent_webhook")
|
||||||
|
class AgentWebhook(Resource):
|
||||||
|
@api.doc(
|
||||||
|
params={"id": "ID of the agent"},
|
||||||
|
description="Generate webhook URL for the agent",
|
||||||
|
)
|
||||||
|
def get(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
user = decoded_token.get("sub")
|
||||||
|
agent_id = request.args.get("id")
|
||||||
|
if not agent_id:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "ID is required"}), 400
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
agent = agents_collection.find_one(
|
||||||
|
{"_id": ObjectId(agent_id), "user": user}
|
||||||
|
)
|
||||||
|
if not agent:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Agent not found"}), 404
|
||||||
|
)
|
||||||
|
webhook_token = agent.get("incoming_webhook_token")
|
||||||
|
if not webhook_token:
|
||||||
|
webhook_token = secrets.token_urlsafe(32)
|
||||||
|
agents_collection.update_one(
|
||||||
|
{"_id": ObjectId(agent_id), "user": user},
|
||||||
|
{"$set": {"incoming_webhook_token": webhook_token}},
|
||||||
|
)
|
||||||
|
base_url = settings.API_URL.rstrip("/")
|
||||||
|
full_webhook_url = f"{base_url}/api/webhooks/agents/{webhook_token}"
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(
|
||||||
|
f"Error generating webhook URL: {err}", exc_info=True
|
||||||
|
)
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Error generating webhook URL"}),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": True, "webhook_url": full_webhook_url}), 200
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@agents_webhooks_ns.route("/webhooks/agents/<string:webhook_token>")
|
||||||
|
class AgentWebhookListener(Resource):
|
||||||
|
method_decorators = [require_agent]
|
||||||
|
|
||||||
|
def _enqueue_webhook_task(self, agent_id_str, payload, source_method):
|
||||||
|
if not payload:
|
||||||
|
current_app.logger.warning(
|
||||||
|
f"Webhook ({source_method}) received for agent {agent_id_str} with empty payload."
|
||||||
|
)
|
||||||
|
current_app.logger.info(
|
||||||
|
f"Incoming {source_method} webhook for agent {agent_id_str}. Enqueuing task with payload: {payload}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
task = process_agent_webhook.delay(
|
||||||
|
agent_id=agent_id_str,
|
||||||
|
payload=payload,
|
||||||
|
)
|
||||||
|
current_app.logger.info(
|
||||||
|
f"Task {task.id} enqueued for agent {agent_id_str} ({source_method})."
|
||||||
|
)
|
||||||
|
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(
|
||||||
|
f"Error enqueuing webhook task ({source_method}) for agent {agent_id_str}: {err}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Error processing webhook"}), 500
|
||||||
|
)
|
||||||
|
|
||||||
|
@api.doc(
|
||||||
|
description="Webhook listener for agent events (POST). Expects JSON payload, which is used to trigger processing.",
|
||||||
|
)
|
||||||
|
def post(self, webhook_token, agent, agent_id_str):
|
||||||
|
payload = request.get_json()
|
||||||
|
if payload is None:
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"message": "Invalid or missing JSON data in request body",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
return self._enqueue_webhook_task(agent_id_str, payload, source_method="POST")
|
||||||
|
|
||||||
|
@api.doc(
|
||||||
|
description="Webhook listener for agent events (GET). Uses URL query parameters as payload to trigger processing.",
|
||||||
|
)
|
||||||
|
def get(self, webhook_token, agent, agent_id_str):
|
||||||
|
payload = request.args.to_dict(flat=True)
|
||||||
|
return self._enqueue_webhook_task(agent_id_str, payload, source_method="GET")
|
||||||
5
application/api/user/analytics/__init__.py
Normal file
5
application/api/user/analytics/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""Analytics module."""
|
||||||
|
|
||||||
|
from .routes import analytics_ns
|
||||||
|
|
||||||
|
__all__ = ["analytics_ns"]
|
||||||
540
application/api/user/analytics/routes.py
Normal file
540
application/api/user/analytics/routes.py
Normal file
@@ -0,0 +1,540 @@
|
|||||||
|
"""Analytics and reporting routes."""
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
from bson.objectid import ObjectId
|
||||||
|
from flask import current_app, jsonify, make_response, request
|
||||||
|
from flask_restx import fields, Namespace, Resource
|
||||||
|
|
||||||
|
from application.api import api
|
||||||
|
from application.api.user.base import (
|
||||||
|
agents_collection,
|
||||||
|
conversations_collection,
|
||||||
|
generate_date_range,
|
||||||
|
generate_hourly_range,
|
||||||
|
generate_minute_range,
|
||||||
|
token_usage_collection,
|
||||||
|
user_logs_collection,
|
||||||
|
)
|
||||||
|
|
||||||
|
analytics_ns = Namespace(
|
||||||
|
"analytics", description="Analytics and reporting operations", path="/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@analytics_ns.route("/get_message_analytics")
|
||||||
|
class GetMessageAnalytics(Resource):
|
||||||
|
get_message_analytics_model = api.model(
|
||||||
|
"GetMessageAnalyticsModel",
|
||||||
|
{
|
||||||
|
"api_key_id": fields.String(required=False, description="API Key ID"),
|
||||||
|
"filter_option": fields.String(
|
||||||
|
required=False,
|
||||||
|
description="Filter option for analytics",
|
||||||
|
default="last_30_days",
|
||||||
|
enum=[
|
||||||
|
"last_hour",
|
||||||
|
"last_24_hour",
|
||||||
|
"last_7_days",
|
||||||
|
"last_15_days",
|
||||||
|
"last_30_days",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@api.expect(get_message_analytics_model)
|
||||||
|
@api.doc(description="Get message analytics based on filter option")
|
||||||
|
def post(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
user = decoded_token.get("sub")
|
||||||
|
data = request.get_json()
|
||||||
|
api_key_id = data.get("api_key_id")
|
||||||
|
filter_option = data.get("filter_option", "last_30_days")
|
||||||
|
|
||||||
|
try:
|
||||||
|
api_key = (
|
||||||
|
agents_collection.find_one({"_id": ObjectId(api_key_id), "user": user})[
|
||||||
|
"key"
|
||||||
|
]
|
||||||
|
if api_key_id
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
|
||||||
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
end_date = datetime.datetime.now(datetime.timezone.utc)
|
||||||
|
|
||||||
|
if filter_option == "last_hour":
|
||||||
|
start_date = end_date - datetime.timedelta(hours=1)
|
||||||
|
group_format = "%Y-%m-%d %H:%M:00"
|
||||||
|
elif filter_option == "last_24_hour":
|
||||||
|
start_date = end_date - datetime.timedelta(hours=24)
|
||||||
|
group_format = "%Y-%m-%d %H:00"
|
||||||
|
else:
|
||||||
|
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
|
||||||
|
filter_days = (
|
||||||
|
6
|
||||||
|
if filter_option == "last_7_days"
|
||||||
|
else 14 if filter_option == "last_15_days" else 29
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Invalid option"}), 400
|
||||||
|
)
|
||||||
|
start_date = end_date - datetime.timedelta(days=filter_days)
|
||||||
|
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||||
|
end_date = end_date.replace(
|
||||||
|
hour=23, minute=59, second=59, microsecond=999999
|
||||||
|
)
|
||||||
|
group_format = "%Y-%m-%d"
|
||||||
|
try:
|
||||||
|
match_stage = {
|
||||||
|
"$match": {
|
||||||
|
"user": user,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if api_key:
|
||||||
|
match_stage["$match"]["api_key"] = api_key
|
||||||
|
pipeline = [
|
||||||
|
match_stage,
|
||||||
|
{"$unwind": "$queries"},
|
||||||
|
{
|
||||||
|
"$match": {
|
||||||
|
"queries.timestamp": {"$gte": start_date, "$lte": end_date}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$group": {
|
||||||
|
"_id": {
|
||||||
|
"$dateToString": {
|
||||||
|
"format": group_format,
|
||||||
|
"date": "$queries.timestamp",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"count": {"$sum": 1},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{"$sort": {"_id": 1}},
|
||||||
|
]
|
||||||
|
|
||||||
|
message_data = conversations_collection.aggregate(pipeline)
|
||||||
|
|
||||||
|
if filter_option == "last_hour":
|
||||||
|
intervals = generate_minute_range(start_date, end_date)
|
||||||
|
elif filter_option == "last_24_hour":
|
||||||
|
intervals = generate_hourly_range(start_date, end_date)
|
||||||
|
else:
|
||||||
|
intervals = generate_date_range(start_date, end_date)
|
||||||
|
daily_messages = {interval: 0 for interval in intervals}
|
||||||
|
|
||||||
|
for entry in message_data:
|
||||||
|
daily_messages[entry["_id"]] = entry["count"]
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(
|
||||||
|
f"Error getting message analytics: {err}", exc_info=True
|
||||||
|
)
|
||||||
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": True, "messages": daily_messages}), 200
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@analytics_ns.route("/get_token_analytics")
|
||||||
|
class GetTokenAnalytics(Resource):
|
||||||
|
get_token_analytics_model = api.model(
|
||||||
|
"GetTokenAnalyticsModel",
|
||||||
|
{
|
||||||
|
"api_key_id": fields.String(required=False, description="API Key ID"),
|
||||||
|
"filter_option": fields.String(
|
||||||
|
required=False,
|
||||||
|
description="Filter option for analytics",
|
||||||
|
default="last_30_days",
|
||||||
|
enum=[
|
||||||
|
"last_hour",
|
||||||
|
"last_24_hour",
|
||||||
|
"last_7_days",
|
||||||
|
"last_15_days",
|
||||||
|
"last_30_days",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@api.expect(get_token_analytics_model)
|
||||||
|
@api.doc(description="Get token analytics data")
|
||||||
|
def post(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
user = decoded_token.get("sub")
|
||||||
|
data = request.get_json()
|
||||||
|
api_key_id = data.get("api_key_id")
|
||||||
|
filter_option = data.get("filter_option", "last_30_days")
|
||||||
|
|
||||||
|
try:
|
||||||
|
api_key = (
|
||||||
|
agents_collection.find_one({"_id": ObjectId(api_key_id), "user": user})[
|
||||||
|
"key"
|
||||||
|
]
|
||||||
|
if api_key_id
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
|
||||||
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
end_date = datetime.datetime.now(datetime.timezone.utc)
|
||||||
|
|
||||||
|
if filter_option == "last_hour":
|
||||||
|
start_date = end_date - datetime.timedelta(hours=1)
|
||||||
|
group_format = "%Y-%m-%d %H:%M:00"
|
||||||
|
group_stage = {
|
||||||
|
"$group": {
|
||||||
|
"_id": {
|
||||||
|
"minute": {
|
||||||
|
"$dateToString": {
|
||||||
|
"format": group_format,
|
||||||
|
"date": "$timestamp",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"total_tokens": {
|
||||||
|
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
elif filter_option == "last_24_hour":
|
||||||
|
start_date = end_date - datetime.timedelta(hours=24)
|
||||||
|
group_format = "%Y-%m-%d %H:00"
|
||||||
|
group_stage = {
|
||||||
|
"$group": {
|
||||||
|
"_id": {
|
||||||
|
"hour": {
|
||||||
|
"$dateToString": {
|
||||||
|
"format": group_format,
|
||||||
|
"date": "$timestamp",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"total_tokens": {
|
||||||
|
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
|
||||||
|
filter_days = (
|
||||||
|
6
|
||||||
|
if filter_option == "last_7_days"
|
||||||
|
else (14 if filter_option == "last_15_days" else 29)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Invalid option"}), 400
|
||||||
|
)
|
||||||
|
start_date = end_date - datetime.timedelta(days=filter_days)
|
||||||
|
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||||
|
end_date = end_date.replace(
|
||||||
|
hour=23, minute=59, second=59, microsecond=999999
|
||||||
|
)
|
||||||
|
group_format = "%Y-%m-%d"
|
||||||
|
group_stage = {
|
||||||
|
"$group": {
|
||||||
|
"_id": {
|
||||||
|
"day": {
|
||||||
|
"$dateToString": {
|
||||||
|
"format": group_format,
|
||||||
|
"date": "$timestamp",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"total_tokens": {
|
||||||
|
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
match_stage = {
|
||||||
|
"$match": {
|
||||||
|
"user_id": user,
|
||||||
|
"timestamp": {"$gte": start_date, "$lte": end_date},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if api_key:
|
||||||
|
match_stage["$match"]["api_key"] = api_key
|
||||||
|
token_usage_data = token_usage_collection.aggregate(
|
||||||
|
[
|
||||||
|
match_stage,
|
||||||
|
group_stage,
|
||||||
|
{"$sort": {"_id": 1}},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if filter_option == "last_hour":
|
||||||
|
intervals = generate_minute_range(start_date, end_date)
|
||||||
|
elif filter_option == "last_24_hour":
|
||||||
|
intervals = generate_hourly_range(start_date, end_date)
|
||||||
|
else:
|
||||||
|
intervals = generate_date_range(start_date, end_date)
|
||||||
|
daily_token_usage = {interval: 0 for interval in intervals}
|
||||||
|
|
||||||
|
for entry in token_usage_data:
|
||||||
|
if filter_option == "last_hour":
|
||||||
|
daily_token_usage[entry["_id"]["minute"]] = entry["total_tokens"]
|
||||||
|
elif filter_option == "last_24_hour":
|
||||||
|
daily_token_usage[entry["_id"]["hour"]] = entry["total_tokens"]
|
||||||
|
else:
|
||||||
|
daily_token_usage[entry["_id"]["day"]] = entry["total_tokens"]
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(
|
||||||
|
f"Error getting token analytics: {err}", exc_info=True
|
||||||
|
)
|
||||||
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": True, "token_usage": daily_token_usage}), 200
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@analytics_ns.route("/get_feedback_analytics")
|
||||||
|
class GetFeedbackAnalytics(Resource):
|
||||||
|
get_feedback_analytics_model = api.model(
|
||||||
|
"GetFeedbackAnalyticsModel",
|
||||||
|
{
|
||||||
|
"api_key_id": fields.String(required=False, description="API Key ID"),
|
||||||
|
"filter_option": fields.String(
|
||||||
|
required=False,
|
||||||
|
description="Filter option for analytics",
|
||||||
|
default="last_30_days",
|
||||||
|
enum=[
|
||||||
|
"last_hour",
|
||||||
|
"last_24_hour",
|
||||||
|
"last_7_days",
|
||||||
|
"last_15_days",
|
||||||
|
"last_30_days",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@api.expect(get_feedback_analytics_model)
|
||||||
|
@api.doc(description="Get feedback analytics data")
|
||||||
|
def post(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
user = decoded_token.get("sub")
|
||||||
|
data = request.get_json()
|
||||||
|
api_key_id = data.get("api_key_id")
|
||||||
|
filter_option = data.get("filter_option", "last_30_days")
|
||||||
|
|
||||||
|
try:
|
||||||
|
api_key = (
|
||||||
|
agents_collection.find_one({"_id": ObjectId(api_key_id), "user": user})[
|
||||||
|
"key"
|
||||||
|
]
|
||||||
|
if api_key_id
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
|
||||||
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
end_date = datetime.datetime.now(datetime.timezone.utc)
|
||||||
|
|
||||||
|
if filter_option == "last_hour":
|
||||||
|
start_date = end_date - datetime.timedelta(hours=1)
|
||||||
|
group_format = "%Y-%m-%d %H:%M:00"
|
||||||
|
date_field = {
|
||||||
|
"$dateToString": {
|
||||||
|
"format": group_format,
|
||||||
|
"date": "$queries.feedback_timestamp",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
elif filter_option == "last_24_hour":
|
||||||
|
start_date = end_date - datetime.timedelta(hours=24)
|
||||||
|
group_format = "%Y-%m-%d %H:00"
|
||||||
|
date_field = {
|
||||||
|
"$dateToString": {
|
||||||
|
"format": group_format,
|
||||||
|
"date": "$queries.feedback_timestamp",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
|
||||||
|
filter_days = (
|
||||||
|
6
|
||||||
|
if filter_option == "last_7_days"
|
||||||
|
else (14 if filter_option == "last_15_days" else 29)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Invalid option"}), 400
|
||||||
|
)
|
||||||
|
start_date = end_date - datetime.timedelta(days=filter_days)
|
||||||
|
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||||
|
end_date = end_date.replace(
|
||||||
|
hour=23, minute=59, second=59, microsecond=999999
|
||||||
|
)
|
||||||
|
group_format = "%Y-%m-%d"
|
||||||
|
date_field = {
|
||||||
|
"$dateToString": {
|
||||||
|
"format": group_format,
|
||||||
|
"date": "$queries.feedback_timestamp",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
match_stage = {
|
||||||
|
"$match": {
|
||||||
|
"queries.feedback_timestamp": {
|
||||||
|
"$gte": start_date,
|
||||||
|
"$lte": end_date,
|
||||||
|
},
|
||||||
|
"queries.feedback": {"$exists": True},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if api_key:
|
||||||
|
match_stage["$match"]["api_key"] = api_key
|
||||||
|
pipeline = [
|
||||||
|
match_stage,
|
||||||
|
{"$unwind": "$queries"},
|
||||||
|
{"$match": {"queries.feedback": {"$exists": True}}},
|
||||||
|
{
|
||||||
|
"$group": {
|
||||||
|
"_id": {"time": date_field, "feedback": "$queries.feedback"},
|
||||||
|
"count": {"$sum": 1},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$group": {
|
||||||
|
"_id": "$_id.time",
|
||||||
|
"positive": {
|
||||||
|
"$sum": {
|
||||||
|
"$cond": [
|
||||||
|
{"$eq": ["$_id.feedback", "LIKE"]},
|
||||||
|
"$count",
|
||||||
|
0,
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"negative": {
|
||||||
|
"$sum": {
|
||||||
|
"$cond": [
|
||||||
|
{"$eq": ["$_id.feedback", "DISLIKE"]},
|
||||||
|
"$count",
|
||||||
|
0,
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{"$sort": {"_id": 1}},
|
||||||
|
]
|
||||||
|
|
||||||
|
feedback_data = conversations_collection.aggregate(pipeline)
|
||||||
|
|
||||||
|
if filter_option == "last_hour":
|
||||||
|
intervals = generate_minute_range(start_date, end_date)
|
||||||
|
elif filter_option == "last_24_hour":
|
||||||
|
intervals = generate_hourly_range(start_date, end_date)
|
||||||
|
else:
|
||||||
|
intervals = generate_date_range(start_date, end_date)
|
||||||
|
daily_feedback = {
|
||||||
|
interval: {"positive": 0, "negative": 0} for interval in intervals
|
||||||
|
}
|
||||||
|
|
||||||
|
for entry in feedback_data:
|
||||||
|
daily_feedback[entry["_id"]] = {
|
||||||
|
"positive": entry["positive"],
|
||||||
|
"negative": entry["negative"],
|
||||||
|
}
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(
|
||||||
|
f"Error getting feedback analytics: {err}", exc_info=True
|
||||||
|
)
|
||||||
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": True, "feedback": daily_feedback}), 200
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@analytics_ns.route("/get_user_logs")
|
||||||
|
class GetUserLogs(Resource):
|
||||||
|
get_user_logs_model = api.model(
|
||||||
|
"GetUserLogsModel",
|
||||||
|
{
|
||||||
|
"page": fields.Integer(
|
||||||
|
required=False,
|
||||||
|
description="Page number for pagination",
|
||||||
|
default=1,
|
||||||
|
),
|
||||||
|
"api_key_id": fields.String(required=False, description="API Key ID"),
|
||||||
|
"page_size": fields.Integer(
|
||||||
|
required=False,
|
||||||
|
description="Number of logs per page",
|
||||||
|
default=10,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@api.expect(get_user_logs_model)
|
||||||
|
@api.doc(description="Get user logs with pagination")
|
||||||
|
def post(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
user = decoded_token.get("sub")
|
||||||
|
data = request.get_json()
|
||||||
|
page = int(data.get("page", 1))
|
||||||
|
api_key_id = data.get("api_key_id")
|
||||||
|
page_size = int(data.get("page_size", 10))
|
||||||
|
skip = (page - 1) * page_size
|
||||||
|
|
||||||
|
try:
|
||||||
|
api_key = (
|
||||||
|
agents_collection.find_one({"_id": ObjectId(api_key_id)})["key"]
|
||||||
|
if api_key_id
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
|
||||||
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
query = {"user": user}
|
||||||
|
if api_key:
|
||||||
|
query = {"api_key": api_key}
|
||||||
|
items_cursor = (
|
||||||
|
user_logs_collection.find(query)
|
||||||
|
.sort("timestamp", -1)
|
||||||
|
.skip(skip)
|
||||||
|
.limit(page_size + 1)
|
||||||
|
)
|
||||||
|
items = list(items_cursor)
|
||||||
|
|
||||||
|
results = [
|
||||||
|
{
|
||||||
|
"id": str(item.get("_id")),
|
||||||
|
"action": item.get("action"),
|
||||||
|
"level": item.get("level"),
|
||||||
|
"user": item.get("user"),
|
||||||
|
"question": item.get("question"),
|
||||||
|
"sources": item.get("sources"),
|
||||||
|
"retriever_params": item.get("retriever_params"),
|
||||||
|
"timestamp": item.get("timestamp"),
|
||||||
|
}
|
||||||
|
for item in items[:page_size]
|
||||||
|
]
|
||||||
|
|
||||||
|
has_more = len(items) > page_size
|
||||||
|
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": True,
|
||||||
|
"logs": results,
|
||||||
|
"page": page,
|
||||||
|
"page_size": page_size,
|
||||||
|
"has_more": has_more,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
200,
|
||||||
|
)
|
||||||
5
application/api/user/attachments/__init__.py
Normal file
5
application/api/user/attachments/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""Attachments module."""
|
||||||
|
|
||||||
|
from .routes import attachments_ns
|
||||||
|
|
||||||
|
__all__ = ["attachments_ns"]
|
||||||
154
application/api/user/attachments/routes.py
Normal file
154
application/api/user/attachments/routes.py
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
"""File attachments and media routes."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from bson.objectid import ObjectId
|
||||||
|
from flask import current_app, jsonify, make_response, request
|
||||||
|
from flask_restx import fields, Namespace, Resource
|
||||||
|
|
||||||
|
from application.api import api
|
||||||
|
from application.api.user.base import agents_collection, storage
|
||||||
|
from application.api.user.tasks import store_attachment
|
||||||
|
from application.core.settings import settings
|
||||||
|
from application.tts.tts_creator import TTSCreator
|
||||||
|
from application.utils import safe_filename
|
||||||
|
|
||||||
|
|
||||||
|
attachments_ns = Namespace(
|
||||||
|
"attachments", description="File attachments and media operations", path="/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@attachments_ns.route("/store_attachment")
|
||||||
|
class StoreAttachment(Resource):
|
||||||
|
@api.expect(
|
||||||
|
api.model(
|
||||||
|
"AttachmentModel",
|
||||||
|
{
|
||||||
|
"file": fields.Raw(required=True, description="File to upload"),
|
||||||
|
"api_key": fields.String(
|
||||||
|
required=False, description="API key (optional)"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
@api.doc(
|
||||||
|
description="Stores a single attachment without vectorization or training. Supports user or API key authentication."
|
||||||
|
)
|
||||||
|
def post(self):
|
||||||
|
decoded_token = getattr(request, "decoded_token", None)
|
||||||
|
api_key = request.form.get("api_key") or request.args.get("api_key")
|
||||||
|
file = request.files.get("file")
|
||||||
|
|
||||||
|
if not file or file.filename == "":
|
||||||
|
return make_response(
|
||||||
|
jsonify({"status": "error", "message": "Missing file"}),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
user = None
|
||||||
|
if decoded_token:
|
||||||
|
user = safe_filename(decoded_token.get("sub"))
|
||||||
|
elif api_key:
|
||||||
|
agent = agents_collection.find_one({"key": api_key})
|
||||||
|
if not agent:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Invalid API key"}), 401
|
||||||
|
)
|
||||||
|
user = safe_filename(agent.get("user"))
|
||||||
|
else:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Authentication required"}), 401
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
attachment_id = ObjectId()
|
||||||
|
original_filename = safe_filename(os.path.basename(file.filename))
|
||||||
|
relative_path = f"{settings.UPLOAD_FOLDER}/{user}/attachments/{str(attachment_id)}/{original_filename}"
|
||||||
|
|
||||||
|
metadata = storage.save_file(file, relative_path)
|
||||||
|
|
||||||
|
file_info = {
|
||||||
|
"filename": original_filename,
|
||||||
|
"attachment_id": str(attachment_id),
|
||||||
|
"path": relative_path,
|
||||||
|
"metadata": metadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
task = store_attachment.delay(file_info, user)
|
||||||
|
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": True,
|
||||||
|
"task_id": task.id,
|
||||||
|
"message": "File uploaded successfully. Processing started.",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
200,
|
||||||
|
)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(f"Error storing attachment: {err}", exc_info=True)
|
||||||
|
return make_response(jsonify({"success": False, "error": str(err)}), 400)
|
||||||
|
|
||||||
|
|
||||||
|
@attachments_ns.route("/images/<path:image_path>")
|
||||||
|
class ServeImage(Resource):
|
||||||
|
@api.doc(description="Serve an image from storage")
|
||||||
|
def get(self, image_path):
|
||||||
|
try:
|
||||||
|
file_obj = storage.get_file(image_path)
|
||||||
|
extension = image_path.split(".")[-1].lower()
|
||||||
|
content_type = f"image/{extension}"
|
||||||
|
if extension == "jpg":
|
||||||
|
content_type = "image/jpeg"
|
||||||
|
response = make_response(file_obj.read())
|
||||||
|
response.headers.set("Content-Type", content_type)
|
||||||
|
response.headers.set("Cache-Control", "max-age=86400")
|
||||||
|
|
||||||
|
return response
|
||||||
|
except FileNotFoundError:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Image not found"}), 404
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
current_app.logger.error(f"Error serving image: {e}")
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Error retrieving image"}), 500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@attachments_ns.route("/tts")
|
||||||
|
class TextToSpeech(Resource):
|
||||||
|
tts_model = api.model(
|
||||||
|
"TextToSpeechModel",
|
||||||
|
{
|
||||||
|
"text": fields.String(
|
||||||
|
required=True, description="Text to be synthesized as audio"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@api.expect(tts_model)
|
||||||
|
@api.doc(description="Synthesize audio speech from text")
|
||||||
|
def post(self):
|
||||||
|
from application.utils import clean_text_for_tts
|
||||||
|
|
||||||
|
data = request.get_json()
|
||||||
|
text = data["text"]
|
||||||
|
cleaned_text = clean_text_for_tts(text)
|
||||||
|
|
||||||
|
try:
|
||||||
|
tts_instance = TTSCreator.create_tts(settings.TTS_PROVIDER)
|
||||||
|
audio_base64, detected_language = tts_instance.text_to_speech(cleaned_text)
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": True,
|
||||||
|
"audio_base64": audio_base64,
|
||||||
|
"lang": detected_language,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
200,
|
||||||
|
)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(f"Error synthesizing audio: {err}", exc_info=True)
|
||||||
|
return make_response(jsonify({"success": False}), 400)
|
||||||
222
application/api/user/base.py
Normal file
222
application/api/user/base.py
Normal file
@@ -0,0 +1,222 @@
|
|||||||
|
"""
|
||||||
|
Shared utilities, database connections, and helper functions for user API routes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from functools import wraps
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
from bson.objectid import ObjectId
|
||||||
|
from flask import current_app, jsonify, make_response, Response
|
||||||
|
from pymongo import ReturnDocument
|
||||||
|
from werkzeug.utils import secure_filename
|
||||||
|
|
||||||
|
from application.core.mongo_db import MongoDB
|
||||||
|
from application.core.settings import settings
|
||||||
|
from application.storage.storage_creator import StorageCreator
|
||||||
|
from application.vectorstore.vector_creator import VectorCreator
|
||||||
|
|
||||||
|
|
||||||
|
storage = StorageCreator.get_storage()
|
||||||
|
|
||||||
|
|
||||||
|
mongo = MongoDB.get_client()
|
||||||
|
db = mongo[settings.MONGO_DB_NAME]
|
||||||
|
|
||||||
|
|
||||||
|
conversations_collection = db["conversations"]
|
||||||
|
sources_collection = db["sources"]
|
||||||
|
prompts_collection = db["prompts"]
|
||||||
|
feedback_collection = db["feedback"]
|
||||||
|
agents_collection = db["agents"]
|
||||||
|
token_usage_collection = db["token_usage"]
|
||||||
|
shared_conversations_collections = db["shared_conversations"]
|
||||||
|
users_collection = db["users"]
|
||||||
|
user_logs_collection = db["user_logs"]
|
||||||
|
user_tools_collection = db["user_tools"]
|
||||||
|
attachments_collection = db["attachments"]
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
agents_collection.create_index(
|
||||||
|
[("shared", 1)],
|
||||||
|
name="shared_index",
|
||||||
|
background=True,
|
||||||
|
)
|
||||||
|
users_collection.create_index("user_id", unique=True)
|
||||||
|
except Exception as e:
|
||||||
|
print("Error creating indexes:", e)
|
||||||
|
current_dir = os.path.dirname(
|
||||||
|
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_minute_range(start_date, end_date):
|
||||||
|
"""Generate a dictionary with minute-level time ranges."""
|
||||||
|
return {
|
||||||
|
(start_date + datetime.timedelta(minutes=i)).strftime("%Y-%m-%d %H:%M:00"): 0
|
||||||
|
for i in range(int((end_date - start_date).total_seconds() // 60) + 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def generate_hourly_range(start_date, end_date):
|
||||||
|
"""Generate a dictionary with hourly time ranges."""
|
||||||
|
return {
|
||||||
|
(start_date + datetime.timedelta(hours=i)).strftime("%Y-%m-%d %H:00"): 0
|
||||||
|
for i in range(int((end_date - start_date).total_seconds() // 3600) + 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def generate_date_range(start_date, end_date):
|
||||||
|
"""Generate a dictionary with daily date ranges."""
|
||||||
|
return {
|
||||||
|
(start_date + datetime.timedelta(days=i)).strftime("%Y-%m-%d"): 0
|
||||||
|
for i in range((end_date - start_date).days + 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_user_doc(user_id):
|
||||||
|
"""
|
||||||
|
Ensure user document exists with proper agent preferences structure.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID to ensure
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The user document
|
||||||
|
"""
|
||||||
|
default_prefs = {
|
||||||
|
"pinned": [],
|
||||||
|
"shared_with_me": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
user_doc = users_collection.find_one_and_update(
|
||||||
|
{"user_id": user_id},
|
||||||
|
{"$setOnInsert": {"agent_preferences": default_prefs}},
|
||||||
|
upsert=True,
|
||||||
|
return_document=ReturnDocument.AFTER,
|
||||||
|
)
|
||||||
|
|
||||||
|
prefs = user_doc.get("agent_preferences", {})
|
||||||
|
updates = {}
|
||||||
|
if "pinned" not in prefs:
|
||||||
|
updates["agent_preferences.pinned"] = []
|
||||||
|
if "shared_with_me" not in prefs:
|
||||||
|
updates["agent_preferences.shared_with_me"] = []
|
||||||
|
if updates:
|
||||||
|
users_collection.update_one({"user_id": user_id}, {"$set": updates})
|
||||||
|
user_doc = users_collection.find_one({"user_id": user_id})
|
||||||
|
return user_doc
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_tool_details(tool_ids):
|
||||||
|
"""
|
||||||
|
Resolve tool IDs to their details.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_ids: List of tool IDs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of tool details with id, name, and display_name
|
||||||
|
"""
|
||||||
|
tools = user_tools_collection.find(
|
||||||
|
{"_id": {"$in": [ObjectId(tid) for tid in tool_ids]}}
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": str(tool["_id"]),
|
||||||
|
"name": tool.get("name", ""),
|
||||||
|
"display_name": tool.get("displayName", tool.get("name", "")),
|
||||||
|
}
|
||||||
|
for tool in tools
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_vector_store(source_id):
|
||||||
|
"""
|
||||||
|
Get the Vector Store for a given source ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_id (str): source id of the document
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Vector store instance
|
||||||
|
"""
|
||||||
|
store = VectorCreator.create_vectorstore(
|
||||||
|
settings.VECTOR_STORE,
|
||||||
|
source_id=source_id,
|
||||||
|
embeddings_key=os.getenv("EMBEDDINGS_KEY"),
|
||||||
|
)
|
||||||
|
return store
|
||||||
|
|
||||||
|
|
||||||
|
def handle_image_upload(
|
||||||
|
request, existing_url: str, user: str, storage, base_path: str = "attachments/"
|
||||||
|
) -> Tuple[str, Optional[Response]]:
|
||||||
|
"""
|
||||||
|
Handle image file upload from request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Flask request object
|
||||||
|
existing_url: Existing image URL (fallback)
|
||||||
|
user: User ID
|
||||||
|
storage: Storage instance
|
||||||
|
base_path: Base path for upload
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (image_url, error_response)
|
||||||
|
"""
|
||||||
|
image_url = existing_url
|
||||||
|
|
||||||
|
if "image" in request.files:
|
||||||
|
file = request.files["image"]
|
||||||
|
if file.filename != "":
|
||||||
|
filename = secure_filename(file.filename)
|
||||||
|
upload_path = f"{settings.UPLOAD_FOLDER.rstrip('/')}/{user}/{base_path.rstrip('/')}/{uuid.uuid4()}_{filename}"
|
||||||
|
try:
|
||||||
|
storage.save_file(file, upload_path, storage_class="STANDARD")
|
||||||
|
image_url = upload_path
|
||||||
|
except Exception as e:
|
||||||
|
current_app.logger.error(f"Error uploading image: {e}")
|
||||||
|
return None, make_response(
|
||||||
|
jsonify({"success": False, "message": "Image upload failed"}),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
return image_url, None
|
||||||
|
|
||||||
|
|
||||||
|
def require_agent(func):
|
||||||
|
"""
|
||||||
|
Decorator to require valid agent webhook token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: Function to decorate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Wrapped function
|
||||||
|
"""
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
webhook_token = kwargs.get("webhook_token")
|
||||||
|
if not webhook_token:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Webhook token missing"}), 400
|
||||||
|
)
|
||||||
|
agent = agents_collection.find_one(
|
||||||
|
{"incoming_webhook_token": webhook_token}, {"_id": 1}
|
||||||
|
)
|
||||||
|
if not agent:
|
||||||
|
current_app.logger.warning(
|
||||||
|
f"Webhook attempt with invalid token: {webhook_token}"
|
||||||
|
)
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Agent not found"}), 404
|
||||||
|
)
|
||||||
|
kwargs["agent"] = agent
|
||||||
|
kwargs["agent_id_str"] = str(agent["_id"])
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
5
application/api/user/conversations/__init__.py
Normal file
5
application/api/user/conversations/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""Conversation management module."""
|
||||||
|
|
||||||
|
from .routes import conversations_ns
|
||||||
|
|
||||||
|
__all__ = ["conversations_ns"]
|
||||||
280
application/api/user/conversations/routes.py
Normal file
280
application/api/user/conversations/routes.py
Normal file
@@ -0,0 +1,280 @@
|
|||||||
|
"""Conversation management routes."""
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
from bson.objectid import ObjectId
|
||||||
|
from flask import current_app, jsonify, make_response, request
|
||||||
|
from flask_restx import fields, Namespace, Resource
|
||||||
|
|
||||||
|
from application.api import api
|
||||||
|
from application.api.user.base import attachments_collection, conversations_collection
|
||||||
|
from application.utils import check_required_fields
|
||||||
|
|
||||||
|
conversations_ns = Namespace(
|
||||||
|
"conversations", description="Conversation management operations", path="/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@conversations_ns.route("/delete_conversation")
|
||||||
|
class DeleteConversation(Resource):
|
||||||
|
@api.doc(
|
||||||
|
description="Deletes a conversation by ID",
|
||||||
|
params={"id": "The ID of the conversation to delete"},
|
||||||
|
)
|
||||||
|
def post(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
conversation_id = request.args.get("id")
|
||||||
|
if not conversation_id:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "ID is required"}), 400
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
conversations_collection.delete_one(
|
||||||
|
{"_id": ObjectId(conversation_id), "user": decoded_token["sub"]}
|
||||||
|
)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(
|
||||||
|
f"Error deleting conversation: {err}", exc_info=True
|
||||||
|
)
|
||||||
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
return make_response(jsonify({"success": True}), 200)
|
||||||
|
|
||||||
|
|
||||||
|
@conversations_ns.route("/delete_all_conversations")
|
||||||
|
class DeleteAllConversations(Resource):
|
||||||
|
@api.doc(
|
||||||
|
description="Deletes all conversations for a specific user",
|
||||||
|
)
|
||||||
|
def get(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
user_id = decoded_token.get("sub")
|
||||||
|
try:
|
||||||
|
conversations_collection.delete_many({"user": user_id})
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(
|
||||||
|
f"Error deleting all conversations: {err}", exc_info=True
|
||||||
|
)
|
||||||
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
return make_response(jsonify({"success": True}), 200)
|
||||||
|
|
||||||
|
|
||||||
|
@conversations_ns.route("/get_conversations")
|
||||||
|
class GetConversations(Resource):
|
||||||
|
@api.doc(
|
||||||
|
description="Retrieve a list of the latest 30 conversations (excluding API key conversations)",
|
||||||
|
)
|
||||||
|
def get(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
try:
|
||||||
|
conversations = (
|
||||||
|
conversations_collection.find(
|
||||||
|
{
|
||||||
|
"$or": [
|
||||||
|
{"api_key": {"$exists": False}},
|
||||||
|
{"agent_id": {"$exists": True}},
|
||||||
|
],
|
||||||
|
"user": decoded_token.get("sub"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
.sort("date", -1)
|
||||||
|
.limit(30)
|
||||||
|
)
|
||||||
|
|
||||||
|
list_conversations = [
|
||||||
|
{
|
||||||
|
"id": str(conversation["_id"]),
|
||||||
|
"name": conversation["name"],
|
||||||
|
"agent_id": conversation.get("agent_id", None),
|
||||||
|
"is_shared_usage": conversation.get("is_shared_usage", False),
|
||||||
|
"shared_token": conversation.get("shared_token", None),
|
||||||
|
}
|
||||||
|
for conversation in conversations
|
||||||
|
]
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(
|
||||||
|
f"Error retrieving conversations: {err}", exc_info=True
|
||||||
|
)
|
||||||
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
return make_response(jsonify(list_conversations), 200)
|
||||||
|
|
||||||
|
|
||||||
|
@conversations_ns.route("/get_single_conversation")
|
||||||
|
class GetSingleConversation(Resource):
|
||||||
|
@api.doc(
|
||||||
|
description="Retrieve a single conversation by ID",
|
||||||
|
params={"id": "The conversation ID"},
|
||||||
|
)
|
||||||
|
def get(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
conversation_id = request.args.get("id")
|
||||||
|
if not conversation_id:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "ID is required"}), 400
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
conversation = conversations_collection.find_one(
|
||||||
|
{"_id": ObjectId(conversation_id), "user": decoded_token.get("sub")}
|
||||||
|
)
|
||||||
|
if not conversation:
|
||||||
|
return make_response(jsonify({"status": "not found"}), 404)
|
||||||
|
# Process queries to include attachment names
|
||||||
|
|
||||||
|
queries = conversation["queries"]
|
||||||
|
for query in queries:
|
||||||
|
if "attachments" in query and query["attachments"]:
|
||||||
|
attachment_details = []
|
||||||
|
for attachment_id in query["attachments"]:
|
||||||
|
try:
|
||||||
|
attachment = attachments_collection.find_one(
|
||||||
|
{"_id": ObjectId(attachment_id)}
|
||||||
|
)
|
||||||
|
if attachment:
|
||||||
|
attachment_details.append(
|
||||||
|
{
|
||||||
|
"id": str(attachment["_id"]),
|
||||||
|
"fileName": attachment.get(
|
||||||
|
"filename", "Unknown file"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
current_app.logger.error(
|
||||||
|
f"Error retrieving attachment {attachment_id}: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
query["attachments"] = attachment_details
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(
|
||||||
|
f"Error retrieving conversation: {err}", exc_info=True
|
||||||
|
)
|
||||||
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
data = {
|
||||||
|
"queries": queries,
|
||||||
|
"agent_id": conversation.get("agent_id"),
|
||||||
|
"is_shared_usage": conversation.get("is_shared_usage", False),
|
||||||
|
"shared_token": conversation.get("shared_token", None),
|
||||||
|
}
|
||||||
|
return make_response(jsonify(data), 200)
|
||||||
|
|
||||||
|
|
||||||
|
@conversations_ns.route("/update_conversation_name")
|
||||||
|
class UpdateConversationName(Resource):
|
||||||
|
@api.expect(
|
||||||
|
api.model(
|
||||||
|
"UpdateConversationModel",
|
||||||
|
{
|
||||||
|
"id": fields.String(required=True, description="Conversation ID"),
|
||||||
|
"name": fields.String(
|
||||||
|
required=True, description="New name of the conversation"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
@api.doc(
|
||||||
|
description="Updates the name of a conversation",
|
||||||
|
)
|
||||||
|
def post(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
data = request.get_json()
|
||||||
|
required_fields = ["id", "name"]
|
||||||
|
missing_fields = check_required_fields(data, required_fields)
|
||||||
|
if missing_fields:
|
||||||
|
return missing_fields
|
||||||
|
try:
|
||||||
|
conversations_collection.update_one(
|
||||||
|
{"_id": ObjectId(data["id"]), "user": decoded_token.get("sub")},
|
||||||
|
{"$set": {"name": data["name"]}},
|
||||||
|
)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(
|
||||||
|
f"Error updating conversation name: {err}", exc_info=True
|
||||||
|
)
|
||||||
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
return make_response(jsonify({"success": True}), 200)
|
||||||
|
|
||||||
|
|
||||||
|
@conversations_ns.route("/feedback")
|
||||||
|
class SubmitFeedback(Resource):
|
||||||
|
@api.expect(
|
||||||
|
api.model(
|
||||||
|
"FeedbackModel",
|
||||||
|
{
|
||||||
|
"question": fields.String(
|
||||||
|
required=False, description="The user question"
|
||||||
|
),
|
||||||
|
"answer": fields.String(required=False, description="The AI answer"),
|
||||||
|
"feedback": fields.String(required=True, description="User feedback"),
|
||||||
|
"question_index": fields.Integer(
|
||||||
|
required=True,
|
||||||
|
description="The question number in that particular conversation",
|
||||||
|
),
|
||||||
|
"conversation_id": fields.String(
|
||||||
|
required=True, description="id of the particular conversation"
|
||||||
|
),
|
||||||
|
"api_key": fields.String(description="Optional API key"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
@api.doc(
|
||||||
|
description="Submit feedback for a conversation",
|
||||||
|
)
|
||||||
|
def post(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
data = request.get_json()
|
||||||
|
required_fields = ["feedback", "conversation_id", "question_index"]
|
||||||
|
missing_fields = check_required_fields(data, required_fields)
|
||||||
|
if missing_fields:
|
||||||
|
return missing_fields
|
||||||
|
try:
|
||||||
|
if data["feedback"] is None:
|
||||||
|
# Remove feedback and feedback_timestamp if feedback is null
|
||||||
|
|
||||||
|
conversations_collection.update_one(
|
||||||
|
{
|
||||||
|
"_id": ObjectId(data["conversation_id"]),
|
||||||
|
"user": decoded_token.get("sub"),
|
||||||
|
f"queries.{data['question_index']}": {"$exists": True},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$unset": {
|
||||||
|
f"queries.{data['question_index']}.feedback": "",
|
||||||
|
f"queries.{data['question_index']}.feedback_timestamp": "",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Set feedback and feedback_timestamp if feedback has a value
|
||||||
|
|
||||||
|
conversations_collection.update_one(
|
||||||
|
{
|
||||||
|
"_id": ObjectId(data["conversation_id"]),
|
||||||
|
"user": decoded_token.get("sub"),
|
||||||
|
f"queries.{data['question_index']}": {"$exists": True},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$set": {
|
||||||
|
f"queries.{data['question_index']}.feedback": data[
|
||||||
|
"feedback"
|
||||||
|
],
|
||||||
|
f"queries.{data['question_index']}.feedback_timestamp": datetime.datetime.now(
|
||||||
|
datetime.timezone.utc
|
||||||
|
),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(f"Error submitting feedback: {err}", exc_info=True)
|
||||||
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
return make_response(jsonify({"success": True}), 200)
|
||||||
5
application/api/user/prompts/__init__.py
Normal file
5
application/api/user/prompts/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""Prompts module."""
|
||||||
|
|
||||||
|
from .routes import prompts_ns
|
||||||
|
|
||||||
|
__all__ = ["prompts_ns"]
|
||||||
191
application/api/user/prompts/routes.py
Normal file
191
application/api/user/prompts/routes.py
Normal file
@@ -0,0 +1,191 @@
|
|||||||
|
"""Prompt management routes."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from bson.objectid import ObjectId
|
||||||
|
from flask import current_app, jsonify, make_response, request
|
||||||
|
from flask_restx import fields, Namespace, Resource
|
||||||
|
|
||||||
|
from application.api import api
|
||||||
|
from application.api.user.base import current_dir, prompts_collection
|
||||||
|
from application.utils import check_required_fields
|
||||||
|
|
||||||
|
prompts_ns = Namespace(
|
||||||
|
"prompts", description="Prompt management operations", path="/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@prompts_ns.route("/create_prompt")
|
||||||
|
class CreatePrompt(Resource):
|
||||||
|
create_prompt_model = api.model(
|
||||||
|
"CreatePromptModel",
|
||||||
|
{
|
||||||
|
"content": fields.String(
|
||||||
|
required=True, description="Content of the prompt"
|
||||||
|
),
|
||||||
|
"name": fields.String(required=True, description="Name of the prompt"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@api.expect(create_prompt_model)
|
||||||
|
@api.doc(description="Create a new prompt")
|
||||||
|
def post(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
data = request.get_json()
|
||||||
|
required_fields = ["content", "name"]
|
||||||
|
missing_fields = check_required_fields(data, required_fields)
|
||||||
|
if missing_fields:
|
||||||
|
return missing_fields
|
||||||
|
user = decoded_token.get("sub")
|
||||||
|
try:
|
||||||
|
|
||||||
|
resp = prompts_collection.insert_one(
|
||||||
|
{
|
||||||
|
"name": data["name"],
|
||||||
|
"content": data["content"],
|
||||||
|
"user": user,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
new_id = str(resp.inserted_id)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(f"Error creating prompt: {err}", exc_info=True)
|
||||||
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
return make_response(jsonify({"id": new_id}), 200)
|
||||||
|
|
||||||
|
|
||||||
|
@prompts_ns.route("/get_prompts")
|
||||||
|
class GetPrompts(Resource):
|
||||||
|
@api.doc(description="Get all prompts for the user")
|
||||||
|
def get(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
user = decoded_token.get("sub")
|
||||||
|
try:
|
||||||
|
prompts = prompts_collection.find({"user": user})
|
||||||
|
list_prompts = [
|
||||||
|
{"id": "default", "name": "default", "type": "public"},
|
||||||
|
{"id": "creative", "name": "creative", "type": "public"},
|
||||||
|
{"id": "strict", "name": "strict", "type": "public"},
|
||||||
|
]
|
||||||
|
|
||||||
|
for prompt in prompts:
|
||||||
|
list_prompts.append(
|
||||||
|
{
|
||||||
|
"id": str(prompt["_id"]),
|
||||||
|
"name": prompt["name"],
|
||||||
|
"type": "private",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(f"Error retrieving prompts: {err}", exc_info=True)
|
||||||
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
return make_response(jsonify(list_prompts), 200)
|
||||||
|
|
||||||
|
|
||||||
|
@prompts_ns.route("/get_single_prompt")
|
||||||
|
class GetSinglePrompt(Resource):
|
||||||
|
@api.doc(params={"id": "ID of the prompt"}, description="Get a single prompt by ID")
|
||||||
|
def get(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
user = decoded_token.get("sub")
|
||||||
|
prompt_id = request.args.get("id")
|
||||||
|
if not prompt_id:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "ID is required"}), 400
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
if prompt_id == "default":
|
||||||
|
with open(
|
||||||
|
os.path.join(current_dir, "prompts", "chat_combine_default.txt"),
|
||||||
|
"r",
|
||||||
|
) as f:
|
||||||
|
chat_combine_template = f.read()
|
||||||
|
return make_response(jsonify({"content": chat_combine_template}), 200)
|
||||||
|
elif prompt_id == "creative":
|
||||||
|
with open(
|
||||||
|
os.path.join(current_dir, "prompts", "chat_combine_creative.txt"),
|
||||||
|
"r",
|
||||||
|
) as f:
|
||||||
|
chat_reduce_creative = f.read()
|
||||||
|
return make_response(jsonify({"content": chat_reduce_creative}), 200)
|
||||||
|
elif prompt_id == "strict":
|
||||||
|
with open(
|
||||||
|
os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r"
|
||||||
|
) as f:
|
||||||
|
chat_reduce_strict = f.read()
|
||||||
|
return make_response(jsonify({"content": chat_reduce_strict}), 200)
|
||||||
|
prompt = prompts_collection.find_one(
|
||||||
|
{"_id": ObjectId(prompt_id), "user": user}
|
||||||
|
)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(f"Error retrieving prompt: {err}", exc_info=True)
|
||||||
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
return make_response(jsonify({"content": prompt["content"]}), 200)
|
||||||
|
|
||||||
|
|
||||||
|
@prompts_ns.route("/delete_prompt")
|
||||||
|
class DeletePrompt(Resource):
|
||||||
|
delete_prompt_model = api.model(
|
||||||
|
"DeletePromptModel",
|
||||||
|
{"id": fields.String(required=True, description="Prompt ID to delete")},
|
||||||
|
)
|
||||||
|
|
||||||
|
@api.expect(delete_prompt_model)
|
||||||
|
@api.doc(description="Delete a prompt by ID")
|
||||||
|
def post(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
user = decoded_token.get("sub")
|
||||||
|
data = request.get_json()
|
||||||
|
required_fields = ["id"]
|
||||||
|
missing_fields = check_required_fields(data, required_fields)
|
||||||
|
if missing_fields:
|
||||||
|
return missing_fields
|
||||||
|
try:
|
||||||
|
prompts_collection.delete_one({"_id": ObjectId(data["id"]), "user": user})
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(f"Error deleting prompt: {err}", exc_info=True)
|
||||||
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
return make_response(jsonify({"success": True}), 200)
|
||||||
|
|
||||||
|
|
||||||
|
@prompts_ns.route("/update_prompt")
|
||||||
|
class UpdatePrompt(Resource):
|
||||||
|
update_prompt_model = api.model(
|
||||||
|
"UpdatePromptModel",
|
||||||
|
{
|
||||||
|
"id": fields.String(required=True, description="Prompt ID to update"),
|
||||||
|
"name": fields.String(required=True, description="New name of the prompt"),
|
||||||
|
"content": fields.String(
|
||||||
|
required=True, description="New content of the prompt"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@api.expect(update_prompt_model)
|
||||||
|
@api.doc(description="Update an existing prompt")
|
||||||
|
def post(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
user = decoded_token.get("sub")
|
||||||
|
data = request.get_json()
|
||||||
|
required_fields = ["id", "name", "content"]
|
||||||
|
missing_fields = check_required_fields(data, required_fields)
|
||||||
|
if missing_fields:
|
||||||
|
return missing_fields
|
||||||
|
try:
|
||||||
|
prompts_collection.update_one(
|
||||||
|
{"_id": ObjectId(data["id"]), "user": user},
|
||||||
|
{"$set": {"name": data["name"], "content": data["content"]}},
|
||||||
|
)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(f"Error updating prompt: {err}", exc_info=True)
|
||||||
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
return make_response(jsonify({"success": True}), 200)
|
||||||
File diff suppressed because it is too large
Load Diff
5
application/api/user/sharing/__init__.py
Normal file
5
application/api/user/sharing/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""Sharing module."""
|
||||||
|
|
||||||
|
from .routes import sharing_ns
|
||||||
|
|
||||||
|
__all__ = ["sharing_ns"]
|
||||||
301
application/api/user/sharing/routes.py
Normal file
301
application/api/user/sharing/routes.py
Normal file
@@ -0,0 +1,301 @@
|
|||||||
|
"""Conversation sharing routes."""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from bson.binary import Binary, UuidRepresentation
|
||||||
|
from bson.dbref import DBRef
|
||||||
|
from bson.objectid import ObjectId
|
||||||
|
from flask import current_app, jsonify, make_response, request
|
||||||
|
from flask_restx import fields, inputs, Namespace, Resource
|
||||||
|
|
||||||
|
from application.api import api
|
||||||
|
from application.api.user.base import (
|
||||||
|
agents_collection,
|
||||||
|
attachments_collection,
|
||||||
|
conversations_collection,
|
||||||
|
db,
|
||||||
|
shared_conversations_collections,
|
||||||
|
)
|
||||||
|
from application.utils import check_required_fields
|
||||||
|
|
||||||
|
sharing_ns = Namespace(
|
||||||
|
"sharing", description="Conversation sharing operations", path="/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@sharing_ns.route("/share")
|
||||||
|
class ShareConversation(Resource):
|
||||||
|
share_conversation_model = api.model(
|
||||||
|
"ShareConversationModel",
|
||||||
|
{
|
||||||
|
"conversation_id": fields.String(
|
||||||
|
required=True, description="Conversation ID"
|
||||||
|
),
|
||||||
|
"user": fields.String(description="User ID (optional)"),
|
||||||
|
"prompt_id": fields.String(description="Prompt ID (optional)"),
|
||||||
|
"chunks": fields.Integer(description="Chunks count (optional)"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@api.expect(share_conversation_model)
|
||||||
|
@api.doc(description="Share a conversation")
|
||||||
|
def post(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
user = decoded_token.get("sub")
|
||||||
|
data = request.get_json()
|
||||||
|
required_fields = ["conversation_id"]
|
||||||
|
missing_fields = check_required_fields(data, required_fields)
|
||||||
|
if missing_fields:
|
||||||
|
return missing_fields
|
||||||
|
is_promptable = request.args.get("isPromptable", type=inputs.boolean)
|
||||||
|
if is_promptable is None:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "isPromptable is required"}), 400
|
||||||
|
)
|
||||||
|
conversation_id = data["conversation_id"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
conversation = conversations_collection.find_one(
|
||||||
|
{"_id": ObjectId(conversation_id)}
|
||||||
|
)
|
||||||
|
if conversation is None:
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"status": "error",
|
||||||
|
"message": "Conversation does not exist",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
404,
|
||||||
|
)
|
||||||
|
current_n_queries = len(conversation["queries"])
|
||||||
|
explicit_binary = Binary.from_uuid(
|
||||||
|
uuid.uuid4(), UuidRepresentation.STANDARD
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_promptable:
|
||||||
|
prompt_id = data.get("prompt_id", "default")
|
||||||
|
chunks = data.get("chunks", "2")
|
||||||
|
|
||||||
|
name = conversation["name"] + "(shared)"
|
||||||
|
new_api_key_data = {
|
||||||
|
"prompt_id": prompt_id,
|
||||||
|
"chunks": chunks,
|
||||||
|
"user": user,
|
||||||
|
}
|
||||||
|
|
||||||
|
if "source" in data and ObjectId.is_valid(data["source"]):
|
||||||
|
new_api_key_data["source"] = DBRef(
|
||||||
|
"sources", ObjectId(data["source"])
|
||||||
|
)
|
||||||
|
if "retriever" in data:
|
||||||
|
new_api_key_data["retriever"] = data["retriever"]
|
||||||
|
pre_existing_api_document = agents_collection.find_one(new_api_key_data)
|
||||||
|
if pre_existing_api_document:
|
||||||
|
api_uuid = pre_existing_api_document["key"]
|
||||||
|
pre_existing = shared_conversations_collections.find_one(
|
||||||
|
{
|
||||||
|
"conversation_id": DBRef(
|
||||||
|
"conversations", ObjectId(conversation_id)
|
||||||
|
),
|
||||||
|
"isPromptable": is_promptable,
|
||||||
|
"first_n_queries": current_n_queries,
|
||||||
|
"user": user,
|
||||||
|
"api_key": api_uuid,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if pre_existing is not None:
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": True,
|
||||||
|
"identifier": str(pre_existing["uuid"].as_uuid()),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
200,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
shared_conversations_collections.insert_one(
|
||||||
|
{
|
||||||
|
"uuid": explicit_binary,
|
||||||
|
"conversation_id": {
|
||||||
|
"$ref": "conversations",
|
||||||
|
"$id": ObjectId(conversation_id),
|
||||||
|
},
|
||||||
|
"isPromptable": is_promptable,
|
||||||
|
"first_n_queries": current_n_queries,
|
||||||
|
"user": user,
|
||||||
|
"api_key": api_uuid,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": True,
|
||||||
|
"identifier": str(explicit_binary.as_uuid()),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
201,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
api_uuid = str(uuid.uuid4())
|
||||||
|
new_api_key_data["key"] = api_uuid
|
||||||
|
new_api_key_data["name"] = name
|
||||||
|
|
||||||
|
if "source" in data and ObjectId.is_valid(data["source"]):
|
||||||
|
new_api_key_data["source"] = DBRef(
|
||||||
|
"sources", ObjectId(data["source"])
|
||||||
|
)
|
||||||
|
if "retriever" in data:
|
||||||
|
new_api_key_data["retriever"] = data["retriever"]
|
||||||
|
agents_collection.insert_one(new_api_key_data)
|
||||||
|
shared_conversations_collections.insert_one(
|
||||||
|
{
|
||||||
|
"uuid": explicit_binary,
|
||||||
|
"conversation_id": {
|
||||||
|
"$ref": "conversations",
|
||||||
|
"$id": ObjectId(conversation_id),
|
||||||
|
},
|
||||||
|
"isPromptable": is_promptable,
|
||||||
|
"first_n_queries": current_n_queries,
|
||||||
|
"user": user,
|
||||||
|
"api_key": api_uuid,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": True,
|
||||||
|
"identifier": str(explicit_binary.as_uuid()),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
201,
|
||||||
|
)
|
||||||
|
pre_existing = shared_conversations_collections.find_one(
|
||||||
|
{
|
||||||
|
"conversation_id": DBRef(
|
||||||
|
"conversations", ObjectId(conversation_id)
|
||||||
|
),
|
||||||
|
"isPromptable": is_promptable,
|
||||||
|
"first_n_queries": current_n_queries,
|
||||||
|
"user": user,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if pre_existing is not None:
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": True,
|
||||||
|
"identifier": str(pre_existing["uuid"].as_uuid()),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
200,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
shared_conversations_collections.insert_one(
|
||||||
|
{
|
||||||
|
"uuid": explicit_binary,
|
||||||
|
"conversation_id": {
|
||||||
|
"$ref": "conversations",
|
||||||
|
"$id": ObjectId(conversation_id),
|
||||||
|
},
|
||||||
|
"isPromptable": is_promptable,
|
||||||
|
"first_n_queries": current_n_queries,
|
||||||
|
"user": user,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{"success": True, "identifier": str(explicit_binary.as_uuid())}
|
||||||
|
),
|
||||||
|
201,
|
||||||
|
)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(
|
||||||
|
f"Error sharing conversation: {err}", exc_info=True
|
||||||
|
)
|
||||||
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
|
||||||
|
|
||||||
|
@sharing_ns.route("/shared_conversation/<string:identifier>")
|
||||||
|
class GetPubliclySharedConversations(Resource):
|
||||||
|
@api.doc(description="Get publicly shared conversations by identifier")
|
||||||
|
def get(self, identifier: str):
|
||||||
|
try:
|
||||||
|
query_uuid = Binary.from_uuid(
|
||||||
|
uuid.UUID(identifier), UuidRepresentation.STANDARD
|
||||||
|
)
|
||||||
|
shared = shared_conversations_collections.find_one({"uuid": query_uuid})
|
||||||
|
conversation_queries = []
|
||||||
|
|
||||||
|
if (
|
||||||
|
shared
|
||||||
|
and "conversation_id" in shared
|
||||||
|
and isinstance(shared["conversation_id"], DBRef)
|
||||||
|
):
|
||||||
|
conversation_ref = shared["conversation_id"]
|
||||||
|
conversation = db.dereference(conversation_ref)
|
||||||
|
if conversation is None:
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"error": "might have broken url or the conversation does not exist",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
404,
|
||||||
|
)
|
||||||
|
conversation_queries = conversation["queries"][
|
||||||
|
: (shared["first_n_queries"])
|
||||||
|
]
|
||||||
|
|
||||||
|
for query in conversation_queries:
|
||||||
|
if "attachments" in query and query["attachments"]:
|
||||||
|
attachment_details = []
|
||||||
|
for attachment_id in query["attachments"]:
|
||||||
|
try:
|
||||||
|
attachment = attachments_collection.find_one(
|
||||||
|
{"_id": ObjectId(attachment_id)}
|
||||||
|
)
|
||||||
|
if attachment:
|
||||||
|
attachment_details.append(
|
||||||
|
{
|
||||||
|
"id": str(attachment["_id"]),
|
||||||
|
"fileName": attachment.get(
|
||||||
|
"filename", "Unknown file"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
current_app.logger.error(
|
||||||
|
f"Error retrieving attachment {attachment_id}: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
query["attachments"] = attachment_details
|
||||||
|
else:
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"error": "might have broken url or the conversation does not exist",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
404,
|
||||||
|
)
|
||||||
|
date = conversation["_id"].generation_time.isoformat()
|
||||||
|
res = {
|
||||||
|
"success": True,
|
||||||
|
"queries": conversation_queries,
|
||||||
|
"title": conversation["name"],
|
||||||
|
"timestamp": date,
|
||||||
|
}
|
||||||
|
if shared["isPromptable"] and "api_key" in shared:
|
||||||
|
res["api_key"] = shared["api_key"]
|
||||||
|
return make_response(jsonify(res), 200)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(
|
||||||
|
f"Error getting shared conversation: {err}", exc_info=True
|
||||||
|
)
|
||||||
|
return make_response(jsonify({"success": False}), 400)
|
||||||
7
application/api/user/sources/__init__.py
Normal file
7
application/api/user/sources/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
"""Sources module."""
|
||||||
|
|
||||||
|
from .chunks import sources_chunks_ns
|
||||||
|
from .routes import sources_ns
|
||||||
|
from .upload import sources_upload_ns
|
||||||
|
|
||||||
|
__all__ = ["sources_ns", "sources_chunks_ns", "sources_upload_ns"]
|
||||||
278
application/api/user/sources/chunks.py
Normal file
278
application/api/user/sources/chunks.py
Normal file
@@ -0,0 +1,278 @@
|
|||||||
|
"""Source document management chunk management."""
|
||||||
|
|
||||||
|
from bson.objectid import ObjectId
|
||||||
|
from flask import current_app, jsonify, make_response, request
|
||||||
|
from flask_restx import fields, Namespace, Resource
|
||||||
|
|
||||||
|
from application.api import api
|
||||||
|
from application.api.user.base import get_vector_store, sources_collection
|
||||||
|
from application.utils import check_required_fields, num_tokens_from_string
|
||||||
|
|
||||||
|
sources_chunks_ns = Namespace(
|
||||||
|
"sources", description="Source document management operations", path="/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@sources_chunks_ns.route("/get_chunks")
|
||||||
|
class GetChunks(Resource):
|
||||||
|
@api.doc(
|
||||||
|
description="Retrieves chunks from a document, optionally filtered by file path and search term",
|
||||||
|
params={
|
||||||
|
"id": "The document ID",
|
||||||
|
"page": "Page number for pagination",
|
||||||
|
"per_page": "Number of chunks per page",
|
||||||
|
"path": "Optional: Filter chunks by relative file path",
|
||||||
|
"search": "Optional: Search term to filter chunks by title or content",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
def get(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
user = decoded_token.get("sub")
|
||||||
|
doc_id = request.args.get("id")
|
||||||
|
page = int(request.args.get("page", 1))
|
||||||
|
per_page = int(request.args.get("per_page", 10))
|
||||||
|
path = request.args.get("path")
|
||||||
|
search_term = request.args.get("search", "").strip().lower()
|
||||||
|
|
||||||
|
if not ObjectId.is_valid(doc_id):
|
||||||
|
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
||||||
|
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
||||||
|
if not doc:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"error": "Document not found or access denied"}), 404
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
store = get_vector_store(doc_id)
|
||||||
|
chunks = store.get_chunks()
|
||||||
|
|
||||||
|
filtered_chunks = []
|
||||||
|
for chunk in chunks:
|
||||||
|
metadata = chunk.get("metadata", {})
|
||||||
|
|
||||||
|
# Filter by path if provided
|
||||||
|
|
||||||
|
if path:
|
||||||
|
chunk_source = metadata.get("source", "")
|
||||||
|
# Check if the chunk's source matches the requested path
|
||||||
|
|
||||||
|
if not chunk_source or not chunk_source.endswith(path):
|
||||||
|
continue
|
||||||
|
# Filter by search term if provided
|
||||||
|
|
||||||
|
if search_term:
|
||||||
|
text_match = search_term in chunk.get("text", "").lower()
|
||||||
|
title_match = search_term in metadata.get("title", "").lower()
|
||||||
|
|
||||||
|
if not (text_match or title_match):
|
||||||
|
continue
|
||||||
|
filtered_chunks.append(chunk)
|
||||||
|
chunks = filtered_chunks
|
||||||
|
|
||||||
|
total_chunks = len(chunks)
|
||||||
|
start = (page - 1) * per_page
|
||||||
|
end = start + per_page
|
||||||
|
paginated_chunks = chunks[start:end]
|
||||||
|
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"page": page,
|
||||||
|
"per_page": per_page,
|
||||||
|
"total": total_chunks,
|
||||||
|
"chunks": paginated_chunks,
|
||||||
|
"path": path if path else None,
|
||||||
|
"search": search_term if search_term else None,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
200,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
current_app.logger.error(f"Error getting chunks: {e}", exc_info=True)
|
||||||
|
return make_response(jsonify({"success": False}), 500)
|
||||||
|
|
||||||
|
|
||||||
|
@sources_chunks_ns.route("/add_chunk")
|
||||||
|
class AddChunk(Resource):
|
||||||
|
@api.expect(
|
||||||
|
api.model(
|
||||||
|
"AddChunkModel",
|
||||||
|
{
|
||||||
|
"id": fields.String(required=True, description="Document ID"),
|
||||||
|
"text": fields.String(required=True, description="Text of the chunk"),
|
||||||
|
"metadata": fields.Raw(
|
||||||
|
required=False,
|
||||||
|
description="Metadata associated with the chunk",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
@api.doc(
|
||||||
|
description="Adds a new chunk to the document",
|
||||||
|
)
|
||||||
|
def post(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
user = decoded_token.get("sub")
|
||||||
|
data = request.get_json()
|
||||||
|
required_fields = ["id", "text"]
|
||||||
|
missing_fields = check_required_fields(data, required_fields)
|
||||||
|
if missing_fields:
|
||||||
|
return missing_fields
|
||||||
|
doc_id = data.get("id")
|
||||||
|
text = data.get("text")
|
||||||
|
metadata = data.get("metadata", {})
|
||||||
|
token_count = num_tokens_from_string(text)
|
||||||
|
metadata["token_count"] = token_count
|
||||||
|
|
||||||
|
if not ObjectId.is_valid(doc_id):
|
||||||
|
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
||||||
|
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
||||||
|
if not doc:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"error": "Document not found or access denied"}), 404
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
store = get_vector_store(doc_id)
|
||||||
|
chunk_id = store.add_chunk(text, metadata)
|
||||||
|
return make_response(
|
||||||
|
jsonify({"message": "Chunk added successfully", "chunk_id": chunk_id}),
|
||||||
|
201,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
current_app.logger.error(f"Error adding chunk: {e}", exc_info=True)
|
||||||
|
return make_response(jsonify({"success": False}), 500)
|
||||||
|
|
||||||
|
|
||||||
|
@sources_chunks_ns.route("/delete_chunk")
|
||||||
|
class DeleteChunk(Resource):
|
||||||
|
@api.doc(
|
||||||
|
description="Deletes a specific chunk from the document.",
|
||||||
|
params={"id": "The document ID", "chunk_id": "The ID of the chunk to delete"},
|
||||||
|
)
|
||||||
|
def delete(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
user = decoded_token.get("sub")
|
||||||
|
doc_id = request.args.get("id")
|
||||||
|
chunk_id = request.args.get("chunk_id")
|
||||||
|
|
||||||
|
if not ObjectId.is_valid(doc_id):
|
||||||
|
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
||||||
|
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
||||||
|
if not doc:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"error": "Document not found or access denied"}), 404
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
store = get_vector_store(doc_id)
|
||||||
|
deleted = store.delete_chunk(chunk_id)
|
||||||
|
if deleted:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"message": "Chunk deleted successfully"}), 200
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"message": "Chunk not found or could not be deleted"}),
|
||||||
|
404,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
current_app.logger.error(f"Error deleting chunk: {e}", exc_info=True)
|
||||||
|
return make_response(jsonify({"success": False}), 500)
|
||||||
|
|
||||||
|
|
||||||
|
@sources_chunks_ns.route("/update_chunk")
|
||||||
|
class UpdateChunk(Resource):
|
||||||
|
@api.expect(
|
||||||
|
api.model(
|
||||||
|
"UpdateChunkModel",
|
||||||
|
{
|
||||||
|
"id": fields.String(required=True, description="Document ID"),
|
||||||
|
"chunk_id": fields.String(
|
||||||
|
required=True, description="Chunk ID to update"
|
||||||
|
),
|
||||||
|
"text": fields.String(
|
||||||
|
required=False, description="New text of the chunk"
|
||||||
|
),
|
||||||
|
"metadata": fields.Raw(
|
||||||
|
required=False,
|
||||||
|
description="Updated metadata associated with the chunk",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
@api.doc(
|
||||||
|
description="Updates an existing chunk in the document.",
|
||||||
|
)
|
||||||
|
def put(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
user = decoded_token.get("sub")
|
||||||
|
data = request.get_json()
|
||||||
|
required_fields = ["id", "chunk_id"]
|
||||||
|
missing_fields = check_required_fields(data, required_fields)
|
||||||
|
if missing_fields:
|
||||||
|
return missing_fields
|
||||||
|
doc_id = data.get("id")
|
||||||
|
chunk_id = data.get("chunk_id")
|
||||||
|
text = data.get("text")
|
||||||
|
metadata = data.get("metadata")
|
||||||
|
|
||||||
|
if text is not None:
|
||||||
|
token_count = num_tokens_from_string(text)
|
||||||
|
if metadata is None:
|
||||||
|
metadata = {}
|
||||||
|
metadata["token_count"] = token_count
|
||||||
|
if not ObjectId.is_valid(doc_id):
|
||||||
|
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
||||||
|
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
||||||
|
if not doc:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"error": "Document not found or access denied"}), 404
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
store = get_vector_store(doc_id)
|
||||||
|
|
||||||
|
chunks = store.get_chunks()
|
||||||
|
existing_chunk = next((c for c in chunks if c["doc_id"] == chunk_id), None)
|
||||||
|
if not existing_chunk:
|
||||||
|
return make_response(jsonify({"error": "Chunk not found"}), 404)
|
||||||
|
new_text = text if text is not None else existing_chunk["text"]
|
||||||
|
|
||||||
|
if metadata is not None:
|
||||||
|
new_metadata = existing_chunk["metadata"].copy()
|
||||||
|
new_metadata.update(metadata)
|
||||||
|
else:
|
||||||
|
new_metadata = existing_chunk["metadata"].copy()
|
||||||
|
if text is not None:
|
||||||
|
new_metadata["token_count"] = num_tokens_from_string(new_text)
|
||||||
|
try:
|
||||||
|
new_chunk_id = store.add_chunk(new_text, new_metadata)
|
||||||
|
|
||||||
|
deleted = store.delete_chunk(chunk_id)
|
||||||
|
if not deleted:
|
||||||
|
current_app.logger.warning(
|
||||||
|
f"Failed to delete old chunk {chunk_id}, but new chunk {new_chunk_id} was created"
|
||||||
|
)
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"message": "Chunk updated successfully",
|
||||||
|
"chunk_id": new_chunk_id,
|
||||||
|
"original_chunk_id": chunk_id,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
200,
|
||||||
|
)
|
||||||
|
except Exception as add_error:
|
||||||
|
current_app.logger.error(f"Failed to add updated chunk: {add_error}")
|
||||||
|
return make_response(
|
||||||
|
jsonify({"error": "Failed to update chunk - addition failed"}), 500
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
current_app.logger.error(f"Error updating chunk: {e}", exc_info=True)
|
||||||
|
return make_response(jsonify({"success": False}), 500)
|
||||||
323
application/api/user/sources/routes.py
Normal file
323
application/api/user/sources/routes.py
Normal file
@@ -0,0 +1,323 @@
|
|||||||
|
"""Source document management routes."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
|
||||||
|
from bson.objectid import ObjectId
|
||||||
|
from flask import current_app, jsonify, make_response, redirect, request
|
||||||
|
from flask_restx import fields, Namespace, Resource
|
||||||
|
|
||||||
|
from application.api import api
|
||||||
|
from application.api.user.base import sources_collection
|
||||||
|
from application.core.settings import settings
|
||||||
|
from application.storage.storage_creator import StorageCreator
|
||||||
|
from application.utils import check_required_fields
|
||||||
|
from application.vectorstore.vector_creator import VectorCreator
|
||||||
|
|
||||||
|
|
||||||
|
sources_ns = Namespace(
|
||||||
|
"sources", description="Source document management operations", path="/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@sources_ns.route("/sources")
|
||||||
|
class CombinedJson(Resource):
|
||||||
|
@api.doc(description="Provide JSON file with combined available indexes")
|
||||||
|
def get(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
user = decoded_token.get("sub")
|
||||||
|
data = [
|
||||||
|
{
|
||||||
|
"name": "Default",
|
||||||
|
"date": "default",
|
||||||
|
"model": settings.EMBEDDINGS_NAME,
|
||||||
|
"location": "remote",
|
||||||
|
"tokens": "",
|
||||||
|
"retriever": "classic",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
for index in sources_collection.find({"user": user}).sort("date", -1):
|
||||||
|
data.append(
|
||||||
|
{
|
||||||
|
"id": str(index["_id"]),
|
||||||
|
"name": index.get("name"),
|
||||||
|
"date": index.get("date"),
|
||||||
|
"model": settings.EMBEDDINGS_NAME,
|
||||||
|
"location": "local",
|
||||||
|
"tokens": index.get("tokens", ""),
|
||||||
|
"retriever": index.get("retriever", "classic"),
|
||||||
|
"syncFrequency": index.get("sync_frequency", ""),
|
||||||
|
"is_nested": bool(index.get("directory_structure")),
|
||||||
|
"type": index.get(
|
||||||
|
"type", "file"
|
||||||
|
), # Add type field with default "file"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(f"Error retrieving sources: {err}", exc_info=True)
|
||||||
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
return make_response(jsonify(data), 200)
|
||||||
|
|
||||||
|
|
||||||
|
@sources_ns.route("/sources/paginated")
|
||||||
|
class PaginatedSources(Resource):
|
||||||
|
@api.doc(description="Get document with pagination, sorting and filtering")
|
||||||
|
def get(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
user = decoded_token.get("sub")
|
||||||
|
sort_field = request.args.get("sort", "date") # Default to 'date'
|
||||||
|
sort_order = request.args.get("order", "desc") # Default to 'desc'
|
||||||
|
page = int(request.args.get("page", 1)) # Default to 1
|
||||||
|
rows_per_page = int(request.args.get("rows", 10)) # Default to 10
|
||||||
|
# add .strip() to remove leading and trailing whitespaces
|
||||||
|
|
||||||
|
search_term = request.args.get(
|
||||||
|
"search", ""
|
||||||
|
).strip() # add search for filter documents
|
||||||
|
|
||||||
|
# Prepare query for filtering
|
||||||
|
|
||||||
|
query = {"user": user}
|
||||||
|
if search_term:
|
||||||
|
query["name"] = {
|
||||||
|
"$regex": search_term,
|
||||||
|
"$options": "i", # using case-insensitive search
|
||||||
|
}
|
||||||
|
total_documents = sources_collection.count_documents(query)
|
||||||
|
total_pages = max(1, math.ceil(total_documents / rows_per_page))
|
||||||
|
page = min(
|
||||||
|
max(1, page), total_pages
|
||||||
|
) # add this to make sure page inbound is within the range
|
||||||
|
sort_order = 1 if sort_order == "asc" else -1
|
||||||
|
skip = (page - 1) * rows_per_page
|
||||||
|
|
||||||
|
try:
|
||||||
|
documents = (
|
||||||
|
sources_collection.find(query)
|
||||||
|
.sort(sort_field, sort_order)
|
||||||
|
.skip(skip)
|
||||||
|
.limit(rows_per_page)
|
||||||
|
)
|
||||||
|
|
||||||
|
paginated_docs = []
|
||||||
|
for doc in documents:
|
||||||
|
doc_data = {
|
||||||
|
"id": str(doc["_id"]),
|
||||||
|
"name": doc.get("name", ""),
|
||||||
|
"date": doc.get("date", ""),
|
||||||
|
"model": settings.EMBEDDINGS_NAME,
|
||||||
|
"location": "local",
|
||||||
|
"tokens": doc.get("tokens", ""),
|
||||||
|
"retriever": doc.get("retriever", "classic"),
|
||||||
|
"syncFrequency": doc.get("sync_frequency", ""),
|
||||||
|
"isNested": bool(doc.get("directory_structure")),
|
||||||
|
"type": doc.get("type", "file"),
|
||||||
|
}
|
||||||
|
paginated_docs.append(doc_data)
|
||||||
|
response = {
|
||||||
|
"total": total_documents,
|
||||||
|
"totalPages": total_pages,
|
||||||
|
"currentPage": page,
|
||||||
|
"paginated": paginated_docs,
|
||||||
|
}
|
||||||
|
return make_response(jsonify(response), 200)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(
|
||||||
|
f"Error retrieving paginated sources: {err}", exc_info=True
|
||||||
|
)
|
||||||
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
|
||||||
|
|
||||||
|
@sources_ns.route("/delete_by_ids")
|
||||||
|
class DeleteByIds(Resource):
|
||||||
|
@api.doc(
|
||||||
|
description="Deletes documents from the vector store by IDs",
|
||||||
|
params={"path": "Comma-separated list of IDs"},
|
||||||
|
)
|
||||||
|
def get(self):
|
||||||
|
ids = request.args.get("path")
|
||||||
|
if not ids:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Missing required fields"}), 400
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
result = sources_collection.delete_index(ids=ids)
|
||||||
|
if result:
|
||||||
|
return make_response(jsonify({"success": True}), 200)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(f"Error deleting indexes: {err}", exc_info=True)
|
||||||
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
|
||||||
|
|
||||||
|
@sources_ns.route("/delete_old")
|
||||||
|
class DeleteOldIndexes(Resource):
|
||||||
|
@api.doc(
|
||||||
|
description="Deletes old indexes and associated files",
|
||||||
|
params={"source_id": "The source ID to delete"},
|
||||||
|
)
|
||||||
|
def get(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
source_id = request.args.get("source_id")
|
||||||
|
if not source_id:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Missing required fields"}), 400
|
||||||
|
)
|
||||||
|
doc = sources_collection.find_one(
|
||||||
|
{"_id": ObjectId(source_id), "user": decoded_token.get("sub")}
|
||||||
|
)
|
||||||
|
if not doc:
|
||||||
|
return make_response(jsonify({"status": "not found"}), 404)
|
||||||
|
storage = StorageCreator.get_storage()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Delete vector index
|
||||||
|
|
||||||
|
if settings.VECTOR_STORE == "faiss":
|
||||||
|
index_path = f"indexes/{str(doc['_id'])}"
|
||||||
|
if storage.file_exists(f"{index_path}/index.faiss"):
|
||||||
|
storage.delete_file(f"{index_path}/index.faiss")
|
||||||
|
if storage.file_exists(f"{index_path}/index.pkl"):
|
||||||
|
storage.delete_file(f"{index_path}/index.pkl")
|
||||||
|
else:
|
||||||
|
vectorstore = VectorCreator.create_vectorstore(
|
||||||
|
settings.VECTOR_STORE, source_id=str(doc["_id"])
|
||||||
|
)
|
||||||
|
vectorstore.delete_index()
|
||||||
|
if "file_path" in doc and doc["file_path"]:
|
||||||
|
file_path = doc["file_path"]
|
||||||
|
if storage.is_directory(file_path):
|
||||||
|
files = storage.list_files(file_path)
|
||||||
|
for f in files:
|
||||||
|
storage.delete_file(f)
|
||||||
|
else:
|
||||||
|
storage.delete_file(file_path)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(
|
||||||
|
f"Error deleting files and indexes: {err}", exc_info=True
|
||||||
|
)
|
||||||
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
sources_collection.delete_one({"_id": ObjectId(source_id)})
|
||||||
|
return make_response(jsonify({"success": True}), 200)
|
||||||
|
|
||||||
|
|
||||||
|
@sources_ns.route("/combine")
|
||||||
|
class RedirectToSources(Resource):
|
||||||
|
@api.doc(
|
||||||
|
description="Redirects /api/combine to /api/sources for backward compatibility"
|
||||||
|
)
|
||||||
|
def get(self):
|
||||||
|
return redirect("/api/sources", code=301)
|
||||||
|
|
||||||
|
|
||||||
|
@sources_ns.route("/manage_sync")
|
||||||
|
class ManageSync(Resource):
|
||||||
|
manage_sync_model = api.model(
|
||||||
|
"ManageSyncModel",
|
||||||
|
{
|
||||||
|
"source_id": fields.String(required=True, description="Source ID"),
|
||||||
|
"sync_frequency": fields.String(
|
||||||
|
required=True,
|
||||||
|
description="Sync frequency (never, daily, weekly, monthly)",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@api.expect(manage_sync_model)
|
||||||
|
@api.doc(description="Manage sync frequency for sources")
|
||||||
|
def post(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
user = decoded_token.get("sub")
|
||||||
|
data = request.get_json()
|
||||||
|
required_fields = ["source_id", "sync_frequency"]
|
||||||
|
missing_fields = check_required_fields(data, required_fields)
|
||||||
|
if missing_fields:
|
||||||
|
return missing_fields
|
||||||
|
source_id = data["source_id"]
|
||||||
|
sync_frequency = data["sync_frequency"]
|
||||||
|
|
||||||
|
if sync_frequency not in ["never", "daily", "weekly", "monthly"]:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Invalid frequency"}), 400
|
||||||
|
)
|
||||||
|
update_data = {"$set": {"sync_frequency": sync_frequency}}
|
||||||
|
try:
|
||||||
|
sources_collection.update_one(
|
||||||
|
{
|
||||||
|
"_id": ObjectId(source_id),
|
||||||
|
"user": user,
|
||||||
|
},
|
||||||
|
update_data,
|
||||||
|
)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(
|
||||||
|
f"Error updating sync frequency: {err}", exc_info=True
|
||||||
|
)
|
||||||
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
return make_response(jsonify({"success": True}), 200)
|
||||||
|
|
||||||
|
|
||||||
|
@sources_ns.route("/directory_structure")
|
||||||
|
class DirectoryStructure(Resource):
|
||||||
|
@api.doc(
|
||||||
|
description="Get the directory structure for a document",
|
||||||
|
params={"id": "The document ID"},
|
||||||
|
)
|
||||||
|
def get(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
user = decoded_token.get("sub")
|
||||||
|
doc_id = request.args.get("id")
|
||||||
|
|
||||||
|
if not doc_id:
|
||||||
|
return make_response(jsonify({"error": "Document ID is required"}), 400)
|
||||||
|
if not ObjectId.is_valid(doc_id):
|
||||||
|
return make_response(jsonify({"error": "Invalid document ID"}), 400)
|
||||||
|
try:
|
||||||
|
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
||||||
|
if not doc:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"error": "Document not found or access denied"}), 404
|
||||||
|
)
|
||||||
|
directory_structure = doc.get("directory_structure", {})
|
||||||
|
base_path = doc.get("file_path", "")
|
||||||
|
|
||||||
|
provider = None
|
||||||
|
remote_data = doc.get("remote_data")
|
||||||
|
try:
|
||||||
|
if isinstance(remote_data, str) and remote_data:
|
||||||
|
remote_data_obj = json.loads(remote_data)
|
||||||
|
provider = remote_data_obj.get("provider")
|
||||||
|
except Exception as e:
|
||||||
|
current_app.logger.warning(
|
||||||
|
f"Failed to parse remote_data for doc {doc_id}: {e}"
|
||||||
|
)
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": True,
|
||||||
|
"directory_structure": directory_structure,
|
||||||
|
"base_path": base_path,
|
||||||
|
"provider": provider,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
200,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
current_app.logger.error(
|
||||||
|
f"Error retrieving directory structure: {e}", exc_info=True
|
||||||
|
)
|
||||||
|
return make_response(jsonify({"success": False, "error": str(e)}), 500)
|
||||||
583
application/api/user/sources/upload.py
Normal file
583
application/api/user/sources/upload.py
Normal file
@@ -0,0 +1,583 @@
|
|||||||
|
"""Source document management upload functionality."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import zipfile
|
||||||
|
|
||||||
|
from bson.objectid import ObjectId
|
||||||
|
from flask import current_app, jsonify, make_response, request
|
||||||
|
from flask_restx import fields, Namespace, Resource
|
||||||
|
|
||||||
|
from application.api import api
|
||||||
|
from application.api.user.base import sources_collection
|
||||||
|
from application.api.user.tasks import ingest, ingest_connector_task, ingest_remote
|
||||||
|
from application.core.settings import settings
|
||||||
|
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||||
|
from application.storage.storage_creator import StorageCreator
|
||||||
|
from application.utils import check_required_fields, safe_filename
|
||||||
|
|
||||||
|
|
||||||
|
sources_upload_ns = Namespace(
|
||||||
|
"sources", description="Source document management operations", path="/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@sources_upload_ns.route("/upload")
|
||||||
|
class UploadFile(Resource):
|
||||||
|
@api.expect(
|
||||||
|
api.model(
|
||||||
|
"UploadModel",
|
||||||
|
{
|
||||||
|
"user": fields.String(required=True, description="User ID"),
|
||||||
|
"name": fields.String(required=True, description="Job name"),
|
||||||
|
"file": fields.Raw(required=True, description="File(s) to upload"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
@api.doc(
|
||||||
|
description="Uploads a file to be vectorized and indexed",
|
||||||
|
)
|
||||||
|
def post(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
data = request.form
|
||||||
|
files = request.files.getlist("file")
|
||||||
|
required_fields = ["user", "name"]
|
||||||
|
missing_fields = check_required_fields(data, required_fields)
|
||||||
|
if missing_fields or not files or all(file.filename == "" for file in files):
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"status": "error",
|
||||||
|
"message": "Missing required fields or files",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
user = decoded_token.get("sub")
|
||||||
|
job_name = request.form["name"]
|
||||||
|
|
||||||
|
# Create safe versions for filesystem operations
|
||||||
|
|
||||||
|
safe_user = safe_filename(user)
|
||||||
|
dir_name = safe_filename(job_name)
|
||||||
|
base_path = f"{settings.UPLOAD_FOLDER}/{safe_user}/{dir_name}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
storage = StorageCreator.get_storage()
|
||||||
|
|
||||||
|
for file in files:
|
||||||
|
original_filename = file.filename
|
||||||
|
safe_file = safe_filename(original_filename)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
temp_file_path = os.path.join(temp_dir, safe_file)
|
||||||
|
file.save(temp_file_path)
|
||||||
|
|
||||||
|
if zipfile.is_zipfile(temp_file_path):
|
||||||
|
try:
|
||||||
|
with zipfile.ZipFile(temp_file_path, "r") as zip_ref:
|
||||||
|
zip_ref.extractall(path=temp_dir)
|
||||||
|
|
||||||
|
# Walk through extracted files and upload them
|
||||||
|
|
||||||
|
for root, _, files in os.walk(temp_dir):
|
||||||
|
for extracted_file in files:
|
||||||
|
if (
|
||||||
|
os.path.join(root, extracted_file)
|
||||||
|
== temp_file_path
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
rel_path = os.path.relpath(
|
||||||
|
os.path.join(root, extracted_file), temp_dir
|
||||||
|
)
|
||||||
|
storage_path = f"{base_path}/{rel_path}"
|
||||||
|
|
||||||
|
with open(
|
||||||
|
os.path.join(root, extracted_file), "rb"
|
||||||
|
) as f:
|
||||||
|
storage.save_file(f, storage_path)
|
||||||
|
except Exception as e:
|
||||||
|
current_app.logger.error(
|
||||||
|
f"Error extracting zip: {e}", exc_info=True
|
||||||
|
)
|
||||||
|
# If zip extraction fails, save the original zip file
|
||||||
|
|
||||||
|
file_path = f"{base_path}/{safe_file}"
|
||||||
|
with open(temp_file_path, "rb") as f:
|
||||||
|
storage.save_file(f, file_path)
|
||||||
|
else:
|
||||||
|
# For non-zip files, save directly
|
||||||
|
|
||||||
|
file_path = f"{base_path}/{safe_file}"
|
||||||
|
with open(temp_file_path, "rb") as f:
|
||||||
|
storage.save_file(f, file_path)
|
||||||
|
task = ingest.delay(
|
||||||
|
settings.UPLOAD_FOLDER,
|
||||||
|
[
|
||||||
|
".rst",
|
||||||
|
".md",
|
||||||
|
".pdf",
|
||||||
|
".txt",
|
||||||
|
".docx",
|
||||||
|
".csv",
|
||||||
|
".epub",
|
||||||
|
".html",
|
||||||
|
".mdx",
|
||||||
|
".json",
|
||||||
|
".xlsx",
|
||||||
|
".pptx",
|
||||||
|
".png",
|
||||||
|
".jpg",
|
||||||
|
".jpeg",
|
||||||
|
],
|
||||||
|
job_name,
|
||||||
|
user,
|
||||||
|
file_path=base_path,
|
||||||
|
filename=dir_name,
|
||||||
|
)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(f"Error uploading file: {err}", exc_info=True)
|
||||||
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
|
||||||
|
|
||||||
|
|
||||||
|
@sources_upload_ns.route("/remote")
|
||||||
|
class UploadRemote(Resource):
|
||||||
|
@api.expect(
|
||||||
|
api.model(
|
||||||
|
"RemoteUploadModel",
|
||||||
|
{
|
||||||
|
"user": fields.String(required=True, description="User ID"),
|
||||||
|
"source": fields.String(
|
||||||
|
required=True, description="Source of the data"
|
||||||
|
),
|
||||||
|
"name": fields.String(required=True, description="Job name"),
|
||||||
|
"data": fields.String(required=True, description="Data to process"),
|
||||||
|
"repo_url": fields.String(description="GitHub repository URL"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
@api.doc(
|
||||||
|
description="Uploads remote source for vectorization",
|
||||||
|
)
|
||||||
|
def post(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
data = request.form
|
||||||
|
required_fields = ["user", "source", "name", "data"]
|
||||||
|
missing_fields = check_required_fields(data, required_fields)
|
||||||
|
if missing_fields:
|
||||||
|
return missing_fields
|
||||||
|
try:
|
||||||
|
config = json.loads(data["data"])
|
||||||
|
source_data = None
|
||||||
|
|
||||||
|
if data["source"] == "github":
|
||||||
|
source_data = config.get("repo_url")
|
||||||
|
elif data["source"] in ["crawler", "url"]:
|
||||||
|
source_data = config.get("url")
|
||||||
|
elif data["source"] == "reddit":
|
||||||
|
source_data = config
|
||||||
|
elif data["source"] in ConnectorCreator.get_supported_connectors():
|
||||||
|
session_token = config.get("session_token")
|
||||||
|
if not session_token:
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"error": f"Missing session_token in {data['source']} configuration",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
# Process file_ids
|
||||||
|
|
||||||
|
file_ids = config.get("file_ids", [])
|
||||||
|
if isinstance(file_ids, str):
|
||||||
|
file_ids = [id.strip() for id in file_ids.split(",") if id.strip()]
|
||||||
|
elif not isinstance(file_ids, list):
|
||||||
|
file_ids = []
|
||||||
|
# Process folder_ids
|
||||||
|
|
||||||
|
folder_ids = config.get("folder_ids", [])
|
||||||
|
if isinstance(folder_ids, str):
|
||||||
|
folder_ids = [
|
||||||
|
id.strip() for id in folder_ids.split(",") if id.strip()
|
||||||
|
]
|
||||||
|
elif not isinstance(folder_ids, list):
|
||||||
|
folder_ids = []
|
||||||
|
config["file_ids"] = file_ids
|
||||||
|
config["folder_ids"] = folder_ids
|
||||||
|
|
||||||
|
task = ingest_connector_task.delay(
|
||||||
|
job_name=data["name"],
|
||||||
|
user=decoded_token.get("sub"),
|
||||||
|
source_type=data["source"],
|
||||||
|
session_token=session_token,
|
||||||
|
file_ids=file_ids,
|
||||||
|
folder_ids=folder_ids,
|
||||||
|
recursive=config.get("recursive", False),
|
||||||
|
retriever=config.get("retriever", "classic"),
|
||||||
|
)
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": True, "task_id": task.id}), 200
|
||||||
|
)
|
||||||
|
task = ingest_remote.delay(
|
||||||
|
source_data=source_data,
|
||||||
|
job_name=data["name"],
|
||||||
|
user=decoded_token.get("sub"),
|
||||||
|
loader=data["source"],
|
||||||
|
)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(
|
||||||
|
f"Error uploading remote source: {err}", exc_info=True
|
||||||
|
)
|
||||||
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
|
||||||
|
|
||||||
|
|
||||||
|
@sources_upload_ns.route("/manage_source_files")
|
||||||
|
class ManageSourceFiles(Resource):
|
||||||
|
@api.expect(
|
||||||
|
api.model(
|
||||||
|
"ManageSourceFilesModel",
|
||||||
|
{
|
||||||
|
"source_id": fields.String(
|
||||||
|
required=True, description="Source ID to modify"
|
||||||
|
),
|
||||||
|
"operation": fields.String(
|
||||||
|
required=True,
|
||||||
|
description="Operation: 'add', 'remove', or 'remove_directory'",
|
||||||
|
),
|
||||||
|
"file_paths": fields.List(
|
||||||
|
fields.String,
|
||||||
|
required=False,
|
||||||
|
description="File paths to remove (for remove operation)",
|
||||||
|
),
|
||||||
|
"directory_path": fields.String(
|
||||||
|
required=False,
|
||||||
|
description="Directory path to remove (for remove_directory operation)",
|
||||||
|
),
|
||||||
|
"file": fields.Raw(
|
||||||
|
required=False, description="Files to add (for add operation)"
|
||||||
|
),
|
||||||
|
"parent_dir": fields.String(
|
||||||
|
required=False,
|
||||||
|
description="Parent directory path relative to source root",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
@api.doc(
|
||||||
|
description="Add files, remove files, or remove directories from an existing source",
|
||||||
|
)
|
||||||
|
def post(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Unauthorized"}), 401
|
||||||
|
)
|
||||||
|
user = decoded_token.get("sub")
|
||||||
|
source_id = request.form.get("source_id")
|
||||||
|
operation = request.form.get("operation")
|
||||||
|
|
||||||
|
if not source_id or not operation:
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"message": "source_id and operation are required",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
if operation not in ["add", "remove", "remove_directory"]:
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"message": "operation must be 'add', 'remove', or 'remove_directory'",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
ObjectId(source_id)
|
||||||
|
except Exception:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Invalid source ID format"}), 400
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
source = sources_collection.find_one(
|
||||||
|
{"_id": ObjectId(source_id), "user": user}
|
||||||
|
)
|
||||||
|
if not source:
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"message": "Source not found or access denied",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
404,
|
||||||
|
)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(f"Error finding source: {err}", exc_info=True)
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Database error"}), 500
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
storage = StorageCreator.get_storage()
|
||||||
|
source_file_path = source.get("file_path", "")
|
||||||
|
parent_dir = request.form.get("parent_dir", "")
|
||||||
|
|
||||||
|
if parent_dir and (parent_dir.startswith("/") or ".." in parent_dir):
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{"success": False, "message": "Invalid parent directory path"}
|
||||||
|
),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
if operation == "add":
|
||||||
|
files = request.files.getlist("file")
|
||||||
|
if not files or all(file.filename == "" for file in files):
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"message": "No files provided for add operation",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
added_files = []
|
||||||
|
|
||||||
|
target_dir = source_file_path
|
||||||
|
if parent_dir:
|
||||||
|
target_dir = f"{source_file_path}/{parent_dir}"
|
||||||
|
for file in files:
|
||||||
|
if file.filename:
|
||||||
|
safe_filename_str = safe_filename(file.filename)
|
||||||
|
file_path = f"{target_dir}/{safe_filename_str}"
|
||||||
|
|
||||||
|
# Save file to storage
|
||||||
|
|
||||||
|
storage.save_file(file, file_path)
|
||||||
|
added_files.append(safe_filename_str)
|
||||||
|
# Trigger re-ingestion pipeline
|
||||||
|
|
||||||
|
from application.api.user.tasks import reingest_source_task
|
||||||
|
|
||||||
|
task = reingest_source_task.delay(source_id=source_id, user=user)
|
||||||
|
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": True,
|
||||||
|
"message": f"Added {len(added_files)} files",
|
||||||
|
"added_files": added_files,
|
||||||
|
"parent_dir": parent_dir,
|
||||||
|
"reingest_task_id": task.id,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
200,
|
||||||
|
)
|
||||||
|
elif operation == "remove":
|
||||||
|
file_paths_str = request.form.get("file_paths")
|
||||||
|
if not file_paths_str:
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"message": "file_paths required for remove operation",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
file_paths = (
|
||||||
|
json.loads(file_paths_str)
|
||||||
|
if isinstance(file_paths_str, str)
|
||||||
|
else file_paths_str
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{"success": False, "message": "Invalid file_paths format"}
|
||||||
|
),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
# Remove files from storage and directory structure
|
||||||
|
|
||||||
|
removed_files = []
|
||||||
|
for file_path in file_paths:
|
||||||
|
full_path = f"{source_file_path}/{file_path}"
|
||||||
|
|
||||||
|
# Remove from storage
|
||||||
|
|
||||||
|
if storage.file_exists(full_path):
|
||||||
|
storage.delete_file(full_path)
|
||||||
|
removed_files.append(file_path)
|
||||||
|
# Trigger re-ingestion pipeline
|
||||||
|
|
||||||
|
from application.api.user.tasks import reingest_source_task
|
||||||
|
|
||||||
|
task = reingest_source_task.delay(source_id=source_id, user=user)
|
||||||
|
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": True,
|
||||||
|
"message": f"Removed {len(removed_files)} files",
|
||||||
|
"removed_files": removed_files,
|
||||||
|
"reingest_task_id": task.id,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
200,
|
||||||
|
)
|
||||||
|
elif operation == "remove_directory":
|
||||||
|
directory_path = request.form.get("directory_path")
|
||||||
|
if not directory_path:
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"message": "directory_path required for remove_directory operation",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
# Validate directory path (prevent path traversal)
|
||||||
|
|
||||||
|
if directory_path.startswith("/") or ".." in directory_path:
|
||||||
|
current_app.logger.warning(
|
||||||
|
f"Invalid directory path attempted for removal. "
|
||||||
|
f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}"
|
||||||
|
)
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{"success": False, "message": "Invalid directory path"}
|
||||||
|
),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
full_directory_path = (
|
||||||
|
f"{source_file_path}/{directory_path}"
|
||||||
|
if directory_path
|
||||||
|
else source_file_path
|
||||||
|
)
|
||||||
|
|
||||||
|
if not storage.is_directory(full_directory_path):
|
||||||
|
current_app.logger.warning(
|
||||||
|
f"Directory not found or is not a directory for removal. "
|
||||||
|
f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}, "
|
||||||
|
f"Full path: {full_directory_path}"
|
||||||
|
)
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"message": "Directory not found or is not a directory",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
404,
|
||||||
|
)
|
||||||
|
success = storage.remove_directory(full_directory_path)
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
current_app.logger.error(
|
||||||
|
f"Failed to remove directory from storage. "
|
||||||
|
f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}, "
|
||||||
|
f"Full path: {full_directory_path}"
|
||||||
|
)
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{"success": False, "message": "Failed to remove directory"}
|
||||||
|
),
|
||||||
|
500,
|
||||||
|
)
|
||||||
|
current_app.logger.info(
|
||||||
|
f"Successfully removed directory. "
|
||||||
|
f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}, "
|
||||||
|
f"Full path: {full_directory_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Trigger re-ingestion pipeline
|
||||||
|
|
||||||
|
from application.api.user.tasks import reingest_source_task
|
||||||
|
|
||||||
|
task = reingest_source_task.delay(source_id=source_id, user=user)
|
||||||
|
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": True,
|
||||||
|
"message": f"Successfully removed directory: {directory_path}",
|
||||||
|
"removed_directory": directory_path,
|
||||||
|
"reingest_task_id": task.id,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
200,
|
||||||
|
)
|
||||||
|
except Exception as err:
|
||||||
|
error_context = f"operation={operation}, user={user}, source_id={source_id}"
|
||||||
|
if operation == "remove_directory":
|
||||||
|
directory_path = request.form.get("directory_path", "")
|
||||||
|
error_context += f", directory_path={directory_path}"
|
||||||
|
elif operation == "remove":
|
||||||
|
file_paths_str = request.form.get("file_paths", "")
|
||||||
|
error_context += f", file_paths={file_paths_str}"
|
||||||
|
elif operation == "add":
|
||||||
|
parent_dir = request.form.get("parent_dir", "")
|
||||||
|
error_context += f", parent_dir={parent_dir}"
|
||||||
|
current_app.logger.error(
|
||||||
|
f"Error managing source files: {err} ({error_context})", exc_info=True
|
||||||
|
)
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Operation failed"}), 500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@sources_upload_ns.route("/task_status")
|
||||||
|
class TaskStatus(Resource):
|
||||||
|
task_status_model = api.model(
|
||||||
|
"TaskStatusModel",
|
||||||
|
{"task_id": fields.String(required=True, description="Task ID")},
|
||||||
|
)
|
||||||
|
|
||||||
|
@api.expect(task_status_model)
|
||||||
|
@api.doc(description="Get celery job status")
|
||||||
|
def get(self):
|
||||||
|
task_id = request.args.get("task_id")
|
||||||
|
if not task_id:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Task ID is required"}), 400
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
from application.celery_init import celery
|
||||||
|
|
||||||
|
task = celery.AsyncResult(task_id)
|
||||||
|
task_meta = task.info
|
||||||
|
print(f"Task status: {task.status}")
|
||||||
|
|
||||||
|
if task.status == "PENDING":
|
||||||
|
inspect = celery.control.inspect()
|
||||||
|
active_workers = inspect.ping()
|
||||||
|
if not active_workers:
|
||||||
|
raise ConnectionError("Service unavailable")
|
||||||
|
|
||||||
|
if not isinstance(
|
||||||
|
task_meta, (dict, list, str, int, float, bool, type(None))
|
||||||
|
):
|
||||||
|
task_meta = str(task_meta) # Convert to a string representation
|
||||||
|
except ConnectionError as err:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": str(err)}), 503
|
||||||
|
)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(f"Error getting task status: {err}", exc_info=True)
|
||||||
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
return make_response(jsonify({"status": task.status, "result": task_meta}), 200)
|
||||||
@@ -5,14 +5,16 @@ from application.worker import (
|
|||||||
agent_webhook_worker,
|
agent_webhook_worker,
|
||||||
attachment_worker,
|
attachment_worker,
|
||||||
ingest_worker,
|
ingest_worker,
|
||||||
|
mcp_oauth,
|
||||||
|
mcp_oauth_status,
|
||||||
remote_worker,
|
remote_worker,
|
||||||
sync_worker,
|
sync_worker,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@celery.task(bind=True)
|
@celery.task(bind=True)
|
||||||
def ingest(self, directory, formats, name_job, filename, user):
|
def ingest(self, directory, formats, job_name, user, file_path, filename):
|
||||||
resp = ingest_worker(self, directory, formats, name_job, filename, user)
|
resp = ingest_worker(self, directory, formats, job_name, file_path, filename, user)
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
@@ -22,6 +24,14 @@ def ingest_remote(self, source_data, job_name, user, loader):
|
|||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
@celery.task(bind=True)
|
||||||
|
def reingest_source_task(self, source_id, user):
|
||||||
|
from application.worker import reingest_source_worker
|
||||||
|
|
||||||
|
resp = reingest_source_worker(self, source_id, user)
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
@celery.task(bind=True)
|
@celery.task(bind=True)
|
||||||
def schedule_syncs(self, frequency):
|
def schedule_syncs(self, frequency):
|
||||||
resp = sync_worker(self, frequency)
|
resp = sync_worker(self, frequency)
|
||||||
@@ -40,6 +50,40 @@ def process_agent_webhook(self, agent_id, payload):
|
|||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
@celery.task(bind=True)
|
||||||
|
def ingest_connector_task(
|
||||||
|
self,
|
||||||
|
job_name,
|
||||||
|
user,
|
||||||
|
source_type,
|
||||||
|
session_token=None,
|
||||||
|
file_ids=None,
|
||||||
|
folder_ids=None,
|
||||||
|
recursive=True,
|
||||||
|
retriever="classic",
|
||||||
|
operation_mode="upload",
|
||||||
|
doc_id=None,
|
||||||
|
sync_frequency="never",
|
||||||
|
):
|
||||||
|
from application.worker import ingest_connector
|
||||||
|
|
||||||
|
resp = ingest_connector(
|
||||||
|
self,
|
||||||
|
job_name,
|
||||||
|
user,
|
||||||
|
source_type,
|
||||||
|
session_token=session_token,
|
||||||
|
file_ids=file_ids,
|
||||||
|
folder_ids=folder_ids,
|
||||||
|
recursive=recursive,
|
||||||
|
retriever=retriever,
|
||||||
|
operation_mode=operation_mode,
|
||||||
|
doc_id=doc_id,
|
||||||
|
sync_frequency=sync_frequency,
|
||||||
|
)
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
@celery.on_after_configure.connect
|
@celery.on_after_configure.connect
|
||||||
def setup_periodic_tasks(sender, **kwargs):
|
def setup_periodic_tasks(sender, **kwargs):
|
||||||
sender.add_periodic_task(
|
sender.add_periodic_task(
|
||||||
@@ -54,3 +98,15 @@ def setup_periodic_tasks(sender, **kwargs):
|
|||||||
timedelta(days=30),
|
timedelta(days=30),
|
||||||
schedule_syncs.s("monthly"),
|
schedule_syncs.s("monthly"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@celery.task(bind=True)
|
||||||
|
def mcp_oauth_task(self, config, user):
|
||||||
|
resp = mcp_oauth(self, config, user)
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
@celery.task(bind=True)
|
||||||
|
def mcp_oauth_status_task(self, task_id):
|
||||||
|
resp = mcp_oauth_status(self, task_id)
|
||||||
|
return resp
|
||||||
|
|||||||
6
application/api/user/tools/__init__.py
Normal file
6
application/api/user/tools/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
"""Tools module."""
|
||||||
|
|
||||||
|
from .mcp import tools_mcp_ns
|
||||||
|
from .routes import tools_ns
|
||||||
|
|
||||||
|
__all__ = ["tools_ns", "tools_mcp_ns"]
|
||||||
333
application/api/user/tools/mcp.py
Normal file
333
application/api/user/tools/mcp.py
Normal file
@@ -0,0 +1,333 @@
|
|||||||
|
"""Tool management MCP server integration."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from email.quoprimime import unquote
|
||||||
|
|
||||||
|
from bson.objectid import ObjectId
|
||||||
|
from flask import current_app, jsonify, make_response, redirect, request
|
||||||
|
from flask_restx import fields, Namespace, Resource
|
||||||
|
|
||||||
|
from application.agents.tools.mcp_tool import MCPOAuthManager, MCPTool
|
||||||
|
from application.api import api
|
||||||
|
from application.api.user.base import user_tools_collection
|
||||||
|
from application.cache import get_redis_instance
|
||||||
|
from application.security.encryption import encrypt_credentials
|
||||||
|
from application.utils import check_required_fields
|
||||||
|
|
||||||
|
tools_mcp_ns = Namespace("tools", description="Tool management operations", path="/api")
|
||||||
|
|
||||||
|
|
||||||
|
@tools_mcp_ns.route("/mcp_server/test")
|
||||||
|
class TestMCPServerConfig(Resource):
|
||||||
|
@api.expect(
|
||||||
|
api.model(
|
||||||
|
"MCPServerTestModel",
|
||||||
|
{
|
||||||
|
"config": fields.Raw(
|
||||||
|
required=True, description="MCP server configuration to test"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
@api.doc(description="Test MCP server connection with provided configuration")
|
||||||
|
def post(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
user = decoded_token.get("sub")
|
||||||
|
data = request.get_json()
|
||||||
|
|
||||||
|
required_fields = ["config"]
|
||||||
|
missing_fields = check_required_fields(data, required_fields)
|
||||||
|
if missing_fields:
|
||||||
|
return missing_fields
|
||||||
|
try:
|
||||||
|
config = data["config"]
|
||||||
|
|
||||||
|
auth_credentials = {}
|
||||||
|
auth_type = config.get("auth_type", "none")
|
||||||
|
|
||||||
|
if auth_type == "api_key" and "api_key" in config:
|
||||||
|
auth_credentials["api_key"] = config["api_key"]
|
||||||
|
if "api_key_header" in config:
|
||||||
|
auth_credentials["api_key_header"] = config["api_key_header"]
|
||||||
|
elif auth_type == "bearer" and "bearer_token" in config:
|
||||||
|
auth_credentials["bearer_token"] = config["bearer_token"]
|
||||||
|
elif auth_type == "basic":
|
||||||
|
if "username" in config:
|
||||||
|
auth_credentials["username"] = config["username"]
|
||||||
|
if "password" in config:
|
||||||
|
auth_credentials["password"] = config["password"]
|
||||||
|
test_config = config.copy()
|
||||||
|
test_config["auth_credentials"] = auth_credentials
|
||||||
|
|
||||||
|
mcp_tool = MCPTool(config=test_config, user_id=user)
|
||||||
|
result = mcp_tool.test_connection()
|
||||||
|
|
||||||
|
return make_response(jsonify(result), 200)
|
||||||
|
except Exception as e:
|
||||||
|
current_app.logger.error(f"Error testing MCP server: {e}", exc_info=True)
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{"success": False, "error": f"Connection test failed: {str(e)}"}
|
||||||
|
),
|
||||||
|
500,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@tools_mcp_ns.route("/mcp_server/save")
|
||||||
|
class MCPServerSave(Resource):
|
||||||
|
@api.expect(
|
||||||
|
api.model(
|
||||||
|
"MCPServerSaveModel",
|
||||||
|
{
|
||||||
|
"id": fields.String(
|
||||||
|
required=False, description="Tool ID for updates (optional)"
|
||||||
|
),
|
||||||
|
"displayName": fields.String(
|
||||||
|
required=True, description="Display name for the MCP server"
|
||||||
|
),
|
||||||
|
"config": fields.Raw(
|
||||||
|
required=True, description="MCP server configuration"
|
||||||
|
),
|
||||||
|
"status": fields.Boolean(
|
||||||
|
required=False, default=True, description="Tool status"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
@api.doc(description="Create or update MCP server with automatic tool discovery")
|
||||||
|
def post(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
user = decoded_token.get("sub")
|
||||||
|
data = request.get_json()
|
||||||
|
|
||||||
|
required_fields = ["displayName", "config"]
|
||||||
|
missing_fields = check_required_fields(data, required_fields)
|
||||||
|
if missing_fields:
|
||||||
|
return missing_fields
|
||||||
|
try:
|
||||||
|
config = data["config"]
|
||||||
|
|
||||||
|
auth_credentials = {}
|
||||||
|
auth_type = config.get("auth_type", "none")
|
||||||
|
if auth_type == "api_key":
|
||||||
|
if "api_key" in config and config["api_key"]:
|
||||||
|
auth_credentials["api_key"] = config["api_key"]
|
||||||
|
if "api_key_header" in config:
|
||||||
|
auth_credentials["api_key_header"] = config["api_key_header"]
|
||||||
|
elif auth_type == "bearer":
|
||||||
|
if "bearer_token" in config and config["bearer_token"]:
|
||||||
|
auth_credentials["bearer_token"] = config["bearer_token"]
|
||||||
|
elif auth_type == "basic":
|
||||||
|
if "username" in config and config["username"]:
|
||||||
|
auth_credentials["username"] = config["username"]
|
||||||
|
if "password" in config and config["password"]:
|
||||||
|
auth_credentials["password"] = config["password"]
|
||||||
|
mcp_config = config.copy()
|
||||||
|
mcp_config["auth_credentials"] = auth_credentials
|
||||||
|
|
||||||
|
if auth_type == "oauth":
|
||||||
|
if not config.get("oauth_task_id"):
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"error": "Connection not authorized. Please complete the OAuth authorization first.",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
redis_client = get_redis_instance()
|
||||||
|
manager = MCPOAuthManager(redis_client)
|
||||||
|
result = manager.get_oauth_status(config["oauth_task_id"])
|
||||||
|
if not result.get("status") == "completed":
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"error": "OAuth failed or not completed. Please try authorizing again.",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
actions_metadata = result.get("tools", [])
|
||||||
|
elif auth_type == "none" or auth_credentials:
|
||||||
|
mcp_tool = MCPTool(config=mcp_config, user_id=user)
|
||||||
|
mcp_tool.discover_tools()
|
||||||
|
actions_metadata = mcp_tool.get_actions_metadata()
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
"No valid credentials provided for the selected authentication type"
|
||||||
|
)
|
||||||
|
storage_config = config.copy()
|
||||||
|
if auth_credentials:
|
||||||
|
encrypted_credentials_string = encrypt_credentials(
|
||||||
|
auth_credentials, user
|
||||||
|
)
|
||||||
|
storage_config["encrypted_credentials"] = encrypted_credentials_string
|
||||||
|
for field in [
|
||||||
|
"api_key",
|
||||||
|
"bearer_token",
|
||||||
|
"username",
|
||||||
|
"password",
|
||||||
|
"api_key_header",
|
||||||
|
]:
|
||||||
|
storage_config.pop(field, None)
|
||||||
|
transformed_actions = []
|
||||||
|
for action in actions_metadata:
|
||||||
|
action["active"] = True
|
||||||
|
if "parameters" in action:
|
||||||
|
if "properties" in action["parameters"]:
|
||||||
|
for param_name, param_details in action["parameters"][
|
||||||
|
"properties"
|
||||||
|
].items():
|
||||||
|
param_details["filled_by_llm"] = True
|
||||||
|
param_details["value"] = ""
|
||||||
|
transformed_actions.append(action)
|
||||||
|
tool_data = {
|
||||||
|
"name": "mcp_tool",
|
||||||
|
"displayName": data["displayName"],
|
||||||
|
"customName": data["displayName"],
|
||||||
|
"description": f"MCP Server: {storage_config.get('server_url', 'Unknown')}",
|
||||||
|
"config": storage_config,
|
||||||
|
"actions": transformed_actions,
|
||||||
|
"status": data.get("status", True),
|
||||||
|
"user": user,
|
||||||
|
}
|
||||||
|
|
||||||
|
tool_id = data.get("id")
|
||||||
|
if tool_id:
|
||||||
|
result = user_tools_collection.update_one(
|
||||||
|
{"_id": ObjectId(tool_id), "user": user, "name": "mcp_tool"},
|
||||||
|
{"$set": {k: v for k, v in tool_data.items() if k != "user"}},
|
||||||
|
)
|
||||||
|
if result.matched_count == 0:
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"error": "Tool not found or access denied",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
404,
|
||||||
|
)
|
||||||
|
response_data = {
|
||||||
|
"success": True,
|
||||||
|
"id": tool_id,
|
||||||
|
"message": f"MCP server updated successfully! Discovered {len(transformed_actions)} tools.",
|
||||||
|
"tools_count": len(transformed_actions),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
result = user_tools_collection.insert_one(tool_data)
|
||||||
|
tool_id = str(result.inserted_id)
|
||||||
|
response_data = {
|
||||||
|
"success": True,
|
||||||
|
"id": tool_id,
|
||||||
|
"message": f"MCP server created successfully! Discovered {len(transformed_actions)} tools.",
|
||||||
|
"tools_count": len(transformed_actions),
|
||||||
|
}
|
||||||
|
return make_response(jsonify(response_data), 200)
|
||||||
|
except Exception as e:
|
||||||
|
current_app.logger.error(f"Error saving MCP server: {e}", exc_info=True)
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{"success": False, "error": f"Failed to save MCP server: {str(e)}"}
|
||||||
|
),
|
||||||
|
500,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@tools_mcp_ns.route("/mcp_server/callback")
|
||||||
|
class MCPOAuthCallback(Resource):
|
||||||
|
@api.expect(
|
||||||
|
api.model(
|
||||||
|
"MCPServerCallbackModel",
|
||||||
|
{
|
||||||
|
"code": fields.String(required=True, description="Authorization code"),
|
||||||
|
"state": fields.String(required=True, description="State parameter"),
|
||||||
|
"error": fields.String(
|
||||||
|
required=False, description="Error message (if any)"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
@api.doc(
|
||||||
|
description="Handle OAuth callback by providing the authorization code and state"
|
||||||
|
)
|
||||||
|
def get(self):
|
||||||
|
code = request.args.get("code")
|
||||||
|
state = request.args.get("state")
|
||||||
|
error = request.args.get("error")
|
||||||
|
|
||||||
|
if error:
|
||||||
|
return redirect(
|
||||||
|
f"/api/connectors/callback-status?status=error&message=OAuth+error:+{error}.+Please+try+again+and+make+sure+to+grant+all+requested+permissions,+including+offline+access.&provider=mcp_tool"
|
||||||
|
)
|
||||||
|
if not code or not state:
|
||||||
|
return redirect(
|
||||||
|
"/api/connectors/callback-status?status=error&message=Authorization+code+or+state+not+provided.+Please+complete+the+authorization+process+and+make+sure+to+grant+offline+access.&provider=mcp_tool"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
redis_client = get_redis_instance()
|
||||||
|
if not redis_client:
|
||||||
|
return redirect(
|
||||||
|
"/api/connectors/callback-status?status=error&message=Internal+server+error:+Redis+not+available.&provider=mcp_tool"
|
||||||
|
)
|
||||||
|
code = unquote(code)
|
||||||
|
manager = MCPOAuthManager(redis_client)
|
||||||
|
success = manager.handle_oauth_callback(state, code, error)
|
||||||
|
if success:
|
||||||
|
return redirect(
|
||||||
|
"/api/connectors/callback-status?status=success&message=Authorization+code+received+successfully.+You+can+close+this+window.&provider=mcp_tool"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return redirect(
|
||||||
|
"/api/connectors/callback-status?status=error&message=OAuth+callback+failed.&provider=mcp_tool"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
current_app.logger.error(
|
||||||
|
f"Error handling MCP OAuth callback: {str(e)}", exc_info=True
|
||||||
|
)
|
||||||
|
return redirect(
|
||||||
|
f"/api/connectors/callback-status?status=error&message=Internal+server+error:+{str(e)}.&provider=mcp_tool"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@tools_mcp_ns.route("/mcp_server/oauth_status/<string:task_id>")
|
||||||
|
class MCPOAuthStatus(Resource):
|
||||||
|
def get(self, task_id):
|
||||||
|
"""
|
||||||
|
Get current status of OAuth flow.
|
||||||
|
Frontend should poll this endpoint periodically.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
redis_client = get_redis_instance()
|
||||||
|
status_key = f"mcp_oauth_status:{task_id}"
|
||||||
|
status_data = redis_client.get(status_key)
|
||||||
|
|
||||||
|
if status_data:
|
||||||
|
status = json.loads(status_data)
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": True, "task_id": task_id, **status})
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"error": "Task not found or expired",
|
||||||
|
"task_id": task_id,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
404,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
current_app.logger.error(
|
||||||
|
f"Error getting OAuth status for task {task_id}: {str(e)}"
|
||||||
|
)
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "error": str(e), "task_id": task_id}), 500
|
||||||
|
)
|
||||||
416
application/api/user/tools/routes.py
Normal file
416
application/api/user/tools/routes.py
Normal file
@@ -0,0 +1,416 @@
|
|||||||
|
"""Tool management routes."""
|
||||||
|
|
||||||
|
from bson.objectid import ObjectId
|
||||||
|
from flask import current_app, jsonify, make_response, request
|
||||||
|
from flask_restx import fields, Namespace, Resource
|
||||||
|
|
||||||
|
from application.agents.tools.tool_manager import ToolManager
|
||||||
|
from application.api import api
|
||||||
|
from application.api.user.base import user_tools_collection
|
||||||
|
from application.security.encryption import decrypt_credentials, encrypt_credentials
|
||||||
|
from application.utils import check_required_fields, validate_function_name
|
||||||
|
|
||||||
|
tool_config = {}
|
||||||
|
tool_manager = ToolManager(config=tool_config)
|
||||||
|
|
||||||
|
|
||||||
|
tools_ns = Namespace("tools", description="Tool management operations", path="/api")
|
||||||
|
|
||||||
|
|
||||||
|
@tools_ns.route("/available_tools")
|
||||||
|
class AvailableTools(Resource):
|
||||||
|
@api.doc(description="Get available tools for a user")
|
||||||
|
def get(self):
|
||||||
|
try:
|
||||||
|
tools_metadata = []
|
||||||
|
for tool_name, tool_instance in tool_manager.tools.items():
|
||||||
|
doc = tool_instance.__doc__.strip()
|
||||||
|
lines = doc.split("\n", 1)
|
||||||
|
name = lines[0].strip()
|
||||||
|
description = lines[1].strip() if len(lines) > 1 else ""
|
||||||
|
tools_metadata.append(
|
||||||
|
{
|
||||||
|
"name": tool_name,
|
||||||
|
"displayName": name,
|
||||||
|
"description": description,
|
||||||
|
"configRequirements": tool_instance.get_config_requirements(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(
|
||||||
|
f"Error getting available tools: {err}", exc_info=True
|
||||||
|
)
|
||||||
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
return make_response(jsonify({"success": True, "data": tools_metadata}), 200)
|
||||||
|
|
||||||
|
|
||||||
|
@tools_ns.route("/get_tools")
|
||||||
|
class GetTools(Resource):
|
||||||
|
@api.doc(description="Get tools created by a user")
|
||||||
|
def get(self):
|
||||||
|
try:
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
user = decoded_token.get("sub")
|
||||||
|
tools = user_tools_collection.find({"user": user})
|
||||||
|
user_tools = []
|
||||||
|
for tool in tools:
|
||||||
|
tool_copy = {**tool}
|
||||||
|
tool_copy["id"] = str(tool["_id"])
|
||||||
|
tool_copy.pop("_id", None)
|
||||||
|
user_tools.append(tool_copy)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(f"Error getting user tools: {err}", exc_info=True)
|
||||||
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
return make_response(jsonify({"success": True, "tools": user_tools}), 200)
|
||||||
|
|
||||||
|
|
||||||
|
@tools_ns.route("/create_tool")
|
||||||
|
class CreateTool(Resource):
|
||||||
|
@api.expect(
|
||||||
|
api.model(
|
||||||
|
"CreateToolModel",
|
||||||
|
{
|
||||||
|
"name": fields.String(required=True, description="Name of the tool"),
|
||||||
|
"displayName": fields.String(
|
||||||
|
required=True, description="Display name for the tool"
|
||||||
|
),
|
||||||
|
"description": fields.String(
|
||||||
|
required=True, description="Tool description"
|
||||||
|
),
|
||||||
|
"config": fields.Raw(
|
||||||
|
required=True, description="Configuration of the tool"
|
||||||
|
),
|
||||||
|
"customName": fields.String(
|
||||||
|
required=False, description="Custom name for the tool"
|
||||||
|
),
|
||||||
|
"status": fields.Boolean(
|
||||||
|
required=True, description="Status of the tool"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
@api.doc(description="Create a new tool")
|
||||||
|
def post(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
user = decoded_token.get("sub")
|
||||||
|
data = request.get_json()
|
||||||
|
required_fields = [
|
||||||
|
"name",
|
||||||
|
"displayName",
|
||||||
|
"description",
|
||||||
|
"config",
|
||||||
|
"status",
|
||||||
|
]
|
||||||
|
missing_fields = check_required_fields(data, required_fields)
|
||||||
|
if missing_fields:
|
||||||
|
return missing_fields
|
||||||
|
try:
|
||||||
|
tool_instance = tool_manager.tools.get(data["name"])
|
||||||
|
if not tool_instance:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Tool not found"}), 404
|
||||||
|
)
|
||||||
|
actions_metadata = tool_instance.get_actions_metadata()
|
||||||
|
transformed_actions = []
|
||||||
|
for action in actions_metadata:
|
||||||
|
action["active"] = True
|
||||||
|
if "parameters" in action:
|
||||||
|
if "properties" in action["parameters"]:
|
||||||
|
for param_name, param_details in action["parameters"][
|
||||||
|
"properties"
|
||||||
|
].items():
|
||||||
|
param_details["filled_by_llm"] = True
|
||||||
|
param_details["value"] = ""
|
||||||
|
transformed_actions.append(action)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(
|
||||||
|
f"Error getting tool actions: {err}", exc_info=True
|
||||||
|
)
|
||||||
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
try:
|
||||||
|
new_tool = {
|
||||||
|
"user": user,
|
||||||
|
"name": data["name"],
|
||||||
|
"displayName": data["displayName"],
|
||||||
|
"description": data["description"],
|
||||||
|
"customName": data.get("customName", ""),
|
||||||
|
"actions": transformed_actions,
|
||||||
|
"config": data["config"],
|
||||||
|
"status": data["status"],
|
||||||
|
}
|
||||||
|
resp = user_tools_collection.insert_one(new_tool)
|
||||||
|
new_id = str(resp.inserted_id)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(f"Error creating tool: {err}", exc_info=True)
|
||||||
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
return make_response(jsonify({"id": new_id}), 200)
|
||||||
|
|
||||||
|
|
||||||
|
@tools_ns.route("/update_tool")
|
||||||
|
class UpdateTool(Resource):
|
||||||
|
@api.expect(
|
||||||
|
api.model(
|
||||||
|
"UpdateToolModel",
|
||||||
|
{
|
||||||
|
"id": fields.String(required=True, description="Tool ID"),
|
||||||
|
"name": fields.String(description="Name of the tool"),
|
||||||
|
"displayName": fields.String(description="Display name for the tool"),
|
||||||
|
"customName": fields.String(description="Custom name for the tool"),
|
||||||
|
"description": fields.String(description="Tool description"),
|
||||||
|
"config": fields.Raw(description="Configuration of the tool"),
|
||||||
|
"actions": fields.List(
|
||||||
|
fields.Raw, description="Actions the tool can perform"
|
||||||
|
),
|
||||||
|
"status": fields.Boolean(description="Status of the tool"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
@api.doc(description="Update a tool by ID")
|
||||||
|
def post(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
user = decoded_token.get("sub")
|
||||||
|
data = request.get_json()
|
||||||
|
required_fields = ["id"]
|
||||||
|
missing_fields = check_required_fields(data, required_fields)
|
||||||
|
if missing_fields:
|
||||||
|
return missing_fields
|
||||||
|
try:
|
||||||
|
update_data = {}
|
||||||
|
if "name" in data:
|
||||||
|
update_data["name"] = data["name"]
|
||||||
|
if "displayName" in data:
|
||||||
|
update_data["displayName"] = data["displayName"]
|
||||||
|
if "customName" in data:
|
||||||
|
update_data["customName"] = data["customName"]
|
||||||
|
if "description" in data:
|
||||||
|
update_data["description"] = data["description"]
|
||||||
|
if "actions" in data:
|
||||||
|
update_data["actions"] = data["actions"]
|
||||||
|
if "config" in data:
|
||||||
|
if "actions" in data["config"]:
|
||||||
|
for action_name in list(data["config"]["actions"].keys()):
|
||||||
|
if not validate_function_name(action_name):
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"message": f"Invalid function name '{action_name}'. Function names must match pattern '^[a-zA-Z0-9_-]+$'.",
|
||||||
|
"param": "tools[].function.name",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
tool_doc = user_tools_collection.find_one(
|
||||||
|
{"_id": ObjectId(data["id"]), "user": user}
|
||||||
|
)
|
||||||
|
if tool_doc and tool_doc.get("name") == "mcp_tool":
|
||||||
|
config = data["config"]
|
||||||
|
existing_config = tool_doc.get("config", {})
|
||||||
|
storage_config = existing_config.copy()
|
||||||
|
|
||||||
|
storage_config.update(config)
|
||||||
|
existing_credentials = {}
|
||||||
|
if "encrypted_credentials" in existing_config:
|
||||||
|
existing_credentials = decrypt_credentials(
|
||||||
|
existing_config["encrypted_credentials"], user
|
||||||
|
)
|
||||||
|
auth_credentials = existing_credentials.copy()
|
||||||
|
auth_type = storage_config.get("auth_type", "none")
|
||||||
|
if auth_type == "api_key":
|
||||||
|
if "api_key" in config and config["api_key"]:
|
||||||
|
auth_credentials["api_key"] = config["api_key"]
|
||||||
|
if "api_key_header" in config:
|
||||||
|
auth_credentials["api_key_header"] = config[
|
||||||
|
"api_key_header"
|
||||||
|
]
|
||||||
|
elif auth_type == "bearer":
|
||||||
|
if "bearer_token" in config and config["bearer_token"]:
|
||||||
|
auth_credentials["bearer_token"] = config["bearer_token"]
|
||||||
|
elif "encrypted_token" in config and config["encrypted_token"]:
|
||||||
|
auth_credentials["bearer_token"] = config["encrypted_token"]
|
||||||
|
elif auth_type == "basic":
|
||||||
|
if "username" in config and config["username"]:
|
||||||
|
auth_credentials["username"] = config["username"]
|
||||||
|
if "password" in config and config["password"]:
|
||||||
|
auth_credentials["password"] = config["password"]
|
||||||
|
if auth_type != "none" and auth_credentials:
|
||||||
|
encrypted_credentials_string = encrypt_credentials(
|
||||||
|
auth_credentials, user
|
||||||
|
)
|
||||||
|
storage_config["encrypted_credentials"] = (
|
||||||
|
encrypted_credentials_string
|
||||||
|
)
|
||||||
|
elif auth_type == "none":
|
||||||
|
storage_config.pop("encrypted_credentials", None)
|
||||||
|
for field in [
|
||||||
|
"api_key",
|
||||||
|
"bearer_token",
|
||||||
|
"encrypted_token",
|
||||||
|
"username",
|
||||||
|
"password",
|
||||||
|
"api_key_header",
|
||||||
|
]:
|
||||||
|
storage_config.pop(field, None)
|
||||||
|
update_data["config"] = storage_config
|
||||||
|
else:
|
||||||
|
update_data["config"] = data["config"]
|
||||||
|
if "status" in data:
|
||||||
|
update_data["status"] = data["status"]
|
||||||
|
user_tools_collection.update_one(
|
||||||
|
{"_id": ObjectId(data["id"]), "user": user},
|
||||||
|
{"$set": update_data},
|
||||||
|
)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(f"Error updating tool: {err}", exc_info=True)
|
||||||
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
return make_response(jsonify({"success": True}), 200)
|
||||||
|
|
||||||
|
|
||||||
|
@tools_ns.route("/update_tool_config")
|
||||||
|
class UpdateToolConfig(Resource):
|
||||||
|
@api.expect(
|
||||||
|
api.model(
|
||||||
|
"UpdateToolConfigModel",
|
||||||
|
{
|
||||||
|
"id": fields.String(required=True, description="Tool ID"),
|
||||||
|
"config": fields.Raw(
|
||||||
|
required=True, description="Configuration of the tool"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
@api.doc(description="Update the configuration of a tool")
|
||||||
|
def post(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
user = decoded_token.get("sub")
|
||||||
|
data = request.get_json()
|
||||||
|
required_fields = ["id", "config"]
|
||||||
|
missing_fields = check_required_fields(data, required_fields)
|
||||||
|
if missing_fields:
|
||||||
|
return missing_fields
|
||||||
|
try:
|
||||||
|
user_tools_collection.update_one(
|
||||||
|
{"_id": ObjectId(data["id"]), "user": user},
|
||||||
|
{"$set": {"config": data["config"]}},
|
||||||
|
)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(
|
||||||
|
f"Error updating tool config: {err}", exc_info=True
|
||||||
|
)
|
||||||
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
return make_response(jsonify({"success": True}), 200)
|
||||||
|
|
||||||
|
|
||||||
|
@tools_ns.route("/update_tool_actions")
|
||||||
|
class UpdateToolActions(Resource):
|
||||||
|
@api.expect(
|
||||||
|
api.model(
|
||||||
|
"UpdateToolActionsModel",
|
||||||
|
{
|
||||||
|
"id": fields.String(required=True, description="Tool ID"),
|
||||||
|
"actions": fields.List(
|
||||||
|
fields.Raw,
|
||||||
|
required=True,
|
||||||
|
description="Actions the tool can perform",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
@api.doc(description="Update the actions of a tool")
|
||||||
|
def post(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
user = decoded_token.get("sub")
|
||||||
|
data = request.get_json()
|
||||||
|
required_fields = ["id", "actions"]
|
||||||
|
missing_fields = check_required_fields(data, required_fields)
|
||||||
|
if missing_fields:
|
||||||
|
return missing_fields
|
||||||
|
try:
|
||||||
|
user_tools_collection.update_one(
|
||||||
|
{"_id": ObjectId(data["id"]), "user": user},
|
||||||
|
{"$set": {"actions": data["actions"]}},
|
||||||
|
)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(
|
||||||
|
f"Error updating tool actions: {err}", exc_info=True
|
||||||
|
)
|
||||||
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
return make_response(jsonify({"success": True}), 200)
|
||||||
|
|
||||||
|
|
||||||
|
@tools_ns.route("/update_tool_status")
|
||||||
|
class UpdateToolStatus(Resource):
|
||||||
|
@api.expect(
|
||||||
|
api.model(
|
||||||
|
"UpdateToolStatusModel",
|
||||||
|
{
|
||||||
|
"id": fields.String(required=True, description="Tool ID"),
|
||||||
|
"status": fields.Boolean(
|
||||||
|
required=True, description="Status of the tool"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
@api.doc(description="Update the status of a tool")
|
||||||
|
def post(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
user = decoded_token.get("sub")
|
||||||
|
data = request.get_json()
|
||||||
|
required_fields = ["id", "status"]
|
||||||
|
missing_fields = check_required_fields(data, required_fields)
|
||||||
|
if missing_fields:
|
||||||
|
return missing_fields
|
||||||
|
try:
|
||||||
|
user_tools_collection.update_one(
|
||||||
|
{"_id": ObjectId(data["id"]), "user": user},
|
||||||
|
{"$set": {"status": data["status"]}},
|
||||||
|
)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(
|
||||||
|
f"Error updating tool status: {err}", exc_info=True
|
||||||
|
)
|
||||||
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
return make_response(jsonify({"success": True}), 200)
|
||||||
|
|
||||||
|
|
||||||
|
@tools_ns.route("/delete_tool")
|
||||||
|
class DeleteTool(Resource):
|
||||||
|
@api.expect(
|
||||||
|
api.model(
|
||||||
|
"DeleteToolModel",
|
||||||
|
{"id": fields.String(required=True, description="Tool ID")},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
@api.doc(description="Delete a tool by ID")
|
||||||
|
def post(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
user = decoded_token.get("sub")
|
||||||
|
data = request.get_json()
|
||||||
|
required_fields = ["id"]
|
||||||
|
missing_fields = check_required_fields(data, required_fields)
|
||||||
|
if missing_fields:
|
||||||
|
return missing_fields
|
||||||
|
try:
|
||||||
|
result = user_tools_collection.delete_one(
|
||||||
|
{"_id": ObjectId(data["id"]), "user": user}
|
||||||
|
)
|
||||||
|
if result.deleted_count == 0:
|
||||||
|
return {"success": False, "message": "Tool not found"}, 404
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(f"Error deleting tool: {err}", exc_info=True)
|
||||||
|
return {"success": False}, 400
|
||||||
|
return {"success": True}, 200
|
||||||
@@ -12,25 +12,26 @@ from application.core.logging_config import setup_logging
|
|||||||
|
|
||||||
setup_logging()
|
setup_logging()
|
||||||
|
|
||||||
from application.api.answer.routes import answer # noqa: E402
|
from application.api import api # noqa: E402
|
||||||
|
from application.api.answer import answer # noqa: E402
|
||||||
from application.api.internal.routes import internal # noqa: E402
|
from application.api.internal.routes import internal # noqa: E402
|
||||||
from application.api.user.routes import user # noqa: E402
|
from application.api.user.routes import user # noqa: E402
|
||||||
|
from application.api.connector.routes import connector # noqa: E402
|
||||||
from application.celery_init import celery # noqa: E402
|
from application.celery_init import celery # noqa: E402
|
||||||
from application.core.settings import settings # noqa: E402
|
from application.core.settings import settings # noqa: E402
|
||||||
from application.extensions import api # noqa: E402
|
|
||||||
|
|
||||||
|
|
||||||
if platform.system() == "Windows":
|
if platform.system() == "Windows":
|
||||||
import pathlib
|
import pathlib
|
||||||
|
|
||||||
pathlib.PosixPath = pathlib.WindowsPath
|
pathlib.PosixPath = pathlib.WindowsPath
|
||||||
|
|
||||||
dotenv.load_dotenv()
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
app.register_blueprint(user)
|
app.register_blueprint(user)
|
||||||
app.register_blueprint(answer)
|
app.register_blueprint(answer)
|
||||||
app.register_blueprint(internal)
|
app.register_blueprint(internal)
|
||||||
|
app.register_blueprint(connector)
|
||||||
app.config.update(
|
app.config.update(
|
||||||
UPLOAD_FOLDER="inputs",
|
UPLOAD_FOLDER="inputs",
|
||||||
CELERY_BROKER_URL=settings.CELERY_BROKER_URL,
|
CELERY_BROKER_URL=settings.CELERY_BROKER_URL,
|
||||||
@@ -52,7 +53,6 @@ if settings.AUTH_TYPE in ("simple_jwt", "session_jwt") and not settings.JWT_SECR
|
|||||||
settings.JWT_SECRET_KEY = new_key
|
settings.JWT_SECRET_KEY = new_key
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to setup JWT_SECRET_KEY: {e}")
|
raise RuntimeError(f"Failed to setup JWT_SECRET_KEY: {e}")
|
||||||
|
|
||||||
SIMPLE_JWT_TOKEN = None
|
SIMPLE_JWT_TOKEN = None
|
||||||
if settings.AUTH_TYPE == "simple_jwt":
|
if settings.AUTH_TYPE == "simple_jwt":
|
||||||
payload = {"sub": "local"}
|
payload = {"sub": "local"}
|
||||||
@@ -92,7 +92,6 @@ def generate_token():
|
|||||||
def authenticate_request():
|
def authenticate_request():
|
||||||
if request.method == "OPTIONS":
|
if request.method == "OPTIONS":
|
||||||
return "", 200
|
return "", 200
|
||||||
|
|
||||||
decoded_token = handle_auth(request)
|
decoded_token = handle_auth(request)
|
||||||
if not decoded_token:
|
if not decoded_token:
|
||||||
request.decoded_token = None
|
request.decoded_token = None
|
||||||
|
|||||||
@@ -10,31 +10,61 @@ current_dir = os.path.dirname(
|
|||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
AUTH_TYPE: Optional[str] = None
|
AUTH_TYPE: Optional[str] = None # simple_jwt, session_jwt, or None
|
||||||
LLM_NAME: str = "docsgpt"
|
LLM_PROVIDER: str = "docsgpt"
|
||||||
MODEL_NAME: Optional[str] = (
|
LLM_NAME: Optional[str] = (
|
||||||
None # if LLM_NAME is openai, MODEL_NAME can be gpt-4 or gpt-3.5-turbo
|
None # if LLM_PROVIDER is openai, LLM_NAME can be gpt-4 or gpt-3.5-turbo
|
||||||
)
|
)
|
||||||
EMBEDDINGS_NAME: str = "huggingface_sentence-transformers/all-mpnet-base-v2"
|
EMBEDDINGS_NAME: str = "huggingface_sentence-transformers/all-mpnet-base-v2"
|
||||||
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
|
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
|
||||||
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
|
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
|
||||||
MONGO_URI: str = "mongodb://localhost:27017/docsgpt"
|
MONGO_URI: str = "mongodb://localhost:27017/docsgpt"
|
||||||
MONGO_DB_NAME: str = "docsgpt"
|
MONGO_DB_NAME: str = "docsgpt"
|
||||||
MODEL_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf")
|
LLM_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf")
|
||||||
DEFAULT_MAX_HISTORY: int = 150
|
DEFAULT_MAX_HISTORY: int = 150
|
||||||
MODEL_TOKEN_LIMITS: dict = {
|
LLM_TOKEN_LIMITS: dict = {
|
||||||
|
"gpt-4o": 128000,
|
||||||
"gpt-4o-mini": 128000,
|
"gpt-4o-mini": 128000,
|
||||||
|
"gpt-4": 8192,
|
||||||
"gpt-3.5-turbo": 4096,
|
"gpt-3.5-turbo": 4096,
|
||||||
"claude-2": 1e5,
|
"claude-2": int(1e5),
|
||||||
"gemini-2.0-flash-exp": 1e6,
|
"gemini-2.5-flash": int(1e6),
|
||||||
|
}
|
||||||
|
DEFAULT_LLM_TOKEN_LIMIT: int = 128000
|
||||||
|
RESERVED_TOKENS: dict = {
|
||||||
|
"system_prompt": 500,
|
||||||
|
"current_query": 500,
|
||||||
|
"safety_buffer": 1000,
|
||||||
|
}
|
||||||
|
DEFAULT_AGENT_LIMITS: dict = {
|
||||||
|
"token_limit": 50000,
|
||||||
|
"request_limit": 500,
|
||||||
}
|
}
|
||||||
UPLOAD_FOLDER: str = "inputs"
|
UPLOAD_FOLDER: str = "inputs"
|
||||||
PARSE_PDF_AS_IMAGE: bool = False
|
PARSE_PDF_AS_IMAGE: bool = False
|
||||||
|
PARSE_IMAGE_REMOTE: bool = False
|
||||||
VECTOR_STORE: str = (
|
VECTOR_STORE: str = (
|
||||||
"faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb"
|
"faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb"
|
||||||
)
|
)
|
||||||
RETRIEVERS_ENABLED: list = ["classic_rag", "duckduck_search"] # also brave_search
|
RETRIEVERS_ENABLED: list = ["classic_rag"]
|
||||||
AGENT_NAME: str = "classic"
|
AGENT_NAME: str = "classic"
|
||||||
|
FALLBACK_LLM_PROVIDER: Optional[str] = None # provider for fallback llm
|
||||||
|
FALLBACK_LLM_NAME: Optional[str] = None # model name for fallback llm
|
||||||
|
FALLBACK_LLM_API_KEY: Optional[str] = None # api key for fallback llm
|
||||||
|
|
||||||
|
# Google Drive integration
|
||||||
|
GOOGLE_CLIENT_ID: Optional[str] = (
|
||||||
|
None # Replace with your actual Google OAuth client ID
|
||||||
|
)
|
||||||
|
GOOGLE_CLIENT_SECRET: Optional[str] = (
|
||||||
|
None # Replace with your actual Google OAuth client secret
|
||||||
|
)
|
||||||
|
CONNECTOR_REDIRECT_BASE_URI: Optional[str] = (
|
||||||
|
"http://127.0.0.1:7091/api/connectors/callback" ##add redirect url as it is to your provider's console(gcp)
|
||||||
|
)
|
||||||
|
|
||||||
|
# GitHub source
|
||||||
|
GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access
|
||||||
|
|
||||||
# LLM Cache
|
# LLM Cache
|
||||||
CACHE_REDIS_URL: str = "redis://localhost:6379/2"
|
CACHE_REDIS_URL: str = "redis://localhost:6379/2"
|
||||||
@@ -86,6 +116,8 @@ class Settings(BaseSettings):
|
|||||||
QDRANT_PATH: Optional[str] = None
|
QDRANT_PATH: Optional[str] = None
|
||||||
QDRANT_DISTANCE_FUNC: str = "Cosine"
|
QDRANT_DISTANCE_FUNC: str = "Cosine"
|
||||||
|
|
||||||
|
# PGVector vectorstore config
|
||||||
|
PGVECTOR_CONNECTION_STRING: Optional[str] = None
|
||||||
# Milvus vectorstore config
|
# Milvus vectorstore config
|
||||||
MILVUS_COLLECTION_NAME: Optional[str] = "docsgpt"
|
MILVUS_COLLECTION_NAME: Optional[str] = "docsgpt"
|
||||||
MILVUS_URI: Optional[str] = "./milvus_local.db" # milvus lite version as default
|
MILVUS_URI: Optional[str] = "./milvus_local.db" # milvus lite version as default
|
||||||
@@ -96,14 +128,21 @@ class Settings(BaseSettings):
|
|||||||
LANCEDB_TABLE_NAME: Optional[str] = (
|
LANCEDB_TABLE_NAME: Optional[str] = (
|
||||||
"docsgpts" # Name of the table to use for storing vectors
|
"docsgpts" # Name of the table to use for storing vectors
|
||||||
)
|
)
|
||||||
BRAVE_SEARCH_API_KEY: Optional[str] = None
|
|
||||||
|
|
||||||
FLASK_DEBUG_MODE: bool = False
|
FLASK_DEBUG_MODE: bool = False
|
||||||
STORAGE_TYPE: str = "local" # local or s3
|
STORAGE_TYPE: str = "local" # local or s3
|
||||||
|
URL_STRATEGY: str = "backend" # backend or s3
|
||||||
|
|
||||||
JWT_SECRET_KEY: str = ""
|
JWT_SECRET_KEY: str = ""
|
||||||
|
|
||||||
|
# Encryption settings
|
||||||
|
ENCRYPTION_SECRET_KEY: str = "default-docsgpt-encryption-key"
|
||||||
|
|
||||||
|
TTS_PROVIDER: str = "google_tts" # google_tts or elevenlabs
|
||||||
|
ELEVENLABS_API_KEY: Optional[str] = None
|
||||||
|
|
||||||
|
# Tool pre-fetch settings
|
||||||
|
ENABLE_TOOL_PREFETCH: bool = True
|
||||||
|
|
||||||
path = Path(__file__).parent.parent.absolute()
|
path = Path(__file__).parent.parent.absolute()
|
||||||
settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8")
|
settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8")
|
||||||
|
|||||||
@@ -1,7 +0,0 @@
|
|||||||
from flask_restx import Api
|
|
||||||
|
|
||||||
api = Api(
|
|
||||||
version="1.0",
|
|
||||||
title="DocsGPT API",
|
|
||||||
description="API for DocsGPT",
|
|
||||||
)
|
|
||||||
@@ -46,5 +46,9 @@ class AnthropicLLM(BaseLLM):
|
|||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
for completion in stream_response:
|
try:
|
||||||
yield completion.completion
|
for completion in stream_response:
|
||||||
|
yield completion.completion
|
||||||
|
finally:
|
||||||
|
if hasattr(stream_response, 'close'):
|
||||||
|
stream_response.close()
|
||||||
|
|||||||
@@ -1,53 +1,123 @@
|
|||||||
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
from application.cache import gen_cache, stream_cache
|
from application.cache import gen_cache, stream_cache
|
||||||
|
|
||||||
|
from application.core.settings import settings
|
||||||
from application.usage import gen_token_usage, stream_token_usage
|
from application.usage import gen_token_usage, stream_token_usage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BaseLLM(ABC):
|
class BaseLLM(ABC):
|
||||||
def __init__(self, decoded_token=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
decoded_token=None,
|
||||||
|
):
|
||||||
self.decoded_token = decoded_token
|
self.decoded_token = decoded_token
|
||||||
self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
|
self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
|
||||||
|
self.fallback_provider = settings.FALLBACK_LLM_PROVIDER
|
||||||
|
self.fallback_model_name = settings.FALLBACK_LLM_NAME
|
||||||
|
self.fallback_llm_api_key = settings.FALLBACK_LLM_API_KEY
|
||||||
|
self._fallback_llm = None
|
||||||
|
|
||||||
def _apply_decorator(self, method, decorators, *args, **kwargs):
|
@property
|
||||||
for decorator in decorators:
|
def fallback_llm(self):
|
||||||
method = decorator(method)
|
"""Lazy-loaded fallback LLM instance."""
|
||||||
return method(self, *args, **kwargs)
|
if (
|
||||||
|
self._fallback_llm is None
|
||||||
|
and self.fallback_provider
|
||||||
|
and self.fallback_model_name
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
from application.llm.llm_creator import LLMCreator
|
||||||
|
|
||||||
|
self._fallback_llm = LLMCreator.create_llm(
|
||||||
|
self.fallback_provider,
|
||||||
|
self.fallback_llm_api_key,
|
||||||
|
None,
|
||||||
|
self.decoded_token,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to initialize fallback LLM: {str(e)}", exc_info=True
|
||||||
|
)
|
||||||
|
return self._fallback_llm
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _remove_null_values(args_dict):
|
||||||
|
if not isinstance(args_dict, dict):
|
||||||
|
return args_dict
|
||||||
|
return {k: v for k, v in args_dict.items() if v is not None}
|
||||||
|
|
||||||
|
def _execute_with_fallback(
|
||||||
|
self, method_name: str, decorators: list, *args, **kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Unified method execution with fallback support.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
method_name: Name of the raw method ('_raw_gen' or '_raw_gen_stream')
|
||||||
|
decorators: List of decorators to apply
|
||||||
|
*args: Positional arguments
|
||||||
|
**kwargs: Keyword arguments
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorated_method():
|
||||||
|
method = getattr(self, method_name)
|
||||||
|
for decorator in decorators:
|
||||||
|
method = decorator(method)
|
||||||
|
return method(self, *args, **kwargs)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return decorated_method()
|
||||||
|
except Exception as e:
|
||||||
|
if not self.fallback_llm:
|
||||||
|
logger.error(f"Primary LLM failed and no fallback available: {str(e)}")
|
||||||
|
raise
|
||||||
|
logger.warning(
|
||||||
|
f"Falling back to {self.fallback_provider}/{self.fallback_model_name}. Error: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
fallback_method = getattr(
|
||||||
|
self.fallback_llm, method_name.replace("_raw_", "")
|
||||||
|
)
|
||||||
|
return fallback_method(*args, **kwargs)
|
||||||
|
|
||||||
|
def gen(self, model, messages, stream=False, tools=None, *args, **kwargs):
|
||||||
|
decorators = [gen_token_usage, gen_cache]
|
||||||
|
return self._execute_with_fallback(
|
||||||
|
"_raw_gen",
|
||||||
|
decorators,
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
stream=stream,
|
||||||
|
tools=tools,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def gen_stream(self, model, messages, stream=True, tools=None, *args, **kwargs):
|
||||||
|
decorators = [stream_cache, stream_token_usage]
|
||||||
|
return self._execute_with_fallback(
|
||||||
|
"_raw_gen_stream",
|
||||||
|
decorators,
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
stream=stream,
|
||||||
|
tools=tools,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _raw_gen(self, model, messages, stream, tools, *args, **kwargs):
|
def _raw_gen(self, model, messages, stream, tools, *args, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def gen(self, model, messages, stream=False, tools=None, *args, **kwargs):
|
|
||||||
decorators = [gen_token_usage, gen_cache]
|
|
||||||
return self._apply_decorator(
|
|
||||||
self._raw_gen,
|
|
||||||
decorators=decorators,
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
stream=stream,
|
|
||||||
tools=tools,
|
|
||||||
*args,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _raw_gen_stream(self, model, messages, stream, *args, **kwargs):
|
def _raw_gen_stream(self, model, messages, stream, *args, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def gen_stream(self, model, messages, stream=True, tools=None, *args, **kwargs):
|
|
||||||
decorators = [stream_cache, stream_token_usage]
|
|
||||||
return self._apply_decorator(
|
|
||||||
self._raw_gen_stream,
|
|
||||||
decorators=decorators,
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
stream=stream,
|
|
||||||
tools=tools,
|
|
||||||
*args,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
def supports_tools(self):
|
def supports_tools(self):
|
||||||
return hasattr(self, "_supports_tools") and callable(
|
return hasattr(self, "_supports_tools") and callable(
|
||||||
getattr(self, "_supports_tools")
|
getattr(self, "_supports_tools")
|
||||||
@@ -55,12 +125,26 @@ class BaseLLM(ABC):
|
|||||||
|
|
||||||
def _supports_tools(self):
|
def _supports_tools(self):
|
||||||
raise NotImplementedError("Subclass must implement _supports_tools method")
|
raise NotImplementedError("Subclass must implement _supports_tools method")
|
||||||
|
|
||||||
|
def supports_structured_output(self):
|
||||||
|
"""Check if the LLM supports structured output/JSON schema enforcement"""
|
||||||
|
return hasattr(self, "_supports_structured_output") and callable(
|
||||||
|
getattr(self, "_supports_structured_output")
|
||||||
|
)
|
||||||
|
|
||||||
|
def _supports_structured_output(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def prepare_structured_output_format(self, json_schema):
|
||||||
|
"""Prepare structured output format specific to the LLM provider"""
|
||||||
|
_ = json_schema
|
||||||
|
return None
|
||||||
|
|
||||||
def get_supported_attachment_types(self):
|
def get_supported_attachment_types(self):
|
||||||
"""
|
"""
|
||||||
Return a list of MIME types supported by this LLM for file uploads.
|
Return a list of MIME types supported by this LLM for file uploads.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list: List of supported MIME types
|
list: List of supported MIME types
|
||||||
"""
|
"""
|
||||||
return [] # Default: no attachments supported
|
return []
|
||||||
|
|||||||
@@ -33,14 +33,15 @@ class DocsGPTAPILLM(BaseLLM):
|
|||||||
{"role": role, "content": item["text"]}
|
{"role": role, "content": item["text"]}
|
||||||
)
|
)
|
||||||
elif "function_call" in item:
|
elif "function_call" in item:
|
||||||
|
cleaned_args = self._remove_null_values(
|
||||||
|
item["function_call"]["args"]
|
||||||
|
)
|
||||||
tool_call = {
|
tool_call = {
|
||||||
"id": item["function_call"]["call_id"],
|
"id": item["function_call"]["call_id"],
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": item["function_call"]["name"],
|
"name": item["function_call"]["name"],
|
||||||
"arguments": json.dumps(
|
"arguments": json.dumps(cleaned_args),
|
||||||
item["function_call"]["args"]
|
|
||||||
),
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
cleaned_messages.append(
|
cleaned_messages.append(
|
||||||
@@ -121,11 +122,19 @@ class DocsGPTAPILLM(BaseLLM):
|
|||||||
model="docsgpt", messages=messages, stream=stream, **kwargs
|
model="docsgpt", messages=messages, stream=stream, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
for line in response:
|
try:
|
||||||
if len(line.choices) > 0 and line.choices[0].delta.content is not None and len(line.choices[0].delta.content) > 0:
|
for line in response:
|
||||||
yield line.choices[0].delta.content
|
if (
|
||||||
elif len(line.choices) > 0:
|
len(line.choices) > 0
|
||||||
yield line.choices[0]
|
and line.choices[0].delta.content is not None
|
||||||
|
and len(line.choices[0].delta.content) > 0
|
||||||
|
):
|
||||||
|
yield line.choices[0].delta.content
|
||||||
|
elif len(line.choices) > 0:
|
||||||
|
yield line.choices[0]
|
||||||
|
finally:
|
||||||
|
if hasattr(response, 'close'):
|
||||||
|
response.close()
|
||||||
|
|
||||||
def _supports_tools(self):
|
def _supports_tools(self):
|
||||||
return True
|
return True
|
||||||
@@ -1,11 +1,13 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
from google import genai
|
from google import genai
|
||||||
from google.genai import types
|
from google.genai import types
|
||||||
import logging
|
|
||||||
import json
|
from application.core.settings import settings
|
||||||
|
|
||||||
from application.llm.base import BaseLLM
|
from application.llm.base import BaseLLM
|
||||||
from application.storage.storage_creator import StorageCreator
|
from application.storage.storage_creator import StorageCreator
|
||||||
from application.core.settings import settings
|
|
||||||
|
|
||||||
|
|
||||||
class GoogleLLM(BaseLLM):
|
class GoogleLLM(BaseLLM):
|
||||||
@@ -24,12 +26,12 @@ class GoogleLLM(BaseLLM):
|
|||||||
list: List of supported MIME types
|
list: List of supported MIME types
|
||||||
"""
|
"""
|
||||||
return [
|
return [
|
||||||
'application/pdf',
|
"application/pdf",
|
||||||
'image/png',
|
"image/png",
|
||||||
'image/jpeg',
|
"image/jpeg",
|
||||||
'image/jpg',
|
"image/jpg",
|
||||||
'image/webp',
|
"image/webp",
|
||||||
'image/gif'
|
"image/gif",
|
||||||
]
|
]
|
||||||
|
|
||||||
def prepare_messages_with_attachments(self, messages, attachments=None):
|
def prepare_messages_with_attachments(self, messages, attachments=None):
|
||||||
@@ -70,26 +72,30 @@ class GoogleLLM(BaseLLM):
|
|||||||
|
|
||||||
files = []
|
files = []
|
||||||
for attachment in attachments:
|
for attachment in attachments:
|
||||||
mime_type = attachment.get('mime_type')
|
mime_type = attachment.get("mime_type")
|
||||||
|
|
||||||
if mime_type in self.get_supported_attachment_types():
|
if mime_type in self.get_supported_attachment_types():
|
||||||
try:
|
try:
|
||||||
file_uri = self._upload_file_to_google(attachment)
|
file_uri = self._upload_file_to_google(attachment)
|
||||||
logging.info(f"GoogleLLM: Successfully uploaded file, got URI: {file_uri}")
|
logging.info(
|
||||||
|
f"GoogleLLM: Successfully uploaded file, got URI: {file_uri}"
|
||||||
|
)
|
||||||
files.append({"file_uri": file_uri, "mime_type": mime_type})
|
files.append({"file_uri": file_uri, "mime_type": mime_type})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"GoogleLLM: Error uploading file: {e}", exc_info=True)
|
logging.error(
|
||||||
if 'content' in attachment:
|
f"GoogleLLM: Error uploading file: {e}", exc_info=True
|
||||||
prepared_messages[user_message_index]["content"].append({
|
)
|
||||||
"type": "text",
|
if "content" in attachment:
|
||||||
"text": f"[File could not be processed: {attachment.get('path', 'unknown')}]"
|
prepared_messages[user_message_index]["content"].append(
|
||||||
})
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": f"[File could not be processed: {attachment.get('path', 'unknown')}]",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
if files:
|
if files:
|
||||||
logging.info(f"GoogleLLM: Adding {len(files)} files to message")
|
logging.info(f"GoogleLLM: Adding {len(files)} files to message")
|
||||||
prepared_messages[user_message_index]["content"].append({
|
prepared_messages[user_message_index]["content"].append({"files": files})
|
||||||
"files": files
|
|
||||||
})
|
|
||||||
|
|
||||||
return prepared_messages
|
return prepared_messages
|
||||||
|
|
||||||
@@ -103,10 +109,10 @@ class GoogleLLM(BaseLLM):
|
|||||||
Returns:
|
Returns:
|
||||||
str: Google AI file URI for the uploaded file.
|
str: Google AI file URI for the uploaded file.
|
||||||
"""
|
"""
|
||||||
if 'google_file_uri' in attachment:
|
if "google_file_uri" in attachment:
|
||||||
return attachment['google_file_uri']
|
return attachment["google_file_uri"]
|
||||||
|
|
||||||
file_path = attachment.get('path')
|
file_path = attachment.get("path")
|
||||||
if not file_path:
|
if not file_path:
|
||||||
raise ValueError("No file path provided in attachment")
|
raise ValueError("No file path provided in attachment")
|
||||||
|
|
||||||
@@ -116,17 +122,19 @@ class GoogleLLM(BaseLLM):
|
|||||||
try:
|
try:
|
||||||
file_uri = self.storage.process_file(
|
file_uri = self.storage.process_file(
|
||||||
file_path,
|
file_path,
|
||||||
lambda local_path, **kwargs: self.client.files.upload(file=local_path).uri
|
lambda local_path, **kwargs: self.client.files.upload(
|
||||||
|
file=local_path
|
||||||
|
).uri,
|
||||||
)
|
)
|
||||||
|
|
||||||
from application.core.mongo_db import MongoDB
|
from application.core.mongo_db import MongoDB
|
||||||
|
|
||||||
mongo = MongoDB.get_client()
|
mongo = MongoDB.get_client()
|
||||||
db = mongo[settings.MONGO_DB_NAME]
|
db = mongo[settings.MONGO_DB_NAME]
|
||||||
attachments_collection = db["attachments"]
|
attachments_collection = db["attachments"]
|
||||||
if '_id' in attachment:
|
if "_id" in attachment:
|
||||||
attachments_collection.update_one(
|
attachments_collection.update_one(
|
||||||
{"_id": attachment['_id']},
|
{"_id": attachment["_id"]}, {"$set": {"google_file_uri": file_uri}}
|
||||||
{"$set": {"google_file_uri": file_uri}}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return file_uri
|
return file_uri
|
||||||
@@ -135,6 +143,7 @@ class GoogleLLM(BaseLLM):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
def _clean_messages_google(self, messages):
|
def _clean_messages_google(self, messages):
|
||||||
|
"""Convert OpenAI format messages to Google AI format."""
|
||||||
cleaned_messages = []
|
cleaned_messages = []
|
||||||
for message in messages:
|
for message in messages:
|
||||||
role = message.get("role")
|
role = message.get("role")
|
||||||
@@ -142,6 +151,8 @@ class GoogleLLM(BaseLLM):
|
|||||||
|
|
||||||
if role == "assistant":
|
if role == "assistant":
|
||||||
role = "model"
|
role = "model"
|
||||||
|
elif role == "tool":
|
||||||
|
role = "model"
|
||||||
|
|
||||||
parts = []
|
parts = []
|
||||||
if role and content is not None:
|
if role and content is not None:
|
||||||
@@ -152,10 +163,14 @@ class GoogleLLM(BaseLLM):
|
|||||||
if "text" in item:
|
if "text" in item:
|
||||||
parts.append(types.Part.from_text(text=item["text"]))
|
parts.append(types.Part.from_text(text=item["text"]))
|
||||||
elif "function_call" in item:
|
elif "function_call" in item:
|
||||||
|
# Remove null values from args to avoid API errors
|
||||||
|
cleaned_args = self._remove_null_values(
|
||||||
|
item["function_call"]["args"]
|
||||||
|
)
|
||||||
parts.append(
|
parts.append(
|
||||||
types.Part.from_function_call(
|
types.Part.from_function_call(
|
||||||
name=item["function_call"]["name"],
|
name=item["function_call"]["name"],
|
||||||
args=item["function_call"]["args"],
|
args=cleaned_args,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif "function_response" in item:
|
elif "function_response" in item:
|
||||||
@@ -166,13 +181,13 @@ class GoogleLLM(BaseLLM):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif "files" in item:
|
elif "files" in item:
|
||||||
for file_data in item["files"]:
|
for file_data in item["files"]:
|
||||||
parts.append(
|
parts.append(
|
||||||
types.Part.from_uri(
|
types.Part.from_uri(
|
||||||
file_uri=file_data["file_uri"],
|
file_uri=file_data["file_uri"],
|
||||||
mime_type=file_data["mime_type"]
|
mime_type=file_data["mime_type"],
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unexpected content dictionary format:{item}"
|
f"Unexpected content dictionary format:{item}"
|
||||||
@@ -180,11 +195,63 @@ class GoogleLLM(BaseLLM):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unexpected content type: {type(content)}")
|
raise ValueError(f"Unexpected content type: {type(content)}")
|
||||||
|
|
||||||
cleaned_messages.append(types.Content(role=role, parts=parts))
|
if parts:
|
||||||
|
cleaned_messages.append(types.Content(role=role, parts=parts))
|
||||||
|
|
||||||
return cleaned_messages
|
return cleaned_messages
|
||||||
|
|
||||||
|
def _clean_schema(self, schema_obj):
|
||||||
|
"""
|
||||||
|
Recursively remove unsupported fields from schema objects
|
||||||
|
and validate required properties.
|
||||||
|
"""
|
||||||
|
if not isinstance(schema_obj, dict):
|
||||||
|
return schema_obj
|
||||||
|
allowed_fields = {
|
||||||
|
"type",
|
||||||
|
"description",
|
||||||
|
"items",
|
||||||
|
"properties",
|
||||||
|
"required",
|
||||||
|
"enum",
|
||||||
|
"pattern",
|
||||||
|
"minimum",
|
||||||
|
"maximum",
|
||||||
|
"nullable",
|
||||||
|
"default",
|
||||||
|
}
|
||||||
|
|
||||||
|
cleaned = {}
|
||||||
|
for key, value in schema_obj.items():
|
||||||
|
if key not in allowed_fields:
|
||||||
|
continue
|
||||||
|
elif key == "type" and isinstance(value, str):
|
||||||
|
cleaned[key] = value.upper()
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
cleaned[key] = self._clean_schema(value)
|
||||||
|
elif isinstance(value, list):
|
||||||
|
cleaned[key] = [self._clean_schema(item) for item in value]
|
||||||
|
else:
|
||||||
|
cleaned[key] = value
|
||||||
|
|
||||||
|
# Validate that required properties actually exist in properties
|
||||||
|
if "required" in cleaned and "properties" in cleaned:
|
||||||
|
valid_required = []
|
||||||
|
properties_keys = set(cleaned["properties"].keys())
|
||||||
|
for required_prop in cleaned["required"]:
|
||||||
|
if required_prop in properties_keys:
|
||||||
|
valid_required.append(required_prop)
|
||||||
|
if valid_required:
|
||||||
|
cleaned["required"] = valid_required
|
||||||
|
else:
|
||||||
|
cleaned.pop("required", None)
|
||||||
|
elif "required" in cleaned and "properties" not in cleaned:
|
||||||
|
cleaned.pop("required", None)
|
||||||
|
|
||||||
|
return cleaned
|
||||||
|
|
||||||
def _clean_tools_format(self, tools_list):
|
def _clean_tools_format(self, tools_list):
|
||||||
|
"""Convert OpenAI format tools to Google AI format."""
|
||||||
genai_tools = []
|
genai_tools = []
|
||||||
for tool_data in tools_list:
|
for tool_data in tools_list:
|
||||||
if tool_data["type"] == "function":
|
if tool_data["type"] == "function":
|
||||||
@@ -193,18 +260,16 @@ class GoogleLLM(BaseLLM):
|
|||||||
properties = parameters.get("properties", {})
|
properties = parameters.get("properties", {})
|
||||||
|
|
||||||
if properties:
|
if properties:
|
||||||
|
cleaned_properties = {}
|
||||||
|
for k, v in properties.items():
|
||||||
|
cleaned_properties[k] = self._clean_schema(v)
|
||||||
|
|
||||||
genai_function = dict(
|
genai_function = dict(
|
||||||
name=function["name"],
|
name=function["name"],
|
||||||
description=function["description"],
|
description=function["description"],
|
||||||
parameters={
|
parameters={
|
||||||
"type": "OBJECT",
|
"type": "OBJECT",
|
||||||
"properties": {
|
"properties": cleaned_properties,
|
||||||
k: {
|
|
||||||
**v,
|
|
||||||
"type": v["type"].upper() if v["type"] else None,
|
|
||||||
}
|
|
||||||
for k, v in properties.items()
|
|
||||||
},
|
|
||||||
"required": (
|
"required": (
|
||||||
parameters["required"]
|
parameters["required"]
|
||||||
if "required" in parameters
|
if "required" in parameters
|
||||||
@@ -231,8 +296,10 @@ class GoogleLLM(BaseLLM):
|
|||||||
stream=False,
|
stream=False,
|
||||||
tools=None,
|
tools=None,
|
||||||
formatting="openai",
|
formatting="openai",
|
||||||
|
response_schema=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
"""Generate content using Google AI API without streaming."""
|
||||||
client = genai.Client(api_key=self.api_key)
|
client = genai.Client(api_key=self.api_key)
|
||||||
if formatting == "openai":
|
if formatting == "openai":
|
||||||
messages = self._clean_messages_google(messages)
|
messages = self._clean_messages_google(messages)
|
||||||
@@ -244,16 +311,21 @@ class GoogleLLM(BaseLLM):
|
|||||||
if tools:
|
if tools:
|
||||||
cleaned_tools = self._clean_tools_format(tools)
|
cleaned_tools = self._clean_tools_format(tools)
|
||||||
config.tools = cleaned_tools
|
config.tools = cleaned_tools
|
||||||
response = client.models.generate_content(
|
|
||||||
model=model,
|
# Add response schema for structured output if provided
|
||||||
contents=messages,
|
if response_schema:
|
||||||
config=config,
|
config.response_schema = response_schema
|
||||||
)
|
config.response_mime_type = "application/json"
|
||||||
|
|
||||||
|
response = client.models.generate_content(
|
||||||
|
model=model,
|
||||||
|
contents=messages,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
if tools:
|
||||||
return response
|
return response
|
||||||
else:
|
else:
|
||||||
response = client.models.generate_content(
|
|
||||||
model=model, contents=messages, config=config
|
|
||||||
)
|
|
||||||
return response.text
|
return response.text
|
||||||
|
|
||||||
def _raw_gen_stream(
|
def _raw_gen_stream(
|
||||||
@@ -264,8 +336,10 @@ class GoogleLLM(BaseLLM):
|
|||||||
stream=True,
|
stream=True,
|
||||||
tools=None,
|
tools=None,
|
||||||
formatting="openai",
|
formatting="openai",
|
||||||
|
response_schema=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
"""Generate content using Google AI API with streaming."""
|
||||||
client = genai.Client(api_key=self.api_key)
|
client = genai.Client(api_key=self.api_key)
|
||||||
if formatting == "openai":
|
if formatting == "openai":
|
||||||
messages = self._clean_messages_google(messages)
|
messages = self._clean_messages_google(messages)
|
||||||
@@ -278,17 +352,24 @@ class GoogleLLM(BaseLLM):
|
|||||||
cleaned_tools = self._clean_tools_format(tools)
|
cleaned_tools = self._clean_tools_format(tools)
|
||||||
config.tools = cleaned_tools
|
config.tools = cleaned_tools
|
||||||
|
|
||||||
|
# Add response schema for structured output if provided
|
||||||
|
if response_schema:
|
||||||
|
config.response_schema = response_schema
|
||||||
|
config.response_mime_type = "application/json"
|
||||||
|
|
||||||
# Check if we have both tools and file attachments
|
# Check if we have both tools and file attachments
|
||||||
has_attachments = False
|
has_attachments = False
|
||||||
for message in messages:
|
for message in messages:
|
||||||
for part in message.parts:
|
for part in message.parts:
|
||||||
if hasattr(part, 'file_data') and part.file_data is not None:
|
if hasattr(part, "file_data") and part.file_data is not None:
|
||||||
has_attachments = True
|
has_attachments = True
|
||||||
break
|
break
|
||||||
if has_attachments:
|
if has_attachments:
|
||||||
break
|
break
|
||||||
|
|
||||||
logging.info(f"GoogleLLM: Starting stream generation. Model: {model}, Messages: {json.dumps(messages, default=str)}, Has attachments: {has_attachments}")
|
logging.info(
|
||||||
|
f"GoogleLLM: Starting stream generation. Model: {model}, Messages: {json.dumps(messages, default=str)}, Has attachments: {has_attachments}"
|
||||||
|
)
|
||||||
|
|
||||||
response = client.models.generate_content_stream(
|
response = client.models.generate_content_stream(
|
||||||
model=model,
|
model=model,
|
||||||
@@ -296,18 +377,96 @@ class GoogleLLM(BaseLLM):
|
|||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
if hasattr(chunk, "candidates") and chunk.candidates:
|
if hasattr(chunk, "candidates") and chunk.candidates:
|
||||||
for candidate in chunk.candidates:
|
for candidate in chunk.candidates:
|
||||||
if candidate.content and candidate.content.parts:
|
if candidate.content and candidate.content.parts:
|
||||||
for part in candidate.content.parts:
|
for part in candidate.content.parts:
|
||||||
if part.function_call:
|
if part.function_call:
|
||||||
yield part
|
yield part
|
||||||
elif part.text:
|
elif part.text:
|
||||||
yield part.text
|
yield part.text
|
||||||
elif hasattr(chunk, "text"):
|
elif hasattr(chunk, "text"):
|
||||||
yield chunk.text
|
yield chunk.text
|
||||||
|
finally:
|
||||||
|
if hasattr(response, "close"):
|
||||||
|
response.close()
|
||||||
|
|
||||||
def _supports_tools(self):
|
def _supports_tools(self):
|
||||||
|
"""Return whether this LLM supports function calling."""
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def _supports_structured_output(self):
|
||||||
|
"""Return whether this LLM supports structured JSON output."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def prepare_structured_output_format(self, json_schema):
|
||||||
|
"""Convert JSON schema to Google AI structured output format."""
|
||||||
|
if not json_schema:
|
||||||
|
return None
|
||||||
|
|
||||||
|
type_map = {
|
||||||
|
"object": "OBJECT",
|
||||||
|
"array": "ARRAY",
|
||||||
|
"string": "STRING",
|
||||||
|
"integer": "INTEGER",
|
||||||
|
"number": "NUMBER",
|
||||||
|
"boolean": "BOOLEAN",
|
||||||
|
}
|
||||||
|
|
||||||
|
def convert(schema):
|
||||||
|
if not isinstance(schema, dict):
|
||||||
|
return schema
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
schema_type = schema.get("type")
|
||||||
|
if schema_type:
|
||||||
|
result["type"] = type_map.get(schema_type.lower(), schema_type.upper())
|
||||||
|
|
||||||
|
for key in [
|
||||||
|
"description",
|
||||||
|
"nullable",
|
||||||
|
"enum",
|
||||||
|
"minItems",
|
||||||
|
"maxItems",
|
||||||
|
"required",
|
||||||
|
"propertyOrdering",
|
||||||
|
]:
|
||||||
|
if key in schema:
|
||||||
|
result[key] = schema[key]
|
||||||
|
|
||||||
|
if "format" in schema:
|
||||||
|
format_value = schema["format"]
|
||||||
|
if schema_type == "string":
|
||||||
|
if format_value == "date":
|
||||||
|
result["format"] = "date-time"
|
||||||
|
elif format_value in ["enum", "date-time"]:
|
||||||
|
result["format"] = format_value
|
||||||
|
else:
|
||||||
|
result["format"] = format_value
|
||||||
|
|
||||||
|
if "properties" in schema:
|
||||||
|
result["properties"] = {
|
||||||
|
k: convert(v) for k, v in schema["properties"].items()
|
||||||
|
}
|
||||||
|
if "propertyOrdering" not in result and result.get("type") == "OBJECT":
|
||||||
|
result["propertyOrdering"] = list(result["properties"].keys())
|
||||||
|
|
||||||
|
if "items" in schema:
|
||||||
|
result["items"] = convert(schema["items"])
|
||||||
|
|
||||||
|
for field in ["anyOf", "oneOf", "allOf"]:
|
||||||
|
if field in schema:
|
||||||
|
result[field] = [convert(s) for s in schema[field]]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
try:
|
||||||
|
return convert(json_schema)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(
|
||||||
|
f"Error preparing structured output format for Google: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|||||||
0
application/llm/handlers/__init__.py
Normal file
0
application/llm/handlers/__init__.py
Normal file
351
application/llm/handlers/base.py
Normal file
351
application/llm/handlers/base.py
Normal file
@@ -0,0 +1,351 @@
|
|||||||
|
import logging
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, Generator, List, Optional, Union
|
||||||
|
|
||||||
|
from application.logging import build_stack_data
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolCall:
|
||||||
|
"""Represents a tool/function call from the LLM."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
arguments: Union[str, Dict]
|
||||||
|
index: Optional[int] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict) -> "ToolCall":
|
||||||
|
"""Create ToolCall from dictionary."""
|
||||||
|
return cls(
|
||||||
|
id=data.get("id", ""),
|
||||||
|
name=data.get("name", ""),
|
||||||
|
arguments=data.get("arguments", {}),
|
||||||
|
index=data.get("index"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LLMResponse:
|
||||||
|
"""Represents a response from the LLM."""
|
||||||
|
|
||||||
|
content: str
|
||||||
|
tool_calls: List[ToolCall]
|
||||||
|
finish_reason: str
|
||||||
|
raw_response: Any
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_tool_call(self) -> bool:
|
||||||
|
"""Check if the response requires tool calls."""
|
||||||
|
return bool(self.tool_calls) and self.finish_reason == "tool_calls"
|
||||||
|
|
||||||
|
|
||||||
|
class LLMHandler(ABC):
|
||||||
|
"""Abstract base class for LLM handlers."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.llm_calls = []
|
||||||
|
self.tool_calls = []
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def parse_response(self, response: Any) -> LLMResponse:
|
||||||
|
"""Parse raw LLM response into standardized format."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict:
|
||||||
|
"""Create a tool result message for the conversation history."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _iterate_stream(self, response: Any) -> Generator:
|
||||||
|
"""Iterate through streaming response chunks."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def process_message_flow(
|
||||||
|
self,
|
||||||
|
agent,
|
||||||
|
initial_response,
|
||||||
|
tools_dict: Dict,
|
||||||
|
messages: List[Dict],
|
||||||
|
attachments: Optional[List] = None,
|
||||||
|
stream: bool = False,
|
||||||
|
) -> Union[str, Generator]:
|
||||||
|
"""
|
||||||
|
Main orchestration method for processing LLM message flow.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: The agent instance
|
||||||
|
initial_response: Initial LLM response
|
||||||
|
tools_dict: Dictionary of available tools
|
||||||
|
messages: Conversation history
|
||||||
|
attachments: Optional attachments
|
||||||
|
stream: Whether to use streaming
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Final response or generator for streaming
|
||||||
|
"""
|
||||||
|
messages = self.prepare_messages(agent, messages, attachments)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return self.handle_streaming(agent, initial_response, tools_dict, messages)
|
||||||
|
else:
|
||||||
|
return self.handle_non_streaming(
|
||||||
|
agent, initial_response, tools_dict, messages
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_messages(
|
||||||
|
self, agent, messages: List[Dict], attachments: Optional[List] = None
|
||||||
|
) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
Prepare messages with attachments and provider-specific formatting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: The agent instance
|
||||||
|
messages: Original messages
|
||||||
|
attachments: List of attachments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Prepared messages list
|
||||||
|
"""
|
||||||
|
if not attachments:
|
||||||
|
return messages
|
||||||
|
logger.info(f"Preparing messages with {len(attachments)} attachments")
|
||||||
|
supported_types = agent.llm.get_supported_attachment_types()
|
||||||
|
|
||||||
|
supported_attachments = [
|
||||||
|
a for a in attachments if a.get("mime_type") in supported_types
|
||||||
|
]
|
||||||
|
unsupported_attachments = [
|
||||||
|
a for a in attachments if a.get("mime_type") not in supported_types
|
||||||
|
]
|
||||||
|
|
||||||
|
# Process supported attachments with the LLM's custom method
|
||||||
|
|
||||||
|
if supported_attachments:
|
||||||
|
logger.info(
|
||||||
|
f"Processing {len(supported_attachments)} supported attachments"
|
||||||
|
)
|
||||||
|
messages = agent.llm.prepare_messages_with_attachments(
|
||||||
|
messages, supported_attachments
|
||||||
|
)
|
||||||
|
# Process unsupported attachments with default method
|
||||||
|
|
||||||
|
if unsupported_attachments:
|
||||||
|
logger.info(
|
||||||
|
f"Processing {len(unsupported_attachments)} unsupported attachments"
|
||||||
|
)
|
||||||
|
messages = self._append_unsupported_attachments(
|
||||||
|
messages, unsupported_attachments
|
||||||
|
)
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def _append_unsupported_attachments(
|
||||||
|
self, messages: List[Dict], attachments: List[Dict]
|
||||||
|
) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
Default method to append unsupported attachment content to system prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Current messages
|
||||||
|
attachments: List of unsupported attachments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated messages list
|
||||||
|
"""
|
||||||
|
prepared_messages = messages.copy()
|
||||||
|
attachment_texts = []
|
||||||
|
|
||||||
|
for attachment in attachments:
|
||||||
|
logger.info(f"Adding attachment {attachment.get('id')} to context")
|
||||||
|
if "content" in attachment:
|
||||||
|
attachment_texts.append(
|
||||||
|
f"Attached file content:\n\n{attachment['content']}"
|
||||||
|
)
|
||||||
|
if attachment_texts:
|
||||||
|
combined_text = "\n\n".join(attachment_texts)
|
||||||
|
|
||||||
|
system_msg = next(
|
||||||
|
(msg for msg in prepared_messages if msg.get("role") == "system"),
|
||||||
|
{"role": "system", "content": ""},
|
||||||
|
)
|
||||||
|
|
||||||
|
if system_msg not in prepared_messages:
|
||||||
|
prepared_messages.insert(0, system_msg)
|
||||||
|
system_msg["content"] += f"\n\n{combined_text}"
|
||||||
|
return prepared_messages
|
||||||
|
|
||||||
|
def handle_tool_calls(
|
||||||
|
self, agent, tool_calls: List[ToolCall], tools_dict: Dict, messages: List[Dict]
|
||||||
|
) -> Generator:
|
||||||
|
"""
|
||||||
|
Execute tool calls and update conversation history.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: The agent instance
|
||||||
|
tool_calls: List of tool calls to execute
|
||||||
|
tools_dict: Available tools dictionary
|
||||||
|
messages: Current conversation history
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated messages list
|
||||||
|
"""
|
||||||
|
updated_messages = messages.copy()
|
||||||
|
|
||||||
|
for call in tool_calls:
|
||||||
|
try:
|
||||||
|
self.tool_calls.append(call)
|
||||||
|
tool_executor_gen = agent._execute_tool_action(tools_dict, call)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
yield next(tool_executor_gen)
|
||||||
|
except StopIteration as e:
|
||||||
|
tool_response, call_id = e.value
|
||||||
|
break
|
||||||
|
updated_messages.append(
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"function_call": {
|
||||||
|
"name": call.name,
|
||||||
|
"args": call.arguments,
|
||||||
|
"call_id": call_id,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
updated_messages.append(self.create_tool_message(call, tool_response))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error executing tool: {str(e)}", exc_info=True)
|
||||||
|
error_call = ToolCall(
|
||||||
|
id=call.id, name=call.name, arguments=call.arguments
|
||||||
|
)
|
||||||
|
error_response = f"Error executing tool: {str(e)}"
|
||||||
|
error_message = self.create_tool_message(error_call, error_response)
|
||||||
|
updated_messages.append(error_message)
|
||||||
|
|
||||||
|
call_parts = call.name.split("_")
|
||||||
|
if len(call_parts) >= 2:
|
||||||
|
tool_id = call_parts[-1] # Last part is tool ID (e.g., "1")
|
||||||
|
action_name = "_".join(call_parts[:-1])
|
||||||
|
tool_name = tools_dict.get(tool_id, {}).get("name", "unknown_tool")
|
||||||
|
full_action_name = f"{action_name}_{tool_id}"
|
||||||
|
else:
|
||||||
|
tool_name = "unknown_tool"
|
||||||
|
action_name = call.name
|
||||||
|
full_action_name = call.name
|
||||||
|
yield {
|
||||||
|
"type": "tool_call",
|
||||||
|
"data": {
|
||||||
|
"tool_name": tool_name,
|
||||||
|
"call_id": call.id,
|
||||||
|
"action_name": full_action_name,
|
||||||
|
"arguments": call.arguments,
|
||||||
|
"error": error_response,
|
||||||
|
"status": "error",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return updated_messages
|
||||||
|
|
||||||
|
def handle_non_streaming(
|
||||||
|
self, agent, response: Any, tools_dict: Dict, messages: List[Dict]
|
||||||
|
) -> Generator:
|
||||||
|
"""
|
||||||
|
Handle non-streaming response flow.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: The agent instance
|
||||||
|
response: Current LLM response
|
||||||
|
tools_dict: Available tools dictionary
|
||||||
|
messages: Conversation history
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Final response after processing all tool calls
|
||||||
|
"""
|
||||||
|
parsed = self.parse_response(response)
|
||||||
|
self.llm_calls.append(build_stack_data(agent.llm))
|
||||||
|
|
||||||
|
while parsed.requires_tool_call:
|
||||||
|
tool_handler_gen = self.handle_tool_calls(
|
||||||
|
agent, parsed.tool_calls, tools_dict, messages
|
||||||
|
)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
yield next(tool_handler_gen)
|
||||||
|
except StopIteration as e:
|
||||||
|
messages = e.value
|
||||||
|
break
|
||||||
|
response = agent.llm.gen(
|
||||||
|
model=agent.gpt_model, messages=messages, tools=agent.tools
|
||||||
|
)
|
||||||
|
parsed = self.parse_response(response)
|
||||||
|
self.llm_calls.append(build_stack_data(agent.llm))
|
||||||
|
return parsed.content
|
||||||
|
|
||||||
|
def handle_streaming(
|
||||||
|
self, agent, response: Any, tools_dict: Dict, messages: List[Dict]
|
||||||
|
) -> Generator:
|
||||||
|
"""
|
||||||
|
Handle streaming response flow.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: The agent instance
|
||||||
|
response: Current LLM response
|
||||||
|
tools_dict: Available tools dictionary
|
||||||
|
messages: Conversation history
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Streaming response chunks
|
||||||
|
"""
|
||||||
|
buffer = ""
|
||||||
|
tool_calls = {}
|
||||||
|
|
||||||
|
for chunk in self._iterate_stream(response):
|
||||||
|
if isinstance(chunk, str):
|
||||||
|
yield chunk
|
||||||
|
continue
|
||||||
|
parsed = self.parse_response(chunk)
|
||||||
|
|
||||||
|
if parsed.tool_calls:
|
||||||
|
for call in parsed.tool_calls:
|
||||||
|
if call.index not in tool_calls:
|
||||||
|
tool_calls[call.index] = call
|
||||||
|
else:
|
||||||
|
existing = tool_calls[call.index]
|
||||||
|
if call.id:
|
||||||
|
existing.id = call.id
|
||||||
|
if call.name:
|
||||||
|
existing.name = call.name
|
||||||
|
if call.arguments:
|
||||||
|
existing.arguments += call.arguments
|
||||||
|
if parsed.finish_reason == "tool_calls":
|
||||||
|
tool_handler_gen = self.handle_tool_calls(
|
||||||
|
agent, list(tool_calls.values()), tools_dict, messages
|
||||||
|
)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
yield next(tool_handler_gen)
|
||||||
|
except StopIteration as e:
|
||||||
|
messages = e.value
|
||||||
|
break
|
||||||
|
tool_calls = {}
|
||||||
|
|
||||||
|
response = agent.llm.gen_stream(
|
||||||
|
model=agent.gpt_model, messages=messages, tools=agent.tools
|
||||||
|
)
|
||||||
|
self.llm_calls.append(build_stack_data(agent.llm))
|
||||||
|
|
||||||
|
yield from self.handle_streaming(agent, response, tools_dict, messages)
|
||||||
|
return
|
||||||
|
if parsed.content:
|
||||||
|
buffer += parsed.content
|
||||||
|
yield buffer
|
||||||
|
buffer = ""
|
||||||
|
if parsed.finish_reason == "stop":
|
||||||
|
return
|
||||||
78
application/llm/handlers/google.py
Normal file
78
application/llm/handlers/google.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
import uuid
|
||||||
|
from typing import Any, Dict, Generator
|
||||||
|
|
||||||
|
from application.llm.handlers.base import LLMHandler, LLMResponse, ToolCall
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleLLMHandler(LLMHandler):
|
||||||
|
"""Handler for Google's GenAI API."""
|
||||||
|
|
||||||
|
def parse_response(self, response: Any) -> LLMResponse:
|
||||||
|
"""Parse Google response into standardized format."""
|
||||||
|
|
||||||
|
if isinstance(response, str):
|
||||||
|
return LLMResponse(
|
||||||
|
content=response,
|
||||||
|
tool_calls=[],
|
||||||
|
finish_reason="stop",
|
||||||
|
raw_response=response,
|
||||||
|
)
|
||||||
|
if hasattr(response, "candidates"):
|
||||||
|
parts = response.candidates[0].content.parts if response.candidates else []
|
||||||
|
tool_calls = [
|
||||||
|
ToolCall(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
name=part.function_call.name,
|
||||||
|
arguments=part.function_call.args,
|
||||||
|
)
|
||||||
|
for part in parts
|
||||||
|
if hasattr(part, "function_call") and part.function_call is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
content = " ".join(
|
||||||
|
part.text
|
||||||
|
for part in parts
|
||||||
|
if hasattr(part, "text") and part.text is not None
|
||||||
|
)
|
||||||
|
return LLMResponse(
|
||||||
|
content=content,
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
finish_reason="tool_calls" if tool_calls else "stop",
|
||||||
|
raw_response=response,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
tool_calls = []
|
||||||
|
if hasattr(response, "function_call"):
|
||||||
|
tool_calls.append(
|
||||||
|
ToolCall(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
name=response.function_call.name,
|
||||||
|
arguments=response.function_call.args,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return LLMResponse(
|
||||||
|
content=response.text if hasattr(response, "text") else "",
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
finish_reason="tool_calls" if tool_calls else "stop",
|
||||||
|
raw_response=response,
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict:
|
||||||
|
"""Create Google-style tool message."""
|
||||||
|
|
||||||
|
return {
|
||||||
|
"role": "model",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"function_response": {
|
||||||
|
"name": tool_call.name,
|
||||||
|
"response": {"result": result},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
def _iterate_stream(self, response: Any) -> Generator:
|
||||||
|
"""Iterate through Google streaming response."""
|
||||||
|
for chunk in response:
|
||||||
|
yield chunk
|
||||||
18
application/llm/handlers/handler_creator.py
Normal file
18
application/llm/handlers/handler_creator.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
from application.llm.handlers.base import LLMHandler
|
||||||
|
from application.llm.handlers.google import GoogleLLMHandler
|
||||||
|
from application.llm.handlers.openai import OpenAILLMHandler
|
||||||
|
|
||||||
|
|
||||||
|
class LLMHandlerCreator:
|
||||||
|
handlers = {
|
||||||
|
"openai": OpenAILLMHandler,
|
||||||
|
"google": GoogleLLMHandler,
|
||||||
|
"default": OpenAILLMHandler,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_handler(cls, llm_type: str, *args, **kwargs) -> LLMHandler:
|
||||||
|
handler_class = cls.handlers.get(llm_type.lower())
|
||||||
|
if not handler_class:
|
||||||
|
handler_class = OpenAILLMHandler
|
||||||
|
return handler_class(*args, **kwargs)
|
||||||
57
application/llm/handlers/openai.py
Normal file
57
application/llm/handlers/openai.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
from typing import Any, Dict, Generator
|
||||||
|
|
||||||
|
from application.llm.handlers.base import LLMHandler, LLMResponse, ToolCall
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAILLMHandler(LLMHandler):
|
||||||
|
"""Handler for OpenAI API."""
|
||||||
|
|
||||||
|
def parse_response(self, response: Any) -> LLMResponse:
|
||||||
|
"""Parse OpenAI response into standardized format."""
|
||||||
|
if isinstance(response, str):
|
||||||
|
return LLMResponse(
|
||||||
|
content=response,
|
||||||
|
tool_calls=[],
|
||||||
|
finish_reason="stop",
|
||||||
|
raw_response=response,
|
||||||
|
)
|
||||||
|
|
||||||
|
message = getattr(response, "message", None) or getattr(response, "delta", None)
|
||||||
|
|
||||||
|
tool_calls = []
|
||||||
|
if hasattr(message, "tool_calls"):
|
||||||
|
tool_calls = [
|
||||||
|
ToolCall(
|
||||||
|
id=getattr(tc, "id", ""),
|
||||||
|
name=getattr(tc.function, "name", ""),
|
||||||
|
arguments=getattr(tc.function, "arguments", ""),
|
||||||
|
index=getattr(tc, "index", None),
|
||||||
|
)
|
||||||
|
for tc in message.tool_calls or []
|
||||||
|
]
|
||||||
|
return LLMResponse(
|
||||||
|
content=getattr(message, "content", ""),
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
finish_reason=getattr(response, "finish_reason", ""),
|
||||||
|
raw_response=response,
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict:
|
||||||
|
"""Create OpenAI-style tool message."""
|
||||||
|
return {
|
||||||
|
"role": "tool",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"function_response": {
|
||||||
|
"name": tool_call.name,
|
||||||
|
"response": {"result": result},
|
||||||
|
"call_id": tool_call.id,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
def _iterate_stream(self, response: Any) -> Generator:
|
||||||
|
"""Iterate through OpenAI streaming response."""
|
||||||
|
for chunk in response:
|
||||||
|
yield chunk
|
||||||
@@ -2,6 +2,7 @@ from application.llm.base import BaseLLM
|
|||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
|
|
||||||
class LlamaSingleton:
|
class LlamaSingleton:
|
||||||
_instances = {}
|
_instances = {}
|
||||||
_lock = threading.Lock() # Add a lock for thread synchronization
|
_lock = threading.Lock() # Add a lock for thread synchronization
|
||||||
@@ -29,7 +30,7 @@ class LlamaCpp(BaseLLM):
|
|||||||
self,
|
self,
|
||||||
api_key=None,
|
api_key=None,
|
||||||
user_api_key=None,
|
user_api_key=None,
|
||||||
llm_name=settings.MODEL_PATH,
|
llm_name=settings.LLM_PATH,
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@@ -42,14 +43,18 @@ class LlamaCpp(BaseLLM):
|
|||||||
context = messages[0]["content"]
|
context = messages[0]["content"]
|
||||||
user_question = messages[-1]["content"]
|
user_question = messages[-1]["content"]
|
||||||
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
|
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
|
||||||
result = LlamaSingleton.query_model(self.llama, prompt, max_tokens=150, echo=False)
|
result = LlamaSingleton.query_model(
|
||||||
|
self.llama, prompt, max_tokens=150, echo=False
|
||||||
|
)
|
||||||
return result["choices"][0]["text"].split("### Answer \n")[-1]
|
return result["choices"][0]["text"].split("### Answer \n")[-1]
|
||||||
|
|
||||||
def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs):
|
def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs):
|
||||||
context = messages[0]["content"]
|
context = messages[0]["content"]
|
||||||
user_question = messages[-1]["content"]
|
user_question = messages[-1]["content"]
|
||||||
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
|
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
|
||||||
result = LlamaSingleton.query_model(self.llama, prompt, max_tokens=150, echo=False, stream=stream)
|
result = LlamaSingleton.query_model(
|
||||||
|
self.llama, prompt, max_tokens=150, echo=False, stream=stream
|
||||||
|
)
|
||||||
for item in result:
|
for item in result:
|
||||||
for choice in item["choices"]:
|
for choice in item["choices"]:
|
||||||
yield choice["text"]
|
yield choice["text"]
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import json
|
|
||||||
import base64
|
import base64
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
@@ -13,7 +13,10 @@ class OpenAILLM(BaseLLM):
|
|||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
if isinstance(settings.OPENAI_BASE_URL, str) and settings.OPENAI_BASE_URL.strip():
|
if (
|
||||||
|
isinstance(settings.OPENAI_BASE_URL, str)
|
||||||
|
and settings.OPENAI_BASE_URL.strip()
|
||||||
|
):
|
||||||
self.client = OpenAI(api_key=api_key, base_url=settings.OPENAI_BASE_URL)
|
self.client = OpenAI(api_key=api_key, base_url=settings.OPENAI_BASE_URL)
|
||||||
else:
|
else:
|
||||||
DEFAULT_OPENAI_API_BASE = "https://api.openai.com/v1"
|
DEFAULT_OPENAI_API_BASE = "https://api.openai.com/v1"
|
||||||
@@ -41,14 +44,15 @@ class OpenAILLM(BaseLLM):
|
|||||||
{"role": role, "content": item["text"]}
|
{"role": role, "content": item["text"]}
|
||||||
)
|
)
|
||||||
elif "function_call" in item:
|
elif "function_call" in item:
|
||||||
|
cleaned_args = self._remove_null_values(
|
||||||
|
item["function_call"]["args"]
|
||||||
|
)
|
||||||
tool_call = {
|
tool_call = {
|
||||||
"id": item["function_call"]["call_id"],
|
"id": item["function_call"]["call_id"],
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": item["function_call"]["name"],
|
"name": item["function_call"]["name"],
|
||||||
"arguments": json.dumps(
|
"arguments": json.dumps(cleaned_args),
|
||||||
item["function_call"]["args"]
|
|
||||||
),
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
cleaned_messages.append(
|
cleaned_messages.append(
|
||||||
@@ -73,14 +77,30 @@ class OpenAILLM(BaseLLM):
|
|||||||
elif isinstance(item, dict):
|
elif isinstance(item, dict):
|
||||||
content_parts = []
|
content_parts = []
|
||||||
if "text" in item:
|
if "text" in item:
|
||||||
content_parts.append({"type": "text", "text": item["text"]})
|
content_parts.append(
|
||||||
elif "type" in item and item["type"] == "text" and "text" in item:
|
{"type": "text", "text": item["text"]}
|
||||||
|
)
|
||||||
|
elif (
|
||||||
|
"type" in item
|
||||||
|
and item["type"] == "text"
|
||||||
|
and "text" in item
|
||||||
|
):
|
||||||
content_parts.append(item)
|
content_parts.append(item)
|
||||||
elif "type" in item and item["type"] == "file" and "file" in item:
|
elif (
|
||||||
|
"type" in item
|
||||||
|
and item["type"] == "file"
|
||||||
|
and "file" in item
|
||||||
|
):
|
||||||
content_parts.append(item)
|
content_parts.append(item)
|
||||||
elif "type" in item and item["type"] == "image_url" and "image_url" in item:
|
elif (
|
||||||
|
"type" in item
|
||||||
|
and item["type"] == "image_url"
|
||||||
|
and "image_url" in item
|
||||||
|
):
|
||||||
content_parts.append(item)
|
content_parts.append(item)
|
||||||
cleaned_messages.append({"role": role, "content": content_parts})
|
cleaned_messages.append(
|
||||||
|
{"role": role, "content": content_parts}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unexpected content dictionary format: {item}"
|
f"Unexpected content dictionary format: {item}"
|
||||||
@@ -98,22 +118,29 @@ class OpenAILLM(BaseLLM):
|
|||||||
stream=False,
|
stream=False,
|
||||||
tools=None,
|
tools=None,
|
||||||
engine=settings.AZURE_DEPLOYMENT_NAME,
|
engine=settings.AZURE_DEPLOYMENT_NAME,
|
||||||
|
response_format=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
messages = self._clean_messages_openai(messages)
|
messages = self._clean_messages_openai(messages)
|
||||||
|
|
||||||
|
request_params = {
|
||||||
|
"model": model,
|
||||||
|
"messages": messages,
|
||||||
|
"stream": stream,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
request_params["tools"] = tools
|
||||||
|
|
||||||
|
if response_format:
|
||||||
|
request_params["response_format"] = response_format
|
||||||
|
|
||||||
|
response = self.client.chat.completions.create(**request_params)
|
||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
response = self.client.chat.completions.create(
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
stream=stream,
|
|
||||||
tools=tools,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
return response.choices[0]
|
return response.choices[0]
|
||||||
else:
|
else:
|
||||||
response = self.client.chat.completions.create(
|
|
||||||
model=model, messages=messages, stream=stream, **kwargs
|
|
||||||
)
|
|
||||||
return response.choices[0].message.content
|
return response.choices[0].message.content
|
||||||
|
|
||||||
def _raw_gen_stream(
|
def _raw_gen_stream(
|
||||||
@@ -124,31 +151,103 @@ class OpenAILLM(BaseLLM):
|
|||||||
stream=True,
|
stream=True,
|
||||||
tools=None,
|
tools=None,
|
||||||
engine=settings.AZURE_DEPLOYMENT_NAME,
|
engine=settings.AZURE_DEPLOYMENT_NAME,
|
||||||
|
response_format=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
messages = self._clean_messages_openai(messages)
|
messages = self._clean_messages_openai(messages)
|
||||||
if tools:
|
|
||||||
response = self.client.chat.completions.create(
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
stream=stream,
|
|
||||||
tools=tools,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
response = self.client.chat.completions.create(
|
|
||||||
model=model, messages=messages, stream=stream, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
for line in response:
|
request_params = {
|
||||||
if len(line.choices) > 0 and line.choices[0].delta.content is not None and len(line.choices[0].delta.content) > 0:
|
"model": model,
|
||||||
yield line.choices[0].delta.content
|
"messages": messages,
|
||||||
elif len(line.choices) > 0:
|
"stream": stream,
|
||||||
yield line.choices[0]
|
**kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
request_params["tools"] = tools
|
||||||
|
|
||||||
|
if response_format:
|
||||||
|
request_params["response_format"] = response_format
|
||||||
|
|
||||||
|
response = self.client.chat.completions.create(**request_params)
|
||||||
|
|
||||||
|
try:
|
||||||
|
for line in response:
|
||||||
|
if (
|
||||||
|
len(line.choices) > 0
|
||||||
|
and line.choices[0].delta.content is not None
|
||||||
|
and len(line.choices[0].delta.content) > 0
|
||||||
|
):
|
||||||
|
yield line.choices[0].delta.content
|
||||||
|
elif len(line.choices) > 0:
|
||||||
|
yield line.choices[0]
|
||||||
|
finally:
|
||||||
|
if hasattr(response, "close"):
|
||||||
|
response.close()
|
||||||
|
|
||||||
def _supports_tools(self):
|
def _supports_tools(self):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def _supports_structured_output(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def prepare_structured_output_format(self, json_schema):
|
||||||
|
if not json_schema:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
def add_additional_properties_false(schema_obj):
|
||||||
|
if isinstance(schema_obj, dict):
|
||||||
|
schema_copy = schema_obj.copy()
|
||||||
|
|
||||||
|
if schema_copy.get("type") == "object":
|
||||||
|
schema_copy["additionalProperties"] = False
|
||||||
|
# Ensure 'required' includes all properties for OpenAI strict mode
|
||||||
|
if "properties" in schema_copy:
|
||||||
|
schema_copy["required"] = list(
|
||||||
|
schema_copy["properties"].keys()
|
||||||
|
)
|
||||||
|
|
||||||
|
for key, value in schema_copy.items():
|
||||||
|
if key == "properties" and isinstance(value, dict):
|
||||||
|
schema_copy[key] = {
|
||||||
|
prop_name: add_additional_properties_false(prop_schema)
|
||||||
|
for prop_name, prop_schema in value.items()
|
||||||
|
}
|
||||||
|
elif key == "items" and isinstance(value, dict):
|
||||||
|
schema_copy[key] = add_additional_properties_false(value)
|
||||||
|
elif key in ["anyOf", "oneOf", "allOf"] and isinstance(
|
||||||
|
value, list
|
||||||
|
):
|
||||||
|
schema_copy[key] = [
|
||||||
|
add_additional_properties_false(sub_schema)
|
||||||
|
for sub_schema in value
|
||||||
|
]
|
||||||
|
|
||||||
|
return schema_copy
|
||||||
|
return schema_obj
|
||||||
|
|
||||||
|
processed_schema = add_additional_properties_false(json_schema)
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"type": "json_schema",
|
||||||
|
"json_schema": {
|
||||||
|
"name": processed_schema.get("name", "response"),
|
||||||
|
"description": processed_schema.get(
|
||||||
|
"description", "Structured response"
|
||||||
|
),
|
||||||
|
"schema": processed_schema,
|
||||||
|
"strict": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error preparing structured output format: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
def get_supported_attachment_types(self):
|
def get_supported_attachment_types(self):
|
||||||
"""
|
"""
|
||||||
Return a list of MIME types supported by OpenAI for file uploads.
|
Return a list of MIME types supported by OpenAI for file uploads.
|
||||||
@@ -157,12 +256,12 @@ class OpenAILLM(BaseLLM):
|
|||||||
list: List of supported MIME types
|
list: List of supported MIME types
|
||||||
"""
|
"""
|
||||||
return [
|
return [
|
||||||
'application/pdf',
|
"application/pdf",
|
||||||
'image/png',
|
"image/png",
|
||||||
'image/jpeg',
|
"image/jpeg",
|
||||||
'image/jpg',
|
"image/jpg",
|
||||||
'image/webp',
|
"image/webp",
|
||||||
'image/gif'
|
"image/gif",
|
||||||
]
|
]
|
||||||
|
|
||||||
def prepare_messages_with_attachments(self, messages, attachments=None):
|
def prepare_messages_with_attachments(self, messages, attachments=None):
|
||||||
@@ -202,39 +301,46 @@ class OpenAILLM(BaseLLM):
|
|||||||
prepared_messages[user_message_index]["content"] = []
|
prepared_messages[user_message_index]["content"] = []
|
||||||
|
|
||||||
for attachment in attachments:
|
for attachment in attachments:
|
||||||
mime_type = attachment.get('mime_type')
|
mime_type = attachment.get("mime_type")
|
||||||
|
|
||||||
if mime_type and mime_type.startswith('image/'):
|
if mime_type and mime_type.startswith("image/"):
|
||||||
try:
|
try:
|
||||||
base64_image = self._get_base64_image(attachment)
|
base64_image = self._get_base64_image(attachment)
|
||||||
prepared_messages[user_message_index]["content"].append({
|
prepared_messages[user_message_index]["content"].append(
|
||||||
"type": "image_url",
|
{
|
||||||
"image_url": {
|
"type": "image_url",
|
||||||
"url": f"data:{mime_type};base64,{base64_image}"
|
"image_url": {
|
||||||
|
"url": f"data:{mime_type};base64,{base64_image}"
|
||||||
|
},
|
||||||
}
|
}
|
||||||
})
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error processing image attachment: {e}", exc_info=True)
|
logging.error(
|
||||||
if 'content' in attachment:
|
f"Error processing image attachment: {e}", exc_info=True
|
||||||
prepared_messages[user_message_index]["content"].append({
|
)
|
||||||
"type": "text",
|
if "content" in attachment:
|
||||||
"text": f"[Image could not be processed: {attachment.get('path', 'unknown')}]"
|
prepared_messages[user_message_index]["content"].append(
|
||||||
})
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": f"[Image could not be processed: {attachment.get('path', 'unknown')}]",
|
||||||
|
}
|
||||||
|
)
|
||||||
# Handle PDFs using the file API
|
# Handle PDFs using the file API
|
||||||
elif mime_type == 'application/pdf':
|
elif mime_type == "application/pdf":
|
||||||
try:
|
try:
|
||||||
file_id = self._upload_file_to_openai(attachment)
|
file_id = self._upload_file_to_openai(attachment)
|
||||||
prepared_messages[user_message_index]["content"].append({
|
prepared_messages[user_message_index]["content"].append(
|
||||||
"type": "file",
|
{"type": "file", "file": {"file_id": file_id}}
|
||||||
"file": {"file_id": file_id}
|
)
|
||||||
})
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error uploading PDF to OpenAI: {e}", exc_info=True)
|
logging.error(f"Error uploading PDF to OpenAI: {e}", exc_info=True)
|
||||||
if 'content' in attachment:
|
if "content" in attachment:
|
||||||
prepared_messages[user_message_index]["content"].append({
|
prepared_messages[user_message_index]["content"].append(
|
||||||
"type": "text",
|
{
|
||||||
"text": f"File content:\n\n{attachment['content']}"
|
"type": "text",
|
||||||
})
|
"text": f"File content:\n\n{attachment['content']}",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return prepared_messages
|
return prepared_messages
|
||||||
|
|
||||||
@@ -248,13 +354,13 @@ class OpenAILLM(BaseLLM):
|
|||||||
Returns:
|
Returns:
|
||||||
str: Base64-encoded image data.
|
str: Base64-encoded image data.
|
||||||
"""
|
"""
|
||||||
file_path = attachment.get('path')
|
file_path = attachment.get("path")
|
||||||
if not file_path:
|
if not file_path:
|
||||||
raise ValueError("No file path provided in attachment")
|
raise ValueError("No file path provided in attachment")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with self.storage.get_file(file_path) as image_file:
|
with self.storage.get_file(file_path) as image_file:
|
||||||
return base64.b64encode(image_file.read()).decode('utf-8')
|
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
raise FileNotFoundError(f"File not found: {file_path}")
|
raise FileNotFoundError(f"File not found: {file_path}")
|
||||||
|
|
||||||
@@ -273,10 +379,10 @@ class OpenAILLM(BaseLLM):
|
|||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
if 'openai_file_id' in attachment:
|
if "openai_file_id" in attachment:
|
||||||
return attachment['openai_file_id']
|
return attachment["openai_file_id"]
|
||||||
|
|
||||||
file_path = attachment.get('path')
|
file_path = attachment.get("path")
|
||||||
|
|
||||||
if not self.storage.file_exists(file_path):
|
if not self.storage.file_exists(file_path):
|
||||||
raise FileNotFoundError(f"File not found: {file_path}")
|
raise FileNotFoundError(f"File not found: {file_path}")
|
||||||
@@ -285,19 +391,18 @@ class OpenAILLM(BaseLLM):
|
|||||||
file_id = self.storage.process_file(
|
file_id = self.storage.process_file(
|
||||||
file_path,
|
file_path,
|
||||||
lambda local_path, **kwargs: self.client.files.create(
|
lambda local_path, **kwargs: self.client.files.create(
|
||||||
file=open(local_path, 'rb'),
|
file=open(local_path, "rb"), purpose="assistants"
|
||||||
purpose="assistants"
|
).id,
|
||||||
).id
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from application.core.mongo_db import MongoDB
|
from application.core.mongo_db import MongoDB
|
||||||
|
|
||||||
mongo = MongoDB.get_client()
|
mongo = MongoDB.get_client()
|
||||||
db = mongo[settings.MONGO_DB_NAME]
|
db = mongo[settings.MONGO_DB_NAME]
|
||||||
attachments_collection = db["attachments"]
|
attachments_collection = db["attachments"]
|
||||||
if '_id' in attachment:
|
if "_id" in attachment:
|
||||||
attachments_collection.update_one(
|
attachments_collection.update_one(
|
||||||
{"_id": attachment['_id']},
|
{"_id": attachment["_id"]}, {"$set": {"openai_file_id": file_id}}
|
||||||
{"$set": {"openai_file_id": file_id}}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return file_id
|
return file_id
|
||||||
@@ -308,9 +413,7 @@ class OpenAILLM(BaseLLM):
|
|||||||
|
|
||||||
class AzureOpenAILLM(OpenAILLM):
|
class AzureOpenAILLM(OpenAILLM):
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, api_key, user_api_key, *args, **kwargs):
|
||||||
self, api_key, user_api_key, *args, **kwargs
|
|
||||||
):
|
|
||||||
|
|
||||||
super().__init__(api_key)
|
super().__init__(api_key)
|
||||||
self.api_base = (settings.OPENAI_API_BASE,)
|
self.api_base = (settings.OPENAI_API_BASE,)
|
||||||
@@ -321,5 +424,5 @@ class AzureOpenAILLM(OpenAILLM):
|
|||||||
self.client = AzureOpenAI(
|
self.client = AzureOpenAI(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_version=settings.OPENAI_API_VERSION,
|
api_version=settings.OPENAI_API_VERSION,
|
||||||
azure_endpoint=settings.OPENAI_API_BASE
|
azure_endpoint=settings.OPENAI_API_BASE,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -136,6 +136,8 @@ def _log_to_mongodb(
|
|||||||
mongo = MongoDB.get_client()
|
mongo = MongoDB.get_client()
|
||||||
db = mongo[settings.MONGO_DB_NAME]
|
db = mongo[settings.MONGO_DB_NAME]
|
||||||
user_logs_collection = db["stack_logs"]
|
user_logs_collection = db["stack_logs"]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
log_entry = {
|
log_entry = {
|
||||||
"endpoint": endpoint,
|
"endpoint": endpoint,
|
||||||
@@ -147,6 +149,11 @@ def _log_to_mongodb(
|
|||||||
"stacks": stacks,
|
"stacks": stacks,
|
||||||
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
||||||
}
|
}
|
||||||
|
# clean up text fields to be no longer than 10000 characters
|
||||||
|
for key, value in log_entry.items():
|
||||||
|
if isinstance(value, str) and len(value) > 10000:
|
||||||
|
log_entry[key] = value[:10000]
|
||||||
|
|
||||||
user_logs_collection.insert_one(log_entry)
|
user_logs_collection.insert_one(log_entry)
|
||||||
logging.debug(f"Logged activity to MongoDB: {activity_id}")
|
logging.debug(f"Logged activity to MongoDB: {activity_id}")
|
||||||
|
|
||||||
|
|||||||
@@ -32,16 +32,7 @@ class Chunker:
|
|||||||
header, body = "", text # No header, treat entire text as body
|
header, body = "", text # No header, treat entire text as body
|
||||||
return header, body
|
return header, body
|
||||||
|
|
||||||
def combine_documents(self, doc: Document, next_doc: Document) -> Document:
|
|
||||||
combined_text = doc.text + " " + next_doc.text
|
|
||||||
combined_token_count = len(self.encoding.encode(combined_text))
|
|
||||||
new_doc = Document(
|
|
||||||
text=combined_text,
|
|
||||||
doc_id=doc.doc_id,
|
|
||||||
embedding=doc.embedding,
|
|
||||||
extra_info={**(doc.extra_info or {}), "token_count": combined_token_count}
|
|
||||||
)
|
|
||||||
return new_doc
|
|
||||||
|
|
||||||
def split_document(self, doc: Document) -> List[Document]:
|
def split_document(self, doc: Document) -> List[Document]:
|
||||||
split_docs = []
|
split_docs = []
|
||||||
@@ -82,26 +73,11 @@ class Chunker:
|
|||||||
processed_docs.append(doc)
|
processed_docs.append(doc)
|
||||||
i += 1
|
i += 1
|
||||||
elif token_count < self.min_tokens:
|
elif token_count < self.min_tokens:
|
||||||
if i + 1 < len(documents):
|
|
||||||
next_doc = documents[i + 1]
|
doc.extra_info = doc.extra_info or {}
|
||||||
next_tokens = self.encoding.encode(next_doc.text)
|
doc.extra_info["token_count"] = token_count
|
||||||
if token_count + len(next_tokens) <= self.max_tokens:
|
processed_docs.append(doc)
|
||||||
# Combine small documents
|
i += 1
|
||||||
combined_doc = self.combine_documents(doc, next_doc)
|
|
||||||
processed_docs.append(combined_doc)
|
|
||||||
i += 2
|
|
||||||
else:
|
|
||||||
# Keep the small document as is if adding next_doc would exceed max_tokens
|
|
||||||
doc.extra_info = doc.extra_info or {}
|
|
||||||
doc.extra_info["token_count"] = token_count
|
|
||||||
processed_docs.append(doc)
|
|
||||||
i += 1
|
|
||||||
else:
|
|
||||||
# No next document to combine with; add the small document as is
|
|
||||||
doc.extra_info = doc.extra_info or {}
|
|
||||||
doc.extra_info["token_count"] = token_count
|
|
||||||
processed_docs.append(doc)
|
|
||||||
i += 1
|
|
||||||
else:
|
else:
|
||||||
# Split large documents
|
# Split large documents
|
||||||
processed_docs.extend(self.split_document(doc))
|
processed_docs.extend(self.split_document(doc))
|
||||||
|
|||||||
18
application/parser/connectors/__init__.py
Normal file
18
application/parser/connectors/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
"""
|
||||||
|
External knowledge base connectors for DocsGPT.
|
||||||
|
|
||||||
|
This module contains connectors for external knowledge bases and document storage systems
|
||||||
|
that require authentication and specialized handling, separate from simple web scrapers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .base import BaseConnectorAuth, BaseConnectorLoader
|
||||||
|
from .connector_creator import ConnectorCreator
|
||||||
|
from .google_drive import GoogleDriveAuth, GoogleDriveLoader
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'BaseConnectorAuth',
|
||||||
|
'BaseConnectorLoader',
|
||||||
|
'ConnectorCreator',
|
||||||
|
'GoogleDriveAuth',
|
||||||
|
'GoogleDriveLoader'
|
||||||
|
]
|
||||||
129
application/parser/connectors/base.py
Normal file
129
application/parser/connectors/base.py
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
"""
|
||||||
|
Base classes for external knowledge base connectors.
|
||||||
|
|
||||||
|
This module provides minimal abstract base classes that define the essential
|
||||||
|
interface for external knowledge base connectors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from application.parser.schema.base import Document
|
||||||
|
|
||||||
|
|
||||||
|
class BaseConnectorAuth(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for connector authentication.
|
||||||
|
|
||||||
|
Defines the minimal interface that all connector authentication
|
||||||
|
implementations must follow.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_authorization_url(self, state: Optional[str] = None) -> str:
|
||||||
|
"""
|
||||||
|
Generate authorization URL for OAuth flows.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: Optional state parameter for CSRF protection
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Authorization URL
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def exchange_code_for_tokens(self, authorization_code: str) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Exchange authorization code for access tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
authorization_code: Authorization code from OAuth callback
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing token information
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def refresh_access_token(self, refresh_token: str) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Refresh an expired access token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
refresh_token: Refresh token
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing refreshed token information
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def is_token_expired(self, token_info: Dict[str, Any]) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a token is expired.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_info: Token information dictionary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if token is expired, False otherwise
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BaseConnectorLoader(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for connector loaders.
|
||||||
|
|
||||||
|
Defines the minimal interface that all connector loader
|
||||||
|
implementations must follow.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(self, session_token: str):
|
||||||
|
"""
|
||||||
|
Initialize the connector loader.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_token: Authentication session token
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load_data(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||||
|
"""
|
||||||
|
Load documents from the external knowledge base.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: Configuration dictionary containing:
|
||||||
|
- file_ids: Optional list of specific file IDs to load
|
||||||
|
- folder_ids: Optional list of folder IDs to browse/download
|
||||||
|
- limit: Maximum number of items to return
|
||||||
|
- list_only: If True, return metadata without content
|
||||||
|
- recursive: Whether to recursively process folders
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Document objects
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def download_to_directory(self, local_dir: str, source_config: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Download files/folders to a local directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
local_dir: Local directory path to download files to
|
||||||
|
source_config: Configuration for what to download
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing download results:
|
||||||
|
- files_downloaded: Number of files downloaded
|
||||||
|
- directory_path: Path where files were downloaded
|
||||||
|
- empty_result: Whether no files were downloaded
|
||||||
|
- source_type: Type of connector
|
||||||
|
- config_used: Configuration that was used
|
||||||
|
- error: Error message if download failed (optional)
|
||||||
|
"""
|
||||||
|
pass
|
||||||
81
application/parser/connectors/connector_creator.py
Normal file
81
application/parser/connectors/connector_creator.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
from application.parser.connectors.google_drive.loader import GoogleDriveLoader
|
||||||
|
from application.parser.connectors.google_drive.auth import GoogleDriveAuth
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectorCreator:
|
||||||
|
"""
|
||||||
|
Factory class for creating external knowledge base connectors and auth providers.
|
||||||
|
|
||||||
|
These are different from remote loaders as they typically require
|
||||||
|
authentication and connect to external document storage systems.
|
||||||
|
"""
|
||||||
|
|
||||||
|
connectors = {
|
||||||
|
"google_drive": GoogleDriveLoader,
|
||||||
|
}
|
||||||
|
|
||||||
|
auth_providers = {
|
||||||
|
"google_drive": GoogleDriveAuth,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_connector(cls, connector_type, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Create a connector instance for the specified type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connector_type: Type of connector to create (e.g., 'google_drive')
|
||||||
|
*args, **kwargs: Arguments to pass to the connector constructor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Connector instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If connector type is not supported
|
||||||
|
"""
|
||||||
|
connector_class = cls.connectors.get(connector_type.lower())
|
||||||
|
if not connector_class:
|
||||||
|
raise ValueError(f"No connector class found for type {connector_type}")
|
||||||
|
return connector_class(*args, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_auth(cls, connector_type):
|
||||||
|
"""
|
||||||
|
Create an auth provider instance for the specified connector type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connector_type: Type of connector auth to create (e.g., 'google_drive')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Auth provider instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If connector type is not supported for auth
|
||||||
|
"""
|
||||||
|
auth_class = cls.auth_providers.get(connector_type.lower())
|
||||||
|
if not auth_class:
|
||||||
|
raise ValueError(f"No auth class found for type {connector_type}")
|
||||||
|
return auth_class()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_supported_connectors(cls):
|
||||||
|
"""
|
||||||
|
Get list of supported connector types.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of supported connector type strings
|
||||||
|
"""
|
||||||
|
return list(cls.connectors.keys())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_supported(cls, connector_type):
|
||||||
|
"""
|
||||||
|
Check if a connector type is supported.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connector_type: Type of connector to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if supported, False otherwise
|
||||||
|
"""
|
||||||
|
return connector_type.lower() in cls.connectors
|
||||||
10
application/parser/connectors/google_drive/__init__.py
Normal file
10
application/parser/connectors/google_drive/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
"""
|
||||||
|
Google Drive connector for DocsGPT.
|
||||||
|
|
||||||
|
This module provides authentication and document loading capabilities for Google Drive.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .auth import GoogleDriveAuth
|
||||||
|
from .loader import GoogleDriveLoader
|
||||||
|
|
||||||
|
__all__ = ['GoogleDriveAuth', 'GoogleDriveLoader']
|
||||||
267
application/parser/connectors/google_drive/auth.py
Normal file
267
application/parser/connectors/google_drive/auth.py
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
import logging
|
||||||
|
import datetime
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
|
from google.oauth2.credentials import Credentials
|
||||||
|
from google_auth_oauthlib.flow import Flow
|
||||||
|
from googleapiclient.discovery import build
|
||||||
|
from googleapiclient.errors import HttpError
|
||||||
|
|
||||||
|
from application.core.settings import settings
|
||||||
|
from application.parser.connectors.base import BaseConnectorAuth
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleDriveAuth(BaseConnectorAuth):
|
||||||
|
"""
|
||||||
|
Handles Google OAuth 2.0 authentication for Google Drive access.
|
||||||
|
"""
|
||||||
|
|
||||||
|
SCOPES = [
|
||||||
|
'https://www.googleapis.com/auth/drive.file'
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.client_id = settings.GOOGLE_CLIENT_ID
|
||||||
|
self.client_secret = settings.GOOGLE_CLIENT_SECRET
|
||||||
|
self.redirect_uri = f"{settings.CONNECTOR_REDIRECT_BASE_URI}"
|
||||||
|
|
||||||
|
if not self.client_id or not self.client_secret:
|
||||||
|
raise ValueError("Google OAuth credentials not configured. Please set GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET in settings.")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def get_authorization_url(self, state: Optional[str] = None) -> str:
|
||||||
|
try:
|
||||||
|
flow = Flow.from_client_config(
|
||||||
|
{
|
||||||
|
"web": {
|
||||||
|
"client_id": self.client_id,
|
||||||
|
"client_secret": self.client_secret,
|
||||||
|
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
||||||
|
"token_uri": "https://oauth2.googleapis.com/token",
|
||||||
|
"redirect_uris": [self.redirect_uri]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
scopes=self.SCOPES
|
||||||
|
)
|
||||||
|
flow.redirect_uri = self.redirect_uri
|
||||||
|
|
||||||
|
authorization_url, _ = flow.authorization_url(
|
||||||
|
access_type='offline',
|
||||||
|
prompt='consent',
|
||||||
|
include_granted_scopes='false',
|
||||||
|
state=state
|
||||||
|
)
|
||||||
|
|
||||||
|
return authorization_url
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error generating authorization URL: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def exchange_code_for_tokens(self, authorization_code: str) -> Dict[str, Any]:
|
||||||
|
try:
|
||||||
|
if not authorization_code:
|
||||||
|
raise ValueError("Authorization code is required")
|
||||||
|
|
||||||
|
flow = Flow.from_client_config(
|
||||||
|
{
|
||||||
|
"web": {
|
||||||
|
"client_id": self.client_id,
|
||||||
|
"client_secret": self.client_secret,
|
||||||
|
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
||||||
|
"token_uri": "https://oauth2.googleapis.com/token",
|
||||||
|
"redirect_uris": [self.redirect_uri]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
scopes=self.SCOPES
|
||||||
|
)
|
||||||
|
flow.redirect_uri = self.redirect_uri
|
||||||
|
|
||||||
|
flow.fetch_token(code=authorization_code)
|
||||||
|
|
||||||
|
credentials = flow.credentials
|
||||||
|
|
||||||
|
if not credentials.refresh_token:
|
||||||
|
logging.warning("OAuth flow did not return a refresh_token.")
|
||||||
|
if not credentials.token:
|
||||||
|
raise ValueError("OAuth flow did not return an access token")
|
||||||
|
|
||||||
|
if not credentials.token_uri:
|
||||||
|
credentials.token_uri = "https://oauth2.googleapis.com/token"
|
||||||
|
|
||||||
|
if not credentials.client_id:
|
||||||
|
credentials.client_id = self.client_id
|
||||||
|
|
||||||
|
if not credentials.client_secret:
|
||||||
|
credentials.client_secret = self.client_secret
|
||||||
|
|
||||||
|
if not credentials.refresh_token:
|
||||||
|
raise ValueError(
|
||||||
|
"No refresh token received. This typically happens when offline access wasn't granted. "
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'access_token': credentials.token,
|
||||||
|
'refresh_token': credentials.refresh_token,
|
||||||
|
'token_uri': credentials.token_uri,
|
||||||
|
'client_id': credentials.client_id,
|
||||||
|
'client_secret': credentials.client_secret,
|
||||||
|
'scopes': credentials.scopes,
|
||||||
|
'expiry': credentials.expiry.isoformat() if credentials.expiry else None
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error exchanging code for tokens: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def refresh_access_token(self, refresh_token: str) -> Dict[str, Any]:
|
||||||
|
try:
|
||||||
|
if not refresh_token:
|
||||||
|
raise ValueError("Refresh token is required")
|
||||||
|
|
||||||
|
credentials = Credentials(
|
||||||
|
token=None,
|
||||||
|
refresh_token=refresh_token,
|
||||||
|
token_uri="https://oauth2.googleapis.com/token",
|
||||||
|
client_id=self.client_id,
|
||||||
|
client_secret=self.client_secret
|
||||||
|
)
|
||||||
|
|
||||||
|
from google.auth.transport.requests import Request
|
||||||
|
credentials.refresh(Request())
|
||||||
|
|
||||||
|
return {
|
||||||
|
'access_token': credentials.token,
|
||||||
|
'refresh_token': refresh_token,
|
||||||
|
'token_uri': credentials.token_uri,
|
||||||
|
'client_id': credentials.client_id,
|
||||||
|
'client_secret': credentials.client_secret,
|
||||||
|
'scopes': credentials.scopes,
|
||||||
|
'expiry': credentials.expiry.isoformat() if credentials.expiry else None
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error refreshing access token: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def create_credentials_from_token_info(self, token_info: Dict[str, Any]) -> Credentials:
|
||||||
|
from application.core.settings import settings
|
||||||
|
|
||||||
|
access_token = token_info.get('access_token')
|
||||||
|
if not access_token:
|
||||||
|
raise ValueError("No access token found in token_info")
|
||||||
|
|
||||||
|
credentials = Credentials(
|
||||||
|
token=access_token,
|
||||||
|
refresh_token=token_info.get('refresh_token'),
|
||||||
|
token_uri= 'https://oauth2.googleapis.com/token',
|
||||||
|
client_id=settings.GOOGLE_CLIENT_ID,
|
||||||
|
client_secret=settings.GOOGLE_CLIENT_SECRET,
|
||||||
|
scopes=token_info.get('scopes', ['https://www.googleapis.com/auth/drive.readonly'])
|
||||||
|
)
|
||||||
|
|
||||||
|
if not credentials.token:
|
||||||
|
raise ValueError("Credentials created without valid access token")
|
||||||
|
|
||||||
|
return credentials
|
||||||
|
|
||||||
|
def build_drive_service(self, credentials: Credentials):
|
||||||
|
try:
|
||||||
|
if not credentials:
|
||||||
|
raise ValueError("No credentials provided")
|
||||||
|
|
||||||
|
if not credentials.token and not credentials.refresh_token:
|
||||||
|
raise ValueError("No access token or refresh token available. User must re-authorize with offline access.")
|
||||||
|
|
||||||
|
needs_refresh = credentials.expired or not credentials.token
|
||||||
|
if needs_refresh:
|
||||||
|
if credentials.refresh_token:
|
||||||
|
try:
|
||||||
|
from google.auth.transport.requests import Request
|
||||||
|
credentials.refresh(Request())
|
||||||
|
except Exception as refresh_error:
|
||||||
|
raise ValueError(f"Failed to refresh credentials: {refresh_error}")
|
||||||
|
else:
|
||||||
|
raise ValueError("No access token or refresh token available. User must re-authorize with offline access.")
|
||||||
|
|
||||||
|
return build('drive', 'v3', credentials=credentials)
|
||||||
|
|
||||||
|
except HttpError as e:
|
||||||
|
raise ValueError(f"Failed to build Google Drive service: HTTP {e.resp.status}")
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Failed to build Google Drive service: {str(e)}")
|
||||||
|
|
||||||
|
def is_token_expired(self, token_info):
|
||||||
|
if 'expiry' in token_info and token_info['expiry']:
|
||||||
|
try:
|
||||||
|
from dateutil import parser
|
||||||
|
# Google Drive provides timezone-aware ISO8601 dates
|
||||||
|
expiry_dt = parser.parse(token_info['expiry'])
|
||||||
|
current_time = datetime.datetime.now(datetime.timezone.utc)
|
||||||
|
return current_time >= expiry_dt - datetime.timedelta(seconds=60)
|
||||||
|
except Exception:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if 'access_token' in token_info and token_info['access_token']:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_token_info_from_session(self, session_token: str) -> Dict[str, Any]:
|
||||||
|
try:
|
||||||
|
from application.core.mongo_db import MongoDB
|
||||||
|
from application.core.settings import settings
|
||||||
|
|
||||||
|
mongo = MongoDB.get_client()
|
||||||
|
db = mongo[settings.MONGO_DB_NAME]
|
||||||
|
|
||||||
|
sessions_collection = db["connector_sessions"]
|
||||||
|
session = sessions_collection.find_one({"session_token": session_token})
|
||||||
|
if not session:
|
||||||
|
raise ValueError(f"Invalid session token: {session_token}")
|
||||||
|
|
||||||
|
if "token_info" not in session:
|
||||||
|
raise ValueError("Session missing token information")
|
||||||
|
|
||||||
|
token_info = session["token_info"]
|
||||||
|
if not token_info:
|
||||||
|
raise ValueError("Invalid token information")
|
||||||
|
|
||||||
|
required_fields = ["access_token", "refresh_token"]
|
||||||
|
missing_fields = [field for field in required_fields if field not in token_info or not token_info.get(field)]
|
||||||
|
if missing_fields:
|
||||||
|
raise ValueError(f"Missing required token fields: {missing_fields}")
|
||||||
|
|
||||||
|
if 'client_id' not in token_info:
|
||||||
|
token_info['client_id'] = settings.GOOGLE_CLIENT_ID
|
||||||
|
if 'client_secret' not in token_info:
|
||||||
|
token_info['client_secret'] = settings.GOOGLE_CLIENT_SECRET
|
||||||
|
if 'token_uri' not in token_info:
|
||||||
|
token_info['token_uri'] = 'https://oauth2.googleapis.com/token'
|
||||||
|
|
||||||
|
return token_info
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Failed to retrieve Google Drive token information: {str(e)}")
|
||||||
|
|
||||||
|
def validate_credentials(self, credentials: Credentials) -> bool:
|
||||||
|
"""
|
||||||
|
Validate Google Drive credentials by making a test API call.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
credentials: Google credentials object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if credentials are valid, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
service = self.build_drive_service(credentials)
|
||||||
|
service.about().get(fields="user").execute()
|
||||||
|
return True
|
||||||
|
|
||||||
|
except HttpError as e:
|
||||||
|
logging.error(f"HTTP error validating credentials: {e}")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error validating credentials: {e}")
|
||||||
|
return False
|
||||||
559
application/parser/connectors/google_drive/loader.py
Normal file
559
application/parser/connectors/google_drive/loader.py
Normal file
@@ -0,0 +1,559 @@
|
|||||||
|
"""
|
||||||
|
Google Drive loader for DocsGPT.
|
||||||
|
Loads documents from Google Drive using Google Drive API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
|
||||||
|
from googleapiclient.http import MediaIoBaseDownload
|
||||||
|
from googleapiclient.errors import HttpError
|
||||||
|
|
||||||
|
from application.parser.connectors.base import BaseConnectorLoader
|
||||||
|
from application.parser.connectors.google_drive.auth import GoogleDriveAuth
|
||||||
|
from application.parser.schema.base import Document
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleDriveLoader(BaseConnectorLoader):
|
||||||
|
|
||||||
|
SUPPORTED_MIME_TYPES = {
|
||||||
|
'application/pdf': '.pdf',
|
||||||
|
'application/vnd.google-apps.document': '.docx',
|
||||||
|
'application/vnd.google-apps.presentation': '.pptx',
|
||||||
|
'application/vnd.google-apps.spreadsheet': '.xlsx',
|
||||||
|
'application/vnd.openxmlformats-officedocument.wordprocessingml.document': '.docx',
|
||||||
|
'application/vnd.openxmlformats-officedocument.presentationml.presentation': '.pptx',
|
||||||
|
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': '.xlsx',
|
||||||
|
'application/msword': '.doc',
|
||||||
|
'application/vnd.ms-powerpoint': '.ppt',
|
||||||
|
'application/vnd.ms-excel': '.xls',
|
||||||
|
'text/plain': '.txt',
|
||||||
|
'text/csv': '.csv',
|
||||||
|
'text/html': '.html',
|
||||||
|
'text/markdown': '.md',
|
||||||
|
'text/x-rst': '.rst',
|
||||||
|
'application/json': '.json',
|
||||||
|
'application/epub+zip': '.epub',
|
||||||
|
'application/rtf': '.rtf',
|
||||||
|
'image/jpeg': '.jpg',
|
||||||
|
'image/jpg': '.jpg',
|
||||||
|
'image/png': '.png',
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPORT_FORMATS = {
|
||||||
|
'application/vnd.google-apps.document': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
||||||
|
'application/vnd.google-apps.presentation': 'application/vnd.openxmlformats-officedocument.presentationml.presentation',
|
||||||
|
'application/vnd.google-apps.spreadsheet': 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet'
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, session_token: str):
|
||||||
|
self.auth = GoogleDriveAuth()
|
||||||
|
self.session_token = session_token
|
||||||
|
|
||||||
|
token_info = self.auth.get_token_info_from_session(session_token)
|
||||||
|
self.credentials = self.auth.create_credentials_from_token_info(token_info)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.service = self.auth.build_drive_service(self.credentials)
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Could not build Google Drive service: {e}")
|
||||||
|
self.service = None
|
||||||
|
|
||||||
|
self.next_page_token = None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _process_file(self, file_metadata: Dict[str, Any], load_content: bool = True) -> Optional[Document]:
|
||||||
|
try:
|
||||||
|
file_id = file_metadata.get('id')
|
||||||
|
file_name = file_metadata.get('name', 'Unknown')
|
||||||
|
mime_type = file_metadata.get('mimeType', 'application/octet-stream')
|
||||||
|
|
||||||
|
if mime_type not in self.SUPPORTED_MIME_TYPES and not mime_type.startswith('application/vnd.google-apps.'):
|
||||||
|
return None
|
||||||
|
if mime_type not in self.SUPPORTED_MIME_TYPES and not mime_type.startswith('application/vnd.google-apps.'):
|
||||||
|
logging.info(f"Skipping unsupported file type: {mime_type} for file {file_name}")
|
||||||
|
return None
|
||||||
|
# Google Drive provides timezone-aware ISO8601 dates
|
||||||
|
doc_metadata = {
|
||||||
|
'file_name': file_name,
|
||||||
|
'mime_type': mime_type,
|
||||||
|
'size': file_metadata.get('size', None),
|
||||||
|
'created_time': file_metadata.get('createdTime'),
|
||||||
|
'modified_time': file_metadata.get('modifiedTime'),
|
||||||
|
'parents': file_metadata.get('parents', []),
|
||||||
|
'source': 'google_drive'
|
||||||
|
}
|
||||||
|
|
||||||
|
if not load_content:
|
||||||
|
return Document(
|
||||||
|
text="",
|
||||||
|
doc_id=file_id,
|
||||||
|
extra_info=doc_metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
content = self._download_file_content(file_id, mime_type)
|
||||||
|
if content is None:
|
||||||
|
logging.warning(f"Could not load content for file {file_name} ({file_id})")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return Document(
|
||||||
|
text=content,
|
||||||
|
doc_id=file_id,
|
||||||
|
extra_info=doc_metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error processing file: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def load_data(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||||
|
session_token = inputs.get('session_token')
|
||||||
|
if session_token and session_token != self.session_token:
|
||||||
|
logging.warning("Session token in inputs differs from loader's session token. Using loader's session token.")
|
||||||
|
self.config = inputs
|
||||||
|
|
||||||
|
try:
|
||||||
|
documents: List[Document] = []
|
||||||
|
|
||||||
|
folder_id = inputs.get('folder_id')
|
||||||
|
file_ids = inputs.get('file_ids', [])
|
||||||
|
limit = inputs.get('limit', 100)
|
||||||
|
list_only = inputs.get('list_only', False)
|
||||||
|
load_content = not list_only
|
||||||
|
page_token = inputs.get('page_token')
|
||||||
|
search_query = inputs.get('search_query')
|
||||||
|
self.next_page_token = None
|
||||||
|
|
||||||
|
if file_ids:
|
||||||
|
# Specific files requested: load them
|
||||||
|
for file_id in file_ids:
|
||||||
|
try:
|
||||||
|
doc = self._load_file_by_id(file_id, load_content=load_content)
|
||||||
|
if doc:
|
||||||
|
if not search_query or (
|
||||||
|
search_query.lower() in doc.extra_info.get('file_name', '').lower()
|
||||||
|
):
|
||||||
|
documents.append(doc)
|
||||||
|
elif hasattr(self, '_credential_refreshed') and self._credential_refreshed:
|
||||||
|
self._credential_refreshed = False
|
||||||
|
logging.info(f"Retrying load of file {file_id} after credential refresh")
|
||||||
|
doc = self._load_file_by_id(file_id, load_content=load_content)
|
||||||
|
if doc and (
|
||||||
|
not search_query or
|
||||||
|
search_query.lower() in doc.extra_info.get('file_name', '').lower()
|
||||||
|
):
|
||||||
|
documents.append(doc)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error loading file {file_id}: {e}")
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# Browsing mode: list immediate children of provided folder or root
|
||||||
|
parent_id = folder_id if folder_id else 'root'
|
||||||
|
documents = self._list_items_in_parent(
|
||||||
|
parent_id,
|
||||||
|
limit=limit,
|
||||||
|
load_content=load_content,
|
||||||
|
page_token=page_token,
|
||||||
|
search_query=search_query
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info(f"Loaded {len(documents)} documents from Google Drive")
|
||||||
|
return documents
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error loading data from Google Drive: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _load_file_by_id(self, file_id: str, load_content: bool = True) -> Optional[Document]:
|
||||||
|
self._ensure_service()
|
||||||
|
|
||||||
|
try:
|
||||||
|
file_metadata = self.service.files().get(
|
||||||
|
fileId=file_id,
|
||||||
|
fields='id,name,mimeType,size,createdTime,modifiedTime,parents'
|
||||||
|
).execute()
|
||||||
|
|
||||||
|
return self._process_file(file_metadata, load_content=load_content)
|
||||||
|
|
||||||
|
except HttpError as e:
|
||||||
|
logging.error(f"HTTP error loading file {file_id}: {e.resp.status} - {e.content}")
|
||||||
|
|
||||||
|
if e.resp.status in [401, 403]:
|
||||||
|
if hasattr(self.credentials, 'refresh_token') and self.credentials.refresh_token:
|
||||||
|
try:
|
||||||
|
from google.auth.transport.requests import Request
|
||||||
|
self.credentials.refresh(Request())
|
||||||
|
self._ensure_service()
|
||||||
|
return None
|
||||||
|
except Exception as refresh_error:
|
||||||
|
raise ValueError(f"Authentication failed and could not be refreshed: {refresh_error}")
|
||||||
|
else:
|
||||||
|
raise ValueError("Authentication failed and cannot be refreshed: missing refresh_token")
|
||||||
|
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error loading file {file_id}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _list_items_in_parent(self, parent_id: str, limit: int = 100, load_content: bool = False, page_token: Optional[str] = None, search_query: Optional[str] = None) -> List[Document]:
|
||||||
|
self._ensure_service()
|
||||||
|
|
||||||
|
documents: List[Document] = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
query = f"'{parent_id}' in parents and trashed=false"
|
||||||
|
|
||||||
|
if search_query:
|
||||||
|
safe_search = search_query.replace("'", "\\'")
|
||||||
|
query += f" and name contains '{safe_search}'"
|
||||||
|
|
||||||
|
next_token_out: Optional[str] = None
|
||||||
|
|
||||||
|
while True:
|
||||||
|
page_size = 100
|
||||||
|
if limit:
|
||||||
|
remaining = max(0, limit - len(documents))
|
||||||
|
if remaining == 0:
|
||||||
|
break
|
||||||
|
page_size = min(100, remaining)
|
||||||
|
|
||||||
|
results = self.service.files().list(
|
||||||
|
q=query,
|
||||||
|
fields='nextPageToken,files(id,name,mimeType,size,createdTime,modifiedTime,parents)',
|
||||||
|
pageToken=page_token,
|
||||||
|
pageSize=page_size,
|
||||||
|
orderBy='name'
|
||||||
|
).execute()
|
||||||
|
|
||||||
|
items = results.get('files', [])
|
||||||
|
for item in items:
|
||||||
|
mime_type = item.get('mimeType')
|
||||||
|
if mime_type == 'application/vnd.google-apps.folder':
|
||||||
|
doc_metadata = {
|
||||||
|
'file_name': item.get('name', 'Unknown'),
|
||||||
|
'mime_type': mime_type,
|
||||||
|
'size': item.get('size', None),
|
||||||
|
'created_time': item.get('createdTime'),
|
||||||
|
'modified_time': item.get('modifiedTime'),
|
||||||
|
'parents': item.get('parents', []),
|
||||||
|
'source': 'google_drive',
|
||||||
|
'is_folder': True
|
||||||
|
}
|
||||||
|
documents.append(Document(text="", doc_id=item.get('id'), extra_info=doc_metadata))
|
||||||
|
else:
|
||||||
|
doc = self._process_file(item, load_content=load_content)
|
||||||
|
if doc:
|
||||||
|
documents.append(doc)
|
||||||
|
|
||||||
|
if limit and len(documents) >= limit:
|
||||||
|
self.next_page_token = results.get('nextPageToken')
|
||||||
|
return documents
|
||||||
|
|
||||||
|
page_token = results.get('nextPageToken')
|
||||||
|
next_token_out = page_token
|
||||||
|
if not page_token:
|
||||||
|
break
|
||||||
|
|
||||||
|
self.next_page_token = next_token_out
|
||||||
|
return documents
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error listing items under parent {parent_id}: {e}")
|
||||||
|
return documents
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _download_file_content(self, file_id: str, mime_type: str) -> Optional[str]:
|
||||||
|
if not self.credentials.token:
|
||||||
|
logging.warning("No access token in credentials, attempting to refresh")
|
||||||
|
if hasattr(self.credentials, 'refresh_token') and self.credentials.refresh_token:
|
||||||
|
try:
|
||||||
|
from google.auth.transport.requests import Request
|
||||||
|
self.credentials.refresh(Request())
|
||||||
|
logging.info("Credentials refreshed successfully")
|
||||||
|
self._ensure_service()
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to refresh credentials: {e}")
|
||||||
|
raise ValueError("Authentication failed and cannot be refreshed: missing or invalid refresh_token")
|
||||||
|
else:
|
||||||
|
logging.error("No access token and no refresh_token available")
|
||||||
|
raise ValueError("Authentication failed and cannot be refreshed: missing refresh_token")
|
||||||
|
|
||||||
|
if self.credentials.expired:
|
||||||
|
logging.warning("Credentials are expired, attempting to refresh")
|
||||||
|
if hasattr(self.credentials, 'refresh_token') and self.credentials.refresh_token:
|
||||||
|
try:
|
||||||
|
from google.auth.transport.requests import Request
|
||||||
|
self.credentials.refresh(Request())
|
||||||
|
logging.info("Credentials refreshed successfully")
|
||||||
|
self._ensure_service()
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to refresh expired credentials: {e}")
|
||||||
|
raise ValueError("Authentication failed and cannot be refreshed: expired credentials")
|
||||||
|
else:
|
||||||
|
logging.error("Credentials expired and no refresh_token available")
|
||||||
|
raise ValueError("Authentication failed and cannot be refreshed: missing refresh_token")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if mime_type in self.EXPORT_FORMATS:
|
||||||
|
export_mime_type = self.EXPORT_FORMATS[mime_type]
|
||||||
|
request = self.service.files().export_media(
|
||||||
|
fileId=file_id,
|
||||||
|
mimeType=export_mime_type
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
request = self.service.files().get_media(fileId=file_id)
|
||||||
|
|
||||||
|
file_io = io.BytesIO()
|
||||||
|
downloader = MediaIoBaseDownload(file_io, request)
|
||||||
|
|
||||||
|
done = False
|
||||||
|
while done is False:
|
||||||
|
try:
|
||||||
|
_, done = downloader.next_chunk()
|
||||||
|
except HttpError as e:
|
||||||
|
logging.error(f"HTTP error downloading file {file_id}: {e.resp.status} - {e.content}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error during download of file {file_id}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
content_bytes = file_io.getvalue()
|
||||||
|
|
||||||
|
try:
|
||||||
|
content = content_bytes.decode('utf-8')
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
try:
|
||||||
|
content = content_bytes.decode('latin-1')
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
logging.error(f"Could not decode file {file_id} as text")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return content
|
||||||
|
|
||||||
|
except HttpError as e:
|
||||||
|
logging.error(f"HTTP error downloading file {file_id}: {e.resp.status} - {e.content}")
|
||||||
|
|
||||||
|
if e.resp.status in [401, 403]:
|
||||||
|
logging.error(f"Authentication error downloading file {file_id}")
|
||||||
|
|
||||||
|
if hasattr(self.credentials, 'refresh_token') and self.credentials.refresh_token:
|
||||||
|
logging.info(f"Attempting to refresh credentials for file {file_id}")
|
||||||
|
try:
|
||||||
|
from google.auth.transport.requests import Request
|
||||||
|
self.credentials.refresh(Request())
|
||||||
|
logging.info("Credentials refreshed successfully")
|
||||||
|
self._credential_refreshed = True
|
||||||
|
self._ensure_service()
|
||||||
|
return None
|
||||||
|
except Exception as refresh_error:
|
||||||
|
logging.error(f"Error refreshing credentials: {refresh_error}")
|
||||||
|
raise ValueError(f"Authentication failed and could not be refreshed: {refresh_error}")
|
||||||
|
else:
|
||||||
|
logging.error("Cannot refresh credentials: missing refresh_token")
|
||||||
|
raise ValueError("Authentication failed and cannot be refreshed: missing refresh_token")
|
||||||
|
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error downloading file {file_id}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _download_file_to_directory(self, file_id: str, local_dir: str) -> bool:
|
||||||
|
try:
|
||||||
|
self._ensure_service()
|
||||||
|
return self._download_single_file(file_id, local_dir)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error downloading file {file_id}: {e}", exc_info=True)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _ensure_service(self):
|
||||||
|
if not self.service:
|
||||||
|
try:
|
||||||
|
self.service = self.auth.build_drive_service(self.credentials)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Cannot access Google Drive: {e}")
|
||||||
|
|
||||||
|
def _download_single_file(self, file_id: str, local_dir: str) -> bool:
|
||||||
|
file_metadata = self.service.files().get(
|
||||||
|
fileId=file_id,
|
||||||
|
fields='name,mimeType'
|
||||||
|
).execute()
|
||||||
|
|
||||||
|
file_name = file_metadata['name']
|
||||||
|
mime_type = file_metadata['mimeType']
|
||||||
|
|
||||||
|
if mime_type not in self.SUPPORTED_MIME_TYPES and not mime_type.startswith('application/vnd.google-apps.'):
|
||||||
|
return False
|
||||||
|
|
||||||
|
os.makedirs(local_dir, exist_ok=True)
|
||||||
|
full_path = os.path.join(local_dir, file_name)
|
||||||
|
|
||||||
|
if mime_type in self.EXPORT_FORMATS:
|
||||||
|
export_mime_type = self.EXPORT_FORMATS[mime_type]
|
||||||
|
request = self.service.files().export_media(
|
||||||
|
fileId=file_id,
|
||||||
|
mimeType=export_mime_type
|
||||||
|
)
|
||||||
|
extension = self._get_extension_for_mime_type(export_mime_type)
|
||||||
|
if not full_path.endswith(extension):
|
||||||
|
full_path += extension
|
||||||
|
else:
|
||||||
|
request = self.service.files().get_media(fileId=file_id)
|
||||||
|
|
||||||
|
with open(full_path, 'wb') as f:
|
||||||
|
downloader = MediaIoBaseDownload(f, request)
|
||||||
|
done = False
|
||||||
|
while not done:
|
||||||
|
_, done = downloader.next_chunk()
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _download_folder_recursive(self, folder_id: str, local_dir: str, recursive: bool = True) -> int:
|
||||||
|
files_downloaded = 0
|
||||||
|
try:
|
||||||
|
os.makedirs(local_dir, exist_ok=True)
|
||||||
|
|
||||||
|
query = f"'{folder_id}' in parents and trashed=false"
|
||||||
|
page_token = None
|
||||||
|
|
||||||
|
while True:
|
||||||
|
results = self.service.files().list(
|
||||||
|
q=query,
|
||||||
|
fields='nextPageToken, files(id, name, mimeType)',
|
||||||
|
pageToken=page_token,
|
||||||
|
pageSize=1000
|
||||||
|
).execute()
|
||||||
|
|
||||||
|
items = results.get('files', [])
|
||||||
|
logging.info(f"Found {len(items)} items in folder {folder_id}")
|
||||||
|
|
||||||
|
for item in items:
|
||||||
|
item_name = item['name']
|
||||||
|
item_id = item['id']
|
||||||
|
mime_type = item['mimeType']
|
||||||
|
|
||||||
|
if mime_type == 'application/vnd.google-apps.folder':
|
||||||
|
if recursive:
|
||||||
|
# Create subfolder and recurse
|
||||||
|
subfolder_path = os.path.join(local_dir, item_name)
|
||||||
|
os.makedirs(subfolder_path, exist_ok=True)
|
||||||
|
subfolder_files = self._download_folder_recursive(
|
||||||
|
item_id,
|
||||||
|
subfolder_path,
|
||||||
|
recursive
|
||||||
|
)
|
||||||
|
files_downloaded += subfolder_files
|
||||||
|
logging.info(f"Downloaded {subfolder_files} files from subfolder {item_name}")
|
||||||
|
else:
|
||||||
|
# Download file
|
||||||
|
success = self._download_single_file(item_id, local_dir)
|
||||||
|
if success:
|
||||||
|
files_downloaded += 1
|
||||||
|
logging.info(f"Downloaded file: {item_name}")
|
||||||
|
else:
|
||||||
|
logging.warning(f"Failed to download file: {item_name}")
|
||||||
|
|
||||||
|
page_token = results.get('nextPageToken')
|
||||||
|
if not page_token:
|
||||||
|
break
|
||||||
|
|
||||||
|
return files_downloaded
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error in _download_folder_recursive for folder {folder_id}: {e}", exc_info=True)
|
||||||
|
return files_downloaded
|
||||||
|
|
||||||
|
def _get_extension_for_mime_type(self, mime_type: str) -> str:
|
||||||
|
extensions = {
|
||||||
|
'application/pdf': '.pdf',
|
||||||
|
'text/plain': '.txt',
|
||||||
|
'application/vnd.openxmlformats-officedocument.wordprocessingml.document': '.docx',
|
||||||
|
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': '.xlsx',
|
||||||
|
'application/vnd.openxmlformats-officedocument.presentationml.presentation': '.pptx',
|
||||||
|
'text/html': '.html',
|
||||||
|
'text/markdown': '.md',
|
||||||
|
}
|
||||||
|
return extensions.get(mime_type, '.bin')
|
||||||
|
|
||||||
|
def _download_folder_contents(self, folder_id: str, local_dir: str, recursive: bool = True) -> int:
|
||||||
|
try:
|
||||||
|
self._ensure_service()
|
||||||
|
return self._download_folder_recursive(folder_id, local_dir, recursive)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error downloading folder {folder_id}: {e}", exc_info=True)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def download_to_directory(self, local_dir: str, source_config: dict = None) -> dict:
|
||||||
|
if source_config is None:
|
||||||
|
source_config = {}
|
||||||
|
|
||||||
|
config = source_config if source_config else getattr(self, 'config', {})
|
||||||
|
files_downloaded = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
folder_ids = config.get('folder_ids', [])
|
||||||
|
file_ids = config.get('file_ids', [])
|
||||||
|
recursive = config.get('recursive', True)
|
||||||
|
|
||||||
|
self._ensure_service()
|
||||||
|
|
||||||
|
if file_ids:
|
||||||
|
if isinstance(file_ids, str):
|
||||||
|
file_ids = [file_ids]
|
||||||
|
|
||||||
|
for file_id in file_ids:
|
||||||
|
if self._download_file_to_directory(file_id, local_dir):
|
||||||
|
files_downloaded += 1
|
||||||
|
|
||||||
|
# Process folders
|
||||||
|
if folder_ids:
|
||||||
|
if isinstance(folder_ids, str):
|
||||||
|
folder_ids = [folder_ids]
|
||||||
|
|
||||||
|
for folder_id in folder_ids:
|
||||||
|
try:
|
||||||
|
folder_metadata = self.service.files().get(
|
||||||
|
fileId=folder_id,
|
||||||
|
fields='name'
|
||||||
|
).execute()
|
||||||
|
folder_name = folder_metadata.get('name', '')
|
||||||
|
folder_path = os.path.join(local_dir, folder_name)
|
||||||
|
os.makedirs(folder_path, exist_ok=True)
|
||||||
|
|
||||||
|
folder_files = self._download_folder_recursive(
|
||||||
|
folder_id,
|
||||||
|
folder_path,
|
||||||
|
recursive
|
||||||
|
)
|
||||||
|
files_downloaded += folder_files
|
||||||
|
logging.info(f"Downloaded {folder_files} files from folder {folder_name}")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error downloading folder {folder_id}: {e}", exc_info=True)
|
||||||
|
|
||||||
|
if not file_ids and not folder_ids:
|
||||||
|
raise ValueError("No folder_ids or file_ids provided for download")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"files_downloaded": files_downloaded,
|
||||||
|
"directory_path": local_dir,
|
||||||
|
"empty_result": files_downloaded == 0,
|
||||||
|
"source_type": "google_drive",
|
||||||
|
"config_used": config
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"files_downloaded": files_downloaded,
|
||||||
|
"directory_path": local_dir,
|
||||||
|
"empty_result": True,
|
||||||
|
"source_type": "google_drive",
|
||||||
|
"config_used": config,
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
@@ -1,21 +1,43 @@
|
|||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
|
from typing import List, Any
|
||||||
from retry import retry
|
from retry import retry
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
from application.vectorstore.vector_creator import VectorCreator
|
from application.vectorstore.vector_creator import VectorCreator
|
||||||
|
|
||||||
|
|
||||||
@retry(tries=10, delay=60)
|
def sanitize_content(content: str) -> str:
|
||||||
def add_text_to_store_with_retry(store, doc, source_id):
|
|
||||||
"""
|
"""
|
||||||
Add a document's text and metadata to the vector store with retry logic.
|
Remove NUL characters that can cause vector store ingestion to fail.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content (str): Raw content that may contain NUL characters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Sanitized content with NUL characters removed
|
||||||
|
"""
|
||||||
|
if not content:
|
||||||
|
return content
|
||||||
|
return content.replace('\x00', '')
|
||||||
|
|
||||||
|
|
||||||
|
@retry(tries=10, delay=60)
|
||||||
|
def add_text_to_store_with_retry(store: Any, doc: Any, source_id: str) -> None:
|
||||||
|
"""Add a document's text and metadata to the vector store with retry logic.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
store: The vector store object.
|
store: The vector store object.
|
||||||
doc: The document to be added.
|
doc: The document to be added.
|
||||||
source_id: Unique identifier for the source.
|
source_id: Unique identifier for the source.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If document addition fails after all retry attempts.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
# Sanitize content to remove NUL characters that cause ingestion failures
|
||||||
|
doc.page_content = sanitize_content(doc.page_content)
|
||||||
|
|
||||||
doc.metadata["source_id"] = str(source_id)
|
doc.metadata["source_id"] = str(source_id)
|
||||||
store.add_texts([doc.page_content], metadatas=[doc.metadata])
|
store.add_texts([doc.page_content], metadatas=[doc.metadata])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -23,18 +45,21 @@ def add_text_to_store_with_retry(store, doc, source_id):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def embed_and_store_documents(docs, folder_name, source_id, task_status):
|
def embed_and_store_documents(docs: List[Any], folder_name: str, source_id: str, task_status: Any) -> None:
|
||||||
"""
|
"""Embeds documents and stores them in a vector store.
|
||||||
Embeds documents and stores them in a vector store.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
docs (list): List of documents to be embedded and stored.
|
docs: List of documents to be embedded and stored.
|
||||||
folder_name (str): Directory to save the vector store.
|
folder_name: Directory to save the vector store.
|
||||||
source_id (str): Unique identifier for the source.
|
source_id: Unique identifier for the source.
|
||||||
task_status: Task state manager for progress updates.
|
task_status: Task state manager for progress updates.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
OSError: If unable to create folder or save vector store.
|
||||||
|
Exception: If vector store creation or document embedding fails.
|
||||||
"""
|
"""
|
||||||
# Ensure the folder exists
|
# Ensure the folder exists
|
||||||
if not os.path.exists(folder_name):
|
if not os.path.exists(folder_name):
|
||||||
@@ -46,7 +71,7 @@ def embed_and_store_documents(docs, folder_name, source_id, task_status):
|
|||||||
store = VectorCreator.create_vectorstore(
|
store = VectorCreator.create_vectorstore(
|
||||||
settings.VECTOR_STORE,
|
settings.VECTOR_STORE,
|
||||||
docs_init=docs_init,
|
docs_init=docs_init,
|
||||||
source_id=folder_name,
|
source_id=source_id,
|
||||||
embeddings_key=os.getenv("EMBEDDINGS_KEY"),
|
embeddings_key=os.getenv("EMBEDDINGS_KEY"),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -77,10 +102,21 @@ def embed_and_store_documents(docs, folder_name, source_id, task_status):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error embedding document {idx}: {e}", exc_info=True)
|
logging.error(f"Error embedding document {idx}: {e}", exc_info=True)
|
||||||
logging.info(f"Saving progress at document {idx} out of {total_docs}")
|
logging.info(f"Saving progress at document {idx} out of {total_docs}")
|
||||||
store.save_local(folder_name)
|
try:
|
||||||
|
store.save_local(folder_name)
|
||||||
|
logging.info("Progress saved successfully")
|
||||||
|
except Exception as save_error:
|
||||||
|
logging.error(f"CRITICAL: Failed to save progress: {save_error}", exc_info=True)
|
||||||
|
# Continue without breaking to attempt final save
|
||||||
break
|
break
|
||||||
|
|
||||||
# Save the vector store
|
# Save the vector store
|
||||||
if settings.VECTOR_STORE == "faiss":
|
if settings.VECTOR_STORE == "faiss":
|
||||||
store.save_local(folder_name)
|
try:
|
||||||
logging.info("Vector store saved successfully.")
|
store.save_local(folder_name)
|
||||||
|
logging.info("Vector store saved successfully.")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"CRITICAL: Failed to save final vector store: {e}", exc_info=True)
|
||||||
|
raise OSError(f"Unable to save vector store to {folder_name}: {e}") from e
|
||||||
|
else:
|
||||||
|
logging.info("Vector store saved successfully.")
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from application.parser.file.json_parser import JSONParser
|
|||||||
from application.parser.file.pptx_parser import PPTXParser
|
from application.parser.file.pptx_parser import PPTXParser
|
||||||
from application.parser.file.image_parser import ImageParser
|
from application.parser.file.image_parser import ImageParser
|
||||||
from application.parser.schema.base import Document
|
from application.parser.schema.base import Document
|
||||||
|
from application.utils import num_tokens_from_string
|
||||||
|
|
||||||
DEFAULT_FILE_EXTRACTOR: Dict[str, BaseParser] = {
|
DEFAULT_FILE_EXTRACTOR: Dict[str, BaseParser] = {
|
||||||
".pdf": PDFParser(),
|
".pdf": PDFParser(),
|
||||||
@@ -141,11 +142,12 @@ class SimpleDirectoryReader(BaseReader):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Document]: A list of documents.
|
List[Document]: A list of documents.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
data: Union[str, List[str]] = ""
|
data: Union[str, List[str]] = ""
|
||||||
data_list: List[str] = []
|
data_list: List[str] = []
|
||||||
metadata_list = []
|
metadata_list = []
|
||||||
|
self.file_token_counts = {}
|
||||||
|
|
||||||
for input_file in self.input_files:
|
for input_file in self.input_files:
|
||||||
if input_file.suffix in self.file_extractor:
|
if input_file.suffix in self.file_extractor:
|
||||||
parser = self.file_extractor[input_file.suffix]
|
parser = self.file_extractor[input_file.suffix]
|
||||||
@@ -156,24 +158,48 @@ class SimpleDirectoryReader(BaseReader):
|
|||||||
# do standard read
|
# do standard read
|
||||||
with open(input_file, "r", errors=self.errors) as f:
|
with open(input_file, "r", errors=self.errors) as f:
|
||||||
data = f.read()
|
data = f.read()
|
||||||
# Prepare metadata for this file
|
|
||||||
if self.file_metadata is not None:
|
# Calculate token count for this file
|
||||||
file_metadata = self.file_metadata(input_file.name)
|
if isinstance(data, List):
|
||||||
|
file_tokens = sum(num_tokens_from_string(str(d)) for d in data)
|
||||||
else:
|
else:
|
||||||
# Provide a default empty metadata
|
file_tokens = num_tokens_from_string(str(data))
|
||||||
file_metadata = {'title': '', 'store': ''}
|
|
||||||
# TODO: Find a case with no metadata and check if breaks anything
|
full_path = str(input_file.resolve())
|
||||||
|
self.file_token_counts[full_path] = file_tokens
|
||||||
|
|
||||||
|
base_metadata = {
|
||||||
|
'title': input_file.name,
|
||||||
|
'token_count': file_tokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasattr(self, 'input_dir'):
|
||||||
|
try:
|
||||||
|
relative_path = str(input_file.relative_to(self.input_dir))
|
||||||
|
base_metadata['source'] = relative_path
|
||||||
|
except ValueError:
|
||||||
|
base_metadata['source'] = str(input_file)
|
||||||
|
else:
|
||||||
|
base_metadata['source'] = str(input_file)
|
||||||
|
|
||||||
|
if self.file_metadata is not None:
|
||||||
|
custom_metadata = self.file_metadata(input_file.name)
|
||||||
|
base_metadata.update(custom_metadata)
|
||||||
|
|
||||||
if isinstance(data, List):
|
if isinstance(data, List):
|
||||||
# Extend data_list with each item in the data list
|
# Extend data_list with each item in the data list
|
||||||
data_list.extend([str(d) for d in data])
|
data_list.extend([str(d) for d in data])
|
||||||
# For each item in the data list, add the file's metadata to metadata_list
|
metadata_list.extend([base_metadata for _ in data])
|
||||||
metadata_list.extend([file_metadata for _ in data])
|
|
||||||
else:
|
else:
|
||||||
# Add the single piece of data to data_list
|
|
||||||
data_list.append(str(data))
|
data_list.append(str(data))
|
||||||
# Add the file's metadata to metadata_list
|
metadata_list.append(base_metadata)
|
||||||
metadata_list.append(file_metadata)
|
|
||||||
|
# Build directory structure if input_dir is provided
|
||||||
|
if hasattr(self, 'input_dir'):
|
||||||
|
self.directory_structure = self.build_directory_structure(self.input_dir)
|
||||||
|
logging.info("Directory structure built successfully")
|
||||||
|
else:
|
||||||
|
self.directory_structure = {}
|
||||||
|
|
||||||
if concatenate:
|
if concatenate:
|
||||||
return [Document("\n".join(data_list))]
|
return [Document("\n".join(data_list))]
|
||||||
@@ -181,3 +207,48 @@ class SimpleDirectoryReader(BaseReader):
|
|||||||
return [Document(d, extra_info=m) for d, m in zip(data_list, metadata_list)]
|
return [Document(d, extra_info=m) for d, m in zip(data_list, metadata_list)]
|
||||||
else:
|
else:
|
||||||
return [Document(d) for d in data_list]
|
return [Document(d) for d in data_list]
|
||||||
|
|
||||||
|
def build_directory_structure(self, base_path):
|
||||||
|
"""Build a dictionary representing the directory structure.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_path: The base path to start building the structure from.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A nested dictionary representing the directory structure.
|
||||||
|
"""
|
||||||
|
import mimetypes
|
||||||
|
|
||||||
|
def build_tree(path):
|
||||||
|
"""Helper function to recursively build the directory tree."""
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
for item in path.iterdir():
|
||||||
|
if self.exclude_hidden and item.name.startswith('.'):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if item.is_dir():
|
||||||
|
subtree = build_tree(item)
|
||||||
|
if subtree:
|
||||||
|
result[item.name] = subtree
|
||||||
|
else:
|
||||||
|
if self.required_exts is not None and item.suffix not in self.required_exts:
|
||||||
|
continue
|
||||||
|
|
||||||
|
full_path = str(item.resolve())
|
||||||
|
file_size_bytes = item.stat().st_size
|
||||||
|
mime_type = mimetypes.guess_type(item.name)[0] or "application/octet-stream"
|
||||||
|
|
||||||
|
file_info = {
|
||||||
|
"type": mime_type,
|
||||||
|
"size_bytes": file_size_bytes
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasattr(self, 'file_token_counts') and full_path in self.file_token_counts:
|
||||||
|
file_info["token_count"] = self.file_token_counts[full_path]
|
||||||
|
|
||||||
|
result[item.name] = file_info
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
return build_tree(Path(base_path))
|
||||||
@@ -8,6 +8,7 @@ import requests
|
|||||||
from typing import Dict, Union
|
from typing import Dict, Union
|
||||||
|
|
||||||
from application.parser.file.base_parser import BaseParser
|
from application.parser.file.base_parser import BaseParser
|
||||||
|
from application.core.settings import settings
|
||||||
|
|
||||||
|
|
||||||
class ImageParser(BaseParser):
|
class ImageParser(BaseParser):
|
||||||
@@ -18,10 +19,13 @@ class ImageParser(BaseParser):
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
def parse_file(self, file: Path, errors: str = "ignore") -> Union[str, list[str]]:
|
def parse_file(self, file: Path, errors: str = "ignore") -> Union[str, list[str]]:
|
||||||
doc2md_service = "https://llm.arc53.com/doc2md"
|
if settings.PARSE_IMAGE_REMOTE:
|
||||||
# alternatively you can use local vision capable LLM
|
doc2md_service = "https://llm.arc53.com/doc2md"
|
||||||
with open(file, "rb") as file_loaded:
|
# alternatively you can use local vision capable LLM
|
||||||
files = {'file': file_loaded}
|
with open(file, "rb") as file_loaded:
|
||||||
response = requests.post(doc2md_service, files=files)
|
files = {'file': file_loaded}
|
||||||
data = response.json()["markdown"]
|
response = requests.post(doc2md_service, files=files)
|
||||||
|
data = response.json()["markdown"]
|
||||||
|
else:
|
||||||
|
data = ""
|
||||||
return data
|
return data
|
||||||
|
|||||||
@@ -1,44 +1,135 @@
|
|||||||
import base64
|
import base64
|
||||||
import requests
|
import requests
|
||||||
from typing import List
|
import time
|
||||||
|
from typing import List, Optional
|
||||||
from application.parser.remote.base import BaseRemote
|
from application.parser.remote.base import BaseRemote
|
||||||
from langchain_core.documents import Document
|
from application.parser.schema.base import Document
|
||||||
import mimetypes
|
import mimetypes
|
||||||
|
from application.core.settings import settings
|
||||||
|
|
||||||
class GitHubLoader(BaseRemote):
|
class GitHubLoader(BaseRemote):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.access_token = None
|
self.access_token = settings.GITHUB_ACCESS_TOKEN
|
||||||
self.headers = {
|
self.headers = {
|
||||||
"Authorization": f"token {self.access_token}"
|
"Authorization": f"token {self.access_token}",
|
||||||
} if self.access_token else {}
|
"Accept": "application/vnd.github.v3+json"
|
||||||
|
} if self.access_token else {
|
||||||
|
"Accept": "application/vnd.github.v3+json"
|
||||||
|
}
|
||||||
return
|
return
|
||||||
|
|
||||||
def fetch_file_content(self, repo_url: str, file_path: str) -> str:
|
def is_text_file(self, file_path: str) -> bool:
|
||||||
|
"""Determine if a file is a text file based on extension."""
|
||||||
|
# Common text file extensions
|
||||||
|
text_extensions = {
|
||||||
|
'.txt', '.md', '.markdown', '.rst', '.json', '.xml', '.yaml', '.yml',
|
||||||
|
'.py', '.js', '.ts', '.jsx', '.tsx', '.java', '.c', '.cpp', '.h', '.hpp',
|
||||||
|
'.cs', '.go', '.rs', '.rb', '.php', '.swift', '.kt', '.scala',
|
||||||
|
'.html', '.css', '.scss', '.sass', '.less',
|
||||||
|
'.sh', '.bash', '.zsh', '.fish',
|
||||||
|
'.sql', '.r', '.m', '.mat',
|
||||||
|
'.ini', '.cfg', '.conf', '.config', '.env',
|
||||||
|
'.gitignore', '.dockerignore', '.editorconfig',
|
||||||
|
'.log', '.csv', '.tsv'
|
||||||
|
}
|
||||||
|
|
||||||
|
# Get file extension
|
||||||
|
file_lower = file_path.lower()
|
||||||
|
for ext in text_extensions:
|
||||||
|
if file_lower.endswith(ext):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Also check MIME type
|
||||||
|
mime_type, _ = mimetypes.guess_type(file_path)
|
||||||
|
if mime_type and (mime_type.startswith("text") or mime_type in ["application/json", "application/xml"]):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def fetch_file_content(self, repo_url: str, file_path: str) -> Optional[str]:
|
||||||
|
"""Fetch file content. Returns None if file should be skipped (binary files or empty files)."""
|
||||||
url = f"https://api.github.com/repos/{repo_url}/contents/{file_path}"
|
url = f"https://api.github.com/repos/{repo_url}/contents/{file_path}"
|
||||||
response = requests.get(url, headers=self.headers)
|
response = self._make_request(url)
|
||||||
|
|
||||||
if response.status_code == 200:
|
content = response.json()
|
||||||
content = response.json()
|
|
||||||
mime_type, _ = mimetypes.guess_type(file_path) # Guess the MIME type based on the file extension
|
|
||||||
|
|
||||||
if content.get("encoding") == "base64":
|
if content.get("encoding") == "base64":
|
||||||
if mime_type and mime_type.startswith("text"): # Handle only text files
|
if self.is_text_file(file_path): # Handle only text files
|
||||||
try:
|
try:
|
||||||
decoded_content = base64.b64decode(content["content"]).decode("utf-8")
|
decoded_content = base64.b64decode(content["content"]).decode("utf-8").strip()
|
||||||
return f"Filename: {file_path}\n\n{decoded_content}"
|
# Skip empty files
|
||||||
except Exception as e:
|
if not decoded_content:
|
||||||
raise e
|
return None
|
||||||
else:
|
return decoded_content
|
||||||
return f"Filename: {file_path} is a binary file and was skipped."
|
except Exception:
|
||||||
|
# If decoding fails, it's probably a binary file
|
||||||
|
return None
|
||||||
else:
|
else:
|
||||||
return f"Filename: {file_path}\n\n{content['content']}"
|
# Skip binary files by returning None
|
||||||
|
return None
|
||||||
else:
|
else:
|
||||||
response.raise_for_status()
|
file_content = content['content'].strip()
|
||||||
|
# Skip empty files
|
||||||
|
if not file_content:
|
||||||
|
return None
|
||||||
|
return file_content
|
||||||
|
|
||||||
|
def _make_request(self, url: str, max_retries: int = 3) -> requests.Response:
|
||||||
|
"""Make a request with retry logic for rate limiting"""
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
response = requests.get(url, headers=self.headers)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
return response
|
||||||
|
elif response.status_code == 403:
|
||||||
|
# Check if it's a rate limit issue
|
||||||
|
try:
|
||||||
|
error_data = response.json()
|
||||||
|
error_msg = error_data.get("message", "")
|
||||||
|
|
||||||
|
# Check rate limit headers
|
||||||
|
remaining = response.headers.get("X-RateLimit-Remaining", "unknown")
|
||||||
|
reset_time = response.headers.get("X-RateLimit-Reset", "unknown")
|
||||||
|
|
||||||
|
print(f"GitHub API 403 Error: {error_msg}")
|
||||||
|
print(f"Rate limit remaining: {remaining}, Reset time: {reset_time}")
|
||||||
|
|
||||||
|
if "rate limit" in error_msg.lower():
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
wait_time = 2 ** attempt # Exponential backoff
|
||||||
|
print(f"Rate limit hit, waiting {wait_time} seconds before retry...")
|
||||||
|
time.sleep(wait_time)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Provide helpful error message
|
||||||
|
if remaining == "0":
|
||||||
|
raise Exception(f"GitHub API rate limit exceeded. Please set GITHUB_ACCESS_TOKEN environment variable. Reset time: {reset_time}")
|
||||||
|
else:
|
||||||
|
raise Exception(f"GitHub API error: {error_msg}. This may require authentication - set GITHUB_ACCESS_TOKEN environment variable.")
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, Exception) and "GitHub API" in str(e):
|
||||||
|
raise
|
||||||
|
# If we can't parse the response, raise the original error
|
||||||
|
response.raise_for_status()
|
||||||
|
else:
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
def fetch_repo_files(self, repo_url: str, path: str = "") -> List[str]:
|
def fetch_repo_files(self, repo_url: str, path: str = "") -> List[str]:
|
||||||
url = f"https://api.github.com/repos/{repo_url}/contents/{path}"
|
url = f"https://api.github.com/repos/{repo_url}/contents/{path}"
|
||||||
response = requests.get(url, headers={**self.headers, "Accept": "application/vnd.github.v3.raw"})
|
response = self._make_request(url)
|
||||||
|
|
||||||
contents = response.json()
|
contents = response.json()
|
||||||
|
|
||||||
|
# Handle error responses from GitHub API
|
||||||
|
if isinstance(contents, dict) and "message" in contents:
|
||||||
|
raise Exception(f"GitHub API error: {contents.get('message')}")
|
||||||
|
|
||||||
|
# Ensure contents is a list
|
||||||
|
if not isinstance(contents, list):
|
||||||
|
raise TypeError(f"Expected list from GitHub API, got {type(contents).__name__}: {contents}")
|
||||||
|
|
||||||
files = []
|
files = []
|
||||||
for item in contents:
|
for item in contents:
|
||||||
if item["type"] == "file":
|
if item["type"] == "file":
|
||||||
@@ -53,6 +144,15 @@ class GitHubLoader(BaseRemote):
|
|||||||
documents = []
|
documents = []
|
||||||
for file_path in files:
|
for file_path in files:
|
||||||
content = self.fetch_file_content(repo_name, file_path)
|
content = self.fetch_file_content(repo_name, file_path)
|
||||||
documents.append(Document(page_content=content, metadata={"title": file_path,
|
# Skip binary files (content is None)
|
||||||
"source": f"https://github.com/{repo_name}/blob/main/{file_path}"}))
|
if content is None:
|
||||||
|
continue
|
||||||
|
documents.append(Document(
|
||||||
|
text=content,
|
||||||
|
doc_id=file_path,
|
||||||
|
extra_info={
|
||||||
|
"title": file_path,
|
||||||
|
"source": f"https://github.com/{repo_name}/blob/main/{file_path}"
|
||||||
|
}
|
||||||
|
))
|
||||||
return documents
|
return documents
|
||||||
|
|||||||
@@ -6,6 +6,16 @@ from application.parser.remote.github_loader import GitHubLoader
|
|||||||
|
|
||||||
|
|
||||||
class RemoteCreator:
|
class RemoteCreator:
|
||||||
|
"""
|
||||||
|
Factory class for creating remote content loaders.
|
||||||
|
|
||||||
|
These loaders fetch content from remote web sources like URLs,
|
||||||
|
sitemaps, web crawlers, social media platforms, etc.
|
||||||
|
|
||||||
|
For external knowledge base connectors (like Google Drive),
|
||||||
|
use ConnectorCreator instead.
|
||||||
|
"""
|
||||||
|
|
||||||
loaders = {
|
loaders = {
|
||||||
"url": WebLoader,
|
"url": WebLoader,
|
||||||
"sitemap": SitemapLoader,
|
"sitemap": SitemapLoader,
|
||||||
@@ -18,5 +28,5 @@ class RemoteCreator:
|
|||||||
def create_loader(cls, type, *args, **kwargs):
|
def create_loader(cls, type, *args, **kwargs):
|
||||||
loader_class = cls.loaders.get(type.lower())
|
loader_class = cls.loaders.get(type.lower())
|
||||||
if not loader_class:
|
if not loader_class:
|
||||||
raise ValueError(f"No LLM class found for type {type}")
|
raise ValueError(f"No loader class found for type {type}")
|
||||||
return loader_class(*args, **kwargs)
|
return loader_class(*args, **kwargs)
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||
You are an AI assistant and talk like you're thinking out loud. Given the following query, outline a concise thought process that includes key steps and considerations necessary for effective analysis and response. Avoid pointwise formatting. The goal is to break down the query into manageable components without excessive detail, focusing on clarity and logical progression.
|
You are an AI assistant and talk like you're thinking out loud. Given the following query, outline a concise thought process that includes key steps and considerations necessary for effective analysis and response. Avoid pointwise formatting. The goal is to break down the query into manageable components without excessive detail, focusing on clarity and logical progression.
|
||||||
|
|
||||||
Include the following elements in your thought process:
|
Include the following elements in your thought and execution process:
|
||||||
1. Identify the main objective of the query.
|
1. Identify the main objective of the query.
|
||||||
2. Determine any relevant context or background information needed to understand the query.
|
2. Determine any relevant context or background information needed to understand the query.
|
||||||
3. List potential approaches or methods to address the query.
|
3. List potential approaches or methods to address the query.
|
||||||
4. Highlight any critical factors or constraints that may influence the outcome.
|
4. Highlight any critical factors or constraints that may influence the outcome.
|
||||||
|
5. Plan with available tools to help you with the analysis but dont execute them. Tools will be executed by another AI.
|
||||||
|
|
||||||
Query: {query}
|
Query: {query}
|
||||||
Summaries: {summaries}
|
Summaries: {summaries}
|
||||||
|
Prompt: {prompt}
|
||||||
|
Observations(potentially previous tool calls): {observations}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ anthropic==0.49.0
|
|||||||
boto3==1.38.18
|
boto3==1.38.18
|
||||||
beautifulsoup4==4.13.4
|
beautifulsoup4==4.13.4
|
||||||
celery==5.4.0
|
celery==5.4.0
|
||||||
|
cryptography==42.0.8
|
||||||
dataclasses-json==0.6.7
|
dataclasses-json==0.6.7
|
||||||
docx2txt==0.8
|
docx2txt==0.8
|
||||||
duckduckgo-search==7.5.2
|
duckduckgo-search==7.5.2
|
||||||
@@ -9,10 +10,15 @@ ebooklib==0.18
|
|||||||
escodegen==1.0.11
|
escodegen==1.0.11
|
||||||
esprima==4.0.1
|
esprima==4.0.1
|
||||||
esutils==1.0.1
|
esutils==1.0.1
|
||||||
|
elevenlabs==2.17.0
|
||||||
Flask==3.1.1
|
Flask==3.1.1
|
||||||
faiss-cpu==1.9.0.post1
|
faiss-cpu==1.9.0.post1
|
||||||
|
fastmcp==2.11.0
|
||||||
flask-restx==1.3.0
|
flask-restx==1.3.0
|
||||||
google-genai==1.3.0
|
google-genai==1.3.0
|
||||||
|
google-api-python-client==2.179.0
|
||||||
|
google-auth-httplib2==0.2.0
|
||||||
|
google-auth-oauthlib==1.2.2
|
||||||
gTTS==2.5.4
|
gTTS==2.5.4
|
||||||
gunicorn==23.0.0
|
gunicorn==23.0.0
|
||||||
javalang==0.13.0
|
javalang==0.13.0
|
||||||
@@ -41,28 +47,28 @@ numpy==2.2.1
|
|||||||
openai==1.78.1
|
openai==1.78.1
|
||||||
openapi3-parser==1.1.21
|
openapi3-parser==1.1.21
|
||||||
orjson==3.10.14
|
orjson==3.10.14
|
||||||
packaging==25.0
|
packaging==24.2
|
||||||
pandas==2.2.3
|
pandas==2.2.3
|
||||||
openpyxl==3.1.5
|
openpyxl==3.1.5
|
||||||
pathable==0.4.4
|
pathable==0.4.4
|
||||||
pillow==11.1.0
|
pillow==11.1.0
|
||||||
portalocker==3.1.1
|
portalocker>=2.7.0,<3.0.0
|
||||||
prance==23.6.21.0
|
prance==23.6.21.0
|
||||||
prompt-toolkit==3.0.51
|
prompt-toolkit==3.0.51
|
||||||
protobuf==5.29.3
|
protobuf==5.29.3
|
||||||
psycopg2-binary==2.9.10
|
psycopg2-binary==2.9.10
|
||||||
py==1.11.0
|
py==1.11.0
|
||||||
pydantic==2.10.6
|
pydantic
|
||||||
pydantic-core==2.27.2
|
pydantic-core
|
||||||
pydantic-settings==2.7.1
|
pydantic-settings
|
||||||
pymongo==4.11.3
|
pymongo==4.11.3
|
||||||
pypdf==5.5.0
|
pypdf==5.5.0
|
||||||
python-dateutil==2.9.0.post0
|
python-dateutil==2.9.0.post0
|
||||||
python-dotenv==1.0.1
|
python-dotenv
|
||||||
python-jose==3.4.0
|
python-jose==3.4.0
|
||||||
python-pptx==1.0.2
|
python-pptx==1.0.2
|
||||||
redis==5.2.1
|
redis==5.2.1
|
||||||
referencing==0.36.2
|
referencing>=0.28.0,<0.31.0
|
||||||
regex==2024.11.6
|
regex==2024.11.6
|
||||||
requests==2.32.3
|
requests==2.32.3
|
||||||
retry==0.9.2
|
retry==0.9.2
|
||||||
@@ -78,7 +84,7 @@ tzdata==2024.2
|
|||||||
urllib3==2.3.0
|
urllib3==2.3.0
|
||||||
vine==5.1.0
|
vine==5.1.0
|
||||||
wcwidth==0.2.13
|
wcwidth==0.2.13
|
||||||
werkzeug==3.1.3
|
werkzeug>=3.1.0,<3.1.2
|
||||||
yarl==1.20.0
|
yarl==1.20.0
|
||||||
markdownify==1.1.0
|
markdownify==1.1.0
|
||||||
tldextract==5.1.3
|
tldextract==5.1.3
|
||||||
|
|||||||
@@ -5,14 +5,6 @@ class BaseRetriever(ABC):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def gen(self, *args, **kwargs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def search(self, *args, **kwargs):
|
def search(self, *args, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_params(self):
|
|
||||||
pass
|
|
||||||
|
|||||||
@@ -1,112 +0,0 @@
|
|||||||
import json
|
|
||||||
|
|
||||||
from langchain_community.tools import BraveSearch
|
|
||||||
|
|
||||||
from application.core.settings import settings
|
|
||||||
from application.llm.llm_creator import LLMCreator
|
|
||||||
from application.retriever.base import BaseRetriever
|
|
||||||
|
|
||||||
|
|
||||||
class BraveRetSearch(BaseRetriever):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
source,
|
|
||||||
chat_history,
|
|
||||||
prompt,
|
|
||||||
chunks=2,
|
|
||||||
token_limit=150,
|
|
||||||
gpt_model="docsgpt",
|
|
||||||
user_api_key=None,
|
|
||||||
decoded_token=None,
|
|
||||||
):
|
|
||||||
self.question = ""
|
|
||||||
self.source = source
|
|
||||||
self.chat_history = chat_history
|
|
||||||
self.prompt = prompt
|
|
||||||
self.chunks = chunks
|
|
||||||
self.gpt_model = gpt_model
|
|
||||||
self.token_limit = (
|
|
||||||
token_limit
|
|
||||||
if token_limit
|
|
||||||
< settings.MODEL_TOKEN_LIMITS.get(
|
|
||||||
self.gpt_model, settings.DEFAULT_MAX_HISTORY
|
|
||||||
)
|
|
||||||
else settings.MODEL_TOKEN_LIMITS.get(
|
|
||||||
self.gpt_model, settings.DEFAULT_MAX_HISTORY
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.user_api_key = user_api_key
|
|
||||||
self.decoded_token = decoded_token
|
|
||||||
|
|
||||||
def _get_data(self):
|
|
||||||
if self.chunks == 0:
|
|
||||||
docs = []
|
|
||||||
else:
|
|
||||||
search = BraveSearch.from_api_key(
|
|
||||||
api_key=settings.BRAVE_SEARCH_API_KEY,
|
|
||||||
search_kwargs={"count": int(self.chunks)},
|
|
||||||
)
|
|
||||||
results = search.run(self.question)
|
|
||||||
results = json.loads(results)
|
|
||||||
|
|
||||||
docs = []
|
|
||||||
for i in results:
|
|
||||||
try:
|
|
||||||
title = i["title"]
|
|
||||||
link = i["link"]
|
|
||||||
snippet = i["snippet"]
|
|
||||||
docs.append({"text": snippet, "title": title, "link": link})
|
|
||||||
except IndexError:
|
|
||||||
pass
|
|
||||||
if settings.LLM_NAME == "llama.cpp":
|
|
||||||
docs = [docs[0]]
|
|
||||||
|
|
||||||
return docs
|
|
||||||
|
|
||||||
def gen(self):
|
|
||||||
docs = self._get_data()
|
|
||||||
|
|
||||||
# join all page_content together with a newline
|
|
||||||
docs_together = "\n".join([doc["text"] for doc in docs])
|
|
||||||
p_chat_combine = self.prompt.replace("{summaries}", docs_together)
|
|
||||||
messages_combine = [{"role": "system", "content": p_chat_combine}]
|
|
||||||
for doc in docs:
|
|
||||||
yield {"source": doc}
|
|
||||||
|
|
||||||
if len(self.chat_history) > 0:
|
|
||||||
for i in self.chat_history:
|
|
||||||
if "prompt" in i and "response" in i:
|
|
||||||
messages_combine.append({"role": "user", "content": i["prompt"]})
|
|
||||||
messages_combine.append(
|
|
||||||
{"role": "assistant", "content": i["response"]}
|
|
||||||
)
|
|
||||||
messages_combine.append({"role": "user", "content": self.question})
|
|
||||||
|
|
||||||
llm = LLMCreator.create_llm(
|
|
||||||
settings.LLM_NAME,
|
|
||||||
api_key=settings.API_KEY,
|
|
||||||
user_api_key=self.user_api_key,
|
|
||||||
decoded_token=self.decoded_token,
|
|
||||||
)
|
|
||||||
|
|
||||||
completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
|
|
||||||
for line in completion:
|
|
||||||
yield {"answer": str(line)}
|
|
||||||
|
|
||||||
def search(self, query: str = ""):
|
|
||||||
if query:
|
|
||||||
self.question = query
|
|
||||||
return self._get_data()
|
|
||||||
|
|
||||||
def get_params(self):
|
|
||||||
return {
|
|
||||||
"question": self.question,
|
|
||||||
"source": self.source,
|
|
||||||
"chat_history": self.chat_history,
|
|
||||||
"prompt": self.prompt,
|
|
||||||
"chunks": self.chunks,
|
|
||||||
"token_limit": self.token_limit,
|
|
||||||
"gpt_model": self.gpt_model,
|
|
||||||
"user_api_key": self.user_api_key,
|
|
||||||
}
|
|
||||||
@@ -1,8 +1,10 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
from application.llm.llm_creator import LLMCreator
|
from application.llm.llm_creator import LLMCreator
|
||||||
from application.retriever.base import BaseRetriever
|
from application.retriever.base import BaseRetriever
|
||||||
|
from application.utils import num_tokens_from_string
|
||||||
from application.vectorstore.vector_creator import VectorCreator
|
from application.vectorstore.vector_creator import VectorCreator
|
||||||
|
|
||||||
|
|
||||||
@@ -13,28 +15,33 @@ class ClassicRAG(BaseRetriever):
|
|||||||
chat_history=None,
|
chat_history=None,
|
||||||
prompt="",
|
prompt="",
|
||||||
chunks=2,
|
chunks=2,
|
||||||
token_limit=150,
|
doc_token_limit=50000,
|
||||||
gpt_model="docsgpt",
|
gpt_model="docsgpt",
|
||||||
user_api_key=None,
|
user_api_key=None,
|
||||||
llm_name=settings.LLM_NAME,
|
llm_name=settings.LLM_PROVIDER,
|
||||||
api_key=settings.API_KEY,
|
api_key=settings.API_KEY,
|
||||||
decoded_token=None,
|
decoded_token=None,
|
||||||
):
|
):
|
||||||
self.original_question = ""
|
self.original_question = source.get("question", "")
|
||||||
self.chat_history = chat_history if chat_history is not None else []
|
self.chat_history = chat_history if chat_history is not None else []
|
||||||
self.prompt = prompt
|
self.prompt = prompt
|
||||||
self.chunks = chunks
|
if isinstance(chunks, str):
|
||||||
self.gpt_model = gpt_model
|
try:
|
||||||
self.token_limit = (
|
self.chunks = int(chunks)
|
||||||
token_limit
|
except ValueError:
|
||||||
if token_limit
|
logging.warning(
|
||||||
< settings.MODEL_TOKEN_LIMITS.get(
|
f"Invalid chunks value '{chunks}', using default value 2"
|
||||||
self.gpt_model, settings.DEFAULT_MAX_HISTORY
|
)
|
||||||
)
|
self.chunks = 2
|
||||||
else settings.MODEL_TOKEN_LIMITS.get(
|
else:
|
||||||
self.gpt_model, settings.DEFAULT_MAX_HISTORY
|
self.chunks = chunks
|
||||||
)
|
user_identifier = user_api_key if user_api_key else "default"
|
||||||
|
logging.info(
|
||||||
|
f"ClassicRAG initialized with chunks={self.chunks}, user_api_key={user_identifier}, "
|
||||||
|
f"sources={'active_docs' in source and source['active_docs'] is not None}"
|
||||||
)
|
)
|
||||||
|
self.gpt_model = gpt_model
|
||||||
|
self.doc_token_limit = doc_token_limit
|
||||||
self.user_api_key = user_api_key
|
self.user_api_key = user_api_key
|
||||||
self.llm_name = llm_name
|
self.llm_name = llm_name
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
@@ -44,24 +51,48 @@ class ClassicRAG(BaseRetriever):
|
|||||||
user_api_key=self.user_api_key,
|
user_api_key=self.user_api_key,
|
||||||
decoded_token=decoded_token,
|
decoded_token=decoded_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if "active_docs" in source and source["active_docs"] is not None:
|
||||||
|
if isinstance(source["active_docs"], list):
|
||||||
|
self.vectorstores = source["active_docs"]
|
||||||
|
else:
|
||||||
|
self.vectorstores = [source["active_docs"]]
|
||||||
|
else:
|
||||||
|
self.vectorstores = []
|
||||||
self.question = self._rephrase_query()
|
self.question = self._rephrase_query()
|
||||||
self.vectorstore = source["active_docs"] if "active_docs" in source else None
|
|
||||||
self.decoded_token = decoded_token
|
self.decoded_token = decoded_token
|
||||||
|
self._validate_vectorstore_config()
|
||||||
|
|
||||||
|
def _validate_vectorstore_config(self):
|
||||||
|
"""Validate vectorstore IDs and remove any empty/invalid entries"""
|
||||||
|
if not self.vectorstores:
|
||||||
|
logging.warning("No vectorstores configured for retrieval")
|
||||||
|
return
|
||||||
|
invalid_ids = [
|
||||||
|
vs_id for vs_id in self.vectorstores if not vs_id or not vs_id.strip()
|
||||||
|
]
|
||||||
|
if invalid_ids:
|
||||||
|
logging.warning(f"Found invalid vectorstore IDs: {invalid_ids}")
|
||||||
|
self.vectorstores = [
|
||||||
|
vs_id for vs_id in self.vectorstores if vs_id and vs_id.strip()
|
||||||
|
]
|
||||||
|
|
||||||
def _rephrase_query(self):
|
def _rephrase_query(self):
|
||||||
|
"""Rephrase user query with chat history context for better retrieval"""
|
||||||
if (
|
if (
|
||||||
not self.original_question
|
not self.original_question
|
||||||
or not self.chat_history
|
or not self.chat_history
|
||||||
or self.chat_history == []
|
or self.chat_history == []
|
||||||
|
or self.chunks == 0
|
||||||
|
or not self.vectorstores
|
||||||
):
|
):
|
||||||
return self.original_question
|
return self.original_question
|
||||||
|
prompt = (
|
||||||
prompt = f"""Given the following conversation history:
|
"Given the following conversation history:\n"
|
||||||
{self.chat_history}
|
f"{self.chat_history}\n\n"
|
||||||
|
"Rephrase the following user question to be a standalone search query "
|
||||||
Rephrase the following user question to be a standalone search query
|
"that captures all relevant context from the conversation:\n"
|
||||||
that captures all relevant context from the conversation:
|
)
|
||||||
"""
|
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": prompt},
|
{"role": "system", "content": prompt},
|
||||||
@@ -77,46 +108,93 @@ class ClassicRAG(BaseRetriever):
|
|||||||
return self.original_question
|
return self.original_question
|
||||||
|
|
||||||
def _get_data(self):
|
def _get_data(self):
|
||||||
if self.chunks == 0:
|
if self.chunks == 0 or not self.vectorstores:
|
||||||
docs = []
|
logging.info(
|
||||||
else:
|
f"ClassicRAG._get_data: Skipping retrieval - chunks={self.chunks}, "
|
||||||
docsearch = VectorCreator.create_vectorstore(
|
f"vectorstores_count={len(self.vectorstores) if self.vectorstores else 0}"
|
||||||
settings.VECTOR_STORE, self.vectorstore, settings.EMBEDDINGS_KEY
|
|
||||||
)
|
)
|
||||||
docs_temp = docsearch.search(self.question, k=self.chunks)
|
return []
|
||||||
docs = [
|
|
||||||
{
|
|
||||||
"title": i.metadata.get(
|
|
||||||
"title", i.metadata.get("post_title", i.page_content)
|
|
||||||
).split("/")[-1],
|
|
||||||
"text": i.page_content,
|
|
||||||
"source": (
|
|
||||||
i.metadata.get("source")
|
|
||||||
if i.metadata.get("source")
|
|
||||||
else "local"
|
|
||||||
),
|
|
||||||
}
|
|
||||||
for i in docs_temp
|
|
||||||
]
|
|
||||||
|
|
||||||
return docs
|
all_docs = []
|
||||||
|
chunks_per_source = max(1, self.chunks // len(self.vectorstores))
|
||||||
|
token_budget = max(int(self.doc_token_limit * 0.9), 100)
|
||||||
|
cumulative_tokens = 0
|
||||||
|
|
||||||
def gen():
|
for vectorstore_id in self.vectorstores:
|
||||||
pass
|
if vectorstore_id:
|
||||||
|
try:
|
||||||
|
docsearch = VectorCreator.create_vectorstore(
|
||||||
|
settings.VECTOR_STORE, vectorstore_id, settings.EMBEDDINGS_KEY
|
||||||
|
)
|
||||||
|
docs_temp = docsearch.search(
|
||||||
|
self.question, k=max(chunks_per_source * 2, 20)
|
||||||
|
)
|
||||||
|
|
||||||
|
for doc in docs_temp:
|
||||||
|
if cumulative_tokens >= token_budget:
|
||||||
|
break
|
||||||
|
|
||||||
|
if hasattr(doc, "page_content") and hasattr(doc, "metadata"):
|
||||||
|
page_content = doc.page_content
|
||||||
|
metadata = doc.metadata
|
||||||
|
else:
|
||||||
|
page_content = doc.get("text", doc.get("page_content", ""))
|
||||||
|
metadata = doc.get("metadata", {})
|
||||||
|
|
||||||
|
title = metadata.get(
|
||||||
|
"title", metadata.get("post_title", page_content)
|
||||||
|
)
|
||||||
|
if not isinstance(title, str):
|
||||||
|
title = str(title)
|
||||||
|
title = title.split("/")[-1]
|
||||||
|
|
||||||
|
filename = (
|
||||||
|
metadata.get("filename")
|
||||||
|
or metadata.get("file_name")
|
||||||
|
or metadata.get("source")
|
||||||
|
)
|
||||||
|
if isinstance(filename, str):
|
||||||
|
filename = os.path.basename(filename) or filename
|
||||||
|
else:
|
||||||
|
filename = title
|
||||||
|
if not filename:
|
||||||
|
filename = title
|
||||||
|
source_path = metadata.get("source") or vectorstore_id
|
||||||
|
|
||||||
|
doc_text_with_header = f"{filename}\n{page_content}"
|
||||||
|
doc_tokens = num_tokens_from_string(doc_text_with_header)
|
||||||
|
|
||||||
|
if cumulative_tokens + doc_tokens < token_budget:
|
||||||
|
all_docs.append(
|
||||||
|
{
|
||||||
|
"title": title,
|
||||||
|
"text": page_content,
|
||||||
|
"source": source_path,
|
||||||
|
"filename": filename,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
cumulative_tokens += doc_tokens
|
||||||
|
|
||||||
|
if cumulative_tokens >= token_budget:
|
||||||
|
break
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(
|
||||||
|
f"Error searching vectorstore {vectorstore_id}: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
f"ClassicRAG._get_data: Retrieval complete - retrieved {len(all_docs)} documents "
|
||||||
|
f"(requested chunks={self.chunks}, chunks_per_source={chunks_per_source}, "
|
||||||
|
f"cumulative_tokens={cumulative_tokens}/{token_budget})"
|
||||||
|
)
|
||||||
|
return all_docs
|
||||||
|
|
||||||
def search(self, query: str = ""):
|
def search(self, query: str = ""):
|
||||||
|
"""Search for documents using optional query override"""
|
||||||
if query:
|
if query:
|
||||||
self.original_question = query
|
self.original_question = query
|
||||||
self.question = self._rephrase_query()
|
self.question = self._rephrase_query()
|
||||||
return self._get_data()
|
return self._get_data()
|
||||||
|
|
||||||
def get_params(self):
|
|
||||||
return {
|
|
||||||
"question": self.original_question,
|
|
||||||
"rephrased_question": self.question,
|
|
||||||
"source": self.vectorstore,
|
|
||||||
"chunks": self.chunks,
|
|
||||||
"token_limit": self.token_limit,
|
|
||||||
"gpt_model": self.gpt_model,
|
|
||||||
"user_api_key": self.user_api_key,
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,111 +0,0 @@
|
|||||||
from langchain_community.tools import DuckDuckGoSearchResults
|
|
||||||
from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
|
|
||||||
|
|
||||||
from application.core.settings import settings
|
|
||||||
from application.llm.llm_creator import LLMCreator
|
|
||||||
from application.retriever.base import BaseRetriever
|
|
||||||
|
|
||||||
|
|
||||||
class DuckDuckSearch(BaseRetriever):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
source,
|
|
||||||
chat_history,
|
|
||||||
prompt,
|
|
||||||
chunks=2,
|
|
||||||
token_limit=150,
|
|
||||||
gpt_model="docsgpt",
|
|
||||||
user_api_key=None,
|
|
||||||
decoded_token=None,
|
|
||||||
):
|
|
||||||
self.question = ""
|
|
||||||
self.source = source
|
|
||||||
self.chat_history = chat_history
|
|
||||||
self.prompt = prompt
|
|
||||||
self.chunks = chunks
|
|
||||||
self.gpt_model = gpt_model
|
|
||||||
self.token_limit = (
|
|
||||||
token_limit
|
|
||||||
if token_limit
|
|
||||||
< settings.MODEL_TOKEN_LIMITS.get(
|
|
||||||
self.gpt_model, settings.DEFAULT_MAX_HISTORY
|
|
||||||
)
|
|
||||||
else settings.MODEL_TOKEN_LIMITS.get(
|
|
||||||
self.gpt_model, settings.DEFAULT_MAX_HISTORY
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.user_api_key = user_api_key
|
|
||||||
self.decoded_token = decoded_token
|
|
||||||
|
|
||||||
def _get_data(self):
|
|
||||||
if self.chunks == 0:
|
|
||||||
docs = []
|
|
||||||
else:
|
|
||||||
wrapper = DuckDuckGoSearchAPIWrapper(max_results=self.chunks)
|
|
||||||
search = DuckDuckGoSearchResults(api_wrapper=wrapper, output_format="list")
|
|
||||||
results = search.run(self.question)
|
|
||||||
|
|
||||||
docs = []
|
|
||||||
for i in results:
|
|
||||||
try:
|
|
||||||
docs.append(
|
|
||||||
{
|
|
||||||
"text": i.get("snippet", "").strip(),
|
|
||||||
"title": i.get("title", "").strip(),
|
|
||||||
"link": i.get("link", "").strip(),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
except IndexError:
|
|
||||||
pass
|
|
||||||
if settings.LLM_NAME == "llama.cpp":
|
|
||||||
docs = [docs[0]]
|
|
||||||
|
|
||||||
return docs
|
|
||||||
|
|
||||||
def gen(self):
|
|
||||||
docs = self._get_data()
|
|
||||||
|
|
||||||
# join all page_content together with a newline
|
|
||||||
docs_together = "\n".join([doc["text"] for doc in docs])
|
|
||||||
p_chat_combine = self.prompt.replace("{summaries}", docs_together)
|
|
||||||
messages_combine = [{"role": "system", "content": p_chat_combine}]
|
|
||||||
for doc in docs:
|
|
||||||
yield {"source": doc}
|
|
||||||
|
|
||||||
if len(self.chat_history) > 0:
|
|
||||||
for i in self.chat_history:
|
|
||||||
if "prompt" in i and "response" in i:
|
|
||||||
messages_combine.append({"role": "user", "content": i["prompt"]})
|
|
||||||
messages_combine.append(
|
|
||||||
{"role": "assistant", "content": i["response"]}
|
|
||||||
)
|
|
||||||
messages_combine.append({"role": "user", "content": self.question})
|
|
||||||
|
|
||||||
llm = LLMCreator.create_llm(
|
|
||||||
settings.LLM_NAME,
|
|
||||||
api_key=settings.API_KEY,
|
|
||||||
user_api_key=self.user_api_key,
|
|
||||||
decoded_token=self.decoded_token,
|
|
||||||
)
|
|
||||||
|
|
||||||
completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
|
|
||||||
for line in completion:
|
|
||||||
yield {"answer": str(line)}
|
|
||||||
|
|
||||||
def search(self, query: str = ""):
|
|
||||||
if query:
|
|
||||||
self.question = query
|
|
||||||
return self._get_data()
|
|
||||||
|
|
||||||
def get_params(self):
|
|
||||||
return {
|
|
||||||
"question": self.question,
|
|
||||||
"source": self.source,
|
|
||||||
"chat_history": self.chat_history,
|
|
||||||
"prompt": self.prompt,
|
|
||||||
"chunks": self.chunks,
|
|
||||||
"token_limit": self.token_limit,
|
|
||||||
"gpt_model": self.gpt_model,
|
|
||||||
"user_api_key": self.user_api_key,
|
|
||||||
}
|
|
||||||
@@ -1,20 +1,16 @@
|
|||||||
from application.retriever.classic_rag import ClassicRAG
|
from application.retriever.classic_rag import ClassicRAG
|
||||||
from application.retriever.duckduck_search import DuckDuckSearch
|
|
||||||
from application.retriever.brave_search import BraveRetSearch
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class RetrieverCreator:
|
class RetrieverCreator:
|
||||||
retrievers = {
|
retrievers = {
|
||||||
'classic': ClassicRAG,
|
"classic": ClassicRAG,
|
||||||
'duckduck_search': DuckDuckSearch,
|
"default": ClassicRAG,
|
||||||
'brave_search': BraveRetSearch,
|
|
||||||
'default': ClassicRAG
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_retriever(cls, type, *args, **kwargs):
|
def create_retriever(cls, type, *args, **kwargs):
|
||||||
retiever_class = cls.retrievers.get(type.lower())
|
retriever_type = (type or "default").lower()
|
||||||
|
retiever_class = cls.retrievers.get(retriever_type)
|
||||||
if not retiever_class:
|
if not retiever_class:
|
||||||
raise ValueError(f"No retievers class found for type {type}")
|
raise ValueError(f"No retievers class found for type {type}")
|
||||||
return retiever_class(*args, **kwargs)
|
return retiever_class(*args, **kwargs)
|
||||||
|
|||||||
0
application/security/__init__.py
Normal file
0
application/security/__init__.py
Normal file
85
application/security/encryption.py
Normal file
85
application/security/encryption.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
from cryptography.hazmat.backends import default_backend
|
||||||
|
from cryptography.hazmat.primitives import hashes
|
||||||
|
from cryptography.hazmat.primitives.ciphers import algorithms, Cipher, modes
|
||||||
|
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||||
|
|
||||||
|
from application.core.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
def _derive_key(user_id: str, salt: bytes) -> bytes:
|
||||||
|
app_secret = settings.ENCRYPTION_SECRET_KEY
|
||||||
|
|
||||||
|
password = f"{app_secret}#{user_id}".encode()
|
||||||
|
|
||||||
|
kdf = PBKDF2HMAC(
|
||||||
|
algorithm=hashes.SHA256(),
|
||||||
|
length=32,
|
||||||
|
salt=salt,
|
||||||
|
iterations=100000,
|
||||||
|
backend=default_backend(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return kdf.derive(password)
|
||||||
|
|
||||||
|
|
||||||
|
def encrypt_credentials(credentials: dict, user_id: str) -> str:
|
||||||
|
if not credentials:
|
||||||
|
return ""
|
||||||
|
try:
|
||||||
|
salt = os.urandom(16)
|
||||||
|
iv = os.urandom(16)
|
||||||
|
key = _derive_key(user_id, salt)
|
||||||
|
|
||||||
|
json_str = json.dumps(credentials)
|
||||||
|
|
||||||
|
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
|
||||||
|
encryptor = cipher.encryptor()
|
||||||
|
|
||||||
|
padded_data = _pad_data(json_str.encode())
|
||||||
|
encrypted_data = encryptor.update(padded_data) + encryptor.finalize()
|
||||||
|
|
||||||
|
result = salt + iv + encrypted_data
|
||||||
|
return base64.b64encode(result).decode()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Failed to encrypt credentials: {e}")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt_credentials(encrypted_data: str, user_id: str) -> dict:
|
||||||
|
if not encrypted_data:
|
||||||
|
return {}
|
||||||
|
try:
|
||||||
|
data = base64.b64decode(encrypted_data.encode())
|
||||||
|
|
||||||
|
salt = data[:16]
|
||||||
|
iv = data[16:32]
|
||||||
|
encrypted_content = data[32:]
|
||||||
|
|
||||||
|
key = _derive_key(user_id, salt)
|
||||||
|
|
||||||
|
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
|
||||||
|
decryptor = cipher.decryptor()
|
||||||
|
|
||||||
|
decrypted_padded = decryptor.update(encrypted_content) + decryptor.finalize()
|
||||||
|
decrypted_data = _unpad_data(decrypted_padded)
|
||||||
|
|
||||||
|
return json.loads(decrypted_data.decode())
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Failed to decrypt credentials: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def _pad_data(data: bytes) -> bytes:
|
||||||
|
block_size = 16
|
||||||
|
padding_len = block_size - (len(data) % block_size)
|
||||||
|
padding = bytes([padding_len]) * padding_len
|
||||||
|
return data + padding
|
||||||
|
|
||||||
|
|
||||||
|
def _unpad_data(data: bytes) -> bytes:
|
||||||
|
padding_len = data[-1]
|
||||||
|
return data[:-padding_len]
|
||||||
0
application/seed/__init__.py
Normal file
0
application/seed/__init__.py
Normal file
26
application/seed/commands.py
Normal file
26
application/seed/commands.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
import click
|
||||||
|
|
||||||
|
from application.core.mongo_db import MongoDB
|
||||||
|
from application.core.settings import settings
|
||||||
|
from application.seed.seeder import DatabaseSeeder
|
||||||
|
|
||||||
|
|
||||||
|
@click.group()
|
||||||
|
def seed():
|
||||||
|
"""Database seeding commands"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@seed.command()
|
||||||
|
@click.option("--force", is_flag=True, help="Force reseeding even if data exists")
|
||||||
|
def init(force):
|
||||||
|
"""Initialize database with seed data"""
|
||||||
|
mongo = MongoDB.get_client()
|
||||||
|
db = mongo[settings.MONGO_DB_NAME]
|
||||||
|
|
||||||
|
seeder = DatabaseSeeder(db)
|
||||||
|
seeder.seed_initial_data(force=force)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
seed()
|
||||||
36
application/seed/config/agents_template.yaml
Normal file
36
application/seed/config/agents_template.yaml
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
# Configuration for Premade Agents
|
||||||
|
# This file contains template agents that will be seeded into the database
|
||||||
|
|
||||||
|
agents:
|
||||||
|
# Basic Agent Template
|
||||||
|
- name: "Agent Name" # Required: Unique name for the agent
|
||||||
|
description: "What this agent does" # Required: Brief description of the agent's purpose
|
||||||
|
image: "URL_TO_IMAGE" # Optional: URL to agent's avatar/image
|
||||||
|
agent_type: "classic" # Required: Type of agent (e.g., classic, react, etc.)
|
||||||
|
prompt_id: "default" # Optional: Reference to prompt template
|
||||||
|
prompt: # Optional: Define new prompt
|
||||||
|
name: "New Prompt"
|
||||||
|
content: "You are new agent with cool new prompt."
|
||||||
|
chunks: "0" # Optional: Chunking strategy for documents
|
||||||
|
retriever: "" # Optional: Retriever type for document search
|
||||||
|
|
||||||
|
# Source Configuration (where the agent gets its knowledge)
|
||||||
|
source: # Optional: Select a source to link with agent
|
||||||
|
name: "Source Display Name" # Human-readable name for the source
|
||||||
|
url: "https://example.com/data-source" # URL or path to knowledge source
|
||||||
|
loader: "url" # Type of loader (url, pdf, txt, etc.)
|
||||||
|
|
||||||
|
# Tools Configuration (what capabilities the agent has)
|
||||||
|
tools: # Optional: Remove if agent doesn't need tools
|
||||||
|
- name: "tool_name" # Must match a supported tool name
|
||||||
|
display_name: "Tool Display Name" # Optional: Human-readable name for the tool
|
||||||
|
config:
|
||||||
|
# Tool-specific configuration
|
||||||
|
# Example for DuckDuckGo:
|
||||||
|
# token: "${DDG_API_KEY}" # ${} denotes environment variable
|
||||||
|
|
||||||
|
# Add more tools as needed
|
||||||
|
# - name: "another_tool"
|
||||||
|
# config:
|
||||||
|
# param1: "value1"
|
||||||
|
# param2: "${ENV_VAR}"
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user