mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-07 14:34:32 +00:00
Compare commits
2 Commits
feat-model
...
feat-bring
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e0a8cc178b | ||
|
|
af618de13d |
@@ -35,8 +35,5 @@ MICROSOFT_TENANT_ID=your-azure-ad-tenant-id
|
||||
#Alternatively, use "https://login.microsoftonline.com/common" for multi-tenant app.
|
||||
MICROSOFT_AUTHORITY=https://{tenantId}.ciamlogin.com/{tenantId}
|
||||
|
||||
# User-data Postgres DB (Phase 0 of the MongoDB→Postgres migration).
|
||||
# Standard Postgres URI — `postgres://` and `postgresql://` both work.
|
||||
# Leave unset while the migration is still being rolled out; the app will
|
||||
# fall back to MongoDB for user data until POSTGRES_URI is configured.
|
||||
|
||||
# POSTGRES_URI=postgresql://docsgpt:docsgpt@localhost:5432/docsgpt
|
||||
|
||||
@@ -42,6 +42,7 @@ class BaseAgent(ABC):
|
||||
llm_handler=None,
|
||||
tool_executor: Optional[ToolExecutor] = None,
|
||||
backup_models: Optional[List[str]] = None,
|
||||
model_user_id: Optional[str] = None,
|
||||
):
|
||||
self.endpoint = endpoint
|
||||
self.llm_name = llm_name
|
||||
@@ -52,10 +53,13 @@ class BaseAgent(ABC):
|
||||
self.prompt = prompt
|
||||
self.decoded_token = decoded_token or {}
|
||||
self.user: str = self.decoded_token.get("sub")
|
||||
# BYOM-resolution scope: owner for shared agents, caller for
|
||||
# caller-owned BYOM, None for built-ins. Falls back to self.user
|
||||
# for worker/legacy callers that don't thread model_user_id.
|
||||
self.model_user_id = model_user_id
|
||||
self.tools: List[Dict] = []
|
||||
self.chat_history: List[Dict] = chat_history if chat_history is not None else []
|
||||
|
||||
# Dependency injection for LLM — fall back to creating if not provided
|
||||
if llm is not None:
|
||||
self.llm = llm
|
||||
else:
|
||||
@@ -67,8 +71,16 @@ class BaseAgent(ABC):
|
||||
model_id=model_id,
|
||||
agent_id=agent_id,
|
||||
backup_models=backup_models,
|
||||
model_user_id=model_user_id,
|
||||
)
|
||||
|
||||
# For BYOM, registry id (UUID) differs from upstream model id
|
||||
# (e.g. ``mistral-large-latest``). LLMCreator resolved this onto
|
||||
# the LLM instance; cache it for subsequent gen calls.
|
||||
self.upstream_model_id = (
|
||||
getattr(self.llm, "model_id", None) or model_id
|
||||
)
|
||||
|
||||
self.retrieved_docs = retrieved_docs or []
|
||||
|
||||
if llm_handler is not None:
|
||||
@@ -306,7 +318,9 @@ class BaseAgent(ABC):
|
||||
try:
|
||||
current_tokens = self._calculate_current_context_tokens(messages)
|
||||
self.current_token_count = current_tokens
|
||||
context_limit = get_token_limit(self.model_id)
|
||||
context_limit = get_token_limit(
|
||||
self.model_id, user_id=self.model_user_id or self.user
|
||||
)
|
||||
threshold = int(context_limit * settings.COMPRESSION_THRESHOLD_PERCENTAGE)
|
||||
|
||||
if current_tokens >= threshold:
|
||||
@@ -325,7 +339,9 @@ class BaseAgent(ABC):
|
||||
|
||||
current_tokens = self._calculate_current_context_tokens(messages)
|
||||
self.current_token_count = current_tokens
|
||||
context_limit = get_token_limit(self.model_id)
|
||||
context_limit = get_token_limit(
|
||||
self.model_id, user_id=self.model_user_id or self.user
|
||||
)
|
||||
percentage = (current_tokens / context_limit) * 100
|
||||
|
||||
if current_tokens >= context_limit:
|
||||
@@ -387,7 +403,9 @@ class BaseAgent(ABC):
|
||||
)
|
||||
system_prompt = system_prompt + compression_context
|
||||
|
||||
context_limit = get_token_limit(self.model_id)
|
||||
context_limit = get_token_limit(
|
||||
self.model_id, user_id=self.model_user_id or self.user
|
||||
)
|
||||
system_tokens = num_tokens_from_string(system_prompt)
|
||||
|
||||
safety_buffer = int(context_limit * 0.1)
|
||||
@@ -497,7 +515,10 @@ class BaseAgent(ABC):
|
||||
def _llm_gen(self, messages: List[Dict], log_context: Optional[LogContext] = None):
|
||||
self._validate_context_size(messages)
|
||||
|
||||
gen_kwargs = {"model": self.model_id, "messages": messages}
|
||||
# Use the upstream id resolved by LLMCreator (see __init__).
|
||||
# Built-in models: same as self.model_id. BYOM: the user's
|
||||
# typed model name, not the internal UUID.
|
||||
gen_kwargs = {"model": self.upstream_model_id, "messages": messages}
|
||||
if self.attachments:
|
||||
gen_kwargs["_usage_attachments"] = self.attachments
|
||||
|
||||
|
||||
@@ -312,7 +312,7 @@ class ResearchAgent(BaseAgent):
|
||||
|
||||
try:
|
||||
response = self.llm.gen(
|
||||
model=self.model_id,
|
||||
model=self.upstream_model_id,
|
||||
messages=messages,
|
||||
tools=None,
|
||||
response_format={"type": "json_object"},
|
||||
@@ -390,7 +390,7 @@ class ResearchAgent(BaseAgent):
|
||||
|
||||
try:
|
||||
response = self.llm.gen(
|
||||
model=self.model_id,
|
||||
model=self.upstream_model_id,
|
||||
messages=messages,
|
||||
tools=None,
|
||||
response_format={"type": "json_object"},
|
||||
@@ -506,7 +506,7 @@ class ResearchAgent(BaseAgent):
|
||||
|
||||
try:
|
||||
response = self.llm.gen(
|
||||
model=self.model_id,
|
||||
model=self.upstream_model_id,
|
||||
messages=messages,
|
||||
tools=self.tools if self.tools else None,
|
||||
)
|
||||
@@ -537,7 +537,7 @@ class ResearchAgent(BaseAgent):
|
||||
)
|
||||
try:
|
||||
response = self.llm.gen(
|
||||
model=self.model_id, messages=messages, tools=None
|
||||
model=self.upstream_model_id, messages=messages, tools=None
|
||||
)
|
||||
self._track_tokens(self._snapshot_llm_tokens())
|
||||
text = self._extract_text(response)
|
||||
@@ -664,7 +664,7 @@ class ResearchAgent(BaseAgent):
|
||||
]
|
||||
|
||||
llm_response = self.llm.gen_stream(
|
||||
model=self.model_id, messages=messages, tools=None
|
||||
model=self.upstream_model_id, messages=messages, tools=None
|
||||
)
|
||||
|
||||
if log_context:
|
||||
|
||||
@@ -39,6 +39,7 @@ class InternalSearchTool(Tool):
|
||||
chunks=int(self.config.get("chunks", 2)),
|
||||
doc_token_limit=int(self.config.get("doc_token_limit", 50000)),
|
||||
model_id=self.config.get("model_id", "docsgpt-local"),
|
||||
model_user_id=self.config.get("model_user_id"),
|
||||
user_api_key=self.config.get("user_api_key"),
|
||||
agent_id=self.config.get("agent_id"),
|
||||
llm_name=self.config.get("llm_name", settings.LLM_PROVIDER),
|
||||
@@ -435,6 +436,7 @@ def build_internal_tool_config(
|
||||
chunks: int = 2,
|
||||
doc_token_limit: int = 50000,
|
||||
model_id: str = "docsgpt-local",
|
||||
model_user_id: Optional[str] = None,
|
||||
user_api_key: Optional[str] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
llm_name: str = None,
|
||||
@@ -449,6 +451,7 @@ def build_internal_tool_config(
|
||||
"chunks": chunks,
|
||||
"doc_token_limit": doc_token_limit,
|
||||
"model_id": model_id,
|
||||
"model_user_id": model_user_id,
|
||||
"user_api_key": user_api_key,
|
||||
"agent_id": agent_id,
|
||||
"llm_name": llm_name or settings.LLM_PROVIDER,
|
||||
|
||||
@@ -211,15 +211,26 @@ class WorkflowEngine:
|
||||
node_config.json_schema, node.title
|
||||
)
|
||||
node_model_id = node_config.model_id or self.agent.model_id
|
||||
# Inherit BYOM scope from parent agent so owner-stored BYOM
|
||||
# resolves on shared workflows.
|
||||
node_user_id = getattr(self.agent, "model_user_id", None) or (
|
||||
self.agent.decoded_token.get("sub")
|
||||
if isinstance(self.agent.decoded_token, dict)
|
||||
else None
|
||||
)
|
||||
node_llm_name = (
|
||||
node_config.llm_name
|
||||
or get_provider_from_model_id(node_model_id or "")
|
||||
or get_provider_from_model_id(
|
||||
node_model_id or "", user_id=node_user_id
|
||||
)
|
||||
or self.agent.llm_name
|
||||
)
|
||||
node_api_key = get_api_key_for_provider(node_llm_name) or self.agent.api_key
|
||||
|
||||
if node_json_schema and node_model_id:
|
||||
model_capabilities = get_model_capabilities(node_model_id)
|
||||
model_capabilities = get_model_capabilities(
|
||||
node_model_id, user_id=node_user_id
|
||||
)
|
||||
if model_capabilities and not model_capabilities.get(
|
||||
"supports_structured_output", False
|
||||
):
|
||||
@@ -232,6 +243,7 @@ class WorkflowEngine:
|
||||
"endpoint": self.agent.endpoint,
|
||||
"llm_name": node_llm_name,
|
||||
"model_id": node_model_id,
|
||||
"model_user_id": getattr(self.agent, "model_user_id", None),
|
||||
"api_key": node_api_key,
|
||||
"tool_ids": node_config.tools,
|
||||
"prompt": node_config.system_prompt,
|
||||
|
||||
65
application/alembic/versions/0003_user_custom_models.py
Normal file
65
application/alembic/versions/0003_user_custom_models.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""0003 user_custom_models — per-user OpenAI-compatible model registrations.
|
||||
|
||||
Revision ID: 0003_user_custom_models
|
||||
Revises: 0002_app_metadata
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
revision: str = "0003_user_custom_models"
|
||||
down_revision: Union[str, None] = "0002_app_metadata"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE user_custom_models (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
upstream_model_id TEXT NOT NULL,
|
||||
display_name TEXT NOT NULL,
|
||||
description TEXT NOT NULL DEFAULT '',
|
||||
base_url TEXT NOT NULL,
|
||||
api_key_encrypted TEXT NOT NULL,
|
||||
capabilities JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
enabled BOOLEAN NOT NULL DEFAULT true,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX user_custom_models_user_id_idx "
|
||||
"ON user_custom_models (user_id);"
|
||||
)
|
||||
|
||||
# Mirror the project-wide invariants set up in 0001_initial:
|
||||
# * user_id FK with ON DELETE RESTRICT (deferrable),
|
||||
# * ensure_user_exists() trigger so the parent users row autocreates,
|
||||
# * set_updated_at() trigger.
|
||||
op.execute(
|
||||
"ALTER TABLE user_custom_models "
|
||||
"ADD CONSTRAINT user_custom_models_user_id_fk "
|
||||
"FOREIGN KEY (user_id) REFERENCES users(user_id) "
|
||||
"ON DELETE RESTRICT DEFERRABLE INITIALLY IMMEDIATE;"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE TRIGGER user_custom_models_ensure_user "
|
||||
"BEFORE INSERT OR UPDATE OF user_id ON user_custom_models "
|
||||
"FOR EACH ROW EXECUTE FUNCTION ensure_user_exists();"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE TRIGGER user_custom_models_set_updated_at "
|
||||
"BEFORE UPDATE ON user_custom_models "
|
||||
"FOR EACH ROW WHEN (OLD.* IS DISTINCT FROM NEW.*) "
|
||||
"EXECUTE FUNCTION set_updated_at();"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP TABLE IF EXISTS user_custom_models;")
|
||||
@@ -177,6 +177,7 @@ class BaseAnswerResource:
|
||||
is_shared_usage: bool = False,
|
||||
shared_token: Optional[str] = None,
|
||||
model_id: Optional[str] = None,
|
||||
model_user_id: Optional[str] = None,
|
||||
_continuation: Optional[Dict] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
@@ -289,8 +290,18 @@ class BaseAnswerResource:
|
||||
# conversation if this is the first turn.
|
||||
if not conversation_id and should_save_conversation:
|
||||
try:
|
||||
# Use model-owner scope so shared-agent
|
||||
# owner-BYOM resolves to its registered plugin.
|
||||
provider = (
|
||||
get_provider_from_model_id(model_id)
|
||||
get_provider_from_model_id(
|
||||
model_id,
|
||||
user_id=model_user_id
|
||||
or (
|
||||
decoded_token.get("sub")
|
||||
if decoded_token
|
||||
else None
|
||||
),
|
||||
)
|
||||
if model_id
|
||||
else settings.LLM_PROVIDER
|
||||
)
|
||||
@@ -304,6 +315,7 @@ class BaseAnswerResource:
|
||||
decoded_token=decoded_token,
|
||||
model_id=model_id,
|
||||
agent_id=agent_id,
|
||||
model_user_id=model_user_id,
|
||||
)
|
||||
conversation_id = (
|
||||
self.conversation_service.save_conversation(
|
||||
@@ -340,6 +352,9 @@ class BaseAnswerResource:
|
||||
tool_schemas=getattr(agent, "tools", []),
|
||||
agent_config={
|
||||
"model_id": model_id or self.default_model_id,
|
||||
# Persist BYOM scope so resume doesn't
|
||||
# fall back to caller's layer.
|
||||
"model_user_id": model_user_id,
|
||||
"llm_name": getattr(agent, "llm_name", settings.LLM_PROVIDER),
|
||||
"api_key": getattr(agent, "api_key", None),
|
||||
"user_api_key": user_api_key,
|
||||
@@ -370,8 +385,14 @@ class BaseAnswerResource:
|
||||
if isNoneDoc:
|
||||
for doc in source_log_docs:
|
||||
doc["source"] = "None"
|
||||
# Run under model-owner scope so title-gen LLM inside
|
||||
# save_conversation uses the owner's BYOM provider/key.
|
||||
provider = (
|
||||
get_provider_from_model_id(model_id)
|
||||
get_provider_from_model_id(
|
||||
model_id,
|
||||
user_id=model_user_id
|
||||
or (decoded_token.get("sub") if decoded_token else None),
|
||||
)
|
||||
if model_id
|
||||
else settings.LLM_PROVIDER
|
||||
)
|
||||
@@ -384,6 +405,7 @@ class BaseAnswerResource:
|
||||
decoded_token=decoded_token,
|
||||
model_id=model_id,
|
||||
agent_id=agent_id,
|
||||
model_user_id=model_user_id,
|
||||
)
|
||||
|
||||
if should_save_conversation:
|
||||
@@ -481,12 +503,34 @@ class BaseAnswerResource:
|
||||
if isNoneDoc:
|
||||
for doc in source_log_docs:
|
||||
doc["source"] = "None"
|
||||
# Mirror the normal-path provider resolution so the
|
||||
# partial-save title LLM uses the model-owner's BYOM
|
||||
# registration (shared-agent dispatch) rather than
|
||||
# the deployment default with the instance api key.
|
||||
provider = (
|
||||
get_provider_from_model_id(
|
||||
model_id,
|
||||
user_id=model_user_id
|
||||
or (
|
||||
decoded_token.get("sub")
|
||||
if decoded_token
|
||||
else None
|
||||
),
|
||||
)
|
||||
if model_id
|
||||
else settings.LLM_PROVIDER
|
||||
)
|
||||
sys_api_key = get_api_key_for_provider(
|
||||
provider or settings.LLM_PROVIDER
|
||||
)
|
||||
llm = LLMCreator.create_llm(
|
||||
settings.LLM_PROVIDER,
|
||||
api_key=settings.API_KEY,
|
||||
provider or settings.LLM_PROVIDER,
|
||||
api_key=sys_api_key,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
model_id=model_id,
|
||||
agent_id=agent_id,
|
||||
model_user_id=model_user_id,
|
||||
)
|
||||
self.conversation_service.save_conversation(
|
||||
conversation_id,
|
||||
|
||||
@@ -109,6 +109,7 @@ class StreamResource(Resource, BaseAnswerResource):
|
||||
decoded_token=processor.decoded_token,
|
||||
agent_id=processor.agent_id,
|
||||
model_id=processor.model_id,
|
||||
model_user_id=processor.model_user_id,
|
||||
_continuation={
|
||||
"messages": messages,
|
||||
"tools_dict": tools_dict,
|
||||
@@ -145,6 +146,7 @@ class StreamResource(Resource, BaseAnswerResource):
|
||||
is_shared_usage=processor.is_shared_usage,
|
||||
shared_token=processor.shared_token,
|
||||
model_id=processor.model_id,
|
||||
model_user_id=processor.model_user_id,
|
||||
),
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
|
||||
@@ -49,6 +49,7 @@ class CompressionOrchestrator:
|
||||
model_id: str,
|
||||
decoded_token: Dict[str, Any],
|
||||
current_query_tokens: int = 500,
|
||||
model_user_id: Optional[str] = None,
|
||||
) -> CompressionResult:
|
||||
"""
|
||||
Check if compression is needed and perform it if so.
|
||||
@@ -57,16 +58,18 @@ class CompressionOrchestrator:
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation ID
|
||||
user_id: User ID
|
||||
user_id: Caller's user id — used for conversation access checks
|
||||
model_id: Model being used for conversation
|
||||
decoded_token: User's decoded JWT token
|
||||
current_query_tokens: Estimated tokens for current query
|
||||
model_user_id: BYOM-resolution scope (model owner); defaults
|
||||
to ``user_id`` for built-in / caller-owned models.
|
||||
|
||||
Returns:
|
||||
CompressionResult with summary and recent queries
|
||||
"""
|
||||
try:
|
||||
# Load conversation
|
||||
# Conversation row is owned by the caller, not the model owner.
|
||||
conversation = self.conversation_service.get_conversation(
|
||||
conversation_id, user_id
|
||||
)
|
||||
@@ -77,9 +80,14 @@ class CompressionOrchestrator:
|
||||
)
|
||||
return CompressionResult.failure("Conversation not found")
|
||||
|
||||
# Check if compression is needed
|
||||
# Use model-owner scope so per-user BYOM context windows
|
||||
# (e.g. 8k) compute the threshold against the right limit.
|
||||
registry_user_id = model_user_id or user_id
|
||||
if not self.threshold_checker.should_compress(
|
||||
conversation, model_id, current_query_tokens
|
||||
conversation,
|
||||
model_id,
|
||||
current_query_tokens,
|
||||
user_id=registry_user_id,
|
||||
):
|
||||
# No compression needed, return full history
|
||||
queries = conversation.get("queries", [])
|
||||
@@ -87,7 +95,12 @@ class CompressionOrchestrator:
|
||||
|
||||
# Perform compression
|
||||
return self._perform_compression(
|
||||
conversation_id, conversation, model_id, decoded_token
|
||||
conversation_id,
|
||||
conversation,
|
||||
model_id,
|
||||
decoded_token,
|
||||
user_id=user_id,
|
||||
model_user_id=model_user_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -102,6 +115,8 @@ class CompressionOrchestrator:
|
||||
conversation: Dict[str, Any],
|
||||
model_id: str,
|
||||
decoded_token: Dict[str, Any],
|
||||
user_id: Optional[str] = None,
|
||||
model_user_id: Optional[str] = None,
|
||||
) -> CompressionResult:
|
||||
"""
|
||||
Perform the actual compression operation.
|
||||
@@ -111,6 +126,8 @@ class CompressionOrchestrator:
|
||||
conversation: Conversation document
|
||||
model_id: Model ID for conversation
|
||||
decoded_token: User token
|
||||
user_id: Caller's id (for conversation reload after compression)
|
||||
model_user_id: BYOM-resolution scope (model owner)
|
||||
|
||||
Returns:
|
||||
CompressionResult
|
||||
@@ -123,11 +140,17 @@ class CompressionOrchestrator:
|
||||
else model_id
|
||||
)
|
||||
|
||||
# Get provider and API key for compression model
|
||||
provider = get_provider_from_model_id(compression_model)
|
||||
# Use model-owner scope so provider/api_key resolves to the
|
||||
# owner's BYOM record (shared-agent dispatch).
|
||||
caller_user_id = user_id
|
||||
if caller_user_id is None and isinstance(decoded_token, dict):
|
||||
caller_user_id = decoded_token.get("sub")
|
||||
registry_user_id = model_user_id or caller_user_id
|
||||
provider = get_provider_from_model_id(
|
||||
compression_model, user_id=registry_user_id
|
||||
)
|
||||
api_key = get_api_key_for_provider(provider)
|
||||
|
||||
# Create compression LLM
|
||||
compression_llm = LLMCreator.create_llm(
|
||||
provider,
|
||||
api_key=api_key,
|
||||
@@ -135,6 +158,7 @@ class CompressionOrchestrator:
|
||||
decoded_token=decoded_token,
|
||||
model_id=compression_model,
|
||||
agent_id=conversation.get("agent_id"),
|
||||
model_user_id=registry_user_id,
|
||||
)
|
||||
|
||||
# Create compression service with DB update capability
|
||||
@@ -167,9 +191,12 @@ class CompressionOrchestrator:
|
||||
f"saved {metadata.original_token_count - metadata.compressed_token_count} tokens"
|
||||
)
|
||||
|
||||
# Reload conversation with updated metadata
|
||||
# Reload under caller (conversation is owned by caller).
|
||||
reload_user_id = caller_user_id
|
||||
if reload_user_id is None and isinstance(decoded_token, dict):
|
||||
reload_user_id = decoded_token.get("sub")
|
||||
conversation = self.conversation_service.get_conversation(
|
||||
conversation_id, user_id=decoded_token.get("sub")
|
||||
conversation_id, user_id=reload_user_id
|
||||
)
|
||||
|
||||
# Get compressed context
|
||||
@@ -192,16 +219,21 @@ class CompressionOrchestrator:
|
||||
model_id: str,
|
||||
decoded_token: Dict[str, Any],
|
||||
current_conversation: Optional[Dict[str, Any]] = None,
|
||||
model_user_id: Optional[str] = None,
|
||||
) -> CompressionResult:
|
||||
"""
|
||||
Perform compression during tool execution.
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation ID
|
||||
user_id: User ID
|
||||
user_id: Caller's user id — used for conversation access checks
|
||||
model_id: Model ID
|
||||
decoded_token: User token
|
||||
current_conversation: Pre-loaded conversation (optional)
|
||||
model_user_id: BYOM-resolution scope (model owner). For
|
||||
shared-agent dispatch this is the agent owner; defaults
|
||||
to ``user_id`` so built-in / caller-owned models are
|
||||
unaffected.
|
||||
|
||||
Returns:
|
||||
CompressionResult
|
||||
@@ -223,7 +255,12 @@ class CompressionOrchestrator:
|
||||
|
||||
# Perform compression
|
||||
return self._perform_compression(
|
||||
conversation_id, conversation, model_id, decoded_token
|
||||
conversation_id,
|
||||
conversation,
|
||||
model_id,
|
||||
decoded_token,
|
||||
user_id=user_id,
|
||||
model_user_id=model_user_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -106,8 +106,13 @@ class CompressionService:
|
||||
f"using model {self.model_id}"
|
||||
)
|
||||
|
||||
# See note in conversation_service.py: ``self.model_id`` is
|
||||
# the registry id (UUID for BYOM); the LLM's own model_id is
|
||||
# what the provider's API actually expects.
|
||||
response = self.llm.gen(
|
||||
model=self.model_id, messages=messages, max_tokens=4000
|
||||
model=getattr(self.llm, "model_id", None) or self.model_id,
|
||||
messages=messages,
|
||||
max_tokens=4000,
|
||||
)
|
||||
|
||||
# Extract summary from response
|
||||
|
||||
@@ -30,6 +30,7 @@ class CompressionThresholdChecker:
|
||||
conversation: Dict[str, Any],
|
||||
model_id: str,
|
||||
current_query_tokens: int = 500,
|
||||
user_id: str | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if compression is needed.
|
||||
@@ -38,6 +39,8 @@ class CompressionThresholdChecker:
|
||||
conversation: Full conversation document
|
||||
model_id: Target model for this request
|
||||
current_query_tokens: Estimated tokens for current query
|
||||
user_id: Owner — needed so per-user BYOM custom-model UUIDs
|
||||
resolve when looking up the context window.
|
||||
|
||||
Returns:
|
||||
True if tokens >= threshold% of context window
|
||||
@@ -48,7 +51,7 @@ class CompressionThresholdChecker:
|
||||
total_tokens += current_query_tokens
|
||||
|
||||
# Get context window limit for model
|
||||
context_limit = get_token_limit(model_id)
|
||||
context_limit = get_token_limit(model_id, user_id=user_id)
|
||||
|
||||
# Calculate threshold
|
||||
threshold = int(context_limit * self.threshold_percentage)
|
||||
@@ -73,20 +76,24 @@ class CompressionThresholdChecker:
|
||||
logger.error(f"Error checking compression need: {str(e)}", exc_info=True)
|
||||
return False
|
||||
|
||||
def check_message_tokens(self, messages: list, model_id: str) -> bool:
|
||||
def check_message_tokens(
|
||||
self, messages: list, model_id: str, user_id: str | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
Check if message list exceeds threshold.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts
|
||||
model_id: Target model
|
||||
user_id: Owner — needed so per-user BYOM custom-model UUIDs
|
||||
resolve when looking up the context window.
|
||||
|
||||
Returns:
|
||||
True if at or above threshold
|
||||
"""
|
||||
try:
|
||||
current_tokens = TokenCounter.count_message_tokens(messages)
|
||||
context_limit = get_token_limit(model_id)
|
||||
context_limit = get_token_limit(model_id, user_id=user_id)
|
||||
threshold = int(context_limit * self.threshold_percentage)
|
||||
|
||||
if current_tokens >= threshold:
|
||||
|
||||
@@ -136,8 +136,14 @@ class ConversationService:
|
||||
},
|
||||
]
|
||||
|
||||
# ``model_id`` here is the registry id (a UUID for BYOM
|
||||
# records). The LLM's own ``model_id`` is the upstream name
|
||||
# LLMCreator resolved at construction time — that's what
|
||||
# the provider's API expects. Built-ins are unaffected.
|
||||
completion = llm.gen(
|
||||
model=model_id, messages=messages_summary, max_tokens=500
|
||||
model=getattr(llm, "model_id", None) or model_id,
|
||||
messages=messages_summary,
|
||||
max_tokens=500,
|
||||
)
|
||||
|
||||
if not completion or not completion.strip():
|
||||
|
||||
@@ -121,6 +121,8 @@ class StreamProcessor:
|
||||
self.agent_id = self.data.get("agent_id")
|
||||
self.agent_key = None
|
||||
self.model_id: Optional[str] = None
|
||||
# BYOM-resolution scope, set by _validate_and_set_model.
|
||||
self.model_user_id: Optional[str] = None
|
||||
self.conversation_service = ConversationService()
|
||||
self.compression_orchestrator = CompressionOrchestrator(
|
||||
self.conversation_service
|
||||
@@ -191,16 +193,23 @@ class StreamProcessor:
|
||||
for query in conversation.get("queries", [])
|
||||
]
|
||||
else:
|
||||
# model_user_id keeps history trim aligned with the BYOM's
|
||||
# actual context window instead of the default 128k.
|
||||
self.history = limit_chat_history(
|
||||
json.loads(self.data.get("history", "[]")), model_id=self.model_id
|
||||
json.loads(self.data.get("history", "[]")),
|
||||
model_id=self.model_id,
|
||||
user_id=self.model_user_id,
|
||||
)
|
||||
|
||||
def _handle_compression(self, conversation: Dict[str, Any]):
|
||||
"""Handle conversation compression logic using orchestrator."""
|
||||
try:
|
||||
# initial_user_id for conversation access; model_user_id
|
||||
# for BYOM context-window / provider lookups.
|
||||
result = self.compression_orchestrator.compress_if_needed(
|
||||
conversation_id=self.conversation_id,
|
||||
user_id=self.initial_user_id,
|
||||
model_user_id=self.model_user_id,
|
||||
model_id=self.model_id,
|
||||
decoded_token=self.decoded_token,
|
||||
)
|
||||
@@ -284,11 +293,18 @@ class StreamProcessor:
|
||||
from application.core.model_settings import ModelRegistry
|
||||
|
||||
requested_model = self.data.get("model_id")
|
||||
# Caller picks from their own BYOM layer; agent defaults resolve
|
||||
# under the owner's layer (shared agents have caller != owner).
|
||||
caller_user_id = self.initial_user_id
|
||||
owner_user_id = self.agent_config.get("user_id") or caller_user_id
|
||||
|
||||
if requested_model:
|
||||
if not validate_model_id(requested_model):
|
||||
if not validate_model_id(requested_model, user_id=caller_user_id):
|
||||
registry = ModelRegistry.get_instance()
|
||||
available_models = [m.id for m in registry.get_enabled_models()]
|
||||
available_models = [
|
||||
m.id
|
||||
for m in registry.get_enabled_models(user_id=caller_user_id)
|
||||
]
|
||||
raise ValueError(
|
||||
f"Invalid model_id '{requested_model}'. "
|
||||
f"Available models: {', '.join(available_models[:5])}"
|
||||
@@ -299,12 +315,17 @@ class StreamProcessor:
|
||||
)
|
||||
)
|
||||
self.model_id = requested_model
|
||||
self.model_user_id = caller_user_id
|
||||
else:
|
||||
agent_default_model = self.agent_config.get("default_model_id", "")
|
||||
if agent_default_model and validate_model_id(agent_default_model):
|
||||
if agent_default_model and validate_model_id(
|
||||
agent_default_model, user_id=owner_user_id
|
||||
):
|
||||
self.model_id = agent_default_model
|
||||
self.model_user_id = owner_user_id
|
||||
else:
|
||||
self.model_id = get_default_model_id()
|
||||
self.model_user_id = None
|
||||
|
||||
def _get_agent_key(self, agent_id: Optional[str], user_id: Optional[str]) -> tuple:
|
||||
"""Get API key for agent with access control."""
|
||||
@@ -514,6 +535,10 @@ class StreamProcessor:
|
||||
"allow_system_prompt_override": self._agent_data.get(
|
||||
"allow_system_prompt_override", False
|
||||
),
|
||||
# Owner identity — _validate_and_set_model reads this to
|
||||
# resolve owner-stored BYOM default_model_id against the
|
||||
# owner's per-user model layer rather than the caller's.
|
||||
"user_id": self._agent_data.get("user"),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -561,7 +586,13 @@ class StreamProcessor:
|
||||
|
||||
def _configure_retriever(self):
|
||||
"""Assemble retriever config with precedence: request > agent > default."""
|
||||
doc_token_limit = calculate_doc_token_budget(model_id=self.model_id)
|
||||
# BYOM scope: owner for shared-agent BYOM, caller for own BYOM,
|
||||
# None for built-ins. Without ``user_id`` here, the doc budget
|
||||
# falls back to settings.DEFAULT_LLM_TOKEN_LIMIT and overfills
|
||||
# the upstream context window for any small (e.g. 8k/32k) BYOM.
|
||||
doc_token_limit = calculate_doc_token_budget(
|
||||
model_id=self.model_id, user_id=self.model_user_id
|
||||
)
|
||||
|
||||
# Start with defaults
|
||||
retriever_name = "classic"
|
||||
@@ -612,6 +643,7 @@ class StreamProcessor:
|
||||
chunks=self.retriever_config["chunks"],
|
||||
doc_token_limit=self.retriever_config.get("doc_token_limit", 50000),
|
||||
model_id=self.model_id,
|
||||
model_user_id=self.model_user_id,
|
||||
user_api_key=self.agent_config["user_api_key"],
|
||||
agent_id=self.agent_id,
|
||||
decoded_token=self.decoded_token,
|
||||
@@ -903,6 +935,11 @@ class StreamProcessor:
|
||||
agent_config = state["agent_config"]
|
||||
|
||||
model_id = agent_config.get("model_id")
|
||||
# BYOM scope captured at initial dispatch. None for built-ins or
|
||||
# caller-owned BYOM where decoded_token['sub'] is already the
|
||||
# right scope; non-None for shared-agent owner BYOM where the
|
||||
# caller's identity differs from the model owner's.
|
||||
model_user_id = agent_config.get("model_user_id")
|
||||
llm_name = agent_config.get("llm_name", settings.LLM_PROVIDER)
|
||||
api_key = agent_config.get("api_key")
|
||||
user_api_key = agent_config.get("user_api_key")
|
||||
@@ -920,6 +957,7 @@ class StreamProcessor:
|
||||
decoded_token=self.decoded_token,
|
||||
model_id=model_id,
|
||||
agent_id=agent_id,
|
||||
model_user_id=model_user_id,
|
||||
)
|
||||
llm_handler = LLMHandlerCreator.create_handler(llm_name or "default")
|
||||
tool_executor = ToolExecutor(
|
||||
@@ -949,6 +987,7 @@ class StreamProcessor:
|
||||
"endpoint": "stream",
|
||||
"llm_name": llm_name,
|
||||
"model_id": model_id,
|
||||
"model_user_id": model_user_id,
|
||||
"api_key": system_api_key,
|
||||
"agent_id": agent_id,
|
||||
"user_api_key": user_api_key,
|
||||
@@ -971,6 +1010,15 @@ class StreamProcessor:
|
||||
|
||||
# Store config for the route layer
|
||||
self.model_id = model_id
|
||||
# Mirror ``model_user_id`` back onto the processor so the route
|
||||
# layer (StreamResource) reads the owner scope captured at
|
||||
# initial dispatch. Without this, ``processor.model_user_id``
|
||||
# stays at the __init__ default (None) and complete_stream
|
||||
# falls back to the caller's sub: the post-resume title-LLM
|
||||
# save misses the owner's BYOM layer, and any second tool
|
||||
# pause persists ``model_user_id=None`` — losing owner scope
|
||||
# for every subsequent resume of this conversation.
|
||||
self.model_user_id = model_user_id
|
||||
self.agent_id = agent_id
|
||||
self.agent_config["user_api_key"] = user_api_key
|
||||
self.conversation_id = conversation_id
|
||||
@@ -1022,8 +1070,11 @@ class StreamProcessor:
|
||||
tools_data=tools_data,
|
||||
)
|
||||
|
||||
# Use the user_id that resolved the model so owner-scoped BYOM
|
||||
# records dispatch correctly on shared-agent requests.
|
||||
model_user_id = getattr(self, "model_user_id", self.initial_user_id)
|
||||
provider = (
|
||||
get_provider_from_model_id(self.model_id)
|
||||
get_provider_from_model_id(self.model_id, user_id=model_user_id)
|
||||
if self.model_id
|
||||
else settings.LLM_PROVIDER
|
||||
)
|
||||
@@ -1048,6 +1099,8 @@ class StreamProcessor:
|
||||
model_id=self.model_id,
|
||||
agent_id=self.agent_id,
|
||||
backup_models=backup_models,
|
||||
# Owner-scope on shared-agent BYOM dispatch.
|
||||
model_user_id=model_user_id,
|
||||
)
|
||||
llm_handler = LLMHandlerCreator.create_handler(
|
||||
provider if provider else "default"
|
||||
@@ -1070,6 +1123,7 @@ class StreamProcessor:
|
||||
"endpoint": "stream",
|
||||
"llm_name": provider or settings.LLM_PROVIDER,
|
||||
"model_id": self.model_id,
|
||||
"model_user_id": self.model_user_id,
|
||||
"api_key": system_api_key,
|
||||
"agent_id": self.agent_id,
|
||||
"user_api_key": self.agent_config["user_api_key"],
|
||||
@@ -1097,6 +1151,7 @@ class StreamProcessor:
|
||||
"doc_token_limit", 50000
|
||||
),
|
||||
"model_id": self.model_id,
|
||||
"model_user_id": self.model_user_id,
|
||||
"user_api_key": self.agent_config["user_api_key"],
|
||||
"agent_id": self.agent_id,
|
||||
"llm_name": provider or settings.LLM_PROVIDER,
|
||||
|
||||
@@ -1,18 +1,135 @@
|
||||
from flask import current_app, jsonify, make_response
|
||||
"""Model routes.
|
||||
|
||||
- ``GET /api/models`` — list available models for the current user.
|
||||
Combines the built-in catalog with the user's BYOM records.
|
||||
- ``GET/POST/PATCH/DELETE /api/user/models[/<id>]`` — CRUD for the
|
||||
user's own OpenAI-compatible model registrations (BYOM).
|
||||
- ``POST /api/user/models/<id>/test`` — sanity-check the upstream
|
||||
endpoint with a tiny request.
|
||||
|
||||
Every BYOM endpoint is user-scoped at the repository layer
|
||||
(every query filters on ``user_id`` from ``request.decoded_token``).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import requests
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import Namespace, Resource
|
||||
|
||||
from application.core.model_settings import ModelRegistry
|
||||
from application.api import api
|
||||
from application.core.model_registry import ModelRegistry
|
||||
from application.security.safe_url import (
|
||||
UnsafeUserUrlError,
|
||||
pinned_post,
|
||||
validate_user_base_url,
|
||||
)
|
||||
from application.storage.db.repositories.user_custom_models import (
|
||||
UserCustomModelsRepository,
|
||||
)
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.utils import check_required_fields
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
models_ns = Namespace("models", description="Available models", path="/api")
|
||||
|
||||
|
||||
_CONTEXT_WINDOW_MIN = 1_000
|
||||
_CONTEXT_WINDOW_MAX = 10_000_000
|
||||
|
||||
|
||||
def _user_id_or_401():
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return None, make_response(jsonify({"success": False}), 401)
|
||||
user_id = decoded_token.get("sub")
|
||||
if not user_id:
|
||||
return None, make_response(jsonify({"success": False}), 401)
|
||||
return user_id, None
|
||||
|
||||
|
||||
def _normalize_capabilities(raw) -> dict:
|
||||
"""Coerce + bound the user-supplied capabilities payload."""
|
||||
raw = raw or {}
|
||||
out = {}
|
||||
if "supports_tools" in raw:
|
||||
out["supports_tools"] = bool(raw["supports_tools"])
|
||||
if "supports_structured_output" in raw:
|
||||
out["supports_structured_output"] = bool(raw["supports_structured_output"])
|
||||
if "supports_streaming" in raw:
|
||||
out["supports_streaming"] = bool(raw["supports_streaming"])
|
||||
if "attachments" in raw:
|
||||
atts = raw["attachments"] or []
|
||||
if not isinstance(atts, list):
|
||||
raise ValueError("'capabilities.attachments' must be a list")
|
||||
coerced = [str(a) for a in atts]
|
||||
# Reject unknown aliases at the API boundary so bad payloads
|
||||
# never reach the registry layer (where lenient expansion just
|
||||
# drops them). Raw MIME types (containing ``/``) pass through
|
||||
# unchanged for parity with the built-in YAML schema.
|
||||
from application.core.model_yaml import builtin_attachment_aliases
|
||||
|
||||
aliases = builtin_attachment_aliases()
|
||||
for entry in coerced:
|
||||
if "/" in entry:
|
||||
continue
|
||||
if entry not in aliases:
|
||||
valid = ", ".join(sorted(aliases.keys())) or "<none defined>"
|
||||
raise ValueError(
|
||||
f"unknown attachment alias '{entry}' in "
|
||||
f"'capabilities.attachments'. Valid aliases: {valid}, "
|
||||
f"or use a raw MIME type like 'image/png'."
|
||||
)
|
||||
out["attachments"] = coerced
|
||||
if "context_window" in raw:
|
||||
try:
|
||||
cw = int(raw["context_window"])
|
||||
except (TypeError, ValueError):
|
||||
raise ValueError("'capabilities.context_window' must be an integer")
|
||||
if not (_CONTEXT_WINDOW_MIN <= cw <= _CONTEXT_WINDOW_MAX):
|
||||
raise ValueError(
|
||||
f"'capabilities.context_window' must be between "
|
||||
f"{_CONTEXT_WINDOW_MIN} and {_CONTEXT_WINDOW_MAX}"
|
||||
)
|
||||
out["context_window"] = cw
|
||||
return out
|
||||
|
||||
|
||||
def _row_to_response(row: dict) -> dict:
|
||||
"""Wire-format projection — never includes the API key."""
|
||||
return {
|
||||
"id": str(row["id"]),
|
||||
"upstream_model_id": row["upstream_model_id"],
|
||||
"display_name": row["display_name"],
|
||||
"description": row.get("description") or "",
|
||||
"base_url": row["base_url"],
|
||||
"capabilities": row.get("capabilities") or {},
|
||||
"enabled": bool(row.get("enabled", True)),
|
||||
"source": "user",
|
||||
}
|
||||
|
||||
|
||||
@models_ns.route("/models")
|
||||
class ModelsListResource(Resource):
|
||||
def get(self):
|
||||
"""Get list of available models with their capabilities."""
|
||||
"""Get list of available models with their capabilities.
|
||||
|
||||
When the request is authenticated, the response includes the
|
||||
user's own BYOM registrations alongside the built-in catalog.
|
||||
"""
|
||||
try:
|
||||
user_id = None
|
||||
decoded_token = getattr(request, "decoded_token", None)
|
||||
if decoded_token:
|
||||
user_id = decoded_token.get("sub")
|
||||
|
||||
registry = ModelRegistry.get_instance()
|
||||
models = registry.get_enabled_models()
|
||||
models = registry.get_enabled_models(user_id=user_id)
|
||||
|
||||
response = {
|
||||
"models": [model.to_dict() for model in models],
|
||||
@@ -23,3 +140,382 @@ class ModelsListResource(Resource):
|
||||
current_app.logger.error(f"Error fetching models: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 500)
|
||||
return make_response(jsonify(response), 200)
|
||||
|
||||
|
||||
@models_ns.route("/user/models")
|
||||
class UserModelsCollectionResource(Resource):
|
||||
@api.doc(description="List the current user's BYOM custom models")
|
||||
def get(self):
|
||||
user_id, err = _user_id_or_401()
|
||||
if err:
|
||||
return err
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
rows = UserCustomModelsRepository(conn).list_for_user(user_id)
|
||||
return make_response(
|
||||
jsonify({"models": [_row_to_response(r) for r in rows]}), 200
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error listing user custom models: {e}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 500)
|
||||
|
||||
@api.doc(description="Register a new BYOM custom model")
|
||||
def post(self):
|
||||
user_id, err = _user_id_or_401()
|
||||
if err:
|
||||
return err
|
||||
|
||||
data = request.get_json() or {}
|
||||
missing = check_required_fields(
|
||||
data,
|
||||
["upstream_model_id", "display_name", "base_url", "api_key"],
|
||||
)
|
||||
if missing:
|
||||
return missing
|
||||
|
||||
# SECURITY: reject blank api_key — would leak instance API key
|
||||
# to the user-supplied base_url via LLMCreator fallback.
|
||||
for required_nonblank in (
|
||||
"upstream_model_id",
|
||||
"display_name",
|
||||
"base_url",
|
||||
"api_key",
|
||||
):
|
||||
value = data.get(required_nonblank)
|
||||
if not isinstance(value, str) or not value.strip():
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": f"'{required_nonblank}' must be a non-empty string",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
|
||||
# SSRF guard at create time. Re-runs at dispatch time (LLMCreator)
|
||||
# as defense in depth against DNS rebinding and pre-guard rows.
|
||||
try:
|
||||
validate_user_base_url(data["base_url"])
|
||||
except UnsafeUserUrlError as e:
|
||||
return make_response(
|
||||
jsonify({"success": False, "error": str(e)}), 400
|
||||
)
|
||||
|
||||
try:
|
||||
capabilities = _normalize_capabilities(data.get("capabilities"))
|
||||
except ValueError as e:
|
||||
return make_response(
|
||||
jsonify({"success": False, "error": str(e)}), 400
|
||||
)
|
||||
|
||||
try:
|
||||
with db_session() as conn:
|
||||
row = UserCustomModelsRepository(conn).create(
|
||||
user_id=user_id,
|
||||
upstream_model_id=data["upstream_model_id"],
|
||||
display_name=data["display_name"],
|
||||
description=data.get("description") or "",
|
||||
base_url=data["base_url"],
|
||||
api_key_plaintext=data["api_key"],
|
||||
capabilities=capabilities,
|
||||
enabled=bool(data.get("enabled", True)),
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error creating user custom model: {e}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 500)
|
||||
|
||||
ModelRegistry.invalidate_user(user_id)
|
||||
return make_response(jsonify(_row_to_response(row)), 201)
|
||||
|
||||
|
||||
@models_ns.route("/user/models/<string:model_id>")
|
||||
class UserModelResource(Resource):
|
||||
@api.doc(description="Get one BYOM custom model")
|
||||
def get(self, model_id):
|
||||
user_id, err = _user_id_or_401()
|
||||
if err:
|
||||
return err
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
row = UserCustomModelsRepository(conn).get(model_id, user_id)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error fetching user custom model: {e}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 500)
|
||||
if row is None:
|
||||
return make_response(jsonify({"success": False}), 404)
|
||||
return make_response(jsonify(_row_to_response(row)), 200)
|
||||
|
||||
@api.doc(description="Update a BYOM custom model (partial)")
|
||||
def patch(self, model_id):
|
||||
user_id, err = _user_id_or_401()
|
||||
if err:
|
||||
return err
|
||||
|
||||
data = request.get_json() or {}
|
||||
|
||||
# Reject present-but-blank values for fields where blank doesn't
|
||||
# mean "no change". (The api_key special case — blank means "keep
|
||||
# existing" — is handled below.)
|
||||
for required_nonblank in (
|
||||
"upstream_model_id",
|
||||
"display_name",
|
||||
"base_url",
|
||||
):
|
||||
if required_nonblank in data:
|
||||
value = data[required_nonblank]
|
||||
if not isinstance(value, str) or not value.strip():
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": f"'{required_nonblank}' cannot be blank",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
|
||||
if "base_url" in data and data["base_url"]:
|
||||
try:
|
||||
validate_user_base_url(data["base_url"])
|
||||
except UnsafeUserUrlError as e:
|
||||
return make_response(
|
||||
jsonify({"success": False, "error": str(e)}), 400
|
||||
)
|
||||
|
||||
update_fields: dict = {}
|
||||
for k in (
|
||||
"upstream_model_id",
|
||||
"display_name",
|
||||
"description",
|
||||
"base_url",
|
||||
"enabled",
|
||||
):
|
||||
if k in data:
|
||||
update_fields[k] = data[k]
|
||||
|
||||
if "capabilities" in data:
|
||||
try:
|
||||
update_fields["capabilities"] = _normalize_capabilities(
|
||||
data["capabilities"]
|
||||
)
|
||||
except ValueError as e:
|
||||
return make_response(
|
||||
jsonify({"success": False, "error": str(e)}), 400
|
||||
)
|
||||
|
||||
# PATCH semantics: blank/missing api_key → keep the existing
|
||||
# ciphertext; non-empty api_key → re-encrypt and replace.
|
||||
if data.get("api_key"):
|
||||
update_fields["api_key_plaintext"] = data["api_key"]
|
||||
|
||||
if not update_fields:
|
||||
return make_response(
|
||||
jsonify({"success": False, "error": "no updatable fields"}), 400
|
||||
)
|
||||
|
||||
try:
|
||||
with db_session() as conn:
|
||||
ok = UserCustomModelsRepository(conn).update(
|
||||
model_id, user_id, update_fields
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error updating user custom model: {e}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 500)
|
||||
|
||||
if not ok:
|
||||
return make_response(jsonify({"success": False}), 404)
|
||||
|
||||
ModelRegistry.invalidate_user(user_id)
|
||||
with db_readonly() as conn:
|
||||
row = UserCustomModelsRepository(conn).get(model_id, user_id)
|
||||
return make_response(jsonify(_row_to_response(row)), 200)
|
||||
|
||||
@api.doc(description="Delete a BYOM custom model")
|
||||
def delete(self, model_id):
|
||||
user_id, err = _user_id_or_401()
|
||||
if err:
|
||||
return err
|
||||
try:
|
||||
with db_session() as conn:
|
||||
ok = UserCustomModelsRepository(conn).delete(model_id, user_id)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error deleting user custom model: {e}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 500)
|
||||
if not ok:
|
||||
return make_response(jsonify({"success": False}), 404)
|
||||
|
||||
ModelRegistry.invalidate_user(user_id)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
def _run_connection_test(
|
||||
base_url: str, api_key: str, upstream_model_id: str
|
||||
):
|
||||
"""Send a 1-token chat-completion to verify a BYOM endpoint.
|
||||
|
||||
Returns ``(body, http_status)``. Upstream errors return 200 with
|
||||
``ok=False`` so the UI can render inline errors; only local SSRF
|
||||
rejection returns 400.
|
||||
"""
|
||||
url = base_url.rstrip("/") + "/chat/completions"
|
||||
payload = {
|
||||
"model": upstream_model_id,
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"max_tokens": 1,
|
||||
"stream": False,
|
||||
}
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
try:
|
||||
# pinned_post closes the DNS-rebinding window. Redirects off
|
||||
# because 3xx could bounce to an internal address (the SSRF
|
||||
# guard only validates the supplied URL).
|
||||
resp = pinned_post(
|
||||
url,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=5,
|
||||
allow_redirects=False,
|
||||
)
|
||||
except UnsafeUserUrlError as e:
|
||||
return {"ok": False, "error": str(e)}, 400
|
||||
except requests.RequestException as e:
|
||||
return {"ok": False, "error": f"connection error: {e}"}, 200
|
||||
|
||||
if 300 <= resp.status_code < 400:
|
||||
return (
|
||||
{
|
||||
"ok": False,
|
||||
"error": (
|
||||
f"upstream returned HTTP {resp.status_code} "
|
||||
"redirect; refusing to follow"
|
||||
),
|
||||
},
|
||||
200,
|
||||
)
|
||||
|
||||
if resp.status_code >= 400:
|
||||
# Cap and only reflect JSON to avoid body-exfil via non-API responses.
|
||||
content_type = (resp.headers.get("Content-Type") or "").lower()
|
||||
if "application/json" in content_type:
|
||||
text = (resp.text or "")[:500]
|
||||
error_msg = f"upstream returned HTTP {resp.status_code}: {text}"
|
||||
else:
|
||||
error_msg = f"upstream returned HTTP {resp.status_code}"
|
||||
return {"ok": False, "error": error_msg}, 200
|
||||
|
||||
return {"ok": True}, 200
|
||||
|
||||
|
||||
@models_ns.route("/user/models/test")
|
||||
class UserModelTestPayloadResource(Resource):
|
||||
@api.doc(
|
||||
description=(
|
||||
"Test an arbitrary BYOM payload (display_name / model id / "
|
||||
"base_url / api_key) without saving. Used by the UI's 'Test "
|
||||
"connection' button so the user can validate before they "
|
||||
"Save. Same SSRF guard, same 1-token request, same 5s "
|
||||
"timeout as the by-id variant."
|
||||
)
|
||||
)
|
||||
def post(self):
|
||||
user_id, err = _user_id_or_401()
|
||||
if err:
|
||||
return err
|
||||
|
||||
data = request.get_json() or {}
|
||||
missing = check_required_fields(
|
||||
data, ["base_url", "api_key", "upstream_model_id"]
|
||||
)
|
||||
if missing:
|
||||
return missing
|
||||
|
||||
body, status = _run_connection_test(
|
||||
data["base_url"], data["api_key"], data["upstream_model_id"]
|
||||
)
|
||||
return make_response(jsonify(body), status)
|
||||
|
||||
|
||||
@models_ns.route("/user/models/<string:model_id>/test")
|
||||
class UserModelTestResource(Resource):
|
||||
@api.doc(
|
||||
description=(
|
||||
"Test a saved BYOM record. Defaults to the stored "
|
||||
"base_url / upstream_model_id / encrypted api_key, but "
|
||||
"any of those can be overridden via the request body so "
|
||||
"the UI can test in-flight edits before saving. Used by "
|
||||
"the 'Test connection' button in edit mode."
|
||||
)
|
||||
)
|
||||
def post(self, model_id):
|
||||
user_id, err = _user_id_or_401()
|
||||
if err:
|
||||
return err
|
||||
|
||||
data = request.get_json() or {}
|
||||
# Per-field overrides; blank/missing falls back to stored value.
|
||||
override_base_url = (data.get("base_url") or "").strip() or None
|
||||
override_upstream_model_id = (
|
||||
data.get("upstream_model_id") or ""
|
||||
).strip() or None
|
||||
override_api_key = (data.get("api_key") or "").strip() or None
|
||||
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
repo = UserCustomModelsRepository(conn)
|
||||
row = repo.get(model_id, user_id)
|
||||
if row is None:
|
||||
return make_response(jsonify({"success": False}), 404)
|
||||
stored_api_key = (
|
||||
repo._decrypt_api_key(
|
||||
row.get("api_key_encrypted", ""), user_id
|
||||
)
|
||||
if not override_api_key
|
||||
else None
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error loading user custom model for test: {e}", exc_info=True
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"ok": False, "error": "internal error loading model"}),
|
||||
500,
|
||||
)
|
||||
|
||||
api_key = override_api_key or stored_api_key
|
||||
if not api_key:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"ok": False,
|
||||
"error": (
|
||||
"Stored API key could not be decrypted. The "
|
||||
"encryption secret may have rotated. Re-save "
|
||||
"the model with the API key to recover."
|
||||
),
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
|
||||
base_url = override_base_url or row["base_url"]
|
||||
upstream_model_id = (
|
||||
override_upstream_model_id or row["upstream_model_id"]
|
||||
)
|
||||
body, status = _run_connection_test(
|
||||
base_url, api_key, upstream_model_id
|
||||
)
|
||||
return make_response(jsonify(body), status)
|
||||
|
||||
@@ -198,8 +198,14 @@ def normalize_agent_node_json_schemas(nodes: List[Dict]) -> List[Dict]:
|
||||
return normalized_nodes
|
||||
|
||||
|
||||
def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[str]:
|
||||
"""Validate workflow graph structure."""
|
||||
def validate_workflow_structure(
|
||||
nodes: List[Dict], edges: List[Dict], user_id: str | None = None
|
||||
) -> List[str]:
|
||||
"""Validate workflow graph structure.
|
||||
|
||||
``user_id`` is required so per-user BYOM custom-model UUIDs resolve
|
||||
when checking each agent node's structured-output capability.
|
||||
"""
|
||||
errors = []
|
||||
|
||||
if not nodes:
|
||||
@@ -343,7 +349,7 @@ def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[st
|
||||
|
||||
model_id = raw_config.get("model_id")
|
||||
if has_json_schema and isinstance(model_id, str) and model_id.strip():
|
||||
capabilities = get_model_capabilities(model_id.strip())
|
||||
capabilities = get_model_capabilities(model_id.strip(), user_id=user_id)
|
||||
if capabilities and not capabilities.get("supports_structured_output", False):
|
||||
errors.append(
|
||||
f"Agent node '{agent_title}' selected model does not support structured output"
|
||||
@@ -389,7 +395,9 @@ class WorkflowList(Resource):
|
||||
nodes_data = data.get("nodes", [])
|
||||
edges_data = data.get("edges", [])
|
||||
|
||||
validation_errors = validate_workflow_structure(nodes_data, edges_data)
|
||||
validation_errors = validate_workflow_structure(
|
||||
nodes_data, edges_data, user_id=user_id
|
||||
)
|
||||
if validation_errors:
|
||||
return error_response(
|
||||
"Workflow validation failed", errors=validation_errors
|
||||
@@ -451,7 +459,9 @@ class WorkflowDetail(Resource):
|
||||
nodes_data = data.get("nodes", [])
|
||||
edges_data = data.get("edges", [])
|
||||
|
||||
validation_errors = validate_workflow_structure(nodes_data, edges_data)
|
||||
validation_errors = validate_workflow_structure(
|
||||
nodes_data, edges_data, user_id=user_id
|
||||
)
|
||||
if validation_errors:
|
||||
return error_response(
|
||||
"Workflow validation failed", errors=validation_errors
|
||||
|
||||
@@ -213,6 +213,7 @@ def _stream_response(
|
||||
decoded_token=processor.decoded_token,
|
||||
agent_id=processor.agent_id,
|
||||
model_id=processor.model_id,
|
||||
model_user_id=processor.model_user_id,
|
||||
should_save_conversation=should_save_conversation,
|
||||
_continuation=continuation,
|
||||
)
|
||||
@@ -257,6 +258,7 @@ def _non_stream_response(
|
||||
decoded_token=processor.decoded_token,
|
||||
agent_id=processor.agent_id,
|
||||
model_id=processor.model_id,
|
||||
model_user_id=processor.model_user_id,
|
||||
should_save_conversation=should_save_conversation,
|
||||
_continuation=continuation,
|
||||
)
|
||||
|
||||
@@ -153,7 +153,7 @@ def after_request(response: Response) -> Response:
|
||||
"""Add CORS headers for the pure Flask development entrypoint."""
|
||||
response.headers["Access-Control-Allow-Origin"] = "*"
|
||||
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization"
|
||||
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS"
|
||||
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
|
||||
return response
|
||||
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ asgi_app = Starlette(
|
||||
Middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
||||
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
|
||||
allow_headers=["Content-Type", "Authorization", "Mcp-Session-Id"],
|
||||
expose_headers=["Mcp-Session-Id"],
|
||||
),
|
||||
|
||||
@@ -1,266 +0,0 @@
|
||||
"""
|
||||
Model configurations for all supported LLM providers.
|
||||
"""
|
||||
|
||||
from application.core.model_settings import (
|
||||
AvailableModel,
|
||||
ModelCapabilities,
|
||||
ModelProvider,
|
||||
)
|
||||
|
||||
# Base image attachment types supported by most vision-capable LLMs
|
||||
IMAGE_ATTACHMENTS = [
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/jpg",
|
||||
"image/webp",
|
||||
"image/gif",
|
||||
]
|
||||
|
||||
# PDF excluded: most OpenAI-compatible endpoints don't support native PDF uploads.
|
||||
# When excluded, PDFs are synthetically processed by converting pages to images.
|
||||
OPENAI_ATTACHMENTS = IMAGE_ATTACHMENTS
|
||||
|
||||
GOOGLE_ATTACHMENTS = ["application/pdf"] + IMAGE_ATTACHMENTS
|
||||
|
||||
ANTHROPIC_ATTACHMENTS = IMAGE_ATTACHMENTS
|
||||
|
||||
OPENROUTER_ATTACHMENTS = IMAGE_ATTACHMENTS
|
||||
|
||||
NOVITA_ATTACHMENTS = IMAGE_ATTACHMENTS
|
||||
|
||||
|
||||
OPENAI_MODELS = [
|
||||
AvailableModel(
|
||||
id="gpt-5.1",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="GPT-5.1",
|
||||
description="Flagship model with enhanced reasoning, coding, and agentic capabilities",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=OPENAI_ATTACHMENTS,
|
||||
context_window=200000,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="gpt-5-mini",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="GPT-5 Mini",
|
||||
description="Faster, cost-effective variant of GPT-5.1",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=OPENAI_ATTACHMENTS,
|
||||
context_window=200000,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
ANTHROPIC_MODELS = [
|
||||
AvailableModel(
|
||||
id="claude-3-5-sonnet-20241022",
|
||||
provider=ModelProvider.ANTHROPIC,
|
||||
display_name="Claude 3.5 Sonnet (Latest)",
|
||||
description="Latest Claude 3.5 Sonnet with enhanced capabilities",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
|
||||
context_window=200000,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="claude-3-5-sonnet",
|
||||
provider=ModelProvider.ANTHROPIC,
|
||||
display_name="Claude 3.5 Sonnet",
|
||||
description="Balanced performance and capability",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
|
||||
context_window=200000,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="claude-3-opus",
|
||||
provider=ModelProvider.ANTHROPIC,
|
||||
display_name="Claude 3 Opus",
|
||||
description="Most capable Claude model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
|
||||
context_window=200000,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="claude-3-haiku",
|
||||
provider=ModelProvider.ANTHROPIC,
|
||||
display_name="Claude 3 Haiku",
|
||||
description="Fastest Claude model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
|
||||
context_window=200000,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
GOOGLE_MODELS = [
|
||||
AvailableModel(
|
||||
id="gemini-flash-latest",
|
||||
provider=ModelProvider.GOOGLE,
|
||||
display_name="Gemini Flash (Latest)",
|
||||
description="Latest experimental Gemini model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=GOOGLE_ATTACHMENTS,
|
||||
context_window=int(1e6),
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="gemini-flash-lite-latest",
|
||||
provider=ModelProvider.GOOGLE,
|
||||
display_name="Gemini Flash Lite (Latest)",
|
||||
description="Fast with huge context window",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=GOOGLE_ATTACHMENTS,
|
||||
context_window=int(1e6),
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="gemini-3-pro-preview",
|
||||
provider=ModelProvider.GOOGLE,
|
||||
display_name="Gemini 3 Pro",
|
||||
description="Most capable Gemini model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=GOOGLE_ATTACHMENTS,
|
||||
context_window=2000000,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
GROQ_MODELS = [
|
||||
AvailableModel(
|
||||
id="llama-3.3-70b-versatile",
|
||||
provider=ModelProvider.GROQ,
|
||||
display_name="Llama 3.3 70B",
|
||||
description="Latest Llama model with high-speed inference",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
context_window=128000,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="openai/gpt-oss-120b",
|
||||
provider=ModelProvider.GROQ,
|
||||
display_name="GPT-OSS 120B",
|
||||
description="Open-source GPT model optimized for speed",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
context_window=128000,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
OPENROUTER_MODELS = [
|
||||
AvailableModel(
|
||||
id="qwen/qwen3-coder:free",
|
||||
provider=ModelProvider.OPENROUTER,
|
||||
display_name="Qwen 3 Coder",
|
||||
description="Latest Qwen model with high-speed inference",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
context_window=128000,
|
||||
supported_attachment_types=OPENROUTER_ATTACHMENTS
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="google/gemma-3-27b-it:free",
|
||||
provider=ModelProvider.OPENROUTER,
|
||||
display_name="Gemma 3 27B",
|
||||
description="Latest Gemma model with high-speed inference",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
context_window=128000,
|
||||
supported_attachment_types=OPENROUTER_ATTACHMENTS
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
NOVITA_MODELS = [
|
||||
AvailableModel(
|
||||
id="moonshotai/kimi-k2.5",
|
||||
provider=ModelProvider.NOVITA,
|
||||
display_name="Kimi K2.5",
|
||||
description="MoE model with function calling, structured output, reasoning, and vision",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=NOVITA_ATTACHMENTS,
|
||||
context_window=262144,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="zai-org/glm-5",
|
||||
provider=ModelProvider.NOVITA,
|
||||
display_name="GLM-5",
|
||||
description="MoE model with function calling, structured output, and reasoning",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=[],
|
||||
context_window=202800,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="minimax/minimax-m2.5",
|
||||
provider=ModelProvider.NOVITA,
|
||||
display_name="MiniMax M2.5",
|
||||
description="MoE model with function calling, structured output, and reasoning",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=[],
|
||||
context_window=204800,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
AZURE_OPENAI_MODELS = [
|
||||
AvailableModel(
|
||||
id="azure-gpt-4",
|
||||
provider=ModelProvider.AZURE_OPENAI,
|
||||
display_name="Azure OpenAI GPT-4",
|
||||
description="Azure-hosted GPT model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=OPENAI_ATTACHMENTS,
|
||||
context_window=8192,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def create_custom_openai_model(model_name: str, base_url: str) -> AvailableModel:
|
||||
"""Create a custom OpenAI-compatible model (e.g., LM Studio, Ollama)."""
|
||||
return AvailableModel(
|
||||
id=model_name,
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name=model_name,
|
||||
description=f"Custom OpenAI-compatible model at {base_url}",
|
||||
base_url=base_url,
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supported_attachment_types=OPENAI_ATTACHMENTS,
|
||||
),
|
||||
)
|
||||
385
application/core/model_registry.py
Normal file
385
application/core/model_registry.py
Normal file
@@ -0,0 +1,385 @@
|
||||
"""Layered model registry.
|
||||
|
||||
Loads model catalogs from YAML files (built-in + operator-supplied),
|
||||
groups them by provider name, then for each registered provider plugin
|
||||
calls ``get_models`` to produce the final per-provider model list.
|
||||
|
||||
End-user BYOM (per-user model records in Postgres) is layered on top:
|
||||
when a lookup arrives with a ``user_id``, the registry consults a
|
||||
per-user cache first (loaded from the ``user_custom_models`` table on
|
||||
miss) and falls through to the built-in catalog.
|
||||
|
||||
Cross-process invalidation: ``ModelRegistry`` is a per-process
|
||||
singleton, so a CRUD write only evicts the cache in the process that
|
||||
served it. Other gunicorn workers and Celery workers would otherwise
|
||||
keep using a deleted/disabled/key-rotated BYOM record indefinitely.
|
||||
``invalidate_user`` therefore both drops the local layer *and* bumps a
|
||||
Redis-side version counter; other processes notice the bump on their
|
||||
next access (after the local TTL window) and reload from Postgres. If
|
||||
Redis is unreachable the per-process TTL still bounds staleness — pure
|
||||
TTL semantics, no regression.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from application.core.model_settings import AvailableModel
|
||||
from application.core.model_yaml import (
|
||||
BUILTIN_MODELS_DIR,
|
||||
ProviderCatalog,
|
||||
load_model_yamls,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_USER_CACHE_TTL_SECONDS = 60.0
|
||||
_USER_VERSION_KEY_PREFIX = "byom:registry_version:"
|
||||
|
||||
|
||||
class ModelRegistry:
|
||||
"""Singleton registry of available models."""
|
||||
|
||||
_instance: Optional["ModelRegistry"] = None
|
||||
_initialized: bool = False
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not ModelRegistry._initialized:
|
||||
self.models: Dict[str, AvailableModel] = {}
|
||||
self.default_model_id: Optional[str] = None
|
||||
# Per-user BYOM cache. Each entry is
|
||||
# ``(layer, version_at_load, loaded_at_monotonic)``:
|
||||
# * ``layer`` — {model_id: AvailableModel}
|
||||
# * ``version_at_load`` — Redis-side counter snapshot at
|
||||
# reload time, or ``None`` if Redis was unreachable
|
||||
# * ``loaded_at_monotonic`` — for TTL bookkeeping
|
||||
# Populated lazily, evicted by TTL + cross-process
|
||||
# invalidation (see ``invalidate_user``).
|
||||
self._user_models: Dict[
|
||||
str,
|
||||
Tuple[Dict[str, AvailableModel], Optional[int], float],
|
||||
] = {}
|
||||
self._load_models()
|
||||
ModelRegistry._initialized = True
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "ModelRegistry":
|
||||
return cls()
|
||||
|
||||
@classmethod
|
||||
def reset(cls) -> None:
|
||||
"""Clear the singleton. Intended for test fixtures."""
|
||||
cls._instance = None
|
||||
cls._initialized = False
|
||||
|
||||
@classmethod
|
||||
def invalidate_user(cls, user_id: str) -> None:
|
||||
"""Drop the cached per-user model layer for ``user_id``.
|
||||
|
||||
Called by the BYOM REST routes after every create/update/delete.
|
||||
Two effects:
|
||||
|
||||
* Local: pop the entry from this process's cache so the next
|
||||
lookup re-reads from Postgres immediately.
|
||||
* Cross-process: ``INCR`` a Redis-side version counter for this
|
||||
user. Other gunicorn/Celery processes notice the counter
|
||||
changed on their next TTL-driven recheck (see
|
||||
``_user_models_for``) and reload. If Redis is unreachable we
|
||||
log and continue — local invalidation still happened, and
|
||||
peers fall back to TTL-only staleness bounds.
|
||||
"""
|
||||
if cls._instance is not None:
|
||||
cls._instance._user_models.pop(user_id, None)
|
||||
try:
|
||||
from application.cache import get_redis_instance
|
||||
|
||||
client = get_redis_instance()
|
||||
if client is not None:
|
||||
client.incr(_USER_VERSION_KEY_PREFIX + user_id)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"BYOM invalidate: failed to publish version bump for "
|
||||
"user %s (Redis unreachable?): %s",
|
||||
user_id,
|
||||
e,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _read_user_version(cls, user_id: str) -> Optional[int]:
|
||||
"""Return the Redis-side invalidation counter for ``user_id``.
|
||||
|
||||
``0`` if the key has never been bumped; ``None`` if Redis is
|
||||
unreachable or the read failed (callers fall back to TTL-only
|
||||
staleness in that case).
|
||||
"""
|
||||
try:
|
||||
from application.cache import get_redis_instance
|
||||
|
||||
client = get_redis_instance()
|
||||
if client is None:
|
||||
return None
|
||||
raw = client.get(_USER_VERSION_KEY_PREFIX + user_id)
|
||||
if raw is None:
|
||||
return 0
|
||||
return int(raw)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _load_models(self) -> None:
|
||||
from pathlib import Path
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.llm.providers import ALL_PROVIDERS
|
||||
|
||||
directories = [BUILTIN_MODELS_DIR]
|
||||
operator_dir = getattr(settings, "MODELS_CONFIG_DIR", None)
|
||||
if operator_dir:
|
||||
op_path = Path(operator_dir)
|
||||
if not op_path.exists():
|
||||
logger.warning(
|
||||
"MODELS_CONFIG_DIR=%s does not exist; no operator "
|
||||
"model YAMLs will be loaded.",
|
||||
operator_dir,
|
||||
)
|
||||
elif not op_path.is_dir():
|
||||
logger.warning(
|
||||
"MODELS_CONFIG_DIR=%s is not a directory; no operator "
|
||||
"model YAMLs will be loaded.",
|
||||
operator_dir,
|
||||
)
|
||||
else:
|
||||
directories.append(op_path)
|
||||
|
||||
catalogs = load_model_yamls(directories)
|
||||
|
||||
# Validate every catalog targets a known plugin before doing any
|
||||
# registry work, so an unknown provider name in YAML aborts boot
|
||||
# with a clear error.
|
||||
plugin_names = {p.name for p in ALL_PROVIDERS}
|
||||
for c in catalogs:
|
||||
if c.provider not in plugin_names:
|
||||
raise ValueError(
|
||||
f"{c.source_path}: YAML declares unknown provider "
|
||||
f"{c.provider!r}; no Provider plugin is registered "
|
||||
f"under that name. Known: {sorted(plugin_names)}"
|
||||
)
|
||||
|
||||
catalogs_by_provider: Dict[str, List[ProviderCatalog]] = defaultdict(list)
|
||||
for c in catalogs:
|
||||
catalogs_by_provider[c.provider].append(c)
|
||||
|
||||
self.models.clear()
|
||||
for provider in ALL_PROVIDERS:
|
||||
if not provider.is_enabled(settings):
|
||||
continue
|
||||
for model in provider.get_models(
|
||||
settings, catalogs_by_provider.get(provider.name, [])
|
||||
):
|
||||
self.models[model.id] = model
|
||||
|
||||
self.default_model_id = self._resolve_default(settings)
|
||||
|
||||
logger.info(
|
||||
"ModelRegistry loaded %d models, default: %s",
|
||||
len(self.models),
|
||||
self.default_model_id,
|
||||
)
|
||||
|
||||
def _resolve_default(self, settings) -> Optional[str]:
|
||||
if settings.LLM_NAME:
|
||||
for name in self._parse_model_names(settings.LLM_NAME):
|
||||
if name in self.models:
|
||||
return name
|
||||
if settings.LLM_NAME in self.models:
|
||||
return settings.LLM_NAME
|
||||
|
||||
if settings.LLM_PROVIDER and settings.API_KEY:
|
||||
for model_id, model in self.models.items():
|
||||
if model.provider.value == settings.LLM_PROVIDER:
|
||||
return model_id
|
||||
|
||||
if self.models:
|
||||
return next(iter(self.models.keys()))
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _parse_model_names(llm_name: str) -> List[str]:
|
||||
if not llm_name:
|
||||
return []
|
||||
return [name.strip() for name in llm_name.split(",") if name.strip()]
|
||||
|
||||
# Per-user (BYOM) layer
|
||||
|
||||
def _user_models_for(self, user_id: str) -> Dict[str, AvailableModel]:
|
||||
"""Return the user's BYOM models keyed by registry id (UUID).
|
||||
|
||||
Loaded lazily from Postgres on first access; cached subject to
|
||||
a per-process TTL (``_USER_CACHE_TTL_SECONDS``) and a Redis-
|
||||
backed version counter for cross-process invalidation. The TTL
|
||||
bounds staleness even when Redis is unreachable, while the
|
||||
version stamp lets peers refresh without a DB read on the
|
||||
common case (no invalidation since last load). Decryption
|
||||
failures and DB errors yield an empty layer (logged) — the
|
||||
user simply doesn't see their custom models on this request,
|
||||
never a 500.
|
||||
"""
|
||||
cached = self._user_models.get(user_id)
|
||||
now = time.monotonic()
|
||||
|
||||
if cached is not None:
|
||||
layer, cached_version, loaded_at = cached
|
||||
if (now - loaded_at) < _USER_CACHE_TTL_SECONDS:
|
||||
return layer
|
||||
# TTL elapsed: peek at the cross-process counter. If it
|
||||
# matches what we saw at load time, no invalidation has
|
||||
# happened — extend the TTL without touching Postgres. If
|
||||
# Redis is unreachable (``current_version is None``) we
|
||||
# fall through to a real reload, which keeps staleness
|
||||
# bounded to the TTL.
|
||||
current_version = self._read_user_version(user_id)
|
||||
if (
|
||||
current_version is not None
|
||||
and cached_version is not None
|
||||
and current_version == cached_version
|
||||
):
|
||||
self._user_models[user_id] = (layer, cached_version, now)
|
||||
return layer
|
||||
|
||||
# Capture the counter *before* the DB read so a CRUD that lands
|
||||
# mid-reload doesn't get masked: the next access will see a
|
||||
# newer version and reload again.
|
||||
version_before_read = self._read_user_version(user_id)
|
||||
|
||||
layer: Dict[str, AvailableModel] = {}
|
||||
try:
|
||||
from application.core.model_settings import (
|
||||
ModelCapabilities,
|
||||
ModelProvider,
|
||||
)
|
||||
from application.storage.db.repositories.user_custom_models import (
|
||||
UserCustomModelsRepository,
|
||||
)
|
||||
from application.storage.db.session import db_readonly
|
||||
|
||||
with db_readonly() as conn:
|
||||
repo = UserCustomModelsRepository(conn)
|
||||
rows = repo.list_for_user(user_id)
|
||||
for row in rows:
|
||||
api_key = repo._decrypt_api_key(
|
||||
row.get("api_key_encrypted", ""), user_id
|
||||
)
|
||||
if not api_key:
|
||||
# SECURITY: do NOT register an unroutable BYOM
|
||||
# record. If we did, LLMCreator would fall back
|
||||
# to the caller-passed api_key (settings.API_KEY
|
||||
# for openai_compatible) and POST it to the
|
||||
# user-supplied base_url — leaking the instance
|
||||
# credential to the user's chosen endpoint.
|
||||
# Most likely cause is ENCRYPTION_SECRET_KEY
|
||||
# having rotated; user must re-save the model.
|
||||
logger.warning(
|
||||
"user_custom_models: skipping model %s for "
|
||||
"user %s — api_key could not be decrypted "
|
||||
"(rotated ENCRYPTION_SECRET_KEY?). Re-save "
|
||||
"the model to recover.",
|
||||
row.get("id"),
|
||||
user_id,
|
||||
)
|
||||
continue
|
||||
caps_raw = row.get("capabilities") or {}
|
||||
# Stored attachments may be aliases (``image``) or
|
||||
# raw MIME types. Built-in YAML models expand at
|
||||
# load time; mirror that here so downstream MIME-
|
||||
# type comparisons (handlers/base.prepare_messages)
|
||||
# match concrete types like ``image/png`` rather
|
||||
# than the bare alias.
|
||||
from application.core.model_yaml import (
|
||||
expand_attachments_lenient,
|
||||
)
|
||||
|
||||
raw_attachments = caps_raw.get("attachments", []) or []
|
||||
expanded_attachments = expand_attachments_lenient(
|
||||
raw_attachments,
|
||||
f"user_custom_models[user={user_id}, model={row.get('id')}]",
|
||||
)
|
||||
caps = ModelCapabilities(
|
||||
supports_tools=bool(caps_raw.get("supports_tools", False)),
|
||||
supports_structured_output=bool(
|
||||
caps_raw.get("supports_structured_output", False)
|
||||
),
|
||||
supports_streaming=bool(
|
||||
caps_raw.get("supports_streaming", True)
|
||||
),
|
||||
supported_attachment_types=expanded_attachments,
|
||||
context_window=int(
|
||||
caps_raw.get("context_window") or 128000
|
||||
),
|
||||
)
|
||||
model_id = str(row["id"])
|
||||
layer[model_id] = AvailableModel(
|
||||
id=model_id,
|
||||
provider=ModelProvider.OPENAI_COMPATIBLE,
|
||||
display_name=row["display_name"],
|
||||
description=row.get("description") or "",
|
||||
capabilities=caps,
|
||||
enabled=bool(row.get("enabled", True)),
|
||||
base_url=row["base_url"],
|
||||
upstream_model_id=row["upstream_model_id"],
|
||||
source="user",
|
||||
api_key=api_key,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"user_custom_models: failed to load layer for user %s: %s",
|
||||
user_id,
|
||||
e,
|
||||
)
|
||||
layer = {}
|
||||
|
||||
self._user_models[user_id] = (layer, version_before_read, now)
|
||||
return layer
|
||||
|
||||
# Lookup API. ``user_id`` enables the BYOM per-user layer; without
|
||||
# it, callers see only the built-in + operator catalog.
|
||||
|
||||
def get_model(
|
||||
self, model_id: str, user_id: Optional[str] = None
|
||||
) -> Optional[AvailableModel]:
|
||||
if user_id:
|
||||
user_layer = self._user_models_for(user_id)
|
||||
if model_id in user_layer:
|
||||
return user_layer[model_id]
|
||||
return self.models.get(model_id)
|
||||
|
||||
def get_all_models(
|
||||
self, user_id: Optional[str] = None
|
||||
) -> List[AvailableModel]:
|
||||
out = list(self.models.values())
|
||||
if user_id:
|
||||
out.extend(self._user_models_for(user_id).values())
|
||||
return out
|
||||
|
||||
def get_enabled_models(
|
||||
self, user_id: Optional[str] = None
|
||||
) -> List[AvailableModel]:
|
||||
out = [m for m in self.models.values() if m.enabled]
|
||||
if user_id:
|
||||
out.extend(
|
||||
m for m in self._user_models_for(user_id).values() if m.enabled
|
||||
)
|
||||
return out
|
||||
|
||||
def model_exists(
|
||||
self, model_id: str, user_id: Optional[str] = None
|
||||
) -> bool:
|
||||
if user_id and model_id in self._user_models_for(user_id):
|
||||
return True
|
||||
return model_id in self.models
|
||||
@@ -5,9 +5,16 @@ from typing import Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Re-exported here so existing call sites (and tests) that do
|
||||
# ``from application.core.model_settings import ModelRegistry`` keep
|
||||
# working. The implementation lives in ``application/core/model_registry.py``.
|
||||
# Imported lazily inside ``__getattr__`` to avoid an import cycle with
|
||||
# ``model_yaml`` → ``model_settings`` (this file).
|
||||
|
||||
|
||||
class ModelProvider(str, Enum):
|
||||
OPENAI = "openai"
|
||||
OPENAI_COMPATIBLE = "openai_compatible"
|
||||
OPENROUTER = "openrouter"
|
||||
AZURE_OPENAI = "azure_openai"
|
||||
ANTHROPIC = "anthropic"
|
||||
@@ -41,11 +48,21 @@ class AvailableModel:
|
||||
capabilities: ModelCapabilities = field(default_factory=ModelCapabilities)
|
||||
enabled: bool = True
|
||||
base_url: Optional[str] = None
|
||||
# User-facing label distinct from dispatch provider (e.g. mistral
|
||||
# routed through openai_compatible).
|
||||
display_provider: Optional[str] = None
|
||||
# Sent in the API call's ``model`` field; falls back to ``self.id``
|
||||
# for built-ins where id IS the upstream name.
|
||||
upstream_model_id: Optional[str] = None
|
||||
# "builtin" for catalog YAMLs, "user" for BYOM records.
|
||||
source: str = "builtin"
|
||||
# Decrypted/resolved at registry-merge time. Never serialized.
|
||||
api_key: Optional[str] = field(default=None, repr=False, compare=False)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
result = {
|
||||
"id": self.id,
|
||||
"provider": self.provider.value,
|
||||
"provider": self.display_provider or self.provider.value,
|
||||
"display_name": self.display_name,
|
||||
"description": self.description,
|
||||
"supported_attachment_types": self.capabilities.supported_attachment_types,
|
||||
@@ -54,261 +71,21 @@ class AvailableModel:
|
||||
"supports_streaming": self.capabilities.supports_streaming,
|
||||
"context_window": self.capabilities.context_window,
|
||||
"enabled": self.enabled,
|
||||
"source": self.source,
|
||||
}
|
||||
if self.base_url:
|
||||
result["base_url"] = self.base_url
|
||||
return result
|
||||
|
||||
|
||||
class ModelRegistry:
|
||||
_instance = None
|
||||
_initialized = False
|
||||
def __getattr__(name):
|
||||
"""Lazy re-export of ``ModelRegistry`` from ``model_registry.py``.
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
Done lazily to avoid an import cycle: ``model_registry`` imports
|
||||
``model_yaml`` which imports the dataclasses from this file.
|
||||
"""
|
||||
if name == "ModelRegistry":
|
||||
from application.core.model_registry import ModelRegistry as _MR
|
||||
|
||||
def __init__(self):
|
||||
if not ModelRegistry._initialized:
|
||||
self.models: Dict[str, AvailableModel] = {}
|
||||
self.default_model_id: Optional[str] = None
|
||||
self._load_models()
|
||||
ModelRegistry._initialized = True
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "ModelRegistry":
|
||||
return cls()
|
||||
|
||||
def _load_models(self):
|
||||
from application.core.settings import settings
|
||||
|
||||
self.models.clear()
|
||||
|
||||
# Skip DocsGPT model if using custom OpenAI-compatible endpoint
|
||||
if not settings.OPENAI_BASE_URL:
|
||||
self._add_docsgpt_models(settings)
|
||||
if (
|
||||
settings.OPENAI_API_KEY
|
||||
or (settings.LLM_PROVIDER == "openai" and settings.API_KEY)
|
||||
or settings.OPENAI_BASE_URL
|
||||
):
|
||||
self._add_openai_models(settings)
|
||||
if settings.OPENAI_API_BASE or (
|
||||
settings.LLM_PROVIDER == "azure_openai" and settings.API_KEY
|
||||
):
|
||||
self._add_azure_openai_models(settings)
|
||||
if settings.ANTHROPIC_API_KEY or (
|
||||
settings.LLM_PROVIDER == "anthropic" and settings.API_KEY
|
||||
):
|
||||
self._add_anthropic_models(settings)
|
||||
if settings.GOOGLE_API_KEY or (
|
||||
settings.LLM_PROVIDER == "google" and settings.API_KEY
|
||||
):
|
||||
self._add_google_models(settings)
|
||||
if settings.GROQ_API_KEY or (
|
||||
settings.LLM_PROVIDER == "groq" and settings.API_KEY
|
||||
):
|
||||
self._add_groq_models(settings)
|
||||
if settings.OPEN_ROUTER_API_KEY or (
|
||||
settings.LLM_PROVIDER == "openrouter" and settings.API_KEY
|
||||
):
|
||||
self._add_openrouter_models(settings)
|
||||
if settings.NOVITA_API_KEY or (
|
||||
settings.LLM_PROVIDER == "novita" and settings.API_KEY
|
||||
):
|
||||
self._add_novita_models(settings)
|
||||
if settings.HUGGINGFACE_API_KEY or (
|
||||
settings.LLM_PROVIDER == "huggingface" and settings.API_KEY
|
||||
):
|
||||
self._add_huggingface_models(settings)
|
||||
# Default model selection
|
||||
if settings.LLM_NAME:
|
||||
# Parse LLM_NAME (may be comma-separated)
|
||||
model_names = self._parse_model_names(settings.LLM_NAME)
|
||||
# First model in the list becomes default
|
||||
for model_name in model_names:
|
||||
if model_name in self.models:
|
||||
self.default_model_id = model_name
|
||||
break
|
||||
# Backward compat: try exact match if no parsed model found
|
||||
if not self.default_model_id and settings.LLM_NAME in self.models:
|
||||
self.default_model_id = settings.LLM_NAME
|
||||
|
||||
if not self.default_model_id:
|
||||
if settings.LLM_PROVIDER and settings.API_KEY:
|
||||
for model_id, model in self.models.items():
|
||||
if model.provider.value == settings.LLM_PROVIDER:
|
||||
self.default_model_id = model_id
|
||||
break
|
||||
|
||||
if not self.default_model_id and self.models:
|
||||
self.default_model_id = next(iter(self.models.keys()))
|
||||
logger.info(
|
||||
f"ModelRegistry loaded {len(self.models)} models, default: {self.default_model_id}"
|
||||
)
|
||||
|
||||
def _add_openai_models(self, settings):
|
||||
from application.core.model_configs import (
|
||||
OPENAI_MODELS,
|
||||
create_custom_openai_model,
|
||||
)
|
||||
|
||||
# Check if using local OpenAI-compatible endpoint (Ollama, LM Studio, etc.)
|
||||
using_local_endpoint = bool(
|
||||
settings.OPENAI_BASE_URL and settings.OPENAI_BASE_URL.strip()
|
||||
)
|
||||
|
||||
if using_local_endpoint:
|
||||
# When OPENAI_BASE_URL is set, ONLY register custom models from LLM_NAME
|
||||
# Do NOT add standard OpenAI models (gpt-5.1, etc.)
|
||||
if settings.LLM_NAME:
|
||||
model_names = self._parse_model_names(settings.LLM_NAME)
|
||||
for model_name in model_names:
|
||||
custom_model = create_custom_openai_model(
|
||||
model_name, settings.OPENAI_BASE_URL
|
||||
)
|
||||
self.models[model_name] = custom_model
|
||||
logger.info(
|
||||
f"Registered custom OpenAI model: {model_name} at {settings.OPENAI_BASE_URL}"
|
||||
)
|
||||
else:
|
||||
# Standard OpenAI API usage - add standard models if API key is valid
|
||||
if settings.OPENAI_API_KEY:
|
||||
for model in OPENAI_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_azure_openai_models(self, settings):
|
||||
from application.core.model_configs import AZURE_OPENAI_MODELS
|
||||
|
||||
if settings.LLM_PROVIDER == "azure_openai" and settings.LLM_NAME:
|
||||
for model in AZURE_OPENAI_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in AZURE_OPENAI_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_anthropic_models(self, settings):
|
||||
from application.core.model_configs import ANTHROPIC_MODELS
|
||||
|
||||
if settings.ANTHROPIC_API_KEY:
|
||||
for model in ANTHROPIC_MODELS:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
if settings.LLM_PROVIDER == "anthropic" and settings.LLM_NAME:
|
||||
for model in ANTHROPIC_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in ANTHROPIC_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_google_models(self, settings):
|
||||
from application.core.model_configs import GOOGLE_MODELS
|
||||
|
||||
if settings.GOOGLE_API_KEY:
|
||||
for model in GOOGLE_MODELS:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
if settings.LLM_PROVIDER == "google" and settings.LLM_NAME:
|
||||
for model in GOOGLE_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in GOOGLE_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_groq_models(self, settings):
|
||||
from application.core.model_configs import GROQ_MODELS
|
||||
|
||||
if settings.GROQ_API_KEY:
|
||||
for model in GROQ_MODELS:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
if settings.LLM_PROVIDER == "groq" and settings.LLM_NAME:
|
||||
for model in GROQ_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in GROQ_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_openrouter_models(self, settings):
|
||||
from application.core.model_configs import OPENROUTER_MODELS
|
||||
|
||||
if settings.OPEN_ROUTER_API_KEY:
|
||||
for model in OPENROUTER_MODELS:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
if settings.LLM_PROVIDER == "openrouter" and settings.LLM_NAME:
|
||||
for model in OPENROUTER_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in OPENROUTER_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_novita_models(self, settings):
|
||||
from application.core.model_configs import NOVITA_MODELS
|
||||
|
||||
if settings.NOVITA_API_KEY:
|
||||
for model in NOVITA_MODELS:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
if settings.LLM_PROVIDER == "novita" and settings.LLM_NAME:
|
||||
for model in NOVITA_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in NOVITA_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_docsgpt_models(self, settings):
|
||||
model_id = "docsgpt-local"
|
||||
model = AvailableModel(
|
||||
id=model_id,
|
||||
provider=ModelProvider.DOCSGPT,
|
||||
display_name="DocsGPT Model",
|
||||
description="Local model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=False,
|
||||
supported_attachment_types=[],
|
||||
),
|
||||
)
|
||||
self.models[model_id] = model
|
||||
|
||||
def _add_huggingface_models(self, settings):
|
||||
model_id = "huggingface-local"
|
||||
model = AvailableModel(
|
||||
id=model_id,
|
||||
provider=ModelProvider.HUGGINGFACE,
|
||||
display_name="Hugging Face Model",
|
||||
description="Local Hugging Face model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=False,
|
||||
supported_attachment_types=[],
|
||||
),
|
||||
)
|
||||
self.models[model_id] = model
|
||||
|
||||
def _parse_model_names(self, llm_name: str) -> List[str]:
|
||||
"""
|
||||
Parse LLM_NAME which may contain comma-separated model names.
|
||||
E.g., 'deepseek-r1:1.5b,gemma:2b' -> ['deepseek-r1:1.5b', 'gemma:2b']
|
||||
"""
|
||||
if not llm_name:
|
||||
return []
|
||||
return [name.strip() for name in llm_name.split(",") if name.strip()]
|
||||
|
||||
def get_model(self, model_id: str) -> Optional[AvailableModel]:
|
||||
return self.models.get(model_id)
|
||||
|
||||
def get_all_models(self) -> List[AvailableModel]:
|
||||
return list(self.models.values())
|
||||
|
||||
def get_enabled_models(self) -> List[AvailableModel]:
|
||||
return [m for m in self.models.values() if m.enabled]
|
||||
|
||||
def model_exists(self, model_id: str) -> bool:
|
||||
return model_id in self.models
|
||||
return _MR
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
@@ -1,47 +1,59 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from application.core.model_settings import ModelRegistry
|
||||
from application.core.model_registry import ModelRegistry
|
||||
|
||||
|
||||
def get_api_key_for_provider(provider: str) -> Optional[str]:
|
||||
"""Get the appropriate API key for a provider"""
|
||||
"""Get the appropriate API key for a provider.
|
||||
|
||||
Delegates to the provider plugin's ``get_api_key``. Falls back to the
|
||||
generic ``settings.API_KEY`` for unknown providers.
|
||||
"""
|
||||
from application.core.settings import settings
|
||||
from application.llm.providers import PROVIDERS_BY_NAME
|
||||
|
||||
provider_key_map = {
|
||||
"openai": settings.OPENAI_API_KEY,
|
||||
"openrouter": settings.OPEN_ROUTER_API_KEY,
|
||||
"novita": settings.NOVITA_API_KEY,
|
||||
"anthropic": settings.ANTHROPIC_API_KEY,
|
||||
"google": settings.GOOGLE_API_KEY,
|
||||
"groq": settings.GROQ_API_KEY,
|
||||
"huggingface": settings.HUGGINGFACE_API_KEY,
|
||||
"azure_openai": settings.API_KEY,
|
||||
"docsgpt": None,
|
||||
"llama.cpp": None,
|
||||
}
|
||||
|
||||
provider_key = provider_key_map.get(provider)
|
||||
if provider_key:
|
||||
return provider_key
|
||||
plugin = PROVIDERS_BY_NAME.get(provider)
|
||||
if plugin is not None:
|
||||
key = plugin.get_api_key(settings)
|
||||
if key:
|
||||
return key
|
||||
return settings.API_KEY
|
||||
|
||||
|
||||
def get_all_available_models() -> Dict[str, Dict[str, Any]]:
|
||||
"""Get all available models with metadata for API response"""
|
||||
def get_all_available_models(
|
||||
user_id: Optional[str] = None,
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get all available models with metadata for API response.
|
||||
|
||||
When ``user_id`` is supplied, the user's BYOM custom-model records
|
||||
are merged into the result alongside the built-in catalog.
|
||||
"""
|
||||
registry = ModelRegistry.get_instance()
|
||||
return {model.id: model.to_dict() for model in registry.get_enabled_models()}
|
||||
return {
|
||||
model.id: model.to_dict()
|
||||
for model in registry.get_enabled_models(user_id=user_id)
|
||||
}
|
||||
|
||||
|
||||
def validate_model_id(model_id: str) -> bool:
|
||||
"""Check if a model ID exists in registry"""
|
||||
def validate_model_id(model_id: str, user_id: Optional[str] = None) -> bool:
|
||||
"""Check if a model ID exists in registry.
|
||||
|
||||
``user_id`` enables resolution of per-user BYOM records (UUIDs).
|
||||
Without it, only built-in catalog ids resolve.
|
||||
"""
|
||||
registry = ModelRegistry.get_instance()
|
||||
return registry.model_exists(model_id)
|
||||
return registry.model_exists(model_id, user_id=user_id)
|
||||
|
||||
|
||||
def get_model_capabilities(model_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get capabilities for a specific model"""
|
||||
def get_model_capabilities(
|
||||
model_id: str, user_id: Optional[str] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Get capabilities for a specific model.
|
||||
|
||||
``user_id`` enables resolution of per-user BYOM records.
|
||||
"""
|
||||
registry = ModelRegistry.get_instance()
|
||||
model = registry.get_model(model_id)
|
||||
model = registry.get_model(model_id, user_id=user_id)
|
||||
if model:
|
||||
return {
|
||||
"supported_attachment_types": model.capabilities.supported_attachment_types,
|
||||
@@ -58,36 +70,68 @@ def get_default_model_id() -> str:
|
||||
return registry.default_model_id
|
||||
|
||||
|
||||
def get_provider_from_model_id(model_id: str) -> Optional[str]:
|
||||
"""Get the provider name for a given model_id"""
|
||||
def get_provider_from_model_id(
|
||||
model_id: str, user_id: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
"""Get the provider name for a given model_id.
|
||||
|
||||
``user_id`` enables resolution of per-user BYOM records (UUIDs).
|
||||
Without it, BYOM model ids return ``None`` and the caller falls
|
||||
back to the deployment default.
|
||||
"""
|
||||
registry = ModelRegistry.get_instance()
|
||||
model = registry.get_model(model_id)
|
||||
model = registry.get_model(model_id, user_id=user_id)
|
||||
if model:
|
||||
return model.provider.value
|
||||
return None
|
||||
|
||||
|
||||
def get_token_limit(model_id: str) -> int:
|
||||
"""
|
||||
Get context window (token limit) for a model.
|
||||
Returns model's context_window or default 128000 if model not found.
|
||||
def get_token_limit(model_id: str, user_id: Optional[str] = None) -> int:
|
||||
"""Get context window (token limit) for a model.
|
||||
|
||||
Returns the model's ``context_window`` or ``DEFAULT_LLM_TOKEN_LIMIT``
|
||||
if not found. ``user_id`` enables resolution of per-user BYOM records.
|
||||
"""
|
||||
from application.core.settings import settings
|
||||
|
||||
registry = ModelRegistry.get_instance()
|
||||
model = registry.get_model(model_id)
|
||||
model = registry.get_model(model_id, user_id=user_id)
|
||||
if model:
|
||||
return model.capabilities.context_window
|
||||
return settings.DEFAULT_LLM_TOKEN_LIMIT
|
||||
|
||||
|
||||
def get_base_url_for_model(model_id: str) -> Optional[str]:
|
||||
"""
|
||||
Get the custom base_url for a specific model if configured.
|
||||
Returns None if no custom base_url is set.
|
||||
def get_base_url_for_model(
|
||||
model_id: str, user_id: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
"""Get the custom base_url for a specific model if configured.
|
||||
|
||||
Returns ``None`` if no custom base_url is set. ``user_id`` enables
|
||||
resolution of per-user BYOM records.
|
||||
"""
|
||||
registry = ModelRegistry.get_instance()
|
||||
model = registry.get_model(model_id)
|
||||
model = registry.get_model(model_id, user_id=user_id)
|
||||
if model:
|
||||
return model.base_url
|
||||
return None
|
||||
|
||||
|
||||
def get_api_key_for_model(
|
||||
model_id: str, user_id: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
"""Resolve the API key to use when invoking ``model_id``.
|
||||
|
||||
Priority:
|
||||
1. The model record's own ``api_key`` (BYOM records and
|
||||
``openai_compatible`` YAMLs populate this).
|
||||
2. The provider plugin's settings-based key.
|
||||
|
||||
``user_id`` enables resolution of per-user BYOM records.
|
||||
"""
|
||||
registry = ModelRegistry.get_instance()
|
||||
model = registry.get_model(model_id, user_id=user_id)
|
||||
if model is not None and model.api_key:
|
||||
return model.api_key
|
||||
if model is not None:
|
||||
return get_api_key_for_provider(model.provider.value)
|
||||
return None
|
||||
|
||||
358
application/core/model_yaml.py
Normal file
358
application/core/model_yaml.py
Normal file
@@ -0,0 +1,358 @@
|
||||
"""YAML loader for model catalog files under ``application/core/models/``.
|
||||
|
||||
Each ``*.yaml`` file declares one provider's static model catalog. Files
|
||||
are validated with Pydantic at load time; any parse, schema, or alias
|
||||
error aborts startup with the offending file path in the message.
|
||||
|
||||
For most providers, one YAML maps to one catalog. The
|
||||
``openai_compatible`` provider is special: each YAML file represents a
|
||||
distinct logical endpoint (Mistral, Together, Ollama, ...) with its own
|
||||
``api_key_env`` and ``base_url``. The loader returns a flat list so the
|
||||
registry can distinguish multiple files with the same ``provider:`` value.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Sequence
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from application.core.model_settings import (
|
||||
AvailableModel,
|
||||
ModelCapabilities,
|
||||
ModelProvider,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BUILTIN_MODELS_DIR = Path(__file__).parent / "models"
|
||||
DEFAULTS_FILENAME = "_defaults.yaml"
|
||||
|
||||
|
||||
class _DefaultsFile(BaseModel):
|
||||
"""Schema for ``_defaults.yaml``. Currently just attachment aliases."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
attachment_aliases: Dict[str, List[str]] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class _CapabilityFields(BaseModel):
|
||||
"""Capability fields shared between provider ``defaults:`` and per-model overrides.
|
||||
|
||||
All fields are optional so a per-model override can selectively replace
|
||||
a single field from the provider-level defaults.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
supports_tools: Optional[bool] = None
|
||||
supports_structured_output: Optional[bool] = None
|
||||
supports_streaming: Optional[bool] = None
|
||||
attachments: Optional[List[str]] = None
|
||||
context_window: Optional[int] = None
|
||||
input_cost_per_token: Optional[float] = None
|
||||
output_cost_per_token: Optional[float] = None
|
||||
|
||||
|
||||
class _ModelEntry(_CapabilityFields):
|
||||
"""Schema for one model row inside a YAML's ``models:`` list."""
|
||||
|
||||
id: str
|
||||
display_name: Optional[str] = None
|
||||
description: str = ""
|
||||
enabled: bool = True
|
||||
base_url: Optional[str] = None
|
||||
aliases: List[str] = Field(default_factory=list)
|
||||
|
||||
@field_validator("id")
|
||||
@classmethod
|
||||
def _id_nonempty(cls, v: str) -> str:
|
||||
if not v or not v.strip():
|
||||
raise ValueError("model id must be a non-empty string")
|
||||
return v
|
||||
|
||||
|
||||
class _ProviderFile(BaseModel):
|
||||
"""Schema for one ``<provider>.yaml`` catalog file."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
provider: str
|
||||
defaults: _CapabilityFields = Field(default_factory=_CapabilityFields)
|
||||
models: List[_ModelEntry] = Field(default_factory=list)
|
||||
# openai_compatible metadata. Optional for other providers.
|
||||
display_provider: Optional[str] = None
|
||||
api_key_env: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
|
||||
|
||||
class ProviderCatalog(BaseModel):
|
||||
"""One YAML file's parsed contents, ready for the registry.
|
||||
|
||||
For most providers, multiple catalogs with the same ``provider`` get
|
||||
merged later by the registry. The ``openai_compatible`` provider is
|
||||
the exception: each catalog is treated as a distinct endpoint, with
|
||||
its own ``api_key_env`` and ``base_url``.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
models: List[AvailableModel]
|
||||
source_path: Optional[Path] = None
|
||||
display_provider: Optional[str] = None
|
||||
api_key_env: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class ModelYAMLError(ValueError):
|
||||
"""Raised when a model YAML fails parsing, schema, or alias validation."""
|
||||
|
||||
|
||||
def _expand_attachments(
|
||||
attachments: Sequence[str], aliases: Dict[str, List[str]], source: str
|
||||
) -> List[str]:
|
||||
"""Resolve attachment shorthands (``image``, ``pdf``) to MIME types.
|
||||
|
||||
Raw MIME-typed entries (containing ``/``) pass through unchanged.
|
||||
Unknown aliases raise ``ModelYAMLError``.
|
||||
"""
|
||||
expanded: List[str] = []
|
||||
seen: set = set()
|
||||
for entry in attachments:
|
||||
if "/" in entry:
|
||||
if entry not in seen:
|
||||
expanded.append(entry)
|
||||
seen.add(entry)
|
||||
continue
|
||||
if entry not in aliases:
|
||||
valid = ", ".join(sorted(aliases.keys())) or "<none defined>"
|
||||
raise ModelYAMLError(
|
||||
f"{source}: unknown attachment alias '{entry}'. "
|
||||
f"Valid aliases: {valid}. "
|
||||
"(Or use a raw MIME type like 'image/png'.)"
|
||||
)
|
||||
for mime in aliases[entry]:
|
||||
if mime not in seen:
|
||||
expanded.append(mime)
|
||||
seen.add(mime)
|
||||
return expanded
|
||||
|
||||
|
||||
def _load_defaults(directory: Path) -> Dict[str, List[str]]:
|
||||
"""Load ``_defaults.yaml`` from ``directory`` if it exists."""
|
||||
path = directory / DEFAULTS_FILENAME
|
||||
if not path.exists():
|
||||
return {}
|
||||
try:
|
||||
raw = yaml.safe_load(path.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as e:
|
||||
raise ModelYAMLError(f"{path}: invalid YAML: {e}") from e
|
||||
try:
|
||||
parsed = _DefaultsFile.model_validate(raw)
|
||||
except Exception as e:
|
||||
raise ModelYAMLError(f"{path}: schema error: {e}") from e
|
||||
return parsed.attachment_aliases
|
||||
|
||||
|
||||
def _resolve_provider_enum(name: str, source: Path) -> ModelProvider:
|
||||
try:
|
||||
return ModelProvider(name)
|
||||
except ValueError as e:
|
||||
valid = ", ".join(p.value for p in ModelProvider)
|
||||
raise ModelYAMLError(
|
||||
f"{source}: unknown provider '{name}'. Valid: {valid}"
|
||||
) from e
|
||||
|
||||
|
||||
def _build_model(
|
||||
entry: _ModelEntry,
|
||||
defaults: _CapabilityFields,
|
||||
provider: ModelProvider,
|
||||
aliases: Dict[str, List[str]],
|
||||
source: Path,
|
||||
display_provider: Optional[str] = None,
|
||||
) -> AvailableModel:
|
||||
"""Merge defaults + per-model overrides into a final ``AvailableModel``."""
|
||||
|
||||
def pick(field_name: str, fallback):
|
||||
v = getattr(entry, field_name)
|
||||
if v is not None:
|
||||
return v
|
||||
d = getattr(defaults, field_name)
|
||||
if d is not None:
|
||||
return d
|
||||
return fallback
|
||||
|
||||
raw_attachments = entry.attachments
|
||||
if raw_attachments is None:
|
||||
raw_attachments = defaults.attachments
|
||||
if raw_attachments is None:
|
||||
raw_attachments = []
|
||||
expanded = _expand_attachments(
|
||||
raw_attachments, aliases, f"{source} [model={entry.id}]"
|
||||
)
|
||||
|
||||
caps = ModelCapabilities(
|
||||
supports_tools=pick("supports_tools", False),
|
||||
supports_structured_output=pick("supports_structured_output", False),
|
||||
supports_streaming=pick("supports_streaming", True),
|
||||
supported_attachment_types=expanded,
|
||||
context_window=pick("context_window", 128000),
|
||||
input_cost_per_token=pick("input_cost_per_token", None),
|
||||
output_cost_per_token=pick("output_cost_per_token", None),
|
||||
)
|
||||
|
||||
return AvailableModel(
|
||||
id=entry.id,
|
||||
provider=provider,
|
||||
display_name=entry.display_name or entry.id,
|
||||
description=entry.description,
|
||||
capabilities=caps,
|
||||
enabled=entry.enabled,
|
||||
base_url=entry.base_url,
|
||||
display_provider=display_provider,
|
||||
)
|
||||
|
||||
|
||||
def _load_one_yaml(
|
||||
path: Path, aliases: Dict[str, List[str]]
|
||||
) -> ProviderCatalog:
|
||||
try:
|
||||
raw = yaml.safe_load(path.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as e:
|
||||
raise ModelYAMLError(f"{path}: invalid YAML: {e}") from e
|
||||
try:
|
||||
parsed = _ProviderFile.model_validate(raw)
|
||||
except Exception as e:
|
||||
raise ModelYAMLError(f"{path}: schema error: {e}") from e
|
||||
|
||||
provider_enum = _resolve_provider_enum(parsed.provider, path)
|
||||
models = [
|
||||
_build_model(
|
||||
entry,
|
||||
parsed.defaults,
|
||||
provider_enum,
|
||||
aliases,
|
||||
path,
|
||||
display_provider=parsed.display_provider,
|
||||
)
|
||||
for entry in parsed.models
|
||||
]
|
||||
|
||||
return ProviderCatalog(
|
||||
provider=parsed.provider,
|
||||
models=models,
|
||||
source_path=path,
|
||||
display_provider=parsed.display_provider,
|
||||
api_key_env=parsed.api_key_env,
|
||||
base_url=parsed.base_url,
|
||||
)
|
||||
|
||||
|
||||
_BUILTIN_ALIASES_CACHE: Optional[Dict[str, List[str]]] = None
|
||||
|
||||
|
||||
def builtin_attachment_aliases() -> Dict[str, List[str]]:
|
||||
"""Return the built-in attachment alias map from ``_defaults.yaml``.
|
||||
|
||||
Cached after first read so repeat calls are cheap.
|
||||
"""
|
||||
global _BUILTIN_ALIASES_CACHE
|
||||
if _BUILTIN_ALIASES_CACHE is None:
|
||||
_BUILTIN_ALIASES_CACHE = _load_defaults(BUILTIN_MODELS_DIR)
|
||||
return _BUILTIN_ALIASES_CACHE
|
||||
|
||||
|
||||
def resolve_attachment_alias(alias: str) -> List[str]:
|
||||
"""Resolve a single attachment alias (e.g. ``"image"``) to its
|
||||
canonical MIME-type list. Raises ``ModelYAMLError`` if unknown.
|
||||
"""
|
||||
aliases = builtin_attachment_aliases()
|
||||
if alias not in aliases:
|
||||
valid = ", ".join(sorted(aliases.keys())) or "<none defined>"
|
||||
raise ModelYAMLError(
|
||||
f"Unknown attachment alias '{alias}'. Valid: {valid}"
|
||||
)
|
||||
return list(aliases[alias])
|
||||
|
||||
|
||||
def expand_attachments_lenient(
|
||||
attachments: Sequence[str], source: str
|
||||
) -> List[str]:
|
||||
"""Expand attachment aliases to MIME types, tolerating unknowns.
|
||||
|
||||
Mirrors ``_expand_attachments`` but logs+skips unknown aliases
|
||||
rather than raising. Used for runtime call sites (BYOM registry
|
||||
load) where an operator-side alias-map edit must not drop the
|
||||
entire user's BYOM layer; the strict raise still happens at the
|
||||
API validation boundary.
|
||||
"""
|
||||
aliases = builtin_attachment_aliases()
|
||||
expanded: List[str] = []
|
||||
seen: set = set()
|
||||
for entry in attachments:
|
||||
if "/" in entry:
|
||||
if entry not in seen:
|
||||
expanded.append(entry)
|
||||
seen.add(entry)
|
||||
continue
|
||||
mime_list = aliases.get(entry)
|
||||
if mime_list is None:
|
||||
logger.warning(
|
||||
"%s: skipping unknown attachment alias %r", source, entry,
|
||||
)
|
||||
continue
|
||||
for mime in mime_list:
|
||||
if mime not in seen:
|
||||
expanded.append(mime)
|
||||
seen.add(mime)
|
||||
return expanded
|
||||
|
||||
|
||||
def load_model_yamls(directories: Sequence[Path]) -> List[ProviderCatalog]:
|
||||
"""Load every ``*.yaml`` file (excluding ``_defaults.yaml``) under each
|
||||
directory in order and return a flat list of catalogs.
|
||||
|
||||
Caller is responsible for merging multiple catalogs that target the
|
||||
same provider plugin. The flat-list shape lets ``openai_compatible``
|
||||
keep each file separate (one logical endpoint per file).
|
||||
|
||||
When the same model ``id`` appears in more than one YAML across the
|
||||
directory list, a warning is logged. Order in the returned list
|
||||
preserves load order, so the registry's "later wins" merge gives the
|
||||
later directory's definition.
|
||||
"""
|
||||
catalogs: List[ProviderCatalog] = []
|
||||
seen_ids: Dict[str, Path] = {}
|
||||
|
||||
aliases: Dict[str, List[str]] = {}
|
||||
for d in directories:
|
||||
if not d or not d.exists():
|
||||
continue
|
||||
aliases.update(_load_defaults(d))
|
||||
|
||||
for d in directories:
|
||||
if not d or not d.exists():
|
||||
continue
|
||||
for path in sorted(d.glob("*.yaml")):
|
||||
if path.name == DEFAULTS_FILENAME:
|
||||
continue
|
||||
catalog = _load_one_yaml(path, aliases)
|
||||
catalogs.append(catalog)
|
||||
for m in catalog.models:
|
||||
prior = seen_ids.get(m.id)
|
||||
if prior is not None and prior != path:
|
||||
logger.warning(
|
||||
"Model id %r redefined: %s overrides %s (later wins)",
|
||||
m.id,
|
||||
path,
|
||||
prior,
|
||||
)
|
||||
seen_ids[m.id] = path
|
||||
|
||||
return catalogs
|
||||
213
application/core/models/README.md
Normal file
213
application/core/models/README.md
Normal file
@@ -0,0 +1,213 @@
|
||||
# Model catalogs
|
||||
|
||||
Each `*.yaml` file in this directory declares one provider's model
|
||||
catalog. The registry loads every YAML at boot and joins it to the
|
||||
matching provider plugin under `application/llm/providers/`.
|
||||
|
||||
To add or edit models, you almost always only touch a YAML here — no
|
||||
Python code required.
|
||||
|
||||
## Add a model to an existing provider
|
||||
|
||||
Open the provider's YAML (e.g. `anthropic.yaml`) and append two lines
|
||||
under `models:`:
|
||||
|
||||
```yaml
|
||||
models:
|
||||
- id: claude-3-7-sonnet
|
||||
display_name: Claude 3.7 Sonnet
|
||||
```
|
||||
|
||||
Capabilities default to the provider's `defaults:` block. Override
|
||||
per-model only when needed:
|
||||
|
||||
```yaml
|
||||
- id: claude-3-7-sonnet
|
||||
display_name: Claude 3.7 Sonnet
|
||||
context_window: 500000
|
||||
```
|
||||
|
||||
Restart the app. The new model appears in `/api/models`.
|
||||
|
||||
> The model `id` is what gets stored in agent / workflow records. Once
|
||||
> users start picking the model, **don't rename it** — agent and
|
||||
> workflow rows reference it as a free-form string and silently fall
|
||||
> back to the system default if the id disappears.
|
||||
|
||||
## Add an OpenAI-compatible provider (zero Python)
|
||||
|
||||
Drop a YAML in this directory (or in your `MODELS_CONFIG_DIR`) that uses
|
||||
the `openai_compatible` plugin. Set the env var named in `api_key_env`
|
||||
and you're done — no Python, no settings.py edit, no LLMCreator change:
|
||||
|
||||
```yaml
|
||||
# mistral.yaml
|
||||
provider: openai_compatible
|
||||
display_provider: mistral # shown in /api/models response
|
||||
api_key_env: MISTRAL_API_KEY # env var the plugin reads at boot
|
||||
base_url: https://api.mistral.ai/v1
|
||||
defaults:
|
||||
supports_tools: true
|
||||
context_window: 128000
|
||||
models:
|
||||
- id: mistral-large-latest
|
||||
display_name: Mistral Large
|
||||
- id: mistral-small-latest
|
||||
display_name: Mistral Small
|
||||
```
|
||||
|
||||
`MISTRAL_API_KEY=sk-... ; restart` — Mistral models appear in
|
||||
`/api/models` with `provider: "mistral"`. They route through the OpenAI
|
||||
wire format (it's `OpenAILLM` under the hood) but with Mistral's
|
||||
endpoint and key.
|
||||
|
||||
Multiple `openai_compatible` YAMLs coexist: each file is one logical
|
||||
endpoint with its own `api_key_env` and `base_url`. Drop in
|
||||
`together.yaml`, `fireworks.yaml`, etc. side by side. If an env var
|
||||
isn't set, that catalog is silently skipped at boot (logged at INFO) —
|
||||
no error.
|
||||
|
||||
Working example: `examples/mistral.yaml.example`. Files inside
|
||||
`examples/` aren't loaded by the registry; the glob only picks up
|
||||
`*.yaml` at the top level.
|
||||
|
||||
## Add a provider with its own SDK
|
||||
|
||||
For a provider that doesn't speak OpenAI's wire format, add one Python
|
||||
file to `application/llm/providers/<name>.py`:
|
||||
|
||||
```python
|
||||
from application.llm.providers.base import Provider
|
||||
from application.llm.my_provider import MyLLM
|
||||
|
||||
class MyProvider(Provider):
|
||||
name = "my_provider"
|
||||
llm_class = MyLLM
|
||||
|
||||
def get_api_key(self, settings):
|
||||
return settings.MY_PROVIDER_API_KEY
|
||||
```
|
||||
|
||||
Register it in `application/llm/providers/__init__.py` (one line in
|
||||
`ALL_PROVIDERS`), add `MY_PROVIDER_API_KEY` to `settings.py`, and create
|
||||
`my_provider.yaml` here with the model catalog.
|
||||
|
||||
## Schema reference
|
||||
|
||||
```yaml
|
||||
provider: <string, required> # matches the Provider plugin's `name`
|
||||
|
||||
# openai_compatible only — required for that provider, ignored for others
|
||||
display_provider: <string> # label shown in /api/models response
|
||||
api_key_env: <string> # name of the env var carrying the key
|
||||
base_url: <string> # endpoint URL
|
||||
|
||||
defaults: # optional, applied to every model below
|
||||
supports_tools: bool # default false
|
||||
supports_structured_output: bool # default false
|
||||
supports_streaming: bool # default true
|
||||
attachments: [<alias-or-mime>, ...] # default []
|
||||
context_window: int # default 128000
|
||||
input_cost_per_token: float # default null
|
||||
output_cost_per_token: float # default null
|
||||
|
||||
models: # required
|
||||
- id: <string, required> # the value persisted in agent records
|
||||
display_name: <string> # default: id
|
||||
description: <string> # default: ""
|
||||
enabled: bool # default true; false hides from /api/models
|
||||
base_url: <string> # optional custom endpoint for this model
|
||||
# All `defaults:` fields above can be overridden here per-model.
|
||||
```
|
||||
|
||||
### Attachment aliases
|
||||
|
||||
The `attachments:` list can mix human-readable aliases with raw MIME
|
||||
types. Aliases are defined in `_defaults.yaml`:
|
||||
|
||||
| Alias | Expands to |
|
||||
|---|---|
|
||||
| `image` | `image/png`, `image/jpeg`, `image/jpg`, `image/webp`, `image/gif` |
|
||||
| `pdf` | `application/pdf` |
|
||||
| `audio` | `audio/mpeg`, `audio/wav`, `audio/ogg` |
|
||||
|
||||
Use raw MIME types when you need surgical control:
|
||||
|
||||
```yaml
|
||||
attachments: [image/png, image/webp] # only these two
|
||||
```
|
||||
|
||||
## Operator-supplied YAMLs (`MODELS_CONFIG_DIR`)
|
||||
|
||||
Set the `MODELS_CONFIG_DIR` env var (or `.env` entry) to a directory
|
||||
path. Every `*.yaml` in that directory is loaded **after** the built-in
|
||||
catalog under `application/core/models/`. Operators use this to:
|
||||
|
||||
- Add new `openai_compatible` providers (Mistral, Together, Fireworks,
|
||||
Ollama, ...) without forking the repo.
|
||||
- Extend an existing provider's catalog with extra models — append
|
||||
models under `provider: anthropic` and they show up alongside the
|
||||
built-ins.
|
||||
- Override a built-in model's capabilities — declare the same `id`
|
||||
with different fields (e.g. a higher `context_window`). Later wins;
|
||||
the override is logged as a `WARNING` so you can audit it.
|
||||
|
||||
Things you cannot do via `MODELS_CONFIG_DIR`:
|
||||
|
||||
- Add a brand-new non-OpenAI provider — that needs a Python plugin
|
||||
under `application/llm/providers/` (see "Add a provider with its own
|
||||
SDK" above). Operator YAMLs may only target a `provider:` value that
|
||||
already has a registered plugin.
|
||||
|
||||
### Example: Docker
|
||||
|
||||
Mount your model YAMLs into the container and point the env var at the
|
||||
mount path:
|
||||
|
||||
```yaml
|
||||
# docker-compose.yml
|
||||
services:
|
||||
app:
|
||||
image: arc53/docsgpt
|
||||
environment:
|
||||
MODELS_CONFIG_DIR: /etc/docsgpt/models
|
||||
MISTRAL_API_KEY: ${MISTRAL_API_KEY}
|
||||
volumes:
|
||||
- ./my-models:/etc/docsgpt/models:ro
|
||||
```
|
||||
|
||||
Then `./my-models/mistral.yaml` (the file from
|
||||
`examples/mistral.yaml.example`) gets picked up at boot.
|
||||
|
||||
### Example: Kubernetes
|
||||
|
||||
Mount a `ConfigMap` containing your YAMLs at a known path and set
|
||||
`MODELS_CONFIG_DIR` on the deployment. The same `examples/mistral.yaml.example`
|
||||
becomes a key in the ConfigMap.
|
||||
|
||||
### Misconfiguration
|
||||
|
||||
If `MODELS_CONFIG_DIR` is set but the path doesn't exist (or isn't a
|
||||
directory), the app logs a `WARNING` at boot and continues with just
|
||||
the built-in catalog. The app does *not* fail to start — operators can
|
||||
ship config drift without taking down the service — but the warning is
|
||||
loud enough to surface in any reasonable log aggregator.
|
||||
|
||||
## Validation
|
||||
|
||||
YAMLs are parsed with Pydantic at boot. The app fails to start with a
|
||||
clear error message if:
|
||||
|
||||
- a top-level key is unknown
|
||||
- a model is missing `id`
|
||||
- an attachment alias isn't defined
|
||||
- the `provider:` value isn't registered as a plugin
|
||||
|
||||
This is intentional — silent fallbacks would mean users don't notice
|
||||
their model picks broke until they hit the API.
|
||||
|
||||
## Reserved fields (not yet implemented)
|
||||
|
||||
- `aliases:` on a model — old IDs that resolve to this model. Reserved
|
||||
for future renames; the schema accepts the field but it is not yet
|
||||
acted on.
|
||||
18
application/core/models/_defaults.yaml
Normal file
18
application/core/models/_defaults.yaml
Normal file
@@ -0,0 +1,18 @@
|
||||
# Global defaults applied across every model YAML in this directory.
|
||||
# Keep this file sparse — per-provider `defaults:` blocks are clearer
|
||||
# than a deep global default chain. This file is for things that
|
||||
# genuinely never vary, like the meaning of "image".
|
||||
|
||||
attachment_aliases:
|
||||
image:
|
||||
- image/png
|
||||
- image/jpeg
|
||||
- image/jpg
|
||||
- image/webp
|
||||
- image/gif
|
||||
pdf:
|
||||
- application/pdf
|
||||
audio:
|
||||
- audio/mpeg
|
||||
- audio/wav
|
||||
- audio/ogg
|
||||
23
application/core/models/anthropic.yaml
Normal file
23
application/core/models/anthropic.yaml
Normal file
@@ -0,0 +1,23 @@
|
||||
provider: anthropic
|
||||
defaults:
|
||||
supports_tools: true
|
||||
attachments: [image]
|
||||
context_window: 200000
|
||||
|
||||
models:
|
||||
- id: claude-opus-4-7
|
||||
display_name: Claude Opus 4.7
|
||||
description: Most capable Claude model for complex reasoning and agentic coding
|
||||
context_window: 1000000
|
||||
supports_structured_output: true
|
||||
|
||||
- id: claude-sonnet-4-6
|
||||
display_name: Claude Sonnet 4.6
|
||||
description: Best balance of speed and intelligence with extended thinking
|
||||
context_window: 1000000
|
||||
supports_structured_output: true
|
||||
|
||||
- id: claude-haiku-4-5
|
||||
display_name: Claude Haiku 4.5
|
||||
description: Fastest Claude model with near-frontier intelligence
|
||||
supports_structured_output: true
|
||||
31
application/core/models/azure_openai.yaml
Normal file
31
application/core/models/azure_openai.yaml
Normal file
@@ -0,0 +1,31 @@
|
||||
# Azure OpenAI catalog.
|
||||
#
|
||||
# IMPORTANT: For Azure OpenAI, the `id` field is the **deployment name**, not
|
||||
# a model name. Deployment names are arbitrary strings the operator chooses
|
||||
# in Azure portal (or via ARM/Bicep/Terraform) when they create a deployment
|
||||
# for a given underlying model + version.
|
||||
#
|
||||
# The IDs below are sensible defaults that mirror the underlying OpenAI
|
||||
# model name (prefixed with `azure-`). Operators almost always need to
|
||||
# override them via `MODELS_CONFIG_DIR` to match the deployment names that
|
||||
# actually exist in their Azure resource. The `display_name`, capability
|
||||
# flags, and `context_window` reflect the underlying OpenAI model.
|
||||
provider: azure_openai
|
||||
|
||||
defaults:
|
||||
supports_tools: true
|
||||
supports_structured_output: true
|
||||
attachments: [image]
|
||||
context_window: 400000
|
||||
|
||||
models:
|
||||
- id: azure-gpt-5.5
|
||||
display_name: Azure OpenAI GPT-5.5
|
||||
description: Azure-hosted flagship frontier model for complex reasoning, coding, and agentic work with a 1M-token context window
|
||||
context_window: 1050000
|
||||
- id: azure-gpt-5.4-mini
|
||||
display_name: Azure OpenAI GPT-5.4 Mini
|
||||
description: Azure-hosted cost-efficient GPT-5.4-class model for high-volume coding, computer use, and subagent workloads
|
||||
- id: azure-gpt-5.4-nano
|
||||
display_name: Azure OpenAI GPT-5.4 Nano
|
||||
description: Azure-hosted cheapest GPT-5.4-class model, optimized for simple high-volume tasks where speed and cost matter most
|
||||
7
application/core/models/docsgpt.yaml
Normal file
7
application/core/models/docsgpt.yaml
Normal file
@@ -0,0 +1,7 @@
|
||||
provider: docsgpt
|
||||
|
||||
models:
|
||||
- id: docsgpt-local
|
||||
display_name: DocsGPT Model
|
||||
description: Local model
|
||||
supports_tools: false
|
||||
31
application/core/models/examples/mistral.yaml.example
Normal file
31
application/core/models/examples/mistral.yaml.example
Normal file
@@ -0,0 +1,31 @@
|
||||
# EXAMPLE — copy this file to ../mistral.yaml (or to your
|
||||
# MODELS_CONFIG_DIR) and set MISTRAL_API_KEY in your environment.
|
||||
#
|
||||
# This is the entire integration. No Python required: the
|
||||
# `openai_compatible` plugin reads `api_key_env` and `base_url` from
|
||||
# the file and routes calls through the OpenAI wire format.
|
||||
#
|
||||
# Files in this `examples/` directory are NOT loaded by the registry
|
||||
# (the loader globs *.yaml at the top level only).
|
||||
|
||||
provider: openai_compatible
|
||||
display_provider: mistral # shown in /api/models response
|
||||
api_key_env: MISTRAL_API_KEY # env var the plugin reads
|
||||
base_url: https://api.mistral.ai/v1 # OpenAI-compatible endpoint
|
||||
|
||||
defaults:
|
||||
supports_tools: true
|
||||
context_window: 128000
|
||||
|
||||
models:
|
||||
- id: mistral-large-latest
|
||||
display_name: Mistral Large
|
||||
description: Top-tier reasoning model
|
||||
|
||||
- id: mistral-small-latest
|
||||
display_name: Mistral Small
|
||||
description: Fast, cost-efficient
|
||||
|
||||
- id: codestral-latest
|
||||
display_name: Codestral
|
||||
description: Code-specialized model
|
||||
17
application/core/models/google.yaml
Normal file
17
application/core/models/google.yaml
Normal file
@@ -0,0 +1,17 @@
|
||||
provider: google
|
||||
defaults:
|
||||
supports_tools: true
|
||||
supports_structured_output: true
|
||||
attachments: [pdf, image]
|
||||
context_window: 1048576
|
||||
|
||||
models:
|
||||
- id: gemini-3.1-pro-preview
|
||||
display_name: Gemini 3.1 Pro
|
||||
description: Most capable Gemini 3 model with advanced reasoning and agentic coding (preview)
|
||||
- id: gemini-3-flash-preview
|
||||
display_name: Gemini 3 Flash
|
||||
description: Frontier-class performance for low-latency, high-volume tasks (preview)
|
||||
- id: gemini-3.1-flash-lite-preview
|
||||
display_name: Gemini 3.1 Flash-Lite
|
||||
description: Cost-efficient frontier-class multimodal model for high-throughput workloads (preview)
|
||||
16
application/core/models/groq.yaml
Normal file
16
application/core/models/groq.yaml
Normal file
@@ -0,0 +1,16 @@
|
||||
provider: groq
|
||||
defaults:
|
||||
supports_tools: true
|
||||
context_window: 131072
|
||||
|
||||
models:
|
||||
- id: openai/gpt-oss-120b
|
||||
display_name: GPT-OSS 120B
|
||||
description: OpenAI's open-weight 120B flagship served on Groq's LPU hardware; strong general reasoning with strict structured output support
|
||||
supports_structured_output: true
|
||||
- id: llama-3.3-70b-versatile
|
||||
display_name: Llama 3.3 70B Versatile
|
||||
description: Meta's Llama 3.3 70B for general-purpose chat with parallel tool use
|
||||
- id: llama-3.1-8b-instant
|
||||
display_name: Llama 3.1 8B Instant
|
||||
description: Small, very low-latency Llama model (~560 tok/s) with parallel tool use
|
||||
7
application/core/models/huggingface.yaml
Normal file
7
application/core/models/huggingface.yaml
Normal file
@@ -0,0 +1,7 @@
|
||||
provider: huggingface
|
||||
|
||||
models:
|
||||
- id: huggingface-local
|
||||
display_name: Hugging Face Model
|
||||
description: Local Hugging Face model
|
||||
supports_tools: false
|
||||
21
application/core/models/novita.yaml
Normal file
21
application/core/models/novita.yaml
Normal file
@@ -0,0 +1,21 @@
|
||||
provider: novita
|
||||
defaults:
|
||||
supports_tools: true
|
||||
supports_structured_output: true
|
||||
|
||||
models:
|
||||
- id: deepseek/deepseek-v4-pro
|
||||
display_name: DeepSeek V4 Pro
|
||||
description: 1.6T MoE (49B active) with 1M context, hybrid CSA/HCA attention, top-tier reasoning and agentic coding
|
||||
context_window: 1048576
|
||||
|
||||
- id: moonshotai/kimi-k2.6
|
||||
display_name: Kimi K2.6
|
||||
description: 1T-parameter open-weight MoE with native vision/video, multi-step tool calling, and agentic long-horizon execution
|
||||
attachments: [image]
|
||||
context_window: 262144
|
||||
|
||||
- id: zai-org/glm-5
|
||||
display_name: GLM-5
|
||||
description: Z.AI 754B-parameter MoE with strong general reasoning, function calling, and structured output
|
||||
context_window: 202800
|
||||
18
application/core/models/openai.yaml
Normal file
18
application/core/models/openai.yaml
Normal file
@@ -0,0 +1,18 @@
|
||||
provider: openai
|
||||
defaults:
|
||||
supports_tools: true
|
||||
supports_structured_output: true
|
||||
attachments: [image]
|
||||
context_window: 400000
|
||||
|
||||
models:
|
||||
- id: gpt-5.5
|
||||
display_name: GPT-5.5
|
||||
description: Flagship frontier model for complex reasoning, coding, and agentic work with a 1M-token context window
|
||||
context_window: 1050000
|
||||
- id: gpt-5.4-mini
|
||||
display_name: GPT-5.4 Mini
|
||||
description: Cost-efficient GPT-5.4-class model for high-volume coding, computer use, and subagent workloads
|
||||
- id: gpt-5.4-nano
|
||||
display_name: GPT-5.4 Nano
|
||||
description: Cheapest GPT-5.4-class model, optimized for simple high-volume tasks where speed and cost matter most
|
||||
25
application/core/models/openrouter.yaml
Normal file
25
application/core/models/openrouter.yaml
Normal file
@@ -0,0 +1,25 @@
|
||||
provider: openrouter
|
||||
defaults:
|
||||
supports_tools: true
|
||||
attachments: [image]
|
||||
context_window: 128000
|
||||
|
||||
models:
|
||||
- id: qwen/qwen3-coder:free
|
||||
display_name: Qwen3 Coder (free)
|
||||
description: Free-tier 480B MoE coder model with strong agentic tool use; rate-limited
|
||||
context_window: 262000
|
||||
attachments: []
|
||||
|
||||
- id: deepseek/deepseek-v3.2
|
||||
display_name: DeepSeek V3.2
|
||||
description: Open-weights reasoning model, very low cost (~$0.25 in / $0.38 out per 1M)
|
||||
context_window: 131072
|
||||
attachments: []
|
||||
supports_structured_output: true
|
||||
|
||||
- id: anthropic/claude-sonnet-4.6
|
||||
display_name: Claude Sonnet 4.6 (via OpenRouter)
|
||||
description: Frontier Sonnet-class model with 1M context, vision, and extended thinking
|
||||
context_window: 1000000
|
||||
supports_structured_output: true
|
||||
@@ -23,6 +23,10 @@ class Settings(BaseSettings):
|
||||
EMBEDDINGS_NAME: str = "huggingface_sentence-transformers/all-mpnet-base-v2"
|
||||
EMBEDDINGS_BASE_URL: Optional[str] = None # Remote embeddings API URL (OpenAI-compatible)
|
||||
EMBEDDINGS_KEY: Optional[str] = None # api key for embeddings (if using openai, just copy API_KEY)
|
||||
# Optional directory of operator-supplied model YAMLs, loaded after the
|
||||
# built-in catalog under application/core/models/. Later wins on
|
||||
# duplicate model id. See application/core/models/README.md.
|
||||
MODELS_CONFIG_DIR: Optional[str] = None
|
||||
|
||||
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
|
||||
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
|
||||
|
||||
@@ -17,6 +17,8 @@ class BaseLLM(ABC):
|
||||
model_id=None,
|
||||
base_url=None,
|
||||
backup_models=None,
|
||||
model_user_id=None,
|
||||
capabilities=None,
|
||||
):
|
||||
self.decoded_token = decoded_token
|
||||
self.agent_id = str(agent_id) if agent_id else None
|
||||
@@ -25,6 +27,12 @@ class BaseLLM(ABC):
|
||||
self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
|
||||
self._backup_models = backup_models or []
|
||||
self._fallback_llm = None
|
||||
# Registry-resolved per-model capability overrides (BYOM caps,
|
||||
# operator YAML). None falls back to provider-class defaults.
|
||||
self.capabilities = capabilities
|
||||
# BYOM-resolution scope captured at LLM creation time so backup
|
||||
# / fallback lookups hit the same per-user layer as the primary.
|
||||
self.model_user_id = model_user_id
|
||||
|
||||
@property
|
||||
def fallback_llm(self):
|
||||
@@ -39,10 +47,19 @@ class BaseLLM(ABC):
|
||||
get_api_key_for_provider,
|
||||
)
|
||||
|
||||
# Try per-agent backup models first
|
||||
# model_user_id (BYOM scope) takes precedence over the caller's
|
||||
# sub so shared-agent backups resolve under the owner's layer.
|
||||
caller_sub = (
|
||||
self.decoded_token.get("sub")
|
||||
if isinstance(self.decoded_token, dict)
|
||||
else None
|
||||
)
|
||||
backup_user_id = self.model_user_id or caller_sub
|
||||
for backup_model_id in self._backup_models:
|
||||
try:
|
||||
provider = get_provider_from_model_id(backup_model_id)
|
||||
provider = get_provider_from_model_id(
|
||||
backup_model_id, user_id=backup_user_id
|
||||
)
|
||||
if not provider:
|
||||
logger.warning(
|
||||
f"Could not resolve provider for backup model: {backup_model_id}"
|
||||
@@ -56,6 +73,7 @@ class BaseLLM(ABC):
|
||||
decoded_token=self.decoded_token,
|
||||
model_id=backup_model_id,
|
||||
agent_id=self.agent_id,
|
||||
model_user_id=self.model_user_id,
|
||||
)
|
||||
logger.info(
|
||||
f"Fallback LLM initialized from agent backup model: "
|
||||
@@ -68,7 +86,10 @@ class BaseLLM(ABC):
|
||||
)
|
||||
continue
|
||||
|
||||
# Fall back to global FALLBACK_* settings
|
||||
# Fall back to global FALLBACK_* settings. Forward
|
||||
# ``model_user_id`` here too: deployments can configure
|
||||
# ``FALLBACK_LLM_NAME`` to a BYOM UUID, and that UUID is owned
|
||||
# by the same user the primary model was resolved under.
|
||||
if settings.FALLBACK_LLM_PROVIDER:
|
||||
try:
|
||||
self._fallback_llm = LLMCreator.create_llm(
|
||||
@@ -78,6 +99,7 @@ class BaseLLM(ABC):
|
||||
decoded_token=self.decoded_token,
|
||||
model_id=settings.FALLBACK_LLM_NAME,
|
||||
agent_id=self.agent_id,
|
||||
model_user_id=self.model_user_id,
|
||||
)
|
||||
logger.info(
|
||||
f"Fallback LLM initialized from global settings: "
|
||||
|
||||
@@ -470,10 +470,14 @@ class LLMHandler(ABC):
|
||||
)
|
||||
return self._perform_in_memory_compression(agent, messages)
|
||||
|
||||
# Use orchestrator to perform compression
|
||||
# Use orchestrator to perform compression. ``model_user_id``
|
||||
# keeps BYOM registry resolution scoped to the model owner
|
||||
# (shared-agent dispatch) while ``user_id`` stays the caller
|
||||
# for the conversation access check.
|
||||
result = orchestrator.compress_mid_execution(
|
||||
conversation_id=agent.conversation_id,
|
||||
user_id=agent.initial_user_id,
|
||||
model_user_id=getattr(agent, "model_user_id", None),
|
||||
model_id=agent.model_id,
|
||||
decoded_token=getattr(agent, "decoded_token", {}),
|
||||
current_conversation=conversation,
|
||||
@@ -577,7 +581,20 @@ class LLMHandler(ABC):
|
||||
if settings.COMPRESSION_MODEL_OVERRIDE
|
||||
else agent.model_id
|
||||
)
|
||||
provider = get_provider_from_model_id(compression_model)
|
||||
agent_decoded = getattr(agent, "decoded_token", None)
|
||||
caller_sub = (
|
||||
agent_decoded.get("sub")
|
||||
if isinstance(agent_decoded, dict)
|
||||
else None
|
||||
)
|
||||
# Use model-owner scope (mirrors orchestrator path) so
|
||||
# shared-agent owner-BYOM resolves under the owner's layer.
|
||||
compression_user_id = (
|
||||
getattr(agent, "model_user_id", None) or caller_sub
|
||||
)
|
||||
provider = get_provider_from_model_id(
|
||||
compression_model, user_id=compression_user_id
|
||||
)
|
||||
api_key = get_api_key_for_provider(provider)
|
||||
compression_llm = LLMCreator.create_llm(
|
||||
provider,
|
||||
@@ -586,6 +603,7 @@ class LLMHandler(ABC):
|
||||
getattr(agent, "decoded_token", None),
|
||||
model_id=compression_model,
|
||||
agent_id=getattr(agent, "agent_id", None),
|
||||
model_user_id=compression_user_id,
|
||||
)
|
||||
|
||||
# Create service without DB persistence capability
|
||||
@@ -921,8 +939,15 @@ class LLMHandler(ABC):
|
||||
}
|
||||
return ""
|
||||
|
||||
# ``agent.model_id`` is the registry id (a UUID for BYOM
|
||||
# records). Use the LLM's own model_id, which LLMCreator
|
||||
# already resolved to the upstream model name. Built-ins:
|
||||
# the two are equal; BYOM: the upstream name like
|
||||
# "mistral-large-latest" instead of the UUID.
|
||||
response = agent.llm.gen(
|
||||
model=agent.model_id, messages=messages, tools=agent.tools
|
||||
model=getattr(agent.llm, "model_id", None) or agent.model_id,
|
||||
messages=messages,
|
||||
tools=agent.tools,
|
||||
)
|
||||
parsed = self.parse_response(response)
|
||||
self.llm_calls.append(build_stack_data(agent.llm))
|
||||
@@ -1011,8 +1036,11 @@ class LLMHandler(ABC):
|
||||
})
|
||||
logger.info("Context limit reached - instructing agent to wrap up")
|
||||
|
||||
# See note above on agent.model_id vs llm.model_id.
|
||||
response = agent.llm.gen_stream(
|
||||
model=agent.model_id, messages=messages, tools=agent.tools if not agent.context_limit_reached else None
|
||||
model=getattr(agent.llm, "model_id", None) or agent.model_id,
|
||||
messages=messages,
|
||||
tools=agent.tools if not agent.context_limit_reached else None,
|
||||
)
|
||||
self.llm_calls.append(build_stack_data(agent.llm))
|
||||
|
||||
|
||||
@@ -1,34 +1,11 @@
|
||||
import logging
|
||||
|
||||
from application.llm.anthropic import AnthropicLLM
|
||||
from application.llm.docsgpt_provider import DocsGPTAPILLM
|
||||
from application.llm.google_ai import GoogleLLM
|
||||
from application.llm.groq import GroqLLM
|
||||
from application.llm.llama_cpp import LlamaCpp
|
||||
from application.llm.novita import NovitaLLM
|
||||
from application.llm.openai import AzureOpenAILLM, OpenAILLM
|
||||
from application.llm.premai import PremAILLM
|
||||
from application.llm.sagemaker import SagemakerAPILLM
|
||||
from application.llm.open_router import OpenRouterLLM
|
||||
from application.llm.providers import PROVIDERS_BY_NAME
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMCreator:
|
||||
llms = {
|
||||
"openai": OpenAILLM,
|
||||
"azure_openai": AzureOpenAILLM,
|
||||
"sagemaker": SagemakerAPILLM,
|
||||
"llama.cpp": LlamaCpp,
|
||||
"anthropic": AnthropicLLM,
|
||||
"docsgpt": DocsGPTAPILLM,
|
||||
"premai": PremAILLM,
|
||||
"groq": GroqLLM,
|
||||
"google": GoogleLLM,
|
||||
"novita": NovitaLLM,
|
||||
"openrouter": OpenRouterLLM,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create_llm(
|
||||
cls,
|
||||
@@ -39,28 +16,111 @@ class LLMCreator:
|
||||
model_id=None,
|
||||
agent_id=None,
|
||||
backup_models=None,
|
||||
model_user_id=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
from application.core.model_utils import get_base_url_for_model
|
||||
"""Construct an LLM for the given provider ``type``.
|
||||
|
||||
llm_class = cls.llms.get(type.lower())
|
||||
if not llm_class:
|
||||
``model_user_id`` is the BYOM-resolution scope. Defaults to
|
||||
``decoded_token['sub']`` (the caller). Pass it explicitly when
|
||||
the model record belongs to a *different* user — most notably
|
||||
for shared-agent dispatch, where the agent's stored
|
||||
``default_model_id`` is the owner's BYOM UUID but
|
||||
``decoded_token`` represents the caller.
|
||||
"""
|
||||
from application.core.model_registry import ModelRegistry
|
||||
from application.security.safe_url import (
|
||||
UnsafeUserUrlError,
|
||||
pinned_httpx_client,
|
||||
validate_user_base_url,
|
||||
)
|
||||
|
||||
plugin = PROVIDERS_BY_NAME.get(type.lower())
|
||||
if plugin is None or plugin.llm_class is None:
|
||||
raise ValueError(f"No LLM class found for type {type}")
|
||||
|
||||
# Extract base_url from model configuration if model_id is provided
|
||||
# Prefer per-model endpoint config from the registry. This is what
|
||||
# makes openai_compatible AND end-user BYOM work without changing
|
||||
# every call site: if the registered AvailableModel carries its
|
||||
# own api_key / base_url, they win over whatever the caller
|
||||
# resolved via the provider plugin.
|
||||
#
|
||||
# End-user BYOM lookups need the user_id from decoded_token to
|
||||
# find the user's per-user models layer (built-in models resolve
|
||||
# without it, so this stays back-compat).
|
||||
base_url = None
|
||||
upstream_model_id = model_id
|
||||
capabilities = None
|
||||
if model_id:
|
||||
base_url = get_base_url_for_model(model_id)
|
||||
user_id = model_user_id
|
||||
if user_id is None:
|
||||
user_id = (
|
||||
(decoded_token or {}).get("sub") if decoded_token else None
|
||||
)
|
||||
model = ModelRegistry.get_instance().get_model(model_id, user_id=user_id)
|
||||
if model is not None:
|
||||
# Forward registry caps so the LLM enforces them at
|
||||
# dispatch (built-in classes hard-code True otherwise).
|
||||
capabilities = getattr(model, "capabilities", None)
|
||||
# SECURITY: refuse user-source dispatch without its own
|
||||
# api_key (would leak settings.API_KEY to base_url).
|
||||
if (
|
||||
getattr(model, "source", "builtin") == "user"
|
||||
and not model.api_key
|
||||
):
|
||||
raise ValueError(
|
||||
f"Custom model {model_id!r} has no usable API key "
|
||||
"(decryption may have failed). Re-save the model "
|
||||
"in settings to dispatch it."
|
||||
)
|
||||
if model.api_key:
|
||||
api_key = model.api_key
|
||||
if model.base_url:
|
||||
base_url = model.base_url
|
||||
# For BYOM the registry id is a UUID; the upstream API
|
||||
# call needs the user's typed model name instead.
|
||||
if model.upstream_model_id:
|
||||
upstream_model_id = model.upstream_model_id
|
||||
|
||||
return llm_class(
|
||||
# SECURITY: re-validate at dispatch (defense in depth
|
||||
# for pre-guard rows / YAML-supplied entries). The
|
||||
# pinned httpx.Client below is what actually closes the
|
||||
# DNS-rebinding TOCTOU window.
|
||||
if base_url and getattr(model, "source", "builtin") == "user":
|
||||
try:
|
||||
validate_user_base_url(base_url)
|
||||
except UnsafeUserUrlError as e:
|
||||
raise ValueError(
|
||||
f"Refusing to dispatch model {model_id!r}: {e}"
|
||||
) from e
|
||||
# Pinned httpx.Client: resolves once, validates, and
|
||||
# binds the SDK's outbound socket to the validated IP
|
||||
# (preserves Host / SNI). Future BYOM providers must
|
||||
# opt in explicitly — only openai_compatible takes
|
||||
# http_client today.
|
||||
if plugin.name == "openai_compatible":
|
||||
try:
|
||||
kwargs["http_client"] = pinned_httpx_client(
|
||||
base_url
|
||||
)
|
||||
except UnsafeUserUrlError as e:
|
||||
raise ValueError(
|
||||
f"Refusing to dispatch model {model_id!r}: {e}"
|
||||
) from e
|
||||
|
||||
# Forward model_user_id so backup/fallback resolves under the
|
||||
# owner's scope on shared-agent dispatch.
|
||||
return plugin.llm_class(
|
||||
api_key,
|
||||
user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
model_id=model_id,
|
||||
model_id=upstream_model_id,
|
||||
agent_id=agent_id,
|
||||
base_url=base_url,
|
||||
backup_models=backup_models,
|
||||
model_user_id=model_user_id,
|
||||
capabilities=capabilities,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -62,7 +62,15 @@ def _truncate_base64_for_logging(messages):
|
||||
|
||||
class OpenAILLM(BaseLLM):
|
||||
|
||||
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
api_key=None,
|
||||
user_api_key=None,
|
||||
base_url=None,
|
||||
http_client=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
self.api_key = api_key or settings.OPENAI_API_KEY or settings.API_KEY
|
||||
@@ -80,7 +88,18 @@ class OpenAILLM(BaseLLM):
|
||||
else:
|
||||
effective_base_url = "https://api.openai.com/v1"
|
||||
|
||||
self.client = OpenAI(api_key=self.api_key, base_url=effective_base_url)
|
||||
# http_client (set by LLMCreator for BYOM) is a DNS-rebinding-safe
|
||||
# httpx.Client; without it the SDK re-resolves DNS per request.
|
||||
if http_client is not None:
|
||||
self.client = OpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=effective_base_url,
|
||||
http_client=http_client,
|
||||
)
|
||||
else:
|
||||
self.client = OpenAI(
|
||||
api_key=self.api_key, base_url=effective_base_url
|
||||
)
|
||||
self.storage = StorageCreator.get_storage()
|
||||
|
||||
def _clean_messages_openai(self, messages):
|
||||
@@ -243,6 +262,13 @@ class OpenAILLM(BaseLLM):
|
||||
if "max_tokens" in kwargs:
|
||||
kwargs["max_completion_tokens"] = kwargs.pop("max_tokens")
|
||||
|
||||
# Defense-in-depth: drop tools / response_format if the
|
||||
# registry's capability flags deny them.
|
||||
if tools and not self._supports_tools():
|
||||
tools = None
|
||||
if response_format and not self._supports_structured_output():
|
||||
response_format = None
|
||||
|
||||
request_params = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
@@ -279,6 +305,13 @@ class OpenAILLM(BaseLLM):
|
||||
if "max_tokens" in kwargs:
|
||||
kwargs["max_completion_tokens"] = kwargs.pop("max_tokens")
|
||||
|
||||
# See _raw_gen for rationale — drop tools/response_format when the
|
||||
# registry-provided capabilities say the model doesn't support them.
|
||||
if tools and not self._supports_tools():
|
||||
tools = None
|
||||
if response_format and not self._supports_structured_output():
|
||||
response_format = None
|
||||
|
||||
request_params = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
@@ -320,9 +353,17 @@ class OpenAILLM(BaseLLM):
|
||||
response.close()
|
||||
|
||||
def _supports_tools(self):
|
||||
# When the LLM was constructed via LLMCreator with a registered
|
||||
# AvailableModel, ``self.capabilities`` is the per-model record.
|
||||
# BYOM users can disable tool support; respect that. Otherwise
|
||||
# OpenAI's API supports tools by default.
|
||||
if self.capabilities is not None:
|
||||
return bool(self.capabilities.supports_tools)
|
||||
return True
|
||||
|
||||
def _supports_structured_output(self):
|
||||
if self.capabilities is not None:
|
||||
return bool(self.capabilities.supports_structured_output)
|
||||
return True
|
||||
|
||||
def prepare_structured_output_format(self, json_schema):
|
||||
@@ -389,8 +430,14 @@ class OpenAILLM(BaseLLM):
|
||||
Returns:
|
||||
list: List of supported MIME types
|
||||
"""
|
||||
from application.core.model_configs import OPENAI_ATTACHMENTS
|
||||
return OPENAI_ATTACHMENTS
|
||||
# Per-model caps from the registry win when present — a BYOM
|
||||
# endpoint that doesn't accept images would otherwise still be
|
||||
# sent base64 image parts because the OpenAI default below
|
||||
# advertises the image alias unconditionally.
|
||||
if self.capabilities is not None:
|
||||
return list(self.capabilities.supported_attachment_types or [])
|
||||
from application.core.model_yaml import resolve_attachment_alias
|
||||
return resolve_attachment_alias("image")
|
||||
|
||||
def prepare_messages_with_attachments(self, messages, attachments=None):
|
||||
"""
|
||||
|
||||
51
application/llm/providers/__init__.py
Normal file
51
application/llm/providers/__init__.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""Provider plugin registry.
|
||||
|
||||
Plugins are imported eagerly so import errors surface at app boot rather
|
||||
than at first request. ``ALL_PROVIDERS`` is the canonical ordered list;
|
||||
``PROVIDERS_BY_NAME`` is a name-keyed lookup for LLMCreator and the
|
||||
model registry.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
from application.llm.providers.anthropic import AnthropicProvider
|
||||
from application.llm.providers.azure_openai import AzureOpenAIProvider
|
||||
from application.llm.providers.base import Provider
|
||||
from application.llm.providers.docsgpt import DocsGPTProvider
|
||||
from application.llm.providers.google import GoogleProvider
|
||||
from application.llm.providers.groq import GroqProvider
|
||||
from application.llm.providers.huggingface import HuggingFaceProvider
|
||||
from application.llm.providers.llama_cpp import LlamaCppProvider
|
||||
from application.llm.providers.novita import NovitaProvider
|
||||
from application.llm.providers.openai import OpenAIProvider
|
||||
from application.llm.providers.openai_compatible import OpenAICompatibleProvider
|
||||
from application.llm.providers.openrouter import OpenRouterProvider
|
||||
from application.llm.providers.premai import PremAIProvider
|
||||
from application.llm.providers.sagemaker import SagemakerProvider
|
||||
|
||||
# Order here is the order the registry iterates providers (and therefore
|
||||
# the order ``/api/models`` reports them). Match the historical order
|
||||
# from the old ModelRegistry._load_models for byte-stable output during
|
||||
# the migration. ``openai_compatible`` slots in right after ``openai``
|
||||
# so legacy ``OPENAI_BASE_URL`` models keep landing in the same place.
|
||||
ALL_PROVIDERS: List[Provider] = [
|
||||
DocsGPTProvider(),
|
||||
OpenAIProvider(),
|
||||
OpenAICompatibleProvider(),
|
||||
AzureOpenAIProvider(),
|
||||
AnthropicProvider(),
|
||||
GoogleProvider(),
|
||||
GroqProvider(),
|
||||
OpenRouterProvider(),
|
||||
NovitaProvider(),
|
||||
HuggingFaceProvider(),
|
||||
LlamaCppProvider(),
|
||||
PremAIProvider(),
|
||||
SagemakerProvider(),
|
||||
]
|
||||
|
||||
PROVIDERS_BY_NAME: Dict[str, Provider] = {p.name: p for p in ALL_PROVIDERS}
|
||||
|
||||
__all__ = ["ALL_PROVIDERS", "PROVIDERS_BY_NAME", "Provider"]
|
||||
51
application/llm/providers/_apikey_or_llm_name.py
Normal file
51
application/llm/providers/_apikey_or_llm_name.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""Shared helper for providers that follow the
|
||||
``<X>_API_KEY or (LLM_PROVIDER==X and API_KEY)`` pattern.
|
||||
|
||||
This is the dominant pattern across Anthropic, Google, Groq, OpenRouter,
|
||||
and Novita. Extracted here so each plugin stays a few lines long.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from application.core.model_settings import AvailableModel
|
||||
|
||||
|
||||
def get_api_key(
|
||||
settings,
|
||||
provider_name: str,
|
||||
provider_specific_key: Optional[str],
|
||||
) -> Optional[str]:
|
||||
if provider_specific_key:
|
||||
return provider_specific_key
|
||||
if settings.LLM_PROVIDER == provider_name and settings.API_KEY:
|
||||
return settings.API_KEY
|
||||
return None
|
||||
|
||||
|
||||
def filter_models_by_llm_name(
|
||||
settings,
|
||||
provider_name: str,
|
||||
provider_specific_key: Optional[str],
|
||||
models: List[AvailableModel],
|
||||
) -> List[AvailableModel]:
|
||||
"""Mirrors the historical ``_add_<X>_models`` selection logic.
|
||||
|
||||
Behavior:
|
||||
- If the provider-specific API key is set → load all models.
|
||||
- Else if ``LLM_PROVIDER`` matches and ``LLM_NAME`` matches a known
|
||||
model → load just that model.
|
||||
- Otherwise → load all models (preserved "load anyway" branch from
|
||||
the original methods).
|
||||
"""
|
||||
if provider_specific_key:
|
||||
return models
|
||||
if (
|
||||
settings.LLM_PROVIDER == provider_name
|
||||
and settings.LLM_NAME
|
||||
):
|
||||
named = [m for m in models if m.id == settings.LLM_NAME]
|
||||
if named:
|
||||
return named
|
||||
return models
|
||||
23
application/llm/providers/anthropic.py
Normal file
23
application/llm/providers/anthropic.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from application.llm.anthropic import AnthropicLLM
|
||||
from application.llm.providers._apikey_or_llm_name import (
|
||||
filter_models_by_llm_name,
|
||||
get_api_key,
|
||||
)
|
||||
from application.llm.providers.base import Provider
|
||||
|
||||
|
||||
class AnthropicProvider(Provider):
|
||||
name = "anthropic"
|
||||
llm_class = AnthropicLLM
|
||||
|
||||
def get_api_key(self, settings) -> Optional[str]:
|
||||
return get_api_key(settings, self.name, settings.ANTHROPIC_API_KEY)
|
||||
|
||||
def filter_yaml_models(self, settings, models):
|
||||
return filter_models_by_llm_name(
|
||||
settings, self.name, settings.ANTHROPIC_API_KEY, models
|
||||
)
|
||||
30
application/llm/providers/azure_openai.py
Normal file
30
application/llm/providers/azure_openai.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from application.llm.openai import AzureOpenAILLM
|
||||
from application.llm.providers.base import Provider
|
||||
|
||||
|
||||
class AzureOpenAIProvider(Provider):
|
||||
name = "azure_openai"
|
||||
llm_class = AzureOpenAILLM
|
||||
|
||||
def get_api_key(self, settings) -> Optional[str]:
|
||||
# Azure historically uses the generic API_KEY field.
|
||||
return settings.API_KEY
|
||||
|
||||
def is_enabled(self, settings) -> bool:
|
||||
if settings.OPENAI_API_BASE:
|
||||
return True
|
||||
return settings.LLM_PROVIDER == self.name and bool(settings.API_KEY)
|
||||
|
||||
def filter_yaml_models(self, settings, models):
|
||||
# Mirrors _add_azure_openai_models: when LLM_PROVIDER==azure_openai
|
||||
# and LLM_NAME matches a known model, narrow to that one model.
|
||||
# Otherwise load the entire catalog.
|
||||
if settings.LLM_PROVIDER == self.name and settings.LLM_NAME:
|
||||
named = [m for m in models if m.id == settings.LLM_NAME]
|
||||
if named:
|
||||
return named
|
||||
return models
|
||||
74
application/llm/providers/base.py
Normal file
74
application/llm/providers/base.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, ClassVar, List, Optional, Type
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from application.core.model_settings import AvailableModel
|
||||
from application.core.model_yaml import ProviderCatalog
|
||||
from application.core.settings import Settings
|
||||
from application.llm.base import BaseLLM
|
||||
|
||||
|
||||
class Provider(ABC):
|
||||
"""Owns the *behavior* of an LLM provider.
|
||||
|
||||
Concrete providers declare their name, the LLM class to instantiate,
|
||||
and how to resolve credentials from settings. Static model catalogs
|
||||
live in YAML under ``application/core/models/`` and are joined to the
|
||||
provider by name at registry load time.
|
||||
|
||||
Most plugins receive zero or one catalog at registry-build time. The
|
||||
``openai_compatible`` plugin is the exception: it receives one catalog
|
||||
per matching YAML file, each with its own ``api_key_env`` and
|
||||
``base_url``. Plugins that need per-catalog metadata override
|
||||
``get_models``; the default implementation merges catalogs and routes
|
||||
through ``filter_yaml_models`` + ``extra_models``.
|
||||
"""
|
||||
|
||||
name: ClassVar[str]
|
||||
# ``None`` means the provider appears in the catalog but isn't
|
||||
# dispatchable through LLMCreator (e.g. Hugging Face today, where the
|
||||
# original LLMCreator dict had no entry).
|
||||
llm_class: ClassVar[Optional[Type["BaseLLM"]]] = None
|
||||
|
||||
@abstractmethod
|
||||
def get_api_key(self, settings: "Settings") -> Optional[str]:
|
||||
"""Return the API key for this provider, or None if unavailable."""
|
||||
|
||||
def is_enabled(self, settings: "Settings") -> bool:
|
||||
"""Whether this provider should contribute models to the registry."""
|
||||
return bool(self.get_api_key(settings))
|
||||
|
||||
def filter_yaml_models(
|
||||
self, settings: "Settings", models: List["AvailableModel"]
|
||||
) -> List["AvailableModel"]:
|
||||
"""Hook to filter YAML-loaded models. Default: return all."""
|
||||
return models
|
||||
|
||||
def extra_models(self, settings: "Settings") -> List["AvailableModel"]:
|
||||
"""Hook to add dynamic models not declared in YAML. Default: none."""
|
||||
return []
|
||||
|
||||
def get_models(
|
||||
self,
|
||||
settings: "Settings",
|
||||
catalogs: List["ProviderCatalog"],
|
||||
) -> List["AvailableModel"]:
|
||||
"""Final list of models this plugin contributes.
|
||||
|
||||
Default: merge the models across all matched catalogs (later
|
||||
catalog wins on duplicate id), filter via ``filter_yaml_models``,
|
||||
then append ``extra_models``. Override when per-catalog metadata
|
||||
matters (see ``OpenAICompatibleProvider``).
|
||||
"""
|
||||
merged: List["AvailableModel"] = []
|
||||
seen: dict = {}
|
||||
for c in catalogs:
|
||||
for m in c.models:
|
||||
if m.id in seen:
|
||||
merged[seen[m.id]] = m
|
||||
else:
|
||||
seen[m.id] = len(merged)
|
||||
merged.append(m)
|
||||
return self.filter_yaml_models(settings, merged) + self.extra_models(settings)
|
||||
22
application/llm/providers/docsgpt.py
Normal file
22
application/llm/providers/docsgpt.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from application.llm.docsgpt_provider import DocsGPTAPILLM
|
||||
from application.llm.providers.base import Provider
|
||||
|
||||
|
||||
class DocsGPTProvider(Provider):
|
||||
name = "docsgpt"
|
||||
llm_class = DocsGPTAPILLM
|
||||
|
||||
def get_api_key(self, settings) -> Optional[str]:
|
||||
# No provider-specific key; the LLM class can use the generic
|
||||
# API_KEY fallback if it needs one. Mirrors model_utils' historical
|
||||
# behavior of returning settings.API_KEY when no specific key exists.
|
||||
return settings.API_KEY
|
||||
|
||||
def is_enabled(self, settings) -> bool:
|
||||
# The hosted DocsGPT model is hidden when the deployment is
|
||||
# pointed at a custom OpenAI-compatible endpoint.
|
||||
return not settings.OPENAI_BASE_URL
|
||||
23
application/llm/providers/google.py
Normal file
23
application/llm/providers/google.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from application.llm.google_ai import GoogleLLM
|
||||
from application.llm.providers._apikey_or_llm_name import (
|
||||
filter_models_by_llm_name,
|
||||
get_api_key,
|
||||
)
|
||||
from application.llm.providers.base import Provider
|
||||
|
||||
|
||||
class GoogleProvider(Provider):
|
||||
name = "google"
|
||||
llm_class = GoogleLLM
|
||||
|
||||
def get_api_key(self, settings) -> Optional[str]:
|
||||
return get_api_key(settings, self.name, settings.GOOGLE_API_KEY)
|
||||
|
||||
def filter_yaml_models(self, settings, models):
|
||||
return filter_models_by_llm_name(
|
||||
settings, self.name, settings.GOOGLE_API_KEY, models
|
||||
)
|
||||
23
application/llm/providers/groq.py
Normal file
23
application/llm/providers/groq.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from application.llm.groq import GroqLLM
|
||||
from application.llm.providers._apikey_or_llm_name import (
|
||||
filter_models_by_llm_name,
|
||||
get_api_key,
|
||||
)
|
||||
from application.llm.providers.base import Provider
|
||||
|
||||
|
||||
class GroqProvider(Provider):
|
||||
name = "groq"
|
||||
llm_class = GroqLLM
|
||||
|
||||
def get_api_key(self, settings) -> Optional[str]:
|
||||
return get_api_key(settings, self.name, settings.GROQ_API_KEY)
|
||||
|
||||
def filter_yaml_models(self, settings, models):
|
||||
return filter_models_by_llm_name(
|
||||
settings, self.name, settings.GROQ_API_KEY, models
|
||||
)
|
||||
25
application/llm/providers/huggingface.py
Normal file
25
application/llm/providers/huggingface.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from application.llm.providers._apikey_or_llm_name import (
|
||||
get_api_key as shared_get_api_key,
|
||||
)
|
||||
from application.llm.providers.base import Provider
|
||||
|
||||
|
||||
class HuggingFaceProvider(Provider):
|
||||
"""Surfaces ``huggingface-local`` to the model catalog.
|
||||
|
||||
Not dispatchable through LLMCreator — historically there was no
|
||||
HuggingFaceLLM entry in ``LLMCreator.llms``, and calling ``create_llm``
|
||||
with ``"huggingface"`` raised ``ValueError``. We preserve that
|
||||
behavior: the model appears in ``/api/models`` but selecting it
|
||||
surfaces the same error it always did.
|
||||
"""
|
||||
|
||||
name = "huggingface"
|
||||
llm_class = None # not dispatchable
|
||||
|
||||
def get_api_key(self, settings) -> Optional[str]:
|
||||
return shared_get_api_key(settings, self.name, settings.HUGGINGFACE_API_KEY)
|
||||
19
application/llm/providers/llama_cpp.py
Normal file
19
application/llm/providers/llama_cpp.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from application.llm.llama_cpp import LlamaCpp
|
||||
from application.llm.providers.base import Provider
|
||||
|
||||
|
||||
class LlamaCppProvider(Provider):
|
||||
"""LLMCreator-only plugin: invocable via LLM_PROVIDER but not in the catalog."""
|
||||
|
||||
name = "llama.cpp"
|
||||
llm_class = LlamaCpp
|
||||
|
||||
def get_api_key(self, settings) -> Optional[str]:
|
||||
return settings.API_KEY
|
||||
|
||||
def is_enabled(self, settings) -> bool:
|
||||
return False
|
||||
23
application/llm/providers/novita.py
Normal file
23
application/llm/providers/novita.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from application.llm.novita import NovitaLLM
|
||||
from application.llm.providers._apikey_or_llm_name import (
|
||||
filter_models_by_llm_name,
|
||||
get_api_key,
|
||||
)
|
||||
from application.llm.providers.base import Provider
|
||||
|
||||
|
||||
class NovitaProvider(Provider):
|
||||
name = "novita"
|
||||
llm_class = NovitaLLM
|
||||
|
||||
def get_api_key(self, settings) -> Optional[str]:
|
||||
return get_api_key(settings, self.name, settings.NOVITA_API_KEY)
|
||||
|
||||
def filter_yaml_models(self, settings, models):
|
||||
return filter_models_by_llm_name(
|
||||
settings, self.name, settings.NOVITA_API_KEY, models
|
||||
)
|
||||
37
application/llm/providers/openai.py
Normal file
37
application/llm/providers/openai.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from application.llm.openai import OpenAILLM
|
||||
from application.llm.providers.base import Provider
|
||||
|
||||
|
||||
class OpenAIProvider(Provider):
|
||||
name = "openai"
|
||||
llm_class = OpenAILLM
|
||||
|
||||
def get_api_key(self, settings) -> Optional[str]:
|
||||
if settings.OPENAI_API_KEY:
|
||||
return settings.OPENAI_API_KEY
|
||||
if settings.LLM_PROVIDER == self.name and settings.API_KEY:
|
||||
return settings.API_KEY
|
||||
return None
|
||||
|
||||
def is_enabled(self, settings) -> bool:
|
||||
# When the deployment is pointed at a custom OpenAI-compatible
|
||||
# endpoint (Ollama, LM Studio, ...), the cloud-OpenAI catalog is
|
||||
# suppressed but ``is_enabled`` stays True — necessary so the
|
||||
# filter below still gets to drop the catalog (rather than the
|
||||
# registry skipping the provider entirely and missing the rule).
|
||||
if settings.OPENAI_BASE_URL:
|
||||
return True
|
||||
return bool(self.get_api_key(settings))
|
||||
|
||||
def filter_yaml_models(self, settings, models):
|
||||
# Legacy local-endpoint mode hides the cloud catalog. The
|
||||
# corresponding dynamic models live in OpenAICompatibleProvider.
|
||||
if settings.OPENAI_BASE_URL:
|
||||
return []
|
||||
if not settings.OPENAI_API_KEY:
|
||||
return []
|
||||
return models
|
||||
149
application/llm/providers/openai_compatible.py
Normal file
149
application/llm/providers/openai_compatible.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""Generic provider for OpenAI-wire-compatible endpoints.
|
||||
|
||||
Each ``openai_compatible`` YAML file describes one logical endpoint
|
||||
(Mistral, Together, Fireworks, Ollama, ...) with its own
|
||||
``api_key_env`` and ``base_url``. Multiple files can coexist; the
|
||||
plugin produces one set of models per file, each pre-configured with
|
||||
the right credentials and URL.
|
||||
|
||||
The plugin also handles the **legacy** ``OPENAI_BASE_URL`` + ``LLM_NAME``
|
||||
local-endpoint pattern that previously lived in ``OpenAIProvider``. That
|
||||
path generates models dynamically from ``LLM_NAME``, using
|
||||
``OPENAI_BASE_URL`` and ``OPENAI_API_KEY`` as the endpoint config.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from application.core.model_settings import (
|
||||
AvailableModel,
|
||||
ModelCapabilities,
|
||||
ModelProvider,
|
||||
)
|
||||
from application.llm.openai import OpenAILLM
|
||||
from application.llm.providers.base import Provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _parse_model_names(llm_name: Optional[str]) -> List[str]:
|
||||
if not llm_name:
|
||||
return []
|
||||
return [name.strip() for name in llm_name.split(",") if name.strip()]
|
||||
|
||||
|
||||
class OpenAICompatibleProvider(Provider):
|
||||
name = "openai_compatible"
|
||||
llm_class = OpenAILLM
|
||||
|
||||
def get_api_key(self, settings) -> Optional[str]:
|
||||
# Per-model: each catalog supplies its own ``api_key_env``. There
|
||||
# is no single plugin-wide key. LLMCreator reads the per-model
|
||||
# ``api_key`` set during catalog materialization.
|
||||
return None
|
||||
|
||||
def is_enabled(self, settings) -> bool:
|
||||
# Concrete enablement happens per catalog (in ``get_models``).
|
||||
# Returning True lets the registry call ``get_models`` so we can
|
||||
# decide per-file whether to contribute models.
|
||||
return True
|
||||
|
||||
def get_models(self, settings, catalogs) -> List[AvailableModel]:
|
||||
out: List[AvailableModel] = []
|
||||
|
||||
for catalog in catalogs:
|
||||
out.extend(self._materialize_yaml_catalog(catalog))
|
||||
|
||||
if settings.OPENAI_BASE_URL and settings.LLM_NAME:
|
||||
out.extend(self._materialize_legacy_local_endpoint(settings))
|
||||
|
||||
return out
|
||||
|
||||
def _materialize_yaml_catalog(self, catalog) -> List[AvailableModel]:
|
||||
"""Resolve one openai_compatible YAML into ready-to-dispatch models.
|
||||
|
||||
Skipped (with an INFO-level log) if ``api_key_env`` resolves to
|
||||
nothing — no point publishing models the user can't actually
|
||||
call. INFO rather than WARNING because operators may legitimately
|
||||
drop multiple provider YAMLs as templates and only set the env
|
||||
vars for the ones they actually use; a missing key is ambiguous,
|
||||
not necessarily a misconfig.
|
||||
"""
|
||||
if not catalog.base_url:
|
||||
raise ValueError(
|
||||
f"{catalog.source_path}: openai_compatible YAML must set "
|
||||
"'base_url'."
|
||||
)
|
||||
if not catalog.api_key_env:
|
||||
raise ValueError(
|
||||
f"{catalog.source_path}: openai_compatible YAML must set "
|
||||
"'api_key_env'."
|
||||
)
|
||||
|
||||
api_key = os.environ.get(catalog.api_key_env)
|
||||
if not api_key:
|
||||
logger.info(
|
||||
"openai_compatible catalog %s skipped: env var %s is not set",
|
||||
catalog.source_path,
|
||||
catalog.api_key_env,
|
||||
)
|
||||
return []
|
||||
|
||||
out: List[AvailableModel] = []
|
||||
for m in catalog.models:
|
||||
out.append(self._with_endpoint(m, catalog.base_url, api_key))
|
||||
return out
|
||||
|
||||
def _materialize_legacy_local_endpoint(self, settings) -> List[AvailableModel]:
|
||||
"""Generate AvailableModels from ``LLM_NAME`` for the legacy
|
||||
``OPENAI_BASE_URL`` deployment pattern (Ollama, LM Studio, ...).
|
||||
|
||||
Preserves the historical ``provider="openai"`` display behavior
|
||||
by setting ``display_provider="openai"``.
|
||||
"""
|
||||
from application.core.model_yaml import resolve_attachment_alias
|
||||
|
||||
attachments = resolve_attachment_alias("image")
|
||||
api_key = settings.OPENAI_API_KEY or settings.API_KEY
|
||||
out: List[AvailableModel] = []
|
||||
for model_name in _parse_model_names(settings.LLM_NAME):
|
||||
out.append(
|
||||
AvailableModel(
|
||||
id=model_name,
|
||||
provider=ModelProvider.OPENAI_COMPATIBLE,
|
||||
display_name=model_name,
|
||||
description=f"Custom OpenAI-compatible model at {settings.OPENAI_BASE_URL}",
|
||||
base_url=settings.OPENAI_BASE_URL,
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supported_attachment_types=attachments,
|
||||
),
|
||||
api_key=api_key,
|
||||
display_provider="openai",
|
||||
)
|
||||
)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def _with_endpoint(
|
||||
model: AvailableModel, base_url: str, api_key: str
|
||||
) -> AvailableModel:
|
||||
"""Return a copy of ``model`` carrying the catalog's endpoint config.
|
||||
|
||||
The catalog-level ``base_url`` is the default; an explicit
|
||||
per-model ``base_url`` in the YAML wins.
|
||||
"""
|
||||
return AvailableModel(
|
||||
id=model.id,
|
||||
provider=model.provider,
|
||||
display_name=model.display_name,
|
||||
description=model.description,
|
||||
capabilities=model.capabilities,
|
||||
enabled=model.enabled,
|
||||
base_url=model.base_url or base_url,
|
||||
display_provider=model.display_provider,
|
||||
api_key=api_key,
|
||||
)
|
||||
23
application/llm/providers/openrouter.py
Normal file
23
application/llm/providers/openrouter.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from application.llm.open_router import OpenRouterLLM
|
||||
from application.llm.providers._apikey_or_llm_name import (
|
||||
filter_models_by_llm_name,
|
||||
get_api_key,
|
||||
)
|
||||
from application.llm.providers.base import Provider
|
||||
|
||||
|
||||
class OpenRouterProvider(Provider):
|
||||
name = "openrouter"
|
||||
llm_class = OpenRouterLLM
|
||||
|
||||
def get_api_key(self, settings) -> Optional[str]:
|
||||
return get_api_key(settings, self.name, settings.OPEN_ROUTER_API_KEY)
|
||||
|
||||
def filter_yaml_models(self, settings, models):
|
||||
return filter_models_by_llm_name(
|
||||
settings, self.name, settings.OPEN_ROUTER_API_KEY, models
|
||||
)
|
||||
19
application/llm/providers/premai.py
Normal file
19
application/llm/providers/premai.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from application.llm.premai import PremAILLM
|
||||
from application.llm.providers.base import Provider
|
||||
|
||||
|
||||
class PremAIProvider(Provider):
|
||||
"""LLMCreator-only plugin: invocable via LLM_PROVIDER but not in the catalog."""
|
||||
|
||||
name = "premai"
|
||||
llm_class = PremAILLM
|
||||
|
||||
def get_api_key(self, settings) -> Optional[str]:
|
||||
return settings.API_KEY
|
||||
|
||||
def is_enabled(self, settings) -> bool:
|
||||
return False
|
||||
24
application/llm/providers/sagemaker.py
Normal file
24
application/llm/providers/sagemaker.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from application.llm.sagemaker import SagemakerAPILLM
|
||||
from application.llm.providers.base import Provider
|
||||
|
||||
|
||||
class SagemakerProvider(Provider):
|
||||
"""LLMCreator-only plugin: invocable via LLM_PROVIDER but not in the catalog.
|
||||
|
||||
SageMaker reads its credentials from ``SAGEMAKER_*`` settings inside
|
||||
the LLM class itself; this plugin's ``get_api_key`` exists only for
|
||||
LLMCreator's symmetry.
|
||||
"""
|
||||
|
||||
name = "sagemaker"
|
||||
llm_class = SagemakerAPILLM
|
||||
|
||||
def get_api_key(self, settings) -> Optional[str]:
|
||||
return settings.API_KEY
|
||||
|
||||
def is_enabled(self, settings) -> bool:
|
||||
return False
|
||||
@@ -82,6 +82,7 @@ python-dateutil==2.9.0.post0
|
||||
python-dotenv
|
||||
python-jose==3.5.0
|
||||
python-pptx==1.0.2
|
||||
PyYAML
|
||||
redis==7.4.0
|
||||
referencing>=0.28.0,<0.38.0
|
||||
regex==2026.4.4
|
||||
|
||||
@@ -22,6 +22,7 @@ class ClassicRAG(BaseRetriever):
|
||||
llm_name=settings.LLM_PROVIDER,
|
||||
api_key=settings.API_KEY,
|
||||
decoded_token=None,
|
||||
model_user_id=None,
|
||||
):
|
||||
self.original_question = source.get("question", "")
|
||||
self.chat_history = chat_history if chat_history is not None else []
|
||||
@@ -42,17 +43,22 @@ class ClassicRAG(BaseRetriever):
|
||||
f"sources={'active_docs' in source and source['active_docs'] is not None}"
|
||||
)
|
||||
self.model_id = model_id
|
||||
self.model_user_id = model_user_id
|
||||
self.doc_token_limit = doc_token_limit
|
||||
self.user_api_key = user_api_key
|
||||
self.agent_id = agent_id
|
||||
self.llm_name = llm_name
|
||||
self.api_key = api_key
|
||||
# Forward model_id + model_user_id so LLMCreator resolves BYOM
|
||||
# base_url / api_key / upstream id for the rephrase client.
|
||||
self.llm = LLMCreator.create_llm(
|
||||
self.llm_name,
|
||||
api_key=self.api_key,
|
||||
user_api_key=self.user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
model_id=self.model_id,
|
||||
agent_id=self.agent_id,
|
||||
model_user_id=self.model_user_id,
|
||||
)
|
||||
|
||||
if "active_docs" in source and source["active_docs"] is not None:
|
||||
@@ -103,7 +109,11 @@ class ClassicRAG(BaseRetriever):
|
||||
]
|
||||
|
||||
try:
|
||||
rephrased_query = self.llm.gen(model=self.model_id, messages=messages)
|
||||
# Send upstream id (resolved by LLMCreator), not registry UUID.
|
||||
rephrased_query = self.llm.gen(
|
||||
model=getattr(self.llm, "model_id", None) or self.model_id,
|
||||
messages=messages,
|
||||
)
|
||||
print(f"Rephrased query: {rephrased_query}")
|
||||
return rephrased_query if rephrased_query else self.original_question
|
||||
except Exception as e:
|
||||
|
||||
464
application/security/safe_url.py
Normal file
464
application/security/safe_url.py
Normal file
@@ -0,0 +1,464 @@
|
||||
"""SSRF protection for user-supplied OpenAI-compatible base URLs.
|
||||
|
||||
This module is the single chokepoint for validating any URL that a user
|
||||
provides as an OpenAI-compatible ``base_url`` ("Bring Your Own Model").
|
||||
The backend will later issue outbound HTTP requests to that URL on the
|
||||
user's behalf, so we must reject anything that could be used to reach
|
||||
internal-network resources (cloud metadata services, RFC 1918 ranges,
|
||||
loopback, link-local, etc.).
|
||||
|
||||
Three entry points:
|
||||
|
||||
* :func:`validate_user_base_url` — called at create/update time on REST
|
||||
routes that persist the URL, to give the user immediate feedback.
|
||||
* :func:`pinned_post` — called at dispatch time when the caller drives
|
||||
``requests`` directly (e.g. the ``/api/models/test`` endpoint).
|
||||
Resolves once, dials the IP literal, preserves the original hostname
|
||||
in the ``Host`` header and via SNI / cert verification for HTTPS.
|
||||
* :func:`pinned_httpx_client` — called at dispatch time when the caller
|
||||
hands an ``httpx.Client`` to a third-party SDK (e.g. the OpenAI
|
||||
Python SDK via ``OpenAI(http_client=...)``). Same DNS-rebinding
|
||||
closure on the httpx transport layer.
|
||||
|
||||
Why all three: the OpenAI / httpx ecosystem performs its own DNS lookup
|
||||
inside ``socket.getaddrinfo`` when a connection opens, so a hostile DNS
|
||||
server can hand a public IP to the validator and a loopback / link-local
|
||||
address to the HTTP client. Validate-then-construct-SDK is unsafe; the
|
||||
pinned variants close that TOCTOU window by resolving exactly once and
|
||||
dialing the chosen IP literal directly.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import socket
|
||||
from typing import Any, Iterable
|
||||
from urllib.parse import urlsplit, urlunsplit
|
||||
|
||||
import httpx
|
||||
import requests
|
||||
from requests.adapters import HTTPAdapter
|
||||
|
||||
# Allowed URL schemes. Anything else (file, gopher, ftp, data, ...) is
|
||||
# rejected outright because it either bypasses HTTP entirely or enables
|
||||
# protocol smuggling against the proxy stack.
|
||||
_ALLOWED_SCHEMES: frozenset[str] = frozenset({"http", "https"})
|
||||
|
||||
# Hostnames that resolve to a loopback / metadata / unspecified address
|
||||
# but which we want to reject *by name* as well, so the rejection
|
||||
# message is unambiguous and so we never accidentally call DNS on them.
|
||||
_BLOCKED_HOSTNAMES: frozenset[str] = frozenset(
|
||||
{
|
||||
"localhost",
|
||||
"localhost.localdomain",
|
||||
"0.0.0.0",
|
||||
"::",
|
||||
"::1",
|
||||
"ip6-localhost",
|
||||
"ip6-loopback",
|
||||
# GCP metadata service. AWS/Azure use 169.254.169.254 which the
|
||||
# IP-range check below already covers via the link-local range,
|
||||
# but Google's hostname does not always resolve to a link-local
|
||||
# IP from every VPC, so we hard-deny the string too.
|
||||
"metadata.google.internal",
|
||||
}
|
||||
)
|
||||
|
||||
# Carrier-grade NAT (RFC 6598). Python's ``ipaddress`` module does NOT
|
||||
# classify this range as ``is_private``, so we must check it explicitly.
|
||||
_CGNAT_NETWORK_V4: ipaddress.IPv4Network = ipaddress.IPv4Network("100.64.0.0/10")
|
||||
|
||||
|
||||
class UnsafeUserUrlError(ValueError):
|
||||
"""Raised when a user-supplied URL fails SSRF validation.
|
||||
|
||||
Subclasses :class:`ValueError` so call sites that already treat
|
||||
invalid input as a 400-class error continue to work. The string
|
||||
message names the specific reason (scheme, hostname, resolved IP,
|
||||
DNS failure, ...) so that it can be surfaced to the user verbatim.
|
||||
"""
|
||||
|
||||
|
||||
def _strip_ipv6_brackets(host: str) -> str:
|
||||
"""Return ``host`` with surrounding ``[`` / ``]`` removed if present."""
|
||||
|
||||
if host.startswith("[") and host.endswith("]"):
|
||||
return host[1:-1]
|
||||
return host
|
||||
|
||||
|
||||
def _is_blocked_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
|
||||
"""Return ``True`` if ``ip`` falls in any range we refuse to dial.
|
||||
|
||||
This is the single source of truth for the IP-range policy:
|
||||
|
||||
* loopback (``127.0.0.0/8``, ``::1``)
|
||||
* private (RFC 1918, ULA ``fc00::/7``)
|
||||
* link-local (``169.254.0.0/16``, ``fe80::/10``)
|
||||
* multicast (``224.0.0.0/4``, ``ff00::/8``)
|
||||
* unspecified (``0.0.0.0``, ``::``)
|
||||
* reserved (``240.0.0.0/4``, etc.)
|
||||
* carrier-grade NAT (``100.64.0.0/10``) — not covered by ``is_private``
|
||||
"""
|
||||
|
||||
if (
|
||||
ip.is_loopback
|
||||
or ip.is_private
|
||||
or ip.is_link_local
|
||||
or ip.is_multicast
|
||||
or ip.is_unspecified
|
||||
or ip.is_reserved
|
||||
):
|
||||
return True
|
||||
if isinstance(ip, ipaddress.IPv4Address) and ip in _CGNAT_NETWORK_V4:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _resolve(host: str) -> Iterable[ipaddress.IPv4Address | ipaddress.IPv6Address]:
|
||||
"""Resolve ``host`` to every A/AAAA record returned by the system.
|
||||
|
||||
Returning *all* addresses (rather than the first one) is critical:
|
||||
a hostile DNS server can return a public IP first followed by a
|
||||
private IP, and the underlying HTTP client may fail over to the
|
||||
private one on connect. We treat the set as unsafe if any element
|
||||
is unsafe.
|
||||
"""
|
||||
|
||||
try:
|
||||
results = socket.getaddrinfo(host, None)
|
||||
except socket.gaierror as exc: # noqa: PERF203 — re-raise as our own type
|
||||
raise UnsafeUserUrlError(f"could not resolve hostname {host!r}: {exc}") from exc
|
||||
|
||||
addresses: list[ipaddress.IPv4Address | ipaddress.IPv6Address] = []
|
||||
for entry in results:
|
||||
sockaddr = entry[4]
|
||||
# IPv4 sockaddr: (host, port). IPv6 sockaddr: (host, port, flowinfo, scope_id).
|
||||
ip_str = sockaddr[0]
|
||||
# Strip IPv6 zone-id ("fe80::1%lo0") before parsing.
|
||||
if "%" in ip_str:
|
||||
ip_str = ip_str.split("%", 1)[0]
|
||||
try:
|
||||
addresses.append(ipaddress.ip_address(ip_str))
|
||||
except ValueError:
|
||||
# An entry we can't parse is itself suspicious; treat as unsafe.
|
||||
raise UnsafeUserUrlError(
|
||||
f"hostname {host!r} resolved to unparseable address {ip_str!r}"
|
||||
) from None
|
||||
return addresses
|
||||
|
||||
|
||||
def _validate_and_pick_ip(
|
||||
url: str,
|
||||
) -> tuple[str, ipaddress.IPv4Address | ipaddress.IPv6Address, "urlsplit"]:
|
||||
"""Run the SSRF guard and return the data needed to dial safely.
|
||||
|
||||
Performs every check :func:`validate_user_base_url` performs, but
|
||||
additionally returns ``(hostname, ip, parts)`` where ``ip`` is one
|
||||
of the validated addresses (the first record returned by the
|
||||
resolver, or the literal itself if the URL already used an IP) and
|
||||
``parts`` is the :func:`urllib.parse.urlsplit` result so callers do
|
||||
not have to re-parse the URL.
|
||||
|
||||
Raises :class:`UnsafeUserUrlError` on the same conditions as
|
||||
:func:`validate_user_base_url`.
|
||||
"""
|
||||
|
||||
if not isinstance(url, str) or not url.strip():
|
||||
raise UnsafeUserUrlError("url must be a non-empty string")
|
||||
|
||||
try:
|
||||
parts = urlsplit(url)
|
||||
except ValueError as exc:
|
||||
raise UnsafeUserUrlError(f"could not parse url {url!r}: {exc}") from exc
|
||||
|
||||
scheme = parts.scheme.lower()
|
||||
if scheme not in _ALLOWED_SCHEMES:
|
||||
raise UnsafeUserUrlError(
|
||||
f"scheme {scheme!r} is not allowed; only http and https are permitted"
|
||||
)
|
||||
|
||||
# ``urlsplit`` returns the bracketed form for IPv6 in ``netloc`` but
|
||||
# the bare form in ``hostname``. Normalize via lower() because
|
||||
# hostnames are case-insensitive and we compare against a lowercase
|
||||
# blocklist.
|
||||
raw_host = parts.hostname
|
||||
if not raw_host:
|
||||
raise UnsafeUserUrlError(f"url {url!r} has no hostname")
|
||||
|
||||
host = raw_host.lower()
|
||||
|
||||
# Check the literal-string blocklist first. urlsplit().hostname strips
|
||||
# IPv6 brackets, so we also test the bracketed form for completeness
|
||||
# (matches the public-spec note about ``[::]``).
|
||||
bracketed = f"[{host}]"
|
||||
if host in _BLOCKED_HOSTNAMES or bracketed in _BLOCKED_HOSTNAMES:
|
||||
raise UnsafeUserUrlError(
|
||||
f"hostname {raw_host!r} is not allowed (matches internal-only name)"
|
||||
)
|
||||
|
||||
# If the host is already an IP literal (with or without IPv6 brackets),
|
||||
# check it directly without going to DNS — DNS for an IP literal is a
|
||||
# no-op but it's clearer to short-circuit and gives a better message.
|
||||
candidate = _strip_ipv6_brackets(host)
|
||||
try:
|
||||
literal = ipaddress.ip_address(candidate)
|
||||
except ValueError:
|
||||
literal = None
|
||||
|
||||
if literal is not None:
|
||||
if _is_blocked_ip(literal):
|
||||
raise UnsafeUserUrlError(
|
||||
f"hostname {raw_host!r} resolves to blocked address {literal} "
|
||||
f"(loopback/private/link-local/multicast/reserved/CGNAT)"
|
||||
)
|
||||
return host, literal, parts
|
||||
|
||||
# Hostname (not an IP literal) — resolve and validate every record.
|
||||
addresses = list(_resolve(host))
|
||||
for ip in addresses:
|
||||
if _is_blocked_ip(ip):
|
||||
raise UnsafeUserUrlError(
|
||||
f"hostname {raw_host!r} resolves to blocked address {ip} "
|
||||
f"(loopback/private/link-local/multicast/reserved/CGNAT)"
|
||||
)
|
||||
if not addresses:
|
||||
# ``getaddrinfo`` would normally raise instead of returning an
|
||||
# empty list, but treat the degenerate case as unsafe too — we
|
||||
# have nothing to bind a connection to.
|
||||
raise UnsafeUserUrlError(
|
||||
f"hostname {raw_host!r} returned no addresses from DNS"
|
||||
)
|
||||
return host, addresses[0], parts
|
||||
|
||||
|
||||
def validate_user_base_url(url: str) -> None:
|
||||
"""Validate that ``url`` is safe to use as an outbound base URL.
|
||||
|
||||
Resolve the URL's hostname to one or more IPs and reject if any
|
||||
resolved IP is private/loopback/link-local/multicast/reserved, or if
|
||||
the URL uses a non-http(s) scheme, or if the hostname is one of the
|
||||
known dangerous strings (``localhost``, ``0.0.0.0``, ``[::]``).
|
||||
|
||||
Raises :class:`UnsafeUserUrlError` on rejection. Returns ``None`` on
|
||||
success.
|
||||
|
||||
This function is the create/update-time check. At dispatch time use
|
||||
:func:`pinned_post` instead, which performs the same validation
|
||||
*and* pins the outbound connection to the validated IP so a DNS
|
||||
rebinder cannot flip the resolution between check and connect.
|
||||
|
||||
Args:
|
||||
url: The user-supplied URL to validate. Expected to be an
|
||||
absolute URL with an ``http`` or ``https`` scheme.
|
||||
|
||||
Raises:
|
||||
UnsafeUserUrlError: If the URL fails to parse, uses a forbidden
|
||||
scheme, has an empty/blocklisted hostname, fails DNS
|
||||
resolution, or resolves to any IP in a blocked range.
|
||||
"""
|
||||
|
||||
_validate_and_pick_ip(url)
|
||||
|
||||
|
||||
class _PinnedHostAdapter(HTTPAdapter):
|
||||
"""HTTPS adapter that performs SNI and cert verification against a
|
||||
fixed hostname even when the URL connects to an IP literal.
|
||||
|
||||
Used by :func:`pinned_post` so that resolving the user-supplied
|
||||
hostname once and dialing the resolved IP doesn't break TLS.
|
||||
Without this, ``urllib3`` would default ``server_hostname`` /
|
||||
``assert_hostname`` to the connect host (the IP) and either send the
|
||||
wrong SNI or fail cert verification — the cert is for the original
|
||||
hostname, not the IP literal.
|
||||
"""
|
||||
|
||||
def __init__(self, server_hostname: str, *args: Any, **kwargs: Any) -> None:
|
||||
self._server_hostname = server_hostname
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def init_poolmanager(self, *args: Any, **kwargs: Any) -> None:
|
||||
kwargs["server_hostname"] = self._server_hostname
|
||||
kwargs["assert_hostname"] = self._server_hostname
|
||||
super().init_poolmanager(*args, **kwargs)
|
||||
|
||||
|
||||
def _ip_to_url_host(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> str:
|
||||
"""Return ``ip`` formatted for use in a URL netloc (brackets for v6)."""
|
||||
|
||||
if isinstance(ip, ipaddress.IPv6Address):
|
||||
return f"[{ip}]"
|
||||
return str(ip)
|
||||
|
||||
|
||||
def pinned_post(
|
||||
url: str,
|
||||
*,
|
||||
json: Any = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: float = 5.0,
|
||||
allow_redirects: bool = False,
|
||||
) -> requests.Response:
|
||||
"""POST to ``url`` with the outbound connection pinned to a single
|
||||
validated IP, closing the DNS-rebinding TOCTOU window left by the
|
||||
naive validate-then-``requests.post`` pattern.
|
||||
|
||||
The URL's hostname is resolved exactly once. Every returned address
|
||||
must pass the same SSRF guard as :func:`validate_user_base_url`. The
|
||||
outbound request is issued against the chosen IP literal (so
|
||||
``urllib3`` cannot ask the resolver again and receive a different
|
||||
answer); the original hostname is preserved in the ``Host`` header
|
||||
and, for HTTPS, via :class:`_PinnedHostAdapter` for SNI and cert
|
||||
verification.
|
||||
|
||||
Args:
|
||||
url: Absolute http(s) URL to POST to.
|
||||
json: JSON-serializable payload — passed through to ``requests``.
|
||||
headers: Caller-supplied headers. Any caller-supplied ``Host``
|
||||
entry is overwritten so the in-flight request matches what
|
||||
was validated.
|
||||
timeout: Per-request timeout (seconds).
|
||||
allow_redirects: Forwarded to ``requests``. Defaults to
|
||||
``False`` because the SSRF guard only inspects the supplied
|
||||
URL — following redirects would let a hostile upstream
|
||||
bounce the request to an internal address.
|
||||
|
||||
Raises:
|
||||
UnsafeUserUrlError: If the URL fails the SSRF guard.
|
||||
requests.RequestException: For network-level failures.
|
||||
"""
|
||||
|
||||
host, ip, parts = _validate_and_pick_ip(url)
|
||||
|
||||
netloc = _ip_to_url_host(ip)
|
||||
if parts.port is not None:
|
||||
netloc = f"{netloc}:{parts.port}"
|
||||
pinned_url = urlunsplit(
|
||||
(parts.scheme, netloc, parts.path, parts.query, parts.fragment)
|
||||
)
|
||||
|
||||
request_headers = dict(headers or {})
|
||||
host_header = host if parts.port is None else f"{host}:{parts.port}"
|
||||
request_headers["Host"] = host_header
|
||||
|
||||
session = requests.Session()
|
||||
if parts.scheme == "https":
|
||||
session.mount("https://", _PinnedHostAdapter(host))
|
||||
try:
|
||||
return session.post(
|
||||
pinned_url,
|
||||
json=json,
|
||||
headers=request_headers,
|
||||
timeout=timeout,
|
||||
allow_redirects=allow_redirects,
|
||||
)
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
class _PinnedHTTPSTransport(httpx.HTTPTransport):
|
||||
"""``httpx`` transport pinned to a single validated IP literal.
|
||||
|
||||
Closes the DNS-rebinding TOCTOU window that
|
||||
:func:`validate_user_base_url` cannot close on its own. The OpenAI
|
||||
Python SDK (and any other SDK that uses ``httpx``) re-resolves the
|
||||
hostname inside ``socket.getaddrinfo`` at request time, so a
|
||||
hostile DNS server can return a public IP at validation time and a
|
||||
private IP at request time. This transport rewrites every outgoing
|
||||
request's URL host to the validated IP literal so ``httpcore``
|
||||
dials that IP without a fresh lookup.
|
||||
|
||||
The original hostname is preserved in two places:
|
||||
|
||||
1. ``Host`` header — ``httpx.Request._prepare`` set it from the URL
|
||||
netloc *before* this transport runs, so it carries the hostname
|
||||
not the IP literal. We deliberately do not touch headers here.
|
||||
2. TLS SNI / cert verification — set via the
|
||||
``request.extensions["sni_hostname"]`` extension which
|
||||
``httpcore`` feeds into ``start_tls``'s ``server_hostname``
|
||||
parameter. Without this, ``urllib3``-equivalent code would use
|
||||
the IP literal as SNI and cert verification would fail (the
|
||||
cert is for the original hostname, not the IP).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
validated_host: str,
|
||||
validated_ip: ipaddress.IPv4Address | ipaddress.IPv6Address,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# http2=False (the httpx default) — defense in depth against
|
||||
# HTTP/2 connection coalescing (RFC 7540 §9.1.1), where a
|
||||
# client may reuse a TCP connection for any host whose cert
|
||||
# covers it. Per-IP pinning never shares connections across
|
||||
# hosts, but explicit is safer than relying on the default.
|
||||
kwargs.setdefault("http2", False)
|
||||
super().__init__(**kwargs)
|
||||
self._host = validated_host
|
||||
self._ip_netloc = _ip_to_url_host(validated_ip)
|
||||
|
||||
def handle_request(self, request: httpx.Request) -> httpx.Response:
|
||||
# Defense in depth: refuse if the request URL's host doesn't
|
||||
# match what we validated. Catches any future SDK regression
|
||||
# that rewrites the URL between Request construction and dial,
|
||||
# and any rare case where the SDK reuses our pinned client for
|
||||
# a different host (which it shouldn't, but assert it anyway).
|
||||
if request.url.host != self._host:
|
||||
raise UnsafeUserUrlError(
|
||||
f"pinned transport bound to {self._host!r}, refused "
|
||||
f"request for {request.url.host!r}"
|
||||
)
|
||||
# SNI/server_hostname for TLS verification. httpcore reads this
|
||||
# extension at _sync/connection.py and feeds it into
|
||||
# start_tls's server_hostname argument. Set before the URL host
|
||||
# is rewritten so cert validation continues to use the original
|
||||
# hostname even though TCP dials the IP literal.
|
||||
request.extensions = {
|
||||
**request.extensions,
|
||||
"sni_hostname": self._host.encode("ascii"),
|
||||
}
|
||||
request.url = request.url.copy_with(host=self._ip_netloc)
|
||||
return super().handle_request(request)
|
||||
|
||||
|
||||
def pinned_httpx_client(
|
||||
base_url: str,
|
||||
*,
|
||||
timeout: float = 600.0,
|
||||
) -> httpx.Client:
|
||||
"""Return an :class:`httpx.Client` whose connections are pinned to
|
||||
one validated IP, closing the DNS-rebinding TOCTOU window the naive
|
||||
``OpenAI(base_url=...)`` flow leaves open.
|
||||
|
||||
The hostname in ``base_url`` is resolved exactly once. Every
|
||||
returned address must pass :func:`_validate_and_pick_ip`'s SSRF
|
||||
guard (loopback, RFC 1918, link-local, multicast, reserved, CGNAT,
|
||||
cloud metadata names). The chosen IP becomes the URL host on every
|
||||
outgoing request so ``httpcore`` cannot ask the resolver again.
|
||||
|
||||
Pass via ``OpenAI(http_client=pinned_httpx_client(base_url))`` (or
|
||||
any other SDK that accepts an ``httpx.Client``) to make BYOM
|
||||
dispatch immune to DNS-rebinding TOCTOU.
|
||||
|
||||
Args:
|
||||
base_url: User-supplied http(s) URL. Validated through the same
|
||||
SSRF guard as :func:`validate_user_base_url`.
|
||||
timeout: Per-request timeout (seconds). Defaults to 600 to
|
||||
match the OpenAI SDK's default; callers should override
|
||||
for non-LLM workloads.
|
||||
|
||||
Raises:
|
||||
UnsafeUserUrlError: If ``base_url`` fails the SSRF guard.
|
||||
"""
|
||||
|
||||
host, ip, _parts = _validate_and_pick_ip(base_url)
|
||||
transport = _PinnedHTTPSTransport(host, ip)
|
||||
# follow_redirects=False — the SSRF guard only inspects the
|
||||
# supplied URL; following 3xx would let a hostile upstream bounce
|
||||
# the in-network request to an internal address (cloud metadata,
|
||||
# RFC1918, loopback) carrying whatever credentials the SDK adds.
|
||||
return httpx.Client(
|
||||
transport=transport,
|
||||
timeout=timeout,
|
||||
follow_redirects=False,
|
||||
)
|
||||
@@ -203,6 +203,24 @@ agents_table = Table(
|
||||
Column("legacy_mongo_id", Text),
|
||||
)
|
||||
|
||||
user_custom_models_table = Table(
|
||||
"user_custom_models",
|
||||
metadata,
|
||||
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||
Column("user_id", Text, nullable=False),
|
||||
Column("upstream_model_id", Text, nullable=False),
|
||||
Column("display_name", Text, nullable=False),
|
||||
Column("description", Text, nullable=False, server_default=""),
|
||||
Column("base_url", Text, nullable=False),
|
||||
# AES-CBC ciphertext (base64) keyed via per-user PBKDF2 in
|
||||
# application.security.encryption.encrypt_credentials.
|
||||
Column("api_key_encrypted", Text, nullable=False),
|
||||
Column("capabilities", JSONB, nullable=False, server_default="{}"),
|
||||
Column("enabled", Boolean, nullable=False, server_default="true"),
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
)
|
||||
|
||||
attachments_table = Table(
|
||||
"attachments",
|
||||
metadata,
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Repository for the ``agents`` table.
|
||||
|
||||
This is the most complex Phase 2 repository. Covers every write operation
|
||||
the legacy Mongo code performs on ``agents_collection``:
|
||||
Covers every write operation the legacy Mongo code performs on ``agents_collection``:
|
||||
|
||||
- create, update, delete
|
||||
- find by key (API key lookup)
|
||||
|
||||
199
application/storage/db/repositories/user_custom_models.py
Normal file
199
application/storage/db/repositories/user_custom_models.py
Normal file
@@ -0,0 +1,199 @@
|
||||
"""Repository for the ``user_custom_models`` table.
|
||||
|
||||
Backs the end-user "Bring Your Own Model" feature. Each row is one
|
||||
user-supplied OpenAI-compatible endpoint (Mistral, Together, vLLM, ...).
|
||||
The ``id`` UUID is the internal DocsGPT identifier (what agents store
|
||||
in ``default_model_id``); ``upstream_model_id`` is what we send verbatim
|
||||
to the provider's API.
|
||||
|
||||
API key handling: callers pass plaintext via ``api_key_plaintext``;
|
||||
this module wraps the existing ``application.security.encryption``
|
||||
helper (AES-CBC + per-user PBKDF2 salt) and writes the base64 ciphertext
|
||||
to the ``api_key_encrypted`` column. Decryption is the caller's
|
||||
responsibility (they hold the ``user_id``).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import Connection, func, text
|
||||
|
||||
from application.security.encryption import (
|
||||
decrypt_credentials,
|
||||
encrypt_credentials,
|
||||
)
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
from application.storage.db.models import user_custom_models_table
|
||||
|
||||
|
||||
_ALLOWED_CAPABILITY_KEYS = frozenset(
|
||||
{
|
||||
"supports_tools",
|
||||
"supports_structured_output",
|
||||
"supports_streaming",
|
||||
"attachments",
|
||||
"context_window",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class UserCustomModelsRepository:
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Encryption wrappers
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
@staticmethod
|
||||
def _encrypt_api_key(api_key_plaintext: str, user_id: str) -> str:
|
||||
"""Encrypt ``api_key_plaintext`` with the per-user PBKDF2 scheme."""
|
||||
return encrypt_credentials({"api_key": api_key_plaintext}, user_id)
|
||||
|
||||
@staticmethod
|
||||
def _decrypt_api_key(api_key_encrypted: str, user_id: str) -> Optional[str]:
|
||||
"""Decrypt the API key. Returns None on failure (which the caller
|
||||
should surface as a configuration error rather than silently
|
||||
proceeding with the upstream call)."""
|
||||
if not api_key_encrypted:
|
||||
return None
|
||||
creds = decrypt_credentials(api_key_encrypted, user_id)
|
||||
return creds.get("api_key") if creds else None
|
||||
|
||||
@staticmethod
|
||||
def _normalize_capabilities(caps: Optional[dict]) -> dict:
|
||||
"""Drop unknown keys; nothing else is forced. Callers (the route
|
||||
layer) are responsible for value validation (numeric ranges,
|
||||
attachment alias resolution)."""
|
||||
if not caps:
|
||||
return {}
|
||||
return {k: v for k, v in caps.items() if k in _ALLOWED_CAPABILITY_KEYS}
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# CRUD
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def create(
|
||||
self,
|
||||
user_id: str,
|
||||
upstream_model_id: str,
|
||||
display_name: str,
|
||||
base_url: str,
|
||||
api_key_plaintext: str,
|
||||
description: str = "",
|
||||
capabilities: Optional[dict] = None,
|
||||
enabled: bool = True,
|
||||
) -> dict:
|
||||
values = {
|
||||
"user_id": user_id,
|
||||
"upstream_model_id": upstream_model_id,
|
||||
"display_name": display_name,
|
||||
"description": description or "",
|
||||
"base_url": base_url,
|
||||
"api_key_encrypted": self._encrypt_api_key(api_key_plaintext, user_id),
|
||||
"capabilities": self._normalize_capabilities(capabilities),
|
||||
"enabled": bool(enabled),
|
||||
}
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
stmt = (
|
||||
pg_insert(user_custom_models_table)
|
||||
.values(**values)
|
||||
.returning(user_custom_models_table)
|
||||
)
|
||||
result = self._conn.execute(stmt)
|
||||
return row_to_dict(result.fetchone())
|
||||
|
||||
def get(self, model_id: str, user_id: str) -> Optional[dict]:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM user_custom_models "
|
||||
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||
),
|
||||
{"id": str(model_id), "user_id": user_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def list_for_user(self, user_id: str) -> list[dict]:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM user_custom_models "
|
||||
"WHERE user_id = :user_id ORDER BY created_at DESC"
|
||||
),
|
||||
{"user_id": user_id},
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def update(self, model_id: str, user_id: str, fields: dict) -> bool:
|
||||
"""Apply a partial update.
|
||||
|
||||
Special-cases ``api_key_plaintext``: when present, it is encrypted
|
||||
and stored in ``api_key_encrypted``. When absent (or empty), the
|
||||
existing ciphertext is kept untouched. This is the wire-shape
|
||||
``PATCH`` expects (the UI sends a blank password field when the
|
||||
operator wants to keep the existing key).
|
||||
"""
|
||||
allowed = {
|
||||
"upstream_model_id",
|
||||
"display_name",
|
||||
"description",
|
||||
"base_url",
|
||||
"capabilities",
|
||||
"enabled",
|
||||
}
|
||||
values: dict[str, Any] = {}
|
||||
for col, val in fields.items():
|
||||
if col not in allowed or val is None:
|
||||
continue
|
||||
if col == "capabilities":
|
||||
values[col] = self._normalize_capabilities(val)
|
||||
elif col == "enabled":
|
||||
values[col] = bool(val)
|
||||
else:
|
||||
values[col] = val
|
||||
|
||||
api_key_plaintext = fields.get("api_key_plaintext")
|
||||
if api_key_plaintext:
|
||||
values["api_key_encrypted"] = self._encrypt_api_key(
|
||||
api_key_plaintext, user_id
|
||||
)
|
||||
|
||||
if not values:
|
||||
return False
|
||||
values["updated_at"] = func.now()
|
||||
|
||||
t = user_custom_models_table
|
||||
stmt = (
|
||||
t.update()
|
||||
.where(t.c.id == str(model_id))
|
||||
.where(t.c.user_id == user_id)
|
||||
.values(**values)
|
||||
)
|
||||
result = self._conn.execute(stmt)
|
||||
return result.rowcount > 0
|
||||
|
||||
def delete(self, model_id: str, user_id: str) -> bool:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"DELETE FROM user_custom_models "
|
||||
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||
),
|
||||
{"id": str(model_id), "user_id": user_id},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Decryption helpers exposed to the registry layer
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def get_decrypted_api_key(
|
||||
self, model_id: str, user_id: str
|
||||
) -> Optional[str]:
|
||||
"""Convenience: fetch the row and return the decrypted API key,
|
||||
or ``None`` if the row is missing or decryption fails."""
|
||||
row = self.get(model_id, user_id)
|
||||
if row is None:
|
||||
return None
|
||||
return self._decrypt_api_key(row.get("api_key_encrypted", ""), user_id)
|
||||
@@ -83,9 +83,9 @@ def count_tokens_docs(docs):
|
||||
|
||||
|
||||
def calculate_doc_token_budget(
|
||||
model_id: str = "gpt-4o"
|
||||
model_id: str = "gpt-4o", user_id: str | None = None
|
||||
) -> int:
|
||||
total_context = get_token_limit(model_id)
|
||||
total_context = get_token_limit(model_id, user_id=user_id)
|
||||
reserved = sum(settings.RESERVED_TOKENS.values())
|
||||
doc_budget = total_context - reserved
|
||||
return max(doc_budget, 1000)
|
||||
@@ -150,9 +150,11 @@ def get_hash(data):
|
||||
return hashlib.md5(data.encode(), usedforsecurity=False).hexdigest()
|
||||
|
||||
|
||||
def limit_chat_history(history, max_token_limit=None, model_id="docsgpt-local"):
|
||||
def limit_chat_history(
|
||||
history, max_token_limit=None, model_id="docsgpt-local", user_id=None
|
||||
):
|
||||
"""Limit chat history to fit within token limit."""
|
||||
model_token_limit = get_token_limit(model_id)
|
||||
model_token_limit = get_token_limit(model_id, user_id=user_id)
|
||||
max_token_limit = (
|
||||
max_token_limit
|
||||
if max_token_limit and max_token_limit < model_token_limit
|
||||
@@ -204,7 +206,9 @@ def generate_image_url(image_path):
|
||||
|
||||
|
||||
def calculate_compression_threshold(
|
||||
model_id: str, threshold_percentage: float = 0.8
|
||||
model_id: str,
|
||||
threshold_percentage: float = 0.8,
|
||||
user_id: str | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Calculate token threshold for triggering compression.
|
||||
@@ -212,11 +216,13 @@ def calculate_compression_threshold(
|
||||
Args:
|
||||
model_id: Model identifier
|
||||
threshold_percentage: Percentage of context window (default 80%)
|
||||
user_id: When set, BYOM custom-model records (UUID-keyed) resolve
|
||||
for context-window lookup.
|
||||
|
||||
Returns:
|
||||
Token count threshold
|
||||
"""
|
||||
total_context = get_token_limit(model_id)
|
||||
total_context = get_token_limit(model_id, user_id=user_id)
|
||||
threshold = int(total_context * threshold_percentage)
|
||||
return threshold
|
||||
|
||||
|
||||
@@ -344,18 +344,34 @@ def run_agent_logic(agent_config, input_data):
|
||||
|
||||
# Determine model_id: check agent's default_model_id, fallback to system default
|
||||
agent_default_model = agent_config.get("default_model_id", "")
|
||||
if agent_default_model and validate_model_id(agent_default_model):
|
||||
if agent_default_model and validate_model_id(
|
||||
agent_default_model, user_id=owner
|
||||
):
|
||||
model_id = agent_default_model
|
||||
else:
|
||||
model_id = get_default_model_id()
|
||||
if agent_default_model:
|
||||
# Stored model_id no longer resolves in the registry. Log so
|
||||
# operators can detect bad YAML edits before users complain;
|
||||
# behavior matches the historical silent fallback.
|
||||
logging.warning(
|
||||
"Agent %s references unknown model_id %r; falling back to %r",
|
||||
agent_id,
|
||||
agent_default_model,
|
||||
model_id,
|
||||
)
|
||||
|
||||
# Get provider and API key for the selected model
|
||||
provider = get_provider_from_model_id(model_id) if model_id else settings.LLM_PROVIDER
|
||||
provider = (
|
||||
get_provider_from_model_id(model_id, user_id=owner)
|
||||
if model_id
|
||||
else settings.LLM_PROVIDER
|
||||
)
|
||||
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
|
||||
|
||||
# Calculate proper doc_token_limit based on model's context window
|
||||
doc_token_limit = calculate_doc_token_budget(
|
||||
model_id=model_id
|
||||
model_id=model_id, user_id=owner
|
||||
)
|
||||
|
||||
retriever = RetrieverCreator.create_retriever(
|
||||
|
||||
@@ -99,6 +99,82 @@ EMBEDDINGS_NAME=huggingface_sentence-transformers/all-mpnet-base-v2 # You can al
|
||||
|
||||
In this case, even though you are using Ollama locally, `LLM_PROVIDER` is set to `openai` because Ollama (and many other local inference engines) are designed to be API-compatible with OpenAI. `OPENAI_BASE_URL` points DocsGPT to the local Ollama server.
|
||||
|
||||
## Adding Custom Models (`MODELS_CONFIG_DIR`)
|
||||
|
||||
DocsGPT ships with a built-in catalog of models for the providers it
|
||||
supports out of the box (OpenAI, Anthropic, Google, Groq, OpenRouter,
|
||||
Novita, Azure OpenAI, Hugging Face, DocsGPT). To add **your own
|
||||
models** without forking the repo — for example, a Mistral or Together
|
||||
account, a self-hosted vLLM endpoint, or any other OpenAI-compatible
|
||||
API — point `MODELS_CONFIG_DIR` at a directory of YAML files.
|
||||
|
||||
```
|
||||
MODELS_CONFIG_DIR=/etc/docsgpt/models
|
||||
MISTRAL_API_KEY=sk-...
|
||||
```
|
||||
|
||||
A minimal YAML for one provider:
|
||||
|
||||
```yaml
|
||||
# /etc/docsgpt/models/mistral.yaml
|
||||
provider: openai_compatible
|
||||
display_provider: mistral
|
||||
api_key_env: MISTRAL_API_KEY
|
||||
base_url: https://api.mistral.ai/v1
|
||||
defaults:
|
||||
supports_tools: true
|
||||
context_window: 128000
|
||||
models:
|
||||
- id: mistral-large-latest
|
||||
display_name: Mistral Large
|
||||
- id: mistral-small-latest
|
||||
display_name: Mistral Small
|
||||
```
|
||||
|
||||
After restart, those models appear in `/api/models` and are selectable
|
||||
in the UI. A working template lives at
|
||||
`application/core/models/examples/mistral.yaml.example`.
|
||||
|
||||
**What you can do:**
|
||||
|
||||
- Add new `openai_compatible` providers (Mistral, Together, Fireworks,
|
||||
Ollama, vLLM, ...) — one YAML per provider, each with its own
|
||||
`api_key_env` and `base_url`.
|
||||
- Extend an existing provider's catalog by dropping a YAML with the
|
||||
same `provider:` value as the built-in (e.g. `provider: anthropic`
|
||||
with extra models).
|
||||
- Override a built-in model's capabilities by re-declaring the same
|
||||
`id` — later wins, override is logged at `WARNING`.
|
||||
|
||||
**What you cannot do via `MODELS_CONFIG_DIR`:** add a brand-new
|
||||
non-OpenAI provider. That requires a Python plugin under
|
||||
`application/llm/providers/`. See
|
||||
`application/core/models/README.md` for the full schema reference.
|
||||
|
||||
### Docker
|
||||
|
||||
Mount the directory and set the env var:
|
||||
|
||||
```yaml
|
||||
# docker-compose.yml
|
||||
services:
|
||||
app:
|
||||
image: arc53/docsgpt
|
||||
environment:
|
||||
MODELS_CONFIG_DIR: /etc/docsgpt/models
|
||||
MISTRAL_API_KEY: ${MISTRAL_API_KEY}
|
||||
volumes:
|
||||
- ./my-models:/etc/docsgpt/models:ro
|
||||
```
|
||||
|
||||
### Misconfiguration
|
||||
|
||||
If `MODELS_CONFIG_DIR` is set but the path doesn't exist (or isn't a
|
||||
directory), the app logs a `WARNING` at boot and continues with just
|
||||
the built-in catalog — it does **not** fail to start. If a YAML
|
||||
declares an unknown provider name or has a schema error, the app
|
||||
**does** fail to start, with the offending file path in the message.
|
||||
|
||||
## Speech-to-Text Settings
|
||||
|
||||
DocsGPT can transcribe audio in two places:
|
||||
|
||||
@@ -448,7 +448,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
setUserTools(tools);
|
||||
};
|
||||
const getModels = async () => {
|
||||
const response = await modelService.getModels(null);
|
||||
const response = await modelService.getModels(token);
|
||||
if (!response.ok) throw new Error('Failed to fetch models');
|
||||
const data = await response.json();
|
||||
const transformed = modelService.transformModels(data.models || []);
|
||||
@@ -1041,10 +1041,24 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
isOpen={isModelsPopupOpen}
|
||||
onClose={() => setIsModelsPopupOpen(false)}
|
||||
anchorRef={modelAnchorButtonRef}
|
||||
options={availableModels.map((model) => ({
|
||||
id: model.id,
|
||||
label: model.display_name,
|
||||
}))}
|
||||
options={(() => {
|
||||
const builtinLabel = t(
|
||||
'settings.customModels.modelsGroup.builtin',
|
||||
);
|
||||
const userLabel = t('settings.customModels.modelsGroup.user');
|
||||
const builtin: OptionType[] = [];
|
||||
const user: OptionType[] = [];
|
||||
availableModels.forEach((model) => {
|
||||
const opt: OptionType = {
|
||||
id: model.id,
|
||||
label: model.display_name,
|
||||
group: model.source === 'user' ? userLabel : builtinLabel,
|
||||
};
|
||||
if (model.source === 'user') user.push(opt);
|
||||
else builtin.push(opt);
|
||||
});
|
||||
return [...builtin, ...user];
|
||||
})()}
|
||||
selectedIds={selectedModelIds}
|
||||
onSelectionChange={(newSelectedIds: Set<string | number>) =>
|
||||
setSelectedModelIds(
|
||||
|
||||
@@ -42,7 +42,9 @@ import { MultiSelect } from '@/components/ui/multi-select';
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectGroup,
|
||||
SelectItem,
|
||||
SelectLabel,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from '@/components/ui/select';
|
||||
@@ -706,7 +708,7 @@ function WorkflowBuilderInner() {
|
||||
useEffect(() => {
|
||||
const loadModelsAndTools = async () => {
|
||||
try {
|
||||
const modelsResponse = await modelService.getModels(null);
|
||||
const modelsResponse = await modelService.getModels(token);
|
||||
if (modelsResponse.ok) {
|
||||
const modelsData = await modelsResponse.json();
|
||||
const transformedModels = modelService.transformModels(
|
||||
@@ -732,7 +734,7 @@ function WorkflowBuilderInner() {
|
||||
}
|
||||
};
|
||||
loadModelsAndTools();
|
||||
}, []);
|
||||
}, [token]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!selectedNode || selectedNode.type !== 'agent') return;
|
||||
@@ -1847,15 +1849,54 @@ function WorkflowBuilderInner() {
|
||||
<SelectValue placeholder="Select a model" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
{availableModels.map((model) => (
|
||||
<SelectItem
|
||||
key={model.id}
|
||||
value={model.id}
|
||||
>
|
||||
{model.display_name} ·{' '}
|
||||
{model.provider}
|
||||
</SelectItem>
|
||||
))}
|
||||
{(() => {
|
||||
const builtin = availableModels.filter(
|
||||
(m) => m.source !== 'user',
|
||||
);
|
||||
const user = availableModels.filter(
|
||||
(m) => m.source === 'user',
|
||||
);
|
||||
return (
|
||||
<>
|
||||
{builtin.length > 0 && (
|
||||
<SelectGroup>
|
||||
<SelectLabel>
|
||||
{t(
|
||||
'settings.customModels.modelsGroup.builtin',
|
||||
)}
|
||||
</SelectLabel>
|
||||
{builtin.map((model) => (
|
||||
<SelectItem
|
||||
key={model.id}
|
||||
value={model.id}
|
||||
>
|
||||
{model.display_name} ·{' '}
|
||||
{model.provider}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectGroup>
|
||||
)}
|
||||
{user.length > 0 && (
|
||||
<SelectGroup>
|
||||
<SelectLabel>
|
||||
{t(
|
||||
'settings.customModels.modelsGroup.user',
|
||||
)}
|
||||
</SelectLabel>
|
||||
{user.map((model) => (
|
||||
<SelectItem
|
||||
key={model.id}
|
||||
value={model.id}
|
||||
>
|
||||
{model.display_name} ·{' '}
|
||||
{model.provider}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectGroup>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
})()}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
|
||||
@@ -80,6 +80,22 @@ const apiClient = {
|
||||
return response;
|
||||
}),
|
||||
|
||||
patch: (
|
||||
url: string,
|
||||
data: any,
|
||||
token: string | null,
|
||||
headers = {},
|
||||
signal?: AbortSignal,
|
||||
): Promise<any> =>
|
||||
fetch(`${baseURL}${url}`, {
|
||||
method: 'PATCH',
|
||||
headers: getHeaders(token, headers),
|
||||
body: JSON.stringify(data),
|
||||
signal,
|
||||
}).then((response) => {
|
||||
return response;
|
||||
}),
|
||||
|
||||
putFormData: (
|
||||
url: string,
|
||||
formData: FormData,
|
||||
|
||||
@@ -76,6 +76,10 @@ const endpoints = {
|
||||
GET_ARTIFACT: (artifactId: string) => `/api/artifact/${artifactId}`,
|
||||
WORKFLOWS: '/api/workflows',
|
||||
WORKFLOW: (id: string) => `/api/workflows/${id}`,
|
||||
CUSTOM_MODELS: '/api/user/models',
|
||||
CUSTOM_MODEL: (id: string) => `/api/user/models/${id}`,
|
||||
CUSTOM_MODEL_TEST: (id: string) => `/api/user/models/${id}/test`,
|
||||
CUSTOM_MODEL_TEST_PAYLOAD: '/api/user/models/test',
|
||||
},
|
||||
V1: {
|
||||
CHAT_COMPLETIONS: '/v1/chat/completions',
|
||||
|
||||
162
frontend/src/api/services/customModelsService.ts
Normal file
162
frontend/src/api/services/customModelsService.ts
Normal file
@@ -0,0 +1,162 @@
|
||||
import apiClient from '../client';
|
||||
import endpoints from '../endpoints';
|
||||
|
||||
import type {
|
||||
CreateCustomModelPayload,
|
||||
CustomModel,
|
||||
CustomModelTestResult,
|
||||
} from '../../models/types';
|
||||
|
||||
const parseJsonOrError = async (response: Response): Promise<any> => {
|
||||
const text = await response.text();
|
||||
let body: any = null;
|
||||
if (text) {
|
||||
try {
|
||||
body = JSON.parse(text);
|
||||
} catch {
|
||||
body = null;
|
||||
}
|
||||
}
|
||||
if (!response.ok) {
|
||||
const message =
|
||||
(body && (body.error || body.message)) ||
|
||||
`Request failed with status ${response.status}`;
|
||||
const err = new Error(message) as Error & {
|
||||
status?: number;
|
||||
payload?: unknown;
|
||||
};
|
||||
err.status = response.status;
|
||||
err.payload = body;
|
||||
throw err;
|
||||
}
|
||||
return body;
|
||||
};
|
||||
|
||||
const customModelsService = {
|
||||
listCustomModels: async (token: string | null): Promise<CustomModel[]> => {
|
||||
const response = await apiClient.get(endpoints.USER.CUSTOM_MODELS, token);
|
||||
const data = await parseJsonOrError(response);
|
||||
if (Array.isArray(data)) return data as CustomModel[];
|
||||
if (data && Array.isArray(data.models)) return data.models as CustomModel[];
|
||||
return [];
|
||||
},
|
||||
|
||||
createCustomModel: async (
|
||||
payload: CreateCustomModelPayload,
|
||||
token: string | null,
|
||||
): Promise<CustomModel> => {
|
||||
const response = await apiClient.post(
|
||||
endpoints.USER.CUSTOM_MODELS,
|
||||
payload,
|
||||
token,
|
||||
);
|
||||
return (await parseJsonOrError(response)) as CustomModel;
|
||||
},
|
||||
|
||||
updateCustomModel: async (
|
||||
id: string,
|
||||
payload: Partial<CreateCustomModelPayload>,
|
||||
token: string | null,
|
||||
): Promise<CustomModel> => {
|
||||
const response = await apiClient.patch(
|
||||
endpoints.USER.CUSTOM_MODEL(id),
|
||||
payload,
|
||||
token,
|
||||
);
|
||||
return (await parseJsonOrError(response)) as CustomModel;
|
||||
},
|
||||
|
||||
deleteCustomModel: async (
|
||||
id: string,
|
||||
token: string | null,
|
||||
): Promise<void> => {
|
||||
const response = await apiClient.delete(
|
||||
endpoints.USER.CUSTOM_MODEL(id),
|
||||
token,
|
||||
);
|
||||
if (!response.ok) {
|
||||
await parseJsonOrError(response);
|
||||
}
|
||||
},
|
||||
|
||||
testCustomModelPayload: async (
|
||||
payload: {
|
||||
base_url: string;
|
||||
api_key: string;
|
||||
upstream_model_id: string;
|
||||
},
|
||||
token: string | null,
|
||||
): Promise<CustomModelTestResult> => {
|
||||
const response = await apiClient.post(
|
||||
endpoints.USER.CUSTOM_MODEL_TEST_PAYLOAD,
|
||||
payload,
|
||||
token,
|
||||
);
|
||||
const text = await response.text();
|
||||
let body: any = null;
|
||||
if (text) {
|
||||
try {
|
||||
body = JSON.parse(text);
|
||||
} catch {
|
||||
body = null;
|
||||
}
|
||||
}
|
||||
if (!response.ok) {
|
||||
return {
|
||||
ok: false,
|
||||
error:
|
||||
(body && (body.error || body.message)) ||
|
||||
`Test failed with status ${response.status}`,
|
||||
};
|
||||
}
|
||||
if (body && typeof body.ok === 'boolean') {
|
||||
return body as CustomModelTestResult;
|
||||
}
|
||||
return { ok: true };
|
||||
},
|
||||
|
||||
testCustomModel: async (
|
||||
id: string,
|
||||
token: string | null,
|
||||
overrides: {
|
||||
base_url?: string;
|
||||
api_key?: string;
|
||||
upstream_model_id?: string;
|
||||
} = {},
|
||||
): Promise<CustomModelTestResult> => {
|
||||
// Send only non-empty overrides; server falls back to stored values.
|
||||
const requestBody: Record<string, string> = {};
|
||||
if (overrides.base_url) requestBody.base_url = overrides.base_url;
|
||||
if (overrides.api_key) requestBody.api_key = overrides.api_key;
|
||||
if (overrides.upstream_model_id)
|
||||
requestBody.upstream_model_id = overrides.upstream_model_id;
|
||||
const response = await apiClient.post(
|
||||
endpoints.USER.CUSTOM_MODEL_TEST(id),
|
||||
requestBody,
|
||||
token,
|
||||
);
|
||||
const text = await response.text();
|
||||
let body: any = null;
|
||||
if (text) {
|
||||
try {
|
||||
body = JSON.parse(text);
|
||||
} catch {
|
||||
body = null;
|
||||
}
|
||||
}
|
||||
if (!response.ok) {
|
||||
return {
|
||||
ok: false,
|
||||
error:
|
||||
(body && (body.error || body.message)) ||
|
||||
`Test failed with status ${response.status}`,
|
||||
};
|
||||
}
|
||||
if (body && typeof body.ok === 'boolean') {
|
||||
return body as CustomModelTestResult;
|
||||
}
|
||||
return { ok: true };
|
||||
},
|
||||
};
|
||||
|
||||
export default customModelsService;
|
||||
@@ -19,6 +19,7 @@ const modelService = {
|
||||
supports_tools: model.supports_tools,
|
||||
supports_structured_output: model.supports_structured_output,
|
||||
supports_streaming: model.supports_streaming,
|
||||
source: model.source,
|
||||
})),
|
||||
};
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import RoundedTick from '../assets/rounded-tick.svg';
|
||||
import {
|
||||
selectAvailableModels,
|
||||
selectSelectedModel,
|
||||
selectToken,
|
||||
setAvailableModels,
|
||||
setModelsLoading,
|
||||
setSelectedModel,
|
||||
@@ -18,17 +19,26 @@ export default function DropdownModel() {
|
||||
const dispatch = useDispatch();
|
||||
const selectedModel = useSelector(selectSelectedModel);
|
||||
const availableModels = useSelector(selectAvailableModels);
|
||||
const token = useSelector(selectToken);
|
||||
const dropdownRef = React.useRef<HTMLDivElement>(null);
|
||||
// Tracks which token the cached availableModels were loaded for.
|
||||
// Without this, the early-return below pins the anonymous/built-in
|
||||
// list forever once it's populated — login/logout never refetches
|
||||
// and a user's BYOM models stay invisible.
|
||||
const lastLoadedTokenRef = React.useRef<string | null | undefined>(undefined);
|
||||
const [isOpen, setIsOpen] = React.useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
const loadModels = async () => {
|
||||
if ((availableModels?.length ?? 0) > 0) {
|
||||
if (
|
||||
(availableModels?.length ?? 0) > 0 &&
|
||||
lastLoadedTokenRef.current === token
|
||||
) {
|
||||
return;
|
||||
}
|
||||
dispatch(setModelsLoading(true));
|
||||
try {
|
||||
const response = await modelService.getModels(null);
|
||||
const response = await modelService.getModels(token);
|
||||
if (!response.ok) {
|
||||
throw new Error(`API error: ${response.status}`);
|
||||
}
|
||||
@@ -37,6 +47,7 @@ export default function DropdownModel() {
|
||||
const transformed = modelService.transformModels(models);
|
||||
|
||||
dispatch(setAvailableModels(transformed));
|
||||
lastLoadedTokenRef.current = token;
|
||||
if (!selectedModel && transformed.length > 0) {
|
||||
const defaultModel =
|
||||
transformed.find((m) => m.id === data.default_model_id) ||
|
||||
@@ -59,7 +70,7 @@ export default function DropdownModel() {
|
||||
};
|
||||
|
||||
loadModels();
|
||||
}, [availableModels?.length, dispatch, selectedModel]);
|
||||
}, [availableModels?.length, dispatch, selectedModel, token]);
|
||||
|
||||
const handleClickOutside = (event: MouseEvent) => {
|
||||
if (
|
||||
|
||||
@@ -11,6 +11,7 @@ export type OptionType = {
|
||||
id: string | number;
|
||||
label: string;
|
||||
icon?: string | React.ReactNode;
|
||||
group?: string;
|
||||
[key: string]: any;
|
||||
};
|
||||
|
||||
@@ -227,43 +228,75 @@ export default function MultiSelectPopup({
|
||||
</p>
|
||||
</div>
|
||||
) : (
|
||||
filteredOptions.map((option) => {
|
||||
const isSelected = selectedIds.has(option.id);
|
||||
return (
|
||||
<div
|
||||
key={option.id}
|
||||
onClick={() => handleOptionClick(option.id)}
|
||||
className="dark:border-border dark:hover:bg-accent hover:bg-accent flex cursor-pointer items-center justify-between border-b border-[#D9D9D9] p-3 last:border-b-0"
|
||||
role="option"
|
||||
aria-selected={isSelected}
|
||||
>
|
||||
<div className="mr-3 flex grow items-center overflow-hidden">
|
||||
{option.icon && renderIcon(option.icon)}
|
||||
<p
|
||||
className="overflow-hidden text-sm font-medium text-ellipsis whitespace-nowrap text-gray-900 dark:text-white"
|
||||
title={option.label}
|
||||
>
|
||||
{option.label}
|
||||
</p>
|
||||
</div>
|
||||
<div className="shrink-0">
|
||||
<div
|
||||
className={`border-border bg-card flex h-4 w-4 items-center justify-center rounded-xs border-2`}
|
||||
aria-hidden="true"
|
||||
>
|
||||
{isSelected && (
|
||||
<img
|
||||
src={CheckmarkIcon}
|
||||
alt="checkmark"
|
||||
width={10}
|
||||
height={10}
|
||||
/>
|
||||
)}
|
||||
(() => {
|
||||
const hasGroups = filteredOptions.some((o) => !!o.group);
|
||||
const renderOption = (option: OptionType) => {
|
||||
const isSelected = selectedIds.has(option.id);
|
||||
return (
|
||||
<div
|
||||
key={option.id}
|
||||
onClick={() => handleOptionClick(option.id)}
|
||||
className="dark:border-border dark:hover:bg-accent hover:bg-accent flex cursor-pointer items-center justify-between border-b border-[#D9D9D9] p-3 last:border-b-0"
|
||||
role="option"
|
||||
aria-selected={isSelected}
|
||||
>
|
||||
<div className="mr-3 flex grow items-center overflow-hidden">
|
||||
{option.icon && renderIcon(option.icon)}
|
||||
<p
|
||||
className="overflow-hidden text-sm font-medium text-ellipsis whitespace-nowrap text-gray-900 dark:text-white"
|
||||
title={option.label}
|
||||
>
|
||||
{option.label}
|
||||
</p>
|
||||
</div>
|
||||
<div className="shrink-0">
|
||||
<div
|
||||
className={`border-border bg-card flex h-4 w-4 items-center justify-center rounded-xs border-2`}
|
||||
aria-hidden="true"
|
||||
>
|
||||
{isSelected && (
|
||||
<img
|
||||
src={CheckmarkIcon}
|
||||
alt="checkmark"
|
||||
width={10}
|
||||
height={10}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
if (!hasGroups) {
|
||||
return filteredOptions.map(renderOption);
|
||||
}
|
||||
|
||||
const groupOrder: string[] = [];
|
||||
const groupMap = new Map<string, OptionType[]>();
|
||||
filteredOptions.forEach((opt) => {
|
||||
const key = opt.group || '';
|
||||
if (!groupMap.has(key)) {
|
||||
groupOrder.push(key);
|
||||
groupMap.set(key, []);
|
||||
}
|
||||
groupMap.get(key)!.push(opt);
|
||||
});
|
||||
|
||||
return groupOrder.map((groupKey) => (
|
||||
<div key={`group-${groupKey || 'ungrouped'}`}>
|
||||
{groupKey && (
|
||||
<div
|
||||
className="bg-muted/50 dark:bg-card text-muted-foreground sticky top-0 z-10 border-b border-[#D9D9D9] px-3 py-1.5 text-xs font-semibold uppercase dark:border-[#2E2E2E]"
|
||||
role="presentation"
|
||||
>
|
||||
{groupKey}
|
||||
</div>
|
||||
)}
|
||||
{(groupMap.get(groupKey) || []).map(renderOption)}
|
||||
</div>
|
||||
);
|
||||
})
|
||||
));
|
||||
})()
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -14,6 +14,7 @@ const useTabs = () => {
|
||||
t('settings.analytics.label'),
|
||||
t('settings.logs.label'),
|
||||
t('settings.tools.label'),
|
||||
t('settings.customModels.label'),
|
||||
];
|
||||
return tabs;
|
||||
};
|
||||
|
||||
@@ -244,6 +244,71 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"customModels": {
|
||||
"label": "Eigene Modelle",
|
||||
"subtitle": "Verbinde dein eigenes OpenAI-kompatibles Modell.",
|
||||
"addModel": "Modell hinzufügen",
|
||||
"empty": "Keine Modelle gefunden",
|
||||
"searchPlaceholder": "Modelle suchen...",
|
||||
"modalSubtitle": "Verbinde einen beliebigen OpenAI-kompatiblen Endpunkt (Mistral, Together, vLLM usw.).",
|
||||
"addTitle": "Eigenes Modell hinzufügen",
|
||||
"editTitle": "Eigenes Modell bearbeiten",
|
||||
"save": "Speichern",
|
||||
"saving": "Speichern",
|
||||
"testConnection": "Verbindung testen",
|
||||
"testing": "Test läuft",
|
||||
"testSuccess": "Verbindung erfolgreich.",
|
||||
"testHintNew": "Fülle Basis-URL, Modell-ID und API-Schlüssel aus, um Verbindungstests zu aktivieren.",
|
||||
"disabledBadge": "Deaktiviert",
|
||||
"deleteWarning": "Möchtest du das eigene Modell \"{{modelName}}\" wirklich löschen?",
|
||||
"actionsMenuAria": "Aktionen für {{modelName}}",
|
||||
"modelsGroup": {
|
||||
"builtin": "DocsGPT-Modelle",
|
||||
"user": "Meine Modelle"
|
||||
},
|
||||
"actions": {
|
||||
"edit": "Bearbeiten",
|
||||
"delete": "Löschen"
|
||||
},
|
||||
"fields": {
|
||||
"displayName": "Anzeigename",
|
||||
"modelId": "Modell-ID",
|
||||
"description": "Beschreibung",
|
||||
"baseUrl": "Basis-URL",
|
||||
"apiKey": "API-Schlüssel"
|
||||
},
|
||||
"placeholders": {
|
||||
"displayName": "Mein Mistral",
|
||||
"modelId": "z. B. mistral-large-latest",
|
||||
"description": "Optionale Beschreibung",
|
||||
"baseUrl": "https://api.mistral.ai/v1",
|
||||
"apiKey": "sk-...",
|
||||
"apiKeyEdit": "Leer lassen, um den vorhandenen Schlüssel beizubehalten"
|
||||
},
|
||||
"hints": {
|
||||
"apiKeyEdit": "Leer lassen, um den vorhandenen Schlüssel beizubehalten."
|
||||
},
|
||||
"capabilities": {
|
||||
"title": "Fähigkeiten",
|
||||
"contextWindowShort": "Kontextfenster",
|
||||
"chips": {
|
||||
"tools": "Tools",
|
||||
"structuredOutput": "Strukturierte Ausgabe",
|
||||
"images": "Bilder"
|
||||
}
|
||||
},
|
||||
"errors": {
|
||||
"displayNameRequired": "Anzeigename ist erforderlich",
|
||||
"modelIdRequired": "Modell-ID ist erforderlich",
|
||||
"baseUrlRequired": "Basis-URL ist erforderlich",
|
||||
"baseUrlScheme": "Basis-URL muss mit http:// oder https:// beginnen",
|
||||
"baseUrlInvalid": "Bitte gib eine gültige URL ein",
|
||||
"apiKeyRequired": "API-Schlüssel ist erforderlich",
|
||||
"contextWindowRange": "Kontextfenster muss zwischen 1.000 und 10.000.000 liegen",
|
||||
"saveFailed": "Eigenes Modell konnte nicht gespeichert werden",
|
||||
"testFailed": "Verbindungstest fehlgeschlagen"
|
||||
}
|
||||
},
|
||||
"scrollTabsLeft": "Tabs nach links scrollen",
|
||||
"tabsAriaLabel": "Einstellungs-Tabs",
|
||||
"scrollTabsRight": "Tabs nach rechts scrollen"
|
||||
|
||||
@@ -256,6 +256,71 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"customModels": {
|
||||
"label": "Custom Models",
|
||||
"subtitle": "Bring your own OpenAI-compatible model.",
|
||||
"addModel": "Add Model",
|
||||
"empty": "No models found",
|
||||
"searchPlaceholder": "Search models...",
|
||||
"modalSubtitle": "Connect any OpenAI-compatible endpoint (Mistral, Together, vLLM, etc).",
|
||||
"addTitle": "Add Custom Model",
|
||||
"editTitle": "Edit Custom Model",
|
||||
"save": "Save",
|
||||
"saving": "Saving",
|
||||
"testConnection": "Test connection",
|
||||
"testing": "Testing",
|
||||
"testSuccess": "Connection successful.",
|
||||
"testHintNew": "Fill in Base URL, Model ID, and API key to enable connection tests.",
|
||||
"disabledBadge": "Disabled",
|
||||
"deleteWarning": "Are you sure you want to delete the custom model \"{{modelName}}\"?",
|
||||
"actionsMenuAria": "Actions for {{modelName}}",
|
||||
"modelsGroup": {
|
||||
"builtin": "DocsGPT Models",
|
||||
"user": "My Models"
|
||||
},
|
||||
"actions": {
|
||||
"edit": "Edit",
|
||||
"delete": "Delete"
|
||||
},
|
||||
"fields": {
|
||||
"displayName": "Display name",
|
||||
"modelId": "Model ID",
|
||||
"description": "Description",
|
||||
"baseUrl": "Base URL",
|
||||
"apiKey": "API key"
|
||||
},
|
||||
"placeholders": {
|
||||
"displayName": "My Mistral",
|
||||
"modelId": "e.g. mistral-large-latest",
|
||||
"description": "Optional description",
|
||||
"baseUrl": "https://api.mistral.ai/v1",
|
||||
"apiKey": "sk-...",
|
||||
"apiKeyEdit": "Leave blank to keep existing"
|
||||
},
|
||||
"hints": {
|
||||
"apiKeyEdit": "Leave blank to keep existing key."
|
||||
},
|
||||
"capabilities": {
|
||||
"title": "Capabilities",
|
||||
"contextWindowShort": "Context window",
|
||||
"chips": {
|
||||
"tools": "Tools",
|
||||
"structuredOutput": "Structured output",
|
||||
"images": "Images"
|
||||
}
|
||||
},
|
||||
"errors": {
|
||||
"displayNameRequired": "Display name is required",
|
||||
"modelIdRequired": "Model ID is required",
|
||||
"baseUrlRequired": "Base URL is required",
|
||||
"baseUrlScheme": "Base URL must start with http:// or https://",
|
||||
"baseUrlInvalid": "Please enter a valid URL",
|
||||
"apiKeyRequired": "API key is required",
|
||||
"contextWindowRange": "Context window must be between 1,000 and 10,000,000",
|
||||
"saveFailed": "Failed to save custom model",
|
||||
"testFailed": "Connection test failed"
|
||||
}
|
||||
},
|
||||
"scrollTabsLeft": "Scroll tabs left",
|
||||
"tabsAriaLabel": "Settings tabs",
|
||||
"scrollTabsRight": "Scroll tabs right"
|
||||
|
||||
@@ -244,6 +244,71 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"customModels": {
|
||||
"label": "Modelos personalizados",
|
||||
"subtitle": "Trae tu propio modelo compatible con OpenAI.",
|
||||
"addModel": "Añadir modelo",
|
||||
"empty": "No se encontraron modelos",
|
||||
"searchPlaceholder": "Buscar modelos...",
|
||||
"modalSubtitle": "Conecta cualquier endpoint compatible con OpenAI (Mistral, Together, vLLM, etc.).",
|
||||
"addTitle": "Añadir modelo personalizado",
|
||||
"editTitle": "Editar modelo personalizado",
|
||||
"save": "Guardar",
|
||||
"saving": "Guardando",
|
||||
"testConnection": "Probar conexión",
|
||||
"testing": "Probando",
|
||||
"testSuccess": "Conexión correcta.",
|
||||
"testHintNew": "Completa URL base, ID de modelo y clave API para habilitar las pruebas de conexión.",
|
||||
"disabledBadge": "Deshabilitado",
|
||||
"deleteWarning": "¿Seguro que quieres eliminar el modelo personalizado \"{{modelName}}\"?",
|
||||
"actionsMenuAria": "Acciones para {{modelName}}",
|
||||
"modelsGroup": {
|
||||
"builtin": "Modelos de DocsGPT",
|
||||
"user": "Mis modelos"
|
||||
},
|
||||
"actions": {
|
||||
"edit": "Editar",
|
||||
"delete": "Eliminar"
|
||||
},
|
||||
"fields": {
|
||||
"displayName": "Nombre para mostrar",
|
||||
"modelId": "ID del modelo",
|
||||
"description": "Descripción",
|
||||
"baseUrl": "URL base",
|
||||
"apiKey": "Clave API"
|
||||
},
|
||||
"placeholders": {
|
||||
"displayName": "Mi Mistral",
|
||||
"modelId": "p. ej. mistral-large-latest",
|
||||
"description": "Descripción opcional",
|
||||
"baseUrl": "https://api.mistral.ai/v1",
|
||||
"apiKey": "sk-...",
|
||||
"apiKeyEdit": "Déjalo en blanco para mantener la actual"
|
||||
},
|
||||
"hints": {
|
||||
"apiKeyEdit": "Déjalo en blanco para mantener la clave actual."
|
||||
},
|
||||
"capabilities": {
|
||||
"title": "Capacidades",
|
||||
"contextWindowShort": "Ventana de contexto",
|
||||
"chips": {
|
||||
"tools": "Herramientas",
|
||||
"structuredOutput": "Salida estructurada",
|
||||
"images": "Imágenes"
|
||||
}
|
||||
},
|
||||
"errors": {
|
||||
"displayNameRequired": "El nombre para mostrar es obligatorio",
|
||||
"modelIdRequired": "El ID del modelo es obligatorio",
|
||||
"baseUrlRequired": "La URL base es obligatoria",
|
||||
"baseUrlScheme": "La URL base debe empezar por http:// o https://",
|
||||
"baseUrlInvalid": "Introduce una URL válida",
|
||||
"apiKeyRequired": "La clave API es obligatoria",
|
||||
"contextWindowRange": "La ventana de contexto debe estar entre 1.000 y 10.000.000",
|
||||
"saveFailed": "No se pudo guardar el modelo personalizado",
|
||||
"testFailed": "La prueba de conexión falló"
|
||||
}
|
||||
},
|
||||
"scrollTabsLeft": "Desplazar pestañas a la izquierda",
|
||||
"tabsAriaLabel": "Pestañas de configuración",
|
||||
"scrollTabsRight": "Desplazar pestañas a la derecha"
|
||||
|
||||
@@ -244,6 +244,71 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"customModels": {
|
||||
"label": "カスタムモデル",
|
||||
"subtitle": "OpenAI 互換のモデルを自分で接続できます。",
|
||||
"addModel": "モデルを追加",
|
||||
"empty": "モデルが見つかりません",
|
||||
"searchPlaceholder": "モデルを検索...",
|
||||
"modalSubtitle": "OpenAI 互換のエンドポイントを接続できます (Mistral、Together、vLLM など)。",
|
||||
"addTitle": "カスタムモデルを追加",
|
||||
"editTitle": "カスタムモデルを編集",
|
||||
"save": "保存",
|
||||
"saving": "保存中",
|
||||
"testConnection": "接続をテスト",
|
||||
"testing": "テスト中",
|
||||
"testSuccess": "接続に成功しました。",
|
||||
"testHintNew": "接続テストを有効にするには、ベース URL、モデル ID、API キーを入力してください。",
|
||||
"disabledBadge": "無効",
|
||||
"deleteWarning": "カスタムモデル「{{modelName}}」を削除してもよろしいですか?",
|
||||
"actionsMenuAria": "{{modelName}} のアクション",
|
||||
"modelsGroup": {
|
||||
"builtin": "DocsGPT モデル",
|
||||
"user": "マイモデル"
|
||||
},
|
||||
"actions": {
|
||||
"edit": "編集",
|
||||
"delete": "削除"
|
||||
},
|
||||
"fields": {
|
||||
"displayName": "表示名",
|
||||
"modelId": "モデル ID",
|
||||
"description": "説明",
|
||||
"baseUrl": "ベース URL",
|
||||
"apiKey": "API キー"
|
||||
},
|
||||
"placeholders": {
|
||||
"displayName": "My Mistral",
|
||||
"modelId": "例: mistral-large-latest",
|
||||
"description": "説明 (任意)",
|
||||
"baseUrl": "https://api.mistral.ai/v1",
|
||||
"apiKey": "sk-...",
|
||||
"apiKeyEdit": "既存のキーを保持する場合は空欄のまま"
|
||||
},
|
||||
"hints": {
|
||||
"apiKeyEdit": "既存のキーを保持するには空欄のままにしてください。"
|
||||
},
|
||||
"capabilities": {
|
||||
"title": "機能",
|
||||
"contextWindowShort": "コンテキストウィンドウ",
|
||||
"chips": {
|
||||
"tools": "ツール",
|
||||
"structuredOutput": "構造化出力",
|
||||
"images": "画像"
|
||||
}
|
||||
},
|
||||
"errors": {
|
||||
"displayNameRequired": "表示名は必須です",
|
||||
"modelIdRequired": "モデル ID は必須です",
|
||||
"baseUrlRequired": "ベース URL は必須です",
|
||||
"baseUrlScheme": "ベース URL は http:// または https:// で始まる必要があります",
|
||||
"baseUrlInvalid": "有効な URL を入力してください",
|
||||
"apiKeyRequired": "API キーは必須です",
|
||||
"contextWindowRange": "コンテキストウィンドウは 1,000 から 10,000,000 の範囲で指定してください",
|
||||
"saveFailed": "カスタムモデルの保存に失敗しました",
|
||||
"testFailed": "接続テストに失敗しました"
|
||||
}
|
||||
},
|
||||
"scrollTabsLeft": "タブを左にスクロール",
|
||||
"tabsAriaLabel": "設定タブ",
|
||||
"scrollTabsRight": "タブを右にスクロール"
|
||||
|
||||
@@ -244,6 +244,71 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"customModels": {
|
||||
"label": "Свои модели",
|
||||
"subtitle": "Подключите свою модель, совместимую с OpenAI.",
|
||||
"addModel": "Добавить модель",
|
||||
"empty": "Модели не найдены",
|
||||
"searchPlaceholder": "Поиск моделей...",
|
||||
"modalSubtitle": "Подключите любой OpenAI-совместимый эндпоинт (Mistral, Together, vLLM и т. п.).",
|
||||
"addTitle": "Добавить свою модель",
|
||||
"editTitle": "Изменить свою модель",
|
||||
"save": "Сохранить",
|
||||
"saving": "Сохранение",
|
||||
"testConnection": "Проверить соединение",
|
||||
"testing": "Проверка",
|
||||
"testSuccess": "Соединение установлено.",
|
||||
"testHintNew": "Заполните Base URL, ID модели и API-ключ, чтобы включить проверку соединения.",
|
||||
"disabledBadge": "Отключено",
|
||||
"deleteWarning": "Удалить свою модель «{{modelName}}»?",
|
||||
"actionsMenuAria": "Действия для {{modelName}}",
|
||||
"modelsGroup": {
|
||||
"builtin": "Модели DocsGPT",
|
||||
"user": "Мои модели"
|
||||
},
|
||||
"actions": {
|
||||
"edit": "Изменить",
|
||||
"delete": "Удалить"
|
||||
},
|
||||
"fields": {
|
||||
"displayName": "Отображаемое имя",
|
||||
"modelId": "ID модели",
|
||||
"description": "Описание",
|
||||
"baseUrl": "Base URL",
|
||||
"apiKey": "API-ключ"
|
||||
},
|
||||
"placeholders": {
|
||||
"displayName": "Мой Mistral",
|
||||
"modelId": "напр. mistral-large-latest",
|
||||
"description": "Необязательное описание",
|
||||
"baseUrl": "https://api.mistral.ai/v1",
|
||||
"apiKey": "sk-...",
|
||||
"apiKeyEdit": "Оставьте пустым, чтобы сохранить текущий"
|
||||
},
|
||||
"hints": {
|
||||
"apiKeyEdit": "Оставьте пустым, чтобы сохранить текущий ключ."
|
||||
},
|
||||
"capabilities": {
|
||||
"title": "Возможности",
|
||||
"contextWindowShort": "Контекстное окно",
|
||||
"chips": {
|
||||
"tools": "Инструменты",
|
||||
"structuredOutput": "Структурированный вывод",
|
||||
"images": "Изображения"
|
||||
}
|
||||
},
|
||||
"errors": {
|
||||
"displayNameRequired": "Укажите отображаемое имя",
|
||||
"modelIdRequired": "Укажите ID модели",
|
||||
"baseUrlRequired": "Укажите Base URL",
|
||||
"baseUrlScheme": "Base URL должен начинаться с http:// или https://",
|
||||
"baseUrlInvalid": "Введите корректный URL",
|
||||
"apiKeyRequired": "Укажите API-ключ",
|
||||
"contextWindowRange": "Контекстное окно должно быть от 1 000 до 10 000 000",
|
||||
"saveFailed": "Не удалось сохранить свою модель",
|
||||
"testFailed": "Проверка соединения не удалась"
|
||||
}
|
||||
},
|
||||
"scrollTabsLeft": "Прокрутить вкладки влево",
|
||||
"tabsAriaLabel": "Вкладки настроек",
|
||||
"scrollTabsRight": "Прокрутить вкладки вправо"
|
||||
|
||||
@@ -244,6 +244,71 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"customModels": {
|
||||
"label": "自訂模型",
|
||||
"subtitle": "接入你自己的 OpenAI 相容模型。",
|
||||
"addModel": "新增模型",
|
||||
"empty": "找不到模型",
|
||||
"searchPlaceholder": "搜尋模型...",
|
||||
"modalSubtitle": "可連接任何 OpenAI 相容端點(Mistral、Together、vLLM 等)。",
|
||||
"addTitle": "新增自訂模型",
|
||||
"editTitle": "編輯自訂模型",
|
||||
"save": "儲存",
|
||||
"saving": "儲存中",
|
||||
"testConnection": "測試連線",
|
||||
"testing": "測試中",
|
||||
"testSuccess": "連線成功。",
|
||||
"testHintNew": "請填寫 Base URL、模型 ID 與 API 金鑰以啟用連線測試。",
|
||||
"disabledBadge": "已停用",
|
||||
"deleteWarning": "確定要刪除自訂模型「{{modelName}}」嗎?",
|
||||
"actionsMenuAria": "{{modelName}} 的操作",
|
||||
"modelsGroup": {
|
||||
"builtin": "DocsGPT 模型",
|
||||
"user": "我的模型"
|
||||
},
|
||||
"actions": {
|
||||
"edit": "編輯",
|
||||
"delete": "刪除"
|
||||
},
|
||||
"fields": {
|
||||
"displayName": "顯示名稱",
|
||||
"modelId": "模型 ID",
|
||||
"description": "說明",
|
||||
"baseUrl": "Base URL",
|
||||
"apiKey": "API 金鑰"
|
||||
},
|
||||
"placeholders": {
|
||||
"displayName": "My Mistral",
|
||||
"modelId": "例:mistral-large-latest",
|
||||
"description": "選填說明",
|
||||
"baseUrl": "https://api.mistral.ai/v1",
|
||||
"apiKey": "sk-...",
|
||||
"apiKeyEdit": "留空則保留原金鑰"
|
||||
},
|
||||
"hints": {
|
||||
"apiKeyEdit": "留空以保留現有金鑰。"
|
||||
},
|
||||
"capabilities": {
|
||||
"title": "功能",
|
||||
"contextWindowShort": "脈絡視窗",
|
||||
"chips": {
|
||||
"tools": "工具",
|
||||
"structuredOutput": "結構化輸出",
|
||||
"images": "圖片"
|
||||
}
|
||||
},
|
||||
"errors": {
|
||||
"displayNameRequired": "顯示名稱為必填",
|
||||
"modelIdRequired": "模型 ID 為必填",
|
||||
"baseUrlRequired": "Base URL 為必填",
|
||||
"baseUrlScheme": "Base URL 必須以 http:// 或 https:// 開頭",
|
||||
"baseUrlInvalid": "請輸入有效的 URL",
|
||||
"apiKeyRequired": "API 金鑰為必填",
|
||||
"contextWindowRange": "脈絡視窗須介於 1,000 至 10,000,000 之間",
|
||||
"saveFailed": "儲存自訂模型失敗",
|
||||
"testFailed": "連線測試失敗"
|
||||
}
|
||||
},
|
||||
"scrollTabsLeft": "向左捲動標籤",
|
||||
"tabsAriaLabel": "設定標籤",
|
||||
"scrollTabsRight": "向右捲動標籤"
|
||||
|
||||
@@ -244,6 +244,71 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"customModels": {
|
||||
"label": "自定义模型",
|
||||
"subtitle": "接入你自己的 OpenAI 兼容模型。",
|
||||
"addModel": "添加模型",
|
||||
"empty": "未找到模型",
|
||||
"searchPlaceholder": "搜索模型...",
|
||||
"modalSubtitle": "可连接任何 OpenAI 兼容的接口(Mistral、Together、vLLM 等)。",
|
||||
"addTitle": "添加自定义模型",
|
||||
"editTitle": "编辑自定义模型",
|
||||
"save": "保存",
|
||||
"saving": "保存中",
|
||||
"testConnection": "测试连接",
|
||||
"testing": "测试中",
|
||||
"testSuccess": "连接成功。",
|
||||
"testHintNew": "请填写 Base URL、模型 ID 和 API 密钥以启用连接测试。",
|
||||
"disabledBadge": "已禁用",
|
||||
"deleteWarning": "确定要删除自定义模型「{{modelName}}」吗?",
|
||||
"actionsMenuAria": "{{modelName}} 的操作",
|
||||
"modelsGroup": {
|
||||
"builtin": "DocsGPT 模型",
|
||||
"user": "我的模型"
|
||||
},
|
||||
"actions": {
|
||||
"edit": "编辑",
|
||||
"delete": "删除"
|
||||
},
|
||||
"fields": {
|
||||
"displayName": "显示名称",
|
||||
"modelId": "模型 ID",
|
||||
"description": "描述",
|
||||
"baseUrl": "Base URL",
|
||||
"apiKey": "API 密钥"
|
||||
},
|
||||
"placeholders": {
|
||||
"displayName": "My Mistral",
|
||||
"modelId": "如 mistral-large-latest",
|
||||
"description": "可选描述",
|
||||
"baseUrl": "https://api.mistral.ai/v1",
|
||||
"apiKey": "sk-...",
|
||||
"apiKeyEdit": "留空则保留原密钥"
|
||||
},
|
||||
"hints": {
|
||||
"apiKeyEdit": "留空以保留现有密钥。"
|
||||
},
|
||||
"capabilities": {
|
||||
"title": "功能",
|
||||
"contextWindowShort": "上下文窗口",
|
||||
"chips": {
|
||||
"tools": "工具",
|
||||
"structuredOutput": "结构化输出",
|
||||
"images": "图片"
|
||||
}
|
||||
},
|
||||
"errors": {
|
||||
"displayNameRequired": "显示名称必填",
|
||||
"modelIdRequired": "模型 ID 必填",
|
||||
"baseUrlRequired": "Base URL 必填",
|
||||
"baseUrlScheme": "Base URL 必须以 http:// 或 https:// 开头",
|
||||
"baseUrlInvalid": "请输入有效的 URL",
|
||||
"apiKeyRequired": "API 密钥必填",
|
||||
"contextWindowRange": "上下文窗口必须在 1,000 至 10,000,000 之间",
|
||||
"saveFailed": "保存自定义模型失败",
|
||||
"testFailed": "连接测试失败"
|
||||
}
|
||||
},
|
||||
"scrollTabsLeft": "向左滚动标签",
|
||||
"tabsAriaLabel": "设置标签",
|
||||
"scrollTabsRight": "向右滚动标签"
|
||||
|
||||
597
frontend/src/modals/CustomModelModal.tsx
Normal file
597
frontend/src/modals/CustomModelModal.tsx
Normal file
@@ -0,0 +1,597 @@
|
||||
import { Check } from 'lucide-react';
|
||||
import { useEffect, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useSelector } from 'react-redux';
|
||||
|
||||
import customModelsService from '../api/services/customModelsService';
|
||||
import Spinner from '../components/Spinner';
|
||||
import { Input } from '../components/ui/input';
|
||||
import { Label } from '../components/ui/label';
|
||||
import { ActiveState } from '../models/misc';
|
||||
import { selectToken } from '../preferences/preferenceSlice';
|
||||
import WrapperComponent from './WrapperModal';
|
||||
|
||||
import type { CreateCustomModelPayload, CustomModel } from '../models/types';
|
||||
|
||||
interface CustomModelModalProps {
|
||||
modalState: ActiveState;
|
||||
setModalState: (state: ActiveState) => void;
|
||||
model?: CustomModel | null;
|
||||
onSaved: (model: CustomModel) => void;
|
||||
}
|
||||
|
||||
interface FormState {
|
||||
display_name: string;
|
||||
upstream_model_id: string;
|
||||
description: string;
|
||||
base_url: string;
|
||||
api_key: string;
|
||||
supports_tools: boolean;
|
||||
supports_structured_output: boolean;
|
||||
supports_images: boolean;
|
||||
context_window: number | '';
|
||||
enabled: boolean;
|
||||
}
|
||||
|
||||
const DEFAULT_CONTEXT_WINDOW = 128000;
|
||||
const MIN_CONTEXT_WINDOW = 1000;
|
||||
const MAX_CONTEXT_WINDOW = 10_000_000;
|
||||
|
||||
const buildInitialFormState = (model?: CustomModel | null): FormState => {
|
||||
if (!model) {
|
||||
return {
|
||||
display_name: '',
|
||||
upstream_model_id: '',
|
||||
description: '',
|
||||
base_url: '',
|
||||
api_key: '',
|
||||
supports_tools: true,
|
||||
supports_structured_output: true,
|
||||
supports_images: false,
|
||||
context_window: DEFAULT_CONTEXT_WINDOW,
|
||||
enabled: true,
|
||||
};
|
||||
}
|
||||
const attachments = Array.isArray(model.capabilities?.attachments)
|
||||
? model.capabilities.attachments
|
||||
: [];
|
||||
return {
|
||||
display_name: model.display_name || '',
|
||||
upstream_model_id: model.upstream_model_id || '',
|
||||
description: model.description || '',
|
||||
base_url: model.base_url || '',
|
||||
api_key: '',
|
||||
supports_tools: model.capabilities?.supports_tools ?? true,
|
||||
supports_structured_output:
|
||||
model.capabilities?.supports_structured_output ?? true,
|
||||
supports_images: attachments.includes('image'),
|
||||
context_window:
|
||||
model.capabilities?.context_window ?? DEFAULT_CONTEXT_WINDOW,
|
||||
enabled: model.enabled ?? true,
|
||||
};
|
||||
};
|
||||
|
||||
export default function CustomModelModal({
|
||||
modalState,
|
||||
setModalState,
|
||||
model,
|
||||
onSaved,
|
||||
}: CustomModelModalProps) {
|
||||
const { t } = useTranslation();
|
||||
const token = useSelector(selectToken);
|
||||
const isEditMode = !!model?.id;
|
||||
|
||||
const [formData, setFormData] = useState<FormState>(() =>
|
||||
buildInitialFormState(model),
|
||||
);
|
||||
const [errors, setErrors] = useState<{ [key: string]: string }>({});
|
||||
const [saving, setSaving] = useState(false);
|
||||
const [testing, setTesting] = useState(false);
|
||||
const [testResult, setTestResult] = useState<{
|
||||
ok: boolean;
|
||||
message: string;
|
||||
} | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
if (modalState === 'ACTIVE') {
|
||||
setFormData(buildInitialFormState(model));
|
||||
setErrors({});
|
||||
setTestResult(null);
|
||||
setSaving(false);
|
||||
setTesting(false);
|
||||
}
|
||||
}, [modalState, model]);
|
||||
|
||||
const closeModal = () => {
|
||||
setModalState('INACTIVE');
|
||||
};
|
||||
|
||||
const handleChange = <K extends keyof FormState>(
|
||||
name: K,
|
||||
value: FormState[K],
|
||||
) => {
|
||||
setFormData((prev) => ({ ...prev, [name]: value }));
|
||||
if (errors[name as string] || errors.general) {
|
||||
setErrors((prev) => {
|
||||
const next = { ...prev };
|
||||
delete next[name as string];
|
||||
delete next.general;
|
||||
delete next.base_url_remote;
|
||||
return next;
|
||||
});
|
||||
}
|
||||
setTestResult(null);
|
||||
};
|
||||
|
||||
const validate = (): boolean => {
|
||||
const newErrors: { [key: string]: string } = {};
|
||||
if (!formData.display_name.trim()) {
|
||||
newErrors.display_name = t(
|
||||
'settings.customModels.errors.displayNameRequired',
|
||||
);
|
||||
}
|
||||
if (!formData.upstream_model_id.trim()) {
|
||||
newErrors.upstream_model_id = t(
|
||||
'settings.customModels.errors.modelIdRequired',
|
||||
);
|
||||
}
|
||||
const trimmedUrl = formData.base_url.trim();
|
||||
if (!trimmedUrl) {
|
||||
newErrors.base_url = t('settings.customModels.errors.baseUrlRequired');
|
||||
} else if (!/^https?:\/\//i.test(trimmedUrl)) {
|
||||
newErrors.base_url = t('settings.customModels.errors.baseUrlScheme');
|
||||
} else {
|
||||
try {
|
||||
new URL(trimmedUrl);
|
||||
} catch {
|
||||
newErrors.base_url = t('settings.customModels.errors.baseUrlInvalid');
|
||||
}
|
||||
}
|
||||
if (!isEditMode && !formData.api_key.trim()) {
|
||||
newErrors.api_key = t('settings.customModels.errors.apiKeyRequired');
|
||||
}
|
||||
const ctxValue =
|
||||
formData.context_window === '' ? NaN : Number(formData.context_window);
|
||||
if (
|
||||
Number.isNaN(ctxValue) ||
|
||||
ctxValue < MIN_CONTEXT_WINDOW ||
|
||||
ctxValue > MAX_CONTEXT_WINDOW
|
||||
) {
|
||||
newErrors.context_window = t(
|
||||
'settings.customModels.errors.contextWindowRange',
|
||||
);
|
||||
}
|
||||
|
||||
setErrors(newErrors);
|
||||
return Object.keys(newErrors).length === 0;
|
||||
};
|
||||
|
||||
const buildPayload = (): CreateCustomModelPayload => {
|
||||
const ctxValue =
|
||||
formData.context_window === ''
|
||||
? DEFAULT_CONTEXT_WINDOW
|
||||
: Number(formData.context_window);
|
||||
const payload: CreateCustomModelPayload = {
|
||||
upstream_model_id: formData.upstream_model_id.trim(),
|
||||
display_name: formData.display_name.trim(),
|
||||
description: formData.description.trim(),
|
||||
base_url: formData.base_url.trim(),
|
||||
capabilities: {
|
||||
supports_tools: formData.supports_tools,
|
||||
supports_structured_output: formData.supports_structured_output,
|
||||
attachments: formData.supports_images ? ['image'] : [],
|
||||
context_window: ctxValue,
|
||||
},
|
||||
enabled: formData.enabled,
|
||||
};
|
||||
if (formData.api_key.trim()) {
|
||||
payload.api_key = formData.api_key.trim();
|
||||
}
|
||||
return payload;
|
||||
};
|
||||
|
||||
const mapErrorToField = (
|
||||
message: string,
|
||||
): { field: string; message: string } => {
|
||||
const lower = message.toLowerCase();
|
||||
if (
|
||||
lower.includes('reachable') ||
|
||||
lower.includes('public internet') ||
|
||||
lower.includes('ssrf') ||
|
||||
lower.includes('url') ||
|
||||
lower.includes('host')
|
||||
) {
|
||||
return { field: 'base_url_remote', message };
|
||||
}
|
||||
return { field: 'general', message };
|
||||
};
|
||||
|
||||
const handleSave = async () => {
|
||||
if (!validate()) return;
|
||||
setSaving(true);
|
||||
setTestResult(null);
|
||||
try {
|
||||
const payload = buildPayload();
|
||||
const saved = isEditMode
|
||||
? await customModelsService.updateCustomModel(model!.id, payload, token)
|
||||
: await customModelsService.createCustomModel(payload, token);
|
||||
onSaved(saved);
|
||||
closeModal();
|
||||
} catch (err) {
|
||||
const message =
|
||||
err instanceof Error
|
||||
? err.message
|
||||
: t('settings.customModels.errors.saveFailed');
|
||||
const mapped = mapErrorToField(message);
|
||||
setErrors((prev) => ({ ...prev, [mapped.field]: mapped.message }));
|
||||
} finally {
|
||||
setSaving(false);
|
||||
}
|
||||
};
|
||||
|
||||
// Edit mode allows blank api_key (by-id endpoint falls back to stored).
|
||||
const trimmedBaseUrl = formData.base_url.trim();
|
||||
const trimmedApiKey = formData.api_key.trim();
|
||||
const trimmedUpstreamId = formData.upstream_model_id.trim();
|
||||
const canTest = isEditMode
|
||||
? !!(trimmedBaseUrl && trimmedUpstreamId)
|
||||
: !!(trimmedBaseUrl && trimmedApiKey && trimmedUpstreamId);
|
||||
const testDisabledHint = canTest
|
||||
? undefined
|
||||
: t('settings.customModels.testHintNew');
|
||||
|
||||
const handleTest = async () => {
|
||||
if (!canTest) return;
|
||||
setTesting(true);
|
||||
setTestResult(null);
|
||||
try {
|
||||
const result =
|
||||
isEditMode && model?.id
|
||||
? await customModelsService.testCustomModel(model.id, token, {
|
||||
base_url: trimmedBaseUrl,
|
||||
api_key: trimmedApiKey,
|
||||
upstream_model_id: trimmedUpstreamId,
|
||||
})
|
||||
: await customModelsService.testCustomModelPayload(
|
||||
{
|
||||
base_url: trimmedBaseUrl,
|
||||
api_key: trimmedApiKey,
|
||||
upstream_model_id: trimmedUpstreamId,
|
||||
},
|
||||
token,
|
||||
);
|
||||
if (result.ok) {
|
||||
setTestResult({
|
||||
ok: true,
|
||||
message: t('settings.customModels.testSuccess'),
|
||||
});
|
||||
} else {
|
||||
const message =
|
||||
result.error || t('settings.customModels.errors.testFailed');
|
||||
setTestResult({ ok: false, message });
|
||||
const mapped = mapErrorToField(message);
|
||||
if (mapped.field === 'base_url_remote') {
|
||||
setErrors((prev) => ({ ...prev, base_url_remote: message }));
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
const message =
|
||||
err instanceof Error
|
||||
? err.message
|
||||
: t('settings.customModels.errors.testFailed');
|
||||
setTestResult({ ok: false, message });
|
||||
} finally {
|
||||
setTesting(false);
|
||||
}
|
||||
};
|
||||
|
||||
if (modalState !== 'ACTIVE') return null;
|
||||
|
||||
return (
|
||||
<WrapperComponent
|
||||
close={closeModal}
|
||||
isPerformingTask={saving}
|
||||
className="max-w-[600px] md:w-[80vw] lg:w-[60vw]"
|
||||
>
|
||||
<div className="flex h-full flex-col">
|
||||
<div className="px-2 py-2">
|
||||
<h2 className="text-foreground dark:text-foreground text-xl font-semibold">
|
||||
{isEditMode
|
||||
? t('settings.customModels.editTitle')
|
||||
: t('settings.customModels.addTitle')}
|
||||
</h2>
|
||||
<p className="text-muted-foreground mt-2 text-sm">
|
||||
{t('settings.customModels.modalSubtitle')}
|
||||
</p>
|
||||
</div>
|
||||
<div className="flex-1 px-2">
|
||||
<div className="flex flex-col gap-4 px-0.5 py-4">
|
||||
{/* Row 1: Display name + Model ID side-by-side */}
|
||||
<div className="grid grid-cols-1 gap-4 sm:grid-cols-2">
|
||||
<div className="flex flex-col gap-1.5">
|
||||
<Label htmlFor="cm-display-name">
|
||||
{t('settings.customModels.fields.displayName')}
|
||||
<span className="text-red-500">*</span>
|
||||
</Label>
|
||||
<Input
|
||||
id="cm-display-name"
|
||||
type="text"
|
||||
value={formData.display_name}
|
||||
onChange={(e) => handleChange('display_name', e.target.value)}
|
||||
placeholder={t(
|
||||
'settings.customModels.placeholders.displayName',
|
||||
)}
|
||||
aria-invalid={!!errors.display_name || undefined}
|
||||
className="rounded-xl"
|
||||
/>
|
||||
{errors.display_name && (
|
||||
<p className="text-destructive text-xs">
|
||||
{errors.display_name}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="flex flex-col gap-1.5">
|
||||
<Label htmlFor="cm-model-id">
|
||||
{t('settings.customModels.fields.modelId')}
|
||||
<span className="text-red-500">*</span>
|
||||
</Label>
|
||||
<Input
|
||||
id="cm-model-id"
|
||||
type="text"
|
||||
value={formData.upstream_model_id}
|
||||
onChange={(e) =>
|
||||
handleChange('upstream_model_id', e.target.value)
|
||||
}
|
||||
placeholder={t('settings.customModels.placeholders.modelId')}
|
||||
aria-invalid={!!errors.upstream_model_id || undefined}
|
||||
className="rounded-xl"
|
||||
/>
|
||||
{errors.upstream_model_id && (
|
||||
<p className="text-destructive text-xs">
|
||||
{errors.upstream_model_id}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Row 2: Base URL + API key side-by-side */}
|
||||
<div className="grid grid-cols-1 gap-4 sm:grid-cols-2">
|
||||
<div className="flex flex-col gap-1.5">
|
||||
<Label htmlFor="cm-base-url">
|
||||
{t('settings.customModels.fields.baseUrl')}
|
||||
<span className="text-red-500">*</span>
|
||||
</Label>
|
||||
<Input
|
||||
id="cm-base-url"
|
||||
type="url"
|
||||
value={formData.base_url}
|
||||
onChange={(e) => handleChange('base_url', e.target.value)}
|
||||
placeholder={t('settings.customModels.placeholders.baseUrl')}
|
||||
aria-invalid={
|
||||
!!errors.base_url || !!errors.base_url_remote || undefined
|
||||
}
|
||||
className="rounded-xl"
|
||||
/>
|
||||
{errors.base_url && (
|
||||
<p className="text-destructive text-xs">{errors.base_url}</p>
|
||||
)}
|
||||
{errors.base_url_remote && (
|
||||
<p className="text-destructive text-xs">
|
||||
{errors.base_url_remote}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="flex flex-col gap-1.5">
|
||||
<Label htmlFor="cm-api-key">
|
||||
{t('settings.customModels.fields.apiKey')}
|
||||
{!isEditMode && <span className="text-red-500">*</span>}
|
||||
</Label>
|
||||
<Input
|
||||
id="cm-api-key"
|
||||
type="password"
|
||||
autoComplete="new-password"
|
||||
value={formData.api_key}
|
||||
onChange={(e) => handleChange('api_key', e.target.value)}
|
||||
placeholder={
|
||||
isEditMode
|
||||
? t('settings.customModels.placeholders.apiKeyEdit')
|
||||
: t('settings.customModels.placeholders.apiKey')
|
||||
}
|
||||
aria-invalid={!!errors.api_key || undefined}
|
||||
className="rounded-xl"
|
||||
/>
|
||||
{isEditMode && (
|
||||
<p className="text-muted-foreground text-xs">
|
||||
{t('settings.customModels.hints.apiKeyEdit')}
|
||||
</p>
|
||||
)}
|
||||
{errors.api_key && (
|
||||
<p className="text-destructive text-xs">{errors.api_key}</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Row 3: Description (full width, optional) */}
|
||||
<div className="flex flex-col gap-1.5">
|
||||
<Label htmlFor="cm-description">
|
||||
{t('settings.customModels.fields.description')}
|
||||
</Label>
|
||||
<Input
|
||||
id="cm-description"
|
||||
type="text"
|
||||
value={formData.description}
|
||||
onChange={(e) => handleChange('description', e.target.value)}
|
||||
placeholder={t(
|
||||
'settings.customModels.placeholders.description',
|
||||
)}
|
||||
className="rounded-xl"
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Row 4: Capabilities — flat (no border), chips + inline ctx */}
|
||||
<div className="flex flex-col gap-2">
|
||||
<Label>{t('settings.customModels.capabilities.title')}</Label>
|
||||
<div className="flex flex-wrap gap-2">
|
||||
<CapabilityChip
|
||||
label={t('settings.customModels.capabilities.chips.tools')}
|
||||
active={formData.supports_tools}
|
||||
onClick={() =>
|
||||
handleChange('supports_tools', !formData.supports_tools)
|
||||
}
|
||||
/>
|
||||
<CapabilityChip
|
||||
label={t(
|
||||
'settings.customModels.capabilities.chips.structuredOutput',
|
||||
)}
|
||||
active={formData.supports_structured_output}
|
||||
onClick={() =>
|
||||
handleChange(
|
||||
'supports_structured_output',
|
||||
!formData.supports_structured_output,
|
||||
)
|
||||
}
|
||||
/>
|
||||
<CapabilityChip
|
||||
label={t('settings.customModels.capabilities.chips.images')}
|
||||
active={formData.supports_images}
|
||||
onClick={() =>
|
||||
handleChange('supports_images', !formData.supports_images)
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
<div className="mt-1 flex flex-col gap-1.5 sm:flex-row sm:items-center sm:gap-3">
|
||||
<Label
|
||||
htmlFor="cm-context-window"
|
||||
className="text-muted-foreground text-xs sm:mb-0"
|
||||
>
|
||||
{t('settings.customModels.capabilities.contextWindowShort')}
|
||||
</Label>
|
||||
<Input
|
||||
id="cm-context-window"
|
||||
type="number"
|
||||
value={formData.context_window}
|
||||
min={MIN_CONTEXT_WINDOW}
|
||||
max={MAX_CONTEXT_WINDOW}
|
||||
step={1000}
|
||||
onChange={(e) => {
|
||||
const v = e.target.value;
|
||||
if (v === '') {
|
||||
handleChange('context_window', '');
|
||||
} else {
|
||||
const n = parseInt(v, 10);
|
||||
if (!Number.isNaN(n)) {
|
||||
handleChange('context_window', n);
|
||||
}
|
||||
}
|
||||
}}
|
||||
aria-invalid={!!errors.context_window || undefined}
|
||||
className="w-full rounded-xl sm:w-40"
|
||||
/>
|
||||
{errors.context_window && (
|
||||
<p className="text-destructive text-xs">
|
||||
{errors.context_window}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{testResult && (
|
||||
<div
|
||||
className={`rounded-xl p-3 text-sm ${
|
||||
testResult.ok
|
||||
? 'bg-green-50 text-green-700 dark:bg-green-900/40 dark:text-green-300'
|
||||
: 'bg-red-50 text-red-700 dark:bg-red-900/40 dark:text-red-300'
|
||||
}`}
|
||||
>
|
||||
{testResult.message}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{errors.general && (
|
||||
<div className="rounded-xl bg-red-50 p-3 text-sm text-red-700 dark:bg-red-900/40 dark:text-red-300">
|
||||
{errors.general}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="px-2 py-4">
|
||||
<div className="flex flex-col gap-3 sm:flex-row sm:justify-between">
|
||||
<button
|
||||
type="button"
|
||||
onClick={handleTest}
|
||||
disabled={!canTest || testing || saving}
|
||||
title={testDisabledHint}
|
||||
className="border-border dark:border-border dark:text-foreground hover:bg-accent dark:hover:bg-muted/50 w-full rounded-3xl border px-6 py-2 text-sm font-medium transition-all disabled:cursor-not-allowed disabled:opacity-50 sm:w-auto"
|
||||
>
|
||||
{testing ? (
|
||||
<div className="flex items-center justify-center">
|
||||
<Spinner size="small" />
|
||||
<span className="ml-2">
|
||||
{t('settings.customModels.testing')}
|
||||
</span>
|
||||
</div>
|
||||
) : (
|
||||
t('settings.customModels.testConnection')
|
||||
)}
|
||||
</button>
|
||||
<div className="flex flex-col-reverse gap-3 sm:flex-row sm:gap-3">
|
||||
<button
|
||||
type="button"
|
||||
onClick={closeModal}
|
||||
disabled={saving}
|
||||
className="dark:text-foreground hover:bg-accent dark:hover:bg-muted/50 w-full cursor-pointer rounded-3xl px-6 py-2 text-sm font-medium disabled:opacity-50 sm:w-auto"
|
||||
>
|
||||
{t('cancel')}
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
onClick={handleSave}
|
||||
disabled={saving}
|
||||
className="bg-primary hover:bg-primary/90 w-full rounded-3xl px-6 py-2 text-sm font-medium text-white transition-all disabled:opacity-50 sm:w-auto"
|
||||
>
|
||||
{saving ? (
|
||||
<div className="flex items-center justify-center">
|
||||
<Spinner size="small" />
|
||||
<span className="ml-2">
|
||||
{t('settings.customModels.saving')}
|
||||
</span>
|
||||
</div>
|
||||
) : (
|
||||
t('settings.customModels.save')
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</WrapperComponent>
|
||||
);
|
||||
}
|
||||
|
||||
interface CapabilityChipProps {
|
||||
label: string;
|
||||
active: boolean;
|
||||
onClick: () => void;
|
||||
}
|
||||
|
||||
function CapabilityChip({ label, active, onClick }: CapabilityChipProps) {
|
||||
return (
|
||||
<button
|
||||
type="button"
|
||||
role="switch"
|
||||
aria-checked={active}
|
||||
onClick={onClick}
|
||||
className={`inline-flex items-center gap-1.5 rounded-full border px-3 py-1.5 text-sm transition-colors ${
|
||||
active
|
||||
? 'border-emerald-500/50 bg-emerald-500/10 text-emerald-700 dark:border-emerald-400/40 dark:bg-emerald-400/10 dark:text-emerald-300'
|
||||
: 'border-border text-muted-foreground hover:bg-accent dark:hover:bg-muted/40'
|
||||
}`}
|
||||
>
|
||||
{active && <Check size={14} strokeWidth={2.5} />}
|
||||
{label}
|
||||
</button>
|
||||
);
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
export type ModelSource = 'builtin' | 'user';
|
||||
|
||||
export interface AvailableModel {
|
||||
id: string;
|
||||
provider: string;
|
||||
@@ -9,6 +11,7 @@ export interface AvailableModel {
|
||||
supports_structured_output: boolean;
|
||||
supports_streaming: boolean;
|
||||
enabled: boolean;
|
||||
source?: ModelSource;
|
||||
}
|
||||
|
||||
export interface Model {
|
||||
@@ -22,4 +25,38 @@ export interface Model {
|
||||
supports_tools: boolean;
|
||||
supports_structured_output: boolean;
|
||||
supports_streaming: boolean;
|
||||
source?: ModelSource;
|
||||
}
|
||||
|
||||
export interface CustomModelCapabilities {
|
||||
supports_tools: boolean;
|
||||
supports_structured_output: boolean;
|
||||
attachments: string[];
|
||||
context_window: number;
|
||||
}
|
||||
|
||||
export interface CustomModel {
|
||||
id: string;
|
||||
upstream_model_id: string;
|
||||
display_name: string;
|
||||
description?: string;
|
||||
base_url: string;
|
||||
capabilities: CustomModelCapabilities;
|
||||
enabled: boolean;
|
||||
source: 'user';
|
||||
}
|
||||
|
||||
export interface CreateCustomModelPayload {
|
||||
upstream_model_id: string;
|
||||
display_name: string;
|
||||
description?: string;
|
||||
base_url: string;
|
||||
api_key?: string;
|
||||
capabilities: CustomModelCapabilities;
|
||||
enabled: boolean;
|
||||
}
|
||||
|
||||
export interface CustomModelTestResult {
|
||||
ok: boolean;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
@@ -250,6 +250,22 @@ prefListenerMiddleware.startListening({
|
||||
},
|
||||
});
|
||||
|
||||
// Reconcile selectedModel when availableModels changes so a deleted
|
||||
// BYOM doesn't leave a stale id in localStorage.
|
||||
prefListenerMiddleware.startListening({
|
||||
matcher: isAnyOf(setAvailableModels),
|
||||
effect: (_action, listenerApi) => {
|
||||
const state = listenerApi.getState() as RootState;
|
||||
const { selectedModel, availableModels } = state.preference;
|
||||
if (!availableModels.length) return;
|
||||
if (!selectedModel) return;
|
||||
const stillValid = availableModels.some((m) => m.id === selectedModel.id);
|
||||
if (!stillValid) {
|
||||
listenerApi.dispatch(setSelectedModel(availableModels[0]));
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
export const selectApiKey = (state: RootState) => state.preference.apiKey;
|
||||
export const selectApiKeyStatus = (state: RootState) =>
|
||||
!!state.preference.apiKey;
|
||||
|
||||
313
frontend/src/settings/CustomModels.tsx
Normal file
313
frontend/src/settings/CustomModels.tsx
Normal file
@@ -0,0 +1,313 @@
|
||||
import { Globe, Tag, Trash } from 'lucide-react';
|
||||
import React from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useDispatch, useSelector } from 'react-redux';
|
||||
|
||||
import customModelsService from '../api/services/customModelsService';
|
||||
import modelService from '../api/services/modelService';
|
||||
import Edit from '../assets/edit.svg';
|
||||
import NoFilesDarkIcon from '../assets/no-files-dark.svg';
|
||||
import NoFilesIcon from '../assets/no-files.svg';
|
||||
import SearchIcon from '../assets/search.svg';
|
||||
import ThreeDotsIcon from '../assets/three-dots.svg';
|
||||
import ContextMenu, { MenuOption } from '../components/ContextMenu';
|
||||
import SkeletonLoader from '../components/SkeletonLoader';
|
||||
import { useDarkTheme, useLoaderState } from '../hooks';
|
||||
import ConfirmationModal from '../modals/ConfirmationModal';
|
||||
import CustomModelModal from '../modals/CustomModelModal';
|
||||
import { ActiveState } from '../models/misc';
|
||||
import {
|
||||
selectToken,
|
||||
setAvailableModels,
|
||||
} from '../preferences/preferenceSlice';
|
||||
|
||||
import type { CustomModel } from '../models/types';
|
||||
|
||||
const formatBaseUrlHost = (baseUrl: string): string => {
|
||||
if (!baseUrl) return '';
|
||||
try {
|
||||
const url = new URL(baseUrl);
|
||||
return url.host || url.hostname || baseUrl;
|
||||
} catch {
|
||||
const stripped = baseUrl.replace(/^https?:\/\//i, '');
|
||||
const slashIdx = stripped.indexOf('/');
|
||||
return slashIdx >= 0 ? stripped.slice(0, slashIdx) : stripped;
|
||||
}
|
||||
};
|
||||
|
||||
export default function CustomModels() {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useDispatch();
|
||||
const token = useSelector(selectToken);
|
||||
const [isDarkTheme] = useDarkTheme();
|
||||
|
||||
const [models, setModels] = React.useState<CustomModel[]>([]);
|
||||
const [searchTerm, setSearchTerm] = React.useState('');
|
||||
const [loading, setLoading] = useLoaderState(false);
|
||||
const [modalState, setModalState] = React.useState<ActiveState>('INACTIVE');
|
||||
const [editingModel, setEditingModel] = React.useState<CustomModel | null>(
|
||||
null,
|
||||
);
|
||||
const [activeMenuId, setActiveMenuId] = React.useState<string | null>(null);
|
||||
const menuRefs = React.useRef<{
|
||||
[key: string]: React.RefObject<HTMLDivElement | null>;
|
||||
}>({});
|
||||
const [deleteState, setDeleteState] = React.useState<ActiveState>('INACTIVE');
|
||||
const [modelToDelete, setModelToDelete] = React.useState<CustomModel | null>(
|
||||
null,
|
||||
);
|
||||
|
||||
// Ref instead of useCallback: useLoaderState returns a fresh setter
|
||||
// each render, which would loop the effect (thousands of req/s).
|
||||
const fetchModelsRef = React.useRef<() => Promise<void>>(async () => {});
|
||||
fetchModelsRef.current = async () => {
|
||||
setLoading(true);
|
||||
try {
|
||||
const data = await customModelsService.listCustomModels(token);
|
||||
setModels(data);
|
||||
} catch (err) {
|
||||
console.error('Failed to load custom models:', err);
|
||||
setModels([]);
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
React.useEffect(() => {
|
||||
fetchModelsRef.current();
|
||||
}, [token]);
|
||||
|
||||
React.useEffect(() => {
|
||||
models.forEach((model) => {
|
||||
if (!menuRefs.current[model.id]) {
|
||||
menuRefs.current[model.id] = React.createRef<HTMLDivElement>();
|
||||
}
|
||||
});
|
||||
}, [models]);
|
||||
|
||||
const openAddModal = () => {
|
||||
setEditingModel(null);
|
||||
setModalState('ACTIVE');
|
||||
};
|
||||
|
||||
const openEditModal = (model: CustomModel) => {
|
||||
setEditingModel(model);
|
||||
setModalState('ACTIVE');
|
||||
};
|
||||
|
||||
// Refresh Redux availableModels so the chat dropdown reconciles a
|
||||
// selectedModel UUID that was just deleted/disabled.
|
||||
const refreshGlobalAvailableModels = React.useCallback(async () => {
|
||||
try {
|
||||
const response = await modelService.getModels(token);
|
||||
if (!response.ok) return;
|
||||
const data = await response.json();
|
||||
const transformed = modelService.transformModels(data.models || []);
|
||||
dispatch(setAvailableModels(transformed));
|
||||
} catch (err) {
|
||||
console.error('Failed to refresh global available models:', err);
|
||||
}
|
||||
}, [dispatch, token]);
|
||||
|
||||
const handleSaved = (saved: CustomModel) => {
|
||||
setModels((prev) => {
|
||||
const idx = prev.findIndex((m) => m.id === saved.id);
|
||||
if (idx === -1) return [saved, ...prev];
|
||||
const next = [...prev];
|
||||
next[idx] = saved;
|
||||
return next;
|
||||
});
|
||||
refreshGlobalAvailableModels();
|
||||
};
|
||||
|
||||
const requestDelete = (model: CustomModel) => {
|
||||
setModelToDelete(model);
|
||||
setDeleteState('ACTIVE');
|
||||
};
|
||||
|
||||
const confirmDelete = async () => {
|
||||
if (!modelToDelete) return;
|
||||
try {
|
||||
await customModelsService.deleteCustomModel(modelToDelete.id, token);
|
||||
setModels((prev) => prev.filter((m) => m.id !== modelToDelete.id));
|
||||
refreshGlobalAvailableModels();
|
||||
} catch (err) {
|
||||
console.error('Failed to delete custom model:', err);
|
||||
} finally {
|
||||
setModelToDelete(null);
|
||||
setDeleteState('INACTIVE');
|
||||
}
|
||||
};
|
||||
|
||||
const getMenuOptions = (model: CustomModel): MenuOption[] => [
|
||||
{
|
||||
icon: Edit,
|
||||
label: t('settings.customModels.actions.edit'),
|
||||
onClick: () => openEditModal(model),
|
||||
variant: 'primary',
|
||||
iconWidth: 14,
|
||||
iconHeight: 14,
|
||||
},
|
||||
{
|
||||
icon: Trash,
|
||||
label: t('settings.customModels.actions.delete'),
|
||||
onClick: () => requestDelete(model),
|
||||
variant: 'danger',
|
||||
iconWidth: 16,
|
||||
iconHeight: 16,
|
||||
},
|
||||
];
|
||||
|
||||
const filteredModels = models.filter((model) => {
|
||||
const q = searchTerm.toLowerCase();
|
||||
return (
|
||||
model.display_name.toLowerCase().includes(q) ||
|
||||
model.upstream_model_id.toLowerCase().includes(q)
|
||||
);
|
||||
});
|
||||
|
||||
const renderEmptyState = () => (
|
||||
<div className="flex w-full flex-col items-center justify-center py-12">
|
||||
<img
|
||||
src={isDarkTheme ? NoFilesDarkIcon : NoFilesIcon}
|
||||
alt={t('settings.customModels.empty')}
|
||||
className="mx-auto mb-6 h-32 w-32"
|
||||
/>
|
||||
<p className="text-center text-lg text-gray-500 dark:text-gray-400">
|
||||
{t('settings.customModels.empty')}
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="mt-8">
|
||||
<div className="relative flex flex-col">
|
||||
<p className="text-muted-foreground mb-5 text-[15px] leading-6">
|
||||
{t('settings.customModels.subtitle')}
|
||||
</p>
|
||||
<div className="my-3 flex flex-col items-start gap-3 sm:flex-row sm:items-center sm:justify-between">
|
||||
<div className="relative w-full max-w-md">
|
||||
<img
|
||||
src={SearchIcon}
|
||||
alt=""
|
||||
className="absolute top-1/2 left-4 h-5 w-5 -translate-y-1/2 opacity-40"
|
||||
/>
|
||||
<input
|
||||
maxLength={256}
|
||||
placeholder={t('settings.customModels.searchPlaceholder')}
|
||||
name="custom-models-search-input"
|
||||
type="text"
|
||||
id="custom-models-search-input"
|
||||
value={searchTerm}
|
||||
onChange={(e) => setSearchTerm(e.target.value)}
|
||||
className="border-border bg-card text-foreground placeholder:text-muted-foreground h-11 w-full rounded-full border py-2 pr-5 pl-11 text-sm shadow-[0_1px_4px_rgba(0,0,0,0.06)] transition-shadow outline-none focus:shadow-[0_2px_8px_rgba(0,0,0,0.1)] dark:shadow-none"
|
||||
/>
|
||||
</div>
|
||||
<button
|
||||
className="bg-primary hover:bg-primary/90 flex h-11 min-w-[108px] items-center justify-center rounded-full px-4 text-sm whitespace-normal text-white"
|
||||
onClick={openAddModal}
|
||||
>
|
||||
{t('settings.customModels.addModel')}
|
||||
</button>
|
||||
</div>
|
||||
<div className="border-border dark:border-border mt-5 mb-8 border-b" />
|
||||
{loading ? (
|
||||
<div className="flex flex-wrap justify-center gap-4 sm:justify-start">
|
||||
<SkeletonLoader component="toolCards" count={3} />
|
||||
</div>
|
||||
) : (
|
||||
<div className="flex flex-wrap justify-center gap-4 sm:justify-start">
|
||||
{filteredModels.length === 0
|
||||
? renderEmptyState()
|
||||
: filteredModels.map((model) => (
|
||||
<div
|
||||
key={model.id}
|
||||
className="bg-muted hover:bg-accent relative flex w-[300px] flex-col overflow-hidden rounded-2xl p-5"
|
||||
>
|
||||
<div
|
||||
ref={menuRefs.current[model.id]}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
setActiveMenuId(
|
||||
activeMenuId === model.id ? null : model.id,
|
||||
);
|
||||
}}
|
||||
className="absolute top-3 right-3 z-10 cursor-pointer"
|
||||
>
|
||||
<img
|
||||
src={ThreeDotsIcon}
|
||||
alt={t('settings.customModels.actionsMenuAria', {
|
||||
modelName: model.display_name,
|
||||
})}
|
||||
className="h-[19px] w-[19px]"
|
||||
/>
|
||||
<ContextMenu
|
||||
isOpen={activeMenuId === model.id}
|
||||
setIsOpen={(isOpen) => {
|
||||
setActiveMenuId(isOpen ? model.id : null);
|
||||
}}
|
||||
options={getMenuOptions(model)}
|
||||
anchorRef={menuRefs.current[model.id]}
|
||||
position="bottom-right"
|
||||
offset={{ x: 0, y: 0 }}
|
||||
/>
|
||||
</div>
|
||||
<div className="w-full pr-7">
|
||||
<div className="flex items-center gap-2">
|
||||
<p
|
||||
title={model.display_name}
|
||||
className="text-foreground dark:text-foreground truncate text-[15px] leading-snug font-semibold"
|
||||
>
|
||||
{model.display_name}
|
||||
</p>
|
||||
{!model.enabled && (
|
||||
<span className="bg-muted-foreground/15 text-muted-foreground shrink-0 rounded-full px-2 py-0.5 text-[10px] leading-none font-medium">
|
||||
{t('settings.customModels.disabledBadge')}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
<div className="mt-3 space-y-1.5">
|
||||
<div
|
||||
className="text-muted-foreground/80 flex items-center gap-1.5 text-xs leading-relaxed"
|
||||
title={model.upstream_model_id}
|
||||
>
|
||||
<Tag className="h-3.5 w-3.5 shrink-0 opacity-70" />
|
||||
<span className="truncate">
|
||||
{model.upstream_model_id}
|
||||
</span>
|
||||
</div>
|
||||
<div
|
||||
className="text-muted-foreground/80 flex items-center gap-1.5 text-xs leading-relaxed"
|
||||
title={model.base_url}
|
||||
>
|
||||
<Globe className="h-3.5 w-3.5 shrink-0 opacity-70" />
|
||||
<span className="truncate">
|
||||
{formatBaseUrlHost(model.base_url)}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<CustomModelModal
|
||||
modalState={modalState}
|
||||
setModalState={setModalState}
|
||||
model={editingModel}
|
||||
onSaved={handleSaved}
|
||||
/>
|
||||
<ConfirmationModal
|
||||
message={t('settings.customModels.deleteWarning', {
|
||||
modelName: modelToDelete?.display_name || '',
|
||||
})}
|
||||
modalState={deleteState}
|
||||
setModalState={setDeleteState}
|
||||
handleSubmit={confirmDelete}
|
||||
submitLabel={t('settings.customModels.actions.delete')}
|
||||
variant="danger"
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -21,6 +21,7 @@ import {
|
||||
setSourceDocs,
|
||||
} from '../preferences/preferenceSlice';
|
||||
import Analytics from './Analytics';
|
||||
import CustomModels from './CustomModels';
|
||||
import Sources from './Sources';
|
||||
import General from './General';
|
||||
import Logs from './Logs';
|
||||
@@ -39,6 +40,8 @@ export default function Settings() {
|
||||
return t('settings.analytics.label');
|
||||
if (path.includes('/settings/logs')) return t('settings.logs.label');
|
||||
if (path.includes('/settings/tools')) return t('settings.tools.label');
|
||||
if (path.includes('/settings/custom-models'))
|
||||
return t('settings.customModels.label');
|
||||
return t('settings.general.label');
|
||||
};
|
||||
|
||||
@@ -52,6 +55,8 @@ export default function Settings() {
|
||||
navigate('/settings/analytics');
|
||||
else if (tab === t('settings.logs.label')) navigate('/settings/logs');
|
||||
else if (tab === t('settings.tools.label')) navigate('/settings/tools');
|
||||
else if (tab === t('settings.customModels.label'))
|
||||
navigate('/settings/custom-models');
|
||||
};
|
||||
|
||||
React.useEffect(() => {
|
||||
@@ -113,6 +118,7 @@ export default function Settings() {
|
||||
<Route path="analytics" element={<Analytics />} />
|
||||
<Route path="logs" element={<Logs />} />
|
||||
<Route path="tools" element={<Tools />} />
|
||||
<Route path="custom-models" element={<CustomModels />} />
|
||||
<Route path="*" element={<Navigate to="/settings" replace />} />
|
||||
</Routes>
|
||||
</div>
|
||||
|
||||
@@ -116,7 +116,7 @@ def test_execute_agent_node_normalizes_wrapped_schema_before_agent_create(monkey
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_model_capabilities",
|
||||
lambda _model_id: {"supports_structured_output": True},
|
||||
lambda _model_id, **_kwargs: {"supports_structured_output": True},
|
||||
)
|
||||
|
||||
list(engine._execute_agent_node(node))
|
||||
@@ -242,7 +242,7 @@ def test_validate_workflow_structure_rejects_unsupported_structured_output_model
|
||||
monkeypatch.setattr(
|
||||
workflow_routes,
|
||||
"get_model_capabilities",
|
||||
lambda _model_id: {"supports_structured_output": False},
|
||||
lambda _model_id, **_kwargs: {"supports_structured_output": False},
|
||||
)
|
||||
|
||||
nodes = [
|
||||
@@ -296,7 +296,7 @@ def test_execute_agent_node_raises_when_structured_output_violates_schema(monkey
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_model_capabilities",
|
||||
lambda _model_id: {"supports_structured_output": True},
|
||||
lambda _model_id, **_kwargs: {"supports_structured_output": True},
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Structured output did not match schema"):
|
||||
@@ -322,7 +322,7 @@ def test_execute_agent_node_raises_when_schema_set_and_response_not_json(monkeyp
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_model_capabilities",
|
||||
lambda _model_id: {"supports_structured_output": True},
|
||||
lambda _model_id, **_kwargs: {"supports_structured_output": True},
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
@@ -360,11 +360,11 @@ class TestWorkflowEngineAdditionalCoverage:
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_provider_from_model_id",
|
||||
lambda _: None,
|
||||
lambda _, **_kwargs: None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_model_capabilities",
|
||||
lambda _: None,
|
||||
lambda _, **_kwargs: None,
|
||||
)
|
||||
|
||||
list(engine._execute_agent_node(node))
|
||||
@@ -390,11 +390,11 @@ class TestWorkflowEngineAdditionalCoverage:
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_provider_from_model_id",
|
||||
lambda _: "openai",
|
||||
lambda _, **_kwargs: "openai",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_model_capabilities",
|
||||
lambda _: None,
|
||||
lambda _, **_kwargs: None,
|
||||
)
|
||||
|
||||
list(engine._execute_agent_node(node))
|
||||
@@ -416,11 +416,11 @@ class TestWorkflowEngineAdditionalCoverage:
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_provider_from_model_id",
|
||||
lambda _: "openai",
|
||||
lambda _, **_kwargs: "openai",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_model_capabilities",
|
||||
lambda _: {"supports_structured_output": False},
|
||||
lambda _, **_kwargs: {"supports_structured_output": False},
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="does not support structured output"):
|
||||
@@ -450,11 +450,11 @@ class TestWorkflowEngineAdditionalCoverage:
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_provider_from_model_id",
|
||||
lambda _: None,
|
||||
lambda _, **_kwargs: None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_model_capabilities",
|
||||
lambda _: {"supports_structured_output": True},
|
||||
lambda _, **_kwargs: {"supports_structured_output": True},
|
||||
)
|
||||
|
||||
list(engine._execute_agent_node(node))
|
||||
@@ -481,11 +481,11 @@ class TestWorkflowEngineAdditionalCoverage:
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_provider_from_model_id",
|
||||
lambda _: None,
|
||||
lambda _, **_kwargs: None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_model_capabilities",
|
||||
lambda _: {"supports_structured_output": True},
|
||||
lambda _, **_kwargs: {"supports_structured_output": True},
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
|
||||
@@ -162,8 +162,10 @@ class TestCompressIfNeeded:
|
||||
current_query_tokens=1000,
|
||||
)
|
||||
|
||||
# user_id flows through so BYOM custom-model UUIDs resolve to
|
||||
# the user's declared context window in the threshold check.
|
||||
mock_threshold_checker.should_compress.assert_called_once_with(
|
||||
sample_conversation, "gpt-4", 1000
|
||||
sample_conversation, "gpt-4", 1000, user_id="user1"
|
||||
)
|
||||
|
||||
|
||||
@@ -270,8 +272,12 @@ class TestPerformCompression:
|
||||
)
|
||||
orch._perform_compression("c1", conversation, "gpt-4", decoded_token)
|
||||
|
||||
# Verify the override model was used
|
||||
mock_get_provider.assert_called_with("gpt-3.5-turbo")
|
||||
# Verify the override model was used. user_id flows from
|
||||
# decoded_token['sub'] so per-user BYOM custom-model UUIDs
|
||||
# resolve.
|
||||
mock_get_provider.assert_called_with(
|
||||
"gpt-3.5-turbo", user_id=decoded_token["sub"]
|
||||
)
|
||||
|
||||
@patch(
|
||||
"application.api.answer.services.compression.orchestrator.get_provider_from_model_id"
|
||||
@@ -362,7 +368,12 @@ class TestCompressMidExecution:
|
||||
)
|
||||
|
||||
mock_perform.assert_called_once_with(
|
||||
"conv1", sample_conversation, "gpt-4", decoded_token
|
||||
"conv1",
|
||||
sample_conversation,
|
||||
"gpt-4",
|
||||
decoded_token,
|
||||
user_id="user1",
|
||||
model_user_id=None,
|
||||
)
|
||||
|
||||
def test_loads_conversation_when_not_provided(
|
||||
|
||||
@@ -200,7 +200,7 @@ class TestSetupPeriodicTasks:
|
||||
|
||||
setup_periodic_tasks(sender)
|
||||
|
||||
assert sender.add_periodic_task.call_count == 4
|
||||
assert sender.add_periodic_task.call_count == 5
|
||||
|
||||
calls = sender.add_periodic_task.call_args_list
|
||||
|
||||
@@ -212,6 +212,8 @@ class TestSetupPeriodicTasks:
|
||||
assert calls[2][0][0] == timedelta(days=30)
|
||||
# pending_tool_state TTL cleanup (60s)
|
||||
assert calls[3][0][0] == timedelta(seconds=60)
|
||||
# version-check (every 7h)
|
||||
assert calls[4][0][0] == timedelta(hours=7)
|
||||
|
||||
|
||||
class TestMcpOauthTask:
|
||||
|
||||
688
tests/api/user/test_user_custom_models_routes.py
Normal file
688
tests/api/user/test_user_custom_models_routes.py
Normal file
@@ -0,0 +1,688 @@
|
||||
"""Tests for the BYOM REST API at /api/user/models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
return Flask(__name__)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _patch_db(conn):
|
||||
"""Patch the routes' db helpers to yield the given pg connection."""
|
||||
|
||||
@contextmanager
|
||||
def _yield_conn():
|
||||
yield conn
|
||||
|
||||
with patch(
|
||||
"application.api.user.models.routes.db_session", _yield_conn
|
||||
), patch(
|
||||
"application.api.user.models.routes.db_readonly", _yield_conn
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_registry():
|
||||
from application.core.model_registry import ModelRegistry
|
||||
|
||||
ModelRegistry.reset()
|
||||
yield
|
||||
ModelRegistry.reset()
|
||||
|
||||
|
||||
# Auth
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestAuth:
|
||||
def test_list_unauthenticated_returns_401(self, app):
|
||||
from application.api.user.models.routes import (
|
||||
UserModelsCollectionResource,
|
||||
)
|
||||
|
||||
with app.test_request_context("/api/user/models"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = None
|
||||
resp = UserModelsCollectionResource().get()
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_create_unauthenticated_returns_401(self, app):
|
||||
from application.api.user.models.routes import (
|
||||
UserModelsCollectionResource,
|
||||
)
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/user/models",
|
||||
method="POST",
|
||||
json={
|
||||
"upstream_model_id": "x",
|
||||
"display_name": "x",
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"api_key": "k",
|
||||
},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = None
|
||||
resp = UserModelsCollectionResource().post()
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
# Create
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCreate:
|
||||
def test_creates_and_returns_201_without_api_key(self, app, pg_conn):
|
||||
from application.api.user.models.routes import (
|
||||
UserModelsCollectionResource,
|
||||
)
|
||||
|
||||
# Mock DNS so the SSRF check passes for api.mistral.ai without
|
||||
# hitting the network.
|
||||
with patch("application.security.safe_url.socket.getaddrinfo") as gai:
|
||||
gai.return_value = [
|
||||
(None, None, None, None, ("104.18.0.1", 0))
|
||||
]
|
||||
with app.test_request_context(
|
||||
"/api/user/models",
|
||||
method="POST",
|
||||
json={
|
||||
"upstream_model_id": "mistral-large-latest",
|
||||
"display_name": "My Mistral",
|
||||
"base_url": "https://api.mistral.ai/v1",
|
||||
"api_key": "sk-mistral-test",
|
||||
"capabilities": {
|
||||
"supports_tools": True,
|
||||
"context_window": 128000,
|
||||
},
|
||||
},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user-1"}
|
||||
with _patch_db(pg_conn):
|
||||
resp = UserModelsCollectionResource().post()
|
||||
|
||||
assert resp.status_code == 201
|
||||
body = resp.get_json()
|
||||
assert body["upstream_model_id"] == "mistral-large-latest"
|
||||
assert body["source"] == "user"
|
||||
# Critical: api_key must NEVER appear in the response
|
||||
assert "api_key" not in body
|
||||
for v in body.values():
|
||||
assert v != "sk-mistral-test"
|
||||
|
||||
def test_create_rejects_missing_required_fields(self, app, pg_conn):
|
||||
from application.api.user.models.routes import (
|
||||
UserModelsCollectionResource,
|
||||
)
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/user/models",
|
||||
method="POST",
|
||||
json={"upstream_model_id": "x"}, # missing the others
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user-1"}
|
||||
with _patch_db(pg_conn):
|
||||
resp = UserModelsCollectionResource().post()
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_create_rejects_loopback_url(self, app, pg_conn):
|
||||
from application.api.user.models.routes import (
|
||||
UserModelsCollectionResource,
|
||||
)
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/user/models",
|
||||
method="POST",
|
||||
json={
|
||||
"upstream_model_id": "x",
|
||||
"display_name": "x",
|
||||
"base_url": "https://127.0.0.1/v1",
|
||||
"api_key": "k",
|
||||
},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user-1"}
|
||||
with _patch_db(pg_conn):
|
||||
resp = UserModelsCollectionResource().post()
|
||||
assert resp.status_code == 400
|
||||
body = resp.get_json()
|
||||
assert "error" in body
|
||||
|
||||
def test_create_rejects_unknown_attachment_alias(self, app, pg_conn):
|
||||
"""The UI sends ``["image"]`` as an alias; bad strings ("video",
|
||||
typos) must reject at the boundary so the DB never holds
|
||||
garbage that the registry would later silently drop.
|
||||
"""
|
||||
from application.api.user.models.routes import (
|
||||
UserModelsCollectionResource,
|
||||
)
|
||||
|
||||
with patch("application.security.safe_url.socket.getaddrinfo") as gai:
|
||||
gai.return_value = [
|
||||
(None, None, None, None, ("104.18.0.1", 0))
|
||||
]
|
||||
with app.test_request_context(
|
||||
"/api/user/models",
|
||||
method="POST",
|
||||
json={
|
||||
"upstream_model_id": "m",
|
||||
"display_name": "M",
|
||||
"base_url": "https://api.mistral.ai/v1",
|
||||
"api_key": "k",
|
||||
"capabilities": {"attachments": ["video"]},
|
||||
},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user-1"}
|
||||
with _patch_db(pg_conn):
|
||||
resp = UserModelsCollectionResource().post()
|
||||
assert resp.status_code == 400
|
||||
body = resp.get_json()
|
||||
assert "video" in body["error"]
|
||||
|
||||
def test_create_accepts_image_alias_and_raw_mime(self, app, pg_conn):
|
||||
"""The known ``image`` alias and raw MIME types both pass."""
|
||||
from application.api.user.models.routes import (
|
||||
UserModelsCollectionResource,
|
||||
)
|
||||
|
||||
with patch("application.security.safe_url.socket.getaddrinfo") as gai:
|
||||
gai.return_value = [
|
||||
(None, None, None, None, ("104.18.0.1", 0))
|
||||
]
|
||||
with app.test_request_context(
|
||||
"/api/user/models",
|
||||
method="POST",
|
||||
json={
|
||||
"upstream_model_id": "m",
|
||||
"display_name": "M",
|
||||
"base_url": "https://api.mistral.ai/v1",
|
||||
"api_key": "k",
|
||||
"capabilities": {
|
||||
"attachments": ["image", "application/pdf"],
|
||||
},
|
||||
},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user-1"}
|
||||
with _patch_db(pg_conn):
|
||||
resp = UserModelsCollectionResource().post()
|
||||
assert resp.status_code == 201
|
||||
|
||||
def test_create_rejects_private_ip_dns(self, app, pg_conn):
|
||||
from application.api.user.models.routes import (
|
||||
UserModelsCollectionResource,
|
||||
)
|
||||
|
||||
with patch("application.security.safe_url.socket.getaddrinfo") as gai:
|
||||
# Hostname resolves to a private IP only — must reject
|
||||
gai.return_value = [
|
||||
(None, None, None, None, ("10.0.0.5", 0))
|
||||
]
|
||||
with app.test_request_context(
|
||||
"/api/user/models",
|
||||
method="POST",
|
||||
json={
|
||||
"upstream_model_id": "x",
|
||||
"display_name": "x",
|
||||
"base_url": "https://evil.example.com/v1",
|
||||
"api_key": "k",
|
||||
},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user-1"}
|
||||
with _patch_db(pg_conn):
|
||||
resp = UserModelsCollectionResource().post()
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
# List / get / patch / delete
|
||||
|
||||
|
||||
def _create_via_repo(pg_conn, user_id="user-1", **kwargs):
|
||||
from application.storage.db.repositories.user_custom_models import (
|
||||
UserCustomModelsRepository,
|
||||
)
|
||||
|
||||
return UserCustomModelsRepository(pg_conn).create(
|
||||
user_id=user_id,
|
||||
upstream_model_id=kwargs.pop("upstream_model_id", "mistral-large-latest"),
|
||||
display_name=kwargs.pop("display_name", "My Mistral"),
|
||||
base_url=kwargs.pop("base_url", "https://api.mistral.ai/v1"),
|
||||
api_key_plaintext=kwargs.pop("api_key_plaintext", "sk-mistral-test"),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestList:
|
||||
def test_lists_only_users_own(self, app, pg_conn):
|
||||
from application.api.user.models.routes import (
|
||||
UserModelsCollectionResource,
|
||||
)
|
||||
|
||||
_create_via_repo(pg_conn, user_id="alice", upstream_model_id="alice-1")
|
||||
_create_via_repo(pg_conn, user_id="bob", upstream_model_id="bob-1")
|
||||
|
||||
with app.test_request_context("/api/user/models"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "alice"}
|
||||
with _patch_db(pg_conn):
|
||||
resp = UserModelsCollectionResource().get()
|
||||
assert resp.status_code == 200
|
||||
body = resp.get_json()
|
||||
upstream_ids = {m["upstream_model_id"] for m in body["models"]}
|
||||
assert upstream_ids == {"alice-1"}
|
||||
# Never expose the api_key
|
||||
for m in body["models"]:
|
||||
assert "api_key" not in m
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGet:
|
||||
def test_returns_404_for_other_users_model(self, app, pg_conn):
|
||||
from application.api.user.models.routes import UserModelResource
|
||||
|
||||
created = _create_via_repo(pg_conn, user_id="alice")
|
||||
with app.test_request_context(
|
||||
f"/api/user/models/{created['id']}"
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "bob"}
|
||||
with _patch_db(pg_conn):
|
||||
resp = UserModelResource().get(model_id=created["id"])
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPatch:
|
||||
def test_patch_updates_display_name(self, app, pg_conn):
|
||||
from application.api.user.models.routes import UserModelResource
|
||||
|
||||
created = _create_via_repo(pg_conn, user_id="user-1")
|
||||
with app.test_request_context(
|
||||
f"/api/user/models/{created['id']}",
|
||||
method="PATCH",
|
||||
json={"display_name": "Renamed"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user-1"}
|
||||
with _patch_db(pg_conn):
|
||||
resp = UserModelResource().patch(model_id=created["id"])
|
||||
assert resp.status_code == 200
|
||||
body = resp.get_json()
|
||||
assert body["display_name"] == "Renamed"
|
||||
|
||||
def test_patch_blank_api_key_keeps_existing(self, app, pg_conn):
|
||||
"""Critical PATCH semantic: empty/missing api_key in body must
|
||||
preserve the stored ciphertext (the UI sends a blank password
|
||||
field when the user wants to keep the existing key)."""
|
||||
from application.api.user.models.routes import UserModelResource
|
||||
from application.storage.db.repositories.user_custom_models import (
|
||||
UserCustomModelsRepository,
|
||||
)
|
||||
|
||||
created = _create_via_repo(pg_conn, user_id="user-1")
|
||||
original_key_plaintext = "sk-mistral-test"
|
||||
|
||||
with app.test_request_context(
|
||||
f"/api/user/models/{created['id']}",
|
||||
method="PATCH",
|
||||
json={"display_name": "Just rename me", "api_key": ""},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user-1"}
|
||||
with _patch_db(pg_conn):
|
||||
resp = UserModelResource().patch(model_id=created["id"])
|
||||
assert resp.status_code == 200
|
||||
|
||||
# Decrypted key is unchanged
|
||||
repo = UserCustomModelsRepository(pg_conn)
|
||||
assert (
|
||||
repo.get_decrypted_api_key(created["id"], "user-1")
|
||||
== original_key_plaintext
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestDelete:
|
||||
def test_delete_removes_row_and_invalidates_cache(self, app, pg_conn):
|
||||
from application.api.user.models.routes import UserModelResource
|
||||
from application.core.model_registry import ModelRegistry
|
||||
|
||||
created = _create_via_repo(pg_conn, user_id="user-1")
|
||||
# Warm the registry's per-user cache via a lookup
|
||||
with patch(
|
||||
"application.storage.db.session.db_readonly"
|
||||
) as ro:
|
||||
@contextmanager
|
||||
def _y():
|
||||
yield pg_conn
|
||||
|
||||
ro.side_effect = _y
|
||||
ModelRegistry.get_instance().get_model(created["id"], user_id="user-1")
|
||||
# Now delete via the route — invalidation must happen so a
|
||||
# subsequent lookup misses
|
||||
with app.test_request_context(
|
||||
f"/api/user/models/{created['id']}", method="DELETE"
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user-1"}
|
||||
with _patch_db(pg_conn):
|
||||
resp = UserModelResource().delete(model_id=created["id"])
|
||||
assert resp.status_code == 200
|
||||
# Cache invalidated → next lookup re-queries DB and finds nothing
|
||||
assert "user-1" not in ModelRegistry.get_instance()._user_models
|
||||
|
||||
|
||||
# /api/models combined view
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSecurityCreateRejectsBlankFields:
|
||||
"""P1 #1 partial: blank api_key on create must be rejected so we
|
||||
can never end up with an unroutable BYOM record that would cause
|
||||
LLMCreator to leak settings.API_KEY to the user-supplied URL."""
|
||||
|
||||
def test_create_rejects_blank_api_key(self, app, pg_conn):
|
||||
from application.api.user.models.routes import (
|
||||
UserModelsCollectionResource,
|
||||
)
|
||||
|
||||
with patch("application.security.safe_url.socket.getaddrinfo") as gai:
|
||||
gai.return_value = [(None, None, None, None, ("104.18.0.1", 0))]
|
||||
with app.test_request_context(
|
||||
"/api/user/models",
|
||||
method="POST",
|
||||
json={
|
||||
"upstream_model_id": "x",
|
||||
"display_name": "x",
|
||||
"base_url": "https://api.mistral.ai/v1",
|
||||
"api_key": " ", # whitespace only
|
||||
},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u"}
|
||||
with _patch_db(pg_conn):
|
||||
resp = UserModelsCollectionResource().post()
|
||||
assert resp.status_code == 400
|
||||
body = resp.get_json()
|
||||
assert "api_key" in (body.get("error") or "").lower()
|
||||
|
||||
def test_patch_rejects_blank_required_field(self, app, pg_conn):
|
||||
from application.api.user.models.routes import UserModelResource
|
||||
from application.storage.db.repositories.user_custom_models import (
|
||||
UserCustomModelsRepository,
|
||||
)
|
||||
|
||||
created = UserCustomModelsRepository(pg_conn).create(
|
||||
user_id="u",
|
||||
upstream_model_id="x",
|
||||
display_name="x",
|
||||
base_url="https://api.mistral.ai/v1",
|
||||
api_key_plaintext="sk-x",
|
||||
)
|
||||
|
||||
with app.test_request_context(
|
||||
f"/api/user/models/{created['id']}",
|
||||
method="PATCH",
|
||||
json={"base_url": " "},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u"}
|
||||
with _patch_db(pg_conn):
|
||||
resp = UserModelResource().patch(model_id=created["id"])
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPayloadConnectionTest:
|
||||
"""Verifies the payload-based test endpoint. Lets the UI's 'Test
|
||||
connection' button work *before* the model is saved — operators
|
||||
expect to validate their endpoint + key before committing."""
|
||||
|
||||
def test_payload_test_rejects_unsafe_url(self, app, pg_conn):
|
||||
from application.api.user.models.routes import (
|
||||
UserModelTestPayloadResource,
|
||||
)
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/user/models/test",
|
||||
method="POST",
|
||||
json={
|
||||
"base_url": "https://127.0.0.1/v1",
|
||||
"api_key": "sk-anything",
|
||||
"upstream_model_id": "x",
|
||||
},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user-1"}
|
||||
with _patch_db(pg_conn):
|
||||
resp = UserModelTestPayloadResource().post()
|
||||
assert resp.status_code == 400
|
||||
body = resp.get_json()
|
||||
assert body["ok"] is False
|
||||
|
||||
def test_payload_test_returns_ok_when_upstream_responds_2xx(
|
||||
self, app, pg_conn
|
||||
):
|
||||
from application.api.user.models.routes import (
|
||||
UserModelTestPayloadResource,
|
||||
)
|
||||
|
||||
# pinned_post is the IP-pinned dispatch helper. Patching it
|
||||
# bypasses both the SSRF guard and the network — the success
|
||||
# path we're verifying here is the route's response handling.
|
||||
with patch("application.api.user.models.routes.pinned_post") as rp:
|
||||
rp.return_value = MagicMock(
|
||||
status_code=200,
|
||||
headers={"Content-Type": "application/json"},
|
||||
text='{"ok": true}',
|
||||
)
|
||||
with app.test_request_context(
|
||||
"/api/user/models/test",
|
||||
method="POST",
|
||||
json={
|
||||
"base_url": "https://api.mistral.ai/v1",
|
||||
"api_key": "sk-mistral-test",
|
||||
"upstream_model_id": "mistral-large-latest",
|
||||
},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user-1"}
|
||||
with _patch_db(pg_conn):
|
||||
resp = UserModelTestPayloadResource().post()
|
||||
assert resp.status_code == 200
|
||||
body = resp.get_json()
|
||||
assert body["ok"] is True
|
||||
# Verify the upstream call carried the user's submitted key (not
|
||||
# whatever's in the DB) and the right model name.
|
||||
call_args = rp.call_args
|
||||
assert call_args.kwargs["headers"]["Authorization"] == "Bearer sk-mistral-test"
|
||||
assert call_args.kwargs["json"]["model"] == "mistral-large-latest"
|
||||
|
||||
def test_payload_test_unauthenticated_returns_401(self, app, pg_conn):
|
||||
from application.api.user.models.routes import (
|
||||
UserModelTestPayloadResource,
|
||||
)
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/user/models/test",
|
||||
method="POST",
|
||||
json={
|
||||
"base_url": "https://api.mistral.ai/v1",
|
||||
"api_key": "k",
|
||||
"upstream_model_id": "x",
|
||||
},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = None
|
||||
with _patch_db(pg_conn):
|
||||
resp = UserModelTestPayloadResource().post()
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_payload_test_missing_fields_returns_400(self, app, pg_conn):
|
||||
from application.api.user.models.routes import (
|
||||
UserModelTestPayloadResource,
|
||||
)
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/user/models/test",
|
||||
method="POST",
|
||||
json={"base_url": "https://api.mistral.ai/v1"},
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user-1"}
|
||||
with _patch_db(pg_conn):
|
||||
resp = UserModelTestPayloadResource().post()
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestByIdConnectionTestAcceptsOverrides:
|
||||
"""P3: in edit mode the modal sends current form state as overrides
|
||||
so the test reflects in-flight edits (not the saved record)."""
|
||||
|
||||
def _make_row(self, pg_conn):
|
||||
from application.storage.db.repositories.user_custom_models import (
|
||||
UserCustomModelsRepository,
|
||||
)
|
||||
|
||||
return UserCustomModelsRepository(pg_conn).create(
|
||||
user_id="u",
|
||||
upstream_model_id="stored-model",
|
||||
display_name="Stored",
|
||||
base_url="https://stored.example.com/v1",
|
||||
api_key_plaintext="sk-stored",
|
||||
)
|
||||
|
||||
def _post_test(self, app, pg_conn, model_id, body):
|
||||
from application.api.user.models.routes import UserModelTestResource
|
||||
|
||||
with patch("application.api.user.models.routes.pinned_post") as rp:
|
||||
rp.return_value = MagicMock(
|
||||
status_code=200,
|
||||
headers={"Content-Type": "application/json"},
|
||||
text='{"ok": true}',
|
||||
)
|
||||
with app.test_request_context(
|
||||
f"/api/user/models/{model_id}/test", method="POST", json=body
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "u"}
|
||||
with _patch_db(pg_conn):
|
||||
UserModelTestResource().post(model_id=model_id)
|
||||
return rp.call_args
|
||||
|
||||
def test_overrides_win_when_supplied(self, app, pg_conn):
|
||||
row = self._make_row(pg_conn)
|
||||
ca = self._post_test(
|
||||
app,
|
||||
pg_conn,
|
||||
row["id"],
|
||||
{
|
||||
"base_url": "https://new.example.com/v1",
|
||||
"api_key": "sk-new",
|
||||
"upstream_model_id": "new-model",
|
||||
},
|
||||
)
|
||||
assert ca.args[0] == "https://new.example.com/v1/chat/completions"
|
||||
assert ca.kwargs["headers"]["Authorization"] == "Bearer sk-new"
|
||||
assert ca.kwargs["json"]["model"] == "new-model"
|
||||
|
||||
def test_blank_overrides_fall_back_to_stored(self, app, pg_conn):
|
||||
"""The classic edit-mode flow: user changed base_url, left
|
||||
api_key blank — server uses the new URL but the stored key."""
|
||||
row = self._make_row(pg_conn)
|
||||
ca = self._post_test(
|
||||
app,
|
||||
pg_conn,
|
||||
row["id"],
|
||||
{
|
||||
"base_url": "https://new.example.com/v1",
|
||||
"api_key": "",
|
||||
"upstream_model_id": "",
|
||||
},
|
||||
)
|
||||
assert ca.args[0] == "https://new.example.com/v1/chat/completions"
|
||||
# Stored key (decrypted) was used.
|
||||
assert ca.kwargs["headers"]["Authorization"] == "Bearer sk-stored"
|
||||
# Stored upstream_model_id was used.
|
||||
assert ca.kwargs["json"]["model"] == "stored-model"
|
||||
|
||||
def test_empty_body_uses_all_stored_values(self, app, pg_conn):
|
||||
row = self._make_row(pg_conn)
|
||||
ca = self._post_test(app, pg_conn, row["id"], {})
|
||||
assert ca.args[0] == "https://stored.example.com/v1/chat/completions"
|
||||
assert ca.kwargs["headers"]["Authorization"] == "Bearer sk-stored"
|
||||
assert ca.kwargs["json"]["model"] == "stored-model"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestApiModelsListWithUser:
|
||||
def test_includes_user_models_when_authenticated(self, app, pg_conn):
|
||||
"""GET /api/models with auth should surface the user's BYOM
|
||||
records alongside built-ins, each tagged with `source`."""
|
||||
from application.api.user.models.routes import ModelsListResource
|
||||
from application.core.model_registry import ModelRegistry
|
||||
|
||||
created = _create_via_repo(
|
||||
pg_conn, user_id="user-1", display_name="My Mistral"
|
||||
)
|
||||
|
||||
# Patch the *registry's* db_readonly so the per-user layer load
|
||||
# uses the test connection.
|
||||
@contextmanager
|
||||
def _yield():
|
||||
yield pg_conn
|
||||
|
||||
with patch(
|
||||
"application.storage.db.session.db_readonly", _yield
|
||||
):
|
||||
ModelRegistry.reset()
|
||||
with app.test_request_context("/api/models"):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": "user-1"}
|
||||
resp = ModelsListResource().get()
|
||||
|
||||
assert resp.status_code == 200
|
||||
body = resp.get_json()
|
||||
ids = [m["id"] for m in body["models"]]
|
||||
assert created["id"] in ids
|
||||
# Source label tags it for the UI
|
||||
user_entries = [m for m in body["models"] if m["id"] == created["id"]]
|
||||
assert user_entries[0]["source"] == "user"
|
||||
# Built-ins still present
|
||||
assert any(m.get("source") == "builtin" for m in body["models"])
|
||||
@@ -122,6 +122,11 @@ def mock_llm():
|
||||
llm._supports_tools = True
|
||||
llm._supports_structured_output = Mock(return_value=False)
|
||||
llm.__class__.__name__ = "MockLLM"
|
||||
# Mirror BaseLLM.__init__: real LLMCreator stores the resolved
|
||||
# upstream model name on self.model_id. Tests that build agents via
|
||||
# ``mock_llm_creator`` rely on the agent's ``upstream_model_id``
|
||||
# falling through to this attribute.
|
||||
llm.model_id = "gpt-4"
|
||||
return llm
|
||||
|
||||
|
||||
|
||||
1268
tests/core/test_byom_user_aware_helpers.py
Normal file
1268
tests/core/test_byom_user_aware_helpers.py
Normal file
File diff suppressed because it is too large
Load Diff
306
tests/core/test_model_registry_yaml.py
Normal file
306
tests/core/test_model_registry_yaml.py
Normal file
@@ -0,0 +1,306 @@
|
||||
"""Phase 1 regression tests for the YAML-driven ModelRegistry.
|
||||
|
||||
These tests encode the contract that persisted agent / workflow /
|
||||
conversation references depend on: every model id and core capability
|
||||
that existed in the old ``model_configs.py`` lists must continue to be
|
||||
produced by the new YAML-backed registry.
|
||||
|
||||
If a future YAML edit accidentally renames an id or changes a
|
||||
capability, these tests fail at CI before merge — protecting agents and
|
||||
workflows from silent fallback to the system default.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from application.core.model_registry import ModelRegistry
|
||||
from application.core.model_yaml import (
|
||||
BUILTIN_MODELS_DIR,
|
||||
load_model_yamls,
|
||||
)
|
||||
|
||||
|
||||
# ── Per-provider expected IDs ─────────────────────────────────────────────
|
||||
# Snapshot of the current built-in catalog. If you intentionally change
|
||||
# what models a provider's YAML lists, update this constant in the same
|
||||
# commit. The test exists to catch *unintentional* renames (e.g. a typo
|
||||
# in an upstream model id) that would silently break every agent that
|
||||
# references the old id.
|
||||
EXPECTED_IDS = {
|
||||
"openai": {"gpt-5.5", "gpt-5.4-mini", "gpt-5.4-nano"},
|
||||
"anthropic": {
|
||||
"claude-opus-4-7",
|
||||
"claude-sonnet-4-6",
|
||||
"claude-haiku-4-5",
|
||||
},
|
||||
"google": {
|
||||
"gemini-3.1-pro-preview",
|
||||
"gemini-3-flash-preview",
|
||||
"gemini-3.1-flash-lite-preview",
|
||||
},
|
||||
"groq": {
|
||||
"openai/gpt-oss-120b",
|
||||
"llama-3.3-70b-versatile",
|
||||
"llama-3.1-8b-instant",
|
||||
},
|
||||
"openrouter": {
|
||||
"qwen/qwen3-coder:free",
|
||||
"deepseek/deepseek-v3.2",
|
||||
"anthropic/claude-sonnet-4.6",
|
||||
},
|
||||
"novita": {
|
||||
"deepseek/deepseek-v4-pro",
|
||||
"moonshotai/kimi-k2.6",
|
||||
"zai-org/glm-5",
|
||||
},
|
||||
"azure_openai": {
|
||||
"azure-gpt-5.5",
|
||||
"azure-gpt-5.4-mini",
|
||||
"azure-gpt-5.4-nano",
|
||||
},
|
||||
"docsgpt": {"docsgpt-local"},
|
||||
"huggingface": {"huggingface-local"},
|
||||
}
|
||||
|
||||
|
||||
def _make_settings(**overrides):
|
||||
s = MagicMock()
|
||||
# All credential / mode flags off by default so each test opts in.
|
||||
s.OPENAI_BASE_URL = None
|
||||
s.OPENAI_API_KEY = None
|
||||
s.OPENAI_API_BASE = None
|
||||
s.ANTHROPIC_API_KEY = None
|
||||
s.GOOGLE_API_KEY = None
|
||||
s.GROQ_API_KEY = None
|
||||
s.OPEN_ROUTER_API_KEY = None
|
||||
s.NOVITA_API_KEY = None
|
||||
s.HUGGINGFACE_API_KEY = None
|
||||
s.LLM_PROVIDER = ""
|
||||
s.LLM_NAME = None
|
||||
s.API_KEY = None
|
||||
s.MODELS_CONFIG_DIR = None
|
||||
for k, v in overrides.items():
|
||||
setattr(s, k, v)
|
||||
return s
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_registry():
|
||||
ModelRegistry.reset()
|
||||
yield
|
||||
ModelRegistry.reset()
|
||||
|
||||
|
||||
# ── YAML schema / loader ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _by_provider(catalogs):
|
||||
"""Group a list of catalogs by provider name. Mirrors the registry's
|
||||
own grouping; useful for asserting per-provider model sets in tests."""
|
||||
out = {}
|
||||
for c in catalogs:
|
||||
out.setdefault(c.provider, []).append(c)
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestYAMLLoader:
|
||||
def test_loader_produces_expected_provider_set(self):
|
||||
catalogs = load_model_yamls([BUILTIN_MODELS_DIR])
|
||||
providers = {c.provider for c in catalogs}
|
||||
assert providers == set(EXPECTED_IDS.keys())
|
||||
|
||||
def test_each_provider_has_expected_ids(self):
|
||||
grouped = _by_provider(load_model_yamls([BUILTIN_MODELS_DIR]))
|
||||
for provider, expected in EXPECTED_IDS.items():
|
||||
actual = {m.id for c in grouped[provider] for m in c.models}
|
||||
assert actual == expected, f"{provider}: expected {expected}, got {actual}"
|
||||
|
||||
def test_attachment_alias_image_expands_to_five_mime_types(self):
|
||||
grouped = _by_provider(load_model_yamls([BUILTIN_MODELS_DIR]))
|
||||
# OpenAI uses `attachments: [image]` in its defaults block.
|
||||
for c in grouped["openai"]:
|
||||
for m in c.models:
|
||||
assert "image/png" in m.capabilities.supported_attachment_types
|
||||
assert "image/jpeg" in m.capabilities.supported_attachment_types
|
||||
assert "image/webp" in m.capabilities.supported_attachment_types
|
||||
assert len(m.capabilities.supported_attachment_types) == 5
|
||||
|
||||
def test_attachment_alias_pdf_plus_image_for_google(self):
|
||||
grouped = _by_provider(load_model_yamls([BUILTIN_MODELS_DIR]))
|
||||
for c in grouped["google"]:
|
||||
for m in c.models:
|
||||
assert "application/pdf" in m.capabilities.supported_attachment_types
|
||||
assert "image/png" in m.capabilities.supported_attachment_types
|
||||
assert len(m.capabilities.supported_attachment_types) == 6
|
||||
|
||||
def test_per_model_context_window_overrides_provider_default(self):
|
||||
grouped = _by_provider(load_model_yamls([BUILTIN_MODELS_DIR]))
|
||||
openai = {m.id: m for c in grouped["openai"] for m in c.models}
|
||||
# Provider default is 400_000; gpt-5.5 overrides to 1_050_000.
|
||||
assert openai["gpt-5.4-mini"].capabilities.context_window == 400_000
|
||||
assert openai["gpt-5.5"].capabilities.context_window == 1_050_000
|
||||
|
||||
|
||||
# ── Registry × settings: every documented .env permutation ───────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestRegistryPermutations:
|
||||
def test_openai_only(self):
|
||||
s = _make_settings(OPENAI_API_KEY="sk-test", LLM_PROVIDER="openai")
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
ids = {m.id for m in reg.get_all_models()}
|
||||
assert ids == EXPECTED_IDS["openai"] | EXPECTED_IDS["docsgpt"]
|
||||
|
||||
def test_openai_base_url_replaces_catalog_with_dynamic(self):
|
||||
s = _make_settings(
|
||||
OPENAI_BASE_URL="http://localhost:11434/v1",
|
||||
OPENAI_API_KEY="sk-test",
|
||||
LLM_PROVIDER="openai",
|
||||
LLM_NAME="llama3,gemma",
|
||||
)
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
ids = {m.id for m in reg.get_all_models()}
|
||||
# Custom local endpoint suppresses both the openai catalog AND
|
||||
# the docsgpt model (matching legacy behavior).
|
||||
assert ids == {"llama3", "gemma"}
|
||||
|
||||
def test_anthropic_only(self):
|
||||
s = _make_settings(ANTHROPIC_API_KEY="sk-ant")
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
ids = {m.id for m in reg.get_all_models()}
|
||||
assert ids == EXPECTED_IDS["anthropic"] | EXPECTED_IDS["docsgpt"]
|
||||
|
||||
def test_anthropic_via_llm_provider_with_llm_name(self):
|
||||
# Mirrors the historical _add_anthropic_models filter: when only
|
||||
# API_KEY (not ANTHROPIC_API_KEY) is set and LLM_NAME matches a
|
||||
# known model, only that model is loaded.
|
||||
s = _make_settings(
|
||||
LLM_PROVIDER="anthropic", API_KEY="key", LLM_NAME="claude-haiku-4-5"
|
||||
)
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
anthropic_ids = {
|
||||
m.id for m in reg.get_all_models() if m.provider.value == "anthropic"
|
||||
}
|
||||
assert anthropic_ids == {"claude-haiku-4-5"}
|
||||
|
||||
def test_google_only(self):
|
||||
s = _make_settings(GOOGLE_API_KEY="g-test")
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
ids = {m.id for m in reg.get_all_models()}
|
||||
assert ids == EXPECTED_IDS["google"] | EXPECTED_IDS["docsgpt"]
|
||||
|
||||
def test_groq_only(self):
|
||||
s = _make_settings(GROQ_API_KEY="g-test")
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
ids = {m.id for m in reg.get_all_models()}
|
||||
assert ids == EXPECTED_IDS["groq"] | EXPECTED_IDS["docsgpt"]
|
||||
|
||||
def test_openrouter_only(self):
|
||||
s = _make_settings(OPEN_ROUTER_API_KEY="or-test")
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
ids = {m.id for m in reg.get_all_models()}
|
||||
assert ids == EXPECTED_IDS["openrouter"] | EXPECTED_IDS["docsgpt"]
|
||||
|
||||
def test_novita_only(self):
|
||||
s = _make_settings(NOVITA_API_KEY="n-test")
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
ids = {m.id for m in reg.get_all_models()}
|
||||
assert ids == EXPECTED_IDS["novita"] | EXPECTED_IDS["docsgpt"]
|
||||
|
||||
def test_huggingface_only(self):
|
||||
s = _make_settings(HUGGINGFACE_API_KEY="hf-test")
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
ids = {m.id for m in reg.get_all_models()}
|
||||
assert ids == EXPECTED_IDS["huggingface"] | EXPECTED_IDS["docsgpt"]
|
||||
|
||||
def test_no_credentials_only_docsgpt(self):
|
||||
s = _make_settings()
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
ids = {m.id for m in reg.get_all_models()}
|
||||
assert ids == EXPECTED_IDS["docsgpt"]
|
||||
|
||||
def test_azure_via_provider(self):
|
||||
s = _make_settings(LLM_PROVIDER="azure_openai", API_KEY="key")
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
ids = {m.id for m in reg.get_all_models()}
|
||||
assert "azure-gpt-5.5" in ids
|
||||
|
||||
def test_azure_via_api_base(self):
|
||||
s = _make_settings(OPENAI_API_BASE="https://x.openai.azure.com")
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
ids = {m.id for m in reg.get_all_models()}
|
||||
assert "azure-gpt-5.5" in ids
|
||||
|
||||
def test_everything_set(self):
|
||||
s = _make_settings(
|
||||
OPENAI_API_KEY="x",
|
||||
ANTHROPIC_API_KEY="x",
|
||||
GOOGLE_API_KEY="x",
|
||||
GROQ_API_KEY="x",
|
||||
OPEN_ROUTER_API_KEY="x",
|
||||
NOVITA_API_KEY="x",
|
||||
HUGGINGFACE_API_KEY="x",
|
||||
OPENAI_API_BASE="x",
|
||||
)
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
ids = {m.id for m in reg.get_all_models()}
|
||||
all_expected = set()
|
||||
for v in EXPECTED_IDS.values():
|
||||
all_expected |= v
|
||||
assert ids == all_expected
|
||||
|
||||
|
||||
# ── Default model resolution ─────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestDefaultModelResolution:
|
||||
def test_llm_name_picks_default(self):
|
||||
s = _make_settings(
|
||||
ANTHROPIC_API_KEY="sk-ant", LLM_NAME="claude-opus-4-7"
|
||||
)
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
assert reg.default_model_id == "claude-opus-4-7"
|
||||
|
||||
def test_falls_back_to_first_model_when_no_match(self):
|
||||
s = _make_settings()
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
assert reg.default_model_id is not None
|
||||
assert reg.default_model_id in reg.models
|
||||
|
||||
|
||||
# ── Forward-compat: user_id parameter is accepted everywhere ─────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestUserIdForwardCompat:
|
||||
def test_lookup_methods_accept_user_id(self):
|
||||
s = _make_settings(OPENAI_API_KEY="sk-test")
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
# All lookup methods must accept user_id (currently ignored,
|
||||
# reserved for end-user BYOM).
|
||||
assert reg.get_model("gpt-5.5", user_id="alice") is not None
|
||||
assert len(reg.get_all_models(user_id="alice")) > 0
|
||||
assert len(reg.get_enabled_models(user_id="alice")) > 0
|
||||
assert reg.model_exists("gpt-5.5", user_id="alice") is True
|
||||
@@ -1,6 +1,17 @@
|
||||
"""Tests for application/core/model_settings.py"""
|
||||
"""Tests for application/core/model_settings.py.
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
The provider-specific load logic that used to live in private
|
||||
``_add_<X>_models`` methods now lives in plugin classes under
|
||||
``application/llm/providers/`` and YAML catalogs under
|
||||
``application/core/models/``. End-to-end coverage of the registry +
|
||||
plugin pipeline is in ``tests/core/test_model_registry_yaml.py``.
|
||||
|
||||
This file covers the data classes (``AvailableModel``,
|
||||
``ModelCapabilities``, ``ModelProvider``) and the singleton/lookup
|
||||
contract on ``ModelRegistry``.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -13,7 +24,6 @@ from application.core.model_settings import (
|
||||
|
||||
|
||||
class TestModelProvider:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_all_providers_exist(self):
|
||||
assert ModelProvider.OPENAI == "openai"
|
||||
@@ -31,7 +41,6 @@ class TestModelProvider:
|
||||
|
||||
|
||||
class TestModelCapabilities:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_defaults(self):
|
||||
caps = ModelCapabilities()
|
||||
@@ -56,7 +65,6 @@ class TestModelCapabilities:
|
||||
|
||||
|
||||
class TestAvailableModel:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_basic(self):
|
||||
model = AvailableModel(
|
||||
@@ -78,35 +86,67 @@ class TestAvailableModel:
|
||||
id="local-model",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="Local",
|
||||
base_url="http://localhost:11434",
|
||||
base_url="http://localhost:11434/v1",
|
||||
)
|
||||
d = model.to_dict()
|
||||
assert d["base_url"] == "http://localhost:11434"
|
||||
assert d["base_url"] == "http://localhost:11434/v1"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_includes_capabilities(self):
|
||||
caps = ModelCapabilities(supports_tools=True, context_window=64000)
|
||||
caps = ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
context_window=200000,
|
||||
supported_attachment_types=["image/png"],
|
||||
)
|
||||
model = AvailableModel(
|
||||
id="m1",
|
||||
provider=ModelProvider.ANTHROPIC,
|
||||
display_name="M1",
|
||||
id="m",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="M",
|
||||
capabilities=caps,
|
||||
)
|
||||
d = model.to_dict()
|
||||
assert d["supports_tools"] is True
|
||||
assert d["context_window"] == 64000
|
||||
assert d["supports_structured_output"] is True
|
||||
assert d["context_window"] == 200000
|
||||
assert d["supported_attachment_types"] == ["image/png"]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_disabled_model(self):
|
||||
model = AvailableModel(
|
||||
id="disabled",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="Disabled",
|
||||
enabled=False,
|
||||
)
|
||||
d = model.to_dict()
|
||||
assert d["enabled"] is False
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_api_key_field_never_serialized(self):
|
||||
"""Forward-compat hook: AvailableModel.api_key (reserved for the
|
||||
future end-user BYOM phase) must never leak into the wire format."""
|
||||
model = AvailableModel(
|
||||
id="byom",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="BYOM",
|
||||
api_key="secret-key-do-not-leak",
|
||||
)
|
||||
d = model.to_dict()
|
||||
assert "api_key" not in d
|
||||
for v in d.values():
|
||||
assert v != "secret-key-do-not-leak"
|
||||
|
||||
|
||||
class TestModelRegistry:
|
||||
class TestModelRegistryPublicAPI:
|
||||
"""Covers the public lookup contract. Loading behavior is exercised
|
||||
end-to-end in tests/core/test_model_registry_yaml.py."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_singleton(self):
|
||||
"""Reset singleton between tests."""
|
||||
ModelRegistry._instance = None
|
||||
ModelRegistry._initialized = False
|
||||
ModelRegistry.reset()
|
||||
yield
|
||||
ModelRegistry._instance = None
|
||||
ModelRegistry._initialized = False
|
||||
ModelRegistry.reset()
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_singleton(self):
|
||||
@@ -125,7 +165,9 @@ class TestModelRegistry:
|
||||
def test_get_model(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
model = AvailableModel(id="test", provider=ModelProvider.OPENAI, display_name="Test")
|
||||
model = AvailableModel(
|
||||
id="test", provider=ModelProvider.OPENAI, display_name="Test"
|
||||
)
|
||||
reg.models["test"] = model
|
||||
assert reg.get_model("test") is model
|
||||
assert reg.get_model("nonexistent") is None
|
||||
@@ -134,16 +176,30 @@ class TestModelRegistry:
|
||||
def test_get_all_models(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models["m1"] = AvailableModel(id="m1", provider=ModelProvider.OPENAI, display_name="M1")
|
||||
reg.models["m2"] = AvailableModel(id="m2", provider=ModelProvider.ANTHROPIC, display_name="M2")
|
||||
reg.models["m1"] = AvailableModel(
|
||||
id="m1", provider=ModelProvider.OPENAI, display_name="M1"
|
||||
)
|
||||
reg.models["m2"] = AvailableModel(
|
||||
id="m2", provider=ModelProvider.ANTHROPIC, display_name="M2"
|
||||
)
|
||||
assert len(reg.get_all_models()) == 2
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_enabled_models(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models["m1"] = AvailableModel(id="m1", provider=ModelProvider.OPENAI, display_name="M1", enabled=True)
|
||||
reg.models["m2"] = AvailableModel(id="m2", provider=ModelProvider.OPENAI, display_name="M2", enabled=False)
|
||||
reg.models["m1"] = AvailableModel(
|
||||
id="m1",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="M1",
|
||||
enabled=True,
|
||||
)
|
||||
reg.models["m2"] = AvailableModel(
|
||||
id="m2",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="M2",
|
||||
enabled=False,
|
||||
)
|
||||
enabled = reg.get_enabled_models()
|
||||
assert len(enabled) == 1
|
||||
assert enabled[0].id == "m1"
|
||||
@@ -152,652 +208,29 @@ class TestModelRegistry:
|
||||
def test_model_exists(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models["m1"] = AvailableModel(id="m1", provider=ModelProvider.OPENAI, display_name="M1")
|
||||
reg.models["m1"] = AvailableModel(
|
||||
id="m1", provider=ModelProvider.OPENAI, display_name="M1"
|
||||
)
|
||||
assert reg.model_exists("m1") is True
|
||||
assert reg.model_exists("m2") is False
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_parse_model_names(self):
|
||||
def test_lookups_accept_user_id_kwarg(self):
|
||||
"""Reserved for the future end-user BYOM phase. Currently ignored."""
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
assert reg._parse_model_names("model1,model2") == ["model1", "model2"]
|
||||
assert reg._parse_model_names("model1 , model2 ") == ["model1", "model2"]
|
||||
assert reg._parse_model_names("single") == ["single"]
|
||||
assert reg._parse_model_names("") == []
|
||||
assert reg._parse_model_names(None) == []
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_docsgpt_models(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
reg._add_docsgpt_models(mock_settings)
|
||||
assert "docsgpt-local" in reg.models
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_huggingface_models(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
reg._add_huggingface_models(mock_settings)
|
||||
assert "huggingface-local" in reg.models
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_load_models_with_openai_key(self):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.OPENAI_BASE_URL = None
|
||||
mock_settings.OPENAI_API_KEY = "sk-test"
|
||||
mock_settings.OPENAI_API_BASE = None
|
||||
mock_settings.ANTHROPIC_API_KEY = None
|
||||
mock_settings.GOOGLE_API_KEY = None
|
||||
mock_settings.GROQ_API_KEY = None
|
||||
mock_settings.OPEN_ROUTER_API_KEY = None
|
||||
mock_settings.NOVITA_API_KEY = None
|
||||
mock_settings.HUGGINGFACE_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "openai"
|
||||
mock_settings.LLM_NAME = ""
|
||||
mock_settings.API_KEY = None
|
||||
|
||||
with patch("application.core.settings.settings", mock_settings):
|
||||
reg = ModelRegistry()
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_load_models_custom_openai_base_url(self):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.OPENAI_BASE_URL = "http://localhost:11434/v1"
|
||||
mock_settings.OPENAI_API_KEY = "sk-test"
|
||||
mock_settings.OPENAI_API_BASE = None
|
||||
mock_settings.ANTHROPIC_API_KEY = None
|
||||
mock_settings.GOOGLE_API_KEY = None
|
||||
mock_settings.GROQ_API_KEY = None
|
||||
mock_settings.OPEN_ROUTER_API_KEY = None
|
||||
mock_settings.NOVITA_API_KEY = None
|
||||
mock_settings.HUGGINGFACE_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "openai"
|
||||
mock_settings.LLM_NAME = "llama3,gemma"
|
||||
mock_settings.API_KEY = None
|
||||
|
||||
with patch("application.core.settings.settings", mock_settings):
|
||||
reg = ModelRegistry()
|
||||
assert "llama3" in reg.models
|
||||
assert "gemma" in reg.models
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_default_model_selection_from_llm_name(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {"gpt-4": AvailableModel(id="gpt-4", provider=ModelProvider.OPENAI, display_name="GPT-4")}
|
||||
reg.default_model_id = "gpt-4"
|
||||
assert reg.default_model_id == "gpt-4"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_anthropic_models_with_key(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.ANTHROPIC_API_KEY = "sk-ant-test"
|
||||
mock_settings.LLM_PROVIDER = ""
|
||||
mock_settings.LLM_NAME = ""
|
||||
reg._add_anthropic_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_google_models_with_key(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.GOOGLE_API_KEY = "google-test"
|
||||
mock_settings.LLM_PROVIDER = ""
|
||||
mock_settings.LLM_NAME = ""
|
||||
reg._add_google_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_groq_models_with_key(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.GROQ_API_KEY = "groq-test"
|
||||
mock_settings.LLM_PROVIDER = ""
|
||||
mock_settings.LLM_NAME = ""
|
||||
reg._add_groq_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_openrouter_models_with_key(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.OPEN_ROUTER_API_KEY = "or-test"
|
||||
mock_settings.LLM_PROVIDER = ""
|
||||
mock_settings.LLM_NAME = ""
|
||||
reg._add_openrouter_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_novita_models_with_key(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.NOVITA_API_KEY = "novita-test"
|
||||
mock_settings.LLM_PROVIDER = ""
|
||||
mock_settings.LLM_NAME = ""
|
||||
reg._add_novita_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_azure_openai_models_specific(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.LLM_PROVIDER = "azure_openai"
|
||||
mock_settings.LLM_NAME = "nonexistent-model"
|
||||
reg._add_azure_openai_models(mock_settings)
|
||||
# Falls through to adding all azure models
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_anthropic_models_no_key_with_provider(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.ANTHROPIC_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "anthropic"
|
||||
mock_settings.LLM_NAME = "nonexistent"
|
||||
reg._add_anthropic_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_default_model_fallback_to_first(self):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.OPENAI_BASE_URL = None
|
||||
mock_settings.OPENAI_API_KEY = None
|
||||
mock_settings.OPENAI_API_BASE = None
|
||||
mock_settings.ANTHROPIC_API_KEY = None
|
||||
mock_settings.GOOGLE_API_KEY = None
|
||||
mock_settings.GROQ_API_KEY = None
|
||||
mock_settings.OPEN_ROUTER_API_KEY = None
|
||||
mock_settings.NOVITA_API_KEY = None
|
||||
mock_settings.HUGGINGFACE_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = ""
|
||||
mock_settings.LLM_NAME = ""
|
||||
mock_settings.API_KEY = None
|
||||
|
||||
with patch("application.core.settings.settings", mock_settings):
|
||||
reg = ModelRegistry()
|
||||
# Should have at least docsgpt-local
|
||||
assert reg.default_model_id is not None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_default_model_from_provider_fallback(self):
|
||||
"""When LLM_NAME is not set but LLM_PROVIDER and API_KEY are,
|
||||
default should be first model of that provider."""
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.OPENAI_BASE_URL = None
|
||||
mock_settings.OPENAI_API_KEY = "sk-test"
|
||||
mock_settings.OPENAI_API_BASE = None
|
||||
mock_settings.ANTHROPIC_API_KEY = None
|
||||
mock_settings.GOOGLE_API_KEY = None
|
||||
mock_settings.GROQ_API_KEY = None
|
||||
mock_settings.OPEN_ROUTER_API_KEY = None
|
||||
mock_settings.NOVITA_API_KEY = None
|
||||
mock_settings.HUGGINGFACE_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "openai"
|
||||
mock_settings.LLM_NAME = None
|
||||
mock_settings.API_KEY = "sk-test"
|
||||
|
||||
with patch("application.core.settings.settings", mock_settings):
|
||||
reg = ModelRegistry()
|
||||
assert reg.default_model_id is not None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_google_models_no_key_with_provider(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.GOOGLE_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "google"
|
||||
mock_settings.LLM_NAME = "nonexistent"
|
||||
reg._add_google_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_groq_models_no_key_with_provider(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.GROQ_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "groq"
|
||||
mock_settings.LLM_NAME = "nonexistent"
|
||||
reg._add_groq_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_openrouter_models_no_key_with_provider(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.OPEN_ROUTER_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "openrouter"
|
||||
mock_settings.LLM_NAME = "nonexistent"
|
||||
reg._add_openrouter_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_novita_models_no_key_with_provider(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.NOVITA_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "novita"
|
||||
mock_settings.LLM_NAME = "nonexistent"
|
||||
reg._add_novita_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_disabled_model(self):
|
||||
model = AvailableModel(
|
||||
id="disabled",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="Disabled",
|
||||
enabled=False,
|
||||
)
|
||||
d = model.to_dict()
|
||||
assert d["enabled"] is False
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_with_attachment_types(self):
|
||||
caps = ModelCapabilities(
|
||||
supported_attachment_types=["image/png", "application/pdf"],
|
||||
)
|
||||
model = AvailableModel(
|
||||
id="vision",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="Vision",
|
||||
capabilities=caps,
|
||||
)
|
||||
d = model.to_dict()
|
||||
assert d["supported_attachment_types"] == ["image/png", "application/pdf"]
|
||||
|
||||
# ----------------------------------------------------------------
|
||||
# Coverage for _add_* methods with matching LLM_NAME
|
||||
# Lines: 100, 105, 147, 171, 179, 186, 199-201, 204, 210, 213,
|
||||
# 218, 229, 233, 241, 250
|
||||
# ----------------------------------------------------------------
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_azure_openai_models_with_matching_name(self):
|
||||
"""Cover line 186: azure model matching LLM_NAME returns early."""
|
||||
from application.core.model_configs import AZURE_OPENAI_MODELS
|
||||
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.LLM_PROVIDER = "azure_openai"
|
||||
if AZURE_OPENAI_MODELS:
|
||||
mock_settings.LLM_NAME = AZURE_OPENAI_MODELS[0].id
|
||||
else:
|
||||
mock_settings.LLM_NAME = "nonexistent"
|
||||
reg._add_azure_openai_models(mock_settings)
|
||||
# Should have added at least one model
|
||||
assert len(reg.models) >= 1
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_anthropic_no_key_no_provider_fallthrough(self):
|
||||
"""Cover lines 199-204: no key, provider set but name not found -> add all."""
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.ANTHROPIC_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "anthropic"
|
||||
mock_settings.LLM_NAME = "nonexistent-model"
|
||||
reg._add_anthropic_models(mock_settings)
|
||||
# Falls through to add all anthropic models
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_google_no_key_matching_name(self):
|
||||
"""Cover lines 213-218: Google fallback with matching name."""
|
||||
from application.core.model_configs import GOOGLE_MODELS
|
||||
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.GOOGLE_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "google"
|
||||
if GOOGLE_MODELS:
|
||||
mock_settings.LLM_NAME = GOOGLE_MODELS[0].id
|
||||
else:
|
||||
mock_settings.LLM_NAME = "nonexistent"
|
||||
reg._add_google_models(mock_settings)
|
||||
assert len(reg.models) >= 1
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_groq_no_key_matching_name(self):
|
||||
"""Cover lines 229-233: Groq fallback with matching name."""
|
||||
from application.core.model_configs import GROQ_MODELS
|
||||
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.GROQ_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "groq"
|
||||
if GROQ_MODELS:
|
||||
mock_settings.LLM_NAME = GROQ_MODELS[0].id
|
||||
else:
|
||||
mock_settings.LLM_NAME = "nonexistent"
|
||||
reg._add_groq_models(mock_settings)
|
||||
assert len(reg.models) >= 1
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_openrouter_no_key_matching_name(self):
|
||||
"""Cover lines 241-250: OpenRouter fallback with matching name."""
|
||||
from application.core.model_configs import OPENROUTER_MODELS
|
||||
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.OPEN_ROUTER_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "openrouter"
|
||||
if OPENROUTER_MODELS:
|
||||
mock_settings.LLM_NAME = OPENROUTER_MODELS[0].id
|
||||
else:
|
||||
mock_settings.LLM_NAME = "nonexistent"
|
||||
reg._add_openrouter_models(mock_settings)
|
||||
assert len(reg.models) >= 1
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_novita_no_key_matching_name(self):
|
||||
"""Cover novita fallback with matching name."""
|
||||
from application.core.model_configs import NOVITA_MODELS
|
||||
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.NOVITA_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "novita"
|
||||
if NOVITA_MODELS:
|
||||
mock_settings.LLM_NAME = NOVITA_MODELS[0].id
|
||||
else:
|
||||
mock_settings.LLM_NAME = "nonexistent"
|
||||
reg._add_novita_models(mock_settings)
|
||||
assert len(reg.models) >= 1
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_load_models_default_from_llm_name_exact_match(self):
|
||||
"""Cover line 136/147: exact LLM_NAME match for default model."""
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.OPENAI_BASE_URL = None
|
||||
mock_settings.OPENAI_API_KEY = "sk-test"
|
||||
mock_settings.OPENAI_API_BASE = None
|
||||
mock_settings.ANTHROPIC_API_KEY = None
|
||||
mock_settings.GOOGLE_API_KEY = None
|
||||
mock_settings.GROQ_API_KEY = None
|
||||
mock_settings.OPEN_ROUTER_API_KEY = None
|
||||
mock_settings.NOVITA_API_KEY = None
|
||||
mock_settings.HUGGINGFACE_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "openai"
|
||||
mock_settings.API_KEY = None
|
||||
|
||||
from application.core.model_configs import OPENAI_MODELS
|
||||
|
||||
if OPENAI_MODELS:
|
||||
mock_settings.LLM_NAME = OPENAI_MODELS[0].id
|
||||
else:
|
||||
mock_settings.LLM_NAME = "gpt-4o"
|
||||
|
||||
with patch("application.core.settings.settings", mock_settings):
|
||||
reg = ModelRegistry()
|
||||
assert reg.default_model_id is not None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_openai_models_local_endpoint_no_name(self):
|
||||
"""Cover line 171: local endpoint without LLM_NAME adds nothing."""
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.OPENAI_BASE_URL = "http://localhost:11434/v1"
|
||||
mock_settings.OPENAI_API_KEY = "sk-test"
|
||||
mock_settings.LLM_NAME = None
|
||||
reg._add_openai_models(mock_settings)
|
||||
assert len(reg.models) == 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_openai_standard_no_api_key(self):
|
||||
"""Cover line 179: standard OpenAI without API key adds nothing."""
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.OPENAI_BASE_URL = None
|
||||
mock_settings.OPENAI_API_KEY = None
|
||||
reg._add_openai_models(mock_settings)
|
||||
assert len(reg.models) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Coverage — additional uncovered lines: 100, 105, 147, 171, 179, 186, 250
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestModelRegistryAdditionalCoverage:
|
||||
|
||||
def test_add_azure_openai_models_specific_name(self):
|
||||
"""Cover line 186: azure_openai with specific LLM_NAME match."""
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.LLM_PROVIDER = "azure_openai"
|
||||
mock_settings.LLM_NAME = "gpt-4o"
|
||||
|
||||
# Create a fake model that matches
|
||||
fake_model = MagicMock()
|
||||
fake_model.id = "gpt-4o"
|
||||
with patch(
|
||||
"application.core.model_configs.AZURE_OPENAI_MODELS",
|
||||
[fake_model],
|
||||
):
|
||||
reg._add_azure_openai_models(mock_settings)
|
||||
assert "gpt-4o" in reg.models
|
||||
|
||||
def test_add_anthropic_models_with_api_key(self):
|
||||
"""Cover line 100: anthropic with API key."""
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.ANTHROPIC_API_KEY = "sk-test"
|
||||
mock_settings.LLM_PROVIDER = "anthropic"
|
||||
reg._add_anthropic_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
def test_add_google_models_with_api_key(self):
|
||||
"""Cover line 105: google with API key."""
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.GOOGLE_API_KEY = "test-key"
|
||||
mock_settings.LLM_PROVIDER = "google"
|
||||
reg._add_google_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
def test_default_model_from_provider(self):
|
||||
"""Cover line 147: default model selected from provider."""
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
reg.default_model_id = None
|
||||
|
||||
fake_model = MagicMock()
|
||||
fake_model.provider = MagicMock()
|
||||
fake_model.provider.value = "openai"
|
||||
reg.models["gpt-4o"] = fake_model
|
||||
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.LLM_NAME = None
|
||||
mock_settings.LLM_PROVIDER = "openai"
|
||||
mock_settings.API_KEY = "key"
|
||||
|
||||
# Simulate the default selection logic
|
||||
if not reg.default_model_id:
|
||||
for model_id, model in reg.models.items():
|
||||
if model.provider.value == mock_settings.LLM_PROVIDER:
|
||||
reg.default_model_id = model_id
|
||||
break
|
||||
|
||||
assert reg.default_model_id == "gpt-4o"
|
||||
|
||||
def test_add_openai_local_endpoint_with_llm_name(self):
|
||||
"""Cover line 171: local endpoint registers custom models from LLM_NAME."""
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.OPENAI_BASE_URL = "http://localhost:11434/v1"
|
||||
mock_settings.OPENAI_API_KEY = "sk-test"
|
||||
mock_settings.LLM_NAME = "llama3,phi3"
|
||||
reg._add_openai_models(mock_settings)
|
||||
assert "llama3" in reg.models
|
||||
assert "phi3" in reg.models
|
||||
|
||||
def test_add_openai_standard_with_api_key(self):
|
||||
"""Cover line 179: standard OpenAI with API key adds models."""
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.OPENAI_BASE_URL = None
|
||||
mock_settings.OPENAI_API_KEY = "sk-real-key"
|
||||
reg._add_openai_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
def test_add_openrouter_models(self):
|
||||
"""Cover line 250: openrouter models added."""
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.OPEN_ROUTER_API_KEY = "or-key"
|
||||
mock_settings.LLM_PROVIDER = "openrouter"
|
||||
reg._add_openrouter_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Additional coverage for model_settings.py
|
||||
# Lines: 135-136 (backward compat LLM_NAME), 138-143 (provider fallback),
|
||||
# 145-146 (first model as default)
|
||||
# ---------------------------------------------------------------------------
|
||||
# Imports already at the top of the file; no additional imports needed
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestDefaultModelSelectionBackwardCompat:
|
||||
"""Cover lines 135-136: backward compat exact match on LLM_NAME."""
|
||||
|
||||
def test_llm_name_exact_match_as_default(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
reg.default_model_id = None
|
||||
# Add a model with composite ID
|
||||
model = AvailableModel(
|
||||
id="my-composite-model",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="Composite",
|
||||
description="test",
|
||||
capabilities=ModelCapabilities(),
|
||||
reg.models["m1"] = AvailableModel(
|
||||
id="m1", provider=ModelProvider.OPENAI, display_name="M1"
|
||||
)
|
||||
reg.models["my-composite-model"] = model
|
||||
assert reg.get_model("m1", user_id="alice") is not None
|
||||
assert reg.model_exists("m1", user_id="alice") is True
|
||||
assert len(reg.get_all_models(user_id="alice")) == 1
|
||||
assert len(reg.get_enabled_models(user_id="alice")) == 1
|
||||
|
||||
# Simulate _parse_model_names returning something different
|
||||
# so that the first for-loop doesn't match
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.LLM_NAME = "my-composite-model"
|
||||
mock_settings.LLM_PROVIDER = None
|
||||
mock_settings.API_KEY = None
|
||||
|
||||
# Call the logic directly
|
||||
model_names = reg._parse_model_names(mock_settings.LLM_NAME)
|
||||
for mn in model_names:
|
||||
if mn in reg.models:
|
||||
reg.default_model_id = mn
|
||||
break
|
||||
|
||||
assert reg.default_model_id == "my-composite-model"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestDefaultModelSelectionByProvider:
|
||||
"""Cover lines 138-143: default model by provider when LLM_NAME doesn't match."""
|
||||
|
||||
def test_default_by_provider(self):
|
||||
@pytest.mark.unit
|
||||
def test_reset(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
reg.default_model_id = None
|
||||
model = AvailableModel(
|
||||
id="gpt-4",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="GPT-4",
|
||||
description="test",
|
||||
capabilities=ModelCapabilities(),
|
||||
)
|
||||
reg.models["gpt-4"] = model
|
||||
|
||||
# Simulate: LLM_NAME doesn't exist/match, but LLM_PROVIDER + API_KEY set
|
||||
if not reg.default_model_id:
|
||||
for model_id, m in reg.models.items():
|
||||
if m.provider.value == "openai":
|
||||
reg.default_model_id = model_id
|
||||
break
|
||||
|
||||
assert reg.default_model_id == "gpt-4"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestDefaultModelSelectionFirstModel:
|
||||
"""Cover lines 145-146: first model as default when nothing else matches."""
|
||||
|
||||
def test_first_model_as_default(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
reg.default_model_id = None
|
||||
model = AvailableModel(
|
||||
id="fallback-model",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="Fallback",
|
||||
description="test",
|
||||
capabilities=ModelCapabilities(),
|
||||
)
|
||||
reg.models["fallback-model"] = model
|
||||
|
||||
if not reg.default_model_id and reg.models:
|
||||
reg.default_model_id = next(iter(reg.models.keys()))
|
||||
|
||||
assert reg.default_model_id == "fallback-model"
|
||||
r1 = ModelRegistry()
|
||||
ModelRegistry.reset()
|
||||
r2 = ModelRegistry()
|
||||
assert r1 is not r2
|
||||
|
||||
208
tests/core/test_models_config_dir.py
Normal file
208
tests/core/test_models_config_dir.py
Normal file
@@ -0,0 +1,208 @@
|
||||
"""Phase 3 tests: operator MODELS_CONFIG_DIR.
|
||||
|
||||
Covers the operator-supplied directory of model YAMLs that's loaded
|
||||
after the built-in catalog. Operators use this to add new
|
||||
``openai_compatible`` providers, extend an existing provider's catalog
|
||||
with extra models, or override a built-in model's capabilities — all
|
||||
without forking the repo.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from textwrap import dedent
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from application.core.model_registry import ModelRegistry
|
||||
|
||||
|
||||
def _make_settings(**overrides):
|
||||
s = MagicMock()
|
||||
s.OPENAI_BASE_URL = None
|
||||
s.OPENAI_API_KEY = None
|
||||
s.OPENAI_API_BASE = None
|
||||
s.ANTHROPIC_API_KEY = None
|
||||
s.GOOGLE_API_KEY = None
|
||||
s.GROQ_API_KEY = None
|
||||
s.OPEN_ROUTER_API_KEY = None
|
||||
s.NOVITA_API_KEY = None
|
||||
s.HUGGINGFACE_API_KEY = None
|
||||
s.LLM_PROVIDER = ""
|
||||
s.LLM_NAME = None
|
||||
s.API_KEY = None
|
||||
s.MODELS_CONFIG_DIR = None
|
||||
for k, v in overrides.items():
|
||||
setattr(s, k, v)
|
||||
return s
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_registry():
|
||||
ModelRegistry.reset()
|
||||
yield
|
||||
ModelRegistry.reset()
|
||||
|
||||
|
||||
# ── New provider via openai_compatible ───────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestOperatorAddsNewProvider:
|
||||
def test_drop_in_yaml_appears_in_registry(
|
||||
self, tmp_path, monkeypatch
|
||||
):
|
||||
(tmp_path / "fireworks.yaml").write_text(dedent("""
|
||||
provider: openai_compatible
|
||||
display_provider: fireworks
|
||||
api_key_env: FIREWORKS_API_KEY
|
||||
base_url: https://api.fireworks.ai/inference/v1
|
||||
defaults:
|
||||
supports_tools: true
|
||||
models:
|
||||
- id: accounts/fireworks/models/llama-v3p3-70b-instruct
|
||||
display_name: Llama 3.3 70B (Fireworks)
|
||||
"""))
|
||||
monkeypatch.setenv("FIREWORKS_API_KEY", "fw-key")
|
||||
|
||||
s = _make_settings(MODELS_CONFIG_DIR=str(tmp_path))
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
|
||||
m = reg.get_model("accounts/fireworks/models/llama-v3p3-70b-instruct")
|
||||
assert m is not None
|
||||
assert m.api_key == "fw-key"
|
||||
assert m.base_url == "https://api.fireworks.ai/inference/v1"
|
||||
assert m.display_provider == "fireworks"
|
||||
|
||||
|
||||
# ── Extending an existing provider's catalog ─────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestOperatorExtendsExistingProvider:
|
||||
def test_operator_adds_anthropic_model_to_builtin_catalog(
|
||||
self, tmp_path
|
||||
):
|
||||
(tmp_path / "anthropic-extra.yaml").write_text(dedent("""
|
||||
provider: anthropic
|
||||
defaults:
|
||||
supports_tools: true
|
||||
context_window: 200000
|
||||
models:
|
||||
- id: claude-haiku-5-0-future
|
||||
display_name: Claude Haiku 5.0
|
||||
"""))
|
||||
|
||||
s = _make_settings(
|
||||
ANTHROPIC_API_KEY="sk-ant",
|
||||
MODELS_CONFIG_DIR=str(tmp_path),
|
||||
)
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
|
||||
# Built-in models still present
|
||||
assert reg.get_model("claude-sonnet-4-6") is not None
|
||||
assert reg.get_model("claude-opus-4-7") is not None
|
||||
# Operator-added model also present
|
||||
added = reg.get_model("claude-haiku-5-0-future")
|
||||
assert added is not None
|
||||
assert added.display_name == "Claude Haiku 5.0"
|
||||
|
||||
|
||||
# ── Overriding a built-in model's capabilities ───────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestOperatorOverridesBuiltinCapabilities:
|
||||
def test_operator_yaml_overrides_builtin_context_window(
|
||||
self, tmp_path, caplog
|
||||
):
|
||||
# Override anthropic claude-haiku-4-5 to claim a 1M context window
|
||||
(tmp_path / "anthropic-override.yaml").write_text(dedent("""
|
||||
provider: anthropic
|
||||
defaults:
|
||||
supports_tools: true
|
||||
attachments: [image]
|
||||
context_window: 1000000
|
||||
models:
|
||||
- id: claude-haiku-4-5
|
||||
display_name: Claude Haiku 4.5 (extended)
|
||||
description: Operator-overridden capabilities
|
||||
"""))
|
||||
|
||||
s = _make_settings(
|
||||
ANTHROPIC_API_KEY="sk-ant",
|
||||
MODELS_CONFIG_DIR=str(tmp_path),
|
||||
)
|
||||
with caplog.at_level(logging.WARNING):
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
|
||||
m = reg.get_model("claude-haiku-4-5")
|
||||
assert m.display_name == "Claude Haiku 4.5 (extended)"
|
||||
assert m.description == "Operator-overridden capabilities"
|
||||
assert m.capabilities.context_window == 1_000_000
|
||||
|
||||
# And the override warning fires so the operator can audit it
|
||||
assert any(
|
||||
"claude-haiku-4-5" in rec.message and "redefined" in rec.message
|
||||
for rec in caplog.records
|
||||
)
|
||||
|
||||
|
||||
# ── Misconfigured MODELS_CONFIG_DIR ──────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestMisconfiguredOperatorDir:
|
||||
def test_missing_dir_logs_warning_and_continues(
|
||||
self, tmp_path, caplog
|
||||
):
|
||||
bogus = tmp_path / "does-not-exist"
|
||||
s = _make_settings(MODELS_CONFIG_DIR=str(bogus))
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
|
||||
# Built-in catalog still loaded
|
||||
assert reg.get_model("docsgpt-local") is not None
|
||||
# And the operator was warned
|
||||
assert any("does not exist" in rec.message for rec in caplog.records)
|
||||
|
||||
def test_path_is_a_file_logs_warning(self, tmp_path, caplog):
|
||||
afile = tmp_path / "not-a-dir.yaml"
|
||||
afile.write_text("provider: anthropic\nmodels: []")
|
||||
|
||||
s = _make_settings(MODELS_CONFIG_DIR=str(afile))
|
||||
with caplog.at_level(logging.WARNING):
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
|
||||
assert reg.get_model("docsgpt-local") is not None
|
||||
assert any("not a directory" in rec.message for rec in caplog.records)
|
||||
|
||||
|
||||
# ── Validation: unknown provider rejected ────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestOperatorValidation:
|
||||
def test_unknown_provider_in_operator_yaml_aborts_boot(self, tmp_path):
|
||||
(tmp_path / "bogus.yaml").write_text(dedent("""
|
||||
provider: not_a_real_provider
|
||||
models:
|
||||
- id: x
|
||||
display_name: X
|
||||
"""))
|
||||
|
||||
s = _make_settings(MODELS_CONFIG_DIR=str(tmp_path))
|
||||
with patch("application.core.settings.settings", s):
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
ModelRegistry()
|
||||
# Could be ModelYAMLError (enum check) or ValueError (registry check);
|
||||
# either way the message must surface what's wrong.
|
||||
msg = str(exc_info.value)
|
||||
assert "not_a_real_provider" in msg
|
||||
298
tests/core/test_openai_compatible.py
Normal file
298
tests/core/test_openai_compatible.py
Normal file
@@ -0,0 +1,298 @@
|
||||
"""Phase 2 tests for the openai_compatible provider.
|
||||
|
||||
Covers YAML loading from a temp directory, multiple coexisting catalogs
|
||||
(Mistral + Together), env-var-based credential resolution, the legacy
|
||||
OPENAI_BASE_URL + LLM_NAME fallback, and end-to-end model dispatch
|
||||
through LLMCreator.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from application.core.model_registry import ModelRegistry
|
||||
from application.core.model_settings import ModelProvider
|
||||
|
||||
|
||||
def _make_settings(**overrides):
|
||||
s = MagicMock()
|
||||
s.OPENAI_BASE_URL = None
|
||||
s.OPENAI_API_KEY = None
|
||||
s.OPENAI_API_BASE = None
|
||||
s.ANTHROPIC_API_KEY = None
|
||||
s.GOOGLE_API_KEY = None
|
||||
s.GROQ_API_KEY = None
|
||||
s.OPEN_ROUTER_API_KEY = None
|
||||
s.NOVITA_API_KEY = None
|
||||
s.HUGGINGFACE_API_KEY = None
|
||||
s.LLM_PROVIDER = ""
|
||||
s.LLM_NAME = None
|
||||
s.API_KEY = None
|
||||
s.MODELS_CONFIG_DIR = None
|
||||
for k, v in overrides.items():
|
||||
setattr(s, k, v)
|
||||
return s
|
||||
|
||||
|
||||
def _write_mistral_yaml(directory: Path) -> Path:
|
||||
path = directory / "mistral.yaml"
|
||||
path.write_text(dedent("""
|
||||
provider: openai_compatible
|
||||
display_provider: mistral
|
||||
api_key_env: MISTRAL_API_KEY
|
||||
base_url: https://api.mistral.ai/v1
|
||||
defaults:
|
||||
supports_tools: true
|
||||
context_window: 128000
|
||||
models:
|
||||
- id: mistral-large-latest
|
||||
display_name: Mistral Large
|
||||
- id: mistral-small-latest
|
||||
display_name: Mistral Small
|
||||
"""))
|
||||
return path
|
||||
|
||||
|
||||
def _write_together_yaml(directory: Path) -> Path:
|
||||
path = directory / "together.yaml"
|
||||
path.write_text(dedent("""
|
||||
provider: openai_compatible
|
||||
display_provider: together
|
||||
api_key_env: TOGETHER_API_KEY
|
||||
base_url: https://api.together.xyz/v1
|
||||
defaults:
|
||||
supports_tools: true
|
||||
models:
|
||||
- id: meta-llama/Llama-3.3-70B-Instruct-Turbo
|
||||
display_name: Llama 3.3 70B (Together)
|
||||
"""))
|
||||
return path
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_registry():
|
||||
ModelRegistry.reset()
|
||||
yield
|
||||
ModelRegistry.reset()
|
||||
|
||||
|
||||
# ── YAML-driven catalogs ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestYAMLCompatibleProvider:
|
||||
def test_mistral_yaml_loads_with_env_key(
|
||||
self, tmp_path, monkeypatch
|
||||
):
|
||||
_write_mistral_yaml(tmp_path)
|
||||
monkeypatch.setenv("MISTRAL_API_KEY", "sk-mistral-test")
|
||||
|
||||
s = _make_settings(MODELS_CONFIG_DIR=str(tmp_path))
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
|
||||
m = reg.get_model("mistral-large-latest")
|
||||
assert m is not None
|
||||
assert m.provider == ModelProvider.OPENAI_COMPATIBLE
|
||||
assert m.display_provider == "mistral"
|
||||
assert m.base_url == "https://api.mistral.ai/v1"
|
||||
assert m.api_key == "sk-mistral-test"
|
||||
assert m.capabilities.supports_tools is True
|
||||
assert m.capabilities.context_window == 128000
|
||||
|
||||
def test_yaml_skipped_when_env_var_missing(
|
||||
self, tmp_path, monkeypatch
|
||||
):
|
||||
_write_mistral_yaml(tmp_path)
|
||||
monkeypatch.delenv("MISTRAL_API_KEY", raising=False)
|
||||
|
||||
s = _make_settings(MODELS_CONFIG_DIR=str(tmp_path))
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
|
||||
# Catalog skipped when no key — no Mistral models in the registry
|
||||
assert reg.get_model("mistral-large-latest") is None
|
||||
|
||||
def test_two_compatible_catalogs_coexist_with_separate_keys(
|
||||
self, tmp_path, monkeypatch
|
||||
):
|
||||
_write_mistral_yaml(tmp_path)
|
||||
_write_together_yaml(tmp_path)
|
||||
monkeypatch.setenv("MISTRAL_API_KEY", "sk-mistral")
|
||||
monkeypatch.setenv("TOGETHER_API_KEY", "sk-together")
|
||||
|
||||
s = _make_settings(MODELS_CONFIG_DIR=str(tmp_path))
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
|
||||
mistral = reg.get_model("mistral-large-latest")
|
||||
together = reg.get_model("meta-llama/Llama-3.3-70B-Instruct-Turbo")
|
||||
|
||||
assert mistral.api_key == "sk-mistral"
|
||||
assert mistral.base_url == "https://api.mistral.ai/v1"
|
||||
assert mistral.display_provider == "mistral"
|
||||
|
||||
assert together.api_key == "sk-together"
|
||||
assert together.base_url == "https://api.together.xyz/v1"
|
||||
assert together.display_provider == "together"
|
||||
|
||||
def test_one_catalog_enabled_other_skipped(
|
||||
self, tmp_path, monkeypatch
|
||||
):
|
||||
_write_mistral_yaml(tmp_path)
|
||||
_write_together_yaml(tmp_path)
|
||||
monkeypatch.setenv("MISTRAL_API_KEY", "sk-mistral")
|
||||
monkeypatch.delenv("TOGETHER_API_KEY", raising=False)
|
||||
|
||||
s = _make_settings(MODELS_CONFIG_DIR=str(tmp_path))
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
|
||||
assert reg.get_model("mistral-large-latest") is not None
|
||||
assert reg.get_model("meta-llama/Llama-3.3-70B-Instruct-Turbo") is None
|
||||
|
||||
def test_missing_base_url_raises(self, tmp_path, monkeypatch):
|
||||
bad = tmp_path / "broken.yaml"
|
||||
bad.write_text(dedent("""
|
||||
provider: openai_compatible
|
||||
api_key_env: SOME_KEY
|
||||
models:
|
||||
- id: x
|
||||
display_name: X
|
||||
"""))
|
||||
monkeypatch.setenv("SOME_KEY", "k")
|
||||
|
||||
s = _make_settings(MODELS_CONFIG_DIR=str(tmp_path))
|
||||
with patch("application.core.settings.settings", s):
|
||||
with pytest.raises(ValueError, match="must set 'base_url'"):
|
||||
ModelRegistry()
|
||||
|
||||
def test_missing_api_key_env_raises(self, tmp_path, monkeypatch):
|
||||
bad = tmp_path / "broken.yaml"
|
||||
bad.write_text(dedent("""
|
||||
provider: openai_compatible
|
||||
base_url: https://x/v1
|
||||
models:
|
||||
- id: x
|
||||
display_name: X
|
||||
"""))
|
||||
|
||||
s = _make_settings(MODELS_CONFIG_DIR=str(tmp_path))
|
||||
with patch("application.core.settings.settings", s):
|
||||
with pytest.raises(ValueError, match="must set 'api_key_env'"):
|
||||
ModelRegistry()
|
||||
|
||||
def test_to_dict_uses_display_provider(
|
||||
self, tmp_path, monkeypatch
|
||||
):
|
||||
_write_mistral_yaml(tmp_path)
|
||||
monkeypatch.setenv("MISTRAL_API_KEY", "sk")
|
||||
|
||||
s = _make_settings(MODELS_CONFIG_DIR=str(tmp_path))
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
|
||||
d = reg.get_model("mistral-large-latest").to_dict()
|
||||
# /api/models response shows "mistral", not "openai_compatible"
|
||||
assert d["provider"] == "mistral"
|
||||
# api_key never leaks into the wire format
|
||||
assert "api_key" not in d
|
||||
for v in d.values():
|
||||
assert v != "sk"
|
||||
|
||||
|
||||
# ── Legacy OPENAI_BASE_URL fallback ──────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestLegacyOpenAIBaseURLPath:
|
||||
def test_legacy_models_now_provided_by_openai_compatible(self):
|
||||
s = _make_settings(
|
||||
OPENAI_BASE_URL="http://localhost:11434/v1",
|
||||
OPENAI_API_KEY="sk-local",
|
||||
LLM_PROVIDER="openai",
|
||||
LLM_NAME="llama3,gemma",
|
||||
)
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
|
||||
ids = {m.id for m in reg.get_all_models()}
|
||||
assert ids == {"llama3", "gemma"}
|
||||
|
||||
llama = reg.get_model("llama3")
|
||||
assert llama.base_url == "http://localhost:11434/v1"
|
||||
assert llama.api_key == "sk-local"
|
||||
assert llama.provider == ModelProvider.OPENAI_COMPATIBLE
|
||||
# Display provider preserves the historical "openai" label
|
||||
assert llama.display_provider == "openai"
|
||||
assert llama.to_dict()["provider"] == "openai"
|
||||
|
||||
def test_legacy_uses_api_key_fallback_when_openai_api_key_missing(self):
|
||||
s = _make_settings(
|
||||
OPENAI_BASE_URL="http://localhost:11434/v1",
|
||||
OPENAI_API_KEY=None,
|
||||
API_KEY="sk-generic",
|
||||
LLM_PROVIDER="openai",
|
||||
LLM_NAME="llama3",
|
||||
)
|
||||
with patch("application.core.settings.settings", s):
|
||||
reg = ModelRegistry()
|
||||
assert reg.get_model("llama3").api_key == "sk-generic"
|
||||
|
||||
|
||||
# ── Dispatch through LLMCreator ──────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestLLMCreatorDispatch:
|
||||
def test_llmcreator_uses_per_model_api_key_and_base_url(
|
||||
self, tmp_path, monkeypatch
|
||||
):
|
||||
"""End-to-end: when an openai_compatible model is dispatched, the
|
||||
per-model api_key + base_url from the registry must override
|
||||
whatever the caller passed."""
|
||||
_write_mistral_yaml(tmp_path)
|
||||
monkeypatch.setenv("MISTRAL_API_KEY", "sk-mistral-real")
|
||||
|
||||
s = _make_settings(MODELS_CONFIG_DIR=str(tmp_path))
|
||||
|
||||
captured = {}
|
||||
|
||||
class _FakeLLM:
|
||||
def __init__(
|
||||
self, api_key, user_api_key, *args, **kwargs
|
||||
):
|
||||
captured["api_key"] = api_key
|
||||
captured["base_url"] = kwargs.get("base_url")
|
||||
captured["model_id"] = kwargs.get("model_id")
|
||||
|
||||
with patch("application.core.settings.settings", s):
|
||||
ModelRegistry.reset()
|
||||
ModelRegistry() # warm up the registry under patched settings
|
||||
|
||||
# Now patch the OpenAI plugin's class so we can capture the
|
||||
# constructor args without spinning up the real OpenAILLM.
|
||||
from application.llm.providers import PROVIDERS_BY_NAME
|
||||
|
||||
with patch.object(
|
||||
PROVIDERS_BY_NAME["openai_compatible"],
|
||||
"llm_class",
|
||||
_FakeLLM,
|
||||
):
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
|
||||
LLMCreator.create_llm(
|
||||
type="openai_compatible",
|
||||
api_key="caller-passed-WRONG-key",
|
||||
user_api_key=None,
|
||||
decoded_token={"sub": "u"},
|
||||
model_id="mistral-large-latest",
|
||||
)
|
||||
|
||||
assert captured["api_key"] == "sk-mistral-real"
|
||||
assert captured["base_url"] == "https://api.mistral.ai/v1"
|
||||
assert captured["model_id"] == "mistral-large-latest"
|
||||
505
tests/core/test_registry_user_layer.py
Normal file
505
tests/core/test_registry_user_layer.py
Normal file
@@ -0,0 +1,505 @@
|
||||
"""Tests for the BYOM per-user layer on ModelRegistry.
|
||||
|
||||
Covers: per-user lookups don't leak across users, lookups without
|
||||
user_id stay built-in only, get_all_models / get_enabled_models /
|
||||
model_exists all consult the user layer when given user_id, and the
|
||||
explicit invalidate_user clears the cache.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from application.core.model_registry import ModelRegistry
|
||||
from application.core.model_settings import ModelProvider
|
||||
from application.storage.db.repositories.user_custom_models import (
|
||||
UserCustomModelsRepository,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_registry():
|
||||
ModelRegistry.reset()
|
||||
yield
|
||||
ModelRegistry.reset()
|
||||
|
||||
|
||||
def _make_settings(**overrides):
|
||||
s = MagicMock()
|
||||
s.OPENAI_BASE_URL = None
|
||||
s.OPENAI_API_KEY = None
|
||||
s.OPENAI_API_BASE = None
|
||||
s.ANTHROPIC_API_KEY = None
|
||||
s.GOOGLE_API_KEY = None
|
||||
s.GROQ_API_KEY = None
|
||||
s.OPEN_ROUTER_API_KEY = None
|
||||
s.NOVITA_API_KEY = None
|
||||
s.HUGGINGFACE_API_KEY = None
|
||||
s.LLM_PROVIDER = ""
|
||||
s.LLM_NAME = None
|
||||
s.API_KEY = None
|
||||
s.MODELS_CONFIG_DIR = None
|
||||
for k, v in overrides.items():
|
||||
setattr(s, k, v)
|
||||
return s
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _yield(conn):
|
||||
yield conn
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPerUserLayer:
|
||||
def test_user_models_isolated_per_user(self, pg_conn):
|
||||
"""Alice's BYOM model must not appear in Bob's lookups."""
|
||||
repo = UserCustomModelsRepository(pg_conn)
|
||||
alice_model = repo.create(
|
||||
user_id="alice",
|
||||
upstream_model_id="alice-mistral",
|
||||
display_name="Alice Mistral",
|
||||
base_url="https://api.mistral.ai/v1",
|
||||
api_key_plaintext="sk-alice",
|
||||
)
|
||||
|
||||
s = _make_settings()
|
||||
with patch("application.core.settings.settings", s), patch(
|
||||
"application.storage.db.session.db_readonly",
|
||||
lambda: _yield(pg_conn),
|
||||
):
|
||||
reg = ModelRegistry()
|
||||
assert reg.get_model(alice_model["id"], user_id="alice") is not None
|
||||
assert reg.get_model(alice_model["id"], user_id="bob") is None
|
||||
# And without a user_id at all, the per-user layer is invisible
|
||||
assert reg.get_model(alice_model["id"]) is None
|
||||
|
||||
def test_get_all_models_includes_user_models(self, pg_conn):
|
||||
repo = UserCustomModelsRepository(pg_conn)
|
||||
created = repo.create(
|
||||
user_id="user-1",
|
||||
upstream_model_id="mistral-large-latest",
|
||||
display_name="My Mistral",
|
||||
base_url="https://api.mistral.ai/v1",
|
||||
api_key_plaintext="sk-test",
|
||||
)
|
||||
|
||||
s = _make_settings()
|
||||
with patch("application.core.settings.settings", s), patch(
|
||||
"application.storage.db.session.db_readonly",
|
||||
lambda: _yield(pg_conn),
|
||||
):
|
||||
reg = ModelRegistry()
|
||||
ids_no_user = {m.id for m in reg.get_all_models()}
|
||||
ids_with_user = {
|
||||
m.id for m in reg.get_all_models(user_id="user-1")
|
||||
}
|
||||
assert created["id"] not in ids_no_user
|
||||
assert created["id"] in ids_with_user
|
||||
|
||||
def test_user_models_carry_decrypted_api_key_and_upstream_id(self, pg_conn):
|
||||
repo = UserCustomModelsRepository(pg_conn)
|
||||
created = repo.create(
|
||||
user_id="user-1",
|
||||
upstream_model_id="mistral-large-latest",
|
||||
display_name="My Mistral",
|
||||
base_url="https://api.mistral.ai/v1",
|
||||
api_key_plaintext="sk-test-XYZ",
|
||||
)
|
||||
|
||||
s = _make_settings()
|
||||
with patch("application.core.settings.settings", s), patch(
|
||||
"application.storage.db.session.db_readonly",
|
||||
lambda: _yield(pg_conn),
|
||||
):
|
||||
reg = ModelRegistry()
|
||||
m = reg.get_model(created["id"], user_id="user-1")
|
||||
|
||||
assert m is not None
|
||||
assert m.provider == ModelProvider.OPENAI_COMPATIBLE
|
||||
assert m.upstream_model_id == "mistral-large-latest"
|
||||
assert m.api_key == "sk-test-XYZ"
|
||||
assert m.base_url == "https://api.mistral.ai/v1"
|
||||
assert m.source == "user"
|
||||
# The wire format never leaks the api_key
|
||||
d = m.to_dict()
|
||||
assert "api_key" not in d
|
||||
for v in d.values():
|
||||
assert v != "sk-test-XYZ"
|
||||
|
||||
def test_invalidate_user_clears_cache(self, pg_conn):
|
||||
repo = UserCustomModelsRepository(pg_conn)
|
||||
created = repo.create(
|
||||
user_id="user-1",
|
||||
upstream_model_id="x",
|
||||
display_name="X",
|
||||
base_url="https://api.mistral.ai/v1",
|
||||
api_key_plaintext="k",
|
||||
)
|
||||
|
||||
s = _make_settings()
|
||||
# Stub Redis so invalidate_user can publish its version bump
|
||||
# without hitting a real broker. The P1 fix calls ``incr`` on
|
||||
# invalidate; here we just need it not to raise.
|
||||
fake_redis = MagicMock()
|
||||
with patch("application.core.settings.settings", s), patch(
|
||||
"application.storage.db.session.db_readonly",
|
||||
lambda: _yield(pg_conn),
|
||||
), patch(
|
||||
"application.cache.get_redis_instance", return_value=fake_redis
|
||||
):
|
||||
reg = ModelRegistry()
|
||||
assert reg.get_model(created["id"], user_id="user-1") is not None
|
||||
# Cache populated
|
||||
assert "user-1" in reg._user_models
|
||||
ModelRegistry.invalidate_user("user-1")
|
||||
assert "user-1" not in reg._user_models
|
||||
# Re-lookup repopulates
|
||||
reg.get_model(created["id"], user_id="user-1")
|
||||
assert "user-1" in reg._user_models
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestLLMCreatorDispatchUsesUpstreamModelId:
|
||||
def test_llmcreator_sends_upstream_id_not_uuid(self, pg_conn):
|
||||
"""End-to-end: LLMCreator with a BYOM uuid must construct the
|
||||
OpenAILLM with the user's upstream model name (e.g.
|
||||
``mistral-large-latest``), not the registry uuid."""
|
||||
repo = UserCustomModelsRepository(pg_conn)
|
||||
created = repo.create(
|
||||
user_id="user-1",
|
||||
upstream_model_id="mistral-large-latest",
|
||||
display_name="My Mistral",
|
||||
base_url="https://api.mistral.ai/v1",
|
||||
api_key_plaintext="sk-mistral-real",
|
||||
)
|
||||
|
||||
captured = {}
|
||||
|
||||
class _FakeLLM:
|
||||
def __init__(self, api_key, user_api_key, *args, **kwargs):
|
||||
captured["api_key"] = api_key
|
||||
captured["base_url"] = kwargs.get("base_url")
|
||||
captured["model_id"] = kwargs.get("model_id")
|
||||
|
||||
s = _make_settings()
|
||||
with patch("application.core.settings.settings", s), patch(
|
||||
"application.storage.db.session.db_readonly",
|
||||
lambda: _yield(pg_conn),
|
||||
):
|
||||
ModelRegistry()
|
||||
from application.llm.providers import PROVIDERS_BY_NAME
|
||||
|
||||
with patch.object(
|
||||
PROVIDERS_BY_NAME["openai_compatible"], "llm_class", _FakeLLM
|
||||
):
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
|
||||
LLMCreator.create_llm(
|
||||
type="openai_compatible",
|
||||
api_key="caller-passed-WRONG",
|
||||
user_api_key=None,
|
||||
decoded_token={"sub": "user-1"},
|
||||
model_id=created["id"],
|
||||
)
|
||||
|
||||
assert captured["api_key"] == "sk-mistral-real"
|
||||
assert captured["base_url"] == "https://api.mistral.ai/v1"
|
||||
assert captured["model_id"] == "mistral-large-latest" # NOT the uuid!
|
||||
|
||||
def test_llmcreator_forwards_byom_capabilities(self, pg_conn):
|
||||
"""LLMCreator must thread the registry-resolved ``capabilities``
|
||||
into the LLM. Without it the OpenAILLM hard-codes ``True`` for
|
||||
tools/structured output and advertises image attachments
|
||||
unconditionally, leaking unsupported features to BYOMs that
|
||||
disabled them."""
|
||||
repo = UserCustomModelsRepository(pg_conn)
|
||||
created = repo.create(
|
||||
user_id="user-2",
|
||||
upstream_model_id="my-text-only-model",
|
||||
display_name="Text-only BYOM",
|
||||
base_url="https://api.mistral.ai/v1",
|
||||
api_key_plaintext="sk-real",
|
||||
capabilities={
|
||||
"supports_tools": False,
|
||||
"supports_structured_output": False,
|
||||
"attachments": [],
|
||||
"context_window": 8192,
|
||||
},
|
||||
)
|
||||
|
||||
captured = {}
|
||||
|
||||
class _FakeLLM:
|
||||
def __init__(self, api_key, user_api_key, *args, **kwargs):
|
||||
captured["capabilities"] = kwargs.get("capabilities")
|
||||
|
||||
s = _make_settings()
|
||||
with patch("application.core.settings.settings", s), patch(
|
||||
"application.storage.db.session.db_readonly",
|
||||
lambda: _yield(pg_conn),
|
||||
):
|
||||
ModelRegistry()
|
||||
from application.llm.providers import PROVIDERS_BY_NAME
|
||||
|
||||
with patch.object(
|
||||
PROVIDERS_BY_NAME["openai_compatible"], "llm_class", _FakeLLM
|
||||
):
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
|
||||
LLMCreator.create_llm(
|
||||
type="openai_compatible",
|
||||
api_key="ignored",
|
||||
user_api_key=None,
|
||||
decoded_token={"sub": "user-2"},
|
||||
model_id=created["id"],
|
||||
)
|
||||
|
||||
caps = captured["capabilities"]
|
||||
assert caps is not None
|
||||
assert caps.supports_tools is False
|
||||
assert caps.supports_structured_output is False
|
||||
assert caps.supported_attachment_types == []
|
||||
|
||||
def test_byom_image_alias_expands_to_mime_types(self, pg_conn):
|
||||
"""A BYOM stored with ``attachments: ["image"]`` (the alias the
|
||||
UI sends) must surface as concrete MIME types on the registry
|
||||
record, matching the built-in YAML expansion. Without this,
|
||||
handlers/base.prepare_messages compares ``image/png`` against
|
||||
the bare alias and filters every image upload as unsupported.
|
||||
"""
|
||||
repo = UserCustomModelsRepository(pg_conn)
|
||||
created = repo.create(
|
||||
user_id="user-1",
|
||||
upstream_model_id="my-vision-model",
|
||||
display_name="Vision BYOM",
|
||||
base_url="https://api.mistral.ai/v1",
|
||||
api_key_plaintext="sk-real",
|
||||
capabilities={"attachments": ["image"]},
|
||||
)
|
||||
|
||||
s = _make_settings()
|
||||
with patch("application.core.settings.settings", s), patch(
|
||||
"application.storage.db.session.db_readonly",
|
||||
lambda: _yield(pg_conn),
|
||||
):
|
||||
reg = ModelRegistry()
|
||||
model = reg.get_model(created["id"], user_id="user-1")
|
||||
|
||||
assert model is not None
|
||||
types = model.capabilities.supported_attachment_types
|
||||
assert "image" not in types, (
|
||||
"alias must be expanded, not stored verbatim"
|
||||
)
|
||||
assert any(t.startswith("image/") for t in types)
|
||||
# Must include at least the common web image types so any image
|
||||
# an end user uploads has a chance to match.
|
||||
assert "image/png" in types
|
||||
assert "image/jpeg" in types
|
||||
|
||||
def test_byom_unknown_alias_is_skipped_at_runtime(self, pg_conn):
|
||||
"""Operator alias-map edits could orphan a stored alias. The
|
||||
registry must drop the unknown entry rather than the entire
|
||||
layer (which would hide every BYOM the user has).
|
||||
"""
|
||||
repo = UserCustomModelsRepository(pg_conn)
|
||||
created = repo.create(
|
||||
user_id="user-1",
|
||||
upstream_model_id="m",
|
||||
display_name="M",
|
||||
base_url="https://api.mistral.ai/v1",
|
||||
api_key_plaintext="k",
|
||||
# Bypass the route validation: write a bad alias straight
|
||||
# to the row to simulate the post-edit orphan case.
|
||||
capabilities={"attachments": ["image", "not-a-real-alias"]},
|
||||
)
|
||||
|
||||
s = _make_settings()
|
||||
with patch("application.core.settings.settings", s), patch(
|
||||
"application.storage.db.session.db_readonly",
|
||||
lambda: _yield(pg_conn),
|
||||
):
|
||||
reg = ModelRegistry()
|
||||
model = reg.get_model(created["id"], user_id="user-1")
|
||||
|
||||
assert model is not None
|
||||
types = model.capabilities.supported_attachment_types
|
||||
assert "not-a-real-alias" not in types
|
||||
assert any(t.startswith("image/") for t in types)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCrossProcessInvalidation:
|
||||
"""The BYOM cache lives per-process. Without the P1 fix, a CRUD on
|
||||
web-1 would leave the decrypted record (with old api_key/base_url)
|
||||
cached forever in web-2 / Celery. These tests pin down that:
|
||||
|
||||
* ``invalidate_user`` publishes a version bump to Redis
|
||||
* peers reload when the version they saw at load time is stale
|
||||
* the local TTL bounds staleness even when Redis is unreachable
|
||||
* unchanged version + expired TTL extends the entry without a
|
||||
DB read (the common-case fast path)
|
||||
"""
|
||||
|
||||
def test_invalidate_user_publishes_redis_version_bump(self, pg_conn):
|
||||
repo = UserCustomModelsRepository(pg_conn)
|
||||
repo.create(
|
||||
user_id="user-1",
|
||||
upstream_model_id="m1",
|
||||
display_name="M1",
|
||||
base_url="https://api.mistral.ai/v1",
|
||||
api_key_plaintext="k",
|
||||
)
|
||||
|
||||
fake_redis = MagicMock()
|
||||
s = _make_settings()
|
||||
with patch("application.core.settings.settings", s), patch(
|
||||
"application.storage.db.session.db_readonly",
|
||||
lambda: _yield(pg_conn),
|
||||
), patch(
|
||||
"application.cache.get_redis_instance", return_value=fake_redis
|
||||
):
|
||||
ModelRegistry().get_model("anything", user_id="user-1")
|
||||
ModelRegistry.invalidate_user("user-1")
|
||||
|
||||
fake_redis.incr.assert_called_once_with("byom:registry_version:user-1")
|
||||
|
||||
def test_peer_reloads_when_redis_version_changes(self, pg_conn):
|
||||
"""Two-process simulation. Peer loads at version=0; another
|
||||
process's CRUD bumps the version and updates Postgres; peer's
|
||||
next post-TTL access sees the version mismatch and reloads,
|
||||
picking up the rotated key it never invalidated locally."""
|
||||
from application.core import model_registry as registry_mod
|
||||
|
||||
repo = UserCustomModelsRepository(pg_conn)
|
||||
created = repo.create(
|
||||
user_id="user-1",
|
||||
upstream_model_id="m-orig",
|
||||
display_name="orig",
|
||||
base_url="https://api.mistral.ai/v1",
|
||||
api_key_plaintext="orig-key",
|
||||
)
|
||||
|
||||
state = {"version": 0}
|
||||
|
||||
class _FakeRedis:
|
||||
def get(self, key):
|
||||
if key == "byom:registry_version:user-1":
|
||||
return str(state["version"]).encode()
|
||||
return None
|
||||
|
||||
def incr(self, key):
|
||||
if key == "byom:registry_version:user-1":
|
||||
state["version"] += 1
|
||||
|
||||
s = _make_settings()
|
||||
# Force TTL to 0 so any subsequent access takes the post-TTL
|
||||
# path without waiting 60s.
|
||||
with patch("application.core.settings.settings", s), patch(
|
||||
"application.storage.db.session.db_readonly",
|
||||
lambda: _yield(pg_conn),
|
||||
), patch(
|
||||
"application.cache.get_redis_instance", return_value=_FakeRedis()
|
||||
), patch.object(
|
||||
registry_mod, "_USER_CACHE_TTL_SECONDS", 0.0
|
||||
):
|
||||
reg = ModelRegistry()
|
||||
assert (
|
||||
reg.get_model(created["id"], user_id="user-1").api_key
|
||||
== "orig-key"
|
||||
)
|
||||
|
||||
# Another process's CRUD: bump Redis counter + mutate the
|
||||
# row. Note we deliberately do NOT call ``invalidate_user``
|
||||
# in this process — that's the whole point of the test.
|
||||
state["version"] += 1
|
||||
repo.update(
|
||||
created["id"],
|
||||
"user-1",
|
||||
{"api_key_plaintext": "rotated-key"},
|
||||
)
|
||||
|
||||
assert (
|
||||
reg.get_model(created["id"], user_id="user-1").api_key
|
||||
== "rotated-key"
|
||||
)
|
||||
|
||||
def test_ttl_bounds_staleness_when_redis_unavailable(self, pg_conn):
|
||||
"""Redis down → fall back to TTL-only invalidation. After the
|
||||
TTL elapses, peers reload regardless."""
|
||||
from application.core import model_registry as registry_mod
|
||||
|
||||
repo = UserCustomModelsRepository(pg_conn)
|
||||
created = repo.create(
|
||||
user_id="user-1",
|
||||
upstream_model_id="m",
|
||||
display_name="m",
|
||||
base_url="https://api.mistral.ai/v1",
|
||||
api_key_plaintext="orig",
|
||||
)
|
||||
|
||||
s = _make_settings()
|
||||
with patch("application.core.settings.settings", s), patch(
|
||||
"application.storage.db.session.db_readonly",
|
||||
lambda: _yield(pg_conn),
|
||||
), patch(
|
||||
"application.cache.get_redis_instance", return_value=None
|
||||
), patch.object(
|
||||
registry_mod, "_USER_CACHE_TTL_SECONDS", 0.0
|
||||
):
|
||||
reg = ModelRegistry()
|
||||
first = reg.get_model(created["id"], user_id="user-1")
|
||||
assert first.api_key == "orig"
|
||||
|
||||
repo.update(
|
||||
created["id"],
|
||||
"user-1",
|
||||
{"api_key_plaintext": "rotated"},
|
||||
)
|
||||
|
||||
second = reg.get_model(created["id"], user_id="user-1")
|
||||
assert second.api_key == "rotated"
|
||||
|
||||
def test_unchanged_version_extends_ttl_without_db_read(self, pg_conn):
|
||||
"""Hot path: TTL expires but Redis says no invalidation
|
||||
happened — extend the entry without re-reading Postgres."""
|
||||
from application.core import model_registry as registry_mod
|
||||
|
||||
repo = UserCustomModelsRepository(pg_conn)
|
||||
created = repo.create(
|
||||
user_id="user-1",
|
||||
upstream_model_id="m",
|
||||
display_name="m",
|
||||
base_url="https://api.mistral.ai/v1",
|
||||
api_key_plaintext="k",
|
||||
)
|
||||
|
||||
fake_redis = MagicMock()
|
||||
fake_redis.get.return_value = b"7" # constant version
|
||||
|
||||
db_open_count = {"n": 0}
|
||||
|
||||
@contextmanager
|
||||
def _counting_db_readonly():
|
||||
db_open_count["n"] += 1
|
||||
yield pg_conn
|
||||
|
||||
s = _make_settings()
|
||||
with patch("application.core.settings.settings", s), patch(
|
||||
"application.storage.db.session.db_readonly",
|
||||
_counting_db_readonly,
|
||||
), patch(
|
||||
"application.cache.get_redis_instance", return_value=fake_redis
|
||||
), patch.object(
|
||||
registry_mod, "_USER_CACHE_TTL_SECONDS", 0.0
|
||||
):
|
||||
reg = ModelRegistry()
|
||||
reg.get_model(created["id"], user_id="user-1")
|
||||
first_open = db_open_count["n"]
|
||||
# TTL has expired, but Redis returns the same version we
|
||||
# captured at load time → no DB reload.
|
||||
reg.get_model(created["id"], user_id="user-1")
|
||||
reg.get_model(created["id"], user_id="user-1")
|
||||
assert db_open_count["n"] == first_open
|
||||
@@ -1239,6 +1239,131 @@ class TestPerformInMemoryCompression:
|
||||
assert messages is not None
|
||||
assert agent.compressed_summary == "summary"
|
||||
|
||||
def test_uses_agent_model_user_id_for_byom_owner_scope(self):
|
||||
"""Shared-agent BYOM dispatch: the agent's model_id is the owner's
|
||||
BYOM UUID and ``decoded_token['sub']`` is a different (caller) user.
|
||||
Provider lookup and LLMCreator must run under the OWNER scope so
|
||||
the registry hits the owner's per-user layer instead of missing
|
||||
and falling back to the deployment default."""
|
||||
handler = ConcreteHandler()
|
||||
agent = Mock()
|
||||
agent.model_id = "byom-uuid-owner"
|
||||
agent.model_user_id = "owner-id"
|
||||
agent.user_api_key = None
|
||||
agent.decoded_token = {"sub": "caller-id"}
|
||||
agent.agent_id = None
|
||||
agent.context_limit_reached = False
|
||||
agent.current_token_count = 0
|
||||
|
||||
mock_metadata = Mock()
|
||||
mock_metadata.compressed_token_count = 100
|
||||
mock_metadata.original_token_count = 1000
|
||||
mock_metadata.compression_ratio = 10.0
|
||||
mock_metadata.to_dict.return_value = {"ratio": 10.0}
|
||||
|
||||
mock_service = Mock()
|
||||
mock_service.compress_conversation.return_value = mock_metadata
|
||||
mock_service.get_compressed_context.return_value = (
|
||||
"summary",
|
||||
[{"prompt": "q", "response": "a"}],
|
||||
)
|
||||
|
||||
provider_lookup = MagicMock(return_value="openai_compatible")
|
||||
create_llm_spy = MagicMock(return_value=Mock())
|
||||
|
||||
with patch.object(
|
||||
handler,
|
||||
"_build_conversation_from_messages",
|
||||
return_value={"queries": [{"prompt": "q", "response": "a"}]},
|
||||
), patch(
|
||||
"application.core.model_utils.get_provider_from_model_id",
|
||||
provider_lookup,
|
||||
), patch(
|
||||
"application.core.model_utils.get_api_key_for_provider",
|
||||
return_value="key",
|
||||
), patch(
|
||||
"application.llm.llm_creator.LLMCreator.create_llm",
|
||||
create_llm_spy,
|
||||
), patch(
|
||||
"application.api.answer.services.compression.service.CompressionService",
|
||||
return_value=mock_service,
|
||||
), patch.object(
|
||||
handler,
|
||||
"_rebuild_messages_after_compression",
|
||||
return_value=[{"role": "system", "content": "rebuilt"}],
|
||||
), patch(
|
||||
"application.core.settings.settings",
|
||||
MagicMock(COMPRESSION_MODEL_OVERRIDE=None),
|
||||
):
|
||||
success, _ = handler._perform_in_memory_compression(
|
||||
agent, [{"role": "user", "content": "hi"}]
|
||||
)
|
||||
|
||||
assert success is True
|
||||
# Provider lookup must use the owner scope, not the caller's sub.
|
||||
assert provider_lookup.call_args.kwargs["user_id"] == "owner-id"
|
||||
# LLMCreator must receive the owner scope as model_user_id;
|
||||
# without this the BYOM UUID resolves under the caller and the
|
||||
# owner's registered base_url + api_key are missed.
|
||||
assert create_llm_spy.call_args.kwargs["model_user_id"] == "owner-id"
|
||||
|
||||
def test_falls_back_to_caller_sub_when_no_model_user_id(self):
|
||||
"""Built-in or caller-owned BYOM: ``agent.model_user_id`` is None.
|
||||
Compression should resolve under the caller's sub, matching the
|
||||
pre-fix behavior."""
|
||||
handler = ConcreteHandler()
|
||||
agent = Mock()
|
||||
agent.model_id = "gpt-4"
|
||||
agent.model_user_id = None
|
||||
agent.user_api_key = None
|
||||
agent.decoded_token = {"sub": "caller-id"}
|
||||
agent.agent_id = None
|
||||
agent.context_limit_reached = False
|
||||
agent.current_token_count = 0
|
||||
|
||||
mock_metadata = Mock()
|
||||
mock_metadata.compressed_token_count = 100
|
||||
mock_metadata.original_token_count = 1000
|
||||
mock_metadata.to_dict.return_value = {}
|
||||
|
||||
mock_service = Mock()
|
||||
mock_service.compress_conversation.return_value = mock_metadata
|
||||
mock_service.get_compressed_context.return_value = ("summary", [])
|
||||
|
||||
provider_lookup = MagicMock(return_value="openai")
|
||||
create_llm_spy = MagicMock(return_value=Mock())
|
||||
|
||||
with patch.object(
|
||||
handler,
|
||||
"_build_conversation_from_messages",
|
||||
return_value={"queries": [{"prompt": "q", "response": "a"}]},
|
||||
), patch(
|
||||
"application.core.model_utils.get_provider_from_model_id",
|
||||
provider_lookup,
|
||||
), patch(
|
||||
"application.core.model_utils.get_api_key_for_provider",
|
||||
return_value="key",
|
||||
), patch(
|
||||
"application.llm.llm_creator.LLMCreator.create_llm",
|
||||
create_llm_spy,
|
||||
), patch(
|
||||
"application.api.answer.services.compression.service.CompressionService",
|
||||
return_value=mock_service,
|
||||
), patch.object(
|
||||
handler,
|
||||
"_rebuild_messages_after_compression",
|
||||
return_value=[],
|
||||
), patch(
|
||||
"application.core.settings.settings",
|
||||
MagicMock(COMPRESSION_MODEL_OVERRIDE=None),
|
||||
):
|
||||
handler._perform_in_memory_compression(
|
||||
agent, [{"role": "user", "content": "hi"}]
|
||||
)
|
||||
|
||||
assert provider_lookup.call_args.kwargs["user_id"] == "caller-id"
|
||||
assert create_llm_spy.call_args.kwargs["model_user_id"] == "caller-id"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _perform_mid_execution_compression — additional edge cases
|
||||
|
||||
@@ -197,7 +197,7 @@ class TestFallbackLLMResolution:
|
||||
mock_fallback = StubLLM()
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_provider_from_model_id",
|
||||
lambda mid: "openai",
|
||||
lambda mid, **_kwargs: "openai",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_api_key_for_provider",
|
||||
@@ -223,7 +223,7 @@ class TestFallbackLLMResolution:
|
||||
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_provider_from_model_id",
|
||||
lambda mid: "openai",
|
||||
lambda mid, **_kwargs: "openai",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_api_key_for_provider",
|
||||
@@ -262,7 +262,7 @@ class TestFallbackLLMResolution:
|
||||
def test_backup_provider_not_found_skipped(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_provider_from_model_id",
|
||||
lambda mid: None,
|
||||
lambda mid, **_kwargs: None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.llm.base.settings",
|
||||
|
||||
@@ -11,9 +11,7 @@ import pytest
|
||||
from application.llm.base import BaseLLM
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Concrete LLM stubs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class FakeLLM(BaseLLM):
|
||||
@@ -59,9 +57,7 @@ class FakeLLM(BaseLLM):
|
||||
return super().gen_stream(*args, **kwargs)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _noop_decorator(func):
|
||||
@@ -121,9 +117,7 @@ def patch_model_utils(monkeypatch):
|
||||
CALL_ARGS = dict(model="test-model", messages=[{"role": "user", "content": "hi"}])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests — fallback_llm property resolution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@@ -135,7 +129,7 @@ class TestFallbackLLMResolution:
|
||||
backup_llm = FakeLLM(responses=["backup response"])
|
||||
|
||||
patch_model_utils(
|
||||
get_provider=lambda mid: "openai",
|
||||
get_provider=lambda mid, **_kwargs: "openai",
|
||||
get_api_key=lambda prov: "fake-key",
|
||||
create_llm=lambda type, **kw: backup_llm,
|
||||
)
|
||||
@@ -174,7 +168,7 @@ class TestFallbackLLMResolution:
|
||||
good_backup = FakeLLM(responses=["good backup"])
|
||||
call_count = {"n": 0}
|
||||
|
||||
def fake_get_provider(model_id):
|
||||
def fake_get_provider(model_id, **_kwargs):
|
||||
call_count["n"] += 1
|
||||
if model_id == "bad-model":
|
||||
return None # unresolvable
|
||||
@@ -202,9 +196,7 @@ class TestFallbackLLMResolution:
|
||||
assert primary.fallback_llm is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests — non-streaming fallback (gen)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@@ -221,7 +213,7 @@ class TestNonStreamingFallback:
|
||||
backup = FakeLLM(responses=["backup ok"])
|
||||
|
||||
patch_model_utils(
|
||||
get_provider=lambda mid: "openai",
|
||||
get_provider=lambda mid, **_kwargs: "openai",
|
||||
get_api_key=lambda p: "k",
|
||||
create_llm=lambda type, **kw: backup,
|
||||
)
|
||||
@@ -242,9 +234,7 @@ class TestNonStreamingFallback:
|
||||
primary.gen(**CALL_ARGS)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests — streaming fallback (gen_stream)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@@ -261,7 +251,7 @@ class TestStreamingFallback:
|
||||
backup = FakeLLM(stream_chunks=["fallback1", "fallback2"])
|
||||
|
||||
patch_model_utils(
|
||||
get_provider=lambda m: "openai",
|
||||
get_provider=lambda m, **_kwargs: "openai",
|
||||
get_api_key=lambda p: "k",
|
||||
create_llm=lambda type, **kw: backup,
|
||||
)
|
||||
@@ -280,7 +270,7 @@ class TestStreamingFallback:
|
||||
backup = FakeLLM(stream_chunks=["recovery1", "recovery2"])
|
||||
|
||||
patch_model_utils(
|
||||
get_provider=lambda m: "openai",
|
||||
get_provider=lambda m, **_kwargs: "openai",
|
||||
get_api_key=lambda p: "k",
|
||||
create_llm=lambda type, **kw: backup,
|
||||
)
|
||||
@@ -305,9 +295,7 @@ class TestStreamingFallback:
|
||||
list(primary.gen_stream(**CALL_ARGS))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests — backup model priority over global fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@@ -323,7 +311,7 @@ class TestBackupModelPriority:
|
||||
return backup
|
||||
|
||||
patch_model_utils(
|
||||
get_provider=lambda m: "openai",
|
||||
get_provider=lambda m, **_kwargs: "openai",
|
||||
get_api_key=lambda p: "k",
|
||||
create_llm=fake_create_llm,
|
||||
)
|
||||
@@ -346,7 +334,7 @@ class TestBackupModelPriority:
|
||||
return backup
|
||||
|
||||
patch_model_utils(
|
||||
get_provider=lambda m: "openai",
|
||||
get_provider=lambda m, **_kwargs: "openai",
|
||||
get_api_key=lambda p: "k",
|
||||
create_llm=fake_create_llm,
|
||||
)
|
||||
@@ -366,7 +354,7 @@ class TestBackupModelPriority:
|
||||
global_fallback = FakeLLM(responses=["global ok"])
|
||||
call_order = []
|
||||
|
||||
def fake_get_provider(mid):
|
||||
def fake_get_provider(mid, **_kwargs):
|
||||
if mid == "broken-backup":
|
||||
return "nonexistent_provider"
|
||||
return "openai"
|
||||
@@ -401,9 +389,7 @@ class TestBackupModelPriority:
|
||||
assert call_order == ["broken-backup", "global-model"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests — fallback uses its own model_id, not the primary's
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@@ -419,7 +405,7 @@ class TestFallbackModelIdOverride:
|
||||
)
|
||||
|
||||
patch_model_utils(
|
||||
get_provider=lambda m: "groq",
|
||||
get_provider=lambda m, **_kwargs: "groq",
|
||||
get_api_key=lambda p: "k",
|
||||
create_llm=lambda type, **kw: backup,
|
||||
)
|
||||
@@ -441,7 +427,7 @@ class TestFallbackModelIdOverride:
|
||||
)
|
||||
|
||||
patch_model_utils(
|
||||
get_provider=lambda m: "groq",
|
||||
get_provider=lambda m, **_kwargs: "groq",
|
||||
get_api_key=lambda p: "k",
|
||||
create_llm=lambda type, **kw: backup,
|
||||
)
|
||||
@@ -465,7 +451,7 @@ class TestFallbackModelIdOverride:
|
||||
)
|
||||
|
||||
patch_model_utils(
|
||||
get_provider=lambda m: "groq",
|
||||
get_provider=lambda m, **_kwargs: "groq",
|
||||
get_api_key=lambda p: "k",
|
||||
create_llm=lambda type, **kw: backup,
|
||||
)
|
||||
@@ -480,3 +466,161 @@ class TestFallbackModelIdOverride:
|
||||
|
||||
assert chunks == ["partial1", "partial2", "recovered"]
|
||||
assert backup.last_model_received == "groq-gpt-oss-120b"
|
||||
|
||||
|
||||
# Tests — model_user_id (BYOM owner scope) propagates into fallback resolution
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestFallbackModelUserIdScope:
|
||||
"""A shared agent dispatched by user B but owned by user A stores
|
||||
A's BYOM UUIDs as backup_models. Without the P2 fix the fallback
|
||||
property looks up those UUIDs against ``decoded_token['sub']`` (B,
|
||||
the caller), which can't see A's per-user layer — backups are
|
||||
silently skipped and the global FALLBACK_* settings are used
|
||||
instead. These tests pin down that ``model_user_id`` (the owner)
|
||||
is used both for the registry lookup and for the recursive
|
||||
``LLMCreator.create_llm`` call."""
|
||||
|
||||
def test_backup_lookup_uses_model_user_id_not_caller(
|
||||
self, patch_model_utils
|
||||
):
|
||||
captured = {"user_id": None}
|
||||
|
||||
def fake_get_provider(model_id, **kwargs):
|
||||
captured["user_id"] = kwargs.get("user_id")
|
||||
return "openai"
|
||||
|
||||
backup = FakeLLM(responses=["ok"])
|
||||
patch_model_utils(
|
||||
get_provider=fake_get_provider,
|
||||
get_api_key=lambda p: "k",
|
||||
create_llm=lambda type, **kw: backup,
|
||||
)
|
||||
|
||||
primary = FakeLLM(
|
||||
decoded_token={"sub": "caller-bob"},
|
||||
model_user_id="owner-alice",
|
||||
backup_models=["alice-byom-uuid"],
|
||||
)
|
||||
_ = primary.fallback_llm
|
||||
assert captured["user_id"] == "owner-alice"
|
||||
|
||||
def test_backup_create_llm_receives_model_user_id(self, patch_model_utils):
|
||||
backup = FakeLLM(responses=["ok"])
|
||||
captured = {}
|
||||
|
||||
def fake_create_llm(type, **kw):
|
||||
captured["model_user_id"] = kw.get("model_user_id")
|
||||
captured["model_id"] = kw.get("model_id")
|
||||
return backup
|
||||
|
||||
patch_model_utils(
|
||||
get_provider=lambda m, **_kwargs: "openai",
|
||||
get_api_key=lambda p: "k",
|
||||
create_llm=fake_create_llm,
|
||||
)
|
||||
|
||||
primary = FakeLLM(
|
||||
decoded_token={"sub": "caller-bob"},
|
||||
model_user_id="owner-alice",
|
||||
backup_models=["alice-byom-uuid"],
|
||||
)
|
||||
_ = primary.fallback_llm
|
||||
assert captured["model_user_id"] == "owner-alice"
|
||||
assert captured["model_id"] == "alice-byom-uuid"
|
||||
|
||||
def test_global_fallback_create_llm_receives_model_user_id(
|
||||
self, monkeypatch, patch_model_utils
|
||||
):
|
||||
"""The global FALLBACK_LLM_NAME path must also forward
|
||||
``model_user_id`` — operators can configure it to a BYOM UUID
|
||||
that's owned by the same user as the primary model."""
|
||||
backup = FakeLLM(responses=["ok"])
|
||||
captured = {}
|
||||
|
||||
def fake_create_llm(type, **kw):
|
||||
captured["model_user_id"] = kw.get("model_user_id")
|
||||
return backup
|
||||
|
||||
patch_model_utils(create_llm=fake_create_llm)
|
||||
monkeypatch.setattr(
|
||||
"application.llm.base.settings",
|
||||
MagicMock(
|
||||
FALLBACK_LLM_PROVIDER="openai",
|
||||
FALLBACK_LLM_NAME="some-uuid",
|
||||
FALLBACK_LLM_API_KEY="k",
|
||||
API_KEY="k",
|
||||
),
|
||||
)
|
||||
|
||||
primary = FakeLLM(
|
||||
decoded_token={"sub": "caller-bob"},
|
||||
model_user_id="owner-alice",
|
||||
backup_models=[],
|
||||
)
|
||||
_ = primary.fallback_llm
|
||||
assert captured["model_user_id"] == "owner-alice"
|
||||
|
||||
def test_falls_back_to_caller_when_model_user_id_unset(
|
||||
self, patch_model_utils
|
||||
):
|
||||
"""Built-in models / pre-P2 callers don't pass model_user_id.
|
||||
In that case the caller's sub is still used — preserving
|
||||
existing behaviour."""
|
||||
captured = {}
|
||||
|
||||
def fake_get_provider(model_id, **kwargs):
|
||||
captured["user_id"] = kwargs.get("user_id")
|
||||
return "openai"
|
||||
|
||||
patch_model_utils(
|
||||
get_provider=fake_get_provider,
|
||||
get_api_key=lambda p: "k",
|
||||
create_llm=lambda type, **kw: FakeLLM(responses=["ok"]),
|
||||
)
|
||||
|
||||
primary = FakeLLM(
|
||||
decoded_token={"sub": "caller-bob"},
|
||||
model_user_id=None,
|
||||
backup_models=["some-builtin-id"],
|
||||
)
|
||||
_ = primary.fallback_llm
|
||||
assert captured["user_id"] == "caller-bob"
|
||||
|
||||
|
||||
# Tests — LLMCreator wires model_user_id through to BaseLLM
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestLLMCreatorPassesModelUserId:
|
||||
"""End-to-end through ``LLMCreator.create_llm``: the constructed
|
||||
LLM must store ``model_user_id`` so its fallback property can
|
||||
resolve under the right scope."""
|
||||
|
||||
def test_model_user_id_set_on_constructed_llm(self, monkeypatch):
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.llm.providers import PROVIDERS_BY_NAME
|
||||
|
||||
captured = {}
|
||||
|
||||
class _CapturingLLM:
|
||||
def __init__(self, api_key, user_api_key, *args, **kwargs):
|
||||
captured["model_user_id"] = kwargs.get("model_user_id")
|
||||
|
||||
# Pick any registered provider — we only need the constructor
|
||||
# call to land in our fake.
|
||||
monkeypatch.setattr(
|
||||
PROVIDERS_BY_NAME["openai"], "llm_class", _CapturingLLM
|
||||
)
|
||||
|
||||
LLMCreator.create_llm(
|
||||
type="openai",
|
||||
api_key="k",
|
||||
user_api_key=None,
|
||||
decoded_token={"sub": "caller-bob"},
|
||||
model_id=None,
|
||||
model_user_id="owner-alice",
|
||||
)
|
||||
|
||||
assert captured["model_user_id"] == "owner-alice"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user