mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
feat: model registry and capabilities for multi-provider support (#2158)
* feat: Implement model registry and capabilities for multi-provider support - Added ModelRegistry to manage available models and their capabilities. - Introduced ModelProvider enum for different LLM providers. - Created ModelCapabilities dataclass to define model features. - Implemented methods to load models based on API keys and settings. - Added utility functions for model management in model_utils.py. - Updated settings.py to include provider-specific API keys. - Refactored LLM classes (Anthropic, OpenAI, Google, etc.) to utilize new model registry. - Enhanced utility functions to handle token limits and model validation. - Improved code structure and logging for better maintainability. * feat: Add model selection feature with API integration and UI component * feat: Add model selection and default model functionality in agent management * test: Update assertions and formatting in stream processing tests * refactor(llm): Standardize model identifier to model_id * fix tests --------- Co-authored-by: Alex <a@tushynski.me>
This commit is contained in:
@@ -12,12 +12,17 @@ from bson.objectid import ObjectId
|
||||
from application.agents.agent_creator import AgentCreator
|
||||
from application.api.answer.services.conversation_service import ConversationService
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
from application.core.model_utils import (
|
||||
get_api_key_for_provider,
|
||||
get_default_model_id,
|
||||
get_provider_from_model_id,
|
||||
validate_model_id,
|
||||
)
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.retriever.retriever_creator import RetrieverCreator
|
||||
from application.utils import (
|
||||
calculate_doc_token_budget,
|
||||
get_gpt_model,
|
||||
limit_chat_history,
|
||||
)
|
||||
|
||||
@@ -83,7 +88,7 @@ class StreamProcessor:
|
||||
self.retriever_config = {}
|
||||
self.is_shared_usage = False
|
||||
self.shared_token = None
|
||||
self.gpt_model = get_gpt_model()
|
||||
self.model_id: Optional[str] = None
|
||||
self.conversation_service = ConversationService()
|
||||
self.prompt_renderer = PromptRenderer()
|
||||
self._prompt_content: Optional[str] = None
|
||||
@@ -91,6 +96,7 @@ class StreamProcessor:
|
||||
|
||||
def initialize(self):
|
||||
"""Initialize all required components for processing"""
|
||||
self._validate_and_set_model()
|
||||
self._configure_agent()
|
||||
self._configure_source()
|
||||
self._configure_retriever()
|
||||
@@ -112,7 +118,7 @@ class StreamProcessor:
|
||||
]
|
||||
else:
|
||||
self.history = limit_chat_history(
|
||||
json.loads(self.data.get("history", "[]")), gpt_model=self.gpt_model
|
||||
json.loads(self.data.get("history", "[]")), model_id=self.model_id
|
||||
)
|
||||
|
||||
def _process_attachments(self):
|
||||
@@ -143,6 +149,25 @@ class StreamProcessor:
|
||||
)
|
||||
return attachments
|
||||
|
||||
def _validate_and_set_model(self):
|
||||
"""Validate and set model_id from request"""
|
||||
from application.core.model_settings import ModelRegistry
|
||||
|
||||
requested_model = self.data.get("model_id")
|
||||
|
||||
if requested_model:
|
||||
if not validate_model_id(requested_model):
|
||||
registry = ModelRegistry.get_instance()
|
||||
available_models = [m.id for m in registry.get_enabled_models()]
|
||||
raise ValueError(
|
||||
f"Invalid model_id '{requested_model}'. "
|
||||
f"Available models: {', '.join(available_models[:5])}"
|
||||
+ (f" and {len(available_models) - 5} more" if len(available_models) > 5 else "")
|
||||
)
|
||||
self.model_id = requested_model
|
||||
else:
|
||||
self.model_id = get_default_model_id()
|
||||
|
||||
def _get_agent_key(self, agent_id: Optional[str], user_id: Optional[str]) -> tuple:
|
||||
"""Get API key for agent with access control"""
|
||||
if not agent_id:
|
||||
@@ -322,7 +347,7 @@ class StreamProcessor:
|
||||
def _configure_retriever(self):
|
||||
history_token_limit = int(self.data.get("token_limit", 2000))
|
||||
doc_token_limit = calculate_doc_token_budget(
|
||||
gpt_model=self.gpt_model, history_token_limit=history_token_limit
|
||||
model_id=self.model_id, history_token_limit=history_token_limit
|
||||
)
|
||||
|
||||
self.retriever_config = {
|
||||
@@ -344,7 +369,7 @@ class StreamProcessor:
|
||||
prompt=get_prompt(self.agent_config["prompt_id"], self.prompts_collection),
|
||||
chunks=self.retriever_config["chunks"],
|
||||
doc_token_limit=self.retriever_config.get("doc_token_limit", 50000),
|
||||
gpt_model=self.gpt_model,
|
||||
model_id=self.model_id,
|
||||
user_api_key=self.agent_config["user_api_key"],
|
||||
decoded_token=self.decoded_token,
|
||||
)
|
||||
@@ -626,12 +651,19 @@ class StreamProcessor:
|
||||
tools_data=tools_data,
|
||||
)
|
||||
|
||||
provider = (
|
||||
get_provider_from_model_id(self.model_id)
|
||||
if self.model_id
|
||||
else settings.LLM_PROVIDER
|
||||
)
|
||||
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
|
||||
|
||||
return AgentCreator.create_agent(
|
||||
self.agent_config["agent_type"],
|
||||
endpoint="stream",
|
||||
llm_name=settings.LLM_PROVIDER,
|
||||
gpt_model=self.gpt_model,
|
||||
api_key=settings.API_KEY,
|
||||
llm_name=provider or settings.LLM_PROVIDER,
|
||||
model_id=self.model_id,
|
||||
api_key=system_api_key,
|
||||
user_api_key=self.agent_config["user_api_key"],
|
||||
prompt=rendered_prompt,
|
||||
chat_history=self.history,
|
||||
|
||||
Reference in New Issue
Block a user