mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 16:43:16 +00:00
Compare commits
375 Commits
dependabot
...
hacktoberf
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e6d64f71f2 | ||
|
|
e72313ebdd | ||
|
|
65d5bd72cd | ||
|
|
dc0cbb41f0 | ||
|
|
c4a54a85be | ||
|
|
c2ccf2c72c | ||
|
|
80aaecb5f0 | ||
|
|
946865a335 | ||
|
|
5de15c8413 | ||
|
|
67268fd35a | ||
|
|
42fc771833 | ||
|
|
4d34dc4234 | ||
|
|
d567399f2b | ||
|
|
77f4f8d8b0 | ||
|
|
a2d04beaa1 | ||
|
|
ba49eea23d | ||
|
|
82beafc086 | ||
|
|
7d8ed2d102 | ||
|
|
aab8d3a4f1 | ||
|
|
76658d50a0 | ||
|
|
88ba22342c | ||
|
|
11a1460af9 | ||
|
|
2cd4c41316 | ||
|
|
b910f308f2 | ||
|
|
763aa73ea4 | ||
|
|
30c79e92d4 | ||
|
|
402d5e054b | ||
|
|
0e211df206 | ||
|
|
e24a0ac686 | ||
|
|
8c91b1c527 | ||
|
|
2b38f80d04 | ||
|
|
282bd35f52 | ||
|
|
cc9b4c2bcb | ||
|
|
068ce4970a | ||
|
|
cf19165ad8 | ||
|
|
68c479f3a5 | ||
|
|
ba496a772b | ||
|
|
3b27db36f2 | ||
|
|
f803def69b | ||
|
|
52065e69a4 | ||
|
|
50f5e8a955 | ||
|
|
2d0e97b66d | ||
|
|
5f3cc5a392 | ||
|
|
ac66d77512 | ||
|
|
50cf653d4a | ||
|
|
56256051d2 | ||
|
|
c0361ff03d | ||
|
|
f153435c08 | ||
|
|
9aa7f22fa6 | ||
|
|
52b7bda5f8 | ||
|
|
21aefa2778 | ||
|
|
a89ff71c9e | ||
|
|
4c275816be | ||
|
|
f8dfbcfc80 | ||
|
|
d317f6473d | ||
|
|
00b4e133d4 | ||
|
|
b6349e4efb | ||
|
|
6ca3d9585c | ||
|
|
5935a0283a | ||
|
|
5400a6ec06 | ||
|
|
6574d9cc84 | ||
|
|
42b83c5994 | ||
|
|
896612a5a3 | ||
|
|
0ee875bee4 | ||
|
|
8ce345cd94 | ||
|
|
da2f8477e6 | ||
|
|
82b47b5673 | ||
|
|
3369b910b4 | ||
|
|
ec0c4c3b84 | ||
|
|
f74e2c9da1 | ||
|
|
e26ad3c475 | ||
|
|
145c3b8ad0 | ||
|
|
0ff6c6a154 | ||
|
|
641cf5a4c1 | ||
|
|
09b9576eef | ||
|
|
18b71ca2f2 | ||
|
|
e0eb7f456e | ||
|
|
188d118fc0 | ||
|
|
adcdce8d76 | ||
|
|
b865a7aec1 | ||
|
|
cec8c72b46 | ||
|
|
b052e32805 | ||
|
|
816f660be3 | ||
|
|
fc8be45d5a | ||
|
|
e749c936c9 | ||
|
|
b2b9670a23 | ||
|
|
2f88890c94 | ||
|
|
6366663f03 | ||
|
|
20fe7dc6d1 | ||
|
|
4b9153069e | ||
|
|
80406d0753 | ||
|
|
35f4c11784 | ||
|
|
7896526f19 | ||
|
|
f7db22edff | ||
|
|
0e4196f036 | ||
|
|
1bf6af6eeb | ||
|
|
5a9bc6d2bf | ||
|
|
f7f6042579 | ||
|
|
c4a598f3d3 | ||
|
|
7c23f43c63 | ||
|
|
7e2cbdd88c | ||
|
|
3b3a04a249 | ||
|
|
f9b2c95695 | ||
|
|
c2c18e8319 | ||
|
|
384ad3e0ac | ||
|
|
8c986aaa7f | ||
|
|
bb4ea76d30 | ||
|
|
2868e47cf8 | ||
|
|
e0adc3e5d5 | ||
|
|
e55d1a5865 | ||
|
|
018273c6b2 | ||
|
|
44b8a11c04 | ||
|
|
56e5aba559 | ||
|
|
46904ccd54 | ||
|
|
5b7c7a4471 | ||
|
|
9da4215d1f | ||
|
|
f39ac9945f | ||
|
|
a0cc2e4d46 | ||
|
|
4065041a9f | ||
|
|
f08067a161 | ||
|
|
545caacfa3 | ||
|
|
a06f646637 | ||
|
|
578c68205a | ||
|
|
f09f1433a9 | ||
|
|
15a9e97a1e | ||
|
|
b3af4ee50b | ||
|
|
07d59b6640 | ||
|
|
e25b988dc8 | ||
|
|
2410bd8654 | ||
|
|
44d21ab703 | ||
|
|
e283957c8f | ||
|
|
b1210c4902 | ||
|
|
e7430f0fbc | ||
|
|
92d6ae54c3 | ||
|
|
f82be23ca9 | ||
|
|
8c3f75e3e2 | ||
|
|
193d59f193 | ||
|
|
c2bebbaefa | ||
|
|
7ae5a9c5a5 | ||
|
|
3b69bea23d | ||
|
|
ab05726b99 | ||
|
|
b2b04268e9 | ||
|
|
bd73fa9ae7 | ||
|
|
927d10d66e | ||
|
|
b67329623c | ||
|
|
6f47aa802b | ||
|
|
3417c73011 | ||
|
|
6a02bcf15b | ||
|
|
cd0fbf79a3 | ||
|
|
15d2d0115b | ||
|
|
d1a0fe6e91 | ||
|
|
1db80d140f | ||
|
|
896dcf1f9e | ||
|
|
819a12fb49 | ||
|
|
c68273706c | ||
|
|
6bb0cd535a | ||
|
|
cb9ec69cf6 | ||
|
|
143854fa81 | ||
|
|
2f48a3d7d5 | ||
|
|
ec95dafe1e | ||
|
|
3d1fe724e5 | ||
|
|
5c615d6f2d | ||
|
|
d72558eb36 | ||
|
|
65c33ad915 | ||
|
|
9be128a963 | ||
|
|
eb05132008 | ||
|
|
f94a093e8c | ||
|
|
0d0c2daf64 | ||
|
|
823d948b25 | ||
|
|
56831fbcf2 | ||
|
|
bf49b9cb88 | ||
|
|
e01adffbad | ||
|
|
08a5d52d82 | ||
|
|
fdae235742 | ||
|
|
9903fad1e9 | ||
|
|
14bbd5338d | ||
|
|
4a236c2f6f | ||
|
|
0a8cdbd7f1 | ||
|
|
94c49843be | ||
|
|
9281fac898 | ||
|
|
0b2736f454 | ||
|
|
ae116b0d0d | ||
|
|
ba260e3382 | ||
|
|
1282e7687f | ||
|
|
b1d8266eef | ||
|
|
7acae6935b | ||
|
|
092c01cae7 | ||
|
|
56a1066c30 | ||
|
|
1356d71839 | ||
|
|
1eb011e8c3 | ||
|
|
e349eb28b0 | ||
|
|
b000b235a2 | ||
|
|
16fe92282e | ||
|
|
e218e88cf4 | ||
|
|
888ea81a32 | ||
|
|
735fab7640 | ||
|
|
45745c2a47 | ||
|
|
4caff0fcf6 | ||
|
|
762ea6ce7f | ||
|
|
8b4f6553f3 | ||
|
|
a61e44d175 | ||
|
|
e1b1558fc9 | ||
|
|
53225bda4e | ||
|
|
5212769848 | ||
|
|
d5ded3c9f4 | ||
|
|
c92d778894 | ||
|
|
829abd1ad6 | ||
|
|
266d256a07 | ||
|
|
8380cac3e7 | ||
|
|
a24652f901 | ||
|
|
2d203d3c70 | ||
|
|
48d21600da | ||
|
|
2508d0fbb3 | ||
|
|
e90e80c289 | ||
|
|
5e4748f9d9 | ||
|
|
212952f3e9 | ||
|
|
f99b6496c5 | ||
|
|
67423d51b9 | ||
|
|
58465ece65 | ||
|
|
8ede3a0173 | ||
|
|
ad2f0f8950 | ||
|
|
76973a4b4c | ||
|
|
b198e2e029 | ||
|
|
4d6ea401b5 | ||
|
|
b00c4cc3b6 | ||
|
|
4185e64c65 | ||
|
|
6eb2c884a2 | ||
|
|
6c0362a4cf | ||
|
|
50b1755a63 | ||
|
|
ff3c7eb5fb | ||
|
|
3755316d49 | ||
|
|
f952046847 | ||
|
|
969cdb4a63 | ||
|
|
f336d44595 | ||
|
|
a53f93c195 | ||
|
|
fcb334ce33 | ||
|
|
8ddf04a904 | ||
|
|
29698ca169 | ||
|
|
a9baf7436a | ||
|
|
99a8962183 | ||
|
|
afc5b15a6b | ||
|
|
b6ab508e27 | ||
|
|
789e65557a | ||
|
|
8a7806ab2d | ||
|
|
493303e103 | ||
|
|
1d9af05e9e | ||
|
|
5b07c5f2e8 | ||
|
|
2a4ec0cf5b | ||
|
|
a00c44386e | ||
|
|
a38d71bbfb | ||
|
|
a24a3f868c | ||
|
|
f60c516185 | ||
|
|
26f4646304 | ||
|
|
3a351f67e6 | ||
|
|
e7c09cb91e | ||
|
|
ae1a6ef303 | ||
|
|
2ff477a339 | ||
|
|
793f3fb683 | ||
|
|
a472ee7602 | ||
|
|
c62040e232 | ||
|
|
2e7cb510ae | ||
|
|
dbe45904d7 | ||
|
|
5623734276 | ||
|
|
d3b592bffc | ||
|
|
4fcbdae5bf | ||
|
|
ca95d7275a | ||
|
|
61baf3701c | ||
|
|
bbce872ac5 | ||
|
|
0f7ebcd8e4 | ||
|
|
82fc19e7b7 | ||
|
|
839a12bed4 | ||
|
|
2ef23fe1b3 | ||
|
|
fd905b1a06 | ||
|
|
1372210004 | ||
|
|
ade704d065 | ||
|
|
42f48649b9 | ||
|
|
0b08e8b617 | ||
|
|
926b2f1a1b | ||
|
|
1770a1a45f | ||
|
|
50ed2a64c6 | ||
|
|
2332344988 | ||
|
|
7ccc8cdc58 | ||
|
|
ecec9f913e | ||
|
|
777f40fc5e | ||
|
|
327ae35420 | ||
|
|
0d48159da8 | ||
|
|
d36f12a4ea | ||
|
|
709488beb1 | ||
|
|
a9e4583695 | ||
|
|
4702dec933 | ||
|
|
e6352dd691 | ||
|
|
240ea3b857 | ||
|
|
f0908af3c0 | ||
|
|
6834961dd1 | ||
|
|
b404162364 | ||
|
|
e879ef805f | ||
|
|
7077ca5e98 | ||
|
|
a1e6978c8f | ||
|
|
584391dd59 | ||
|
|
bab3ae809c | ||
|
|
c78518baf0 | ||
|
|
556d7e0497 | ||
|
|
2d27936dab | ||
|
|
0cc22de545 | ||
|
|
63f6127049 | ||
|
|
f34e00c986 | ||
|
|
55f60a9fe1 | ||
|
|
7da3618e0c | ||
|
|
56bfa98633 | ||
|
|
96f6188722 | ||
|
|
aa9d359039 | ||
|
|
cef5731028 | ||
|
|
5bc28bd4fd | ||
|
|
55a1d867c3 | ||
|
|
6c3a79802e | ||
|
|
c35c5e0793 | ||
|
|
7bc83caa99 | ||
|
|
3aceca63c6 | ||
|
|
9bc166ffd4 | ||
|
|
fc01b90007 | ||
|
|
e35f1d70e4 | ||
|
|
cab1f3787a | ||
|
|
bb42f4cbc1 | ||
|
|
98dc418a51 | ||
|
|
322b4eb18c | ||
|
|
7f1cc30ed8 | ||
|
|
7b45a6b956 | ||
|
|
e36769e70f | ||
|
|
bd4a4cc4af | ||
|
|
8343fe63cb | ||
|
|
7d89fb8461 | ||
|
|
098955d230 | ||
|
|
d254d14928 | ||
|
|
0a3e8ca535 | ||
|
|
b8a10e0962 | ||
|
|
0aceda96e4 | ||
|
|
44b6ec25a2 | ||
|
|
1b84d1fa9d | ||
|
|
78d5ed2ed2 | ||
|
|
142477ab9b | ||
|
|
b414f79bc5 | ||
|
|
6e08fe21d0 | ||
|
|
9b839655a7 | ||
|
|
3353c0ee1d | ||
|
|
aaecf52c99 | ||
|
|
8b3e960be0 | ||
|
|
3351f71813 | ||
|
|
7490256303 | ||
|
|
041d600e45 | ||
|
|
b4e2588a24 | ||
|
|
68dc14c5a1 | ||
|
|
ef35864e16 | ||
|
|
c0d385b983 | ||
|
|
b2df431fa4 | ||
|
|
69a4bd415a | ||
|
|
4862548e65 | ||
|
|
50248cc9ea | ||
|
|
430822bae3 | ||
|
|
dd9d18208d | ||
|
|
e5b1a71659 | ||
|
|
35f4b13237 | ||
|
|
5f5c31cd5b | ||
|
|
e9530d5ec5 | ||
|
|
143f4aa886 | ||
|
|
ece5c8bb31 | ||
|
|
3bae30c70c | ||
|
|
12b18c6bd1 | ||
|
|
787d9e3bf5 | ||
|
|
f325b54895 | ||
|
|
c5616705b0 | ||
|
|
e90fe117ec | ||
|
|
7cab5b3b09 | ||
|
|
9f911cb5cb | ||
|
|
3da7cba06c | ||
|
|
92c3c707e1 |
2
.gitattributes
vendored
Normal file
2
.gitattributes
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
# Auto detect text files and perform LF normalization
|
||||
* text=auto
|
||||
33
.vscode/launch.json
vendored
33
.vscode/launch.json
vendored
@@ -2,15 +2,11 @@
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Docker Debug Frontend",
|
||||
"name": "Frontend Debug (npm)",
|
||||
"type": "node-terminal",
|
||||
"request": "launch",
|
||||
"type": "chrome",
|
||||
"preLaunchTask": "docker-compose: debug:frontend",
|
||||
"url": "http://127.0.0.1:5173",
|
||||
"webRoot": "${workspaceFolder}/frontend",
|
||||
"skipFiles": [
|
||||
"<node_internals>/**"
|
||||
]
|
||||
"command": "npm run dev",
|
||||
"cwd": "${workspaceFolder}/frontend"
|
||||
},
|
||||
{
|
||||
"name": "Flask Debugger",
|
||||
@@ -49,6 +45,27 @@
|
||||
"--pool=solo"
|
||||
],
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
{
|
||||
"name": "Dev Containers (Mongo + Redis)",
|
||||
"type": "node-terminal",
|
||||
"request": "launch",
|
||||
"command": "docker compose -f deployment/docker-compose-dev.yaml up --build",
|
||||
"cwd": "${workspaceFolder}"
|
||||
}
|
||||
],
|
||||
"compounds": [
|
||||
{
|
||||
"name": "DocsGPT: Full Stack",
|
||||
"configurations": [
|
||||
"Frontend Debug (npm)",
|
||||
"Flask Debugger",
|
||||
"Celery Debugger"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "DocsGPT",
|
||||
"order": 1
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
21
.vscode/tasks.json
vendored
21
.vscode/tasks.json
vendored
@@ -1,21 +0,0 @@
|
||||
{
|
||||
"version": "2.0.0",
|
||||
"tasks": [
|
||||
{
|
||||
"type": "docker-compose",
|
||||
"label": "docker-compose: debug:frontend",
|
||||
"dockerCompose": {
|
||||
"up": {
|
||||
"detached": true,
|
||||
"services": [
|
||||
"frontend"
|
||||
],
|
||||
"build": true
|
||||
},
|
||||
"files": [
|
||||
"${workspaceFolder}/docker-compose.yaml"
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
38
HACKTOBERFEST.md
Normal file
38
HACKTOBERFEST.md
Normal file
@@ -0,0 +1,38 @@
|
||||
# **🎉 Join the Hacktoberfest with DocsGPT and win a Free T-shirt for a meaningful PR! 🎉**
|
||||
|
||||
Welcome, contributors! We're excited to announce that DocsGPT is participating in Hacktoberfest. Get involved by submitting meaningful pull requests.
|
||||
|
||||
All Meaningful contributors with accepted PRs that were created for issues with the `hacktoberfest` label (set by our maintainer team: dartpain, siiddhantt, pabik, ManishMadan2882) will receive a cool T-shirt! 🤩.
|
||||
|
||||
Fill in [this form](https://forms.gle/Npaba4n9Epfyx56S8
|
||||
) after your PR was merged please
|
||||
|
||||
If you are in doubt don't hesitate to ping us on discord, ping me - Alex (dartpain).
|
||||
|
||||
## 📜 Here's How to Contribute:
|
||||
```text
|
||||
🛠️ Code: This is the golden ticket! Make meaningful contributions through PRs.
|
||||
|
||||
🧩 API extension: Build an app utilising DocsGPT API. We prefer submissions that showcase original ideas and turn the API into an AI agent.
|
||||
They can be a completely separate repos.
|
||||
For example:
|
||||
https://github.com/arc53/tg-bot-docsgpt-extenstion or
|
||||
https://github.com/arc53/DocsGPT-cli
|
||||
|
||||
Non-Code Contributions:
|
||||
|
||||
📚 Wiki: Improve our documentation, create a guide.
|
||||
|
||||
🖥️ Design: Improve the UI/UX or design a new feature.
|
||||
```
|
||||
|
||||
### 📝 Guidelines for Pull Requests:
|
||||
- Familiarize yourself with the current contributions and our [Roadmap](https://github.com/orgs/arc53/projects/2).
|
||||
- Before contributing check existing [issues](https://github.com/arc53/DocsGPT/issues) or [create](https://github.com/arc53/DocsGPT/issues/new/choose) an issue and wait to get assigned.
|
||||
- Once you are finished with your contribution, please fill in this [form](https://forms.gle/Npaba4n9Epfyx56S8).
|
||||
- Refer to the [Documentation](https://docs.docsgpt.cloud/).
|
||||
- Feel free to join our [Discord](https://discord.gg/n5BX8dh8rU) server. We're here to help newcomers, so don't hesitate to jump in! Join us [here](https://discord.gg/n5BX8dh8rU).
|
||||
|
||||
Thank you very much for considering contributing to DocsGPT during Hacktoberfest! 🙏 Your contributions (not just simple typos) could earn you a stylish new t-shirt.
|
||||
|
||||
We will publish a t-shirt desing later into the October.
|
||||
39
README.md
39
README.md
@@ -3,11 +3,11 @@
|
||||
</h1>
|
||||
|
||||
<p align="center">
|
||||
<strong>Open-Source RAG Assistant</strong>
|
||||
<strong>Private AI for agents, assistants and enterprise search</strong>
|
||||
</p>
|
||||
|
||||
<p align="left">
|
||||
<strong><a href="https://www.docsgpt.cloud/">DocsGPT</a></strong> is an open-source genAI tool that helps users get reliable answers from any knowledge source, while avoiding hallucinations. It enables quick and reliable information retrieval, with tooling and agentic system capability built in.
|
||||
<strong><a href="https://www.docsgpt.cloud/">DocsGPT</a></strong> is an open-source AI platform for building intelligent agents and assistants. Features Agent Builder, deep research tools, document analysis (PDF, Office, web content), Multi-model support (choose your provider or run locally), and rich API connectivity for agents with actionable tools and integrations. Deploy anywhere with complete privacy control.
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
@@ -19,13 +19,23 @@
|
||||
<a href="https://discord.gg/n5BX8dh8rU"></a>
|
||||
<a href="https://twitter.com/docsgptai"></a>
|
||||
|
||||
<a href="https://docs.docsgpt.cloud/quickstart">⚡️ Quickstart</a> • <a href="https://app.docsgpt.cloud/">☁️ Cloud Version</a> • <a href="https://discord.gg/n5BX8dh8rU">💬 Discord</a>
|
||||
<br>
|
||||
<a href="https://docs.docsgpt.cloud/">📖 Documentation</a> • <a href="https://github.com/arc53/DocsGPT/blob/main/CONTRIBUTING.md">👫 Contribute</a> • <a href="https://blog.docsgpt.cloud/">🗞 Blog</a>
|
||||
<br>
|
||||
<a href="https://docs.docsgpt.cloud/quickstart">⚡️ Quickstart</a> • <a href="https://app.docsgpt.cloud/">☁️ Cloud Version</a> • <a href="https://discord.gg/n5BX8dh8rU">💬 Discord</a>
|
||||
<br>
|
||||
<a href="https://docs.docsgpt.cloud/">📖 Documentation</a> • <a href="https://github.com/arc53/DocsGPT/blob/main/CONTRIBUTING.md">👫 Contribute</a> • <a href="https://blog.docsgpt.cloud/">🗞 Blog</a>
|
||||
<br>
|
||||
|
||||
</div>
|
||||
|
||||
<div align="center">
|
||||
<br>
|
||||
🎃 <a href="https://github.com/arc53/DocsGPT/blob/main/HACKTOBERFEST.md"> Hacktoberfest Prizes, Rules & Q&A </a> 🎃
|
||||
<br>
|
||||
<br>
|
||||
</div>
|
||||
|
||||
|
||||
<div align="center">
|
||||
<br>
|
||||
<img src="https://d3dg1063dc54p9.cloudfront.net/videos/demov7.gif" alt="video-example-of-docs-gpt" width="800" height="450">
|
||||
</div>
|
||||
<h3 align="left">
|
||||
@@ -52,8 +62,14 @@
|
||||
- [x] Chatbots menu re-design to handle tools, agent types, and more (April 2025)
|
||||
- [x] New input box in the conversation menu (April 2025)
|
||||
- [x] Add triggerable actions / tools (webhook) (April 2025)
|
||||
- [ ] Anthropic Tool compatibility (May 2025)
|
||||
- [ ] Add OAuth 2.0 authentication for tools and sources
|
||||
- [x] Agent optimisations (May 2025)
|
||||
- [x] Filesystem sources update (July 2025)
|
||||
- [x] Json Responses (August 2025)
|
||||
- [x] MCP support (August 2025)
|
||||
- [x] Google Drive integration (September 2025)
|
||||
- [ ] Add OAuth 2.0 authentication for MCP (September 2025)
|
||||
- [ ] Sharepoint integration (October 2025)
|
||||
- [ ] Deep Agents (October 2025)
|
||||
- [ ] Agent scheduling
|
||||
|
||||
You can find our full roadmap [here](https://github.com/orgs/arc53/projects/2). Please don't hesitate to contribute or create issues, it helps us improve DocsGPT!
|
||||
@@ -68,11 +84,10 @@ We're eager to provide personalized assistance when deploying your DocsGPT to a
|
||||
|
||||
## Join the Lighthouse Program 🌟
|
||||
|
||||
Calling all developers and GenAI innovators! The **DocsGPT Lighthouse Program** connects technical leaders actively deploying or extending DocsGPT in real-world scenarios. Collaborate directly with our team to shape the roadmap, access priority support, and build enterprise-ready solutions with exclusive community insights.
|
||||
Calling all developers and GenAI innovators! The **DocsGPT Lighthouse Program** connects technical leaders actively deploying or extending DocsGPT in real-world scenarios. Collaborate directly with our team to shape the roadmap, access priority support, and build enterprise-ready solutions with exclusive community insights.
|
||||
|
||||
[Learn More & Apply →](https://docs.google.com/forms/d/1KAADiJinUJ8EMQyfTXUIGyFbqINNClNR3jBNWq7DgTE)
|
||||
|
||||
|
||||
## QuickStart
|
||||
|
||||
> [!Note]
|
||||
@@ -103,7 +118,7 @@ A more detailed [Quickstart](https://docs.docsgpt.cloud/quickstart) is available
|
||||
PowerShell -ExecutionPolicy Bypass -File .\setup.ps1
|
||||
```
|
||||
|
||||
Either script will guide you through setting up DocsGPT. Four options available: using the public API, running locally, connecting to a local inference engine, or using a cloud API provider. Scripts will automatically configure your `.env` file and handle necessary downloads and installations based on your chosen option.
|
||||
Either script will guide you through setting up DocsGPT. Four options available: using the public API, running locally, connecting to a local inference engine, or using a cloud API provider. Scripts will automatically configure your `.env` file and handle necessary downloads and installations based on your chosen option.
|
||||
|
||||
**Navigate to http://localhost:5173/**
|
||||
|
||||
@@ -112,6 +127,7 @@ To stop DocsGPT, open a terminal in the `DocsGPT` directory and run:
|
||||
```bash
|
||||
docker compose -f deployment/docker-compose.yaml down
|
||||
```
|
||||
|
||||
(or use the specific `docker compose down` command shown after running the setup script).
|
||||
|
||||
> [!Note]
|
||||
@@ -139,7 +155,6 @@ Please refer to the [CONTRIBUTING.md](CONTRIBUTING.md) file for information abou
|
||||
|
||||
We as members, contributors, and leaders, pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. Please refer to the [CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md) file for more information about contributing.
|
||||
|
||||
|
||||
## Many Thanks To Our Contributors⚡
|
||||
|
||||
<a href="https://github.com/arc53/DocsGPT/graphs/contributors" alt="View Contributors">
|
||||
|
||||
@@ -1,17 +1,20 @@
|
||||
import logging
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Generator, List, Optional
|
||||
|
||||
from application.agents.llm_handler import get_llm_handler
|
||||
from bson.objectid import ObjectId
|
||||
|
||||
from application.agents.tools.tool_action_parser import ToolActionParser
|
||||
from application.agents.tools.tool_manager import ToolManager
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.llm.handlers.handler_creator import LLMHandlerCreator
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.logging import build_stack_data, log_activity, LogContext
|
||||
from application.retriever.base import BaseRetriever
|
||||
from application.core.settings import settings
|
||||
from bson.objectid import ObjectId
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseAgent(ABC):
|
||||
@@ -26,6 +29,7 @@ class BaseAgent(ABC):
|
||||
chat_history: Optional[List[Dict]] = None,
|
||||
decoded_token: Optional[Dict] = None,
|
||||
attachments: Optional[List[Dict]] = None,
|
||||
json_schema: Optional[Dict] = None,
|
||||
):
|
||||
self.endpoint = endpoint
|
||||
self.llm_name = llm_name
|
||||
@@ -45,8 +49,11 @@ class BaseAgent(ABC):
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
self.llm_handler = get_llm_handler(llm_name)
|
||||
self.llm_handler = LLMHandlerCreator.create_handler(
|
||||
llm_name if llm_name else "default"
|
||||
)
|
||||
self.attachments = attachments or []
|
||||
self.json_schema = json_schema
|
||||
|
||||
@log_activity()
|
||||
def gen(
|
||||
@@ -87,8 +94,8 @@ class BaseAgent(ABC):
|
||||
user_tools_collection = db["user_tools"]
|
||||
user_tools = user_tools_collection.find({"user": user, "status": True})
|
||||
user_tools = list(user_tools)
|
||||
tools_by_id = {str(tool["_id"]): tool for tool in user_tools}
|
||||
return tools_by_id
|
||||
|
||||
return {str(i): tool for i, tool in enumerate(user_tools)}
|
||||
|
||||
def _build_tool_parameters(self, action):
|
||||
params = {"type": "object", "properties": {}, "required": []}
|
||||
@@ -132,6 +139,49 @@ class BaseAgent(ABC):
|
||||
parser = ToolActionParser(self.llm.__class__.__name__)
|
||||
tool_id, action_name, call_args = parser.parse_args(call)
|
||||
|
||||
call_id = getattr(call, "id", None) or str(uuid.uuid4())
|
||||
|
||||
# Check if parsing failed
|
||||
if tool_id is None or action_name is None:
|
||||
error_message = f"Error: Failed to parse LLM tool call. Tool name: {getattr(call, 'name', 'unknown')}"
|
||||
logger.error(error_message)
|
||||
|
||||
tool_call_data = {
|
||||
"tool_name": "unknown",
|
||||
"call_id": call_id,
|
||||
"action_name": getattr(call, "name", "unknown"),
|
||||
"arguments": call_args or {},
|
||||
"result": f"Failed to parse tool call. Invalid tool name format: {getattr(call, 'name', 'unknown')}",
|
||||
}
|
||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
|
||||
self.tool_calls.append(tool_call_data)
|
||||
return "Failed to parse tool call.", call_id
|
||||
|
||||
# Check if tool_id exists in available tools
|
||||
if tool_id not in tools_dict:
|
||||
error_message = f"Error: Tool ID '{tool_id}' extracted from LLM call not found in available tools_dict. Available IDs: {list(tools_dict.keys())}"
|
||||
logger.error(error_message)
|
||||
|
||||
# Return error result
|
||||
tool_call_data = {
|
||||
"tool_name": "unknown",
|
||||
"call_id": call_id,
|
||||
"action_name": f"{action_name}_{tool_id}",
|
||||
"arguments": call_args,
|
||||
"result": f"Tool with ID {tool_id} not found. Available tools: {list(tools_dict.keys())}",
|
||||
}
|
||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
|
||||
self.tool_calls.append(tool_call_data)
|
||||
return f"Tool with ID {tool_id} not found.", call_id
|
||||
|
||||
tool_call_data = {
|
||||
"tool_name": tools_dict[tool_id]["name"],
|
||||
"call_id": call_id,
|
||||
"action_name": f"{action_name}_{tool_id}",
|
||||
"arguments": call_args,
|
||||
}
|
||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "pending"}}
|
||||
|
||||
tool_data = tools_dict[tool_id]
|
||||
action_data = (
|
||||
tool_data["config"]["actions"][action_name]
|
||||
@@ -175,6 +225,7 @@ class BaseAgent(ABC):
|
||||
if tool_data["name"] == "api_tool"
|
||||
else tool_data["config"]
|
||||
),
|
||||
user_id=self.user, # Pass user ID for MCP tools credential decryption
|
||||
)
|
||||
if tool_data["name"] == "api_tool":
|
||||
print(
|
||||
@@ -184,26 +235,44 @@ class BaseAgent(ABC):
|
||||
else:
|
||||
print(f"Executing tool: {action_name} with args: {call_args}")
|
||||
result = tool.execute_action(action_name, **parameters)
|
||||
call_id = getattr(call, "id", None)
|
||||
tool_call_data["result"] = (
|
||||
f"{str(result)[:50]}..." if len(str(result)) > 50 else result
|
||||
)
|
||||
|
||||
tool_call_data = {
|
||||
"tool_name": tool_data["name"],
|
||||
"call_id": call_id if call_id is not None else "None",
|
||||
"action_name": f"{action_name}_{tool_id}",
|
||||
"arguments": call_args,
|
||||
"result": result,
|
||||
}
|
||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "completed"}}
|
||||
self.tool_calls.append(tool_call_data)
|
||||
|
||||
return result, call_id
|
||||
|
||||
def _get_truncated_tool_calls(self):
|
||||
return [
|
||||
{
|
||||
**tool_call,
|
||||
"result": (
|
||||
f"{str(tool_call['result'])[:50]}..."
|
||||
if len(str(tool_call["result"])) > 50
|
||||
else tool_call["result"]
|
||||
),
|
||||
"status": "completed",
|
||||
}
|
||||
for tool_call in self.tool_calls
|
||||
]
|
||||
|
||||
def _build_messages(
|
||||
self,
|
||||
system_prompt: str,
|
||||
query: str,
|
||||
retrieved_data: List[Dict],
|
||||
) -> List[Dict]:
|
||||
docs_together = "\n".join([doc["text"] for doc in retrieved_data])
|
||||
docs_with_filenames = []
|
||||
for doc in retrieved_data:
|
||||
filename = doc.get("filename") or doc.get("title") or doc.get("source")
|
||||
if filename:
|
||||
chunk_header = str(filename)
|
||||
docs_with_filenames.append(f"{chunk_header}\n{doc['text']}")
|
||||
else:
|
||||
docs_with_filenames.append(doc["text"])
|
||||
docs_together = "\n\n".join(docs_with_filenames)
|
||||
p_chat_combine = system_prompt.replace("{summaries}", docs_together)
|
||||
messages_combine = [{"role": "system", "content": p_chat_combine}]
|
||||
|
||||
@@ -252,9 +321,31 @@ class BaseAgent(ABC):
|
||||
return retrieved_data
|
||||
|
||||
def _llm_gen(self, messages: List[Dict], log_context: Optional[LogContext] = None):
|
||||
resp = self.llm.gen_stream(
|
||||
model=self.gpt_model, messages=messages, tools=self.tools
|
||||
)
|
||||
gen_kwargs = {"model": self.gpt_model, "messages": messages}
|
||||
|
||||
if (
|
||||
hasattr(self.llm, "_supports_tools")
|
||||
and self.llm._supports_tools
|
||||
and self.tools
|
||||
):
|
||||
gen_kwargs["tools"] = self.tools
|
||||
|
||||
if (
|
||||
self.json_schema
|
||||
and hasattr(self.llm, "_supports_structured_output")
|
||||
and self.llm._supports_structured_output()
|
||||
):
|
||||
structured_format = self.llm.prepare_structured_output_format(
|
||||
self.json_schema
|
||||
)
|
||||
if structured_format:
|
||||
if self.llm_name == "openai":
|
||||
gen_kwargs["response_format"] = structured_format
|
||||
elif self.llm_name == "google":
|
||||
gen_kwargs["response_schema"] = structured_format
|
||||
|
||||
resp = self.llm.gen_stream(**gen_kwargs)
|
||||
|
||||
if log_context:
|
||||
data = build_stack_data(self.llm, exclude_attributes=["client"])
|
||||
log_context.stacks.append({"component": "llm", "data": data})
|
||||
@@ -268,10 +359,51 @@ class BaseAgent(ABC):
|
||||
log_context: Optional[LogContext] = None,
|
||||
attachments: Optional[List[Dict]] = None,
|
||||
):
|
||||
resp = self.llm_handler.handle_response(
|
||||
self, resp, tools_dict, messages, attachments
|
||||
resp = self.llm_handler.process_message_flow(
|
||||
self, resp, tools_dict, messages, attachments, True
|
||||
)
|
||||
if log_context:
|
||||
data = build_stack_data(self.llm_handler, exclude_attributes=["tool_calls"])
|
||||
log_context.stacks.append({"component": "llm_handler", "data": data})
|
||||
return resp
|
||||
|
||||
def _handle_response(self, response, tools_dict, messages, log_context):
|
||||
is_structured_output = (
|
||||
self.json_schema is not None
|
||||
and hasattr(self.llm, "_supports_structured_output")
|
||||
and self.llm._supports_structured_output()
|
||||
)
|
||||
|
||||
if isinstance(response, str):
|
||||
answer_data = {"answer": response}
|
||||
if is_structured_output:
|
||||
answer_data["structured"] = True
|
||||
answer_data["schema"] = self.json_schema
|
||||
yield answer_data
|
||||
return
|
||||
if hasattr(response, "message") and getattr(response.message, "content", None):
|
||||
answer_data = {"answer": response.message.content}
|
||||
if is_structured_output:
|
||||
answer_data["structured"] = True
|
||||
answer_data["schema"] = self.json_schema
|
||||
yield answer_data
|
||||
return
|
||||
processed_response_gen = self._llm_handler(
|
||||
response, tools_dict, messages, log_context, self.attachments
|
||||
)
|
||||
|
||||
for event in processed_response_gen:
|
||||
if isinstance(event, str):
|
||||
answer_data = {"answer": event}
|
||||
if is_structured_output:
|
||||
answer_data["structured"] = True
|
||||
answer_data["schema"] = self.json_schema
|
||||
yield answer_data
|
||||
elif hasattr(event, "message") and getattr(event.message, "content", None):
|
||||
answer_data = {"answer": event.message.content}
|
||||
if is_structured_output:
|
||||
answer_data["structured"] = True
|
||||
answer_data["schema"] = self.json_schema
|
||||
yield answer_data
|
||||
elif isinstance(event, dict) and "type" in event:
|
||||
yield event
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from typing import Dict, Generator
|
||||
|
||||
from application.agents.base import BaseAgent
|
||||
from application.logging import LogContext
|
||||
|
||||
from application.retriever.base import BaseRetriever
|
||||
import logging
|
||||
|
||||
@@ -10,55 +8,46 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ClassicAgent(BaseAgent):
|
||||
"""A simplified agent with clear execution flow.
|
||||
|
||||
Usage:
|
||||
1. Processes a query through retrieval
|
||||
2. Sets up available tools
|
||||
3. Generates responses using LLM
|
||||
4. Handles tool interactions if needed
|
||||
5. Returns standardized outputs
|
||||
|
||||
Easy to extend by overriding specific steps.
|
||||
"""
|
||||
|
||||
def _gen_inner(
|
||||
self, query: str, retriever: BaseRetriever, log_context: LogContext
|
||||
) -> Generator[Dict, None, None]:
|
||||
# Step 1: Retrieve relevant data
|
||||
retrieved_data = self._retriever_search(retriever, query, log_context)
|
||||
if self.user_api_key:
|
||||
tools_dict = self._get_tools(self.user_api_key)
|
||||
else:
|
||||
tools_dict = self._get_user_tools(self.user)
|
||||
|
||||
# Step 2: Prepare tools
|
||||
tools_dict = (
|
||||
self._get_user_tools(self.user)
|
||||
if not self.user_api_key
|
||||
else self._get_tools(self.user_api_key)
|
||||
)
|
||||
self._prepare_tools(tools_dict)
|
||||
|
||||
# Step 3: Build and process messages
|
||||
messages = self._build_messages(self.prompt, query, retrieved_data)
|
||||
llm_response = self._llm_gen(messages, log_context)
|
||||
|
||||
resp = self._llm_gen(messages, log_context)
|
||||
# Step 4: Handle the response
|
||||
yield from self._handle_response(
|
||||
llm_response, tools_dict, messages, log_context
|
||||
)
|
||||
|
||||
attachments = self.attachments
|
||||
|
||||
if isinstance(resp, str):
|
||||
yield {"answer": resp}
|
||||
return
|
||||
if (
|
||||
hasattr(resp, "message")
|
||||
and hasattr(resp.message, "content")
|
||||
and resp.message.content is not None
|
||||
):
|
||||
yield {"answer": resp.message.content}
|
||||
return
|
||||
|
||||
resp = self._llm_handler(resp, tools_dict, messages, log_context, attachments)
|
||||
|
||||
if isinstance(resp, str):
|
||||
yield {"answer": resp}
|
||||
elif (
|
||||
hasattr(resp, "message")
|
||||
and hasattr(resp.message, "content")
|
||||
and resp.message.content is not None
|
||||
):
|
||||
yield {"answer": resp.message.content}
|
||||
else:
|
||||
for line in resp:
|
||||
if isinstance(line, str):
|
||||
yield {"answer": line}
|
||||
# Step 5: Return metadata
|
||||
yield {"sources": retrieved_data}
|
||||
yield {"tool_calls": self._get_truncated_tool_calls()}
|
||||
|
||||
# Log tool calls for debugging
|
||||
log_context.stacks.append(
|
||||
{"component": "agent", "data": {"tool_calls": self.tool_calls.copy()}}
|
||||
)
|
||||
|
||||
yield {"sources": retrieved_data}
|
||||
# clean tool_call_data only send first 50 characters of tool_call['result']
|
||||
for tool_call in self.tool_calls:
|
||||
if len(str(tool_call["result"])) > 50:
|
||||
tool_call["result"] = str(tool_call["result"])[:50] + "..."
|
||||
yield {"tool_calls": self.tool_calls.copy()}
|
||||
|
||||
@@ -1,351 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from application.logging import build_stack_data
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMHandler(ABC):
|
||||
def __init__(self):
|
||||
self.llm_calls = []
|
||||
self.tool_calls = []
|
||||
|
||||
@abstractmethod
|
||||
def handle_response(self, agent, resp, tools_dict, messages, attachments=None, **kwargs):
|
||||
pass
|
||||
|
||||
def prepare_messages_with_attachments(self, agent, messages, attachments=None):
|
||||
"""
|
||||
Prepare messages with attachment content if available.
|
||||
|
||||
Args:
|
||||
agent: The current agent instance.
|
||||
messages (list): List of message dictionaries.
|
||||
attachments (list): List of attachment dictionaries with content.
|
||||
|
||||
Returns:
|
||||
list: Messages with attachment context added to the system prompt.
|
||||
"""
|
||||
if not attachments:
|
||||
return messages
|
||||
|
||||
logger.info(f"Preparing messages with {len(attachments)} attachments")
|
||||
|
||||
supported_types = agent.llm.get_supported_attachment_types()
|
||||
|
||||
supported_attachments = []
|
||||
unsupported_attachments = []
|
||||
|
||||
for attachment in attachments:
|
||||
mime_type = attachment.get('mime_type')
|
||||
if mime_type in supported_types:
|
||||
supported_attachments.append(attachment)
|
||||
else:
|
||||
unsupported_attachments.append(attachment)
|
||||
|
||||
# Process supported attachments with the LLM's custom method
|
||||
prepared_messages = messages
|
||||
if supported_attachments:
|
||||
logger.info(f"Processing {len(supported_attachments)} supported attachments with {agent.llm.__class__.__name__}'s method")
|
||||
prepared_messages = agent.llm.prepare_messages_with_attachments(messages, supported_attachments)
|
||||
|
||||
# Process unsupported attachments with the default method
|
||||
if unsupported_attachments:
|
||||
logger.info(f"Processing {len(unsupported_attachments)} unsupported attachments with default method")
|
||||
prepared_messages = self._append_attachment_content_to_system(prepared_messages, unsupported_attachments)
|
||||
|
||||
return prepared_messages
|
||||
|
||||
def _append_attachment_content_to_system(self, messages, attachments):
|
||||
"""
|
||||
Default method to append attachment content to the system prompt.
|
||||
|
||||
Args:
|
||||
messages (list): List of message dictionaries.
|
||||
attachments (list): List of attachment dictionaries with content.
|
||||
|
||||
Returns:
|
||||
list: Messages with attachment context added to the system prompt.
|
||||
"""
|
||||
prepared_messages = messages.copy()
|
||||
|
||||
attachment_texts = []
|
||||
for attachment in attachments:
|
||||
logger.info(f"Adding attachment {attachment.get('id')} to context")
|
||||
if 'content' in attachment:
|
||||
attachment_texts.append(f"Attached file content:\n\n{attachment['content']}")
|
||||
|
||||
if attachment_texts:
|
||||
combined_attachment_text = "\n\n".join(attachment_texts)
|
||||
|
||||
system_found = False
|
||||
for i in range(len(prepared_messages)):
|
||||
if prepared_messages[i].get("role") == "system":
|
||||
prepared_messages[i]["content"] += f"\n\n{combined_attachment_text}"
|
||||
system_found = True
|
||||
break
|
||||
|
||||
if not system_found:
|
||||
prepared_messages.insert(0, {"role": "system", "content": combined_attachment_text})
|
||||
|
||||
return prepared_messages
|
||||
|
||||
class OpenAILLMHandler(LLMHandler):
|
||||
def handle_response(self, agent, resp, tools_dict, messages, attachments=None, stream: bool = True):
|
||||
|
||||
messages = self.prepare_messages_with_attachments(agent, messages, attachments)
|
||||
logger.info(f"Messages with attachments: {messages}")
|
||||
if not stream:
|
||||
while hasattr(resp, "finish_reason") and resp.finish_reason == "tool_calls":
|
||||
message = json.loads(resp.model_dump_json())["message"]
|
||||
keys_to_remove = {"audio", "function_call", "refusal"}
|
||||
filtered_data = {
|
||||
k: v for k, v in message.items() if k not in keys_to_remove
|
||||
}
|
||||
messages.append(filtered_data)
|
||||
|
||||
tool_calls = resp.message.tool_calls
|
||||
for call in tool_calls:
|
||||
try:
|
||||
self.tool_calls.append(call)
|
||||
tool_response, call_id = agent._execute_tool_action(
|
||||
tools_dict, call
|
||||
)
|
||||
function_call_dict = {
|
||||
"function_call": {
|
||||
"name": call.function.name,
|
||||
"args": call.function.arguments,
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
function_response_dict = {
|
||||
"function_response": {
|
||||
"name": call.function.name,
|
||||
"response": {"result": tool_response},
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
|
||||
messages.append(
|
||||
{"role": "assistant", "content": [function_call_dict]}
|
||||
)
|
||||
messages.append(
|
||||
{"role": "tool", "content": [function_response_dict]}
|
||||
)
|
||||
|
||||
messages = self.prepare_messages_with_attachments(agent, messages, attachments)
|
||||
except Exception as e:
|
||||
logging.error(f"Error executing tool: {str(e)}", exc_info=True)
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": f"Error executing tool: {str(e)}",
|
||||
"tool_call_id": call_id,
|
||||
}
|
||||
)
|
||||
resp = agent.llm.gen_stream(
|
||||
model=agent.gpt_model, messages=messages, tools=agent.tools
|
||||
)
|
||||
self.llm_calls.append(build_stack_data(agent.llm))
|
||||
return resp
|
||||
|
||||
else:
|
||||
text_buffer = ""
|
||||
while True:
|
||||
tool_calls = {}
|
||||
for chunk in resp:
|
||||
if isinstance(chunk, str) and len(chunk) > 0:
|
||||
yield chunk
|
||||
continue
|
||||
elif hasattr(chunk, "delta"):
|
||||
chunk_delta = chunk.delta
|
||||
|
||||
if (
|
||||
hasattr(chunk_delta, "tool_calls")
|
||||
and chunk_delta.tool_calls is not None
|
||||
):
|
||||
for tool_call in chunk_delta.tool_calls:
|
||||
index = tool_call.index
|
||||
if index not in tool_calls:
|
||||
tool_calls[index] = {
|
||||
"id": "",
|
||||
"function": {"name": "", "arguments": ""},
|
||||
}
|
||||
|
||||
current = tool_calls[index]
|
||||
if tool_call.id:
|
||||
current["id"] = tool_call.id
|
||||
if tool_call.function.name:
|
||||
current["function"][
|
||||
"name"
|
||||
] = tool_call.function.name
|
||||
if tool_call.function.arguments:
|
||||
current["function"][
|
||||
"arguments"
|
||||
] += tool_call.function.arguments
|
||||
tool_calls[index] = current
|
||||
|
||||
if (
|
||||
hasattr(chunk, "finish_reason")
|
||||
and chunk.finish_reason == "tool_calls"
|
||||
):
|
||||
for index in sorted(tool_calls.keys()):
|
||||
call = tool_calls[index]
|
||||
try:
|
||||
self.tool_calls.append(call)
|
||||
tool_response, call_id = agent._execute_tool_action(
|
||||
tools_dict, call
|
||||
)
|
||||
if isinstance(call["function"]["arguments"], str):
|
||||
call["function"]["arguments"] = json.loads(call["function"]["arguments"])
|
||||
|
||||
function_call_dict = {
|
||||
"function_call": {
|
||||
"name": call["function"]["name"],
|
||||
"args": call["function"]["arguments"],
|
||||
"call_id": call["id"],
|
||||
}
|
||||
}
|
||||
function_response_dict = {
|
||||
"function_response": {
|
||||
"name": call["function"]["name"],
|
||||
"response": {"result": tool_response},
|
||||
"call_id": call["id"],
|
||||
}
|
||||
}
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [function_call_dict],
|
||||
}
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": [function_response_dict],
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error executing tool: {str(e)}", exc_info=True)
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"Error executing tool: {str(e)}",
|
||||
}
|
||||
)
|
||||
tool_calls = {}
|
||||
if hasattr(chunk_delta, "content") and chunk_delta.content:
|
||||
# Add to buffer or yield immediately based on your preference
|
||||
text_buffer += chunk_delta.content
|
||||
yield text_buffer
|
||||
text_buffer = ""
|
||||
|
||||
if (
|
||||
hasattr(chunk, "finish_reason")
|
||||
and chunk.finish_reason == "stop"
|
||||
):
|
||||
return resp
|
||||
elif isinstance(chunk, str) and len(chunk) == 0:
|
||||
continue
|
||||
|
||||
logger.info(f"Regenerating with messages: {messages}")
|
||||
resp = agent.llm.gen_stream(
|
||||
model=agent.gpt_model, messages=messages, tools=agent.tools
|
||||
)
|
||||
self.llm_calls.append(build_stack_data(agent.llm))
|
||||
|
||||
|
||||
class GoogleLLMHandler(LLMHandler):
|
||||
def handle_response(self, agent, resp, tools_dict, messages, attachments=None, stream: bool = True):
|
||||
from google.genai import types
|
||||
|
||||
messages = self.prepare_messages_with_attachments(agent, messages, attachments)
|
||||
|
||||
while True:
|
||||
if not stream:
|
||||
response = agent.llm.gen(
|
||||
model=agent.gpt_model, messages=messages, tools=agent.tools
|
||||
)
|
||||
self.llm_calls.append(build_stack_data(agent.llm))
|
||||
if response.candidates and response.candidates[0].content.parts:
|
||||
tool_call_found = False
|
||||
for part in response.candidates[0].content.parts:
|
||||
if part.function_call:
|
||||
tool_call_found = True
|
||||
self.tool_calls.append(part.function_call)
|
||||
tool_response, call_id = agent._execute_tool_action(
|
||||
tools_dict, part.function_call
|
||||
)
|
||||
function_response_part = types.Part.from_function_response(
|
||||
name=part.function_call.name,
|
||||
response={"result": tool_response},
|
||||
)
|
||||
|
||||
messages.append(
|
||||
{"role": "model", "content": [part.to_json_dict()]}
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": [function_response_part.to_json_dict()],
|
||||
}
|
||||
)
|
||||
|
||||
if (
|
||||
not tool_call_found
|
||||
and response.candidates[0].content.parts
|
||||
and response.candidates[0].content.parts[0].text
|
||||
):
|
||||
return response.candidates[0].content.parts[0].text
|
||||
elif not tool_call_found:
|
||||
return response.candidates[0].content.parts
|
||||
|
||||
else:
|
||||
return response
|
||||
|
||||
else:
|
||||
response = agent.llm.gen_stream(
|
||||
model=agent.gpt_model, messages=messages, tools=agent.tools
|
||||
)
|
||||
self.llm_calls.append(build_stack_data(agent.llm))
|
||||
|
||||
tool_call_found = False
|
||||
for result in response:
|
||||
if hasattr(result, "function_call"):
|
||||
tool_call_found = True
|
||||
self.tool_calls.append(result.function_call)
|
||||
tool_response, call_id = agent._execute_tool_action(
|
||||
tools_dict, result.function_call
|
||||
)
|
||||
function_response_part = types.Part.from_function_response(
|
||||
name=result.function_call.name,
|
||||
response={"result": tool_response},
|
||||
)
|
||||
|
||||
messages.append(
|
||||
{"role": "model", "content": [result.to_json_dict()]}
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": [function_response_part.to_json_dict()],
|
||||
}
|
||||
)
|
||||
else:
|
||||
tool_call_found = False
|
||||
yield result
|
||||
|
||||
if not tool_call_found:
|
||||
return response
|
||||
|
||||
|
||||
def get_llm_handler(llm_type):
|
||||
handlers = {
|
||||
"openai": OpenAILLMHandler(),
|
||||
"google": GoogleLLMHandler(),
|
||||
}
|
||||
return handlers.get(llm_type, OpenAILLMHandler())
|
||||
@@ -25,27 +25,35 @@ class BraveSearchTool(Tool):
|
||||
else:
|
||||
raise ValueError(f"Unknown action: {action_name}")
|
||||
|
||||
def _web_search(self, query, country="ALL", search_lang="en", count=10,
|
||||
offset=0, safesearch="off", freshness=None,
|
||||
result_filter=None, extra_snippets=False, summary=False):
|
||||
def _web_search(
|
||||
self,
|
||||
query,
|
||||
country="ALL",
|
||||
search_lang="en",
|
||||
count=10,
|
||||
offset=0,
|
||||
safesearch="off",
|
||||
freshness=None,
|
||||
result_filter=None,
|
||||
extra_snippets=False,
|
||||
summary=False,
|
||||
):
|
||||
"""
|
||||
Performs a web search using the Brave Search API.
|
||||
"""
|
||||
print(f"Performing Brave web search for: {query}")
|
||||
|
||||
|
||||
url = f"{self.base_url}/web/search"
|
||||
|
||||
# Build query parameters
|
||||
|
||||
params = {
|
||||
"q": query,
|
||||
"country": country,
|
||||
"search_lang": search_lang,
|
||||
"count": min(count, 20),
|
||||
"offset": min(offset, 9),
|
||||
"safesearch": safesearch
|
||||
"safesearch": safesearch,
|
||||
}
|
||||
|
||||
# Add optional parameters only if they have values
|
||||
|
||||
if freshness:
|
||||
params["freshness"] = freshness
|
||||
if result_filter:
|
||||
@@ -54,68 +62,69 @@ class BraveSearchTool(Tool):
|
||||
params["extra_snippets"] = 1
|
||||
if summary:
|
||||
params["summary"] = 1
|
||||
|
||||
# Set up headers
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"Accept-Encoding": "gzip",
|
||||
"X-Subscription-Token": self.token
|
||||
"X-Subscription-Token": self.token,
|
||||
}
|
||||
|
||||
# Make the request
|
||||
|
||||
response = requests.get(url, params=params, headers=headers)
|
||||
|
||||
|
||||
if response.status_code == 200:
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"results": response.json(),
|
||||
"message": "Search completed successfully."
|
||||
"message": "Search completed successfully.",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"message": f"Search failed with status code: {response.status_code}."
|
||||
"message": f"Search failed with status code: {response.status_code}.",
|
||||
}
|
||||
|
||||
def _image_search(self, query, country="ALL", search_lang="en", count=5,
|
||||
safesearch="off", spellcheck=False):
|
||||
|
||||
def _image_search(
|
||||
self,
|
||||
query,
|
||||
country="ALL",
|
||||
search_lang="en",
|
||||
count=5,
|
||||
safesearch="off",
|
||||
spellcheck=False,
|
||||
):
|
||||
"""
|
||||
Performs an image search using the Brave Search API.
|
||||
"""
|
||||
print(f"Performing Brave image search for: {query}")
|
||||
|
||||
|
||||
url = f"{self.base_url}/images/search"
|
||||
|
||||
# Build query parameters
|
||||
|
||||
params = {
|
||||
"q": query,
|
||||
"country": country,
|
||||
"search_lang": search_lang,
|
||||
"count": min(count, 100), # API max is 100
|
||||
"safesearch": safesearch,
|
||||
"spellcheck": 1 if spellcheck else 0
|
||||
"spellcheck": 1 if spellcheck else 0,
|
||||
}
|
||||
|
||||
# Set up headers
|
||||
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"Accept-Encoding": "gzip",
|
||||
"X-Subscription-Token": self.token
|
||||
"X-Subscription-Token": self.token,
|
||||
}
|
||||
|
||||
# Make the request
|
||||
|
||||
response = requests.get(url, params=params, headers=headers)
|
||||
|
||||
|
||||
if response.status_code == 200:
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"results": response.json(),
|
||||
"message": "Image search completed successfully."
|
||||
"message": "Image search completed successfully.",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"message": f"Image search failed with status code: {response.status_code}."
|
||||
"message": f"Image search failed with status code: {response.status_code}.",
|
||||
}
|
||||
|
||||
def get_actions_metadata(self):
|
||||
@@ -130,42 +139,14 @@ class BraveSearchTool(Tool):
|
||||
"type": "string",
|
||||
"description": "The search query (max 400 characters, 50 words)",
|
||||
},
|
||||
# "country": {
|
||||
# "type": "string",
|
||||
# "description": "The 2-character country code (default: US)",
|
||||
# },
|
||||
"search_lang": {
|
||||
"type": "string",
|
||||
"description": "The search language preference (default: en)",
|
||||
},
|
||||
# "count": {
|
||||
# "type": "integer",
|
||||
# "description": "Number of results to return (max 20, default: 10)",
|
||||
# },
|
||||
# "offset": {
|
||||
# "type": "integer",
|
||||
# "description": "Pagination offset (max 9, default: 0)",
|
||||
# },
|
||||
# "safesearch": {
|
||||
# "type": "string",
|
||||
# "description": "Filter level for adult content (off, moderate, strict)",
|
||||
# },
|
||||
"freshness": {
|
||||
"type": "string",
|
||||
"description": "Time filter for results (pd: last 24h, pw: last week, pm: last month, py: last year)",
|
||||
},
|
||||
# "result_filter": {
|
||||
# "type": "string",
|
||||
# "description": "Comma-delimited list of result types to include",
|
||||
# },
|
||||
# "extra_snippets": {
|
||||
# "type": "boolean",
|
||||
# "description": "Get additional excerpts from result pages",
|
||||
# },
|
||||
# "summary": {
|
||||
# "type": "boolean",
|
||||
# "description": "Enable summary generation in search results",
|
||||
# }
|
||||
},
|
||||
"required": ["query"],
|
||||
"additionalProperties": False,
|
||||
@@ -181,37 +162,21 @@ class BraveSearchTool(Tool):
|
||||
"type": "string",
|
||||
"description": "The search query (max 400 characters, 50 words)",
|
||||
},
|
||||
# "country": {
|
||||
# "type": "string",
|
||||
# "description": "The 2-character country code (default: US)",
|
||||
# },
|
||||
# "search_lang": {
|
||||
# "type": "string",
|
||||
# "description": "The search language preference (default: en)",
|
||||
# },
|
||||
"count": {
|
||||
"type": "integer",
|
||||
"description": "Number of results to return (max 100, default: 5)",
|
||||
},
|
||||
# "safesearch": {
|
||||
# "type": "string",
|
||||
# "description": "Filter level for adult content (off, strict). Default: strict",
|
||||
# },
|
||||
# "spellcheck": {
|
||||
# "type": "boolean",
|
||||
# "description": "Whether to spellcheck provided query (default: true)",
|
||||
# }
|
||||
},
|
||||
"required": ["query"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
def get_config_requirements(self):
|
||||
return {
|
||||
"token": {
|
||||
"type": "string",
|
||||
"description": "Brave Search API key for authentication"
|
||||
"type": "string",
|
||||
"description": "Brave Search API key for authentication",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
114
application/agents/tools/duckduckgo.py
Normal file
114
application/agents/tools/duckduckgo.py
Normal file
@@ -0,0 +1,114 @@
|
||||
from application.agents.tools.base import Tool
|
||||
from duckduckgo_search import DDGS
|
||||
|
||||
|
||||
class DuckDuckGoSearchTool(Tool):
|
||||
"""
|
||||
DuckDuckGo Search
|
||||
A tool for performing web and image searches using DuckDuckGo.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
def execute_action(self, action_name, **kwargs):
|
||||
actions = {
|
||||
"ddg_web_search": self._web_search,
|
||||
"ddg_image_search": self._image_search,
|
||||
}
|
||||
|
||||
if action_name in actions:
|
||||
return actions[action_name](**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown action: {action_name}")
|
||||
|
||||
def _web_search(
|
||||
self,
|
||||
query,
|
||||
max_results=5,
|
||||
):
|
||||
print(f"Performing DuckDuckGo web search for: {query}")
|
||||
|
||||
try:
|
||||
results = DDGS().text(
|
||||
query,
|
||||
max_results=max_results,
|
||||
)
|
||||
|
||||
return {
|
||||
"status_code": 200,
|
||||
"results": results,
|
||||
"message": "Web search completed successfully.",
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"status_code": 500,
|
||||
"message": f"Web search failed: {str(e)}",
|
||||
}
|
||||
|
||||
def _image_search(
|
||||
self,
|
||||
query,
|
||||
max_results=5,
|
||||
):
|
||||
print(f"Performing DuckDuckGo image search for: {query}")
|
||||
|
||||
try:
|
||||
results = DDGS().images(
|
||||
keywords=query,
|
||||
max_results=max_results,
|
||||
)
|
||||
|
||||
return {
|
||||
"status_code": 200,
|
||||
"results": results,
|
||||
"message": "Image search completed successfully.",
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"status_code": 500,
|
||||
"message": f"Image search failed: {str(e)}",
|
||||
}
|
||||
|
||||
def get_actions_metadata(self):
|
||||
return [
|
||||
{
|
||||
"name": "ddg_web_search",
|
||||
"description": "Perform a web search using DuckDuckGo.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query",
|
||||
},
|
||||
"max_results": {
|
||||
"type": "integer",
|
||||
"description": "Number of results to return (default: 5)",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "ddg_image_search",
|
||||
"description": "Perform an image search using DuckDuckGo.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query",
|
||||
},
|
||||
"max_results": {
|
||||
"type": "integer",
|
||||
"description": "Number of results to return (default: 5, max: 50)",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
def get_config_requirements(self):
|
||||
return {}
|
||||
861
application/agents/tools/mcp_tool.py
Normal file
861
application/agents/tools/mcp_tool.py
Normal file
@@ -0,0 +1,861 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from application.agents.tools.base import Tool
|
||||
from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task
|
||||
from application.cache import get_redis_instance
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
|
||||
from application.core.settings import settings
|
||||
|
||||
from application.security.encryption import decrypt_credentials
|
||||
from fastmcp import Client
|
||||
from fastmcp.client.auth import BearerAuth
|
||||
from fastmcp.client.transports import (
|
||||
SSETransport,
|
||||
StdioTransport,
|
||||
StreamableHttpTransport,
|
||||
)
|
||||
from mcp.client.auth import OAuthClientProvider, TokenStorage
|
||||
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
|
||||
|
||||
from pydantic import AnyHttpUrl, ValidationError
|
||||
from redis import Redis
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
|
||||
_mcp_clients_cache = {}
|
||||
|
||||
|
||||
class MCPTool(Tool):
|
||||
"""
|
||||
MCP Tool
|
||||
Connect to remote Model Context Protocol (MCP) servers to access dynamic tools and resources. Supports various authentication methods and provides secure access to external services through the MCP protocol.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any], user_id: Optional[str] = None):
|
||||
"""
|
||||
Initialize the MCP Tool with configuration.
|
||||
|
||||
Args:
|
||||
config: Dictionary containing MCP server configuration:
|
||||
- server_url: URL of the remote MCP server
|
||||
- transport_type: Transport type (auto, sse, http, stdio)
|
||||
- auth_type: Type of authentication (bearer, oauth, api_key, basic, none)
|
||||
- encrypted_credentials: Encrypted credentials (if available)
|
||||
- timeout: Request timeout in seconds (default: 30)
|
||||
- headers: Custom headers for requests
|
||||
- command: Command for STDIO transport
|
||||
- args: Arguments for STDIO transport
|
||||
- oauth_scopes: OAuth scopes for oauth auth type
|
||||
- oauth_client_name: OAuth client name for oauth auth type
|
||||
user_id: User ID for decrypting credentials (required if encrypted_credentials exist)
|
||||
"""
|
||||
self.config = config
|
||||
self.user_id = user_id
|
||||
self.server_url = config.get("server_url", "")
|
||||
self.transport_type = config.get("transport_type", "auto")
|
||||
self.auth_type = config.get("auth_type", "none")
|
||||
self.timeout = config.get("timeout", 30)
|
||||
self.custom_headers = config.get("headers", {})
|
||||
|
||||
self.auth_credentials = {}
|
||||
if config.get("encrypted_credentials") and user_id:
|
||||
self.auth_credentials = decrypt_credentials(
|
||||
config["encrypted_credentials"], user_id
|
||||
)
|
||||
else:
|
||||
self.auth_credentials = config.get("auth_credentials", {})
|
||||
self.oauth_scopes = config.get("oauth_scopes", [])
|
||||
self.oauth_task_id = config.get("oauth_task_id", None)
|
||||
self.oauth_client_name = config.get("oauth_client_name", "DocsGPT-MCP")
|
||||
self.redirect_uri = f"{settings.API_URL}/api/mcp_server/callback"
|
||||
|
||||
self.available_tools = []
|
||||
self._cache_key = self._generate_cache_key()
|
||||
self._client = None
|
||||
|
||||
# Only validate and setup if server_url is provided and not OAuth
|
||||
|
||||
if self.server_url and self.auth_type != "oauth":
|
||||
self._setup_client()
|
||||
|
||||
def _generate_cache_key(self) -> str:
|
||||
"""Generate a unique cache key for this MCP server configuration."""
|
||||
auth_key = ""
|
||||
if self.auth_type == "oauth":
|
||||
scopes_str = ",".join(self.oauth_scopes) if self.oauth_scopes else "none"
|
||||
auth_key = f"oauth:{self.oauth_client_name}:{scopes_str}"
|
||||
elif self.auth_type in ["bearer"]:
|
||||
token = self.auth_credentials.get(
|
||||
"bearer_token", ""
|
||||
) or self.auth_credentials.get("access_token", "")
|
||||
auth_key = f"bearer:{token[:10]}..." if token else "bearer:none"
|
||||
elif self.auth_type == "api_key":
|
||||
api_key = self.auth_credentials.get("api_key", "")
|
||||
auth_key = f"apikey:{api_key[:10]}..." if api_key else "apikey:none"
|
||||
elif self.auth_type == "basic":
|
||||
username = self.auth_credentials.get("username", "")
|
||||
auth_key = f"basic:{username}"
|
||||
else:
|
||||
auth_key = "none"
|
||||
return f"{self.server_url}#{self.transport_type}#{auth_key}"
|
||||
|
||||
def _setup_client(self):
|
||||
"""Setup FastMCP client with proper transport and authentication."""
|
||||
global _mcp_clients_cache
|
||||
if self._cache_key in _mcp_clients_cache:
|
||||
cached_data = _mcp_clients_cache[self._cache_key]
|
||||
if time.time() - cached_data["created_at"] < 1800:
|
||||
self._client = cached_data["client"]
|
||||
return
|
||||
else:
|
||||
del _mcp_clients_cache[self._cache_key]
|
||||
transport = self._create_transport()
|
||||
auth = None
|
||||
|
||||
if self.auth_type == "oauth":
|
||||
redis_client = get_redis_instance()
|
||||
auth = DocsGPTOAuth(
|
||||
mcp_url=self.server_url,
|
||||
scopes=self.oauth_scopes,
|
||||
redis_client=redis_client,
|
||||
redirect_uri=self.redirect_uri,
|
||||
task_id=self.oauth_task_id,
|
||||
db=db,
|
||||
user_id=self.user_id,
|
||||
)
|
||||
elif self.auth_type == "bearer":
|
||||
token = self.auth_credentials.get(
|
||||
"bearer_token", ""
|
||||
) or self.auth_credentials.get("access_token", "")
|
||||
if token:
|
||||
auth = BearerAuth(token)
|
||||
self._client = Client(transport, auth=auth)
|
||||
_mcp_clients_cache[self._cache_key] = {
|
||||
"client": self._client,
|
||||
"created_at": time.time(),
|
||||
}
|
||||
|
||||
def _create_transport(self):
|
||||
"""Create appropriate transport based on configuration."""
|
||||
headers = {"Content-Type": "application/json", "User-Agent": "DocsGPT-MCP/1.0"}
|
||||
headers.update(self.custom_headers)
|
||||
|
||||
if self.auth_type == "api_key":
|
||||
api_key = self.auth_credentials.get("api_key", "")
|
||||
header_name = self.auth_credentials.get("api_key_header", "X-API-Key")
|
||||
if api_key:
|
||||
headers[header_name] = api_key
|
||||
elif self.auth_type == "basic":
|
||||
username = self.auth_credentials.get("username", "")
|
||||
password = self.auth_credentials.get("password", "")
|
||||
if username and password:
|
||||
credentials = base64.b64encode(
|
||||
f"{username}:{password}".encode()
|
||||
).decode()
|
||||
headers["Authorization"] = f"Basic {credentials}"
|
||||
if self.transport_type == "auto":
|
||||
if "sse" in self.server_url.lower() or self.server_url.endswith("/sse"):
|
||||
transport_type = "sse"
|
||||
else:
|
||||
transport_type = "http"
|
||||
else:
|
||||
transport_type = self.transport_type
|
||||
if transport_type == "sse":
|
||||
headers.update({"Accept": "text/event-stream", "Cache-Control": "no-cache"})
|
||||
return SSETransport(url=self.server_url, headers=headers)
|
||||
elif transport_type == "http":
|
||||
return StreamableHttpTransport(url=self.server_url, headers=headers)
|
||||
elif transport_type == "stdio":
|
||||
command = self.config.get("command", "python")
|
||||
args = self.config.get("args", [])
|
||||
env = self.auth_credentials if self.auth_credentials else None
|
||||
return StdioTransport(command=command, args=args, env=env)
|
||||
else:
|
||||
return StreamableHttpTransport(url=self.server_url, headers=headers)
|
||||
|
||||
def _format_tools(self, tools_response) -> List[Dict]:
|
||||
"""Format tools response to match expected format."""
|
||||
if hasattr(tools_response, "tools"):
|
||||
tools = tools_response.tools
|
||||
elif isinstance(tools_response, list):
|
||||
tools = tools_response
|
||||
else:
|
||||
tools = []
|
||||
tools_dict = []
|
||||
for tool in tools:
|
||||
if hasattr(tool, "name"):
|
||||
tool_dict = {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
}
|
||||
if hasattr(tool, "inputSchema"):
|
||||
tool_dict["inputSchema"] = tool.inputSchema
|
||||
tools_dict.append(tool_dict)
|
||||
elif isinstance(tool, dict):
|
||||
tools_dict.append(tool)
|
||||
else:
|
||||
if hasattr(tool, "model_dump"):
|
||||
tools_dict.append(tool.model_dump())
|
||||
else:
|
||||
tools_dict.append({"name": str(tool), "description": ""})
|
||||
return tools_dict
|
||||
|
||||
async def _execute_with_client(self, operation: str, *args, **kwargs):
|
||||
"""Execute operation with FastMCP client."""
|
||||
if not self._client:
|
||||
raise Exception("FastMCP client not initialized")
|
||||
async with self._client:
|
||||
if operation == "ping":
|
||||
return await self._client.ping()
|
||||
elif operation == "list_tools":
|
||||
tools_response = await self._client.list_tools()
|
||||
self.available_tools = self._format_tools(tools_response)
|
||||
return self.available_tools
|
||||
elif operation == "call_tool":
|
||||
tool_name = args[0]
|
||||
tool_args = kwargs
|
||||
return await self._client.call_tool(tool_name, tool_args)
|
||||
elif operation == "list_resources":
|
||||
return await self._client.list_resources()
|
||||
elif operation == "list_prompts":
|
||||
return await self._client.list_prompts()
|
||||
else:
|
||||
raise Exception(f"Unknown operation: {operation}")
|
||||
|
||||
def _run_async_operation(self, operation: str, *args, **kwargs):
|
||||
"""Run async operation in sync context."""
|
||||
try:
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
import concurrent.futures
|
||||
|
||||
def run_in_thread():
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
try:
|
||||
return new_loop.run_until_complete(
|
||||
self._execute_with_client(operation, *args, **kwargs)
|
||||
)
|
||||
finally:
|
||||
new_loop.close()
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(run_in_thread)
|
||||
return future.result(timeout=self.timeout)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
return loop.run_until_complete(
|
||||
self._execute_with_client(operation, *args, **kwargs)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
except Exception as e:
|
||||
print(f"Error occurred while running async operation: {e}")
|
||||
raise
|
||||
|
||||
def discover_tools(self) -> List[Dict]:
|
||||
"""
|
||||
Discover available tools from the MCP server using FastMCP.
|
||||
|
||||
Returns:
|
||||
List of tool definitions from the server
|
||||
"""
|
||||
if not self.server_url:
|
||||
return []
|
||||
if not self._client:
|
||||
self._setup_client()
|
||||
try:
|
||||
tools = self._run_async_operation("list_tools")
|
||||
self.available_tools = tools
|
||||
return self.available_tools
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to discover tools from MCP server: {str(e)}")
|
||||
|
||||
def execute_action(self, action_name: str, **kwargs) -> Any:
|
||||
"""
|
||||
Execute an action on the remote MCP server using FastMCP.
|
||||
|
||||
Args:
|
||||
action_name: Name of the action to execute
|
||||
**kwargs: Parameters for the action
|
||||
|
||||
Returns:
|
||||
Result from the MCP server
|
||||
"""
|
||||
if not self.server_url:
|
||||
raise Exception("No MCP server configured")
|
||||
if not self._client:
|
||||
self._setup_client()
|
||||
cleaned_kwargs = {}
|
||||
for key, value in kwargs.items():
|
||||
if value == "" or value is None:
|
||||
continue
|
||||
cleaned_kwargs[key] = value
|
||||
try:
|
||||
result = self._run_async_operation(
|
||||
"call_tool", action_name, **cleaned_kwargs
|
||||
)
|
||||
return self._format_result(result)
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to execute action '{action_name}': {str(e)}")
|
||||
|
||||
def _format_result(self, result) -> Dict:
|
||||
"""Format FastMCP result to match expected format."""
|
||||
if hasattr(result, "content"):
|
||||
content_list = []
|
||||
for content_item in result.content:
|
||||
if hasattr(content_item, "text"):
|
||||
content_list.append({"type": "text", "text": content_item.text})
|
||||
elif hasattr(content_item, "data"):
|
||||
content_list.append({"type": "data", "data": content_item.data})
|
||||
else:
|
||||
content_list.append(
|
||||
{"type": "unknown", "content": str(content_item)}
|
||||
)
|
||||
return {
|
||||
"content": content_list,
|
||||
"isError": getattr(result, "isError", False),
|
||||
}
|
||||
else:
|
||||
return result
|
||||
|
||||
def test_connection(self) -> Dict:
|
||||
"""
|
||||
Test the connection to the MCP server and validate functionality.
|
||||
|
||||
Returns:
|
||||
Dictionary with connection test results including tool count
|
||||
"""
|
||||
if not self.server_url:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "No MCP server URL configured",
|
||||
"tools_count": 0,
|
||||
"transport_type": self.transport_type,
|
||||
"auth_type": self.auth_type,
|
||||
"error_type": "ConfigurationError",
|
||||
}
|
||||
if not self._client:
|
||||
self._setup_client()
|
||||
try:
|
||||
if self.auth_type == "oauth":
|
||||
return self._test_oauth_connection()
|
||||
else:
|
||||
return self._test_regular_connection()
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"Connection failed: {str(e)}",
|
||||
"tools_count": 0,
|
||||
"transport_type": self.transport_type,
|
||||
"auth_type": self.auth_type,
|
||||
"error_type": type(e).__name__,
|
||||
}
|
||||
|
||||
def _test_regular_connection(self) -> Dict:
|
||||
"""Test connection for non-OAuth auth types."""
|
||||
try:
|
||||
self._run_async_operation("ping")
|
||||
ping_success = True
|
||||
except Exception:
|
||||
ping_success = False
|
||||
tools = self.discover_tools()
|
||||
|
||||
message = f"Successfully connected to MCP server. Found {len(tools)} tools."
|
||||
if not ping_success:
|
||||
message += " (Ping not supported, but tool discovery worked)"
|
||||
return {
|
||||
"success": True,
|
||||
"message": message,
|
||||
"tools_count": len(tools),
|
||||
"transport_type": self.transport_type,
|
||||
"auth_type": self.auth_type,
|
||||
"ping_supported": ping_success,
|
||||
"tools": [tool.get("name", "unknown") for tool in tools],
|
||||
}
|
||||
|
||||
def _test_oauth_connection(self) -> Dict:
|
||||
"""Test connection for OAuth auth type with proper async handling."""
|
||||
try:
|
||||
task = mcp_oauth_task.delay(config=self.config, user=self.user_id)
|
||||
if not task:
|
||||
raise Exception("Failed to start OAuth authentication")
|
||||
return {
|
||||
"success": True,
|
||||
"requires_oauth": True,
|
||||
"task_id": task.id,
|
||||
"status": "pending",
|
||||
"message": "OAuth flow started",
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"OAuth connection failed: {str(e)}",
|
||||
"tools_count": 0,
|
||||
"transport_type": self.transport_type,
|
||||
"auth_type": self.auth_type,
|
||||
"error_type": type(e).__name__,
|
||||
}
|
||||
|
||||
def get_actions_metadata(self) -> List[Dict]:
|
||||
"""
|
||||
Get metadata for all available actions.
|
||||
|
||||
Returns:
|
||||
List of action metadata dictionaries
|
||||
"""
|
||||
actions = []
|
||||
for tool in self.available_tools:
|
||||
input_schema = (
|
||||
tool.get("inputSchema")
|
||||
or tool.get("input_schema")
|
||||
or tool.get("schema")
|
||||
or tool.get("parameters")
|
||||
)
|
||||
|
||||
parameters_schema = {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
if input_schema:
|
||||
if isinstance(input_schema, dict):
|
||||
if "properties" in input_schema:
|
||||
parameters_schema = {
|
||||
"type": input_schema.get("type", "object"),
|
||||
"properties": input_schema.get("properties", {}),
|
||||
"required": input_schema.get("required", []),
|
||||
}
|
||||
|
||||
for key in ["additionalProperties", "description"]:
|
||||
if key in input_schema:
|
||||
parameters_schema[key] = input_schema[key]
|
||||
else:
|
||||
parameters_schema["properties"] = input_schema
|
||||
action = {
|
||||
"name": tool.get("name", ""),
|
||||
"description": tool.get("description", ""),
|
||||
"parameters": parameters_schema,
|
||||
}
|
||||
actions.append(action)
|
||||
return actions
|
||||
|
||||
def get_config_requirements(self) -> Dict:
|
||||
"""Get configuration requirements for the MCP tool."""
|
||||
return {
|
||||
"server_url": {
|
||||
"type": "string",
|
||||
"description": "URL of the remote MCP server (e.g., https://api.example.com/mcp or https://docs.mcp.cloudflare.com/sse)",
|
||||
"required": True,
|
||||
},
|
||||
"transport_type": {
|
||||
"type": "string",
|
||||
"description": "Transport type for connection",
|
||||
"enum": ["auto", "sse", "http", "stdio"],
|
||||
"default": "auto",
|
||||
"required": False,
|
||||
"help": {
|
||||
"auto": "Automatically detect best transport",
|
||||
"sse": "Server-Sent Events (for real-time streaming)",
|
||||
"http": "HTTP streaming (recommended for production)",
|
||||
"stdio": "Standard I/O (for local servers)",
|
||||
},
|
||||
},
|
||||
"auth_type": {
|
||||
"type": "string",
|
||||
"description": "Authentication type",
|
||||
"enum": ["none", "bearer", "oauth", "api_key", "basic"],
|
||||
"default": "none",
|
||||
"required": True,
|
||||
"help": {
|
||||
"none": "No authentication",
|
||||
"bearer": "Bearer token authentication",
|
||||
"oauth": "OAuth 2.1 authentication (with frontend integration)",
|
||||
"api_key": "API key authentication",
|
||||
"basic": "Basic authentication",
|
||||
},
|
||||
},
|
||||
"auth_credentials": {
|
||||
"type": "object",
|
||||
"description": "Authentication credentials (varies by auth_type)",
|
||||
"required": False,
|
||||
"properties": {
|
||||
"bearer_token": {
|
||||
"type": "string",
|
||||
"description": "Bearer token for bearer auth",
|
||||
},
|
||||
"access_token": {
|
||||
"type": "string",
|
||||
"description": "Access token for OAuth (if pre-obtained)",
|
||||
},
|
||||
"api_key": {
|
||||
"type": "string",
|
||||
"description": "API key for api_key auth",
|
||||
},
|
||||
"api_key_header": {
|
||||
"type": "string",
|
||||
"description": "Header name for API key (default: X-API-Key)",
|
||||
},
|
||||
"username": {
|
||||
"type": "string",
|
||||
"description": "Username for basic auth",
|
||||
},
|
||||
"password": {
|
||||
"type": "string",
|
||||
"description": "Password for basic auth",
|
||||
},
|
||||
},
|
||||
},
|
||||
"oauth_scopes": {
|
||||
"type": "array",
|
||||
"description": "OAuth scopes to request (for oauth auth_type)",
|
||||
"items": {"type": "string"},
|
||||
"required": False,
|
||||
"default": [],
|
||||
},
|
||||
"oauth_client_name": {
|
||||
"type": "string",
|
||||
"description": "Client name for OAuth registration (for oauth auth_type)",
|
||||
"default": "DocsGPT-MCP",
|
||||
"required": False,
|
||||
},
|
||||
"headers": {
|
||||
"type": "object",
|
||||
"description": "Custom headers to send with requests",
|
||||
"required": False,
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Request timeout in seconds",
|
||||
"default": 30,
|
||||
"minimum": 1,
|
||||
"maximum": 300,
|
||||
"required": False,
|
||||
},
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "Command to run for STDIO transport (e.g., 'python')",
|
||||
"required": False,
|
||||
},
|
||||
"args": {
|
||||
"type": "array",
|
||||
"description": "Arguments for STDIO command",
|
||||
"items": {"type": "string"},
|
||||
"required": False,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class DocsGPTOAuth(OAuthClientProvider):
|
||||
"""
|
||||
Custom OAuth handler for DocsGPT that uses frontend redirect instead of browser.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mcp_url: str,
|
||||
redirect_uri: str,
|
||||
redis_client: Redis | None = None,
|
||||
redis_prefix: str = "mcp_oauth:",
|
||||
task_id: str = None,
|
||||
scopes: str | list[str] | None = None,
|
||||
client_name: str = "DocsGPT-MCP",
|
||||
user_id=None,
|
||||
db=None,
|
||||
additional_client_metadata: dict[str, Any] | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize custom OAuth client provider for DocsGPT.
|
||||
|
||||
Args:
|
||||
mcp_url: Full URL to the MCP endpoint
|
||||
redirect_uri: Custom redirect URI for DocsGPT frontend
|
||||
redis_client: Redis client for storing auth state
|
||||
redis_prefix: Prefix for Redis keys
|
||||
task_id: Task ID for tracking auth status
|
||||
scopes: OAuth scopes to request
|
||||
client_name: Name for this client during registration
|
||||
user_id: User ID for token storage
|
||||
db: Database instance for token storage
|
||||
additional_client_metadata: Extra fields for OAuthClientMetadata
|
||||
"""
|
||||
|
||||
self.redirect_uri = redirect_uri
|
||||
self.redis_client = redis_client
|
||||
self.redis_prefix = redis_prefix
|
||||
self.task_id = task_id
|
||||
self.user_id = user_id
|
||||
self.db = db
|
||||
|
||||
parsed_url = urlparse(mcp_url)
|
||||
self.server_base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||||
|
||||
if isinstance(scopes, list):
|
||||
scopes = " ".join(scopes)
|
||||
client_metadata = OAuthClientMetadata(
|
||||
client_name=client_name,
|
||||
redirect_uris=[AnyHttpUrl(redirect_uri)],
|
||||
grant_types=["authorization_code", "refresh_token"],
|
||||
response_types=["code"],
|
||||
scope=scopes,
|
||||
**(additional_client_metadata or {}),
|
||||
)
|
||||
|
||||
storage = DBTokenStorage(
|
||||
server_url=self.server_base_url, user_id=self.user_id, db_client=self.db
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
server_url=self.server_base_url,
|
||||
client_metadata=client_metadata,
|
||||
storage=storage,
|
||||
redirect_handler=self.redirect_handler,
|
||||
callback_handler=self.callback_handler,
|
||||
)
|
||||
|
||||
self.auth_url = None
|
||||
self.extracted_state = None
|
||||
|
||||
def _process_auth_url(self, authorization_url: str) -> tuple[str, str]:
|
||||
"""Process authorization URL to extract state"""
|
||||
try:
|
||||
parsed_url = urlparse(authorization_url)
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
|
||||
state_params = query_params.get("state", [])
|
||||
if state_params:
|
||||
state = state_params[0]
|
||||
else:
|
||||
raise ValueError("No state in auth URL")
|
||||
return authorization_url, state
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to process auth URL: {e}")
|
||||
|
||||
async def redirect_handler(self, authorization_url: str) -> None:
|
||||
"""Store auth URL and state in Redis for frontend to use."""
|
||||
auth_url, state = self._process_auth_url(authorization_url)
|
||||
logging.info(
|
||||
"[DocsGPTOAuth] Processed auth_url: %s, state: %s", auth_url, state
|
||||
)
|
||||
self.auth_url = auth_url
|
||||
self.extracted_state = state
|
||||
|
||||
if self.redis_client and self.extracted_state:
|
||||
key = f"{self.redis_prefix}auth_url:{self.extracted_state}"
|
||||
self.redis_client.setex(key, 600, auth_url)
|
||||
logging.info("[DocsGPTOAuth] Stored auth_url in Redis: %s", key)
|
||||
|
||||
if self.task_id:
|
||||
status_key = f"mcp_oauth_status:{self.task_id}"
|
||||
status_data = {
|
||||
"status": "requires_redirect",
|
||||
"message": "OAuth authorization required",
|
||||
"authorization_url": self.auth_url,
|
||||
"state": self.extracted_state,
|
||||
"requires_oauth": True,
|
||||
"task_id": self.task_id,
|
||||
}
|
||||
self.redis_client.setex(status_key, 600, json.dumps(status_data))
|
||||
|
||||
async def callback_handler(self) -> tuple[str, str | None]:
|
||||
"""Wait for auth code from Redis using the state value."""
|
||||
if not self.redis_client or not self.extracted_state:
|
||||
raise Exception("Redis client or state not configured for OAuth")
|
||||
poll_interval = 1
|
||||
max_wait_time = 300
|
||||
code_key = f"{self.redis_prefix}code:{self.extracted_state}"
|
||||
|
||||
if self.task_id:
|
||||
status_key = f"mcp_oauth_status:{self.task_id}"
|
||||
status_data = {
|
||||
"status": "awaiting_callback",
|
||||
"message": "Waiting for OAuth callback...",
|
||||
"authorization_url": self.auth_url,
|
||||
"state": self.extracted_state,
|
||||
"requires_oauth": True,
|
||||
"task_id": self.task_id,
|
||||
}
|
||||
self.redis_client.setex(status_key, 600, json.dumps(status_data))
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < max_wait_time:
|
||||
code_data = self.redis_client.get(code_key)
|
||||
if code_data:
|
||||
code = code_data.decode()
|
||||
returned_state = self.extracted_state
|
||||
|
||||
self.redis_client.delete(code_key)
|
||||
self.redis_client.delete(
|
||||
f"{self.redis_prefix}auth_url:{self.extracted_state}"
|
||||
)
|
||||
self.redis_client.delete(
|
||||
f"{self.redis_prefix}state:{self.extracted_state}"
|
||||
)
|
||||
|
||||
if self.task_id:
|
||||
status_data = {
|
||||
"status": "callback_received",
|
||||
"message": "OAuth callback received, completing authentication...",
|
||||
"task_id": self.task_id,
|
||||
}
|
||||
self.redis_client.setex(status_key, 600, json.dumps(status_data))
|
||||
return code, returned_state
|
||||
error_key = f"{self.redis_prefix}error:{self.extracted_state}"
|
||||
error_data = self.redis_client.get(error_key)
|
||||
if error_data:
|
||||
error_msg = error_data.decode()
|
||||
self.redis_client.delete(error_key)
|
||||
self.redis_client.delete(
|
||||
f"{self.redis_prefix}auth_url:{self.extracted_state}"
|
||||
)
|
||||
self.redis_client.delete(
|
||||
f"{self.redis_prefix}state:{self.extracted_state}"
|
||||
)
|
||||
raise Exception(f"OAuth error: {error_msg}")
|
||||
await asyncio.sleep(poll_interval)
|
||||
self.redis_client.delete(f"{self.redis_prefix}auth_url:{self.extracted_state}")
|
||||
self.redis_client.delete(f"{self.redis_prefix}state:{self.extracted_state}")
|
||||
raise Exception("OAuth callback timeout: no code received within 5 minutes")
|
||||
|
||||
|
||||
class DBTokenStorage(TokenStorage):
|
||||
def __init__(self, server_url: str, user_id: str, db_client):
|
||||
self.server_url = server_url
|
||||
self.user_id = user_id
|
||||
self.db_client = db_client
|
||||
self.collection = db_client["connector_sessions"]
|
||||
|
||||
@staticmethod
|
||||
def get_base_url(url: str) -> str:
|
||||
parsed = urlparse(url)
|
||||
return f"{parsed.scheme}://{parsed.netloc}"
|
||||
|
||||
def get_db_key(self) -> dict:
|
||||
return {
|
||||
"server_url": self.get_base_url(self.server_url),
|
||||
"user_id": self.user_id,
|
||||
}
|
||||
|
||||
async def get_tokens(self) -> OAuthToken | None:
|
||||
doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key())
|
||||
if not doc or "tokens" not in doc:
|
||||
return None
|
||||
try:
|
||||
tokens = OAuthToken.model_validate(doc["tokens"])
|
||||
return tokens
|
||||
except ValidationError as e:
|
||||
logging.error(f"Could not load tokens: {e}")
|
||||
return None
|
||||
|
||||
async def set_tokens(self, tokens: OAuthToken) -> None:
|
||||
await asyncio.to_thread(
|
||||
self.collection.update_one,
|
||||
self.get_db_key(),
|
||||
{"$set": {"tokens": tokens.model_dump()}},
|
||||
True,
|
||||
)
|
||||
logging.info(f"Saved tokens for {self.get_base_url(self.server_url)}")
|
||||
|
||||
async def get_client_info(self) -> OAuthClientInformationFull | None:
|
||||
doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key())
|
||||
if not doc or "client_info" not in doc:
|
||||
return None
|
||||
try:
|
||||
client_info = OAuthClientInformationFull.model_validate(doc["client_info"])
|
||||
tokens = await self.get_tokens()
|
||||
if tokens is None:
|
||||
logging.debug(
|
||||
"No tokens found, clearing client info to force fresh registration."
|
||||
)
|
||||
await asyncio.to_thread(
|
||||
self.collection.update_one,
|
||||
self.get_db_key(),
|
||||
{"$unset": {"client_info": ""}},
|
||||
)
|
||||
return None
|
||||
return client_info
|
||||
except ValidationError as e:
|
||||
logging.error(f"Could not load client info: {e}")
|
||||
return None
|
||||
|
||||
def _serialize_client_info(self, info: dict) -> dict:
|
||||
if "redirect_uris" in info and isinstance(info["redirect_uris"], list):
|
||||
info["redirect_uris"] = [str(u) for u in info["redirect_uris"]]
|
||||
return info
|
||||
|
||||
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
|
||||
serialized_info = self._serialize_client_info(client_info.model_dump())
|
||||
await asyncio.to_thread(
|
||||
self.collection.update_one,
|
||||
self.get_db_key(),
|
||||
{"$set": {"client_info": serialized_info}},
|
||||
True,
|
||||
)
|
||||
logging.info(f"Saved client info for {self.get_base_url(self.server_url)}")
|
||||
|
||||
async def clear(self) -> None:
|
||||
await asyncio.to_thread(self.collection.delete_one, self.get_db_key())
|
||||
logging.info(f"Cleared OAuth cache for {self.get_base_url(self.server_url)}")
|
||||
|
||||
@classmethod
|
||||
async def clear_all(cls, db_client) -> None:
|
||||
collection = db_client["connector_sessions"]
|
||||
await asyncio.to_thread(collection.delete_many, {})
|
||||
logging.info("Cleared all OAuth client cache data.")
|
||||
|
||||
|
||||
class MCPOAuthManager:
|
||||
"""Manager for handling MCP OAuth callbacks."""
|
||||
|
||||
def __init__(self, redis_client: Redis | None, redis_prefix: str = "mcp_oauth:"):
|
||||
self.redis_client = redis_client
|
||||
self.redis_prefix = redis_prefix
|
||||
|
||||
def handle_oauth_callback(
|
||||
self, state: str, code: str, error: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Handle OAuth callback from provider.
|
||||
|
||||
Args:
|
||||
state: The state parameter from OAuth callback
|
||||
code: The authorization code from OAuth callback
|
||||
error: Error message if OAuth failed
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
if not self.redis_client or not state:
|
||||
raise Exception("Redis client or state not provided")
|
||||
if error:
|
||||
error_key = f"{self.redis_prefix}error:{state}"
|
||||
self.redis_client.setex(error_key, 300, error)
|
||||
raise Exception(f"OAuth error received: {error}")
|
||||
code_key = f"{self.redis_prefix}code:{state}"
|
||||
self.redis_client.setex(code_key, 300, code)
|
||||
|
||||
state_key = f"{self.redis_prefix}state:{state}"
|
||||
self.redis_client.setex(state_key, 300, "completed")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.error(f"Error handling OAuth callback: {e}")
|
||||
return False
|
||||
|
||||
def get_oauth_status(self, task_id: str) -> Dict[str, Any]:
|
||||
"""Get current status of OAuth flow using provided task_id."""
|
||||
if not task_id:
|
||||
return {"status": "not_started", "message": "OAuth flow not started"}
|
||||
return mcp_oauth_status_task(task_id)
|
||||
@@ -17,26 +17,45 @@ class ToolActionParser:
|
||||
return parser(call)
|
||||
|
||||
def _parse_openai_llm(self, call):
|
||||
if isinstance(call, dict):
|
||||
try:
|
||||
call_args = json.loads(call["function"]["arguments"])
|
||||
tool_id = call["function"]["name"].split("_")[-1]
|
||||
action_name = call["function"]["name"].rsplit("_", 1)[0]
|
||||
except (KeyError, TypeError) as e:
|
||||
logger.error(f"Error parsing OpenAI LLM call: {e}")
|
||||
return None, None, None
|
||||
else:
|
||||
try:
|
||||
call_args = json.loads(call.function.arguments)
|
||||
tool_id = call.function.name.split("_")[-1]
|
||||
action_name = call.function.name.rsplit("_", 1)[0]
|
||||
except (AttributeError, TypeError) as e:
|
||||
logger.error(f"Error parsing OpenAI LLM call: {e}")
|
||||
try:
|
||||
call_args = json.loads(call.arguments)
|
||||
tool_parts = call.name.split("_")
|
||||
|
||||
# If the tool name doesn't contain an underscore, it's likely a hallucinated tool
|
||||
if len(tool_parts) < 2:
|
||||
logger.warning(f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id")
|
||||
return None, None, None
|
||||
|
||||
tool_id = tool_parts[-1]
|
||||
action_name = "_".join(tool_parts[:-1])
|
||||
|
||||
# Validate that tool_id looks like a numerical ID
|
||||
if not tool_id.isdigit():
|
||||
logger.warning(f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call.")
|
||||
|
||||
except (AttributeError, TypeError) as e:
|
||||
logger.error(f"Error parsing OpenAI LLM call: {e}")
|
||||
return None, None, None
|
||||
return tool_id, action_name, call_args
|
||||
|
||||
def _parse_google_llm(self, call):
|
||||
call_args = call.args
|
||||
tool_id = call.name.split("_")[-1]
|
||||
action_name = call.name.rsplit("_", 1)[0]
|
||||
try:
|
||||
call_args = call.arguments
|
||||
tool_parts = call.name.split("_")
|
||||
|
||||
# If the tool name doesn't contain an underscore, it's likely a hallucinated tool
|
||||
if len(tool_parts) < 2:
|
||||
logger.warning(f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id")
|
||||
return None, None, None
|
||||
|
||||
tool_id = tool_parts[-1]
|
||||
action_name = "_".join(tool_parts[:-1])
|
||||
|
||||
# Validate that tool_id looks like a numerical ID
|
||||
if not tool_id.isdigit():
|
||||
logger.warning(f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call.")
|
||||
|
||||
except (AttributeError, TypeError) as e:
|
||||
logger.error(f"Error parsing Google LLM call: {e}")
|
||||
return None, None, None
|
||||
return tool_id, action_name, call_args
|
||||
|
||||
@@ -23,16 +23,23 @@ class ToolManager:
|
||||
tool_config = self.config.get(name, {})
|
||||
self.tools[name] = obj(tool_config)
|
||||
|
||||
def load_tool(self, tool_name, tool_config):
|
||||
def load_tool(self, tool_name, tool_config, user_id=None):
|
||||
self.config[tool_name] = tool_config
|
||||
module = importlib.import_module(f"application.agents.tools.{tool_name}")
|
||||
for member_name, obj in inspect.getmembers(module, inspect.isclass):
|
||||
if issubclass(obj, Tool) and obj is not Tool:
|
||||
return obj(tool_config)
|
||||
if tool_name == "mcp_tool" and user_id:
|
||||
return obj(tool_config, user_id)
|
||||
else:
|
||||
return obj(tool_config)
|
||||
|
||||
def execute_action(self, tool_name, action_name, **kwargs):
|
||||
def execute_action(self, tool_name, action_name, user_id=None, **kwargs):
|
||||
if tool_name not in self.tools:
|
||||
raise ValueError(f"Tool '{tool_name}' not loaded")
|
||||
if tool_name == "mcp_tool" and user_id:
|
||||
tool_config = self.config.get(tool_name, {})
|
||||
tool = self.load_tool(tool_name, tool_config, user_id)
|
||||
return tool.execute_action(action_name, **kwargs)
|
||||
return self.tools[tool_name].execute_action(action_name, **kwargs)
|
||||
|
||||
def get_all_actions_metadata(self):
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
from flask_restx import Api
|
||||
|
||||
api = Api(
|
||||
version="1.0",
|
||||
title="DocsGPT API",
|
||||
description="API for DocsGPT",
|
||||
)
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
from flask import Blueprint
|
||||
|
||||
from application.api import api
|
||||
from application.api.answer.routes.answer import AnswerResource
|
||||
from application.api.answer.routes.base import answer_ns
|
||||
from application.api.answer.routes.stream import StreamResource
|
||||
|
||||
|
||||
answer = Blueprint("answer", __name__)
|
||||
|
||||
api.add_namespace(answer_ns)
|
||||
|
||||
|
||||
def init_answer_routes():
|
||||
api.add_resource(StreamResource, "/stream")
|
||||
api.add_resource(AnswerResource, "/api/answer")
|
||||
|
||||
|
||||
init_answer_routes()
|
||||
|
||||
@@ -1,916 +0,0 @@
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import traceback
|
||||
|
||||
from bson.dbref import DBRef
|
||||
from bson.objectid import ObjectId
|
||||
from flask import Blueprint, make_response, request, Response
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.agents.agent_creator import AgentCreator
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.error import bad_request
|
||||
from application.extensions import api
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.retriever.retriever_creator import RetrieverCreator
|
||||
from application.utils import check_required_fields, limit_chat_history
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
conversations_collection = db["conversations"]
|
||||
sources_collection = db["sources"]
|
||||
prompts_collection = db["prompts"]
|
||||
agents_collection = db["agents"]
|
||||
user_logs_collection = db["user_logs"]
|
||||
attachments_collection = db["attachments"]
|
||||
|
||||
answer = Blueprint("answer", __name__)
|
||||
answer_ns = Namespace("answer", description="Answer related operations", path="/")
|
||||
api.add_namespace(answer_ns)
|
||||
|
||||
gpt_model = ""
|
||||
# to have some kind of default behaviour
|
||||
if settings.LLM_NAME == "openai":
|
||||
gpt_model = "gpt-4o-mini"
|
||||
elif settings.LLM_NAME == "anthropic":
|
||||
gpt_model = "claude-2"
|
||||
elif settings.LLM_NAME == "groq":
|
||||
gpt_model = "llama3-8b-8192"
|
||||
elif settings.LLM_NAME == "novita":
|
||||
gpt_model = "deepseek/deepseek-r1"
|
||||
|
||||
if settings.MODEL_NAME: # in case there is particular model name configured
|
||||
gpt_model = settings.MODEL_NAME
|
||||
|
||||
# load the prompts
|
||||
current_dir = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
)
|
||||
with open(os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r") as f:
|
||||
chat_combine_template = f.read()
|
||||
|
||||
with open(os.path.join(current_dir, "prompts", "chat_reduce_prompt.txt"), "r") as f:
|
||||
chat_reduce_template = f.read()
|
||||
|
||||
with open(os.path.join(current_dir, "prompts", "chat_combine_creative.txt"), "r") as f:
|
||||
chat_combine_creative = f.read()
|
||||
|
||||
with open(os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r") as f:
|
||||
chat_combine_strict = f.read()
|
||||
|
||||
api_key_set = settings.API_KEY is not None
|
||||
embeddings_key_set = settings.EMBEDDINGS_KEY is not None
|
||||
|
||||
|
||||
async def async_generate(chain, question, chat_history):
|
||||
result = await chain.arun({"question": question, "chat_history": chat_history})
|
||||
return result
|
||||
|
||||
|
||||
def run_async_chain(chain, question, chat_history):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
result = {}
|
||||
try:
|
||||
answer = loop.run_until_complete(async_generate(chain, question, chat_history))
|
||||
finally:
|
||||
loop.close()
|
||||
result["answer"] = answer
|
||||
return result
|
||||
|
||||
|
||||
def get_agent_key(agent_id, user_id):
|
||||
if not agent_id:
|
||||
return None, False, None
|
||||
|
||||
try:
|
||||
agent = agents_collection.find_one({"_id": ObjectId(agent_id)})
|
||||
if agent is None:
|
||||
raise Exception("Agent not found", 404)
|
||||
|
||||
is_owner = agent.get("user") == user_id
|
||||
|
||||
if is_owner:
|
||||
agents_collection.update_one(
|
||||
{"_id": ObjectId(agent_id)},
|
||||
{"$set": {"lastUsedAt": datetime.datetime.now(datetime.timezone.utc)}},
|
||||
)
|
||||
return str(agent["key"]), False, None
|
||||
|
||||
is_shared_with_user = agent.get(
|
||||
"shared_publicly", False
|
||||
) or user_id in agent.get("shared_with", [])
|
||||
|
||||
if is_shared_with_user:
|
||||
return str(agent["key"]), True, agent.get("shared_token")
|
||||
|
||||
raise Exception("Unauthorized access to the agent", 403)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_agent_key: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
def get_data_from_api_key(api_key):
|
||||
data = agents_collection.find_one({"key": api_key})
|
||||
if not data:
|
||||
raise Exception("Invalid API Key, please generate a new key", 401)
|
||||
|
||||
source = data.get("source")
|
||||
if isinstance(source, DBRef):
|
||||
source_doc = db.dereference(source)
|
||||
data["source"] = str(source_doc["_id"])
|
||||
data["retriever"] = source_doc.get("retriever", data.get("retriever"))
|
||||
else:
|
||||
data["source"] = {}
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def get_retriever(source_id: str):
|
||||
doc = sources_collection.find_one({"_id": ObjectId(source_id)})
|
||||
if doc is None:
|
||||
raise Exception("Source document does not exist", 404)
|
||||
retriever_name = None if "retriever" not in doc else doc["retriever"]
|
||||
return retriever_name
|
||||
|
||||
|
||||
def is_azure_configured():
|
||||
return (
|
||||
settings.OPENAI_API_BASE
|
||||
and settings.OPENAI_API_VERSION
|
||||
and settings.AZURE_DEPLOYMENT_NAME
|
||||
)
|
||||
|
||||
|
||||
def save_conversation(
|
||||
conversation_id,
|
||||
question,
|
||||
response,
|
||||
thought,
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
decoded_token,
|
||||
index=None,
|
||||
api_key=None,
|
||||
agent_id=None,
|
||||
is_shared_usage=False,
|
||||
shared_token=None,
|
||||
):
|
||||
current_time = datetime.datetime.now(datetime.timezone.utc)
|
||||
if conversation_id is not None and index is not None:
|
||||
conversations_collection.update_one(
|
||||
{"_id": ObjectId(conversation_id), f"queries.{index}": {"$exists": True}},
|
||||
{
|
||||
"$set": {
|
||||
f"queries.{index}.prompt": question,
|
||||
f"queries.{index}.response": response,
|
||||
f"queries.{index}.thought": thought,
|
||||
f"queries.{index}.sources": source_log_docs,
|
||||
f"queries.{index}.tool_calls": tool_calls,
|
||||
f"queries.{index}.timestamp": current_time,
|
||||
}
|
||||
},
|
||||
)
|
||||
##remove following queries from the array
|
||||
conversations_collection.update_one(
|
||||
{"_id": ObjectId(conversation_id), f"queries.{index}": {"$exists": True}},
|
||||
{"$push": {"queries": {"$each": [], "$slice": index + 1}}},
|
||||
)
|
||||
elif conversation_id is not None and conversation_id != "None":
|
||||
conversations_collection.update_one(
|
||||
{"_id": ObjectId(conversation_id)},
|
||||
{
|
||||
"$push": {
|
||||
"queries": {
|
||||
"prompt": question,
|
||||
"response": response,
|
||||
"thought": thought,
|
||||
"sources": source_log_docs,
|
||||
"tool_calls": tool_calls,
|
||||
"timestamp": current_time,
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
else:
|
||||
# create new conversation
|
||||
# generate summary
|
||||
messages_summary = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Summarise following conversation in no more than 3 "
|
||||
"words, respond ONLY with the summary, use the same "
|
||||
"language as the system",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Summarise following conversation in no more than 3 words, "
|
||||
"respond ONLY with the summary, use the same language as the "
|
||||
"system \n\nUser: " + question + "\n\n" + "AI: " + response,
|
||||
},
|
||||
]
|
||||
|
||||
completion = llm.gen(model=gpt_model, messages=messages_summary, max_tokens=30)
|
||||
conversation_data = {
|
||||
"user": decoded_token.get("sub"),
|
||||
"date": datetime.datetime.utcnow(),
|
||||
"name": completion,
|
||||
"queries": [
|
||||
{
|
||||
"prompt": question,
|
||||
"response": response,
|
||||
"thought": thought,
|
||||
"sources": source_log_docs,
|
||||
"tool_calls": tool_calls,
|
||||
"timestamp": current_time,
|
||||
}
|
||||
],
|
||||
}
|
||||
if api_key:
|
||||
if agent_id:
|
||||
conversation_data["agent_id"] = agent_id
|
||||
if is_shared_usage:
|
||||
conversation_data["is_shared_usage"] = is_shared_usage
|
||||
conversation_data["shared_token"] = shared_token
|
||||
api_key_doc = agents_collection.find_one({"key": api_key})
|
||||
if api_key_doc:
|
||||
conversation_data["api_key"] = api_key_doc["key"]
|
||||
conversation_id = conversations_collection.insert_one(
|
||||
conversation_data
|
||||
).inserted_id
|
||||
return conversation_id
|
||||
|
||||
|
||||
def get_prompt(prompt_id):
|
||||
if prompt_id == "default":
|
||||
prompt = chat_combine_template
|
||||
elif prompt_id == "creative":
|
||||
prompt = chat_combine_creative
|
||||
elif prompt_id == "strict":
|
||||
prompt = chat_combine_strict
|
||||
else:
|
||||
prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"]
|
||||
return prompt
|
||||
|
||||
|
||||
def complete_stream(
|
||||
question,
|
||||
agent,
|
||||
retriever,
|
||||
conversation_id,
|
||||
user_api_key,
|
||||
decoded_token,
|
||||
isNoneDoc=False,
|
||||
index=None,
|
||||
should_save_conversation=True,
|
||||
attachments=None,
|
||||
agent_id=None,
|
||||
is_shared_usage=False,
|
||||
shared_token=None,
|
||||
):
|
||||
try:
|
||||
response_full, thought, source_log_docs, tool_calls = "", "", [], []
|
||||
attachment_ids = []
|
||||
|
||||
if attachments:
|
||||
attachment_ids = [attachment["id"] for attachment in attachments]
|
||||
logger.info(
|
||||
f"Processing request with {len(attachments)} attachments: {attachment_ids}"
|
||||
)
|
||||
|
||||
answer = agent.gen(query=question, retriever=retriever)
|
||||
|
||||
for line in answer:
|
||||
if "answer" in line:
|
||||
response_full += str(line["answer"])
|
||||
data = json.dumps({"type": "answer", "answer": line["answer"]})
|
||||
yield f"data: {data}\n\n"
|
||||
elif "sources" in line:
|
||||
truncated_sources = []
|
||||
source_log_docs = line["sources"]
|
||||
for source in line["sources"]:
|
||||
truncated_source = source.copy()
|
||||
if "text" in truncated_source:
|
||||
truncated_source["text"] = (
|
||||
truncated_source["text"][:100].strip() + "..."
|
||||
)
|
||||
truncated_sources.append(truncated_source)
|
||||
if len(truncated_sources) > 0:
|
||||
data = json.dumps({"type": "source", "source": truncated_sources})
|
||||
yield f"data: {data}\n\n"
|
||||
elif "tool_calls" in line:
|
||||
tool_calls = line["tool_calls"]
|
||||
data = json.dumps({"type": "tool_calls", "tool_calls": tool_calls})
|
||||
yield f"data: {data}\n\n"
|
||||
elif "thought" in line:
|
||||
thought += line["thought"]
|
||||
data = json.dumps({"type": "thought", "thought": line["thought"]})
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
if isNoneDoc:
|
||||
for doc in source_log_docs:
|
||||
doc["source"] = "None"
|
||||
|
||||
llm = LLMCreator.create_llm(
|
||||
settings.LLM_NAME,
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
if should_save_conversation:
|
||||
conversation_id = save_conversation(
|
||||
conversation_id,
|
||||
question,
|
||||
response_full,
|
||||
thought,
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
decoded_token,
|
||||
index,
|
||||
api_key=user_api_key,
|
||||
agent_id=agent_id,
|
||||
is_shared_usage=is_shared_usage,
|
||||
shared_token=shared_token,
|
||||
)
|
||||
else:
|
||||
conversation_id = None
|
||||
|
||||
# send data.type = "end" to indicate that the stream has ended as json
|
||||
data = json.dumps({"type": "id", "id": str(conversation_id)})
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
retriever_params = retriever.get_params()
|
||||
user_logs_collection.insert_one(
|
||||
{
|
||||
"action": "stream_answer",
|
||||
"level": "info",
|
||||
"user": decoded_token.get("sub"),
|
||||
"api_key": user_api_key,
|
||||
"question": question,
|
||||
"response": response_full,
|
||||
"sources": source_log_docs,
|
||||
"retriever_params": retriever_params,
|
||||
"attachments": attachment_ids,
|
||||
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
||||
}
|
||||
)
|
||||
data = json.dumps({"type": "end"})
|
||||
yield f"data: {data}\n\n"
|
||||
except Exception as e:
|
||||
logger.error(f"Error in stream: {str(e)}", exc_info=True)
|
||||
data = json.dumps(
|
||||
{
|
||||
"type": "error",
|
||||
"error": "Please try again later. We apologize for any inconvenience.",
|
||||
}
|
||||
)
|
||||
yield f"data: {data}\n\n"
|
||||
return
|
||||
|
||||
|
||||
@answer_ns.route("/stream")
|
||||
class Stream(Resource):
|
||||
stream_model = api.model(
|
||||
"StreamModel",
|
||||
{
|
||||
"question": fields.String(
|
||||
required=True, description="Question to be asked"
|
||||
),
|
||||
"history": fields.List(
|
||||
fields.String, required=False, description="Chat history"
|
||||
),
|
||||
"conversation_id": fields.String(
|
||||
required=False, description="Conversation ID"
|
||||
),
|
||||
"prompt_id": fields.String(
|
||||
required=False, default="default", description="Prompt ID"
|
||||
),
|
||||
"chunks": fields.Integer(
|
||||
required=False, default=2, description="Number of chunks"
|
||||
),
|
||||
"token_limit": fields.Integer(required=False, description="Token limit"),
|
||||
"retriever": fields.String(required=False, description="Retriever type"),
|
||||
"api_key": fields.String(required=False, description="API key"),
|
||||
"active_docs": fields.String(
|
||||
required=False, description="Active documents"
|
||||
),
|
||||
"isNoneDoc": fields.Boolean(
|
||||
required=False, description="Flag indicating if no document is used"
|
||||
),
|
||||
"index": fields.Integer(
|
||||
required=False, description="Index of the query to update"
|
||||
),
|
||||
"save_conversation": fields.Boolean(
|
||||
required=False,
|
||||
default=True,
|
||||
description="Whether to save the conversation",
|
||||
),
|
||||
"attachments": fields.List(
|
||||
fields.String, required=False, description="List of attachment IDs"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(stream_model)
|
||||
@api.doc(description="Stream a response based on the question and retriever")
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
required_fields = ["question"]
|
||||
if "index" in data:
|
||||
required_fields = ["question", "conversation_id"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
|
||||
save_conv = data.get("save_conversation", True)
|
||||
|
||||
try:
|
||||
question = data["question"]
|
||||
history = limit_chat_history(
|
||||
json.loads(data.get("history", "[]")), gpt_model=gpt_model
|
||||
)
|
||||
conversation_id = data.get("conversation_id")
|
||||
prompt_id = data.get("prompt_id", "default")
|
||||
attachment_ids = data.get("attachments", [])
|
||||
|
||||
index = data.get("index", None)
|
||||
chunks = int(data.get("chunks", 2))
|
||||
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
|
||||
retriever_name = data.get("retriever", "classic")
|
||||
agent_id = data.get("agent_id", None)
|
||||
agent_type = settings.AGENT_NAME
|
||||
decoded_token = getattr(request, "decoded_token", None)
|
||||
user_sub = decoded_token.get("sub") if decoded_token else None
|
||||
agent_key, is_shared_usage, shared_token = get_agent_key(
|
||||
agent_id, user_sub
|
||||
)
|
||||
|
||||
if agent_key:
|
||||
data.update({"api_key": agent_key})
|
||||
else:
|
||||
agent_id = None
|
||||
|
||||
if "api_key" in data:
|
||||
data_key = get_data_from_api_key(data["api_key"])
|
||||
chunks = int(data_key.get("chunks", 2))
|
||||
prompt_id = data_key.get("prompt_id", "default")
|
||||
source = {"active_docs": data_key.get("source")}
|
||||
retriever_name = data_key.get("retriever", retriever_name)
|
||||
user_api_key = data["api_key"]
|
||||
agent_type = data_key.get("agent_type", agent_type)
|
||||
if is_shared_usage:
|
||||
decoded_token = request.decoded_token
|
||||
else:
|
||||
decoded_token = {"sub": data_key.get("user")}
|
||||
is_shared_usage = False
|
||||
|
||||
elif "active_docs" in data:
|
||||
source = {"active_docs": data["active_docs"]}
|
||||
retriever_name = get_retriever(data["active_docs"]) or retriever_name
|
||||
user_api_key = None
|
||||
decoded_token = request.decoded_token
|
||||
|
||||
else:
|
||||
source = {}
|
||||
user_api_key = None
|
||||
decoded_token = request.decoded_token
|
||||
|
||||
if not decoded_token:
|
||||
return make_response({"error": "Unauthorized"}, 401)
|
||||
|
||||
attachments = get_attachments_content(
|
||||
attachment_ids, decoded_token.get("sub")
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"/stream - request_data: {data}, source: {source}, attachments: {len(attachments)}",
|
||||
extra={"data": json.dumps({"request_data": data, "source": source})},
|
||||
)
|
||||
|
||||
prompt = get_prompt(prompt_id)
|
||||
if "isNoneDoc" in data and data["isNoneDoc"] is True:
|
||||
chunks = 0
|
||||
|
||||
agent = AgentCreator.create_agent(
|
||||
agent_type,
|
||||
endpoint="stream",
|
||||
llm_name=settings.LLM_NAME,
|
||||
gpt_model=gpt_model,
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=user_api_key,
|
||||
prompt=prompt,
|
||||
chat_history=history,
|
||||
decoded_token=decoded_token,
|
||||
attachments=attachments,
|
||||
)
|
||||
|
||||
retriever = RetrieverCreator.create_retriever(
|
||||
retriever_name,
|
||||
source=source,
|
||||
chat_history=history,
|
||||
prompt=prompt,
|
||||
chunks=chunks,
|
||||
token_limit=token_limit,
|
||||
gpt_model=gpt_model,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
return Response(
|
||||
complete_stream(
|
||||
question=question,
|
||||
agent=agent,
|
||||
retriever=retriever,
|
||||
conversation_id=conversation_id,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
isNoneDoc=data.get("isNoneDoc"),
|
||||
index=index,
|
||||
should_save_conversation=save_conv,
|
||||
agent_id=agent_id,
|
||||
is_shared_usage=is_shared_usage,
|
||||
shared_token=shared_token,
|
||||
),
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
|
||||
except ValueError:
|
||||
message = "Malformed request body"
|
||||
logger.error(f"/stream - error: {message}")
|
||||
return Response(
|
||||
error_stream_generate(message),
|
||||
status=400,
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"/stream - error: {str(e)} - traceback: {traceback.format_exc()}",
|
||||
extra={"error": str(e), "traceback": traceback.format_exc()},
|
||||
)
|
||||
status_code = 400
|
||||
return Response(
|
||||
error_stream_generate("Unknown error occurred"),
|
||||
status=status_code,
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
|
||||
|
||||
def error_stream_generate(err_response):
|
||||
data = json.dumps({"type": "error", "error": err_response})
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
|
||||
@answer_ns.route("/api/answer")
|
||||
class Answer(Resource):
|
||||
answer_model = api.model(
|
||||
"AnswerModel",
|
||||
{
|
||||
"question": fields.String(
|
||||
required=True, description="The question to answer"
|
||||
),
|
||||
"history": fields.List(
|
||||
fields.String, required=False, description="Conversation history"
|
||||
),
|
||||
"conversation_id": fields.String(
|
||||
required=False, description="Conversation ID"
|
||||
),
|
||||
"prompt_id": fields.String(
|
||||
required=False, default="default", description="Prompt ID"
|
||||
),
|
||||
"chunks": fields.Integer(
|
||||
required=False, default=2, description="Number of chunks"
|
||||
),
|
||||
"token_limit": fields.Integer(required=False, description="Token limit"),
|
||||
"retriever": fields.String(required=False, description="Retriever type"),
|
||||
"api_key": fields.String(required=False, description="API key"),
|
||||
"active_docs": fields.String(
|
||||
required=False, description="Active documents"
|
||||
),
|
||||
"isNoneDoc": fields.Boolean(
|
||||
required=False, description="Flag indicating if no document is used"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(answer_model)
|
||||
@api.doc(description="Provide an answer based on the question and retriever")
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
required_fields = ["question"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
|
||||
try:
|
||||
question = data["question"]
|
||||
history = limit_chat_history(
|
||||
json.loads(data.get("history", [])), gpt_model=gpt_model
|
||||
)
|
||||
conversation_id = data.get("conversation_id")
|
||||
prompt_id = data.get("prompt_id", "default")
|
||||
chunks = int(data.get("chunks", 2))
|
||||
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
|
||||
retriever_name = data.get("retriever", "classic")
|
||||
agent_type = settings.AGENT_NAME
|
||||
|
||||
if "api_key" in data:
|
||||
data_key = get_data_from_api_key(data["api_key"])
|
||||
chunks = int(data_key.get("chunks", 2))
|
||||
prompt_id = data_key.get("prompt_id", "default")
|
||||
source = {"active_docs": data_key.get("source")}
|
||||
retriever_name = data_key.get("retriever", retriever_name)
|
||||
user_api_key = data["api_key"]
|
||||
agent_type = data_key.get("agent_type", agent_type)
|
||||
decoded_token = {"sub": data_key.get("user")}
|
||||
|
||||
elif "active_docs" in data:
|
||||
source = {"active_docs": data["active_docs"]}
|
||||
retriever_name = get_retriever(data["active_docs"]) or retriever_name
|
||||
user_api_key = None
|
||||
decoded_token = request.decoded_token
|
||||
|
||||
else:
|
||||
source = {}
|
||||
user_api_key = None
|
||||
decoded_token = request.decoded_token
|
||||
|
||||
if not decoded_token:
|
||||
return make_response({"error": "Unauthorized"}, 401)
|
||||
|
||||
prompt = get_prompt(prompt_id)
|
||||
|
||||
logger.info(
|
||||
f"/api/answer - request_data: {data}, source: {source}",
|
||||
extra={"data": json.dumps({"request_data": data, "source": source})},
|
||||
)
|
||||
|
||||
agent = AgentCreator.create_agent(
|
||||
agent_type,
|
||||
endpoint="api/answer",
|
||||
llm_name=settings.LLM_NAME,
|
||||
gpt_model=gpt_model,
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=user_api_key,
|
||||
prompt=prompt,
|
||||
chat_history=history,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
retriever = RetrieverCreator.create_retriever(
|
||||
retriever_name,
|
||||
source=source,
|
||||
chat_history=history,
|
||||
prompt=prompt,
|
||||
chunks=chunks,
|
||||
token_limit=token_limit,
|
||||
gpt_model=gpt_model,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
response_full = ""
|
||||
source_log_docs = []
|
||||
tool_calls = []
|
||||
stream_ended = False
|
||||
thought = ""
|
||||
|
||||
for line in complete_stream(
|
||||
question=question,
|
||||
agent=agent,
|
||||
retriever=retriever,
|
||||
conversation_id=conversation_id,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
isNoneDoc=data.get("isNoneDoc"),
|
||||
index=None,
|
||||
should_save_conversation=False,
|
||||
):
|
||||
try:
|
||||
event_data = line.replace("data: ", "").strip()
|
||||
event = json.loads(event_data)
|
||||
|
||||
if event["type"] == "answer":
|
||||
response_full += event["answer"]
|
||||
elif event["type"] == "source":
|
||||
source_log_docs = event["source"]
|
||||
elif event["type"] == "tool_calls":
|
||||
tool_calls = event["tool_calls"]
|
||||
elif event["type"] == "thought":
|
||||
thought = event["thought"]
|
||||
elif event["type"] == "error":
|
||||
logger.error(f"Error from stream: {event['error']}")
|
||||
return bad_request(500, event["error"])
|
||||
elif event["type"] == "end":
|
||||
stream_ended = True
|
||||
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
logger.warning(f"Error parsing stream event: {e}, line: {line}")
|
||||
continue
|
||||
|
||||
if not stream_ended:
|
||||
logger.error("Stream ended unexpectedly without an 'end' event.")
|
||||
return bad_request(500, "Stream ended unexpectedly.")
|
||||
|
||||
if data.get("isNoneDoc"):
|
||||
for doc in source_log_docs:
|
||||
doc["source"] = "None"
|
||||
|
||||
llm = LLMCreator.create_llm(
|
||||
settings.LLM_NAME,
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
result = {"answer": response_full, "sources": source_log_docs}
|
||||
result["conversation_id"] = str(
|
||||
save_conversation(
|
||||
conversation_id,
|
||||
question,
|
||||
response_full,
|
||||
thought,
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
decoded_token,
|
||||
api_key=user_api_key,
|
||||
)
|
||||
)
|
||||
|
||||
retriever_params = retriever.get_params()
|
||||
user_logs_collection.insert_one(
|
||||
{
|
||||
"action": "api_answer",
|
||||
"level": "info",
|
||||
"user": decoded_token.get("sub"),
|
||||
"api_key": user_api_key,
|
||||
"question": question,
|
||||
"response": response_full,
|
||||
"sources": source_log_docs,
|
||||
"retriever_params": retriever_params,
|
||||
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",
|
||||
extra={"error": str(e), "traceback": traceback.format_exc()},
|
||||
)
|
||||
return bad_request(500, str(e))
|
||||
|
||||
return make_response(result, 200)
|
||||
|
||||
|
||||
@answer_ns.route("/api/search")
|
||||
class Search(Resource):
|
||||
search_model = api.model(
|
||||
"SearchModel",
|
||||
{
|
||||
"question": fields.String(
|
||||
required=True, description="The question to search"
|
||||
),
|
||||
"chunks": fields.Integer(
|
||||
required=False, default=2, description="Number of chunks"
|
||||
),
|
||||
"api_key": fields.String(
|
||||
required=False, description="API key for authentication"
|
||||
),
|
||||
"active_docs": fields.String(
|
||||
required=False, description="Active documents for retrieval"
|
||||
),
|
||||
"retriever": fields.String(required=False, description="Retriever type"),
|
||||
"token_limit": fields.Integer(
|
||||
required=False, description="Limit for tokens"
|
||||
),
|
||||
"isNoneDoc": fields.Boolean(
|
||||
required=False, description="Flag indicating if no document is used"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(search_model)
|
||||
@api.doc(
|
||||
description="Search for relevant documents based on the question and retriever"
|
||||
)
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
required_fields = ["question"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
|
||||
try:
|
||||
question = data["question"]
|
||||
chunks = int(data.get("chunks", 2))
|
||||
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
|
||||
retriever_name = data.get("retriever", "classic")
|
||||
|
||||
if "api_key" in data:
|
||||
data_key = get_data_from_api_key(data["api_key"])
|
||||
chunks = int(data_key.get("chunks", 2))
|
||||
source = {"active_docs": data_key.get("source")}
|
||||
user_api_key = data["api_key"]
|
||||
decoded_token = {"sub": data_key.get("user")}
|
||||
|
||||
elif "active_docs" in data:
|
||||
source = {"active_docs": data["active_docs"]}
|
||||
user_api_key = None
|
||||
decoded_token = request.decoded_token
|
||||
|
||||
else:
|
||||
source = {}
|
||||
user_api_key = None
|
||||
decoded_token = request.decoded_token
|
||||
|
||||
if not decoded_token:
|
||||
return make_response({"error": "Unauthorized"}, 401)
|
||||
|
||||
logger.info(
|
||||
f"/api/answer - request_data: {data}, source: {source}",
|
||||
extra={"data": json.dumps({"request_data": data, "source": source})},
|
||||
)
|
||||
|
||||
retriever = RetrieverCreator.create_retriever(
|
||||
retriever_name,
|
||||
source=source,
|
||||
chat_history=[],
|
||||
prompt="default",
|
||||
chunks=chunks,
|
||||
token_limit=token_limit,
|
||||
gpt_model=gpt_model,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
docs = retriever.search(question)
|
||||
retriever_params = retriever.get_params()
|
||||
|
||||
user_logs_collection.insert_one(
|
||||
{
|
||||
"action": "api_search",
|
||||
"level": "info",
|
||||
"user": decoded_token.get("sub"),
|
||||
"api_key": user_api_key,
|
||||
"question": question,
|
||||
"sources": docs,
|
||||
"retriever_params": retriever_params,
|
||||
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
||||
}
|
||||
)
|
||||
|
||||
if data.get("isNoneDoc"):
|
||||
for doc in docs:
|
||||
doc["source"] = "None"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"/api/search - error: {str(e)} - traceback: {traceback.format_exc()}",
|
||||
extra={"error": str(e), "traceback": traceback.format_exc()},
|
||||
)
|
||||
return bad_request(500, str(e))
|
||||
|
||||
return make_response(docs, 200)
|
||||
|
||||
|
||||
def get_attachments_content(attachment_ids, user):
|
||||
"""
|
||||
Retrieve content from attachment documents based on their IDs.
|
||||
|
||||
Args:
|
||||
attachment_ids (list): List of attachment document IDs
|
||||
user (str): User identifier to verify ownership
|
||||
|
||||
Returns:
|
||||
list: List of dictionaries containing attachment content and metadata
|
||||
"""
|
||||
if not attachment_ids:
|
||||
return []
|
||||
|
||||
attachments = []
|
||||
for attachment_id in attachment_ids:
|
||||
try:
|
||||
attachment_doc = attachments_collection.find_one(
|
||||
{"_id": ObjectId(attachment_id), "user": user}
|
||||
)
|
||||
|
||||
if attachment_doc:
|
||||
attachments.append(attachment_doc)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error retrieving attachment {attachment_id}: {e}", exc_info=True
|
||||
)
|
||||
|
||||
return attachments
|
||||
0
application/api/answer/routes/__init__.py
Normal file
0
application/api/answer/routes/__init__.py
Normal file
122
application/api/answer/routes/answer.py
Normal file
122
application/api/answer/routes/answer.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
from flask import make_response, request
|
||||
from flask_restx import fields, Resource
|
||||
|
||||
from application.api import api
|
||||
|
||||
from application.api.answer.routes.base import answer_ns, BaseAnswerResource
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@answer_ns.route("/api/answer")
|
||||
class AnswerResource(Resource, BaseAnswerResource):
|
||||
def __init__(self, *args, **kwargs):
|
||||
Resource.__init__(self, *args, **kwargs)
|
||||
BaseAnswerResource.__init__(self)
|
||||
|
||||
answer_model = answer_ns.model(
|
||||
"AnswerModel",
|
||||
{
|
||||
"question": fields.String(
|
||||
required=True, description="Question to be asked"
|
||||
),
|
||||
"history": fields.List(
|
||||
fields.String,
|
||||
required=False,
|
||||
description="Conversation history (only for new conversations)",
|
||||
),
|
||||
"conversation_id": fields.String(
|
||||
required=False,
|
||||
description="Existing conversation ID (loads history)",
|
||||
),
|
||||
"prompt_id": fields.String(
|
||||
required=False, default="default", description="Prompt ID"
|
||||
),
|
||||
"chunks": fields.Integer(
|
||||
required=False, default=2, description="Number of chunks"
|
||||
),
|
||||
"token_limit": fields.Integer(required=False, description="Token limit"),
|
||||
"retriever": fields.String(required=False, description="Retriever type"),
|
||||
"api_key": fields.String(required=False, description="API key"),
|
||||
"active_docs": fields.String(
|
||||
required=False, description="Active documents"
|
||||
),
|
||||
"isNoneDoc": fields.Boolean(
|
||||
required=False, description="Flag indicating if no document is used"
|
||||
),
|
||||
"save_conversation": fields.Boolean(
|
||||
required=False,
|
||||
default=True,
|
||||
description="Whether to save the conversation",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(answer_model)
|
||||
@api.doc(description="Provide a response based on the question and retriever")
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
if error := self.validate_request(data):
|
||||
return error
|
||||
decoded_token = getattr(request, "decoded_token", None)
|
||||
processor = StreamProcessor(data, decoded_token)
|
||||
try:
|
||||
processor.initialize()
|
||||
if not processor.decoded_token:
|
||||
return make_response({"error": "Unauthorized"}, 401)
|
||||
agent = processor.create_agent()
|
||||
retriever = processor.create_retriever()
|
||||
|
||||
stream = self.complete_stream(
|
||||
question=data["question"],
|
||||
agent=agent,
|
||||
retriever=retriever,
|
||||
conversation_id=processor.conversation_id,
|
||||
user_api_key=processor.agent_config.get("user_api_key"),
|
||||
decoded_token=processor.decoded_token,
|
||||
isNoneDoc=data.get("isNoneDoc"),
|
||||
index=None,
|
||||
should_save_conversation=data.get("save_conversation", True),
|
||||
)
|
||||
stream_result = self.process_response_stream(stream)
|
||||
|
||||
if len(stream_result) == 7:
|
||||
(
|
||||
conversation_id,
|
||||
response,
|
||||
sources,
|
||||
tool_calls,
|
||||
thought,
|
||||
error,
|
||||
structured_info,
|
||||
) = stream_result
|
||||
else:
|
||||
conversation_id, response, sources, tool_calls, thought, error = (
|
||||
stream_result
|
||||
)
|
||||
structured_info = None
|
||||
|
||||
if error:
|
||||
return make_response({"error": error}, 400)
|
||||
result = {
|
||||
"conversation_id": conversation_id,
|
||||
"answer": response,
|
||||
"sources": sources,
|
||||
"tool_calls": tool_calls,
|
||||
"thought": thought,
|
||||
}
|
||||
|
||||
if structured_info:
|
||||
result.update(structured_info)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",
|
||||
extra={"error": str(e), "traceback": traceback.format_exc()},
|
||||
)
|
||||
return make_response({"error": str(e)}, 500)
|
||||
return make_response(result, 200)
|
||||
265
application/api/answer/routes/base.py
Normal file
265
application/api/answer/routes/base.py
Normal file
@@ -0,0 +1,265 @@
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, Generator, List, Optional
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Namespace
|
||||
|
||||
from application.api.answer.services.conversation_service import ConversationService
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.utils import check_required_fields, get_gpt_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
answer_ns = Namespace("answer", description="Answer related operations", path="/")
|
||||
|
||||
|
||||
class BaseAnswerResource:
|
||||
"""Shared base class for answer endpoints"""
|
||||
|
||||
def __init__(self):
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
self.user_logs_collection = db["user_logs"]
|
||||
self.gpt_model = get_gpt_model()
|
||||
self.conversation_service = ConversationService()
|
||||
|
||||
def validate_request(
|
||||
self, data: Dict[str, Any], require_conversation_id: bool = False
|
||||
) -> Optional[Response]:
|
||||
"""Common request validation"""
|
||||
required_fields = ["question"]
|
||||
if require_conversation_id:
|
||||
required_fields.append("conversation_id")
|
||||
if missing_fields := check_required_fields(data, required_fields):
|
||||
return missing_fields
|
||||
return None
|
||||
|
||||
def complete_stream(
|
||||
self,
|
||||
question: str,
|
||||
agent: Any,
|
||||
retriever: Any,
|
||||
conversation_id: Optional[str],
|
||||
user_api_key: Optional[str],
|
||||
decoded_token: Dict[str, Any],
|
||||
isNoneDoc: bool = False,
|
||||
index: Optional[int] = None,
|
||||
should_save_conversation: bool = True,
|
||||
attachment_ids: Optional[List[str]] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
is_shared_usage: bool = False,
|
||||
shared_token: Optional[str] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Generator function that streams the complete conversation response.
|
||||
|
||||
Args:
|
||||
question: The user's question
|
||||
agent: The agent instance
|
||||
retriever: The retriever instance
|
||||
conversation_id: Existing conversation ID
|
||||
user_api_key: User's API key if any
|
||||
decoded_token: Decoded JWT token
|
||||
isNoneDoc: Flag for document-less responses
|
||||
index: Index of message to update
|
||||
should_save_conversation: Whether to persist the conversation
|
||||
attachment_ids: List of attachment IDs
|
||||
agent_id: ID of agent used
|
||||
is_shared_usage: Flag for shared agent usage
|
||||
shared_token: Token for shared agent
|
||||
|
||||
Yields:
|
||||
Server-sent event strings
|
||||
"""
|
||||
try:
|
||||
response_full, thought, source_log_docs, tool_calls = "", "", [], []
|
||||
is_structured = False
|
||||
schema_info = None
|
||||
structured_chunks = []
|
||||
|
||||
for line in agent.gen(query=question, retriever=retriever):
|
||||
if "answer" in line:
|
||||
response_full += str(line["answer"])
|
||||
if line.get("structured"):
|
||||
is_structured = True
|
||||
schema_info = line.get("schema")
|
||||
structured_chunks.append(line["answer"])
|
||||
else:
|
||||
data = json.dumps({"type": "answer", "answer": line["answer"]})
|
||||
yield f"data: {data}\n\n"
|
||||
elif "sources" in line:
|
||||
truncated_sources = []
|
||||
source_log_docs = line["sources"]
|
||||
for source in line["sources"]:
|
||||
truncated_source = source.copy()
|
||||
if "text" in truncated_source:
|
||||
truncated_source["text"] = (
|
||||
truncated_source["text"][:100].strip() + "..."
|
||||
)
|
||||
truncated_sources.append(truncated_source)
|
||||
if truncated_sources:
|
||||
data = json.dumps(
|
||||
{"type": "source", "source": truncated_sources}
|
||||
)
|
||||
yield f"data: {data}\n\n"
|
||||
elif "tool_calls" in line:
|
||||
tool_calls = line["tool_calls"]
|
||||
data = json.dumps({"type": "tool_calls", "tool_calls": tool_calls})
|
||||
yield f"data: {data}\n\n"
|
||||
elif "thought" in line:
|
||||
thought += line["thought"]
|
||||
data = json.dumps({"type": "thought", "thought": line["thought"]})
|
||||
yield f"data: {data}\n\n"
|
||||
elif "type" in line:
|
||||
data = json.dumps(line)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
if is_structured and structured_chunks:
|
||||
structured_data = {
|
||||
"type": "structured_answer",
|
||||
"answer": response_full,
|
||||
"structured": True,
|
||||
"schema": schema_info,
|
||||
}
|
||||
data = json.dumps(structured_data)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
if isNoneDoc:
|
||||
for doc in source_log_docs:
|
||||
doc["source"] = "None"
|
||||
llm = LLMCreator.create_llm(
|
||||
settings.LLM_PROVIDER,
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
if should_save_conversation:
|
||||
conversation_id = self.conversation_service.save_conversation(
|
||||
conversation_id,
|
||||
question,
|
||||
response_full,
|
||||
thought,
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
self.gpt_model,
|
||||
decoded_token,
|
||||
index=index,
|
||||
api_key=user_api_key,
|
||||
agent_id=agent_id,
|
||||
is_shared_usage=is_shared_usage,
|
||||
shared_token=shared_token,
|
||||
attachment_ids=attachment_ids,
|
||||
)
|
||||
else:
|
||||
conversation_id = None
|
||||
id_data = {"type": "id", "id": str(conversation_id)}
|
||||
data = json.dumps(id_data)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
retriever_params = retriever.get_params()
|
||||
log_data = {
|
||||
"action": "stream_answer",
|
||||
"level": "info",
|
||||
"user": decoded_token.get("sub"),
|
||||
"api_key": user_api_key,
|
||||
"question": question,
|
||||
"response": response_full,
|
||||
"sources": source_log_docs,
|
||||
"retriever_params": retriever_params,
|
||||
"attachments": attachment_ids,
|
||||
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
||||
}
|
||||
if is_structured:
|
||||
log_data["structured_output"] = True
|
||||
if schema_info:
|
||||
log_data["schema"] = schema_info
|
||||
|
||||
# clean up text fields to be no longer than 10000 characters
|
||||
for key, value in log_data.items():
|
||||
if isinstance(value, str) and len(value) > 10000:
|
||||
log_data[key] = value[:10000]
|
||||
|
||||
self.user_logs_collection.insert_one(log_data)
|
||||
|
||||
# End of stream
|
||||
|
||||
data = json.dumps({"type": "end"})
|
||||
yield f"data: {data}\n\n"
|
||||
except Exception as e:
|
||||
logger.error(f"Error in stream: {str(e)}", exc_info=True)
|
||||
data = json.dumps(
|
||||
{
|
||||
"type": "error",
|
||||
"error": "Please try again later. We apologize for any inconvenience.",
|
||||
}
|
||||
)
|
||||
yield f"data: {data}\n\n"
|
||||
return
|
||||
|
||||
def process_response_stream(self, stream):
|
||||
"""Process the stream response for non-streaming endpoint"""
|
||||
conversation_id = ""
|
||||
response_full = ""
|
||||
source_log_docs = []
|
||||
tool_calls = []
|
||||
thought = ""
|
||||
stream_ended = False
|
||||
is_structured = False
|
||||
schema_info = None
|
||||
|
||||
for line in stream:
|
||||
try:
|
||||
event_data = line.replace("data: ", "").strip()
|
||||
event = json.loads(event_data)
|
||||
|
||||
if event["type"] == "id":
|
||||
conversation_id = event["id"]
|
||||
elif event["type"] == "answer":
|
||||
response_full += event["answer"]
|
||||
elif event["type"] == "structured_answer":
|
||||
response_full = event["answer"]
|
||||
is_structured = True
|
||||
schema_info = event.get("schema")
|
||||
elif event["type"] == "source":
|
||||
source_log_docs = event["source"]
|
||||
elif event["type"] == "tool_calls":
|
||||
tool_calls = event["tool_calls"]
|
||||
elif event["type"] == "thought":
|
||||
thought = event["thought"]
|
||||
elif event["type"] == "error":
|
||||
logger.error(f"Error from stream: {event['error']}")
|
||||
return None, None, None, None, event["error"]
|
||||
elif event["type"] == "end":
|
||||
stream_ended = True
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
logger.warning(f"Error parsing stream event: {e}, line: {line}")
|
||||
continue
|
||||
if not stream_ended:
|
||||
logger.error("Stream ended unexpectedly without an 'end' event.")
|
||||
return None, None, None, None, "Stream ended unexpectedly"
|
||||
|
||||
result = (
|
||||
conversation_id,
|
||||
response_full,
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
thought,
|
||||
None,
|
||||
)
|
||||
|
||||
if is_structured:
|
||||
result = result + ({"structured": True, "schema": schema_info},)
|
||||
|
||||
return result
|
||||
|
||||
def error_stream_generate(self, err_response):
|
||||
data = json.dumps({"type": "error", "error": err_response})
|
||||
yield f"data: {data}\n\n"
|
||||
117
application/api/answer/routes/stream.py
Normal file
117
application/api/answer/routes/stream.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
from flask import request, Response
|
||||
from flask_restx import fields, Resource
|
||||
|
||||
from application.api import api
|
||||
|
||||
from application.api.answer.routes.base import answer_ns, BaseAnswerResource
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@answer_ns.route("/stream")
|
||||
class StreamResource(Resource, BaseAnswerResource):
|
||||
def __init__(self, *args, **kwargs):
|
||||
Resource.__init__(self, *args, **kwargs)
|
||||
BaseAnswerResource.__init__(self)
|
||||
|
||||
stream_model = answer_ns.model(
|
||||
"StreamModel",
|
||||
{
|
||||
"question": fields.String(
|
||||
required=True, description="Question to be asked"
|
||||
),
|
||||
"history": fields.List(
|
||||
fields.String,
|
||||
required=False,
|
||||
description="Conversation history (only for new conversations)",
|
||||
),
|
||||
"conversation_id": fields.String(
|
||||
required=False,
|
||||
description="Existing conversation ID (loads history)",
|
||||
),
|
||||
"prompt_id": fields.String(
|
||||
required=False, default="default", description="Prompt ID"
|
||||
),
|
||||
"chunks": fields.Integer(
|
||||
required=False, default=2, description="Number of chunks"
|
||||
),
|
||||
"token_limit": fields.Integer(required=False, description="Token limit"),
|
||||
"retriever": fields.String(required=False, description="Retriever type"),
|
||||
"api_key": fields.String(required=False, description="API key"),
|
||||
"active_docs": fields.String(
|
||||
required=False, description="Active documents"
|
||||
),
|
||||
"isNoneDoc": fields.Boolean(
|
||||
required=False, description="Flag indicating if no document is used"
|
||||
),
|
||||
"index": fields.Integer(
|
||||
required=False, description="Index of the query to update"
|
||||
),
|
||||
"save_conversation": fields.Boolean(
|
||||
required=False,
|
||||
default=True,
|
||||
description="Whether to save the conversation",
|
||||
),
|
||||
"attachments": fields.List(
|
||||
fields.String, required=False, description="List of attachment IDs"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(stream_model)
|
||||
@api.doc(description="Stream a response based on the question and retriever")
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
if error := self.validate_request(data, "index" in data):
|
||||
return error
|
||||
decoded_token = getattr(request, "decoded_token", None)
|
||||
processor = StreamProcessor(data, decoded_token)
|
||||
try:
|
||||
processor.initialize()
|
||||
agent = processor.create_agent()
|
||||
retriever = processor.create_retriever()
|
||||
|
||||
return Response(
|
||||
self.complete_stream(
|
||||
question=data["question"],
|
||||
agent=agent,
|
||||
retriever=retriever,
|
||||
conversation_id=processor.conversation_id,
|
||||
user_api_key=processor.agent_config.get("user_api_key"),
|
||||
decoded_token=processor.decoded_token,
|
||||
isNoneDoc=data.get("isNoneDoc"),
|
||||
index=data.get("index"),
|
||||
should_save_conversation=data.get("save_conversation", True),
|
||||
attachment_ids=data.get("attachments", []),
|
||||
agent_id=data.get("agent_id"),
|
||||
is_shared_usage=processor.is_shared_usage,
|
||||
shared_token=processor.shared_token,
|
||||
),
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
except ValueError as e:
|
||||
message = "Malformed request body"
|
||||
logger.error(
|
||||
f"/stream - error: {message} - specific error: {str(e)} - traceback: {traceback.format_exc()}",
|
||||
extra={"error": str(e), "traceback": traceback.format_exc()},
|
||||
)
|
||||
return Response(
|
||||
self.error_stream_generate(message),
|
||||
status=400,
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"/stream - error: {str(e)} - traceback: {traceback.format_exc()}",
|
||||
extra={"error": str(e), "traceback": traceback.format_exc()},
|
||||
)
|
||||
return Response(
|
||||
self.error_stream_generate("Unknown error occurred"),
|
||||
status=400,
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
0
application/api/answer/services/__init__.py
Normal file
0
application/api/answer/services/__init__.py
Normal file
180
application/api/answer/services/conversation_service.py
Normal file
180
application/api/answer/services/conversation_service.py
Normal file
@@ -0,0 +1,180 @@
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
|
||||
from application.core.settings import settings
|
||||
from bson import ObjectId
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConversationService:
|
||||
def __init__(self):
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
self.conversations_collection = db["conversations"]
|
||||
self.agents_collection = db["agents"]
|
||||
|
||||
def get_conversation(
|
||||
self, conversation_id: str, user_id: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Retrieve a conversation with proper access control"""
|
||||
if not conversation_id or not user_id:
|
||||
return None
|
||||
try:
|
||||
conversation = self.conversations_collection.find_one(
|
||||
{
|
||||
"_id": ObjectId(conversation_id),
|
||||
"$or": [{"user": user_id}, {"shared_with": user_id}],
|
||||
}
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
logger.warning(
|
||||
f"Conversation not found or unauthorized - ID: {conversation_id}, User: {user_id}"
|
||||
)
|
||||
return None
|
||||
conversation["_id"] = str(conversation["_id"])
|
||||
return conversation
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching conversation: {str(e)}", exc_info=True)
|
||||
return None
|
||||
|
||||
def save_conversation(
|
||||
self,
|
||||
conversation_id: Optional[str],
|
||||
question: str,
|
||||
response: str,
|
||||
thought: str,
|
||||
sources: List[Dict[str, Any]],
|
||||
tool_calls: List[Dict[str, Any]],
|
||||
llm: Any,
|
||||
gpt_model: str,
|
||||
decoded_token: Dict[str, Any],
|
||||
index: Optional[int] = None,
|
||||
api_key: Optional[str] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
is_shared_usage: bool = False,
|
||||
shared_token: Optional[str] = None,
|
||||
attachment_ids: Optional[List[str]] = None,
|
||||
) -> str:
|
||||
"""Save or update a conversation in the database"""
|
||||
user_id = decoded_token.get("sub")
|
||||
if not user_id:
|
||||
raise ValueError("User ID not found in token")
|
||||
current_time = datetime.now(timezone.utc)
|
||||
|
||||
# clean up in sources array such that we save max 1k characters for text part
|
||||
for source in sources:
|
||||
if "text" in source and isinstance(source["text"], str):
|
||||
source["text"] = source["text"][:1000]
|
||||
|
||||
if conversation_id is not None and index is not None:
|
||||
# Update existing conversation with new query
|
||||
|
||||
result = self.conversations_collection.update_one(
|
||||
{
|
||||
"_id": ObjectId(conversation_id),
|
||||
"user": user_id,
|
||||
f"queries.{index}": {"$exists": True},
|
||||
},
|
||||
{
|
||||
"$set": {
|
||||
f"queries.{index}.prompt": question,
|
||||
f"queries.{index}.response": response,
|
||||
f"queries.{index}.thought": thought,
|
||||
f"queries.{index}.sources": sources,
|
||||
f"queries.{index}.tool_calls": tool_calls,
|
||||
f"queries.{index}.timestamp": current_time,
|
||||
f"queries.{index}.attachments": attachment_ids,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
if result.matched_count == 0:
|
||||
raise ValueError("Conversation not found or unauthorized")
|
||||
self.conversations_collection.update_one(
|
||||
{
|
||||
"_id": ObjectId(conversation_id),
|
||||
"user": user_id,
|
||||
f"queries.{index}": {"$exists": True},
|
||||
},
|
||||
{"$push": {"queries": {"$each": [], "$slice": index + 1}}},
|
||||
)
|
||||
return conversation_id
|
||||
elif conversation_id:
|
||||
# Append new message to existing conversation
|
||||
|
||||
result = self.conversations_collection.update_one(
|
||||
{"_id": ObjectId(conversation_id), "user": user_id},
|
||||
{
|
||||
"$push": {
|
||||
"queries": {
|
||||
"prompt": question,
|
||||
"response": response,
|
||||
"thought": thought,
|
||||
"sources": sources,
|
||||
"tool_calls": tool_calls,
|
||||
"timestamp": current_time,
|
||||
"attachments": attachment_ids,
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
if result.matched_count == 0:
|
||||
raise ValueError("Conversation not found or unauthorized")
|
||||
return conversation_id
|
||||
else:
|
||||
# Create new conversation
|
||||
|
||||
messages_summary = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Summarise following conversation in no more than 3 "
|
||||
"words, respond ONLY with the summary, use the same "
|
||||
"language as the user query",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Summarise following conversation in no more than 3 words, "
|
||||
"respond ONLY with the summary, use the same language as the "
|
||||
"user query \n\nUser: " + question + "\n\n" + "AI: " + response,
|
||||
},
|
||||
]
|
||||
|
||||
completion = llm.gen(
|
||||
model=gpt_model, messages=messages_summary, max_tokens=30
|
||||
)
|
||||
|
||||
conversation_data = {
|
||||
"user": user_id,
|
||||
"date": current_time,
|
||||
"name": completion,
|
||||
"queries": [
|
||||
{
|
||||
"prompt": question,
|
||||
"response": response,
|
||||
"thought": thought,
|
||||
"sources": sources,
|
||||
"tool_calls": tool_calls,
|
||||
"timestamp": current_time,
|
||||
"attachments": attachment_ids,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
if api_key:
|
||||
if agent_id:
|
||||
conversation_data["agent_id"] = agent_id
|
||||
if is_shared_usage:
|
||||
conversation_data["is_shared_usage"] = is_shared_usage
|
||||
conversation_data["shared_token"] = shared_token
|
||||
agent = self.agents_collection.find_one({"key": api_key})
|
||||
if agent:
|
||||
conversation_data["api_key"] = agent["key"]
|
||||
result = self.conversations_collection.insert_one(conversation_data)
|
||||
return str(result.inserted_id)
|
||||
353
application/api/answer/services/stream_processor.py
Normal file
353
application/api/answer/services/stream_processor.py
Normal file
@@ -0,0 +1,353 @@
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from bson.dbref import DBRef
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
|
||||
from application.agents.agent_creator import AgentCreator
|
||||
from application.api.answer.services.conversation_service import ConversationService
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.retriever.retriever_creator import RetrieverCreator
|
||||
from application.utils import get_gpt_model, limit_chat_history
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_prompt(prompt_id: str, prompts_collection=None) -> str:
|
||||
"""
|
||||
Get a prompt by preset name or MongoDB ID
|
||||
"""
|
||||
current_dir = Path(__file__).resolve().parents[3]
|
||||
prompts_dir = current_dir / "prompts"
|
||||
|
||||
preset_mapping = {
|
||||
"default": "chat_combine_default.txt",
|
||||
"creative": "chat_combine_creative.txt",
|
||||
"strict": "chat_combine_strict.txt",
|
||||
"reduce": "chat_reduce_prompt.txt",
|
||||
}
|
||||
|
||||
if prompt_id in preset_mapping:
|
||||
file_path = os.path.join(prompts_dir, preset_mapping[prompt_id])
|
||||
try:
|
||||
with open(file_path, "r") as f:
|
||||
return f.read()
|
||||
except FileNotFoundError:
|
||||
raise FileNotFoundError(f"Prompt file not found: {file_path}")
|
||||
try:
|
||||
if prompts_collection is None:
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
prompts_collection = db["prompts"]
|
||||
prompt_doc = prompts_collection.find_one({"_id": ObjectId(prompt_id)})
|
||||
if not prompt_doc:
|
||||
raise ValueError(f"Prompt with ID {prompt_id} not found")
|
||||
return prompt_doc["content"]
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid prompt ID: {prompt_id}") from e
|
||||
|
||||
|
||||
class StreamProcessor:
|
||||
def __init__(
|
||||
self, request_data: Dict[str, Any], decoded_token: Optional[Dict[str, Any]]
|
||||
):
|
||||
mongo = MongoDB.get_client()
|
||||
self.db = mongo[settings.MONGO_DB_NAME]
|
||||
self.agents_collection = self.db["agents"]
|
||||
self.attachments_collection = self.db["attachments"]
|
||||
self.prompts_collection = self.db["prompts"]
|
||||
|
||||
self.data = request_data
|
||||
self.decoded_token = decoded_token
|
||||
self.initial_user_id = (
|
||||
self.decoded_token.get("sub") if self.decoded_token is not None else None
|
||||
)
|
||||
self.conversation_id = self.data.get("conversation_id")
|
||||
self.source = {}
|
||||
self.all_sources = []
|
||||
self.attachments = []
|
||||
self.history = []
|
||||
self.agent_config = {}
|
||||
self.retriever_config = {}
|
||||
self.is_shared_usage = False
|
||||
self.shared_token = None
|
||||
self.gpt_model = get_gpt_model()
|
||||
self.conversation_service = ConversationService()
|
||||
|
||||
def initialize(self):
|
||||
"""Initialize all required components for processing"""
|
||||
self._configure_agent()
|
||||
self._configure_source()
|
||||
self._configure_retriever()
|
||||
self._configure_agent()
|
||||
self._load_conversation_history()
|
||||
self._process_attachments()
|
||||
|
||||
def _load_conversation_history(self):
|
||||
"""Load conversation history either from DB or request"""
|
||||
if self.conversation_id and self.initial_user_id:
|
||||
conversation = self.conversation_service.get_conversation(
|
||||
self.conversation_id, self.initial_user_id
|
||||
)
|
||||
if not conversation:
|
||||
raise ValueError("Conversation not found or unauthorized")
|
||||
self.history = [
|
||||
{"prompt": query["prompt"], "response": query["response"]}
|
||||
for query in conversation.get("queries", [])
|
||||
]
|
||||
else:
|
||||
self.history = limit_chat_history(
|
||||
json.loads(self.data.get("history", "[]")), gpt_model=self.gpt_model
|
||||
)
|
||||
|
||||
def _process_attachments(self):
|
||||
"""Process any attachments in the request"""
|
||||
attachment_ids = self.data.get("attachments", [])
|
||||
self.attachments = self._get_attachments_content(
|
||||
attachment_ids, self.initial_user_id
|
||||
)
|
||||
|
||||
def _get_attachments_content(self, attachment_ids, user_id):
|
||||
"""
|
||||
Retrieve content from attachment documents based on their IDs.
|
||||
"""
|
||||
if not attachment_ids:
|
||||
return []
|
||||
attachments = []
|
||||
for attachment_id in attachment_ids:
|
||||
try:
|
||||
attachment_doc = self.attachments_collection.find_one(
|
||||
{"_id": ObjectId(attachment_id), "user": user_id}
|
||||
)
|
||||
|
||||
if attachment_doc:
|
||||
attachments.append(attachment_doc)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error retrieving attachment {attachment_id}: {e}", exc_info=True
|
||||
)
|
||||
return attachments
|
||||
|
||||
def _get_agent_key(self, agent_id: Optional[str], user_id: Optional[str]) -> tuple:
|
||||
"""Get API key for agent with access control"""
|
||||
if not agent_id:
|
||||
return None, False, None
|
||||
try:
|
||||
agent = self.agents_collection.find_one({"_id": ObjectId(agent_id)})
|
||||
if agent is None:
|
||||
raise Exception("Agent not found")
|
||||
is_owner = agent.get("user") == user_id
|
||||
is_shared_with_user = agent.get(
|
||||
"shared_publicly", False
|
||||
) or user_id in agent.get("shared_with", [])
|
||||
|
||||
if not (is_owner or is_shared_with_user):
|
||||
raise Exception("Unauthorized access to the agent")
|
||||
if is_owner:
|
||||
self.agents_collection.update_one(
|
||||
{"_id": ObjectId(agent_id)},
|
||||
{
|
||||
"$set": {
|
||||
"lastUsedAt": datetime.datetime.now(datetime.timezone.utc)
|
||||
}
|
||||
},
|
||||
)
|
||||
return str(agent["key"]), not is_owner, agent.get("shared_token")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_agent_key: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def _get_data_from_api_key(self, api_key: str) -> Dict[str, Any]:
|
||||
data = self.agents_collection.find_one({"key": api_key})
|
||||
if not data:
|
||||
raise Exception("Invalid API Key, please generate a new key", 401)
|
||||
source = data.get("source")
|
||||
if isinstance(source, DBRef):
|
||||
source_doc = self.db.dereference(source)
|
||||
if source_doc:
|
||||
data["source"] = str(source_doc["_id"])
|
||||
data["retriever"] = source_doc.get("retriever", data.get("retriever"))
|
||||
data["chunks"] = source_doc.get("chunks", data.get("chunks"))
|
||||
else:
|
||||
data["source"] = None
|
||||
elif source == "default":
|
||||
data["source"] = "default"
|
||||
else:
|
||||
data["source"] = None
|
||||
# Handle multiple sources
|
||||
|
||||
sources = data.get("sources", [])
|
||||
if sources and isinstance(sources, list):
|
||||
sources_list = []
|
||||
for i, source_ref in enumerate(sources):
|
||||
if source_ref == "default":
|
||||
processed_source = {
|
||||
"id": "default",
|
||||
"retriever": "classic",
|
||||
"chunks": data.get("chunks", "2"),
|
||||
}
|
||||
sources_list.append(processed_source)
|
||||
elif isinstance(source_ref, DBRef):
|
||||
source_doc = self.db.dereference(source_ref)
|
||||
if source_doc:
|
||||
processed_source = {
|
||||
"id": str(source_doc["_id"]),
|
||||
"retriever": source_doc.get("retriever", "classic"),
|
||||
"chunks": source_doc.get("chunks", data.get("chunks", "2")),
|
||||
}
|
||||
sources_list.append(processed_source)
|
||||
data["sources"] = sources_list
|
||||
else:
|
||||
data["sources"] = []
|
||||
return data
|
||||
|
||||
def _configure_source(self):
|
||||
"""Configure the source based on agent data"""
|
||||
api_key = self.data.get("api_key") or self.agent_key
|
||||
|
||||
if api_key:
|
||||
agent_data = self._get_data_from_api_key(api_key)
|
||||
|
||||
if agent_data.get("sources") and len(agent_data["sources"]) > 0:
|
||||
source_ids = [
|
||||
source["id"] for source in agent_data["sources"] if source.get("id")
|
||||
]
|
||||
if source_ids:
|
||||
self.source = {"active_docs": source_ids}
|
||||
else:
|
||||
self.source = {}
|
||||
self.all_sources = agent_data["sources"]
|
||||
elif agent_data.get("source"):
|
||||
self.source = {"active_docs": agent_data["source"]}
|
||||
self.all_sources = [
|
||||
{
|
||||
"id": agent_data["source"],
|
||||
"retriever": agent_data.get("retriever", "classic"),
|
||||
}
|
||||
]
|
||||
else:
|
||||
self.source = {}
|
||||
self.all_sources = []
|
||||
return
|
||||
if "active_docs" in self.data:
|
||||
self.source = {"active_docs": self.data["active_docs"]}
|
||||
return
|
||||
self.source = {}
|
||||
self.all_sources = []
|
||||
|
||||
def _configure_agent(self):
|
||||
"""Configure the agent based on request data"""
|
||||
agent_id = self.data.get("agent_id")
|
||||
self.agent_key, self.is_shared_usage, self.shared_token = self._get_agent_key(
|
||||
agent_id, self.initial_user_id
|
||||
)
|
||||
|
||||
api_key = self.data.get("api_key")
|
||||
if api_key:
|
||||
data_key = self._get_data_from_api_key(api_key)
|
||||
self.agent_config.update(
|
||||
{
|
||||
"prompt_id": data_key.get("prompt_id", "default"),
|
||||
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
|
||||
"user_api_key": api_key,
|
||||
"json_schema": data_key.get("json_schema"),
|
||||
}
|
||||
)
|
||||
self.initial_user_id = data_key.get("user")
|
||||
self.decoded_token = {"sub": data_key.get("user")}
|
||||
if data_key.get("source"):
|
||||
self.source = {"active_docs": data_key["source"]}
|
||||
if data_key.get("retriever"):
|
||||
self.retriever_config["retriever_name"] = data_key["retriever"]
|
||||
if data_key.get("chunks") is not None:
|
||||
try:
|
||||
self.retriever_config["chunks"] = int(data_key["chunks"])
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
f"Invalid chunks value: {data_key['chunks']}, using default value 2"
|
||||
)
|
||||
self.retriever_config["chunks"] = 2
|
||||
elif self.agent_key:
|
||||
data_key = self._get_data_from_api_key(self.agent_key)
|
||||
self.agent_config.update(
|
||||
{
|
||||
"prompt_id": data_key.get("prompt_id", "default"),
|
||||
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
|
||||
"user_api_key": self.agent_key,
|
||||
"json_schema": data_key.get("json_schema"),
|
||||
}
|
||||
)
|
||||
self.decoded_token = (
|
||||
self.decoded_token
|
||||
if self.is_shared_usage
|
||||
else {"sub": data_key.get("user")}
|
||||
)
|
||||
if data_key.get("source"):
|
||||
self.source = {"active_docs": data_key["source"]}
|
||||
if data_key.get("retriever"):
|
||||
self.retriever_config["retriever_name"] = data_key["retriever"]
|
||||
if data_key.get("chunks") is not None:
|
||||
try:
|
||||
self.retriever_config["chunks"] = int(data_key["chunks"])
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
f"Invalid chunks value: {data_key['chunks']}, using default value 2"
|
||||
)
|
||||
self.retriever_config["chunks"] = 2
|
||||
else:
|
||||
self.agent_config.update(
|
||||
{
|
||||
"prompt_id": self.data.get("prompt_id", "default"),
|
||||
"agent_type": settings.AGENT_NAME,
|
||||
"user_api_key": None,
|
||||
"json_schema": None,
|
||||
}
|
||||
)
|
||||
|
||||
def _configure_retriever(self):
|
||||
"""Configure the retriever based on request data"""
|
||||
self.retriever_config = {
|
||||
"retriever_name": self.data.get("retriever", "classic"),
|
||||
"chunks": int(self.data.get("chunks", 2)),
|
||||
"token_limit": self.data.get("token_limit", settings.DEFAULT_MAX_HISTORY),
|
||||
}
|
||||
|
||||
api_key = self.data.get("api_key") or self.agent_key
|
||||
if not api_key and "isNoneDoc" in self.data and self.data["isNoneDoc"]:
|
||||
self.retriever_config["chunks"] = 0
|
||||
|
||||
def create_agent(self):
|
||||
"""Create and return the configured agent"""
|
||||
return AgentCreator.create_agent(
|
||||
self.agent_config["agent_type"],
|
||||
endpoint="stream",
|
||||
llm_name=settings.LLM_PROVIDER,
|
||||
gpt_model=self.gpt_model,
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=self.agent_config["user_api_key"],
|
||||
prompt=get_prompt(self.agent_config["prompt_id"], self.prompts_collection),
|
||||
chat_history=self.history,
|
||||
decoded_token=self.decoded_token,
|
||||
attachments=self.attachments,
|
||||
json_schema=self.agent_config.get("json_schema"),
|
||||
)
|
||||
|
||||
def create_retriever(self):
|
||||
"""Create and return the configured retriever"""
|
||||
return RetrieverCreator.create_retriever(
|
||||
self.retriever_config["retriever_name"],
|
||||
source=self.source,
|
||||
chat_history=self.history,
|
||||
prompt=get_prompt(self.agent_config["prompt_id"], self.prompts_collection),
|
||||
chunks=self.retriever_config["chunks"],
|
||||
token_limit=self.retriever_config["token_limit"],
|
||||
gpt_model=self.gpt_model,
|
||||
user_api_key=self.agent_config["user_api_key"],
|
||||
decoded_token=self.decoded_token,
|
||||
)
|
||||
695
application/api/connector/routes.py
Normal file
695
application/api/connector/routes.py
Normal file
@@ -0,0 +1,695 @@
|
||||
import base64
|
||||
import datetime
|
||||
import json
|
||||
import uuid
|
||||
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import (
|
||||
Blueprint,
|
||||
current_app,
|
||||
jsonify,
|
||||
make_response,
|
||||
request
|
||||
)
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
|
||||
from application.api.user.tasks import (
|
||||
ingest_connector_task,
|
||||
)
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.api import api
|
||||
|
||||
|
||||
from application.utils import (
|
||||
check_required_fields
|
||||
)
|
||||
|
||||
|
||||
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||
|
||||
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
sources_collection = db["sources"]
|
||||
sessions_collection = db["connector_sessions"]
|
||||
|
||||
connector = Blueprint("connector", __name__)
|
||||
connectors_ns = Namespace("connectors", description="Connector operations", path="/")
|
||||
api.add_namespace(connectors_ns)
|
||||
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/upload")
|
||||
class UploadConnector(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"ConnectorUploadModel",
|
||||
{
|
||||
"user": fields.String(required=True, description="User ID"),
|
||||
"source": fields.String(
|
||||
required=True, description="Source type (google_drive, github, etc.)"
|
||||
),
|
||||
"name": fields.String(required=True, description="Job name"),
|
||||
"data": fields.String(required=True, description="Configuration data"),
|
||||
"repo_url": fields.String(description="GitHub repository URL"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Uploads connector source for vectorization",
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
data = request.form
|
||||
required_fields = ["user", "source", "name", "data"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
config = json.loads(data["data"])
|
||||
source_data = None
|
||||
sync_frequency = config.get("sync_frequency", "never")
|
||||
|
||||
if data["source"] == "github":
|
||||
source_data = config.get("repo_url")
|
||||
elif data["source"] in ["crawler", "url"]:
|
||||
source_data = config.get("url")
|
||||
elif data["source"] == "reddit":
|
||||
source_data = config
|
||||
elif data["source"] in ConnectorCreator.get_supported_connectors():
|
||||
session_token = config.get("session_token")
|
||||
if not session_token:
|
||||
return make_response(jsonify({
|
||||
"success": False,
|
||||
"error": f"Missing session_token in {data['source']} configuration"
|
||||
}), 400)
|
||||
|
||||
file_ids = config.get("file_ids", [])
|
||||
if isinstance(file_ids, str):
|
||||
file_ids = [id.strip() for id in file_ids.split(',') if id.strip()]
|
||||
elif not isinstance(file_ids, list):
|
||||
file_ids = []
|
||||
|
||||
folder_ids = config.get("folder_ids", [])
|
||||
if isinstance(folder_ids, str):
|
||||
folder_ids = [id.strip() for id in folder_ids.split(',') if id.strip()]
|
||||
elif not isinstance(folder_ids, list):
|
||||
folder_ids = []
|
||||
|
||||
config["file_ids"] = file_ids
|
||||
config["folder_ids"] = folder_ids
|
||||
|
||||
task = ingest_connector_task.delay(
|
||||
job_name=data["name"],
|
||||
user=decoded_token.get("sub"),
|
||||
source_type=data["source"],
|
||||
session_token=session_token,
|
||||
file_ids=file_ids,
|
||||
folder_ids=folder_ids,
|
||||
recursive=config.get("recursive", False),
|
||||
retriever=config.get("retriever", "classic"),
|
||||
sync_frequency=sync_frequency
|
||||
)
|
||||
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
|
||||
task = ingest_connector_task.delay(
|
||||
source_data=source_data,
|
||||
job_name=data["name"],
|
||||
user=decoded_token.get("sub"),
|
||||
loader=data["source"],
|
||||
sync_frequency=sync_frequency
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error uploading connector source: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/task_status")
|
||||
class ConnectorTaskStatus(Resource):
|
||||
task_status_model = api.model(
|
||||
"ConnectorTaskStatusModel",
|
||||
{"task_id": fields.String(required=True, description="Task ID")},
|
||||
)
|
||||
|
||||
@api.expect(task_status_model)
|
||||
@api.doc(description="Get connector task status")
|
||||
def get(self):
|
||||
task_id = request.args.get("task_id")
|
||||
if not task_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Task ID is required"}), 400
|
||||
)
|
||||
try:
|
||||
from application.celery_init import celery
|
||||
|
||||
task = celery.AsyncResult(task_id)
|
||||
task_meta = task.info
|
||||
print(f"Task status: {task.status}")
|
||||
if not isinstance(
|
||||
task_meta, (dict, list, str, int, float, bool, type(None))
|
||||
):
|
||||
task_meta = str(task_meta)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error getting task status: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"status": task.status, "result": task_meta}), 200)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/sources")
|
||||
class ConnectorSources(Resource):
|
||||
@api.doc(description="Get connector sources")
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
sources = sources_collection.find({"user": user, "type": "connector:file"}).sort("date", -1)
|
||||
connector_sources = []
|
||||
for source in sources:
|
||||
connector_sources.append({
|
||||
"id": str(source["_id"]),
|
||||
"name": source.get("name"),
|
||||
"date": source.get("date"),
|
||||
"type": source.get("type"),
|
||||
"source": source.get("source"),
|
||||
"tokens": source.get("tokens", ""),
|
||||
"retriever": source.get("retriever", "classic"),
|
||||
"syncFrequency": source.get("sync_frequency", ""),
|
||||
})
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving connector sources: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify(connector_sources), 200)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/delete")
|
||||
class DeleteConnectorSource(Resource):
|
||||
@api.doc(
|
||||
description="Delete a connector source",
|
||||
params={"source_id": "The source ID to delete"},
|
||||
)
|
||||
def delete(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
source_id = request.args.get("source_id")
|
||||
if not source_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "source_id is required"}), 400
|
||||
)
|
||||
try:
|
||||
result = sources_collection.delete_one(
|
||||
{"_id": ObjectId(source_id), "user": decoded_token.get("sub")}
|
||||
)
|
||||
if result.deleted_count == 0:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Source not found"}), 404
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error deleting connector source: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/auth")
|
||||
class ConnectorAuth(Resource):
|
||||
@api.doc(description="Get connector OAuth authorization URL", params={"provider": "Connector provider (e.g., google_drive)"})
|
||||
def get(self):
|
||||
try:
|
||||
provider = request.args.get('provider') or request.args.get('source')
|
||||
if not provider:
|
||||
return make_response(jsonify({"success": False, "error": "Missing provider"}), 400)
|
||||
|
||||
if not ConnectorCreator.is_supported(provider):
|
||||
return make_response(jsonify({"success": False, "error": f"Unsupported provider: {provider}"}), 400)
|
||||
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
||||
user_id = decoded_token.get('sub')
|
||||
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
result = sessions_collection.insert_one({
|
||||
"provider": provider,
|
||||
"user": user_id,
|
||||
"status": "pending",
|
||||
"created_at": now
|
||||
})
|
||||
state_dict = {
|
||||
"provider": provider,
|
||||
"object_id": str(result.inserted_id)
|
||||
}
|
||||
state = base64.urlsafe_b64encode(json.dumps(state_dict).encode()).decode()
|
||||
|
||||
auth = ConnectorCreator.create_auth(provider)
|
||||
authorization_url = auth.get_authorization_url(state=state)
|
||||
return make_response(jsonify({
|
||||
"success": True,
|
||||
"authorization_url": authorization_url,
|
||||
"state": state
|
||||
}), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error generating connector auth URL: {e}")
|
||||
return make_response(jsonify({"success": False, "error": str(e)}), 500)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/callback")
|
||||
class ConnectorsCallback(Resource):
|
||||
@api.doc(description="Handle OAuth callback for external connectors")
|
||||
def get(self):
|
||||
"""Handle OAuth callback for external connectors"""
|
||||
try:
|
||||
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||
from flask import request, redirect
|
||||
|
||||
authorization_code = request.args.get('code')
|
||||
state = request.args.get('state')
|
||||
error = request.args.get('error')
|
||||
|
||||
state_dict = json.loads(base64.urlsafe_b64decode(state.encode()).decode())
|
||||
provider = state_dict["provider"]
|
||||
state_object_id = state_dict["object_id"]
|
||||
|
||||
if error:
|
||||
if error == "access_denied":
|
||||
return redirect(f"/api/connectors/callback-status?status=cancelled&message=Authentication+was+cancelled.+You+can+try+again+if+you'd+like+to+connect+your+account.&provider={provider}")
|
||||
else:
|
||||
current_app.logger.warning(f"OAuth error in callback: {error}")
|
||||
return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}")
|
||||
|
||||
if not authorization_code:
|
||||
return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}")
|
||||
|
||||
try:
|
||||
auth = ConnectorCreator.create_auth(provider)
|
||||
token_info = auth.exchange_code_for_tokens(authorization_code)
|
||||
|
||||
session_token = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
credentials = auth.create_credentials_from_token_info(token_info)
|
||||
service = auth.build_drive_service(credentials)
|
||||
user_info = service.about().get(fields="user").execute()
|
||||
user_email = user_info.get('user', {}).get('emailAddress', 'Connected User')
|
||||
except Exception as e:
|
||||
current_app.logger.warning(f"Could not get user info: {e}")
|
||||
user_email = 'Connected User'
|
||||
|
||||
sanitized_token_info = {
|
||||
"access_token": token_info.get("access_token"),
|
||||
"refresh_token": token_info.get("refresh_token"),
|
||||
"token_uri": token_info.get("token_uri"),
|
||||
"expiry": token_info.get("expiry")
|
||||
}
|
||||
|
||||
sessions_collection.find_one_and_update(
|
||||
{"_id": ObjectId(state_object_id), "provider": provider},
|
||||
{
|
||||
"$set": {
|
||||
"session_token": session_token,
|
||||
"token_info": sanitized_token_info,
|
||||
"user_email": user_email,
|
||||
"status": "authorized"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Redirect to success page with session token and user email
|
||||
return redirect(f"/api/connectors/callback-status?status=success&message=Authentication+successful&provider={provider}&session_token={session_token}&user_email={user_email}")
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error exchanging code for tokens: {str(e)}", exc_info=True)
|
||||
return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}")
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error handling connector callback: {e}")
|
||||
return redirect("/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.")
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/refresh")
|
||||
class ConnectorRefresh(Resource):
|
||||
@api.expect(api.model("ConnectorRefreshModel", {"provider": fields.String(required=True), "refresh_token": fields.String(required=True)}))
|
||||
@api.doc(description="Refresh connector access token")
|
||||
def post(self):
|
||||
try:
|
||||
data = request.get_json()
|
||||
provider = data.get('provider')
|
||||
refresh_token = data.get('refresh_token')
|
||||
|
||||
if not provider or not refresh_token:
|
||||
return make_response(jsonify({"success": False, "error": "provider and refresh_token are required"}), 400)
|
||||
|
||||
auth = ConnectorCreator.create_auth(provider)
|
||||
token_info = auth.refresh_access_token(refresh_token)
|
||||
return make_response(jsonify({"success": True, "token_info": token_info}), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error refreshing token for connector: {e}")
|
||||
return make_response(jsonify({"success": False, "error": str(e)}), 500)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/files")
|
||||
class ConnectorFiles(Resource):
|
||||
@api.expect(api.model("ConnectorFilesModel", {
|
||||
"provider": fields.String(required=True),
|
||||
"session_token": fields.String(required=True),
|
||||
"folder_id": fields.String(required=False),
|
||||
"limit": fields.Integer(required=False),
|
||||
"page_token": fields.String(required=False),
|
||||
"search_query": fields.String(required=False)
|
||||
}))
|
||||
@api.doc(description="List files from a connector provider (supports pagination and search)")
|
||||
def post(self):
|
||||
try:
|
||||
data = request.get_json()
|
||||
provider = data.get('provider')
|
||||
session_token = data.get('session_token')
|
||||
folder_id = data.get('folder_id')
|
||||
limit = data.get('limit', 10)
|
||||
page_token = data.get('page_token')
|
||||
search_query = data.get('search_query')
|
||||
|
||||
if not provider or not session_token:
|
||||
return make_response(jsonify({"success": False, "error": "provider and session_token are required"}), 400)
|
||||
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
||||
user = decoded_token.get('sub')
|
||||
session = sessions_collection.find_one({"session_token": session_token, "user": user})
|
||||
if not session:
|
||||
return make_response(jsonify({"success": False, "error": "Invalid or unauthorized session"}), 401)
|
||||
|
||||
loader = ConnectorCreator.create_connector(provider, session_token)
|
||||
input_config = {
|
||||
'limit': limit,
|
||||
'list_only': True,
|
||||
'session_token': session_token,
|
||||
'folder_id': folder_id,
|
||||
'page_token': page_token
|
||||
}
|
||||
if search_query:
|
||||
input_config['search_query'] = search_query
|
||||
|
||||
documents = loader.load_data(input_config)
|
||||
|
||||
files = []
|
||||
for doc in documents[:limit]:
|
||||
metadata = doc.extra_info
|
||||
modified_time = metadata.get('modified_time')
|
||||
if modified_time:
|
||||
date_part = modified_time.split('T')[0]
|
||||
time_part = modified_time.split('T')[1].split('.')[0].split('Z')[0]
|
||||
formatted_time = f"{date_part} {time_part}"
|
||||
else:
|
||||
formatted_time = None
|
||||
|
||||
files.append({
|
||||
'id': doc.doc_id,
|
||||
'name': metadata.get('file_name', 'Unknown File'),
|
||||
'type': metadata.get('mime_type', 'unknown'),
|
||||
'size': metadata.get('size', None),
|
||||
'modifiedTime': formatted_time,
|
||||
'isFolder': metadata.get('is_folder', False)
|
||||
})
|
||||
|
||||
next_token = getattr(loader, 'next_page_token', None)
|
||||
has_more = bool(next_token)
|
||||
|
||||
return make_response(jsonify({
|
||||
"success": True,
|
||||
"files": files,
|
||||
"total": len(files),
|
||||
"next_page_token": next_token,
|
||||
"has_more": has_more
|
||||
}), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error loading connector files: {e}")
|
||||
return make_response(jsonify({"success": False, "error": f"Failed to load files: {str(e)}"}), 500)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/validate-session")
|
||||
class ConnectorValidateSession(Resource):
|
||||
@api.expect(api.model("ConnectorValidateSessionModel", {"provider": fields.String(required=True), "session_token": fields.String(required=True)}))
|
||||
@api.doc(description="Validate connector session token and return user info and access token")
|
||||
def post(self):
|
||||
try:
|
||||
data = request.get_json()
|
||||
provider = data.get('provider')
|
||||
session_token = data.get('session_token')
|
||||
if not provider or not session_token:
|
||||
return make_response(jsonify({"success": False, "error": "provider and session_token are required"}), 400)
|
||||
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
||||
user = decoded_token.get('sub')
|
||||
|
||||
session = sessions_collection.find_one({"session_token": session_token, "user": user})
|
||||
if not session or "token_info" not in session:
|
||||
return make_response(jsonify({"success": False, "error": "Invalid or expired session"}), 401)
|
||||
|
||||
token_info = session["token_info"]
|
||||
auth = ConnectorCreator.create_auth(provider)
|
||||
is_expired = auth.is_token_expired(token_info)
|
||||
|
||||
if is_expired and token_info.get('refresh_token'):
|
||||
try:
|
||||
refreshed_token_info = auth.refresh_access_token(token_info.get('refresh_token'))
|
||||
sanitized_token_info = {
|
||||
"access_token": refreshed_token_info.get("access_token"),
|
||||
"refresh_token": refreshed_token_info.get("refresh_token"),
|
||||
"token_uri": refreshed_token_info.get("token_uri"),
|
||||
"expiry": refreshed_token_info.get("expiry")
|
||||
}
|
||||
sessions_collection.update_one(
|
||||
{"session_token": session_token},
|
||||
{"$set": {"token_info": sanitized_token_info}}
|
||||
)
|
||||
token_info = sanitized_token_info
|
||||
is_expired = False
|
||||
except Exception as refresh_error:
|
||||
current_app.logger.error(f"Failed to refresh token: {refresh_error}")
|
||||
|
||||
if is_expired:
|
||||
return make_response(jsonify({
|
||||
"success": False,
|
||||
"expired": True,
|
||||
"error": "Session token has expired. Please reconnect."
|
||||
}), 401)
|
||||
|
||||
return make_response(jsonify({
|
||||
"success": True,
|
||||
"expired": False,
|
||||
"user_email": session.get('user_email', 'Connected User'),
|
||||
"access_token": token_info.get('access_token')
|
||||
}), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error validating connector session: {e}")
|
||||
return make_response(jsonify({"success": False, "error": str(e)}), 500)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/disconnect")
|
||||
class ConnectorDisconnect(Resource):
|
||||
@api.expect(api.model("ConnectorDisconnectModel", {"provider": fields.String(required=True), "session_token": fields.String(required=False)}))
|
||||
@api.doc(description="Disconnect a connector session")
|
||||
def post(self):
|
||||
try:
|
||||
data = request.get_json()
|
||||
provider = data.get('provider')
|
||||
session_token = data.get('session_token')
|
||||
if not provider:
|
||||
return make_response(jsonify({"success": False, "error": "provider is required"}), 400)
|
||||
|
||||
|
||||
if session_token:
|
||||
sessions_collection.delete_one({"session_token": session_token})
|
||||
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error disconnecting connector session: {e}")
|
||||
return make_response(jsonify({"success": False, "error": str(e)}), 500)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/sync")
|
||||
class ConnectorSync(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"ConnectorSyncModel",
|
||||
{
|
||||
"source_id": fields.String(required=True, description="Source ID to sync"),
|
||||
"session_token": fields.String(required=True, description="Authentication token")
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Sync connector source to check for modifications")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
|
||||
try:
|
||||
data = request.get_json()
|
||||
source_id = data.get('source_id')
|
||||
session_token = data.get('session_token')
|
||||
|
||||
if not all([source_id, session_token]):
|
||||
return make_response(
|
||||
jsonify({
|
||||
"success": False,
|
||||
"error": "source_id and session_token are required"
|
||||
}),
|
||||
400
|
||||
)
|
||||
source = sources_collection.find_one({"_id": ObjectId(source_id)})
|
||||
if not source:
|
||||
return make_response(
|
||||
jsonify({
|
||||
"success": False,
|
||||
"error": "Source not found"
|
||||
}),
|
||||
404
|
||||
)
|
||||
|
||||
if source.get('user') != decoded_token.get('sub'):
|
||||
return make_response(
|
||||
jsonify({
|
||||
"success": False,
|
||||
"error": "Unauthorized access to source"
|
||||
}),
|
||||
403
|
||||
)
|
||||
|
||||
remote_data = {}
|
||||
try:
|
||||
if source.get('remote_data'):
|
||||
remote_data = json.loads(source.get('remote_data'))
|
||||
except json.JSONDecodeError:
|
||||
current_app.logger.error(f"Invalid remote_data format for source {source_id}")
|
||||
remote_data = {}
|
||||
|
||||
source_type = remote_data.get('provider')
|
||||
if not source_type:
|
||||
return make_response(
|
||||
jsonify({
|
||||
"success": False,
|
||||
"error": "Source provider not found in remote_data"
|
||||
}),
|
||||
400
|
||||
)
|
||||
|
||||
# Extract configuration from remote_data
|
||||
file_ids = remote_data.get('file_ids', [])
|
||||
folder_ids = remote_data.get('folder_ids', [])
|
||||
recursive = remote_data.get('recursive', True)
|
||||
|
||||
# Start the sync task
|
||||
task = ingest_connector_task.delay(
|
||||
job_name=source.get('name'),
|
||||
user=decoded_token.get('sub'),
|
||||
source_type=source_type,
|
||||
session_token=session_token,
|
||||
file_ids=file_ids,
|
||||
folder_ids=folder_ids,
|
||||
recursive=recursive,
|
||||
retriever=source.get('retriever', 'classic'),
|
||||
operation_mode="sync",
|
||||
doc_id=source_id,
|
||||
sync_frequency=source.get('sync_frequency', 'never')
|
||||
)
|
||||
|
||||
return make_response(
|
||||
jsonify({
|
||||
"success": True,
|
||||
"task_id": task.id
|
||||
}),
|
||||
200
|
||||
)
|
||||
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error syncing connector source: {err}",
|
||||
exc_info=True
|
||||
)
|
||||
return make_response(
|
||||
jsonify({
|
||||
"success": False,
|
||||
"error": str(err)
|
||||
}),
|
||||
400
|
||||
)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/callback-status")
|
||||
class ConnectorCallbackStatus(Resource):
|
||||
@api.doc(description="Return HTML page with connector authentication status")
|
||||
def get(self):
|
||||
"""Return HTML page with connector authentication status"""
|
||||
try:
|
||||
status = request.args.get('status', 'error')
|
||||
message = request.args.get('message', '')
|
||||
provider = request.args.get('provider', 'connector')
|
||||
session_token = request.args.get('session_token', '')
|
||||
user_email = request.args.get('user_email', '')
|
||||
|
||||
html_content = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>{provider.replace('_', ' ').title()} Authentication</title>
|
||||
<style>
|
||||
body {{ font-family: Arial, sans-serif; text-align: center; padding: 40px; }}
|
||||
.container {{ max-width: 600px; margin: 0 auto; }}
|
||||
.success {{ color: #4CAF50; }}
|
||||
.error {{ color: #F44336; }}
|
||||
.cancelled {{ color: #FF9800; }}
|
||||
</style>
|
||||
<script>
|
||||
window.onload = function() {{
|
||||
const status = "{status}";
|
||||
const sessionToken = "{session_token}";
|
||||
const userEmail = "{user_email}";
|
||||
|
||||
if (status === "success" && window.opener) {{
|
||||
window.opener.postMessage({{
|
||||
type: '{provider}_auth_success',
|
||||
session_token: sessionToken,
|
||||
user_email: userEmail
|
||||
}}, '*');
|
||||
|
||||
setTimeout(() => window.close(), 3000);
|
||||
}} else if (status === "cancelled" || status === "error") {{
|
||||
setTimeout(() => window.close(), 3000);
|
||||
}}
|
||||
}};
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h2>{provider.replace('_', ' ').title()} Authentication</h2>
|
||||
<div class="{status}">
|
||||
<p>{message}</p>
|
||||
{f'<p>Connected as: {user_email}</p>' if status == 'success' else ''}
|
||||
</div>
|
||||
<p><small>You can close this window. {f"Your {provider.replace('_', ' ').title()} is now connected and ready to use." if status == 'success' else "Feel free to close this window."}</small></p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
return make_response(html_content, 200, {'Content-Type': 'text/html'})
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error rendering callback status page: {e}")
|
||||
return make_response("Authentication error occurred", 500, {'Content-Type': 'text/html'})
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
import datetime
|
||||
import json
|
||||
from flask import Blueprint, request, send_from_directory
|
||||
from werkzeug.utils import secure_filename
|
||||
from bson.objectid import ObjectId
|
||||
@@ -37,16 +38,28 @@ def upload_index_files():
|
||||
"""Upload two files(index.faiss, index.pkl) to the user's folder."""
|
||||
if "user" not in request.form:
|
||||
return {"status": "no user"}
|
||||
user = secure_filename(request.form["user"])
|
||||
user = request.form["user"]
|
||||
if "name" not in request.form:
|
||||
return {"status": "no name"}
|
||||
job_name = secure_filename(request.form["name"])
|
||||
tokens = secure_filename(request.form["tokens"])
|
||||
retriever = secure_filename(request.form["retriever"])
|
||||
id = secure_filename(request.form["id"])
|
||||
type = secure_filename(request.form["type"])
|
||||
job_name = request.form["name"]
|
||||
tokens = request.form["tokens"]
|
||||
retriever = request.form["retriever"]
|
||||
id = request.form["id"]
|
||||
type = request.form["type"]
|
||||
remote_data = request.form["remote_data"] if "remote_data" in request.form else None
|
||||
sync_frequency = secure_filename(request.form["sync_frequency"]) if "sync_frequency" in request.form else None
|
||||
sync_frequency = request.form["sync_frequency"] if "sync_frequency" in request.form else None
|
||||
|
||||
file_path = request.form.get("file_path")
|
||||
directory_structure = request.form.get("directory_structure")
|
||||
|
||||
if directory_structure:
|
||||
try:
|
||||
directory_structure = json.loads(directory_structure)
|
||||
except Exception:
|
||||
logger.error("Error parsing directory_structure")
|
||||
directory_structure = {}
|
||||
else:
|
||||
directory_structure = {}
|
||||
|
||||
storage = StorageCreator.get_storage()
|
||||
index_base_path = f"indexes/{id}"
|
||||
@@ -64,10 +77,13 @@ def upload_index_files():
|
||||
file_pkl = request.files["file_pkl"]
|
||||
if file_pkl.filename == "":
|
||||
return {"status": "no file name"}
|
||||
|
||||
|
||||
# Save index files to storage
|
||||
storage.save_file(file_faiss, f"{index_base_path}/index.faiss")
|
||||
storage.save_file(file_pkl, f"{index_base_path}/index.pkl")
|
||||
faiss_storage_path = f"{index_base_path}/index.faiss"
|
||||
pkl_storage_path = f"{index_base_path}/index.pkl"
|
||||
storage.save_file(file_faiss, faiss_storage_path)
|
||||
storage.save_file(file_pkl, pkl_storage_path)
|
||||
|
||||
|
||||
existing_entry = sources_collection.find_one({"_id": ObjectId(id)})
|
||||
if existing_entry:
|
||||
@@ -85,6 +101,8 @@ def upload_index_files():
|
||||
"retriever": retriever,
|
||||
"remote_data": remote_data,
|
||||
"sync_frequency": sync_frequency,
|
||||
"file_path": file_path,
|
||||
"directory_structure": directory_structure,
|
||||
}
|
||||
},
|
||||
)
|
||||
@@ -102,6 +120,8 @@ def upload_index_files():
|
||||
"retriever": retriever,
|
||||
"remote_data": remote_data,
|
||||
"sync_frequency": sync_frequency,
|
||||
"file_path": file_path,
|
||||
"directory_structure": directory_structure,
|
||||
}
|
||||
)
|
||||
return {"status": "ok"}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -5,14 +5,16 @@ from application.worker import (
|
||||
agent_webhook_worker,
|
||||
attachment_worker,
|
||||
ingest_worker,
|
||||
mcp_oauth,
|
||||
mcp_oauth_status,
|
||||
remote_worker,
|
||||
sync_worker,
|
||||
)
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def ingest(self, directory, formats, name_job, filename, user):
|
||||
resp = ingest_worker(self, directory, formats, name_job, filename, user)
|
||||
def ingest(self, directory, formats, job_name, user, file_path, filename):
|
||||
resp = ingest_worker(self, directory, formats, job_name, file_path, filename, user)
|
||||
return resp
|
||||
|
||||
|
||||
@@ -22,6 +24,14 @@ def ingest_remote(self, source_data, job_name, user, loader):
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def reingest_source_task(self, source_id, user):
|
||||
from application.worker import reingest_source_worker
|
||||
|
||||
resp = reingest_source_worker(self, source_id, user)
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def schedule_syncs(self, frequency):
|
||||
resp = sync_worker(self, frequency)
|
||||
@@ -40,6 +50,40 @@ def process_agent_webhook(self, agent_id, payload):
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def ingest_connector_task(
|
||||
self,
|
||||
job_name,
|
||||
user,
|
||||
source_type,
|
||||
session_token=None,
|
||||
file_ids=None,
|
||||
folder_ids=None,
|
||||
recursive=True,
|
||||
retriever="classic",
|
||||
operation_mode="upload",
|
||||
doc_id=None,
|
||||
sync_frequency="never",
|
||||
):
|
||||
from application.worker import ingest_connector
|
||||
|
||||
resp = ingest_connector(
|
||||
self,
|
||||
job_name,
|
||||
user,
|
||||
source_type,
|
||||
session_token=session_token,
|
||||
file_ids=file_ids,
|
||||
folder_ids=folder_ids,
|
||||
recursive=recursive,
|
||||
retriever=retriever,
|
||||
operation_mode=operation_mode,
|
||||
doc_id=doc_id,
|
||||
sync_frequency=sync_frequency,
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
@celery.on_after_configure.connect
|
||||
def setup_periodic_tasks(sender, **kwargs):
|
||||
sender.add_periodic_task(
|
||||
@@ -54,3 +98,15 @@ def setup_periodic_tasks(sender, **kwargs):
|
||||
timedelta(days=30),
|
||||
schedule_syncs.s("monthly"),
|
||||
)
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def mcp_oauth_task(self, config, user):
|
||||
resp = mcp_oauth(self, config, user)
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def mcp_oauth_status_task(self, task_id):
|
||||
resp = mcp_oauth_status(self, task_id)
|
||||
return resp
|
||||
|
||||
@@ -12,25 +12,26 @@ from application.core.logging_config import setup_logging
|
||||
|
||||
setup_logging()
|
||||
|
||||
from application.api.answer.routes import answer # noqa: E402
|
||||
from application.api import api # noqa: E402
|
||||
from application.api.answer import answer # noqa: E402
|
||||
from application.api.internal.routes import internal # noqa: E402
|
||||
from application.api.user.routes import user # noqa: E402
|
||||
from application.api.connector.routes import connector # noqa: E402
|
||||
from application.celery_init import celery # noqa: E402
|
||||
from application.core.settings import settings # noqa: E402
|
||||
from application.extensions import api # noqa: E402
|
||||
|
||||
|
||||
if platform.system() == "Windows":
|
||||
import pathlib
|
||||
|
||||
pathlib.PosixPath = pathlib.WindowsPath
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
app = Flask(__name__)
|
||||
app.register_blueprint(user)
|
||||
app.register_blueprint(answer)
|
||||
app.register_blueprint(internal)
|
||||
app.register_blueprint(connector)
|
||||
app.config.update(
|
||||
UPLOAD_FOLDER="inputs",
|
||||
CELERY_BROKER_URL=settings.CELERY_BROKER_URL,
|
||||
@@ -52,7 +53,6 @@ if settings.AUTH_TYPE in ("simple_jwt", "session_jwt") and not settings.JWT_SECR
|
||||
settings.JWT_SECRET_KEY = new_key
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to setup JWT_SECRET_KEY: {e}")
|
||||
|
||||
SIMPLE_JWT_TOKEN = None
|
||||
if settings.AUTH_TYPE == "simple_jwt":
|
||||
payload = {"sub": "local"}
|
||||
@@ -92,7 +92,6 @@ def generate_token():
|
||||
def authenticate_request():
|
||||
if request.method == "OPTIONS":
|
||||
return "", 200
|
||||
|
||||
decoded_token = handle_auth(request)
|
||||
if not decoded_token:
|
||||
request.decoded_token = None
|
||||
|
||||
@@ -10,31 +10,41 @@ current_dir = os.path.dirname(
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
AUTH_TYPE: Optional[str] = None
|
||||
LLM_NAME: str = "docsgpt"
|
||||
MODEL_NAME: Optional[str] = (
|
||||
None # if LLM_NAME is openai, MODEL_NAME can be gpt-4 or gpt-3.5-turbo
|
||||
AUTH_TYPE: Optional[str] = None # simple_jwt, session_jwt, or None
|
||||
LLM_PROVIDER: str = "docsgpt"
|
||||
LLM_NAME: Optional[str] = (
|
||||
None # if LLM_PROVIDER is openai, LLM_NAME can be gpt-4 or gpt-3.5-turbo
|
||||
)
|
||||
EMBEDDINGS_NAME: str = "huggingface_sentence-transformers/all-mpnet-base-v2"
|
||||
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
|
||||
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
|
||||
MONGO_URI: str = "mongodb://localhost:27017/docsgpt"
|
||||
MONGO_DB_NAME: str = "docsgpt"
|
||||
MODEL_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf")
|
||||
LLM_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf")
|
||||
DEFAULT_MAX_HISTORY: int = 150
|
||||
MODEL_TOKEN_LIMITS: dict = {
|
||||
LLM_TOKEN_LIMITS: dict = {
|
||||
"gpt-4o-mini": 128000,
|
||||
"gpt-3.5-turbo": 4096,
|
||||
"claude-2": 1e5,
|
||||
"gemini-2.0-flash-exp": 1e6,
|
||||
"gemini-2.5-flash": 1e6,
|
||||
}
|
||||
UPLOAD_FOLDER: str = "inputs"
|
||||
PARSE_PDF_AS_IMAGE: bool = False
|
||||
PARSE_IMAGE_REMOTE: bool = False
|
||||
VECTOR_STORE: str = (
|
||||
"faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb"
|
||||
)
|
||||
RETRIEVERS_ENABLED: list = ["classic_rag", "duckduck_search"] # also brave_search
|
||||
RETRIEVERS_ENABLED: list = ["classic_rag"]
|
||||
AGENT_NAME: str = "classic"
|
||||
FALLBACK_LLM_PROVIDER: Optional[str] = None # provider for fallback llm
|
||||
FALLBACK_LLM_NAME: Optional[str] = None # model name for fallback llm
|
||||
FALLBACK_LLM_API_KEY: Optional[str] = None # api key for fallback llm
|
||||
|
||||
# Google Drive integration
|
||||
GOOGLE_CLIENT_ID: Optional[str] = None # Replace with your actual Google OAuth client ID
|
||||
GOOGLE_CLIENT_SECRET: Optional[str] = None# Replace with your actual Google OAuth client secret
|
||||
CONNECTOR_REDIRECT_BASE_URI: Optional[str] = "http://127.0.0.1:7091/api/connectors/callback" ##add redirect url as it is to your provider's console(gcp)
|
||||
|
||||
|
||||
# LLM Cache
|
||||
CACHE_REDIS_URL: str = "redis://localhost:6379/2"
|
||||
@@ -86,6 +96,8 @@ class Settings(BaseSettings):
|
||||
QDRANT_PATH: Optional[str] = None
|
||||
QDRANT_DISTANCE_FUNC: str = "Cosine"
|
||||
|
||||
# PGVector vectorstore config
|
||||
PGVECTOR_CONNECTION_STRING: Optional[str] = None
|
||||
# Milvus vectorstore config
|
||||
MILVUS_COLLECTION_NAME: Optional[str] = "docsgpt"
|
||||
MILVUS_URI: Optional[str] = "./milvus_local.db" # milvus lite version as default
|
||||
@@ -96,14 +108,17 @@ class Settings(BaseSettings):
|
||||
LANCEDB_TABLE_NAME: Optional[str] = (
|
||||
"docsgpts" # Name of the table to use for storing vectors
|
||||
)
|
||||
BRAVE_SEARCH_API_KEY: Optional[str] = None
|
||||
|
||||
FLASK_DEBUG_MODE: bool = False
|
||||
STORAGE_TYPE: str = "local" # local or s3
|
||||
|
||||
STORAGE_TYPE: str = "local" # local or s3
|
||||
URL_STRATEGY: str = "backend" # backend or s3
|
||||
|
||||
JWT_SECRET_KEY: str = ""
|
||||
|
||||
# Encryption settings
|
||||
ENCRYPTION_SECRET_KEY: str = "default-docsgpt-encryption-key"
|
||||
|
||||
ELEVENLABS_API_KEY: Optional[str] = None
|
||||
|
||||
path = Path(__file__).parent.parent.absolute()
|
||||
settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8")
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
from flask_restx import Api
|
||||
|
||||
api = Api(
|
||||
version="1.0",
|
||||
title="DocsGPT API",
|
||||
description="API for DocsGPT",
|
||||
)
|
||||
@@ -1,53 +1,117 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from application.cache import gen_cache, stream_cache
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.usage import gen_token_usage, stream_token_usage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseLLM(ABC):
|
||||
def __init__(self, decoded_token=None):
|
||||
def __init__(
|
||||
self,
|
||||
decoded_token=None,
|
||||
):
|
||||
self.decoded_token = decoded_token
|
||||
self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
|
||||
self.fallback_provider = settings.FALLBACK_LLM_PROVIDER
|
||||
self.fallback_model_name = settings.FALLBACK_LLM_NAME
|
||||
self.fallback_llm_api_key = settings.FALLBACK_LLM_API_KEY
|
||||
self._fallback_llm = None
|
||||
|
||||
def _apply_decorator(self, method, decorators, *args, **kwargs):
|
||||
for decorator in decorators:
|
||||
method = decorator(method)
|
||||
return method(self, *args, **kwargs)
|
||||
@property
|
||||
def fallback_llm(self):
|
||||
"""Lazy-loaded fallback LLM instance."""
|
||||
if (
|
||||
self._fallback_llm is None
|
||||
and self.fallback_provider
|
||||
and self.fallback_model_name
|
||||
):
|
||||
try:
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
|
||||
self._fallback_llm = LLMCreator.create_llm(
|
||||
self.fallback_provider,
|
||||
self.fallback_llm_api_key,
|
||||
None,
|
||||
self.decoded_token,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to initialize fallback LLM: {str(e)}", exc_info=True
|
||||
)
|
||||
return self._fallback_llm
|
||||
|
||||
def _execute_with_fallback(
|
||||
self, method_name: str, decorators: list, *args, **kwargs
|
||||
):
|
||||
"""
|
||||
Unified method execution with fallback support.
|
||||
|
||||
Args:
|
||||
method_name: Name of the raw method ('_raw_gen' or '_raw_gen_stream')
|
||||
decorators: List of decorators to apply
|
||||
*args: Positional arguments
|
||||
**kwargs: Keyword arguments
|
||||
"""
|
||||
|
||||
def decorated_method():
|
||||
method = getattr(self, method_name)
|
||||
for decorator in decorators:
|
||||
method = decorator(method)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
try:
|
||||
return decorated_method()
|
||||
except Exception as e:
|
||||
if not self.fallback_llm:
|
||||
logger.error(f"Primary LLM failed and no fallback available: {str(e)}")
|
||||
raise
|
||||
logger.warning(
|
||||
f"Falling back to {self.fallback_provider}/{self.fallback_model_name}. Error: {str(e)}"
|
||||
)
|
||||
|
||||
fallback_method = getattr(
|
||||
self.fallback_llm, method_name.replace("_raw_", "")
|
||||
)
|
||||
return fallback_method(*args, **kwargs)
|
||||
|
||||
def gen(self, model, messages, stream=False, tools=None, *args, **kwargs):
|
||||
decorators = [gen_token_usage, gen_cache]
|
||||
return self._execute_with_fallback(
|
||||
"_raw_gen",
|
||||
decorators,
|
||||
model=model,
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
tools=tools,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def gen_stream(self, model, messages, stream=True, tools=None, *args, **kwargs):
|
||||
decorators = [stream_cache, stream_token_usage]
|
||||
return self._execute_with_fallback(
|
||||
"_raw_gen_stream",
|
||||
decorators,
|
||||
model=model,
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
tools=tools,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _raw_gen(self, model, messages, stream, tools, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def gen(self, model, messages, stream=False, tools=None, *args, **kwargs):
|
||||
decorators = [gen_token_usage, gen_cache]
|
||||
return self._apply_decorator(
|
||||
self._raw_gen,
|
||||
decorators=decorators,
|
||||
model=model,
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
tools=tools,
|
||||
*args,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _raw_gen_stream(self, model, messages, stream, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def gen_stream(self, model, messages, stream=True, tools=None, *args, **kwargs):
|
||||
decorators = [stream_cache, stream_token_usage]
|
||||
return self._apply_decorator(
|
||||
self._raw_gen_stream,
|
||||
decorators=decorators,
|
||||
model=model,
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
tools=tools,
|
||||
*args,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def supports_tools(self):
|
||||
return hasattr(self, "_supports_tools") and callable(
|
||||
getattr(self, "_supports_tools")
|
||||
@@ -55,12 +119,26 @@ class BaseLLM(ABC):
|
||||
|
||||
def _supports_tools(self):
|
||||
raise NotImplementedError("Subclass must implement _supports_tools method")
|
||||
|
||||
|
||||
def supports_structured_output(self):
|
||||
"""Check if the LLM supports structured output/JSON schema enforcement"""
|
||||
return hasattr(self, "_supports_structured_output") and callable(
|
||||
getattr(self, "_supports_structured_output")
|
||||
)
|
||||
|
||||
def _supports_structured_output(self):
|
||||
return False
|
||||
|
||||
def prepare_structured_output_format(self, json_schema):
|
||||
"""Prepare structured output format specific to the LLM provider"""
|
||||
_ = json_schema
|
||||
return None
|
||||
|
||||
def get_supported_attachment_types(self):
|
||||
"""
|
||||
Return a list of MIME types supported by this LLM for file uploads.
|
||||
|
||||
|
||||
Returns:
|
||||
list: List of supported MIME types
|
||||
"""
|
||||
return [] # Default: no attachments supported
|
||||
return []
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
import logging
|
||||
import json
|
||||
|
||||
from application.core.settings import settings
|
||||
|
||||
from application.llm.base import BaseLLM
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
class GoogleLLM(BaseLLM):
|
||||
@@ -24,12 +26,12 @@ class GoogleLLM(BaseLLM):
|
||||
list: List of supported MIME types
|
||||
"""
|
||||
return [
|
||||
'application/pdf',
|
||||
'image/png',
|
||||
'image/jpeg',
|
||||
'image/jpg',
|
||||
'image/webp',
|
||||
'image/gif'
|
||||
"application/pdf",
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/jpg",
|
||||
"image/webp",
|
||||
"image/gif",
|
||||
]
|
||||
|
||||
def prepare_messages_with_attachments(self, messages, attachments=None):
|
||||
@@ -70,26 +72,30 @@ class GoogleLLM(BaseLLM):
|
||||
|
||||
files = []
|
||||
for attachment in attachments:
|
||||
mime_type = attachment.get('mime_type')
|
||||
mime_type = attachment.get("mime_type")
|
||||
|
||||
if mime_type in self.get_supported_attachment_types():
|
||||
try:
|
||||
file_uri = self._upload_file_to_google(attachment)
|
||||
logging.info(f"GoogleLLM: Successfully uploaded file, got URI: {file_uri}")
|
||||
logging.info(
|
||||
f"GoogleLLM: Successfully uploaded file, got URI: {file_uri}"
|
||||
)
|
||||
files.append({"file_uri": file_uri, "mime_type": mime_type})
|
||||
except Exception as e:
|
||||
logging.error(f"GoogleLLM: Error uploading file: {e}", exc_info=True)
|
||||
if 'content' in attachment:
|
||||
prepared_messages[user_message_index]["content"].append({
|
||||
"type": "text",
|
||||
"text": f"[File could not be processed: {attachment.get('path', 'unknown')}]"
|
||||
})
|
||||
logging.error(
|
||||
f"GoogleLLM: Error uploading file: {e}", exc_info=True
|
||||
)
|
||||
if "content" in attachment:
|
||||
prepared_messages[user_message_index]["content"].append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"[File could not be processed: {attachment.get('path', 'unknown')}]",
|
||||
}
|
||||
)
|
||||
|
||||
if files:
|
||||
logging.info(f"GoogleLLM: Adding {len(files)} files to message")
|
||||
prepared_messages[user_message_index]["content"].append({
|
||||
"files": files
|
||||
})
|
||||
prepared_messages[user_message_index]["content"].append({"files": files})
|
||||
|
||||
return prepared_messages
|
||||
|
||||
@@ -103,10 +109,10 @@ class GoogleLLM(BaseLLM):
|
||||
Returns:
|
||||
str: Google AI file URI for the uploaded file.
|
||||
"""
|
||||
if 'google_file_uri' in attachment:
|
||||
return attachment['google_file_uri']
|
||||
if "google_file_uri" in attachment:
|
||||
return attachment["google_file_uri"]
|
||||
|
||||
file_path = attachment.get('path')
|
||||
file_path = attachment.get("path")
|
||||
if not file_path:
|
||||
raise ValueError("No file path provided in attachment")
|
||||
|
||||
@@ -116,17 +122,19 @@ class GoogleLLM(BaseLLM):
|
||||
try:
|
||||
file_uri = self.storage.process_file(
|
||||
file_path,
|
||||
lambda local_path, **kwargs: self.client.files.upload(file=local_path).uri
|
||||
lambda local_path, **kwargs: self.client.files.upload(
|
||||
file=local_path
|
||||
).uri,
|
||||
)
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
attachments_collection = db["attachments"]
|
||||
if '_id' in attachment:
|
||||
if "_id" in attachment:
|
||||
attachments_collection.update_one(
|
||||
{"_id": attachment['_id']},
|
||||
{"$set": {"google_file_uri": file_uri}}
|
||||
{"_id": attachment["_id"]}, {"$set": {"google_file_uri": file_uri}}
|
||||
)
|
||||
|
||||
return file_uri
|
||||
@@ -135,6 +143,7 @@ class GoogleLLM(BaseLLM):
|
||||
raise
|
||||
|
||||
def _clean_messages_google(self, messages):
|
||||
"""Convert OpenAI format messages to Google AI format."""
|
||||
cleaned_messages = []
|
||||
for message in messages:
|
||||
role = message.get("role")
|
||||
@@ -142,6 +151,8 @@ class GoogleLLM(BaseLLM):
|
||||
|
||||
if role == "assistant":
|
||||
role = "model"
|
||||
elif role == "tool":
|
||||
role = "model"
|
||||
|
||||
parts = []
|
||||
if role and content is not None:
|
||||
@@ -166,13 +177,13 @@ class GoogleLLM(BaseLLM):
|
||||
)
|
||||
)
|
||||
elif "files" in item:
|
||||
for file_data in item["files"]:
|
||||
parts.append(
|
||||
types.Part.from_uri(
|
||||
file_uri=file_data["file_uri"],
|
||||
mime_type=file_data["mime_type"]
|
||||
)
|
||||
for file_data in item["files"]:
|
||||
parts.append(
|
||||
types.Part.from_uri(
|
||||
file_uri=file_data["file_uri"],
|
||||
mime_type=file_data["mime_type"],
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected content dictionary format:{item}"
|
||||
@@ -180,11 +191,63 @@ class GoogleLLM(BaseLLM):
|
||||
else:
|
||||
raise ValueError(f"Unexpected content type: {type(content)}")
|
||||
|
||||
cleaned_messages.append(types.Content(role=role, parts=parts))
|
||||
if parts:
|
||||
cleaned_messages.append(types.Content(role=role, parts=parts))
|
||||
|
||||
return cleaned_messages
|
||||
|
||||
def _clean_schema(self, schema_obj):
|
||||
"""
|
||||
Recursively remove unsupported fields from schema objects
|
||||
and validate required properties.
|
||||
"""
|
||||
if not isinstance(schema_obj, dict):
|
||||
return schema_obj
|
||||
allowed_fields = {
|
||||
"type",
|
||||
"description",
|
||||
"items",
|
||||
"properties",
|
||||
"required",
|
||||
"enum",
|
||||
"pattern",
|
||||
"minimum",
|
||||
"maximum",
|
||||
"nullable",
|
||||
"default",
|
||||
}
|
||||
|
||||
cleaned = {}
|
||||
for key, value in schema_obj.items():
|
||||
if key not in allowed_fields:
|
||||
continue
|
||||
elif key == "type" and isinstance(value, str):
|
||||
cleaned[key] = value.upper()
|
||||
elif isinstance(value, dict):
|
||||
cleaned[key] = self._clean_schema(value)
|
||||
elif isinstance(value, list):
|
||||
cleaned[key] = [self._clean_schema(item) for item in value]
|
||||
else:
|
||||
cleaned[key] = value
|
||||
|
||||
# Validate that required properties actually exist in properties
|
||||
if "required" in cleaned and "properties" in cleaned:
|
||||
valid_required = []
|
||||
properties_keys = set(cleaned["properties"].keys())
|
||||
for required_prop in cleaned["required"]:
|
||||
if required_prop in properties_keys:
|
||||
valid_required.append(required_prop)
|
||||
if valid_required:
|
||||
cleaned["required"] = valid_required
|
||||
else:
|
||||
cleaned.pop("required", None)
|
||||
elif "required" in cleaned and "properties" not in cleaned:
|
||||
cleaned.pop("required", None)
|
||||
|
||||
return cleaned
|
||||
|
||||
def _clean_tools_format(self, tools_list):
|
||||
"""Convert OpenAI format tools to Google AI format."""
|
||||
genai_tools = []
|
||||
for tool_data in tools_list:
|
||||
if tool_data["type"] == "function":
|
||||
@@ -193,18 +256,16 @@ class GoogleLLM(BaseLLM):
|
||||
properties = parameters.get("properties", {})
|
||||
|
||||
if properties:
|
||||
cleaned_properties = {}
|
||||
for k, v in properties.items():
|
||||
cleaned_properties[k] = self._clean_schema(v)
|
||||
|
||||
genai_function = dict(
|
||||
name=function["name"],
|
||||
description=function["description"],
|
||||
parameters={
|
||||
"type": "OBJECT",
|
||||
"properties": {
|
||||
k: {
|
||||
**v,
|
||||
"type": v["type"].upper() if v["type"] else None,
|
||||
}
|
||||
for k, v in properties.items()
|
||||
},
|
||||
"properties": cleaned_properties,
|
||||
"required": (
|
||||
parameters["required"]
|
||||
if "required" in parameters
|
||||
@@ -231,8 +292,10 @@ class GoogleLLM(BaseLLM):
|
||||
stream=False,
|
||||
tools=None,
|
||||
formatting="openai",
|
||||
response_schema=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Generate content using Google AI API without streaming."""
|
||||
client = genai.Client(api_key=self.api_key)
|
||||
if formatting == "openai":
|
||||
messages = self._clean_messages_google(messages)
|
||||
@@ -244,16 +307,21 @@ class GoogleLLM(BaseLLM):
|
||||
if tools:
|
||||
cleaned_tools = self._clean_tools_format(tools)
|
||||
config.tools = cleaned_tools
|
||||
response = client.models.generate_content(
|
||||
model=model,
|
||||
contents=messages,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Add response schema for structured output if provided
|
||||
if response_schema:
|
||||
config.response_schema = response_schema
|
||||
config.response_mime_type = "application/json"
|
||||
|
||||
response = client.models.generate_content(
|
||||
model=model,
|
||||
contents=messages,
|
||||
config=config,
|
||||
)
|
||||
|
||||
if tools:
|
||||
return response
|
||||
else:
|
||||
response = client.models.generate_content(
|
||||
model=model, contents=messages, config=config
|
||||
)
|
||||
return response.text
|
||||
|
||||
def _raw_gen_stream(
|
||||
@@ -264,8 +332,10 @@ class GoogleLLM(BaseLLM):
|
||||
stream=True,
|
||||
tools=None,
|
||||
formatting="openai",
|
||||
response_schema=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Generate content using Google AI API with streaming."""
|
||||
client = genai.Client(api_key=self.api_key)
|
||||
if formatting == "openai":
|
||||
messages = self._clean_messages_google(messages)
|
||||
@@ -278,17 +348,24 @@ class GoogleLLM(BaseLLM):
|
||||
cleaned_tools = self._clean_tools_format(tools)
|
||||
config.tools = cleaned_tools
|
||||
|
||||
# Add response schema for structured output if provided
|
||||
if response_schema:
|
||||
config.response_schema = response_schema
|
||||
config.response_mime_type = "application/json"
|
||||
|
||||
# Check if we have both tools and file attachments
|
||||
has_attachments = False
|
||||
for message in messages:
|
||||
for part in message.parts:
|
||||
if hasattr(part, 'file_data') and part.file_data is not None:
|
||||
if hasattr(part, "file_data") and part.file_data is not None:
|
||||
has_attachments = True
|
||||
break
|
||||
if has_attachments:
|
||||
break
|
||||
|
||||
logging.info(f"GoogleLLM: Starting stream generation. Model: {model}, Messages: {json.dumps(messages, default=str)}, Has attachments: {has_attachments}")
|
||||
logging.info(
|
||||
f"GoogleLLM: Starting stream generation. Model: {model}, Messages: {json.dumps(messages, default=str)}, Has attachments: {has_attachments}"
|
||||
)
|
||||
|
||||
response = client.models.generate_content_stream(
|
||||
model=model,
|
||||
@@ -296,7 +373,6 @@ class GoogleLLM(BaseLLM):
|
||||
config=config,
|
||||
)
|
||||
|
||||
|
||||
for chunk in response:
|
||||
if hasattr(chunk, "candidates") and chunk.candidates:
|
||||
for candidate in chunk.candidates:
|
||||
@@ -310,4 +386,79 @@ class GoogleLLM(BaseLLM):
|
||||
yield chunk.text
|
||||
|
||||
def _supports_tools(self):
|
||||
"""Return whether this LLM supports function calling."""
|
||||
return True
|
||||
|
||||
def _supports_structured_output(self):
|
||||
"""Return whether this LLM supports structured JSON output."""
|
||||
return True
|
||||
|
||||
def prepare_structured_output_format(self, json_schema):
|
||||
"""Convert JSON schema to Google AI structured output format."""
|
||||
if not json_schema:
|
||||
return None
|
||||
|
||||
type_map = {
|
||||
"object": "OBJECT",
|
||||
"array": "ARRAY",
|
||||
"string": "STRING",
|
||||
"integer": "INTEGER",
|
||||
"number": "NUMBER",
|
||||
"boolean": "BOOLEAN",
|
||||
}
|
||||
|
||||
def convert(schema):
|
||||
if not isinstance(schema, dict):
|
||||
return schema
|
||||
|
||||
result = {}
|
||||
schema_type = schema.get("type")
|
||||
if schema_type:
|
||||
result["type"] = type_map.get(schema_type.lower(), schema_type.upper())
|
||||
|
||||
for key in [
|
||||
"description",
|
||||
"nullable",
|
||||
"enum",
|
||||
"minItems",
|
||||
"maxItems",
|
||||
"required",
|
||||
"propertyOrdering",
|
||||
]:
|
||||
if key in schema:
|
||||
result[key] = schema[key]
|
||||
|
||||
if "format" in schema:
|
||||
format_value = schema["format"]
|
||||
if schema_type == "string":
|
||||
if format_value == "date":
|
||||
result["format"] = "date-time"
|
||||
elif format_value in ["enum", "date-time"]:
|
||||
result["format"] = format_value
|
||||
else:
|
||||
result["format"] = format_value
|
||||
|
||||
if "properties" in schema:
|
||||
result["properties"] = {
|
||||
k: convert(v) for k, v in schema["properties"].items()
|
||||
}
|
||||
if "propertyOrdering" not in result and result.get("type") == "OBJECT":
|
||||
result["propertyOrdering"] = list(result["properties"].keys())
|
||||
|
||||
if "items" in schema:
|
||||
result["items"] = convert(schema["items"])
|
||||
|
||||
for field in ["anyOf", "oneOf", "allOf"]:
|
||||
if field in schema:
|
||||
result[field] = [convert(s) for s in schema[field]]
|
||||
|
||||
return result
|
||||
|
||||
try:
|
||||
return convert(json_schema)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error preparing structured output format for Google: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
0
application/llm/handlers/__init__.py
Normal file
0
application/llm/handlers/__init__.py
Normal file
351
application/llm/handlers/base.py
Normal file
351
application/llm/handlers/base.py
Normal file
@@ -0,0 +1,351 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Generator, List, Optional, Union
|
||||
|
||||
from application.logging import build_stack_data
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCall:
|
||||
"""Represents a tool/function call from the LLM."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
arguments: Union[str, Dict]
|
||||
index: Optional[int] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> "ToolCall":
|
||||
"""Create ToolCall from dictionary."""
|
||||
return cls(
|
||||
id=data.get("id", ""),
|
||||
name=data.get("name", ""),
|
||||
arguments=data.get("arguments", {}),
|
||||
index=data.get("index"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
"""Represents a response from the LLM."""
|
||||
|
||||
content: str
|
||||
tool_calls: List[ToolCall]
|
||||
finish_reason: str
|
||||
raw_response: Any
|
||||
|
||||
@property
|
||||
def requires_tool_call(self) -> bool:
|
||||
"""Check if the response requires tool calls."""
|
||||
return bool(self.tool_calls) and self.finish_reason == "tool_calls"
|
||||
|
||||
|
||||
class LLMHandler(ABC):
|
||||
"""Abstract base class for LLM handlers."""
|
||||
|
||||
def __init__(self):
|
||||
self.llm_calls = []
|
||||
self.tool_calls = []
|
||||
|
||||
@abstractmethod
|
||||
def parse_response(self, response: Any) -> LLMResponse:
|
||||
"""Parse raw LLM response into standardized format."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict:
|
||||
"""Create a tool result message for the conversation history."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _iterate_stream(self, response: Any) -> Generator:
|
||||
"""Iterate through streaming response chunks."""
|
||||
pass
|
||||
|
||||
def process_message_flow(
|
||||
self,
|
||||
agent,
|
||||
initial_response,
|
||||
tools_dict: Dict,
|
||||
messages: List[Dict],
|
||||
attachments: Optional[List] = None,
|
||||
stream: bool = False,
|
||||
) -> Union[str, Generator]:
|
||||
"""
|
||||
Main orchestration method for processing LLM message flow.
|
||||
|
||||
Args:
|
||||
agent: The agent instance
|
||||
initial_response: Initial LLM response
|
||||
tools_dict: Dictionary of available tools
|
||||
messages: Conversation history
|
||||
attachments: Optional attachments
|
||||
stream: Whether to use streaming
|
||||
|
||||
Returns:
|
||||
Final response or generator for streaming
|
||||
"""
|
||||
messages = self.prepare_messages(agent, messages, attachments)
|
||||
|
||||
if stream:
|
||||
return self.handle_streaming(agent, initial_response, tools_dict, messages)
|
||||
else:
|
||||
return self.handle_non_streaming(
|
||||
agent, initial_response, tools_dict, messages
|
||||
)
|
||||
|
||||
def prepare_messages(
|
||||
self, agent, messages: List[Dict], attachments: Optional[List] = None
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Prepare messages with attachments and provider-specific formatting.
|
||||
|
||||
Args:
|
||||
agent: The agent instance
|
||||
messages: Original messages
|
||||
attachments: List of attachments
|
||||
|
||||
Returns:
|
||||
Prepared messages list
|
||||
"""
|
||||
if not attachments:
|
||||
return messages
|
||||
logger.info(f"Preparing messages with {len(attachments)} attachments")
|
||||
supported_types = agent.llm.get_supported_attachment_types()
|
||||
|
||||
supported_attachments = [
|
||||
a for a in attachments if a.get("mime_type") in supported_types
|
||||
]
|
||||
unsupported_attachments = [
|
||||
a for a in attachments if a.get("mime_type") not in supported_types
|
||||
]
|
||||
|
||||
# Process supported attachments with the LLM's custom method
|
||||
|
||||
if supported_attachments:
|
||||
logger.info(
|
||||
f"Processing {len(supported_attachments)} supported attachments"
|
||||
)
|
||||
messages = agent.llm.prepare_messages_with_attachments(
|
||||
messages, supported_attachments
|
||||
)
|
||||
# Process unsupported attachments with default method
|
||||
|
||||
if unsupported_attachments:
|
||||
logger.info(
|
||||
f"Processing {len(unsupported_attachments)} unsupported attachments"
|
||||
)
|
||||
messages = self._append_unsupported_attachments(
|
||||
messages, unsupported_attachments
|
||||
)
|
||||
return messages
|
||||
|
||||
def _append_unsupported_attachments(
|
||||
self, messages: List[Dict], attachments: List[Dict]
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Default method to append unsupported attachment content to system prompt.
|
||||
|
||||
Args:
|
||||
messages: Current messages
|
||||
attachments: List of unsupported attachments
|
||||
|
||||
Returns:
|
||||
Updated messages list
|
||||
"""
|
||||
prepared_messages = messages.copy()
|
||||
attachment_texts = []
|
||||
|
||||
for attachment in attachments:
|
||||
logger.info(f"Adding attachment {attachment.get('id')} to context")
|
||||
if "content" in attachment:
|
||||
attachment_texts.append(
|
||||
f"Attached file content:\n\n{attachment['content']}"
|
||||
)
|
||||
if attachment_texts:
|
||||
combined_text = "\n\n".join(attachment_texts)
|
||||
|
||||
system_msg = next(
|
||||
(msg for msg in prepared_messages if msg.get("role") == "system"),
|
||||
{"role": "system", "content": ""},
|
||||
)
|
||||
|
||||
if system_msg not in prepared_messages:
|
||||
prepared_messages.insert(0, system_msg)
|
||||
system_msg["content"] += f"\n\n{combined_text}"
|
||||
return prepared_messages
|
||||
|
||||
def handle_tool_calls(
|
||||
self, agent, tool_calls: List[ToolCall], tools_dict: Dict, messages: List[Dict]
|
||||
) -> Generator:
|
||||
"""
|
||||
Execute tool calls and update conversation history.
|
||||
|
||||
Args:
|
||||
agent: The agent instance
|
||||
tool_calls: List of tool calls to execute
|
||||
tools_dict: Available tools dictionary
|
||||
messages: Current conversation history
|
||||
|
||||
Returns:
|
||||
Updated messages list
|
||||
"""
|
||||
updated_messages = messages.copy()
|
||||
|
||||
for call in tool_calls:
|
||||
try:
|
||||
self.tool_calls.append(call)
|
||||
tool_executor_gen = agent._execute_tool_action(tools_dict, call)
|
||||
while True:
|
||||
try:
|
||||
yield next(tool_executor_gen)
|
||||
except StopIteration as e:
|
||||
tool_response, call_id = e.value
|
||||
break
|
||||
updated_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"function_call": {
|
||||
"name": call.name,
|
||||
"args": call.arguments,
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
updated_messages.append(self.create_tool_message(call, tool_response))
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing tool: {str(e)}", exc_info=True)
|
||||
error_call = ToolCall(
|
||||
id=call.id, name=call.name, arguments=call.arguments
|
||||
)
|
||||
error_response = f"Error executing tool: {str(e)}"
|
||||
error_message = self.create_tool_message(error_call, error_response)
|
||||
updated_messages.append(error_message)
|
||||
|
||||
call_parts = call.name.split("_")
|
||||
if len(call_parts) >= 2:
|
||||
tool_id = call_parts[-1] # Last part is tool ID (e.g., "1")
|
||||
action_name = "_".join(call_parts[:-1])
|
||||
tool_name = tools_dict.get(tool_id, {}).get("name", "unknown_tool")
|
||||
full_action_name = f"{action_name}_{tool_id}"
|
||||
else:
|
||||
tool_name = "unknown_tool"
|
||||
action_name = call.name
|
||||
full_action_name = call.name
|
||||
yield {
|
||||
"type": "tool_call",
|
||||
"data": {
|
||||
"tool_name": tool_name,
|
||||
"call_id": call.id,
|
||||
"action_name": full_action_name,
|
||||
"arguments": call.arguments,
|
||||
"error": error_response,
|
||||
"status": "error",
|
||||
},
|
||||
}
|
||||
return updated_messages
|
||||
|
||||
def handle_non_streaming(
|
||||
self, agent, response: Any, tools_dict: Dict, messages: List[Dict]
|
||||
) -> Generator:
|
||||
"""
|
||||
Handle non-streaming response flow.
|
||||
|
||||
Args:
|
||||
agent: The agent instance
|
||||
response: Current LLM response
|
||||
tools_dict: Available tools dictionary
|
||||
messages: Conversation history
|
||||
|
||||
Returns:
|
||||
Final response after processing all tool calls
|
||||
"""
|
||||
parsed = self.parse_response(response)
|
||||
self.llm_calls.append(build_stack_data(agent.llm))
|
||||
|
||||
while parsed.requires_tool_call:
|
||||
tool_handler_gen = self.handle_tool_calls(
|
||||
agent, parsed.tool_calls, tools_dict, messages
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
yield next(tool_handler_gen)
|
||||
except StopIteration as e:
|
||||
messages = e.value
|
||||
break
|
||||
response = agent.llm.gen(
|
||||
model=agent.gpt_model, messages=messages, tools=agent.tools
|
||||
)
|
||||
parsed = self.parse_response(response)
|
||||
self.llm_calls.append(build_stack_data(agent.llm))
|
||||
return parsed.content
|
||||
|
||||
def handle_streaming(
|
||||
self, agent, response: Any, tools_dict: Dict, messages: List[Dict]
|
||||
) -> Generator:
|
||||
"""
|
||||
Handle streaming response flow.
|
||||
|
||||
Args:
|
||||
agent: The agent instance
|
||||
response: Current LLM response
|
||||
tools_dict: Available tools dictionary
|
||||
messages: Conversation history
|
||||
|
||||
Yields:
|
||||
Streaming response chunks
|
||||
"""
|
||||
buffer = ""
|
||||
tool_calls = {}
|
||||
|
||||
for chunk in self._iterate_stream(response):
|
||||
if isinstance(chunk, str):
|
||||
yield chunk
|
||||
continue
|
||||
parsed = self.parse_response(chunk)
|
||||
|
||||
if parsed.tool_calls:
|
||||
for call in parsed.tool_calls:
|
||||
if call.index not in tool_calls:
|
||||
tool_calls[call.index] = call
|
||||
else:
|
||||
existing = tool_calls[call.index]
|
||||
if call.id:
|
||||
existing.id = call.id
|
||||
if call.name:
|
||||
existing.name = call.name
|
||||
if call.arguments:
|
||||
existing.arguments += call.arguments
|
||||
if parsed.finish_reason == "tool_calls":
|
||||
tool_handler_gen = self.handle_tool_calls(
|
||||
agent, list(tool_calls.values()), tools_dict, messages
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
yield next(tool_handler_gen)
|
||||
except StopIteration as e:
|
||||
messages = e.value
|
||||
break
|
||||
tool_calls = {}
|
||||
|
||||
response = agent.llm.gen_stream(
|
||||
model=agent.gpt_model, messages=messages, tools=agent.tools
|
||||
)
|
||||
self.llm_calls.append(build_stack_data(agent.llm))
|
||||
|
||||
yield from self.handle_streaming(agent, response, tools_dict, messages)
|
||||
return
|
||||
if parsed.content:
|
||||
buffer += parsed.content
|
||||
yield buffer
|
||||
buffer = ""
|
||||
if parsed.finish_reason == "stop":
|
||||
return
|
||||
78
application/llm/handlers/google.py
Normal file
78
application/llm/handlers/google.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import uuid
|
||||
from typing import Any, Dict, Generator
|
||||
|
||||
from application.llm.handlers.base import LLMHandler, LLMResponse, ToolCall
|
||||
|
||||
|
||||
class GoogleLLMHandler(LLMHandler):
|
||||
"""Handler for Google's GenAI API."""
|
||||
|
||||
def parse_response(self, response: Any) -> LLMResponse:
|
||||
"""Parse Google response into standardized format."""
|
||||
|
||||
if isinstance(response, str):
|
||||
return LLMResponse(
|
||||
content=response,
|
||||
tool_calls=[],
|
||||
finish_reason="stop",
|
||||
raw_response=response,
|
||||
)
|
||||
if hasattr(response, "candidates"):
|
||||
parts = response.candidates[0].content.parts if response.candidates else []
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
id=str(uuid.uuid4()),
|
||||
name=part.function_call.name,
|
||||
arguments=part.function_call.args,
|
||||
)
|
||||
for part in parts
|
||||
if hasattr(part, "function_call") and part.function_call is not None
|
||||
]
|
||||
|
||||
content = " ".join(
|
||||
part.text
|
||||
for part in parts
|
||||
if hasattr(part, "text") and part.text is not None
|
||||
)
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason="tool_calls" if tool_calls else "stop",
|
||||
raw_response=response,
|
||||
)
|
||||
else:
|
||||
tool_calls = []
|
||||
if hasattr(response, "function_call"):
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=str(uuid.uuid4()),
|
||||
name=response.function_call.name,
|
||||
arguments=response.function_call.args,
|
||||
)
|
||||
)
|
||||
return LLMResponse(
|
||||
content=response.text if hasattr(response, "text") else "",
|
||||
tool_calls=tool_calls,
|
||||
finish_reason="tool_calls" if tool_calls else "stop",
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict:
|
||||
"""Create Google-style tool message."""
|
||||
|
||||
return {
|
||||
"role": "model",
|
||||
"content": [
|
||||
{
|
||||
"function_response": {
|
||||
"name": tool_call.name,
|
||||
"response": {"result": result},
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
def _iterate_stream(self, response: Any) -> Generator:
|
||||
"""Iterate through Google streaming response."""
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
18
application/llm/handlers/handler_creator.py
Normal file
18
application/llm/handlers/handler_creator.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from application.llm.handlers.base import LLMHandler
|
||||
from application.llm.handlers.google import GoogleLLMHandler
|
||||
from application.llm.handlers.openai import OpenAILLMHandler
|
||||
|
||||
|
||||
class LLMHandlerCreator:
|
||||
handlers = {
|
||||
"openai": OpenAILLMHandler,
|
||||
"google": GoogleLLMHandler,
|
||||
"default": OpenAILLMHandler,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create_handler(cls, llm_type: str, *args, **kwargs) -> LLMHandler:
|
||||
handler_class = cls.handlers.get(llm_type.lower())
|
||||
if not handler_class:
|
||||
handler_class = OpenAILLMHandler
|
||||
return handler_class(*args, **kwargs)
|
||||
57
application/llm/handlers/openai.py
Normal file
57
application/llm/handlers/openai.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from typing import Any, Dict, Generator
|
||||
|
||||
from application.llm.handlers.base import LLMHandler, LLMResponse, ToolCall
|
||||
|
||||
|
||||
class OpenAILLMHandler(LLMHandler):
|
||||
"""Handler for OpenAI API."""
|
||||
|
||||
def parse_response(self, response: Any) -> LLMResponse:
|
||||
"""Parse OpenAI response into standardized format."""
|
||||
if isinstance(response, str):
|
||||
return LLMResponse(
|
||||
content=response,
|
||||
tool_calls=[],
|
||||
finish_reason="stop",
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
message = getattr(response, "message", None) or getattr(response, "delta", None)
|
||||
|
||||
tool_calls = []
|
||||
if hasattr(message, "tool_calls"):
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
id=getattr(tc, "id", ""),
|
||||
name=getattr(tc.function, "name", ""),
|
||||
arguments=getattr(tc.function, "arguments", ""),
|
||||
index=getattr(tc, "index", None),
|
||||
)
|
||||
for tc in message.tool_calls or []
|
||||
]
|
||||
return LLMResponse(
|
||||
content=getattr(message, "content", ""),
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=getattr(response, "finish_reason", ""),
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict:
|
||||
"""Create OpenAI-style tool message."""
|
||||
return {
|
||||
"role": "tool",
|
||||
"content": [
|
||||
{
|
||||
"function_response": {
|
||||
"name": tool_call.name,
|
||||
"response": {"result": result},
|
||||
"call_id": tool_call.id,
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
def _iterate_stream(self, response: Any) -> Generator:
|
||||
"""Iterate through OpenAI streaming response."""
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
@@ -2,6 +2,7 @@ from application.llm.base import BaseLLM
|
||||
from application.core.settings import settings
|
||||
import threading
|
||||
|
||||
|
||||
class LlamaSingleton:
|
||||
_instances = {}
|
||||
_lock = threading.Lock() # Add a lock for thread synchronization
|
||||
@@ -29,7 +30,7 @@ class LlamaCpp(BaseLLM):
|
||||
self,
|
||||
api_key=None,
|
||||
user_api_key=None,
|
||||
llm_name=settings.MODEL_PATH,
|
||||
llm_name=settings.LLM_PATH,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -42,14 +43,18 @@ class LlamaCpp(BaseLLM):
|
||||
context = messages[0]["content"]
|
||||
user_question = messages[-1]["content"]
|
||||
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
|
||||
result = LlamaSingleton.query_model(self.llama, prompt, max_tokens=150, echo=False)
|
||||
result = LlamaSingleton.query_model(
|
||||
self.llama, prompt, max_tokens=150, echo=False
|
||||
)
|
||||
return result["choices"][0]["text"].split("### Answer \n")[-1]
|
||||
|
||||
def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs):
|
||||
context = messages[0]["content"]
|
||||
user_question = messages[-1]["content"]
|
||||
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
|
||||
result = LlamaSingleton.query_model(self.llama, prompt, max_tokens=150, echo=False, stream=stream)
|
||||
result = LlamaSingleton.query_model(
|
||||
self.llama, prompt, max_tokens=150, echo=False, stream=stream
|
||||
)
|
||||
for item in result:
|
||||
for choice in item["choices"]:
|
||||
yield choice["text"]
|
||||
yield choice["text"]
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
|
||||
from application.core.settings import settings
|
||||
@@ -13,7 +13,10 @@ class OpenAILLM(BaseLLM):
|
||||
from openai import OpenAI
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
if isinstance(settings.OPENAI_BASE_URL, str) and settings.OPENAI_BASE_URL.strip():
|
||||
if (
|
||||
isinstance(settings.OPENAI_BASE_URL, str)
|
||||
and settings.OPENAI_BASE_URL.strip()
|
||||
):
|
||||
self.client = OpenAI(api_key=api_key, base_url=settings.OPENAI_BASE_URL)
|
||||
else:
|
||||
DEFAULT_OPENAI_API_BASE = "https://api.openai.com/v1"
|
||||
@@ -73,14 +76,30 @@ class OpenAILLM(BaseLLM):
|
||||
elif isinstance(item, dict):
|
||||
content_parts = []
|
||||
if "text" in item:
|
||||
content_parts.append({"type": "text", "text": item["text"]})
|
||||
elif "type" in item and item["type"] == "text" and "text" in item:
|
||||
content_parts.append(
|
||||
{"type": "text", "text": item["text"]}
|
||||
)
|
||||
elif (
|
||||
"type" in item
|
||||
and item["type"] == "text"
|
||||
and "text" in item
|
||||
):
|
||||
content_parts.append(item)
|
||||
elif "type" in item and item["type"] == "file" and "file" in item:
|
||||
elif (
|
||||
"type" in item
|
||||
and item["type"] == "file"
|
||||
and "file" in item
|
||||
):
|
||||
content_parts.append(item)
|
||||
elif "type" in item and item["type"] == "image_url" and "image_url" in item:
|
||||
elif (
|
||||
"type" in item
|
||||
and item["type"] == "image_url"
|
||||
and "image_url" in item
|
||||
):
|
||||
content_parts.append(item)
|
||||
cleaned_messages.append({"role": role, "content": content_parts})
|
||||
cleaned_messages.append(
|
||||
{"role": role, "content": content_parts}
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected content dictionary format: {item}"
|
||||
@@ -98,22 +117,29 @@ class OpenAILLM(BaseLLM):
|
||||
stream=False,
|
||||
tools=None,
|
||||
engine=settings.AZURE_DEPLOYMENT_NAME,
|
||||
response_format=None,
|
||||
**kwargs,
|
||||
):
|
||||
messages = self._clean_messages_openai(messages)
|
||||
|
||||
request_params = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"stream": stream,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
if tools:
|
||||
request_params["tools"] = tools
|
||||
|
||||
if response_format:
|
||||
request_params["response_format"] = response_format
|
||||
|
||||
response = self.client.chat.completions.create(**request_params)
|
||||
|
||||
if tools:
|
||||
response = self.client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
tools=tools,
|
||||
**kwargs,
|
||||
)
|
||||
return response.choices[0]
|
||||
else:
|
||||
response = self.client.chat.completions.create(
|
||||
model=model, messages=messages, stream=stream, **kwargs
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
|
||||
def _raw_gen_stream(
|
||||
@@ -124,24 +150,32 @@ class OpenAILLM(BaseLLM):
|
||||
stream=True,
|
||||
tools=None,
|
||||
engine=settings.AZURE_DEPLOYMENT_NAME,
|
||||
response_format=None,
|
||||
**kwargs,
|
||||
):
|
||||
messages = self._clean_messages_openai(messages)
|
||||
|
||||
request_params = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"stream": stream,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
if tools:
|
||||
response = self.client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
tools=tools,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
response = self.client.chat.completions.create(
|
||||
model=model, messages=messages, stream=stream, **kwargs
|
||||
)
|
||||
request_params["tools"] = tools
|
||||
|
||||
if response_format:
|
||||
request_params["response_format"] = response_format
|
||||
|
||||
response = self.client.chat.completions.create(**request_params)
|
||||
|
||||
for line in response:
|
||||
if len(line.choices) > 0 and line.choices[0].delta.content is not None and len(line.choices[0].delta.content) > 0:
|
||||
if (
|
||||
len(line.choices) > 0
|
||||
and line.choices[0].delta.content is not None
|
||||
and len(line.choices[0].delta.content) > 0
|
||||
):
|
||||
yield line.choices[0].delta.content
|
||||
elif len(line.choices) > 0:
|
||||
yield line.choices[0]
|
||||
@@ -149,6 +183,66 @@ class OpenAILLM(BaseLLM):
|
||||
def _supports_tools(self):
|
||||
return True
|
||||
|
||||
def _supports_structured_output(self):
|
||||
return True
|
||||
|
||||
def prepare_structured_output_format(self, json_schema):
|
||||
if not json_schema:
|
||||
return None
|
||||
|
||||
try:
|
||||
|
||||
def add_additional_properties_false(schema_obj):
|
||||
if isinstance(schema_obj, dict):
|
||||
schema_copy = schema_obj.copy()
|
||||
|
||||
if schema_copy.get("type") == "object":
|
||||
schema_copy["additionalProperties"] = False
|
||||
# Ensure 'required' includes all properties for OpenAI strict mode
|
||||
if "properties" in schema_copy:
|
||||
schema_copy["required"] = list(
|
||||
schema_copy["properties"].keys()
|
||||
)
|
||||
|
||||
for key, value in schema_copy.items():
|
||||
if key == "properties" and isinstance(value, dict):
|
||||
schema_copy[key] = {
|
||||
prop_name: add_additional_properties_false(prop_schema)
|
||||
for prop_name, prop_schema in value.items()
|
||||
}
|
||||
elif key == "items" and isinstance(value, dict):
|
||||
schema_copy[key] = add_additional_properties_false(value)
|
||||
elif key in ["anyOf", "oneOf", "allOf"] and isinstance(
|
||||
value, list
|
||||
):
|
||||
schema_copy[key] = [
|
||||
add_additional_properties_false(sub_schema)
|
||||
for sub_schema in value
|
||||
]
|
||||
|
||||
return schema_copy
|
||||
return schema_obj
|
||||
|
||||
processed_schema = add_additional_properties_false(json_schema)
|
||||
|
||||
result = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": processed_schema.get("name", "response"),
|
||||
"description": processed_schema.get(
|
||||
"description", "Structured response"
|
||||
),
|
||||
"schema": processed_schema,
|
||||
"strict": True,
|
||||
},
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error preparing structured output format: {e}")
|
||||
return None
|
||||
|
||||
def get_supported_attachment_types(self):
|
||||
"""
|
||||
Return a list of MIME types supported by OpenAI for file uploads.
|
||||
@@ -157,12 +251,12 @@ class OpenAILLM(BaseLLM):
|
||||
list: List of supported MIME types
|
||||
"""
|
||||
return [
|
||||
'application/pdf',
|
||||
'image/png',
|
||||
'image/jpeg',
|
||||
'image/jpg',
|
||||
'image/webp',
|
||||
'image/gif'
|
||||
"application/pdf",
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/jpg",
|
||||
"image/webp",
|
||||
"image/gif",
|
||||
]
|
||||
|
||||
def prepare_messages_with_attachments(self, messages, attachments=None):
|
||||
@@ -202,39 +296,46 @@ class OpenAILLM(BaseLLM):
|
||||
prepared_messages[user_message_index]["content"] = []
|
||||
|
||||
for attachment in attachments:
|
||||
mime_type = attachment.get('mime_type')
|
||||
mime_type = attachment.get("mime_type")
|
||||
|
||||
if mime_type and mime_type.startswith('image/'):
|
||||
if mime_type and mime_type.startswith("image/"):
|
||||
try:
|
||||
base64_image = self._get_base64_image(attachment)
|
||||
prepared_messages[user_message_index]["content"].append({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:{mime_type};base64,{base64_image}"
|
||||
prepared_messages[user_message_index]["content"].append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:{mime_type};base64,{base64_image}"
|
||||
},
|
||||
}
|
||||
})
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing image attachment: {e}", exc_info=True)
|
||||
if 'content' in attachment:
|
||||
prepared_messages[user_message_index]["content"].append({
|
||||
"type": "text",
|
||||
"text": f"[Image could not be processed: {attachment.get('path', 'unknown')}]"
|
||||
})
|
||||
logging.error(
|
||||
f"Error processing image attachment: {e}", exc_info=True
|
||||
)
|
||||
if "content" in attachment:
|
||||
prepared_messages[user_message_index]["content"].append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"[Image could not be processed: {attachment.get('path', 'unknown')}]",
|
||||
}
|
||||
)
|
||||
# Handle PDFs using the file API
|
||||
elif mime_type == 'application/pdf':
|
||||
elif mime_type == "application/pdf":
|
||||
try:
|
||||
file_id = self._upload_file_to_openai(attachment)
|
||||
prepared_messages[user_message_index]["content"].append({
|
||||
"type": "file",
|
||||
"file": {"file_id": file_id}
|
||||
})
|
||||
prepared_messages[user_message_index]["content"].append(
|
||||
{"type": "file", "file": {"file_id": file_id}}
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error uploading PDF to OpenAI: {e}", exc_info=True)
|
||||
if 'content' in attachment:
|
||||
prepared_messages[user_message_index]["content"].append({
|
||||
"type": "text",
|
||||
"text": f"File content:\n\n{attachment['content']}"
|
||||
})
|
||||
if "content" in attachment:
|
||||
prepared_messages[user_message_index]["content"].append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"File content:\n\n{attachment['content']}",
|
||||
}
|
||||
)
|
||||
|
||||
return prepared_messages
|
||||
|
||||
@@ -248,13 +349,13 @@ class OpenAILLM(BaseLLM):
|
||||
Returns:
|
||||
str: Base64-encoded image data.
|
||||
"""
|
||||
file_path = attachment.get('path')
|
||||
file_path = attachment.get("path")
|
||||
if not file_path:
|
||||
raise ValueError("No file path provided in attachment")
|
||||
|
||||
try:
|
||||
with self.storage.get_file(file_path) as image_file:
|
||||
return base64.b64encode(image_file.read()).decode('utf-8')
|
||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||
except FileNotFoundError:
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
@@ -273,10 +374,10 @@ class OpenAILLM(BaseLLM):
|
||||
"""
|
||||
import logging
|
||||
|
||||
if 'openai_file_id' in attachment:
|
||||
return attachment['openai_file_id']
|
||||
if "openai_file_id" in attachment:
|
||||
return attachment["openai_file_id"]
|
||||
|
||||
file_path = attachment.get('path')
|
||||
file_path = attachment.get("path")
|
||||
|
||||
if not self.storage.file_exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
@@ -285,19 +386,18 @@ class OpenAILLM(BaseLLM):
|
||||
file_id = self.storage.process_file(
|
||||
file_path,
|
||||
lambda local_path, **kwargs: self.client.files.create(
|
||||
file=open(local_path, 'rb'),
|
||||
purpose="assistants"
|
||||
).id
|
||||
file=open(local_path, "rb"), purpose="assistants"
|
||||
).id,
|
||||
)
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
attachments_collection = db["attachments"]
|
||||
if '_id' in attachment:
|
||||
if "_id" in attachment:
|
||||
attachments_collection.update_one(
|
||||
{"_id": attachment['_id']},
|
||||
{"$set": {"openai_file_id": file_id}}
|
||||
{"_id": attachment["_id"]}, {"$set": {"openai_file_id": file_id}}
|
||||
)
|
||||
|
||||
return file_id
|
||||
@@ -308,9 +408,7 @@ class OpenAILLM(BaseLLM):
|
||||
|
||||
class AzureOpenAILLM(OpenAILLM):
|
||||
|
||||
def __init__(
|
||||
self, api_key, user_api_key, *args, **kwargs
|
||||
):
|
||||
def __init__(self, api_key, user_api_key, *args, **kwargs):
|
||||
|
||||
super().__init__(api_key)
|
||||
self.api_base = (settings.OPENAI_API_BASE,)
|
||||
@@ -321,5 +419,5 @@ class AzureOpenAILLM(OpenAILLM):
|
||||
self.client = AzureOpenAI(
|
||||
api_key=api_key,
|
||||
api_version=settings.OPENAI_API_VERSION,
|
||||
azure_endpoint=settings.OPENAI_API_BASE
|
||||
azure_endpoint=settings.OPENAI_API_BASE,
|
||||
)
|
||||
|
||||
@@ -136,6 +136,8 @@ def _log_to_mongodb(
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
user_logs_collection = db["stack_logs"]
|
||||
|
||||
|
||||
|
||||
log_entry = {
|
||||
"endpoint": endpoint,
|
||||
@@ -147,6 +149,11 @@ def _log_to_mongodb(
|
||||
"stacks": stacks,
|
||||
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
||||
}
|
||||
# clean up text fields to be no longer than 10000 characters
|
||||
for key, value in log_entry.items():
|
||||
if isinstance(value, str) and len(value) > 10000:
|
||||
log_entry[key] = value[:10000]
|
||||
|
||||
user_logs_collection.insert_one(log_entry)
|
||||
logging.debug(f"Logged activity to MongoDB: {activity_id}")
|
||||
|
||||
|
||||
@@ -32,16 +32,7 @@ class Chunker:
|
||||
header, body = "", text # No header, treat entire text as body
|
||||
return header, body
|
||||
|
||||
def combine_documents(self, doc: Document, next_doc: Document) -> Document:
|
||||
combined_text = doc.text + " " + next_doc.text
|
||||
combined_token_count = len(self.encoding.encode(combined_text))
|
||||
new_doc = Document(
|
||||
text=combined_text,
|
||||
doc_id=doc.doc_id,
|
||||
embedding=doc.embedding,
|
||||
extra_info={**(doc.extra_info or {}), "token_count": combined_token_count}
|
||||
)
|
||||
return new_doc
|
||||
|
||||
|
||||
def split_document(self, doc: Document) -> List[Document]:
|
||||
split_docs = []
|
||||
@@ -82,26 +73,11 @@ class Chunker:
|
||||
processed_docs.append(doc)
|
||||
i += 1
|
||||
elif token_count < self.min_tokens:
|
||||
if i + 1 < len(documents):
|
||||
next_doc = documents[i + 1]
|
||||
next_tokens = self.encoding.encode(next_doc.text)
|
||||
if token_count + len(next_tokens) <= self.max_tokens:
|
||||
# Combine small documents
|
||||
combined_doc = self.combine_documents(doc, next_doc)
|
||||
processed_docs.append(combined_doc)
|
||||
i += 2
|
||||
else:
|
||||
# Keep the small document as is if adding next_doc would exceed max_tokens
|
||||
doc.extra_info = doc.extra_info or {}
|
||||
doc.extra_info["token_count"] = token_count
|
||||
processed_docs.append(doc)
|
||||
i += 1
|
||||
else:
|
||||
# No next document to combine with; add the small document as is
|
||||
doc.extra_info = doc.extra_info or {}
|
||||
doc.extra_info["token_count"] = token_count
|
||||
processed_docs.append(doc)
|
||||
i += 1
|
||||
|
||||
doc.extra_info = doc.extra_info or {}
|
||||
doc.extra_info["token_count"] = token_count
|
||||
processed_docs.append(doc)
|
||||
i += 1
|
||||
else:
|
||||
# Split large documents
|
||||
processed_docs.extend(self.split_document(doc))
|
||||
|
||||
18
application/parser/connectors/__init__.py
Normal file
18
application/parser/connectors/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""
|
||||
External knowledge base connectors for DocsGPT.
|
||||
|
||||
This module contains connectors for external knowledge bases and document storage systems
|
||||
that require authentication and specialized handling, separate from simple web scrapers.
|
||||
"""
|
||||
|
||||
from .base import BaseConnectorAuth, BaseConnectorLoader
|
||||
from .connector_creator import ConnectorCreator
|
||||
from .google_drive import GoogleDriveAuth, GoogleDriveLoader
|
||||
|
||||
__all__ = [
|
||||
'BaseConnectorAuth',
|
||||
'BaseConnectorLoader',
|
||||
'ConnectorCreator',
|
||||
'GoogleDriveAuth',
|
||||
'GoogleDriveLoader'
|
||||
]
|
||||
129
application/parser/connectors/base.py
Normal file
129
application/parser/connectors/base.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""
|
||||
Base classes for external knowledge base connectors.
|
||||
|
||||
This module provides minimal abstract base classes that define the essential
|
||||
interface for external knowledge base connectors.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from application.parser.schema.base import Document
|
||||
|
||||
|
||||
class BaseConnectorAuth(ABC):
|
||||
"""
|
||||
Abstract base class for connector authentication.
|
||||
|
||||
Defines the minimal interface that all connector authentication
|
||||
implementations must follow.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_authorization_url(self, state: Optional[str] = None) -> str:
|
||||
"""
|
||||
Generate authorization URL for OAuth flows.
|
||||
|
||||
Args:
|
||||
state: Optional state parameter for CSRF protection
|
||||
|
||||
Returns:
|
||||
Authorization URL
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def exchange_code_for_tokens(self, authorization_code: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Exchange authorization code for access tokens.
|
||||
|
||||
Args:
|
||||
authorization_code: Authorization code from OAuth callback
|
||||
|
||||
Returns:
|
||||
Dictionary containing token information
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def refresh_access_token(self, refresh_token: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Refresh an expired access token.
|
||||
|
||||
Args:
|
||||
refresh_token: Refresh token
|
||||
|
||||
Returns:
|
||||
Dictionary containing refreshed token information
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_token_expired(self, token_info: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Check if a token is expired.
|
||||
|
||||
Args:
|
||||
token_info: Token information dictionary
|
||||
|
||||
Returns:
|
||||
True if token is expired, False otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class BaseConnectorLoader(ABC):
|
||||
"""
|
||||
Abstract base class for connector loaders.
|
||||
|
||||
Defines the minimal interface that all connector loader
|
||||
implementations must follow.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, session_token: str):
|
||||
"""
|
||||
Initialize the connector loader.
|
||||
|
||||
Args:
|
||||
session_token: Authentication session token
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_data(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
"""
|
||||
Load documents from the external knowledge base.
|
||||
|
||||
Args:
|
||||
inputs: Configuration dictionary containing:
|
||||
- file_ids: Optional list of specific file IDs to load
|
||||
- folder_ids: Optional list of folder IDs to browse/download
|
||||
- limit: Maximum number of items to return
|
||||
- list_only: If True, return metadata without content
|
||||
- recursive: Whether to recursively process folders
|
||||
|
||||
Returns:
|
||||
List of Document objects
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def download_to_directory(self, local_dir: str, source_config: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Download files/folders to a local directory.
|
||||
|
||||
Args:
|
||||
local_dir: Local directory path to download files to
|
||||
source_config: Configuration for what to download
|
||||
|
||||
Returns:
|
||||
Dictionary containing download results:
|
||||
- files_downloaded: Number of files downloaded
|
||||
- directory_path: Path where files were downloaded
|
||||
- empty_result: Whether no files were downloaded
|
||||
- source_type: Type of connector
|
||||
- config_used: Configuration that was used
|
||||
- error: Error message if download failed (optional)
|
||||
"""
|
||||
pass
|
||||
81
application/parser/connectors/connector_creator.py
Normal file
81
application/parser/connectors/connector_creator.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from application.parser.connectors.google_drive.loader import GoogleDriveLoader
|
||||
from application.parser.connectors.google_drive.auth import GoogleDriveAuth
|
||||
|
||||
|
||||
class ConnectorCreator:
|
||||
"""
|
||||
Factory class for creating external knowledge base connectors and auth providers.
|
||||
|
||||
These are different from remote loaders as they typically require
|
||||
authentication and connect to external document storage systems.
|
||||
"""
|
||||
|
||||
connectors = {
|
||||
"google_drive": GoogleDriveLoader,
|
||||
}
|
||||
|
||||
auth_providers = {
|
||||
"google_drive": GoogleDriveAuth,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create_connector(cls, connector_type, *args, **kwargs):
|
||||
"""
|
||||
Create a connector instance for the specified type.
|
||||
|
||||
Args:
|
||||
connector_type: Type of connector to create (e.g., 'google_drive')
|
||||
*args, **kwargs: Arguments to pass to the connector constructor
|
||||
|
||||
Returns:
|
||||
Connector instance
|
||||
|
||||
Raises:
|
||||
ValueError: If connector type is not supported
|
||||
"""
|
||||
connector_class = cls.connectors.get(connector_type.lower())
|
||||
if not connector_class:
|
||||
raise ValueError(f"No connector class found for type {connector_type}")
|
||||
return connector_class(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def create_auth(cls, connector_type):
|
||||
"""
|
||||
Create an auth provider instance for the specified connector type.
|
||||
|
||||
Args:
|
||||
connector_type: Type of connector auth to create (e.g., 'google_drive')
|
||||
|
||||
Returns:
|
||||
Auth provider instance
|
||||
|
||||
Raises:
|
||||
ValueError: If connector type is not supported for auth
|
||||
"""
|
||||
auth_class = cls.auth_providers.get(connector_type.lower())
|
||||
if not auth_class:
|
||||
raise ValueError(f"No auth class found for type {connector_type}")
|
||||
return auth_class()
|
||||
|
||||
@classmethod
|
||||
def get_supported_connectors(cls):
|
||||
"""
|
||||
Get list of supported connector types.
|
||||
|
||||
Returns:
|
||||
List of supported connector type strings
|
||||
"""
|
||||
return list(cls.connectors.keys())
|
||||
|
||||
@classmethod
|
||||
def is_supported(cls, connector_type):
|
||||
"""
|
||||
Check if a connector type is supported.
|
||||
|
||||
Args:
|
||||
connector_type: Type of connector to check
|
||||
|
||||
Returns:
|
||||
True if supported, False otherwise
|
||||
"""
|
||||
return connector_type.lower() in cls.connectors
|
||||
10
application/parser/connectors/google_drive/__init__.py
Normal file
10
application/parser/connectors/google_drive/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
Google Drive connector for DocsGPT.
|
||||
|
||||
This module provides authentication and document loading capabilities for Google Drive.
|
||||
"""
|
||||
|
||||
from .auth import GoogleDriveAuth
|
||||
from .loader import GoogleDriveLoader
|
||||
|
||||
__all__ = ['GoogleDriveAuth', 'GoogleDriveLoader']
|
||||
267
application/parser/connectors/google_drive/auth.py
Normal file
267
application/parser/connectors/google_drive/auth.py
Normal file
@@ -0,0 +1,267 @@
|
||||
import logging
|
||||
import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
from google_auth_oauthlib.flow import Flow
|
||||
from googleapiclient.discovery import build
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.parser.connectors.base import BaseConnectorAuth
|
||||
|
||||
|
||||
class GoogleDriveAuth(BaseConnectorAuth):
|
||||
"""
|
||||
Handles Google OAuth 2.0 authentication for Google Drive access.
|
||||
"""
|
||||
|
||||
SCOPES = [
|
||||
'https://www.googleapis.com/auth/drive.file'
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
self.client_id = settings.GOOGLE_CLIENT_ID
|
||||
self.client_secret = settings.GOOGLE_CLIENT_SECRET
|
||||
self.redirect_uri = f"{settings.CONNECTOR_REDIRECT_BASE_URI}"
|
||||
|
||||
if not self.client_id or not self.client_secret:
|
||||
raise ValueError("Google OAuth credentials not configured. Please set GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET in settings.")
|
||||
|
||||
|
||||
|
||||
def get_authorization_url(self, state: Optional[str] = None) -> str:
|
||||
try:
|
||||
flow = Flow.from_client_config(
|
||||
{
|
||||
"web": {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
||||
"token_uri": "https://oauth2.googleapis.com/token",
|
||||
"redirect_uris": [self.redirect_uri]
|
||||
}
|
||||
},
|
||||
scopes=self.SCOPES
|
||||
)
|
||||
flow.redirect_uri = self.redirect_uri
|
||||
|
||||
authorization_url, _ = flow.authorization_url(
|
||||
access_type='offline',
|
||||
prompt='consent',
|
||||
include_granted_scopes='false',
|
||||
state=state
|
||||
)
|
||||
|
||||
return authorization_url
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error generating authorization URL: {e}")
|
||||
raise
|
||||
|
||||
def exchange_code_for_tokens(self, authorization_code: str) -> Dict[str, Any]:
|
||||
try:
|
||||
if not authorization_code:
|
||||
raise ValueError("Authorization code is required")
|
||||
|
||||
flow = Flow.from_client_config(
|
||||
{
|
||||
"web": {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
||||
"token_uri": "https://oauth2.googleapis.com/token",
|
||||
"redirect_uris": [self.redirect_uri]
|
||||
}
|
||||
},
|
||||
scopes=self.SCOPES
|
||||
)
|
||||
flow.redirect_uri = self.redirect_uri
|
||||
|
||||
flow.fetch_token(code=authorization_code)
|
||||
|
||||
credentials = flow.credentials
|
||||
|
||||
if not credentials.refresh_token:
|
||||
logging.warning("OAuth flow did not return a refresh_token.")
|
||||
if not credentials.token:
|
||||
raise ValueError("OAuth flow did not return an access token")
|
||||
|
||||
if not credentials.token_uri:
|
||||
credentials.token_uri = "https://oauth2.googleapis.com/token"
|
||||
|
||||
if not credentials.client_id:
|
||||
credentials.client_id = self.client_id
|
||||
|
||||
if not credentials.client_secret:
|
||||
credentials.client_secret = self.client_secret
|
||||
|
||||
if not credentials.refresh_token:
|
||||
raise ValueError(
|
||||
"No refresh token received. This typically happens when offline access wasn't granted. "
|
||||
)
|
||||
|
||||
return {
|
||||
'access_token': credentials.token,
|
||||
'refresh_token': credentials.refresh_token,
|
||||
'token_uri': credentials.token_uri,
|
||||
'client_id': credentials.client_id,
|
||||
'client_secret': credentials.client_secret,
|
||||
'scopes': credentials.scopes,
|
||||
'expiry': credentials.expiry.isoformat() if credentials.expiry else None
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error exchanging code for tokens: {e}")
|
||||
raise
|
||||
|
||||
def refresh_access_token(self, refresh_token: str) -> Dict[str, Any]:
|
||||
try:
|
||||
if not refresh_token:
|
||||
raise ValueError("Refresh token is required")
|
||||
|
||||
credentials = Credentials(
|
||||
token=None,
|
||||
refresh_token=refresh_token,
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
client_id=self.client_id,
|
||||
client_secret=self.client_secret
|
||||
)
|
||||
|
||||
from google.auth.transport.requests import Request
|
||||
credentials.refresh(Request())
|
||||
|
||||
return {
|
||||
'access_token': credentials.token,
|
||||
'refresh_token': refresh_token,
|
||||
'token_uri': credentials.token_uri,
|
||||
'client_id': credentials.client_id,
|
||||
'client_secret': credentials.client_secret,
|
||||
'scopes': credentials.scopes,
|
||||
'expiry': credentials.expiry.isoformat() if credentials.expiry else None
|
||||
}
|
||||
except Exception as e:
|
||||
logging.error(f"Error refreshing access token: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def create_credentials_from_token_info(self, token_info: Dict[str, Any]) -> Credentials:
|
||||
from application.core.settings import settings
|
||||
|
||||
access_token = token_info.get('access_token')
|
||||
if not access_token:
|
||||
raise ValueError("No access token found in token_info")
|
||||
|
||||
credentials = Credentials(
|
||||
token=access_token,
|
||||
refresh_token=token_info.get('refresh_token'),
|
||||
token_uri= 'https://oauth2.googleapis.com/token',
|
||||
client_id=settings.GOOGLE_CLIENT_ID,
|
||||
client_secret=settings.GOOGLE_CLIENT_SECRET,
|
||||
scopes=token_info.get('scopes', ['https://www.googleapis.com/auth/drive.readonly'])
|
||||
)
|
||||
|
||||
if not credentials.token:
|
||||
raise ValueError("Credentials created without valid access token")
|
||||
|
||||
return credentials
|
||||
|
||||
def build_drive_service(self, credentials: Credentials):
|
||||
try:
|
||||
if not credentials:
|
||||
raise ValueError("No credentials provided")
|
||||
|
||||
if not credentials.token and not credentials.refresh_token:
|
||||
raise ValueError("No access token or refresh token available. User must re-authorize with offline access.")
|
||||
|
||||
needs_refresh = credentials.expired or not credentials.token
|
||||
if needs_refresh:
|
||||
if credentials.refresh_token:
|
||||
try:
|
||||
from google.auth.transport.requests import Request
|
||||
credentials.refresh(Request())
|
||||
except Exception as refresh_error:
|
||||
raise ValueError(f"Failed to refresh credentials: {refresh_error}")
|
||||
else:
|
||||
raise ValueError("No access token or refresh token available. User must re-authorize with offline access.")
|
||||
|
||||
return build('drive', 'v3', credentials=credentials)
|
||||
|
||||
except HttpError as e:
|
||||
raise ValueError(f"Failed to build Google Drive service: HTTP {e.resp.status}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to build Google Drive service: {str(e)}")
|
||||
|
||||
def is_token_expired(self, token_info):
|
||||
if 'expiry' in token_info and token_info['expiry']:
|
||||
try:
|
||||
from dateutil import parser
|
||||
# Google Drive provides timezone-aware ISO8601 dates
|
||||
expiry_dt = parser.parse(token_info['expiry'])
|
||||
current_time = datetime.datetime.now(datetime.timezone.utc)
|
||||
return current_time >= expiry_dt - datetime.timedelta(seconds=60)
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
if 'access_token' in token_info and token_info['access_token']:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_token_info_from_session(self, session_token: str) -> Dict[str, Any]:
|
||||
try:
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
|
||||
sessions_collection = db["connector_sessions"]
|
||||
session = sessions_collection.find_one({"session_token": session_token})
|
||||
if not session:
|
||||
raise ValueError(f"Invalid session token: {session_token}")
|
||||
|
||||
if "token_info" not in session:
|
||||
raise ValueError("Session missing token information")
|
||||
|
||||
token_info = session["token_info"]
|
||||
if not token_info:
|
||||
raise ValueError("Invalid token information")
|
||||
|
||||
required_fields = ["access_token", "refresh_token"]
|
||||
missing_fields = [field for field in required_fields if field not in token_info or not token_info.get(field)]
|
||||
if missing_fields:
|
||||
raise ValueError(f"Missing required token fields: {missing_fields}")
|
||||
|
||||
if 'client_id' not in token_info:
|
||||
token_info['client_id'] = settings.GOOGLE_CLIENT_ID
|
||||
if 'client_secret' not in token_info:
|
||||
token_info['client_secret'] = settings.GOOGLE_CLIENT_SECRET
|
||||
if 'token_uri' not in token_info:
|
||||
token_info['token_uri'] = 'https://oauth2.googleapis.com/token'
|
||||
|
||||
return token_info
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to retrieve Google Drive token information: {str(e)}")
|
||||
|
||||
def validate_credentials(self, credentials: Credentials) -> bool:
|
||||
"""
|
||||
Validate Google Drive credentials by making a test API call.
|
||||
|
||||
Args:
|
||||
credentials: Google credentials object
|
||||
|
||||
Returns:
|
||||
True if credentials are valid, False otherwise
|
||||
"""
|
||||
try:
|
||||
service = self.build_drive_service(credentials)
|
||||
service.about().get(fields="user").execute()
|
||||
return True
|
||||
|
||||
except HttpError as e:
|
||||
logging.error(f"HTTP error validating credentials: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logging.error(f"Error validating credentials: {e}")
|
||||
return False
|
||||
559
application/parser/connectors/google_drive/loader.py
Normal file
559
application/parser/connectors/google_drive/loader.py
Normal file
@@ -0,0 +1,559 @@
|
||||
"""
|
||||
Google Drive loader for DocsGPT.
|
||||
Loads documents from Google Drive using Google Drive API.
|
||||
"""
|
||||
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from googleapiclient.http import MediaIoBaseDownload
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
from application.parser.connectors.base import BaseConnectorLoader
|
||||
from application.parser.connectors.google_drive.auth import GoogleDriveAuth
|
||||
from application.parser.schema.base import Document
|
||||
|
||||
|
||||
class GoogleDriveLoader(BaseConnectorLoader):
|
||||
|
||||
SUPPORTED_MIME_TYPES = {
|
||||
'application/pdf': '.pdf',
|
||||
'application/vnd.google-apps.document': '.docx',
|
||||
'application/vnd.google-apps.presentation': '.pptx',
|
||||
'application/vnd.google-apps.spreadsheet': '.xlsx',
|
||||
'application/vnd.openxmlformats-officedocument.wordprocessingml.document': '.docx',
|
||||
'application/vnd.openxmlformats-officedocument.presentationml.presentation': '.pptx',
|
||||
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': '.xlsx',
|
||||
'application/msword': '.doc',
|
||||
'application/vnd.ms-powerpoint': '.ppt',
|
||||
'application/vnd.ms-excel': '.xls',
|
||||
'text/plain': '.txt',
|
||||
'text/csv': '.csv',
|
||||
'text/html': '.html',
|
||||
'text/markdown': '.md',
|
||||
'text/x-rst': '.rst',
|
||||
'application/json': '.json',
|
||||
'application/epub+zip': '.epub',
|
||||
'application/rtf': '.rtf',
|
||||
'image/jpeg': '.jpg',
|
||||
'image/jpg': '.jpg',
|
||||
'image/png': '.png',
|
||||
}
|
||||
|
||||
EXPORT_FORMATS = {
|
||||
'application/vnd.google-apps.document': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
||||
'application/vnd.google-apps.presentation': 'application/vnd.openxmlformats-officedocument.presentationml.presentation',
|
||||
'application/vnd.google-apps.spreadsheet': 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet'
|
||||
}
|
||||
|
||||
def __init__(self, session_token: str):
|
||||
self.auth = GoogleDriveAuth()
|
||||
self.session_token = session_token
|
||||
|
||||
token_info = self.auth.get_token_info_from_session(session_token)
|
||||
self.credentials = self.auth.create_credentials_from_token_info(token_info)
|
||||
|
||||
try:
|
||||
self.service = self.auth.build_drive_service(self.credentials)
|
||||
except Exception as e:
|
||||
logging.warning(f"Could not build Google Drive service: {e}")
|
||||
self.service = None
|
||||
|
||||
self.next_page_token = None
|
||||
|
||||
|
||||
|
||||
def _process_file(self, file_metadata: Dict[str, Any], load_content: bool = True) -> Optional[Document]:
|
||||
try:
|
||||
file_id = file_metadata.get('id')
|
||||
file_name = file_metadata.get('name', 'Unknown')
|
||||
mime_type = file_metadata.get('mimeType', 'application/octet-stream')
|
||||
|
||||
if mime_type not in self.SUPPORTED_MIME_TYPES and not mime_type.startswith('application/vnd.google-apps.'):
|
||||
return None
|
||||
if mime_type not in self.SUPPORTED_MIME_TYPES and not mime_type.startswith('application/vnd.google-apps.'):
|
||||
logging.info(f"Skipping unsupported file type: {mime_type} for file {file_name}")
|
||||
return None
|
||||
# Google Drive provides timezone-aware ISO8601 dates
|
||||
doc_metadata = {
|
||||
'file_name': file_name,
|
||||
'mime_type': mime_type,
|
||||
'size': file_metadata.get('size', None),
|
||||
'created_time': file_metadata.get('createdTime'),
|
||||
'modified_time': file_metadata.get('modifiedTime'),
|
||||
'parents': file_metadata.get('parents', []),
|
||||
'source': 'google_drive'
|
||||
}
|
||||
|
||||
if not load_content:
|
||||
return Document(
|
||||
text="",
|
||||
doc_id=file_id,
|
||||
extra_info=doc_metadata
|
||||
)
|
||||
|
||||
content = self._download_file_content(file_id, mime_type)
|
||||
if content is None:
|
||||
logging.warning(f"Could not load content for file {file_name} ({file_id})")
|
||||
return None
|
||||
|
||||
return Document(
|
||||
text=content,
|
||||
doc_id=file_id,
|
||||
extra_info=doc_metadata
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing file: {e}")
|
||||
return None
|
||||
|
||||
def load_data(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
session_token = inputs.get('session_token')
|
||||
if session_token and session_token != self.session_token:
|
||||
logging.warning("Session token in inputs differs from loader's session token. Using loader's session token.")
|
||||
self.config = inputs
|
||||
|
||||
try:
|
||||
documents: List[Document] = []
|
||||
|
||||
folder_id = inputs.get('folder_id')
|
||||
file_ids = inputs.get('file_ids', [])
|
||||
limit = inputs.get('limit', 100)
|
||||
list_only = inputs.get('list_only', False)
|
||||
load_content = not list_only
|
||||
page_token = inputs.get('page_token')
|
||||
search_query = inputs.get('search_query')
|
||||
self.next_page_token = None
|
||||
|
||||
if file_ids:
|
||||
# Specific files requested: load them
|
||||
for file_id in file_ids:
|
||||
try:
|
||||
doc = self._load_file_by_id(file_id, load_content=load_content)
|
||||
if doc:
|
||||
if not search_query or (
|
||||
search_query.lower() in doc.extra_info.get('file_name', '').lower()
|
||||
):
|
||||
documents.append(doc)
|
||||
elif hasattr(self, '_credential_refreshed') and self._credential_refreshed:
|
||||
self._credential_refreshed = False
|
||||
logging.info(f"Retrying load of file {file_id} after credential refresh")
|
||||
doc = self._load_file_by_id(file_id, load_content=load_content)
|
||||
if doc and (
|
||||
not search_query or
|
||||
search_query.lower() in doc.extra_info.get('file_name', '').lower()
|
||||
):
|
||||
documents.append(doc)
|
||||
except Exception as e:
|
||||
logging.error(f"Error loading file {file_id}: {e}")
|
||||
continue
|
||||
else:
|
||||
# Browsing mode: list immediate children of provided folder or root
|
||||
parent_id = folder_id if folder_id else 'root'
|
||||
documents = self._list_items_in_parent(
|
||||
parent_id,
|
||||
limit=limit,
|
||||
load_content=load_content,
|
||||
page_token=page_token,
|
||||
search_query=search_query
|
||||
)
|
||||
|
||||
logging.info(f"Loaded {len(documents)} documents from Google Drive")
|
||||
return documents
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error loading data from Google Drive: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
|
||||
def _load_file_by_id(self, file_id: str, load_content: bool = True) -> Optional[Document]:
|
||||
self._ensure_service()
|
||||
|
||||
try:
|
||||
file_metadata = self.service.files().get(
|
||||
fileId=file_id,
|
||||
fields='id,name,mimeType,size,createdTime,modifiedTime,parents'
|
||||
).execute()
|
||||
|
||||
return self._process_file(file_metadata, load_content=load_content)
|
||||
|
||||
except HttpError as e:
|
||||
logging.error(f"HTTP error loading file {file_id}: {e.resp.status} - {e.content}")
|
||||
|
||||
if e.resp.status in [401, 403]:
|
||||
if hasattr(self.credentials, 'refresh_token') and self.credentials.refresh_token:
|
||||
try:
|
||||
from google.auth.transport.requests import Request
|
||||
self.credentials.refresh(Request())
|
||||
self._ensure_service()
|
||||
return None
|
||||
except Exception as refresh_error:
|
||||
raise ValueError(f"Authentication failed and could not be refreshed: {refresh_error}")
|
||||
else:
|
||||
raise ValueError("Authentication failed and cannot be refreshed: missing refresh_token")
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(f"Error loading file {file_id}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _list_items_in_parent(self, parent_id: str, limit: int = 100, load_content: bool = False, page_token: Optional[str] = None, search_query: Optional[str] = None) -> List[Document]:
|
||||
self._ensure_service()
|
||||
|
||||
documents: List[Document] = []
|
||||
|
||||
try:
|
||||
query = f"'{parent_id}' in parents and trashed=false"
|
||||
|
||||
if search_query:
|
||||
safe_search = search_query.replace("'", "\\'")
|
||||
query += f" and name contains '{safe_search}'"
|
||||
|
||||
next_token_out: Optional[str] = None
|
||||
|
||||
while True:
|
||||
page_size = 100
|
||||
if limit:
|
||||
remaining = max(0, limit - len(documents))
|
||||
if remaining == 0:
|
||||
break
|
||||
page_size = min(100, remaining)
|
||||
|
||||
results = self.service.files().list(
|
||||
q=query,
|
||||
fields='nextPageToken,files(id,name,mimeType,size,createdTime,modifiedTime,parents)',
|
||||
pageToken=page_token,
|
||||
pageSize=page_size,
|
||||
orderBy='name'
|
||||
).execute()
|
||||
|
||||
items = results.get('files', [])
|
||||
for item in items:
|
||||
mime_type = item.get('mimeType')
|
||||
if mime_type == 'application/vnd.google-apps.folder':
|
||||
doc_metadata = {
|
||||
'file_name': item.get('name', 'Unknown'),
|
||||
'mime_type': mime_type,
|
||||
'size': item.get('size', None),
|
||||
'created_time': item.get('createdTime'),
|
||||
'modified_time': item.get('modifiedTime'),
|
||||
'parents': item.get('parents', []),
|
||||
'source': 'google_drive',
|
||||
'is_folder': True
|
||||
}
|
||||
documents.append(Document(text="", doc_id=item.get('id'), extra_info=doc_metadata))
|
||||
else:
|
||||
doc = self._process_file(item, load_content=load_content)
|
||||
if doc:
|
||||
documents.append(doc)
|
||||
|
||||
if limit and len(documents) >= limit:
|
||||
self.next_page_token = results.get('nextPageToken')
|
||||
return documents
|
||||
|
||||
page_token = results.get('nextPageToken')
|
||||
next_token_out = page_token
|
||||
if not page_token:
|
||||
break
|
||||
|
||||
self.next_page_token = next_token_out
|
||||
return documents
|
||||
except Exception as e:
|
||||
logging.error(f"Error listing items under parent {parent_id}: {e}")
|
||||
return documents
|
||||
|
||||
|
||||
|
||||
|
||||
def _download_file_content(self, file_id: str, mime_type: str) -> Optional[str]:
|
||||
if not self.credentials.token:
|
||||
logging.warning("No access token in credentials, attempting to refresh")
|
||||
if hasattr(self.credentials, 'refresh_token') and self.credentials.refresh_token:
|
||||
try:
|
||||
from google.auth.transport.requests import Request
|
||||
self.credentials.refresh(Request())
|
||||
logging.info("Credentials refreshed successfully")
|
||||
self._ensure_service()
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to refresh credentials: {e}")
|
||||
raise ValueError("Authentication failed and cannot be refreshed: missing or invalid refresh_token")
|
||||
else:
|
||||
logging.error("No access token and no refresh_token available")
|
||||
raise ValueError("Authentication failed and cannot be refreshed: missing refresh_token")
|
||||
|
||||
if self.credentials.expired:
|
||||
logging.warning("Credentials are expired, attempting to refresh")
|
||||
if hasattr(self.credentials, 'refresh_token') and self.credentials.refresh_token:
|
||||
try:
|
||||
from google.auth.transport.requests import Request
|
||||
self.credentials.refresh(Request())
|
||||
logging.info("Credentials refreshed successfully")
|
||||
self._ensure_service()
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to refresh expired credentials: {e}")
|
||||
raise ValueError("Authentication failed and cannot be refreshed: expired credentials")
|
||||
else:
|
||||
logging.error("Credentials expired and no refresh_token available")
|
||||
raise ValueError("Authentication failed and cannot be refreshed: missing refresh_token")
|
||||
|
||||
try:
|
||||
if mime_type in self.EXPORT_FORMATS:
|
||||
export_mime_type = self.EXPORT_FORMATS[mime_type]
|
||||
request = self.service.files().export_media(
|
||||
fileId=file_id,
|
||||
mimeType=export_mime_type
|
||||
)
|
||||
else:
|
||||
request = self.service.files().get_media(fileId=file_id)
|
||||
|
||||
file_io = io.BytesIO()
|
||||
downloader = MediaIoBaseDownload(file_io, request)
|
||||
|
||||
done = False
|
||||
while done is False:
|
||||
try:
|
||||
_, done = downloader.next_chunk()
|
||||
except HttpError as e:
|
||||
logging.error(f"HTTP error downloading file {file_id}: {e.resp.status} - {e.content}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(f"Error during download of file {file_id}: {e}")
|
||||
return None
|
||||
|
||||
content_bytes = file_io.getvalue()
|
||||
|
||||
try:
|
||||
content = content_bytes.decode('utf-8')
|
||||
except UnicodeDecodeError:
|
||||
try:
|
||||
content = content_bytes.decode('latin-1')
|
||||
except UnicodeDecodeError:
|
||||
logging.error(f"Could not decode file {file_id} as text")
|
||||
return None
|
||||
|
||||
return content
|
||||
|
||||
except HttpError as e:
|
||||
logging.error(f"HTTP error downloading file {file_id}: {e.resp.status} - {e.content}")
|
||||
|
||||
if e.resp.status in [401, 403]:
|
||||
logging.error(f"Authentication error downloading file {file_id}")
|
||||
|
||||
if hasattr(self.credentials, 'refresh_token') and self.credentials.refresh_token:
|
||||
logging.info(f"Attempting to refresh credentials for file {file_id}")
|
||||
try:
|
||||
from google.auth.transport.requests import Request
|
||||
self.credentials.refresh(Request())
|
||||
logging.info("Credentials refreshed successfully")
|
||||
self._credential_refreshed = True
|
||||
self._ensure_service()
|
||||
return None
|
||||
except Exception as refresh_error:
|
||||
logging.error(f"Error refreshing credentials: {refresh_error}")
|
||||
raise ValueError(f"Authentication failed and could not be refreshed: {refresh_error}")
|
||||
else:
|
||||
logging.error("Cannot refresh credentials: missing refresh_token")
|
||||
raise ValueError("Authentication failed and cannot be refreshed: missing refresh_token")
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(f"Error downloading file {file_id}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _download_file_to_directory(self, file_id: str, local_dir: str) -> bool:
|
||||
try:
|
||||
self._ensure_service()
|
||||
return self._download_single_file(file_id, local_dir)
|
||||
except Exception as e:
|
||||
logging.error(f"Error downloading file {file_id}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def _ensure_service(self):
|
||||
if not self.service:
|
||||
try:
|
||||
self.service = self.auth.build_drive_service(self.credentials)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Cannot access Google Drive: {e}")
|
||||
|
||||
def _download_single_file(self, file_id: str, local_dir: str) -> bool:
|
||||
file_metadata = self.service.files().get(
|
||||
fileId=file_id,
|
||||
fields='name,mimeType'
|
||||
).execute()
|
||||
|
||||
file_name = file_metadata['name']
|
||||
mime_type = file_metadata['mimeType']
|
||||
|
||||
if mime_type not in self.SUPPORTED_MIME_TYPES and not mime_type.startswith('application/vnd.google-apps.'):
|
||||
return False
|
||||
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
full_path = os.path.join(local_dir, file_name)
|
||||
|
||||
if mime_type in self.EXPORT_FORMATS:
|
||||
export_mime_type = self.EXPORT_FORMATS[mime_type]
|
||||
request = self.service.files().export_media(
|
||||
fileId=file_id,
|
||||
mimeType=export_mime_type
|
||||
)
|
||||
extension = self._get_extension_for_mime_type(export_mime_type)
|
||||
if not full_path.endswith(extension):
|
||||
full_path += extension
|
||||
else:
|
||||
request = self.service.files().get_media(fileId=file_id)
|
||||
|
||||
with open(full_path, 'wb') as f:
|
||||
downloader = MediaIoBaseDownload(f, request)
|
||||
done = False
|
||||
while not done:
|
||||
_, done = downloader.next_chunk()
|
||||
|
||||
return True
|
||||
|
||||
def _download_folder_recursive(self, folder_id: str, local_dir: str, recursive: bool = True) -> int:
|
||||
files_downloaded = 0
|
||||
try:
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
|
||||
query = f"'{folder_id}' in parents and trashed=false"
|
||||
page_token = None
|
||||
|
||||
while True:
|
||||
results = self.service.files().list(
|
||||
q=query,
|
||||
fields='nextPageToken, files(id, name, mimeType)',
|
||||
pageToken=page_token,
|
||||
pageSize=1000
|
||||
).execute()
|
||||
|
||||
items = results.get('files', [])
|
||||
logging.info(f"Found {len(items)} items in folder {folder_id}")
|
||||
|
||||
for item in items:
|
||||
item_name = item['name']
|
||||
item_id = item['id']
|
||||
mime_type = item['mimeType']
|
||||
|
||||
if mime_type == 'application/vnd.google-apps.folder':
|
||||
if recursive:
|
||||
# Create subfolder and recurse
|
||||
subfolder_path = os.path.join(local_dir, item_name)
|
||||
os.makedirs(subfolder_path, exist_ok=True)
|
||||
subfolder_files = self._download_folder_recursive(
|
||||
item_id,
|
||||
subfolder_path,
|
||||
recursive
|
||||
)
|
||||
files_downloaded += subfolder_files
|
||||
logging.info(f"Downloaded {subfolder_files} files from subfolder {item_name}")
|
||||
else:
|
||||
# Download file
|
||||
success = self._download_single_file(item_id, local_dir)
|
||||
if success:
|
||||
files_downloaded += 1
|
||||
logging.info(f"Downloaded file: {item_name}")
|
||||
else:
|
||||
logging.warning(f"Failed to download file: {item_name}")
|
||||
|
||||
page_token = results.get('nextPageToken')
|
||||
if not page_token:
|
||||
break
|
||||
|
||||
return files_downloaded
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error in _download_folder_recursive for folder {folder_id}: {e}", exc_info=True)
|
||||
return files_downloaded
|
||||
|
||||
def _get_extension_for_mime_type(self, mime_type: str) -> str:
|
||||
extensions = {
|
||||
'application/pdf': '.pdf',
|
||||
'text/plain': '.txt',
|
||||
'application/vnd.openxmlformats-officedocument.wordprocessingml.document': '.docx',
|
||||
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': '.xlsx',
|
||||
'application/vnd.openxmlformats-officedocument.presentationml.presentation': '.pptx',
|
||||
'text/html': '.html',
|
||||
'text/markdown': '.md',
|
||||
}
|
||||
return extensions.get(mime_type, '.bin')
|
||||
|
||||
def _download_folder_contents(self, folder_id: str, local_dir: str, recursive: bool = True) -> int:
|
||||
try:
|
||||
self._ensure_service()
|
||||
return self._download_folder_recursive(folder_id, local_dir, recursive)
|
||||
except Exception as e:
|
||||
logging.error(f"Error downloading folder {folder_id}: {e}", exc_info=True)
|
||||
return 0
|
||||
|
||||
def download_to_directory(self, local_dir: str, source_config: dict = None) -> dict:
|
||||
if source_config is None:
|
||||
source_config = {}
|
||||
|
||||
config = source_config if source_config else getattr(self, 'config', {})
|
||||
files_downloaded = 0
|
||||
|
||||
try:
|
||||
folder_ids = config.get('folder_ids', [])
|
||||
file_ids = config.get('file_ids', [])
|
||||
recursive = config.get('recursive', True)
|
||||
|
||||
self._ensure_service()
|
||||
|
||||
if file_ids:
|
||||
if isinstance(file_ids, str):
|
||||
file_ids = [file_ids]
|
||||
|
||||
for file_id in file_ids:
|
||||
if self._download_file_to_directory(file_id, local_dir):
|
||||
files_downloaded += 1
|
||||
|
||||
# Process folders
|
||||
if folder_ids:
|
||||
if isinstance(folder_ids, str):
|
||||
folder_ids = [folder_ids]
|
||||
|
||||
for folder_id in folder_ids:
|
||||
try:
|
||||
folder_metadata = self.service.files().get(
|
||||
fileId=folder_id,
|
||||
fields='name'
|
||||
).execute()
|
||||
folder_name = folder_metadata.get('name', '')
|
||||
folder_path = os.path.join(local_dir, folder_name)
|
||||
os.makedirs(folder_path, exist_ok=True)
|
||||
|
||||
folder_files = self._download_folder_recursive(
|
||||
folder_id,
|
||||
folder_path,
|
||||
recursive
|
||||
)
|
||||
files_downloaded += folder_files
|
||||
logging.info(f"Downloaded {folder_files} files from folder {folder_name}")
|
||||
except Exception as e:
|
||||
logging.error(f"Error downloading folder {folder_id}: {e}", exc_info=True)
|
||||
|
||||
if not file_ids and not folder_ids:
|
||||
raise ValueError("No folder_ids or file_ids provided for download")
|
||||
|
||||
return {
|
||||
"files_downloaded": files_downloaded,
|
||||
"directory_path": local_dir,
|
||||
"empty_result": files_downloaded == 0,
|
||||
"source_type": "google_drive",
|
||||
"config_used": config
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"files_downloaded": files_downloaded,
|
||||
"directory_path": local_dir,
|
||||
"empty_result": True,
|
||||
"source_type": "google_drive",
|
||||
"config_used": config,
|
||||
"error": str(e)
|
||||
}
|
||||
@@ -6,6 +6,21 @@ from application.core.settings import settings
|
||||
from application.vectorstore.vector_creator import VectorCreator
|
||||
|
||||
|
||||
def sanitize_content(content: str) -> str:
|
||||
"""
|
||||
Remove NUL characters that can cause vector store ingestion to fail.
|
||||
|
||||
Args:
|
||||
content (str): Raw content that may contain NUL characters
|
||||
|
||||
Returns:
|
||||
str: Sanitized content with NUL characters removed
|
||||
"""
|
||||
if not content:
|
||||
return content
|
||||
return content.replace('\x00', '')
|
||||
|
||||
|
||||
@retry(tries=10, delay=60)
|
||||
def add_text_to_store_with_retry(store, doc, source_id):
|
||||
"""
|
||||
@@ -16,6 +31,9 @@ def add_text_to_store_with_retry(store, doc, source_id):
|
||||
source_id: Unique identifier for the source.
|
||||
"""
|
||||
try:
|
||||
# Sanitize content to remove NUL characters that cause ingestion failures
|
||||
doc.page_content = sanitize_content(doc.page_content)
|
||||
|
||||
doc.metadata["source_id"] = str(source_id)
|
||||
store.add_texts([doc.page_content], metadatas=[doc.metadata])
|
||||
except Exception as e:
|
||||
@@ -46,7 +64,7 @@ def embed_and_store_documents(docs, folder_name, source_id, task_status):
|
||||
store = VectorCreator.create_vectorstore(
|
||||
settings.VECTOR_STORE,
|
||||
docs_init=docs_init,
|
||||
source_id=folder_name,
|
||||
source_id=source_id,
|
||||
embeddings_key=os.getenv("EMBEDDINGS_KEY"),
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -15,6 +15,7 @@ from application.parser.file.json_parser import JSONParser
|
||||
from application.parser.file.pptx_parser import PPTXParser
|
||||
from application.parser.file.image_parser import ImageParser
|
||||
from application.parser.schema.base import Document
|
||||
from application.utils import num_tokens_from_string
|
||||
|
||||
DEFAULT_FILE_EXTRACTOR: Dict[str, BaseParser] = {
|
||||
".pdf": PDFParser(),
|
||||
@@ -141,11 +142,12 @@ class SimpleDirectoryReader(BaseReader):
|
||||
|
||||
Returns:
|
||||
List[Document]: A list of documents.
|
||||
|
||||
"""
|
||||
data: Union[str, List[str]] = ""
|
||||
data_list: List[str] = []
|
||||
metadata_list = []
|
||||
self.file_token_counts = {}
|
||||
|
||||
for input_file in self.input_files:
|
||||
if input_file.suffix in self.file_extractor:
|
||||
parser = self.file_extractor[input_file.suffix]
|
||||
@@ -156,24 +158,48 @@ class SimpleDirectoryReader(BaseReader):
|
||||
# do standard read
|
||||
with open(input_file, "r", errors=self.errors) as f:
|
||||
data = f.read()
|
||||
# Prepare metadata for this file
|
||||
if self.file_metadata is not None:
|
||||
file_metadata = self.file_metadata(input_file.name)
|
||||
|
||||
# Calculate token count for this file
|
||||
if isinstance(data, List):
|
||||
file_tokens = sum(num_tokens_from_string(str(d)) for d in data)
|
||||
else:
|
||||
# Provide a default empty metadata
|
||||
file_metadata = {'title': '', 'store': ''}
|
||||
# TODO: Find a case with no metadata and check if breaks anything
|
||||
file_tokens = num_tokens_from_string(str(data))
|
||||
|
||||
full_path = str(input_file.resolve())
|
||||
self.file_token_counts[full_path] = file_tokens
|
||||
|
||||
base_metadata = {
|
||||
'title': input_file.name,
|
||||
'token_count': file_tokens,
|
||||
}
|
||||
|
||||
if hasattr(self, 'input_dir'):
|
||||
try:
|
||||
relative_path = str(input_file.relative_to(self.input_dir))
|
||||
base_metadata['source'] = relative_path
|
||||
except ValueError:
|
||||
base_metadata['source'] = str(input_file)
|
||||
else:
|
||||
base_metadata['source'] = str(input_file)
|
||||
|
||||
if self.file_metadata is not None:
|
||||
custom_metadata = self.file_metadata(input_file.name)
|
||||
base_metadata.update(custom_metadata)
|
||||
|
||||
if isinstance(data, List):
|
||||
# Extend data_list with each item in the data list
|
||||
data_list.extend([str(d) for d in data])
|
||||
# For each item in the data list, add the file's metadata to metadata_list
|
||||
metadata_list.extend([file_metadata for _ in data])
|
||||
metadata_list.extend([base_metadata for _ in data])
|
||||
else:
|
||||
# Add the single piece of data to data_list
|
||||
data_list.append(str(data))
|
||||
# Add the file's metadata to metadata_list
|
||||
metadata_list.append(file_metadata)
|
||||
metadata_list.append(base_metadata)
|
||||
|
||||
# Build directory structure if input_dir is provided
|
||||
if hasattr(self, 'input_dir'):
|
||||
self.directory_structure = self.build_directory_structure(self.input_dir)
|
||||
logging.info("Directory structure built successfully")
|
||||
else:
|
||||
self.directory_structure = {}
|
||||
|
||||
if concatenate:
|
||||
return [Document("\n".join(data_list))]
|
||||
@@ -181,3 +207,48 @@ class SimpleDirectoryReader(BaseReader):
|
||||
return [Document(d, extra_info=m) for d, m in zip(data_list, metadata_list)]
|
||||
else:
|
||||
return [Document(d) for d in data_list]
|
||||
|
||||
def build_directory_structure(self, base_path):
|
||||
"""Build a dictionary representing the directory structure.
|
||||
|
||||
Args:
|
||||
base_path: The base path to start building the structure from.
|
||||
|
||||
Returns:
|
||||
dict: A nested dictionary representing the directory structure.
|
||||
"""
|
||||
import mimetypes
|
||||
|
||||
def build_tree(path):
|
||||
"""Helper function to recursively build the directory tree."""
|
||||
result = {}
|
||||
|
||||
for item in path.iterdir():
|
||||
if self.exclude_hidden and item.name.startswith('.'):
|
||||
continue
|
||||
|
||||
if item.is_dir():
|
||||
subtree = build_tree(item)
|
||||
if subtree:
|
||||
result[item.name] = subtree
|
||||
else:
|
||||
if self.required_exts is not None and item.suffix not in self.required_exts:
|
||||
continue
|
||||
|
||||
full_path = str(item.resolve())
|
||||
file_size_bytes = item.stat().st_size
|
||||
mime_type = mimetypes.guess_type(item.name)[0] or "application/octet-stream"
|
||||
|
||||
file_info = {
|
||||
"type": mime_type,
|
||||
"size_bytes": file_size_bytes
|
||||
}
|
||||
|
||||
if hasattr(self, 'file_token_counts') and full_path in self.file_token_counts:
|
||||
file_info["token_count"] = self.file_token_counts[full_path]
|
||||
|
||||
result[item.name] = file_info
|
||||
|
||||
return result
|
||||
|
||||
return build_tree(Path(base_path))
|
||||
@@ -8,6 +8,7 @@ import requests
|
||||
from typing import Dict, Union
|
||||
|
||||
from application.parser.file.base_parser import BaseParser
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
class ImageParser(BaseParser):
|
||||
@@ -18,10 +19,13 @@ class ImageParser(BaseParser):
|
||||
return {}
|
||||
|
||||
def parse_file(self, file: Path, errors: str = "ignore") -> Union[str, list[str]]:
|
||||
doc2md_service = "https://llm.arc53.com/doc2md"
|
||||
# alternatively you can use local vision capable LLM
|
||||
with open(file, "rb") as file_loaded:
|
||||
files = {'file': file_loaded}
|
||||
response = requests.post(doc2md_service, files=files)
|
||||
data = response.json()["markdown"]
|
||||
if settings.PARSE_IMAGE_REMOTE:
|
||||
doc2md_service = "https://llm.arc53.com/doc2md"
|
||||
# alternatively you can use local vision capable LLM
|
||||
with open(file, "rb") as file_loaded:
|
||||
files = {'file': file_loaded}
|
||||
response = requests.post(doc2md_service, files=files)
|
||||
data = response.json()["markdown"]
|
||||
else:
|
||||
data = ""
|
||||
return data
|
||||
|
||||
@@ -6,6 +6,16 @@ from application.parser.remote.github_loader import GitHubLoader
|
||||
|
||||
|
||||
class RemoteCreator:
|
||||
"""
|
||||
Factory class for creating remote content loaders.
|
||||
|
||||
These loaders fetch content from remote web sources like URLs,
|
||||
sitemaps, web crawlers, social media platforms, etc.
|
||||
|
||||
For external knowledge base connectors (like Google Drive),
|
||||
use ConnectorCreator instead.
|
||||
"""
|
||||
|
||||
loaders = {
|
||||
"url": WebLoader,
|
||||
"sitemap": SitemapLoader,
|
||||
@@ -18,5 +28,5 @@ class RemoteCreator:
|
||||
def create_loader(cls, type, *args, **kwargs):
|
||||
loader_class = cls.loaders.get(type.lower())
|
||||
if not loader_class:
|
||||
raise ValueError(f"No LLM class found for type {type}")
|
||||
raise ValueError(f"No loader class found for type {type}")
|
||||
return loader_class(*args, **kwargs)
|
||||
|
||||
@@ -2,6 +2,7 @@ anthropic==0.49.0
|
||||
boto3==1.38.18
|
||||
beautifulsoup4==4.13.4
|
||||
celery==5.4.0
|
||||
cryptography==42.0.8
|
||||
dataclasses-json==0.6.7
|
||||
docx2txt==0.8
|
||||
duckduckgo-search==7.5.2
|
||||
@@ -11,8 +12,12 @@ esprima==4.0.1
|
||||
esutils==1.0.1
|
||||
Flask==3.1.1
|
||||
faiss-cpu==1.9.0.post1
|
||||
fastmcp==2.11.0
|
||||
flask-restx==1.3.0
|
||||
google-genai==1.3.0
|
||||
google-api-python-client==2.179.0
|
||||
google-auth-httplib2==0.2.0
|
||||
google-auth-oauthlib==1.2.2
|
||||
gTTS==2.5.4
|
||||
gunicorn==23.0.0
|
||||
javalang==0.13.0
|
||||
@@ -52,13 +57,13 @@ prompt-toolkit==3.0.51
|
||||
protobuf==5.29.3
|
||||
psycopg2-binary==2.9.10
|
||||
py==1.11.0
|
||||
pydantic==2.10.6
|
||||
pydantic-core==2.27.2
|
||||
pydantic-settings==2.7.1
|
||||
pydantic
|
||||
pydantic-core
|
||||
pydantic-settings
|
||||
pymongo==4.11.3
|
||||
pypdf==5.5.0
|
||||
python-dateutil==2.9.0.post0
|
||||
python-dotenv==1.0.1
|
||||
python-dotenv
|
||||
python-jose==3.4.0
|
||||
python-pptx==1.0.2
|
||||
redis==5.2.1
|
||||
@@ -78,7 +83,7 @@ tzdata==2024.2
|
||||
urllib3==2.3.0
|
||||
vine==5.1.0
|
||||
wcwidth==0.2.13
|
||||
werkzeug==3.1.3
|
||||
werkzeug>=3.1.0,<3.1.2
|
||||
yarl==1.20.0
|
||||
markdownify==1.1.0
|
||||
tldextract==5.1.3
|
||||
|
||||
@@ -5,10 +5,6 @@ class BaseRetriever(ABC):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def gen(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@@ -1,112 +0,0 @@
|
||||
import json
|
||||
|
||||
from langchain_community.tools import BraveSearch
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.retriever.base import BaseRetriever
|
||||
|
||||
|
||||
class BraveRetSearch(BaseRetriever):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
source,
|
||||
chat_history,
|
||||
prompt,
|
||||
chunks=2,
|
||||
token_limit=150,
|
||||
gpt_model="docsgpt",
|
||||
user_api_key=None,
|
||||
decoded_token=None,
|
||||
):
|
||||
self.question = ""
|
||||
self.source = source
|
||||
self.chat_history = chat_history
|
||||
self.prompt = prompt
|
||||
self.chunks = chunks
|
||||
self.gpt_model = gpt_model
|
||||
self.token_limit = (
|
||||
token_limit
|
||||
if token_limit
|
||||
< settings.MODEL_TOKEN_LIMITS.get(
|
||||
self.gpt_model, settings.DEFAULT_MAX_HISTORY
|
||||
)
|
||||
else settings.MODEL_TOKEN_LIMITS.get(
|
||||
self.gpt_model, settings.DEFAULT_MAX_HISTORY
|
||||
)
|
||||
)
|
||||
self.user_api_key = user_api_key
|
||||
self.decoded_token = decoded_token
|
||||
|
||||
def _get_data(self):
|
||||
if self.chunks == 0:
|
||||
docs = []
|
||||
else:
|
||||
search = BraveSearch.from_api_key(
|
||||
api_key=settings.BRAVE_SEARCH_API_KEY,
|
||||
search_kwargs={"count": int(self.chunks)},
|
||||
)
|
||||
results = search.run(self.question)
|
||||
results = json.loads(results)
|
||||
|
||||
docs = []
|
||||
for i in results:
|
||||
try:
|
||||
title = i["title"]
|
||||
link = i["link"]
|
||||
snippet = i["snippet"]
|
||||
docs.append({"text": snippet, "title": title, "link": link})
|
||||
except IndexError:
|
||||
pass
|
||||
if settings.LLM_NAME == "llama.cpp":
|
||||
docs = [docs[0]]
|
||||
|
||||
return docs
|
||||
|
||||
def gen(self):
|
||||
docs = self._get_data()
|
||||
|
||||
# join all page_content together with a newline
|
||||
docs_together = "\n".join([doc["text"] for doc in docs])
|
||||
p_chat_combine = self.prompt.replace("{summaries}", docs_together)
|
||||
messages_combine = [{"role": "system", "content": p_chat_combine}]
|
||||
for doc in docs:
|
||||
yield {"source": doc}
|
||||
|
||||
if len(self.chat_history) > 0:
|
||||
for i in self.chat_history:
|
||||
if "prompt" in i and "response" in i:
|
||||
messages_combine.append({"role": "user", "content": i["prompt"]})
|
||||
messages_combine.append(
|
||||
{"role": "assistant", "content": i["response"]}
|
||||
)
|
||||
messages_combine.append({"role": "user", "content": self.question})
|
||||
|
||||
llm = LLMCreator.create_llm(
|
||||
settings.LLM_NAME,
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=self.user_api_key,
|
||||
decoded_token=self.decoded_token,
|
||||
)
|
||||
|
||||
completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
|
||||
for line in completion:
|
||||
yield {"answer": str(line)}
|
||||
|
||||
def search(self, query: str = ""):
|
||||
if query:
|
||||
self.question = query
|
||||
return self._get_data()
|
||||
|
||||
def get_params(self):
|
||||
return {
|
||||
"question": self.question,
|
||||
"source": self.source,
|
||||
"chat_history": self.chat_history,
|
||||
"prompt": self.prompt,
|
||||
"chunks": self.chunks,
|
||||
"token_limit": self.token_limit,
|
||||
"gpt_model": self.gpt_model,
|
||||
"user_api_key": self.user_api_key,
|
||||
}
|
||||
@@ -1,4 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.retriever.base import BaseRetriever
|
||||
@@ -16,22 +18,37 @@ class ClassicRAG(BaseRetriever):
|
||||
token_limit=150,
|
||||
gpt_model="docsgpt",
|
||||
user_api_key=None,
|
||||
llm_name=settings.LLM_NAME,
|
||||
llm_name=settings.LLM_PROVIDER,
|
||||
api_key=settings.API_KEY,
|
||||
decoded_token=None,
|
||||
):
|
||||
self.original_question = ""
|
||||
"""Initialize ClassicRAG retriever with vectorstore sources and LLM configuration"""
|
||||
self.original_question = source.get("question", "")
|
||||
self.chat_history = chat_history if chat_history is not None else []
|
||||
self.prompt = prompt
|
||||
self.chunks = chunks
|
||||
if isinstance(chunks, str):
|
||||
try:
|
||||
self.chunks = int(chunks)
|
||||
except ValueError:
|
||||
logging.warning(
|
||||
f"Invalid chunks value '{chunks}', using default value 2"
|
||||
)
|
||||
self.chunks = 2
|
||||
else:
|
||||
self.chunks = chunks
|
||||
user_identifier = user_api_key if user_api_key else "default"
|
||||
logging.info(
|
||||
f"ClassicRAG initialized with chunks={self.chunks}, user_api_key={user_identifier}, "
|
||||
f"sources={'active_docs' in source and source['active_docs'] is not None}"
|
||||
)
|
||||
self.gpt_model = gpt_model
|
||||
self.token_limit = (
|
||||
token_limit
|
||||
if token_limit
|
||||
< settings.MODEL_TOKEN_LIMITS.get(
|
||||
< settings.LLM_TOKEN_LIMITS.get(
|
||||
self.gpt_model, settings.DEFAULT_MAX_HISTORY
|
||||
)
|
||||
else settings.MODEL_TOKEN_LIMITS.get(
|
||||
else settings.LLM_TOKEN_LIMITS.get(
|
||||
self.gpt_model, settings.DEFAULT_MAX_HISTORY
|
||||
)
|
||||
)
|
||||
@@ -44,26 +61,48 @@ class ClassicRAG(BaseRetriever):
|
||||
user_api_key=self.user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
self.vectorstore = source["active_docs"] if "active_docs" in source else None
|
||||
|
||||
if "active_docs" in source and source["active_docs"] is not None:
|
||||
if isinstance(source["active_docs"], list):
|
||||
self.vectorstores = source["active_docs"]
|
||||
else:
|
||||
self.vectorstores = [source["active_docs"]]
|
||||
else:
|
||||
self.vectorstores = []
|
||||
self.question = self._rephrase_query()
|
||||
self.decoded_token = decoded_token
|
||||
self._validate_vectorstore_config()
|
||||
|
||||
def _validate_vectorstore_config(self):
|
||||
"""Validate vectorstore IDs and remove any empty/invalid entries"""
|
||||
if not self.vectorstores:
|
||||
logging.warning("No vectorstores configured for retrieval")
|
||||
return
|
||||
invalid_ids = [
|
||||
vs_id for vs_id in self.vectorstores if not vs_id or not vs_id.strip()
|
||||
]
|
||||
if invalid_ids:
|
||||
logging.warning(f"Found invalid vectorstore IDs: {invalid_ids}")
|
||||
self.vectorstores = [
|
||||
vs_id for vs_id in self.vectorstores if vs_id and vs_id.strip()
|
||||
]
|
||||
|
||||
def _rephrase_query(self):
|
||||
"""Rephrase user query with chat history context for better retrieval"""
|
||||
if (
|
||||
not self.original_question
|
||||
or not self.chat_history
|
||||
or self.chat_history == []
|
||||
or self.chunks == 0
|
||||
or self.vectorstore is None
|
||||
or not self.vectorstores
|
||||
):
|
||||
return self.original_question
|
||||
|
||||
prompt = f"""Given the following conversation history:
|
||||
{self.chat_history}
|
||||
|
||||
Rephrase the following user question to be a standalone search query
|
||||
that captures all relevant context from the conversation:
|
||||
"""
|
||||
prompt = (
|
||||
"Given the following conversation history:\n"
|
||||
f"{self.chat_history}\n\n"
|
||||
"Rephrase the following user question to be a standalone search query "
|
||||
"that captures all relevant context from the conversation:\n"
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": prompt},
|
||||
@@ -79,44 +118,89 @@ class ClassicRAG(BaseRetriever):
|
||||
return self.original_question
|
||||
|
||||
def _get_data(self):
|
||||
if self.chunks == 0 or self.vectorstore is None:
|
||||
docs = []
|
||||
else:
|
||||
docsearch = VectorCreator.create_vectorstore(
|
||||
settings.VECTOR_STORE, self.vectorstore, settings.EMBEDDINGS_KEY
|
||||
"""Retrieve relevant documents from configured vectorstores"""
|
||||
if self.chunks == 0 or not self.vectorstores:
|
||||
logging.info(
|
||||
f"ClassicRAG._get_data: Skipping retrieval - chunks={self.chunks}, "
|
||||
f"vectorstores_count={len(self.vectorstores) if self.vectorstores else 0}"
|
||||
)
|
||||
docs_temp = docsearch.search(self.question, k=self.chunks)
|
||||
docs = [
|
||||
{
|
||||
"title": i.metadata.get(
|
||||
"title", i.metadata.get("post_title", i.page_content)
|
||||
).split("/")[-1],
|
||||
"text": i.page_content,
|
||||
"source": (
|
||||
i.metadata.get("source")
|
||||
if i.metadata.get("source")
|
||||
else "local"
|
||||
),
|
||||
}
|
||||
for i in docs_temp
|
||||
]
|
||||
return []
|
||||
all_docs = []
|
||||
chunks_per_source = max(1, self.chunks // len(self.vectorstores))
|
||||
|
||||
return docs
|
||||
logging.info(
|
||||
f"ClassicRAG._get_data: Starting retrieval with chunks={self.chunks}, "
|
||||
f"vectorstores={self.vectorstores}, chunks_per_source={chunks_per_source}, "
|
||||
f"query='{self.question[:50]}...'"
|
||||
)
|
||||
|
||||
def gen():
|
||||
pass
|
||||
for vectorstore_id in self.vectorstores:
|
||||
if vectorstore_id:
|
||||
try:
|
||||
docsearch = VectorCreator.create_vectorstore(
|
||||
settings.VECTOR_STORE, vectorstore_id, settings.EMBEDDINGS_KEY
|
||||
)
|
||||
docs_temp = docsearch.search(self.question, k=chunks_per_source)
|
||||
|
||||
for doc in docs_temp:
|
||||
if hasattr(doc, "page_content") and hasattr(doc, "metadata"):
|
||||
page_content = doc.page_content
|
||||
metadata = doc.metadata
|
||||
else:
|
||||
page_content = doc.get("text", doc.get("page_content", ""))
|
||||
metadata = doc.get("metadata", {})
|
||||
title = metadata.get(
|
||||
"title", metadata.get("post_title", page_content)
|
||||
)
|
||||
if not isinstance(title, str):
|
||||
title = str(title)
|
||||
title = title.split("/")[-1]
|
||||
|
||||
filename = (
|
||||
metadata.get("filename")
|
||||
or metadata.get("file_name")
|
||||
or metadata.get("source")
|
||||
)
|
||||
if isinstance(filename, str):
|
||||
filename = os.path.basename(filename) or filename
|
||||
else:
|
||||
filename = title
|
||||
if not filename:
|
||||
filename = title
|
||||
source_path = metadata.get("source") or vectorstore_id
|
||||
all_docs.append(
|
||||
{
|
||||
"title": title,
|
||||
"text": page_content,
|
||||
"source": source_path,
|
||||
"filename": filename,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error searching vectorstore {vectorstore_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
continue
|
||||
logging.info(
|
||||
f"ClassicRAG._get_data: Retrieval complete - retrieved {len(all_docs)} documents "
|
||||
f"(requested chunks={self.chunks}, chunks_per_source={chunks_per_source})"
|
||||
)
|
||||
return all_docs
|
||||
|
||||
def search(self, query: str = ""):
|
||||
"""Search for documents using optional query override"""
|
||||
if query:
|
||||
self.original_question = query
|
||||
self.question = self._rephrase_query()
|
||||
return self._get_data()
|
||||
|
||||
def get_params(self):
|
||||
"""Return current retriever configuration parameters"""
|
||||
return {
|
||||
"question": self.original_question,
|
||||
"rephrased_question": self.question,
|
||||
"source": self.vectorstore,
|
||||
"sources": self.vectorstores,
|
||||
"chunks": self.chunks,
|
||||
"token_limit": self.token_limit,
|
||||
"gpt_model": self.gpt_model,
|
||||
|
||||
@@ -1,111 +0,0 @@
|
||||
from langchain_community.tools import DuckDuckGoSearchResults
|
||||
from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.retriever.base import BaseRetriever
|
||||
|
||||
|
||||
class DuckDuckSearch(BaseRetriever):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
source,
|
||||
chat_history,
|
||||
prompt,
|
||||
chunks=2,
|
||||
token_limit=150,
|
||||
gpt_model="docsgpt",
|
||||
user_api_key=None,
|
||||
decoded_token=None,
|
||||
):
|
||||
self.question = ""
|
||||
self.source = source
|
||||
self.chat_history = chat_history
|
||||
self.prompt = prompt
|
||||
self.chunks = chunks
|
||||
self.gpt_model = gpt_model
|
||||
self.token_limit = (
|
||||
token_limit
|
||||
if token_limit
|
||||
< settings.MODEL_TOKEN_LIMITS.get(
|
||||
self.gpt_model, settings.DEFAULT_MAX_HISTORY
|
||||
)
|
||||
else settings.MODEL_TOKEN_LIMITS.get(
|
||||
self.gpt_model, settings.DEFAULT_MAX_HISTORY
|
||||
)
|
||||
)
|
||||
self.user_api_key = user_api_key
|
||||
self.decoded_token = decoded_token
|
||||
|
||||
def _get_data(self):
|
||||
if self.chunks == 0:
|
||||
docs = []
|
||||
else:
|
||||
wrapper = DuckDuckGoSearchAPIWrapper(max_results=self.chunks)
|
||||
search = DuckDuckGoSearchResults(api_wrapper=wrapper, output_format="list")
|
||||
results = search.run(self.question)
|
||||
|
||||
docs = []
|
||||
for i in results:
|
||||
try:
|
||||
docs.append(
|
||||
{
|
||||
"text": i.get("snippet", "").strip(),
|
||||
"title": i.get("title", "").strip(),
|
||||
"link": i.get("link", "").strip(),
|
||||
}
|
||||
)
|
||||
except IndexError:
|
||||
pass
|
||||
if settings.LLM_NAME == "llama.cpp":
|
||||
docs = [docs[0]]
|
||||
|
||||
return docs
|
||||
|
||||
def gen(self):
|
||||
docs = self._get_data()
|
||||
|
||||
# join all page_content together with a newline
|
||||
docs_together = "\n".join([doc["text"] for doc in docs])
|
||||
p_chat_combine = self.prompt.replace("{summaries}", docs_together)
|
||||
messages_combine = [{"role": "system", "content": p_chat_combine}]
|
||||
for doc in docs:
|
||||
yield {"source": doc}
|
||||
|
||||
if len(self.chat_history) > 0:
|
||||
for i in self.chat_history:
|
||||
if "prompt" in i and "response" in i:
|
||||
messages_combine.append({"role": "user", "content": i["prompt"]})
|
||||
messages_combine.append(
|
||||
{"role": "assistant", "content": i["response"]}
|
||||
)
|
||||
messages_combine.append({"role": "user", "content": self.question})
|
||||
|
||||
llm = LLMCreator.create_llm(
|
||||
settings.LLM_NAME,
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=self.user_api_key,
|
||||
decoded_token=self.decoded_token,
|
||||
)
|
||||
|
||||
completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
|
||||
for line in completion:
|
||||
yield {"answer": str(line)}
|
||||
|
||||
def search(self, query: str = ""):
|
||||
if query:
|
||||
self.question = query
|
||||
return self._get_data()
|
||||
|
||||
def get_params(self):
|
||||
return {
|
||||
"question": self.question,
|
||||
"source": self.source,
|
||||
"chat_history": self.chat_history,
|
||||
"prompt": self.prompt,
|
||||
"chunks": self.chunks,
|
||||
"token_limit": self.token_limit,
|
||||
"gpt_model": self.gpt_model,
|
||||
"user_api_key": self.user_api_key,
|
||||
}
|
||||
@@ -1,13 +1,9 @@
|
||||
from application.retriever.classic_rag import ClassicRAG
|
||||
from application.retriever.duckduck_search import DuckDuckSearch
|
||||
from application.retriever.brave_search import BraveRetSearch
|
||||
|
||||
|
||||
class RetrieverCreator:
|
||||
retrievers = {
|
||||
"classic": ClassicRAG,
|
||||
"duckduck_search": DuckDuckSearch,
|
||||
"brave_search": BraveRetSearch,
|
||||
"default": ClassicRAG,
|
||||
}
|
||||
|
||||
|
||||
0
application/security/__init__.py
Normal file
0
application/security/__init__.py
Normal file
85
application/security/encryption.py
Normal file
85
application/security/encryption.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.ciphers import algorithms, Cipher, modes
|
||||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
def _derive_key(user_id: str, salt: bytes) -> bytes:
|
||||
app_secret = settings.ENCRYPTION_SECRET_KEY
|
||||
|
||||
password = f"{app_secret}#{user_id}".encode()
|
||||
|
||||
kdf = PBKDF2HMAC(
|
||||
algorithm=hashes.SHA256(),
|
||||
length=32,
|
||||
salt=salt,
|
||||
iterations=100000,
|
||||
backend=default_backend(),
|
||||
)
|
||||
|
||||
return kdf.derive(password)
|
||||
|
||||
|
||||
def encrypt_credentials(credentials: dict, user_id: str) -> str:
|
||||
if not credentials:
|
||||
return ""
|
||||
try:
|
||||
salt = os.urandom(16)
|
||||
iv = os.urandom(16)
|
||||
key = _derive_key(user_id, salt)
|
||||
|
||||
json_str = json.dumps(credentials)
|
||||
|
||||
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
|
||||
encryptor = cipher.encryptor()
|
||||
|
||||
padded_data = _pad_data(json_str.encode())
|
||||
encrypted_data = encryptor.update(padded_data) + encryptor.finalize()
|
||||
|
||||
result = salt + iv + encrypted_data
|
||||
return base64.b64encode(result).decode()
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to encrypt credentials: {e}")
|
||||
return ""
|
||||
|
||||
|
||||
def decrypt_credentials(encrypted_data: str, user_id: str) -> dict:
|
||||
if not encrypted_data:
|
||||
return {}
|
||||
try:
|
||||
data = base64.b64decode(encrypted_data.encode())
|
||||
|
||||
salt = data[:16]
|
||||
iv = data[16:32]
|
||||
encrypted_content = data[32:]
|
||||
|
||||
key = _derive_key(user_id, salt)
|
||||
|
||||
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
|
||||
decryptor = cipher.decryptor()
|
||||
|
||||
decrypted_padded = decryptor.update(encrypted_content) + decryptor.finalize()
|
||||
decrypted_data = _unpad_data(decrypted_padded)
|
||||
|
||||
return json.loads(decrypted_data.decode())
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to decrypt credentials: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def _pad_data(data: bytes) -> bytes:
|
||||
block_size = 16
|
||||
padding_len = block_size - (len(data) % block_size)
|
||||
padding = bytes([padding_len]) * padding_len
|
||||
return data + padding
|
||||
|
||||
|
||||
def _unpad_data(data: bytes) -> bytes:
|
||||
padding_len = data[-1]
|
||||
return data[:-padding_len]
|
||||
0
application/storage/__init__.py
Normal file
0
application/storage/__init__.py
Normal file
@@ -1,4 +1,5 @@
|
||||
"""Base storage class for file system abstraction."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import BinaryIO, List, Callable
|
||||
|
||||
@@ -7,7 +8,7 @@ class BaseStorage(ABC):
|
||||
"""Abstract base class for storage implementations."""
|
||||
|
||||
@abstractmethod
|
||||
def save_file(self, file_data: BinaryIO, path: str) -> dict:
|
||||
def save_file(self, file_data: BinaryIO, path: str, **kwargs) -> dict:
|
||||
"""
|
||||
Save a file to storage.
|
||||
|
||||
@@ -92,3 +93,32 @@ class BaseStorage(ABC):
|
||||
List[str]: List of file paths
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_directory(self, path: str) -> bool:
|
||||
"""
|
||||
Check if a path is a directory.
|
||||
|
||||
Args:
|
||||
path: Path to check
|
||||
|
||||
Returns:
|
||||
bool: True if the path is a directory
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def remove_directory(self, directory: str) -> bool:
|
||||
"""
|
||||
Remove a directory and all its contents.
|
||||
|
||||
For local storage, this removes the directory and all files/subdirectories within it.
|
||||
For S3 storage, this removes all objects with the directory path as a prefix.
|
||||
|
||||
Args:
|
||||
directory: Directory path to remove
|
||||
|
||||
Returns:
|
||||
bool: True if removal was successful, False otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -101,3 +101,40 @@ class LocalStorage(BaseStorage):
|
||||
raise FileNotFoundError(f"File not found: {full_path}")
|
||||
|
||||
return processor_func(local_path=full_path, **kwargs)
|
||||
|
||||
def is_directory(self, path: str) -> bool:
|
||||
"""
|
||||
Check if a path is a directory in local storage.
|
||||
|
||||
Args:
|
||||
path: Path to check
|
||||
|
||||
Returns:
|
||||
bool: True if the path is a directory, False otherwise
|
||||
"""
|
||||
full_path = self._get_full_path(path)
|
||||
return os.path.isdir(full_path)
|
||||
|
||||
def remove_directory(self, directory: str) -> bool:
|
||||
"""
|
||||
Remove a directory and all its contents from local storage.
|
||||
|
||||
Args:
|
||||
directory: Directory path to remove
|
||||
|
||||
Returns:
|
||||
bool: True if removal was successful, False otherwise
|
||||
"""
|
||||
full_path = self._get_full_path(directory)
|
||||
|
||||
if not os.path.exists(full_path):
|
||||
return False
|
||||
|
||||
if not os.path.isdir(full_path):
|
||||
return False
|
||||
|
||||
try:
|
||||
shutil.rmtree(full_path)
|
||||
return True
|
||||
except (OSError, PermissionError):
|
||||
return False
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
"""S3 storage implementation."""
|
||||
|
||||
import io
|
||||
from typing import BinaryIO, List, Callable
|
||||
import os
|
||||
from typing import BinaryIO, Callable, List
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
from application.core.settings import settings
|
||||
|
||||
from application.storage.base import BaseStorage
|
||||
from application.core.settings import settings
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
|
||||
class S3Storage(BaseStorage):
|
||||
@@ -20,38 +21,48 @@ class S3Storage(BaseStorage):
|
||||
Args:
|
||||
bucket_name: S3 bucket name (optional, defaults to settings)
|
||||
"""
|
||||
self.bucket_name = bucket_name or getattr(settings, "S3_BUCKET_NAME", "docsgpt-test-bucket")
|
||||
self.bucket_name = bucket_name or getattr(
|
||||
settings, "S3_BUCKET_NAME", "docsgpt-test-bucket"
|
||||
)
|
||||
|
||||
# Get credentials from settings
|
||||
|
||||
aws_access_key_id = getattr(settings, "SAGEMAKER_ACCESS_KEY", None)
|
||||
aws_secret_access_key = getattr(settings, "SAGEMAKER_SECRET_KEY", None)
|
||||
region_name = getattr(settings, "SAGEMAKER_REGION", None)
|
||||
|
||||
self.s3 = boto3.client(
|
||||
's3',
|
||||
"s3",
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
region_name=region_name
|
||||
region_name=region_name,
|
||||
)
|
||||
|
||||
def save_file(self, file_data: BinaryIO, path: str) -> dict:
|
||||
def save_file(
|
||||
self,
|
||||
file_data: BinaryIO,
|
||||
path: str,
|
||||
storage_class: str = "INTELLIGENT_TIERING",
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
"""Save a file to S3 storage."""
|
||||
self.s3.upload_fileobj(file_data, self.bucket_name, path)
|
||||
self.s3.upload_fileobj(
|
||||
file_data, self.bucket_name, path, ExtraArgs={"StorageClass": storage_class}
|
||||
)
|
||||
|
||||
region = getattr(settings, "SAGEMAKER_REGION", None)
|
||||
|
||||
return {
|
||||
'storage_type': 's3',
|
||||
'bucket_name': self.bucket_name,
|
||||
'uri': f's3://{self.bucket_name}/{path}',
|
||||
'region': region
|
||||
"storage_type": "s3",
|
||||
"bucket_name": self.bucket_name,
|
||||
"uri": f"s3://{self.bucket_name}/{path}",
|
||||
"region": region,
|
||||
}
|
||||
|
||||
def get_file(self, path: str) -> BinaryIO:
|
||||
"""Get a file from S3 storage."""
|
||||
if not self.file_exists(path):
|
||||
raise FileNotFoundError(f"File not found: {path}")
|
||||
|
||||
file_obj = io.BytesIO()
|
||||
self.s3.download_fileobj(self.bucket_name, path, file_obj)
|
||||
file_obj.seek(0)
|
||||
@@ -76,18 +87,17 @@ class S3Storage(BaseStorage):
|
||||
def list_files(self, directory: str) -> List[str]:
|
||||
"""List all files in a directory in S3 storage."""
|
||||
# Ensure directory ends with a slash if it's not empty
|
||||
if directory and not directory.endswith('/'):
|
||||
directory += '/'
|
||||
|
||||
if directory and not directory.endswith("/"):
|
||||
directory += "/"
|
||||
result = []
|
||||
paginator = self.s3.get_paginator('list_objects_v2')
|
||||
paginator = self.s3.get_paginator("list_objects_v2")
|
||||
pages = paginator.paginate(Bucket=self.bucket_name, Prefix=directory)
|
||||
|
||||
for page in pages:
|
||||
if 'Contents' in page:
|
||||
for obj in page['Contents']:
|
||||
result.append(obj['Key'])
|
||||
|
||||
if "Contents" in page:
|
||||
for obj in page["Contents"]:
|
||||
result.append(obj["Key"])
|
||||
return result
|
||||
|
||||
def process_file(self, path: str, processor_func: Callable, **kwargs):
|
||||
@@ -98,23 +108,99 @@ class S3Storage(BaseStorage):
|
||||
path: Path to the file
|
||||
processor_func: Function that processes the file
|
||||
**kwargs: Additional arguments to pass to the processor function
|
||||
|
||||
|
||||
Returns:
|
||||
The result of the processor function
|
||||
"""
|
||||
import tempfile
|
||||
import logging
|
||||
|
||||
import tempfile
|
||||
|
||||
if not self.file_exists(path):
|
||||
raise FileNotFoundError(f"File not found in S3: {path}")
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=os.path.splitext(path)[1], delete=True) as temp_file:
|
||||
with tempfile.NamedTemporaryFile(
|
||||
suffix=os.path.splitext(path)[1], delete=True
|
||||
) as temp_file:
|
||||
try:
|
||||
# Download the file from S3 to the temporary file
|
||||
|
||||
self.s3.download_fileobj(self.bucket_name, path, temp_file)
|
||||
temp_file.flush()
|
||||
|
||||
|
||||
return processor_func(local_path=temp_file.name, **kwargs)
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing S3 file {path}: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def is_directory(self, path: str) -> bool:
|
||||
"""
|
||||
Check if a path is a directory in S3 storage.
|
||||
|
||||
In S3, directories are virtual concepts. A path is considered a directory
|
||||
if there are objects with the path as a prefix.
|
||||
|
||||
Args:
|
||||
path: Path to check
|
||||
|
||||
Returns:
|
||||
bool: True if the path is a directory, False otherwise
|
||||
"""
|
||||
# Ensure path ends with a slash if not empty
|
||||
if path and not path.endswith('/'):
|
||||
path += '/'
|
||||
|
||||
response = self.s3.list_objects_v2(
|
||||
Bucket=self.bucket_name,
|
||||
Prefix=path,
|
||||
MaxKeys=1
|
||||
)
|
||||
|
||||
return 'Contents' in response
|
||||
|
||||
def remove_directory(self, directory: str) -> bool:
|
||||
"""
|
||||
Remove a directory and all its contents from S3 storage.
|
||||
|
||||
In S3, this removes all objects with the directory path as a prefix.
|
||||
Since S3 doesn't have actual directories, this effectively removes
|
||||
all files within the virtual directory structure.
|
||||
|
||||
Args:
|
||||
directory: Directory path to remove
|
||||
|
||||
Returns:
|
||||
bool: True if removal was successful, False otherwise
|
||||
"""
|
||||
# Ensure directory ends with a slash if not empty
|
||||
if directory and not directory.endswith('/'):
|
||||
directory += '/'
|
||||
|
||||
try:
|
||||
# Get all objects with the directory prefix
|
||||
objects_to_delete = []
|
||||
paginator = self.s3.get_paginator('list_objects_v2')
|
||||
pages = paginator.paginate(Bucket=self.bucket_name, Prefix=directory)
|
||||
|
||||
for page in pages:
|
||||
if 'Contents' in page:
|
||||
for obj in page['Contents']:
|
||||
objects_to_delete.append({'Key': obj['Key']})
|
||||
|
||||
if not objects_to_delete:
|
||||
return False
|
||||
|
||||
batch_size = 1000
|
||||
for i in range(0, len(objects_to_delete), batch_size):
|
||||
batch = objects_to_delete[i:i + batch_size]
|
||||
|
||||
response = self.s3.delete_objects(
|
||||
Bucket=self.bucket_name,
|
||||
Delete={'Objects': batch}
|
||||
)
|
||||
|
||||
if 'Errors' in response and response['Errors']:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except ClientError:
|
||||
return False
|
||||
|
||||
@@ -1,84 +1,30 @@
|
||||
import asyncio
|
||||
import websockets
|
||||
import json
|
||||
import base64
|
||||
from io import BytesIO
|
||||
import base64
|
||||
from application.tts.base import BaseTTS
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
class ElevenlabsTTS(BaseTTS):
|
||||
def __init__(self):
|
||||
self.api_key = 'ELEVENLABS_API_KEY'# here you should put your api key
|
||||
self.model = "eleven_flash_v2_5"
|
||||
self.voice = "VOICE_ID" # this is the hash code for the voice not the name!
|
||||
self.write_audio = 1
|
||||
def __init__(self):
|
||||
from elevenlabs.client import ElevenLabs
|
||||
|
||||
self.client = ElevenLabs(
|
||||
api_key=settings.ELEVENLABS_API_KEY,
|
||||
)
|
||||
|
||||
|
||||
def text_to_speech(self, text):
|
||||
asyncio.run(self._text_to_speech_websocket(text))
|
||||
lang = "en"
|
||||
audio = self.client.generate(
|
||||
text=text,
|
||||
model="eleven_multilingual_v2",
|
||||
voice="Brian",
|
||||
)
|
||||
audio_data = BytesIO()
|
||||
for chunk in audio:
|
||||
audio_data.write(chunk)
|
||||
audio_bytes = audio_data.getvalue()
|
||||
|
||||
async def _text_to_speech_websocket(self, text):
|
||||
uri = f"wss://api.elevenlabs.io/v1/text-to-speech/{self.voice}/stream-input?model_id={self.model}"
|
||||
websocket = await websockets.connect(uri)
|
||||
payload = {
|
||||
"text": " ",
|
||||
"voice_settings": {
|
||||
"stability": 0.5,
|
||||
"similarity_boost": 0.8,
|
||||
},
|
||||
"xi_api_key": self.api_key,
|
||||
}
|
||||
|
||||
await websocket.send(json.dumps(payload))
|
||||
|
||||
async def listen():
|
||||
while 1:
|
||||
try:
|
||||
msg = await websocket.recv()
|
||||
data = json.loads(msg)
|
||||
|
||||
if data.get("audio"):
|
||||
print("audio received")
|
||||
yield base64.b64decode(data["audio"])
|
||||
elif data.get("isFinal"):
|
||||
break
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
print("websocket closed")
|
||||
break
|
||||
listen_task = asyncio.create_task(self.stream(listen()))
|
||||
|
||||
await websocket.send(json.dumps({"text": text}))
|
||||
# this is to signal the end of the text, either use this or flush
|
||||
await websocket.send(json.dumps({"text": ""}))
|
||||
|
||||
await listen_task
|
||||
|
||||
async def stream(self, audio_stream):
|
||||
if self.write_audio:
|
||||
audio_bytes = BytesIO()
|
||||
async for chunk in audio_stream:
|
||||
if chunk:
|
||||
audio_bytes.write(chunk)
|
||||
with open("output_audio.mp3", "wb") as f:
|
||||
f.write(audio_bytes.getvalue())
|
||||
|
||||
else:
|
||||
async for chunk in audio_stream:
|
||||
pass # depends on the streamer!
|
||||
|
||||
|
||||
def test_elevenlabs_websocket():
|
||||
"""
|
||||
Tests the ElevenlabsTTS text_to_speech method with a sample prompt.
|
||||
Prints out the base64-encoded result and writes it to 'output_audio.mp3'.
|
||||
"""
|
||||
# Instantiate your TTS class
|
||||
tts = ElevenlabsTTS()
|
||||
|
||||
# Call the method with some sample text
|
||||
tts.text_to_speech("Hello from ElevenLabs WebSocket!")
|
||||
|
||||
print("Saved audio to output_audio.mp3.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_elevenlabs_websocket()
|
||||
# Encode to base64
|
||||
audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")
|
||||
return audio_base64, lang
|
||||
@@ -1,8 +1,13 @@
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
|
||||
import tiktoken
|
||||
from flask import jsonify, make_response
|
||||
from werkzeug.utils import secure_filename
|
||||
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
_encoding = None
|
||||
@@ -15,6 +20,41 @@ def get_encoding():
|
||||
return _encoding
|
||||
|
||||
|
||||
def get_gpt_model() -> str:
|
||||
"""Get the appropriate GPT model based on provider"""
|
||||
model_map = {
|
||||
"openai": "gpt-4o-mini",
|
||||
"anthropic": "claude-2",
|
||||
"groq": "llama3-8b-8192",
|
||||
"novita": "deepseek/deepseek-r1",
|
||||
}
|
||||
return settings.LLM_NAME or model_map.get(settings.LLM_PROVIDER, "")
|
||||
|
||||
|
||||
def safe_filename(filename):
|
||||
"""
|
||||
Creates a safe filename that preserves the original extension.
|
||||
Uses secure_filename, but ensures a proper filename is returned even with non-Latin characters.
|
||||
|
||||
Args:
|
||||
filename (str): The original filename
|
||||
|
||||
Returns:
|
||||
str: A safe filename that can be used for storage
|
||||
"""
|
||||
if not filename:
|
||||
return str(uuid.uuid4())
|
||||
_, extension = os.path.splitext(filename)
|
||||
|
||||
safe_name = secure_filename(filename)
|
||||
|
||||
# If secure_filename returns just the extension or an empty string
|
||||
|
||||
if not safe_name or safe_name == extension.lstrip("."):
|
||||
return f"{str(uuid.uuid4())}{extension}"
|
||||
return safe_name
|
||||
|
||||
|
||||
def num_tokens_from_string(string: str) -> int:
|
||||
encoding = get_encoding()
|
||||
if isinstance(string, str):
|
||||
@@ -39,7 +79,6 @@ def count_tokens_docs(docs):
|
||||
docs_content = ""
|
||||
for doc in docs:
|
||||
docs_content += doc.page_content
|
||||
|
||||
tokens = num_tokens_from_string(docs_content)
|
||||
return tokens
|
||||
|
||||
@@ -51,7 +90,7 @@ def check_required_fields(data, required_fields):
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"Missing fields: {', '.join(missing_fields)}",
|
||||
"message": f"Missing required fields: {', '.join(missing_fields)}",
|
||||
}
|
||||
),
|
||||
400,
|
||||
@@ -59,6 +98,27 @@ def check_required_fields(data, required_fields):
|
||||
return None
|
||||
|
||||
|
||||
def validate_required_fields(data, required_fields):
|
||||
missing_fields = []
|
||||
empty_fields = []
|
||||
|
||||
for field in required_fields:
|
||||
if field not in data:
|
||||
missing_fields.append(field)
|
||||
elif not data[field]:
|
||||
empty_fields.append(field)
|
||||
errors = []
|
||||
if missing_fields:
|
||||
errors.append(f"Missing required fields: {', '.join(missing_fields)}")
|
||||
if empty_fields:
|
||||
errors.append(f"Empty values in required fields: {', '.join(empty_fields)}")
|
||||
if errors:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": " | ".join(errors)}), 400
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def get_hash(data):
|
||||
return hashlib.md5(data.encode(), usedforsecurity=False).hexdigest()
|
||||
|
||||
@@ -74,13 +134,12 @@ def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"):
|
||||
max_token_limit
|
||||
if max_token_limit
|
||||
and max_token_limit
|
||||
< settings.MODEL_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_MAX_HISTORY)
|
||||
else settings.MODEL_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_MAX_HISTORY)
|
||||
< settings.LLM_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_MAX_HISTORY)
|
||||
else settings.LLM_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_MAX_HISTORY)
|
||||
)
|
||||
|
||||
if not history:
|
||||
return []
|
||||
|
||||
trimmed_history = []
|
||||
tokens_current_history = 0
|
||||
|
||||
@@ -89,18 +148,15 @@ def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"):
|
||||
if "prompt" in message and "response" in message:
|
||||
tokens_batch += num_tokens_from_string(message["prompt"])
|
||||
tokens_batch += num_tokens_from_string(message["response"])
|
||||
|
||||
if "tool_calls" in message:
|
||||
for tool_call in message["tool_calls"]:
|
||||
tool_call_string = f"Tool: {tool_call.get('tool_name')} | Action: {tool_call.get('action_name')} | Args: {tool_call.get('arguments')} | Response: {tool_call.get('result')}"
|
||||
tokens_batch += num_tokens_from_string(tool_call_string)
|
||||
|
||||
if tokens_current_history + tokens_batch < max_token_limit:
|
||||
tokens_current_history += tokens_batch
|
||||
trimmed_history.insert(0, message)
|
||||
else:
|
||||
break
|
||||
|
||||
return trimmed_history
|
||||
|
||||
|
||||
@@ -109,3 +165,14 @@ def validate_function_name(function_name):
|
||||
if not re.match(r"^[a-zA-Z0-9_-]+$", function_name):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def generate_image_url(image_path):
|
||||
strategy = getattr(settings, "URL_STRATEGY", "backend")
|
||||
if strategy == "s3":
|
||||
bucket_name = getattr(settings, "S3_BUCKET_NAME", "docsgpt-test-bucket")
|
||||
region_name = getattr(settings, "SAGEMAKER_REGION", "eu-central-1")
|
||||
return f"https://{bucket_name}.s3.{region_name}.amazonaws.com/{image_path}"
|
||||
else:
|
||||
base_url = getattr(settings, "API_URL", "http://localhost:7091")
|
||||
return f"{base_url}/api/images/{image_path}"
|
||||
|
||||
@@ -1,20 +1,43 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import logging
|
||||
import os
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
class EmbeddingsWrapper:
|
||||
def __init__(self, model_name, *args, **kwargs):
|
||||
self.model = SentenceTransformer(model_name, config_kwargs={'allow_dangerous_deserialization': True}, *args, **kwargs)
|
||||
self.dimension = self.model.get_sentence_embedding_dimension()
|
||||
logging.info(f"Initializing EmbeddingsWrapper with model: {model_name}")
|
||||
try:
|
||||
kwargs.setdefault("trust_remote_code", True)
|
||||
self.model = SentenceTransformer(
|
||||
model_name,
|
||||
config_kwargs={"allow_dangerous_deserialization": True},
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
if self.model is None or self.model._first_module() is None:
|
||||
raise ValueError(
|
||||
f"SentenceTransformer model failed to load properly for: {model_name}"
|
||||
)
|
||||
self.dimension = self.model.get_sentence_embedding_dimension()
|
||||
logging.info(f"Successfully loaded model with dimension: {self.dimension}")
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Failed to initialize SentenceTransformer with model {model_name}: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
def embed_query(self, query: str):
|
||||
return self.model.encode(query).tolist()
|
||||
|
||||
|
||||
def embed_documents(self, documents: list):
|
||||
return self.model.encode(documents).tolist()
|
||||
|
||||
|
||||
def __call__(self, text):
|
||||
if isinstance(text, str):
|
||||
return self.embed_query(text)
|
||||
@@ -24,15 +47,14 @@ class EmbeddingsWrapper:
|
||||
raise ValueError("Input must be a string or a list of strings")
|
||||
|
||||
|
||||
|
||||
class EmbeddingsSingleton:
|
||||
_instances = {}
|
||||
|
||||
@staticmethod
|
||||
def get_instance(embeddings_name, *args, **kwargs):
|
||||
if embeddings_name not in EmbeddingsSingleton._instances:
|
||||
EmbeddingsSingleton._instances[embeddings_name] = EmbeddingsSingleton._create_instance(
|
||||
embeddings_name, *args, **kwargs
|
||||
EmbeddingsSingleton._instances[embeddings_name] = (
|
||||
EmbeddingsSingleton._create_instance(embeddings_name, *args, **kwargs)
|
||||
)
|
||||
return EmbeddingsSingleton._instances[embeddings_name]
|
||||
|
||||
@@ -40,9 +62,15 @@ class EmbeddingsSingleton:
|
||||
def _create_instance(embeddings_name, *args, **kwargs):
|
||||
embeddings_factory = {
|
||||
"openai_text-embedding-ada-002": OpenAIEmbeddings,
|
||||
"huggingface_sentence-transformers/all-mpnet-base-v2": lambda: EmbeddingsWrapper("sentence-transformers/all-mpnet-base-v2"),
|
||||
"huggingface_sentence-transformers-all-mpnet-base-v2": lambda: EmbeddingsWrapper("sentence-transformers/all-mpnet-base-v2"),
|
||||
"huggingface_hkunlp/instructor-large": lambda: EmbeddingsWrapper("hkunlp/instructor-large"),
|
||||
"huggingface_sentence-transformers/all-mpnet-base-v2": lambda: EmbeddingsWrapper(
|
||||
"sentence-transformers/all-mpnet-base-v2"
|
||||
),
|
||||
"huggingface_sentence-transformers-all-mpnet-base-v2": lambda: EmbeddingsWrapper(
|
||||
"sentence-transformers/all-mpnet-base-v2"
|
||||
),
|
||||
"huggingface_hkunlp/instructor-large": lambda: EmbeddingsWrapper(
|
||||
"hkunlp/instructor-large"
|
||||
),
|
||||
}
|
||||
|
||||
if embeddings_name in embeddings_factory:
|
||||
@@ -50,41 +78,83 @@ class EmbeddingsSingleton:
|
||||
else:
|
||||
return EmbeddingsWrapper(embeddings_name, *args, **kwargs)
|
||||
|
||||
|
||||
class BaseVectorStore(ABC):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(self, *args, **kwargs):
|
||||
"""Search for similar documents/chunks in the vectorstore"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_texts(self, texts, metadatas=None, *args, **kwargs):
|
||||
"""Add texts with their embeddings to the vectorstore"""
|
||||
pass
|
||||
|
||||
def delete_index(self, *args, **kwargs):
|
||||
"""Delete the entire index/collection"""
|
||||
pass
|
||||
|
||||
def save_local(self, *args, **kwargs):
|
||||
"""Save vectorstore to local storage"""
|
||||
pass
|
||||
|
||||
def get_chunks(self, *args, **kwargs):
|
||||
"""Get all chunks from the vectorstore"""
|
||||
pass
|
||||
|
||||
def add_chunk(self, text, metadata=None, *args, **kwargs):
|
||||
"""Add a single chunk to the vectorstore"""
|
||||
pass
|
||||
|
||||
def delete_chunk(self, chunk_id, *args, **kwargs):
|
||||
"""Delete a specific chunk from the vectorstore"""
|
||||
pass
|
||||
|
||||
def is_azure_configured(self):
|
||||
return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME
|
||||
return (
|
||||
settings.OPENAI_API_BASE
|
||||
and settings.OPENAI_API_VERSION
|
||||
and settings.AZURE_DEPLOYMENT_NAME
|
||||
)
|
||||
|
||||
def _get_embeddings(self, embeddings_name, embeddings_key=None):
|
||||
if embeddings_name == "openai_text-embedding-ada-002":
|
||||
if self.is_azure_configured():
|
||||
os.environ["OPENAI_API_TYPE"] = "azure"
|
||||
embedding_instance = EmbeddingsSingleton.get_instance(
|
||||
embeddings_name,
|
||||
model=settings.AZURE_EMBEDDINGS_DEPLOYMENT_NAME
|
||||
embeddings_name, model=settings.AZURE_EMBEDDINGS_DEPLOYMENT_NAME
|
||||
)
|
||||
else:
|
||||
embedding_instance = EmbeddingsSingleton.get_instance(
|
||||
embeddings_name,
|
||||
openai_api_key=embeddings_key
|
||||
embeddings_name, openai_api_key=embeddings_key
|
||||
)
|
||||
elif embeddings_name == "huggingface_sentence-transformers/all-mpnet-base-v2":
|
||||
if os.path.exists("./models/all-mpnet-base-v2"):
|
||||
possible_paths = [
|
||||
"/app/models/all-mpnet-base-v2", # Docker absolute path
|
||||
"./models/all-mpnet-base-v2", # Relative path
|
||||
]
|
||||
local_model_path = None
|
||||
for path in possible_paths:
|
||||
if os.path.exists(path):
|
||||
local_model_path = path
|
||||
logging.info(f"Found local model at path: {path}")
|
||||
break
|
||||
else:
|
||||
logging.info(f"Path does not exist: {path}")
|
||||
if local_model_path:
|
||||
embedding_instance = EmbeddingsSingleton.get_instance(
|
||||
embeddings_name = "./models/all-mpnet-base-v2",
|
||||
local_model_path,
|
||||
)
|
||||
else:
|
||||
logging.warning(
|
||||
f"Local model not found in any of the paths: {possible_paths}. Falling back to HuggingFace download."
|
||||
)
|
||||
embedding_instance = EmbeddingsSingleton.get_instance(
|
||||
embeddings_name,
|
||||
)
|
||||
else:
|
||||
embedding_instance = EmbeddingsSingleton.get_instance(embeddings_name)
|
||||
|
||||
return embedding_instance
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
import tempfile
|
||||
import io
|
||||
|
||||
from langchain_community.vectorstores import FAISS
|
||||
|
||||
@@ -66,8 +67,37 @@ class FaissStore(BaseVectorStore):
|
||||
def add_texts(self, *args, **kwargs):
|
||||
return self.docsearch.add_texts(*args, **kwargs)
|
||||
|
||||
def save_local(self, *args, **kwargs):
|
||||
return self.docsearch.save_local(*args, **kwargs)
|
||||
def _save_to_storage(self):
|
||||
"""
|
||||
Save the FAISS index to storage using temporary directory pattern.
|
||||
Works consistently for both local and S3 storage.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
self.docsearch.save_local(temp_dir)
|
||||
|
||||
faiss_path = os.path.join(temp_dir, "index.faiss")
|
||||
pkl_path = os.path.join(temp_dir, "index.pkl")
|
||||
|
||||
with open(faiss_path, "rb") as f_faiss:
|
||||
faiss_data = f_faiss.read()
|
||||
|
||||
with open(pkl_path, "rb") as f_pkl:
|
||||
pkl_data = f_pkl.read()
|
||||
|
||||
storage_path = get_vectorstore(self.source_id)
|
||||
self.storage.save_file(io.BytesIO(faiss_data), f"{storage_path}/index.faiss")
|
||||
self.storage.save_file(io.BytesIO(pkl_data), f"{storage_path}/index.pkl")
|
||||
|
||||
return True
|
||||
|
||||
def save_local(self, path=None):
|
||||
if path:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
self.docsearch.save_local(path)
|
||||
|
||||
self._save_to_storage()
|
||||
|
||||
return True
|
||||
|
||||
def delete_index(self, *args, **kwargs):
|
||||
return self.docsearch.delete(*args, **kwargs)
|
||||
@@ -103,13 +133,17 @@ class FaissStore(BaseVectorStore):
|
||||
return chunks
|
||||
|
||||
def add_chunk(self, text, metadata=None):
|
||||
"""Add a new chunk and save to storage."""
|
||||
metadata = metadata or {}
|
||||
doc = Document(text=text, extra_info=metadata).to_langchain_format()
|
||||
doc_id = self.docsearch.add_documents([doc])
|
||||
self.save_local(self.path)
|
||||
self._save_to_storage()
|
||||
return doc_id
|
||||
|
||||
|
||||
|
||||
def delete_chunk(self, chunk_id):
|
||||
"""Delete a chunk and save to storage."""
|
||||
self.delete_index([chunk_id])
|
||||
self.save_local(self.path)
|
||||
self._save_to_storage()
|
||||
return True
|
||||
|
||||
303
application/vectorstore/pgvector.py
Normal file
303
application/vectorstore/pgvector.py
Normal file
@@ -0,0 +1,303 @@
|
||||
import logging
|
||||
from typing import List, Optional, Any, Dict
|
||||
from application.core.settings import settings
|
||||
from application.vectorstore.base import BaseVectorStore
|
||||
from application.vectorstore.document_class import Document
|
||||
|
||||
|
||||
class PGVectorStore(BaseVectorStore):
|
||||
def __init__(
|
||||
self,
|
||||
source_id: str = "",
|
||||
embeddings_key: str = "embeddings",
|
||||
table_name: str = "documents",
|
||||
vector_column: str = "embedding",
|
||||
text_column: str = "text",
|
||||
metadata_column: str = "metadata",
|
||||
connection_string: str = None,
|
||||
):
|
||||
super().__init__()
|
||||
# Store the source_id for use in add_chunk
|
||||
self._source_id = str(source_id).replace("application/indexes/", "").rstrip("/")
|
||||
self._embeddings_key = embeddings_key
|
||||
self._table_name = table_name
|
||||
self._vector_column = vector_column
|
||||
self._text_column = text_column
|
||||
self._metadata_column = metadata_column
|
||||
self._embedding = self._get_embeddings(settings.EMBEDDINGS_NAME, embeddings_key)
|
||||
|
||||
# Use provided connection string or fall back to settings
|
||||
self._connection_string = connection_string or getattr(settings, 'PGVECTOR_CONNECTION_STRING', None)
|
||||
|
||||
if not self._connection_string:
|
||||
raise ValueError(
|
||||
"PostgreSQL connection string is required. "
|
||||
"Set PGVECTOR_CONNECTION_STRING in settings or pass connection_string parameter."
|
||||
)
|
||||
|
||||
try:
|
||||
import psycopg2
|
||||
from psycopg2.extras import Json
|
||||
import pgvector.psycopg2
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import required packages. "
|
||||
"Please install with `pip install psycopg2-binary pgvector`."
|
||||
)
|
||||
|
||||
self._psycopg2 = psycopg2
|
||||
self._Json = Json
|
||||
self._pgvector = pgvector.psycopg2
|
||||
self._connection = None
|
||||
self._ensure_table_exists()
|
||||
|
||||
def _get_connection(self):
|
||||
"""Get or create database connection"""
|
||||
if self._connection is None or self._connection.closed:
|
||||
self._connection = self._psycopg2.connect(self._connection_string)
|
||||
# Register pgvector types
|
||||
self._pgvector.register_vector(self._connection)
|
||||
return self._connection
|
||||
|
||||
def _ensure_table_exists(self):
|
||||
"""Create table and enable pgvector extension if they don't exist"""
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
# Enable pgvector extension
|
||||
cursor.execute("CREATE EXTENSION IF NOT EXISTS vector;")
|
||||
|
||||
# Get embedding dimension
|
||||
embedding_dim = getattr(self._embedding, 'dimension', 1536) # Default to OpenAI dimension
|
||||
|
||||
# Create table with vector column
|
||||
create_table_query = f"""
|
||||
CREATE TABLE IF NOT EXISTS {self._table_name} (
|
||||
id SERIAL PRIMARY KEY,
|
||||
{self._text_column} TEXT NOT NULL,
|
||||
{self._vector_column} vector({embedding_dim}),
|
||||
{self._metadata_column} JSONB,
|
||||
source_id TEXT NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
"""
|
||||
cursor.execute(create_table_query)
|
||||
|
||||
# Create index for vector similarity search
|
||||
index_query = f"""
|
||||
CREATE INDEX IF NOT EXISTS {self._table_name}_{self._vector_column}_idx
|
||||
ON {self._table_name} USING ivfflat ({self._vector_column} vector_cosine_ops)
|
||||
WITH (lists = 100);
|
||||
"""
|
||||
cursor.execute(index_query)
|
||||
|
||||
# Create index for source_id filtering
|
||||
source_index_query = f"""
|
||||
CREATE INDEX IF NOT EXISTS {self._table_name}_source_id_idx
|
||||
ON {self._table_name} (source_id);
|
||||
"""
|
||||
cursor.execute(source_index_query)
|
||||
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
logging.error(f"Error creating table: {e}")
|
||||
raise
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
def search(self, question: str, k: int = 2, *args, **kwargs) -> List[Document]:
|
||||
"""Search for similar documents using vector similarity"""
|
||||
query_vector = self._embedding.embed_query(question)
|
||||
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
# Use cosine distance for similarity search with proper vector formatting
|
||||
search_query = f"""
|
||||
SELECT {self._text_column}, {self._metadata_column},
|
||||
({self._vector_column} <=> %s::vector) as distance
|
||||
FROM {self._table_name}
|
||||
WHERE source_id = %s
|
||||
ORDER BY {self._vector_column} <=> %s::vector
|
||||
LIMIT %s;
|
||||
"""
|
||||
|
||||
cursor.execute(search_query, (query_vector, self._source_id, query_vector, k))
|
||||
results = cursor.fetchall()
|
||||
|
||||
|
||||
documents = []
|
||||
for text, metadata, distance in results:
|
||||
metadata = metadata or {}
|
||||
documents.append(Document(page_content=text, metadata=metadata))
|
||||
|
||||
return documents
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error searching documents: {e}", exc_info=True)
|
||||
return []
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: List[str],
|
||||
metadatas: Optional[List[Dict[str, Any]]] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> List[str]:
|
||||
"""Add texts with their embeddings to the vector store"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
embeddings = self._embedding.embed_documents(texts)
|
||||
metadatas = metadatas or [{}] * len(texts)
|
||||
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
insert_query = f"""
|
||||
INSERT INTO {self._table_name} ({self._text_column}, {self._vector_column}, {self._metadata_column}, source_id)
|
||||
VALUES (%s, %s, %s, %s)
|
||||
RETURNING id;
|
||||
"""
|
||||
|
||||
inserted_ids = []
|
||||
for text, embedding, metadata in zip(texts, embeddings, metadatas):
|
||||
cursor.execute(
|
||||
insert_query,
|
||||
(text, embedding, self._Json(metadata), self._source_id)
|
||||
)
|
||||
inserted_id = cursor.fetchone()[0]
|
||||
inserted_ids.append(str(inserted_id))
|
||||
|
||||
conn.commit()
|
||||
return inserted_ids
|
||||
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
logging.error(f"Error adding texts: {e}")
|
||||
raise
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
def delete_index(self, *args, **kwargs):
|
||||
"""Delete all documents for this source_id"""
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
delete_query = f"DELETE FROM {self._table_name} WHERE source_id = %s;"
|
||||
cursor.execute(delete_query, (self._source_id,))
|
||||
conn.commit()
|
||||
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
logging.error(f"Error deleting index: {e}")
|
||||
raise
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
def save_local(self, *args, **kwargs):
|
||||
"""No-op for PostgreSQL - data is already persisted"""
|
||||
pass
|
||||
|
||||
def get_chunks(self) -> List[Dict[str, Any]]:
|
||||
"""Get all chunks for this source_id"""
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
select_query = f"""
|
||||
SELECT id, {self._text_column}, {self._metadata_column}
|
||||
FROM {self._table_name}
|
||||
WHERE source_id = %s;
|
||||
"""
|
||||
cursor.execute(select_query, (self._source_id,))
|
||||
results = cursor.fetchall()
|
||||
|
||||
chunks = []
|
||||
for doc_id, text, metadata in results:
|
||||
chunks.append({
|
||||
"doc_id": str(doc_id),
|
||||
"text": text,
|
||||
"metadata": metadata or {}
|
||||
})
|
||||
|
||||
return chunks
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error getting chunks: {e}")
|
||||
return []
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
def add_chunk(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> str:
|
||||
"""Add a single chunk to the vector store"""
|
||||
metadata = metadata or {}
|
||||
|
||||
# Create a copy to avoid modifying the original metadata
|
||||
final_metadata = metadata.copy()
|
||||
|
||||
# Ensure the source_id is in the metadata so the chunk can be found by filters
|
||||
final_metadata["source_id"] = self._source_id
|
||||
|
||||
embeddings = self._embedding.embed_documents([text])
|
||||
|
||||
if not embeddings:
|
||||
raise ValueError("Could not generate embedding for chunk")
|
||||
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
insert_query = f"""
|
||||
INSERT INTO {self._table_name} ({self._text_column}, {self._vector_column}, {self._metadata_column}, source_id)
|
||||
VALUES (%s, %s, %s, %s)
|
||||
RETURNING id;
|
||||
"""
|
||||
|
||||
cursor.execute(
|
||||
insert_query,
|
||||
(text, embeddings[0], self._Json(final_metadata), self._source_id)
|
||||
)
|
||||
inserted_id = cursor.fetchone()[0]
|
||||
conn.commit()
|
||||
|
||||
return str(inserted_id)
|
||||
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
logging.error(f"Error adding chunk: {e}")
|
||||
raise
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
def delete_chunk(self, chunk_id: str) -> bool:
|
||||
"""Delete a specific chunk by its ID"""
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
delete_query = f"DELETE FROM {self._table_name} WHERE id = %s AND source_id = %s;"
|
||||
cursor.execute(delete_query, (int(chunk_id), self._source_id))
|
||||
deleted_count = cursor.rowcount
|
||||
conn.commit()
|
||||
|
||||
return deleted_count > 0
|
||||
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
logging.error(f"Error deleting chunk: {e}")
|
||||
return False
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
def __del__(self):
|
||||
"""Close database connection when object is destroyed"""
|
||||
if hasattr(self, '_connection') and self._connection and not self._connection.closed:
|
||||
self._connection.close()
|
||||
@@ -1,5 +1,7 @@
|
||||
import logging
|
||||
from application.vectorstore.base import BaseVectorStore
|
||||
from application.core.settings import settings
|
||||
from application.vectorstore.document_class import Document
|
||||
|
||||
|
||||
class QdrantStore(BaseVectorStore):
|
||||
@@ -7,18 +9,22 @@ class QdrantStore(BaseVectorStore):
|
||||
from qdrant_client import models
|
||||
from langchain_community.vectorstores.qdrant import Qdrant
|
||||
|
||||
# Store the source_id for use in add_chunk
|
||||
self._source_id = str(source_id).replace("application/indexes/", "").rstrip("/")
|
||||
|
||||
self._filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="metadata.source_id",
|
||||
match=models.MatchValue(value=source_id.replace("application/indexes/", "").rstrip("/")),
|
||||
match=models.MatchValue(value=self._source_id),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
embedding=self._get_embeddings(settings.EMBEDDINGS_NAME, embeddings_key)
|
||||
self._docsearch = Qdrant.construct_instance(
|
||||
["TEXT_TO_OBTAIN_EMBEDDINGS_DIMENSION"],
|
||||
embedding=self._get_embeddings(settings.EMBEDDINGS_NAME, embeddings_key),
|
||||
embedding=embedding,
|
||||
collection_name=settings.QDRANT_COLLECTION_NAME,
|
||||
location=settings.QDRANT_LOCATION,
|
||||
url=settings.QDRANT_URL,
|
||||
@@ -32,6 +38,32 @@ class QdrantStore(BaseVectorStore):
|
||||
path=settings.QDRANT_PATH,
|
||||
distance_func=settings.QDRANT_DISTANCE_FUNC,
|
||||
)
|
||||
try:
|
||||
collections = self._docsearch.client.get_collections()
|
||||
collection_exists = settings.QDRANT_COLLECTION_NAME in [
|
||||
collection.name for collection in collections.collections
|
||||
]
|
||||
|
||||
if not collection_exists:
|
||||
self._docsearch.client.recreate_collection(
|
||||
collection_name=settings.QDRANT_COLLECTION_NAME,
|
||||
vectors_config=models.VectorParams(size=embedding.client[1].word_embedding_dimension, distance=models.Distance.COSINE),
|
||||
)
|
||||
|
||||
# Ensure the required index exists for metadata.source_id
|
||||
try:
|
||||
self._docsearch.client.create_payload_index(
|
||||
collection_name=settings.QDRANT_COLLECTION_NAME,
|
||||
field_name="metadata.source_id",
|
||||
field_schema=models.PayloadSchemaType.KEYWORD,
|
||||
)
|
||||
except Exception as index_error:
|
||||
# Index might already exist, which is fine
|
||||
if "already exists" not in str(index_error).lower():
|
||||
logging.warning(f"Could not create index for metadata.source_id: {index_error}")
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f"Could not check for collection: {e}")
|
||||
|
||||
def search(self, *args, **kwargs):
|
||||
return self._docsearch.similarity_search(filter=self._filter, *args, **kwargs)
|
||||
@@ -46,3 +78,59 @@ class QdrantStore(BaseVectorStore):
|
||||
return self._docsearch.client.delete(
|
||||
collection_name=settings.QDRANT_COLLECTION_NAME, points_selector=self._filter
|
||||
)
|
||||
|
||||
def get_chunks(self):
|
||||
try:
|
||||
|
||||
chunks = []
|
||||
offset = None
|
||||
while True:
|
||||
records, offset = self._docsearch.client.scroll(
|
||||
collection_name=settings.QDRANT_COLLECTION_NAME,
|
||||
scroll_filter=self._filter,
|
||||
limit=10,
|
||||
with_payload=True,
|
||||
with_vectors=False,
|
||||
offset=offset,
|
||||
)
|
||||
for record in records:
|
||||
doc_id = record.id
|
||||
text = record.payload.get("page_content")
|
||||
metadata = record.payload.get("metadata")
|
||||
chunks.append(
|
||||
{"doc_id": doc_id, "text": text, "metadata": metadata}
|
||||
)
|
||||
if offset is None:
|
||||
break
|
||||
return chunks
|
||||
except Exception as e:
|
||||
logging.error(f"Error getting chunks: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
def add_chunk(self, text, metadata=None):
|
||||
import uuid
|
||||
metadata = metadata or {}
|
||||
|
||||
# Create a copy to avoid modifying the original metadata
|
||||
final_metadata = metadata.copy()
|
||||
|
||||
# Ensure the source_id is in the metadata so the chunk can be found by filters
|
||||
final_metadata["source_id"] = self._source_id
|
||||
|
||||
doc = Document(page_content=text, metadata=final_metadata)
|
||||
# Generate a unique ID for the document
|
||||
doc_id = str(uuid.uuid4())
|
||||
doc.id = doc_id
|
||||
doc_ids = self._docsearch.add_documents([doc])
|
||||
return doc_ids[0] if doc_ids else doc_id
|
||||
|
||||
def delete_chunk(self, chunk_id):
|
||||
try:
|
||||
self._docsearch.client.delete(
|
||||
collection_name=settings.QDRANT_COLLECTION_NAME,
|
||||
points_selector=[chunk_id],
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.error(f"Error deleting chunk: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
@@ -3,6 +3,7 @@ from application.vectorstore.elasticsearch import ElasticsearchStore
|
||||
from application.vectorstore.milvus import MilvusStore
|
||||
from application.vectorstore.mongodb import MongoDBVectorStore
|
||||
from application.vectorstore.qdrant import QdrantStore
|
||||
from application.vectorstore.pgvector import PGVectorStore
|
||||
|
||||
|
||||
class VectorCreator:
|
||||
@@ -12,6 +13,7 @@ class VectorCreator:
|
||||
"mongodb": MongoDBVectorStore,
|
||||
"qdrant": QdrantStore,
|
||||
"milvus": MilvusStore,
|
||||
"pgvector": PGVectorStore
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,3 +1,4 @@
|
||||
name: docsgpt-oss
|
||||
services:
|
||||
|
||||
redis:
|
||||
|
||||
75
deployment/docker-compose-hub.yaml
Normal file
75
deployment/docker-compose-hub.yaml
Normal file
@@ -0,0 +1,75 @@
|
||||
name: docsgpt-oss
|
||||
services:
|
||||
|
||||
frontend:
|
||||
image: arc53/docsgpt-fe:develop
|
||||
environment:
|
||||
- VITE_API_HOST=http://localhost:7091
|
||||
- VITE_API_STREAMING=$VITE_API_STREAMING
|
||||
- VITE_GOOGLE_CLIENT_ID=$VITE_GOOGLE_CLIENT_ID
|
||||
ports:
|
||||
- "5173:5173"
|
||||
depends_on:
|
||||
- backend
|
||||
|
||||
|
||||
backend:
|
||||
user: root
|
||||
image: arc53/docsgpt:develop
|
||||
environment:
|
||||
- API_KEY=$API_KEY
|
||||
- EMBEDDINGS_KEY=$API_KEY
|
||||
- LLM_PROVIDER=$LLM_PROVIDER
|
||||
- LLM_NAME=$LLM_NAME
|
||||
- CELERY_BROKER_URL=redis://redis:6379/0
|
||||
- CELERY_RESULT_BACKEND=redis://redis:6379/1
|
||||
- MONGO_URI=mongodb://mongo:27017/docsgpt
|
||||
- CACHE_REDIS_URL=redis://redis:6379/2
|
||||
- OPENAI_BASE_URL=$OPENAI_BASE_URL
|
||||
ports:
|
||||
- "7091:7091"
|
||||
volumes:
|
||||
- ../application/indexes:/app/indexes
|
||||
- ../application/inputs:/app/inputs
|
||||
- ../application/vectors:/app/vectors
|
||||
depends_on:
|
||||
- redis
|
||||
- mongo
|
||||
|
||||
|
||||
worker:
|
||||
user: root
|
||||
image: arc53/docsgpt:develop
|
||||
command: celery -A application.app.celery worker -l INFO -B
|
||||
environment:
|
||||
- API_KEY=$API_KEY
|
||||
- EMBEDDINGS_KEY=$API_KEY
|
||||
- LLM_PROVIDER=$LLM_PROVIDER
|
||||
- LLM_NAME=$LLM_NAME
|
||||
- CELERY_BROKER_URL=redis://redis:6379/0
|
||||
- CELERY_RESULT_BACKEND=redis://redis:6379/1
|
||||
- MONGO_URI=mongodb://mongo:27017/docsgpt
|
||||
- API_URL=http://backend:7091
|
||||
- CACHE_REDIS_URL=redis://redis:6379/2
|
||||
volumes:
|
||||
- ../application/indexes:/app/indexes
|
||||
- ../application/inputs:/app/inputs
|
||||
- ../application/vectors:/app/vectors
|
||||
depends_on:
|
||||
- redis
|
||||
- mongo
|
||||
|
||||
redis:
|
||||
image: redis:6-alpine
|
||||
ports:
|
||||
- 6379:6379
|
||||
|
||||
mongo:
|
||||
image: mongo:6
|
||||
ports:
|
||||
- 27017:27017
|
||||
volumes:
|
||||
- mongodb_data_container:/data/db
|
||||
|
||||
volumes:
|
||||
mongodb_data_container:
|
||||
@@ -1,3 +1,4 @@
|
||||
name: docsgpt-oss
|
||||
services:
|
||||
frontend:
|
||||
build: ../frontend
|
||||
@@ -6,6 +7,7 @@ services:
|
||||
environment:
|
||||
- VITE_API_HOST=http://localhost:7091
|
||||
- VITE_API_STREAMING=$VITE_API_STREAMING
|
||||
- VITE_GOOGLE_CLIENT_ID=$VITE_GOOGLE_CLIENT_ID
|
||||
ports:
|
||||
- "5173:5173"
|
||||
depends_on:
|
||||
@@ -17,19 +19,19 @@ services:
|
||||
environment:
|
||||
- API_KEY=$API_KEY
|
||||
- EMBEDDINGS_KEY=$API_KEY
|
||||
- LLM_PROVIDER=$LLM_PROVIDER
|
||||
- LLM_NAME=$LLM_NAME
|
||||
- CELERY_BROKER_URL=redis://redis:6379/0
|
||||
- CELERY_RESULT_BACKEND=redis://redis:6379/1
|
||||
- MONGO_URI=mongodb://mongo:27017/docsgpt
|
||||
- CACHE_REDIS_URL=redis://redis:6379/2
|
||||
- OPENAI_BASE_URL=$OPENAI_BASE_URL
|
||||
- MODEL_NAME=$MODEL_NAME
|
||||
ports:
|
||||
- "7091:7091"
|
||||
volumes:
|
||||
- ../application/indexes:/app/application/indexes
|
||||
- ../application/indexes:/app/indexes
|
||||
- ../application/inputs:/app/inputs
|
||||
- ../application/vectors:/app/application/vectors
|
||||
- ../application/vectors:/app/vectors
|
||||
depends_on:
|
||||
- redis
|
||||
- mongo
|
||||
@@ -41,6 +43,7 @@ services:
|
||||
environment:
|
||||
- API_KEY=$API_KEY
|
||||
- EMBEDDINGS_KEY=$API_KEY
|
||||
- LLM_PROVIDER=$LLM_PROVIDER
|
||||
- LLM_NAME=$LLM_NAME
|
||||
- CELERY_BROKER_URL=redis://redis:6379/0
|
||||
- CELERY_RESULT_BACKEND=redis://redis:6379/1
|
||||
@@ -48,9 +51,9 @@ services:
|
||||
- API_URL=http://backend:7091
|
||||
- CACHE_REDIS_URL=redis://redis:6379/2
|
||||
volumes:
|
||||
- ../application/indexes:/app/application/indexes
|
||||
- ../application/indexes:/app/indexes
|
||||
- ../application/inputs:/app/inputs
|
||||
- ../application/vectors:/app/application/vectors
|
||||
- ../application/vectors:/app/vectors
|
||||
depends_on:
|
||||
- redis
|
||||
- mongo
|
||||
|
||||
@@ -4,7 +4,7 @@ metadata:
|
||||
name: docsgpt-secrets
|
||||
type: Opaque
|
||||
data:
|
||||
LLM_NAME: ZG9jc2dwdA==
|
||||
LLM_PROVIDER: ZG9jc2dwdA==
|
||||
INTERNAL_KEY: aW50ZXJuYWw=
|
||||
CELERY_BROKER_URL: cmVkaXM6Ly9yZWRpcy1zZXJ2aWNlOjYzNzkvMA==
|
||||
CELERY_RESULT_BACKEND: cmVkaXM6Ly9yZWRpcy1zZXJ2aWNlOjYzNzkvMA==
|
||||
|
||||
@@ -2,5 +2,13 @@
|
||||
"basics": {
|
||||
"title": "🤖 Agent Basics",
|
||||
"href": "/Agents/basics"
|
||||
},
|
||||
"api": {
|
||||
"title": "🔌 Agent API",
|
||||
"href": "/Agents/api"
|
||||
},
|
||||
"webhooks": {
|
||||
"title": "🪝 Agent Webhooks",
|
||||
"href": "/Agents/webhooks"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
227
docs/pages/Agents/api.mdx
Normal file
227
docs/pages/Agents/api.mdx
Normal file
@@ -0,0 +1,227 @@
|
||||
---
|
||||
title: Interacting with Agents via API
|
||||
description: Learn how to programmatically interact with DocsGPT Agents using the streaming and non-streaming API endpoints.
|
||||
---
|
||||
|
||||
import { Callout, Tabs } from 'nextra/components';
|
||||
|
||||
# Interacting with Agents via API
|
||||
|
||||
DocsGPT Agents can be accessed programmatically through a dedicated API, allowing you to integrate their specialized capabilities into your own applications, scripts, and workflows. This guide covers the two primary methods for interacting with an agent: the streaming API for real-time responses and the non-streaming API for a single, consolidated answer.
|
||||
|
||||
When you use an API key generated for a specific agent, you do not need to pass `prompt`, `tools` etc. The agent's configuration (including its prompt, selected tools, and knowledge sources) is already associated with its unique API key.
|
||||
|
||||
### API Endpoints
|
||||
|
||||
- **Non-Streaming:** `http://localhost:7091/api/answer`
|
||||
- **Streaming:** `http://localhost:7091/stream`
|
||||
|
||||
<Callout type="info">
|
||||
For DocsGPT Cloud, use `https://gptcloud.arc53.com/` as the base URL.
|
||||
</Callout>
|
||||
|
||||
For more technical details, you can explore the API swagger documentation available for the cloud version or your local instance.
|
||||
|
||||
---
|
||||
|
||||
## Non-Streaming API (`/api/answer`)
|
||||
|
||||
This is a standard synchronous endpoint. It waits for the agent to fully process the request and returns a single JSON object with the complete answer. This is the simplest method and is ideal for backend processes where a real-time feed is not required.
|
||||
|
||||
### Request
|
||||
|
||||
- **Endpoint:** `/api/answer`
|
||||
- **Method:** `POST`
|
||||
- **Payload:**
|
||||
- `question` (string, required): The user's query or input for the agent.
|
||||
- `api_key` (string, required): The unique API key for the agent you wish to interact with.
|
||||
- `history` (string, optional): A JSON string representing the conversation history, e.g., `[{\"prompt\": \"first question\", \"answer\": \"first answer\"}]`.
|
||||
|
||||
### Response
|
||||
|
||||
A single JSON object containing:
|
||||
- `answer`: The complete, final answer from the agent.
|
||||
- `sources`: A list of sources the agent consulted.
|
||||
- `conversation_id`: The unique ID for the interaction.
|
||||
|
||||
### Examples
|
||||
|
||||
<Tabs items={['cURL', 'Python', 'JavaScript']}>
|
||||
<Tabs.Tab>
|
||||
```bash
|
||||
curl -X POST http://localhost:7091/api/answer \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"question": "your question here",
|
||||
"api_key": "your_agent_api_key"
|
||||
}'
|
||||
```
|
||||
</Tabs.Tab>
|
||||
<Tabs.Tab>
|
||||
```python
|
||||
import requests
|
||||
|
||||
API_URL = "http://localhost:7091/api/answer"
|
||||
API_KEY = "your_agent_api_key"
|
||||
QUESTION = "your question here"
|
||||
|
||||
response = requests.post(
|
||||
API_URL,
|
||||
json={"question": QUESTION, "api_key": API_KEY}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
print(response.json())
|
||||
else:
|
||||
print(f"Error: {response.status_code}")
|
||||
print(response.text)
|
||||
```
|
||||
</Tabs.Tab>
|
||||
<Tabs.Tab>
|
||||
```javascript
|
||||
const apiUrl = 'http://localhost:7091/api/answer';
|
||||
const apiKey = 'your_agent_api_key';
|
||||
const question = 'your question here';
|
||||
|
||||
async function getAnswer() {
|
||||
try {
|
||||
const response = await fetch(apiUrl, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({ question, api_key: apiKey }),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! Status: ${response.status}`);
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
console.log(data);
|
||||
} catch (error) {
|
||||
console.error("Failed to fetch answer:", error);
|
||||
}
|
||||
}
|
||||
|
||||
getAnswer();
|
||||
```
|
||||
</Tabs.Tab>
|
||||
</Tabs>
|
||||
|
||||
---
|
||||
|
||||
## Streaming API (`/stream`)
|
||||
|
||||
The `/stream` endpoint uses Server-Sent Events (SSE) to push data in real-time. This is ideal for applications where you want to display the response as it's being generated, such as in a live chatbot interface.
|
||||
|
||||
### Request
|
||||
|
||||
- **Endpoint:** `/stream`
|
||||
- **Method:** `POST`
|
||||
- **Payload:** Same as the non-streaming API.
|
||||
|
||||
### Response (SSE Stream)
|
||||
|
||||
The stream consists of multiple `data:` events, each containing a JSON object. Your client should listen for these events and process them based on their `type`.
|
||||
|
||||
**Event Types:**
|
||||
- `answer`: A chunk of the agent's final answer.
|
||||
- `source`: A document or source used by the agent.
|
||||
- `thought`: A reasoning step from the agent (for ReAct agents).
|
||||
- `id`: The unique `conversation_id` for the interaction.
|
||||
- `error`: An error message.
|
||||
- `end`: A final message indicating the stream has concluded.
|
||||
|
||||
### Examples
|
||||
|
||||
<Tabs items={['cURL', 'Python', 'JavaScript']}>
|
||||
<Tabs.Tab>
|
||||
```bash
|
||||
curl -X POST http://localhost:7091/stream \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Accept: text/event-stream" \
|
||||
-d '{
|
||||
"question": "your question here",
|
||||
"api_key": "your_agent_api_key"
|
||||
}'
|
||||
```
|
||||
</Tabs.Tab>
|
||||
<Tabs.Tab>
|
||||
```python
|
||||
import requests
|
||||
import json
|
||||
|
||||
API_URL = "http://localhost:7091/stream"
|
||||
payload = {
|
||||
"question": "your question here",
|
||||
"api_key": "your_agent_api_key"
|
||||
}
|
||||
|
||||
with requests.post(API_URL, json=payload, stream=True) as r:
|
||||
for line in r.iter_lines():
|
||||
if line:
|
||||
decoded_line = line.decode('utf-8')
|
||||
if decoded_line.startswith('data: '):
|
||||
try:
|
||||
data = json.loads(decoded_line[6:])
|
||||
print(data)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
```
|
||||
</Tabs.Tab>
|
||||
<Tabs.Tab>
|
||||
```javascript
|
||||
const apiUrl = 'http://localhost:7091/stream';
|
||||
const apiKey = 'your_agent_api_key';
|
||||
const question = 'your question here';
|
||||
|
||||
async function getStream() {
|
||||
try {
|
||||
const response = await fetch(apiUrl, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'text/event-stream'
|
||||
},
|
||||
// Corrected line: 'apiKey' is changed to 'api_key'
|
||||
body: JSON.stringify({ question, api_key: apiKey }),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! Status: ${response.status}`);
|
||||
}
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
const chunk = decoder.decode(value, { stream: true });
|
||||
// Note: This parsing method assumes each chunk contains whole lines.
|
||||
// For a more robust production implementation, buffer the chunks
|
||||
// and process them line by line.
|
||||
const lines = chunk.split('\n');
|
||||
|
||||
for (const line of lines) {
|
||||
if (line.startsWith('data: ')) {
|
||||
try {
|
||||
const data = JSON.parse(line.substring(6));
|
||||
console.log(data);
|
||||
} catch (e) {
|
||||
console.error("Failed to parse JSON from SSE event:", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to fetch stream:", error);
|
||||
}
|
||||
}
|
||||
|
||||
getStream();
|
||||
```
|
||||
</Tabs.Tab>
|
||||
</Tabs>
|
||||
152
docs/pages/Agents/webhooks.mdx
Normal file
152
docs/pages/Agents/webhooks.mdx
Normal file
@@ -0,0 +1,152 @@
|
||||
---
|
||||
title: Triggering Agents with Webhooks
|
||||
description: Learn how to automate and integrate DocsGPT Agents using webhooks for asynchronous task execution.
|
||||
---
|
||||
|
||||
import { Callout, Tabs } from 'nextra/components';
|
||||
|
||||
# Triggering Agents with Webhooks
|
||||
|
||||
Agent Webhooks provide a powerful mechanism to trigger an agent's execution from external systems. Unlike the direct API which provides an immediate response, webhooks are designed for **asynchronous** operations. When you call a webhook, DocsGPT enqueues the agent's task for background processing and immediately returns a `task_id`. You then use this ID to poll for the result.
|
||||
|
||||
This workflow is ideal for integrating with services that expect a quick initial response (e.g., form submissions) or for triggering long-running tasks without tying up a client connection.
|
||||
|
||||
Each agent has its own unique webhook URL, which can be generated from the agent's edit page in the DocsGPT UI. This URL includes a secure token for authentication.
|
||||
|
||||
### API Endpoints
|
||||
|
||||
- **Webhook URL:** `http://localhost:7091/api/webhooks/agents/{AGENT_WEBHOOK_TOKEN}`
|
||||
- **Task Status URL:** `http://localhost:7091/api/task_status`
|
||||
|
||||
<Callout type="info">
|
||||
For DocsGPT Cloud, use `https://gptcloud.arc53.com/` as the base URL.
|
||||
</Callout>
|
||||
|
||||
For more technical details, you can explore the API swagger documentation available for the cloud version or your local instance.
|
||||
|
||||
---
|
||||
|
||||
## The Webhook Workflow
|
||||
|
||||
The process involves two main steps: triggering the task and polling for the result.
|
||||
|
||||
### Step 1: Trigger the Webhook
|
||||
|
||||
Send an HTTP `POST` request to the agent's unique webhook URL with the required payload. The structure of this payload should match what the agent's prompt and tools are designed to handle.
|
||||
|
||||
- **Method:** `POST`
|
||||
- **Response:** A JSON object with a `task_id`. `{"task_id": "a1b2c3d4-e5f6-..."}`
|
||||
|
||||
<Tabs items={['cURL', 'Python', 'JavaScript']}>
|
||||
<Tabs.Tab>
|
||||
```bash
|
||||
curl -X POST \
|
||||
http://localhost:7091/api/webhooks/agents/your_webhook_token \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"question": "Your message to agent"}'
|
||||
```
|
||||
</Tabs.Tab>
|
||||
<Tabs.Tab>
|
||||
```python
|
||||
import requests
|
||||
|
||||
WEBHOOK_URL = "http://localhost:7091/api/webhooks/agents/your_webhook_token"
|
||||
payload = {"question": "Your message to agent"}
|
||||
|
||||
try:
|
||||
response = requests.post(WEBHOOK_URL, json=payload)
|
||||
response.raise_for_status()
|
||||
task_id = response.json().get("task_id")
|
||||
print(f"Task successfully created with ID: {task_id}")
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"Error triggering webhook: {e}")
|
||||
```
|
||||
</Tabs.Tab>
|
||||
<Tabs.Tab>
|
||||
```javascript
|
||||
const webhookUrl = 'http://localhost:7091/api/webhooks/agents/your_webhook_token';
|
||||
const payload = { question: 'Your message to agent' };
|
||||
|
||||
async function triggerWebhook() {
|
||||
try {
|
||||
const response = await fetch(webhookUrl, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(payload)
|
||||
});
|
||||
if (!response.ok) throw new Error(`HTTP error! ${response.status}`);
|
||||
const data = await response.json();
|
||||
console.log(`Task successfully created with ID: ${data.task_id}`);
|
||||
return data.task_id;
|
||||
} catch (error) {
|
||||
console.error('Error triggering webhook:', error);
|
||||
}
|
||||
}
|
||||
|
||||
triggerWebhook();
|
||||
```
|
||||
</Tabs.Tab>
|
||||
</Tabs>
|
||||
|
||||
### Step 2: Poll for the Result
|
||||
|
||||
Once you have the `task_id`, periodically send a `GET` request to the `/api/task_status` endpoint until the task `status` is `SUCCESS` or `FAILURE`.
|
||||
|
||||
- **`status`**: The current state of the task (`PENDING`, `STARTED`, `SUCCESS`, `FAILURE`).
|
||||
- **`result`**: The final output from the agent, available when the status is `SUCCESS` or `FAILURE`.
|
||||
|
||||
<Tabs items={['cURL', 'Python', 'JavaScript']}>
|
||||
<Tabs.Tab>
|
||||
```bash
|
||||
# Replace the task_id with the one you received
|
||||
curl http://localhost:7091/api/task_status?task_id=YOUR_TASK_ID
|
||||
```
|
||||
</Tabs.Tab>
|
||||
<Tabs.Tab>
|
||||
```python
|
||||
import requests
|
||||
import time
|
||||
|
||||
STATUS_URL = "http://localhost:7091/api/task_status"
|
||||
task_id = "YOUR_TASK_ID"
|
||||
|
||||
while True:
|
||||
response = requests.get(STATUS_URL, params={"task_id": task_id})
|
||||
data = response.json()
|
||||
status = data.get("status")
|
||||
print(f"Current task status: {status}")
|
||||
|
||||
if status in ["SUCCESS", "FAILURE"]:
|
||||
print("Final Result:")
|
||||
print(data.get("result"))
|
||||
break
|
||||
|
||||
time.sleep(2)
|
||||
```
|
||||
</Tabs.Tab>
|
||||
<Tabs.Tab>
|
||||
```javascript
|
||||
const statusUrl = 'http://localhost:7091/api/task_status';
|
||||
const taskId = 'YOUR_TASK_ID';
|
||||
|
||||
const sleep = (ms) => new Promise(resolve => setTimeout(resolve, ms));
|
||||
|
||||
async function pollForResult() {
|
||||
while (true) {
|
||||
const response = await fetch(`${statusUrl}?task_id=${taskId}`);
|
||||
const data = await response.json();
|
||||
const status = data.status;
|
||||
console.log(`Current task status: ${status}`);
|
||||
|
||||
if (status === 'SUCCESS' || status === 'FAILURE') {
|
||||
console.log('Final Result:', data.result);
|
||||
break;
|
||||
}
|
||||
await sleep(2000);
|
||||
}
|
||||
}
|
||||
|
||||
pollForResult();
|
||||
```
|
||||
</Tabs.Tab>
|
||||
</Tabs>
|
||||
@@ -37,7 +37,7 @@ The fastest way to try out DocsGPT is by using the public API endpoint. This req
|
||||
Open the `.env` file and add the following lines:
|
||||
|
||||
```
|
||||
LLM_NAME=docsgpt
|
||||
LLM_PROVIDER=docsgpt
|
||||
VITE_API_STREAMING=true
|
||||
```
|
||||
|
||||
@@ -93,16 +93,16 @@ There are two Ollama optional files:
|
||||
|
||||
3. **Pull the Ollama Model:**
|
||||
|
||||
**Crucially, after launching with Ollama, you need to pull the desired model into the Ollama container.** Find the `MODEL_NAME` you configured in your `.env` file (e.g., `llama3.2:1b`). Then execute the following command to pull the model *inside* the running Ollama container:
|
||||
**Crucially, after launching with Ollama, you need to pull the desired model into the Ollama container.** Find the `LLM_NAME` you configured in your `.env` file (e.g., `llama3.2:1b`). Then execute the following command to pull the model *inside* the running Ollama container:
|
||||
|
||||
```bash
|
||||
docker compose -f deployment/docker-compose.yaml -f deployment/optional/docker-compose.optional.ollama-cpu.yaml exec -it ollama ollama pull <MODEL_NAME>
|
||||
docker compose -f deployment/docker-compose.yaml -f deployment/optional/docker-compose.optional.ollama-cpu.yaml exec -it ollama ollama pull <LLM_NAME>
|
||||
```
|
||||
or (for GPU):
|
||||
```bash
|
||||
docker compose -f deployment/docker-compose.yaml -f deployment/optional/docker-compose.optional.ollama-gpu.yaml exec -it ollama ollama pull <MODEL_NAME>
|
||||
docker compose -f deployment/docker-compose.yaml -f deployment/optional/docker-compose.optional.ollama-gpu.yaml exec -it ollama ollama pull <LLM_NAME>
|
||||
```
|
||||
Replace `<MODEL_NAME>` with the actual model name from your `.env` file.
|
||||
Replace `<LLM_NAME>` with the actual model name from your `.env` file.
|
||||
|
||||
4. **Access DocsGPT in your browser:**
|
||||
|
||||
|
||||
@@ -20,9 +20,9 @@ The easiest and recommended way to configure basic settings is by using a `.env`
|
||||
**Example `.env` file structure:**
|
||||
|
||||
```
|
||||
LLM_NAME=openai
|
||||
LLM_PROVIDER=openai
|
||||
API_KEY=YOUR_OPENAI_API_KEY
|
||||
MODEL_NAME=gpt-4o
|
||||
LLM_NAME=gpt-4o
|
||||
```
|
||||
|
||||
### 2. Configuration via `settings.py` file (Advanced)
|
||||
@@ -37,33 +37,33 @@ While modifying `settings.py` offers more flexibility, it's generally recommende
|
||||
|
||||
Here are some of the most fundamental settings you'll likely want to configure:
|
||||
|
||||
- **`LLM_NAME`**: This setting determines which Large Language Model (LLM) provider DocsGPT will use. It tells DocsGPT which API to interact with.
|
||||
- **`LLM_PROVIDER`**: This setting determines which Large Language Model (LLM) provider DocsGPT will use. It tells DocsGPT which API to interact with.
|
||||
|
||||
- **Common values:**
|
||||
- `docsgpt`: Use the DocsGPT Public API Endpoint (simple and free, as offered in `setup.sh` option 1).
|
||||
- `openai`: Use OpenAI's API (requires an API key).
|
||||
- `google`: Use Google's Vertex AI or Gemini models.
|
||||
- `anthropic`: Use Anthropic's Claude models.
|
||||
- `groq`: Use Groq's models.
|
||||
- `huggingface`: Use HuggingFace Inference API.
|
||||
- `azure_openai`: Use Azure OpenAI Service.
|
||||
- `openai` (when using local inference engines like Ollama, Llama.cpp, TGI, etc.): This signals DocsGPT to use an OpenAI-compatible API format, even if the actual LLM is running locally.
|
||||
- **Common values:**
|
||||
- `docsgpt`: Use the DocsGPT Public API Endpoint (simple and free, as offered in `setup.sh` option 1).
|
||||
- `openai`: Use OpenAI's API (requires an API key).
|
||||
- `google`: Use Google's Vertex AI or Gemini models.
|
||||
- `anthropic`: Use Anthropic's Claude models.
|
||||
- `groq`: Use Groq's models.
|
||||
- `huggingface`: Use HuggingFace Inference API.
|
||||
- `azure_openai`: Use Azure OpenAI Service.
|
||||
- `openai` (when using local inference engines like Ollama, Llama.cpp, TGI, etc.): This signals DocsGPT to use an OpenAI-compatible API format, even if the actual LLM is running locally.
|
||||
|
||||
- **`MODEL_NAME`**: Specifies the specific model to use from the chosen LLM provider. The available models depend on the `LLM_NAME` you've selected.
|
||||
- **`LLM_NAME`**: Specifies the specific model to use from the chosen LLM provider. The available models depend on the `LLM_PROVIDER` you've selected.
|
||||
|
||||
- **Examples:**
|
||||
- For `LLM_NAME=openai`: `gpt-4o`
|
||||
- For `LLM_NAME=google`: `gemini-2.0-flash`
|
||||
- For local models (e.g., Ollama): `llama3.2:1b` (or any model name available in your setup).
|
||||
- **Examples:**
|
||||
- For `LLM_PROVIDER=openai`: `gpt-4o`
|
||||
- For `LLM_PROVIDER=google`: `gemini-2.0-flash`
|
||||
- For local models (e.g., Ollama): `llama3.2:1b` (or any model name available in your setup).
|
||||
|
||||
- **`EMBEDDINGS_NAME`**: This setting defines which embedding model DocsGPT will use to generate vector embeddings for your documents. Embeddings are numerical representations of text that allow DocsGPT to understand the semantic meaning of your documents for efficient search and retrieval.
|
||||
- **`EMBEDDINGS_NAME`**: This setting defines which embedding model DocsGPT will use to generate vector embeddings for your documents. Embeddings are numerical representations of text that allow DocsGPT to understand the semantic meaning of your documents for efficient search and retrieval.
|
||||
|
||||
- **Default value:** `huggingface_sentence-transformers/all-mpnet-base-v2` (a good general-purpose embedding model).
|
||||
- **Other options:** You can explore other embedding models from Hugging Face Sentence Transformers or other providers if needed.
|
||||
- **Default value:** `huggingface_sentence-transformers/all-mpnet-base-v2` (a good general-purpose embedding model).
|
||||
- **Other options:** You can explore other embedding models from Hugging Face Sentence Transformers or other providers if needed.
|
||||
|
||||
- **`API_KEY`**: Required for most cloud-based LLM providers. This is your authentication key to access the LLM provider's API. You'll need to obtain this key from your chosen provider's platform.
|
||||
- **`API_KEY`**: Required for most cloud-based LLM providers. This is your authentication key to access the LLM provider's API. You'll need to obtain this key from your chosen provider's platform.
|
||||
|
||||
- **`OPENAI_BASE_URL`**: Specifically used when `LLM_NAME` is set to `openai` but you are connecting to a local inference engine (like Ollama, Llama.cpp, etc.) that exposes an OpenAI-compatible API. This setting tells DocsGPT where to find your local LLM server.
|
||||
- **`OPENAI_BASE_URL`**: Specifically used when `LLM_PROVIDER` is set to `openai` but you are connecting to a local inference engine (like Ollama, Llama.cpp, etc.) that exposes an OpenAI-compatible API. This setting tells DocsGPT where to find your local LLM server.
|
||||
|
||||
## Configuration Examples
|
||||
|
||||
@@ -74,9 +74,9 @@ Let's look at some concrete examples of how to configure these settings in your
|
||||
To use OpenAI's `gpt-4o` model, you would configure your `.env` file like this:
|
||||
|
||||
```
|
||||
LLM_NAME=openai
|
||||
LLM_PROVIDER=openai
|
||||
API_KEY=YOUR_OPENAI_API_KEY # Replace with your actual OpenAI API key
|
||||
MODEL_NAME=gpt-4o
|
||||
LLM_NAME=gpt-4o
|
||||
```
|
||||
|
||||
Make sure to replace `YOUR_OPENAI_API_KEY` with your actual OpenAI API key.
|
||||
@@ -86,58 +86,89 @@ Make sure to replace `YOUR_OPENAI_API_KEY` with your actual OpenAI API key.
|
||||
To use a local Ollama server with the `llama3.2:1b` model, you would configure your `.env` file like this:
|
||||
|
||||
```
|
||||
LLM_NAME=openai # Using OpenAI compatible API format for local models
|
||||
LLM_PROVIDER=openai # Using OpenAI compatible API format for local models
|
||||
API_KEY=None # API Key is not needed for local Ollama
|
||||
MODEL_NAME=llama3.2:1b
|
||||
LLM_NAME=llama3.2:1b
|
||||
OPENAI_BASE_URL=http://host.docker.internal:11434/v1 # Default Ollama API URL within Docker
|
||||
EMBEDDINGS_NAME=huggingface_sentence-transformers/all-mpnet-base-v2 # You can also run embeddings locally if needed
|
||||
```
|
||||
|
||||
In this case, even though you are using Ollama locally, `LLM_NAME` is set to `openai` because Ollama (and many other local inference engines) are designed to be API-compatible with OpenAI. `OPENAI_BASE_URL` points DocsGPT to the local Ollama server.
|
||||
In this case, even though you are using Ollama locally, `LLM_PROVIDER` is set to `openai` because Ollama (and many other local inference engines) are designed to be API-compatible with OpenAI. `OPENAI_BASE_URL` points DocsGPT to the local Ollama server.
|
||||
|
||||
## Authentication Settings
|
||||
|
||||
DocsGPT includes a JWT (JSON Web Token) based authentication feature for managing sessions or securing local deployments while allowing access.
|
||||
|
||||
- **`AUTH_TYPE`**: This setting in your `.env` file or `settings.py` determines the authentication method.
|
||||
|
||||
- **Possible values:**
|
||||
- `None` (or not set): No authentication is used.
|
||||
- `simple_jwt`: A single, long-lived JWT token is generated and used for all authenticated requests. This is useful for securing a local deployment with a shared secret.
|
||||
- `session_jwt`: Unique JWT tokens are generated for sessions, typically for individual users or temporary access.
|
||||
- If `AUTH_TYPE` is set to `simple_jwt` or `session_jwt`, then a `JWT_SECRET_KEY` is required.
|
||||
- **`JWT_SECRET_KEY`**: This is a crucial secret key used to sign and verify JWTs.
|
||||
|
||||
- It can be set directly in your `.env` file or `settings.py`.
|
||||
- **Automatic Key Generation**: If `AUTH_TYPE` is `simple_jwt` or `session_jwt` and `JWT_SECRET_KEY` is _not_ set in your environment variables or `settings.py`, DocsGPT will attempt to:
|
||||
1. Read the key from a file named `.jwt_secret_key` in the project's root directory.
|
||||
2. If the file doesn't exist, it will generate a new 32-byte random key, save it to `.jwt_secret_key`, and use it for the session. This ensures that the key persists across application restarts.
|
||||
- **Security Note**: It's vital to keep this key secure. If you set it manually, choose a strong, random string.
|
||||
### `AUTH_TYPE` Overview
|
||||
|
||||
**How it works:**
|
||||
The `AUTH_TYPE` setting in your `.env` file or `settings.py` determines the authentication method used by DocsGPT. This allows you to control how users authenticate with your DocsGPT instance.
|
||||
|
||||
- When `AUTH_TYPE` is set to `simple_jwt`, a token is generated at startup (if not already present or configured) and printed to the console. This token should be included in the `Authorization` header of your API requests as a Bearer token (e.g., `Authorization: Bearer YOUR_SIMPLE_JWT_TOKEN`).
|
||||
- When `AUTH_TYPE` is set to `session_jwt`:
|
||||
- Clients can request a new token from the `/api/generate_token` endpoint.
|
||||
- This token should then be included in the `Authorization` header for subsequent requests.
|
||||
- The backend verifies the JWT token provided in the `Authorization` header for protected routes.
|
||||
- The `/api/config` endpoint can be used to check the current `auth_type` and whether authentication is required.
|
||||
| Value | Description |
|
||||
| ------------- | ------------------------------------------------------------------------------------------- |
|
||||
| `None` | No authentication is used. Anyone can access the app. |
|
||||
| `simple_jwt` | A single, long-lived JWT token is generated at startup. All requests use this shared token. |
|
||||
| `session_jwt` | Unique JWT tokens are generated for each session/user. |
|
||||
|
||||
**Frontend Token Input for `simple_jwt`:**
|
||||
#### How to Configure
|
||||
|
||||
<img
|
||||
src="/jwt-input.png"
|
||||
alt="Frontend prompt for JWT Token"
|
||||
style={{
|
||||
width: '500px',
|
||||
maxWidth: '100%',
|
||||
display: 'block',
|
||||
margin: '1em auto'
|
||||
}}
|
||||
Add the following to your `.env` file (or set in `settings.py`):
|
||||
|
||||
```env
|
||||
# No authentication (default)
|
||||
AUTH_TYPE=None
|
||||
|
||||
# OR: Simple JWT (shared token)
|
||||
AUTH_TYPE=simple_jwt
|
||||
JWT_SECRET_KEY=your_secret_key_here
|
||||
|
||||
# OR: Session JWT (per-user/session tokens)
|
||||
AUTH_TYPE=session_jwt
|
||||
JWT_SECRET_KEY=your_secret_key_here
|
||||
```
|
||||
|
||||
- If `AUTH_TYPE` is set to `simple_jwt` or `session_jwt`, a `JWT_SECRET_KEY` is required.
|
||||
- If `JWT_SECRET_KEY` is not set, DocsGPT will generate one and store it in `.jwt_secret_key` in the project root.
|
||||
|
||||
#### How Each Method Works
|
||||
|
||||
- **None**: No authentication. All API and UI access is open.
|
||||
- **simple_jwt**:
|
||||
- A single JWT token is generated at startup and printed to the console.
|
||||
- Use this token in the `Authorization` header for all API requests:
|
||||
```http
|
||||
Authorization: Bearer <SIMPLE_JWT_TOKEN>
|
||||
```
|
||||
- The frontend will prompt for this token if not already set.
|
||||
- **session_jwt**:
|
||||
- Clients can request a new token from `/api/generate_token`.
|
||||
- Use the received token in the `Authorization` header for subsequent requests.
|
||||
- Each user/session gets a unique token.
|
||||
|
||||
#### Security Notes
|
||||
|
||||
- Always keep your `JWT_SECRET_KEY` secure and private.
|
||||
- If you set it manually, use a strong, random string.
|
||||
- If not set, DocsGPT will generate a secure key and persist it in `.jwt_secret_key`.
|
||||
|
||||
#### Checking Current Auth Type
|
||||
|
||||
- Use the `/api/config` endpoint to check the current `auth_type` and whether authentication is required.
|
||||
|
||||
#### Frontend Token Input for `simple_jwt`
|
||||
|
||||
If you have configured `AUTH_TYPE=simple_jwt`, the DocsGPT frontend will prompt you to enter the JWT token if it's not already set or is invalid. Paste the `SIMPLE_JWT_TOKEN` (printed to your console when the backend starts) into this field to access the application.
|
||||
|
||||
<img
|
||||
src="/jwt-input.png"
|
||||
alt="Frontend prompt for JWT Token"
|
||||
style={{
|
||||
width: "500px",
|
||||
maxWidth: "100%",
|
||||
display: "block",
|
||||
margin: "1em auto",
|
||||
}}
|
||||
/>
|
||||
|
||||
If you have configured `AUTH_TYPE=simple_jwt`, the DocsGPT frontend will prompt you to enter the JWT token if it's not already set or is invalid. You'll need to paste the `SIMPLE_JWT_TOKEN` (which is printed to your console when the backend starts) into this field to access the application.
|
||||
|
||||
## Exploring More Settings
|
||||
|
||||
These are just the basic settings to get you started. The `settings.py` file contains many more advanced options that you can explore to further customize DocsGPT, such as:
|
||||
@@ -147,4 +178,4 @@ These are just the basic settings to get you started. The `settings.py` file con
|
||||
- Cache settings (`CACHE_REDIS_URL`)
|
||||
- And many more!
|
||||
|
||||
For a complete list of available settings and their descriptions, refer to the `settings.py` file in `application/core`. Remember to restart your Docker containers after making changes to your `.env` file or `settings.py` for the changes to take effect.
|
||||
For a complete list of available settings and their descriptions, refer to the `settings.py` file in `application/core`. Remember to restart your Docker containers after making changes to your `.env` file or `settings.py` for the changes to take effect.
|
||||
|
||||
@@ -32,9 +32,9 @@ Choose the LLM of your choice.
|
||||
### For Open source llm change:
|
||||
<Steps>
|
||||
### Step 1
|
||||
For open source version please edit `LLM_NAME`, `MODEL_NAME` and others in the .env file. Refer to [⚙️ App Configuration](/Deploying/DocsGPT-Settings) for more information.
|
||||
For open source version please edit `LLM_PROVIDER`, `LLM_NAME` and others in the .env file. Refer to [⚙️ App Configuration](/Deploying/DocsGPT-Settings) for more information.
|
||||
### Step 2
|
||||
Visit [☁️ Cloud Providers](/Models/cloud-providers) for the updated list of online models. Make sure you have the right API_KEY and correct LLM_NAME.
|
||||
Visit [☁️ Cloud Providers](/Models/cloud-providers) for the updated list of online models. Make sure you have the right API_KEY and correct LLM_PROVIDER.
|
||||
For self-hosted please visit [🖥️ Local Inference](/Models/local-inference).
|
||||
</Steps>
|
||||
|
||||
|
||||
6
docs/pages/Guides/Integrations/_meta.json
Normal file
6
docs/pages/Guides/Integrations/_meta.json
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"google-drive-connector": {
|
||||
"title": "🔗 Google Drive",
|
||||
"href": "/Guides/Integrations/google-drive-connector"
|
||||
}
|
||||
}
|
||||
212
docs/pages/Guides/Integrations/google-drive-connector.mdx
Normal file
212
docs/pages/Guides/Integrations/google-drive-connector.mdx
Normal file
@@ -0,0 +1,212 @@
|
||||
---
|
||||
title: Google Drive Connector
|
||||
description: Connect your Google Drive as an external knowledge base to upload and process files directly from your Google Drive account.
|
||||
---
|
||||
|
||||
import { Callout } from 'nextra/components'
|
||||
import { Steps } from 'nextra/components'
|
||||
|
||||
# Google Drive Connector
|
||||
|
||||
The Google Drive Connector allows you to seamlessly connect your Google Drive account as an external knowledge base. This integration enables you to upload and process files directly from your Google Drive without manually downloading and uploading them to DocsGPT.
|
||||
|
||||
## Features
|
||||
|
||||
- **Direct File Access**: Browse and select files directly from your Google Drive
|
||||
- **Comprehensive File Support**: Supports all major document formats including:
|
||||
- Google Workspace files (Docs, Sheets, Slides)
|
||||
- Microsoft Office files (.docx, .xlsx, .pptx, .doc, .ppt, .xls)
|
||||
- PDF documents
|
||||
- Text files (.txt, .md, .rst, .html, .rtf)
|
||||
- Data files (.csv, .json)
|
||||
- Image files (.png, .jpg, .jpeg)
|
||||
- E-books (.epub)
|
||||
- **Secure Authentication**: Uses OAuth 2.0 for secure access to your Google Drive
|
||||
- **Real-time Sync**: Process files directly from Google Drive without local downloads
|
||||
|
||||
<Callout type="info" emoji="ℹ️">
|
||||
The Google Drive Connector requires proper configuration of Google API credentials. Follow the setup instructions below to enable this feature.
|
||||
</Callout>
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Before setting up the Google Drive Connector, you'll need:
|
||||
|
||||
1. A Google Cloud Platform (GCP) project
|
||||
2. Google Drive API enabled
|
||||
3. OAuth 2.0 credentials configured
|
||||
4. DocsGPT instance with proper environment variables
|
||||
|
||||
## Setup Instructions
|
||||
|
||||
<Steps>
|
||||
|
||||
### Step 1: Create a Google Cloud Project
|
||||
|
||||
1. Go to the [Google Cloud Console](https://console.cloud.google.com/)
|
||||
2. Create a new project or select an existing one
|
||||
3. Note down your Project ID for later use
|
||||
|
||||
### Step 2: Enable Google Drive API
|
||||
|
||||
1. In the Google Cloud Console, navigate to **APIs & Services** > **Library**
|
||||
2. Search for "Google Drive API"
|
||||
3. Click on "Google Drive API" and click **Enable**
|
||||
|
||||
### Step 3: Create OAuth 2.0 Credentials
|
||||
|
||||
1. Go to **APIs & Services** > **Credentials**
|
||||
2. Click **Create Credentials** > **OAuth client ID**
|
||||
3. If prompted, configure the OAuth consent screen:
|
||||
- Choose **External** user type (unless you're using Google Workspace)
|
||||
- Fill in the required fields (App name, User support email, Developer contact)
|
||||
- Add your domain to **Authorized domains** if deploying publicly
|
||||
4. For Application type, select **Web application**
|
||||
5. Add your DocsGPT frontend URL to **Authorized JavaScript origins**:
|
||||
- For local development: `http://localhost:3000`
|
||||
- For production: `https://yourdomain.com`
|
||||
6. Add your DocsGPT callback URL to **Authorized redirect URIs**:
|
||||
- For local development: `http://localhost:7091/api/connectors/callback?provider=google_drive`
|
||||
- For production: `https://yourdomain.com/api/connectors/callback?provider=google_drive`
|
||||
7. Click **Create** and note down the **Client ID** and **Client Secret**
|
||||
|
||||
|
||||
|
||||
### Step 4: Configure Backend Environment Variables
|
||||
|
||||
Add the following environment variables to your backend configuration:
|
||||
|
||||
**For Docker deployment**, add to your `.env` file in the root directory:
|
||||
|
||||
```env
|
||||
# Google Drive Connector Configuration
|
||||
GOOGLE_CLIENT_ID=your_google_client_id_here
|
||||
GOOGLE_CLIENT_SECRET=your_google_client_secret_here
|
||||
```
|
||||
|
||||
**For manual deployment**, set these environment variables in your system or application configuration.
|
||||
|
||||
### Step 5: Configure Frontend Environment Variables
|
||||
|
||||
Add the following environment variables to your frontend `.env` file:
|
||||
|
||||
```env
|
||||
# Google Drive Frontend Configuration
|
||||
VITE_GOOGLE_CLIENT_ID=your_google_client_id_here
|
||||
```
|
||||
|
||||
<Callout type="warning" emoji="⚠️">
|
||||
Make sure to use the same Google Client ID in both backend and frontend configurations.
|
||||
</Callout>
|
||||
|
||||
### Step 6: Restart Your Application
|
||||
|
||||
After configuring the environment variables:
|
||||
|
||||
1. **For Docker**: Restart your Docker containers
|
||||
```bash
|
||||
docker-compose down
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
2. **For manual deployment**: Restart both backend and frontend services
|
||||
|
||||
</Steps>
|
||||
|
||||
## Using the Google Drive Connector
|
||||
|
||||
Once configured, you can use the Google Drive Connector to upload files:
|
||||
|
||||
<Steps>
|
||||
|
||||
### Step 1: Access the Upload Interface
|
||||
|
||||
1. Navigate to the DocsGPT interface
|
||||
2. Go to the upload/training section
|
||||
3. You should now see "Google Drive" as an available upload option
|
||||
|
||||
### Step 2: Connect Your Google Account
|
||||
|
||||
1. Select "Google Drive" as your upload method
|
||||
2. Click "Connect to Google Drive"
|
||||
3. You'll be redirected to Google's OAuth consent screen
|
||||
4. Grant the necessary permissions to DocsGPT
|
||||
5. You'll be redirected back to DocsGPT with a successful connection
|
||||
|
||||
### Step 3: Select Files
|
||||
|
||||
1. Once connected, click "Select Files"
|
||||
2. The Google Drive picker will open
|
||||
3. Browse your Google Drive and select the files you want to process
|
||||
4. Click "Select" to confirm your choices
|
||||
|
||||
### Step 4: Process Files
|
||||
|
||||
1. Review your selected files
|
||||
2. Click "Train" or "Upload" to process the files
|
||||
3. DocsGPT will download and process the files from your Google Drive
|
||||
4. Once processing is complete, the files will be available in your knowledge base
|
||||
|
||||
</Steps>
|
||||
|
||||
## Supported File Types
|
||||
|
||||
The Google Drive Connector supports the following file types:
|
||||
|
||||
| File Type | Extensions | Description |
|
||||
|-----------|------------|-------------|
|
||||
| **Google Workspace** | - | Google Docs, Sheets, Slides (automatically converted) |
|
||||
| **Microsoft Office** | .docx, .xlsx, .pptx | Modern Office formats |
|
||||
| **Legacy Office** | .doc, .ppt, .xls | Older Office formats |
|
||||
| **PDF Documents** | .pdf | Portable Document Format |
|
||||
| **Text Files** | .txt, .md, .rst, .html, .rtf | Various text formats |
|
||||
| **Data Files** | .csv, .json | Structured data formats |
|
||||
| **Images** | .png, .jpg, .jpeg | Image files (with OCR if enabled) |
|
||||
| **E-books** | .epub | Electronic publication format |
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
**"Google Drive option not appearing"**
|
||||
- Verify that `VITE_GOOGLE_CLIENT_ID` is set in frontend environment
|
||||
- Check that `VITE_GOOGLE_CLIENT_ID` environment variable is present in your frontend configuration
|
||||
- Check browser console for any JavaScript errors
|
||||
- Ensure the frontend has been restarted after adding environment variables
|
||||
|
||||
**"Authentication failed"**
|
||||
- Verify that your OAuth 2.0 credentials are correctly configured
|
||||
- Check that the redirect URI `http://<your-domain>/api/connectors/callback?provider=google_drive` is correctly added in GCP console
|
||||
- Ensure the Google Drive API is enabled in your GCP project
|
||||
|
||||
**"Permission denied" errors**
|
||||
- Verify that the OAuth consent screen is properly configured
|
||||
- Check that your Google account has access to the files you're trying to select
|
||||
- Ensure the required scopes are granted during authentication
|
||||
|
||||
**"Files not processing"**
|
||||
- Check that the backend environment variables are correctly set
|
||||
- Verify that the OAuth credentials have the necessary permissions
|
||||
- Check the backend logs for any error messages
|
||||
|
||||
### Environment Variable Checklist
|
||||
|
||||
**Backend (.env in root directory):**
|
||||
- ✅ `GOOGLE_CLIENT_ID`
|
||||
- ✅ `GOOGLE_CLIENT_SECRET`
|
||||
|
||||
**Frontend (.env in frontend directory):**
|
||||
- ✅ `VITE_GOOGLE_CLIENT_ID`
|
||||
|
||||
### Security Considerations
|
||||
|
||||
- Keep your Google Client Secret secure and never expose it in frontend code
|
||||
- Regularly rotate your OAuth credentials
|
||||
- Use HTTPS in production to protect authentication tokens
|
||||
- Ensure proper OAuth consent screen configuration for production use
|
||||
|
||||
<Callout type="tip" emoji="💡">
|
||||
For production deployments, make sure to add your actual domain to the OAuth consent screen and authorized origins/redirect URIs.
|
||||
</Callout>
|
||||
|
||||
|
||||
@@ -20,5 +20,8 @@
|
||||
"Architecture": {
|
||||
"title": "🏗️ Architecture",
|
||||
"href": "/Guides/Architecture"
|
||||
},
|
||||
"Integrations": {
|
||||
"title": "🔗 Integrations"
|
||||
}
|
||||
}
|
||||
@@ -13,15 +13,15 @@ The primary method for configuring your LLM provider in DocsGPT is through the `
|
||||
|
||||
To connect to a cloud LLM provider, you will typically need to configure the following basic settings in your `.env` file:
|
||||
|
||||
* **`LLM_NAME`**: This setting is essential and identifies the specific cloud provider you wish to use (e.g., `openai`, `google`, `anthropic`).
|
||||
* **`MODEL_NAME`**: Specifies the exact model you want to utilize from your chosen provider (e.g., `gpt-4o`, `gemini-2.0-flash`, `claude-3-5-sonnet-latest`). Refer to your provider's documentation for a list of available models.
|
||||
* **`LLM_PROVIDER`**: This setting is essential and identifies the specific cloud provider you wish to use (e.g., `openai`, `google`, `anthropic`).
|
||||
* **`LLM_NAME`**: Specifies the exact model you want to utilize from your chosen provider (e.g., `gpt-4o`, `gemini-2.0-flash`, `claude-3-5-sonnet-latest`). Refer to your provider's documentation for a list of available models.
|
||||
* **`API_KEY`**: Almost all cloud LLM providers require an API key for authentication. Obtain your API key from your chosen provider's platform and securely store it in your `.env` file.
|
||||
|
||||
## Explicitly Supported Cloud Providers
|
||||
|
||||
DocsGPT offers direct, streamlined support for the following cloud LLM providers, making configuration straightforward. The table below outlines the `LLM_NAME` and example `MODEL_NAME` values to use for each provider in your `.env` file.
|
||||
DocsGPT offers direct, streamlined support for the following cloud LLM providers, making configuration straightforward. The table below outlines the `LLM_PROVIDER` and example `LLM_NAME` values to use for each provider in your `.env` file.
|
||||
|
||||
| Provider | `LLM_NAME` | Example `MODEL_NAME` |
|
||||
| Provider | `LLM_PROVIDER` | Example `LLM_NAME` |
|
||||
| :--------------------------- | :------------- | :-------------------------- |
|
||||
| DocsGPT Public API | `docsgpt` | `None` |
|
||||
| OpenAI | `openai` | `gpt-4o` |
|
||||
@@ -35,16 +35,16 @@ DocsGPT offers direct, streamlined support for the following cloud LLM providers
|
||||
|
||||
DocsGPT's flexible architecture allows you to connect to any cloud provider that offers an API compatible with the OpenAI API standard. This opens up a vast ecosystem of LLM services.
|
||||
|
||||
To connect to an OpenAI-compatible cloud provider, you will still use `LLM_NAME=openai` in your `.env` file. However, you will also need to specify the API endpoint of your chosen provider using the `OPENAI_BASE_URL` setting. You will also likely need to provide an `API_KEY` and `MODEL_NAME` as required by that provider.
|
||||
To connect to an OpenAI-compatible cloud provider, you will still use `LLM_PROVIDER=openai` in your `.env` file. However, you will also need to specify the API endpoint of your chosen provider using the `OPENAI_BASE_URL` setting. You will also likely need to provide an `API_KEY` and `LLM_NAME` as required by that provider.
|
||||
|
||||
**Example for DeepSeek (OpenAI-Compatible API):**
|
||||
|
||||
To connect to DeepSeek, which offers an OpenAI-compatible API, your `.env` file could be configured as follows:
|
||||
|
||||
```
|
||||
LLM_NAME=openai
|
||||
LLM_PROVIDER=openai
|
||||
API_KEY=YOUR_API_KEY # Your DeepSeek API key
|
||||
MODEL_NAME=deepseek-chat # Or your desired DeepSeek model name
|
||||
LLM_NAME=deepseek-chat # Or your desired DeepSeek model name
|
||||
OPENAI_BASE_URL=https://api.deepseek.com/v1 # DeepSeek's OpenAI API URL
|
||||
```
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ To use OpenAI's `text-embedding-ada-002` embedding model, you need to set `EMBED
|
||||
**Example `.env` configuration for OpenAI Embeddings:**
|
||||
|
||||
```
|
||||
LLM_NAME=openai
|
||||
LLM_PROVIDER=openai
|
||||
API_KEY=YOUR_OPENAI_API_KEY # Your OpenAI API Key
|
||||
EMBEDDINGS_NAME=openai_text-embedding-ada-002
|
||||
```
|
||||
|
||||
@@ -15,8 +15,8 @@ Setting up a local inference engine with DocsGPT is configured through environme
|
||||
|
||||
To connect to a local inference engine, you will generally need to configure these settings in your `.env` file:
|
||||
|
||||
* **`LLM_NAME`**: Crucially set this to `openai`. This tells DocsGPT to use the OpenAI-compatible API format for communication, even though the LLM is local.
|
||||
* **`MODEL_NAME`**: Specify the model name as recognized by your local inference engine. This might be a model identifier or left as `None` if the engine doesn't require explicit model naming in the API request.
|
||||
* **`LLM_PROVIDER`**: Crucially set this to `openai`. This tells DocsGPT to use the OpenAI-compatible API format for communication, even though the LLM is local.
|
||||
* **`LLM_NAME`**: Specify the model name as recognized by your local inference engine. This might be a model identifier or left as `None` if the engine doesn't require explicit model naming in the API request.
|
||||
* **`OPENAI_BASE_URL`**: This is essential. Set this to the base URL of your local inference engine's API endpoint. This tells DocsGPT where to find your local LLM server.
|
||||
* **`API_KEY`**: Generally, for local inference engines, you can set `API_KEY=None` as authentication is usually not required in local setups.
|
||||
|
||||
@@ -24,16 +24,16 @@ To connect to a local inference engine, you will generally need to configure the
|
||||
|
||||
DocsGPT is readily configurable to work with the following local inference engines, all communicating via the OpenAI API format. Here are example `OPENAI_BASE_URL` values for each, based on default setups:
|
||||
|
||||
| Inference Engine | `LLM_NAME` | `OPENAI_BASE_URL` |
|
||||
| :---------------------------- | :--------- | :------------------------- |
|
||||
| LLaMa.cpp | `openai` | `http://localhost:8000/v1` |
|
||||
| Ollama | `openai` | `http://localhost:11434/v1` |
|
||||
| Text Generation Inference (TGI)| `openai` | `http://localhost:8080/v1` |
|
||||
| SGLang | `openai` | `http://localhost:30000/v1` |
|
||||
| vLLM | `openai` | `http://localhost:8000/v1` |
|
||||
| Aphrodite | `openai` | `http://localhost:2242/v1` |
|
||||
| FriendliAI | `openai` | `http://localhost:8997/v1` |
|
||||
| LMDeploy | `openai` | `http://localhost:23333/v1` |
|
||||
| Inference Engine | `LLM_PROVIDER` | `OPENAI_BASE_URL` |
|
||||
| :---------------------------- | :------------- | :------------------------- |
|
||||
| LLaMa.cpp | `openai` | `http://localhost:8000/v1` |
|
||||
| Ollama | `openai` | `http://localhost:11434/v1` |
|
||||
| Text Generation Inference (TGI)| `openai` | `http://localhost:8080/v1` |
|
||||
| SGLang | `openai` | `http://localhost:30000/v1` |
|
||||
| vLLM | `openai` | `http://localhost:8000/v1` |
|
||||
| Aphrodite | `openai` | `http://localhost:2242/v1` |
|
||||
| FriendliAI | `openai` | `http://localhost:8997/v1` |
|
||||
| LMDeploy | `openai` | `http://localhost:23333/v1` |
|
||||
|
||||
**Important Note on `localhost` vs `host.docker.internal`:**
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ const config = {
|
||||
GitHub
|
||||
</a>
|
||||
{' | '}
|
||||
<a href="https://www.blog.docsgpt.cloud/" target="_blank">
|
||||
<a href="https://blog.docsgpt.cloud/" target="_blank">
|
||||
Blog
|
||||
</a>
|
||||
</div>
|
||||
|
||||
@@ -5,6 +5,8 @@
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0,viewport-fit=cover" />
|
||||
<meta name="apple-mobile-web-app-capable" content="yes">
|
||||
<meta name="theme-color" content="#fbfbfb" media="(prefers-color-scheme: light)" />
|
||||
<meta name="theme-color" content="#161616" media="(prefers-color-scheme: dark)" />
|
||||
<title>DocsGPT</title>
|
||||
<link rel="shortcut icon" type="image/x-icon" href="/favicon.ico" />
|
||||
</head>
|
||||
|
||||
1732
frontend/package-lock.json
generated
1732
frontend/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -19,19 +19,20 @@
|
||||
]
|
||||
},
|
||||
"dependencies": {
|
||||
"@reduxjs/toolkit": "^2.5.1",
|
||||
"@reduxjs/toolkit": "^2.8.2",
|
||||
"chart.js": "^4.4.4",
|
||||
"clsx": "^2.1.1",
|
||||
"copy-to-clipboard": "^3.3.3",
|
||||
"i18next": "^24.2.0",
|
||||
"i18next-browser-languagedetector": "^8.0.2",
|
||||
"lodash": "^4.17.21",
|
||||
"mermaid": "^11.6.0",
|
||||
"prop-types": "^15.8.1",
|
||||
"react": "^18.2.0",
|
||||
"react": "^19.1.0",
|
||||
"react-chartjs-2": "^5.3.0",
|
||||
"react-copy-to-clipboard": "^5.1.0",
|
||||
"react-dom": "^18.3.1",
|
||||
"react-dropzone": "^14.3.5",
|
||||
"react-helmet": "^6.1.0",
|
||||
"react-dom": "^19.0.0",
|
||||
"react-dropzone": "^14.3.8",
|
||||
"react-google-drive-picker": "^1.2.2",
|
||||
"react-i18next": "^15.4.0",
|
||||
"react-markdown": "^9.0.1",
|
||||
"react-redux": "^9.2.0",
|
||||
@@ -39,18 +40,19 @@
|
||||
"react-syntax-highlighter": "^15.6.1",
|
||||
"rehype-katex": "^7.0.1",
|
||||
"remark-gfm": "^4.0.0",
|
||||
"remark-math": "^6.0.0"
|
||||
"remark-math": "^6.0.0",
|
||||
"tailwind-merge": "^3.3.1"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@tailwindcss/postcss": "^4.1.10",
|
||||
"@types/lodash": "^4.17.20",
|
||||
"@types/mermaid": "^9.1.0",
|
||||
"@types/react": "^18.0.27",
|
||||
"@types/react-dom": "^18.3.0",
|
||||
"@types/react-helmet": "^6.1.11",
|
||||
"@types/react": "^19.1.8",
|
||||
"@types/react-dom": "^19.0.0",
|
||||
"@types/react-syntax-highlighter": "^15.5.13",
|
||||
"@typescript-eslint/eslint-plugin": "^5.51.0",
|
||||
"@typescript-eslint/parser": "^5.62.0",
|
||||
"@vitejs/plugin-react": "^4.3.4",
|
||||
"autoprefixer": "^10.4.13",
|
||||
"eslint": "^8.57.1",
|
||||
"eslint-config-prettier": "^10.1.5",
|
||||
"eslint-config-standard-with-typescript": "^34.0.0",
|
||||
@@ -64,8 +66,8 @@
|
||||
"lint-staged": "^15.3.0",
|
||||
"postcss": "^8.4.49",
|
||||
"prettier": "^3.5.3",
|
||||
"prettier-plugin-tailwindcss": "^0.6.11",
|
||||
"tailwindcss": "^3.4.17",
|
||||
"prettier-plugin-tailwindcss": "^0.6.13",
|
||||
"tailwindcss": "^4.1.11",
|
||||
"typescript": "^5.8.3",
|
||||
"vite": "^6.3.5",
|
||||
"vite-plugin-svgr": "^4.3.0"
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
module.exports = {
|
||||
plugins: {
|
||||
tailwindcss: {},
|
||||
autoprefixer: {},
|
||||
'@tailwindcss/postcss': {},
|
||||
},
|
||||
}
|
||||
|
||||
1
frontend/public/toolIcons/tool_duckduckgo.svg
Normal file
1
frontend/public/toolIcons/tool_duckduckgo.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 122.88 122.88"><defs><style>.a{fill:#d53;}.b{fill:#fff;}.c{fill:#ddd;}.d{fill:#fc0;}.e{fill:#6b5;}.f{fill:#4a4;}.g{fill:#148;}</style></defs><title>duckduckgo</title><path class="a" d="M122.88,61.44a61.44,61.44,0,1,0-61.44,61.44,61.44,61.44,0,0,0,61.44-61.44Z"/><path class="b" d="M114.37,61.44a52.92,52.92,0,1,0-15.5,37.43,52.76,52.76,0,0,0,15.5-37.43Zm-13.12-39.8A56.29,56.29,0,1,1,61.44,5.15a56.12,56.12,0,0,1,39.81,16.49Z"/><path class="c" d="M43.24,30.15C26.17,34.13,32.43,58,32.43,58l10.81,52.9,4,1.71-4-82.49Zm-4-10.24H34.7L41,22.19s-6.26,0-6.26,4C48.36,25.6,54.61,29,54.61,29l-15.36-9.1Zm0,0Z"/><path class="b" d="M75.66,115.48S62,93.87,62,79.64c0-26.73,17.63-4,17.63-25S62,28.44,62,28.44c-8.53-10.8-25-8.53-25-8.53l4,2.28s-4,1.13-5.12,2.27,10.81-1.7,15.93,2.85C30.72,29,34.13,46.08,34.13,46.08l11.95,68.27,29.58,1.13Zm0,0Z"/><path class="d" d="M75.66,60.87l21.62-5.69C116.62,58,80.78,68.84,78.51,68.27c-17.07-2.85-12,11.37,8.53,6.82s5.12,11.38-13.65,5.12c-26.74-7.39-12.52-20.48,2.27-19.34Z"/><path class="e" d="M70,105.81l1.14-1.7c12.52,4.55,13.09,6.25,12.52-5.12s0-11.38-13.09-1.71c0-2.84-7.39-1.71-8.53,0-11.95-5.12-13.09-6.83-12.52,1.14,1.14,16.5.57,13.65,11.95,8l8.53-.57Zm0,0Z"/><path class="f" d="M60.87,99.56v6.82c.57,1.14,9.67,1.14,9.67-1.14s-4.55,1.71-7.39.57S62,98.42,62,98.42l-1.14,1.14Zm0,0Z"/><path class="g" d="M48.36,43.24c-2.85-3.42-10.24-.57-8.54,4,.57-2.28,4.55-5.69,8.54-4Zm18.2,0c.57-3.42,6.26-4,8-.57a8,8,0,0,0-8,.57Zm-18.77,9.1a1.14,1.14,0,1,1,0,.57v-.57Zm-4.55,2.27a4,4,0,1,0,0-.57v.57Zm29.58-4a1.14,1.14,0,1,1,0,.57v-.57ZM69.4,52.91a3.42,3.42,0,1,0,0-.57v.57Zm0,0Z"/></svg>
|
||||
|
After Width: | Height: | Size: 1.6 KiB |
4
frontend/public/toolIcons/tool_mcp_tool.svg
Normal file
4
frontend/public/toolIcons/tool_mcp_tool.svg
Normal file
@@ -0,0 +1,4 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" width="64" height="64" color="#000000" fill="none">
|
||||
<path d="M3.49994 11.7501L11.6717 3.57855C12.7762 2.47398 14.5672 2.47398 15.6717 3.57855C16.7762 4.68312 16.7762 6.47398 15.6717 7.57855M15.6717 7.57855L9.49994 13.7501M15.6717 7.57855C16.7762 6.47398 18.5672 6.47398 19.6717 7.57855C20.7762 8.68312 20.7762 10.474 19.6717 11.5785L12.7072 18.543C12.3167 18.9335 12.3167 19.5667 12.7072 19.9572L13.9999 21.2499" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"></path>
|
||||
<path d="M17.4999 9.74921L11.3282 15.921C10.2237 17.0255 8.43272 17.0255 7.32823 15.921C6.22373 14.8164 6.22373 13.0255 7.32823 11.921L13.4999 5.74939" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"></path>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 831 B |
Binary file not shown.
@@ -33,7 +33,7 @@ function MainLayout() {
|
||||
const [navOpen, setNavOpen] = useState(!(isMobile || isTablet));
|
||||
|
||||
return (
|
||||
<div className="relative h-screen overflow-hidden dark:bg-raisin-black">
|
||||
<div className="dark:bg-raisin-black relative h-screen overflow-hidden">
|
||||
<Navigation navOpen={navOpen} setNavOpen={setNavOpen} />
|
||||
<ActionButtons showNewChat={true} showShare={true} />
|
||||
<div
|
||||
|
||||
@@ -19,9 +19,9 @@ export default function Hero({
|
||||
}>;
|
||||
|
||||
return (
|
||||
<div className="flex h-full w-full flex-col items-center justify-between text-black-1000 dark:text-bright-gray">
|
||||
<div className="text-black-1000 dark:text-bright-gray flex h-full w-full flex-col items-center justify-between">
|
||||
{/* Header Section */}
|
||||
<div className="flex flex-grow flex-col items-center justify-center pt-8 md:pt-0">
|
||||
<div className="flex grow flex-col items-center justify-center pt-8 md:pt-0">
|
||||
<div className="mb-4 flex items-center">
|
||||
<span className="text-4xl font-semibold">DocsGPT</span>
|
||||
<img className="mb-1 inline w-14" src={DocsGPT3} alt="docsgpt" />
|
||||
@@ -29,7 +29,7 @@ export default function Hero({
|
||||
</div>
|
||||
|
||||
{/* Demo Buttons Section */}
|
||||
<div className="mb-8 w-full max-w-full md:mb-16">
|
||||
<div className="mb-3 w-full max-w-full md:mb-3">
|
||||
<div className="grid grid-cols-1 gap-3 text-xs md:grid-cols-1 md:gap-4 lg:grid-cols-2">
|
||||
{demos?.map(
|
||||
(demo: { header: string; query: string }, key: number) =>
|
||||
@@ -38,9 +38,9 @@ export default function Hero({
|
||||
<button
|
||||
key={key}
|
||||
onClick={() => handleQuestion({ question: demo.query })}
|
||||
className={`w-full rounded-[66px] border border-dark-gray bg-transparent px-6 py-[14px] text-left text-just-black transition-colors hover:bg-cultured dark:border-dim-gray dark:text-chinese-white dark:hover:bg-charleston-green ${key >= 2 ? 'hidden md:block' : ''} // Show only 2 buttons on mobile`}
|
||||
className={`border-dark-gray text-just-black hover:bg-cultured dark:border-dim-gray dark:text-chinese-white dark:hover:bg-charleston-green w-full rounded-[66px] border bg-transparent px-6 py-[14px] text-left transition-colors ${key >= 2 ? 'hidden md:block' : ''} // Show only 2 buttons on mobile`}
|
||||
>
|
||||
<p className="mb-2 font-semibold text-black-1000 dark:text-bright-gray">
|
||||
<p className="text-black-1000 dark:text-bright-gray mb-2 font-semibold">
|
||||
{demo.header}
|
||||
</p>
|
||||
<span className="line-clamp-2 text-gray-700 opacity-60 dark:text-gray-300">
|
||||
|
||||
@@ -10,7 +10,7 @@ import Add from './assets/add.svg';
|
||||
import DocsGPT3 from './assets/cute_docsgpt3.svg';
|
||||
import Discord from './assets/discord.svg';
|
||||
import Expand from './assets/expand.svg';
|
||||
import Github from './assets/github.svg';
|
||||
import Github from './assets/git_nav.svg';
|
||||
import Hamburger from './assets/hamburger.svg';
|
||||
import openNewChat from './assets/openNewChat.svg';
|
||||
import Pin from './assets/pin.svg';
|
||||
@@ -48,6 +48,7 @@ import {
|
||||
setConversations,
|
||||
setModalStateDeleteConv,
|
||||
setSelectedAgent,
|
||||
setSharedAgents,
|
||||
} from './preferences/preferenceSlice';
|
||||
import Upload from './upload/Upload';
|
||||
|
||||
@@ -80,8 +81,27 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
useState<ActiveState>('INACTIVE');
|
||||
const [recentAgents, setRecentAgents] = useState<Agent[]>([]);
|
||||
|
||||
const navRef = useRef(null);
|
||||
const navRef = useRef<HTMLDivElement>(null);
|
||||
useEffect(() => {
|
||||
function handleClickOutside(event: MouseEvent) {
|
||||
if (
|
||||
navRef.current &&
|
||||
!navRef.current.contains(event.target as Node) &&
|
||||
(isMobile || isTablet) &&
|
||||
navOpen
|
||||
) {
|
||||
setNavOpen(false);
|
||||
}
|
||||
}
|
||||
|
||||
//event listener only for mobile/tablet when nav is open
|
||||
if ((isMobile || isTablet) && navOpen) {
|
||||
document.addEventListener('mousedown', handleClickOutside);
|
||||
return () => {
|
||||
document.removeEventListener('mousedown', handleClickOutside);
|
||||
};
|
||||
}
|
||||
}, [navOpen, isMobile, isTablet, setNavOpen]);
|
||||
async function fetchRecentAgents() {
|
||||
try {
|
||||
const response = await userService.getPinnedAgents(token);
|
||||
@@ -169,73 +189,65 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
const handleTogglePin = (agent: Agent) => {
|
||||
userService.togglePinAgent(agent.id ?? '', token).then((response) => {
|
||||
if (response.ok) {
|
||||
const updatedAgents = agents?.map((a) =>
|
||||
a.id === agent.id ? { ...a, pinned: !a.pinned } : a,
|
||||
);
|
||||
dispatch(setAgents(updatedAgents));
|
||||
const updatePinnedStatus = (a: Agent) =>
|
||||
a.id === agent.id ? { ...a, pinned: !a.pinned } : a;
|
||||
dispatch(setAgents(agents?.map(updatePinnedStatus)));
|
||||
dispatch(setSharedAgents(sharedAgents?.map(updatePinnedStatus)));
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
const handleConversationClick = (index: string) => {
|
||||
dispatch(setSelectedAgent(null));
|
||||
conversationService
|
||||
.getConversation(index, token)
|
||||
.then((response) => {
|
||||
if (!response.ok) {
|
||||
navigate('/');
|
||||
dispatch(setSelectedAgent(null));
|
||||
return null;
|
||||
}
|
||||
return response.json();
|
||||
})
|
||||
.then((data) => {
|
||||
if (!data) return;
|
||||
dispatch(setConversation(data.queries));
|
||||
dispatch(
|
||||
updateConversationId({
|
||||
query: { conversationId: index },
|
||||
}),
|
||||
const handleConversationClick = async (index: string) => {
|
||||
try {
|
||||
dispatch(setSelectedAgent(null));
|
||||
|
||||
const response = await conversationService.getConversation(index, token);
|
||||
if (!response.ok) {
|
||||
navigate('/');
|
||||
return;
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
if (!data) return;
|
||||
|
||||
dispatch(setConversation(data.queries));
|
||||
dispatch(updateConversationId({ query: { conversationId: index } }));
|
||||
|
||||
if (!data.agent_id) {
|
||||
navigate('/');
|
||||
return;
|
||||
}
|
||||
|
||||
let agent: Agent;
|
||||
if (data.is_shared_usage) {
|
||||
const sharedResponse = await userService.getSharedAgent(
|
||||
data.shared_token,
|
||||
token,
|
||||
);
|
||||
if (isMobile || isTablet) {
|
||||
setNavOpen(false);
|
||||
}
|
||||
if (data.agent_id) {
|
||||
if (data.is_shared_usage) {
|
||||
userService
|
||||
.getSharedAgent(data.shared_token, token)
|
||||
.then((response) => {
|
||||
if (!response.ok) {
|
||||
navigate('/');
|
||||
dispatch(setSelectedAgent(null));
|
||||
return;
|
||||
}
|
||||
response.json().then((agent: Agent) => {
|
||||
navigate(`/agents/shared/${agent.shared_token}`);
|
||||
});
|
||||
});
|
||||
} else {
|
||||
userService.getAgent(data.agent_id, token).then((response) => {
|
||||
if (!response.ok) {
|
||||
navigate('/');
|
||||
dispatch(setSelectedAgent(null));
|
||||
return;
|
||||
}
|
||||
response.json().then((agent: Agent) => {
|
||||
if (agent.shared_token)
|
||||
navigate(`/agents/shared/${agent.shared_token}`);
|
||||
else {
|
||||
dispatch(setSelectedAgent(agent));
|
||||
navigate('/');
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
} else {
|
||||
if (!sharedResponse.ok) {
|
||||
navigate('/');
|
||||
dispatch(setSelectedAgent(null));
|
||||
return;
|
||||
}
|
||||
});
|
||||
agent = await sharedResponse.json();
|
||||
navigate(`/agents/shared/${agent.shared_token}`);
|
||||
} else {
|
||||
const agentResponse = await userService.getAgent(data.agent_id, token);
|
||||
if (!agentResponse.ok) {
|
||||
navigate('/');
|
||||
return;
|
||||
}
|
||||
agent = await agentResponse.json();
|
||||
if (agent.shared_token) {
|
||||
navigate(`/agents/shared/${agent.shared_token}`);
|
||||
} else {
|
||||
await Promise.resolve(dispatch(setSelectedAgent(agent)));
|
||||
navigate('/');
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error handling conversation click:', error);
|
||||
navigate('/');
|
||||
}
|
||||
};
|
||||
|
||||
const resetConversation = () => {
|
||||
@@ -281,7 +293,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
return (
|
||||
<>
|
||||
{!navOpen && (
|
||||
<div className="duration-25 absolute left-3 top-3 z-20 hidden transition-all lg:block">
|
||||
<div className="absolute top-3 left-3 z-20 hidden transition-all duration-25 lg:block">
|
||||
<div className="flex items-center gap-3">
|
||||
<button
|
||||
onClick={() => {
|
||||
@@ -309,7 +321,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
/>
|
||||
</button>
|
||||
)}
|
||||
<div className="text-[20px] font-medium text-[#949494]">
|
||||
<div className="text-gray-4000 text-[20px] font-medium">
|
||||
DocsGPT
|
||||
</div>
|
||||
</div>
|
||||
@@ -318,8 +330,8 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
<div
|
||||
ref={navRef}
|
||||
className={`${
|
||||
!navOpen && '-ml-96 md:-ml-[18rem]'
|
||||
} duration-20 fixed top-0 z-20 flex h-full w-72 flex-col border-b-0 border-r-[1px] bg-lotion transition-all dark:border-r-purple-taupe dark:bg-chinese-black dark:text-white`}
|
||||
!navOpen && '-ml-96 md:-ml-72'
|
||||
} bg-lotion dark:border-r-purple-taupe dark:bg-chinese-black fixed top-0 z-20 flex h-full w-72 flex-col border-r border-b-0 transition-all duration-20 dark:text-white`}
|
||||
>
|
||||
<div
|
||||
className={'visible mt-2 flex h-[6vh] w-full justify-between md:h-12'}
|
||||
@@ -363,7 +375,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
className={({ isActive }) =>
|
||||
`${
|
||||
isActive ? 'bg-transparent' : ''
|
||||
} group sticky mx-4 mt-4 flex cursor-pointer gap-2.5 rounded-3xl border border-silver p-3 hover:border-rainy-gray hover:bg-transparent dark:border-purple-taupe dark:text-white`
|
||||
} group border-silver hover:border-rainy-gray dark:border-purple-taupe sticky mx-4 mt-4 flex cursor-pointer gap-2.5 rounded-3xl border p-3 hover:bg-transparent dark:text-white`
|
||||
}
|
||||
>
|
||||
<img
|
||||
@@ -371,16 +383,16 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
alt="Create new chat"
|
||||
className="opacity-80 group-hover:opacity-100"
|
||||
/>
|
||||
<p className="text-sm text-dove-gray group-hover:text-neutral-600 dark:text-chinese-silver dark:group-hover:text-bright-gray">
|
||||
<p className="text-dove-gray dark:text-chinese-silver dark:group-hover:text-bright-gray text-sm group-hover:text-neutral-600">
|
||||
{t('newChat')}
|
||||
</p>
|
||||
</NavLink>
|
||||
<div
|
||||
id="conversationsMainDiv"
|
||||
className="mb-auto h-[78vh] overflow-y-auto overflow-x-hidden dark:text-white"
|
||||
className="mb-auto h-[78vh] overflow-x-hidden overflow-y-auto dark:text-white"
|
||||
>
|
||||
{conversations?.loading && !isDeletingConversation && (
|
||||
<div className="absolute left-1/2 top-1/2 -translate-x-1/2 -translate-y-1/2 transform">
|
||||
<div className="absolute top-1/2 left-1/2 -translate-x-1/2 -translate-y-1/2 transform">
|
||||
<img
|
||||
src={isDarkTheme ? SpinnerDark : Spinner}
|
||||
className="animate-spin cursor-pointer bg-transparent"
|
||||
@@ -391,14 +403,14 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
{recentAgents?.length > 0 ? (
|
||||
<div>
|
||||
<div className="mx-4 my-auto mt-2 flex h-6 items-center">
|
||||
<p className="ml-4 mt-1 text-sm font-semibold">Agents</p>
|
||||
<p className="mt-1 ml-4 text-sm font-semibold">Agents</p>
|
||||
</div>
|
||||
<div className="agents-container">
|
||||
<div>
|
||||
{recentAgents.map((agent, idx) => (
|
||||
<div
|
||||
key={idx}
|
||||
className={`group mx-4 my-auto mt-4 flex h-9 cursor-pointer items-center justify-between rounded-3xl pl-4 hover:bg-bright-gray dark:hover:bg-dark-charcoal ${
|
||||
className={`group hover:bg-bright-gray dark:hover:bg-dark-charcoal mx-4 my-auto mt-4 flex h-9 cursor-pointer items-center justify-between rounded-3xl pl-4 ${
|
||||
agent.id === selectedAgent?.id && !conversationId
|
||||
? 'bg-bright-gray dark:bg-dark-charcoal'
|
||||
: ''
|
||||
@@ -408,12 +420,16 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
<div className="flex items-center gap-2">
|
||||
<div className="flex w-6 justify-center">
|
||||
<img
|
||||
src={agent.image ?? Robot}
|
||||
src={
|
||||
agent.image && agent.image.trim() !== ''
|
||||
? agent.image
|
||||
: Robot
|
||||
}
|
||||
alt="agent-logo"
|
||||
className="h-6 w-6"
|
||||
className="h-6 w-6 rounded-full object-contain"
|
||||
/>
|
||||
</div>
|
||||
<p className="overflow-hidden overflow-ellipsis whitespace-nowrap text-sm leading-6 text-eerie-black dark:text-bright-gray">
|
||||
<p className="text-eerie-black dark:text-bright-gray overflow-hidden text-sm leading-6 text-ellipsis whitespace-nowrap">
|
||||
{agent.name}
|
||||
</p>
|
||||
</div>
|
||||
@@ -437,7 +453,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
))}
|
||||
</div>
|
||||
<div
|
||||
className="mx-4 my-auto mt-2 flex h-9 cursor-pointer items-center gap-2 rounded-3xl pl-4 hover:bg-bright-gray dark:hover:bg-dark-charcoal"
|
||||
className="hover:bg-bright-gray dark:hover:bg-dark-charcoal mx-4 my-auto mt-2 flex h-9 cursor-pointer items-center gap-2 rounded-3xl pl-4"
|
||||
onClick={() => {
|
||||
dispatch(setSelectedAgent(null));
|
||||
if (isMobile || isTablet) {
|
||||
@@ -453,7 +469,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
className="h-[18px] w-[18px]"
|
||||
/>
|
||||
</div>
|
||||
<p className="overflow-hidden overflow-ellipsis whitespace-nowrap text-sm leading-6 text-eerie-black dark:text-bright-gray">
|
||||
<p className="text-eerie-black dark:text-bright-gray overflow-hidden text-sm leading-6 text-ellipsis whitespace-nowrap">
|
||||
{t('manageAgents')}
|
||||
</p>
|
||||
</div>
|
||||
@@ -461,7 +477,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
</div>
|
||||
) : (
|
||||
<div
|
||||
className="mx-4 my-auto mt-2 flex h-9 cursor-pointer items-center gap-2 rounded-3xl pl-4 hover:bg-bright-gray dark:hover:bg-dark-charcoal"
|
||||
className="hover:bg-bright-gray dark:hover:bg-dark-charcoal mx-4 my-auto mt-2 flex h-9 cursor-pointer items-center gap-2 rounded-3xl pl-4"
|
||||
onClick={() => {
|
||||
if (isMobile || isTablet) {
|
||||
setNavOpen(false);
|
||||
@@ -477,7 +493,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
className="h-[18px] w-[18px]"
|
||||
/>
|
||||
</div>
|
||||
<p className="overflow-hidden overflow-ellipsis whitespace-nowrap text-sm leading-6 text-eerie-black dark:text-bright-gray">
|
||||
<p className="text-eerie-black dark:text-bright-gray overflow-hidden text-sm leading-6 text-ellipsis whitespace-nowrap">
|
||||
{t('manageAgents')}
|
||||
</p>
|
||||
</div>
|
||||
@@ -485,7 +501,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
{conversations?.data && conversations.data.length > 0 ? (
|
||||
<div className="mt-7">
|
||||
<div className="mx-4 my-auto mt-2 flex h-6 items-center justify-between gap-4 rounded-3xl">
|
||||
<p className="ml-4 mt-1 text-sm font-semibold">{t('chats')}</p>
|
||||
<p className="mt-1 ml-4 text-sm font-semibold">{t('chats')}</p>
|
||||
</div>
|
||||
<div className="conversations-container">
|
||||
{conversations.data?.map((conversation) => (
|
||||
@@ -510,8 +526,8 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
<></>
|
||||
)}
|
||||
</div>
|
||||
<div className="flex h-auto flex-col justify-end text-eerie-black dark:text-white">
|
||||
<div className="flex flex-col gap-2 border-b-[1px] py-2 dark:border-b-purple-taupe">
|
||||
<div className="text-eerie-black flex h-auto flex-col justify-end dark:text-white">
|
||||
<div className="dark:border-b-purple-taupe flex flex-col gap-2 border-b py-2">
|
||||
<NavLink
|
||||
onClick={() => {
|
||||
if (isMobile || isTablet) {
|
||||
@@ -521,7 +537,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
}}
|
||||
to="/settings"
|
||||
className={({ isActive }) =>
|
||||
`mx-4 my-auto flex h-9 cursor-pointer gap-4 rounded-3xl hover:bg-gray-100 dark:hover:bg-[#28292E] ${
|
||||
`mx-4 my-auto flex h-9 cursor-pointer items-center gap-4 rounded-3xl hover:bg-gray-100 dark:hover:bg-[#28292E] ${
|
||||
isActive ? 'bg-gray-3000 dark:bg-transparent' : ''
|
||||
}`
|
||||
}
|
||||
@@ -529,14 +545,16 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
<img
|
||||
src={SettingGear}
|
||||
alt="Settings"
|
||||
className="w- ml-2 filter dark:invert"
|
||||
width={21}
|
||||
height={21}
|
||||
className="my-auto ml-2 filter dark:invert"
|
||||
/>
|
||||
<p className="my-auto text-sm text-eerie-black dark:text-white">
|
||||
<p className="text-eerie-black text-sm dark:text-white">
|
||||
{t('settings.label')}
|
||||
</p>
|
||||
</NavLink>
|
||||
</div>
|
||||
<div className="flex flex-col justify-end text-eerie-black dark:text-white">
|
||||
<div className="text-eerie-black flex flex-col justify-end dark:text-white">
|
||||
<div className="flex items-center justify-between py-1">
|
||||
<Help />
|
||||
|
||||
@@ -550,6 +568,8 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
>
|
||||
<img
|
||||
src={Discord}
|
||||
width={24}
|
||||
height={24}
|
||||
alt="Join Discord community"
|
||||
className="m-2 w-6 self-center filter dark:invert"
|
||||
/>
|
||||
@@ -563,8 +583,10 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
>
|
||||
<img
|
||||
src={Twitter}
|
||||
width={20}
|
||||
height={20}
|
||||
alt="Follow us on Twitter"
|
||||
className="m-2 w-5 self-center filter dark:invert"
|
||||
className="m-2 self-center filter dark:invert"
|
||||
/>
|
||||
</NavLink>
|
||||
<NavLink
|
||||
@@ -577,7 +599,9 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
<img
|
||||
src={Github}
|
||||
alt="View on GitHub"
|
||||
className="m-2 w-6 self-center filter dark:invert"
|
||||
width={28}
|
||||
height={28}
|
||||
className="m-2 self-center filter dark:invert"
|
||||
/>
|
||||
</NavLink>
|
||||
</div>
|
||||
@@ -585,7 +609,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div className="sticky z-10 h-16 w-full border-b-2 bg-gray-50 dark:border-b-purple-taupe dark:bg-chinese-black lg:hidden">
|
||||
<div className="dark:border-b-purple-taupe dark:bg-chinese-black sticky z-10 h-16 w-full border-b-2 bg-gray-50 lg:hidden">
|
||||
<div className="ml-6 flex h-full items-center gap-6">
|
||||
<button
|
||||
className="h-6 w-6 lg:hidden"
|
||||
@@ -597,7 +621,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
className="w-7 filter dark:invert"
|
||||
/>
|
||||
</button>
|
||||
<div className="text-[20px] font-medium text-[#949494]">DocsGPT</div>
|
||||
<div className="text-gray-4000 text-[20px] font-medium">DocsGPT</div>
|
||||
</div>
|
||||
</div>
|
||||
<DeleteConvModal
|
||||
|
||||
@@ -2,11 +2,11 @@ import { Link } from 'react-router-dom';
|
||||
|
||||
export default function PageNotFound() {
|
||||
return (
|
||||
<div className="grid min-h-screen dark:bg-raisin-black">
|
||||
<p className="mx-auto my-auto mt-20 flex w-full max-w-6xl flex-col place-items-center gap-6 rounded-3xl bg-gray-100 p-6 text-jet dark:bg-outer-space dark:text-gray-100 lg:p-10 xl:p-16">
|
||||
<div className="dark:bg-raisin-black grid min-h-screen">
|
||||
<p className="text-jet dark:bg-outer-space mx-auto my-auto mt-20 flex w-full max-w-6xl flex-col place-items-center gap-6 rounded-3xl bg-gray-100 p-6 lg:p-10 xl:p-16 dark:text-gray-100">
|
||||
<h1>404</h1>
|
||||
<p>The page you are looking for does not exist.</p>
|
||||
<button className="pointer-cursor mr-4 flex cursor-pointer items-center justify-center rounded-full bg-blue-1000 px-4 py-2 text-white transition-colors duration-100 hover:bg-blue-3000">
|
||||
<button className="pointer-cursor bg-blue-1000 hover:bg-blue-3000 mr-4 flex cursor-pointer items-center justify-center rounded-full px-4 py-2 text-white transition-colors duration-100">
|
||||
<Link to="/">Go Back Home</Link>
|
||||
</button>
|
||||
</p>
|
||||
|
||||
@@ -54,7 +54,7 @@ export default function AgentCard({
|
||||
|
||||
return (
|
||||
<div
|
||||
className={`relative flex h-44 w-48 flex-col justify-between rounded-[1.2rem] bg-[#F6F6F6] px-6 py-5 hover:bg-[#ECECEC] dark:bg-[#383838] hover:dark:bg-[#383838]/80 ${
|
||||
className={`relative flex h-44 w-48 flex-col justify-between rounded-[1.2rem] bg-[#F6F6F6] px-6 py-5 hover:bg-[#ECECEC] dark:bg-[#383838] dark:hover:bg-[#383838]/80 ${
|
||||
agent.status === 'published' ? 'cursor-pointer' : ''
|
||||
}`}
|
||||
onClick={handleCardClick}
|
||||
@@ -65,7 +65,7 @@ export default function AgentCard({
|
||||
e.stopPropagation();
|
||||
setIsMenuOpen(true);
|
||||
}}
|
||||
className="absolute right-4 top-4 z-10 cursor-pointer"
|
||||
className="absolute top-4 right-4 z-10 cursor-pointer"
|
||||
>
|
||||
<img src={ThreeDots} alt="options" className="h-[19px] w-[19px]" />
|
||||
{menuOptions && (
|
||||
@@ -83,9 +83,9 @@ export default function AgentCard({
|
||||
<div className="w-full">
|
||||
<div className="flex w-full items-center gap-1 px-1">
|
||||
<img
|
||||
src={agent.image ?? Robot}
|
||||
src={agent.image && agent.image.trim() !== '' ? agent.image : Robot}
|
||||
alt={`${agent.name}`}
|
||||
className="h-7 w-7 rounded-full"
|
||||
className="h-7 w-7 rounded-full object-contain"
|
||||
/>
|
||||
{agent.status === 'draft' && (
|
||||
<p className="text-xs text-black opacity-50 dark:text-[#E0E0E0]">
|
||||
@@ -96,11 +96,11 @@ export default function AgentCard({
|
||||
<div className="mt-2">
|
||||
<p
|
||||
title={agent.name}
|
||||
className="truncate px-1 text-[13px] font-semibold capitalize leading-relaxed text-[#020617] dark:text-[#E0E0E0]"
|
||||
className="truncate px-1 text-[13px] leading-relaxed font-semibold text-[#020617] capitalize dark:text-[#E0E0E0]"
|
||||
>
|
||||
{agent.name}
|
||||
</p>
|
||||
<p className="mt-1 h-20 overflow-auto px-1 text-[12px] leading-relaxed text-[#64748B] dark:text-sonic-silver-light">
|
||||
<p className="dark:text-sonic-silver-light mt-1 h-20 overflow-auto px-1 text-[12px] leading-relaxed text-[#64748B]">
|
||||
{agent.description}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user