Compare commits

..

427 Commits

Author SHA1 Message Date
Alex
3352d42414 fix(frontend): use bracket notation for tool variable paths (#2176) 2025-11-26 19:12:02 +02:00
JustACodeA
899b30da5e feat: add German translation (#2170)
Adds complete German (Deutsch) language support to DocsGPT.

Changes:
- Add de.json with full German translations
- Register German in i18n configuration
- Add German to language selector dropdown

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude <noreply@anthropic.com>
2025-11-26 12:52:23 +02:00
Alex
dc2faf7a7e fix: webhooks (#2175) 2025-11-25 16:08:22 +02:00
Alex
67e0d222d1 fix: model in agents via api (#2174) 2025-11-25 13:54:34 +02:00
Alex
17698ce774 feat: context compression (#2173)
* feat: context compression

* fix: ruff
2025-11-24 12:44:19 +02:00
Alex
7d1c8c008b Update README.md 2025-11-22 16:42:25 +02:00
Alex
9e58eb02b3 Update .env.development 2025-11-14 19:53:19 +02:00
Siddhant Rai
3f7de867cc feat: model registry and capabilities for multi-provider support (#2158)
* feat: Implement model registry and capabilities for multi-provider support

- Added ModelRegistry to manage available models and their capabilities.
- Introduced ModelProvider enum for different LLM providers.
- Created ModelCapabilities dataclass to define model features.
- Implemented methods to load models based on API keys and settings.
- Added utility functions for model management in model_utils.py.
- Updated settings.py to include provider-specific API keys.
- Refactored LLM classes (Anthropic, OpenAI, Google, etc.) to utilize new model registry.
- Enhanced utility functions to handle token limits and model validation.
- Improved code structure and logging for better maintainability.

* feat: Add model selection feature with API integration and UI component

* feat: Add model selection and default model functionality in agent management

* test: Update assertions and formatting in stream processing tests

* refactor(llm): Standardize model identifier to model_id

* fix tests

---------

Co-authored-by: Alex <a@tushynski.me>
2025-11-14 13:13:19 +02:00
Manish Madan
fbf7cf874b chore(dependabot): add react-widget npm dependency updates (#2146) 2025-11-07 17:17:46 +02:00
Manish Madan
ba7278b80f Merge pull request #2140 from arc53/dependabot/npm_and_yarn/frontend/husky-9.1.7
chore(deps-dev): bump husky from 8.0.3 to 9.1.7 in /frontend
2025-11-07 03:02:52 +05:30
ManishMadan2882
9d649de6f9 chore(eslint): migrate to ESLint 9 flat config format
- Add eslint.config.js with ESLint 9 flat config format
- Remove deprecated .eslintrc.cjs file
- Remove deprecated .eslintignore file (replaced by ignores in config)
- Maintain all existing ESLint rules and configurations
- Ensure compatibility with Husky 9.1.7
2025-11-07 02:59:51 +05:30
dependabot[bot]
7929afbf58 chore(deps-dev): bump husky from 8.0.3 to 9.1.7 in /frontend
Bumps [husky](https://github.com/typicode/husky) from 8.0.3 to 9.1.7.
- [Release notes](https://github.com/typicode/husky/releases)
- [Commits](https://github.com/typicode/husky/compare/v8.0.3...v9.1.7)

---
updated-dependencies:
- dependency-name: husky
  dependency-version: 9.1.7
  dependency-type: direct:development
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-11-06 21:27:39 +00:00
Manish Madan
ceaf942e70 Merge pull request #2139 from arc53/dependabot/npm_and_yarn/frontend/eslint-9.39.1
chore(deps-dev): bump eslint from 8.57.1 to 9.39.1 in /frontend
2025-11-07 02:33:32 +05:30
dependabot[bot]
f355601a44 chore(deps-dev): bump eslint from 8.57.1 to 9.39.1 in /frontend
Bumps [eslint](https://github.com/eslint/eslint) from 8.57.1 to 9.39.1.
- [Release notes](https://github.com/eslint/eslint/releases)
- [Commits](https://github.com/eslint/eslint/compare/v8.57.1...v9.39.1)

---
updated-dependencies:
- dependency-name: eslint
  dependency-version: 9.39.1
  dependency-type: direct:development
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-11-06 20:00:14 +00:00
Manish Madan
4ff99a1e86 Merge pull request #2138 from arc53/dependabot/npm_and_yarn/frontend/reduxjs/toolkit-2.10.1
chore(deps): bump @reduxjs/toolkit from 2.9.2 to 2.10.1 in /frontend
2025-11-07 01:28:58 +05:30
dependabot[bot]
129084ba92 chore(deps): bump @reduxjs/toolkit from 2.9.2 to 2.10.1 in /frontend
Bumps [@reduxjs/toolkit](https://github.com/reduxjs/redux-toolkit) from 2.9.2 to 2.10.1.
- [Release notes](https://github.com/reduxjs/redux-toolkit/releases)
- [Commits](https://github.com/reduxjs/redux-toolkit/compare/v2.9.2...v2.10.1)

---
updated-dependencies:
- dependency-name: "@reduxjs/toolkit"
  dependency-version: 2.10.1
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-11-06 19:56:28 +00:00
Manish Madan
2288df1293 Merge pull request #2141 from arc53/dependabot/npm_and_yarn/frontend/vite-7.2.0
chore(deps-dev): bump vite from 7.1.12 to 7.2.0 in /frontend
2025-11-07 01:05:29 +05:30
Manish Madan
d9dfac55e7 Merge pull request #2134 from arc53/dependabot/npm_and_yarn/frontend/types/mermaid-9.2.0
chore(deps-dev): bump @types/mermaid from 9.1.0 to 9.2.0 in /frontend
2025-11-06 17:46:59 +05:30
Nick
404cf4b7c7 Update quickstart.mdx (#2142)
Added missing **
2025-11-06 12:37:27 +02:00
dependabot[bot]
f1c1fc123b chore(deps-dev): bump vite from 7.1.12 to 7.2.0 in /frontend
Bumps [vite](https://github.com/vitejs/vite/tree/HEAD/packages/vite) from 7.1.12 to 7.2.0.
- [Release notes](https://github.com/vitejs/vite/releases)
- [Changelog](https://github.com/vitejs/vite/blob/main/packages/vite/CHANGELOG.md)
- [Commits](https://github.com/vitejs/vite/commits/v7.2.0/packages/vite)

---
updated-dependencies:
- dependency-name: vite
  dependency-version: 7.2.0
  dependency-type: direct:development
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-11-05 20:08:29 +00:00
ManishMadan2882
9f19c7ee4c Remove deprecated @types/mermaid dependency - mermaid provides its own types 2025-11-05 20:43:47 +05:30
dependabot[bot]
155e74eca1 chore(deps-dev): bump @types/mermaid from 9.1.0 to 9.2.0 in /frontend
Bumps [@types/mermaid](https://github.com/DefinitelyTyped/DefinitelyTyped/tree/HEAD/types/mermaid) from 9.1.0 to 9.2.0.
- [Release notes](https://github.com/DefinitelyTyped/DefinitelyTyped/releases)
- [Commits](https://github.com/DefinitelyTyped/DefinitelyTyped/commits/HEAD/types/mermaid)

---
updated-dependencies:
- dependency-name: "@types/mermaid"
  dependency-version: 9.2.0
  dependency-type: direct:development
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-11-05 15:10:19 +00:00
Manish Madan
ea2dc4dbcb Merge pull request #2133 from arc53/dependabot/npm_and_yarn/frontend/react-i18next-16.2.4
chore(deps): bump react-i18next from 15.7.4 to 16.2.4 in /frontend
2025-11-05 20:23:15 +05:30
dependabot[bot]
616edc97de chore(deps): bump react-i18next from 15.7.4 to 16.2.4 in /frontend
Bumps [react-i18next](https://github.com/i18next/react-i18next) from 15.7.4 to 16.2.4.
- [Changelog](https://github.com/i18next/react-i18next/blob/master/CHANGELOG.md)
- [Commits](https://github.com/i18next/react-i18next/compare/v15.7.4...v16.2.4)

---
updated-dependencies:
- dependency-name: react-i18next
  dependency-version: 16.2.4
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-11-05 14:48:28 +00:00
Manish Madan
b017e99c79 Merge pull request #2132 from arc53/dependabot/npm_and_yarn/frontend/eslint-plugin-n-17.23.1
chore(deps-dev): bump eslint-plugin-n from 16.6.2 to 17.23.1 in /frontend
2025-11-05 20:14:18 +05:30
dependabot[bot]
f698e9d3e1 chore(deps-dev): bump eslint-plugin-n in /frontend
Bumps [eslint-plugin-n](https://github.com/eslint-community/eslint-plugin-n) from 16.6.2 to 17.23.1.
- [Release notes](https://github.com/eslint-community/eslint-plugin-n/releases)
- [Changelog](https://github.com/eslint-community/eslint-plugin-n/blob/master/CHANGELOG.md)
- [Commits](https://github.com/eslint-community/eslint-plugin-n/compare/16.6.2...v17.23.1)

---
updated-dependencies:
- dependency-name: eslint-plugin-n
  dependency-version: 17.23.1
  dependency-type: direct:development
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-11-05 14:35:17 +00:00
Manish Madan
d366502850 Merge pull request #2131 from arc53/dependabot/npm_and_yarn/frontend/typescript-eslint/parser-8.46.3
chore(deps-dev): bump @typescript-eslint/parser from 6.21.0 to 8.46.3 in /frontend
2025-11-05 20:03:59 +05:30
ManishMadan2882
3d6757c170 (chore:lint) relax rules, build fix 2025-11-05 20:02:01 +05:30
Manish Madan
cb8302add8 Fixes shared conversation on cloud version (#2135)
* (fix:shared) conv as id, not dbref

* (chore) script to migrate dbref to id

* (chore): ruff fix

---------

Co-authored-by: GH Action - Upstream Sync <action@github.com>
2025-11-05 16:08:10 +02:00
dependabot[bot]
9d266e9fad chore(deps-dev): bump @typescript-eslint/parser in /frontend
Bumps [@typescript-eslint/parser](https://github.com/typescript-eslint/typescript-eslint/tree/HEAD/packages/parser) from 6.21.0 to 8.46.3.
- [Release notes](https://github.com/typescript-eslint/typescript-eslint/releases)
- [Changelog](https://github.com/typescript-eslint/typescript-eslint/blob/main/packages/parser/CHANGELOG.md)
- [Commits](https://github.com/typescript-eslint/typescript-eslint/commits/v8.46.3/packages/parser)

---
updated-dependencies:
- dependency-name: "@typescript-eslint/parser"
  dependency-version: 8.46.3
  dependency-type: direct:development
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-11-05 13:45:18 +00:00
Manish Madan
ae94c9d31e Merge pull request #2130 from arc53/dependabot/npm_and_yarn/frontend/vite-7.1.12
chore(deps-dev): bump vite from 6.4.1 to 7.1.12 in /frontend
2025-11-05 19:13:59 +05:30
ManishMadan2882
83ab232dcd (chore:fe) pkg lock 2025-11-05 19:12:20 +05:30
dependabot[bot]
eea85772a3 chore(deps-dev): bump vite from 6.4.1 to 7.1.12 in /frontend
Bumps [vite](https://github.com/vitejs/vite/tree/HEAD/packages/vite) from 6.4.1 to 7.1.12.
- [Release notes](https://github.com/vitejs/vite/releases)
- [Changelog](https://github.com/vitejs/vite/blob/v7.1.12/packages/vite/CHANGELOG.md)
- [Commits](https://github.com/vitejs/vite/commits/v7.1.12/packages/vite)

---
updated-dependencies:
- dependency-name: vite
  dependency-version: 7.1.12
  dependency-type: direct:development
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-11-05 19:10:27 +05:30
Alex
0fe7e223cc fix: update Discord invite link across documentation and navigation 2025-11-04 09:27:22 +00:00
Heisenberg Vader
3789d2eb03 Updated the technique for handling multiple file uploads from the user (#2126)
* Fixed multiple file uploads to be sent through a single request to backend for further processing and storing

* Fixed multiple file uploads to be sent through a single request to backend for further processing and storing

* Fixed multiple file uploads to be sent through a single request to backend for further processing and storing

* Made duplicate multiple keyword fixes

* Added back drag and drop functionality and it keeps the multiple file uploads
2025-11-04 01:12:35 +02:00
Manish Madan
d54469532e fix: adjust ESLint rules to warnings for strict type checking (#2129)
- Changed @typescript-eslint/no-explicit-any from error to warning
- Changed @typescript-eslint/no-unused-vars from error to warning
- Allows codebase to pass linting while maintaining code quality checks
- These rules can be gradually enforced as code is refactored
- Verified with npm run build - successful
2025-11-04 01:09:39 +02:00
Manish Madan
9884e51836 Merge pull request #2122 from arc53/dependabot/npm_and_yarn/frontend/prettier-plugin-tailwindcss-0.7.1
chore(deps-dev): bump prettier-plugin-tailwindcss from 0.6.13 to 0.7.1 in /frontend
2025-11-03 19:31:30 +05:30
Alex
6626723180 feat: enhance prompt variable handling and add system variable options in prompts modal (#2128) 2025-11-03 15:54:13 +02:00
Manish Madan
0c251e066b Merge pull request #2124 from arc53/dependabot/npm_and_yarn/frontend/eslint-plugin-n-17.23.1
chore(deps-dev): bump eslint-plugin-n from 15.7.0 to 17.23.1 in /frontend
2025-11-03 19:22:22 +05:30
dependabot[bot]
0957034bfa chore(deps-dev): bump prettier-plugin-tailwindcss in /frontend
Bumps [prettier-plugin-tailwindcss](https://github.com/tailwindlabs/prettier-plugin-tailwindcss) from 0.6.13 to 0.7.1.
- [Release notes](https://github.com/tailwindlabs/prettier-plugin-tailwindcss/releases)
- [Changelog](https://github.com/tailwindlabs/prettier-plugin-tailwindcss/blob/main/CHANGELOG.md)
- [Commits](https://github.com/tailwindlabs/prettier-plugin-tailwindcss/compare/v0.6.13...v0.7.1)

---
updated-dependencies:
- dependency-name: prettier-plugin-tailwindcss
  dependency-version: 0.7.1
  dependency-type: direct:development
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-11-03 13:49:34 +00:00
ManishMadan2882
44521cd893 fix: resolve peer dependency conflict with eslint-plugin-n
- Downgrade eslint-plugin-n from ^17.23.1 to ^16.6.2
- Ensure compatibility with eslint-config-standard-with-typescript@43.0.1
- eslint-config-standard-with-typescript requires eslint-plugin-n@^15.0.0 || ^16.0.0
- Verified with successful npm install and vite build
2025-11-03 19:19:02 +05:30
dependabot[bot]
b17f846730 chore(deps-dev): bump eslint-plugin-n in /frontend
Bumps [eslint-plugin-n](https://github.com/eslint-community/eslint-plugin-n) from 15.7.0 to 17.23.1.
- [Release notes](https://github.com/eslint-community/eslint-plugin-n/releases)
- [Changelog](https://github.com/eslint-community/eslint-plugin-n/blob/master/CHANGELOG.md)
- [Commits](https://github.com/eslint-community/eslint-plugin-n/compare/15.7.0...v17.23.1)

---
updated-dependencies:
- dependency-name: eslint-plugin-n
  dependency-version: 17.23.1
  dependency-type: direct:development
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-11-03 13:45:27 +00:00
Manish Madan
6dd32fd4ca Merge pull request #2111 from arc53/dependabot/npm_and_yarn/frontend/mermaid-11.12.1
chore(deps): bump mermaid from 11.12.0 to 11.12.1 in /frontend
2025-11-03 19:14:00 +05:30
dependabot[bot]
b17b1c70b5 chore(deps): bump mermaid from 11.12.0 to 11.12.1 in /frontend
Bumps [mermaid](https://github.com/mermaid-js/mermaid) from 11.12.0 to 11.12.1.
- [Release notes](https://github.com/mermaid-js/mermaid/releases)
- [Commits](https://github.com/mermaid-js/mermaid/compare/mermaid@11.12.0...mermaid@11.12.1)

---
updated-dependencies:
- dependency-name: mermaid
  dependency-version: 11.12.1
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-11-03 19:07:58 +05:30
Manish Madan
3f5b31fb5f Merge pull request #2084 from arc53/dependabot/npm_and_yarn/frontend/typescript-eslint/eslint-plugin-8.46.2
chore(deps-dev): bump @typescript-eslint/eslint-plugin from 5.51.0 to 8.46.2 in /frontend
2025-11-03 18:51:23 +05:30
ManishMadan2882
06bda6bd55 fix: resolve peer dependency conflicts in eslint and typescript-eslint packages
- Update @typescript-eslint/eslint-plugin from ^8.46.2 to ^6.21.0
- Update @typescript-eslint/parser from ^5.62.0 to ^6.21.0
- Update eslint-config-standard-with-typescript from ^34.0.0 to ^43.0.1
- Ensure all dependencies are compatible without requiring --legacy-peer-deps
- Verified with successful npm install and vite build
2025-11-03 18:47:34 +05:30
Christine Belzie
7dd97821a8 feat: Installing vale(redo) (#2104)
* docs: setup Vale

Signed-off-by: Christine Belzie <shecoder30@gmail.com>

* adding more content

* chore: add Vale configuration and custom dictionary

* chore: clean up spelling configuration and remove unused vocabularies

* fix: correct file path format for Vale linter configuration

---------

Signed-off-by: Christine Belzie <shecoder30@gmail.com>
Co-authored-by: Alex <a@tushynski.me>
2025-10-31 18:00:09 +02:00
Harshit Ranjan
695191d888 added error saving vector store (#2081)
* added error saving vector store

* fixed code formating

* added tests for embedding pipeline
2025-10-31 16:29:35 +02:00
dependabot[bot]
1dbcef24c7 chore(deps-dev): bump @typescript-eslint/eslint-plugin in /frontend
Bumps [@typescript-eslint/eslint-plugin](https://github.com/typescript-eslint/typescript-eslint/tree/HEAD/packages/eslint-plugin) from 5.51.0 to 8.46.2.
- [Release notes](https://github.com/typescript-eslint/typescript-eslint/releases)
- [Changelog](https://github.com/typescript-eslint/typescript-eslint/blob/main/packages/eslint-plugin/CHANGELOG.md)
- [Commits](https://github.com/typescript-eslint/typescript-eslint/commits/v8.46.2/packages/eslint-plugin)

---
updated-dependencies:
- dependency-name: "@typescript-eslint/eslint-plugin"
  dependency-version: 8.46.2
  dependency-type: direct:development
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-10-31 14:04:19 +00:00
Manish Madan
e086c79da0 Merge pull request #1933 from arc53/dependabot/npm_and_yarn/frontend/npm_and_yarn-d56b2ef021
chore(deps): bump mermaid from 11.6.0 to 11.10.0 in /frontend in the npm_and_yarn group
2025-10-31 19:32:57 +05:30
dependabot[bot]
6ae8d34b27 chore(deps): bump mermaid in /frontend in the npm_and_yarn group
Bumps the npm_and_yarn group in /frontend with 1 update: [mermaid](https://github.com/mermaid-js/mermaid).

Updates `mermaid` from 11.6.0 to 11.10.0
- [Release notes](https://github.com/mermaid-js/mermaid/releases)
- [Commits](https://github.com/mermaid-js/mermaid/compare/mermaid@11.6.0...mermaid@11.10.0)

---
updated-dependencies:
- dependency-name: mermaid
  dependency-version: 11.10.0
  dependency-type: direct:production
  dependency-group: npm_and_yarn
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-10-31 19:27:31 +05:30
Manish Madan
2e23e547d3 Merge pull request #1916 from arc53/dependabot/npm_and_yarn/frontend/eslint-plugin-prettier-5.5.4
build(deps-dev): bump eslint-plugin-prettier from 5.2.1 to 5.5.4 in /frontend
2025-10-31 19:24:38 +05:30
dependabot[bot]
fa11dc9828 build(deps-dev): bump eslint-plugin-prettier in /frontend
Bumps [eslint-plugin-prettier](https://github.com/prettier/eslint-plugin-prettier) from 5.2.1 to 5.5.4.
- [Release notes](https://github.com/prettier/eslint-plugin-prettier/releases)
- [Changelog](https://github.com/prettier/eslint-plugin-prettier/blob/main/CHANGELOG.md)
- [Commits](https://github.com/prettier/eslint-plugin-prettier/compare/v5.2.1...v5.5.4)

---
updated-dependencies:
- dependency-name: eslint-plugin-prettier
  dependency-version: 5.5.4
  dependency-type: direct:development
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-10-31 19:18:44 +05:30
Manish Madan
673fa70bc5 Merge pull request #1903 from arc53/dependabot/npm_and_yarn/frontend/multi-6fb5dc7d23
build(deps): bump react-dom and @types/react-dom in /frontend
2025-10-31 19:16:17 +05:30
dependabot[bot]
a0660a54c1 build(deps): bump react-dom and @types/react-dom in /frontend
Bumps [react-dom](https://github.com/facebook/react/tree/HEAD/packages/react-dom) and [@types/react-dom](https://github.com/DefinitelyTyped/DefinitelyTyped/tree/HEAD/types/react-dom). These dependencies needed to be updated together.

Updates `react-dom` from 19.1.0 to 19.1.1
- [Release notes](https://github.com/facebook/react/releases)
- [Changelog](https://github.com/facebook/react/blob/main/CHANGELOG.md)
- [Commits](https://github.com/facebook/react/commits/v19.1.1/packages/react-dom)

Updates `@types/react-dom` from 19.1.6 to 19.1.7
- [Release notes](https://github.com/DefinitelyTyped/DefinitelyTyped/releases)
- [Commits](https://github.com/DefinitelyTyped/DefinitelyTyped/commits/HEAD/types/react-dom)

---
updated-dependencies:
- dependency-name: react-dom
  dependency-version: 19.1.1
  dependency-type: direct:production
  update-type: version-update:semver-patch
- dependency-name: "@types/react-dom"
  dependency-version: 19.1.7
  dependency-type: direct:development
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-10-31 19:10:51 +05:30
Manish Madan
1137bf4280 Merge pull request #1864 from arc53/dependabot/npm_and_yarn/frontend/i18next-browser-languagedetector-8.2.0
build(deps): bump i18next-browser-languagedetector from 8.0.2 to 8.2.0 in /frontend
2025-10-31 19:10:26 +05:30
dependabot[bot]
da41c898d8 build(deps): bump i18next-browser-languagedetector in /frontend
Bumps [i18next-browser-languagedetector](https://github.com/i18next/i18next-browser-languageDetector) from 8.0.2 to 8.2.0.
- [Changelog](https://github.com/i18next/i18next-browser-languageDetector/blob/master/CHANGELOG.md)
- [Commits](https://github.com/i18next/i18next-browser-languageDetector/compare/v8.0.2...v8.2.0)

---
updated-dependencies:
- dependency-name: i18next-browser-languagedetector
  dependency-version: 8.2.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-10-31 18:59:21 +05:30
Siddhant Rai
21e5c261ef feat: template-based prompt rendering with dynamic namespace injection (#2091)
* feat: template-based prompt rendering with dynamic namespace injection

* refactor: improve template engine initialization with clearer formatting

* refactor: streamline ReActAgent methods and improve content extraction logic

feat: enhance error handling in NamespaceManager and TemplateEngine

fix: update NewAgent component to ensure consistent form data submission

test: modify tests for ReActAgent and prompt renderer to reflect method changes and improve coverage

* feat: tools namespace + three-tier token budget

* refactor: remove unused variable assignment in message building tests

* Enhance prompt customization and tool pre-fetching functionality

* ruff lint fix

* refactor: cleaner error handling and reduce code clutter

---------

Co-authored-by: Alex <a@tushynski.me>
2025-10-31 12:47:44 +00:00
Aqsa Aqeel
a7d61b9d59 feat: implementing the new custom modal design (#2090)
* feat: implementing the new custom modal design

* feat: added tool variable dropdown

* fix: ui fixes and link fixes

* feat: implemented redisgn for edit prompt modal

* (feat:prompts) matching figma

* (fix:prompts) tool vars

* (fix:promptsModal) responsive; disable save on text

---------

Co-authored-by: Aqsa Aqeel <aqsa.aqeel17@example.com>
Co-authored-by: ManishMadan2882 <manishmadan321@gmail.com>
2025-10-31 12:18:13 +02:00
dorkdiaries9
c5fe25c149 Enhance migration script with logging and error handling (#2103)
Added logging for migration steps and error handling.
2025-10-29 01:49:47 +02:00
Manish Madan
6a4cb617f9 Frontend audit: Bug fixes and refinements (#2112)
* (fix:attachements) sep id for redux ops

* (fix:ui) popups, toast, share modal

* (feat:agentsPreview) stable preview, ui fixes

* (fix:ui) light theme icon, sleek scroll

* (chore:i18n) missin keys

* (chore:i18n) missing keys

* (feat:preferrenceSlice) autoclear invalid source from storage

* (fix:general) delete all conv close btn

* (fix:tts) play one at a time

* (fix:tts) gracefully unmount

* (feat:tts) audio LRU cache

* (feat:tts) pointer on hovered area

* (feat:tts) clean text for speach

---------

Co-authored-by: GH Action - Upstream Sync <action@github.com>
2025-10-29 01:47:26 +02:00
Alex
94f70e6de5 refactor(tests): update todo tool tests to use simplified action names and improve key handling 2025-10-28 10:39:35 +00:00
Hanzalah Waheed
ab4ebf9a9d feat: add bg blur for modals (#2110)
* feat: add bg blur for modals

* feat: adjust darkness for lightmode
2025-10-28 12:14:37 +02:00
Nikunj Kohli
9f7945fcf5 [UI/UX] Improve image upload experience — add preview & drag-to-reord… (#2095)
* [UI/UX] Improve image upload experience — add preview & drag-to-reorder in chat section

* chore(chat): remove image previews, keep drag-to-reorder

* chore(chat): prevent attachment drag from triggering upload dropzone

* Revert "chore(chat): prevent attachment drag from triggering upload dropzone"

This reverts commit dd4b96256c.

* (feat:conv) rmv drag-drop on sources

* (feat:msg-input) drop attachments

---------

Co-authored-by: ManishMadan2882 <manishmadan321@gmail.com>
2025-10-27 21:53:18 +02:00
Gayatri K
d8ec3c008c todo tool feature added to tools (#1977)
* todo tool feature added to tools

* removed configs

* fix: require user_id on TodoListTool, normalize timestamps, add tests

* lint and tests fixes

* refactor: support multiple todos per user/tool by indexing with todo_id

* modified todo_id to use auto-increamenting integer instead of UUID

* test-case fixes

* feat: fix todos

---------

Co-authored-by: Alex <a@tushynski.me>
2025-10-27 19:09:32 +02:00
Pavel
2f00691246 Merge pull request #2096 from arc53/hacktoberfest-t-shirt-image
Update HACKTOBERFEST.md with T-shirt image
2025-10-24 13:15:30 +01:00
Pavel
9b2383b074 Update HACKTOBERFEST.md with T-shirt image
Added t-shirt image.
2025-10-24 13:09:23 +01:00
Ritoban Dutta
e4e9910575 fix(ui): use dedicated sidebar open/close icons for better visual feedback of actions (#2088)
* fix(ui): use dedicated icons for sidebar toggle (panel-left-open/close)

* fix: update sidebar toggle icon colors
2025-10-24 00:25:17 +03:00
Nihar
f448e4a615 add configurable provider in settings.py and update ElevenLabs Api (#2065) (#2074) 2025-10-22 19:07:21 +03:00
Manish Madan
c4e8daf50e Frontend audit: refinements (#2083)
* (fix:attachements) sep id for redux ops

* (fix:ui) popups, toast, share modal

* (feat:agentsPreview) stable preview, ui fixes

* (fix:ui) light theme icon, sleek scroll

---------

Co-authored-by: GH Action - Upstream Sync <action@github.com>
2025-10-22 12:12:05 +03:00
Hanzalah Waheed
5aa4ec1b9f fix: cleanup ConversationBubble and fix CopyButton (#2073)
* fix: rm states for hovering. use tailwind classes instead

* fix: use group hover css intead of states

* chore: no point in having separate defaults if cant be customised

* fix: move default bg colors into conditionals
2025-10-18 21:59:15 +03:00
Siddhant Rai
125ce0aad3 test: implement full API test suite with mongomock and centralized fixtures (#2068) 2025-10-17 12:01:14 +03:00
TinTin
ababc9ae04 fix: reduce large margins between list items in answer rendering (#2058) 2025-10-16 13:59:45 +03:00
Alex
62ac90746e fix: improve error handling and loading state in fetchChunks function 2025-10-15 17:33:13 +01:00
Alex
096f6d91a2 fix: handle potential undefined value for selectedDocs in fetchAnswer 2025-10-15 17:12:52 +01:00
jay98
d28ef6b094 Refactor: use async/await in fetchChunks for correct error handling (typescript:S4822) (#2066) 2025-10-15 15:52:55 +03:00
Anurag Yadav
8fb945ab09 feedback button to show after message (#2064) 2025-10-14 18:52:31 +03:00
jay98
835d71727c fix: remove redundant conditional operator for file assignment (#2060) 2025-10-14 17:35:41 +03:00
Ali Arda Fincan
ce32dd2907 Feat: Agent Token or Request Limiting (#2041)
* Update routes.py, added token and request limits to create/update agent operations

* added usage limit check to api endpoints

cannot create agents with usage limit right now that will be implemented

* implemented api limiting as either token limiting or request limiting modes

* minor typo & bug fix
2025-10-13 21:32:46 +03:00
Manish Madan
72bc24a490 Chore: deleted unused files, dead code; minor fixes (#2059)
* (feat:pause-stream) generator exit

* (feat:pause-stream) close request

* (feat:pause-stream) finally close; google anthropic

* (feat:task_status)communicate failure

* (clean:connector) unused routes

* (feat:file-table) missing skeletons

* (chore:fe) purge unused

* (fix:apiKeys) build err

* (chore:settings) clean unused

* merge from main

* (chore:purge) unused fe assets

* (clean:check_docs) unused logic

* (feat:selectedDocs) replace null type

---------

Co-authored-by: GH Action - Upstream Sync <action@github.com>
2025-10-13 19:11:24 +03:00
Siddhant Rai
d6c49bdbf0 test: add agent test coverage and standardize test suite (#2051)
- Add 104 comprehensive tests for agent system
- Integrate agent tests into CI/CD pipeline
- Standardize tests with @pytest.mark.unit markers
- Fix cross-platform path compatibility
- Clean up unused imports and dependencies
2025-10-13 14:43:35 +03:00
beKool.sh
1805292528 Update README.md (#2057) 2025-10-13 13:00:13 +03:00
Nirjas Jakilim
d09ce7e1f7 fixed broken link (#2054) 2025-10-12 15:26:05 +03:00
Marco Ponce
a8d2024791 Windows deployment powershell and renamed LLM_PROVIDER and runtime (#2050)
* Windows deployment powershell and  renamed LLM_PROVIDER and runtime

* added LLM_NAME back

* revert changes on docker-compose-hub.yaml
2025-10-12 15:25:42 +03:00
Manish Madan
f0b954dbfb Upload: communicate failure, minor frontend updates (#2048)
* (feat:pause-stream) generator exit

* (feat:pause-stream) close request

* (feat:pause-stream) finally close; google anthropic

* (feat:task_status)communicate failure

* (clean:connector) unused routes

* (feat:file-table) missing skeletons

* (fix:apiKeys) build err

---------

Co-authored-by: GH Action - Upstream Sync <action@github.com>
2025-10-10 17:34:02 +03:00
Rahul
50bee7c2b0 feat: Add button to cancel LLM response (#1978)
* feat: Add button to cancel LLM response
- Replace text area with cancel button when loading.
- Add useEffect to change elipsis in cancel button text.
- Add new SVG icon for cancel response.
- Button colors match Figma designs.

* fix: Cancel button UI matches new design
- Delete cancel-response svg.
- Change previous cancel button to match the new Figma design.
- Remove console log in handleCancel function.

* fix: Adjust cancel button rounding

* feat: Update UI for send button
- Add SendArrowIcon component, enables dynamic svg color changes
- Replace original icon
- Update colors and hover effects

* (fix:send-button) minor blink in transition

---------

Co-authored-by: Manish Madan <manishmadan321@gmail.com>
2025-10-09 12:01:25 +03:00
Mariam Saeed
e7b15b316e Feat: Notification section (#2033)
* Feature/Notification-section

* fix notification ui and add local storage variable to save the state

* add notification component to app.tsx
2025-10-09 01:26:10 +03:00
Manish Madan
a4507008c1 complete_stream: Stop response streaming (#2031)
* (feat:pause-stream) generator exit

* (feat:pause-stream) close request

* (feat:pause-stream) finally close; google anthropic

---------

Co-authored-by: GH Action - Upstream Sync <action@github.com>
2025-10-08 20:37:30 +03:00
Hanzalah Waheed
c5ba85f929 fix(ui): create a var to check for shared metadata obj (#2040) 2025-10-08 18:18:54 +03:00
Manish Madan
2e636bd67e Merge pull request #1970 from arc53/dependabot/npm_and_yarn/frontend/mermaid-11.12.0
chore(deps): bump mermaid from 11.6.0 to 11.12.0 in /frontend
2025-10-08 20:42:32 +05:30
dependabot[bot]
4a039f1abf chore(deps): bump mermaid from 11.6.0 to 11.12.0 in /frontend
Bumps [mermaid](https://github.com/mermaid-js/mermaid) from 11.6.0 to 11.12.0.
- [Release notes](https://github.com/mermaid-js/mermaid/releases)
- [Commits](https://github.com/mermaid-js/mermaid/compare/mermaid@11.6.0...mermaid@11.12.0)

---
updated-dependencies:
- dependency-name: mermaid
  dependency-version: 11.12.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-10-08 15:06:36 +00:00
JeevaRamanathan
434d8e2070 fix spinner to match theme (dark/light) in conversation (#2044) 2025-10-08 15:16:44 +03:00
Nihar
160ad2dc79 correct path for storing embeddings model (#2035) 2025-10-07 23:01:39 +03:00
Anshuman Payasi
0ec86c2c71 fix: adjust share modal size and spacing (#2027)
* fix/share-modal-spacing

* fix(share-modal): spacing adjusted for mobile view
2025-10-07 17:11:03 +03:00
Alex
03452ffd9f feat: add GitHub access token support and fix file content fetching logic (#2032) 2025-10-07 16:53:14 +03:00
Siddhant Rai
da6317a242 feat: agent templates and seeding premade agents (#1910)
* feat: agent templates and seeding premade agents

* fix: ensure ObjectId is used for source reference in agent configuration

* fix: improve source handling in DatabaseSeeder and update tool config processing

* feat: add prompt handling in DatabaseSeeder for agent configuration

* Docs premade agents

* link to prescraped docs

* feat: add template agent retrieval and adopt agent functionality

* feat: simplify agent descriptions in premade_agents.yaml  added docs

---------

Co-authored-by: Pavel <pabin@yandex.ru>
Co-authored-by: Alex <a@tushynski.me>
2025-10-07 13:00:14 +03:00
Nihar
8b8e616557 fix: handle missing kwargs in local save_file (#2017)
Previously, the local save_file function didn’t accept kwargs, causing
a crash when passed extra params. Added support to maintain consistency
with AWS version.

Fixes #2009
2025-10-06 23:55:49 +03:00
Marco Ponce
d260f1a1a6 Made changes to how the documentation is represented including the new 5th option when forking and launching DocsGPT on the inviduals device, as well as updated README.md which has miswording issues saying only 4 options, but now includes the 5th option, detailing and giving an explanation for what that option does and documentation provided. (#2020) 2025-10-06 23:50:16 +03:00
Alex
9d452e3b04 feat: enhance MemoryTool and NotesTool with tool_id management and directory renaming tests (#2026) 2025-10-06 23:45:47 +03:00
Alex
e012189672 feat: implement MemoryTool with CRUD operations and integrate with MongoDB 2025-10-06 21:09:22 +01:00
Dhairya Parikh
4c31e9a8b1 Add MongoDB-backed NotesTool with CRUD actions and unit tests (#1982)
* Add MongoDB-backed NotesTool with CRUD actions and unit tests

* NotesTool: enforce single note per user, require decoded_token['sub'] user_id, fix tests

* chore: remove accidentally committed results.txt and ignore it

* fix: remove results.txt, enforce single note per user, add tests

* refactor: update NotesTool actions and tests for clarity and consistency

* refactor: update NotesTool docstring for clarity

* refactor: simplify MCPTool docstring and remove redundant import in test_notes_tool

* lint: fix test

* refactor: remove unused import from test_notes_tool.py

---------

Co-authored-by: Alex <a@tushynski.me>
2025-10-06 19:24:03 +03:00
Alex
7cfc230316 Update README.md 2025-10-06 18:30:47 +03:00
Alex
9605e85f1c Merge pull request #2004 from Lanthoiba2022/AgentImageFix1
Fix #1983: Agent image fallback added
2025-10-06 14:47:26 +01:00
Alex
498e2b772c fix: update save_file method to accept additional keyword arguments; enhance AgentImage component with useEffect for dynamic source updates 2025-10-06 14:41:51 +01:00
Manish Madan
dad897da51 Merge pull request #1997 from arc53/dependabot/npm_and_yarn/frontend/i18next-25.5.3
build(deps): bump i18next from 24.2.0 to 25.5.3 in /frontend
2025-10-06 16:44:15 +05:30
dependabot[bot]
02ad5f062e build(deps): bump i18next from 24.2.0 to 25.5.3 in /frontend
Bumps [i18next](https://github.com/i18next/i18next) from 24.2.0 to 25.5.3.
- [Release notes](https://github.com/i18next/i18next/releases)
- [Changelog](https://github.com/i18next/i18next/blob/master/CHANGELOG.md)
- [Commits](https://github.com/i18next/i18next/compare/v24.2.0...v25.5.3)

---
updated-dependencies:
- dependency-name: i18next
  dependency-version: 25.5.3
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-10-06 11:03:35 +00:00
Manish Madan
4eb9471b4f Merge pull request #2021 from M4cr0Chen/enhance-edit-button-and-response-ui
adjusted edit button padding and response box padding & roundness
2025-10-06 13:13:00 +05:30
Zhenghong Chen
b505d207d7 adjusted edit button padding and response box padding & roundness 2025-10-05 16:34:17 -04:00
Alex
3c954bd07f Merge pull request #2007 from ManishMadan2882/upload-toast
Upload Sources: Tasks are notified with UI toasts
2025-10-05 13:49:11 +01:00
ManishMadan2882
c00b6459dc (fix:upload) set docs 2025-10-05 18:07:21 +05:30
ManishMadan2882
eb4d776784 (feat:upload) wording, icons and rmv clear button 2025-10-05 16:12:58 +05:30
Alex
5d7a890533 Merge pull request #2015 from JeevaRamanathan/fix/twitter-navigation-link
chore: update twitter.com to x.com
2025-10-04 21:50:47 +01:00
JeevaRamanathan
9c6aefef1e chore: update twitter.com to x.com for avoiding redirection
Signed-off-by: JeevaRamanathan <jeevaramanathan.m@infosys.com>
2025-10-05 02:00:57 +05:30
ManishMadan2882
e4554d6c09 Merge branch 'main' of https://github.com/arc53/DocsGPT into upload-toast 2025-10-03 23:26:32 +05:30
ManishMadan2882
c184b63df8 (feat:upload) i18n 2025-10-03 20:32:43 +05:30
ManishMadan2882
6bb4195393 (feat:upload) dismiss, but notify on completion 2025-10-03 19:55:41 +05:30
Alex
7827a4d40d Merge pull request #1960 from jayamrutiya/letmecheck/bug-return-value-not-be-ignored
Fix: replace map with for...of loop to avoid ignoring return value (S2201)
2025-10-03 14:15:39 +01:00
ManishMadan2882
f09fa8231a (feat:upload) new toast UI 2025-10-03 17:08:39 +05:30
Alex
96ff10000d Merge pull request #1999 from hanzalahwaheed/feat/ui-enhancements
Feat: UI enhancements
2025-10-03 10:15:54 +01:00
Alex
9460636867 Merge pull request #2005 from siiddhantt/feat/routes-refactor
refactor: modularize user API routes into domain-driven structure
2025-10-03 10:03:33 +01:00
Siddhant Rai
6c43245295 refactor: organize import statements for clarity and consistency 2025-10-03 12:41:03 +05:30
Pavel
266b6cf638 Fix typo in HACKTOBERFEST.md 2025-10-03 08:09:22 +01:00
Siddhant Rai
70183e234a refactor: break up monolithic routes.py into modular domain structure 2025-10-03 12:30:04 +05:30
Hanzalah Waheed
17b9c359ca fix: use hexcode for purple taupe and rm extra props 2025-10-03 00:22:36 +04:00
Lanthoiba22
045630b8a5 Agent image fallback added 2025-10-02 21:53:38 +05:30
Alex
55ff7dd640 Merge pull request #2003 from arc53/hacktoberfest-rules
Update HACKTOBERFEST.md
2025-10-02 16:58:22 +01:00
Alex
e6d64f71f2 Update HACKTOBERFEST.md 2025-10-02 16:58:01 +01:00
Alex
e72313ebdd Merge pull request #2002 from ManishMadan2882/main
Frontend Lint
2025-10-02 12:35:59 +01:00
ManishMadan2882
65d5bd72cd Merge branch 'main' of https://github.com/manishmadan2882/docsgpt 2025-10-02 16:59:13 +05:30
ManishMadan2882
dc0cbb41f0 (fix:table) remove display name 2025-10-02 16:50:59 +05:30
ManishMadan2882
c4a54a85be (lint) fe 2025-10-02 16:30:08 +05:30
Hanzalah Waheed
5b2738aec9 fix (ui): is_copied true, disable hover state 2025-10-02 13:44:05 +04:00
Hanzalah Waheed
892312fc08 fix: hover bg color change fixed using correct css var 2025-10-02 12:23:17 +04:00
Alex
c2ccf2c72c Merge pull request #2000 from ManishMadan2882/main
Test coverage:  TTS, Security and Storage layers
2025-10-02 00:11:22 +01:00
ManishMadan2882
80aaecb5f0 ruff-fix 2025-10-02 02:48:16 +05:30
ManishMadan2882
946865a335 test:TTS 2025-10-02 02:40:30 +05:30
ManishMadan2882
5de15c8413 (feat:11Labs) just functional part 2025-10-02 02:39:53 +05:30
ManishMadan2882
67268fd35a tests:security 2025-10-02 02:34:05 +05:30
ManishMadan2882
42fc771833 tests:storage 2025-10-02 02:33:12 +05:30
Hanzalah Waheed
444b1a0b65 feat(ui): add transition animation to navigation sidebar 2025-10-01 23:16:24 +04:00
Hanzalah Waheed
814ea1c016 fix(ui): add text-32px for smaller than lg view for agent headings 2025-10-01 21:21:23 +04:00
Alex
4d34dc4234 Merge pull request #1996 from arc53/pr/1988
Pr/1988
2025-10-01 10:29:20 +01:00
Alex
d567399f2b Merge pull request #1995 from siiddhantt/fix/agent-sources
fix: agent sources and other issues
2025-10-01 10:19:47 +01:00
Alex
77f4f8d8b0 (refactor) update launch configurations and remove obsolete tasks.json 2025-10-01 10:17:52 +01:00
Alex
a2d04beaa1 (refactor) remove redundant source validation in CreateAgent 2025-10-01 10:17:26 +01:00
Siddhant Rai
ba49eea23d Refactor agent creation and update logic to improve error handling and default values; enhance logging for better traceability 2025-10-01 13:56:31 +05:30
Alex
82beafc086 Merge pull request #1991 from ManishMadan2882/tester
Tests: coverage for application/llm/* and application/llm/handlers/* ; fix on parsers/test_markdown
2025-09-30 23:19:38 +01:00
ManishMadan2882
7d8ed2d102 ruff-fix 2025-10-01 02:23:53 +05:30
ManishMadan2882
aab8d3a4f1 Merge branch 'tester' of https://github.com/manishmadan2882/docsgpt into tester 2025-10-01 01:58:52 +05:30
Alex
76658d50a0 Update HACKTOBERFEST.md 2025-09-30 21:46:09 +03:00
Alex
88ba22342c Update README with Hacktoberfest details and demo
Added Hacktoberfest information and a demo GIF.
2025-09-30 21:45:55 +03:00
Alex
11a1460af9 Add Hacktoberfest participation details to HACKTOBERFEST.md
This document outlines the participation of DocsGPT in Hacktoberfest, encouraging contributors to submit meaningful pull requests for a chance to win a T-shirt. It includes guidelines for contributions and links to resources.
2025-09-30 21:42:13 +03:00
Alex
2cd4c41316 Merge pull request #1992 from arc53/fix-api-answer-tool-call
fix: api answer tool call event
2025-09-30 14:49:57 +01:00
Alex
b910f308f2 fix: api answer tool call event 2025-09-30 14:42:54 +01:00
ManishMadan2882
763aa73ea4 (tests:llm) llms, handlers 2025-09-30 16:03:02 +05:30
Manish Madan
30c79e92d4 Merge branch 'arc53:main' into tester 2025-09-30 15:52:23 +05:30
ManishMadan2882
402d5e054b (fix:test_markdown) patch tokenizer 2025-09-30 13:35:03 +05:30
Alex
0e211df206 Merge pull request #1988 from ManishMadan2882/tester
Test coverage for parsers
2025-09-29 17:42:27 +01:00
ManishMadan2882
e24a0ac686 (test:parsers) github, reddit 2025-09-29 20:33:05 +05:30
ManishMadan2882
8c91b1c527 (tests:parsers) remote 2025-09-29 19:39:24 +05:30
ManishMadan2882
2b38f80d04 (test:files) epub, image, rst 2025-09-29 17:39:20 +05:30
ManishMadan2882
282bd35f52 (test:files) html, md, json 2025-09-29 17:23:15 +05:30
Alex
cc9b4c2bcb Merge pull request #1964 from siiddhantt/refine/mcp-tool
refactor: oauth + use fastmcp client for handling SSE and different transports in remote mcp
2025-09-26 13:54:40 +01:00
Alex
068ce4970a Merge branch 'refine/mcp-tool' of https://github.com/siiddhantt/DocsGPT into pr/1964 2025-09-26 13:42:34 +01:00
Alex
cf19165ad8 fix(lint): ruff 2025-09-26 13:42:08 +01:00
Alex
68c479f3a5 Merge branch 'main' into refine/mcp-tool 2025-09-26 13:40:12 +01:00
ManishMadan2882
ba496a772b (test) doc parsers coverage 2025-09-26 16:07:12 +05:30
Siddhant Rai
3b27db36f2 feat: Implement OAuth flow for MCP server integration
- Added MCPOAuthManager to handle OAuth authorization.
- Updated MCPServerSave resource to manage OAuth status and callback.
- Introduced new endpoints for OAuth status and callback handling.
- Enhanced user interface to support OAuth authentication type.
- Implemented polling mechanism for OAuth status in MCPServerModal.
- Updated frontend services and endpoints to accommodate new OAuth features.
- Improved error handling and user feedback for OAuth processes.
2025-09-26 02:44:08 +05:30
Alex
f803def69b Merge pull request #1976 from ManishMadan2882/main
OAuth fix for connector
2025-09-25 10:15:13 +01:00
ManishMadan2882
52065e69a4 ruff 2025-09-25 05:05:59 +05:30
ManishMadan2882
50f5e8a955 Merge branch 'main' of https://github.com/manishmadan2882/docsgpt 2025-09-25 04:58:05 +05:30
ManishMadan2882
2d0e97b66d (feat:oauth) provider as state, not args 2025-09-25 04:55:25 +05:30
Manish Madan
5f3cc5a392 Merge branch 'arc53:main' into main 2025-09-25 03:50:50 +05:30
ManishMadan2882
ac66d77512 (fix:oauth) handle access denied 2025-09-25 03:45:12 +05:30
Alex
50cf653d4a Merge pull request #1975 from arc53/chunk-fix
fix: chunking
2025-09-24 23:05:41 +01:00
Alex
56256051d2 fix: chunking 2025-09-24 22:59:53 +01:00
ManishMadan2882
c0361ff03d (fix:oauth) user field null 2025-09-25 03:27:16 +05:30
Alex
f153435c08 Merge pull request #1974 from ManishMadan2882/tester
Test app fix
2025-09-24 10:12:29 +01:00
Alex
9aa7f22fa6 Merge pull request #1961 from NewAi25/feature-letmecheck-fix-Identica-Sub-Expression
Added fix in frontend/src/conversation/ConversationMessages.tsx line…
2025-09-24 09:35:12 +01:00
ManishMadan2882
52b7bda5f8 Merge branch 'main' of https://github.com/manishmadan2882/docsgpt into tester 2025-09-24 02:37:20 +05:30
ManishMadan2882
21aefa2778 (fix:tests.application) attr err 2025-09-23 23:43:50 +05:30
Alex
a89ff71c9e Update README.md 2025-09-23 12:48:49 +03:00
Alex
4c275816be Merge pull request #1954 from ManishMadan2882/main
Refinements: Files uploaded via Connectors
2025-09-22 21:54:48 +01:00
ManishMadan2882
f8dfbcfc80 (docs) guide gdrive 2025-09-22 20:20:51 +05:30
ManishMadan2882
d317f6473d (feat:gdrive) upload files only 2025-09-22 20:19:56 +05:30
Siddhant Rai
00b4e133d4 feat: implement OAuth 2.1 integration with custom handlers for fastmcp 2025-09-22 01:31:09 +05:30
ManishMadan2882
b6349e4efb (fix:gdrive) no api keyneded 2025-09-20 19:55:18 +05:30
ManishMadan2882
6ca3d9585c (fix:connector) db entry 2025-09-20 19:52:12 +05:30
ManishMadan2882
5935a0283a (fix:ui)adjust 2025-09-20 19:50:58 +05:30
ManishMadan2882
5400a6ec06 (feat:googlePicker) frontend envars 2025-09-20 00:13:03 +05:30
ManishMadan2882
6574d9cc84 (feat:pickers) ux, code refactor 2025-09-20 00:04:27 +05:30
ManishMadan2882
42b83c5994 (refactor:upload) single schema with validation 2025-09-19 23:55:40 +05:30
ManishMadan2882
896612a5a3 (fix:auth) refresh drive token 2025-09-19 22:57:09 +05:30
ManishMadan2882
0ee875bee4 (fix:gdrive) missing appId in picker 2025-09-18 20:25:20 +05:30
Siddhant Rai
8ce345cd94 feat: refactor MCPTool to use FastMCP client and improve async handling, update transport and authentication configurations 2025-09-17 20:51:32 +05:30
ManishMadan2882
da2f8477e6 (feat:drive) oauth for drive.file scope, picker 2025-09-17 19:37:01 +05:30
jane
82b47b5673 Added fix in frontend/src/conversation/ConversationMessages.tsx line 213 2025-09-16 23:53:06 +05:30
jayamrutiya
7c15a4c7ff Fix: replace map with forEach to avoid ignoring return value (S2201) 2025-09-16 22:11:55 +05:30
Siddhant Rai
3369b910b4 feat: update MCP request ID generation and enhance response handling for event streams 2025-09-16 20:43:04 +05:30
ManishMadan2882
ec0c4c3b84 Merge branch 'main' of https://github.com/arc53/DocsGPT 2025-09-16 14:59:18 +05:30
ManishMadan2882
f74e2c9da1 (feat:filepciker) load inside table, ui 2025-09-16 14:50:58 +05:30
Alex
e26ad3c475 Merge pull request #1947 from siiddhantt/feat/remote-mcp
feat: remote mcp
2025-09-15 09:31:36 +01:00
ManishMadan2882
145c3b8ad0 (feat:file picker) table ui 2025-09-15 12:58:45 +05:30
Siddhant Rai
0ff6c6a154 Merge branch 'main' into feat/remote-mcp 2025-09-15 09:53:58 +05:30
Siddhant Rai
641cf5a4c1 feat: skip empty fields in mcp tool call + improve error handling and response 2025-09-11 19:04:10 +05:30
Siddhant Rai
09b9576eef feat: enhance message and schema cleaning for Google AI integration 2025-09-11 17:54:46 +05:30
ManishMadan2882
18b71ca2f2 (feat:upload) cards for upload types 2025-09-11 13:27:55 +05:30
Alex
e0eb7f456e Merge pull request #1930 from Ankit-Matth/feature/multi-select-sources
Added support for Multi select sources
2025-09-10 17:48:02 +01:00
Siddhant Rai
188d118fc0 refactor: remove unused logging import from routes.py 2025-09-10 22:14:31 +05:30
Siddhant Rai
adcdce8d76 fix: handle invalid chunks value in StreamProcessor and ClassicRAG 2025-09-10 22:10:11 +05:30
Siddhant Rai
b865a7aec1 Merge branch 'main' of https://github.com/siiddhantt/DocsGPT into pr/1930 2025-09-10 20:15:20 +05:30
ManishMadan2882
cec8c72b46 (refactor:uploads) YAGNI 2025-09-10 19:19:40 +05:30
Siddhant Rai
b052e32805 feat: enhance MCP tool configuration handling and authentication logic 2025-09-10 14:15:51 +05:30
Siddhant Rai
816f660be3 fix: enhance MCPTool request handling and tool discovery logic 2025-09-10 13:08:14 +05:30
ManishMadan2882
fc8be45d5a (feat:sync) confirmation check 2025-09-09 13:08:38 +05:30
ManishMadan2882
e749c936c9 (refactor:uploads) for new ui 2025-09-09 13:07:26 +05:30
ManishMadan2882
b2b9670a23 (feat:connectors) type as connector:file 2025-09-09 00:00:58 +05:30
Siddhant Rai
2f88890c94 feat: add support for multiple sources in agent configuration and update related components 2025-09-08 22:10:08 +05:30
ManishMadan2882
6366663f03 (refactor:uploads) separate file picker 2025-09-08 19:09:19 +05:30
Manish Madan
20fe7dc6d1 Merge branch 'main' into main 2025-09-08 14:19:32 +05:30
Alex
4b9153069e Merge pull request #1928 from Ankit-Matth/fix/ui-ux-improvements
bugfix(ui): several UI/UX updates and fixes
2025-09-05 13:54:09 +01:00
ManishMadan2882
80406d0753 (feat:drive) debounce search, ui/ux 2025-09-05 13:25:36 +05:30
ManishMadan2882
35f4c11784 Merge branch 'main' of https://github.com/manishmadan2882/docsgpt 2025-09-05 11:43:18 +05:30
ManishMadan2882
7896526f19 (feat:load_files) search feature 2025-09-05 10:35:23 +05:30
Alex
f7db22edff Merge pull request #1937 from ManishMadan2882/main
Connectors: Google Drive Ingestion
2025-09-04 12:05:15 +01:00
Siddhant Rai
0e4196f036 fix: remove unused tool labels from localization file 2025-09-04 15:19:58 +05:30
Siddhant Rai
1bf6af6eeb feat: finalize remote mcp 2025-09-04 15:10:12 +05:30
ManishMadan2882
5a9bc6d2bf (feat:connector) infinite scroll file pick 2025-09-04 08:35:41 +05:30
ManishMadan2882
f7f6042579 (feat:connector) paginate files 2025-09-04 07:58:12 +05:30
ManishMadan2882
c4a598f3d3 (lint-fix) ruff 2025-09-03 19:29:34 +05:30
Siddhant Rai
7c23f43c63 feat: Add MCP Server management functionality
- Implemented encryption utility for securely storing sensitive credentials.
- Added MCPServerModal component for managing MCP server configurations.
- Updated AddToolModal to include MCP server management.
- Enhanced localization with new strings for MCP server features.
- Introduced SVG icons for MCP tools in the frontend.
- Created a new settings section for MCP server configurations in the application.
2025-09-03 15:41:59 +05:30
ManishMadan2882
7e2cbdd88c (feat:connector) redirect url as backend overhead 2025-09-03 09:57:13 +05:30
ManishMadan2882
3b3a04a249 (feat:connector) sync fixes UI, minor refactor 2025-09-02 20:28:23 +05:30
ManishMadan2882
f9b2c95695 (feat:connector) sync, simply re-ingest 2025-09-02 18:06:04 +05:30
ManishMadan2882
c2c18e8319 (feat:connector,fe) sync api, notification 2025-09-02 13:36:41 +05:30
ManishMadan2882
384ad3e0ac (feat:connector) raw sync flow 2025-09-02 13:34:31 +05:30
ManishMadan2882
8c986aaa7f Merge branch 'main' of https://github.com/manishmadan2882/docsgpt 2025-09-01 12:05:17 +05:30
ManishMadan2882
bb4ea76d30 (fix:connectorTree) path navigation fn 2025-09-01 12:04:58 +05:30
ManishMadan2882
2868e47cf8 (feat:connector) provider metadata, separate fe nested display 2025-08-29 18:05:58 +05:30
GH Action - Upstream Sync
e0adc3e5d5 Merge branch 'main' of https://github.com/arc53/DocsGPT 2025-08-29 01:36:09 +00:00
ManishMadan2882
e55d1a5865 (feat:connector,auth) consider user_id 2025-08-29 02:13:51 +05:30
ManishMadan2882
018273c6b2 (feat:connector) refactor, updated routes FE 2025-08-29 01:06:40 +05:30
Alex
44b8a11c04 Merge pull request #1936 from Ankit-Matth/feature/load-containers-from-dockerhub
Speed up scripts by loading containers from docker hub
2025-08-28 14:48:08 +01:00
Siddhant Rai
56e5aba559 fix: correct frontend image name in docker-compose configuration 2025-08-28 18:21:54 +05:30
Alex
46904ccd54 feat: add theme color meta tags for light and dark modes 2025-08-28 11:36:42 +01:00
Siddhant Rai
5b7c7a4471 fix: update Docker images in docker-compose to use 'develop' tag 2025-08-28 12:11:06 +05:30
Siddhant Rai
9da4215d1f feat: implement Docker Hub integration for building and pushing images in CI/CD workflow 2025-08-28 12:01:04 +05:30
ManishMadan2882
f39ac9945f (feat:auth) follow connector-session 2025-08-28 00:53:19 +05:30
ManishMadan2882
a0cc2e4d46 Merge branch 'main' of https://github.com/manishmadan2882/docsgpt 2025-08-28 00:51:29 +05:30
ManishMadan2882
4065041a9f (feat:connectors) separate routes, namespace 2025-08-28 00:51:09 +05:30
GH Action - Upstream Sync
f08067a161 Merge branch 'main' of https://github.com/arc53/DocsGPT 2025-08-27 01:36:38 +00:00
Alex
545caacfa3 feat: prevent NUL character ingestion failures 2025-08-26 23:30:57 +01:00
Alex
a06f646637 feat: enhance tool call error handling 2025-08-26 22:37:21 +01:00
ManishMadan2882
578c68205a (feat:connectors) abstracting auth, base class 2025-08-26 02:46:36 +05:30
ManishMadan2882
f09f1433a9 (feat:connectors) separate layer 2025-08-26 01:38:36 +05:30
ManishMadan2882
15a9e97a1e (feat:ingest_connectors) spread config params 2025-08-26 00:56:39 +05:30
Ankit Matth
b3af4ee50b speed up scripts by using docker hub 2025-08-24 08:59:19 +05:30
Ankit Matth
07d59b6640 refactor: use list instead of string parsing 2025-08-23 20:25:29 +05:30
GH Action - Upstream Sync
e25b988dc8 Merge branch 'main' of https://github.com/arc53/DocsGPT 2025-08-23 01:35:35 +00:00
ManishMadan2882
2410bd8654 (fix:driveLoader) folder ingesting 2025-08-22 19:07:52 +05:30
Alex
44d21ab703 fix: passing sources and chunk if agent is shared 2025-08-22 13:36:31 +01:00
Alex
e283957c8f Fix source field retrieval in SharedAgent to handle DBRef correctly 2025-08-22 11:41:35 +01:00
ManishMadan2882
b1210c4902 Merge branch 'main' of https://github.com/manishmadan2882/docsgpt 2025-08-22 13:36:52 +05:30
ManishMadan2882
e7430f0fbc (feat:googleDrive,fe) file tree 2025-08-22 13:36:32 +05:30
ManishMadan2882
92d6ae54c3 (fix:google-oauth) no explicit datetime compare 2025-08-22 13:35:03 +05:30
ManishMadan2882
f82be23ca9 (feat:ingestion) external drive connect 2025-08-22 13:33:21 +05:30
ManishMadan2882
8c3f75e3e2 (feat:ingestion) google drive loader 2025-08-22 13:32:40 +05:30
GH Action - Upstream Sync
193d59f193 Merge branch 'main' of https://github.com/arc53/DocsGPT 2025-08-22 01:38:59 +00:00
ManishMadan2882
c2bebbaefa (feat:oauth/drive) raw fe integrate 2025-08-22 03:29:57 +05:30
Alex
7ae5a9c5a5 Refactor diagramId initialization to use a combination of Date.now() and random string for uniqueness 2025-08-21 14:50:37 +01:00
ManishMadan2882
3b69bea23d (chore:settings)addefault oath creds 2025-08-21 17:02:23 +05:30
ManishMadan2882
ab05726b99 Merge branch 'main' of https://github.com/manishmadan2882/docsgpt 2025-08-21 02:46:56 +05:30
ManishMadan2882
b2b04268e9 (feat:drive) oauth flow 2025-08-21 02:46:32 +05:30
Siddhant Rai
bd73fa9ae7 refactor: remove unused abstract method and improve retrievers 2025-08-20 22:25:31 +05:30
Alex
927d10d66e Update README.md 2025-08-17 12:44:47 +03:00
Alex
b67329623c Update README.md 2025-08-16 13:51:40 +03:00
Ankit Matth
6f47aa802b added support for multi select sources 2025-08-16 15:19:19 +05:30
Ankit Matth
3417c73011 fix(ui): resolve multiple UI/UX issues 2025-08-15 10:14:31 +05:30
Alex
6a02bcf15b Merge pull request #1873 from ManishMadan2882/main
Sources are the new Docs
2025-08-13 18:24:35 +01:00
Alex
cd0fbf79a3 Merge pull request #1924 from siiddhantt/feat/agent-schema-response
feat: add support for structured output and JSON schema validation
2025-08-13 17:33:29 +01:00
Alex
15d2d0115b Merge branch 'main' into feat/agent-schema-response 2025-08-13 17:12:26 +01:00
ManishMadan2882
d1a0fe6e91 Merge branch 'main' of https://github.com/manishmadan2882/docsgpt 2025-08-13 17:39:59 +05:30
ManishMadan2882
1db80d140f (fix) search dropdowns 2025-08-13 17:39:39 +05:30
Siddhant Rai
896dcf1f9e feat: add support for structured output and JSON schema validation 2025-08-13 13:29:51 +05:30
GH Action - Upstream Sync
819a12fb49 Merge branch 'main' of https://github.com/arc53/DocsGPT 2025-08-13 01:45:41 +00:00
Alex
c68273706c Merge pull request #1920 from hanzalahwaheed/fix/response-bubble-btns
fix: always show the response bubble buttons.
2025-08-13 00:14:01 +01:00
Hanzalah Waheed
6bb0cd535a fix: rm redundant states. track feedback state w prop var 2025-08-13 02:36:58 +04:00
Hanzalah Waheed
cb9ec69cf6 chore: refactor code to use ternary operator for error type check 2025-08-13 02:31:25 +04:00
Hanzalah Waheed
143854fa81 fix: show both like and dislike buttons 2025-08-13 02:11:47 +04:00
ManishMadan2882
2f48a3d7d5 (feat:chunks) consistent path header 2025-08-13 02:53:32 +05:30
ManishMadan2882
ec95dafe1e (feat:sources) matching the figma 2025-08-13 01:35:23 +05:30
ManishMadan2882
3d1fe724e5 Merge branch 'main' of https://github.com/manishmadan2882/docsgpt 2025-08-12 18:05:11 +05:30
ManishMadan2882
5c615d6f2d (feat:sources) card ui 2025-08-12 18:04:40 +05:30
GH Action - Upstream Sync
d72558eb36 Merge branch 'main' of https://github.com/arc53/DocsGPT 2025-08-12 01:43:57 +00:00
ManishMadan2882
65c33ad915 Merge branch 'main' of https://github.com/manishmadan2882/docsgpt 2025-08-12 03:09:13 +05:30
ManishMadan2882
9be128a963 (feat:sources) closer to figma,ux 2025-08-12 03:08:47 +05:30
Hanzalah Waheed
eb05132008 fix: always show the response bubble buttons. 2025-08-12 00:16:21 +04:00
Alex
f94a093e8c fix: truncate long text fields to prevent overflow in logs and sources 2025-08-11 14:56:31 +01:00
GH Action - Upstream Sync
0d0c2daf64 Merge branch 'main' of https://github.com/arc53/DocsGPT 2025-08-09 01:44:41 +00:00
ManishMadan2882
823d948b25 Merge branch 'main' of https://github.com/manishmadan2882/docsgpt 2025-08-08 21:59:01 +05:30
Alex
56831fbcf2 Merge pull request #1917 from ManishMadan2882/fix/agent_prompts
Fixes missing attributes on shared agents
2025-08-08 16:30:09 +01:00
ManishMadan2882
bf49b9cb88 (fix/shared_agent)missing main attr 2025-08-08 20:44:09 +05:30
GH Action - Upstream Sync
e01adffbad Merge branch 'main' of https://github.com/arc53/DocsGPT 2025-08-08 01:54:52 +00:00
Alex
08a5d52d82 Update README.md 2025-08-07 17:20:45 +03:00
ManishMadan2882
fdae235742 (feat:sources) i18n 2025-08-07 12:53:12 +05:30
GH Action - Upstream Sync
9903fad1e9 Merge branch 'main' of https://github.com/arc53/DocsGPT 2025-08-07 01:55:18 +00:00
Alex
14bbd5338d Merge pull request #1909 from siiddhantt/main
fix: remove unnecessary parameter from fetchPreviewAnswer
2025-08-06 19:21:59 +01:00
Siddhant Rai
4a236c2f6f fix: remove unnecessary parameter from fetchPreviewAnswer 2025-08-06 22:34:06 +05:30
Alex
0a8cdbd7f1 fix: update token selector in FileTreeComponent 2025-08-06 12:44:21 +01:00
Alex
94c49843be Merge pull request #1906 from arc53/feat/pg-vector
feat: implement PGVectorStore for PostgreSQL vector storage
2025-08-06 10:42:34 +01:00
Alex
9281fac898 fix: improve error logging for index creation and add PARSE_IMAGE_REMOTE setting 2025-08-06 10:40:20 +01:00
GH Action - Upstream Sync
0b2736f454 Merge branch 'main' of https://github.com/arc53/DocsGPT 2025-08-06 01:55:04 +00:00
ManishMadan2882
ae116b0d0d (fix:agentPreview) build err 2025-08-06 02:55:39 +05:30
ManishMadan2882
ba260e3382 (fix:faiss) not save tmp dir 2025-08-06 02:53:39 +05:30
Alex
1282e7687f fix: add error handling for index creation in user and agents collections 2025-08-05 17:19:06 +01:00
Alex
b1d8266eef feat: implement PGVectorStore for PostgreSQL vector storage 2025-08-05 13:54:39 +01:00
Alex
7acae6935b Merge pull request #1905 from arc53/fix-qdrant
fix: qdrant issues
2025-08-05 12:26:24 +01:00
Alex
092c01cae7 fix: ruff lint 2025-08-05 12:22:33 +01:00
Alex
56a1066c30 fix: qdrant issues 2025-08-05 12:19:18 +01:00
ManishMadan2882
1356d71839 (lint) ruff fix 2025-08-05 15:37:39 +05:30
ManishMadan2882
1eb011e8c3 Merge branch 'main' of https://github.com/arc53/DocsGPT 2025-08-05 15:28:51 +05:30
ManishMadan2882
e349eb28b0 (fix:update_chunk) data integrity, uplod back faiss 2025-08-05 06:31:00 +05:30
ManishMadan2882
b000b235a2 (feat:sources) renamed docs,fe 2025-08-05 05:23:29 +05:30
ManishMadan2882
16fe92282e (fix:chunks) responsive, editing controls 2025-08-05 03:04:37 +05:30
ManishMadan2882
e218e88cf4 (fix:chunks) alignment and ui 2025-08-04 19:28:15 +05:30
ManishMadan2882
888ea81a32 (feat:fil_management) serialising updates, queue 2025-08-04 17:27:12 +05:30
ManishMadan2882
735fab7640 (feat:storage) sync base 2025-08-04 16:36:38 +05:30
ManishMadan2882
45745c2a47 (feat:docs) skeleton loader 2025-08-04 16:35:24 +05:30
Alex
4caff0fcf6 fix: enhance error logging for malformed request in stream route 2025-08-04 11:41:41 +01:00
Alex
762ea6ce7f Merge pull request #1866 from arc53/dependabot/npm_and_yarn/frontend/tailwindcss-4.1.11
build(deps-dev): bump tailwindcss from 4.1.10 to 4.1.11 in /frontend
2025-08-02 23:49:54 +01:00
ManishMadan2882
8b4f6553f3 (fix:menu) left and right 2025-08-02 02:08:35 +05:30
ManishMadan2882
a61e44d175 (feat:dir_tree) improvement 2025-08-02 01:48:43 +05:30
ManishMadan2882
e1b1558fc9 (feat:storage) rm dir 2025-08-02 00:54:09 +05:30
ManishMadan2882
53225bda4e (feat:reingestion) spit directories 2025-08-02 00:49:15 +05:30
ManishMadan2882
5212769848 (feat:reingest) UI, polling 2025-08-01 01:25:37 +05:30
ManishMadan2882
d5ded3c9f4 (feat:reingest) eat and spit specific chunks 2025-08-01 01:14:48 +05:30
ManishMadan2882
c92d778894 (feat:chunker) do not combine text 2025-07-31 02:13:55 +05:30
ManishMadan2882
829abd1ad6 (fix:textarea) consistent line indxs 2025-07-30 21:00:00 +05:30
ManishMadan2882
266d256a07 (feat:sources) management, simple re-ingest 2025-07-30 01:57:40 +05:30
Alex
8380cac3e7 Merge pull request #1900 from naaa760/docs/auth-type-configuration
Docs: Expand and Clarify AUTH_TYPE Configuration and Authentication Methods (#1882)
2025-07-28 23:07:40 +01:00
ManishMadan2882
a24652f901 (feat:chunks) update iff changed 2025-07-28 19:35:41 +05:30
ManishMadan2882
2d203d3c70 (fix:chunks)responsive 2025-07-28 18:01:51 +05:30
Alex
48d21600da Merge pull request #1896 from siiddhantt/feat/rework-answer-routes
feat: answer routes re-structure for better maintainability and reuse
2025-07-26 14:18:10 +01:00
ManishMadan2882
2508d0fbb3 (fix:chunks) preserve paths 2025-07-26 00:27:39 +05:30
ManishMadan2882
e90e80c289 (fix:chunks) also count tokens 2025-07-26 00:16:45 +05:30
ManishMadan2882
5e4748f9d9 (fix:faiss) rely on storage abstrct 2025-07-26 00:14:17 +05:30
Siddhant Rai
212952f3e9 fix: allow api call in stream route + get_prompt error 2025-07-25 16:17:18 +05:30
naaa760
f99b6496c5 update 2025-07-25 14:48:49 +05:30
ManishMadan2882
67423d51b9 (feat:chunks) ask to edit, ui 2025-07-25 04:05:06 +05:30
ManishMadan2882
58465ece65 (feat:chunks) server-side filter on search 2025-07-25 01:43:50 +05:30
ManishMadan2882
8ede3a0173 Merge branch 'main' of https://github.com/manishmadan2882/docsgpt 2025-07-23 20:48:00 +05:30
ManishMadan2882
ad2f0f8950 (chore:chunks) i18n 2025-07-23 20:47:36 +05:30
Siddhant Rai
76973a4b4c feat: answer routes re-structure for better maintainability and reuse 2025-07-23 20:07:42 +05:30
GH Action - Upstream Sync
b198e2e029 Merge branch 'main' of https://github.com/arc53/DocsGPT 2025-07-23 01:52:41 +00:00
ManishMadan2882
4d6ea401b5 (feat:chunks) line numbered editor 2025-07-23 03:36:18 +05:30
ManishMadan2882
b00c4cc3b6 (feat:chunk) editing mode 2025-07-23 02:22:56 +05:30
Alex
4185e64c65 Merge pull request #1893 from arc53/fix-attachment-bugs
fix: replace secure_filename with safe_filename for attachment handling
2025-07-22 16:24:55 +01:00
ManishMadan2882
6eb2c884a2 (refactor) separation in chunks/files view 2025-07-22 19:36:52 +05:30
Alex
6c0362a4cf fix: replace secure_filename with safe_filename for attachment handling 2025-07-22 12:56:17 +01:00
ManishMadan2882
50b1755a63 Merge branch 'main' of https://github.com/manishmadan2882/docsgpt 2025-07-21 16:31:57 +05:30
ManishMadan2882
ff3c7eb5fb (fix:delete_old) comply with storage abtrctn 2025-07-21 16:31:42 +05:30
ManishMadan2882
3755316d49 (fix:chunks) responsive design 2025-07-21 16:30:30 +05:30
GH Action - Upstream Sync
f952046847 Merge branch 'main' of https://github.com/arc53/DocsGPT 2025-07-19 01:47:09 +00:00
Alex
969cdb4a63 Merge pull request #1890 from siiddhantt/feat/enhance-agents
feat: enhance prompt selection in new agents
2025-07-18 13:07:45 +01:00
ManishMadan2882
f336d44595 (feat:chunks) search in dir 2025-07-18 15:03:23 +05:30
Siddhant Rai
a53f93c195 feat: enhance dropdown component and prompts integration 2025-07-18 14:02:29 +05:30
GH Action - Upstream Sync
fcb334ce33 Merge branch 'main' of https://github.com/arc53/DocsGPT 2025-07-17 01:50:59 +00:00
ManishMadan2882
8ddf04a904 (feat:chunks) use common header, navigate 2025-07-17 03:08:01 +05:30
ManishMadan2882
29698ca169 (feat:chunks) redesigned 2025-07-17 02:16:40 +05:30
ManishMadan2882
a9baf7436a Merge branch 'main' of https://github.com/manishmadan2882/docsgpt 2025-07-17 01:05:59 +05:30
ManishMadan2882
99a8962183 (fix/docs) menu event capture 2025-07-17 01:05:24 +05:30
Alex
afc5b15a6b Merge pull request #1887 from Krrish0902/fix-issue-1854
fix: Removed incorrect www from URL in DocsGPT Docs frontend (#1854)
2025-07-16 13:20:05 +01:00
GH Action - Upstream Sync
b6ab508e27 Merge branch 'main' of https://github.com/arc53/DocsGPT 2025-07-16 01:50:25 +00:00
ananthakrishnan
789e65557a fix: Removed incorrect www from URL in DocsGPT Docs frontend (#1854) 2025-07-15 22:02:02 +05:30
ManishMadan2882
8a7806ab2d (feat:nested source view) file tree, chunks display 2025-07-15 19:15:40 +05:30
Alex
493303e103 Merge pull request #1886 from arc53/copilot/fix-1878
🐛 Fix conversation summary prompt to use user query language
2025-07-15 12:57:14 +01:00
ManishMadan2882
1d9af05e9e (feat:storage) is dir fnc 2025-07-15 15:34:16 +05:30
ManishMadan2882
5b07c5f2e8 (feat:ingestion) unzip, extract and store 2025-07-15 15:31:26 +05:30
copilot-swe-agent[bot]
2a4ec0cf5b Fix conversation summary prompt to use user query language
Co-authored-by: dartpain <15183589+dartpain@users.noreply.github.com>
2025-07-15 09:33:52 +00:00
copilot-swe-agent[bot]
a00c44386e Initial plan 2025-07-15 09:29:22 +00:00
ManishMadan2882
a38d71bbfb (feat:get_chunks) filtered by relative path 2025-07-15 13:38:19 +05:30
ManishMadan2882
a24a3f868c (feat:dir-structure) adding route 2025-07-15 13:38:19 +05:30
ManishMadan2882
f60c516185 (feat:dir_tree) table with folder contents 2025-07-15 13:38:19 +05:30
Manish Madan
26f4646304 Merge branch 'arc53:main' into main 2025-07-14 23:02:36 +05:30
ananthakrishnan
3a351f67e6 fix: correct agent tools name when creating new agent (#1877) 2025-07-14 20:20:55 +05:30
dependabot[bot]
e7c09cb91e build(deps-dev): bump tailwindcss from 4.1.10 to 4.1.11 in /frontend
Bumps [tailwindcss](https://github.com/tailwindlabs/tailwindcss/tree/HEAD/packages/tailwindcss) from 4.1.10 to 4.1.11.
- [Release notes](https://github.com/tailwindlabs/tailwindcss/releases)
- [Changelog](https://github.com/tailwindlabs/tailwindcss/blob/main/CHANGELOG.md)
- [Commits](https://github.com/tailwindlabs/tailwindcss/commits/v4.1.11/packages/tailwindcss)

---
updated-dependencies:
- dependency-name: tailwindcss
  dependency-version: 4.1.11
  dependency-type: direct:development
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-07-14 09:02:48 +00:00
Alex
ae1a6ef303 Merge pull request #1883 from siiddhantt/feat/agent-validation-enhance
feat: improve interactivity of publishing/drafting of agents
2025-07-14 09:58:12 +01:00
Siddhant Rai
2ff477a339 feat(agent): enhance validation for agent creation by checking required and invalid fields 2025-07-14 12:53:09 +05:30
Siddhant Rai
793f3fb683 refactor(NewAgent): remove debug logs 2025-07-12 12:36:18 +05:30
Siddhant Rai
a472ee7602 feat: add validation for required fields and improve agent creation logic 2025-07-12 12:30:00 +05:30
GH Action - Upstream Sync
c62040e232 Merge branch 'main' of https://github.com/arc53/DocsGPT 2025-07-11 01:49:09 +00:00
Alex
2e7cb510ae Update README.md 2025-07-10 17:27:08 +03:00
Alex
dbe45904d7 Merge pull request #1881 from siiddhantt/fix/glacier-images
fix: s3 storage class for image upload
2025-07-10 12:00:55 +01:00
Siddhant Rai
5623734276 feat(storage): enhance save_file method to accept storage class parameter 2025-07-10 15:34:52 +05:30
ManishMadan2882
d3b592bffc Merge branch 'main' of https://github.com/manishmadan2882/docsgpt 2025-07-09 01:29:59 +05:30
ManishMadan2882
4fcbdae5bf (feat:docs) cards UI 2025-07-08 02:46:34 +05:30
ManishMadan2882
ca95d7275a (feat:dateTimeUtils) localise weekday format 2025-07-08 02:32:18 +05:30
Manish Madan
61baf3701c Merge branch 'arc53:main' into main 2025-07-04 02:26:24 +05:30
ManishMadan2882
bbce872ac5 (fix:chunker) combine metadata as well 2025-07-04 02:19:58 +05:30
ManishMadan2882
0f7ebcd8e4 (feat:dir-reader) store mime types, file size in db 2025-07-03 18:09:19 +05:30
ManishMadan2882
82fc19e7b7 (fix:dir-reader) conflict of same filename in dir 2025-07-03 17:28:12 +05:30
Alex
839a12bed4 Merge pull request #1869 from siiddhantt/refactor/tools-dict
refactor: update user tools dict to use enumeration based key
2025-07-03 12:51:33 +09:00
ManishMadan2882
2ef23fe1b3 (feat:dir-reader) maintain dir structure in db 2025-07-03 01:24:22 +05:30
ManishMadan2882
fd905b1a06 (feat:dir-reader) save tokens with filenames 2025-07-02 16:30:29 +05:30
Siddhant Rai
1372210004 refactor: update user tools dict to use enumeration based key 2025-07-02 10:37:32 +05:30
ManishMadan2882
ade704d065 (refactor:ingestion) pass file path once 2025-07-01 04:00:57 +05:30
Alex
42f48649b9 Merge pull request #1843 from arc53/dependabot/npm_and_yarn/frontend/tailwindcss-4.1.10
build(deps-dev): bump tailwindcss from 3.4.17 to 4.1.10 in /frontend
2025-06-28 13:22:47 +09:00
ManishMadan2882
0b08e8b617 (fix:nav) settings gear dimesions 2025-06-27 22:16:01 +05:30
ManishMadan2882
926b2f1a1b clean 2025-06-25 19:03:24 +05:30
ManishMadan2882
1770a1a45f Merge branch 'main' of https://github.com/arc53/DocsGPT into dependabot/npm_and_yarn/frontend/tailwindcss-4.1.10 2025-06-25 18:59:51 +05:30
ManishMadan2882
50ed2a64c6 (fix/ddropdown) use complete classNames, interpolation 2025-06-25 18:26:33 +05:30
Alex
2332344988 Merge pull request #1861 from arc53/docs-agents-update
Agent docs upd
2025-06-25 09:23:09 +01:00
Alex
7ccc8cdc58 Merge pull request #1855 from siiddhantt/refactor/ddg-brave-tools
refactor: ddg and brave tools with sources fix
2025-06-25 09:21:38 +01:00
ManishMadan2882
ecec9f913e (fix:messageInput) modern tailwind syntx 2025-06-25 02:15:03 +05:30
ManishMadan2882
777f40fc5e (fix:conflicting global css) layered styles 2025-06-25 01:11:39 +05:30
Pavel
327ae35420 Agent docs upd
1. Added a page about interacting with agent API.
2. Added a page about interacting with agent webhooks.
3. Fixed small bug with /api/answer
2025-06-24 16:48:12 +02:00
Siddhant Rai
0d48159da8 feat: enhance modal functionality with reset and confirmation handlers 2025-06-24 02:14:15 +05:30
Siddhant Rai
d36f12a4ea feat: add DuckDuckGo icon and remove sources skeleton 2025-06-24 02:11:58 +05:30
Siddhant Rai
709488beb1 fix: sources not getting set on stream end 2025-06-23 09:23:18 +05:30
Siddhant Rai
a9e4583695 refactor: use DuckDuckGo and Brave as tools instead of retrievers 2025-06-23 09:22:17 +05:30
ManishMadan2882
4702dec933 (chore:tailwindcss) via upgrade tool 2025-06-22 18:31:26 +05:30
ManishMadan2882
e6352dd691 Revert "build(deps-dev): bump tailwindcss from 3.4.17 to 4.1.10 in /frontend"
This reverts commit 240ea3b857.
2025-06-22 18:16:41 +05:30
dependabot[bot]
240ea3b857 build(deps-dev): bump tailwindcss from 3.4.17 to 4.1.10 in /frontend
Bumps [tailwindcss](https://github.com/tailwindlabs/tailwindcss/tree/HEAD/packages/tailwindcss) from 3.4.17 to 4.1.10.
- [Release notes](https://github.com/tailwindlabs/tailwindcss/releases)
- [Changelog](https://github.com/tailwindlabs/tailwindcss/blob/main/CHANGELOG.md)
- [Commits](https://github.com/tailwindlabs/tailwindcss/commits/v4.1.10/packages/tailwindcss)

---
updated-dependencies:
- dependency-name: tailwindcss
  dependency-version: 4.1.10
  dependency-type: direct:development
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-06-22 12:08:36 +00:00
Alex
f0908af3c0 Merge pull request #1829 from arc53/dependabot/npm_and_yarn/frontend/multi-d4a34be08c
build(deps): bump react and @types/react in /frontend
2025-06-22 12:07:44 +01:00
ManishMadan2882
6834961dd1 (fix:types) stricter in v19 2025-06-20 23:11:53 +05:30
ManishMadan2882
7077ca5e98 (chore/upgrade) migrate to react v19 2025-06-20 17:36:28 +05:30
dependabot[bot]
bab3ae809c build(deps): bump react and @types/react in /frontend
Bumps [react](https://github.com/facebook/react/tree/HEAD/packages/react) and [@types/react](https://github.com/DefinitelyTyped/DefinitelyTyped/tree/HEAD/types/react). These dependencies needed to be updated together.

Updates `react` from 18.3.1 to 19.1.0
- [Release notes](https://github.com/facebook/react/releases)
- [Changelog](https://github.com/facebook/react/blob/main/CHANGELOG.md)
- [Commits](https://github.com/facebook/react/commits/v19.1.0/packages/react)

Updates `@types/react` from 18.3.23 to 19.1.6
- [Release notes](https://github.com/DefinitelyTyped/DefinitelyTyped/releases)
- [Commits](https://github.com/DefinitelyTyped/DefinitelyTyped/commits/HEAD/types/react)

---
updated-dependencies:
- dependency-name: react
  dependency-version: 19.1.0
  dependency-type: direct:production
  update-type: version-update:semver-major
- dependency-name: "@types/react"
  dependency-version: 19.1.6
  dependency-type: direct:development
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-06-19 11:02:29 +00:00
383 changed files with 49519 additions and 14533 deletions

View File

@@ -13,7 +13,11 @@ updates:
directory: "/frontend" # Location of package manifests
schedule:
interval: "daily"
- package-ecosystem: "npm"
directory: "/extensions/react-widget"
schedule:
interval: "daily"
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "daily"
interval: "daily"

11
.github/styles/DocsGPT/Spelling.yml vendored Normal file
View File

@@ -0,0 +1,11 @@
extends: spelling
level: warning
message: "Did you really mean '%s'?"
ignore:
- "**/node_modules/**"
- "**/dist/**"
- "**/build/**"
- "**/coverage/**"
- "**/public/**"
- "**/static/**"
vocab: DocsGPT

View File

@@ -0,0 +1,46 @@
Ollama
Qdrant
Milvus
Chatwoot
Nextra
VSCode
npm
LLMs
APIs
Groq
SGLang
LMDeploy
OAuth
Vite
LLM
JSONPath
UIs
configs
uncomment
qdrant
vectorstore
docsgpt
llm
GPUs
kubectl
Lightsail
enqueues
chatbot
VSCode's
Shareability
feedbacks
automations
Premade
Signup
Repo
repo
env
URl
agentic
llama_cpp
parsable
SDKs
boolean
bool
hardcode
EOL

View File

@@ -16,15 +16,15 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pytest pytest-cov
cd application
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
cd ../tests
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Test with pytest and generate coverage report
run: |
python -m pytest --cov=application --cov-report=xml
python -m pytest --cov=application --cov-report=xml --cov-report=term-missing
- name: Upload coverage reports to Codecov
if: github.event_name == 'pull_request' && matrix.python-version == '3.12'
uses: codecov/codecov-action@v5
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}

26
.github/workflows/vale.yml vendored Normal file
View File

@@ -0,0 +1,26 @@
name: Vale Documentation Linter
on:
pull_request:
paths:
- 'docs/**/*.md'
- 'docs/**/*.mdx'
- '**/*.md'
- '.vale.ini'
- '.github/styles/**'
jobs:
vale:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Vale linter
uses: errata-ai/vale-action@v2
with:
files: docs
fail_on_error: false
version: 3.0.5
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

2
.gitignore vendored
View File

@@ -2,7 +2,9 @@
__pycache__/
*.py[cod]
*$py.class
experiments/
experiments
# C extensions
*.so
*.next

5
.vale.ini Normal file
View File

@@ -0,0 +1,5 @@
MinAlertLevel = warning
StylesPath = .github/styles
[*.{md,mdx}]
BasedOnStyles = DocsGPT

33
.vscode/launch.json vendored
View File

@@ -2,15 +2,11 @@
"version": "0.2.0",
"configurations": [
{
"name": "Docker Debug Frontend",
"name": "Frontend Debug (npm)",
"type": "node-terminal",
"request": "launch",
"type": "chrome",
"preLaunchTask": "docker-compose: debug:frontend",
"url": "http://127.0.0.1:5173",
"webRoot": "${workspaceFolder}/frontend",
"skipFiles": [
"<node_internals>/**"
]
"command": "npm run dev",
"cwd": "${workspaceFolder}/frontend"
},
{
"name": "Flask Debugger",
@@ -49,6 +45,27 @@
"--pool=solo"
],
"cwd": "${workspaceFolder}"
},
{
"name": "Dev Containers (Mongo + Redis)",
"type": "node-terminal",
"request": "launch",
"command": "docker compose -f deployment/docker-compose-dev.yaml up --build",
"cwd": "${workspaceFolder}"
}
],
"compounds": [
{
"name": "DocsGPT: Full Stack",
"configurations": [
"Frontend Debug (npm)",
"Flask Debugger",
"Celery Debugger"
],
"presentation": {
"group": "DocsGPT",
"order": 1
}
}
]
}

21
.vscode/tasks.json vendored
View File

@@ -1,21 +0,0 @@
{
"version": "2.0.0",
"tasks": [
{
"type": "docker-compose",
"label": "docker-compose: debug:frontend",
"dockerCompose": {
"up": {
"detached": true,
"services": [
"frontend"
],
"build": true
},
"files": [
"${workspaceFolder}/docker-compose.yaml"
]
}
}
]
}

View File

@@ -147,5 +147,5 @@ Here's a step-by-step guide on how to contribute to DocsGPT:
Thank you for considering contributing to DocsGPT! 🙏
## Questions/collaboration
Feel free to join our [Discord](https://discord.gg/n5BX8dh8rU). We're very friendly and welcoming to new contributors, so don't hesitate to reach out.
Feel free to join our [Discord](https://discord.gg/vN7YFfdMpj). We're very friendly and welcoming to new contributors, so don't hesitate to reach out.
# Thank you so much for considering to contributing DocsGPT!🙏

39
HACKTOBERFEST.md Normal file
View File

@@ -0,0 +1,39 @@
# **🎉 Join the Hacktoberfest with DocsGPT and win a Free T-shirt for a meaningful PR! 🎉**
Welcome, contributors! We're excited to announce that DocsGPT is participating in Hacktoberfest. Get involved by submitting meaningful pull requests.
All Meaningful contributors with accepted PRs that were created for issues with the `hacktoberfest` label (set by our maintainer team: dartpain, siiddhantt, pabik, ManishMadan2882) will receive a cool T-shirt! 🤩.
<img width="1331" height="678" alt="hacktoberfest-mocks-preview" src="https://github.com/user-attachments/assets/633f6377-38db-48f5-b519-a8b3855a9eb4" />
Fill in [this form](https://forms.gle/Npaba4n9Epfyx56S8
) after your PR was merged please
If you are in doubt don't hesitate to ping us on discord, ping me - Alex (dartpain).
## 📜 Here's How to Contribute:
```text
🛠️ Code: This is the golden ticket! Make meaningful contributions through PRs.
🧩 API extension: Build an app utilising DocsGPT API. We prefer submissions that showcase original ideas and turn the API into an AI agent.
They can be a completely separate repos.
For example:
https://github.com/arc53/tg-bot-docsgpt-extenstion or
https://github.com/arc53/DocsGPT-cli
Non-Code Contributions:
📚 Wiki: Improve our documentation, create a guide.
🖥️ Design: Improve the UI/UX or design a new feature.
```
### 📝 Guidelines for Pull Requests:
- Familiarize yourself with the current contributions and our [Roadmap](https://github.com/orgs/arc53/projects/2).
- Before contributing check existing [issues](https://github.com/arc53/DocsGPT/issues) or [create](https://github.com/arc53/DocsGPT/issues/new/choose) an issue and wait to get assigned.
- Once you are finished with your contribution, please fill in this [form](https://forms.gle/Npaba4n9Epfyx56S8).
- Refer to the [Documentation](https://docs.docsgpt.cloud/).
- Feel free to join our [Discord](https://discord.gg/vN7YFfdMpj) server. We're here to help newcomers, so don't hesitate to jump in! Join us [here](https://discord.gg/vN7YFfdMpj).
Thank you very much for considering contributing to DocsGPT during Hacktoberfest! 🙏 Your contributions (not just simple typos) could earn you a stylish new t-shirt.
We will publish a t-shirt design later into the October.

View File

@@ -3,11 +3,11 @@
</h1>
<p align="center">
<strong>Open-Source RAG Assistant</strong>
<strong>Private AI for agents, assistants and enterprise search</strong>
</p>
<p align="left">
<strong><a href="https://www.docsgpt.cloud/">DocsGPT</a></strong> is an open-source genAI tool that helps users get reliable answers from any knowledge source, while avoiding hallucinations. It enables quick and reliable information retrieval, with tooling and agentic system capability built in.
<strong><a href="https://www.docsgpt.cloud/">DocsGPT</a></strong> is an open-source AI platform for building intelligent agents and assistants. Features Agent Builder, deep research tools, document analysis (PDF, Office, web content), Multi-model support (choose your provider or run locally), and rich API connectivity for agents with actionable tools and integrations. Deploy anywhere with complete privacy control.
</p>
<div align="center">
@@ -16,16 +16,19 @@
<a href="https://github.com/arc53/DocsGPT">![link to main GitHub showing Forks number](https://img.shields.io/github/forks/arc53/docsgpt?style=social)</a>
<a href="https://github.com/arc53/DocsGPT/blob/main/LICENSE">![link to license file](https://img.shields.io/github/license/arc53/docsgpt)</a>
<a href="https://www.bestpractices.dev/projects/9907"><img src="https://www.bestpractices.dev/projects/9907/badge"></a>
<a href="https://discord.gg/n5BX8dh8rU">![link to discord](https://img.shields.io/discord/1070046503302877216)</a>
<a href="https://twitter.com/docsgptai">![X (formerly Twitter) URL](https://img.shields.io/twitter/follow/docsgptai)</a>
<a href="https://discord.gg/vN7YFfdMpj">![link to discord](https://img.shields.io/discord/1070046503302877216)</a>
<a href="https://x.com/docsgptai">![X (formerly Twitter) URL](https://img.shields.io/twitter/follow/docsgptai)</a>
<a href="https://docs.docsgpt.cloud/quickstart">⚡️ Quickstart</a><a href="https://app.docsgpt.cloud/">☁️ Cloud Version</a><a href="https://discord.gg/n5BX8dh8rU">💬 Discord</a>
<br>
<a href="https://docs.docsgpt.cloud/">📖 Documentation</a><a href="https://github.com/arc53/DocsGPT/blob/main/CONTRIBUTING.md">👫 Contribute</a><a href="https://blog.docsgpt.cloud/">🗞 Blog</a>
<br>
<a href="https://docs.docsgpt.cloud/quickstart">⚡️ Quickstart</a><a href="https://app.docsgpt.cloud/">☁️ Cloud Version</a><a href="https://discord.gg/vN7YFfdMpj">💬 Discord</a>
<br>
<a href="https://docs.docsgpt.cloud/">📖 Documentation</a><a href="https://github.com/arc53/DocsGPT/blob/main/CONTRIBUTING.md">👫 Contribute</a><a href="https://blog.docsgpt.cloud/">🗞 Blog</a>
<br>
</div>
<div align="center">
<br>
<img src="https://d3dg1063dc54p9.cloudfront.net/videos/demov7.gif" alt="video-example-of-docs-gpt" width="800" height="450">
</div>
<h3 align="left">
@@ -53,9 +56,13 @@
- [x] New input box in the conversation menu (April 2025)
- [x] Add triggerable actions / tools (webhook) (April 2025)
- [x] Agent optimisations (May 2025)
- [ ] Anthropic Tool compatibility (June 2025)
- [ ] MCP support (June 2025)
- [ ] Add OAuth 2.0 authentication for tools and sources (July 2025)
- [x] Filesystem sources update (July 2025)
- [x] Json Responses (August 2025)
- [x] MCP support (August 2025)
- [x] Google Drive integration (September 2025)
- [x] Add OAuth 2.0 authentication for MCP (September 2025)
- [ ] SharePoint integration (October 2025)
- [ ] Deep Agents (October 2025)
- [ ] Agent scheduling
You can find our full roadmap [here](https://github.com/orgs/arc53/projects/2). Please don't hesitate to contribute or create issues, it helps us improve DocsGPT!
@@ -70,11 +77,10 @@ We're eager to provide personalized assistance when deploying your DocsGPT to a
## Join the Lighthouse Program 🌟
Calling all developers and GenAI innovators! The **DocsGPT Lighthouse Program** connects technical leaders actively deploying or extending DocsGPT in real-world scenarios. Collaborate directly with our team to shape the roadmap, access priority support, and build enterprise-ready solutions with exclusive community insights.
Calling all developers and GenAI innovators! The **DocsGPT Lighthouse Program** connects technical leaders actively deploying or extending DocsGPT in real-world scenarios. Collaborate directly with our team to shape the roadmap, access priority support, and build enterprise-ready solutions with exclusive community insights.
[Learn More & Apply →](https://docs.google.com/forms/d/1KAADiJinUJ8EMQyfTXUIGyFbqINNClNR3jBNWq7DgTE)
## QuickStart
> [!Note]
@@ -105,7 +111,7 @@ A more detailed [Quickstart](https://docs.docsgpt.cloud/quickstart) is available
PowerShell -ExecutionPolicy Bypass -File .\setup.ps1
```
Either script will guide you through setting up DocsGPT. Four options available: using the public API, running locally, connecting to a local inference engine, or using a cloud API provider. Scripts will automatically configure your `.env` file and handle necessary downloads and installations based on your chosen option.
Either script will guide you through setting up DocsGPT. Five options available: using the public API, running locally, connecting to a local inference engine, using a cloud API provider, or build the docker image locally. Scripts will automatically configure your `.env` file and handle necessary downloads and installations based on your chosen option.
**Navigate to http://localhost:5173/**
@@ -114,6 +120,7 @@ To stop DocsGPT, open a terminal in the `DocsGPT` directory and run:
```bash
docker compose -f deployment/docker-compose.yaml down
```
(or use the specific `docker compose down` command shown after running the setup script).
> [!Note]
@@ -141,7 +148,6 @@ Please refer to the [CONTRIBUTING.md](CONTRIBUTING.md) file for information abou
We as members, contributors, and leaders, pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. Please refer to the [CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md) file for more information about contributing.
## Many Thanks To Our Contributors⚡
<a href="https://github.com/arc53/DocsGPT/graphs/contributors" alt="View Contributors">

View File

@@ -1,5 +1,8 @@
from application.agents.classic_agent import ClassicAgent
from application.agents.react_agent import ReActAgent
import logging
logger = logging.getLogger(__name__)
class AgentCreator:
@@ -13,4 +16,5 @@ class AgentCreator:
agent_class = cls.agents.get(type.lower())
if not agent_class:
raise ValueError(f"No agent class found for type {type}")
return agent_class(*args, **kwargs)

View File

@@ -1,3 +1,4 @@
import logging
import uuid
from abc import ABC, abstractmethod
from typing import Dict, Generator, List, Optional
@@ -6,14 +7,13 @@ from bson.objectid import ObjectId
from application.agents.tools.tool_action_parser import ToolActionParser
from application.agents.tools.tool_manager import ToolManager
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.llm.handlers.handler_creator import LLMHandlerCreator
from application.llm.llm_creator import LLMCreator
from application.logging import build_stack_data, log_activity, LogContext
from application.retriever.base import BaseRetriever
logger = logging.getLogger(__name__)
class BaseAgent(ABC):
@@ -21,22 +21,29 @@ class BaseAgent(ABC):
self,
endpoint: str,
llm_name: str,
gpt_model: str,
model_id: str,
api_key: str,
user_api_key: Optional[str] = None,
prompt: str = "",
chat_history: Optional[List[Dict]] = None,
retrieved_docs: Optional[List[Dict]] = None,
decoded_token: Optional[Dict] = None,
attachments: Optional[List[Dict]] = None,
json_schema: Optional[Dict] = None,
limited_token_mode: Optional[bool] = False,
token_limit: Optional[int] = settings.DEFAULT_AGENT_LIMITS["token_limit"],
limited_request_mode: Optional[bool] = False,
request_limit: Optional[int] = settings.DEFAULT_AGENT_LIMITS["request_limit"],
compressed_summary: Optional[str] = None,
):
self.endpoint = endpoint
self.llm_name = llm_name
self.gpt_model = gpt_model
self.model_id = model_id
self.api_key = api_key
self.user_api_key = user_api_key
self.prompt = prompt
self.decoded_token = decoded_token or {}
self.user: str = decoded_token.get("sub")
self.user: str = self.decoded_token.get("sub")
self.tool_config: Dict = {}
self.tools: List[Dict] = []
self.tool_calls: List[Dict] = []
@@ -46,21 +53,31 @@ class BaseAgent(ABC):
api_key=api_key,
user_api_key=user_api_key,
decoded_token=decoded_token,
model_id=model_id,
)
self.retrieved_docs = retrieved_docs or []
self.llm_handler = LLMHandlerCreator.create_handler(
llm_name if llm_name else "default"
)
self.attachments = attachments or []
self.json_schema = json_schema
self.limited_token_mode = limited_token_mode
self.token_limit = token_limit
self.limited_request_mode = limited_request_mode
self.request_limit = request_limit
self.compressed_summary = compressed_summary
self.current_token_count = 0
self.context_limit_reached = False
@log_activity()
def gen(
self, query: str, retriever: BaseRetriever, log_context: LogContext = None
self, query: str, log_context: LogContext = None
) -> Generator[Dict, None, None]:
yield from self._gen_inner(query, retriever, log_context)
yield from self._gen_inner(query, log_context)
@abstractmethod
def _gen_inner(
self, query: str, retriever: BaseRetriever, log_context: LogContext
self, query: str, log_context: LogContext
) -> Generator[Dict, None, None]:
pass
@@ -91,8 +108,8 @@ class BaseAgent(ABC):
user_tools_collection = db["user_tools"]
user_tools = user_tools_collection.find({"user": user, "status": True})
user_tools = list(user_tools)
tools_by_id = {str(tool["_id"]): tool for tool in user_tools}
return tools_by_id
return {str(i): tool for i, tool in enumerate(user_tools)}
def _build_tool_parameters(self, action):
params = {"type": "object", "properties": {}, "required": []}
@@ -137,6 +154,41 @@ class BaseAgent(ABC):
tool_id, action_name, call_args = parser.parse_args(call)
call_id = getattr(call, "id", None) or str(uuid.uuid4())
# Check if parsing failed
if tool_id is None or action_name is None:
error_message = f"Error: Failed to parse LLM tool call. Tool name: {getattr(call, 'name', 'unknown')}"
logger.error(error_message)
tool_call_data = {
"tool_name": "unknown",
"call_id": call_id,
"action_name": getattr(call, "name", "unknown"),
"arguments": call_args or {},
"result": f"Failed to parse tool call. Invalid tool name format: {getattr(call, 'name', 'unknown')}",
}
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
self.tool_calls.append(tool_call_data)
return "Failed to parse tool call.", call_id
# Check if tool_id exists in available tools
if tool_id not in tools_dict:
error_message = f"Error: Tool ID '{tool_id}' extracted from LLM call not found in available tools_dict. Available IDs: {list(tools_dict.keys())}"
logger.error(error_message)
# Return error result
tool_call_data = {
"tool_name": "unknown",
"call_id": call_id,
"action_name": f"{action_name}_{tool_id}",
"arguments": call_args,
"result": f"Tool with ID {tool_id} not found. Available tools: {list(tools_dict.keys())}",
}
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
self.tool_calls.append(tool_call_data)
return f"Tool with ID {tool_id} not found.", call_id
tool_call_data = {
"tool_name": tools_dict[tool_id]["name"],
"call_id": call_id,
@@ -176,18 +228,26 @@ class BaseAgent(ABC):
):
target_dict[param] = value
tm = ToolManager(config={})
# Prepare tool_config and add tool_id for memory tools
if tool_data["name"] == "api_tool":
tool_config = {
"url": tool_data["config"]["actions"][action_name]["url"],
"method": tool_data["config"]["actions"][action_name]["method"],
"headers": headers,
"query_params": query_params,
}
else:
tool_config = tool_data["config"].copy() if tool_data["config"] else {}
# Add tool_id from MongoDB _id for tools that need instance isolation (like memory tool)
# Use MongoDB _id if available, otherwise fall back to enumerated tool_id
tool_config["tool_id"] = str(tool_data.get("_id", tool_id))
tool = tm.load_tool(
tool_data["name"],
tool_config=(
{
"url": tool_data["config"]["actions"][action_name]["url"],
"method": tool_data["config"]["actions"][action_name]["method"],
"headers": headers,
"query_params": query_params,
}
if tool_data["name"] == "api_tool"
else tool_data["config"]
),
tool_config=tool_config,
user_id=self.user, # Pass user ID for MCP tools credential decryption
)
if tool_data["name"] == "api_tool":
print(
@@ -220,20 +280,83 @@ class BaseAgent(ABC):
for tool_call in self.tool_calls
]
def _calculate_current_context_tokens(self, messages: List[Dict]) -> int:
"""
Calculate total tokens in current context (messages).
Args:
messages: List of message dicts
Returns:
Total token count
"""
from application.api.answer.services.compression.token_counter import (
TokenCounter,
)
return TokenCounter.count_message_tokens(messages)
def _check_context_limit(self, messages: List[Dict]) -> bool:
"""
Check if we're approaching context limit (80%).
Args:
messages: Current message list
Returns:
True if at or above 80% of context limit
"""
from application.core.model_utils import get_token_limit
from application.core.settings import settings
try:
# Calculate current tokens
current_tokens = self._calculate_current_context_tokens(messages)
self.current_token_count = current_tokens
# Get context limit for model
context_limit = get_token_limit(self.model_id)
# Calculate threshold (80%)
threshold = int(context_limit * settings.COMPRESSION_THRESHOLD_PERCENTAGE)
# Check if we've reached the limit
if current_tokens >= threshold:
logger.warning(
f"Context limit approaching: {current_tokens}/{context_limit} tokens "
f"({(current_tokens/context_limit)*100:.1f}%)"
)
return True
return False
except Exception as e:
logger.error(f"Error checking context limit: {str(e)}", exc_info=True)
return False
def _build_messages(
self,
system_prompt: str,
query: str,
retrieved_data: List[Dict],
) -> List[Dict]:
docs_together = "\n".join([doc["text"] for doc in retrieved_data])
p_chat_combine = system_prompt.replace("{summaries}", docs_together)
messages_combine = [{"role": "system", "content": p_chat_combine}]
"""Build messages using pre-rendered system prompt"""
# Append compression summary to system prompt if present
if self.compressed_summary:
compression_context = (
"\n\n---\n\n"
"This session is being continued from a previous conversation that "
"has been compressed to fit within context limits. "
"The conversation is summarized below:\n\n"
f"{self.compressed_summary}"
)
system_prompt = system_prompt + compression_context
messages = [{"role": "system", "content": system_prompt}]
for i in self.chat_history:
if "prompt" in i and "response" in i:
messages_combine.append({"role": "user", "content": i["prompt"]})
messages_combine.append({"role": "assistant", "content": i["response"]})
messages.append({"role": "user", "content": i["prompt"]})
messages.append({"role": "assistant", "content": i["response"]})
if "tool_calls" in i:
for tool_call in i["tool_calls"]:
call_id = tool_call.get("call_id") or str(uuid.uuid4())
@@ -253,29 +376,17 @@ class BaseAgent(ABC):
}
}
messages_combine.append(
messages.append(
{"role": "assistant", "content": [function_call_dict]}
)
messages_combine.append(
messages.append(
{"role": "tool", "content": [function_response_dict]}
)
messages_combine.append({"role": "user", "content": query})
return messages_combine
def _retriever_search(
self,
retriever: BaseRetriever,
query: str,
log_context: Optional[LogContext] = None,
) -> List[Dict]:
retrieved_data = retriever.search(query)
if log_context:
data = build_stack_data(retriever, exclude_attributes=["llm"])
log_context.stacks.append({"component": "retriever", "data": data})
return retrieved_data
messages.append({"role": "user", "content": query})
return messages
def _llm_gen(self, messages: List[Dict], log_context: Optional[LogContext] = None):
gen_kwargs = {"model": self.gpt_model, "messages": messages}
gen_kwargs = {"model": self.model_id, "messages": messages}
if (
hasattr(self.llm, "_supports_tools")
@@ -283,6 +394,19 @@ class BaseAgent(ABC):
and self.tools
):
gen_kwargs["tools"] = self.tools
if (
self.json_schema
and hasattr(self.llm, "_supports_structured_output")
and self.llm._supports_structured_output()
):
structured_format = self.llm.prepare_structured_output_format(
self.json_schema
)
if structured_format:
if self.llm_name == "openai":
gen_kwargs["response_format"] = structured_format
elif self.llm_name == "google":
gen_kwargs["response_schema"] = structured_format
resp = self.llm.gen_stream(**gen_kwargs)
if log_context:
@@ -307,21 +431,42 @@ class BaseAgent(ABC):
return resp
def _handle_response(self, response, tools_dict, messages, log_context):
is_structured_output = (
self.json_schema is not None
and hasattr(self.llm, "_supports_structured_output")
and self.llm._supports_structured_output()
)
if isinstance(response, str):
yield {"answer": response}
answer_data = {"answer": response}
if is_structured_output:
answer_data["structured"] = True
answer_data["schema"] = self.json_schema
yield answer_data
return
if hasattr(response, "message") and getattr(response.message, "content", None):
yield {"answer": response.message.content}
answer_data = {"answer": response.message.content}
if is_structured_output:
answer_data["structured"] = True
answer_data["schema"] = self.json_schema
yield answer_data
return
processed_response_gen = self._llm_handler(
response, tools_dict, messages, log_context, self.attachments
)
for event in processed_response_gen:
if isinstance(event, str):
yield {"answer": event}
answer_data = {"answer": event}
if is_structured_output:
answer_data["structured"] = True
answer_data["schema"] = self.json_schema
yield answer_data
elif hasattr(event, "message") and getattr(event.message, "content", None):
yield {"answer": event.message.content}
answer_data = {"answer": event.message.content}
if is_structured_output:
answer_data["structured"] = True
answer_data["schema"] = self.json_schema
yield answer_data
elif isinstance(event, dict) and "type" in event:
yield event

View File

@@ -1,32 +1,20 @@
import logging
from typing import Dict, Generator
from application.agents.base import BaseAgent
from application.logging import LogContext
from application.retriever.base import BaseRetriever
import logging
logger = logging.getLogger(__name__)
class ClassicAgent(BaseAgent):
"""A simplified classic agent with clear execution flow.
Usage:
1. Processes a query through retrieval
2. Sets up available tools
3. Generates responses using LLM
4. Handles tool interactions if needed
5. Returns standardized outputs
Easy to extend by overriding specific steps.
"""
"""A simplified agent with clear execution flow"""
def _gen_inner(
self, query: str, retriever: BaseRetriever, log_context: LogContext
self, query: str, log_context: LogContext
) -> Generator[Dict, None, None]:
# Step 1: Retrieve relevant data
retrieved_data = self._retriever_search(retriever, query, log_context)
"""Core generator function for ClassicAgent execution flow"""
# Step 2: Prepare tools
tools_dict = (
self._get_user_tools(self.user)
if not self.user_api_key
@@ -34,20 +22,16 @@ class ClassicAgent(BaseAgent):
)
self._prepare_tools(tools_dict)
# Step 3: Build and process messages
messages = self._build_messages(self.prompt, query, retrieved_data)
messages = self._build_messages(self.prompt, query)
llm_response = self._llm_gen(messages, log_context)
# Step 4: Handle the response
yield from self._handle_response(
llm_response, tools_dict, messages, log_context
)
# Step 5: Return metadata
yield {"sources": retrieved_data}
yield {"sources": self.retrieved_docs}
yield {"tool_calls": self._get_truncated_tool_calls()}
# Log tool calls for debugging
log_context.stacks.append(
{"component": "agent", "data": {"tool_calls": self.tool_calls.copy()}}
)

View File

@@ -1,229 +1,238 @@
import os
from typing import Dict, Generator, List, Any
import logging
import os
from typing import Any, Dict, Generator, List
from application.agents.base import BaseAgent
from application.logging import build_stack_data, LogContext
from application.retriever.base import BaseRetriever
logger = logging.getLogger(__name__)
MAX_ITERATIONS_REASONING = 10
current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
with open(
os.path.join(current_dir, "application/prompts", "react_planning_prompt.txt"), "r"
) as f:
planning_prompt_template = f.read()
PLANNING_PROMPT_TEMPLATE = f.read()
with open(
os.path.join(current_dir, "application/prompts", "react_final_prompt.txt"),
"r",
os.path.join(current_dir, "application/prompts", "react_final_prompt.txt"), "r"
) as f:
final_prompt_template = f.read()
MAX_ITERATIONS_REASONING = 10
FINAL_PROMPT_TEMPLATE = f.read()
class ReActAgent(BaseAgent):
"""
Research and Action (ReAct) Agent - Advanced reasoning agent with iterative planning.
Implements a think-act-observe loop for complex problem-solving:
1. Creates a strategic plan based on the query
2. Executes tools and gathers observations
3. Iteratively refines approach until satisfied
4. Synthesizes final answer from all observations
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.plan: str = ""
self.observations: List[str] = []
def _extract_content_from_llm_response(self, resp: Any) -> str:
"""
Helper to extract string content from various LLM response types.
Handles strings, message objects (OpenAI-like), and streams.
Adapt stream handling for your specific LLM client if not OpenAI.
"""
collected_content = []
if isinstance(resp, str):
collected_content.append(resp)
elif ( # OpenAI non-streaming or Anthropic non-streaming (older SDK style)
hasattr(resp, "message")
and hasattr(resp.message, "content")
and resp.message.content is not None
):
collected_content.append(resp.message.content)
elif ( # OpenAI non-streaming (Pydantic model), Anthropic new SDK non-streaming
hasattr(resp, "choices") and resp.choices and
hasattr(resp.choices[0], "message") and
hasattr(resp.choices[0].message, "content") and
resp.choices[0].message.content is not None
):
collected_content.append(resp.choices[0].message.content) # OpenAI
elif ( # Anthropic new SDK non-streaming content block
hasattr(resp, "content") and isinstance(resp.content, list) and resp.content and
hasattr(resp.content[0], "text")
):
collected_content.append(resp.content[0].text) # Anthropic
else:
# Assume resp is a stream if not a recognized object
try:
for chunk in resp: # This will fail if resp is not iterable (e.g. a non-streaming response object)
content_piece = ""
# OpenAI-like stream
if hasattr(chunk, 'choices') and len(chunk.choices) > 0 and \
hasattr(chunk.choices[0], 'delta') and \
hasattr(chunk.choices[0].delta, 'content') and \
chunk.choices[0].delta.content is not None:
content_piece = chunk.choices[0].delta.content
# Anthropic-like stream (ContentBlockDelta)
elif hasattr(chunk, 'type') and chunk.type == 'content_block_delta' and \
hasattr(chunk, 'delta') and hasattr(chunk.delta, 'text'):
content_piece = chunk.delta.text
elif isinstance(chunk, str): # Simplest case: stream of strings
content_piece = chunk
if content_piece:
collected_content.append(content_piece)
except TypeError: # If resp is not iterable (e.g. a final response object that wasn't caught above)
logger.debug(f"Response type {type(resp)} could not be iterated as a stream. It might be a non-streaming object not handled by specific checks.")
except Exception as e:
logger.error(f"Error processing potential stream chunk: {e}, chunk was: {getattr(chunk, '__dict__', chunk)}")
return "".join(collected_content)
def _gen_inner(
self, query: str, retriever: BaseRetriever, log_context: LogContext
self, query: str, log_context: LogContext
) -> Generator[Dict, None, None]:
# Reset state for this generation call
self.plan = ""
self.observations = []
retrieved_data = self._retriever_search(retriever, query, log_context)
"""Execute ReAct reasoning loop with planning, action, and observation cycles"""
if self.user_api_key:
tools_dict = self._get_tools(self.user_api_key)
else:
tools_dict = self._get_user_tools(self.user)
self._reset_state()
tools_dict = (
self._get_tools(self.user_api_key)
if self.user_api_key
else self._get_user_tools(self.user)
)
self._prepare_tools(tools_dict)
docs_together = "\n".join([doc["text"] for doc in retrieved_data])
iterating_reasoning = 0
while iterating_reasoning < MAX_ITERATIONS_REASONING:
iterating_reasoning += 1
# 1. Create Plan
logger.info("ReActAgent: Creating plan...")
plan_stream = self._create_plan(query, docs_together, log_context)
current_plan_parts = []
yield {"thought": f"Reasoning... (iteration {iterating_reasoning})\n\n"}
for line_chunk in plan_stream:
current_plan_parts.append(line_chunk)
yield {"thought": line_chunk}
self.plan = "".join(current_plan_parts)
if self.plan:
self.observations.append(f"Plan: {self.plan} Iteration: {iterating_reasoning}")
for iteration in range(1, MAX_ITERATIONS_REASONING + 1):
yield {"thought": f"Reasoning... (iteration {iteration})\n\n"}
yield from self._planning_phase(query, log_context)
max_obs_len = 20000
obs_str = "\n".join(self.observations)
if len(obs_str) > max_obs_len:
obs_str = obs_str[:max_obs_len] + "\n...[observations truncated]"
execution_prompt_str = (
(self.prompt or "")
+ f"\n\nFollow this plan:\n{self.plan}"
+ f"\n\nObservations:\n{obs_str}"
+ f"\n\nIf there is enough data to complete user query '{query}', Respond with 'SATISFIED' only. Otherwise, continue. Dont Menstion 'SATISFIED' in your response if you are not ready. "
)
messages = self._build_messages(execution_prompt_str, query, retrieved_data)
resp_from_llm_gen = self._llm_gen(messages, log_context)
initial_llm_thought_content = self._extract_content_from_llm_response(resp_from_llm_gen)
if initial_llm_thought_content:
self.observations.append(f"Initial thought/response: {initial_llm_thought_content}")
else:
logger.info("ReActAgent: Initial LLM response (before handler) had no textual content (might be only tool calls).")
resp_after_handler = self._llm_handler(resp_from_llm_gen, tools_dict, messages, log_context)
for tool_call_info in self.tool_calls: # Iterate over self.tool_calls populated by _llm_handler
observation_string = (
f"Executed Action: Tool '{tool_call_info.get('tool_name', 'N/A')}' "
f"with arguments '{tool_call_info.get('arguments', '{}')}'. Result: '{str(tool_call_info.get('result', ''))[:200]}...'"
if not self.plan:
logger.warning(
f"ReActAgent: No plan generated in iteration {iteration}"
)
self.observations.append(observation_string)
content_after_handler = self._extract_content_from_llm_response(resp_after_handler)
if content_after_handler:
self.observations.append(f"Response after tool execution: {content_after_handler}")
else:
logger.info("ReActAgent: LLM response after handler had no textual content.")
if log_context:
log_context.stacks.append(
{"component": "agent_tool_calls", "data": {"tool_calls": self.tool_calls.copy()}}
)
yield {"sources": retrieved_data}
display_tool_calls = []
for tc in self.tool_calls:
cleaned_tc = tc.copy()
if len(str(cleaned_tc.get("result", ""))) > 50:
cleaned_tc["result"] = str(cleaned_tc["result"])[:50] + "..."
display_tool_calls.append(cleaned_tc)
if display_tool_calls:
yield {"tool_calls": display_tool_calls}
if "SATISFIED" in content_after_handler:
logger.info("ReActAgent: LLM satisfied with the plan and data. Stopping reasoning.")
break
self.observations.append(f"Plan (iteration {iteration}): {self.plan}")
# 3. Create Final Answer based on all observations
final_answer_stream = self._create_final_answer(query, self.observations, log_context)
for answer_chunk in final_answer_stream:
yield {"answer": answer_chunk}
logger.info("ReActAgent: Finished generating final answer.")
satisfied = yield from self._execution_phase(query, tools_dict, log_context)
def _create_plan(
self, query: str, docs_data: str, log_context: LogContext = None
) -> Generator[str, None, None]:
plan_prompt_filled = planning_prompt_template.replace("{query}", query)
if "{summaries}" in plan_prompt_filled:
summaries = docs_data if docs_data else "No documents retrieved."
plan_prompt_filled = plan_prompt_filled.replace("{summaries}", summaries)
plan_prompt_filled = plan_prompt_filled.replace("{prompt}", self.prompt or "")
plan_prompt_filled = plan_prompt_filled.replace("{observations}", "\n".join(self.observations))
if satisfied:
logger.info("ReActAgent: Goal satisfied, stopping reasoning loop")
break
yield from self._synthesis_phase(query, log_context)
messages = [{"role": "user", "content": plan_prompt_filled}]
def _reset_state(self):
"""Reset agent state for new query"""
self.plan = ""
self.observations = []
plan_stream_from_llm = self.llm.gen_stream(
model=self.gpt_model, messages=messages, tools=getattr(self, 'tools', None) # Use self.tools
def _planning_phase(
self, query: str, log_context: LogContext
) -> Generator[Dict, None, None]:
"""Generate strategic plan for query"""
logger.info("ReActAgent: Creating plan...")
plan_prompt = self._build_planning_prompt(query)
messages = [{"role": "user", "content": plan_prompt}]
plan_stream = self.llm.gen_stream(
model=self.model_id,
messages=messages,
tools=self.tools if self.tools else None,
)
if log_context:
data = build_stack_data(self.llm)
log_context.stacks.append({"component": "planning_llm", "data": data})
log_context.stacks.append(
{"component": "planning_llm", "data": build_stack_data(self.llm)}
)
plan_parts = []
for chunk in plan_stream:
content = self._extract_content(chunk)
if content:
plan_parts.append(content)
yield {"thought": content}
self.plan = "".join(plan_parts)
for chunk in plan_stream_from_llm:
content_piece = self._extract_content_from_llm_response(chunk)
if content_piece:
yield content_piece
def _execution_phase(
self, query: str, tools_dict: Dict, log_context: LogContext
) -> Generator[bool, None, None]:
"""Execute plan with tool calls and observations"""
execution_prompt = self._build_execution_prompt(query)
messages = self._build_messages(execution_prompt, query)
def _create_final_answer(
self, query: str, observations: List[str], log_context: LogContext = None
) -> Generator[str, None, None]:
observation_string = "\n".join(observations)
max_obs_len = 10000
if len(observation_string) > max_obs_len:
observation_string = observation_string[:max_obs_len] + "\n...[observations truncated]"
logger.warning("ReActAgent: Truncated observations for final answer prompt due to length.")
llm_response = self._llm_gen(messages, log_context)
initial_content = self._extract_content(llm_response)
final_answer_prompt_filled = final_prompt_template.format(
query=query, observations=observation_string
if initial_content:
self.observations.append(f"Initial response: {initial_content}")
processed_response = self._llm_handler(
llm_response, tools_dict, messages, log_context
)
messages = [{"role": "user", "content": final_answer_prompt_filled}]
# Final answer should synthesize, not call tools.
final_answer_stream_from_llm = self.llm.gen_stream(
model=self.gpt_model, messages=messages, tools=None
)
for tool_call in self.tool_calls:
observation = (
f"Executed: {tool_call.get('tool_name', 'Unknown')} "
f"with args {tool_call.get('arguments', {})}. "
f"Result: {str(tool_call.get('result', ''))[:200]}"
)
self.observations.append(observation)
final_content = self._extract_content(processed_response)
if final_content:
self.observations.append(f"Response after tools: {final_content}")
if log_context:
data = build_stack_data(self.llm)
log_context.stacks.append({"component": "final_answer_llm", "data": data})
log_context.stacks.append(
{
"component": "agent_tool_calls",
"data": {"tool_calls": self.tool_calls.copy()},
}
)
yield {"sources": self.retrieved_docs}
yield {"tool_calls": self._get_truncated_tool_calls()}
for chunk in final_answer_stream_from_llm:
content_piece = self._extract_content_from_llm_response(chunk)
if content_piece:
yield content_piece
return "SATISFIED" in (final_content or "")
def _synthesis_phase(
self, query: str, log_context: LogContext
) -> Generator[Dict, None, None]:
"""Synthesize final answer from all observations"""
logger.info("ReActAgent: Generating final answer...")
final_prompt = self._build_final_answer_prompt(query)
messages = [{"role": "user", "content": final_prompt}]
final_stream = self.llm.gen_stream(
model=self.model_id, messages=messages, tools=None
)
if log_context:
log_context.stacks.append(
{"component": "final_answer_llm", "data": build_stack_data(self.llm)}
)
for chunk in final_stream:
content = self._extract_content(chunk)
if content:
yield {"answer": content}
def _build_planning_prompt(self, query: str) -> str:
"""Build planning phase prompt"""
prompt = PLANNING_PROMPT_TEMPLATE.replace("{query}", query)
prompt = prompt.replace("{prompt}", self.prompt or "")
prompt = prompt.replace("{summaries}", "")
prompt = prompt.replace("{observations}", "\n".join(self.observations))
return prompt
def _build_execution_prompt(self, query: str) -> str:
"""Build execution phase prompt with plan and observations"""
observations_str = "\n".join(self.observations)
if len(observations_str) > 20000:
observations_str = observations_str[:20000] + "\n...[truncated]"
return (
f"{self.prompt or ''}\n\n"
f"Follow this plan:\n{self.plan}\n\n"
f"Observations:\n{observations_str}\n\n"
f"If sufficient data exists to answer '{query}', respond with 'SATISFIED'. "
f"Otherwise, continue executing the plan."
)
def _build_final_answer_prompt(self, query: str) -> str:
"""Build final synthesis prompt"""
observations_str = "\n".join(self.observations)
if len(observations_str) > 10000:
observations_str = observations_str[:10000] + "\n...[truncated]"
logger.warning("ReActAgent: Observations truncated for final answer")
return FINAL_PROMPT_TEMPLATE.format(query=query, observations=observations_str)
def _extract_content(self, response: Any) -> str:
"""Extract text content from various LLM response formats"""
if not response:
return ""
collected = []
if isinstance(response, str):
return response
if hasattr(response, "message") and hasattr(response.message, "content"):
if response.message.content:
return response.message.content
if hasattr(response, "choices") and response.choices:
if hasattr(response.choices[0], "message"):
content = response.choices[0].message.content
if content:
return content
if hasattr(response, "content") and isinstance(response.content, list):
if response.content and hasattr(response.content[0], "text"):
return response.content[0].text
try:
for chunk in response:
content_piece = ""
if hasattr(chunk, "choices") and chunk.choices:
if hasattr(chunk.choices[0], "delta"):
delta_content = chunk.choices[0].delta.content
if delta_content:
content_piece = delta_content
elif hasattr(chunk, "type") and chunk.type == "content_block_delta":
if hasattr(chunk, "delta") and hasattr(chunk.delta, "text"):
content_piece = chunk.delta.text
elif isinstance(chunk, str):
content_piece = chunk
if content_piece:
collected.append(content_piece)
except (TypeError, AttributeError):
logger.debug(
f"Response not iterable or unexpected format: {type(response)}"
)
except Exception as e:
logger.error(f"Error extracting content: {e}")
return "".join(collected)

View File

@@ -25,27 +25,35 @@ class BraveSearchTool(Tool):
else:
raise ValueError(f"Unknown action: {action_name}")
def _web_search(self, query, country="ALL", search_lang="en", count=10,
offset=0, safesearch="off", freshness=None,
result_filter=None, extra_snippets=False, summary=False):
def _web_search(
self,
query,
country="ALL",
search_lang="en",
count=10,
offset=0,
safesearch="off",
freshness=None,
result_filter=None,
extra_snippets=False,
summary=False,
):
"""
Performs a web search using the Brave Search API.
"""
print(f"Performing Brave web search for: {query}")
url = f"{self.base_url}/web/search"
# Build query parameters
params = {
"q": query,
"country": country,
"search_lang": search_lang,
"count": min(count, 20),
"offset": min(offset, 9),
"safesearch": safesearch
"safesearch": safesearch,
}
# Add optional parameters only if they have values
if freshness:
params["freshness"] = freshness
if result_filter:
@@ -54,68 +62,69 @@ class BraveSearchTool(Tool):
params["extra_snippets"] = 1
if summary:
params["summary"] = 1
# Set up headers
headers = {
"Accept": "application/json",
"Accept-Encoding": "gzip",
"X-Subscription-Token": self.token
"X-Subscription-Token": self.token,
}
# Make the request
response = requests.get(url, params=params, headers=headers)
if response.status_code == 200:
return {
"status_code": response.status_code,
"results": response.json(),
"message": "Search completed successfully."
"message": "Search completed successfully.",
}
else:
return {
"status_code": response.status_code,
"message": f"Search failed with status code: {response.status_code}."
"message": f"Search failed with status code: {response.status_code}.",
}
def _image_search(self, query, country="ALL", search_lang="en", count=5,
safesearch="off", spellcheck=False):
def _image_search(
self,
query,
country="ALL",
search_lang="en",
count=5,
safesearch="off",
spellcheck=False,
):
"""
Performs an image search using the Brave Search API.
"""
print(f"Performing Brave image search for: {query}")
url = f"{self.base_url}/images/search"
# Build query parameters
params = {
"q": query,
"country": country,
"search_lang": search_lang,
"count": min(count, 100), # API max is 100
"safesearch": safesearch,
"spellcheck": 1 if spellcheck else 0
"spellcheck": 1 if spellcheck else 0,
}
# Set up headers
headers = {
"Accept": "application/json",
"Accept-Encoding": "gzip",
"X-Subscription-Token": self.token
"X-Subscription-Token": self.token,
}
# Make the request
response = requests.get(url, params=params, headers=headers)
if response.status_code == 200:
return {
"status_code": response.status_code,
"results": response.json(),
"message": "Image search completed successfully."
"message": "Image search completed successfully.",
}
else:
return {
"status_code": response.status_code,
"message": f"Image search failed with status code: {response.status_code}."
"message": f"Image search failed with status code: {response.status_code}.",
}
def get_actions_metadata(self):
@@ -130,42 +139,14 @@ class BraveSearchTool(Tool):
"type": "string",
"description": "The search query (max 400 characters, 50 words)",
},
# "country": {
# "type": "string",
# "description": "The 2-character country code (default: US)",
# },
"search_lang": {
"type": "string",
"description": "The search language preference (default: en)",
},
# "count": {
# "type": "integer",
# "description": "Number of results to return (max 20, default: 10)",
# },
# "offset": {
# "type": "integer",
# "description": "Pagination offset (max 9, default: 0)",
# },
# "safesearch": {
# "type": "string",
# "description": "Filter level for adult content (off, moderate, strict)",
# },
"freshness": {
"type": "string",
"description": "Time filter for results (pd: last 24h, pw: last week, pm: last month, py: last year)",
},
# "result_filter": {
# "type": "string",
# "description": "Comma-delimited list of result types to include",
# },
# "extra_snippets": {
# "type": "boolean",
# "description": "Get additional excerpts from result pages",
# },
# "summary": {
# "type": "boolean",
# "description": "Enable summary generation in search results",
# }
},
"required": ["query"],
"additionalProperties": False,
@@ -181,37 +162,21 @@ class BraveSearchTool(Tool):
"type": "string",
"description": "The search query (max 400 characters, 50 words)",
},
# "country": {
# "type": "string",
# "description": "The 2-character country code (default: US)",
# },
# "search_lang": {
# "type": "string",
# "description": "The search language preference (default: en)",
# },
"count": {
"type": "integer",
"description": "Number of results to return (max 100, default: 5)",
},
# "safesearch": {
# "type": "string",
# "description": "Filter level for adult content (off, strict). Default: strict",
# },
# "spellcheck": {
# "type": "boolean",
# "description": "Whether to spellcheck provided query (default: true)",
# }
},
"required": ["query"],
"additionalProperties": False,
},
}
},
]
def get_config_requirements(self):
return {
"token": {
"type": "string",
"description": "Brave Search API key for authentication"
"type": "string",
"description": "Brave Search API key for authentication",
},
}
}

View File

@@ -0,0 +1,114 @@
from application.agents.tools.base import Tool
from duckduckgo_search import DDGS
class DuckDuckGoSearchTool(Tool):
"""
DuckDuckGo Search
A tool for performing web and image searches using DuckDuckGo.
"""
def __init__(self, config):
self.config = config
def execute_action(self, action_name, **kwargs):
actions = {
"ddg_web_search": self._web_search,
"ddg_image_search": self._image_search,
}
if action_name in actions:
return actions[action_name](**kwargs)
else:
raise ValueError(f"Unknown action: {action_name}")
def _web_search(
self,
query,
max_results=5,
):
print(f"Performing DuckDuckGo web search for: {query}")
try:
results = DDGS().text(
query,
max_results=max_results,
)
return {
"status_code": 200,
"results": results,
"message": "Web search completed successfully.",
}
except Exception as e:
return {
"status_code": 500,
"message": f"Web search failed: {str(e)}",
}
def _image_search(
self,
query,
max_results=5,
):
print(f"Performing DuckDuckGo image search for: {query}")
try:
results = DDGS().images(
keywords=query,
max_results=max_results,
)
return {
"status_code": 200,
"results": results,
"message": "Image search completed successfully.",
}
except Exception as e:
return {
"status_code": 500,
"message": f"Image search failed: {str(e)}",
}
def get_actions_metadata(self):
return [
{
"name": "ddg_web_search",
"description": "Perform a web search using DuckDuckGo.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query",
},
"max_results": {
"type": "integer",
"description": "Number of results to return (default: 5)",
},
},
"required": ["query"],
},
},
{
"name": "ddg_image_search",
"description": "Perform an image search using DuckDuckGo.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query",
},
"max_results": {
"type": "integer",
"description": "Number of results to return (default: 5, max: 50)",
},
},
"required": ["query"],
},
},
]
def get_config_requirements(self):
return {}

View File

@@ -0,0 +1,861 @@
import asyncio
import base64
import json
import logging
import time
from typing import Any, Dict, List, Optional
from urllib.parse import parse_qs, urlparse
from application.agents.tools.base import Tool
from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task
from application.cache import get_redis_instance
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.security.encryption import decrypt_credentials
from fastmcp import Client
from fastmcp.client.auth import BearerAuth
from fastmcp.client.transports import (
SSETransport,
StdioTransport,
StreamableHttpTransport,
)
from mcp.client.auth import OAuthClientProvider, TokenStorage
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
from pydantic import AnyHttpUrl, ValidationError
from redis import Redis
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
_mcp_clients_cache = {}
class MCPTool(Tool):
"""
MCP Tool
Connect to remote Model Context Protocol (MCP) servers to access dynamic tools and resources.
"""
def __init__(self, config: Dict[str, Any], user_id: Optional[str] = None):
"""
Initialize the MCP Tool with configuration.
Args:
config: Dictionary containing MCP server configuration:
- server_url: URL of the remote MCP server
- transport_type: Transport type (auto, sse, http, stdio)
- auth_type: Type of authentication (bearer, oauth, api_key, basic, none)
- encrypted_credentials: Encrypted credentials (if available)
- timeout: Request timeout in seconds (default: 30)
- headers: Custom headers for requests
- command: Command for STDIO transport
- args: Arguments for STDIO transport
- oauth_scopes: OAuth scopes for oauth auth type
- oauth_client_name: OAuth client name for oauth auth type
user_id: User ID for decrypting credentials (required if encrypted_credentials exist)
"""
self.config = config
self.user_id = user_id
self.server_url = config.get("server_url", "")
self.transport_type = config.get("transport_type", "auto")
self.auth_type = config.get("auth_type", "none")
self.timeout = config.get("timeout", 30)
self.custom_headers = config.get("headers", {})
self.auth_credentials = {}
if config.get("encrypted_credentials") and user_id:
self.auth_credentials = decrypt_credentials(
config["encrypted_credentials"], user_id
)
else:
self.auth_credentials = config.get("auth_credentials", {})
self.oauth_scopes = config.get("oauth_scopes", [])
self.oauth_task_id = config.get("oauth_task_id", None)
self.oauth_client_name = config.get("oauth_client_name", "DocsGPT-MCP")
self.redirect_uri = f"{settings.API_URL}/api/mcp_server/callback"
self.available_tools = []
self._cache_key = self._generate_cache_key()
self._client = None
# Only validate and setup if server_url is provided and not OAuth
if self.server_url and self.auth_type != "oauth":
self._setup_client()
def _generate_cache_key(self) -> str:
"""Generate a unique cache key for this MCP server configuration."""
auth_key = ""
if self.auth_type == "oauth":
scopes_str = ",".join(self.oauth_scopes) if self.oauth_scopes else "none"
auth_key = f"oauth:{self.oauth_client_name}:{scopes_str}"
elif self.auth_type in ["bearer"]:
token = self.auth_credentials.get(
"bearer_token", ""
) or self.auth_credentials.get("access_token", "")
auth_key = f"bearer:{token[:10]}..." if token else "bearer:none"
elif self.auth_type == "api_key":
api_key = self.auth_credentials.get("api_key", "")
auth_key = f"apikey:{api_key[:10]}..." if api_key else "apikey:none"
elif self.auth_type == "basic":
username = self.auth_credentials.get("username", "")
auth_key = f"basic:{username}"
else:
auth_key = "none"
return f"{self.server_url}#{self.transport_type}#{auth_key}"
def _setup_client(self):
"""Setup FastMCP client with proper transport and authentication."""
global _mcp_clients_cache
if self._cache_key in _mcp_clients_cache:
cached_data = _mcp_clients_cache[self._cache_key]
if time.time() - cached_data["created_at"] < 1800:
self._client = cached_data["client"]
return
else:
del _mcp_clients_cache[self._cache_key]
transport = self._create_transport()
auth = None
if self.auth_type == "oauth":
redis_client = get_redis_instance()
auth = DocsGPTOAuth(
mcp_url=self.server_url,
scopes=self.oauth_scopes,
redis_client=redis_client,
redirect_uri=self.redirect_uri,
task_id=self.oauth_task_id,
db=db,
user_id=self.user_id,
)
elif self.auth_type == "bearer":
token = self.auth_credentials.get(
"bearer_token", ""
) or self.auth_credentials.get("access_token", "")
if token:
auth = BearerAuth(token)
self._client = Client(transport, auth=auth)
_mcp_clients_cache[self._cache_key] = {
"client": self._client,
"created_at": time.time(),
}
def _create_transport(self):
"""Create appropriate transport based on configuration."""
headers = {"Content-Type": "application/json", "User-Agent": "DocsGPT-MCP/1.0"}
headers.update(self.custom_headers)
if self.auth_type == "api_key":
api_key = self.auth_credentials.get("api_key", "")
header_name = self.auth_credentials.get("api_key_header", "X-API-Key")
if api_key:
headers[header_name] = api_key
elif self.auth_type == "basic":
username = self.auth_credentials.get("username", "")
password = self.auth_credentials.get("password", "")
if username and password:
credentials = base64.b64encode(
f"{username}:{password}".encode()
).decode()
headers["Authorization"] = f"Basic {credentials}"
if self.transport_type == "auto":
if "sse" in self.server_url.lower() or self.server_url.endswith("/sse"):
transport_type = "sse"
else:
transport_type = "http"
else:
transport_type = self.transport_type
if transport_type == "sse":
headers.update({"Accept": "text/event-stream", "Cache-Control": "no-cache"})
return SSETransport(url=self.server_url, headers=headers)
elif transport_type == "http":
return StreamableHttpTransport(url=self.server_url, headers=headers)
elif transport_type == "stdio":
command = self.config.get("command", "python")
args = self.config.get("args", [])
env = self.auth_credentials if self.auth_credentials else None
return StdioTransport(command=command, args=args, env=env)
else:
return StreamableHttpTransport(url=self.server_url, headers=headers)
def _format_tools(self, tools_response) -> List[Dict]:
"""Format tools response to match expected format."""
if hasattr(tools_response, "tools"):
tools = tools_response.tools
elif isinstance(tools_response, list):
tools = tools_response
else:
tools = []
tools_dict = []
for tool in tools:
if hasattr(tool, "name"):
tool_dict = {
"name": tool.name,
"description": tool.description,
}
if hasattr(tool, "inputSchema"):
tool_dict["inputSchema"] = tool.inputSchema
tools_dict.append(tool_dict)
elif isinstance(tool, dict):
tools_dict.append(tool)
else:
if hasattr(tool, "model_dump"):
tools_dict.append(tool.model_dump())
else:
tools_dict.append({"name": str(tool), "description": ""})
return tools_dict
async def _execute_with_client(self, operation: str, *args, **kwargs):
"""Execute operation with FastMCP client."""
if not self._client:
raise Exception("FastMCP client not initialized")
async with self._client:
if operation == "ping":
return await self._client.ping()
elif operation == "list_tools":
tools_response = await self._client.list_tools()
self.available_tools = self._format_tools(tools_response)
return self.available_tools
elif operation == "call_tool":
tool_name = args[0]
tool_args = kwargs
return await self._client.call_tool(tool_name, tool_args)
elif operation == "list_resources":
return await self._client.list_resources()
elif operation == "list_prompts":
return await self._client.list_prompts()
else:
raise Exception(f"Unknown operation: {operation}")
def _run_async_operation(self, operation: str, *args, **kwargs):
"""Run async operation in sync context."""
try:
try:
loop = asyncio.get_running_loop()
import concurrent.futures
def run_in_thread():
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
try:
return new_loop.run_until_complete(
self._execute_with_client(operation, *args, **kwargs)
)
finally:
new_loop.close()
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_in_thread)
return future.result(timeout=self.timeout)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(
self._execute_with_client(operation, *args, **kwargs)
)
finally:
loop.close()
except Exception as e:
print(f"Error occurred while running async operation: {e}")
raise
def discover_tools(self) -> List[Dict]:
"""
Discover available tools from the MCP server using FastMCP.
Returns:
List of tool definitions from the server
"""
if not self.server_url:
return []
if not self._client:
self._setup_client()
try:
tools = self._run_async_operation("list_tools")
self.available_tools = tools
return self.available_tools
except Exception as e:
raise Exception(f"Failed to discover tools from MCP server: {str(e)}")
def execute_action(self, action_name: str, **kwargs) -> Any:
"""
Execute an action on the remote MCP server using FastMCP.
Args:
action_name: Name of the action to execute
**kwargs: Parameters for the action
Returns:
Result from the MCP server
"""
if not self.server_url:
raise Exception("No MCP server configured")
if not self._client:
self._setup_client()
cleaned_kwargs = {}
for key, value in kwargs.items():
if value == "" or value is None:
continue
cleaned_kwargs[key] = value
try:
result = self._run_async_operation(
"call_tool", action_name, **cleaned_kwargs
)
return self._format_result(result)
except Exception as e:
raise Exception(f"Failed to execute action '{action_name}': {str(e)}")
def _format_result(self, result) -> Dict:
"""Format FastMCP result to match expected format."""
if hasattr(result, "content"):
content_list = []
for content_item in result.content:
if hasattr(content_item, "text"):
content_list.append({"type": "text", "text": content_item.text})
elif hasattr(content_item, "data"):
content_list.append({"type": "data", "data": content_item.data})
else:
content_list.append(
{"type": "unknown", "content": str(content_item)}
)
return {
"content": content_list,
"isError": getattr(result, "isError", False),
}
else:
return result
def test_connection(self) -> Dict:
"""
Test the connection to the MCP server and validate functionality.
Returns:
Dictionary with connection test results including tool count
"""
if not self.server_url:
return {
"success": False,
"message": "No MCP server URL configured",
"tools_count": 0,
"transport_type": self.transport_type,
"auth_type": self.auth_type,
"error_type": "ConfigurationError",
}
if not self._client:
self._setup_client()
try:
if self.auth_type == "oauth":
return self._test_oauth_connection()
else:
return self._test_regular_connection()
except Exception as e:
return {
"success": False,
"message": f"Connection failed: {str(e)}",
"tools_count": 0,
"transport_type": self.transport_type,
"auth_type": self.auth_type,
"error_type": type(e).__name__,
}
def _test_regular_connection(self) -> Dict:
"""Test connection for non-OAuth auth types."""
try:
self._run_async_operation("ping")
ping_success = True
except Exception:
ping_success = False
tools = self.discover_tools()
message = f"Successfully connected to MCP server. Found {len(tools)} tools."
if not ping_success:
message += " (Ping not supported, but tool discovery worked)"
return {
"success": True,
"message": message,
"tools_count": len(tools),
"transport_type": self.transport_type,
"auth_type": self.auth_type,
"ping_supported": ping_success,
"tools": [tool.get("name", "unknown") for tool in tools],
}
def _test_oauth_connection(self) -> Dict:
"""Test connection for OAuth auth type with proper async handling."""
try:
task = mcp_oauth_task.delay(config=self.config, user=self.user_id)
if not task:
raise Exception("Failed to start OAuth authentication")
return {
"success": True,
"requires_oauth": True,
"task_id": task.id,
"status": "pending",
"message": "OAuth flow started",
}
except Exception as e:
return {
"success": False,
"message": f"OAuth connection failed: {str(e)}",
"tools_count": 0,
"transport_type": self.transport_type,
"auth_type": self.auth_type,
"error_type": type(e).__name__,
}
def get_actions_metadata(self) -> List[Dict]:
"""
Get metadata for all available actions.
Returns:
List of action metadata dictionaries
"""
actions = []
for tool in self.available_tools:
input_schema = (
tool.get("inputSchema")
or tool.get("input_schema")
or tool.get("schema")
or tool.get("parameters")
)
parameters_schema = {
"type": "object",
"properties": {},
"required": [],
}
if input_schema:
if isinstance(input_schema, dict):
if "properties" in input_schema:
parameters_schema = {
"type": input_schema.get("type", "object"),
"properties": input_schema.get("properties", {}),
"required": input_schema.get("required", []),
}
for key in ["additionalProperties", "description"]:
if key in input_schema:
parameters_schema[key] = input_schema[key]
else:
parameters_schema["properties"] = input_schema
action = {
"name": tool.get("name", ""),
"description": tool.get("description", ""),
"parameters": parameters_schema,
}
actions.append(action)
return actions
def get_config_requirements(self) -> Dict:
"""Get configuration requirements for the MCP tool."""
return {
"server_url": {
"type": "string",
"description": "URL of the remote MCP server (e.g., https://api.example.com/mcp or https://docs.mcp.cloudflare.com/sse)",
"required": True,
},
"transport_type": {
"type": "string",
"description": "Transport type for connection",
"enum": ["auto", "sse", "http", "stdio"],
"default": "auto",
"required": False,
"help": {
"auto": "Automatically detect best transport",
"sse": "Server-Sent Events (for real-time streaming)",
"http": "HTTP streaming (recommended for production)",
"stdio": "Standard I/O (for local servers)",
},
},
"auth_type": {
"type": "string",
"description": "Authentication type",
"enum": ["none", "bearer", "oauth", "api_key", "basic"],
"default": "none",
"required": True,
"help": {
"none": "No authentication",
"bearer": "Bearer token authentication",
"oauth": "OAuth 2.1 authentication (with frontend integration)",
"api_key": "API key authentication",
"basic": "Basic authentication",
},
},
"auth_credentials": {
"type": "object",
"description": "Authentication credentials (varies by auth_type)",
"required": False,
"properties": {
"bearer_token": {
"type": "string",
"description": "Bearer token for bearer auth",
},
"access_token": {
"type": "string",
"description": "Access token for OAuth (if pre-obtained)",
},
"api_key": {
"type": "string",
"description": "API key for api_key auth",
},
"api_key_header": {
"type": "string",
"description": "Header name for API key (default: X-API-Key)",
},
"username": {
"type": "string",
"description": "Username for basic auth",
},
"password": {
"type": "string",
"description": "Password for basic auth",
},
},
},
"oauth_scopes": {
"type": "array",
"description": "OAuth scopes to request (for oauth auth_type)",
"items": {"type": "string"},
"required": False,
"default": [],
},
"oauth_client_name": {
"type": "string",
"description": "Client name for OAuth registration (for oauth auth_type)",
"default": "DocsGPT-MCP",
"required": False,
},
"headers": {
"type": "object",
"description": "Custom headers to send with requests",
"required": False,
},
"timeout": {
"type": "integer",
"description": "Request timeout in seconds",
"default": 30,
"minimum": 1,
"maximum": 300,
"required": False,
},
"command": {
"type": "string",
"description": "Command to run for STDIO transport (e.g., 'python')",
"required": False,
},
"args": {
"type": "array",
"description": "Arguments for STDIO command",
"items": {"type": "string"},
"required": False,
},
}
class DocsGPTOAuth(OAuthClientProvider):
"""
Custom OAuth handler for DocsGPT that uses frontend redirect instead of browser.
"""
def __init__(
self,
mcp_url: str,
redirect_uri: str,
redis_client: Redis | None = None,
redis_prefix: str = "mcp_oauth:",
task_id: str = None,
scopes: str | list[str] | None = None,
client_name: str = "DocsGPT-MCP",
user_id=None,
db=None,
additional_client_metadata: dict[str, Any] | None = None,
):
"""
Initialize custom OAuth client provider for DocsGPT.
Args:
mcp_url: Full URL to the MCP endpoint
redirect_uri: Custom redirect URI for DocsGPT frontend
redis_client: Redis client for storing auth state
redis_prefix: Prefix for Redis keys
task_id: Task ID for tracking auth status
scopes: OAuth scopes to request
client_name: Name for this client during registration
user_id: User ID for token storage
db: Database instance for token storage
additional_client_metadata: Extra fields for OAuthClientMetadata
"""
self.redirect_uri = redirect_uri
self.redis_client = redis_client
self.redis_prefix = redis_prefix
self.task_id = task_id
self.user_id = user_id
self.db = db
parsed_url = urlparse(mcp_url)
self.server_base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
if isinstance(scopes, list):
scopes = " ".join(scopes)
client_metadata = OAuthClientMetadata(
client_name=client_name,
redirect_uris=[AnyHttpUrl(redirect_uri)],
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
scope=scopes,
**(additional_client_metadata or {}),
)
storage = DBTokenStorage(
server_url=self.server_base_url, user_id=self.user_id, db_client=self.db
)
super().__init__(
server_url=self.server_base_url,
client_metadata=client_metadata,
storage=storage,
redirect_handler=self.redirect_handler,
callback_handler=self.callback_handler,
)
self.auth_url = None
self.extracted_state = None
def _process_auth_url(self, authorization_url: str) -> tuple[str, str]:
"""Process authorization URL to extract state"""
try:
parsed_url = urlparse(authorization_url)
query_params = parse_qs(parsed_url.query)
state_params = query_params.get("state", [])
if state_params:
state = state_params[0]
else:
raise ValueError("No state in auth URL")
return authorization_url, state
except Exception as e:
raise Exception(f"Failed to process auth URL: {e}")
async def redirect_handler(self, authorization_url: str) -> None:
"""Store auth URL and state in Redis for frontend to use."""
auth_url, state = self._process_auth_url(authorization_url)
logging.info(
"[DocsGPTOAuth] Processed auth_url: %s, state: %s", auth_url, state
)
self.auth_url = auth_url
self.extracted_state = state
if self.redis_client and self.extracted_state:
key = f"{self.redis_prefix}auth_url:{self.extracted_state}"
self.redis_client.setex(key, 600, auth_url)
logging.info("[DocsGPTOAuth] Stored auth_url in Redis: %s", key)
if self.task_id:
status_key = f"mcp_oauth_status:{self.task_id}"
status_data = {
"status": "requires_redirect",
"message": "OAuth authorization required",
"authorization_url": self.auth_url,
"state": self.extracted_state,
"requires_oauth": True,
"task_id": self.task_id,
}
self.redis_client.setex(status_key, 600, json.dumps(status_data))
async def callback_handler(self) -> tuple[str, str | None]:
"""Wait for auth code from Redis using the state value."""
if not self.redis_client or not self.extracted_state:
raise Exception("Redis client or state not configured for OAuth")
poll_interval = 1
max_wait_time = 300
code_key = f"{self.redis_prefix}code:{self.extracted_state}"
if self.task_id:
status_key = f"mcp_oauth_status:{self.task_id}"
status_data = {
"status": "awaiting_callback",
"message": "Waiting for OAuth callback...",
"authorization_url": self.auth_url,
"state": self.extracted_state,
"requires_oauth": True,
"task_id": self.task_id,
}
self.redis_client.setex(status_key, 600, json.dumps(status_data))
start_time = time.time()
while time.time() - start_time < max_wait_time:
code_data = self.redis_client.get(code_key)
if code_data:
code = code_data.decode()
returned_state = self.extracted_state
self.redis_client.delete(code_key)
self.redis_client.delete(
f"{self.redis_prefix}auth_url:{self.extracted_state}"
)
self.redis_client.delete(
f"{self.redis_prefix}state:{self.extracted_state}"
)
if self.task_id:
status_data = {
"status": "callback_received",
"message": "OAuth callback received, completing authentication...",
"task_id": self.task_id,
}
self.redis_client.setex(status_key, 600, json.dumps(status_data))
return code, returned_state
error_key = f"{self.redis_prefix}error:{self.extracted_state}"
error_data = self.redis_client.get(error_key)
if error_data:
error_msg = error_data.decode()
self.redis_client.delete(error_key)
self.redis_client.delete(
f"{self.redis_prefix}auth_url:{self.extracted_state}"
)
self.redis_client.delete(
f"{self.redis_prefix}state:{self.extracted_state}"
)
raise Exception(f"OAuth error: {error_msg}")
await asyncio.sleep(poll_interval)
self.redis_client.delete(f"{self.redis_prefix}auth_url:{self.extracted_state}")
self.redis_client.delete(f"{self.redis_prefix}state:{self.extracted_state}")
raise Exception("OAuth callback timeout: no code received within 5 minutes")
class DBTokenStorage(TokenStorage):
def __init__(self, server_url: str, user_id: str, db_client):
self.server_url = server_url
self.user_id = user_id
self.db_client = db_client
self.collection = db_client["connector_sessions"]
@staticmethod
def get_base_url(url: str) -> str:
parsed = urlparse(url)
return f"{parsed.scheme}://{parsed.netloc}"
def get_db_key(self) -> dict:
return {
"server_url": self.get_base_url(self.server_url),
"user_id": self.user_id,
}
async def get_tokens(self) -> OAuthToken | None:
doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key())
if not doc or "tokens" not in doc:
return None
try:
tokens = OAuthToken.model_validate(doc["tokens"])
return tokens
except ValidationError as e:
logging.error(f"Could not load tokens: {e}")
return None
async def set_tokens(self, tokens: OAuthToken) -> None:
await asyncio.to_thread(
self.collection.update_one,
self.get_db_key(),
{"$set": {"tokens": tokens.model_dump()}},
True,
)
logging.info(f"Saved tokens for {self.get_base_url(self.server_url)}")
async def get_client_info(self) -> OAuthClientInformationFull | None:
doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key())
if not doc or "client_info" not in doc:
return None
try:
client_info = OAuthClientInformationFull.model_validate(doc["client_info"])
tokens = await self.get_tokens()
if tokens is None:
logging.debug(
"No tokens found, clearing client info to force fresh registration."
)
await asyncio.to_thread(
self.collection.update_one,
self.get_db_key(),
{"$unset": {"client_info": ""}},
)
return None
return client_info
except ValidationError as e:
logging.error(f"Could not load client info: {e}")
return None
def _serialize_client_info(self, info: dict) -> dict:
if "redirect_uris" in info and isinstance(info["redirect_uris"], list):
info["redirect_uris"] = [str(u) for u in info["redirect_uris"]]
return info
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
serialized_info = self._serialize_client_info(client_info.model_dump())
await asyncio.to_thread(
self.collection.update_one,
self.get_db_key(),
{"$set": {"client_info": serialized_info}},
True,
)
logging.info(f"Saved client info for {self.get_base_url(self.server_url)}")
async def clear(self) -> None:
await asyncio.to_thread(self.collection.delete_one, self.get_db_key())
logging.info(f"Cleared OAuth cache for {self.get_base_url(self.server_url)}")
@classmethod
async def clear_all(cls, db_client) -> None:
collection = db_client["connector_sessions"]
await asyncio.to_thread(collection.delete_many, {})
logging.info("Cleared all OAuth client cache data.")
class MCPOAuthManager:
"""Manager for handling MCP OAuth callbacks."""
def __init__(self, redis_client: Redis | None, redis_prefix: str = "mcp_oauth:"):
self.redis_client = redis_client
self.redis_prefix = redis_prefix
def handle_oauth_callback(
self, state: str, code: str, error: Optional[str] = None
) -> bool:
"""
Handle OAuth callback from provider.
Args:
state: The state parameter from OAuth callback
code: The authorization code from OAuth callback
error: Error message if OAuth failed
Returns:
True if successful, False otherwise
"""
try:
if not self.redis_client or not state:
raise Exception("Redis client or state not provided")
if error:
error_key = f"{self.redis_prefix}error:{state}"
self.redis_client.setex(error_key, 300, error)
raise Exception(f"OAuth error received: {error}")
code_key = f"{self.redis_prefix}code:{state}"
self.redis_client.setex(code_key, 300, code)
state_key = f"{self.redis_prefix}state:{state}"
self.redis_client.setex(state_key, 300, "completed")
return True
except Exception as e:
logging.error(f"Error handling OAuth callback: {e}")
return False
def get_oauth_status(self, task_id: str) -> Dict[str, Any]:
"""Get current status of OAuth flow using provided task_id."""
if not task_id:
return {"status": "not_started", "message": "OAuth flow not started"}
return mcp_oauth_status_task(task_id)

View File

@@ -0,0 +1,546 @@
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
import re
import uuid
from .base import Tool
from application.core.mongo_db import MongoDB
from application.core.settings import settings
class MemoryTool(Tool):
"""Memory
Stores and retrieves information across conversations through a memory file directory.
"""
def __init__(self, tool_config: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> None:
"""Initialize the tool.
Args:
tool_config: Optional tool configuration. Should include:
- tool_id: Unique identifier for this memory tool instance (from user_tools._id)
This ensures each user's tool configuration has isolated memories
user_id: The authenticated user's id (should come from decoded_token["sub"]).
"""
self.user_id: Optional[str] = user_id
# Get tool_id from configuration (passed from user_tools._id in production)
# In production, tool_id is the MongoDB ObjectId string from user_tools collection
if tool_config and "tool_id" in tool_config:
self.tool_id = tool_config["tool_id"]
elif user_id:
# Fallback for backward compatibility or testing
self.tool_id = f"default_{user_id}"
else:
# Last resort fallback (shouldn't happen in normal use)
self.tool_id = str(uuid.uuid4())
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
self.collection = db["memories"]
# -----------------------------
# Action implementations
# -----------------------------
def execute_action(self, action_name: str, **kwargs: Any) -> str:
"""Execute an action by name.
Args:
action_name: One of view, create, str_replace, insert, delete, rename.
**kwargs: Parameters for the action.
Returns:
A human-readable string result.
"""
if not self.user_id:
return "Error: MemoryTool requires a valid user_id."
if action_name == "view":
return self._view(
kwargs.get("path", "/"),
kwargs.get("view_range")
)
if action_name == "create":
return self._create(
kwargs.get("path", ""),
kwargs.get("file_text", "")
)
if action_name == "str_replace":
return self._str_replace(
kwargs.get("path", ""),
kwargs.get("old_str", ""),
kwargs.get("new_str", "")
)
if action_name == "insert":
return self._insert(
kwargs.get("path", ""),
kwargs.get("insert_line", 1),
kwargs.get("insert_text", "")
)
if action_name == "delete":
return self._delete(kwargs.get("path", ""))
if action_name == "rename":
return self._rename(
kwargs.get("old_path", ""),
kwargs.get("new_path", "")
)
return f"Unknown action: {action_name}"
def get_actions_metadata(self) -> List[Dict[str, Any]]:
"""Return JSON metadata describing supported actions for tool schemas."""
return [
{
"name": "view",
"description": "Shows directory contents or file contents with optional line ranges.",
"parameters": {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Path to file or directory (e.g., /notes.txt or /project/ or /)."
},
"view_range": {
"type": "array",
"items": {"type": "integer"},
"description": "Optional [start_line, end_line] to view specific lines (1-indexed)."
}
},
"required": ["path"]
},
},
{
"name": "create",
"description": "Create or overwrite a file.",
"parameters": {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "File path to create (e.g., /notes.txt or /project/task.txt)."
},
"file_text": {
"type": "string",
"description": "Content to write to the file."
}
},
"required": ["path", "file_text"]
},
},
{
"name": "str_replace",
"description": "Replace text in a file.",
"parameters": {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "File path (e.g., /notes.txt)."
},
"old_str": {
"type": "string",
"description": "String to find."
},
"new_str": {
"type": "string",
"description": "String to replace with."
}
},
"required": ["path", "old_str", "new_str"]
},
},
{
"name": "insert",
"description": "Insert text at a specific line in a file.",
"parameters": {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "File path (e.g., /notes.txt)."
},
"insert_line": {
"type": "integer",
"description": "Line number to insert at (1-indexed)."
},
"insert_text": {
"type": "string",
"description": "Text to insert."
}
},
"required": ["path", "insert_line", "insert_text"]
},
},
{
"name": "delete",
"description": "Delete a file or directory.",
"parameters": {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Path to delete (e.g., /notes.txt or /project/)."
}
},
"required": ["path"]
},
},
{
"name": "rename",
"description": "Rename or move a file/directory.",
"parameters": {
"type": "object",
"properties": {
"old_path": {
"type": "string",
"description": "Current path (e.g., /old.txt)."
},
"new_path": {
"type": "string",
"description": "New path (e.g., /new.txt)."
}
},
"required": ["old_path", "new_path"]
},
},
]
def get_config_requirements(self) -> Dict[str, Any]:
"""Return configuration requirements."""
return {}
# -----------------------------
# Path validation
# -----------------------------
def _validate_path(self, path: str) -> Optional[str]:
"""Validate and normalize path.
Args:
path: User-provided path.
Returns:
Normalized path or None if invalid.
"""
if not path:
return None
# Remove any leading/trailing whitespace
path = path.strip()
# Preserve whether path ends with / (indicates directory)
is_directory = path.endswith("/")
# Ensure path starts with / for consistency
if not path.startswith("/"):
path = "/" + path
# Check for directory traversal patterns
if ".." in path or path.count("//") > 0:
return None
# Normalize the path
try:
# Convert to Path object and resolve to canonical form
normalized = str(Path(path).as_posix())
# Ensure it still starts with /
if not normalized.startswith("/"):
return None
# Preserve trailing slash for directories
if is_directory and not normalized.endswith("/") and normalized != "/":
normalized = normalized + "/"
return normalized
except Exception:
return None
# -----------------------------
# Internal helpers
# -----------------------------
def _view(self, path: str, view_range: Optional[List[int]] = None) -> str:
"""View directory contents or file contents."""
validated_path = self._validate_path(path)
if not validated_path:
return "Error: Invalid path."
# Check if viewing directory (ends with / or is root)
if validated_path == "/" or validated_path.endswith("/"):
return self._view_directory(validated_path)
# Otherwise view file
return self._view_file(validated_path, view_range)
def _view_directory(self, path: str) -> str:
"""List files in a directory."""
# Ensure path ends with / for proper prefix matching
search_path = path if path.endswith("/") else path + "/"
# Find all files that start with this directory path
query = {
"user_id": self.user_id,
"tool_id": self.tool_id,
"path": {"$regex": f"^{re.escape(search_path)}"}
}
docs = list(self.collection.find(query, {"path": 1}))
if not docs:
return f"Directory: {path}\n(empty)"
# Extract filenames relative to the directory
files = []
for doc in docs:
file_path = doc["path"]
# Remove the directory prefix
if file_path.startswith(search_path):
relative = file_path[len(search_path):]
if relative:
files.append(relative)
files.sort()
file_list = "\n".join(f"- {f}" for f in files)
return f"Directory: {path}\n{file_list}"
def _view_file(self, path: str, view_range: Optional[List[int]] = None) -> str:
"""View file contents with optional line range."""
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": path})
if not doc or not doc.get("content"):
return f"Error: File not found: {path}"
content = str(doc["content"])
# Apply view_range if specified
if view_range and len(view_range) == 2:
lines = content.split("\n")
start, end = view_range
# Convert to 0-indexed
start_idx = max(0, start - 1)
end_idx = min(len(lines), end)
if start_idx >= len(lines):
return f"Error: Line range out of bounds. File has {len(lines)} lines."
selected_lines = lines[start_idx:end_idx]
# Add line numbers (enumerate with 1-based start)
numbered_lines = [f"{i}: {line}" for i, line in enumerate(selected_lines, start=start)]
return "\n".join(numbered_lines)
return content
def _create(self, path: str, file_text: str) -> str:
"""Create or overwrite a file."""
validated_path = self._validate_path(path)
if not validated_path:
return "Error: Invalid path."
if validated_path == "/" or validated_path.endswith("/"):
return "Error: Cannot create a file at directory path."
self.collection.update_one(
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
{
"$set": {
"content": file_text,
"updated_at": datetime.now()
}
},
upsert=True
)
return f"File created: {validated_path}"
def _str_replace(self, path: str, old_str: str, new_str: str) -> str:
"""Replace text in a file."""
validated_path = self._validate_path(path)
if not validated_path:
return "Error: Invalid path."
if not old_str:
return "Error: old_str is required."
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path})
if not doc or not doc.get("content"):
return f"Error: File not found: {validated_path}"
current_content = str(doc["content"])
# Check if old_str exists (case-insensitive)
if old_str.lower() not in current_content.lower():
return f"Error: String '{old_str}' not found in file."
# Replace the string (case-insensitive)
import re as regex_module
updated_content = regex_module.sub(regex_module.escape(old_str), new_str, current_content, flags=regex_module.IGNORECASE)
self.collection.update_one(
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
{
"$set": {
"content": updated_content,
"updated_at": datetime.now()
}
}
)
return f"File updated: {validated_path}"
def _insert(self, path: str, insert_line: int, insert_text: str) -> str:
"""Insert text at a specific line."""
validated_path = self._validate_path(path)
if not validated_path:
return "Error: Invalid path."
if not insert_text:
return "Error: insert_text is required."
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path})
if not doc or not doc.get("content"):
return f"Error: File not found: {validated_path}"
current_content = str(doc["content"])
lines = current_content.split("\n")
# Convert to 0-indexed
index = insert_line - 1
if index < 0 or index > len(lines):
return f"Error: Invalid line number. File has {len(lines)} lines."
lines.insert(index, insert_text)
updated_content = "\n".join(lines)
self.collection.update_one(
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
{
"$set": {
"content": updated_content,
"updated_at": datetime.now()
}
}
)
return f"Text inserted at line {insert_line} in {validated_path}"
def _delete(self, path: str) -> str:
"""Delete a file or directory."""
validated_path = self._validate_path(path)
if not validated_path:
return "Error: Invalid path."
if validated_path == "/":
# Delete all files for this user and tool
result = self.collection.delete_many({"user_id": self.user_id, "tool_id": self.tool_id})
return f"Deleted {result.deleted_count} file(s) from memory."
# Check if it's a directory (ends with /)
if validated_path.endswith("/"):
# Delete all files in directory
result = self.collection.delete_many({
"user_id": self.user_id,
"tool_id": self.tool_id,
"path": {"$regex": f"^{re.escape(validated_path)}"}
})
return f"Deleted directory and {result.deleted_count} file(s)."
# Try to delete as directory first (without trailing slash)
# Check if any files start with this path + /
search_path = validated_path + "/"
directory_result = self.collection.delete_many({
"user_id": self.user_id,
"tool_id": self.tool_id,
"path": {"$regex": f"^{re.escape(search_path)}"}
})
if directory_result.deleted_count > 0:
return f"Deleted directory and {directory_result.deleted_count} file(s)."
# Delete single file
result = self.collection.delete_one({
"user_id": self.user_id,
"tool_id": self.tool_id,
"path": validated_path
})
if result.deleted_count:
return f"Deleted: {validated_path}"
return f"Error: File not found: {validated_path}"
def _rename(self, old_path: str, new_path: str) -> str:
"""Rename or move a file/directory."""
validated_old = self._validate_path(old_path)
validated_new = self._validate_path(new_path)
if not validated_old or not validated_new:
return "Error: Invalid path."
if validated_old == "/" or validated_new == "/":
return "Error: Cannot rename root directory."
# Check if renaming a directory
if validated_old.endswith("/"):
# Ensure validated_new also ends with / for proper path replacement
if not validated_new.endswith("/"):
validated_new = validated_new + "/"
# Find all files in the old directory
docs = list(self.collection.find({
"user_id": self.user_id,
"tool_id": self.tool_id,
"path": {"$regex": f"^{re.escape(validated_old)}"}
}))
if not docs:
return f"Error: Directory not found: {validated_old}"
# Update paths for all files
for doc in docs:
old_file_path = doc["path"]
new_file_path = old_file_path.replace(validated_old, validated_new, 1)
self.collection.update_one(
{"_id": doc["_id"]},
{"$set": {"path": new_file_path, "updated_at": datetime.now()}}
)
return f"Renamed directory: {validated_old} -> {validated_new} ({len(docs)} files)"
# Rename single file
doc = self.collection.find_one({
"user_id": self.user_id,
"tool_id": self.tool_id,
"path": validated_old
})
if not doc:
return f"Error: File not found: {validated_old}"
# Check if new path already exists
existing = self.collection.find_one({
"user_id": self.user_id,
"tool_id": self.tool_id,
"path": validated_new
})
if existing:
return f"Error: File already exists at {validated_new}"
# Delete the old document and create a new one with the new path
self.collection.delete_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_old})
self.collection.insert_one({
"user_id": self.user_id,
"tool_id": self.tool_id,
"path": validated_new,
"content": doc.get("content", ""),
"updated_at": datetime.now()
})
return f"Renamed: {validated_old} -> {validated_new}"

View File

@@ -0,0 +1,199 @@
from datetime import datetime
from typing import Any, Dict, List, Optional
import uuid
from .base import Tool
from application.core.mongo_db import MongoDB
from application.core.settings import settings
class NotesTool(Tool):
"""Notepad
Single note. Supports viewing, overwriting, string replacement.
"""
def __init__(self, tool_config: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> None:
"""Initialize the tool.
Args:
tool_config: Optional tool configuration. Should include:
- tool_id: Unique identifier for this notes tool instance (from user_tools._id)
This ensures each user's tool configuration has isolated notes
user_id: The authenticated user's id (should come from decoded_token["sub"]).
"""
self.user_id: Optional[str] = user_id
# Get tool_id from configuration (passed from user_tools._id in production)
# In production, tool_id is the MongoDB ObjectId string from user_tools collection
if tool_config and "tool_id" in tool_config:
self.tool_id = tool_config["tool_id"]
elif user_id:
# Fallback for backward compatibility or testing
self.tool_id = f"default_{user_id}"
else:
# Last resort fallback (shouldn't happen in normal use)
self.tool_id = str(uuid.uuid4())
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
self.collection = db["notes"]
# -----------------------------
# Action implementations
# -----------------------------
def execute_action(self, action_name: str, **kwargs: Any) -> str:
"""Execute an action by name.
Args:
action_name: One of view, overwrite, str_replace, insert, delete.
**kwargs: Parameters for the action.
Returns:
A human-readable string result.
"""
if not self.user_id:
return "Error: NotesTool requires a valid user_id."
if action_name == "view":
return self._get_note()
if action_name == "overwrite":
return self._overwrite_note(kwargs.get("text", ""))
if action_name == "str_replace":
return self._str_replace(kwargs.get("old_str", ""), kwargs.get("new_str", ""))
if action_name == "insert":
return self._insert(kwargs.get("line_number", 1), kwargs.get("text", ""))
if action_name == "delete":
return self._delete_note()
return f"Unknown action: {action_name}"
def get_actions_metadata(self) -> List[Dict[str, Any]]:
"""Return JSON metadata describing supported actions for tool schemas."""
return [
{
"name": "view",
"description": "Retrieve the user's note.",
"parameters": {"type": "object", "properties": {}},
},
{
"name": "overwrite",
"description": "Replace the entire note content (creates if doesn't exist).",
"parameters": {
"type": "object",
"properties": {
"text": {"type": "string", "description": "New note content."}
},
"required": ["text"],
},
},
{
"name": "str_replace",
"description": "Replace occurrences of old_str with new_str in the note.",
"parameters": {
"type": "object",
"properties": {
"old_str": {"type": "string", "description": "String to find."},
"new_str": {"type": "string", "description": "String to replace with."}
},
"required": ["old_str", "new_str"],
},
},
{
"name": "insert",
"description": "Insert text at the specified line number (1-indexed).",
"parameters": {
"type": "object",
"properties": {
"line_number": {"type": "integer", "description": "Line number to insert at (1-indexed)."},
"text": {"type": "string", "description": "Text to insert."}
},
"required": ["line_number", "text"],
},
},
{
"name": "delete",
"description": "Delete the user's note.",
"parameters": {"type": "object", "properties": {}},
},
]
def get_config_requirements(self) -> Dict[str, Any]:
"""Return configuration requirements (none for now)."""
return {}
# -----------------------------
# Internal helpers (single-note)
# -----------------------------
def _get_note(self) -> str:
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
if not doc or not doc.get("note"):
return "No note found."
return str(doc["note"])
def _overwrite_note(self, content: str) -> str:
content = (content or "").strip()
if not content:
return "Note content required."
self.collection.update_one(
{"user_id": self.user_id, "tool_id": self.tool_id},
{"$set": {"note": content, "updated_at": datetime.utcnow()}},
upsert=True, # ✅ create if missing
)
return "Note saved."
def _str_replace(self, old_str: str, new_str: str) -> str:
if not old_str:
return "old_str is required."
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
if not doc or not doc.get("note"):
return "No note found."
current_note = str(doc["note"])
# Case-insensitive search
if old_str.lower() not in current_note.lower():
return f"String '{old_str}' not found in note."
# Case-insensitive replacement
import re
updated_note = re.sub(re.escape(old_str), new_str, current_note, flags=re.IGNORECASE)
self.collection.update_one(
{"user_id": self.user_id, "tool_id": self.tool_id},
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
)
return "Note updated."
def _insert(self, line_number: int, text: str) -> str:
if not text:
return "Text is required."
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
if not doc or not doc.get("note"):
return "No note found."
current_note = str(doc["note"])
lines = current_note.split("\n")
# Convert to 0-indexed and validate
index = line_number - 1
if index < 0 or index > len(lines):
return f"Invalid line number. Note has {len(lines)} lines."
lines.insert(index, text)
updated_note = "\n".join(lines)
self.collection.update_one(
{"user_id": self.user_id, "tool_id": self.tool_id},
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
)
return "Text inserted."
def _delete_note(self) -> str:
res = self.collection.delete_one({"user_id": self.user_id, "tool_id": self.tool_id})
return "Note deleted." if res.deleted_count else "No note found to delete."

View File

@@ -0,0 +1,321 @@
from datetime import datetime
from typing import Any, Dict, List, Optional
import uuid
from .base import Tool
from application.core.mongo_db import MongoDB
from application.core.settings import settings
class TodoListTool(Tool):
"""Todo List
Manages todo items for users. Supports creating, viewing, updating, and deleting todos.
"""
def __init__(self, tool_config: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> None:
"""Initialize the tool.
Args:
tool_config: Optional tool configuration. Should include:
- tool_id: Unique identifier for this todo list tool instance (from user_tools._id)
This ensures each user's tool configuration has isolated todos
user_id: The authenticated user's id (should come from decoded_token["sub"]).
"""
self.user_id: Optional[str] = user_id
# Get tool_id from configuration (passed from user_tools._id in production)
# In production, tool_id is the MongoDB ObjectId string from user_tools collection
if tool_config and "tool_id" in tool_config:
self.tool_id = tool_config["tool_id"]
elif user_id:
# Fallback for backward compatibility or testing
self.tool_id = f"default_{user_id}"
else:
# Last resort fallback (shouldn't happen in normal use)
self.tool_id = str(uuid.uuid4())
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
self.collection = db["todos"]
# -----------------------------
# Action implementations
# -----------------------------
def execute_action(self, action_name: str, **kwargs: Any) -> str:
"""Execute an action by name.
Args:
action_name: One of list, create, get, update, complete, delete.
**kwargs: Parameters for the action.
Returns:
A human-readable string result.
"""
if not self.user_id:
return "Error: TodoListTool requires a valid user_id."
if action_name == "list":
return self._list()
if action_name == "create":
return self._create(kwargs.get("title", ""))
if action_name == "get":
return self._get(kwargs.get("todo_id"))
if action_name == "update":
return self._update(
kwargs.get("todo_id"),
kwargs.get("title", "")
)
if action_name == "complete":
return self._complete(kwargs.get("todo_id"))
if action_name == "delete":
return self._delete(kwargs.get("todo_id"))
return f"Unknown action: {action_name}"
def get_actions_metadata(self) -> List[Dict[str, Any]]:
"""Return JSON metadata describing supported actions for tool schemas."""
return [
{
"name": "list",
"description": "List all todos for the user.",
"parameters": {"type": "object", "properties": {}},
},
{
"name": "create",
"description": "Create a new todo item.",
"parameters": {
"type": "object",
"properties": {
"title": {
"type": "string",
"description": "Title of the todo item."
}
},
"required": ["title"],
},
},
{
"name": "get",
"description": "Get a specific todo by ID.",
"parameters": {
"type": "object",
"properties": {
"todo_id": {
"type": "integer",
"description": "The ID of the todo to retrieve."
}
},
"required": ["todo_id"],
},
},
{
"name": "update",
"description": "Update a todo's title by ID.",
"parameters": {
"type": "object",
"properties": {
"todo_id": {
"type": "integer",
"description": "The ID of the todo to update."
},
"title": {
"type": "string",
"description": "The new title for the todo."
}
},
"required": ["todo_id", "title"],
},
},
{
"name": "complete",
"description": "Mark a todo as completed.",
"parameters": {
"type": "object",
"properties": {
"todo_id": {
"type": "integer",
"description": "The ID of the todo to mark as completed."
}
},
"required": ["todo_id"],
},
},
{
"name": "delete",
"description": "Delete a specific todo by ID.",
"parameters": {
"type": "object",
"properties": {
"todo_id": {
"type": "integer",
"description": "The ID of the todo to delete."
}
},
"required": ["todo_id"],
},
},
]
def get_config_requirements(self) -> Dict[str, Any]:
"""Return configuration requirements."""
return {}
# -----------------------------
# Internal helpers
# -----------------------------
def _coerce_todo_id(self, value: Optional[Any]) -> Optional[int]:
"""Convert todo identifiers to sequential integers."""
if value is None:
return None
if isinstance(value, int):
return value if value > 0 else None
if isinstance(value, str):
stripped = value.strip()
if stripped.isdigit():
numeric_value = int(stripped)
return numeric_value if numeric_value > 0 else None
return None
def _get_next_todo_id(self) -> int:
"""Get the next sequential todo_id for this user and tool.
Returns a simple integer (1, 2, 3, ...) scoped to this user/tool.
With 5-10 todos max, scanning is negligible.
"""
# Find all todos for this user/tool and get their IDs
todos = list(self.collection.find(
{"user_id": self.user_id, "tool_id": self.tool_id},
{"todo_id": 1}
))
# Find the maximum todo_id
max_id = 0
for todo in todos:
todo_id = self._coerce_todo_id(todo.get("todo_id"))
if todo_id is not None:
max_id = max(max_id, todo_id)
return max_id + 1
def _list(self) -> str:
"""List all todos for the user."""
cursor = self.collection.find({"user_id": self.user_id, "tool_id": self.tool_id})
todos = list(cursor)
if not todos:
return "No todos found."
result_lines = ["Todos:"]
for doc in todos:
todo_id = doc.get("todo_id")
title = doc.get("title", "Untitled")
status = doc.get("status", "open")
line = f"[{todo_id}] {title} ({status})"
result_lines.append(line)
return "\n".join(result_lines)
def _create(self, title: str) -> str:
"""Create a new todo item."""
title = (title or "").strip()
if not title:
return "Error: Title is required."
now = datetime.now()
todo_id = self._get_next_todo_id()
doc = {
"todo_id": todo_id,
"user_id": self.user_id,
"tool_id": self.tool_id,
"title": title,
"status": "open",
"created_at": now,
"updated_at": now,
}
self.collection.insert_one(doc)
return f"Todo created with ID {todo_id}: {title}"
def _get(self, todo_id: Optional[Any]) -> str:
"""Get a specific todo by ID."""
parsed_todo_id = self._coerce_todo_id(todo_id)
if parsed_todo_id is None:
return "Error: todo_id must be a positive integer."
doc = self.collection.find_one({
"user_id": self.user_id,
"tool_id": self.tool_id,
"todo_id": parsed_todo_id
})
if not doc:
return f"Error: Todo with ID {parsed_todo_id} not found."
title = doc.get("title", "Untitled")
status = doc.get("status", "open")
result = f"Todo [{parsed_todo_id}]:\nTitle: {title}\nStatus: {status}"
return result
def _update(self, todo_id: Optional[Any], title: str) -> str:
"""Update a todo's title by ID."""
parsed_todo_id = self._coerce_todo_id(todo_id)
if parsed_todo_id is None:
return "Error: todo_id must be a positive integer."
title = (title or "").strip()
if not title:
return "Error: Title is required."
result = self.collection.update_one(
{"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id},
{"$set": {"title": title, "updated_at": datetime.now()}}
)
if result.matched_count == 0:
return f"Error: Todo with ID {parsed_todo_id} not found."
return f"Todo {parsed_todo_id} updated to: {title}"
def _complete(self, todo_id: Optional[Any]) -> str:
"""Mark a todo as completed."""
parsed_todo_id = self._coerce_todo_id(todo_id)
if parsed_todo_id is None:
return "Error: todo_id must be a positive integer."
result = self.collection.update_one(
{"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id},
{"$set": {"status": "completed", "updated_at": datetime.now()}}
)
if result.matched_count == 0:
return f"Error: Todo with ID {parsed_todo_id} not found."
return f"Todo {parsed_todo_id} marked as completed."
def _delete(self, todo_id: Optional[Any]) -> str:
"""Delete a specific todo by ID."""
parsed_todo_id = self._coerce_todo_id(todo_id)
if parsed_todo_id is None:
return "Error: todo_id must be a positive integer."
result = self.collection.delete_one({
"user_id": self.user_id,
"tool_id": self.tool_id,
"todo_id": parsed_todo_id
})
if result.deleted_count == 0:
return f"Error: Todo with ID {parsed_todo_id} not found."
return f"Todo {parsed_todo_id} deleted."

View File

@@ -19,9 +19,25 @@ class ToolActionParser:
def _parse_openai_llm(self, call):
try:
call_args = json.loads(call.arguments)
tool_id = call.name.split("_")[-1]
action_name = call.name.rsplit("_", 1)[0]
except (AttributeError, TypeError) as e:
tool_parts = call.name.split("_")
# If the tool name doesn't contain an underscore, it's likely a hallucinated tool
if len(tool_parts) < 2:
logger.warning(
f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id"
)
return None, None, None
tool_id = tool_parts[-1]
action_name = "_".join(tool_parts[:-1])
# Validate that tool_id looks like a numerical ID
if not tool_id.isdigit():
logger.warning(
f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call."
)
except (AttributeError, TypeError, json.JSONDecodeError) as e:
logger.error(f"Error parsing OpenAI LLM call: {e}")
return None, None, None
return tool_id, action_name, call_args
@@ -29,8 +45,24 @@ class ToolActionParser:
def _parse_google_llm(self, call):
try:
call_args = call.arguments
tool_id = call.name.split("_")[-1]
action_name = call.name.rsplit("_", 1)[0]
tool_parts = call.name.split("_")
# If the tool name doesn't contain an underscore, it's likely a hallucinated tool
if len(tool_parts) < 2:
logger.warning(
f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id"
)
return None, None, None
tool_id = tool_parts[-1]
action_name = "_".join(tool_parts[:-1])
# Validate that tool_id looks like a numerical ID
if not tool_id.isdigit():
logger.warning(
f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call."
)
except (AttributeError, TypeError) as e:
logger.error(f"Error parsing Google LLM call: {e}")
return None, None, None

View File

@@ -23,16 +23,23 @@ class ToolManager:
tool_config = self.config.get(name, {})
self.tools[name] = obj(tool_config)
def load_tool(self, tool_name, tool_config):
def load_tool(self, tool_name, tool_config, user_id=None):
self.config[tool_name] = tool_config
module = importlib.import_module(f"application.agents.tools.{tool_name}")
for member_name, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, Tool) and obj is not Tool:
return obj(tool_config)
if tool_name in {"mcp_tool", "notes", "memory", "todo_list"} and user_id:
return obj(tool_config, user_id)
else:
return obj(tool_config)
def execute_action(self, tool_name, action_name, **kwargs):
def execute_action(self, tool_name, action_name, user_id=None, **kwargs):
if tool_name not in self.tools:
raise ValueError(f"Tool '{tool_name}' not loaded")
if tool_name in {"mcp_tool", "memory", "todo_list"} and user_id:
tool_config = self.config.get(tool_name, {})
tool = self.load_tool(tool_name, tool_config, user_id)
return tool.execute_action(action_name, **kwargs)
return self.tools[tool_name].execute_action(action_name, **kwargs)
def get_all_actions_metadata(self):

View File

@@ -0,0 +1,7 @@
from flask_restx import Api
api = Api(
version="1.0",
title="DocsGPT API",
description="API for DocsGPT",
)

View File

@@ -0,0 +1,19 @@
from flask import Blueprint
from application.api import api
from application.api.answer.routes.answer import AnswerResource
from application.api.answer.routes.base import answer_ns
from application.api.answer.routes.stream import StreamResource
answer = Blueprint("answer", __name__)
api.add_namespace(answer_ns)
def init_answer_routes():
api.add_resource(StreamResource, "/stream")
api.add_resource(AnswerResource, "/api/answer")
init_answer_routes()

View File

@@ -1,914 +0,0 @@
import asyncio
import datetime
import json
import logging
import os
import traceback
from bson.dbref import DBRef
from bson.objectid import ObjectId
from flask import Blueprint, make_response, request, Response
from flask_restx import fields, Namespace, Resource
from application.agents.agent_creator import AgentCreator
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.error import bad_request
from application.extensions import api
from application.llm.llm_creator import LLMCreator
from application.retriever.retriever_creator import RetrieverCreator
from application.utils import check_required_fields, limit_chat_history
logger = logging.getLogger(__name__)
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
conversations_collection = db["conversations"]
sources_collection = db["sources"]
prompts_collection = db["prompts"]
agents_collection = db["agents"]
user_logs_collection = db["user_logs"]
attachments_collection = db["attachments"]
answer = Blueprint("answer", __name__)
answer_ns = Namespace("answer", description="Answer related operations", path="/")
api.add_namespace(answer_ns)
gpt_model = ""
# to have some kind of default behaviour
if settings.LLM_PROVIDER == "openai":
gpt_model = "gpt-4o-mini"
elif settings.LLM_PROVIDER == "anthropic":
gpt_model = "claude-2"
elif settings.LLM_PROVIDER == "groq":
gpt_model = "llama3-8b-8192"
elif settings.LLM_PROVIDER == "novita":
gpt_model = "deepseek/deepseek-r1"
if settings.LLM_NAME: # in case there is particular model name configured
gpt_model = settings.LLM_NAME
# load the prompts
current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
with open(os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r") as f:
chat_combine_template = f.read()
with open(os.path.join(current_dir, "prompts", "chat_reduce_prompt.txt"), "r") as f:
chat_reduce_template = f.read()
with open(os.path.join(current_dir, "prompts", "chat_combine_creative.txt"), "r") as f:
chat_combine_creative = f.read()
with open(os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r") as f:
chat_combine_strict = f.read()
api_key_set = settings.API_KEY is not None
embeddings_key_set = settings.EMBEDDINGS_KEY is not None
async def async_generate(chain, question, chat_history):
result = await chain.arun({"question": question, "chat_history": chat_history})
return result
def run_async_chain(chain, question, chat_history):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = {}
try:
answer = loop.run_until_complete(async_generate(chain, question, chat_history))
finally:
loop.close()
result["answer"] = answer
return result
def get_agent_key(agent_id, user_id):
if not agent_id:
return None, False, None
try:
agent = agents_collection.find_one({"_id": ObjectId(agent_id)})
if agent is None:
raise Exception("Agent not found", 404)
is_owner = agent.get("user") == user_id
if is_owner:
agents_collection.update_one(
{"_id": ObjectId(agent_id)},
{"$set": {"lastUsedAt": datetime.datetime.now(datetime.timezone.utc)}},
)
return str(agent["key"]), False, None
is_shared_with_user = agent.get(
"shared_publicly", False
) or user_id in agent.get("shared_with", [])
if is_shared_with_user:
return str(agent["key"]), True, agent.get("shared_token")
raise Exception("Unauthorized access to the agent", 403)
except Exception as e:
logger.error(f"Error in get_agent_key: {str(e)}", exc_info=True)
raise
def get_data_from_api_key(api_key):
data = agents_collection.find_one({"key": api_key})
if not data:
raise Exception("Invalid API Key, please generate a new key", 401)
source = data.get("source")
if isinstance(source, DBRef):
source_doc = db.dereference(source)
data["source"] = str(source_doc["_id"])
data["retriever"] = source_doc.get("retriever", data.get("retriever"))
else:
data["source"] = {}
return data
def get_retriever(source_id: str):
doc = sources_collection.find_one({"_id": ObjectId(source_id)})
if doc is None:
raise Exception("Source document does not exist", 404)
retriever_name = None if "retriever" not in doc else doc["retriever"]
return retriever_name
def is_azure_configured():
return (
settings.OPENAI_API_BASE
and settings.OPENAI_API_VERSION
and settings.AZURE_DEPLOYMENT_NAME
)
def save_conversation(
conversation_id,
question,
response,
thought,
source_log_docs,
tool_calls,
llm,
decoded_token,
index=None,
api_key=None,
agent_id=None,
is_shared_usage=False,
shared_token=None,
attachment_ids=None,
):
current_time = datetime.datetime.now(datetime.timezone.utc)
if conversation_id is not None and index is not None:
conversations_collection.update_one(
{"_id": ObjectId(conversation_id), f"queries.{index}": {"$exists": True}},
{
"$set": {
f"queries.{index}.prompt": question,
f"queries.{index}.response": response,
f"queries.{index}.thought": thought,
f"queries.{index}.sources": source_log_docs,
f"queries.{index}.tool_calls": tool_calls,
f"queries.{index}.timestamp": current_time,
f"queries.{index}.attachments": attachment_ids,
}
},
)
##remove following queries from the array
conversations_collection.update_one(
{"_id": ObjectId(conversation_id), f"queries.{index}": {"$exists": True}},
{"$push": {"queries": {"$each": [], "$slice": index + 1}}},
)
elif conversation_id is not None and conversation_id != "None":
conversations_collection.update_one(
{"_id": ObjectId(conversation_id)},
{
"$push": {
"queries": {
"prompt": question,
"response": response,
"thought": thought,
"sources": source_log_docs,
"tool_calls": tool_calls,
"timestamp": current_time,
"attachments": attachment_ids,
}
}
},
)
else:
# create new conversation
# generate summary
messages_summary = [
{
"role": "assistant",
"content": "Summarise following conversation in no more than 3 "
"words, respond ONLY with the summary, use the same "
"language as the system",
},
{
"role": "user",
"content": "Summarise following conversation in no more than 3 words, "
"respond ONLY with the summary, use the same language as the "
"system \n\nUser: " + question + "\n\n" + "AI: " + response,
},
]
completion = llm.gen(model=gpt_model, messages=messages_summary, max_tokens=30)
conversation_data = {
"user": decoded_token.get("sub"),
"date": datetime.datetime.utcnow(),
"name": completion,
"queries": [
{
"prompt": question,
"response": response,
"thought": thought,
"sources": source_log_docs,
"tool_calls": tool_calls,
"timestamp": current_time,
"attachments": attachment_ids,
}
],
}
if api_key:
if agent_id:
conversation_data["agent_id"] = agent_id
if is_shared_usage:
conversation_data["is_shared_usage"] = is_shared_usage
conversation_data["shared_token"] = shared_token
api_key_doc = agents_collection.find_one({"key": api_key})
if api_key_doc:
conversation_data["api_key"] = api_key_doc["key"]
conversation_id = conversations_collection.insert_one(
conversation_data
).inserted_id
return conversation_id
def get_prompt(prompt_id):
if prompt_id == "default":
prompt = chat_combine_template
elif prompt_id == "creative":
prompt = chat_combine_creative
elif prompt_id == "strict":
prompt = chat_combine_strict
else:
prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"]
return prompt
def complete_stream(
question,
agent,
retriever,
conversation_id,
user_api_key,
decoded_token,
isNoneDoc=False,
index=None,
should_save_conversation=True,
attachment_ids=None,
agent_id=None,
is_shared_usage=False,
shared_token=None,
):
try:
response_full, thought, source_log_docs, tool_calls = "", "", [], []
answer = agent.gen(query=question, retriever=retriever)
for line in answer:
if "answer" in line:
response_full += str(line["answer"])
data = json.dumps({"type": "answer", "answer": line["answer"]})
yield f"data: {data}\n\n"
elif "sources" in line:
truncated_sources = []
source_log_docs = line["sources"]
for source in line["sources"]:
truncated_source = source.copy()
if "text" in truncated_source:
truncated_source["text"] = (
truncated_source["text"][:100].strip() + "..."
)
truncated_sources.append(truncated_source)
if len(truncated_sources) > 0:
data = json.dumps({"type": "source", "source": truncated_sources})
yield f"data: {data}\n\n"
elif "tool_calls" in line:
tool_calls = line["tool_calls"]
elif "thought" in line:
thought += line["thought"]
data = json.dumps({"type": "thought", "thought": line["thought"]})
yield f"data: {data}\n\n"
elif "type" in line:
data = json.dumps(line)
yield f"data: {data}\n\n"
if isNoneDoc:
for doc in source_log_docs:
doc["source"] = "None"
llm = LLMCreator.create_llm(
settings.LLM_PROVIDER,
api_key=settings.API_KEY,
user_api_key=user_api_key,
decoded_token=decoded_token,
)
if should_save_conversation:
conversation_id = save_conversation(
conversation_id,
question,
response_full,
thought,
source_log_docs,
tool_calls,
llm,
decoded_token,
index,
api_key=user_api_key,
attachment_ids=attachment_ids,
agent_id=agent_id,
is_shared_usage=is_shared_usage,
shared_token=shared_token,
)
else:
conversation_id = None
# send data.type = "end" to indicate that the stream has ended as json
data = json.dumps({"type": "id", "id": str(conversation_id)})
yield f"data: {data}\n\n"
retriever_params = retriever.get_params()
user_logs_collection.insert_one(
{
"action": "stream_answer",
"level": "info",
"user": decoded_token.get("sub"),
"api_key": user_api_key,
"question": question,
"response": response_full,
"sources": source_log_docs,
"retriever_params": retriever_params,
"attachments": attachment_ids,
"timestamp": datetime.datetime.now(datetime.timezone.utc),
}
)
data = json.dumps({"type": "end"})
yield f"data: {data}\n\n"
except Exception as e:
logger.error(f"Error in stream: {str(e)}", exc_info=True)
data = json.dumps(
{
"type": "error",
"error": "Please try again later. We apologize for any inconvenience.",
}
)
yield f"data: {data}\n\n"
return
@answer_ns.route("/stream")
class Stream(Resource):
stream_model = api.model(
"StreamModel",
{
"question": fields.String(
required=True, description="Question to be asked"
),
"history": fields.List(
fields.String, required=False, description="Chat history"
),
"conversation_id": fields.String(
required=False, description="Conversation ID"
),
"prompt_id": fields.String(
required=False, default="default", description="Prompt ID"
),
"chunks": fields.Integer(
required=False, default=2, description="Number of chunks"
),
"token_limit": fields.Integer(required=False, description="Token limit"),
"retriever": fields.String(required=False, description="Retriever type"),
"api_key": fields.String(required=False, description="API key"),
"active_docs": fields.String(
required=False, description="Active documents"
),
"isNoneDoc": fields.Boolean(
required=False, description="Flag indicating if no document is used"
),
"index": fields.Integer(
required=False, description="Index of the query to update"
),
"save_conversation": fields.Boolean(
required=False,
default=True,
description="Whether to save the conversation",
),
"attachments": fields.List(
fields.String, required=False, description="List of attachment IDs"
),
},
)
@api.expect(stream_model)
@api.doc(description="Stream a response based on the question and retriever")
def post(self):
data = request.get_json()
required_fields = ["question"]
if "index" in data:
required_fields = ["question", "conversation_id"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
save_conv = data.get("save_conversation", True)
try:
question = data["question"]
history = limit_chat_history(
json.loads(data.get("history", "[]")), gpt_model=gpt_model
)
conversation_id = data.get("conversation_id")
prompt_id = data.get("prompt_id", "default")
attachment_ids = data.get("attachments", [])
index = data.get("index", None)
chunks = int(data.get("chunks", 2))
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
retriever_name = data.get("retriever", "classic")
agent_id = data.get("agent_id", None)
agent_type = settings.AGENT_NAME
decoded_token = getattr(request, "decoded_token", None)
user_sub = decoded_token.get("sub") if decoded_token else None
agent_key, is_shared_usage, shared_token = get_agent_key(agent_id, user_sub)
if agent_key:
data.update({"api_key": agent_key})
else:
agent_id = None
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
chunks = int(data_key.get("chunks", 2))
prompt_id = data_key.get("prompt_id", "default")
source = {"active_docs": data_key.get("source")}
retriever_name = data_key.get("retriever", retriever_name)
user_api_key = data["api_key"]
agent_type = data_key.get("agent_type", agent_type)
if is_shared_usage:
decoded_token = request.decoded_token
else:
decoded_token = {"sub": data_key.get("user")}
is_shared_usage = False
elif "active_docs" in data:
source = {"active_docs": data["active_docs"]}
retriever_name = get_retriever(data["active_docs"]) or retriever_name
user_api_key = None
decoded_token = request.decoded_token
else:
source = {}
user_api_key = None
decoded_token = request.decoded_token
if not decoded_token:
return make_response({"error": "Unauthorized"}, 401)
attachments = get_attachments_content(
attachment_ids, decoded_token.get("sub")
)
logger.info(
f"/stream - request_data: {data}, source: {source}, attachments: {len(attachments)}",
extra={"data": json.dumps({"request_data": data, "source": source})},
)
prompt = get_prompt(prompt_id)
if "isNoneDoc" in data and data["isNoneDoc"] is True:
chunks = 0
agent = AgentCreator.create_agent(
agent_type,
endpoint="stream",
llm_name=settings.LLM_PROVIDER,
gpt_model=gpt_model,
api_key=settings.API_KEY,
user_api_key=user_api_key,
prompt=prompt,
chat_history=history,
decoded_token=decoded_token,
attachments=attachments,
)
retriever = RetrieverCreator.create_retriever(
retriever_name,
source=source,
chat_history=history,
prompt=prompt,
chunks=chunks,
token_limit=token_limit,
gpt_model=gpt_model,
user_api_key=user_api_key,
decoded_token=decoded_token,
)
return Response(
complete_stream(
question=question,
agent=agent,
retriever=retriever,
conversation_id=conversation_id,
user_api_key=user_api_key,
decoded_token=decoded_token,
isNoneDoc=data.get("isNoneDoc"),
index=index,
should_save_conversation=save_conv,
attachment_ids=attachment_ids,
agent_id=agent_id,
is_shared_usage=is_shared_usage,
shared_token=shared_token,
),
mimetype="text/event-stream",
)
except ValueError:
message = "Malformed request body"
logger.error(f"/stream - error: {message}")
return Response(
error_stream_generate(message),
status=400,
mimetype="text/event-stream",
)
except Exception as e:
logger.error(
f"/stream - error: {str(e)} - traceback: {traceback.format_exc()}",
extra={"error": str(e), "traceback": traceback.format_exc()},
)
status_code = 400
return Response(
error_stream_generate("Unknown error occurred"),
status=status_code,
mimetype="text/event-stream",
)
def error_stream_generate(err_response):
data = json.dumps({"type": "error", "error": err_response})
yield f"data: {data}\n\n"
@answer_ns.route("/api/answer")
class Answer(Resource):
answer_model = api.model(
"AnswerModel",
{
"question": fields.String(
required=True, description="The question to answer"
),
"history": fields.List(
fields.String, required=False, description="Conversation history"
),
"conversation_id": fields.String(
required=False, description="Conversation ID"
),
"prompt_id": fields.String(
required=False, default="default", description="Prompt ID"
),
"chunks": fields.Integer(
required=False, default=2, description="Number of chunks"
),
"token_limit": fields.Integer(required=False, description="Token limit"),
"retriever": fields.String(required=False, description="Retriever type"),
"api_key": fields.String(required=False, description="API key"),
"active_docs": fields.String(
required=False, description="Active documents"
),
"isNoneDoc": fields.Boolean(
required=False, description="Flag indicating if no document is used"
),
},
)
@api.expect(answer_model)
@api.doc(description="Provide an answer based on the question and retriever")
def post(self):
data = request.get_json()
required_fields = ["question"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
question = data["question"]
history = limit_chat_history(
json.loads(data.get("history", [])), gpt_model=gpt_model
)
conversation_id = data.get("conversation_id")
prompt_id = data.get("prompt_id", "default")
chunks = int(data.get("chunks", 2))
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
retriever_name = data.get("retriever", "classic")
agent_type = settings.AGENT_NAME
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
chunks = int(data_key.get("chunks", 2))
prompt_id = data_key.get("prompt_id", "default")
source = {"active_docs": data_key.get("source")}
retriever_name = data_key.get("retriever", retriever_name)
user_api_key = data["api_key"]
agent_type = data_key.get("agent_type", agent_type)
decoded_token = {"sub": data_key.get("user")}
elif "active_docs" in data:
source = {"active_docs": data["active_docs"]}
retriever_name = get_retriever(data["active_docs"]) or retriever_name
user_api_key = None
decoded_token = request.decoded_token
else:
source = {}
user_api_key = None
decoded_token = request.decoded_token
if not decoded_token:
return make_response({"error": "Unauthorized"}, 401)
prompt = get_prompt(prompt_id)
logger.info(
f"/api/answer - request_data: {data}, source: {source}",
extra={"data": json.dumps({"request_data": data, "source": source})},
)
agent = AgentCreator.create_agent(
agent_type,
endpoint="api/answer",
llm_name=settings.LLM_PROVIDER,
gpt_model=gpt_model,
api_key=settings.API_KEY,
user_api_key=user_api_key,
prompt=prompt,
chat_history=history,
decoded_token=decoded_token,
)
retriever = RetrieverCreator.create_retriever(
retriever_name,
source=source,
chat_history=history,
prompt=prompt,
chunks=chunks,
token_limit=token_limit,
gpt_model=gpt_model,
user_api_key=user_api_key,
decoded_token=decoded_token,
)
response_full = ""
source_log_docs = []
tool_calls = []
stream_ended = False
thought = ""
for line in complete_stream(
question=question,
agent=agent,
retriever=retriever,
conversation_id=conversation_id,
user_api_key=user_api_key,
decoded_token=decoded_token,
isNoneDoc=data.get("isNoneDoc"),
index=None,
should_save_conversation=False,
):
try:
event_data = line.replace("data: ", "").strip()
event = json.loads(event_data)
if event["type"] == "answer":
response_full += event["answer"]
elif event["type"] == "source":
source_log_docs = event["source"]
elif event["type"] == "tool_calls":
tool_calls = event["tool_calls"]
elif event["type"] == "thought":
thought = event["thought"]
elif event["type"] == "error":
logger.error(f"Error from stream: {event['error']}")
return bad_request(500, event["error"])
elif event["type"] == "end":
stream_ended = True
except (json.JSONDecodeError, KeyError) as e:
logger.warning(f"Error parsing stream event: {e}, line: {line}")
continue
if not stream_ended:
logger.error("Stream ended unexpectedly without an 'end' event.")
return bad_request(500, "Stream ended unexpectedly.")
if data.get("isNoneDoc"):
for doc in source_log_docs:
doc["source"] = "None"
llm = LLMCreator.create_llm(
settings.LLM_PROVIDER,
api_key=settings.API_KEY,
user_api_key=user_api_key,
decoded_token=decoded_token,
)
result = {"answer": response_full, "sources": source_log_docs}
result["conversation_id"] = str(
save_conversation(
conversation_id,
question,
response_full,
thought,
source_log_docs,
tool_calls,
llm,
decoded_token,
api_key=user_api_key,
)
)
retriever_params = retriever.get_params()
user_logs_collection.insert_one(
{
"action": "api_answer",
"level": "info",
"user": decoded_token.get("sub"),
"api_key": user_api_key,
"question": question,
"response": response_full,
"sources": source_log_docs,
"retriever_params": retriever_params,
"timestamp": datetime.datetime.now(datetime.timezone.utc),
}
)
except Exception as e:
logger.error(
f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",
extra={"error": str(e), "traceback": traceback.format_exc()},
)
return bad_request(500, str(e))
return make_response(result, 200)
@answer_ns.route("/api/search")
class Search(Resource):
search_model = api.model(
"SearchModel",
{
"question": fields.String(
required=True, description="The question to search"
),
"chunks": fields.Integer(
required=False, default=2, description="Number of chunks"
),
"api_key": fields.String(
required=False, description="API key for authentication"
),
"active_docs": fields.String(
required=False, description="Active documents for retrieval"
),
"retriever": fields.String(required=False, description="Retriever type"),
"token_limit": fields.Integer(
required=False, description="Limit for tokens"
),
"isNoneDoc": fields.Boolean(
required=False, description="Flag indicating if no document is used"
),
},
)
@api.expect(search_model)
@api.doc(
description="Search for relevant documents based on the question and retriever"
)
def post(self):
data = request.get_json()
required_fields = ["question"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
question = data["question"]
chunks = int(data.get("chunks", 2))
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
retriever_name = data.get("retriever", "classic")
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
chunks = int(data_key.get("chunks", 2))
source = {"active_docs": data_key.get("source")}
user_api_key = data["api_key"]
decoded_token = {"sub": data_key.get("user")}
elif "active_docs" in data:
source = {"active_docs": data["active_docs"]}
user_api_key = None
decoded_token = request.decoded_token
else:
source = {}
user_api_key = None
decoded_token = request.decoded_token
if not decoded_token:
return make_response({"error": "Unauthorized"}, 401)
logger.info(
f"/api/answer - request_data: {data}, source: {source}",
extra={"data": json.dumps({"request_data": data, "source": source})},
)
retriever = RetrieverCreator.create_retriever(
retriever_name,
source=source,
chat_history=[],
prompt="default",
chunks=chunks,
token_limit=token_limit,
gpt_model=gpt_model,
user_api_key=user_api_key,
decoded_token=decoded_token,
)
docs = retriever.search(question)
retriever_params = retriever.get_params()
user_logs_collection.insert_one(
{
"action": "api_search",
"level": "info",
"user": decoded_token.get("sub"),
"api_key": user_api_key,
"question": question,
"sources": docs,
"retriever_params": retriever_params,
"timestamp": datetime.datetime.now(datetime.timezone.utc),
}
)
if data.get("isNoneDoc"):
for doc in docs:
doc["source"] = "None"
except Exception as e:
logger.error(
f"/api/search - error: {str(e)} - traceback: {traceback.format_exc()}",
extra={"error": str(e), "traceback": traceback.format_exc()},
)
return bad_request(500, str(e))
return make_response(docs, 200)
def get_attachments_content(attachment_ids, user):
"""
Retrieve content from attachment documents based on their IDs.
Args:
attachment_ids (list): List of attachment document IDs
user (str): User identifier to verify ownership
Returns:
list: List of dictionaries containing attachment content and metadata
"""
if not attachment_ids:
return []
attachments = []
for attachment_id in attachment_ids:
try:
attachment_doc = attachments_collection.find_one(
{"_id": ObjectId(attachment_id), "user": user}
)
if attachment_doc:
attachments.append(attachment_doc)
except Exception as e:
logger.error(
f"Error retrieving attachment {attachment_id}: {e}", exc_info=True
)
return attachments

View File

@@ -0,0 +1,142 @@
import logging
import traceback
from flask import make_response, request
from flask_restx import fields, Resource
from application.api import api
from application.api.answer.routes.base import answer_ns, BaseAnswerResource
from application.api.answer.services.stream_processor import StreamProcessor
logger = logging.getLogger(__name__)
@answer_ns.route("/api/answer")
class AnswerResource(Resource, BaseAnswerResource):
def __init__(self, *args, **kwargs):
Resource.__init__(self, *args, **kwargs)
BaseAnswerResource.__init__(self)
answer_model = answer_ns.model(
"AnswerModel",
{
"question": fields.String(
required=True, description="Question to be asked"
),
"history": fields.List(
fields.String,
required=False,
description="Conversation history (only for new conversations)",
),
"conversation_id": fields.String(
required=False,
description="Existing conversation ID (loads history)",
),
"prompt_id": fields.String(
required=False, default="default", description="Prompt ID"
),
"chunks": fields.Integer(
required=False, default=2, description="Number of chunks"
),
"token_limit": fields.Integer(required=False, description="Token limit"),
"retriever": fields.String(required=False, description="Retriever type"),
"api_key": fields.String(required=False, description="API key"),
"active_docs": fields.String(
required=False, description="Active documents"
),
"isNoneDoc": fields.Boolean(
required=False, description="Flag indicating if no document is used"
),
"save_conversation": fields.Boolean(
required=False,
default=True,
description="Whether to save the conversation",
),
"model_id": fields.String(
required=False,
description="Model ID to use for this request",
),
"passthrough": fields.Raw(
required=False,
description="Dynamic parameters to inject into prompt template",
),
},
)
@api.expect(answer_model)
@api.doc(description="Provide a response based on the question and retriever")
def post(self):
data = request.get_json()
if error := self.validate_request(data):
return error
decoded_token = getattr(request, "decoded_token", None)
processor = StreamProcessor(data, decoded_token)
try:
processor.initialize()
if not processor.decoded_token:
return make_response({"error": "Unauthorized"}, 401)
docs_together, docs_list = processor.pre_fetch_docs(
data.get("question", "")
)
tools_data = processor.pre_fetch_tools()
agent = processor.create_agent(
docs_together=docs_together,
docs=docs_list,
tools_data=tools_data,
)
if error := self.check_usage(processor.agent_config):
return error
stream = self.complete_stream(
question=data["question"],
agent=agent,
conversation_id=processor.conversation_id,
user_api_key=processor.agent_config.get("user_api_key"),
decoded_token=processor.decoded_token,
isNoneDoc=data.get("isNoneDoc"),
index=None,
should_save_conversation=data.get("save_conversation", True),
model_id=processor.model_id,
)
stream_result = self.process_response_stream(stream)
if len(stream_result) == 7:
(
conversation_id,
response,
sources,
tool_calls,
thought,
error,
structured_info,
) = stream_result
else:
conversation_id, response, sources, tool_calls, thought, error = (
stream_result
)
structured_info = None
if error:
return make_response({"error": error}, 400)
result = {
"conversation_id": conversation_id,
"answer": response,
"sources": sources,
"tool_calls": tool_calls,
"thought": thought,
}
if structured_info:
result.update(structured_info)
except Exception as e:
logger.error(
f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",
extra={"error": str(e), "traceback": traceback.format_exc()},
)
return make_response({"error": str(e)}, 500)
return make_response(result, 200)

View File

@@ -0,0 +1,442 @@
import datetime
import json
import logging
from typing import Any, Dict, Generator, List, Optional
from flask import jsonify, make_response, Response
from flask_restx import Namespace
from application.api.answer.services.conversation_service import ConversationService
from application.core.model_utils import (
get_api_key_for_provider,
get_default_model_id,
get_provider_from_model_id,
)
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
from application.utils import check_required_fields
logger = logging.getLogger(__name__)
answer_ns = Namespace("answer", description="Answer related operations", path="/")
class BaseAnswerResource:
"""Shared base class for answer endpoints"""
def __init__(self):
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
self.db = db
self.user_logs_collection = db["user_logs"]
self.default_model_id = get_default_model_id()
self.conversation_service = ConversationService()
def validate_request(
self, data: Dict[str, Any], require_conversation_id: bool = False
) -> Optional[Response]:
"""Common request validation"""
required_fields = ["question"]
if require_conversation_id:
required_fields.append("conversation_id")
if missing_fields := check_required_fields(data, required_fields):
return missing_fields
return None
def check_usage(self, agent_config: Dict) -> Optional[Response]:
"""Check if there is a usage limit and if it is exceeded
Args:
agent_config: The config dict of agent instance
Returns:
None or Response if either of limits exceeded.
"""
api_key = agent_config.get("user_api_key")
if not api_key:
return None
agents_collection = self.db["agents"]
agent = agents_collection.find_one({"key": api_key})
if not agent:
return make_response(
jsonify({"success": False, "message": "Invalid API key."}), 401
)
limited_token_mode_raw = agent.get("limited_token_mode", False)
limited_request_mode_raw = agent.get("limited_request_mode", False)
limited_token_mode = (
limited_token_mode_raw
if isinstance(limited_token_mode_raw, bool)
else limited_token_mode_raw == "True"
)
limited_request_mode = (
limited_request_mode_raw
if isinstance(limited_request_mode_raw, bool)
else limited_request_mode_raw == "True"
)
token_limit = int(
agent.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"])
)
request_limit = int(
agent.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"])
)
token_usage_collection = self.db["token_usage"]
end_date = datetime.datetime.now()
start_date = end_date - datetime.timedelta(hours=24)
match_query = {
"timestamp": {"$gte": start_date, "$lte": end_date},
"api_key": api_key,
}
if limited_token_mode:
token_pipeline = [
{"$match": match_query},
{
"$group": {
"_id": None,
"total_tokens": {
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
},
}
},
]
token_result = list(token_usage_collection.aggregate(token_pipeline))
daily_token_usage = token_result[0]["total_tokens"] if token_result else 0
else:
daily_token_usage = 0
if limited_request_mode:
daily_request_usage = token_usage_collection.count_documents(match_query)
else:
daily_request_usage = 0
if not limited_token_mode and not limited_request_mode:
return None
token_exceeded = (
limited_token_mode and token_limit > 0 and daily_token_usage >= token_limit
)
request_exceeded = (
limited_request_mode
and request_limit > 0
and daily_request_usage >= request_limit
)
if token_exceeded or request_exceeded:
return make_response(
jsonify(
{
"success": False,
"message": "Exceeding usage limit, please try again later.",
}
),
429,
)
return None
def complete_stream(
self,
question: str,
agent: Any,
conversation_id: Optional[str],
user_api_key: Optional[str],
decoded_token: Dict[str, Any],
isNoneDoc: bool = False,
index: Optional[int] = None,
should_save_conversation: bool = True,
attachment_ids: Optional[List[str]] = None,
agent_id: Optional[str] = None,
is_shared_usage: bool = False,
shared_token: Optional[str] = None,
model_id: Optional[str] = None,
) -> Generator[str, None, None]:
"""
Generator function that streams the complete conversation response.
Args:
question: The user's question
agent: The agent instance
retriever: The retriever instance
conversation_id: Existing conversation ID
user_api_key: User's API key if any
decoded_token: Decoded JWT token
isNoneDoc: Flag for document-less responses
index: Index of message to update
should_save_conversation: Whether to persist the conversation
attachment_ids: List of attachment IDs
agent_id: ID of agent used
is_shared_usage: Flag for shared agent usage
shared_token: Token for shared agent
model_id: Model ID used for the request
retrieved_docs: Pre-fetched documents for sources (optional)
Yields:
Server-sent event strings
"""
try:
response_full, thought, source_log_docs, tool_calls = "", "", [], []
is_structured = False
schema_info = None
structured_chunks = []
for line in agent.gen(query=question):
if "answer" in line:
response_full += str(line["answer"])
if line.get("structured"):
is_structured = True
schema_info = line.get("schema")
structured_chunks.append(line["answer"])
else:
data = json.dumps({"type": "answer", "answer": line["answer"]})
yield f"data: {data}\n\n"
elif "sources" in line:
truncated_sources = []
source_log_docs = line["sources"]
for source in line["sources"]:
truncated_source = source.copy()
if "text" in truncated_source:
truncated_source["text"] = (
truncated_source["text"][:100].strip() + "..."
)
truncated_sources.append(truncated_source)
if truncated_sources:
data = json.dumps(
{"type": "source", "source": truncated_sources}
)
yield f"data: {data}\n\n"
elif "tool_calls" in line:
tool_calls = line["tool_calls"]
data = json.dumps({"type": "tool_calls", "tool_calls": tool_calls})
yield f"data: {data}\n\n"
elif "thought" in line:
thought += line["thought"]
data = json.dumps({"type": "thought", "thought": line["thought"]})
yield f"data: {data}\n\n"
elif "type" in line:
data = json.dumps(line)
yield f"data: {data}\n\n"
if is_structured and structured_chunks:
structured_data = {
"type": "structured_answer",
"answer": response_full,
"structured": True,
"schema": schema_info,
}
data = json.dumps(structured_data)
yield f"data: {data}\n\n"
if isNoneDoc:
for doc in source_log_docs:
doc["source"] = "None"
provider = (
get_provider_from_model_id(model_id)
if model_id
else settings.LLM_PROVIDER
)
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
llm = LLMCreator.create_llm(
provider or settings.LLM_PROVIDER,
api_key=system_api_key,
user_api_key=user_api_key,
decoded_token=decoded_token,
model_id=model_id,
)
if should_save_conversation:
conversation_id = self.conversation_service.save_conversation(
conversation_id,
question,
response_full,
thought,
source_log_docs,
tool_calls,
llm,
model_id or self.default_model_id,
decoded_token,
index=index,
api_key=user_api_key,
agent_id=agent_id,
is_shared_usage=is_shared_usage,
shared_token=shared_token,
attachment_ids=attachment_ids,
)
# Persist compression metadata/summary if it exists and wasn't saved mid-execution
compression_meta = getattr(agent, "compression_metadata", None)
compression_saved = getattr(agent, "compression_saved", False)
if conversation_id and compression_meta and not compression_saved:
try:
self.conversation_service.update_compression_metadata(
conversation_id, compression_meta
)
self.conversation_service.append_compression_message(
conversation_id, compression_meta
)
agent.compression_saved = True
logger.info(
f"Persisted compression metadata for conversation {conversation_id}"
)
except Exception as e:
logger.error(
f"Failed to persist compression metadata: {str(e)}",
exc_info=True,
)
else:
conversation_id = None
id_data = {"type": "id", "id": str(conversation_id)}
data = json.dumps(id_data)
yield f"data: {data}\n\n"
log_data = {
"action": "stream_answer",
"level": "info",
"user": decoded_token.get("sub"),
"api_key": user_api_key,
"question": question,
"response": response_full,
"sources": source_log_docs,
"attachments": attachment_ids,
"timestamp": datetime.datetime.now(datetime.timezone.utc),
}
if is_structured:
log_data["structured_output"] = True
if schema_info:
log_data["schema"] = schema_info
# Clean up text fields to be no longer than 10000 characters
for key, value in log_data.items():
if isinstance(value, str) and len(value) > 10000:
log_data[key] = value[:10000]
self.user_logs_collection.insert_one(log_data)
data = json.dumps({"type": "end"})
yield f"data: {data}\n\n"
except GeneratorExit:
logger.info(f"Stream aborted by client for question: {question[:50]}... ")
# Save partial response
if should_save_conversation and response_full:
try:
if isNoneDoc:
for doc in source_log_docs:
doc["source"] = "None"
llm = LLMCreator.create_llm(
settings.LLM_PROVIDER,
api_key=settings.API_KEY,
user_api_key=user_api_key,
decoded_token=decoded_token,
)
self.conversation_service.save_conversation(
conversation_id,
question,
response_full,
thought,
source_log_docs,
tool_calls,
llm,
model_id or self.default_model_id,
decoded_token,
index=index,
api_key=user_api_key,
agent_id=agent_id,
is_shared_usage=is_shared_usage,
shared_token=shared_token,
attachment_ids=attachment_ids,
)
compression_meta = getattr(agent, "compression_metadata", None)
compression_saved = getattr(agent, "compression_saved", False)
if conversation_id and compression_meta and not compression_saved:
try:
self.conversation_service.update_compression_metadata(
conversation_id, compression_meta
)
self.conversation_service.append_compression_message(
conversation_id, compression_meta
)
agent.compression_saved = True
logger.info(
f"Persisted compression metadata for conversation {conversation_id} (partial stream)"
)
except Exception as e:
logger.error(
f"Failed to persist compression metadata (partial stream): {str(e)}",
exc_info=True,
)
except Exception as e:
logger.error(
f"Error saving partial response: {str(e)}", exc_info=True
)
raise
except Exception as e:
logger.error(f"Error in stream: {str(e)}", exc_info=True)
data = json.dumps(
{
"type": "error",
"error": "Please try again later. We apologize for any inconvenience.",
}
)
yield f"data: {data}\n\n"
return
def process_response_stream(self, stream):
"""Process the stream response for non-streaming endpoint"""
conversation_id = ""
response_full = ""
source_log_docs = []
tool_calls = []
thought = ""
stream_ended = False
is_structured = False
schema_info = None
for line in stream:
try:
event_data = line.replace("data: ", "").strip()
event = json.loads(event_data)
if event["type"] == "id":
conversation_id = event["id"]
elif event["type"] == "answer":
response_full += event["answer"]
elif event["type"] == "structured_answer":
response_full = event["answer"]
is_structured = True
schema_info = event.get("schema")
elif event["type"] == "source":
source_log_docs = event["source"]
elif event["type"] == "tool_calls":
tool_calls = event["tool_calls"]
elif event["type"] == "thought":
thought = event["thought"]
elif event["type"] == "error":
logger.error(f"Error from stream: {event['error']}")
return None, None, None, None, event["error"], None
elif event["type"] == "end":
stream_ended = True
except (json.JSONDecodeError, KeyError) as e:
logger.warning(f"Error parsing stream event: {e}, line: {line}")
continue
if not stream_ended:
logger.error("Stream ended unexpectedly without an 'end' event.")
return None, None, None, None, "Stream ended unexpectedly", None
result = (
conversation_id,
response_full,
source_log_docs,
tool_calls,
thought,
None,
)
if is_structured:
result = result + ({"structured": True, "schema": schema_info},)
return result
def error_stream_generate(self, err_response):
data = json.dumps({"type": "error", "error": err_response})
yield f"data: {data}\n\n"

View File

@@ -0,0 +1,132 @@
import logging
import traceback
from flask import request, Response
from flask_restx import fields, Resource
from application.api import api
from application.api.answer.routes.base import answer_ns, BaseAnswerResource
from application.api.answer.services.stream_processor import StreamProcessor
logger = logging.getLogger(__name__)
@answer_ns.route("/stream")
class StreamResource(Resource, BaseAnswerResource):
def __init__(self, *args, **kwargs):
Resource.__init__(self, *args, **kwargs)
BaseAnswerResource.__init__(self)
stream_model = answer_ns.model(
"StreamModel",
{
"question": fields.String(
required=True, description="Question to be asked"
),
"history": fields.List(
fields.String,
required=False,
description="Conversation history (only for new conversations)",
),
"conversation_id": fields.String(
required=False,
description="Existing conversation ID (loads history)",
),
"prompt_id": fields.String(
required=False, default="default", description="Prompt ID"
),
"chunks": fields.Integer(
required=False, default=2, description="Number of chunks"
),
"token_limit": fields.Integer(required=False, description="Token limit"),
"retriever": fields.String(required=False, description="Retriever type"),
"api_key": fields.String(required=False, description="API key"),
"active_docs": fields.String(
required=False, description="Active documents"
),
"isNoneDoc": fields.Boolean(
required=False, description="Flag indicating if no document is used"
),
"index": fields.Integer(
required=False, description="Index of the query to update"
),
"save_conversation": fields.Boolean(
required=False,
default=True,
description="Whether to save the conversation",
),
"model_id": fields.String(
required=False,
description="Model ID to use for this request",
),
"attachments": fields.List(
fields.String, required=False, description="List of attachment IDs"
),
"passthrough": fields.Raw(
required=False,
description="Dynamic parameters to inject into prompt template",
),
},
)
@api.expect(stream_model)
@api.doc(description="Stream a response based on the question and retriever")
def post(self):
data = request.get_json()
if error := self.validate_request(data, "index" in data):
return error
decoded_token = getattr(request, "decoded_token", None)
processor = StreamProcessor(data, decoded_token)
try:
processor.initialize()
docs_together, docs_list = processor.pre_fetch_docs(data["question"])
tools_data = processor.pre_fetch_tools()
agent = processor.create_agent(
docs_together=docs_together, docs=docs_list, tools_data=tools_data
)
if error := self.check_usage(processor.agent_config):
return error
return Response(
self.complete_stream(
question=data["question"],
agent=agent,
conversation_id=processor.conversation_id,
user_api_key=processor.agent_config.get("user_api_key"),
decoded_token=processor.decoded_token,
isNoneDoc=data.get("isNoneDoc"),
index=data.get("index"),
should_save_conversation=data.get("save_conversation", True),
attachment_ids=data.get("attachments", []),
agent_id=data.get("agent_id"),
is_shared_usage=processor.is_shared_usage,
shared_token=processor.shared_token,
model_id=processor.model_id,
),
mimetype="text/event-stream",
)
except ValueError as e:
message = "Malformed request body"
logger.error(
f"/stream - error: {message} - specific error: {str(e)} - traceback: {traceback.format_exc()}",
extra={"error": str(e), "traceback": traceback.format_exc()},
)
return Response(
self.error_stream_generate(message),
status=400,
mimetype="text/event-stream",
)
except Exception as e:
logger.error(
f"/stream - error: {str(e)} - traceback: {traceback.format_exc()}",
extra={"error": str(e), "traceback": traceback.format_exc()},
)
return Response(
self.error_stream_generate("Unknown error occurred"),
status=400,
mimetype="text/event-stream",
)

View File

@@ -0,0 +1,20 @@
"""
Compression module for managing conversation context compression.
"""
from application.api.answer.services.compression.orchestrator import (
CompressionOrchestrator,
)
from application.api.answer.services.compression.service import CompressionService
from application.api.answer.services.compression.types import (
CompressionResult,
CompressionMetadata,
)
__all__ = [
"CompressionOrchestrator",
"CompressionService",
"CompressionResult",
"CompressionMetadata",
]

View File

@@ -0,0 +1,234 @@
"""Message reconstruction utilities for compression."""
import logging
import uuid
from typing import Dict, List, Optional
logger = logging.getLogger(__name__)
class MessageBuilder:
"""Builds message arrays from compressed context."""
@staticmethod
def build_from_compressed_context(
system_prompt: str,
compressed_summary: Optional[str],
recent_queries: List[Dict],
include_tool_calls: bool = False,
context_type: str = "pre_request",
) -> List[Dict]:
"""
Build messages from compressed context.
Args:
system_prompt: Original system prompt
compressed_summary: Compressed summary (if any)
recent_queries: Recent uncompressed queries
include_tool_calls: Whether to include tool calls from history
context_type: Type of context ('pre_request' or 'mid_execution')
Returns:
List of message dicts ready for LLM
"""
# Append compression summary to system prompt if present
if compressed_summary:
system_prompt = MessageBuilder._append_compression_context(
system_prompt, compressed_summary, context_type
)
messages = [{"role": "system", "content": system_prompt}]
# Add recent history
for query in recent_queries:
if "prompt" in query and "response" in query:
messages.append({"role": "user", "content": query["prompt"]})
messages.append({"role": "assistant", "content": query["response"]})
# Add tool calls from history if present
if include_tool_calls and "tool_calls" in query:
for tool_call in query["tool_calls"]:
call_id = tool_call.get("call_id") or str(uuid.uuid4())
function_call_dict = {
"function_call": {
"name": tool_call.get("action_name"),
"args": tool_call.get("arguments"),
"call_id": call_id,
}
}
function_response_dict = {
"function_response": {
"name": tool_call.get("action_name"),
"response": {"result": tool_call.get("result")},
"call_id": call_id,
}
}
messages.append(
{"role": "assistant", "content": [function_call_dict]}
)
messages.append(
{"role": "tool", "content": [function_response_dict]}
)
# If no recent queries (everything was compressed), add a continuation user message
if len(recent_queries) == 0 and compressed_summary:
messages.append({
"role": "user",
"content": "Please continue with the remaining tasks based on the context above."
})
logger.info("Added continuation user message to maintain proper turn-taking after full compression")
return messages
@staticmethod
def _append_compression_context(
system_prompt: str, compressed_summary: str, context_type: str = "pre_request"
) -> str:
"""
Append compression context to system prompt.
Args:
system_prompt: Original system prompt
compressed_summary: Summary to append
context_type: Type of compression context
Returns:
Updated system prompt
"""
# Remove existing compression context if present
if "This session is being continued" in system_prompt or "Context window limit reached" in system_prompt:
parts = system_prompt.split("\n\n---\n\n")
system_prompt = parts[0]
# Build appropriate context message based on type
if context_type == "mid_execution":
context_message = (
"\n\n---\n\n"
"Context window limit reached during execution. "
"Previous conversation has been compressed to fit within limits. "
"The conversation is summarized below:\n\n"
f"{compressed_summary}"
)
else: # pre_request
context_message = (
"\n\n---\n\n"
"This session is being continued from a previous conversation that "
"has been compressed to fit within context limits. "
"The conversation is summarized below:\n\n"
f"{compressed_summary}"
)
return system_prompt + context_message
@staticmethod
def rebuild_messages_after_compression(
messages: List[Dict],
compressed_summary: Optional[str],
recent_queries: List[Dict],
include_current_execution: bool = False,
include_tool_calls: bool = False,
) -> Optional[List[Dict]]:
"""
Rebuild the message list after compression so tool execution can continue.
Args:
messages: Original message list
compressed_summary: Compressed summary
recent_queries: Recent uncompressed queries
include_current_execution: Whether to preserve current execution messages
include_tool_calls: Whether to include tool calls from history
Returns:
Rebuilt message list or None if failed
"""
# Find the system message
system_message = next(
(msg for msg in messages if msg.get("role") == "system"), None
)
if not system_message:
logger.warning("No system message found in messages list")
return None
# Update system message with compressed summary
if compressed_summary:
content = system_message.get("content", "")
system_message["content"] = MessageBuilder._append_compression_context(
content, compressed_summary, "mid_execution"
)
logger.info(
"Appended compression summary to system prompt (truncated): %s",
(
compressed_summary[:500] + "..."
if len(compressed_summary) > 500
else compressed_summary
),
)
rebuilt_messages = [system_message]
# Add recent history from compressed context
for query in recent_queries:
if "prompt" in query and "response" in query:
rebuilt_messages.append({"role": "user", "content": query["prompt"]})
rebuilt_messages.append(
{"role": "assistant", "content": query["response"]}
)
# Add tool calls from history if present
if include_tool_calls and "tool_calls" in query:
for tool_call in query["tool_calls"]:
call_id = tool_call.get("call_id") or str(uuid.uuid4())
function_call_dict = {
"function_call": {
"name": tool_call.get("action_name"),
"args": tool_call.get("arguments"),
"call_id": call_id,
}
}
function_response_dict = {
"function_response": {
"name": tool_call.get("action_name"),
"response": {"result": tool_call.get("result")},
"call_id": call_id,
}
}
rebuilt_messages.append(
{"role": "assistant", "content": [function_call_dict]}
)
rebuilt_messages.append(
{"role": "tool", "content": [function_response_dict]}
)
# If no recent queries (everything was compressed), add a continuation user message
if len(recent_queries) == 0 and compressed_summary:
rebuilt_messages.append({
"role": "user",
"content": "Please continue with the remaining tasks based on the context above."
})
logger.info("Added continuation user message to maintain proper turn-taking after full compression")
if include_current_execution:
# Preserve any messages that were added during the current execution cycle
recent_msg_count = 1 # system message
for query in recent_queries:
if "prompt" in query and "response" in query:
recent_msg_count += 2
if "tool_calls" in query:
recent_msg_count += len(query["tool_calls"]) * 2
if len(messages) > recent_msg_count:
current_execution_messages = messages[recent_msg_count:]
rebuilt_messages.extend(current_execution_messages)
logger.info(
f"Preserved {len(current_execution_messages)} messages from current execution cycle"
)
logger.info(
f"Messages rebuilt: {len(messages)}{len(rebuilt_messages)} messages. "
f"Ready to continue tool execution."
)
return rebuilt_messages

View File

@@ -0,0 +1,232 @@
"""High-level compression orchestration."""
import logging
from typing import Any, Dict, Optional
from application.api.answer.services.compression.service import CompressionService
from application.api.answer.services.compression.threshold_checker import (
CompressionThresholdChecker,
)
from application.api.answer.services.compression.types import CompressionResult
from application.api.answer.services.conversation_service import ConversationService
from application.core.model_utils import (
get_api_key_for_provider,
get_provider_from_model_id,
)
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
logger = logging.getLogger(__name__)
class CompressionOrchestrator:
"""
Facade for compression operations.
Coordinates between all compression components and provides
a simple interface for callers.
"""
def __init__(
self,
conversation_service: ConversationService,
threshold_checker: Optional[CompressionThresholdChecker] = None,
):
"""
Initialize orchestrator.
Args:
conversation_service: Service for DB operations
threshold_checker: Custom threshold checker (optional)
"""
self.conversation_service = conversation_service
self.threshold_checker = threshold_checker or CompressionThresholdChecker()
def compress_if_needed(
self,
conversation_id: str,
user_id: str,
model_id: str,
decoded_token: Dict[str, Any],
current_query_tokens: int = 500,
) -> CompressionResult:
"""
Check if compression is needed and perform it if so.
This is the main entry point for compression operations.
Args:
conversation_id: Conversation ID
user_id: User ID
model_id: Model being used for conversation
decoded_token: User's decoded JWT token
current_query_tokens: Estimated tokens for current query
Returns:
CompressionResult with summary and recent queries
"""
try:
# Load conversation
conversation = self.conversation_service.get_conversation(
conversation_id, user_id
)
if not conversation:
logger.warning(
f"Conversation {conversation_id} not found for user {user_id}"
)
return CompressionResult.failure("Conversation not found")
# Check if compression is needed
if not self.threshold_checker.should_compress(
conversation, model_id, current_query_tokens
):
# No compression needed, return full history
queries = conversation.get("queries", [])
return CompressionResult.success_no_compression(queries)
# Perform compression
return self._perform_compression(
conversation_id, conversation, model_id, decoded_token
)
except Exception as e:
logger.error(
f"Error in compress_if_needed: {str(e)}", exc_info=True
)
return CompressionResult.failure(str(e))
def _perform_compression(
self,
conversation_id: str,
conversation: Dict[str, Any],
model_id: str,
decoded_token: Dict[str, Any],
) -> CompressionResult:
"""
Perform the actual compression operation.
Args:
conversation_id: Conversation ID
conversation: Conversation document
model_id: Model ID for conversation
decoded_token: User token
Returns:
CompressionResult
"""
try:
# Determine which model to use for compression
compression_model = (
settings.COMPRESSION_MODEL_OVERRIDE
if settings.COMPRESSION_MODEL_OVERRIDE
else model_id
)
# Get provider and API key for compression model
provider = get_provider_from_model_id(compression_model)
api_key = get_api_key_for_provider(provider)
# Create compression LLM
compression_llm = LLMCreator.create_llm(
provider,
api_key=api_key,
user_api_key=None,
decoded_token=decoded_token,
model_id=compression_model,
)
# Create compression service with DB update capability
compression_service = CompressionService(
llm=compression_llm,
model_id=compression_model,
conversation_service=self.conversation_service,
)
# Compress all queries up to the latest
queries_count = len(conversation.get("queries", []))
compress_up_to = queries_count - 1
if compress_up_to < 0:
logger.warning("No queries to compress")
return CompressionResult.success_no_compression([])
logger.info(
f"Initiating compression for conversation {conversation_id}: "
f"compressing all {queries_count} queries (0-{compress_up_to})"
)
# Perform compression and save to DB
metadata = compression_service.compress_and_save(
conversation_id, conversation, compress_up_to
)
logger.info(
f"Compression successful - ratio: {metadata.compression_ratio:.1f}x, "
f"saved {metadata.original_token_count - metadata.compressed_token_count} tokens"
)
# Reload conversation with updated metadata
conversation = self.conversation_service.get_conversation(
conversation_id, user_id=decoded_token.get("sub")
)
# Get compressed context
compressed_summary, recent_queries = (
compression_service.get_compressed_context(conversation)
)
return CompressionResult.success_with_compression(
compressed_summary, recent_queries, metadata
)
except Exception as e:
logger.error(f"Error performing compression: {str(e)}", exc_info=True)
return CompressionResult.failure(str(e))
def compress_mid_execution(
self,
conversation_id: str,
user_id: str,
model_id: str,
decoded_token: Dict[str, Any],
current_conversation: Optional[Dict[str, Any]] = None,
) -> CompressionResult:
"""
Perform compression during tool execution.
Args:
conversation_id: Conversation ID
user_id: User ID
model_id: Model ID
decoded_token: User token
current_conversation: Pre-loaded conversation (optional)
Returns:
CompressionResult
"""
try:
# Load conversation if not provided
if current_conversation:
conversation = current_conversation
else:
conversation = self.conversation_service.get_conversation(
conversation_id, user_id
)
if not conversation:
logger.warning(
f"Could not load conversation {conversation_id} for mid-execution compression"
)
return CompressionResult.failure("Conversation not found")
# Perform compression
return self._perform_compression(
conversation_id, conversation, model_id, decoded_token
)
except Exception as e:
logger.error(
f"Error in mid-execution compression: {str(e)}", exc_info=True
)
return CompressionResult.failure(str(e))

View File

@@ -0,0 +1,149 @@
"""Compression prompt building logic."""
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
class CompressionPromptBuilder:
"""Builds prompts for LLM compression calls."""
def __init__(self, version: str = "v1.0"):
"""
Initialize prompt builder.
Args:
version: Prompt template version to use
"""
self.version = version
self.system_prompt = self._load_prompt(version)
def _load_prompt(self, version: str) -> str:
"""
Load prompt template from file.
Args:
version: Version string (e.g., 'v1.0')
Returns:
Prompt template content
Raises:
FileNotFoundError: If prompt template file doesn't exist
"""
current_dir = Path(__file__).resolve().parents[4]
prompt_path = current_dir / "prompts" / "compression" / f"{version}.txt"
try:
with open(prompt_path, "r") as f:
return f.read()
except FileNotFoundError:
logger.error(f"Compression prompt template not found: {prompt_path}")
raise FileNotFoundError(
f"Compression prompt template '{version}' not found at {prompt_path}. "
f"Please ensure the template file exists."
)
def build_prompt(
self,
queries: List[Dict[str, Any]],
existing_compressions: Optional[List[Dict[str, Any]]] = None,
) -> List[Dict[str, str]]:
"""
Build messages for compression LLM call.
Args:
queries: List of query objects to compress
existing_compressions: List of previous compression points
Returns:
List of message dicts for LLM
"""
# Build conversation text
conversation_text = self._format_conversation(queries)
# Add existing compression context if present
existing_compression_context = ""
if existing_compressions and len(existing_compressions) > 0:
existing_compression_context = (
"\n\nIMPORTANT: This conversation has been compressed before. "
"Previous compression summaries:\n\n"
)
for i, comp in enumerate(existing_compressions):
existing_compression_context += (
f"--- Compression {i + 1} (up to message {comp.get('query_index', 'unknown')}) ---\n"
f"{comp.get('compressed_summary', '')}\n\n"
)
existing_compression_context += (
"Your task is to create a NEW summary that incorporates the context from "
"previous compressions AND the new messages below. The final summary should "
"be comprehensive and include all important information from both previous "
"compressions and new messages.\n\n"
)
user_prompt = (
f"{existing_compression_context}"
f"Here is the conversation to summarize:\n\n"
f"{conversation_text}"
)
messages = [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": user_prompt},
]
return messages
def _format_conversation(self, queries: List[Dict[str, Any]]) -> str:
"""
Format conversation queries into readable text for compression.
Args:
queries: List of query objects
Returns:
Formatted conversation text
"""
conversation_lines = []
for i, query in enumerate(queries):
conversation_lines.append(f"--- Message {i + 1} ---")
conversation_lines.append(f"User: {query.get('prompt', '')}")
# Add tool calls if present
tool_calls = query.get("tool_calls", [])
if tool_calls:
conversation_lines.append("\nTool Calls:")
for tc in tool_calls:
tool_name = tc.get("tool_name", "unknown")
action_name = tc.get("action_name", "unknown")
arguments = tc.get("arguments", {})
result = tc.get("result", "")
if result is None:
result = ""
status = tc.get("status", "unknown")
# Include full tool result for complete compression context
conversation_lines.append(
f" - {tool_name}.{action_name}({arguments}) "
f"[{status}] → {result}"
)
# Add agent thought if present
thought = query.get("thought", "")
if thought:
conversation_lines.append(f"\nAgent Thought: {thought}")
# Add assistant response
conversation_lines.append(f"\nAssistant: {query.get('response', '')}")
# Add sources if present
sources = query.get("sources", [])
if sources:
conversation_lines.append(f"\nSources Used: {len(sources)} documents")
conversation_lines.append("") # Empty line between messages
return "\n".join(conversation_lines)

View File

@@ -0,0 +1,306 @@
"""Core compression service with simplified responsibilities."""
import logging
import re
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
from application.api.answer.services.compression.prompt_builder import (
CompressionPromptBuilder,
)
from application.api.answer.services.compression.token_counter import TokenCounter
from application.api.answer.services.compression.types import (
CompressionMetadata,
)
from application.core.settings import settings
logger = logging.getLogger(__name__)
class CompressionService:
"""
Service for compressing conversation history.
Handles DB updates.
"""
def __init__(
self,
llm,
model_id: str,
conversation_service=None,
prompt_builder: Optional[CompressionPromptBuilder] = None,
):
"""
Initialize compression service.
Args:
llm: LLM instance to use for compression
model_id: Model ID for compression
conversation_service: Service for DB operations (optional, for DB updates)
prompt_builder: Custom prompt builder (optional)
"""
self.llm = llm
self.model_id = model_id
self.conversation_service = conversation_service
self.prompt_builder = prompt_builder or CompressionPromptBuilder(
version=settings.COMPRESSION_PROMPT_VERSION
)
def compress_conversation(
self,
conversation: Dict[str, Any],
compress_up_to_index: int,
) -> CompressionMetadata:
"""
Compress conversation history up to specified index.
Args:
conversation: Full conversation document
compress_up_to_index: Last query index to include in compression
Returns:
CompressionMetadata with compression details
Raises:
ValueError: If compress_up_to_index is invalid
"""
try:
queries = conversation.get("queries", [])
if compress_up_to_index < 0 or compress_up_to_index >= len(queries):
raise ValueError(
f"Invalid compress_up_to_index: {compress_up_to_index} "
f"(conversation has {len(queries)} queries)"
)
# Get queries to compress
queries_to_compress = queries[: compress_up_to_index + 1]
# Check if there are existing compressions
existing_compressions = conversation.get("compression_metadata", {}).get(
"compression_points", []
)
if existing_compressions:
logger.info(
f"Found {len(existing_compressions)} previous compression(s) - "
f"will incorporate into new summary"
)
# Calculate original token count
original_tokens = TokenCounter.count_query_tokens(queries_to_compress)
# Log tool call stats
self._log_tool_call_stats(queries_to_compress)
# Build compression prompt
messages = self.prompt_builder.build_prompt(
queries_to_compress, existing_compressions
)
# Call LLM to generate compression
logger.info(
f"Starting compression: {len(queries_to_compress)} queries "
f"(messages 0-{compress_up_to_index}, {original_tokens} tokens) "
f"using model {self.model_id}"
)
response = self.llm.gen(
model=self.model_id, messages=messages, max_tokens=4000
)
# Extract summary from response
compressed_summary = self._extract_summary(response)
# Calculate compressed token count
compressed_tokens = TokenCounter.count_message_tokens(
[{"content": compressed_summary}]
)
# Calculate compression ratio
compression_ratio = (
original_tokens / compressed_tokens if compressed_tokens > 0 else 0
)
logger.info(
f"Compression complete: {original_tokens}{compressed_tokens} tokens "
f"({compression_ratio:.1f}x compression)"
)
# Build compression metadata
compression_metadata = CompressionMetadata(
timestamp=datetime.now(timezone.utc),
query_index=compress_up_to_index,
compressed_summary=compressed_summary,
original_token_count=original_tokens,
compressed_token_count=compressed_tokens,
compression_ratio=compression_ratio,
model_used=self.model_id,
compression_prompt_version=self.prompt_builder.version,
)
return compression_metadata
except Exception as e:
logger.error(f"Error compressing conversation: {str(e)}", exc_info=True)
raise
def compress_and_save(
self,
conversation_id: str,
conversation: Dict[str, Any],
compress_up_to_index: int,
) -> CompressionMetadata:
"""
Compress conversation and save to database.
Args:
conversation_id: Conversation ID
conversation: Full conversation document
compress_up_to_index: Last query index to include
Returns:
CompressionMetadata
Raises:
ValueError: If conversation_service not provided or invalid index
"""
if not self.conversation_service:
raise ValueError(
"conversation_service required for compress_and_save operation"
)
# Perform compression
metadata = self.compress_conversation(conversation, compress_up_to_index)
# Save to database
self.conversation_service.update_compression_metadata(
conversation_id, metadata.to_dict()
)
logger.info(f"Compression metadata saved to database for {conversation_id}")
return metadata
def get_compressed_context(
self, conversation: Dict[str, Any]
) -> tuple[Optional[str], List[Dict[str, Any]]]:
"""
Get compressed summary + recent uncompressed messages.
Args:
conversation: Full conversation document
Returns:
(compressed_summary, recent_messages)
"""
try:
compression_metadata = conversation.get("compression_metadata", {})
if not compression_metadata.get("is_compressed"):
logger.debug("No compression metadata found - using full history")
queries = conversation.get("queries", [])
if queries is None:
logger.error("Conversation queries is None - returning empty list")
return None, []
return None, queries
compression_points = compression_metadata.get("compression_points", [])
if not compression_points:
logger.debug("No compression points found - using full history")
queries = conversation.get("queries", [])
if queries is None:
logger.error("Conversation queries is None - returning empty list")
return None, []
return None, queries
# Get the most recent compression point
latest_compression = compression_points[-1]
compressed_summary = latest_compression.get("compressed_summary")
last_compressed_index = latest_compression.get("query_index")
compressed_tokens = latest_compression.get("compressed_token_count", 0)
original_tokens = latest_compression.get("original_token_count", 0)
# Get only messages after compression point
queries = conversation.get("queries", [])
total_queries = len(queries)
recent_queries = queries[last_compressed_index + 1 :]
logger.info(
f"Using compressed context: summary ({compressed_tokens} tokens, "
f"compressed from {original_tokens}) + {len(recent_queries)} recent messages "
f"(messages {last_compressed_index + 1}-{total_queries - 1})"
)
return compressed_summary, recent_queries
except Exception as e:
logger.error(
f"Error getting compressed context: {str(e)}", exc_info=True
)
queries = conversation.get("queries", [])
if queries is None:
return None, []
return None, queries
def _extract_summary(self, llm_response: str) -> str:
"""
Extract clean summary from LLM response.
Args:
llm_response: Raw LLM response
Returns:
Cleaned summary text
"""
try:
# Try to extract content within <summary> tags
summary_match = re.search(
r"<summary>(.*?)</summary>", llm_response, re.DOTALL
)
if summary_match:
summary = summary_match.group(1).strip()
else:
# If no summary tags, remove analysis tags and use the rest
summary = re.sub(
r"<analysis>.*?</analysis>", "", llm_response, flags=re.DOTALL
).strip()
return summary
except Exception as e:
logger.warning(f"Error extracting summary: {str(e)}, using full response")
return llm_response
def _log_tool_call_stats(self, queries: List[Dict[str, Any]]) -> None:
"""Log statistics about tool calls in queries."""
total_tool_calls = 0
total_tool_result_chars = 0
tool_call_breakdown = {}
for q in queries:
for tc in q.get("tool_calls", []):
total_tool_calls += 1
tool_name = tc.get("tool_name", "unknown")
action_name = tc.get("action_name", "unknown")
key = f"{tool_name}.{action_name}"
tool_call_breakdown[key] = tool_call_breakdown.get(key, 0) + 1
# Track total tool result size
result = tc.get("result", "")
if result:
total_tool_result_chars += len(str(result))
if total_tool_calls > 0:
tool_breakdown_str = ", ".join(
f"{tool}({count})"
for tool, count in sorted(tool_call_breakdown.items())
)
tool_result_kb = total_tool_result_chars / 1024
logger.info(
f"Tool call breakdown: {tool_breakdown_str} "
f"(total result size: {tool_result_kb:.1f} KB, {total_tool_result_chars:,} chars)"
)

View File

@@ -0,0 +1,103 @@
"""Compression threshold checking logic."""
import logging
from typing import Any, Dict
from application.core.model_utils import get_token_limit
from application.core.settings import settings
from application.api.answer.services.compression.token_counter import TokenCounter
logger = logging.getLogger(__name__)
class CompressionThresholdChecker:
"""Determines if compression is needed based on token thresholds."""
def __init__(self, threshold_percentage: float = None):
"""
Initialize threshold checker.
Args:
threshold_percentage: Percentage of context to use as threshold
(defaults to settings.COMPRESSION_THRESHOLD_PERCENTAGE)
"""
self.threshold_percentage = (
threshold_percentage or settings.COMPRESSION_THRESHOLD_PERCENTAGE
)
def should_compress(
self,
conversation: Dict[str, Any],
model_id: str,
current_query_tokens: int = 500,
) -> bool:
"""
Determine if compression is needed.
Args:
conversation: Full conversation document
model_id: Target model for this request
current_query_tokens: Estimated tokens for current query
Returns:
True if tokens >= threshold% of context window
"""
try:
# Calculate total tokens in conversation
total_tokens = TokenCounter.count_conversation_tokens(conversation)
total_tokens += current_query_tokens
# Get context window limit for model
context_limit = get_token_limit(model_id)
# Calculate threshold
threshold = int(context_limit * self.threshold_percentage)
compression_needed = total_tokens >= threshold
percentage_used = (total_tokens / context_limit) * 100
if compression_needed:
logger.warning(
f"COMPRESSION TRIGGERED: {total_tokens} tokens / {context_limit} limit "
f"({percentage_used:.1f}% used, threshold: {self.threshold_percentage * 100:.0f}%)"
)
else:
logger.info(
f"Compression check: {total_tokens}/{context_limit} tokens "
f"({percentage_used:.1f}% used, threshold: {self.threshold_percentage * 100:.0f}%) - No compression needed"
)
return compression_needed
except Exception as e:
logger.error(f"Error checking compression need: {str(e)}", exc_info=True)
return False
def check_message_tokens(self, messages: list, model_id: str) -> bool:
"""
Check if message list exceeds threshold.
Args:
messages: List of message dicts
model_id: Target model
Returns:
True if at or above threshold
"""
try:
current_tokens = TokenCounter.count_message_tokens(messages)
context_limit = get_token_limit(model_id)
threshold = int(context_limit * self.threshold_percentage)
if current_tokens >= threshold:
logger.warning(
f"Message context limit approaching: {current_tokens}/{context_limit} tokens "
f"({(current_tokens/context_limit)*100:.1f}%)"
)
return True
return False
except Exception as e:
logger.error(f"Error checking message tokens: {str(e)}", exc_info=True)
return False

View File

@@ -0,0 +1,103 @@
"""Token counting utilities for compression."""
import logging
from typing import Any, Dict, List
from application.utils import num_tokens_from_string
from application.core.settings import settings
logger = logging.getLogger(__name__)
class TokenCounter:
"""Centralized token counting for conversations and messages."""
@staticmethod
def count_message_tokens(messages: List[Dict]) -> int:
"""
Calculate total tokens in a list of messages.
Args:
messages: List of message dicts with 'content' field
Returns:
Total token count
"""
total_tokens = 0
for message in messages:
content = message.get("content", "")
if isinstance(content, str):
total_tokens += num_tokens_from_string(content)
elif isinstance(content, list):
# Handle structured content (tool calls, etc.)
for item in content:
if isinstance(item, dict):
total_tokens += num_tokens_from_string(str(item))
return total_tokens
@staticmethod
def count_query_tokens(
queries: List[Dict[str, Any]], include_tool_calls: bool = True
) -> int:
"""
Count tokens across multiple query objects.
Args:
queries: List of query objects from conversation
include_tool_calls: Whether to count tool call tokens
Returns:
Total token count
"""
total_tokens = 0
for query in queries:
# Count prompt and response tokens
if "prompt" in query:
total_tokens += num_tokens_from_string(query["prompt"])
if "response" in query:
total_tokens += num_tokens_from_string(query["response"])
if "thought" in query:
total_tokens += num_tokens_from_string(query.get("thought", ""))
# Count tool call tokens
if include_tool_calls and "tool_calls" in query:
for tool_call in query["tool_calls"]:
tool_call_string = (
f"Tool: {tool_call.get('tool_name')} | "
f"Action: {tool_call.get('action_name')} | "
f"Args: {tool_call.get('arguments')} | "
f"Response: {tool_call.get('result')}"
)
total_tokens += num_tokens_from_string(tool_call_string)
return total_tokens
@staticmethod
def count_conversation_tokens(
conversation: Dict[str, Any], include_system_prompt: bool = False
) -> int:
"""
Calculate total tokens in a conversation.
Args:
conversation: Conversation document
include_system_prompt: Whether to include system prompt in count
Returns:
Total token count
"""
try:
queries = conversation.get("queries", [])
total_tokens = TokenCounter.count_query_tokens(queries)
# Add system prompt tokens if requested
if include_system_prompt:
# Rough estimate for system prompt
total_tokens += settings.RESERVED_TOKENS.get("system_prompt", 500)
return total_tokens
except Exception as e:
logger.error(f"Error calculating conversation tokens: {str(e)}")
return 0

View File

@@ -0,0 +1,83 @@
"""Type definitions for compression module."""
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Dict, List, Optional
@dataclass
class CompressionMetadata:
"""Metadata about a compression operation."""
timestamp: datetime
query_index: int
compressed_summary: str
original_token_count: int
compressed_token_count: int
compression_ratio: float
model_used: str
compression_prompt_version: str
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for DB storage."""
return {
"timestamp": self.timestamp,
"query_index": self.query_index,
"compressed_summary": self.compressed_summary,
"original_token_count": self.original_token_count,
"compressed_token_count": self.compressed_token_count,
"compression_ratio": self.compression_ratio,
"model_used": self.model_used,
"compression_prompt_version": self.compression_prompt_version,
}
@dataclass
class CompressionResult:
"""Result of a compression operation."""
success: bool
compressed_summary: Optional[str] = None
recent_queries: List[Dict[str, Any]] = field(default_factory=list)
metadata: Optional[CompressionMetadata] = None
error: Optional[str] = None
compression_performed: bool = False
@classmethod
def success_with_compression(
cls, summary: str, queries: List[Dict], metadata: CompressionMetadata
) -> "CompressionResult":
"""Create a successful result with compression."""
return cls(
success=True,
compressed_summary=summary,
recent_queries=queries,
metadata=metadata,
compression_performed=True,
)
@classmethod
def success_no_compression(cls, queries: List[Dict]) -> "CompressionResult":
"""Create a successful result without compression needed."""
return cls(
success=True,
recent_queries=queries,
compression_performed=False,
)
@classmethod
def failure(cls, error: str) -> "CompressionResult":
"""Create a failure result."""
return cls(success=False, error=error, compression_performed=False)
def as_history(self) -> List[Dict[str, str]]:
"""
Convert recent queries to history format.
Returns:
List of prompt/response dicts
"""
return [
{"prompt": q["prompt"], "response": q["response"]}
for q in self.recent_queries
]

View File

@@ -0,0 +1,282 @@
import logging
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from bson import ObjectId
logger = logging.getLogger(__name__)
class ConversationService:
def __init__(self):
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
self.conversations_collection = db["conversations"]
self.agents_collection = db["agents"]
def get_conversation(
self, conversation_id: str, user_id: str
) -> Optional[Dict[str, Any]]:
"""Retrieve a conversation with proper access control"""
if not conversation_id or not user_id:
return None
try:
conversation = self.conversations_collection.find_one(
{
"_id": ObjectId(conversation_id),
"$or": [{"user": user_id}, {"shared_with": user_id}],
}
)
if not conversation:
logger.warning(
f"Conversation not found or unauthorized - ID: {conversation_id}, User: {user_id}"
)
return None
conversation["_id"] = str(conversation["_id"])
return conversation
except Exception as e:
logger.error(f"Error fetching conversation: {str(e)}", exc_info=True)
return None
def save_conversation(
self,
conversation_id: Optional[str],
question: str,
response: str,
thought: str,
sources: List[Dict[str, Any]],
tool_calls: List[Dict[str, Any]],
llm: Any,
model_id: str,
decoded_token: Dict[str, Any],
index: Optional[int] = None,
api_key: Optional[str] = None,
agent_id: Optional[str] = None,
is_shared_usage: bool = False,
shared_token: Optional[str] = None,
attachment_ids: Optional[List[str]] = None,
) -> str:
"""Save or update a conversation in the database"""
user_id = decoded_token.get("sub")
if not user_id:
raise ValueError("User ID not found in token")
current_time = datetime.now(timezone.utc)
# clean up in sources array such that we save max 1k characters for text part
for source in sources:
if "text" in source and isinstance(source["text"], str):
source["text"] = source["text"][:1000]
if conversation_id is not None and index is not None:
# Update existing conversation with new query
result = self.conversations_collection.update_one(
{
"_id": ObjectId(conversation_id),
"user": user_id,
f"queries.{index}": {"$exists": True},
},
{
"$set": {
f"queries.{index}.prompt": question,
f"queries.{index}.response": response,
f"queries.{index}.thought": thought,
f"queries.{index}.sources": sources,
f"queries.{index}.tool_calls": tool_calls,
f"queries.{index}.timestamp": current_time,
f"queries.{index}.attachments": attachment_ids,
f"queries.{index}.model_id": model_id,
}
},
)
if result.matched_count == 0:
raise ValueError("Conversation not found or unauthorized")
self.conversations_collection.update_one(
{
"_id": ObjectId(conversation_id),
"user": user_id,
f"queries.{index}": {"$exists": True},
},
{"$push": {"queries": {"$each": [], "$slice": index + 1}}},
)
return conversation_id
elif conversation_id:
# Append new message to existing conversation
result = self.conversations_collection.update_one(
{"_id": ObjectId(conversation_id), "user": user_id},
{
"$push": {
"queries": {
"prompt": question,
"response": response,
"thought": thought,
"sources": sources,
"tool_calls": tool_calls,
"timestamp": current_time,
"attachments": attachment_ids,
"model_id": model_id,
}
}
},
)
if result.matched_count == 0:
raise ValueError("Conversation not found or unauthorized")
return conversation_id
else:
# Create new conversation
messages_summary = [
{
"role": "system",
"content": "You are a helpful assistant that creates concise conversation titles. "
"Summarize conversations in 3 words or less using the same language as the user.",
},
{
"role": "user",
"content": "Summarise following conversation in no more than 3 words, "
"respond ONLY with the summary, use the same language as the "
"user query \n\nUser: " + question + "\n\n" + "AI: " + response,
},
]
completion = llm.gen(
model=model_id, messages=messages_summary, max_tokens=30
)
conversation_data = {
"user": user_id,
"date": current_time,
"name": completion,
"queries": [
{
"prompt": question,
"response": response,
"thought": thought,
"sources": sources,
"tool_calls": tool_calls,
"timestamp": current_time,
"attachments": attachment_ids,
"model_id": model_id,
}
],
}
if api_key:
if agent_id:
conversation_data["agent_id"] = agent_id
if is_shared_usage:
conversation_data["is_shared_usage"] = is_shared_usage
conversation_data["shared_token"] = shared_token
agent = self.agents_collection.find_one({"key": api_key})
if agent:
conversation_data["api_key"] = agent["key"]
result = self.conversations_collection.insert_one(conversation_data)
return str(result.inserted_id)
def update_compression_metadata(
self, conversation_id: str, compression_metadata: Dict[str, Any]
) -> None:
"""
Update conversation with compression metadata.
Uses $push with $slice to keep only the most recent compression points,
preventing unbounded array growth. Since each compression incorporates
previous compressions, older points become redundant.
Args:
conversation_id: Conversation ID
compression_metadata: Compression point data
"""
try:
self.conversations_collection.update_one(
{"_id": ObjectId(conversation_id)},
{
"$set": {
"compression_metadata.is_compressed": True,
"compression_metadata.last_compression_at": compression_metadata.get(
"timestamp"
),
},
"$push": {
"compression_metadata.compression_points": {
"$each": [compression_metadata],
"$slice": -settings.COMPRESSION_MAX_HISTORY_POINTS,
}
},
},
)
logger.info(
f"Updated compression metadata for conversation {conversation_id}"
)
except Exception as e:
logger.error(
f"Error updating compression metadata: {str(e)}", exc_info=True
)
raise
def append_compression_message(
self, conversation_id: str, compression_metadata: Dict[str, Any]
) -> None:
"""
Append a synthetic compression summary entry into the conversation history.
This makes the summary visible in the DB alongside normal queries.
"""
try:
summary = compression_metadata.get("compressed_summary", "")
if not summary:
return
timestamp = compression_metadata.get("timestamp", datetime.now(timezone.utc))
self.conversations_collection.update_one(
{"_id": ObjectId(conversation_id)},
{
"$push": {
"queries": {
"prompt": "[Context Compression Summary]",
"response": summary,
"thought": "",
"sources": [],
"tool_calls": [],
"timestamp": timestamp,
"attachments": [],
"model_id": compression_metadata.get("model_used"),
}
}
},
)
logger.info(f"Appended compression summary to conversation {conversation_id}")
except Exception as e:
logger.error(
f"Error appending compression summary: {str(e)}", exc_info=True
)
def get_compression_metadata(
self, conversation_id: str
) -> Optional[Dict[str, Any]]:
"""
Get compression metadata for a conversation.
Args:
conversation_id: Conversation ID
Returns:
Compression metadata dict or None
"""
try:
conversation = self.conversations_collection.find_one(
{"_id": ObjectId(conversation_id)}, {"compression_metadata": 1}
)
return conversation.get("compression_metadata") if conversation else None
except Exception as e:
logger.error(
f"Error getting compression metadata: {str(e)}", exc_info=True
)
return None

View File

@@ -0,0 +1,97 @@
import logging
from typing import Any, Dict, Optional
from application.templates.namespaces import NamespaceManager
from application.templates.template_engine import TemplateEngine, TemplateRenderError
logger = logging.getLogger(__name__)
class PromptRenderer:
"""Service for rendering prompts with dynamic context using namespaces"""
def __init__(self):
self.template_engine = TemplateEngine()
self.namespace_manager = NamespaceManager()
def render_prompt(
self,
prompt_content: str,
user_id: Optional[str] = None,
request_id: Optional[str] = None,
passthrough_data: Optional[Dict[str, Any]] = None,
docs: Optional[list] = None,
docs_together: Optional[str] = None,
tools_data: Optional[Dict[str, Any]] = None,
**kwargs,
) -> str:
"""
Render prompt with full context from all namespaces.
Args:
prompt_content: Raw prompt template string
user_id: Current user identifier
request_id: Unique request identifier
passthrough_data: Parameters from web request
docs: RAG retrieved documents
docs_together: Concatenated document content
tools_data: Pre-fetched tool results organized by tool name
**kwargs: Additional parameters for namespace builders
Returns:
Rendered prompt string with all variables substituted
Raises:
TemplateRenderError: If template rendering fails
"""
if not prompt_content:
return ""
uses_template = self._uses_template_syntax(prompt_content)
if not uses_template:
return self._apply_legacy_substitutions(prompt_content, docs_together)
try:
context = self.namespace_manager.build_context(
user_id=user_id,
request_id=request_id,
passthrough_data=passthrough_data,
docs=docs,
docs_together=docs_together,
tools_data=tools_data,
**kwargs,
)
return self.template_engine.render(prompt_content, context)
except TemplateRenderError:
raise
except Exception as e:
error_msg = f"Prompt rendering failed: {str(e)}"
logger.error(error_msg)
raise TemplateRenderError(error_msg) from e
def _uses_template_syntax(self, prompt_content: str) -> bool:
"""Check if prompt uses Jinja2 template syntax"""
return "{{" in prompt_content and "}}" in prompt_content
def _apply_legacy_substitutions(
self, prompt_content: str, docs_together: Optional[str] = None
) -> str:
"""
Apply backward-compatible substitutions for old prompt format.
Handles legacy {summaries} and {query} placeholders during transition period.
"""
if docs_together:
prompt_content = prompt_content.replace("{summaries}", docs_together)
return prompt_content
def validate_template(self, prompt_content: str) -> bool:
"""Validate prompt template syntax"""
return self.template_engine.validate_template(prompt_content)
def extract_variables(self, prompt_content: str) -> set[str]:
"""Extract all variable names from prompt template"""
return self.template_engine.extract_variables(prompt_content)

View File

@@ -0,0 +1,755 @@
import datetime
import json
import logging
import os
from pathlib import Path
from typing import Any, Dict, Optional, Set
from bson.dbref import DBRef
from bson.objectid import ObjectId
from application.agents.agent_creator import AgentCreator
from application.api.answer.services.compression import CompressionOrchestrator
from application.api.answer.services.compression.token_counter import TokenCounter
from application.api.answer.services.conversation_service import ConversationService
from application.api.answer.services.prompt_renderer import PromptRenderer
from application.core.model_utils import (
get_api_key_for_provider,
get_default_model_id,
get_provider_from_model_id,
validate_model_id,
)
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.retriever.retriever_creator import RetrieverCreator
from application.utils import (
calculate_doc_token_budget,
limit_chat_history,
)
logger = logging.getLogger(__name__)
def get_prompt(prompt_id: str, prompts_collection=None) -> str:
"""
Get a prompt by preset name or MongoDB ID
"""
current_dir = Path(__file__).resolve().parents[3]
prompts_dir = current_dir / "prompts"
preset_mapping = {
"default": "chat_combine_default.txt",
"creative": "chat_combine_creative.txt",
"strict": "chat_combine_strict.txt",
"reduce": "chat_reduce_prompt.txt",
}
if prompt_id in preset_mapping:
file_path = os.path.join(prompts_dir, preset_mapping[prompt_id])
try:
with open(file_path, "r") as f:
return f.read()
except FileNotFoundError:
raise FileNotFoundError(f"Prompt file not found: {file_path}")
try:
if prompts_collection is None:
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
prompts_collection = db["prompts"]
prompt_doc = prompts_collection.find_one({"_id": ObjectId(prompt_id)})
if not prompt_doc:
raise ValueError(f"Prompt with ID {prompt_id} not found")
return prompt_doc["content"]
except Exception as e:
raise ValueError(f"Invalid prompt ID: {prompt_id}") from e
class StreamProcessor:
def __init__(
self, request_data: Dict[str, Any], decoded_token: Optional[Dict[str, Any]]
):
mongo = MongoDB.get_client()
self.db = mongo[settings.MONGO_DB_NAME]
self.agents_collection = self.db["agents"]
self.attachments_collection = self.db["attachments"]
self.prompts_collection = self.db["prompts"]
self.data = request_data
self.decoded_token = decoded_token
self.initial_user_id = (
self.decoded_token.get("sub") if self.decoded_token is not None else None
)
self.conversation_id = self.data.get("conversation_id")
self.source = {}
self.all_sources = []
self.attachments = []
self.history = []
self.retrieved_docs = []
self.agent_config = {}
self.retriever_config = {}
self.is_shared_usage = False
self.shared_token = None
self.model_id: Optional[str] = None
self.conversation_service = ConversationService()
self.compression_orchestrator = CompressionOrchestrator(
self.conversation_service
)
self.prompt_renderer = PromptRenderer()
self._prompt_content: Optional[str] = None
self._required_tool_actions: Optional[Dict[str, Set[Optional[str]]]] = None
self.compressed_summary: Optional[str] = None
self.compressed_summary_tokens: int = 0
def initialize(self):
"""Initialize all required components for processing"""
self._configure_agent()
self._validate_and_set_model()
self._configure_source()
self._configure_retriever()
self._load_conversation_history()
self._process_attachments()
def _load_conversation_history(self):
"""Load conversation history either from DB or request"""
if self.conversation_id and self.initial_user_id:
conversation = self.conversation_service.get_conversation(
self.conversation_id, self.initial_user_id
)
if not conversation:
raise ValueError("Conversation not found or unauthorized")
# Check if compression is enabled and needed
if settings.ENABLE_CONVERSATION_COMPRESSION:
self._handle_compression(conversation)
else:
# Original behavior - load all history
self.history = [
{"prompt": query["prompt"], "response": query["response"]}
for query in conversation.get("queries", [])
]
else:
self.history = limit_chat_history(
json.loads(self.data.get("history", "[]")), model_id=self.model_id
)
def _handle_compression(self, conversation: Dict[str, Any]):
"""
Handle conversation compression logic using orchestrator.
Args:
conversation: Full conversation document
"""
try:
# Use orchestrator to handle all compression logic
result = self.compression_orchestrator.compress_if_needed(
conversation_id=self.conversation_id,
user_id=self.initial_user_id,
model_id=self.model_id,
decoded_token=self.decoded_token,
)
if not result.success:
logger.error(
f"Compression failed: {result.error}, using full history"
)
self.history = [
{"prompt": query["prompt"], "response": query["response"]}
for query in conversation.get("queries", [])
]
return
# Set compressed summary if compression was performed
if result.compression_performed and result.compressed_summary:
self.compressed_summary = result.compressed_summary
self.compressed_summary_tokens = TokenCounter.count_message_tokens(
[{"content": result.compressed_summary}]
)
logger.info(
f"Using compressed summary ({self.compressed_summary_tokens} tokens) "
f"+ {len(result.recent_queries)} recent messages"
)
# Build history from recent queries
self.history = result.as_history()
except Exception as e:
logger.error(
f"Error handling compression, falling back to standard history: {str(e)}",
exc_info=True,
)
# Fallback to original behavior
self.history = [
{"prompt": query["prompt"], "response": query["response"]}
for query in conversation.get("queries", [])
]
def _process_attachments(self):
"""Process any attachments in the request"""
attachment_ids = self.data.get("attachments", [])
self.attachments = self._get_attachments_content(
attachment_ids, self.initial_user_id
)
def _get_attachments_content(self, attachment_ids, user_id):
"""
Retrieve content from attachment documents based on their IDs.
"""
if not attachment_ids:
return []
attachments = []
for attachment_id in attachment_ids:
try:
attachment_doc = self.attachments_collection.find_one(
{"_id": ObjectId(attachment_id), "user": user_id}
)
if attachment_doc:
attachments.append(attachment_doc)
except Exception as e:
logger.error(
f"Error retrieving attachment {attachment_id}: {e}", exc_info=True
)
return attachments
def _validate_and_set_model(self):
"""Validate and set model_id from request"""
from application.core.model_settings import ModelRegistry
requested_model = self.data.get("model_id")
if requested_model:
if not validate_model_id(requested_model):
registry = ModelRegistry.get_instance()
available_models = [m.id for m in registry.get_enabled_models()]
raise ValueError(
f"Invalid model_id '{requested_model}'. "
f"Available models: {', '.join(available_models[:5])}"
+ (f" and {len(available_models) - 5} more" if len(available_models) > 5 else "")
)
self.model_id = requested_model
else:
# Check if agent has a default model configured
agent_default_model = self.agent_config.get("default_model_id", "")
if agent_default_model and validate_model_id(agent_default_model):
self.model_id = agent_default_model
else:
self.model_id = get_default_model_id()
def _get_agent_key(self, agent_id: Optional[str], user_id: Optional[str]) -> tuple:
"""Get API key for agent with access control"""
if not agent_id:
return None, False, None
try:
agent = self.agents_collection.find_one({"_id": ObjectId(agent_id)})
if agent is None:
raise Exception("Agent not found")
is_owner = agent.get("user") == user_id
is_shared_with_user = agent.get(
"shared_publicly", False
) or user_id in agent.get("shared_with", [])
if not (is_owner or is_shared_with_user):
raise Exception("Unauthorized access to the agent")
if is_owner:
self.agents_collection.update_one(
{"_id": ObjectId(agent_id)},
{
"$set": {
"lastUsedAt": datetime.datetime.now(datetime.timezone.utc)
}
},
)
return str(agent["key"]), not is_owner, agent.get("shared_token")
except Exception as e:
logger.error(f"Error in get_agent_key: {str(e)}", exc_info=True)
raise
def _get_data_from_api_key(self, api_key: str) -> Dict[str, Any]:
data = self.agents_collection.find_one({"key": api_key})
if not data:
raise Exception("Invalid API Key, please generate a new key", 401)
source = data.get("source")
if isinstance(source, DBRef):
source_doc = self.db.dereference(source)
if source_doc:
data["source"] = str(source_doc["_id"])
data["retriever"] = source_doc.get("retriever", data.get("retriever"))
data["chunks"] = source_doc.get("chunks", data.get("chunks"))
else:
data["source"] = None
elif source == "default":
data["source"] = "default"
else:
data["source"] = None
# Handle multiple sources
sources = data.get("sources", [])
if sources and isinstance(sources, list):
sources_list = []
for i, source_ref in enumerate(sources):
if source_ref == "default":
processed_source = {
"id": "default",
"retriever": "classic",
"chunks": data.get("chunks", "2"),
}
sources_list.append(processed_source)
elif isinstance(source_ref, DBRef):
source_doc = self.db.dereference(source_ref)
if source_doc:
processed_source = {
"id": str(source_doc["_id"]),
"retriever": source_doc.get("retriever", "classic"),
"chunks": source_doc.get("chunks", data.get("chunks", "2")),
}
sources_list.append(processed_source)
data["sources"] = sources_list
else:
data["sources"] = []
# Preserve model configuration from agent
data["default_model_id"] = data.get("default_model_id", "")
return data
def _configure_source(self):
"""Configure the source based on agent data"""
api_key = self.data.get("api_key") or self.agent_key
if api_key:
agent_data = self._get_data_from_api_key(api_key)
if agent_data.get("sources") and len(agent_data["sources"]) > 0:
source_ids = [
source["id"] for source in agent_data["sources"] if source.get("id")
]
if source_ids:
self.source = {"active_docs": source_ids}
else:
self.source = {}
self.all_sources = agent_data["sources"]
elif agent_data.get("source"):
self.source = {"active_docs": agent_data["source"]}
self.all_sources = [
{
"id": agent_data["source"],
"retriever": agent_data.get("retriever", "classic"),
}
]
else:
self.source = {}
self.all_sources = []
return
if "active_docs" in self.data:
self.source = {"active_docs": self.data["active_docs"]}
return
self.source = {}
self.all_sources = []
def _configure_agent(self):
"""Configure the agent based on request data"""
agent_id = self.data.get("agent_id")
self.agent_key, self.is_shared_usage, self.shared_token = self._get_agent_key(
agent_id, self.initial_user_id
)
api_key = self.data.get("api_key")
if api_key:
data_key = self._get_data_from_api_key(api_key)
self.agent_config.update(
{
"prompt_id": data_key.get("prompt_id", "default"),
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
"user_api_key": api_key,
"json_schema": data_key.get("json_schema"),
"default_model_id": data_key.get("default_model_id", ""),
}
)
self.initial_user_id = data_key.get("user")
self.decoded_token = {"sub": data_key.get("user")}
if data_key.get("source"):
self.source = {"active_docs": data_key["source"]}
if data_key.get("retriever"):
self.retriever_config["retriever_name"] = data_key["retriever"]
if data_key.get("chunks") is not None:
try:
self.retriever_config["chunks"] = int(data_key["chunks"])
except (ValueError, TypeError):
logger.warning(
f"Invalid chunks value: {data_key['chunks']}, using default value 2"
)
self.retriever_config["chunks"] = 2
elif self.agent_key:
data_key = self._get_data_from_api_key(self.agent_key)
self.agent_config.update(
{
"prompt_id": data_key.get("prompt_id", "default"),
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
"user_api_key": self.agent_key,
"json_schema": data_key.get("json_schema"),
"default_model_id": data_key.get("default_model_id", ""),
}
)
self.decoded_token = (
self.decoded_token
if self.is_shared_usage
else {"sub": data_key.get("user")}
)
if data_key.get("source"):
self.source = {"active_docs": data_key["source"]}
if data_key.get("retriever"):
self.retriever_config["retriever_name"] = data_key["retriever"]
if data_key.get("chunks") is not None:
try:
self.retriever_config["chunks"] = int(data_key["chunks"])
except (ValueError, TypeError):
logger.warning(
f"Invalid chunks value: {data_key['chunks']}, using default value 2"
)
self.retriever_config["chunks"] = 2
else:
self.agent_config.update(
{
"prompt_id": self.data.get("prompt_id", "default"),
"agent_type": settings.AGENT_NAME,
"user_api_key": None,
"json_schema": None,
"default_model_id": "",
}
)
def _configure_retriever(self):
history_token_limit = int(self.data.get("token_limit", 2000))
doc_token_limit = calculate_doc_token_budget(
model_id=self.model_id, history_token_limit=history_token_limit
)
self.retriever_config = {
"retriever_name": self.data.get("retriever", "classic"),
"chunks": int(self.data.get("chunks", 2)),
"doc_token_limit": doc_token_limit,
"history_token_limit": history_token_limit,
}
api_key = self.data.get("api_key") or self.agent_key
if not api_key and "isNoneDoc" in self.data and self.data["isNoneDoc"]:
self.retriever_config["chunks"] = 0
def create_retriever(self):
return RetrieverCreator.create_retriever(
self.retriever_config["retriever_name"],
source=self.source,
chat_history=self.history,
prompt=get_prompt(self.agent_config["prompt_id"], self.prompts_collection),
chunks=self.retriever_config["chunks"],
doc_token_limit=self.retriever_config.get("doc_token_limit", 50000),
model_id=self.model_id,
user_api_key=self.agent_config["user_api_key"],
decoded_token=self.decoded_token,
)
def pre_fetch_docs(self, question: str) -> tuple[Optional[str], Optional[list]]:
"""Pre-fetch documents for template rendering before agent creation"""
if self.data.get("isNoneDoc", False):
logger.info("Pre-fetch skipped: isNoneDoc=True")
return None, None
try:
retriever = self.create_retriever()
logger.info(
f"Pre-fetching docs with chunks={retriever.chunks}, doc_token_limit={retriever.doc_token_limit}"
)
docs = retriever.search(question)
logger.info(f"Pre-fetch retrieved {len(docs) if docs else 0} documents")
if not docs:
logger.info("Pre-fetch: No documents returned from search")
return None, None
self.retrieved_docs = docs
docs_with_filenames = []
for doc in docs:
filename = doc.get("filename") or doc.get("title") or doc.get("source")
if filename:
chunk_header = str(filename)
docs_with_filenames.append(f"{chunk_header}\n{doc['text']}")
else:
docs_with_filenames.append(doc["text"])
docs_together = "\n\n".join(docs_with_filenames)
logger.info(f"Pre-fetch docs_together size: {len(docs_together)} chars")
return docs_together, docs
except Exception as e:
logger.error(f"Failed to pre-fetch docs: {str(e)}", exc_info=True)
return None, None
def pre_fetch_tools(self) -> Optional[Dict[str, Any]]:
"""Pre-fetch tool data for template rendering before agent creation
Can be controlled via:
1. Global setting: ENABLE_TOOL_PREFETCH in .env
2. Per-request: disable_tool_prefetch in request data
"""
if not settings.ENABLE_TOOL_PREFETCH:
logger.info(
"Tool pre-fetching disabled globally via ENABLE_TOOL_PREFETCH setting"
)
return None
if self.data.get("disable_tool_prefetch", False):
logger.info("Tool pre-fetching disabled for this request")
return None
required_tool_actions = self._get_required_tool_actions()
filtering_enabled = required_tool_actions is not None
try:
user_tools_collection = self.db["user_tools"]
user_id = self.initial_user_id or "local"
user_tools = list(
user_tools_collection.find({"user": user_id, "status": True})
)
if not user_tools:
return None
tools_data = {}
for tool_doc in user_tools:
tool_name = tool_doc.get("name")
tool_id = str(tool_doc.get("_id"))
if filtering_enabled:
required_actions_by_name = required_tool_actions.get(
tool_name, set()
)
required_actions_by_id = required_tool_actions.get(tool_id, set())
required_actions = required_actions_by_name | required_actions_by_id
if not required_actions:
continue
else:
required_actions = None
tool_data = self._fetch_tool_data(tool_doc, required_actions)
if tool_data:
tools_data[tool_name] = tool_data
tools_data[tool_id] = tool_data
return tools_data if tools_data else None
except Exception as e:
logger.warning(f"Failed to pre-fetch tools: {type(e).__name__}")
return None
def _fetch_tool_data(
self,
tool_doc: Dict[str, Any],
required_actions: Optional[Set[Optional[str]]],
) -> Optional[Dict[str, Any]]:
"""Fetch and execute tool actions with saved parameters"""
try:
from application.agents.tools.tool_manager import ToolManager
tool_name = tool_doc.get("name")
tool_config = tool_doc.get("config", {}).copy()
tool_config["tool_id"] = str(tool_doc["_id"])
tool_manager = ToolManager(config={tool_name: tool_config})
user_id = self.initial_user_id or "local"
tool = tool_manager.load_tool(tool_name, tool_config, user_id=user_id)
if not tool:
logger.debug(f"Tool '{tool_name}' failed to load")
return None
tool_actions = tool.get_actions_metadata()
if not tool_actions:
logger.debug(f"Tool '{tool_name}' has no actions")
return None
saved_actions = tool_doc.get("actions", [])
include_all_actions = required_actions is None or (
required_actions and None in required_actions
)
allowed_actions: Set[str] = (
{action for action in required_actions if isinstance(action, str)}
if required_actions
else set()
)
action_results = {}
for action_meta in tool_actions:
action_name = action_meta.get("name")
if action_name is None:
continue
if (
not include_all_actions
and allowed_actions
and action_name not in allowed_actions
):
continue
try:
saved_action = None
for sa in saved_actions:
if sa.get("name") == action_name:
saved_action = sa
break
action_params = action_meta.get("parameters", {})
properties = action_params.get("properties", {})
kwargs = {}
for param_name, param_spec in properties.items():
if saved_action:
saved_props = saved_action.get("parameters", {}).get(
"properties", {}
)
if param_name in saved_props:
param_value = saved_props[param_name].get("value")
if param_value is not None:
kwargs[param_name] = param_value
continue
if param_name in tool_config:
kwargs[param_name] = tool_config[param_name]
elif "default" in param_spec:
kwargs[param_name] = param_spec["default"]
result = tool.execute_action(action_name, **kwargs)
action_results[action_name] = result
except Exception as e:
logger.debug(
f"Action '{action_name}' execution failed: {type(e).__name__}"
)
continue
return action_results if action_results else None
except Exception as e:
logger.debug(f"Tool pre-fetch failed for '{tool_name}': {type(e).__name__}")
return None
def _get_prompt_content(self) -> Optional[str]:
"""Retrieve and cache the raw prompt content for the current agent configuration."""
if self._prompt_content is not None:
return self._prompt_content
prompt_id = (
self.agent_config.get("prompt_id")
if isinstance(self.agent_config, dict)
else None
)
if not prompt_id:
return None
try:
self._prompt_content = get_prompt(prompt_id, self.prompts_collection)
except ValueError as e:
logger.debug(f"Invalid prompt ID '{prompt_id}': {str(e)}")
self._prompt_content = None
except Exception as e:
logger.debug(f"Failed to fetch prompt '{prompt_id}': {type(e).__name__}")
self._prompt_content = None
return self._prompt_content
def _get_required_tool_actions(self) -> Optional[Dict[str, Set[Optional[str]]]]:
"""Determine which tool actions are referenced in the prompt template"""
if self._required_tool_actions is not None:
return self._required_tool_actions
prompt_content = self._get_prompt_content()
if prompt_content is None:
return None
if "{{" not in prompt_content or "}}" not in prompt_content:
self._required_tool_actions = {}
return self._required_tool_actions
try:
from application.templates.template_engine import TemplateEngine
template_engine = TemplateEngine()
usages = template_engine.extract_tool_usages(prompt_content)
self._required_tool_actions = usages
return self._required_tool_actions
except Exception as e:
logger.debug(f"Failed to extract tool usages: {type(e).__name__}")
self._required_tool_actions = {}
return self._required_tool_actions
def _fetch_memory_tool_data(
self, tool_doc: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
"""Fetch memory tool data for pre-injection into prompt"""
try:
tool_config = tool_doc.get("config", {}).copy()
tool_config["tool_id"] = str(tool_doc["_id"])
from application.agents.tools.memory import MemoryTool
memory_tool = MemoryTool(tool_config, self.initial_user_id)
root_view = memory_tool.execute_action("view", path="/")
if "Error:" in root_view or not root_view.strip():
return None
return {"root": root_view, "available": True}
except Exception as e:
logger.warning(f"Failed to fetch memory tool data: {str(e)}")
return None
def create_agent(
self,
docs_together: Optional[str] = None,
docs: Optional[list] = None,
tools_data: Optional[Dict[str, Any]] = None,
):
"""Create and return the configured agent with rendered prompt"""
raw_prompt = self._get_prompt_content()
if raw_prompt is None:
raw_prompt = get_prompt(
self.agent_config["prompt_id"], self.prompts_collection
)
self._prompt_content = raw_prompt
rendered_prompt = self.prompt_renderer.render_prompt(
prompt_content=raw_prompt,
user_id=self.initial_user_id,
request_id=self.data.get("request_id"),
passthrough_data=self.data.get("passthrough"),
docs=docs,
docs_together=docs_together,
tools_data=tools_data,
)
provider = (
get_provider_from_model_id(self.model_id)
if self.model_id
else settings.LLM_PROVIDER
)
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
agent = AgentCreator.create_agent(
self.agent_config["agent_type"],
endpoint="stream",
llm_name=provider or settings.LLM_PROVIDER,
model_id=self.model_id,
api_key=system_api_key,
user_api_key=self.agent_config["user_api_key"],
prompt=rendered_prompt,
chat_history=self.history,
retrieved_docs=self.retrieved_docs,
decoded_token=self.decoded_token,
attachments=self.attachments,
json_schema=self.agent_config.get("json_schema"),
compressed_summary=self.compressed_summary,
)
agent.conversation_id = self.conversation_id
agent.initial_user_id = self.initial_user_id
return agent

View File

@@ -0,0 +1,489 @@
import base64
import datetime
import json
import uuid
from bson.objectid import ObjectId
from flask import (
Blueprint,
current_app,
jsonify,
make_response,
request
)
from flask_restx import fields, Namespace, Resource
from application.api.user.tasks import (
ingest_connector_task,
)
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.api import api
from application.parser.connectors.connector_creator import ConnectorCreator
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
sources_collection = db["sources"]
sessions_collection = db["connector_sessions"]
connector = Blueprint("connector", __name__)
connectors_ns = Namespace("connectors", description="Connector operations", path="/")
api.add_namespace(connectors_ns)
@connectors_ns.route("/api/connectors/auth")
class ConnectorAuth(Resource):
@api.doc(description="Get connector OAuth authorization URL", params={"provider": "Connector provider (e.g., google_drive)"})
def get(self):
try:
provider = request.args.get('provider') or request.args.get('source')
if not provider:
return make_response(jsonify({"success": False, "error": "Missing provider"}), 400)
if not ConnectorCreator.is_supported(provider):
return make_response(jsonify({"success": False, "error": f"Unsupported provider: {provider}"}), 400)
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
user_id = decoded_token.get('sub')
now = datetime.datetime.now(datetime.timezone.utc)
result = sessions_collection.insert_one({
"provider": provider,
"user": user_id,
"status": "pending",
"created_at": now
})
state_dict = {
"provider": provider,
"object_id": str(result.inserted_id)
}
state = base64.urlsafe_b64encode(json.dumps(state_dict).encode()).decode()
auth = ConnectorCreator.create_auth(provider)
authorization_url = auth.get_authorization_url(state=state)
return make_response(jsonify({
"success": True,
"authorization_url": authorization_url,
"state": state
}), 200)
except Exception as e:
current_app.logger.error(f"Error generating connector auth URL: {e}")
return make_response(jsonify({"success": False, "error": str(e)}), 500)
@connectors_ns.route("/api/connectors/callback")
class ConnectorsCallback(Resource):
@api.doc(description="Handle OAuth callback for external connectors")
def get(self):
"""Handle OAuth callback for external connectors"""
try:
from application.parser.connectors.connector_creator import ConnectorCreator
from flask import request, redirect
authorization_code = request.args.get('code')
state = request.args.get('state')
error = request.args.get('error')
state_dict = json.loads(base64.urlsafe_b64decode(state.encode()).decode())
provider = state_dict["provider"]
state_object_id = state_dict["object_id"]
if error:
if error == "access_denied":
return redirect(f"/api/connectors/callback-status?status=cancelled&message=Authentication+was+cancelled.+You+can+try+again+if+you'd+like+to+connect+your+account.&provider={provider}")
else:
current_app.logger.warning(f"OAuth error in callback: {error}")
return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}")
if not authorization_code:
return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}")
try:
auth = ConnectorCreator.create_auth(provider)
token_info = auth.exchange_code_for_tokens(authorization_code)
session_token = str(uuid.uuid4())
try:
credentials = auth.create_credentials_from_token_info(token_info)
service = auth.build_drive_service(credentials)
user_info = service.about().get(fields="user").execute()
user_email = user_info.get('user', {}).get('emailAddress', 'Connected User')
except Exception as e:
current_app.logger.warning(f"Could not get user info: {e}")
user_email = 'Connected User'
sanitized_token_info = {
"access_token": token_info.get("access_token"),
"refresh_token": token_info.get("refresh_token"),
"token_uri": token_info.get("token_uri"),
"expiry": token_info.get("expiry")
}
sessions_collection.find_one_and_update(
{"_id": ObjectId(state_object_id), "provider": provider},
{
"$set": {
"session_token": session_token,
"token_info": sanitized_token_info,
"user_email": user_email,
"status": "authorized"
}
}
)
# Redirect to success page with session token and user email
return redirect(f"/api/connectors/callback-status?status=success&message=Authentication+successful&provider={provider}&session_token={session_token}&user_email={user_email}")
except Exception as e:
current_app.logger.error(f"Error exchanging code for tokens: {str(e)}", exc_info=True)
return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}")
except Exception as e:
current_app.logger.error(f"Error handling connector callback: {e}")
return redirect("/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.")
@connectors_ns.route("/api/connectors/files")
class ConnectorFiles(Resource):
@api.expect(api.model("ConnectorFilesModel", {
"provider": fields.String(required=True),
"session_token": fields.String(required=True),
"folder_id": fields.String(required=False),
"limit": fields.Integer(required=False),
"page_token": fields.String(required=False),
"search_query": fields.String(required=False)
}))
@api.doc(description="List files from a connector provider (supports pagination and search)")
def post(self):
try:
data = request.get_json()
provider = data.get('provider')
session_token = data.get('session_token')
folder_id = data.get('folder_id')
limit = data.get('limit', 10)
page_token = data.get('page_token')
search_query = data.get('search_query')
if not provider or not session_token:
return make_response(jsonify({"success": False, "error": "provider and session_token are required"}), 400)
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
user = decoded_token.get('sub')
session = sessions_collection.find_one({"session_token": session_token, "user": user})
if not session:
return make_response(jsonify({"success": False, "error": "Invalid or unauthorized session"}), 401)
loader = ConnectorCreator.create_connector(provider, session_token)
input_config = {
'limit': limit,
'list_only': True,
'session_token': session_token,
'folder_id': folder_id,
'page_token': page_token
}
if search_query:
input_config['search_query'] = search_query
documents = loader.load_data(input_config)
files = []
for doc in documents[:limit]:
metadata = doc.extra_info
modified_time = metadata.get('modified_time')
if modified_time:
date_part = modified_time.split('T')[0]
time_part = modified_time.split('T')[1].split('.')[0].split('Z')[0]
formatted_time = f"{date_part} {time_part}"
else:
formatted_time = None
files.append({
'id': doc.doc_id,
'name': metadata.get('file_name', 'Unknown File'),
'type': metadata.get('mime_type', 'unknown'),
'size': metadata.get('size', None),
'modifiedTime': formatted_time,
'isFolder': metadata.get('is_folder', False)
})
next_token = getattr(loader, 'next_page_token', None)
has_more = bool(next_token)
return make_response(jsonify({
"success": True,
"files": files,
"total": len(files),
"next_page_token": next_token,
"has_more": has_more
}), 200)
except Exception as e:
current_app.logger.error(f"Error loading connector files: {e}")
return make_response(jsonify({"success": False, "error": f"Failed to load files: {str(e)}"}), 500)
@connectors_ns.route("/api/connectors/validate-session")
class ConnectorValidateSession(Resource):
@api.expect(api.model("ConnectorValidateSessionModel", {"provider": fields.String(required=True), "session_token": fields.String(required=True)}))
@api.doc(description="Validate connector session token and return user info and access token")
def post(self):
try:
data = request.get_json()
provider = data.get('provider')
session_token = data.get('session_token')
if not provider or not session_token:
return make_response(jsonify({"success": False, "error": "provider and session_token are required"}), 400)
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
user = decoded_token.get('sub')
session = sessions_collection.find_one({"session_token": session_token, "user": user})
if not session or "token_info" not in session:
return make_response(jsonify({"success": False, "error": "Invalid or expired session"}), 401)
token_info = session["token_info"]
auth = ConnectorCreator.create_auth(provider)
is_expired = auth.is_token_expired(token_info)
if is_expired and token_info.get('refresh_token'):
try:
refreshed_token_info = auth.refresh_access_token(token_info.get('refresh_token'))
sanitized_token_info = {
"access_token": refreshed_token_info.get("access_token"),
"refresh_token": refreshed_token_info.get("refresh_token"),
"token_uri": refreshed_token_info.get("token_uri"),
"expiry": refreshed_token_info.get("expiry")
}
sessions_collection.update_one(
{"session_token": session_token},
{"$set": {"token_info": sanitized_token_info}}
)
token_info = sanitized_token_info
is_expired = False
except Exception as refresh_error:
current_app.logger.error(f"Failed to refresh token: {refresh_error}")
if is_expired:
return make_response(jsonify({
"success": False,
"expired": True,
"error": "Session token has expired. Please reconnect."
}), 401)
return make_response(jsonify({
"success": True,
"expired": False,
"user_email": session.get('user_email', 'Connected User'),
"access_token": token_info.get('access_token')
}), 200)
except Exception as e:
current_app.logger.error(f"Error validating connector session: {e}")
return make_response(jsonify({"success": False, "error": str(e)}), 500)
@connectors_ns.route("/api/connectors/disconnect")
class ConnectorDisconnect(Resource):
@api.expect(api.model("ConnectorDisconnectModel", {"provider": fields.String(required=True), "session_token": fields.String(required=False)}))
@api.doc(description="Disconnect a connector session")
def post(self):
try:
data = request.get_json()
provider = data.get('provider')
session_token = data.get('session_token')
if not provider:
return make_response(jsonify({"success": False, "error": "provider is required"}), 400)
if session_token:
sessions_collection.delete_one({"session_token": session_token})
return make_response(jsonify({"success": True}), 200)
except Exception as e:
current_app.logger.error(f"Error disconnecting connector session: {e}")
return make_response(jsonify({"success": False, "error": str(e)}), 500)
@connectors_ns.route("/api/connectors/sync")
class ConnectorSync(Resource):
@api.expect(
api.model(
"ConnectorSyncModel",
{
"source_id": fields.String(required=True, description="Source ID to sync"),
"session_token": fields.String(required=True, description="Authentication token")
},
)
)
@api.doc(description="Sync connector source to check for modifications")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
try:
data = request.get_json()
source_id = data.get('source_id')
session_token = data.get('session_token')
if not all([source_id, session_token]):
return make_response(
jsonify({
"success": False,
"error": "source_id and session_token are required"
}),
400
)
source = sources_collection.find_one({"_id": ObjectId(source_id)})
if not source:
return make_response(
jsonify({
"success": False,
"error": "Source not found"
}),
404
)
if source.get('user') != decoded_token.get('sub'):
return make_response(
jsonify({
"success": False,
"error": "Unauthorized access to source"
}),
403
)
remote_data = {}
try:
if source.get('remote_data'):
remote_data = json.loads(source.get('remote_data'))
except json.JSONDecodeError:
current_app.logger.error(f"Invalid remote_data format for source {source_id}")
remote_data = {}
source_type = remote_data.get('provider')
if not source_type:
return make_response(
jsonify({
"success": False,
"error": "Source provider not found in remote_data"
}),
400
)
# Extract configuration from remote_data
file_ids = remote_data.get('file_ids', [])
folder_ids = remote_data.get('folder_ids', [])
recursive = remote_data.get('recursive', True)
# Start the sync task
task = ingest_connector_task.delay(
job_name=source.get('name'),
user=decoded_token.get('sub'),
source_type=source_type,
session_token=session_token,
file_ids=file_ids,
folder_ids=folder_ids,
recursive=recursive,
retriever=source.get('retriever', 'classic'),
operation_mode="sync",
doc_id=source_id,
sync_frequency=source.get('sync_frequency', 'never')
)
return make_response(
jsonify({
"success": True,
"task_id": task.id
}),
200
)
except Exception as err:
current_app.logger.error(
f"Error syncing connector source: {err}",
exc_info=True
)
return make_response(
jsonify({
"success": False,
"error": str(err)
}),
400
)
@connectors_ns.route("/api/connectors/callback-status")
class ConnectorCallbackStatus(Resource):
@api.doc(description="Return HTML page with connector authentication status")
def get(self):
"""Return HTML page with connector authentication status"""
try:
status = request.args.get('status', 'error')
message = request.args.get('message', '')
provider = request.args.get('provider', 'connector')
session_token = request.args.get('session_token', '')
user_email = request.args.get('user_email', '')
html_content = f"""
<!DOCTYPE html>
<html>
<head>
<title>{provider.replace('_', ' ').title()} Authentication</title>
<style>
body {{ font-family: Arial, sans-serif; text-align: center; padding: 40px; }}
.container {{ max-width: 600px; margin: 0 auto; }}
.success {{ color: #4CAF50; }}
.error {{ color: #F44336; }}
.cancelled {{ color: #FF9800; }}
</style>
<script>
window.onload = function() {{
const status = "{status}";
const sessionToken = "{session_token}";
const userEmail = "{user_email}";
if (status === "success" && window.opener) {{
window.opener.postMessage({{
type: '{provider}_auth_success',
session_token: sessionToken,
user_email: userEmail
}}, '*');
setTimeout(() => window.close(), 3000);
}} else if (status === "cancelled" || status === "error") {{
setTimeout(() => window.close(), 3000);
}}
}};
</script>
</head>
<body>
<div class="container">
<h2>{provider.replace('_', ' ').title()} Authentication</h2>
<div class="{status}">
<p>{message}</p>
{f'<p>Connected as: {user_email}</p>' if status == 'success' else ''}
</div>
<p><small>You can close this window. {f"Your {provider.replace('_', ' ').title()} is now connected and ready to use." if status == 'success' else "Feel free to close this window."}</small></p>
</div>
</body>
</html>
"""
return make_response(html_content, 200, {'Content-Type': 'text/html'})
except Exception as e:
current_app.logger.error(f"Error rendering callback status page: {e}")
return make_response("Authentication error occurred", 500, {'Content-Type': 'text/html'})

View File

@@ -1,5 +1,6 @@
import os
import datetime
import json
from flask import Blueprint, request, send_from_directory
from werkzeug.utils import secure_filename
from bson.objectid import ObjectId
@@ -48,7 +49,17 @@ def upload_index_files():
remote_data = request.form["remote_data"] if "remote_data" in request.form else None
sync_frequency = request.form["sync_frequency"] if "sync_frequency" in request.form else None
original_file_path = request.form.get("original_file_path")
file_path = request.form.get("file_path")
directory_structure = request.form.get("directory_structure")
if directory_structure:
try:
directory_structure = json.loads(directory_structure)
except Exception:
logger.error("Error parsing directory_structure")
directory_structure = {}
else:
directory_structure = {}
storage = StorageCreator.get_storage()
index_base_path = f"indexes/{id}"
@@ -66,10 +77,13 @@ def upload_index_files():
file_pkl = request.files["file_pkl"]
if file_pkl.filename == "":
return {"status": "no file name"}
# Save index files to storage
storage.save_file(file_faiss, f"{index_base_path}/index.faiss")
storage.save_file(file_pkl, f"{index_base_path}/index.pkl")
faiss_storage_path = f"{index_base_path}/index.faiss"
pkl_storage_path = f"{index_base_path}/index.pkl"
storage.save_file(file_faiss, faiss_storage_path)
storage.save_file(file_pkl, pkl_storage_path)
existing_entry = sources_collection.find_one({"_id": ObjectId(id)})
if existing_entry:
@@ -87,7 +101,8 @@ def upload_index_files():
"retriever": retriever,
"remote_data": remote_data,
"sync_frequency": sync_frequency,
"file_path": original_file_path,
"file_path": file_path,
"directory_structure": directory_structure,
}
},
)
@@ -105,7 +120,8 @@ def upload_index_files():
"retriever": retriever,
"remote_data": remote_data,
"sync_frequency": sync_frequency,
"file_path": original_file_path,
"file_path": file_path,
"directory_structure": directory_structure,
}
)
return {"status": "ok"}

View File

@@ -0,0 +1,5 @@
"""User API module - provides all user-related API endpoints"""
from .routes import user
__all__ = ["user"]

View File

@@ -0,0 +1,7 @@
"""Agents module."""
from .routes import agents_ns
from .sharing import agents_sharing_ns
from .webhooks import agents_webhooks_ns
__all__ = ["agents_ns", "agents_sharing_ns", "agents_webhooks_ns"]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,263 @@
"""Agent management sharing functionality."""
import datetime
import secrets
from bson import DBRef
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from application.api import api
from application.core.settings import settings
from application.api.user.base import (
agents_collection,
db,
ensure_user_doc,
resolve_tool_details,
user_tools_collection,
users_collection,
)
from application.utils import generate_image_url
agents_sharing_ns = Namespace(
"agents", description="Agent management operations", path="/api"
)
@agents_sharing_ns.route("/shared_agent")
class SharedAgent(Resource):
@api.doc(
params={
"token": "Shared token of the agent",
},
description="Get a shared agent by token or ID",
)
def get(self):
shared_token = request.args.get("token")
if not shared_token:
return make_response(
jsonify({"success": False, "message": "Token or ID is required"}), 400
)
try:
query = {
"shared_publicly": True,
"shared_token": shared_token,
}
shared_agent = agents_collection.find_one(query)
if not shared_agent:
return make_response(
jsonify({"success": False, "message": "Shared agent not found"}),
404,
)
agent_id = str(shared_agent["_id"])
data = {
"id": agent_id,
"user": shared_agent.get("user", ""),
"name": shared_agent.get("name", ""),
"image": (
generate_image_url(shared_agent["image"])
if shared_agent.get("image")
else ""
),
"description": shared_agent.get("description", ""),
"source": (
str(source_doc["_id"])
if isinstance(shared_agent.get("source"), DBRef)
and (source_doc := db.dereference(shared_agent.get("source")))
else ""
),
"chunks": shared_agent.get("chunks", "0"),
"retriever": shared_agent.get("retriever", "classic"),
"prompt_id": shared_agent.get("prompt_id", "default"),
"tools": shared_agent.get("tools", []),
"tool_details": resolve_tool_details(shared_agent.get("tools", [])),
"agent_type": shared_agent.get("agent_type", ""),
"status": shared_agent.get("status", ""),
"json_schema": shared_agent.get("json_schema"),
"limited_token_mode": shared_agent.get("limited_token_mode", False),
"token_limit": shared_agent.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"]),
"limited_request_mode": shared_agent.get("limited_request_mode", False),
"request_limit": shared_agent.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"]),
"created_at": shared_agent.get("createdAt", ""),
"updated_at": shared_agent.get("updatedAt", ""),
"shared": shared_agent.get("shared_publicly", False),
"shared_token": shared_agent.get("shared_token", ""),
"shared_metadata": shared_agent.get("shared_metadata", {}),
}
if data["tools"]:
enriched_tools = []
for tool in data["tools"]:
tool_data = user_tools_collection.find_one({"_id": ObjectId(tool)})
if tool_data:
enriched_tools.append(tool_data.get("name", ""))
data["tools"] = enriched_tools
decoded_token = getattr(request, "decoded_token", None)
if decoded_token:
user_id = decoded_token.get("sub")
owner_id = shared_agent.get("user")
if user_id != owner_id:
ensure_user_doc(user_id)
users_collection.update_one(
{"user_id": user_id},
{"$addToSet": {"agent_preferences.shared_with_me": agent_id}},
)
return make_response(jsonify(data), 200)
except Exception as err:
current_app.logger.error(f"Error retrieving shared agent: {err}")
return make_response(jsonify({"success": False}), 400)
@agents_sharing_ns.route("/shared_agents")
class SharedAgents(Resource):
@api.doc(description="Get shared agents explicitly shared with the user")
def get(self):
try:
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user_id = decoded_token.get("sub")
user_doc = ensure_user_doc(user_id)
shared_with_ids = user_doc.get("agent_preferences", {}).get(
"shared_with_me", []
)
shared_object_ids = [ObjectId(id) for id in shared_with_ids]
shared_agents_cursor = agents_collection.find(
{"_id": {"$in": shared_object_ids}, "shared_publicly": True}
)
shared_agents = list(shared_agents_cursor)
found_ids_set = {str(agent["_id"]) for agent in shared_agents}
stale_ids = [id for id in shared_with_ids if id not in found_ids_set]
if stale_ids:
users_collection.update_one(
{"user_id": user_id},
{"$pullAll": {"agent_preferences.shared_with_me": stale_ids}},
)
pinned_ids = set(user_doc.get("agent_preferences", {}).get("pinned", []))
list_shared_agents = [
{
"id": str(agent["_id"]),
"name": agent.get("name", ""),
"description": agent.get("description", ""),
"image": (
generate_image_url(agent["image"]) if agent.get("image") else ""
),
"tools": agent.get("tools", []),
"tool_details": resolve_tool_details(agent.get("tools", [])),
"agent_type": agent.get("agent_type", ""),
"status": agent.get("status", ""),
"json_schema": agent.get("json_schema"),
"limited_token_mode": agent.get("limited_token_mode", False),
"token_limit": agent.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"]),
"limited_request_mode": agent.get("limited_request_mode", False),
"request_limit": agent.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"]),
"created_at": agent.get("createdAt", ""),
"updated_at": agent.get("updatedAt", ""),
"pinned": str(agent["_id"]) in pinned_ids,
"shared": agent.get("shared_publicly", False),
"shared_token": agent.get("shared_token", ""),
"shared_metadata": agent.get("shared_metadata", {}),
}
for agent in shared_agents
]
return make_response(jsonify(list_shared_agents), 200)
except Exception as err:
current_app.logger.error(f"Error retrieving shared agents: {err}")
return make_response(jsonify({"success": False}), 400)
@agents_sharing_ns.route("/share_agent")
class ShareAgent(Resource):
@api.expect(
api.model(
"ShareAgentModel",
{
"id": fields.String(required=True, description="ID of the agent"),
"shared": fields.Boolean(
required=True, description="Share or unshare the agent"
),
"username": fields.String(
required=False, description="Name of the user"
),
},
)
)
@api.doc(description="Share or unshare an agent")
def put(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
if not data:
return make_response(
jsonify({"success": False, "message": "Missing JSON body"}), 400
)
agent_id = data.get("id")
shared = data.get("shared")
username = data.get("username", "")
if not agent_id:
return make_response(
jsonify({"success": False, "message": "ID is required"}), 400
)
if shared is None:
return make_response(
jsonify(
{
"success": False,
"message": "Shared parameter is required and must be true or false",
}
),
400,
)
try:
try:
agent_oid = ObjectId(agent_id)
except Exception:
return make_response(
jsonify({"success": False, "message": "Invalid agent ID"}), 400
)
agent = agents_collection.find_one({"_id": agent_oid, "user": user})
if not agent:
return make_response(
jsonify({"success": False, "message": "Agent not found"}), 404
)
if shared:
shared_metadata = {
"shared_by": username,
"shared_at": datetime.datetime.now(datetime.timezone.utc),
}
shared_token = secrets.token_urlsafe(32)
agents_collection.update_one(
{"_id": agent_oid, "user": user},
{
"$set": {
"shared_publicly": shared,
"shared_metadata": shared_metadata,
"shared_token": shared_token,
}
},
)
else:
agents_collection.update_one(
{"_id": agent_oid, "user": user},
{"$set": {"shared_publicly": shared, "shared_token": None}},
{"$unset": {"shared_metadata": ""}},
)
except Exception as err:
current_app.logger.error(f"Error sharing/unsharing agent: {err}")
return make_response(jsonify({"success": False, "error": str(err)}), 400)
shared_token = shared_token if shared else None
return make_response(
jsonify({"success": True, "shared_token": shared_token}), 200
)

View File

@@ -0,0 +1,119 @@
"""Agent management webhook handlers."""
import secrets
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import Namespace, Resource
from application.api import api
from application.api.user.base import agents_collection, require_agent
from application.api.user.tasks import process_agent_webhook
from application.core.settings import settings
agents_webhooks_ns = Namespace(
"agents", description="Agent management operations", path="/api"
)
@agents_webhooks_ns.route("/agent_webhook")
class AgentWebhook(Resource):
@api.doc(
params={"id": "ID of the agent"},
description="Generate webhook URL for the agent",
)
def get(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
agent_id = request.args.get("id")
if not agent_id:
return make_response(
jsonify({"success": False, "message": "ID is required"}), 400
)
try:
agent = agents_collection.find_one(
{"_id": ObjectId(agent_id), "user": user}
)
if not agent:
return make_response(
jsonify({"success": False, "message": "Agent not found"}), 404
)
webhook_token = agent.get("incoming_webhook_token")
if not webhook_token:
webhook_token = secrets.token_urlsafe(32)
agents_collection.update_one(
{"_id": ObjectId(agent_id), "user": user},
{"$set": {"incoming_webhook_token": webhook_token}},
)
base_url = settings.API_URL.rstrip("/")
full_webhook_url = f"{base_url}/api/webhooks/agents/{webhook_token}"
except Exception as err:
current_app.logger.error(
f"Error generating webhook URL: {err}", exc_info=True
)
return make_response(
jsonify({"success": False, "message": "Error generating webhook URL"}),
400,
)
return make_response(
jsonify({"success": True, "webhook_url": full_webhook_url}), 200
)
@agents_webhooks_ns.route("/webhooks/agents/<string:webhook_token>")
class AgentWebhookListener(Resource):
method_decorators = [require_agent]
def _enqueue_webhook_task(self, agent_id_str, payload, source_method):
if not payload:
current_app.logger.warning(
f"Webhook ({source_method}) received for agent {agent_id_str} with empty payload."
)
current_app.logger.info(
f"Incoming {source_method} webhook for agent {agent_id_str}. Enqueuing task with payload: {payload}"
)
try:
task = process_agent_webhook.delay(
agent_id=agent_id_str,
payload=payload,
)
current_app.logger.info(
f"Task {task.id} enqueued for agent {agent_id_str} ({source_method})."
)
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
except Exception as err:
current_app.logger.error(
f"Error enqueuing webhook task ({source_method}) for agent {agent_id_str}: {err}",
exc_info=True,
)
return make_response(
jsonify({"success": False, "message": "Error processing webhook"}), 500
)
@api.doc(
description="Webhook listener for agent events (POST). Expects JSON payload, which is used to trigger processing.",
)
def post(self, webhook_token, agent, agent_id_str):
payload = request.get_json()
if payload is None:
return make_response(
jsonify(
{
"success": False,
"message": "Invalid or missing JSON data in request body",
}
),
400,
)
return self._enqueue_webhook_task(agent_id_str, payload, source_method="POST")
@api.doc(
description="Webhook listener for agent events (GET). Uses URL query parameters as payload to trigger processing.",
)
def get(self, webhook_token, agent, agent_id_str):
payload = request.args.to_dict(flat=True)
return self._enqueue_webhook_task(agent_id_str, payload, source_method="GET")

View File

@@ -0,0 +1,5 @@
"""Analytics module."""
from .routes import analytics_ns
__all__ = ["analytics_ns"]

View File

@@ -0,0 +1,540 @@
"""Analytics and reporting routes."""
import datetime
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from application.api import api
from application.api.user.base import (
agents_collection,
conversations_collection,
generate_date_range,
generate_hourly_range,
generate_minute_range,
token_usage_collection,
user_logs_collection,
)
analytics_ns = Namespace(
"analytics", description="Analytics and reporting operations", path="/api"
)
@analytics_ns.route("/get_message_analytics")
class GetMessageAnalytics(Resource):
get_message_analytics_model = api.model(
"GetMessageAnalyticsModel",
{
"api_key_id": fields.String(required=False, description="API Key ID"),
"filter_option": fields.String(
required=False,
description="Filter option for analytics",
default="last_30_days",
enum=[
"last_hour",
"last_24_hour",
"last_7_days",
"last_15_days",
"last_30_days",
],
),
},
)
@api.expect(get_message_analytics_model)
@api.doc(description="Get message analytics based on filter option")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
api_key_id = data.get("api_key_id")
filter_option = data.get("filter_option", "last_30_days")
try:
api_key = (
agents_collection.find_one({"_id": ObjectId(api_key_id), "user": user})[
"key"
]
if api_key_id
else None
)
except Exception as err:
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
end_date = datetime.datetime.now(datetime.timezone.utc)
if filter_option == "last_hour":
start_date = end_date - datetime.timedelta(hours=1)
group_format = "%Y-%m-%d %H:%M:00"
elif filter_option == "last_24_hour":
start_date = end_date - datetime.timedelta(hours=24)
group_format = "%Y-%m-%d %H:00"
else:
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
filter_days = (
6
if filter_option == "last_7_days"
else 14 if filter_option == "last_15_days" else 29
)
else:
return make_response(
jsonify({"success": False, "message": "Invalid option"}), 400
)
start_date = end_date - datetime.timedelta(days=filter_days)
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
end_date = end_date.replace(
hour=23, minute=59, second=59, microsecond=999999
)
group_format = "%Y-%m-%d"
try:
match_stage = {
"$match": {
"user": user,
}
}
if api_key:
match_stage["$match"]["api_key"] = api_key
pipeline = [
match_stage,
{"$unwind": "$queries"},
{
"$match": {
"queries.timestamp": {"$gte": start_date, "$lte": end_date}
}
},
{
"$group": {
"_id": {
"$dateToString": {
"format": group_format,
"date": "$queries.timestamp",
}
},
"count": {"$sum": 1},
}
},
{"$sort": {"_id": 1}},
]
message_data = conversations_collection.aggregate(pipeline)
if filter_option == "last_hour":
intervals = generate_minute_range(start_date, end_date)
elif filter_option == "last_24_hour":
intervals = generate_hourly_range(start_date, end_date)
else:
intervals = generate_date_range(start_date, end_date)
daily_messages = {interval: 0 for interval in intervals}
for entry in message_data:
daily_messages[entry["_id"]] = entry["count"]
except Exception as err:
current_app.logger.error(
f"Error getting message analytics: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
return make_response(
jsonify({"success": True, "messages": daily_messages}), 200
)
@analytics_ns.route("/get_token_analytics")
class GetTokenAnalytics(Resource):
get_token_analytics_model = api.model(
"GetTokenAnalyticsModel",
{
"api_key_id": fields.String(required=False, description="API Key ID"),
"filter_option": fields.String(
required=False,
description="Filter option for analytics",
default="last_30_days",
enum=[
"last_hour",
"last_24_hour",
"last_7_days",
"last_15_days",
"last_30_days",
],
),
},
)
@api.expect(get_token_analytics_model)
@api.doc(description="Get token analytics data")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
api_key_id = data.get("api_key_id")
filter_option = data.get("filter_option", "last_30_days")
try:
api_key = (
agents_collection.find_one({"_id": ObjectId(api_key_id), "user": user})[
"key"
]
if api_key_id
else None
)
except Exception as err:
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
end_date = datetime.datetime.now(datetime.timezone.utc)
if filter_option == "last_hour":
start_date = end_date - datetime.timedelta(hours=1)
group_format = "%Y-%m-%d %H:%M:00"
group_stage = {
"$group": {
"_id": {
"minute": {
"$dateToString": {
"format": group_format,
"date": "$timestamp",
}
}
},
"total_tokens": {
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
},
}
}
elif filter_option == "last_24_hour":
start_date = end_date - datetime.timedelta(hours=24)
group_format = "%Y-%m-%d %H:00"
group_stage = {
"$group": {
"_id": {
"hour": {
"$dateToString": {
"format": group_format,
"date": "$timestamp",
}
}
},
"total_tokens": {
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
},
}
}
else:
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
filter_days = (
6
if filter_option == "last_7_days"
else (14 if filter_option == "last_15_days" else 29)
)
else:
return make_response(
jsonify({"success": False, "message": "Invalid option"}), 400
)
start_date = end_date - datetime.timedelta(days=filter_days)
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
end_date = end_date.replace(
hour=23, minute=59, second=59, microsecond=999999
)
group_format = "%Y-%m-%d"
group_stage = {
"$group": {
"_id": {
"day": {
"$dateToString": {
"format": group_format,
"date": "$timestamp",
}
}
},
"total_tokens": {
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
},
}
}
try:
match_stage = {
"$match": {
"user_id": user,
"timestamp": {"$gte": start_date, "$lte": end_date},
}
}
if api_key:
match_stage["$match"]["api_key"] = api_key
token_usage_data = token_usage_collection.aggregate(
[
match_stage,
group_stage,
{"$sort": {"_id": 1}},
]
)
if filter_option == "last_hour":
intervals = generate_minute_range(start_date, end_date)
elif filter_option == "last_24_hour":
intervals = generate_hourly_range(start_date, end_date)
else:
intervals = generate_date_range(start_date, end_date)
daily_token_usage = {interval: 0 for interval in intervals}
for entry in token_usage_data:
if filter_option == "last_hour":
daily_token_usage[entry["_id"]["minute"]] = entry["total_tokens"]
elif filter_option == "last_24_hour":
daily_token_usage[entry["_id"]["hour"]] = entry["total_tokens"]
else:
daily_token_usage[entry["_id"]["day"]] = entry["total_tokens"]
except Exception as err:
current_app.logger.error(
f"Error getting token analytics: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
return make_response(
jsonify({"success": True, "token_usage": daily_token_usage}), 200
)
@analytics_ns.route("/get_feedback_analytics")
class GetFeedbackAnalytics(Resource):
get_feedback_analytics_model = api.model(
"GetFeedbackAnalyticsModel",
{
"api_key_id": fields.String(required=False, description="API Key ID"),
"filter_option": fields.String(
required=False,
description="Filter option for analytics",
default="last_30_days",
enum=[
"last_hour",
"last_24_hour",
"last_7_days",
"last_15_days",
"last_30_days",
],
),
},
)
@api.expect(get_feedback_analytics_model)
@api.doc(description="Get feedback analytics data")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
api_key_id = data.get("api_key_id")
filter_option = data.get("filter_option", "last_30_days")
try:
api_key = (
agents_collection.find_one({"_id": ObjectId(api_key_id), "user": user})[
"key"
]
if api_key_id
else None
)
except Exception as err:
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
end_date = datetime.datetime.now(datetime.timezone.utc)
if filter_option == "last_hour":
start_date = end_date - datetime.timedelta(hours=1)
group_format = "%Y-%m-%d %H:%M:00"
date_field = {
"$dateToString": {
"format": group_format,
"date": "$queries.feedback_timestamp",
}
}
elif filter_option == "last_24_hour":
start_date = end_date - datetime.timedelta(hours=24)
group_format = "%Y-%m-%d %H:00"
date_field = {
"$dateToString": {
"format": group_format,
"date": "$queries.feedback_timestamp",
}
}
else:
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
filter_days = (
6
if filter_option == "last_7_days"
else (14 if filter_option == "last_15_days" else 29)
)
else:
return make_response(
jsonify({"success": False, "message": "Invalid option"}), 400
)
start_date = end_date - datetime.timedelta(days=filter_days)
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
end_date = end_date.replace(
hour=23, minute=59, second=59, microsecond=999999
)
group_format = "%Y-%m-%d"
date_field = {
"$dateToString": {
"format": group_format,
"date": "$queries.feedback_timestamp",
}
}
try:
match_stage = {
"$match": {
"queries.feedback_timestamp": {
"$gte": start_date,
"$lte": end_date,
},
"queries.feedback": {"$exists": True},
}
}
if api_key:
match_stage["$match"]["api_key"] = api_key
pipeline = [
match_stage,
{"$unwind": "$queries"},
{"$match": {"queries.feedback": {"$exists": True}}},
{
"$group": {
"_id": {"time": date_field, "feedback": "$queries.feedback"},
"count": {"$sum": 1},
}
},
{
"$group": {
"_id": "$_id.time",
"positive": {
"$sum": {
"$cond": [
{"$eq": ["$_id.feedback", "LIKE"]},
"$count",
0,
]
}
},
"negative": {
"$sum": {
"$cond": [
{"$eq": ["$_id.feedback", "DISLIKE"]},
"$count",
0,
]
}
},
}
},
{"$sort": {"_id": 1}},
]
feedback_data = conversations_collection.aggregate(pipeline)
if filter_option == "last_hour":
intervals = generate_minute_range(start_date, end_date)
elif filter_option == "last_24_hour":
intervals = generate_hourly_range(start_date, end_date)
else:
intervals = generate_date_range(start_date, end_date)
daily_feedback = {
interval: {"positive": 0, "negative": 0} for interval in intervals
}
for entry in feedback_data:
daily_feedback[entry["_id"]] = {
"positive": entry["positive"],
"negative": entry["negative"],
}
except Exception as err:
current_app.logger.error(
f"Error getting feedback analytics: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
return make_response(
jsonify({"success": True, "feedback": daily_feedback}), 200
)
@analytics_ns.route("/get_user_logs")
class GetUserLogs(Resource):
get_user_logs_model = api.model(
"GetUserLogsModel",
{
"page": fields.Integer(
required=False,
description="Page number for pagination",
default=1,
),
"api_key_id": fields.String(required=False, description="API Key ID"),
"page_size": fields.Integer(
required=False,
description="Number of logs per page",
default=10,
),
},
)
@api.expect(get_user_logs_model)
@api.doc(description="Get user logs with pagination")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
page = int(data.get("page", 1))
api_key_id = data.get("api_key_id")
page_size = int(data.get("page_size", 10))
skip = (page - 1) * page_size
try:
api_key = (
agents_collection.find_one({"_id": ObjectId(api_key_id)})["key"]
if api_key_id
else None
)
except Exception as err:
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
query = {"user": user}
if api_key:
query = {"api_key": api_key}
items_cursor = (
user_logs_collection.find(query)
.sort("timestamp", -1)
.skip(skip)
.limit(page_size + 1)
)
items = list(items_cursor)
results = [
{
"id": str(item.get("_id")),
"action": item.get("action"),
"level": item.get("level"),
"user": item.get("user"),
"question": item.get("question"),
"sources": item.get("sources"),
"retriever_params": item.get("retriever_params"),
"timestamp": item.get("timestamp"),
}
for item in items[:page_size]
]
has_more = len(items) > page_size
return make_response(
jsonify(
{
"success": True,
"logs": results,
"page": page,
"page_size": page_size,
"has_more": has_more,
}
),
200,
)

View File

@@ -0,0 +1,5 @@
"""Attachments module."""
from .routes import attachments_ns
__all__ = ["attachments_ns"]

View File

@@ -0,0 +1,198 @@
"""File attachments and media routes."""
import os
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from application.api import api
from application.api.user.base import agents_collection, storage
from application.api.user.tasks import store_attachment
from application.core.settings import settings
from application.tts.tts_creator import TTSCreator
from application.utils import safe_filename
attachments_ns = Namespace(
"attachments", description="File attachments and media operations", path="/api"
)
@attachments_ns.route("/store_attachment")
class StoreAttachment(Resource):
@api.expect(
api.model(
"AttachmentModel",
{
"file": fields.Raw(required=True, description="File(s) to upload"),
"api_key": fields.String(
required=False, description="API key (optional)"
),
},
)
)
@api.doc(
description="Stores one or multiple attachments without vectorization or training. Supports user or API key authentication."
)
def post(self):
decoded_token = getattr(request, "decoded_token", None)
api_key = request.form.get("api_key") or request.args.get("api_key")
files = request.files.getlist("file")
if not files:
single_file = request.files.get("file")
if single_file:
files = [single_file]
if not files or all(f.filename == "" for f in files):
return make_response(
jsonify({"status": "error", "message": "Missing file(s)"}),
400,
)
user = None
if decoded_token:
user = safe_filename(decoded_token.get("sub"))
elif api_key:
agent = agents_collection.find_one({"key": api_key})
if not agent:
return make_response(
jsonify({"success": False, "message": "Invalid API key"}), 401
)
user = safe_filename(agent.get("user"))
else:
return make_response(
jsonify({"success": False, "message": "Authentication required"}), 401
)
try:
tasks = []
errors = []
original_file_count = len(files)
for idx, file in enumerate(files):
try:
attachment_id = ObjectId()
original_filename = safe_filename(os.path.basename(file.filename))
relative_path = f"{settings.UPLOAD_FOLDER}/{user}/attachments/{str(attachment_id)}/{original_filename}"
metadata = storage.save_file(file, relative_path)
file_info = {
"filename": original_filename,
"attachment_id": str(attachment_id),
"path": relative_path,
"metadata": metadata,
}
task = store_attachment.delay(file_info, user)
tasks.append({
"task_id": task.id,
"filename": original_filename,
"attachment_id": str(attachment_id),
})
except Exception as file_err:
current_app.logger.error(f"Error processing file {idx} ({file.filename}): {file_err}", exc_info=True)
errors.append({
"filename": file.filename,
"error": str(file_err)
})
if not tasks:
error_msg = "No valid files to upload"
if errors:
error_msg += f". Errors: {errors}"
return make_response(
jsonify({"status": "error", "message": error_msg, "errors": errors}),
400,
)
if original_file_count == 1 and len(tasks) == 1:
current_app.logger.info("Returning single task_id response")
return make_response(
jsonify(
{
"success": True,
"task_id": tasks[0]["task_id"],
"message": "File uploaded successfully. Processing started.",
}
),
200,
)
else:
response_data = {
"success": True,
"tasks": tasks,
"message": f"{len(tasks)} file(s) uploaded successfully. Processing started.",
}
if errors:
response_data["errors"] = errors
response_data["message"] += f" {len(errors)} file(s) failed."
return make_response(
jsonify(response_data),
200,
)
except Exception as err:
current_app.logger.error(f"Error storing attachment: {err}", exc_info=True)
return make_response(jsonify({"success": False, "error": str(err)}), 400)
@attachments_ns.route("/images/<path:image_path>")
class ServeImage(Resource):
@api.doc(description="Serve an image from storage")
def get(self, image_path):
try:
file_obj = storage.get_file(image_path)
extension = image_path.split(".")[-1].lower()
content_type = f"image/{extension}"
if extension == "jpg":
content_type = "image/jpeg"
response = make_response(file_obj.read())
response.headers.set("Content-Type", content_type)
response.headers.set("Cache-Control", "max-age=86400")
return response
except FileNotFoundError:
return make_response(
jsonify({"success": False, "message": "Image not found"}), 404
)
except Exception as e:
current_app.logger.error(f"Error serving image: {e}")
return make_response(
jsonify({"success": False, "message": "Error retrieving image"}), 500
)
@attachments_ns.route("/tts")
class TextToSpeech(Resource):
tts_model = api.model(
"TextToSpeechModel",
{
"text": fields.String(
required=True, description="Text to be synthesized as audio"
),
},
)
@api.expect(tts_model)
@api.doc(description="Synthesize audio speech from text")
def post(self):
data = request.get_json()
text = data["text"]
try:
tts_instance = TTSCreator.create_tts(settings.TTS_PROVIDER)
audio_base64, detected_language = tts_instance.text_to_speech(text)
return make_response(
jsonify(
{
"success": True,
"audio_base64": audio_base64,
"lang": detected_language,
}
),
200,
)
except Exception as err:
current_app.logger.error(f"Error synthesizing audio: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)

View File

@@ -0,0 +1,222 @@
"""
Shared utilities, database connections, and helper functions for user API routes.
"""
import datetime
import os
import uuid
from functools import wraps
from typing import Optional, Tuple
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, Response
from pymongo import ReturnDocument
from werkzeug.utils import secure_filename
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.storage.storage_creator import StorageCreator
from application.vectorstore.vector_creator import VectorCreator
storage = StorageCreator.get_storage()
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
conversations_collection = db["conversations"]
sources_collection = db["sources"]
prompts_collection = db["prompts"]
feedback_collection = db["feedback"]
agents_collection = db["agents"]
token_usage_collection = db["token_usage"]
shared_conversations_collections = db["shared_conversations"]
users_collection = db["users"]
user_logs_collection = db["user_logs"]
user_tools_collection = db["user_tools"]
attachments_collection = db["attachments"]
try:
agents_collection.create_index(
[("shared", 1)],
name="shared_index",
background=True,
)
users_collection.create_index("user_id", unique=True)
except Exception as e:
print("Error creating indexes:", e)
current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
def generate_minute_range(start_date, end_date):
"""Generate a dictionary with minute-level time ranges."""
return {
(start_date + datetime.timedelta(minutes=i)).strftime("%Y-%m-%d %H:%M:00"): 0
for i in range(int((end_date - start_date).total_seconds() // 60) + 1)
}
def generate_hourly_range(start_date, end_date):
"""Generate a dictionary with hourly time ranges."""
return {
(start_date + datetime.timedelta(hours=i)).strftime("%Y-%m-%d %H:00"): 0
for i in range(int((end_date - start_date).total_seconds() // 3600) + 1)
}
def generate_date_range(start_date, end_date):
"""Generate a dictionary with daily date ranges."""
return {
(start_date + datetime.timedelta(days=i)).strftime("%Y-%m-%d"): 0
for i in range((end_date - start_date).days + 1)
}
def ensure_user_doc(user_id):
"""
Ensure user document exists with proper agent preferences structure.
Args:
user_id: The user ID to ensure
Returns:
The user document
"""
default_prefs = {
"pinned": [],
"shared_with_me": [],
}
user_doc = users_collection.find_one_and_update(
{"user_id": user_id},
{"$setOnInsert": {"agent_preferences": default_prefs}},
upsert=True,
return_document=ReturnDocument.AFTER,
)
prefs = user_doc.get("agent_preferences", {})
updates = {}
if "pinned" not in prefs:
updates["agent_preferences.pinned"] = []
if "shared_with_me" not in prefs:
updates["agent_preferences.shared_with_me"] = []
if updates:
users_collection.update_one({"user_id": user_id}, {"$set": updates})
user_doc = users_collection.find_one({"user_id": user_id})
return user_doc
def resolve_tool_details(tool_ids):
"""
Resolve tool IDs to their details.
Args:
tool_ids: List of tool IDs
Returns:
List of tool details with id, name, and display_name
"""
tools = user_tools_collection.find(
{"_id": {"$in": [ObjectId(tid) for tid in tool_ids]}}
)
return [
{
"id": str(tool["_id"]),
"name": tool.get("name", ""),
"display_name": tool.get("displayName", tool.get("name", "")),
}
for tool in tools
]
def get_vector_store(source_id):
"""
Get the Vector Store for a given source ID.
Args:
source_id (str): source id of the document
Returns:
Vector store instance
"""
store = VectorCreator.create_vectorstore(
settings.VECTOR_STORE,
source_id=source_id,
embeddings_key=os.getenv("EMBEDDINGS_KEY"),
)
return store
def handle_image_upload(
request, existing_url: str, user: str, storage, base_path: str = "attachments/"
) -> Tuple[str, Optional[Response]]:
"""
Handle image file upload from request.
Args:
request: Flask request object
existing_url: Existing image URL (fallback)
user: User ID
storage: Storage instance
base_path: Base path for upload
Returns:
Tuple of (image_url, error_response)
"""
image_url = existing_url
if "image" in request.files:
file = request.files["image"]
if file.filename != "":
filename = secure_filename(file.filename)
upload_path = f"{settings.UPLOAD_FOLDER.rstrip('/')}/{user}/{base_path.rstrip('/')}/{uuid.uuid4()}_{filename}"
try:
storage.save_file(file, upload_path, storage_class="STANDARD")
image_url = upload_path
except Exception as e:
current_app.logger.error(f"Error uploading image: {e}")
return None, make_response(
jsonify({"success": False, "message": "Image upload failed"}),
400,
)
return image_url, None
def require_agent(func):
"""
Decorator to require valid agent webhook token.
Args:
func: Function to decorate
Returns:
Wrapped function
"""
@wraps(func)
def wrapper(*args, **kwargs):
webhook_token = kwargs.get("webhook_token")
if not webhook_token:
return make_response(
jsonify({"success": False, "message": "Webhook token missing"}), 400
)
agent = agents_collection.find_one(
{"incoming_webhook_token": webhook_token}, {"_id": 1}
)
if not agent:
current_app.logger.warning(
f"Webhook attempt with invalid token: {webhook_token}"
)
return make_response(
jsonify({"success": False, "message": "Agent not found"}), 404
)
kwargs["agent"] = agent
kwargs["agent_id_str"] = str(agent["_id"])
return func(*args, **kwargs)
return wrapper

View File

@@ -0,0 +1,5 @@
"""Conversation management module."""
from .routes import conversations_ns
__all__ = ["conversations_ns"]

View File

@@ -0,0 +1,280 @@
"""Conversation management routes."""
import datetime
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from application.api import api
from application.api.user.base import attachments_collection, conversations_collection
from application.utils import check_required_fields
conversations_ns = Namespace(
"conversations", description="Conversation management operations", path="/api"
)
@conversations_ns.route("/delete_conversation")
class DeleteConversation(Resource):
@api.doc(
description="Deletes a conversation by ID",
params={"id": "The ID of the conversation to delete"},
)
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
conversation_id = request.args.get("id")
if not conversation_id:
return make_response(
jsonify({"success": False, "message": "ID is required"}), 400
)
try:
conversations_collection.delete_one(
{"_id": ObjectId(conversation_id), "user": decoded_token["sub"]}
)
except Exception as err:
current_app.logger.error(
f"Error deleting conversation: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True}), 200)
@conversations_ns.route("/delete_all_conversations")
class DeleteAllConversations(Resource):
@api.doc(
description="Deletes all conversations for a specific user",
)
def get(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user_id = decoded_token.get("sub")
try:
conversations_collection.delete_many({"user": user_id})
except Exception as err:
current_app.logger.error(
f"Error deleting all conversations: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True}), 200)
@conversations_ns.route("/get_conversations")
class GetConversations(Resource):
@api.doc(
description="Retrieve a list of the latest 30 conversations (excluding API key conversations)",
)
def get(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
try:
conversations = (
conversations_collection.find(
{
"$or": [
{"api_key": {"$exists": False}},
{"agent_id": {"$exists": True}},
],
"user": decoded_token.get("sub"),
}
)
.sort("date", -1)
.limit(30)
)
list_conversations = [
{
"id": str(conversation["_id"]),
"name": conversation["name"],
"agent_id": conversation.get("agent_id", None),
"is_shared_usage": conversation.get("is_shared_usage", False),
"shared_token": conversation.get("shared_token", None),
}
for conversation in conversations
]
except Exception as err:
current_app.logger.error(
f"Error retrieving conversations: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify(list_conversations), 200)
@conversations_ns.route("/get_single_conversation")
class GetSingleConversation(Resource):
@api.doc(
description="Retrieve a single conversation by ID",
params={"id": "The conversation ID"},
)
def get(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
conversation_id = request.args.get("id")
if not conversation_id:
return make_response(
jsonify({"success": False, "message": "ID is required"}), 400
)
try:
conversation = conversations_collection.find_one(
{"_id": ObjectId(conversation_id), "user": decoded_token.get("sub")}
)
if not conversation:
return make_response(jsonify({"status": "not found"}), 404)
# Process queries to include attachment names
queries = conversation["queries"]
for query in queries:
if "attachments" in query and query["attachments"]:
attachment_details = []
for attachment_id in query["attachments"]:
try:
attachment = attachments_collection.find_one(
{"_id": ObjectId(attachment_id)}
)
if attachment:
attachment_details.append(
{
"id": str(attachment["_id"]),
"fileName": attachment.get(
"filename", "Unknown file"
),
}
)
except Exception as e:
current_app.logger.error(
f"Error retrieving attachment {attachment_id}: {e}",
exc_info=True,
)
query["attachments"] = attachment_details
except Exception as err:
current_app.logger.error(
f"Error retrieving conversation: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
data = {
"queries": queries,
"agent_id": conversation.get("agent_id"),
"is_shared_usage": conversation.get("is_shared_usage", False),
"shared_token": conversation.get("shared_token", None),
}
return make_response(jsonify(data), 200)
@conversations_ns.route("/update_conversation_name")
class UpdateConversationName(Resource):
@api.expect(
api.model(
"UpdateConversationModel",
{
"id": fields.String(required=True, description="Conversation ID"),
"name": fields.String(
required=True, description="New name of the conversation"
),
},
)
)
@api.doc(
description="Updates the name of a conversation",
)
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
data = request.get_json()
required_fields = ["id", "name"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
conversations_collection.update_one(
{"_id": ObjectId(data["id"]), "user": decoded_token.get("sub")},
{"$set": {"name": data["name"]}},
)
except Exception as err:
current_app.logger.error(
f"Error updating conversation name: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True}), 200)
@conversations_ns.route("/feedback")
class SubmitFeedback(Resource):
@api.expect(
api.model(
"FeedbackModel",
{
"question": fields.String(
required=False, description="The user question"
),
"answer": fields.String(required=False, description="The AI answer"),
"feedback": fields.String(required=True, description="User feedback"),
"question_index": fields.Integer(
required=True,
description="The question number in that particular conversation",
),
"conversation_id": fields.String(
required=True, description="id of the particular conversation"
),
"api_key": fields.String(description="Optional API key"),
},
)
)
@api.doc(
description="Submit feedback for a conversation",
)
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
data = request.get_json()
required_fields = ["feedback", "conversation_id", "question_index"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
if data["feedback"] is None:
# Remove feedback and feedback_timestamp if feedback is null
conversations_collection.update_one(
{
"_id": ObjectId(data["conversation_id"]),
"user": decoded_token.get("sub"),
f"queries.{data['question_index']}": {"$exists": True},
},
{
"$unset": {
f"queries.{data['question_index']}.feedback": "",
f"queries.{data['question_index']}.feedback_timestamp": "",
}
},
)
else:
# Set feedback and feedback_timestamp if feedback has a value
conversations_collection.update_one(
{
"_id": ObjectId(data["conversation_id"]),
"user": decoded_token.get("sub"),
f"queries.{data['question_index']}": {"$exists": True},
},
{
"$set": {
f"queries.{data['question_index']}.feedback": data[
"feedback"
],
f"queries.{data['question_index']}.feedback_timestamp": datetime.datetime.now(
datetime.timezone.utc
),
}
},
)
except Exception as err:
current_app.logger.error(f"Error submitting feedback: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True}), 200)

View File

@@ -0,0 +1,3 @@
from .routes import models_ns
__all__ = ["models_ns"]

View File

@@ -0,0 +1,25 @@
from flask import current_app, jsonify, make_response
from flask_restx import Namespace, Resource
from application.core.model_settings import ModelRegistry
models_ns = Namespace("models", description="Available models", path="/api")
@models_ns.route("/models")
class ModelsListResource(Resource):
def get(self):
"""Get list of available models with their capabilities."""
try:
registry = ModelRegistry.get_instance()
models = registry.get_enabled_models()
response = {
"models": [model.to_dict() for model in models],
"default_model_id": registry.default_model_id,
"count": len(models),
}
except Exception as err:
current_app.logger.error(f"Error fetching models: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 500)
return make_response(jsonify(response), 200)

View File

@@ -0,0 +1,5 @@
"""Prompts module."""
from .routes import prompts_ns
__all__ = ["prompts_ns"]

View File

@@ -0,0 +1,191 @@
"""Prompt management routes."""
import os
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from application.api import api
from application.api.user.base import current_dir, prompts_collection
from application.utils import check_required_fields
prompts_ns = Namespace(
"prompts", description="Prompt management operations", path="/api"
)
@prompts_ns.route("/create_prompt")
class CreatePrompt(Resource):
create_prompt_model = api.model(
"CreatePromptModel",
{
"content": fields.String(
required=True, description="Content of the prompt"
),
"name": fields.String(required=True, description="Name of the prompt"),
},
)
@api.expect(create_prompt_model)
@api.doc(description="Create a new prompt")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
data = request.get_json()
required_fields = ["content", "name"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
user = decoded_token.get("sub")
try:
resp = prompts_collection.insert_one(
{
"name": data["name"],
"content": data["content"],
"user": user,
}
)
new_id = str(resp.inserted_id)
except Exception as err:
current_app.logger.error(f"Error creating prompt: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"id": new_id}), 200)
@prompts_ns.route("/get_prompts")
class GetPrompts(Resource):
@api.doc(description="Get all prompts for the user")
def get(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
try:
prompts = prompts_collection.find({"user": user})
list_prompts = [
{"id": "default", "name": "default", "type": "public"},
{"id": "creative", "name": "creative", "type": "public"},
{"id": "strict", "name": "strict", "type": "public"},
]
for prompt in prompts:
list_prompts.append(
{
"id": str(prompt["_id"]),
"name": prompt["name"],
"type": "private",
}
)
except Exception as err:
current_app.logger.error(f"Error retrieving prompts: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify(list_prompts), 200)
@prompts_ns.route("/get_single_prompt")
class GetSinglePrompt(Resource):
@api.doc(params={"id": "ID of the prompt"}, description="Get a single prompt by ID")
def get(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
prompt_id = request.args.get("id")
if not prompt_id:
return make_response(
jsonify({"success": False, "message": "ID is required"}), 400
)
try:
if prompt_id == "default":
with open(
os.path.join(current_dir, "prompts", "chat_combine_default.txt"),
"r",
) as f:
chat_combine_template = f.read()
return make_response(jsonify({"content": chat_combine_template}), 200)
elif prompt_id == "creative":
with open(
os.path.join(current_dir, "prompts", "chat_combine_creative.txt"),
"r",
) as f:
chat_reduce_creative = f.read()
return make_response(jsonify({"content": chat_reduce_creative}), 200)
elif prompt_id == "strict":
with open(
os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r"
) as f:
chat_reduce_strict = f.read()
return make_response(jsonify({"content": chat_reduce_strict}), 200)
prompt = prompts_collection.find_one(
{"_id": ObjectId(prompt_id), "user": user}
)
except Exception as err:
current_app.logger.error(f"Error retrieving prompt: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"content": prompt["content"]}), 200)
@prompts_ns.route("/delete_prompt")
class DeletePrompt(Resource):
delete_prompt_model = api.model(
"DeletePromptModel",
{"id": fields.String(required=True, description="Prompt ID to delete")},
)
@api.expect(delete_prompt_model)
@api.doc(description="Delete a prompt by ID")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
required_fields = ["id"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
prompts_collection.delete_one({"_id": ObjectId(data["id"]), "user": user})
except Exception as err:
current_app.logger.error(f"Error deleting prompt: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True}), 200)
@prompts_ns.route("/update_prompt")
class UpdatePrompt(Resource):
update_prompt_model = api.model(
"UpdatePromptModel",
{
"id": fields.String(required=True, description="Prompt ID to update"),
"name": fields.String(required=True, description="New name of the prompt"),
"content": fields.String(
required=True, description="New content of the prompt"
),
},
)
@api.expect(update_prompt_model)
@api.doc(description="Update an existing prompt")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
required_fields = ["id", "name", "content"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
prompts_collection.update_one(
{"_id": ObjectId(data["id"]), "user": user},
{"$set": {"name": data["name"], "content": data["content"]}},
)
except Exception as err:
current_app.logger.error(f"Error updating prompt: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True}), 200)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,5 @@
"""Sharing module."""
from .routes import sharing_ns
__all__ = ["sharing_ns"]

View File

@@ -0,0 +1,289 @@
"""Conversation sharing routes."""
import uuid
from bson.binary import Binary, UuidRepresentation
from bson.dbref import DBRef
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, inputs, Namespace, Resource
from application.api import api
from application.api.user.base import (
agents_collection,
attachments_collection,
conversations_collection,
shared_conversations_collections,
)
from application.utils import check_required_fields
sharing_ns = Namespace(
"sharing", description="Conversation sharing operations", path="/api"
)
@sharing_ns.route("/share")
class ShareConversation(Resource):
share_conversation_model = api.model(
"ShareConversationModel",
{
"conversation_id": fields.String(
required=True, description="Conversation ID"
),
"user": fields.String(description="User ID (optional)"),
"prompt_id": fields.String(description="Prompt ID (optional)"),
"chunks": fields.Integer(description="Chunks count (optional)"),
},
)
@api.expect(share_conversation_model)
@api.doc(description="Share a conversation")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
required_fields = ["conversation_id"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
is_promptable = request.args.get("isPromptable", type=inputs.boolean)
if is_promptable is None:
return make_response(
jsonify({"success": False, "message": "isPromptable is required"}), 400
)
conversation_id = data["conversation_id"]
try:
conversation = conversations_collection.find_one(
{"_id": ObjectId(conversation_id)}
)
if conversation is None:
return make_response(
jsonify(
{
"status": "error",
"message": "Conversation does not exist",
}
),
404,
)
current_n_queries = len(conversation["queries"])
explicit_binary = Binary.from_uuid(
uuid.uuid4(), UuidRepresentation.STANDARD
)
if is_promptable:
prompt_id = data.get("prompt_id", "default")
chunks = data.get("chunks", "2")
name = conversation["name"] + "(shared)"
new_api_key_data = {
"prompt_id": prompt_id,
"chunks": chunks,
"user": user,
}
if "source" in data and ObjectId.is_valid(data["source"]):
new_api_key_data["source"] = DBRef(
"sources", ObjectId(data["source"])
)
if "retriever" in data:
new_api_key_data["retriever"] = data["retriever"]
pre_existing_api_document = agents_collection.find_one(new_api_key_data)
if pre_existing_api_document:
api_uuid = pre_existing_api_document["key"]
pre_existing = shared_conversations_collections.find_one(
{
"conversation_id": ObjectId(conversation_id),
"isPromptable": is_promptable,
"first_n_queries": current_n_queries,
"user": user,
"api_key": api_uuid,
}
)
if pre_existing is not None:
return make_response(
jsonify(
{
"success": True,
"identifier": str(pre_existing["uuid"].as_uuid()),
}
),
200,
)
else:
shared_conversations_collections.insert_one(
{
"uuid": explicit_binary,
"conversation_id": ObjectId(conversation_id),
"isPromptable": is_promptable,
"first_n_queries": current_n_queries,
"user": user,
"api_key": api_uuid,
}
)
return make_response(
jsonify(
{
"success": True,
"identifier": str(explicit_binary.as_uuid()),
}
),
201,
)
else:
api_uuid = str(uuid.uuid4())
new_api_key_data["key"] = api_uuid
new_api_key_data["name"] = name
if "source" in data and ObjectId.is_valid(data["source"]):
new_api_key_data["source"] = DBRef(
"sources", ObjectId(data["source"])
)
if "retriever" in data:
new_api_key_data["retriever"] = data["retriever"]
agents_collection.insert_one(new_api_key_data)
shared_conversations_collections.insert_one(
{
"uuid": explicit_binary,
"conversation_id": ObjectId(conversation_id),
"isPromptable": is_promptable,
"first_n_queries": current_n_queries,
"user": user,
"api_key": api_uuid,
}
)
return make_response(
jsonify(
{
"success": True,
"identifier": str(explicit_binary.as_uuid()),
}
),
201,
)
pre_existing = shared_conversations_collections.find_one(
{
"conversation_id": ObjectId(conversation_id),
"isPromptable": is_promptable,
"first_n_queries": current_n_queries,
"user": user,
}
)
if pre_existing is not None:
return make_response(
jsonify(
{
"success": True,
"identifier": str(pre_existing["uuid"].as_uuid()),
}
),
200,
)
else:
shared_conversations_collections.insert_one(
{
"uuid": explicit_binary,
"conversation_id": ObjectId(conversation_id),
"isPromptable": is_promptable,
"first_n_queries": current_n_queries,
"user": user,
}
)
return make_response(
jsonify(
{"success": True, "identifier": str(explicit_binary.as_uuid())}
),
201,
)
except Exception as err:
current_app.logger.error(
f"Error sharing conversation: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
@sharing_ns.route("/shared_conversation/<string:identifier>")
class GetPubliclySharedConversations(Resource):
@api.doc(description="Get publicly shared conversations by identifier")
def get(self, identifier: str):
try:
query_uuid = Binary.from_uuid(
uuid.UUID(identifier), UuidRepresentation.STANDARD
)
shared = shared_conversations_collections.find_one({"uuid": query_uuid})
conversation_queries = []
if (
shared
and "conversation_id" in shared
):
# conversation_id is now stored as an ObjectId, not a DBRef
conversation_id = shared["conversation_id"]
conversation = conversations_collection.find_one(
{"_id": conversation_id}
)
if conversation is None:
return make_response(
jsonify(
{
"success": False,
"error": "might have broken url or the conversation does not exist",
}
),
404,
)
conversation_queries = conversation["queries"][
: (shared["first_n_queries"])
]
for query in conversation_queries:
if "attachments" in query and query["attachments"]:
attachment_details = []
for attachment_id in query["attachments"]:
try:
attachment = attachments_collection.find_one(
{"_id": ObjectId(attachment_id)}
)
if attachment:
attachment_details.append(
{
"id": str(attachment["_id"]),
"fileName": attachment.get(
"filename", "Unknown file"
),
}
)
except Exception as e:
current_app.logger.error(
f"Error retrieving attachment {attachment_id}: {e}",
exc_info=True,
)
query["attachments"] = attachment_details
else:
return make_response(
jsonify(
{
"success": False,
"error": "might have broken url or the conversation does not exist",
}
),
404,
)
date = conversation["_id"].generation_time.isoformat()
res = {
"success": True,
"queries": conversation_queries,
"title": conversation["name"],
"timestamp": date,
}
if shared["isPromptable"] and "api_key" in shared:
res["api_key"] = shared["api_key"]
return make_response(jsonify(res), 200)
except Exception as err:
current_app.logger.error(
f"Error getting shared conversation: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)

View File

@@ -0,0 +1,7 @@
"""Sources module."""
from .chunks import sources_chunks_ns
from .routes import sources_ns
from .upload import sources_upload_ns
__all__ = ["sources_ns", "sources_chunks_ns", "sources_upload_ns"]

View File

@@ -0,0 +1,278 @@
"""Source document management chunk management."""
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from application.api import api
from application.api.user.base import get_vector_store, sources_collection
from application.utils import check_required_fields, num_tokens_from_string
sources_chunks_ns = Namespace(
"sources", description="Source document management operations", path="/api"
)
@sources_chunks_ns.route("/get_chunks")
class GetChunks(Resource):
@api.doc(
description="Retrieves chunks from a document, optionally filtered by file path and search term",
params={
"id": "The document ID",
"page": "Page number for pagination",
"per_page": "Number of chunks per page",
"path": "Optional: Filter chunks by relative file path",
"search": "Optional: Search term to filter chunks by title or content",
},
)
def get(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
doc_id = request.args.get("id")
page = int(request.args.get("page", 1))
per_page = int(request.args.get("per_page", 10))
path = request.args.get("path")
search_term = request.args.get("search", "").strip().lower()
if not ObjectId.is_valid(doc_id):
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
if not doc:
return make_response(
jsonify({"error": "Document not found or access denied"}), 404
)
try:
store = get_vector_store(doc_id)
chunks = store.get_chunks()
filtered_chunks = []
for chunk in chunks:
metadata = chunk.get("metadata", {})
# Filter by path if provided
if path:
chunk_source = metadata.get("source", "")
# Check if the chunk's source matches the requested path
if not chunk_source or not chunk_source.endswith(path):
continue
# Filter by search term if provided
if search_term:
text_match = search_term in chunk.get("text", "").lower()
title_match = search_term in metadata.get("title", "").lower()
if not (text_match or title_match):
continue
filtered_chunks.append(chunk)
chunks = filtered_chunks
total_chunks = len(chunks)
start = (page - 1) * per_page
end = start + per_page
paginated_chunks = chunks[start:end]
return make_response(
jsonify(
{
"page": page,
"per_page": per_page,
"total": total_chunks,
"chunks": paginated_chunks,
"path": path if path else None,
"search": search_term if search_term else None,
}
),
200,
)
except Exception as e:
current_app.logger.error(f"Error getting chunks: {e}", exc_info=True)
return make_response(jsonify({"success": False}), 500)
@sources_chunks_ns.route("/add_chunk")
class AddChunk(Resource):
@api.expect(
api.model(
"AddChunkModel",
{
"id": fields.String(required=True, description="Document ID"),
"text": fields.String(required=True, description="Text of the chunk"),
"metadata": fields.Raw(
required=False,
description="Metadata associated with the chunk",
),
},
)
)
@api.doc(
description="Adds a new chunk to the document",
)
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
required_fields = ["id", "text"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
doc_id = data.get("id")
text = data.get("text")
metadata = data.get("metadata", {})
token_count = num_tokens_from_string(text)
metadata["token_count"] = token_count
if not ObjectId.is_valid(doc_id):
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
if not doc:
return make_response(
jsonify({"error": "Document not found or access denied"}), 404
)
try:
store = get_vector_store(doc_id)
chunk_id = store.add_chunk(text, metadata)
return make_response(
jsonify({"message": "Chunk added successfully", "chunk_id": chunk_id}),
201,
)
except Exception as e:
current_app.logger.error(f"Error adding chunk: {e}", exc_info=True)
return make_response(jsonify({"success": False}), 500)
@sources_chunks_ns.route("/delete_chunk")
class DeleteChunk(Resource):
@api.doc(
description="Deletes a specific chunk from the document.",
params={"id": "The document ID", "chunk_id": "The ID of the chunk to delete"},
)
def delete(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
doc_id = request.args.get("id")
chunk_id = request.args.get("chunk_id")
if not ObjectId.is_valid(doc_id):
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
if not doc:
return make_response(
jsonify({"error": "Document not found or access denied"}), 404
)
try:
store = get_vector_store(doc_id)
deleted = store.delete_chunk(chunk_id)
if deleted:
return make_response(
jsonify({"message": "Chunk deleted successfully"}), 200
)
else:
return make_response(
jsonify({"message": "Chunk not found or could not be deleted"}),
404,
)
except Exception as e:
current_app.logger.error(f"Error deleting chunk: {e}", exc_info=True)
return make_response(jsonify({"success": False}), 500)
@sources_chunks_ns.route("/update_chunk")
class UpdateChunk(Resource):
@api.expect(
api.model(
"UpdateChunkModel",
{
"id": fields.String(required=True, description="Document ID"),
"chunk_id": fields.String(
required=True, description="Chunk ID to update"
),
"text": fields.String(
required=False, description="New text of the chunk"
),
"metadata": fields.Raw(
required=False,
description="Updated metadata associated with the chunk",
),
},
)
)
@api.doc(
description="Updates an existing chunk in the document.",
)
def put(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
required_fields = ["id", "chunk_id"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
doc_id = data.get("id")
chunk_id = data.get("chunk_id")
text = data.get("text")
metadata = data.get("metadata")
if text is not None:
token_count = num_tokens_from_string(text)
if metadata is None:
metadata = {}
metadata["token_count"] = token_count
if not ObjectId.is_valid(doc_id):
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
if not doc:
return make_response(
jsonify({"error": "Document not found or access denied"}), 404
)
try:
store = get_vector_store(doc_id)
chunks = store.get_chunks()
existing_chunk = next((c for c in chunks if c["doc_id"] == chunk_id), None)
if not existing_chunk:
return make_response(jsonify({"error": "Chunk not found"}), 404)
new_text = text if text is not None else existing_chunk["text"]
if metadata is not None:
new_metadata = existing_chunk["metadata"].copy()
new_metadata.update(metadata)
else:
new_metadata = existing_chunk["metadata"].copy()
if text is not None:
new_metadata["token_count"] = num_tokens_from_string(new_text)
try:
new_chunk_id = store.add_chunk(new_text, new_metadata)
deleted = store.delete_chunk(chunk_id)
if not deleted:
current_app.logger.warning(
f"Failed to delete old chunk {chunk_id}, but new chunk {new_chunk_id} was created"
)
return make_response(
jsonify(
{
"message": "Chunk updated successfully",
"chunk_id": new_chunk_id,
"original_chunk_id": chunk_id,
}
),
200,
)
except Exception as add_error:
current_app.logger.error(f"Failed to add updated chunk: {add_error}")
return make_response(
jsonify({"error": "Failed to update chunk - addition failed"}), 500
)
except Exception as e:
current_app.logger.error(f"Error updating chunk: {e}", exc_info=True)
return make_response(jsonify({"success": False}), 500)

View File

@@ -0,0 +1,323 @@
"""Source document management routes."""
import json
import math
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, redirect, request
from flask_restx import fields, Namespace, Resource
from application.api import api
from application.api.user.base import sources_collection
from application.core.settings import settings
from application.storage.storage_creator import StorageCreator
from application.utils import check_required_fields
from application.vectorstore.vector_creator import VectorCreator
sources_ns = Namespace(
"sources", description="Source document management operations", path="/api"
)
@sources_ns.route("/sources")
class CombinedJson(Resource):
@api.doc(description="Provide JSON file with combined available indexes")
def get(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = [
{
"name": "Default",
"date": "default",
"model": settings.EMBEDDINGS_NAME,
"location": "remote",
"tokens": "",
"retriever": "classic",
}
]
try:
for index in sources_collection.find({"user": user}).sort("date", -1):
data.append(
{
"id": str(index["_id"]),
"name": index.get("name"),
"date": index.get("date"),
"model": settings.EMBEDDINGS_NAME,
"location": "local",
"tokens": index.get("tokens", ""),
"retriever": index.get("retriever", "classic"),
"syncFrequency": index.get("sync_frequency", ""),
"is_nested": bool(index.get("directory_structure")),
"type": index.get(
"type", "file"
), # Add type field with default "file"
}
)
except Exception as err:
current_app.logger.error(f"Error retrieving sources: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify(data), 200)
@sources_ns.route("/sources/paginated")
class PaginatedSources(Resource):
@api.doc(description="Get document with pagination, sorting and filtering")
def get(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
sort_field = request.args.get("sort", "date") # Default to 'date'
sort_order = request.args.get("order", "desc") # Default to 'desc'
page = int(request.args.get("page", 1)) # Default to 1
rows_per_page = int(request.args.get("rows", 10)) # Default to 10
# add .strip() to remove leading and trailing whitespaces
search_term = request.args.get(
"search", ""
).strip() # add search for filter documents
# Prepare query for filtering
query = {"user": user}
if search_term:
query["name"] = {
"$regex": search_term,
"$options": "i", # using case-insensitive search
}
total_documents = sources_collection.count_documents(query)
total_pages = max(1, math.ceil(total_documents / rows_per_page))
page = min(
max(1, page), total_pages
) # add this to make sure page inbound is within the range
sort_order = 1 if sort_order == "asc" else -1
skip = (page - 1) * rows_per_page
try:
documents = (
sources_collection.find(query)
.sort(sort_field, sort_order)
.skip(skip)
.limit(rows_per_page)
)
paginated_docs = []
for doc in documents:
doc_data = {
"id": str(doc["_id"]),
"name": doc.get("name", ""),
"date": doc.get("date", ""),
"model": settings.EMBEDDINGS_NAME,
"location": "local",
"tokens": doc.get("tokens", ""),
"retriever": doc.get("retriever", "classic"),
"syncFrequency": doc.get("sync_frequency", ""),
"isNested": bool(doc.get("directory_structure")),
"type": doc.get("type", "file"),
}
paginated_docs.append(doc_data)
response = {
"total": total_documents,
"totalPages": total_pages,
"currentPage": page,
"paginated": paginated_docs,
}
return make_response(jsonify(response), 200)
except Exception as err:
current_app.logger.error(
f"Error retrieving paginated sources: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
@sources_ns.route("/delete_by_ids")
class DeleteByIds(Resource):
@api.doc(
description="Deletes documents from the vector store by IDs",
params={"path": "Comma-separated list of IDs"},
)
def get(self):
ids = request.args.get("path")
if not ids:
return make_response(
jsonify({"success": False, "message": "Missing required fields"}), 400
)
try:
result = sources_collection.delete_index(ids=ids)
if result:
return make_response(jsonify({"success": True}), 200)
except Exception as err:
current_app.logger.error(f"Error deleting indexes: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": False}), 400)
@sources_ns.route("/delete_old")
class DeleteOldIndexes(Resource):
@api.doc(
description="Deletes old indexes and associated files",
params={"source_id": "The source ID to delete"},
)
def get(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
source_id = request.args.get("source_id")
if not source_id:
return make_response(
jsonify({"success": False, "message": "Missing required fields"}), 400
)
doc = sources_collection.find_one(
{"_id": ObjectId(source_id), "user": decoded_token.get("sub")}
)
if not doc:
return make_response(jsonify({"status": "not found"}), 404)
storage = StorageCreator.get_storage()
try:
# Delete vector index
if settings.VECTOR_STORE == "faiss":
index_path = f"indexes/{str(doc['_id'])}"
if storage.file_exists(f"{index_path}/index.faiss"):
storage.delete_file(f"{index_path}/index.faiss")
if storage.file_exists(f"{index_path}/index.pkl"):
storage.delete_file(f"{index_path}/index.pkl")
else:
vectorstore = VectorCreator.create_vectorstore(
settings.VECTOR_STORE, source_id=str(doc["_id"])
)
vectorstore.delete_index()
if "file_path" in doc and doc["file_path"]:
file_path = doc["file_path"]
if storage.is_directory(file_path):
files = storage.list_files(file_path)
for f in files:
storage.delete_file(f)
else:
storage.delete_file(file_path)
except FileNotFoundError:
pass
except Exception as err:
current_app.logger.error(
f"Error deleting files and indexes: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
sources_collection.delete_one({"_id": ObjectId(source_id)})
return make_response(jsonify({"success": True}), 200)
@sources_ns.route("/combine")
class RedirectToSources(Resource):
@api.doc(
description="Redirects /api/combine to /api/sources for backward compatibility"
)
def get(self):
return redirect("/api/sources", code=301)
@sources_ns.route("/manage_sync")
class ManageSync(Resource):
manage_sync_model = api.model(
"ManageSyncModel",
{
"source_id": fields.String(required=True, description="Source ID"),
"sync_frequency": fields.String(
required=True,
description="Sync frequency (never, daily, weekly, monthly)",
),
},
)
@api.expect(manage_sync_model)
@api.doc(description="Manage sync frequency for sources")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
required_fields = ["source_id", "sync_frequency"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
source_id = data["source_id"]
sync_frequency = data["sync_frequency"]
if sync_frequency not in ["never", "daily", "weekly", "monthly"]:
return make_response(
jsonify({"success": False, "message": "Invalid frequency"}), 400
)
update_data = {"$set": {"sync_frequency": sync_frequency}}
try:
sources_collection.update_one(
{
"_id": ObjectId(source_id),
"user": user,
},
update_data,
)
except Exception as err:
current_app.logger.error(
f"Error updating sync frequency: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True}), 200)
@sources_ns.route("/directory_structure")
class DirectoryStructure(Resource):
@api.doc(
description="Get the directory structure for a document",
params={"id": "The document ID"},
)
def get(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
doc_id = request.args.get("id")
if not doc_id:
return make_response(jsonify({"error": "Document ID is required"}), 400)
if not ObjectId.is_valid(doc_id):
return make_response(jsonify({"error": "Invalid document ID"}), 400)
try:
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
if not doc:
return make_response(
jsonify({"error": "Document not found or access denied"}), 404
)
directory_structure = doc.get("directory_structure", {})
base_path = doc.get("file_path", "")
provider = None
remote_data = doc.get("remote_data")
try:
if isinstance(remote_data, str) and remote_data:
remote_data_obj = json.loads(remote_data)
provider = remote_data_obj.get("provider")
except Exception as e:
current_app.logger.warning(
f"Failed to parse remote_data for doc {doc_id}: {e}"
)
return make_response(
jsonify(
{
"success": True,
"directory_structure": directory_structure,
"base_path": base_path,
"provider": provider,
}
),
200,
)
except Exception as e:
current_app.logger.error(
f"Error retrieving directory structure: {e}", exc_info=True
)
return make_response(jsonify({"success": False, "error": str(e)}), 500)

View File

@@ -0,0 +1,583 @@
"""Source document management upload functionality."""
import json
import os
import tempfile
import zipfile
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from application.api import api
from application.api.user.base import sources_collection
from application.api.user.tasks import ingest, ingest_connector_task, ingest_remote
from application.core.settings import settings
from application.parser.connectors.connector_creator import ConnectorCreator
from application.storage.storage_creator import StorageCreator
from application.utils import check_required_fields, safe_filename
sources_upload_ns = Namespace(
"sources", description="Source document management operations", path="/api"
)
@sources_upload_ns.route("/upload")
class UploadFile(Resource):
@api.expect(
api.model(
"UploadModel",
{
"user": fields.String(required=True, description="User ID"),
"name": fields.String(required=True, description="Job name"),
"file": fields.Raw(required=True, description="File(s) to upload"),
},
)
)
@api.doc(
description="Uploads a file to be vectorized and indexed",
)
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
data = request.form
files = request.files.getlist("file")
required_fields = ["user", "name"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields or not files or all(file.filename == "" for file in files):
return make_response(
jsonify(
{
"status": "error",
"message": "Missing required fields or files",
}
),
400,
)
user = decoded_token.get("sub")
job_name = request.form["name"]
# Create safe versions for filesystem operations
safe_user = safe_filename(user)
dir_name = safe_filename(job_name)
base_path = f"{settings.UPLOAD_FOLDER}/{safe_user}/{dir_name}"
try:
storage = StorageCreator.get_storage()
for file in files:
original_filename = file.filename
safe_file = safe_filename(original_filename)
with tempfile.TemporaryDirectory() as temp_dir:
temp_file_path = os.path.join(temp_dir, safe_file)
file.save(temp_file_path)
if zipfile.is_zipfile(temp_file_path):
try:
with zipfile.ZipFile(temp_file_path, "r") as zip_ref:
zip_ref.extractall(path=temp_dir)
# Walk through extracted files and upload them
for root, _, files in os.walk(temp_dir):
for extracted_file in files:
if (
os.path.join(root, extracted_file)
== temp_file_path
):
continue
rel_path = os.path.relpath(
os.path.join(root, extracted_file), temp_dir
)
storage_path = f"{base_path}/{rel_path}"
with open(
os.path.join(root, extracted_file), "rb"
) as f:
storage.save_file(f, storage_path)
except Exception as e:
current_app.logger.error(
f"Error extracting zip: {e}", exc_info=True
)
# If zip extraction fails, save the original zip file
file_path = f"{base_path}/{safe_file}"
with open(temp_file_path, "rb") as f:
storage.save_file(f, file_path)
else:
# For non-zip files, save directly
file_path = f"{base_path}/{safe_file}"
with open(temp_file_path, "rb") as f:
storage.save_file(f, file_path)
task = ingest.delay(
settings.UPLOAD_FOLDER,
[
".rst",
".md",
".pdf",
".txt",
".docx",
".csv",
".epub",
".html",
".mdx",
".json",
".xlsx",
".pptx",
".png",
".jpg",
".jpeg",
],
job_name,
user,
file_path=base_path,
filename=dir_name,
)
except Exception as err:
current_app.logger.error(f"Error uploading file: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
@sources_upload_ns.route("/remote")
class UploadRemote(Resource):
@api.expect(
api.model(
"RemoteUploadModel",
{
"user": fields.String(required=True, description="User ID"),
"source": fields.String(
required=True, description="Source of the data"
),
"name": fields.String(required=True, description="Job name"),
"data": fields.String(required=True, description="Data to process"),
"repo_url": fields.String(description="GitHub repository URL"),
},
)
)
@api.doc(
description="Uploads remote source for vectorization",
)
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
data = request.form
required_fields = ["user", "source", "name", "data"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
config = json.loads(data["data"])
source_data = None
if data["source"] == "github":
source_data = config.get("repo_url")
elif data["source"] in ["crawler", "url"]:
source_data = config.get("url")
elif data["source"] == "reddit":
source_data = config
elif data["source"] in ConnectorCreator.get_supported_connectors():
session_token = config.get("session_token")
if not session_token:
return make_response(
jsonify(
{
"success": False,
"error": f"Missing session_token in {data['source']} configuration",
}
),
400,
)
# Process file_ids
file_ids = config.get("file_ids", [])
if isinstance(file_ids, str):
file_ids = [id.strip() for id in file_ids.split(",") if id.strip()]
elif not isinstance(file_ids, list):
file_ids = []
# Process folder_ids
folder_ids = config.get("folder_ids", [])
if isinstance(folder_ids, str):
folder_ids = [
id.strip() for id in folder_ids.split(",") if id.strip()
]
elif not isinstance(folder_ids, list):
folder_ids = []
config["file_ids"] = file_ids
config["folder_ids"] = folder_ids
task = ingest_connector_task.delay(
job_name=data["name"],
user=decoded_token.get("sub"),
source_type=data["source"],
session_token=session_token,
file_ids=file_ids,
folder_ids=folder_ids,
recursive=config.get("recursive", False),
retriever=config.get("retriever", "classic"),
)
return make_response(
jsonify({"success": True, "task_id": task.id}), 200
)
task = ingest_remote.delay(
source_data=source_data,
job_name=data["name"],
user=decoded_token.get("sub"),
loader=data["source"],
)
except Exception as err:
current_app.logger.error(
f"Error uploading remote source: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
@sources_upload_ns.route("/manage_source_files")
class ManageSourceFiles(Resource):
@api.expect(
api.model(
"ManageSourceFilesModel",
{
"source_id": fields.String(
required=True, description="Source ID to modify"
),
"operation": fields.String(
required=True,
description="Operation: 'add', 'remove', or 'remove_directory'",
),
"file_paths": fields.List(
fields.String,
required=False,
description="File paths to remove (for remove operation)",
),
"directory_path": fields.String(
required=False,
description="Directory path to remove (for remove_directory operation)",
),
"file": fields.Raw(
required=False, description="Files to add (for add operation)"
),
"parent_dir": fields.String(
required=False,
description="Parent directory path relative to source root",
),
},
)
)
@api.doc(
description="Add files, remove files, or remove directories from an existing source",
)
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(
jsonify({"success": False, "message": "Unauthorized"}), 401
)
user = decoded_token.get("sub")
source_id = request.form.get("source_id")
operation = request.form.get("operation")
if not source_id or not operation:
return make_response(
jsonify(
{
"success": False,
"message": "source_id and operation are required",
}
),
400,
)
if operation not in ["add", "remove", "remove_directory"]:
return make_response(
jsonify(
{
"success": False,
"message": "operation must be 'add', 'remove', or 'remove_directory'",
}
),
400,
)
try:
ObjectId(source_id)
except Exception:
return make_response(
jsonify({"success": False, "message": "Invalid source ID format"}), 400
)
try:
source = sources_collection.find_one(
{"_id": ObjectId(source_id), "user": user}
)
if not source:
return make_response(
jsonify(
{
"success": False,
"message": "Source not found or access denied",
}
),
404,
)
except Exception as err:
current_app.logger.error(f"Error finding source: {err}", exc_info=True)
return make_response(
jsonify({"success": False, "message": "Database error"}), 500
)
try:
storage = StorageCreator.get_storage()
source_file_path = source.get("file_path", "")
parent_dir = request.form.get("parent_dir", "")
if parent_dir and (parent_dir.startswith("/") or ".." in parent_dir):
return make_response(
jsonify(
{"success": False, "message": "Invalid parent directory path"}
),
400,
)
if operation == "add":
files = request.files.getlist("file")
if not files or all(file.filename == "" for file in files):
return make_response(
jsonify(
{
"success": False,
"message": "No files provided for add operation",
}
),
400,
)
added_files = []
target_dir = source_file_path
if parent_dir:
target_dir = f"{source_file_path}/{parent_dir}"
for file in files:
if file.filename:
safe_filename_str = safe_filename(file.filename)
file_path = f"{target_dir}/{safe_filename_str}"
# Save file to storage
storage.save_file(file, file_path)
added_files.append(safe_filename_str)
# Trigger re-ingestion pipeline
from application.api.user.tasks import reingest_source_task
task = reingest_source_task.delay(source_id=source_id, user=user)
return make_response(
jsonify(
{
"success": True,
"message": f"Added {len(added_files)} files",
"added_files": added_files,
"parent_dir": parent_dir,
"reingest_task_id": task.id,
}
),
200,
)
elif operation == "remove":
file_paths_str = request.form.get("file_paths")
if not file_paths_str:
return make_response(
jsonify(
{
"success": False,
"message": "file_paths required for remove operation",
}
),
400,
)
try:
file_paths = (
json.loads(file_paths_str)
if isinstance(file_paths_str, str)
else file_paths_str
)
except Exception:
return make_response(
jsonify(
{"success": False, "message": "Invalid file_paths format"}
),
400,
)
# Remove files from storage and directory structure
removed_files = []
for file_path in file_paths:
full_path = f"{source_file_path}/{file_path}"
# Remove from storage
if storage.file_exists(full_path):
storage.delete_file(full_path)
removed_files.append(file_path)
# Trigger re-ingestion pipeline
from application.api.user.tasks import reingest_source_task
task = reingest_source_task.delay(source_id=source_id, user=user)
return make_response(
jsonify(
{
"success": True,
"message": f"Removed {len(removed_files)} files",
"removed_files": removed_files,
"reingest_task_id": task.id,
}
),
200,
)
elif operation == "remove_directory":
directory_path = request.form.get("directory_path")
if not directory_path:
return make_response(
jsonify(
{
"success": False,
"message": "directory_path required for remove_directory operation",
}
),
400,
)
# Validate directory path (prevent path traversal)
if directory_path.startswith("/") or ".." in directory_path:
current_app.logger.warning(
f"Invalid directory path attempted for removal. "
f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}"
)
return make_response(
jsonify(
{"success": False, "message": "Invalid directory path"}
),
400,
)
full_directory_path = (
f"{source_file_path}/{directory_path}"
if directory_path
else source_file_path
)
if not storage.is_directory(full_directory_path):
current_app.logger.warning(
f"Directory not found or is not a directory for removal. "
f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}, "
f"Full path: {full_directory_path}"
)
return make_response(
jsonify(
{
"success": False,
"message": "Directory not found or is not a directory",
}
),
404,
)
success = storage.remove_directory(full_directory_path)
if not success:
current_app.logger.error(
f"Failed to remove directory from storage. "
f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}, "
f"Full path: {full_directory_path}"
)
return make_response(
jsonify(
{"success": False, "message": "Failed to remove directory"}
),
500,
)
current_app.logger.info(
f"Successfully removed directory. "
f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}, "
f"Full path: {full_directory_path}"
)
# Trigger re-ingestion pipeline
from application.api.user.tasks import reingest_source_task
task = reingest_source_task.delay(source_id=source_id, user=user)
return make_response(
jsonify(
{
"success": True,
"message": f"Successfully removed directory: {directory_path}",
"removed_directory": directory_path,
"reingest_task_id": task.id,
}
),
200,
)
except Exception as err:
error_context = f"operation={operation}, user={user}, source_id={source_id}"
if operation == "remove_directory":
directory_path = request.form.get("directory_path", "")
error_context += f", directory_path={directory_path}"
elif operation == "remove":
file_paths_str = request.form.get("file_paths", "")
error_context += f", file_paths={file_paths_str}"
elif operation == "add":
parent_dir = request.form.get("parent_dir", "")
error_context += f", parent_dir={parent_dir}"
current_app.logger.error(
f"Error managing source files: {err} ({error_context})", exc_info=True
)
return make_response(
jsonify({"success": False, "message": "Operation failed"}), 500
)
@sources_upload_ns.route("/task_status")
class TaskStatus(Resource):
task_status_model = api.model(
"TaskStatusModel",
{"task_id": fields.String(required=True, description="Task ID")},
)
@api.expect(task_status_model)
@api.doc(description="Get celery job status")
def get(self):
task_id = request.args.get("task_id")
if not task_id:
return make_response(
jsonify({"success": False, "message": "Task ID is required"}), 400
)
try:
from application.celery_init import celery
task = celery.AsyncResult(task_id)
task_meta = task.info
print(f"Task status: {task.status}")
if task.status == "PENDING":
inspect = celery.control.inspect()
active_workers = inspect.ping()
if not active_workers:
raise ConnectionError("Service unavailable")
if not isinstance(
task_meta, (dict, list, str, int, float, bool, type(None))
):
task_meta = str(task_meta) # Convert to a string representation
except ConnectionError as err:
return make_response(
jsonify({"success": False, "message": str(err)}), 503
)
except Exception as err:
current_app.logger.error(f"Error getting task status: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"status": task.status, "result": task_meta}), 200)

View File

@@ -5,14 +5,16 @@ from application.worker import (
agent_webhook_worker,
attachment_worker,
ingest_worker,
mcp_oauth,
mcp_oauth_status,
remote_worker,
sync_worker,
)
@celery.task(bind=True)
def ingest(self, directory, formats, job_name, filename, user, dir_name, user_dir):
resp = ingest_worker(self, directory, formats, job_name, filename, user, dir_name, user_dir)
def ingest(self, directory, formats, job_name, user, file_path, filename):
resp = ingest_worker(self, directory, formats, job_name, file_path, filename, user)
return resp
@@ -22,6 +24,14 @@ def ingest_remote(self, source_data, job_name, user, loader):
return resp
@celery.task(bind=True)
def reingest_source_task(self, source_id, user):
from application.worker import reingest_source_worker
resp = reingest_source_worker(self, source_id, user)
return resp
@celery.task(bind=True)
def schedule_syncs(self, frequency):
resp = sync_worker(self, frequency)
@@ -40,6 +50,40 @@ def process_agent_webhook(self, agent_id, payload):
return resp
@celery.task(bind=True)
def ingest_connector_task(
self,
job_name,
user,
source_type,
session_token=None,
file_ids=None,
folder_ids=None,
recursive=True,
retriever="classic",
operation_mode="upload",
doc_id=None,
sync_frequency="never",
):
from application.worker import ingest_connector
resp = ingest_connector(
self,
job_name,
user,
source_type,
session_token=session_token,
file_ids=file_ids,
folder_ids=folder_ids,
recursive=recursive,
retriever=retriever,
operation_mode=operation_mode,
doc_id=doc_id,
sync_frequency=sync_frequency,
)
return resp
@celery.on_after_configure.connect
def setup_periodic_tasks(sender, **kwargs):
sender.add_periodic_task(
@@ -54,3 +98,15 @@ def setup_periodic_tasks(sender, **kwargs):
timedelta(days=30),
schedule_syncs.s("monthly"),
)
@celery.task(bind=True)
def mcp_oauth_task(self, config, user):
resp = mcp_oauth(self, config, user)
return resp
@celery.task(bind=True)
def mcp_oauth_status_task(self, task_id):
resp = mcp_oauth_status(self, task_id)
return resp

View File

@@ -0,0 +1,6 @@
"""Tools module."""
from .mcp import tools_mcp_ns
from .routes import tools_ns
__all__ = ["tools_ns", "tools_mcp_ns"]

View File

@@ -0,0 +1,333 @@
"""Tool management MCP server integration."""
import json
from email.quoprimime import unquote
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, redirect, request
from flask_restx import fields, Namespace, Resource
from application.agents.tools.mcp_tool import MCPOAuthManager, MCPTool
from application.api import api
from application.api.user.base import user_tools_collection
from application.cache import get_redis_instance
from application.security.encryption import encrypt_credentials
from application.utils import check_required_fields
tools_mcp_ns = Namespace("tools", description="Tool management operations", path="/api")
@tools_mcp_ns.route("/mcp_server/test")
class TestMCPServerConfig(Resource):
@api.expect(
api.model(
"MCPServerTestModel",
{
"config": fields.Raw(
required=True, description="MCP server configuration to test"
),
},
)
)
@api.doc(description="Test MCP server connection with provided configuration")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
required_fields = ["config"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
config = data["config"]
auth_credentials = {}
auth_type = config.get("auth_type", "none")
if auth_type == "api_key" and "api_key" in config:
auth_credentials["api_key"] = config["api_key"]
if "api_key_header" in config:
auth_credentials["api_key_header"] = config["api_key_header"]
elif auth_type == "bearer" and "bearer_token" in config:
auth_credentials["bearer_token"] = config["bearer_token"]
elif auth_type == "basic":
if "username" in config:
auth_credentials["username"] = config["username"]
if "password" in config:
auth_credentials["password"] = config["password"]
test_config = config.copy()
test_config["auth_credentials"] = auth_credentials
mcp_tool = MCPTool(config=test_config, user_id=user)
result = mcp_tool.test_connection()
return make_response(jsonify(result), 200)
except Exception as e:
current_app.logger.error(f"Error testing MCP server: {e}", exc_info=True)
return make_response(
jsonify(
{"success": False, "error": f"Connection test failed: {str(e)}"}
),
500,
)
@tools_mcp_ns.route("/mcp_server/save")
class MCPServerSave(Resource):
@api.expect(
api.model(
"MCPServerSaveModel",
{
"id": fields.String(
required=False, description="Tool ID for updates (optional)"
),
"displayName": fields.String(
required=True, description="Display name for the MCP server"
),
"config": fields.Raw(
required=True, description="MCP server configuration"
),
"status": fields.Boolean(
required=False, default=True, description="Tool status"
),
},
)
)
@api.doc(description="Create or update MCP server with automatic tool discovery")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
required_fields = ["displayName", "config"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
config = data["config"]
auth_credentials = {}
auth_type = config.get("auth_type", "none")
if auth_type == "api_key":
if "api_key" in config and config["api_key"]:
auth_credentials["api_key"] = config["api_key"]
if "api_key_header" in config:
auth_credentials["api_key_header"] = config["api_key_header"]
elif auth_type == "bearer":
if "bearer_token" in config and config["bearer_token"]:
auth_credentials["bearer_token"] = config["bearer_token"]
elif auth_type == "basic":
if "username" in config and config["username"]:
auth_credentials["username"] = config["username"]
if "password" in config and config["password"]:
auth_credentials["password"] = config["password"]
mcp_config = config.copy()
mcp_config["auth_credentials"] = auth_credentials
if auth_type == "oauth":
if not config.get("oauth_task_id"):
return make_response(
jsonify(
{
"success": False,
"error": "Connection not authorized. Please complete the OAuth authorization first.",
}
),
400,
)
redis_client = get_redis_instance()
manager = MCPOAuthManager(redis_client)
result = manager.get_oauth_status(config["oauth_task_id"])
if not result.get("status") == "completed":
return make_response(
jsonify(
{
"success": False,
"error": "OAuth failed or not completed. Please try authorizing again.",
}
),
400,
)
actions_metadata = result.get("tools", [])
elif auth_type == "none" or auth_credentials:
mcp_tool = MCPTool(config=mcp_config, user_id=user)
mcp_tool.discover_tools()
actions_metadata = mcp_tool.get_actions_metadata()
else:
raise Exception(
"No valid credentials provided for the selected authentication type"
)
storage_config = config.copy()
if auth_credentials:
encrypted_credentials_string = encrypt_credentials(
auth_credentials, user
)
storage_config["encrypted_credentials"] = encrypted_credentials_string
for field in [
"api_key",
"bearer_token",
"username",
"password",
"api_key_header",
]:
storage_config.pop(field, None)
transformed_actions = []
for action in actions_metadata:
action["active"] = True
if "parameters" in action:
if "properties" in action["parameters"]:
for param_name, param_details in action["parameters"][
"properties"
].items():
param_details["filled_by_llm"] = True
param_details["value"] = ""
transformed_actions.append(action)
tool_data = {
"name": "mcp_tool",
"displayName": data["displayName"],
"customName": data["displayName"],
"description": f"MCP Server: {storage_config.get('server_url', 'Unknown')}",
"config": storage_config,
"actions": transformed_actions,
"status": data.get("status", True),
"user": user,
}
tool_id = data.get("id")
if tool_id:
result = user_tools_collection.update_one(
{"_id": ObjectId(tool_id), "user": user, "name": "mcp_tool"},
{"$set": {k: v for k, v in tool_data.items() if k != "user"}},
)
if result.matched_count == 0:
return make_response(
jsonify(
{
"success": False,
"error": "Tool not found or access denied",
}
),
404,
)
response_data = {
"success": True,
"id": tool_id,
"message": f"MCP server updated successfully! Discovered {len(transformed_actions)} tools.",
"tools_count": len(transformed_actions),
}
else:
result = user_tools_collection.insert_one(tool_data)
tool_id = str(result.inserted_id)
response_data = {
"success": True,
"id": tool_id,
"message": f"MCP server created successfully! Discovered {len(transformed_actions)} tools.",
"tools_count": len(transformed_actions),
}
return make_response(jsonify(response_data), 200)
except Exception as e:
current_app.logger.error(f"Error saving MCP server: {e}", exc_info=True)
return make_response(
jsonify(
{"success": False, "error": f"Failed to save MCP server: {str(e)}"}
),
500,
)
@tools_mcp_ns.route("/mcp_server/callback")
class MCPOAuthCallback(Resource):
@api.expect(
api.model(
"MCPServerCallbackModel",
{
"code": fields.String(required=True, description="Authorization code"),
"state": fields.String(required=True, description="State parameter"),
"error": fields.String(
required=False, description="Error message (if any)"
),
},
)
)
@api.doc(
description="Handle OAuth callback by providing the authorization code and state"
)
def get(self):
code = request.args.get("code")
state = request.args.get("state")
error = request.args.get("error")
if error:
return redirect(
f"/api/connectors/callback-status?status=error&message=OAuth+error:+{error}.+Please+try+again+and+make+sure+to+grant+all+requested+permissions,+including+offline+access.&provider=mcp_tool"
)
if not code or not state:
return redirect(
"/api/connectors/callback-status?status=error&message=Authorization+code+or+state+not+provided.+Please+complete+the+authorization+process+and+make+sure+to+grant+offline+access.&provider=mcp_tool"
)
try:
redis_client = get_redis_instance()
if not redis_client:
return redirect(
"/api/connectors/callback-status?status=error&message=Internal+server+error:+Redis+not+available.&provider=mcp_tool"
)
code = unquote(code)
manager = MCPOAuthManager(redis_client)
success = manager.handle_oauth_callback(state, code, error)
if success:
return redirect(
"/api/connectors/callback-status?status=success&message=Authorization+code+received+successfully.+You+can+close+this+window.&provider=mcp_tool"
)
else:
return redirect(
"/api/connectors/callback-status?status=error&message=OAuth+callback+failed.&provider=mcp_tool"
)
except Exception as e:
current_app.logger.error(
f"Error handling MCP OAuth callback: {str(e)}", exc_info=True
)
return redirect(
f"/api/connectors/callback-status?status=error&message=Internal+server+error:+{str(e)}.&provider=mcp_tool"
)
@tools_mcp_ns.route("/mcp_server/oauth_status/<string:task_id>")
class MCPOAuthStatus(Resource):
def get(self, task_id):
"""
Get current status of OAuth flow.
Frontend should poll this endpoint periodically.
"""
try:
redis_client = get_redis_instance()
status_key = f"mcp_oauth_status:{task_id}"
status_data = redis_client.get(status_key)
if status_data:
status = json.loads(status_data)
return make_response(
jsonify({"success": True, "task_id": task_id, **status})
)
else:
return make_response(
jsonify(
{
"success": False,
"error": "Task not found or expired",
"task_id": task_id,
}
),
404,
)
except Exception as e:
current_app.logger.error(
f"Error getting OAuth status for task {task_id}: {str(e)}"
)
return make_response(
jsonify({"success": False, "error": str(e), "task_id": task_id}), 500
)

View File

@@ -0,0 +1,416 @@
"""Tool management routes."""
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from application.agents.tools.tool_manager import ToolManager
from application.api import api
from application.api.user.base import user_tools_collection
from application.security.encryption import decrypt_credentials, encrypt_credentials
from application.utils import check_required_fields, validate_function_name
tool_config = {}
tool_manager = ToolManager(config=tool_config)
tools_ns = Namespace("tools", description="Tool management operations", path="/api")
@tools_ns.route("/available_tools")
class AvailableTools(Resource):
@api.doc(description="Get available tools for a user")
def get(self):
try:
tools_metadata = []
for tool_name, tool_instance in tool_manager.tools.items():
doc = tool_instance.__doc__.strip()
lines = doc.split("\n", 1)
name = lines[0].strip()
description = lines[1].strip() if len(lines) > 1 else ""
tools_metadata.append(
{
"name": tool_name,
"displayName": name,
"description": description,
"configRequirements": tool_instance.get_config_requirements(),
}
)
except Exception as err:
current_app.logger.error(
f"Error getting available tools: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True, "data": tools_metadata}), 200)
@tools_ns.route("/get_tools")
class GetTools(Resource):
@api.doc(description="Get tools created by a user")
def get(self):
try:
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
tools = user_tools_collection.find({"user": user})
user_tools = []
for tool in tools:
tool_copy = {**tool}
tool_copy["id"] = str(tool["_id"])
tool_copy.pop("_id", None)
user_tools.append(tool_copy)
except Exception as err:
current_app.logger.error(f"Error getting user tools: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True, "tools": user_tools}), 200)
@tools_ns.route("/create_tool")
class CreateTool(Resource):
@api.expect(
api.model(
"CreateToolModel",
{
"name": fields.String(required=True, description="Name of the tool"),
"displayName": fields.String(
required=True, description="Display name for the tool"
),
"description": fields.String(
required=True, description="Tool description"
),
"config": fields.Raw(
required=True, description="Configuration of the tool"
),
"customName": fields.String(
required=False, description="Custom name for the tool"
),
"status": fields.Boolean(
required=True, description="Status of the tool"
),
},
)
)
@api.doc(description="Create a new tool")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
required_fields = [
"name",
"displayName",
"description",
"config",
"status",
]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
tool_instance = tool_manager.tools.get(data["name"])
if not tool_instance:
return make_response(
jsonify({"success": False, "message": "Tool not found"}), 404
)
actions_metadata = tool_instance.get_actions_metadata()
transformed_actions = []
for action in actions_metadata:
action["active"] = True
if "parameters" in action:
if "properties" in action["parameters"]:
for param_name, param_details in action["parameters"][
"properties"
].items():
param_details["filled_by_llm"] = True
param_details["value"] = ""
transformed_actions.append(action)
except Exception as err:
current_app.logger.error(
f"Error getting tool actions: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
try:
new_tool = {
"user": user,
"name": data["name"],
"displayName": data["displayName"],
"description": data["description"],
"customName": data.get("customName", ""),
"actions": transformed_actions,
"config": data["config"],
"status": data["status"],
}
resp = user_tools_collection.insert_one(new_tool)
new_id = str(resp.inserted_id)
except Exception as err:
current_app.logger.error(f"Error creating tool: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"id": new_id}), 200)
@tools_ns.route("/update_tool")
class UpdateTool(Resource):
@api.expect(
api.model(
"UpdateToolModel",
{
"id": fields.String(required=True, description="Tool ID"),
"name": fields.String(description="Name of the tool"),
"displayName": fields.String(description="Display name for the tool"),
"customName": fields.String(description="Custom name for the tool"),
"description": fields.String(description="Tool description"),
"config": fields.Raw(description="Configuration of the tool"),
"actions": fields.List(
fields.Raw, description="Actions the tool can perform"
),
"status": fields.Boolean(description="Status of the tool"),
},
)
)
@api.doc(description="Update a tool by ID")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
required_fields = ["id"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
update_data = {}
if "name" in data:
update_data["name"] = data["name"]
if "displayName" in data:
update_data["displayName"] = data["displayName"]
if "customName" in data:
update_data["customName"] = data["customName"]
if "description" in data:
update_data["description"] = data["description"]
if "actions" in data:
update_data["actions"] = data["actions"]
if "config" in data:
if "actions" in data["config"]:
for action_name in list(data["config"]["actions"].keys()):
if not validate_function_name(action_name):
return make_response(
jsonify(
{
"success": False,
"message": f"Invalid function name '{action_name}'. Function names must match pattern '^[a-zA-Z0-9_-]+$'.",
"param": "tools[].function.name",
}
),
400,
)
tool_doc = user_tools_collection.find_one(
{"_id": ObjectId(data["id"]), "user": user}
)
if tool_doc and tool_doc.get("name") == "mcp_tool":
config = data["config"]
existing_config = tool_doc.get("config", {})
storage_config = existing_config.copy()
storage_config.update(config)
existing_credentials = {}
if "encrypted_credentials" in existing_config:
existing_credentials = decrypt_credentials(
existing_config["encrypted_credentials"], user
)
auth_credentials = existing_credentials.copy()
auth_type = storage_config.get("auth_type", "none")
if auth_type == "api_key":
if "api_key" in config and config["api_key"]:
auth_credentials["api_key"] = config["api_key"]
if "api_key_header" in config:
auth_credentials["api_key_header"] = config[
"api_key_header"
]
elif auth_type == "bearer":
if "bearer_token" in config and config["bearer_token"]:
auth_credentials["bearer_token"] = config["bearer_token"]
elif "encrypted_token" in config and config["encrypted_token"]:
auth_credentials["bearer_token"] = config["encrypted_token"]
elif auth_type == "basic":
if "username" in config and config["username"]:
auth_credentials["username"] = config["username"]
if "password" in config and config["password"]:
auth_credentials["password"] = config["password"]
if auth_type != "none" and auth_credentials:
encrypted_credentials_string = encrypt_credentials(
auth_credentials, user
)
storage_config["encrypted_credentials"] = (
encrypted_credentials_string
)
elif auth_type == "none":
storage_config.pop("encrypted_credentials", None)
for field in [
"api_key",
"bearer_token",
"encrypted_token",
"username",
"password",
"api_key_header",
]:
storage_config.pop(field, None)
update_data["config"] = storage_config
else:
update_data["config"] = data["config"]
if "status" in data:
update_data["status"] = data["status"]
user_tools_collection.update_one(
{"_id": ObjectId(data["id"]), "user": user},
{"$set": update_data},
)
except Exception as err:
current_app.logger.error(f"Error updating tool: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True}), 200)
@tools_ns.route("/update_tool_config")
class UpdateToolConfig(Resource):
@api.expect(
api.model(
"UpdateToolConfigModel",
{
"id": fields.String(required=True, description="Tool ID"),
"config": fields.Raw(
required=True, description="Configuration of the tool"
),
},
)
)
@api.doc(description="Update the configuration of a tool")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
required_fields = ["id", "config"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
user_tools_collection.update_one(
{"_id": ObjectId(data["id"]), "user": user},
{"$set": {"config": data["config"]}},
)
except Exception as err:
current_app.logger.error(
f"Error updating tool config: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True}), 200)
@tools_ns.route("/update_tool_actions")
class UpdateToolActions(Resource):
@api.expect(
api.model(
"UpdateToolActionsModel",
{
"id": fields.String(required=True, description="Tool ID"),
"actions": fields.List(
fields.Raw,
required=True,
description="Actions the tool can perform",
),
},
)
)
@api.doc(description="Update the actions of a tool")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
required_fields = ["id", "actions"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
user_tools_collection.update_one(
{"_id": ObjectId(data["id"]), "user": user},
{"$set": {"actions": data["actions"]}},
)
except Exception as err:
current_app.logger.error(
f"Error updating tool actions: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True}), 200)
@tools_ns.route("/update_tool_status")
class UpdateToolStatus(Resource):
@api.expect(
api.model(
"UpdateToolStatusModel",
{
"id": fields.String(required=True, description="Tool ID"),
"status": fields.Boolean(
required=True, description="Status of the tool"
),
},
)
)
@api.doc(description="Update the status of a tool")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
required_fields = ["id", "status"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
user_tools_collection.update_one(
{"_id": ObjectId(data["id"]), "user": user},
{"$set": {"status": data["status"]}},
)
except Exception as err:
current_app.logger.error(
f"Error updating tool status: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True}), 200)
@tools_ns.route("/delete_tool")
class DeleteTool(Resource):
@api.expect(
api.model(
"DeleteToolModel",
{"id": fields.String(required=True, description="Tool ID")},
)
)
@api.doc(description="Delete a tool by ID")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
required_fields = ["id"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
result = user_tools_collection.delete_one(
{"_id": ObjectId(data["id"]), "user": user}
)
if result.deleted_count == 0:
return {"success": False, "message": "Tool not found"}, 404
except Exception as err:
current_app.logger.error(f"Error deleting tool: {err}", exc_info=True)
return {"success": False}, 400
return {"success": True}, 200

View File

@@ -12,25 +12,26 @@ from application.core.logging_config import setup_logging
setup_logging()
from application.api.answer.routes import answer # noqa: E402
from application.api import api # noqa: E402
from application.api.answer import answer # noqa: E402
from application.api.internal.routes import internal # noqa: E402
from application.api.user.routes import user # noqa: E402
from application.api.connector.routes import connector # noqa: E402
from application.celery_init import celery # noqa: E402
from application.core.settings import settings # noqa: E402
from application.extensions import api # noqa: E402
if platform.system() == "Windows":
import pathlib
pathlib.PosixPath = pathlib.WindowsPath
dotenv.load_dotenv()
app = Flask(__name__)
app.register_blueprint(user)
app.register_blueprint(answer)
app.register_blueprint(internal)
app.register_blueprint(connector)
app.config.update(
UPLOAD_FOLDER="inputs",
CELERY_BROKER_URL=settings.CELERY_BROKER_URL,
@@ -52,7 +53,6 @@ if settings.AUTH_TYPE in ("simple_jwt", "session_jwt") and not settings.JWT_SECR
settings.JWT_SECRET_KEY = new_key
except Exception as e:
raise RuntimeError(f"Failed to setup JWT_SECRET_KEY: {e}")
SIMPLE_JWT_TOKEN = None
if settings.AUTH_TYPE == "simple_jwt":
payload = {"sub": "local"}
@@ -92,7 +92,6 @@ def generate_token():
def authenticate_request():
if request.method == "OPTIONS":
return "", 200
decoded_token = handle_auth(request)
if not decoded_token:
request.decoded_token = None

View File

@@ -0,0 +1,189 @@
"""
Model configurations for all supported LLM providers.
"""
from application.core.model_settings import (
AvailableModel,
ModelCapabilities,
ModelProvider,
)
OPENAI_ATTACHMENTS = [
"application/pdf",
"image/png",
"image/jpeg",
"image/jpg",
"image/webp",
"image/gif",
]
GOOGLE_ATTACHMENTS = [
"application/pdf",
"image/png",
"image/jpeg",
"image/jpg",
"image/webp",
"image/gif",
]
OPENAI_MODELS = [
AvailableModel(
id="gpt-5.1",
provider=ModelProvider.OPENAI,
display_name="GPT-5.1",
description="Flagship model with enhanced reasoning, coding, and agentic capabilities",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
context_window=200000,
),
),
AvailableModel(
id="gpt-5-mini",
provider=ModelProvider.OPENAI,
display_name="GPT-5 Mini",
description="Faster, cost-effective variant of GPT-5.1",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
context_window=200000,
),
)
]
ANTHROPIC_MODELS = [
AvailableModel(
id="claude-3-5-sonnet-20241022",
provider=ModelProvider.ANTHROPIC,
display_name="Claude 3.5 Sonnet (Latest)",
description="Latest Claude 3.5 Sonnet with enhanced capabilities",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=200000,
),
),
AvailableModel(
id="claude-3-5-sonnet",
provider=ModelProvider.ANTHROPIC,
display_name="Claude 3.5 Sonnet",
description="Balanced performance and capability",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=200000,
),
),
AvailableModel(
id="claude-3-opus",
provider=ModelProvider.ANTHROPIC,
display_name="Claude 3 Opus",
description="Most capable Claude model",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=200000,
),
),
AvailableModel(
id="claude-3-haiku",
provider=ModelProvider.ANTHROPIC,
display_name="Claude 3 Haiku",
description="Fastest Claude model",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=200000,
),
),
]
GOOGLE_MODELS = [
AvailableModel(
id="gemini-flash-latest",
provider=ModelProvider.GOOGLE,
display_name="Gemini Flash (Latest)",
description="Latest experimental Gemini model",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=GOOGLE_ATTACHMENTS,
context_window=int(1e6),
),
),
AvailableModel(
id="gemini-flash-lite-latest",
provider=ModelProvider.GOOGLE,
display_name="Gemini Flash Lite (Latest)",
description="Fast with huge context window",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=GOOGLE_ATTACHMENTS,
context_window=int(1e6),
),
),
AvailableModel(
id="gemini-3-pro-preview",
provider=ModelProvider.GOOGLE,
display_name="Gemini 3 Pro",
description="Most capable Gemini model",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=GOOGLE_ATTACHMENTS,
context_window=2000000,
),
),
]
GROQ_MODELS = [
AvailableModel(
id="llama-3.3-70b-versatile",
provider=ModelProvider.GROQ,
display_name="Llama 3.3 70B",
description="Latest Llama model with high-speed inference",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=128000,
),
),
AvailableModel(
id="llama-3.1-8b-instant",
provider=ModelProvider.GROQ,
display_name="Llama 3.1 8B",
description="Ultra-fast inference",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=128000,
),
),
AvailableModel(
id="mixtral-8x7b-32768",
provider=ModelProvider.GROQ,
display_name="Mixtral 8x7B",
description="High-speed inference with tools",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=32768,
),
),
]
AZURE_OPENAI_MODELS = [
AvailableModel(
id="azure-gpt-4",
provider=ModelProvider.AZURE_OPENAI,
display_name="Azure OpenAI GPT-4",
description="Azure-hosted GPT model",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
context_window=8192,
),
),
]

View File

@@ -0,0 +1,236 @@
import logging
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, List, Optional
logger = logging.getLogger(__name__)
class ModelProvider(str, Enum):
OPENAI = "openai"
AZURE_OPENAI = "azure_openai"
ANTHROPIC = "anthropic"
GROQ = "groq"
GOOGLE = "google"
HUGGINGFACE = "huggingface"
LLAMA_CPP = "llama.cpp"
DOCSGPT = "docsgpt"
PREMAI = "premai"
SAGEMAKER = "sagemaker"
NOVITA = "novita"
@dataclass
class ModelCapabilities:
supports_tools: bool = False
supports_structured_output: bool = False
supports_streaming: bool = True
supported_attachment_types: List[str] = field(default_factory=list)
context_window: int = 128000
input_cost_per_token: Optional[float] = None
output_cost_per_token: Optional[float] = None
@dataclass
class AvailableModel:
id: str
provider: ModelProvider
display_name: str
description: str = ""
capabilities: ModelCapabilities = field(default_factory=ModelCapabilities)
enabled: bool = True
base_url: Optional[str] = None
def to_dict(self) -> Dict:
result = {
"id": self.id,
"provider": self.provider.value,
"display_name": self.display_name,
"description": self.description,
"supported_attachment_types": self.capabilities.supported_attachment_types,
"supports_tools": self.capabilities.supports_tools,
"supports_structured_output": self.capabilities.supports_structured_output,
"supports_streaming": self.capabilities.supports_streaming,
"context_window": self.capabilities.context_window,
"enabled": self.enabled,
}
if self.base_url:
result["base_url"] = self.base_url
return result
class ModelRegistry:
_instance = None
_initialized = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not ModelRegistry._initialized:
self.models: Dict[str, AvailableModel] = {}
self.default_model_id: Optional[str] = None
self._load_models()
ModelRegistry._initialized = True
@classmethod
def get_instance(cls) -> "ModelRegistry":
return cls()
def _load_models(self):
from application.core.settings import settings
self.models.clear()
self._add_docsgpt_models(settings)
if settings.OPENAI_API_KEY or (
settings.LLM_PROVIDER == "openai" and settings.API_KEY
):
self._add_openai_models(settings)
if settings.OPENAI_API_BASE or (
settings.LLM_PROVIDER == "azure_openai" and settings.API_KEY
):
self._add_azure_openai_models(settings)
if settings.ANTHROPIC_API_KEY or (
settings.LLM_PROVIDER == "anthropic" and settings.API_KEY
):
self._add_anthropic_models(settings)
if settings.GOOGLE_API_KEY or (
settings.LLM_PROVIDER == "google" and settings.API_KEY
):
self._add_google_models(settings)
if settings.GROQ_API_KEY or (
settings.LLM_PROVIDER == "groq" and settings.API_KEY
):
self._add_groq_models(settings)
if settings.HUGGINGFACE_API_KEY or (
settings.LLM_PROVIDER == "huggingface" and settings.API_KEY
):
self._add_huggingface_models(settings)
# Default model selection
if settings.LLM_NAME and settings.LLM_NAME in self.models:
self.default_model_id = settings.LLM_NAME
elif settings.LLM_PROVIDER and settings.API_KEY:
for model_id, model in self.models.items():
if model.provider.value == settings.LLM_PROVIDER:
self.default_model_id = model_id
break
else:
self.default_model_id = next(iter(self.models.keys()))
logger.info(
f"ModelRegistry loaded {len(self.models)} models, default: {self.default_model_id}"
)
def _add_openai_models(self, settings):
from application.core.model_configs import OPENAI_MODELS
if settings.OPENAI_API_KEY:
for model in OPENAI_MODELS:
self.models[model.id] = model
return
if settings.LLM_PROVIDER == "openai" and settings.LLM_NAME:
for model in OPENAI_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in OPENAI_MODELS:
self.models[model.id] = model
def _add_azure_openai_models(self, settings):
from application.core.model_configs import AZURE_OPENAI_MODELS
if settings.LLM_PROVIDER == "azure_openai" and settings.LLM_NAME:
for model in AZURE_OPENAI_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in AZURE_OPENAI_MODELS:
self.models[model.id] = model
def _add_anthropic_models(self, settings):
from application.core.model_configs import ANTHROPIC_MODELS
if settings.ANTHROPIC_API_KEY:
for model in ANTHROPIC_MODELS:
self.models[model.id] = model
return
if settings.LLM_PROVIDER == "anthropic" and settings.LLM_NAME:
for model in ANTHROPIC_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in ANTHROPIC_MODELS:
self.models[model.id] = model
def _add_google_models(self, settings):
from application.core.model_configs import GOOGLE_MODELS
if settings.GOOGLE_API_KEY:
for model in GOOGLE_MODELS:
self.models[model.id] = model
return
if settings.LLM_PROVIDER == "google" and settings.LLM_NAME:
for model in GOOGLE_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in GOOGLE_MODELS:
self.models[model.id] = model
def _add_groq_models(self, settings):
from application.core.model_configs import GROQ_MODELS
if settings.GROQ_API_KEY:
for model in GROQ_MODELS:
self.models[model.id] = model
return
if settings.LLM_PROVIDER == "groq" and settings.LLM_NAME:
for model in GROQ_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in GROQ_MODELS:
self.models[model.id] = model
def _add_docsgpt_models(self, settings):
model_id = "docsgpt-local"
model = AvailableModel(
id=model_id,
provider=ModelProvider.DOCSGPT,
display_name="DocsGPT Model",
description="Local model",
capabilities=ModelCapabilities(
supports_tools=False,
supported_attachment_types=[],
),
)
self.models[model_id] = model
def _add_huggingface_models(self, settings):
model_id = "huggingface-local"
model = AvailableModel(
id=model_id,
provider=ModelProvider.HUGGINGFACE,
display_name="Hugging Face Model",
description="Local Hugging Face model",
capabilities=ModelCapabilities(
supports_tools=False,
supported_attachment_types=[],
),
)
self.models[model_id] = model
def get_model(self, model_id: str) -> Optional[AvailableModel]:
return self.models.get(model_id)
def get_all_models(self) -> List[AvailableModel]:
return list(self.models.values())
def get_enabled_models(self) -> List[AvailableModel]:
return [m for m in self.models.values() if m.enabled]
def model_exists(self, model_id: str) -> bool:
return model_id in self.models

View File

@@ -0,0 +1,91 @@
from typing import Any, Dict, Optional
from application.core.model_settings import ModelRegistry
def get_api_key_for_provider(provider: str) -> Optional[str]:
"""Get the appropriate API key for a provider"""
from application.core.settings import settings
provider_key_map = {
"openai": settings.OPENAI_API_KEY,
"anthropic": settings.ANTHROPIC_API_KEY,
"google": settings.GOOGLE_API_KEY,
"groq": settings.GROQ_API_KEY,
"huggingface": settings.HUGGINGFACE_API_KEY,
"azure_openai": settings.API_KEY,
"docsgpt": None,
"llama.cpp": None,
}
provider_key = provider_key_map.get(provider)
if provider_key:
return provider_key
return settings.API_KEY
def get_all_available_models() -> Dict[str, Dict[str, Any]]:
"""Get all available models with metadata for API response"""
registry = ModelRegistry.get_instance()
return {model.id: model.to_dict() for model in registry.get_enabled_models()}
def validate_model_id(model_id: str) -> bool:
"""Check if a model ID exists in registry"""
registry = ModelRegistry.get_instance()
return registry.model_exists(model_id)
def get_model_capabilities(model_id: str) -> Optional[Dict[str, Any]]:
"""Get capabilities for a specific model"""
registry = ModelRegistry.get_instance()
model = registry.get_model(model_id)
if model:
return {
"supported_attachment_types": model.capabilities.supported_attachment_types,
"supports_tools": model.capabilities.supports_tools,
"supports_structured_output": model.capabilities.supports_structured_output,
"context_window": model.capabilities.context_window,
}
return None
def get_default_model_id() -> str:
"""Get the system default model ID"""
registry = ModelRegistry.get_instance()
return registry.default_model_id
def get_provider_from_model_id(model_id: str) -> Optional[str]:
"""Get the provider name for a given model_id"""
registry = ModelRegistry.get_instance()
model = registry.get_model(model_id)
if model:
return model.provider.value
return None
def get_token_limit(model_id: str) -> int:
"""
Get context window (token limit) for a model.
Returns model's context_window or default 128000 if model not found.
"""
from application.core.settings import settings
registry = ModelRegistry.get_instance()
model = registry.get_model(model_id)
if model:
return model.capabilities.context_window
return settings.DEFAULT_LLM_TOKEN_LIMIT
def get_base_url_for_model(model_id: str) -> Optional[str]:
"""
Get the custom base_url for a specific model if configured.
Returns None if no custom base_url is set.
"""
registry = ModelRegistry.get_instance()
model = registry.get_model(model_id)
if model:
return model.base_url
return None

View File

@@ -10,42 +10,68 @@ current_dir = os.path.dirname(
class Settings(BaseSettings):
AUTH_TYPE: Optional[str] = None
AUTH_TYPE: Optional[str] = None # simple_jwt, session_jwt, or None
LLM_PROVIDER: str = "docsgpt"
LLM_NAME: Optional[str] = (
None # if LLM_PROVIDER is openai, LLM_NAME can be gpt-4 or gpt-3.5-turbo
)
EMBEDDINGS_NAME: str = "huggingface_sentence-transformers/all-mpnet-base-v2"
EMBEDDINGS_PATH: Optional[str] = "./models/all-mpnet-base-v2" # Set None for SentenceTransformer to manage download
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
MONGO_URI: str = "mongodb://localhost:27017/docsgpt"
MONGO_DB_NAME: str = "docsgpt"
LLM_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf")
DEFAULT_MAX_HISTORY: int = 150
LLM_TOKEN_LIMITS: dict = {
"gpt-4o-mini": 128000,
"gpt-3.5-turbo": 4096,
"claude-2": 1e5,
"gemini-2.0-flash-exp": 1e6,
DEFAULT_LLM_TOKEN_LIMIT: int = 128000 # Fallback when model not found in registry
RESERVED_TOKENS: dict = {
"system_prompt": 500,
"current_query": 500,
"safety_buffer": 1000,
}
DEFAULT_AGENT_LIMITS: dict = {
"token_limit": 50000,
"request_limit": 500,
}
UPLOAD_FOLDER: str = "inputs"
PARSE_PDF_AS_IMAGE: bool = False
PARSE_IMAGE_REMOTE: bool = False
VECTOR_STORE: str = (
"faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb"
)
RETRIEVERS_ENABLED: list = ["classic_rag", "duckduck_search"] # also brave_search
RETRIEVERS_ENABLED: list = ["classic_rag"]
AGENT_NAME: str = "classic"
FALLBACK_LLM_PROVIDER: Optional[str] = None # provider for fallback llm
FALLBACK_LLM_NAME: Optional[str] = None # model name for fallback llm
FALLBACK_LLM_API_KEY: Optional[str] = None # api key for fallback llm
# Google Drive integration
GOOGLE_CLIENT_ID: Optional[str] = (
None # Replace with your actual Google OAuth client ID
)
GOOGLE_CLIENT_SECRET: Optional[str] = (
None # Replace with your actual Google OAuth client secret
)
CONNECTOR_REDIRECT_BASE_URI: Optional[str] = (
"http://127.0.0.1:7091/api/connectors/callback" ##add redirect url as it is to your provider's console(gcp)
)
# GitHub source
GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access
# LLM Cache
CACHE_REDIS_URL: str = "redis://localhost:6379/2"
API_URL: str = "http://localhost:7091" # backend url for celery worker
API_KEY: Optional[str] = None # LLM api key
API_KEY: Optional[str] = None # LLM api key (used by LLM_PROVIDER)
# Provider-specific API keys (for multi-model support)
OPENAI_API_KEY: Optional[str] = None
ANTHROPIC_API_KEY: Optional[str] = None
GOOGLE_API_KEY: Optional[str] = None
GROQ_API_KEY: Optional[str] = None
HUGGINGFACE_API_KEY: Optional[str] = None
EMBEDDINGS_KEY: Optional[str] = (
None # api key for embeddings (if using openai, just copy API_KEY)
)
@@ -90,6 +116,8 @@ class Settings(BaseSettings):
QDRANT_PATH: Optional[str] = None
QDRANT_DISTANCE_FUNC: str = "Cosine"
# PGVector vectorstore config
PGVECTOR_CONNECTION_STRING: Optional[str] = None
# Milvus vectorstore config
MILVUS_COLLECTION_NAME: Optional[str] = "docsgpt"
MILVUS_URI: Optional[str] = "./milvus_local.db" # milvus lite version as default
@@ -100,7 +128,6 @@ class Settings(BaseSettings):
LANCEDB_TABLE_NAME: Optional[str] = (
"docsgpts" # Name of the table to use for storing vectors
)
BRAVE_SEARCH_API_KEY: Optional[str] = None
FLASK_DEBUG_MODE: bool = False
STORAGE_TYPE: str = "local" # local or s3
@@ -108,6 +135,22 @@ class Settings(BaseSettings):
JWT_SECRET_KEY: str = ""
# Encryption settings
ENCRYPTION_SECRET_KEY: str = "default-docsgpt-encryption-key"
TTS_PROVIDER: str = "google_tts" # google_tts or elevenlabs
ELEVENLABS_API_KEY: Optional[str] = None
# Tool pre-fetch settings
ENABLE_TOOL_PREFETCH: bool = True
# Conversation Compression Settings
ENABLE_CONVERSATION_COMPRESSION: bool = True
COMPRESSION_THRESHOLD_PERCENTAGE: float = 0.8 # Trigger at 80% of context
COMPRESSION_MODEL_OVERRIDE: Optional[str] = None # Use different model for compression
COMPRESSION_PROMPT_VERSION: str = "v1.0" # Track prompt iterations
COMPRESSION_MAX_HISTORY_POINTS: int = 3 # Keep only last N compression points to prevent DB bloat
path = Path(__file__).parent.parent.absolute()
settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8")

View File

@@ -1,7 +0,0 @@
from flask_restx import Api
api = Api(
version="1.0",
title="DocsGPT API",
description="API for DocsGPT",
)

View File

@@ -1,30 +1,41 @@
from application.llm.base import BaseLLM
from anthropic import AI_PROMPT, Anthropic, HUMAN_PROMPT
from application.core.settings import settings
from application.llm.base import BaseLLM
class AnthropicLLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.api_key = (
api_key or settings.ANTHROPIC_API_KEY
) # If not provided, use a default from settings
self.api_key = api_key or settings.ANTHROPIC_API_KEY or settings.API_KEY
self.user_api_key = user_api_key
self.anthropic = Anthropic(api_key=self.api_key)
# Use custom base_url if provided
if base_url:
self.anthropic = Anthropic(api_key=self.api_key, base_url=base_url)
else:
self.anthropic = Anthropic(api_key=self.api_key)
self.HUMAN_PROMPT = HUMAN_PROMPT
self.AI_PROMPT = AI_PROMPT
def _raw_gen(
self, baseself, model, messages, stream=False, tools=None, max_tokens=300, **kwargs
self,
baseself,
model,
messages,
stream=False,
tools=None,
max_tokens=300,
**kwargs,
):
context = messages[0]["content"]
user_question = messages[-1]["content"]
prompt = f"### Context \n {context} \n ### Question \n {user_question}"
if stream:
return self.gen_stream(model, prompt, stream, max_tokens, **kwargs)
completion = self.anthropic.completions.create(
model=model,
max_tokens_to_sample=max_tokens,
@@ -34,7 +45,14 @@ class AnthropicLLM(BaseLLM):
return completion.completion
def _raw_gen_stream(
self, baseself, model, messages, stream=True, tools=None, max_tokens=300, **kwargs
self,
baseself,
model,
messages,
stream=True,
tools=None,
max_tokens=300,
**kwargs,
):
context = messages[0]["content"]
user_question = messages[-1]["content"]
@@ -46,5 +64,9 @@ class AnthropicLLM(BaseLLM):
stream=True,
)
for completion in stream_response:
yield completion.completion
try:
for completion in stream_response:
yield completion.completion
finally:
if hasattr(stream_response, "close"):
stream_response.close()

View File

@@ -13,30 +13,32 @@ class BaseLLM(ABC):
def __init__(
self,
decoded_token=None,
model_id=None,
base_url=None,
):
self.decoded_token = decoded_token
self.model_id = model_id
self.base_url = base_url
self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
self.fallback_provider = settings.FALLBACK_LLM_PROVIDER
self.fallback_model_name = settings.FALLBACK_LLM_NAME
self.fallback_llm_api_key = settings.FALLBACK_LLM_API_KEY
self._fallback_llm = None
self._fallback_sequence_index = 0
@property
def fallback_llm(self):
"""Lazy-loaded fallback LLM instance."""
if (
self._fallback_llm is None
and self.fallback_provider
and self.fallback_model_name
):
"""Lazy-loaded fallback LLM from FALLBACK_* settings."""
if self._fallback_llm is None and settings.FALLBACK_LLM_PROVIDER:
try:
from application.llm.llm_creator import LLMCreator
self._fallback_llm = LLMCreator.create_llm(
self.fallback_provider,
self.fallback_llm_api_key,
None,
self.decoded_token,
settings.FALLBACK_LLM_PROVIDER,
api_key=settings.FALLBACK_LLM_API_KEY or settings.API_KEY,
user_api_key=None,
decoded_token=self.decoded_token,
model_id=settings.FALLBACK_LLM_NAME,
)
logger.info(
f"Fallback LLM initialized: {settings.FALLBACK_LLM_PROVIDER}/{settings.FALLBACK_LLM_NAME}"
)
except Exception as e:
logger.error(
@@ -44,11 +46,17 @@ class BaseLLM(ABC):
)
return self._fallback_llm
@staticmethod
def _remove_null_values(args_dict):
if not isinstance(args_dict, dict):
return args_dict
return {k: v for k, v in args_dict.items() if v is not None}
def _execute_with_fallback(
self, method_name: str, decorators: list, *args, **kwargs
):
"""
Unified method execution with fallback support.
Execute method with fallback support.
Args:
method_name: Name of the raw method ('_raw_gen' or '_raw_gen_stream')
@@ -67,10 +75,10 @@ class BaseLLM(ABC):
return decorated_method()
except Exception as e:
if not self.fallback_llm:
logger.error(f"Primary LLM failed and no fallback available: {str(e)}")
logger.error(f"Primary LLM failed and no fallback configured: {str(e)}")
raise
logger.warning(
f"Falling back to {self.fallback_provider}/{self.fallback_model_name}. Error: {str(e)}"
f"Primary LLM failed. Falling back to {settings.FALLBACK_LLM_PROVIDER}/{settings.FALLBACK_LLM_NAME}. Error: {str(e)}"
)
fallback_method = getattr(
@@ -120,6 +128,20 @@ class BaseLLM(ABC):
def _supports_tools(self):
raise NotImplementedError("Subclass must implement _supports_tools method")
def supports_structured_output(self):
"""Check if the LLM supports structured output/JSON schema enforcement"""
return hasattr(self, "_supports_structured_output") and callable(
getattr(self, "_supports_structured_output")
)
def _supports_structured_output(self):
return False
def prepare_structured_output_format(self, json_schema):
"""Prepare structured output format specific to the LLM provider"""
_ = json_schema
return None
def get_supported_attachment_types(self):
"""
Return a list of MIME types supported by this LLM for file uploads.
@@ -127,4 +149,4 @@ class BaseLLM(ABC):
Returns:
list: List of supported MIME types
"""
return [] # Default: no attachments supported
return []

View File

@@ -1,5 +1,7 @@
import json
from openai import OpenAI
from application.core.settings import settings
from application.llm.base import BaseLLM
@@ -7,12 +9,11 @@ from application.llm.base import BaseLLM
class DocsGPTAPILLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
from openai import OpenAI
super().__init__(*args, **kwargs)
self.client = OpenAI(api_key="sk-docsgpt-public", base_url="https://oai.arc53.com")
self.api_key = "sk-docsgpt-public"
self.client = OpenAI(api_key=self.api_key, base_url="https://oai.arc53.com")
self.user_api_key = user_api_key
self.api_key = api_key
def _clean_messages_openai(self, messages):
cleaned_messages = []
@@ -22,7 +23,6 @@ class DocsGPTAPILLM(BaseLLM):
if role == "model":
role = "assistant"
if role and content is not None:
if isinstance(content, str):
cleaned_messages.append({"role": role, "content": content})
@@ -33,14 +33,15 @@ class DocsGPTAPILLM(BaseLLM):
{"role": role, "content": item["text"]}
)
elif "function_call" in item:
cleaned_args = self._remove_null_values(
item["function_call"]["args"]
)
tool_call = {
"id": item["function_call"]["call_id"],
"type": "function",
"function": {
"name": item["function_call"]["name"],
"arguments": json.dumps(
item["function_call"]["args"]
),
"arguments": json.dumps(cleaned_args),
},
}
cleaned_messages.append(
@@ -68,7 +69,6 @@ class DocsGPTAPILLM(BaseLLM):
)
else:
raise ValueError(f"Unexpected content type: {type(content)}")
return cleaned_messages
def _raw_gen(
@@ -120,12 +120,19 @@ class DocsGPTAPILLM(BaseLLM):
response = self.client.chat.completions.create(
model="docsgpt", messages=messages, stream=stream, **kwargs
)
for line in response:
if len(line.choices) > 0 and line.choices[0].delta.content is not None and len(line.choices[0].delta.content) > 0:
yield line.choices[0].delta.content
elif len(line.choices) > 0:
yield line.choices[0]
try:
for line in response:
if (
len(line.choices) > 0
and line.choices[0].delta.content is not None
and len(line.choices[0].delta.content) > 0
):
yield line.choices[0].delta.content
elif len(line.choices) > 0:
yield line.choices[0]
finally:
if hasattr(response, "close"):
response.close()
def _supports_tools(self):
return True
return True

View File

@@ -1,18 +1,22 @@
import logging
from google import genai
from google.genai import types
import logging
import json
from application.core.settings import settings
from application.llm.base import BaseLLM
from application.storage.storage_creator import StorageCreator
from application.core.settings import settings
class GoogleLLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
def __init__(
self, api_key=None, user_api_key=None, decoded_token=None, *args, **kwargs
):
super().__init__(*args, **kwargs)
self.api_key = api_key
self.api_key = api_key or settings.GOOGLE_API_KEY or settings.API_KEY
self.user_api_key = user_api_key
self.client = genai.Client(api_key=self.api_key)
self.storage = StorageCreator.get_storage()
@@ -24,12 +28,18 @@ class GoogleLLM(BaseLLM):
list: List of supported MIME types
"""
return [
'application/pdf',
'image/png',
'image/jpeg',
'image/jpg',
'image/webp',
'image/gif'
"application/pdf",
"image/png",
"image/jpeg",
"image/jpg",
"image/webp",
"image/gif",
"application/pdf",
"image/png",
"image/jpeg",
"image/jpg",
"image/webp",
"image/gif",
]
def prepare_messages_with_attachments(self, messages, attachments=None):
@@ -45,21 +55,19 @@ class GoogleLLM(BaseLLM):
"""
if not attachments:
return messages
prepared_messages = messages.copy()
# Find the user message to attach files to the last one
user_message_index = None
for i in range(len(prepared_messages) - 1, -1, -1):
if prepared_messages[i].get("role") == "user":
user_message_index = i
break
if user_message_index is None:
user_message = {"role": "user", "content": []}
prepared_messages.append(user_message)
user_message_index = len(prepared_messages) - 1
if isinstance(prepared_messages[user_message_index].get("content"), str):
text_content = prepared_messages[user_message_index]["content"]
prepared_messages[user_message_index]["content"] = [
@@ -67,30 +75,31 @@ class GoogleLLM(BaseLLM):
]
elif not isinstance(prepared_messages[user_message_index].get("content"), list):
prepared_messages[user_message_index]["content"] = []
files = []
for attachment in attachments:
mime_type = attachment.get('mime_type')
mime_type = attachment.get("mime_type")
if mime_type in self.get_supported_attachment_types():
try:
file_uri = self._upload_file_to_google(attachment)
logging.info(f"GoogleLLM: Successfully uploaded file, got URI: {file_uri}")
logging.info(
f"GoogleLLM: Successfully uploaded file, got URI: {file_uri}"
)
files.append({"file_uri": file_uri, "mime_type": mime_type})
except Exception as e:
logging.error(f"GoogleLLM: Error uploading file: {e}", exc_info=True)
if 'content' in attachment:
prepared_messages[user_message_index]["content"].append({
"type": "text",
"text": f"[File could not be processed: {attachment.get('path', 'unknown')}]"
})
logging.error(
f"GoogleLLM: Error uploading file: {e}", exc_info=True
)
if "content" in attachment:
prepared_messages[user_message_index]["content"].append(
{
"type": "text",
"text": f"[File could not be processed: {attachment.get('path', 'unknown')}]",
}
)
if files:
logging.info(f"GoogleLLM: Adding {len(files)} files to message")
prepared_messages[user_message_index]["content"].append({
"files": files
})
prepared_messages[user_message_index]["content"].append({"files": files})
return prepared_messages
def _upload_file_to_google(self, attachment):
@@ -103,46 +112,72 @@ class GoogleLLM(BaseLLM):
Returns:
str: Google AI file URI for the uploaded file.
"""
if 'google_file_uri' in attachment:
return attachment['google_file_uri']
file_path = attachment.get('path')
if "google_file_uri" in attachment:
return attachment["google_file_uri"]
file_path = attachment.get("path")
if not file_path:
raise ValueError("No file path provided in attachment")
if not self.storage.file_exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
try:
file_uri = self.storage.process_file(
file_path,
lambda local_path, **kwargs: self.client.files.upload(file=local_path).uri
lambda local_path, **kwargs: self.client.files.upload(
file=local_path
).uri,
)
from application.core.mongo_db import MongoDB
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
attachments_collection = db["attachments"]
if '_id' in attachment:
if "_id" in attachment:
attachments_collection.update_one(
{"_id": attachment['_id']},
{"$set": {"google_file_uri": file_uri}}
{"_id": attachment["_id"]}, {"$set": {"google_file_uri": file_uri}}
)
return file_uri
except Exception as e:
logging.error(f"Error uploading file to Google AI: {e}", exc_info=True)
raise
def _clean_messages_google(self, messages):
"""
Convert OpenAI format messages to Google AI format and collect system prompts.
Returns:
tuple[list[types.Content], Optional[str]]: cleaned messages and optional
combined system instruction.
"""
cleaned_messages = []
system_instructions = []
def _extract_system_text(content):
if isinstance(content, str):
return content
if isinstance(content, list):
parts = []
for item in content:
if isinstance(item, dict) and "text" in item and item["text"] is not None:
parts.append(item["text"])
return "\n".join(parts)
return ""
for message in messages:
role = message.get("role")
content = message.get("content")
# Gemini only accepts user/model in the contents list.
if role == "system":
sys_text = _extract_system_text(content)
if sys_text:
system_instructions.append(sys_text)
continue
if role == "assistant":
role = "model"
elif role == "tool":
role = "model"
parts = []
if role and content is not None:
if isinstance(content, str):
@@ -152,12 +187,32 @@ class GoogleLLM(BaseLLM):
if "text" in item:
parts.append(types.Part.from_text(text=item["text"]))
elif "function_call" in item:
parts.append(
types.Part.from_function_call(
name=item["function_call"]["name"],
args=item["function_call"]["args"],
)
# Remove null values from args to avoid API errors
cleaned_args = self._remove_null_values(
item["function_call"]["args"]
)
# Create function call part with thought_signature if present
# For Gemini 3 models, we need to include thought_signature
if "thought_signature" in item:
# Use Part constructor with functionCall and thoughtSignature
parts.append(
types.Part(
functionCall=types.FunctionCall(
name=item["function_call"]["name"],
args=cleaned_args,
),
thoughtSignature=item["thought_signature"],
)
)
else:
# Use helper method when no thought_signature
parts.append(
types.Part.from_function_call(
name=item["function_call"]["name"],
args=cleaned_args,
)
)
elif "function_response" in item:
parts.append(
types.Part.from_function_response(
@@ -166,25 +221,75 @@ class GoogleLLM(BaseLLM):
)
)
elif "files" in item:
for file_data in item["files"]:
parts.append(
types.Part.from_uri(
file_uri=file_data["file_uri"],
mime_type=file_data["mime_type"]
)
for file_data in item["files"]:
parts.append(
types.Part.from_uri(
file_uri=file_data["file_uri"],
mime_type=file_data["mime_type"],
)
)
else:
raise ValueError(
f"Unexpected content dictionary format:{item}"
)
else:
raise ValueError(f"Unexpected content type: {type(content)}")
if parts:
cleaned_messages.append(types.Content(role=role, parts=parts))
system_instruction = "\n\n".join(system_instructions) if system_instructions else None
return cleaned_messages, system_instruction
cleaned_messages.append(types.Content(role=role, parts=parts))
def _clean_schema(self, schema_obj):
"""
Recursively remove unsupported fields from schema objects
and validate required properties.
"""
if not isinstance(schema_obj, dict):
return schema_obj
allowed_fields = {
"type",
"description",
"items",
"properties",
"required",
"enum",
"pattern",
"minimum",
"maximum",
"nullable",
"default",
}
return cleaned_messages
cleaned = {}
for key, value in schema_obj.items():
if key not in allowed_fields:
continue
elif key == "type" and isinstance(value, str):
cleaned[key] = value.upper()
elif isinstance(value, dict):
cleaned[key] = self._clean_schema(value)
elif isinstance(value, list):
cleaned[key] = [self._clean_schema(item) for item in value]
else:
cleaned[key] = value
# Validate that required properties actually exist in properties
if "required" in cleaned and "properties" in cleaned:
valid_required = []
properties_keys = set(cleaned["properties"].keys())
for required_prop in cleaned["required"]:
if required_prop in properties_keys:
valid_required.append(required_prop)
if valid_required:
cleaned["required"] = valid_required
else:
cleaned.pop("required", None)
elif "required" in cleaned and "properties" not in cleaned:
cleaned.pop("required", None)
return cleaned
def _clean_tools_format(self, tools_list):
"""Convert OpenAI format tools to Google AI format."""
genai_tools = []
for tool_data in tools_list:
if tool_data["type"] == "function":
@@ -193,18 +298,15 @@ class GoogleLLM(BaseLLM):
properties = parameters.get("properties", {})
if properties:
cleaned_properties = {}
for k, v in properties.items():
cleaned_properties[k] = self._clean_schema(v)
genai_function = dict(
name=function["name"],
description=function["description"],
parameters={
"type": "OBJECT",
"properties": {
k: {
**v,
"type": v["type"].upper() if v["type"] else None,
}
for k, v in properties.items()
},
"properties": cleaned_properties,
"required": (
parameters["required"]
if "required" in parameters
@@ -217,12 +319,65 @@ class GoogleLLM(BaseLLM):
name=function["name"],
description=function["description"],
)
genai_tool = types.Tool(function_declarations=[genai_function])
genai_tools.append(genai_tool)
return genai_tools
def _extract_preview_from_message(self, message):
"""Get a short, human-readable preview from the last message."""
try:
if hasattr(message, "parts"):
for part in reversed(message.parts):
if getattr(part, "text", None):
return part.text
function_call = getattr(part, "function_call", None)
if function_call:
name = getattr(function_call, "name", "") or "function_call"
return f"function_call:{name}"
function_response = getattr(part, "function_response", None)
if function_response:
name = getattr(function_response, "name", "") or "function_response"
return f"function_response:{name}"
if isinstance(message, dict):
content = message.get("content")
if isinstance(content, str):
return content
if isinstance(content, list):
for item in reversed(content):
if isinstance(item, str):
return item
if isinstance(item, dict):
if item.get("text"):
return item["text"]
if item.get("function_call"):
fn = item["function_call"]
if isinstance(fn, dict):
name = fn.get("name") or "function_call"
return f"function_call:{name}"
return "function_call"
if item.get("function_response"):
resp = item["function_response"]
if isinstance(resp, dict):
name = resp.get("name") or "function_response"
return f"function_response:{name}"
return "function_response"
if "text" in message and isinstance(message["text"], str):
return message["text"]
except Exception:
pass
return str(message)
def _summarize_messages_for_log(self, messages, preview_chars=20):
"""Return a compact summary for logging to avoid huge payloads."""
message_count = len(messages) if messages else 0
last_preview = ""
if messages:
last_preview = self._extract_preview_from_message(messages[-1]) or ""
last_preview = str(last_preview).replace("\n", " ")
if len(last_preview) > preview_chars:
last_preview = f"{last_preview[:preview_chars]}..."
return f"count={message_count}, last='{last_preview}'"
def _raw_gen(
self,
baseself,
@@ -231,29 +386,34 @@ class GoogleLLM(BaseLLM):
stream=False,
tools=None,
formatting="openai",
response_schema=None,
**kwargs,
):
"""Generate content using Google AI API without streaming."""
client = genai.Client(api_key=self.api_key)
system_instruction = None
if formatting == "openai":
messages = self._clean_messages_google(messages)
messages, system_instruction = self._clean_messages_google(messages)
config = types.GenerateContentConfig()
if messages[0].role == "system":
config.system_instruction = messages[0].parts[0].text
messages = messages[1:]
if system_instruction:
config.system_instruction = system_instruction
if tools:
cleaned_tools = self._clean_tools_format(tools)
config.tools = cleaned_tools
response = client.models.generate_content(
model=model,
contents=messages,
config=config,
)
# Add response schema for structured output if provided
if response_schema:
config.response_schema = response_schema
config.response_mime_type = "application/json"
response = client.models.generate_content(
model=model,
contents=messages,
config=config,
)
if tools:
return response
else:
response = client.models.generate_content(
model=model, contents=messages, config=config
)
return response.text
def _raw_gen_stream(
@@ -264,31 +424,42 @@ class GoogleLLM(BaseLLM):
stream=True,
tools=None,
formatting="openai",
response_schema=None,
**kwargs,
):
"""Generate content using Google AI API with streaming."""
client = genai.Client(api_key=self.api_key)
system_instruction = None
if formatting == "openai":
messages = self._clean_messages_google(messages)
messages, system_instruction = self._clean_messages_google(messages)
config = types.GenerateContentConfig()
if messages[0].role == "system":
config.system_instruction = messages[0].parts[0].text
messages = messages[1:]
if system_instruction:
config.system_instruction = system_instruction
if tools:
cleaned_tools = self._clean_tools_format(tools)
config.tools = cleaned_tools
# Add response schema for structured output if provided
if response_schema:
config.response_schema = response_schema
config.response_mime_type = "application/json"
# Check if we have both tools and file attachments
has_attachments = False
for message in messages:
for part in message.parts:
if hasattr(part, 'file_data') and part.file_data is not None:
if hasattr(part, "file_data") and part.file_data is not None:
has_attachments = True
break
if has_attachments:
break
logging.info(f"GoogleLLM: Starting stream generation. Model: {model}, Messages: {json.dumps(messages, default=str)}, Has attachments: {has_attachments}")
messages_summary = self._summarize_messages_for_log(messages)
logging.info(
"GoogleLLM: Starting stream generation. Model: %s, Messages: %s, Has attachments: %s",
model,
messages_summary,
has_attachments,
)
response = client.models.generate_content_stream(
model=model,
@@ -296,18 +467,88 @@ class GoogleLLM(BaseLLM):
config=config,
)
for chunk in response:
if hasattr(chunk, "candidates") and chunk.candidates:
for candidate in chunk.candidates:
if candidate.content and candidate.content.parts:
for part in candidate.content.parts:
if part.function_call:
yield part
elif part.text:
yield part.text
elif hasattr(chunk, "text"):
yield chunk.text
try:
for chunk in response:
if hasattr(chunk, "candidates") and chunk.candidates:
for candidate in chunk.candidates:
if candidate.content and candidate.content.parts:
for part in candidate.content.parts:
if part.function_call:
yield part
elif part.text:
yield part.text
elif hasattr(chunk, "text"):
yield chunk.text
finally:
if hasattr(response, "close"):
response.close()
def _supports_tools(self):
"""Return whether this LLM supports function calling."""
return True
def _supports_structured_output(self):
"""Return whether this LLM supports structured JSON output."""
return True
def prepare_structured_output_format(self, json_schema):
"""Convert JSON schema to Google AI structured output format."""
if not json_schema:
return None
type_map = {
"object": "OBJECT",
"array": "ARRAY",
"string": "STRING",
"integer": "INTEGER",
"number": "NUMBER",
"boolean": "BOOLEAN",
}
def convert(schema):
if not isinstance(schema, dict):
return schema
result = {}
schema_type = schema.get("type")
if schema_type:
result["type"] = type_map.get(schema_type.lower(), schema_type.upper())
for key in [
"description",
"nullable",
"enum",
"minItems",
"maxItems",
"required",
"propertyOrdering",
]:
if key in schema:
result[key] = schema[key]
if "format" in schema:
format_value = schema["format"]
if schema_type == "string":
if format_value == "date":
result["format"] = "date-time"
elif format_value in ["enum", "date-time"]:
result["format"] = format_value
else:
result["format"] = format_value
if "properties" in schema:
result["properties"] = {
k: convert(v) for k, v in schema["properties"].items()
}
if "propertyOrdering" not in result and result.get("type") == "OBJECT":
result["propertyOrdering"] = list(result["properties"].keys())
if "items" in schema:
result["items"] = convert(schema["items"])
for field in ["anyOf", "oneOf", "allOf"]:
if field in schema:
result[field] = [convert(s) for s in schema[field]]
return result
try:
return convert(json_schema)
except Exception as e:
logging.error(
f"Error preparing structured output format for Google: {e}",
exc_info=True,
)
return None

View File

@@ -1,13 +1,18 @@
from application.llm.base import BaseLLM
from openai import OpenAI
from application.core.settings import settings
from application.llm.base import BaseLLM
class GroqLLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.client = OpenAI(api_key=api_key, base_url="https://api.groq.com/openai/v1")
self.api_key = api_key
self.api_key = api_key or settings.GROQ_API_KEY or settings.API_KEY
self.user_api_key = user_api_key
self.client = OpenAI(
api_key=self.api_key, base_url="https://api.groq.com/openai/v1"
)
def _raw_gen(self, baseself, model, messages, stream=False, tools=None, **kwargs):
if tools:

View File

@@ -1,4 +1,5 @@
import logging
import uuid
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Generator, List, Optional, Union
@@ -16,6 +17,7 @@ class ToolCall:
name: str
arguments: Union[str, Dict]
index: Optional[int] = None
thought_signature: Optional[str] = None
@classmethod
def from_dict(cls, data: Dict) -> "ToolCall":
@@ -178,6 +180,406 @@ class LLMHandler(ABC):
system_msg["content"] += f"\n\n{combined_text}"
return prepared_messages
def _prune_messages_minimal(self, messages: List[Dict]) -> Optional[List[Dict]]:
"""
Build a minimal context: system prompt + latest user message only.
Drops all tool/function messages to shrink context aggressively.
"""
system_message = next((m for m in messages if m.get("role") == "system"), None)
if not system_message:
logger.warning("Cannot prune messages minimally: missing system message.")
return None
last_non_system = None
for m in reversed(messages):
if m.get("role") == "user":
last_non_system = m
break
if not last_non_system and m.get("role") not in ("system", None):
last_non_system = m
if not last_non_system:
logger.warning("Cannot prune messages minimally: missing user/assistant messages.")
return None
logger.info("Pruning context to system + latest user/assistant message to proceed.")
return [system_message, last_non_system]
def _extract_text_from_content(self, content: Any) -> str:
"""
Convert message content (str or list of parts) to plain text for compression.
"""
if isinstance(content, str):
return content
if isinstance(content, list):
parts_text = []
for item in content:
if isinstance(item, dict):
if "text" in item and item["text"] is not None:
parts_text.append(str(item["text"]))
elif "function_call" in item or "function_response" in item:
# Keep serialized function calls/responses so the compressor sees actions
parts_text.append(str(item))
elif "files" in item:
parts_text.append(str(item))
return "\n".join(parts_text)
return ""
def _build_conversation_from_messages(self, messages: List[Dict]) -> Optional[Dict]:
"""
Build a conversation-like dict from current messages so we can compress
even when the conversation isn't persisted yet. Includes tool calls/results.
"""
queries = []
current_prompt = None
current_tool_calls = {}
def _commit_query(response_text: str):
nonlocal current_prompt, current_tool_calls
if current_prompt is None and not response_text:
return
tool_calls_list = list(current_tool_calls.values())
queries.append(
{
"prompt": current_prompt or "",
"response": response_text,
"tool_calls": tool_calls_list,
}
)
current_prompt = None
current_tool_calls = {}
for message in messages:
role = message.get("role")
content = message.get("content")
if role == "user":
current_prompt = self._extract_text_from_content(content)
elif role in {"assistant", "model"}:
# If this assistant turn contains tool calls, collect them; otherwise commit a response.
if isinstance(content, list):
for item in content:
if "function_call" in item:
fc = item["function_call"]
call_id = fc.get("call_id") or str(uuid.uuid4())
current_tool_calls[call_id] = {
"tool_name": "unknown_tool",
"action_name": fc.get("name"),
"arguments": fc.get("args"),
"result": None,
"status": "called",
"call_id": call_id,
}
elif "function_response" in item:
fr = item["function_response"]
call_id = fr.get("call_id") or str(uuid.uuid4())
current_tool_calls[call_id] = {
"tool_name": "unknown_tool",
"action_name": fr.get("name"),
"arguments": None,
"result": fr.get("response", {}).get("result"),
"status": "completed",
"call_id": call_id,
}
# No direct assistant text here; continue to next message
continue
response_text = self._extract_text_from_content(content)
_commit_query(response_text)
elif role == "tool":
# Attach tool outputs to the latest pending tool call if possible
tool_text = self._extract_text_from_content(content)
# Attempt to parse function_response style
call_id = None
if isinstance(content, list):
for item in content:
if "function_response" in item and item["function_response"].get("call_id"):
call_id = item["function_response"]["call_id"]
break
if call_id and call_id in current_tool_calls:
current_tool_calls[call_id]["result"] = tool_text
current_tool_calls[call_id]["status"] = "completed"
elif queries:
queries[-1].setdefault("tool_calls", []).append(
{
"tool_name": "unknown_tool",
"action_name": "unknown_action",
"arguments": {},
"result": tool_text,
"status": "completed",
}
)
# If there's an unfinished prompt with tool_calls but no response yet, commit it
if current_prompt is not None or current_tool_calls:
_commit_query(response_text="")
if not queries:
return None
return {
"queries": queries,
"compression_metadata": {
"is_compressed": False,
"compression_points": [],
},
}
def _rebuild_messages_after_compression(
self,
messages: List[Dict],
compressed_summary: Optional[str],
recent_queries: List[Dict],
include_current_execution: bool = False,
include_tool_calls: bool = False,
) -> Optional[List[Dict]]:
"""
Rebuild the message list after compression so tool execution can continue.
Delegates to MessageBuilder for the actual reconstruction.
"""
from application.api.answer.services.compression.message_builder import (
MessageBuilder,
)
return MessageBuilder.rebuild_messages_after_compression(
messages=messages,
compressed_summary=compressed_summary,
recent_queries=recent_queries,
include_current_execution=include_current_execution,
include_tool_calls=include_tool_calls,
)
def _perform_mid_execution_compression(
self, agent, messages: List[Dict]
) -> tuple[bool, Optional[List[Dict]]]:
"""
Perform compression during tool execution and rebuild messages.
Uses the new orchestrator for simplified compression.
Args:
agent: The agent instance
messages: Current conversation messages
Returns:
(success: bool, rebuilt_messages: Optional[List[Dict]])
"""
try:
from application.api.answer.services.compression import (
CompressionOrchestrator,
)
from application.api.answer.services.conversation_service import (
ConversationService,
)
conversation_service = ConversationService()
orchestrator = CompressionOrchestrator(conversation_service)
# Get conversation from database (may be None for new sessions)
conversation = conversation_service.get_conversation(
agent.conversation_id, agent.initial_user_id
)
if conversation:
# Merge current in-flight messages (including tool calls)
conversation_from_msgs = self._build_conversation_from_messages(messages)
if conversation_from_msgs:
conversation = conversation_from_msgs
else:
logger.warning(
"Could not load conversation for compression; attempting in-memory compression"
)
return self._perform_in_memory_compression(agent, messages)
# Use orchestrator to perform compression
result = orchestrator.compress_mid_execution(
conversation_id=agent.conversation_id,
user_id=agent.initial_user_id,
model_id=agent.model_id,
decoded_token=getattr(agent, "decoded_token", {}),
current_conversation=conversation,
)
if not result.success:
logger.warning(f"Mid-execution compression failed: {result.error}")
# Try minimal pruning as fallback
pruned = self._prune_messages_minimal(messages)
if pruned:
agent.context_limit_reached = False
agent.current_token_count = 0
return True, pruned
return False, None
if not result.compression_performed:
logger.warning("Compression not performed")
return False, None
# Check if compression actually reduced tokens
if result.metadata:
if result.metadata.compressed_token_count >= result.metadata.original_token_count:
logger.warning(
"Compression did not reduce token count; falling back to minimal pruning"
)
pruned = self._prune_messages_minimal(messages)
if pruned:
agent.context_limit_reached = False
agent.current_token_count = 0
return True, pruned
return False, None
logger.info(
f"Mid-execution compression successful - ratio: {result.metadata.compression_ratio:.1f}x, "
f"saved {result.metadata.original_token_count - result.metadata.compressed_token_count} tokens"
)
# Also store the compression summary as a visible message
if result.metadata:
conversation_service.append_compression_message(
agent.conversation_id, result.metadata.to_dict()
)
# Update agent's compressed summary for downstream persistence
agent.compressed_summary = result.compressed_summary
agent.compression_metadata = result.metadata.to_dict() if result.metadata else None
agent.compression_saved = False
# Reset the context limit flag so tools can continue
agent.context_limit_reached = False
agent.current_token_count = 0
# Rebuild messages
rebuilt_messages = self._rebuild_messages_after_compression(
messages,
result.compressed_summary,
result.recent_queries,
include_current_execution=False,
include_tool_calls=False,
)
if rebuilt_messages is None:
return False, None
return True, rebuilt_messages
except Exception as e:
logger.error(
f"Error performing mid-execution compression: {str(e)}", exc_info=True
)
return False, None
def _perform_in_memory_compression(
self, agent, messages: List[Dict]
) -> tuple[bool, Optional[List[Dict]]]:
"""
Fallback compression path when the conversation is not yet persisted.
Uses CompressionService directly without DB persistence.
"""
try:
from application.api.answer.services.compression.service import (
CompressionService,
)
from application.core.model_utils import (
get_api_key_for_provider,
get_provider_from_model_id,
)
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
conversation = self._build_conversation_from_messages(messages)
if not conversation:
logger.warning(
"Cannot perform in-memory compression: no user/assistant turns found"
)
return False, None
compression_model = (
settings.COMPRESSION_MODEL_OVERRIDE
if settings.COMPRESSION_MODEL_OVERRIDE
else agent.model_id
)
provider = get_provider_from_model_id(compression_model)
api_key = get_api_key_for_provider(provider)
compression_llm = LLMCreator.create_llm(
provider,
api_key,
getattr(agent, "user_api_key", None),
getattr(agent, "decoded_token", None),
model_id=compression_model,
)
# Create service without DB persistence capability
compression_service = CompressionService(
llm=compression_llm,
model_id=compression_model,
conversation_service=None, # No DB updates for in-memory
)
queries_count = len(conversation.get("queries", []))
compress_up_to = queries_count - 1
if compress_up_to < 0 or queries_count == 0:
logger.warning("Not enough queries to compress in-memory context")
return False, None
metadata = compression_service.compress_conversation(
conversation,
compress_up_to_index=compress_up_to,
)
# If compression doesn't reduce tokens, fall back to minimal pruning
if (
metadata.compressed_token_count
>= metadata.original_token_count
):
logger.warning(
"In-memory compression did not reduce token count; falling back to minimal pruning"
)
pruned = self._prune_messages_minimal(messages)
if pruned:
agent.context_limit_reached = False
agent.current_token_count = 0
return True, pruned
return False, None
# Attach metadata to synthetic conversation
conversation["compression_metadata"] = {
"is_compressed": True,
"compression_points": [metadata.to_dict()],
}
compressed_summary, recent_queries = (
compression_service.get_compressed_context(conversation)
)
agent.compressed_summary = compressed_summary
agent.compression_metadata = metadata.to_dict()
agent.compression_saved = False
agent.context_limit_reached = False
agent.current_token_count = 0
rebuilt_messages = self._rebuild_messages_after_compression(
messages,
compressed_summary,
recent_queries,
include_current_execution=False,
include_tool_calls=False,
)
if rebuilt_messages is None:
return False, None
logger.info(
f"In-memory compression successful - ratio: {metadata.compression_ratio:.1f}x, "
f"saved {metadata.original_token_count - metadata.compressed_token_count} tokens"
)
return True, rebuilt_messages
except Exception as e:
logger.error(
f"Error performing in-memory compression: {str(e)}", exc_info=True
)
return False, None
def handle_tool_calls(
self, agent, tool_calls: List[ToolCall], tools_dict: Dict, messages: List[Dict]
) -> Generator:
@@ -195,7 +597,110 @@ class LLMHandler(ABC):
"""
updated_messages = messages.copy()
for call in tool_calls:
for i, call in enumerate(tool_calls):
# Check context limit before executing tool call
if hasattr(agent, '_check_context_limit') and agent._check_context_limit(updated_messages):
# Context limit reached - attempt mid-execution compression
compression_attempted = False
compression_successful = False
try:
from application.core.settings import settings
compression_enabled = settings.ENABLE_CONVERSATION_COMPRESSION
except Exception:
compression_enabled = False
if compression_enabled:
compression_attempted = True
try:
logger.info(
f"Context limit reached with {len(tool_calls) - i} remaining tool calls. "
f"Attempting mid-execution compression..."
)
# Trigger mid-execution compression (DB-backed if available, otherwise in-memory)
compression_successful, rebuilt_messages = self._perform_mid_execution_compression(
agent, updated_messages
)
if compression_successful and rebuilt_messages is not None:
# Update the messages list with rebuilt compressed version
updated_messages = rebuilt_messages
# Yield compression success message
yield {
"type": "info",
"data": {
"message": "Context window limit reached. Compressed conversation history to continue processing."
}
}
logger.info(
f"Mid-execution compression successful. Continuing with {len(tool_calls) - i} remaining tool calls."
)
# Proceed to execute the current tool call with the reduced context
else:
logger.warning("Mid-execution compression attempted but failed. Skipping remaining tools.")
except Exception as e:
logger.error(f"Error during mid-execution compression: {str(e)}", exc_info=True)
compression_attempted = True
compression_successful = False
# If compression wasn't attempted or failed, skip remaining tools
if not compression_successful:
if i == 0:
# Special case: limit reached before executing any tools
# This can happen when previous tool responses pushed context over limit
if compression_attempted:
logger.warning(
f"Context limit reached before executing any tools. "
f"Compression attempted but failed. "
f"Skipping all {len(tool_calls)} pending tool call(s). "
f"This typically occurs when previous tool responses contained large amounts of data."
)
else:
logger.warning(
f"Context limit reached before executing any tools. "
f"Skipping all {len(tool_calls)} pending tool call(s). "
f"This typically occurs when previous tool responses contained large amounts of data. "
f"Consider enabling compression or using a model with larger context window."
)
else:
# Normal case: executed some tools, now stopping
tool_word = "tool call" if i == 1 else "tool calls"
remaining = len(tool_calls) - i
remaining_word = "tool call" if remaining == 1 else "tool calls"
if compression_attempted:
logger.warning(
f"Context limit reached after executing {i} {tool_word}. "
f"Compression attempted but failed. "
f"Skipping remaining {remaining} {remaining_word}."
)
else:
logger.warning(
f"Context limit reached after executing {i} {tool_word}. "
f"Skipping remaining {remaining} {remaining_word}. "
f"Consider enabling compression or using a model with larger context window."
)
# Mark remaining tools as skipped
for remaining_call in tool_calls[i:]:
skip_message = {
"type": "tool_call",
"data": {
"tool_name": "system",
"call_id": remaining_call.id,
"action_name": remaining_call.name,
"arguments": {},
"result": "Skipped: Context limit reached. Too many tool calls in conversation.",
"status": "skipped"
}
}
yield skip_message
# Set flag on agent
agent.context_limit_reached = True
break
try:
self.tool_calls.append(call)
tool_executor_gen = agent._execute_tool_action(tools_dict, call)
@@ -205,34 +710,57 @@ class LLMHandler(ABC):
except StopIteration as e:
tool_response, call_id = e.value
break
function_call_content = {
"function_call": {
"name": call.name,
"args": call.arguments,
"call_id": call_id,
}
}
# Include thought_signature for Google Gemini 3 models
# It should be at the same level as function_call, not inside it
if call.thought_signature:
function_call_content["thought_signature"] = call.thought_signature
updated_messages.append(
{
"role": "assistant",
"content": [
{
"function_call": {
"name": call.name,
"args": call.arguments,
"call_id": call_id,
}
}
],
"content": [function_call_content],
}
)
updated_messages.append(self.create_tool_message(call, tool_response))
except Exception as e:
logger.error(f"Error executing tool: {str(e)}", exc_info=True)
updated_messages.append(
{
"role": "tool",
"content": f"Error executing tool: {str(e)}",
"tool_call_id": call.id,
}
error_call = ToolCall(
id=call.id, name=call.name, arguments=call.arguments
)
error_response = f"Error executing tool: {str(e)}"
error_message = self.create_tool_message(error_call, error_response)
updated_messages.append(error_message)
call_parts = call.name.split("_")
if len(call_parts) >= 2:
tool_id = call_parts[-1] # Last part is tool ID (e.g., "1")
action_name = "_".join(call_parts[:-1])
tool_name = tools_dict.get(tool_id, {}).get("name", "unknown_tool")
full_action_name = f"{action_name}_{tool_id}"
else:
tool_name = "unknown_tool"
action_name = call.name
full_action_name = call.name
yield {
"type": "tool_call",
"data": {
"tool_name": tool_name,
"call_id": call.id,
"action_name": full_action_name,
"arguments": call.arguments,
"error": error_response,
"status": "error",
},
}
return updated_messages
def handle_non_streaming(
@@ -263,13 +791,11 @@ class LLMHandler(ABC):
except StopIteration as e:
messages = e.value
break
response = agent.llm.gen(
model=agent.gpt_model, messages=messages, tools=agent.tools
model=agent.model_id, messages=messages, tools=agent.tools
)
parsed = self.parse_response(response)
self.llm_calls.append(build_stack_data(agent.llm))
return parsed.content
def handle_streaming(
@@ -308,6 +834,9 @@ class LLMHandler(ABC):
existing.name = call.name
if call.arguments:
existing.arguments += call.arguments
# Preserve thought_signature for Google Gemini 3 models
if call.thought_signature:
existing.thought_signature = call.thought_signature
if parsed.finish_reason == "tool_calls":
tool_handler_gen = self.handle_tool_calls(
agent, list(tool_calls.values()), tools_dict, messages
@@ -320,8 +849,21 @@ class LLMHandler(ABC):
break
tool_calls = {}
# Check if context limit was reached during tool execution
if hasattr(agent, 'context_limit_reached') and agent.context_limit_reached:
# Add system message warning about context limit
messages.append({
"role": "system",
"content": (
"WARNING: Context window limit has been reached. "
"Please provide a final response to the user without making additional tool calls. "
"Summarize the work completed so far."
)
})
logger.info("Context limit reached - instructing agent to wrap up")
response = agent.llm.gen_stream(
model=agent.gpt_model, messages=messages, tools=agent.tools
model=agent.model_id, messages=messages, tools=agent.tools if not agent.context_limit_reached else None
)
self.llm_calls.append(build_stack_data(agent.llm))

View File

@@ -17,18 +17,22 @@ class GoogleLLMHandler(LLMHandler):
finish_reason="stop",
raw_response=response,
)
if hasattr(response, "candidates"):
parts = response.candidates[0].content.parts if response.candidates else []
tool_calls = [
ToolCall(
id=str(uuid.uuid4()),
name=part.function_call.name,
arguments=part.function_call.args,
)
for part in parts
if hasattr(part, "function_call") and part.function_call is not None
]
tool_calls = []
for idx, part in enumerate(parts):
if hasattr(part, "function_call") and part.function_call is not None:
has_sig = hasattr(part, "thought_signature") and part.thought_signature is not None
thought_sig = part.thought_signature if has_sig else None
tool_calls.append(
ToolCall(
id=str(uuid.uuid4()),
name=part.function_call.name,
arguments=part.function_call.args,
index=idx,
thought_signature=thought_sig,
)
)
content = " ".join(
part.text
@@ -41,15 +45,18 @@ class GoogleLLMHandler(LLMHandler):
finish_reason="tool_calls" if tool_calls else "stop",
raw_response=response,
)
else:
# This branch handles individual Part objects from streaming responses
tool_calls = []
if hasattr(response, "function_call"):
if hasattr(response, "function_call") and response.function_call is not None:
has_sig = hasattr(response, "thought_signature") and response.thought_signature is not None
thought_sig = response.thought_signature if has_sig else None
tool_calls.append(
ToolCall(
id=str(uuid.uuid4()),
name=response.function_call.name,
arguments=response.function_call.args,
thought_signature=thought_sig,
)
)
return LLMResponse(
@@ -61,14 +68,16 @@ class GoogleLLMHandler(LLMHandler):
def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict:
"""Create Google-style tool message."""
from google.genai import types
return {
"role": "tool",
"role": "model",
"content": [
types.Part.from_function_response(
name=tool_call.name, response={"result": result}
).to_json_dict()
{
"function_response": {
"name": tool_call.name,
"response": {"result": result},
}
}
],
}

View File

@@ -1,13 +1,17 @@
from application.llm.groq import GroqLLM
from application.llm.openai import OpenAILLM, AzureOpenAILLM
from application.llm.sagemaker import SagemakerAPILLM
from application.llm.huggingface import HuggingFaceLLM
from application.llm.llama_cpp import LlamaCpp
import logging
from application.llm.anthropic import AnthropicLLM
from application.llm.docsgpt_provider import DocsGPTAPILLM
from application.llm.premai import PremAILLM
from application.llm.google_ai import GoogleLLM
from application.llm.groq import GroqLLM
from application.llm.huggingface import HuggingFaceLLM
from application.llm.llama_cpp import LlamaCpp
from application.llm.novita import NovitaLLM
from application.llm.openai import AzureOpenAILLM, OpenAILLM
from application.llm.premai import PremAILLM
from application.llm.sagemaker import SagemakerAPILLM
logger = logging.getLogger(__name__)
class LLMCreator:
@@ -26,10 +30,26 @@ class LLMCreator:
}
@classmethod
def create_llm(cls, type, api_key, user_api_key, decoded_token, *args, **kwargs):
def create_llm(
cls, type, api_key, user_api_key, decoded_token, model_id=None, *args, **kwargs
):
from application.core.model_utils import get_base_url_for_model
llm_class = cls.llms.get(type.lower())
if not llm_class:
raise ValueError(f"No LLM class found for type {type}")
# Extract base_url from model configuration if model_id is provided
base_url = None
if model_id:
base_url = get_base_url_for_model(model_id)
return llm_class(
api_key, user_api_key, decoded_token=decoded_token, *args, **kwargs
api_key,
user_api_key,
decoded_token=decoded_token,
model_id=model_id,
base_url=base_url,
*args,
**kwargs,
)

View File

@@ -1,7 +1,9 @@
import json
import base64
import json
import logging
from openai import OpenAI
from application.core.settings import settings
from application.llm.base import BaseLLM
from application.storage.storage_creator import StorageCreator
@@ -9,17 +11,25 @@ from application.storage.storage_creator import StorageCreator
class OpenAILLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
from openai import OpenAI
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
super().__init__(*args, **kwargs)
if isinstance(settings.OPENAI_BASE_URL, str) and settings.OPENAI_BASE_URL.strip():
self.client = OpenAI(api_key=api_key, base_url=settings.OPENAI_BASE_URL)
else:
DEFAULT_OPENAI_API_BASE = "https://api.openai.com/v1"
self.client = OpenAI(api_key=api_key, base_url=DEFAULT_OPENAI_API_BASE)
self.api_key = api_key
self.api_key = api_key or settings.OPENAI_API_KEY or settings.API_KEY
self.user_api_key = user_api_key
# Priority: 1) Parameter base_url, 2) Settings OPENAI_BASE_URL, 3) Default
effective_base_url = None
if base_url and isinstance(base_url, str) and base_url.strip():
effective_base_url = base_url
elif (
isinstance(settings.OPENAI_BASE_URL, str)
and settings.OPENAI_BASE_URL.strip()
):
effective_base_url = settings.OPENAI_BASE_URL
else:
effective_base_url = "https://api.openai.com/v1"
self.client = OpenAI(api_key=self.api_key, base_url=effective_base_url)
self.storage = StorageCreator.get_storage()
def _clean_messages_openai(self, messages):
@@ -30,7 +40,6 @@ class OpenAILLM(BaseLLM):
if role == "model":
role = "assistant"
if role and content is not None:
if isinstance(content, str):
cleaned_messages.append({"role": role, "content": content})
@@ -41,14 +50,15 @@ class OpenAILLM(BaseLLM):
{"role": role, "content": item["text"]}
)
elif "function_call" in item:
cleaned_args = self._remove_null_values(
item["function_call"]["args"]
)
tool_call = {
"id": item["function_call"]["call_id"],
"type": "function",
"function": {
"name": item["function_call"]["name"],
"arguments": json.dumps(
item["function_call"]["args"]
),
"arguments": json.dumps(cleaned_args),
},
}
cleaned_messages.append(
@@ -73,21 +83,36 @@ class OpenAILLM(BaseLLM):
elif isinstance(item, dict):
content_parts = []
if "text" in item:
content_parts.append({"type": "text", "text": item["text"]})
elif "type" in item and item["type"] == "text" and "text" in item:
content_parts.append(
{"type": "text", "text": item["text"]}
)
elif (
"type" in item
and item["type"] == "text"
and "text" in item
):
content_parts.append(item)
elif "type" in item and item["type"] == "file" and "file" in item:
elif (
"type" in item
and item["type"] == "file"
and "file" in item
):
content_parts.append(item)
elif "type" in item and item["type"] == "image_url" and "image_url" in item:
elif (
"type" in item
and item["type"] == "image_url"
and "image_url" in item
):
content_parts.append(item)
cleaned_messages.append({"role": role, "content": content_parts})
cleaned_messages.append(
{"role": role, "content": content_parts}
)
else:
raise ValueError(
f"Unexpected content dictionary format: {item}"
)
else:
raise ValueError(f"Unexpected content type: {type(content)}")
return cleaned_messages
def _raw_gen(
@@ -98,22 +123,31 @@ class OpenAILLM(BaseLLM):
stream=False,
tools=None,
engine=settings.AZURE_DEPLOYMENT_NAME,
response_format=None,
**kwargs,
):
messages = self._clean_messages_openai(messages)
# Convert max_tokens to max_completion_tokens for newer models
if "max_tokens" in kwargs:
kwargs["max_completion_tokens"] = kwargs.pop("max_tokens")
request_params = {
"model": model,
"messages": messages,
"stream": stream,
**kwargs,
}
if tools:
request_params["tools"] = tools
if response_format:
request_params["response_format"] = response_format
response = self.client.chat.completions.create(**request_params)
if tools:
response = self.client.chat.completions.create(
model=model,
messages=messages,
stream=stream,
tools=tools,
**kwargs,
)
return response.choices[0]
else:
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
return response.choices[0].message.content
def _raw_gen_stream(
@@ -124,31 +158,102 @@ class OpenAILLM(BaseLLM):
stream=True,
tools=None,
engine=settings.AZURE_DEPLOYMENT_NAME,
response_format=None,
**kwargs,
):
messages = self._clean_messages_openai(messages)
if tools:
response = self.client.chat.completions.create(
model=model,
messages=messages,
stream=stream,
tools=tools,
**kwargs,
)
else:
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
for line in response:
if len(line.choices) > 0 and line.choices[0].delta.content is not None and len(line.choices[0].delta.content) > 0:
yield line.choices[0].delta.content
elif len(line.choices) > 0:
yield line.choices[0]
# Convert max_tokens to max_completion_tokens for newer models
if "max_tokens" in kwargs:
kwargs["max_completion_tokens"] = kwargs.pop("max_tokens")
request_params = {
"model": model,
"messages": messages,
"stream": stream,
**kwargs,
}
if tools:
request_params["tools"] = tools
if response_format:
request_params["response_format"] = response_format
response = self.client.chat.completions.create(**request_params)
try:
for line in response:
if (
len(line.choices) > 0
and line.choices[0].delta.content is not None
and len(line.choices[0].delta.content) > 0
):
yield line.choices[0].delta.content
elif len(line.choices) > 0:
yield line.choices[0]
finally:
if hasattr(response, "close"):
response.close()
def _supports_tools(self):
return True
def _supports_structured_output(self):
return True
def prepare_structured_output_format(self, json_schema):
if not json_schema:
return None
try:
def add_additional_properties_false(schema_obj):
if isinstance(schema_obj, dict):
schema_copy = schema_obj.copy()
if schema_copy.get("type") == "object":
schema_copy["additionalProperties"] = False
# Ensure 'required' includes all properties for OpenAI strict mode
if "properties" in schema_copy:
schema_copy["required"] = list(
schema_copy["properties"].keys()
)
for key, value in schema_copy.items():
if key == "properties" and isinstance(value, dict):
schema_copy[key] = {
prop_name: add_additional_properties_false(prop_schema)
for prop_name, prop_schema in value.items()
}
elif key == "items" and isinstance(value, dict):
schema_copy[key] = add_additional_properties_false(value)
elif key in ["anyOf", "oneOf", "allOf"] and isinstance(
value, list
):
schema_copy[key] = [
add_additional_properties_false(sub_schema)
for sub_schema in value
]
return schema_copy
return schema_obj
processed_schema = add_additional_properties_false(json_schema)
result = {
"type": "json_schema",
"json_schema": {
"name": processed_schema.get("name", "response"),
"description": processed_schema.get(
"description", "Structured response"
),
"schema": processed_schema,
"strict": True,
},
}
return result
except Exception as e:
logging.error(f"Error preparing structured output format: {e}")
return None
def get_supported_attachment_types(self):
"""
Return a list of MIME types supported by OpenAI for file uploads.
@@ -157,12 +262,12 @@ class OpenAILLM(BaseLLM):
list: List of supported MIME types
"""
return [
'application/pdf',
'image/png',
'image/jpeg',
'image/jpg',
'image/webp',
'image/gif'
"application/pdf",
"image/png",
"image/jpeg",
"image/jpg",
"image/webp",
"image/gif",
]
def prepare_messages_with_attachments(self, messages, attachments=None):
@@ -178,21 +283,19 @@ class OpenAILLM(BaseLLM):
"""
if not attachments:
return messages
prepared_messages = messages.copy()
# Find the user message to attach file_id to the last one
user_message_index = None
for i in range(len(prepared_messages) - 1, -1, -1):
if prepared_messages[i].get("role") == "user":
user_message_index = i
break
if user_message_index is None:
user_message = {"role": "user", "content": []}
prepared_messages.append(user_message)
user_message_index = len(prepared_messages) - 1
if isinstance(prepared_messages[user_message_index].get("content"), str):
text_content = prepared_messages[user_message_index]["content"]
prepared_messages[user_message_index]["content"] = [
@@ -200,42 +303,48 @@ class OpenAILLM(BaseLLM):
]
elif not isinstance(prepared_messages[user_message_index].get("content"), list):
prepared_messages[user_message_index]["content"] = []
for attachment in attachments:
mime_type = attachment.get('mime_type')
mime_type = attachment.get("mime_type")
if mime_type and mime_type.startswith('image/'):
if mime_type and mime_type.startswith("image/"):
try:
base64_image = self._get_base64_image(attachment)
prepared_messages[user_message_index]["content"].append({
"type": "image_url",
"image_url": {
"url": f"data:{mime_type};base64,{base64_image}"
prepared_messages[user_message_index]["content"].append(
{
"type": "image_url",
"image_url": {
"url": f"data:{mime_type};base64,{base64_image}"
},
}
})
)
except Exception as e:
logging.error(f"Error processing image attachment: {e}", exc_info=True)
if 'content' in attachment:
prepared_messages[user_message_index]["content"].append({
"type": "text",
"text": f"[Image could not be processed: {attachment.get('path', 'unknown')}]"
})
logging.error(
f"Error processing image attachment: {e}", exc_info=True
)
if "content" in attachment:
prepared_messages[user_message_index]["content"].append(
{
"type": "text",
"text": f"[Image could not be processed: {attachment.get('path', 'unknown')}]",
}
)
# Handle PDFs using the file API
elif mime_type == 'application/pdf':
elif mime_type == "application/pdf":
try:
file_id = self._upload_file_to_openai(attachment)
prepared_messages[user_message_index]["content"].append({
"type": "file",
"file": {"file_id": file_id}
})
prepared_messages[user_message_index]["content"].append(
{"type": "file", "file": {"file_id": file_id}}
)
except Exception as e:
logging.error(f"Error uploading PDF to OpenAI: {e}", exc_info=True)
if 'content' in attachment:
prepared_messages[user_message_index]["content"].append({
"type": "text",
"text": f"File content:\n\n{attachment['content']}"
})
if "content" in attachment:
prepared_messages[user_message_index]["content"].append(
{
"type": "text",
"text": f"File content:\n\n{attachment['content']}",
}
)
return prepared_messages
def _get_base64_image(self, attachment):
@@ -248,13 +357,12 @@ class OpenAILLM(BaseLLM):
Returns:
str: Base64-encoded image data.
"""
file_path = attachment.get('path')
file_path = attachment.get("path")
if not file_path:
raise ValueError("No file path provided in attachment")
try:
with self.storage.get_file(file_path) as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
return base64.b64encode(image_file.read()).decode("utf-8")
except FileNotFoundError:
raise FileNotFoundError(f"File not found: {file_path}")
@@ -273,33 +381,29 @@ class OpenAILLM(BaseLLM):
"""
import logging
if 'openai_file_id' in attachment:
return attachment['openai_file_id']
file_path = attachment.get('path')
if "openai_file_id" in attachment:
return attachment["openai_file_id"]
file_path = attachment.get("path")
if not self.storage.file_exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
try:
file_id = self.storage.process_file(
file_path,
lambda local_path, **kwargs: self.client.files.create(
file=open(local_path, 'rb'),
purpose="assistants"
).id
file=open(local_path, "rb"), purpose="assistants"
).id,
)
from application.core.mongo_db import MongoDB
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
attachments_collection = db["attachments"]
if '_id' in attachment:
if "_id" in attachment:
attachments_collection.update_one(
{"_id": attachment['_id']},
{"$set": {"openai_file_id": file_id}}
{"_id": attachment["_id"]}, {"$set": {"openai_file_id": file_id}}
)
return file_id
except Exception as e:
logging.error(f"Error uploading file to OpenAI: {e}", exc_info=True)
@@ -308,9 +412,7 @@ class OpenAILLM(BaseLLM):
class AzureOpenAILLM(OpenAILLM):
def __init__(
self, api_key, user_api_key, *args, **kwargs
):
def __init__(self, api_key, user_api_key, *args, **kwargs):
super().__init__(api_key)
self.api_base = (settings.OPENAI_API_BASE,)
@@ -321,5 +423,5 @@ class AzureOpenAILLM(OpenAILLM):
self.client = AzureOpenAI(
api_key=api_key,
api_version=settings.OPENAI_API_VERSION,
azure_endpoint=settings.OPENAI_API_BASE
azure_endpoint=settings.OPENAI_API_BASE,
)

View File

@@ -136,6 +136,8 @@ def _log_to_mongodb(
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
user_logs_collection = db["stack_logs"]
log_entry = {
"endpoint": endpoint,
@@ -147,6 +149,11 @@ def _log_to_mongodb(
"stacks": stacks,
"timestamp": datetime.datetime.now(datetime.timezone.utc),
}
# clean up text fields to be no longer than 10000 characters
for key, value in log_entry.items():
if isinstance(value, str) and len(value) > 10000:
log_entry[key] = value[:10000]
user_logs_collection.insert_one(log_entry)
logging.debug(f"Logged activity to MongoDB: {activity_id}")

View File

@@ -32,16 +32,7 @@ class Chunker:
header, body = "", text # No header, treat entire text as body
return header, body
def combine_documents(self, doc: Document, next_doc: Document) -> Document:
combined_text = doc.text + " " + next_doc.text
combined_token_count = len(self.encoding.encode(combined_text))
new_doc = Document(
text=combined_text,
doc_id=doc.doc_id,
embedding=doc.embedding,
extra_info={**(doc.extra_info or {}), "token_count": combined_token_count}
)
return new_doc
def split_document(self, doc: Document) -> List[Document]:
split_docs = []
@@ -82,26 +73,11 @@ class Chunker:
processed_docs.append(doc)
i += 1
elif token_count < self.min_tokens:
if i + 1 < len(documents):
next_doc = documents[i + 1]
next_tokens = self.encoding.encode(next_doc.text)
if token_count + len(next_tokens) <= self.max_tokens:
# Combine small documents
combined_doc = self.combine_documents(doc, next_doc)
processed_docs.append(combined_doc)
i += 2
else:
# Keep the small document as is if adding next_doc would exceed max_tokens
doc.extra_info = doc.extra_info or {}
doc.extra_info["token_count"] = token_count
processed_docs.append(doc)
i += 1
else:
# No next document to combine with; add the small document as is
doc.extra_info = doc.extra_info or {}
doc.extra_info["token_count"] = token_count
processed_docs.append(doc)
i += 1
doc.extra_info = doc.extra_info or {}
doc.extra_info["token_count"] = token_count
processed_docs.append(doc)
i += 1
else:
# Split large documents
processed_docs.extend(self.split_document(doc))

View File

@@ -0,0 +1,18 @@
"""
External knowledge base connectors for DocsGPT.
This module contains connectors for external knowledge bases and document storage systems
that require authentication and specialized handling, separate from simple web scrapers.
"""
from .base import BaseConnectorAuth, BaseConnectorLoader
from .connector_creator import ConnectorCreator
from .google_drive import GoogleDriveAuth, GoogleDriveLoader
__all__ = [
'BaseConnectorAuth',
'BaseConnectorLoader',
'ConnectorCreator',
'GoogleDriveAuth',
'GoogleDriveLoader'
]

View File

@@ -0,0 +1,129 @@
"""
Base classes for external knowledge base connectors.
This module provides minimal abstract base classes that define the essential
interface for external knowledge base connectors.
"""
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from application.parser.schema.base import Document
class BaseConnectorAuth(ABC):
"""
Abstract base class for connector authentication.
Defines the minimal interface that all connector authentication
implementations must follow.
"""
@abstractmethod
def get_authorization_url(self, state: Optional[str] = None) -> str:
"""
Generate authorization URL for OAuth flows.
Args:
state: Optional state parameter for CSRF protection
Returns:
Authorization URL
"""
pass
@abstractmethod
def exchange_code_for_tokens(self, authorization_code: str) -> Dict[str, Any]:
"""
Exchange authorization code for access tokens.
Args:
authorization_code: Authorization code from OAuth callback
Returns:
Dictionary containing token information
"""
pass
@abstractmethod
def refresh_access_token(self, refresh_token: str) -> Dict[str, Any]:
"""
Refresh an expired access token.
Args:
refresh_token: Refresh token
Returns:
Dictionary containing refreshed token information
"""
pass
@abstractmethod
def is_token_expired(self, token_info: Dict[str, Any]) -> bool:
"""
Check if a token is expired.
Args:
token_info: Token information dictionary
Returns:
True if token is expired, False otherwise
"""
pass
class BaseConnectorLoader(ABC):
"""
Abstract base class for connector loaders.
Defines the minimal interface that all connector loader
implementations must follow.
"""
@abstractmethod
def __init__(self, session_token: str):
"""
Initialize the connector loader.
Args:
session_token: Authentication session token
"""
pass
@abstractmethod
def load_data(self, inputs: Dict[str, Any]) -> List[Document]:
"""
Load documents from the external knowledge base.
Args:
inputs: Configuration dictionary containing:
- file_ids: Optional list of specific file IDs to load
- folder_ids: Optional list of folder IDs to browse/download
- limit: Maximum number of items to return
- list_only: If True, return metadata without content
- recursive: Whether to recursively process folders
Returns:
List of Document objects
"""
pass
@abstractmethod
def download_to_directory(self, local_dir: str, source_config: Dict[str, Any] = None) -> Dict[str, Any]:
"""
Download files/folders to a local directory.
Args:
local_dir: Local directory path to download files to
source_config: Configuration for what to download
Returns:
Dictionary containing download results:
- files_downloaded: Number of files downloaded
- directory_path: Path where files were downloaded
- empty_result: Whether no files were downloaded
- source_type: Type of connector
- config_used: Configuration that was used
- error: Error message if download failed (optional)
"""
pass

View File

@@ -0,0 +1,81 @@
from application.parser.connectors.google_drive.loader import GoogleDriveLoader
from application.parser.connectors.google_drive.auth import GoogleDriveAuth
class ConnectorCreator:
"""
Factory class for creating external knowledge base connectors and auth providers.
These are different from remote loaders as they typically require
authentication and connect to external document storage systems.
"""
connectors = {
"google_drive": GoogleDriveLoader,
}
auth_providers = {
"google_drive": GoogleDriveAuth,
}
@classmethod
def create_connector(cls, connector_type, *args, **kwargs):
"""
Create a connector instance for the specified type.
Args:
connector_type: Type of connector to create (e.g., 'google_drive')
*args, **kwargs: Arguments to pass to the connector constructor
Returns:
Connector instance
Raises:
ValueError: If connector type is not supported
"""
connector_class = cls.connectors.get(connector_type.lower())
if not connector_class:
raise ValueError(f"No connector class found for type {connector_type}")
return connector_class(*args, **kwargs)
@classmethod
def create_auth(cls, connector_type):
"""
Create an auth provider instance for the specified connector type.
Args:
connector_type: Type of connector auth to create (e.g., 'google_drive')
Returns:
Auth provider instance
Raises:
ValueError: If connector type is not supported for auth
"""
auth_class = cls.auth_providers.get(connector_type.lower())
if not auth_class:
raise ValueError(f"No auth class found for type {connector_type}")
return auth_class()
@classmethod
def get_supported_connectors(cls):
"""
Get list of supported connector types.
Returns:
List of supported connector type strings
"""
return list(cls.connectors.keys())
@classmethod
def is_supported(cls, connector_type):
"""
Check if a connector type is supported.
Args:
connector_type: Type of connector to check
Returns:
True if supported, False otherwise
"""
return connector_type.lower() in cls.connectors

View File

@@ -0,0 +1,10 @@
"""
Google Drive connector for DocsGPT.
This module provides authentication and document loading capabilities for Google Drive.
"""
from .auth import GoogleDriveAuth
from .loader import GoogleDriveLoader
__all__ = ['GoogleDriveAuth', 'GoogleDriveLoader']

View File

@@ -0,0 +1,267 @@
import logging
import datetime
from typing import Optional, Dict, Any
from google.oauth2.credentials import Credentials
from google_auth_oauthlib.flow import Flow
from googleapiclient.discovery import build
from googleapiclient.errors import HttpError
from application.core.settings import settings
from application.parser.connectors.base import BaseConnectorAuth
class GoogleDriveAuth(BaseConnectorAuth):
"""
Handles Google OAuth 2.0 authentication for Google Drive access.
"""
SCOPES = [
'https://www.googleapis.com/auth/drive.file'
]
def __init__(self):
self.client_id = settings.GOOGLE_CLIENT_ID
self.client_secret = settings.GOOGLE_CLIENT_SECRET
self.redirect_uri = f"{settings.CONNECTOR_REDIRECT_BASE_URI}"
if not self.client_id or not self.client_secret:
raise ValueError("Google OAuth credentials not configured. Please set GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET in settings.")
def get_authorization_url(self, state: Optional[str] = None) -> str:
try:
flow = Flow.from_client_config(
{
"web": {
"client_id": self.client_id,
"client_secret": self.client_secret,
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token",
"redirect_uris": [self.redirect_uri]
}
},
scopes=self.SCOPES
)
flow.redirect_uri = self.redirect_uri
authorization_url, _ = flow.authorization_url(
access_type='offline',
prompt='consent',
include_granted_scopes='false',
state=state
)
return authorization_url
except Exception as e:
logging.error(f"Error generating authorization URL: {e}")
raise
def exchange_code_for_tokens(self, authorization_code: str) -> Dict[str, Any]:
try:
if not authorization_code:
raise ValueError("Authorization code is required")
flow = Flow.from_client_config(
{
"web": {
"client_id": self.client_id,
"client_secret": self.client_secret,
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token",
"redirect_uris": [self.redirect_uri]
}
},
scopes=self.SCOPES
)
flow.redirect_uri = self.redirect_uri
flow.fetch_token(code=authorization_code)
credentials = flow.credentials
if not credentials.refresh_token:
logging.warning("OAuth flow did not return a refresh_token.")
if not credentials.token:
raise ValueError("OAuth flow did not return an access token")
if not credentials.token_uri:
credentials.token_uri = "https://oauth2.googleapis.com/token"
if not credentials.client_id:
credentials.client_id = self.client_id
if not credentials.client_secret:
credentials.client_secret = self.client_secret
if not credentials.refresh_token:
raise ValueError(
"No refresh token received. This typically happens when offline access wasn't granted. "
)
return {
'access_token': credentials.token,
'refresh_token': credentials.refresh_token,
'token_uri': credentials.token_uri,
'client_id': credentials.client_id,
'client_secret': credentials.client_secret,
'scopes': credentials.scopes,
'expiry': credentials.expiry.isoformat() if credentials.expiry else None
}
except Exception as e:
logging.error(f"Error exchanging code for tokens: {e}")
raise
def refresh_access_token(self, refresh_token: str) -> Dict[str, Any]:
try:
if not refresh_token:
raise ValueError("Refresh token is required")
credentials = Credentials(
token=None,
refresh_token=refresh_token,
token_uri="https://oauth2.googleapis.com/token",
client_id=self.client_id,
client_secret=self.client_secret
)
from google.auth.transport.requests import Request
credentials.refresh(Request())
return {
'access_token': credentials.token,
'refresh_token': refresh_token,
'token_uri': credentials.token_uri,
'client_id': credentials.client_id,
'client_secret': credentials.client_secret,
'scopes': credentials.scopes,
'expiry': credentials.expiry.isoformat() if credentials.expiry else None
}
except Exception as e:
logging.error(f"Error refreshing access token: {e}", exc_info=True)
raise
def create_credentials_from_token_info(self, token_info: Dict[str, Any]) -> Credentials:
from application.core.settings import settings
access_token = token_info.get('access_token')
if not access_token:
raise ValueError("No access token found in token_info")
credentials = Credentials(
token=access_token,
refresh_token=token_info.get('refresh_token'),
token_uri= 'https://oauth2.googleapis.com/token',
client_id=settings.GOOGLE_CLIENT_ID,
client_secret=settings.GOOGLE_CLIENT_SECRET,
scopes=token_info.get('scopes', ['https://www.googleapis.com/auth/drive.readonly'])
)
if not credentials.token:
raise ValueError("Credentials created without valid access token")
return credentials
def build_drive_service(self, credentials: Credentials):
try:
if not credentials:
raise ValueError("No credentials provided")
if not credentials.token and not credentials.refresh_token:
raise ValueError("No access token or refresh token available. User must re-authorize with offline access.")
needs_refresh = credentials.expired or not credentials.token
if needs_refresh:
if credentials.refresh_token:
try:
from google.auth.transport.requests import Request
credentials.refresh(Request())
except Exception as refresh_error:
raise ValueError(f"Failed to refresh credentials: {refresh_error}")
else:
raise ValueError("No access token or refresh token available. User must re-authorize with offline access.")
return build('drive', 'v3', credentials=credentials)
except HttpError as e:
raise ValueError(f"Failed to build Google Drive service: HTTP {e.resp.status}")
except Exception as e:
raise ValueError(f"Failed to build Google Drive service: {str(e)}")
def is_token_expired(self, token_info):
if 'expiry' in token_info and token_info['expiry']:
try:
from dateutil import parser
# Google Drive provides timezone-aware ISO8601 dates
expiry_dt = parser.parse(token_info['expiry'])
current_time = datetime.datetime.now(datetime.timezone.utc)
return current_time >= expiry_dt - datetime.timedelta(seconds=60)
except Exception:
return True
if 'access_token' in token_info and token_info['access_token']:
return False
return True
def get_token_info_from_session(self, session_token: str) -> Dict[str, Any]:
try:
from application.core.mongo_db import MongoDB
from application.core.settings import settings
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
sessions_collection = db["connector_sessions"]
session = sessions_collection.find_one({"session_token": session_token})
if not session:
raise ValueError(f"Invalid session token: {session_token}")
if "token_info" not in session:
raise ValueError("Session missing token information")
token_info = session["token_info"]
if not token_info:
raise ValueError("Invalid token information")
required_fields = ["access_token", "refresh_token"]
missing_fields = [field for field in required_fields if field not in token_info or not token_info.get(field)]
if missing_fields:
raise ValueError(f"Missing required token fields: {missing_fields}")
if 'client_id' not in token_info:
token_info['client_id'] = settings.GOOGLE_CLIENT_ID
if 'client_secret' not in token_info:
token_info['client_secret'] = settings.GOOGLE_CLIENT_SECRET
if 'token_uri' not in token_info:
token_info['token_uri'] = 'https://oauth2.googleapis.com/token'
return token_info
except Exception as e:
raise ValueError(f"Failed to retrieve Google Drive token information: {str(e)}")
def validate_credentials(self, credentials: Credentials) -> bool:
"""
Validate Google Drive credentials by making a test API call.
Args:
credentials: Google credentials object
Returns:
True if credentials are valid, False otherwise
"""
try:
service = self.build_drive_service(credentials)
service.about().get(fields="user").execute()
return True
except HttpError as e:
logging.error(f"HTTP error validating credentials: {e}")
return False
except Exception as e:
logging.error(f"Error validating credentials: {e}")
return False

View File

@@ -0,0 +1,559 @@
"""
Google Drive loader for DocsGPT.
Loads documents from Google Drive using Google Drive API.
"""
import io
import logging
import os
from typing import List, Dict, Any, Optional
from googleapiclient.http import MediaIoBaseDownload
from googleapiclient.errors import HttpError
from application.parser.connectors.base import BaseConnectorLoader
from application.parser.connectors.google_drive.auth import GoogleDriveAuth
from application.parser.schema.base import Document
class GoogleDriveLoader(BaseConnectorLoader):
SUPPORTED_MIME_TYPES = {
'application/pdf': '.pdf',
'application/vnd.google-apps.document': '.docx',
'application/vnd.google-apps.presentation': '.pptx',
'application/vnd.google-apps.spreadsheet': '.xlsx',
'application/vnd.openxmlformats-officedocument.wordprocessingml.document': '.docx',
'application/vnd.openxmlformats-officedocument.presentationml.presentation': '.pptx',
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': '.xlsx',
'application/msword': '.doc',
'application/vnd.ms-powerpoint': '.ppt',
'application/vnd.ms-excel': '.xls',
'text/plain': '.txt',
'text/csv': '.csv',
'text/html': '.html',
'text/markdown': '.md',
'text/x-rst': '.rst',
'application/json': '.json',
'application/epub+zip': '.epub',
'application/rtf': '.rtf',
'image/jpeg': '.jpg',
'image/jpg': '.jpg',
'image/png': '.png',
}
EXPORT_FORMATS = {
'application/vnd.google-apps.document': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
'application/vnd.google-apps.presentation': 'application/vnd.openxmlformats-officedocument.presentationml.presentation',
'application/vnd.google-apps.spreadsheet': 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet'
}
def __init__(self, session_token: str):
self.auth = GoogleDriveAuth()
self.session_token = session_token
token_info = self.auth.get_token_info_from_session(session_token)
self.credentials = self.auth.create_credentials_from_token_info(token_info)
try:
self.service = self.auth.build_drive_service(self.credentials)
except Exception as e:
logging.warning(f"Could not build Google Drive service: {e}")
self.service = None
self.next_page_token = None
def _process_file(self, file_metadata: Dict[str, Any], load_content: bool = True) -> Optional[Document]:
try:
file_id = file_metadata.get('id')
file_name = file_metadata.get('name', 'Unknown')
mime_type = file_metadata.get('mimeType', 'application/octet-stream')
if mime_type not in self.SUPPORTED_MIME_TYPES and not mime_type.startswith('application/vnd.google-apps.'):
return None
if mime_type not in self.SUPPORTED_MIME_TYPES and not mime_type.startswith('application/vnd.google-apps.'):
logging.info(f"Skipping unsupported file type: {mime_type} for file {file_name}")
return None
# Google Drive provides timezone-aware ISO8601 dates
doc_metadata = {
'file_name': file_name,
'mime_type': mime_type,
'size': file_metadata.get('size', None),
'created_time': file_metadata.get('createdTime'),
'modified_time': file_metadata.get('modifiedTime'),
'parents': file_metadata.get('parents', []),
'source': 'google_drive'
}
if not load_content:
return Document(
text="",
doc_id=file_id,
extra_info=doc_metadata
)
content = self._download_file_content(file_id, mime_type)
if content is None:
logging.warning(f"Could not load content for file {file_name} ({file_id})")
return None
return Document(
text=content,
doc_id=file_id,
extra_info=doc_metadata
)
except Exception as e:
logging.error(f"Error processing file: {e}")
return None
def load_data(self, inputs: Dict[str, Any]) -> List[Document]:
session_token = inputs.get('session_token')
if session_token and session_token != self.session_token:
logging.warning("Session token in inputs differs from loader's session token. Using loader's session token.")
self.config = inputs
try:
documents: List[Document] = []
folder_id = inputs.get('folder_id')
file_ids = inputs.get('file_ids', [])
limit = inputs.get('limit', 100)
list_only = inputs.get('list_only', False)
load_content = not list_only
page_token = inputs.get('page_token')
search_query = inputs.get('search_query')
self.next_page_token = None
if file_ids:
# Specific files requested: load them
for file_id in file_ids:
try:
doc = self._load_file_by_id(file_id, load_content=load_content)
if doc:
if not search_query or (
search_query.lower() in doc.extra_info.get('file_name', '').lower()
):
documents.append(doc)
elif hasattr(self, '_credential_refreshed') and self._credential_refreshed:
self._credential_refreshed = False
logging.info(f"Retrying load of file {file_id} after credential refresh")
doc = self._load_file_by_id(file_id, load_content=load_content)
if doc and (
not search_query or
search_query.lower() in doc.extra_info.get('file_name', '').lower()
):
documents.append(doc)
except Exception as e:
logging.error(f"Error loading file {file_id}: {e}")
continue
else:
# Browsing mode: list immediate children of provided folder or root
parent_id = folder_id if folder_id else 'root'
documents = self._list_items_in_parent(
parent_id,
limit=limit,
load_content=load_content,
page_token=page_token,
search_query=search_query
)
logging.info(f"Loaded {len(documents)} documents from Google Drive")
return documents
except Exception as e:
logging.error(f"Error loading data from Google Drive: {e}", exc_info=True)
raise
def _load_file_by_id(self, file_id: str, load_content: bool = True) -> Optional[Document]:
self._ensure_service()
try:
file_metadata = self.service.files().get(
fileId=file_id,
fields='id,name,mimeType,size,createdTime,modifiedTime,parents'
).execute()
return self._process_file(file_metadata, load_content=load_content)
except HttpError as e:
logging.error(f"HTTP error loading file {file_id}: {e.resp.status} - {e.content}")
if e.resp.status in [401, 403]:
if hasattr(self.credentials, 'refresh_token') and self.credentials.refresh_token:
try:
from google.auth.transport.requests import Request
self.credentials.refresh(Request())
self._ensure_service()
return None
except Exception as refresh_error:
raise ValueError(f"Authentication failed and could not be refreshed: {refresh_error}")
else:
raise ValueError("Authentication failed and cannot be refreshed: missing refresh_token")
return None
except Exception as e:
logging.error(f"Error loading file {file_id}: {e}")
return None
def _list_items_in_parent(self, parent_id: str, limit: int = 100, load_content: bool = False, page_token: Optional[str] = None, search_query: Optional[str] = None) -> List[Document]:
self._ensure_service()
documents: List[Document] = []
try:
query = f"'{parent_id}' in parents and trashed=false"
if search_query:
safe_search = search_query.replace("'", "\\'")
query += f" and name contains '{safe_search}'"
next_token_out: Optional[str] = None
while True:
page_size = 100
if limit:
remaining = max(0, limit - len(documents))
if remaining == 0:
break
page_size = min(100, remaining)
results = self.service.files().list(
q=query,
fields='nextPageToken,files(id,name,mimeType,size,createdTime,modifiedTime,parents)',
pageToken=page_token,
pageSize=page_size,
orderBy='name'
).execute()
items = results.get('files', [])
for item in items:
mime_type = item.get('mimeType')
if mime_type == 'application/vnd.google-apps.folder':
doc_metadata = {
'file_name': item.get('name', 'Unknown'),
'mime_type': mime_type,
'size': item.get('size', None),
'created_time': item.get('createdTime'),
'modified_time': item.get('modifiedTime'),
'parents': item.get('parents', []),
'source': 'google_drive',
'is_folder': True
}
documents.append(Document(text="", doc_id=item.get('id'), extra_info=doc_metadata))
else:
doc = self._process_file(item, load_content=load_content)
if doc:
documents.append(doc)
if limit and len(documents) >= limit:
self.next_page_token = results.get('nextPageToken')
return documents
page_token = results.get('nextPageToken')
next_token_out = page_token
if not page_token:
break
self.next_page_token = next_token_out
return documents
except Exception as e:
logging.error(f"Error listing items under parent {parent_id}: {e}")
return documents
def _download_file_content(self, file_id: str, mime_type: str) -> Optional[str]:
if not self.credentials.token:
logging.warning("No access token in credentials, attempting to refresh")
if hasattr(self.credentials, 'refresh_token') and self.credentials.refresh_token:
try:
from google.auth.transport.requests import Request
self.credentials.refresh(Request())
logging.info("Credentials refreshed successfully")
self._ensure_service()
except Exception as e:
logging.error(f"Failed to refresh credentials: {e}")
raise ValueError("Authentication failed and cannot be refreshed: missing or invalid refresh_token")
else:
logging.error("No access token and no refresh_token available")
raise ValueError("Authentication failed and cannot be refreshed: missing refresh_token")
if self.credentials.expired:
logging.warning("Credentials are expired, attempting to refresh")
if hasattr(self.credentials, 'refresh_token') and self.credentials.refresh_token:
try:
from google.auth.transport.requests import Request
self.credentials.refresh(Request())
logging.info("Credentials refreshed successfully")
self._ensure_service()
except Exception as e:
logging.error(f"Failed to refresh expired credentials: {e}")
raise ValueError("Authentication failed and cannot be refreshed: expired credentials")
else:
logging.error("Credentials expired and no refresh_token available")
raise ValueError("Authentication failed and cannot be refreshed: missing refresh_token")
try:
if mime_type in self.EXPORT_FORMATS:
export_mime_type = self.EXPORT_FORMATS[mime_type]
request = self.service.files().export_media(
fileId=file_id,
mimeType=export_mime_type
)
else:
request = self.service.files().get_media(fileId=file_id)
file_io = io.BytesIO()
downloader = MediaIoBaseDownload(file_io, request)
done = False
while done is False:
try:
_, done = downloader.next_chunk()
except HttpError as e:
logging.error(f"HTTP error downloading file {file_id}: {e.resp.status} - {e.content}")
return None
except Exception as e:
logging.error(f"Error during download of file {file_id}: {e}")
return None
content_bytes = file_io.getvalue()
try:
content = content_bytes.decode('utf-8')
except UnicodeDecodeError:
try:
content = content_bytes.decode('latin-1')
except UnicodeDecodeError:
logging.error(f"Could not decode file {file_id} as text")
return None
return content
except HttpError as e:
logging.error(f"HTTP error downloading file {file_id}: {e.resp.status} - {e.content}")
if e.resp.status in [401, 403]:
logging.error(f"Authentication error downloading file {file_id}")
if hasattr(self.credentials, 'refresh_token') and self.credentials.refresh_token:
logging.info(f"Attempting to refresh credentials for file {file_id}")
try:
from google.auth.transport.requests import Request
self.credentials.refresh(Request())
logging.info("Credentials refreshed successfully")
self._credential_refreshed = True
self._ensure_service()
return None
except Exception as refresh_error:
logging.error(f"Error refreshing credentials: {refresh_error}")
raise ValueError(f"Authentication failed and could not be refreshed: {refresh_error}")
else:
logging.error("Cannot refresh credentials: missing refresh_token")
raise ValueError("Authentication failed and cannot be refreshed: missing refresh_token")
return None
except Exception as e:
logging.error(f"Error downloading file {file_id}: {e}")
return None
def _download_file_to_directory(self, file_id: str, local_dir: str) -> bool:
try:
self._ensure_service()
return self._download_single_file(file_id, local_dir)
except Exception as e:
logging.error(f"Error downloading file {file_id}: {e}", exc_info=True)
return False
def _ensure_service(self):
if not self.service:
try:
self.service = self.auth.build_drive_service(self.credentials)
except Exception as e:
raise ValueError(f"Cannot access Google Drive: {e}")
def _download_single_file(self, file_id: str, local_dir: str) -> bool:
file_metadata = self.service.files().get(
fileId=file_id,
fields='name,mimeType'
).execute()
file_name = file_metadata['name']
mime_type = file_metadata['mimeType']
if mime_type not in self.SUPPORTED_MIME_TYPES and not mime_type.startswith('application/vnd.google-apps.'):
return False
os.makedirs(local_dir, exist_ok=True)
full_path = os.path.join(local_dir, file_name)
if mime_type in self.EXPORT_FORMATS:
export_mime_type = self.EXPORT_FORMATS[mime_type]
request = self.service.files().export_media(
fileId=file_id,
mimeType=export_mime_type
)
extension = self._get_extension_for_mime_type(export_mime_type)
if not full_path.endswith(extension):
full_path += extension
else:
request = self.service.files().get_media(fileId=file_id)
with open(full_path, 'wb') as f:
downloader = MediaIoBaseDownload(f, request)
done = False
while not done:
_, done = downloader.next_chunk()
return True
def _download_folder_recursive(self, folder_id: str, local_dir: str, recursive: bool = True) -> int:
files_downloaded = 0
try:
os.makedirs(local_dir, exist_ok=True)
query = f"'{folder_id}' in parents and trashed=false"
page_token = None
while True:
results = self.service.files().list(
q=query,
fields='nextPageToken, files(id, name, mimeType)',
pageToken=page_token,
pageSize=1000
).execute()
items = results.get('files', [])
logging.info(f"Found {len(items)} items in folder {folder_id}")
for item in items:
item_name = item['name']
item_id = item['id']
mime_type = item['mimeType']
if mime_type == 'application/vnd.google-apps.folder':
if recursive:
# Create subfolder and recurse
subfolder_path = os.path.join(local_dir, item_name)
os.makedirs(subfolder_path, exist_ok=True)
subfolder_files = self._download_folder_recursive(
item_id,
subfolder_path,
recursive
)
files_downloaded += subfolder_files
logging.info(f"Downloaded {subfolder_files} files from subfolder {item_name}")
else:
# Download file
success = self._download_single_file(item_id, local_dir)
if success:
files_downloaded += 1
logging.info(f"Downloaded file: {item_name}")
else:
logging.warning(f"Failed to download file: {item_name}")
page_token = results.get('nextPageToken')
if not page_token:
break
return files_downloaded
except Exception as e:
logging.error(f"Error in _download_folder_recursive for folder {folder_id}: {e}", exc_info=True)
return files_downloaded
def _get_extension_for_mime_type(self, mime_type: str) -> str:
extensions = {
'application/pdf': '.pdf',
'text/plain': '.txt',
'application/vnd.openxmlformats-officedocument.wordprocessingml.document': '.docx',
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': '.xlsx',
'application/vnd.openxmlformats-officedocument.presentationml.presentation': '.pptx',
'text/html': '.html',
'text/markdown': '.md',
}
return extensions.get(mime_type, '.bin')
def _download_folder_contents(self, folder_id: str, local_dir: str, recursive: bool = True) -> int:
try:
self._ensure_service()
return self._download_folder_recursive(folder_id, local_dir, recursive)
except Exception as e:
logging.error(f"Error downloading folder {folder_id}: {e}", exc_info=True)
return 0
def download_to_directory(self, local_dir: str, source_config: dict = None) -> dict:
if source_config is None:
source_config = {}
config = source_config if source_config else getattr(self, 'config', {})
files_downloaded = 0
try:
folder_ids = config.get('folder_ids', [])
file_ids = config.get('file_ids', [])
recursive = config.get('recursive', True)
self._ensure_service()
if file_ids:
if isinstance(file_ids, str):
file_ids = [file_ids]
for file_id in file_ids:
if self._download_file_to_directory(file_id, local_dir):
files_downloaded += 1
# Process folders
if folder_ids:
if isinstance(folder_ids, str):
folder_ids = [folder_ids]
for folder_id in folder_ids:
try:
folder_metadata = self.service.files().get(
fileId=folder_id,
fields='name'
).execute()
folder_name = folder_metadata.get('name', '')
folder_path = os.path.join(local_dir, folder_name)
os.makedirs(folder_path, exist_ok=True)
folder_files = self._download_folder_recursive(
folder_id,
folder_path,
recursive
)
files_downloaded += folder_files
logging.info(f"Downloaded {folder_files} files from folder {folder_name}")
except Exception as e:
logging.error(f"Error downloading folder {folder_id}: {e}", exc_info=True)
if not file_ids and not folder_ids:
raise ValueError("No folder_ids or file_ids provided for download")
return {
"files_downloaded": files_downloaded,
"directory_path": local_dir,
"empty_result": files_downloaded == 0,
"source_type": "google_drive",
"config_used": config
}
except Exception as e:
return {
"files_downloaded": files_downloaded,
"directory_path": local_dir,
"empty_result": True,
"source_type": "google_drive",
"config_used": config,
"error": str(e)
}

View File

@@ -1,21 +1,43 @@
import os
import logging
from typing import List, Any
from retry import retry
from tqdm import tqdm
from application.core.settings import settings
from application.vectorstore.vector_creator import VectorCreator
@retry(tries=10, delay=60)
def add_text_to_store_with_retry(store, doc, source_id):
def sanitize_content(content: str) -> str:
"""
Add a document's text and metadata to the vector store with retry logic.
Remove NUL characters that can cause vector store ingestion to fail.
Args:
content (str): Raw content that may contain NUL characters
Returns:
str: Sanitized content with NUL characters removed
"""
if not content:
return content
return content.replace('\x00', '')
@retry(tries=10, delay=60)
def add_text_to_store_with_retry(store: Any, doc: Any, source_id: str) -> None:
"""Add a document's text and metadata to the vector store with retry logic.
Args:
store: The vector store object.
doc: The document to be added.
source_id: Unique identifier for the source.
Raises:
Exception: If document addition fails after all retry attempts.
"""
try:
# Sanitize content to remove NUL characters that cause ingestion failures
doc.page_content = sanitize_content(doc.page_content)
doc.metadata["source_id"] = str(source_id)
store.add_texts([doc.page_content], metadatas=[doc.metadata])
except Exception as e:
@@ -23,18 +45,21 @@ def add_text_to_store_with_retry(store, doc, source_id):
raise
def embed_and_store_documents(docs, folder_name, source_id, task_status):
"""
Embeds documents and stores them in a vector store.
def embed_and_store_documents(docs: List[Any], folder_name: str, source_id: str, task_status: Any) -> None:
"""Embeds documents and stores them in a vector store.
Args:
docs (list): List of documents to be embedded and stored.
folder_name (str): Directory to save the vector store.
source_id (str): Unique identifier for the source.
docs: List of documents to be embedded and stored.
folder_name: Directory to save the vector store.
source_id: Unique identifier for the source.
task_status: Task state manager for progress updates.
Returns:
None
Raises:
OSError: If unable to create folder or save vector store.
Exception: If vector store creation or document embedding fails.
"""
# Ensure the folder exists
if not os.path.exists(folder_name):
@@ -46,7 +71,7 @@ def embed_and_store_documents(docs, folder_name, source_id, task_status):
store = VectorCreator.create_vectorstore(
settings.VECTOR_STORE,
docs_init=docs_init,
source_id=folder_name,
source_id=source_id,
embeddings_key=os.getenv("EMBEDDINGS_KEY"),
)
else:
@@ -77,10 +102,21 @@ def embed_and_store_documents(docs, folder_name, source_id, task_status):
except Exception as e:
logging.error(f"Error embedding document {idx}: {e}", exc_info=True)
logging.info(f"Saving progress at document {idx} out of {total_docs}")
store.save_local(folder_name)
try:
store.save_local(folder_name)
logging.info("Progress saved successfully")
except Exception as save_error:
logging.error(f"CRITICAL: Failed to save progress: {save_error}", exc_info=True)
# Continue without breaking to attempt final save
break
# Save the vector store
if settings.VECTOR_STORE == "faiss":
store.save_local(folder_name)
logging.info("Vector store saved successfully.")
try:
store.save_local(folder_name)
logging.info("Vector store saved successfully.")
except Exception as e:
logging.error(f"CRITICAL: Failed to save final vector store: {e}", exc_info=True)
raise OSError(f"Unable to save vector store to {folder_name}: {e}") from e
else:
logging.info("Vector store saved successfully.")

View File

@@ -15,6 +15,7 @@ from application.parser.file.json_parser import JSONParser
from application.parser.file.pptx_parser import PPTXParser
from application.parser.file.image_parser import ImageParser
from application.parser.schema.base import Document
from application.utils import num_tokens_from_string
DEFAULT_FILE_EXTRACTOR: Dict[str, BaseParser] = {
".pdf": PDFParser(),
@@ -141,11 +142,12 @@ class SimpleDirectoryReader(BaseReader):
Returns:
List[Document]: A list of documents.
"""
data: Union[str, List[str]] = ""
data_list: List[str] = []
metadata_list = []
self.file_token_counts = {}
for input_file in self.input_files:
if input_file.suffix in self.file_extractor:
parser = self.file_extractor[input_file.suffix]
@@ -156,24 +158,48 @@ class SimpleDirectoryReader(BaseReader):
# do standard read
with open(input_file, "r", errors=self.errors) as f:
data = f.read()
# Prepare metadata for this file
if self.file_metadata is not None:
file_metadata = self.file_metadata(input_file.name)
# Calculate token count for this file
if isinstance(data, List):
file_tokens = sum(num_tokens_from_string(str(d)) for d in data)
else:
# Provide a default empty metadata
file_metadata = {'title': '', 'store': ''}
# TODO: Find a case with no metadata and check if breaks anything
file_tokens = num_tokens_from_string(str(data))
full_path = str(input_file.resolve())
self.file_token_counts[full_path] = file_tokens
base_metadata = {
'title': input_file.name,
'token_count': file_tokens,
}
if hasattr(self, 'input_dir'):
try:
relative_path = str(input_file.relative_to(self.input_dir))
base_metadata['source'] = relative_path
except ValueError:
base_metadata['source'] = str(input_file)
else:
base_metadata['source'] = str(input_file)
if self.file_metadata is not None:
custom_metadata = self.file_metadata(input_file.name)
base_metadata.update(custom_metadata)
if isinstance(data, List):
# Extend data_list with each item in the data list
data_list.extend([str(d) for d in data])
# For each item in the data list, add the file's metadata to metadata_list
metadata_list.extend([file_metadata for _ in data])
metadata_list.extend([base_metadata for _ in data])
else:
# Add the single piece of data to data_list
data_list.append(str(data))
# Add the file's metadata to metadata_list
metadata_list.append(file_metadata)
metadata_list.append(base_metadata)
# Build directory structure if input_dir is provided
if hasattr(self, 'input_dir'):
self.directory_structure = self.build_directory_structure(self.input_dir)
logging.info("Directory structure built successfully")
else:
self.directory_structure = {}
if concatenate:
return [Document("\n".join(data_list))]
@@ -181,3 +207,48 @@ class SimpleDirectoryReader(BaseReader):
return [Document(d, extra_info=m) for d, m in zip(data_list, metadata_list)]
else:
return [Document(d) for d in data_list]
def build_directory_structure(self, base_path):
"""Build a dictionary representing the directory structure.
Args:
base_path: The base path to start building the structure from.
Returns:
dict: A nested dictionary representing the directory structure.
"""
import mimetypes
def build_tree(path):
"""Helper function to recursively build the directory tree."""
result = {}
for item in path.iterdir():
if self.exclude_hidden and item.name.startswith('.'):
continue
if item.is_dir():
subtree = build_tree(item)
if subtree:
result[item.name] = subtree
else:
if self.required_exts is not None and item.suffix not in self.required_exts:
continue
full_path = str(item.resolve())
file_size_bytes = item.stat().st_size
mime_type = mimetypes.guess_type(item.name)[0] or "application/octet-stream"
file_info = {
"type": mime_type,
"size_bytes": file_size_bytes
}
if hasattr(self, 'file_token_counts') and full_path in self.file_token_counts:
file_info["token_count"] = self.file_token_counts[full_path]
result[item.name] = file_info
return result
return build_tree(Path(base_path))

View File

@@ -8,6 +8,7 @@ import requests
from typing import Dict, Union
from application.parser.file.base_parser import BaseParser
from application.core.settings import settings
class ImageParser(BaseParser):
@@ -18,10 +19,13 @@ class ImageParser(BaseParser):
return {}
def parse_file(self, file: Path, errors: str = "ignore") -> Union[str, list[str]]:
doc2md_service = "https://llm.arc53.com/doc2md"
# alternatively you can use local vision capable LLM
with open(file, "rb") as file_loaded:
files = {'file': file_loaded}
response = requests.post(doc2md_service, files=files)
data = response.json()["markdown"]
if settings.PARSE_IMAGE_REMOTE:
doc2md_service = "https://llm.arc53.com/doc2md"
# alternatively you can use local vision capable LLM
with open(file, "rb") as file_loaded:
files = {'file': file_loaded}
response = requests.post(doc2md_service, files=files)
data = response.json()["markdown"]
else:
data = ""
return data

View File

@@ -1,44 +1,135 @@
import base64
import requests
from typing import List
import time
from typing import List, Optional
from application.parser.remote.base import BaseRemote
from langchain_core.documents import Document
from application.parser.schema.base import Document
import mimetypes
from application.core.settings import settings
class GitHubLoader(BaseRemote):
def __init__(self):
self.access_token = None
self.access_token = settings.GITHUB_ACCESS_TOKEN
self.headers = {
"Authorization": f"token {self.access_token}"
} if self.access_token else {}
"Authorization": f"token {self.access_token}",
"Accept": "application/vnd.github.v3+json"
} if self.access_token else {
"Accept": "application/vnd.github.v3+json"
}
return
def fetch_file_content(self, repo_url: str, file_path: str) -> str:
def is_text_file(self, file_path: str) -> bool:
"""Determine if a file is a text file based on extension."""
# Common text file extensions
text_extensions = {
'.txt', '.md', '.markdown', '.rst', '.json', '.xml', '.yaml', '.yml',
'.py', '.js', '.ts', '.jsx', '.tsx', '.java', '.c', '.cpp', '.h', '.hpp',
'.cs', '.go', '.rs', '.rb', '.php', '.swift', '.kt', '.scala',
'.html', '.css', '.scss', '.sass', '.less',
'.sh', '.bash', '.zsh', '.fish',
'.sql', '.r', '.m', '.mat',
'.ini', '.cfg', '.conf', '.config', '.env',
'.gitignore', '.dockerignore', '.editorconfig',
'.log', '.csv', '.tsv'
}
# Get file extension
file_lower = file_path.lower()
for ext in text_extensions:
if file_lower.endswith(ext):
return True
# Also check MIME type
mime_type, _ = mimetypes.guess_type(file_path)
if mime_type and (mime_type.startswith("text") or mime_type in ["application/json", "application/xml"]):
return True
return False
def fetch_file_content(self, repo_url: str, file_path: str) -> Optional[str]:
"""Fetch file content. Returns None if file should be skipped (binary files or empty files)."""
url = f"https://api.github.com/repos/{repo_url}/contents/{file_path}"
response = requests.get(url, headers=self.headers)
response = self._make_request(url)
if response.status_code == 200:
content = response.json()
mime_type, _ = mimetypes.guess_type(file_path) # Guess the MIME type based on the file extension
content = response.json()
if content.get("encoding") == "base64":
if mime_type and mime_type.startswith("text"): # Handle only text files
try:
decoded_content = base64.b64decode(content["content"]).decode("utf-8")
return f"Filename: {file_path}\n\n{decoded_content}"
except Exception as e:
raise e
else:
return f"Filename: {file_path} is a binary file and was skipped."
if content.get("encoding") == "base64":
if self.is_text_file(file_path): # Handle only text files
try:
decoded_content = base64.b64decode(content["content"]).decode("utf-8").strip()
# Skip empty files
if not decoded_content:
return None
return decoded_content
except Exception:
# If decoding fails, it's probably a binary file
return None
else:
return f"Filename: {file_path}\n\n{content['content']}"
# Skip binary files by returning None
return None
else:
response.raise_for_status()
file_content = content['content'].strip()
# Skip empty files
if not file_content:
return None
return file_content
def _make_request(self, url: str, max_retries: int = 3) -> requests.Response:
"""Make a request with retry logic for rate limiting"""
for attempt in range(max_retries):
response = requests.get(url, headers=self.headers)
if response.status_code == 200:
return response
elif response.status_code == 403:
# Check if it's a rate limit issue
try:
error_data = response.json()
error_msg = error_data.get("message", "")
# Check rate limit headers
remaining = response.headers.get("X-RateLimit-Remaining", "unknown")
reset_time = response.headers.get("X-RateLimit-Reset", "unknown")
print(f"GitHub API 403 Error: {error_msg}")
print(f"Rate limit remaining: {remaining}, Reset time: {reset_time}")
if "rate limit" in error_msg.lower():
if attempt < max_retries - 1:
wait_time = 2 ** attempt # Exponential backoff
print(f"Rate limit hit, waiting {wait_time} seconds before retry...")
time.sleep(wait_time)
continue
# Provide helpful error message
if remaining == "0":
raise Exception(f"GitHub API rate limit exceeded. Please set GITHUB_ACCESS_TOKEN environment variable. Reset time: {reset_time}")
else:
raise Exception(f"GitHub API error: {error_msg}. This may require authentication - set GITHUB_ACCESS_TOKEN environment variable.")
except Exception as e:
if isinstance(e, Exception) and "GitHub API" in str(e):
raise
# If we can't parse the response, raise the original error
response.raise_for_status()
else:
response.raise_for_status()
return response
def fetch_repo_files(self, repo_url: str, path: str = "") -> List[str]:
url = f"https://api.github.com/repos/{repo_url}/contents/{path}"
response = requests.get(url, headers={**self.headers, "Accept": "application/vnd.github.v3.raw"})
response = self._make_request(url)
contents = response.json()
# Handle error responses from GitHub API
if isinstance(contents, dict) and "message" in contents:
raise Exception(f"GitHub API error: {contents.get('message')}")
# Ensure contents is a list
if not isinstance(contents, list):
raise TypeError(f"Expected list from GitHub API, got {type(contents).__name__}: {contents}")
files = []
for item in contents:
if item["type"] == "file":
@@ -53,6 +144,15 @@ class GitHubLoader(BaseRemote):
documents = []
for file_path in files:
content = self.fetch_file_content(repo_name, file_path)
documents.append(Document(page_content=content, metadata={"title": file_path,
"source": f"https://github.com/{repo_name}/blob/main/{file_path}"}))
# Skip binary files (content is None)
if content is None:
continue
documents.append(Document(
text=content,
doc_id=file_path,
extra_info={
"title": file_path,
"source": f"https://github.com/{repo_name}/blob/main/{file_path}"
}
))
return documents

View File

@@ -6,6 +6,16 @@ from application.parser.remote.github_loader import GitHubLoader
class RemoteCreator:
"""
Factory class for creating remote content loaders.
These loaders fetch content from remote web sources like URLs,
sitemaps, web crawlers, social media platforms, etc.
For external knowledge base connectors (like Google Drive),
use ConnectorCreator instead.
"""
loaders = {
"url": WebLoader,
"sitemap": SitemapLoader,
@@ -18,5 +28,5 @@ class RemoteCreator:
def create_loader(cls, type, *args, **kwargs):
loader_class = cls.loaders.get(type.lower())
if not loader_class:
raise ValueError(f"No LLM class found for type {type}")
raise ValueError(f"No loader class found for type {type}")
return loader_class(*args, **kwargs)

Some files were not shown because too many files have changed in this diff Show More