Compare commits

...

37 Commits

Author SHA1 Message Date
dependabot[bot]
32c70081c1 chore(deps): bump h11
Bumps the pip group with 1 update in the /extensions/slack-bot directory: [h11](https://github.com/python-hyper/h11).


Updates `h11` from 0.14.0 to 0.16.0
- [Commits](https://github.com/python-hyper/h11/compare/v0.14.0...v0.16.0)

---
updated-dependencies:
- dependency-name: h11
  dependency-version: 0.16.0
  dependency-type: direct:production
  dependency-group: pip
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-11-20 18:08:30 +00:00
Alex
9e58eb02b3 Update .env.development 2025-11-14 19:53:19 +02:00
Siddhant Rai
3f7de867cc feat: model registry and capabilities for multi-provider support (#2158)
* feat: Implement model registry and capabilities for multi-provider support

- Added ModelRegistry to manage available models and their capabilities.
- Introduced ModelProvider enum for different LLM providers.
- Created ModelCapabilities dataclass to define model features.
- Implemented methods to load models based on API keys and settings.
- Added utility functions for model management in model_utils.py.
- Updated settings.py to include provider-specific API keys.
- Refactored LLM classes (Anthropic, OpenAI, Google, etc.) to utilize new model registry.
- Enhanced utility functions to handle token limits and model validation.
- Improved code structure and logging for better maintainability.

* feat: Add model selection feature with API integration and UI component

* feat: Add model selection and default model functionality in agent management

* test: Update assertions and formatting in stream processing tests

* refactor(llm): Standardize model identifier to model_id

* fix tests

---------

Co-authored-by: Alex <a@tushynski.me>
2025-11-14 13:13:19 +02:00
Manish Madan
fbf7cf874b chore(dependabot): add react-widget npm dependency updates (#2146) 2025-11-07 17:17:46 +02:00
Manish Madan
ba7278b80f Merge pull request #2140 from arc53/dependabot/npm_and_yarn/frontend/husky-9.1.7
chore(deps-dev): bump husky from 8.0.3 to 9.1.7 in /frontend
2025-11-07 03:02:52 +05:30
ManishMadan2882
9d649de6f9 chore(eslint): migrate to ESLint 9 flat config format
- Add eslint.config.js with ESLint 9 flat config format
- Remove deprecated .eslintrc.cjs file
- Remove deprecated .eslintignore file (replaced by ignores in config)
- Maintain all existing ESLint rules and configurations
- Ensure compatibility with Husky 9.1.7
2025-11-07 02:59:51 +05:30
dependabot[bot]
7929afbf58 chore(deps-dev): bump husky from 8.0.3 to 9.1.7 in /frontend
Bumps [husky](https://github.com/typicode/husky) from 8.0.3 to 9.1.7.
- [Release notes](https://github.com/typicode/husky/releases)
- [Commits](https://github.com/typicode/husky/compare/v8.0.3...v9.1.7)

---
updated-dependencies:
- dependency-name: husky
  dependency-version: 9.1.7
  dependency-type: direct:development
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-11-06 21:27:39 +00:00
Manish Madan
ceaf942e70 Merge pull request #2139 from arc53/dependabot/npm_and_yarn/frontend/eslint-9.39.1
chore(deps-dev): bump eslint from 8.57.1 to 9.39.1 in /frontend
2025-11-07 02:33:32 +05:30
dependabot[bot]
f355601a44 chore(deps-dev): bump eslint from 8.57.1 to 9.39.1 in /frontend
Bumps [eslint](https://github.com/eslint/eslint) from 8.57.1 to 9.39.1.
- [Release notes](https://github.com/eslint/eslint/releases)
- [Commits](https://github.com/eslint/eslint/compare/v8.57.1...v9.39.1)

---
updated-dependencies:
- dependency-name: eslint
  dependency-version: 9.39.1
  dependency-type: direct:development
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-11-06 20:00:14 +00:00
Manish Madan
4ff99a1e86 Merge pull request #2138 from arc53/dependabot/npm_and_yarn/frontend/reduxjs/toolkit-2.10.1
chore(deps): bump @reduxjs/toolkit from 2.9.2 to 2.10.1 in /frontend
2025-11-07 01:28:58 +05:30
dependabot[bot]
129084ba92 chore(deps): bump @reduxjs/toolkit from 2.9.2 to 2.10.1 in /frontend
Bumps [@reduxjs/toolkit](https://github.com/reduxjs/redux-toolkit) from 2.9.2 to 2.10.1.
- [Release notes](https://github.com/reduxjs/redux-toolkit/releases)
- [Commits](https://github.com/reduxjs/redux-toolkit/compare/v2.9.2...v2.10.1)

---
updated-dependencies:
- dependency-name: "@reduxjs/toolkit"
  dependency-version: 2.10.1
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-11-06 19:56:28 +00:00
Manish Madan
2288df1293 Merge pull request #2141 from arc53/dependabot/npm_and_yarn/frontend/vite-7.2.0
chore(deps-dev): bump vite from 7.1.12 to 7.2.0 in /frontend
2025-11-07 01:05:29 +05:30
Manish Madan
d9dfac55e7 Merge pull request #2134 from arc53/dependabot/npm_and_yarn/frontend/types/mermaid-9.2.0
chore(deps-dev): bump @types/mermaid from 9.1.0 to 9.2.0 in /frontend
2025-11-06 17:46:59 +05:30
Nick
404cf4b7c7 Update quickstart.mdx (#2142)
Added missing **
2025-11-06 12:37:27 +02:00
dependabot[bot]
f1c1fc123b chore(deps-dev): bump vite from 7.1.12 to 7.2.0 in /frontend
Bumps [vite](https://github.com/vitejs/vite/tree/HEAD/packages/vite) from 7.1.12 to 7.2.0.
- [Release notes](https://github.com/vitejs/vite/releases)
- [Changelog](https://github.com/vitejs/vite/blob/main/packages/vite/CHANGELOG.md)
- [Commits](https://github.com/vitejs/vite/commits/v7.2.0/packages/vite)

---
updated-dependencies:
- dependency-name: vite
  dependency-version: 7.2.0
  dependency-type: direct:development
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-11-05 20:08:29 +00:00
ManishMadan2882
9f19c7ee4c Remove deprecated @types/mermaid dependency - mermaid provides its own types 2025-11-05 20:43:47 +05:30
dependabot[bot]
155e74eca1 chore(deps-dev): bump @types/mermaid from 9.1.0 to 9.2.0 in /frontend
Bumps [@types/mermaid](https://github.com/DefinitelyTyped/DefinitelyTyped/tree/HEAD/types/mermaid) from 9.1.0 to 9.2.0.
- [Release notes](https://github.com/DefinitelyTyped/DefinitelyTyped/releases)
- [Commits](https://github.com/DefinitelyTyped/DefinitelyTyped/commits/HEAD/types/mermaid)

---
updated-dependencies:
- dependency-name: "@types/mermaid"
  dependency-version: 9.2.0
  dependency-type: direct:development
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-11-05 15:10:19 +00:00
Manish Madan
ea2dc4dbcb Merge pull request #2133 from arc53/dependabot/npm_and_yarn/frontend/react-i18next-16.2.4
chore(deps): bump react-i18next from 15.7.4 to 16.2.4 in /frontend
2025-11-05 20:23:15 +05:30
dependabot[bot]
616edc97de chore(deps): bump react-i18next from 15.7.4 to 16.2.4 in /frontend
Bumps [react-i18next](https://github.com/i18next/react-i18next) from 15.7.4 to 16.2.4.
- [Changelog](https://github.com/i18next/react-i18next/blob/master/CHANGELOG.md)
- [Commits](https://github.com/i18next/react-i18next/compare/v15.7.4...v16.2.4)

---
updated-dependencies:
- dependency-name: react-i18next
  dependency-version: 16.2.4
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-11-05 14:48:28 +00:00
Manish Madan
b017e99c79 Merge pull request #2132 from arc53/dependabot/npm_and_yarn/frontend/eslint-plugin-n-17.23.1
chore(deps-dev): bump eslint-plugin-n from 16.6.2 to 17.23.1 in /frontend
2025-11-05 20:14:18 +05:30
dependabot[bot]
f698e9d3e1 chore(deps-dev): bump eslint-plugin-n in /frontend
Bumps [eslint-plugin-n](https://github.com/eslint-community/eslint-plugin-n) from 16.6.2 to 17.23.1.
- [Release notes](https://github.com/eslint-community/eslint-plugin-n/releases)
- [Changelog](https://github.com/eslint-community/eslint-plugin-n/blob/master/CHANGELOG.md)
- [Commits](https://github.com/eslint-community/eslint-plugin-n/compare/16.6.2...v17.23.1)

---
updated-dependencies:
- dependency-name: eslint-plugin-n
  dependency-version: 17.23.1
  dependency-type: direct:development
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-11-05 14:35:17 +00:00
Manish Madan
d366502850 Merge pull request #2131 from arc53/dependabot/npm_and_yarn/frontend/typescript-eslint/parser-8.46.3
chore(deps-dev): bump @typescript-eslint/parser from 6.21.0 to 8.46.3 in /frontend
2025-11-05 20:03:59 +05:30
ManishMadan2882
3d6757c170 (chore:lint) relax rules, build fix 2025-11-05 20:02:01 +05:30
Manish Madan
cb8302add8 Fixes shared conversation on cloud version (#2135)
* (fix:shared) conv as id, not dbref

* (chore) script to migrate dbref to id

* (chore): ruff fix

---------

Co-authored-by: GH Action - Upstream Sync <action@github.com>
2025-11-05 16:08:10 +02:00
dependabot[bot]
9d266e9fad chore(deps-dev): bump @typescript-eslint/parser in /frontend
Bumps [@typescript-eslint/parser](https://github.com/typescript-eslint/typescript-eslint/tree/HEAD/packages/parser) from 6.21.0 to 8.46.3.
- [Release notes](https://github.com/typescript-eslint/typescript-eslint/releases)
- [Changelog](https://github.com/typescript-eslint/typescript-eslint/blob/main/packages/parser/CHANGELOG.md)
- [Commits](https://github.com/typescript-eslint/typescript-eslint/commits/v8.46.3/packages/parser)

---
updated-dependencies:
- dependency-name: "@typescript-eslint/parser"
  dependency-version: 8.46.3
  dependency-type: direct:development
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-11-05 13:45:18 +00:00
Manish Madan
ae94c9d31e Merge pull request #2130 from arc53/dependabot/npm_and_yarn/frontend/vite-7.1.12
chore(deps-dev): bump vite from 6.4.1 to 7.1.12 in /frontend
2025-11-05 19:13:59 +05:30
ManishMadan2882
83ab232dcd (chore:fe) pkg lock 2025-11-05 19:12:20 +05:30
dependabot[bot]
eea85772a3 chore(deps-dev): bump vite from 6.4.1 to 7.1.12 in /frontend
Bumps [vite](https://github.com/vitejs/vite/tree/HEAD/packages/vite) from 6.4.1 to 7.1.12.
- [Release notes](https://github.com/vitejs/vite/releases)
- [Changelog](https://github.com/vitejs/vite/blob/v7.1.12/packages/vite/CHANGELOG.md)
- [Commits](https://github.com/vitejs/vite/commits/v7.1.12/packages/vite)

---
updated-dependencies:
- dependency-name: vite
  dependency-version: 7.1.12
  dependency-type: direct:development
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-11-05 19:10:27 +05:30
Alex
0fe7e223cc fix: update Discord invite link across documentation and navigation 2025-11-04 09:27:22 +00:00
Heisenberg Vader
3789d2eb03 Updated the technique for handling multiple file uploads from the user (#2126)
* Fixed multiple file uploads to be sent through a single request to backend for further processing and storing

* Fixed multiple file uploads to be sent through a single request to backend for further processing and storing

* Fixed multiple file uploads to be sent through a single request to backend for further processing and storing

* Made duplicate multiple keyword fixes

* Added back drag and drop functionality and it keeps the multiple file uploads
2025-11-04 01:12:35 +02:00
Manish Madan
d54469532e fix: adjust ESLint rules to warnings for strict type checking (#2129)
- Changed @typescript-eslint/no-explicit-any from error to warning
- Changed @typescript-eslint/no-unused-vars from error to warning
- Allows codebase to pass linting while maintaining code quality checks
- These rules can be gradually enforced as code is refactored
- Verified with npm run build - successful
2025-11-04 01:09:39 +02:00
Manish Madan
9884e51836 Merge pull request #2122 from arc53/dependabot/npm_and_yarn/frontend/prettier-plugin-tailwindcss-0.7.1
chore(deps-dev): bump prettier-plugin-tailwindcss from 0.6.13 to 0.7.1 in /frontend
2025-11-03 19:31:30 +05:30
Alex
6626723180 feat: enhance prompt variable handling and add system variable options in prompts modal (#2128) 2025-11-03 15:54:13 +02:00
Manish Madan
0c251e066b Merge pull request #2124 from arc53/dependabot/npm_and_yarn/frontend/eslint-plugin-n-17.23.1
chore(deps-dev): bump eslint-plugin-n from 15.7.0 to 17.23.1 in /frontend
2025-11-03 19:22:22 +05:30
dependabot[bot]
0957034bfa chore(deps-dev): bump prettier-plugin-tailwindcss in /frontend
Bumps [prettier-plugin-tailwindcss](https://github.com/tailwindlabs/prettier-plugin-tailwindcss) from 0.6.13 to 0.7.1.
- [Release notes](https://github.com/tailwindlabs/prettier-plugin-tailwindcss/releases)
- [Changelog](https://github.com/tailwindlabs/prettier-plugin-tailwindcss/blob/main/CHANGELOG.md)
- [Commits](https://github.com/tailwindlabs/prettier-plugin-tailwindcss/compare/v0.6.13...v0.7.1)

---
updated-dependencies:
- dependency-name: prettier-plugin-tailwindcss
  dependency-version: 0.7.1
  dependency-type: direct:development
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-11-03 13:49:34 +00:00
ManishMadan2882
44521cd893 fix: resolve peer dependency conflict with eslint-plugin-n
- Downgrade eslint-plugin-n from ^17.23.1 to ^16.6.2
- Ensure compatibility with eslint-config-standard-with-typescript@43.0.1
- eslint-config-standard-with-typescript requires eslint-plugin-n@^15.0.0 || ^16.0.0
- Verified with successful npm install and vite build
2025-11-03 19:19:02 +05:30
dependabot[bot]
b17f846730 chore(deps-dev): bump eslint-plugin-n in /frontend
Bumps [eslint-plugin-n](https://github.com/eslint-community/eslint-plugin-n) from 15.7.0 to 17.23.1.
- [Release notes](https://github.com/eslint-community/eslint-plugin-n/releases)
- [Changelog](https://github.com/eslint-community/eslint-plugin-n/blob/master/CHANGELOG.md)
- [Commits](https://github.com/eslint-community/eslint-plugin-n/compare/15.7.0...v17.23.1)

---
updated-dependencies:
- dependency-name: eslint-plugin-n
  dependency-version: 17.23.1
  dependency-type: direct:development
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-11-03 13:45:27 +00:00
74 changed files with 3071 additions and 1171 deletions

View File

@@ -13,7 +13,11 @@ updates:
directory: "/frontend" # Location of package manifests directory: "/frontend" # Location of package manifests
schedule: schedule:
interval: "daily" interval: "daily"
- package-ecosystem: "npm"
directory: "/extensions/react-widget"
schedule:
interval: "daily"
- package-ecosystem: "github-actions" - package-ecosystem: "github-actions"
directory: "/" directory: "/"
schedule: schedule:
interval: "daily" interval: "daily"

View File

@@ -147,5 +147,5 @@ Here's a step-by-step guide on how to contribute to DocsGPT:
Thank you for considering contributing to DocsGPT! 🙏 Thank you for considering contributing to DocsGPT! 🙏
## Questions/collaboration ## Questions/collaboration
Feel free to join our [Discord](https://discord.gg/n5BX8dh8rU). We're very friendly and welcoming to new contributors, so don't hesitate to reach out. Feel free to join our [Discord](https://discord.gg/vN7YFfdMpj). We're very friendly and welcoming to new contributors, so don't hesitate to reach out.
# Thank you so much for considering to contributing DocsGPT!🙏 # Thank you so much for considering to contributing DocsGPT!🙏

View File

@@ -32,7 +32,7 @@ Non-Code Contributions:
- Before contributing check existing [issues](https://github.com/arc53/DocsGPT/issues) or [create](https://github.com/arc53/DocsGPT/issues/new/choose) an issue and wait to get assigned. - Before contributing check existing [issues](https://github.com/arc53/DocsGPT/issues) or [create](https://github.com/arc53/DocsGPT/issues/new/choose) an issue and wait to get assigned.
- Once you are finished with your contribution, please fill in this [form](https://forms.gle/Npaba4n9Epfyx56S8). - Once you are finished with your contribution, please fill in this [form](https://forms.gle/Npaba4n9Epfyx56S8).
- Refer to the [Documentation](https://docs.docsgpt.cloud/). - Refer to the [Documentation](https://docs.docsgpt.cloud/).
- Feel free to join our [Discord](https://discord.gg/n5BX8dh8rU) server. We're here to help newcomers, so don't hesitate to jump in! Join us [here](https://discord.gg/n5BX8dh8rU). - Feel free to join our [Discord](https://discord.gg/vN7YFfdMpj) server. We're here to help newcomers, so don't hesitate to jump in! Join us [here](https://discord.gg/vN7YFfdMpj).
Thank you very much for considering contributing to DocsGPT during Hacktoberfest! 🙏 Your contributions (not just simple typos) could earn you a stylish new t-shirt. Thank you very much for considering contributing to DocsGPT during Hacktoberfest! 🙏 Your contributions (not just simple typos) could earn you a stylish new t-shirt.

View File

@@ -16,10 +16,10 @@
<a href="https://github.com/arc53/DocsGPT">![link to main GitHub showing Forks number](https://img.shields.io/github/forks/arc53/docsgpt?style=social)</a> <a href="https://github.com/arc53/DocsGPT">![link to main GitHub showing Forks number](https://img.shields.io/github/forks/arc53/docsgpt?style=social)</a>
<a href="https://github.com/arc53/DocsGPT/blob/main/LICENSE">![link to license file](https://img.shields.io/github/license/arc53/docsgpt)</a> <a href="https://github.com/arc53/DocsGPT/blob/main/LICENSE">![link to license file](https://img.shields.io/github/license/arc53/docsgpt)</a>
<a href="https://www.bestpractices.dev/projects/9907"><img src="https://www.bestpractices.dev/projects/9907/badge"></a> <a href="https://www.bestpractices.dev/projects/9907"><img src="https://www.bestpractices.dev/projects/9907/badge"></a>
<a href="https://discord.gg/n5BX8dh8rU">![link to discord](https://img.shields.io/discord/1070046503302877216)</a> <a href="https://discord.gg/vN7YFfdMpj">![link to discord](https://img.shields.io/discord/1070046503302877216)</a>
<a href="https://x.com/docsgptai">![X (formerly Twitter) URL](https://img.shields.io/twitter/follow/docsgptai)</a> <a href="https://x.com/docsgptai">![X (formerly Twitter) URL](https://img.shields.io/twitter/follow/docsgptai)</a>
<a href="https://docs.docsgpt.cloud/quickstart">⚡️ Quickstart</a><a href="https://app.docsgpt.cloud/">☁️ Cloud Version</a><a href="https://discord.gg/n5BX8dh8rU">💬 Discord</a> <a href="https://docs.docsgpt.cloud/quickstart">⚡️ Quickstart</a><a href="https://app.docsgpt.cloud/">☁️ Cloud Version</a><a href="https://discord.gg/vN7YFfdMpj">💬 Discord</a>
<br> <br>
<a href="https://docs.docsgpt.cloud/">📖 Documentation</a><a href="https://github.com/arc53/DocsGPT/blob/main/CONTRIBUTING.md">👫 Contribute</a><a href="https://blog.docsgpt.cloud/">🗞 Blog</a> <a href="https://docs.docsgpt.cloud/">📖 Documentation</a><a href="https://github.com/arc53/DocsGPT/blob/main/CONTRIBUTING.md">👫 Contribute</a><a href="https://blog.docsgpt.cloud/">🗞 Blog</a>
<br> <br>

View File

@@ -1,5 +1,8 @@
from application.agents.classic_agent import ClassicAgent from application.agents.classic_agent import ClassicAgent
from application.agents.react_agent import ReActAgent from application.agents.react_agent import ReActAgent
import logging
logger = logging.getLogger(__name__)
class AgentCreator: class AgentCreator:
@@ -13,4 +16,5 @@ class AgentCreator:
agent_class = cls.agents.get(type.lower()) agent_class = cls.agents.get(type.lower())
if not agent_class: if not agent_class:
raise ValueError(f"No agent class found for type {type}") raise ValueError(f"No agent class found for type {type}")
return agent_class(*args, **kwargs) return agent_class(*args, **kwargs)

View File

@@ -21,7 +21,7 @@ class BaseAgent(ABC):
self, self,
endpoint: str, endpoint: str,
llm_name: str, llm_name: str,
gpt_model: str, model_id: str,
api_key: str, api_key: str,
user_api_key: Optional[str] = None, user_api_key: Optional[str] = None,
prompt: str = "", prompt: str = "",
@@ -37,7 +37,7 @@ class BaseAgent(ABC):
): ):
self.endpoint = endpoint self.endpoint = endpoint
self.llm_name = llm_name self.llm_name = llm_name
self.gpt_model = gpt_model self.model_id = model_id
self.api_key = api_key self.api_key = api_key
self.user_api_key = user_api_key self.user_api_key = user_api_key
self.prompt = prompt self.prompt = prompt
@@ -52,6 +52,7 @@ class BaseAgent(ABC):
api_key=api_key, api_key=api_key,
user_api_key=user_api_key, user_api_key=user_api_key,
decoded_token=decoded_token, decoded_token=decoded_token,
model_id=model_id,
) )
self.retrieved_docs = retrieved_docs or [] self.retrieved_docs = retrieved_docs or []
self.llm_handler = LLMHandlerCreator.create_handler( self.llm_handler = LLMHandlerCreator.create_handler(
@@ -316,7 +317,7 @@ class BaseAgent(ABC):
return messages return messages
def _llm_gen(self, messages: List[Dict], log_context: Optional[LogContext] = None): def _llm_gen(self, messages: List[Dict], log_context: Optional[LogContext] = None):
gen_kwargs = {"model": self.gpt_model, "messages": messages} gen_kwargs = {"model": self.model_id, "messages": messages}
if ( if (
hasattr(self.llm, "_supports_tools") hasattr(self.llm, "_supports_tools")

View File

@@ -86,7 +86,7 @@ class ReActAgent(BaseAgent):
messages = [{"role": "user", "content": plan_prompt}] messages = [{"role": "user", "content": plan_prompt}]
plan_stream = self.llm.gen_stream( plan_stream = self.llm.gen_stream(
model=self.gpt_model, model=self.model_id,
messages=messages, messages=messages,
tools=self.tools if self.tools else None, tools=self.tools if self.tools else None,
) )
@@ -151,7 +151,7 @@ class ReActAgent(BaseAgent):
messages = [{"role": "user", "content": final_prompt}] messages = [{"role": "user", "content": final_prompt}]
final_stream = self.llm.gen_stream( final_stream = self.llm.gen_stream(
model=self.gpt_model, messages=messages, tools=None model=self.model_id, messages=messages, tools=None
) )
if log_context: if log_context:

View File

@@ -54,6 +54,10 @@ class AnswerResource(Resource, BaseAnswerResource):
default=True, default=True,
description="Whether to save the conversation", description="Whether to save the conversation",
), ),
"model_id": fields.String(
required=False,
description="Model ID to use for this request",
),
"passthrough": fields.Raw( "passthrough": fields.Raw(
required=False, required=False,
description="Dynamic parameters to inject into prompt template", description="Dynamic parameters to inject into prompt template",
@@ -97,6 +101,7 @@ class AnswerResource(Resource, BaseAnswerResource):
isNoneDoc=data.get("isNoneDoc"), isNoneDoc=data.get("isNoneDoc"),
index=None, index=None,
should_save_conversation=data.get("save_conversation", True), should_save_conversation=data.get("save_conversation", True),
model_id=processor.model_id,
) )
stream_result = self.process_response_stream(stream) stream_result = self.process_response_stream(stream)

View File

@@ -7,11 +7,16 @@ from flask import jsonify, make_response, Response
from flask_restx import Namespace from flask_restx import Namespace
from application.api.answer.services.conversation_service import ConversationService from application.api.answer.services.conversation_service import ConversationService
from application.core.model_utils import (
get_api_key_for_provider,
get_default_model_id,
get_provider_from_model_id,
)
from application.core.mongo_db import MongoDB from application.core.mongo_db import MongoDB
from application.core.settings import settings from application.core.settings import settings
from application.llm.llm_creator import LLMCreator from application.llm.llm_creator import LLMCreator
from application.utils import check_required_fields, get_gpt_model from application.utils import check_required_fields
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -27,7 +32,7 @@ class BaseAnswerResource:
db = mongo[settings.MONGO_DB_NAME] db = mongo[settings.MONGO_DB_NAME]
self.db = db self.db = db
self.user_logs_collection = db["user_logs"] self.user_logs_collection = db["user_logs"]
self.gpt_model = get_gpt_model() self.default_model_id = get_default_model_id()
self.conversation_service = ConversationService() self.conversation_service = ConversationService()
def validate_request( def validate_request(
@@ -54,7 +59,6 @@ class BaseAnswerResource:
api_key = agent_config.get("user_api_key") api_key = agent_config.get("user_api_key")
if not api_key: if not api_key:
return None return None
agents_collection = self.db["agents"] agents_collection = self.db["agents"]
agent = agents_collection.find_one({"key": api_key}) agent = agents_collection.find_one({"key": api_key})
@@ -62,7 +66,6 @@ class BaseAnswerResource:
return make_response( return make_response(
jsonify({"success": False, "message": "Invalid API key."}), 401 jsonify({"success": False, "message": "Invalid API key."}), 401
) )
limited_token_mode_raw = agent.get("limited_token_mode", False) limited_token_mode_raw = agent.get("limited_token_mode", False)
limited_request_mode_raw = agent.get("limited_request_mode", False) limited_request_mode_raw = agent.get("limited_request_mode", False)
@@ -110,15 +113,12 @@ class BaseAnswerResource:
daily_token_usage = token_result[0]["total_tokens"] if token_result else 0 daily_token_usage = token_result[0]["total_tokens"] if token_result else 0
else: else:
daily_token_usage = 0 daily_token_usage = 0
if limited_request_mode: if limited_request_mode:
daily_request_usage = token_usage_collection.count_documents(match_query) daily_request_usage = token_usage_collection.count_documents(match_query)
else: else:
daily_request_usage = 0 daily_request_usage = 0
if not limited_token_mode and not limited_request_mode: if not limited_token_mode and not limited_request_mode:
return None return None
token_exceeded = ( token_exceeded = (
limited_token_mode and token_limit > 0 and daily_token_usage >= token_limit limited_token_mode and token_limit > 0 and daily_token_usage >= token_limit
) )
@@ -138,7 +138,6 @@ class BaseAnswerResource:
), ),
429, 429,
) )
return None return None
def complete_stream( def complete_stream(
@@ -155,6 +154,7 @@ class BaseAnswerResource:
agent_id: Optional[str] = None, agent_id: Optional[str] = None,
is_shared_usage: bool = False, is_shared_usage: bool = False,
shared_token: Optional[str] = None, shared_token: Optional[str] = None,
model_id: Optional[str] = None,
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
""" """
Generator function that streams the complete conversation response. Generator function that streams the complete conversation response.
@@ -173,6 +173,7 @@ class BaseAnswerResource:
agent_id: ID of agent used agent_id: ID of agent used
is_shared_usage: Flag for shared agent usage is_shared_usage: Flag for shared agent usage
shared_token: Token for shared agent shared_token: Token for shared agent
model_id: Model ID used for the request
retrieved_docs: Pre-fetched documents for sources (optional) retrieved_docs: Pre-fetched documents for sources (optional)
Yields: Yields:
@@ -220,7 +221,6 @@ class BaseAnswerResource:
elif "type" in line: elif "type" in line:
data = json.dumps(line) data = json.dumps(line)
yield f"data: {data}\n\n" yield f"data: {data}\n\n"
if is_structured and structured_chunks: if is_structured and structured_chunks:
structured_data = { structured_data = {
"type": "structured_answer", "type": "structured_answer",
@@ -230,15 +230,22 @@ class BaseAnswerResource:
} }
data = json.dumps(structured_data) data = json.dumps(structured_data)
yield f"data: {data}\n\n" yield f"data: {data}\n\n"
if isNoneDoc: if isNoneDoc:
for doc in source_log_docs: for doc in source_log_docs:
doc["source"] = "None" doc["source"] = "None"
provider = (
get_provider_from_model_id(model_id)
if model_id
else settings.LLM_PROVIDER
)
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
llm = LLMCreator.create_llm( llm = LLMCreator.create_llm(
settings.LLM_PROVIDER, provider or settings.LLM_PROVIDER,
api_key=settings.API_KEY, api_key=system_api_key,
user_api_key=user_api_key, user_api_key=user_api_key,
decoded_token=decoded_token, decoded_token=decoded_token,
model_id=model_id,
) )
if should_save_conversation: if should_save_conversation:
@@ -250,7 +257,7 @@ class BaseAnswerResource:
source_log_docs, source_log_docs,
tool_calls, tool_calls,
llm, llm,
self.gpt_model, model_id or self.default_model_id,
decoded_token, decoded_token,
index=index, index=index,
api_key=user_api_key, api_key=user_api_key,
@@ -280,12 +287,11 @@ class BaseAnswerResource:
log_data["structured_output"] = True log_data["structured_output"] = True
if schema_info: if schema_info:
log_data["schema"] = schema_info log_data["schema"] = schema_info
# Clean up text fields to be no longer than 10000 characters # Clean up text fields to be no longer than 10000 characters
for key, value in log_data.items(): for key, value in log_data.items():
if isinstance(value, str) and len(value) > 10000: if isinstance(value, str) and len(value) > 10000:
log_data[key] = value[:10000] log_data[key] = value[:10000]
self.user_logs_collection.insert_one(log_data) self.user_logs_collection.insert_one(log_data)
data = json.dumps({"type": "end"}) data = json.dumps({"type": "end"})
@@ -293,6 +299,7 @@ class BaseAnswerResource:
except GeneratorExit: except GeneratorExit:
logger.info(f"Stream aborted by client for question: {question[:50]}... ") logger.info(f"Stream aborted by client for question: {question[:50]}... ")
# Save partial response # Save partial response
if should_save_conversation and response_full: if should_save_conversation and response_full:
try: try:
if isNoneDoc: if isNoneDoc:
@@ -312,7 +319,7 @@ class BaseAnswerResource:
source_log_docs, source_log_docs,
tool_calls, tool_calls,
llm, llm,
self.gpt_model, model_id or self.default_model_id,
decoded_token, decoded_token,
index=index, index=index,
api_key=user_api_key, api_key=user_api_key,
@@ -369,7 +376,7 @@ class BaseAnswerResource:
thought = event["thought"] thought = event["thought"]
elif event["type"] == "error": elif event["type"] == "error":
logger.error(f"Error from stream: {event['error']}") logger.error(f"Error from stream: {event['error']}")
return None, None, None, None, event["error"] return None, None, None, None, event["error"], None
elif event["type"] == "end": elif event["type"] == "end":
stream_ended = True stream_ended = True
except (json.JSONDecodeError, KeyError) as e: except (json.JSONDecodeError, KeyError) as e:
@@ -377,8 +384,7 @@ class BaseAnswerResource:
continue continue
if not stream_ended: if not stream_ended:
logger.error("Stream ended unexpectedly without an 'end' event.") logger.error("Stream ended unexpectedly without an 'end' event.")
return None, None, None, None, "Stream ended unexpectedly" return None, None, None, None, "Stream ended unexpectedly", None
result = ( result = (
conversation_id, conversation_id,
response_full, response_full,
@@ -390,7 +396,6 @@ class BaseAnswerResource:
if is_structured: if is_structured:
result = result + ({"structured": True, "schema": schema_info},) result = result + ({"structured": True, "schema": schema_info},)
return result return result
def error_stream_generate(self, err_response): def error_stream_generate(self, err_response):

View File

@@ -57,6 +57,10 @@ class StreamResource(Resource, BaseAnswerResource):
default=True, default=True,
description="Whether to save the conversation", description="Whether to save the conversation",
), ),
"model_id": fields.String(
required=False,
description="Model ID to use for this request",
),
"attachments": fields.List( "attachments": fields.List(
fields.String, required=False, description="List of attachment IDs" fields.String, required=False, description="List of attachment IDs"
), ),
@@ -101,6 +105,7 @@ class StreamResource(Resource, BaseAnswerResource):
agent_id=data.get("agent_id"), agent_id=data.get("agent_id"),
is_shared_usage=processor.is_shared_usage, is_shared_usage=processor.is_shared_usage,
shared_token=processor.shared_token, shared_token=processor.shared_token,
model_id=processor.model_id,
), ),
mimetype="text/event-stream", mimetype="text/event-stream",
) )

View File

@@ -52,7 +52,7 @@ class ConversationService:
sources: List[Dict[str, Any]], sources: List[Dict[str, Any]],
tool_calls: List[Dict[str, Any]], tool_calls: List[Dict[str, Any]],
llm: Any, llm: Any,
gpt_model: str, model_id: str,
decoded_token: Dict[str, Any], decoded_token: Dict[str, Any],
index: Optional[int] = None, index: Optional[int] = None,
api_key: Optional[str] = None, api_key: Optional[str] = None,
@@ -66,7 +66,7 @@ class ConversationService:
if not user_id: if not user_id:
raise ValueError("User ID not found in token") raise ValueError("User ID not found in token")
current_time = datetime.now(timezone.utc) current_time = datetime.now(timezone.utc)
# clean up in sources array such that we save max 1k characters for text part # clean up in sources array such that we save max 1k characters for text part
for source in sources: for source in sources:
if "text" in source and isinstance(source["text"], str): if "text" in source and isinstance(source["text"], str):
@@ -90,6 +90,7 @@ class ConversationService:
f"queries.{index}.tool_calls": tool_calls, f"queries.{index}.tool_calls": tool_calls,
f"queries.{index}.timestamp": current_time, f"queries.{index}.timestamp": current_time,
f"queries.{index}.attachments": attachment_ids, f"queries.{index}.attachments": attachment_ids,
f"queries.{index}.model_id": model_id,
} }
}, },
) )
@@ -120,6 +121,7 @@ class ConversationService:
"tool_calls": tool_calls, "tool_calls": tool_calls,
"timestamp": current_time, "timestamp": current_time,
"attachments": attachment_ids, "attachments": attachment_ids,
"model_id": model_id,
} }
} }
}, },
@@ -146,7 +148,7 @@ class ConversationService:
] ]
completion = llm.gen( completion = llm.gen(
model=gpt_model, messages=messages_summary, max_tokens=30 model=model_id, messages=messages_summary, max_tokens=30
) )
conversation_data = { conversation_data = {
@@ -162,6 +164,7 @@ class ConversationService:
"tool_calls": tool_calls, "tool_calls": tool_calls,
"timestamp": current_time, "timestamp": current_time,
"attachments": attachment_ids, "attachments": attachment_ids,
"model_id": model_id,
} }
], ],
} }

View File

@@ -12,12 +12,17 @@ from bson.objectid import ObjectId
from application.agents.agent_creator import AgentCreator from application.agents.agent_creator import AgentCreator
from application.api.answer.services.conversation_service import ConversationService from application.api.answer.services.conversation_service import ConversationService
from application.api.answer.services.prompt_renderer import PromptRenderer from application.api.answer.services.prompt_renderer import PromptRenderer
from application.core.model_utils import (
get_api_key_for_provider,
get_default_model_id,
get_provider_from_model_id,
validate_model_id,
)
from application.core.mongo_db import MongoDB from application.core.mongo_db import MongoDB
from application.core.settings import settings from application.core.settings import settings
from application.retriever.retriever_creator import RetrieverCreator from application.retriever.retriever_creator import RetrieverCreator
from application.utils import ( from application.utils import (
calculate_doc_token_budget, calculate_doc_token_budget,
get_gpt_model,
limit_chat_history, limit_chat_history,
) )
@@ -83,7 +88,7 @@ class StreamProcessor:
self.retriever_config = {} self.retriever_config = {}
self.is_shared_usage = False self.is_shared_usage = False
self.shared_token = None self.shared_token = None
self.gpt_model = get_gpt_model() self.model_id: Optional[str] = None
self.conversation_service = ConversationService() self.conversation_service = ConversationService()
self.prompt_renderer = PromptRenderer() self.prompt_renderer = PromptRenderer()
self._prompt_content: Optional[str] = None self._prompt_content: Optional[str] = None
@@ -91,6 +96,7 @@ class StreamProcessor:
def initialize(self): def initialize(self):
"""Initialize all required components for processing""" """Initialize all required components for processing"""
self._validate_and_set_model()
self._configure_agent() self._configure_agent()
self._configure_source() self._configure_source()
self._configure_retriever() self._configure_retriever()
@@ -112,7 +118,7 @@ class StreamProcessor:
] ]
else: else:
self.history = limit_chat_history( self.history = limit_chat_history(
json.loads(self.data.get("history", "[]")), gpt_model=self.gpt_model json.loads(self.data.get("history", "[]")), model_id=self.model_id
) )
def _process_attachments(self): def _process_attachments(self):
@@ -143,6 +149,25 @@ class StreamProcessor:
) )
return attachments return attachments
def _validate_and_set_model(self):
"""Validate and set model_id from request"""
from application.core.model_settings import ModelRegistry
requested_model = self.data.get("model_id")
if requested_model:
if not validate_model_id(requested_model):
registry = ModelRegistry.get_instance()
available_models = [m.id for m in registry.get_enabled_models()]
raise ValueError(
f"Invalid model_id '{requested_model}'. "
f"Available models: {', '.join(available_models[:5])}"
+ (f" and {len(available_models) - 5} more" if len(available_models) > 5 else "")
)
self.model_id = requested_model
else:
self.model_id = get_default_model_id()
def _get_agent_key(self, agent_id: Optional[str], user_id: Optional[str]) -> tuple: def _get_agent_key(self, agent_id: Optional[str], user_id: Optional[str]) -> tuple:
"""Get API key for agent with access control""" """Get API key for agent with access control"""
if not agent_id: if not agent_id:
@@ -322,7 +347,7 @@ class StreamProcessor:
def _configure_retriever(self): def _configure_retriever(self):
history_token_limit = int(self.data.get("token_limit", 2000)) history_token_limit = int(self.data.get("token_limit", 2000))
doc_token_limit = calculate_doc_token_budget( doc_token_limit = calculate_doc_token_budget(
gpt_model=self.gpt_model, history_token_limit=history_token_limit model_id=self.model_id, history_token_limit=history_token_limit
) )
self.retriever_config = { self.retriever_config = {
@@ -344,7 +369,7 @@ class StreamProcessor:
prompt=get_prompt(self.agent_config["prompt_id"], self.prompts_collection), prompt=get_prompt(self.agent_config["prompt_id"], self.prompts_collection),
chunks=self.retriever_config["chunks"], chunks=self.retriever_config["chunks"],
doc_token_limit=self.retriever_config.get("doc_token_limit", 50000), doc_token_limit=self.retriever_config.get("doc_token_limit", 50000),
gpt_model=self.gpt_model, model_id=self.model_id,
user_api_key=self.agent_config["user_api_key"], user_api_key=self.agent_config["user_api_key"],
decoded_token=self.decoded_token, decoded_token=self.decoded_token,
) )
@@ -626,12 +651,19 @@ class StreamProcessor:
tools_data=tools_data, tools_data=tools_data,
) )
provider = (
get_provider_from_model_id(self.model_id)
if self.model_id
else settings.LLM_PROVIDER
)
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
return AgentCreator.create_agent( return AgentCreator.create_agent(
self.agent_config["agent_type"], self.agent_config["agent_type"],
endpoint="stream", endpoint="stream",
llm_name=settings.LLM_PROVIDER, llm_name=provider or settings.LLM_PROVIDER,
gpt_model=self.gpt_model, model_id=self.model_id,
api_key=settings.API_KEY, api_key=system_api_key,
user_api_key=self.agent_config["user_api_key"], user_api_key=self.agent_config["user_api_key"],
prompt=rendered_prompt, prompt=rendered_prompt,
chat_history=self.history, chat_history=self.history,

View File

@@ -95,6 +95,8 @@ class GetAgent(Resource):
"shared": agent.get("shared_publicly", False), "shared": agent.get("shared_publicly", False),
"shared_metadata": agent.get("shared_metadata", {}), "shared_metadata": agent.get("shared_metadata", {}),
"shared_token": agent.get("shared_token", ""), "shared_token": agent.get("shared_token", ""),
"models": agent.get("models", []),
"default_model_id": agent.get("default_model_id", ""),
} }
return make_response(jsonify(data), 200) return make_response(jsonify(data), 200)
except Exception as e: except Exception as e:
@@ -172,6 +174,8 @@ class GetAgents(Resource):
"shared": agent.get("shared_publicly", False), "shared": agent.get("shared_publicly", False),
"shared_metadata": agent.get("shared_metadata", {}), "shared_metadata": agent.get("shared_metadata", {}),
"shared_token": agent.get("shared_token", ""), "shared_token": agent.get("shared_token", ""),
"models": agent.get("models", []),
"default_model_id": agent.get("default_model_id", ""),
} }
for agent in agents for agent in agents
if "source" in agent or "retriever" in agent if "source" in agent or "retriever" in agent
@@ -230,6 +234,14 @@ class CreateAgent(Resource):
required=False, required=False,
description="Request limit for the agent in limited mode", description="Request limit for the agent in limited mode",
), ),
"models": fields.List(
fields.String,
required=False,
description="List of available model IDs for this agent",
),
"default_model_id": fields.String(
required=False, description="Default model ID for this agent"
),
}, },
) )
@@ -258,6 +270,11 @@ class CreateAgent(Resource):
data["json_schema"] = json.loads(data["json_schema"]) data["json_schema"] = json.loads(data["json_schema"])
except json.JSONDecodeError: except json.JSONDecodeError:
data["json_schema"] = None data["json_schema"] = None
if "models" in data:
try:
data["models"] = json.loads(data["models"])
except json.JSONDecodeError:
data["models"] = []
print(f"Received data: {data}") print(f"Received data: {data}")
# Validate JSON schema if provided # Validate JSON schema if provided
@@ -399,6 +416,8 @@ class CreateAgent(Resource):
"updatedAt": datetime.datetime.now(datetime.timezone.utc), "updatedAt": datetime.datetime.now(datetime.timezone.utc),
"lastUsedAt": None, "lastUsedAt": None,
"key": key, "key": key,
"models": data.get("models", []),
"default_model_id": data.get("default_model_id", ""),
} }
if new_agent["chunks"] == "": if new_agent["chunks"] == "":
new_agent["chunks"] = "2" new_agent["chunks"] = "2"
@@ -464,6 +483,14 @@ class UpdateAgent(Resource):
required=False, required=False,
description="Request limit for the agent in limited mode", description="Request limit for the agent in limited mode",
), ),
"models": fields.List(
fields.String,
required=False,
description="List of available model IDs for this agent",
),
"default_model_id": fields.String(
required=False, description="Default model ID for this agent"
),
}, },
) )
@@ -487,7 +514,7 @@ class UpdateAgent(Resource):
data = request.get_json() data = request.get_json()
else: else:
data = request.form.to_dict() data = request.form.to_dict()
json_fields = ["tools", "sources", "json_schema"] json_fields = ["tools", "sources", "json_schema", "models"]
for field in json_fields: for field in json_fields:
if field in data and data[field]: if field in data and data[field]:
try: try:
@@ -555,6 +582,8 @@ class UpdateAgent(Resource):
"token_limit", "token_limit",
"limited_request_mode", "limited_request_mode",
"request_limit", "request_limit",
"models",
"default_model_id",
] ]
for field in allowed_fields: for field in allowed_fields:

View File

@@ -25,7 +25,7 @@ class StoreAttachment(Resource):
api.model( api.model(
"AttachmentModel", "AttachmentModel",
{ {
"file": fields.Raw(required=True, description="File to upload"), "file": fields.Raw(required=True, description="File(s) to upload"),
"api_key": fields.String( "api_key": fields.String(
required=False, description="API key (optional)" required=False, description="API key (optional)"
), ),
@@ -33,18 +33,24 @@ class StoreAttachment(Resource):
) )
) )
@api.doc( @api.doc(
description="Stores a single attachment without vectorization or training. Supports user or API key authentication." description="Stores one or multiple attachments without vectorization or training. Supports user or API key authentication."
) )
def post(self): def post(self):
decoded_token = getattr(request, "decoded_token", None) decoded_token = getattr(request, "decoded_token", None)
api_key = request.form.get("api_key") or request.args.get("api_key") api_key = request.form.get("api_key") or request.args.get("api_key")
file = request.files.get("file")
files = request.files.getlist("file")
if not file or file.filename == "": if not files:
single_file = request.files.get("file")
if single_file:
files = [single_file]
if not files or all(f.filename == "" for f in files):
return make_response( return make_response(
jsonify({"status": "error", "message": "Missing file"}), jsonify({"status": "error", "message": "Missing file(s)"}),
400, 400,
) )
user = None user = None
if decoded_token: if decoded_token:
user = safe_filename(decoded_token.get("sub")) user = safe_filename(decoded_token.get("sub"))
@@ -59,32 +65,74 @@ class StoreAttachment(Resource):
return make_response( return make_response(
jsonify({"success": False, "message": "Authentication required"}), 401 jsonify({"success": False, "message": "Authentication required"}), 401
) )
try: try:
attachment_id = ObjectId() tasks = []
original_filename = safe_filename(os.path.basename(file.filename)) errors = []
relative_path = f"{settings.UPLOAD_FOLDER}/{user}/attachments/{str(attachment_id)}/{original_filename}" original_file_count = len(files)
for idx, file in enumerate(files):
try:
attachment_id = ObjectId()
original_filename = safe_filename(os.path.basename(file.filename))
relative_path = f"{settings.UPLOAD_FOLDER}/{user}/attachments/{str(attachment_id)}/{original_filename}"
metadata = storage.save_file(file, relative_path) metadata = storage.save_file(file, relative_path)
file_info = {
file_info = { "filename": original_filename,
"filename": original_filename, "attachment_id": str(attachment_id),
"attachment_id": str(attachment_id), "path": relative_path,
"path": relative_path, "metadata": metadata,
"metadata": metadata,
}
task = store_attachment.delay(file_info, user)
return make_response(
jsonify(
{
"success": True,
"task_id": task.id,
"message": "File uploaded successfully. Processing started.",
} }
),
200, task = store_attachment.delay(file_info, user)
) tasks.append({
"task_id": task.id,
"filename": original_filename,
"attachment_id": str(attachment_id),
})
except Exception as file_err:
current_app.logger.error(f"Error processing file {idx} ({file.filename}): {file_err}", exc_info=True)
errors.append({
"filename": file.filename,
"error": str(file_err)
})
if not tasks:
error_msg = "No valid files to upload"
if errors:
error_msg += f". Errors: {errors}"
return make_response(
jsonify({"status": "error", "message": error_msg, "errors": errors}),
400,
)
if original_file_count == 1 and len(tasks) == 1:
current_app.logger.info("Returning single task_id response")
return make_response(
jsonify(
{
"success": True,
"task_id": tasks[0]["task_id"],
"message": "File uploaded successfully. Processing started.",
}
),
200,
)
else:
response_data = {
"success": True,
"tasks": tasks,
"message": f"{len(tasks)} file(s) uploaded successfully. Processing started.",
}
if errors:
response_data["errors"] = errors
response_data["message"] += f" {len(errors)} file(s) failed."
return make_response(
jsonify(response_data),
200,
)
except Exception as err: except Exception as err:
current_app.logger.error(f"Error storing attachment: {err}", exc_info=True) current_app.logger.error(f"Error storing attachment: {err}", exc_info=True)
return make_response(jsonify({"success": False, "error": str(err)}), 400) return make_response(jsonify({"success": False, "error": str(err)}), 400)
@@ -130,15 +178,11 @@ class TextToSpeech(Resource):
@api.expect(tts_model) @api.expect(tts_model)
@api.doc(description="Synthesize audio speech from text") @api.doc(description="Synthesize audio speech from text")
def post(self): def post(self):
from application.utils import clean_text_for_tts
data = request.get_json() data = request.get_json()
text = data["text"] text = data["text"]
cleaned_text = clean_text_for_tts(text)
try: try:
tts_instance = TTSCreator.create_tts(settings.TTS_PROVIDER) tts_instance = TTSCreator.create_tts(settings.TTS_PROVIDER)
audio_base64, detected_language = tts_instance.text_to_speech(cleaned_text) audio_base64, detected_language = tts_instance.text_to_speech(text)
return make_response( return make_response(
jsonify( jsonify(
{ {

View File

@@ -0,0 +1,3 @@
from .routes import models_ns
__all__ = ["models_ns"]

View File

@@ -0,0 +1,25 @@
from flask import current_app, jsonify, make_response
from flask_restx import Namespace, Resource
from application.core.model_settings import ModelRegistry
models_ns = Namespace("models", description="Available models", path="/api")
@models_ns.route("/models")
class ModelsListResource(Resource):
def get(self):
"""Get list of available models with their capabilities."""
try:
registry = ModelRegistry.get_instance()
models = registry.get_enabled_models()
response = {
"models": [model.to_dict() for model in models],
"default_model_id": registry.default_model_id,
"count": len(models),
}
except Exception as err:
current_app.logger.error(f"Error fetching models: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 500)
return make_response(jsonify(response), 200)

View File

@@ -10,6 +10,7 @@ from .agents import agents_ns, agents_sharing_ns, agents_webhooks_ns
from .analytics import analytics_ns from .analytics import analytics_ns
from .attachments import attachments_ns from .attachments import attachments_ns
from .conversations import conversations_ns from .conversations import conversations_ns
from .models import models_ns
from .prompts import prompts_ns from .prompts import prompts_ns
from .sharing import sharing_ns from .sharing import sharing_ns
from .sources import sources_chunks_ns, sources_ns, sources_upload_ns from .sources import sources_chunks_ns, sources_ns, sources_upload_ns
@@ -27,6 +28,9 @@ api.add_namespace(attachments_ns)
# Conversations # Conversations
api.add_namespace(conversations_ns) api.add_namespace(conversations_ns)
# Models
api.add_namespace(models_ns)
# Agents (main, sharing, webhooks) # Agents (main, sharing, webhooks)
api.add_namespace(agents_ns) api.add_namespace(agents_ns)
api.add_namespace(agents_sharing_ns) api.add_namespace(agents_sharing_ns)

View File

@@ -13,7 +13,6 @@ from application.api.user.base import (
agents_collection, agents_collection,
attachments_collection, attachments_collection,
conversations_collection, conversations_collection,
db,
shared_conversations_collections, shared_conversations_collections,
) )
from application.utils import check_required_fields from application.utils import check_required_fields
@@ -97,9 +96,7 @@ class ShareConversation(Resource):
api_uuid = pre_existing_api_document["key"] api_uuid = pre_existing_api_document["key"]
pre_existing = shared_conversations_collections.find_one( pre_existing = shared_conversations_collections.find_one(
{ {
"conversation_id": DBRef( "conversation_id": ObjectId(conversation_id),
"conversations", ObjectId(conversation_id)
),
"isPromptable": is_promptable, "isPromptable": is_promptable,
"first_n_queries": current_n_queries, "first_n_queries": current_n_queries,
"user": user, "user": user,
@@ -120,10 +117,7 @@ class ShareConversation(Resource):
shared_conversations_collections.insert_one( shared_conversations_collections.insert_one(
{ {
"uuid": explicit_binary, "uuid": explicit_binary,
"conversation_id": { "conversation_id": ObjectId(conversation_id),
"$ref": "conversations",
"$id": ObjectId(conversation_id),
},
"isPromptable": is_promptable, "isPromptable": is_promptable,
"first_n_queries": current_n_queries, "first_n_queries": current_n_queries,
"user": user, "user": user,
@@ -154,10 +148,7 @@ class ShareConversation(Resource):
shared_conversations_collections.insert_one( shared_conversations_collections.insert_one(
{ {
"uuid": explicit_binary, "uuid": explicit_binary,
"conversation_id": { "conversation_id": ObjectId(conversation_id),
"$ref": "conversations",
"$id": ObjectId(conversation_id),
},
"isPromptable": is_promptable, "isPromptable": is_promptable,
"first_n_queries": current_n_queries, "first_n_queries": current_n_queries,
"user": user, "user": user,
@@ -175,9 +166,7 @@ class ShareConversation(Resource):
) )
pre_existing = shared_conversations_collections.find_one( pre_existing = shared_conversations_collections.find_one(
{ {
"conversation_id": DBRef( "conversation_id": ObjectId(conversation_id),
"conversations", ObjectId(conversation_id)
),
"isPromptable": is_promptable, "isPromptable": is_promptable,
"first_n_queries": current_n_queries, "first_n_queries": current_n_queries,
"user": user, "user": user,
@@ -197,10 +186,7 @@ class ShareConversation(Resource):
shared_conversations_collections.insert_one( shared_conversations_collections.insert_one(
{ {
"uuid": explicit_binary, "uuid": explicit_binary,
"conversation_id": { "conversation_id": ObjectId(conversation_id),
"$ref": "conversations",
"$id": ObjectId(conversation_id),
},
"isPromptable": is_promptable, "isPromptable": is_promptable,
"first_n_queries": current_n_queries, "first_n_queries": current_n_queries,
"user": user, "user": user,
@@ -233,10 +219,12 @@ class GetPubliclySharedConversations(Resource):
if ( if (
shared shared
and "conversation_id" in shared and "conversation_id" in shared
and isinstance(shared["conversation_id"], DBRef)
): ):
conversation_ref = shared["conversation_id"] # conversation_id is now stored as an ObjectId, not a DBRef
conversation = db.dereference(conversation_ref) conversation_id = shared["conversation_id"]
conversation = conversations_collection.find_one(
{"_id": conversation_id}
)
if conversation is None: if conversation is None:
return make_response( return make_response(
jsonify( jsonify(

View File

@@ -56,9 +56,10 @@ class GetTools(Resource):
tools = user_tools_collection.find({"user": user}) tools = user_tools_collection.find({"user": user})
user_tools = [] user_tools = []
for tool in tools: for tool in tools:
tool["id"] = str(tool["_id"]) tool_copy = {**tool}
tool.pop("_id") tool_copy["id"] = str(tool["_id"])
user_tools.append(tool) tool_copy.pop("_id", None)
user_tools.append(tool_copy)
except Exception as err: except Exception as err:
current_app.logger.error(f"Error getting user tools: {err}", exc_info=True) current_app.logger.error(f"Error getting user tools: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400) return make_response(jsonify({"success": False}), 400)

View File

@@ -0,0 +1,223 @@
"""
Model configurations for all supported LLM providers.
"""
from application.core.model_settings import (
AvailableModel,
ModelCapabilities,
ModelProvider,
)
OPENAI_ATTACHMENTS = [
"application/pdf",
"image/png",
"image/jpeg",
"image/jpg",
"image/webp",
"image/gif",
]
GOOGLE_ATTACHMENTS = [
"application/pdf",
"image/png",
"image/jpeg",
"image/jpg",
"image/webp",
"image/gif",
]
OPENAI_MODELS = [
AvailableModel(
id="gpt-4o",
provider=ModelProvider.OPENAI,
display_name="GPT-4 Omni",
description="Latest and most capable model",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
context_window=128000,
),
),
AvailableModel(
id="gpt-4o-mini",
provider=ModelProvider.OPENAI,
display_name="GPT-4 Omni Mini",
description="Fast and efficient",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
context_window=128000,
),
),
AvailableModel(
id="gpt-4-turbo",
provider=ModelProvider.OPENAI,
display_name="GPT-4 Turbo",
description="Fast GPT-4 with 128k context",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
context_window=128000,
),
),
AvailableModel(
id="gpt-4",
provider=ModelProvider.OPENAI,
display_name="GPT-4",
description="Most capable model",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
context_window=8192,
),
),
AvailableModel(
id="gpt-3.5-turbo",
provider=ModelProvider.OPENAI,
display_name="GPT-3.5 Turbo",
description="Fast and cost-effective",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=4096,
),
),
]
ANTHROPIC_MODELS = [
AvailableModel(
id="claude-3-5-sonnet-20241022",
provider=ModelProvider.ANTHROPIC,
display_name="Claude 3.5 Sonnet (Latest)",
description="Latest Claude 3.5 Sonnet with enhanced capabilities",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=200000,
),
),
AvailableModel(
id="claude-3-5-sonnet",
provider=ModelProvider.ANTHROPIC,
display_name="Claude 3.5 Sonnet",
description="Balanced performance and capability",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=200000,
),
),
AvailableModel(
id="claude-3-opus",
provider=ModelProvider.ANTHROPIC,
display_name="Claude 3 Opus",
description="Most capable Claude model",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=200000,
),
),
AvailableModel(
id="claude-3-haiku",
provider=ModelProvider.ANTHROPIC,
display_name="Claude 3 Haiku",
description="Fastest Claude model",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=200000,
),
),
]
GOOGLE_MODELS = [
AvailableModel(
id="gemini-flash-latest",
provider=ModelProvider.GOOGLE,
display_name="Gemini Flash (Latest)",
description="Latest experimental Gemini model",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=GOOGLE_ATTACHMENTS,
context_window=int(1e6),
),
),
AvailableModel(
id="gemini-flash-lite-latest",
provider=ModelProvider.GOOGLE,
display_name="Gemini Flash Lite (Latest)",
description="Fast with huge context window",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=GOOGLE_ATTACHMENTS,
context_window=int(1e6),
),
),
AvailableModel(
id="gemini-2.5-pro",
provider=ModelProvider.GOOGLE,
display_name="Gemini 2.5 Pro",
description="Most capable Gemini model",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=GOOGLE_ATTACHMENTS,
context_window=2000000,
),
),
]
GROQ_MODELS = [
AvailableModel(
id="llama-3.3-70b-versatile",
provider=ModelProvider.GROQ,
display_name="Llama 3.3 70B",
description="Latest Llama model with high-speed inference",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=128000,
),
),
AvailableModel(
id="llama-3.1-8b-instant",
provider=ModelProvider.GROQ,
display_name="Llama 3.1 8B",
description="Ultra-fast inference",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=128000,
),
),
AvailableModel(
id="mixtral-8x7b-32768",
provider=ModelProvider.GROQ,
display_name="Mixtral 8x7B",
description="High-speed inference with tools",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=32768,
),
),
]
AZURE_OPENAI_MODELS = [
AvailableModel(
id="azure-gpt-4",
provider=ModelProvider.AZURE_OPENAI,
display_name="Azure OpenAI GPT-4",
description="Azure-hosted GPT model",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
context_window=8192,
),
),
]

View File

@@ -0,0 +1,236 @@
import logging
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, List, Optional
logger = logging.getLogger(__name__)
class ModelProvider(str, Enum):
OPENAI = "openai"
AZURE_OPENAI = "azure_openai"
ANTHROPIC = "anthropic"
GROQ = "groq"
GOOGLE = "google"
HUGGINGFACE = "huggingface"
LLAMA_CPP = "llama.cpp"
DOCSGPT = "docsgpt"
PREMAI = "premai"
SAGEMAKER = "sagemaker"
NOVITA = "novita"
@dataclass
class ModelCapabilities:
supports_tools: bool = False
supports_structured_output: bool = False
supports_streaming: bool = True
supported_attachment_types: List[str] = field(default_factory=list)
context_window: int = 128000
input_cost_per_token: Optional[float] = None
output_cost_per_token: Optional[float] = None
@dataclass
class AvailableModel:
id: str
provider: ModelProvider
display_name: str
description: str = ""
capabilities: ModelCapabilities = field(default_factory=ModelCapabilities)
enabled: bool = True
base_url: Optional[str] = None
def to_dict(self) -> Dict:
result = {
"id": self.id,
"provider": self.provider.value,
"display_name": self.display_name,
"description": self.description,
"supported_attachment_types": self.capabilities.supported_attachment_types,
"supports_tools": self.capabilities.supports_tools,
"supports_structured_output": self.capabilities.supports_structured_output,
"supports_streaming": self.capabilities.supports_streaming,
"context_window": self.capabilities.context_window,
"enabled": self.enabled,
}
if self.base_url:
result["base_url"] = self.base_url
return result
class ModelRegistry:
_instance = None
_initialized = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not ModelRegistry._initialized:
self.models: Dict[str, AvailableModel] = {}
self.default_model_id: Optional[str] = None
self._load_models()
ModelRegistry._initialized = True
@classmethod
def get_instance(cls) -> "ModelRegistry":
return cls()
def _load_models(self):
from application.core.settings import settings
self.models.clear()
self._add_docsgpt_models(settings)
if settings.OPENAI_API_KEY or (
settings.LLM_PROVIDER == "openai" and settings.API_KEY
):
self._add_openai_models(settings)
if settings.OPENAI_API_BASE or (
settings.LLM_PROVIDER == "azure_openai" and settings.API_KEY
):
self._add_azure_openai_models(settings)
if settings.ANTHROPIC_API_KEY or (
settings.LLM_PROVIDER == "anthropic" and settings.API_KEY
):
self._add_anthropic_models(settings)
if settings.GOOGLE_API_KEY or (
settings.LLM_PROVIDER == "google" and settings.API_KEY
):
self._add_google_models(settings)
if settings.GROQ_API_KEY or (
settings.LLM_PROVIDER == "groq" and settings.API_KEY
):
self._add_groq_models(settings)
if settings.HUGGINGFACE_API_KEY or (
settings.LLM_PROVIDER == "huggingface" and settings.API_KEY
):
self._add_huggingface_models(settings)
# Default model selection
if settings.LLM_NAME and settings.LLM_NAME in self.models:
self.default_model_id = settings.LLM_NAME
elif settings.LLM_PROVIDER and settings.API_KEY:
for model_id, model in self.models.items():
if model.provider.value == settings.LLM_PROVIDER:
self.default_model_id = model_id
break
else:
self.default_model_id = next(iter(self.models.keys()))
logger.info(
f"ModelRegistry loaded {len(self.models)} models, default: {self.default_model_id}"
)
def _add_openai_models(self, settings):
from application.core.model_configs import OPENAI_MODELS
if settings.OPENAI_API_KEY:
for model in OPENAI_MODELS:
self.models[model.id] = model
return
if settings.LLM_PROVIDER == "openai" and settings.LLM_NAME:
for model in OPENAI_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in OPENAI_MODELS:
self.models[model.id] = model
def _add_azure_openai_models(self, settings):
from application.core.model_configs import AZURE_OPENAI_MODELS
if settings.LLM_PROVIDER == "azure_openai" and settings.LLM_NAME:
for model in AZURE_OPENAI_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in AZURE_OPENAI_MODELS:
self.models[model.id] = model
def _add_anthropic_models(self, settings):
from application.core.model_configs import ANTHROPIC_MODELS
if settings.ANTHROPIC_API_KEY:
for model in ANTHROPIC_MODELS:
self.models[model.id] = model
return
if settings.LLM_PROVIDER == "anthropic" and settings.LLM_NAME:
for model in ANTHROPIC_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in ANTHROPIC_MODELS:
self.models[model.id] = model
def _add_google_models(self, settings):
from application.core.model_configs import GOOGLE_MODELS
if settings.GOOGLE_API_KEY:
for model in GOOGLE_MODELS:
self.models[model.id] = model
return
if settings.LLM_PROVIDER == "google" and settings.LLM_NAME:
for model in GOOGLE_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in GOOGLE_MODELS:
self.models[model.id] = model
def _add_groq_models(self, settings):
from application.core.model_configs import GROQ_MODELS
if settings.GROQ_API_KEY:
for model in GROQ_MODELS:
self.models[model.id] = model
return
if settings.LLM_PROVIDER == "groq" and settings.LLM_NAME:
for model in GROQ_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in GROQ_MODELS:
self.models[model.id] = model
def _add_docsgpt_models(self, settings):
model_id = "docsgpt-local"
model = AvailableModel(
id=model_id,
provider=ModelProvider.DOCSGPT,
display_name="DocsGPT Model",
description="Local model",
capabilities=ModelCapabilities(
supports_tools=False,
supported_attachment_types=[],
),
)
self.models[model_id] = model
def _add_huggingface_models(self, settings):
model_id = "huggingface-local"
model = AvailableModel(
id=model_id,
provider=ModelProvider.HUGGINGFACE,
display_name="Hugging Face Model",
description="Local Hugging Face model",
capabilities=ModelCapabilities(
supports_tools=False,
supported_attachment_types=[],
),
)
self.models[model_id] = model
def get_model(self, model_id: str) -> Optional[AvailableModel]:
return self.models.get(model_id)
def get_all_models(self) -> List[AvailableModel]:
return list(self.models.values())
def get_enabled_models(self) -> List[AvailableModel]:
return [m for m in self.models.values() if m.enabled]
def model_exists(self, model_id: str) -> bool:
return model_id in self.models

View File

@@ -0,0 +1,91 @@
from typing import Any, Dict, Optional
from application.core.model_settings import ModelRegistry
def get_api_key_for_provider(provider: str) -> Optional[str]:
"""Get the appropriate API key for a provider"""
from application.core.settings import settings
provider_key_map = {
"openai": settings.OPENAI_API_KEY,
"anthropic": settings.ANTHROPIC_API_KEY,
"google": settings.GOOGLE_API_KEY,
"groq": settings.GROQ_API_KEY,
"huggingface": settings.HUGGINGFACE_API_KEY,
"azure_openai": settings.API_KEY,
"docsgpt": None,
"llama.cpp": None,
}
provider_key = provider_key_map.get(provider)
if provider_key:
return provider_key
return settings.API_KEY
def get_all_available_models() -> Dict[str, Dict[str, Any]]:
"""Get all available models with metadata for API response"""
registry = ModelRegistry.get_instance()
return {model.id: model.to_dict() for model in registry.get_enabled_models()}
def validate_model_id(model_id: str) -> bool:
"""Check if a model ID exists in registry"""
registry = ModelRegistry.get_instance()
return registry.model_exists(model_id)
def get_model_capabilities(model_id: str) -> Optional[Dict[str, Any]]:
"""Get capabilities for a specific model"""
registry = ModelRegistry.get_instance()
model = registry.get_model(model_id)
if model:
return {
"supported_attachment_types": model.capabilities.supported_attachment_types,
"supports_tools": model.capabilities.supports_tools,
"supports_structured_output": model.capabilities.supports_structured_output,
"context_window": model.capabilities.context_window,
}
return None
def get_default_model_id() -> str:
"""Get the system default model ID"""
registry = ModelRegistry.get_instance()
return registry.default_model_id
def get_provider_from_model_id(model_id: str) -> Optional[str]:
"""Get the provider name for a given model_id"""
registry = ModelRegistry.get_instance()
model = registry.get_model(model_id)
if model:
return model.provider.value
return None
def get_token_limit(model_id: str) -> int:
"""
Get context window (token limit) for a model.
Returns model's context_window or default 128000 if model not found.
"""
from application.core.settings import settings
registry = ModelRegistry.get_instance()
model = registry.get_model(model_id)
if model:
return model.capabilities.context_window
return settings.DEFAULT_LLM_TOKEN_LIMIT
def get_base_url_for_model(model_id: str) -> Optional[str]:
"""
Get the custom base_url for a specific model if configured.
Returns None if no custom base_url is set.
"""
registry = ModelRegistry.get_instance()
model = registry.get_model(model_id)
if model:
return model.base_url
return None

View File

@@ -22,15 +22,7 @@ class Settings(BaseSettings):
MONGO_DB_NAME: str = "docsgpt" MONGO_DB_NAME: str = "docsgpt"
LLM_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf") LLM_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf")
DEFAULT_MAX_HISTORY: int = 150 DEFAULT_MAX_HISTORY: int = 150
LLM_TOKEN_LIMITS: dict = { DEFAULT_LLM_TOKEN_LIMIT: int = 128000 # Fallback when model not found in registry
"gpt-4o": 128000,
"gpt-4o-mini": 128000,
"gpt-4": 8192,
"gpt-3.5-turbo": 4096,
"claude-2": int(1e5),
"gemini-2.5-flash": int(1e6),
}
DEFAULT_LLM_TOKEN_LIMIT: int = 128000
RESERVED_TOKENS: dict = { RESERVED_TOKENS: dict = {
"system_prompt": 500, "system_prompt": 500,
"current_query": 500, "current_query": 500,
@@ -64,14 +56,22 @@ class Settings(BaseSettings):
) )
# GitHub source # GitHub source
GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access
# LLM Cache # LLM Cache
CACHE_REDIS_URL: str = "redis://localhost:6379/2" CACHE_REDIS_URL: str = "redis://localhost:6379/2"
API_URL: str = "http://localhost:7091" # backend url for celery worker API_URL: str = "http://localhost:7091" # backend url for celery worker
API_KEY: Optional[str] = None # LLM api key API_KEY: Optional[str] = None # LLM api key (used by LLM_PROVIDER)
# Provider-specific API keys (for multi-model support)
OPENAI_API_KEY: Optional[str] = None
ANTHROPIC_API_KEY: Optional[str] = None
GOOGLE_API_KEY: Optional[str] = None
GROQ_API_KEY: Optional[str] = None
HUGGINGFACE_API_KEY: Optional[str] = None
EMBEDDINGS_KEY: Optional[str] = ( EMBEDDINGS_KEY: Optional[str] = (
None # api key for embeddings (if using openai, just copy API_KEY) None # api key for embeddings (if using openai, just copy API_KEY)
) )
@@ -138,11 +138,12 @@ class Settings(BaseSettings):
# Encryption settings # Encryption settings
ENCRYPTION_SECRET_KEY: str = "default-docsgpt-encryption-key" ENCRYPTION_SECRET_KEY: str = "default-docsgpt-encryption-key"
TTS_PROVIDER: str = "google_tts" # google_tts or elevenlabs TTS_PROVIDER: str = "google_tts" # google_tts or elevenlabs
ELEVENLABS_API_KEY: Optional[str] = None ELEVENLABS_API_KEY: Optional[str] = None
# Tool pre-fetch settings # Tool pre-fetch settings
ENABLE_TOOL_PREFETCH: bool = True ENABLE_TOOL_PREFETCH: bool = True
path = Path(__file__).parent.parent.absolute() path = Path(__file__).parent.parent.absolute()
settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8") settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8")

View File

@@ -1,30 +1,41 @@
from application.llm.base import BaseLLM from anthropic import AI_PROMPT, Anthropic, HUMAN_PROMPT
from application.core.settings import settings from application.core.settings import settings
from application.llm.base import BaseLLM
class AnthropicLLM(BaseLLM): class AnthropicLLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs): def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.api_key = ( self.api_key = api_key or settings.ANTHROPIC_API_KEY or settings.API_KEY
api_key or settings.ANTHROPIC_API_KEY
) # If not provided, use a default from settings
self.user_api_key = user_api_key self.user_api_key = user_api_key
self.anthropic = Anthropic(api_key=self.api_key)
# Use custom base_url if provided
if base_url:
self.anthropic = Anthropic(api_key=self.api_key, base_url=base_url)
else:
self.anthropic = Anthropic(api_key=self.api_key)
self.HUMAN_PROMPT = HUMAN_PROMPT self.HUMAN_PROMPT = HUMAN_PROMPT
self.AI_PROMPT = AI_PROMPT self.AI_PROMPT = AI_PROMPT
def _raw_gen( def _raw_gen(
self, baseself, model, messages, stream=False, tools=None, max_tokens=300, **kwargs self,
baseself,
model,
messages,
stream=False,
tools=None,
max_tokens=300,
**kwargs,
): ):
context = messages[0]["content"] context = messages[0]["content"]
user_question = messages[-1]["content"] user_question = messages[-1]["content"]
prompt = f"### Context \n {context} \n ### Question \n {user_question}" prompt = f"### Context \n {context} \n ### Question \n {user_question}"
if stream: if stream:
return self.gen_stream(model, prompt, stream, max_tokens, **kwargs) return self.gen_stream(model, prompt, stream, max_tokens, **kwargs)
completion = self.anthropic.completions.create( completion = self.anthropic.completions.create(
model=model, model=model,
max_tokens_to_sample=max_tokens, max_tokens_to_sample=max_tokens,
@@ -34,7 +45,14 @@ class AnthropicLLM(BaseLLM):
return completion.completion return completion.completion
def _raw_gen_stream( def _raw_gen_stream(
self, baseself, model, messages, stream=True, tools=None, max_tokens=300, **kwargs self,
baseself,
model,
messages,
stream=True,
tools=None,
max_tokens=300,
**kwargs,
): ):
context = messages[0]["content"] context = messages[0]["content"]
user_question = messages[-1]["content"] user_question = messages[-1]["content"]
@@ -50,5 +68,5 @@ class AnthropicLLM(BaseLLM):
for completion in stream_response: for completion in stream_response:
yield completion.completion yield completion.completion
finally: finally:
if hasattr(stream_response, 'close'): if hasattr(stream_response, "close"):
stream_response.close() stream_response.close()

View File

@@ -13,30 +13,32 @@ class BaseLLM(ABC):
def __init__( def __init__(
self, self,
decoded_token=None, decoded_token=None,
model_id=None,
base_url=None,
): ):
self.decoded_token = decoded_token self.decoded_token = decoded_token
self.model_id = model_id
self.base_url = base_url
self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0} self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
self.fallback_provider = settings.FALLBACK_LLM_PROVIDER
self.fallback_model_name = settings.FALLBACK_LLM_NAME
self.fallback_llm_api_key = settings.FALLBACK_LLM_API_KEY
self._fallback_llm = None self._fallback_llm = None
self._fallback_sequence_index = 0
@property @property
def fallback_llm(self): def fallback_llm(self):
"""Lazy-loaded fallback LLM instance.""" """Lazy-loaded fallback LLM from FALLBACK_* settings."""
if ( if self._fallback_llm is None and settings.FALLBACK_LLM_PROVIDER:
self._fallback_llm is None
and self.fallback_provider
and self.fallback_model_name
):
try: try:
from application.llm.llm_creator import LLMCreator from application.llm.llm_creator import LLMCreator
self._fallback_llm = LLMCreator.create_llm( self._fallback_llm = LLMCreator.create_llm(
self.fallback_provider, settings.FALLBACK_LLM_PROVIDER,
self.fallback_llm_api_key, api_key=settings.FALLBACK_LLM_API_KEY or settings.API_KEY,
None, user_api_key=None,
self.decoded_token, decoded_token=self.decoded_token,
model_id=settings.FALLBACK_LLM_NAME,
)
logger.info(
f"Fallback LLM initialized: {settings.FALLBACK_LLM_PROVIDER}/{settings.FALLBACK_LLM_NAME}"
) )
except Exception as e: except Exception as e:
logger.error( logger.error(
@@ -54,7 +56,7 @@ class BaseLLM(ABC):
self, method_name: str, decorators: list, *args, **kwargs self, method_name: str, decorators: list, *args, **kwargs
): ):
""" """
Unified method execution with fallback support. Execute method with fallback support.
Args: Args:
method_name: Name of the raw method ('_raw_gen' or '_raw_gen_stream') method_name: Name of the raw method ('_raw_gen' or '_raw_gen_stream')
@@ -73,10 +75,10 @@ class BaseLLM(ABC):
return decorated_method() return decorated_method()
except Exception as e: except Exception as e:
if not self.fallback_llm: if not self.fallback_llm:
logger.error(f"Primary LLM failed and no fallback available: {str(e)}") logger.error(f"Primary LLM failed and no fallback configured: {str(e)}")
raise raise
logger.warning( logger.warning(
f"Falling back to {self.fallback_provider}/{self.fallback_model_name}. Error: {str(e)}" f"Primary LLM failed. Falling back to {settings.FALLBACK_LLM_PROVIDER}/{settings.FALLBACK_LLM_NAME}. Error: {str(e)}"
) )
fallback_method = getattr( fallback_method = getattr(

View File

@@ -1,5 +1,7 @@
import json import json
from openai import OpenAI
from application.core.settings import settings from application.core.settings import settings
from application.llm.base import BaseLLM from application.llm.base import BaseLLM
@@ -7,12 +9,11 @@ from application.llm.base import BaseLLM
class DocsGPTAPILLM(BaseLLM): class DocsGPTAPILLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs): def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
from openai import OpenAI
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.client = OpenAI(api_key="sk-docsgpt-public", base_url="https://oai.arc53.com") self.api_key = "sk-docsgpt-public"
self.client = OpenAI(api_key=self.api_key, base_url="https://oai.arc53.com")
self.user_api_key = user_api_key self.user_api_key = user_api_key
self.api_key = api_key
def _clean_messages_openai(self, messages): def _clean_messages_openai(self, messages):
cleaned_messages = [] cleaned_messages = []
@@ -22,7 +23,6 @@ class DocsGPTAPILLM(BaseLLM):
if role == "model": if role == "model":
role = "assistant" role = "assistant"
if role and content is not None: if role and content is not None:
if isinstance(content, str): if isinstance(content, str):
cleaned_messages.append({"role": role, "content": content}) cleaned_messages.append({"role": role, "content": content})
@@ -69,7 +69,6 @@ class DocsGPTAPILLM(BaseLLM):
) )
else: else:
raise ValueError(f"Unexpected content type: {type(content)}") raise ValueError(f"Unexpected content type: {type(content)}")
return cleaned_messages return cleaned_messages
def _raw_gen( def _raw_gen(
@@ -121,7 +120,6 @@ class DocsGPTAPILLM(BaseLLM):
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
model="docsgpt", messages=messages, stream=stream, **kwargs model="docsgpt", messages=messages, stream=stream, **kwargs
) )
try: try:
for line in response: for line in response:
if ( if (
@@ -133,8 +131,8 @@ class DocsGPTAPILLM(BaseLLM):
elif len(line.choices) > 0: elif len(line.choices) > 0:
yield line.choices[0] yield line.choices[0]
finally: finally:
if hasattr(response, 'close'): if hasattr(response, "close"):
response.close() response.close()
def _supports_tools(self): def _supports_tools(self):
return True return True

View File

@@ -13,8 +13,9 @@ from application.storage.storage_creator import StorageCreator
class GoogleLLM(BaseLLM): class GoogleLLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs): def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.api_key = api_key self.api_key = api_key or settings.GOOGLE_API_KEY or settings.API_KEY
self.user_api_key = user_api_key self.user_api_key = user_api_key
self.client = genai.Client(api_key=self.api_key) self.client = genai.Client(api_key=self.api_key)
self.storage = StorageCreator.get_storage() self.storage = StorageCreator.get_storage()
@@ -47,21 +48,19 @@ class GoogleLLM(BaseLLM):
""" """
if not attachments: if not attachments:
return messages return messages
prepared_messages = messages.copy() prepared_messages = messages.copy()
# Find the user message to attach files to the last one # Find the user message to attach files to the last one
user_message_index = None user_message_index = None
for i in range(len(prepared_messages) - 1, -1, -1): for i in range(len(prepared_messages) - 1, -1, -1):
if prepared_messages[i].get("role") == "user": if prepared_messages[i].get("role") == "user":
user_message_index = i user_message_index = i
break break
if user_message_index is None: if user_message_index is None:
user_message = {"role": "user", "content": []} user_message = {"role": "user", "content": []}
prepared_messages.append(user_message) prepared_messages.append(user_message)
user_message_index = len(prepared_messages) - 1 user_message_index = len(prepared_messages) - 1
if isinstance(prepared_messages[user_message_index].get("content"), str): if isinstance(prepared_messages[user_message_index].get("content"), str):
text_content = prepared_messages[user_message_index]["content"] text_content = prepared_messages[user_message_index]["content"]
prepared_messages[user_message_index]["content"] = [ prepared_messages[user_message_index]["content"] = [
@@ -69,7 +68,6 @@ class GoogleLLM(BaseLLM):
] ]
elif not isinstance(prepared_messages[user_message_index].get("content"), list): elif not isinstance(prepared_messages[user_message_index].get("content"), list):
prepared_messages[user_message_index]["content"] = [] prepared_messages[user_message_index]["content"] = []
files = [] files = []
for attachment in attachments: for attachment in attachments:
mime_type = attachment.get("mime_type") mime_type = attachment.get("mime_type")
@@ -92,11 +90,9 @@ class GoogleLLM(BaseLLM):
"text": f"[File could not be processed: {attachment.get('path', 'unknown')}]", "text": f"[File could not be processed: {attachment.get('path', 'unknown')}]",
} }
) )
if files: if files:
logging.info(f"GoogleLLM: Adding {len(files)} files to message") logging.info(f"GoogleLLM: Adding {len(files)} files to message")
prepared_messages[user_message_index]["content"].append({"files": files}) prepared_messages[user_message_index]["content"].append({"files": files})
return prepared_messages return prepared_messages
def _upload_file_to_google(self, attachment): def _upload_file_to_google(self, attachment):
@@ -111,14 +107,11 @@ class GoogleLLM(BaseLLM):
""" """
if "google_file_uri" in attachment: if "google_file_uri" in attachment:
return attachment["google_file_uri"] return attachment["google_file_uri"]
file_path = attachment.get("path") file_path = attachment.get("path")
if not file_path: if not file_path:
raise ValueError("No file path provided in attachment") raise ValueError("No file path provided in attachment")
if not self.storage.file_exists(file_path): if not self.storage.file_exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}") raise FileNotFoundError(f"File not found: {file_path}")
try: try:
file_uri = self.storage.process_file( file_uri = self.storage.process_file(
file_path, file_path,
@@ -136,7 +129,6 @@ class GoogleLLM(BaseLLM):
attachments_collection.update_one( attachments_collection.update_one(
{"_id": attachment["_id"]}, {"$set": {"google_file_uri": file_uri}} {"_id": attachment["_id"]}, {"$set": {"google_file_uri": file_uri}}
) )
return file_uri return file_uri
except Exception as e: except Exception as e:
logging.error(f"Error uploading file to Google AI: {e}", exc_info=True) logging.error(f"Error uploading file to Google AI: {e}", exc_info=True)
@@ -153,7 +145,6 @@ class GoogleLLM(BaseLLM):
role = "model" role = "model"
elif role == "tool": elif role == "tool":
role = "model" role = "model"
parts = [] parts = []
if role and content is not None: if role and content is not None:
if isinstance(content, str): if isinstance(content, str):
@@ -164,6 +155,7 @@ class GoogleLLM(BaseLLM):
parts.append(types.Part.from_text(text=item["text"])) parts.append(types.Part.from_text(text=item["text"]))
elif "function_call" in item: elif "function_call" in item:
# Remove null values from args to avoid API errors # Remove null values from args to avoid API errors
cleaned_args = self._remove_null_values( cleaned_args = self._remove_null_values(
item["function_call"]["args"] item["function_call"]["args"]
) )
@@ -194,10 +186,8 @@ class GoogleLLM(BaseLLM):
) )
else: else:
raise ValueError(f"Unexpected content type: {type(content)}") raise ValueError(f"Unexpected content type: {type(content)}")
if parts: if parts:
cleaned_messages.append(types.Content(role=role, parts=parts)) cleaned_messages.append(types.Content(role=role, parts=parts))
return cleaned_messages return cleaned_messages
def _clean_schema(self, schema_obj): def _clean_schema(self, schema_obj):
@@ -233,8 +223,8 @@ class GoogleLLM(BaseLLM):
cleaned[key] = [self._clean_schema(item) for item in value] cleaned[key] = [self._clean_schema(item) for item in value]
else: else:
cleaned[key] = value cleaned[key] = value
# Validate that required properties actually exist in properties # Validate that required properties actually exist in properties
if "required" in cleaned and "properties" in cleaned: if "required" in cleaned and "properties" in cleaned:
valid_required = [] valid_required = []
properties_keys = set(cleaned["properties"].keys()) properties_keys = set(cleaned["properties"].keys())
@@ -247,7 +237,6 @@ class GoogleLLM(BaseLLM):
cleaned.pop("required", None) cleaned.pop("required", None)
elif "required" in cleaned and "properties" not in cleaned: elif "required" in cleaned and "properties" not in cleaned:
cleaned.pop("required", None) cleaned.pop("required", None)
return cleaned return cleaned
def _clean_tools_format(self, tools_list): def _clean_tools_format(self, tools_list):
@@ -263,7 +252,6 @@ class GoogleLLM(BaseLLM):
cleaned_properties = {} cleaned_properties = {}
for k, v in properties.items(): for k, v in properties.items():
cleaned_properties[k] = self._clean_schema(v) cleaned_properties[k] = self._clean_schema(v)
genai_function = dict( genai_function = dict(
name=function["name"], name=function["name"],
description=function["description"], description=function["description"],
@@ -282,10 +270,8 @@ class GoogleLLM(BaseLLM):
name=function["name"], name=function["name"],
description=function["description"], description=function["description"],
) )
genai_tool = types.Tool(function_declarations=[genai_function]) genai_tool = types.Tool(function_declarations=[genai_function])
genai_tools.append(genai_tool) genai_tools.append(genai_tool)
return genai_tools return genai_tools
def _raw_gen( def _raw_gen(
@@ -307,16 +293,14 @@ class GoogleLLM(BaseLLM):
if messages[0].role == "system": if messages[0].role == "system":
config.system_instruction = messages[0].parts[0].text config.system_instruction = messages[0].parts[0].text
messages = messages[1:] messages = messages[1:]
if tools: if tools:
cleaned_tools = self._clean_tools_format(tools) cleaned_tools = self._clean_tools_format(tools)
config.tools = cleaned_tools config.tools = cleaned_tools
# Add response schema for structured output if provided # Add response schema for structured output if provided
if response_schema: if response_schema:
config.response_schema = response_schema config.response_schema = response_schema
config.response_mime_type = "application/json" config.response_mime_type = "application/json"
response = client.models.generate_content( response = client.models.generate_content(
model=model, model=model,
contents=messages, contents=messages,
@@ -347,17 +331,16 @@ class GoogleLLM(BaseLLM):
if messages[0].role == "system": if messages[0].role == "system":
config.system_instruction = messages[0].parts[0].text config.system_instruction = messages[0].parts[0].text
messages = messages[1:] messages = messages[1:]
if tools: if tools:
cleaned_tools = self._clean_tools_format(tools) cleaned_tools = self._clean_tools_format(tools)
config.tools = cleaned_tools config.tools = cleaned_tools
# Add response schema for structured output if provided # Add response schema for structured output if provided
if response_schema: if response_schema:
config.response_schema = response_schema config.response_schema = response_schema
config.response_mime_type = "application/json" config.response_mime_type = "application/json"
# Check if we have both tools and file attachments # Check if we have both tools and file attachments
has_attachments = False has_attachments = False
for message in messages: for message in messages:
for part in message.parts: for part in message.parts:
@@ -366,7 +349,6 @@ class GoogleLLM(BaseLLM):
break break
if has_attachments: if has_attachments:
break break
logging.info( logging.info(
f"GoogleLLM: Starting stream generation. Model: {model}, Messages: {json.dumps(messages, default=str)}, Has attachments: {has_attachments}" f"GoogleLLM: Starting stream generation. Model: {model}, Messages: {json.dumps(messages, default=str)}, Has attachments: {has_attachments}"
) )
@@ -405,7 +387,6 @@ class GoogleLLM(BaseLLM):
"""Convert JSON schema to Google AI structured output format.""" """Convert JSON schema to Google AI structured output format."""
if not json_schema: if not json_schema:
return None return None
type_map = { type_map = {
"object": "OBJECT", "object": "OBJECT",
"array": "ARRAY", "array": "ARRAY",
@@ -418,12 +399,10 @@ class GoogleLLM(BaseLLM):
def convert(schema): def convert(schema):
if not isinstance(schema, dict): if not isinstance(schema, dict):
return schema return schema
result = {} result = {}
schema_type = schema.get("type") schema_type = schema.get("type")
if schema_type: if schema_type:
result["type"] = type_map.get(schema_type.lower(), schema_type.upper()) result["type"] = type_map.get(schema_type.lower(), schema_type.upper())
for key in [ for key in [
"description", "description",
"nullable", "nullable",
@@ -435,7 +414,6 @@ class GoogleLLM(BaseLLM):
]: ]:
if key in schema: if key in schema:
result[key] = schema[key] result[key] = schema[key]
if "format" in schema: if "format" in schema:
format_value = schema["format"] format_value = schema["format"]
if schema_type == "string": if schema_type == "string":
@@ -445,21 +423,17 @@ class GoogleLLM(BaseLLM):
result["format"] = format_value result["format"] = format_value
else: else:
result["format"] = format_value result["format"] = format_value
if "properties" in schema: if "properties" in schema:
result["properties"] = { result["properties"] = {
k: convert(v) for k, v in schema["properties"].items() k: convert(v) for k, v in schema["properties"].items()
} }
if "propertyOrdering" not in result and result.get("type") == "OBJECT": if "propertyOrdering" not in result and result.get("type") == "OBJECT":
result["propertyOrdering"] = list(result["properties"].keys()) result["propertyOrdering"] = list(result["properties"].keys())
if "items" in schema: if "items" in schema:
result["items"] = convert(schema["items"]) result["items"] = convert(schema["items"])
for field in ["anyOf", "oneOf", "allOf"]: for field in ["anyOf", "oneOf", "allOf"]:
if field in schema: if field in schema:
result[field] = [convert(s) for s in schema[field]] result[field] = [convert(s) for s in schema[field]]
return result return result
try: try:

View File

@@ -1,13 +1,18 @@
from application.llm.base import BaseLLM
from openai import OpenAI from openai import OpenAI
from application.core.settings import settings
from application.llm.base import BaseLLM
class GroqLLM(BaseLLM): class GroqLLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs): def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.client = OpenAI(api_key=api_key, base_url="https://api.groq.com/openai/v1") self.api_key = api_key or settings.GROQ_API_KEY or settings.API_KEY
self.api_key = api_key
self.user_api_key = user_api_key self.user_api_key = user_api_key
self.client = OpenAI(
api_key=self.api_key, base_url="https://api.groq.com/openai/v1"
)
def _raw_gen(self, baseself, model, messages, stream=False, tools=None, **kwargs): def _raw_gen(self, baseself, model, messages, stream=False, tools=None, **kwargs):
if tools: if tools:

View File

@@ -282,7 +282,7 @@ class LLMHandler(ABC):
messages = e.value messages = e.value
break break
response = agent.llm.gen( response = agent.llm.gen(
model=agent.gpt_model, messages=messages, tools=agent.tools model=agent.model_id, messages=messages, tools=agent.tools
) )
parsed = self.parse_response(response) parsed = self.parse_response(response)
self.llm_calls.append(build_stack_data(agent.llm)) self.llm_calls.append(build_stack_data(agent.llm))
@@ -337,7 +337,7 @@ class LLMHandler(ABC):
tool_calls = {} tool_calls = {}
response = agent.llm.gen_stream( response = agent.llm.gen_stream(
model=agent.gpt_model, messages=messages, tools=agent.tools model=agent.model_id, messages=messages, tools=agent.tools
) )
self.llm_calls.append(build_stack_data(agent.llm)) self.llm_calls.append(build_stack_data(agent.llm))

View File

@@ -1,13 +1,17 @@
from application.llm.groq import GroqLLM import logging
from application.llm.openai import OpenAILLM, AzureOpenAILLM
from application.llm.sagemaker import SagemakerAPILLM
from application.llm.huggingface import HuggingFaceLLM
from application.llm.llama_cpp import LlamaCpp
from application.llm.anthropic import AnthropicLLM from application.llm.anthropic import AnthropicLLM
from application.llm.docsgpt_provider import DocsGPTAPILLM from application.llm.docsgpt_provider import DocsGPTAPILLM
from application.llm.premai import PremAILLM
from application.llm.google_ai import GoogleLLM from application.llm.google_ai import GoogleLLM
from application.llm.groq import GroqLLM
from application.llm.huggingface import HuggingFaceLLM
from application.llm.llama_cpp import LlamaCpp
from application.llm.novita import NovitaLLM from application.llm.novita import NovitaLLM
from application.llm.openai import AzureOpenAILLM, OpenAILLM
from application.llm.premai import PremAILLM
from application.llm.sagemaker import SagemakerAPILLM
logger = logging.getLogger(__name__)
class LLMCreator: class LLMCreator:
@@ -26,10 +30,26 @@ class LLMCreator:
} }
@classmethod @classmethod
def create_llm(cls, type, api_key, user_api_key, decoded_token, *args, **kwargs): def create_llm(
cls, type, api_key, user_api_key, decoded_token, model_id=None, *args, **kwargs
):
from application.core.model_utils import get_base_url_for_model
llm_class = cls.llms.get(type.lower()) llm_class = cls.llms.get(type.lower())
if not llm_class: if not llm_class:
raise ValueError(f"No LLM class found for type {type}") raise ValueError(f"No LLM class found for type {type}")
# Extract base_url from model configuration if model_id is provided
base_url = None
if model_id:
base_url = get_base_url_for_model(model_id)
return llm_class( return llm_class(
api_key, user_api_key, decoded_token=decoded_token, *args, **kwargs api_key,
user_api_key,
decoded_token=decoded_token,
model_id=model_id,
base_url=base_url,
*args,
**kwargs,
) )

View File

@@ -2,6 +2,8 @@ import base64
import json import json
import logging import logging
from openai import OpenAI
from application.core.settings import settings from application.core.settings import settings
from application.llm.base import BaseLLM from application.llm.base import BaseLLM
from application.storage.storage_creator import StorageCreator from application.storage.storage_creator import StorageCreator
@@ -9,20 +11,25 @@ from application.storage.storage_creator import StorageCreator
class OpenAILLM(BaseLLM): class OpenAILLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs): def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
from openai import OpenAI
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
if ( self.api_key = api_key or settings.OPENAI_API_KEY or settings.API_KEY
self.user_api_key = user_api_key
# Priority: 1) Parameter base_url, 2) Settings OPENAI_BASE_URL, 3) Default
effective_base_url = None
if base_url and isinstance(base_url, str) and base_url.strip():
effective_base_url = base_url
elif (
isinstance(settings.OPENAI_BASE_URL, str) isinstance(settings.OPENAI_BASE_URL, str)
and settings.OPENAI_BASE_URL.strip() and settings.OPENAI_BASE_URL.strip()
): ):
self.client = OpenAI(api_key=api_key, base_url=settings.OPENAI_BASE_URL) effective_base_url = settings.OPENAI_BASE_URL
else: else:
DEFAULT_OPENAI_API_BASE = "https://api.openai.com/v1" effective_base_url = "https://api.openai.com/v1"
self.client = OpenAI(api_key=api_key, base_url=DEFAULT_OPENAI_API_BASE)
self.api_key = api_key self.client = OpenAI(api_key=self.api_key, base_url=effective_base_url)
self.user_api_key = user_api_key
self.storage = StorageCreator.get_storage() self.storage = StorageCreator.get_storage()
def _clean_messages_openai(self, messages): def _clean_messages_openai(self, messages):
@@ -33,7 +40,6 @@ class OpenAILLM(BaseLLM):
if role == "model": if role == "model":
role = "assistant" role = "assistant"
if role and content is not None: if role and content is not None:
if isinstance(content, str): if isinstance(content, str):
cleaned_messages.append({"role": role, "content": content}) cleaned_messages.append({"role": role, "content": content})
@@ -107,7 +113,6 @@ class OpenAILLM(BaseLLM):
) )
else: else:
raise ValueError(f"Unexpected content type: {type(content)}") raise ValueError(f"Unexpected content type: {type(content)}")
return cleaned_messages return cleaned_messages
def _raw_gen( def _raw_gen(
@@ -132,10 +137,8 @@ class OpenAILLM(BaseLLM):
if tools: if tools:
request_params["tools"] = tools request_params["tools"] = tools
if response_format: if response_format:
request_params["response_format"] = response_format request_params["response_format"] = response_format
response = self.client.chat.completions.create(**request_params) response = self.client.chat.completions.create(**request_params)
if tools: if tools:
@@ -165,10 +168,8 @@ class OpenAILLM(BaseLLM):
if tools: if tools:
request_params["tools"] = tools request_params["tools"] = tools
if response_format: if response_format:
request_params["response_format"] = response_format request_params["response_format"] = response_format
response = self.client.chat.completions.create(**request_params) response = self.client.chat.completions.create(**request_params)
try: try:
@@ -194,7 +195,6 @@ class OpenAILLM(BaseLLM):
def prepare_structured_output_format(self, json_schema): def prepare_structured_output_format(self, json_schema):
if not json_schema: if not json_schema:
return None return None
try: try:
def add_additional_properties_false(schema_obj): def add_additional_properties_false(schema_obj):
@@ -204,11 +204,11 @@ class OpenAILLM(BaseLLM):
if schema_copy.get("type") == "object": if schema_copy.get("type") == "object":
schema_copy["additionalProperties"] = False schema_copy["additionalProperties"] = False
# Ensure 'required' includes all properties for OpenAI strict mode # Ensure 'required' includes all properties for OpenAI strict mode
if "properties" in schema_copy: if "properties" in schema_copy:
schema_copy["required"] = list( schema_copy["required"] = list(
schema_copy["properties"].keys() schema_copy["properties"].keys()
) )
for key, value in schema_copy.items(): for key, value in schema_copy.items():
if key == "properties" and isinstance(value, dict): if key == "properties" and isinstance(value, dict):
schema_copy[key] = { schema_copy[key] = {
@@ -224,7 +224,6 @@ class OpenAILLM(BaseLLM):
add_additional_properties_false(sub_schema) add_additional_properties_false(sub_schema)
for sub_schema in value for sub_schema in value
] ]
return schema_copy return schema_copy
return schema_obj return schema_obj
@@ -243,7 +242,6 @@ class OpenAILLM(BaseLLM):
} }
return result return result
except Exception as e: except Exception as e:
logging.error(f"Error preparing structured output format: {e}") logging.error(f"Error preparing structured output format: {e}")
return None return None
@@ -277,21 +275,19 @@ class OpenAILLM(BaseLLM):
""" """
if not attachments: if not attachments:
return messages return messages
prepared_messages = messages.copy() prepared_messages = messages.copy()
# Find the user message to attach file_id to the last one # Find the user message to attach file_id to the last one
user_message_index = None user_message_index = None
for i in range(len(prepared_messages) - 1, -1, -1): for i in range(len(prepared_messages) - 1, -1, -1):
if prepared_messages[i].get("role") == "user": if prepared_messages[i].get("role") == "user":
user_message_index = i user_message_index = i
break break
if user_message_index is None: if user_message_index is None:
user_message = {"role": "user", "content": []} user_message = {"role": "user", "content": []}
prepared_messages.append(user_message) prepared_messages.append(user_message)
user_message_index = len(prepared_messages) - 1 user_message_index = len(prepared_messages) - 1
if isinstance(prepared_messages[user_message_index].get("content"), str): if isinstance(prepared_messages[user_message_index].get("content"), str):
text_content = prepared_messages[user_message_index]["content"] text_content = prepared_messages[user_message_index]["content"]
prepared_messages[user_message_index]["content"] = [ prepared_messages[user_message_index]["content"] = [
@@ -299,7 +295,6 @@ class OpenAILLM(BaseLLM):
] ]
elif not isinstance(prepared_messages[user_message_index].get("content"), list): elif not isinstance(prepared_messages[user_message_index].get("content"), list):
prepared_messages[user_message_index]["content"] = [] prepared_messages[user_message_index]["content"] = []
for attachment in attachments: for attachment in attachments:
mime_type = attachment.get("mime_type") mime_type = attachment.get("mime_type")
@@ -326,6 +321,7 @@ class OpenAILLM(BaseLLM):
} }
) )
# Handle PDFs using the file API # Handle PDFs using the file API
elif mime_type == "application/pdf": elif mime_type == "application/pdf":
try: try:
file_id = self._upload_file_to_openai(attachment) file_id = self._upload_file_to_openai(attachment)
@@ -341,7 +337,6 @@ class OpenAILLM(BaseLLM):
"text": f"File content:\n\n{attachment['content']}", "text": f"File content:\n\n{attachment['content']}",
} }
) )
return prepared_messages return prepared_messages
def _get_base64_image(self, attachment): def _get_base64_image(self, attachment):
@@ -357,7 +352,6 @@ class OpenAILLM(BaseLLM):
file_path = attachment.get("path") file_path = attachment.get("path")
if not file_path: if not file_path:
raise ValueError("No file path provided in attachment") raise ValueError("No file path provided in attachment")
try: try:
with self.storage.get_file(file_path) as image_file: with self.storage.get_file(file_path) as image_file:
return base64.b64encode(image_file.read()).decode("utf-8") return base64.b64encode(image_file.read()).decode("utf-8")
@@ -381,12 +375,10 @@ class OpenAILLM(BaseLLM):
if "openai_file_id" in attachment: if "openai_file_id" in attachment:
return attachment["openai_file_id"] return attachment["openai_file_id"]
file_path = attachment.get("path") file_path = attachment.get("path")
if not self.storage.file_exists(file_path): if not self.storage.file_exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}") raise FileNotFoundError(f"File not found: {file_path}")
try: try:
file_id = self.storage.process_file( file_id = self.storage.process_file(
file_path, file_path,
@@ -404,7 +396,6 @@ class OpenAILLM(BaseLLM):
attachments_collection.update_one( attachments_collection.update_one(
{"_id": attachment["_id"]}, {"$set": {"openai_file_id": file_id}} {"_id": attachment["_id"]}, {"$set": {"openai_file_id": file_id}}
) )
return file_id return file_id
except Exception as e: except Exception as e:
logging.error(f"Error uploading file to OpenAI: {e}", exc_info=True) logging.error(f"Error uploading file to OpenAI: {e}", exc_info=True)

View File

@@ -16,7 +16,7 @@ class ClassicRAG(BaseRetriever):
prompt="", prompt="",
chunks=2, chunks=2,
doc_token_limit=50000, doc_token_limit=50000,
gpt_model="docsgpt", model_id="docsgpt-local",
user_api_key=None, user_api_key=None,
llm_name=settings.LLM_PROVIDER, llm_name=settings.LLM_PROVIDER,
api_key=settings.API_KEY, api_key=settings.API_KEY,
@@ -40,7 +40,7 @@ class ClassicRAG(BaseRetriever):
f"ClassicRAG initialized with chunks={self.chunks}, user_api_key={user_identifier}, " f"ClassicRAG initialized with chunks={self.chunks}, user_api_key={user_identifier}, "
f"sources={'active_docs' in source and source['active_docs'] is not None}" f"sources={'active_docs' in source and source['active_docs'] is not None}"
) )
self.gpt_model = gpt_model self.model_id = model_id
self.doc_token_limit = doc_token_limit self.doc_token_limit = doc_token_limit
self.user_api_key = user_api_key self.user_api_key = user_api_key
self.llm_name = llm_name self.llm_name = llm_name
@@ -100,7 +100,7 @@ class ClassicRAG(BaseRetriever):
] ]
try: try:
rephrased_query = self.llm.gen(model=self.gpt_model, messages=messages) rephrased_query = self.llm.gen(model=self.model_id, messages=messages)
print(f"Rephrased query: {rephrased_query}") print(f"Rephrased query: {rephrased_query}")
return rephrased_query if rephrased_query else self.original_question return rephrased_query if rephrased_query else self.original_question
except Exception as e: except Exception as e:

View File

@@ -7,6 +7,8 @@ import tiktoken
from flask import jsonify, make_response from flask import jsonify, make_response
from werkzeug.utils import secure_filename from werkzeug.utils import secure_filename
from application.core.model_utils import get_token_limit
from application.core.settings import settings from application.core.settings import settings
@@ -75,11 +77,9 @@ def count_tokens_docs(docs):
def calculate_doc_token_budget( def calculate_doc_token_budget(
gpt_model: str = "gpt-4o", history_token_limit: int = 2000 model_id: str = "gpt-4o", history_token_limit: int = 2000
) -> int: ) -> int:
total_context = settings.LLM_TOKEN_LIMITS.get( total_context = get_token_limit(model_id)
gpt_model, settings.DEFAULT_LLM_TOKEN_LIMIT
)
reserved = sum(settings.RESERVED_TOKENS.values()) reserved = sum(settings.RESERVED_TOKENS.values())
doc_budget = total_context - history_token_limit - reserved doc_budget = total_context - history_token_limit - reserved
return max(doc_budget, 1000) return max(doc_budget, 1000)
@@ -144,16 +144,13 @@ def get_hash(data):
return hashlib.md5(data.encode(), usedforsecurity=False).hexdigest() return hashlib.md5(data.encode(), usedforsecurity=False).hexdigest()
def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"): def limit_chat_history(history, max_token_limit=None, model_id="docsgpt-local"):
"""Limit chat history to fit within token limit.""" """Limit chat history to fit within token limit."""
from application.core.settings import settings model_token_limit = get_token_limit(model_id)
max_token_limit = ( max_token_limit = (
max_token_limit max_token_limit
if max_token_limit if max_token_limit and max_token_limit < model_token_limit
and max_token_limit else model_token_limit
< settings.LLM_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_LLM_TOKEN_LIMIT)
else settings.LLM_TOKEN_LIMITS.get(gpt_model, settings.DEFAULT_LLM_TOKEN_LIMIT)
) )
if not history: if not history:
@@ -205,37 +202,44 @@ def clean_text_for_tts(text: str) -> str:
clean text for Text-to-Speech processing. clean text for Text-to-Speech processing.
""" """
# Handle code blocks and links # Handle code blocks and links
text = re.sub(r'```mermaid[\s\S]*?```', ' flowchart, ', text) ## ```mermaid...```
text = re.sub(r'```[\s\S]*?```', ' code block, ', text) ## ```code``` text = re.sub(r"```mermaid[\s\S]*?```", " flowchart, ", text) ## ```mermaid...```
text = re.sub(r'\[([^\]]+)\]\([^\)]+\)', r'\1', text) ## [text](url) text = re.sub(r"```[\s\S]*?```", " code block, ", text) ## ```code```
text = re.sub(r'!\[([^\]]*)\]\([^\)]+\)', '', text) ## ![alt](url) text = re.sub(r"\[([^\]]+)\]\([^\)]+\)", r"\1", text) ## [text](url)
text = re.sub(r"!\[([^\]]*)\]\([^\)]+\)", "", text) ## ![alt](url)
# Remove markdown formatting # Remove markdown formatting
text = re.sub(r'`([^`]+)`', r'\1', text) ## `code`
text = re.sub(r'\{([^}]*)\}', r' \1 ', text) ## {text}
text = re.sub(r'[{}]', ' ', text) ## unmatched {}
text = re.sub(r'\[([^\]]+)\]', r' \1 ', text) ## [text]
text = re.sub(r'[\[\]]', ' ', text) ## unmatched []
text = re.sub(r'(\*\*|__)(.*?)\1', r'\2', text) ## **bold** __bold__
text = re.sub(r'(\*|_)(.*?)\1', r'\2', text) ## *italic* _italic_
text = re.sub(r'^#{1,6}\s+', '', text, flags=re.MULTILINE) ## # headers
text = re.sub(r'^>\s+', '', text, flags=re.MULTILINE) ## > blockquotes
text = re.sub(r'^[\s]*[-\*\+]\s+', '', text, flags=re.MULTILINE) ## - * + lists
text = re.sub(r'^[\s]*\d+\.\s+', '', text, flags=re.MULTILINE) ## 1. numbered lists
text = re.sub(r'^[\*\-_]{3,}\s*$', '', text, flags=re.MULTILINE) ## --- *** ___ rules
text = re.sub(r'<[^>]*>', '', text) ## <html> tags
#Remove non-ASCII (emojis, special Unicode) text = re.sub(r"`([^`]+)`", r"\1", text) ## `code`
text = re.sub(r'[^\x20-\x7E\n\r\t]', '', text) text = re.sub(r"\{([^}]*)\}", r" \1 ", text) ## {text}
text = re.sub(r"[{}]", " ", text) ## unmatched {}
text = re.sub(r"\[([^\]]+)\]", r" \1 ", text) ## [text]
text = re.sub(r"[\[\]]", " ", text) ## unmatched []
text = re.sub(r"(\*\*|__)(.*?)\1", r"\2", text) ## **bold** __bold__
text = re.sub(r"(\*|_)(.*?)\1", r"\2", text) ## *italic* _italic_
text = re.sub(r"^#{1,6}\s+", "", text, flags=re.MULTILINE) ## # headers
text = re.sub(r"^>\s+", "", text, flags=re.MULTILINE) ## > blockquotes
text = re.sub(r"^[\s]*[-\*\+]\s+", "", text, flags=re.MULTILINE) ## - * + lists
text = re.sub(r"^[\s]*\d+\.\s+", "", text, flags=re.MULTILINE) ## 1. numbered lists
text = re.sub(
r"^[\*\-_]{3,}\s*$", "", text, flags=re.MULTILINE
) ## --- *** ___ rules
text = re.sub(r"<[^>]*>", "", text) ## <html> tags
#Replace special sequences # Remove non-ASCII (emojis, special Unicode)
text = re.sub(r'-->', ', ', text) ## -->
text = re.sub(r'<--', ', ', text) ## <--
text = re.sub(r'=>', ', ', text) ## =>
text = re.sub(r'::', ' ', text) ## ::
#Normalize whitespace text = re.sub(r"[^\x20-\x7E\n\r\t]", "", text)
text = re.sub(r'\s+', ' ', text)
# Replace special sequences
text = re.sub(r"-->", ", ", text) ## -->
text = re.sub(r"<--", ", ", text) ## <--
text = re.sub(r"=>", ", ", text) ## =>
text = re.sub(r"::", " ", text) ## ::
# Normalize whitespace
text = re.sub(r"\s+", " ", text)
text = text.strip() text = text.strip()
return text return text

View File

@@ -165,7 +165,7 @@ def run_agent_logic(agent_config, input_data):
agent_type, agent_type,
endpoint="webhook", endpoint="webhook",
llm_name=settings.LLM_PROVIDER, llm_name=settings.LLM_PROVIDER,
gpt_model=settings.LLM_NAME, model_id=settings.LLM_NAME,
api_key=settings.API_KEY, api_key=settings.API_KEY,
user_api_key=user_api_key, user_api_key=user_api_key,
prompt=prompt, prompt=prompt,
@@ -180,7 +180,7 @@ def run_agent_logic(agent_config, input_data):
prompt=prompt, prompt=prompt,
chunks=chunks, chunks=chunks,
token_limit=settings.DEFAULT_MAX_HISTORY, token_limit=settings.DEFAULT_MAX_HISTORY,
gpt_model=settings.LLM_NAME, model_id=settings.LLM_NAME,
user_api_key=user_api_key, user_api_key=user_api_key,
decoded_token=decoded_token, decoded_token=decoded_token,
) )

View File

@@ -57,7 +57,7 @@ The easiest way to launch DocsGPT is using the provided `setup.sh` script. This
* **4) Connect Cloud API Provider:** This option lets you connect DocsGPT to a commercial Cloud API provider such as OpenAI, Google (Vertex AI/Gemini), Anthropic (Claude), Groq, HuggingFace Inference API, or Azure OpenAI. You will need an API key from your chosen provider. Select this if you prefer to use a powerful cloud-based LLM. * **4) Connect Cloud API Provider:** This option lets you connect DocsGPT to a commercial Cloud API provider such as OpenAI, Google (Vertex AI/Gemini), Anthropic (Claude), Groq, HuggingFace Inference API, or Azure OpenAI. You will need an API key from your chosen provider. Select this if you prefer to use a powerful cloud-based LLM.
* **5) Modify DocsGPT's source code and rebuild the Docker images locally. Instead of pulling prebuilt images from Docker Hub or using the hosted/public API, you build the entire backend and frontend from source, customizing how DocsGPT works internally, or run it in an environment without internet access. * **5) Modify DocsGPT's source code and rebuild the Docker images locally.** Instead of pulling prebuilt images from Docker Hub or using the hosted/public API, you build the entire backend and frontend from source, customizing how DocsGPT works internally, or run it in an environment without internet access.
After selecting an option and providing any required information (like API keys or model names), the script will configure your `.env` file and start DocsGPT using Docker Compose. After selecting an option and providing any required information (like API keys or model names), the script will configure your `.env` file and start DocsGPT using Docker Compose.
@@ -119,4 +119,4 @@ If you prefer a more manual approach, you can follow our [Docker Deployment docu
For more advanced customization of DocsGPT settings, such as configuring vector stores, embedding models, and other parameters, please refer to the [DocsGPT Settings documentation](/Deploying/DocsGPT-Settings). This guide explains how to modify the `.env` file or `settings.py` for deeper configuration. For more advanced customization of DocsGPT settings, such as configuring vector stores, embedding models, and other parameters, please refer to the [DocsGPT Settings documentation](/Deploying/DocsGPT-Settings). This guide explains how to modify the `.env` file or `settings.py` for deeper configuration.
Enjoy using DocsGPT! Enjoy using DocsGPT!

View File

@@ -1,6 +1,6 @@
aiohttp>=3,<4 aiohttp>=3,<4
certifi==2024.7.4 certifi==2024.7.4
h11==0.14.0 h11==0.16.0
httpcore==1.0.5 httpcore==1.0.5
httpx==0.27.0 httpx==0.27.0
idna==3.7 idna==3.7

View File

@@ -3,4 +3,4 @@ VITE_BASE_URL=http://localhost:5173
VITE_API_HOST=http://127.0.0.1:7091 VITE_API_HOST=http://127.0.0.1:7091
VITE_API_STREAMING=true VITE_API_STREAMING=true
VITE_NOTIFICATION_TEXT="What's new in 0.14.0 — Changelog" VITE_NOTIFICATION_TEXT="What's new in 0.14.0 — Changelog"
VITE_NOTIFICATION_LINK="#" VITE_NOTIFICATION_LINK="https://blog.docsgpt.cloud/docsgpt-0-14-agents-automate-integrate-and-innovate/"

View File

@@ -1,17 +0,0 @@
node_modules/
dist/
prettier.config.cjs
.eslintrc.cjs
env.d.ts
public/
assets/
vite-env.d.ts
.prettierignore
package-lock.json
package.json
postcss.config.cjs
prettier.config.cjs
tailwind.config.cjs
tsconfig.json
tsconfig.node.json
vite.config.ts

View File

@@ -1,45 +0,0 @@
module.exports = {
env: {
browser: true,
es2021: true,
node: true,
},
extends: [
'eslint:recommended',
'plugin:@typescript-eslint/recommended',
'plugin:react/recommended',
'plugin:prettier/recommended',
],
overrides: [],
parser: '@typescript-eslint/parser',
parserOptions: {
ecmaVersion: 'latest',
sourceType: 'module',
},
plugins: ['react', 'unused-imports'],
rules: {
'react/prop-types': 'off',
'unused-imports/no-unused-imports': 'error',
'react/react-in-jsx-scope': 'off',
'prettier/prettier': [
'error',
{
endOfLine: 'auto',
},
],
},
settings: {
'import/parsers': {
'@typescript-eslint/parser': ['.ts', '.tsx'],
},
react: {
version: 'detect',
},
'import/resolver': {
node: {
paths: ['src'],
extensions: ['.js', '.jsx', '.ts', '.tsx'],
},
},
},
};

78
frontend/eslint.config.js Normal file
View File

@@ -0,0 +1,78 @@
import js from '@eslint/js'
import tsParser from '@typescript-eslint/parser'
import tsPlugin from '@typescript-eslint/eslint-plugin'
import react from 'eslint-plugin-react'
import unusedImports from 'eslint-plugin-unused-imports'
import prettier from 'eslint-plugin-prettier'
import globals from 'globals'
export default [
{
ignores: [
'node_modules/',
'dist/',
'prettier.config.cjs',
'.eslintrc.cjs',
'env.d.ts',
'public/',
'assets/',
'vite-env.d.ts',
'.prettierignore',
'package-lock.json',
'package.json',
'postcss.config.cjs',
'tailwind.config.cjs',
'tsconfig.json',
'tsconfig.node.json',
'vite.config.ts',
],
},
{
files: ['**/*.{js,jsx,ts,tsx}'],
languageOptions: {
ecmaVersion: 'latest',
sourceType: 'module',
parser: tsParser,
parserOptions: {
ecmaFeatures: {
jsx: true,
},
},
globals: {
...globals.browser,
...globals.es2021,
...globals.node,
},
},
plugins: {
'@typescript-eslint': tsPlugin,
react,
'unused-imports': unusedImports,
prettier,
},
rules: {
...js.configs.recommended.rules,
...tsPlugin.configs.recommended.rules,
...react.configs.recommended.rules,
...prettier.configs.recommended.rules,
'react/prop-types': 'off',
'unused-imports/no-unused-imports': 'error',
'react/react-in-jsx-scope': 'off',
'no-undef': 'off',
'@typescript-eslint/no-explicit-any': 'warn',
'@typescript-eslint/no-unused-vars': 'warn',
'@typescript-eslint/no-unused-expressions': 'warn',
'prettier/prettier': [
'error',
{
endOfLine: 'auto',
},
],
},
settings: {
react: {
version: 'detect',
},
},
},
]

File diff suppressed because it is too large Load Diff

View File

@@ -19,7 +19,7 @@
] ]
}, },
"dependencies": { "dependencies": {
"@reduxjs/toolkit": "^2.8.2", "@reduxjs/toolkit": "^2.10.1",
"chart.js": "^4.4.4", "chart.js": "^4.4.4",
"clsx": "^2.1.1", "clsx": "^2.1.1",
"copy-to-clipboard": "^3.3.3", "copy-to-clipboard": "^3.3.3",
@@ -33,7 +33,7 @@
"react-dom": "^19.1.1", "react-dom": "^19.1.1",
"react-dropzone": "^14.3.8", "react-dropzone": "^14.3.8",
"react-google-drive-picker": "^1.2.2", "react-google-drive-picker": "^1.2.2",
"react-i18next": "^15.4.0", "react-i18next": "^16.2.4",
"react-markdown": "^9.0.1", "react-markdown": "^9.0.1",
"react-redux": "^9.2.0", "react-redux": "^9.2.0",
"react-router-dom": "^7.6.1", "react-router-dom": "^7.6.1",
@@ -46,30 +46,28 @@
"devDependencies": { "devDependencies": {
"@tailwindcss/postcss": "^4.1.10", "@tailwindcss/postcss": "^4.1.10",
"@types/lodash": "^4.17.20", "@types/lodash": "^4.17.20",
"@types/mermaid": "^9.1.0",
"@types/react": "^19.1.8", "@types/react": "^19.1.8",
"@types/react-dom": "^19.1.7", "@types/react-dom": "^19.1.7",
"@types/react-syntax-highlighter": "^15.5.13", "@types/react-syntax-highlighter": "^15.5.13",
"@typescript-eslint/eslint-plugin": "^6.21.0", "@typescript-eslint/eslint-plugin": "^8.46.3",
"@typescript-eslint/parser": "^6.21.0", "@typescript-eslint/parser": "^8.46.3",
"@vitejs/plugin-react": "^4.3.4", "@vitejs/plugin-react": "^4.3.4",
"eslint": "^8.57.1", "eslint": "^9.39.1",
"eslint-config-prettier": "^10.1.5", "eslint-config-prettier": "^10.1.5",
"eslint-config-standard-with-typescript": "^43.0.1",
"eslint-plugin-import": "^2.31.0", "eslint-plugin-import": "^2.31.0",
"eslint-plugin-n": "^15.7.0", "eslint-plugin-n": "^17.23.1",
"eslint-plugin-prettier": "^5.5.4", "eslint-plugin-prettier": "^5.5.4",
"eslint-plugin-promise": "^6.6.0", "eslint-plugin-promise": "^6.6.0",
"eslint-plugin-react": "^7.37.5", "eslint-plugin-react": "^7.37.5",
"eslint-plugin-unused-imports": "^4.1.4", "eslint-plugin-unused-imports": "^4.1.4",
"husky": "^8.0.0", "husky": "^9.1.7",
"lint-staged": "^15.3.0", "lint-staged": "^15.3.0",
"postcss": "^8.4.49", "postcss": "^8.4.49",
"prettier": "^3.5.3", "prettier": "^3.5.3",
"prettier-plugin-tailwindcss": "^0.6.13", "prettier-plugin-tailwindcss": "^0.7.1",
"tailwindcss": "^4.1.11", "tailwindcss": "^4.1.11",
"typescript": "^5.8.3", "typescript": "^5.8.3",
"vite": "^6.3.5", "vite": "^7.2.0",
"vite-plugin-svgr": "^4.3.0" "vite-plugin-svgr": "^4.3.0"
} }
} }

View File

@@ -1,6 +1,8 @@
import DocsGPT3 from './assets/cute_docsgpt3.svg';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import DocsGPT3 from './assets/cute_docsgpt3.svg';
import DropdownModel from './components/DropdownModel';
export default function Hero({ export default function Hero({
handleQuestion, handleQuestion,
}: { }: {
@@ -26,6 +28,10 @@ export default function Hero({
<span className="text-4xl font-semibold">DocsGPT</span> <span className="text-4xl font-semibold">DocsGPT</span>
<img className="mb-1 inline w-14" src={DocsGPT3} alt="docsgpt" /> <img className="mb-1 inline w-14" src={DocsGPT3} alt="docsgpt" />
</div> </div>
{/* Model Selector */}
<div className="relative w-72">
<DropdownModel />
</div>
</div> </div>
{/* Demo Buttons Section */} {/* Demo Buttons Section */}
@@ -38,7 +44,7 @@ export default function Hero({
<button <button
key={key} key={key}
onClick={() => handleQuestion({ question: demo.query })} onClick={() => handleQuestion({ question: demo.query })}
className={`border-dark-gray text-just-black hover:bg-cultured dark:border-dim-gray dark:text-chinese-white dark:hover:bg-charleston-green w-full rounded-[66px] border bg-transparent px-6 py-[14px] text-left transition-colors ${key >= 2 ? 'hidden md:block' : ''} // Show only 2 buttons on mobile`} className={`border-dark-gray text-just-black hover:bg-cultured dark:border-dim-gray dark:text-chinese-white dark:hover:bg-charleston-green w-full rounded-[66px] border bg-transparent px-6 py-[14px] text-left transition-colors ${key >= 2 ? 'hidden md:block' : ''}`}
> >
<p className="text-black-1000 dark:text-bright-gray mb-2 font-semibold"> <p className="text-black-1000 dark:text-bright-gray mb-2 font-semibold">
{demo.header} {demo.header}

View File

@@ -567,7 +567,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
<div className="flex items-center gap-1 pr-4"> <div className="flex items-center gap-1 pr-4">
<NavLink <NavLink
target="_blank" target="_blank"
to={'https://discord.gg/WHJdfbQDR4'} to={'https://discord.gg/vN7YFfdMpj'}
className={ className={
'rounded-full hover:bg-gray-100 dark:hover:bg-[#28292E]' 'rounded-full hover:bg-gray-100 dark:hover:bg-[#28292E]'
} }

View File

@@ -4,6 +4,7 @@ import { useTranslation } from 'react-i18next';
import { useDispatch, useSelector } from 'react-redux'; import { useDispatch, useSelector } from 'react-redux';
import { useNavigate, useParams } from 'react-router-dom'; import { useNavigate, useParams } from 'react-router-dom';
import modelService from '../api/services/modelService';
import userService from '../api/services/userService'; import userService from '../api/services/userService';
import ArrowLeft from '../assets/arrow-left.svg'; import ArrowLeft from '../assets/arrow-left.svg';
import SourceIcon from '../assets/source.svg'; import SourceIcon from '../assets/source.svg';
@@ -26,6 +27,7 @@ import { UserToolType } from '../settings/types';
import AgentPreview from './AgentPreview'; import AgentPreview from './AgentPreview';
import { Agent, ToolSummary } from './types'; import { Agent, ToolSummary } from './types';
import type { Model } from '../models/types';
const embeddingsName = const embeddingsName =
import.meta.env.VITE_EMBEDDINGS_NAME || import.meta.env.VITE_EMBEDDINGS_NAME ||
'huggingface_sentence-transformers/all-mpnet-base-v2'; 'huggingface_sentence-transformers/all-mpnet-base-v2';
@@ -59,18 +61,25 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
token_limit: undefined, token_limit: undefined,
limited_request_mode: false, limited_request_mode: false,
request_limit: undefined, request_limit: undefined,
models: [],
default_model_id: '',
}); });
const [imageFile, setImageFile] = useState<File | null>(null); const [imageFile, setImageFile] = useState<File | null>(null);
const [prompts, setPrompts] = useState< const [prompts, setPrompts] = useState<
{ name: string; id: string; type: string }[] { name: string; id: string; type: string }[]
>([]); >([]);
const [userTools, setUserTools] = useState<OptionType[]>([]); const [userTools, setUserTools] = useState<OptionType[]>([]);
const [availableModels, setAvailableModels] = useState<Model[]>([]);
const [isSourcePopupOpen, setIsSourcePopupOpen] = useState(false); const [isSourcePopupOpen, setIsSourcePopupOpen] = useState(false);
const [isToolsPopupOpen, setIsToolsPopupOpen] = useState(false); const [isToolsPopupOpen, setIsToolsPopupOpen] = useState(false);
const [isModelsPopupOpen, setIsModelsPopupOpen] = useState(false);
const [selectedSourceIds, setSelectedSourceIds] = useState< const [selectedSourceIds, setSelectedSourceIds] = useState<
Set<string | number> Set<string | number>
>(new Set()); >(new Set());
const [selectedTools, setSelectedTools] = useState<ToolSummary[]>([]); const [selectedTools, setSelectedTools] = useState<ToolSummary[]>([]);
const [selectedModelIds, setSelectedModelIds] = useState<Set<string>>(
new Set(),
);
const [deleteConfirmation, setDeleteConfirmation] = const [deleteConfirmation, setDeleteConfirmation] =
useState<ActiveState>('INACTIVE'); useState<ActiveState>('INACTIVE');
const [agentDetails, setAgentDetails] = useState<ActiveState>('INACTIVE'); const [agentDetails, setAgentDetails] = useState<ActiveState>('INACTIVE');
@@ -86,6 +95,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
const initialAgentRef = useRef<Agent | null>(null); const initialAgentRef = useRef<Agent | null>(null);
const sourceAnchorButtonRef = useRef<HTMLButtonElement>(null); const sourceAnchorButtonRef = useRef<HTMLButtonElement>(null);
const toolAnchorButtonRef = useRef<HTMLButtonElement>(null); const toolAnchorButtonRef = useRef<HTMLButtonElement>(null);
const modelAnchorButtonRef = useRef<HTMLButtonElement>(null);
const modeConfig = { const modeConfig = {
new: { new: {
@@ -224,6 +234,13 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
formData.append('json_schema', JSON.stringify(agent.json_schema)); formData.append('json_schema', JSON.stringify(agent.json_schema));
} }
if (agent.models && agent.models.length > 0) {
formData.append('models', JSON.stringify(agent.models));
}
if (agent.default_model_id) {
formData.append('default_model_id', agent.default_model_id);
}
try { try {
setDraftLoading(true); setDraftLoading(true);
const response = const response =
@@ -320,6 +337,13 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
formData.append('request_limit', '0'); formData.append('request_limit', '0');
} }
if (agent.models && agent.models.length > 0) {
formData.append('models', JSON.stringify(agent.models));
}
if (agent.default_model_id) {
formData.append('default_model_id', agent.default_model_id);
}
try { try {
setPublishLoading(true); setPublishLoading(true);
const response = const response =
@@ -388,8 +412,16 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
const data = await response.json(); const data = await response.json();
setPrompts(data); setPrompts(data);
}; };
const getModels = async () => {
const response = await modelService.getModels(null);
if (!response.ok) throw new Error('Failed to fetch models');
const data = await response.json();
const transformed = modelService.transformModels(data.models || []);
setAvailableModels(transformed);
};
getTools(); getTools();
getPrompts(); getPrompts();
getModels();
}, [token]); }, [token]);
// Auto-select default source if none selected // Auto-select default source if none selected
@@ -462,6 +494,34 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
} }
}, [agentId, mode, token]); }, [agentId, mode, token]);
useEffect(() => {
if (agent.models && agent.models.length > 0 && availableModels.length > 0) {
const agentModelIds = new Set(agent.models);
if (agentModelIds.size > 0 && selectedModelIds.size === 0) {
setSelectedModelIds(agentModelIds);
}
}
}, [agent.models, availableModels.length]);
useEffect(() => {
const modelsArray = Array.from(selectedModelIds);
if (modelsArray.length > 0) {
setAgent((prev) => ({
...prev,
models: modelsArray,
default_model_id: modelsArray.includes(prev.default_model_id || '')
? prev.default_model_id
: modelsArray[0],
}));
} else {
setAgent((prev) => ({
...prev,
models: [],
default_model_id: '',
}));
}
}, [selectedModelIds]);
useEffect(() => { useEffect(() => {
const selectedSources = Array.from(selectedSourceIds) const selectedSources = Array.from(selectedSourceIds)
.map((id) => .map((id) =>
@@ -882,6 +942,82 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
/> />
</div> </div>
</div> </div>
<div className="dark:bg-raisin-black rounded-[30px] bg-white px-6 py-3 dark:text-[#E0E0E0]">
<h2 className="text-lg font-semibold">
{t('agents.form.sections.models')}
</h2>
<div className="mt-3 flex flex-col gap-3">
<button
ref={modelAnchorButtonRef}
onClick={() => setIsModelsPopupOpen(!isModelsPopupOpen)}
className={`border-silver dark:bg-raisin-black w-full truncate rounded-3xl border bg-white px-5 py-3 text-left text-sm dark:border-[#7E7E7E] ${
selectedModelIds.size > 0
? 'text-jet dark:text-bright-gray'
: 'dark:text-silver text-gray-400'
}`}
>
{selectedModelIds.size > 0
? availableModels
.filter((m) => selectedModelIds.has(m.id))
.map((m) => m.display_name)
.join(', ')
: t('agents.form.placeholders.selectModels')}
</button>
<MultiSelectPopup
isOpen={isModelsPopupOpen}
onClose={() => setIsModelsPopupOpen(false)}
anchorRef={modelAnchorButtonRef}
options={availableModels.map((model) => ({
id: model.id,
label: model.display_name,
}))}
selectedIds={selectedModelIds}
onSelectionChange={(newSelectedIds: Set<string | number>) =>
setSelectedModelIds(
new Set(Array.from(newSelectedIds).map(String)),
)
}
title={t('agents.form.modelsPopup.title')}
searchPlaceholder={t(
'agents.form.modelsPopup.searchPlaceholder',
)}
noOptionsMessage={t('agents.form.modelsPopup.noOptionsMessage')}
/>
{selectedModelIds.size > 0 && (
<div>
<label className="mb-2 block text-sm font-medium">
{t('agents.form.labels.defaultModel')}
</label>
<Dropdown
options={availableModels
.filter((m) => selectedModelIds.has(m.id))
.map((m) => ({
label: m.display_name,
value: m.id,
}))}
selectedValue={
availableModels.find(
(m) => m.id === agent.default_model_id,
)?.display_name || null
}
onSelect={(option: { label: string; value: string }) =>
setAgent({ ...agent, default_model_id: option.value })
}
size="w-full"
rounded="3xl"
border="border"
buttonClassName="bg-white dark:bg-[#222327] border-silver dark:border-[#7E7E7E]"
optionsClassName="bg-white dark:bg-[#383838] border-silver dark:border-[#7E7E7E]"
placeholder={t(
'agents.form.placeholders.selectDefaultModel',
)}
placeholderClassName="text-gray-400 dark:text-silver"
contentSize="text-sm"
/>
</div>
)}
</div>
</div>
<div className="dark:bg-raisin-black rounded-[30px] bg-white px-6 py-3 dark:text-[#E0E0E0]"> <div className="dark:bg-raisin-black rounded-[30px] bg-white px-6 py-3 dark:text-[#E0E0E0]">
<button <button
onClick={() => onClick={() =>

View File

@@ -52,6 +52,10 @@ export const fetchPreviewAnswer = createAsyncThunk<
} }
if (state.preference) { if (state.preference) {
const modelId =
state.preference.selectedAgent?.default_model_id ||
state.preference.selectedModel?.id;
if (API_STREAMING) { if (API_STREAMING) {
await handleFetchAnswerSteaming( await handleFetchAnswerSteaming(
question, question,
@@ -120,22 +124,23 @@ export const fetchPreviewAnswer = createAsyncThunk<
indx, indx,
state.preference.selectedAgent?.id, state.preference.selectedAgent?.id,
attachmentIds, attachmentIds,
false, // Don't save preview conversations false,
modelId,
); );
} else { } else {
// Non-streaming implementation
const answer = await handleFetchAnswer( const answer = await handleFetchAnswer(
question, question,
signal, signal,
state.preference.token, state.preference.token,
state.preference.selectedDocs, state.preference.selectedDocs,
null, // No conversation ID for previews null,
state.preference.prompt.id, state.preference.prompt.id,
state.preference.chunks, state.preference.chunks,
state.preference.token_limit, state.preference.token_limit,
state.preference.selectedAgent?.id, state.preference.selectedAgent?.id,
attachmentIds, attachmentIds,
false, // Don't save preview conversations false,
modelId,
); );
if (answer) { if (answer) {

View File

@@ -32,4 +32,6 @@ export type Agent = {
token_limit?: number; token_limit?: number;
limited_request_mode?: boolean; limited_request_mode?: boolean;
request_limit?: number; request_limit?: number;
models?: string[];
default_model_id?: string;
}; };

View File

@@ -2,6 +2,7 @@ const endpoints = {
USER: { USER: {
CONFIG: '/api/config', CONFIG: '/api/config',
NEW_TOKEN: '/api/generate_token', NEW_TOKEN: '/api/generate_token',
MODELS: '/api/models',
DOCS: '/api/sources', DOCS: '/api/sources',
DOCS_PAGINATED: '/api/sources/paginated', DOCS_PAGINATED: '/api/sources/paginated',
API_KEYS: '/api/get_api_keys', API_KEYS: '/api/get_api_keys',

View File

@@ -0,0 +1,25 @@
import apiClient from '../client';
import endpoints from '../endpoints';
import type { AvailableModel, Model } from '../../models/types';
const modelService = {
getModels: (token: string | null): Promise<Response> =>
apiClient.get(endpoints.USER.MODELS, token, {}),
transformModels: (models: AvailableModel[]): Model[] =>
models.map((model) => ({
id: model.id,
value: model.id,
provider: model.provider,
display_name: model.display_name,
description: model.description,
context_window: model.context_window,
supported_attachment_types: model.supported_attachment_types,
supports_tools: model.supports_tools,
supports_structured_output: model.supports_structured_output,
supports_streaming: model.supports_streaming,
})),
};
export default modelService;

View File

@@ -0,0 +1,3 @@
<svg width="20" height="21" viewBox="0 0 20 21" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M10 0.75C4.62391 0.75 0.25 5.12391 0.25 10.5C0.25 15.8761 4.62391 20.25 10 20.25C15.3761 20.25 19.75 15.8761 19.75 10.5C19.75 5.12391 15.3761 0.75 10 0.75ZM15.0742 7.23234L8.77422 14.7323C8.70511 14.8147 8.61912 14.8812 8.52207 14.9273C8.42502 14.9735 8.31918 14.9983 8.21172 15H8.19906C8.09394 15 7.99 14.9778 7.89398 14.935C7.79797 14.8922 7.71202 14.8297 7.64172 14.7516L4.94172 11.7516C4.87315 11.6788 4.81981 11.5931 4.78483 11.4995C4.74986 11.4059 4.73395 11.3062 4.73805 11.2063C4.74215 11.1064 4.76617 11.0084 4.8087 10.9179C4.85124 10.8275 4.91142 10.7464 4.98572 10.6796C5.06002 10.6127 5.14694 10.5614 5.24136 10.5286C5.33579 10.4958 5.43581 10.4822 5.53556 10.4886C5.63531 10.495 5.73277 10.5213 5.82222 10.5659C5.91166 10.6106 5.99128 10.6726 6.05641 10.7484L8.17938 13.1072L13.9258 6.26766C14.0547 6.11863 14.237 6.02631 14.4335 6.01066C14.6299 5.99501 14.8246 6.05728 14.9754 6.18402C15.1263 6.31075 15.2212 6.49176 15.2397 6.68793C15.2582 6.8841 15.1988 7.07966 15.0742 7.23234Z" fill="#B5B5B5"/>
</svg>

After

Width:  |  Height:  |  Size: 1.1 KiB

View File

@@ -0,0 +1,138 @@
import React, { useEffect } from 'react';
import { useDispatch, useSelector } from 'react-redux';
import modelService from '../api/services/modelService';
import Arrow2 from '../assets/dropdown-arrow.svg';
import RoundedTick from '../assets/rounded-tick.svg';
import {
selectAvailableModels,
selectSelectedModel,
setAvailableModels,
setModelsLoading,
setSelectedModel,
} from '../preferences/preferenceSlice';
import type { Model } from '../models/types';
export default function DropdownModel() {
const dispatch = useDispatch();
const selectedModel = useSelector(selectSelectedModel);
const availableModels = useSelector(selectAvailableModels);
const dropdownRef = React.useRef<HTMLDivElement>(null);
const [isOpen, setIsOpen] = React.useState(false);
useEffect(() => {
const loadModels = async () => {
if ((availableModels?.length ?? 0) > 0) {
return;
}
dispatch(setModelsLoading(true));
try {
const response = await modelService.getModels(null);
if (!response.ok) {
throw new Error(`API error: ${response.status}`);
}
const data = await response.json();
const models = data.models || [];
const transformed = modelService.transformModels(models);
dispatch(setAvailableModels(transformed));
if (!selectedModel && transformed.length > 0) {
const defaultModel =
transformed.find((m) => m.id === data.default_model_id) ||
transformed[0];
dispatch(setSelectedModel(defaultModel));
} else if (selectedModel && transformed.length > 0) {
const isValid = transformed.find((m) => m.id === selectedModel.id);
if (!isValid) {
const defaultModel =
transformed.find((m) => m.id === data.default_model_id) ||
transformed[0];
dispatch(setSelectedModel(defaultModel));
}
}
} catch (error) {
console.error('Failed to load models:', error);
} finally {
dispatch(setModelsLoading(false));
}
};
loadModels();
}, [availableModels?.length, dispatch, selectedModel]);
const handleClickOutside = (event: MouseEvent) => {
if (
dropdownRef.current &&
!dropdownRef.current.contains(event.target as Node)
) {
setIsOpen(false);
}
};
useEffect(() => {
document.addEventListener('mousedown', handleClickOutside);
return () => {
document.removeEventListener('mousedown', handleClickOutside);
};
}, []);
return (
<div ref={dropdownRef}>
<div
className={`bg-gray-1000 dark:bg-dark-charcoal mx-auto flex w-full cursor-pointer justify-between p-1 dark:text-white ${isOpen ? 'rounded-t-3xl' : 'rounded-3xl'}`}
onClick={() => setIsOpen(!isOpen)}
>
{selectedModel?.display_name ? (
<p className="mx-4 my-3 truncate overflow-hidden whitespace-nowrap">
{selectedModel.display_name}
</p>
) : (
<p className="mx-4 my-3 truncate overflow-hidden whitespace-nowrap">
Select Model
</p>
)}
<img
src={Arrow2}
alt="arrow"
className={`${
isOpen ? 'rotate-360' : 'rotate-270'
} mr-3 w-3 transition-all select-none`}
/>
</div>
{isOpen && (
<div className="no-scrollbar dark:bg-dark-charcoal absolute right-0 left-0 z-20 -mt-1 max-h-52 w-full overflow-y-auto rounded-b-3xl bg-white shadow-md">
{availableModels && (availableModels?.length ?? 0) > 0 ? (
availableModels.map((model: Model) => (
<div
key={model.id}
onClick={() => {
dispatch(setSelectedModel(model));
setIsOpen(false);
}}
className={`border-gray-3000/75 dark:border-purple-taupe/50 hover:bg-gray-3000/75 dark:hover:bg-purple-taupe flex h-10 w-full cursor-pointer items-center justify-between border-t`}
>
<div className="flex w-full items-center justify-between">
<p className="overflow-hidden py-3 pr-2 pl-5 overflow-ellipsis whitespace-nowrap">
{model.display_name}
</p>
{model.id === selectedModel?.id ? (
<img
src={RoundedTick}
alt="selected"
className="mr-3.5 h-4 w-4"
/>
) : null}
</div>
</div>
))
) : (
<div className="h-10 w-full border-x-2 border-b-2">
<p className="ml-5 py-3 text-gray-500">No models available</p>
</div>
)}
</div>
)}
</div>
);
}

View File

@@ -19,8 +19,8 @@ import {
removeAttachment, removeAttachment,
selectAttachments, selectAttachments,
updateAttachment, updateAttachment,
reorderAttachments,
} from '../upload/uploadSlice'; } from '../upload/uploadSlice';
import { reorderAttachments } from '../upload/uploadSlice';
import { ActiveState } from '../models/misc'; import { ActiveState } from '../models/misc';
import { import {
@@ -77,7 +77,7 @@ export default function MessageInput({
(browserOS === 'mac' && event.metaKey && event.key === 'k') (browserOS === 'mac' && event.metaKey && event.key === 'k')
) { ) {
event.preventDefault(); event.preventDefault();
setIsSourcesPopupOpen(!isSourcesPopupOpen); setIsSourcesPopupOpen((s) => !s);
} }
}; };
@@ -89,8 +89,198 @@ export default function MessageInput({
const uploadFiles = useCallback( const uploadFiles = useCallback(
(files: File[]) => { (files: File[]) => {
if (!files || files.length === 0) return;
const apiHost = import.meta.env.VITE_API_HOST; const apiHost = import.meta.env.VITE_API_HOST;
if (files.length > 1) {
const formData = new FormData();
const indexToUiId: Record<number, string> = {};
files.forEach((file, i) => {
formData.append('file', file);
const uiId = crypto.randomUUID();
indexToUiId[i] = uiId;
dispatch(
addAttachment({
id: uiId,
fileName: file.name,
progress: 0,
status: 'uploading' as const,
taskId: '',
}),
);
});
const xhr = new XMLHttpRequest();
xhr.upload.addEventListener('progress', (event) => {
if (event.lengthComputable) {
const progress = Math.round((event.loaded / event.total) * 100);
Object.values(indexToUiId).forEach((uiId) =>
dispatch(
updateAttachment({
id: uiId,
updates: { progress },
}),
),
);
}
});
xhr.onload = () => {
const status = xhr.status;
if (status === 200) {
try {
const response = JSON.parse(xhr.responseText);
if (Array.isArray(response?.tasks)) {
const tasks = response.tasks as Array<{
task_id?: string;
filename?: string;
attachment_id?: string;
path?: string;
}>;
tasks.forEach((t, idx) => {
const uiId = indexToUiId[idx];
if (!uiId) return;
if (t?.task_id) {
dispatch(
updateAttachment({
id: uiId,
updates: {
taskId: t.task_id,
status: 'processing',
progress: 10,
},
}),
);
} else {
dispatch(
updateAttachment({
id: uiId,
updates: { status: 'failed' },
}),
);
}
});
if (tasks.length < files.length) {
for (let i = tasks.length; i < files.length; i++) {
const uiId = indexToUiId[i];
if (uiId) {
dispatch(
updateAttachment({
id: uiId,
updates: { status: 'failed' },
}),
);
}
}
}
} else if (response?.task_id) {
if (files.length === 1) {
const uiId = indexToUiId[0];
if (uiId) {
dispatch(
updateAttachment({
id: uiId,
updates: {
taskId: response.task_id,
status: 'processing',
progress: 10,
},
}),
);
}
} else {
console.warn(
'Server returned a single task_id for multiple files. Update backend to return tasks[].',
);
const firstUi = indexToUiId[0];
if (firstUi) {
dispatch(
updateAttachment({
id: firstUi,
updates: {
taskId: response.task_id,
status: 'processing',
progress: 10,
},
}),
);
}
for (let i = 1; i < files.length; i++) {
const uiId = indexToUiId[i];
if (uiId) {
dispatch(
updateAttachment({
id: uiId,
updates: { status: 'failed' },
}),
);
}
}
}
} else {
console.error('Unexpected upload response shape', response);
Object.values(indexToUiId).forEach((id) =>
dispatch(
updateAttachment({
id,
updates: { status: 'failed' },
}),
),
);
}
} catch (err) {
console.error(
'Failed to parse upload response',
err,
xhr.responseText,
);
Object.values(indexToUiId).forEach((id) =>
dispatch(
updateAttachment({
id,
updates: { status: 'failed' },
}),
),
);
}
} else {
console.error('Upload failed', status, xhr.responseText);
Object.values(indexToUiId).forEach((id) =>
dispatch(
updateAttachment({
id,
updates: { status: 'failed' },
}),
),
);
}
};
xhr.onerror = () => {
console.error('Upload network error');
Object.values(indexToUiId).forEach((id) =>
dispatch(
updateAttachment({
id,
updates: { status: 'failed' },
}),
),
);
};
xhr.open('POST', `${apiHost}${endpoints.USER.STORE_ATTACHMENT}`);
if (token) xhr.setRequestHeader('Authorization', `Bearer ${token}`);
xhr.send(formData);
return;
}
// Single-file path: upload each file individually (original repo behavior)
files.forEach((file) => { files.forEach((file) => {
const formData = new FormData(); const formData = new FormData();
formData.append('file', file); formData.append('file', file);
@@ -121,16 +311,54 @@ export default function MessageInput({
xhr.onload = () => { xhr.onload = () => {
if (xhr.status === 200) { if (xhr.status === 200) {
const response = JSON.parse(xhr.responseText); try {
if (response.task_id) { const response = JSON.parse(xhr.responseText);
if (response.task_id) {
dispatch(
updateAttachment({
id: uniqueId,
updates: {
taskId: response.task_id,
status: 'processing',
progress: 10,
},
}),
);
} else {
// If backend returned tasks[] for single-file, handle gracefully:
if (
Array.isArray(response?.tasks) &&
response.tasks[0]?.task_id
) {
dispatch(
updateAttachment({
id: uniqueId,
updates: {
taskId: response.tasks[0].task_id,
status: 'processing',
progress: 10,
},
}),
);
} else {
dispatch(
updateAttachment({
id: uniqueId,
updates: { status: 'failed' },
}),
);
}
}
} catch (err) {
console.error(
'Failed to parse upload response',
err,
xhr.responseText,
);
dispatch( dispatch(
updateAttachment({ updateAttachment({
id: uniqueId, id: uniqueId,
updates: { updates: { status: 'failed' },
taskId: response.task_id,
status: 'processing',
progress: 10,
},
}), }),
); );
} }
@@ -154,7 +382,7 @@ export default function MessageInput({
}; };
xhr.open('POST', `${apiHost}${endpoints.USER.STORE_ATTACHMENT}`); xhr.open('POST', `${apiHost}${endpoints.USER.STORE_ATTACHMENT}`);
xhr.setRequestHeader('Authorization', `Bearer ${token}`); if (token) xhr.setRequestHeader('Authorization', `Bearer ${token}`);
xhr.send(formData); xhr.send(formData);
}); });
}, },
@@ -163,15 +391,13 @@ export default function MessageInput({
const handleFileAttachment = (e: React.ChangeEvent<HTMLInputElement>) => { const handleFileAttachment = (e: React.ChangeEvent<HTMLInputElement>) => {
if (!e.target.files || e.target.files.length === 0) return; if (!e.target.files || e.target.files.length === 0) return;
const files = Array.from(e.target.files); const files = Array.from(e.target.files);
uploadFiles(files); uploadFiles(files);
// clear input so same file can be selected again // clear input so same file can be selected again
e.target.value = ''; e.target.value = '';
}; };
// Drag and drop handler // Drag & drop via react-dropzone
const onDrop = useCallback( const onDrop = useCallback(
(acceptedFiles: File[]) => { (acceptedFiles: File[]) => {
uploadFiles(acceptedFiles); uploadFiles(acceptedFiles);
@@ -321,11 +547,8 @@ export default function MessageInput({
handleAbort(); handleAbort();
}; };
// Drag state for reordering
const [draggingId, setDraggingId] = useState<string | null>(null); const [draggingId, setDraggingId] = useState<string | null>(null);
// no preview object URLs to revoke (preview removed per reviewer request)
const findIndexById = (id: string) => const findIndexById = (id: string) =>
attachments.findIndex((a) => a.id === id); attachments.findIndex((a) => a.id === id);
@@ -359,7 +582,9 @@ export default function MessageInput({
return ( return (
<div {...getRootProps()} className="flex w-full flex-col"> <div {...getRootProps()} className="flex w-full flex-col">
{/* react-dropzone input (for drag/drop) */}
<input {...getInputProps()} /> <input {...getInputProps()} />
<div className="border-dark-gray bg-lotion dark:border-grey relative flex w-full flex-col rounded-[23px] border dark:bg-transparent"> <div className="border-dark-gray bg-lotion dark:border-grey relative flex w-full flex-col rounded-[23px] border dark:bg-transparent">
<div className="flex flex-wrap gap-1.5 px-2 py-2 sm:gap-2 sm:px-3"> <div className="flex flex-wrap gap-1.5 px-2 py-2 sm:gap-2 sm:px-3">
{attachments.map((attachment) => { {attachments.map((attachment) => {
@@ -374,7 +599,11 @@ export default function MessageInput({
attachment.status !== 'completed' attachment.status !== 'completed'
? 'opacity-70' ? 'opacity-70'
: 'opacity-100' : 'opacity-100'
} ${draggingId === attachment.id ? 'ring-dashed opacity-60 ring-2 ring-purple-200' : ''}`} } ${
draggingId === attachment.id
? 'ring-dashed opacity-60 ring-2 ring-purple-200'
: ''
}`}
title={attachment.fileName} title={attachment.fileName}
> >
<div className="bg-purple-30 mr-2 flex h-8 w-8 items-center justify-center rounded-md p-1"> <div className="bg-purple-30 mr-2 flex h-8 w-8 items-center justify-center rounded-md p-1">

View File

@@ -15,6 +15,7 @@ export function handleFetchAnswer(
agentId?: string, agentId?: string,
attachments?: string[], attachments?: string[],
save_conversation = true, save_conversation = true,
modelId?: string,
): Promise< ): Promise<
| { | {
result: any; result: any;
@@ -47,6 +48,10 @@ export function handleFetchAnswer(
save_conversation: save_conversation, save_conversation: save_conversation,
}; };
if (modelId) {
payload.model_id = modelId;
}
// Add attachments to payload if they exist // Add attachments to payload if they exist
if (attachments && attachments.length > 0) { if (attachments && attachments.length > 0) {
payload.attachments = attachments; payload.attachments = attachments;
@@ -101,6 +106,7 @@ export function handleFetchAnswerSteaming(
agentId?: string, agentId?: string,
attachments?: string[], attachments?: string[],
save_conversation = true, save_conversation = true,
modelId?: string,
): Promise<Answer> { ): Promise<Answer> {
const payload: RetrievalPayload = { const payload: RetrievalPayload = {
question: question, question: question,
@@ -114,6 +120,10 @@ export function handleFetchAnswerSteaming(
save_conversation: save_conversation, save_conversation: save_conversation,
}; };
if (modelId) {
payload.model_id = modelId;
}
// Add attachments to payload if they exist // Add attachments to payload if they exist
if (attachments && attachments.length > 0) { if (attachments && attachments.length > 0) {
payload.attachments = attachments; payload.attachments = attachments;

View File

@@ -65,4 +65,5 @@ export interface RetrievalPayload {
agent_id?: string; agent_id?: string;
attachments?: string[]; attachments?: string[];
save_conversation?: boolean; save_conversation?: boolean;
model_id?: string;
} }

View File

@@ -49,6 +49,9 @@ export const fetchAnswer = createAsyncThunk<
} }
const currentConversationId = state.conversation.conversationId; const currentConversationId = state.conversation.conversationId;
const modelId =
state.preference.selectedAgent?.default_model_id ||
state.preference.selectedModel?.id;
if (state.preference) { if (state.preference) {
if (API_STREAMING) { if (API_STREAMING) {
@@ -156,7 +159,8 @@ export const fetchAnswer = createAsyncThunk<
indx, indx,
state.preference.selectedAgent?.id, state.preference.selectedAgent?.id,
attachmentIds, attachmentIds,
true, // Always save conversation true,
modelId,
); );
} else { } else {
const answer = await handleFetchAnswer( const answer = await handleFetchAnswer(
@@ -170,7 +174,8 @@ export const fetchAnswer = createAsyncThunk<
state.preference.token_limit, state.preference.token_limit,
state.preference.selectedAgent?.id, state.preference.selectedAgent?.id,
attachmentIds, attachmentIds,
true, // Always save conversation true,
modelId,
); );
if (answer) { if (answer) {
let sourcesPrepped = []; let sourcesPrepped = [];

View File

@@ -225,6 +225,16 @@ layer(base);
} }
@layer base { @layer base {
.prompt-variable-highlight {
background-color: rgba(106, 77, 244, 0.18);
border-radius: 0.375rem;
padding: 0 0.25rem;
}
.dark .prompt-variable-highlight {
background-color: rgba(106, 77, 244, 0.32);
}
/*! normalize.css v8.0.1 | MIT License | github.com/necolas/normalize.css */ /*! normalize.css v8.0.1 | MIT License | github.com/necolas/normalize.css */
/* Document /* Document

View File

@@ -396,6 +396,18 @@
"variablesDescription": "Click to insert into prompt", "variablesDescription": "Click to insert into prompt",
"systemVariables": "Click to insert into prompt", "systemVariables": "Click to insert into prompt",
"toolVariables": "Tool Variables", "toolVariables": "Tool Variables",
"systemVariablesDropdownLabel": "System Variables",
"systemVariableOptions": {
"sourceContent": "Sources content",
"sourceSummaries": "Alias for content (backward compatible)",
"sourceDocuments": "Document objects list",
"sourceCount": "Number of retrieved documents",
"systemDate": "Current date (YYYY-MM-DD)",
"systemTime": "Current time (HH:MM:SS)",
"systemTimestamp": "ISO 8601 timestamp",
"systemRequestId": "Unique request identifier",
"systemUserId": "Current user ID"
},
"learnAboutPrompts": "Learn about Prompts →", "learnAboutPrompts": "Learn about Prompts →",
"publicPromptEditDisabled": "Public prompts cannot be edited", "publicPromptEditDisabled": "Public prompts cannot be edited",
"promptTypePublic": "public", "promptTypePublic": "public",
@@ -518,6 +530,7 @@
"prompt": "Prompt", "prompt": "Prompt",
"tools": "Tools", "tools": "Tools",
"agentType": "Agent type", "agentType": "Agent type",
"models": "Models",
"advanced": "Advanced", "advanced": "Advanced",
"preview": "Preview" "preview": "Preview"
}, },
@@ -528,6 +541,8 @@
"chunksPerQuery": "Chunks per query", "chunksPerQuery": "Chunks per query",
"selectType": "Select type", "selectType": "Select type",
"selectTools": "Select tools", "selectTools": "Select tools",
"selectModels": "Select models for this agent",
"selectDefaultModel": "Select default model",
"enterTokenLimit": "Enter token limit", "enterTokenLimit": "Enter token limit",
"enterRequestLimit": "Enter request limit" "enterRequestLimit": "Enter request limit"
}, },
@@ -541,6 +556,11 @@
"searchPlaceholder": "Search tools...", "searchPlaceholder": "Search tools...",
"noOptionsMessage": "No tools available" "noOptionsMessage": "No tools available"
}, },
"modelsPopup": {
"title": "Select Models",
"searchPlaceholder": "Search models...",
"noOptionsMessage": "No models available"
},
"upload": { "upload": {
"clickToUpload": "Click to upload", "clickToUpload": "Click to upload",
"dragAndDrop": " or drag and drop" "dragAndDrop": " or drag and drop"
@@ -549,6 +569,9 @@
"classic": "Classic", "classic": "Classic",
"react": "ReAct" "react": "ReAct"
}, },
"labels": {
"defaultModel": "Default Model"
},
"advanced": { "advanced": {
"jsonSchema": "JSON response schema", "jsonSchema": "JSON response schema",
"jsonSchemaDescription": "Define a JSON schema to enforce structured output format", "jsonSchemaDescription": "Define a JSON schema to enforce structured output format",

View File

@@ -396,6 +396,18 @@
"variablesDescription": "Haz clic para insertar en el prompt", "variablesDescription": "Haz clic para insertar en el prompt",
"systemVariables": "Variables del sistema", "systemVariables": "Variables del sistema",
"toolVariables": "Variables de herramientas", "toolVariables": "Variables de herramientas",
"systemVariablesDropdownLabel": "Variables del sistema",
"systemVariableOptions": {
"sourceContent": "Contenido de las fuentes",
"sourceSummaries": "Alias del contenido (compatibilidad retroactiva)",
"sourceDocuments": "Lista de objetos de documentos",
"sourceCount": "Número de documentos recuperados",
"systemDate": "Fecha actual (YYYY-MM-DD)",
"systemTime": "Hora actual (HH:MM:SS)",
"systemTimestamp": "Marca de tiempo ISO 8601",
"systemRequestId": "Identificador único de solicitud",
"systemUserId": "ID del usuario actual"
},
"learnAboutPrompts": "Aprende sobre los Prompts →", "learnAboutPrompts": "Aprende sobre los Prompts →",
"publicPromptEditDisabled": "Los prompts públicos no se pueden editar", "publicPromptEditDisabled": "Los prompts públicos no se pueden editar",
"promptTypePublic": "público", "promptTypePublic": "público",
@@ -518,6 +530,7 @@
"prompt": "Prompt", "prompt": "Prompt",
"tools": "Herramientas", "tools": "Herramientas",
"agentType": "Tipo de agente", "agentType": "Tipo de agente",
"models": "Modelos",
"advanced": "Avanzado", "advanced": "Avanzado",
"preview": "Vista previa" "preview": "Vista previa"
}, },
@@ -528,6 +541,8 @@
"chunksPerQuery": "Fragmentos por consulta", "chunksPerQuery": "Fragmentos por consulta",
"selectType": "Seleccionar tipo", "selectType": "Seleccionar tipo",
"selectTools": "Seleccionar herramientas", "selectTools": "Seleccionar herramientas",
"selectModels": "Seleccionar modelos para este agente",
"selectDefaultModel": "Seleccionar modelo predeterminado",
"enterTokenLimit": "Ingresar límite de tokens", "enterTokenLimit": "Ingresar límite de tokens",
"enterRequestLimit": "Ingresar límite de solicitudes" "enterRequestLimit": "Ingresar límite de solicitudes"
}, },
@@ -541,6 +556,11 @@
"searchPlaceholder": "Buscar herramientas...", "searchPlaceholder": "Buscar herramientas...",
"noOptionsMessage": "No hay herramientas disponibles" "noOptionsMessage": "No hay herramientas disponibles"
}, },
"modelsPopup": {
"title": "Seleccionar Modelos",
"searchPlaceholder": "Buscar modelos...",
"noOptionsMessage": "No hay modelos disponibles"
},
"upload": { "upload": {
"clickToUpload": "Haz clic para subir", "clickToUpload": "Haz clic para subir",
"dragAndDrop": " o arrastra y suelta" "dragAndDrop": " o arrastra y suelta"
@@ -549,6 +569,9 @@
"classic": "Clásico", "classic": "Clásico",
"react": "ReAct" "react": "ReAct"
}, },
"labels": {
"defaultModel": "Modelo Predeterminado"
},
"advanced": { "advanced": {
"jsonSchema": "Esquema de respuesta JSON", "jsonSchema": "Esquema de respuesta JSON",
"jsonSchemaDescription": "Define un esquema JSON para aplicar formato de salida estructurado", "jsonSchemaDescription": "Define un esquema JSON para aplicar formato de salida estructurado",

View File

@@ -396,6 +396,18 @@
"variablesDescription": "クリックしてプロンプトに挿入", "variablesDescription": "クリックしてプロンプトに挿入",
"systemVariables": "システム変数", "systemVariables": "システム変数",
"toolVariables": "ツール変数", "toolVariables": "ツール変数",
"systemVariablesDropdownLabel": "System Variables",
"systemVariableOptions": {
"sourceContent": "Sources content",
"sourceSummaries": "Alias for content (backward compatible)",
"sourceDocuments": "Document objects list",
"sourceCount": "Number of retrieved documents",
"systemDate": "Current date (YYYY-MM-DD)",
"systemTime": "Current time (HH:MM:SS)",
"systemTimestamp": "ISO 8601 timestamp",
"systemRequestId": "Unique request identifier",
"systemUserId": "Current user ID"
},
"learnAboutPrompts": "プロンプトについて学ぶ →", "learnAboutPrompts": "プロンプトについて学ぶ →",
"publicPromptEditDisabled": "公開プロンプトは編集できません", "publicPromptEditDisabled": "公開プロンプトは編集できません",
"promptTypePublic": "公開", "promptTypePublic": "公開",
@@ -518,6 +530,7 @@
"prompt": "プロンプト", "prompt": "プロンプト",
"tools": "ツール", "tools": "ツール",
"agentType": "エージェントタイプ", "agentType": "エージェントタイプ",
"models": "モデル",
"advanced": "詳細設定", "advanced": "詳細設定",
"preview": "プレビュー" "preview": "プレビュー"
}, },
@@ -528,6 +541,8 @@
"chunksPerQuery": "クエリごとのチャンク数", "chunksPerQuery": "クエリごとのチャンク数",
"selectType": "タイプを選択", "selectType": "タイプを選択",
"selectTools": "ツールを選択", "selectTools": "ツールを選択",
"selectModels": "このエージェントのモデルを選択",
"selectDefaultModel": "デフォルトモデルを選択",
"enterTokenLimit": "トークン制限を入力", "enterTokenLimit": "トークン制限を入力",
"enterRequestLimit": "リクエスト制限を入力" "enterRequestLimit": "リクエスト制限を入力"
}, },
@@ -541,6 +556,11 @@
"searchPlaceholder": "ツールを検索...", "searchPlaceholder": "ツールを検索...",
"noOptionsMessage": "利用可能なツールがありません" "noOptionsMessage": "利用可能なツールがありません"
}, },
"modelsPopup": {
"title": "モデルを選択",
"searchPlaceholder": "モデルを検索...",
"noOptionsMessage": "利用可能なモデルがありません"
},
"upload": { "upload": {
"clickToUpload": "クリックしてアップロード", "clickToUpload": "クリックしてアップロード",
"dragAndDrop": " またはドラッグ&ドロップ" "dragAndDrop": " またはドラッグ&ドロップ"
@@ -549,6 +569,9 @@
"classic": "クラシック", "classic": "クラシック",
"react": "ReAct" "react": "ReAct"
}, },
"labels": {
"defaultModel": "デフォルトモデル"
},
"advanced": { "advanced": {
"jsonSchema": "JSON応答スキーマ", "jsonSchema": "JSON応答スキーマ",
"jsonSchemaDescription": "構造化された出力形式を適用するためのJSONスキーマを定義します", "jsonSchemaDescription": "構造化された出力形式を適用するためのJSONスキーマを定義します",

View File

@@ -396,6 +396,18 @@
"variablesDescription": "Нажмите, чтобы вставить в промпт", "variablesDescription": "Нажмите, чтобы вставить в промпт",
"systemVariables": "Системные переменные", "systemVariables": "Системные переменные",
"toolVariables": "Переменные инструментов", "toolVariables": "Переменные инструментов",
"systemVariablesDropdownLabel": "Системные переменные",
"systemVariableOptions": {
"sourceContent": "Содержимое источников",
"sourceSummaries": "Псевдоним содержимого (обратная совместимость)",
"sourceDocuments": "Список объектов документов",
"sourceCount": "Количество полученных документов",
"systemDate": "Текущая дата (ГГГГ-ММ-ДД)",
"systemTime": "Текущее время (ЧЧ:ММ:СС)",
"systemTimestamp": "Отметка времени ISO 8601",
"systemRequestId": "Уникальный идентификатор запроса",
"systemUserId": "Идентификатор текущего пользователя"
},
"learnAboutPrompts": "Узнать о промптах →", "learnAboutPrompts": "Узнать о промптах →",
"publicPromptEditDisabled": "Публичные промпты нельзя редактировать", "publicPromptEditDisabled": "Публичные промпты нельзя редактировать",
"promptTypePublic": "публичный", "promptTypePublic": "публичный",
@@ -518,6 +530,7 @@
"prompt": "Промпт", "prompt": "Промпт",
"tools": "Инструменты", "tools": "Инструменты",
"agentType": "Тип агента", "agentType": "Тип агента",
"models": "Модели",
"advanced": "Расширенные", "advanced": "Расширенные",
"preview": "Предпросмотр" "preview": "Предпросмотр"
}, },
@@ -528,6 +541,8 @@
"chunksPerQuery": "Фрагментов на запрос", "chunksPerQuery": "Фрагментов на запрос",
"selectType": "Выберите тип", "selectType": "Выберите тип",
"selectTools": "Выберите инструменты", "selectTools": "Выберите инструменты",
"selectModels": "Выберите модели для этого агента",
"selectDefaultModel": "Выберите модель по умолчанию",
"enterTokenLimit": "Введите лимит токенов", "enterTokenLimit": "Введите лимит токенов",
"enterRequestLimit": "Введите лимит запросов" "enterRequestLimit": "Введите лимит запросов"
}, },
@@ -541,6 +556,11 @@
"searchPlaceholder": "Поиск инструментов...", "searchPlaceholder": "Поиск инструментов...",
"noOptionsMessage": "Нет доступных инструментов" "noOptionsMessage": "Нет доступных инструментов"
}, },
"modelsPopup": {
"title": "Выберите Модели",
"searchPlaceholder": "Поиск моделей...",
"noOptionsMessage": "Нет доступных моделей"
},
"upload": { "upload": {
"clickToUpload": "Нажмите для загрузки", "clickToUpload": "Нажмите для загрузки",
"dragAndDrop": " или перетащите" "dragAndDrop": " или перетащите"
@@ -549,6 +569,9 @@
"classic": "Классический", "classic": "Классический",
"react": "ReAct" "react": "ReAct"
}, },
"labels": {
"defaultModel": "Модель по умолчанию"
},
"advanced": { "advanced": {
"jsonSchema": "Схема ответа JSON", "jsonSchema": "Схема ответа JSON",
"jsonSchemaDescription": "Определите схему JSON для применения структурированного формата вывода", "jsonSchemaDescription": "Определите схему JSON для применения структурированного формата вывода",

View File

@@ -396,6 +396,18 @@
"variablesDescription": "點擊以插入到提示中", "variablesDescription": "點擊以插入到提示中",
"systemVariables": "點擊以插入提示中", "systemVariables": "點擊以插入提示中",
"toolVariables": "工具變數", "toolVariables": "工具變數",
"systemVariablesDropdownLabel": "系統變數",
"systemVariableOptions": {
"sourceContent": "來源內容",
"sourceSummaries": "內容別名(向後相容)",
"sourceDocuments": "文件物件列表",
"sourceCount": "擷取的文件數量",
"systemDate": "目前日期 (YYYY-MM-DD)",
"systemTime": "目前時間 (HH:MM:SS)",
"systemTimestamp": "ISO 8601 時間戳記",
"systemRequestId": "唯一請求識別碼",
"systemUserId": "目前使用者 ID"
},
"learnAboutPrompts": "了解提示 →", "learnAboutPrompts": "了解提示 →",
"publicPromptEditDisabled": "公共提示無法編輯", "publicPromptEditDisabled": "公共提示無法編輯",
"promptTypePublic": "公共", "promptTypePublic": "公共",
@@ -518,6 +530,7 @@
"prompt": "提示詞", "prompt": "提示詞",
"tools": "工具", "tools": "工具",
"agentType": "代理類型", "agentType": "代理類型",
"models": "模型",
"advanced": "進階", "advanced": "進階",
"preview": "預覽" "preview": "預覽"
}, },
@@ -528,6 +541,8 @@
"chunksPerQuery": "每次查詢的區塊數", "chunksPerQuery": "每次查詢的區塊數",
"selectType": "選擇類型", "selectType": "選擇類型",
"selectTools": "選擇工具", "selectTools": "選擇工具",
"selectModels": "為此代理選擇模型",
"selectDefaultModel": "選擇預設模型",
"enterTokenLimit": "輸入權杖限制", "enterTokenLimit": "輸入權杖限制",
"enterRequestLimit": "輸入請求限制" "enterRequestLimit": "輸入請求限制"
}, },
@@ -541,6 +556,11 @@
"searchPlaceholder": "搜尋工具...", "searchPlaceholder": "搜尋工具...",
"noOptionsMessage": "沒有可用的工具" "noOptionsMessage": "沒有可用的工具"
}, },
"modelsPopup": {
"title": "選擇模型",
"searchPlaceholder": "搜尋模型...",
"noOptionsMessage": "沒有可用的模型"
},
"upload": { "upload": {
"clickToUpload": "點擊上傳", "clickToUpload": "點擊上傳",
"dragAndDrop": " 或拖放" "dragAndDrop": " 或拖放"
@@ -549,6 +569,9 @@
"classic": "經典", "classic": "經典",
"react": "ReAct" "react": "ReAct"
}, },
"labels": {
"defaultModel": "預設模型"
},
"advanced": { "advanced": {
"jsonSchema": "JSON回應架構", "jsonSchema": "JSON回應架構",
"jsonSchemaDescription": "定義JSON架構以強制執行結構化輸出格式", "jsonSchemaDescription": "定義JSON架構以強制執行結構化輸出格式",

View File

@@ -396,6 +396,18 @@
"variablesDescription": "點擊以插入到提示中", "variablesDescription": "點擊以插入到提示中",
"systemVariables": "點擊以插入提示中", "systemVariables": "點擊以插入提示中",
"toolVariables": "工具變數", "toolVariables": "工具變數",
"systemVariablesDropdownLabel": "系統變數",
"systemVariableOptions": {
"sourceContent": "來源內容",
"sourceSummaries": "內容別名(向後相容)",
"sourceDocuments": "文件物件列表",
"sourceCount": "擷取的文件數量",
"systemDate": "目前日期 (YYYY-MM-DD)",
"systemTime": "目前時間 (HH:MM:SS)",
"systemTimestamp": "ISO 8601 時間戳記",
"systemRequestId": "唯一請求識別碼",
"systemUserId": "目前使用者 ID"
},
"learnAboutPrompts": "了解提示 →", "learnAboutPrompts": "了解提示 →",
"publicPromptEditDisabled": "公共提示無法編輯", "publicPromptEditDisabled": "公共提示無法編輯",
"promptTypePublic": "公共", "promptTypePublic": "公共",
@@ -518,6 +530,7 @@
"prompt": "提示词", "prompt": "提示词",
"tools": "工具", "tools": "工具",
"agentType": "代理类型", "agentType": "代理类型",
"models": "模型",
"advanced": "高级", "advanced": "高级",
"preview": "预览" "preview": "预览"
}, },
@@ -528,6 +541,8 @@
"chunksPerQuery": "每次查询的块数", "chunksPerQuery": "每次查询的块数",
"selectType": "选择类型", "selectType": "选择类型",
"selectTools": "选择工具", "selectTools": "选择工具",
"selectModels": "为此代理选择模型",
"selectDefaultModel": "选择默认模型",
"enterTokenLimit": "输入令牌限制", "enterTokenLimit": "输入令牌限制",
"enterRequestLimit": "输入请求限制" "enterRequestLimit": "输入请求限制"
}, },
@@ -541,6 +556,11 @@
"searchPlaceholder": "搜索工具...", "searchPlaceholder": "搜索工具...",
"noOptionsMessage": "没有可用的工具" "noOptionsMessage": "没有可用的工具"
}, },
"modelsPopup": {
"title": "选择模型",
"searchPlaceholder": "搜索模型...",
"noOptionsMessage": "没有可用的模型"
},
"upload": { "upload": {
"clickToUpload": "点击上传", "clickToUpload": "点击上传",
"dragAndDrop": " 或拖放" "dragAndDrop": " 或拖放"
@@ -549,6 +569,9 @@
"classic": "经典", "classic": "经典",
"react": "ReAct" "react": "ReAct"
}, },
"labels": {
"defaultModel": "默认模型"
},
"advanced": { "advanced": {
"jsonSchema": "JSON响应架构", "jsonSchema": "JSON响应架构",
"jsonSchemaDescription": "定义JSON架构以强制执行结构化输出格式", "jsonSchemaDescription": "定义JSON架构以强制执行结构化输出格式",

View File

@@ -0,0 +1,25 @@
export interface AvailableModel {
id: string;
provider: string;
display_name: string;
description?: string;
context_window: number;
supported_attachment_types: string[];
supports_tools: boolean;
supports_structured_output: boolean;
supports_streaming: boolean;
enabled: boolean;
}
export interface Model {
id: string;
value: string;
provider: string;
display_name: string;
description?: string;
context_window: number;
supported_attachment_types: string[];
supports_tools: boolean;
supports_structured_output: boolean;
supports_streaming: boolean;
}

View File

@@ -12,6 +12,141 @@ import userService from '../api/services/userService';
import { selectToken } from '../preferences/preferenceSlice'; import { selectToken } from '../preferences/preferenceSlice';
import { UserToolType } from '../settings/types'; import { UserToolType } from '../settings/types';
const variablePattern = /(\{\{\s*[^{}]+\s*\}\}|\{(?!\{)[^{}]+\})/g;
const escapeHtml = (value: string) =>
value
.replace(/&/g, '&amp;')
.replace(/</g, '&lt;')
.replace(/>/g, '&gt;')
.replace(/"/g, '&quot;')
.replace(/'/g, '&#39;');
const highlightPromptVariables = (text: string) => {
if (!text) {
return '&#8203;';
}
variablePattern.lastIndex = 0;
let result = '';
let lastIndex = 0;
let match: RegExpExecArray | null;
while ((match = variablePattern.exec(text)) !== null) {
const precedingText = text.slice(lastIndex, match.index);
if (precedingText) {
result += escapeHtml(precedingText);
}
result += `<span class="prompt-variable-highlight">${escapeHtml(match[0])}</span>`;
lastIndex = match.index + match[0].length;
}
const remainingText = text.slice(lastIndex);
if (remainingText) {
result += escapeHtml(remainingText);
}
return result || '&#8203;';
};
const systemVariableOptionDefinitions = [
{
labelKey: 'modals.prompts.systemVariableOptions.sourceContent',
value: 'source.content',
},
{
labelKey: 'modals.prompts.systemVariableOptions.sourceSummaries',
value: 'source.summaries',
},
{
labelKey: 'modals.prompts.systemVariableOptions.sourceDocuments',
value: 'source.documents',
},
{
labelKey: 'modals.prompts.systemVariableOptions.sourceCount',
value: 'source.count',
},
{
labelKey: 'modals.prompts.systemVariableOptions.systemDate',
value: 'system.date',
},
{
labelKey: 'modals.prompts.systemVariableOptions.systemTime',
value: 'system.time',
},
{
labelKey: 'modals.prompts.systemVariableOptions.systemTimestamp',
value: 'system.timestamp',
},
{
labelKey: 'modals.prompts.systemVariableOptions.systemRequestId',
value: 'system.request_id',
},
{
labelKey: 'modals.prompts.systemVariableOptions.systemUserId',
value: 'system.user_id',
},
];
const buildSystemVariableOptions = (translate: (key: string) => string) =>
systemVariableOptionDefinitions.map(({ value, labelKey }) => ({
value,
label: translate(labelKey),
}));
type PromptTextareaProps = {
id: string;
value: string;
onChange: (event: React.ChangeEvent<HTMLTextAreaElement>) => void;
ariaLabel: string;
};
function PromptTextarea({
id,
value,
onChange,
ariaLabel,
}: PromptTextareaProps) {
const [scrollOffsets, setScrollOffsets] = React.useState({ top: 0, left: 0 });
const highlightedValue = React.useMemo(
() => highlightPromptVariables(value),
[value],
);
const handleScroll = (event: React.UIEvent<HTMLTextAreaElement>) => {
const { scrollTop, scrollLeft } = event.currentTarget;
setScrollOffsets({
top: scrollTop,
left: scrollLeft,
});
};
return (
<>
<div
className="pointer-events-none absolute inset-0 z-0 overflow-hidden rounded bg-white px-3 py-2 dark:bg-[#26272E]"
aria-hidden="true"
>
<div
className="min-h-full text-base leading-[1.5] break-words whitespace-pre-wrap text-transparent"
style={{
transform: `translate(${-scrollOffsets.left}px, ${-scrollOffsets.top}px)`,
}}
dangerouslySetInnerHTML={{ __html: highlightedValue }}
/>
</div>
<textarea
id={id}
className="peer border-silver dark:border-silver/40 relative z-10 h-48 w-full resize-none rounded border-2 bg-transparent px-3 py-2 text-base text-gray-800 outline-none dark:bg-transparent dark:text-white"
value={value}
onChange={onChange}
onScroll={handleScroll}
placeholder=" "
aria-label={ariaLabel}
/>
</>
);
}
// Custom hook for fetching tool variables // Custom hook for fetching tool variables
const useToolVariables = () => { const useToolVariables = () => {
const token = useSelector(selectToken); const token = useSelector(selectToken);
@@ -50,9 +185,13 @@ const useToolVariables = () => {
); );
if (canUseAction) { if (canUseAction) {
const toolIdentifier = tool.id ?? tool.name;
if (!toolIdentifier) {
return;
}
filteredActions.push({ filteredActions.push({
label: `${action.name} (${tool.displayName || tool.name})`, label: `${action.name} (${tool.displayName || tool.name})`,
value: `tools.${tool.name}.${action.name}`, value: `tools.${toolIdentifier}.${action.name}`,
}); });
} }
} }
@@ -91,6 +230,10 @@ function AddPrompt({
disableSave: boolean; disableSave: boolean;
}) { }) {
const { t } = useTranslation(); const { t } = useTranslation();
const systemVariableOptions = React.useMemo(
() => buildSystemVariableOptions(t),
[t],
);
const toolVariables = useToolVariables(); const toolVariables = useToolVariables();
return ( return (
@@ -115,17 +258,15 @@ function AddPrompt({
/> />
<div className="relative w-full"> <div className="relative w-full">
<textarea <PromptTextarea
id="new-prompt-content" id="new-prompt-content"
className="peer border-silver dark:border-silver/40 h-48 w-full resize-none rounded border-2 bg-white px-3 py-2 text-base text-gray-800 outline-none dark:bg-[#26272E] dark:text-white"
value={newPromptContent} value={newPromptContent}
onChange={(e) => setNewPromptContent(e.target.value)} onChange={(e) => setNewPromptContent(e.target.value)}
placeholder=" " ariaLabel={t('prompts.textAriaLabel')}
aria-label={t('prompts.textAriaLabel')}
/> />
<label <label
htmlFor="new-prompt-content" htmlFor="new-prompt-content"
className={`absolute select-none ${ className={`absolute z-20 select-none ${
newPromptContent ? '-top-2.5 left-3 text-xs' : '' newPromptContent ? '-top-2.5 left-3 text-xs' : ''
} text-gray-4000 pointer-events-none max-w-[calc(100%-24px)] cursor-none overflow-hidden bg-white px-2 text-ellipsis whitespace-nowrap transition-all peer-placeholder-shown:top-2.5 peer-placeholder-shown:left-3 peer-placeholder-shown:text-base peer-focus:-top-2.5 peer-focus:left-3 peer-focus:text-xs dark:bg-[#26272E] dark:text-gray-400`} } text-gray-4000 pointer-events-none max-w-[calc(100%-24px)] cursor-none overflow-hidden bg-white px-2 text-ellipsis whitespace-nowrap transition-all peer-placeholder-shown:top-2.5 peer-placeholder-shown:left-3 peer-placeholder-shown:text-base peer-focus:-top-2.5 peer-focus:left-3 peer-focus:text-xs dark:bg-[#26272E] dark:text-gray-400`}
> >
@@ -146,8 +287,8 @@ function AddPrompt({
<div className="flex flex-wrap items-center gap-2 sm:gap-3"> <div className="flex flex-wrap items-center gap-2 sm:gap-3">
<Dropdown <Dropdown
options={[{ label: 'Summaries', value: 'summaries' }]} options={systemVariableOptions}
selectedValue={'System Variables'} selectedValue={t('modals.prompts.systemVariablesDropdownLabel')}
onSelect={(option) => { onSelect={(option) => {
const textarea = document.getElementById( const textarea = document.getElementById(
'new-prompt-content', 'new-prompt-content',
@@ -165,7 +306,7 @@ function AddPrompt({
const newText = const newText =
textBefore + textBefore +
(needsSpace ? ' ' : '') + (needsSpace ? ' ' : '') +
`{${option.value}}` + `{{ ${option.value} }}` +
textAfter; textAfter;
setNewPromptContent(newText); setNewPromptContent(newText);
@@ -174,17 +315,17 @@ function AddPrompt({
textarea.setSelectionRange( textarea.setSelectionRange(
cursorPosition + cursorPosition +
option.value.length + option.value.length +
2 + 6 +
(needsSpace ? 1 : 0), (needsSpace ? 1 : 0),
cursorPosition + cursorPosition +
option.value.length + option.value.length +
2 + 6 +
(needsSpace ? 1 : 0), (needsSpace ? 1 : 0),
); );
}, 0); }, 0);
} }
}} }}
placeholder="System Variables" placeholder={t('modals.prompts.systemVariablesDropdownLabel')}
size="w-[140px] sm:w-[185px]" size="w-[140px] sm:w-[185px]"
rounded="3xl" rounded="3xl"
border="border" border="border"
@@ -298,6 +439,10 @@ function EditPrompt({
disableSave: boolean; disableSave: boolean;
}) { }) {
const { t } = useTranslation(); const { t } = useTranslation();
const systemVariableOptions = React.useMemo(
() => buildSystemVariableOptions(t),
[t],
);
const toolVariables = useToolVariables(); const toolVariables = useToolVariables();
return ( return (
@@ -322,17 +467,15 @@ function EditPrompt({
/> />
<div className="relative w-full"> <div className="relative w-full">
<textarea <PromptTextarea
id="edit-prompt-content" id="edit-prompt-content"
className="peer border-silver dark:border-silver/40 h-48 w-full resize-none rounded border-2 bg-white px-3 py-2 text-base text-gray-800 outline-none dark:bg-[#26272E] dark:text-white"
value={editPromptContent} value={editPromptContent}
onChange={(e) => setEditPromptContent(e.target.value)} onChange={(e) => setEditPromptContent(e.target.value)}
placeholder=" " ariaLabel={t('prompts.textAriaLabel')}
aria-label={t('prompts.textAriaLabel')}
/> />
<label <label
htmlFor="edit-prompt-content" htmlFor="edit-prompt-content"
className={`absolute select-none ${ className={`absolute z-20 select-none ${
editPromptContent ? '-top-2.5 left-3 text-xs' : '' editPromptContent ? '-top-2.5 left-3 text-xs' : ''
} text-gray-4000 pointer-events-none max-w-[calc(100%-24px)] cursor-none overflow-hidden bg-white px-2 text-ellipsis whitespace-nowrap transition-all peer-placeholder-shown:top-2.5 peer-placeholder-shown:left-3 peer-placeholder-shown:text-base peer-focus:-top-2.5 peer-focus:left-3 peer-focus:text-xs dark:bg-[#26272E] dark:text-gray-400`} } text-gray-4000 pointer-events-none max-w-[calc(100%-24px)] cursor-none overflow-hidden bg-white px-2 text-ellipsis whitespace-nowrap transition-all peer-placeholder-shown:top-2.5 peer-placeholder-shown:left-3 peer-placeholder-shown:text-base peer-focus:-top-2.5 peer-focus:left-3 peer-focus:text-xs dark:bg-[#26272E] dark:text-gray-400`}
> >
@@ -353,8 +496,8 @@ function EditPrompt({
<div className="flex flex-wrap items-center gap-2 sm:gap-3"> <div className="flex flex-wrap items-center gap-2 sm:gap-3">
<Dropdown <Dropdown
options={[{ label: 'Summaries', value: 'summaries' }]} options={systemVariableOptions}
selectedValue={'System Variables'} selectedValue={t('modals.prompts.systemVariablesDropdownLabel')}
onSelect={(option) => { onSelect={(option) => {
const textarea = document.getElementById( const textarea = document.getElementById(
'edit-prompt-content', 'edit-prompt-content',
@@ -372,7 +515,7 @@ function EditPrompt({
const newText = const newText =
textBefore + textBefore +
(needsSpace ? ' ' : '') + (needsSpace ? ' ' : '') +
`{${option.value}}` + `{{ ${option.value} }}` +
textAfter; textAfter;
setEditPromptContent(newText); setEditPromptContent(newText);
@@ -381,17 +524,17 @@ function EditPrompt({
textarea.setSelectionRange( textarea.setSelectionRange(
cursorPosition + cursorPosition +
option.value.length + option.value.length +
2 + 6 +
(needsSpace ? 1 : 0), (needsSpace ? 1 : 0),
cursorPosition + cursorPosition +
option.value.length + option.value.length +
2 + 6 +
(needsSpace ? 1 : 0), (needsSpace ? 1 : 0),
); );
}, 0); }, 0);
} }
}} }}
placeholder="System Variables" placeholder={t('modals.prompts.systemVariablesDropdownLabel')}
size="w-[140px] sm:w-[185px]" size="w-[140px] sm:w-[185px]"
rounded="3xl" rounded="3xl"
border="border" border="border"

View File

@@ -9,11 +9,12 @@ import { Agent } from '../agents/types';
import { ActiveState, Doc } from '../models/misc'; import { ActiveState, Doc } from '../models/misc';
import { RootState } from '../store'; import { RootState } from '../store';
import { import {
getLocalRecentDocs,
setLocalApiKey, setLocalApiKey,
setLocalRecentDocs, setLocalRecentDocs,
getLocalRecentDocs,
} from './preferenceApi'; } from './preferenceApi';
import type { Model } from '../models/types';
export interface Preference { export interface Preference {
apiKey: string; apiKey: string;
prompt: { name: string; id: string; type: string }; prompt: { name: string; id: string; type: string };
@@ -32,6 +33,9 @@ export interface Preference {
agents: Agent[] | null; agents: Agent[] | null;
sharedAgents: Agent[] | null; sharedAgents: Agent[] | null;
selectedAgent: Agent | null; selectedAgent: Agent | null;
selectedModel: Model | null;
availableModels: Model[];
modelsLoading: boolean;
} }
const initialState: Preference = { const initialState: Preference = {
@@ -61,6 +65,9 @@ const initialState: Preference = {
agents: null, agents: null,
sharedAgents: null, sharedAgents: null,
selectedAgent: null, selectedAgent: null,
selectedModel: null,
availableModels: [],
modelsLoading: false,
}; };
export const prefSlice = createSlice({ export const prefSlice = createSlice({
@@ -109,6 +116,15 @@ export const prefSlice = createSlice({
setSelectedAgent: (state, action) => { setSelectedAgent: (state, action) => {
state.selectedAgent = action.payload; state.selectedAgent = action.payload;
}, },
setSelectedModel: (state, action: PayloadAction<Model | null>) => {
state.selectedModel = action.payload;
},
setAvailableModels: (state, action: PayloadAction<Model[]>) => {
state.availableModels = action.payload;
},
setModelsLoading: (state, action: PayloadAction<boolean>) => {
state.modelsLoading = action.payload;
},
}, },
}); });
@@ -127,6 +143,9 @@ export const {
setAgents, setAgents,
setSharedAgents, setSharedAgents,
setSelectedAgent, setSelectedAgent,
setSelectedModel,
setAvailableModels,
setModelsLoading,
} = prefSlice.actions; } = prefSlice.actions;
export default prefSlice.reducer; export default prefSlice.reducer;
@@ -198,6 +217,19 @@ prefListenerMiddleware.startListening({
}, },
}); });
prefListenerMiddleware.startListening({
matcher: isAnyOf(setSelectedModel),
effect: (action, listenerApi) => {
const model = (listenerApi.getState() as RootState).preference
.selectedModel;
if (model) {
localStorage.setItem('DocsGPTSelectedModel', JSON.stringify(model));
} else {
localStorage.removeItem('DocsGPTSelectedModel');
}
},
});
export const selectApiKey = (state: RootState) => state.preference.apiKey; export const selectApiKey = (state: RootState) => state.preference.apiKey;
export const selectApiKeyStatus = (state: RootState) => export const selectApiKeyStatus = (state: RootState) =>
!!state.preference.apiKey; !!state.preference.apiKey;
@@ -227,3 +259,9 @@ export const selectSharedAgents = (state: RootState) =>
state.preference.sharedAgents; state.preference.sharedAgents;
export const selectSelectedAgent = (state: RootState) => export const selectSelectedAgent = (state: RootState) =>
state.preference.selectedAgent; state.preference.selectedAgent;
export const selectSelectedModel = (state: RootState) =>
state.preference.selectedModel;
export const selectAvailableModels = (state: RootState) =>
state.preference.availableModels;
export const selectModelsLoading = (state: RootState) =>
state.preference.modelsLoading;

View File

@@ -15,6 +15,7 @@ const prompt = localStorage.getItem('DocsGPTPrompt');
const chunks = localStorage.getItem('DocsGPTChunks'); const chunks = localStorage.getItem('DocsGPTChunks');
const token_limit = localStorage.getItem('DocsGPTTokenLimit'); const token_limit = localStorage.getItem('DocsGPTTokenLimit');
const doc = localStorage.getItem('DocsGPTRecentDocs'); const doc = localStorage.getItem('DocsGPTRecentDocs');
const selectedModel = localStorage.getItem('DocsGPTSelectedModel');
const preloadedState: { preference: Preference } = { const preloadedState: { preference: Preference } = {
preference: { preference: {
@@ -47,6 +48,9 @@ const preloadedState: { preference: Preference } = {
agents: null, agents: null,
sharedAgents: null, sharedAgents: null,
selectedAgent: null, selectedAgent: null,
selectedModel: selectedModel ? JSON.parse(selectedModel) : null,
availableModels: [],
modelsLoading: false,
}, },
}; };
const store = configureStore({ const store = configureStore({

View File

@@ -0,0 +1,114 @@
#!/usr/bin/env python3
"""
Migration script to convert conversation_id from DBRef to ObjectId in shared_conversations collection.
"""
import pymongo
import logging
from tqdm import tqdm
from bson.dbref import DBRef
from bson.objectid import ObjectId
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger()
# Configuration
MONGO_URI = "mongodb://localhost:27017/"
DB_NAME = "docsgpt"
def backup_collection(collection, backup_collection_name):
"""Backup collection before migration."""
logger.info(f"Backing up collection {collection.name} to {backup_collection_name}")
collection.aggregate([{"$out": backup_collection_name}])
logger.info("Backup completed")
def migrate_conversation_id_dbref_to_objectid():
"""Migrate conversation_id from DBRef to ObjectId."""
client = pymongo.MongoClient(MONGO_URI)
db = client[DB_NAME]
shared_conversations_collection = db["shared_conversations"]
try:
# Backup collection before migration
backup_collection(shared_conversations_collection, "shared_conversations_backup")
# Find all documents and filter for DBRef conversation_id in Python
all_documents = list(shared_conversations_collection.find({}))
documents_with_dbref = []
for doc in all_documents:
conversation_id_field = doc.get("conversation_id")
if isinstance(conversation_id_field, DBRef):
documents_with_dbref.append(doc)
if not documents_with_dbref:
logger.info("No documents with DBRef conversation_id found. Migration not needed.")
return
logger.info(f"Found {len(documents_with_dbref)} documents with DBRef conversation_id")
# Process each document
migrated_count = 0
error_count = 0
for doc in tqdm(documents_with_dbref, desc="Migrating conversation_id"):
try:
conversation_id_field = doc.get("conversation_id")
# Extract the ObjectId from the DBRef
dbref_id = conversation_id_field.id
if dbref_id and ObjectId.is_valid(dbref_id):
# Update the document to use direct ObjectId
result = shared_conversations_collection.update_one(
{"_id": doc["_id"]},
{"$set": {"conversation_id": dbref_id}}
)
if result.modified_count > 0:
migrated_count += 1
logger.debug(f"Successfully migrated document {doc['_id']}")
else:
error_count += 1
logger.warning(f"Failed to update document {doc['_id']}")
else:
error_count += 1
logger.warning(f"Invalid ObjectId in DBRef for document {doc['_id']}: {dbref_id}")
except Exception as e:
error_count += 1
logger.error(f"Error migrating document {doc['_id']}: {e}")
# Final verification
all_docs_after = list(shared_conversations_collection.find({}))
remaining_dbref = 0
for doc in all_docs_after:
if isinstance(doc.get("conversation_id"), DBRef):
remaining_dbref += 1
logger.info("Migration completed:")
logger.info(f" - Total documents processed: {len(documents_with_dbref)}")
logger.info(f" - Successfully migrated: {migrated_count}")
logger.info(f" - Errors encountered: {error_count}")
logger.info(f" - Remaining DBRef documents: {remaining_dbref}")
if remaining_dbref == 0:
logger.info("✅ Migration successful: All DBRef conversation_id fields have been converted to ObjectId")
else:
logger.warning(f"⚠️ Migration incomplete: {remaining_dbref} DBRef documents still exist")
except Exception as e:
logger.error(f"Migration failed: {e}")
raise
finally:
client.close()
if __name__ == "__main__":
try:
logger.info("Starting conversation_id DBRef to ObjectId migration...")
migrate_conversation_id_dbref_to_objectid()
logger.info("Migration completed successfully!")
except Exception as e:
logger.error(f"Migration failed due to error: {e}")
logger.warning("Please verify database state or restore from backups if necessary.")

View File

@@ -12,7 +12,7 @@ class TestAgentCreator:
assert isinstance(agent, ClassicAgent) assert isinstance(agent, ClassicAgent)
assert agent.endpoint == agent_base_params["endpoint"] assert agent.endpoint == agent_base_params["endpoint"]
assert agent.llm_name == agent_base_params["llm_name"] assert agent.llm_name == agent_base_params["llm_name"]
assert agent.gpt_model == agent_base_params["gpt_model"] assert agent.model_id == agent_base_params["model_id"]
def test_create_react_agent(self, agent_base_params): def test_create_react_agent(self, agent_base_params):
agent = AgentCreator.create_agent("react", **agent_base_params) agent = AgentCreator.create_agent("react", **agent_base_params)

View File

@@ -15,7 +15,7 @@ class TestBaseAgentInitialization:
assert agent.endpoint == agent_base_params["endpoint"] assert agent.endpoint == agent_base_params["endpoint"]
assert agent.llm_name == agent_base_params["llm_name"] assert agent.llm_name == agent_base_params["llm_name"]
assert agent.gpt_model == agent_base_params["gpt_model"] assert agent.model_id == agent_base_params["model_id"]
assert agent.api_key == agent_base_params["api_key"] assert agent.api_key == agent_base_params["api_key"]
assert agent.prompt == agent_base_params["prompt"] assert agent.prompt == agent_base_params["prompt"]
assert agent.user == agent_base_params["decoded_token"]["sub"] assert agent.user == agent_base_params["decoded_token"]["sub"]
@@ -480,7 +480,7 @@ class TestBaseAgentLLMGeneration:
mock_llm.gen_stream.assert_called_once() mock_llm.gen_stream.assert_called_once()
call_args = mock_llm.gen_stream.call_args[1] call_args = mock_llm.gen_stream.call_args[1]
assert call_args["model"] == agent.gpt_model assert call_args["model"] == agent.model_id
assert call_args["messages"] == messages assert call_args["messages"] == messages
def test_llm_gen_with_tools( def test_llm_gen_with_tools(

View File

@@ -23,7 +23,7 @@ class TestReActAgent:
assert agent.endpoint == agent_base_params["endpoint"] assert agent.endpoint == agent_base_params["endpoint"]
assert agent.llm_name == agent_base_params["llm_name"] assert agent.llm_name == agent_base_params["llm_name"]
assert agent.gpt_model == agent_base_params["gpt_model"] assert agent.model_id == agent_base_params["model_id"]
@pytest.mark.unit @pytest.mark.unit

View File

@@ -274,8 +274,8 @@ class TestGPTModelRetrieval:
with flask_app.app_context(): with flask_app.app_context():
resource = BaseAnswerResource() resource = BaseAnswerResource()
assert hasattr(resource, "gpt_model") assert hasattr(resource, "default_model_id")
assert resource.gpt_model is not None assert resource.default_model_id is not None
@pytest.mark.unit @pytest.mark.unit
@@ -412,7 +412,7 @@ class TestCompleteStreamMethod:
resource.complete_stream( resource.complete_stream(
question="Test?", question="Test?",
agent=mock_agent, agent=mock_agent,
conversation_id=None, conversation_id=None,
user_api_key=None, user_api_key=None,
decoded_token=decoded_token, decoded_token=decoded_token,
should_save_conversation=True, should_save_conversation=True,
@@ -500,9 +500,10 @@ class TestProcessResponseStream:
result = resource.process_response_stream(iter(stream)) result = resource.process_response_stream(iter(stream))
assert len(result) == 5 assert len(result) == 6
assert result[0] is None assert result[0] is None
assert result[4] == "Test error" assert result[4] == "Test error"
assert result[5] is None
def test_handles_malformed_stream_data(self, mock_mongo_db, flask_app): def test_handles_malformed_stream_data(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource from application.api.answer.routes.base import BaseAnswerResource

View File

@@ -108,7 +108,7 @@ class TestConversationServiceSave:
sources=[], sources=[],
tool_calls=[], tool_calls=[],
llm=mock_llm, llm=mock_llm,
gpt_model="gpt-4", model_id="gpt-4",
decoded_token={}, # No 'sub' key decoded_token={}, # No 'sub' key
) )
@@ -136,7 +136,7 @@ class TestConversationServiceSave:
sources=sources, sources=sources,
tool_calls=[], tool_calls=[],
llm=mock_llm, llm=mock_llm,
gpt_model="gpt-4", model_id="gpt-4",
decoded_token={"sub": "user_123"}, decoded_token={"sub": "user_123"},
) )
@@ -167,7 +167,7 @@ class TestConversationServiceSave:
sources=[], sources=[],
tool_calls=[], tool_calls=[],
llm=mock_llm, llm=mock_llm,
gpt_model="gpt-4", model_id="gpt-4",
decoded_token={"sub": "user_123"}, decoded_token={"sub": "user_123"},
) )
@@ -208,7 +208,7 @@ class TestConversationServiceSave:
sources=[], sources=[],
tool_calls=[], tool_calls=[],
llm=mock_llm, llm=mock_llm,
gpt_model="gpt-4", model_id="gpt-4",
decoded_token={"sub": "user_123"}, decoded_token={"sub": "user_123"},
) )
@@ -237,6 +237,6 @@ class TestConversationServiceSave:
sources=[], sources=[],
tool_calls=[], tool_calls=[],
llm=mock_llm, llm=mock_llm,
gpt_model="gpt-4", model_id="gpt-4",
decoded_token={"sub": "hacker_456"}, decoded_token={"sub": "hacker_456"},
) )

View File

@@ -150,7 +150,7 @@ def agent_base_params(decoded_token):
return { return {
"endpoint": "https://api.example.com", "endpoint": "https://api.example.com",
"llm_name": "openai", "llm_name": "openai",
"gpt_model": "gpt-4", "model_id": "gpt-4",
"api_key": "test_api_key", "api_key": "test_api_key",
"user_api_key": None, "user_api_key": None,
"prompt": "You are a helpful assistant.", "prompt": "You are a helpful assistant.",

View File

@@ -1,11 +1,14 @@
import sys import sys
import types import types
import pytest import pytest
class _FakeCompletion: class _FakeCompletion:
def __init__(self, text): def __init__(self, text):
self.completion = text self.completion = text
class _FakeCompletions: class _FakeCompletions:
def __init__(self): def __init__(self):
self.last_kwargs = None self.last_kwargs = None
@@ -17,6 +20,7 @@ class _FakeCompletions:
return self._stream return self._stream
return _FakeCompletion("final") return _FakeCompletion("final")
class _FakeAnthropic: class _FakeAnthropic:
def __init__(self, api_key=None): def __init__(self, api_key=None):
self.api_key = api_key self.api_key = api_key
@@ -29,9 +33,19 @@ def patch_anthropic(monkeypatch):
fake.Anthropic = _FakeAnthropic fake.Anthropic = _FakeAnthropic
fake.HUMAN_PROMPT = "<HUMAN>" fake.HUMAN_PROMPT = "<HUMAN>"
fake.AI_PROMPT = "<AI>" fake.AI_PROMPT = "<AI>"
modules_to_remove = [key for key in sys.modules if key.startswith("anthropic")]
for key in modules_to_remove:
sys.modules.pop(key, None)
sys.modules["anthropic"] = fake sys.modules["anthropic"] = fake
if "application.llm.anthropic" in sys.modules:
del sys.modules["application.llm.anthropic"]
yield yield
sys.modules.pop("anthropic", None) sys.modules.pop("anthropic", None)
if "application.llm.anthropic" in sys.modules:
del sys.modules["application.llm.anthropic"]
def test_anthropic_raw_gen_builds_prompt_and_returns_completion(): def test_anthropic_raw_gen_builds_prompt_and_returns_completion():
@@ -42,7 +56,9 @@ def test_anthropic_raw_gen_builds_prompt_and_returns_completion():
{"content": "ctx"}, {"content": "ctx"},
{"content": "q"}, {"content": "q"},
] ]
out = llm._raw_gen(llm, model="claude-2", messages=msgs, stream=False, max_tokens=55) out = llm._raw_gen(
llm, model="claude-2", messages=msgs, stream=False, max_tokens=55
)
assert out == "final" assert out == "final"
last = llm.anthropic.completions.last_kwargs last = llm.anthropic.completions.last_kwargs
assert last["model"] == "claude-2" assert last["model"] == "claude-2"
@@ -59,7 +75,8 @@ def test_anthropic_raw_gen_stream_yields_chunks():
{"content": "ctx"}, {"content": "ctx"},
{"content": "q"}, {"content": "q"},
] ]
gen = llm._raw_gen_stream(llm, model="claude", messages=msgs, stream=True, max_tokens=10) gen = llm._raw_gen_stream(
llm, model="claude", messages=msgs, stream=True, max_tokens=10
)
chunks = list(gen) chunks = list(gen)
assert chunks == ["s1", "s2"] assert chunks == ["s1", "s2"]