mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-30 09:03:15 +00:00
feat: model registry and capabilities for multi-provider support (#2158)
* feat: Implement model registry and capabilities for multi-provider support - Added ModelRegistry to manage available models and their capabilities. - Introduced ModelProvider enum for different LLM providers. - Created ModelCapabilities dataclass to define model features. - Implemented methods to load models based on API keys and settings. - Added utility functions for model management in model_utils.py. - Updated settings.py to include provider-specific API keys. - Refactored LLM classes (Anthropic, OpenAI, Google, etc.) to utilize new model registry. - Enhanced utility functions to handle token limits and model validation. - Improved code structure and logging for better maintainability. * feat: Add model selection feature with API integration and UI component * feat: Add model selection and default model functionality in agent management * test: Update assertions and formatting in stream processing tests * refactor(llm): Standardize model identifier to model_id * fix tests --------- Co-authored-by: Alex <a@tushynski.me>
This commit is contained in:
91
application/core/model_utils.py
Normal file
91
application/core/model_utils.py
Normal file
@@ -0,0 +1,91 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from application.core.model_settings import ModelRegistry
|
||||
|
||||
|
||||
def get_api_key_for_provider(provider: str) -> Optional[str]:
|
||||
"""Get the appropriate API key for a provider"""
|
||||
from application.core.settings import settings
|
||||
|
||||
provider_key_map = {
|
||||
"openai": settings.OPENAI_API_KEY,
|
||||
"anthropic": settings.ANTHROPIC_API_KEY,
|
||||
"google": settings.GOOGLE_API_KEY,
|
||||
"groq": settings.GROQ_API_KEY,
|
||||
"huggingface": settings.HUGGINGFACE_API_KEY,
|
||||
"azure_openai": settings.API_KEY,
|
||||
"docsgpt": None,
|
||||
"llama.cpp": None,
|
||||
}
|
||||
|
||||
provider_key = provider_key_map.get(provider)
|
||||
if provider_key:
|
||||
return provider_key
|
||||
return settings.API_KEY
|
||||
|
||||
|
||||
def get_all_available_models() -> Dict[str, Dict[str, Any]]:
|
||||
"""Get all available models with metadata for API response"""
|
||||
registry = ModelRegistry.get_instance()
|
||||
return {model.id: model.to_dict() for model in registry.get_enabled_models()}
|
||||
|
||||
|
||||
def validate_model_id(model_id: str) -> bool:
|
||||
"""Check if a model ID exists in registry"""
|
||||
registry = ModelRegistry.get_instance()
|
||||
return registry.model_exists(model_id)
|
||||
|
||||
|
||||
def get_model_capabilities(model_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get capabilities for a specific model"""
|
||||
registry = ModelRegistry.get_instance()
|
||||
model = registry.get_model(model_id)
|
||||
if model:
|
||||
return {
|
||||
"supported_attachment_types": model.capabilities.supported_attachment_types,
|
||||
"supports_tools": model.capabilities.supports_tools,
|
||||
"supports_structured_output": model.capabilities.supports_structured_output,
|
||||
"context_window": model.capabilities.context_window,
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
def get_default_model_id() -> str:
|
||||
"""Get the system default model ID"""
|
||||
registry = ModelRegistry.get_instance()
|
||||
return registry.default_model_id
|
||||
|
||||
|
||||
def get_provider_from_model_id(model_id: str) -> Optional[str]:
|
||||
"""Get the provider name for a given model_id"""
|
||||
registry = ModelRegistry.get_instance()
|
||||
model = registry.get_model(model_id)
|
||||
if model:
|
||||
return model.provider.value
|
||||
return None
|
||||
|
||||
|
||||
def get_token_limit(model_id: str) -> int:
|
||||
"""
|
||||
Get context window (token limit) for a model.
|
||||
Returns model's context_window or default 128000 if model not found.
|
||||
"""
|
||||
from application.core.settings import settings
|
||||
|
||||
registry = ModelRegistry.get_instance()
|
||||
model = registry.get_model(model_id)
|
||||
if model:
|
||||
return model.capabilities.context_window
|
||||
return settings.DEFAULT_LLM_TOKEN_LIMIT
|
||||
|
||||
|
||||
def get_base_url_for_model(model_id: str) -> Optional[str]:
|
||||
"""
|
||||
Get the custom base_url for a specific model if configured.
|
||||
Returns None if no custom base_url is set.
|
||||
"""
|
||||
registry = ModelRegistry.get_instance()
|
||||
model = registry.get_model(model_id)
|
||||
if model:
|
||||
return model.base_url
|
||||
return None
|
||||
Reference in New Issue
Block a user