Compare commits

...

2 Commits

Author SHA1 Message Date
Alex
e0a8cc178b feat: BYOM 2026-04-27 21:50:45 +01:00
Alex
af618de13d Feat models (#2432)
* feat: simplified model structure

* fix: test

* fix: mini docstring stuff
2026-04-26 00:58:29 +01:00
107 changed files with 10225 additions and 1489 deletions

View File

@@ -35,8 +35,5 @@ MICROSOFT_TENANT_ID=your-azure-ad-tenant-id
#Alternatively, use "https://login.microsoftonline.com/common" for multi-tenant app.
MICROSOFT_AUTHORITY=https://{tenantId}.ciamlogin.com/{tenantId}
# User-data Postgres DB (Phase 0 of the MongoDB→Postgres migration).
# Standard Postgres URI — `postgres://` and `postgresql://` both work.
# Leave unset while the migration is still being rolled out; the app will
# fall back to MongoDB for user data until POSTGRES_URI is configured.
# POSTGRES_URI=postgresql://docsgpt:docsgpt@localhost:5432/docsgpt

View File

@@ -42,6 +42,7 @@ class BaseAgent(ABC):
llm_handler=None,
tool_executor: Optional[ToolExecutor] = None,
backup_models: Optional[List[str]] = None,
model_user_id: Optional[str] = None,
):
self.endpoint = endpoint
self.llm_name = llm_name
@@ -52,10 +53,13 @@ class BaseAgent(ABC):
self.prompt = prompt
self.decoded_token = decoded_token or {}
self.user: str = self.decoded_token.get("sub")
# BYOM-resolution scope: owner for shared agents, caller for
# caller-owned BYOM, None for built-ins. Falls back to self.user
# for worker/legacy callers that don't thread model_user_id.
self.model_user_id = model_user_id
self.tools: List[Dict] = []
self.chat_history: List[Dict] = chat_history if chat_history is not None else []
# Dependency injection for LLM — fall back to creating if not provided
if llm is not None:
self.llm = llm
else:
@@ -67,8 +71,16 @@ class BaseAgent(ABC):
model_id=model_id,
agent_id=agent_id,
backup_models=backup_models,
model_user_id=model_user_id,
)
# For BYOM, registry id (UUID) differs from upstream model id
# (e.g. ``mistral-large-latest``). LLMCreator resolved this onto
# the LLM instance; cache it for subsequent gen calls.
self.upstream_model_id = (
getattr(self.llm, "model_id", None) or model_id
)
self.retrieved_docs = retrieved_docs or []
if llm_handler is not None:
@@ -306,7 +318,9 @@ class BaseAgent(ABC):
try:
current_tokens = self._calculate_current_context_tokens(messages)
self.current_token_count = current_tokens
context_limit = get_token_limit(self.model_id)
context_limit = get_token_limit(
self.model_id, user_id=self.model_user_id or self.user
)
threshold = int(context_limit * settings.COMPRESSION_THRESHOLD_PERCENTAGE)
if current_tokens >= threshold:
@@ -325,7 +339,9 @@ class BaseAgent(ABC):
current_tokens = self._calculate_current_context_tokens(messages)
self.current_token_count = current_tokens
context_limit = get_token_limit(self.model_id)
context_limit = get_token_limit(
self.model_id, user_id=self.model_user_id or self.user
)
percentage = (current_tokens / context_limit) * 100
if current_tokens >= context_limit:
@@ -387,7 +403,9 @@ class BaseAgent(ABC):
)
system_prompt = system_prompt + compression_context
context_limit = get_token_limit(self.model_id)
context_limit = get_token_limit(
self.model_id, user_id=self.model_user_id or self.user
)
system_tokens = num_tokens_from_string(system_prompt)
safety_buffer = int(context_limit * 0.1)
@@ -497,7 +515,10 @@ class BaseAgent(ABC):
def _llm_gen(self, messages: List[Dict], log_context: Optional[LogContext] = None):
self._validate_context_size(messages)
gen_kwargs = {"model": self.model_id, "messages": messages}
# Use the upstream id resolved by LLMCreator (see __init__).
# Built-in models: same as self.model_id. BYOM: the user's
# typed model name, not the internal UUID.
gen_kwargs = {"model": self.upstream_model_id, "messages": messages}
if self.attachments:
gen_kwargs["_usage_attachments"] = self.attachments

View File

@@ -312,7 +312,7 @@ class ResearchAgent(BaseAgent):
try:
response = self.llm.gen(
model=self.model_id,
model=self.upstream_model_id,
messages=messages,
tools=None,
response_format={"type": "json_object"},
@@ -390,7 +390,7 @@ class ResearchAgent(BaseAgent):
try:
response = self.llm.gen(
model=self.model_id,
model=self.upstream_model_id,
messages=messages,
tools=None,
response_format={"type": "json_object"},
@@ -506,7 +506,7 @@ class ResearchAgent(BaseAgent):
try:
response = self.llm.gen(
model=self.model_id,
model=self.upstream_model_id,
messages=messages,
tools=self.tools if self.tools else None,
)
@@ -537,7 +537,7 @@ class ResearchAgent(BaseAgent):
)
try:
response = self.llm.gen(
model=self.model_id, messages=messages, tools=None
model=self.upstream_model_id, messages=messages, tools=None
)
self._track_tokens(self._snapshot_llm_tokens())
text = self._extract_text(response)
@@ -664,7 +664,7 @@ class ResearchAgent(BaseAgent):
]
llm_response = self.llm.gen_stream(
model=self.model_id, messages=messages, tools=None
model=self.upstream_model_id, messages=messages, tools=None
)
if log_context:

View File

@@ -39,6 +39,7 @@ class InternalSearchTool(Tool):
chunks=int(self.config.get("chunks", 2)),
doc_token_limit=int(self.config.get("doc_token_limit", 50000)),
model_id=self.config.get("model_id", "docsgpt-local"),
model_user_id=self.config.get("model_user_id"),
user_api_key=self.config.get("user_api_key"),
agent_id=self.config.get("agent_id"),
llm_name=self.config.get("llm_name", settings.LLM_PROVIDER),
@@ -435,6 +436,7 @@ def build_internal_tool_config(
chunks: int = 2,
doc_token_limit: int = 50000,
model_id: str = "docsgpt-local",
model_user_id: Optional[str] = None,
user_api_key: Optional[str] = None,
agent_id: Optional[str] = None,
llm_name: str = None,
@@ -449,6 +451,7 @@ def build_internal_tool_config(
"chunks": chunks,
"doc_token_limit": doc_token_limit,
"model_id": model_id,
"model_user_id": model_user_id,
"user_api_key": user_api_key,
"agent_id": agent_id,
"llm_name": llm_name or settings.LLM_PROVIDER,

View File

@@ -211,15 +211,26 @@ class WorkflowEngine:
node_config.json_schema, node.title
)
node_model_id = node_config.model_id or self.agent.model_id
# Inherit BYOM scope from parent agent so owner-stored BYOM
# resolves on shared workflows.
node_user_id = getattr(self.agent, "model_user_id", None) or (
self.agent.decoded_token.get("sub")
if isinstance(self.agent.decoded_token, dict)
else None
)
node_llm_name = (
node_config.llm_name
or get_provider_from_model_id(node_model_id or "")
or get_provider_from_model_id(
node_model_id or "", user_id=node_user_id
)
or self.agent.llm_name
)
node_api_key = get_api_key_for_provider(node_llm_name) or self.agent.api_key
if node_json_schema and node_model_id:
model_capabilities = get_model_capabilities(node_model_id)
model_capabilities = get_model_capabilities(
node_model_id, user_id=node_user_id
)
if model_capabilities and not model_capabilities.get(
"supports_structured_output", False
):
@@ -232,6 +243,7 @@ class WorkflowEngine:
"endpoint": self.agent.endpoint,
"llm_name": node_llm_name,
"model_id": node_model_id,
"model_user_id": getattr(self.agent, "model_user_id", None),
"api_key": node_api_key,
"tool_ids": node_config.tools,
"prompt": node_config.system_prompt,

View File

@@ -0,0 +1,65 @@
"""0003 user_custom_models — per-user OpenAI-compatible model registrations.
Revision ID: 0003_user_custom_models
Revises: 0002_app_metadata
"""
from typing import Sequence, Union
from alembic import op
revision: str = "0003_user_custom_models"
down_revision: Union[str, None] = "0002_app_metadata"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.execute(
"""
CREATE TABLE user_custom_models (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
upstream_model_id TEXT NOT NULL,
display_name TEXT NOT NULL,
description TEXT NOT NULL DEFAULT '',
base_url TEXT NOT NULL,
api_key_encrypted TEXT NOT NULL,
capabilities JSONB NOT NULL DEFAULT '{}'::jsonb,
enabled BOOLEAN NOT NULL DEFAULT true,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
"""
)
op.execute(
"CREATE INDEX user_custom_models_user_id_idx "
"ON user_custom_models (user_id);"
)
# Mirror the project-wide invariants set up in 0001_initial:
# * user_id FK with ON DELETE RESTRICT (deferrable),
# * ensure_user_exists() trigger so the parent users row autocreates,
# * set_updated_at() trigger.
op.execute(
"ALTER TABLE user_custom_models "
"ADD CONSTRAINT user_custom_models_user_id_fk "
"FOREIGN KEY (user_id) REFERENCES users(user_id) "
"ON DELETE RESTRICT DEFERRABLE INITIALLY IMMEDIATE;"
)
op.execute(
"CREATE TRIGGER user_custom_models_ensure_user "
"BEFORE INSERT OR UPDATE OF user_id ON user_custom_models "
"FOR EACH ROW EXECUTE FUNCTION ensure_user_exists();"
)
op.execute(
"CREATE TRIGGER user_custom_models_set_updated_at "
"BEFORE UPDATE ON user_custom_models "
"FOR EACH ROW WHEN (OLD.* IS DISTINCT FROM NEW.*) "
"EXECUTE FUNCTION set_updated_at();"
)
def downgrade() -> None:
op.execute("DROP TABLE IF EXISTS user_custom_models;")

View File

@@ -177,6 +177,7 @@ class BaseAnswerResource:
is_shared_usage: bool = False,
shared_token: Optional[str] = None,
model_id: Optional[str] = None,
model_user_id: Optional[str] = None,
_continuation: Optional[Dict] = None,
) -> Generator[str, None, None]:
"""
@@ -289,8 +290,18 @@ class BaseAnswerResource:
# conversation if this is the first turn.
if not conversation_id and should_save_conversation:
try:
# Use model-owner scope so shared-agent
# owner-BYOM resolves to its registered plugin.
provider = (
get_provider_from_model_id(model_id)
get_provider_from_model_id(
model_id,
user_id=model_user_id
or (
decoded_token.get("sub")
if decoded_token
else None
),
)
if model_id
else settings.LLM_PROVIDER
)
@@ -304,6 +315,7 @@ class BaseAnswerResource:
decoded_token=decoded_token,
model_id=model_id,
agent_id=agent_id,
model_user_id=model_user_id,
)
conversation_id = (
self.conversation_service.save_conversation(
@@ -340,6 +352,9 @@ class BaseAnswerResource:
tool_schemas=getattr(agent, "tools", []),
agent_config={
"model_id": model_id or self.default_model_id,
# Persist BYOM scope so resume doesn't
# fall back to caller's layer.
"model_user_id": model_user_id,
"llm_name": getattr(agent, "llm_name", settings.LLM_PROVIDER),
"api_key": getattr(agent, "api_key", None),
"user_api_key": user_api_key,
@@ -370,8 +385,14 @@ class BaseAnswerResource:
if isNoneDoc:
for doc in source_log_docs:
doc["source"] = "None"
# Run under model-owner scope so title-gen LLM inside
# save_conversation uses the owner's BYOM provider/key.
provider = (
get_provider_from_model_id(model_id)
get_provider_from_model_id(
model_id,
user_id=model_user_id
or (decoded_token.get("sub") if decoded_token else None),
)
if model_id
else settings.LLM_PROVIDER
)
@@ -384,6 +405,7 @@ class BaseAnswerResource:
decoded_token=decoded_token,
model_id=model_id,
agent_id=agent_id,
model_user_id=model_user_id,
)
if should_save_conversation:
@@ -481,12 +503,34 @@ class BaseAnswerResource:
if isNoneDoc:
for doc in source_log_docs:
doc["source"] = "None"
# Mirror the normal-path provider resolution so the
# partial-save title LLM uses the model-owner's BYOM
# registration (shared-agent dispatch) rather than
# the deployment default with the instance api key.
provider = (
get_provider_from_model_id(
model_id,
user_id=model_user_id
or (
decoded_token.get("sub")
if decoded_token
else None
),
)
if model_id
else settings.LLM_PROVIDER
)
sys_api_key = get_api_key_for_provider(
provider or settings.LLM_PROVIDER
)
llm = LLMCreator.create_llm(
settings.LLM_PROVIDER,
api_key=settings.API_KEY,
provider or settings.LLM_PROVIDER,
api_key=sys_api_key,
user_api_key=user_api_key,
decoded_token=decoded_token,
model_id=model_id,
agent_id=agent_id,
model_user_id=model_user_id,
)
self.conversation_service.save_conversation(
conversation_id,

View File

@@ -109,6 +109,7 @@ class StreamResource(Resource, BaseAnswerResource):
decoded_token=processor.decoded_token,
agent_id=processor.agent_id,
model_id=processor.model_id,
model_user_id=processor.model_user_id,
_continuation={
"messages": messages,
"tools_dict": tools_dict,
@@ -145,6 +146,7 @@ class StreamResource(Resource, BaseAnswerResource):
is_shared_usage=processor.is_shared_usage,
shared_token=processor.shared_token,
model_id=processor.model_id,
model_user_id=processor.model_user_id,
),
mimetype="text/event-stream",
)

View File

@@ -49,6 +49,7 @@ class CompressionOrchestrator:
model_id: str,
decoded_token: Dict[str, Any],
current_query_tokens: int = 500,
model_user_id: Optional[str] = None,
) -> CompressionResult:
"""
Check if compression is needed and perform it if so.
@@ -57,16 +58,18 @@ class CompressionOrchestrator:
Args:
conversation_id: Conversation ID
user_id: User ID
user_id: Caller's user id — used for conversation access checks
model_id: Model being used for conversation
decoded_token: User's decoded JWT token
current_query_tokens: Estimated tokens for current query
model_user_id: BYOM-resolution scope (model owner); defaults
to ``user_id`` for built-in / caller-owned models.
Returns:
CompressionResult with summary and recent queries
"""
try:
# Load conversation
# Conversation row is owned by the caller, not the model owner.
conversation = self.conversation_service.get_conversation(
conversation_id, user_id
)
@@ -77,9 +80,14 @@ class CompressionOrchestrator:
)
return CompressionResult.failure("Conversation not found")
# Check if compression is needed
# Use model-owner scope so per-user BYOM context windows
# (e.g. 8k) compute the threshold against the right limit.
registry_user_id = model_user_id or user_id
if not self.threshold_checker.should_compress(
conversation, model_id, current_query_tokens
conversation,
model_id,
current_query_tokens,
user_id=registry_user_id,
):
# No compression needed, return full history
queries = conversation.get("queries", [])
@@ -87,7 +95,12 @@ class CompressionOrchestrator:
# Perform compression
return self._perform_compression(
conversation_id, conversation, model_id, decoded_token
conversation_id,
conversation,
model_id,
decoded_token,
user_id=user_id,
model_user_id=model_user_id,
)
except Exception as e:
@@ -102,6 +115,8 @@ class CompressionOrchestrator:
conversation: Dict[str, Any],
model_id: str,
decoded_token: Dict[str, Any],
user_id: Optional[str] = None,
model_user_id: Optional[str] = None,
) -> CompressionResult:
"""
Perform the actual compression operation.
@@ -111,6 +126,8 @@ class CompressionOrchestrator:
conversation: Conversation document
model_id: Model ID for conversation
decoded_token: User token
user_id: Caller's id (for conversation reload after compression)
model_user_id: BYOM-resolution scope (model owner)
Returns:
CompressionResult
@@ -123,11 +140,17 @@ class CompressionOrchestrator:
else model_id
)
# Get provider and API key for compression model
provider = get_provider_from_model_id(compression_model)
# Use model-owner scope so provider/api_key resolves to the
# owner's BYOM record (shared-agent dispatch).
caller_user_id = user_id
if caller_user_id is None and isinstance(decoded_token, dict):
caller_user_id = decoded_token.get("sub")
registry_user_id = model_user_id or caller_user_id
provider = get_provider_from_model_id(
compression_model, user_id=registry_user_id
)
api_key = get_api_key_for_provider(provider)
# Create compression LLM
compression_llm = LLMCreator.create_llm(
provider,
api_key=api_key,
@@ -135,6 +158,7 @@ class CompressionOrchestrator:
decoded_token=decoded_token,
model_id=compression_model,
agent_id=conversation.get("agent_id"),
model_user_id=registry_user_id,
)
# Create compression service with DB update capability
@@ -167,9 +191,12 @@ class CompressionOrchestrator:
f"saved {metadata.original_token_count - metadata.compressed_token_count} tokens"
)
# Reload conversation with updated metadata
# Reload under caller (conversation is owned by caller).
reload_user_id = caller_user_id
if reload_user_id is None and isinstance(decoded_token, dict):
reload_user_id = decoded_token.get("sub")
conversation = self.conversation_service.get_conversation(
conversation_id, user_id=decoded_token.get("sub")
conversation_id, user_id=reload_user_id
)
# Get compressed context
@@ -192,16 +219,21 @@ class CompressionOrchestrator:
model_id: str,
decoded_token: Dict[str, Any],
current_conversation: Optional[Dict[str, Any]] = None,
model_user_id: Optional[str] = None,
) -> CompressionResult:
"""
Perform compression during tool execution.
Args:
conversation_id: Conversation ID
user_id: User ID
user_id: Caller's user id — used for conversation access checks
model_id: Model ID
decoded_token: User token
current_conversation: Pre-loaded conversation (optional)
model_user_id: BYOM-resolution scope (model owner). For
shared-agent dispatch this is the agent owner; defaults
to ``user_id`` so built-in / caller-owned models are
unaffected.
Returns:
CompressionResult
@@ -223,7 +255,12 @@ class CompressionOrchestrator:
# Perform compression
return self._perform_compression(
conversation_id, conversation, model_id, decoded_token
conversation_id,
conversation,
model_id,
decoded_token,
user_id=user_id,
model_user_id=model_user_id,
)
except Exception as e:

View File

@@ -106,8 +106,13 @@ class CompressionService:
f"using model {self.model_id}"
)
# See note in conversation_service.py: ``self.model_id`` is
# the registry id (UUID for BYOM); the LLM's own model_id is
# what the provider's API actually expects.
response = self.llm.gen(
model=self.model_id, messages=messages, max_tokens=4000
model=getattr(self.llm, "model_id", None) or self.model_id,
messages=messages,
max_tokens=4000,
)
# Extract summary from response

View File

@@ -30,6 +30,7 @@ class CompressionThresholdChecker:
conversation: Dict[str, Any],
model_id: str,
current_query_tokens: int = 500,
user_id: str | None = None,
) -> bool:
"""
Determine if compression is needed.
@@ -38,6 +39,8 @@ class CompressionThresholdChecker:
conversation: Full conversation document
model_id: Target model for this request
current_query_tokens: Estimated tokens for current query
user_id: Owner — needed so per-user BYOM custom-model UUIDs
resolve when looking up the context window.
Returns:
True if tokens >= threshold% of context window
@@ -48,7 +51,7 @@ class CompressionThresholdChecker:
total_tokens += current_query_tokens
# Get context window limit for model
context_limit = get_token_limit(model_id)
context_limit = get_token_limit(model_id, user_id=user_id)
# Calculate threshold
threshold = int(context_limit * self.threshold_percentage)
@@ -73,20 +76,24 @@ class CompressionThresholdChecker:
logger.error(f"Error checking compression need: {str(e)}", exc_info=True)
return False
def check_message_tokens(self, messages: list, model_id: str) -> bool:
def check_message_tokens(
self, messages: list, model_id: str, user_id: str | None = None
) -> bool:
"""
Check if message list exceeds threshold.
Args:
messages: List of message dicts
model_id: Target model
user_id: Owner — needed so per-user BYOM custom-model UUIDs
resolve when looking up the context window.
Returns:
True if at or above threshold
"""
try:
current_tokens = TokenCounter.count_message_tokens(messages)
context_limit = get_token_limit(model_id)
context_limit = get_token_limit(model_id, user_id=user_id)
threshold = int(context_limit * self.threshold_percentage)
if current_tokens >= threshold:

View File

@@ -136,8 +136,14 @@ class ConversationService:
},
]
# ``model_id`` here is the registry id (a UUID for BYOM
# records). The LLM's own ``model_id`` is the upstream name
# LLMCreator resolved at construction time — that's what
# the provider's API expects. Built-ins are unaffected.
completion = llm.gen(
model=model_id, messages=messages_summary, max_tokens=500
model=getattr(llm, "model_id", None) or model_id,
messages=messages_summary,
max_tokens=500,
)
if not completion or not completion.strip():

View File

@@ -121,6 +121,8 @@ class StreamProcessor:
self.agent_id = self.data.get("agent_id")
self.agent_key = None
self.model_id: Optional[str] = None
# BYOM-resolution scope, set by _validate_and_set_model.
self.model_user_id: Optional[str] = None
self.conversation_service = ConversationService()
self.compression_orchestrator = CompressionOrchestrator(
self.conversation_service
@@ -191,16 +193,23 @@ class StreamProcessor:
for query in conversation.get("queries", [])
]
else:
# model_user_id keeps history trim aligned with the BYOM's
# actual context window instead of the default 128k.
self.history = limit_chat_history(
json.loads(self.data.get("history", "[]")), model_id=self.model_id
json.loads(self.data.get("history", "[]")),
model_id=self.model_id,
user_id=self.model_user_id,
)
def _handle_compression(self, conversation: Dict[str, Any]):
"""Handle conversation compression logic using orchestrator."""
try:
# initial_user_id for conversation access; model_user_id
# for BYOM context-window / provider lookups.
result = self.compression_orchestrator.compress_if_needed(
conversation_id=self.conversation_id,
user_id=self.initial_user_id,
model_user_id=self.model_user_id,
model_id=self.model_id,
decoded_token=self.decoded_token,
)
@@ -284,11 +293,18 @@ class StreamProcessor:
from application.core.model_settings import ModelRegistry
requested_model = self.data.get("model_id")
# Caller picks from their own BYOM layer; agent defaults resolve
# under the owner's layer (shared agents have caller != owner).
caller_user_id = self.initial_user_id
owner_user_id = self.agent_config.get("user_id") or caller_user_id
if requested_model:
if not validate_model_id(requested_model):
if not validate_model_id(requested_model, user_id=caller_user_id):
registry = ModelRegistry.get_instance()
available_models = [m.id for m in registry.get_enabled_models()]
available_models = [
m.id
for m in registry.get_enabled_models(user_id=caller_user_id)
]
raise ValueError(
f"Invalid model_id '{requested_model}'. "
f"Available models: {', '.join(available_models[:5])}"
@@ -299,12 +315,17 @@ class StreamProcessor:
)
)
self.model_id = requested_model
self.model_user_id = caller_user_id
else:
agent_default_model = self.agent_config.get("default_model_id", "")
if agent_default_model and validate_model_id(agent_default_model):
if agent_default_model and validate_model_id(
agent_default_model, user_id=owner_user_id
):
self.model_id = agent_default_model
self.model_user_id = owner_user_id
else:
self.model_id = get_default_model_id()
self.model_user_id = None
def _get_agent_key(self, agent_id: Optional[str], user_id: Optional[str]) -> tuple:
"""Get API key for agent with access control."""
@@ -514,6 +535,10 @@ class StreamProcessor:
"allow_system_prompt_override": self._agent_data.get(
"allow_system_prompt_override", False
),
# Owner identity — _validate_and_set_model reads this to
# resolve owner-stored BYOM default_model_id against the
# owner's per-user model layer rather than the caller's.
"user_id": self._agent_data.get("user"),
}
)
@@ -561,7 +586,13 @@ class StreamProcessor:
def _configure_retriever(self):
"""Assemble retriever config with precedence: request > agent > default."""
doc_token_limit = calculate_doc_token_budget(model_id=self.model_id)
# BYOM scope: owner for shared-agent BYOM, caller for own BYOM,
# None for built-ins. Without ``user_id`` here, the doc budget
# falls back to settings.DEFAULT_LLM_TOKEN_LIMIT and overfills
# the upstream context window for any small (e.g. 8k/32k) BYOM.
doc_token_limit = calculate_doc_token_budget(
model_id=self.model_id, user_id=self.model_user_id
)
# Start with defaults
retriever_name = "classic"
@@ -612,6 +643,7 @@ class StreamProcessor:
chunks=self.retriever_config["chunks"],
doc_token_limit=self.retriever_config.get("doc_token_limit", 50000),
model_id=self.model_id,
model_user_id=self.model_user_id,
user_api_key=self.agent_config["user_api_key"],
agent_id=self.agent_id,
decoded_token=self.decoded_token,
@@ -903,6 +935,11 @@ class StreamProcessor:
agent_config = state["agent_config"]
model_id = agent_config.get("model_id")
# BYOM scope captured at initial dispatch. None for built-ins or
# caller-owned BYOM where decoded_token['sub'] is already the
# right scope; non-None for shared-agent owner BYOM where the
# caller's identity differs from the model owner's.
model_user_id = agent_config.get("model_user_id")
llm_name = agent_config.get("llm_name", settings.LLM_PROVIDER)
api_key = agent_config.get("api_key")
user_api_key = agent_config.get("user_api_key")
@@ -920,6 +957,7 @@ class StreamProcessor:
decoded_token=self.decoded_token,
model_id=model_id,
agent_id=agent_id,
model_user_id=model_user_id,
)
llm_handler = LLMHandlerCreator.create_handler(llm_name or "default")
tool_executor = ToolExecutor(
@@ -949,6 +987,7 @@ class StreamProcessor:
"endpoint": "stream",
"llm_name": llm_name,
"model_id": model_id,
"model_user_id": model_user_id,
"api_key": system_api_key,
"agent_id": agent_id,
"user_api_key": user_api_key,
@@ -971,6 +1010,15 @@ class StreamProcessor:
# Store config for the route layer
self.model_id = model_id
# Mirror ``model_user_id`` back onto the processor so the route
# layer (StreamResource) reads the owner scope captured at
# initial dispatch. Without this, ``processor.model_user_id``
# stays at the __init__ default (None) and complete_stream
# falls back to the caller's sub: the post-resume title-LLM
# save misses the owner's BYOM layer, and any second tool
# pause persists ``model_user_id=None`` — losing owner scope
# for every subsequent resume of this conversation.
self.model_user_id = model_user_id
self.agent_id = agent_id
self.agent_config["user_api_key"] = user_api_key
self.conversation_id = conversation_id
@@ -1022,8 +1070,11 @@ class StreamProcessor:
tools_data=tools_data,
)
# Use the user_id that resolved the model so owner-scoped BYOM
# records dispatch correctly on shared-agent requests.
model_user_id = getattr(self, "model_user_id", self.initial_user_id)
provider = (
get_provider_from_model_id(self.model_id)
get_provider_from_model_id(self.model_id, user_id=model_user_id)
if self.model_id
else settings.LLM_PROVIDER
)
@@ -1048,6 +1099,8 @@ class StreamProcessor:
model_id=self.model_id,
agent_id=self.agent_id,
backup_models=backup_models,
# Owner-scope on shared-agent BYOM dispatch.
model_user_id=model_user_id,
)
llm_handler = LLMHandlerCreator.create_handler(
provider if provider else "default"
@@ -1070,6 +1123,7 @@ class StreamProcessor:
"endpoint": "stream",
"llm_name": provider or settings.LLM_PROVIDER,
"model_id": self.model_id,
"model_user_id": self.model_user_id,
"api_key": system_api_key,
"agent_id": self.agent_id,
"user_api_key": self.agent_config["user_api_key"],
@@ -1097,6 +1151,7 @@ class StreamProcessor:
"doc_token_limit", 50000
),
"model_id": self.model_id,
"model_user_id": self.model_user_id,
"user_api_key": self.agent_config["user_api_key"],
"agent_id": self.agent_id,
"llm_name": provider or settings.LLM_PROVIDER,

View File

@@ -1,18 +1,135 @@
from flask import current_app, jsonify, make_response
"""Model routes.
- ``GET /api/models`` — list available models for the current user.
Combines the built-in catalog with the user's BYOM records.
- ``GET/POST/PATCH/DELETE /api/user/models[/<id>]`` — CRUD for the
user's own OpenAI-compatible model registrations (BYOM).
- ``POST /api/user/models/<id>/test`` — sanity-check the upstream
endpoint with a tiny request.
Every BYOM endpoint is user-scoped at the repository layer
(every query filters on ``user_id`` from ``request.decoded_token``).
"""
from __future__ import annotations
import logging
import requests
from flask import current_app, jsonify, make_response, request
from flask_restx import Namespace, Resource
from application.core.model_settings import ModelRegistry
from application.api import api
from application.core.model_registry import ModelRegistry
from application.security.safe_url import (
UnsafeUserUrlError,
pinned_post,
validate_user_base_url,
)
from application.storage.db.repositories.user_custom_models import (
UserCustomModelsRepository,
)
from application.storage.db.session import db_readonly, db_session
from application.utils import check_required_fields
logger = logging.getLogger(__name__)
models_ns = Namespace("models", description="Available models", path="/api")
_CONTEXT_WINDOW_MIN = 1_000
_CONTEXT_WINDOW_MAX = 10_000_000
def _user_id_or_401():
decoded_token = request.decoded_token
if not decoded_token:
return None, make_response(jsonify({"success": False}), 401)
user_id = decoded_token.get("sub")
if not user_id:
return None, make_response(jsonify({"success": False}), 401)
return user_id, None
def _normalize_capabilities(raw) -> dict:
"""Coerce + bound the user-supplied capabilities payload."""
raw = raw or {}
out = {}
if "supports_tools" in raw:
out["supports_tools"] = bool(raw["supports_tools"])
if "supports_structured_output" in raw:
out["supports_structured_output"] = bool(raw["supports_structured_output"])
if "supports_streaming" in raw:
out["supports_streaming"] = bool(raw["supports_streaming"])
if "attachments" in raw:
atts = raw["attachments"] or []
if not isinstance(atts, list):
raise ValueError("'capabilities.attachments' must be a list")
coerced = [str(a) for a in atts]
# Reject unknown aliases at the API boundary so bad payloads
# never reach the registry layer (where lenient expansion just
# drops them). Raw MIME types (containing ``/``) pass through
# unchanged for parity with the built-in YAML schema.
from application.core.model_yaml import builtin_attachment_aliases
aliases = builtin_attachment_aliases()
for entry in coerced:
if "/" in entry:
continue
if entry not in aliases:
valid = ", ".join(sorted(aliases.keys())) or "<none defined>"
raise ValueError(
f"unknown attachment alias '{entry}' in "
f"'capabilities.attachments'. Valid aliases: {valid}, "
f"or use a raw MIME type like 'image/png'."
)
out["attachments"] = coerced
if "context_window" in raw:
try:
cw = int(raw["context_window"])
except (TypeError, ValueError):
raise ValueError("'capabilities.context_window' must be an integer")
if not (_CONTEXT_WINDOW_MIN <= cw <= _CONTEXT_WINDOW_MAX):
raise ValueError(
f"'capabilities.context_window' must be between "
f"{_CONTEXT_WINDOW_MIN} and {_CONTEXT_WINDOW_MAX}"
)
out["context_window"] = cw
return out
def _row_to_response(row: dict) -> dict:
"""Wire-format projection — never includes the API key."""
return {
"id": str(row["id"]),
"upstream_model_id": row["upstream_model_id"],
"display_name": row["display_name"],
"description": row.get("description") or "",
"base_url": row["base_url"],
"capabilities": row.get("capabilities") or {},
"enabled": bool(row.get("enabled", True)),
"source": "user",
}
@models_ns.route("/models")
class ModelsListResource(Resource):
def get(self):
"""Get list of available models with their capabilities."""
"""Get list of available models with their capabilities.
When the request is authenticated, the response includes the
user's own BYOM registrations alongside the built-in catalog.
"""
try:
user_id = None
decoded_token = getattr(request, "decoded_token", None)
if decoded_token:
user_id = decoded_token.get("sub")
registry = ModelRegistry.get_instance()
models = registry.get_enabled_models()
models = registry.get_enabled_models(user_id=user_id)
response = {
"models": [model.to_dict() for model in models],
@@ -23,3 +140,382 @@ class ModelsListResource(Resource):
current_app.logger.error(f"Error fetching models: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 500)
return make_response(jsonify(response), 200)
@models_ns.route("/user/models")
class UserModelsCollectionResource(Resource):
@api.doc(description="List the current user's BYOM custom models")
def get(self):
user_id, err = _user_id_or_401()
if err:
return err
try:
with db_readonly() as conn:
rows = UserCustomModelsRepository(conn).list_for_user(user_id)
return make_response(
jsonify({"models": [_row_to_response(r) for r in rows]}), 200
)
except Exception as e:
current_app.logger.error(
f"Error listing user custom models: {e}", exc_info=True
)
return make_response(jsonify({"success": False}), 500)
@api.doc(description="Register a new BYOM custom model")
def post(self):
user_id, err = _user_id_or_401()
if err:
return err
data = request.get_json() or {}
missing = check_required_fields(
data,
["upstream_model_id", "display_name", "base_url", "api_key"],
)
if missing:
return missing
# SECURITY: reject blank api_key — would leak instance API key
# to the user-supplied base_url via LLMCreator fallback.
for required_nonblank in (
"upstream_model_id",
"display_name",
"base_url",
"api_key",
):
value = data.get(required_nonblank)
if not isinstance(value, str) or not value.strip():
return make_response(
jsonify(
{
"success": False,
"error": f"'{required_nonblank}' must be a non-empty string",
}
),
400,
)
# SSRF guard at create time. Re-runs at dispatch time (LLMCreator)
# as defense in depth against DNS rebinding and pre-guard rows.
try:
validate_user_base_url(data["base_url"])
except UnsafeUserUrlError as e:
return make_response(
jsonify({"success": False, "error": str(e)}), 400
)
try:
capabilities = _normalize_capabilities(data.get("capabilities"))
except ValueError as e:
return make_response(
jsonify({"success": False, "error": str(e)}), 400
)
try:
with db_session() as conn:
row = UserCustomModelsRepository(conn).create(
user_id=user_id,
upstream_model_id=data["upstream_model_id"],
display_name=data["display_name"],
description=data.get("description") or "",
base_url=data["base_url"],
api_key_plaintext=data["api_key"],
capabilities=capabilities,
enabled=bool(data.get("enabled", True)),
)
except Exception as e:
current_app.logger.error(
f"Error creating user custom model: {e}", exc_info=True
)
return make_response(jsonify({"success": False}), 500)
ModelRegistry.invalidate_user(user_id)
return make_response(jsonify(_row_to_response(row)), 201)
@models_ns.route("/user/models/<string:model_id>")
class UserModelResource(Resource):
@api.doc(description="Get one BYOM custom model")
def get(self, model_id):
user_id, err = _user_id_or_401()
if err:
return err
try:
with db_readonly() as conn:
row = UserCustomModelsRepository(conn).get(model_id, user_id)
except Exception as e:
current_app.logger.error(
f"Error fetching user custom model: {e}", exc_info=True
)
return make_response(jsonify({"success": False}), 500)
if row is None:
return make_response(jsonify({"success": False}), 404)
return make_response(jsonify(_row_to_response(row)), 200)
@api.doc(description="Update a BYOM custom model (partial)")
def patch(self, model_id):
user_id, err = _user_id_or_401()
if err:
return err
data = request.get_json() or {}
# Reject present-but-blank values for fields where blank doesn't
# mean "no change". (The api_key special case — blank means "keep
# existing" — is handled below.)
for required_nonblank in (
"upstream_model_id",
"display_name",
"base_url",
):
if required_nonblank in data:
value = data[required_nonblank]
if not isinstance(value, str) or not value.strip():
return make_response(
jsonify(
{
"success": False,
"error": f"'{required_nonblank}' cannot be blank",
}
),
400,
)
if "base_url" in data and data["base_url"]:
try:
validate_user_base_url(data["base_url"])
except UnsafeUserUrlError as e:
return make_response(
jsonify({"success": False, "error": str(e)}), 400
)
update_fields: dict = {}
for k in (
"upstream_model_id",
"display_name",
"description",
"base_url",
"enabled",
):
if k in data:
update_fields[k] = data[k]
if "capabilities" in data:
try:
update_fields["capabilities"] = _normalize_capabilities(
data["capabilities"]
)
except ValueError as e:
return make_response(
jsonify({"success": False, "error": str(e)}), 400
)
# PATCH semantics: blank/missing api_key → keep the existing
# ciphertext; non-empty api_key → re-encrypt and replace.
if data.get("api_key"):
update_fields["api_key_plaintext"] = data["api_key"]
if not update_fields:
return make_response(
jsonify({"success": False, "error": "no updatable fields"}), 400
)
try:
with db_session() as conn:
ok = UserCustomModelsRepository(conn).update(
model_id, user_id, update_fields
)
except Exception as e:
current_app.logger.error(
f"Error updating user custom model: {e}", exc_info=True
)
return make_response(jsonify({"success": False}), 500)
if not ok:
return make_response(jsonify({"success": False}), 404)
ModelRegistry.invalidate_user(user_id)
with db_readonly() as conn:
row = UserCustomModelsRepository(conn).get(model_id, user_id)
return make_response(jsonify(_row_to_response(row)), 200)
@api.doc(description="Delete a BYOM custom model")
def delete(self, model_id):
user_id, err = _user_id_or_401()
if err:
return err
try:
with db_session() as conn:
ok = UserCustomModelsRepository(conn).delete(model_id, user_id)
except Exception as e:
current_app.logger.error(
f"Error deleting user custom model: {e}", exc_info=True
)
return make_response(jsonify({"success": False}), 500)
if not ok:
return make_response(jsonify({"success": False}), 404)
ModelRegistry.invalidate_user(user_id)
return make_response(jsonify({"success": True}), 200)
def _run_connection_test(
base_url: str, api_key: str, upstream_model_id: str
):
"""Send a 1-token chat-completion to verify a BYOM endpoint.
Returns ``(body, http_status)``. Upstream errors return 200 with
``ok=False`` so the UI can render inline errors; only local SSRF
rejection returns 400.
"""
url = base_url.rstrip("/") + "/chat/completions"
payload = {
"model": upstream_model_id,
"messages": [{"role": "user", "content": "hi"}],
"max_tokens": 1,
"stream": False,
}
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
try:
# pinned_post closes the DNS-rebinding window. Redirects off
# because 3xx could bounce to an internal address (the SSRF
# guard only validates the supplied URL).
resp = pinned_post(
url,
json=payload,
headers=headers,
timeout=5,
allow_redirects=False,
)
except UnsafeUserUrlError as e:
return {"ok": False, "error": str(e)}, 400
except requests.RequestException as e:
return {"ok": False, "error": f"connection error: {e}"}, 200
if 300 <= resp.status_code < 400:
return (
{
"ok": False,
"error": (
f"upstream returned HTTP {resp.status_code} "
"redirect; refusing to follow"
),
},
200,
)
if resp.status_code >= 400:
# Cap and only reflect JSON to avoid body-exfil via non-API responses.
content_type = (resp.headers.get("Content-Type") or "").lower()
if "application/json" in content_type:
text = (resp.text or "")[:500]
error_msg = f"upstream returned HTTP {resp.status_code}: {text}"
else:
error_msg = f"upstream returned HTTP {resp.status_code}"
return {"ok": False, "error": error_msg}, 200
return {"ok": True}, 200
@models_ns.route("/user/models/test")
class UserModelTestPayloadResource(Resource):
@api.doc(
description=(
"Test an arbitrary BYOM payload (display_name / model id / "
"base_url / api_key) without saving. Used by the UI's 'Test "
"connection' button so the user can validate before they "
"Save. Same SSRF guard, same 1-token request, same 5s "
"timeout as the by-id variant."
)
)
def post(self):
user_id, err = _user_id_or_401()
if err:
return err
data = request.get_json() or {}
missing = check_required_fields(
data, ["base_url", "api_key", "upstream_model_id"]
)
if missing:
return missing
body, status = _run_connection_test(
data["base_url"], data["api_key"], data["upstream_model_id"]
)
return make_response(jsonify(body), status)
@models_ns.route("/user/models/<string:model_id>/test")
class UserModelTestResource(Resource):
@api.doc(
description=(
"Test a saved BYOM record. Defaults to the stored "
"base_url / upstream_model_id / encrypted api_key, but "
"any of those can be overridden via the request body so "
"the UI can test in-flight edits before saving. Used by "
"the 'Test connection' button in edit mode."
)
)
def post(self, model_id):
user_id, err = _user_id_or_401()
if err:
return err
data = request.get_json() or {}
# Per-field overrides; blank/missing falls back to stored value.
override_base_url = (data.get("base_url") or "").strip() or None
override_upstream_model_id = (
data.get("upstream_model_id") or ""
).strip() or None
override_api_key = (data.get("api_key") or "").strip() or None
try:
with db_readonly() as conn:
repo = UserCustomModelsRepository(conn)
row = repo.get(model_id, user_id)
if row is None:
return make_response(jsonify({"success": False}), 404)
stored_api_key = (
repo._decrypt_api_key(
row.get("api_key_encrypted", ""), user_id
)
if not override_api_key
else None
)
except Exception as e:
current_app.logger.error(
f"Error loading user custom model for test: {e}", exc_info=True
)
return make_response(
jsonify({"ok": False, "error": "internal error loading model"}),
500,
)
api_key = override_api_key or stored_api_key
if not api_key:
return make_response(
jsonify(
{
"ok": False,
"error": (
"Stored API key could not be decrypted. The "
"encryption secret may have rotated. Re-save "
"the model with the API key to recover."
),
}
),
400,
)
base_url = override_base_url or row["base_url"]
upstream_model_id = (
override_upstream_model_id or row["upstream_model_id"]
)
body, status = _run_connection_test(
base_url, api_key, upstream_model_id
)
return make_response(jsonify(body), status)

View File

@@ -198,8 +198,14 @@ def normalize_agent_node_json_schemas(nodes: List[Dict]) -> List[Dict]:
return normalized_nodes
def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[str]:
"""Validate workflow graph structure."""
def validate_workflow_structure(
nodes: List[Dict], edges: List[Dict], user_id: str | None = None
) -> List[str]:
"""Validate workflow graph structure.
``user_id`` is required so per-user BYOM custom-model UUIDs resolve
when checking each agent node's structured-output capability.
"""
errors = []
if not nodes:
@@ -343,7 +349,7 @@ def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[st
model_id = raw_config.get("model_id")
if has_json_schema and isinstance(model_id, str) and model_id.strip():
capabilities = get_model_capabilities(model_id.strip())
capabilities = get_model_capabilities(model_id.strip(), user_id=user_id)
if capabilities and not capabilities.get("supports_structured_output", False):
errors.append(
f"Agent node '{agent_title}' selected model does not support structured output"
@@ -389,7 +395,9 @@ class WorkflowList(Resource):
nodes_data = data.get("nodes", [])
edges_data = data.get("edges", [])
validation_errors = validate_workflow_structure(nodes_data, edges_data)
validation_errors = validate_workflow_structure(
nodes_data, edges_data, user_id=user_id
)
if validation_errors:
return error_response(
"Workflow validation failed", errors=validation_errors
@@ -451,7 +459,9 @@ class WorkflowDetail(Resource):
nodes_data = data.get("nodes", [])
edges_data = data.get("edges", [])
validation_errors = validate_workflow_structure(nodes_data, edges_data)
validation_errors = validate_workflow_structure(
nodes_data, edges_data, user_id=user_id
)
if validation_errors:
return error_response(
"Workflow validation failed", errors=validation_errors

View File

@@ -213,6 +213,7 @@ def _stream_response(
decoded_token=processor.decoded_token,
agent_id=processor.agent_id,
model_id=processor.model_id,
model_user_id=processor.model_user_id,
should_save_conversation=should_save_conversation,
_continuation=continuation,
)
@@ -257,6 +258,7 @@ def _non_stream_response(
decoded_token=processor.decoded_token,
agent_id=processor.agent_id,
model_id=processor.model_id,
model_user_id=processor.model_user_id,
should_save_conversation=should_save_conversation,
_continuation=continuation,
)

View File

@@ -153,7 +153,7 @@ def after_request(response: Response) -> Response:
"""Add CORS headers for the pure Flask development entrypoint."""
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization"
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS"
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
return response

View File

@@ -24,7 +24,7 @@ asgi_app = Starlette(
Middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
allow_headers=["Content-Type", "Authorization", "Mcp-Session-Id"],
expose_headers=["Mcp-Session-Id"],
),

View File

@@ -1,266 +0,0 @@
"""
Model configurations for all supported LLM providers.
"""
from application.core.model_settings import (
AvailableModel,
ModelCapabilities,
ModelProvider,
)
# Base image attachment types supported by most vision-capable LLMs
IMAGE_ATTACHMENTS = [
"image/png",
"image/jpeg",
"image/jpg",
"image/webp",
"image/gif",
]
# PDF excluded: most OpenAI-compatible endpoints don't support native PDF uploads.
# When excluded, PDFs are synthetically processed by converting pages to images.
OPENAI_ATTACHMENTS = IMAGE_ATTACHMENTS
GOOGLE_ATTACHMENTS = ["application/pdf"] + IMAGE_ATTACHMENTS
ANTHROPIC_ATTACHMENTS = IMAGE_ATTACHMENTS
OPENROUTER_ATTACHMENTS = IMAGE_ATTACHMENTS
NOVITA_ATTACHMENTS = IMAGE_ATTACHMENTS
OPENAI_MODELS = [
AvailableModel(
id="gpt-5.1",
provider=ModelProvider.OPENAI,
display_name="GPT-5.1",
description="Flagship model with enhanced reasoning, coding, and agentic capabilities",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
context_window=200000,
),
),
AvailableModel(
id="gpt-5-mini",
provider=ModelProvider.OPENAI,
display_name="GPT-5 Mini",
description="Faster, cost-effective variant of GPT-5.1",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
context_window=200000,
),
)
]
ANTHROPIC_MODELS = [
AvailableModel(
id="claude-3-5-sonnet-20241022",
provider=ModelProvider.ANTHROPIC,
display_name="Claude 3.5 Sonnet (Latest)",
description="Latest Claude 3.5 Sonnet with enhanced capabilities",
capabilities=ModelCapabilities(
supports_tools=True,
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
context_window=200000,
),
),
AvailableModel(
id="claude-3-5-sonnet",
provider=ModelProvider.ANTHROPIC,
display_name="Claude 3.5 Sonnet",
description="Balanced performance and capability",
capabilities=ModelCapabilities(
supports_tools=True,
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
context_window=200000,
),
),
AvailableModel(
id="claude-3-opus",
provider=ModelProvider.ANTHROPIC,
display_name="Claude 3 Opus",
description="Most capable Claude model",
capabilities=ModelCapabilities(
supports_tools=True,
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
context_window=200000,
),
),
AvailableModel(
id="claude-3-haiku",
provider=ModelProvider.ANTHROPIC,
display_name="Claude 3 Haiku",
description="Fastest Claude model",
capabilities=ModelCapabilities(
supports_tools=True,
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
context_window=200000,
),
),
]
GOOGLE_MODELS = [
AvailableModel(
id="gemini-flash-latest",
provider=ModelProvider.GOOGLE,
display_name="Gemini Flash (Latest)",
description="Latest experimental Gemini model",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=GOOGLE_ATTACHMENTS,
context_window=int(1e6),
),
),
AvailableModel(
id="gemini-flash-lite-latest",
provider=ModelProvider.GOOGLE,
display_name="Gemini Flash Lite (Latest)",
description="Fast with huge context window",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=GOOGLE_ATTACHMENTS,
context_window=int(1e6),
),
),
AvailableModel(
id="gemini-3-pro-preview",
provider=ModelProvider.GOOGLE,
display_name="Gemini 3 Pro",
description="Most capable Gemini model",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=GOOGLE_ATTACHMENTS,
context_window=2000000,
),
),
]
GROQ_MODELS = [
AvailableModel(
id="llama-3.3-70b-versatile",
provider=ModelProvider.GROQ,
display_name="Llama 3.3 70B",
description="Latest Llama model with high-speed inference",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=128000,
),
),
AvailableModel(
id="openai/gpt-oss-120b",
provider=ModelProvider.GROQ,
display_name="GPT-OSS 120B",
description="Open-source GPT model optimized for speed",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=128000,
),
),
]
OPENROUTER_MODELS = [
AvailableModel(
id="qwen/qwen3-coder:free",
provider=ModelProvider.OPENROUTER,
display_name="Qwen 3 Coder",
description="Latest Qwen model with high-speed inference",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=128000,
supported_attachment_types=OPENROUTER_ATTACHMENTS
),
),
AvailableModel(
id="google/gemma-3-27b-it:free",
provider=ModelProvider.OPENROUTER,
display_name="Gemma 3 27B",
description="Latest Gemma model with high-speed inference",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=128000,
supported_attachment_types=OPENROUTER_ATTACHMENTS
),
),
]
NOVITA_MODELS = [
AvailableModel(
id="moonshotai/kimi-k2.5",
provider=ModelProvider.NOVITA,
display_name="Kimi K2.5",
description="MoE model with function calling, structured output, reasoning, and vision",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=NOVITA_ATTACHMENTS,
context_window=262144,
),
),
AvailableModel(
id="zai-org/glm-5",
provider=ModelProvider.NOVITA,
display_name="GLM-5",
description="MoE model with function calling, structured output, and reasoning",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=[],
context_window=202800,
),
),
AvailableModel(
id="minimax/minimax-m2.5",
provider=ModelProvider.NOVITA,
display_name="MiniMax M2.5",
description="MoE model with function calling, structured output, and reasoning",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=[],
context_window=204800,
),
),
]
AZURE_OPENAI_MODELS = [
AvailableModel(
id="azure-gpt-4",
provider=ModelProvider.AZURE_OPENAI,
display_name="Azure OpenAI GPT-4",
description="Azure-hosted GPT model",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
context_window=8192,
),
),
]
def create_custom_openai_model(model_name: str, base_url: str) -> AvailableModel:
"""Create a custom OpenAI-compatible model (e.g., LM Studio, Ollama)."""
return AvailableModel(
id=model_name,
provider=ModelProvider.OPENAI,
display_name=model_name,
description=f"Custom OpenAI-compatible model at {base_url}",
base_url=base_url,
capabilities=ModelCapabilities(
supports_tools=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
),
)

View File

@@ -0,0 +1,385 @@
"""Layered model registry.
Loads model catalogs from YAML files (built-in + operator-supplied),
groups them by provider name, then for each registered provider plugin
calls ``get_models`` to produce the final per-provider model list.
End-user BYOM (per-user model records in Postgres) is layered on top:
when a lookup arrives with a ``user_id``, the registry consults a
per-user cache first (loaded from the ``user_custom_models`` table on
miss) and falls through to the built-in catalog.
Cross-process invalidation: ``ModelRegistry`` is a per-process
singleton, so a CRUD write only evicts the cache in the process that
served it. Other gunicorn workers and Celery workers would otherwise
keep using a deleted/disabled/key-rotated BYOM record indefinitely.
``invalidate_user`` therefore both drops the local layer *and* bumps a
Redis-side version counter; other processes notice the bump on their
next access (after the local TTL window) and reload from Postgres. If
Redis is unreachable the per-process TTL still bounds staleness — pure
TTL semantics, no regression.
"""
from __future__ import annotations
import logging
import time
from collections import defaultdict
from typing import Dict, List, Optional, Tuple
from application.core.model_settings import AvailableModel
from application.core.model_yaml import (
BUILTIN_MODELS_DIR,
ProviderCatalog,
load_model_yamls,
)
logger = logging.getLogger(__name__)
_USER_CACHE_TTL_SECONDS = 60.0
_USER_VERSION_KEY_PREFIX = "byom:registry_version:"
class ModelRegistry:
"""Singleton registry of available models."""
_instance: Optional["ModelRegistry"] = None
_initialized: bool = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not ModelRegistry._initialized:
self.models: Dict[str, AvailableModel] = {}
self.default_model_id: Optional[str] = None
# Per-user BYOM cache. Each entry is
# ``(layer, version_at_load, loaded_at_monotonic)``:
# * ``layer`` — {model_id: AvailableModel}
# * ``version_at_load`` — Redis-side counter snapshot at
# reload time, or ``None`` if Redis was unreachable
# * ``loaded_at_monotonic`` — for TTL bookkeeping
# Populated lazily, evicted by TTL + cross-process
# invalidation (see ``invalidate_user``).
self._user_models: Dict[
str,
Tuple[Dict[str, AvailableModel], Optional[int], float],
] = {}
self._load_models()
ModelRegistry._initialized = True
@classmethod
def get_instance(cls) -> "ModelRegistry":
return cls()
@classmethod
def reset(cls) -> None:
"""Clear the singleton. Intended for test fixtures."""
cls._instance = None
cls._initialized = False
@classmethod
def invalidate_user(cls, user_id: str) -> None:
"""Drop the cached per-user model layer for ``user_id``.
Called by the BYOM REST routes after every create/update/delete.
Two effects:
* Local: pop the entry from this process's cache so the next
lookup re-reads from Postgres immediately.
* Cross-process: ``INCR`` a Redis-side version counter for this
user. Other gunicorn/Celery processes notice the counter
changed on their next TTL-driven recheck (see
``_user_models_for``) and reload. If Redis is unreachable we
log and continue — local invalidation still happened, and
peers fall back to TTL-only staleness bounds.
"""
if cls._instance is not None:
cls._instance._user_models.pop(user_id, None)
try:
from application.cache import get_redis_instance
client = get_redis_instance()
if client is not None:
client.incr(_USER_VERSION_KEY_PREFIX + user_id)
except Exception as e:
logger.warning(
"BYOM invalidate: failed to publish version bump for "
"user %s (Redis unreachable?): %s",
user_id,
e,
)
@classmethod
def _read_user_version(cls, user_id: str) -> Optional[int]:
"""Return the Redis-side invalidation counter for ``user_id``.
``0`` if the key has never been bumped; ``None`` if Redis is
unreachable or the read failed (callers fall back to TTL-only
staleness in that case).
"""
try:
from application.cache import get_redis_instance
client = get_redis_instance()
if client is None:
return None
raw = client.get(_USER_VERSION_KEY_PREFIX + user_id)
if raw is None:
return 0
return int(raw)
except Exception:
return None
def _load_models(self) -> None:
from pathlib import Path
from application.core.settings import settings
from application.llm.providers import ALL_PROVIDERS
directories = [BUILTIN_MODELS_DIR]
operator_dir = getattr(settings, "MODELS_CONFIG_DIR", None)
if operator_dir:
op_path = Path(operator_dir)
if not op_path.exists():
logger.warning(
"MODELS_CONFIG_DIR=%s does not exist; no operator "
"model YAMLs will be loaded.",
operator_dir,
)
elif not op_path.is_dir():
logger.warning(
"MODELS_CONFIG_DIR=%s is not a directory; no operator "
"model YAMLs will be loaded.",
operator_dir,
)
else:
directories.append(op_path)
catalogs = load_model_yamls(directories)
# Validate every catalog targets a known plugin before doing any
# registry work, so an unknown provider name in YAML aborts boot
# with a clear error.
plugin_names = {p.name for p in ALL_PROVIDERS}
for c in catalogs:
if c.provider not in plugin_names:
raise ValueError(
f"{c.source_path}: YAML declares unknown provider "
f"{c.provider!r}; no Provider plugin is registered "
f"under that name. Known: {sorted(plugin_names)}"
)
catalogs_by_provider: Dict[str, List[ProviderCatalog]] = defaultdict(list)
for c in catalogs:
catalogs_by_provider[c.provider].append(c)
self.models.clear()
for provider in ALL_PROVIDERS:
if not provider.is_enabled(settings):
continue
for model in provider.get_models(
settings, catalogs_by_provider.get(provider.name, [])
):
self.models[model.id] = model
self.default_model_id = self._resolve_default(settings)
logger.info(
"ModelRegistry loaded %d models, default: %s",
len(self.models),
self.default_model_id,
)
def _resolve_default(self, settings) -> Optional[str]:
if settings.LLM_NAME:
for name in self._parse_model_names(settings.LLM_NAME):
if name in self.models:
return name
if settings.LLM_NAME in self.models:
return settings.LLM_NAME
if settings.LLM_PROVIDER and settings.API_KEY:
for model_id, model in self.models.items():
if model.provider.value == settings.LLM_PROVIDER:
return model_id
if self.models:
return next(iter(self.models.keys()))
return None
@staticmethod
def _parse_model_names(llm_name: str) -> List[str]:
if not llm_name:
return []
return [name.strip() for name in llm_name.split(",") if name.strip()]
# Per-user (BYOM) layer
def _user_models_for(self, user_id: str) -> Dict[str, AvailableModel]:
"""Return the user's BYOM models keyed by registry id (UUID).
Loaded lazily from Postgres on first access; cached subject to
a per-process TTL (``_USER_CACHE_TTL_SECONDS``) and a Redis-
backed version counter for cross-process invalidation. The TTL
bounds staleness even when Redis is unreachable, while the
version stamp lets peers refresh without a DB read on the
common case (no invalidation since last load). Decryption
failures and DB errors yield an empty layer (logged) — the
user simply doesn't see their custom models on this request,
never a 500.
"""
cached = self._user_models.get(user_id)
now = time.monotonic()
if cached is not None:
layer, cached_version, loaded_at = cached
if (now - loaded_at) < _USER_CACHE_TTL_SECONDS:
return layer
# TTL elapsed: peek at the cross-process counter. If it
# matches what we saw at load time, no invalidation has
# happened — extend the TTL without touching Postgres. If
# Redis is unreachable (``current_version is None``) we
# fall through to a real reload, which keeps staleness
# bounded to the TTL.
current_version = self._read_user_version(user_id)
if (
current_version is not None
and cached_version is not None
and current_version == cached_version
):
self._user_models[user_id] = (layer, cached_version, now)
return layer
# Capture the counter *before* the DB read so a CRUD that lands
# mid-reload doesn't get masked: the next access will see a
# newer version and reload again.
version_before_read = self._read_user_version(user_id)
layer: Dict[str, AvailableModel] = {}
try:
from application.core.model_settings import (
ModelCapabilities,
ModelProvider,
)
from application.storage.db.repositories.user_custom_models import (
UserCustomModelsRepository,
)
from application.storage.db.session import db_readonly
with db_readonly() as conn:
repo = UserCustomModelsRepository(conn)
rows = repo.list_for_user(user_id)
for row in rows:
api_key = repo._decrypt_api_key(
row.get("api_key_encrypted", ""), user_id
)
if not api_key:
# SECURITY: do NOT register an unroutable BYOM
# record. If we did, LLMCreator would fall back
# to the caller-passed api_key (settings.API_KEY
# for openai_compatible) and POST it to the
# user-supplied base_url — leaking the instance
# credential to the user's chosen endpoint.
# Most likely cause is ENCRYPTION_SECRET_KEY
# having rotated; user must re-save the model.
logger.warning(
"user_custom_models: skipping model %s for "
"user %s — api_key could not be decrypted "
"(rotated ENCRYPTION_SECRET_KEY?). Re-save "
"the model to recover.",
row.get("id"),
user_id,
)
continue
caps_raw = row.get("capabilities") or {}
# Stored attachments may be aliases (``image``) or
# raw MIME types. Built-in YAML models expand at
# load time; mirror that here so downstream MIME-
# type comparisons (handlers/base.prepare_messages)
# match concrete types like ``image/png`` rather
# than the bare alias.
from application.core.model_yaml import (
expand_attachments_lenient,
)
raw_attachments = caps_raw.get("attachments", []) or []
expanded_attachments = expand_attachments_lenient(
raw_attachments,
f"user_custom_models[user={user_id}, model={row.get('id')}]",
)
caps = ModelCapabilities(
supports_tools=bool(caps_raw.get("supports_tools", False)),
supports_structured_output=bool(
caps_raw.get("supports_structured_output", False)
),
supports_streaming=bool(
caps_raw.get("supports_streaming", True)
),
supported_attachment_types=expanded_attachments,
context_window=int(
caps_raw.get("context_window") or 128000
),
)
model_id = str(row["id"])
layer[model_id] = AvailableModel(
id=model_id,
provider=ModelProvider.OPENAI_COMPATIBLE,
display_name=row["display_name"],
description=row.get("description") or "",
capabilities=caps,
enabled=bool(row.get("enabled", True)),
base_url=row["base_url"],
upstream_model_id=row["upstream_model_id"],
source="user",
api_key=api_key,
)
except Exception as e:
logger.warning(
"user_custom_models: failed to load layer for user %s: %s",
user_id,
e,
)
layer = {}
self._user_models[user_id] = (layer, version_before_read, now)
return layer
# Lookup API. ``user_id`` enables the BYOM per-user layer; without
# it, callers see only the built-in + operator catalog.
def get_model(
self, model_id: str, user_id: Optional[str] = None
) -> Optional[AvailableModel]:
if user_id:
user_layer = self._user_models_for(user_id)
if model_id in user_layer:
return user_layer[model_id]
return self.models.get(model_id)
def get_all_models(
self, user_id: Optional[str] = None
) -> List[AvailableModel]:
out = list(self.models.values())
if user_id:
out.extend(self._user_models_for(user_id).values())
return out
def get_enabled_models(
self, user_id: Optional[str] = None
) -> List[AvailableModel]:
out = [m for m in self.models.values() if m.enabled]
if user_id:
out.extend(
m for m in self._user_models_for(user_id).values() if m.enabled
)
return out
def model_exists(
self, model_id: str, user_id: Optional[str] = None
) -> bool:
if user_id and model_id in self._user_models_for(user_id):
return True
return model_id in self.models

View File

@@ -5,9 +5,16 @@ from typing import Dict, List, Optional
logger = logging.getLogger(__name__)
# Re-exported here so existing call sites (and tests) that do
# ``from application.core.model_settings import ModelRegistry`` keep
# working. The implementation lives in ``application/core/model_registry.py``.
# Imported lazily inside ``__getattr__`` to avoid an import cycle with
# ``model_yaml`` → ``model_settings`` (this file).
class ModelProvider(str, Enum):
OPENAI = "openai"
OPENAI_COMPATIBLE = "openai_compatible"
OPENROUTER = "openrouter"
AZURE_OPENAI = "azure_openai"
ANTHROPIC = "anthropic"
@@ -41,11 +48,21 @@ class AvailableModel:
capabilities: ModelCapabilities = field(default_factory=ModelCapabilities)
enabled: bool = True
base_url: Optional[str] = None
# User-facing label distinct from dispatch provider (e.g. mistral
# routed through openai_compatible).
display_provider: Optional[str] = None
# Sent in the API call's ``model`` field; falls back to ``self.id``
# for built-ins where id IS the upstream name.
upstream_model_id: Optional[str] = None
# "builtin" for catalog YAMLs, "user" for BYOM records.
source: str = "builtin"
# Decrypted/resolved at registry-merge time. Never serialized.
api_key: Optional[str] = field(default=None, repr=False, compare=False)
def to_dict(self) -> Dict:
result = {
"id": self.id,
"provider": self.provider.value,
"provider": self.display_provider or self.provider.value,
"display_name": self.display_name,
"description": self.description,
"supported_attachment_types": self.capabilities.supported_attachment_types,
@@ -54,261 +71,21 @@ class AvailableModel:
"supports_streaming": self.capabilities.supports_streaming,
"context_window": self.capabilities.context_window,
"enabled": self.enabled,
"source": self.source,
}
if self.base_url:
result["base_url"] = self.base_url
return result
class ModelRegistry:
_instance = None
_initialized = False
def __getattr__(name):
"""Lazy re-export of ``ModelRegistry`` from ``model_registry.py``.
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
Done lazily to avoid an import cycle: ``model_registry`` imports
``model_yaml`` which imports the dataclasses from this file.
"""
if name == "ModelRegistry":
from application.core.model_registry import ModelRegistry as _MR
def __init__(self):
if not ModelRegistry._initialized:
self.models: Dict[str, AvailableModel] = {}
self.default_model_id: Optional[str] = None
self._load_models()
ModelRegistry._initialized = True
@classmethod
def get_instance(cls) -> "ModelRegistry":
return cls()
def _load_models(self):
from application.core.settings import settings
self.models.clear()
# Skip DocsGPT model if using custom OpenAI-compatible endpoint
if not settings.OPENAI_BASE_URL:
self._add_docsgpt_models(settings)
if (
settings.OPENAI_API_KEY
or (settings.LLM_PROVIDER == "openai" and settings.API_KEY)
or settings.OPENAI_BASE_URL
):
self._add_openai_models(settings)
if settings.OPENAI_API_BASE or (
settings.LLM_PROVIDER == "azure_openai" and settings.API_KEY
):
self._add_azure_openai_models(settings)
if settings.ANTHROPIC_API_KEY or (
settings.LLM_PROVIDER == "anthropic" and settings.API_KEY
):
self._add_anthropic_models(settings)
if settings.GOOGLE_API_KEY or (
settings.LLM_PROVIDER == "google" and settings.API_KEY
):
self._add_google_models(settings)
if settings.GROQ_API_KEY or (
settings.LLM_PROVIDER == "groq" and settings.API_KEY
):
self._add_groq_models(settings)
if settings.OPEN_ROUTER_API_KEY or (
settings.LLM_PROVIDER == "openrouter" and settings.API_KEY
):
self._add_openrouter_models(settings)
if settings.NOVITA_API_KEY or (
settings.LLM_PROVIDER == "novita" and settings.API_KEY
):
self._add_novita_models(settings)
if settings.HUGGINGFACE_API_KEY or (
settings.LLM_PROVIDER == "huggingface" and settings.API_KEY
):
self._add_huggingface_models(settings)
# Default model selection
if settings.LLM_NAME:
# Parse LLM_NAME (may be comma-separated)
model_names = self._parse_model_names(settings.LLM_NAME)
# First model in the list becomes default
for model_name in model_names:
if model_name in self.models:
self.default_model_id = model_name
break
# Backward compat: try exact match if no parsed model found
if not self.default_model_id and settings.LLM_NAME in self.models:
self.default_model_id = settings.LLM_NAME
if not self.default_model_id:
if settings.LLM_PROVIDER and settings.API_KEY:
for model_id, model in self.models.items():
if model.provider.value == settings.LLM_PROVIDER:
self.default_model_id = model_id
break
if not self.default_model_id and self.models:
self.default_model_id = next(iter(self.models.keys()))
logger.info(
f"ModelRegistry loaded {len(self.models)} models, default: {self.default_model_id}"
)
def _add_openai_models(self, settings):
from application.core.model_configs import (
OPENAI_MODELS,
create_custom_openai_model,
)
# Check if using local OpenAI-compatible endpoint (Ollama, LM Studio, etc.)
using_local_endpoint = bool(
settings.OPENAI_BASE_URL and settings.OPENAI_BASE_URL.strip()
)
if using_local_endpoint:
# When OPENAI_BASE_URL is set, ONLY register custom models from LLM_NAME
# Do NOT add standard OpenAI models (gpt-5.1, etc.)
if settings.LLM_NAME:
model_names = self._parse_model_names(settings.LLM_NAME)
for model_name in model_names:
custom_model = create_custom_openai_model(
model_name, settings.OPENAI_BASE_URL
)
self.models[model_name] = custom_model
logger.info(
f"Registered custom OpenAI model: {model_name} at {settings.OPENAI_BASE_URL}"
)
else:
# Standard OpenAI API usage - add standard models if API key is valid
if settings.OPENAI_API_KEY:
for model in OPENAI_MODELS:
self.models[model.id] = model
def _add_azure_openai_models(self, settings):
from application.core.model_configs import AZURE_OPENAI_MODELS
if settings.LLM_PROVIDER == "azure_openai" and settings.LLM_NAME:
for model in AZURE_OPENAI_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in AZURE_OPENAI_MODELS:
self.models[model.id] = model
def _add_anthropic_models(self, settings):
from application.core.model_configs import ANTHROPIC_MODELS
if settings.ANTHROPIC_API_KEY:
for model in ANTHROPIC_MODELS:
self.models[model.id] = model
return
if settings.LLM_PROVIDER == "anthropic" and settings.LLM_NAME:
for model in ANTHROPIC_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in ANTHROPIC_MODELS:
self.models[model.id] = model
def _add_google_models(self, settings):
from application.core.model_configs import GOOGLE_MODELS
if settings.GOOGLE_API_KEY:
for model in GOOGLE_MODELS:
self.models[model.id] = model
return
if settings.LLM_PROVIDER == "google" and settings.LLM_NAME:
for model in GOOGLE_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in GOOGLE_MODELS:
self.models[model.id] = model
def _add_groq_models(self, settings):
from application.core.model_configs import GROQ_MODELS
if settings.GROQ_API_KEY:
for model in GROQ_MODELS:
self.models[model.id] = model
return
if settings.LLM_PROVIDER == "groq" and settings.LLM_NAME:
for model in GROQ_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in GROQ_MODELS:
self.models[model.id] = model
def _add_openrouter_models(self, settings):
from application.core.model_configs import OPENROUTER_MODELS
if settings.OPEN_ROUTER_API_KEY:
for model in OPENROUTER_MODELS:
self.models[model.id] = model
return
if settings.LLM_PROVIDER == "openrouter" and settings.LLM_NAME:
for model in OPENROUTER_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in OPENROUTER_MODELS:
self.models[model.id] = model
def _add_novita_models(self, settings):
from application.core.model_configs import NOVITA_MODELS
if settings.NOVITA_API_KEY:
for model in NOVITA_MODELS:
self.models[model.id] = model
return
if settings.LLM_PROVIDER == "novita" and settings.LLM_NAME:
for model in NOVITA_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in NOVITA_MODELS:
self.models[model.id] = model
def _add_docsgpt_models(self, settings):
model_id = "docsgpt-local"
model = AvailableModel(
id=model_id,
provider=ModelProvider.DOCSGPT,
display_name="DocsGPT Model",
description="Local model",
capabilities=ModelCapabilities(
supports_tools=False,
supported_attachment_types=[],
),
)
self.models[model_id] = model
def _add_huggingface_models(self, settings):
model_id = "huggingface-local"
model = AvailableModel(
id=model_id,
provider=ModelProvider.HUGGINGFACE,
display_name="Hugging Face Model",
description="Local Hugging Face model",
capabilities=ModelCapabilities(
supports_tools=False,
supported_attachment_types=[],
),
)
self.models[model_id] = model
def _parse_model_names(self, llm_name: str) -> List[str]:
"""
Parse LLM_NAME which may contain comma-separated model names.
E.g., 'deepseek-r1:1.5b,gemma:2b' -> ['deepseek-r1:1.5b', 'gemma:2b']
"""
if not llm_name:
return []
return [name.strip() for name in llm_name.split(",") if name.strip()]
def get_model(self, model_id: str) -> Optional[AvailableModel]:
return self.models.get(model_id)
def get_all_models(self) -> List[AvailableModel]:
return list(self.models.values())
def get_enabled_models(self) -> List[AvailableModel]:
return [m for m in self.models.values() if m.enabled]
def model_exists(self, model_id: str) -> bool:
return model_id in self.models
return _MR
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@@ -1,47 +1,59 @@
from typing import Any, Dict, Optional
from application.core.model_settings import ModelRegistry
from application.core.model_registry import ModelRegistry
def get_api_key_for_provider(provider: str) -> Optional[str]:
"""Get the appropriate API key for a provider"""
"""Get the appropriate API key for a provider.
Delegates to the provider plugin's ``get_api_key``. Falls back to the
generic ``settings.API_KEY`` for unknown providers.
"""
from application.core.settings import settings
from application.llm.providers import PROVIDERS_BY_NAME
provider_key_map = {
"openai": settings.OPENAI_API_KEY,
"openrouter": settings.OPEN_ROUTER_API_KEY,
"novita": settings.NOVITA_API_KEY,
"anthropic": settings.ANTHROPIC_API_KEY,
"google": settings.GOOGLE_API_KEY,
"groq": settings.GROQ_API_KEY,
"huggingface": settings.HUGGINGFACE_API_KEY,
"azure_openai": settings.API_KEY,
"docsgpt": None,
"llama.cpp": None,
}
provider_key = provider_key_map.get(provider)
if provider_key:
return provider_key
plugin = PROVIDERS_BY_NAME.get(provider)
if plugin is not None:
key = plugin.get_api_key(settings)
if key:
return key
return settings.API_KEY
def get_all_available_models() -> Dict[str, Dict[str, Any]]:
"""Get all available models with metadata for API response"""
def get_all_available_models(
user_id: Optional[str] = None,
) -> Dict[str, Dict[str, Any]]:
"""Get all available models with metadata for API response.
When ``user_id`` is supplied, the user's BYOM custom-model records
are merged into the result alongside the built-in catalog.
"""
registry = ModelRegistry.get_instance()
return {model.id: model.to_dict() for model in registry.get_enabled_models()}
return {
model.id: model.to_dict()
for model in registry.get_enabled_models(user_id=user_id)
}
def validate_model_id(model_id: str) -> bool:
"""Check if a model ID exists in registry"""
def validate_model_id(model_id: str, user_id: Optional[str] = None) -> bool:
"""Check if a model ID exists in registry.
``user_id`` enables resolution of per-user BYOM records (UUIDs).
Without it, only built-in catalog ids resolve.
"""
registry = ModelRegistry.get_instance()
return registry.model_exists(model_id)
return registry.model_exists(model_id, user_id=user_id)
def get_model_capabilities(model_id: str) -> Optional[Dict[str, Any]]:
"""Get capabilities for a specific model"""
def get_model_capabilities(
model_id: str, user_id: Optional[str] = None
) -> Optional[Dict[str, Any]]:
"""Get capabilities for a specific model.
``user_id`` enables resolution of per-user BYOM records.
"""
registry = ModelRegistry.get_instance()
model = registry.get_model(model_id)
model = registry.get_model(model_id, user_id=user_id)
if model:
return {
"supported_attachment_types": model.capabilities.supported_attachment_types,
@@ -58,36 +70,68 @@ def get_default_model_id() -> str:
return registry.default_model_id
def get_provider_from_model_id(model_id: str) -> Optional[str]:
"""Get the provider name for a given model_id"""
def get_provider_from_model_id(
model_id: str, user_id: Optional[str] = None
) -> Optional[str]:
"""Get the provider name for a given model_id.
``user_id`` enables resolution of per-user BYOM records (UUIDs).
Without it, BYOM model ids return ``None`` and the caller falls
back to the deployment default.
"""
registry = ModelRegistry.get_instance()
model = registry.get_model(model_id)
model = registry.get_model(model_id, user_id=user_id)
if model:
return model.provider.value
return None
def get_token_limit(model_id: str) -> int:
"""
Get context window (token limit) for a model.
Returns model's context_window or default 128000 if model not found.
def get_token_limit(model_id: str, user_id: Optional[str] = None) -> int:
"""Get context window (token limit) for a model.
Returns the model's ``context_window`` or ``DEFAULT_LLM_TOKEN_LIMIT``
if not found. ``user_id`` enables resolution of per-user BYOM records.
"""
from application.core.settings import settings
registry = ModelRegistry.get_instance()
model = registry.get_model(model_id)
model = registry.get_model(model_id, user_id=user_id)
if model:
return model.capabilities.context_window
return settings.DEFAULT_LLM_TOKEN_LIMIT
def get_base_url_for_model(model_id: str) -> Optional[str]:
"""
Get the custom base_url for a specific model if configured.
Returns None if no custom base_url is set.
def get_base_url_for_model(
model_id: str, user_id: Optional[str] = None
) -> Optional[str]:
"""Get the custom base_url for a specific model if configured.
Returns ``None`` if no custom base_url is set. ``user_id`` enables
resolution of per-user BYOM records.
"""
registry = ModelRegistry.get_instance()
model = registry.get_model(model_id)
model = registry.get_model(model_id, user_id=user_id)
if model:
return model.base_url
return None
def get_api_key_for_model(
model_id: str, user_id: Optional[str] = None
) -> Optional[str]:
"""Resolve the API key to use when invoking ``model_id``.
Priority:
1. The model record's own ``api_key`` (BYOM records and
``openai_compatible`` YAMLs populate this).
2. The provider plugin's settings-based key.
``user_id`` enables resolution of per-user BYOM records.
"""
registry = ModelRegistry.get_instance()
model = registry.get_model(model_id, user_id=user_id)
if model is not None and model.api_key:
return model.api_key
if model is not None:
return get_api_key_for_provider(model.provider.value)
return None

View File

@@ -0,0 +1,358 @@
"""YAML loader for model catalog files under ``application/core/models/``.
Each ``*.yaml`` file declares one provider's static model catalog. Files
are validated with Pydantic at load time; any parse, schema, or alias
error aborts startup with the offending file path in the message.
For most providers, one YAML maps to one catalog. The
``openai_compatible`` provider is special: each YAML file represents a
distinct logical endpoint (Mistral, Together, Ollama, ...) with its own
``api_key_env`` and ``base_url``. The loader returns a flat list so the
registry can distinguish multiple files with the same ``provider:`` value.
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Dict, List, Optional, Sequence
import yaml
from pydantic import BaseModel, ConfigDict, Field, field_validator
from application.core.model_settings import (
AvailableModel,
ModelCapabilities,
ModelProvider,
)
logger = logging.getLogger(__name__)
BUILTIN_MODELS_DIR = Path(__file__).parent / "models"
DEFAULTS_FILENAME = "_defaults.yaml"
class _DefaultsFile(BaseModel):
"""Schema for ``_defaults.yaml``. Currently just attachment aliases."""
model_config = ConfigDict(extra="forbid")
attachment_aliases: Dict[str, List[str]] = Field(default_factory=dict)
class _CapabilityFields(BaseModel):
"""Capability fields shared between provider ``defaults:`` and per-model overrides.
All fields are optional so a per-model override can selectively replace
a single field from the provider-level defaults.
"""
model_config = ConfigDict(extra="forbid")
supports_tools: Optional[bool] = None
supports_structured_output: Optional[bool] = None
supports_streaming: Optional[bool] = None
attachments: Optional[List[str]] = None
context_window: Optional[int] = None
input_cost_per_token: Optional[float] = None
output_cost_per_token: Optional[float] = None
class _ModelEntry(_CapabilityFields):
"""Schema for one model row inside a YAML's ``models:`` list."""
id: str
display_name: Optional[str] = None
description: str = ""
enabled: bool = True
base_url: Optional[str] = None
aliases: List[str] = Field(default_factory=list)
@field_validator("id")
@classmethod
def _id_nonempty(cls, v: str) -> str:
if not v or not v.strip():
raise ValueError("model id must be a non-empty string")
return v
class _ProviderFile(BaseModel):
"""Schema for one ``<provider>.yaml`` catalog file."""
model_config = ConfigDict(extra="forbid")
provider: str
defaults: _CapabilityFields = Field(default_factory=_CapabilityFields)
models: List[_ModelEntry] = Field(default_factory=list)
# openai_compatible metadata. Optional for other providers.
display_provider: Optional[str] = None
api_key_env: Optional[str] = None
base_url: Optional[str] = None
class ProviderCatalog(BaseModel):
"""One YAML file's parsed contents, ready for the registry.
For most providers, multiple catalogs with the same ``provider`` get
merged later by the registry. The ``openai_compatible`` provider is
the exception: each catalog is treated as a distinct endpoint, with
its own ``api_key_env`` and ``base_url``.
"""
provider: str
models: List[AvailableModel]
source_path: Optional[Path] = None
display_provider: Optional[str] = None
api_key_env: Optional[str] = None
base_url: Optional[str] = None
model_config = ConfigDict(arbitrary_types_allowed=True)
class ModelYAMLError(ValueError):
"""Raised when a model YAML fails parsing, schema, or alias validation."""
def _expand_attachments(
attachments: Sequence[str], aliases: Dict[str, List[str]], source: str
) -> List[str]:
"""Resolve attachment shorthands (``image``, ``pdf``) to MIME types.
Raw MIME-typed entries (containing ``/``) pass through unchanged.
Unknown aliases raise ``ModelYAMLError``.
"""
expanded: List[str] = []
seen: set = set()
for entry in attachments:
if "/" in entry:
if entry not in seen:
expanded.append(entry)
seen.add(entry)
continue
if entry not in aliases:
valid = ", ".join(sorted(aliases.keys())) or "<none defined>"
raise ModelYAMLError(
f"{source}: unknown attachment alias '{entry}'. "
f"Valid aliases: {valid}. "
"(Or use a raw MIME type like 'image/png'.)"
)
for mime in aliases[entry]:
if mime not in seen:
expanded.append(mime)
seen.add(mime)
return expanded
def _load_defaults(directory: Path) -> Dict[str, List[str]]:
"""Load ``_defaults.yaml`` from ``directory`` if it exists."""
path = directory / DEFAULTS_FILENAME
if not path.exists():
return {}
try:
raw = yaml.safe_load(path.read_text(encoding="utf-8")) or {}
except yaml.YAMLError as e:
raise ModelYAMLError(f"{path}: invalid YAML: {e}") from e
try:
parsed = _DefaultsFile.model_validate(raw)
except Exception as e:
raise ModelYAMLError(f"{path}: schema error: {e}") from e
return parsed.attachment_aliases
def _resolve_provider_enum(name: str, source: Path) -> ModelProvider:
try:
return ModelProvider(name)
except ValueError as e:
valid = ", ".join(p.value for p in ModelProvider)
raise ModelYAMLError(
f"{source}: unknown provider '{name}'. Valid: {valid}"
) from e
def _build_model(
entry: _ModelEntry,
defaults: _CapabilityFields,
provider: ModelProvider,
aliases: Dict[str, List[str]],
source: Path,
display_provider: Optional[str] = None,
) -> AvailableModel:
"""Merge defaults + per-model overrides into a final ``AvailableModel``."""
def pick(field_name: str, fallback):
v = getattr(entry, field_name)
if v is not None:
return v
d = getattr(defaults, field_name)
if d is not None:
return d
return fallback
raw_attachments = entry.attachments
if raw_attachments is None:
raw_attachments = defaults.attachments
if raw_attachments is None:
raw_attachments = []
expanded = _expand_attachments(
raw_attachments, aliases, f"{source} [model={entry.id}]"
)
caps = ModelCapabilities(
supports_tools=pick("supports_tools", False),
supports_structured_output=pick("supports_structured_output", False),
supports_streaming=pick("supports_streaming", True),
supported_attachment_types=expanded,
context_window=pick("context_window", 128000),
input_cost_per_token=pick("input_cost_per_token", None),
output_cost_per_token=pick("output_cost_per_token", None),
)
return AvailableModel(
id=entry.id,
provider=provider,
display_name=entry.display_name or entry.id,
description=entry.description,
capabilities=caps,
enabled=entry.enabled,
base_url=entry.base_url,
display_provider=display_provider,
)
def _load_one_yaml(
path: Path, aliases: Dict[str, List[str]]
) -> ProviderCatalog:
try:
raw = yaml.safe_load(path.read_text(encoding="utf-8")) or {}
except yaml.YAMLError as e:
raise ModelYAMLError(f"{path}: invalid YAML: {e}") from e
try:
parsed = _ProviderFile.model_validate(raw)
except Exception as e:
raise ModelYAMLError(f"{path}: schema error: {e}") from e
provider_enum = _resolve_provider_enum(parsed.provider, path)
models = [
_build_model(
entry,
parsed.defaults,
provider_enum,
aliases,
path,
display_provider=parsed.display_provider,
)
for entry in parsed.models
]
return ProviderCatalog(
provider=parsed.provider,
models=models,
source_path=path,
display_provider=parsed.display_provider,
api_key_env=parsed.api_key_env,
base_url=parsed.base_url,
)
_BUILTIN_ALIASES_CACHE: Optional[Dict[str, List[str]]] = None
def builtin_attachment_aliases() -> Dict[str, List[str]]:
"""Return the built-in attachment alias map from ``_defaults.yaml``.
Cached after first read so repeat calls are cheap.
"""
global _BUILTIN_ALIASES_CACHE
if _BUILTIN_ALIASES_CACHE is None:
_BUILTIN_ALIASES_CACHE = _load_defaults(BUILTIN_MODELS_DIR)
return _BUILTIN_ALIASES_CACHE
def resolve_attachment_alias(alias: str) -> List[str]:
"""Resolve a single attachment alias (e.g. ``"image"``) to its
canonical MIME-type list. Raises ``ModelYAMLError`` if unknown.
"""
aliases = builtin_attachment_aliases()
if alias not in aliases:
valid = ", ".join(sorted(aliases.keys())) or "<none defined>"
raise ModelYAMLError(
f"Unknown attachment alias '{alias}'. Valid: {valid}"
)
return list(aliases[alias])
def expand_attachments_lenient(
attachments: Sequence[str], source: str
) -> List[str]:
"""Expand attachment aliases to MIME types, tolerating unknowns.
Mirrors ``_expand_attachments`` but logs+skips unknown aliases
rather than raising. Used for runtime call sites (BYOM registry
load) where an operator-side alias-map edit must not drop the
entire user's BYOM layer; the strict raise still happens at the
API validation boundary.
"""
aliases = builtin_attachment_aliases()
expanded: List[str] = []
seen: set = set()
for entry in attachments:
if "/" in entry:
if entry not in seen:
expanded.append(entry)
seen.add(entry)
continue
mime_list = aliases.get(entry)
if mime_list is None:
logger.warning(
"%s: skipping unknown attachment alias %r", source, entry,
)
continue
for mime in mime_list:
if mime not in seen:
expanded.append(mime)
seen.add(mime)
return expanded
def load_model_yamls(directories: Sequence[Path]) -> List[ProviderCatalog]:
"""Load every ``*.yaml`` file (excluding ``_defaults.yaml``) under each
directory in order and return a flat list of catalogs.
Caller is responsible for merging multiple catalogs that target the
same provider plugin. The flat-list shape lets ``openai_compatible``
keep each file separate (one logical endpoint per file).
When the same model ``id`` appears in more than one YAML across the
directory list, a warning is logged. Order in the returned list
preserves load order, so the registry's "later wins" merge gives the
later directory's definition.
"""
catalogs: List[ProviderCatalog] = []
seen_ids: Dict[str, Path] = {}
aliases: Dict[str, List[str]] = {}
for d in directories:
if not d or not d.exists():
continue
aliases.update(_load_defaults(d))
for d in directories:
if not d or not d.exists():
continue
for path in sorted(d.glob("*.yaml")):
if path.name == DEFAULTS_FILENAME:
continue
catalog = _load_one_yaml(path, aliases)
catalogs.append(catalog)
for m in catalog.models:
prior = seen_ids.get(m.id)
if prior is not None and prior != path:
logger.warning(
"Model id %r redefined: %s overrides %s (later wins)",
m.id,
path,
prior,
)
seen_ids[m.id] = path
return catalogs

View File

@@ -0,0 +1,213 @@
# Model catalogs
Each `*.yaml` file in this directory declares one provider's model
catalog. The registry loads every YAML at boot and joins it to the
matching provider plugin under `application/llm/providers/`.
To add or edit models, you almost always only touch a YAML here — no
Python code required.
## Add a model to an existing provider
Open the provider's YAML (e.g. `anthropic.yaml`) and append two lines
under `models:`:
```yaml
models:
- id: claude-3-7-sonnet
display_name: Claude 3.7 Sonnet
```
Capabilities default to the provider's `defaults:` block. Override
per-model only when needed:
```yaml
- id: claude-3-7-sonnet
display_name: Claude 3.7 Sonnet
context_window: 500000
```
Restart the app. The new model appears in `/api/models`.
> The model `id` is what gets stored in agent / workflow records. Once
> users start picking the model, **don't rename it** — agent and
> workflow rows reference it as a free-form string and silently fall
> back to the system default if the id disappears.
## Add an OpenAI-compatible provider (zero Python)
Drop a YAML in this directory (or in your `MODELS_CONFIG_DIR`) that uses
the `openai_compatible` plugin. Set the env var named in `api_key_env`
and you're done — no Python, no settings.py edit, no LLMCreator change:
```yaml
# mistral.yaml
provider: openai_compatible
display_provider: mistral # shown in /api/models response
api_key_env: MISTRAL_API_KEY # env var the plugin reads at boot
base_url: https://api.mistral.ai/v1
defaults:
supports_tools: true
context_window: 128000
models:
- id: mistral-large-latest
display_name: Mistral Large
- id: mistral-small-latest
display_name: Mistral Small
```
`MISTRAL_API_KEY=sk-... ; restart` — Mistral models appear in
`/api/models` with `provider: "mistral"`. They route through the OpenAI
wire format (it's `OpenAILLM` under the hood) but with Mistral's
endpoint and key.
Multiple `openai_compatible` YAMLs coexist: each file is one logical
endpoint with its own `api_key_env` and `base_url`. Drop in
`together.yaml`, `fireworks.yaml`, etc. side by side. If an env var
isn't set, that catalog is silently skipped at boot (logged at INFO) —
no error.
Working example: `examples/mistral.yaml.example`. Files inside
`examples/` aren't loaded by the registry; the glob only picks up
`*.yaml` at the top level.
## Add a provider with its own SDK
For a provider that doesn't speak OpenAI's wire format, add one Python
file to `application/llm/providers/<name>.py`:
```python
from application.llm.providers.base import Provider
from application.llm.my_provider import MyLLM
class MyProvider(Provider):
name = "my_provider"
llm_class = MyLLM
def get_api_key(self, settings):
return settings.MY_PROVIDER_API_KEY
```
Register it in `application/llm/providers/__init__.py` (one line in
`ALL_PROVIDERS`), add `MY_PROVIDER_API_KEY` to `settings.py`, and create
`my_provider.yaml` here with the model catalog.
## Schema reference
```yaml
provider: <string, required> # matches the Provider plugin's `name`
# openai_compatible only — required for that provider, ignored for others
display_provider: <string> # label shown in /api/models response
api_key_env: <string> # name of the env var carrying the key
base_url: <string> # endpoint URL
defaults: # optional, applied to every model below
supports_tools: bool # default false
supports_structured_output: bool # default false
supports_streaming: bool # default true
attachments: [<alias-or-mime>, ...] # default []
context_window: int # default 128000
input_cost_per_token: float # default null
output_cost_per_token: float # default null
models: # required
- id: <string, required> # the value persisted in agent records
display_name: <string> # default: id
description: <string> # default: ""
enabled: bool # default true; false hides from /api/models
base_url: <string> # optional custom endpoint for this model
# All `defaults:` fields above can be overridden here per-model.
```
### Attachment aliases
The `attachments:` list can mix human-readable aliases with raw MIME
types. Aliases are defined in `_defaults.yaml`:
| Alias | Expands to |
|---|---|
| `image` | `image/png`, `image/jpeg`, `image/jpg`, `image/webp`, `image/gif` |
| `pdf` | `application/pdf` |
| `audio` | `audio/mpeg`, `audio/wav`, `audio/ogg` |
Use raw MIME types when you need surgical control:
```yaml
attachments: [image/png, image/webp] # only these two
```
## Operator-supplied YAMLs (`MODELS_CONFIG_DIR`)
Set the `MODELS_CONFIG_DIR` env var (or `.env` entry) to a directory
path. Every `*.yaml` in that directory is loaded **after** the built-in
catalog under `application/core/models/`. Operators use this to:
- Add new `openai_compatible` providers (Mistral, Together, Fireworks,
Ollama, ...) without forking the repo.
- Extend an existing provider's catalog with extra models — append
models under `provider: anthropic` and they show up alongside the
built-ins.
- Override a built-in model's capabilities — declare the same `id`
with different fields (e.g. a higher `context_window`). Later wins;
the override is logged as a `WARNING` so you can audit it.
Things you cannot do via `MODELS_CONFIG_DIR`:
- Add a brand-new non-OpenAI provider — that needs a Python plugin
under `application/llm/providers/` (see "Add a provider with its own
SDK" above). Operator YAMLs may only target a `provider:` value that
already has a registered plugin.
### Example: Docker
Mount your model YAMLs into the container and point the env var at the
mount path:
```yaml
# docker-compose.yml
services:
app:
image: arc53/docsgpt
environment:
MODELS_CONFIG_DIR: /etc/docsgpt/models
MISTRAL_API_KEY: ${MISTRAL_API_KEY}
volumes:
- ./my-models:/etc/docsgpt/models:ro
```
Then `./my-models/mistral.yaml` (the file from
`examples/mistral.yaml.example`) gets picked up at boot.
### Example: Kubernetes
Mount a `ConfigMap` containing your YAMLs at a known path and set
`MODELS_CONFIG_DIR` on the deployment. The same `examples/mistral.yaml.example`
becomes a key in the ConfigMap.
### Misconfiguration
If `MODELS_CONFIG_DIR` is set but the path doesn't exist (or isn't a
directory), the app logs a `WARNING` at boot and continues with just
the built-in catalog. The app does *not* fail to start — operators can
ship config drift without taking down the service — but the warning is
loud enough to surface in any reasonable log aggregator.
## Validation
YAMLs are parsed with Pydantic at boot. The app fails to start with a
clear error message if:
- a top-level key is unknown
- a model is missing `id`
- an attachment alias isn't defined
- the `provider:` value isn't registered as a plugin
This is intentional — silent fallbacks would mean users don't notice
their model picks broke until they hit the API.
## Reserved fields (not yet implemented)
- `aliases:` on a model — old IDs that resolve to this model. Reserved
for future renames; the schema accepts the field but it is not yet
acted on.

View File

@@ -0,0 +1,18 @@
# Global defaults applied across every model YAML in this directory.
# Keep this file sparse — per-provider `defaults:` blocks are clearer
# than a deep global default chain. This file is for things that
# genuinely never vary, like the meaning of "image".
attachment_aliases:
image:
- image/png
- image/jpeg
- image/jpg
- image/webp
- image/gif
pdf:
- application/pdf
audio:
- audio/mpeg
- audio/wav
- audio/ogg

View File

@@ -0,0 +1,23 @@
provider: anthropic
defaults:
supports_tools: true
attachments: [image]
context_window: 200000
models:
- id: claude-opus-4-7
display_name: Claude Opus 4.7
description: Most capable Claude model for complex reasoning and agentic coding
context_window: 1000000
supports_structured_output: true
- id: claude-sonnet-4-6
display_name: Claude Sonnet 4.6
description: Best balance of speed and intelligence with extended thinking
context_window: 1000000
supports_structured_output: true
- id: claude-haiku-4-5
display_name: Claude Haiku 4.5
description: Fastest Claude model with near-frontier intelligence
supports_structured_output: true

View File

@@ -0,0 +1,31 @@
# Azure OpenAI catalog.
#
# IMPORTANT: For Azure OpenAI, the `id` field is the **deployment name**, not
# a model name. Deployment names are arbitrary strings the operator chooses
# in Azure portal (or via ARM/Bicep/Terraform) when they create a deployment
# for a given underlying model + version.
#
# The IDs below are sensible defaults that mirror the underlying OpenAI
# model name (prefixed with `azure-`). Operators almost always need to
# override them via `MODELS_CONFIG_DIR` to match the deployment names that
# actually exist in their Azure resource. The `display_name`, capability
# flags, and `context_window` reflect the underlying OpenAI model.
provider: azure_openai
defaults:
supports_tools: true
supports_structured_output: true
attachments: [image]
context_window: 400000
models:
- id: azure-gpt-5.5
display_name: Azure OpenAI GPT-5.5
description: Azure-hosted flagship frontier model for complex reasoning, coding, and agentic work with a 1M-token context window
context_window: 1050000
- id: azure-gpt-5.4-mini
display_name: Azure OpenAI GPT-5.4 Mini
description: Azure-hosted cost-efficient GPT-5.4-class model for high-volume coding, computer use, and subagent workloads
- id: azure-gpt-5.4-nano
display_name: Azure OpenAI GPT-5.4 Nano
description: Azure-hosted cheapest GPT-5.4-class model, optimized for simple high-volume tasks where speed and cost matter most

View File

@@ -0,0 +1,7 @@
provider: docsgpt
models:
- id: docsgpt-local
display_name: DocsGPT Model
description: Local model
supports_tools: false

View File

@@ -0,0 +1,31 @@
# EXAMPLE — copy this file to ../mistral.yaml (or to your
# MODELS_CONFIG_DIR) and set MISTRAL_API_KEY in your environment.
#
# This is the entire integration. No Python required: the
# `openai_compatible` plugin reads `api_key_env` and `base_url` from
# the file and routes calls through the OpenAI wire format.
#
# Files in this `examples/` directory are NOT loaded by the registry
# (the loader globs *.yaml at the top level only).
provider: openai_compatible
display_provider: mistral # shown in /api/models response
api_key_env: MISTRAL_API_KEY # env var the plugin reads
base_url: https://api.mistral.ai/v1 # OpenAI-compatible endpoint
defaults:
supports_tools: true
context_window: 128000
models:
- id: mistral-large-latest
display_name: Mistral Large
description: Top-tier reasoning model
- id: mistral-small-latest
display_name: Mistral Small
description: Fast, cost-efficient
- id: codestral-latest
display_name: Codestral
description: Code-specialized model

View File

@@ -0,0 +1,17 @@
provider: google
defaults:
supports_tools: true
supports_structured_output: true
attachments: [pdf, image]
context_window: 1048576
models:
- id: gemini-3.1-pro-preview
display_name: Gemini 3.1 Pro
description: Most capable Gemini 3 model with advanced reasoning and agentic coding (preview)
- id: gemini-3-flash-preview
display_name: Gemini 3 Flash
description: Frontier-class performance for low-latency, high-volume tasks (preview)
- id: gemini-3.1-flash-lite-preview
display_name: Gemini 3.1 Flash-Lite
description: Cost-efficient frontier-class multimodal model for high-throughput workloads (preview)

View File

@@ -0,0 +1,16 @@
provider: groq
defaults:
supports_tools: true
context_window: 131072
models:
- id: openai/gpt-oss-120b
display_name: GPT-OSS 120B
description: OpenAI's open-weight 120B flagship served on Groq's LPU hardware; strong general reasoning with strict structured output support
supports_structured_output: true
- id: llama-3.3-70b-versatile
display_name: Llama 3.3 70B Versatile
description: Meta's Llama 3.3 70B for general-purpose chat with parallel tool use
- id: llama-3.1-8b-instant
display_name: Llama 3.1 8B Instant
description: Small, very low-latency Llama model (~560 tok/s) with parallel tool use

View File

@@ -0,0 +1,7 @@
provider: huggingface
models:
- id: huggingface-local
display_name: Hugging Face Model
description: Local Hugging Face model
supports_tools: false

View File

@@ -0,0 +1,21 @@
provider: novita
defaults:
supports_tools: true
supports_structured_output: true
models:
- id: deepseek/deepseek-v4-pro
display_name: DeepSeek V4 Pro
description: 1.6T MoE (49B active) with 1M context, hybrid CSA/HCA attention, top-tier reasoning and agentic coding
context_window: 1048576
- id: moonshotai/kimi-k2.6
display_name: Kimi K2.6
description: 1T-parameter open-weight MoE with native vision/video, multi-step tool calling, and agentic long-horizon execution
attachments: [image]
context_window: 262144
- id: zai-org/glm-5
display_name: GLM-5
description: Z.AI 754B-parameter MoE with strong general reasoning, function calling, and structured output
context_window: 202800

View File

@@ -0,0 +1,18 @@
provider: openai
defaults:
supports_tools: true
supports_structured_output: true
attachments: [image]
context_window: 400000
models:
- id: gpt-5.5
display_name: GPT-5.5
description: Flagship frontier model for complex reasoning, coding, and agentic work with a 1M-token context window
context_window: 1050000
- id: gpt-5.4-mini
display_name: GPT-5.4 Mini
description: Cost-efficient GPT-5.4-class model for high-volume coding, computer use, and subagent workloads
- id: gpt-5.4-nano
display_name: GPT-5.4 Nano
description: Cheapest GPT-5.4-class model, optimized for simple high-volume tasks where speed and cost matter most

View File

@@ -0,0 +1,25 @@
provider: openrouter
defaults:
supports_tools: true
attachments: [image]
context_window: 128000
models:
- id: qwen/qwen3-coder:free
display_name: Qwen3 Coder (free)
description: Free-tier 480B MoE coder model with strong agentic tool use; rate-limited
context_window: 262000
attachments: []
- id: deepseek/deepseek-v3.2
display_name: DeepSeek V3.2
description: Open-weights reasoning model, very low cost (~$0.25 in / $0.38 out per 1M)
context_window: 131072
attachments: []
supports_structured_output: true
- id: anthropic/claude-sonnet-4.6
display_name: Claude Sonnet 4.6 (via OpenRouter)
description: Frontier Sonnet-class model with 1M context, vision, and extended thinking
context_window: 1000000
supports_structured_output: true

View File

@@ -23,6 +23,10 @@ class Settings(BaseSettings):
EMBEDDINGS_NAME: str = "huggingface_sentence-transformers/all-mpnet-base-v2"
EMBEDDINGS_BASE_URL: Optional[str] = None # Remote embeddings API URL (OpenAI-compatible)
EMBEDDINGS_KEY: Optional[str] = None # api key for embeddings (if using openai, just copy API_KEY)
# Optional directory of operator-supplied model YAMLs, loaded after the
# built-in catalog under application/core/models/. Later wins on
# duplicate model id. See application/core/models/README.md.
MODELS_CONFIG_DIR: Optional[str] = None
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"

View File

@@ -17,6 +17,8 @@ class BaseLLM(ABC):
model_id=None,
base_url=None,
backup_models=None,
model_user_id=None,
capabilities=None,
):
self.decoded_token = decoded_token
self.agent_id = str(agent_id) if agent_id else None
@@ -25,6 +27,12 @@ class BaseLLM(ABC):
self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
self._backup_models = backup_models or []
self._fallback_llm = None
# Registry-resolved per-model capability overrides (BYOM caps,
# operator YAML). None falls back to provider-class defaults.
self.capabilities = capabilities
# BYOM-resolution scope captured at LLM creation time so backup
# / fallback lookups hit the same per-user layer as the primary.
self.model_user_id = model_user_id
@property
def fallback_llm(self):
@@ -39,10 +47,19 @@ class BaseLLM(ABC):
get_api_key_for_provider,
)
# Try per-agent backup models first
# model_user_id (BYOM scope) takes precedence over the caller's
# sub so shared-agent backups resolve under the owner's layer.
caller_sub = (
self.decoded_token.get("sub")
if isinstance(self.decoded_token, dict)
else None
)
backup_user_id = self.model_user_id or caller_sub
for backup_model_id in self._backup_models:
try:
provider = get_provider_from_model_id(backup_model_id)
provider = get_provider_from_model_id(
backup_model_id, user_id=backup_user_id
)
if not provider:
logger.warning(
f"Could not resolve provider for backup model: {backup_model_id}"
@@ -56,6 +73,7 @@ class BaseLLM(ABC):
decoded_token=self.decoded_token,
model_id=backup_model_id,
agent_id=self.agent_id,
model_user_id=self.model_user_id,
)
logger.info(
f"Fallback LLM initialized from agent backup model: "
@@ -68,7 +86,10 @@ class BaseLLM(ABC):
)
continue
# Fall back to global FALLBACK_* settings
# Fall back to global FALLBACK_* settings. Forward
# ``model_user_id`` here too: deployments can configure
# ``FALLBACK_LLM_NAME`` to a BYOM UUID, and that UUID is owned
# by the same user the primary model was resolved under.
if settings.FALLBACK_LLM_PROVIDER:
try:
self._fallback_llm = LLMCreator.create_llm(
@@ -78,6 +99,7 @@ class BaseLLM(ABC):
decoded_token=self.decoded_token,
model_id=settings.FALLBACK_LLM_NAME,
agent_id=self.agent_id,
model_user_id=self.model_user_id,
)
logger.info(
f"Fallback LLM initialized from global settings: "

View File

@@ -470,10 +470,14 @@ class LLMHandler(ABC):
)
return self._perform_in_memory_compression(agent, messages)
# Use orchestrator to perform compression
# Use orchestrator to perform compression. ``model_user_id``
# keeps BYOM registry resolution scoped to the model owner
# (shared-agent dispatch) while ``user_id`` stays the caller
# for the conversation access check.
result = orchestrator.compress_mid_execution(
conversation_id=agent.conversation_id,
user_id=agent.initial_user_id,
model_user_id=getattr(agent, "model_user_id", None),
model_id=agent.model_id,
decoded_token=getattr(agent, "decoded_token", {}),
current_conversation=conversation,
@@ -577,7 +581,20 @@ class LLMHandler(ABC):
if settings.COMPRESSION_MODEL_OVERRIDE
else agent.model_id
)
provider = get_provider_from_model_id(compression_model)
agent_decoded = getattr(agent, "decoded_token", None)
caller_sub = (
agent_decoded.get("sub")
if isinstance(agent_decoded, dict)
else None
)
# Use model-owner scope (mirrors orchestrator path) so
# shared-agent owner-BYOM resolves under the owner's layer.
compression_user_id = (
getattr(agent, "model_user_id", None) or caller_sub
)
provider = get_provider_from_model_id(
compression_model, user_id=compression_user_id
)
api_key = get_api_key_for_provider(provider)
compression_llm = LLMCreator.create_llm(
provider,
@@ -586,6 +603,7 @@ class LLMHandler(ABC):
getattr(agent, "decoded_token", None),
model_id=compression_model,
agent_id=getattr(agent, "agent_id", None),
model_user_id=compression_user_id,
)
# Create service without DB persistence capability
@@ -921,8 +939,15 @@ class LLMHandler(ABC):
}
return ""
# ``agent.model_id`` is the registry id (a UUID for BYOM
# records). Use the LLM's own model_id, which LLMCreator
# already resolved to the upstream model name. Built-ins:
# the two are equal; BYOM: the upstream name like
# "mistral-large-latest" instead of the UUID.
response = agent.llm.gen(
model=agent.model_id, messages=messages, tools=agent.tools
model=getattr(agent.llm, "model_id", None) or agent.model_id,
messages=messages,
tools=agent.tools,
)
parsed = self.parse_response(response)
self.llm_calls.append(build_stack_data(agent.llm))
@@ -1011,8 +1036,11 @@ class LLMHandler(ABC):
})
logger.info("Context limit reached - instructing agent to wrap up")
# See note above on agent.model_id vs llm.model_id.
response = agent.llm.gen_stream(
model=agent.model_id, messages=messages, tools=agent.tools if not agent.context_limit_reached else None
model=getattr(agent.llm, "model_id", None) or agent.model_id,
messages=messages,
tools=agent.tools if not agent.context_limit_reached else None,
)
self.llm_calls.append(build_stack_data(agent.llm))

View File

@@ -1,34 +1,11 @@
import logging
from application.llm.anthropic import AnthropicLLM
from application.llm.docsgpt_provider import DocsGPTAPILLM
from application.llm.google_ai import GoogleLLM
from application.llm.groq import GroqLLM
from application.llm.llama_cpp import LlamaCpp
from application.llm.novita import NovitaLLM
from application.llm.openai import AzureOpenAILLM, OpenAILLM
from application.llm.premai import PremAILLM
from application.llm.sagemaker import SagemakerAPILLM
from application.llm.open_router import OpenRouterLLM
from application.llm.providers import PROVIDERS_BY_NAME
logger = logging.getLogger(__name__)
class LLMCreator:
llms = {
"openai": OpenAILLM,
"azure_openai": AzureOpenAILLM,
"sagemaker": SagemakerAPILLM,
"llama.cpp": LlamaCpp,
"anthropic": AnthropicLLM,
"docsgpt": DocsGPTAPILLM,
"premai": PremAILLM,
"groq": GroqLLM,
"google": GoogleLLM,
"novita": NovitaLLM,
"openrouter": OpenRouterLLM,
}
@classmethod
def create_llm(
cls,
@@ -39,28 +16,111 @@ class LLMCreator:
model_id=None,
agent_id=None,
backup_models=None,
model_user_id=None,
*args,
**kwargs,
):
from application.core.model_utils import get_base_url_for_model
"""Construct an LLM for the given provider ``type``.
llm_class = cls.llms.get(type.lower())
if not llm_class:
``model_user_id`` is the BYOM-resolution scope. Defaults to
``decoded_token['sub']`` (the caller). Pass it explicitly when
the model record belongs to a *different* user — most notably
for shared-agent dispatch, where the agent's stored
``default_model_id`` is the owner's BYOM UUID but
``decoded_token`` represents the caller.
"""
from application.core.model_registry import ModelRegistry
from application.security.safe_url import (
UnsafeUserUrlError,
pinned_httpx_client,
validate_user_base_url,
)
plugin = PROVIDERS_BY_NAME.get(type.lower())
if plugin is None or plugin.llm_class is None:
raise ValueError(f"No LLM class found for type {type}")
# Extract base_url from model configuration if model_id is provided
# Prefer per-model endpoint config from the registry. This is what
# makes openai_compatible AND end-user BYOM work without changing
# every call site: if the registered AvailableModel carries its
# own api_key / base_url, they win over whatever the caller
# resolved via the provider plugin.
#
# End-user BYOM lookups need the user_id from decoded_token to
# find the user's per-user models layer (built-in models resolve
# without it, so this stays back-compat).
base_url = None
upstream_model_id = model_id
capabilities = None
if model_id:
base_url = get_base_url_for_model(model_id)
user_id = model_user_id
if user_id is None:
user_id = (
(decoded_token or {}).get("sub") if decoded_token else None
)
model = ModelRegistry.get_instance().get_model(model_id, user_id=user_id)
if model is not None:
# Forward registry caps so the LLM enforces them at
# dispatch (built-in classes hard-code True otherwise).
capabilities = getattr(model, "capabilities", None)
# SECURITY: refuse user-source dispatch without its own
# api_key (would leak settings.API_KEY to base_url).
if (
getattr(model, "source", "builtin") == "user"
and not model.api_key
):
raise ValueError(
f"Custom model {model_id!r} has no usable API key "
"(decryption may have failed). Re-save the model "
"in settings to dispatch it."
)
if model.api_key:
api_key = model.api_key
if model.base_url:
base_url = model.base_url
# For BYOM the registry id is a UUID; the upstream API
# call needs the user's typed model name instead.
if model.upstream_model_id:
upstream_model_id = model.upstream_model_id
return llm_class(
# SECURITY: re-validate at dispatch (defense in depth
# for pre-guard rows / YAML-supplied entries). The
# pinned httpx.Client below is what actually closes the
# DNS-rebinding TOCTOU window.
if base_url and getattr(model, "source", "builtin") == "user":
try:
validate_user_base_url(base_url)
except UnsafeUserUrlError as e:
raise ValueError(
f"Refusing to dispatch model {model_id!r}: {e}"
) from e
# Pinned httpx.Client: resolves once, validates, and
# binds the SDK's outbound socket to the validated IP
# (preserves Host / SNI). Future BYOM providers must
# opt in explicitly — only openai_compatible takes
# http_client today.
if plugin.name == "openai_compatible":
try:
kwargs["http_client"] = pinned_httpx_client(
base_url
)
except UnsafeUserUrlError as e:
raise ValueError(
f"Refusing to dispatch model {model_id!r}: {e}"
) from e
# Forward model_user_id so backup/fallback resolves under the
# owner's scope on shared-agent dispatch.
return plugin.llm_class(
api_key,
user_api_key,
decoded_token=decoded_token,
model_id=model_id,
model_id=upstream_model_id,
agent_id=agent_id,
base_url=base_url,
backup_models=backup_models,
model_user_id=model_user_id,
capabilities=capabilities,
*args,
**kwargs,
)

View File

@@ -62,7 +62,15 @@ def _truncate_base64_for_logging(messages):
class OpenAILLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
def __init__(
self,
api_key=None,
user_api_key=None,
base_url=None,
http_client=None,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.api_key = api_key or settings.OPENAI_API_KEY or settings.API_KEY
@@ -80,7 +88,18 @@ class OpenAILLM(BaseLLM):
else:
effective_base_url = "https://api.openai.com/v1"
self.client = OpenAI(api_key=self.api_key, base_url=effective_base_url)
# http_client (set by LLMCreator for BYOM) is a DNS-rebinding-safe
# httpx.Client; without it the SDK re-resolves DNS per request.
if http_client is not None:
self.client = OpenAI(
api_key=self.api_key,
base_url=effective_base_url,
http_client=http_client,
)
else:
self.client = OpenAI(
api_key=self.api_key, base_url=effective_base_url
)
self.storage = StorageCreator.get_storage()
def _clean_messages_openai(self, messages):
@@ -243,6 +262,13 @@ class OpenAILLM(BaseLLM):
if "max_tokens" in kwargs:
kwargs["max_completion_tokens"] = kwargs.pop("max_tokens")
# Defense-in-depth: drop tools / response_format if the
# registry's capability flags deny them.
if tools and not self._supports_tools():
tools = None
if response_format and not self._supports_structured_output():
response_format = None
request_params = {
"model": model,
"messages": messages,
@@ -279,6 +305,13 @@ class OpenAILLM(BaseLLM):
if "max_tokens" in kwargs:
kwargs["max_completion_tokens"] = kwargs.pop("max_tokens")
# See _raw_gen for rationale — drop tools/response_format when the
# registry-provided capabilities say the model doesn't support them.
if tools and not self._supports_tools():
tools = None
if response_format and not self._supports_structured_output():
response_format = None
request_params = {
"model": model,
"messages": messages,
@@ -320,9 +353,17 @@ class OpenAILLM(BaseLLM):
response.close()
def _supports_tools(self):
# When the LLM was constructed via LLMCreator with a registered
# AvailableModel, ``self.capabilities`` is the per-model record.
# BYOM users can disable tool support; respect that. Otherwise
# OpenAI's API supports tools by default.
if self.capabilities is not None:
return bool(self.capabilities.supports_tools)
return True
def _supports_structured_output(self):
if self.capabilities is not None:
return bool(self.capabilities.supports_structured_output)
return True
def prepare_structured_output_format(self, json_schema):
@@ -389,8 +430,14 @@ class OpenAILLM(BaseLLM):
Returns:
list: List of supported MIME types
"""
from application.core.model_configs import OPENAI_ATTACHMENTS
return OPENAI_ATTACHMENTS
# Per-model caps from the registry win when present — a BYOM
# endpoint that doesn't accept images would otherwise still be
# sent base64 image parts because the OpenAI default below
# advertises the image alias unconditionally.
if self.capabilities is not None:
return list(self.capabilities.supported_attachment_types or [])
from application.core.model_yaml import resolve_attachment_alias
return resolve_attachment_alias("image")
def prepare_messages_with_attachments(self, messages, attachments=None):
"""

View File

@@ -0,0 +1,51 @@
"""Provider plugin registry.
Plugins are imported eagerly so import errors surface at app boot rather
than at first request. ``ALL_PROVIDERS`` is the canonical ordered list;
``PROVIDERS_BY_NAME`` is a name-keyed lookup for LLMCreator and the
model registry.
"""
from __future__ import annotations
from typing import Dict, List
from application.llm.providers.anthropic import AnthropicProvider
from application.llm.providers.azure_openai import AzureOpenAIProvider
from application.llm.providers.base import Provider
from application.llm.providers.docsgpt import DocsGPTProvider
from application.llm.providers.google import GoogleProvider
from application.llm.providers.groq import GroqProvider
from application.llm.providers.huggingface import HuggingFaceProvider
from application.llm.providers.llama_cpp import LlamaCppProvider
from application.llm.providers.novita import NovitaProvider
from application.llm.providers.openai import OpenAIProvider
from application.llm.providers.openai_compatible import OpenAICompatibleProvider
from application.llm.providers.openrouter import OpenRouterProvider
from application.llm.providers.premai import PremAIProvider
from application.llm.providers.sagemaker import SagemakerProvider
# Order here is the order the registry iterates providers (and therefore
# the order ``/api/models`` reports them). Match the historical order
# from the old ModelRegistry._load_models for byte-stable output during
# the migration. ``openai_compatible`` slots in right after ``openai``
# so legacy ``OPENAI_BASE_URL`` models keep landing in the same place.
ALL_PROVIDERS: List[Provider] = [
DocsGPTProvider(),
OpenAIProvider(),
OpenAICompatibleProvider(),
AzureOpenAIProvider(),
AnthropicProvider(),
GoogleProvider(),
GroqProvider(),
OpenRouterProvider(),
NovitaProvider(),
HuggingFaceProvider(),
LlamaCppProvider(),
PremAIProvider(),
SagemakerProvider(),
]
PROVIDERS_BY_NAME: Dict[str, Provider] = {p.name: p for p in ALL_PROVIDERS}
__all__ = ["ALL_PROVIDERS", "PROVIDERS_BY_NAME", "Provider"]

View File

@@ -0,0 +1,51 @@
"""Shared helper for providers that follow the
``<X>_API_KEY or (LLM_PROVIDER==X and API_KEY)`` pattern.
This is the dominant pattern across Anthropic, Google, Groq, OpenRouter,
and Novita. Extracted here so each plugin stays a few lines long.
"""
from __future__ import annotations
from typing import List, Optional
from application.core.model_settings import AvailableModel
def get_api_key(
settings,
provider_name: str,
provider_specific_key: Optional[str],
) -> Optional[str]:
if provider_specific_key:
return provider_specific_key
if settings.LLM_PROVIDER == provider_name and settings.API_KEY:
return settings.API_KEY
return None
def filter_models_by_llm_name(
settings,
provider_name: str,
provider_specific_key: Optional[str],
models: List[AvailableModel],
) -> List[AvailableModel]:
"""Mirrors the historical ``_add_<X>_models`` selection logic.
Behavior:
- If the provider-specific API key is set → load all models.
- Else if ``LLM_PROVIDER`` matches and ``LLM_NAME`` matches a known
model → load just that model.
- Otherwise → load all models (preserved "load anyway" branch from
the original methods).
"""
if provider_specific_key:
return models
if (
settings.LLM_PROVIDER == provider_name
and settings.LLM_NAME
):
named = [m for m in models if m.id == settings.LLM_NAME]
if named:
return named
return models

View File

@@ -0,0 +1,23 @@
from __future__ import annotations
from typing import Optional
from application.llm.anthropic import AnthropicLLM
from application.llm.providers._apikey_or_llm_name import (
filter_models_by_llm_name,
get_api_key,
)
from application.llm.providers.base import Provider
class AnthropicProvider(Provider):
name = "anthropic"
llm_class = AnthropicLLM
def get_api_key(self, settings) -> Optional[str]:
return get_api_key(settings, self.name, settings.ANTHROPIC_API_KEY)
def filter_yaml_models(self, settings, models):
return filter_models_by_llm_name(
settings, self.name, settings.ANTHROPIC_API_KEY, models
)

View File

@@ -0,0 +1,30 @@
from __future__ import annotations
from typing import Optional
from application.llm.openai import AzureOpenAILLM
from application.llm.providers.base import Provider
class AzureOpenAIProvider(Provider):
name = "azure_openai"
llm_class = AzureOpenAILLM
def get_api_key(self, settings) -> Optional[str]:
# Azure historically uses the generic API_KEY field.
return settings.API_KEY
def is_enabled(self, settings) -> bool:
if settings.OPENAI_API_BASE:
return True
return settings.LLM_PROVIDER == self.name and bool(settings.API_KEY)
def filter_yaml_models(self, settings, models):
# Mirrors _add_azure_openai_models: when LLM_PROVIDER==azure_openai
# and LLM_NAME matches a known model, narrow to that one model.
# Otherwise load the entire catalog.
if settings.LLM_PROVIDER == self.name and settings.LLM_NAME:
named = [m for m in models if m.id == settings.LLM_NAME]
if named:
return named
return models

View File

@@ -0,0 +1,74 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, ClassVar, List, Optional, Type
if TYPE_CHECKING:
from application.core.model_settings import AvailableModel
from application.core.model_yaml import ProviderCatalog
from application.core.settings import Settings
from application.llm.base import BaseLLM
class Provider(ABC):
"""Owns the *behavior* of an LLM provider.
Concrete providers declare their name, the LLM class to instantiate,
and how to resolve credentials from settings. Static model catalogs
live in YAML under ``application/core/models/`` and are joined to the
provider by name at registry load time.
Most plugins receive zero or one catalog at registry-build time. The
``openai_compatible`` plugin is the exception: it receives one catalog
per matching YAML file, each with its own ``api_key_env`` and
``base_url``. Plugins that need per-catalog metadata override
``get_models``; the default implementation merges catalogs and routes
through ``filter_yaml_models`` + ``extra_models``.
"""
name: ClassVar[str]
# ``None`` means the provider appears in the catalog but isn't
# dispatchable through LLMCreator (e.g. Hugging Face today, where the
# original LLMCreator dict had no entry).
llm_class: ClassVar[Optional[Type["BaseLLM"]]] = None
@abstractmethod
def get_api_key(self, settings: "Settings") -> Optional[str]:
"""Return the API key for this provider, or None if unavailable."""
def is_enabled(self, settings: "Settings") -> bool:
"""Whether this provider should contribute models to the registry."""
return bool(self.get_api_key(settings))
def filter_yaml_models(
self, settings: "Settings", models: List["AvailableModel"]
) -> List["AvailableModel"]:
"""Hook to filter YAML-loaded models. Default: return all."""
return models
def extra_models(self, settings: "Settings") -> List["AvailableModel"]:
"""Hook to add dynamic models not declared in YAML. Default: none."""
return []
def get_models(
self,
settings: "Settings",
catalogs: List["ProviderCatalog"],
) -> List["AvailableModel"]:
"""Final list of models this plugin contributes.
Default: merge the models across all matched catalogs (later
catalog wins on duplicate id), filter via ``filter_yaml_models``,
then append ``extra_models``. Override when per-catalog metadata
matters (see ``OpenAICompatibleProvider``).
"""
merged: List["AvailableModel"] = []
seen: dict = {}
for c in catalogs:
for m in c.models:
if m.id in seen:
merged[seen[m.id]] = m
else:
seen[m.id] = len(merged)
merged.append(m)
return self.filter_yaml_models(settings, merged) + self.extra_models(settings)

View File

@@ -0,0 +1,22 @@
from __future__ import annotations
from typing import Optional
from application.llm.docsgpt_provider import DocsGPTAPILLM
from application.llm.providers.base import Provider
class DocsGPTProvider(Provider):
name = "docsgpt"
llm_class = DocsGPTAPILLM
def get_api_key(self, settings) -> Optional[str]:
# No provider-specific key; the LLM class can use the generic
# API_KEY fallback if it needs one. Mirrors model_utils' historical
# behavior of returning settings.API_KEY when no specific key exists.
return settings.API_KEY
def is_enabled(self, settings) -> bool:
# The hosted DocsGPT model is hidden when the deployment is
# pointed at a custom OpenAI-compatible endpoint.
return not settings.OPENAI_BASE_URL

View File

@@ -0,0 +1,23 @@
from __future__ import annotations
from typing import Optional
from application.llm.google_ai import GoogleLLM
from application.llm.providers._apikey_or_llm_name import (
filter_models_by_llm_name,
get_api_key,
)
from application.llm.providers.base import Provider
class GoogleProvider(Provider):
name = "google"
llm_class = GoogleLLM
def get_api_key(self, settings) -> Optional[str]:
return get_api_key(settings, self.name, settings.GOOGLE_API_KEY)
def filter_yaml_models(self, settings, models):
return filter_models_by_llm_name(
settings, self.name, settings.GOOGLE_API_KEY, models
)

View File

@@ -0,0 +1,23 @@
from __future__ import annotations
from typing import Optional
from application.llm.groq import GroqLLM
from application.llm.providers._apikey_or_llm_name import (
filter_models_by_llm_name,
get_api_key,
)
from application.llm.providers.base import Provider
class GroqProvider(Provider):
name = "groq"
llm_class = GroqLLM
def get_api_key(self, settings) -> Optional[str]:
return get_api_key(settings, self.name, settings.GROQ_API_KEY)
def filter_yaml_models(self, settings, models):
return filter_models_by_llm_name(
settings, self.name, settings.GROQ_API_KEY, models
)

View File

@@ -0,0 +1,25 @@
from __future__ import annotations
from typing import Optional
from application.llm.providers._apikey_or_llm_name import (
get_api_key as shared_get_api_key,
)
from application.llm.providers.base import Provider
class HuggingFaceProvider(Provider):
"""Surfaces ``huggingface-local`` to the model catalog.
Not dispatchable through LLMCreator — historically there was no
HuggingFaceLLM entry in ``LLMCreator.llms``, and calling ``create_llm``
with ``"huggingface"`` raised ``ValueError``. We preserve that
behavior: the model appears in ``/api/models`` but selecting it
surfaces the same error it always did.
"""
name = "huggingface"
llm_class = None # not dispatchable
def get_api_key(self, settings) -> Optional[str]:
return shared_get_api_key(settings, self.name, settings.HUGGINGFACE_API_KEY)

View File

@@ -0,0 +1,19 @@
from __future__ import annotations
from typing import Optional
from application.llm.llama_cpp import LlamaCpp
from application.llm.providers.base import Provider
class LlamaCppProvider(Provider):
"""LLMCreator-only plugin: invocable via LLM_PROVIDER but not in the catalog."""
name = "llama.cpp"
llm_class = LlamaCpp
def get_api_key(self, settings) -> Optional[str]:
return settings.API_KEY
def is_enabled(self, settings) -> bool:
return False

View File

@@ -0,0 +1,23 @@
from __future__ import annotations
from typing import Optional
from application.llm.novita import NovitaLLM
from application.llm.providers._apikey_or_llm_name import (
filter_models_by_llm_name,
get_api_key,
)
from application.llm.providers.base import Provider
class NovitaProvider(Provider):
name = "novita"
llm_class = NovitaLLM
def get_api_key(self, settings) -> Optional[str]:
return get_api_key(settings, self.name, settings.NOVITA_API_KEY)
def filter_yaml_models(self, settings, models):
return filter_models_by_llm_name(
settings, self.name, settings.NOVITA_API_KEY, models
)

View File

@@ -0,0 +1,37 @@
from __future__ import annotations
from typing import Optional
from application.llm.openai import OpenAILLM
from application.llm.providers.base import Provider
class OpenAIProvider(Provider):
name = "openai"
llm_class = OpenAILLM
def get_api_key(self, settings) -> Optional[str]:
if settings.OPENAI_API_KEY:
return settings.OPENAI_API_KEY
if settings.LLM_PROVIDER == self.name and settings.API_KEY:
return settings.API_KEY
return None
def is_enabled(self, settings) -> bool:
# When the deployment is pointed at a custom OpenAI-compatible
# endpoint (Ollama, LM Studio, ...), the cloud-OpenAI catalog is
# suppressed but ``is_enabled`` stays True — necessary so the
# filter below still gets to drop the catalog (rather than the
# registry skipping the provider entirely and missing the rule).
if settings.OPENAI_BASE_URL:
return True
return bool(self.get_api_key(settings))
def filter_yaml_models(self, settings, models):
# Legacy local-endpoint mode hides the cloud catalog. The
# corresponding dynamic models live in OpenAICompatibleProvider.
if settings.OPENAI_BASE_URL:
return []
if not settings.OPENAI_API_KEY:
return []
return models

View File

@@ -0,0 +1,149 @@
"""Generic provider for OpenAI-wire-compatible endpoints.
Each ``openai_compatible`` YAML file describes one logical endpoint
(Mistral, Together, Fireworks, Ollama, ...) with its own
``api_key_env`` and ``base_url``. Multiple files can coexist; the
plugin produces one set of models per file, each pre-configured with
the right credentials and URL.
The plugin also handles the **legacy** ``OPENAI_BASE_URL`` + ``LLM_NAME``
local-endpoint pattern that previously lived in ``OpenAIProvider``. That
path generates models dynamically from ``LLM_NAME``, using
``OPENAI_BASE_URL`` and ``OPENAI_API_KEY`` as the endpoint config.
"""
from __future__ import annotations
import logging
import os
from typing import List, Optional
from application.core.model_settings import (
AvailableModel,
ModelCapabilities,
ModelProvider,
)
from application.llm.openai import OpenAILLM
from application.llm.providers.base import Provider
logger = logging.getLogger(__name__)
def _parse_model_names(llm_name: Optional[str]) -> List[str]:
if not llm_name:
return []
return [name.strip() for name in llm_name.split(",") if name.strip()]
class OpenAICompatibleProvider(Provider):
name = "openai_compatible"
llm_class = OpenAILLM
def get_api_key(self, settings) -> Optional[str]:
# Per-model: each catalog supplies its own ``api_key_env``. There
# is no single plugin-wide key. LLMCreator reads the per-model
# ``api_key`` set during catalog materialization.
return None
def is_enabled(self, settings) -> bool:
# Concrete enablement happens per catalog (in ``get_models``).
# Returning True lets the registry call ``get_models`` so we can
# decide per-file whether to contribute models.
return True
def get_models(self, settings, catalogs) -> List[AvailableModel]:
out: List[AvailableModel] = []
for catalog in catalogs:
out.extend(self._materialize_yaml_catalog(catalog))
if settings.OPENAI_BASE_URL and settings.LLM_NAME:
out.extend(self._materialize_legacy_local_endpoint(settings))
return out
def _materialize_yaml_catalog(self, catalog) -> List[AvailableModel]:
"""Resolve one openai_compatible YAML into ready-to-dispatch models.
Skipped (with an INFO-level log) if ``api_key_env`` resolves to
nothing — no point publishing models the user can't actually
call. INFO rather than WARNING because operators may legitimately
drop multiple provider YAMLs as templates and only set the env
vars for the ones they actually use; a missing key is ambiguous,
not necessarily a misconfig.
"""
if not catalog.base_url:
raise ValueError(
f"{catalog.source_path}: openai_compatible YAML must set "
"'base_url'."
)
if not catalog.api_key_env:
raise ValueError(
f"{catalog.source_path}: openai_compatible YAML must set "
"'api_key_env'."
)
api_key = os.environ.get(catalog.api_key_env)
if not api_key:
logger.info(
"openai_compatible catalog %s skipped: env var %s is not set",
catalog.source_path,
catalog.api_key_env,
)
return []
out: List[AvailableModel] = []
for m in catalog.models:
out.append(self._with_endpoint(m, catalog.base_url, api_key))
return out
def _materialize_legacy_local_endpoint(self, settings) -> List[AvailableModel]:
"""Generate AvailableModels from ``LLM_NAME`` for the legacy
``OPENAI_BASE_URL`` deployment pattern (Ollama, LM Studio, ...).
Preserves the historical ``provider="openai"`` display behavior
by setting ``display_provider="openai"``.
"""
from application.core.model_yaml import resolve_attachment_alias
attachments = resolve_attachment_alias("image")
api_key = settings.OPENAI_API_KEY or settings.API_KEY
out: List[AvailableModel] = []
for model_name in _parse_model_names(settings.LLM_NAME):
out.append(
AvailableModel(
id=model_name,
provider=ModelProvider.OPENAI_COMPATIBLE,
display_name=model_name,
description=f"Custom OpenAI-compatible model at {settings.OPENAI_BASE_URL}",
base_url=settings.OPENAI_BASE_URL,
capabilities=ModelCapabilities(
supports_tools=True,
supported_attachment_types=attachments,
),
api_key=api_key,
display_provider="openai",
)
)
return out
@staticmethod
def _with_endpoint(
model: AvailableModel, base_url: str, api_key: str
) -> AvailableModel:
"""Return a copy of ``model`` carrying the catalog's endpoint config.
The catalog-level ``base_url`` is the default; an explicit
per-model ``base_url`` in the YAML wins.
"""
return AvailableModel(
id=model.id,
provider=model.provider,
display_name=model.display_name,
description=model.description,
capabilities=model.capabilities,
enabled=model.enabled,
base_url=model.base_url or base_url,
display_provider=model.display_provider,
api_key=api_key,
)

View File

@@ -0,0 +1,23 @@
from __future__ import annotations
from typing import Optional
from application.llm.open_router import OpenRouterLLM
from application.llm.providers._apikey_or_llm_name import (
filter_models_by_llm_name,
get_api_key,
)
from application.llm.providers.base import Provider
class OpenRouterProvider(Provider):
name = "openrouter"
llm_class = OpenRouterLLM
def get_api_key(self, settings) -> Optional[str]:
return get_api_key(settings, self.name, settings.OPEN_ROUTER_API_KEY)
def filter_yaml_models(self, settings, models):
return filter_models_by_llm_name(
settings, self.name, settings.OPEN_ROUTER_API_KEY, models
)

View File

@@ -0,0 +1,19 @@
from __future__ import annotations
from typing import Optional
from application.llm.premai import PremAILLM
from application.llm.providers.base import Provider
class PremAIProvider(Provider):
"""LLMCreator-only plugin: invocable via LLM_PROVIDER but not in the catalog."""
name = "premai"
llm_class = PremAILLM
def get_api_key(self, settings) -> Optional[str]:
return settings.API_KEY
def is_enabled(self, settings) -> bool:
return False

View File

@@ -0,0 +1,24 @@
from __future__ import annotations
from typing import Optional
from application.llm.sagemaker import SagemakerAPILLM
from application.llm.providers.base import Provider
class SagemakerProvider(Provider):
"""LLMCreator-only plugin: invocable via LLM_PROVIDER but not in the catalog.
SageMaker reads its credentials from ``SAGEMAKER_*`` settings inside
the LLM class itself; this plugin's ``get_api_key`` exists only for
LLMCreator's symmetry.
"""
name = "sagemaker"
llm_class = SagemakerAPILLM
def get_api_key(self, settings) -> Optional[str]:
return settings.API_KEY
def is_enabled(self, settings) -> bool:
return False

View File

@@ -82,6 +82,7 @@ python-dateutil==2.9.0.post0
python-dotenv
python-jose==3.5.0
python-pptx==1.0.2
PyYAML
redis==7.4.0
referencing>=0.28.0,<0.38.0
regex==2026.4.4

View File

@@ -22,6 +22,7 @@ class ClassicRAG(BaseRetriever):
llm_name=settings.LLM_PROVIDER,
api_key=settings.API_KEY,
decoded_token=None,
model_user_id=None,
):
self.original_question = source.get("question", "")
self.chat_history = chat_history if chat_history is not None else []
@@ -42,17 +43,22 @@ class ClassicRAG(BaseRetriever):
f"sources={'active_docs' in source and source['active_docs'] is not None}"
)
self.model_id = model_id
self.model_user_id = model_user_id
self.doc_token_limit = doc_token_limit
self.user_api_key = user_api_key
self.agent_id = agent_id
self.llm_name = llm_name
self.api_key = api_key
# Forward model_id + model_user_id so LLMCreator resolves BYOM
# base_url / api_key / upstream id for the rephrase client.
self.llm = LLMCreator.create_llm(
self.llm_name,
api_key=self.api_key,
user_api_key=self.user_api_key,
decoded_token=decoded_token,
model_id=self.model_id,
agent_id=self.agent_id,
model_user_id=self.model_user_id,
)
if "active_docs" in source and source["active_docs"] is not None:
@@ -103,7 +109,11 @@ class ClassicRAG(BaseRetriever):
]
try:
rephrased_query = self.llm.gen(model=self.model_id, messages=messages)
# Send upstream id (resolved by LLMCreator), not registry UUID.
rephrased_query = self.llm.gen(
model=getattr(self.llm, "model_id", None) or self.model_id,
messages=messages,
)
print(f"Rephrased query: {rephrased_query}")
return rephrased_query if rephrased_query else self.original_question
except Exception as e:

View File

@@ -0,0 +1,464 @@
"""SSRF protection for user-supplied OpenAI-compatible base URLs.
This module is the single chokepoint for validating any URL that a user
provides as an OpenAI-compatible ``base_url`` ("Bring Your Own Model").
The backend will later issue outbound HTTP requests to that URL on the
user's behalf, so we must reject anything that could be used to reach
internal-network resources (cloud metadata services, RFC 1918 ranges,
loopback, link-local, etc.).
Three entry points:
* :func:`validate_user_base_url` — called at create/update time on REST
routes that persist the URL, to give the user immediate feedback.
* :func:`pinned_post` — called at dispatch time when the caller drives
``requests`` directly (e.g. the ``/api/models/test`` endpoint).
Resolves once, dials the IP literal, preserves the original hostname
in the ``Host`` header and via SNI / cert verification for HTTPS.
* :func:`pinned_httpx_client` — called at dispatch time when the caller
hands an ``httpx.Client`` to a third-party SDK (e.g. the OpenAI
Python SDK via ``OpenAI(http_client=...)``). Same DNS-rebinding
closure on the httpx transport layer.
Why all three: the OpenAI / httpx ecosystem performs its own DNS lookup
inside ``socket.getaddrinfo`` when a connection opens, so a hostile DNS
server can hand a public IP to the validator and a loopback / link-local
address to the HTTP client. Validate-then-construct-SDK is unsafe; the
pinned variants close that TOCTOU window by resolving exactly once and
dialing the chosen IP literal directly.
"""
from __future__ import annotations
import ipaddress
import socket
from typing import Any, Iterable
from urllib.parse import urlsplit, urlunsplit
import httpx
import requests
from requests.adapters import HTTPAdapter
# Allowed URL schemes. Anything else (file, gopher, ftp, data, ...) is
# rejected outright because it either bypasses HTTP entirely or enables
# protocol smuggling against the proxy stack.
_ALLOWED_SCHEMES: frozenset[str] = frozenset({"http", "https"})
# Hostnames that resolve to a loopback / metadata / unspecified address
# but which we want to reject *by name* as well, so the rejection
# message is unambiguous and so we never accidentally call DNS on them.
_BLOCKED_HOSTNAMES: frozenset[str] = frozenset(
{
"localhost",
"localhost.localdomain",
"0.0.0.0",
"::",
"::1",
"ip6-localhost",
"ip6-loopback",
# GCP metadata service. AWS/Azure use 169.254.169.254 which the
# IP-range check below already covers via the link-local range,
# but Google's hostname does not always resolve to a link-local
# IP from every VPC, so we hard-deny the string too.
"metadata.google.internal",
}
)
# Carrier-grade NAT (RFC 6598). Python's ``ipaddress`` module does NOT
# classify this range as ``is_private``, so we must check it explicitly.
_CGNAT_NETWORK_V4: ipaddress.IPv4Network = ipaddress.IPv4Network("100.64.0.0/10")
class UnsafeUserUrlError(ValueError):
"""Raised when a user-supplied URL fails SSRF validation.
Subclasses :class:`ValueError` so call sites that already treat
invalid input as a 400-class error continue to work. The string
message names the specific reason (scheme, hostname, resolved IP,
DNS failure, ...) so that it can be surfaced to the user verbatim.
"""
def _strip_ipv6_brackets(host: str) -> str:
"""Return ``host`` with surrounding ``[`` / ``]`` removed if present."""
if host.startswith("[") and host.endswith("]"):
return host[1:-1]
return host
def _is_blocked_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
"""Return ``True`` if ``ip`` falls in any range we refuse to dial.
This is the single source of truth for the IP-range policy:
* loopback (``127.0.0.0/8``, ``::1``)
* private (RFC 1918, ULA ``fc00::/7``)
* link-local (``169.254.0.0/16``, ``fe80::/10``)
* multicast (``224.0.0.0/4``, ``ff00::/8``)
* unspecified (``0.0.0.0``, ``::``)
* reserved (``240.0.0.0/4``, etc.)
* carrier-grade NAT (``100.64.0.0/10``) — not covered by ``is_private``
"""
if (
ip.is_loopback
or ip.is_private
or ip.is_link_local
or ip.is_multicast
or ip.is_unspecified
or ip.is_reserved
):
return True
if isinstance(ip, ipaddress.IPv4Address) and ip in _CGNAT_NETWORK_V4:
return True
return False
def _resolve(host: str) -> Iterable[ipaddress.IPv4Address | ipaddress.IPv6Address]:
"""Resolve ``host`` to every A/AAAA record returned by the system.
Returning *all* addresses (rather than the first one) is critical:
a hostile DNS server can return a public IP first followed by a
private IP, and the underlying HTTP client may fail over to the
private one on connect. We treat the set as unsafe if any element
is unsafe.
"""
try:
results = socket.getaddrinfo(host, None)
except socket.gaierror as exc: # noqa: PERF203 — re-raise as our own type
raise UnsafeUserUrlError(f"could not resolve hostname {host!r}: {exc}") from exc
addresses: list[ipaddress.IPv4Address | ipaddress.IPv6Address] = []
for entry in results:
sockaddr = entry[4]
# IPv4 sockaddr: (host, port). IPv6 sockaddr: (host, port, flowinfo, scope_id).
ip_str = sockaddr[0]
# Strip IPv6 zone-id ("fe80::1%lo0") before parsing.
if "%" in ip_str:
ip_str = ip_str.split("%", 1)[0]
try:
addresses.append(ipaddress.ip_address(ip_str))
except ValueError:
# An entry we can't parse is itself suspicious; treat as unsafe.
raise UnsafeUserUrlError(
f"hostname {host!r} resolved to unparseable address {ip_str!r}"
) from None
return addresses
def _validate_and_pick_ip(
url: str,
) -> tuple[str, ipaddress.IPv4Address | ipaddress.IPv6Address, "urlsplit"]:
"""Run the SSRF guard and return the data needed to dial safely.
Performs every check :func:`validate_user_base_url` performs, but
additionally returns ``(hostname, ip, parts)`` where ``ip`` is one
of the validated addresses (the first record returned by the
resolver, or the literal itself if the URL already used an IP) and
``parts`` is the :func:`urllib.parse.urlsplit` result so callers do
not have to re-parse the URL.
Raises :class:`UnsafeUserUrlError` on the same conditions as
:func:`validate_user_base_url`.
"""
if not isinstance(url, str) or not url.strip():
raise UnsafeUserUrlError("url must be a non-empty string")
try:
parts = urlsplit(url)
except ValueError as exc:
raise UnsafeUserUrlError(f"could not parse url {url!r}: {exc}") from exc
scheme = parts.scheme.lower()
if scheme not in _ALLOWED_SCHEMES:
raise UnsafeUserUrlError(
f"scheme {scheme!r} is not allowed; only http and https are permitted"
)
# ``urlsplit`` returns the bracketed form for IPv6 in ``netloc`` but
# the bare form in ``hostname``. Normalize via lower() because
# hostnames are case-insensitive and we compare against a lowercase
# blocklist.
raw_host = parts.hostname
if not raw_host:
raise UnsafeUserUrlError(f"url {url!r} has no hostname")
host = raw_host.lower()
# Check the literal-string blocklist first. urlsplit().hostname strips
# IPv6 brackets, so we also test the bracketed form for completeness
# (matches the public-spec note about ``[::]``).
bracketed = f"[{host}]"
if host in _BLOCKED_HOSTNAMES or bracketed in _BLOCKED_HOSTNAMES:
raise UnsafeUserUrlError(
f"hostname {raw_host!r} is not allowed (matches internal-only name)"
)
# If the host is already an IP literal (with or without IPv6 brackets),
# check it directly without going to DNS — DNS for an IP literal is a
# no-op but it's clearer to short-circuit and gives a better message.
candidate = _strip_ipv6_brackets(host)
try:
literal = ipaddress.ip_address(candidate)
except ValueError:
literal = None
if literal is not None:
if _is_blocked_ip(literal):
raise UnsafeUserUrlError(
f"hostname {raw_host!r} resolves to blocked address {literal} "
f"(loopback/private/link-local/multicast/reserved/CGNAT)"
)
return host, literal, parts
# Hostname (not an IP literal) — resolve and validate every record.
addresses = list(_resolve(host))
for ip in addresses:
if _is_blocked_ip(ip):
raise UnsafeUserUrlError(
f"hostname {raw_host!r} resolves to blocked address {ip} "
f"(loopback/private/link-local/multicast/reserved/CGNAT)"
)
if not addresses:
# ``getaddrinfo`` would normally raise instead of returning an
# empty list, but treat the degenerate case as unsafe too — we
# have nothing to bind a connection to.
raise UnsafeUserUrlError(
f"hostname {raw_host!r} returned no addresses from DNS"
)
return host, addresses[0], parts
def validate_user_base_url(url: str) -> None:
"""Validate that ``url`` is safe to use as an outbound base URL.
Resolve the URL's hostname to one or more IPs and reject if any
resolved IP is private/loopback/link-local/multicast/reserved, or if
the URL uses a non-http(s) scheme, or if the hostname is one of the
known dangerous strings (``localhost``, ``0.0.0.0``, ``[::]``).
Raises :class:`UnsafeUserUrlError` on rejection. Returns ``None`` on
success.
This function is the create/update-time check. At dispatch time use
:func:`pinned_post` instead, which performs the same validation
*and* pins the outbound connection to the validated IP so a DNS
rebinder cannot flip the resolution between check and connect.
Args:
url: The user-supplied URL to validate. Expected to be an
absolute URL with an ``http`` or ``https`` scheme.
Raises:
UnsafeUserUrlError: If the URL fails to parse, uses a forbidden
scheme, has an empty/blocklisted hostname, fails DNS
resolution, or resolves to any IP in a blocked range.
"""
_validate_and_pick_ip(url)
class _PinnedHostAdapter(HTTPAdapter):
"""HTTPS adapter that performs SNI and cert verification against a
fixed hostname even when the URL connects to an IP literal.
Used by :func:`pinned_post` so that resolving the user-supplied
hostname once and dialing the resolved IP doesn't break TLS.
Without this, ``urllib3`` would default ``server_hostname`` /
``assert_hostname`` to the connect host (the IP) and either send the
wrong SNI or fail cert verification — the cert is for the original
hostname, not the IP literal.
"""
def __init__(self, server_hostname: str, *args: Any, **kwargs: Any) -> None:
self._server_hostname = server_hostname
super().__init__(*args, **kwargs)
def init_poolmanager(self, *args: Any, **kwargs: Any) -> None:
kwargs["server_hostname"] = self._server_hostname
kwargs["assert_hostname"] = self._server_hostname
super().init_poolmanager(*args, **kwargs)
def _ip_to_url_host(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> str:
"""Return ``ip`` formatted for use in a URL netloc (brackets for v6)."""
if isinstance(ip, ipaddress.IPv6Address):
return f"[{ip}]"
return str(ip)
def pinned_post(
url: str,
*,
json: Any = None,
headers: dict[str, str] | None = None,
timeout: float = 5.0,
allow_redirects: bool = False,
) -> requests.Response:
"""POST to ``url`` with the outbound connection pinned to a single
validated IP, closing the DNS-rebinding TOCTOU window left by the
naive validate-then-``requests.post`` pattern.
The URL's hostname is resolved exactly once. Every returned address
must pass the same SSRF guard as :func:`validate_user_base_url`. The
outbound request is issued against the chosen IP literal (so
``urllib3`` cannot ask the resolver again and receive a different
answer); the original hostname is preserved in the ``Host`` header
and, for HTTPS, via :class:`_PinnedHostAdapter` for SNI and cert
verification.
Args:
url: Absolute http(s) URL to POST to.
json: JSON-serializable payload — passed through to ``requests``.
headers: Caller-supplied headers. Any caller-supplied ``Host``
entry is overwritten so the in-flight request matches what
was validated.
timeout: Per-request timeout (seconds).
allow_redirects: Forwarded to ``requests``. Defaults to
``False`` because the SSRF guard only inspects the supplied
URL — following redirects would let a hostile upstream
bounce the request to an internal address.
Raises:
UnsafeUserUrlError: If the URL fails the SSRF guard.
requests.RequestException: For network-level failures.
"""
host, ip, parts = _validate_and_pick_ip(url)
netloc = _ip_to_url_host(ip)
if parts.port is not None:
netloc = f"{netloc}:{parts.port}"
pinned_url = urlunsplit(
(parts.scheme, netloc, parts.path, parts.query, parts.fragment)
)
request_headers = dict(headers or {})
host_header = host if parts.port is None else f"{host}:{parts.port}"
request_headers["Host"] = host_header
session = requests.Session()
if parts.scheme == "https":
session.mount("https://", _PinnedHostAdapter(host))
try:
return session.post(
pinned_url,
json=json,
headers=request_headers,
timeout=timeout,
allow_redirects=allow_redirects,
)
finally:
session.close()
class _PinnedHTTPSTransport(httpx.HTTPTransport):
"""``httpx`` transport pinned to a single validated IP literal.
Closes the DNS-rebinding TOCTOU window that
:func:`validate_user_base_url` cannot close on its own. The OpenAI
Python SDK (and any other SDK that uses ``httpx``) re-resolves the
hostname inside ``socket.getaddrinfo`` at request time, so a
hostile DNS server can return a public IP at validation time and a
private IP at request time. This transport rewrites every outgoing
request's URL host to the validated IP literal so ``httpcore``
dials that IP without a fresh lookup.
The original hostname is preserved in two places:
1. ``Host`` header — ``httpx.Request._prepare`` set it from the URL
netloc *before* this transport runs, so it carries the hostname
not the IP literal. We deliberately do not touch headers here.
2. TLS SNI / cert verification — set via the
``request.extensions["sni_hostname"]`` extension which
``httpcore`` feeds into ``start_tls``'s ``server_hostname``
parameter. Without this, ``urllib3``-equivalent code would use
the IP literal as SNI and cert verification would fail (the
cert is for the original hostname, not the IP).
"""
def __init__(
self,
validated_host: str,
validated_ip: ipaddress.IPv4Address | ipaddress.IPv6Address,
**kwargs: Any,
) -> None:
# http2=False (the httpx default) — defense in depth against
# HTTP/2 connection coalescing (RFC 7540 §9.1.1), where a
# client may reuse a TCP connection for any host whose cert
# covers it. Per-IP pinning never shares connections across
# hosts, but explicit is safer than relying on the default.
kwargs.setdefault("http2", False)
super().__init__(**kwargs)
self._host = validated_host
self._ip_netloc = _ip_to_url_host(validated_ip)
def handle_request(self, request: httpx.Request) -> httpx.Response:
# Defense in depth: refuse if the request URL's host doesn't
# match what we validated. Catches any future SDK regression
# that rewrites the URL between Request construction and dial,
# and any rare case where the SDK reuses our pinned client for
# a different host (which it shouldn't, but assert it anyway).
if request.url.host != self._host:
raise UnsafeUserUrlError(
f"pinned transport bound to {self._host!r}, refused "
f"request for {request.url.host!r}"
)
# SNI/server_hostname for TLS verification. httpcore reads this
# extension at _sync/connection.py and feeds it into
# start_tls's server_hostname argument. Set before the URL host
# is rewritten so cert validation continues to use the original
# hostname even though TCP dials the IP literal.
request.extensions = {
**request.extensions,
"sni_hostname": self._host.encode("ascii"),
}
request.url = request.url.copy_with(host=self._ip_netloc)
return super().handle_request(request)
def pinned_httpx_client(
base_url: str,
*,
timeout: float = 600.0,
) -> httpx.Client:
"""Return an :class:`httpx.Client` whose connections are pinned to
one validated IP, closing the DNS-rebinding TOCTOU window the naive
``OpenAI(base_url=...)`` flow leaves open.
The hostname in ``base_url`` is resolved exactly once. Every
returned address must pass :func:`_validate_and_pick_ip`'s SSRF
guard (loopback, RFC 1918, link-local, multicast, reserved, CGNAT,
cloud metadata names). The chosen IP becomes the URL host on every
outgoing request so ``httpcore`` cannot ask the resolver again.
Pass via ``OpenAI(http_client=pinned_httpx_client(base_url))`` (or
any other SDK that accepts an ``httpx.Client``) to make BYOM
dispatch immune to DNS-rebinding TOCTOU.
Args:
base_url: User-supplied http(s) URL. Validated through the same
SSRF guard as :func:`validate_user_base_url`.
timeout: Per-request timeout (seconds). Defaults to 600 to
match the OpenAI SDK's default; callers should override
for non-LLM workloads.
Raises:
UnsafeUserUrlError: If ``base_url`` fails the SSRF guard.
"""
host, ip, _parts = _validate_and_pick_ip(base_url)
transport = _PinnedHTTPSTransport(host, ip)
# follow_redirects=False — the SSRF guard only inspects the
# supplied URL; following 3xx would let a hostile upstream bounce
# the in-network request to an internal address (cloud metadata,
# RFC1918, loopback) carrying whatever credentials the SDK adds.
return httpx.Client(
transport=transport,
timeout=timeout,
follow_redirects=False,
)

View File

@@ -203,6 +203,24 @@ agents_table = Table(
Column("legacy_mongo_id", Text),
)
user_custom_models_table = Table(
"user_custom_models",
metadata,
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
Column("user_id", Text, nullable=False),
Column("upstream_model_id", Text, nullable=False),
Column("display_name", Text, nullable=False),
Column("description", Text, nullable=False, server_default=""),
Column("base_url", Text, nullable=False),
# AES-CBC ciphertext (base64) keyed via per-user PBKDF2 in
# application.security.encryption.encrypt_credentials.
Column("api_key_encrypted", Text, nullable=False),
Column("capabilities", JSONB, nullable=False, server_default="{}"),
Column("enabled", Boolean, nullable=False, server_default="true"),
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
)
attachments_table = Table(
"attachments",
metadata,

View File

@@ -1,7 +1,6 @@
"""Repository for the ``agents`` table.
This is the most complex Phase 2 repository. Covers every write operation
the legacy Mongo code performs on ``agents_collection``:
Covers every write operation the legacy Mongo code performs on ``agents_collection``:
- create, update, delete
- find by key (API key lookup)

View File

@@ -0,0 +1,199 @@
"""Repository for the ``user_custom_models`` table.
Backs the end-user "Bring Your Own Model" feature. Each row is one
user-supplied OpenAI-compatible endpoint (Mistral, Together, vLLM, ...).
The ``id`` UUID is the internal DocsGPT identifier (what agents store
in ``default_model_id``); ``upstream_model_id`` is what we send verbatim
to the provider's API.
API key handling: callers pass plaintext via ``api_key_plaintext``;
this module wraps the existing ``application.security.encryption``
helper (AES-CBC + per-user PBKDF2 salt) and writes the base64 ciphertext
to the ``api_key_encrypted`` column. Decryption is the caller's
responsibility (they hold the ``user_id``).
"""
from __future__ import annotations
from typing import Any, Optional
from sqlalchemy import Connection, func, text
from application.security.encryption import (
decrypt_credentials,
encrypt_credentials,
)
from application.storage.db.base_repository import row_to_dict
from application.storage.db.models import user_custom_models_table
_ALLOWED_CAPABILITY_KEYS = frozenset(
{
"supports_tools",
"supports_structured_output",
"supports_streaming",
"attachments",
"context_window",
}
)
class UserCustomModelsRepository:
def __init__(self, conn: Connection) -> None:
self._conn = conn
# ------------------------------------------------------------------ #
# Encryption wrappers
# ------------------------------------------------------------------ #
@staticmethod
def _encrypt_api_key(api_key_plaintext: str, user_id: str) -> str:
"""Encrypt ``api_key_plaintext`` with the per-user PBKDF2 scheme."""
return encrypt_credentials({"api_key": api_key_plaintext}, user_id)
@staticmethod
def _decrypt_api_key(api_key_encrypted: str, user_id: str) -> Optional[str]:
"""Decrypt the API key. Returns None on failure (which the caller
should surface as a configuration error rather than silently
proceeding with the upstream call)."""
if not api_key_encrypted:
return None
creds = decrypt_credentials(api_key_encrypted, user_id)
return creds.get("api_key") if creds else None
@staticmethod
def _normalize_capabilities(caps: Optional[dict]) -> dict:
"""Drop unknown keys; nothing else is forced. Callers (the route
layer) are responsible for value validation (numeric ranges,
attachment alias resolution)."""
if not caps:
return {}
return {k: v for k, v in caps.items() if k in _ALLOWED_CAPABILITY_KEYS}
# ------------------------------------------------------------------ #
# CRUD
# ------------------------------------------------------------------ #
def create(
self,
user_id: str,
upstream_model_id: str,
display_name: str,
base_url: str,
api_key_plaintext: str,
description: str = "",
capabilities: Optional[dict] = None,
enabled: bool = True,
) -> dict:
values = {
"user_id": user_id,
"upstream_model_id": upstream_model_id,
"display_name": display_name,
"description": description or "",
"base_url": base_url,
"api_key_encrypted": self._encrypt_api_key(api_key_plaintext, user_id),
"capabilities": self._normalize_capabilities(capabilities),
"enabled": bool(enabled),
}
from sqlalchemy.dialects.postgresql import insert as pg_insert
stmt = (
pg_insert(user_custom_models_table)
.values(**values)
.returning(user_custom_models_table)
)
result = self._conn.execute(stmt)
return row_to_dict(result.fetchone())
def get(self, model_id: str, user_id: str) -> Optional[dict]:
result = self._conn.execute(
text(
"SELECT * FROM user_custom_models "
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
),
{"id": str(model_id), "user_id": user_id},
)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def list_for_user(self, user_id: str) -> list[dict]:
result = self._conn.execute(
text(
"SELECT * FROM user_custom_models "
"WHERE user_id = :user_id ORDER BY created_at DESC"
),
{"user_id": user_id},
)
return [row_to_dict(r) for r in result.fetchall()]
def update(self, model_id: str, user_id: str, fields: dict) -> bool:
"""Apply a partial update.
Special-cases ``api_key_plaintext``: when present, it is encrypted
and stored in ``api_key_encrypted``. When absent (or empty), the
existing ciphertext is kept untouched. This is the wire-shape
``PATCH`` expects (the UI sends a blank password field when the
operator wants to keep the existing key).
"""
allowed = {
"upstream_model_id",
"display_name",
"description",
"base_url",
"capabilities",
"enabled",
}
values: dict[str, Any] = {}
for col, val in fields.items():
if col not in allowed or val is None:
continue
if col == "capabilities":
values[col] = self._normalize_capabilities(val)
elif col == "enabled":
values[col] = bool(val)
else:
values[col] = val
api_key_plaintext = fields.get("api_key_plaintext")
if api_key_plaintext:
values["api_key_encrypted"] = self._encrypt_api_key(
api_key_plaintext, user_id
)
if not values:
return False
values["updated_at"] = func.now()
t = user_custom_models_table
stmt = (
t.update()
.where(t.c.id == str(model_id))
.where(t.c.user_id == user_id)
.values(**values)
)
result = self._conn.execute(stmt)
return result.rowcount > 0
def delete(self, model_id: str, user_id: str) -> bool:
result = self._conn.execute(
text(
"DELETE FROM user_custom_models "
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
),
{"id": str(model_id), "user_id": user_id},
)
return result.rowcount > 0
# ------------------------------------------------------------------ #
# Decryption helpers exposed to the registry layer
# ------------------------------------------------------------------ #
def get_decrypted_api_key(
self, model_id: str, user_id: str
) -> Optional[str]:
"""Convenience: fetch the row and return the decrypted API key,
or ``None`` if the row is missing or decryption fails."""
row = self.get(model_id, user_id)
if row is None:
return None
return self._decrypt_api_key(row.get("api_key_encrypted", ""), user_id)

View File

@@ -83,9 +83,9 @@ def count_tokens_docs(docs):
def calculate_doc_token_budget(
model_id: str = "gpt-4o"
model_id: str = "gpt-4o", user_id: str | None = None
) -> int:
total_context = get_token_limit(model_id)
total_context = get_token_limit(model_id, user_id=user_id)
reserved = sum(settings.RESERVED_TOKENS.values())
doc_budget = total_context - reserved
return max(doc_budget, 1000)
@@ -150,9 +150,11 @@ def get_hash(data):
return hashlib.md5(data.encode(), usedforsecurity=False).hexdigest()
def limit_chat_history(history, max_token_limit=None, model_id="docsgpt-local"):
def limit_chat_history(
history, max_token_limit=None, model_id="docsgpt-local", user_id=None
):
"""Limit chat history to fit within token limit."""
model_token_limit = get_token_limit(model_id)
model_token_limit = get_token_limit(model_id, user_id=user_id)
max_token_limit = (
max_token_limit
if max_token_limit and max_token_limit < model_token_limit
@@ -204,7 +206,9 @@ def generate_image_url(image_path):
def calculate_compression_threshold(
model_id: str, threshold_percentage: float = 0.8
model_id: str,
threshold_percentage: float = 0.8,
user_id: str | None = None,
) -> int:
"""
Calculate token threshold for triggering compression.
@@ -212,11 +216,13 @@ def calculate_compression_threshold(
Args:
model_id: Model identifier
threshold_percentage: Percentage of context window (default 80%)
user_id: When set, BYOM custom-model records (UUID-keyed) resolve
for context-window lookup.
Returns:
Token count threshold
"""
total_context = get_token_limit(model_id)
total_context = get_token_limit(model_id, user_id=user_id)
threshold = int(total_context * threshold_percentage)
return threshold

View File

@@ -344,18 +344,34 @@ def run_agent_logic(agent_config, input_data):
# Determine model_id: check agent's default_model_id, fallback to system default
agent_default_model = agent_config.get("default_model_id", "")
if agent_default_model and validate_model_id(agent_default_model):
if agent_default_model and validate_model_id(
agent_default_model, user_id=owner
):
model_id = agent_default_model
else:
model_id = get_default_model_id()
if agent_default_model:
# Stored model_id no longer resolves in the registry. Log so
# operators can detect bad YAML edits before users complain;
# behavior matches the historical silent fallback.
logging.warning(
"Agent %s references unknown model_id %r; falling back to %r",
agent_id,
agent_default_model,
model_id,
)
# Get provider and API key for the selected model
provider = get_provider_from_model_id(model_id) if model_id else settings.LLM_PROVIDER
provider = (
get_provider_from_model_id(model_id, user_id=owner)
if model_id
else settings.LLM_PROVIDER
)
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
# Calculate proper doc_token_limit based on model's context window
doc_token_limit = calculate_doc_token_budget(
model_id=model_id
model_id=model_id, user_id=owner
)
retriever = RetrieverCreator.create_retriever(

View File

@@ -99,6 +99,82 @@ EMBEDDINGS_NAME=huggingface_sentence-transformers/all-mpnet-base-v2 # You can al
In this case, even though you are using Ollama locally, `LLM_PROVIDER` is set to `openai` because Ollama (and many other local inference engines) are designed to be API-compatible with OpenAI. `OPENAI_BASE_URL` points DocsGPT to the local Ollama server.
## Adding Custom Models (`MODELS_CONFIG_DIR`)
DocsGPT ships with a built-in catalog of models for the providers it
supports out of the box (OpenAI, Anthropic, Google, Groq, OpenRouter,
Novita, Azure OpenAI, Hugging Face, DocsGPT). To add **your own
models** without forking the repo — for example, a Mistral or Together
account, a self-hosted vLLM endpoint, or any other OpenAI-compatible
API — point `MODELS_CONFIG_DIR` at a directory of YAML files.
```
MODELS_CONFIG_DIR=/etc/docsgpt/models
MISTRAL_API_KEY=sk-...
```
A minimal YAML for one provider:
```yaml
# /etc/docsgpt/models/mistral.yaml
provider: openai_compatible
display_provider: mistral
api_key_env: MISTRAL_API_KEY
base_url: https://api.mistral.ai/v1
defaults:
supports_tools: true
context_window: 128000
models:
- id: mistral-large-latest
display_name: Mistral Large
- id: mistral-small-latest
display_name: Mistral Small
```
After restart, those models appear in `/api/models` and are selectable
in the UI. A working template lives at
`application/core/models/examples/mistral.yaml.example`.
**What you can do:**
- Add new `openai_compatible` providers (Mistral, Together, Fireworks,
Ollama, vLLM, ...) — one YAML per provider, each with its own
`api_key_env` and `base_url`.
- Extend an existing provider's catalog by dropping a YAML with the
same `provider:` value as the built-in (e.g. `provider: anthropic`
with extra models).
- Override a built-in model's capabilities by re-declaring the same
`id` — later wins, override is logged at `WARNING`.
**What you cannot do via `MODELS_CONFIG_DIR`:** add a brand-new
non-OpenAI provider. That requires a Python plugin under
`application/llm/providers/`. See
`application/core/models/README.md` for the full schema reference.
### Docker
Mount the directory and set the env var:
```yaml
# docker-compose.yml
services:
app:
image: arc53/docsgpt
environment:
MODELS_CONFIG_DIR: /etc/docsgpt/models
MISTRAL_API_KEY: ${MISTRAL_API_KEY}
volumes:
- ./my-models:/etc/docsgpt/models:ro
```
### Misconfiguration
If `MODELS_CONFIG_DIR` is set but the path doesn't exist (or isn't a
directory), the app logs a `WARNING` at boot and continues with just
the built-in catalog — it does **not** fail to start. If a YAML
declares an unknown provider name or has a schema error, the app
**does** fail to start, with the offending file path in the message.
## Speech-to-Text Settings
DocsGPT can transcribe audio in two places:

View File

@@ -448,7 +448,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
setUserTools(tools);
};
const getModels = async () => {
const response = await modelService.getModels(null);
const response = await modelService.getModels(token);
if (!response.ok) throw new Error('Failed to fetch models');
const data = await response.json();
const transformed = modelService.transformModels(data.models || []);
@@ -1041,10 +1041,24 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
isOpen={isModelsPopupOpen}
onClose={() => setIsModelsPopupOpen(false)}
anchorRef={modelAnchorButtonRef}
options={availableModels.map((model) => ({
id: model.id,
label: model.display_name,
}))}
options={(() => {
const builtinLabel = t(
'settings.customModels.modelsGroup.builtin',
);
const userLabel = t('settings.customModels.modelsGroup.user');
const builtin: OptionType[] = [];
const user: OptionType[] = [];
availableModels.forEach((model) => {
const opt: OptionType = {
id: model.id,
label: model.display_name,
group: model.source === 'user' ? userLabel : builtinLabel,
};
if (model.source === 'user') user.push(opt);
else builtin.push(opt);
});
return [...builtin, ...user];
})()}
selectedIds={selectedModelIds}
onSelectionChange={(newSelectedIds: Set<string | number>) =>
setSelectedModelIds(

View File

@@ -42,7 +42,9 @@ import { MultiSelect } from '@/components/ui/multi-select';
import {
Select,
SelectContent,
SelectGroup,
SelectItem,
SelectLabel,
SelectTrigger,
SelectValue,
} from '@/components/ui/select';
@@ -706,7 +708,7 @@ function WorkflowBuilderInner() {
useEffect(() => {
const loadModelsAndTools = async () => {
try {
const modelsResponse = await modelService.getModels(null);
const modelsResponse = await modelService.getModels(token);
if (modelsResponse.ok) {
const modelsData = await modelsResponse.json();
const transformedModels = modelService.transformModels(
@@ -732,7 +734,7 @@ function WorkflowBuilderInner() {
}
};
loadModelsAndTools();
}, []);
}, [token]);
useEffect(() => {
if (!selectedNode || selectedNode.type !== 'agent') return;
@@ -1847,15 +1849,54 @@ function WorkflowBuilderInner() {
<SelectValue placeholder="Select a model" />
</SelectTrigger>
<SelectContent>
{availableModels.map((model) => (
<SelectItem
key={model.id}
value={model.id}
>
{model.display_name} ·{' '}
{model.provider}
</SelectItem>
))}
{(() => {
const builtin = availableModels.filter(
(m) => m.source !== 'user',
);
const user = availableModels.filter(
(m) => m.source === 'user',
);
return (
<>
{builtin.length > 0 && (
<SelectGroup>
<SelectLabel>
{t(
'settings.customModels.modelsGroup.builtin',
)}
</SelectLabel>
{builtin.map((model) => (
<SelectItem
key={model.id}
value={model.id}
>
{model.display_name} ·{' '}
{model.provider}
</SelectItem>
))}
</SelectGroup>
)}
{user.length > 0 && (
<SelectGroup>
<SelectLabel>
{t(
'settings.customModels.modelsGroup.user',
)}
</SelectLabel>
{user.map((model) => (
<SelectItem
key={model.id}
value={model.id}
>
{model.display_name} ·{' '}
{model.provider}
</SelectItem>
))}
</SelectGroup>
)}
</>
);
})()}
</SelectContent>
</Select>
</div>

View File

@@ -80,6 +80,22 @@ const apiClient = {
return response;
}),
patch: (
url: string,
data: any,
token: string | null,
headers = {},
signal?: AbortSignal,
): Promise<any> =>
fetch(`${baseURL}${url}`, {
method: 'PATCH',
headers: getHeaders(token, headers),
body: JSON.stringify(data),
signal,
}).then((response) => {
return response;
}),
putFormData: (
url: string,
formData: FormData,

View File

@@ -76,6 +76,10 @@ const endpoints = {
GET_ARTIFACT: (artifactId: string) => `/api/artifact/${artifactId}`,
WORKFLOWS: '/api/workflows',
WORKFLOW: (id: string) => `/api/workflows/${id}`,
CUSTOM_MODELS: '/api/user/models',
CUSTOM_MODEL: (id: string) => `/api/user/models/${id}`,
CUSTOM_MODEL_TEST: (id: string) => `/api/user/models/${id}/test`,
CUSTOM_MODEL_TEST_PAYLOAD: '/api/user/models/test',
},
V1: {
CHAT_COMPLETIONS: '/v1/chat/completions',

View File

@@ -0,0 +1,162 @@
import apiClient from '../client';
import endpoints from '../endpoints';
import type {
CreateCustomModelPayload,
CustomModel,
CustomModelTestResult,
} from '../../models/types';
const parseJsonOrError = async (response: Response): Promise<any> => {
const text = await response.text();
let body: any = null;
if (text) {
try {
body = JSON.parse(text);
} catch {
body = null;
}
}
if (!response.ok) {
const message =
(body && (body.error || body.message)) ||
`Request failed with status ${response.status}`;
const err = new Error(message) as Error & {
status?: number;
payload?: unknown;
};
err.status = response.status;
err.payload = body;
throw err;
}
return body;
};
const customModelsService = {
listCustomModels: async (token: string | null): Promise<CustomModel[]> => {
const response = await apiClient.get(endpoints.USER.CUSTOM_MODELS, token);
const data = await parseJsonOrError(response);
if (Array.isArray(data)) return data as CustomModel[];
if (data && Array.isArray(data.models)) return data.models as CustomModel[];
return [];
},
createCustomModel: async (
payload: CreateCustomModelPayload,
token: string | null,
): Promise<CustomModel> => {
const response = await apiClient.post(
endpoints.USER.CUSTOM_MODELS,
payload,
token,
);
return (await parseJsonOrError(response)) as CustomModel;
},
updateCustomModel: async (
id: string,
payload: Partial<CreateCustomModelPayload>,
token: string | null,
): Promise<CustomModel> => {
const response = await apiClient.patch(
endpoints.USER.CUSTOM_MODEL(id),
payload,
token,
);
return (await parseJsonOrError(response)) as CustomModel;
},
deleteCustomModel: async (
id: string,
token: string | null,
): Promise<void> => {
const response = await apiClient.delete(
endpoints.USER.CUSTOM_MODEL(id),
token,
);
if (!response.ok) {
await parseJsonOrError(response);
}
},
testCustomModelPayload: async (
payload: {
base_url: string;
api_key: string;
upstream_model_id: string;
},
token: string | null,
): Promise<CustomModelTestResult> => {
const response = await apiClient.post(
endpoints.USER.CUSTOM_MODEL_TEST_PAYLOAD,
payload,
token,
);
const text = await response.text();
let body: any = null;
if (text) {
try {
body = JSON.parse(text);
} catch {
body = null;
}
}
if (!response.ok) {
return {
ok: false,
error:
(body && (body.error || body.message)) ||
`Test failed with status ${response.status}`,
};
}
if (body && typeof body.ok === 'boolean') {
return body as CustomModelTestResult;
}
return { ok: true };
},
testCustomModel: async (
id: string,
token: string | null,
overrides: {
base_url?: string;
api_key?: string;
upstream_model_id?: string;
} = {},
): Promise<CustomModelTestResult> => {
// Send only non-empty overrides; server falls back to stored values.
const requestBody: Record<string, string> = {};
if (overrides.base_url) requestBody.base_url = overrides.base_url;
if (overrides.api_key) requestBody.api_key = overrides.api_key;
if (overrides.upstream_model_id)
requestBody.upstream_model_id = overrides.upstream_model_id;
const response = await apiClient.post(
endpoints.USER.CUSTOM_MODEL_TEST(id),
requestBody,
token,
);
const text = await response.text();
let body: any = null;
if (text) {
try {
body = JSON.parse(text);
} catch {
body = null;
}
}
if (!response.ok) {
return {
ok: false,
error:
(body && (body.error || body.message)) ||
`Test failed with status ${response.status}`,
};
}
if (body && typeof body.ok === 'boolean') {
return body as CustomModelTestResult;
}
return { ok: true };
},
};
export default customModelsService;

View File

@@ -19,6 +19,7 @@ const modelService = {
supports_tools: model.supports_tools,
supports_structured_output: model.supports_structured_output,
supports_streaming: model.supports_streaming,
source: model.source,
})),
};

View File

@@ -7,6 +7,7 @@ import RoundedTick from '../assets/rounded-tick.svg';
import {
selectAvailableModels,
selectSelectedModel,
selectToken,
setAvailableModels,
setModelsLoading,
setSelectedModel,
@@ -18,17 +19,26 @@ export default function DropdownModel() {
const dispatch = useDispatch();
const selectedModel = useSelector(selectSelectedModel);
const availableModels = useSelector(selectAvailableModels);
const token = useSelector(selectToken);
const dropdownRef = React.useRef<HTMLDivElement>(null);
// Tracks which token the cached availableModels were loaded for.
// Without this, the early-return below pins the anonymous/built-in
// list forever once it's populated — login/logout never refetches
// and a user's BYOM models stay invisible.
const lastLoadedTokenRef = React.useRef<string | null | undefined>(undefined);
const [isOpen, setIsOpen] = React.useState(false);
useEffect(() => {
const loadModels = async () => {
if ((availableModels?.length ?? 0) > 0) {
if (
(availableModels?.length ?? 0) > 0 &&
lastLoadedTokenRef.current === token
) {
return;
}
dispatch(setModelsLoading(true));
try {
const response = await modelService.getModels(null);
const response = await modelService.getModels(token);
if (!response.ok) {
throw new Error(`API error: ${response.status}`);
}
@@ -37,6 +47,7 @@ export default function DropdownModel() {
const transformed = modelService.transformModels(models);
dispatch(setAvailableModels(transformed));
lastLoadedTokenRef.current = token;
if (!selectedModel && transformed.length > 0) {
const defaultModel =
transformed.find((m) => m.id === data.default_model_id) ||
@@ -59,7 +70,7 @@ export default function DropdownModel() {
};
loadModels();
}, [availableModels?.length, dispatch, selectedModel]);
}, [availableModels?.length, dispatch, selectedModel, token]);
const handleClickOutside = (event: MouseEvent) => {
if (

View File

@@ -11,6 +11,7 @@ export type OptionType = {
id: string | number;
label: string;
icon?: string | React.ReactNode;
group?: string;
[key: string]: any;
};
@@ -227,43 +228,75 @@ export default function MultiSelectPopup({
</p>
</div>
) : (
filteredOptions.map((option) => {
const isSelected = selectedIds.has(option.id);
return (
<div
key={option.id}
onClick={() => handleOptionClick(option.id)}
className="dark:border-border dark:hover:bg-accent hover:bg-accent flex cursor-pointer items-center justify-between border-b border-[#D9D9D9] p-3 last:border-b-0"
role="option"
aria-selected={isSelected}
>
<div className="mr-3 flex grow items-center overflow-hidden">
{option.icon && renderIcon(option.icon)}
<p
className="overflow-hidden text-sm font-medium text-ellipsis whitespace-nowrap text-gray-900 dark:text-white"
title={option.label}
>
{option.label}
</p>
</div>
<div className="shrink-0">
<div
className={`border-border bg-card flex h-4 w-4 items-center justify-center rounded-xs border-2`}
aria-hidden="true"
>
{isSelected && (
<img
src={CheckmarkIcon}
alt="checkmark"
width={10}
height={10}
/>
)}
(() => {
const hasGroups = filteredOptions.some((o) => !!o.group);
const renderOption = (option: OptionType) => {
const isSelected = selectedIds.has(option.id);
return (
<div
key={option.id}
onClick={() => handleOptionClick(option.id)}
className="dark:border-border dark:hover:bg-accent hover:bg-accent flex cursor-pointer items-center justify-between border-b border-[#D9D9D9] p-3 last:border-b-0"
role="option"
aria-selected={isSelected}
>
<div className="mr-3 flex grow items-center overflow-hidden">
{option.icon && renderIcon(option.icon)}
<p
className="overflow-hidden text-sm font-medium text-ellipsis whitespace-nowrap text-gray-900 dark:text-white"
title={option.label}
>
{option.label}
</p>
</div>
<div className="shrink-0">
<div
className={`border-border bg-card flex h-4 w-4 items-center justify-center rounded-xs border-2`}
aria-hidden="true"
>
{isSelected && (
<img
src={CheckmarkIcon}
alt="checkmark"
width={10}
height={10}
/>
)}
</div>
</div>
</div>
);
};
if (!hasGroups) {
return filteredOptions.map(renderOption);
}
const groupOrder: string[] = [];
const groupMap = new Map<string, OptionType[]>();
filteredOptions.forEach((opt) => {
const key = opt.group || '';
if (!groupMap.has(key)) {
groupOrder.push(key);
groupMap.set(key, []);
}
groupMap.get(key)!.push(opt);
});
return groupOrder.map((groupKey) => (
<div key={`group-${groupKey || 'ungrouped'}`}>
{groupKey && (
<div
className="bg-muted/50 dark:bg-card text-muted-foreground sticky top-0 z-10 border-b border-[#D9D9D9] px-3 py-1.5 text-xs font-semibold uppercase dark:border-[#2E2E2E]"
role="presentation"
>
{groupKey}
</div>
)}
{(groupMap.get(groupKey) || []).map(renderOption)}
</div>
);
})
));
})()
)}
</div>
)}

View File

@@ -14,6 +14,7 @@ const useTabs = () => {
t('settings.analytics.label'),
t('settings.logs.label'),
t('settings.tools.label'),
t('settings.customModels.label'),
];
return tabs;
};

View File

@@ -244,6 +244,71 @@
}
}
},
"customModels": {
"label": "Eigene Modelle",
"subtitle": "Verbinde dein eigenes OpenAI-kompatibles Modell.",
"addModel": "Modell hinzufügen",
"empty": "Keine Modelle gefunden",
"searchPlaceholder": "Modelle suchen...",
"modalSubtitle": "Verbinde einen beliebigen OpenAI-kompatiblen Endpunkt (Mistral, Together, vLLM usw.).",
"addTitle": "Eigenes Modell hinzufügen",
"editTitle": "Eigenes Modell bearbeiten",
"save": "Speichern",
"saving": "Speichern",
"testConnection": "Verbindung testen",
"testing": "Test läuft",
"testSuccess": "Verbindung erfolgreich.",
"testHintNew": "Fülle Basis-URL, Modell-ID und API-Schlüssel aus, um Verbindungstests zu aktivieren.",
"disabledBadge": "Deaktiviert",
"deleteWarning": "Möchtest du das eigene Modell \"{{modelName}}\" wirklich löschen?",
"actionsMenuAria": "Aktionen für {{modelName}}",
"modelsGroup": {
"builtin": "DocsGPT-Modelle",
"user": "Meine Modelle"
},
"actions": {
"edit": "Bearbeiten",
"delete": "Löschen"
},
"fields": {
"displayName": "Anzeigename",
"modelId": "Modell-ID",
"description": "Beschreibung",
"baseUrl": "Basis-URL",
"apiKey": "API-Schlüssel"
},
"placeholders": {
"displayName": "Mein Mistral",
"modelId": "z. B. mistral-large-latest",
"description": "Optionale Beschreibung",
"baseUrl": "https://api.mistral.ai/v1",
"apiKey": "sk-...",
"apiKeyEdit": "Leer lassen, um den vorhandenen Schlüssel beizubehalten"
},
"hints": {
"apiKeyEdit": "Leer lassen, um den vorhandenen Schlüssel beizubehalten."
},
"capabilities": {
"title": "Fähigkeiten",
"contextWindowShort": "Kontextfenster",
"chips": {
"tools": "Tools",
"structuredOutput": "Strukturierte Ausgabe",
"images": "Bilder"
}
},
"errors": {
"displayNameRequired": "Anzeigename ist erforderlich",
"modelIdRequired": "Modell-ID ist erforderlich",
"baseUrlRequired": "Basis-URL ist erforderlich",
"baseUrlScheme": "Basis-URL muss mit http:// oder https:// beginnen",
"baseUrlInvalid": "Bitte gib eine gültige URL ein",
"apiKeyRequired": "API-Schlüssel ist erforderlich",
"contextWindowRange": "Kontextfenster muss zwischen 1.000 und 10.000.000 liegen",
"saveFailed": "Eigenes Modell konnte nicht gespeichert werden",
"testFailed": "Verbindungstest fehlgeschlagen"
}
},
"scrollTabsLeft": "Tabs nach links scrollen",
"tabsAriaLabel": "Einstellungs-Tabs",
"scrollTabsRight": "Tabs nach rechts scrollen"

View File

@@ -256,6 +256,71 @@
}
}
},
"customModels": {
"label": "Custom Models",
"subtitle": "Bring your own OpenAI-compatible model.",
"addModel": "Add Model",
"empty": "No models found",
"searchPlaceholder": "Search models...",
"modalSubtitle": "Connect any OpenAI-compatible endpoint (Mistral, Together, vLLM, etc).",
"addTitle": "Add Custom Model",
"editTitle": "Edit Custom Model",
"save": "Save",
"saving": "Saving",
"testConnection": "Test connection",
"testing": "Testing",
"testSuccess": "Connection successful.",
"testHintNew": "Fill in Base URL, Model ID, and API key to enable connection tests.",
"disabledBadge": "Disabled",
"deleteWarning": "Are you sure you want to delete the custom model \"{{modelName}}\"?",
"actionsMenuAria": "Actions for {{modelName}}",
"modelsGroup": {
"builtin": "DocsGPT Models",
"user": "My Models"
},
"actions": {
"edit": "Edit",
"delete": "Delete"
},
"fields": {
"displayName": "Display name",
"modelId": "Model ID",
"description": "Description",
"baseUrl": "Base URL",
"apiKey": "API key"
},
"placeholders": {
"displayName": "My Mistral",
"modelId": "e.g. mistral-large-latest",
"description": "Optional description",
"baseUrl": "https://api.mistral.ai/v1",
"apiKey": "sk-...",
"apiKeyEdit": "Leave blank to keep existing"
},
"hints": {
"apiKeyEdit": "Leave blank to keep existing key."
},
"capabilities": {
"title": "Capabilities",
"contextWindowShort": "Context window",
"chips": {
"tools": "Tools",
"structuredOutput": "Structured output",
"images": "Images"
}
},
"errors": {
"displayNameRequired": "Display name is required",
"modelIdRequired": "Model ID is required",
"baseUrlRequired": "Base URL is required",
"baseUrlScheme": "Base URL must start with http:// or https://",
"baseUrlInvalid": "Please enter a valid URL",
"apiKeyRequired": "API key is required",
"contextWindowRange": "Context window must be between 1,000 and 10,000,000",
"saveFailed": "Failed to save custom model",
"testFailed": "Connection test failed"
}
},
"scrollTabsLeft": "Scroll tabs left",
"tabsAriaLabel": "Settings tabs",
"scrollTabsRight": "Scroll tabs right"

View File

@@ -244,6 +244,71 @@
}
}
},
"customModels": {
"label": "Modelos personalizados",
"subtitle": "Trae tu propio modelo compatible con OpenAI.",
"addModel": "Añadir modelo",
"empty": "No se encontraron modelos",
"searchPlaceholder": "Buscar modelos...",
"modalSubtitle": "Conecta cualquier endpoint compatible con OpenAI (Mistral, Together, vLLM, etc.).",
"addTitle": "Añadir modelo personalizado",
"editTitle": "Editar modelo personalizado",
"save": "Guardar",
"saving": "Guardando",
"testConnection": "Probar conexión",
"testing": "Probando",
"testSuccess": "Conexión correcta.",
"testHintNew": "Completa URL base, ID de modelo y clave API para habilitar las pruebas de conexión.",
"disabledBadge": "Deshabilitado",
"deleteWarning": "¿Seguro que quieres eliminar el modelo personalizado \"{{modelName}}\"?",
"actionsMenuAria": "Acciones para {{modelName}}",
"modelsGroup": {
"builtin": "Modelos de DocsGPT",
"user": "Mis modelos"
},
"actions": {
"edit": "Editar",
"delete": "Eliminar"
},
"fields": {
"displayName": "Nombre para mostrar",
"modelId": "ID del modelo",
"description": "Descripción",
"baseUrl": "URL base",
"apiKey": "Clave API"
},
"placeholders": {
"displayName": "Mi Mistral",
"modelId": "p. ej. mistral-large-latest",
"description": "Descripción opcional",
"baseUrl": "https://api.mistral.ai/v1",
"apiKey": "sk-...",
"apiKeyEdit": "Déjalo en blanco para mantener la actual"
},
"hints": {
"apiKeyEdit": "Déjalo en blanco para mantener la clave actual."
},
"capabilities": {
"title": "Capacidades",
"contextWindowShort": "Ventana de contexto",
"chips": {
"tools": "Herramientas",
"structuredOutput": "Salida estructurada",
"images": "Imágenes"
}
},
"errors": {
"displayNameRequired": "El nombre para mostrar es obligatorio",
"modelIdRequired": "El ID del modelo es obligatorio",
"baseUrlRequired": "La URL base es obligatoria",
"baseUrlScheme": "La URL base debe empezar por http:// o https://",
"baseUrlInvalid": "Introduce una URL válida",
"apiKeyRequired": "La clave API es obligatoria",
"contextWindowRange": "La ventana de contexto debe estar entre 1.000 y 10.000.000",
"saveFailed": "No se pudo guardar el modelo personalizado",
"testFailed": "La prueba de conexión falló"
}
},
"scrollTabsLeft": "Desplazar pestañas a la izquierda",
"tabsAriaLabel": "Pestañas de configuración",
"scrollTabsRight": "Desplazar pestañas a la derecha"

View File

@@ -244,6 +244,71 @@
}
}
},
"customModels": {
"label": "カスタムモデル",
"subtitle": "OpenAI 互換のモデルを自分で接続できます。",
"addModel": "モデルを追加",
"empty": "モデルが見つかりません",
"searchPlaceholder": "モデルを検索...",
"modalSubtitle": "OpenAI 互換のエンドポイントを接続できます (Mistral、Together、vLLM など)。",
"addTitle": "カスタムモデルを追加",
"editTitle": "カスタムモデルを編集",
"save": "保存",
"saving": "保存中",
"testConnection": "接続をテスト",
"testing": "テスト中",
"testSuccess": "接続に成功しました。",
"testHintNew": "接続テストを有効にするには、ベース URL、モデル ID、API キーを入力してください。",
"disabledBadge": "無効",
"deleteWarning": "カスタムモデル「{{modelName}}」を削除してもよろしいですか?",
"actionsMenuAria": "{{modelName}} のアクション",
"modelsGroup": {
"builtin": "DocsGPT モデル",
"user": "マイモデル"
},
"actions": {
"edit": "編集",
"delete": "削除"
},
"fields": {
"displayName": "表示名",
"modelId": "モデル ID",
"description": "説明",
"baseUrl": "ベース URL",
"apiKey": "API キー"
},
"placeholders": {
"displayName": "My Mistral",
"modelId": "例: mistral-large-latest",
"description": "説明 (任意)",
"baseUrl": "https://api.mistral.ai/v1",
"apiKey": "sk-...",
"apiKeyEdit": "既存のキーを保持する場合は空欄のまま"
},
"hints": {
"apiKeyEdit": "既存のキーを保持するには空欄のままにしてください。"
},
"capabilities": {
"title": "機能",
"contextWindowShort": "コンテキストウィンドウ",
"chips": {
"tools": "ツール",
"structuredOutput": "構造化出力",
"images": "画像"
}
},
"errors": {
"displayNameRequired": "表示名は必須です",
"modelIdRequired": "モデル ID は必須です",
"baseUrlRequired": "ベース URL は必須です",
"baseUrlScheme": "ベース URL は http:// または https:// で始まる必要があります",
"baseUrlInvalid": "有効な URL を入力してください",
"apiKeyRequired": "API キーは必須です",
"contextWindowRange": "コンテキストウィンドウは 1,000 から 10,000,000 の範囲で指定してください",
"saveFailed": "カスタムモデルの保存に失敗しました",
"testFailed": "接続テストに失敗しました"
}
},
"scrollTabsLeft": "タブを左にスクロール",
"tabsAriaLabel": "設定タブ",
"scrollTabsRight": "タブを右にスクロール"

View File

@@ -244,6 +244,71 @@
}
}
},
"customModels": {
"label": "Свои модели",
"subtitle": "Подключите свою модель, совместимую с OpenAI.",
"addModel": "Добавить модель",
"empty": "Модели не найдены",
"searchPlaceholder": "Поиск моделей...",
"modalSubtitle": "Подключите любой OpenAI-совместимый эндпоинт (Mistral, Together, vLLM и т. п.).",
"addTitle": "Добавить свою модель",
"editTitle": "Изменить свою модель",
"save": "Сохранить",
"saving": "Сохранение",
"testConnection": "Проверить соединение",
"testing": "Проверка",
"testSuccess": "Соединение установлено.",
"testHintNew": "Заполните Base URL, ID модели и API-ключ, чтобы включить проверку соединения.",
"disabledBadge": "Отключено",
"deleteWarning": "Удалить свою модель «{{modelName}}»?",
"actionsMenuAria": "Действия для {{modelName}}",
"modelsGroup": {
"builtin": "Модели DocsGPT",
"user": "Мои модели"
},
"actions": {
"edit": "Изменить",
"delete": "Удалить"
},
"fields": {
"displayName": "Отображаемое имя",
"modelId": "ID модели",
"description": "Описание",
"baseUrl": "Base URL",
"apiKey": "API-ключ"
},
"placeholders": {
"displayName": "Мой Mistral",
"modelId": "напр. mistral-large-latest",
"description": "Необязательное описание",
"baseUrl": "https://api.mistral.ai/v1",
"apiKey": "sk-...",
"apiKeyEdit": "Оставьте пустым, чтобы сохранить текущий"
},
"hints": {
"apiKeyEdit": "Оставьте пустым, чтобы сохранить текущий ключ."
},
"capabilities": {
"title": "Возможности",
"contextWindowShort": "Контекстное окно",
"chips": {
"tools": "Инструменты",
"structuredOutput": "Структурированный вывод",
"images": "Изображения"
}
},
"errors": {
"displayNameRequired": "Укажите отображаемое имя",
"modelIdRequired": "Укажите ID модели",
"baseUrlRequired": "Укажите Base URL",
"baseUrlScheme": "Base URL должен начинаться с http:// или https://",
"baseUrlInvalid": "Введите корректный URL",
"apiKeyRequired": "Укажите API-ключ",
"contextWindowRange": "Контекстное окно должно быть от 1 000 до 10 000 000",
"saveFailed": "Не удалось сохранить свою модель",
"testFailed": "Проверка соединения не удалась"
}
},
"scrollTabsLeft": "Прокрутить вкладки влево",
"tabsAriaLabel": "Вкладки настроек",
"scrollTabsRight": "Прокрутить вкладки вправо"

View File

@@ -244,6 +244,71 @@
}
}
},
"customModels": {
"label": "自訂模型",
"subtitle": "接入你自己的 OpenAI 相容模型。",
"addModel": "新增模型",
"empty": "找不到模型",
"searchPlaceholder": "搜尋模型...",
"modalSubtitle": "可連接任何 OpenAI 相容端點(Mistral、Together、vLLM 等)。",
"addTitle": "新增自訂模型",
"editTitle": "編輯自訂模型",
"save": "儲存",
"saving": "儲存中",
"testConnection": "測試連線",
"testing": "測試中",
"testSuccess": "連線成功。",
"testHintNew": "請填寫 Base URL、模型 ID 與 API 金鑰以啟用連線測試。",
"disabledBadge": "已停用",
"deleteWarning": "確定要刪除自訂模型「{{modelName}}」嗎?",
"actionsMenuAria": "{{modelName}} 的操作",
"modelsGroup": {
"builtin": "DocsGPT 模型",
"user": "我的模型"
},
"actions": {
"edit": "編輯",
"delete": "刪除"
},
"fields": {
"displayName": "顯示名稱",
"modelId": "模型 ID",
"description": "說明",
"baseUrl": "Base URL",
"apiKey": "API 金鑰"
},
"placeholders": {
"displayName": "My Mistral",
"modelId": "例:mistral-large-latest",
"description": "選填說明",
"baseUrl": "https://api.mistral.ai/v1",
"apiKey": "sk-...",
"apiKeyEdit": "留空則保留原金鑰"
},
"hints": {
"apiKeyEdit": "留空以保留現有金鑰。"
},
"capabilities": {
"title": "功能",
"contextWindowShort": "脈絡視窗",
"chips": {
"tools": "工具",
"structuredOutput": "結構化輸出",
"images": "圖片"
}
},
"errors": {
"displayNameRequired": "顯示名稱為必填",
"modelIdRequired": "模型 ID 為必填",
"baseUrlRequired": "Base URL 為必填",
"baseUrlScheme": "Base URL 必須以 http:// 或 https:// 開頭",
"baseUrlInvalid": "請輸入有效的 URL",
"apiKeyRequired": "API 金鑰為必填",
"contextWindowRange": "脈絡視窗須介於 1,000 至 10,000,000 之間",
"saveFailed": "儲存自訂模型失敗",
"testFailed": "連線測試失敗"
}
},
"scrollTabsLeft": "向左捲動標籤",
"tabsAriaLabel": "設定標籤",
"scrollTabsRight": "向右捲動標籤"

View File

@@ -244,6 +244,71 @@
}
}
},
"customModels": {
"label": "自定义模型",
"subtitle": "接入你自己的 OpenAI 兼容模型。",
"addModel": "添加模型",
"empty": "未找到模型",
"searchPlaceholder": "搜索模型...",
"modalSubtitle": "可连接任何 OpenAI 兼容的接口(Mistral、Together、vLLM 等)。",
"addTitle": "添加自定义模型",
"editTitle": "编辑自定义模型",
"save": "保存",
"saving": "保存中",
"testConnection": "测试连接",
"testing": "测试中",
"testSuccess": "连接成功。",
"testHintNew": "请填写 Base URL、模型 ID 和 API 密钥以启用连接测试。",
"disabledBadge": "已禁用",
"deleteWarning": "确定要删除自定义模型「{{modelName}}」吗?",
"actionsMenuAria": "{{modelName}} 的操作",
"modelsGroup": {
"builtin": "DocsGPT 模型",
"user": "我的模型"
},
"actions": {
"edit": "编辑",
"delete": "删除"
},
"fields": {
"displayName": "显示名称",
"modelId": "模型 ID",
"description": "描述",
"baseUrl": "Base URL",
"apiKey": "API 密钥"
},
"placeholders": {
"displayName": "My Mistral",
"modelId": "如 mistral-large-latest",
"description": "可选描述",
"baseUrl": "https://api.mistral.ai/v1",
"apiKey": "sk-...",
"apiKeyEdit": "留空则保留原密钥"
},
"hints": {
"apiKeyEdit": "留空以保留现有密钥。"
},
"capabilities": {
"title": "功能",
"contextWindowShort": "上下文窗口",
"chips": {
"tools": "工具",
"structuredOutput": "结构化输出",
"images": "图片"
}
},
"errors": {
"displayNameRequired": "显示名称必填",
"modelIdRequired": "模型 ID 必填",
"baseUrlRequired": "Base URL 必填",
"baseUrlScheme": "Base URL 必须以 http:// 或 https:// 开头",
"baseUrlInvalid": "请输入有效的 URL",
"apiKeyRequired": "API 密钥必填",
"contextWindowRange": "上下文窗口必须在 1,000 至 10,000,000 之间",
"saveFailed": "保存自定义模型失败",
"testFailed": "连接测试失败"
}
},
"scrollTabsLeft": "向左滚动标签",
"tabsAriaLabel": "设置标签",
"scrollTabsRight": "向右滚动标签"

View File

@@ -0,0 +1,597 @@
import { Check } from 'lucide-react';
import { useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useSelector } from 'react-redux';
import customModelsService from '../api/services/customModelsService';
import Spinner from '../components/Spinner';
import { Input } from '../components/ui/input';
import { Label } from '../components/ui/label';
import { ActiveState } from '../models/misc';
import { selectToken } from '../preferences/preferenceSlice';
import WrapperComponent from './WrapperModal';
import type { CreateCustomModelPayload, CustomModel } from '../models/types';
interface CustomModelModalProps {
modalState: ActiveState;
setModalState: (state: ActiveState) => void;
model?: CustomModel | null;
onSaved: (model: CustomModel) => void;
}
interface FormState {
display_name: string;
upstream_model_id: string;
description: string;
base_url: string;
api_key: string;
supports_tools: boolean;
supports_structured_output: boolean;
supports_images: boolean;
context_window: number | '';
enabled: boolean;
}
const DEFAULT_CONTEXT_WINDOW = 128000;
const MIN_CONTEXT_WINDOW = 1000;
const MAX_CONTEXT_WINDOW = 10_000_000;
const buildInitialFormState = (model?: CustomModel | null): FormState => {
if (!model) {
return {
display_name: '',
upstream_model_id: '',
description: '',
base_url: '',
api_key: '',
supports_tools: true,
supports_structured_output: true,
supports_images: false,
context_window: DEFAULT_CONTEXT_WINDOW,
enabled: true,
};
}
const attachments = Array.isArray(model.capabilities?.attachments)
? model.capabilities.attachments
: [];
return {
display_name: model.display_name || '',
upstream_model_id: model.upstream_model_id || '',
description: model.description || '',
base_url: model.base_url || '',
api_key: '',
supports_tools: model.capabilities?.supports_tools ?? true,
supports_structured_output:
model.capabilities?.supports_structured_output ?? true,
supports_images: attachments.includes('image'),
context_window:
model.capabilities?.context_window ?? DEFAULT_CONTEXT_WINDOW,
enabled: model.enabled ?? true,
};
};
export default function CustomModelModal({
modalState,
setModalState,
model,
onSaved,
}: CustomModelModalProps) {
const { t } = useTranslation();
const token = useSelector(selectToken);
const isEditMode = !!model?.id;
const [formData, setFormData] = useState<FormState>(() =>
buildInitialFormState(model),
);
const [errors, setErrors] = useState<{ [key: string]: string }>({});
const [saving, setSaving] = useState(false);
const [testing, setTesting] = useState(false);
const [testResult, setTestResult] = useState<{
ok: boolean;
message: string;
} | null>(null);
useEffect(() => {
if (modalState === 'ACTIVE') {
setFormData(buildInitialFormState(model));
setErrors({});
setTestResult(null);
setSaving(false);
setTesting(false);
}
}, [modalState, model]);
const closeModal = () => {
setModalState('INACTIVE');
};
const handleChange = <K extends keyof FormState>(
name: K,
value: FormState[K],
) => {
setFormData((prev) => ({ ...prev, [name]: value }));
if (errors[name as string] || errors.general) {
setErrors((prev) => {
const next = { ...prev };
delete next[name as string];
delete next.general;
delete next.base_url_remote;
return next;
});
}
setTestResult(null);
};
const validate = (): boolean => {
const newErrors: { [key: string]: string } = {};
if (!formData.display_name.trim()) {
newErrors.display_name = t(
'settings.customModels.errors.displayNameRequired',
);
}
if (!formData.upstream_model_id.trim()) {
newErrors.upstream_model_id = t(
'settings.customModels.errors.modelIdRequired',
);
}
const trimmedUrl = formData.base_url.trim();
if (!trimmedUrl) {
newErrors.base_url = t('settings.customModels.errors.baseUrlRequired');
} else if (!/^https?:\/\//i.test(trimmedUrl)) {
newErrors.base_url = t('settings.customModels.errors.baseUrlScheme');
} else {
try {
new URL(trimmedUrl);
} catch {
newErrors.base_url = t('settings.customModels.errors.baseUrlInvalid');
}
}
if (!isEditMode && !formData.api_key.trim()) {
newErrors.api_key = t('settings.customModels.errors.apiKeyRequired');
}
const ctxValue =
formData.context_window === '' ? NaN : Number(formData.context_window);
if (
Number.isNaN(ctxValue) ||
ctxValue < MIN_CONTEXT_WINDOW ||
ctxValue > MAX_CONTEXT_WINDOW
) {
newErrors.context_window = t(
'settings.customModels.errors.contextWindowRange',
);
}
setErrors(newErrors);
return Object.keys(newErrors).length === 0;
};
const buildPayload = (): CreateCustomModelPayload => {
const ctxValue =
formData.context_window === ''
? DEFAULT_CONTEXT_WINDOW
: Number(formData.context_window);
const payload: CreateCustomModelPayload = {
upstream_model_id: formData.upstream_model_id.trim(),
display_name: formData.display_name.trim(),
description: formData.description.trim(),
base_url: formData.base_url.trim(),
capabilities: {
supports_tools: formData.supports_tools,
supports_structured_output: formData.supports_structured_output,
attachments: formData.supports_images ? ['image'] : [],
context_window: ctxValue,
},
enabled: formData.enabled,
};
if (formData.api_key.trim()) {
payload.api_key = formData.api_key.trim();
}
return payload;
};
const mapErrorToField = (
message: string,
): { field: string; message: string } => {
const lower = message.toLowerCase();
if (
lower.includes('reachable') ||
lower.includes('public internet') ||
lower.includes('ssrf') ||
lower.includes('url') ||
lower.includes('host')
) {
return { field: 'base_url_remote', message };
}
return { field: 'general', message };
};
const handleSave = async () => {
if (!validate()) return;
setSaving(true);
setTestResult(null);
try {
const payload = buildPayload();
const saved = isEditMode
? await customModelsService.updateCustomModel(model!.id, payload, token)
: await customModelsService.createCustomModel(payload, token);
onSaved(saved);
closeModal();
} catch (err) {
const message =
err instanceof Error
? err.message
: t('settings.customModels.errors.saveFailed');
const mapped = mapErrorToField(message);
setErrors((prev) => ({ ...prev, [mapped.field]: mapped.message }));
} finally {
setSaving(false);
}
};
// Edit mode allows blank api_key (by-id endpoint falls back to stored).
const trimmedBaseUrl = formData.base_url.trim();
const trimmedApiKey = formData.api_key.trim();
const trimmedUpstreamId = formData.upstream_model_id.trim();
const canTest = isEditMode
? !!(trimmedBaseUrl && trimmedUpstreamId)
: !!(trimmedBaseUrl && trimmedApiKey && trimmedUpstreamId);
const testDisabledHint = canTest
? undefined
: t('settings.customModels.testHintNew');
const handleTest = async () => {
if (!canTest) return;
setTesting(true);
setTestResult(null);
try {
const result =
isEditMode && model?.id
? await customModelsService.testCustomModel(model.id, token, {
base_url: trimmedBaseUrl,
api_key: trimmedApiKey,
upstream_model_id: trimmedUpstreamId,
})
: await customModelsService.testCustomModelPayload(
{
base_url: trimmedBaseUrl,
api_key: trimmedApiKey,
upstream_model_id: trimmedUpstreamId,
},
token,
);
if (result.ok) {
setTestResult({
ok: true,
message: t('settings.customModels.testSuccess'),
});
} else {
const message =
result.error || t('settings.customModels.errors.testFailed');
setTestResult({ ok: false, message });
const mapped = mapErrorToField(message);
if (mapped.field === 'base_url_remote') {
setErrors((prev) => ({ ...prev, base_url_remote: message }));
}
}
} catch (err) {
const message =
err instanceof Error
? err.message
: t('settings.customModels.errors.testFailed');
setTestResult({ ok: false, message });
} finally {
setTesting(false);
}
};
if (modalState !== 'ACTIVE') return null;
return (
<WrapperComponent
close={closeModal}
isPerformingTask={saving}
className="max-w-[600px] md:w-[80vw] lg:w-[60vw]"
>
<div className="flex h-full flex-col">
<div className="px-2 py-2">
<h2 className="text-foreground dark:text-foreground text-xl font-semibold">
{isEditMode
? t('settings.customModels.editTitle')
: t('settings.customModels.addTitle')}
</h2>
<p className="text-muted-foreground mt-2 text-sm">
{t('settings.customModels.modalSubtitle')}
</p>
</div>
<div className="flex-1 px-2">
<div className="flex flex-col gap-4 px-0.5 py-4">
{/* Row 1: Display name + Model ID side-by-side */}
<div className="grid grid-cols-1 gap-4 sm:grid-cols-2">
<div className="flex flex-col gap-1.5">
<Label htmlFor="cm-display-name">
{t('settings.customModels.fields.displayName')}
<span className="text-red-500">*</span>
</Label>
<Input
id="cm-display-name"
type="text"
value={formData.display_name}
onChange={(e) => handleChange('display_name', e.target.value)}
placeholder={t(
'settings.customModels.placeholders.displayName',
)}
aria-invalid={!!errors.display_name || undefined}
className="rounded-xl"
/>
{errors.display_name && (
<p className="text-destructive text-xs">
{errors.display_name}
</p>
)}
</div>
<div className="flex flex-col gap-1.5">
<Label htmlFor="cm-model-id">
{t('settings.customModels.fields.modelId')}
<span className="text-red-500">*</span>
</Label>
<Input
id="cm-model-id"
type="text"
value={formData.upstream_model_id}
onChange={(e) =>
handleChange('upstream_model_id', e.target.value)
}
placeholder={t('settings.customModels.placeholders.modelId')}
aria-invalid={!!errors.upstream_model_id || undefined}
className="rounded-xl"
/>
{errors.upstream_model_id && (
<p className="text-destructive text-xs">
{errors.upstream_model_id}
</p>
)}
</div>
</div>
{/* Row 2: Base URL + API key side-by-side */}
<div className="grid grid-cols-1 gap-4 sm:grid-cols-2">
<div className="flex flex-col gap-1.5">
<Label htmlFor="cm-base-url">
{t('settings.customModels.fields.baseUrl')}
<span className="text-red-500">*</span>
</Label>
<Input
id="cm-base-url"
type="url"
value={formData.base_url}
onChange={(e) => handleChange('base_url', e.target.value)}
placeholder={t('settings.customModels.placeholders.baseUrl')}
aria-invalid={
!!errors.base_url || !!errors.base_url_remote || undefined
}
className="rounded-xl"
/>
{errors.base_url && (
<p className="text-destructive text-xs">{errors.base_url}</p>
)}
{errors.base_url_remote && (
<p className="text-destructive text-xs">
{errors.base_url_remote}
</p>
)}
</div>
<div className="flex flex-col gap-1.5">
<Label htmlFor="cm-api-key">
{t('settings.customModels.fields.apiKey')}
{!isEditMode && <span className="text-red-500">*</span>}
</Label>
<Input
id="cm-api-key"
type="password"
autoComplete="new-password"
value={formData.api_key}
onChange={(e) => handleChange('api_key', e.target.value)}
placeholder={
isEditMode
? t('settings.customModels.placeholders.apiKeyEdit')
: t('settings.customModels.placeholders.apiKey')
}
aria-invalid={!!errors.api_key || undefined}
className="rounded-xl"
/>
{isEditMode && (
<p className="text-muted-foreground text-xs">
{t('settings.customModels.hints.apiKeyEdit')}
</p>
)}
{errors.api_key && (
<p className="text-destructive text-xs">{errors.api_key}</p>
)}
</div>
</div>
{/* Row 3: Description (full width, optional) */}
<div className="flex flex-col gap-1.5">
<Label htmlFor="cm-description">
{t('settings.customModels.fields.description')}
</Label>
<Input
id="cm-description"
type="text"
value={formData.description}
onChange={(e) => handleChange('description', e.target.value)}
placeholder={t(
'settings.customModels.placeholders.description',
)}
className="rounded-xl"
/>
</div>
{/* Row 4: Capabilities — flat (no border), chips + inline ctx */}
<div className="flex flex-col gap-2">
<Label>{t('settings.customModels.capabilities.title')}</Label>
<div className="flex flex-wrap gap-2">
<CapabilityChip
label={t('settings.customModels.capabilities.chips.tools')}
active={formData.supports_tools}
onClick={() =>
handleChange('supports_tools', !formData.supports_tools)
}
/>
<CapabilityChip
label={t(
'settings.customModels.capabilities.chips.structuredOutput',
)}
active={formData.supports_structured_output}
onClick={() =>
handleChange(
'supports_structured_output',
!formData.supports_structured_output,
)
}
/>
<CapabilityChip
label={t('settings.customModels.capabilities.chips.images')}
active={formData.supports_images}
onClick={() =>
handleChange('supports_images', !formData.supports_images)
}
/>
</div>
<div className="mt-1 flex flex-col gap-1.5 sm:flex-row sm:items-center sm:gap-3">
<Label
htmlFor="cm-context-window"
className="text-muted-foreground text-xs sm:mb-0"
>
{t('settings.customModels.capabilities.contextWindowShort')}
</Label>
<Input
id="cm-context-window"
type="number"
value={formData.context_window}
min={MIN_CONTEXT_WINDOW}
max={MAX_CONTEXT_WINDOW}
step={1000}
onChange={(e) => {
const v = e.target.value;
if (v === '') {
handleChange('context_window', '');
} else {
const n = parseInt(v, 10);
if (!Number.isNaN(n)) {
handleChange('context_window', n);
}
}
}}
aria-invalid={!!errors.context_window || undefined}
className="w-full rounded-xl sm:w-40"
/>
{errors.context_window && (
<p className="text-destructive text-xs">
{errors.context_window}
</p>
)}
</div>
</div>
{testResult && (
<div
className={`rounded-xl p-3 text-sm ${
testResult.ok
? 'bg-green-50 text-green-700 dark:bg-green-900/40 dark:text-green-300'
: 'bg-red-50 text-red-700 dark:bg-red-900/40 dark:text-red-300'
}`}
>
{testResult.message}
</div>
)}
{errors.general && (
<div className="rounded-xl bg-red-50 p-3 text-sm text-red-700 dark:bg-red-900/40 dark:text-red-300">
{errors.general}
</div>
)}
</div>
</div>
<div className="px-2 py-4">
<div className="flex flex-col gap-3 sm:flex-row sm:justify-between">
<button
type="button"
onClick={handleTest}
disabled={!canTest || testing || saving}
title={testDisabledHint}
className="border-border dark:border-border dark:text-foreground hover:bg-accent dark:hover:bg-muted/50 w-full rounded-3xl border px-6 py-2 text-sm font-medium transition-all disabled:cursor-not-allowed disabled:opacity-50 sm:w-auto"
>
{testing ? (
<div className="flex items-center justify-center">
<Spinner size="small" />
<span className="ml-2">
{t('settings.customModels.testing')}
</span>
</div>
) : (
t('settings.customModels.testConnection')
)}
</button>
<div className="flex flex-col-reverse gap-3 sm:flex-row sm:gap-3">
<button
type="button"
onClick={closeModal}
disabled={saving}
className="dark:text-foreground hover:bg-accent dark:hover:bg-muted/50 w-full cursor-pointer rounded-3xl px-6 py-2 text-sm font-medium disabled:opacity-50 sm:w-auto"
>
{t('cancel')}
</button>
<button
type="button"
onClick={handleSave}
disabled={saving}
className="bg-primary hover:bg-primary/90 w-full rounded-3xl px-6 py-2 text-sm font-medium text-white transition-all disabled:opacity-50 sm:w-auto"
>
{saving ? (
<div className="flex items-center justify-center">
<Spinner size="small" />
<span className="ml-2">
{t('settings.customModels.saving')}
</span>
</div>
) : (
t('settings.customModels.save')
)}
</button>
</div>
</div>
</div>
</div>
</WrapperComponent>
);
}
interface CapabilityChipProps {
label: string;
active: boolean;
onClick: () => void;
}
function CapabilityChip({ label, active, onClick }: CapabilityChipProps) {
return (
<button
type="button"
role="switch"
aria-checked={active}
onClick={onClick}
className={`inline-flex items-center gap-1.5 rounded-full border px-3 py-1.5 text-sm transition-colors ${
active
? 'border-emerald-500/50 bg-emerald-500/10 text-emerald-700 dark:border-emerald-400/40 dark:bg-emerald-400/10 dark:text-emerald-300'
: 'border-border text-muted-foreground hover:bg-accent dark:hover:bg-muted/40'
}`}
>
{active && <Check size={14} strokeWidth={2.5} />}
{label}
</button>
);
}

View File

@@ -1,3 +1,5 @@
export type ModelSource = 'builtin' | 'user';
export interface AvailableModel {
id: string;
provider: string;
@@ -9,6 +11,7 @@ export interface AvailableModel {
supports_structured_output: boolean;
supports_streaming: boolean;
enabled: boolean;
source?: ModelSource;
}
export interface Model {
@@ -22,4 +25,38 @@ export interface Model {
supports_tools: boolean;
supports_structured_output: boolean;
supports_streaming: boolean;
source?: ModelSource;
}
export interface CustomModelCapabilities {
supports_tools: boolean;
supports_structured_output: boolean;
attachments: string[];
context_window: number;
}
export interface CustomModel {
id: string;
upstream_model_id: string;
display_name: string;
description?: string;
base_url: string;
capabilities: CustomModelCapabilities;
enabled: boolean;
source: 'user';
}
export interface CreateCustomModelPayload {
upstream_model_id: string;
display_name: string;
description?: string;
base_url: string;
api_key?: string;
capabilities: CustomModelCapabilities;
enabled: boolean;
}
export interface CustomModelTestResult {
ok: boolean;
error?: string;
}

View File

@@ -250,6 +250,22 @@ prefListenerMiddleware.startListening({
},
});
// Reconcile selectedModel when availableModels changes so a deleted
// BYOM doesn't leave a stale id in localStorage.
prefListenerMiddleware.startListening({
matcher: isAnyOf(setAvailableModels),
effect: (_action, listenerApi) => {
const state = listenerApi.getState() as RootState;
const { selectedModel, availableModels } = state.preference;
if (!availableModels.length) return;
if (!selectedModel) return;
const stillValid = availableModels.some((m) => m.id === selectedModel.id);
if (!stillValid) {
listenerApi.dispatch(setSelectedModel(availableModels[0]));
}
},
});
export const selectApiKey = (state: RootState) => state.preference.apiKey;
export const selectApiKeyStatus = (state: RootState) =>
!!state.preference.apiKey;

View File

@@ -0,0 +1,313 @@
import { Globe, Tag, Trash } from 'lucide-react';
import React from 'react';
import { useTranslation } from 'react-i18next';
import { useDispatch, useSelector } from 'react-redux';
import customModelsService from '../api/services/customModelsService';
import modelService from '../api/services/modelService';
import Edit from '../assets/edit.svg';
import NoFilesDarkIcon from '../assets/no-files-dark.svg';
import NoFilesIcon from '../assets/no-files.svg';
import SearchIcon from '../assets/search.svg';
import ThreeDotsIcon from '../assets/three-dots.svg';
import ContextMenu, { MenuOption } from '../components/ContextMenu';
import SkeletonLoader from '../components/SkeletonLoader';
import { useDarkTheme, useLoaderState } from '../hooks';
import ConfirmationModal from '../modals/ConfirmationModal';
import CustomModelModal from '../modals/CustomModelModal';
import { ActiveState } from '../models/misc';
import {
selectToken,
setAvailableModels,
} from '../preferences/preferenceSlice';
import type { CustomModel } from '../models/types';
const formatBaseUrlHost = (baseUrl: string): string => {
if (!baseUrl) return '';
try {
const url = new URL(baseUrl);
return url.host || url.hostname || baseUrl;
} catch {
const stripped = baseUrl.replace(/^https?:\/\//i, '');
const slashIdx = stripped.indexOf('/');
return slashIdx >= 0 ? stripped.slice(0, slashIdx) : stripped;
}
};
export default function CustomModels() {
const { t } = useTranslation();
const dispatch = useDispatch();
const token = useSelector(selectToken);
const [isDarkTheme] = useDarkTheme();
const [models, setModels] = React.useState<CustomModel[]>([]);
const [searchTerm, setSearchTerm] = React.useState('');
const [loading, setLoading] = useLoaderState(false);
const [modalState, setModalState] = React.useState<ActiveState>('INACTIVE');
const [editingModel, setEditingModel] = React.useState<CustomModel | null>(
null,
);
const [activeMenuId, setActiveMenuId] = React.useState<string | null>(null);
const menuRefs = React.useRef<{
[key: string]: React.RefObject<HTMLDivElement | null>;
}>({});
const [deleteState, setDeleteState] = React.useState<ActiveState>('INACTIVE');
const [modelToDelete, setModelToDelete] = React.useState<CustomModel | null>(
null,
);
// Ref instead of useCallback: useLoaderState returns a fresh setter
// each render, which would loop the effect (thousands of req/s).
const fetchModelsRef = React.useRef<() => Promise<void>>(async () => {});
fetchModelsRef.current = async () => {
setLoading(true);
try {
const data = await customModelsService.listCustomModels(token);
setModels(data);
} catch (err) {
console.error('Failed to load custom models:', err);
setModels([]);
} finally {
setLoading(false);
}
};
React.useEffect(() => {
fetchModelsRef.current();
}, [token]);
React.useEffect(() => {
models.forEach((model) => {
if (!menuRefs.current[model.id]) {
menuRefs.current[model.id] = React.createRef<HTMLDivElement>();
}
});
}, [models]);
const openAddModal = () => {
setEditingModel(null);
setModalState('ACTIVE');
};
const openEditModal = (model: CustomModel) => {
setEditingModel(model);
setModalState('ACTIVE');
};
// Refresh Redux availableModels so the chat dropdown reconciles a
// selectedModel UUID that was just deleted/disabled.
const refreshGlobalAvailableModels = React.useCallback(async () => {
try {
const response = await modelService.getModels(token);
if (!response.ok) return;
const data = await response.json();
const transformed = modelService.transformModels(data.models || []);
dispatch(setAvailableModels(transformed));
} catch (err) {
console.error('Failed to refresh global available models:', err);
}
}, [dispatch, token]);
const handleSaved = (saved: CustomModel) => {
setModels((prev) => {
const idx = prev.findIndex((m) => m.id === saved.id);
if (idx === -1) return [saved, ...prev];
const next = [...prev];
next[idx] = saved;
return next;
});
refreshGlobalAvailableModels();
};
const requestDelete = (model: CustomModel) => {
setModelToDelete(model);
setDeleteState('ACTIVE');
};
const confirmDelete = async () => {
if (!modelToDelete) return;
try {
await customModelsService.deleteCustomModel(modelToDelete.id, token);
setModels((prev) => prev.filter((m) => m.id !== modelToDelete.id));
refreshGlobalAvailableModels();
} catch (err) {
console.error('Failed to delete custom model:', err);
} finally {
setModelToDelete(null);
setDeleteState('INACTIVE');
}
};
const getMenuOptions = (model: CustomModel): MenuOption[] => [
{
icon: Edit,
label: t('settings.customModels.actions.edit'),
onClick: () => openEditModal(model),
variant: 'primary',
iconWidth: 14,
iconHeight: 14,
},
{
icon: Trash,
label: t('settings.customModels.actions.delete'),
onClick: () => requestDelete(model),
variant: 'danger',
iconWidth: 16,
iconHeight: 16,
},
];
const filteredModels = models.filter((model) => {
const q = searchTerm.toLowerCase();
return (
model.display_name.toLowerCase().includes(q) ||
model.upstream_model_id.toLowerCase().includes(q)
);
});
const renderEmptyState = () => (
<div className="flex w-full flex-col items-center justify-center py-12">
<img
src={isDarkTheme ? NoFilesDarkIcon : NoFilesIcon}
alt={t('settings.customModels.empty')}
className="mx-auto mb-6 h-32 w-32"
/>
<p className="text-center text-lg text-gray-500 dark:text-gray-400">
{t('settings.customModels.empty')}
</p>
</div>
);
return (
<div className="mt-8">
<div className="relative flex flex-col">
<p className="text-muted-foreground mb-5 text-[15px] leading-6">
{t('settings.customModels.subtitle')}
</p>
<div className="my-3 flex flex-col items-start gap-3 sm:flex-row sm:items-center sm:justify-between">
<div className="relative w-full max-w-md">
<img
src={SearchIcon}
alt=""
className="absolute top-1/2 left-4 h-5 w-5 -translate-y-1/2 opacity-40"
/>
<input
maxLength={256}
placeholder={t('settings.customModels.searchPlaceholder')}
name="custom-models-search-input"
type="text"
id="custom-models-search-input"
value={searchTerm}
onChange={(e) => setSearchTerm(e.target.value)}
className="border-border bg-card text-foreground placeholder:text-muted-foreground h-11 w-full rounded-full border py-2 pr-5 pl-11 text-sm shadow-[0_1px_4px_rgba(0,0,0,0.06)] transition-shadow outline-none focus:shadow-[0_2px_8px_rgba(0,0,0,0.1)] dark:shadow-none"
/>
</div>
<button
className="bg-primary hover:bg-primary/90 flex h-11 min-w-[108px] items-center justify-center rounded-full px-4 text-sm whitespace-normal text-white"
onClick={openAddModal}
>
{t('settings.customModels.addModel')}
</button>
</div>
<div className="border-border dark:border-border mt-5 mb-8 border-b" />
{loading ? (
<div className="flex flex-wrap justify-center gap-4 sm:justify-start">
<SkeletonLoader component="toolCards" count={3} />
</div>
) : (
<div className="flex flex-wrap justify-center gap-4 sm:justify-start">
{filteredModels.length === 0
? renderEmptyState()
: filteredModels.map((model) => (
<div
key={model.id}
className="bg-muted hover:bg-accent relative flex w-[300px] flex-col overflow-hidden rounded-2xl p-5"
>
<div
ref={menuRefs.current[model.id]}
onClick={(e) => {
e.stopPropagation();
setActiveMenuId(
activeMenuId === model.id ? null : model.id,
);
}}
className="absolute top-3 right-3 z-10 cursor-pointer"
>
<img
src={ThreeDotsIcon}
alt={t('settings.customModels.actionsMenuAria', {
modelName: model.display_name,
})}
className="h-[19px] w-[19px]"
/>
<ContextMenu
isOpen={activeMenuId === model.id}
setIsOpen={(isOpen) => {
setActiveMenuId(isOpen ? model.id : null);
}}
options={getMenuOptions(model)}
anchorRef={menuRefs.current[model.id]}
position="bottom-right"
offset={{ x: 0, y: 0 }}
/>
</div>
<div className="w-full pr-7">
<div className="flex items-center gap-2">
<p
title={model.display_name}
className="text-foreground dark:text-foreground truncate text-[15px] leading-snug font-semibold"
>
{model.display_name}
</p>
{!model.enabled && (
<span className="bg-muted-foreground/15 text-muted-foreground shrink-0 rounded-full px-2 py-0.5 text-[10px] leading-none font-medium">
{t('settings.customModels.disabledBadge')}
</span>
)}
</div>
<div className="mt-3 space-y-1.5">
<div
className="text-muted-foreground/80 flex items-center gap-1.5 text-xs leading-relaxed"
title={model.upstream_model_id}
>
<Tag className="h-3.5 w-3.5 shrink-0 opacity-70" />
<span className="truncate">
{model.upstream_model_id}
</span>
</div>
<div
className="text-muted-foreground/80 flex items-center gap-1.5 text-xs leading-relaxed"
title={model.base_url}
>
<Globe className="h-3.5 w-3.5 shrink-0 opacity-70" />
<span className="truncate">
{formatBaseUrlHost(model.base_url)}
</span>
</div>
</div>
</div>
</div>
))}
</div>
)}
</div>
<CustomModelModal
modalState={modalState}
setModalState={setModalState}
model={editingModel}
onSaved={handleSaved}
/>
<ConfirmationModal
message={t('settings.customModels.deleteWarning', {
modelName: modelToDelete?.display_name || '',
})}
modalState={deleteState}
setModalState={setDeleteState}
handleSubmit={confirmDelete}
submitLabel={t('settings.customModels.actions.delete')}
variant="danger"
/>
</div>
);
}

View File

@@ -21,6 +21,7 @@ import {
setSourceDocs,
} from '../preferences/preferenceSlice';
import Analytics from './Analytics';
import CustomModels from './CustomModels';
import Sources from './Sources';
import General from './General';
import Logs from './Logs';
@@ -39,6 +40,8 @@ export default function Settings() {
return t('settings.analytics.label');
if (path.includes('/settings/logs')) return t('settings.logs.label');
if (path.includes('/settings/tools')) return t('settings.tools.label');
if (path.includes('/settings/custom-models'))
return t('settings.customModels.label');
return t('settings.general.label');
};
@@ -52,6 +55,8 @@ export default function Settings() {
navigate('/settings/analytics');
else if (tab === t('settings.logs.label')) navigate('/settings/logs');
else if (tab === t('settings.tools.label')) navigate('/settings/tools');
else if (tab === t('settings.customModels.label'))
navigate('/settings/custom-models');
};
React.useEffect(() => {
@@ -113,6 +118,7 @@ export default function Settings() {
<Route path="analytics" element={<Analytics />} />
<Route path="logs" element={<Logs />} />
<Route path="tools" element={<Tools />} />
<Route path="custom-models" element={<CustomModels />} />
<Route path="*" element={<Navigate to="/settings" replace />} />
</Routes>
</div>

View File

@@ -116,7 +116,7 @@ def test_execute_agent_node_normalizes_wrapped_schema_before_agent_create(monkey
)
monkeypatch.setattr(
"application.core.model_utils.get_model_capabilities",
lambda _model_id: {"supports_structured_output": True},
lambda _model_id, **_kwargs: {"supports_structured_output": True},
)
list(engine._execute_agent_node(node))
@@ -242,7 +242,7 @@ def test_validate_workflow_structure_rejects_unsupported_structured_output_model
monkeypatch.setattr(
workflow_routes,
"get_model_capabilities",
lambda _model_id: {"supports_structured_output": False},
lambda _model_id, **_kwargs: {"supports_structured_output": False},
)
nodes = [
@@ -296,7 +296,7 @@ def test_execute_agent_node_raises_when_structured_output_violates_schema(monkey
)
monkeypatch.setattr(
"application.core.model_utils.get_model_capabilities",
lambda _model_id: {"supports_structured_output": True},
lambda _model_id, **_kwargs: {"supports_structured_output": True},
)
with pytest.raises(ValueError, match="Structured output did not match schema"):
@@ -322,7 +322,7 @@ def test_execute_agent_node_raises_when_schema_set_and_response_not_json(monkeyp
)
monkeypatch.setattr(
"application.core.model_utils.get_model_capabilities",
lambda _model_id: {"supports_structured_output": True},
lambda _model_id, **_kwargs: {"supports_structured_output": True},
)
with pytest.raises(
@@ -360,11 +360,11 @@ class TestWorkflowEngineAdditionalCoverage:
)
monkeypatch.setattr(
"application.core.model_utils.get_provider_from_model_id",
lambda _: None,
lambda _, **_kwargs: None,
)
monkeypatch.setattr(
"application.core.model_utils.get_model_capabilities",
lambda _: None,
lambda _, **_kwargs: None,
)
list(engine._execute_agent_node(node))
@@ -390,11 +390,11 @@ class TestWorkflowEngineAdditionalCoverage:
)
monkeypatch.setattr(
"application.core.model_utils.get_provider_from_model_id",
lambda _: "openai",
lambda _, **_kwargs: "openai",
)
monkeypatch.setattr(
"application.core.model_utils.get_model_capabilities",
lambda _: None,
lambda _, **_kwargs: None,
)
list(engine._execute_agent_node(node))
@@ -416,11 +416,11 @@ class TestWorkflowEngineAdditionalCoverage:
)
monkeypatch.setattr(
"application.core.model_utils.get_provider_from_model_id",
lambda _: "openai",
lambda _, **_kwargs: "openai",
)
monkeypatch.setattr(
"application.core.model_utils.get_model_capabilities",
lambda _: {"supports_structured_output": False},
lambda _, **_kwargs: {"supports_structured_output": False},
)
with pytest.raises(ValueError, match="does not support structured output"):
@@ -450,11 +450,11 @@ class TestWorkflowEngineAdditionalCoverage:
)
monkeypatch.setattr(
"application.core.model_utils.get_provider_from_model_id",
lambda _: None,
lambda _, **_kwargs: None,
)
monkeypatch.setattr(
"application.core.model_utils.get_model_capabilities",
lambda _: {"supports_structured_output": True},
lambda _, **_kwargs: {"supports_structured_output": True},
)
list(engine._execute_agent_node(node))
@@ -481,11 +481,11 @@ class TestWorkflowEngineAdditionalCoverage:
)
monkeypatch.setattr(
"application.core.model_utils.get_provider_from_model_id",
lambda _: None,
lambda _, **_kwargs: None,
)
monkeypatch.setattr(
"application.core.model_utils.get_model_capabilities",
lambda _: {"supports_structured_output": True},
lambda _, **_kwargs: {"supports_structured_output": True},
)
with pytest.raises(

View File

@@ -162,8 +162,10 @@ class TestCompressIfNeeded:
current_query_tokens=1000,
)
# user_id flows through so BYOM custom-model UUIDs resolve to
# the user's declared context window in the threshold check.
mock_threshold_checker.should_compress.assert_called_once_with(
sample_conversation, "gpt-4", 1000
sample_conversation, "gpt-4", 1000, user_id="user1"
)
@@ -270,8 +272,12 @@ class TestPerformCompression:
)
orch._perform_compression("c1", conversation, "gpt-4", decoded_token)
# Verify the override model was used
mock_get_provider.assert_called_with("gpt-3.5-turbo")
# Verify the override model was used. user_id flows from
# decoded_token['sub'] so per-user BYOM custom-model UUIDs
# resolve.
mock_get_provider.assert_called_with(
"gpt-3.5-turbo", user_id=decoded_token["sub"]
)
@patch(
"application.api.answer.services.compression.orchestrator.get_provider_from_model_id"
@@ -362,7 +368,12 @@ class TestCompressMidExecution:
)
mock_perform.assert_called_once_with(
"conv1", sample_conversation, "gpt-4", decoded_token
"conv1",
sample_conversation,
"gpt-4",
decoded_token,
user_id="user1",
model_user_id=None,
)
def test_loads_conversation_when_not_provided(

View File

@@ -200,7 +200,7 @@ class TestSetupPeriodicTasks:
setup_periodic_tasks(sender)
assert sender.add_periodic_task.call_count == 4
assert sender.add_periodic_task.call_count == 5
calls = sender.add_periodic_task.call_args_list
@@ -212,6 +212,8 @@ class TestSetupPeriodicTasks:
assert calls[2][0][0] == timedelta(days=30)
# pending_tool_state TTL cleanup (60s)
assert calls[3][0][0] == timedelta(seconds=60)
# version-check (every 7h)
assert calls[4][0][0] == timedelta(hours=7)
class TestMcpOauthTask:

View File

@@ -0,0 +1,688 @@
"""Tests for the BYOM REST API at /api/user/models."""
from __future__ import annotations
from contextlib import contextmanager
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
@pytest.fixture
def app():
return Flask(__name__)
@contextmanager
def _patch_db(conn):
"""Patch the routes' db helpers to yield the given pg connection."""
@contextmanager
def _yield_conn():
yield conn
with patch(
"application.api.user.models.routes.db_session", _yield_conn
), patch(
"application.api.user.models.routes.db_readonly", _yield_conn
):
yield
@pytest.fixture(autouse=True)
def _reset_registry():
from application.core.model_registry import ModelRegistry
ModelRegistry.reset()
yield
ModelRegistry.reset()
# Auth
@pytest.mark.unit
class TestAuth:
def test_list_unauthenticated_returns_401(self, app):
from application.api.user.models.routes import (
UserModelsCollectionResource,
)
with app.test_request_context("/api/user/models"):
from flask import request
request.decoded_token = None
resp = UserModelsCollectionResource().get()
assert resp.status_code == 401
def test_create_unauthenticated_returns_401(self, app):
from application.api.user.models.routes import (
UserModelsCollectionResource,
)
with app.test_request_context(
"/api/user/models",
method="POST",
json={
"upstream_model_id": "x",
"display_name": "x",
"base_url": "https://api.openai.com/v1",
"api_key": "k",
},
):
from flask import request
request.decoded_token = None
resp = UserModelsCollectionResource().post()
assert resp.status_code == 401
# Create
@pytest.mark.unit
class TestCreate:
def test_creates_and_returns_201_without_api_key(self, app, pg_conn):
from application.api.user.models.routes import (
UserModelsCollectionResource,
)
# Mock DNS so the SSRF check passes for api.mistral.ai without
# hitting the network.
with patch("application.security.safe_url.socket.getaddrinfo") as gai:
gai.return_value = [
(None, None, None, None, ("104.18.0.1", 0))
]
with app.test_request_context(
"/api/user/models",
method="POST",
json={
"upstream_model_id": "mistral-large-latest",
"display_name": "My Mistral",
"base_url": "https://api.mistral.ai/v1",
"api_key": "sk-mistral-test",
"capabilities": {
"supports_tools": True,
"context_window": 128000,
},
},
):
from flask import request
request.decoded_token = {"sub": "user-1"}
with _patch_db(pg_conn):
resp = UserModelsCollectionResource().post()
assert resp.status_code == 201
body = resp.get_json()
assert body["upstream_model_id"] == "mistral-large-latest"
assert body["source"] == "user"
# Critical: api_key must NEVER appear in the response
assert "api_key" not in body
for v in body.values():
assert v != "sk-mistral-test"
def test_create_rejects_missing_required_fields(self, app, pg_conn):
from application.api.user.models.routes import (
UserModelsCollectionResource,
)
with app.test_request_context(
"/api/user/models",
method="POST",
json={"upstream_model_id": "x"}, # missing the others
):
from flask import request
request.decoded_token = {"sub": "user-1"}
with _patch_db(pg_conn):
resp = UserModelsCollectionResource().post()
assert resp.status_code == 400
def test_create_rejects_loopback_url(self, app, pg_conn):
from application.api.user.models.routes import (
UserModelsCollectionResource,
)
with app.test_request_context(
"/api/user/models",
method="POST",
json={
"upstream_model_id": "x",
"display_name": "x",
"base_url": "https://127.0.0.1/v1",
"api_key": "k",
},
):
from flask import request
request.decoded_token = {"sub": "user-1"}
with _patch_db(pg_conn):
resp = UserModelsCollectionResource().post()
assert resp.status_code == 400
body = resp.get_json()
assert "error" in body
def test_create_rejects_unknown_attachment_alias(self, app, pg_conn):
"""The UI sends ``["image"]`` as an alias; bad strings ("video",
typos) must reject at the boundary so the DB never holds
garbage that the registry would later silently drop.
"""
from application.api.user.models.routes import (
UserModelsCollectionResource,
)
with patch("application.security.safe_url.socket.getaddrinfo") as gai:
gai.return_value = [
(None, None, None, None, ("104.18.0.1", 0))
]
with app.test_request_context(
"/api/user/models",
method="POST",
json={
"upstream_model_id": "m",
"display_name": "M",
"base_url": "https://api.mistral.ai/v1",
"api_key": "k",
"capabilities": {"attachments": ["video"]},
},
):
from flask import request
request.decoded_token = {"sub": "user-1"}
with _patch_db(pg_conn):
resp = UserModelsCollectionResource().post()
assert resp.status_code == 400
body = resp.get_json()
assert "video" in body["error"]
def test_create_accepts_image_alias_and_raw_mime(self, app, pg_conn):
"""The known ``image`` alias and raw MIME types both pass."""
from application.api.user.models.routes import (
UserModelsCollectionResource,
)
with patch("application.security.safe_url.socket.getaddrinfo") as gai:
gai.return_value = [
(None, None, None, None, ("104.18.0.1", 0))
]
with app.test_request_context(
"/api/user/models",
method="POST",
json={
"upstream_model_id": "m",
"display_name": "M",
"base_url": "https://api.mistral.ai/v1",
"api_key": "k",
"capabilities": {
"attachments": ["image", "application/pdf"],
},
},
):
from flask import request
request.decoded_token = {"sub": "user-1"}
with _patch_db(pg_conn):
resp = UserModelsCollectionResource().post()
assert resp.status_code == 201
def test_create_rejects_private_ip_dns(self, app, pg_conn):
from application.api.user.models.routes import (
UserModelsCollectionResource,
)
with patch("application.security.safe_url.socket.getaddrinfo") as gai:
# Hostname resolves to a private IP only — must reject
gai.return_value = [
(None, None, None, None, ("10.0.0.5", 0))
]
with app.test_request_context(
"/api/user/models",
method="POST",
json={
"upstream_model_id": "x",
"display_name": "x",
"base_url": "https://evil.example.com/v1",
"api_key": "k",
},
):
from flask import request
request.decoded_token = {"sub": "user-1"}
with _patch_db(pg_conn):
resp = UserModelsCollectionResource().post()
assert resp.status_code == 400
# List / get / patch / delete
def _create_via_repo(pg_conn, user_id="user-1", **kwargs):
from application.storage.db.repositories.user_custom_models import (
UserCustomModelsRepository,
)
return UserCustomModelsRepository(pg_conn).create(
user_id=user_id,
upstream_model_id=kwargs.pop("upstream_model_id", "mistral-large-latest"),
display_name=kwargs.pop("display_name", "My Mistral"),
base_url=kwargs.pop("base_url", "https://api.mistral.ai/v1"),
api_key_plaintext=kwargs.pop("api_key_plaintext", "sk-mistral-test"),
**kwargs,
)
@pytest.mark.unit
class TestList:
def test_lists_only_users_own(self, app, pg_conn):
from application.api.user.models.routes import (
UserModelsCollectionResource,
)
_create_via_repo(pg_conn, user_id="alice", upstream_model_id="alice-1")
_create_via_repo(pg_conn, user_id="bob", upstream_model_id="bob-1")
with app.test_request_context("/api/user/models"):
from flask import request
request.decoded_token = {"sub": "alice"}
with _patch_db(pg_conn):
resp = UserModelsCollectionResource().get()
assert resp.status_code == 200
body = resp.get_json()
upstream_ids = {m["upstream_model_id"] for m in body["models"]}
assert upstream_ids == {"alice-1"}
# Never expose the api_key
for m in body["models"]:
assert "api_key" not in m
@pytest.mark.unit
class TestGet:
def test_returns_404_for_other_users_model(self, app, pg_conn):
from application.api.user.models.routes import UserModelResource
created = _create_via_repo(pg_conn, user_id="alice")
with app.test_request_context(
f"/api/user/models/{created['id']}"
):
from flask import request
request.decoded_token = {"sub": "bob"}
with _patch_db(pg_conn):
resp = UserModelResource().get(model_id=created["id"])
assert resp.status_code == 404
@pytest.mark.unit
class TestPatch:
def test_patch_updates_display_name(self, app, pg_conn):
from application.api.user.models.routes import UserModelResource
created = _create_via_repo(pg_conn, user_id="user-1")
with app.test_request_context(
f"/api/user/models/{created['id']}",
method="PATCH",
json={"display_name": "Renamed"},
):
from flask import request
request.decoded_token = {"sub": "user-1"}
with _patch_db(pg_conn):
resp = UserModelResource().patch(model_id=created["id"])
assert resp.status_code == 200
body = resp.get_json()
assert body["display_name"] == "Renamed"
def test_patch_blank_api_key_keeps_existing(self, app, pg_conn):
"""Critical PATCH semantic: empty/missing api_key in body must
preserve the stored ciphertext (the UI sends a blank password
field when the user wants to keep the existing key)."""
from application.api.user.models.routes import UserModelResource
from application.storage.db.repositories.user_custom_models import (
UserCustomModelsRepository,
)
created = _create_via_repo(pg_conn, user_id="user-1")
original_key_plaintext = "sk-mistral-test"
with app.test_request_context(
f"/api/user/models/{created['id']}",
method="PATCH",
json={"display_name": "Just rename me", "api_key": ""},
):
from flask import request
request.decoded_token = {"sub": "user-1"}
with _patch_db(pg_conn):
resp = UserModelResource().patch(model_id=created["id"])
assert resp.status_code == 200
# Decrypted key is unchanged
repo = UserCustomModelsRepository(pg_conn)
assert (
repo.get_decrypted_api_key(created["id"], "user-1")
== original_key_plaintext
)
@pytest.mark.unit
class TestDelete:
def test_delete_removes_row_and_invalidates_cache(self, app, pg_conn):
from application.api.user.models.routes import UserModelResource
from application.core.model_registry import ModelRegistry
created = _create_via_repo(pg_conn, user_id="user-1")
# Warm the registry's per-user cache via a lookup
with patch(
"application.storage.db.session.db_readonly"
) as ro:
@contextmanager
def _y():
yield pg_conn
ro.side_effect = _y
ModelRegistry.get_instance().get_model(created["id"], user_id="user-1")
# Now delete via the route — invalidation must happen so a
# subsequent lookup misses
with app.test_request_context(
f"/api/user/models/{created['id']}", method="DELETE"
):
from flask import request
request.decoded_token = {"sub": "user-1"}
with _patch_db(pg_conn):
resp = UserModelResource().delete(model_id=created["id"])
assert resp.status_code == 200
# Cache invalidated → next lookup re-queries DB and finds nothing
assert "user-1" not in ModelRegistry.get_instance()._user_models
# /api/models combined view
@pytest.mark.unit
class TestSecurityCreateRejectsBlankFields:
"""P1 #1 partial: blank api_key on create must be rejected so we
can never end up with an unroutable BYOM record that would cause
LLMCreator to leak settings.API_KEY to the user-supplied URL."""
def test_create_rejects_blank_api_key(self, app, pg_conn):
from application.api.user.models.routes import (
UserModelsCollectionResource,
)
with patch("application.security.safe_url.socket.getaddrinfo") as gai:
gai.return_value = [(None, None, None, None, ("104.18.0.1", 0))]
with app.test_request_context(
"/api/user/models",
method="POST",
json={
"upstream_model_id": "x",
"display_name": "x",
"base_url": "https://api.mistral.ai/v1",
"api_key": " ", # whitespace only
},
):
from flask import request
request.decoded_token = {"sub": "u"}
with _patch_db(pg_conn):
resp = UserModelsCollectionResource().post()
assert resp.status_code == 400
body = resp.get_json()
assert "api_key" in (body.get("error") or "").lower()
def test_patch_rejects_blank_required_field(self, app, pg_conn):
from application.api.user.models.routes import UserModelResource
from application.storage.db.repositories.user_custom_models import (
UserCustomModelsRepository,
)
created = UserCustomModelsRepository(pg_conn).create(
user_id="u",
upstream_model_id="x",
display_name="x",
base_url="https://api.mistral.ai/v1",
api_key_plaintext="sk-x",
)
with app.test_request_context(
f"/api/user/models/{created['id']}",
method="PATCH",
json={"base_url": " "},
):
from flask import request
request.decoded_token = {"sub": "u"}
with _patch_db(pg_conn):
resp = UserModelResource().patch(model_id=created["id"])
assert resp.status_code == 400
@pytest.mark.unit
class TestPayloadConnectionTest:
"""Verifies the payload-based test endpoint. Lets the UI's 'Test
connection' button work *before* the model is saved — operators
expect to validate their endpoint + key before committing."""
def test_payload_test_rejects_unsafe_url(self, app, pg_conn):
from application.api.user.models.routes import (
UserModelTestPayloadResource,
)
with app.test_request_context(
"/api/user/models/test",
method="POST",
json={
"base_url": "https://127.0.0.1/v1",
"api_key": "sk-anything",
"upstream_model_id": "x",
},
):
from flask import request
request.decoded_token = {"sub": "user-1"}
with _patch_db(pg_conn):
resp = UserModelTestPayloadResource().post()
assert resp.status_code == 400
body = resp.get_json()
assert body["ok"] is False
def test_payload_test_returns_ok_when_upstream_responds_2xx(
self, app, pg_conn
):
from application.api.user.models.routes import (
UserModelTestPayloadResource,
)
# pinned_post is the IP-pinned dispatch helper. Patching it
# bypasses both the SSRF guard and the network — the success
# path we're verifying here is the route's response handling.
with patch("application.api.user.models.routes.pinned_post") as rp:
rp.return_value = MagicMock(
status_code=200,
headers={"Content-Type": "application/json"},
text='{"ok": true}',
)
with app.test_request_context(
"/api/user/models/test",
method="POST",
json={
"base_url": "https://api.mistral.ai/v1",
"api_key": "sk-mistral-test",
"upstream_model_id": "mistral-large-latest",
},
):
from flask import request
request.decoded_token = {"sub": "user-1"}
with _patch_db(pg_conn):
resp = UserModelTestPayloadResource().post()
assert resp.status_code == 200
body = resp.get_json()
assert body["ok"] is True
# Verify the upstream call carried the user's submitted key (not
# whatever's in the DB) and the right model name.
call_args = rp.call_args
assert call_args.kwargs["headers"]["Authorization"] == "Bearer sk-mistral-test"
assert call_args.kwargs["json"]["model"] == "mistral-large-latest"
def test_payload_test_unauthenticated_returns_401(self, app, pg_conn):
from application.api.user.models.routes import (
UserModelTestPayloadResource,
)
with app.test_request_context(
"/api/user/models/test",
method="POST",
json={
"base_url": "https://api.mistral.ai/v1",
"api_key": "k",
"upstream_model_id": "x",
},
):
from flask import request
request.decoded_token = None
with _patch_db(pg_conn):
resp = UserModelTestPayloadResource().post()
assert resp.status_code == 401
def test_payload_test_missing_fields_returns_400(self, app, pg_conn):
from application.api.user.models.routes import (
UserModelTestPayloadResource,
)
with app.test_request_context(
"/api/user/models/test",
method="POST",
json={"base_url": "https://api.mistral.ai/v1"},
):
from flask import request
request.decoded_token = {"sub": "user-1"}
with _patch_db(pg_conn):
resp = UserModelTestPayloadResource().post()
assert resp.status_code == 400
@pytest.mark.unit
class TestByIdConnectionTestAcceptsOverrides:
"""P3: in edit mode the modal sends current form state as overrides
so the test reflects in-flight edits (not the saved record)."""
def _make_row(self, pg_conn):
from application.storage.db.repositories.user_custom_models import (
UserCustomModelsRepository,
)
return UserCustomModelsRepository(pg_conn).create(
user_id="u",
upstream_model_id="stored-model",
display_name="Stored",
base_url="https://stored.example.com/v1",
api_key_plaintext="sk-stored",
)
def _post_test(self, app, pg_conn, model_id, body):
from application.api.user.models.routes import UserModelTestResource
with patch("application.api.user.models.routes.pinned_post") as rp:
rp.return_value = MagicMock(
status_code=200,
headers={"Content-Type": "application/json"},
text='{"ok": true}',
)
with app.test_request_context(
f"/api/user/models/{model_id}/test", method="POST", json=body
):
from flask import request
request.decoded_token = {"sub": "u"}
with _patch_db(pg_conn):
UserModelTestResource().post(model_id=model_id)
return rp.call_args
def test_overrides_win_when_supplied(self, app, pg_conn):
row = self._make_row(pg_conn)
ca = self._post_test(
app,
pg_conn,
row["id"],
{
"base_url": "https://new.example.com/v1",
"api_key": "sk-new",
"upstream_model_id": "new-model",
},
)
assert ca.args[0] == "https://new.example.com/v1/chat/completions"
assert ca.kwargs["headers"]["Authorization"] == "Bearer sk-new"
assert ca.kwargs["json"]["model"] == "new-model"
def test_blank_overrides_fall_back_to_stored(self, app, pg_conn):
"""The classic edit-mode flow: user changed base_url, left
api_key blank — server uses the new URL but the stored key."""
row = self._make_row(pg_conn)
ca = self._post_test(
app,
pg_conn,
row["id"],
{
"base_url": "https://new.example.com/v1",
"api_key": "",
"upstream_model_id": "",
},
)
assert ca.args[0] == "https://new.example.com/v1/chat/completions"
# Stored key (decrypted) was used.
assert ca.kwargs["headers"]["Authorization"] == "Bearer sk-stored"
# Stored upstream_model_id was used.
assert ca.kwargs["json"]["model"] == "stored-model"
def test_empty_body_uses_all_stored_values(self, app, pg_conn):
row = self._make_row(pg_conn)
ca = self._post_test(app, pg_conn, row["id"], {})
assert ca.args[0] == "https://stored.example.com/v1/chat/completions"
assert ca.kwargs["headers"]["Authorization"] == "Bearer sk-stored"
assert ca.kwargs["json"]["model"] == "stored-model"
@pytest.mark.unit
class TestApiModelsListWithUser:
def test_includes_user_models_when_authenticated(self, app, pg_conn):
"""GET /api/models with auth should surface the user's BYOM
records alongside built-ins, each tagged with `source`."""
from application.api.user.models.routes import ModelsListResource
from application.core.model_registry import ModelRegistry
created = _create_via_repo(
pg_conn, user_id="user-1", display_name="My Mistral"
)
# Patch the *registry's* db_readonly so the per-user layer load
# uses the test connection.
@contextmanager
def _yield():
yield pg_conn
with patch(
"application.storage.db.session.db_readonly", _yield
):
ModelRegistry.reset()
with app.test_request_context("/api/models"):
from flask import request
request.decoded_token = {"sub": "user-1"}
resp = ModelsListResource().get()
assert resp.status_code == 200
body = resp.get_json()
ids = [m["id"] for m in body["models"]]
assert created["id"] in ids
# Source label tags it for the UI
user_entries = [m for m in body["models"] if m["id"] == created["id"]]
assert user_entries[0]["source"] == "user"
# Built-ins still present
assert any(m.get("source") == "builtin" for m in body["models"])

View File

@@ -122,6 +122,11 @@ def mock_llm():
llm._supports_tools = True
llm._supports_structured_output = Mock(return_value=False)
llm.__class__.__name__ = "MockLLM"
# Mirror BaseLLM.__init__: real LLMCreator stores the resolved
# upstream model name on self.model_id. Tests that build agents via
# ``mock_llm_creator`` rely on the agent's ``upstream_model_id``
# falling through to this attribute.
llm.model_id = "gpt-4"
return llm

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,306 @@
"""Phase 1 regression tests for the YAML-driven ModelRegistry.
These tests encode the contract that persisted agent / workflow /
conversation references depend on: every model id and core capability
that existed in the old ``model_configs.py`` lists must continue to be
produced by the new YAML-backed registry.
If a future YAML edit accidentally renames an id or changes a
capability, these tests fail at CI before merge — protecting agents and
workflows from silent fallback to the system default.
"""
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from application.core.model_registry import ModelRegistry
from application.core.model_yaml import (
BUILTIN_MODELS_DIR,
load_model_yamls,
)
# ── Per-provider expected IDs ─────────────────────────────────────────────
# Snapshot of the current built-in catalog. If you intentionally change
# what models a provider's YAML lists, update this constant in the same
# commit. The test exists to catch *unintentional* renames (e.g. a typo
# in an upstream model id) that would silently break every agent that
# references the old id.
EXPECTED_IDS = {
"openai": {"gpt-5.5", "gpt-5.4-mini", "gpt-5.4-nano"},
"anthropic": {
"claude-opus-4-7",
"claude-sonnet-4-6",
"claude-haiku-4-5",
},
"google": {
"gemini-3.1-pro-preview",
"gemini-3-flash-preview",
"gemini-3.1-flash-lite-preview",
},
"groq": {
"openai/gpt-oss-120b",
"llama-3.3-70b-versatile",
"llama-3.1-8b-instant",
},
"openrouter": {
"qwen/qwen3-coder:free",
"deepseek/deepseek-v3.2",
"anthropic/claude-sonnet-4.6",
},
"novita": {
"deepseek/deepseek-v4-pro",
"moonshotai/kimi-k2.6",
"zai-org/glm-5",
},
"azure_openai": {
"azure-gpt-5.5",
"azure-gpt-5.4-mini",
"azure-gpt-5.4-nano",
},
"docsgpt": {"docsgpt-local"},
"huggingface": {"huggingface-local"},
}
def _make_settings(**overrides):
s = MagicMock()
# All credential / mode flags off by default so each test opts in.
s.OPENAI_BASE_URL = None
s.OPENAI_API_KEY = None
s.OPENAI_API_BASE = None
s.ANTHROPIC_API_KEY = None
s.GOOGLE_API_KEY = None
s.GROQ_API_KEY = None
s.OPEN_ROUTER_API_KEY = None
s.NOVITA_API_KEY = None
s.HUGGINGFACE_API_KEY = None
s.LLM_PROVIDER = ""
s.LLM_NAME = None
s.API_KEY = None
s.MODELS_CONFIG_DIR = None
for k, v in overrides.items():
setattr(s, k, v)
return s
@pytest.fixture(autouse=True)
def _reset_registry():
ModelRegistry.reset()
yield
ModelRegistry.reset()
# ── YAML schema / loader ─────────────────────────────────────────────────
def _by_provider(catalogs):
"""Group a list of catalogs by provider name. Mirrors the registry's
own grouping; useful for asserting per-provider model sets in tests."""
out = {}
for c in catalogs:
out.setdefault(c.provider, []).append(c)
return out
@pytest.mark.unit
class TestYAMLLoader:
def test_loader_produces_expected_provider_set(self):
catalogs = load_model_yamls([BUILTIN_MODELS_DIR])
providers = {c.provider for c in catalogs}
assert providers == set(EXPECTED_IDS.keys())
def test_each_provider_has_expected_ids(self):
grouped = _by_provider(load_model_yamls([BUILTIN_MODELS_DIR]))
for provider, expected in EXPECTED_IDS.items():
actual = {m.id for c in grouped[provider] for m in c.models}
assert actual == expected, f"{provider}: expected {expected}, got {actual}"
def test_attachment_alias_image_expands_to_five_mime_types(self):
grouped = _by_provider(load_model_yamls([BUILTIN_MODELS_DIR]))
# OpenAI uses `attachments: [image]` in its defaults block.
for c in grouped["openai"]:
for m in c.models:
assert "image/png" in m.capabilities.supported_attachment_types
assert "image/jpeg" in m.capabilities.supported_attachment_types
assert "image/webp" in m.capabilities.supported_attachment_types
assert len(m.capabilities.supported_attachment_types) == 5
def test_attachment_alias_pdf_plus_image_for_google(self):
grouped = _by_provider(load_model_yamls([BUILTIN_MODELS_DIR]))
for c in grouped["google"]:
for m in c.models:
assert "application/pdf" in m.capabilities.supported_attachment_types
assert "image/png" in m.capabilities.supported_attachment_types
assert len(m.capabilities.supported_attachment_types) == 6
def test_per_model_context_window_overrides_provider_default(self):
grouped = _by_provider(load_model_yamls([BUILTIN_MODELS_DIR]))
openai = {m.id: m for c in grouped["openai"] for m in c.models}
# Provider default is 400_000; gpt-5.5 overrides to 1_050_000.
assert openai["gpt-5.4-mini"].capabilities.context_window == 400_000
assert openai["gpt-5.5"].capabilities.context_window == 1_050_000
# ── Registry × settings: every documented .env permutation ───────────────
@pytest.mark.unit
class TestRegistryPermutations:
def test_openai_only(self):
s = _make_settings(OPENAI_API_KEY="sk-test", LLM_PROVIDER="openai")
with patch("application.core.settings.settings", s):
reg = ModelRegistry()
ids = {m.id for m in reg.get_all_models()}
assert ids == EXPECTED_IDS["openai"] | EXPECTED_IDS["docsgpt"]
def test_openai_base_url_replaces_catalog_with_dynamic(self):
s = _make_settings(
OPENAI_BASE_URL="http://localhost:11434/v1",
OPENAI_API_KEY="sk-test",
LLM_PROVIDER="openai",
LLM_NAME="llama3,gemma",
)
with patch("application.core.settings.settings", s):
reg = ModelRegistry()
ids = {m.id for m in reg.get_all_models()}
# Custom local endpoint suppresses both the openai catalog AND
# the docsgpt model (matching legacy behavior).
assert ids == {"llama3", "gemma"}
def test_anthropic_only(self):
s = _make_settings(ANTHROPIC_API_KEY="sk-ant")
with patch("application.core.settings.settings", s):
reg = ModelRegistry()
ids = {m.id for m in reg.get_all_models()}
assert ids == EXPECTED_IDS["anthropic"] | EXPECTED_IDS["docsgpt"]
def test_anthropic_via_llm_provider_with_llm_name(self):
# Mirrors the historical _add_anthropic_models filter: when only
# API_KEY (not ANTHROPIC_API_KEY) is set and LLM_NAME matches a
# known model, only that model is loaded.
s = _make_settings(
LLM_PROVIDER="anthropic", API_KEY="key", LLM_NAME="claude-haiku-4-5"
)
with patch("application.core.settings.settings", s):
reg = ModelRegistry()
anthropic_ids = {
m.id for m in reg.get_all_models() if m.provider.value == "anthropic"
}
assert anthropic_ids == {"claude-haiku-4-5"}
def test_google_only(self):
s = _make_settings(GOOGLE_API_KEY="g-test")
with patch("application.core.settings.settings", s):
reg = ModelRegistry()
ids = {m.id for m in reg.get_all_models()}
assert ids == EXPECTED_IDS["google"] | EXPECTED_IDS["docsgpt"]
def test_groq_only(self):
s = _make_settings(GROQ_API_KEY="g-test")
with patch("application.core.settings.settings", s):
reg = ModelRegistry()
ids = {m.id for m in reg.get_all_models()}
assert ids == EXPECTED_IDS["groq"] | EXPECTED_IDS["docsgpt"]
def test_openrouter_only(self):
s = _make_settings(OPEN_ROUTER_API_KEY="or-test")
with patch("application.core.settings.settings", s):
reg = ModelRegistry()
ids = {m.id for m in reg.get_all_models()}
assert ids == EXPECTED_IDS["openrouter"] | EXPECTED_IDS["docsgpt"]
def test_novita_only(self):
s = _make_settings(NOVITA_API_KEY="n-test")
with patch("application.core.settings.settings", s):
reg = ModelRegistry()
ids = {m.id for m in reg.get_all_models()}
assert ids == EXPECTED_IDS["novita"] | EXPECTED_IDS["docsgpt"]
def test_huggingface_only(self):
s = _make_settings(HUGGINGFACE_API_KEY="hf-test")
with patch("application.core.settings.settings", s):
reg = ModelRegistry()
ids = {m.id for m in reg.get_all_models()}
assert ids == EXPECTED_IDS["huggingface"] | EXPECTED_IDS["docsgpt"]
def test_no_credentials_only_docsgpt(self):
s = _make_settings()
with patch("application.core.settings.settings", s):
reg = ModelRegistry()
ids = {m.id for m in reg.get_all_models()}
assert ids == EXPECTED_IDS["docsgpt"]
def test_azure_via_provider(self):
s = _make_settings(LLM_PROVIDER="azure_openai", API_KEY="key")
with patch("application.core.settings.settings", s):
reg = ModelRegistry()
ids = {m.id for m in reg.get_all_models()}
assert "azure-gpt-5.5" in ids
def test_azure_via_api_base(self):
s = _make_settings(OPENAI_API_BASE="https://x.openai.azure.com")
with patch("application.core.settings.settings", s):
reg = ModelRegistry()
ids = {m.id for m in reg.get_all_models()}
assert "azure-gpt-5.5" in ids
def test_everything_set(self):
s = _make_settings(
OPENAI_API_KEY="x",
ANTHROPIC_API_KEY="x",
GOOGLE_API_KEY="x",
GROQ_API_KEY="x",
OPEN_ROUTER_API_KEY="x",
NOVITA_API_KEY="x",
HUGGINGFACE_API_KEY="x",
OPENAI_API_BASE="x",
)
with patch("application.core.settings.settings", s):
reg = ModelRegistry()
ids = {m.id for m in reg.get_all_models()}
all_expected = set()
for v in EXPECTED_IDS.values():
all_expected |= v
assert ids == all_expected
# ── Default model resolution ─────────────────────────────────────────────
@pytest.mark.unit
class TestDefaultModelResolution:
def test_llm_name_picks_default(self):
s = _make_settings(
ANTHROPIC_API_KEY="sk-ant", LLM_NAME="claude-opus-4-7"
)
with patch("application.core.settings.settings", s):
reg = ModelRegistry()
assert reg.default_model_id == "claude-opus-4-7"
def test_falls_back_to_first_model_when_no_match(self):
s = _make_settings()
with patch("application.core.settings.settings", s):
reg = ModelRegistry()
assert reg.default_model_id is not None
assert reg.default_model_id in reg.models
# ── Forward-compat: user_id parameter is accepted everywhere ─────────────
@pytest.mark.unit
class TestUserIdForwardCompat:
def test_lookup_methods_accept_user_id(self):
s = _make_settings(OPENAI_API_KEY="sk-test")
with patch("application.core.settings.settings", s):
reg = ModelRegistry()
# All lookup methods must accept user_id (currently ignored,
# reserved for end-user BYOM).
assert reg.get_model("gpt-5.5", user_id="alice") is not None
assert len(reg.get_all_models(user_id="alice")) > 0
assert len(reg.get_enabled_models(user_id="alice")) > 0
assert reg.model_exists("gpt-5.5", user_id="alice") is True

View File

@@ -1,6 +1,17 @@
"""Tests for application/core/model_settings.py"""
"""Tests for application/core/model_settings.py.
from unittest.mock import MagicMock, patch
The provider-specific load logic that used to live in private
``_add_<X>_models`` methods now lives in plugin classes under
``application/llm/providers/`` and YAML catalogs under
``application/core/models/``. End-to-end coverage of the registry +
plugin pipeline is in ``tests/core/test_model_registry_yaml.py``.
This file covers the data classes (``AvailableModel``,
``ModelCapabilities``, ``ModelProvider``) and the singleton/lookup
contract on ``ModelRegistry``.
"""
from unittest.mock import patch
import pytest
@@ -13,7 +24,6 @@ from application.core.model_settings import (
class TestModelProvider:
@pytest.mark.unit
def test_all_providers_exist(self):
assert ModelProvider.OPENAI == "openai"
@@ -31,7 +41,6 @@ class TestModelProvider:
class TestModelCapabilities:
@pytest.mark.unit
def test_defaults(self):
caps = ModelCapabilities()
@@ -56,7 +65,6 @@ class TestModelCapabilities:
class TestAvailableModel:
@pytest.mark.unit
def test_to_dict_basic(self):
model = AvailableModel(
@@ -78,35 +86,67 @@ class TestAvailableModel:
id="local-model",
provider=ModelProvider.OPENAI,
display_name="Local",
base_url="http://localhost:11434",
base_url="http://localhost:11434/v1",
)
d = model.to_dict()
assert d["base_url"] == "http://localhost:11434"
assert d["base_url"] == "http://localhost:11434/v1"
@pytest.mark.unit
def test_to_dict_includes_capabilities(self):
caps = ModelCapabilities(supports_tools=True, context_window=64000)
caps = ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
context_window=200000,
supported_attachment_types=["image/png"],
)
model = AvailableModel(
id="m1",
provider=ModelProvider.ANTHROPIC,
display_name="M1",
id="m",
provider=ModelProvider.OPENAI,
display_name="M",
capabilities=caps,
)
d = model.to_dict()
assert d["supports_tools"] is True
assert d["context_window"] == 64000
assert d["supports_structured_output"] is True
assert d["context_window"] == 200000
assert d["supported_attachment_types"] == ["image/png"]
@pytest.mark.unit
def test_to_dict_disabled_model(self):
model = AvailableModel(
id="disabled",
provider=ModelProvider.OPENAI,
display_name="Disabled",
enabled=False,
)
d = model.to_dict()
assert d["enabled"] is False
@pytest.mark.unit
def test_api_key_field_never_serialized(self):
"""Forward-compat hook: AvailableModel.api_key (reserved for the
future end-user BYOM phase) must never leak into the wire format."""
model = AvailableModel(
id="byom",
provider=ModelProvider.OPENAI,
display_name="BYOM",
api_key="secret-key-do-not-leak",
)
d = model.to_dict()
assert "api_key" not in d
for v in d.values():
assert v != "secret-key-do-not-leak"
class TestModelRegistry:
class TestModelRegistryPublicAPI:
"""Covers the public lookup contract. Loading behavior is exercised
end-to-end in tests/core/test_model_registry_yaml.py."""
@pytest.fixture(autouse=True)
def _reset_singleton(self):
"""Reset singleton between tests."""
ModelRegistry._instance = None
ModelRegistry._initialized = False
ModelRegistry.reset()
yield
ModelRegistry._instance = None
ModelRegistry._initialized = False
ModelRegistry.reset()
@pytest.mark.unit
def test_singleton(self):
@@ -125,7 +165,9 @@ class TestModelRegistry:
def test_get_model(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
model = AvailableModel(id="test", provider=ModelProvider.OPENAI, display_name="Test")
model = AvailableModel(
id="test", provider=ModelProvider.OPENAI, display_name="Test"
)
reg.models["test"] = model
assert reg.get_model("test") is model
assert reg.get_model("nonexistent") is None
@@ -134,16 +176,30 @@ class TestModelRegistry:
def test_get_all_models(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models["m1"] = AvailableModel(id="m1", provider=ModelProvider.OPENAI, display_name="M1")
reg.models["m2"] = AvailableModel(id="m2", provider=ModelProvider.ANTHROPIC, display_name="M2")
reg.models["m1"] = AvailableModel(
id="m1", provider=ModelProvider.OPENAI, display_name="M1"
)
reg.models["m2"] = AvailableModel(
id="m2", provider=ModelProvider.ANTHROPIC, display_name="M2"
)
assert len(reg.get_all_models()) == 2
@pytest.mark.unit
def test_get_enabled_models(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models["m1"] = AvailableModel(id="m1", provider=ModelProvider.OPENAI, display_name="M1", enabled=True)
reg.models["m2"] = AvailableModel(id="m2", provider=ModelProvider.OPENAI, display_name="M2", enabled=False)
reg.models["m1"] = AvailableModel(
id="m1",
provider=ModelProvider.OPENAI,
display_name="M1",
enabled=True,
)
reg.models["m2"] = AvailableModel(
id="m2",
provider=ModelProvider.OPENAI,
display_name="M2",
enabled=False,
)
enabled = reg.get_enabled_models()
assert len(enabled) == 1
assert enabled[0].id == "m1"
@@ -152,652 +208,29 @@ class TestModelRegistry:
def test_model_exists(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models["m1"] = AvailableModel(id="m1", provider=ModelProvider.OPENAI, display_name="M1")
reg.models["m1"] = AvailableModel(
id="m1", provider=ModelProvider.OPENAI, display_name="M1"
)
assert reg.model_exists("m1") is True
assert reg.model_exists("m2") is False
@pytest.mark.unit
def test_parse_model_names(self):
def test_lookups_accept_user_id_kwarg(self):
"""Reserved for the future end-user BYOM phase. Currently ignored."""
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
assert reg._parse_model_names("model1,model2") == ["model1", "model2"]
assert reg._parse_model_names("model1 , model2 ") == ["model1", "model2"]
assert reg._parse_model_names("single") == ["single"]
assert reg._parse_model_names("") == []
assert reg._parse_model_names(None) == []
@pytest.mark.unit
def test_add_docsgpt_models(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
reg._add_docsgpt_models(mock_settings)
assert "docsgpt-local" in reg.models
@pytest.mark.unit
def test_add_huggingface_models(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
reg._add_huggingface_models(mock_settings)
assert "huggingface-local" in reg.models
@pytest.mark.unit
def test_load_models_with_openai_key(self):
mock_settings = MagicMock()
mock_settings.OPENAI_BASE_URL = None
mock_settings.OPENAI_API_KEY = "sk-test"
mock_settings.OPENAI_API_BASE = None
mock_settings.ANTHROPIC_API_KEY = None
mock_settings.GOOGLE_API_KEY = None
mock_settings.GROQ_API_KEY = None
mock_settings.OPEN_ROUTER_API_KEY = None
mock_settings.NOVITA_API_KEY = None
mock_settings.HUGGINGFACE_API_KEY = None
mock_settings.LLM_PROVIDER = "openai"
mock_settings.LLM_NAME = ""
mock_settings.API_KEY = None
with patch("application.core.settings.settings", mock_settings):
reg = ModelRegistry()
assert len(reg.models) > 0
@pytest.mark.unit
def test_load_models_custom_openai_base_url(self):
mock_settings = MagicMock()
mock_settings.OPENAI_BASE_URL = "http://localhost:11434/v1"
mock_settings.OPENAI_API_KEY = "sk-test"
mock_settings.OPENAI_API_BASE = None
mock_settings.ANTHROPIC_API_KEY = None
mock_settings.GOOGLE_API_KEY = None
mock_settings.GROQ_API_KEY = None
mock_settings.OPEN_ROUTER_API_KEY = None
mock_settings.NOVITA_API_KEY = None
mock_settings.HUGGINGFACE_API_KEY = None
mock_settings.LLM_PROVIDER = "openai"
mock_settings.LLM_NAME = "llama3,gemma"
mock_settings.API_KEY = None
with patch("application.core.settings.settings", mock_settings):
reg = ModelRegistry()
assert "llama3" in reg.models
assert "gemma" in reg.models
@pytest.mark.unit
def test_default_model_selection_from_llm_name(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {"gpt-4": AvailableModel(id="gpt-4", provider=ModelProvider.OPENAI, display_name="GPT-4")}
reg.default_model_id = "gpt-4"
assert reg.default_model_id == "gpt-4"
@pytest.mark.unit
def test_add_anthropic_models_with_key(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.ANTHROPIC_API_KEY = "sk-ant-test"
mock_settings.LLM_PROVIDER = ""
mock_settings.LLM_NAME = ""
reg._add_anthropic_models(mock_settings)
assert len(reg.models) > 0
@pytest.mark.unit
def test_add_google_models_with_key(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.GOOGLE_API_KEY = "google-test"
mock_settings.LLM_PROVIDER = ""
mock_settings.LLM_NAME = ""
reg._add_google_models(mock_settings)
assert len(reg.models) > 0
@pytest.mark.unit
def test_add_groq_models_with_key(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.GROQ_API_KEY = "groq-test"
mock_settings.LLM_PROVIDER = ""
mock_settings.LLM_NAME = ""
reg._add_groq_models(mock_settings)
assert len(reg.models) > 0
@pytest.mark.unit
def test_add_openrouter_models_with_key(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.OPEN_ROUTER_API_KEY = "or-test"
mock_settings.LLM_PROVIDER = ""
mock_settings.LLM_NAME = ""
reg._add_openrouter_models(mock_settings)
assert len(reg.models) > 0
@pytest.mark.unit
def test_add_novita_models_with_key(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.NOVITA_API_KEY = "novita-test"
mock_settings.LLM_PROVIDER = ""
mock_settings.LLM_NAME = ""
reg._add_novita_models(mock_settings)
assert len(reg.models) > 0
@pytest.mark.unit
def test_add_azure_openai_models_specific(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.LLM_PROVIDER = "azure_openai"
mock_settings.LLM_NAME = "nonexistent-model"
reg._add_azure_openai_models(mock_settings)
# Falls through to adding all azure models
assert len(reg.models) > 0
@pytest.mark.unit
def test_add_anthropic_models_no_key_with_provider(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.ANTHROPIC_API_KEY = None
mock_settings.LLM_PROVIDER = "anthropic"
mock_settings.LLM_NAME = "nonexistent"
reg._add_anthropic_models(mock_settings)
assert len(reg.models) > 0
@pytest.mark.unit
def test_default_model_fallback_to_first(self):
mock_settings = MagicMock()
mock_settings.OPENAI_BASE_URL = None
mock_settings.OPENAI_API_KEY = None
mock_settings.OPENAI_API_BASE = None
mock_settings.ANTHROPIC_API_KEY = None
mock_settings.GOOGLE_API_KEY = None
mock_settings.GROQ_API_KEY = None
mock_settings.OPEN_ROUTER_API_KEY = None
mock_settings.NOVITA_API_KEY = None
mock_settings.HUGGINGFACE_API_KEY = None
mock_settings.LLM_PROVIDER = ""
mock_settings.LLM_NAME = ""
mock_settings.API_KEY = None
with patch("application.core.settings.settings", mock_settings):
reg = ModelRegistry()
# Should have at least docsgpt-local
assert reg.default_model_id is not None
@pytest.mark.unit
def test_default_model_from_provider_fallback(self):
"""When LLM_NAME is not set but LLM_PROVIDER and API_KEY are,
default should be first model of that provider."""
mock_settings = MagicMock()
mock_settings.OPENAI_BASE_URL = None
mock_settings.OPENAI_API_KEY = "sk-test"
mock_settings.OPENAI_API_BASE = None
mock_settings.ANTHROPIC_API_KEY = None
mock_settings.GOOGLE_API_KEY = None
mock_settings.GROQ_API_KEY = None
mock_settings.OPEN_ROUTER_API_KEY = None
mock_settings.NOVITA_API_KEY = None
mock_settings.HUGGINGFACE_API_KEY = None
mock_settings.LLM_PROVIDER = "openai"
mock_settings.LLM_NAME = None
mock_settings.API_KEY = "sk-test"
with patch("application.core.settings.settings", mock_settings):
reg = ModelRegistry()
assert reg.default_model_id is not None
@pytest.mark.unit
def test_add_google_models_no_key_with_provider(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.GOOGLE_API_KEY = None
mock_settings.LLM_PROVIDER = "google"
mock_settings.LLM_NAME = "nonexistent"
reg._add_google_models(mock_settings)
assert len(reg.models) > 0
@pytest.mark.unit
def test_add_groq_models_no_key_with_provider(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.GROQ_API_KEY = None
mock_settings.LLM_PROVIDER = "groq"
mock_settings.LLM_NAME = "nonexistent"
reg._add_groq_models(mock_settings)
assert len(reg.models) > 0
@pytest.mark.unit
def test_add_openrouter_models_no_key_with_provider(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.OPEN_ROUTER_API_KEY = None
mock_settings.LLM_PROVIDER = "openrouter"
mock_settings.LLM_NAME = "nonexistent"
reg._add_openrouter_models(mock_settings)
assert len(reg.models) > 0
@pytest.mark.unit
def test_add_novita_models_no_key_with_provider(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.NOVITA_API_KEY = None
mock_settings.LLM_PROVIDER = "novita"
mock_settings.LLM_NAME = "nonexistent"
reg._add_novita_models(mock_settings)
assert len(reg.models) > 0
@pytest.mark.unit
def test_to_dict_disabled_model(self):
model = AvailableModel(
id="disabled",
provider=ModelProvider.OPENAI,
display_name="Disabled",
enabled=False,
)
d = model.to_dict()
assert d["enabled"] is False
@pytest.mark.unit
def test_to_dict_with_attachment_types(self):
caps = ModelCapabilities(
supported_attachment_types=["image/png", "application/pdf"],
)
model = AvailableModel(
id="vision",
provider=ModelProvider.OPENAI,
display_name="Vision",
capabilities=caps,
)
d = model.to_dict()
assert d["supported_attachment_types"] == ["image/png", "application/pdf"]
# ----------------------------------------------------------------
# Coverage for _add_* methods with matching LLM_NAME
# Lines: 100, 105, 147, 171, 179, 186, 199-201, 204, 210, 213,
# 218, 229, 233, 241, 250
# ----------------------------------------------------------------
@pytest.mark.unit
def test_add_azure_openai_models_with_matching_name(self):
"""Cover line 186: azure model matching LLM_NAME returns early."""
from application.core.model_configs import AZURE_OPENAI_MODELS
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.LLM_PROVIDER = "azure_openai"
if AZURE_OPENAI_MODELS:
mock_settings.LLM_NAME = AZURE_OPENAI_MODELS[0].id
else:
mock_settings.LLM_NAME = "nonexistent"
reg._add_azure_openai_models(mock_settings)
# Should have added at least one model
assert len(reg.models) >= 1
@pytest.mark.unit
def test_add_anthropic_no_key_no_provider_fallthrough(self):
"""Cover lines 199-204: no key, provider set but name not found -> add all."""
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.ANTHROPIC_API_KEY = None
mock_settings.LLM_PROVIDER = "anthropic"
mock_settings.LLM_NAME = "nonexistent-model"
reg._add_anthropic_models(mock_settings)
# Falls through to add all anthropic models
assert len(reg.models) > 0
@pytest.mark.unit
def test_add_google_no_key_matching_name(self):
"""Cover lines 213-218: Google fallback with matching name."""
from application.core.model_configs import GOOGLE_MODELS
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.GOOGLE_API_KEY = None
mock_settings.LLM_PROVIDER = "google"
if GOOGLE_MODELS:
mock_settings.LLM_NAME = GOOGLE_MODELS[0].id
else:
mock_settings.LLM_NAME = "nonexistent"
reg._add_google_models(mock_settings)
assert len(reg.models) >= 1
@pytest.mark.unit
def test_add_groq_no_key_matching_name(self):
"""Cover lines 229-233: Groq fallback with matching name."""
from application.core.model_configs import GROQ_MODELS
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.GROQ_API_KEY = None
mock_settings.LLM_PROVIDER = "groq"
if GROQ_MODELS:
mock_settings.LLM_NAME = GROQ_MODELS[0].id
else:
mock_settings.LLM_NAME = "nonexistent"
reg._add_groq_models(mock_settings)
assert len(reg.models) >= 1
@pytest.mark.unit
def test_add_openrouter_no_key_matching_name(self):
"""Cover lines 241-250: OpenRouter fallback with matching name."""
from application.core.model_configs import OPENROUTER_MODELS
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.OPEN_ROUTER_API_KEY = None
mock_settings.LLM_PROVIDER = "openrouter"
if OPENROUTER_MODELS:
mock_settings.LLM_NAME = OPENROUTER_MODELS[0].id
else:
mock_settings.LLM_NAME = "nonexistent"
reg._add_openrouter_models(mock_settings)
assert len(reg.models) >= 1
@pytest.mark.unit
def test_add_novita_no_key_matching_name(self):
"""Cover novita fallback with matching name."""
from application.core.model_configs import NOVITA_MODELS
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.NOVITA_API_KEY = None
mock_settings.LLM_PROVIDER = "novita"
if NOVITA_MODELS:
mock_settings.LLM_NAME = NOVITA_MODELS[0].id
else:
mock_settings.LLM_NAME = "nonexistent"
reg._add_novita_models(mock_settings)
assert len(reg.models) >= 1
@pytest.mark.unit
def test_load_models_default_from_llm_name_exact_match(self):
"""Cover line 136/147: exact LLM_NAME match for default model."""
mock_settings = MagicMock()
mock_settings.OPENAI_BASE_URL = None
mock_settings.OPENAI_API_KEY = "sk-test"
mock_settings.OPENAI_API_BASE = None
mock_settings.ANTHROPIC_API_KEY = None
mock_settings.GOOGLE_API_KEY = None
mock_settings.GROQ_API_KEY = None
mock_settings.OPEN_ROUTER_API_KEY = None
mock_settings.NOVITA_API_KEY = None
mock_settings.HUGGINGFACE_API_KEY = None
mock_settings.LLM_PROVIDER = "openai"
mock_settings.API_KEY = None
from application.core.model_configs import OPENAI_MODELS
if OPENAI_MODELS:
mock_settings.LLM_NAME = OPENAI_MODELS[0].id
else:
mock_settings.LLM_NAME = "gpt-4o"
with patch("application.core.settings.settings", mock_settings):
reg = ModelRegistry()
assert reg.default_model_id is not None
@pytest.mark.unit
def test_add_openai_models_local_endpoint_no_name(self):
"""Cover line 171: local endpoint without LLM_NAME adds nothing."""
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.OPENAI_BASE_URL = "http://localhost:11434/v1"
mock_settings.OPENAI_API_KEY = "sk-test"
mock_settings.LLM_NAME = None
reg._add_openai_models(mock_settings)
assert len(reg.models) == 0
@pytest.mark.unit
def test_add_openai_standard_no_api_key(self):
"""Cover line 179: standard OpenAI without API key adds nothing."""
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.OPENAI_BASE_URL = None
mock_settings.OPENAI_API_KEY = None
reg._add_openai_models(mock_settings)
assert len(reg.models) == 0
# ---------------------------------------------------------------------------
# Coverage — additional uncovered lines: 100, 105, 147, 171, 179, 186, 250
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestModelRegistryAdditionalCoverage:
def test_add_azure_openai_models_specific_name(self):
"""Cover line 186: azure_openai with specific LLM_NAME match."""
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.LLM_PROVIDER = "azure_openai"
mock_settings.LLM_NAME = "gpt-4o"
# Create a fake model that matches
fake_model = MagicMock()
fake_model.id = "gpt-4o"
with patch(
"application.core.model_configs.AZURE_OPENAI_MODELS",
[fake_model],
):
reg._add_azure_openai_models(mock_settings)
assert "gpt-4o" in reg.models
def test_add_anthropic_models_with_api_key(self):
"""Cover line 100: anthropic with API key."""
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.ANTHROPIC_API_KEY = "sk-test"
mock_settings.LLM_PROVIDER = "anthropic"
reg._add_anthropic_models(mock_settings)
assert len(reg.models) > 0
def test_add_google_models_with_api_key(self):
"""Cover line 105: google with API key."""
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.GOOGLE_API_KEY = "test-key"
mock_settings.LLM_PROVIDER = "google"
reg._add_google_models(mock_settings)
assert len(reg.models) > 0
def test_default_model_from_provider(self):
"""Cover line 147: default model selected from provider."""
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
reg.default_model_id = None
fake_model = MagicMock()
fake_model.provider = MagicMock()
fake_model.provider.value = "openai"
reg.models["gpt-4o"] = fake_model
mock_settings = MagicMock()
mock_settings.LLM_NAME = None
mock_settings.LLM_PROVIDER = "openai"
mock_settings.API_KEY = "key"
# Simulate the default selection logic
if not reg.default_model_id:
for model_id, model in reg.models.items():
if model.provider.value == mock_settings.LLM_PROVIDER:
reg.default_model_id = model_id
break
assert reg.default_model_id == "gpt-4o"
def test_add_openai_local_endpoint_with_llm_name(self):
"""Cover line 171: local endpoint registers custom models from LLM_NAME."""
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.OPENAI_BASE_URL = "http://localhost:11434/v1"
mock_settings.OPENAI_API_KEY = "sk-test"
mock_settings.LLM_NAME = "llama3,phi3"
reg._add_openai_models(mock_settings)
assert "llama3" in reg.models
assert "phi3" in reg.models
def test_add_openai_standard_with_api_key(self):
"""Cover line 179: standard OpenAI with API key adds models."""
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.OPENAI_BASE_URL = None
mock_settings.OPENAI_API_KEY = "sk-real-key"
reg._add_openai_models(mock_settings)
assert len(reg.models) > 0
def test_add_openrouter_models(self):
"""Cover line 250: openrouter models added."""
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.OPEN_ROUTER_API_KEY = "or-key"
mock_settings.LLM_PROVIDER = "openrouter"
reg._add_openrouter_models(mock_settings)
assert len(reg.models) > 0
# ---------------------------------------------------------------------------
# Additional coverage for model_settings.py
# Lines: 135-136 (backward compat LLM_NAME), 138-143 (provider fallback),
# 145-146 (first model as default)
# ---------------------------------------------------------------------------
# Imports already at the top of the file; no additional imports needed
@pytest.mark.unit
class TestDefaultModelSelectionBackwardCompat:
"""Cover lines 135-136: backward compat exact match on LLM_NAME."""
def test_llm_name_exact_match_as_default(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
reg.default_model_id = None
# Add a model with composite ID
model = AvailableModel(
id="my-composite-model",
provider=ModelProvider.OPENAI,
display_name="Composite",
description="test",
capabilities=ModelCapabilities(),
reg.models["m1"] = AvailableModel(
id="m1", provider=ModelProvider.OPENAI, display_name="M1"
)
reg.models["my-composite-model"] = model
assert reg.get_model("m1", user_id="alice") is not None
assert reg.model_exists("m1", user_id="alice") is True
assert len(reg.get_all_models(user_id="alice")) == 1
assert len(reg.get_enabled_models(user_id="alice")) == 1
# Simulate _parse_model_names returning something different
# so that the first for-loop doesn't match
mock_settings = MagicMock()
mock_settings.LLM_NAME = "my-composite-model"
mock_settings.LLM_PROVIDER = None
mock_settings.API_KEY = None
# Call the logic directly
model_names = reg._parse_model_names(mock_settings.LLM_NAME)
for mn in model_names:
if mn in reg.models:
reg.default_model_id = mn
break
assert reg.default_model_id == "my-composite-model"
@pytest.mark.unit
class TestDefaultModelSelectionByProvider:
"""Cover lines 138-143: default model by provider when LLM_NAME doesn't match."""
def test_default_by_provider(self):
@pytest.mark.unit
def test_reset(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
reg.default_model_id = None
model = AvailableModel(
id="gpt-4",
provider=ModelProvider.OPENAI,
display_name="GPT-4",
description="test",
capabilities=ModelCapabilities(),
)
reg.models["gpt-4"] = model
# Simulate: LLM_NAME doesn't exist/match, but LLM_PROVIDER + API_KEY set
if not reg.default_model_id:
for model_id, m in reg.models.items():
if m.provider.value == "openai":
reg.default_model_id = model_id
break
assert reg.default_model_id == "gpt-4"
@pytest.mark.unit
class TestDefaultModelSelectionFirstModel:
"""Cover lines 145-146: first model as default when nothing else matches."""
def test_first_model_as_default(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
reg.default_model_id = None
model = AvailableModel(
id="fallback-model",
provider=ModelProvider.OPENAI,
display_name="Fallback",
description="test",
capabilities=ModelCapabilities(),
)
reg.models["fallback-model"] = model
if not reg.default_model_id and reg.models:
reg.default_model_id = next(iter(reg.models.keys()))
assert reg.default_model_id == "fallback-model"
r1 = ModelRegistry()
ModelRegistry.reset()
r2 = ModelRegistry()
assert r1 is not r2

View File

@@ -0,0 +1,208 @@
"""Phase 3 tests: operator MODELS_CONFIG_DIR.
Covers the operator-supplied directory of model YAMLs that's loaded
after the built-in catalog. Operators use this to add new
``openai_compatible`` providers, extend an existing provider's catalog
with extra models, or override a built-in model's capabilities — all
without forking the repo.
"""
from __future__ import annotations
import logging
from textwrap import dedent
from unittest.mock import MagicMock, patch
import pytest
from application.core.model_registry import ModelRegistry
def _make_settings(**overrides):
s = MagicMock()
s.OPENAI_BASE_URL = None
s.OPENAI_API_KEY = None
s.OPENAI_API_BASE = None
s.ANTHROPIC_API_KEY = None
s.GOOGLE_API_KEY = None
s.GROQ_API_KEY = None
s.OPEN_ROUTER_API_KEY = None
s.NOVITA_API_KEY = None
s.HUGGINGFACE_API_KEY = None
s.LLM_PROVIDER = ""
s.LLM_NAME = None
s.API_KEY = None
s.MODELS_CONFIG_DIR = None
for k, v in overrides.items():
setattr(s, k, v)
return s
@pytest.fixture(autouse=True)
def _reset_registry():
ModelRegistry.reset()
yield
ModelRegistry.reset()
# ── New provider via openai_compatible ───────────────────────────────────
@pytest.mark.unit
class TestOperatorAddsNewProvider:
def test_drop_in_yaml_appears_in_registry(
self, tmp_path, monkeypatch
):
(tmp_path / "fireworks.yaml").write_text(dedent("""
provider: openai_compatible
display_provider: fireworks
api_key_env: FIREWORKS_API_KEY
base_url: https://api.fireworks.ai/inference/v1
defaults:
supports_tools: true
models:
- id: accounts/fireworks/models/llama-v3p3-70b-instruct
display_name: Llama 3.3 70B (Fireworks)
"""))
monkeypatch.setenv("FIREWORKS_API_KEY", "fw-key")
s = _make_settings(MODELS_CONFIG_DIR=str(tmp_path))
with patch("application.core.settings.settings", s):
reg = ModelRegistry()
m = reg.get_model("accounts/fireworks/models/llama-v3p3-70b-instruct")
assert m is not None
assert m.api_key == "fw-key"
assert m.base_url == "https://api.fireworks.ai/inference/v1"
assert m.display_provider == "fireworks"
# ── Extending an existing provider's catalog ─────────────────────────────
@pytest.mark.unit
class TestOperatorExtendsExistingProvider:
def test_operator_adds_anthropic_model_to_builtin_catalog(
self, tmp_path
):
(tmp_path / "anthropic-extra.yaml").write_text(dedent("""
provider: anthropic
defaults:
supports_tools: true
context_window: 200000
models:
- id: claude-haiku-5-0-future
display_name: Claude Haiku 5.0
"""))
s = _make_settings(
ANTHROPIC_API_KEY="sk-ant",
MODELS_CONFIG_DIR=str(tmp_path),
)
with patch("application.core.settings.settings", s):
reg = ModelRegistry()
# Built-in models still present
assert reg.get_model("claude-sonnet-4-6") is not None
assert reg.get_model("claude-opus-4-7") is not None
# Operator-added model also present
added = reg.get_model("claude-haiku-5-0-future")
assert added is not None
assert added.display_name == "Claude Haiku 5.0"
# ── Overriding a built-in model's capabilities ───────────────────────────
@pytest.mark.unit
class TestOperatorOverridesBuiltinCapabilities:
def test_operator_yaml_overrides_builtin_context_window(
self, tmp_path, caplog
):
# Override anthropic claude-haiku-4-5 to claim a 1M context window
(tmp_path / "anthropic-override.yaml").write_text(dedent("""
provider: anthropic
defaults:
supports_tools: true
attachments: [image]
context_window: 1000000
models:
- id: claude-haiku-4-5
display_name: Claude Haiku 4.5 (extended)
description: Operator-overridden capabilities
"""))
s = _make_settings(
ANTHROPIC_API_KEY="sk-ant",
MODELS_CONFIG_DIR=str(tmp_path),
)
with caplog.at_level(logging.WARNING):
with patch("application.core.settings.settings", s):
reg = ModelRegistry()
m = reg.get_model("claude-haiku-4-5")
assert m.display_name == "Claude Haiku 4.5 (extended)"
assert m.description == "Operator-overridden capabilities"
assert m.capabilities.context_window == 1_000_000
# And the override warning fires so the operator can audit it
assert any(
"claude-haiku-4-5" in rec.message and "redefined" in rec.message
for rec in caplog.records
)
# ── Misconfigured MODELS_CONFIG_DIR ──────────────────────────────────────
@pytest.mark.unit
class TestMisconfiguredOperatorDir:
def test_missing_dir_logs_warning_and_continues(
self, tmp_path, caplog
):
bogus = tmp_path / "does-not-exist"
s = _make_settings(MODELS_CONFIG_DIR=str(bogus))
with caplog.at_level(logging.WARNING):
with patch("application.core.settings.settings", s):
reg = ModelRegistry()
# Built-in catalog still loaded
assert reg.get_model("docsgpt-local") is not None
# And the operator was warned
assert any("does not exist" in rec.message for rec in caplog.records)
def test_path_is_a_file_logs_warning(self, tmp_path, caplog):
afile = tmp_path / "not-a-dir.yaml"
afile.write_text("provider: anthropic\nmodels: []")
s = _make_settings(MODELS_CONFIG_DIR=str(afile))
with caplog.at_level(logging.WARNING):
with patch("application.core.settings.settings", s):
reg = ModelRegistry()
assert reg.get_model("docsgpt-local") is not None
assert any("not a directory" in rec.message for rec in caplog.records)
# ── Validation: unknown provider rejected ────────────────────────────────
@pytest.mark.unit
class TestOperatorValidation:
def test_unknown_provider_in_operator_yaml_aborts_boot(self, tmp_path):
(tmp_path / "bogus.yaml").write_text(dedent("""
provider: not_a_real_provider
models:
- id: x
display_name: X
"""))
s = _make_settings(MODELS_CONFIG_DIR=str(tmp_path))
with patch("application.core.settings.settings", s):
with pytest.raises(Exception) as exc_info:
ModelRegistry()
# Could be ModelYAMLError (enum check) or ValueError (registry check);
# either way the message must surface what's wrong.
msg = str(exc_info.value)
assert "not_a_real_provider" in msg

View File

@@ -0,0 +1,298 @@
"""Phase 2 tests for the openai_compatible provider.
Covers YAML loading from a temp directory, multiple coexisting catalogs
(Mistral + Together), env-var-based credential resolution, the legacy
OPENAI_BASE_URL + LLM_NAME fallback, and end-to-end model dispatch
through LLMCreator.
"""
from __future__ import annotations
from pathlib import Path
from textwrap import dedent
from unittest.mock import MagicMock, patch
import pytest
from application.core.model_registry import ModelRegistry
from application.core.model_settings import ModelProvider
def _make_settings(**overrides):
s = MagicMock()
s.OPENAI_BASE_URL = None
s.OPENAI_API_KEY = None
s.OPENAI_API_BASE = None
s.ANTHROPIC_API_KEY = None
s.GOOGLE_API_KEY = None
s.GROQ_API_KEY = None
s.OPEN_ROUTER_API_KEY = None
s.NOVITA_API_KEY = None
s.HUGGINGFACE_API_KEY = None
s.LLM_PROVIDER = ""
s.LLM_NAME = None
s.API_KEY = None
s.MODELS_CONFIG_DIR = None
for k, v in overrides.items():
setattr(s, k, v)
return s
def _write_mistral_yaml(directory: Path) -> Path:
path = directory / "mistral.yaml"
path.write_text(dedent("""
provider: openai_compatible
display_provider: mistral
api_key_env: MISTRAL_API_KEY
base_url: https://api.mistral.ai/v1
defaults:
supports_tools: true
context_window: 128000
models:
- id: mistral-large-latest
display_name: Mistral Large
- id: mistral-small-latest
display_name: Mistral Small
"""))
return path
def _write_together_yaml(directory: Path) -> Path:
path = directory / "together.yaml"
path.write_text(dedent("""
provider: openai_compatible
display_provider: together
api_key_env: TOGETHER_API_KEY
base_url: https://api.together.xyz/v1
defaults:
supports_tools: true
models:
- id: meta-llama/Llama-3.3-70B-Instruct-Turbo
display_name: Llama 3.3 70B (Together)
"""))
return path
@pytest.fixture(autouse=True)
def _reset_registry():
ModelRegistry.reset()
yield
ModelRegistry.reset()
# ── YAML-driven catalogs ─────────────────────────────────────────────────
@pytest.mark.unit
class TestYAMLCompatibleProvider:
def test_mistral_yaml_loads_with_env_key(
self, tmp_path, monkeypatch
):
_write_mistral_yaml(tmp_path)
monkeypatch.setenv("MISTRAL_API_KEY", "sk-mistral-test")
s = _make_settings(MODELS_CONFIG_DIR=str(tmp_path))
with patch("application.core.settings.settings", s):
reg = ModelRegistry()
m = reg.get_model("mistral-large-latest")
assert m is not None
assert m.provider == ModelProvider.OPENAI_COMPATIBLE
assert m.display_provider == "mistral"
assert m.base_url == "https://api.mistral.ai/v1"
assert m.api_key == "sk-mistral-test"
assert m.capabilities.supports_tools is True
assert m.capabilities.context_window == 128000
def test_yaml_skipped_when_env_var_missing(
self, tmp_path, monkeypatch
):
_write_mistral_yaml(tmp_path)
monkeypatch.delenv("MISTRAL_API_KEY", raising=False)
s = _make_settings(MODELS_CONFIG_DIR=str(tmp_path))
with patch("application.core.settings.settings", s):
reg = ModelRegistry()
# Catalog skipped when no key — no Mistral models in the registry
assert reg.get_model("mistral-large-latest") is None
def test_two_compatible_catalogs_coexist_with_separate_keys(
self, tmp_path, monkeypatch
):
_write_mistral_yaml(tmp_path)
_write_together_yaml(tmp_path)
monkeypatch.setenv("MISTRAL_API_KEY", "sk-mistral")
monkeypatch.setenv("TOGETHER_API_KEY", "sk-together")
s = _make_settings(MODELS_CONFIG_DIR=str(tmp_path))
with patch("application.core.settings.settings", s):
reg = ModelRegistry()
mistral = reg.get_model("mistral-large-latest")
together = reg.get_model("meta-llama/Llama-3.3-70B-Instruct-Turbo")
assert mistral.api_key == "sk-mistral"
assert mistral.base_url == "https://api.mistral.ai/v1"
assert mistral.display_provider == "mistral"
assert together.api_key == "sk-together"
assert together.base_url == "https://api.together.xyz/v1"
assert together.display_provider == "together"
def test_one_catalog_enabled_other_skipped(
self, tmp_path, monkeypatch
):
_write_mistral_yaml(tmp_path)
_write_together_yaml(tmp_path)
monkeypatch.setenv("MISTRAL_API_KEY", "sk-mistral")
monkeypatch.delenv("TOGETHER_API_KEY", raising=False)
s = _make_settings(MODELS_CONFIG_DIR=str(tmp_path))
with patch("application.core.settings.settings", s):
reg = ModelRegistry()
assert reg.get_model("mistral-large-latest") is not None
assert reg.get_model("meta-llama/Llama-3.3-70B-Instruct-Turbo") is None
def test_missing_base_url_raises(self, tmp_path, monkeypatch):
bad = tmp_path / "broken.yaml"
bad.write_text(dedent("""
provider: openai_compatible
api_key_env: SOME_KEY
models:
- id: x
display_name: X
"""))
monkeypatch.setenv("SOME_KEY", "k")
s = _make_settings(MODELS_CONFIG_DIR=str(tmp_path))
with patch("application.core.settings.settings", s):
with pytest.raises(ValueError, match="must set 'base_url'"):
ModelRegistry()
def test_missing_api_key_env_raises(self, tmp_path, monkeypatch):
bad = tmp_path / "broken.yaml"
bad.write_text(dedent("""
provider: openai_compatible
base_url: https://x/v1
models:
- id: x
display_name: X
"""))
s = _make_settings(MODELS_CONFIG_DIR=str(tmp_path))
with patch("application.core.settings.settings", s):
with pytest.raises(ValueError, match="must set 'api_key_env'"):
ModelRegistry()
def test_to_dict_uses_display_provider(
self, tmp_path, monkeypatch
):
_write_mistral_yaml(tmp_path)
monkeypatch.setenv("MISTRAL_API_KEY", "sk")
s = _make_settings(MODELS_CONFIG_DIR=str(tmp_path))
with patch("application.core.settings.settings", s):
reg = ModelRegistry()
d = reg.get_model("mistral-large-latest").to_dict()
# /api/models response shows "mistral", not "openai_compatible"
assert d["provider"] == "mistral"
# api_key never leaks into the wire format
assert "api_key" not in d
for v in d.values():
assert v != "sk"
# ── Legacy OPENAI_BASE_URL fallback ──────────────────────────────────────
@pytest.mark.unit
class TestLegacyOpenAIBaseURLPath:
def test_legacy_models_now_provided_by_openai_compatible(self):
s = _make_settings(
OPENAI_BASE_URL="http://localhost:11434/v1",
OPENAI_API_KEY="sk-local",
LLM_PROVIDER="openai",
LLM_NAME="llama3,gemma",
)
with patch("application.core.settings.settings", s):
reg = ModelRegistry()
ids = {m.id for m in reg.get_all_models()}
assert ids == {"llama3", "gemma"}
llama = reg.get_model("llama3")
assert llama.base_url == "http://localhost:11434/v1"
assert llama.api_key == "sk-local"
assert llama.provider == ModelProvider.OPENAI_COMPATIBLE
# Display provider preserves the historical "openai" label
assert llama.display_provider == "openai"
assert llama.to_dict()["provider"] == "openai"
def test_legacy_uses_api_key_fallback_when_openai_api_key_missing(self):
s = _make_settings(
OPENAI_BASE_URL="http://localhost:11434/v1",
OPENAI_API_KEY=None,
API_KEY="sk-generic",
LLM_PROVIDER="openai",
LLM_NAME="llama3",
)
with patch("application.core.settings.settings", s):
reg = ModelRegistry()
assert reg.get_model("llama3").api_key == "sk-generic"
# ── Dispatch through LLMCreator ──────────────────────────────────────────
@pytest.mark.unit
class TestLLMCreatorDispatch:
def test_llmcreator_uses_per_model_api_key_and_base_url(
self, tmp_path, monkeypatch
):
"""End-to-end: when an openai_compatible model is dispatched, the
per-model api_key + base_url from the registry must override
whatever the caller passed."""
_write_mistral_yaml(tmp_path)
monkeypatch.setenv("MISTRAL_API_KEY", "sk-mistral-real")
s = _make_settings(MODELS_CONFIG_DIR=str(tmp_path))
captured = {}
class _FakeLLM:
def __init__(
self, api_key, user_api_key, *args, **kwargs
):
captured["api_key"] = api_key
captured["base_url"] = kwargs.get("base_url")
captured["model_id"] = kwargs.get("model_id")
with patch("application.core.settings.settings", s):
ModelRegistry.reset()
ModelRegistry() # warm up the registry under patched settings
# Now patch the OpenAI plugin's class so we can capture the
# constructor args without spinning up the real OpenAILLM.
from application.llm.providers import PROVIDERS_BY_NAME
with patch.object(
PROVIDERS_BY_NAME["openai_compatible"],
"llm_class",
_FakeLLM,
):
from application.llm.llm_creator import LLMCreator
LLMCreator.create_llm(
type="openai_compatible",
api_key="caller-passed-WRONG-key",
user_api_key=None,
decoded_token={"sub": "u"},
model_id="mistral-large-latest",
)
assert captured["api_key"] == "sk-mistral-real"
assert captured["base_url"] == "https://api.mistral.ai/v1"
assert captured["model_id"] == "mistral-large-latest"

View File

@@ -0,0 +1,505 @@
"""Tests for the BYOM per-user layer on ModelRegistry.
Covers: per-user lookups don't leak across users, lookups without
user_id stay built-in only, get_all_models / get_enabled_models /
model_exists all consult the user layer when given user_id, and the
explicit invalidate_user clears the cache.
"""
from __future__ import annotations
from contextlib import contextmanager
from unittest.mock import MagicMock, patch
import pytest
from application.core.model_registry import ModelRegistry
from application.core.model_settings import ModelProvider
from application.storage.db.repositories.user_custom_models import (
UserCustomModelsRepository,
)
@pytest.fixture(autouse=True)
def _reset_registry():
ModelRegistry.reset()
yield
ModelRegistry.reset()
def _make_settings(**overrides):
s = MagicMock()
s.OPENAI_BASE_URL = None
s.OPENAI_API_KEY = None
s.OPENAI_API_BASE = None
s.ANTHROPIC_API_KEY = None
s.GOOGLE_API_KEY = None
s.GROQ_API_KEY = None
s.OPEN_ROUTER_API_KEY = None
s.NOVITA_API_KEY = None
s.HUGGINGFACE_API_KEY = None
s.LLM_PROVIDER = ""
s.LLM_NAME = None
s.API_KEY = None
s.MODELS_CONFIG_DIR = None
for k, v in overrides.items():
setattr(s, k, v)
return s
@contextmanager
def _yield(conn):
yield conn
@pytest.mark.unit
class TestPerUserLayer:
def test_user_models_isolated_per_user(self, pg_conn):
"""Alice's BYOM model must not appear in Bob's lookups."""
repo = UserCustomModelsRepository(pg_conn)
alice_model = repo.create(
user_id="alice",
upstream_model_id="alice-mistral",
display_name="Alice Mistral",
base_url="https://api.mistral.ai/v1",
api_key_plaintext="sk-alice",
)
s = _make_settings()
with patch("application.core.settings.settings", s), patch(
"application.storage.db.session.db_readonly",
lambda: _yield(pg_conn),
):
reg = ModelRegistry()
assert reg.get_model(alice_model["id"], user_id="alice") is not None
assert reg.get_model(alice_model["id"], user_id="bob") is None
# And without a user_id at all, the per-user layer is invisible
assert reg.get_model(alice_model["id"]) is None
def test_get_all_models_includes_user_models(self, pg_conn):
repo = UserCustomModelsRepository(pg_conn)
created = repo.create(
user_id="user-1",
upstream_model_id="mistral-large-latest",
display_name="My Mistral",
base_url="https://api.mistral.ai/v1",
api_key_plaintext="sk-test",
)
s = _make_settings()
with patch("application.core.settings.settings", s), patch(
"application.storage.db.session.db_readonly",
lambda: _yield(pg_conn),
):
reg = ModelRegistry()
ids_no_user = {m.id for m in reg.get_all_models()}
ids_with_user = {
m.id for m in reg.get_all_models(user_id="user-1")
}
assert created["id"] not in ids_no_user
assert created["id"] in ids_with_user
def test_user_models_carry_decrypted_api_key_and_upstream_id(self, pg_conn):
repo = UserCustomModelsRepository(pg_conn)
created = repo.create(
user_id="user-1",
upstream_model_id="mistral-large-latest",
display_name="My Mistral",
base_url="https://api.mistral.ai/v1",
api_key_plaintext="sk-test-XYZ",
)
s = _make_settings()
with patch("application.core.settings.settings", s), patch(
"application.storage.db.session.db_readonly",
lambda: _yield(pg_conn),
):
reg = ModelRegistry()
m = reg.get_model(created["id"], user_id="user-1")
assert m is not None
assert m.provider == ModelProvider.OPENAI_COMPATIBLE
assert m.upstream_model_id == "mistral-large-latest"
assert m.api_key == "sk-test-XYZ"
assert m.base_url == "https://api.mistral.ai/v1"
assert m.source == "user"
# The wire format never leaks the api_key
d = m.to_dict()
assert "api_key" not in d
for v in d.values():
assert v != "sk-test-XYZ"
def test_invalidate_user_clears_cache(self, pg_conn):
repo = UserCustomModelsRepository(pg_conn)
created = repo.create(
user_id="user-1",
upstream_model_id="x",
display_name="X",
base_url="https://api.mistral.ai/v1",
api_key_plaintext="k",
)
s = _make_settings()
# Stub Redis so invalidate_user can publish its version bump
# without hitting a real broker. The P1 fix calls ``incr`` on
# invalidate; here we just need it not to raise.
fake_redis = MagicMock()
with patch("application.core.settings.settings", s), patch(
"application.storage.db.session.db_readonly",
lambda: _yield(pg_conn),
), patch(
"application.cache.get_redis_instance", return_value=fake_redis
):
reg = ModelRegistry()
assert reg.get_model(created["id"], user_id="user-1") is not None
# Cache populated
assert "user-1" in reg._user_models
ModelRegistry.invalidate_user("user-1")
assert "user-1" not in reg._user_models
# Re-lookup repopulates
reg.get_model(created["id"], user_id="user-1")
assert "user-1" in reg._user_models
@pytest.mark.unit
class TestLLMCreatorDispatchUsesUpstreamModelId:
def test_llmcreator_sends_upstream_id_not_uuid(self, pg_conn):
"""End-to-end: LLMCreator with a BYOM uuid must construct the
OpenAILLM with the user's upstream model name (e.g.
``mistral-large-latest``), not the registry uuid."""
repo = UserCustomModelsRepository(pg_conn)
created = repo.create(
user_id="user-1",
upstream_model_id="mistral-large-latest",
display_name="My Mistral",
base_url="https://api.mistral.ai/v1",
api_key_plaintext="sk-mistral-real",
)
captured = {}
class _FakeLLM:
def __init__(self, api_key, user_api_key, *args, **kwargs):
captured["api_key"] = api_key
captured["base_url"] = kwargs.get("base_url")
captured["model_id"] = kwargs.get("model_id")
s = _make_settings()
with patch("application.core.settings.settings", s), patch(
"application.storage.db.session.db_readonly",
lambda: _yield(pg_conn),
):
ModelRegistry()
from application.llm.providers import PROVIDERS_BY_NAME
with patch.object(
PROVIDERS_BY_NAME["openai_compatible"], "llm_class", _FakeLLM
):
from application.llm.llm_creator import LLMCreator
LLMCreator.create_llm(
type="openai_compatible",
api_key="caller-passed-WRONG",
user_api_key=None,
decoded_token={"sub": "user-1"},
model_id=created["id"],
)
assert captured["api_key"] == "sk-mistral-real"
assert captured["base_url"] == "https://api.mistral.ai/v1"
assert captured["model_id"] == "mistral-large-latest" # NOT the uuid!
def test_llmcreator_forwards_byom_capabilities(self, pg_conn):
"""LLMCreator must thread the registry-resolved ``capabilities``
into the LLM. Without it the OpenAILLM hard-codes ``True`` for
tools/structured output and advertises image attachments
unconditionally, leaking unsupported features to BYOMs that
disabled them."""
repo = UserCustomModelsRepository(pg_conn)
created = repo.create(
user_id="user-2",
upstream_model_id="my-text-only-model",
display_name="Text-only BYOM",
base_url="https://api.mistral.ai/v1",
api_key_plaintext="sk-real",
capabilities={
"supports_tools": False,
"supports_structured_output": False,
"attachments": [],
"context_window": 8192,
},
)
captured = {}
class _FakeLLM:
def __init__(self, api_key, user_api_key, *args, **kwargs):
captured["capabilities"] = kwargs.get("capabilities")
s = _make_settings()
with patch("application.core.settings.settings", s), patch(
"application.storage.db.session.db_readonly",
lambda: _yield(pg_conn),
):
ModelRegistry()
from application.llm.providers import PROVIDERS_BY_NAME
with patch.object(
PROVIDERS_BY_NAME["openai_compatible"], "llm_class", _FakeLLM
):
from application.llm.llm_creator import LLMCreator
LLMCreator.create_llm(
type="openai_compatible",
api_key="ignored",
user_api_key=None,
decoded_token={"sub": "user-2"},
model_id=created["id"],
)
caps = captured["capabilities"]
assert caps is not None
assert caps.supports_tools is False
assert caps.supports_structured_output is False
assert caps.supported_attachment_types == []
def test_byom_image_alias_expands_to_mime_types(self, pg_conn):
"""A BYOM stored with ``attachments: ["image"]`` (the alias the
UI sends) must surface as concrete MIME types on the registry
record, matching the built-in YAML expansion. Without this,
handlers/base.prepare_messages compares ``image/png`` against
the bare alias and filters every image upload as unsupported.
"""
repo = UserCustomModelsRepository(pg_conn)
created = repo.create(
user_id="user-1",
upstream_model_id="my-vision-model",
display_name="Vision BYOM",
base_url="https://api.mistral.ai/v1",
api_key_plaintext="sk-real",
capabilities={"attachments": ["image"]},
)
s = _make_settings()
with patch("application.core.settings.settings", s), patch(
"application.storage.db.session.db_readonly",
lambda: _yield(pg_conn),
):
reg = ModelRegistry()
model = reg.get_model(created["id"], user_id="user-1")
assert model is not None
types = model.capabilities.supported_attachment_types
assert "image" not in types, (
"alias must be expanded, not stored verbatim"
)
assert any(t.startswith("image/") for t in types)
# Must include at least the common web image types so any image
# an end user uploads has a chance to match.
assert "image/png" in types
assert "image/jpeg" in types
def test_byom_unknown_alias_is_skipped_at_runtime(self, pg_conn):
"""Operator alias-map edits could orphan a stored alias. The
registry must drop the unknown entry rather than the entire
layer (which would hide every BYOM the user has).
"""
repo = UserCustomModelsRepository(pg_conn)
created = repo.create(
user_id="user-1",
upstream_model_id="m",
display_name="M",
base_url="https://api.mistral.ai/v1",
api_key_plaintext="k",
# Bypass the route validation: write a bad alias straight
# to the row to simulate the post-edit orphan case.
capabilities={"attachments": ["image", "not-a-real-alias"]},
)
s = _make_settings()
with patch("application.core.settings.settings", s), patch(
"application.storage.db.session.db_readonly",
lambda: _yield(pg_conn),
):
reg = ModelRegistry()
model = reg.get_model(created["id"], user_id="user-1")
assert model is not None
types = model.capabilities.supported_attachment_types
assert "not-a-real-alias" not in types
assert any(t.startswith("image/") for t in types)
@pytest.mark.unit
class TestCrossProcessInvalidation:
"""The BYOM cache lives per-process. Without the P1 fix, a CRUD on
web-1 would leave the decrypted record (with old api_key/base_url)
cached forever in web-2 / Celery. These tests pin down that:
* ``invalidate_user`` publishes a version bump to Redis
* peers reload when the version they saw at load time is stale
* the local TTL bounds staleness even when Redis is unreachable
* unchanged version + expired TTL extends the entry without a
DB read (the common-case fast path)
"""
def test_invalidate_user_publishes_redis_version_bump(self, pg_conn):
repo = UserCustomModelsRepository(pg_conn)
repo.create(
user_id="user-1",
upstream_model_id="m1",
display_name="M1",
base_url="https://api.mistral.ai/v1",
api_key_plaintext="k",
)
fake_redis = MagicMock()
s = _make_settings()
with patch("application.core.settings.settings", s), patch(
"application.storage.db.session.db_readonly",
lambda: _yield(pg_conn),
), patch(
"application.cache.get_redis_instance", return_value=fake_redis
):
ModelRegistry().get_model("anything", user_id="user-1")
ModelRegistry.invalidate_user("user-1")
fake_redis.incr.assert_called_once_with("byom:registry_version:user-1")
def test_peer_reloads_when_redis_version_changes(self, pg_conn):
"""Two-process simulation. Peer loads at version=0; another
process's CRUD bumps the version and updates Postgres; peer's
next post-TTL access sees the version mismatch and reloads,
picking up the rotated key it never invalidated locally."""
from application.core import model_registry as registry_mod
repo = UserCustomModelsRepository(pg_conn)
created = repo.create(
user_id="user-1",
upstream_model_id="m-orig",
display_name="orig",
base_url="https://api.mistral.ai/v1",
api_key_plaintext="orig-key",
)
state = {"version": 0}
class _FakeRedis:
def get(self, key):
if key == "byom:registry_version:user-1":
return str(state["version"]).encode()
return None
def incr(self, key):
if key == "byom:registry_version:user-1":
state["version"] += 1
s = _make_settings()
# Force TTL to 0 so any subsequent access takes the post-TTL
# path without waiting 60s.
with patch("application.core.settings.settings", s), patch(
"application.storage.db.session.db_readonly",
lambda: _yield(pg_conn),
), patch(
"application.cache.get_redis_instance", return_value=_FakeRedis()
), patch.object(
registry_mod, "_USER_CACHE_TTL_SECONDS", 0.0
):
reg = ModelRegistry()
assert (
reg.get_model(created["id"], user_id="user-1").api_key
== "orig-key"
)
# Another process's CRUD: bump Redis counter + mutate the
# row. Note we deliberately do NOT call ``invalidate_user``
# in this process — that's the whole point of the test.
state["version"] += 1
repo.update(
created["id"],
"user-1",
{"api_key_plaintext": "rotated-key"},
)
assert (
reg.get_model(created["id"], user_id="user-1").api_key
== "rotated-key"
)
def test_ttl_bounds_staleness_when_redis_unavailable(self, pg_conn):
"""Redis down → fall back to TTL-only invalidation. After the
TTL elapses, peers reload regardless."""
from application.core import model_registry as registry_mod
repo = UserCustomModelsRepository(pg_conn)
created = repo.create(
user_id="user-1",
upstream_model_id="m",
display_name="m",
base_url="https://api.mistral.ai/v1",
api_key_plaintext="orig",
)
s = _make_settings()
with patch("application.core.settings.settings", s), patch(
"application.storage.db.session.db_readonly",
lambda: _yield(pg_conn),
), patch(
"application.cache.get_redis_instance", return_value=None
), patch.object(
registry_mod, "_USER_CACHE_TTL_SECONDS", 0.0
):
reg = ModelRegistry()
first = reg.get_model(created["id"], user_id="user-1")
assert first.api_key == "orig"
repo.update(
created["id"],
"user-1",
{"api_key_plaintext": "rotated"},
)
second = reg.get_model(created["id"], user_id="user-1")
assert second.api_key == "rotated"
def test_unchanged_version_extends_ttl_without_db_read(self, pg_conn):
"""Hot path: TTL expires but Redis says no invalidation
happened — extend the entry without re-reading Postgres."""
from application.core import model_registry as registry_mod
repo = UserCustomModelsRepository(pg_conn)
created = repo.create(
user_id="user-1",
upstream_model_id="m",
display_name="m",
base_url="https://api.mistral.ai/v1",
api_key_plaintext="k",
)
fake_redis = MagicMock()
fake_redis.get.return_value = b"7" # constant version
db_open_count = {"n": 0}
@contextmanager
def _counting_db_readonly():
db_open_count["n"] += 1
yield pg_conn
s = _make_settings()
with patch("application.core.settings.settings", s), patch(
"application.storage.db.session.db_readonly",
_counting_db_readonly,
), patch(
"application.cache.get_redis_instance", return_value=fake_redis
), patch.object(
registry_mod, "_USER_CACHE_TTL_SECONDS", 0.0
):
reg = ModelRegistry()
reg.get_model(created["id"], user_id="user-1")
first_open = db_open_count["n"]
# TTL has expired, but Redis returns the same version we
# captured at load time → no DB reload.
reg.get_model(created["id"], user_id="user-1")
reg.get_model(created["id"], user_id="user-1")
assert db_open_count["n"] == first_open

View File

@@ -1239,6 +1239,131 @@ class TestPerformInMemoryCompression:
assert messages is not None
assert agent.compressed_summary == "summary"
def test_uses_agent_model_user_id_for_byom_owner_scope(self):
"""Shared-agent BYOM dispatch: the agent's model_id is the owner's
BYOM UUID and ``decoded_token['sub']`` is a different (caller) user.
Provider lookup and LLMCreator must run under the OWNER scope so
the registry hits the owner's per-user layer instead of missing
and falling back to the deployment default."""
handler = ConcreteHandler()
agent = Mock()
agent.model_id = "byom-uuid-owner"
agent.model_user_id = "owner-id"
agent.user_api_key = None
agent.decoded_token = {"sub": "caller-id"}
agent.agent_id = None
agent.context_limit_reached = False
agent.current_token_count = 0
mock_metadata = Mock()
mock_metadata.compressed_token_count = 100
mock_metadata.original_token_count = 1000
mock_metadata.compression_ratio = 10.0
mock_metadata.to_dict.return_value = {"ratio": 10.0}
mock_service = Mock()
mock_service.compress_conversation.return_value = mock_metadata
mock_service.get_compressed_context.return_value = (
"summary",
[{"prompt": "q", "response": "a"}],
)
provider_lookup = MagicMock(return_value="openai_compatible")
create_llm_spy = MagicMock(return_value=Mock())
with patch.object(
handler,
"_build_conversation_from_messages",
return_value={"queries": [{"prompt": "q", "response": "a"}]},
), patch(
"application.core.model_utils.get_provider_from_model_id",
provider_lookup,
), patch(
"application.core.model_utils.get_api_key_for_provider",
return_value="key",
), patch(
"application.llm.llm_creator.LLMCreator.create_llm",
create_llm_spy,
), patch(
"application.api.answer.services.compression.service.CompressionService",
return_value=mock_service,
), patch.object(
handler,
"_rebuild_messages_after_compression",
return_value=[{"role": "system", "content": "rebuilt"}],
), patch(
"application.core.settings.settings",
MagicMock(COMPRESSION_MODEL_OVERRIDE=None),
):
success, _ = handler._perform_in_memory_compression(
agent, [{"role": "user", "content": "hi"}]
)
assert success is True
# Provider lookup must use the owner scope, not the caller's sub.
assert provider_lookup.call_args.kwargs["user_id"] == "owner-id"
# LLMCreator must receive the owner scope as model_user_id;
# without this the BYOM UUID resolves under the caller and the
# owner's registered base_url + api_key are missed.
assert create_llm_spy.call_args.kwargs["model_user_id"] == "owner-id"
def test_falls_back_to_caller_sub_when_no_model_user_id(self):
"""Built-in or caller-owned BYOM: ``agent.model_user_id`` is None.
Compression should resolve under the caller's sub, matching the
pre-fix behavior."""
handler = ConcreteHandler()
agent = Mock()
agent.model_id = "gpt-4"
agent.model_user_id = None
agent.user_api_key = None
agent.decoded_token = {"sub": "caller-id"}
agent.agent_id = None
agent.context_limit_reached = False
agent.current_token_count = 0
mock_metadata = Mock()
mock_metadata.compressed_token_count = 100
mock_metadata.original_token_count = 1000
mock_metadata.to_dict.return_value = {}
mock_service = Mock()
mock_service.compress_conversation.return_value = mock_metadata
mock_service.get_compressed_context.return_value = ("summary", [])
provider_lookup = MagicMock(return_value="openai")
create_llm_spy = MagicMock(return_value=Mock())
with patch.object(
handler,
"_build_conversation_from_messages",
return_value={"queries": [{"prompt": "q", "response": "a"}]},
), patch(
"application.core.model_utils.get_provider_from_model_id",
provider_lookup,
), patch(
"application.core.model_utils.get_api_key_for_provider",
return_value="key",
), patch(
"application.llm.llm_creator.LLMCreator.create_llm",
create_llm_spy,
), patch(
"application.api.answer.services.compression.service.CompressionService",
return_value=mock_service,
), patch.object(
handler,
"_rebuild_messages_after_compression",
return_value=[],
), patch(
"application.core.settings.settings",
MagicMock(COMPRESSION_MODEL_OVERRIDE=None),
):
handler._perform_in_memory_compression(
agent, [{"role": "user", "content": "hi"}]
)
assert provider_lookup.call_args.kwargs["user_id"] == "caller-id"
assert create_llm_spy.call_args.kwargs["model_user_id"] == "caller-id"
# ---------------------------------------------------------------------------
# _perform_mid_execution_compression — additional edge cases

View File

@@ -197,7 +197,7 @@ class TestFallbackLLMResolution:
mock_fallback = StubLLM()
monkeypatch.setattr(
"application.core.model_utils.get_provider_from_model_id",
lambda mid: "openai",
lambda mid, **_kwargs: "openai",
)
monkeypatch.setattr(
"application.core.model_utils.get_api_key_for_provider",
@@ -223,7 +223,7 @@ class TestFallbackLLMResolution:
monkeypatch.setattr(
"application.core.model_utils.get_provider_from_model_id",
lambda mid: "openai",
lambda mid, **_kwargs: "openai",
)
monkeypatch.setattr(
"application.core.model_utils.get_api_key_for_provider",
@@ -262,7 +262,7 @@ class TestFallbackLLMResolution:
def test_backup_provider_not_found_skipped(self, monkeypatch):
monkeypatch.setattr(
"application.core.model_utils.get_provider_from_model_id",
lambda mid: None,
lambda mid, **_kwargs: None,
)
monkeypatch.setattr(
"application.llm.base.settings",

View File

@@ -11,9 +11,7 @@ import pytest
from application.llm.base import BaseLLM
# ---------------------------------------------------------------------------
# Concrete LLM stubs
# ---------------------------------------------------------------------------
class FakeLLM(BaseLLM):
@@ -59,9 +57,7 @@ class FakeLLM(BaseLLM):
return super().gen_stream(*args, **kwargs)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _noop_decorator(func):
@@ -121,9 +117,7 @@ def patch_model_utils(monkeypatch):
CALL_ARGS = dict(model="test-model", messages=[{"role": "user", "content": "hi"}])
# ---------------------------------------------------------------------------
# Tests — fallback_llm property resolution
# ---------------------------------------------------------------------------
@pytest.mark.integration
@@ -135,7 +129,7 @@ class TestFallbackLLMResolution:
backup_llm = FakeLLM(responses=["backup response"])
patch_model_utils(
get_provider=lambda mid: "openai",
get_provider=lambda mid, **_kwargs: "openai",
get_api_key=lambda prov: "fake-key",
create_llm=lambda type, **kw: backup_llm,
)
@@ -174,7 +168,7 @@ class TestFallbackLLMResolution:
good_backup = FakeLLM(responses=["good backup"])
call_count = {"n": 0}
def fake_get_provider(model_id):
def fake_get_provider(model_id, **_kwargs):
call_count["n"] += 1
if model_id == "bad-model":
return None # unresolvable
@@ -202,9 +196,7 @@ class TestFallbackLLMResolution:
assert primary.fallback_llm is None
# ---------------------------------------------------------------------------
# Tests — non-streaming fallback (gen)
# ---------------------------------------------------------------------------
@pytest.mark.integration
@@ -221,7 +213,7 @@ class TestNonStreamingFallback:
backup = FakeLLM(responses=["backup ok"])
patch_model_utils(
get_provider=lambda mid: "openai",
get_provider=lambda mid, **_kwargs: "openai",
get_api_key=lambda p: "k",
create_llm=lambda type, **kw: backup,
)
@@ -242,9 +234,7 @@ class TestNonStreamingFallback:
primary.gen(**CALL_ARGS)
# ---------------------------------------------------------------------------
# Tests — streaming fallback (gen_stream)
# ---------------------------------------------------------------------------
@pytest.mark.integration
@@ -261,7 +251,7 @@ class TestStreamingFallback:
backup = FakeLLM(stream_chunks=["fallback1", "fallback2"])
patch_model_utils(
get_provider=lambda m: "openai",
get_provider=lambda m, **_kwargs: "openai",
get_api_key=lambda p: "k",
create_llm=lambda type, **kw: backup,
)
@@ -280,7 +270,7 @@ class TestStreamingFallback:
backup = FakeLLM(stream_chunks=["recovery1", "recovery2"])
patch_model_utils(
get_provider=lambda m: "openai",
get_provider=lambda m, **_kwargs: "openai",
get_api_key=lambda p: "k",
create_llm=lambda type, **kw: backup,
)
@@ -305,9 +295,7 @@ class TestStreamingFallback:
list(primary.gen_stream(**CALL_ARGS))
# ---------------------------------------------------------------------------
# Tests — backup model priority over global fallback
# ---------------------------------------------------------------------------
@pytest.mark.integration
@@ -323,7 +311,7 @@ class TestBackupModelPriority:
return backup
patch_model_utils(
get_provider=lambda m: "openai",
get_provider=lambda m, **_kwargs: "openai",
get_api_key=lambda p: "k",
create_llm=fake_create_llm,
)
@@ -346,7 +334,7 @@ class TestBackupModelPriority:
return backup
patch_model_utils(
get_provider=lambda m: "openai",
get_provider=lambda m, **_kwargs: "openai",
get_api_key=lambda p: "k",
create_llm=fake_create_llm,
)
@@ -366,7 +354,7 @@ class TestBackupModelPriority:
global_fallback = FakeLLM(responses=["global ok"])
call_order = []
def fake_get_provider(mid):
def fake_get_provider(mid, **_kwargs):
if mid == "broken-backup":
return "nonexistent_provider"
return "openai"
@@ -401,9 +389,7 @@ class TestBackupModelPriority:
assert call_order == ["broken-backup", "global-model"]
# ---------------------------------------------------------------------------
# Tests — fallback uses its own model_id, not the primary's
# ---------------------------------------------------------------------------
@pytest.mark.integration
@@ -419,7 +405,7 @@ class TestFallbackModelIdOverride:
)
patch_model_utils(
get_provider=lambda m: "groq",
get_provider=lambda m, **_kwargs: "groq",
get_api_key=lambda p: "k",
create_llm=lambda type, **kw: backup,
)
@@ -441,7 +427,7 @@ class TestFallbackModelIdOverride:
)
patch_model_utils(
get_provider=lambda m: "groq",
get_provider=lambda m, **_kwargs: "groq",
get_api_key=lambda p: "k",
create_llm=lambda type, **kw: backup,
)
@@ -465,7 +451,7 @@ class TestFallbackModelIdOverride:
)
patch_model_utils(
get_provider=lambda m: "groq",
get_provider=lambda m, **_kwargs: "groq",
get_api_key=lambda p: "k",
create_llm=lambda type, **kw: backup,
)
@@ -480,3 +466,161 @@ class TestFallbackModelIdOverride:
assert chunks == ["partial1", "partial2", "recovered"]
assert backup.last_model_received == "groq-gpt-oss-120b"
# Tests — model_user_id (BYOM owner scope) propagates into fallback resolution
@pytest.mark.integration
class TestFallbackModelUserIdScope:
"""A shared agent dispatched by user B but owned by user A stores
A's BYOM UUIDs as backup_models. Without the P2 fix the fallback
property looks up those UUIDs against ``decoded_token['sub']`` (B,
the caller), which can't see A's per-user layer — backups are
silently skipped and the global FALLBACK_* settings are used
instead. These tests pin down that ``model_user_id`` (the owner)
is used both for the registry lookup and for the recursive
``LLMCreator.create_llm`` call."""
def test_backup_lookup_uses_model_user_id_not_caller(
self, patch_model_utils
):
captured = {"user_id": None}
def fake_get_provider(model_id, **kwargs):
captured["user_id"] = kwargs.get("user_id")
return "openai"
backup = FakeLLM(responses=["ok"])
patch_model_utils(
get_provider=fake_get_provider,
get_api_key=lambda p: "k",
create_llm=lambda type, **kw: backup,
)
primary = FakeLLM(
decoded_token={"sub": "caller-bob"},
model_user_id="owner-alice",
backup_models=["alice-byom-uuid"],
)
_ = primary.fallback_llm
assert captured["user_id"] == "owner-alice"
def test_backup_create_llm_receives_model_user_id(self, patch_model_utils):
backup = FakeLLM(responses=["ok"])
captured = {}
def fake_create_llm(type, **kw):
captured["model_user_id"] = kw.get("model_user_id")
captured["model_id"] = kw.get("model_id")
return backup
patch_model_utils(
get_provider=lambda m, **_kwargs: "openai",
get_api_key=lambda p: "k",
create_llm=fake_create_llm,
)
primary = FakeLLM(
decoded_token={"sub": "caller-bob"},
model_user_id="owner-alice",
backup_models=["alice-byom-uuid"],
)
_ = primary.fallback_llm
assert captured["model_user_id"] == "owner-alice"
assert captured["model_id"] == "alice-byom-uuid"
def test_global_fallback_create_llm_receives_model_user_id(
self, monkeypatch, patch_model_utils
):
"""The global FALLBACK_LLM_NAME path must also forward
``model_user_id`` — operators can configure it to a BYOM UUID
that's owned by the same user as the primary model."""
backup = FakeLLM(responses=["ok"])
captured = {}
def fake_create_llm(type, **kw):
captured["model_user_id"] = kw.get("model_user_id")
return backup
patch_model_utils(create_llm=fake_create_llm)
monkeypatch.setattr(
"application.llm.base.settings",
MagicMock(
FALLBACK_LLM_PROVIDER="openai",
FALLBACK_LLM_NAME="some-uuid",
FALLBACK_LLM_API_KEY="k",
API_KEY="k",
),
)
primary = FakeLLM(
decoded_token={"sub": "caller-bob"},
model_user_id="owner-alice",
backup_models=[],
)
_ = primary.fallback_llm
assert captured["model_user_id"] == "owner-alice"
def test_falls_back_to_caller_when_model_user_id_unset(
self, patch_model_utils
):
"""Built-in models / pre-P2 callers don't pass model_user_id.
In that case the caller's sub is still used — preserving
existing behaviour."""
captured = {}
def fake_get_provider(model_id, **kwargs):
captured["user_id"] = kwargs.get("user_id")
return "openai"
patch_model_utils(
get_provider=fake_get_provider,
get_api_key=lambda p: "k",
create_llm=lambda type, **kw: FakeLLM(responses=["ok"]),
)
primary = FakeLLM(
decoded_token={"sub": "caller-bob"},
model_user_id=None,
backup_models=["some-builtin-id"],
)
_ = primary.fallback_llm
assert captured["user_id"] == "caller-bob"
# Tests — LLMCreator wires model_user_id through to BaseLLM
@pytest.mark.unit
class TestLLMCreatorPassesModelUserId:
"""End-to-end through ``LLMCreator.create_llm``: the constructed
LLM must store ``model_user_id`` so its fallback property can
resolve under the right scope."""
def test_model_user_id_set_on_constructed_llm(self, monkeypatch):
from application.llm.llm_creator import LLMCreator
from application.llm.providers import PROVIDERS_BY_NAME
captured = {}
class _CapturingLLM:
def __init__(self, api_key, user_api_key, *args, **kwargs):
captured["model_user_id"] = kwargs.get("model_user_id")
# Pick any registered provider — we only need the constructor
# call to land in our fake.
monkeypatch.setattr(
PROVIDERS_BY_NAME["openai"], "llm_class", _CapturingLLM
)
LLMCreator.create_llm(
type="openai",
api_key="k",
user_api_key=None,
decoded_token={"sub": "caller-bob"},
model_id=None,
model_user_id="owner-alice",
)
assert captured["model_user_id"] == "owner-alice"

Some files were not shown because too many files have changed in this diff Show More