mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-07 06:30:03 +00:00
Compare commits
3 Commits
feat-bring
...
feat-model
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ddd5704c49 | ||
|
|
d54e6d8b34 | ||
|
|
2806825959 |
@@ -35,8 +35,5 @@ MICROSOFT_TENANT_ID=your-azure-ad-tenant-id
|
||||
#Alternatively, use "https://login.microsoftonline.com/common" for multi-tenant app.
|
||||
MICROSOFT_AUTHORITY=https://{tenantId}.ciamlogin.com/{tenantId}
|
||||
|
||||
# User-data Postgres DB (Phase 0 of the MongoDB→Postgres migration).
|
||||
# Standard Postgres URI — `postgres://` and `postgresql://` both work.
|
||||
# Leave unset while the migration is still being rolled out; the app will
|
||||
# fall back to MongoDB for user data until POSTGRES_URI is configured.
|
||||
|
||||
# POSTGRES_URI=postgresql://docsgpt:docsgpt@localhost:5432/docsgpt
|
||||
|
||||
@@ -1,266 +0,0 @@
|
||||
"""
|
||||
Model configurations for all supported LLM providers.
|
||||
"""
|
||||
|
||||
from application.core.model_settings import (
|
||||
AvailableModel,
|
||||
ModelCapabilities,
|
||||
ModelProvider,
|
||||
)
|
||||
|
||||
# Base image attachment types supported by most vision-capable LLMs
|
||||
IMAGE_ATTACHMENTS = [
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/jpg",
|
||||
"image/webp",
|
||||
"image/gif",
|
||||
]
|
||||
|
||||
# PDF excluded: most OpenAI-compatible endpoints don't support native PDF uploads.
|
||||
# When excluded, PDFs are synthetically processed by converting pages to images.
|
||||
OPENAI_ATTACHMENTS = IMAGE_ATTACHMENTS
|
||||
|
||||
GOOGLE_ATTACHMENTS = ["application/pdf"] + IMAGE_ATTACHMENTS
|
||||
|
||||
ANTHROPIC_ATTACHMENTS = IMAGE_ATTACHMENTS
|
||||
|
||||
OPENROUTER_ATTACHMENTS = IMAGE_ATTACHMENTS
|
||||
|
||||
NOVITA_ATTACHMENTS = IMAGE_ATTACHMENTS
|
||||
|
||||
|
||||
OPENAI_MODELS = [
|
||||
AvailableModel(
|
||||
id="gpt-5.1",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="GPT-5.1",
|
||||
description="Flagship model with enhanced reasoning, coding, and agentic capabilities",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=OPENAI_ATTACHMENTS,
|
||||
context_window=200000,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="gpt-5-mini",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="GPT-5 Mini",
|
||||
description="Faster, cost-effective variant of GPT-5.1",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=OPENAI_ATTACHMENTS,
|
||||
context_window=200000,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
ANTHROPIC_MODELS = [
|
||||
AvailableModel(
|
||||
id="claude-3-5-sonnet-20241022",
|
||||
provider=ModelProvider.ANTHROPIC,
|
||||
display_name="Claude 3.5 Sonnet (Latest)",
|
||||
description="Latest Claude 3.5 Sonnet with enhanced capabilities",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
|
||||
context_window=200000,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="claude-3-5-sonnet",
|
||||
provider=ModelProvider.ANTHROPIC,
|
||||
display_name="Claude 3.5 Sonnet",
|
||||
description="Balanced performance and capability",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
|
||||
context_window=200000,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="claude-3-opus",
|
||||
provider=ModelProvider.ANTHROPIC,
|
||||
display_name="Claude 3 Opus",
|
||||
description="Most capable Claude model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
|
||||
context_window=200000,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="claude-3-haiku",
|
||||
provider=ModelProvider.ANTHROPIC,
|
||||
display_name="Claude 3 Haiku",
|
||||
description="Fastest Claude model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
|
||||
context_window=200000,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
GOOGLE_MODELS = [
|
||||
AvailableModel(
|
||||
id="gemini-flash-latest",
|
||||
provider=ModelProvider.GOOGLE,
|
||||
display_name="Gemini Flash (Latest)",
|
||||
description="Latest experimental Gemini model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=GOOGLE_ATTACHMENTS,
|
||||
context_window=int(1e6),
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="gemini-flash-lite-latest",
|
||||
provider=ModelProvider.GOOGLE,
|
||||
display_name="Gemini Flash Lite (Latest)",
|
||||
description="Fast with huge context window",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=GOOGLE_ATTACHMENTS,
|
||||
context_window=int(1e6),
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="gemini-3-pro-preview",
|
||||
provider=ModelProvider.GOOGLE,
|
||||
display_name="Gemini 3 Pro",
|
||||
description="Most capable Gemini model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=GOOGLE_ATTACHMENTS,
|
||||
context_window=2000000,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
GROQ_MODELS = [
|
||||
AvailableModel(
|
||||
id="llama-3.3-70b-versatile",
|
||||
provider=ModelProvider.GROQ,
|
||||
display_name="Llama 3.3 70B",
|
||||
description="Latest Llama model with high-speed inference",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
context_window=128000,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="openai/gpt-oss-120b",
|
||||
provider=ModelProvider.GROQ,
|
||||
display_name="GPT-OSS 120B",
|
||||
description="Open-source GPT model optimized for speed",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
context_window=128000,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
OPENROUTER_MODELS = [
|
||||
AvailableModel(
|
||||
id="qwen/qwen3-coder:free",
|
||||
provider=ModelProvider.OPENROUTER,
|
||||
display_name="Qwen 3 Coder",
|
||||
description="Latest Qwen model with high-speed inference",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
context_window=128000,
|
||||
supported_attachment_types=OPENROUTER_ATTACHMENTS
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="google/gemma-3-27b-it:free",
|
||||
provider=ModelProvider.OPENROUTER,
|
||||
display_name="Gemma 3 27B",
|
||||
description="Latest Gemma model with high-speed inference",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
context_window=128000,
|
||||
supported_attachment_types=OPENROUTER_ATTACHMENTS
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
NOVITA_MODELS = [
|
||||
AvailableModel(
|
||||
id="moonshotai/kimi-k2.5",
|
||||
provider=ModelProvider.NOVITA,
|
||||
display_name="Kimi K2.5",
|
||||
description="MoE model with function calling, structured output, reasoning, and vision",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=NOVITA_ATTACHMENTS,
|
||||
context_window=262144,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="zai-org/glm-5",
|
||||
provider=ModelProvider.NOVITA,
|
||||
display_name="GLM-5",
|
||||
description="MoE model with function calling, structured output, and reasoning",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=[],
|
||||
context_window=202800,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="minimax/minimax-m2.5",
|
||||
provider=ModelProvider.NOVITA,
|
||||
display_name="MiniMax M2.5",
|
||||
description="MoE model with function calling, structured output, and reasoning",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=[],
|
||||
context_window=204800,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
AZURE_OPENAI_MODELS = [
|
||||
AvailableModel(
|
||||
id="azure-gpt-4",
|
||||
provider=ModelProvider.AZURE_OPENAI,
|
||||
display_name="Azure OpenAI GPT-4",
|
||||
description="Azure-hosted GPT model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=OPENAI_ATTACHMENTS,
|
||||
context_window=8192,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def create_custom_openai_model(model_name: str, base_url: str) -> AvailableModel:
|
||||
"""Create a custom OpenAI-compatible model (e.g., LM Studio, Ollama)."""
|
||||
return AvailableModel(
|
||||
id=model_name,
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name=model_name,
|
||||
description=f"Custom OpenAI-compatible model at {base_url}",
|
||||
base_url=base_url,
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supported_attachment_types=OPENAI_ATTACHMENTS,
|
||||
),
|
||||
)
|
||||
164
application/core/model_registry.py
Normal file
164
application/core/model_registry.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""Layered model registry.
|
||||
|
||||
Loads model catalogs from YAML files (built-in + operator-supplied),
|
||||
groups them by provider name, then for each registered provider plugin
|
||||
calls ``get_models`` to produce the final per-provider model list.
|
||||
|
||||
The ``user_id`` parameter on lookup methods is reserved for the future
|
||||
end-user BYOM (per-user model records in Postgres). It is currently
|
||||
ignored — defaulted to ``None`` everywhere — so call sites can be
|
||||
threaded through without a wide refactor when BYOM lands.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from application.core.model_settings import AvailableModel
|
||||
from application.core.model_yaml import (
|
||||
BUILTIN_MODELS_DIR,
|
||||
ProviderCatalog,
|
||||
load_model_yamls,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelRegistry:
|
||||
"""Singleton registry of available models."""
|
||||
|
||||
_instance: Optional["ModelRegistry"] = None
|
||||
_initialized: bool = False
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not ModelRegistry._initialized:
|
||||
self.models: Dict[str, AvailableModel] = {}
|
||||
self.default_model_id: Optional[str] = None
|
||||
self._load_models()
|
||||
ModelRegistry._initialized = True
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "ModelRegistry":
|
||||
return cls()
|
||||
|
||||
@classmethod
|
||||
def reset(cls) -> None:
|
||||
"""Clear the singleton. Intended for test fixtures."""
|
||||
cls._instance = None
|
||||
cls._initialized = False
|
||||
|
||||
def _load_models(self) -> None:
|
||||
from pathlib import Path
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.llm.providers import ALL_PROVIDERS
|
||||
|
||||
directories = [BUILTIN_MODELS_DIR]
|
||||
operator_dir = getattr(settings, "MODELS_CONFIG_DIR", None)
|
||||
if operator_dir:
|
||||
op_path = Path(operator_dir)
|
||||
if not op_path.exists():
|
||||
logger.warning(
|
||||
"MODELS_CONFIG_DIR=%s does not exist; no operator "
|
||||
"model YAMLs will be loaded.",
|
||||
operator_dir,
|
||||
)
|
||||
elif not op_path.is_dir():
|
||||
logger.warning(
|
||||
"MODELS_CONFIG_DIR=%s is not a directory; no operator "
|
||||
"model YAMLs will be loaded.",
|
||||
operator_dir,
|
||||
)
|
||||
else:
|
||||
directories.append(op_path)
|
||||
|
||||
catalogs = load_model_yamls(directories)
|
||||
|
||||
# Validate every catalog targets a known plugin before doing any
|
||||
# registry work, so an unknown provider name in YAML aborts boot
|
||||
# with a clear error.
|
||||
plugin_names = {p.name for p in ALL_PROVIDERS}
|
||||
for c in catalogs:
|
||||
if c.provider not in plugin_names:
|
||||
raise ValueError(
|
||||
f"{c.source_path}: YAML declares unknown provider "
|
||||
f"{c.provider!r}; no Provider plugin is registered "
|
||||
f"under that name. Known: {sorted(plugin_names)}"
|
||||
)
|
||||
|
||||
catalogs_by_provider: Dict[str, List[ProviderCatalog]] = defaultdict(list)
|
||||
for c in catalogs:
|
||||
catalogs_by_provider[c.provider].append(c)
|
||||
|
||||
self.models.clear()
|
||||
for provider in ALL_PROVIDERS:
|
||||
if not provider.is_enabled(settings):
|
||||
continue
|
||||
for model in provider.get_models(
|
||||
settings, catalogs_by_provider.get(provider.name, [])
|
||||
):
|
||||
self.models[model.id] = model
|
||||
|
||||
self.default_model_id = self._resolve_default(settings)
|
||||
|
||||
logger.info(
|
||||
"ModelRegistry loaded %d models, default: %s",
|
||||
len(self.models),
|
||||
self.default_model_id,
|
||||
)
|
||||
|
||||
def _resolve_default(self, settings) -> Optional[str]:
|
||||
if settings.LLM_NAME:
|
||||
for name in self._parse_model_names(settings.LLM_NAME):
|
||||
if name in self.models:
|
||||
return name
|
||||
if settings.LLM_NAME in self.models:
|
||||
return settings.LLM_NAME
|
||||
|
||||
if settings.LLM_PROVIDER and settings.API_KEY:
|
||||
for model_id, model in self.models.items():
|
||||
if model.provider.value == settings.LLM_PROVIDER:
|
||||
return model_id
|
||||
|
||||
if self.models:
|
||||
return next(iter(self.models.keys()))
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _parse_model_names(llm_name: str) -> List[str]:
|
||||
if not llm_name:
|
||||
return []
|
||||
return [name.strip() for name in llm_name.split(",") if name.strip()]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lookup API. ``user_id`` is reserved for the future BYOM and
|
||||
# is ignored today — but threading it through every call site now
|
||||
# means BYOM doesn't require a wide refactor when we build it.
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_model(
|
||||
self, model_id: str, user_id: Optional[str] = None
|
||||
) -> Optional[AvailableModel]:
|
||||
return self.models.get(model_id)
|
||||
|
||||
def get_all_models(
|
||||
self, user_id: Optional[str] = None
|
||||
) -> List[AvailableModel]:
|
||||
return list(self.models.values())
|
||||
|
||||
def get_enabled_models(
|
||||
self, user_id: Optional[str] = None
|
||||
) -> List[AvailableModel]:
|
||||
return [m for m in self.models.values() if m.enabled]
|
||||
|
||||
def model_exists(
|
||||
self, model_id: str, user_id: Optional[str] = None
|
||||
) -> bool:
|
||||
return model_id in self.models
|
||||
@@ -5,9 +5,16 @@ from typing import Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Re-exported here so existing call sites (and tests) that do
|
||||
# ``from application.core.model_settings import ModelRegistry`` keep
|
||||
# working. The implementation lives in ``application/core/model_registry.py``.
|
||||
# Imported lazily inside ``__getattr__`` to avoid an import cycle with
|
||||
# ``model_yaml`` → ``model_settings`` (this file).
|
||||
|
||||
|
||||
class ModelProvider(str, Enum):
|
||||
OPENAI = "openai"
|
||||
OPENAI_COMPATIBLE = "openai_compatible"
|
||||
OPENROUTER = "openrouter"
|
||||
AZURE_OPENAI = "azure_openai"
|
||||
ANTHROPIC = "anthropic"
|
||||
@@ -41,11 +48,20 @@ class AvailableModel:
|
||||
capabilities: ModelCapabilities = field(default_factory=ModelCapabilities)
|
||||
enabled: bool = True
|
||||
base_url: Optional[str] = None
|
||||
# User-facing label distinct from the dispatch ``provider``. Used by
|
||||
# openai_compatible YAMLs so a Mistral model shows "mistral" in the
|
||||
# API response while still routing through the OpenAI wire format.
|
||||
display_provider: Optional[str] = None
|
||||
# Per-record API key. Operator YAMLs leave this None; populated for
|
||||
# openai_compatible models (resolved from the YAML's ``api_key_env``)
|
||||
# and reserved for the future end-user BYOM phase. Never serialized
|
||||
# into to_dict().
|
||||
api_key: Optional[str] = field(default=None, repr=False, compare=False)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
result = {
|
||||
"id": self.id,
|
||||
"provider": self.provider.value,
|
||||
"provider": self.display_provider or self.provider.value,
|
||||
"display_name": self.display_name,
|
||||
"description": self.description,
|
||||
"supported_attachment_types": self.capabilities.supported_attachment_types,
|
||||
@@ -60,255 +76,14 @@ class AvailableModel:
|
||||
return result
|
||||
|
||||
|
||||
class ModelRegistry:
|
||||
_instance = None
|
||||
_initialized = False
|
||||
def __getattr__(name):
|
||||
"""Lazy re-export of ``ModelRegistry`` from ``model_registry.py``.
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
Done lazily to avoid an import cycle: ``model_registry`` imports
|
||||
``model_yaml`` which imports the dataclasses from this file.
|
||||
"""
|
||||
if name == "ModelRegistry":
|
||||
from application.core.model_registry import ModelRegistry as _MR
|
||||
|
||||
def __init__(self):
|
||||
if not ModelRegistry._initialized:
|
||||
self.models: Dict[str, AvailableModel] = {}
|
||||
self.default_model_id: Optional[str] = None
|
||||
self._load_models()
|
||||
ModelRegistry._initialized = True
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "ModelRegistry":
|
||||
return cls()
|
||||
|
||||
def _load_models(self):
|
||||
from application.core.settings import settings
|
||||
|
||||
self.models.clear()
|
||||
|
||||
# Skip DocsGPT model if using custom OpenAI-compatible endpoint
|
||||
if not settings.OPENAI_BASE_URL:
|
||||
self._add_docsgpt_models(settings)
|
||||
if (
|
||||
settings.OPENAI_API_KEY
|
||||
or (settings.LLM_PROVIDER == "openai" and settings.API_KEY)
|
||||
or settings.OPENAI_BASE_URL
|
||||
):
|
||||
self._add_openai_models(settings)
|
||||
if settings.OPENAI_API_BASE or (
|
||||
settings.LLM_PROVIDER == "azure_openai" and settings.API_KEY
|
||||
):
|
||||
self._add_azure_openai_models(settings)
|
||||
if settings.ANTHROPIC_API_KEY or (
|
||||
settings.LLM_PROVIDER == "anthropic" and settings.API_KEY
|
||||
):
|
||||
self._add_anthropic_models(settings)
|
||||
if settings.GOOGLE_API_KEY or (
|
||||
settings.LLM_PROVIDER == "google" and settings.API_KEY
|
||||
):
|
||||
self._add_google_models(settings)
|
||||
if settings.GROQ_API_KEY or (
|
||||
settings.LLM_PROVIDER == "groq" and settings.API_KEY
|
||||
):
|
||||
self._add_groq_models(settings)
|
||||
if settings.OPEN_ROUTER_API_KEY or (
|
||||
settings.LLM_PROVIDER == "openrouter" and settings.API_KEY
|
||||
):
|
||||
self._add_openrouter_models(settings)
|
||||
if settings.NOVITA_API_KEY or (
|
||||
settings.LLM_PROVIDER == "novita" and settings.API_KEY
|
||||
):
|
||||
self._add_novita_models(settings)
|
||||
if settings.HUGGINGFACE_API_KEY or (
|
||||
settings.LLM_PROVIDER == "huggingface" and settings.API_KEY
|
||||
):
|
||||
self._add_huggingface_models(settings)
|
||||
# Default model selection
|
||||
if settings.LLM_NAME:
|
||||
# Parse LLM_NAME (may be comma-separated)
|
||||
model_names = self._parse_model_names(settings.LLM_NAME)
|
||||
# First model in the list becomes default
|
||||
for model_name in model_names:
|
||||
if model_name in self.models:
|
||||
self.default_model_id = model_name
|
||||
break
|
||||
# Backward compat: try exact match if no parsed model found
|
||||
if not self.default_model_id and settings.LLM_NAME in self.models:
|
||||
self.default_model_id = settings.LLM_NAME
|
||||
|
||||
if not self.default_model_id:
|
||||
if settings.LLM_PROVIDER and settings.API_KEY:
|
||||
for model_id, model in self.models.items():
|
||||
if model.provider.value == settings.LLM_PROVIDER:
|
||||
self.default_model_id = model_id
|
||||
break
|
||||
|
||||
if not self.default_model_id and self.models:
|
||||
self.default_model_id = next(iter(self.models.keys()))
|
||||
logger.info(
|
||||
f"ModelRegistry loaded {len(self.models)} models, default: {self.default_model_id}"
|
||||
)
|
||||
|
||||
def _add_openai_models(self, settings):
|
||||
from application.core.model_configs import (
|
||||
OPENAI_MODELS,
|
||||
create_custom_openai_model,
|
||||
)
|
||||
|
||||
# Check if using local OpenAI-compatible endpoint (Ollama, LM Studio, etc.)
|
||||
using_local_endpoint = bool(
|
||||
settings.OPENAI_BASE_URL and settings.OPENAI_BASE_URL.strip()
|
||||
)
|
||||
|
||||
if using_local_endpoint:
|
||||
# When OPENAI_BASE_URL is set, ONLY register custom models from LLM_NAME
|
||||
# Do NOT add standard OpenAI models (gpt-5.1, etc.)
|
||||
if settings.LLM_NAME:
|
||||
model_names = self._parse_model_names(settings.LLM_NAME)
|
||||
for model_name in model_names:
|
||||
custom_model = create_custom_openai_model(
|
||||
model_name, settings.OPENAI_BASE_URL
|
||||
)
|
||||
self.models[model_name] = custom_model
|
||||
logger.info(
|
||||
f"Registered custom OpenAI model: {model_name} at {settings.OPENAI_BASE_URL}"
|
||||
)
|
||||
else:
|
||||
# Standard OpenAI API usage - add standard models if API key is valid
|
||||
if settings.OPENAI_API_KEY:
|
||||
for model in OPENAI_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_azure_openai_models(self, settings):
|
||||
from application.core.model_configs import AZURE_OPENAI_MODELS
|
||||
|
||||
if settings.LLM_PROVIDER == "azure_openai" and settings.LLM_NAME:
|
||||
for model in AZURE_OPENAI_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in AZURE_OPENAI_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_anthropic_models(self, settings):
|
||||
from application.core.model_configs import ANTHROPIC_MODELS
|
||||
|
||||
if settings.ANTHROPIC_API_KEY:
|
||||
for model in ANTHROPIC_MODELS:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
if settings.LLM_PROVIDER == "anthropic" and settings.LLM_NAME:
|
||||
for model in ANTHROPIC_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in ANTHROPIC_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_google_models(self, settings):
|
||||
from application.core.model_configs import GOOGLE_MODELS
|
||||
|
||||
if settings.GOOGLE_API_KEY:
|
||||
for model in GOOGLE_MODELS:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
if settings.LLM_PROVIDER == "google" and settings.LLM_NAME:
|
||||
for model in GOOGLE_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in GOOGLE_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_groq_models(self, settings):
|
||||
from application.core.model_configs import GROQ_MODELS
|
||||
|
||||
if settings.GROQ_API_KEY:
|
||||
for model in GROQ_MODELS:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
if settings.LLM_PROVIDER == "groq" and settings.LLM_NAME:
|
||||
for model in GROQ_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in GROQ_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_openrouter_models(self, settings):
|
||||
from application.core.model_configs import OPENROUTER_MODELS
|
||||
|
||||
if settings.OPEN_ROUTER_API_KEY:
|
||||
for model in OPENROUTER_MODELS:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
if settings.LLM_PROVIDER == "openrouter" and settings.LLM_NAME:
|
||||
for model in OPENROUTER_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in OPENROUTER_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_novita_models(self, settings):
|
||||
from application.core.model_configs import NOVITA_MODELS
|
||||
|
||||
if settings.NOVITA_API_KEY:
|
||||
for model in NOVITA_MODELS:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
if settings.LLM_PROVIDER == "novita" and settings.LLM_NAME:
|
||||
for model in NOVITA_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in NOVITA_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_docsgpt_models(self, settings):
|
||||
model_id = "docsgpt-local"
|
||||
model = AvailableModel(
|
||||
id=model_id,
|
||||
provider=ModelProvider.DOCSGPT,
|
||||
display_name="DocsGPT Model",
|
||||
description="Local model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=False,
|
||||
supported_attachment_types=[],
|
||||
),
|
||||
)
|
||||
self.models[model_id] = model
|
||||
|
||||
def _add_huggingface_models(self, settings):
|
||||
model_id = "huggingface-local"
|
||||
model = AvailableModel(
|
||||
id=model_id,
|
||||
provider=ModelProvider.HUGGINGFACE,
|
||||
display_name="Hugging Face Model",
|
||||
description="Local Hugging Face model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=False,
|
||||
supported_attachment_types=[],
|
||||
),
|
||||
)
|
||||
self.models[model_id] = model
|
||||
|
||||
def _parse_model_names(self, llm_name: str) -> List[str]:
|
||||
"""
|
||||
Parse LLM_NAME which may contain comma-separated model names.
|
||||
E.g., 'deepseek-r1:1.5b,gemma:2b' -> ['deepseek-r1:1.5b', 'gemma:2b']
|
||||
"""
|
||||
if not llm_name:
|
||||
return []
|
||||
return [name.strip() for name in llm_name.split(",") if name.strip()]
|
||||
|
||||
def get_model(self, model_id: str) -> Optional[AvailableModel]:
|
||||
return self.models.get(model_id)
|
||||
|
||||
def get_all_models(self) -> List[AvailableModel]:
|
||||
return list(self.models.values())
|
||||
|
||||
def get_enabled_models(self) -> List[AvailableModel]:
|
||||
return [m for m in self.models.values() if m.enabled]
|
||||
|
||||
def model_exists(self, model_id: str) -> bool:
|
||||
return model_id in self.models
|
||||
return _MR
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
@@ -1,28 +1,22 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from application.core.model_settings import ModelRegistry
|
||||
from application.core.model_registry import ModelRegistry
|
||||
|
||||
|
||||
def get_api_key_for_provider(provider: str) -> Optional[str]:
|
||||
"""Get the appropriate API key for a provider"""
|
||||
"""Get the appropriate API key for a provider.
|
||||
|
||||
Delegates to the provider plugin's ``get_api_key``. Falls back to the
|
||||
generic ``settings.API_KEY`` for unknown providers.
|
||||
"""
|
||||
from application.core.settings import settings
|
||||
from application.llm.providers import PROVIDERS_BY_NAME
|
||||
|
||||
provider_key_map = {
|
||||
"openai": settings.OPENAI_API_KEY,
|
||||
"openrouter": settings.OPEN_ROUTER_API_KEY,
|
||||
"novita": settings.NOVITA_API_KEY,
|
||||
"anthropic": settings.ANTHROPIC_API_KEY,
|
||||
"google": settings.GOOGLE_API_KEY,
|
||||
"groq": settings.GROQ_API_KEY,
|
||||
"huggingface": settings.HUGGINGFACE_API_KEY,
|
||||
"azure_openai": settings.API_KEY,
|
||||
"docsgpt": None,
|
||||
"llama.cpp": None,
|
||||
}
|
||||
|
||||
provider_key = provider_key_map.get(provider)
|
||||
if provider_key:
|
||||
return provider_key
|
||||
plugin = PROVIDERS_BY_NAME.get(provider)
|
||||
if plugin is not None:
|
||||
key = plugin.get_api_key(settings)
|
||||
if key:
|
||||
return key
|
||||
return settings.API_KEY
|
||||
|
||||
|
||||
@@ -91,3 +85,21 @@ def get_base_url_for_model(model_id: str) -> Optional[str]:
|
||||
if model:
|
||||
return model.base_url
|
||||
return None
|
||||
|
||||
|
||||
def get_api_key_for_model(model_id: str) -> Optional[str]:
|
||||
"""
|
||||
Resolve the API key to use when invoking ``model_id``.
|
||||
|
||||
Priority:
|
||||
1. The model record's own ``api_key`` (reserved for future end-user
|
||||
BYOM where credentials travel with the record).
|
||||
2. The provider plugin's settings-based key.
|
||||
"""
|
||||
registry = ModelRegistry.get_instance()
|
||||
model = registry.get_model(model_id)
|
||||
if model is not None and model.api_key:
|
||||
return model.api_key
|
||||
if model is not None:
|
||||
return get_api_key_for_provider(model.provider.value)
|
||||
return None
|
||||
|
||||
325
application/core/model_yaml.py
Normal file
325
application/core/model_yaml.py
Normal file
@@ -0,0 +1,325 @@
|
||||
"""YAML loader for model catalog files under ``application/core/models/``.
|
||||
|
||||
Each ``*.yaml`` file declares one provider's static model catalog. Files
|
||||
are validated with Pydantic at load time; any parse, schema, or alias
|
||||
error aborts startup with the offending file path in the message.
|
||||
|
||||
For most providers, one YAML maps to one catalog. The
|
||||
``openai_compatible`` provider is special: each YAML file represents a
|
||||
distinct logical endpoint (Mistral, Together, Ollama, ...) with its own
|
||||
``api_key_env`` and ``base_url``. The loader returns a flat list so the
|
||||
registry can distinguish multiple files with the same ``provider:`` value.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Sequence
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from application.core.model_settings import (
|
||||
AvailableModel,
|
||||
ModelCapabilities,
|
||||
ModelProvider,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BUILTIN_MODELS_DIR = Path(__file__).parent / "models"
|
||||
DEFAULTS_FILENAME = "_defaults.yaml"
|
||||
|
||||
|
||||
class _DefaultsFile(BaseModel):
|
||||
"""Schema for ``_defaults.yaml``. Currently just attachment aliases."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
attachment_aliases: Dict[str, List[str]] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class _CapabilityFields(BaseModel):
|
||||
"""Capability fields shared between provider ``defaults:`` and per-model overrides.
|
||||
|
||||
All fields are optional so a per-model override can selectively replace
|
||||
a single field from the provider-level defaults.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
supports_tools: Optional[bool] = None
|
||||
supports_structured_output: Optional[bool] = None
|
||||
supports_streaming: Optional[bool] = None
|
||||
attachments: Optional[List[str]] = None
|
||||
context_window: Optional[int] = None
|
||||
input_cost_per_token: Optional[float] = None
|
||||
output_cost_per_token: Optional[float] = None
|
||||
|
||||
|
||||
class _ModelEntry(_CapabilityFields):
|
||||
"""Schema for one model row inside a YAML's ``models:`` list."""
|
||||
|
||||
id: str
|
||||
display_name: Optional[str] = None
|
||||
description: str = ""
|
||||
enabled: bool = True
|
||||
base_url: Optional[str] = None
|
||||
aliases: List[str] = Field(default_factory=list)
|
||||
|
||||
@field_validator("id")
|
||||
@classmethod
|
||||
def _id_nonempty(cls, v: str) -> str:
|
||||
if not v or not v.strip():
|
||||
raise ValueError("model id must be a non-empty string")
|
||||
return v
|
||||
|
||||
|
||||
class _ProviderFile(BaseModel):
|
||||
"""Schema for one ``<provider>.yaml`` catalog file."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
provider: str
|
||||
defaults: _CapabilityFields = Field(default_factory=_CapabilityFields)
|
||||
models: List[_ModelEntry] = Field(default_factory=list)
|
||||
# openai_compatible metadata. Optional for other providers.
|
||||
display_provider: Optional[str] = None
|
||||
api_key_env: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
|
||||
|
||||
class ProviderCatalog(BaseModel):
|
||||
"""One YAML file's parsed contents, ready for the registry.
|
||||
|
||||
For most providers, multiple catalogs with the same ``provider`` get
|
||||
merged later by the registry. The ``openai_compatible`` provider is
|
||||
the exception: each catalog is treated as a distinct endpoint, with
|
||||
its own ``api_key_env`` and ``base_url``.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
models: List[AvailableModel]
|
||||
source_path: Optional[Path] = None
|
||||
display_provider: Optional[str] = None
|
||||
api_key_env: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class ModelYAMLError(ValueError):
|
||||
"""Raised when a model YAML fails parsing, schema, or alias validation."""
|
||||
|
||||
|
||||
def _expand_attachments(
|
||||
attachments: Sequence[str], aliases: Dict[str, List[str]], source: str
|
||||
) -> List[str]:
|
||||
"""Resolve attachment shorthands (``image``, ``pdf``) to MIME types.
|
||||
|
||||
Raw MIME-typed entries (containing ``/``) pass through unchanged.
|
||||
Unknown aliases raise ``ModelYAMLError``.
|
||||
"""
|
||||
expanded: List[str] = []
|
||||
seen: set = set()
|
||||
for entry in attachments:
|
||||
if "/" in entry:
|
||||
if entry not in seen:
|
||||
expanded.append(entry)
|
||||
seen.add(entry)
|
||||
continue
|
||||
if entry not in aliases:
|
||||
valid = ", ".join(sorted(aliases.keys())) or "<none defined>"
|
||||
raise ModelYAMLError(
|
||||
f"{source}: unknown attachment alias '{entry}'. "
|
||||
f"Valid aliases: {valid}. "
|
||||
"(Or use a raw MIME type like 'image/png'.)"
|
||||
)
|
||||
for mime in aliases[entry]:
|
||||
if mime not in seen:
|
||||
expanded.append(mime)
|
||||
seen.add(mime)
|
||||
return expanded
|
||||
|
||||
|
||||
def _load_defaults(directory: Path) -> Dict[str, List[str]]:
|
||||
"""Load ``_defaults.yaml`` from ``directory`` if it exists."""
|
||||
path = directory / DEFAULTS_FILENAME
|
||||
if not path.exists():
|
||||
return {}
|
||||
try:
|
||||
raw = yaml.safe_load(path.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as e:
|
||||
raise ModelYAMLError(f"{path}: invalid YAML: {e}") from e
|
||||
try:
|
||||
parsed = _DefaultsFile.model_validate(raw)
|
||||
except Exception as e:
|
||||
raise ModelYAMLError(f"{path}: schema error: {e}") from e
|
||||
return parsed.attachment_aliases
|
||||
|
||||
|
||||
def _resolve_provider_enum(name: str, source: Path) -> ModelProvider:
|
||||
try:
|
||||
return ModelProvider(name)
|
||||
except ValueError as e:
|
||||
valid = ", ".join(p.value for p in ModelProvider)
|
||||
raise ModelYAMLError(
|
||||
f"{source}: unknown provider '{name}'. Valid: {valid}"
|
||||
) from e
|
||||
|
||||
|
||||
def _build_model(
|
||||
entry: _ModelEntry,
|
||||
defaults: _CapabilityFields,
|
||||
provider: ModelProvider,
|
||||
aliases: Dict[str, List[str]],
|
||||
source: Path,
|
||||
display_provider: Optional[str] = None,
|
||||
) -> AvailableModel:
|
||||
"""Merge defaults + per-model overrides into a final ``AvailableModel``."""
|
||||
|
||||
def pick(field_name: str, fallback):
|
||||
v = getattr(entry, field_name)
|
||||
if v is not None:
|
||||
return v
|
||||
d = getattr(defaults, field_name)
|
||||
if d is not None:
|
||||
return d
|
||||
return fallback
|
||||
|
||||
raw_attachments = entry.attachments
|
||||
if raw_attachments is None:
|
||||
raw_attachments = defaults.attachments
|
||||
if raw_attachments is None:
|
||||
raw_attachments = []
|
||||
expanded = _expand_attachments(
|
||||
raw_attachments, aliases, f"{source} [model={entry.id}]"
|
||||
)
|
||||
|
||||
caps = ModelCapabilities(
|
||||
supports_tools=pick("supports_tools", False),
|
||||
supports_structured_output=pick("supports_structured_output", False),
|
||||
supports_streaming=pick("supports_streaming", True),
|
||||
supported_attachment_types=expanded,
|
||||
context_window=pick("context_window", 128000),
|
||||
input_cost_per_token=pick("input_cost_per_token", None),
|
||||
output_cost_per_token=pick("output_cost_per_token", None),
|
||||
)
|
||||
|
||||
return AvailableModel(
|
||||
id=entry.id,
|
||||
provider=provider,
|
||||
display_name=entry.display_name or entry.id,
|
||||
description=entry.description,
|
||||
capabilities=caps,
|
||||
enabled=entry.enabled,
|
||||
base_url=entry.base_url,
|
||||
display_provider=display_provider,
|
||||
)
|
||||
|
||||
|
||||
def _load_one_yaml(
|
||||
path: Path, aliases: Dict[str, List[str]]
|
||||
) -> ProviderCatalog:
|
||||
try:
|
||||
raw = yaml.safe_load(path.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as e:
|
||||
raise ModelYAMLError(f"{path}: invalid YAML: {e}") from e
|
||||
try:
|
||||
parsed = _ProviderFile.model_validate(raw)
|
||||
except Exception as e:
|
||||
raise ModelYAMLError(f"{path}: schema error: {e}") from e
|
||||
|
||||
provider_enum = _resolve_provider_enum(parsed.provider, path)
|
||||
models = [
|
||||
_build_model(
|
||||
entry,
|
||||
parsed.defaults,
|
||||
provider_enum,
|
||||
aliases,
|
||||
path,
|
||||
display_provider=parsed.display_provider,
|
||||
)
|
||||
for entry in parsed.models
|
||||
]
|
||||
|
||||
return ProviderCatalog(
|
||||
provider=parsed.provider,
|
||||
models=models,
|
||||
source_path=path,
|
||||
display_provider=parsed.display_provider,
|
||||
api_key_env=parsed.api_key_env,
|
||||
base_url=parsed.base_url,
|
||||
)
|
||||
|
||||
|
||||
_BUILTIN_ALIASES_CACHE: Optional[Dict[str, List[str]]] = None
|
||||
|
||||
|
||||
def builtin_attachment_aliases() -> Dict[str, List[str]]:
|
||||
"""Return the built-in attachment alias map from ``_defaults.yaml``.
|
||||
|
||||
Cached after first read so repeat calls are cheap.
|
||||
"""
|
||||
global _BUILTIN_ALIASES_CACHE
|
||||
if _BUILTIN_ALIASES_CACHE is None:
|
||||
_BUILTIN_ALIASES_CACHE = _load_defaults(BUILTIN_MODELS_DIR)
|
||||
return _BUILTIN_ALIASES_CACHE
|
||||
|
||||
|
||||
def resolve_attachment_alias(alias: str) -> List[str]:
|
||||
"""Resolve a single attachment alias (e.g. ``"image"``) to its
|
||||
canonical MIME-type list. Raises ``ModelYAMLError`` if unknown.
|
||||
"""
|
||||
aliases = builtin_attachment_aliases()
|
||||
if alias not in aliases:
|
||||
valid = ", ".join(sorted(aliases.keys())) or "<none defined>"
|
||||
raise ModelYAMLError(
|
||||
f"Unknown attachment alias '{alias}'. Valid: {valid}"
|
||||
)
|
||||
return list(aliases[alias])
|
||||
|
||||
|
||||
def load_model_yamls(directories: Sequence[Path]) -> List[ProviderCatalog]:
|
||||
"""Load every ``*.yaml`` file (excluding ``_defaults.yaml``) under each
|
||||
directory in order and return a flat list of catalogs.
|
||||
|
||||
Caller is responsible for merging multiple catalogs that target the
|
||||
same provider plugin. The flat-list shape lets ``openai_compatible``
|
||||
keep each file separate (one logical endpoint per file).
|
||||
|
||||
When the same model ``id`` appears in more than one YAML across the
|
||||
directory list, a warning is logged. Order in the returned list
|
||||
preserves load order, so the registry's "later wins" merge gives the
|
||||
later directory's definition.
|
||||
"""
|
||||
catalogs: List[ProviderCatalog] = []
|
||||
seen_ids: Dict[str, Path] = {}
|
||||
|
||||
aliases: Dict[str, List[str]] = {}
|
||||
for d in directories:
|
||||
if not d or not d.exists():
|
||||
continue
|
||||
aliases.update(_load_defaults(d))
|
||||
|
||||
for d in directories:
|
||||
if not d or not d.exists():
|
||||
continue
|
||||
for path in sorted(d.glob("*.yaml")):
|
||||
if path.name == DEFAULTS_FILENAME:
|
||||
continue
|
||||
catalog = _load_one_yaml(path, aliases)
|
||||
catalogs.append(catalog)
|
||||
for m in catalog.models:
|
||||
prior = seen_ids.get(m.id)
|
||||
if prior is not None and prior != path:
|
||||
logger.warning(
|
||||
"Model id %r redefined: %s overrides %s (later wins)",
|
||||
m.id,
|
||||
path,
|
||||
prior,
|
||||
)
|
||||
seen_ids[m.id] = path
|
||||
|
||||
return catalogs
|
||||
213
application/core/models/README.md
Normal file
213
application/core/models/README.md
Normal file
@@ -0,0 +1,213 @@
|
||||
# Model catalogs
|
||||
|
||||
Each `*.yaml` file in this directory declares one provider's model
|
||||
catalog. The registry loads every YAML at boot and joins it to the
|
||||
matching provider plugin under `application/llm/providers/`.
|
||||
|
||||
To add or edit models, you almost always only touch a YAML here — no
|
||||
Python code required.
|
||||
|
||||
## Add a model to an existing provider
|
||||
|
||||
Open the provider's YAML (e.g. `anthropic.yaml`) and append two lines
|
||||
under `models:`:
|
||||
|
||||
```yaml
|
||||
models:
|
||||
- id: claude-3-7-sonnet
|
||||
display_name: Claude 3.7 Sonnet
|
||||
```
|
||||
|
||||
Capabilities default to the provider's `defaults:` block. Override
|
||||
per-model only when needed:
|
||||
|
||||
```yaml
|
||||
- id: claude-3-7-sonnet
|
||||
display_name: Claude 3.7 Sonnet
|
||||
context_window: 500000
|
||||
```
|
||||
|
||||
Restart the app. The new model appears in `/api/models`.
|
||||
|
||||
> The model `id` is what gets stored in agent / workflow records. Once
|
||||
> users start picking the model, **don't rename it** — agent and
|
||||
> workflow rows reference it as a free-form string and silently fall
|
||||
> back to the system default if the id disappears.
|
||||
|
||||
## Add an OpenAI-compatible provider (zero Python)
|
||||
|
||||
Drop a YAML in this directory (or in your `MODELS_CONFIG_DIR`) that uses
|
||||
the `openai_compatible` plugin. Set the env var named in `api_key_env`
|
||||
and you're done — no Python, no settings.py edit, no LLMCreator change:
|
||||
|
||||
```yaml
|
||||
# mistral.yaml
|
||||
provider: openai_compatible
|
||||
display_provider: mistral # shown in /api/models response
|
||||
api_key_env: MISTRAL_API_KEY # env var the plugin reads at boot
|
||||
base_url: https://api.mistral.ai/v1
|
||||
defaults:
|
||||
supports_tools: true
|
||||
context_window: 128000
|
||||
models:
|
||||
- id: mistral-large-latest
|
||||
display_name: Mistral Large
|
||||
- id: mistral-small-latest
|
||||
display_name: Mistral Small
|
||||
```
|
||||
|
||||
`MISTRAL_API_KEY=sk-... ; restart` — Mistral models appear in
|
||||
`/api/models` with `provider: "mistral"`. They route through the OpenAI
|
||||
wire format (it's `OpenAILLM` under the hood) but with Mistral's
|
||||
endpoint and key.
|
||||
|
||||
Multiple `openai_compatible` YAMLs coexist: each file is one logical
|
||||
endpoint with its own `api_key_env` and `base_url`. Drop in
|
||||
`together.yaml`, `fireworks.yaml`, etc. side by side. If an env var
|
||||
isn't set, that catalog is silently skipped at boot (logged at INFO) —
|
||||
no error.
|
||||
|
||||
Working example: `examples/mistral.yaml.example`. Files inside
|
||||
`examples/` aren't loaded by the registry; the glob only picks up
|
||||
`*.yaml` at the top level.
|
||||
|
||||
## Add a provider with its own SDK
|
||||
|
||||
For a provider that doesn't speak OpenAI's wire format, add one Python
|
||||
file to `application/llm/providers/<name>.py`:
|
||||
|
||||
```python
|
||||
from application.llm.providers.base import Provider
|
||||
from application.llm.my_provider import MyLLM
|
||||
|
||||
class MyProvider(Provider):
|
||||
name = "my_provider"
|
||||
llm_class = MyLLM
|
||||
|
||||
def get_api_key(self, settings):
|
||||
return settings.MY_PROVIDER_API_KEY
|
||||
```
|
||||
|
||||
Register it in `application/llm/providers/__init__.py` (one line in
|
||||
`ALL_PROVIDERS`), add `MY_PROVIDER_API_KEY` to `settings.py`, and create
|
||||
`my_provider.yaml` here with the model catalog.
|
||||
|
||||
## Schema reference
|
||||
|
||||
```yaml
|
||||
provider: <string, required> # matches the Provider plugin's `name`
|
||||
|
||||
# openai_compatible only — required for that provider, ignored for others
|
||||
display_provider: <string> # label shown in /api/models response
|
||||
api_key_env: <string> # name of the env var carrying the key
|
||||
base_url: <string> # endpoint URL
|
||||
|
||||
defaults: # optional, applied to every model below
|
||||
supports_tools: bool # default false
|
||||
supports_structured_output: bool # default false
|
||||
supports_streaming: bool # default true
|
||||
attachments: [<alias-or-mime>, ...] # default []
|
||||
context_window: int # default 128000
|
||||
input_cost_per_token: float # default null
|
||||
output_cost_per_token: float # default null
|
||||
|
||||
models: # required
|
||||
- id: <string, required> # the value persisted in agent records
|
||||
display_name: <string> # default: id
|
||||
description: <string> # default: ""
|
||||
enabled: bool # default true; false hides from /api/models
|
||||
base_url: <string> # optional custom endpoint for this model
|
||||
# All `defaults:` fields above can be overridden here per-model.
|
||||
```
|
||||
|
||||
### Attachment aliases
|
||||
|
||||
The `attachments:` list can mix human-readable aliases with raw MIME
|
||||
types. Aliases are defined in `_defaults.yaml`:
|
||||
|
||||
| Alias | Expands to |
|
||||
|---|---|
|
||||
| `image` | `image/png`, `image/jpeg`, `image/jpg`, `image/webp`, `image/gif` |
|
||||
| `pdf` | `application/pdf` |
|
||||
| `audio` | `audio/mpeg`, `audio/wav`, `audio/ogg` |
|
||||
|
||||
Use raw MIME types when you need surgical control:
|
||||
|
||||
```yaml
|
||||
attachments: [image/png, image/webp] # only these two
|
||||
```
|
||||
|
||||
## Operator-supplied YAMLs (`MODELS_CONFIG_DIR`)
|
||||
|
||||
Set the `MODELS_CONFIG_DIR` env var (or `.env` entry) to a directory
|
||||
path. Every `*.yaml` in that directory is loaded **after** the built-in
|
||||
catalog under `application/core/models/`. Operators use this to:
|
||||
|
||||
- Add new `openai_compatible` providers (Mistral, Together, Fireworks,
|
||||
Ollama, ...) without forking the repo.
|
||||
- Extend an existing provider's catalog with extra models — append
|
||||
models under `provider: anthropic` and they show up alongside the
|
||||
built-ins.
|
||||
- Override a built-in model's capabilities — declare the same `id`
|
||||
with different fields (e.g. a higher `context_window`). Later wins;
|
||||
the override is logged as a `WARNING` so you can audit it.
|
||||
|
||||
Things you cannot do via `MODELS_CONFIG_DIR`:
|
||||
|
||||
- Add a brand-new non-OpenAI provider — that needs a Python plugin
|
||||
under `application/llm/providers/` (see "Add a provider with its own
|
||||
SDK" above). Operator YAMLs may only target a `provider:` value that
|
||||
already has a registered plugin.
|
||||
|
||||
### Example: Docker
|
||||
|
||||
Mount your model YAMLs into the container and point the env var at the
|
||||
mount path:
|
||||
|
||||
```yaml
|
||||
# docker-compose.yml
|
||||
services:
|
||||
app:
|
||||
image: arc53/docsgpt
|
||||
environment:
|
||||
MODELS_CONFIG_DIR: /etc/docsgpt/models
|
||||
MISTRAL_API_KEY: ${MISTRAL_API_KEY}
|
||||
volumes:
|
||||
- ./my-models:/etc/docsgpt/models:ro
|
||||
```
|
||||
|
||||
Then `./my-models/mistral.yaml` (the file from
|
||||
`examples/mistral.yaml.example`) gets picked up at boot.
|
||||
|
||||
### Example: Kubernetes
|
||||
|
||||
Mount a `ConfigMap` containing your YAMLs at a known path and set
|
||||
`MODELS_CONFIG_DIR` on the deployment. The same `examples/mistral.yaml.example`
|
||||
becomes a key in the ConfigMap.
|
||||
|
||||
### Misconfiguration
|
||||
|
||||
If `MODELS_CONFIG_DIR` is set but the path doesn't exist (or isn't a
|
||||
directory), the app logs a `WARNING` at boot and continues with just
|
||||
the built-in catalog. The app does *not* fail to start — operators can
|
||||
ship config drift without taking down the service — but the warning is
|
||||
loud enough to surface in any reasonable log aggregator.
|
||||
|
||||
## Validation
|
||||
|
||||
YAMLs are parsed with Pydantic at boot. The app fails to start with a
|
||||
clear error message if:
|
||||
|
||||
- a top-level key is unknown
|
||||
- a model is missing `id`
|
||||
- an attachment alias isn't defined
|
||||
- the `provider:` value isn't registered as a plugin
|
||||
|
||||
This is intentional — silent fallbacks would mean users don't notice
|
||||
their model picks broke until they hit the API.
|
||||
|
||||
## Reserved fields (not yet implemented)
|
||||
|
||||
- `aliases:` on a model — old IDs that resolve to this model. Reserved
|
||||
for future renames; the schema accepts the field but it is not yet
|
||||
acted on.
|
||||
18
application/core/models/_defaults.yaml
Normal file
18
application/core/models/_defaults.yaml
Normal file
@@ -0,0 +1,18 @@
|
||||
# Global defaults applied across every model YAML in this directory.
|
||||
# Keep this file sparse — per-provider `defaults:` blocks are clearer
|
||||
# than a deep global default chain. This file is for things that
|
||||
# genuinely never vary, like the meaning of "image".
|
||||
|
||||
attachment_aliases:
|
||||
image:
|
||||
- image/png
|
||||
- image/jpeg
|
||||
- image/jpg
|
||||
- image/webp
|
||||
- image/gif
|
||||
pdf:
|
||||
- application/pdf
|
||||
audio:
|
||||
- audio/mpeg
|
||||
- audio/wav
|
||||
- audio/ogg
|
||||
23
application/core/models/anthropic.yaml
Normal file
23
application/core/models/anthropic.yaml
Normal file
@@ -0,0 +1,23 @@
|
||||
provider: anthropic
|
||||
defaults:
|
||||
supports_tools: true
|
||||
attachments: [image]
|
||||
context_window: 200000
|
||||
|
||||
models:
|
||||
- id: claude-opus-4-7
|
||||
display_name: Claude Opus 4.7
|
||||
description: Most capable Claude model for complex reasoning and agentic coding
|
||||
context_window: 1000000
|
||||
supports_structured_output: true
|
||||
|
||||
- id: claude-sonnet-4-6
|
||||
display_name: Claude Sonnet 4.6
|
||||
description: Best balance of speed and intelligence with extended thinking
|
||||
context_window: 1000000
|
||||
supports_structured_output: true
|
||||
|
||||
- id: claude-haiku-4-5
|
||||
display_name: Claude Haiku 4.5
|
||||
description: Fastest Claude model with near-frontier intelligence
|
||||
supports_structured_output: true
|
||||
31
application/core/models/azure_openai.yaml
Normal file
31
application/core/models/azure_openai.yaml
Normal file
@@ -0,0 +1,31 @@
|
||||
# Azure OpenAI catalog.
|
||||
#
|
||||
# IMPORTANT: For Azure OpenAI, the `id` field is the **deployment name**, not
|
||||
# a model name. Deployment names are arbitrary strings the operator chooses
|
||||
# in Azure portal (or via ARM/Bicep/Terraform) when they create a deployment
|
||||
# for a given underlying model + version.
|
||||
#
|
||||
# The IDs below are sensible defaults that mirror the underlying OpenAI
|
||||
# model name (prefixed with `azure-`). Operators almost always need to
|
||||
# override them via `MODELS_CONFIG_DIR` to match the deployment names that
|
||||
# actually exist in their Azure resource. The `display_name`, capability
|
||||
# flags, and `context_window` reflect the underlying OpenAI model.
|
||||
provider: azure_openai
|
||||
|
||||
defaults:
|
||||
supports_tools: true
|
||||
supports_structured_output: true
|
||||
attachments: [image]
|
||||
context_window: 400000
|
||||
|
||||
models:
|
||||
- id: azure-gpt-5.5
|
||||
display_name: Azure OpenAI GPT-5.5
|
||||
description: Azure-hosted flagship frontier model for complex reasoning, coding, and agentic work with a 1M-token context window
|
||||
context_window: 1050000
|
||||
- id: azure-gpt-5.4-mini
|
||||
display_name: Azure OpenAI GPT-5.4 Mini
|
||||
description: Azure-hosted cost-efficient GPT-5.4-class model for high-volume coding, computer use, and subagent workloads
|
||||
- id: azure-gpt-5.4-nano
|
||||
display_name: Azure OpenAI GPT-5.4 Nano
|
||||
description: Azure-hosted cheapest GPT-5.4-class model, optimized for simple high-volume tasks where speed and cost matter most
|
||||
7
application/core/models/docsgpt.yaml
Normal file
7
application/core/models/docsgpt.yaml
Normal file
@@ -0,0 +1,7 @@
|
||||
provider: docsgpt
|
||||
|
||||
models:
|
||||
- id: docsgpt-local
|
||||
display_name: DocsGPT Model
|
||||
description: Local model
|
||||
supports_tools: false
|
||||
31
application/core/models/examples/mistral.yaml.example
Normal file
31
application/core/models/examples/mistral.yaml.example
Normal file
@@ -0,0 +1,31 @@
|
||||
# EXAMPLE — copy this file to ../mistral.yaml (or to your
|
||||
# MODELS_CONFIG_DIR) and set MISTRAL_API_KEY in your environment.
|
||||
#
|
||||
# This is the entire integration. No Python required: the
|
||||
# `openai_compatible` plugin reads `api_key_env` and `base_url` from
|
||||
# the file and routes calls through the OpenAI wire format.
|
||||
#
|
||||
# Files in this `examples/` directory are NOT loaded by the registry
|
||||
# (the loader globs *.yaml at the top level only).
|
||||
|
||||
provider: openai_compatible
|
||||
display_provider: mistral # shown in /api/models response
|
||||
api_key_env: MISTRAL_API_KEY # env var the plugin reads
|
||||
base_url: https://api.mistral.ai/v1 # OpenAI-compatible endpoint
|
||||
|
||||
defaults:
|
||||
supports_tools: true
|
||||
context_window: 128000
|
||||
|
||||
models:
|
||||
- id: mistral-large-latest
|
||||
display_name: Mistral Large
|
||||
description: Top-tier reasoning model
|
||||
|
||||
- id: mistral-small-latest
|
||||
display_name: Mistral Small
|
||||
description: Fast, cost-efficient
|
||||
|
||||
- id: codestral-latest
|
||||
display_name: Codestral
|
||||
description: Code-specialized model
|
||||
17
application/core/models/google.yaml
Normal file
17
application/core/models/google.yaml
Normal file
@@ -0,0 +1,17 @@
|
||||
provider: google
|
||||
defaults:
|
||||
supports_tools: true
|
||||
supports_structured_output: true
|
||||
attachments: [pdf, image]
|
||||
context_window: 1048576
|
||||
|
||||
models:
|
||||
- id: gemini-3.1-pro-preview
|
||||
display_name: Gemini 3.1 Pro
|
||||
description: Most capable Gemini 3 model with advanced reasoning and agentic coding (preview)
|
||||
- id: gemini-3-flash-preview
|
||||
display_name: Gemini 3 Flash
|
||||
description: Frontier-class performance for low-latency, high-volume tasks (preview)
|
||||
- id: gemini-3.1-flash-lite-preview
|
||||
display_name: Gemini 3.1 Flash-Lite
|
||||
description: Cost-efficient frontier-class multimodal model for high-throughput workloads (preview)
|
||||
16
application/core/models/groq.yaml
Normal file
16
application/core/models/groq.yaml
Normal file
@@ -0,0 +1,16 @@
|
||||
provider: groq
|
||||
defaults:
|
||||
supports_tools: true
|
||||
context_window: 131072
|
||||
|
||||
models:
|
||||
- id: openai/gpt-oss-120b
|
||||
display_name: GPT-OSS 120B
|
||||
description: OpenAI's open-weight 120B flagship served on Groq's LPU hardware; strong general reasoning with strict structured output support
|
||||
supports_structured_output: true
|
||||
- id: llama-3.3-70b-versatile
|
||||
display_name: Llama 3.3 70B Versatile
|
||||
description: Meta's Llama 3.3 70B for general-purpose chat with parallel tool use
|
||||
- id: llama-3.1-8b-instant
|
||||
display_name: Llama 3.1 8B Instant
|
||||
description: Small, very low-latency Llama model (~560 tok/s) with parallel tool use
|
||||
7
application/core/models/huggingface.yaml
Normal file
7
application/core/models/huggingface.yaml
Normal file
@@ -0,0 +1,7 @@
|
||||
provider: huggingface
|
||||
|
||||
models:
|
||||
- id: huggingface-local
|
||||
display_name: Hugging Face Model
|
||||
description: Local Hugging Face model
|
||||
supports_tools: false
|
||||
21
application/core/models/novita.yaml
Normal file
21
application/core/models/novita.yaml
Normal file
@@ -0,0 +1,21 @@
|
||||
provider: novita
|
||||
defaults:
|
||||
supports_tools: true
|
||||
supports_structured_output: true
|
||||
|
||||
models:
|
||||
- id: deepseek/deepseek-v4-pro
|
||||
display_name: DeepSeek V4 Pro
|
||||
description: 1.6T MoE (49B active) with 1M context, hybrid CSA/HCA attention, top-tier reasoning and agentic coding
|
||||
context_window: 1048576
|
||||
|
||||
- id: moonshotai/kimi-k2.6
|
||||
display_name: Kimi K2.6
|
||||
description: 1T-parameter open-weight MoE with native vision/video, multi-step tool calling, and agentic long-horizon execution
|
||||
attachments: [image]
|
||||
context_window: 262144
|
||||
|
||||
- id: zai-org/glm-5
|
||||
display_name: GLM-5
|
||||
description: Z.AI 754B-parameter MoE with strong general reasoning, function calling, and structured output
|
||||
context_window: 202800
|
||||
18
application/core/models/openai.yaml
Normal file
18
application/core/models/openai.yaml
Normal file
@@ -0,0 +1,18 @@
|
||||
provider: openai
|
||||
defaults:
|
||||
supports_tools: true
|
||||
supports_structured_output: true
|
||||
attachments: [image]
|
||||
context_window: 400000
|
||||
|
||||
models:
|
||||
- id: gpt-5.5
|
||||
display_name: GPT-5.5
|
||||
description: Flagship frontier model for complex reasoning, coding, and agentic work with a 1M-token context window
|
||||
context_window: 1050000
|
||||
- id: gpt-5.4-mini
|
||||
display_name: GPT-5.4 Mini
|
||||
description: Cost-efficient GPT-5.4-class model for high-volume coding, computer use, and subagent workloads
|
||||
- id: gpt-5.4-nano
|
||||
display_name: GPT-5.4 Nano
|
||||
description: Cheapest GPT-5.4-class model, optimized for simple high-volume tasks where speed and cost matter most
|
||||
25
application/core/models/openrouter.yaml
Normal file
25
application/core/models/openrouter.yaml
Normal file
@@ -0,0 +1,25 @@
|
||||
provider: openrouter
|
||||
defaults:
|
||||
supports_tools: true
|
||||
attachments: [image]
|
||||
context_window: 128000
|
||||
|
||||
models:
|
||||
- id: qwen/qwen3-coder:free
|
||||
display_name: Qwen3 Coder (free)
|
||||
description: Free-tier 480B MoE coder model with strong agentic tool use; rate-limited
|
||||
context_window: 262000
|
||||
attachments: []
|
||||
|
||||
- id: deepseek/deepseek-v3.2
|
||||
display_name: DeepSeek V3.2
|
||||
description: Open-weights reasoning model, very low cost (~$0.25 in / $0.38 out per 1M)
|
||||
context_window: 131072
|
||||
attachments: []
|
||||
supports_structured_output: true
|
||||
|
||||
- id: anthropic/claude-sonnet-4.6
|
||||
display_name: Claude Sonnet 4.6 (via OpenRouter)
|
||||
description: Frontier Sonnet-class model with 1M context, vision, and extended thinking
|
||||
context_window: 1000000
|
||||
supports_structured_output: true
|
||||
@@ -23,6 +23,10 @@ class Settings(BaseSettings):
|
||||
EMBEDDINGS_NAME: str = "huggingface_sentence-transformers/all-mpnet-base-v2"
|
||||
EMBEDDINGS_BASE_URL: Optional[str] = None # Remote embeddings API URL (OpenAI-compatible)
|
||||
EMBEDDINGS_KEY: Optional[str] = None # api key for embeddings (if using openai, just copy API_KEY)
|
||||
# Optional directory of operator-supplied model YAMLs, loaded after the
|
||||
# built-in catalog under application/core/models/. Later wins on
|
||||
# duplicate model id. See application/core/models/README.md.
|
||||
MODELS_CONFIG_DIR: Optional[str] = None
|
||||
|
||||
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
|
||||
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
|
||||
|
||||
@@ -1,34 +1,11 @@
|
||||
import logging
|
||||
|
||||
from application.llm.anthropic import AnthropicLLM
|
||||
from application.llm.docsgpt_provider import DocsGPTAPILLM
|
||||
from application.llm.google_ai import GoogleLLM
|
||||
from application.llm.groq import GroqLLM
|
||||
from application.llm.llama_cpp import LlamaCpp
|
||||
from application.llm.novita import NovitaLLM
|
||||
from application.llm.openai import AzureOpenAILLM, OpenAILLM
|
||||
from application.llm.premai import PremAILLM
|
||||
from application.llm.sagemaker import SagemakerAPILLM
|
||||
from application.llm.open_router import OpenRouterLLM
|
||||
from application.llm.providers import PROVIDERS_BY_NAME
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMCreator:
|
||||
llms = {
|
||||
"openai": OpenAILLM,
|
||||
"azure_openai": AzureOpenAILLM,
|
||||
"sagemaker": SagemakerAPILLM,
|
||||
"llama.cpp": LlamaCpp,
|
||||
"anthropic": AnthropicLLM,
|
||||
"docsgpt": DocsGPTAPILLM,
|
||||
"premai": PremAILLM,
|
||||
"groq": GroqLLM,
|
||||
"google": GoogleLLM,
|
||||
"novita": NovitaLLM,
|
||||
"openrouter": OpenRouterLLM,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create_llm(
|
||||
cls,
|
||||
@@ -42,18 +19,27 @@ class LLMCreator:
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
from application.core.model_utils import get_base_url_for_model
|
||||
from application.core.model_registry import ModelRegistry
|
||||
|
||||
llm_class = cls.llms.get(type.lower())
|
||||
if not llm_class:
|
||||
plugin = PROVIDERS_BY_NAME.get(type.lower())
|
||||
if plugin is None or plugin.llm_class is None:
|
||||
raise ValueError(f"No LLM class found for type {type}")
|
||||
|
||||
# Extract base_url from model configuration if model_id is provided
|
||||
# Prefer per-model endpoint config from the registry. This is what
|
||||
# makes openai_compatible (and the future end-user BYOM phase)
|
||||
# work without changing every call site: if the registered
|
||||
# AvailableModel carries its own api_key / base_url, they win
|
||||
# over whatever the caller resolved via the provider plugin.
|
||||
base_url = None
|
||||
if model_id:
|
||||
base_url = get_base_url_for_model(model_id)
|
||||
model = ModelRegistry.get_instance().get_model(model_id)
|
||||
if model is not None:
|
||||
if model.api_key:
|
||||
api_key = model.api_key
|
||||
if model.base_url:
|
||||
base_url = model.base_url
|
||||
|
||||
return llm_class(
|
||||
return plugin.llm_class(
|
||||
api_key,
|
||||
user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
|
||||
@@ -389,8 +389,8 @@ class OpenAILLM(BaseLLM):
|
||||
Returns:
|
||||
list: List of supported MIME types
|
||||
"""
|
||||
from application.core.model_configs import OPENAI_ATTACHMENTS
|
||||
return OPENAI_ATTACHMENTS
|
||||
from application.core.model_yaml import resolve_attachment_alias
|
||||
return resolve_attachment_alias("image")
|
||||
|
||||
def prepare_messages_with_attachments(self, messages, attachments=None):
|
||||
"""
|
||||
|
||||
51
application/llm/providers/__init__.py
Normal file
51
application/llm/providers/__init__.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""Provider plugin registry.
|
||||
|
||||
Plugins are imported eagerly so import errors surface at app boot rather
|
||||
than at first request. ``ALL_PROVIDERS`` is the canonical ordered list;
|
||||
``PROVIDERS_BY_NAME`` is a name-keyed lookup for LLMCreator and the
|
||||
model registry.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
from application.llm.providers.anthropic import AnthropicProvider
|
||||
from application.llm.providers.azure_openai import AzureOpenAIProvider
|
||||
from application.llm.providers.base import Provider
|
||||
from application.llm.providers.docsgpt import DocsGPTProvider
|
||||
from application.llm.providers.google import GoogleProvider
|
||||
from application.llm.providers.groq import GroqProvider
|
||||
from application.llm.providers.huggingface import HuggingFaceProvider
|
||||
from application.llm.providers.llama_cpp import LlamaCppProvider
|
||||
from application.llm.providers.novita import NovitaProvider
|
||||
from application.llm.providers.openai import OpenAIProvider
|
||||
from application.llm.providers.openai_compatible import OpenAICompatibleProvider
|
||||
from application.llm.providers.openrouter import OpenRouterProvider
|
||||
from application.llm.providers.premai import PremAIProvider
|
||||
from application.llm.providers.sagemaker import SagemakerProvider
|
||||
|
||||
# Order here is the order the registry iterates providers (and therefore
|
||||
# the order ``/api/models`` reports them). Match the historical order
|
||||
# from the old ModelRegistry._load_models for byte-stable output during
|
||||
# the migration. ``openai_compatible`` slots in right after ``openai``
|
||||
# so legacy ``OPENAI_BASE_URL`` models keep landing in the same place.
|
||||
ALL_PROVIDERS: List[Provider] = [
|
||||
DocsGPTProvider(),
|
||||
OpenAIProvider(),
|
||||
OpenAICompatibleProvider(),
|
||||
AzureOpenAIProvider(),
|
||||
AnthropicProvider(),
|
||||
GoogleProvider(),
|
||||
GroqProvider(),
|
||||
OpenRouterProvider(),
|
||||
NovitaProvider(),
|
||||
HuggingFaceProvider(),
|
||||
LlamaCppProvider(),
|
||||
PremAIProvider(),
|
||||
SagemakerProvider(),
|
||||
]
|
||||
|
||||
PROVIDERS_BY_NAME: Dict[str, Provider] = {p.name: p for p in ALL_PROVIDERS}
|
||||
|
||||
__all__ = ["ALL_PROVIDERS", "PROVIDERS_BY_NAME", "Provider"]
|
||||
51
application/llm/providers/_apikey_or_llm_name.py
Normal file
51
application/llm/providers/_apikey_or_llm_name.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""Shared helper for providers that follow the
|
||||
``<X>_API_KEY or (LLM_PROVIDER==X and API_KEY)`` pattern.
|
||||
|
||||
This is the dominant pattern across Anthropic, Google, Groq, OpenRouter,
|
||||
and Novita. Extracted here so each plugin stays a few lines long.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from application.core.model_settings import AvailableModel
|
||||
|
||||
|
||||
def get_api_key(
|
||||
settings,
|
||||
provider_name: str,
|
||||
provider_specific_key: Optional[str],
|
||||
) -> Optional[str]:
|
||||
if provider_specific_key:
|
||||
return provider_specific_key
|
||||
if settings.LLM_PROVIDER == provider_name and settings.API_KEY:
|
||||
return settings.API_KEY
|
||||
return None
|
||||
|
||||
|
||||
def filter_models_by_llm_name(
|
||||
settings,
|
||||
provider_name: str,
|
||||
provider_specific_key: Optional[str],
|
||||
models: List[AvailableModel],
|
||||
) -> List[AvailableModel]:
|
||||
"""Mirrors the historical ``_add_<X>_models`` selection logic.
|
||||
|
||||
Behavior:
|
||||
- If the provider-specific API key is set → load all models.
|
||||
- Else if ``LLM_PROVIDER`` matches and ``LLM_NAME`` matches a known
|
||||
model → load just that model.
|
||||
- Otherwise → load all models (preserved "load anyway" branch from
|
||||
the original methods).
|
||||
"""
|
||||
if provider_specific_key:
|
||||
return models
|
||||
if (
|
||||
settings.LLM_PROVIDER == provider_name
|
||||
and settings.LLM_NAME
|
||||
):
|
||||
named = [m for m in models if m.id == settings.LLM_NAME]
|
||||
if named:
|
||||
return named
|
||||
return models
|
||||
23
application/llm/providers/anthropic.py
Normal file
23
application/llm/providers/anthropic.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from application.llm.anthropic import AnthropicLLM
|
||||
from application.llm.providers._apikey_or_llm_name import (
|
||||
filter_models_by_llm_name,
|
||||
get_api_key,
|
||||
)
|
||||
from application.llm.providers.base import Provider
|
||||
|
||||
|
||||
class AnthropicProvider(Provider):
|
||||
name = "anthropic"
|
||||
llm_class = AnthropicLLM
|
||||
|
||||
def get_api_key(self, settings) -> Optional[str]:
|
||||
return get_api_key(settings, self.name, settings.ANTHROPIC_API_KEY)
|
||||
|
||||
def filter_yaml_models(self, settings, models):
|
||||
return filter_models_by_llm_name(
|
||||
settings, self.name, settings.ANTHROPIC_API_KEY, models
|
||||
)
|
||||
30
application/llm/providers/azure_openai.py
Normal file
30
application/llm/providers/azure_openai.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from application.llm.openai import AzureOpenAILLM
|
||||
from application.llm.providers.base import Provider
|
||||
|
||||
|
||||
class AzureOpenAIProvider(Provider):
|
||||
name = "azure_openai"
|
||||
llm_class = AzureOpenAILLM
|
||||
|
||||
def get_api_key(self, settings) -> Optional[str]:
|
||||
# Azure historically uses the generic API_KEY field.
|
||||
return settings.API_KEY
|
||||
|
||||
def is_enabled(self, settings) -> bool:
|
||||
if settings.OPENAI_API_BASE:
|
||||
return True
|
||||
return settings.LLM_PROVIDER == self.name and bool(settings.API_KEY)
|
||||
|
||||
def filter_yaml_models(self, settings, models):
|
||||
# Mirrors _add_azure_openai_models: when LLM_PROVIDER==azure_openai
|
||||
# and LLM_NAME matches a known model, narrow to that one model.
|
||||
# Otherwise load the entire catalog.
|
||||
if settings.LLM_PROVIDER == self.name and settings.LLM_NAME:
|
||||
named = [m for m in models if m.id == settings.LLM_NAME]
|
||||
if named:
|
||||
return named
|
||||
return models
|
||||
74
application/llm/providers/base.py
Normal file
74
application/llm/providers/base.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, ClassVar, List, Optional, Type
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from application.core.model_settings import AvailableModel
|
||||
from application.core.model_yaml import ProviderCatalog
|
||||
from application.core.settings import Settings
|
||||
from application.llm.base import BaseLLM
|
||||
|
||||
|
||||
class Provider(ABC):
|
||||
"""Owns the *behavior* of an LLM provider.
|
||||
|
||||
Concrete providers declare their name, the LLM class to instantiate,
|
||||
and how to resolve credentials from settings. Static model catalogs
|
||||
live in YAML under ``application/core/models/`` and are joined to the
|
||||
provider by name at registry load time.
|
||||
|
||||
Most plugins receive zero or one catalog at registry-build time. The
|
||||
``openai_compatible`` plugin is the exception: it receives one catalog
|
||||
per matching YAML file, each with its own ``api_key_env`` and
|
||||
``base_url``. Plugins that need per-catalog metadata override
|
||||
``get_models``; the default implementation merges catalogs and routes
|
||||
through ``filter_yaml_models`` + ``extra_models``.
|
||||
"""
|
||||
|
||||
name: ClassVar[str]
|
||||
# ``None`` means the provider appears in the catalog but isn't
|
||||
# dispatchable through LLMCreator (e.g. Hugging Face today, where the
|
||||
# original LLMCreator dict had no entry).
|
||||
llm_class: ClassVar[Optional[Type["BaseLLM"]]] = None
|
||||
|
||||
@abstractmethod
|
||||
def get_api_key(self, settings: "Settings") -> Optional[str]:
|
||||
"""Return the API key for this provider, or None if unavailable."""
|
||||
|
||||
def is_enabled(self, settings: "Settings") -> bool:
|
||||
"""Whether this provider should contribute models to the registry."""
|
||||
return bool(self.get_api_key(settings))
|
||||
|
||||
def filter_yaml_models(
|
||||
self, settings: "Settings", models: List["AvailableModel"]
|
||||
) -> List["AvailableModel"]:
|
||||
"""Hook to filter YAML-loaded models. Default: return all."""
|
||||
return models
|
||||
|
||||
def extra_models(self, settings: "Settings") -> List["AvailableModel"]:
|
||||
"""Hook to add dynamic models not declared in YAML. Default: none."""
|
||||
return []
|
||||
|
||||
def get_models(
|
||||
self,
|
||||
settings: "Settings",
|
||||
catalogs: List["ProviderCatalog"],
|
||||
) -> List["AvailableModel"]:
|
||||
"""Final list of models this plugin contributes.
|
||||
|
||||
Default: merge the models across all matched catalogs (later
|
||||
catalog wins on duplicate id), filter via ``filter_yaml_models``,
|
||||
then append ``extra_models``. Override when per-catalog metadata
|
||||
matters (see ``OpenAICompatibleProvider``).
|
||||
"""
|
||||
merged: List["AvailableModel"] = []
|
||||
seen: dict = {}
|
||||
for c in catalogs:
|
||||
for m in c.models:
|
||||
if m.id in seen:
|
||||
merged[seen[m.id]] = m
|
||||
else:
|
||||
seen[m.id] = len(merged)
|
||||
merged.append(m)
|
||||
return self.filter_yaml_models(settings, merged) + self.extra_models(settings)
|
||||
22
application/llm/providers/docsgpt.py
Normal file
22
application/llm/providers/docsgpt.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from application.llm.docsgpt_provider import DocsGPTAPILLM
|
||||
from application.llm.providers.base import Provider
|
||||
|
||||
|
||||
class DocsGPTProvider(Provider):
|
||||
name = "docsgpt"
|
||||
llm_class = DocsGPTAPILLM
|
||||
|
||||
def get_api_key(self, settings) -> Optional[str]:
|
||||
# No provider-specific key; the LLM class can use the generic
|
||||
# API_KEY fallback if it needs one. Mirrors model_utils' historical
|
||||
# behavior of returning settings.API_KEY when no specific key exists.
|
||||
return settings.API_KEY
|
||||
|
||||
def is_enabled(self, settings) -> bool:
|
||||
# The hosted DocsGPT model is hidden when the deployment is
|
||||
# pointed at a custom OpenAI-compatible endpoint.
|
||||
return not settings.OPENAI_BASE_URL
|
||||
23
application/llm/providers/google.py
Normal file
23
application/llm/providers/google.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from application.llm.google_ai import GoogleLLM
|
||||
from application.llm.providers._apikey_or_llm_name import (
|
||||
filter_models_by_llm_name,
|
||||
get_api_key,
|
||||
)
|
||||
from application.llm.providers.base import Provider
|
||||
|
||||
|
||||
class GoogleProvider(Provider):
|
||||
name = "google"
|
||||
llm_class = GoogleLLM
|
||||
|
||||
def get_api_key(self, settings) -> Optional[str]:
|
||||
return get_api_key(settings, self.name, settings.GOOGLE_API_KEY)
|
||||
|
||||
def filter_yaml_models(self, settings, models):
|
||||
return filter_models_by_llm_name(
|
||||
settings, self.name, settings.GOOGLE_API_KEY, models
|
||||
)
|
||||
23
application/llm/providers/groq.py
Normal file
23
application/llm/providers/groq.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from application.llm.groq import GroqLLM
|
||||
from application.llm.providers._apikey_or_llm_name import (
|
||||
filter_models_by_llm_name,
|
||||
get_api_key,
|
||||
)
|
||||
from application.llm.providers.base import Provider
|
||||
|
||||
|
||||
class GroqProvider(Provider):
|
||||
name = "groq"
|
||||
llm_class = GroqLLM
|
||||
|
||||
def get_api_key(self, settings) -> Optional[str]:
|
||||
return get_api_key(settings, self.name, settings.GROQ_API_KEY)
|
||||
|
||||
def filter_yaml_models(self, settings, models):
|
||||
return filter_models_by_llm_name(
|
||||
settings, self.name, settings.GROQ_API_KEY, models
|
||||
)
|
||||
25
application/llm/providers/huggingface.py
Normal file
25
application/llm/providers/huggingface.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from application.llm.providers._apikey_or_llm_name import (
|
||||
get_api_key as shared_get_api_key,
|
||||
)
|
||||
from application.llm.providers.base import Provider
|
||||
|
||||
|
||||
class HuggingFaceProvider(Provider):
|
||||
"""Surfaces ``huggingface-local`` to the model catalog.
|
||||
|
||||
Not dispatchable through LLMCreator — historically there was no
|
||||
HuggingFaceLLM entry in ``LLMCreator.llms``, and calling ``create_llm``
|
||||
with ``"huggingface"`` raised ``ValueError``. We preserve that
|
||||
behavior: the model appears in ``/api/models`` but selecting it
|
||||
surfaces the same error it always did.
|
||||
"""
|
||||
|
||||
name = "huggingface"
|
||||
llm_class = None # not dispatchable
|
||||
|
||||
def get_api_key(self, settings) -> Optional[str]:
|
||||
return shared_get_api_key(settings, self.name, settings.HUGGINGFACE_API_KEY)
|
||||
19
application/llm/providers/llama_cpp.py
Normal file
19
application/llm/providers/llama_cpp.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from application.llm.llama_cpp import LlamaCpp
|
||||
from application.llm.providers.base import Provider
|
||||
|
||||
|
||||
class LlamaCppProvider(Provider):
|
||||
"""LLMCreator-only plugin: invocable via LLM_PROVIDER but not in the catalog."""
|
||||
|
||||
name = "llama.cpp"
|
||||
llm_class = LlamaCpp
|
||||
|
||||
def get_api_key(self, settings) -> Optional[str]:
|
||||
return settings.API_KEY
|
||||
|
||||
def is_enabled(self, settings) -> bool:
|
||||
return False
|
||||
23
application/llm/providers/novita.py
Normal file
23
application/llm/providers/novita.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from application.llm.novita import NovitaLLM
|
||||
from application.llm.providers._apikey_or_llm_name import (
|
||||
filter_models_by_llm_name,
|
||||
get_api_key,
|
||||
)
|
||||
from application.llm.providers.base import Provider
|
||||
|
||||
|
||||
class NovitaProvider(Provider):
|
||||
name = "novita"
|
||||
llm_class = NovitaLLM
|
||||
|
||||
def get_api_key(self, settings) -> Optional[str]:
|
||||
return get_api_key(settings, self.name, settings.NOVITA_API_KEY)
|
||||
|
||||
def filter_yaml_models(self, settings, models):
|
||||
return filter_models_by_llm_name(
|
||||
settings, self.name, settings.NOVITA_API_KEY, models
|
||||
)
|
||||
37
application/llm/providers/openai.py
Normal file
37
application/llm/providers/openai.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from application.llm.openai import OpenAILLM
|
||||
from application.llm.providers.base import Provider
|
||||
|
||||
|
||||
class OpenAIProvider(Provider):
|
||||
name = "openai"
|
||||
llm_class = OpenAILLM
|
||||
|
||||
def get_api_key(self, settings) -> Optional[str]:
|
||||
if settings.OPENAI_API_KEY:
|
||||
return settings.OPENAI_API_KEY
|
||||
if settings.LLM_PROVIDER == self.name and settings.API_KEY:
|
||||
return settings.API_KEY
|
||||
return None
|
||||
|
||||
def is_enabled(self, settings) -> bool:
|
||||
# When the deployment is pointed at a custom OpenAI-compatible
|
||||
# endpoint (Ollama, LM Studio, ...), the cloud-OpenAI catalog is
|
||||
# suppressed but ``is_enabled`` stays True — necessary so the
|
||||
# filter below still gets to drop the catalog (rather than the
|
||||
# registry skipping the provider entirely and missing the rule).
|
||||
if settings.OPENAI_BASE_URL:
|
||||
return True
|
||||
return bool(self.get_api_key(settings))
|
||||
|
||||
def filter_yaml_models(self, settings, models):
|
||||
# Legacy local-endpoint mode hides the cloud catalog. The
|
||||
# corresponding dynamic models live in OpenAICompatibleProvider.
|
||||
if settings.OPENAI_BASE_URL:
|
||||
return []
|
||||
if not settings.OPENAI_API_KEY:
|
||||
return []
|
||||
return models
|
||||
149
application/llm/providers/openai_compatible.py
Normal file
149
application/llm/providers/openai_compatible.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""Generic provider for OpenAI-wire-compatible endpoints.
|
||||
|
||||
Each ``openai_compatible`` YAML file describes one logical endpoint
|
||||
(Mistral, Together, Fireworks, Ollama, ...) with its own
|
||||
``api_key_env`` and ``base_url``. Multiple files can coexist; the
|
||||
plugin produces one set of models per file, each pre-configured with
|
||||
the right credentials and URL.
|
||||
|
||||
The plugin also handles the **legacy** ``OPENAI_BASE_URL`` + ``LLM_NAME``
|
||||
local-endpoint pattern that previously lived in ``OpenAIProvider``. That
|
||||
path generates models dynamically from ``LLM_NAME``, using
|
||||
``OPENAI_BASE_URL`` and ``OPENAI_API_KEY`` as the endpoint config.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from application.core.model_settings import (
|
||||
AvailableModel,
|
||||
ModelCapabilities,
|
||||
ModelProvider,
|
||||
)
|
||||
from application.llm.openai import OpenAILLM
|
||||
from application.llm.providers.base import Provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _parse_model_names(llm_name: Optional[str]) -> List[str]:
|
||||
if not llm_name:
|
||||
return []
|
||||
return [name.strip() for name in llm_name.split(",") if name.strip()]
|
||||
|
||||
|
||||
class OpenAICompatibleProvider(Provider):
|
||||
name = "openai_compatible"
|
||||
llm_class = OpenAILLM
|
||||
|
||||
def get_api_key(self, settings) -> Optional[str]:
|
||||
# Per-model: each catalog supplies its own ``api_key_env``. There
|
||||
# is no single plugin-wide key. LLMCreator reads the per-model
|
||||
# ``api_key`` set during catalog materialization.
|
||||
return None
|
||||
|
||||
def is_enabled(self, settings) -> bool:
|
||||
# Concrete enablement happens per catalog (in ``get_models``).
|
||||
# Returning True lets the registry call ``get_models`` so we can
|
||||
# decide per-file whether to contribute models.
|
||||
return True
|
||||
|
||||
def get_models(self, settings, catalogs) -> List[AvailableModel]:
|
||||
out: List[AvailableModel] = []
|
||||
|
||||
for catalog in catalogs:
|
||||
out.extend(self._materialize_yaml_catalog(catalog))
|
||||
|
||||
if settings.OPENAI_BASE_URL and settings.LLM_NAME:
|
||||
out.extend(self._materialize_legacy_local_endpoint(settings))
|
||||
|
||||
return out
|
||||
|
||||
def _materialize_yaml_catalog(self, catalog) -> List[AvailableModel]:
|
||||
"""Resolve one openai_compatible YAML into ready-to-dispatch models.
|
||||
|
||||
Skipped (with an INFO-level log) if ``api_key_env`` resolves to
|
||||
nothing — no point publishing models the user can't actually
|
||||
call. INFO rather than WARNING because operators may legitimately
|
||||
drop multiple provider YAMLs as templates and only set the env
|
||||
vars for the ones they actually use; a missing key is ambiguous,
|
||||
not necessarily a misconfig.
|
||||
"""
|
||||
if not catalog.base_url:
|
||||
raise ValueError(
|
||||
f"{catalog.source_path}: openai_compatible YAML must set "
|
||||
"'base_url'."
|
||||
)
|
||||
if not catalog.api_key_env:
|
||||
raise ValueError(
|
||||
f"{catalog.source_path}: openai_compatible YAML must set "
|
||||
"'api_key_env'."
|
||||
)
|
||||
|
||||
api_key = os.environ.get(catalog.api_key_env)
|
||||
if not api_key:
|
||||
logger.info(
|
||||
"openai_compatible catalog %s skipped: env var %s is not set",
|
||||
catalog.source_path,
|
||||
catalog.api_key_env,
|
||||
)
|
||||
return []
|
||||
|
||||
out: List[AvailableModel] = []
|
||||
for m in catalog.models:
|
||||
out.append(self._with_endpoint(m, catalog.base_url, api_key))
|
||||
return out
|
||||
|
||||
def _materialize_legacy_local_endpoint(self, settings) -> List[AvailableModel]:
|
||||
"""Generate AvailableModels from ``LLM_NAME`` for the legacy
|
||||
``OPENAI_BASE_URL`` deployment pattern (Ollama, LM Studio, ...).
|
||||
|
||||
Preserves the historical ``provider="openai"`` display behavior
|
||||
by setting ``display_provider="openai"``.
|
||||
"""
|
||||
from application.core.model_yaml import resolve_attachment_alias
|
||||
|
||||
attachments = resolve_attachment_alias("image")
|
||||
api_key = settings.OPENAI_API_KEY or settings.API_KEY
|
||||
out: List[AvailableModel] = []
|
||||
for model_name in _parse_model_names(settings.LLM_NAME):
|
||||
out.append(
|
||||
AvailableModel(
|
||||
id=model_name,
|
||||
provider=ModelProvider.OPENAI_COMPATIBLE,
|
||||
display_name=model_name,
|
||||
description=f"Custom OpenAI-compatible model at {settings.OPENAI_BASE_URL}",
|
||||
base_url=settings.OPENAI_BASE_URL,
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supported_attachment_types=attachments,
|
||||
),
|
||||
api_key=api_key,
|
||||
display_provider="openai",
|
||||
)
|
||||
)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def _with_endpoint(
|
||||
model: AvailableModel, base_url: str, api_key: str
|
||||
) -> AvailableModel:
|
||||
"""Return a copy of ``model`` carrying the catalog's endpoint config.
|
||||
|
||||
The catalog-level ``base_url`` is the default; an explicit
|
||||
per-model ``base_url`` in the YAML wins.
|
||||
"""
|
||||
return AvailableModel(
|
||||
id=model.id,
|
||||
provider=model.provider,
|
||||
display_name=model.display_name,
|
||||
description=model.description,
|
||||
capabilities=model.capabilities,
|
||||
enabled=model.enabled,
|
||||
base_url=model.base_url or base_url,
|
||||
display_provider=model.display_provider,
|
||||
api_key=api_key,
|
||||
)
|
||||
23
application/llm/providers/openrouter.py
Normal file
23
application/llm/providers/openrouter.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from application.llm.open_router import OpenRouterLLM
|
||||
from application.llm.providers._apikey_or_llm_name import (
|
||||
filter_models_by_llm_name,
|
||||
get_api_key,
|
||||
)
|
||||
from application.llm.providers.base import Provider
|
||||
|
||||
|
||||
class OpenRouterProvider(Provider):
|
||||
name = "openrouter"
|
||||
llm_class = OpenRouterLLM
|
||||
|
||||
def get_api_key(self, settings) -> Optional[str]:
|
||||
return get_api_key(settings, self.name, settings.OPEN_ROUTER_API_KEY)
|
||||
|
||||
def filter_yaml_models(self, settings, models):
|
||||
return filter_models_by_llm_name(
|
||||
settings, self.name, settings.OPEN_ROUTER_API_KEY, models
|
||||
)
|
||||
19
application/llm/providers/premai.py
Normal file
19
application/llm/providers/premai.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from application.llm.premai import PremAILLM
|
||||
from application.llm.providers.base import Provider
|
||||
|
||||
|
||||
class PremAIProvider(Provider):
|
||||
"""LLMCreator-only plugin: invocable via LLM_PROVIDER but not in the catalog."""
|
||||
|
||||
name = "premai"
|
||||
llm_class = PremAILLM
|
||||
|
||||
def get_api_key(self, settings) -> Optional[str]:
|
||||
return settings.API_KEY
|
||||
|
||||
def is_enabled(self, settings) -> bool:
|
||||
return False
|
||||
24
application/llm/providers/sagemaker.py
Normal file
24
application/llm/providers/sagemaker.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from application.llm.sagemaker import SagemakerAPILLM
|
||||
from application.llm.providers.base import Provider
|
||||
|
||||
|
||||
class SagemakerProvider(Provider):
|
||||
"""LLMCreator-only plugin: invocable via LLM_PROVIDER but not in the catalog.
|
||||
|
||||
SageMaker reads its credentials from ``SAGEMAKER_*`` settings inside
|
||||
the LLM class itself; this plugin's ``get_api_key`` exists only for
|
||||
LLMCreator's symmetry.
|
||||
"""
|
||||
|
||||
name = "sagemaker"
|
||||
llm_class = SagemakerAPILLM
|
||||
|
||||
def get_api_key(self, settings) -> Optional[str]:
|
||||
return settings.API_KEY
|
||||
|
||||
def is_enabled(self, settings) -> bool:
|
||||
return False
|
||||
@@ -82,6 +82,7 @@ python-dateutil==2.9.0.post0
|
||||
python-dotenv
|
||||
python-jose==3.5.0
|
||||
python-pptx==1.0.2
|
||||
PyYAML
|
||||
redis==7.4.0
|
||||
referencing>=0.28.0,<0.38.0
|
||||
regex==2026.4.4
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Repository for the ``agents`` table.
|
||||
|
||||
This is the most complex Phase 2 repository. Covers every write operation
|
||||
the legacy Mongo code performs on ``agents_collection``:
|
||||
Covers every write operation the legacy Mongo code performs on ``agents_collection``:
|
||||
|
||||
- create, update, delete
|
||||
- find by key (API key lookup)
|
||||
|
||||
@@ -348,6 +348,16 @@ def run_agent_logic(agent_config, input_data):
|
||||
model_id = agent_default_model
|
||||
else:
|
||||
model_id = get_default_model_id()
|
||||
if agent_default_model:
|
||||
# Stored model_id no longer resolves in the registry. Log so
|
||||
# operators can detect bad YAML edits before users complain;
|
||||
# behavior matches the historical silent fallback.
|
||||
logging.warning(
|
||||
"Agent %s references unknown model_id %r; falling back to %r",
|
||||
agent_id,
|
||||
agent_default_model,
|
||||
model_id,
|
||||
)
|
||||
|
||||
# Get provider and API key for the selected model
|
||||
provider = get_provider_from_model_id(model_id) if model_id else settings.LLM_PROVIDER
|
||||
|
||||
@@ -99,6 +99,82 @@ EMBEDDINGS_NAME=huggingface_sentence-transformers/all-mpnet-base-v2 # You can al
|
||||
|
||||
In this case, even though you are using Ollama locally, `LLM_PROVIDER` is set to `openai` because Ollama (and many other local inference engines) are designed to be API-compatible with OpenAI. `OPENAI_BASE_URL` points DocsGPT to the local Ollama server.
|
||||
|
||||
## Adding Custom Models (`MODELS_CONFIG_DIR`)
|
||||
|
||||
DocsGPT ships with a built-in catalog of models for the providers it
|
||||
supports out of the box (OpenAI, Anthropic, Google, Groq, OpenRouter,
|
||||
Novita, Azure OpenAI, Hugging Face, DocsGPT). To add **your own
|
||||
models** without forking the repo — for example, a Mistral or Together
|
||||
account, a self-hosted vLLM endpoint, or any other OpenAI-compatible
|
||||
API — point `MODELS_CONFIG_DIR` at a directory of YAML files.
|
||||
|
||||
```
|
||||
MODELS_CONFIG_DIR=/etc/docsgpt/models
|
||||
MISTRAL_API_KEY=sk-...
|
||||
```
|
||||
|
||||
A minimal YAML for one provider:
|
||||
|
||||
```yaml
|
||||
# /etc/docsgpt/models/mistral.yaml
|
||||
provider: openai_compatible
|
||||
display_provider: mistral
|
||||
api_key_env: MISTRAL_API_KEY
|
||||
base_url: https://api.mistral.ai/v1
|
||||
defaults:
|
||||
supports_tools: true
|
||||
context_window: 128000
|
||||
models:
|
||||
- id: mistral-large-latest
|
||||
display_name: Mistral Large
|
||||
- id: mistral-small-latest
|
||||
display_name: Mistral Small
|
||||
```
|
||||
|
||||
After restart, those models appear in `/api/models` and are selectable
|
||||
in the UI. A working template lives at
|
||||
`application/core/models/examples/mistral.yaml.example`.
|
||||
|
||||
**What you can do:**
|
||||
|
||||
- Add new `openai_compatible` providers (Mistral, Together, Fireworks,
|
||||
Ollama, vLLM, ...) — one YAML per provider, each with its own
|
||||
`api_key_env` and `base_url`.
|
||||
- Extend an existing provider's catalog by dropping a YAML with the
|
||||
same `provider:` value as the built-in (e.g. `provider: anthropic`
|
||||
with extra models).
|
||||
- Override a built-in model's capabilities by re-declaring the same
|
||||
`id` — later wins, override is logged at `WARNING`.
|
||||
|
||||
**What you cannot do via `MODELS_CONFIG_DIR`:** add a brand-new
|
||||
non-OpenAI provider. That requires a Python plugin under
|
||||
`application/llm/providers/`. See
|
||||
`application/core/models/README.md` for the full schema reference.
|
||||
|
||||
### Docker
|
||||
|
||||
Mount the directory and set the env var:
|
||||
|
||||
```yaml
|
||||
# docker-compose.yml
|
||||
services:
|
||||
app:
|
||||
image: arc53/docsgpt
|
||||
environment:
|
||||
MODELS_CONFIG_DIR: /etc/docsgpt/models
|
||||
MISTRAL_API_KEY: ${MISTRAL_API_KEY}
|
||||
volumes:
|
||||
- ./my-models:/etc/docsgpt/models:ro
|
||||
```
|
||||
|
||||
### Misconfiguration
|
||||
|
||||
If `MODELS_CONFIG_DIR` is set but the path doesn't exist (or isn't a
|
||||
directory), the app logs a `WARNING` at boot and continues with just
|
||||
the built-in catalog — it does **not** fail to start. If a YAML
|
||||
declares an unknown provider name or has a schema error, the app
|
||||
**does** fail to start, with the offending file path in the message.
|
||||
|
||||
## Speech-to-Text Settings
|
||||
|
||||
DocsGPT can transcribe audio in two places:
|
||||
|
||||
@@ -200,7 +200,7 @@ class TestSetupPeriodicTasks:
|
||||
|
||||
setup_periodic_tasks(sender)
|
||||
|
||||
assert sender.add_periodic_task.call_count == 4
|
||||
assert sender.add_periodic_task.call_count == 5
|
||||
|
||||
calls = sender.add_periodic_task.call_args_list
|
||||
|
||||
@@ -212,6 +212,8 @@ class TestSetupPeriodicTasks:
|
||||
assert calls[2][0][0] == timedelta(days=30)
|
||||
# pending_tool_state TTL cleanup (60s)
|
||||
assert calls[3][0][0] == timedelta(seconds=60)
|
||||
# version-check (every 7h)
|
||||
assert calls[4][0][0] == timedelta(hours=7)
|
||||
|
||||
|
||||
class TestMcpOauthTask:
|
||||
|
||||
306
tests/core/test_model_registry_yaml.py
Normal file
306
tests/core/test_model_registry_yaml.py
Normal file
@@ -0,0 +1,306 @@
|
||||
"""Phase 1 regression tests for the YAML-driven ModelRegistry.
|
||||
|
||||
These tests encode the contract that persisted agent / workflow /
|
||||
conversation references depend on: every model id and core capability
|
||||
that existed in the old ``model_configs.py`` lists must continue to be
|
||||
produced by the new YAML-backed registry.
|
||||
|
||||
If a future YAML edit accidentally renames an id or changes a
|
||||
capability, these tests fail at CI before merge — protecting agents and
|
||||
workflows from silent fallback to the system default.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from application.core.model_registry import ModelRegistry
|
||||
from application.core.model_yaml import (
|
||||
BUILTIN_MODELS_DIR,
|
||||
load_model_yamls,
|
||||
)
|
||||
|
||||
|
||||
# ── Per-provider expected IDs ─────────────────────────────────────────────
|
||||
# Snapshot of the current built-in catalog. If you intentionally change
|
||||
# what models a provider's YAML lists, update this constant in the same
|
||||
# commit. The test exists to catch *unintentional* renames (e.g. a typo
|
||||
# in an upstream model id) that would silently break every agent that
|
||||
# references the old id.
|
||||
EXPECTED_IDS = {
|
||||
"openai": {"gpt-5.5", "gpt-5.4-mini", "gpt-5.4-nano"},
|
||||
"anthropic": {
|
||||
"claude-opus-4-7",
|
||||
"claude-sonnet-4-6",
|
||||
"claude-haiku-4-5",
|
||||
},
|
||||
"google": {
|
||||
"gemini-3.1-pro-preview",
|
||||
"gemini-3-flash-preview",
|
||||
"gemini-3.1-flash-lite-preview",
|
||||
},
|
||||
"groq": {
|
||||
"openai/gpt-oss-120b",
|
||||
"llama-3.3-70b-versatile",
|
||||
"llama-3.1-8b-instant",
|
||||
},
|
||||
"openrouter": {
|
||||
"qwen/qwen3-coder:free",
|
||||
"deepseek/deepseek-v3.2",
|
||||
"anthropic/claude-sonnet-4.6",
|
||||
},
|
||||
"novita": {
|
||||
"deepseek/deepseek-v4-pro",
|
||||
"moonshotai/kimi-k2.6",
|
||||
"zai-org/glm-5",
|
||||
},
|
||||
"azure_openai": {
|
||||
"azure-gpt-5.5",
|
||||
"azure-gpt-5.4-mini",
|
||||
"azure-gpt-5.4-nano",
|
||||
},
|
||||
"docsgpt": {"docsgpt-local"},
|
||||
"huggingface": {"huggingface-local"},
|
||||
}
|
||||
|
||||
|
||||
def _make_settings(**overrides):
|
||||
s = MagicMock()
|
||||
# All credential / mode flags off by default so each test opts in.
|
||||
s.OPENAI_BASE_URL = None
|
||||
s.OPENAI_API_KEY = None
|
||||
s.OPENAI_API_BASE = None
|
||||
s.ANTHROPIC_API_KEY = None
|
||||
s.GOOGLE_API_KEY = None
|
||||
s.GROQ_API_KEY = None
|
||||
s.OPEN_ROUTER_API_KEY = None
|
||||
s.NOVITA_API_KEY = None
|
||||
s.HUGGINGFACE_API_KEY = None
|
||||
s.LLM_PROVIDER = ""
|
||||
s.LLM_NAME = None
|
||||
s.API_KEY = None
|
||||
s.MODELS_CONFIG_DIR = None
|
||||
for k, v in overrides.items():
|
||||
setattr(s, k, v)
|
||||
return s
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_registry():
|
||||
ModelRegistry.reset()
|
||||
yield
|
||||
ModelRegistry.reset()
|
||||
|
||||
|
||||
# ── YAML schema / loader ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _by_provider(catalogs):
|
||||
"""Group a list of catalogs by provider name. Mirrors the registry's
|
||||
own grouping; useful for asserting per-provider model sets in tests."""
|
||||
out = {}
|
||||
for c in catalogs:
|
||||
out.setdefault(c.provider, []).append(c)
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestYAMLLoader:
|
||||
def test_loader_produces_expected_provider_set(self):
|
||||
catalogs = load_model_yamls([BUILTIN_MODELS_DIR])
|
||||
providers = {c.provider for c in catalogs}
|
||||
assert providers == set(EXPECTED_IDS.keys())
|
||||
|
||||
def test_each_provider_has_expected_ids(self):
|
||||
grouped = _by_provider(load_model_yamls([BUILTIN_MODELS_DIR]))
|
||||
for provider, expected in EXPECTED_IDS.items():
|
||||
actual = {m.id for c in grouped[provider] for m in c.models}
|
||||
assert actual == expected, f"{provider}: expected {expected}, got {actual}"
|
||||
|
||||
def test_attachment_alias_image_expands_to_five_mime_types(self):
|
||||
grouped = _by_provider(load_model_yamls([BUILTIN_MODELS_DIR]))
|
||||
# OpenAI uses `attachments: [image]` in its defaults block.
|
||||
for c in grouped["openai"]:
|
||||
for m in c.models:
|
||||
assert "image/png" in m.capabilities.supported_attachment_types
|
||||
assert "image/jpeg" in m.capabilities.supported_attachment_types
|
||||
assert "image/webp" in m.capabilities.supported_attachment_types
|
||||
assert len(m.capabilities.supported_attachment_types) == 5
|
||||
|
||||
def test_attachment_alias_pdf_plus_image_for_google(self):
|
||||
grouped = _by_provider(load_model_yamls([BUILTIN_MODELS_DIR]))
|
||||
for c in grouped["google"]:
|
||||
for m in c.models:
|
||||
assert "application/pdf" in m.capabilities.supported_attachment_types
|
||||
assert "image/png" in m.capabilities.supported_attachment_types
|
||||
assert len(m.capabilities.supported_attachment_types) == 6
|
||||
|
||||
def test_per_model_context_window_overrides_provider_default(self):
|
||||
grouped = _by_provider(load_model_yamls([BUILTIN_MODELS_DIR]))
|
||||
openai = {m.id: m for c in grouped["openai"] for m in c.models}
|
||||
# Provider default is 400_000; gpt-5.5 overrides to 1_050_000.
|
||||
assert openai["gpt-5.4-mini"].capabilities.context_window == 400_000
|
||||
assert openai["gpt-5.5"].capabilities.context_window == 1_050_000
|
||||
|
||||
|
||||
# ── Registry × settings: every documented .env permutation ───────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestRegistryPermutations:
|
||||
def test_openai_only(self):
|
||||
s = _make_settings(OPENAI_API_KEY="sk-test", LLM_PROVIDER="openai")
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
ids = {m.id for m in reg.get_all_models()}
|
||||
assert ids == EXPECTED_IDS["openai"] | EXPECTED_IDS["docsgpt"]
|
||||
|
||||
def test_openai_base_url_replaces_catalog_with_dynamic(self):
|
||||
s = _make_settings(
|
||||
OPENAI_BASE_URL="http://localhost:11434/v1",
|
||||
OPENAI_API_KEY="sk-test",
|
||||
LLM_PROVIDER="openai",
|
||||
LLM_NAME="llama3,gemma",
|
||||
)
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
ids = {m.id for m in reg.get_all_models()}
|
||||
# Custom local endpoint suppresses both the openai catalog AND
|
||||
# the docsgpt model (matching legacy behavior).
|
||||
assert ids == {"llama3", "gemma"}
|
||||
|
||||
def test_anthropic_only(self):
|
||||
s = _make_settings(ANTHROPIC_API_KEY="sk-ant")
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
ids = {m.id for m in reg.get_all_models()}
|
||||
assert ids == EXPECTED_IDS["anthropic"] | EXPECTED_IDS["docsgpt"]
|
||||
|
||||
def test_anthropic_via_llm_provider_with_llm_name(self):
|
||||
# Mirrors the historical _add_anthropic_models filter: when only
|
||||
# API_KEY (not ANTHROPIC_API_KEY) is set and LLM_NAME matches a
|
||||
# known model, only that model is loaded.
|
||||
s = _make_settings(
|
||||
LLM_PROVIDER="anthropic", API_KEY="key", LLM_NAME="claude-haiku-4-5"
|
||||
)
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
anthropic_ids = {
|
||||
m.id for m in reg.get_all_models() if m.provider.value == "anthropic"
|
||||
}
|
||||
assert anthropic_ids == {"claude-haiku-4-5"}
|
||||
|
||||
def test_google_only(self):
|
||||
s = _make_settings(GOOGLE_API_KEY="g-test")
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
ids = {m.id for m in reg.get_all_models()}
|
||||
assert ids == EXPECTED_IDS["google"] | EXPECTED_IDS["docsgpt"]
|
||||
|
||||
def test_groq_only(self):
|
||||
s = _make_settings(GROQ_API_KEY="g-test")
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
ids = {m.id for m in reg.get_all_models()}
|
||||
assert ids == EXPECTED_IDS["groq"] | EXPECTED_IDS["docsgpt"]
|
||||
|
||||
def test_openrouter_only(self):
|
||||
s = _make_settings(OPEN_ROUTER_API_KEY="or-test")
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
ids = {m.id for m in reg.get_all_models()}
|
||||
assert ids == EXPECTED_IDS["openrouter"] | EXPECTED_IDS["docsgpt"]
|
||||
|
||||
def test_novita_only(self):
|
||||
s = _make_settings(NOVITA_API_KEY="n-test")
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
ids = {m.id for m in reg.get_all_models()}
|
||||
assert ids == EXPECTED_IDS["novita"] | EXPECTED_IDS["docsgpt"]
|
||||
|
||||
def test_huggingface_only(self):
|
||||
s = _make_settings(HUGGINGFACE_API_KEY="hf-test")
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
ids = {m.id for m in reg.get_all_models()}
|
||||
assert ids == EXPECTED_IDS["huggingface"] | EXPECTED_IDS["docsgpt"]
|
||||
|
||||
def test_no_credentials_only_docsgpt(self):
|
||||
s = _make_settings()
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
ids = {m.id for m in reg.get_all_models()}
|
||||
assert ids == EXPECTED_IDS["docsgpt"]
|
||||
|
||||
def test_azure_via_provider(self):
|
||||
s = _make_settings(LLM_PROVIDER="azure_openai", API_KEY="key")
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
ids = {m.id for m in reg.get_all_models()}
|
||||
assert "azure-gpt-5.5" in ids
|
||||
|
||||
def test_azure_via_api_base(self):
|
||||
s = _make_settings(OPENAI_API_BASE="https://x.openai.azure.com")
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
ids = {m.id for m in reg.get_all_models()}
|
||||
assert "azure-gpt-5.5" in ids
|
||||
|
||||
def test_everything_set(self):
|
||||
s = _make_settings(
|
||||
OPENAI_API_KEY="x",
|
||||
ANTHROPIC_API_KEY="x",
|
||||
GOOGLE_API_KEY="x",
|
||||
GROQ_API_KEY="x",
|
||||
OPEN_ROUTER_API_KEY="x",
|
||||
NOVITA_API_KEY="x",
|
||||
HUGGINGFACE_API_KEY="x",
|
||||
OPENAI_API_BASE="x",
|
||||
)
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
ids = {m.id for m in reg.get_all_models()}
|
||||
all_expected = set()
|
||||
for v in EXPECTED_IDS.values():
|
||||
all_expected |= v
|
||||
assert ids == all_expected
|
||||
|
||||
|
||||
# ── Default model resolution ─────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestDefaultModelResolution:
|
||||
def test_llm_name_picks_default(self):
|
||||
s = _make_settings(
|
||||
ANTHROPIC_API_KEY="sk-ant", LLM_NAME="claude-opus-4-7"
|
||||
)
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
assert reg.default_model_id == "claude-opus-4-7"
|
||||
|
||||
def test_falls_back_to_first_model_when_no_match(self):
|
||||
s = _make_settings()
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
assert reg.default_model_id is not None
|
||||
assert reg.default_model_id in reg.models
|
||||
|
||||
|
||||
# ── Forward-compat: user_id parameter is accepted everywhere ─────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestUserIdForwardCompat:
|
||||
def test_lookup_methods_accept_user_id(self):
|
||||
s = _make_settings(OPENAI_API_KEY="sk-test")
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
# All lookup methods must accept user_id (currently ignored,
|
||||
# reserved for end-user BYOM).
|
||||
assert reg.get_model("gpt-5.5", user_id="alice") is not None
|
||||
assert len(reg.get_all_models(user_id="alice")) > 0
|
||||
assert len(reg.get_enabled_models(user_id="alice")) > 0
|
||||
assert reg.model_exists("gpt-5.5", user_id="alice") is True
|
||||
@@ -1,6 +1,17 @@
|
||||
"""Tests for application/core/model_settings.py"""
|
||||
"""Tests for application/core/model_settings.py.
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
The provider-specific load logic that used to live in private
|
||||
``_add_<X>_models`` methods now lives in plugin classes under
|
||||
``application/llm/providers/`` and YAML catalogs under
|
||||
``application/core/models/``. End-to-end coverage of the registry +
|
||||
plugin pipeline is in ``tests/core/test_model_registry_yaml.py``.
|
||||
|
||||
This file covers the data classes (``AvailableModel``,
|
||||
``ModelCapabilities``, ``ModelProvider``) and the singleton/lookup
|
||||
contract on ``ModelRegistry``.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -13,7 +24,6 @@ from application.core.model_settings import (
|
||||
|
||||
|
||||
class TestModelProvider:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_all_providers_exist(self):
|
||||
assert ModelProvider.OPENAI == "openai"
|
||||
@@ -31,7 +41,6 @@ class TestModelProvider:
|
||||
|
||||
|
||||
class TestModelCapabilities:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_defaults(self):
|
||||
caps = ModelCapabilities()
|
||||
@@ -56,7 +65,6 @@ class TestModelCapabilities:
|
||||
|
||||
|
||||
class TestAvailableModel:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_basic(self):
|
||||
model = AvailableModel(
|
||||
@@ -78,35 +86,67 @@ class TestAvailableModel:
|
||||
id="local-model",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="Local",
|
||||
base_url="http://localhost:11434",
|
||||
base_url="http://localhost:11434/v1",
|
||||
)
|
||||
d = model.to_dict()
|
||||
assert d["base_url"] == "http://localhost:11434"
|
||||
assert d["base_url"] == "http://localhost:11434/v1"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_includes_capabilities(self):
|
||||
caps = ModelCapabilities(supports_tools=True, context_window=64000)
|
||||
caps = ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
context_window=200000,
|
||||
supported_attachment_types=["image/png"],
|
||||
)
|
||||
model = AvailableModel(
|
||||
id="m1",
|
||||
provider=ModelProvider.ANTHROPIC,
|
||||
display_name="M1",
|
||||
id="m",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="M",
|
||||
capabilities=caps,
|
||||
)
|
||||
d = model.to_dict()
|
||||
assert d["supports_tools"] is True
|
||||
assert d["context_window"] == 64000
|
||||
assert d["supports_structured_output"] is True
|
||||
assert d["context_window"] == 200000
|
||||
assert d["supported_attachment_types"] == ["image/png"]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_disabled_model(self):
|
||||
model = AvailableModel(
|
||||
id="disabled",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="Disabled",
|
||||
enabled=False,
|
||||
)
|
||||
d = model.to_dict()
|
||||
assert d["enabled"] is False
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_api_key_field_never_serialized(self):
|
||||
"""Forward-compat hook: AvailableModel.api_key (reserved for the
|
||||
future end-user BYOM phase) must never leak into the wire format."""
|
||||
model = AvailableModel(
|
||||
id="byom",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="BYOM",
|
||||
api_key="secret-key-do-not-leak",
|
||||
)
|
||||
d = model.to_dict()
|
||||
assert "api_key" not in d
|
||||
for v in d.values():
|
||||
assert v != "secret-key-do-not-leak"
|
||||
|
||||
|
||||
class TestModelRegistry:
|
||||
class TestModelRegistryPublicAPI:
|
||||
"""Covers the public lookup contract. Loading behavior is exercised
|
||||
end-to-end in tests/core/test_model_registry_yaml.py."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_singleton(self):
|
||||
"""Reset singleton between tests."""
|
||||
ModelRegistry._instance = None
|
||||
ModelRegistry._initialized = False
|
||||
ModelRegistry.reset()
|
||||
yield
|
||||
ModelRegistry._instance = None
|
||||
ModelRegistry._initialized = False
|
||||
ModelRegistry.reset()
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_singleton(self):
|
||||
@@ -125,7 +165,9 @@ class TestModelRegistry:
|
||||
def test_get_model(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
model = AvailableModel(id="test", provider=ModelProvider.OPENAI, display_name="Test")
|
||||
model = AvailableModel(
|
||||
id="test", provider=ModelProvider.OPENAI, display_name="Test"
|
||||
)
|
||||
reg.models["test"] = model
|
||||
assert reg.get_model("test") is model
|
||||
assert reg.get_model("nonexistent") is None
|
||||
@@ -134,16 +176,30 @@ class TestModelRegistry:
|
||||
def test_get_all_models(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models["m1"] = AvailableModel(id="m1", provider=ModelProvider.OPENAI, display_name="M1")
|
||||
reg.models["m2"] = AvailableModel(id="m2", provider=ModelProvider.ANTHROPIC, display_name="M2")
|
||||
reg.models["m1"] = AvailableModel(
|
||||
id="m1", provider=ModelProvider.OPENAI, display_name="M1"
|
||||
)
|
||||
reg.models["m2"] = AvailableModel(
|
||||
id="m2", provider=ModelProvider.ANTHROPIC, display_name="M2"
|
||||
)
|
||||
assert len(reg.get_all_models()) == 2
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_enabled_models(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models["m1"] = AvailableModel(id="m1", provider=ModelProvider.OPENAI, display_name="M1", enabled=True)
|
||||
reg.models["m2"] = AvailableModel(id="m2", provider=ModelProvider.OPENAI, display_name="M2", enabled=False)
|
||||
reg.models["m1"] = AvailableModel(
|
||||
id="m1",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="M1",
|
||||
enabled=True,
|
||||
)
|
||||
reg.models["m2"] = AvailableModel(
|
||||
id="m2",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="M2",
|
||||
enabled=False,
|
||||
)
|
||||
enabled = reg.get_enabled_models()
|
||||
assert len(enabled) == 1
|
||||
assert enabled[0].id == "m1"
|
||||
@@ -152,652 +208,29 @@ class TestModelRegistry:
|
||||
def test_model_exists(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models["m1"] = AvailableModel(id="m1", provider=ModelProvider.OPENAI, display_name="M1")
|
||||
reg.models["m1"] = AvailableModel(
|
||||
id="m1", provider=ModelProvider.OPENAI, display_name="M1"
|
||||
)
|
||||
assert reg.model_exists("m1") is True
|
||||
assert reg.model_exists("m2") is False
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_parse_model_names(self):
|
||||
def test_lookups_accept_user_id_kwarg(self):
|
||||
"""Reserved for the future end-user BYOM phase. Currently ignored."""
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
assert reg._parse_model_names("model1,model2") == ["model1", "model2"]
|
||||
assert reg._parse_model_names("model1 , model2 ") == ["model1", "model2"]
|
||||
assert reg._parse_model_names("single") == ["single"]
|
||||
assert reg._parse_model_names("") == []
|
||||
assert reg._parse_model_names(None) == []
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_docsgpt_models(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
reg._add_docsgpt_models(mock_settings)
|
||||
assert "docsgpt-local" in reg.models
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_huggingface_models(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
reg._add_huggingface_models(mock_settings)
|
||||
assert "huggingface-local" in reg.models
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_load_models_with_openai_key(self):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.OPENAI_BASE_URL = None
|
||||
mock_settings.OPENAI_API_KEY = "sk-test"
|
||||
mock_settings.OPENAI_API_BASE = None
|
||||
mock_settings.ANTHROPIC_API_KEY = None
|
||||
mock_settings.GOOGLE_API_KEY = None
|
||||
mock_settings.GROQ_API_KEY = None
|
||||
mock_settings.OPEN_ROUTER_API_KEY = None
|
||||
mock_settings.NOVITA_API_KEY = None
|
||||
mock_settings.HUGGINGFACE_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "openai"
|
||||
mock_settings.LLM_NAME = ""
|
||||
mock_settings.API_KEY = None
|
||||
|
||||
with patch("application.core.settings.settings", mock_settings):
|
||||
reg = ModelRegistry()
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_load_models_custom_openai_base_url(self):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.OPENAI_BASE_URL = "http://localhost:11434/v1"
|
||||
mock_settings.OPENAI_API_KEY = "sk-test"
|
||||
mock_settings.OPENAI_API_BASE = None
|
||||
mock_settings.ANTHROPIC_API_KEY = None
|
||||
mock_settings.GOOGLE_API_KEY = None
|
||||
mock_settings.GROQ_API_KEY = None
|
||||
mock_settings.OPEN_ROUTER_API_KEY = None
|
||||
mock_settings.NOVITA_API_KEY = None
|
||||
mock_settings.HUGGINGFACE_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "openai"
|
||||
mock_settings.LLM_NAME = "llama3,gemma"
|
||||
mock_settings.API_KEY = None
|
||||
|
||||
with patch("application.core.settings.settings", mock_settings):
|
||||
reg = ModelRegistry()
|
||||
assert "llama3" in reg.models
|
||||
assert "gemma" in reg.models
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_default_model_selection_from_llm_name(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {"gpt-4": AvailableModel(id="gpt-4", provider=ModelProvider.OPENAI, display_name="GPT-4")}
|
||||
reg.default_model_id = "gpt-4"
|
||||
assert reg.default_model_id == "gpt-4"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_anthropic_models_with_key(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.ANTHROPIC_API_KEY = "sk-ant-test"
|
||||
mock_settings.LLM_PROVIDER = ""
|
||||
mock_settings.LLM_NAME = ""
|
||||
reg._add_anthropic_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_google_models_with_key(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.GOOGLE_API_KEY = "google-test"
|
||||
mock_settings.LLM_PROVIDER = ""
|
||||
mock_settings.LLM_NAME = ""
|
||||
reg._add_google_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_groq_models_with_key(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.GROQ_API_KEY = "groq-test"
|
||||
mock_settings.LLM_PROVIDER = ""
|
||||
mock_settings.LLM_NAME = ""
|
||||
reg._add_groq_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_openrouter_models_with_key(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.OPEN_ROUTER_API_KEY = "or-test"
|
||||
mock_settings.LLM_PROVIDER = ""
|
||||
mock_settings.LLM_NAME = ""
|
||||
reg._add_openrouter_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_novita_models_with_key(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.NOVITA_API_KEY = "novita-test"
|
||||
mock_settings.LLM_PROVIDER = ""
|
||||
mock_settings.LLM_NAME = ""
|
||||
reg._add_novita_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_azure_openai_models_specific(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.LLM_PROVIDER = "azure_openai"
|
||||
mock_settings.LLM_NAME = "nonexistent-model"
|
||||
reg._add_azure_openai_models(mock_settings)
|
||||
# Falls through to adding all azure models
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_anthropic_models_no_key_with_provider(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.ANTHROPIC_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "anthropic"
|
||||
mock_settings.LLM_NAME = "nonexistent"
|
||||
reg._add_anthropic_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_default_model_fallback_to_first(self):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.OPENAI_BASE_URL = None
|
||||
mock_settings.OPENAI_API_KEY = None
|
||||
mock_settings.OPENAI_API_BASE = None
|
||||
mock_settings.ANTHROPIC_API_KEY = None
|
||||
mock_settings.GOOGLE_API_KEY = None
|
||||
mock_settings.GROQ_API_KEY = None
|
||||
mock_settings.OPEN_ROUTER_API_KEY = None
|
||||
mock_settings.NOVITA_API_KEY = None
|
||||
mock_settings.HUGGINGFACE_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = ""
|
||||
mock_settings.LLM_NAME = ""
|
||||
mock_settings.API_KEY = None
|
||||
|
||||
with patch("application.core.settings.settings", mock_settings):
|
||||
reg = ModelRegistry()
|
||||
# Should have at least docsgpt-local
|
||||
assert reg.default_model_id is not None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_default_model_from_provider_fallback(self):
|
||||
"""When LLM_NAME is not set but LLM_PROVIDER and API_KEY are,
|
||||
default should be first model of that provider."""
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.OPENAI_BASE_URL = None
|
||||
mock_settings.OPENAI_API_KEY = "sk-test"
|
||||
mock_settings.OPENAI_API_BASE = None
|
||||
mock_settings.ANTHROPIC_API_KEY = None
|
||||
mock_settings.GOOGLE_API_KEY = None
|
||||
mock_settings.GROQ_API_KEY = None
|
||||
mock_settings.OPEN_ROUTER_API_KEY = None
|
||||
mock_settings.NOVITA_API_KEY = None
|
||||
mock_settings.HUGGINGFACE_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "openai"
|
||||
mock_settings.LLM_NAME = None
|
||||
mock_settings.API_KEY = "sk-test"
|
||||
|
||||
with patch("application.core.settings.settings", mock_settings):
|
||||
reg = ModelRegistry()
|
||||
assert reg.default_model_id is not None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_google_models_no_key_with_provider(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.GOOGLE_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "google"
|
||||
mock_settings.LLM_NAME = "nonexistent"
|
||||
reg._add_google_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_groq_models_no_key_with_provider(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.GROQ_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "groq"
|
||||
mock_settings.LLM_NAME = "nonexistent"
|
||||
reg._add_groq_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_openrouter_models_no_key_with_provider(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.OPEN_ROUTER_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "openrouter"
|
||||
mock_settings.LLM_NAME = "nonexistent"
|
||||
reg._add_openrouter_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_novita_models_no_key_with_provider(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.NOVITA_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "novita"
|
||||
mock_settings.LLM_NAME = "nonexistent"
|
||||
reg._add_novita_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_disabled_model(self):
|
||||
model = AvailableModel(
|
||||
id="disabled",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="Disabled",
|
||||
enabled=False,
|
||||
)
|
||||
d = model.to_dict()
|
||||
assert d["enabled"] is False
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_with_attachment_types(self):
|
||||
caps = ModelCapabilities(
|
||||
supported_attachment_types=["image/png", "application/pdf"],
|
||||
)
|
||||
model = AvailableModel(
|
||||
id="vision",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="Vision",
|
||||
capabilities=caps,
|
||||
)
|
||||
d = model.to_dict()
|
||||
assert d["supported_attachment_types"] == ["image/png", "application/pdf"]
|
||||
|
||||
# ----------------------------------------------------------------
|
||||
# Coverage for _add_* methods with matching LLM_NAME
|
||||
# Lines: 100, 105, 147, 171, 179, 186, 199-201, 204, 210, 213,
|
||||
# 218, 229, 233, 241, 250
|
||||
# ----------------------------------------------------------------
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_azure_openai_models_with_matching_name(self):
|
||||
"""Cover line 186: azure model matching LLM_NAME returns early."""
|
||||
from application.core.model_configs import AZURE_OPENAI_MODELS
|
||||
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.LLM_PROVIDER = "azure_openai"
|
||||
if AZURE_OPENAI_MODELS:
|
||||
mock_settings.LLM_NAME = AZURE_OPENAI_MODELS[0].id
|
||||
else:
|
||||
mock_settings.LLM_NAME = "nonexistent"
|
||||
reg._add_azure_openai_models(mock_settings)
|
||||
# Should have added at least one model
|
||||
assert len(reg.models) >= 1
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_anthropic_no_key_no_provider_fallthrough(self):
|
||||
"""Cover lines 199-204: no key, provider set but name not found -> add all."""
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.ANTHROPIC_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "anthropic"
|
||||
mock_settings.LLM_NAME = "nonexistent-model"
|
||||
reg._add_anthropic_models(mock_settings)
|
||||
# Falls through to add all anthropic models
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_google_no_key_matching_name(self):
|
||||
"""Cover lines 213-218: Google fallback with matching name."""
|
||||
from application.core.model_configs import GOOGLE_MODELS
|
||||
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.GOOGLE_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "google"
|
||||
if GOOGLE_MODELS:
|
||||
mock_settings.LLM_NAME = GOOGLE_MODELS[0].id
|
||||
else:
|
||||
mock_settings.LLM_NAME = "nonexistent"
|
||||
reg._add_google_models(mock_settings)
|
||||
assert len(reg.models) >= 1
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_groq_no_key_matching_name(self):
|
||||
"""Cover lines 229-233: Groq fallback with matching name."""
|
||||
from application.core.model_configs import GROQ_MODELS
|
||||
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.GROQ_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "groq"
|
||||
if GROQ_MODELS:
|
||||
mock_settings.LLM_NAME = GROQ_MODELS[0].id
|
||||
else:
|
||||
mock_settings.LLM_NAME = "nonexistent"
|
||||
reg._add_groq_models(mock_settings)
|
||||
assert len(reg.models) >= 1
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_openrouter_no_key_matching_name(self):
|
||||
"""Cover lines 241-250: OpenRouter fallback with matching name."""
|
||||
from application.core.model_configs import OPENROUTER_MODELS
|
||||
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.OPEN_ROUTER_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "openrouter"
|
||||
if OPENROUTER_MODELS:
|
||||
mock_settings.LLM_NAME = OPENROUTER_MODELS[0].id
|
||||
else:
|
||||
mock_settings.LLM_NAME = "nonexistent"
|
||||
reg._add_openrouter_models(mock_settings)
|
||||
assert len(reg.models) >= 1
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_novita_no_key_matching_name(self):
|
||||
"""Cover novita fallback with matching name."""
|
||||
from application.core.model_configs import NOVITA_MODELS
|
||||
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.NOVITA_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "novita"
|
||||
if NOVITA_MODELS:
|
||||
mock_settings.LLM_NAME = NOVITA_MODELS[0].id
|
||||
else:
|
||||
mock_settings.LLM_NAME = "nonexistent"
|
||||
reg._add_novita_models(mock_settings)
|
||||
assert len(reg.models) >= 1
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_load_models_default_from_llm_name_exact_match(self):
|
||||
"""Cover line 136/147: exact LLM_NAME match for default model."""
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.OPENAI_BASE_URL = None
|
||||
mock_settings.OPENAI_API_KEY = "sk-test"
|
||||
mock_settings.OPENAI_API_BASE = None
|
||||
mock_settings.ANTHROPIC_API_KEY = None
|
||||
mock_settings.GOOGLE_API_KEY = None
|
||||
mock_settings.GROQ_API_KEY = None
|
||||
mock_settings.OPEN_ROUTER_API_KEY = None
|
||||
mock_settings.NOVITA_API_KEY = None
|
||||
mock_settings.HUGGINGFACE_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "openai"
|
||||
mock_settings.API_KEY = None
|
||||
|
||||
from application.core.model_configs import OPENAI_MODELS
|
||||
|
||||
if OPENAI_MODELS:
|
||||
mock_settings.LLM_NAME = OPENAI_MODELS[0].id
|
||||
else:
|
||||
mock_settings.LLM_NAME = "gpt-4o"
|
||||
|
||||
with patch("application.core.settings.settings", mock_settings):
|
||||
reg = ModelRegistry()
|
||||
assert reg.default_model_id is not None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_openai_models_local_endpoint_no_name(self):
|
||||
"""Cover line 171: local endpoint without LLM_NAME adds nothing."""
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.OPENAI_BASE_URL = "http://localhost:11434/v1"
|
||||
mock_settings.OPENAI_API_KEY = "sk-test"
|
||||
mock_settings.LLM_NAME = None
|
||||
reg._add_openai_models(mock_settings)
|
||||
assert len(reg.models) == 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_openai_standard_no_api_key(self):
|
||||
"""Cover line 179: standard OpenAI without API key adds nothing."""
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.OPENAI_BASE_URL = None
|
||||
mock_settings.OPENAI_API_KEY = None
|
||||
reg._add_openai_models(mock_settings)
|
||||
assert len(reg.models) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Coverage — additional uncovered lines: 100, 105, 147, 171, 179, 186, 250
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestModelRegistryAdditionalCoverage:
|
||||
|
||||
def test_add_azure_openai_models_specific_name(self):
|
||||
"""Cover line 186: azure_openai with specific LLM_NAME match."""
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.LLM_PROVIDER = "azure_openai"
|
||||
mock_settings.LLM_NAME = "gpt-4o"
|
||||
|
||||
# Create a fake model that matches
|
||||
fake_model = MagicMock()
|
||||
fake_model.id = "gpt-4o"
|
||||
with patch(
|
||||
"application.core.model_configs.AZURE_OPENAI_MODELS",
|
||||
[fake_model],
|
||||
):
|
||||
reg._add_azure_openai_models(mock_settings)
|
||||
assert "gpt-4o" in reg.models
|
||||
|
||||
def test_add_anthropic_models_with_api_key(self):
|
||||
"""Cover line 100: anthropic with API key."""
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.ANTHROPIC_API_KEY = "sk-test"
|
||||
mock_settings.LLM_PROVIDER = "anthropic"
|
||||
reg._add_anthropic_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
def test_add_google_models_with_api_key(self):
|
||||
"""Cover line 105: google with API key."""
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.GOOGLE_API_KEY = "test-key"
|
||||
mock_settings.LLM_PROVIDER = "google"
|
||||
reg._add_google_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
def test_default_model_from_provider(self):
|
||||
"""Cover line 147: default model selected from provider."""
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
reg.default_model_id = None
|
||||
|
||||
fake_model = MagicMock()
|
||||
fake_model.provider = MagicMock()
|
||||
fake_model.provider.value = "openai"
|
||||
reg.models["gpt-4o"] = fake_model
|
||||
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.LLM_NAME = None
|
||||
mock_settings.LLM_PROVIDER = "openai"
|
||||
mock_settings.API_KEY = "key"
|
||||
|
||||
# Simulate the default selection logic
|
||||
if not reg.default_model_id:
|
||||
for model_id, model in reg.models.items():
|
||||
if model.provider.value == mock_settings.LLM_PROVIDER:
|
||||
reg.default_model_id = model_id
|
||||
break
|
||||
|
||||
assert reg.default_model_id == "gpt-4o"
|
||||
|
||||
def test_add_openai_local_endpoint_with_llm_name(self):
|
||||
"""Cover line 171: local endpoint registers custom models from LLM_NAME."""
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.OPENAI_BASE_URL = "http://localhost:11434/v1"
|
||||
mock_settings.OPENAI_API_KEY = "sk-test"
|
||||
mock_settings.LLM_NAME = "llama3,phi3"
|
||||
reg._add_openai_models(mock_settings)
|
||||
assert "llama3" in reg.models
|
||||
assert "phi3" in reg.models
|
||||
|
||||
def test_add_openai_standard_with_api_key(self):
|
||||
"""Cover line 179: standard OpenAI with API key adds models."""
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.OPENAI_BASE_URL = None
|
||||
mock_settings.OPENAI_API_KEY = "sk-real-key"
|
||||
reg._add_openai_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
def test_add_openrouter_models(self):
|
||||
"""Cover line 250: openrouter models added."""
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.OPEN_ROUTER_API_KEY = "or-key"
|
||||
mock_settings.LLM_PROVIDER = "openrouter"
|
||||
reg._add_openrouter_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Additional coverage for model_settings.py
|
||||
# Lines: 135-136 (backward compat LLM_NAME), 138-143 (provider fallback),
|
||||
# 145-146 (first model as default)
|
||||
# ---------------------------------------------------------------------------
|
||||
# Imports already at the top of the file; no additional imports needed
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestDefaultModelSelectionBackwardCompat:
|
||||
"""Cover lines 135-136: backward compat exact match on LLM_NAME."""
|
||||
|
||||
def test_llm_name_exact_match_as_default(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
reg.default_model_id = None
|
||||
# Add a model with composite ID
|
||||
model = AvailableModel(
|
||||
id="my-composite-model",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="Composite",
|
||||
description="test",
|
||||
capabilities=ModelCapabilities(),
|
||||
reg.models["m1"] = AvailableModel(
|
||||
id="m1", provider=ModelProvider.OPENAI, display_name="M1"
|
||||
)
|
||||
reg.models["my-composite-model"] = model
|
||||
assert reg.get_model("m1", user_id="alice") is not None
|
||||
assert reg.model_exists("m1", user_id="alice") is True
|
||||
assert len(reg.get_all_models(user_id="alice")) == 1
|
||||
assert len(reg.get_enabled_models(user_id="alice")) == 1
|
||||
|
||||
# Simulate _parse_model_names returning something different
|
||||
# so that the first for-loop doesn't match
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.LLM_NAME = "my-composite-model"
|
||||
mock_settings.LLM_PROVIDER = None
|
||||
mock_settings.API_KEY = None
|
||||
|
||||
# Call the logic directly
|
||||
model_names = reg._parse_model_names(mock_settings.LLM_NAME)
|
||||
for mn in model_names:
|
||||
if mn in reg.models:
|
||||
reg.default_model_id = mn
|
||||
break
|
||||
|
||||
assert reg.default_model_id == "my-composite-model"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestDefaultModelSelectionByProvider:
|
||||
"""Cover lines 138-143: default model by provider when LLM_NAME doesn't match."""
|
||||
|
||||
def test_default_by_provider(self):
|
||||
@pytest.mark.unit
|
||||
def test_reset(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
reg.default_model_id = None
|
||||
model = AvailableModel(
|
||||
id="gpt-4",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="GPT-4",
|
||||
description="test",
|
||||
capabilities=ModelCapabilities(),
|
||||
)
|
||||
reg.models["gpt-4"] = model
|
||||
|
||||
# Simulate: LLM_NAME doesn't exist/match, but LLM_PROVIDER + API_KEY set
|
||||
if not reg.default_model_id:
|
||||
for model_id, m in reg.models.items():
|
||||
if m.provider.value == "openai":
|
||||
reg.default_model_id = model_id
|
||||
break
|
||||
|
||||
assert reg.default_model_id == "gpt-4"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestDefaultModelSelectionFirstModel:
|
||||
"""Cover lines 145-146: first model as default when nothing else matches."""
|
||||
|
||||
def test_first_model_as_default(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
reg.default_model_id = None
|
||||
model = AvailableModel(
|
||||
id="fallback-model",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="Fallback",
|
||||
description="test",
|
||||
capabilities=ModelCapabilities(),
|
||||
)
|
||||
reg.models["fallback-model"] = model
|
||||
|
||||
if not reg.default_model_id and reg.models:
|
||||
reg.default_model_id = next(iter(reg.models.keys()))
|
||||
|
||||
assert reg.default_model_id == "fallback-model"
|
||||
r1 = ModelRegistry()
|
||||
ModelRegistry.reset()
|
||||
r2 = ModelRegistry()
|
||||
assert r1 is not r2
|
||||
|
||||
208
tests/core/test_models_config_dir.py
Normal file
208
tests/core/test_models_config_dir.py
Normal file
@@ -0,0 +1,208 @@
|
||||
"""Phase 3 tests: operator MODELS_CONFIG_DIR.
|
||||
|
||||
Covers the operator-supplied directory of model YAMLs that's loaded
|
||||
after the built-in catalog. Operators use this to add new
|
||||
``openai_compatible`` providers, extend an existing provider's catalog
|
||||
with extra models, or override a built-in model's capabilities — all
|
||||
without forking the repo.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from textwrap import dedent
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from application.core.model_registry import ModelRegistry
|
||||
|
||||
|
||||
def _make_settings(**overrides):
|
||||
s = MagicMock()
|
||||
s.OPENAI_BASE_URL = None
|
||||
s.OPENAI_API_KEY = None
|
||||
s.OPENAI_API_BASE = None
|
||||
s.ANTHROPIC_API_KEY = None
|
||||
s.GOOGLE_API_KEY = None
|
||||
s.GROQ_API_KEY = None
|
||||
s.OPEN_ROUTER_API_KEY = None
|
||||
s.NOVITA_API_KEY = None
|
||||
s.HUGGINGFACE_API_KEY = None
|
||||
s.LLM_PROVIDER = ""
|
||||
s.LLM_NAME = None
|
||||
s.API_KEY = None
|
||||
s.MODELS_CONFIG_DIR = None
|
||||
for k, v in overrides.items():
|
||||
setattr(s, k, v)
|
||||
return s
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_registry():
|
||||
ModelRegistry.reset()
|
||||
yield
|
||||
ModelRegistry.reset()
|
||||
|
||||
|
||||
# ── New provider via openai_compatible ───────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestOperatorAddsNewProvider:
|
||||
def test_drop_in_yaml_appears_in_registry(
|
||||
self, tmp_path, monkeypatch
|
||||
):
|
||||
(tmp_path / "fireworks.yaml").write_text(dedent("""
|
||||
provider: openai_compatible
|
||||
display_provider: fireworks
|
||||
api_key_env: FIREWORKS_API_KEY
|
||||
base_url: https://api.fireworks.ai/inference/v1
|
||||
defaults:
|
||||
supports_tools: true
|
||||
models:
|
||||
- id: accounts/fireworks/models/llama-v3p3-70b-instruct
|
||||
display_name: Llama 3.3 70B (Fireworks)
|
||||
"""))
|
||||
monkeypatch.setenv("FIREWORKS_API_KEY", "fw-key")
|
||||
|
||||
s = _make_settings(MODELS_CONFIG_DIR=str(tmp_path))
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
|
||||
m = reg.get_model("accounts/fireworks/models/llama-v3p3-70b-instruct")
|
||||
assert m is not None
|
||||
assert m.api_key == "fw-key"
|
||||
assert m.base_url == "https://api.fireworks.ai/inference/v1"
|
||||
assert m.display_provider == "fireworks"
|
||||
|
||||
|
||||
# ── Extending an existing provider's catalog ─────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestOperatorExtendsExistingProvider:
|
||||
def test_operator_adds_anthropic_model_to_builtin_catalog(
|
||||
self, tmp_path
|
||||
):
|
||||
(tmp_path / "anthropic-extra.yaml").write_text(dedent("""
|
||||
provider: anthropic
|
||||
defaults:
|
||||
supports_tools: true
|
||||
context_window: 200000
|
||||
models:
|
||||
- id: claude-haiku-5-0-future
|
||||
display_name: Claude Haiku 5.0
|
||||
"""))
|
||||
|
||||
s = _make_settings(
|
||||
ANTHROPIC_API_KEY="sk-ant",
|
||||
MODELS_CONFIG_DIR=str(tmp_path),
|
||||
)
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
|
||||
# Built-in models still present
|
||||
assert reg.get_model("claude-sonnet-4-6") is not None
|
||||
assert reg.get_model("claude-opus-4-7") is not None
|
||||
# Operator-added model also present
|
||||
added = reg.get_model("claude-haiku-5-0-future")
|
||||
assert added is not None
|
||||
assert added.display_name == "Claude Haiku 5.0"
|
||||
|
||||
|
||||
# ── Overriding a built-in model's capabilities ───────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestOperatorOverridesBuiltinCapabilities:
|
||||
def test_operator_yaml_overrides_builtin_context_window(
|
||||
self, tmp_path, caplog
|
||||
):
|
||||
# Override anthropic claude-haiku-4-5 to claim a 1M context window
|
||||
(tmp_path / "anthropic-override.yaml").write_text(dedent("""
|
||||
provider: anthropic
|
||||
defaults:
|
||||
supports_tools: true
|
||||
attachments: [image]
|
||||
context_window: 1000000
|
||||
models:
|
||||
- id: claude-haiku-4-5
|
||||
display_name: Claude Haiku 4.5 (extended)
|
||||
description: Operator-overridden capabilities
|
||||
"""))
|
||||
|
||||
s = _make_settings(
|
||||
ANTHROPIC_API_KEY="sk-ant",
|
||||
MODELS_CONFIG_DIR=str(tmp_path),
|
||||
)
|
||||
with caplog.at_level(logging.WARNING):
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
|
||||
m = reg.get_model("claude-haiku-4-5")
|
||||
assert m.display_name == "Claude Haiku 4.5 (extended)"
|
||||
assert m.description == "Operator-overridden capabilities"
|
||||
assert m.capabilities.context_window == 1_000_000
|
||||
|
||||
# And the override warning fires so the operator can audit it
|
||||
assert any(
|
||||
"claude-haiku-4-5" in rec.message and "redefined" in rec.message
|
||||
for rec in caplog.records
|
||||
)
|
||||
|
||||
|
||||
# ── Misconfigured MODELS_CONFIG_DIR ──────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestMisconfiguredOperatorDir:
|
||||
def test_missing_dir_logs_warning_and_continues(
|
||||
self, tmp_path, caplog
|
||||
):
|
||||
bogus = tmp_path / "does-not-exist"
|
||||
s = _make_settings(MODELS_CONFIG_DIR=str(bogus))
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
|
||||
# Built-in catalog still loaded
|
||||
assert reg.get_model("docsgpt-local") is not None
|
||||
# And the operator was warned
|
||||
assert any("does not exist" in rec.message for rec in caplog.records)
|
||||
|
||||
def test_path_is_a_file_logs_warning(self, tmp_path, caplog):
|
||||
afile = tmp_path / "not-a-dir.yaml"
|
||||
afile.write_text("provider: anthropic\nmodels: []")
|
||||
|
||||
s = _make_settings(MODELS_CONFIG_DIR=str(afile))
|
||||
with caplog.at_level(logging.WARNING):
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
|
||||
assert reg.get_model("docsgpt-local") is not None
|
||||
assert any("not a directory" in rec.message for rec in caplog.records)
|
||||
|
||||
|
||||
# ── Validation: unknown provider rejected ────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestOperatorValidation:
|
||||
def test_unknown_provider_in_operator_yaml_aborts_boot(self, tmp_path):
|
||||
(tmp_path / "bogus.yaml").write_text(dedent("""
|
||||
provider: not_a_real_provider
|
||||
models:
|
||||
- id: x
|
||||
display_name: X
|
||||
"""))
|
||||
|
||||
s = _make_settings(MODELS_CONFIG_DIR=str(tmp_path))
|
||||
with patch("application.core.settings.settings", s):
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
ModelRegistry()
|
||||
# Could be ModelYAMLError (enum check) or ValueError (registry check);
|
||||
# either way the message must surface what's wrong.
|
||||
msg = str(exc_info.value)
|
||||
assert "not_a_real_provider" in msg
|
||||
298
tests/core/test_openai_compatible.py
Normal file
298
tests/core/test_openai_compatible.py
Normal file
@@ -0,0 +1,298 @@
|
||||
"""Phase 2 tests for the openai_compatible provider.
|
||||
|
||||
Covers YAML loading from a temp directory, multiple coexisting catalogs
|
||||
(Mistral + Together), env-var-based credential resolution, the legacy
|
||||
OPENAI_BASE_URL + LLM_NAME fallback, and end-to-end model dispatch
|
||||
through LLMCreator.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from application.core.model_registry import ModelRegistry
|
||||
from application.core.model_settings import ModelProvider
|
||||
|
||||
|
||||
def _make_settings(**overrides):
|
||||
s = MagicMock()
|
||||
s.OPENAI_BASE_URL = None
|
||||
s.OPENAI_API_KEY = None
|
||||
s.OPENAI_API_BASE = None
|
||||
s.ANTHROPIC_API_KEY = None
|
||||
s.GOOGLE_API_KEY = None
|
||||
s.GROQ_API_KEY = None
|
||||
s.OPEN_ROUTER_API_KEY = None
|
||||
s.NOVITA_API_KEY = None
|
||||
s.HUGGINGFACE_API_KEY = None
|
||||
s.LLM_PROVIDER = ""
|
||||
s.LLM_NAME = None
|
||||
s.API_KEY = None
|
||||
s.MODELS_CONFIG_DIR = None
|
||||
for k, v in overrides.items():
|
||||
setattr(s, k, v)
|
||||
return s
|
||||
|
||||
|
||||
def _write_mistral_yaml(directory: Path) -> Path:
|
||||
path = directory / "mistral.yaml"
|
||||
path.write_text(dedent("""
|
||||
provider: openai_compatible
|
||||
display_provider: mistral
|
||||
api_key_env: MISTRAL_API_KEY
|
||||
base_url: https://api.mistral.ai/v1
|
||||
defaults:
|
||||
supports_tools: true
|
||||
context_window: 128000
|
||||
models:
|
||||
- id: mistral-large-latest
|
||||
display_name: Mistral Large
|
||||
- id: mistral-small-latest
|
||||
display_name: Mistral Small
|
||||
"""))
|
||||
return path
|
||||
|
||||
|
||||
def _write_together_yaml(directory: Path) -> Path:
|
||||
path = directory / "together.yaml"
|
||||
path.write_text(dedent("""
|
||||
provider: openai_compatible
|
||||
display_provider: together
|
||||
api_key_env: TOGETHER_API_KEY
|
||||
base_url: https://api.together.xyz/v1
|
||||
defaults:
|
||||
supports_tools: true
|
||||
models:
|
||||
- id: meta-llama/Llama-3.3-70B-Instruct-Turbo
|
||||
display_name: Llama 3.3 70B (Together)
|
||||
"""))
|
||||
return path
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_registry():
|
||||
ModelRegistry.reset()
|
||||
yield
|
||||
ModelRegistry.reset()
|
||||
|
||||
|
||||
# ── YAML-driven catalogs ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestYAMLCompatibleProvider:
|
||||
def test_mistral_yaml_loads_with_env_key(
|
||||
self, tmp_path, monkeypatch
|
||||
):
|
||||
_write_mistral_yaml(tmp_path)
|
||||
monkeypatch.setenv("MISTRAL_API_KEY", "sk-mistral-test")
|
||||
|
||||
s = _make_settings(MODELS_CONFIG_DIR=str(tmp_path))
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
|
||||
m = reg.get_model("mistral-large-latest")
|
||||
assert m is not None
|
||||
assert m.provider == ModelProvider.OPENAI_COMPATIBLE
|
||||
assert m.display_provider == "mistral"
|
||||
assert m.base_url == "https://api.mistral.ai/v1"
|
||||
assert m.api_key == "sk-mistral-test"
|
||||
assert m.capabilities.supports_tools is True
|
||||
assert m.capabilities.context_window == 128000
|
||||
|
||||
def test_yaml_skipped_when_env_var_missing(
|
||||
self, tmp_path, monkeypatch
|
||||
):
|
||||
_write_mistral_yaml(tmp_path)
|
||||
monkeypatch.delenv("MISTRAL_API_KEY", raising=False)
|
||||
|
||||
s = _make_settings(MODELS_CONFIG_DIR=str(tmp_path))
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
|
||||
# Catalog skipped when no key — no Mistral models in the registry
|
||||
assert reg.get_model("mistral-large-latest") is None
|
||||
|
||||
def test_two_compatible_catalogs_coexist_with_separate_keys(
|
||||
self, tmp_path, monkeypatch
|
||||
):
|
||||
_write_mistral_yaml(tmp_path)
|
||||
_write_together_yaml(tmp_path)
|
||||
monkeypatch.setenv("MISTRAL_API_KEY", "sk-mistral")
|
||||
monkeypatch.setenv("TOGETHER_API_KEY", "sk-together")
|
||||
|
||||
s = _make_settings(MODELS_CONFIG_DIR=str(tmp_path))
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
|
||||
mistral = reg.get_model("mistral-large-latest")
|
||||
together = reg.get_model("meta-llama/Llama-3.3-70B-Instruct-Turbo")
|
||||
|
||||
assert mistral.api_key == "sk-mistral"
|
||||
assert mistral.base_url == "https://api.mistral.ai/v1"
|
||||
assert mistral.display_provider == "mistral"
|
||||
|
||||
assert together.api_key == "sk-together"
|
||||
assert together.base_url == "https://api.together.xyz/v1"
|
||||
assert together.display_provider == "together"
|
||||
|
||||
def test_one_catalog_enabled_other_skipped(
|
||||
self, tmp_path, monkeypatch
|
||||
):
|
||||
_write_mistral_yaml(tmp_path)
|
||||
_write_together_yaml(tmp_path)
|
||||
monkeypatch.setenv("MISTRAL_API_KEY", "sk-mistral")
|
||||
monkeypatch.delenv("TOGETHER_API_KEY", raising=False)
|
||||
|
||||
s = _make_settings(MODELS_CONFIG_DIR=str(tmp_path))
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
|
||||
assert reg.get_model("mistral-large-latest") is not None
|
||||
assert reg.get_model("meta-llama/Llama-3.3-70B-Instruct-Turbo") is None
|
||||
|
||||
def test_missing_base_url_raises(self, tmp_path, monkeypatch):
|
||||
bad = tmp_path / "broken.yaml"
|
||||
bad.write_text(dedent("""
|
||||
provider: openai_compatible
|
||||
api_key_env: SOME_KEY
|
||||
models:
|
||||
- id: x
|
||||
display_name: X
|
||||
"""))
|
||||
monkeypatch.setenv("SOME_KEY", "k")
|
||||
|
||||
s = _make_settings(MODELS_CONFIG_DIR=str(tmp_path))
|
||||
with patch("application.core.settings.settings", s):
|
||||
with pytest.raises(ValueError, match="must set 'base_url'"):
|
||||
ModelRegistry()
|
||||
|
||||
def test_missing_api_key_env_raises(self, tmp_path, monkeypatch):
|
||||
bad = tmp_path / "broken.yaml"
|
||||
bad.write_text(dedent("""
|
||||
provider: openai_compatible
|
||||
base_url: https://x/v1
|
||||
models:
|
||||
- id: x
|
||||
display_name: X
|
||||
"""))
|
||||
|
||||
s = _make_settings(MODELS_CONFIG_DIR=str(tmp_path))
|
||||
with patch("application.core.settings.settings", s):
|
||||
with pytest.raises(ValueError, match="must set 'api_key_env'"):
|
||||
ModelRegistry()
|
||||
|
||||
def test_to_dict_uses_display_provider(
|
||||
self, tmp_path, monkeypatch
|
||||
):
|
||||
_write_mistral_yaml(tmp_path)
|
||||
monkeypatch.setenv("MISTRAL_API_KEY", "sk")
|
||||
|
||||
s = _make_settings(MODELS_CONFIG_DIR=str(tmp_path))
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
|
||||
d = reg.get_model("mistral-large-latest").to_dict()
|
||||
# /api/models response shows "mistral", not "openai_compatible"
|
||||
assert d["provider"] == "mistral"
|
||||
# api_key never leaks into the wire format
|
||||
assert "api_key" not in d
|
||||
for v in d.values():
|
||||
assert v != "sk"
|
||||
|
||||
|
||||
# ── Legacy OPENAI_BASE_URL fallback ──────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestLegacyOpenAIBaseURLPath:
|
||||
def test_legacy_models_now_provided_by_openai_compatible(self):
|
||||
s = _make_settings(
|
||||
OPENAI_BASE_URL="http://localhost:11434/v1",
|
||||
OPENAI_API_KEY="sk-local",
|
||||
LLM_PROVIDER="openai",
|
||||
LLM_NAME="llama3,gemma",
|
||||
)
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
|
||||
ids = {m.id for m in reg.get_all_models()}
|
||||
assert ids == {"llama3", "gemma"}
|
||||
|
||||
llama = reg.get_model("llama3")
|
||||
assert llama.base_url == "http://localhost:11434/v1"
|
||||
assert llama.api_key == "sk-local"
|
||||
assert llama.provider == ModelProvider.OPENAI_COMPATIBLE
|
||||
# Display provider preserves the historical "openai" label
|
||||
assert llama.display_provider == "openai"
|
||||
assert llama.to_dict()["provider"] == "openai"
|
||||
|
||||
def test_legacy_uses_api_key_fallback_when_openai_api_key_missing(self):
|
||||
s = _make_settings(
|
||||
OPENAI_BASE_URL="http://localhost:11434/v1",
|
||||
OPENAI_API_KEY=None,
|
||||
API_KEY="sk-generic",
|
||||
LLM_PROVIDER="openai",
|
||||
LLM_NAME="llama3",
|
||||
)
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
assert reg.get_model("llama3").api_key == "sk-generic"
|
||||
|
||||
|
||||
# ── Dispatch through LLMCreator ──────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestLLMCreatorDispatch:
|
||||
def test_llmcreator_uses_per_model_api_key_and_base_url(
|
||||
self, tmp_path, monkeypatch
|
||||
):
|
||||
"""End-to-end: when an openai_compatible model is dispatched, the
|
||||
per-model api_key + base_url from the registry must override
|
||||
whatever the caller passed."""
|
||||
_write_mistral_yaml(tmp_path)
|
||||
monkeypatch.setenv("MISTRAL_API_KEY", "sk-mistral-real")
|
||||
|
||||
s = _make_settings(MODELS_CONFIG_DIR=str(tmp_path))
|
||||
|
||||
captured = {}
|
||||
|
||||
class _FakeLLM:
|
||||
def __init__(
|
||||
self, api_key, user_api_key, *args, **kwargs
|
||||
):
|
||||
captured["api_key"] = api_key
|
||||
captured["base_url"] = kwargs.get("base_url")
|
||||
captured["model_id"] = kwargs.get("model_id")
|
||||
|
||||
with patch("application.core.settings.settings", s):
|
||||
ModelRegistry.reset()
|
||||
ModelRegistry() # warm up the registry under patched settings
|
||||
|
||||
# Now patch the OpenAI plugin's class so we can capture the
|
||||
# constructor args without spinning up the real OpenAILLM.
|
||||
from application.llm.providers import PROVIDERS_BY_NAME
|
||||
|
||||
with patch.object(
|
||||
PROVIDERS_BY_NAME["openai_compatible"],
|
||||
"llm_class",
|
||||
_FakeLLM,
|
||||
):
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
|
||||
LLMCreator.create_llm(
|
||||
type="openai_compatible",
|
||||
api_key="caller-passed-WRONG-key",
|
||||
user_api_key=None,
|
||||
decoded_token={"sub": "u"},
|
||||
model_id="mistral-large-latest",
|
||||
)
|
||||
|
||||
assert captured["api_key"] == "sk-mistral-real"
|
||||
assert captured["base_url"] == "https://api.mistral.ai/v1"
|
||||
assert captured["model_id"] == "mistral-large-latest"
|
||||
Reference in New Issue
Block a user