mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 16:43:16 +00:00
feat: model registry and capabilities for multi-provider support (#2158)
* feat: Implement model registry and capabilities for multi-provider support - Added ModelRegistry to manage available models and their capabilities. - Introduced ModelProvider enum for different LLM providers. - Created ModelCapabilities dataclass to define model features. - Implemented methods to load models based on API keys and settings. - Added utility functions for model management in model_utils.py. - Updated settings.py to include provider-specific API keys. - Refactored LLM classes (Anthropic, OpenAI, Google, etc.) to utilize new model registry. - Enhanced utility functions to handle token limits and model validation. - Improved code structure and logging for better maintainability. * feat: Add model selection feature with API integration and UI component * feat: Add model selection and default model functionality in agent management * test: Update assertions and formatting in stream processing tests * refactor(llm): Standardize model identifier to model_id * fix tests --------- Co-authored-by: Alex <a@tushynski.me>
This commit is contained in:
236
application/core/model_settings.py
Normal file
236
application/core/model_settings.py
Normal file
@@ -0,0 +1,236 @@
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelProvider(str, Enum):
|
||||
OPENAI = "openai"
|
||||
AZURE_OPENAI = "azure_openai"
|
||||
ANTHROPIC = "anthropic"
|
||||
GROQ = "groq"
|
||||
GOOGLE = "google"
|
||||
HUGGINGFACE = "huggingface"
|
||||
LLAMA_CPP = "llama.cpp"
|
||||
DOCSGPT = "docsgpt"
|
||||
PREMAI = "premai"
|
||||
SAGEMAKER = "sagemaker"
|
||||
NOVITA = "novita"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelCapabilities:
|
||||
supports_tools: bool = False
|
||||
supports_structured_output: bool = False
|
||||
supports_streaming: bool = True
|
||||
supported_attachment_types: List[str] = field(default_factory=list)
|
||||
context_window: int = 128000
|
||||
input_cost_per_token: Optional[float] = None
|
||||
output_cost_per_token: Optional[float] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AvailableModel:
|
||||
id: str
|
||||
provider: ModelProvider
|
||||
display_name: str
|
||||
description: str = ""
|
||||
capabilities: ModelCapabilities = field(default_factory=ModelCapabilities)
|
||||
enabled: bool = True
|
||||
base_url: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
result = {
|
||||
"id": self.id,
|
||||
"provider": self.provider.value,
|
||||
"display_name": self.display_name,
|
||||
"description": self.description,
|
||||
"supported_attachment_types": self.capabilities.supported_attachment_types,
|
||||
"supports_tools": self.capabilities.supports_tools,
|
||||
"supports_structured_output": self.capabilities.supports_structured_output,
|
||||
"supports_streaming": self.capabilities.supports_streaming,
|
||||
"context_window": self.capabilities.context_window,
|
||||
"enabled": self.enabled,
|
||||
}
|
||||
if self.base_url:
|
||||
result["base_url"] = self.base_url
|
||||
return result
|
||||
|
||||
|
||||
class ModelRegistry:
|
||||
_instance = None
|
||||
_initialized = False
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not ModelRegistry._initialized:
|
||||
self.models: Dict[str, AvailableModel] = {}
|
||||
self.default_model_id: Optional[str] = None
|
||||
self._load_models()
|
||||
ModelRegistry._initialized = True
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "ModelRegistry":
|
||||
return cls()
|
||||
|
||||
def _load_models(self):
|
||||
from application.core.settings import settings
|
||||
|
||||
self.models.clear()
|
||||
|
||||
self._add_docsgpt_models(settings)
|
||||
if settings.OPENAI_API_KEY or (
|
||||
settings.LLM_PROVIDER == "openai" and settings.API_KEY
|
||||
):
|
||||
self._add_openai_models(settings)
|
||||
if settings.OPENAI_API_BASE or (
|
||||
settings.LLM_PROVIDER == "azure_openai" and settings.API_KEY
|
||||
):
|
||||
self._add_azure_openai_models(settings)
|
||||
if settings.ANTHROPIC_API_KEY or (
|
||||
settings.LLM_PROVIDER == "anthropic" and settings.API_KEY
|
||||
):
|
||||
self._add_anthropic_models(settings)
|
||||
if settings.GOOGLE_API_KEY or (
|
||||
settings.LLM_PROVIDER == "google" and settings.API_KEY
|
||||
):
|
||||
self._add_google_models(settings)
|
||||
if settings.GROQ_API_KEY or (
|
||||
settings.LLM_PROVIDER == "groq" and settings.API_KEY
|
||||
):
|
||||
self._add_groq_models(settings)
|
||||
if settings.HUGGINGFACE_API_KEY or (
|
||||
settings.LLM_PROVIDER == "huggingface" and settings.API_KEY
|
||||
):
|
||||
self._add_huggingface_models(settings)
|
||||
# Default model selection
|
||||
|
||||
if settings.LLM_NAME and settings.LLM_NAME in self.models:
|
||||
self.default_model_id = settings.LLM_NAME
|
||||
elif settings.LLM_PROVIDER and settings.API_KEY:
|
||||
for model_id, model in self.models.items():
|
||||
if model.provider.value == settings.LLM_PROVIDER:
|
||||
self.default_model_id = model_id
|
||||
break
|
||||
else:
|
||||
self.default_model_id = next(iter(self.models.keys()))
|
||||
logger.info(
|
||||
f"ModelRegistry loaded {len(self.models)} models, default: {self.default_model_id}"
|
||||
)
|
||||
|
||||
def _add_openai_models(self, settings):
|
||||
from application.core.model_configs import OPENAI_MODELS
|
||||
|
||||
if settings.OPENAI_API_KEY:
|
||||
for model in OPENAI_MODELS:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
if settings.LLM_PROVIDER == "openai" and settings.LLM_NAME:
|
||||
for model in OPENAI_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in OPENAI_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_azure_openai_models(self, settings):
|
||||
from application.core.model_configs import AZURE_OPENAI_MODELS
|
||||
|
||||
if settings.LLM_PROVIDER == "azure_openai" and settings.LLM_NAME:
|
||||
for model in AZURE_OPENAI_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in AZURE_OPENAI_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_anthropic_models(self, settings):
|
||||
from application.core.model_configs import ANTHROPIC_MODELS
|
||||
|
||||
if settings.ANTHROPIC_API_KEY:
|
||||
for model in ANTHROPIC_MODELS:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
if settings.LLM_PROVIDER == "anthropic" and settings.LLM_NAME:
|
||||
for model in ANTHROPIC_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in ANTHROPIC_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_google_models(self, settings):
|
||||
from application.core.model_configs import GOOGLE_MODELS
|
||||
|
||||
if settings.GOOGLE_API_KEY:
|
||||
for model in GOOGLE_MODELS:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
if settings.LLM_PROVIDER == "google" and settings.LLM_NAME:
|
||||
for model in GOOGLE_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in GOOGLE_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_groq_models(self, settings):
|
||||
from application.core.model_configs import GROQ_MODELS
|
||||
|
||||
if settings.GROQ_API_KEY:
|
||||
for model in GROQ_MODELS:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
if settings.LLM_PROVIDER == "groq" and settings.LLM_NAME:
|
||||
for model in GROQ_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in GROQ_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_docsgpt_models(self, settings):
|
||||
model_id = "docsgpt-local"
|
||||
model = AvailableModel(
|
||||
id=model_id,
|
||||
provider=ModelProvider.DOCSGPT,
|
||||
display_name="DocsGPT Model",
|
||||
description="Local model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=False,
|
||||
supported_attachment_types=[],
|
||||
),
|
||||
)
|
||||
self.models[model_id] = model
|
||||
|
||||
def _add_huggingface_models(self, settings):
|
||||
model_id = "huggingface-local"
|
||||
model = AvailableModel(
|
||||
id=model_id,
|
||||
provider=ModelProvider.HUGGINGFACE,
|
||||
display_name="Hugging Face Model",
|
||||
description="Local Hugging Face model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=False,
|
||||
supported_attachment_types=[],
|
||||
),
|
||||
)
|
||||
self.models[model_id] = model
|
||||
|
||||
def get_model(self, model_id: str) -> Optional[AvailableModel]:
|
||||
return self.models.get(model_id)
|
||||
|
||||
def get_all_models(self) -> List[AvailableModel]:
|
||||
return list(self.models.values())
|
||||
|
||||
def get_enabled_models(self) -> List[AvailableModel]:
|
||||
return [m for m in self.models.values() if m.enabled]
|
||||
|
||||
def model_exists(self, model_id: str) -> bool:
|
||||
return model_id in self.models
|
||||
Reference in New Issue
Block a user