Compare commits

..

175 Commits

Author SHA1 Message Date
Alex
ed0063aada chore: handlers tests 2026-03-30 12:53:50 +01:00
Alex
3f6d6f15ea Merge pull request #2338 from arc53/tests-utils
chore: utils tests
2026-03-29 11:54:59 +01:00
Alex
126fa01b14 chore: utils tests 2026-03-29 11:49:35 +01:00
Alex
e06debad5f chore: connector tests 2026-03-29 10:32:48 +01:00
Alex
6492852f7d fix: lint on seeder test 2026-03-29 09:48:46 +01:00
Alex
00a621f33a tests: retriever and seeder 2026-03-29 09:46:07 +01:00
Alex
e92ffc6fdc Merge pull request #2337 from arc53/tests-api
chore: api and tool tests
2026-03-28 22:55:10 +00:00
Alex
fe185e5b8d chore: api and tool tests 2026-03-28 21:51:47 +00:00
Alex
9f3d9ab860 docs: agent types 2026-03-28 18:50:50 +00:00
Alex
1c0adde380 chore: docs update 2026-03-28 17:04:06 +00:00
Alex
3c56bd0d0b docs: fix conflicts 2026-03-28 16:16:15 +00:00
Alex
86664ebda2 Merge pull request #2320 from Alex-wuhu/novita-integration
feat: complete Novita AI provider integration
2026-03-28 13:22:02 +00:00
Alex
db18b743d1 fix tests 2 2026-03-28 13:10:41 +00:00
Alex
9e85cc9065 fix: test errors 2026-03-28 13:04:21 +00:00
Alex
aaaa6f002d Merge branch 'main' of https://github.com/arc53/DocsGPT 2026-03-28 12:03:27 +00:00
Alex
47dcbcb74b fix: tests and sources on workflow agent 2026-03-28 12:03:16 +00:00
Alex
ddbfd94193 Adjust demo GIF height in README
Update the height of the demo GIF in README.
2026-03-28 11:09:05 +00:00
Alex
8dec60ab8b Update demo GIF in README
Replaced demo video GIF in README with a new example.
2026-03-28 11:08:10 +00:00
Alex
84b2e4bab4 fix: end node multi input 2026-03-28 10:00:01 +00:00
Alex
2afdd7f026 fix: enable tools in workflow agents 2026-03-27 13:43:09 +00:00
Alex
f364475f64 Merge pull request #2335 from arc53/tests-worker
tests: worker coverage
2026-03-27 12:14:56 +00:00
Alex
b254de6ed6 tests: worker coverage 2026-03-27 12:03:55 +00:00
Alex
08dedcaf95 Merge pull request #2334 from arc53/tests-vectors
tests: vectors
2026-03-26 19:11:20 +00:00
Alex
c726eb8ebd fix: ruff 2026-03-26 18:48:53 +00:00
Alex
5f0d39e5f1 tests: vectors 2026-03-26 18:42:59 +00:00
Alex
8c82fc5495 Merge pull request #2333 from arc53/chore/bump-npm-v0.6.3
chore: bump npm libraries to v0.6.3
2026-03-26 14:17:09 +00:00
github-actions[bot]
6d81a15e97 chore: bump npm libraries to v0.6.3 2026-03-26 14:16:32 +00:00
Alex
5478e4234c chore: bump npm again 2026-03-26 14:06:05 +00:00
Alex
4056278fef chore: bump npm 2026-03-26 13:48:05 +00:00
Alex
ee6530fe00 chore: bump npm 2026-03-26 13:42:51 +00:00
Alex
7c1decbcc3 fix: npm publish 2026-03-26 13:39:09 +00:00
Alex
8a3c724b31 Merge pull request #2332 from arc53/fix-fallbacks-onstream
Fix fallbacks onstream
2026-03-26 13:18:32 +00:00
Alex
15d4e9dbf5 fix: strict json in openai base 2026-03-26 13:08:22 +00:00
Alex
f7bfd38b28 fix: proper fallback handling within agent during stream 2026-03-26 12:52:30 +00:00
Alex
187e5da61e Merge pull request #2328 from ManishMadan2882/main
Chore(React widget): automate publish via actions
2026-03-26 11:15:52 +00:00
Alex
175ed58d2e Merge pull request #2327 from arc53/research-agent
Research agent
2026-03-26 11:00:23 +00:00
Alex
820ee3a843 fix test 2026-03-25 22:53:09 +00:00
Alex
462f2e9494 mini refactors 2026-03-25 22:34:25 +00:00
ManishMadan2882
c4968a641e (chore)react-widget: add linter 2026-03-26 02:06:49 +05:30
Alex
c6ece177cd fix: small issues 2026-03-25 20:04:03 +00:00
ManishMadan2882
a3e6a5622d (chore)react-widget: build, publish act 2026-03-26 01:03:09 +05:30
Alex
e8d11fdfa6 feat: list files internal tool 2026-03-25 19:21:46 +00:00
Alex
72393dc369 feat: improve research 2026-03-25 17:42:24 +00:00
Alex
556b0a1da5 feat: research init 2026-03-25 15:16:18 +00:00
Alex
32c268a21e refactor: simplify agent architecture and remove ReActAgent 2026-03-25 12:47:17 +00:00
Alex
ed34c2b929 fix: tool name in agent builder 2026-03-25 11:29:31 +00:00
Alex
06e827573c fix: jinja 2026-03-25 00:03:42 +00:00
Manish Madan
74e76d4cda Merge pull request #2323 from arc53/dependabot/npm_and_yarn/extensions/react-widget/styled-components-6.3.12
chore(deps): bump styled-components from 6.1.11 to 6.3.12 in /extensions/react-widget
2026-03-25 01:44:55 +05:30
ManishMadan2882
db5c69ca76 (chore)react widget: build fix 2026-03-25 01:37:09 +05:30
dependabot[bot]
05aa9d7cca chore(deps): bump styled-components in /extensions/react-widget
Bumps [styled-components](https://github.com/styled-components/styled-components) from 6.1.11 to 6.3.12.
- [Release notes](https://github.com/styled-components/styled-components/releases)
- [Commits](https://github.com/styled-components/styled-components/compare/v6.1.11...styled-components@6.3.12)

---
updated-dependencies:
- dependency-name: styled-components
  dependency-version: 6.3.12
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-03-23 20:53:52 +00:00
Manish Madan
dcececd118 Merge pull request #2298 from arc53/dependabot/npm_and_yarn/extensions/react-widget/flow-bin-0.305.1
chore(deps): bump flow-bin from 0.305.0 to 0.305.1 in /extensions/react-widget
2026-03-23 19:26:21 +05:30
Alex-wuhu
eaf39bb15b feat: add Novita AI as LLM provider
Add Novita AI (https://novita.ai) as a new LLM provider option.
Novita offers OpenAI-compatible API endpoints with competitive pricing.
2026-03-23 10:52:26 +08:00
dependabot[bot]
6515481624 chore(deps): bump flow-bin in /extensions/react-widget
Bumps [flow-bin](https://github.com/flowtype/flow-bin) from 0.305.0 to 0.305.1.
- [Release notes](https://github.com/flowtype/flow-bin/releases)
- [Commits](https://github.com/flowtype/flow-bin/commits)

---
updated-dependencies:
- dependency-name: flow-bin
  dependency-version: 0.305.1
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-03-22 20:29:46 +00:00
Manish Madan
6a7e3b6d77 Merge pull request #2165 from arc53/dependabot/pip/extensions/slack-bot/pip-75f0befae6
chore(deps): bump h11 from 0.14.0 to 0.16.0 in /extensions/slack-bot in the pip group across 1 directory
2026-03-23 01:48:58 +05:30
dependabot[bot]
02804fecce chore(deps): bump h11
Bumps the pip group with 1 update in the /extensions/slack-bot directory: [h11](https://github.com/python-hyper/h11).


Updates `h11` from 0.14.0 to 0.16.0
- [Commits](https://github.com/python-hyper/h11/compare/v0.14.0...v0.16.0)

---
updated-dependencies:
- dependency-name: h11
  dependency-version: 0.16.0
  dependency-type: direct:production
  dependency-group: pip
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-03-22 19:41:56 +00:00
Alex
ce5cd5561a loading animation 2026-03-19 15:52:27 +00:00
Manish Madan
adeefce9aa (fix)widget: since v6 shouldForwardProp is no longer provided by default (#2312) 2026-03-18 15:32:36 +00:00
Alex
5ab43fd12c fix: lang overflows, sst (#2314) 2026-03-18 14:50:29 +00:00
Alex
5894e47189 Agents.md 2026-03-18 00:34:26 +00:00
Manish Madan
ca61d81f4a Merge pull request #2154 from arc53/dependabot/npm_and_yarn/extensions/react-widget/typescript-5.9.3
chore(deps-dev): bump typescript from 5.4.5 to 5.9.3 in /extensions/react-widget
2026-03-18 01:43:16 +05:30
Alex
b12d0ca7b1 screenshot memo on contributing md 2026-03-17 19:45:37 +00:00
Alex
21996af626 stt init (#2306)
* stt init

* fix: limits

* fix: errors

* fix: error messages
2026-03-17 14:27:48 +00:00
ManishMadan2882
cc3b174e5a Merge branch 'dependabot/npm_and_yarn/extensions/react-widget/typescript-5.9.3' of https://github.com/arc53/docsgpt into dependabot/npm_and_yarn/extensions/react-widget/typescript-5.9.3 2026-03-16 17:08:06 +05:30
dependabot[bot]
faee58fb1e chore(deps-dev): bump typescript in /extensions/react-widget
Bumps [typescript](https://github.com/microsoft/TypeScript) from 5.4.5 to 5.9.3.
- [Release notes](https://github.com/microsoft/TypeScript/releases)
- [Changelog](https://github.com/microsoft/TypeScript/blob/main/azure-pipelines.release-publish.yml)
- [Commits](https://github.com/microsoft/TypeScript/compare/v5.4.5...v5.9.3)

---
updated-dependencies:
- dependency-name: typescript
  dependency-version: 5.9.3
  dependency-type: direct:development
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-03-16 11:37:27 +00:00
Manish Madan
d439e48b39 Merge pull request #2151 from arc53/dependabot/npm_and_yarn/extensions/react-widget/flow-bin-0.290.0
chore(deps): bump flow-bin from 0.229.2 to 0.290.0 in /extensions/react-widget
2026-03-16 17:05:33 +05:30
ManishMadan2882
3f0f155d64 Merge branch 'dependabot/npm_and_yarn/extensions/react-widget/flow-bin-0.290.0' of https://github.com/arc53/docsgpt into dependabot/npm_and_yarn/extensions/react-widget/flow-bin-0.290.0 2026-03-16 17:02:35 +05:30
dependabot[bot]
d82d512319 chore(deps): bump flow-bin in /extensions/react-widget
Bumps [flow-bin](https://github.com/flowtype/flow-bin) from 0.229.2 to 0.290.0.
- [Release notes](https://github.com/flowtype/flow-bin/releases)
- [Commits](https://github.com/flowtype/flow-bin/commits)

---
updated-dependencies:
- dependency-name: flow-bin
  dependency-version: 0.290.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-03-16 11:29:26 +00:00
Manish Madan
76aea1716f Merge pull request #2150 from arc53/dependabot/npm_and_yarn/extensions/react-widget/babel-loader-10.0.0
chore(deps-dev): bump babel-loader from 8.3.0 to 10.0.0 in /extensions/react-widget
2026-03-16 16:57:38 +05:30
ManishMadan2882
586649b73f (chore) react-widget: peer deps update 2026-03-16 16:56:02 +05:30
dependabot[bot]
0349a79cb3 chore(deps-dev): bump babel-loader in /extensions/react-widget
Bumps [babel-loader](https://github.com/babel/babel-loader) from 8.3.0 to 10.0.0.
- [Release notes](https://github.com/babel/babel-loader/releases)
- [Changelog](https://github.com/babel/babel-loader/blob/main/CHANGELOG.md)
- [Commits](https://github.com/babel/babel-loader/compare/v8.3.0...v10.0.0)

---
updated-dependencies:
- dependency-name: babel-loader
  dependency-version: 10.0.0
  dependency-type: direct:development
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-03-15 21:30:12 +00:00
Manish Madan
78a255bdd7 Merge pull request #2148 from arc53/dependabot/npm_and_yarn/extensions/react-widget/radix-ui/react-icons-1.3.2
chore(deps): bump @radix-ui/react-icons from 1.3.0 to 1.3.2 in /extensions/react-widget
2026-03-16 02:58:35 +05:30
ManishMadan2882
5b30e71aa1 (chore)deps fe and react-widget: audit fix 2026-03-16 02:24:02 +05:30
ManishMadan2882
99d84aece9 Merge branch 'dependabot/npm_and_yarn/extensions/react-widget/radix-ui/react-icons-1.3.2' of https://github.com/arc53/docsgpt into dependabot/npm_and_yarn/extensions/react-widget/radix-ui/react-icons-1.3.2 2026-03-16 01:57:14 +05:30
dependabot[bot]
525d8eb66d chore(deps): bump @radix-ui/react-icons in /extensions/react-widget
Bumps @radix-ui/react-icons from 1.3.0 to 1.3.2.

---
updated-dependencies:
- dependency-name: "@radix-ui/react-icons"
  dependency-version: 1.3.2
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-03-15 20:25:18 +00:00
Manish Madan
4c810108e0 Merge pull request #2145 from arc53/dependabot/npm_and_yarn/frontend/lint-staged-16.2.6
chore(deps-dev): bump lint-staged from 15.5.2 to 16.2.6 in /frontend
2026-03-16 01:41:34 +05:30
dependabot[bot]
fc03cdc76a chore(deps-dev): bump lint-staged from 15.5.2 to 16.2.6 in /frontend
Bumps [lint-staged](https://github.com/lint-staged/lint-staged) from 15.5.2 to 16.2.6.
- [Release notes](https://github.com/lint-staged/lint-staged/releases)
- [Changelog](https://github.com/lint-staged/lint-staged/blob/main/CHANGELOG.md)
- [Commits](https://github.com/lint-staged/lint-staged/compare/v15.5.2...v16.2.6)

---
updated-dependencies:
- dependency-name: lint-staged
  dependency-version: 16.2.6
  dependency-type: direct:development
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-03-15 20:09:03 +00:00
Manish Madan
9779a563f3 Merge pull request #2144 from arc53/dependabot/npm_and_yarn/frontend/vitejs/plugin-react-5.1.0
chore(deps-dev): bump @vitejs/plugin-react from 4.7.0 to 5.1.0 in /frontend
2026-03-16 01:36:30 +05:30
ManishMadan2882
6141c3c348 (chore:deps) fe, vite up 2026-03-16 01:34:24 +05:30
dependabot[bot]
c3726ddfc9 chore(deps-dev): bump @vitejs/plugin-react in /frontend
Bumps [@vitejs/plugin-react](https://github.com/vitejs/vite-plugin-react/tree/HEAD/packages/plugin-react) from 4.7.0 to 5.1.0.
- [Release notes](https://github.com/vitejs/vite-plugin-react/releases)
- [Changelog](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react/CHANGELOG.md)
- [Commits](https://github.com/vitejs/vite-plugin-react/commits/plugin-react@5.1.0/packages/plugin-react)

---
updated-dependencies:
- dependency-name: "@vitejs/plugin-react"
  dependency-version: 5.1.0
  dependency-type: direct:development
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-03-15 19:03:36 +00:00
Manish Madan
10eaa8143e Merge pull request #2143 from arc53/dependabot/npm_and_yarn/frontend/tailwindcss-4.1.17
chore(deps-dev): bump tailwindcss from 4.1.16 to 4.1.17 in /frontend
2026-03-16 00:31:39 +05:30
dependabot[bot]
0c4f4e1f0c chore(deps-dev): bump tailwindcss from 4.1.16 to 4.1.17 in /frontend
Bumps [tailwindcss](https://github.com/tailwindlabs/tailwindcss/tree/HEAD/packages/tailwindcss) from 4.1.16 to 4.1.17.
- [Release notes](https://github.com/tailwindlabs/tailwindcss/releases)
- [Changelog](https://github.com/tailwindlabs/tailwindcss/blob/main/CHANGELOG.md)
- [Commits](https://github.com/tailwindlabs/tailwindcss/commits/v4.1.17/packages/tailwindcss)

---
updated-dependencies:
- dependency-name: tailwindcss
  dependency-version: 4.1.17
  dependency-type: direct:development
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-03-15 18:56:49 +00:00
Manish Madan
b225c3cd80 Merge pull request #2149 from arc53/dependabot/npm_and_yarn/frontend/i18next-25.6.1
chore(deps): bump i18next from 25.6.0 to 25.6.1 in /frontend
2026-03-16 00:24:57 +05:30
dependabot[bot]
b558645d6b chore(deps): bump i18next from 25.6.0 to 25.6.1 in /frontend
Bumps [i18next](https://github.com/i18next/i18next) from 25.6.0 to 25.6.1.
- [Release notes](https://github.com/i18next/i18next/releases)
- [Changelog](https://github.com/i18next/i18next/blob/master/CHANGELOG.md)
- [Commits](https://github.com/i18next/i18next/compare/v25.6.0...v25.6.1)

---
updated-dependencies:
- dependency-name: i18next
  dependency-version: 25.6.1
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-03-15 18:48:31 +00:00
Manish Madan
03b0889b15 Merge pull request #2152 from arc53/dependabot/npm_and_yarn/frontend/react-syntax-highlighter-16.1.0
chore(deps): bump react-syntax-highlighter from 15.6.6 to 16.1.0 in /frontend
2026-03-15 19:58:29 +05:30
dependabot[bot]
943fe3651c chore(deps): bump react-syntax-highlighter in /frontend
Bumps [react-syntax-highlighter](https://github.com/react-syntax-highlighter/react-syntax-highlighter) from 15.6.6 to 16.1.0.
- [Release notes](https://github.com/react-syntax-highlighter/react-syntax-highlighter/releases)
- [Changelog](https://github.com/react-syntax-highlighter/react-syntax-highlighter/blob/master/CHANGELOG.MD)
- [Commits](https://github.com/react-syntax-highlighter/react-syntax-highlighter/compare/v15.6.6...v16.1.0)

---
updated-dependencies:
- dependency-name: react-syntax-highlighter
  dependency-version: 16.1.0
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-03-13 22:58:44 +00:00
Siddhant Rai
65e57be4dd feat: dynamic config rendering + mcp tool enhancement (#2286)
* feat: enhance modal functionality and configuration handling

- Updated WrapperModal to improve click outside detection for closing the modal.
- Refactored ToolConfig to utilize ConfigFieldSpec for better configuration management.
- Added validation and dynamic handling of configuration fields in ToolConfig.
- Introduced reconnect functionality for MCP tools in the Tools component.
- Enhanced user experience with improved error handling and loading states.
- Updated types for better type safety and clarity in configuration requirements.

* refactor: reorganize imports and improve conditional formatting

* fix: revert API_URL to use backend service name in docker-compose

* feat: add MCP auth status endpoint and integrate into user service and tools

* feat: implement logging for Brave, Postgres, and Telegram tools; add transport sanitization and credential extraction for MCP

---------

Co-authored-by: Alex <a@tushynski.me>
2026-03-13 15:58:50 +00:00
Siddhant Rai
13ad3b5dce feat: enhance logging and error handling across various tools; update DuckDuckGo dependency (#2282)
Co-authored-by: Alex <a@tushynski.me>
2026-03-12 16:50:29 +00:00
Manish Madan
918bbf0369 Sharepoint (#2283)
* feat: add Microsoft Entra ID integration

- Updated .env-template and settings.py for Microsoft Entra ID configuration.
- Enhanced ConnectorsCallback to support SharePoint authentication.
- Introduced SharePointAuth and SharePointLoader classes.
- Added required dependencies in requirements.txt.

* feat: agent templates and seeding premade agents (#1910)

* feat: agent templates and seeding premade agents

* fix: ensure ObjectId is used for source reference in agent configuration

* fix: improve source handling in DatabaseSeeder and update tool config processing

* feat: add prompt handling in DatabaseSeeder for agent configuration

* Docs premade agents

* link to prescraped docs

* feat: add template agent retrieval and adopt agent functionality

* feat: simplify agent descriptions in premade_agents.yaml  added docs

---------

Co-authored-by: Pavel <pabin@yandex.ru>
Co-authored-by: Alex <a@tushynski.me>

* feat: add GitHub access token support and fix file content fetching logic (#2032)

* feat: add init for Share Point connector module

* chore(deps): bump mermaid from 11.6.0 to 11.12.0 in /frontend

Bumps [mermaid](https://github.com/mermaid-js/mermaid) from 11.6.0 to 11.12.0.
- [Release notes](https://github.com/mermaid-js/mermaid/releases)
- [Commits](https://github.com/mermaid-js/mermaid/compare/mermaid@11.6.0...mermaid@11.12.0)

---
updated-dependencies:
- dependency-name: mermaid
  dependency-version: 11.12.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* Feat: Notification section (#2033)

* Feature/Notification-section

* fix notification ui and add local storage variable to save the state

* add notification component to app.tsx

* refactor: remove MICROSOFT_REDIRECT_URI and update SharePointAuth to use CONNECTOR_REDIRECT_BASE_URI

* feat: Add button to cancel LLM response (#1978)

* feat: Add button to cancel LLM response
- Replace text area with cancel button when loading.
- Add useEffect to change elipsis in cancel button text.
- Add new SVG icon for cancel response.
- Button colors match Figma designs.

* fix: Cancel button UI matches new design
- Delete cancel-response svg.
- Change previous cancel button to match the new Figma design.
- Remove console log in handleCancel function.

* fix: Adjust cancel button rounding

* feat: Update UI for send button
- Add SendArrowIcon component, enables dynamic svg color changes
- Replace original icon
- Update colors and hover effects

* (fix:send-button) minor blink in transition

---------

Co-authored-by: Manish Madan <manishmadan321@gmail.com>

* feat: add SharePoint integration with session validation and UI components

* (feat:oneDrive) file loading for ingestion

* feat(oneDrive): shared user files

* (feat:oneDrive) rm shared file support, as sharedWithMe is degraded

* (feat:sharepoint) shared files for work msa

* (feat:sharepoint) retry on auth failure, decorator

* (fix) tests/ruff

* test: fix sharepoint loader expecting client id

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: Abhishek Malviya <abfeb8@gmail.com>
Co-authored-by: Siddhant Rai <47355538+siiddhantt@users.noreply.github.com>
Co-authored-by: Pavel <pabin@yandex.ru>
Co-authored-by: Alex <a@tushynski.me>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Mariam Saeed <69825646+Mariam-Saeed@users.noreply.github.com>
Co-authored-by: Rahul <rahulgithub96@gmail.com>
2026-03-12 14:46:26 +00:00
Alex
5006271abb fix stream stuff (#2293) 2026-03-11 11:43:27 +00:00
Alex
a6625ec5de fix: mini workflow fixes 2026-02-22 11:10:42 +00:00
Alex
1a2104f474 fix: token calc (#2285) 2026-02-20 17:37:47 +00:00
Alex
444abb8283 fix search nextra 2026-02-18 18:03:03 +00:00
Alex
ee86537f21 docs: add llms.txt and enable copy code button in nextra 2026-02-18 17:54:25 +00:00
Alex
17a736a927 docs: migrate to Nextra 4 and Next.js App Router 2026-02-18 17:13:24 +00:00
Alex
6b5779054d Merge branch 'main' of https://github.com/arc53/DocsGPT 2026-02-17 18:46:35 +00:00
Alex
14296632ef build(docs): upgrade nextra to v3 and update config 2026-02-17 18:46:28 +00:00
Siddhant Rai
2a3f0e455a feat: condition node functionality with CEL evaluation in Workflows (#2280)
* feat: add condition node functionality with CEL evaluation

- Introduced ConditionNode to support conditional branching in workflows.
- Implemented CEL evaluation for state updates and condition expressions.
- Updated WorkflowEngine to handle condition nodes and their execution logic.
- Enhanced validation for workflows to ensure condition nodes have at least two outgoing edges and valid expressions.
- Modified frontend components to support new condition node type and its configuration.
- Added necessary types and interfaces for condition cases and state operations.
- Updated requirements to include cel-python for expression evaluation.

* mini-fixes

* feat(workflow): improve UX

---------

Co-authored-by: Alex <a@tushynski.me>
2026-02-17 17:29:48 +00:00
Pavel
8aa44c415b Advanced settings (#2281)
Add additional settings to setup scripts
2026-02-17 11:54:59 +00:00
Manish Madan
0566c41a32 Merge pull request #2279 from yusufazam225/main
Fix XSS vulnerability: replace dangerouslySetInnerHTML with safe React rendering in PromptsModal
2026-02-15 19:27:37 +05:30
Manish Madan
876b04c058 Artifacts-backed persistence for Agent “Self” tools (Notes / Todo) + streaming artifact_id support (#2267)
* (feat:memory) use fs/storage for files

* (feat:todo) artifact_id via sse

* (feat:notes) artifact id return

* (feat:artifact) add get endpoint, store todos with conv id

* (feat: artifacts) fe integration

* feat(artifacts): ui enhancements, notes as mkdwn

* chore(artifacts) updated artifact tests

* (feat:todo_tool) return all todo items

* (feat:tools) use specific tool names in bubble

* feat: add conversationId prop to artifact components in Conversation

* Revert "(feat:memory) use fs/storage for files"

This reverts commit d1ce3bea31.

* (fix:fe) build fail
2026-02-15 00:08:37 +00:00
Yusuf Azam
b49a5934e2 Fix XSS vulnerability: replace dangerouslySetInnerHTML with safe React rendering 2026-02-14 17:18:52 +05:30
Alex
5fb063914e fix: frontend lib 2026-02-12 16:01:49 +00:00
Pavel
b9941e29a9 Scrollbar normalize styles (#2277)
* normalize styles

* Fix agent subhead width

* Dark scrollbar color adjust

* different browser support

---------

Co-authored-by: Alex <a@tushynski.me>
2026-02-12 15:19:42 +00:00
Siddhant Rai
8ef321d784 feat: agent workflow builder (#2264)
* feat: implement WorkflowAgent and GraphExecutor for workflow management and execution

* refactor: workflow schemas and introduce WorkflowEngine

- Updated schemas in `schemas.py` to include new agent types and configurations.
- Created `WorkflowEngine` class in `workflow_engine.py` to manage workflow execution.
- Enhanced `StreamProcessor` to handle workflow-related data.
- Added new routes and utilities for managing workflows in the user API.
- Implemented validation and serialization functions for workflows.
- Established MongoDB collections and indexes for workflows and related entities.

* refactor: improve WorkflowAgent documentation and update type hints in WorkflowEngine

* feat: workflow builder and managing in frontend

- Added new endpoints for workflows in `endpoints.ts`.
- Implemented `getWorkflow`, `createWorkflow`, and `updateWorkflow` methods in `userService.ts`.
- Introduced new UI components for alerts, buttons, commands, dialogs, multi-select, popovers, and selects.
- Enhanced styling in `index.css` with new theme variables and animations.
- Refactored modal components for better layout and styling.
- Configured TypeScript paths and Vite aliases for cleaner imports.

* feat: add workflow preview component and related state management

- Implemented WorkflowPreview component for displaying workflow execution.
- Created WorkflowPreviewSlice for managing workflow preview state, including queries and execution steps.
- Added WorkflowMiniMap for visual representation of workflow nodes and their statuses.
- Integrated conversation handling with the ability to fetch answers and manage query states.
- Introduced reusable Sheet component for UI overlays.
- Updated Redux store to include workflowPreview reducer.

* feat: enhance workflow execution details and state management in WorkflowEngine and WorkflowPreview

* feat: enhance workflow components with improved UI and functionality

- Updated WorkflowPreview to allow text truncation for better display of long names.
- Enhanced BaseNode with connectable handles and improved styling for better visibility.
- Added MobileBlocker component to inform users about desktop requirements for the Workflow Builder.
- Introduced PromptTextArea component for improved variable insertion and search functionality, including upstream variable extraction and context addition.

* feat(workflow): add owner validation and graph version support

* fix: ruff lint

---------

Co-authored-by: Alex <a@tushynski.me>
2026-02-11 14:15:24 +00:00
Alex
8353f9c649 docs: add guide for OCR configuration and usage 2026-02-10 15:54:27 +00:00
Alex
cb6b3aa406 fix: test google ai 2026-02-09 14:37:36 +00:00
Alex
36c7bd9206 Thinking stream (#2276)
* feat: stream thinking tokens

* fix: retry bug

* fix test
2026-02-09 14:27:53 +00:00
Alex
fea94379d7 feat: stream thinking tokens (#2275) 2026-02-09 13:46:27 +00:00
Alex
e602d941ca fix: sources display (#2274)
* fix: sources display

* fix: sources display2
2026-02-05 19:40:35 +00:00
Alex
f41f69a268 docs: add specific Celery startup command for macOS users 2026-02-03 17:33:13 +00:00
Manish Madan
ff72251878 Merge pull request #2268 from IRjSI/frontend-fix
fix(frontend): fix the input styling on renaming chat
2026-02-03 16:57:43 +05:30
Alex
7751fb52dd fix: neon docs mention 2026-02-03 00:23:37 +00:00
Alex
87a44d101d Update link in README for Neon documentation 2026-02-03 00:16:49 +00:00
Alex
80148f25b6 Update image source in README.md 2026-02-03 00:11:53 +00:00
Alex
8e3e4a8b09 fix: stio mcp 2026-02-02 15:46:19 +00:00
–IRjSI
9389b4a1e8 fix(frontend): fix the input styling on renaming chat 2026-01-27 22:33:28 +05:30
Pavel
4245e5bd2e End 2 end tests (#2266)
* All endpoints covered

test_integration.py kept for backwards compatability.
tests/integration/run_all.py proposed as alternative to cover all endpoints.

* Linter fixes
2026-01-22 13:11:24 +02:00
Pavel
e7d2af2405 Setup plus env fixes (#2265)
* fixes setup scripts

fixes to env handling in setup script plus other minor fixes

* Remove var declarations

Declarations such as `LLM_PROVIDER=$LLM_PROVIDER` override .env variables in compose

Similar issue is present in the frontend - need to choose either to switch to separate frontend env or keep as is.

* Manage apikeys in settings

1. More pydantic management of api keys.
2. Clean up of variable declarations from docker compose files, used to block .env imports. Now should be managed ether by settings.py defaults or .env
2026-01-22 12:21:01 +02:00
Alex
4c32a96370 rearrange settings 2026-01-22 00:51:59 +02:00
Alex
f61d112cea feat: process pdfs synthetically im model does not support file natively (#2263)
* feat: process pdfs synthetically im model does not support file natively

* fix: small code optimisations
2026-01-15 02:30:33 +02:00
Alex
2c55c6cd9a fix: tiktoken import in markdown parser 2026-01-12 23:04:20 +00:00
Alex
f1d714b5c1 fix(frontend): replace crypto.randomUUID with custom ID generator 2026-01-12 14:58:37 +00:00
Alex
69d9dc672a chore(settings): disable Docling OCR by default for text parsing 2026-01-12 12:01:54 +00:00
Alex
9192e010e8 build(docker): add g++ and python3.12-dev to system dependencies 2026-01-12 11:59:37 +00:00
Alex
f24cea0877 docs(models): update provider examples and add native llama.cpp info 2026-01-12 11:56:10 +00:00
Alex
a29bfa7489 fix simple routing (#2261) 2026-01-12 13:51:19 +02:00
Manish Madan
2246866a09 Feat: Agents grouped under folders (#2245)
* chore(dependabot): add react-widget npm dependency updates

* refactor(prompts): init on load, mv to pref slice

* (refactor): searchable dropdowns are separate

* (fix/ui) prompts adjust

* feat(changelog): dancing stars

* (fix)conversation: re-blink bubble past stream

* (fix)endless GET sources, esling err

* (feat:Agents) folders metadata

* (feat:agents) create new folder

* (feat:agent-management) ui

* feat:(agent folders) nesting/sub-folders

* feat:(agent folders)- closer the figma, inline folder inputs

* fix(delete behaviour) refetch agents on delete

* (fix:search) folder context missing

* fix(newAgent) preserve folder context

* feat(agent folders) id preserved im query, navigate

* feat(agents) mobile responsive

* feat(search/agents) lookup for nested agents as well

* (fix/modals) close on outside click

---------

Co-authored-by: GH Action - Upstream Sync <action@github.com>
2026-01-08 18:46:40 +02:00
Alex
7b17fde34a bump reqs 2026-01-08 11:53:07 +00:00
Alex
df57053613 feat: improve crawlers and update chunk filtering (#2250) 2026-01-06 00:52:12 +02:00
Alex
5662be12b5 feat(agent): implement context validation and message truncation (#2249) 2026-01-05 19:49:28 +02:00
Ankit Matth
d3e9d66b07 Fixed issue in models name (#2247) 2026-01-05 02:02:54 +02:00
Alex
e0bdbcbe38 Update README.md 2026-01-01 16:36:42 +02:00
Alex
05c835ed02 feat: enable OCR for docling when parsing attachments and update file extractor (#2246) 2025-12-31 02:08:49 +02:00
Alex
9e7f1ad1c0 Add Amazon S3 support and synchronization features (#2244)
* Add Amazon S3 support and synchronization features

* refactor: remove unused variable in load_data test
2025-12-30 20:26:51 +02:00
Alex
f910a82683 feat: add unauthorized response handling in StreamResource and bump deps 2025-12-27 14:23:37 +00:00
Alex
d8b7e86f8d bump dump (#2233) 2025-12-26 17:49:03 +02:00
Alex
aef3e0b4bb chore: update workflow permissions and fix paths in settings (#2227)
* chore: update workflow permissions and fix paths in settings

* dep

* dep upgraes
2025-12-25 14:26:01 +02:00
Alex
b0eee7be24 Patches (#2225)
* feat: implement URL validation to prevent SSRF

* feat: add zip extraction security

* ruff fixes

* fix: standardize error messages across API responses

* fix: improve error handling and standardize error messages across multiple routes

* fix: enhance JavaScript string safety in ConnectorCallbackStatus

* fix: improve OAuth error handling and message formatting in MCPOAuthCallback
2025-12-25 02:57:25 +02:00
Alex
197e94302b Patches (#2219)
* feat: implement URL validation to prevent SSRF

* feat: add zip extraction security

* ruff fixes

* fix: standardize error messages across API responses
2025-12-24 18:35:57 +02:00
Alex
98e949d2fd Patches (#2218)
* feat: implement URL validation to prevent SSRF

* feat: add zip extraction security

* ruff fixes
2025-12-24 17:05:35 +02:00
Alex
83e7a928f1 bump deps 2025-12-23 23:36:15 +00:00
Alex
ccd29b7d4e feat: implement Docling parsers (#2202)
* feat: implement Docling parsers

* fix office

* docling-ocr-fix

* Docling smart ocr

* ruff fix

---------

Co-authored-by: Pavel <pabin@yandex.ru>
2025-12-23 18:33:51 +02:00
Siddhant Rai
5b6cfa6ecc feat: enhance API tool with body serialization and content type handling (#2192)
* feat: enhance API tool with body serialization and content type handling

* feat: enhance ToolConfig with import functionality and user action management

- Added ImportSpecModal to allow importing actions into the tool configuration.
- Implemented search functionality for user actions with expandable action details.
- Introduced method colors for better visual distinction of HTTP methods.
- Updated APIActionType and ParameterGroupType to include optional 'required' field.
- Refactored action rendering to improve usability and maintainability.

* feat: add base URL input to ImportSpecModal for action URL customization

* feat: update TestBaseAgentTools to include 'required' field for parameters

* feat: standardize API call timeout to DEFAULT_TIMEOUT constant

* feat: add import specification functionality and related translations for multiple languages

---------

Co-authored-by: Alex <a@tushynski.me>
2025-12-23 15:37:44 +02:00
Akash Bhadana
f91846ce2d docs: Update VECTOR_STORE comment to include pgvector (#2211)
Co-authored-by: root <root@MD-CG-010>
2025-12-22 18:53:30 +02:00
dependabot[bot]
87e24ab96e chore(deps): bump elevenlabs from 2.26.1 to 2.27.0 in /application (#2203)
Bumps [elevenlabs](https://github.com/elevenlabs/elevenlabs-python) from 2.26.1 to 2.27.0.
- [Release notes](https://github.com/elevenlabs/elevenlabs-python/releases)
- [Commits](https://github.com/elevenlabs/elevenlabs-python/compare/v2.26.1...v2.27.0)

---
updated-dependencies:
- dependency-name: elevenlabs
  dependency-version: 2.27.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-12-22 12:25:45 +02:00
Alex
40c3e5568c fix search (#2210)
* fix search

* fix ruff
2025-12-22 00:51:06 +02:00
Yash
7958d29e13 Fix: Import external-link.svg properly in AgentDetailsModal (#2191) 2025-12-19 19:08:56 +02:00
Rahul Badade
a6fafa6a4d Fix: Autoselect input text box on pageload and conversation reset (#2177) (#2194)
* Fix: Autoselect input text box on pageload and conversation reset

- Added autoFocus to useEffect dependency array in MessageInput
- Added key prop to MessageInput to force remount on conversation reset
- Implemented refocus after message submission
- Removed duplicate input clearing logic in handleKeyDown

Fixes #2177

* fix: optimize input handling

---------

Co-authored-by: Alex <a@tushynski.me>
2025-12-19 18:57:57 +02:00
JustACodeA
3ad38f53fd fix: update Node.js version to 22 for Vite compatibility (#2169)
Updates the frontend Dockerfile from Node 20.6.1 to Node 22 to resolve
compatibility issues with Vite dependencies.

Closes #2157

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude <noreply@anthropic.com>
2025-12-19 18:29:58 +02:00
JustACodeA
d90b1c57e5 feat: add hover animation to conversation context menu button (#2168)
* feat: add hover animation to conversation context menu button

Adds visual feedback when hovering over the three-dot menu button in conversation tiles.
This makes it clear that the submenu is being targeted rather than the parent item.

Changes:
- Added rounded hover background with smooth transition
- Increased clickable area for better UX
- Supports both light and dark themes

Closes #2097

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Update frontend/src/conversation/ConversationTile.tsx

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: Claude <noreply@anthropic.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-12-19 18:25:29 +02:00
Alex
a69a0e100f fix: update dependencies in requirements.txt (#2201) 2025-12-19 02:17:59 +02:00
Alex
b0d4576a95 fix: improve error handling in agent webhook worker 2025-12-18 13:27:40 +00:00
Alex
2a4ab3aca1 Fix history leftover (#2198)
* fix: history leftover

* fix: unbound result
2025-12-17 16:07:14 +02:00
Alex
e0fd11a86e fix: bump next 2025-12-17 14:06:53 +00:00
Alex
de369f8b5e fix: history leftover (#2197) 2025-12-17 13:07:50 +02:00
Alex
af3e16c4fc fix: count history tokens from chunks, remove old UI setting limit (#2196) 2025-12-17 03:34:17 +02:00
Alex
aacf281222 fix: improve remote embeds (#2193) 2025-12-16 13:59:17 +02:00
AbbasSalloum
6d8f083c6f Adding a feature to paste files you ctrl v (#2183) 2025-12-12 11:55:16 +00:00
Mohamed-Abuali
909bc421c0 Bugfix/docs gpt widget behavior (#2172)
* style(DocsGPTWidget): improve message bubbles and markdown styling

- Adjust max-width for message bubbles to 90% for answers and 80% for questions
- Add overflow-wrap to prevent text overflow in messages
- Update list styling with proper spacing and positioning
- Add responsive font sizing for headings using clamp()
- Implement custom table styling with proper borders and spacing
- Add custom markdown renderer rules for tables

* feat(widget): replace input with textarea for prompt input

Add support for multi-line input and custom scrollbar styling. Implement Enter key submission handling while allowing Shift+Enter for new lines.

* feat(widget): improve textarea auto-resizing and table styling

- Add auto-resizing functionality for prompt textarea with min/max height constraints
- Fix table cell markup (th/td) and improve scrollbar styling
- Add promptRef to manage textarea state and reset after submission

* fix(widget): correct table cell styling and prevent empty submissions

- Fix swapped td/th elements in markdown renderer
- Adjust font weights for table headers and cells
- Add validation to prevent empty message submissions

* (fix) name mkdwn rule as the returned element

---------

Co-authored-by: ManishMadan2882 <manishmadan321@gmail.com>
2025-12-11 01:35:55 +00:00
Alex
d14f04d79c Update requirements.txt 2025-12-11 00:54:58 +02:00
Alex
e0a9f08632 refactor and deps (#2184) 2025-12-10 23:53:59 +02:00
Manish Madan
09e7c1b97f Fixes: re-blink in converstaion, (refactor) prompts and validate LocalStorage prompts (#2181)
* chore(dependabot): add react-widget npm dependency updates

* refactor(prompts): init on load, mv to pref slice

* (refactor): searchable dropdowns are separate

* (fix/ui) prompts adjust

* feat(changelog): dancing stars

* (fix)conversation: re-blink bubble past stream

* (fix)endless GET sources, esling err

---------

Co-authored-by: GH Action - Upstream Sync <action@github.com>
2025-12-10 23:53:40 +02:00
Alex
4adffe762a Update README.md 2025-12-08 16:59:08 +02:00
Alex
9a937d2686 Feat/small optimisation (#2182)
* optimised ram use + celery

* Remove VITE_EMBEDDINGS_NAME

* fix: timeout on remote embeds
2025-12-05 20:57:39 +02:00
Alex
e68da34c13 feat: implement internal API authentication mechanism 2025-12-04 15:52:45 +00:00
Siddhant Rai
9b9f95710a feat: agent search functionality with filters and loading states (#2179)
* feat: implement agent search functionality with filters and loading states

* style: improve layout and styling of agent search input and description
2025-12-04 17:46:37 +02:00
Alex
3352d42414 fix(frontend): use bracket notation for tool variable paths (#2176) 2025-11-26 19:12:02 +02:00
JustACodeA
899b30da5e feat: add German translation (#2170)
Adds complete German (Deutsch) language support to DocsGPT.

Changes:
- Add de.json with full German translations
- Register German in i18n configuration
- Add German to language selector dropdown

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude <noreply@anthropic.com>
2025-11-26 12:52:23 +02:00
Alex
dc2faf7a7e fix: webhooks (#2175) 2025-11-25 16:08:22 +02:00
Alex
67e0d222d1 fix: model in agents via api (#2174) 2025-11-25 13:54:34 +02:00
Alex
17698ce774 feat: context compression (#2173)
* feat: context compression

* fix: ruff
2025-11-24 12:44:19 +02:00
Alex
7d1c8c008b Update README.md 2025-11-22 16:42:25 +02:00
dependabot[bot]
2c2bdd37d5 chore(deps-dev): bump typescript in /extensions/react-widget
Bumps [typescript](https://github.com/microsoft/TypeScript) from 5.4.5 to 5.9.3.
- [Release notes](https://github.com/microsoft/TypeScript/releases)
- [Changelog](https://github.com/microsoft/TypeScript/blob/main/azure-pipelines.release-publish.yml)
- [Commits](https://github.com/microsoft/TypeScript/compare/v5.4.5...v5.9.3)

---
updated-dependencies:
- dependency-name: typescript
  dependency-version: 5.9.3
  dependency-type: direct:development
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-11-07 15:19:11 +00:00
dependabot[bot]
6a00319c2d chore(deps): bump flow-bin in /extensions/react-widget
Bumps [flow-bin](https://github.com/flowtype/flow-bin) from 0.229.2 to 0.290.0.
- [Release notes](https://github.com/flowtype/flow-bin/releases)
- [Commits](https://github.com/flowtype/flow-bin/commits)

---
updated-dependencies:
- dependency-name: flow-bin
  dependency-version: 0.290.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-11-07 15:19:00 +00:00
dependabot[bot]
66870279d3 chore(deps): bump @radix-ui/react-icons in /extensions/react-widget
Bumps @radix-ui/react-icons from 1.3.0 to 1.3.2.

---
updated-dependencies:
- dependency-name: "@radix-ui/react-icons"
  dependency-version: 1.3.2
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-11-07 15:18:48 +00:00
478 changed files with 82048 additions and 16185 deletions

View File

@@ -1,9 +1,36 @@
API_KEY=<LLM api key (for example, open ai key)>
LLM_NAME=docsgpt
VITE_API_STREAMING=true
INTERNAL_KEY=<internal key for worker-to-backend authentication>
# Provider-specific API keys (optional - use these to enable multiple providers)
# OPENAI_API_KEY=<your-openai-api-key>
# ANTHROPIC_API_KEY=<your-anthropic-api-key>
# GOOGLE_API_KEY=<your-google-api-key>
# GROQ_API_KEY=<your-groq-api-key>
# NOVITA_API_KEY=<your-novita-api-key>
# OPEN_ROUTER_API_KEY=<your-openrouter-api-key>
# Remote Embeddings (Optional - for using a remote embeddings API instead of local SentenceTransformer)
# When set, the app will use the remote API and won't load SentenceTransformer (saves RAM)
EMBEDDINGS_BASE_URL=
EMBEDDINGS_KEY=
#For Azure (you can delete it if you don't use Azure)
OPENAI_API_BASE=
OPENAI_API_VERSION=
AZURE_DEPLOYMENT_NAME=
AZURE_EMBEDDINGS_DEPLOYMENT_NAME=
AZURE_EMBEDDINGS_DEPLOYMENT_NAME=
#Azure AD Application (client) ID
MICROSOFT_CLIENT_ID=your-azure-ad-client-id
#Azure AD Application client secret
MICROSOFT_CLIENT_SECRET=your-azure-ad-client-secret
#Azure AD Tenant ID (or 'common' for multi-tenant)
MICROSOFT_TENANT_ID=your-azure-ad-tenant-id
#If you are using a Microsoft Entra ID tenant,
#configure the AUTHORITY variable as
#"https://login.microsoftonline.com/TENANT_GUID"
#or "https://login.microsoftonline.com/contoso.onmicrosoft.com".
#Alternatively, use "https://login.microsoftonline.com/common" for multi-tenant app.
MICROSOFT_AUTHORITY=https://{tenantId}.ciamlogin.com/{tenantId}

View File

@@ -18,7 +18,7 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5

View File

@@ -21,7 +21,7 @@ jobs:
contents: read
packages: write
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v4
- name: Set up QEMU # Only needed for emulation, not for native arm64 builds
if: matrix.platform == 'linux/arm64'

View File

@@ -21,7 +21,7 @@ jobs:
contents: read
packages: write
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v4
- name: Set up QEMU # Only needed for emulation, not for native arm64 builds
if: matrix.platform == 'linux/arm64'

View File

@@ -23,7 +23,7 @@ jobs:
contents: read
packages: write
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

View File

@@ -23,7 +23,7 @@ jobs:
contents: read
packages: write
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v4
- name: Set up QEMU # Only needed for emulation, not for native arm64 builds
if: matrix.platform == 'linux/arm64'

View File

@@ -7,11 +7,14 @@ on:
pull_request:
types: [ opened, synchronize ]
permissions:
contents: read
jobs:
ruff:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v4
- name: Lint with Ruff
uses: chartboost/ruff-action@v1

114
.github/workflows/npm-publish.yml vendored Normal file
View File

@@ -0,0 +1,114 @@
name: Publish npm libraries
on:
workflow_dispatch:
inputs:
version:
description: >
Version bump type (patch | minor | major) or explicit semver (e.g. 1.2.3).
Applies to both docsgpt and docsgpt-react.
required: true
default: patch
permissions:
contents: write
pull-requests: write
jobs:
publish:
runs-on: ubuntu-latest
environment: npm-release
defaults:
run:
working-directory: extensions/react-widget
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v4
with:
node-version: 20
registry-url: https://registry.npmjs.org
- name: Install dependencies
run: npm ci
# ── docsgpt (HTML embedding bundle) ──────────────────────────────────
# Uses the `build` script (parcel build src/browser.tsx) and keeps
# the `targets` field so Parcel produces browser-optimised bundles.
- name: Set package name → docsgpt
run: jq --arg n "docsgpt" '.name=$n' package.json > _tmp.json && mv _tmp.json package.json
- name: Bump version (docsgpt)
id: version_docsgpt
run: |
VERSION="${{ github.event.inputs.version }}"
NEW_VER=$(npm version "${VERSION:-patch}" --no-git-tag-version)
echo "version=${NEW_VER#v}" >> "$GITHUB_OUTPUT"
- name: Build docsgpt
run: npm run build
- name: Publish docsgpt
run: npm publish --verbose
env:
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
# ── docsgpt-react (React library bundle) ─────────────────────────────
# Uses `build:react` script (parcel build src/index.ts) and strips
# the `targets` field so Parcel treats the output as a plain library
# without browser-specific target resolution, producing a smaller bundle.
- name: Reset package.json from source control
run: git checkout -- package.json
- name: Set package name → docsgpt-react
run: jq --arg n "docsgpt-react" '.name=$n' package.json > _tmp.json && mv _tmp.json package.json
- name: Remove targets field (react library build)
run: jq 'del(.targets)' package.json > _tmp.json && mv _tmp.json package.json
- name: Bump version (docsgpt-react) to match docsgpt
run: npm version "${{ steps.version_docsgpt.outputs.version }}" --no-git-tag-version
- name: Clean dist before react build
run: rm -rf dist
- name: Build docsgpt-react
run: npm run build:react
- name: Publish docsgpt-react
run: npm publish --verbose
env:
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
# ── Commit the bumped version back to the repository ─────────────────
- name: Reset package.json and write final version
run: |
git checkout -- package.json
jq --arg v "${{ steps.version_docsgpt.outputs.version }}" '.version=$v' \
package.json > _tmp.json && mv _tmp.json package.json
npm install --package-lock-only
- name: Commit version bump and create PR
run: |
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
BRANCH="chore/bump-npm-v${{ steps.version_docsgpt.outputs.version }}"
git checkout -b "$BRANCH"
git add package.json package-lock.json
git commit -m "chore: bump npm libraries to v${{ steps.version_docsgpt.outputs.version }}"
git push origin "$BRANCH"
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Create PR
run: |
gh pr create \
--title "chore: bump npm libraries to v${{ steps.version_docsgpt.outputs.version }}" \
--body "Automated version bump after npm publish." \
--base main
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -1,5 +1,9 @@
name: Run python tests with pytest
on: [push, pull_request]
permissions:
contents: read
jobs:
pytest_and_coverage:
name: Run tests and count coverage
@@ -8,7 +12,7 @@ jobs:
matrix:
python-version: ["3.12"]
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:

View File

@@ -0,0 +1,34 @@
name: React Widget Build
on:
push:
paths:
- 'extensions/react-widget/**'
pull_request:
paths:
- 'extensions/react-widget/**'
permissions:
contents: read
jobs:
build:
runs-on: ubuntu-latest
defaults:
run:
working-directory: extensions/react-widget
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v4
with:
node-version: 20
cache: npm
cache-dependency-path: extensions/react-widget/package-lock.json
- name: Install dependencies
run: npm ci
- name: Build
run: npm run build

View File

@@ -17,7 +17,7 @@ jobs:
steps:
# Step 1: run a standard checkout action
- name: Checkout target repo
uses: actions/checkout@v6
uses: actions/checkout@v4
# Step 2: run the sync action
- name: Sync upstream changes

View File

@@ -9,12 +9,16 @@ on:
- '.vale.ini'
- '.github/styles/**'
permissions:
contents: read
pull-requests: write
jobs:
vale:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
- name: Vale linter
uses: errata-ai/vale-action@v2

7
.gitignore vendored
View File

@@ -2,6 +2,8 @@
__pycache__/
*.py[cod]
*$py.class
results.txt
experiments/
experiments
# C extensions
@@ -70,6 +72,7 @@ instance/
# Sphinx documentation
docs/_build/
docs/public/_pagefind/
# PyBuilder
target/
@@ -146,6 +149,10 @@ frontend/yarn-error.log*
frontend/pnpm-debug.log*
frontend/lerna-debug.log*
# Keep frontend utility helpers tracked (overrides global lib/ ignore)
!frontend/src/lib/
!frontend/src/lib/**
frontend/node_modules
frontend/dist
frontend/dist-ssr

View File

@@ -1,2 +1,6 @@
# Allow lines to be as long as 120 characters.
line-length = 120
line-length = 120
[lint.per-file-ignores]
# Integration tests use sys.path.insert() before imports for standalone execution
"tests/integration/*.py" = ["E402"]

134
AGENTS.md Normal file
View File

@@ -0,0 +1,134 @@
# AGENTS.md
- Read `CONTRIBUTING.md` before making non-trivial changes.
- For day-to-day development and feature work, follow the development-environment workflow rather than defaulting to `setup.sh` / `setup.ps1`.
- Avoid using the setup scripts during normal feature work unless the user explicitly asks for them. Users configure `.env` usually.
- Try to follow red/green TDD
### Check existing dev prerequisites first
For feature work, do **not** assume the environment needs to be recreated.
- Check whether the user already has a Python virtual environment such as `venv/` or `.venv/`.
- Check whether MongoDB is already running.
- Check whether Redis is already running.
- Reuse what is already working. Do not stop or recreate MongoDB, Redis, or the Python environment unless the task is environment setup or troubleshooting.
## Normal local development commands
Use these commands once the dev prerequisites above are satisfied.
### Backend
```bash
source .venv/bin/activate # macOS/Linux
uv pip install -r application/requirements.txt # or: pip install -r application/requirements.txt
```
Run the Flask API (if needed):
```bash
flask --app application/app.py run --host=0.0.0.0 --port=7091
```
Run the Celery worker in a separate terminal (if needed):
```bash
celery -A application.app.celery worker -l INFO
```
On macOS, prefer the solo pool for Celery:
```bash
python -m celery -A application.app.celery worker -l INFO --pool=solo
```
### Frontend
Install dependencies only when needed, then run the dev server:
```bash
cd frontend
npm install --include=dev
npm run dev
```
### Docs site
```bash
cd docs
npm install
```
### Python / backend changes validation
```bash
ruff check .
python -m pytest
```
### Frontend changes
```bash
cd frontend && npm run lint
cd frontend && npm run build
```
### Documentation changes
```bash
cd docs && npm run build
```
If Vale is installed locally and you edited prose, also run:
```bash
vale .
```
## Repository map
- `application/`: Flask backend, API routes, agent logic, retrieval, parsing, security, storage, Celery worker, and WSGI entrypoints.
- `tests/`: backend unit/integration tests and test-only Python dependencies.
- `frontend/`: Vite + React + TypeScript application.
- `frontend/src/`: main UI code, including `components`, `conversation`, `hooks`, `locale`, `settings`, `upload`, and Redux store wiring in `store.ts`.
- `docs/`: separate documentation site built with Next.js/Nextra.
- `extensions/`: integrations and widgets such as Chatwoot, Chrome, Discord, React widget, Slack bot, and web widget.
- `deployment/`: Docker Compose variants and Kubernetes manifests.
## Coding rules
### Backend
- Follow PEP 8 and keep Python line length at or under 120 characters.
- Use type hints for function arguments and return values.
- Add Google-style docstrings to new or substantially changed functions and classes.
- Add or update tests under `tests/` for backend behavior changes.
- Keep changes narrow in `api`, `auth`, `security`, `parser`, `retriever`, and `storage` areas.
### Backend Abstractions
- LLM providers implement a common interface in `application/llm/` (add new providers by extending the base class).
- Vector stores are abstracted in `application/vectorstore/`.
- Parsers live in `application/parser/` and handle different document formats in the ingestion stage.
- Agents and tools are in `application/agents/` and `application/agents/tools/`.
- Celery setup/config lives in `application/celery_init.py` and `application/celeryconfig.py`.
- Settings and env vars are managed via Pydantic in `application/core/settings.py`.
### Frontend
- Follow the existing ESLint + Prettier setup.
- Prefer small, reusable functional components and hooks.
- If shared state must be added, use Redux rather than introducing a new global state library.
- Avoid broad UI refactors unless the task explicitly asks for them.
- Do not re-create components if we already have some in the app.
## PR readiness
Before opening a PR:
- run the relevant validation commands above
- confirm backend changes still work end-to-end after ingesting sample data when applicable
- clearly summarize user-visible behavior changes
- mention any config, dependency, or deployment implications
- Ask your user to attach a screenshot or a video to it

View File

@@ -22,6 +22,11 @@ Thank you for choosing to contribute to DocsGPT! We are all very grateful!
- We have a frontend built on React (Vite) and a backend in Python.
> **Required for every PR:** Please attach screenshots or a short screen
> recording that shows the working version of your changes. This makes the
> requirement visible to reviewers and helps them quickly verify what you are
> submitting.
Before creating issues, please check out how the latest version of our app looks and works by launching it via [Quickstart](https://github.com/arc53/DocsGPT#quickstart) the version on our live demo is slightly modified with login. Your issues should relate to the version you can launch via [Quickstart](https://github.com/arc53/DocsGPT#quickstart).
@@ -125,7 +130,7 @@ Here's a step-by-step guide on how to contribute to DocsGPT:
```
9. **Submit a Pull Request (PR):**
- Create a Pull Request from your branch to the main repository. Make sure to include a detailed description of your changes and reference any related issues.
- Create a Pull Request from your branch to the main repository. Make sure to include a detailed description of your changes, reference any related issues, and attach screenshots or a screen recording showing the working version.
10. **Collaborate:**
- Be responsive to comments and feedback on your PR.

View File

@@ -7,7 +7,7 @@
</p>
<p align="left">
<strong><a href="https://www.docsgpt.cloud/">DocsGPT</a></strong> is an open-source AI platform for building intelligent agents and assistants. Features Agent Builder, deep research tools, document analysis (PDF, Office, web content), Multi-model support (choose your provider or run locally), and rich API connectivity for agents with actionable tools and integrations. Deploy anywhere with complete privacy control.
<strong><a href="https://www.docsgpt.cloud/">DocsGPT</a></strong> is an open-source AI platform for building intelligent agents and assistants. Features Agent Builder, deep research tools, document analysis (PDF, Office, web content, and audio), Multi-model support (choose your provider or run locally), and rich API connectivity for agents with actionable tools and integrations. Deploy anywhere with complete privacy control.
</p>
<div align="center">
@@ -26,23 +26,17 @@
</div>
<div align="center">
<br>
🎃 <a href="https://github.com/arc53/DocsGPT/blob/main/HACKTOBERFEST.md"> Hacktoberfest Prizes, Rules & Q&A </a> 🎃
<br>
<br>
</div>
<div align="center">
<br>
<img src="https://d3dg1063dc54p9.cloudfront.net/videos/demov7.gif" alt="video-example-of-docs-gpt" width="800" height="450">
<img src="https://d3dg1063dc54p9.cloudfront.net/videos/demo-26.gif" alt="video-example-of-docs-gpt" width="800" height="480">
</div>
<h3 align="left">
<strong>Key Features:</strong>
</h3>
<ul align="left">
<li><strong>🗂️ Wide Format Support:</strong> Reads PDF, DOCX, CSV, XLSX, EPUB, MD, RST, HTML, MDX, JSON, PPTX, and images.</li>
<li><strong>🗂️ Wide Format Support:</strong> Reads PDF, DOCX, CSV, XLSX, EPUB, MD, RST, HTML, MDX, JSON, PPTX, images, and audio files such as MP3, WAV, M4A, OGG, and WebM.</li>
<li><strong>🎙️ Speech Workflows:</strong> Record voice input into chat, transcribe audio on the backend, and ingest meeting recordings or voice notes as searchable knowledge.</li>
<li><strong>🌐 Web & Data Integration:</strong> Ingests from URLs, sitemaps, Reddit, GitHub and web crawlers.</li>
<li><strong>✅ Reliable Answers:</strong> Get accurate, hallucination-free responses with source citations viewable in a clean UI.</li>
<li><strong>🔑 Streamlined API Keys:</strong> Generate keys linked to your settings, documents, and models, simplifying chatbot and integration setup.</li>
@@ -53,24 +47,11 @@
</ul>
## Roadmap
- [x] Full GoogleAI compatibility (Jan 2025)
- [x] Add tools (Jan 2025)
- [x] Manually updating chunks in the app UI (Feb 2025)
- [x] Devcontainer for easy development (Feb 2025)
- [x] ReACT agent (March 2025)
- [x] Chatbots menu re-design to handle tools, agent types, and more (April 2025)
- [x] New input box in the conversation menu (April 2025)
- [x] Add triggerable actions / tools (webhook) (April 2025)
- [x] Agent optimisations (May 2025)
- [x] Filesystem sources update (July 2025)
- [x] Json Responses (August 2025)
- [x] MCP support (August 2025)
- [x] Google Drive integration (September 2025)
- [x] Add OAuth 2.0 authentication for MCP (September 2025)
- [ ] SharePoint integration (October 2025)
- [ ] Deep Agents (October 2025)
- [ ] Agent scheduling
- [x] Add OAuth 2.0 authentication for MCP ( September 2025 )
- [x] Deep Agents ( October 2025 )
- [x] Prompt Templating ( October 2025 )
- [x] Full api tooling ( Dec 2025 )
- [ ] Agent scheduling ( Jan 2026 )
You can find our full roadmap [here](https://github.com/orgs/arc53/projects/2). Please don't hesitate to contribute or create issues, it helps us improve DocsGPT!
@@ -165,9 +146,16 @@ We as members, contributors, and leaders, pledge to make participation in our co
The source code license is [MIT](https://opensource.org/license/mit/), as described in the [LICENSE](LICENSE) file.
<p>This project is supported by:</p>
## This project is supported by:
<p>
<a href="https://www.digitalocean.com/?utm_medium=opensource&utm_source=DocsGPT">
<img src="https://opensource.nyc3.cdn.digitaloceanspaces.com/attribution/assets/SVG/DO_Logo_horizontal_blue.svg" width="201px">
</a>
</p>
<p>
<a href="https://get.neon.com/docsgpt">
<img width="201" alt="color" src="https://github.com/user-attachments/assets/7d9813b7-0e6d-403f-b5af-68af066b326f" />
</a>
</p>

View File

@@ -1,11 +0,0 @@
API_KEY=your_api_key
EMBEDDINGS_KEY=your_api_key
API_URL=http://localhost:7091
FLASK_APP=application/app.py
FLASK_DEBUG=true
#For OPENAI on Azure
OPENAI_API_BASE=
OPENAI_API_VERSION=
AZURE_DEPLOYMENT_NAME=
AZURE_EMBEDDINGS_DEPLOYMENT_NAME=

View File

@@ -7,7 +7,7 @@ RUN apt-get update && \
apt-get install -y software-properties-common && \
add-apt-repository ppa:deadsnakes/ppa && \
apt-get update && \
apt-get install -y --no-install-recommends gcc wget unzip libc6-dev python3.12 python3.12-venv && \
apt-get install -y --no-install-recommends gcc g++ wget unzip libc6-dev python3.12 python3.12-venv python3.12-dev && \
rm -rf /var/lib/apt/lists/*
# Verify Python installation and setup symlink
@@ -48,7 +48,12 @@ FROM ubuntu:24.04 as final
RUN apt-get update && \
apt-get install -y software-properties-common && \
add-apt-repository ppa:deadsnakes/ppa && \
apt-get update && apt-get install -y --no-install-recommends python3.12 && \
apt-get update && apt-get install -y --no-install-recommends \
python3.12 \
libgl1 \
libglib2.0-0 \
poppler-utils \
&& \
ln -s /usr/bin/python3.12 /usr/bin/python && \
rm -rf /var/lib/apt/lists/*

View File

@@ -1,14 +1,20 @@
from application.agents.classic_agent import ClassicAgent
from application.agents.react_agent import ReActAgent
import logging
from application.agents.agentic_agent import AgenticAgent
from application.agents.classic_agent import ClassicAgent
from application.agents.research_agent import ResearchAgent
from application.agents.workflow_agent import WorkflowAgent
logger = logging.getLogger(__name__)
class AgentCreator:
agents = {
"classic": ClassicAgent,
"react": ReActAgent,
"react": ClassicAgent, # backwards compat: react falls back to classic
"agentic": AgenticAgent,
"research": ResearchAgent,
"workflow": WorkflowAgent,
}
@classmethod
@@ -16,5 +22,4 @@ class AgentCreator:
agent_class = cls.agents.get(type.lower())
if not agent_class:
raise ValueError(f"No agent class found for type {type}")
return agent_class(*args, **kwargs)

View File

@@ -0,0 +1,63 @@
import logging
from typing import Dict, Generator, Optional
from application.agents.base import BaseAgent
from application.agents.tools.internal_search import (
INTERNAL_TOOL_ID,
add_internal_search_tool,
)
from application.logging import LogContext
logger = logging.getLogger(__name__)
class AgenticAgent(BaseAgent):
"""Agent where the LLM controls retrieval via tools.
Unlike ClassicAgent which pre-fetches docs into the prompt,
AgenticAgent gives the LLM an internal_search tool so it can
decide when, what, and whether to search.
"""
def __init__(
self,
retriever_config: Optional[Dict] = None,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.retriever_config = retriever_config or {}
def _gen_inner(
self, query: str, log_context: LogContext
) -> Generator[Dict, None, None]:
tools_dict = self.tool_executor.get_tools()
add_internal_search_tool(tools_dict, self.retriever_config)
self._prepare_tools(tools_dict)
# 4. Build messages (prompt has NO pre-fetched docs)
messages = self._build_messages(self.prompt, query)
# 5. Call LLM — the handler manages the tool loop
llm_response = self._llm_gen(messages, log_context)
yield from self._handle_response(
llm_response, tools_dict, messages, log_context
)
# 6. Collect sources from internal search tool results
self._collect_internal_sources()
yield {"sources": self.retrieved_docs}
yield {"tool_calls": self._get_truncated_tool_calls()}
log_context.stacks.append(
{"component": "agent", "data": {"tool_calls": self.tool_calls.copy()}}
)
def _collect_internal_sources(self):
"""Collect retrieved docs from the cached InternalSearchTool instance."""
cache_key = f"internal_search:{INTERNAL_TOOL_ID}:{self.user or ''}"
tool = self.tool_executor._loaded_tools.get(cache_key)
if tool and hasattr(tool, "retrieved_docs") and tool.retrieved_docs:
self.retrieved_docs = tool.retrieved_docs

View File

@@ -3,11 +3,11 @@ import uuid
from abc import ABC, abstractmethod
from typing import Dict, Generator, List, Optional
from bson.objectid import ObjectId
from application.agents.tools.tool_action_parser import ToolActionParser
from application.agents.tools.tool_manager import ToolManager
from application.core.mongo_db import MongoDB
from application.agents.tool_executor import ToolExecutor
from application.core.json_schema_utils import (
JsonSchemaValidationError,
normalize_json_schema_payload,
)
from application.core.settings import settings
from application.llm.handlers.handler_creator import LLMHandlerCreator
from application.llm.llm_creator import LLMCreator
@@ -23,6 +23,7 @@ class BaseAgent(ABC):
llm_name: str,
model_id: str,
api_key: str,
agent_id: Optional[str] = None,
user_api_key: Optional[str] = None,
prompt: str = "",
chat_history: Optional[List[Dict]] = None,
@@ -34,36 +35,71 @@ class BaseAgent(ABC):
token_limit: Optional[int] = settings.DEFAULT_AGENT_LIMITS["token_limit"],
limited_request_mode: Optional[bool] = False,
request_limit: Optional[int] = settings.DEFAULT_AGENT_LIMITS["request_limit"],
compressed_summary: Optional[str] = None,
llm=None,
llm_handler=None,
tool_executor: Optional[ToolExecutor] = None,
backup_models: Optional[List[str]] = None,
):
self.endpoint = endpoint
self.llm_name = llm_name
self.model_id = model_id
self.api_key = api_key
self.agent_id = agent_id
self.user_api_key = user_api_key
self.prompt = prompt
self.decoded_token = decoded_token or {}
self.user: str = self.decoded_token.get("sub")
self.tool_config: Dict = {}
self.tools: List[Dict] = []
self.tool_calls: List[Dict] = []
self.chat_history: List[Dict] = chat_history if chat_history is not None else []
self.llm = LLMCreator.create_llm(
llm_name,
api_key=api_key,
user_api_key=user_api_key,
decoded_token=decoded_token,
model_id=model_id,
)
# Dependency injection for LLM — fall back to creating if not provided
if llm is not None:
self.llm = llm
else:
self.llm = LLMCreator.create_llm(
llm_name,
api_key=api_key,
user_api_key=user_api_key,
decoded_token=decoded_token,
model_id=model_id,
agent_id=agent_id,
backup_models=backup_models,
)
self.retrieved_docs = retrieved_docs or []
self.llm_handler = LLMHandlerCreator.create_handler(
llm_name if llm_name else "default"
)
if llm_handler is not None:
self.llm_handler = llm_handler
else:
self.llm_handler = LLMHandlerCreator.create_handler(
llm_name if llm_name else "default"
)
# Tool executor — injected or created
if tool_executor is not None:
self.tool_executor = tool_executor
else:
self.tool_executor = ToolExecutor(
user_api_key=user_api_key,
user=self.user,
decoded_token=decoded_token,
)
self.attachments = attachments or []
self.json_schema = json_schema
self.json_schema = None
if json_schema is not None:
try:
self.json_schema = normalize_json_schema_payload(json_schema)
except JsonSchemaValidationError as exc:
logger.warning("Ignoring invalid JSON schema payload: %s", exc)
self.limited_token_mode = limited_token_mode
self.token_limit = token_limit
self.limited_request_mode = limited_request_mode
self.request_limit = request_limit
self.compressed_summary = compressed_summary
self.current_token_count = 0
self.context_limit_reached = False
@log_activity()
def gen(
@@ -77,204 +113,111 @@ class BaseAgent(ABC):
) -> Generator[Dict, None, None]:
pass
# ---- Tool delegation (thin wrappers around ToolExecutor) ----
@property
def tool_calls(self) -> List[Dict]:
return self.tool_executor.tool_calls
@tool_calls.setter
def tool_calls(self, value: List[Dict]):
self.tool_executor.tool_calls = value
def _get_tools(self, api_key: str = None) -> Dict[str, Dict]:
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
agents_collection = db["agents"]
tools_collection = db["user_tools"]
agent_data = agents_collection.find_one({"key": api_key or self.user_api_key})
tool_ids = agent_data.get("tools", []) if agent_data else []
tools = (
tools_collection.find(
{"_id": {"$in": [ObjectId(tool_id) for tool_id in tool_ids]}}
)
if tool_ids
else []
)
tools = list(tools)
tools_by_id = {str(tool["_id"]): tool for tool in tools} if tools else {}
return tools_by_id
return self.tool_executor._get_tools_by_api_key(api_key or self.user_api_key)
def _get_user_tools(self, user="local"):
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
user_tools_collection = db["user_tools"]
user_tools = user_tools_collection.find({"user": user, "status": True})
user_tools = list(user_tools)
return {str(i): tool for i, tool in enumerate(user_tools)}
return self.tool_executor._get_user_tools(user)
def _build_tool_parameters(self, action):
params = {"type": "object", "properties": {}, "required": []}
for param_type in ["query_params", "headers", "body", "parameters"]:
if param_type in action and action[param_type].get("properties"):
for k, v in action[param_type]["properties"].items():
if v.get("filled_by_llm", True):
params["properties"][k] = {
key: value
for key, value in v.items()
if key != "filled_by_llm" and key != "value"
}
params["required"].append(k)
return params
return self.tool_executor._build_tool_parameters(action)
def _prepare_tools(self, tools_dict):
self.tools = [
{
"type": "function",
"function": {
"name": f"{action['name']}_{tool_id}",
"description": action["description"],
"parameters": self._build_tool_parameters(action),
},
}
for tool_id, tool in tools_dict.items()
if (
(tool["name"] == "api_tool" and "actions" in tool.get("config", {}))
or (tool["name"] != "api_tool" and "actions" in tool)
)
for action in (
tool["config"]["actions"].values()
if tool["name"] == "api_tool"
else tool["actions"]
)
if action.get("active", True)
]
self.tools = self.tool_executor.prepare_tools_for_llm(tools_dict)
def _execute_tool_action(self, tools_dict, call):
parser = ToolActionParser(self.llm.__class__.__name__)
tool_id, action_name, call_args = parser.parse_args(call)
call_id = getattr(call, "id", None) or str(uuid.uuid4())
# Check if parsing failed
if tool_id is None or action_name is None:
error_message = f"Error: Failed to parse LLM tool call. Tool name: {getattr(call, 'name', 'unknown')}"
logger.error(error_message)
tool_call_data = {
"tool_name": "unknown",
"call_id": call_id,
"action_name": getattr(call, "name", "unknown"),
"arguments": call_args or {},
"result": f"Failed to parse tool call. Invalid tool name format: {getattr(call, 'name', 'unknown')}",
}
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
self.tool_calls.append(tool_call_data)
return "Failed to parse tool call.", call_id
# Check if tool_id exists in available tools
if tool_id not in tools_dict:
error_message = f"Error: Tool ID '{tool_id}' extracted from LLM call not found in available tools_dict. Available IDs: {list(tools_dict.keys())}"
logger.error(error_message)
# Return error result
tool_call_data = {
"tool_name": "unknown",
"call_id": call_id,
"action_name": f"{action_name}_{tool_id}",
"arguments": call_args,
"result": f"Tool with ID {tool_id} not found. Available tools: {list(tools_dict.keys())}",
}
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
self.tool_calls.append(tool_call_data)
return f"Tool with ID {tool_id} not found.", call_id
tool_call_data = {
"tool_name": tools_dict[tool_id]["name"],
"call_id": call_id,
"action_name": f"{action_name}_{tool_id}",
"arguments": call_args,
}
yield {"type": "tool_call", "data": {**tool_call_data, "status": "pending"}}
tool_data = tools_dict[tool_id]
action_data = (
tool_data["config"]["actions"][action_name]
if tool_data["name"] == "api_tool"
else next(
action
for action in tool_data["actions"]
if action["name"] == action_name
)
return self.tool_executor.execute(
tools_dict, call, self.llm.__class__.__name__
)
query_params, headers, body, parameters = {}, {}, {}, {}
param_types = {
"query_params": query_params,
"headers": headers,
"body": body,
"parameters": parameters,
}
for param_type, target_dict in param_types.items():
if param_type in action_data and action_data[param_type].get("properties"):
for param, details in action_data[param_type]["properties"].items():
if param not in call_args and "value" in details:
target_dict[param] = details["value"]
for param, value in call_args.items():
for param_type, target_dict in param_types.items():
if param_type in action_data and param in action_data[param_type].get(
"properties", {}
):
target_dict[param] = value
tm = ToolManager(config={})
# Prepare tool_config and add tool_id for memory tools
if tool_data["name"] == "api_tool":
tool_config = {
"url": tool_data["config"]["actions"][action_name]["url"],
"method": tool_data["config"]["actions"][action_name]["method"],
"headers": headers,
"query_params": query_params,
}
else:
tool_config = tool_data["config"].copy() if tool_data["config"] else {}
# Add tool_id from MongoDB _id for tools that need instance isolation (like memory tool)
# Use MongoDB _id if available, otherwise fall back to enumerated tool_id
tool_config["tool_id"] = str(tool_data.get("_id", tool_id))
tool = tm.load_tool(
tool_data["name"],
tool_config=tool_config,
user_id=self.user, # Pass user ID for MCP tools credential decryption
)
if tool_data["name"] == "api_tool":
print(
f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}"
)
result = tool.execute_action(action_name, **body)
else:
print(f"Executing tool: {action_name} with args: {call_args}")
result = tool.execute_action(action_name, **parameters)
tool_call_data["result"] = (
f"{str(result)[:50]}..." if len(str(result)) > 50 else result
)
yield {"type": "tool_call", "data": {**tool_call_data, "status": "completed"}}
self.tool_calls.append(tool_call_data)
return result, call_id
def _get_truncated_tool_calls(self):
return [
{
**tool_call,
"result": (
f"{str(tool_call['result'])[:50]}..."
if len(str(tool_call["result"])) > 50
else tool_call["result"]
),
"status": "completed",
}
for tool_call in self.tool_calls
]
return self.tool_executor.get_truncated_tool_calls()
# ---- Context / token management ----
def _calculate_current_context_tokens(self, messages: List[Dict]) -> int:
from application.api.answer.services.compression.token_counter import (
TokenCounter,
)
return TokenCounter.count_message_tokens(messages)
def _check_context_limit(self, messages: List[Dict]) -> bool:
from application.core.model_utils import get_token_limit
try:
current_tokens = self._calculate_current_context_tokens(messages)
self.current_token_count = current_tokens
context_limit = get_token_limit(self.model_id)
threshold = int(context_limit * settings.COMPRESSION_THRESHOLD_PERCENTAGE)
if current_tokens >= threshold:
logger.warning(
f"Context limit approaching: {current_tokens}/{context_limit} tokens "
f"({(current_tokens/context_limit)*100:.1f}%)"
)
return True
return False
except Exception as e:
logger.error(f"Error checking context limit: {str(e)}", exc_info=True)
return False
def _validate_context_size(self, messages: List[Dict]) -> None:
from application.core.model_utils import get_token_limit
current_tokens = self._calculate_current_context_tokens(messages)
self.current_token_count = current_tokens
context_limit = get_token_limit(self.model_id)
percentage = (current_tokens / context_limit) * 100
if current_tokens >= context_limit:
logger.warning(
f"Context at limit: {current_tokens:,}/{context_limit:,} tokens "
f"({percentage:.1f}%). Model: {self.model_id}"
)
elif current_tokens >= int(
context_limit * settings.COMPRESSION_THRESHOLD_PERCENTAGE
):
logger.info(
f"Context approaching limit: {current_tokens:,}/{context_limit:,} tokens "
f"({percentage:.1f}%)"
)
def _truncate_text_middle(self, text: str, max_tokens: int) -> str:
from application.utils import num_tokens_from_string
current_tokens = num_tokens_from_string(text)
if current_tokens <= max_tokens:
return text
chars_per_token = len(text) / current_tokens if current_tokens > 0 else 4
target_chars = int(max_tokens * chars_per_token * 0.95)
if target_chars <= 0:
return ""
start_chars = int(target_chars * 0.4)
end_chars = int(target_chars * 0.4)
truncation_marker = "\n\n[... content truncated to fit context limit ...]\n\n"
truncated = text[:start_chars] + truncation_marker + text[-end_chars:]
logger.info(
f"Truncated text from {current_tokens:,} to ~{max_tokens:,} tokens "
f"(removed middle section)"
)
return truncated
# ---- Message building ----
def _build_messages(
self,
@@ -282,9 +225,42 @@ class BaseAgent(ABC):
query: str,
) -> List[Dict]:
"""Build messages using pre-rendered system prompt"""
from application.core.model_utils import get_token_limit
from application.utils import num_tokens_from_string
if self.compressed_summary:
compression_context = (
"\n\n---\n\n"
"This session is being continued from a previous conversation that "
"has been compressed to fit within context limits. "
"The conversation is summarized below:\n\n"
f"{self.compressed_summary}"
)
system_prompt = system_prompt + compression_context
context_limit = get_token_limit(self.model_id)
system_tokens = num_tokens_from_string(system_prompt)
safety_buffer = int(context_limit * 0.1)
available_after_system = context_limit - system_tokens - safety_buffer
max_query_tokens = int(available_after_system * 0.8)
query_tokens = num_tokens_from_string(query)
if query_tokens > max_query_tokens:
query = self._truncate_text_middle(query, max_query_tokens)
query_tokens = num_tokens_from_string(query)
available_for_history = max(available_after_system - query_tokens, 0)
working_history = self._truncate_history_to_fit(
self.chat_history,
available_for_history,
)
messages = [{"role": "system", "content": system_prompt}]
for i in self.chat_history:
for i in working_history:
if "prompt" in i and "response" in i:
messages.append({"role": "user", "content": i["prompt"]})
messages.append({"role": "assistant", "content": i["response"]})
@@ -316,8 +292,58 @@ class BaseAgent(ABC):
messages.append({"role": "user", "content": query})
return messages
def _truncate_history_to_fit(
self,
history: List[Dict],
max_tokens: int,
) -> List[Dict]:
from application.utils import num_tokens_from_string
if not history or max_tokens <= 0:
return []
truncated = []
current_tokens = 0
for message in reversed(history):
message_tokens = 0
if "prompt" in message and "response" in message:
message_tokens += num_tokens_from_string(message["prompt"])
message_tokens += num_tokens_from_string(message["response"])
if "tool_calls" in message:
for tool_call in message["tool_calls"]:
tool_str = (
f"Tool: {tool_call.get('tool_name')} | "
f"Action: {tool_call.get('action_name')} | "
f"Args: {tool_call.get('arguments')} | "
f"Response: {tool_call.get('result')}"
)
message_tokens += num_tokens_from_string(tool_str)
if current_tokens + message_tokens <= max_tokens:
current_tokens += message_tokens
truncated.insert(0, message)
else:
break
if len(truncated) < len(history):
logger.info(
f"Truncated chat history from {len(history)} to {len(truncated)} messages "
f"to fit within {max_tokens:,} token budget"
)
return truncated
# ---- LLM generation ----
def _llm_gen(self, messages: List[Dict], log_context: Optional[LogContext] = None):
self._validate_context_size(messages)
gen_kwargs = {"model": self.model_id, "messages": messages}
if self.attachments:
gen_kwargs["_usage_attachments"] = self.attachments
if (
hasattr(self.llm, "_supports_tools")

View File

@@ -15,11 +15,7 @@ class ClassicAgent(BaseAgent):
) -> Generator[Dict, None, None]:
"""Core generator function for ClassicAgent execution flow"""
tools_dict = (
self._get_user_tools(self.user)
if not self.user_api_key
else self._get_tools(self.user_api_key)
)
tools_dict = self.tool_executor.get_tools()
self._prepare_tools(tools_dict)
messages = self._build_messages(self.prompt, query)

View File

@@ -1,238 +0,0 @@
import logging
import os
from typing import Any, Dict, Generator, List
from application.agents.base import BaseAgent
from application.logging import build_stack_data, LogContext
logger = logging.getLogger(__name__)
MAX_ITERATIONS_REASONING = 10
current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
with open(
os.path.join(current_dir, "application/prompts", "react_planning_prompt.txt"), "r"
) as f:
PLANNING_PROMPT_TEMPLATE = f.read()
with open(
os.path.join(current_dir, "application/prompts", "react_final_prompt.txt"), "r"
) as f:
FINAL_PROMPT_TEMPLATE = f.read()
class ReActAgent(BaseAgent):
"""
Research and Action (ReAct) Agent - Advanced reasoning agent with iterative planning.
Implements a think-act-observe loop for complex problem-solving:
1. Creates a strategic plan based on the query
2. Executes tools and gathers observations
3. Iteratively refines approach until satisfied
4. Synthesizes final answer from all observations
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.plan: str = ""
self.observations: List[str] = []
def _gen_inner(
self, query: str, log_context: LogContext
) -> Generator[Dict, None, None]:
"""Execute ReAct reasoning loop with planning, action, and observation cycles"""
self._reset_state()
tools_dict = (
self._get_tools(self.user_api_key)
if self.user_api_key
else self._get_user_tools(self.user)
)
self._prepare_tools(tools_dict)
for iteration in range(1, MAX_ITERATIONS_REASONING + 1):
yield {"thought": f"Reasoning... (iteration {iteration})\n\n"}
yield from self._planning_phase(query, log_context)
if not self.plan:
logger.warning(
f"ReActAgent: No plan generated in iteration {iteration}"
)
break
self.observations.append(f"Plan (iteration {iteration}): {self.plan}")
satisfied = yield from self._execution_phase(query, tools_dict, log_context)
if satisfied:
logger.info("ReActAgent: Goal satisfied, stopping reasoning loop")
break
yield from self._synthesis_phase(query, log_context)
def _reset_state(self):
"""Reset agent state for new query"""
self.plan = ""
self.observations = []
def _planning_phase(
self, query: str, log_context: LogContext
) -> Generator[Dict, None, None]:
"""Generate strategic plan for query"""
logger.info("ReActAgent: Creating plan...")
plan_prompt = self._build_planning_prompt(query)
messages = [{"role": "user", "content": plan_prompt}]
plan_stream = self.llm.gen_stream(
model=self.model_id,
messages=messages,
tools=self.tools if self.tools else None,
)
if log_context:
log_context.stacks.append(
{"component": "planning_llm", "data": build_stack_data(self.llm)}
)
plan_parts = []
for chunk in plan_stream:
content = self._extract_content(chunk)
if content:
plan_parts.append(content)
yield {"thought": content}
self.plan = "".join(plan_parts)
def _execution_phase(
self, query: str, tools_dict: Dict, log_context: LogContext
) -> Generator[bool, None, None]:
"""Execute plan with tool calls and observations"""
execution_prompt = self._build_execution_prompt(query)
messages = self._build_messages(execution_prompt, query)
llm_response = self._llm_gen(messages, log_context)
initial_content = self._extract_content(llm_response)
if initial_content:
self.observations.append(f"Initial response: {initial_content}")
processed_response = self._llm_handler(
llm_response, tools_dict, messages, log_context
)
for tool_call in self.tool_calls:
observation = (
f"Executed: {tool_call.get('tool_name', 'Unknown')} "
f"with args {tool_call.get('arguments', {})}. "
f"Result: {str(tool_call.get('result', ''))[:200]}"
)
self.observations.append(observation)
final_content = self._extract_content(processed_response)
if final_content:
self.observations.append(f"Response after tools: {final_content}")
if log_context:
log_context.stacks.append(
{
"component": "agent_tool_calls",
"data": {"tool_calls": self.tool_calls.copy()},
}
)
yield {"sources": self.retrieved_docs}
yield {"tool_calls": self._get_truncated_tool_calls()}
return "SATISFIED" in (final_content or "")
def _synthesis_phase(
self, query: str, log_context: LogContext
) -> Generator[Dict, None, None]:
"""Synthesize final answer from all observations"""
logger.info("ReActAgent: Generating final answer...")
final_prompt = self._build_final_answer_prompt(query)
messages = [{"role": "user", "content": final_prompt}]
final_stream = self.llm.gen_stream(
model=self.model_id, messages=messages, tools=None
)
if log_context:
log_context.stacks.append(
{"component": "final_answer_llm", "data": build_stack_data(self.llm)}
)
for chunk in final_stream:
content = self._extract_content(chunk)
if content:
yield {"answer": content}
def _build_planning_prompt(self, query: str) -> str:
"""Build planning phase prompt"""
prompt = PLANNING_PROMPT_TEMPLATE.replace("{query}", query)
prompt = prompt.replace("{prompt}", self.prompt or "")
prompt = prompt.replace("{summaries}", "")
prompt = prompt.replace("{observations}", "\n".join(self.observations))
return prompt
def _build_execution_prompt(self, query: str) -> str:
"""Build execution phase prompt with plan and observations"""
observations_str = "\n".join(self.observations)
if len(observations_str) > 20000:
observations_str = observations_str[:20000] + "\n...[truncated]"
return (
f"{self.prompt or ''}\n\n"
f"Follow this plan:\n{self.plan}\n\n"
f"Observations:\n{observations_str}\n\n"
f"If sufficient data exists to answer '{query}', respond with 'SATISFIED'. "
f"Otherwise, continue executing the plan."
)
def _build_final_answer_prompt(self, query: str) -> str:
"""Build final synthesis prompt"""
observations_str = "\n".join(self.observations)
if len(observations_str) > 10000:
observations_str = observations_str[:10000] + "\n...[truncated]"
logger.warning("ReActAgent: Observations truncated for final answer")
return FINAL_PROMPT_TEMPLATE.format(query=query, observations=observations_str)
def _extract_content(self, response: Any) -> str:
"""Extract text content from various LLM response formats"""
if not response:
return ""
collected = []
if isinstance(response, str):
return response
if hasattr(response, "message") and hasattr(response.message, "content"):
if response.message.content:
return response.message.content
if hasattr(response, "choices") and response.choices:
if hasattr(response.choices[0], "message"):
content = response.choices[0].message.content
if content:
return content
if hasattr(response, "content") and isinstance(response.content, list):
if response.content and hasattr(response.content[0], "text"):
return response.content[0].text
try:
for chunk in response:
content_piece = ""
if hasattr(chunk, "choices") and chunk.choices:
if hasattr(chunk.choices[0], "delta"):
delta_content = chunk.choices[0].delta.content
if delta_content:
content_piece = delta_content
elif hasattr(chunk, "type") and chunk.type == "content_block_delta":
if hasattr(chunk, "delta") and hasattr(chunk.delta, "text"):
content_piece = chunk.delta.text
elif isinstance(chunk, str):
content_piece = chunk
if content_piece:
collected.append(content_piece)
except (TypeError, AttributeError):
logger.debug(
f"Response not iterable or unexpected format: {type(response)}"
)
except Exception as e:
logger.error(f"Error extracting content: {e}")
return "".join(collected)

View File

@@ -0,0 +1,692 @@
import json
import logging
import os
import time
from typing import Dict, Generator, List, Optional
from application.agents.base import BaseAgent
from application.agents.tool_executor import ToolExecutor
from application.agents.tools.internal_search import (
INTERNAL_TOOL_ID,
add_internal_search_tool,
)
from application.agents.tools.think import THINK_TOOL_ENTRY, THINK_TOOL_ID
from application.logging import LogContext
logger = logging.getLogger(__name__)
# Defaults (can be overridden via constructor)
DEFAULT_MAX_STEPS = 6
DEFAULT_MAX_SUB_ITERATIONS = 5
DEFAULT_TIMEOUT_SECONDS = 300 # 5 minutes
DEFAULT_TOKEN_BUDGET = 100_000
DEFAULT_PARALLEL_WORKERS = 3
# Adaptive depth caps per complexity level
COMPLEXITY_CAPS = {
"simple": 2,
"moderate": 4,
"complex": 6,
}
_PROMPTS_DIR = os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
"prompts",
"research",
)
def _load_prompt(name: str) -> str:
with open(os.path.join(_PROMPTS_DIR, name), "r") as f:
return f.read()
CLARIFICATION_PROMPT = _load_prompt("clarification.txt")
PLANNING_PROMPT = _load_prompt("planning.txt")
STEP_PROMPT = _load_prompt("step.txt")
SYNTHESIS_PROMPT = _load_prompt("synthesis.txt")
# ---------------------------------------------------------------------------
# CitationManager
# ---------------------------------------------------------------------------
class CitationManager:
"""Tracks and deduplicates citations across research steps."""
def __init__(self):
self.citations: Dict[int, Dict] = {}
self._counter = 0
def add(self, doc: Dict) -> int:
"""Register a source, return its citation number. Deduplicates by source."""
source = doc.get("source", "")
title = doc.get("title", "")
for num, existing in self.citations.items():
if existing.get("source") == source and existing.get("title") == title:
return num
self._counter += 1
self.citations[self._counter] = doc
return self._counter
def add_docs(self, docs: List[Dict]) -> str:
"""Register multiple docs, return formatted citation mapping text."""
mapping_lines = []
for doc in docs:
num = self.add(doc)
title = doc.get("title", "Untitled")
mapping_lines.append(f"[{num}] {title}")
return "\n".join(mapping_lines)
def format_references(self) -> str:
"""Generate [N] -> source mapping for report footer."""
if not self.citations:
return "No sources found."
lines = []
for num, doc in sorted(self.citations.items()):
title = doc.get("title", "Untitled")
source = doc.get("source", "Unknown")
filename = doc.get("filename", "")
display = filename or title
lines.append(f"[{num}] {display}{source}")
return "\n".join(lines)
def get_all_docs(self) -> List[Dict]:
return list(self.citations.values())
# ---------------------------------------------------------------------------
# ResearchAgent
# ---------------------------------------------------------------------------
class ResearchAgent(BaseAgent):
"""Multi-step research agent with parallel execution and budget controls.
Orchestrates: Plan -> Research (per step, optionally parallel) -> Synthesize.
"""
def __init__(
self,
retriever_config: Optional[Dict] = None,
max_steps: int = DEFAULT_MAX_STEPS,
max_sub_iterations: int = DEFAULT_MAX_SUB_ITERATIONS,
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
token_budget: int = DEFAULT_TOKEN_BUDGET,
parallel_workers: int = DEFAULT_PARALLEL_WORKERS,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.retriever_config = retriever_config or {}
self.max_steps = max_steps
self.max_sub_iterations = max_sub_iterations
self.timeout_seconds = timeout_seconds
self.token_budget = token_budget
self.parallel_workers = parallel_workers
self.citations = CitationManager()
self._start_time: float = 0
self._tokens_used: int = 0
self._last_token_snapshot: int = 0
# ------------------------------------------------------------------
# Budget & timeout helpers
# ------------------------------------------------------------------
def _is_timed_out(self) -> bool:
return (time.monotonic() - self._start_time) >= self.timeout_seconds
def _elapsed(self) -> float:
return round(time.monotonic() - self._start_time, 1)
def _track_tokens(self, count: int):
self._tokens_used += count
def _budget_remaining(self) -> int:
return max(self.token_budget - self._tokens_used, 0)
def _is_over_budget(self) -> bool:
return self._tokens_used >= self.token_budget
def _snapshot_llm_tokens(self) -> int:
"""Read current token usage from LLM and return delta since last snapshot."""
current = self.llm.token_usage.get("prompt_tokens", 0) + self.llm.token_usage.get("generated_tokens", 0)
delta = current - self._last_token_snapshot
self._last_token_snapshot = current
return delta
# ------------------------------------------------------------------
# Main orchestration
# ------------------------------------------------------------------
def _gen_inner(
self, query: str, log_context: LogContext
) -> Generator[Dict, None, None]:
self._start_time = time.monotonic()
tools_dict = self._setup_tools()
# Phase 0: Clarification (skip if user is responding to a prior clarification)
if not self._is_follow_up():
clarification = self._clarification_phase(query)
if clarification:
yield {"metadata": {"is_clarification": True}}
yield {"answer": clarification}
yield {"sources": []}
yield {"tool_calls": []}
log_context.stacks.append(
{"component": "agent", "data": {"clarification": True}}
)
return
# Phase 1: Planning (with adaptive depth)
yield {"type": "research_progress", "data": {"status": "planning"}}
plan, complexity = self._planning_phase(query)
if not plan:
logger.warning("ResearchAgent: Planning produced no steps, falling back")
plan = [{"query": query, "rationale": "Direct investigation"}]
complexity = "simple"
yield {
"type": "research_plan",
"data": {"steps": plan, "complexity": complexity},
}
# Phase 2: Research each step (yields progress events in real-time)
intermediate_reports = []
for i, step in enumerate(plan):
step_num = i + 1
step_query = step.get("query", query)
if self._is_timed_out():
logger.warning(
f"ResearchAgent: Timeout at step {step_num}/{len(plan)} "
f"({self._elapsed()}s)"
)
break
if self._is_over_budget():
logger.warning(
f"ResearchAgent: Token budget exhausted at step {step_num}/{len(plan)}"
)
break
yield {
"type": "research_progress",
"data": {
"step": step_num,
"total": len(plan),
"query": step_query,
"status": "researching",
},
}
report = self._research_step(step_query, tools_dict)
intermediate_reports.append({"step": step, "content": report})
yield {
"type": "research_progress",
"data": {
"step": step_num,
"total": len(plan),
"query": step_query,
"status": "complete",
},
}
# Phase 3: Synthesis (streaming)
if self._is_timed_out():
logger.warning(
f"ResearchAgent: Timeout ({self._elapsed()}s) before synthesis, "
f"synthesizing with {len(intermediate_reports)} reports"
)
yield {
"type": "research_progress",
"data": {
"status": "synthesizing",
"elapsed_seconds": self._elapsed(),
"tokens_used": self._tokens_used,
},
}
yield from self._synthesis_phase(
query, plan, intermediate_reports, tools_dict, log_context
)
# Sources and tool calls
self.retrieved_docs = self.citations.get_all_docs()
yield {"sources": self.retrieved_docs}
yield {"tool_calls": self._get_truncated_tool_calls()}
logger.info(
f"ResearchAgent completed: {len(intermediate_reports)}/{len(plan)} steps, "
f"{self._elapsed()}s, ~{self._tokens_used} tokens"
)
log_context.stacks.append(
{"component": "agent", "data": {"tool_calls": self.tool_calls.copy()}}
)
# ------------------------------------------------------------------
# Tool setup
# ------------------------------------------------------------------
def _setup_tools(self) -> Dict:
"""Build tools_dict with user tools + internal search + think."""
tools_dict = self.tool_executor.get_tools()
add_internal_search_tool(tools_dict, self.retriever_config)
think_entry = dict(THINK_TOOL_ENTRY)
think_entry["config"] = {}
tools_dict[THINK_TOOL_ID] = think_entry
self._prepare_tools(tools_dict)
return tools_dict
# ------------------------------------------------------------------
# Phase 0: Clarification
# ------------------------------------------------------------------
def _is_follow_up(self) -> bool:
"""Check if the user is responding to a prior clarification.
Uses the metadata flag stored in the conversation DB — no string matching.
Only skip clarification when the last query was explicitly flagged
as a clarification by this agent.
"""
if not self.chat_history:
return False
last = self.chat_history[-1]
meta = last.get("metadata", {})
return bool(meta.get("is_clarification"))
def _clarification_phase(self, question: str) -> Optional[str]:
"""Ask the LLM whether the question needs clarification.
Returns formatted clarification text if needed, or None to proceed.
Uses response_format to force valid JSON output.
"""
messages = [
{"role": "system", "content": CLARIFICATION_PROMPT},
{"role": "user", "content": question},
]
try:
response = self.llm.gen(
model=self.model_id,
messages=messages,
tools=None,
response_format={"type": "json_object"},
)
text = self._extract_text(response)
self._track_tokens(self._snapshot_llm_tokens())
logger.info(f"ResearchAgent clarification response: {text[:300]}")
data = self._parse_clarification_json(text)
if not data or not data.get("needs_clarification"):
return None
questions = data.get("questions", [])
if not questions:
return None
# Format as a friendly response
lines = [
"Before I begin researching, I'd like to clarify a few things:\n"
]
for i, q in enumerate(questions[:3], 1):
lines.append(f"{i}. {q}")
lines.append(
"\nPlease provide these details and I'll start the research."
)
return "\n".join(lines)
except Exception as e:
logger.error(f"Clarification phase failed: {e}", exc_info=True)
return None # proceed with research on failure
def _parse_clarification_json(self, text: str) -> Optional[Dict]:
"""Parse clarification JSON from LLM response."""
try:
return json.loads(text)
except json.JSONDecodeError:
pass
# Try extracting from code fences
for marker in ["```json", "```"]:
if marker in text:
start = text.index(marker) + len(marker)
end = text.index("```", start) if "```" in text[start:] else len(text)
try:
return json.loads(text[start:end].strip())
except (json.JSONDecodeError, ValueError):
pass
# Try finding JSON object
for i, ch in enumerate(text):
if ch == "{":
for j in range(len(text) - 1, i, -1):
if text[j] == "}":
try:
return json.loads(text[i : j + 1])
except json.JSONDecodeError:
continue
break
return None
# ------------------------------------------------------------------
# Phase 1: Planning (with adaptive depth)
# ------------------------------------------------------------------
def _planning_phase(self, question: str) -> tuple[List[Dict], str]:
"""Decompose the question into research steps via LLM.
Returns (steps, complexity) where complexity is simple/moderate/complex.
"""
messages = [
{"role": "system", "content": PLANNING_PROMPT},
{"role": "user", "content": question},
]
try:
response = self.llm.gen(
model=self.model_id,
messages=messages,
tools=None,
response_format={"type": "json_object"},
)
text = self._extract_text(response)
self._track_tokens(self._snapshot_llm_tokens())
logger.info(f"ResearchAgent planning LLM response: {text[:500]}")
plan_data = self._parse_plan_json(text)
if isinstance(plan_data, dict):
complexity = plan_data.get("complexity", "moderate")
steps = plan_data.get("steps", [])
else:
complexity = "moderate"
steps = plan_data
# Adaptive depth: cap steps based on assessed complexity
cap = COMPLEXITY_CAPS.get(complexity, self.max_steps)
cap = min(cap, self.max_steps)
steps = steps[:cap]
logger.info(
f"ResearchAgent plan: complexity={complexity}, "
f"steps={len(steps)} (cap={cap})"
)
return steps, complexity
except Exception as e:
logger.error(f"Planning phase failed: {e}", exc_info=True)
return (
[{"query": question, "rationale": "Direct investigation (planning failed)"}],
"simple",
)
def _parse_plan_json(self, text: str):
"""Extract JSON plan from LLM response. Returns dict or list."""
# Try direct parse
try:
data = json.loads(text)
if isinstance(data, dict) and "steps" in data:
return data
if isinstance(data, list):
return data
except json.JSONDecodeError:
pass
# Try extracting from markdown code fences
for marker in ["```json", "```"]:
if marker in text:
start = text.index(marker) + len(marker)
end = text.index("```", start) if "```" in text[start:] else len(text)
try:
data = json.loads(text[start:end].strip())
if isinstance(data, dict) and "steps" in data:
return data
if isinstance(data, list):
return data
except (json.JSONDecodeError, ValueError):
pass
# Try finding JSON object in text
for i, ch in enumerate(text):
if ch == "{":
for j in range(len(text) - 1, i, -1):
if text[j] == "}":
try:
data = json.loads(text[i : j + 1])
if isinstance(data, dict) and "steps" in data:
return data
except json.JSONDecodeError:
continue
break
logger.warning(f"Could not parse plan JSON from: {text[:200]}")
return []
# ------------------------------------------------------------------
# Phase 2: Research step (core loop)
# ------------------------------------------------------------------
def _research_step(self, step_query: str, tools_dict: Dict) -> str:
"""Run a focused research loop for one sub-question (sequential path)."""
report = self._research_step_with_executor(
step_query, tools_dict, self.tool_executor
)
self._collect_step_sources()
return report
def _research_step_with_executor(
self, step_query: str, tools_dict: Dict, executor: ToolExecutor
) -> str:
"""Core research loop. Works with any ToolExecutor instance."""
system_prompt = STEP_PROMPT.replace("{step_query}", step_query)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": step_query},
]
last_search_empty = False
for iteration in range(self.max_sub_iterations):
# Check timeout and budget
if self._is_timed_out():
logger.info(
f"Research step '{step_query[:50]}' timed out at iteration {iteration}"
)
break
if self._is_over_budget():
logger.info(
f"Research step '{step_query[:50]}' hit token budget at iteration {iteration}"
)
break
try:
response = self.llm.gen(
model=self.model_id,
messages=messages,
tools=self.tools if self.tools else None,
)
self._track_tokens(self._snapshot_llm_tokens())
except Exception as e:
logger.error(
f"Research step LLM call failed (iteration {iteration}): {e}",
exc_info=True,
)
break
parsed = self.llm_handler.parse_response(response)
if not parsed.requires_tool_call:
return parsed.content or "No findings for this step."
# Execute tool calls
messages, last_search_empty = self._execute_step_tools_with_refinement(
parsed.tool_calls, tools_dict, messages, executor, last_search_empty
)
# Max iterations / timeout / budget — ask for summary
messages.append(
{
"role": "user",
"content": "Please summarize your findings so far based on the information gathered.",
}
)
try:
response = self.llm.gen(
model=self.model_id, messages=messages, tools=None
)
self._track_tokens(self._snapshot_llm_tokens())
text = self._extract_text(response)
return text or "Research step completed."
except Exception:
return "Research step completed."
def _execute_step_tools_with_refinement(
self,
tool_calls,
tools_dict: Dict,
messages: List[Dict],
executor: ToolExecutor,
last_search_empty: bool,
) -> tuple[List[Dict], bool]:
"""Execute tool calls with query refinement on empty results.
Returns (updated_messages, was_last_search_empty).
"""
search_returned_empty = False
for call in tool_calls:
gen = executor.execute(
tools_dict, call, self.llm.__class__.__name__
)
result = None
call_id = None
while True:
try:
event = next(gen)
# Log tool_call status events instead of discarding them
if isinstance(event, dict) and event.get("type") == "tool_call":
logger.debug(
"Tool %s status: %s",
event.get("data", {}).get("action_name", ""),
event.get("data", {}).get("status", ""),
)
except StopIteration as e:
result, call_id = e.value
break
# Detect empty search results for refinement
is_search = "search" in (call.name or "").lower()
result_str = str(result) if result else ""
if is_search and "No documents found" in result_str:
search_returned_empty = True
if last_search_empty:
# Two consecutive empty searches — inject refinement hint
result_str += (
"\n\nHint: Previous search also returned no results. "
"Try a very different query with different keywords, "
"or broaden your search terms."
)
result = result_str
function_call_content = {
"function_call": {
"name": call.name,
"args": call.arguments,
"call_id": call_id,
}
}
messages.append(
{"role": "assistant", "content": [function_call_content]}
)
tool_message = self.llm_handler.create_tool_message(call, result)
messages.append(tool_message)
return messages, search_returned_empty
def _collect_step_sources(self):
"""Collect sources from InternalSearchTool and register with CitationManager."""
cache_key = f"internal_search:{INTERNAL_TOOL_ID}:{self.user or ''}"
tool = self.tool_executor._loaded_tools.get(cache_key)
if tool and hasattr(tool, "retrieved_docs"):
for doc in tool.retrieved_docs:
self.citations.add(doc)
# ------------------------------------------------------------------
# Phase 3: Synthesis
# ------------------------------------------------------------------
def _synthesis_phase(
self,
question: str,
plan: List[Dict],
intermediate_reports: List[Dict],
tools_dict: Dict,
log_context: LogContext,
) -> Generator[Dict, None, None]:
"""Compile all findings into a final cited report (streaming)."""
plan_lines = []
for i, step in enumerate(plan, 1):
plan_lines.append(
f"{i}. {step.get('query', 'Unknown')}{step.get('rationale', '')}"
)
plan_summary = "\n".join(plan_lines)
findings_parts = []
for i, report in enumerate(intermediate_reports, 1):
step_query = report["step"].get("query", "Unknown")
content = report["content"]
findings_parts.append(
f"--- Step {i}: {step_query} ---\n{content}"
)
findings = "\n\n".join(findings_parts)
references = self.citations.format_references()
synthesis_prompt = SYNTHESIS_PROMPT.replace("{question}", question)
synthesis_prompt = synthesis_prompt.replace("{plan_summary}", plan_summary)
synthesis_prompt = synthesis_prompt.replace("{findings}", findings)
synthesis_prompt = synthesis_prompt.replace("{references}", references)
messages = [
{"role": "system", "content": synthesis_prompt},
{"role": "user", "content": f"Please write the research report for: {question}"},
]
llm_response = self.llm.gen_stream(
model=self.model_id, messages=messages, tools=None
)
if log_context:
from application.logging import build_stack_data
log_context.stacks.append(
{"component": "synthesis_llm", "data": build_stack_data(self.llm)}
)
yield from self._handle_response(
llm_response, tools_dict, messages, log_context
)
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _extract_text(self, response) -> str:
"""Extract text content from a non-streaming LLM response."""
if isinstance(response, str):
return response
if hasattr(response, "message") and hasattr(response.message, "content"):
return response.message.content or ""
if hasattr(response, "choices") and response.choices:
choice = response.choices[0]
if hasattr(choice, "message") and hasattr(choice.message, "content"):
return choice.message.content or ""
if hasattr(response, "content") and isinstance(response.content, list):
if response.content and hasattr(response.content[0], "text"):
return response.content[0].text or ""
return str(response) if response else ""

View File

@@ -0,0 +1,313 @@
import logging
import uuid
from typing import Dict, List, Optional
from bson.objectid import ObjectId
from application.agents.tools.tool_action_parser import ToolActionParser
from application.agents.tools.tool_manager import ToolManager
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.security.encryption import decrypt_credentials
logger = logging.getLogger(__name__)
class ToolExecutor:
"""Handles tool discovery, preparation, and execution.
Extracted from BaseAgent to separate concerns and enable tool caching.
"""
def __init__(
self,
user_api_key: Optional[str] = None,
user: Optional[str] = None,
decoded_token: Optional[Dict] = None,
):
self.user_api_key = user_api_key
self.user = user
self.decoded_token = decoded_token
self.tool_calls: List[Dict] = []
self._loaded_tools: Dict[str, object] = {}
self.conversation_id: Optional[str] = None
def get_tools(self) -> Dict[str, Dict]:
"""Load tool configs from DB based on user context."""
if self.user_api_key:
return self._get_tools_by_api_key(self.user_api_key)
return self._get_user_tools(self.user or "local")
def _get_tools_by_api_key(self, api_key: str) -> Dict[str, Dict]:
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
agents_collection = db["agents"]
tools_collection = db["user_tools"]
agent_data = agents_collection.find_one({"key": api_key})
tool_ids = agent_data.get("tools", []) if agent_data else []
tools = (
tools_collection.find(
{"_id": {"$in": [ObjectId(tool_id) for tool_id in tool_ids]}}
)
if tool_ids
else []
)
tools = list(tools)
return {str(tool["_id"]): tool for tool in tools} if tools else {}
def _get_user_tools(self, user: str = "local") -> Dict[str, Dict]:
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
user_tools_collection = db["user_tools"]
user_tools = user_tools_collection.find({"user": user, "status": True})
user_tools = list(user_tools)
return {str(i): tool for i, tool in enumerate(user_tools)}
def prepare_tools_for_llm(self, tools_dict: Dict) -> List[Dict]:
"""Convert tool configs to LLM function schemas."""
return [
{
"type": "function",
"function": {
"name": f"{action['name']}_{tool_id}",
"description": action["description"],
"parameters": self._build_tool_parameters(action),
},
}
for tool_id, tool in tools_dict.items()
if (
(tool["name"] == "api_tool" and "actions" in tool.get("config", {}))
or (tool["name"] != "api_tool" and "actions" in tool)
)
for action in (
tool["config"]["actions"].values()
if tool["name"] == "api_tool"
else tool["actions"]
)
if action.get("active", True)
]
def _build_tool_parameters(self, action: Dict) -> Dict:
params = {"type": "object", "properties": {}, "required": []}
for param_type in ["query_params", "headers", "body", "parameters"]:
if param_type in action and action[param_type].get("properties"):
for k, v in action[param_type]["properties"].items():
if v.get("filled_by_llm", True):
params["properties"][k] = {
key: value
for key, value in v.items()
if key not in ("filled_by_llm", "value", "required")
}
if v.get("required", False):
params["required"].append(k)
return params
def execute(self, tools_dict: Dict, call, llm_class_name: str):
"""Execute a tool call. Yields status events, returns (result, call_id)."""
parser = ToolActionParser(llm_class_name)
tool_id, action_name, call_args = parser.parse_args(call)
call_id = getattr(call, "id", None) or str(uuid.uuid4())
if tool_id is None or action_name is None:
error_message = f"Error: Failed to parse LLM tool call. Tool name: {getattr(call, 'name', 'unknown')}"
logger.error(error_message)
tool_call_data = {
"tool_name": "unknown",
"call_id": call_id,
"action_name": getattr(call, "name", "unknown"),
"arguments": call_args or {},
"result": f"Failed to parse tool call. Invalid tool name format: {getattr(call, 'name', 'unknown')}",
}
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
self.tool_calls.append(tool_call_data)
return "Failed to parse tool call.", call_id
if tool_id not in tools_dict:
error_message = f"Error: Tool ID '{tool_id}' extracted from LLM call not found in available tools_dict. Available IDs: {list(tools_dict.keys())}"
logger.error(error_message)
tool_call_data = {
"tool_name": "unknown",
"call_id": call_id,
"action_name": f"{action_name}_{tool_id}",
"arguments": call_args,
"result": f"Tool with ID {tool_id} not found. Available tools: {list(tools_dict.keys())}",
}
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
self.tool_calls.append(tool_call_data)
return f"Tool with ID {tool_id} not found.", call_id
tool_call_data = {
"tool_name": tools_dict[tool_id]["name"],
"call_id": call_id,
"action_name": f"{action_name}_{tool_id}",
"arguments": call_args,
}
yield {"type": "tool_call", "data": {**tool_call_data, "status": "pending"}}
tool_data = tools_dict[tool_id]
action_data = (
tool_data["config"]["actions"][action_name]
if tool_data["name"] == "api_tool"
else next(
action
for action in tool_data["actions"]
if action["name"] == action_name
)
)
query_params, headers, body, parameters = {}, {}, {}, {}
param_types = {
"query_params": query_params,
"headers": headers,
"body": body,
"parameters": parameters,
}
for param_type, target_dict in param_types.items():
if param_type in action_data and action_data[param_type].get("properties"):
for param, details in action_data[param_type]["properties"].items():
if (
param not in call_args
and "value" in details
and details["value"]
):
target_dict[param] = details["value"]
for param, value in call_args.items():
for param_type, target_dict in param_types.items():
if param_type in action_data and param in action_data[param_type].get(
"properties", {}
):
target_dict[param] = value
# Load tool (with caching)
tool = self._get_or_load_tool(
tool_data, tool_id, action_name,
headers=headers, query_params=query_params,
)
resolved_arguments = (
{"query_params": query_params, "headers": headers, "body": body}
if tool_data["name"] == "api_tool"
else parameters
)
if tool_data["name"] == "api_tool":
logger.debug(
f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}"
)
result = tool.execute_action(action_name, **body)
else:
logger.debug(f"Executing tool: {action_name} with args: {call_args}")
result = tool.execute_action(action_name, **parameters)
get_artifact_id = (
getattr(tool, "get_artifact_id", None)
if tool_data["name"] != "api_tool"
else None
)
artifact_id = None
if callable(get_artifact_id):
try:
artifact_id = get_artifact_id(action_name, **parameters)
except Exception:
logger.exception(
"Failed to extract artifact_id from tool %s for action %s",
tool_data["name"],
action_name,
)
artifact_id = str(artifact_id).strip() if artifact_id is not None else ""
if artifact_id:
tool_call_data["artifact_id"] = artifact_id
result_full = str(result)
tool_call_data["resolved_arguments"] = resolved_arguments
tool_call_data["result_full"] = result_full
tool_call_data["result"] = (
f"{result_full[:50]}..." if len(result_full) > 50 else result_full
)
stream_tool_call_data = {
key: value
for key, value in tool_call_data.items()
if key not in {"result_full", "resolved_arguments"}
}
yield {"type": "tool_call", "data": {**stream_tool_call_data, "status": "completed"}}
self.tool_calls.append(tool_call_data)
return result, call_id
def _get_or_load_tool(
self, tool_data: Dict, tool_id: str, action_name: str,
headers: Optional[Dict] = None, query_params: Optional[Dict] = None,
):
"""Load a tool, using cache when possible."""
cache_key = f"{tool_data['name']}:{tool_id}:{self.user or ''}"
if cache_key in self._loaded_tools:
return self._loaded_tools[cache_key]
tm = ToolManager(config={})
if tool_data["name"] == "api_tool":
action_config = tool_data["config"]["actions"][action_name]
tool_config = {
"url": action_config["url"],
"method": action_config["method"],
"headers": headers or {},
"query_params": query_params or {},
}
if "body_content_type" in action_config:
tool_config["body_content_type"] = action_config.get(
"body_content_type", "application/json"
)
tool_config["body_encoding_rules"] = action_config.get(
"body_encoding_rules", {}
)
else:
tool_config = tool_data["config"].copy() if tool_data["config"] else {}
if tool_config.get("encrypted_credentials") and self.user:
decrypted = decrypt_credentials(
tool_config["encrypted_credentials"], self.user
)
tool_config.update(decrypted)
tool_config["auth_credentials"] = decrypted
tool_config.pop("encrypted_credentials", None)
tool_config["tool_id"] = str(tool_data.get("_id", tool_id))
if self.conversation_id:
tool_config["conversation_id"] = self.conversation_id
if tool_data["name"] == "mcp_tool":
tool_config["query_mode"] = True
tool = tm.load_tool(
tool_data["name"],
tool_config=tool_config,
user_id=self.user,
)
# Don't cache api_tool since config varies by action
if tool_data["name"] != "api_tool":
self._loaded_tools[cache_key] = tool
return tool
def get_truncated_tool_calls(self) -> List[Dict]:
return [
{
"tool_name": tool_call.get("tool_name"),
"call_id": tool_call.get("call_id"),
"action_name": tool_call.get("action_name"),
"arguments": tool_call.get("arguments"),
"artifact_id": tool_call.get("artifact_id"),
"result": (
f"{str(tool_call['result'])[:50]}..."
if len(str(tool_call["result"])) > 50
else tool_call["result"]
),
"status": "completed",
}
for tool_call in self.tool_calls
]

View File

@@ -0,0 +1,323 @@
import base64
import json
import logging
from enum import Enum
from typing import Any, Dict, Optional, Union
from urllib.parse import quote, urlencode
logger = logging.getLogger(__name__)
class ContentType(str, Enum):
"""Supported content types for request bodies."""
JSON = "application/json"
FORM_URLENCODED = "application/x-www-form-urlencoded"
MULTIPART_FORM_DATA = "multipart/form-data"
TEXT_PLAIN = "text/plain"
XML = "application/xml"
OCTET_STREAM = "application/octet-stream"
class RequestBodySerializer:
"""Serializes request bodies according to content-type and OpenAPI 3.1 spec."""
@staticmethod
def serialize(
body_data: Dict[str, Any],
content_type: str = ContentType.JSON,
encoding_rules: Optional[Dict[str, Dict[str, Any]]] = None,
) -> tuple[Union[str, bytes], Dict[str, str]]:
"""
Serialize body data to appropriate format.
Args:
body_data: Dictionary of body parameters
content_type: Content-Type header value
encoding_rules: OpenAPI Encoding Object rules per field
Returns:
Tuple of (serialized_body, updated_headers_dict)
Raises:
ValueError: If serialization fails
"""
if not body_data:
return None, {}
try:
content_type_lower = content_type.lower().split(";")[0].strip()
if content_type_lower == ContentType.JSON:
return RequestBodySerializer._serialize_json(body_data)
elif content_type_lower == ContentType.FORM_URLENCODED:
return RequestBodySerializer._serialize_form_urlencoded(
body_data, encoding_rules
)
elif content_type_lower == ContentType.MULTIPART_FORM_DATA:
return RequestBodySerializer._serialize_multipart_form_data(
body_data, encoding_rules
)
elif content_type_lower == ContentType.TEXT_PLAIN:
return RequestBodySerializer._serialize_text_plain(body_data)
elif content_type_lower == ContentType.XML:
return RequestBodySerializer._serialize_xml(body_data)
elif content_type_lower == ContentType.OCTET_STREAM:
return RequestBodySerializer._serialize_octet_stream(body_data)
else:
logger.warning(
f"Unknown content type: {content_type}, treating as JSON"
)
return RequestBodySerializer._serialize_json(body_data)
except Exception as e:
logger.error(f"Error serializing body: {str(e)}", exc_info=True)
raise ValueError(f"Failed to serialize request body: {str(e)}")
@staticmethod
def _serialize_json(body_data: Dict[str, Any]) -> tuple[str, Dict[str, str]]:
"""Serialize body as JSON per OpenAPI spec."""
try:
serialized = json.dumps(
body_data, separators=(",", ":"), ensure_ascii=False
)
headers = {"Content-Type": ContentType.JSON.value}
return serialized, headers
except (TypeError, ValueError) as e:
raise ValueError(f"Failed to serialize JSON body: {str(e)}")
@staticmethod
def _serialize_form_urlencoded(
body_data: Dict[str, Any],
encoding_rules: Optional[Dict[str, Dict[str, Any]]] = None,
) -> tuple[str, Dict[str, str]]:
"""Serialize body as application/x-www-form-urlencoded per RFC1866/RFC3986."""
encoding_rules = encoding_rules or {}
params = []
for key, value in body_data.items():
if value is None:
continue
rule = encoding_rules.get(key, {})
style = rule.get("style", "form")
explode = rule.get("explode", style == "form")
content_type = rule.get("contentType", "text/plain")
serialized_value = RequestBodySerializer._serialize_form_value(
value, style, explode, content_type, key
)
if isinstance(serialized_value, list):
for sv in serialized_value:
params.append((key, sv))
else:
params.append((key, serialized_value))
# Use standard urlencode (replaces space with +)
serialized = urlencode(params, safe="")
headers = {"Content-Type": ContentType.FORM_URLENCODED.value}
return serialized, headers
@staticmethod
def _serialize_form_value(
value: Any, style: str, explode: bool, content_type: str, key: str
) -> Union[str, list]:
"""Serialize individual form value with encoding rules."""
if isinstance(value, dict):
if content_type == "application/json":
return json.dumps(value, separators=(",", ":"))
elif content_type == "application/xml":
return RequestBodySerializer._dict_to_xml(value)
else:
if style == "deepObject" and explode:
return [
f"{RequestBodySerializer._percent_encode(str(v))}"
for v in value.values()
]
elif explode:
return [
f"{RequestBodySerializer._percent_encode(str(v))}"
for v in value.values()
]
else:
pairs = [f"{k},{v}" for k, v in value.items()]
return RequestBodySerializer._percent_encode(",".join(pairs))
elif isinstance(value, (list, tuple)):
if explode:
return [
RequestBodySerializer._percent_encode(str(item)) for item in value
]
else:
return RequestBodySerializer._percent_encode(
",".join(str(v) for v in value)
)
else:
return RequestBodySerializer._percent_encode(str(value))
@staticmethod
def _serialize_multipart_form_data(
body_data: Dict[str, Any],
encoding_rules: Optional[Dict[str, Dict[str, Any]]] = None,
) -> tuple[bytes, Dict[str, str]]:
"""
Serialize body as multipart/form-data per RFC7578.
Supports file uploads and encoding rules.
"""
import secrets
encoding_rules = encoding_rules or {}
boundary = f"----DocsGPT{secrets.token_hex(16)}"
parts = []
for key, value in body_data.items():
if value is None:
continue
rule = encoding_rules.get(key, {})
content_type = rule.get("contentType", "text/plain")
headers_rule = rule.get("headers", {})
part = RequestBodySerializer._create_multipart_part(
key, value, content_type, headers_rule
)
parts.append(part)
body_bytes = f"--{boundary}\r\n".encode("utf-8")
body_bytes += f"--{boundary}\r\n".join(parts).encode("utf-8")
body_bytes += f"\r\n--{boundary}--\r\n".encode("utf-8")
headers = {
"Content-Type": f"multipart/form-data; boundary={boundary}",
}
return body_bytes, headers
@staticmethod
def _create_multipart_part(
name: str, value: Any, content_type: str, headers_rule: Dict[str, Any]
) -> str:
"""Create a single multipart/form-data part."""
headers = [
f'Content-Disposition: form-data; name="{RequestBodySerializer._percent_encode(name)}"'
]
if isinstance(value, bytes):
if content_type == "application/octet-stream":
value_encoded = base64.b64encode(value).decode("utf-8")
else:
value_encoded = value.decode("utf-8", errors="replace")
headers.append(f"Content-Type: {content_type}")
headers.append("Content-Transfer-Encoding: base64")
elif isinstance(value, dict):
if content_type == "application/json":
value_encoded = json.dumps(value, separators=(",", ":"))
elif content_type == "application/xml":
value_encoded = RequestBodySerializer._dict_to_xml(value)
else:
value_encoded = str(value)
headers.append(f"Content-Type: {content_type}")
elif isinstance(value, str) and content_type != "text/plain":
try:
if content_type == "application/json":
json.loads(value)
value_encoded = value
elif content_type == "application/xml":
value_encoded = value
else:
value_encoded = str(value)
except json.JSONDecodeError:
value_encoded = str(value)
headers.append(f"Content-Type: {content_type}")
else:
value_encoded = str(value)
if content_type != "text/plain":
headers.append(f"Content-Type: {content_type}")
part = "\r\n".join(headers) + "\r\n\r\n" + value_encoded + "\r\n"
return part
@staticmethod
def _serialize_text_plain(body_data: Dict[str, Any]) -> tuple[str, Dict[str, str]]:
"""Serialize body as plain text."""
if len(body_data) == 1:
value = list(body_data.values())[0]
return str(value), {"Content-Type": ContentType.TEXT_PLAIN.value}
else:
text = "\n".join(f"{k}: {v}" for k, v in body_data.items())
return text, {"Content-Type": ContentType.TEXT_PLAIN.value}
@staticmethod
def _serialize_xml(body_data: Dict[str, Any]) -> tuple[str, Dict[str, str]]:
"""Serialize body as XML."""
xml_str = RequestBodySerializer._dict_to_xml(body_data)
return xml_str, {"Content-Type": ContentType.XML.value}
@staticmethod
def _serialize_octet_stream(
body_data: Dict[str, Any],
) -> tuple[bytes, Dict[str, str]]:
"""Serialize body as binary octet stream."""
if isinstance(body_data, bytes):
return body_data, {"Content-Type": ContentType.OCTET_STREAM.value}
elif isinstance(body_data, str):
return body_data.encode("utf-8"), {
"Content-Type": ContentType.OCTET_STREAM.value
}
else:
serialized = json.dumps(body_data)
return serialized.encode("utf-8"), {
"Content-Type": ContentType.OCTET_STREAM.value
}
@staticmethod
def _percent_encode(value: str, safe_chars: str = "") -> str:
"""
Percent-encode per RFC3986.
Args:
value: String to encode
safe_chars: Additional characters to not encode
"""
return quote(value, safe=safe_chars)
@staticmethod
def _dict_to_xml(data: Dict[str, Any], root_name: str = "root") -> str:
"""
Convert dict to simple XML format.
"""
def build_xml(obj: Any, name: str) -> str:
if isinstance(obj, dict):
inner = "".join(build_xml(v, k) for k, v in obj.items())
return f"<{name}>{inner}</{name}>"
elif isinstance(obj, (list, tuple)):
items = "".join(
build_xml(item, f"{name[:-1] if name.endswith('s') else name}")
for item in obj
)
return items
else:
return f"<{name}>{RequestBodySerializer._escape_xml(str(obj))}</{name}>"
root = build_xml(data, root_name)
return f'<?xml version="1.0" encoding="UTF-8"?>{root}'
@staticmethod
def _escape_xml(value: str) -> str:
"""Escape XML special characters."""
return (
value.replace("&", "&amp;")
.replace("<", "&lt;")
.replace(">", "&gt;")
.replace('"', "&quot;")
.replace("'", "&apos;")
)

View File

@@ -1,72 +1,280 @@
import json
import logging
import re
from typing import Any, Dict, Optional
from urllib.parse import urlencode
import requests
from application.agents.tools.api_body_serializer import (
ContentType,
RequestBodySerializer,
)
from application.agents.tools.base import Tool
from application.core.url_validation import validate_url, SSRFError
logger = logging.getLogger(__name__)
DEFAULT_TIMEOUT = 90 # seconds
class APITool(Tool):
"""
API Tool
A flexible tool for performing various API actions (e.g., sending messages, retrieving data) via custom user-specified APIs
A flexible tool for performing various API actions (e.g., sending messages, retrieving data) via custom user-specified APIs.
"""
def __init__(self, config):
self.config = config
self.url = config.get("url", "")
self.method = config.get("method", "GET")
self.headers = config.get("headers", {"Content-Type": "application/json"})
self.headers = config.get("headers", {})
self.query_params = config.get("query_params", {})
self.body_content_type = config.get("body_content_type", ContentType.JSON)
self.body_encoding_rules = config.get("body_encoding_rules", {})
def execute_action(self, action_name, **kwargs):
"""Execute an API action with the given arguments."""
return self._make_api_call(
self.url, self.method, self.headers, self.query_params, kwargs
self.url,
self.method,
self.headers,
self.query_params,
kwargs,
self.body_content_type,
self.body_encoding_rules,
)
def _make_api_call(self, url, method, headers, query_params, body):
if query_params:
url = f"{url}?{requests.compat.urlencode(query_params)}"
# if isinstance(body, dict):
# body = json.dumps(body)
def _make_api_call(
self,
url: str,
method: str,
headers: Dict[str, str],
query_params: Dict[str, Any],
body: Dict[str, Any],
content_type: str = ContentType.JSON,
encoding_rules: Optional[Dict[str, Dict[str, Any]]] = None,
) -> Dict[str, Any]:
"""
Make an API call with proper body serialization and error handling.
Args:
url: API endpoint URL
method: HTTP method (GET, POST, PUT, DELETE, PATCH, HEAD, OPTIONS)
headers: Request headers dict
query_params: URL query parameters
body: Request body as dict
content_type: Content-Type for serialization
encoding_rules: OpenAPI encoding rules
Returns:
Dict with status_code, data, and message
"""
request_url = url
request_headers = headers.copy() if headers else {}
response = None
# Validate URL to prevent SSRF attacks
try:
print(f"Making API call: {method} {url} with body: {body}")
if body == "{}":
body = None
response = requests.request(method, url, headers=headers, data=body)
response.raise_for_status()
content_type = response.headers.get(
"Content-Type", "application/json"
).lower()
if "application/json" in content_type:
validate_url(request_url)
except SSRFError as e:
logger.error(f"URL validation failed: {e}")
return {
"status_code": None,
"message": f"URL validation error: {e}",
"data": None,
}
try:
path_params_used = set()
if query_params:
for match in re.finditer(r"\{([^}]+)\}", request_url):
param_name = match.group(1)
if param_name in query_params:
request_url = request_url.replace(
f"{{{param_name}}}", str(query_params[param_name])
)
path_params_used.add(param_name)
remaining_params = {
k: v for k, v in query_params.items() if k not in path_params_used
}
if remaining_params:
query_string = urlencode(remaining_params)
separator = "&" if "?" in request_url else "?"
request_url = f"{request_url}{separator}{query_string}"
# Re-validate URL after parameter substitution to prevent SSRF via path params
try:
validate_url(request_url)
except SSRFError as e:
logger.error(f"URL validation failed after parameter substitution: {e}")
return {
"status_code": None,
"message": f"URL validation error: {e}",
"data": None,
}
# Serialize body based on content type
if body and body != {}:
try:
data = response.json()
except json.JSONDecodeError as e:
print(f"Error decoding JSON: {e}. Raw response: {response.text}")
serialized_body, body_headers = RequestBodySerializer.serialize(
body, content_type, encoding_rules
)
request_headers.update(body_headers)
except ValueError as e:
logger.error(f"Body serialization failed: {str(e)}")
return {
"status_code": response.status_code,
"message": f"API call returned invalid JSON. Error: {e}",
"data": response.text,
"status_code": None,
"message": f"Body serialization error: {str(e)}",
"data": None,
}
elif "text/" in content_type or "application/xml" in content_type:
data = response.text
elif not response.content:
data = None
else:
print(f"Unsupported content type: {content_type}")
data = response.content
serialized_body = None
if "Content-Type" not in request_headers and method not in [
"GET",
"HEAD",
"DELETE",
]:
request_headers["Content-Type"] = ContentType.JSON
logger.debug(
f"API Call: {method} {request_url} | Content-Type: {request_headers.get('Content-Type', 'N/A')}"
)
if method.upper() == "GET":
response = requests.get(
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
)
elif method.upper() == "POST":
response = requests.post(
request_url,
data=serialized_body,
headers=request_headers,
timeout=DEFAULT_TIMEOUT,
)
elif method.upper() == "PUT":
response = requests.put(
request_url,
data=serialized_body,
headers=request_headers,
timeout=DEFAULT_TIMEOUT,
)
elif method.upper() == "DELETE":
response = requests.delete(
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
)
elif method.upper() == "PATCH":
response = requests.patch(
request_url,
data=serialized_body,
headers=request_headers,
timeout=DEFAULT_TIMEOUT,
)
elif method.upper() == "HEAD":
response = requests.head(
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
)
elif method.upper() == "OPTIONS":
response = requests.options(
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
)
else:
return {
"status_code": None,
"message": f"Unsupported HTTP method: {method}",
"data": None,
}
response.raise_for_status()
data = self._parse_response(response)
return {
"status_code": response.status_code,
"data": data,
"message": "API call successful.",
}
except requests.exceptions.Timeout:
logger.error(f"Request timeout for {request_url}")
return {
"status_code": None,
"message": f"Request timeout ({DEFAULT_TIMEOUT}s exceeded)",
"data": None,
}
except requests.exceptions.ConnectionError as e:
logger.error(f"Connection error: {str(e)}")
return {
"status_code": None,
"message": f"Connection error: {str(e)}",
"data": None,
}
except requests.exceptions.HTTPError as e:
logger.error(f"HTTP error {response.status_code}: {str(e)}")
try:
error_data = response.json()
except (json.JSONDecodeError, ValueError):
error_data = response.text
return {
"status_code": response.status_code,
"message": f"HTTP Error {response.status_code}",
"data": error_data,
}
except requests.exceptions.RequestException as e:
logger.error(f"Request failed: {str(e)}")
return {
"status_code": response.status_code if response else None,
"message": f"API call failed: {str(e)}",
"data": None,
}
except Exception as e:
logger.error(f"Unexpected error in API call: {str(e)}", exc_info=True)
return {
"status_code": None,
"message": f"Unexpected error: {str(e)}",
"data": None,
}
def _parse_response(self, response: requests.Response) -> Any:
"""
Parse response based on Content-Type header.
Supports: JSON, XML, plain text, binary data.
"""
content_type = response.headers.get("Content-Type", "").lower()
if not response.content:
return None
# JSON response
if "application/json" in content_type:
try:
return response.json()
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse JSON response: {str(e)}")
return response.text
# XML response
elif "application/xml" in content_type or "text/xml" in content_type:
return response.text
# Plain text response
elif "text/plain" in content_type or "text/html" in content_type:
return response.text
# Binary/unknown response
else:
# Try to decode as text first, fall back to base64
try:
return response.text
except (UnicodeDecodeError, AttributeError):
import base64
return base64.b64encode(response.content).decode("utf-8")
def get_actions_metadata(self):
"""Return metadata for available actions (none for API Tool - actions are user-defined)."""
return []
def get_config_requirements(self):
"""Return configuration requirements for the tool."""
return {}

View File

@@ -1,6 +1,11 @@
import logging
import requests
from application.agents.tools.base import Tool
logger = logging.getLogger(__name__)
class BraveSearchTool(Tool):
"""
@@ -41,7 +46,7 @@ class BraveSearchTool(Tool):
"""
Performs a web search using the Brave Search API.
"""
print(f"Performing Brave web search for: {query}")
logger.debug("Performing Brave web search for: %s", query)
url = f"{self.base_url}/web/search"
@@ -94,7 +99,7 @@ class BraveSearchTool(Tool):
"""
Performs an image search using the Brave Search API.
"""
print(f"Performing Brave image search for: {query}")
logger.debug("Performing Brave image search for: %s", query)
url = f"{self.base_url}/images/search"
@@ -177,6 +182,10 @@ class BraveSearchTool(Tool):
return {
"token": {
"type": "string",
"label": "API Key",
"description": "Brave Search API key for authentication",
"required": True,
"secret": True,
"order": 1,
},
}

View File

@@ -1,5 +1,14 @@
import logging
import time
from typing import Any, Dict, Optional
from application.agents.tools.base import Tool
from duckduckgo_search import DDGS
logger = logging.getLogger(__name__)
MAX_RETRIES = 3
RETRY_DELAY = 2.0
DEFAULT_TIMEOUT = 15
class DuckDuckGoSearchTool(Tool):
@@ -10,71 +19,123 @@ class DuckDuckGoSearchTool(Tool):
def __init__(self, config):
self.config = config
self.timeout = config.get("timeout", DEFAULT_TIMEOUT)
def _get_ddgs_client(self):
from ddgs import DDGS
return DDGS(timeout=self.timeout)
def _execute_with_retry(self, operation, operation_name: str) -> Dict[str, Any]:
last_error = None
for attempt in range(1, MAX_RETRIES + 1):
try:
results = operation()
return {
"status_code": 200,
"results": list(results) if results else [],
"message": f"{operation_name} completed successfully.",
}
except Exception as e:
last_error = e
error_str = str(e).lower()
if "ratelimit" in error_str or "429" in error_str:
if attempt < MAX_RETRIES:
delay = RETRY_DELAY * attempt
logger.warning(
f"{operation_name} rate limited, retrying in {delay}s (attempt {attempt}/{MAX_RETRIES})"
)
time.sleep(delay)
continue
logger.error(f"{operation_name} failed: {e}")
break
return {
"status_code": 500,
"results": [],
"message": f"{operation_name} failed: {str(last_error)}",
}
def execute_action(self, action_name, **kwargs):
actions = {
"ddg_web_search": self._web_search,
"ddg_image_search": self._image_search,
"ddg_news_search": self._news_search,
}
if action_name in actions:
return actions[action_name](**kwargs)
else:
if action_name not in actions:
raise ValueError(f"Unknown action: {action_name}")
return actions[action_name](**kwargs)
def _web_search(
self,
query,
max_results=5,
):
print(f"Performing DuckDuckGo web search for: {query}")
query: str,
max_results: int = 5,
region: str = "wt-wt",
safesearch: str = "moderate",
timelimit: Optional[str] = None,
) -> Dict[str, Any]:
logger.info(f"DuckDuckGo web search: {query}")
try:
results = DDGS().text(
def operation():
client = self._get_ddgs_client()
return client.text(
query,
max_results=max_results,
region=region,
safesearch=safesearch,
timelimit=timelimit,
max_results=min(max_results, 20),
)
return {
"status_code": 200,
"results": results,
"message": "Web search completed successfully.",
}
except Exception as e:
return {
"status_code": 500,
"message": f"Web search failed: {str(e)}",
}
return self._execute_with_retry(operation, "Web search")
def _image_search(
self,
query,
max_results=5,
):
print(f"Performing DuckDuckGo image search for: {query}")
query: str,
max_results: int = 5,
region: str = "wt-wt",
safesearch: str = "moderate",
timelimit: Optional[str] = None,
) -> Dict[str, Any]:
logger.info(f"DuckDuckGo image search: {query}")
try:
results = DDGS().images(
keywords=query,
max_results=max_results,
def operation():
client = self._get_ddgs_client()
return client.images(
query,
region=region,
safesearch=safesearch,
timelimit=timelimit,
max_results=min(max_results, 50),
)
return {
"status_code": 200,
"results": results,
"message": "Image search completed successfully.",
}
except Exception as e:
return {
"status_code": 500,
"message": f"Image search failed: {str(e)}",
}
return self._execute_with_retry(operation, "Image search")
def _news_search(
self,
query: str,
max_results: int = 5,
region: str = "wt-wt",
safesearch: str = "moderate",
timelimit: Optional[str] = None,
) -> Dict[str, Any]:
logger.info(f"DuckDuckGo news search: {query}")
def operation():
client = self._get_ddgs_client()
return client.news(
query,
region=region,
safesearch=safesearch,
timelimit=timelimit,
max_results=min(max_results, 20),
)
return self._execute_with_retry(operation, "News search")
def get_actions_metadata(self):
return [
{
"name": "ddg_web_search",
"description": "Perform a web search using DuckDuckGo.",
"description": "Search the web using DuckDuckGo. Returns titles, URLs, and snippets.",
"parameters": {
"type": "object",
"properties": {
@@ -84,7 +145,15 @@ class DuckDuckGoSearchTool(Tool):
},
"max_results": {
"type": "integer",
"description": "Number of results to return (default: 5)",
"description": "Number of results (default: 5, max: 20)",
},
"region": {
"type": "string",
"description": "Region code (default: wt-wt for worldwide, us-en for US)",
},
"timelimit": {
"type": "string",
"description": "Time filter: d (day), w (week), m (month), y (year)",
},
},
"required": ["query"],
@@ -92,17 +161,43 @@ class DuckDuckGoSearchTool(Tool):
},
{
"name": "ddg_image_search",
"description": "Perform an image search using DuckDuckGo.",
"description": "Search for images using DuckDuckGo. Returns image URLs and metadata.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query",
"description": "Image search query",
},
"max_results": {
"type": "integer",
"description": "Number of results to return (default: 5, max: 50)",
"description": "Number of results (default: 5, max: 50)",
},
"region": {
"type": "string",
"description": "Region code (default: wt-wt for worldwide)",
},
},
"required": ["query"],
},
},
{
"name": "ddg_news_search",
"description": "Search for news articles using DuckDuckGo. Returns recent news.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "News search query",
},
"max_results": {
"type": "integer",
"description": "Number of results (default: 5, max: 20)",
},
"timelimit": {
"type": "string",
"description": "Time filter: d (day), w (week), m (month)",
},
},
"required": ["query"],

View File

@@ -0,0 +1,436 @@
import json
import logging
from typing import Dict, List, Optional
from application.agents.tools.base import Tool
from application.core.settings import settings
from application.retriever.retriever_creator import RetrieverCreator
logger = logging.getLogger(__name__)
class InternalSearchTool(Tool):
"""Wraps the ClassicRAG retriever as an LLM-callable tool.
Instead of pre-fetching docs into the prompt, the LLM decides
when and what to search. Supports multiple searches per session.
Optional capabilities (enabled when sources have directory_structure):
- path_filter on search: restrict results to a specific file/folder
- list_files action: browse the file/folder structure
"""
def __init__(self, config: Dict):
self.config = config
self.retrieved_docs: List[Dict] = []
self._retriever = None
self._directory_structure: Optional[Dict] = None
self._dir_structure_loaded = False
def _get_retriever(self):
if self._retriever is None:
self._retriever = RetrieverCreator.create_retriever(
self.config.get("retriever_name", "classic"),
source=self.config.get("source", {}),
chat_history=[],
prompt="",
chunks=int(self.config.get("chunks", 2)),
doc_token_limit=int(self.config.get("doc_token_limit", 50000)),
model_id=self.config.get("model_id", "docsgpt-local"),
user_api_key=self.config.get("user_api_key"),
agent_id=self.config.get("agent_id"),
llm_name=self.config.get("llm_name", settings.LLM_PROVIDER),
api_key=self.config.get("api_key", settings.API_KEY),
decoded_token=self.config.get("decoded_token"),
)
return self._retriever
def _get_directory_structure(self) -> Optional[Dict]:
"""Load directory structure from MongoDB for the configured sources."""
if self._dir_structure_loaded:
return self._directory_structure
self._dir_structure_loaded = True
source = self.config.get("source", {})
active_docs = source.get("active_docs", [])
if not active_docs:
return None
try:
from bson.objectid import ObjectId
from application.core.mongo_db import MongoDB
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
sources_collection = db["sources"]
if isinstance(active_docs, str):
active_docs = [active_docs]
merged_structure = {}
for doc_id in active_docs:
try:
source_doc = sources_collection.find_one(
{"_id": ObjectId(doc_id)}
)
if not source_doc:
continue
dir_str = source_doc.get("directory_structure")
if dir_str:
if isinstance(dir_str, str):
dir_str = json.loads(dir_str)
source_name = source_doc.get("name", doc_id)
if len(active_docs) > 1:
merged_structure[source_name] = dir_str
else:
merged_structure = dir_str
except Exception as e:
logger.debug(f"Could not load dir structure for {doc_id}: {e}")
self._directory_structure = merged_structure if merged_structure else None
except Exception as e:
logger.debug(f"Failed to load directory structures: {e}")
return self._directory_structure
def execute_action(self, action_name: str, **kwargs):
if action_name == "search":
return self._execute_search(**kwargs)
elif action_name == "list_files":
return self._execute_list_files(**kwargs)
return f"Unknown action: {action_name}"
def _execute_search(self, **kwargs) -> str:
query = kwargs.get("query", "")
path_filter = kwargs.get("path_filter", "")
if not query:
return "Error: 'query' parameter is required."
try:
retriever = self._get_retriever()
docs = retriever.search(query)
except Exception as e:
logger.error(f"Internal search failed: {e}", exc_info=True)
return "Search failed: an internal error occurred."
if not docs:
return "No documents found matching your query."
# Apply path filter if specified
if path_filter:
path_lower = path_filter.lower()
docs = [
d
for d in docs
if path_lower in d.get("source", "").lower()
or path_lower in d.get("filename", "").lower()
or path_lower in d.get("title", "").lower()
]
if not docs:
return f"No documents found matching query '{query}' in path '{path_filter}'."
# Accumulate for source tracking
for doc in docs:
if doc not in self.retrieved_docs:
self.retrieved_docs.append(doc)
# Format results for the LLM
formatted = []
for i, doc in enumerate(docs, 1):
title = doc.get("title", "Untitled")
text = doc.get("text", "")
source = doc.get("source", "Unknown")
filename = doc.get("filename", "")
header = filename or title
formatted.append(f"[{i}] {header} (source: {source})\n{text}")
return "\n\n---\n\n".join(formatted)
def _execute_list_files(self, **kwargs) -> str:
path = kwargs.get("path", "")
dir_structure = self._get_directory_structure()
if not dir_structure:
return "No file structure available for the current sources."
# Navigate to the requested path
current = dir_structure
if path:
for part in path.strip("/").split("/"):
if not part:
continue
if isinstance(current, dict) and part in current:
current = current[part]
else:
return f"Path '{path}' not found in the file structure."
# Format the structure for the LLM
return self._format_structure(current, path or "/")
def _format_structure(self, node: Dict, current_path: str) -> str:
if not isinstance(node, dict):
return f"'{current_path}' is a file, not a directory."
lines = [f"File structure at '{current_path}':\n"]
folders = []
files = []
for name, value in sorted(node.items()):
if isinstance(value, dict):
# Check if it's a file metadata dict or a folder
if "type" in value or "size_bytes" in value or "token_count" in value:
# It's a file with metadata
size = value.get("token_count", "")
ftype = value.get("type", "")
info_parts = []
if ftype:
info_parts.append(ftype)
if size:
info_parts.append(f"{size} tokens")
info = f" ({', '.join(info_parts)})" if info_parts else ""
files.append(f" {name}{info}")
else:
# It's a folder
count = self._count_files(value)
folders.append(f" {name}/ ({count} items)")
else:
files.append(f" {name}")
if folders:
lines.append("Folders:")
lines.extend(folders)
if files:
lines.append("Files:")
lines.extend(files)
if not folders and not files:
lines.append(" (empty)")
return "\n".join(lines)
def _count_files(self, node: Dict) -> int:
count = 0
for value in node.values():
if isinstance(value, dict):
if "type" in value or "size_bytes" in value or "token_count" in value:
count += 1
else:
count += self._count_files(value)
else:
count += 1
return count
def get_actions_metadata(self):
actions = [
{
"name": "search",
"description": (
"Search the user's uploaded documents and knowledge base. "
"Use this to find relevant information before answering questions. "
"You can call this multiple times with different queries."
),
"parameters": {
"properties": {
"query": {
"type": "string",
"description": "The search query. Be specific and focused.",
"filled_by_llm": True,
"required": True,
},
}
},
}
]
# Add path_filter and list_files only if directory structure exists
has_structure = self.config.get("has_directory_structure", False)
if has_structure:
actions[0]["parameters"]["properties"]["path_filter"] = {
"type": "string",
"description": (
"Optional: filter results to a specific file or folder path. "
"Use list_files first to see available paths."
),
"filled_by_llm": True,
"required": False,
}
actions.append(
{
"name": "list_files",
"description": (
"Browse the file and folder structure of the knowledge base. "
"Use this to see what files are available before searching. "
"Optionally provide a path to browse a specific folder."
),
"parameters": {
"properties": {
"path": {
"type": "string",
"description": "Optional: folder path to browse. Leave empty for root.",
"filled_by_llm": True,
"required": False,
}
}
},
}
)
return actions
def get_config_requirements(self):
return {}
# Constants for building synthetic tools_dict entries
INTERNAL_TOOL_ID = "internal"
def build_internal_tool_entry(has_directory_structure: bool = False) -> Dict:
"""Build the tools_dict entry for InternalSearchTool.
Dynamically includes list_files and path_filter based on
whether the sources have directory structure.
"""
search_params = {
"properties": {
"query": {
"type": "string",
"description": "The search query. Be specific and focused.",
"filled_by_llm": True,
"required": True,
}
}
}
actions = [
{
"name": "search",
"description": (
"Search the user's uploaded documents and knowledge base. "
"Use this to find relevant information before answering questions. "
"You can call this multiple times with different queries."
),
"active": True,
"parameters": search_params,
}
]
if has_directory_structure:
search_params["properties"]["path_filter"] = {
"type": "string",
"description": (
"Optional: filter results to a specific file or folder path. "
"Use list_files first to see available paths."
),
"filled_by_llm": True,
"required": False,
}
actions.append(
{
"name": "list_files",
"description": (
"Browse the file and folder structure of the knowledge base. "
"Use this to see what files are available before searching. "
"Optionally provide a path to browse a specific folder."
),
"active": True,
"parameters": {
"properties": {
"path": {
"type": "string",
"description": "Optional: folder path to browse. Leave empty for root.",
"filled_by_llm": True,
"required": False,
}
}
},
}
)
return {"name": "internal_search", "actions": actions}
# Keep backward compat
INTERNAL_TOOL_ENTRY = build_internal_tool_entry(has_directory_structure=False)
def sources_have_directory_structure(source: Dict) -> bool:
"""Check if any of the active sources have directory_structure in MongoDB."""
active_docs = source.get("active_docs", [])
if not active_docs:
return False
try:
from bson.objectid import ObjectId
from application.core.mongo_db import MongoDB
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
sources_collection = db["sources"]
if isinstance(active_docs, str):
active_docs = [active_docs]
for doc_id in active_docs:
try:
source_doc = sources_collection.find_one(
{"_id": ObjectId(doc_id)},
{"directory_structure": 1},
)
if source_doc and source_doc.get("directory_structure"):
return True
except Exception:
continue
except Exception as e:
logger.debug(f"Could not check directory structure: {e}")
return False
def add_internal_search_tool(tools_dict: Dict, retriever_config: Dict) -> None:
"""Add the internal search tool to tools_dict if sources are configured.
Shared by AgenticAgent and ResearchAgent to avoid duplicate setup logic.
Mutates tools_dict in place.
"""
source = retriever_config.get("source", {})
has_sources = bool(source.get("active_docs"))
if not retriever_config or not has_sources:
return
has_dir = sources_have_directory_structure(source)
internal_entry = build_internal_tool_entry(has_directory_structure=has_dir)
internal_entry["config"] = build_internal_tool_config(
**retriever_config,
has_directory_structure=has_dir,
)
tools_dict[INTERNAL_TOOL_ID] = internal_entry
def build_internal_tool_config(
source: Dict,
retriever_name: str = "classic",
chunks: int = 2,
doc_token_limit: int = 50000,
model_id: str = "docsgpt-local",
user_api_key: Optional[str] = None,
agent_id: Optional[str] = None,
llm_name: str = None,
api_key: str = None,
decoded_token: Optional[Dict] = None,
has_directory_structure: bool = False,
) -> Dict:
"""Build the config dict for InternalSearchTool."""
return {
"source": source,
"retriever_name": retriever_name,
"chunks": chunks,
"doc_token_limit": doc_token_limit,
"model_id": model_id,
"user_api_key": user_api_key,
"agent_id": agent_id,
"llm_name": llm_name or settings.LLM_PROVIDER,
"api_key": api_key or settings.API_KEY,
"decoded_token": decoded_token,
"has_directory_structure": has_directory_structure,
}

View File

@@ -1,20 +1,12 @@
import asyncio
import base64
import concurrent.futures
import json
import logging
import time
from typing import Any, Dict, List, Optional
from urllib.parse import parse_qs, urlparse
from application.agents.tools.base import Tool
from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task
from application.cache import get_redis_instance
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.security.encryption import decrypt_credentials
from fastmcp import Client
from fastmcp.client.auth import BearerAuth
from fastmcp.client.transports import (
@@ -24,10 +16,18 @@ from fastmcp.client.transports import (
)
from mcp.client.auth import OAuthClientProvider, TokenStorage
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
from pydantic import AnyHttpUrl, ValidationError
from redis import Redis
from application.agents.tools.base import Tool
from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task
from application.cache import get_redis_instance
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.security.encryption import decrypt_credentials
logger = logging.getLogger(__name__)
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
@@ -56,6 +56,7 @@ class MCPTool(Tool):
- args: Arguments for STDIO transport
- oauth_scopes: OAuth scopes for oauth auth type
- oauth_client_name: OAuth client name for oauth auth type
- query_mode: If True, use non-interactive OAuth (fail-fast on 401)
user_id: User ID for decrypting credentials (required if encrypted_credentials exist)
"""
self.config = config
@@ -76,23 +77,40 @@ class MCPTool(Tool):
self.oauth_scopes = config.get("oauth_scopes", [])
self.oauth_task_id = config.get("oauth_task_id", None)
self.oauth_client_name = config.get("oauth_client_name", "DocsGPT-MCP")
self.redirect_uri = f"{settings.API_URL}/api/mcp_server/callback"
self.redirect_uri = self._resolve_redirect_uri(config.get("redirect_uri"))
self.available_tools = []
self._cache_key = self._generate_cache_key()
self._client = None
# Only validate and setup if server_url is provided and not OAuth
self.query_mode = config.get("query_mode", False)
if self.server_url and self.auth_type != "oauth":
self._setup_client()
def _resolve_redirect_uri(self, configured_redirect_uri: Optional[str]) -> str:
if configured_redirect_uri:
return configured_redirect_uri.rstrip("/")
explicit = getattr(settings, "MCP_OAUTH_REDIRECT_URI", None)
if explicit:
return explicit.rstrip("/")
connector_base = getattr(settings, "CONNECTOR_REDIRECT_BASE_URI", None)
if connector_base:
parsed = urlparse(connector_base)
if parsed.scheme and parsed.netloc:
return f"{parsed.scheme}://{parsed.netloc}/api/mcp_server/callback"
return f"{settings.API_URL.rstrip('/')}/api/mcp_server/callback"
def _generate_cache_key(self) -> str:
"""Generate a unique cache key for this MCP server configuration."""
auth_key = ""
if self.auth_type == "oauth":
scopes_str = ",".join(self.oauth_scopes) if self.oauth_scopes else "none"
auth_key = f"oauth:{self.oauth_client_name}:{scopes_str}"
auth_key = (
f"oauth:{self.oauth_client_name}:{scopes_str}:{self.redirect_uri}"
)
elif self.auth_type in ["bearer"]:
token = self.auth_credentials.get(
"bearer_token", ""
@@ -109,11 +127,10 @@ class MCPTool(Tool):
return f"{self.server_url}#{self.transport_type}#{auth_key}"
def _setup_client(self):
"""Setup FastMCP client with proper transport and authentication."""
global _mcp_clients_cache
if self._cache_key in _mcp_clients_cache:
cached_data = _mcp_clients_cache[self._cache_key]
if time.time() - cached_data["created_at"] < 1800:
if time.time() - cached_data["created_at"] < 300:
self._client = cached_data["client"]
return
else:
@@ -123,15 +140,25 @@ class MCPTool(Tool):
if self.auth_type == "oauth":
redis_client = get_redis_instance()
auth = DocsGPTOAuth(
mcp_url=self.server_url,
scopes=self.oauth_scopes,
redis_client=redis_client,
redirect_uri=self.redirect_uri,
task_id=self.oauth_task_id,
db=db,
user_id=self.user_id,
)
if self.query_mode:
auth = NonInteractiveOAuth(
mcp_url=self.server_url,
scopes=self.oauth_scopes,
redis_client=redis_client,
redirect_uri=self.redirect_uri,
db=db,
user_id=self.user_id,
)
else:
auth = DocsGPTOAuth(
mcp_url=self.server_url,
scopes=self.oauth_scopes,
redis_client=redis_client,
redirect_uri=self.redirect_uri,
task_id=self.oauth_task_id,
db=db,
user_id=self.user_id,
)
elif self.auth_type == "bearer":
token = self.auth_credentials.get(
"bearer_token", ""
@@ -169,6 +196,8 @@ class MCPTool(Tool):
transport_type = "http"
else:
transport_type = self.transport_type
if transport_type == "stdio":
raise ValueError("STDIO transport is disabled")
if transport_type == "sse":
headers.update({"Accept": "text/event-stream", "Cache-Control": "no-cache"})
return SSETransport(url=self.server_url, headers=headers)
@@ -231,38 +260,53 @@ class MCPTool(Tool):
else:
raise Exception(f"Unknown operation: {operation}")
_ERROR_MAP = [
(concurrent.futures.TimeoutError, lambda op, t, _: f"Timed out after {t}s"),
(ConnectionRefusedError, lambda *_: "Connection refused"),
]
_ERROR_PATTERNS = {
("403", "Forbidden"): "Access denied (403 Forbidden)",
("401", "Unauthorized"): "Authentication failed (401 Unauthorized)",
("ECONNREFUSED",): "Connection refused",
("SSL", "certificate"): "SSL/TLS error",
}
def _run_async_operation(self, operation: str, *args, **kwargs):
"""Run async operation in sync context."""
try:
try:
loop = asyncio.get_running_loop()
import concurrent.futures
def run_in_thread():
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
try:
return new_loop.run_until_complete(
self._execute_with_client(operation, *args, **kwargs)
)
finally:
new_loop.close()
asyncio.get_running_loop()
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_in_thread)
future = executor.submit(
self._run_in_new_loop, operation, *args, **kwargs
)
return future.result(timeout=self.timeout)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(
self._execute_with_client(operation, *args, **kwargs)
)
finally:
loop.close()
return self._run_in_new_loop(operation, *args, **kwargs)
except Exception as e:
print(f"Error occurred while running async operation: {e}")
raise
raise self._map_error(operation, e) from e
raise self._map_error(operation, e) from e
def _run_in_new_loop(self, operation, *args, **kwargs):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(
self._execute_with_client(operation, *args, **kwargs)
)
finally:
loop.close()
def _map_error(self, operation: str, exc: Exception) -> Exception:
for exc_type, msg_fn in self._ERROR_MAP:
if isinstance(exc, exc_type):
return Exception(msg_fn(operation, self.timeout, exc))
error_msg = str(exc)
for patterns, friendly in self._ERROR_PATTERNS.items():
if any(p.lower() in error_msg.lower() for p in patterns):
return Exception(friendly)
logger.error("MCP %s failed: %s", operation, exc)
return exc
def discover_tools(self) -> List[Dict]:
"""
@@ -283,16 +327,6 @@ class MCPTool(Tool):
raise Exception(f"Failed to discover tools from MCP server: {str(e)}")
def execute_action(self, action_name: str, **kwargs) -> Any:
"""
Execute an action on the remote MCP server using FastMCP.
Args:
action_name: Name of the action to execute
**kwargs: Parameters for the action
Returns:
Result from the MCP server
"""
if not self.server_url:
raise Exception("No MCP server configured")
if not self._client:
@@ -308,7 +342,37 @@ class MCPTool(Tool):
)
return self._format_result(result)
except Exception as e:
raise Exception(f"Failed to execute action '{action_name}': {str(e)}")
error_msg = str(e)
lower_msg = error_msg.lower()
is_auth_error = (
"401" in error_msg
or "unauthorized" in lower_msg
or "session expired" in lower_msg
or "re-authorize" in lower_msg
)
if is_auth_error:
if self.auth_type == "oauth":
raise Exception(
f"Action '{action_name}' failed: OAuth session expired. "
"Please re-authorize this MCP server in tool settings."
) from e
global _mcp_clients_cache
_mcp_clients_cache.pop(self._cache_key, None)
self._client = None
self._setup_client()
try:
result = self._run_async_operation(
"call_tool", action_name, **cleaned_kwargs
)
return self._format_result(result)
except Exception as retry_e:
raise Exception(
f"Action '{action_name}' failed after re-auth attempt: {retry_e}. "
"Your credentials may have expired — please re-authorize in tool settings."
) from retry_e
raise Exception(
f"Failed to execute action '{action_name}': {error_msg}"
) from e
def _format_result(self, result) -> Dict:
"""Format FastMCP result to match expected format."""
@@ -331,23 +395,35 @@ class MCPTool(Tool):
return result
def test_connection(self) -> Dict:
"""
Test the connection to the MCP server and validate functionality.
Returns:
Dictionary with connection test results including tool count
"""
if not self.server_url:
return {
"success": False,
"message": "No MCP server URL configured",
"message": "No server URL configured",
"tools_count": 0,
}
try:
parsed = urlparse(self.server_url)
if parsed.scheme not in ("http", "https"):
return {
"success": False,
"message": f"Invalid URL scheme '{parsed.scheme}' — use http:// or https://",
"tools_count": 0,
}
except Exception:
return {
"success": False,
"message": "Invalid URL format",
"tools_count": 0,
"transport_type": self.transport_type,
"auth_type": self.auth_type,
"error_type": "ConfigurationError",
}
if not self._client:
self._setup_client()
try:
self._setup_client()
except Exception as e:
return {
"success": False,
"message": f"Client init failed: {str(e)}",
"tools_count": 0,
}
try:
if self.auth_type == "oauth":
return self._test_oauth_connection()
@@ -358,56 +434,94 @@ class MCPTool(Tool):
"success": False,
"message": f"Connection failed: {str(e)}",
"tools_count": 0,
"transport_type": self.transport_type,
"auth_type": self.auth_type,
"error_type": type(e).__name__,
}
def _test_regular_connection(self) -> Dict:
"""Test connection for non-OAuth auth types."""
ping_ok = False
ping_error = None
try:
self._run_async_operation("ping")
ping_success = True
except Exception:
ping_success = False
tools = self.discover_tools()
ping_ok = True
except Exception as e:
ping_error = str(e)
message = f"Successfully connected to MCP server. Found {len(tools)} tools."
if not ping_success:
message += " (Ping not supported, but tool discovery worked)"
return {
"success": True,
"message": message,
"tools_count": len(tools),
"transport_type": self.transport_type,
"auth_type": self.auth_type,
"ping_supported": ping_success,
"tools": [tool.get("name", "unknown") for tool in tools],
}
def _test_oauth_connection(self) -> Dict:
"""Test connection for OAuth auth type with proper async handling."""
try:
task = mcp_oauth_task.delay(config=self.config, user=self.user_id)
if not task:
raise Exception("Failed to start OAuth authentication")
return {
"success": True,
"requires_oauth": True,
"task_id": task.id,
"status": "pending",
"message": "OAuth flow started",
}
tools = self.discover_tools()
except Exception as e:
return {
"success": False,
"message": f"OAuth connection failed: {str(e)}",
"message": f"Connection failed: {ping_error or str(e)}",
"tools_count": 0,
"transport_type": self.transport_type,
"auth_type": self.auth_type,
"error_type": type(e).__name__,
}
if not tools and not ping_ok:
return {
"success": False,
"message": f"Connection failed: {ping_error or 'No tools found'}",
"tools_count": 0,
}
return {
"success": True,
"message": f"Connected — found {len(tools)} tool{'s' if len(tools) != 1 else ''}.",
"tools_count": len(tools),
"tools": [
{
"name": tool.get("name", "unknown"),
"description": tool.get("description", ""),
}
for tool in tools
],
}
def _test_oauth_connection(self) -> Dict:
storage = DBTokenStorage(
server_url=self.server_url, user_id=self.user_id, db_client=db
)
loop = asyncio.new_event_loop()
try:
tokens = loop.run_until_complete(storage.get_tokens())
finally:
loop.close()
if tokens and tokens.access_token:
self.query_mode = True
_mcp_clients_cache.pop(self._cache_key, None)
self._client = None
self._setup_client()
try:
tools = self.discover_tools()
return {
"success": True,
"message": f"Connected — found {len(tools)} tool{'s' if len(tools) != 1 else ''}.",
"tools_count": len(tools),
"tools": [
{
"name": t.get("name", "unknown"),
"description": t.get("description", ""),
}
for t in tools
],
}
except Exception as e:
logger.warning("OAuth token validation failed: %s", e)
_mcp_clients_cache.pop(self._cache_key, None)
self._client = None
return self._start_oauth_task()
def _start_oauth_task(self) -> Dict:
task_config = self.config.copy()
task_config.pop("query_mode", None)
result = mcp_oauth_task.delay(task_config, self.user_id)
return {
"success": False,
"requires_oauth": True,
"task_id": result.id,
"message": "OAuth authorization required.",
"tools_count": 0,
}
def get_actions_metadata(self) -> List[Dict]:
"""
Get metadata for all available actions.
@@ -453,107 +567,88 @@ class MCPTool(Tool):
return actions
def get_config_requirements(self) -> Dict:
"""Get configuration requirements for the MCP tool."""
return {
"server_url": {
"type": "string",
"description": "URL of the remote MCP server (e.g., https://api.example.com/mcp or https://docs.mcp.cloudflare.com/sse)",
"label": "Server URL",
"description": "URL of the remote MCP server",
"required": True,
},
"transport_type": {
"type": "string",
"description": "Transport type for connection",
"enum": ["auto", "sse", "http", "stdio"],
"default": "auto",
"required": False,
"help": {
"auto": "Automatically detect best transport",
"sse": "Server-Sent Events (for real-time streaming)",
"http": "HTTP streaming (recommended for production)",
"stdio": "Standard I/O (for local servers)",
},
"secret": False,
"order": 1,
},
"auth_type": {
"type": "string",
"description": "Authentication type",
"label": "Authentication Type",
"description": "Authentication method for the MCP server",
"enum": ["none", "bearer", "oauth", "api_key", "basic"],
"default": "none",
"required": True,
"help": {
"none": "No authentication",
"bearer": "Bearer token authentication",
"oauth": "OAuth 2.1 authentication (with frontend integration)",
"api_key": "API key authentication",
"basic": "Basic authentication",
},
"secret": False,
"order": 2,
},
"auth_credentials": {
"type": "object",
"description": "Authentication credentials (varies by auth_type)",
"api_key": {
"type": "string",
"label": "API Key",
"description": "API key for authentication",
"required": False,
"properties": {
"bearer_token": {
"type": "string",
"description": "Bearer token for bearer auth",
},
"access_token": {
"type": "string",
"description": "Access token for OAuth (if pre-obtained)",
},
"api_key": {
"type": "string",
"description": "API key for api_key auth",
},
"api_key_header": {
"type": "string",
"description": "Header name for API key (default: X-API-Key)",
},
"username": {
"type": "string",
"description": "Username for basic auth",
},
"password": {
"type": "string",
"description": "Password for basic auth",
},
},
"secret": True,
"order": 3,
"depends_on": {"auth_type": "api_key"},
},
"api_key_header": {
"type": "string",
"label": "API Key Header",
"description": "Header name for API key (default: X-API-Key)",
"default": "X-API-Key",
"required": False,
"secret": False,
"order": 4,
"depends_on": {"auth_type": "api_key"},
},
"bearer_token": {
"type": "string",
"label": "Bearer Token",
"description": "Bearer token for authentication",
"required": False,
"secret": True,
"order": 3,
"depends_on": {"auth_type": "bearer"},
},
"username": {
"type": "string",
"label": "Username",
"description": "Username for basic authentication",
"required": False,
"secret": False,
"order": 3,
"depends_on": {"auth_type": "basic"},
},
"password": {
"type": "string",
"label": "Password",
"description": "Password for basic authentication",
"required": False,
"secret": True,
"order": 4,
"depends_on": {"auth_type": "basic"},
},
"oauth_scopes": {
"type": "array",
"description": "OAuth scopes to request (for oauth auth_type)",
"items": {"type": "string"},
"required": False,
"default": [],
},
"oauth_client_name": {
"type": "string",
"description": "Client name for OAuth registration (for oauth auth_type)",
"default": "DocsGPT-MCP",
"required": False,
},
"headers": {
"type": "object",
"description": "Custom headers to send with requests",
"label": "OAuth Scopes",
"description": "Comma-separated OAuth scopes to request",
"required": False,
"secret": False,
"order": 3,
"depends_on": {"auth_type": "oauth"},
},
"timeout": {
"type": "integer",
"description": "Request timeout in seconds",
"type": "number",
"label": "Timeout (seconds)",
"description": "Request timeout in seconds (1-300)",
"default": 30,
"minimum": 1,
"maximum": 300,
"required": False,
},
"command": {
"type": "string",
"description": "Command to run for STDIO transport (e.g., 'python')",
"required": False,
},
"args": {
"type": "array",
"description": "Arguments for STDIO command",
"items": {"type": "string"},
"required": False,
"secret": False,
"order": 10,
},
}
@@ -575,23 +670,8 @@ class DocsGPTOAuth(OAuthClientProvider):
user_id=None,
db=None,
additional_client_metadata: dict[str, Any] | None = None,
skip_redirect_validation: bool = False,
):
"""
Initialize custom OAuth client provider for DocsGPT.
Args:
mcp_url: Full URL to the MCP endpoint
redirect_uri: Custom redirect URI for DocsGPT frontend
redis_client: Redis client for storing auth state
redis_prefix: Prefix for Redis keys
task_id: Task ID for tracking auth status
scopes: OAuth scopes to request
client_name: Name for this client during registration
user_id: User ID for token storage
db: Database instance for token storage
additional_client_metadata: Extra fields for OAuthClientMetadata
"""
self.redirect_uri = redirect_uri
self.redis_client = redis_client
self.redis_prefix = redis_prefix
@@ -614,7 +694,10 @@ class DocsGPTOAuth(OAuthClientProvider):
)
storage = DBTokenStorage(
server_url=self.server_base_url, user_id=self.user_id, db_client=self.db
server_url=self.server_base_url,
user_id=self.user_id,
db_client=self.db,
expected_redirect_uri=None if skip_redirect_validation else redirect_uri,
)
super().__init__(
@@ -646,22 +729,20 @@ class DocsGPTOAuth(OAuthClientProvider):
async def redirect_handler(self, authorization_url: str) -> None:
"""Store auth URL and state in Redis for frontend to use."""
auth_url, state = self._process_auth_url(authorization_url)
logging.info(
"[DocsGPTOAuth] Processed auth_url: %s, state: %s", auth_url, state
)
logger.info("Processed auth_url: %s, state: %s", auth_url, state)
self.auth_url = auth_url
self.extracted_state = state
if self.redis_client and self.extracted_state:
key = f"{self.redis_prefix}auth_url:{self.extracted_state}"
self.redis_client.setex(key, 600, auth_url)
logging.info("[DocsGPTOAuth] Stored auth_url in Redis: %s", key)
logger.info("Stored auth_url in Redis: %s", key)
if self.task_id:
status_key = f"mcp_oauth_status:{self.task_id}"
status_data = {
"status": "requires_redirect",
"message": "OAuth authorization required",
"message": "Authorization required",
"authorization_url": self.auth_url,
"state": self.extracted_state,
"requires_oauth": True,
@@ -681,7 +762,7 @@ class DocsGPTOAuth(OAuthClientProvider):
status_key = f"mcp_oauth_status:{self.task_id}"
status_data = {
"status": "awaiting_callback",
"message": "Waiting for OAuth callback...",
"message": "Waiting for authorization...",
"authorization_url": self.auth_url,
"state": self.extracted_state,
"requires_oauth": True,
@@ -706,7 +787,7 @@ class DocsGPTOAuth(OAuthClientProvider):
if self.task_id:
status_data = {
"status": "callback_received",
"message": "OAuth callback received, completing authentication...",
"message": "Completing authentication...",
"task_id": self.task_id,
}
self.redis_client.setex(status_key, 600, json.dumps(status_data))
@@ -726,14 +807,44 @@ class DocsGPTOAuth(OAuthClientProvider):
await asyncio.sleep(poll_interval)
self.redis_client.delete(f"{self.redis_prefix}auth_url:{self.extracted_state}")
self.redis_client.delete(f"{self.redis_prefix}state:{self.extracted_state}")
raise Exception("OAuth callback timeout: no code received within 5 minutes")
raise Exception("OAuth timeout: no code received within 5 minutes")
class NonInteractiveOAuth(DocsGPTOAuth):
"""OAuth provider that fails fast on 401 instead of starting interactive auth.
Used during query execution to prevent the streaming response from blocking
while waiting for user authorization that will never come.
"""
def __init__(self, **kwargs):
kwargs.setdefault("task_id", None)
kwargs["skip_redirect_validation"] = True
super().__init__(**kwargs)
async def redirect_handler(self, authorization_url: str) -> None:
raise Exception(
"OAuth session expired — please re-authorize this MCP server in tool settings."
)
async def callback_handler(self) -> tuple[str, str | None]:
raise Exception(
"OAuth session expired — please re-authorize this MCP server in tool settings."
)
class DBTokenStorage(TokenStorage):
def __init__(self, server_url: str, user_id: str, db_client):
def __init__(
self,
server_url: str,
user_id: str,
db_client,
expected_redirect_uri: Optional[str] = None,
):
self.server_url = server_url
self.user_id = user_id
self.db_client = db_client
self.expected_redirect_uri = expected_redirect_uri
self.collection = db_client["connector_sessions"]
@staticmethod
@@ -752,10 +863,9 @@ class DBTokenStorage(TokenStorage):
if not doc or "tokens" not in doc:
return None
try:
tokens = OAuthToken.model_validate(doc["tokens"])
return tokens
return OAuthToken.model_validate(doc["tokens"])
except ValidationError as e:
logging.error(f"Could not load tokens: {e}")
logger.error("Could not load tokens: %s", e)
return None
async def set_tokens(self, tokens: OAuthToken) -> None:
@@ -765,28 +875,38 @@ class DBTokenStorage(TokenStorage):
{"$set": {"tokens": tokens.model_dump()}},
True,
)
logging.info(f"Saved tokens for {self.get_base_url(self.server_url)}")
logger.info("Saved tokens for %s", self.get_base_url(self.server_url))
async def get_client_info(self) -> OAuthClientInformationFull | None:
doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key())
if not doc or "client_info" not in doc:
logger.debug(
"No client_info in DB for %s", self.get_base_url(self.server_url)
)
return None
try:
client_info = OAuthClientInformationFull.model_validate(doc["client_info"])
tokens = await self.get_tokens()
if tokens is None:
logging.debug(
"No tokens found, clearing client info to force fresh registration."
)
await asyncio.to_thread(
self.collection.update_one,
self.get_db_key(),
{"$unset": {"client_info": ""}},
)
return None
if self.expected_redirect_uri:
stored_uris = [
str(uri).rstrip("/") for uri in client_info.redirect_uris
]
expected_uri = self.expected_redirect_uri.rstrip("/")
if expected_uri not in stored_uris:
logger.warning(
"Redirect URI mismatch for %s: expected=%s stored=%s — clearing.",
self.get_base_url(self.server_url),
expected_uri,
stored_uris,
)
await asyncio.to_thread(
self.collection.update_one,
self.get_db_key(),
{"$unset": {"client_info": "", "tokens": ""}},
)
return None
return client_info
except ValidationError as e:
logging.error(f"Could not load client info: {e}")
logger.error("Could not load client info: %s", e)
return None
def _serialize_client_info(self, info: dict) -> dict:
@@ -802,17 +922,17 @@ class DBTokenStorage(TokenStorage):
{"$set": {"client_info": serialized_info}},
True,
)
logging.info(f"Saved client info for {self.get_base_url(self.server_url)}")
logger.info("Saved client info for %s", self.get_base_url(self.server_url))
async def clear(self) -> None:
await asyncio.to_thread(self.collection.delete_one, self.get_db_key())
logging.info(f"Cleared OAuth cache for {self.get_base_url(self.server_url)}")
logger.info("Cleared OAuth cache for %s", self.get_base_url(self.server_url))
@classmethod
async def clear_all(cls, db_client) -> None:
collection = db_client["connector_sessions"]
await asyncio.to_thread(collection.delete_many, {})
logging.info("Cleared all OAuth client cache data.")
logger.info("Cleared all OAuth client cache data.")
class MCPOAuthManager:
@@ -851,7 +971,7 @@ class MCPOAuthManager:
return True
except Exception as e:
logging.error(f"Error handling OAuth callback: {e}")
logger.error("Error handling OAuth callback: %s", e)
return False
def get_oauth_status(self, task_id: str) -> Dict[str, Any]:

View File

@@ -38,6 +38,8 @@ class NotesTool(Tool):
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
self.collection = db["notes"]
self._last_artifact_id: Optional[str] = None
# -----------------------------
# Action implementations
# -----------------------------
@@ -54,6 +56,8 @@ class NotesTool(Tool):
if not self.user_id:
return "Error: NotesTool requires a valid user_id."
self._last_artifact_id = None
if action_name == "view":
return self._get_note()
@@ -125,6 +129,9 @@ class NotesTool(Tool):
"""Return configuration requirements (none for now)."""
return {}
def get_artifact_id(self, action_name: str, **kwargs: Any) -> Optional[str]:
return self._last_artifact_id
# -----------------------------
# Internal helpers (single-note)
# -----------------------------
@@ -132,17 +139,22 @@ class NotesTool(Tool):
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
if not doc or not doc.get("note"):
return "No note found."
if doc.get("_id") is not None:
self._last_artifact_id = str(doc.get("_id"))
return str(doc["note"])
def _overwrite_note(self, content: str) -> str:
content = (content or "").strip()
if not content:
return "Note content required."
self.collection.update_one(
result = self.collection.find_one_and_update(
{"user_id": self.user_id, "tool_id": self.tool_id},
{"$set": {"note": content, "updated_at": datetime.utcnow()}},
upsert=True, # ✅ create if missing
upsert=True,
return_document=True,
)
if result and result.get("_id") is not None:
self._last_artifact_id = str(result.get("_id"))
return "Note saved."
def _str_replace(self, old_str: str, new_str: str) -> str:
@@ -163,10 +175,13 @@ class NotesTool(Tool):
import re
updated_note = re.sub(re.escape(old_str), new_str, current_note, flags=re.IGNORECASE)
self.collection.update_one(
result = self.collection.find_one_and_update(
{"user_id": self.user_id, "tool_id": self.tool_id},
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
return_document=True,
)
if result and result.get("_id") is not None:
self._last_artifact_id = str(result.get("_id"))
return "Note updated."
def _insert(self, line_number: int, text: str) -> str:
@@ -188,12 +203,21 @@ class NotesTool(Tool):
lines.insert(index, text)
updated_note = "\n".join(lines)
self.collection.update_one(
result = self.collection.find_one_and_update(
{"user_id": self.user_id, "tool_id": self.tool_id},
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
return_document=True,
)
if result and result.get("_id") is not None:
self._last_artifact_id = str(result.get("_id"))
return "Text inserted."
def _delete_note(self) -> str:
res = self.collection.delete_one({"user_id": self.user_id, "tool_id": self.tool_id})
return "Note deleted." if res.deleted_count else "No note found to delete."
doc = self.collection.find_one_and_delete(
{"user_id": self.user_id, "tool_id": self.tool_id}
)
if not doc:
return "No note found to delete."
if doc.get("_id") is not None:
self._last_artifact_id = str(doc.get("_id"))
return "Note deleted."

View File

@@ -116,12 +116,13 @@ class NtfyTool(Tool):
]
def get_config_requirements(self):
"""
Specify the configuration requirements.
Returns:
dict: Dictionary describing required config parameters.
"""
return {
"token": {"type": "string", "description": "Access token for authentication"},
"token": {
"type": "string",
"label": "Access Token",
"description": "Ntfy access token for authentication",
"required": True,
"secret": True,
"order": 1,
},
}

View File

@@ -1,6 +1,12 @@
import logging
import psycopg2
from application.agents.tools.base import Tool
logger = logging.getLogger(__name__)
class PostgresTool(Tool):
"""
PostgreSQL Database Tool
@@ -17,17 +23,15 @@ class PostgresTool(Tool):
"postgres_execute_sql": self._execute_sql,
"postgres_get_schema": self._get_schema,
}
if action_name in actions:
return actions[action_name](**kwargs)
else:
if action_name not in actions:
raise ValueError(f"Unknown action: {action_name}")
return actions[action_name](**kwargs)
def _execute_sql(self, sql_query):
"""
Executes an SQL query against the PostgreSQL database using a connection string.
"""
conn = None # Initialize conn to None for error handling
conn = None
try:
conn = psycopg2.connect(self.connection_string)
cur = conn.cursor()
@@ -35,7 +39,9 @@ class PostgresTool(Tool):
conn.commit()
if sql_query.strip().lower().startswith("select"):
column_names = [desc[0] for desc in cur.description] if cur.description else []
column_names = (
[desc[0] for desc in cur.description] if cur.description else []
)
results = []
rows = cur.fetchall()
for row in rows:
@@ -43,7 +49,9 @@ class PostgresTool(Tool):
response_data = {"data": results, "column_names": column_names}
else:
row_count = cur.rowcount
response_data = {"message": f"Query executed successfully, {row_count} rows affected."}
response_data = {
"message": f"Query executed successfully, {row_count} rows affected."
}
cur.close()
return {
@@ -54,26 +62,27 @@ class PostgresTool(Tool):
except psycopg2.Error as e:
error_message = f"Database error: {e}"
print(f"Database error: {e}")
logger.error("PostgreSQL execute_sql error: %s", e)
return {
"status_code": 500,
"message": "Failed to execute SQL query.",
"error": error_message,
}
finally:
if conn: # Ensure connection is closed even if errors occur
if conn:
conn.close()
def _get_schema(self, db_name):
"""
Retrieves the schema of the PostgreSQL database using a connection string.
"""
conn = None # Initialize conn to None for error handling
conn = None
try:
conn = psycopg2.connect(self.connection_string)
cur = conn.cursor()
cur.execute("""
cur.execute(
"""
SELECT
table_name,
column_name,
@@ -87,19 +96,22 @@ class PostgresTool(Tool):
ORDER BY
table_name,
ordinal_position;
""")
"""
)
schema_data = {}
for row in cur.fetchall():
table_name, column_name, data_type, column_default, is_nullable = row
if table_name not in schema_data:
schema_data[table_name] = []
schema_data[table_name].append({
"column_name": column_name,
"data_type": data_type,
"column_default": column_default,
"is_nullable": is_nullable
})
schema_data[table_name].append(
{
"column_name": column_name,
"data_type": data_type,
"column_default": column_default,
"is_nullable": is_nullable,
}
)
cur.close()
return {
@@ -110,14 +122,14 @@ class PostgresTool(Tool):
except psycopg2.Error as e:
error_message = f"Database error: {e}"
print(f"Database error: {e}")
logger.error("PostgreSQL get_schema error: %s", e)
return {
"status_code": 500,
"message": "Failed to retrieve database schema.",
"error": error_message,
}
finally:
if conn: # Ensure connection is closed even if errors occur
if conn:
conn.close()
def get_actions_metadata(self):
@@ -158,6 +170,10 @@ class PostgresTool(Tool):
return {
"token": {
"type": "string",
"description": "PostgreSQL database connection string (e.g., 'postgresql://user:password@host:port/dbname')",
"label": "Connection String",
"description": "PostgreSQL database connection string",
"required": True,
"secret": True,
"order": 1,
},
}
}

View File

@@ -1,7 +1,7 @@
import requests
from markdownify import markdownify
from application.agents.tools.base import Tool
from urllib.parse import urlparse
from application.core.url_validation import validate_url, SSRFError
class ReadWebpageTool(Tool):
"""
@@ -31,11 +31,12 @@ class ReadWebpageTool(Tool):
if not url:
return "Error: URL parameter is missing."
# Ensure the URL has a scheme (if not, default to http)
parsed_url = urlparse(url)
if not parsed_url.scheme:
url = "http://" + url
# Validate URL to prevent SSRF attacks
try:
url = validate_url(url)
except SSRFError as e:
return f"Error: URL validation failed - {e}"
try:
response = requests.get(url, timeout=10, headers={'User-Agent': 'DocsGPT-Agent/1.0'})
response.raise_for_status() # Raise an exception for HTTP errors (4xx or 5xx)

View File

@@ -0,0 +1,342 @@
"""
API Specification Parser
Parses OpenAPI 3.x and Swagger 2.0 specifications and converts them
to API Tool action definitions for use in DocsGPT.
"""
import json
import logging
import re
from typing import Any, Dict, List, Optional, Tuple
import yaml
logger = logging.getLogger(__name__)
SUPPORTED_METHODS = frozenset(
{"get", "post", "put", "delete", "patch", "head", "options"}
)
def parse_spec(spec_content: str) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
"""
Parse an API specification and convert operations to action definitions.
Supports OpenAPI 3.x and Swagger 2.0 formats in JSON or YAML.
Args:
spec_content: Raw specification content as string
Returns:
Tuple of (metadata dict, list of action dicts)
Raises:
ValueError: If the spec is invalid or uses an unsupported format
"""
spec = _load_spec(spec_content)
_validate_spec(spec)
is_swagger = "swagger" in spec
metadata = _extract_metadata(spec, is_swagger)
actions = _extract_actions(spec, is_swagger)
return metadata, actions
def _load_spec(content: str) -> Dict[str, Any]:
"""Parse spec content from JSON or YAML string."""
content = content.strip()
if not content:
raise ValueError("Empty specification content")
try:
if content.startswith("{"):
return json.loads(content)
return yaml.safe_load(content)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON format: {e.msg}")
except yaml.YAMLError as e:
raise ValueError(f"Invalid YAML format: {e}")
def _validate_spec(spec: Dict[str, Any]) -> None:
"""Validate spec version and required fields."""
if not isinstance(spec, dict):
raise ValueError("Specification must be a valid object")
openapi_version = spec.get("openapi", "")
swagger_version = spec.get("swagger", "")
if not (openapi_version.startswith("3.") or swagger_version == "2.0"):
raise ValueError(
"Unsupported specification version. Expected OpenAPI 3.x or Swagger 2.0"
)
if "paths" not in spec or not spec["paths"]:
raise ValueError("No API paths defined in the specification")
def _extract_metadata(spec: Dict[str, Any], is_swagger: bool) -> Dict[str, Any]:
"""Extract API metadata from specification."""
info = spec.get("info", {})
base_url = _get_base_url(spec, is_swagger)
return {
"title": info.get("title", "Untitled API"),
"description": (info.get("description", "") or "")[:500],
"version": info.get("version", ""),
"base_url": base_url,
}
def _get_base_url(spec: Dict[str, Any], is_swagger: bool) -> str:
"""Extract base URL from spec (handles both OpenAPI 3.x and Swagger 2.0)."""
if is_swagger:
schemes = spec.get("schemes", ["https"])
host = spec.get("host", "")
base_path = spec.get("basePath", "")
if host:
scheme = schemes[0] if schemes else "https"
return f"{scheme}://{host}{base_path}".rstrip("/")
return ""
servers = spec.get("servers", [])
if servers and isinstance(servers, list) and servers[0].get("url"):
return servers[0]["url"].rstrip("/")
return ""
def _extract_actions(spec: Dict[str, Any], is_swagger: bool) -> List[Dict[str, Any]]:
"""Extract all API operations as action definitions."""
actions = []
paths = spec.get("paths", {})
base_url = _get_base_url(spec, is_swagger)
components = spec.get("components", {})
definitions = spec.get("definitions", {})
for path, path_item in paths.items():
if not isinstance(path_item, dict):
continue
path_params = path_item.get("parameters", [])
for method in SUPPORTED_METHODS:
operation = path_item.get(method)
if not isinstance(operation, dict):
continue
try:
action = _build_action(
path=path,
method=method,
operation=operation,
path_params=path_params,
base_url=base_url,
components=components,
definitions=definitions,
is_swagger=is_swagger,
)
actions.append(action)
except Exception as e:
logger.warning(
f"Failed to parse operation {method.upper()} {path}: {e}"
)
continue
return actions
def _build_action(
path: str,
method: str,
operation: Dict[str, Any],
path_params: List[Dict],
base_url: str,
components: Dict[str, Any],
definitions: Dict[str, Any],
is_swagger: bool,
) -> Dict[str, Any]:
"""Build a single action from an API operation."""
action_name = _generate_action_name(operation, method, path)
full_url = f"{base_url}{path}" if base_url else path
all_params = path_params + operation.get("parameters", [])
query_params, headers = _categorize_parameters(all_params, components, definitions)
body, body_content_type = _extract_request_body(
operation, components, definitions, is_swagger
)
description = operation.get("summary", "") or operation.get("description", "")
return {
"name": action_name,
"url": full_url,
"method": method.upper(),
"description": (description or "")[:500],
"query_params": {"type": "object", "properties": query_params},
"headers": {"type": "object", "properties": headers},
"body": {"type": "object", "properties": body},
"body_content_type": body_content_type,
"active": True,
}
def _generate_action_name(operation: Dict[str, Any], method: str, path: str) -> str:
"""Generate a valid action name from operationId or method+path."""
if operation.get("operationId"):
name = operation["operationId"]
else:
path_slug = re.sub(r"[{}]", "", path)
path_slug = re.sub(r"[^a-zA-Z0-9]", "_", path_slug)
path_slug = re.sub(r"_+", "_", path_slug).strip("_")
name = f"{method}_{path_slug}"
name = re.sub(r"[^a-zA-Z0-9_-]", "_", name)
return name[:64]
def _categorize_parameters(
parameters: List[Dict],
components: Dict[str, Any],
definitions: Dict[str, Any],
) -> Tuple[Dict, Dict]:
"""Categorize parameters into query params and headers."""
query_params = {}
headers = {}
for param in parameters:
resolved = _resolve_ref(param, components, definitions)
if not resolved or "name" not in resolved:
continue
location = resolved.get("in", "query")
prop = _param_to_property(resolved)
if location in ("query", "path"):
query_params[resolved["name"]] = prop
elif location == "header":
headers[resolved["name"]] = prop
return query_params, headers
def _param_to_property(param: Dict) -> Dict[str, Any]:
"""Convert an API parameter to an action property definition."""
schema = param.get("schema", {})
param_type = schema.get("type", param.get("type", "string"))
mapped_type = "integer" if param_type in ("integer", "number") else "string"
return {
"type": mapped_type,
"description": (param.get("description", "") or "")[:200],
"value": "",
"filled_by_llm": param.get("required", False),
"required": param.get("required", False),
}
def _extract_request_body(
operation: Dict[str, Any],
components: Dict[str, Any],
definitions: Dict[str, Any],
is_swagger: bool,
) -> Tuple[Dict, str]:
"""Extract request body schema and content type."""
content_types = [
"application/json",
"application/x-www-form-urlencoded",
"multipart/form-data",
"text/plain",
"application/xml",
]
if is_swagger:
consumes = operation.get("consumes", [])
body_param = next(
(p for p in operation.get("parameters", []) if p.get("in") == "body"), None
)
if not body_param:
return {}, "application/json"
selected_type = consumes[0] if consumes else "application/json"
schema = body_param.get("schema", {})
else:
request_body = operation.get("requestBody", {})
if not request_body:
return {}, "application/json"
request_body = _resolve_ref(request_body, components, definitions)
content = request_body.get("content", {})
selected_type = "application/json"
schema = {}
for ct in content_types:
if ct in content:
selected_type = ct
schema = content[ct].get("schema", {})
break
if not schema and content:
first_type = next(iter(content))
selected_type = first_type
schema = content[first_type].get("schema", {})
properties = _schema_to_properties(schema, components, definitions)
return properties, selected_type
def _schema_to_properties(
schema: Dict,
components: Dict[str, Any],
definitions: Dict[str, Any],
depth: int = 0,
) -> Dict[str, Any]:
"""Convert schema to action body properties (limited depth to prevent recursion)."""
if depth > 3:
return {}
schema = _resolve_ref(schema, components, definitions)
if not schema or not isinstance(schema, dict):
return {}
properties = {}
schema_type = schema.get("type", "object")
if schema_type == "object":
required_fields = set(schema.get("required", []))
for prop_name, prop_schema in schema.get("properties", {}).items():
resolved = _resolve_ref(prop_schema, components, definitions)
if not isinstance(resolved, dict):
continue
prop_type = resolved.get("type", "string")
mapped_type = "integer" if prop_type in ("integer", "number") else "string"
properties[prop_name] = {
"type": mapped_type,
"description": (resolved.get("description", "") or "")[:200],
"value": "",
"filled_by_llm": prop_name in required_fields,
"required": prop_name in required_fields,
}
return properties
def _resolve_ref(
obj: Any,
components: Dict[str, Any],
definitions: Dict[str, Any],
) -> Optional[Dict]:
"""Resolve $ref references in the specification."""
if not isinstance(obj, dict):
return obj if isinstance(obj, dict) else None
if "$ref" not in obj:
return obj
ref_path = obj["$ref"]
if ref_path.startswith("#/components/"):
parts = ref_path.replace("#/components/", "").split("/")
return _traverse_path(components, parts)
elif ref_path.startswith("#/definitions/"):
parts = ref_path.replace("#/definitions/", "").split("/")
return _traverse_path(definitions, parts)
logger.debug(f"Unsupported ref path: {ref_path}")
return None
def _traverse_path(obj: Dict, parts: List[str]) -> Optional[Dict]:
"""Traverse a nested dictionary using path parts."""
try:
for part in parts:
obj = obj[part]
return obj if isinstance(obj, dict) else None
except (KeyError, TypeError):
return None

View File

@@ -1,6 +1,11 @@
import logging
import requests
from application.agents.tools.base import Tool
logger = logging.getLogger(__name__)
class TelegramTool(Tool):
"""
@@ -18,21 +23,19 @@ class TelegramTool(Tool):
"telegram_send_message": self._send_message,
"telegram_send_image": self._send_image,
}
if action_name in actions:
return actions[action_name](**kwargs)
else:
if action_name not in actions:
raise ValueError(f"Unknown action: {action_name}")
return actions[action_name](**kwargs)
def _send_message(self, text, chat_id):
print(f"Sending message: {text}")
logger.debug("Sending Telegram message to chat_id=%s", chat_id)
url = f"https://api.telegram.org/bot{self.token}/sendMessage"
payload = {"chat_id": chat_id, "text": text}
response = requests.post(url, data=payload)
return {"status_code": response.status_code, "message": "Message sent"}
def _send_image(self, image_url, chat_id):
print(f"Sending image: {image_url}")
logger.debug("Sending Telegram image to chat_id=%s", chat_id)
url = f"https://api.telegram.org/bot{self.token}/sendPhoto"
payload = {"chat_id": chat_id, "photo": image_url}
response = requests.post(url, data=payload)
@@ -82,5 +85,12 @@ class TelegramTool(Tool):
def get_config_requirements(self):
return {
"token": {"type": "string", "description": "Bot token for authentication"},
"token": {
"type": "string",
"label": "Bot Token",
"description": "Telegram bot token for authentication",
"required": True,
"secret": True,
"order": 1,
},
}

View File

@@ -0,0 +1,68 @@
from application.agents.tools.base import Tool
THINK_TOOL_ID = "think"
THINK_TOOL_ENTRY = {
"name": "think",
"actions": [
{
"name": "reason",
"description": (
"Use this tool to think through your reasoning step by step "
"before deciding on your next action. Always reason before "
"searching or answering."
),
"active": True,
"parameters": {
"properties": {
"reasoning": {
"type": "string",
"description": "Your step-by-step reasoning and analysis",
"filled_by_llm": True,
"required": True,
}
}
},
}
],
}
class ThinkTool(Tool):
"""Pseudo-tool that captures chain-of-thought reasoning.
Returns a short acknowledgment so the LLM can continue.
The reasoning content is captured in tool_call data for transparency.
"""
def __init__(self, config=None):
pass
def execute_action(self, action_name: str, **kwargs):
return "Continue."
def get_actions_metadata(self):
return [
{
"name": "reason",
"description": (
"Use this tool to think through your reasoning step by step "
"before deciding on your next action. Always reason before "
"searching or answering."
),
"parameters": {
"properties": {
"reasoning": {
"type": "string",
"description": "Your step-by-step reasoning and analysis",
"filled_by_llm": True,
"required": True,
}
}
},
}
]
def get_config_requirements(self):
return {}

View File

@@ -38,6 +38,8 @@ class TodoListTool(Tool):
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
self.collection = db["todos"]
self._last_artifact_id: Optional[str] = None
# -----------------------------
# Action implementations
# -----------------------------
@@ -54,6 +56,8 @@ class TodoListTool(Tool):
if not self.user_id:
return "Error: TodoListTool requires a valid user_id."
self._last_artifact_id = None
if action_name == "list":
return self._list()
@@ -165,6 +169,9 @@ class TodoListTool(Tool):
"""Return configuration requirements."""
return {}
def get_artifact_id(self, action_name: str, **kwargs: Any) -> Optional[str]:
return self._last_artifact_id
# -----------------------------
# Internal helpers
# -----------------------------
@@ -190,11 +197,8 @@ class TodoListTool(Tool):
Returns a simple integer (1, 2, 3, ...) scoped to this user/tool.
With 5-10 todos max, scanning is negligible.
"""
# Find all todos for this user/tool and get their IDs
todos = list(self.collection.find(
{"user_id": self.user_id, "tool_id": self.tool_id},
{"todo_id": 1}
))
query = {"user_id": self.user_id, "tool_id": self.tool_id}
todos = list(self.collection.find(query, {"todo_id": 1}))
# Find the maximum todo_id
max_id = 0
@@ -207,8 +211,8 @@ class TodoListTool(Tool):
def _list(self) -> str:
"""List all todos for the user."""
cursor = self.collection.find({"user_id": self.user_id, "tool_id": self.tool_id})
todos = list(cursor)
query = {"user_id": self.user_id, "tool_id": self.tool_id}
todos = list(self.collection.find(query))
if not todos:
return "No todos found."
@@ -242,7 +246,10 @@ class TodoListTool(Tool):
"created_at": now,
"updated_at": now,
}
self.collection.insert_one(doc)
insert_result = self.collection.insert_one(doc)
inserted_id = getattr(insert_result, "inserted_id", None) or doc.get("_id")
if inserted_id is not None:
self._last_artifact_id = str(inserted_id)
return f"Todo created with ID {todo_id}: {title}"
def _get(self, todo_id: Optional[Any]) -> str:
@@ -251,15 +258,15 @@ class TodoListTool(Tool):
if parsed_todo_id is None:
return "Error: todo_id must be a positive integer."
doc = self.collection.find_one({
"user_id": self.user_id,
"tool_id": self.tool_id,
"todo_id": parsed_todo_id
})
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
doc = self.collection.find_one(query)
if not doc:
return f"Error: Todo with ID {parsed_todo_id} not found."
if doc.get("_id") is not None:
self._last_artifact_id = str(doc.get("_id"))
title = doc.get("title", "Untitled")
status = doc.get("status", "open")
@@ -277,14 +284,17 @@ class TodoListTool(Tool):
if not title:
return "Error: Title is required."
result = self.collection.update_one(
{"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id},
{"$set": {"title": title, "updated_at": datetime.now()}}
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
doc = self.collection.find_one_and_update(
query,
{"$set": {"title": title, "updated_at": datetime.now()}},
)
if result.matched_count == 0:
if not doc:
return f"Error: Todo with ID {parsed_todo_id} not found."
if doc.get("_id") is not None:
self._last_artifact_id = str(doc.get("_id"))
return f"Todo {parsed_todo_id} updated to: {title}"
def _complete(self, todo_id: Optional[Any]) -> str:
@@ -293,14 +303,17 @@ class TodoListTool(Tool):
if parsed_todo_id is None:
return "Error: todo_id must be a positive integer."
result = self.collection.update_one(
{"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id},
{"$set": {"status": "completed", "updated_at": datetime.now()}}
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
doc = self.collection.find_one_and_update(
query,
{"$set": {"status": "completed", "updated_at": datetime.now()}},
)
if result.matched_count == 0:
if not doc:
return f"Error: Todo with ID {parsed_todo_id} not found."
if doc.get("_id") is not None:
self._last_artifact_id = str(doc.get("_id"))
return f"Todo {parsed_todo_id} marked as completed."
def _delete(self, todo_id: Optional[Any]) -> str:
@@ -309,13 +322,12 @@ class TodoListTool(Tool):
if parsed_todo_id is None:
return "Error: todo_id must be a positive integer."
result = self.collection.delete_one({
"user_id": self.user_id,
"tool_id": self.tool_id,
"todo_id": parsed_todo_id
})
if result.deleted_count == 0:
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
doc = self.collection.find_one_and_delete(query)
if not doc:
return f"Error: Todo with ID {parsed_todo_id} not found."
if doc.get("_id") is not None:
self._last_artifact_id = str(doc.get("_id"))
return f"Todo {parsed_todo_id} deleted."

View File

@@ -36,7 +36,7 @@ class ToolManager:
def execute_action(self, tool_name, action_name, user_id=None, **kwargs):
if tool_name not in self.tools:
raise ValueError(f"Tool '{tool_name}' not loaded")
if tool_name in {"mcp_tool", "memory", "todo_list"} and user_id:
if tool_name in {"mcp_tool", "memory", "todo_list", "notes"} and user_id:
tool_config = self.config.get(tool_name, {})
tool = self.load_tool(tool_name, tool_config, user_id)
return tool.execute_action(action_name, **kwargs)

View File

@@ -0,0 +1,231 @@
import logging
from datetime import datetime, timezone
from typing import Any, Dict, Generator, Optional
from application.agents.base import BaseAgent
from application.agents.workflows.schemas import (
ExecutionStatus,
Workflow,
WorkflowEdge,
WorkflowGraph,
WorkflowNode,
WorkflowRun,
)
from application.agents.workflows.workflow_engine import WorkflowEngine
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.logging import log_activity, LogContext
logger = logging.getLogger(__name__)
class WorkflowAgent(BaseAgent):
"""A specialized agent that executes predefined workflows."""
def __init__(
self,
*args,
workflow_id: Optional[str] = None,
workflow: Optional[Dict[str, Any]] = None,
workflow_owner: Optional[str] = None,
**kwargs,
):
super().__init__(*args, **kwargs)
self.workflow_id = workflow_id
self.workflow_owner = workflow_owner
self._workflow_data = workflow
self._engine: Optional[WorkflowEngine] = None
@log_activity()
def gen(
self, query: str, log_context: LogContext = None
) -> Generator[Dict[str, str], None, None]:
yield from self._gen_inner(query, log_context)
def _gen_inner(
self, query: str, log_context: LogContext
) -> Generator[Dict[str, str], None, None]:
graph = self._load_workflow_graph()
if not graph:
yield {"type": "error", "error": "Failed to load workflow configuration."}
return
self._engine = WorkflowEngine(graph, self)
yield from self._engine.execute({}, query)
self._save_workflow_run(query)
def _load_workflow_graph(self) -> Optional[WorkflowGraph]:
if self._workflow_data:
return self._parse_embedded_workflow()
if self.workflow_id:
return self._load_from_database()
return None
def _parse_embedded_workflow(self) -> Optional[WorkflowGraph]:
try:
nodes_data = self._workflow_data.get("nodes", [])
edges_data = self._workflow_data.get("edges", [])
workflow = Workflow(
name=self._workflow_data.get("name", "Embedded Workflow"),
description=self._workflow_data.get("description"),
)
nodes = []
for n in nodes_data:
node_config = n.get("data", {})
nodes.append(
WorkflowNode(
id=n["id"],
workflow_id=self.workflow_id or "embedded",
type=n["type"],
title=n.get("title", "Node"),
description=n.get("description"),
position=n.get("position", {"x": 0, "y": 0}),
config=node_config,
)
)
edges = []
for e in edges_data:
edges.append(
WorkflowEdge(
id=e["id"],
workflow_id=self.workflow_id or "embedded",
source=e.get("source") or e.get("source_id"),
target=e.get("target") or e.get("target_id"),
sourceHandle=e.get("sourceHandle") or e.get("source_handle"),
targetHandle=e.get("targetHandle") or e.get("target_handle"),
)
)
return WorkflowGraph(workflow=workflow, nodes=nodes, edges=edges)
except Exception as e:
logger.error(f"Invalid embedded workflow: {e}")
return None
def _load_from_database(self) -> Optional[WorkflowGraph]:
try:
from bson.objectid import ObjectId
if not self.workflow_id or not ObjectId.is_valid(self.workflow_id):
logger.error(f"Invalid workflow ID: {self.workflow_id}")
return None
owner_id = self.workflow_owner
if not owner_id and isinstance(self.decoded_token, dict):
owner_id = self.decoded_token.get("sub")
if not owner_id:
logger.error(
f"Workflow owner not available for workflow load: {self.workflow_id}"
)
return None
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
workflows_coll = db["workflows"]
workflow_nodes_coll = db["workflow_nodes"]
workflow_edges_coll = db["workflow_edges"]
workflow_doc = workflows_coll.find_one(
{"_id": ObjectId(self.workflow_id), "user": owner_id}
)
if not workflow_doc:
logger.error(
f"Workflow {self.workflow_id} not found or inaccessible for user {owner_id}"
)
return None
workflow = Workflow(**workflow_doc)
graph_version = workflow_doc.get("current_graph_version", 1)
try:
graph_version = int(graph_version)
if graph_version <= 0:
graph_version = 1
except (ValueError, TypeError):
graph_version = 1
nodes_docs = list(
workflow_nodes_coll.find(
{"workflow_id": self.workflow_id, "graph_version": graph_version}
)
)
if not nodes_docs and graph_version == 1:
nodes_docs = list(
workflow_nodes_coll.find(
{
"workflow_id": self.workflow_id,
"graph_version": {"$exists": False},
}
)
)
nodes = [WorkflowNode(**doc) for doc in nodes_docs]
edges_docs = list(
workflow_edges_coll.find(
{"workflow_id": self.workflow_id, "graph_version": graph_version}
)
)
if not edges_docs and graph_version == 1:
edges_docs = list(
workflow_edges_coll.find(
{
"workflow_id": self.workflow_id,
"graph_version": {"$exists": False},
}
)
)
edges = [WorkflowEdge(**doc) for doc in edges_docs]
return WorkflowGraph(workflow=workflow, nodes=nodes, edges=edges)
except Exception as e:
logger.error(f"Failed to load workflow from database: {e}")
return None
def _save_workflow_run(self, query: str) -> None:
if not self._engine:
return
try:
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
workflow_runs_coll = db["workflow_runs"]
run = WorkflowRun(
workflow_id=self.workflow_id or "unknown",
status=self._determine_run_status(),
inputs={"query": query},
outputs=self._serialize_state(self._engine.state),
steps=self._engine.get_execution_summary(),
created_at=datetime.now(timezone.utc),
completed_at=datetime.now(timezone.utc),
)
workflow_runs_coll.insert_one(run.to_mongo_doc())
except Exception as e:
logger.error(f"Failed to save workflow run: {e}")
def _determine_run_status(self) -> ExecutionStatus:
if not self._engine or not self._engine.execution_log:
return ExecutionStatus.COMPLETED
for log in self._engine.execution_log:
if log.get("status") == ExecutionStatus.FAILED.value:
return ExecutionStatus.FAILED
return ExecutionStatus.COMPLETED
def _serialize_state(self, state: Dict[str, Any]) -> Dict[str, Any]:
serialized: Dict[str, Any] = {}
for key, value in state.items():
serialized[key] = self._serialize_state_value(value)
return serialized
def _serialize_state_value(self, value: Any) -> Any:
if isinstance(value, dict):
return {
str(dict_key): self._serialize_state_value(dict_value)
for dict_key, dict_value in value.items()
}
if isinstance(value, list):
return [self._serialize_state_value(item) for item in value]
if isinstance(value, tuple):
return [self._serialize_state_value(item) for item in value]
if isinstance(value, datetime):
return value.isoformat()
if isinstance(value, (str, int, float, bool, type(None))):
return value
return str(value)

View File

@@ -0,0 +1,64 @@
from typing import Any, Dict
import celpy
import celpy.celtypes
class CelEvaluationError(Exception):
pass
def _convert_value(value: Any) -> Any:
if isinstance(value, bool):
return celpy.celtypes.BoolType(value)
if isinstance(value, int):
return celpy.celtypes.IntType(value)
if isinstance(value, float):
return celpy.celtypes.DoubleType(value)
if isinstance(value, str):
return celpy.celtypes.StringType(value)
if isinstance(value, list):
return celpy.celtypes.ListType([_convert_value(item) for item in value])
if isinstance(value, dict):
return celpy.celtypes.MapType(
{celpy.celtypes.StringType(k): _convert_value(v) for k, v in value.items()}
)
if value is None:
return celpy.celtypes.BoolType(False)
return celpy.celtypes.StringType(str(value))
def build_activation(state: Dict[str, Any]) -> Dict[str, Any]:
return {k: _convert_value(v) for k, v in state.items()}
def evaluate_cel(expression: str, state: Dict[str, Any]) -> Any:
if not expression or not expression.strip():
raise CelEvaluationError("Empty expression")
try:
env = celpy.Environment()
ast = env.compile(expression)
program = env.program(ast)
activation = build_activation(state)
result = program.evaluate(activation)
except celpy.CELEvalError as exc:
raise CelEvaluationError(f"CEL evaluation error: {exc}") from exc
except Exception as exc:
raise CelEvaluationError(f"CEL error: {exc}") from exc
return cel_to_python(result)
def cel_to_python(value: Any) -> Any:
if isinstance(value, celpy.celtypes.BoolType):
return bool(value)
if isinstance(value, celpy.celtypes.IntType):
return int(value)
if isinstance(value, celpy.celtypes.DoubleType):
return float(value)
if isinstance(value, celpy.celtypes.StringType):
return str(value)
if isinstance(value, celpy.celtypes.ListType):
return [cel_to_python(item) for item in value]
if isinstance(value, celpy.celtypes.MapType):
return {str(k): cel_to_python(v) for k, v in value.items()}
return value

View File

@@ -0,0 +1,104 @@
"""Workflow Node Agents - defines specialized agents for workflow nodes."""
from typing import Any, Dict, List, Optional, Type
from application.agents.agentic_agent import AgenticAgent
from application.agents.base import BaseAgent
from application.agents.classic_agent import ClassicAgent
from application.agents.research_agent import ResearchAgent
from application.agents.workflows.schemas import AgentType
class ToolFilterMixin:
"""Mixin that filters fetched tools to only those specified in tool_ids."""
_allowed_tool_ids: List[str]
def _get_user_tools(self, user: str = "local") -> Dict[str, Dict[str, Any]]:
all_tools = super()._get_user_tools(user)
if not self._allowed_tool_ids:
return {}
filtered_tools = {
tool_id: tool
for tool_id, tool in all_tools.items()
if str(tool.get("_id", "")) in self._allowed_tool_ids
}
return filtered_tools
def _get_tools(self, api_key: str = None) -> Dict[str, Dict[str, Any]]:
all_tools = super()._get_tools(api_key)
if not self._allowed_tool_ids:
return {}
filtered_tools = {
tool_id: tool
for tool_id, tool in all_tools.items()
if str(tool.get("_id", "")) in self._allowed_tool_ids
}
return filtered_tools
class _WorkflowNodeMixin:
"""Common __init__ for all workflow node agents."""
def __init__(
self,
endpoint: str,
llm_name: str,
model_id: str,
api_key: str,
tool_ids: Optional[List[str]] = None,
**kwargs,
):
super().__init__(
endpoint=endpoint,
llm_name=llm_name,
model_id=model_id,
api_key=api_key,
**kwargs,
)
self._allowed_tool_ids = tool_ids or []
class WorkflowNodeClassicAgent(ToolFilterMixin, _WorkflowNodeMixin, ClassicAgent):
pass
class WorkflowNodeAgenticAgent(ToolFilterMixin, _WorkflowNodeMixin, AgenticAgent):
pass
class WorkflowNodeResearchAgent(ToolFilterMixin, _WorkflowNodeMixin, ResearchAgent):
pass
class WorkflowNodeAgentFactory:
_agents: Dict[AgentType, Type[BaseAgent]] = {
AgentType.CLASSIC: WorkflowNodeClassicAgent,
AgentType.REACT: WorkflowNodeClassicAgent, # backwards compat
AgentType.AGENTIC: WorkflowNodeAgenticAgent,
AgentType.RESEARCH: WorkflowNodeResearchAgent,
}
@classmethod
def create(
cls,
agent_type: AgentType,
endpoint: str,
llm_name: str,
model_id: str,
api_key: str,
tool_ids: Optional[List[str]] = None,
**kwargs,
) -> BaseAgent:
agent_class = cls._agents.get(agent_type)
if not agent_class:
raise ValueError(f"Unsupported agent type: {agent_type}")
return agent_class(
endpoint=endpoint,
llm_name=llm_name,
model_id=model_id,
api_key=api_key,
tool_ids=tool_ids,
**kwargs,
)

View File

@@ -0,0 +1,237 @@
from datetime import datetime, timezone
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Union
from bson import ObjectId
from pydantic import BaseModel, ConfigDict, Field, field_validator
class NodeType(str, Enum):
START = "start"
END = "end"
AGENT = "agent"
NOTE = "note"
STATE = "state"
CONDITION = "condition"
class AgentType(str, Enum):
CLASSIC = "classic"
REACT = "react"
AGENTIC = "agentic"
RESEARCH = "research"
class ExecutionStatus(str, Enum):
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
class Position(BaseModel):
model_config = ConfigDict(extra="forbid")
x: float = 0.0
y: float = 0.0
class AgentNodeConfig(BaseModel):
model_config = ConfigDict(extra="allow")
agent_type: AgentType = AgentType.CLASSIC
llm_name: Optional[str] = None
system_prompt: str = "You are a helpful assistant."
prompt_template: str = ""
output_variable: Optional[str] = None
stream_to_user: bool = True
tools: List[str] = Field(default_factory=list)
sources: List[str] = Field(default_factory=list)
chunks: str = "2"
retriever: str = ""
model_id: Optional[str] = None
json_schema: Optional[Dict[str, Any]] = None
class ConditionCase(BaseModel):
model_config = ConfigDict(extra="forbid", populate_by_name=True)
name: Optional[str] = None
expression: str = ""
source_handle: str = Field(..., alias="sourceHandle")
class ConditionNodeConfig(BaseModel):
model_config = ConfigDict(extra="allow")
mode: Literal["simple", "advanced"] = "simple"
cases: List[ConditionCase] = Field(default_factory=list)
class StateOperation(BaseModel):
model_config = ConfigDict(extra="forbid")
expression: str = ""
target_variable: str = ""
class WorkflowEdgeCreate(BaseModel):
model_config = ConfigDict(populate_by_name=True)
id: str
workflow_id: str
source_id: str = Field(..., alias="source")
target_id: str = Field(..., alias="target")
source_handle: Optional[str] = Field(None, alias="sourceHandle")
target_handle: Optional[str] = Field(None, alias="targetHandle")
class WorkflowEdge(WorkflowEdgeCreate):
mongo_id: Optional[str] = Field(None, alias="_id")
@field_validator("mongo_id", mode="before")
@classmethod
def convert_objectid(cls, v: Any) -> Optional[str]:
if isinstance(v, ObjectId):
return str(v)
return v
def to_mongo_doc(self) -> Dict[str, Any]:
return {
"id": self.id,
"workflow_id": self.workflow_id,
"source_id": self.source_id,
"target_id": self.target_id,
"source_handle": self.source_handle,
"target_handle": self.target_handle,
}
class WorkflowNodeCreate(BaseModel):
model_config = ConfigDict(extra="allow")
id: str
workflow_id: str
type: NodeType
title: str = "Node"
description: Optional[str] = None
position: Position = Field(default_factory=Position)
config: Dict[str, Any] = Field(default_factory=dict)
@field_validator("position", mode="before")
@classmethod
def parse_position(cls, v: Union[Dict[str, float], Position]) -> Position:
if isinstance(v, dict):
return Position(**v)
return v
class WorkflowNode(WorkflowNodeCreate):
mongo_id: Optional[str] = Field(None, alias="_id")
@field_validator("mongo_id", mode="before")
@classmethod
def convert_objectid(cls, v: Any) -> Optional[str]:
if isinstance(v, ObjectId):
return str(v)
return v
def to_mongo_doc(self) -> Dict[str, Any]:
return {
"id": self.id,
"workflow_id": self.workflow_id,
"type": self.type.value,
"title": self.title,
"description": self.description,
"position": self.position.model_dump(),
"config": self.config,
}
class WorkflowCreate(BaseModel):
model_config = ConfigDict(extra="allow")
name: str = "New Workflow"
description: Optional[str] = None
user: Optional[str] = None
class Workflow(WorkflowCreate):
id: Optional[str] = Field(None, alias="_id")
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
@field_validator("id", mode="before")
@classmethod
def convert_objectid(cls, v: Any) -> Optional[str]:
if isinstance(v, ObjectId):
return str(v)
return v
def to_mongo_doc(self) -> Dict[str, Any]:
return {
"name": self.name,
"description": self.description,
"user": self.user,
"created_at": self.created_at,
"updated_at": self.updated_at,
}
class WorkflowGraph(BaseModel):
workflow: Workflow
nodes: List[WorkflowNode] = Field(default_factory=list)
edges: List[WorkflowEdge] = Field(default_factory=list)
def get_node_by_id(self, node_id: str) -> Optional[WorkflowNode]:
for node in self.nodes:
if node.id == node_id:
return node
return None
def get_start_node(self) -> Optional[WorkflowNode]:
for node in self.nodes:
if node.type == NodeType.START:
return node
return None
def get_outgoing_edges(self, node_id: str) -> List[WorkflowEdge]:
return [edge for edge in self.edges if edge.source_id == node_id]
class NodeExecutionLog(BaseModel):
model_config = ConfigDict(extra="forbid")
node_id: str
node_type: str
status: ExecutionStatus
started_at: datetime
completed_at: Optional[datetime] = None
error: Optional[str] = None
state_snapshot: Dict[str, Any] = Field(default_factory=dict)
class WorkflowRunCreate(BaseModel):
workflow_id: str
inputs: Dict[str, str] = Field(default_factory=dict)
class WorkflowRun(BaseModel):
model_config = ConfigDict(extra="allow")
id: Optional[str] = Field(None, alias="_id")
workflow_id: str
status: ExecutionStatus = ExecutionStatus.PENDING
inputs: Dict[str, str] = Field(default_factory=dict)
outputs: Dict[str, Any] = Field(default_factory=dict)
steps: List[NodeExecutionLog] = Field(default_factory=list)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
completed_at: Optional[datetime] = None
@field_validator("id", mode="before")
@classmethod
def convert_objectid(cls, v: Any) -> Optional[str]:
if isinstance(v, ObjectId):
return str(v)
return v
def to_mongo_doc(self) -> Dict[str, Any]:
return {
"workflow_id": self.workflow_id,
"status": self.status.value,
"inputs": self.inputs,
"outputs": self.outputs,
"steps": [step.model_dump() for step in self.steps],
"created_at": self.created_at,
"completed_at": self.completed_at,
}

View File

@@ -0,0 +1,470 @@
import json
import logging
from datetime import datetime, timezone
from typing import Any, Dict, Generator, List, Optional, TYPE_CHECKING
from application.agents.workflows.cel_evaluator import CelEvaluationError, evaluate_cel
from application.agents.workflows.node_agent import WorkflowNodeAgentFactory
from application.agents.workflows.schemas import (
AgentNodeConfig,
AgentType,
ConditionNodeConfig,
ExecutionStatus,
NodeExecutionLog,
NodeType,
WorkflowGraph,
WorkflowNode,
)
from application.core.json_schema_utils import (
JsonSchemaValidationError,
normalize_json_schema_payload,
)
from application.error import sanitize_api_error
from application.templates.namespaces import NamespaceManager
from application.templates.template_engine import TemplateEngine, TemplateRenderError
try:
import jsonschema
except ImportError: # pragma: no cover - optional dependency in some deployments.
jsonschema = None
if TYPE_CHECKING:
from application.agents.base import BaseAgent
logger = logging.getLogger(__name__)
StateValue = Any
WorkflowState = Dict[str, StateValue]
TEMPLATE_RESERVED_NAMESPACES = {"agent", "system", "source", "tools", "passthrough"}
class WorkflowEngine:
MAX_EXECUTION_STEPS = 50
def __init__(self, graph: WorkflowGraph, agent: "BaseAgent"):
self.graph = graph
self.agent = agent
self.state: WorkflowState = {}
self.execution_log: List[Dict[str, Any]] = []
self._condition_result: Optional[str] = None
self._template_engine = TemplateEngine()
self._namespace_manager = NamespaceManager()
def execute(
self, initial_inputs: WorkflowState, query: str
) -> Generator[Dict[str, str], None, None]:
self._initialize_state(initial_inputs, query)
start_node = self.graph.get_start_node()
if not start_node:
yield {"type": "error", "error": "No start node found in workflow."}
return
current_node_id: Optional[str] = start_node.id
steps = 0
while current_node_id and steps < self.MAX_EXECUTION_STEPS:
node = self.graph.get_node_by_id(current_node_id)
if not node:
yield {"type": "error", "error": f"Node {current_node_id} not found."}
break
log_entry = self._create_log_entry(node)
yield {
"type": "workflow_step",
"node_id": node.id,
"node_type": node.type.value,
"node_title": node.title,
"status": "running",
}
try:
yield from self._execute_node(node)
log_entry["status"] = ExecutionStatus.COMPLETED.value
log_entry["completed_at"] = datetime.now(timezone.utc)
output_key = f"node_{node.id}_output"
node_output = self.state.get(output_key)
yield {
"type": "workflow_step",
"node_id": node.id,
"node_type": node.type.value,
"node_title": node.title,
"status": "completed",
"state_snapshot": dict(self.state),
"output": node_output,
}
except Exception as e:
logger.error(f"Error executing node {node.id}: {e}", exc_info=True)
log_entry["status"] = ExecutionStatus.FAILED.value
log_entry["error"] = str(e)
log_entry["completed_at"] = datetime.now(timezone.utc)
log_entry["state_snapshot"] = dict(self.state)
self.execution_log.append(log_entry)
user_friendly_error = sanitize_api_error(e)
yield {
"type": "workflow_step",
"node_id": node.id,
"node_type": node.type.value,
"node_title": node.title,
"status": "failed",
"state_snapshot": dict(self.state),
"error": user_friendly_error,
}
yield {"type": "error", "error": user_friendly_error}
break
log_entry["state_snapshot"] = dict(self.state)
self.execution_log.append(log_entry)
if node.type == NodeType.END:
break
current_node_id = self._get_next_node_id(current_node_id)
if current_node_id is None and node.type != NodeType.END:
logger.warning(
f"Branch ended at node '{node.title}' ({node.id}) without reaching an end node"
)
steps += 1
if steps >= self.MAX_EXECUTION_STEPS:
logger.warning(
f"Workflow reached max steps limit ({self.MAX_EXECUTION_STEPS})"
)
def _initialize_state(self, initial_inputs: WorkflowState, query: str) -> None:
self.state.update(initial_inputs)
self.state["query"] = query
self.state["chat_history"] = str(self.agent.chat_history)
def _create_log_entry(self, node: WorkflowNode) -> Dict[str, Any]:
return {
"node_id": node.id,
"node_type": node.type.value,
"started_at": datetime.now(timezone.utc),
"completed_at": None,
"status": ExecutionStatus.RUNNING.value,
"error": None,
"state_snapshot": {},
}
def _get_next_node_id(self, current_node_id: str) -> Optional[str]:
node = self.graph.get_node_by_id(current_node_id)
edges = self.graph.get_outgoing_edges(current_node_id)
if not edges:
return None
if node and node.type == NodeType.CONDITION and self._condition_result:
target_handle = self._condition_result
self._condition_result = None
for edge in edges:
if edge.source_handle == target_handle:
return edge.target_id
return None
return edges[0].target_id
def _execute_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
logger.info(f"Executing node {node.id} ({node.type.value})")
node_handlers = {
NodeType.START: self._execute_start_node,
NodeType.NOTE: self._execute_note_node,
NodeType.AGENT: self._execute_agent_node,
NodeType.STATE: self._execute_state_node,
NodeType.CONDITION: self._execute_condition_node,
NodeType.END: self._execute_end_node,
}
handler = node_handlers.get(node.type)
if handler:
yield from handler(node)
def _execute_start_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
yield from ()
def _execute_note_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
yield from ()
def _execute_agent_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
from application.core.model_utils import (
get_api_key_for_provider,
get_model_capabilities,
get_provider_from_model_id,
)
node_config = AgentNodeConfig(**node.config.get("config", node.config))
if node_config.prompt_template:
formatted_prompt = self._format_template(node_config.prompt_template)
else:
formatted_prompt = self.state.get("query", "")
node_json_schema = self._normalize_node_json_schema(
node_config.json_schema, node.title
)
node_model_id = node_config.model_id or self.agent.model_id
node_llm_name = (
node_config.llm_name
or get_provider_from_model_id(node_model_id or "")
or self.agent.llm_name
)
node_api_key = get_api_key_for_provider(node_llm_name) or self.agent.api_key
if node_json_schema and node_model_id:
model_capabilities = get_model_capabilities(node_model_id)
if model_capabilities and not model_capabilities.get(
"supports_structured_output", False
):
raise ValueError(
f'Model "{node_model_id}" does not support structured output for node "{node.title}"'
)
factory_kwargs = {
"agent_type": node_config.agent_type,
"endpoint": self.agent.endpoint,
"llm_name": node_llm_name,
"model_id": node_model_id,
"api_key": node_api_key,
"tool_ids": node_config.tools,
"prompt": node_config.system_prompt,
"chat_history": self.agent.chat_history,
"decoded_token": self.agent.decoded_token,
"json_schema": node_json_schema,
}
# Agentic/research agents need retriever_config for on-demand search
if node_config.agent_type in (AgentType.AGENTIC, AgentType.RESEARCH):
factory_kwargs["retriever_config"] = {
"source": {"active_docs": node_config.sources} if node_config.sources else {},
"retriever_name": node_config.retriever or "classic",
"chunks": int(node_config.chunks) if node_config.chunks else 2,
"model_id": node_model_id,
"llm_name": node_llm_name,
"api_key": node_api_key,
"decoded_token": self.agent.decoded_token,
}
node_agent = WorkflowNodeAgentFactory.create(**factory_kwargs)
full_response_parts: List[str] = []
structured_response_parts: List[str] = []
has_structured_response = False
first_chunk = True
for event in node_agent.gen(formatted_prompt):
if "answer" in event:
chunk = str(event["answer"])
full_response_parts.append(chunk)
if event.get("structured"):
has_structured_response = True
structured_response_parts.append(chunk)
if node_config.stream_to_user:
if first_chunk and hasattr(self, "_has_streamed"):
yield {"answer": "\n\n"}
first_chunk = False
yield event
if node_config.stream_to_user:
self._has_streamed = True
full_response = "".join(full_response_parts).strip()
output_value: Any = full_response
if has_structured_response:
structured_response = "".join(structured_response_parts).strip()
response_to_parse = structured_response or full_response
parsed_success, parsed_structured = self._parse_structured_output(
response_to_parse
)
output_value = parsed_structured if parsed_success else response_to_parse
if node_json_schema:
self._validate_structured_output(node_json_schema, output_value)
elif node_json_schema:
parsed_success, parsed_structured = self._parse_structured_output(
full_response
)
if not parsed_success:
raise ValueError(
"Structured output was expected but response was not valid JSON"
)
output_value = parsed_structured
self._validate_structured_output(node_json_schema, output_value)
default_output_key = f"node_{node.id}_output"
self.state[default_output_key] = output_value
if node_config.output_variable:
self.state[node_config.output_variable] = output_value
def _execute_state_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
config = node.config.get("config", node.config)
for op in config.get("operations", []):
expression = op.get("expression", "")
target_variable = op.get("target_variable", "")
if expression and target_variable:
self.state[target_variable] = evaluate_cel(expression, self.state)
yield from ()
def _execute_condition_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
config = ConditionNodeConfig(**node.config.get("config", node.config))
matched_handle = None
for case in config.cases:
if not case.expression.strip():
continue
try:
if evaluate_cel(case.expression, self.state):
matched_handle = case.source_handle
break
except CelEvaluationError:
continue
self._condition_result = matched_handle or "else"
yield from ()
def _execute_end_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
config = node.config.get("config", node.config)
output_template = str(config.get("output_template", ""))
if output_template:
formatted_output = self._format_template(output_template)
yield {"answer": formatted_output}
def _parse_structured_output(self, raw_response: str) -> tuple[bool, Optional[Any]]:
normalized_response = raw_response.strip()
if not normalized_response:
return False, None
try:
return True, json.loads(normalized_response)
except json.JSONDecodeError:
logger.warning(
"Workflow agent returned structured output that was not valid JSON"
)
return False, None
def _normalize_node_json_schema(
self, schema: Optional[Dict[str, Any]], node_title: str
) -> Optional[Dict[str, Any]]:
if schema is None:
return None
try:
return normalize_json_schema_payload(schema)
except JsonSchemaValidationError as exc:
raise ValueError(
f'Invalid JSON schema for node "{node_title}": {exc}'
) from exc
def _validate_structured_output(self, schema: Dict[str, Any], output_value: Any) -> None:
if jsonschema is None:
logger.warning(
"jsonschema package is not available, skipping structured output validation"
)
return
try:
normalized_schema = normalize_json_schema_payload(schema)
except JsonSchemaValidationError as exc:
raise ValueError(f"Invalid JSON schema: {exc}") from exc
try:
jsonschema.validate(instance=output_value, schema=normalized_schema)
except jsonschema.exceptions.ValidationError as exc:
raise ValueError(f"Structured output did not match schema: {exc.message}") from exc
except jsonschema.exceptions.SchemaError as exc:
raise ValueError(f"Invalid JSON schema: {exc.message}") from exc
def _format_template(self, template: str) -> str:
context = self._build_template_context()
try:
return self._template_engine.render(template, context)
except TemplateRenderError as e:
logger.warning(
"Workflow template rendering failed, using raw template: %s", str(e)
)
return template
def _build_template_context(self) -> Dict[str, Any]:
docs, docs_together = self._get_source_template_data()
passthrough_data = (
self.state.get("passthrough")
if isinstance(self.state.get("passthrough"), dict)
else None
)
tools_data = (
self.state.get("tools") if isinstance(self.state.get("tools"), dict) else None
)
context = self._namespace_manager.build_context(
user_id=getattr(self.agent, "user", None),
request_id=getattr(self.agent, "request_id", None),
passthrough_data=passthrough_data,
docs=docs,
docs_together=docs_together,
tools_data=tools_data,
)
agent_context: Dict[str, Any] = {}
for key, value in self.state.items():
if not isinstance(key, str):
continue
normalized_key = key.strip()
if not normalized_key:
continue
agent_context[normalized_key] = value
context["agent"] = agent_context
# Keep legacy top-level variables working while namespaced variables are adopted.
for key, value in agent_context.items():
if key in TEMPLATE_RESERVED_NAMESPACES:
context[f"agent_{key}"] = value
continue
if key not in context:
context[key] = value
return context
def _get_source_template_data(self) -> tuple[Optional[List[Dict[str, Any]]], Optional[str]]:
docs = getattr(self.agent, "retrieved_docs", None)
if not isinstance(docs, list) or len(docs) == 0:
return None, None
docs_together_parts: List[str] = []
for doc in docs:
if not isinstance(doc, dict):
continue
text = doc.get("text")
if not isinstance(text, str):
continue
filename = doc.get("filename") or doc.get("title") or doc.get("source")
if isinstance(filename, str) and filename.strip():
docs_together_parts.append(f"{filename}\n{text}")
else:
docs_together_parts.append(text)
docs_together = "\n\n".join(docs_together_parts) if docs_together_parts else None
return docs, docs_together
def get_execution_summary(self) -> List[NodeExecutionLog]:
return [
NodeExecutionLog(
node_id=log["node_id"],
node_type=log["node_type"],
status=ExecutionStatus(log["status"]),
started_at=log["started_at"],
completed_at=log.get("completed_at"),
error=log.get("error"),
state_snapshot=log.get("state_snapshot", {}),
)
for log in self.execution_log
]

View File

@@ -3,6 +3,7 @@ from flask import Blueprint
from application.api import api
from application.api.answer.routes.answer import AnswerResource
from application.api.answer.routes.base import answer_ns
from application.api.answer.routes.search import SearchResource
from application.api.answer.routes.stream import StreamResource
@@ -14,6 +15,7 @@ api.add_namespace(answer_ns)
def init_answer_routes():
api.add_resource(StreamResource, "/stream")
api.add_resource(AnswerResource, "/api/answer")
api.add_resource(SearchResource, "/api/search")
init_answer_routes()

View File

@@ -40,9 +40,9 @@ class AnswerResource(Resource, BaseAnswerResource):
"chunks": fields.Integer(
required=False, default=2, description="Number of chunks"
),
"token_limit": fields.Integer(required=False, description="Token limit"),
"retriever": fields.String(required=False, description="Retriever type"),
"api_key": fields.String(required=False, description="API key"),
"agent_id": fields.String(required=False, description="Agent ID"),
"active_docs": fields.String(
required=False, description="Active documents"
),
@@ -74,21 +74,10 @@ class AnswerResource(Resource, BaseAnswerResource):
decoded_token = getattr(request, "decoded_token", None)
processor = StreamProcessor(data, decoded_token)
try:
processor.initialize()
agent = processor.build_agent(data.get("question", ""))
if not processor.decoded_token:
return make_response({"error": "Unauthorized"}, 401)
docs_together, docs_list = processor.pre_fetch_docs(
data.get("question", "")
)
tools_data = processor.pre_fetch_tools()
agent = processor.create_agent(
docs_together=docs_together,
docs=docs_list,
tools_data=tools_data,
)
if error := self.check_usage(processor.agent_config):
return error
@@ -101,6 +90,9 @@ class AnswerResource(Resource, BaseAnswerResource):
isNoneDoc=data.get("isNoneDoc"),
index=None,
should_save_conversation=data.get("save_conversation", True),
agent_id=processor.agent_id,
is_shared_usage=processor.is_shared_usage,
shared_token=processor.shared_token,
model_id=processor.model_id,
)
stream_result = self.process_response_stream(stream)
@@ -138,5 +130,5 @@ class AnswerResource(Resource, BaseAnswerResource):
f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",
extra={"error": str(e), "traceback": traceback.format_exc()},
)
return make_response({"error": str(e)}, 500)
return make_response({"error": "An error occurred processing your request"}, 500)
return make_response(result, 200)

View File

@@ -15,6 +15,7 @@ from application.core.model_utils import (
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.error import sanitize_api_error
from application.llm.llm_creator import LLMCreator
from application.utils import check_required_fields
@@ -46,6 +47,27 @@ class BaseAnswerResource:
return missing_fields
return None
@staticmethod
def _prepare_tool_calls_for_logging(
tool_calls: Optional[List[Dict[str, Any]]], max_chars: int = 10000
) -> List[Dict[str, Any]]:
if not tool_calls:
return []
prepared = []
for tool_call in tool_calls:
if not isinstance(tool_call, dict):
prepared.append({"result": str(tool_call)[:max_chars]})
continue
item = dict(tool_call)
for key in ("result", "result_full"):
value = item.get(key)
if isinstance(value, str) and len(value) > max_chars:
item[key] = value[:max_chars]
prepared.append(item)
return prepared
def check_usage(self, agent_config: Dict) -> Optional[Response]:
"""Check if there is a usage limit and if it is exceeded
@@ -184,9 +206,12 @@ class BaseAnswerResource:
is_structured = False
schema_info = None
structured_chunks = []
query_metadata = {}
for line in agent.gen(query=question):
if "answer" in line:
if "metadata" in line:
query_metadata.update(line["metadata"])
elif "answer" in line:
response_full += str(line["answer"])
if line.get("structured"):
is_structured = True
@@ -219,7 +244,14 @@ class BaseAnswerResource:
data = json.dumps({"type": "thought", "thought": line["thought"]})
yield f"data: {data}\n\n"
elif "type" in line:
data = json.dumps(line)
if line.get("type") == "error":
sanitized_error = {
"type": "error",
"error": sanitize_api_error(line.get("error", "An error occurred"))
}
data = json.dumps(sanitized_error)
else:
data = json.dumps(line)
yield f"data: {data}\n\n"
if is_structured and structured_chunks:
structured_data = {
@@ -246,6 +278,7 @@ class BaseAnswerResource:
user_api_key=user_api_key,
decoded_token=decoded_token,
model_id=model_id,
agent_id=agent_id,
)
if should_save_conversation:
@@ -265,21 +298,48 @@ class BaseAnswerResource:
is_shared_usage=is_shared_usage,
shared_token=shared_token,
attachment_ids=attachment_ids,
metadata=query_metadata if query_metadata else None,
)
# Persist compression metadata/summary if it exists and wasn't saved mid-execution
compression_meta = getattr(agent, "compression_metadata", None)
compression_saved = getattr(agent, "compression_saved", False)
if conversation_id and compression_meta and not compression_saved:
try:
self.conversation_service.update_compression_metadata(
conversation_id, compression_meta
)
self.conversation_service.append_compression_message(
conversation_id, compression_meta
)
agent.compression_saved = True
logger.info(
f"Persisted compression metadata for conversation {conversation_id}"
)
except Exception as e:
logger.error(
f"Failed to persist compression metadata: {str(e)}",
exc_info=True,
)
else:
conversation_id = None
id_data = {"type": "id", "id": str(conversation_id)}
data = json.dumps(id_data)
yield f"data: {data}\n\n"
tool_calls_for_logging = self._prepare_tool_calls_for_logging(
getattr(agent, "tool_calls", tool_calls) or tool_calls
)
log_data = {
"action": "stream_answer",
"level": "info",
"user": decoded_token.get("sub"),
"api_key": user_api_key,
"agent_id": agent_id,
"question": question,
"response": response_full,
"sources": source_log_docs,
"tool_calls": tool_calls_for_logging,
"attachments": attachment_ids,
"timestamp": datetime.datetime.now(datetime.timezone.utc),
}
@@ -310,6 +370,7 @@ class BaseAnswerResource:
api_key=settings.API_KEY,
user_api_key=user_api_key,
decoded_token=decoded_token,
agent_id=agent_id,
)
self.conversation_service.save_conversation(
conversation_id,
@@ -327,7 +388,27 @@ class BaseAnswerResource:
is_shared_usage=is_shared_usage,
shared_token=shared_token,
attachment_ids=attachment_ids,
metadata=query_metadata if query_metadata else None,
)
compression_meta = getattr(agent, "compression_metadata", None)
compression_saved = getattr(agent, "compression_saved", False)
if conversation_id and compression_meta and not compression_saved:
try:
self.conversation_service.update_compression_metadata(
conversation_id, compression_meta
)
self.conversation_service.append_compression_message(
conversation_id, compression_meta
)
agent.compression_saved = True
logger.info(
f"Persisted compression metadata for conversation {conversation_id} (partial stream)"
)
except Exception as e:
logger.error(
f"Failed to persist compression metadata (partial stream): {str(e)}",
exc_info=True,
)
except Exception as e:
logger.error(
f"Error saving partial response: {str(e)}", exc_info=True

View File

@@ -0,0 +1,186 @@
import logging
from typing import Any, Dict, List
from flask import make_response, request
from flask_restx import fields, Resource
from bson.dbref import DBRef
from application.api.answer.routes.base import answer_ns
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.vectorstore.vector_creator import VectorCreator
logger = logging.getLogger(__name__)
@answer_ns.route("/api/search")
class SearchResource(Resource):
"""Fast search endpoint for retrieving relevant documents"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
mongo = MongoDB.get_client()
self.db = mongo[settings.MONGO_DB_NAME]
self.agents_collection = self.db["agents"]
search_model = answer_ns.model(
"SearchModel",
{
"question": fields.String(
required=True, description="Search query"
),
"api_key": fields.String(
required=True, description="API key for authentication"
),
"chunks": fields.Integer(
required=False, default=5, description="Number of results to return"
),
},
)
def _get_sources_from_api_key(self, api_key: str) -> List[str]:
"""Get source IDs connected to the API key/agent.
"""
agent_data = self.agents_collection.find_one({"key": api_key})
if not agent_data:
return []
source_ids = []
# Handle multiple sources (only if non-empty)
sources = agent_data.get("sources", [])
if sources and isinstance(sources, list) and len(sources) > 0:
for source_ref in sources:
# Skip "default" - it's a placeholder, not an actual vectorstore
if source_ref == "default":
continue
elif isinstance(source_ref, DBRef):
source_doc = self.db.dereference(source_ref)
if source_doc:
source_ids.append(str(source_doc["_id"]))
# Handle single source (legacy) - check if sources was empty or didn't yield results
if not source_ids:
source = agent_data.get("source")
if isinstance(source, DBRef):
source_doc = self.db.dereference(source)
if source_doc:
source_ids.append(str(source_doc["_id"]))
# Skip "default" - it's a placeholder, not an actual vectorstore
elif source and source != "default":
source_ids.append(source)
return source_ids
def _search_vectorstores(
self, query: str, source_ids: List[str], chunks: int
) -> List[Dict[str, Any]]:
"""Search across vectorstores and return results"""
if not source_ids:
return []
results = []
chunks_per_source = max(1, chunks // len(source_ids))
seen_texts = set()
for source_id in source_ids:
if not source_id or not source_id.strip():
continue
try:
docsearch = VectorCreator.create_vectorstore(
settings.VECTOR_STORE, source_id, settings.EMBEDDINGS_KEY
)
docs = docsearch.search(query, k=chunks_per_source * 2)
for doc in docs:
if len(results) >= chunks:
break
if hasattr(doc, "page_content") and hasattr(doc, "metadata"):
page_content = doc.page_content
metadata = doc.metadata
else:
page_content = doc.get("text", doc.get("page_content", ""))
metadata = doc.get("metadata", {})
# Skip duplicates
text_hash = hash(page_content[:200])
if text_hash in seen_texts:
continue
seen_texts.add(text_hash)
title = metadata.get(
"title", metadata.get("post_title", "")
)
if not isinstance(title, str):
title = str(title) if title else ""
# Clean up title
if title:
title = title.split("/")[-1]
else:
# Use filename or first part of content as title
title = metadata.get("filename", page_content[:50] + "...")
source = metadata.get("source", source_id)
results.append({
"text": page_content,
"title": title,
"source": source,
})
if len(results) >= chunks:
break
except Exception as e:
logger.error(
f"Error searching vectorstore {source_id}: {e}",
exc_info=True,
)
continue
return results[:chunks]
@answer_ns.expect(search_model)
@answer_ns.doc(description="Search for relevant documents based on query")
def post(self):
data = request.get_json()
question = data.get("question")
api_key = data.get("api_key")
chunks = data.get("chunks", 5)
if not question:
return make_response({"error": "question is required"}, 400)
if not api_key:
return make_response({"error": "api_key is required"}, 400)
# Validate API key
agent = self.agents_collection.find_one({"key": api_key})
if not agent:
return make_response({"error": "Invalid API key"}, 401)
try:
# Get sources connected to this API key
source_ids = self._get_sources_from_api_key(api_key)
if not source_ids:
return make_response([], 200)
# Perform search
results = self._search_vectorstores(question, source_ids, chunks)
return make_response(results, 200)
except Exception as e:
logger.error(
f"/api/search - error: {str(e)}",
extra={"error": str(e)},
exc_info=True,
)
return make_response({"error": "Search failed"}, 500)

View File

@@ -40,9 +40,9 @@ class StreamResource(Resource, BaseAnswerResource):
"chunks": fields.Integer(
required=False, default=2, description="Number of chunks"
),
"token_limit": fields.Integer(required=False, description="Token limit"),
"retriever": fields.String(required=False, description="Retriever type"),
"api_key": fields.String(required=False, description="API key"),
"agent_id": fields.String(required=False, description="Agent ID"),
"active_docs": fields.String(
required=False, description="Active documents"
),
@@ -80,14 +80,13 @@ class StreamResource(Resource, BaseAnswerResource):
decoded_token = getattr(request, "decoded_token", None)
processor = StreamProcessor(data, decoded_token)
try:
processor.initialize()
docs_together, docs_list = processor.pre_fetch_docs(data["question"])
tools_data = processor.pre_fetch_tools()
agent = processor.create_agent(
docs_together=docs_together, docs=docs_list, tools_data=tools_data
)
agent = processor.build_agent(data["question"])
if not processor.decoded_token:
return Response(
self.error_stream_generate("Unauthorized"),
status=401,
mimetype="text/event-stream",
)
if error := self.check_usage(processor.agent_config):
return error
@@ -102,7 +101,7 @@ class StreamResource(Resource, BaseAnswerResource):
index=data.get("index"),
should_save_conversation=data.get("save_conversation", True),
attachment_ids=data.get("attachments", []),
agent_id=data.get("agent_id"),
agent_id=processor.agent_id,
is_shared_usage=processor.is_shared_usage,
shared_token=processor.shared_token,
model_id=processor.model_id,

View File

@@ -0,0 +1,20 @@
"""
Compression module for managing conversation context compression.
"""
from application.api.answer.services.compression.orchestrator import (
CompressionOrchestrator,
)
from application.api.answer.services.compression.service import CompressionService
from application.api.answer.services.compression.types import (
CompressionResult,
CompressionMetadata,
)
__all__ = [
"CompressionOrchestrator",
"CompressionService",
"CompressionResult",
"CompressionMetadata",
]

View File

@@ -0,0 +1,234 @@
"""Message reconstruction utilities for compression."""
import logging
import uuid
from typing import Dict, List, Optional
logger = logging.getLogger(__name__)
class MessageBuilder:
"""Builds message arrays from compressed context."""
@staticmethod
def build_from_compressed_context(
system_prompt: str,
compressed_summary: Optional[str],
recent_queries: List[Dict],
include_tool_calls: bool = False,
context_type: str = "pre_request",
) -> List[Dict]:
"""
Build messages from compressed context.
Args:
system_prompt: Original system prompt
compressed_summary: Compressed summary (if any)
recent_queries: Recent uncompressed queries
include_tool_calls: Whether to include tool calls from history
context_type: Type of context ('pre_request' or 'mid_execution')
Returns:
List of message dicts ready for LLM
"""
# Append compression summary to system prompt if present
if compressed_summary:
system_prompt = MessageBuilder._append_compression_context(
system_prompt, compressed_summary, context_type
)
messages = [{"role": "system", "content": system_prompt}]
# Add recent history
for query in recent_queries:
if "prompt" in query and "response" in query:
messages.append({"role": "user", "content": query["prompt"]})
messages.append({"role": "assistant", "content": query["response"]})
# Add tool calls from history if present
if include_tool_calls and "tool_calls" in query:
for tool_call in query["tool_calls"]:
call_id = tool_call.get("call_id") or str(uuid.uuid4())
function_call_dict = {
"function_call": {
"name": tool_call.get("action_name"),
"args": tool_call.get("arguments"),
"call_id": call_id,
}
}
function_response_dict = {
"function_response": {
"name": tool_call.get("action_name"),
"response": {"result": tool_call.get("result")},
"call_id": call_id,
}
}
messages.append(
{"role": "assistant", "content": [function_call_dict]}
)
messages.append(
{"role": "tool", "content": [function_response_dict]}
)
# If no recent queries (everything was compressed), add a continuation user message
if len(recent_queries) == 0 and compressed_summary:
messages.append({
"role": "user",
"content": "Please continue with the remaining tasks based on the context above."
})
logger.info("Added continuation user message to maintain proper turn-taking after full compression")
return messages
@staticmethod
def _append_compression_context(
system_prompt: str, compressed_summary: str, context_type: str = "pre_request"
) -> str:
"""
Append compression context to system prompt.
Args:
system_prompt: Original system prompt
compressed_summary: Summary to append
context_type: Type of compression context
Returns:
Updated system prompt
"""
# Remove existing compression context if present
if "This session is being continued" in system_prompt or "Context window limit reached" in system_prompt:
parts = system_prompt.split("\n\n---\n\n")
system_prompt = parts[0]
# Build appropriate context message based on type
if context_type == "mid_execution":
context_message = (
"\n\n---\n\n"
"Context window limit reached during execution. "
"Previous conversation has been compressed to fit within limits. "
"The conversation is summarized below:\n\n"
f"{compressed_summary}"
)
else: # pre_request
context_message = (
"\n\n---\n\n"
"This session is being continued from a previous conversation that "
"has been compressed to fit within context limits. "
"The conversation is summarized below:\n\n"
f"{compressed_summary}"
)
return system_prompt + context_message
@staticmethod
def rebuild_messages_after_compression(
messages: List[Dict],
compressed_summary: Optional[str],
recent_queries: List[Dict],
include_current_execution: bool = False,
include_tool_calls: bool = False,
) -> Optional[List[Dict]]:
"""
Rebuild the message list after compression so tool execution can continue.
Args:
messages: Original message list
compressed_summary: Compressed summary
recent_queries: Recent uncompressed queries
include_current_execution: Whether to preserve current execution messages
include_tool_calls: Whether to include tool calls from history
Returns:
Rebuilt message list or None if failed
"""
# Find the system message
system_message = next(
(msg for msg in messages if msg.get("role") == "system"), None
)
if not system_message:
logger.warning("No system message found in messages list")
return None
# Update system message with compressed summary
if compressed_summary:
content = system_message.get("content", "")
system_message["content"] = MessageBuilder._append_compression_context(
content, compressed_summary, "mid_execution"
)
logger.info(
"Appended compression summary to system prompt (truncated): %s",
(
compressed_summary[:500] + "..."
if len(compressed_summary) > 500
else compressed_summary
),
)
rebuilt_messages = [system_message]
# Add recent history from compressed context
for query in recent_queries:
if "prompt" in query and "response" in query:
rebuilt_messages.append({"role": "user", "content": query["prompt"]})
rebuilt_messages.append(
{"role": "assistant", "content": query["response"]}
)
# Add tool calls from history if present
if include_tool_calls and "tool_calls" in query:
for tool_call in query["tool_calls"]:
call_id = tool_call.get("call_id") or str(uuid.uuid4())
function_call_dict = {
"function_call": {
"name": tool_call.get("action_name"),
"args": tool_call.get("arguments"),
"call_id": call_id,
}
}
function_response_dict = {
"function_response": {
"name": tool_call.get("action_name"),
"response": {"result": tool_call.get("result")},
"call_id": call_id,
}
}
rebuilt_messages.append(
{"role": "assistant", "content": [function_call_dict]}
)
rebuilt_messages.append(
{"role": "tool", "content": [function_response_dict]}
)
# If no recent queries (everything was compressed), add a continuation user message
if len(recent_queries) == 0 and compressed_summary:
rebuilt_messages.append({
"role": "user",
"content": "Please continue with the remaining tasks based on the context above."
})
logger.info("Added continuation user message to maintain proper turn-taking after full compression")
if include_current_execution:
# Preserve any messages that were added during the current execution cycle
recent_msg_count = 1 # system message
for query in recent_queries:
if "prompt" in query and "response" in query:
recent_msg_count += 2
if "tool_calls" in query:
recent_msg_count += len(query["tool_calls"]) * 2
if len(messages) > recent_msg_count:
current_execution_messages = messages[recent_msg_count:]
rebuilt_messages.extend(current_execution_messages)
logger.info(
f"Preserved {len(current_execution_messages)} messages from current execution cycle"
)
logger.info(
f"Messages rebuilt: {len(messages)}{len(rebuilt_messages)} messages. "
f"Ready to continue tool execution."
)
return rebuilt_messages

View File

@@ -0,0 +1,233 @@
"""High-level compression orchestration."""
import logging
from typing import Any, Dict, Optional
from application.api.answer.services.compression.service import CompressionService
from application.api.answer.services.compression.threshold_checker import (
CompressionThresholdChecker,
)
from application.api.answer.services.compression.types import CompressionResult
from application.api.answer.services.conversation_service import ConversationService
from application.core.model_utils import (
get_api_key_for_provider,
get_provider_from_model_id,
)
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
logger = logging.getLogger(__name__)
class CompressionOrchestrator:
"""
Facade for compression operations.
Coordinates between all compression components and provides
a simple interface for callers.
"""
def __init__(
self,
conversation_service: ConversationService,
threshold_checker: Optional[CompressionThresholdChecker] = None,
):
"""
Initialize orchestrator.
Args:
conversation_service: Service for DB operations
threshold_checker: Custom threshold checker (optional)
"""
self.conversation_service = conversation_service
self.threshold_checker = threshold_checker or CompressionThresholdChecker()
def compress_if_needed(
self,
conversation_id: str,
user_id: str,
model_id: str,
decoded_token: Dict[str, Any],
current_query_tokens: int = 500,
) -> CompressionResult:
"""
Check if compression is needed and perform it if so.
This is the main entry point for compression operations.
Args:
conversation_id: Conversation ID
user_id: User ID
model_id: Model being used for conversation
decoded_token: User's decoded JWT token
current_query_tokens: Estimated tokens for current query
Returns:
CompressionResult with summary and recent queries
"""
try:
# Load conversation
conversation = self.conversation_service.get_conversation(
conversation_id, user_id
)
if not conversation:
logger.warning(
f"Conversation {conversation_id} not found for user {user_id}"
)
return CompressionResult.failure("Conversation not found")
# Check if compression is needed
if not self.threshold_checker.should_compress(
conversation, model_id, current_query_tokens
):
# No compression needed, return full history
queries = conversation.get("queries", [])
return CompressionResult.success_no_compression(queries)
# Perform compression
return self._perform_compression(
conversation_id, conversation, model_id, decoded_token
)
except Exception as e:
logger.error(
f"Error in compress_if_needed: {str(e)}", exc_info=True
)
return CompressionResult.failure(str(e))
def _perform_compression(
self,
conversation_id: str,
conversation: Dict[str, Any],
model_id: str,
decoded_token: Dict[str, Any],
) -> CompressionResult:
"""
Perform the actual compression operation.
Args:
conversation_id: Conversation ID
conversation: Conversation document
model_id: Model ID for conversation
decoded_token: User token
Returns:
CompressionResult
"""
try:
# Determine which model to use for compression
compression_model = (
settings.COMPRESSION_MODEL_OVERRIDE
if settings.COMPRESSION_MODEL_OVERRIDE
else model_id
)
# Get provider and API key for compression model
provider = get_provider_from_model_id(compression_model)
api_key = get_api_key_for_provider(provider)
# Create compression LLM
compression_llm = LLMCreator.create_llm(
provider,
api_key=api_key,
user_api_key=None,
decoded_token=decoded_token,
model_id=compression_model,
agent_id=conversation.get("agent_id"),
)
# Create compression service with DB update capability
compression_service = CompressionService(
llm=compression_llm,
model_id=compression_model,
conversation_service=self.conversation_service,
)
# Compress all queries up to the latest
queries_count = len(conversation.get("queries", []))
compress_up_to = queries_count - 1
if compress_up_to < 0:
logger.warning("No queries to compress")
return CompressionResult.success_no_compression([])
logger.info(
f"Initiating compression for conversation {conversation_id}: "
f"compressing all {queries_count} queries (0-{compress_up_to})"
)
# Perform compression and save to DB
metadata = compression_service.compress_and_save(
conversation_id, conversation, compress_up_to
)
logger.info(
f"Compression successful - ratio: {metadata.compression_ratio:.1f}x, "
f"saved {metadata.original_token_count - metadata.compressed_token_count} tokens"
)
# Reload conversation with updated metadata
conversation = self.conversation_service.get_conversation(
conversation_id, user_id=decoded_token.get("sub")
)
# Get compressed context
compressed_summary, recent_queries = (
compression_service.get_compressed_context(conversation)
)
return CompressionResult.success_with_compression(
compressed_summary, recent_queries, metadata
)
except Exception as e:
logger.error(f"Error performing compression: {str(e)}", exc_info=True)
return CompressionResult.failure(str(e))
def compress_mid_execution(
self,
conversation_id: str,
user_id: str,
model_id: str,
decoded_token: Dict[str, Any],
current_conversation: Optional[Dict[str, Any]] = None,
) -> CompressionResult:
"""
Perform compression during tool execution.
Args:
conversation_id: Conversation ID
user_id: User ID
model_id: Model ID
decoded_token: User token
current_conversation: Pre-loaded conversation (optional)
Returns:
CompressionResult
"""
try:
# Load conversation if not provided
if current_conversation:
conversation = current_conversation
else:
conversation = self.conversation_service.get_conversation(
conversation_id, user_id
)
if not conversation:
logger.warning(
f"Could not load conversation {conversation_id} for mid-execution compression"
)
return CompressionResult.failure("Conversation not found")
# Perform compression
return self._perform_compression(
conversation_id, conversation, model_id, decoded_token
)
except Exception as e:
logger.error(
f"Error in mid-execution compression: {str(e)}", exc_info=True
)
return CompressionResult.failure(str(e))

View File

@@ -0,0 +1,149 @@
"""Compression prompt building logic."""
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
class CompressionPromptBuilder:
"""Builds prompts for LLM compression calls."""
def __init__(self, version: str = "v1.0"):
"""
Initialize prompt builder.
Args:
version: Prompt template version to use
"""
self.version = version
self.system_prompt = self._load_prompt(version)
def _load_prompt(self, version: str) -> str:
"""
Load prompt template from file.
Args:
version: Version string (e.g., 'v1.0')
Returns:
Prompt template content
Raises:
FileNotFoundError: If prompt template file doesn't exist
"""
current_dir = Path(__file__).resolve().parents[4]
prompt_path = current_dir / "prompts" / "compression" / f"{version}.txt"
try:
with open(prompt_path, "r") as f:
return f.read()
except FileNotFoundError:
logger.error(f"Compression prompt template not found: {prompt_path}")
raise FileNotFoundError(
f"Compression prompt template '{version}' not found at {prompt_path}. "
f"Please ensure the template file exists."
)
def build_prompt(
self,
queries: List[Dict[str, Any]],
existing_compressions: Optional[List[Dict[str, Any]]] = None,
) -> List[Dict[str, str]]:
"""
Build messages for compression LLM call.
Args:
queries: List of query objects to compress
existing_compressions: List of previous compression points
Returns:
List of message dicts for LLM
"""
# Build conversation text
conversation_text = self._format_conversation(queries)
# Add existing compression context if present
existing_compression_context = ""
if existing_compressions and len(existing_compressions) > 0:
existing_compression_context = (
"\n\nIMPORTANT: This conversation has been compressed before. "
"Previous compression summaries:\n\n"
)
for i, comp in enumerate(existing_compressions):
existing_compression_context += (
f"--- Compression {i + 1} (up to message {comp.get('query_index', 'unknown')}) ---\n"
f"{comp.get('compressed_summary', '')}\n\n"
)
existing_compression_context += (
"Your task is to create a NEW summary that incorporates the context from "
"previous compressions AND the new messages below. The final summary should "
"be comprehensive and include all important information from both previous "
"compressions and new messages.\n\n"
)
user_prompt = (
f"{existing_compression_context}"
f"Here is the conversation to summarize:\n\n"
f"{conversation_text}"
)
messages = [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": user_prompt},
]
return messages
def _format_conversation(self, queries: List[Dict[str, Any]]) -> str:
"""
Format conversation queries into readable text for compression.
Args:
queries: List of query objects
Returns:
Formatted conversation text
"""
conversation_lines = []
for i, query in enumerate(queries):
conversation_lines.append(f"--- Message {i + 1} ---")
conversation_lines.append(f"User: {query.get('prompt', '')}")
# Add tool calls if present
tool_calls = query.get("tool_calls", [])
if tool_calls:
conversation_lines.append("\nTool Calls:")
for tc in tool_calls:
tool_name = tc.get("tool_name", "unknown")
action_name = tc.get("action_name", "unknown")
arguments = tc.get("arguments", {})
result = tc.get("result", "")
if result is None:
result = ""
status = tc.get("status", "unknown")
# Include full tool result for complete compression context
conversation_lines.append(
f" - {tool_name}.{action_name}({arguments}) "
f"[{status}] → {result}"
)
# Add agent thought if present
thought = query.get("thought", "")
if thought:
conversation_lines.append(f"\nAgent Thought: {thought}")
# Add assistant response
conversation_lines.append(f"\nAssistant: {query.get('response', '')}")
# Add sources if present
sources = query.get("sources", [])
if sources:
conversation_lines.append(f"\nSources Used: {len(sources)} documents")
conversation_lines.append("") # Empty line between messages
return "\n".join(conversation_lines)

View File

@@ -0,0 +1,306 @@
"""Core compression service with simplified responsibilities."""
import logging
import re
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
from application.api.answer.services.compression.prompt_builder import (
CompressionPromptBuilder,
)
from application.api.answer.services.compression.token_counter import TokenCounter
from application.api.answer.services.compression.types import (
CompressionMetadata,
)
from application.core.settings import settings
logger = logging.getLogger(__name__)
class CompressionService:
"""
Service for compressing conversation history.
Handles DB updates.
"""
def __init__(
self,
llm,
model_id: str,
conversation_service=None,
prompt_builder: Optional[CompressionPromptBuilder] = None,
):
"""
Initialize compression service.
Args:
llm: LLM instance to use for compression
model_id: Model ID for compression
conversation_service: Service for DB operations (optional, for DB updates)
prompt_builder: Custom prompt builder (optional)
"""
self.llm = llm
self.model_id = model_id
self.conversation_service = conversation_service
self.prompt_builder = prompt_builder or CompressionPromptBuilder(
version=settings.COMPRESSION_PROMPT_VERSION
)
def compress_conversation(
self,
conversation: Dict[str, Any],
compress_up_to_index: int,
) -> CompressionMetadata:
"""
Compress conversation history up to specified index.
Args:
conversation: Full conversation document
compress_up_to_index: Last query index to include in compression
Returns:
CompressionMetadata with compression details
Raises:
ValueError: If compress_up_to_index is invalid
"""
try:
queries = conversation.get("queries", [])
if compress_up_to_index < 0 or compress_up_to_index >= len(queries):
raise ValueError(
f"Invalid compress_up_to_index: {compress_up_to_index} "
f"(conversation has {len(queries)} queries)"
)
# Get queries to compress
queries_to_compress = queries[: compress_up_to_index + 1]
# Check if there are existing compressions
existing_compressions = conversation.get("compression_metadata", {}).get(
"compression_points", []
)
if existing_compressions:
logger.info(
f"Found {len(existing_compressions)} previous compression(s) - "
f"will incorporate into new summary"
)
# Calculate original token count
original_tokens = TokenCounter.count_query_tokens(queries_to_compress)
# Log tool call stats
self._log_tool_call_stats(queries_to_compress)
# Build compression prompt
messages = self.prompt_builder.build_prompt(
queries_to_compress, existing_compressions
)
# Call LLM to generate compression
logger.info(
f"Starting compression: {len(queries_to_compress)} queries "
f"(messages 0-{compress_up_to_index}, {original_tokens} tokens) "
f"using model {self.model_id}"
)
response = self.llm.gen(
model=self.model_id, messages=messages, max_tokens=4000
)
# Extract summary from response
compressed_summary = self._extract_summary(response)
# Calculate compressed token count
compressed_tokens = TokenCounter.count_message_tokens(
[{"content": compressed_summary}]
)
# Calculate compression ratio
compression_ratio = (
original_tokens / compressed_tokens if compressed_tokens > 0 else 0
)
logger.info(
f"Compression complete: {original_tokens}{compressed_tokens} tokens "
f"({compression_ratio:.1f}x compression)"
)
# Build compression metadata
compression_metadata = CompressionMetadata(
timestamp=datetime.now(timezone.utc),
query_index=compress_up_to_index,
compressed_summary=compressed_summary,
original_token_count=original_tokens,
compressed_token_count=compressed_tokens,
compression_ratio=compression_ratio,
model_used=self.model_id,
compression_prompt_version=self.prompt_builder.version,
)
return compression_metadata
except Exception as e:
logger.error(f"Error compressing conversation: {str(e)}", exc_info=True)
raise
def compress_and_save(
self,
conversation_id: str,
conversation: Dict[str, Any],
compress_up_to_index: int,
) -> CompressionMetadata:
"""
Compress conversation and save to database.
Args:
conversation_id: Conversation ID
conversation: Full conversation document
compress_up_to_index: Last query index to include
Returns:
CompressionMetadata
Raises:
ValueError: If conversation_service not provided or invalid index
"""
if not self.conversation_service:
raise ValueError(
"conversation_service required for compress_and_save operation"
)
# Perform compression
metadata = self.compress_conversation(conversation, compress_up_to_index)
# Save to database
self.conversation_service.update_compression_metadata(
conversation_id, metadata.to_dict()
)
logger.info(f"Compression metadata saved to database for {conversation_id}")
return metadata
def get_compressed_context(
self, conversation: Dict[str, Any]
) -> tuple[Optional[str], List[Dict[str, Any]]]:
"""
Get compressed summary + recent uncompressed messages.
Args:
conversation: Full conversation document
Returns:
(compressed_summary, recent_messages)
"""
try:
compression_metadata = conversation.get("compression_metadata", {})
if not compression_metadata.get("is_compressed"):
logger.debug("No compression metadata found - using full history")
queries = conversation.get("queries", [])
if queries is None:
logger.error("Conversation queries is None - returning empty list")
return None, []
return None, queries
compression_points = compression_metadata.get("compression_points", [])
if not compression_points:
logger.debug("No compression points found - using full history")
queries = conversation.get("queries", [])
if queries is None:
logger.error("Conversation queries is None - returning empty list")
return None, []
return None, queries
# Get the most recent compression point
latest_compression = compression_points[-1]
compressed_summary = latest_compression.get("compressed_summary")
last_compressed_index = latest_compression.get("query_index")
compressed_tokens = latest_compression.get("compressed_token_count", 0)
original_tokens = latest_compression.get("original_token_count", 0)
# Get only messages after compression point
queries = conversation.get("queries", [])
total_queries = len(queries)
recent_queries = queries[last_compressed_index + 1 :]
logger.info(
f"Using compressed context: summary ({compressed_tokens} tokens, "
f"compressed from {original_tokens}) + {len(recent_queries)} recent messages "
f"(messages {last_compressed_index + 1}-{total_queries - 1})"
)
return compressed_summary, recent_queries
except Exception as e:
logger.error(
f"Error getting compressed context: {str(e)}", exc_info=True
)
queries = conversation.get("queries", [])
if queries is None:
return None, []
return None, queries
def _extract_summary(self, llm_response: str) -> str:
"""
Extract clean summary from LLM response.
Args:
llm_response: Raw LLM response
Returns:
Cleaned summary text
"""
try:
# Try to extract content within <summary> tags
summary_match = re.search(
r"<summary>(.*?)</summary>", llm_response, re.DOTALL
)
if summary_match:
summary = summary_match.group(1).strip()
else:
# If no summary tags, remove analysis tags and use the rest
summary = re.sub(
r"<analysis>.*?</analysis>", "", llm_response, flags=re.DOTALL
).strip()
return summary
except Exception as e:
logger.warning(f"Error extracting summary: {str(e)}, using full response")
return llm_response
def _log_tool_call_stats(self, queries: List[Dict[str, Any]]) -> None:
"""Log statistics about tool calls in queries."""
total_tool_calls = 0
total_tool_result_chars = 0
tool_call_breakdown = {}
for q in queries:
for tc in q.get("tool_calls", []):
total_tool_calls += 1
tool_name = tc.get("tool_name", "unknown")
action_name = tc.get("action_name", "unknown")
key = f"{tool_name}.{action_name}"
tool_call_breakdown[key] = tool_call_breakdown.get(key, 0) + 1
# Track total tool result size
result = tc.get("result", "")
if result:
total_tool_result_chars += len(str(result))
if total_tool_calls > 0:
tool_breakdown_str = ", ".join(
f"{tool}({count})"
for tool, count in sorted(tool_call_breakdown.items())
)
tool_result_kb = total_tool_result_chars / 1024
logger.info(
f"Tool call breakdown: {tool_breakdown_str} "
f"(total result size: {tool_result_kb:.1f} KB, {total_tool_result_chars:,} chars)"
)

View File

@@ -0,0 +1,103 @@
"""Compression threshold checking logic."""
import logging
from typing import Any, Dict
from application.core.model_utils import get_token_limit
from application.core.settings import settings
from application.api.answer.services.compression.token_counter import TokenCounter
logger = logging.getLogger(__name__)
class CompressionThresholdChecker:
"""Determines if compression is needed based on token thresholds."""
def __init__(self, threshold_percentage: float = None):
"""
Initialize threshold checker.
Args:
threshold_percentage: Percentage of context to use as threshold
(defaults to settings.COMPRESSION_THRESHOLD_PERCENTAGE)
"""
self.threshold_percentage = (
threshold_percentage or settings.COMPRESSION_THRESHOLD_PERCENTAGE
)
def should_compress(
self,
conversation: Dict[str, Any],
model_id: str,
current_query_tokens: int = 500,
) -> bool:
"""
Determine if compression is needed.
Args:
conversation: Full conversation document
model_id: Target model for this request
current_query_tokens: Estimated tokens for current query
Returns:
True if tokens >= threshold% of context window
"""
try:
# Calculate total tokens in conversation
total_tokens = TokenCounter.count_conversation_tokens(conversation)
total_tokens += current_query_tokens
# Get context window limit for model
context_limit = get_token_limit(model_id)
# Calculate threshold
threshold = int(context_limit * self.threshold_percentage)
compression_needed = total_tokens >= threshold
percentage_used = (total_tokens / context_limit) * 100
if compression_needed:
logger.warning(
f"COMPRESSION TRIGGERED: {total_tokens} tokens / {context_limit} limit "
f"({percentage_used:.1f}% used, threshold: {self.threshold_percentage * 100:.0f}%)"
)
else:
logger.info(
f"Compression check: {total_tokens}/{context_limit} tokens "
f"({percentage_used:.1f}% used, threshold: {self.threshold_percentage * 100:.0f}%) - No compression needed"
)
return compression_needed
except Exception as e:
logger.error(f"Error checking compression need: {str(e)}", exc_info=True)
return False
def check_message_tokens(self, messages: list, model_id: str) -> bool:
"""
Check if message list exceeds threshold.
Args:
messages: List of message dicts
model_id: Target model
Returns:
True if at or above threshold
"""
try:
current_tokens = TokenCounter.count_message_tokens(messages)
context_limit = get_token_limit(model_id)
threshold = int(context_limit * self.threshold_percentage)
if current_tokens >= threshold:
logger.warning(
f"Message context limit approaching: {current_tokens}/{context_limit} tokens "
f"({(current_tokens/context_limit)*100:.1f}%)"
)
return True
return False
except Exception as e:
logger.error(f"Error checking message tokens: {str(e)}", exc_info=True)
return False

View File

@@ -0,0 +1,103 @@
"""Token counting utilities for compression."""
import logging
from typing import Any, Dict, List
from application.utils import num_tokens_from_string
from application.core.settings import settings
logger = logging.getLogger(__name__)
class TokenCounter:
"""Centralized token counting for conversations and messages."""
@staticmethod
def count_message_tokens(messages: List[Dict]) -> int:
"""
Calculate total tokens in a list of messages.
Args:
messages: List of message dicts with 'content' field
Returns:
Total token count
"""
total_tokens = 0
for message in messages:
content = message.get("content", "")
if isinstance(content, str):
total_tokens += num_tokens_from_string(content)
elif isinstance(content, list):
# Handle structured content (tool calls, etc.)
for item in content:
if isinstance(item, dict):
total_tokens += num_tokens_from_string(str(item))
return total_tokens
@staticmethod
def count_query_tokens(
queries: List[Dict[str, Any]], include_tool_calls: bool = True
) -> int:
"""
Count tokens across multiple query objects.
Args:
queries: List of query objects from conversation
include_tool_calls: Whether to count tool call tokens
Returns:
Total token count
"""
total_tokens = 0
for query in queries:
# Count prompt and response tokens
if "prompt" in query:
total_tokens += num_tokens_from_string(query["prompt"])
if "response" in query:
total_tokens += num_tokens_from_string(query["response"])
if "thought" in query:
total_tokens += num_tokens_from_string(query.get("thought", ""))
# Count tool call tokens
if include_tool_calls and "tool_calls" in query:
for tool_call in query["tool_calls"]:
tool_call_string = (
f"Tool: {tool_call.get('tool_name')} | "
f"Action: {tool_call.get('action_name')} | "
f"Args: {tool_call.get('arguments')} | "
f"Response: {tool_call.get('result')}"
)
total_tokens += num_tokens_from_string(tool_call_string)
return total_tokens
@staticmethod
def count_conversation_tokens(
conversation: Dict[str, Any], include_system_prompt: bool = False
) -> int:
"""
Calculate total tokens in a conversation.
Args:
conversation: Conversation document
include_system_prompt: Whether to include system prompt in count
Returns:
Total token count
"""
try:
queries = conversation.get("queries", [])
total_tokens = TokenCounter.count_query_tokens(queries)
# Add system prompt tokens if requested
if include_system_prompt:
# Rough estimate for system prompt
total_tokens += settings.RESERVED_TOKENS.get("system_prompt", 500)
return total_tokens
except Exception as e:
logger.error(f"Error calculating conversation tokens: {str(e)}")
return 0

View File

@@ -0,0 +1,83 @@
"""Type definitions for compression module."""
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Dict, List, Optional
@dataclass
class CompressionMetadata:
"""Metadata about a compression operation."""
timestamp: datetime
query_index: int
compressed_summary: str
original_token_count: int
compressed_token_count: int
compression_ratio: float
model_used: str
compression_prompt_version: str
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for DB storage."""
return {
"timestamp": self.timestamp,
"query_index": self.query_index,
"compressed_summary": self.compressed_summary,
"original_token_count": self.original_token_count,
"compressed_token_count": self.compressed_token_count,
"compression_ratio": self.compression_ratio,
"model_used": self.model_used,
"compression_prompt_version": self.compression_prompt_version,
}
@dataclass
class CompressionResult:
"""Result of a compression operation."""
success: bool
compressed_summary: Optional[str] = None
recent_queries: List[Dict[str, Any]] = field(default_factory=list)
metadata: Optional[CompressionMetadata] = None
error: Optional[str] = None
compression_performed: bool = False
@classmethod
def success_with_compression(
cls, summary: str, queries: List[Dict], metadata: CompressionMetadata
) -> "CompressionResult":
"""Create a successful result with compression."""
return cls(
success=True,
compressed_summary=summary,
recent_queries=queries,
metadata=metadata,
compression_performed=True,
)
@classmethod
def success_no_compression(cls, queries: List[Dict]) -> "CompressionResult":
"""Create a successful result without compression needed."""
return cls(
success=True,
recent_queries=queries,
compression_performed=False,
)
@classmethod
def failure(cls, error: str) -> "CompressionResult":
"""Create a failure result."""
return cls(success=False, error=error, compression_performed=False)
def as_history(self) -> List[Dict[str, str]]:
"""
Convert recent queries to history format.
Returns:
List of prompt/response dicts
"""
return [
{"prompt": q["prompt"], "response": q["response"]}
for q in self.recent_queries
]

View File

@@ -60,8 +60,11 @@ class ConversationService:
is_shared_usage: bool = False,
shared_token: Optional[str] = None,
attachment_ids: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> str:
"""Save or update a conversation in the database"""
if decoded_token is None:
raise ValueError("Invalid or missing authentication token")
user_id = decoded_token.get("sub")
if not user_id:
raise ValueError("User ID not found in token")
@@ -91,6 +94,11 @@ class ConversationService:
f"queries.{index}.timestamp": current_time,
f"queries.{index}.attachments": attachment_ids,
f"queries.{index}.model_id": model_id,
**(
{f"queries.{index}.metadata": metadata}
if metadata
else {}
),
}
},
)
@@ -122,6 +130,7 @@ class ConversationService:
"timestamp": current_time,
"attachments": attachment_ids,
"model_id": model_id,
**({"metadata": metadata} if metadata else {}),
}
}
},
@@ -148,25 +157,30 @@ class ConversationService:
]
completion = llm.gen(
model=model_id, messages=messages_summary, max_tokens=30
model=model_id, messages=messages_summary, max_tokens=500
)
if not completion or not completion.strip():
completion = question[:50] if question else "New Conversation"
query_doc = {
"prompt": question,
"response": response,
"thought": thought,
"sources": sources,
"tool_calls": tool_calls,
"timestamp": current_time,
"attachments": attachment_ids,
"model_id": model_id,
}
if metadata:
query_doc["metadata"] = metadata
conversation_data = {
"user": user_id,
"date": current_time,
"name": completion,
"queries": [
{
"prompt": question,
"response": response,
"thought": thought,
"sources": sources,
"tool_calls": tool_calls,
"timestamp": current_time,
"attachments": attachment_ids,
"model_id": model_id,
}
],
"queries": [query_doc],
}
if api_key:
@@ -180,3 +194,103 @@ class ConversationService:
conversation_data["api_key"] = agent["key"]
result = self.conversations_collection.insert_one(conversation_data)
return str(result.inserted_id)
def update_compression_metadata(
self, conversation_id: str, compression_metadata: Dict[str, Any]
) -> None:
"""
Update conversation with compression metadata.
Uses $push with $slice to keep only the most recent compression points,
preventing unbounded array growth. Since each compression incorporates
previous compressions, older points become redundant.
Args:
conversation_id: Conversation ID
compression_metadata: Compression point data
"""
try:
self.conversations_collection.update_one(
{"_id": ObjectId(conversation_id)},
{
"$set": {
"compression_metadata.is_compressed": True,
"compression_metadata.last_compression_at": compression_metadata.get(
"timestamp"
),
},
"$push": {
"compression_metadata.compression_points": {
"$each": [compression_metadata],
"$slice": -settings.COMPRESSION_MAX_HISTORY_POINTS,
}
},
},
)
logger.info(
f"Updated compression metadata for conversation {conversation_id}"
)
except Exception as e:
logger.error(
f"Error updating compression metadata: {str(e)}", exc_info=True
)
raise
def append_compression_message(
self, conversation_id: str, compression_metadata: Dict[str, Any]
) -> None:
"""
Append a synthetic compression summary entry into the conversation history.
This makes the summary visible in the DB alongside normal queries.
"""
try:
summary = compression_metadata.get("compressed_summary", "")
if not summary:
return
timestamp = compression_metadata.get("timestamp", datetime.now(timezone.utc))
self.conversations_collection.update_one(
{"_id": ObjectId(conversation_id)},
{
"$push": {
"queries": {
"prompt": "[Context Compression Summary]",
"response": summary,
"thought": "",
"sources": [],
"tool_calls": [],
"timestamp": timestamp,
"attachments": [],
"model_id": compression_metadata.get("model_used"),
}
}
},
)
logger.info(f"Appended compression summary to conversation {conversation_id}")
except Exception as e:
logger.error(
f"Error appending compression summary: {str(e)}", exc_info=True
)
def get_compression_metadata(
self, conversation_id: str
) -> Optional[Dict[str, Any]]:
"""
Get compression metadata for a conversation.
Args:
conversation_id: Conversation ID
Returns:
Compression metadata dict or None
"""
try:
conversation = self.conversations_collection.find_one(
{"_id": ObjectId(conversation_id)}, {"compression_metadata": 1}
)
return conversation.get("compression_metadata") if conversation else None
except Exception as e:
logger.error(
f"Error getting compression metadata: {str(e)}", exc_info=True
)
return None

View File

@@ -10,6 +10,8 @@ from bson.dbref import DBRef
from bson.objectid import ObjectId
from application.agents.agent_creator import AgentCreator
from application.api.answer.services.compression import CompressionOrchestrator
from application.api.answer.services.compression.token_counter import TokenCounter
from application.api.answer.services.conversation_service import ConversationService
from application.api.answer.services.prompt_renderer import PromptRenderer
from application.core.model_utils import (
@@ -36,13 +38,23 @@ def get_prompt(prompt_id: str, prompts_collection=None) -> str:
current_dir = Path(__file__).resolve().parents[3]
prompts_dir = current_dir / "prompts"
preset_mapping = {
# Maps for classic agent types
CLASSIC_PRESETS = {
"default": "chat_combine_default.txt",
"creative": "chat_combine_creative.txt",
"strict": "chat_combine_strict.txt",
"reduce": "chat_reduce_prompt.txt",
}
# Agentic counterparts — same styles, but with search tool instructions
AGENTIC_PRESETS = {
"default": "agentic/default.txt",
"creative": "agentic/creative.txt",
"strict": "agentic/strict.txt",
}
preset_mapping = {**CLASSIC_PRESETS, **{f"agentic_{k}": v for k, v in AGENTIC_PRESETS.items()}}
if prompt_id in preset_mapping:
file_path = os.path.join(prompts_dir, preset_mapping[prompt_id])
try:
@@ -88,22 +100,51 @@ class StreamProcessor:
self.retriever_config = {}
self.is_shared_usage = False
self.shared_token = None
self.agent_id = self.data.get("agent_id")
self.agent_key = None
self.model_id: Optional[str] = None
self.conversation_service = ConversationService()
self.compression_orchestrator = CompressionOrchestrator(
self.conversation_service
)
self.prompt_renderer = PromptRenderer()
self._prompt_content: Optional[str] = None
self._required_tool_actions: Optional[Dict[str, Set[Optional[str]]]] = None
self.compressed_summary: Optional[str] = None
self.compressed_summary_tokens: int = 0
def initialize(self):
"""Initialize all required components for processing"""
self._validate_and_set_model()
self._configure_agent()
self._validate_and_set_model()
self._configure_source()
self._configure_retriever()
self._configure_agent()
self._load_conversation_history()
self._process_attachments()
def build_agent(self, question: str):
"""One call to go from request data to a ready-to-run agent.
Combines initialize(), pre_fetch_docs(), pre_fetch_tools(), and
create_agent() into a single convenience method.
"""
self.initialize()
agent_type = self.agent_config.get("agent_type", "classic")
# Agentic/research agents skip pre-fetch — the LLM searches on-demand via tools
if agent_type in ("agentic", "research"):
tools_data = self.pre_fetch_tools()
return self.create_agent(tools_data=tools_data)
docs_together, docs_list = self.pre_fetch_docs(question)
tools_data = self.pre_fetch_tools()
return self.create_agent(
docs_together=docs_together,
docs=docs_list,
tools_data=tools_data,
)
def _load_conversation_history(self):
"""Load conversation history either from DB or request"""
if self.conversation_id and self.initial_user_id:
@@ -112,15 +153,85 @@ class StreamProcessor:
)
if not conversation:
raise ValueError("Conversation not found or unauthorized")
self.history = [
{"prompt": query["prompt"], "response": query["response"]}
for query in conversation.get("queries", [])
]
# Check if compression is enabled and needed
if settings.ENABLE_CONVERSATION_COMPRESSION:
self._handle_compression(conversation)
else:
# Original behavior - load all history (include metadata if present)
self.history = [
{
"prompt": query["prompt"],
"response": query["response"],
**(
{"metadata": query["metadata"]}
if "metadata" in query
else {}
),
}
for query in conversation.get("queries", [])
]
else:
self.history = limit_chat_history(
json.loads(self.data.get("history", "[]")), model_id=self.model_id
)
def _handle_compression(self, conversation: Dict[str, Any]):
"""Handle conversation compression logic using orchestrator."""
try:
result = self.compression_orchestrator.compress_if_needed(
conversation_id=self.conversation_id,
user_id=self.initial_user_id,
model_id=self.model_id,
decoded_token=self.decoded_token,
)
if not result.success:
logger.error(f"Compression failed: {result.error}, using full history")
self.history = [
{
"prompt": query["prompt"],
"response": query["response"],
**({"metadata": query["metadata"]} if "metadata" in query else {}),
}
for query in conversation.get("queries", [])
]
return
if result.compression_performed and result.compressed_summary:
self.compressed_summary = result.compressed_summary
self.compressed_summary_tokens = TokenCounter.count_message_tokens(
[{"content": result.compressed_summary}]
)
logger.info(
f"Using compressed summary ({self.compressed_summary_tokens} tokens) "
f"+ {len(result.recent_queries)} recent messages"
)
self.history = result.as_history()
# Preserve metadata from recent queries (as_history only has prompt/response)
recent = result.recent_queries if result.recent_queries else conversation.get("queries", [])
for i, entry in enumerate(self.history):
# Match by index from the end of recent queries
offset = len(recent) - len(self.history)
qi = offset + i
if 0 <= qi < len(recent) and "metadata" in recent[qi]:
entry["metadata"] = recent[qi]["metadata"]
except Exception as e:
logger.error(
f"Error handling compression, falling back to standard history: {str(e)}",
exc_info=True,
)
self.history = [
{
"prompt": query["prompt"],
"response": query["response"],
**({"metadata": query["metadata"]} if "metadata" in query else {}),
}
for query in conversation.get("queries", [])
]
def _process_attachments(self):
"""Process any attachments in the request"""
attachment_ids = self.data.get("attachments", [])
@@ -129,9 +240,6 @@ class StreamProcessor:
)
def _get_attachments_content(self, attachment_ids, user_id):
"""
Retrieve content from attachment documents based on their IDs.
"""
if not attachment_ids:
return []
attachments = []
@@ -140,7 +248,6 @@ class StreamProcessor:
attachment_doc = self.attachments_collection.find_one(
{"_id": ObjectId(attachment_id), "user": user_id}
)
if attachment_doc:
attachments.append(attachment_doc)
except Exception as e:
@@ -162,11 +269,19 @@ class StreamProcessor:
raise ValueError(
f"Invalid model_id '{requested_model}'. "
f"Available models: {', '.join(available_models[:5])}"
+ (f" and {len(available_models) - 5} more" if len(available_models) > 5 else "")
+ (
f" and {len(available_models) - 5} more"
if len(available_models) > 5
else ""
)
)
self.model_id = requested_model
else:
self.model_id = get_default_model_id()
agent_default_model = self.agent_config.get("default_model_id", "")
if agent_default_model and validate_model_id(agent_default_model):
self.model_id = agent_default_model
else:
self.model_id = get_default_model_id()
def _get_agent_key(self, agent_id: Optional[str], user_id: Optional[str]) -> tuple:
"""Get API key for agent with access control"""
@@ -214,7 +329,6 @@ class StreamProcessor:
data["source"] = "default"
else:
data["source"] = None
# Handle multiple sources
sources = data.get("sources", [])
if sources and isinstance(sources, list):
@@ -239,6 +353,9 @@ class StreamProcessor:
data["sources"] = sources_list
else:
data["sources"] = []
data["default_model_id"] = data.get("default_model_id", "")
return data
def _configure_source(self):
@@ -275,55 +392,79 @@ class StreamProcessor:
self.source = {}
self.all_sources = []
def _resolve_agent_id(self) -> Optional[str]:
"""Resolve agent_id from request, then fall back to conversation context."""
request_agent_id = self.data.get("agent_id")
if request_agent_id:
return str(request_agent_id)
if not self.conversation_id or not self.initial_user_id:
return None
try:
conversation = self.conversation_service.get_conversation(
self.conversation_id, self.initial_user_id
)
except Exception:
return None
if not conversation:
return None
conversation_agent_id = conversation.get("agent_id")
if conversation_agent_id:
return str(conversation_agent_id)
return None
def _configure_agent(self):
"""Configure the agent based on request data"""
agent_id = self.data.get("agent_id")
"""Configure the agent based on request data.
Unified flow: resolve the effective API key, then extract config once.
"""
agent_id = self._resolve_agent_id()
self.agent_key, self.is_shared_usage, self.shared_token = self._get_agent_key(
agent_id, self.initial_user_id
)
self.agent_id = str(agent_id) if agent_id else None
# Determine the effective API key (explicit > agent-derived)
effective_key = self.data.get("api_key") or self.agent_key
if effective_key:
data_key = self._get_data_from_api_key(effective_key)
if data_key.get("_id"):
self.agent_id = str(data_key.get("_id"))
api_key = self.data.get("api_key")
if api_key:
data_key = self._get_data_from_api_key(api_key)
self.agent_config.update(
{
"prompt_id": data_key.get("prompt_id", "default"),
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
"user_api_key": api_key,
"user_api_key": effective_key,
"json_schema": data_key.get("json_schema"),
"default_model_id": data_key.get("default_model_id", ""),
"models": data_key.get("models", []),
}
)
self.initial_user_id = data_key.get("user")
self.decoded_token = {"sub": data_key.get("user")}
if data_key.get("source"):
self.source = {"active_docs": data_key["source"]}
if data_key.get("retriever"):
self.retriever_config["retriever_name"] = data_key["retriever"]
if data_key.get("chunks") is not None:
try:
self.retriever_config["chunks"] = int(data_key["chunks"])
except (ValueError, TypeError):
logger.warning(
f"Invalid chunks value: {data_key['chunks']}, using default value 2"
)
self.retriever_config["chunks"] = 2
elif self.agent_key:
data_key = self._get_data_from_api_key(self.agent_key)
self.agent_config.update(
{
"prompt_id": data_key.get("prompt_id", "default"),
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
"user_api_key": self.agent_key,
"json_schema": data_key.get("json_schema"),
}
)
self.decoded_token = (
self.decoded_token
if self.is_shared_usage
else {"sub": data_key.get("user")}
)
# Set identity context
if self.data.get("api_key"):
# External API key: use the key owner's identity
self.initial_user_id = data_key.get("user")
self.decoded_token = {"sub": data_key.get("user")}
elif self.is_shared_usage:
# Shared agent: keep the caller's identity
pass
else:
# Owner using their own agent
self.decoded_token = {"sub": data_key.get("user")}
if data_key.get("source"):
self.source = {"active_docs": data_key["source"]}
if data_key.get("workflow"):
self.agent_config["workflow"] = data_key["workflow"]
self.agent_config["workflow_owner"] = data_key.get("user")
if data_key.get("retriever"):
self.retriever_config["retriever_name"] = data_key["retriever"]
if data_key.get("chunks") is not None:
@@ -335,26 +476,33 @@ class StreamProcessor:
)
self.retriever_config["chunks"] = 2
else:
# No API key — default/workflow configuration
agent_type = settings.AGENT_NAME
if self.data.get("workflow") and isinstance(
self.data.get("workflow"), dict
):
agent_type = "workflow"
self.agent_config["workflow"] = self.data["workflow"]
if isinstance(self.decoded_token, dict):
self.agent_config["workflow_owner"] = self.decoded_token.get("sub")
self.agent_config.update(
{
"prompt_id": self.data.get("prompt_id", "default"),
"agent_type": settings.AGENT_NAME,
"agent_type": agent_type,
"user_api_key": None,
"json_schema": None,
"default_model_id": "",
}
)
def _configure_retriever(self):
history_token_limit = int(self.data.get("token_limit", 2000))
doc_token_limit = calculate_doc_token_budget(
model_id=self.model_id, history_token_limit=history_token_limit
)
doc_token_limit = calculate_doc_token_budget(model_id=self.model_id)
self.retriever_config = {
"retriever_name": self.data.get("retriever", "classic"),
"chunks": int(self.data.get("chunks", 2)),
"doc_token_limit": doc_token_limit,
"history_token_limit": history_token_limit,
}
api_key = self.data.get("api_key") or self.agent_key
@@ -371,12 +519,13 @@ class StreamProcessor:
doc_token_limit=self.retriever_config.get("doc_token_limit", 50000),
model_id=self.model_id,
user_api_key=self.agent_config["user_api_key"],
agent_id=self.agent_id,
decoded_token=self.decoded_token,
)
def pre_fetch_docs(self, question: str) -> tuple[Optional[str], Optional[list]]:
"""Pre-fetch documents for template rendering before agent creation"""
if self.data.get("isNoneDoc", False):
if self.data.get("isNoneDoc", False) and not self.agent_id:
logger.info("Pre-fetch skipped: isNoneDoc=True")
return None, None
try:
@@ -410,12 +559,7 @@ class StreamProcessor:
return None, None
def pre_fetch_tools(self) -> Optional[Dict[str, Any]]:
"""Pre-fetch tool data for template rendering before agent creation
Can be controlled via:
1. Global setting: ENABLE_TOOL_PREFETCH in .env
2. Per-request: disable_tool_prefetch in request data
"""
"""Pre-fetch tool data for template rendering before agent creation"""
if not settings.ENABLE_TOOL_PREFETCH:
logger.info(
"Tool pre-fetching disabled globally via ENABLE_TOOL_PREFETCH setting"
@@ -634,11 +778,21 @@ class StreamProcessor:
tools_data: Optional[Dict[str, Any]] = None,
):
"""Create and return the configured agent with rendered prompt"""
agent_type = self.agent_config["agent_type"]
# For agentic agents, swap standard presets for their agentic
# counterparts (which include search tool instructions instead of
# {summaries}). Custom / user-provided prompts pass through as-is.
raw_prompt = self._get_prompt_content()
if raw_prompt is None:
raw_prompt = get_prompt(
self.agent_config["prompt_id"], self.prompts_collection
)
prompt_id = self.agent_config.get("prompt_id", "default")
agentic_presets = {"default", "creative", "strict"}
if agent_type in ("agentic", "research") and prompt_id in agentic_presets:
raw_prompt = get_prompt(
f"agentic_{prompt_id}", self.prompts_collection
)
else:
raw_prompt = get_prompt(prompt_id, self.prompts_collection)
self._prompt_content = raw_prompt
rendered_prompt = self.prompt_renderer.render_prompt(
@@ -658,17 +812,88 @@ class StreamProcessor:
)
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
return AgentCreator.create_agent(
self.agent_config["agent_type"],
endpoint="stream",
llm_name=provider or settings.LLM_PROVIDER,
model_id=self.model_id,
# Create LLM and handler (dependency injection)
from application.llm.llm_creator import LLMCreator
from application.llm.handlers.handler_creator import LLMHandlerCreator
from application.agents.tool_executor import ToolExecutor
# Compute backup models: agent's configured models minus the active one
agent_models = self.agent_config.get("models", [])
backup_models = [m for m in agent_models if m != self.model_id]
llm = LLMCreator.create_llm(
provider or settings.LLM_PROVIDER,
api_key=system_api_key,
user_api_key=self.agent_config["user_api_key"],
prompt=rendered_prompt,
chat_history=self.history,
retrieved_docs=self.retrieved_docs,
decoded_token=self.decoded_token,
attachments=self.attachments,
json_schema=self.agent_config.get("json_schema"),
model_id=self.model_id,
agent_id=self.agent_id,
backup_models=backup_models,
)
llm_handler = LLMHandlerCreator.create_handler(
provider if provider else "default"
)
user = self.decoded_token.get("sub") if self.decoded_token else None
tool_executor = ToolExecutor(
user_api_key=self.agent_config["user_api_key"],
user=user,
decoded_token=self.decoded_token,
)
tool_executor.conversation_id = self.conversation_id
# Base agent kwargs
agent_kwargs = {
"endpoint": "stream",
"llm_name": provider or settings.LLM_PROVIDER,
"model_id": self.model_id,
"api_key": system_api_key,
"agent_id": self.agent_id,
"user_api_key": self.agent_config["user_api_key"],
"prompt": rendered_prompt,
"chat_history": self.history,
"retrieved_docs": self.retrieved_docs,
"decoded_token": self.decoded_token,
"attachments": self.attachments,
"json_schema": self.agent_config.get("json_schema"),
"compressed_summary": self.compressed_summary,
"llm": llm,
"llm_handler": llm_handler,
"tool_executor": tool_executor,
}
# Type-specific kwargs
if agent_type in ("agentic", "research"):
agent_kwargs["retriever_config"] = {
"source": self.source,
"retriever_name": self.retriever_config.get(
"retriever_name", "classic"
),
"chunks": self.retriever_config.get("chunks", 2),
"doc_token_limit": self.retriever_config.get(
"doc_token_limit", 50000
),
"model_id": self.model_id,
"user_api_key": self.agent_config["user_api_key"],
"agent_id": self.agent_id,
"llm_name": provider or settings.LLM_PROVIDER,
"api_key": system_api_key,
"decoded_token": self.decoded_token,
}
elif agent_type == "workflow":
workflow_config = self.agent_config.get("workflow")
if isinstance(workflow_config, str):
agent_kwargs["workflow_id"] = workflow_config
elif isinstance(workflow_config, dict):
agent_kwargs["workflow"] = workflow_config
workflow_owner = self.agent_config.get("workflow_owner")
if workflow_owner:
agent_kwargs["workflow_owner"] = workflow_owner
agent = AgentCreator.create_agent(agent_type, **agent_kwargs)
agent.conversation_id = self.conversation_id
agent.initial_user_id = self.initial_user_id
return agent

View File

@@ -1,7 +1,9 @@
import base64
import datetime
import html
import json
import uuid
from urllib.parse import urlencode
from bson.objectid import ObjectId
@@ -35,6 +37,18 @@ connector = Blueprint("connector", __name__)
connectors_ns = Namespace("connectors", description="Connector operations", path="/")
api.add_namespace(connectors_ns)
# Fixed callback status path to prevent open redirect
CALLBACK_STATUS_PATH = "/api/connectors/callback-status"
def build_callback_redirect(params: dict) -> str:
"""Build a safe redirect URL to the callback status page.
Uses a fixed path and properly URL-encodes all parameters
to prevent URL injection and open redirect vulnerabilities.
"""
return f"{CALLBACK_STATUS_PATH}?{urlencode(params)}"
@connectors_ns.route("/api/connectors/auth")
@@ -75,8 +89,8 @@ class ConnectorAuth(Resource):
"state": state
}), 200)
except Exception as e:
current_app.logger.error(f"Error generating connector auth URL: {e}")
return make_response(jsonify({"success": False, "error": str(e)}), 500)
current_app.logger.error(f"Error generating connector auth URL: {e}", exc_info=True)
return make_response(jsonify({"success": False, "error": "Failed to generate authorization URL"}), 500)
@connectors_ns.route("/api/connectors/callback")
@@ -93,18 +107,37 @@ class ConnectorsCallback(Resource):
error = request.args.get('error')
state_dict = json.loads(base64.urlsafe_b64decode(state.encode()).decode())
provider = state_dict["provider"]
state_object_id = state_dict["object_id"]
provider = state_dict.get("provider")
state_object_id = state_dict.get("object_id")
# Validate provider
if not provider or not isinstance(provider, str) or not ConnectorCreator.is_supported(provider):
return redirect(build_callback_redirect({
"status": "error",
"message": "Invalid provider"
}))
if error:
if error == "access_denied":
return redirect(f"/api/connectors/callback-status?status=cancelled&message=Authentication+was+cancelled.+You+can+try+again+if+you'd+like+to+connect+your+account.&provider={provider}")
return redirect(build_callback_redirect({
"status": "cancelled",
"message": "Authentication was cancelled. You can try again if you'd like to connect your account.",
"provider": provider
}))
else:
current_app.logger.warning(f"OAuth error in callback: {error}")
return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}")
return redirect(build_callback_redirect({
"status": "error",
"message": "Authentication failed. Please try again and make sure to grant all requested permissions.",
"provider": provider
}))
if not authorization_code:
return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}")
return redirect(build_callback_redirect({
"status": "error",
"message": "Authentication failed. Please try again and make sure to grant all requested permissions.",
"provider": provider
}))
try:
auth = ConnectorCreator.create_auth(provider)
@@ -113,20 +146,19 @@ class ConnectorsCallback(Resource):
session_token = str(uuid.uuid4())
try:
credentials = auth.create_credentials_from_token_info(token_info)
service = auth.build_drive_service(credentials)
user_info = service.about().get(fields="user").execute()
user_email = user_info.get('user', {}).get('emailAddress', 'Connected User')
if provider == "google_drive":
credentials = auth.create_credentials_from_token_info(token_info)
service = auth.build_drive_service(credentials)
user_info = service.about().get(fields="user").execute()
user_email = user_info.get('user', {}).get('emailAddress', 'Connected User')
else:
user_email = token_info.get('user_info', {}).get('email', 'Connected User')
except Exception as e:
current_app.logger.warning(f"Could not get user info: {e}")
user_email = 'Connected User'
sanitized_token_info = {
"access_token": token_info.get("access_token"),
"refresh_token": token_info.get("refresh_token"),
"token_uri": token_info.get("token_uri"),
"expiry": token_info.get("expiry")
}
sanitized_token_info = auth.sanitize_token_info(token_info)
sessions_collection.find_one_and_update(
{"_id": ObjectId(state_object_id), "provider": provider},
@@ -141,26 +173,39 @@ class ConnectorsCallback(Resource):
)
# Redirect to success page with session token and user email
return redirect(f"/api/connectors/callback-status?status=success&message=Authentication+successful&provider={provider}&session_token={session_token}&user_email={user_email}")
return redirect(build_callback_redirect({
"status": "success",
"message": "Authentication successful",
"provider": provider,
"session_token": session_token,
"user_email": user_email
}))
except Exception as e:
current_app.logger.error(f"Error exchanging code for tokens: {str(e)}", exc_info=True)
return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}")
return redirect(build_callback_redirect({
"status": "error",
"message": "Authentication failed. Please try again and make sure to grant all requested permissions.",
"provider": provider
}))
except Exception as e:
current_app.logger.error(f"Error handling connector callback: {e}")
return redirect("/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.")
return redirect(build_callback_redirect({
"status": "error",
"message": "Authentication failed. Please try again and make sure to grant all requested permissions."
}))
@connectors_ns.route("/api/connectors/files")
class ConnectorFiles(Resource):
@api.expect(api.model("ConnectorFilesModel", {
"provider": fields.String(required=True),
"session_token": fields.String(required=True),
"folder_id": fields.String(required=False),
"limit": fields.Integer(required=False),
"provider": fields.String(required=True),
"session_token": fields.String(required=True),
"folder_id": fields.String(required=False),
"limit": fields.Integer(required=False),
"page_token": fields.String(required=False),
"search_query": fields.String(required=False)
"search_query": fields.String(required=False),
}))
@api.doc(description="List files from a connector provider (supports pagination and search)")
def post(self):
@@ -168,11 +213,8 @@ class ConnectorFiles(Resource):
data = request.get_json()
provider = data.get('provider')
session_token = data.get('session_token')
folder_id = data.get('folder_id')
limit = data.get('limit', 10)
page_token = data.get('page_token')
search_query = data.get('search_query')
if not provider or not session_token:
return make_response(jsonify({"success": False, "error": "provider and session_token are required"}), 400)
@@ -185,15 +227,12 @@ class ConnectorFiles(Resource):
return make_response(jsonify({"success": False, "error": "Invalid or unauthorized session"}), 401)
loader = ConnectorCreator.create_connector(provider, session_token)
generic_keys = {'provider', 'session_token'}
input_config = {
'limit': limit,
'list_only': True,
'session_token': session_token,
'folder_id': folder_id,
'page_token': page_token
k: v for k, v in data.items() if k not in generic_keys
}
if search_query:
input_config['search_query'] = search_query
input_config['list_only'] = True
documents = loader.load_data(input_config)
@@ -228,8 +267,8 @@ class ConnectorFiles(Resource):
"has_more": has_more
}), 200)
except Exception as e:
current_app.logger.error(f"Error loading connector files: {e}")
return make_response(jsonify({"success": False, "error": f"Failed to load files: {str(e)}"}), 500)
current_app.logger.error(f"Error loading connector files: {e}", exc_info=True)
return make_response(jsonify({"success": False, "error": "Failed to load files"}), 500)
@connectors_ns.route("/api/connectors/validate-session")
@@ -260,12 +299,7 @@ class ConnectorValidateSession(Resource):
if is_expired and token_info.get('refresh_token'):
try:
refreshed_token_info = auth.refresh_access_token(token_info.get('refresh_token'))
sanitized_token_info = {
"access_token": refreshed_token_info.get("access_token"),
"refresh_token": refreshed_token_info.get("refresh_token"),
"token_uri": refreshed_token_info.get("token_uri"),
"expiry": refreshed_token_info.get("expiry")
}
sanitized_token_info = auth.sanitize_token_info(refreshed_token_info)
sessions_collection.update_one(
{"session_token": session_token},
{"$set": {"token_info": sanitized_token_info}}
@@ -282,15 +316,21 @@ class ConnectorValidateSession(Resource):
"error": "Session token has expired. Please reconnect."
}), 401)
return make_response(jsonify({
_base_fields = {"access_token", "refresh_token", "token_uri", "expiry"}
provider_extras = {k: v for k, v in token_info.items() if k not in _base_fields}
response_data = {
"success": True,
"expired": False,
"user_email": session.get('user_email', 'Connected User'),
"access_token": token_info.get('access_token')
}), 200)
"access_token": token_info.get('access_token'),
**provider_extras,
}
return make_response(jsonify(response_data), 200)
except Exception as e:
current_app.logger.error(f"Error validating connector session: {e}")
return make_response(jsonify({"success": False, "error": str(e)}), 500)
current_app.logger.error(f"Error validating connector session: {e}", exc_info=True)
return make_response(jsonify({"success": False, "error": "Failed to validate session"}), 500)
@connectors_ns.route("/api/connectors/disconnect")
@@ -311,8 +351,8 @@ class ConnectorDisconnect(Resource):
return make_response(jsonify({"success": True}), 200)
except Exception as e:
current_app.logger.error(f"Error disconnecting connector session: {e}")
return make_response(jsonify({"success": False, "error": str(e)}), 500)
current_app.logger.error(f"Error disconnecting connector session: {e}", exc_info=True)
return make_response(jsonify({"success": False, "error": "Failed to disconnect session"}), 500)
@connectors_ns.route("/api/connectors/sync")
@@ -418,8 +458,8 @@ class ConnectorSync(Resource):
return make_response(
jsonify({
"success": False,
"error": str(err)
}),
"error": "Failed to sync connector source"
}),
400
)
@@ -430,17 +470,32 @@ class ConnectorCallbackStatus(Resource):
def get(self):
"""Return HTML page with connector authentication status"""
try:
status = request.args.get('status', 'error')
message = request.args.get('message', '')
provider = request.args.get('provider', 'connector')
# Validate and sanitize status to a known value
status_raw = request.args.get('status', 'error')
status = status_raw if status_raw in ('success', 'error', 'cancelled') else 'error'
# Escape all user-controlled values for HTML context
message = html.escape(request.args.get('message', ''))
provider_raw = request.args.get('provider', 'connector')
provider = html.escape(provider_raw.replace('_', ' ').title())
session_token = request.args.get('session_token', '')
user_email = request.args.get('user_email', '')
user_email = html.escape(request.args.get('user_email', ''))
def safe_js_string(value: str) -> str:
"""Safely encode a string for embedding in inline JavaScript."""
js_encoded = json.dumps(value)
return js_encoded.replace('</', '<\\/').replace('<!--', '<\\!--')
js_status = safe_js_string(status)
js_session_token = safe_js_string(session_token)
js_user_email = safe_js_string(user_email)
js_provider_type = safe_js_string(provider_raw)
html_content = f"""
<!DOCTYPE html>
<html>
<head>
<title>{provider.replace('_', ' ').title()} Authentication</title>
<title>{provider} Authentication</title>
<style>
body {{ font-family: Arial, sans-serif; text-align: center; padding: 40px; }}
.container {{ max-width: 600px; margin: 0 auto; }}
@@ -450,13 +505,14 @@ class ConnectorCallbackStatus(Resource):
</style>
<script>
window.onload = function() {{
const status = "{status}";
const sessionToken = "{session_token}";
const userEmail = "{user_email}";
const status = {js_status};
const sessionToken = {js_session_token};
const userEmail = {js_user_email};
const providerType = {js_provider_type};
if (status === "success" && window.opener) {{
window.opener.postMessage({{
type: '{provider}_auth_success',
type: providerType + '_auth_success',
session_token: sessionToken,
user_email: userEmail
}}, '*');
@@ -470,17 +526,17 @@ class ConnectorCallbackStatus(Resource):
</head>
<body>
<div class="container">
<h2>{provider.replace('_', ' ').title()} Authentication</h2>
<h2>{provider} Authentication</h2>
<div class="{status}">
<p>{message}</p>
{f'<p>Connected as: {user_email}</p>' if status == 'success' else ''}
</div>
<p><small>You can close this window. {f"Your {provider.replace('_', ' ').title()} is now connected and ready to use." if status == 'success' else "Feel free to close this window."}</small></p>
<p><small>You can close this window. {f"Your {provider} is now connected and ready to use." if status == 'success' else "Feel free to close this window."}</small></p>
</div>
</body>
</html>
"""
return make_response(html_content, 200, {'Content-Type': 'text/html'})
except Exception as e:
current_app.logger.error(f"Error rendering callback status page: {e}")

View File

@@ -1,7 +1,7 @@
import os
import datetime
import json
from flask import Blueprint, request, send_from_directory
from flask import Blueprint, request, send_from_directory, jsonify
from werkzeug.utils import secure_filename
from bson.objectid import ObjectId
import logging
@@ -24,6 +24,16 @@ current_dir = os.path.dirname(
internal = Blueprint("internal", __name__)
@internal.before_request
def verify_internal_key():
"""Verify INTERNAL_KEY for all internal endpoint requests."""
if settings.INTERNAL_KEY:
internal_key = request.headers.get("X-Internal-Key")
if not internal_key or internal_key != settings.INTERNAL_KEY:
logger.warning(f"Unauthorized internal API access attempt from {request.remote_addr}")
return jsonify({"error": "Unauthorized", "message": "Invalid or missing internal key"}), 401
@internal.route("/api/download", methods=["get"])
def download_file():
user = secure_filename(request.args.get("user"))
@@ -51,6 +61,7 @@ def upload_index_files():
file_path = request.form.get("file_path")
directory_structure = request.form.get("directory_structure")
file_name_map = request.form.get("file_name_map")
if directory_structure:
try:
@@ -60,6 +71,14 @@ def upload_index_files():
directory_structure = {}
else:
directory_structure = {}
if file_name_map:
try:
file_name_map = json.loads(file_name_map)
except Exception:
logger.error("Error parsing file_name_map")
file_name_map = None
else:
file_name_map = None
storage = StorageCreator.get_storage()
index_base_path = f"indexes/{id}"
@@ -87,41 +106,43 @@ def upload_index_files():
existing_entry = sources_collection.find_one({"_id": ObjectId(id)})
if existing_entry:
update_fields = {
"user": user,
"name": job_name,
"language": job_name,
"date": datetime.datetime.now(),
"model": settings.EMBEDDINGS_NAME,
"type": type,
"tokens": tokens,
"retriever": retriever,
"remote_data": remote_data,
"sync_frequency": sync_frequency,
"file_path": file_path,
"directory_structure": directory_structure,
}
if file_name_map is not None:
update_fields["file_name_map"] = file_name_map
sources_collection.update_one(
{"_id": ObjectId(id)},
{
"$set": {
"user": user,
"name": job_name,
"language": job_name,
"date": datetime.datetime.now(),
"model": settings.EMBEDDINGS_NAME,
"type": type,
"tokens": tokens,
"retriever": retriever,
"remote_data": remote_data,
"sync_frequency": sync_frequency,
"file_path": file_path,
"directory_structure": directory_structure,
}
},
{"$set": update_fields},
)
else:
sources_collection.insert_one(
{
"_id": ObjectId(id),
"user": user,
"name": job_name,
"language": job_name,
"date": datetime.datetime.now(),
"model": settings.EMBEDDINGS_NAME,
"type": type,
"tokens": tokens,
"retriever": retriever,
"remote_data": remote_data,
"sync_frequency": sync_frequency,
"file_path": file_path,
"directory_structure": directory_structure,
}
)
insert_doc = {
"_id": ObjectId(id),
"user": user,
"name": job_name,
"language": job_name,
"date": datetime.datetime.now(),
"model": settings.EMBEDDINGS_NAME,
"type": type,
"tokens": tokens,
"retriever": retriever,
"remote_data": remote_data,
"sync_frequency": sync_frequency,
"file_path": file_path,
"directory_structure": directory_structure,
}
if file_name_map is not None:
insert_doc["file_name_map"] = file_name_map
sources_collection.insert_one(insert_doc)
return {"status": "ok"}

View File

@@ -3,5 +3,6 @@
from .routes import agents_ns
from .sharing import agents_sharing_ns
from .webhooks import agents_webhooks_ns
from .folders import agents_folders_ns
__all__ = ["agents_ns", "agents_sharing_ns", "agents_webhooks_ns"]
__all__ = ["agents_ns", "agents_sharing_ns", "agents_webhooks_ns", "agents_folders_ns"]

View File

@@ -0,0 +1,266 @@
"""
Agent folders management routes.
Provides virtual folder organization for agents (Google Drive-like structure).
"""
import datetime
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import Namespace, Resource, fields
from application.api import api
from application.api.user.base import (
agent_folders_collection,
agents_collection,
)
agents_folders_ns = Namespace(
"agents_folders", description="Agent folder management", path="/api/agents/folders"
)
def _folder_error_response(message: str, err: Exception):
current_app.logger.error(f"{message}: {err}", exc_info=True)
return make_response(jsonify({"success": False, "message": message}), 400)
@agents_folders_ns.route("/")
class AgentFolders(Resource):
@api.doc(description="Get all folders for the user")
def get(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
try:
folders = list(agent_folders_collection.find({"user": user}))
result = [
{
"id": str(f["_id"]),
"name": f["name"],
"parent_id": f.get("parent_id"),
"created_at": f.get("created_at", "").isoformat() if f.get("created_at") else None,
"updated_at": f.get("updated_at", "").isoformat() if f.get("updated_at") else None,
}
for f in folders
]
return make_response(jsonify({"folders": result}), 200)
except Exception as err:
return _folder_error_response("Failed to fetch folders", err)
@api.doc(description="Create a new folder")
@api.expect(
api.model(
"CreateFolder",
{
"name": fields.String(required=True, description="Folder name"),
"parent_id": fields.String(required=False, description="Parent folder ID"),
},
)
)
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
if not data or not data.get("name"):
return make_response(jsonify({"success": False, "message": "Folder name is required"}), 400)
parent_id = data.get("parent_id")
if parent_id:
parent = agent_folders_collection.find_one({"_id": ObjectId(parent_id), "user": user})
if not parent:
return make_response(jsonify({"success": False, "message": "Parent folder not found"}), 404)
try:
now = datetime.datetime.now(datetime.timezone.utc)
folder = {
"user": user,
"name": data["name"],
"parent_id": parent_id,
"created_at": now,
"updated_at": now,
}
result = agent_folders_collection.insert_one(folder)
return make_response(
jsonify({"id": str(result.inserted_id), "name": data["name"], "parent_id": parent_id}),
201,
)
except Exception as err:
return _folder_error_response("Failed to create folder", err)
@agents_folders_ns.route("/<string:folder_id>")
class AgentFolder(Resource):
@api.doc(description="Get a specific folder with its agents")
def get(self, folder_id):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
try:
folder = agent_folders_collection.find_one({"_id": ObjectId(folder_id), "user": user})
if not folder:
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
agents = list(agents_collection.find({"user": user, "folder_id": folder_id}))
agents_list = [
{"id": str(a["_id"]), "name": a["name"], "description": a.get("description", "")}
for a in agents
]
subfolders = list(agent_folders_collection.find({"user": user, "parent_id": folder_id}))
subfolders_list = [{"id": str(sf["_id"]), "name": sf["name"]} for sf in subfolders]
return make_response(
jsonify({
"id": str(folder["_id"]),
"name": folder["name"],
"parent_id": folder.get("parent_id"),
"agents": agents_list,
"subfolders": subfolders_list,
}),
200,
)
except Exception as err:
return _folder_error_response("Failed to fetch folder", err)
@api.doc(description="Update a folder")
def put(self, folder_id):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
if not data:
return make_response(jsonify({"success": False, "message": "No data provided"}), 400)
try:
update_fields = {"updated_at": datetime.datetime.now(datetime.timezone.utc)}
if "name" in data:
update_fields["name"] = data["name"]
if "parent_id" in data:
if data["parent_id"] == folder_id:
return make_response(jsonify({"success": False, "message": "Cannot set folder as its own parent"}), 400)
update_fields["parent_id"] = data["parent_id"]
result = agent_folders_collection.update_one(
{"_id": ObjectId(folder_id), "user": user}, {"$set": update_fields}
)
if result.matched_count == 0:
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
return make_response(jsonify({"success": True}), 200)
except Exception as err:
return _folder_error_response("Failed to update folder", err)
@api.doc(description="Delete a folder")
def delete(self, folder_id):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
try:
agents_collection.update_many(
{"user": user, "folder_id": folder_id}, {"$unset": {"folder_id": ""}}
)
agent_folders_collection.update_many(
{"user": user, "parent_id": folder_id}, {"$unset": {"parent_id": ""}}
)
result = agent_folders_collection.delete_one({"_id": ObjectId(folder_id), "user": user})
if result.deleted_count == 0:
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
return make_response(jsonify({"success": True}), 200)
except Exception as err:
return _folder_error_response("Failed to delete folder", err)
@agents_folders_ns.route("/move_agent")
class MoveAgentToFolder(Resource):
@api.doc(description="Move an agent to a folder or remove from folder")
@api.expect(
api.model(
"MoveAgent",
{
"agent_id": fields.String(required=True, description="Agent ID to move"),
"folder_id": fields.String(required=False, description="Target folder ID (null to remove from folder)"),
},
)
)
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
if not data or not data.get("agent_id"):
return make_response(jsonify({"success": False, "message": "Agent ID is required"}), 400)
agent_id = data["agent_id"]
folder_id = data.get("folder_id")
try:
agent = agents_collection.find_one({"_id": ObjectId(agent_id), "user": user})
if not agent:
return make_response(jsonify({"success": False, "message": "Agent not found"}), 404)
if folder_id:
folder = agent_folders_collection.find_one({"_id": ObjectId(folder_id), "user": user})
if not folder:
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
agents_collection.update_one(
{"_id": ObjectId(agent_id)}, {"$set": {"folder_id": folder_id}}
)
else:
agents_collection.update_one(
{"_id": ObjectId(agent_id)}, {"$unset": {"folder_id": ""}}
)
return make_response(jsonify({"success": True}), 200)
except Exception as err:
return _folder_error_response("Failed to move agent", err)
@agents_folders_ns.route("/bulk_move")
class BulkMoveAgents(Resource):
@api.doc(description="Move multiple agents to a folder")
@api.expect(
api.model(
"BulkMoveAgents",
{
"agent_ids": fields.List(fields.String, required=True, description="List of agent IDs"),
"folder_id": fields.String(required=False, description="Target folder ID"),
},
)
)
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
if not data or not data.get("agent_ids"):
return make_response(jsonify({"success": False, "message": "Agent IDs are required"}), 400)
agent_ids = data["agent_ids"]
folder_id = data.get("folder_id")
try:
if folder_id:
folder = agent_folders_collection.find_one({"_id": ObjectId(folder_id), "user": user})
if not folder:
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
object_ids = [ObjectId(aid) for aid in agent_ids]
if folder_id:
agents_collection.update_many(
{"_id": {"$in": object_ids}, "user": user},
{"$set": {"folder_id": folder_id}},
)
else:
agents_collection.update_many(
{"_id": {"$in": object_ids}, "user": user},
{"$unset": {"folder_id": ""}},
)
return make_response(jsonify({"success": True}), 200)
except Exception as err:
return _folder_error_response("Failed to move agents", err)

View File

@@ -11,6 +11,7 @@ from flask_restx import fields, Namespace, Resource
from application.api import api
from application.api.user.base import (
agent_folders_collection,
agents_collection,
db,
ensure_user_doc,
@@ -18,6 +19,13 @@ from application.api.user.base import (
resolve_tool_details,
storage,
users_collection,
workflow_edges_collection,
workflow_nodes_collection,
workflows_collection,
)
from application.core.json_schema_utils import (
JsonSchemaValidationError,
normalize_json_schema_payload,
)
from application.core.settings import settings
from application.utils import (
@@ -30,6 +38,191 @@ from application.utils import (
agents_ns = Namespace("agents", description="Agent management operations", path="/api")
AGENT_TYPE_SCHEMAS = {
"classic": {
"required_published": [
"name",
"description",
"chunks",
"retriever",
"prompt_id",
],
"required_draft": ["name"],
"validate_published": ["name", "description", "prompt_id"],
"validate_draft": [],
"require_source": True,
"fields": [
"user",
"name",
"description",
"agent_type",
"status",
"key",
"image",
"source",
"sources",
"chunks",
"retriever",
"prompt_id",
"tools",
"json_schema",
"models",
"default_model_id",
"folder_id",
"limited_token_mode",
"token_limit",
"limited_request_mode",
"request_limit",
"createdAt",
"updatedAt",
"lastUsedAt",
],
},
"workflow": {
"required_published": ["name", "workflow"],
"required_draft": ["name"],
"validate_published": ["name", "workflow"],
"validate_draft": [],
"fields": [
"user",
"name",
"description",
"agent_type",
"status",
"key",
"workflow",
"folder_id",
"limited_token_mode",
"token_limit",
"limited_request_mode",
"request_limit",
"createdAt",
"updatedAt",
"lastUsedAt",
],
},
}
AGENT_TYPE_SCHEMAS["react"] = AGENT_TYPE_SCHEMAS["classic"]
AGENT_TYPE_SCHEMAS["agentic"] = AGENT_TYPE_SCHEMAS["classic"]
AGENT_TYPE_SCHEMAS["research"] = AGENT_TYPE_SCHEMAS["classic"]
AGENT_TYPE_SCHEMAS["openai"] = AGENT_TYPE_SCHEMAS["classic"]
def normalize_workflow_reference(workflow_value):
"""Normalize workflow references from form/json payloads."""
if workflow_value is None:
return None
if isinstance(workflow_value, dict):
return (
workflow_value.get("id")
or workflow_value.get("_id")
or workflow_value.get("workflow_id")
)
if isinstance(workflow_value, str):
value = workflow_value.strip()
if not value:
return ""
try:
parsed = json.loads(value)
if isinstance(parsed, str):
return parsed.strip()
if isinstance(parsed, dict):
return (
parsed.get("id") or parsed.get("_id") or parsed.get("workflow_id")
)
except json.JSONDecodeError:
pass
return value
return str(workflow_value)
def validate_workflow_access(workflow_value, user, required=False):
"""Validate workflow reference and ensure ownership."""
workflow_id = normalize_workflow_reference(workflow_value)
if not workflow_id:
if required:
return None, make_response(
jsonify({"success": False, "message": "Workflow is required"}), 400
)
return None, None
if not ObjectId.is_valid(workflow_id):
return None, make_response(
jsonify({"success": False, "message": "Invalid workflow ID format"}), 400
)
workflow = workflows_collection.find_one({"_id": ObjectId(workflow_id), "user": user})
if not workflow:
return None, make_response(
jsonify({"success": False, "message": "Workflow not found"}), 404
)
return workflow_id, None
def build_agent_document(
data, user, key, agent_type, image_url=None, source_field=None, sources_list=None
):
"""Build agent document based on agent type schema."""
if not agent_type or agent_type not in AGENT_TYPE_SCHEMAS:
agent_type = "classic"
schema = AGENT_TYPE_SCHEMAS.get(agent_type, AGENT_TYPE_SCHEMAS["classic"])
allowed_fields = set(schema["fields"])
now = datetime.datetime.now(datetime.timezone.utc)
base_doc = {
"user": user,
"name": data.get("name"),
"description": data.get("description", ""),
"agent_type": agent_type,
"status": data.get("status"),
"key": key,
"createdAt": now,
"updatedAt": now,
"lastUsedAt": None,
}
if agent_type == "workflow":
base_doc["workflow"] = data.get("workflow")
base_doc["folder_id"] = data.get("folder_id")
else:
base_doc.update(
{
"image": image_url or "",
"source": source_field or "",
"sources": sources_list or [],
"chunks": data.get("chunks", ""),
"retriever": data.get("retriever", ""),
"prompt_id": data.get("prompt_id", ""),
"tools": data.get("tools", []),
"json_schema": data.get("json_schema"),
"models": data.get("models", []),
"default_model_id": data.get("default_model_id", ""),
"folder_id": data.get("folder_id"),
}
)
if "limited_token_mode" in allowed_fields:
base_doc["limited_token_mode"] = (
data.get("limited_token_mode") == "True"
if isinstance(data.get("limited_token_mode"), str)
else bool(data.get("limited_token_mode", False))
)
if "token_limit" in allowed_fields:
base_doc["token_limit"] = int(
data.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"])
)
if "limited_request_mode" in allowed_fields:
base_doc["limited_request_mode"] = (
data.get("limited_request_mode") == "True"
if isinstance(data.get("limited_request_mode"), str)
else bool(data.get("limited_request_mode", False))
)
if "request_limit" in allowed_fields:
base_doc["request_limit"] = int(
data.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"])
)
return {k: v for k, v in base_doc.items() if k in allowed_fields}
@agents_ns.route("/get_agent")
class GetAgent(Resource):
@api.doc(params={"id": "Agent ID"}, description="Get agent by ID")
@@ -67,7 +260,7 @@ class GetAgent(Resource):
if (isinstance(source_ref, DBRef) and db.dereference(source_ref))
or source_ref == "default"
],
"chunks": agent["chunks"],
"chunks": agent.get("chunks", "2"),
"retriever": agent.get("retriever", ""),
"prompt_id": agent.get("prompt_id", ""),
"tools": agent.get("tools", []),
@@ -97,6 +290,8 @@ class GetAgent(Resource):
"shared_token": agent.get("shared_token", ""),
"models": agent.get("models", []),
"default_model_id": agent.get("default_model_id", ""),
"folder_id": agent.get("folder_id"),
"workflow": agent.get("workflow"),
}
return make_response(jsonify(data), 200)
except Exception as e:
@@ -146,7 +341,7 @@ class GetAgents(Resource):
isinstance(source_ref, DBRef) and db.dereference(source_ref)
)
],
"chunks": agent["chunks"],
"chunks": agent.get("chunks", "2"),
"retriever": agent.get("retriever", ""),
"prompt_id": agent.get("prompt_id", ""),
"tools": agent.get("tools", []),
@@ -176,9 +371,13 @@ class GetAgents(Resource):
"shared_token": agent.get("shared_token", ""),
"models": agent.get("models", []),
"default_model_id": agent.get("default_model_id", ""),
"folder_id": agent.get("folder_id"),
"workflow": agent.get("workflow"),
}
for agent in agents
if "source" in agent or "retriever" in agent
if "source" in agent
or "retriever" in agent
or agent.get("agent_type") == "workflow"
]
except Exception as err:
current_app.logger.error(f"Error retrieving agents: {err}", exc_info=True)
@@ -206,16 +405,22 @@ class CreateAgent(Resource):
required=False,
description="List of source identifiers for multiple sources",
),
"chunks": fields.Integer(required=True, description="Chunks count"),
"retriever": fields.String(required=True, description="Retriever ID"),
"prompt_id": fields.String(required=True, description="Prompt ID"),
"chunks": fields.Integer(required=False, description="Chunks count"),
"retriever": fields.String(required=False, description="Retriever ID"),
"prompt_id": fields.String(required=False, description="Prompt ID"),
"tools": fields.List(
fields.String, required=False, description="List of tool identifiers"
),
"agent_type": fields.String(required=True, description="Type of the agent"),
"agent_type": fields.String(
required=False,
description="Type of the agent (classic, react, workflow). Defaults to 'classic' for backwards compatibility.",
),
"status": fields.String(
required=True, description="Status of the agent (draft or published)"
),
"workflow": fields.String(
required=False, description="Workflow ID for workflow-type agents"
),
"json_schema": fields.Raw(
required=False,
description="JSON schema for enforcing structured output format",
@@ -242,6 +447,9 @@ class CreateAgent(Resource):
"default_model_id": fields.String(
required=False, description="Default model ID for this agent"
),
"folder_id": fields.String(
required=False, description="Folder ID to organize the agent"
),
},
)
@@ -277,40 +485,15 @@ class CreateAgent(Resource):
data["models"] = []
print(f"Received data: {data}")
# Validate JSON schema if provided
if data.get("json_schema"):
# Validate and normalize JSON schema if provided
if "json_schema" in data:
try:
# Basic validation - ensure it's a valid JSON structure
json_schema = data.get("json_schema")
if not isinstance(json_schema, dict):
return make_response(
jsonify(
{
"success": False,
"message": "JSON schema must be a valid JSON object",
}
),
400,
)
# Validate that it has either a 'schema' property or is itself a schema
if "schema" not in json_schema and "type" not in json_schema:
return make_response(
jsonify(
{
"success": False,
"message": "JSON schema must contain either a 'schema' property or be a valid JSON schema with 'type' property",
}
),
400,
)
except Exception as e:
data["json_schema"] = normalize_json_schema_payload(
data.get("json_schema")
)
except JsonSchemaValidationError as exc:
return make_response(
jsonify(
{"success": False, "message": f"Invalid JSON schema: {str(e)}"}
),
jsonify({"success": False, "message": f"JSON schema {exc}"}),
400,
)
if data.get("status") not in ["draft", "published"]:
@@ -323,18 +506,34 @@ class CreateAgent(Resource):
),
400,
)
if data.get("status") == "published":
required_fields = [
"name",
"description",
"chunks",
"retriever",
"prompt_id",
"agent_type",
]
# Require either source or sources (but not both)
agent_type = data.get("agent_type", "")
# Default to classic schema for empty or unknown agent types
if not data.get("source") and not data.get("sources"):
if not agent_type or agent_type not in AGENT_TYPE_SCHEMAS:
schema = AGENT_TYPE_SCHEMAS["classic"]
# Set agent_type to classic if it was empty
if not agent_type:
agent_type = "classic"
else:
schema = AGENT_TYPE_SCHEMAS[agent_type]
is_published = data.get("status") == "published"
if agent_type == "workflow":
workflow_id, workflow_error = validate_workflow_access(
data.get("workflow"), user, required=is_published
)
if workflow_error:
return workflow_error
data["workflow"] = workflow_id
if data.get("status") == "published":
required_fields = schema["required_published"]
validate_fields = schema["validate_published"]
if (
schema.get("require_source")
and not data.get("source")
and not data.get("sources")
):
return make_response(
jsonify(
{
@@ -344,10 +543,9 @@ class CreateAgent(Resource):
),
400,
)
validate_fields = ["name", "description", "prompt_id", "agent_type"]
else:
required_fields = ["name"]
validate_fields = []
required_fields = schema["required_draft"]
validate_fields = schema["validate_draft"]
missing_fields = check_required_fields(data, required_fields)
invalid_fields = validate_required_fields(data, validate_fields)
if missing_fields:
@@ -359,74 +557,50 @@ class CreateAgent(Resource):
return make_response(
jsonify({"success": False, "message": "Image upload failed"}), 400
)
folder_id = data.get("folder_id")
if folder_id:
if not ObjectId.is_valid(folder_id):
return make_response(
jsonify({"success": False, "message": "Invalid folder ID format"}),
400,
)
folder = agent_folders_collection.find_one(
{"_id": ObjectId(folder_id), "user": user}
)
if not folder:
return make_response(
jsonify({"success": False, "message": "Folder not found"}), 404
)
try:
key = str(uuid.uuid4()) if data.get("status") == "published" else ""
sources_list = []
source_field = ""
if data.get("sources") and len(data.get("sources", [])) > 0:
for source_id in data.get("sources", []):
if source_id == "default":
sources_list.append("default")
elif ObjectId.is_valid(source_id):
sources_list.append(DBRef("sources", ObjectId(source_id)))
source_field = ""
else:
source_value = data.get("source", "")
if source_value == "default":
source_field = "default"
elif ObjectId.is_valid(source_value):
source_field = DBRef("sources", ObjectId(source_value))
else:
source_field = ""
new_agent = {
"user": user,
"name": data.get("name"),
"description": data.get("description", ""),
"image": image_url,
"source": source_field,
"sources": sources_list,
"chunks": data.get("chunks", ""),
"retriever": data.get("retriever", ""),
"prompt_id": data.get("prompt_id", ""),
"tools": data.get("tools", []),
"agent_type": data.get("agent_type", ""),
"status": data.get("status"),
"json_schema": data.get("json_schema"),
"limited_token_mode": (
data.get("limited_token_mode") == "True"
if isinstance(data.get("limited_token_mode"), str)
else bool(data.get("limited_token_mode", False))
),
"token_limit": int(
data.get(
"token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"]
)
),
"limited_request_mode": (
data.get("limited_request_mode") == "True"
if isinstance(data.get("limited_request_mode"), str)
else bool(data.get("limited_request_mode", False))
),
"request_limit": int(
data.get(
"request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"]
)
),
"createdAt": datetime.datetime.now(datetime.timezone.utc),
"updatedAt": datetime.datetime.now(datetime.timezone.utc),
"lastUsedAt": None,
"key": key,
"models": data.get("models", []),
"default_model_id": data.get("default_model_id", ""),
}
if new_agent["chunks"] == "":
new_agent["chunks"] = "2"
if (
new_agent["source"] == ""
and new_agent["retriever"] == ""
and not new_agent["sources"]
):
new_agent["retriever"] = "classic"
new_agent = build_agent_document(
data, user, key, agent_type, image_url, source_field, sources_list
)
if agent_type != "workflow":
if new_agent.get("chunks") == "":
new_agent["chunks"] = "2"
if (
new_agent.get("source") == ""
and new_agent.get("retriever") == ""
and not new_agent.get("sources")
):
new_agent["retriever"] = "classic"
resp = agents_collection.insert_one(new_agent)
new_id = str(resp.inserted_id)
except Exception as err:
@@ -455,16 +629,22 @@ class UpdateAgent(Resource):
required=False,
description="List of source identifiers for multiple sources",
),
"chunks": fields.Integer(required=True, description="Chunks count"),
"retriever": fields.String(required=True, description="Retriever ID"),
"prompt_id": fields.String(required=True, description="Prompt ID"),
"chunks": fields.Integer(required=False, description="Chunks count"),
"retriever": fields.String(required=False, description="Retriever ID"),
"prompt_id": fields.String(required=False, description="Prompt ID"),
"tools": fields.List(
fields.String, required=False, description="List of tool identifiers"
),
"agent_type": fields.String(required=True, description="Type of the agent"),
"agent_type": fields.String(
required=False,
description="Type of the agent (classic, react, workflow). Defaults to 'classic' for backwards compatibility.",
),
"status": fields.String(
required=True, description="Status of the agent (draft or published)"
),
"workflow": fields.String(
required=False, description="Workflow ID for workflow-type agents"
),
"json_schema": fields.Raw(
required=False,
description="JSON schema for enforcing structured output format",
@@ -491,6 +671,9 @@ class UpdateAgent(Resource):
"default_model_id": fields.String(
required=False, description="Default model ID for this agent"
),
"folder_id": fields.String(
required=False, description="Folder ID to organize the agent"
),
},
)
@@ -529,6 +712,8 @@ class UpdateAgent(Resource):
),
400,
)
if data.get("json_schema") == "":
data["json_schema"] = None
except Exception as err:
current_app.logger.error(
f"Error parsing request data: {err}", exc_info=True
@@ -557,13 +742,7 @@ class UpdateAgent(Resource):
request, existing_agent.get("image", ""), user, storage
)
if error:
current_app.logger.error(
f"Image upload error for agent {agent_id}: {error}"
)
return make_response(
jsonify({"success": False, "message": f"Image upload failed: {error}"}),
400,
)
return error
update_fields = {}
allowed_fields = [
"name",
@@ -584,6 +763,8 @@ class UpdateAgent(Resource):
"request_limit",
"models",
"default_model_id",
"folder_id",
"workflow",
]
for field in allowed_fields:
@@ -687,17 +868,15 @@ class UpdateAgent(Resource):
elif field == "json_schema":
json_schema = data.get("json_schema")
if json_schema is not None:
if not isinstance(json_schema, dict):
try:
update_fields[field] = normalize_json_schema_payload(
json_schema
)
except JsonSchemaValidationError as exc:
return make_response(
jsonify(
{
"success": False,
"message": "JSON schema must be a valid object",
}
),
jsonify({"success": False, "message": f"JSON schema {exc}"}),
400,
)
update_fields[field] = json_schema
else:
update_fields[field] = None
elif field == "limited_token_mode":
@@ -740,10 +919,10 @@ class UpdateAgent(Resource):
)
elif field == "token_limit":
token_limit = data.get("token_limit")
# Convert to int and store
update_fields[field] = int(token_limit) if token_limit else 0
# Validate consistency with mode
if update_fields[field] > 0 and not data.get("limited_token_mode"):
return make_response(
jsonify(
@@ -768,6 +947,42 @@ class UpdateAgent(Resource):
),
400,
)
elif field == "folder_id":
folder_id = data.get("folder_id")
if folder_id:
if not ObjectId.is_valid(folder_id):
return make_response(
jsonify(
{
"success": False,
"message": "Invalid folder ID format",
}
),
400,
)
folder = agent_folders_collection.find_one(
{"_id": ObjectId(folder_id), "user": user}
)
if not folder:
return make_response(
jsonify({"success": False, "message": "Folder not found"}),
404,
)
update_fields[field] = folder_id
else:
update_fields[field] = None
elif field == "workflow":
workflow_required = (
data.get("status", existing_agent.get("status")) == "published"
and data.get("agent_type", existing_agent.get("agent_type"))
== "workflow"
)
workflow_id, workflow_error = validate_workflow_access(
data.get("workflow"), user, required=workflow_required
)
if workflow_error:
return workflow_error
update_fields[field] = workflow_id
else:
value = data[field]
if field in ["name", "description", "prompt_id", "agent_type"]:
@@ -796,46 +1011,82 @@ class UpdateAgent(Resource):
)
newly_generated_key = None
final_status = update_fields.get("status", existing_agent.get("status"))
agent_type = update_fields.get("agent_type", existing_agent.get("agent_type"))
if final_status == "published":
required_published_fields = {
"name": "Agent name",
"description": "Agent description",
"chunks": "Chunks count",
"prompt_id": "Prompt",
"agent_type": "Agent type",
}
if agent_type == "workflow":
required_published_fields = {
"name": "Agent name",
}
missing_published_fields = []
for req_field, field_label in required_published_fields.items():
final_value = update_fields.get(
req_field, existing_agent.get(req_field)
)
if not final_value:
missing_published_fields.append(field_label)
missing_published_fields = []
for req_field, field_label in required_published_fields.items():
final_value = update_fields.get(
req_field, existing_agent.get(req_field)
workflow_id = update_fields.get("workflow", existing_agent.get("workflow"))
if not workflow_id:
missing_published_fields.append("Workflow")
elif not ObjectId.is_valid(workflow_id):
missing_published_fields.append("Valid workflow")
else:
workflow = workflows_collection.find_one(
{"_id": ObjectId(workflow_id), "user": user}
)
if not workflow:
missing_published_fields.append("Workflow access")
if missing_published_fields:
return make_response(
jsonify(
{
"success": False,
"message": f"Cannot publish workflow agent. Missing required fields: {', '.join(missing_published_fields)}",
}
),
400,
)
else:
required_published_fields = {
"name": "Agent name",
"description": "Agent description",
"chunks": "Chunks count",
"prompt_id": "Prompt",
"agent_type": "Agent type",
}
missing_published_fields = []
for req_field, field_label in required_published_fields.items():
final_value = update_fields.get(
req_field, existing_agent.get(req_field)
)
if not final_value:
missing_published_fields.append(field_label)
source_val = update_fields.get("source", existing_agent.get("source"))
sources_val = update_fields.get(
"sources", existing_agent.get("sources", [])
)
if not final_value:
missing_published_fields.append(field_label)
source_val = update_fields.get("source", existing_agent.get("source"))
sources_val = update_fields.get(
"sources", existing_agent.get("sources", [])
)
has_valid_source = (
isinstance(source_val, DBRef)
or source_val == "default"
or (isinstance(sources_val, list) and len(sources_val) > 0)
)
if not has_valid_source:
missing_published_fields.append("Source")
if missing_published_fields:
return make_response(
jsonify(
{
"success": False,
"message": f"Cannot publish agent. Missing or invalid required fields: {', '.join(missing_published_fields)}",
}
),
400,
has_valid_source = (
isinstance(source_val, DBRef)
or source_val == "default"
or (isinstance(sources_val, list) and len(sources_val) > 0)
)
if not has_valid_source:
missing_published_fields.append("Source")
if missing_published_fields:
return make_response(
jsonify(
{
"success": False,
"message": f"Cannot publish agent. Missing or invalid required fields: {', '.join(missing_published_fields)}",
}
),
400,
)
if not existing_agent.get("key"):
newly_generated_key = str(uuid.uuid4())
update_fields["key"] = newly_generated_key
@@ -907,6 +1158,29 @@ class DeleteAgent(Resource):
jsonify({"success": False, "message": "Agent not found"}), 404
)
deleted_id = str(deleted_agent["_id"])
if deleted_agent.get("agent_type") == "workflow" and deleted_agent.get(
"workflow"
):
workflow_id = normalize_workflow_reference(deleted_agent.get("workflow"))
if workflow_id and ObjectId.is_valid(workflow_id):
workflow_oid = ObjectId(workflow_id)
owned_workflow = workflows_collection.find_one(
{"_id": workflow_oid, "user": user}, {"_id": 1}
)
if owned_workflow:
workflow_nodes_collection.delete_many({"workflow_id": workflow_id})
workflow_edges_collection.delete_many({"workflow_id": workflow_id})
workflows_collection.delete_one({"_id": workflow_oid, "user": user})
else:
current_app.logger.warning(
f"Skipping workflow cleanup for non-owned workflow {workflow_id}"
)
elif workflow_id:
current_app.logger.warning(
f"Skipping workflow cleanup for invalid workflow id {workflow_id}"
)
except Exception as err:
current_app.logger.error(f"Error deleting agent: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
@@ -1015,19 +1289,16 @@ class AdoptAgent(Resource):
def post(self):
if not (decoded_token := request.decoded_token):
return make_response(jsonify({"success": False}), 401)
if not (agent_id := request.args.get("id")):
return make_response(
jsonify({"success": False, "message": "ID required"}), 400
)
try:
agent = agents_collection.find_one(
{"_id": ObjectId(agent_id), "user": "system"}
)
if not agent:
return make_response(jsonify({"status": "Not found"}), 404)
new_agent = agent.copy()
new_agent.pop("_id", None)
new_agent["user"] = decoded_token["sub"]

View File

@@ -255,8 +255,8 @@ class ShareAgent(Resource):
{"$unset": {"shared_metadata": ""}},
)
except Exception as err:
current_app.logger.error(f"Error sharing/unsharing agent: {err}")
return make_response(jsonify({"success": False, "error": str(err)}), 400)
current_app.logger.error(f"Error sharing/unsharing agent: {err}", exc_info=True)
return make_response(jsonify({"success": False, "error": "Failed to update agent sharing status"}), 400)
shared_token = shared_token if shared else None
return make_response(
jsonify({"success": True, "shared_token": shared_token}), 200

View File

@@ -1,15 +1,36 @@
"""File attachments and media routes."""
import os
import tempfile
from pathlib import Path
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from application.api import api
from application.api.user.base import agents_collection, storage
from application.api.user.tasks import store_attachment
from application.cache import get_redis_instance
from application.core.settings import settings
from application.stt.constants import (
SUPPORTED_AUDIO_EXTENSIONS,
SUPPORTED_AUDIO_MIME_TYPES,
)
from application.stt.upload_limits import (
AudioFileTooLargeError,
build_stt_file_size_limit_message,
enforce_audio_file_size_limit,
is_audio_filename,
)
from application.stt.live_session import (
apply_live_stt_hypothesis,
create_live_stt_session,
delete_live_stt_session,
finalize_live_stt_session,
get_live_stt_transcript_text,
load_live_stt_session,
save_live_stt_session,
)
from application.stt.stt_creator import STTCreator
from application.tts.tts_creator import TTSCreator
from application.utils import safe_filename
@@ -19,6 +40,74 @@ attachments_ns = Namespace(
)
def _resolve_authenticated_user():
decoded_token = getattr(request, "decoded_token", None)
api_key = request.form.get("api_key") or request.args.get("api_key")
if decoded_token:
return safe_filename(decoded_token.get("sub"))
if api_key:
from application.api.user.base import agents_collection
agent = agents_collection.find_one({"key": api_key})
if not agent:
return make_response(
jsonify({"success": False, "message": "Invalid API key"}), 401
)
return safe_filename(agent.get("user"))
return None
def _get_uploaded_file_size(file) -> int:
try:
current_position = file.stream.tell()
file.stream.seek(0, os.SEEK_END)
size_bytes = file.stream.tell()
file.stream.seek(current_position)
return size_bytes
except Exception:
return 0
def _is_supported_audio_mimetype(mimetype: str) -> bool:
if not mimetype:
return True
normalized = mimetype.split(";")[0].strip().lower()
return normalized.startswith("audio/") or normalized in SUPPORTED_AUDIO_MIME_TYPES
def _enforce_uploaded_audio_size_limit(file, filename: str) -> None:
if not is_audio_filename(filename):
return
size_bytes = _get_uploaded_file_size(file)
if size_bytes:
enforce_audio_file_size_limit(size_bytes)
def _get_store_attachment_user_error(exc: Exception) -> str:
if isinstance(exc, AudioFileTooLargeError):
return build_stt_file_size_limit_message()
return "Failed to process file"
def _require_live_stt_redis():
redis_client = get_redis_instance()
if redis_client:
return redis_client
return make_response(
jsonify({"success": False, "message": "Live transcription is unavailable"}),
503,
)
def _parse_bool_form_value(value: str | None) -> bool:
if value is None:
return False
return value.strip().lower() in {"1", "true", "yes", "on"}
@attachments_ns.route("/store_attachment")
class StoreAttachment(Resource):
@api.expect(
@@ -36,8 +125,9 @@ class StoreAttachment(Resource):
description="Stores one or multiple attachments without vectorization or training. Supports user or API key authentication."
)
def post(self):
decoded_token = getattr(request, "decoded_token", None)
api_key = request.form.get("api_key") or request.args.get("api_key")
auth_user = _resolve_authenticated_user()
if hasattr(auth_user, "status_code"):
return auth_user
files = request.files.getlist("file")
if not files:
@@ -51,22 +141,16 @@ class StoreAttachment(Resource):
400,
)
user = None
if decoded_token:
user = safe_filename(decoded_token.get("sub"))
elif api_key:
agent = agents_collection.find_one({"key": api_key})
if not agent:
return make_response(
jsonify({"success": False, "message": "Invalid API key"}), 401
)
user = safe_filename(agent.get("user"))
else:
user = auth_user
if not user:
return make_response(
jsonify({"success": False, "message": "Authentication required"}), 401
)
try:
from application.api.user.tasks import store_attachment
from application.api.user.base import storage
tasks = []
errors = []
original_file_count = len(files)
@@ -75,6 +159,7 @@ class StoreAttachment(Resource):
try:
attachment_id = ObjectId()
original_filename = safe_filename(os.path.basename(file.filename))
_enforce_uploaded_audio_size_limit(file, original_filename)
relative_path = f"{settings.UPLOAD_FOLDER}/{user}/attachments/{str(attachment_id)}/{original_filename}"
metadata = storage.save_file(file, relative_path)
@@ -90,20 +175,33 @@ class StoreAttachment(Resource):
"task_id": task.id,
"filename": original_filename,
"attachment_id": str(attachment_id),
"upload_index": idx,
})
except Exception as file_err:
current_app.logger.error(f"Error processing file {idx} ({file.filename}): {file_err}", exc_info=True)
errors.append({
"upload_index": idx,
"filename": file.filename,
"error": str(file_err)
"error": _get_store_attachment_user_error(file_err),
})
if not tasks:
error_msg = "No valid files to upload"
if errors:
error_msg += f". Errors: {errors}"
if errors and all(
error.get("error") == build_stt_file_size_limit_message()
for error in errors
):
return make_response(
jsonify(
{
"success": False,
"message": build_stt_file_size_limit_message(),
"errors": errors,
}
),
413,
)
return make_response(
jsonify({"status": "error", "message": error_msg, "errors": errors}),
jsonify({"status": "error", "message": "No valid files to upload"}),
400,
)
@@ -135,7 +233,379 @@ class StoreAttachment(Resource):
)
except Exception as err:
current_app.logger.error(f"Error storing attachment: {err}", exc_info=True)
return make_response(jsonify({"success": False, "error": str(err)}), 400)
return make_response(jsonify({"success": False, "error": "Failed to store attachment"}), 400)
@attachments_ns.route("/stt")
class SpeechToText(Resource):
@api.expect(
api.model(
"SpeechToTextModel",
{
"file": fields.Raw(required=True, description="Audio file"),
"language": fields.String(
required=False, description="Optional transcription language hint"
),
},
)
)
@api.doc(description="Transcribe an uploaded audio file")
def post(self):
auth_user = _resolve_authenticated_user()
if hasattr(auth_user, "status_code"):
return auth_user
if not auth_user:
return make_response(
jsonify({"success": False, "message": "Authentication required"}),
401,
)
file = request.files.get("file")
if not file or file.filename == "":
return make_response(
jsonify({"success": False, "message": "Missing file"}),
400,
)
filename = safe_filename(os.path.basename(file.filename))
suffix = Path(filename).suffix.lower()
if suffix not in SUPPORTED_AUDIO_EXTENSIONS:
return make_response(
jsonify({"success": False, "message": "Unsupported audio format"}),
400,
)
if not _is_supported_audio_mimetype(file.mimetype or ""):
return make_response(
jsonify({"success": False, "message": "Unsupported audio MIME type"}),
400,
)
try:
_enforce_uploaded_audio_size_limit(file, filename)
except AudioFileTooLargeError:
return make_response(
jsonify(
{
"success": False,
"message": build_stt_file_size_limit_message(),
}
),
413,
)
temp_path = None
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
file.save(temp_file.name)
temp_path = Path(temp_file.name)
stt_instance = STTCreator.create_stt(settings.STT_PROVIDER)
transcript = stt_instance.transcribe(
temp_path,
language=request.form.get("language") or settings.STT_LANGUAGE,
timestamps=settings.STT_ENABLE_TIMESTAMPS,
diarize=settings.STT_ENABLE_DIARIZATION,
)
return make_response(jsonify({"success": True, **transcript}), 200)
except Exception as err:
current_app.logger.error(f"Error transcribing audio: {err}", exc_info=True)
return make_response(
jsonify({"success": False, "message": "Failed to transcribe audio"}),
400,
)
finally:
if temp_path and temp_path.exists():
temp_path.unlink()
@attachments_ns.route("/stt/live/start")
class LiveSpeechToTextStart(Resource):
@api.doc(description="Start a live speech-to-text session")
def post(self):
auth_user = _resolve_authenticated_user()
if hasattr(auth_user, "status_code"):
return auth_user
if not auth_user:
return make_response(
jsonify({"success": False, "message": "Authentication required"}),
401,
)
redis_client = _require_live_stt_redis()
if hasattr(redis_client, "status_code"):
return redis_client
payload = request.get_json(silent=True) or {}
session_state = create_live_stt_session(
user=auth_user,
language=payload.get("language") or settings.STT_LANGUAGE,
)
save_live_stt_session(redis_client, session_state)
return make_response(
jsonify(
{
"success": True,
"session_id": session_state["session_id"],
"language": session_state.get("language"),
"committed_text": "",
"mutable_text": "",
"previous_hypothesis": "",
"latest_hypothesis": "",
"finalized_text": "",
"pending_text": "",
"transcript_text": "",
}
),
200,
)
@attachments_ns.route("/stt/live/chunk")
class LiveSpeechToTextChunk(Resource):
@api.expect(
api.model(
"LiveSpeechToTextChunkModel",
{
"session_id": fields.String(
required=True, description="Live transcription session ID"
),
"chunk_index": fields.Integer(
required=True, description="Sequential chunk index"
),
"is_silence": fields.Boolean(
required=False,
description="Whether the latest capture window was mostly silence",
),
"file": fields.Raw(required=True, description="Audio chunk"),
},
)
)
@api.doc(description="Transcribe a chunk for a live speech-to-text session")
def post(self):
auth_user = _resolve_authenticated_user()
if hasattr(auth_user, "status_code"):
return auth_user
if not auth_user:
return make_response(
jsonify({"success": False, "message": "Authentication required"}),
401,
)
redis_client = _require_live_stt_redis()
if hasattr(redis_client, "status_code"):
return redis_client
session_id = request.form.get("session_id", "").strip()
if not session_id:
return make_response(
jsonify({"success": False, "message": "Missing session_id"}),
400,
)
session_state = load_live_stt_session(redis_client, session_id)
if not session_state:
return make_response(
jsonify(
{
"success": False,
"message": "Live transcription session not found",
}
),
404,
)
if safe_filename(str(session_state.get("user", ""))) != auth_user:
return make_response(
jsonify({"success": False, "message": "Forbidden"}),
403,
)
chunk_index_raw = request.form.get("chunk_index", "").strip()
if chunk_index_raw == "":
return make_response(
jsonify({"success": False, "message": "Missing chunk_index"}),
400,
)
try:
chunk_index = int(chunk_index_raw)
except ValueError:
return make_response(
jsonify({"success": False, "message": "Invalid chunk_index"}),
400,
)
is_silence = _parse_bool_form_value(request.form.get("is_silence"))
file = request.files.get("file")
if not file or file.filename == "":
return make_response(
jsonify({"success": False, "message": "Missing file"}),
400,
)
filename = safe_filename(os.path.basename(file.filename))
suffix = Path(filename).suffix.lower()
if suffix not in SUPPORTED_AUDIO_EXTENSIONS:
return make_response(
jsonify({"success": False, "message": "Unsupported audio format"}),
400,
)
if not _is_supported_audio_mimetype(file.mimetype or ""):
return make_response(
jsonify({"success": False, "message": "Unsupported audio MIME type"}),
400,
)
try:
_enforce_uploaded_audio_size_limit(file, filename)
except AudioFileTooLargeError:
return make_response(
jsonify(
{
"success": False,
"message": build_stt_file_size_limit_message(),
}
),
413,
)
temp_path = None
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
file.save(temp_file.name)
temp_path = Path(temp_file.name)
session_language = session_state.get("language") or settings.STT_LANGUAGE
stt_instance = STTCreator.create_stt(settings.STT_PROVIDER)
transcript = stt_instance.transcribe(
temp_path,
language=session_language,
timestamps=False,
diarize=False,
)
if not session_state.get("language") and transcript.get("language"):
session_state["language"] = transcript["language"]
try:
apply_live_stt_hypothesis(
session_state,
str(transcript.get("text", "")),
chunk_index,
is_silence=is_silence,
)
except ValueError:
current_app.logger.warning(
"Invalid live transcription chunk",
exc_info=True,
)
return make_response(
jsonify(
{
"success": False,
"message": "Invalid live transcription chunk",
}
),
409,
)
save_live_stt_session(redis_client, session_state)
return make_response(
jsonify(
{
"success": True,
"session_id": session_id,
"chunk_index": chunk_index,
"chunk_text": transcript.get("text", ""),
"is_silence": is_silence,
"language": session_state.get("language"),
"committed_text": session_state.get("committed_text", ""),
"mutable_text": session_state.get("mutable_text", ""),
"previous_hypothesis": session_state.get(
"previous_hypothesis", ""
),
"latest_hypothesis": session_state.get(
"latest_hypothesis", ""
),
"finalized_text": session_state.get("committed_text", ""),
"pending_text": session_state.get("mutable_text", ""),
"transcript_text": get_live_stt_transcript_text(session_state),
}
),
200,
)
except Exception as err:
current_app.logger.error(
f"Error transcribing live audio chunk: {err}", exc_info=True
)
return make_response(
jsonify({"success": False, "message": "Failed to transcribe audio"}),
400,
)
finally:
if temp_path and temp_path.exists():
temp_path.unlink()
@attachments_ns.route("/stt/live/finish")
class LiveSpeechToTextFinish(Resource):
@api.doc(description="Finish a live speech-to-text session")
def post(self):
auth_user = _resolve_authenticated_user()
if hasattr(auth_user, "status_code"):
return auth_user
if not auth_user:
return make_response(
jsonify({"success": False, "message": "Authentication required"}),
401,
)
redis_client = _require_live_stt_redis()
if hasattr(redis_client, "status_code"):
return redis_client
payload = request.get_json(silent=True) or {}
session_id = str(payload.get("session_id", "")).strip()
if not session_id:
return make_response(
jsonify({"success": False, "message": "Missing session_id"}),
400,
)
session_state = load_live_stt_session(redis_client, session_id)
if not session_state:
return make_response(
jsonify(
{
"success": False,
"message": "Live transcription session not found",
}
),
404,
)
if safe_filename(str(session_state.get("user", ""))) != auth_user:
return make_response(
jsonify({"success": False, "message": "Forbidden"}),
403,
)
final_text = finalize_live_stt_session(session_state)
delete_live_stt_session(redis_client, session_id)
return make_response(
jsonify(
{
"success": True,
"session_id": session_id,
"language": session_state.get("language"),
"text": final_text,
}
),
200,
)
@attachments_ns.route("/images/<path:image_path>")
@@ -143,6 +613,8 @@ class ServeImage(Resource):
@api.doc(description="Serve an image from storage")
def get(self, image_path):
try:
from application.api.user.base import storage
file_obj = storage.get_file(image_path)
extension = image_path.split(".")[-1].lower()
content_type = f"image/{extension}"

View File

@@ -31,12 +31,17 @@ sources_collection = db["sources"]
prompts_collection = db["prompts"]
feedback_collection = db["feedback"]
agents_collection = db["agents"]
agent_folders_collection = db["agent_folders"]
token_usage_collection = db["token_usage"]
shared_conversations_collections = db["shared_conversations"]
users_collection = db["users"]
user_logs_collection = db["user_logs"]
user_tools_collection = db["user_tools"]
attachments_collection = db["attachments"]
workflow_runs_collection = db["workflow_runs"]
workflows_collection = db["workflows"]
workflow_nodes_collection = db["workflow_nodes"]
workflow_edges_collection = db["workflow_edges"]
try:
@@ -46,6 +51,25 @@ try:
background=True,
)
users_collection.create_index("user_id", unique=True)
workflows_collection.create_index(
[("user", 1)], name="workflow_user_index", background=True
)
workflow_nodes_collection.create_index(
[("workflow_id", 1)], name="node_workflow_index", background=True
)
workflow_nodes_collection.create_index(
[("workflow_id", 1), ("graph_version", 1)],
name="node_workflow_graph_version_index",
background=True,
)
workflow_edges_collection.create_index(
[("workflow_id", 1)], name="edge_workflow_index", background=True
)
workflow_edges_collection.create_index(
[("workflow_id", 1), ("graph_version", 1)],
name="edge_workflow_graph_version_index",
background=True,
)
except Exception as e:
print("Error creating indexes:", e)
current_dir = os.path.dirname(
@@ -121,14 +145,22 @@ def resolve_tool_details(tool_ids):
Returns:
List of tool details with id, name, and display_name
"""
valid_ids = []
for tid in tool_ids:
try:
valid_ids.append(ObjectId(tid))
except Exception:
continue
tools = user_tools_collection.find(
{"_id": {"$in": [ObjectId(tid) for tid in tool_ids]}}
)
{"_id": {"$in": valid_ids}}
) if valid_ids else []
return [
{
"id": str(tool["_id"]),
"name": tool.get("name", ""),
"display_name": tool.get("displayName", tool.get("name", "")),
"display_name": tool.get("customName")
or tool.get("displayName")
or tool.get("name", ""),
}
for tool in tools
]

View File

@@ -5,8 +5,7 @@ Main user API routes - registers all namespace modules.
from flask import Blueprint
from application.api import api
from .agents import agents_ns, agents_sharing_ns, agents_webhooks_ns
from .agents import agents_ns, agents_sharing_ns, agents_webhooks_ns, agents_folders_ns
from .analytics import analytics_ns
from .attachments import attachments_ns
from .conversations import conversations_ns
@@ -15,6 +14,7 @@ from .prompts import prompts_ns
from .sharing import sharing_ns
from .sources import sources_chunks_ns, sources_ns, sources_upload_ns
from .tools import tools_mcp_ns, tools_ns
from .workflows import workflows_ns
user = Blueprint("user", __name__)
@@ -31,10 +31,11 @@ api.add_namespace(conversations_ns)
# Models
api.add_namespace(models_ns)
# Agents (main, sharing, webhooks)
# Agents (main, sharing, webhooks, folders)
api.add_namespace(agents_ns)
api.add_namespace(agents_sharing_ns)
api.add_namespace(agents_webhooks_ns)
api.add_namespace(agents_folders_ns)
# Prompts
api.add_namespace(prompts_ns)
@@ -50,3 +51,6 @@ api.add_namespace(sources_upload_ns)
# Tools (main, MCP)
api.add_namespace(tools_ns)
api.add_namespace(tools_mcp_ns)
# Workflows
api.add_namespace(workflows_ns)

View File

@@ -220,8 +220,23 @@ class GetPubliclySharedConversations(Resource):
shared
and "conversation_id" in shared
):
# conversation_id is now stored as an ObjectId, not a DBRef
# Handle DBRef (legacy), ObjectId, dict, and string formats for conversation_id
conversation_id = shared["conversation_id"]
if isinstance(conversation_id, DBRef):
conversation_id = conversation_id.id
elif isinstance(conversation_id, dict):
# Handle dict representation of DBRef (e.g., {"$ref": "...", "$id": "..."})
if "$id" in conversation_id:
conv_id = conversation_id["$id"]
# $id might be a dict like {"$oid": "..."} or a string
if isinstance(conv_id, dict) and "$oid" in conv_id:
conversation_id = ObjectId(conv_id["$oid"])
else:
conversation_id = ObjectId(conv_id)
elif "_id" in conversation_id:
conversation_id = ObjectId(conversation_id["_id"])
elif isinstance(conversation_id, str):
conversation_id = ObjectId(conversation_id)
conversation = conversations_collection.find_one(
{"_id": conversation_id}
)

View File

@@ -55,9 +55,14 @@ class GetChunks(Resource):
if path:
chunk_source = metadata.get("source", "")
# Check if the chunk's source matches the requested path
chunk_file_path = metadata.get("file_path", "")
# Check if the chunk matches the requested path
# For file uploads: source ends with path (e.g., "inputs/.../file.pdf" ends with "file.pdf")
# For crawlers: file_path ends with path (e.g., "guides/setup.md" ends with "setup.md")
source_match = chunk_source and chunk_source.endswith(path)
file_path_match = chunk_file_path and chunk_file_path.endswith(path)
if not chunk_source or not chunk_source.endswith(path):
if not (source_match or file_path_match):
continue
# Filter by search term if provided

View File

@@ -9,6 +9,7 @@ from flask_restx import fields, Namespace, Resource
from application.api import api
from application.api.user.base import sources_collection
from application.api.user.tasks import sync_source
from application.core.settings import settings
from application.storage.storage_creator import StorageCreator
from application.utils import check_required_fields
@@ -20,6 +21,21 @@ sources_ns = Namespace(
)
def _get_provider_from_remote_data(remote_data):
if not remote_data:
return None
if isinstance(remote_data, dict):
return remote_data.get("provider")
if isinstance(remote_data, str):
try:
remote_data_obj = json.loads(remote_data)
except Exception:
return None
if isinstance(remote_data_obj, dict):
return remote_data_obj.get("provider")
return None
@sources_ns.route("/sources")
class CombinedJson(Resource):
@api.doc(description="Provide JSON file with combined available indexes")
@@ -41,6 +57,7 @@ class CombinedJson(Resource):
try:
for index in sources_collection.find({"user": user}).sort("date", -1):
provider = _get_provider_from_remote_data(index.get("remote_data"))
data.append(
{
"id": str(index["_id"]),
@@ -51,6 +68,7 @@ class CombinedJson(Resource):
"tokens": index.get("tokens", ""),
"retriever": index.get("retriever", "classic"),
"syncFrequency": index.get("sync_frequency", ""),
"provider": provider,
"is_nested": bool(index.get("directory_structure")),
"type": index.get(
"type", "file"
@@ -107,6 +125,7 @@ class PaginatedSources(Resource):
paginated_docs = []
for doc in documents:
provider = _get_provider_from_remote_data(doc.get("remote_data"))
doc_data = {
"id": str(doc["_id"]),
"name": doc.get("name", ""),
@@ -116,6 +135,7 @@ class PaginatedSources(Resource):
"tokens": doc.get("tokens", ""),
"retriever": doc.get("retriever", "classic"),
"syncFrequency": doc.get("sync_frequency", ""),
"provider": provider,
"isNested": bool(doc.get("directory_structure")),
"type": doc.get("type", "file"),
}
@@ -240,7 +260,7 @@ class ManageSync(Resource):
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
data = request.get_json() or {}
required_fields = ["source_id", "sync_frequency"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
@@ -269,6 +289,72 @@ class ManageSync(Resource):
return make_response(jsonify({"success": True}), 200)
@sources_ns.route("/sync_source")
class SyncSource(Resource):
sync_source_model = api.model(
"SyncSourceModel",
{"source_id": fields.String(required=True, description="Source ID")},
)
@api.expect(sync_source_model)
@api.doc(description="Trigger an immediate sync for a source")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
required_fields = ["source_id"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
source_id = data["source_id"]
if not ObjectId.is_valid(source_id):
return make_response(
jsonify({"success": False, "message": "Invalid source ID"}), 400
)
doc = sources_collection.find_one(
{"_id": ObjectId(source_id), "user": user}
)
if not doc:
return make_response(
jsonify({"success": False, "message": "Source not found"}), 404
)
source_type = doc.get("type", "")
if source_type.startswith("connector"):
return make_response(
jsonify(
{
"success": False,
"message": "Connector sources must be synced via /api/connectors/sync",
}
),
400,
)
source_data = doc.get("remote_data")
if not source_data:
return make_response(
jsonify({"success": False, "message": "Source is not syncable"}), 400
)
try:
task = sync_source.delay(
source_data=source_data,
job_name=doc.get("name", ""),
user=user,
loader=source_type,
sync_frequency=doc.get("sync_frequency", "never"),
retriever=doc.get("retriever", "classic"),
doc_id=source_id,
)
except Exception as err:
current_app.logger.error(
f"Error starting sync for source {source_id}: {err}",
exc_info=True,
)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
@sources_ns.route("/directory_structure")
class DirectoryStructure(Resource):
@api.doc(
@@ -320,4 +406,4 @@ class DirectoryStructure(Resource):
current_app.logger.error(
f"Error retrieving directory structure: {e}", exc_info=True
)
return make_response(jsonify({"success": False, "error": str(e)}), 500)
return make_response(jsonify({"success": False, "error": "Failed to retrieve directory structure"}), 500)

View File

@@ -14,7 +14,14 @@ from application.api.user.base import sources_collection
from application.api.user.tasks import ingest, ingest_connector_task, ingest_remote
from application.core.settings import settings
from application.parser.connectors.connector_creator import ConnectorCreator
from application.parser.file.constants import SUPPORTED_SOURCE_EXTENSIONS
from application.storage.storage_creator import StorageCreator
from application.stt.upload_limits import (
AudioFileTooLargeError,
build_stt_file_size_limit_message,
enforce_audio_file_size_limit,
is_audio_filename,
)
from application.utils import check_required_fields, safe_filename
@@ -23,6 +30,12 @@ sources_upload_ns = Namespace(
)
def _enforce_audio_path_size_limit(file_path: str, filename: str) -> None:
if not is_audio_filename(filename):
return
enforce_audio_file_size_limit(os.path.getsize(file_path))
@sources_upload_ns.route("/upload")
class UploadFile(Resource):
@api.expect(
@@ -64,19 +77,28 @@ class UploadFile(Resource):
safe_user = safe_filename(user)
dir_name = safe_filename(job_name)
base_path = f"{settings.UPLOAD_FOLDER}/{safe_user}/{dir_name}"
file_name_map = {}
try:
storage = StorageCreator.get_storage()
for file in files:
original_filename = file.filename
original_filename = os.path.basename(file.filename)
safe_file = safe_filename(original_filename)
if original_filename:
file_name_map[safe_file] = original_filename
with tempfile.TemporaryDirectory() as temp_dir:
temp_file_path = os.path.join(temp_dir, safe_file)
file.save(temp_file_path)
_enforce_audio_path_size_limit(temp_file_path, safe_file)
if zipfile.is_zipfile(temp_file_path):
# Only extract actual .zip files, not Office formats (.docx, .xlsx, .pptx)
# which are technically zip archives but should be processed as-is
is_office_format = safe_file.lower().endswith(
(".docx", ".xlsx", ".pptx", ".odt", ".ods", ".odp", ".epub")
)
if zipfile.is_zipfile(temp_file_path) and not is_office_format:
try:
with zipfile.ZipFile(temp_file_path, "r") as zip_ref:
zip_ref.extractall(path=temp_dir)
@@ -94,6 +116,10 @@ class UploadFile(Resource):
os.path.join(root, extracted_file), temp_dir
)
storage_path = f"{base_path}/{rel_path}"
_enforce_audio_path_size_limit(
os.path.join(root, extracted_file),
extracted_file,
)
with open(
os.path.join(root, extracted_file), "rb"
@@ -116,27 +142,22 @@ class UploadFile(Resource):
storage.save_file(f, file_path)
task = ingest.delay(
settings.UPLOAD_FOLDER,
[
".rst",
".md",
".pdf",
".txt",
".docx",
".csv",
".epub",
".html",
".mdx",
".json",
".xlsx",
".pptx",
".png",
".jpg",
".jpeg",
],
list(SUPPORTED_SOURCE_EXTENSIONS),
job_name,
user,
file_path=base_path,
filename=dir_name,
file_name_map=file_name_map,
)
except AudioFileTooLargeError:
return make_response(
jsonify(
{
"success": False,
"message": build_stt_file_size_limit_message(),
}
),
413,
)
except Exception as err:
current_app.logger.error(f"Error uploading file: {err}", exc_info=True)
@@ -182,6 +203,8 @@ class UploadRemote(Resource):
source_data = config.get("url")
elif data["source"] == "reddit":
source_data = config
elif data["source"] == "s3":
source_data = config
elif data["source"] in ConnectorCreator.get_supported_connectors():
session_token = config.get("session_token")
if not session_token:
@@ -334,6 +357,14 @@ class ManageSourceFiles(Resource):
storage = StorageCreator.get_storage()
source_file_path = source.get("file_path", "")
parent_dir = request.form.get("parent_dir", "")
file_name_map = source.get("file_name_map") or {}
if isinstance(file_name_map, str):
try:
file_name_map = json.loads(file_name_map)
except Exception:
file_name_map = {}
if not isinstance(file_name_map, dict):
file_name_map = {}
if parent_dir and (parent_dir.startswith("/") or ".." in parent_dir):
return make_response(
@@ -355,19 +386,35 @@ class ManageSourceFiles(Resource):
400,
)
added_files = []
map_updated = False
target_dir = source_file_path
if parent_dir:
target_dir = f"{source_file_path}/{parent_dir}"
for file in files:
if file.filename:
safe_filename_str = safe_filename(file.filename)
original_filename = os.path.basename(file.filename)
safe_filename_str = safe_filename(original_filename)
file_path = f"{target_dir}/{safe_filename_str}"
# Save file to storage
storage.save_file(file, file_path)
added_files.append(safe_filename_str)
if original_filename:
relative_key = (
f"{parent_dir}/{safe_filename_str}"
if parent_dir
else safe_filename_str
)
file_name_map[relative_key] = original_filename
map_updated = True
if map_updated:
sources_collection.update_one(
{"_id": ObjectId(source_id)},
{"$set": {"file_name_map": file_name_map}},
)
# Trigger re-ingestion pipeline
from application.api.user.tasks import reingest_source_task
@@ -414,6 +461,7 @@ class ManageSourceFiles(Resource):
# Remove files from storage and directory structure
removed_files = []
map_updated = False
for file_path in file_paths:
full_path = f"{source_file_path}/{file_path}"
@@ -422,6 +470,15 @@ class ManageSourceFiles(Resource):
if storage.file_exists(full_path):
storage.delete_file(full_path)
removed_files.append(file_path)
if file_path in file_name_map:
file_name_map.pop(file_path, None)
map_updated = True
if map_updated and isinstance(file_name_map, dict):
sources_collection.update_one(
{"_id": ObjectId(source_id)},
{"$set": {"file_name_map": file_name_map}},
)
# Trigger re-ingestion pipeline
from application.api.user.tasks import reingest_source_task
@@ -504,6 +561,20 @@ class ManageSourceFiles(Resource):
f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}, "
f"Full path: {full_directory_path}"
)
if directory_path and file_name_map:
prefix = f"{directory_path.rstrip('/')}/"
keys_to_remove = [
key
for key in file_name_map.keys()
if key == directory_path or key.startswith(prefix)
]
if keys_to_remove:
for key in keys_to_remove:
file_name_map.pop(key, None)
sources_collection.update_one(
{"_id": ObjectId(source_id)},
{"$set": {"file_name_map": file_name_map}},
)
# Trigger re-ingestion pipeline
@@ -574,8 +645,9 @@ class TaskStatus(Resource):
):
task_meta = str(task_meta) # Convert to a string representation
except ConnectionError as err:
current_app.logger.error(f"Connection error getting task status: {err}")
return make_response(
jsonify({"success": False, "message": str(err)}), 503
jsonify({"success": False, "message": "Service unavailable"}), 503
)
except Exception as err:
current_app.logger.error(f"Error getting task status: {err}", exc_info=True)

View File

@@ -8,13 +8,25 @@ from application.worker import (
mcp_oauth,
mcp_oauth_status,
remote_worker,
sync,
sync_worker,
)
@celery.task(bind=True)
def ingest(self, directory, formats, job_name, user, file_path, filename):
resp = ingest_worker(self, directory, formats, job_name, file_path, filename, user)
def ingest(
self, directory, formats, job_name, user, file_path, filename, file_name_map=None
):
resp = ingest_worker(
self,
directory,
formats,
job_name,
file_path,
filename,
user,
file_name_map=file_name_map,
)
return resp
@@ -38,6 +50,30 @@ def schedule_syncs(self, frequency):
return resp
@celery.task(bind=True)
def sync_source(
self,
source_data,
job_name,
user,
loader,
sync_frequency,
retriever,
doc_id,
):
resp = sync(
self,
source_data,
job_name,
user,
loader,
sync_frequency,
retriever,
doc_id,
)
return resp
@celery.task(bind=True)
def store_attachment(self, file_info, user):
resp = attachment_worker(self, file_info, user)

View File

@@ -1,21 +1,67 @@
"""Tool management MCP server integration."""
import json
from email.quoprimime import unquote
from urllib.parse import urlencode, urlparse
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, redirect, request
from flask_restx import fields, Namespace, Resource
from flask_restx import Namespace, Resource, fields
from application.agents.tools.mcp_tool import MCPOAuthManager, MCPTool
from application.api import api
from application.api.user.base import user_tools_collection
from application.api.user.tools.routes import transform_actions
from application.cache import get_redis_instance
from application.security.encryption import encrypt_credentials
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.security.encryption import decrypt_credentials, encrypt_credentials
from application.utils import check_required_fields
tools_mcp_ns = Namespace("tools", description="Tool management operations", path="/api")
_mongo = MongoDB.get_client()
_db = _mongo[settings.MONGO_DB_NAME]
_connector_sessions = _db["connector_sessions"]
_ALLOWED_TRANSPORTS = {"auto", "sse", "http"}
def _sanitize_mcp_transport(config):
"""Normalise and validate the transport_type field.
Strips ``command`` / ``args`` keys that are only valid for local STDIO
transports and returns the cleaned transport type string.
"""
transport_type = (config.get("transport_type") or "auto").lower()
if transport_type not in _ALLOWED_TRANSPORTS:
raise ValueError(f"Unsupported transport_type: {transport_type}")
config.pop("command", None)
config.pop("args", None)
config["transport_type"] = transport_type
return transport_type
def _extract_auth_credentials(config):
"""Build an ``auth_credentials`` dict from the raw MCP config."""
auth_credentials = {}
auth_type = config.get("auth_type", "none")
if auth_type == "api_key":
if config.get("api_key"):
auth_credentials["api_key"] = config["api_key"]
if config.get("api_key_header"):
auth_credentials["api_key_header"] = config["api_key_header"]
elif auth_type == "bearer":
if config.get("bearer_token"):
auth_credentials["bearer_token"] = config["bearer_token"]
elif auth_type == "basic":
if config.get("username"):
auth_credentials["username"] = config["username"]
if config.get("password"):
auth_credentials["password"] = config["password"]
return auth_credentials
@tools_mcp_ns.route("/mcp_server/test")
class TestMCPServerConfig(Resource):
@@ -43,34 +89,35 @@ class TestMCPServerConfig(Resource):
return missing_fields
try:
config = data["config"]
try:
_sanitize_mcp_transport(config)
except ValueError:
return make_response(
jsonify({"success": False, "error": "Unsupported transport_type"}),
400,
)
auth_credentials = {}
auth_type = config.get("auth_type", "none")
if auth_type == "api_key" and "api_key" in config:
auth_credentials["api_key"] = config["api_key"]
if "api_key_header" in config:
auth_credentials["api_key_header"] = config["api_key_header"]
elif auth_type == "bearer" and "bearer_token" in config:
auth_credentials["bearer_token"] = config["bearer_token"]
elif auth_type == "basic":
if "username" in config:
auth_credentials["username"] = config["username"]
if "password" in config:
auth_credentials["password"] = config["password"]
auth_credentials = _extract_auth_credentials(config)
test_config = config.copy()
test_config["auth_credentials"] = auth_credentials
mcp_tool = MCPTool(config=test_config, user_id=user)
result = mcp_tool.test_connection()
if result.get("requires_oauth"):
return make_response(jsonify(result), 200)
if not result.get("success") and "message" in result:
current_app.logger.error(
f"MCP connection test failed: {result.get('message')}"
)
result["message"] = "Connection test failed"
return make_response(jsonify(result), 200)
except Exception as e:
current_app.logger.error(f"Error testing MCP server: {e}", exc_info=True)
return make_response(
jsonify(
{"success": False, "error": f"Connection test failed: {str(e)}"}
),
jsonify({"success": False, "error": "Connection test failed"}),
500,
)
@@ -110,22 +157,16 @@ class MCPServerSave(Resource):
return missing_fields
try:
config = data["config"]
try:
_sanitize_mcp_transport(config)
except ValueError:
return make_response(
jsonify({"success": False, "error": "Unsupported transport_type"}),
400,
)
auth_credentials = {}
auth_credentials = _extract_auth_credentials(config)
auth_type = config.get("auth_type", "none")
if auth_type == "api_key":
if "api_key" in config and config["api_key"]:
auth_credentials["api_key"] = config["api_key"]
if "api_key_header" in config:
auth_credentials["api_key_header"] = config["api_key_header"]
elif auth_type == "bearer":
if "bearer_token" in config and config["bearer_token"]:
auth_credentials["bearer_token"] = config["bearer_token"]
elif auth_type == "basic":
if "username" in config and config["username"]:
auth_credentials["username"] = config["username"]
if "password" in config and config["password"]:
auth_credentials["password"] = config["password"]
mcp_config = config.copy()
mcp_config["auth_credentials"] = auth_credentials
@@ -163,30 +204,39 @@ class MCPServerSave(Resource):
"No valid credentials provided for the selected authentication type"
)
storage_config = config.copy()
tool_id = data.get("id")
existing_encrypted = None
if tool_id:
existing_doc = user_tools_collection.find_one(
{"_id": ObjectId(tool_id), "user": user, "name": "mcp_tool"}
)
if existing_doc:
existing_encrypted = existing_doc.get("config", {}).get(
"encrypted_credentials"
)
if auth_credentials:
encrypted_credentials_string = encrypt_credentials(
if existing_encrypted:
existing_secrets = decrypt_credentials(existing_encrypted, user)
existing_secrets.update(auth_credentials)
auth_credentials = existing_secrets
storage_config["encrypted_credentials"] = encrypt_credentials(
auth_credentials, user
)
storage_config["encrypted_credentials"] = encrypted_credentials_string
elif existing_encrypted:
storage_config["encrypted_credentials"] = existing_encrypted
for field in [
"api_key",
"bearer_token",
"username",
"password",
"api_key_header",
"redirect_uri",
]:
storage_config.pop(field, None)
transformed_actions = []
for action in actions_metadata:
action["active"] = True
if "parameters" in action:
if "properties" in action["parameters"]:
for param_name, param_details in action["parameters"][
"properties"
].items():
param_details["filled_by_llm"] = True
param_details["value"] = ""
transformed_actions.append(action)
transformed_actions = transform_actions(actions_metadata)
tool_data = {
"name": "mcp_tool",
"displayName": data["displayName"],
@@ -198,7 +248,6 @@ class MCPServerSave(Resource):
"user": user,
}
tool_id = data.get("id")
if tool_id:
result = user_tools_collection.update_one(
{"_id": ObjectId(tool_id), "user": user, "name": "mcp_tool"},
@@ -233,9 +282,7 @@ class MCPServerSave(Resource):
except Exception as e:
current_app.logger.error(f"Error saving MCP server: {e}", exc_info=True)
return make_response(
jsonify(
{"success": False, "error": f"Failed to save MCP server: {str(e)}"}
),
jsonify({"success": False, "error": "Failed to save MCP server"}),
500,
)
@@ -263,9 +310,12 @@ class MCPOAuthCallback(Resource):
error = request.args.get("error")
if error:
return redirect(
f"/api/connectors/callback-status?status=error&message=OAuth+error:+{error}.+Please+try+again+and+make+sure+to+grant+all+requested+permissions,+including+offline+access.&provider=mcp_tool"
)
params = {
"status": "error",
"message": f"OAuth error: {error}. Please try again and make sure to grant all requested permissions, including offline access.",
"provider": "mcp_tool",
}
return redirect(f"/api/connectors/callback-status?{urlencode(params)}")
if not code or not state:
return redirect(
"/api/connectors/callback-status?status=error&message=Authorization+code+or+state+not+provided.+Please+complete+the+authorization+process+and+make+sure+to+grant+offline+access.&provider=mcp_tool"
@@ -276,7 +326,6 @@ class MCPOAuthCallback(Resource):
return redirect(
"/api/connectors/callback-status?status=error&message=Internal+server+error:+Redis+not+available.&provider=mcp_tool"
)
code = unquote(code)
manager = MCPOAuthManager(redis_client)
success = manager.handle_oauth_callback(state, code, error)
if success:
@@ -292,17 +341,13 @@ class MCPOAuthCallback(Resource):
f"Error handling MCP OAuth callback: {str(e)}", exc_info=True
)
return redirect(
f"/api/connectors/callback-status?status=error&message=Internal+server+error:+{str(e)}.&provider=mcp_tool"
"/api/connectors/callback-status?status=error&message=Internal+server+error.&provider=mcp_tool"
)
@tools_mcp_ns.route("/mcp_server/oauth_status/<string:task_id>")
class MCPOAuthStatus(Resource):
def get(self, task_id):
"""
Get current status of OAuth flow.
Frontend should poll this endpoint periodically.
"""
try:
redis_client = get_redis_instance()
status_key = f"mcp_oauth_status:{task_id}"
@@ -310,6 +355,14 @@ class MCPOAuthStatus(Resource):
if status_data:
status = json.loads(status_data)
if "tools" in status and isinstance(status["tools"], list):
status["tools"] = [
{
"name": t.get("name", "unknown"),
"description": t.get("description", ""),
}
for t in status["tools"]
]
return make_response(
jsonify({"success": True, "task_id": task_id, **status})
)
@@ -317,17 +370,93 @@ class MCPOAuthStatus(Resource):
return make_response(
jsonify(
{
"success": False,
"error": "Task not found or expired",
"success": True,
"task_id": task_id,
"status": "pending",
"message": "Waiting for OAuth to start...",
}
),
404,
200,
)
except Exception as e:
current_app.logger.error(
f"Error getting OAuth status for task {task_id}: {str(e)}"
f"Error getting OAuth status for task {task_id}: {str(e)}",
exc_info=True,
)
return make_response(
jsonify({"success": False, "error": str(e), "task_id": task_id}), 500
jsonify(
{
"success": False,
"error": "Failed to get OAuth status",
"task_id": task_id,
}
),
500,
)
@tools_mcp_ns.route("/mcp_server/auth_status")
class MCPAuthStatus(Resource):
@api.doc(
description="Batch check auth status for all MCP tools. "
"Lightweight DB-only check — no network calls to MCP servers."
)
def get(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
try:
mcp_tools = list(
user_tools_collection.find(
{"user": user, "name": "mcp_tool"},
{"_id": 1, "config": 1},
)
)
if not mcp_tools:
return make_response(jsonify({"success": True, "statuses": {}}), 200)
oauth_server_urls = {}
statuses = {}
for tool in mcp_tools:
tool_id = str(tool["_id"])
config = tool.get("config", {})
auth_type = config.get("auth_type", "none")
if auth_type == "oauth":
server_url = config.get("server_url", "")
if server_url:
parsed = urlparse(server_url)
base_url = f"{parsed.scheme}://{parsed.netloc}"
oauth_server_urls[tool_id] = base_url
else:
statuses[tool_id] = "needs_auth"
else:
statuses[tool_id] = "configured"
if oauth_server_urls:
unique_urls = list(set(oauth_server_urls.values()))
sessions = list(
_connector_sessions.find(
{"user_id": user, "server_url": {"$in": unique_urls}},
{"server_url": 1, "tokens": 1},
)
)
url_has_tokens = {
doc["server_url"]: bool(doc.get("tokens", {}).get("access_token"))
for doc in sessions
}
for tool_id, base_url in oauth_server_urls.items():
if url_has_tokens.get(base_url):
statuses[tool_id] = "connected"
else:
statuses[tool_id] = "needs_auth"
return make_response(jsonify({"success": True, "statuses": statuses}), 200)
except Exception as e:
current_app.logger.error(
"Error checking MCP auth status: %s", e, exc_info=True
)
return make_response(
jsonify({"success": False, "error": "Failed to check auth status"}),
500,
)

View File

@@ -4,6 +4,7 @@ from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from application.agents.tools.spec_parser import parse_spec
from application.agents.tools.tool_manager import ToolManager
from application.api import api
from application.api.user.base import user_tools_collection
@@ -14,6 +15,114 @@ tool_config = {}
tool_manager = ToolManager(config=tool_config)
def _encrypt_secret_fields(config, config_requirements, user_id):
secret_keys = [
key for key, spec in config_requirements.items()
if spec.get("secret") and key in config and config[key]
]
if not secret_keys:
return config
storage_config = config.copy()
secret_values = {k: config[k] for k in secret_keys}
storage_config["encrypted_credentials"] = encrypt_credentials(secret_values, user_id)
for key in secret_keys:
storage_config.pop(key, None)
return storage_config
def _validate_config(config, config_requirements, has_existing_secrets=False):
errors = {}
for key, spec in config_requirements.items():
depends_on = spec.get("depends_on")
if depends_on:
if not all(config.get(dk) == dv for dk, dv in depends_on.items()):
continue
if spec.get("required") and not config.get(key):
if has_existing_secrets and spec.get("secret"):
continue
errors[key] = f"{spec.get('label', key)} is required"
value = config.get(key)
if value is not None and value != "":
if spec.get("type") == "number":
try:
num = float(value)
if key == "timeout" and (num < 1 or num > 300):
errors[key] = "Timeout must be between 1 and 300"
except (ValueError, TypeError):
errors[key] = f"{spec.get('label', key)} must be a number"
if spec.get("enum") and value not in spec["enum"]:
errors[key] = f"Invalid value for {spec.get('label', key)}"
return errors
def _merge_secrets_on_update(new_config, existing_config, config_requirements, user_id):
"""Merge incoming config with existing encrypted secrets and re-encrypt.
For updates, the client may omit unchanged secret values. This helper
decrypts any previously stored secrets, overlays whatever the client *did*
send, strips plain-text secrets from the stored config, and re-encrypts
the merged result.
Returns the final ``config`` dict ready for persistence.
"""
secret_keys = [
key for key, spec in config_requirements.items()
if spec.get("secret")
]
if not secret_keys:
return new_config
existing_secrets = {}
if "encrypted_credentials" in existing_config:
existing_secrets = decrypt_credentials(
existing_config["encrypted_credentials"], user_id
)
merged_secrets = existing_secrets.copy()
for key in secret_keys:
if key in new_config and new_config[key]:
merged_secrets[key] = new_config[key]
# Start from existing non-secret values, then overlay incoming non-secrets
storage_config = {
k: v for k, v in existing_config.items()
if k not in secret_keys and k != "encrypted_credentials"
}
storage_config.update(
{k: v for k, v in new_config.items() if k not in secret_keys}
)
if merged_secrets:
storage_config["encrypted_credentials"] = encrypt_credentials(
merged_secrets, user_id
)
else:
storage_config.pop("encrypted_credentials", None)
storage_config.pop("has_encrypted_credentials", None)
return storage_config
def transform_actions(actions_metadata):
"""Set default flags on action metadata for storage.
Marks each action as active, sets ``filled_by_llm`` and ``value`` on every
parameter property. Used by both the generic create_tool and MCP save routes.
"""
transformed = []
for action in actions_metadata:
action["active"] = True
if "parameters" in action:
props = action["parameters"].get("properties", {})
for param_details in props.values():
param_details["filled_by_llm"] = True
param_details["value"] = ""
transformed.append(action)
return transformed
tools_ns = Namespace("tools", description="Tool management operations", path="/api")
@@ -28,12 +137,15 @@ class AvailableTools(Resource):
lines = doc.split("\n", 1)
name = lines[0].strip()
description = lines[1].strip() if len(lines) > 1 else ""
config_req = tool_instance.get_config_requirements()
actions = tool_instance.get_actions_metadata()
tools_metadata.append(
{
"name": tool_name,
"displayName": name,
"description": description,
"configRequirements": tool_instance.get_config_requirements(),
"configRequirements": config_req,
"actions": actions,
}
)
except Exception as err:
@@ -59,6 +171,21 @@ class GetTools(Resource):
tool_copy = {**tool}
tool_copy["id"] = str(tool["_id"])
tool_copy.pop("_id", None)
config_req = tool_copy.get("configRequirements", {})
if not config_req:
tool_instance = tool_manager.tools.get(tool_copy.get("name"))
if tool_instance:
config_req = tool_instance.get_config_requirements()
tool_copy["configRequirements"] = config_req
has_secrets = any(
spec.get("secret") for spec in config_req.values()
) if config_req else False
if has_secrets and "encrypted_credentials" in tool_copy.get("config", {}):
tool_copy["config"]["has_encrypted_credentials"] = True
tool_copy["config"].pop("encrypted_credentials", None)
user_tools.append(tool_copy)
except Exception as err:
current_app.logger.error(f"Error getting user tools: {err}", exc_info=True)
@@ -115,23 +242,32 @@ class CreateTool(Resource):
jsonify({"success": False, "message": "Tool not found"}), 404
)
actions_metadata = tool_instance.get_actions_metadata()
transformed_actions = []
for action in actions_metadata:
action["active"] = True
if "parameters" in action:
if "properties" in action["parameters"]:
for param_name, param_details in action["parameters"][
"properties"
].items():
param_details["filled_by_llm"] = True
param_details["value"] = ""
transformed_actions.append(action)
transformed_actions = transform_actions(actions_metadata)
except Exception as err:
current_app.logger.error(
f"Error getting tool actions: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
try:
config_requirements = tool_instance.get_config_requirements()
if config_requirements:
validation_errors = _validate_config(
data["config"], config_requirements
)
if validation_errors:
return make_response(
jsonify(
{
"success": False,
"message": "Validation failed",
"errors": validation_errors,
}
),
400,
)
storage_config = _encrypt_secret_fields(
data["config"], config_requirements, user
)
new_tool = {
"user": user,
"name": data["name"],
@@ -139,7 +275,8 @@ class CreateTool(Resource):
"description": data["description"],
"customName": data.get("customName", ""),
"actions": transformed_actions,
"config": data["config"],
"config": storage_config,
"configRequirements": config_requirements,
"status": data["status"],
}
resp = user_tools_collection.insert_one(new_tool)
@@ -209,57 +346,37 @@ class UpdateTool(Resource):
tool_doc = user_tools_collection.find_one(
{"_id": ObjectId(data["id"]), "user": user}
)
if tool_doc and tool_doc.get("name") == "mcp_tool":
config = data["config"]
existing_config = tool_doc.get("config", {})
storage_config = existing_config.copy()
if not tool_doc:
return make_response(
jsonify({"success": False, "message": "Tool not found"}),
404,
)
tool_name = tool_doc.get("name", data.get("name"))
tool_instance = tool_manager.tools.get(tool_name)
config_requirements = (
tool_instance.get_config_requirements() if tool_instance else {}
)
existing_config = tool_doc.get("config", {})
has_existing_secrets = "encrypted_credentials" in existing_config
storage_config.update(config)
existing_credentials = {}
if "encrypted_credentials" in existing_config:
existing_credentials = decrypt_credentials(
existing_config["encrypted_credentials"], user
if config_requirements:
validation_errors = _validate_config(
data["config"], config_requirements,
has_existing_secrets=has_existing_secrets,
)
if validation_errors:
return make_response(
jsonify({
"success": False,
"message": "Validation failed",
"errors": validation_errors,
}),
400,
)
auth_credentials = existing_credentials.copy()
auth_type = storage_config.get("auth_type", "none")
if auth_type == "api_key":
if "api_key" in config and config["api_key"]:
auth_credentials["api_key"] = config["api_key"]
if "api_key_header" in config:
auth_credentials["api_key_header"] = config[
"api_key_header"
]
elif auth_type == "bearer":
if "bearer_token" in config and config["bearer_token"]:
auth_credentials["bearer_token"] = config["bearer_token"]
elif "encrypted_token" in config and config["encrypted_token"]:
auth_credentials["bearer_token"] = config["encrypted_token"]
elif auth_type == "basic":
if "username" in config and config["username"]:
auth_credentials["username"] = config["username"]
if "password" in config and config["password"]:
auth_credentials["password"] = config["password"]
if auth_type != "none" and auth_credentials:
encrypted_credentials_string = encrypt_credentials(
auth_credentials, user
)
storage_config["encrypted_credentials"] = (
encrypted_credentials_string
)
elif auth_type == "none":
storage_config.pop("encrypted_credentials", None)
for field in [
"api_key",
"bearer_token",
"encrypted_token",
"username",
"password",
"api_key_header",
]:
storage_config.pop(field, None)
update_data["config"] = storage_config
else:
update_data["config"] = data["config"]
update_data["config"] = _merge_secrets_on_update(
data["config"], existing_config, config_requirements, user
)
if "status" in data:
update_data["status"] = data["status"]
user_tools_collection.update_one(
@@ -297,9 +414,42 @@ class UpdateToolConfig(Resource):
if missing_fields:
return missing_fields
try:
tool_doc = user_tools_collection.find_one(
{"_id": ObjectId(data["id"]), "user": user}
)
if not tool_doc:
return make_response(jsonify({"success": False}), 404)
tool_name = tool_doc.get("name")
tool_instance = tool_manager.tools.get(tool_name)
config_requirements = (
tool_instance.get_config_requirements() if tool_instance else {}
)
existing_config = tool_doc.get("config", {})
has_existing_secrets = "encrypted_credentials" in existing_config
if config_requirements:
validation_errors = _validate_config(
data["config"], config_requirements,
has_existing_secrets=has_existing_secrets,
)
if validation_errors:
return make_response(
jsonify({
"success": False,
"message": "Validation failed",
"errors": validation_errors,
}),
400,
)
final_config = _merge_secrets_on_update(
data["config"], existing_config, config_requirements, user
)
user_tools_collection.update_one(
{"_id": ObjectId(data["id"]), "user": user},
{"$set": {"config": data["config"]}},
{"$set": {"config": final_config}},
)
except Exception as err:
current_app.logger.error(
@@ -409,8 +559,142 @@ class DeleteTool(Resource):
{"_id": ObjectId(data["id"]), "user": user}
)
if result.deleted_count == 0:
return {"success": False, "message": "Tool not found"}, 404
return make_response(
jsonify({"success": False, "message": "Tool not found"}), 404
)
except Exception as err:
current_app.logger.error(f"Error deleting tool: {err}", exc_info=True)
return {"success": False}, 400
return {"success": True}, 200
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True}), 200)
@tools_ns.route("/parse_spec")
class ParseSpec(Resource):
@api.doc(
description="Parse an API specification (OpenAPI 3.x or Swagger 2.0) and return actions"
)
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
if "file" in request.files:
file = request.files["file"]
if not file.filename:
return make_response(
jsonify({"success": False, "message": "No file selected"}), 400
)
try:
spec_content = file.read().decode("utf-8")
except UnicodeDecodeError:
return make_response(
jsonify({"success": False, "message": "Invalid file encoding"}), 400
)
elif request.is_json:
data = request.get_json()
spec_content = data.get("spec_content", "")
else:
return make_response(
jsonify({"success": False, "message": "No spec provided"}), 400
)
if not spec_content or not spec_content.strip():
return make_response(
jsonify({"success": False, "message": "Empty spec content"}), 400
)
try:
metadata, actions = parse_spec(spec_content)
return make_response(
jsonify(
{
"success": True,
"metadata": metadata,
"actions": actions,
}
),
200,
)
except ValueError as e:
current_app.logger.error(f"Spec validation error: {e}")
return make_response(jsonify({"success": False, "error": "Invalid specification format"}), 400)
except Exception as err:
current_app.logger.error(f"Error parsing spec: {err}", exc_info=True)
return make_response(jsonify({"success": False, "error": "Failed to parse specification"}), 500)
@tools_ns.route("/artifact/<artifact_id>")
class GetArtifact(Resource):
@api.doc(description="Get artifact data by artifact ID. Returns all todos for the tool when fetching a todo artifact.")
def get(self, artifact_id: str):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user_id = decoded_token.get("sub")
try:
obj_id = ObjectId(artifact_id)
except Exception:
return make_response(
jsonify({"success": False, "message": "Invalid artifact ID"}), 400
)
from application.core.mongo_db import MongoDB
from application.core.settings import settings
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
note_doc = db["notes"].find_one({"_id": obj_id, "user_id": user_id})
if note_doc:
content = note_doc.get("note", "")
line_count = len(content.split("\n")) if content else 0
artifact = {
"artifact_type": "note",
"data": {
"content": content,
"line_count": line_count,
"updated_at": (
note_doc["updated_at"].isoformat()
if note_doc.get("updated_at")
else None
),
},
}
return make_response(jsonify({"success": True, "artifact": artifact}), 200)
todo_doc = db["todos"].find_one({"_id": obj_id, "user_id": user_id})
if todo_doc:
tool_id = todo_doc.get("tool_id")
query = {"user_id": user_id, "tool_id": tool_id}
all_todos = list(db["todos"].find(query))
items = []
open_count = 0
completed_count = 0
for t in all_todos:
status = t.get("status", "open")
if status == "open":
open_count += 1
elif status == "completed":
completed_count += 1
items.append({
"todo_id": t.get("todo_id"),
"title": t.get("title", ""),
"status": status,
"created_at": (
t["created_at"].isoformat() if t.get("created_at") else None
),
"updated_at": (
t["updated_at"].isoformat() if t.get("updated_at") else None
),
})
artifact = {
"artifact_type": "todo_list",
"data": {
"items": items,
"total_count": len(items),
"open_count": open_count,
"completed_count": completed_count,
},
}
return make_response(jsonify({"success": True, "artifact": artifact}), 200)
return make_response(
jsonify({"success": False, "message": "Artifact not found"}), 404
)

View File

@@ -0,0 +1,387 @@
"""Centralized utilities for API routes."""
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Tuple
from bson.errors import InvalidId
from bson.objectid import ObjectId
from flask import (
Response,
current_app,
has_app_context,
jsonify,
make_response,
request,
)
from pymongo.collection import Collection
def get_user_id() -> Optional[str]:
"""
Extract user ID from decoded JWT token.
Returns:
User ID string or None if not authenticated
"""
decoded_token = getattr(request, "decoded_token", None)
return decoded_token.get("sub") if decoded_token else None
def require_auth(func: Callable) -> Callable:
"""
Decorator to require authentication for route handlers.
Usage:
@require_auth
def get(self):
user_id = get_user_id()
...
"""
@wraps(func)
def wrapper(*args, **kwargs):
user_id = get_user_id()
if not user_id:
return error_response("Unauthorized", 401)
return func(*args, **kwargs)
return wrapper
def success_response(
data: Optional[Dict[str, Any]] = None, status: int = 200
) -> Response:
"""
Create a standardized success response.
Args:
data: Optional data dictionary to include in response
status: HTTP status code (default: 200)
Returns:
Flask Response object
Example:
return success_response({"users": [...], "total": 10})
"""
response = {"success": True}
if data:
response.update(data)
return make_response(jsonify(response), status)
def error_response(message: str, status: int = 400, **kwargs) -> Response:
"""
Create a standardized error response.
Args:
message: Error message string
status: HTTP status code (default: 400)
**kwargs: Additional fields to include in response
Returns:
Flask Response object
Example:
return error_response("Resource not found", 404)
return error_response("Invalid input", 400, errors=["field1", "field2"])
"""
response = {"success": False, "message": message}
response.update(kwargs)
return make_response(jsonify(response), status)
def validate_object_id(
id_string: str, resource_name: str = "Resource"
) -> Tuple[Optional[ObjectId], Optional[Response]]:
"""
Validate and convert string to ObjectId.
Args:
id_string: String to convert
resource_name: Name of resource for error message
Returns:
Tuple of (ObjectId or None, error_response or None)
Example:
obj_id, error = validate_object_id(workflow_id, "Workflow")
if error:
return error
"""
try:
return ObjectId(id_string), None
except (InvalidId, TypeError):
return None, error_response(f"Invalid {resource_name} ID format")
def validate_pagination(
default_limit: int = 20, max_limit: int = 100
) -> Tuple[int, int, Optional[Response]]:
"""
Extract and validate pagination parameters from request.
Args:
default_limit: Default items per page
max_limit: Maximum allowed items per page
Returns:
Tuple of (limit, skip, error_response or None)
Example:
limit, skip, error = validate_pagination()
if error:
return error
"""
try:
limit = min(int(request.args.get("limit", default_limit)), max_limit)
skip = int(request.args.get("skip", 0))
if limit < 1 or skip < 0:
return 0, 0, error_response("Invalid pagination parameters")
return limit, skip, None
except ValueError:
return 0, 0, error_response("Invalid pagination parameters")
def check_resource_ownership(
collection: Collection,
resource_id: ObjectId,
user_id: str,
resource_name: str = "Resource",
) -> Tuple[Optional[Dict], Optional[Response]]:
"""
Check if resource exists and belongs to user.
Args:
collection: MongoDB collection
resource_id: Resource ObjectId
user_id: User ID string
resource_name: Name of resource for error messages
Returns:
Tuple of (resource_dict or None, error_response or None)
Example:
workflow, error = check_resource_ownership(
workflows_collection,
workflow_id,
user_id,
"Workflow"
)
if error:
return error
"""
resource = collection.find_one({"_id": resource_id, "user": user_id})
if not resource:
return None, error_response(f"{resource_name} not found", 404)
return resource, None
def serialize_object_id(
obj: Dict[str, Any], id_field: str = "_id", new_field: str = "id"
) -> Dict[str, Any]:
"""
Convert ObjectId to string in a dictionary.
Args:
obj: Dictionary containing ObjectId
id_field: Field name containing ObjectId
new_field: New field name for string ID
Returns:
Modified dictionary
Example:
user = serialize_object_id(user_doc)
# user["id"] = "507f1f77bcf86cd799439011"
"""
if id_field in obj:
obj[new_field] = str(obj[id_field])
if id_field != new_field:
obj.pop(id_field, None)
return obj
def serialize_list(items: List[Dict], serializer: Callable[[Dict], Dict]) -> List[Dict]:
"""
Apply serializer function to list of items.
Args:
items: List of dictionaries
serializer: Function to apply to each item
Returns:
List of serialized items
Example:
workflows = serialize_list(workflow_docs, serialize_workflow)
"""
return [serializer(item) for item in items]
def paginated_response(
collection: Collection,
query: Dict[str, Any],
serializer: Callable[[Dict], Dict],
limit: int,
skip: int,
sort_field: str = "created_at",
sort_order: int = -1,
response_key: str = "items",
) -> Response:
"""
Create paginated response for collection query.
Args:
collection: MongoDB collection
query: Query dictionary
serializer: Function to serialize each item
limit: Items per page
skip: Number of items to skip
sort_field: Field to sort by
sort_order: Sort order (1=asc, -1=desc)
response_key: Key name for items in response
Returns:
Flask Response with paginated data
Example:
return paginated_response(
workflows_collection,
{"user": user_id},
serialize_workflow,
limit, skip,
response_key="workflows"
)
"""
items = list(
collection.find(query).sort(sort_field, sort_order).skip(skip).limit(limit)
)
total = collection.count_documents(query)
return success_response(
{
response_key: serialize_list(items, serializer),
"total": total,
"limit": limit,
"skip": skip,
}
)
def require_fields(required: List[str]) -> Callable:
"""
Decorator to validate required fields in request JSON.
Args:
required: List of required field names
Returns:
Decorator function
Example:
@require_fields(["name", "description"])
def post(self):
data = request.get_json()
...
"""
def decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args, **kwargs):
data = request.get_json()
if not data:
return error_response("Request body required")
missing = [field for field in required if not data.get(field)]
if missing:
return error_response(f"Missing required fields: {', '.join(missing)}")
return func(*args, **kwargs)
return wrapper
return decorator
def safe_db_operation(
operation: Callable, error_message: str = "Database operation failed"
) -> Tuple[Any, Optional[Response]]:
"""
Safely execute database operation with error handling.
Args:
operation: Function to execute
error_message: Error message if operation fails
Returns:
Tuple of (result or None, error_response or None)
Example:
result, error = safe_db_operation(
lambda: collection.insert_one(doc),
"Failed to create resource"
)
if error:
return error
"""
try:
result = operation()
return result, None
except Exception as err:
if has_app_context():
current_app.logger.error(f"{error_message}: {err}", exc_info=True)
return None, error_response(error_message)
def validate_enum(
value: Any, allowed: List[Any], field_name: str
) -> Optional[Response]:
"""
Validate that value is in allowed list.
Args:
value: Value to validate
allowed: List of allowed values
field_name: Field name for error message
Returns:
error_response if invalid, None if valid
Example:
error = validate_enum(status, ["draft", "published"], "status")
if error:
return error
"""
if value not in allowed:
allowed_str = ", ".join(f"'{v}'" for v in allowed)
return error_response(f"Invalid {field_name}. Must be one of: {allowed_str}")
return None
def extract_sort_params(
default_field: str = "created_at",
default_order: str = "desc",
allowed_fields: Optional[List[str]] = None,
) -> Tuple[str, int]:
"""
Extract and validate sort parameters from request.
Args:
default_field: Default sort field
default_order: Default sort order ("asc" or "desc")
allowed_fields: List of allowed sort fields (None = no validation)
Returns:
Tuple of (sort_field, sort_order)
Example:
sort_field, sort_order = extract_sort_params(
allowed_fields=["name", "date", "status"]
)
"""
sort_field = request.args.get("sort", default_field)
sort_order_str = request.args.get("order", default_order).lower()
if allowed_fields and sort_field not in allowed_fields:
sort_field = default_field
sort_order = -1 if sort_order_str == "desc" else 1
return sort_field, sort_order

View File

@@ -0,0 +1,3 @@
from .routes import workflows_ns
__all__ = ["workflows_ns"]

View File

@@ -0,0 +1,546 @@
"""Workflow management routes."""
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional, Set
from flask import current_app, request
from flask_restx import Namespace, Resource
from application.api.user.base import (
workflow_edges_collection,
workflow_nodes_collection,
workflows_collection,
)
from application.core.json_schema_utils import (
JsonSchemaValidationError,
normalize_json_schema_payload,
)
from application.core.model_utils import get_model_capabilities
from application.api.user.utils import (
check_resource_ownership,
error_response,
get_user_id,
require_auth,
require_fields,
safe_db_operation,
success_response,
validate_object_id,
)
workflows_ns = Namespace("workflows", path="/api")
def _workflow_error_response(message: str, err: Exception):
current_app.logger.error(f"{message}: {err}", exc_info=True)
return error_response(message)
def serialize_workflow(w: Dict) -> Dict:
"""Serialize workflow document to API response format."""
return {
"id": str(w["_id"]),
"name": w.get("name"),
"description": w.get("description"),
"created_at": w["created_at"].isoformat() if w.get("created_at") else None,
"updated_at": w["updated_at"].isoformat() if w.get("updated_at") else None,
}
def serialize_node(n: Dict) -> Dict:
"""Serialize workflow node document to API response format."""
return {
"id": n["id"],
"type": n["type"],
"title": n.get("title"),
"description": n.get("description"),
"position": n.get("position"),
"data": n.get("config", {}),
}
def serialize_edge(e: Dict) -> Dict:
"""Serialize workflow edge document to API response format."""
return {
"id": e["id"],
"source": e.get("source_id"),
"target": e.get("target_id"),
"sourceHandle": e.get("source_handle"),
"targetHandle": e.get("target_handle"),
}
def get_workflow_graph_version(workflow: Dict) -> int:
"""Get current graph version with legacy fallback."""
raw_version = workflow.get("current_graph_version", 1)
try:
version = int(raw_version)
return version if version > 0 else 1
except (ValueError, TypeError):
return 1
def fetch_graph_documents(collection, workflow_id: str, graph_version: int) -> List[Dict]:
"""Fetch graph docs for active version, with fallback for legacy unversioned data."""
docs = list(
collection.find({"workflow_id": workflow_id, "graph_version": graph_version})
)
if docs:
return docs
if graph_version == 1:
return list(
collection.find(
{"workflow_id": workflow_id, "graph_version": {"$exists": False}}
)
)
return docs
def validate_json_schema_payload(
json_schema: Any,
) -> tuple[Optional[Dict[str, Any]], Optional[str]]:
"""Validate and normalize optional JSON schema payload for structured output."""
if json_schema is None:
return None, None
try:
return normalize_json_schema_payload(json_schema), None
except JsonSchemaValidationError as exc:
return None, str(exc)
def normalize_agent_node_json_schemas(nodes: List[Dict]) -> List[Dict]:
"""Normalize agent-node JSON schema payloads before persistence."""
normalized_nodes: List[Dict] = []
for node in nodes:
if not isinstance(node, dict):
normalized_nodes.append(node)
continue
normalized_node = dict(node)
if normalized_node.get("type") != "agent":
normalized_nodes.append(normalized_node)
continue
raw_config = normalized_node.get("data")
if not isinstance(raw_config, dict) or "json_schema" not in raw_config:
normalized_nodes.append(normalized_node)
continue
normalized_config = dict(raw_config)
try:
normalized_config["json_schema"] = normalize_json_schema_payload(
raw_config.get("json_schema")
)
except JsonSchemaValidationError:
# Validation runs before normalization; keep original on unexpected shape.
normalized_config["json_schema"] = raw_config.get("json_schema")
normalized_node["data"] = normalized_config
normalized_nodes.append(normalized_node)
return normalized_nodes
def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[str]:
"""Validate workflow graph structure."""
errors = []
if not nodes:
errors.append("Workflow must have at least one node")
return errors
start_nodes = [n for n in nodes if n.get("type") == "start"]
if len(start_nodes) != 1:
errors.append("Workflow must have exactly one start node")
end_nodes = [n for n in nodes if n.get("type") == "end"]
if not end_nodes:
errors.append("Workflow must have at least one end node")
node_ids = {n.get("id") for n in nodes}
node_map = {n.get("id"): n for n in nodes}
end_ids = {n.get("id") for n in end_nodes}
for edge in edges:
source_id = edge.get("source")
target_id = edge.get("target")
if source_id not in node_ids:
errors.append(f"Edge references non-existent source: {source_id}")
if target_id not in node_ids:
errors.append(f"Edge references non-existent target: {target_id}")
if start_nodes:
start_id = start_nodes[0].get("id")
if not any(e.get("source") == start_id for e in edges):
errors.append("Start node must have at least one outgoing edge")
condition_nodes = [n for n in nodes if n.get("type") == "condition"]
for cnode in condition_nodes:
cnode_id = cnode.get("id")
cnode_title = cnode.get("title", cnode_id)
outgoing = [e for e in edges if e.get("source") == cnode_id]
if len(outgoing) < 2:
errors.append(
f"Condition node '{cnode_title}' must have at least 2 outgoing edges"
)
node_data = cnode.get("data", {}) or {}
cases = node_data.get("cases", [])
if not isinstance(cases, list):
cases = []
if not cases or not any(
isinstance(c, dict) and str(c.get("expression", "")).strip() for c in cases
):
errors.append(
f"Condition node '{cnode_title}' must have at least one case with an expression"
)
case_handles: Set[str] = set()
duplicate_case_handles: Set[str] = set()
for case in cases:
if not isinstance(case, dict):
continue
raw_handle = case.get("sourceHandle", "")
handle = raw_handle.strip() if isinstance(raw_handle, str) else ""
if not handle:
errors.append(
f"Condition node '{cnode_title}' has a case without a branch handle"
)
continue
if handle in case_handles:
duplicate_case_handles.add(handle)
case_handles.add(handle)
for handle in duplicate_case_handles:
errors.append(
f"Condition node '{cnode_title}' has duplicate case handle '{handle}'"
)
outgoing_by_handle: Dict[str, List[Dict]] = {}
for out_edge in outgoing:
raw_handle = out_edge.get("sourceHandle", "")
handle = raw_handle.strip() if isinstance(raw_handle, str) else ""
outgoing_by_handle.setdefault(handle, []).append(out_edge)
for handle, handle_edges in outgoing_by_handle.items():
if not handle:
errors.append(
f"Condition node '{cnode_title}' has an outgoing edge without sourceHandle"
)
continue
if handle != "else" and handle not in case_handles:
errors.append(
f"Condition node '{cnode_title}' has a connection from unknown branch '{handle}'"
)
if len(handle_edges) > 1:
errors.append(
f"Condition node '{cnode_title}' has multiple outgoing edges from branch '{handle}'"
)
if "else" not in outgoing_by_handle:
errors.append(f"Condition node '{cnode_title}' must have an 'else' branch")
for case in cases:
if not isinstance(case, dict):
continue
raw_handle = case.get("sourceHandle", "")
handle = raw_handle.strip() if isinstance(raw_handle, str) else ""
if not handle:
continue
raw_expression = case.get("expression", "")
has_expression = isinstance(raw_expression, str) and bool(
raw_expression.strip()
)
has_outgoing = bool(outgoing_by_handle.get(handle))
if has_expression and not has_outgoing:
errors.append(
f"Condition node '{cnode_title}' case '{handle}' has an expression but no outgoing edge"
)
if not has_expression and has_outgoing:
errors.append(
f"Condition node '{cnode_title}' case '{handle}' has an outgoing edge but no expression"
)
for handle, handle_edges in outgoing_by_handle.items():
if not handle:
continue
for out_edge in handle_edges:
target = out_edge.get("target")
if target and not _can_reach_end(target, edges, node_map, end_ids):
errors.append(
f"Branch '{handle}' of condition '{cnode_title}' "
f"must eventually reach an end node"
)
agent_nodes = [n for n in nodes if n.get("type") == "agent"]
for agent_node in agent_nodes:
agent_title = agent_node.get("title", agent_node.get("id", "unknown"))
raw_config = agent_node.get("data", {}) or {}
if not isinstance(raw_config, dict):
errors.append(f"Agent node '{agent_title}' has invalid configuration")
continue
normalized_schema, schema_error = validate_json_schema_payload(
raw_config.get("json_schema")
)
has_json_schema = normalized_schema is not None
model_id = raw_config.get("model_id")
if has_json_schema and isinstance(model_id, str) and model_id.strip():
capabilities = get_model_capabilities(model_id.strip())
if capabilities and not capabilities.get("supports_structured_output", False):
errors.append(
f"Agent node '{agent_title}' selected model does not support structured output"
)
if schema_error:
errors.append(f"Agent node '{agent_title}' JSON schema {schema_error}")
for node in nodes:
if not node.get("id"):
errors.append("All nodes must have an id")
if not node.get("type"):
errors.append(f"Node {node.get('id', 'unknown')} must have a type")
return errors
def _can_reach_end(
node_id: str, edges: List[Dict], node_map: Dict, end_ids: set, visited: set = None
) -> bool:
if visited is None:
visited = set()
if node_id in end_ids:
return True
if node_id in visited or node_id not in node_map:
return False
visited.add(node_id)
outgoing = [e.get("target") for e in edges if e.get("source") == node_id]
return any(_can_reach_end(t, edges, node_map, end_ids, visited) for t in outgoing if t)
def create_workflow_nodes(
workflow_id: str, nodes_data: List[Dict], graph_version: int
) -> None:
"""Insert workflow nodes into database."""
if nodes_data:
workflow_nodes_collection.insert_many(
[
{
"id": n["id"],
"workflow_id": workflow_id,
"graph_version": graph_version,
"type": n["type"],
"title": n.get("title", ""),
"description": n.get("description", ""),
"position": n.get("position", {"x": 0, "y": 0}),
"config": n.get("data", {}),
}
for n in nodes_data
]
)
def create_workflow_edges(
workflow_id: str, edges_data: List[Dict], graph_version: int
) -> None:
"""Insert workflow edges into database."""
if edges_data:
workflow_edges_collection.insert_many(
[
{
"id": e["id"],
"workflow_id": workflow_id,
"graph_version": graph_version,
"source_id": e.get("source"),
"target_id": e.get("target"),
"source_handle": e.get("sourceHandle"),
"target_handle": e.get("targetHandle"),
}
for e in edges_data
]
)
@workflows_ns.route("/workflows")
class WorkflowList(Resource):
@require_auth
@require_fields(["name"])
def post(self):
"""Create a new workflow with nodes and edges."""
user_id = get_user_id()
data = request.get_json()
name = data.get("name", "").strip()
nodes_data = data.get("nodes", [])
edges_data = data.get("edges", [])
validation_errors = validate_workflow_structure(nodes_data, edges_data)
if validation_errors:
return error_response(
"Workflow validation failed", errors=validation_errors
)
nodes_data = normalize_agent_node_json_schemas(nodes_data)
now = datetime.now(timezone.utc)
workflow_doc = {
"name": name,
"description": data.get("description", ""),
"user": user_id,
"created_at": now,
"updated_at": now,
"current_graph_version": 1,
}
result, error = safe_db_operation(
lambda: workflows_collection.insert_one(workflow_doc),
"Failed to create workflow",
)
if error:
return error
workflow_id = str(result.inserted_id)
try:
create_workflow_nodes(workflow_id, nodes_data, 1)
create_workflow_edges(workflow_id, edges_data, 1)
except Exception as err:
workflow_nodes_collection.delete_many({"workflow_id": workflow_id})
workflow_edges_collection.delete_many({"workflow_id": workflow_id})
workflows_collection.delete_one({"_id": result.inserted_id})
return _workflow_error_response("Failed to create workflow structure", err)
return success_response({"id": workflow_id}, 201)
@workflows_ns.route("/workflows/<string:workflow_id>")
class WorkflowDetail(Resource):
@require_auth
def get(self, workflow_id: str):
"""Get workflow details with nodes and edges."""
user_id = get_user_id()
obj_id, error = validate_object_id(workflow_id, "Workflow")
if error:
return error
workflow, error = check_resource_ownership(
workflows_collection, obj_id, user_id, "Workflow"
)
if error:
return error
graph_version = get_workflow_graph_version(workflow)
nodes = fetch_graph_documents(
workflow_nodes_collection, workflow_id, graph_version
)
edges = fetch_graph_documents(
workflow_edges_collection, workflow_id, graph_version
)
return success_response(
{
"workflow": serialize_workflow(workflow),
"nodes": [serialize_node(n) for n in nodes],
"edges": [serialize_edge(e) for e in edges],
}
)
@require_auth
@require_fields(["name"])
def put(self, workflow_id: str):
"""Update workflow and replace nodes/edges."""
user_id = get_user_id()
obj_id, error = validate_object_id(workflow_id, "Workflow")
if error:
return error
workflow, error = check_resource_ownership(
workflows_collection, obj_id, user_id, "Workflow"
)
if error:
return error
data = request.get_json()
name = data.get("name", "").strip()
nodes_data = data.get("nodes", [])
edges_data = data.get("edges", [])
validation_errors = validate_workflow_structure(nodes_data, edges_data)
if validation_errors:
return error_response(
"Workflow validation failed", errors=validation_errors
)
nodes_data = normalize_agent_node_json_schemas(nodes_data)
current_graph_version = get_workflow_graph_version(workflow)
next_graph_version = current_graph_version + 1
try:
create_workflow_nodes(workflow_id, nodes_data, next_graph_version)
create_workflow_edges(workflow_id, edges_data, next_graph_version)
except Exception as err:
workflow_nodes_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": next_graph_version}
)
workflow_edges_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": next_graph_version}
)
return _workflow_error_response("Failed to update workflow structure", err)
now = datetime.now(timezone.utc)
_, error = safe_db_operation(
lambda: workflows_collection.update_one(
{"_id": obj_id},
{
"$set": {
"name": name,
"description": data.get("description", ""),
"updated_at": now,
"current_graph_version": next_graph_version,
}
},
),
"Failed to update workflow",
)
if error:
workflow_nodes_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": next_graph_version}
)
workflow_edges_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": next_graph_version}
)
return error
try:
workflow_nodes_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": {"$ne": next_graph_version}}
)
workflow_edges_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": {"$ne": next_graph_version}}
)
except Exception as cleanup_err:
current_app.logger.warning(
f"Failed to clean old workflow graph versions for {workflow_id}: {cleanup_err}"
)
return success_response()
@require_auth
def delete(self, workflow_id: str):
"""Delete workflow and its graph."""
user_id = get_user_id()
obj_id, error = validate_object_id(workflow_id, "Workflow")
if error:
return error
workflow, error = check_resource_ownership(
workflows_collection, obj_id, user_id, "Workflow"
)
if error:
return error
try:
workflow_nodes_collection.delete_many({"workflow_id": workflow_id})
workflow_edges_collection.delete_many({"workflow_id": workflow_id})
workflows_collection.delete_one({"_id": workflow["_id"], "user": user_id})
except Exception as err:
return _workflow_error_response("Failed to delete workflow", err)
return success_response()

View File

@@ -19,6 +19,10 @@ from application.api.user.routes import user # noqa: E402
from application.api.connector.routes import connector # noqa: E402
from application.celery_init import celery # noqa: E402
from application.core.settings import settings # noqa: E402
from application.stt.upload_limits import ( # noqa: E402
build_stt_file_size_limit_message,
should_reject_stt_request,
)
if platform.system() == "Windows":
@@ -68,6 +72,11 @@ def home():
return "Welcome to DocsGPT Backend!"
@app.route("/api/health")
def health():
return jsonify({"status": "ok"})
@app.route("/api/config")
def get_config():
response = {
@@ -88,6 +97,23 @@ def generate_token():
return jsonify({"error": "Token generation not allowed in current auth mode"}), 400
@app.before_request
def enforce_stt_request_size_limits():
if request.method == "OPTIONS":
return None
if should_reject_stt_request(request.path, request.content_length):
return (
jsonify(
{
"success": False,
"message": build_stt_file_size_limit_message(),
}
),
413,
)
return None
@app.before_request
def authenticate_request():
if request.method == "OPTIONS":

View File

@@ -19,9 +19,9 @@ def handle_auth(request, data={}):
options={"verify_exp": False},
)
return decoded_token
except Exception as e:
except Exception:
return {
"message": f"Authentication error: {str(e)}",
"message": "Authentication error: invalid token",
"error": "invalid_token",
}
else:

View File

@@ -21,3 +21,4 @@ def config_loggers(*args, **kwargs):
celery = make_celery()
celery.config_from_object("application.celeryconfig")

View File

@@ -6,3 +6,6 @@ result_backend = os.getenv("CELERY_RESULT_BACKEND")
task_serializer = 'json'
result_serializer = 'json'
accept_content = ['json']
# Autodiscover tasks
imports = ('application.api.user.tasks',)

View File

@@ -0,0 +1,34 @@
from typing import Any, Dict, Optional
class JsonSchemaValidationError(ValueError):
"""Raised when a JSON schema payload is invalid."""
def normalize_json_schema_payload(json_schema: Any) -> Optional[Dict[str, Any]]:
"""
Normalize accepted JSON schema payload shapes to a plain schema object.
Accepted inputs:
- None
- A raw schema object with a top-level "type"
- A wrapped payload with a top-level "schema" object
"""
if json_schema is None:
return None
if not isinstance(json_schema, dict):
raise JsonSchemaValidationError("must be a valid JSON object")
wrapped_schema = json_schema.get("schema")
if wrapped_schema is not None:
if not isinstance(wrapped_schema, dict):
raise JsonSchemaValidationError('field "schema" must be a valid JSON object')
return wrapped_schema
if "type" not in json_schema:
raise JsonSchemaValidationError(
'must include either a "type" or "schema" field'
)
return json_schema

View File

@@ -8,8 +8,8 @@ from application.core.model_settings import (
ModelProvider,
)
OPENAI_ATTACHMENTS = [
"application/pdf",
# Base image attachment types supported by most vision-capable LLMs
IMAGE_ATTACHMENTS = [
"image/png",
"image/jpeg",
"image/jpg",
@@ -17,75 +17,44 @@ OPENAI_ATTACHMENTS = [
"image/gif",
]
GOOGLE_ATTACHMENTS = [
"application/pdf",
"image/png",
"image/jpeg",
"image/jpg",
"image/webp",
"image/gif",
]
# PDF excluded: most OpenAI-compatible endpoints don't support native PDF uploads.
# When excluded, PDFs are synthetically processed by converting pages to images.
OPENAI_ATTACHMENTS = IMAGE_ATTACHMENTS
GOOGLE_ATTACHMENTS = ["application/pdf"] + IMAGE_ATTACHMENTS
ANTHROPIC_ATTACHMENTS = IMAGE_ATTACHMENTS
OPENROUTER_ATTACHMENTS = IMAGE_ATTACHMENTS
NOVITA_ATTACHMENTS = IMAGE_ATTACHMENTS
OPENAI_MODELS = [
AvailableModel(
id="gpt-4o",
id="gpt-5.1",
provider=ModelProvider.OPENAI,
display_name="GPT-4 Omni",
description="Latest and most capable model",
display_name="GPT-5.1",
description="Flagship model with enhanced reasoning, coding, and agentic capabilities",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
context_window=128000,
context_window=200000,
),
),
AvailableModel(
id="gpt-4o-mini",
id="gpt-5-mini",
provider=ModelProvider.OPENAI,
display_name="GPT-4 Omni Mini",
description="Fast and efficient",
display_name="GPT-5 Mini",
description="Faster, cost-effective variant of GPT-5.1",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
context_window=128000,
context_window=200000,
),
),
AvailableModel(
id="gpt-4-turbo",
provider=ModelProvider.OPENAI,
display_name="GPT-4 Turbo",
description="Fast GPT-4 with 128k context",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
context_window=128000,
),
),
AvailableModel(
id="gpt-4",
provider=ModelProvider.OPENAI,
display_name="GPT-4",
description="Most capable model",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
context_window=8192,
),
),
AvailableModel(
id="gpt-3.5-turbo",
provider=ModelProvider.OPENAI,
display_name="GPT-3.5 Turbo",
description="Fast and cost-effective",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=4096,
),
),
)
]
@@ -97,6 +66,7 @@ ANTHROPIC_MODELS = [
description="Latest Claude 3.5 Sonnet with enhanced capabilities",
capabilities=ModelCapabilities(
supports_tools=True,
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
context_window=200000,
),
),
@@ -107,6 +77,7 @@ ANTHROPIC_MODELS = [
description="Balanced performance and capability",
capabilities=ModelCapabilities(
supports_tools=True,
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
context_window=200000,
),
),
@@ -117,6 +88,7 @@ ANTHROPIC_MODELS = [
description="Most capable Claude model",
capabilities=ModelCapabilities(
supports_tools=True,
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
context_window=200000,
),
),
@@ -127,6 +99,7 @@ ANTHROPIC_MODELS = [
description="Fastest Claude model",
capabilities=ModelCapabilities(
supports_tools=True,
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
context_window=200000,
),
),
@@ -159,9 +132,9 @@ GOOGLE_MODELS = [
),
),
AvailableModel(
id="gemini-2.5-pro",
id="gemini-3-pro-preview",
provider=ModelProvider.GOOGLE,
display_name="Gemini 2.5 Pro",
display_name="Gemini 3 Pro",
description="Most capable Gemini model",
capabilities=ModelCapabilities(
supports_tools=True,
@@ -185,23 +158,78 @@ GROQ_MODELS = [
),
),
AvailableModel(
id="llama-3.1-8b-instant",
id="openai/gpt-oss-120b",
provider=ModelProvider.GROQ,
display_name="Llama 3.1 8B",
description="Ultra-fast inference",
display_name="GPT-OSS 120B",
description="Open-source GPT model optimized for speed",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=128000,
),
),
]
OPENROUTER_MODELS = [
AvailableModel(
id="mixtral-8x7b-32768",
provider=ModelProvider.GROQ,
display_name="Mixtral 8x7B",
description="High-speed inference with tools",
id="qwen/qwen3-coder:free",
provider=ModelProvider.OPENROUTER,
display_name="Qwen 3 Coder",
description="Latest Qwen model with high-speed inference",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=32768,
context_window=128000,
supported_attachment_types=OPENROUTER_ATTACHMENTS
),
),
AvailableModel(
id="google/gemma-3-27b-it:free",
provider=ModelProvider.OPENROUTER,
display_name="Gemma 3 27B",
description="Latest Gemma model with high-speed inference",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=128000,
supported_attachment_types=OPENROUTER_ATTACHMENTS
),
),
]
NOVITA_MODELS = [
AvailableModel(
id="moonshotai/kimi-k2.5",
provider=ModelProvider.NOVITA,
display_name="Kimi K2.5",
description="MoE model with function calling, structured output, reasoning, and vision",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=NOVITA_ATTACHMENTS,
context_window=262144,
),
),
AvailableModel(
id="zai-org/glm-5",
provider=ModelProvider.NOVITA,
display_name="GLM-5",
description="MoE model with function calling, structured output, and reasoning",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=[],
context_window=202800,
),
),
AvailableModel(
id="minimax/minimax-m2.5",
provider=ModelProvider.NOVITA,
display_name="MiniMax M2.5",
description="MoE model with function calling, structured output, and reasoning",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=[],
context_window=204800,
),
),
]
@@ -221,3 +249,18 @@ AZURE_OPENAI_MODELS = [
),
),
]
def create_custom_openai_model(model_name: str, base_url: str) -> AvailableModel:
"""Create a custom OpenAI-compatible model (e.g., LM Studio, Ollama)."""
return AvailableModel(
id=model_name,
provider=ModelProvider.OPENAI,
display_name=model_name,
description=f"Custom OpenAI-compatible model at {base_url}",
base_url=base_url,
capabilities=ModelCapabilities(
supports_tools=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
),
)

View File

@@ -8,6 +8,7 @@ logger = logging.getLogger(__name__)
class ModelProvider(str, Enum):
OPENAI = "openai"
OPENROUTER = "openrouter"
AZURE_OPENAI = "azure_openai"
ANTHROPIC = "anthropic"
GROQ = "groq"
@@ -84,9 +85,13 @@ class ModelRegistry:
self.models.clear()
self._add_docsgpt_models(settings)
if settings.OPENAI_API_KEY or (
settings.LLM_PROVIDER == "openai" and settings.API_KEY
# Skip DocsGPT model if using custom OpenAI-compatible endpoint
if not settings.OPENAI_BASE_URL:
self._add_docsgpt_models(settings)
if (
settings.OPENAI_API_KEY
or (settings.LLM_PROVIDER == "openai" and settings.API_KEY)
or settings.OPENAI_BASE_URL
):
self._add_openai_models(settings)
if settings.OPENAI_API_BASE or (
@@ -105,39 +110,73 @@ class ModelRegistry:
settings.LLM_PROVIDER == "groq" and settings.API_KEY
):
self._add_groq_models(settings)
if settings.OPEN_ROUTER_API_KEY or (
settings.LLM_PROVIDER == "openrouter" and settings.API_KEY
):
self._add_openrouter_models(settings)
if settings.NOVITA_API_KEY or (
settings.LLM_PROVIDER == "novita" and settings.API_KEY
):
self._add_novita_models(settings)
if settings.HUGGINGFACE_API_KEY or (
settings.LLM_PROVIDER == "huggingface" and settings.API_KEY
):
self._add_huggingface_models(settings)
# Default model selection
if settings.LLM_NAME and settings.LLM_NAME in self.models:
self.default_model_id = settings.LLM_NAME
elif settings.LLM_PROVIDER and settings.API_KEY:
for model_id, model in self.models.items():
if model.provider.value == settings.LLM_PROVIDER:
self.default_model_id = model_id
if settings.LLM_NAME:
# Parse LLM_NAME (may be comma-separated)
model_names = self._parse_model_names(settings.LLM_NAME)
# First model in the list becomes default
for model_name in model_names:
if model_name in self.models:
self.default_model_id = model_name
break
else:
# Backward compat: try exact match if no parsed model found
if not self.default_model_id and settings.LLM_NAME in self.models:
self.default_model_id = settings.LLM_NAME
if not self.default_model_id:
if settings.LLM_PROVIDER and settings.API_KEY:
for model_id, model in self.models.items():
if model.provider.value == settings.LLM_PROVIDER:
self.default_model_id = model_id
break
if not self.default_model_id and self.models:
self.default_model_id = next(iter(self.models.keys()))
logger.info(
f"ModelRegistry loaded {len(self.models)} models, default: {self.default_model_id}"
)
def _add_openai_models(self, settings):
from application.core.model_configs import OPENAI_MODELS
from application.core.model_configs import (
OPENAI_MODELS,
create_custom_openai_model,
)
if settings.OPENAI_API_KEY:
for model in OPENAI_MODELS:
self.models[model.id] = model
return
if settings.LLM_PROVIDER == "openai" and settings.LLM_NAME:
for model in OPENAI_MODELS:
if model.id == settings.LLM_NAME:
# Check if using local OpenAI-compatible endpoint (Ollama, LM Studio, etc.)
using_local_endpoint = bool(
settings.OPENAI_BASE_URL and settings.OPENAI_BASE_URL.strip()
)
if using_local_endpoint:
# When OPENAI_BASE_URL is set, ONLY register custom models from LLM_NAME
# Do NOT add standard OpenAI models (gpt-5.1, etc.)
if settings.LLM_NAME:
model_names = self._parse_model_names(settings.LLM_NAME)
for model_name in model_names:
custom_model = create_custom_openai_model(
model_name, settings.OPENAI_BASE_URL
)
self.models[model_name] = custom_model
logger.info(
f"Registered custom OpenAI model: {model_name} at {settings.OPENAI_BASE_URL}"
)
else:
# Standard OpenAI API usage - add standard models if API key is valid
if settings.OPENAI_API_KEY:
for model in OPENAI_MODELS:
self.models[model.id] = model
return
for model in OPENAI_MODELS:
self.models[model.id] = model
def _add_azure_openai_models(self, settings):
from application.core.model_configs import AZURE_OPENAI_MODELS
@@ -194,6 +233,36 @@ class ModelRegistry:
return
for model in GROQ_MODELS:
self.models[model.id] = model
def _add_openrouter_models(self, settings):
from application.core.model_configs import OPENROUTER_MODELS
if settings.OPEN_ROUTER_API_KEY:
for model in OPENROUTER_MODELS:
self.models[model.id] = model
return
if settings.LLM_PROVIDER == "openrouter" and settings.LLM_NAME:
for model in OPENROUTER_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in OPENROUTER_MODELS:
self.models[model.id] = model
def _add_novita_models(self, settings):
from application.core.model_configs import NOVITA_MODELS
if settings.NOVITA_API_KEY:
for model in NOVITA_MODELS:
self.models[model.id] = model
return
if settings.LLM_PROVIDER == "novita" and settings.LLM_NAME:
for model in NOVITA_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in NOVITA_MODELS:
self.models[model.id] = model
def _add_docsgpt_models(self, settings):
model_id = "docsgpt-local"
@@ -223,6 +292,15 @@ class ModelRegistry:
)
self.models[model_id] = model
def _parse_model_names(self, llm_name: str) -> List[str]:
"""
Parse LLM_NAME which may contain comma-separated model names.
E.g., 'deepseek-r1:1.5b,gemma:2b' -> ['deepseek-r1:1.5b', 'gemma:2b']
"""
if not llm_name:
return []
return [name.strip() for name in llm_name.split(",") if name.strip()]
def get_model(self, model_id: str) -> Optional[AvailableModel]:
return self.models.get(model_id)

View File

@@ -9,6 +9,8 @@ def get_api_key_for_provider(provider: str) -> Optional[str]:
provider_key_map = {
"openai": settings.OPENAI_API_KEY,
"openrouter": settings.OPEN_ROUTER_API_KEY,
"novita": settings.NOVITA_API_KEY,
"anthropic": settings.ANTHROPIC_API_KEY,
"google": settings.GOOGLE_API_KEY,
"groq": settings.GROQ_API_KEY,

View File

@@ -2,20 +2,22 @@ import os
from pathlib import Path
from typing import Optional
from pydantic_settings import BaseSettings
from pydantic import field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
class Settings(BaseSettings):
model_config = SettingsConfigDict(extra="ignore")
AUTH_TYPE: Optional[str] = None # simple_jwt, session_jwt, or None
LLM_PROVIDER: str = "docsgpt"
LLM_NAME: Optional[str] = (
None # if LLM_PROVIDER is openai, LLM_NAME can be gpt-4 or gpt-3.5-turbo
)
LLM_NAME: Optional[str] = None # if LLM_PROVIDER is openai, LLM_NAME can be gpt-4 or gpt-3.5-turbo
EMBEDDINGS_NAME: str = "huggingface_sentence-transformers/all-mpnet-base-v2"
EMBEDDINGS_BASE_URL: Optional[str] = None # Remote embeddings API URL (OpenAI-compatible)
EMBEDDINGS_KEY: Optional[str] = None # api key for embeddings (if using openai, just copy API_KEY)
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
MONGO_URI: str = "mongodb://localhost:27017/docsgpt"
@@ -35,9 +37,9 @@ class Settings(BaseSettings):
UPLOAD_FOLDER: str = "inputs"
PARSE_PDF_AS_IMAGE: bool = False
PARSE_IMAGE_REMOTE: bool = False
VECTOR_STORE: str = (
"faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb"
)
DOCLING_OCR_ENABLED: bool = False # Enable OCR for docling parsers (PDF, images)
DOCLING_OCR_ATTACHMENTS_ENABLED: bool = False # Enable OCR for docling when parsing attachments
VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb" or "pgvector"
RETRIEVERS_ENABLED: list = ["classic_rag"]
AGENT_NAME: str = "classic"
FALLBACK_LLM_PROVIDER: Optional[str] = None # provider for fallback llm
@@ -45,16 +47,18 @@ class Settings(BaseSettings):
FALLBACK_LLM_API_KEY: Optional[str] = None # api key for fallback llm
# Google Drive integration
GOOGLE_CLIENT_ID: Optional[str] = (
None # Replace with your actual Google OAuth client ID
)
GOOGLE_CLIENT_SECRET: Optional[str] = (
None # Replace with your actual Google OAuth client secret
)
GOOGLE_CLIENT_ID: Optional[str] = None # Replace with your actual Google OAuth client ID
GOOGLE_CLIENT_SECRET: Optional[str] = None # Replace with your actual Google OAuth client secret
CONNECTOR_REDIRECT_BASE_URI: Optional[str] = (
"http://127.0.0.1:7091/api/connectors/callback" ##add redirect url as it is to your provider's console(gcp)
)
# Microsoft Entra ID (Azure AD) integration
MICROSOFT_CLIENT_ID: Optional[str] = None # Azure AD Application (client) ID
MICROSOFT_CLIENT_SECRET: Optional[str] = None # Azure AD Application client secret
MICROSOFT_TENANT_ID: Optional[str] = "common" # Azure AD Tenant ID (or 'common' for multi-tenant)
MICROSOFT_AUTHORITY: Optional[str] = None # e.g., "https://login.microsoftonline.com/{tenant_id}"
# GitHub source
GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access
@@ -62,6 +66,8 @@ class Settings(BaseSettings):
CACHE_REDIS_URL: str = "redis://localhost:6379/2"
API_URL: str = "http://localhost:7091" # backend url for celery worker
MCP_OAUTH_REDIRECT_URI: Optional[str] = None # public callback URL for MCP OAuth
INTERNAL_KEY: Optional[str] = None # internal api key for worker-to-backend auth
API_KEY: Optional[str] = None # LLM api key (used by LLM_PROVIDER)
@@ -71,19 +77,14 @@ class Settings(BaseSettings):
GOOGLE_API_KEY: Optional[str] = None
GROQ_API_KEY: Optional[str] = None
HUGGINGFACE_API_KEY: Optional[str] = None
OPEN_ROUTER_API_KEY: Optional[str] = None
NOVITA_API_KEY: Optional[str] = None
EMBEDDINGS_KEY: Optional[str] = (
None # api key for embeddings (if using openai, just copy API_KEY)
)
OPENAI_API_BASE: Optional[str] = None # azure openai api base url
OPENAI_API_VERSION: Optional[str] = None # azure openai api version
AZURE_DEPLOYMENT_NAME: Optional[str] = None # azure deployment name for answering
AZURE_EMBEDDINGS_DEPLOYMENT_NAME: Optional[str] = (
None # azure deployment name for embeddings
)
OPENAI_BASE_URL: Optional[str] = (
None # openai base url for open ai compatable models
)
AZURE_EMBEDDINGS_DEPLOYMENT_NAME: Optional[str] = None # azure deployment name for embeddings
OPENAI_BASE_URL: Optional[str] = None # openai base url for open ai compatable models
# elasticsearch
ELASTIC_CLOUD_ID: Optional[str] = None # cloud id for elasticsearch
@@ -124,10 +125,8 @@ class Settings(BaseSettings):
MILVUS_TOKEN: Optional[str] = ""
# LanceDB vectorstore config
LANCEDB_PATH: str = "/tmp/lancedb" # Path where LanceDB stores its local data
LANCEDB_TABLE_NAME: Optional[str] = (
"docsgpts" # Name of the table to use for storing vectors
)
LANCEDB_PATH: str = "./data/lancedb" # Path where LanceDB stores its local data
LANCEDB_TABLE_NAME: Optional[str] = "docsgpts" # Name of the table to use for storing vectors
FLASK_DEBUG_MODE: bool = False
STORAGE_TYPE: str = "local" # local or s3
@@ -140,10 +139,55 @@ class Settings(BaseSettings):
TTS_PROVIDER: str = "google_tts" # google_tts or elevenlabs
ELEVENLABS_API_KEY: Optional[str] = None
STT_PROVIDER: str = "openai" # openai or faster_whisper
OPENAI_STT_MODEL: str = "gpt-4o-mini-transcribe"
STT_LANGUAGE: Optional[str] = None
STT_MAX_FILE_SIZE_MB: int = 50
STT_ENABLE_TIMESTAMPS: bool = False
STT_ENABLE_DIARIZATION: bool = False
# Tool pre-fetch settings
ENABLE_TOOL_PREFETCH: bool = True
# Conversation Compression Settings
ENABLE_CONVERSATION_COMPRESSION: bool = True
COMPRESSION_THRESHOLD_PERCENTAGE: float = 0.8 # Trigger at 80% of context
COMPRESSION_MODEL_OVERRIDE: Optional[str] = None # Use different model for compression
COMPRESSION_PROMPT_VERSION: str = "v1.0" # Track prompt iterations
COMPRESSION_MAX_HISTORY_POINTS: int = 3 # Keep only last N compression points to prevent DB bloat
path = Path(__file__).parent.parent.absolute()
@field_validator(
"API_KEY",
"OPENAI_API_KEY",
"ANTHROPIC_API_KEY",
"GOOGLE_API_KEY",
"GROQ_API_KEY",
"HUGGINGFACE_API_KEY",
"NOVITA_API_KEY",
"EMBEDDINGS_KEY",
"FALLBACK_LLM_API_KEY",
"QDRANT_API_KEY",
"ELEVENLABS_API_KEY",
"INTERNAL_KEY",
mode="before",
)
@classmethod
def normalize_api_key(cls, v: Optional[str]) -> Optional[str]:
"""
Normalize API keys: convert 'None', 'none', empty strings,
and whitespace-only strings to actual None.
Handles Pydantic loading 'None' from .env as string "None".
"""
if v is None:
return None
if not isinstance(v, str):
return v
stripped = v.strip()
if stripped == "" or stripped.lower() == "none":
return None
return stripped
# Project root is one level above application/
path = Path(__file__).parent.parent.parent.absolute()
settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8")

View File

@@ -0,0 +1,181 @@
"""
URL validation utilities to prevent SSRF (Server-Side Request Forgery) attacks.
This module provides functions to validate URLs before making HTTP requests,
blocking access to internal networks, cloud metadata services, and other
potentially dangerous endpoints.
"""
import ipaddress
import socket
from urllib.parse import urlparse
from typing import Optional, Set
class SSRFError(Exception):
"""Raised when a URL fails SSRF validation."""
pass
# Blocked hostnames that should never be accessed
BLOCKED_HOSTNAMES: Set[str] = {
"localhost",
"localhost.localdomain",
"metadata.google.internal",
"metadata",
}
# Cloud metadata IP addresses (AWS, GCP, Azure, etc.)
METADATA_IPS: Set[str] = {
"169.254.169.254", # AWS, GCP, Azure metadata
"169.254.170.2", # AWS ECS task metadata
"fd00:ec2::254", # AWS IPv6 metadata
}
# Allowed schemes for external requests
ALLOWED_SCHEMES: Set[str] = {"http", "https"}
def is_private_ip(ip_str: str) -> bool:
"""
Check if an IP address is private, loopback, or link-local.
Args:
ip_str: IP address as a string
Returns:
True if the IP is private/internal, False otherwise
"""
try:
ip = ipaddress.ip_address(ip_str)
return (
ip.is_private or
ip.is_loopback or
ip.is_link_local or
ip.is_reserved or
ip.is_multicast or
ip.is_unspecified
)
except ValueError:
# If we can't parse it as an IP, return False
return False
def is_metadata_ip(ip_str: str) -> bool:
"""
Check if an IP address is a cloud metadata service IP.
Args:
ip_str: IP address as a string
Returns:
True if the IP is a metadata service, False otherwise
"""
return ip_str in METADATA_IPS
def resolve_hostname(hostname: str) -> Optional[str]:
"""
Resolve a hostname to an IP address.
Args:
hostname: The hostname to resolve
Returns:
The resolved IP address, or None if resolution fails
"""
try:
return socket.gethostbyname(hostname)
except socket.gaierror:
return None
def validate_url(url: str, allow_localhost: bool = False) -> str:
"""
Validate a URL to prevent SSRF attacks.
This function checks that:
1. The URL has an allowed scheme (http or https)
2. The hostname is not a blocked hostname
3. The resolved IP is not a private/internal IP
4. The resolved IP is not a cloud metadata service
Args:
url: The URL to validate
allow_localhost: If True, allow localhost connections (for testing only)
Returns:
The validated URL (with scheme added if missing)
Raises:
SSRFError: If the URL fails validation
"""
# Ensure URL has a scheme
if not urlparse(url).scheme:
url = "http://" + url
parsed = urlparse(url)
# Check scheme
if parsed.scheme not in ALLOWED_SCHEMES:
raise SSRFError(f"URL scheme '{parsed.scheme}' is not allowed. Only HTTP(S) is permitted.")
hostname = parsed.hostname
if not hostname:
raise SSRFError("URL must have a valid hostname.")
hostname_lower = hostname.lower()
# Check blocked hostnames
if hostname_lower in BLOCKED_HOSTNAMES and not allow_localhost:
raise SSRFError(f"Access to '{hostname}' is not allowed.")
# Check if hostname is an IP address directly
try:
ip = ipaddress.ip_address(hostname)
ip_str = str(ip)
if is_metadata_ip(ip_str):
raise SSRFError("Access to cloud metadata services is not allowed.")
if is_private_ip(ip_str) and not allow_localhost:
raise SSRFError("Access to private/internal IP addresses is not allowed.")
return url
except ValueError:
# Not an IP address, it's a hostname - resolve it
pass
# Resolve hostname and check the IP
resolved_ip = resolve_hostname(hostname)
if resolved_ip is None:
raise SSRFError(f"Unable to resolve hostname: {hostname}")
if is_metadata_ip(resolved_ip):
raise SSRFError("Access to cloud metadata services is not allowed.")
if is_private_ip(resolved_ip) and not allow_localhost:
raise SSRFError("Access to private/internal networks is not allowed.")
return url
def validate_url_safe(url: str, allow_localhost: bool = False) -> tuple[bool, str, Optional[str]]:
"""
Validate a URL and return a tuple with validation result.
This is a non-throwing version of validate_url for cases where
you want to handle validation failures gracefully.
Args:
url: The URL to validate
allow_localhost: If True, allow localhost connections (for testing only)
Returns:
Tuple of (is_valid, validated_url_or_original, error_message_or_none)
"""
try:
validated = validate_url(url, allow_localhost)
return (True, validated, None)
except SSRFError as e:
return (False, url, str(e))

View File

@@ -13,3 +13,25 @@ def response_error(code_status, message=None):
def bad_request(status_code=400, message=''):
return response_error(code_status=status_code, message=message)
def sanitize_api_error(error) -> str:
"""
Convert technical API errors to user-friendly messages.
Works with both Exception objects and error message strings.
"""
error_str = str(error).lower()
if "503" in error_str or "unavailable" in error_str or "high demand" in error_str:
return "The AI service is temporarily unavailable due to high demand. Please try again in a moment."
if "429" in error_str or "rate limit" in error_str or "quota" in error_str:
return "Rate limit exceeded. Please wait a moment before trying again."
if "401" in error_str or "unauthorized" in error_str or "invalid api key" in error_str:
return "Authentication error. Please check your API configuration."
if "timeout" in error_str or "timed out" in error_str:
return "The request timed out. Please try again."
if "connection" in error_str or "network" in error_str:
return "Network error. Please check your connection and try again."
original = str(error)
if len(original) > 200 or "{" in original or "traceback" in error_str:
return "An error occurred while processing your request. Please try again later."
return original

View File

@@ -1,7 +1,13 @@
import base64
import logging
from anthropic import AI_PROMPT, Anthropic, HUMAN_PROMPT
from application.core.settings import settings
from application.llm.base import BaseLLM
from application.storage.storage_creator import StorageCreator
logger = logging.getLogger(__name__)
class AnthropicLLM(BaseLLM):
@@ -20,6 +26,7 @@ class AnthropicLLM(BaseLLM):
self.HUMAN_PROMPT = HUMAN_PROMPT
self.AI_PROMPT = AI_PROMPT
self.storage = StorageCreator.get_storage()
def _raw_gen(
self,
@@ -70,3 +77,115 @@ class AnthropicLLM(BaseLLM):
finally:
if hasattr(stream_response, "close"):
stream_response.close()
def get_supported_attachment_types(self):
"""
Return a list of MIME types supported by Anthropic Claude for file uploads.
Claude supports images but not PDFs natively.
PDFs are synthetically supported via PDF-to-image conversion in the handler.
Returns:
list: List of supported MIME types
"""
return [
"image/png",
"image/jpeg",
"image/jpg",
"image/webp",
"image/gif",
]
def prepare_messages_with_attachments(self, messages, attachments=None):
"""
Process attachments for Anthropic Claude API.
Formats images using Claude's vision message format.
Args:
messages (list): List of message dictionaries.
attachments (list): List of attachment dictionaries with content and metadata.
Returns:
list: Messages formatted with image content for Claude API.
"""
if not attachments:
return messages
prepared_messages = messages.copy()
# Find the last user message to attach images to
user_message_index = None
for i in range(len(prepared_messages) - 1, -1, -1):
if prepared_messages[i].get("role") == "user":
user_message_index = i
break
if user_message_index is None:
user_message = {"role": "user", "content": []}
prepared_messages.append(user_message)
user_message_index = len(prepared_messages) - 1
# Convert content to list format if it's a string
if isinstance(prepared_messages[user_message_index].get("content"), str):
text_content = prepared_messages[user_message_index]["content"]
prepared_messages[user_message_index]["content"] = [
{"type": "text", "text": text_content}
]
elif not isinstance(prepared_messages[user_message_index].get("content"), list):
prepared_messages[user_message_index]["content"] = []
for attachment in attachments:
mime_type = attachment.get("mime_type")
if mime_type and mime_type.startswith("image/"):
try:
# Check if this is a pre-converted image (from PDF-to-image conversion)
# These have 'data' key with base64 already
if "data" in attachment:
base64_image = attachment["data"]
else:
base64_image = self._get_base64_image(attachment)
# Claude uses a specific format for images
prepared_messages[user_message_index]["content"].append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": mime_type,
"data": base64_image,
},
}
)
except Exception as e:
logger.error(
f"Error processing image attachment: {e}", exc_info=True
)
if "content" in attachment:
prepared_messages[user_message_index]["content"].append(
{
"type": "text",
"text": f"[Image could not be processed: {attachment.get('path', 'unknown')}]",
}
)
return prepared_messages
def _get_base64_image(self, attachment):
"""
Convert an image file to base64 encoding.
Args:
attachment (dict): Attachment dictionary with path and metadata.
Returns:
str: Base64-encoded image data.
"""
file_path = attachment.get("path")
if not file_path:
raise ValueError("No file path provided in attachment")
try:
with self.storage.get_file(file_path) as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
except FileNotFoundError:
raise FileNotFoundError(f"File not found: {file_path}")

View File

@@ -13,37 +13,81 @@ class BaseLLM(ABC):
def __init__(
self,
decoded_token=None,
agent_id=None,
model_id=None,
base_url=None,
backup_models=None,
):
self.decoded_token = decoded_token
self.agent_id = str(agent_id) if agent_id else None
self.model_id = model_id
self.base_url = base_url
self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
self._backup_models = backup_models or []
self._fallback_llm = None
self._fallback_sequence_index = 0
@property
def fallback_llm(self):
"""Lazy-loaded fallback LLM from FALLBACK_* settings."""
if self._fallback_llm is None and settings.FALLBACK_LLM_PROVIDER:
try:
from application.llm.llm_creator import LLMCreator
"""Lazy-loaded fallback LLM: tries per-agent backup models first,
then the global FALLBACK_* settings."""
if self._fallback_llm is not None:
return self._fallback_llm
from application.llm.llm_creator import LLMCreator
from application.core.model_utils import (
get_provider_from_model_id,
get_api_key_for_provider,
)
# Try per-agent backup models first
for backup_model_id in self._backup_models:
try:
provider = get_provider_from_model_id(backup_model_id)
if not provider:
logger.warning(
f"Could not resolve provider for backup model: {backup_model_id}"
)
continue
api_key = get_api_key_for_provider(provider)
self._fallback_llm = LLMCreator.create_llm(
provider,
api_key=api_key,
user_api_key=getattr(self, "user_api_key", None),
decoded_token=self.decoded_token,
model_id=backup_model_id,
agent_id=self.agent_id,
)
logger.info(
f"Fallback LLM initialized from agent backup model: "
f"{provider}/{backup_model_id}"
)
return self._fallback_llm
except Exception as e:
logger.warning(
f"Failed to initialize backup model {backup_model_id}: {str(e)}"
)
continue
# Fall back to global FALLBACK_* settings
if settings.FALLBACK_LLM_PROVIDER:
try:
self._fallback_llm = LLMCreator.create_llm(
settings.FALLBACK_LLM_PROVIDER,
api_key=settings.FALLBACK_LLM_API_KEY or settings.API_KEY,
user_api_key=None,
user_api_key=getattr(self, "user_api_key", None),
decoded_token=self.decoded_token,
model_id=settings.FALLBACK_LLM_NAME,
agent_id=self.agent_id,
)
logger.info(
f"Fallback LLM initialized: {settings.FALLBACK_LLM_PROVIDER}/{settings.FALLBACK_LLM_NAME}"
f"Fallback LLM initialized from global settings: "
f"{settings.FALLBACK_LLM_PROVIDER}/{settings.FALLBACK_LLM_NAME}"
)
except Exception as e:
logger.error(
f"Failed to initialize fallback LLM: {str(e)}", exc_info=True
)
return self._fallback_llm
@staticmethod
@@ -71,20 +115,60 @@ class BaseLLM(ABC):
method = decorator(method)
return method(self, *args, **kwargs)
is_stream = "stream" in method_name
if is_stream:
return self._stream_with_fallback(
decorated_method, method_name, *args, **kwargs
)
try:
return decorated_method()
except Exception as e:
if not self.fallback_llm:
logger.error(f"Primary LLM failed and no fallback configured: {str(e)}")
raise
fallback = self.fallback_llm
logger.warning(
f"Primary LLM failed. Falling back to {settings.FALLBACK_LLM_PROVIDER}/{settings.FALLBACK_LLM_NAME}. Error: {str(e)}"
f"Primary LLM failed. Falling back to "
f"{fallback.model_id}. Error: {str(e)}"
)
fallback_method = getattr(
self.fallback_llm, method_name.replace("_raw_", "")
fallback, method_name.replace("_raw_", "")
)
return fallback_method(*args, **kwargs)
fallback_kwargs = {**kwargs, "model": fallback.model_id}
return fallback_method(*args, **fallback_kwargs)
def _stream_with_fallback(
self, decorated_method, method_name, *args, **kwargs
):
"""
Wrapper generator that catches mid-stream errors and falls back.
Unlike non-streaming calls where exceptions are raised immediately,
streaming generators raise exceptions during iteration. This wrapper
ensures that if the primary LLM fails at any point during streaming
(creation or mid-stream), we fall back to the backup model.
"""
try:
yield from decorated_method()
except Exception as e:
if not self.fallback_llm:
logger.error(
f"Primary LLM failed and no fallback configured: {str(e)}"
)
raise
fallback = self.fallback_llm
logger.warning(
f"Primary LLM failed mid-stream. Falling back to "
f"{fallback.model_id}. Error: {str(e)}"
)
fallback_method = getattr(
fallback, method_name.replace("_raw_", "")
)
fallback_kwargs = {**kwargs, "model": fallback.model_id}
yield from fallback_method(*args, **fallback_kwargs)
def gen(self, model, messages, stream=False, tools=None, *args, **kwargs):
decorators = [gen_token_usage, gen_cache]

View File

@@ -1,75 +1,19 @@
import json
from openai import OpenAI
from application.core.settings import settings
from application.llm.base import BaseLLM
from application.llm.openai import OpenAILLM
DOCSGPT_API_KEY = "sk-docsgpt-public"
DOCSGPT_BASE_URL = "https://oai.arc53.com"
DOCSGPT_MODEL = "docsgpt"
class DocsGPTAPILLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.api_key = "sk-docsgpt-public"
self.client = OpenAI(api_key=self.api_key, base_url="https://oai.arc53.com")
self.user_api_key = user_api_key
def _clean_messages_openai(self, messages):
cleaned_messages = []
for message in messages:
role = message.get("role")
content = message.get("content")
if role == "model":
role = "assistant"
if role and content is not None:
if isinstance(content, str):
cleaned_messages.append({"role": role, "content": content})
elif isinstance(content, list):
for item in content:
if "text" in item:
cleaned_messages.append(
{"role": role, "content": item["text"]}
)
elif "function_call" in item:
cleaned_args = self._remove_null_values(
item["function_call"]["args"]
)
tool_call = {
"id": item["function_call"]["call_id"],
"type": "function",
"function": {
"name": item["function_call"]["name"],
"arguments": json.dumps(cleaned_args),
},
}
cleaned_messages.append(
{
"role": "assistant",
"content": None,
"tool_calls": [tool_call],
}
)
elif "function_response" in item:
cleaned_messages.append(
{
"role": "tool",
"tool_call_id": item["function_response"][
"call_id"
],
"content": json.dumps(
item["function_response"]["response"]["result"]
),
}
)
else:
raise ValueError(
f"Unexpected content dictionary format: {item}"
)
else:
raise ValueError(f"Unexpected content type: {type(content)}")
return cleaned_messages
class DocsGPTAPILLM(OpenAILLM):
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
super().__init__(
api_key=DOCSGPT_API_KEY,
user_api_key=user_api_key,
base_url=DOCSGPT_BASE_URL,
*args,
**kwargs,
)
def _raw_gen(
self,
@@ -79,23 +23,19 @@ class DocsGPTAPILLM(BaseLLM):
stream=False,
tools=None,
engine=settings.AZURE_DEPLOYMENT_NAME,
response_format=None,
**kwargs,
):
messages = self._clean_messages_openai(messages)
if tools:
response = self.client.chat.completions.create(
model="docsgpt",
messages=messages,
stream=stream,
tools=tools,
**kwargs,
)
return response.choices[0]
else:
response = self.client.chat.completions.create(
model="docsgpt", messages=messages, stream=stream, **kwargs
)
return response.choices[0].message.content
return super()._raw_gen(
baseself,
DOCSGPT_MODEL,
messages,
stream=stream,
tools=tools,
engine=engine,
response_format=response_format,
**kwargs,
)
def _raw_gen_stream(
self,
@@ -105,34 +45,16 @@ class DocsGPTAPILLM(BaseLLM):
stream=True,
tools=None,
engine=settings.AZURE_DEPLOYMENT_NAME,
response_format=None,
**kwargs,
):
messages = self._clean_messages_openai(messages)
if tools:
response = self.client.chat.completions.create(
model="docsgpt",
messages=messages,
stream=stream,
tools=tools,
**kwargs,
)
else:
response = self.client.chat.completions.create(
model="docsgpt", messages=messages, stream=stream, **kwargs
)
try:
for line in response:
if (
len(line.choices) > 0
and line.choices[0].delta.content is not None
and len(line.choices[0].delta.content) > 0
):
yield line.choices[0].delta.content
elif len(line.choices) > 0:
yield line.choices[0]
finally:
if hasattr(response, "close"):
response.close()
def _supports_tools(self):
return True
return super()._raw_gen_stream(
baseself,
DOCSGPT_MODEL,
messages,
stream=stream,
tools=tools,
engine=engine,
response_format=response_format,
**kwargs,
)

View File

@@ -1,4 +1,3 @@
import json
import logging
from google import genai
@@ -11,11 +10,13 @@ from application.storage.storage_creator import StorageCreator
class GoogleLLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(
self, api_key=None, user_api_key=None, decoded_token=None, *args, **kwargs
):
super().__init__(decoded_token=decoded_token, *args, **kwargs)
self.api_key = api_key or settings.GOOGLE_API_KEY or settings.API_KEY
self.user_api_key = user_api_key
self.client = genai.Client(api_key=self.api_key)
self.storage = StorageCreator.get_storage()
@@ -33,6 +34,12 @@ class GoogleLLM(BaseLLM):
"image/jpg",
"image/webp",
"image/gif",
"application/pdf",
"image/png",
"image/jpeg",
"image/jpg",
"image/webp",
"image/gif",
]
def prepare_messages_with_attachments(self, messages, attachments=None):
@@ -135,12 +142,42 @@ class GoogleLLM(BaseLLM):
raise
def _clean_messages_google(self, messages):
"""Convert OpenAI format messages to Google AI format."""
"""
Convert OpenAI format messages to Google AI format and collect system prompts.
Returns:
tuple[list[types.Content], Optional[str]]: cleaned messages and optional
combined system instruction.
"""
cleaned_messages = []
system_instructions = []
def _extract_system_text(content):
if isinstance(content, str):
return content
if isinstance(content, list):
parts = []
for item in content:
if (
isinstance(item, dict)
and "text" in item
and item["text"] is not None
):
parts.append(item["text"])
return "\n".join(parts)
return ""
for message in messages:
role = message.get("role")
content = message.get("content")
# Gemini only accepts user/model in the contents list.
if role == "system":
sys_text = _extract_system_text(content)
if sys_text:
system_instructions.append(sys_text)
continue
if role == "assistant":
role = "model"
elif role == "tool":
@@ -159,12 +196,27 @@ class GoogleLLM(BaseLLM):
cleaned_args = self._remove_null_values(
item["function_call"]["args"]
)
parts.append(
types.Part.from_function_call(
name=item["function_call"]["name"],
args=cleaned_args,
# Create function call part with thought_signature if present
# For Gemini 3 models, we need to include thought_signature
if "thought_signature" in item:
# Use Part constructor with functionCall and thoughtSignature
parts.append(
types.Part(
functionCall=types.FunctionCall(
name=item["function_call"]["name"],
args=cleaned_args,
),
thoughtSignature=item["thought_signature"],
)
)
else:
# Use helper method when no thought_signature
parts.append(
types.Part.from_function_call(
name=item["function_call"]["name"],
args=cleaned_args,
)
)
)
elif "function_response" in item:
parts.append(
types.Part.from_function_response(
@@ -188,7 +240,10 @@ class GoogleLLM(BaseLLM):
raise ValueError(f"Unexpected content type: {type(content)}")
if parts:
cleaned_messages.append(types.Content(role=role, parts=parts))
return cleaned_messages
system_instruction = (
"\n\n".join(system_instructions) if system_instructions else None
)
return cleaned_messages, system_instruction
def _clean_schema(self, schema_obj):
"""
@@ -274,6 +329,80 @@ class GoogleLLM(BaseLLM):
genai_tools.append(genai_tool)
return genai_tools
def _extract_preview_from_message(self, message):
"""Get a short, human-readable preview from the last message."""
try:
if hasattr(message, "parts"):
for part in reversed(message.parts):
if getattr(part, "text", None):
return part.text
function_call = getattr(part, "function_call", None)
if function_call:
name = getattr(function_call, "name", "") or "function_call"
return f"function_call:{name}"
function_response = getattr(part, "function_response", None)
if function_response:
name = (
getattr(function_response, "name", "")
or "function_response"
)
return f"function_response:{name}"
if isinstance(message, dict):
content = message.get("content")
if isinstance(content, str):
return content
if isinstance(content, list):
for item in reversed(content):
if isinstance(item, str):
return item
if isinstance(item, dict):
if item.get("text"):
return item["text"]
if item.get("function_call"):
fn = item["function_call"]
if isinstance(fn, dict):
name = fn.get("name") or "function_call"
return f"function_call:{name}"
return "function_call"
if item.get("function_response"):
resp = item["function_response"]
if isinstance(resp, dict):
name = resp.get("name") or "function_response"
return f"function_response:{name}"
return "function_response"
if "text" in message and isinstance(message["text"], str):
return message["text"]
except Exception:
pass
return str(message)
def _summarize_messages_for_log(self, messages, preview_chars=20):
"""Return a compact summary for logging to avoid huge payloads."""
message_count = len(messages) if messages else 0
last_preview = ""
if messages:
last_preview = self._extract_preview_from_message(messages[-1]) or ""
last_preview = str(last_preview).replace("\n", " ")
if len(last_preview) > preview_chars:
last_preview = f"{last_preview[:preview_chars]}..."
return f"count={message_count}, last='{last_preview}'"
@staticmethod
def _get_text_value(part):
"""Get text from both SDK objects and dict-shaped test doubles."""
if isinstance(part, dict):
value = part.get("text")
return value if isinstance(value, str) else ""
value = getattr(part, "text", None)
return value if isinstance(value, str) else ""
@staticmethod
def _is_thought_part(part):
"""Detect Gemini thinking parts when available."""
if isinstance(part, dict):
return bool(part.get("thought"))
return bool(getattr(part, "thought", False))
def _raw_gen(
self,
baseself,
@@ -287,12 +416,12 @@ class GoogleLLM(BaseLLM):
):
"""Generate content using Google AI API without streaming."""
client = genai.Client(api_key=self.api_key)
system_instruction = None
if formatting == "openai":
messages = self._clean_messages_google(messages)
messages, system_instruction = self._clean_messages_google(messages)
config = types.GenerateContentConfig()
if messages[0].role == "system":
config.system_instruction = messages[0].parts[0].text
messages = messages[1:]
if system_instruction:
config.system_instruction = system_instruction
if tools:
cleaned_tools = self._clean_tools_format(tools)
config.tools = cleaned_tools
@@ -325,16 +454,15 @@ class GoogleLLM(BaseLLM):
):
"""Generate content using Google AI API with streaming."""
client = genai.Client(api_key=self.api_key)
system_instruction = None
if formatting == "openai":
messages = self._clean_messages_google(messages)
messages, system_instruction = self._clean_messages_google(messages)
config = types.GenerateContentConfig()
if messages[0].role == "system":
config.system_instruction = messages[0].parts[0].text
messages = messages[1:]
if system_instruction:
config.system_instruction = system_instruction
if tools:
cleaned_tools = self._clean_tools_format(tools)
config.tools = cleaned_tools
# Add response schema for structured output if provided
if response_schema:
config.response_schema = response_schema
@@ -349,8 +477,12 @@ class GoogleLLM(BaseLLM):
break
if has_attachments:
break
messages_summary = self._summarize_messages_for_log(messages)
logging.info(
f"GoogleLLM: Starting stream generation. Model: {model}, Messages: {json.dumps(messages, default=str)}, Has attachments: {has_attachments}"
"GoogleLLM: Starting stream generation. Model: %s, Messages: %s, Has attachments: %s",
model,
messages_summary,
has_attachments,
)
response = client.models.generate_content_stream(
@@ -367,10 +499,26 @@ class GoogleLLM(BaseLLM):
for part in candidate.content.parts:
if part.function_call:
yield part
elif part.text:
yield part.text
continue
part_text = self._get_text_value(part)
if not part_text:
continue
if self._is_thought_part(part):
yield {"type": "thought", "thought": part_text}
else:
yield part_text
elif hasattr(chunk, "text"):
yield chunk.text
chunk_text = self._get_text_value(chunk)
if chunk_text:
if self._is_thought_part(chunk):
yield {"type": "thought", "thought": chunk_text}
else:
yield chunk_text
except Exception as e:
logging.error(f"GoogleLLM: Stream error: {e}", exc_info=True)
raise
finally:
if hasattr(response, "close"):
response.close()

View File

@@ -1,37 +1,15 @@
from openai import OpenAI
from application.core.settings import settings
from application.llm.base import BaseLLM
from application.llm.openai import OpenAILLM
GROQ_BASE_URL = "https://api.groq.com/openai/v1"
class GroqLLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.api_key = api_key or settings.GROQ_API_KEY or settings.API_KEY
self.user_api_key = user_api_key
self.client = OpenAI(
api_key=self.api_key, base_url="https://api.groq.com/openai/v1"
class GroqLLM(OpenAILLM):
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
super().__init__(
api_key=api_key or settings.GROQ_API_KEY or settings.API_KEY,
user_api_key=user_api_key,
base_url=base_url or GROQ_BASE_URL,
*args,
**kwargs,
)
def _raw_gen(self, baseself, model, messages, stream=False, tools=None, **kwargs):
if tools:
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, tools=tools, **kwargs
)
return response.choices[0]
else:
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
return response.choices[0].message.content
def _raw_gen_stream(
self, baseself, model, messages, stream=True, tools=None, **kwargs
):
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
for line in response:
if line.choices[0].delta.content is not None:
yield line.choices[0].delta.content

View File

@@ -1,4 +1,5 @@
import logging
import uuid
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Generator, List, Optional, Union
@@ -16,6 +17,7 @@ class ToolCall:
name: str
arguments: Union[str, Dict]
index: Optional[int] = None
thought_signature: Optional[str] = None
@classmethod
def from_dict(cls, data: Dict) -> "ToolCall":
@@ -103,6 +105,7 @@ class LLMHandler(ABC):
"""
Prepare messages with attachments and provider-specific formatting.
Args:
agent: The agent instance
messages: Original messages
@@ -116,11 +119,40 @@ class LLMHandler(ABC):
logger.info(f"Preparing messages with {len(attachments)} attachments")
supported_types = agent.llm.get_supported_attachment_types()
# Check if provider supports images but not PDF (synthetic PDF support)
supports_images = any(t.startswith("image/") for t in supported_types)
supports_pdf = "application/pdf" in supported_types
# Process attachments, converting PDFs to images if needed
processed_attachments = []
for attachment in attachments:
mime_type = attachment.get("mime_type")
# Synthetic PDF support: convert PDF to images if LLM supports images but not PDF
if mime_type == "application/pdf" and supports_images and not supports_pdf:
logger.info(
f"Converting PDF to images for synthetic PDF support: {attachment.get('path', 'unknown')}"
)
try:
converted_images = self._convert_pdf_to_images(attachment)
processed_attachments.extend(converted_images)
logger.info(
f"Converted PDF to {len(converted_images)} images"
)
except Exception as e:
logger.error(
f"Failed to convert PDF to images, falling back to text: {e}"
)
# Fall back to treating as unsupported (text extraction)
processed_attachments.append(attachment)
else:
processed_attachments.append(attachment)
supported_attachments = [
a for a in attachments if a.get("mime_type") in supported_types
a for a in processed_attachments if a.get("mime_type") in supported_types
]
unsupported_attachments = [
a for a in attachments if a.get("mime_type") not in supported_types
a for a in processed_attachments if a.get("mime_type") not in supported_types
]
# Process supported attachments with the LLM's custom method
@@ -143,6 +175,37 @@ class LLMHandler(ABC):
)
return messages
def _convert_pdf_to_images(self, attachment: Dict) -> List[Dict]:
"""
Convert a PDF attachment to a list of image attachments.
This enables synthetic PDF support for LLMs that support images but not PDFs.
Args:
attachment: PDF attachment dictionary with 'path' and optional 'content'
Returns:
List of image attachment dictionaries with 'data', 'mime_type', and 'page'
"""
from application.utils import convert_pdf_to_images
from application.storage.storage_creator import StorageCreator
file_path = attachment.get("path")
if not file_path:
raise ValueError("No file path provided in PDF attachment")
storage = StorageCreator.get_storage()
# Convert PDF to images
images_data = convert_pdf_to_images(
file_path=file_path,
storage=storage,
max_pages=20,
dpi=150,
)
return images_data
def _append_unsupported_attachments(
self, messages: List[Dict], attachments: List[Dict]
) -> List[Dict]:
@@ -178,6 +241,407 @@ class LLMHandler(ABC):
system_msg["content"] += f"\n\n{combined_text}"
return prepared_messages
def _prune_messages_minimal(self, messages: List[Dict]) -> Optional[List[Dict]]:
"""
Build a minimal context: system prompt + latest user message only.
Drops all tool/function messages to shrink context aggressively.
"""
system_message = next((m for m in messages if m.get("role") == "system"), None)
if not system_message:
logger.warning("Cannot prune messages minimally: missing system message.")
return None
last_non_system = None
for m in reversed(messages):
if m.get("role") == "user":
last_non_system = m
break
if not last_non_system and m.get("role") not in ("system", None):
last_non_system = m
if not last_non_system:
logger.warning("Cannot prune messages minimally: missing user/assistant messages.")
return None
logger.info("Pruning context to system + latest user/assistant message to proceed.")
return [system_message, last_non_system]
def _extract_text_from_content(self, content: Any) -> str:
"""
Convert message content (str or list of parts) to plain text for compression.
"""
if isinstance(content, str):
return content
if isinstance(content, list):
parts_text = []
for item in content:
if isinstance(item, dict):
if "text" in item and item["text"] is not None:
parts_text.append(str(item["text"]))
elif "function_call" in item or "function_response" in item:
# Keep serialized function calls/responses so the compressor sees actions
parts_text.append(str(item))
elif "files" in item:
parts_text.append(str(item))
return "\n".join(parts_text)
return ""
def _build_conversation_from_messages(self, messages: List[Dict]) -> Optional[Dict]:
"""
Build a conversation-like dict from current messages so we can compress
even when the conversation isn't persisted yet. Includes tool calls/results.
"""
queries = []
current_prompt = None
current_tool_calls = {}
def _commit_query(response_text: str):
nonlocal current_prompt, current_tool_calls
if current_prompt is None and not response_text:
return
tool_calls_list = list(current_tool_calls.values())
queries.append(
{
"prompt": current_prompt or "",
"response": response_text,
"tool_calls": tool_calls_list,
}
)
current_prompt = None
current_tool_calls = {}
for message in messages:
role = message.get("role")
content = message.get("content")
if role == "user":
current_prompt = self._extract_text_from_content(content)
elif role in {"assistant", "model"}:
# If this assistant turn contains tool calls, collect them; otherwise commit a response.
if isinstance(content, list):
for item in content:
if "function_call" in item:
fc = item["function_call"]
call_id = fc.get("call_id") or str(uuid.uuid4())
current_tool_calls[call_id] = {
"tool_name": "unknown_tool",
"action_name": fc.get("name"),
"arguments": fc.get("args"),
"result": None,
"status": "called",
"call_id": call_id,
}
elif "function_response" in item:
fr = item["function_response"]
call_id = fr.get("call_id") or str(uuid.uuid4())
current_tool_calls[call_id] = {
"tool_name": "unknown_tool",
"action_name": fr.get("name"),
"arguments": None,
"result": fr.get("response", {}).get("result"),
"status": "completed",
"call_id": call_id,
}
# No direct assistant text here; continue to next message
continue
response_text = self._extract_text_from_content(content)
_commit_query(response_text)
elif role == "tool":
# Attach tool outputs to the latest pending tool call if possible
tool_text = self._extract_text_from_content(content)
# Attempt to parse function_response style
call_id = None
if isinstance(content, list):
for item in content:
if "function_response" in item and item["function_response"].get("call_id"):
call_id = item["function_response"]["call_id"]
break
if call_id and call_id in current_tool_calls:
current_tool_calls[call_id]["result"] = tool_text
current_tool_calls[call_id]["status"] = "completed"
elif queries:
queries[-1].setdefault("tool_calls", []).append(
{
"tool_name": "unknown_tool",
"action_name": "unknown_action",
"arguments": {},
"result": tool_text,
"status": "completed",
}
)
# If there's an unfinished prompt with tool_calls but no response yet, commit it
if current_prompt is not None or current_tool_calls:
_commit_query(response_text="")
if not queries:
return None
return {
"queries": queries,
"compression_metadata": {
"is_compressed": False,
"compression_points": [],
},
}
def _rebuild_messages_after_compression(
self,
messages: List[Dict],
compressed_summary: Optional[str],
recent_queries: List[Dict],
include_current_execution: bool = False,
include_tool_calls: bool = False,
) -> Optional[List[Dict]]:
"""
Rebuild the message list after compression so tool execution can continue.
Delegates to MessageBuilder for the actual reconstruction.
"""
from application.api.answer.services.compression.message_builder import (
MessageBuilder,
)
return MessageBuilder.rebuild_messages_after_compression(
messages=messages,
compressed_summary=compressed_summary,
recent_queries=recent_queries,
include_current_execution=include_current_execution,
include_tool_calls=include_tool_calls,
)
def _perform_mid_execution_compression(
self, agent, messages: List[Dict]
) -> tuple[bool, Optional[List[Dict]]]:
"""
Perform compression during tool execution and rebuild messages.
Uses the new orchestrator for simplified compression.
Args:
agent: The agent instance
messages: Current conversation messages
Returns:
(success: bool, rebuilt_messages: Optional[List[Dict]])
"""
try:
from application.api.answer.services.compression import (
CompressionOrchestrator,
)
from application.api.answer.services.conversation_service import (
ConversationService,
)
conversation_service = ConversationService()
orchestrator = CompressionOrchestrator(conversation_service)
# Get conversation from database (may be None for new sessions)
conversation = conversation_service.get_conversation(
agent.conversation_id, agent.initial_user_id
)
if conversation:
# Merge current in-flight messages (including tool calls)
conversation_from_msgs = self._build_conversation_from_messages(messages)
if conversation_from_msgs:
conversation = conversation_from_msgs
else:
logger.warning(
"Could not load conversation for compression; attempting in-memory compression"
)
return self._perform_in_memory_compression(agent, messages)
# Use orchestrator to perform compression
result = orchestrator.compress_mid_execution(
conversation_id=agent.conversation_id,
user_id=agent.initial_user_id,
model_id=agent.model_id,
decoded_token=getattr(agent, "decoded_token", {}),
current_conversation=conversation,
)
if not result.success:
logger.warning(f"Mid-execution compression failed: {result.error}")
# Try minimal pruning as fallback
pruned = self._prune_messages_minimal(messages)
if pruned:
agent.context_limit_reached = False
agent.current_token_count = 0
return True, pruned
return False, None
if not result.compression_performed:
logger.warning("Compression not performed")
return False, None
# Check if compression actually reduced tokens
if result.metadata:
if result.metadata.compressed_token_count >= result.metadata.original_token_count:
logger.warning(
"Compression did not reduce token count; falling back to minimal pruning"
)
pruned = self._prune_messages_minimal(messages)
if pruned:
agent.context_limit_reached = False
agent.current_token_count = 0
return True, pruned
return False, None
logger.info(
f"Mid-execution compression successful - ratio: {result.metadata.compression_ratio:.1f}x, "
f"saved {result.metadata.original_token_count - result.metadata.compressed_token_count} tokens"
)
# Also store the compression summary as a visible message
if result.metadata:
conversation_service.append_compression_message(
agent.conversation_id, result.metadata.to_dict()
)
# Update agent's compressed summary for downstream persistence
agent.compressed_summary = result.compressed_summary
agent.compression_metadata = result.metadata.to_dict() if result.metadata else None
agent.compression_saved = False
# Reset the context limit flag so tools can continue
agent.context_limit_reached = False
agent.current_token_count = 0
# Rebuild messages
rebuilt_messages = self._rebuild_messages_after_compression(
messages,
result.compressed_summary,
result.recent_queries,
include_current_execution=False,
include_tool_calls=False,
)
if rebuilt_messages is None:
return False, None
return True, rebuilt_messages
except Exception as e:
logger.error(
f"Error performing mid-execution compression: {str(e)}", exc_info=True
)
return False, None
def _perform_in_memory_compression(
self, agent, messages: List[Dict]
) -> tuple[bool, Optional[List[Dict]]]:
"""
Fallback compression path when the conversation is not yet persisted.
Uses CompressionService directly without DB persistence.
"""
try:
from application.api.answer.services.compression.service import (
CompressionService,
)
from application.core.model_utils import (
get_api_key_for_provider,
get_provider_from_model_id,
)
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
conversation = self._build_conversation_from_messages(messages)
if not conversation:
logger.warning(
"Cannot perform in-memory compression: no user/assistant turns found"
)
return False, None
compression_model = (
settings.COMPRESSION_MODEL_OVERRIDE
if settings.COMPRESSION_MODEL_OVERRIDE
else agent.model_id
)
provider = get_provider_from_model_id(compression_model)
api_key = get_api_key_for_provider(provider)
compression_llm = LLMCreator.create_llm(
provider,
api_key,
getattr(agent, "user_api_key", None),
getattr(agent, "decoded_token", None),
model_id=compression_model,
agent_id=getattr(agent, "agent_id", None),
)
# Create service without DB persistence capability
compression_service = CompressionService(
llm=compression_llm,
model_id=compression_model,
conversation_service=None, # No DB updates for in-memory
)
queries_count = len(conversation.get("queries", []))
compress_up_to = queries_count - 1
if compress_up_to < 0 or queries_count == 0:
logger.warning("Not enough queries to compress in-memory context")
return False, None
metadata = compression_service.compress_conversation(
conversation,
compress_up_to_index=compress_up_to,
)
# If compression doesn't reduce tokens, fall back to minimal pruning
if (
metadata.compressed_token_count
>= metadata.original_token_count
):
logger.warning(
"In-memory compression did not reduce token count; falling back to minimal pruning"
)
pruned = self._prune_messages_minimal(messages)
if pruned:
agent.context_limit_reached = False
agent.current_token_count = 0
return True, pruned
return False, None
# Attach metadata to synthetic conversation
conversation["compression_metadata"] = {
"is_compressed": True,
"compression_points": [metadata.to_dict()],
}
compressed_summary, recent_queries = (
compression_service.get_compressed_context(conversation)
)
agent.compressed_summary = compressed_summary
agent.compression_metadata = metadata.to_dict()
agent.compression_saved = False
agent.context_limit_reached = False
agent.current_token_count = 0
rebuilt_messages = self._rebuild_messages_after_compression(
messages,
compressed_summary,
recent_queries,
include_current_execution=False,
include_tool_calls=False,
)
if rebuilt_messages is None:
return False, None
logger.info(
f"In-memory compression successful - ratio: {metadata.compression_ratio:.1f}x, "
f"saved {metadata.original_token_count - metadata.compressed_token_count} tokens"
)
return True, rebuilt_messages
except Exception as e:
logger.error(
f"Error performing in-memory compression: {str(e)}", exc_info=True
)
return False, None
def handle_tool_calls(
self, agent, tool_calls: List[ToolCall], tools_dict: Dict, messages: List[Dict]
) -> Generator:
@@ -195,7 +659,110 @@ class LLMHandler(ABC):
"""
updated_messages = messages.copy()
for call in tool_calls:
for i, call in enumerate(tool_calls):
# Check context limit before executing tool call
if hasattr(agent, '_check_context_limit') and agent._check_context_limit(updated_messages):
# Context limit reached - attempt mid-execution compression
compression_attempted = False
compression_successful = False
try:
from application.core.settings import settings
compression_enabled = settings.ENABLE_CONVERSATION_COMPRESSION
except Exception:
compression_enabled = False
if compression_enabled:
compression_attempted = True
try:
logger.info(
f"Context limit reached with {len(tool_calls) - i} remaining tool calls. "
f"Attempting mid-execution compression..."
)
# Trigger mid-execution compression (DB-backed if available, otherwise in-memory)
compression_successful, rebuilt_messages = self._perform_mid_execution_compression(
agent, updated_messages
)
if compression_successful and rebuilt_messages is not None:
# Update the messages list with rebuilt compressed version
updated_messages = rebuilt_messages
# Yield compression success message
yield {
"type": "info",
"data": {
"message": "Context window limit reached. Compressed conversation history to continue processing."
}
}
logger.info(
f"Mid-execution compression successful. Continuing with {len(tool_calls) - i} remaining tool calls."
)
# Proceed to execute the current tool call with the reduced context
else:
logger.warning("Mid-execution compression attempted but failed. Skipping remaining tools.")
except Exception as e:
logger.error(f"Error during mid-execution compression: {str(e)}", exc_info=True)
compression_attempted = True
compression_successful = False
# If compression wasn't attempted or failed, skip remaining tools
if not compression_successful:
if i == 0:
# Special case: limit reached before executing any tools
# This can happen when previous tool responses pushed context over limit
if compression_attempted:
logger.warning(
f"Context limit reached before executing any tools. "
f"Compression attempted but failed. "
f"Skipping all {len(tool_calls)} pending tool call(s). "
f"This typically occurs when previous tool responses contained large amounts of data."
)
else:
logger.warning(
f"Context limit reached before executing any tools. "
f"Skipping all {len(tool_calls)} pending tool call(s). "
f"This typically occurs when previous tool responses contained large amounts of data. "
f"Consider enabling compression or using a model with larger context window."
)
else:
# Normal case: executed some tools, now stopping
tool_word = "tool call" if i == 1 else "tool calls"
remaining = len(tool_calls) - i
remaining_word = "tool call" if remaining == 1 else "tool calls"
if compression_attempted:
logger.warning(
f"Context limit reached after executing {i} {tool_word}. "
f"Compression attempted but failed. "
f"Skipping remaining {remaining} {remaining_word}."
)
else:
logger.warning(
f"Context limit reached after executing {i} {tool_word}. "
f"Skipping remaining {remaining} {remaining_word}. "
f"Consider enabling compression or using a model with larger context window."
)
# Mark remaining tools as skipped
for remaining_call in tool_calls[i:]:
skip_message = {
"type": "tool_call",
"data": {
"tool_name": "system",
"call_id": remaining_call.id,
"action_name": remaining_call.name,
"arguments": {},
"result": "Skipped: Context limit reached. Too many tool calls in conversation.",
"status": "skipped"
}
}
yield skip_message
# Set flag on agent
agent.context_limit_reached = True
break
try:
self.tool_calls.append(call)
tool_executor_gen = agent._execute_tool_action(tools_dict, call)
@@ -205,21 +772,26 @@ class LLMHandler(ABC):
except StopIteration as e:
tool_response, call_id = e.value
break
function_call_content = {
"function_call": {
"name": call.name,
"args": call.arguments,
"call_id": call_id,
}
}
# Include thought_signature for Google Gemini 3 models
# It should be at the same level as function_call, not inside it
if call.thought_signature:
function_call_content["thought_signature"] = call.thought_signature
updated_messages.append(
{
"role": "assistant",
"content": [
{
"function_call": {
"name": call.name,
"args": call.arguments,
"call_id": call_id,
}
}
],
"content": [function_call_content],
}
)
updated_messages.append(self.create_tool_message(call, tool_response))
except Exception as e:
logger.error(f"Error executing tool: {str(e)}", exc_info=True)
@@ -307,6 +879,9 @@ class LLMHandler(ABC):
tool_calls = {}
for chunk in self._iterate_stream(response):
if isinstance(chunk, dict) and chunk.get("type") == "thought":
yield chunk
continue
if isinstance(chunk, str):
yield chunk
continue
@@ -323,7 +898,13 @@ class LLMHandler(ABC):
if call.name:
existing.name = call.name
if call.arguments:
existing.arguments += call.arguments
if existing.arguments is None:
existing.arguments = call.arguments
else:
existing.arguments += call.arguments
# Preserve thought_signature for Google Gemini 3 models
if call.thought_signature:
existing.thought_signature = call.thought_signature
if parsed.finish_reason == "tool_calls":
tool_handler_gen = self.handle_tool_calls(
agent, list(tool_calls.values()), tools_dict, messages
@@ -336,8 +917,21 @@ class LLMHandler(ABC):
break
tool_calls = {}
# Check if context limit was reached during tool execution
if hasattr(agent, 'context_limit_reached') and agent.context_limit_reached:
# Add system message warning about context limit
messages.append({
"role": "system",
"content": (
"WARNING: Context window limit has been reached. "
"Please provide a final response to the user without making additional tool calls. "
"Summarize the work completed so far."
)
})
logger.info("Context limit reached - instructing agent to wrap up")
response = agent.llm.gen_stream(
model=agent.model_id, messages=messages, tools=agent.tools
model=agent.model_id, messages=messages, tools=agent.tools if not agent.context_limit_reached else None
)
self.llm_calls.append(build_stack_data(agent.llm))

View File

@@ -19,15 +19,20 @@ class GoogleLLMHandler(LLMHandler):
)
if hasattr(response, "candidates"):
parts = response.candidates[0].content.parts if response.candidates else []
tool_calls = [
ToolCall(
id=str(uuid.uuid4()),
name=part.function_call.name,
arguments=part.function_call.args,
)
for part in parts
if hasattr(part, "function_call") and part.function_call is not None
]
tool_calls = []
for idx, part in enumerate(parts):
if hasattr(part, "function_call") and part.function_call is not None:
has_sig = hasattr(part, "thought_signature") and part.thought_signature is not None
thought_sig = part.thought_signature if has_sig else None
tool_calls.append(
ToolCall(
id=str(uuid.uuid4()),
name=part.function_call.name,
arguments=part.function_call.args,
index=idx,
thought_signature=thought_sig,
)
)
content = " ".join(
part.text
@@ -41,13 +46,17 @@ class GoogleLLMHandler(LLMHandler):
raw_response=response,
)
else:
# This branch handles individual Part objects from streaming responses
tool_calls = []
if hasattr(response, "function_call"):
if hasattr(response, "function_call") and response.function_call is not None:
has_sig = hasattr(response, "thought_signature") and response.thought_signature is not None
thought_sig = response.thought_signature if has_sig else None
tool_calls.append(
ToolCall(
id=str(uuid.uuid4()),
name=response.function_call.name,
arguments=response.function_call.args,
thought_signature=thought_sig,
)
)
return LLMResponse(

View File

@@ -7,6 +7,7 @@ class LLMHandlerCreator:
handlers = {
"openai": OpenAILLMHandler,
"google": GoogleLLMHandler,
"novita": OpenAILLMHandler, # Novita uses OpenAI-compatible API
"default": OpenAILLMHandler,
}

View File

@@ -1,68 +0,0 @@
from application.llm.base import BaseLLM
class HuggingFaceLLM(BaseLLM):
def __init__(
self,
api_key=None,
user_api_key=None,
llm_name="Arc53/DocsGPT-7B",
q=False,
*args,
**kwargs,
):
global hf
from langchain.llms import HuggingFacePipeline
if q:
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
pipeline,
BitsAndBytesConfig,
)
tokenizer = AutoTokenizer.from_pretrained(llm_name)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
model = AutoModelForCausalLM.from_pretrained(
llm_name, quantization_config=bnb_config
)
else:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
tokenizer = AutoTokenizer.from_pretrained(llm_name)
model = AutoModelForCausalLM.from_pretrained(llm_name)
super().__init__(*args, **kwargs)
self.api_key = api_key
self.user_api_key = user_api_key
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=2000,
device_map="auto",
eos_token_id=tokenizer.eos_token_id,
)
hf = HuggingFacePipeline(pipeline=pipe)
def _raw_gen(self, baseself, model, messages, stream=False, **kwargs):
context = messages[0]["content"]
user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
result = hf(prompt)
return result.content
def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs):
raise NotImplementedError("HuggingFaceLLM Streaming is not implemented yet.")

Some files were not shown because too many files have changed in this diff Show More