Compare commits

...

1 Commits

Author SHA1 Message Date
Alex
e0a8cc178b feat: BYOM 2026-04-27 21:50:45 +01:00
68 changed files with 7645 additions and 282 deletions

View File

@@ -42,6 +42,7 @@ class BaseAgent(ABC):
llm_handler=None,
tool_executor: Optional[ToolExecutor] = None,
backup_models: Optional[List[str]] = None,
model_user_id: Optional[str] = None,
):
self.endpoint = endpoint
self.llm_name = llm_name
@@ -52,10 +53,13 @@ class BaseAgent(ABC):
self.prompt = prompt
self.decoded_token = decoded_token or {}
self.user: str = self.decoded_token.get("sub")
# BYOM-resolution scope: owner for shared agents, caller for
# caller-owned BYOM, None for built-ins. Falls back to self.user
# for worker/legacy callers that don't thread model_user_id.
self.model_user_id = model_user_id
self.tools: List[Dict] = []
self.chat_history: List[Dict] = chat_history if chat_history is not None else []
# Dependency injection for LLM — fall back to creating if not provided
if llm is not None:
self.llm = llm
else:
@@ -67,8 +71,16 @@ class BaseAgent(ABC):
model_id=model_id,
agent_id=agent_id,
backup_models=backup_models,
model_user_id=model_user_id,
)
# For BYOM, registry id (UUID) differs from upstream model id
# (e.g. ``mistral-large-latest``). LLMCreator resolved this onto
# the LLM instance; cache it for subsequent gen calls.
self.upstream_model_id = (
getattr(self.llm, "model_id", None) or model_id
)
self.retrieved_docs = retrieved_docs or []
if llm_handler is not None:
@@ -306,7 +318,9 @@ class BaseAgent(ABC):
try:
current_tokens = self._calculate_current_context_tokens(messages)
self.current_token_count = current_tokens
context_limit = get_token_limit(self.model_id)
context_limit = get_token_limit(
self.model_id, user_id=self.model_user_id or self.user
)
threshold = int(context_limit * settings.COMPRESSION_THRESHOLD_PERCENTAGE)
if current_tokens >= threshold:
@@ -325,7 +339,9 @@ class BaseAgent(ABC):
current_tokens = self._calculate_current_context_tokens(messages)
self.current_token_count = current_tokens
context_limit = get_token_limit(self.model_id)
context_limit = get_token_limit(
self.model_id, user_id=self.model_user_id or self.user
)
percentage = (current_tokens / context_limit) * 100
if current_tokens >= context_limit:
@@ -387,7 +403,9 @@ class BaseAgent(ABC):
)
system_prompt = system_prompt + compression_context
context_limit = get_token_limit(self.model_id)
context_limit = get_token_limit(
self.model_id, user_id=self.model_user_id or self.user
)
system_tokens = num_tokens_from_string(system_prompt)
safety_buffer = int(context_limit * 0.1)
@@ -497,7 +515,10 @@ class BaseAgent(ABC):
def _llm_gen(self, messages: List[Dict], log_context: Optional[LogContext] = None):
self._validate_context_size(messages)
gen_kwargs = {"model": self.model_id, "messages": messages}
# Use the upstream id resolved by LLMCreator (see __init__).
# Built-in models: same as self.model_id. BYOM: the user's
# typed model name, not the internal UUID.
gen_kwargs = {"model": self.upstream_model_id, "messages": messages}
if self.attachments:
gen_kwargs["_usage_attachments"] = self.attachments

View File

@@ -312,7 +312,7 @@ class ResearchAgent(BaseAgent):
try:
response = self.llm.gen(
model=self.model_id,
model=self.upstream_model_id,
messages=messages,
tools=None,
response_format={"type": "json_object"},
@@ -390,7 +390,7 @@ class ResearchAgent(BaseAgent):
try:
response = self.llm.gen(
model=self.model_id,
model=self.upstream_model_id,
messages=messages,
tools=None,
response_format={"type": "json_object"},
@@ -506,7 +506,7 @@ class ResearchAgent(BaseAgent):
try:
response = self.llm.gen(
model=self.model_id,
model=self.upstream_model_id,
messages=messages,
tools=self.tools if self.tools else None,
)
@@ -537,7 +537,7 @@ class ResearchAgent(BaseAgent):
)
try:
response = self.llm.gen(
model=self.model_id, messages=messages, tools=None
model=self.upstream_model_id, messages=messages, tools=None
)
self._track_tokens(self._snapshot_llm_tokens())
text = self._extract_text(response)
@@ -664,7 +664,7 @@ class ResearchAgent(BaseAgent):
]
llm_response = self.llm.gen_stream(
model=self.model_id, messages=messages, tools=None
model=self.upstream_model_id, messages=messages, tools=None
)
if log_context:

View File

@@ -39,6 +39,7 @@ class InternalSearchTool(Tool):
chunks=int(self.config.get("chunks", 2)),
doc_token_limit=int(self.config.get("doc_token_limit", 50000)),
model_id=self.config.get("model_id", "docsgpt-local"),
model_user_id=self.config.get("model_user_id"),
user_api_key=self.config.get("user_api_key"),
agent_id=self.config.get("agent_id"),
llm_name=self.config.get("llm_name", settings.LLM_PROVIDER),
@@ -435,6 +436,7 @@ def build_internal_tool_config(
chunks: int = 2,
doc_token_limit: int = 50000,
model_id: str = "docsgpt-local",
model_user_id: Optional[str] = None,
user_api_key: Optional[str] = None,
agent_id: Optional[str] = None,
llm_name: str = None,
@@ -449,6 +451,7 @@ def build_internal_tool_config(
"chunks": chunks,
"doc_token_limit": doc_token_limit,
"model_id": model_id,
"model_user_id": model_user_id,
"user_api_key": user_api_key,
"agent_id": agent_id,
"llm_name": llm_name or settings.LLM_PROVIDER,

View File

@@ -211,15 +211,26 @@ class WorkflowEngine:
node_config.json_schema, node.title
)
node_model_id = node_config.model_id or self.agent.model_id
# Inherit BYOM scope from parent agent so owner-stored BYOM
# resolves on shared workflows.
node_user_id = getattr(self.agent, "model_user_id", None) or (
self.agent.decoded_token.get("sub")
if isinstance(self.agent.decoded_token, dict)
else None
)
node_llm_name = (
node_config.llm_name
or get_provider_from_model_id(node_model_id or "")
or get_provider_from_model_id(
node_model_id or "", user_id=node_user_id
)
or self.agent.llm_name
)
node_api_key = get_api_key_for_provider(node_llm_name) or self.agent.api_key
if node_json_schema and node_model_id:
model_capabilities = get_model_capabilities(node_model_id)
model_capabilities = get_model_capabilities(
node_model_id, user_id=node_user_id
)
if model_capabilities and not model_capabilities.get(
"supports_structured_output", False
):
@@ -232,6 +243,7 @@ class WorkflowEngine:
"endpoint": self.agent.endpoint,
"llm_name": node_llm_name,
"model_id": node_model_id,
"model_user_id": getattr(self.agent, "model_user_id", None),
"api_key": node_api_key,
"tool_ids": node_config.tools,
"prompt": node_config.system_prompt,

View File

@@ -0,0 +1,65 @@
"""0003 user_custom_models — per-user OpenAI-compatible model registrations.
Revision ID: 0003_user_custom_models
Revises: 0002_app_metadata
"""
from typing import Sequence, Union
from alembic import op
revision: str = "0003_user_custom_models"
down_revision: Union[str, None] = "0002_app_metadata"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.execute(
"""
CREATE TABLE user_custom_models (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
upstream_model_id TEXT NOT NULL,
display_name TEXT NOT NULL,
description TEXT NOT NULL DEFAULT '',
base_url TEXT NOT NULL,
api_key_encrypted TEXT NOT NULL,
capabilities JSONB NOT NULL DEFAULT '{}'::jsonb,
enabled BOOLEAN NOT NULL DEFAULT true,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
"""
)
op.execute(
"CREATE INDEX user_custom_models_user_id_idx "
"ON user_custom_models (user_id);"
)
# Mirror the project-wide invariants set up in 0001_initial:
# * user_id FK with ON DELETE RESTRICT (deferrable),
# * ensure_user_exists() trigger so the parent users row autocreates,
# * set_updated_at() trigger.
op.execute(
"ALTER TABLE user_custom_models "
"ADD CONSTRAINT user_custom_models_user_id_fk "
"FOREIGN KEY (user_id) REFERENCES users(user_id) "
"ON DELETE RESTRICT DEFERRABLE INITIALLY IMMEDIATE;"
)
op.execute(
"CREATE TRIGGER user_custom_models_ensure_user "
"BEFORE INSERT OR UPDATE OF user_id ON user_custom_models "
"FOR EACH ROW EXECUTE FUNCTION ensure_user_exists();"
)
op.execute(
"CREATE TRIGGER user_custom_models_set_updated_at "
"BEFORE UPDATE ON user_custom_models "
"FOR EACH ROW WHEN (OLD.* IS DISTINCT FROM NEW.*) "
"EXECUTE FUNCTION set_updated_at();"
)
def downgrade() -> None:
op.execute("DROP TABLE IF EXISTS user_custom_models;")

View File

@@ -177,6 +177,7 @@ class BaseAnswerResource:
is_shared_usage: bool = False,
shared_token: Optional[str] = None,
model_id: Optional[str] = None,
model_user_id: Optional[str] = None,
_continuation: Optional[Dict] = None,
) -> Generator[str, None, None]:
"""
@@ -289,8 +290,18 @@ class BaseAnswerResource:
# conversation if this is the first turn.
if not conversation_id and should_save_conversation:
try:
# Use model-owner scope so shared-agent
# owner-BYOM resolves to its registered plugin.
provider = (
get_provider_from_model_id(model_id)
get_provider_from_model_id(
model_id,
user_id=model_user_id
or (
decoded_token.get("sub")
if decoded_token
else None
),
)
if model_id
else settings.LLM_PROVIDER
)
@@ -304,6 +315,7 @@ class BaseAnswerResource:
decoded_token=decoded_token,
model_id=model_id,
agent_id=agent_id,
model_user_id=model_user_id,
)
conversation_id = (
self.conversation_service.save_conversation(
@@ -340,6 +352,9 @@ class BaseAnswerResource:
tool_schemas=getattr(agent, "tools", []),
agent_config={
"model_id": model_id or self.default_model_id,
# Persist BYOM scope so resume doesn't
# fall back to caller's layer.
"model_user_id": model_user_id,
"llm_name": getattr(agent, "llm_name", settings.LLM_PROVIDER),
"api_key": getattr(agent, "api_key", None),
"user_api_key": user_api_key,
@@ -370,8 +385,14 @@ class BaseAnswerResource:
if isNoneDoc:
for doc in source_log_docs:
doc["source"] = "None"
# Run under model-owner scope so title-gen LLM inside
# save_conversation uses the owner's BYOM provider/key.
provider = (
get_provider_from_model_id(model_id)
get_provider_from_model_id(
model_id,
user_id=model_user_id
or (decoded_token.get("sub") if decoded_token else None),
)
if model_id
else settings.LLM_PROVIDER
)
@@ -384,6 +405,7 @@ class BaseAnswerResource:
decoded_token=decoded_token,
model_id=model_id,
agent_id=agent_id,
model_user_id=model_user_id,
)
if should_save_conversation:
@@ -481,12 +503,34 @@ class BaseAnswerResource:
if isNoneDoc:
for doc in source_log_docs:
doc["source"] = "None"
# Mirror the normal-path provider resolution so the
# partial-save title LLM uses the model-owner's BYOM
# registration (shared-agent dispatch) rather than
# the deployment default with the instance api key.
provider = (
get_provider_from_model_id(
model_id,
user_id=model_user_id
or (
decoded_token.get("sub")
if decoded_token
else None
),
)
if model_id
else settings.LLM_PROVIDER
)
sys_api_key = get_api_key_for_provider(
provider or settings.LLM_PROVIDER
)
llm = LLMCreator.create_llm(
settings.LLM_PROVIDER,
api_key=settings.API_KEY,
provider or settings.LLM_PROVIDER,
api_key=sys_api_key,
user_api_key=user_api_key,
decoded_token=decoded_token,
model_id=model_id,
agent_id=agent_id,
model_user_id=model_user_id,
)
self.conversation_service.save_conversation(
conversation_id,

View File

@@ -109,6 +109,7 @@ class StreamResource(Resource, BaseAnswerResource):
decoded_token=processor.decoded_token,
agent_id=processor.agent_id,
model_id=processor.model_id,
model_user_id=processor.model_user_id,
_continuation={
"messages": messages,
"tools_dict": tools_dict,
@@ -145,6 +146,7 @@ class StreamResource(Resource, BaseAnswerResource):
is_shared_usage=processor.is_shared_usage,
shared_token=processor.shared_token,
model_id=processor.model_id,
model_user_id=processor.model_user_id,
),
mimetype="text/event-stream",
)

View File

@@ -49,6 +49,7 @@ class CompressionOrchestrator:
model_id: str,
decoded_token: Dict[str, Any],
current_query_tokens: int = 500,
model_user_id: Optional[str] = None,
) -> CompressionResult:
"""
Check if compression is needed and perform it if so.
@@ -57,16 +58,18 @@ class CompressionOrchestrator:
Args:
conversation_id: Conversation ID
user_id: User ID
user_id: Caller's user id — used for conversation access checks
model_id: Model being used for conversation
decoded_token: User's decoded JWT token
current_query_tokens: Estimated tokens for current query
model_user_id: BYOM-resolution scope (model owner); defaults
to ``user_id`` for built-in / caller-owned models.
Returns:
CompressionResult with summary and recent queries
"""
try:
# Load conversation
# Conversation row is owned by the caller, not the model owner.
conversation = self.conversation_service.get_conversation(
conversation_id, user_id
)
@@ -77,9 +80,14 @@ class CompressionOrchestrator:
)
return CompressionResult.failure("Conversation not found")
# Check if compression is needed
# Use model-owner scope so per-user BYOM context windows
# (e.g. 8k) compute the threshold against the right limit.
registry_user_id = model_user_id or user_id
if not self.threshold_checker.should_compress(
conversation, model_id, current_query_tokens
conversation,
model_id,
current_query_tokens,
user_id=registry_user_id,
):
# No compression needed, return full history
queries = conversation.get("queries", [])
@@ -87,7 +95,12 @@ class CompressionOrchestrator:
# Perform compression
return self._perform_compression(
conversation_id, conversation, model_id, decoded_token
conversation_id,
conversation,
model_id,
decoded_token,
user_id=user_id,
model_user_id=model_user_id,
)
except Exception as e:
@@ -102,6 +115,8 @@ class CompressionOrchestrator:
conversation: Dict[str, Any],
model_id: str,
decoded_token: Dict[str, Any],
user_id: Optional[str] = None,
model_user_id: Optional[str] = None,
) -> CompressionResult:
"""
Perform the actual compression operation.
@@ -111,6 +126,8 @@ class CompressionOrchestrator:
conversation: Conversation document
model_id: Model ID for conversation
decoded_token: User token
user_id: Caller's id (for conversation reload after compression)
model_user_id: BYOM-resolution scope (model owner)
Returns:
CompressionResult
@@ -123,11 +140,17 @@ class CompressionOrchestrator:
else model_id
)
# Get provider and API key for compression model
provider = get_provider_from_model_id(compression_model)
# Use model-owner scope so provider/api_key resolves to the
# owner's BYOM record (shared-agent dispatch).
caller_user_id = user_id
if caller_user_id is None and isinstance(decoded_token, dict):
caller_user_id = decoded_token.get("sub")
registry_user_id = model_user_id or caller_user_id
provider = get_provider_from_model_id(
compression_model, user_id=registry_user_id
)
api_key = get_api_key_for_provider(provider)
# Create compression LLM
compression_llm = LLMCreator.create_llm(
provider,
api_key=api_key,
@@ -135,6 +158,7 @@ class CompressionOrchestrator:
decoded_token=decoded_token,
model_id=compression_model,
agent_id=conversation.get("agent_id"),
model_user_id=registry_user_id,
)
# Create compression service with DB update capability
@@ -167,9 +191,12 @@ class CompressionOrchestrator:
f"saved {metadata.original_token_count - metadata.compressed_token_count} tokens"
)
# Reload conversation with updated metadata
# Reload under caller (conversation is owned by caller).
reload_user_id = caller_user_id
if reload_user_id is None and isinstance(decoded_token, dict):
reload_user_id = decoded_token.get("sub")
conversation = self.conversation_service.get_conversation(
conversation_id, user_id=decoded_token.get("sub")
conversation_id, user_id=reload_user_id
)
# Get compressed context
@@ -192,16 +219,21 @@ class CompressionOrchestrator:
model_id: str,
decoded_token: Dict[str, Any],
current_conversation: Optional[Dict[str, Any]] = None,
model_user_id: Optional[str] = None,
) -> CompressionResult:
"""
Perform compression during tool execution.
Args:
conversation_id: Conversation ID
user_id: User ID
user_id: Caller's user id — used for conversation access checks
model_id: Model ID
decoded_token: User token
current_conversation: Pre-loaded conversation (optional)
model_user_id: BYOM-resolution scope (model owner). For
shared-agent dispatch this is the agent owner; defaults
to ``user_id`` so built-in / caller-owned models are
unaffected.
Returns:
CompressionResult
@@ -223,7 +255,12 @@ class CompressionOrchestrator:
# Perform compression
return self._perform_compression(
conversation_id, conversation, model_id, decoded_token
conversation_id,
conversation,
model_id,
decoded_token,
user_id=user_id,
model_user_id=model_user_id,
)
except Exception as e:

View File

@@ -106,8 +106,13 @@ class CompressionService:
f"using model {self.model_id}"
)
# See note in conversation_service.py: ``self.model_id`` is
# the registry id (UUID for BYOM); the LLM's own model_id is
# what the provider's API actually expects.
response = self.llm.gen(
model=self.model_id, messages=messages, max_tokens=4000
model=getattr(self.llm, "model_id", None) or self.model_id,
messages=messages,
max_tokens=4000,
)
# Extract summary from response

View File

@@ -30,6 +30,7 @@ class CompressionThresholdChecker:
conversation: Dict[str, Any],
model_id: str,
current_query_tokens: int = 500,
user_id: str | None = None,
) -> bool:
"""
Determine if compression is needed.
@@ -38,6 +39,8 @@ class CompressionThresholdChecker:
conversation: Full conversation document
model_id: Target model for this request
current_query_tokens: Estimated tokens for current query
user_id: Owner — needed so per-user BYOM custom-model UUIDs
resolve when looking up the context window.
Returns:
True if tokens >= threshold% of context window
@@ -48,7 +51,7 @@ class CompressionThresholdChecker:
total_tokens += current_query_tokens
# Get context window limit for model
context_limit = get_token_limit(model_id)
context_limit = get_token_limit(model_id, user_id=user_id)
# Calculate threshold
threshold = int(context_limit * self.threshold_percentage)
@@ -73,20 +76,24 @@ class CompressionThresholdChecker:
logger.error(f"Error checking compression need: {str(e)}", exc_info=True)
return False
def check_message_tokens(self, messages: list, model_id: str) -> bool:
def check_message_tokens(
self, messages: list, model_id: str, user_id: str | None = None
) -> bool:
"""
Check if message list exceeds threshold.
Args:
messages: List of message dicts
model_id: Target model
user_id: Owner — needed so per-user BYOM custom-model UUIDs
resolve when looking up the context window.
Returns:
True if at or above threshold
"""
try:
current_tokens = TokenCounter.count_message_tokens(messages)
context_limit = get_token_limit(model_id)
context_limit = get_token_limit(model_id, user_id=user_id)
threshold = int(context_limit * self.threshold_percentage)
if current_tokens >= threshold:

View File

@@ -136,8 +136,14 @@ class ConversationService:
},
]
# ``model_id`` here is the registry id (a UUID for BYOM
# records). The LLM's own ``model_id`` is the upstream name
# LLMCreator resolved at construction time — that's what
# the provider's API expects. Built-ins are unaffected.
completion = llm.gen(
model=model_id, messages=messages_summary, max_tokens=500
model=getattr(llm, "model_id", None) or model_id,
messages=messages_summary,
max_tokens=500,
)
if not completion or not completion.strip():

View File

@@ -121,6 +121,8 @@ class StreamProcessor:
self.agent_id = self.data.get("agent_id")
self.agent_key = None
self.model_id: Optional[str] = None
# BYOM-resolution scope, set by _validate_and_set_model.
self.model_user_id: Optional[str] = None
self.conversation_service = ConversationService()
self.compression_orchestrator = CompressionOrchestrator(
self.conversation_service
@@ -191,16 +193,23 @@ class StreamProcessor:
for query in conversation.get("queries", [])
]
else:
# model_user_id keeps history trim aligned with the BYOM's
# actual context window instead of the default 128k.
self.history = limit_chat_history(
json.loads(self.data.get("history", "[]")), model_id=self.model_id
json.loads(self.data.get("history", "[]")),
model_id=self.model_id,
user_id=self.model_user_id,
)
def _handle_compression(self, conversation: Dict[str, Any]):
"""Handle conversation compression logic using orchestrator."""
try:
# initial_user_id for conversation access; model_user_id
# for BYOM context-window / provider lookups.
result = self.compression_orchestrator.compress_if_needed(
conversation_id=self.conversation_id,
user_id=self.initial_user_id,
model_user_id=self.model_user_id,
model_id=self.model_id,
decoded_token=self.decoded_token,
)
@@ -284,11 +293,18 @@ class StreamProcessor:
from application.core.model_settings import ModelRegistry
requested_model = self.data.get("model_id")
# Caller picks from their own BYOM layer; agent defaults resolve
# under the owner's layer (shared agents have caller != owner).
caller_user_id = self.initial_user_id
owner_user_id = self.agent_config.get("user_id") or caller_user_id
if requested_model:
if not validate_model_id(requested_model):
if not validate_model_id(requested_model, user_id=caller_user_id):
registry = ModelRegistry.get_instance()
available_models = [m.id for m in registry.get_enabled_models()]
available_models = [
m.id
for m in registry.get_enabled_models(user_id=caller_user_id)
]
raise ValueError(
f"Invalid model_id '{requested_model}'. "
f"Available models: {', '.join(available_models[:5])}"
@@ -299,12 +315,17 @@ class StreamProcessor:
)
)
self.model_id = requested_model
self.model_user_id = caller_user_id
else:
agent_default_model = self.agent_config.get("default_model_id", "")
if agent_default_model and validate_model_id(agent_default_model):
if agent_default_model and validate_model_id(
agent_default_model, user_id=owner_user_id
):
self.model_id = agent_default_model
self.model_user_id = owner_user_id
else:
self.model_id = get_default_model_id()
self.model_user_id = None
def _get_agent_key(self, agent_id: Optional[str], user_id: Optional[str]) -> tuple:
"""Get API key for agent with access control."""
@@ -514,6 +535,10 @@ class StreamProcessor:
"allow_system_prompt_override": self._agent_data.get(
"allow_system_prompt_override", False
),
# Owner identity — _validate_and_set_model reads this to
# resolve owner-stored BYOM default_model_id against the
# owner's per-user model layer rather than the caller's.
"user_id": self._agent_data.get("user"),
}
)
@@ -561,7 +586,13 @@ class StreamProcessor:
def _configure_retriever(self):
"""Assemble retriever config with precedence: request > agent > default."""
doc_token_limit = calculate_doc_token_budget(model_id=self.model_id)
# BYOM scope: owner for shared-agent BYOM, caller for own BYOM,
# None for built-ins. Without ``user_id`` here, the doc budget
# falls back to settings.DEFAULT_LLM_TOKEN_LIMIT and overfills
# the upstream context window for any small (e.g. 8k/32k) BYOM.
doc_token_limit = calculate_doc_token_budget(
model_id=self.model_id, user_id=self.model_user_id
)
# Start with defaults
retriever_name = "classic"
@@ -612,6 +643,7 @@ class StreamProcessor:
chunks=self.retriever_config["chunks"],
doc_token_limit=self.retriever_config.get("doc_token_limit", 50000),
model_id=self.model_id,
model_user_id=self.model_user_id,
user_api_key=self.agent_config["user_api_key"],
agent_id=self.agent_id,
decoded_token=self.decoded_token,
@@ -903,6 +935,11 @@ class StreamProcessor:
agent_config = state["agent_config"]
model_id = agent_config.get("model_id")
# BYOM scope captured at initial dispatch. None for built-ins or
# caller-owned BYOM where decoded_token['sub'] is already the
# right scope; non-None for shared-agent owner BYOM where the
# caller's identity differs from the model owner's.
model_user_id = agent_config.get("model_user_id")
llm_name = agent_config.get("llm_name", settings.LLM_PROVIDER)
api_key = agent_config.get("api_key")
user_api_key = agent_config.get("user_api_key")
@@ -920,6 +957,7 @@ class StreamProcessor:
decoded_token=self.decoded_token,
model_id=model_id,
agent_id=agent_id,
model_user_id=model_user_id,
)
llm_handler = LLMHandlerCreator.create_handler(llm_name or "default")
tool_executor = ToolExecutor(
@@ -949,6 +987,7 @@ class StreamProcessor:
"endpoint": "stream",
"llm_name": llm_name,
"model_id": model_id,
"model_user_id": model_user_id,
"api_key": system_api_key,
"agent_id": agent_id,
"user_api_key": user_api_key,
@@ -971,6 +1010,15 @@ class StreamProcessor:
# Store config for the route layer
self.model_id = model_id
# Mirror ``model_user_id`` back onto the processor so the route
# layer (StreamResource) reads the owner scope captured at
# initial dispatch. Without this, ``processor.model_user_id``
# stays at the __init__ default (None) and complete_stream
# falls back to the caller's sub: the post-resume title-LLM
# save misses the owner's BYOM layer, and any second tool
# pause persists ``model_user_id=None`` — losing owner scope
# for every subsequent resume of this conversation.
self.model_user_id = model_user_id
self.agent_id = agent_id
self.agent_config["user_api_key"] = user_api_key
self.conversation_id = conversation_id
@@ -1022,8 +1070,11 @@ class StreamProcessor:
tools_data=tools_data,
)
# Use the user_id that resolved the model so owner-scoped BYOM
# records dispatch correctly on shared-agent requests.
model_user_id = getattr(self, "model_user_id", self.initial_user_id)
provider = (
get_provider_from_model_id(self.model_id)
get_provider_from_model_id(self.model_id, user_id=model_user_id)
if self.model_id
else settings.LLM_PROVIDER
)
@@ -1048,6 +1099,8 @@ class StreamProcessor:
model_id=self.model_id,
agent_id=self.agent_id,
backup_models=backup_models,
# Owner-scope on shared-agent BYOM dispatch.
model_user_id=model_user_id,
)
llm_handler = LLMHandlerCreator.create_handler(
provider if provider else "default"
@@ -1070,6 +1123,7 @@ class StreamProcessor:
"endpoint": "stream",
"llm_name": provider or settings.LLM_PROVIDER,
"model_id": self.model_id,
"model_user_id": self.model_user_id,
"api_key": system_api_key,
"agent_id": self.agent_id,
"user_api_key": self.agent_config["user_api_key"],
@@ -1097,6 +1151,7 @@ class StreamProcessor:
"doc_token_limit", 50000
),
"model_id": self.model_id,
"model_user_id": self.model_user_id,
"user_api_key": self.agent_config["user_api_key"],
"agent_id": self.agent_id,
"llm_name": provider or settings.LLM_PROVIDER,

View File

@@ -1,18 +1,135 @@
from flask import current_app, jsonify, make_response
"""Model routes.
- ``GET /api/models`` — list available models for the current user.
Combines the built-in catalog with the user's BYOM records.
- ``GET/POST/PATCH/DELETE /api/user/models[/<id>]`` — CRUD for the
user's own OpenAI-compatible model registrations (BYOM).
- ``POST /api/user/models/<id>/test`` — sanity-check the upstream
endpoint with a tiny request.
Every BYOM endpoint is user-scoped at the repository layer
(every query filters on ``user_id`` from ``request.decoded_token``).
"""
from __future__ import annotations
import logging
import requests
from flask import current_app, jsonify, make_response, request
from flask_restx import Namespace, Resource
from application.core.model_settings import ModelRegistry
from application.api import api
from application.core.model_registry import ModelRegistry
from application.security.safe_url import (
UnsafeUserUrlError,
pinned_post,
validate_user_base_url,
)
from application.storage.db.repositories.user_custom_models import (
UserCustomModelsRepository,
)
from application.storage.db.session import db_readonly, db_session
from application.utils import check_required_fields
logger = logging.getLogger(__name__)
models_ns = Namespace("models", description="Available models", path="/api")
_CONTEXT_WINDOW_MIN = 1_000
_CONTEXT_WINDOW_MAX = 10_000_000
def _user_id_or_401():
decoded_token = request.decoded_token
if not decoded_token:
return None, make_response(jsonify({"success": False}), 401)
user_id = decoded_token.get("sub")
if not user_id:
return None, make_response(jsonify({"success": False}), 401)
return user_id, None
def _normalize_capabilities(raw) -> dict:
"""Coerce + bound the user-supplied capabilities payload."""
raw = raw or {}
out = {}
if "supports_tools" in raw:
out["supports_tools"] = bool(raw["supports_tools"])
if "supports_structured_output" in raw:
out["supports_structured_output"] = bool(raw["supports_structured_output"])
if "supports_streaming" in raw:
out["supports_streaming"] = bool(raw["supports_streaming"])
if "attachments" in raw:
atts = raw["attachments"] or []
if not isinstance(atts, list):
raise ValueError("'capabilities.attachments' must be a list")
coerced = [str(a) for a in atts]
# Reject unknown aliases at the API boundary so bad payloads
# never reach the registry layer (where lenient expansion just
# drops them). Raw MIME types (containing ``/``) pass through
# unchanged for parity with the built-in YAML schema.
from application.core.model_yaml import builtin_attachment_aliases
aliases = builtin_attachment_aliases()
for entry in coerced:
if "/" in entry:
continue
if entry not in aliases:
valid = ", ".join(sorted(aliases.keys())) or "<none defined>"
raise ValueError(
f"unknown attachment alias '{entry}' in "
f"'capabilities.attachments'. Valid aliases: {valid}, "
f"or use a raw MIME type like 'image/png'."
)
out["attachments"] = coerced
if "context_window" in raw:
try:
cw = int(raw["context_window"])
except (TypeError, ValueError):
raise ValueError("'capabilities.context_window' must be an integer")
if not (_CONTEXT_WINDOW_MIN <= cw <= _CONTEXT_WINDOW_MAX):
raise ValueError(
f"'capabilities.context_window' must be between "
f"{_CONTEXT_WINDOW_MIN} and {_CONTEXT_WINDOW_MAX}"
)
out["context_window"] = cw
return out
def _row_to_response(row: dict) -> dict:
"""Wire-format projection — never includes the API key."""
return {
"id": str(row["id"]),
"upstream_model_id": row["upstream_model_id"],
"display_name": row["display_name"],
"description": row.get("description") or "",
"base_url": row["base_url"],
"capabilities": row.get("capabilities") or {},
"enabled": bool(row.get("enabled", True)),
"source": "user",
}
@models_ns.route("/models")
class ModelsListResource(Resource):
def get(self):
"""Get list of available models with their capabilities."""
"""Get list of available models with their capabilities.
When the request is authenticated, the response includes the
user's own BYOM registrations alongside the built-in catalog.
"""
try:
user_id = None
decoded_token = getattr(request, "decoded_token", None)
if decoded_token:
user_id = decoded_token.get("sub")
registry = ModelRegistry.get_instance()
models = registry.get_enabled_models()
models = registry.get_enabled_models(user_id=user_id)
response = {
"models": [model.to_dict() for model in models],
@@ -23,3 +140,382 @@ class ModelsListResource(Resource):
current_app.logger.error(f"Error fetching models: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 500)
return make_response(jsonify(response), 200)
@models_ns.route("/user/models")
class UserModelsCollectionResource(Resource):
@api.doc(description="List the current user's BYOM custom models")
def get(self):
user_id, err = _user_id_or_401()
if err:
return err
try:
with db_readonly() as conn:
rows = UserCustomModelsRepository(conn).list_for_user(user_id)
return make_response(
jsonify({"models": [_row_to_response(r) for r in rows]}), 200
)
except Exception as e:
current_app.logger.error(
f"Error listing user custom models: {e}", exc_info=True
)
return make_response(jsonify({"success": False}), 500)
@api.doc(description="Register a new BYOM custom model")
def post(self):
user_id, err = _user_id_or_401()
if err:
return err
data = request.get_json() or {}
missing = check_required_fields(
data,
["upstream_model_id", "display_name", "base_url", "api_key"],
)
if missing:
return missing
# SECURITY: reject blank api_key — would leak instance API key
# to the user-supplied base_url via LLMCreator fallback.
for required_nonblank in (
"upstream_model_id",
"display_name",
"base_url",
"api_key",
):
value = data.get(required_nonblank)
if not isinstance(value, str) or not value.strip():
return make_response(
jsonify(
{
"success": False,
"error": f"'{required_nonblank}' must be a non-empty string",
}
),
400,
)
# SSRF guard at create time. Re-runs at dispatch time (LLMCreator)
# as defense in depth against DNS rebinding and pre-guard rows.
try:
validate_user_base_url(data["base_url"])
except UnsafeUserUrlError as e:
return make_response(
jsonify({"success": False, "error": str(e)}), 400
)
try:
capabilities = _normalize_capabilities(data.get("capabilities"))
except ValueError as e:
return make_response(
jsonify({"success": False, "error": str(e)}), 400
)
try:
with db_session() as conn:
row = UserCustomModelsRepository(conn).create(
user_id=user_id,
upstream_model_id=data["upstream_model_id"],
display_name=data["display_name"],
description=data.get("description") or "",
base_url=data["base_url"],
api_key_plaintext=data["api_key"],
capabilities=capabilities,
enabled=bool(data.get("enabled", True)),
)
except Exception as e:
current_app.logger.error(
f"Error creating user custom model: {e}", exc_info=True
)
return make_response(jsonify({"success": False}), 500)
ModelRegistry.invalidate_user(user_id)
return make_response(jsonify(_row_to_response(row)), 201)
@models_ns.route("/user/models/<string:model_id>")
class UserModelResource(Resource):
@api.doc(description="Get one BYOM custom model")
def get(self, model_id):
user_id, err = _user_id_or_401()
if err:
return err
try:
with db_readonly() as conn:
row = UserCustomModelsRepository(conn).get(model_id, user_id)
except Exception as e:
current_app.logger.error(
f"Error fetching user custom model: {e}", exc_info=True
)
return make_response(jsonify({"success": False}), 500)
if row is None:
return make_response(jsonify({"success": False}), 404)
return make_response(jsonify(_row_to_response(row)), 200)
@api.doc(description="Update a BYOM custom model (partial)")
def patch(self, model_id):
user_id, err = _user_id_or_401()
if err:
return err
data = request.get_json() or {}
# Reject present-but-blank values for fields where blank doesn't
# mean "no change". (The api_key special case — blank means "keep
# existing" — is handled below.)
for required_nonblank in (
"upstream_model_id",
"display_name",
"base_url",
):
if required_nonblank in data:
value = data[required_nonblank]
if not isinstance(value, str) or not value.strip():
return make_response(
jsonify(
{
"success": False,
"error": f"'{required_nonblank}' cannot be blank",
}
),
400,
)
if "base_url" in data and data["base_url"]:
try:
validate_user_base_url(data["base_url"])
except UnsafeUserUrlError as e:
return make_response(
jsonify({"success": False, "error": str(e)}), 400
)
update_fields: dict = {}
for k in (
"upstream_model_id",
"display_name",
"description",
"base_url",
"enabled",
):
if k in data:
update_fields[k] = data[k]
if "capabilities" in data:
try:
update_fields["capabilities"] = _normalize_capabilities(
data["capabilities"]
)
except ValueError as e:
return make_response(
jsonify({"success": False, "error": str(e)}), 400
)
# PATCH semantics: blank/missing api_key → keep the existing
# ciphertext; non-empty api_key → re-encrypt and replace.
if data.get("api_key"):
update_fields["api_key_plaintext"] = data["api_key"]
if not update_fields:
return make_response(
jsonify({"success": False, "error": "no updatable fields"}), 400
)
try:
with db_session() as conn:
ok = UserCustomModelsRepository(conn).update(
model_id, user_id, update_fields
)
except Exception as e:
current_app.logger.error(
f"Error updating user custom model: {e}", exc_info=True
)
return make_response(jsonify({"success": False}), 500)
if not ok:
return make_response(jsonify({"success": False}), 404)
ModelRegistry.invalidate_user(user_id)
with db_readonly() as conn:
row = UserCustomModelsRepository(conn).get(model_id, user_id)
return make_response(jsonify(_row_to_response(row)), 200)
@api.doc(description="Delete a BYOM custom model")
def delete(self, model_id):
user_id, err = _user_id_or_401()
if err:
return err
try:
with db_session() as conn:
ok = UserCustomModelsRepository(conn).delete(model_id, user_id)
except Exception as e:
current_app.logger.error(
f"Error deleting user custom model: {e}", exc_info=True
)
return make_response(jsonify({"success": False}), 500)
if not ok:
return make_response(jsonify({"success": False}), 404)
ModelRegistry.invalidate_user(user_id)
return make_response(jsonify({"success": True}), 200)
def _run_connection_test(
base_url: str, api_key: str, upstream_model_id: str
):
"""Send a 1-token chat-completion to verify a BYOM endpoint.
Returns ``(body, http_status)``. Upstream errors return 200 with
``ok=False`` so the UI can render inline errors; only local SSRF
rejection returns 400.
"""
url = base_url.rstrip("/") + "/chat/completions"
payload = {
"model": upstream_model_id,
"messages": [{"role": "user", "content": "hi"}],
"max_tokens": 1,
"stream": False,
}
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
try:
# pinned_post closes the DNS-rebinding window. Redirects off
# because 3xx could bounce to an internal address (the SSRF
# guard only validates the supplied URL).
resp = pinned_post(
url,
json=payload,
headers=headers,
timeout=5,
allow_redirects=False,
)
except UnsafeUserUrlError as e:
return {"ok": False, "error": str(e)}, 400
except requests.RequestException as e:
return {"ok": False, "error": f"connection error: {e}"}, 200
if 300 <= resp.status_code < 400:
return (
{
"ok": False,
"error": (
f"upstream returned HTTP {resp.status_code} "
"redirect; refusing to follow"
),
},
200,
)
if resp.status_code >= 400:
# Cap and only reflect JSON to avoid body-exfil via non-API responses.
content_type = (resp.headers.get("Content-Type") or "").lower()
if "application/json" in content_type:
text = (resp.text or "")[:500]
error_msg = f"upstream returned HTTP {resp.status_code}: {text}"
else:
error_msg = f"upstream returned HTTP {resp.status_code}"
return {"ok": False, "error": error_msg}, 200
return {"ok": True}, 200
@models_ns.route("/user/models/test")
class UserModelTestPayloadResource(Resource):
@api.doc(
description=(
"Test an arbitrary BYOM payload (display_name / model id / "
"base_url / api_key) without saving. Used by the UI's 'Test "
"connection' button so the user can validate before they "
"Save. Same SSRF guard, same 1-token request, same 5s "
"timeout as the by-id variant."
)
)
def post(self):
user_id, err = _user_id_or_401()
if err:
return err
data = request.get_json() or {}
missing = check_required_fields(
data, ["base_url", "api_key", "upstream_model_id"]
)
if missing:
return missing
body, status = _run_connection_test(
data["base_url"], data["api_key"], data["upstream_model_id"]
)
return make_response(jsonify(body), status)
@models_ns.route("/user/models/<string:model_id>/test")
class UserModelTestResource(Resource):
@api.doc(
description=(
"Test a saved BYOM record. Defaults to the stored "
"base_url / upstream_model_id / encrypted api_key, but "
"any of those can be overridden via the request body so "
"the UI can test in-flight edits before saving. Used by "
"the 'Test connection' button in edit mode."
)
)
def post(self, model_id):
user_id, err = _user_id_or_401()
if err:
return err
data = request.get_json() or {}
# Per-field overrides; blank/missing falls back to stored value.
override_base_url = (data.get("base_url") or "").strip() or None
override_upstream_model_id = (
data.get("upstream_model_id") or ""
).strip() or None
override_api_key = (data.get("api_key") or "").strip() or None
try:
with db_readonly() as conn:
repo = UserCustomModelsRepository(conn)
row = repo.get(model_id, user_id)
if row is None:
return make_response(jsonify({"success": False}), 404)
stored_api_key = (
repo._decrypt_api_key(
row.get("api_key_encrypted", ""), user_id
)
if not override_api_key
else None
)
except Exception as e:
current_app.logger.error(
f"Error loading user custom model for test: {e}", exc_info=True
)
return make_response(
jsonify({"ok": False, "error": "internal error loading model"}),
500,
)
api_key = override_api_key or stored_api_key
if not api_key:
return make_response(
jsonify(
{
"ok": False,
"error": (
"Stored API key could not be decrypted. The "
"encryption secret may have rotated. Re-save "
"the model with the API key to recover."
),
}
),
400,
)
base_url = override_base_url or row["base_url"]
upstream_model_id = (
override_upstream_model_id or row["upstream_model_id"]
)
body, status = _run_connection_test(
base_url, api_key, upstream_model_id
)
return make_response(jsonify(body), status)

View File

@@ -198,8 +198,14 @@ def normalize_agent_node_json_schemas(nodes: List[Dict]) -> List[Dict]:
return normalized_nodes
def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[str]:
"""Validate workflow graph structure."""
def validate_workflow_structure(
nodes: List[Dict], edges: List[Dict], user_id: str | None = None
) -> List[str]:
"""Validate workflow graph structure.
``user_id`` is required so per-user BYOM custom-model UUIDs resolve
when checking each agent node's structured-output capability.
"""
errors = []
if not nodes:
@@ -343,7 +349,7 @@ def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[st
model_id = raw_config.get("model_id")
if has_json_schema and isinstance(model_id, str) and model_id.strip():
capabilities = get_model_capabilities(model_id.strip())
capabilities = get_model_capabilities(model_id.strip(), user_id=user_id)
if capabilities and not capabilities.get("supports_structured_output", False):
errors.append(
f"Agent node '{agent_title}' selected model does not support structured output"
@@ -389,7 +395,9 @@ class WorkflowList(Resource):
nodes_data = data.get("nodes", [])
edges_data = data.get("edges", [])
validation_errors = validate_workflow_structure(nodes_data, edges_data)
validation_errors = validate_workflow_structure(
nodes_data, edges_data, user_id=user_id
)
if validation_errors:
return error_response(
"Workflow validation failed", errors=validation_errors
@@ -451,7 +459,9 @@ class WorkflowDetail(Resource):
nodes_data = data.get("nodes", [])
edges_data = data.get("edges", [])
validation_errors = validate_workflow_structure(nodes_data, edges_data)
validation_errors = validate_workflow_structure(
nodes_data, edges_data, user_id=user_id
)
if validation_errors:
return error_response(
"Workflow validation failed", errors=validation_errors

View File

@@ -213,6 +213,7 @@ def _stream_response(
decoded_token=processor.decoded_token,
agent_id=processor.agent_id,
model_id=processor.model_id,
model_user_id=processor.model_user_id,
should_save_conversation=should_save_conversation,
_continuation=continuation,
)
@@ -257,6 +258,7 @@ def _non_stream_response(
decoded_token=processor.decoded_token,
agent_id=processor.agent_id,
model_id=processor.model_id,
model_user_id=processor.model_user_id,
should_save_conversation=should_save_conversation,
_continuation=continuation,
)

View File

@@ -153,7 +153,7 @@ def after_request(response: Response) -> Response:
"""Add CORS headers for the pure Flask development entrypoint."""
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization"
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS"
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
return response

View File

@@ -24,7 +24,7 @@ asgi_app = Starlette(
Middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
allow_headers=["Content-Type", "Authorization", "Mcp-Session-Id"],
expose_headers=["Mcp-Session-Id"],
),

View File

@@ -4,17 +4,28 @@ Loads model catalogs from YAML files (built-in + operator-supplied),
groups them by provider name, then for each registered provider plugin
calls ``get_models`` to produce the final per-provider model list.
The ``user_id`` parameter on lookup methods is reserved for the future
end-user BYOM (per-user model records in Postgres). It is currently
ignored — defaulted to ``None`` everywhere — so call sites can be
threaded through without a wide refactor when BYOM lands.
End-user BYOM (per-user model records in Postgres) is layered on top:
when a lookup arrives with a ``user_id``, the registry consults a
per-user cache first (loaded from the ``user_custom_models`` table on
miss) and falls through to the built-in catalog.
Cross-process invalidation: ``ModelRegistry`` is a per-process
singleton, so a CRUD write only evicts the cache in the process that
served it. Other gunicorn workers and Celery workers would otherwise
keep using a deleted/disabled/key-rotated BYOM record indefinitely.
``invalidate_user`` therefore both drops the local layer *and* bumps a
Redis-side version counter; other processes notice the bump on their
next access (after the local TTL window) and reload from Postgres. If
Redis is unreachable the per-process TTL still bounds staleness — pure
TTL semantics, no regression.
"""
from __future__ import annotations
import logging
import time
from collections import defaultdict
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Tuple
from application.core.model_settings import AvailableModel
from application.core.model_yaml import (
@@ -25,6 +36,9 @@ from application.core.model_yaml import (
logger = logging.getLogger(__name__)
_USER_CACHE_TTL_SECONDS = 60.0
_USER_VERSION_KEY_PREFIX = "byom:registry_version:"
class ModelRegistry:
"""Singleton registry of available models."""
@@ -41,6 +55,18 @@ class ModelRegistry:
if not ModelRegistry._initialized:
self.models: Dict[str, AvailableModel] = {}
self.default_model_id: Optional[str] = None
# Per-user BYOM cache. Each entry is
# ``(layer, version_at_load, loaded_at_monotonic)``:
# * ``layer`` — {model_id: AvailableModel}
# * ``version_at_load`` — Redis-side counter snapshot at
# reload time, or ``None`` if Redis was unreachable
# * ``loaded_at_monotonic`` — for TTL bookkeeping
# Populated lazily, evicted by TTL + cross-process
# invalidation (see ``invalidate_user``).
self._user_models: Dict[
str,
Tuple[Dict[str, AvailableModel], Optional[int], float],
] = {}
self._load_models()
ModelRegistry._initialized = True
@@ -54,6 +80,59 @@ class ModelRegistry:
cls._instance = None
cls._initialized = False
@classmethod
def invalidate_user(cls, user_id: str) -> None:
"""Drop the cached per-user model layer for ``user_id``.
Called by the BYOM REST routes after every create/update/delete.
Two effects:
* Local: pop the entry from this process's cache so the next
lookup re-reads from Postgres immediately.
* Cross-process: ``INCR`` a Redis-side version counter for this
user. Other gunicorn/Celery processes notice the counter
changed on their next TTL-driven recheck (see
``_user_models_for``) and reload. If Redis is unreachable we
log and continue — local invalidation still happened, and
peers fall back to TTL-only staleness bounds.
"""
if cls._instance is not None:
cls._instance._user_models.pop(user_id, None)
try:
from application.cache import get_redis_instance
client = get_redis_instance()
if client is not None:
client.incr(_USER_VERSION_KEY_PREFIX + user_id)
except Exception as e:
logger.warning(
"BYOM invalidate: failed to publish version bump for "
"user %s (Redis unreachable?): %s",
user_id,
e,
)
@classmethod
def _read_user_version(cls, user_id: str) -> Optional[int]:
"""Return the Redis-side invalidation counter for ``user_id``.
``0`` if the key has never been bumped; ``None`` if Redis is
unreachable or the read failed (callers fall back to TTL-only
staleness in that case).
"""
try:
from application.cache import get_redis_instance
client = get_redis_instance()
if client is None:
return None
raw = client.get(_USER_VERSION_KEY_PREFIX + user_id)
if raw is None:
return 0
return int(raw)
except Exception:
return None
def _load_models(self) -> None:
from pathlib import Path
@@ -137,28 +216,170 @@ class ModelRegistry:
return []
return [name.strip() for name in llm_name.split(",") if name.strip()]
# ------------------------------------------------------------------
# Lookup API. ``user_id`` is reserved for the future BYOM and
# is ignored today — but threading it through every call site now
# means BYOM doesn't require a wide refactor when we build it.
# ------------------------------------------------------------------
# Per-user (BYOM) layer
def _user_models_for(self, user_id: str) -> Dict[str, AvailableModel]:
"""Return the user's BYOM models keyed by registry id (UUID).
Loaded lazily from Postgres on first access; cached subject to
a per-process TTL (``_USER_CACHE_TTL_SECONDS``) and a Redis-
backed version counter for cross-process invalidation. The TTL
bounds staleness even when Redis is unreachable, while the
version stamp lets peers refresh without a DB read on the
common case (no invalidation since last load). Decryption
failures and DB errors yield an empty layer (logged) — the
user simply doesn't see their custom models on this request,
never a 500.
"""
cached = self._user_models.get(user_id)
now = time.monotonic()
if cached is not None:
layer, cached_version, loaded_at = cached
if (now - loaded_at) < _USER_CACHE_TTL_SECONDS:
return layer
# TTL elapsed: peek at the cross-process counter. If it
# matches what we saw at load time, no invalidation has
# happened — extend the TTL without touching Postgres. If
# Redis is unreachable (``current_version is None``) we
# fall through to a real reload, which keeps staleness
# bounded to the TTL.
current_version = self._read_user_version(user_id)
if (
current_version is not None
and cached_version is not None
and current_version == cached_version
):
self._user_models[user_id] = (layer, cached_version, now)
return layer
# Capture the counter *before* the DB read so a CRUD that lands
# mid-reload doesn't get masked: the next access will see a
# newer version and reload again.
version_before_read = self._read_user_version(user_id)
layer: Dict[str, AvailableModel] = {}
try:
from application.core.model_settings import (
ModelCapabilities,
ModelProvider,
)
from application.storage.db.repositories.user_custom_models import (
UserCustomModelsRepository,
)
from application.storage.db.session import db_readonly
with db_readonly() as conn:
repo = UserCustomModelsRepository(conn)
rows = repo.list_for_user(user_id)
for row in rows:
api_key = repo._decrypt_api_key(
row.get("api_key_encrypted", ""), user_id
)
if not api_key:
# SECURITY: do NOT register an unroutable BYOM
# record. If we did, LLMCreator would fall back
# to the caller-passed api_key (settings.API_KEY
# for openai_compatible) and POST it to the
# user-supplied base_url — leaking the instance
# credential to the user's chosen endpoint.
# Most likely cause is ENCRYPTION_SECRET_KEY
# having rotated; user must re-save the model.
logger.warning(
"user_custom_models: skipping model %s for "
"user %s — api_key could not be decrypted "
"(rotated ENCRYPTION_SECRET_KEY?). Re-save "
"the model to recover.",
row.get("id"),
user_id,
)
continue
caps_raw = row.get("capabilities") or {}
# Stored attachments may be aliases (``image``) or
# raw MIME types. Built-in YAML models expand at
# load time; mirror that here so downstream MIME-
# type comparisons (handlers/base.prepare_messages)
# match concrete types like ``image/png`` rather
# than the bare alias.
from application.core.model_yaml import (
expand_attachments_lenient,
)
raw_attachments = caps_raw.get("attachments", []) or []
expanded_attachments = expand_attachments_lenient(
raw_attachments,
f"user_custom_models[user={user_id}, model={row.get('id')}]",
)
caps = ModelCapabilities(
supports_tools=bool(caps_raw.get("supports_tools", False)),
supports_structured_output=bool(
caps_raw.get("supports_structured_output", False)
),
supports_streaming=bool(
caps_raw.get("supports_streaming", True)
),
supported_attachment_types=expanded_attachments,
context_window=int(
caps_raw.get("context_window") or 128000
),
)
model_id = str(row["id"])
layer[model_id] = AvailableModel(
id=model_id,
provider=ModelProvider.OPENAI_COMPATIBLE,
display_name=row["display_name"],
description=row.get("description") or "",
capabilities=caps,
enabled=bool(row.get("enabled", True)),
base_url=row["base_url"],
upstream_model_id=row["upstream_model_id"],
source="user",
api_key=api_key,
)
except Exception as e:
logger.warning(
"user_custom_models: failed to load layer for user %s: %s",
user_id,
e,
)
layer = {}
self._user_models[user_id] = (layer, version_before_read, now)
return layer
# Lookup API. ``user_id`` enables the BYOM per-user layer; without
# it, callers see only the built-in + operator catalog.
def get_model(
self, model_id: str, user_id: Optional[str] = None
) -> Optional[AvailableModel]:
if user_id:
user_layer = self._user_models_for(user_id)
if model_id in user_layer:
return user_layer[model_id]
return self.models.get(model_id)
def get_all_models(
self, user_id: Optional[str] = None
) -> List[AvailableModel]:
return list(self.models.values())
out = list(self.models.values())
if user_id:
out.extend(self._user_models_for(user_id).values())
return out
def get_enabled_models(
self, user_id: Optional[str] = None
) -> List[AvailableModel]:
return [m for m in self.models.values() if m.enabled]
out = [m for m in self.models.values() if m.enabled]
if user_id:
out.extend(
m for m in self._user_models_for(user_id).values() if m.enabled
)
return out
def model_exists(
self, model_id: str, user_id: Optional[str] = None
) -> bool:
if user_id and model_id in self._user_models_for(user_id):
return True
return model_id in self.models

View File

@@ -48,14 +48,15 @@ class AvailableModel:
capabilities: ModelCapabilities = field(default_factory=ModelCapabilities)
enabled: bool = True
base_url: Optional[str] = None
# User-facing label distinct from the dispatch ``provider``. Used by
# openai_compatible YAMLs so a Mistral model shows "mistral" in the
# API response while still routing through the OpenAI wire format.
# User-facing label distinct from dispatch provider (e.g. mistral
# routed through openai_compatible).
display_provider: Optional[str] = None
# Per-record API key. Operator YAMLs leave this None; populated for
# openai_compatible models (resolved from the YAML's ``api_key_env``)
# and reserved for the future end-user BYOM phase. Never serialized
# into to_dict().
# Sent in the API call's ``model`` field; falls back to ``self.id``
# for built-ins where id IS the upstream name.
upstream_model_id: Optional[str] = None
# "builtin" for catalog YAMLs, "user" for BYOM records.
source: str = "builtin"
# Decrypted/resolved at registry-merge time. Never serialized.
api_key: Optional[str] = field(default=None, repr=False, compare=False)
def to_dict(self) -> Dict:
@@ -70,6 +71,7 @@ class AvailableModel:
"supports_streaming": self.capabilities.supports_streaming,
"context_window": self.capabilities.context_window,
"enabled": self.enabled,
"source": self.source,
}
if self.base_url:
result["base_url"] = self.base_url

View File

@@ -20,22 +20,40 @@ def get_api_key_for_provider(provider: str) -> Optional[str]:
return settings.API_KEY
def get_all_available_models() -> Dict[str, Dict[str, Any]]:
"""Get all available models with metadata for API response"""
def get_all_available_models(
user_id: Optional[str] = None,
) -> Dict[str, Dict[str, Any]]:
"""Get all available models with metadata for API response.
When ``user_id`` is supplied, the user's BYOM custom-model records
are merged into the result alongside the built-in catalog.
"""
registry = ModelRegistry.get_instance()
return {model.id: model.to_dict() for model in registry.get_enabled_models()}
return {
model.id: model.to_dict()
for model in registry.get_enabled_models(user_id=user_id)
}
def validate_model_id(model_id: str) -> bool:
"""Check if a model ID exists in registry"""
def validate_model_id(model_id: str, user_id: Optional[str] = None) -> bool:
"""Check if a model ID exists in registry.
``user_id`` enables resolution of per-user BYOM records (UUIDs).
Without it, only built-in catalog ids resolve.
"""
registry = ModelRegistry.get_instance()
return registry.model_exists(model_id)
return registry.model_exists(model_id, user_id=user_id)
def get_model_capabilities(model_id: str) -> Optional[Dict[str, Any]]:
"""Get capabilities for a specific model"""
def get_model_capabilities(
model_id: str, user_id: Optional[str] = None
) -> Optional[Dict[str, Any]]:
"""Get capabilities for a specific model.
``user_id`` enables resolution of per-user BYOM records.
"""
registry = ModelRegistry.get_instance()
model = registry.get_model(model_id)
model = registry.get_model(model_id, user_id=user_id)
if model:
return {
"supported_attachment_types": model.capabilities.supported_attachment_types,
@@ -52,52 +70,66 @@ def get_default_model_id() -> str:
return registry.default_model_id
def get_provider_from_model_id(model_id: str) -> Optional[str]:
"""Get the provider name for a given model_id"""
def get_provider_from_model_id(
model_id: str, user_id: Optional[str] = None
) -> Optional[str]:
"""Get the provider name for a given model_id.
``user_id`` enables resolution of per-user BYOM records (UUIDs).
Without it, BYOM model ids return ``None`` and the caller falls
back to the deployment default.
"""
registry = ModelRegistry.get_instance()
model = registry.get_model(model_id)
model = registry.get_model(model_id, user_id=user_id)
if model:
return model.provider.value
return None
def get_token_limit(model_id: str) -> int:
"""
Get context window (token limit) for a model.
Returns model's context_window or default 128000 if model not found.
def get_token_limit(model_id: str, user_id: Optional[str] = None) -> int:
"""Get context window (token limit) for a model.
Returns the model's ``context_window`` or ``DEFAULT_LLM_TOKEN_LIMIT``
if not found. ``user_id`` enables resolution of per-user BYOM records.
"""
from application.core.settings import settings
registry = ModelRegistry.get_instance()
model = registry.get_model(model_id)
model = registry.get_model(model_id, user_id=user_id)
if model:
return model.capabilities.context_window
return settings.DEFAULT_LLM_TOKEN_LIMIT
def get_base_url_for_model(model_id: str) -> Optional[str]:
"""
Get the custom base_url for a specific model if configured.
Returns None if no custom base_url is set.
def get_base_url_for_model(
model_id: str, user_id: Optional[str] = None
) -> Optional[str]:
"""Get the custom base_url for a specific model if configured.
Returns ``None`` if no custom base_url is set. ``user_id`` enables
resolution of per-user BYOM records.
"""
registry = ModelRegistry.get_instance()
model = registry.get_model(model_id)
model = registry.get_model(model_id, user_id=user_id)
if model:
return model.base_url
return None
def get_api_key_for_model(model_id: str) -> Optional[str]:
"""
Resolve the API key to use when invoking ``model_id``.
def get_api_key_for_model(
model_id: str, user_id: Optional[str] = None
) -> Optional[str]:
"""Resolve the API key to use when invoking ``model_id``.
Priority:
1. The model record's own ``api_key`` (reserved for future end-user
BYOM where credentials travel with the record).
1. The model record's own ``api_key`` (BYOM records and
``openai_compatible`` YAMLs populate this).
2. The provider plugin's settings-based key.
``user_id`` enables resolution of per-user BYOM records.
"""
registry = ModelRegistry.get_instance()
model = registry.get_model(model_id)
model = registry.get_model(model_id, user_id=user_id)
if model is not None and model.api_key:
return model.api_key
if model is not None:

View File

@@ -281,6 +281,39 @@ def resolve_attachment_alias(alias: str) -> List[str]:
return list(aliases[alias])
def expand_attachments_lenient(
attachments: Sequence[str], source: str
) -> List[str]:
"""Expand attachment aliases to MIME types, tolerating unknowns.
Mirrors ``_expand_attachments`` but logs+skips unknown aliases
rather than raising. Used for runtime call sites (BYOM registry
load) where an operator-side alias-map edit must not drop the
entire user's BYOM layer; the strict raise still happens at the
API validation boundary.
"""
aliases = builtin_attachment_aliases()
expanded: List[str] = []
seen: set = set()
for entry in attachments:
if "/" in entry:
if entry not in seen:
expanded.append(entry)
seen.add(entry)
continue
mime_list = aliases.get(entry)
if mime_list is None:
logger.warning(
"%s: skipping unknown attachment alias %r", source, entry,
)
continue
for mime in mime_list:
if mime not in seen:
expanded.append(mime)
seen.add(mime)
return expanded
def load_model_yamls(directories: Sequence[Path]) -> List[ProviderCatalog]:
"""Load every ``*.yaml`` file (excluding ``_defaults.yaml``) under each
directory in order and return a flat list of catalogs.

View File

@@ -17,6 +17,8 @@ class BaseLLM(ABC):
model_id=None,
base_url=None,
backup_models=None,
model_user_id=None,
capabilities=None,
):
self.decoded_token = decoded_token
self.agent_id = str(agent_id) if agent_id else None
@@ -25,6 +27,12 @@ class BaseLLM(ABC):
self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
self._backup_models = backup_models or []
self._fallback_llm = None
# Registry-resolved per-model capability overrides (BYOM caps,
# operator YAML). None falls back to provider-class defaults.
self.capabilities = capabilities
# BYOM-resolution scope captured at LLM creation time so backup
# / fallback lookups hit the same per-user layer as the primary.
self.model_user_id = model_user_id
@property
def fallback_llm(self):
@@ -39,10 +47,19 @@ class BaseLLM(ABC):
get_api_key_for_provider,
)
# Try per-agent backup models first
# model_user_id (BYOM scope) takes precedence over the caller's
# sub so shared-agent backups resolve under the owner's layer.
caller_sub = (
self.decoded_token.get("sub")
if isinstance(self.decoded_token, dict)
else None
)
backup_user_id = self.model_user_id or caller_sub
for backup_model_id in self._backup_models:
try:
provider = get_provider_from_model_id(backup_model_id)
provider = get_provider_from_model_id(
backup_model_id, user_id=backup_user_id
)
if not provider:
logger.warning(
f"Could not resolve provider for backup model: {backup_model_id}"
@@ -56,6 +73,7 @@ class BaseLLM(ABC):
decoded_token=self.decoded_token,
model_id=backup_model_id,
agent_id=self.agent_id,
model_user_id=self.model_user_id,
)
logger.info(
f"Fallback LLM initialized from agent backup model: "
@@ -68,7 +86,10 @@ class BaseLLM(ABC):
)
continue
# Fall back to global FALLBACK_* settings
# Fall back to global FALLBACK_* settings. Forward
# ``model_user_id`` here too: deployments can configure
# ``FALLBACK_LLM_NAME`` to a BYOM UUID, and that UUID is owned
# by the same user the primary model was resolved under.
if settings.FALLBACK_LLM_PROVIDER:
try:
self._fallback_llm = LLMCreator.create_llm(
@@ -78,6 +99,7 @@ class BaseLLM(ABC):
decoded_token=self.decoded_token,
model_id=settings.FALLBACK_LLM_NAME,
agent_id=self.agent_id,
model_user_id=self.model_user_id,
)
logger.info(
f"Fallback LLM initialized from global settings: "

View File

@@ -470,10 +470,14 @@ class LLMHandler(ABC):
)
return self._perform_in_memory_compression(agent, messages)
# Use orchestrator to perform compression
# Use orchestrator to perform compression. ``model_user_id``
# keeps BYOM registry resolution scoped to the model owner
# (shared-agent dispatch) while ``user_id`` stays the caller
# for the conversation access check.
result = orchestrator.compress_mid_execution(
conversation_id=agent.conversation_id,
user_id=agent.initial_user_id,
model_user_id=getattr(agent, "model_user_id", None),
model_id=agent.model_id,
decoded_token=getattr(agent, "decoded_token", {}),
current_conversation=conversation,
@@ -577,7 +581,20 @@ class LLMHandler(ABC):
if settings.COMPRESSION_MODEL_OVERRIDE
else agent.model_id
)
provider = get_provider_from_model_id(compression_model)
agent_decoded = getattr(agent, "decoded_token", None)
caller_sub = (
agent_decoded.get("sub")
if isinstance(agent_decoded, dict)
else None
)
# Use model-owner scope (mirrors orchestrator path) so
# shared-agent owner-BYOM resolves under the owner's layer.
compression_user_id = (
getattr(agent, "model_user_id", None) or caller_sub
)
provider = get_provider_from_model_id(
compression_model, user_id=compression_user_id
)
api_key = get_api_key_for_provider(provider)
compression_llm = LLMCreator.create_llm(
provider,
@@ -586,6 +603,7 @@ class LLMHandler(ABC):
getattr(agent, "decoded_token", None),
model_id=compression_model,
agent_id=getattr(agent, "agent_id", None),
model_user_id=compression_user_id,
)
# Create service without DB persistence capability
@@ -921,8 +939,15 @@ class LLMHandler(ABC):
}
return ""
# ``agent.model_id`` is the registry id (a UUID for BYOM
# records). Use the LLM's own model_id, which LLMCreator
# already resolved to the upstream model name. Built-ins:
# the two are equal; BYOM: the upstream name like
# "mistral-large-latest" instead of the UUID.
response = agent.llm.gen(
model=agent.model_id, messages=messages, tools=agent.tools
model=getattr(agent.llm, "model_id", None) or agent.model_id,
messages=messages,
tools=agent.tools,
)
parsed = self.parse_response(response)
self.llm_calls.append(build_stack_data(agent.llm))
@@ -1011,8 +1036,11 @@ class LLMHandler(ABC):
})
logger.info("Context limit reached - instructing agent to wrap up")
# See note above on agent.model_id vs llm.model_id.
response = agent.llm.gen_stream(
model=agent.model_id, messages=messages, tools=agent.tools if not agent.context_limit_reached else None
model=getattr(agent.llm, "model_id", None) or agent.model_id,
messages=messages,
tools=agent.tools if not agent.context_limit_reached else None,
)
self.llm_calls.append(build_stack_data(agent.llm))

View File

@@ -16,37 +16,111 @@ class LLMCreator:
model_id=None,
agent_id=None,
backup_models=None,
model_user_id=None,
*args,
**kwargs,
):
"""Construct an LLM for the given provider ``type``.
``model_user_id`` is the BYOM-resolution scope. Defaults to
``decoded_token['sub']`` (the caller). Pass it explicitly when
the model record belongs to a *different* user — most notably
for shared-agent dispatch, where the agent's stored
``default_model_id`` is the owner's BYOM UUID but
``decoded_token`` represents the caller.
"""
from application.core.model_registry import ModelRegistry
from application.security.safe_url import (
UnsafeUserUrlError,
pinned_httpx_client,
validate_user_base_url,
)
plugin = PROVIDERS_BY_NAME.get(type.lower())
if plugin is None or plugin.llm_class is None:
raise ValueError(f"No LLM class found for type {type}")
# Prefer per-model endpoint config from the registry. This is what
# makes openai_compatible (and the future end-user BYOM phase)
# work without changing every call site: if the registered
# AvailableModel carries its own api_key / base_url, they win
# over whatever the caller resolved via the provider plugin.
# makes openai_compatible AND end-user BYOM work without changing
# every call site: if the registered AvailableModel carries its
# own api_key / base_url, they win over whatever the caller
# resolved via the provider plugin.
#
# End-user BYOM lookups need the user_id from decoded_token to
# find the user's per-user models layer (built-in models resolve
# without it, so this stays back-compat).
base_url = None
upstream_model_id = model_id
capabilities = None
if model_id:
model = ModelRegistry.get_instance().get_model(model_id)
user_id = model_user_id
if user_id is None:
user_id = (
(decoded_token or {}).get("sub") if decoded_token else None
)
model = ModelRegistry.get_instance().get_model(model_id, user_id=user_id)
if model is not None:
# Forward registry caps so the LLM enforces them at
# dispatch (built-in classes hard-code True otherwise).
capabilities = getattr(model, "capabilities", None)
# SECURITY: refuse user-source dispatch without its own
# api_key (would leak settings.API_KEY to base_url).
if (
getattr(model, "source", "builtin") == "user"
and not model.api_key
):
raise ValueError(
f"Custom model {model_id!r} has no usable API key "
"(decryption may have failed). Re-save the model "
"in settings to dispatch it."
)
if model.api_key:
api_key = model.api_key
if model.base_url:
base_url = model.base_url
# For BYOM the registry id is a UUID; the upstream API
# call needs the user's typed model name instead.
if model.upstream_model_id:
upstream_model_id = model.upstream_model_id
# SECURITY: re-validate at dispatch (defense in depth
# for pre-guard rows / YAML-supplied entries). The
# pinned httpx.Client below is what actually closes the
# DNS-rebinding TOCTOU window.
if base_url and getattr(model, "source", "builtin") == "user":
try:
validate_user_base_url(base_url)
except UnsafeUserUrlError as e:
raise ValueError(
f"Refusing to dispatch model {model_id!r}: {e}"
) from e
# Pinned httpx.Client: resolves once, validates, and
# binds the SDK's outbound socket to the validated IP
# (preserves Host / SNI). Future BYOM providers must
# opt in explicitly — only openai_compatible takes
# http_client today.
if plugin.name == "openai_compatible":
try:
kwargs["http_client"] = pinned_httpx_client(
base_url
)
except UnsafeUserUrlError as e:
raise ValueError(
f"Refusing to dispatch model {model_id!r}: {e}"
) from e
# Forward model_user_id so backup/fallback resolves under the
# owner's scope on shared-agent dispatch.
return plugin.llm_class(
api_key,
user_api_key,
decoded_token=decoded_token,
model_id=model_id,
model_id=upstream_model_id,
agent_id=agent_id,
base_url=base_url,
backup_models=backup_models,
model_user_id=model_user_id,
capabilities=capabilities,
*args,
**kwargs,
)

View File

@@ -62,7 +62,15 @@ def _truncate_base64_for_logging(messages):
class OpenAILLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
def __init__(
self,
api_key=None,
user_api_key=None,
base_url=None,
http_client=None,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.api_key = api_key or settings.OPENAI_API_KEY or settings.API_KEY
@@ -80,7 +88,18 @@ class OpenAILLM(BaseLLM):
else:
effective_base_url = "https://api.openai.com/v1"
self.client = OpenAI(api_key=self.api_key, base_url=effective_base_url)
# http_client (set by LLMCreator for BYOM) is a DNS-rebinding-safe
# httpx.Client; without it the SDK re-resolves DNS per request.
if http_client is not None:
self.client = OpenAI(
api_key=self.api_key,
base_url=effective_base_url,
http_client=http_client,
)
else:
self.client = OpenAI(
api_key=self.api_key, base_url=effective_base_url
)
self.storage = StorageCreator.get_storage()
def _clean_messages_openai(self, messages):
@@ -243,6 +262,13 @@ class OpenAILLM(BaseLLM):
if "max_tokens" in kwargs:
kwargs["max_completion_tokens"] = kwargs.pop("max_tokens")
# Defense-in-depth: drop tools / response_format if the
# registry's capability flags deny them.
if tools and not self._supports_tools():
tools = None
if response_format and not self._supports_structured_output():
response_format = None
request_params = {
"model": model,
"messages": messages,
@@ -279,6 +305,13 @@ class OpenAILLM(BaseLLM):
if "max_tokens" in kwargs:
kwargs["max_completion_tokens"] = kwargs.pop("max_tokens")
# See _raw_gen for rationale — drop tools/response_format when the
# registry-provided capabilities say the model doesn't support them.
if tools and not self._supports_tools():
tools = None
if response_format and not self._supports_structured_output():
response_format = None
request_params = {
"model": model,
"messages": messages,
@@ -320,9 +353,17 @@ class OpenAILLM(BaseLLM):
response.close()
def _supports_tools(self):
# When the LLM was constructed via LLMCreator with a registered
# AvailableModel, ``self.capabilities`` is the per-model record.
# BYOM users can disable tool support; respect that. Otherwise
# OpenAI's API supports tools by default.
if self.capabilities is not None:
return bool(self.capabilities.supports_tools)
return True
def _supports_structured_output(self):
if self.capabilities is not None:
return bool(self.capabilities.supports_structured_output)
return True
def prepare_structured_output_format(self, json_schema):
@@ -389,6 +430,12 @@ class OpenAILLM(BaseLLM):
Returns:
list: List of supported MIME types
"""
# Per-model caps from the registry win when present — a BYOM
# endpoint that doesn't accept images would otherwise still be
# sent base64 image parts because the OpenAI default below
# advertises the image alias unconditionally.
if self.capabilities is not None:
return list(self.capabilities.supported_attachment_types or [])
from application.core.model_yaml import resolve_attachment_alias
return resolve_attachment_alias("image")

View File

@@ -22,6 +22,7 @@ class ClassicRAG(BaseRetriever):
llm_name=settings.LLM_PROVIDER,
api_key=settings.API_KEY,
decoded_token=None,
model_user_id=None,
):
self.original_question = source.get("question", "")
self.chat_history = chat_history if chat_history is not None else []
@@ -42,17 +43,22 @@ class ClassicRAG(BaseRetriever):
f"sources={'active_docs' in source and source['active_docs'] is not None}"
)
self.model_id = model_id
self.model_user_id = model_user_id
self.doc_token_limit = doc_token_limit
self.user_api_key = user_api_key
self.agent_id = agent_id
self.llm_name = llm_name
self.api_key = api_key
# Forward model_id + model_user_id so LLMCreator resolves BYOM
# base_url / api_key / upstream id for the rephrase client.
self.llm = LLMCreator.create_llm(
self.llm_name,
api_key=self.api_key,
user_api_key=self.user_api_key,
decoded_token=decoded_token,
model_id=self.model_id,
agent_id=self.agent_id,
model_user_id=self.model_user_id,
)
if "active_docs" in source and source["active_docs"] is not None:
@@ -103,7 +109,11 @@ class ClassicRAG(BaseRetriever):
]
try:
rephrased_query = self.llm.gen(model=self.model_id, messages=messages)
# Send upstream id (resolved by LLMCreator), not registry UUID.
rephrased_query = self.llm.gen(
model=getattr(self.llm, "model_id", None) or self.model_id,
messages=messages,
)
print(f"Rephrased query: {rephrased_query}")
return rephrased_query if rephrased_query else self.original_question
except Exception as e:

View File

@@ -0,0 +1,464 @@
"""SSRF protection for user-supplied OpenAI-compatible base URLs.
This module is the single chokepoint for validating any URL that a user
provides as an OpenAI-compatible ``base_url`` ("Bring Your Own Model").
The backend will later issue outbound HTTP requests to that URL on the
user's behalf, so we must reject anything that could be used to reach
internal-network resources (cloud metadata services, RFC 1918 ranges,
loopback, link-local, etc.).
Three entry points:
* :func:`validate_user_base_url` — called at create/update time on REST
routes that persist the URL, to give the user immediate feedback.
* :func:`pinned_post` — called at dispatch time when the caller drives
``requests`` directly (e.g. the ``/api/models/test`` endpoint).
Resolves once, dials the IP literal, preserves the original hostname
in the ``Host`` header and via SNI / cert verification for HTTPS.
* :func:`pinned_httpx_client` — called at dispatch time when the caller
hands an ``httpx.Client`` to a third-party SDK (e.g. the OpenAI
Python SDK via ``OpenAI(http_client=...)``). Same DNS-rebinding
closure on the httpx transport layer.
Why all three: the OpenAI / httpx ecosystem performs its own DNS lookup
inside ``socket.getaddrinfo`` when a connection opens, so a hostile DNS
server can hand a public IP to the validator and a loopback / link-local
address to the HTTP client. Validate-then-construct-SDK is unsafe; the
pinned variants close that TOCTOU window by resolving exactly once and
dialing the chosen IP literal directly.
"""
from __future__ import annotations
import ipaddress
import socket
from typing import Any, Iterable
from urllib.parse import urlsplit, urlunsplit
import httpx
import requests
from requests.adapters import HTTPAdapter
# Allowed URL schemes. Anything else (file, gopher, ftp, data, ...) is
# rejected outright because it either bypasses HTTP entirely or enables
# protocol smuggling against the proxy stack.
_ALLOWED_SCHEMES: frozenset[str] = frozenset({"http", "https"})
# Hostnames that resolve to a loopback / metadata / unspecified address
# but which we want to reject *by name* as well, so the rejection
# message is unambiguous and so we never accidentally call DNS on them.
_BLOCKED_HOSTNAMES: frozenset[str] = frozenset(
{
"localhost",
"localhost.localdomain",
"0.0.0.0",
"::",
"::1",
"ip6-localhost",
"ip6-loopback",
# GCP metadata service. AWS/Azure use 169.254.169.254 which the
# IP-range check below already covers via the link-local range,
# but Google's hostname does not always resolve to a link-local
# IP from every VPC, so we hard-deny the string too.
"metadata.google.internal",
}
)
# Carrier-grade NAT (RFC 6598). Python's ``ipaddress`` module does NOT
# classify this range as ``is_private``, so we must check it explicitly.
_CGNAT_NETWORK_V4: ipaddress.IPv4Network = ipaddress.IPv4Network("100.64.0.0/10")
class UnsafeUserUrlError(ValueError):
"""Raised when a user-supplied URL fails SSRF validation.
Subclasses :class:`ValueError` so call sites that already treat
invalid input as a 400-class error continue to work. The string
message names the specific reason (scheme, hostname, resolved IP,
DNS failure, ...) so that it can be surfaced to the user verbatim.
"""
def _strip_ipv6_brackets(host: str) -> str:
"""Return ``host`` with surrounding ``[`` / ``]`` removed if present."""
if host.startswith("[") and host.endswith("]"):
return host[1:-1]
return host
def _is_blocked_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
"""Return ``True`` if ``ip`` falls in any range we refuse to dial.
This is the single source of truth for the IP-range policy:
* loopback (``127.0.0.0/8``, ``::1``)
* private (RFC 1918, ULA ``fc00::/7``)
* link-local (``169.254.0.0/16``, ``fe80::/10``)
* multicast (``224.0.0.0/4``, ``ff00::/8``)
* unspecified (``0.0.0.0``, ``::``)
* reserved (``240.0.0.0/4``, etc.)
* carrier-grade NAT (``100.64.0.0/10``) — not covered by ``is_private``
"""
if (
ip.is_loopback
or ip.is_private
or ip.is_link_local
or ip.is_multicast
or ip.is_unspecified
or ip.is_reserved
):
return True
if isinstance(ip, ipaddress.IPv4Address) and ip in _CGNAT_NETWORK_V4:
return True
return False
def _resolve(host: str) -> Iterable[ipaddress.IPv4Address | ipaddress.IPv6Address]:
"""Resolve ``host`` to every A/AAAA record returned by the system.
Returning *all* addresses (rather than the first one) is critical:
a hostile DNS server can return a public IP first followed by a
private IP, and the underlying HTTP client may fail over to the
private one on connect. We treat the set as unsafe if any element
is unsafe.
"""
try:
results = socket.getaddrinfo(host, None)
except socket.gaierror as exc: # noqa: PERF203 — re-raise as our own type
raise UnsafeUserUrlError(f"could not resolve hostname {host!r}: {exc}") from exc
addresses: list[ipaddress.IPv4Address | ipaddress.IPv6Address] = []
for entry in results:
sockaddr = entry[4]
# IPv4 sockaddr: (host, port). IPv6 sockaddr: (host, port, flowinfo, scope_id).
ip_str = sockaddr[0]
# Strip IPv6 zone-id ("fe80::1%lo0") before parsing.
if "%" in ip_str:
ip_str = ip_str.split("%", 1)[0]
try:
addresses.append(ipaddress.ip_address(ip_str))
except ValueError:
# An entry we can't parse is itself suspicious; treat as unsafe.
raise UnsafeUserUrlError(
f"hostname {host!r} resolved to unparseable address {ip_str!r}"
) from None
return addresses
def _validate_and_pick_ip(
url: str,
) -> tuple[str, ipaddress.IPv4Address | ipaddress.IPv6Address, "urlsplit"]:
"""Run the SSRF guard and return the data needed to dial safely.
Performs every check :func:`validate_user_base_url` performs, but
additionally returns ``(hostname, ip, parts)`` where ``ip`` is one
of the validated addresses (the first record returned by the
resolver, or the literal itself if the URL already used an IP) and
``parts`` is the :func:`urllib.parse.urlsplit` result so callers do
not have to re-parse the URL.
Raises :class:`UnsafeUserUrlError` on the same conditions as
:func:`validate_user_base_url`.
"""
if not isinstance(url, str) or not url.strip():
raise UnsafeUserUrlError("url must be a non-empty string")
try:
parts = urlsplit(url)
except ValueError as exc:
raise UnsafeUserUrlError(f"could not parse url {url!r}: {exc}") from exc
scheme = parts.scheme.lower()
if scheme not in _ALLOWED_SCHEMES:
raise UnsafeUserUrlError(
f"scheme {scheme!r} is not allowed; only http and https are permitted"
)
# ``urlsplit`` returns the bracketed form for IPv6 in ``netloc`` but
# the bare form in ``hostname``. Normalize via lower() because
# hostnames are case-insensitive and we compare against a lowercase
# blocklist.
raw_host = parts.hostname
if not raw_host:
raise UnsafeUserUrlError(f"url {url!r} has no hostname")
host = raw_host.lower()
# Check the literal-string blocklist first. urlsplit().hostname strips
# IPv6 brackets, so we also test the bracketed form for completeness
# (matches the public-spec note about ``[::]``).
bracketed = f"[{host}]"
if host in _BLOCKED_HOSTNAMES or bracketed in _BLOCKED_HOSTNAMES:
raise UnsafeUserUrlError(
f"hostname {raw_host!r} is not allowed (matches internal-only name)"
)
# If the host is already an IP literal (with or without IPv6 brackets),
# check it directly without going to DNS — DNS for an IP literal is a
# no-op but it's clearer to short-circuit and gives a better message.
candidate = _strip_ipv6_brackets(host)
try:
literal = ipaddress.ip_address(candidate)
except ValueError:
literal = None
if literal is not None:
if _is_blocked_ip(literal):
raise UnsafeUserUrlError(
f"hostname {raw_host!r} resolves to blocked address {literal} "
f"(loopback/private/link-local/multicast/reserved/CGNAT)"
)
return host, literal, parts
# Hostname (not an IP literal) — resolve and validate every record.
addresses = list(_resolve(host))
for ip in addresses:
if _is_blocked_ip(ip):
raise UnsafeUserUrlError(
f"hostname {raw_host!r} resolves to blocked address {ip} "
f"(loopback/private/link-local/multicast/reserved/CGNAT)"
)
if not addresses:
# ``getaddrinfo`` would normally raise instead of returning an
# empty list, but treat the degenerate case as unsafe too — we
# have nothing to bind a connection to.
raise UnsafeUserUrlError(
f"hostname {raw_host!r} returned no addresses from DNS"
)
return host, addresses[0], parts
def validate_user_base_url(url: str) -> None:
"""Validate that ``url`` is safe to use as an outbound base URL.
Resolve the URL's hostname to one or more IPs and reject if any
resolved IP is private/loopback/link-local/multicast/reserved, or if
the URL uses a non-http(s) scheme, or if the hostname is one of the
known dangerous strings (``localhost``, ``0.0.0.0``, ``[::]``).
Raises :class:`UnsafeUserUrlError` on rejection. Returns ``None`` on
success.
This function is the create/update-time check. At dispatch time use
:func:`pinned_post` instead, which performs the same validation
*and* pins the outbound connection to the validated IP so a DNS
rebinder cannot flip the resolution between check and connect.
Args:
url: The user-supplied URL to validate. Expected to be an
absolute URL with an ``http`` or ``https`` scheme.
Raises:
UnsafeUserUrlError: If the URL fails to parse, uses a forbidden
scheme, has an empty/blocklisted hostname, fails DNS
resolution, or resolves to any IP in a blocked range.
"""
_validate_and_pick_ip(url)
class _PinnedHostAdapter(HTTPAdapter):
"""HTTPS adapter that performs SNI and cert verification against a
fixed hostname even when the URL connects to an IP literal.
Used by :func:`pinned_post` so that resolving the user-supplied
hostname once and dialing the resolved IP doesn't break TLS.
Without this, ``urllib3`` would default ``server_hostname`` /
``assert_hostname`` to the connect host (the IP) and either send the
wrong SNI or fail cert verification — the cert is for the original
hostname, not the IP literal.
"""
def __init__(self, server_hostname: str, *args: Any, **kwargs: Any) -> None:
self._server_hostname = server_hostname
super().__init__(*args, **kwargs)
def init_poolmanager(self, *args: Any, **kwargs: Any) -> None:
kwargs["server_hostname"] = self._server_hostname
kwargs["assert_hostname"] = self._server_hostname
super().init_poolmanager(*args, **kwargs)
def _ip_to_url_host(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> str:
"""Return ``ip`` formatted for use in a URL netloc (brackets for v6)."""
if isinstance(ip, ipaddress.IPv6Address):
return f"[{ip}]"
return str(ip)
def pinned_post(
url: str,
*,
json: Any = None,
headers: dict[str, str] | None = None,
timeout: float = 5.0,
allow_redirects: bool = False,
) -> requests.Response:
"""POST to ``url`` with the outbound connection pinned to a single
validated IP, closing the DNS-rebinding TOCTOU window left by the
naive validate-then-``requests.post`` pattern.
The URL's hostname is resolved exactly once. Every returned address
must pass the same SSRF guard as :func:`validate_user_base_url`. The
outbound request is issued against the chosen IP literal (so
``urllib3`` cannot ask the resolver again and receive a different
answer); the original hostname is preserved in the ``Host`` header
and, for HTTPS, via :class:`_PinnedHostAdapter` for SNI and cert
verification.
Args:
url: Absolute http(s) URL to POST to.
json: JSON-serializable payload — passed through to ``requests``.
headers: Caller-supplied headers. Any caller-supplied ``Host``
entry is overwritten so the in-flight request matches what
was validated.
timeout: Per-request timeout (seconds).
allow_redirects: Forwarded to ``requests``. Defaults to
``False`` because the SSRF guard only inspects the supplied
URL — following redirects would let a hostile upstream
bounce the request to an internal address.
Raises:
UnsafeUserUrlError: If the URL fails the SSRF guard.
requests.RequestException: For network-level failures.
"""
host, ip, parts = _validate_and_pick_ip(url)
netloc = _ip_to_url_host(ip)
if parts.port is not None:
netloc = f"{netloc}:{parts.port}"
pinned_url = urlunsplit(
(parts.scheme, netloc, parts.path, parts.query, parts.fragment)
)
request_headers = dict(headers or {})
host_header = host if parts.port is None else f"{host}:{parts.port}"
request_headers["Host"] = host_header
session = requests.Session()
if parts.scheme == "https":
session.mount("https://", _PinnedHostAdapter(host))
try:
return session.post(
pinned_url,
json=json,
headers=request_headers,
timeout=timeout,
allow_redirects=allow_redirects,
)
finally:
session.close()
class _PinnedHTTPSTransport(httpx.HTTPTransport):
"""``httpx`` transport pinned to a single validated IP literal.
Closes the DNS-rebinding TOCTOU window that
:func:`validate_user_base_url` cannot close on its own. The OpenAI
Python SDK (and any other SDK that uses ``httpx``) re-resolves the
hostname inside ``socket.getaddrinfo`` at request time, so a
hostile DNS server can return a public IP at validation time and a
private IP at request time. This transport rewrites every outgoing
request's URL host to the validated IP literal so ``httpcore``
dials that IP without a fresh lookup.
The original hostname is preserved in two places:
1. ``Host`` header — ``httpx.Request._prepare`` set it from the URL
netloc *before* this transport runs, so it carries the hostname
not the IP literal. We deliberately do not touch headers here.
2. TLS SNI / cert verification — set via the
``request.extensions["sni_hostname"]`` extension which
``httpcore`` feeds into ``start_tls``'s ``server_hostname``
parameter. Without this, ``urllib3``-equivalent code would use
the IP literal as SNI and cert verification would fail (the
cert is for the original hostname, not the IP).
"""
def __init__(
self,
validated_host: str,
validated_ip: ipaddress.IPv4Address | ipaddress.IPv6Address,
**kwargs: Any,
) -> None:
# http2=False (the httpx default) — defense in depth against
# HTTP/2 connection coalescing (RFC 7540 §9.1.1), where a
# client may reuse a TCP connection for any host whose cert
# covers it. Per-IP pinning never shares connections across
# hosts, but explicit is safer than relying on the default.
kwargs.setdefault("http2", False)
super().__init__(**kwargs)
self._host = validated_host
self._ip_netloc = _ip_to_url_host(validated_ip)
def handle_request(self, request: httpx.Request) -> httpx.Response:
# Defense in depth: refuse if the request URL's host doesn't
# match what we validated. Catches any future SDK regression
# that rewrites the URL between Request construction and dial,
# and any rare case where the SDK reuses our pinned client for
# a different host (which it shouldn't, but assert it anyway).
if request.url.host != self._host:
raise UnsafeUserUrlError(
f"pinned transport bound to {self._host!r}, refused "
f"request for {request.url.host!r}"
)
# SNI/server_hostname for TLS verification. httpcore reads this
# extension at _sync/connection.py and feeds it into
# start_tls's server_hostname argument. Set before the URL host
# is rewritten so cert validation continues to use the original
# hostname even though TCP dials the IP literal.
request.extensions = {
**request.extensions,
"sni_hostname": self._host.encode("ascii"),
}
request.url = request.url.copy_with(host=self._ip_netloc)
return super().handle_request(request)
def pinned_httpx_client(
base_url: str,
*,
timeout: float = 600.0,
) -> httpx.Client:
"""Return an :class:`httpx.Client` whose connections are pinned to
one validated IP, closing the DNS-rebinding TOCTOU window the naive
``OpenAI(base_url=...)`` flow leaves open.
The hostname in ``base_url`` is resolved exactly once. Every
returned address must pass :func:`_validate_and_pick_ip`'s SSRF
guard (loopback, RFC 1918, link-local, multicast, reserved, CGNAT,
cloud metadata names). The chosen IP becomes the URL host on every
outgoing request so ``httpcore`` cannot ask the resolver again.
Pass via ``OpenAI(http_client=pinned_httpx_client(base_url))`` (or
any other SDK that accepts an ``httpx.Client``) to make BYOM
dispatch immune to DNS-rebinding TOCTOU.
Args:
base_url: User-supplied http(s) URL. Validated through the same
SSRF guard as :func:`validate_user_base_url`.
timeout: Per-request timeout (seconds). Defaults to 600 to
match the OpenAI SDK's default; callers should override
for non-LLM workloads.
Raises:
UnsafeUserUrlError: If ``base_url`` fails the SSRF guard.
"""
host, ip, _parts = _validate_and_pick_ip(base_url)
transport = _PinnedHTTPSTransport(host, ip)
# follow_redirects=False — the SSRF guard only inspects the
# supplied URL; following 3xx would let a hostile upstream bounce
# the in-network request to an internal address (cloud metadata,
# RFC1918, loopback) carrying whatever credentials the SDK adds.
return httpx.Client(
transport=transport,
timeout=timeout,
follow_redirects=False,
)

View File

@@ -203,6 +203,24 @@ agents_table = Table(
Column("legacy_mongo_id", Text),
)
user_custom_models_table = Table(
"user_custom_models",
metadata,
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
Column("user_id", Text, nullable=False),
Column("upstream_model_id", Text, nullable=False),
Column("display_name", Text, nullable=False),
Column("description", Text, nullable=False, server_default=""),
Column("base_url", Text, nullable=False),
# AES-CBC ciphertext (base64) keyed via per-user PBKDF2 in
# application.security.encryption.encrypt_credentials.
Column("api_key_encrypted", Text, nullable=False),
Column("capabilities", JSONB, nullable=False, server_default="{}"),
Column("enabled", Boolean, nullable=False, server_default="true"),
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
)
attachments_table = Table(
"attachments",
metadata,

View File

@@ -0,0 +1,199 @@
"""Repository for the ``user_custom_models`` table.
Backs the end-user "Bring Your Own Model" feature. Each row is one
user-supplied OpenAI-compatible endpoint (Mistral, Together, vLLM, ...).
The ``id`` UUID is the internal DocsGPT identifier (what agents store
in ``default_model_id``); ``upstream_model_id`` is what we send verbatim
to the provider's API.
API key handling: callers pass plaintext via ``api_key_plaintext``;
this module wraps the existing ``application.security.encryption``
helper (AES-CBC + per-user PBKDF2 salt) and writes the base64 ciphertext
to the ``api_key_encrypted`` column. Decryption is the caller's
responsibility (they hold the ``user_id``).
"""
from __future__ import annotations
from typing import Any, Optional
from sqlalchemy import Connection, func, text
from application.security.encryption import (
decrypt_credentials,
encrypt_credentials,
)
from application.storage.db.base_repository import row_to_dict
from application.storage.db.models import user_custom_models_table
_ALLOWED_CAPABILITY_KEYS = frozenset(
{
"supports_tools",
"supports_structured_output",
"supports_streaming",
"attachments",
"context_window",
}
)
class UserCustomModelsRepository:
def __init__(self, conn: Connection) -> None:
self._conn = conn
# ------------------------------------------------------------------ #
# Encryption wrappers
# ------------------------------------------------------------------ #
@staticmethod
def _encrypt_api_key(api_key_plaintext: str, user_id: str) -> str:
"""Encrypt ``api_key_plaintext`` with the per-user PBKDF2 scheme."""
return encrypt_credentials({"api_key": api_key_plaintext}, user_id)
@staticmethod
def _decrypt_api_key(api_key_encrypted: str, user_id: str) -> Optional[str]:
"""Decrypt the API key. Returns None on failure (which the caller
should surface as a configuration error rather than silently
proceeding with the upstream call)."""
if not api_key_encrypted:
return None
creds = decrypt_credentials(api_key_encrypted, user_id)
return creds.get("api_key") if creds else None
@staticmethod
def _normalize_capabilities(caps: Optional[dict]) -> dict:
"""Drop unknown keys; nothing else is forced. Callers (the route
layer) are responsible for value validation (numeric ranges,
attachment alias resolution)."""
if not caps:
return {}
return {k: v for k, v in caps.items() if k in _ALLOWED_CAPABILITY_KEYS}
# ------------------------------------------------------------------ #
# CRUD
# ------------------------------------------------------------------ #
def create(
self,
user_id: str,
upstream_model_id: str,
display_name: str,
base_url: str,
api_key_plaintext: str,
description: str = "",
capabilities: Optional[dict] = None,
enabled: bool = True,
) -> dict:
values = {
"user_id": user_id,
"upstream_model_id": upstream_model_id,
"display_name": display_name,
"description": description or "",
"base_url": base_url,
"api_key_encrypted": self._encrypt_api_key(api_key_plaintext, user_id),
"capabilities": self._normalize_capabilities(capabilities),
"enabled": bool(enabled),
}
from sqlalchemy.dialects.postgresql import insert as pg_insert
stmt = (
pg_insert(user_custom_models_table)
.values(**values)
.returning(user_custom_models_table)
)
result = self._conn.execute(stmt)
return row_to_dict(result.fetchone())
def get(self, model_id: str, user_id: str) -> Optional[dict]:
result = self._conn.execute(
text(
"SELECT * FROM user_custom_models "
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
),
{"id": str(model_id), "user_id": user_id},
)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def list_for_user(self, user_id: str) -> list[dict]:
result = self._conn.execute(
text(
"SELECT * FROM user_custom_models "
"WHERE user_id = :user_id ORDER BY created_at DESC"
),
{"user_id": user_id},
)
return [row_to_dict(r) for r in result.fetchall()]
def update(self, model_id: str, user_id: str, fields: dict) -> bool:
"""Apply a partial update.
Special-cases ``api_key_plaintext``: when present, it is encrypted
and stored in ``api_key_encrypted``. When absent (or empty), the
existing ciphertext is kept untouched. This is the wire-shape
``PATCH`` expects (the UI sends a blank password field when the
operator wants to keep the existing key).
"""
allowed = {
"upstream_model_id",
"display_name",
"description",
"base_url",
"capabilities",
"enabled",
}
values: dict[str, Any] = {}
for col, val in fields.items():
if col not in allowed or val is None:
continue
if col == "capabilities":
values[col] = self._normalize_capabilities(val)
elif col == "enabled":
values[col] = bool(val)
else:
values[col] = val
api_key_plaintext = fields.get("api_key_plaintext")
if api_key_plaintext:
values["api_key_encrypted"] = self._encrypt_api_key(
api_key_plaintext, user_id
)
if not values:
return False
values["updated_at"] = func.now()
t = user_custom_models_table
stmt = (
t.update()
.where(t.c.id == str(model_id))
.where(t.c.user_id == user_id)
.values(**values)
)
result = self._conn.execute(stmt)
return result.rowcount > 0
def delete(self, model_id: str, user_id: str) -> bool:
result = self._conn.execute(
text(
"DELETE FROM user_custom_models "
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
),
{"id": str(model_id), "user_id": user_id},
)
return result.rowcount > 0
# ------------------------------------------------------------------ #
# Decryption helpers exposed to the registry layer
# ------------------------------------------------------------------ #
def get_decrypted_api_key(
self, model_id: str, user_id: str
) -> Optional[str]:
"""Convenience: fetch the row and return the decrypted API key,
or ``None`` if the row is missing or decryption fails."""
row = self.get(model_id, user_id)
if row is None:
return None
return self._decrypt_api_key(row.get("api_key_encrypted", ""), user_id)

View File

@@ -83,9 +83,9 @@ def count_tokens_docs(docs):
def calculate_doc_token_budget(
model_id: str = "gpt-4o"
model_id: str = "gpt-4o", user_id: str | None = None
) -> int:
total_context = get_token_limit(model_id)
total_context = get_token_limit(model_id, user_id=user_id)
reserved = sum(settings.RESERVED_TOKENS.values())
doc_budget = total_context - reserved
return max(doc_budget, 1000)
@@ -150,9 +150,11 @@ def get_hash(data):
return hashlib.md5(data.encode(), usedforsecurity=False).hexdigest()
def limit_chat_history(history, max_token_limit=None, model_id="docsgpt-local"):
def limit_chat_history(
history, max_token_limit=None, model_id="docsgpt-local", user_id=None
):
"""Limit chat history to fit within token limit."""
model_token_limit = get_token_limit(model_id)
model_token_limit = get_token_limit(model_id, user_id=user_id)
max_token_limit = (
max_token_limit
if max_token_limit and max_token_limit < model_token_limit
@@ -204,7 +206,9 @@ def generate_image_url(image_path):
def calculate_compression_threshold(
model_id: str, threshold_percentage: float = 0.8
model_id: str,
threshold_percentage: float = 0.8,
user_id: str | None = None,
) -> int:
"""
Calculate token threshold for triggering compression.
@@ -212,11 +216,13 @@ def calculate_compression_threshold(
Args:
model_id: Model identifier
threshold_percentage: Percentage of context window (default 80%)
user_id: When set, BYOM custom-model records (UUID-keyed) resolve
for context-window lookup.
Returns:
Token count threshold
"""
total_context = get_token_limit(model_id)
total_context = get_token_limit(model_id, user_id=user_id)
threshold = int(total_context * threshold_percentage)
return threshold

View File

@@ -344,7 +344,9 @@ def run_agent_logic(agent_config, input_data):
# Determine model_id: check agent's default_model_id, fallback to system default
agent_default_model = agent_config.get("default_model_id", "")
if agent_default_model and validate_model_id(agent_default_model):
if agent_default_model and validate_model_id(
agent_default_model, user_id=owner
):
model_id = agent_default_model
else:
model_id = get_default_model_id()
@@ -360,12 +362,16 @@ def run_agent_logic(agent_config, input_data):
)
# Get provider and API key for the selected model
provider = get_provider_from_model_id(model_id) if model_id else settings.LLM_PROVIDER
provider = (
get_provider_from_model_id(model_id, user_id=owner)
if model_id
else settings.LLM_PROVIDER
)
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
# Calculate proper doc_token_limit based on model's context window
doc_token_limit = calculate_doc_token_budget(
model_id=model_id
model_id=model_id, user_id=owner
)
retriever = RetrieverCreator.create_retriever(

View File

@@ -448,7 +448,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
setUserTools(tools);
};
const getModels = async () => {
const response = await modelService.getModels(null);
const response = await modelService.getModels(token);
if (!response.ok) throw new Error('Failed to fetch models');
const data = await response.json();
const transformed = modelService.transformModels(data.models || []);
@@ -1041,10 +1041,24 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
isOpen={isModelsPopupOpen}
onClose={() => setIsModelsPopupOpen(false)}
anchorRef={modelAnchorButtonRef}
options={availableModels.map((model) => ({
id: model.id,
label: model.display_name,
}))}
options={(() => {
const builtinLabel = t(
'settings.customModels.modelsGroup.builtin',
);
const userLabel = t('settings.customModels.modelsGroup.user');
const builtin: OptionType[] = [];
const user: OptionType[] = [];
availableModels.forEach((model) => {
const opt: OptionType = {
id: model.id,
label: model.display_name,
group: model.source === 'user' ? userLabel : builtinLabel,
};
if (model.source === 'user') user.push(opt);
else builtin.push(opt);
});
return [...builtin, ...user];
})()}
selectedIds={selectedModelIds}
onSelectionChange={(newSelectedIds: Set<string | number>) =>
setSelectedModelIds(

View File

@@ -42,7 +42,9 @@ import { MultiSelect } from '@/components/ui/multi-select';
import {
Select,
SelectContent,
SelectGroup,
SelectItem,
SelectLabel,
SelectTrigger,
SelectValue,
} from '@/components/ui/select';
@@ -706,7 +708,7 @@ function WorkflowBuilderInner() {
useEffect(() => {
const loadModelsAndTools = async () => {
try {
const modelsResponse = await modelService.getModels(null);
const modelsResponse = await modelService.getModels(token);
if (modelsResponse.ok) {
const modelsData = await modelsResponse.json();
const transformedModels = modelService.transformModels(
@@ -732,7 +734,7 @@ function WorkflowBuilderInner() {
}
};
loadModelsAndTools();
}, []);
}, [token]);
useEffect(() => {
if (!selectedNode || selectedNode.type !== 'agent') return;
@@ -1847,15 +1849,54 @@ function WorkflowBuilderInner() {
<SelectValue placeholder="Select a model" />
</SelectTrigger>
<SelectContent>
{availableModels.map((model) => (
<SelectItem
key={model.id}
value={model.id}
>
{model.display_name} ·{' '}
{model.provider}
</SelectItem>
))}
{(() => {
const builtin = availableModels.filter(
(m) => m.source !== 'user',
);
const user = availableModels.filter(
(m) => m.source === 'user',
);
return (
<>
{builtin.length > 0 && (
<SelectGroup>
<SelectLabel>
{t(
'settings.customModels.modelsGroup.builtin',
)}
</SelectLabel>
{builtin.map((model) => (
<SelectItem
key={model.id}
value={model.id}
>
{model.display_name} ·{' '}
{model.provider}
</SelectItem>
))}
</SelectGroup>
)}
{user.length > 0 && (
<SelectGroup>
<SelectLabel>
{t(
'settings.customModels.modelsGroup.user',
)}
</SelectLabel>
{user.map((model) => (
<SelectItem
key={model.id}
value={model.id}
>
{model.display_name} ·{' '}
{model.provider}
</SelectItem>
))}
</SelectGroup>
)}
</>
);
})()}
</SelectContent>
</Select>
</div>

View File

@@ -80,6 +80,22 @@ const apiClient = {
return response;
}),
patch: (
url: string,
data: any,
token: string | null,
headers = {},
signal?: AbortSignal,
): Promise<any> =>
fetch(`${baseURL}${url}`, {
method: 'PATCH',
headers: getHeaders(token, headers),
body: JSON.stringify(data),
signal,
}).then((response) => {
return response;
}),
putFormData: (
url: string,
formData: FormData,

View File

@@ -76,6 +76,10 @@ const endpoints = {
GET_ARTIFACT: (artifactId: string) => `/api/artifact/${artifactId}`,
WORKFLOWS: '/api/workflows',
WORKFLOW: (id: string) => `/api/workflows/${id}`,
CUSTOM_MODELS: '/api/user/models',
CUSTOM_MODEL: (id: string) => `/api/user/models/${id}`,
CUSTOM_MODEL_TEST: (id: string) => `/api/user/models/${id}/test`,
CUSTOM_MODEL_TEST_PAYLOAD: '/api/user/models/test',
},
V1: {
CHAT_COMPLETIONS: '/v1/chat/completions',

View File

@@ -0,0 +1,162 @@
import apiClient from '../client';
import endpoints from '../endpoints';
import type {
CreateCustomModelPayload,
CustomModel,
CustomModelTestResult,
} from '../../models/types';
const parseJsonOrError = async (response: Response): Promise<any> => {
const text = await response.text();
let body: any = null;
if (text) {
try {
body = JSON.parse(text);
} catch {
body = null;
}
}
if (!response.ok) {
const message =
(body && (body.error || body.message)) ||
`Request failed with status ${response.status}`;
const err = new Error(message) as Error & {
status?: number;
payload?: unknown;
};
err.status = response.status;
err.payload = body;
throw err;
}
return body;
};
const customModelsService = {
listCustomModels: async (token: string | null): Promise<CustomModel[]> => {
const response = await apiClient.get(endpoints.USER.CUSTOM_MODELS, token);
const data = await parseJsonOrError(response);
if (Array.isArray(data)) return data as CustomModel[];
if (data && Array.isArray(data.models)) return data.models as CustomModel[];
return [];
},
createCustomModel: async (
payload: CreateCustomModelPayload,
token: string | null,
): Promise<CustomModel> => {
const response = await apiClient.post(
endpoints.USER.CUSTOM_MODELS,
payload,
token,
);
return (await parseJsonOrError(response)) as CustomModel;
},
updateCustomModel: async (
id: string,
payload: Partial<CreateCustomModelPayload>,
token: string | null,
): Promise<CustomModel> => {
const response = await apiClient.patch(
endpoints.USER.CUSTOM_MODEL(id),
payload,
token,
);
return (await parseJsonOrError(response)) as CustomModel;
},
deleteCustomModel: async (
id: string,
token: string | null,
): Promise<void> => {
const response = await apiClient.delete(
endpoints.USER.CUSTOM_MODEL(id),
token,
);
if (!response.ok) {
await parseJsonOrError(response);
}
},
testCustomModelPayload: async (
payload: {
base_url: string;
api_key: string;
upstream_model_id: string;
},
token: string | null,
): Promise<CustomModelTestResult> => {
const response = await apiClient.post(
endpoints.USER.CUSTOM_MODEL_TEST_PAYLOAD,
payload,
token,
);
const text = await response.text();
let body: any = null;
if (text) {
try {
body = JSON.parse(text);
} catch {
body = null;
}
}
if (!response.ok) {
return {
ok: false,
error:
(body && (body.error || body.message)) ||
`Test failed with status ${response.status}`,
};
}
if (body && typeof body.ok === 'boolean') {
return body as CustomModelTestResult;
}
return { ok: true };
},
testCustomModel: async (
id: string,
token: string | null,
overrides: {
base_url?: string;
api_key?: string;
upstream_model_id?: string;
} = {},
): Promise<CustomModelTestResult> => {
// Send only non-empty overrides; server falls back to stored values.
const requestBody: Record<string, string> = {};
if (overrides.base_url) requestBody.base_url = overrides.base_url;
if (overrides.api_key) requestBody.api_key = overrides.api_key;
if (overrides.upstream_model_id)
requestBody.upstream_model_id = overrides.upstream_model_id;
const response = await apiClient.post(
endpoints.USER.CUSTOM_MODEL_TEST(id),
requestBody,
token,
);
const text = await response.text();
let body: any = null;
if (text) {
try {
body = JSON.parse(text);
} catch {
body = null;
}
}
if (!response.ok) {
return {
ok: false,
error:
(body && (body.error || body.message)) ||
`Test failed with status ${response.status}`,
};
}
if (body && typeof body.ok === 'boolean') {
return body as CustomModelTestResult;
}
return { ok: true };
},
};
export default customModelsService;

View File

@@ -19,6 +19,7 @@ const modelService = {
supports_tools: model.supports_tools,
supports_structured_output: model.supports_structured_output,
supports_streaming: model.supports_streaming,
source: model.source,
})),
};

View File

@@ -7,6 +7,7 @@ import RoundedTick from '../assets/rounded-tick.svg';
import {
selectAvailableModels,
selectSelectedModel,
selectToken,
setAvailableModels,
setModelsLoading,
setSelectedModel,
@@ -18,17 +19,26 @@ export default function DropdownModel() {
const dispatch = useDispatch();
const selectedModel = useSelector(selectSelectedModel);
const availableModels = useSelector(selectAvailableModels);
const token = useSelector(selectToken);
const dropdownRef = React.useRef<HTMLDivElement>(null);
// Tracks which token the cached availableModels were loaded for.
// Without this, the early-return below pins the anonymous/built-in
// list forever once it's populated — login/logout never refetches
// and a user's BYOM models stay invisible.
const lastLoadedTokenRef = React.useRef<string | null | undefined>(undefined);
const [isOpen, setIsOpen] = React.useState(false);
useEffect(() => {
const loadModels = async () => {
if ((availableModels?.length ?? 0) > 0) {
if (
(availableModels?.length ?? 0) > 0 &&
lastLoadedTokenRef.current === token
) {
return;
}
dispatch(setModelsLoading(true));
try {
const response = await modelService.getModels(null);
const response = await modelService.getModels(token);
if (!response.ok) {
throw new Error(`API error: ${response.status}`);
}
@@ -37,6 +47,7 @@ export default function DropdownModel() {
const transformed = modelService.transformModels(models);
dispatch(setAvailableModels(transformed));
lastLoadedTokenRef.current = token;
if (!selectedModel && transformed.length > 0) {
const defaultModel =
transformed.find((m) => m.id === data.default_model_id) ||
@@ -59,7 +70,7 @@ export default function DropdownModel() {
};
loadModels();
}, [availableModels?.length, dispatch, selectedModel]);
}, [availableModels?.length, dispatch, selectedModel, token]);
const handleClickOutside = (event: MouseEvent) => {
if (

View File

@@ -11,6 +11,7 @@ export type OptionType = {
id: string | number;
label: string;
icon?: string | React.ReactNode;
group?: string;
[key: string]: any;
};
@@ -227,43 +228,75 @@ export default function MultiSelectPopup({
</p>
</div>
) : (
filteredOptions.map((option) => {
const isSelected = selectedIds.has(option.id);
return (
<div
key={option.id}
onClick={() => handleOptionClick(option.id)}
className="dark:border-border dark:hover:bg-accent hover:bg-accent flex cursor-pointer items-center justify-between border-b border-[#D9D9D9] p-3 last:border-b-0"
role="option"
aria-selected={isSelected}
>
<div className="mr-3 flex grow items-center overflow-hidden">
{option.icon && renderIcon(option.icon)}
<p
className="overflow-hidden text-sm font-medium text-ellipsis whitespace-nowrap text-gray-900 dark:text-white"
title={option.label}
>
{option.label}
</p>
</div>
<div className="shrink-0">
<div
className={`border-border bg-card flex h-4 w-4 items-center justify-center rounded-xs border-2`}
aria-hidden="true"
>
{isSelected && (
<img
src={CheckmarkIcon}
alt="checkmark"
width={10}
height={10}
/>
)}
(() => {
const hasGroups = filteredOptions.some((o) => !!o.group);
const renderOption = (option: OptionType) => {
const isSelected = selectedIds.has(option.id);
return (
<div
key={option.id}
onClick={() => handleOptionClick(option.id)}
className="dark:border-border dark:hover:bg-accent hover:bg-accent flex cursor-pointer items-center justify-between border-b border-[#D9D9D9] p-3 last:border-b-0"
role="option"
aria-selected={isSelected}
>
<div className="mr-3 flex grow items-center overflow-hidden">
{option.icon && renderIcon(option.icon)}
<p
className="overflow-hidden text-sm font-medium text-ellipsis whitespace-nowrap text-gray-900 dark:text-white"
title={option.label}
>
{option.label}
</p>
</div>
<div className="shrink-0">
<div
className={`border-border bg-card flex h-4 w-4 items-center justify-center rounded-xs border-2`}
aria-hidden="true"
>
{isSelected && (
<img
src={CheckmarkIcon}
alt="checkmark"
width={10}
height={10}
/>
)}
</div>
</div>
</div>
);
};
if (!hasGroups) {
return filteredOptions.map(renderOption);
}
const groupOrder: string[] = [];
const groupMap = new Map<string, OptionType[]>();
filteredOptions.forEach((opt) => {
const key = opt.group || '';
if (!groupMap.has(key)) {
groupOrder.push(key);
groupMap.set(key, []);
}
groupMap.get(key)!.push(opt);
});
return groupOrder.map((groupKey) => (
<div key={`group-${groupKey || 'ungrouped'}`}>
{groupKey && (
<div
className="bg-muted/50 dark:bg-card text-muted-foreground sticky top-0 z-10 border-b border-[#D9D9D9] px-3 py-1.5 text-xs font-semibold uppercase dark:border-[#2E2E2E]"
role="presentation"
>
{groupKey}
</div>
)}
{(groupMap.get(groupKey) || []).map(renderOption)}
</div>
);
})
));
})()
)}
</div>
)}

View File

@@ -14,6 +14,7 @@ const useTabs = () => {
t('settings.analytics.label'),
t('settings.logs.label'),
t('settings.tools.label'),
t('settings.customModels.label'),
];
return tabs;
};

View File

@@ -244,6 +244,71 @@
}
}
},
"customModels": {
"label": "Eigene Modelle",
"subtitle": "Verbinde dein eigenes OpenAI-kompatibles Modell.",
"addModel": "Modell hinzufügen",
"empty": "Keine Modelle gefunden",
"searchPlaceholder": "Modelle suchen...",
"modalSubtitle": "Verbinde einen beliebigen OpenAI-kompatiblen Endpunkt (Mistral, Together, vLLM usw.).",
"addTitle": "Eigenes Modell hinzufügen",
"editTitle": "Eigenes Modell bearbeiten",
"save": "Speichern",
"saving": "Speichern",
"testConnection": "Verbindung testen",
"testing": "Test läuft",
"testSuccess": "Verbindung erfolgreich.",
"testHintNew": "Fülle Basis-URL, Modell-ID und API-Schlüssel aus, um Verbindungstests zu aktivieren.",
"disabledBadge": "Deaktiviert",
"deleteWarning": "Möchtest du das eigene Modell \"{{modelName}}\" wirklich löschen?",
"actionsMenuAria": "Aktionen für {{modelName}}",
"modelsGroup": {
"builtin": "DocsGPT-Modelle",
"user": "Meine Modelle"
},
"actions": {
"edit": "Bearbeiten",
"delete": "Löschen"
},
"fields": {
"displayName": "Anzeigename",
"modelId": "Modell-ID",
"description": "Beschreibung",
"baseUrl": "Basis-URL",
"apiKey": "API-Schlüssel"
},
"placeholders": {
"displayName": "Mein Mistral",
"modelId": "z. B. mistral-large-latest",
"description": "Optionale Beschreibung",
"baseUrl": "https://api.mistral.ai/v1",
"apiKey": "sk-...",
"apiKeyEdit": "Leer lassen, um den vorhandenen Schlüssel beizubehalten"
},
"hints": {
"apiKeyEdit": "Leer lassen, um den vorhandenen Schlüssel beizubehalten."
},
"capabilities": {
"title": "Fähigkeiten",
"contextWindowShort": "Kontextfenster",
"chips": {
"tools": "Tools",
"structuredOutput": "Strukturierte Ausgabe",
"images": "Bilder"
}
},
"errors": {
"displayNameRequired": "Anzeigename ist erforderlich",
"modelIdRequired": "Modell-ID ist erforderlich",
"baseUrlRequired": "Basis-URL ist erforderlich",
"baseUrlScheme": "Basis-URL muss mit http:// oder https:// beginnen",
"baseUrlInvalid": "Bitte gib eine gültige URL ein",
"apiKeyRequired": "API-Schlüssel ist erforderlich",
"contextWindowRange": "Kontextfenster muss zwischen 1.000 und 10.000.000 liegen",
"saveFailed": "Eigenes Modell konnte nicht gespeichert werden",
"testFailed": "Verbindungstest fehlgeschlagen"
}
},
"scrollTabsLeft": "Tabs nach links scrollen",
"tabsAriaLabel": "Einstellungs-Tabs",
"scrollTabsRight": "Tabs nach rechts scrollen"

View File

@@ -256,6 +256,71 @@
}
}
},
"customModels": {
"label": "Custom Models",
"subtitle": "Bring your own OpenAI-compatible model.",
"addModel": "Add Model",
"empty": "No models found",
"searchPlaceholder": "Search models...",
"modalSubtitle": "Connect any OpenAI-compatible endpoint (Mistral, Together, vLLM, etc).",
"addTitle": "Add Custom Model",
"editTitle": "Edit Custom Model",
"save": "Save",
"saving": "Saving",
"testConnection": "Test connection",
"testing": "Testing",
"testSuccess": "Connection successful.",
"testHintNew": "Fill in Base URL, Model ID, and API key to enable connection tests.",
"disabledBadge": "Disabled",
"deleteWarning": "Are you sure you want to delete the custom model \"{{modelName}}\"?",
"actionsMenuAria": "Actions for {{modelName}}",
"modelsGroup": {
"builtin": "DocsGPT Models",
"user": "My Models"
},
"actions": {
"edit": "Edit",
"delete": "Delete"
},
"fields": {
"displayName": "Display name",
"modelId": "Model ID",
"description": "Description",
"baseUrl": "Base URL",
"apiKey": "API key"
},
"placeholders": {
"displayName": "My Mistral",
"modelId": "e.g. mistral-large-latest",
"description": "Optional description",
"baseUrl": "https://api.mistral.ai/v1",
"apiKey": "sk-...",
"apiKeyEdit": "Leave blank to keep existing"
},
"hints": {
"apiKeyEdit": "Leave blank to keep existing key."
},
"capabilities": {
"title": "Capabilities",
"contextWindowShort": "Context window",
"chips": {
"tools": "Tools",
"structuredOutput": "Structured output",
"images": "Images"
}
},
"errors": {
"displayNameRequired": "Display name is required",
"modelIdRequired": "Model ID is required",
"baseUrlRequired": "Base URL is required",
"baseUrlScheme": "Base URL must start with http:// or https://",
"baseUrlInvalid": "Please enter a valid URL",
"apiKeyRequired": "API key is required",
"contextWindowRange": "Context window must be between 1,000 and 10,000,000",
"saveFailed": "Failed to save custom model",
"testFailed": "Connection test failed"
}
},
"scrollTabsLeft": "Scroll tabs left",
"tabsAriaLabel": "Settings tabs",
"scrollTabsRight": "Scroll tabs right"

View File

@@ -244,6 +244,71 @@
}
}
},
"customModels": {
"label": "Modelos personalizados",
"subtitle": "Trae tu propio modelo compatible con OpenAI.",
"addModel": "Añadir modelo",
"empty": "No se encontraron modelos",
"searchPlaceholder": "Buscar modelos...",
"modalSubtitle": "Conecta cualquier endpoint compatible con OpenAI (Mistral, Together, vLLM, etc.).",
"addTitle": "Añadir modelo personalizado",
"editTitle": "Editar modelo personalizado",
"save": "Guardar",
"saving": "Guardando",
"testConnection": "Probar conexión",
"testing": "Probando",
"testSuccess": "Conexión correcta.",
"testHintNew": "Completa URL base, ID de modelo y clave API para habilitar las pruebas de conexión.",
"disabledBadge": "Deshabilitado",
"deleteWarning": "¿Seguro que quieres eliminar el modelo personalizado \"{{modelName}}\"?",
"actionsMenuAria": "Acciones para {{modelName}}",
"modelsGroup": {
"builtin": "Modelos de DocsGPT",
"user": "Mis modelos"
},
"actions": {
"edit": "Editar",
"delete": "Eliminar"
},
"fields": {
"displayName": "Nombre para mostrar",
"modelId": "ID del modelo",
"description": "Descripción",
"baseUrl": "URL base",
"apiKey": "Clave API"
},
"placeholders": {
"displayName": "Mi Mistral",
"modelId": "p. ej. mistral-large-latest",
"description": "Descripción opcional",
"baseUrl": "https://api.mistral.ai/v1",
"apiKey": "sk-...",
"apiKeyEdit": "Déjalo en blanco para mantener la actual"
},
"hints": {
"apiKeyEdit": "Déjalo en blanco para mantener la clave actual."
},
"capabilities": {
"title": "Capacidades",
"contextWindowShort": "Ventana de contexto",
"chips": {
"tools": "Herramientas",
"structuredOutput": "Salida estructurada",
"images": "Imágenes"
}
},
"errors": {
"displayNameRequired": "El nombre para mostrar es obligatorio",
"modelIdRequired": "El ID del modelo es obligatorio",
"baseUrlRequired": "La URL base es obligatoria",
"baseUrlScheme": "La URL base debe empezar por http:// o https://",
"baseUrlInvalid": "Introduce una URL válida",
"apiKeyRequired": "La clave API es obligatoria",
"contextWindowRange": "La ventana de contexto debe estar entre 1.000 y 10.000.000",
"saveFailed": "No se pudo guardar el modelo personalizado",
"testFailed": "La prueba de conexión falló"
}
},
"scrollTabsLeft": "Desplazar pestañas a la izquierda",
"tabsAriaLabel": "Pestañas de configuración",
"scrollTabsRight": "Desplazar pestañas a la derecha"

View File

@@ -244,6 +244,71 @@
}
}
},
"customModels": {
"label": "カスタムモデル",
"subtitle": "OpenAI 互換のモデルを自分で接続できます。",
"addModel": "モデルを追加",
"empty": "モデルが見つかりません",
"searchPlaceholder": "モデルを検索...",
"modalSubtitle": "OpenAI 互換のエンドポイントを接続できます (Mistral、Together、vLLM など)。",
"addTitle": "カスタムモデルを追加",
"editTitle": "カスタムモデルを編集",
"save": "保存",
"saving": "保存中",
"testConnection": "接続をテスト",
"testing": "テスト中",
"testSuccess": "接続に成功しました。",
"testHintNew": "接続テストを有効にするには、ベース URL、モデル ID、API キーを入力してください。",
"disabledBadge": "無効",
"deleteWarning": "カスタムモデル「{{modelName}}」を削除してもよろしいですか?",
"actionsMenuAria": "{{modelName}} のアクション",
"modelsGroup": {
"builtin": "DocsGPT モデル",
"user": "マイモデル"
},
"actions": {
"edit": "編集",
"delete": "削除"
},
"fields": {
"displayName": "表示名",
"modelId": "モデル ID",
"description": "説明",
"baseUrl": "ベース URL",
"apiKey": "API キー"
},
"placeholders": {
"displayName": "My Mistral",
"modelId": "例: mistral-large-latest",
"description": "説明 (任意)",
"baseUrl": "https://api.mistral.ai/v1",
"apiKey": "sk-...",
"apiKeyEdit": "既存のキーを保持する場合は空欄のまま"
},
"hints": {
"apiKeyEdit": "既存のキーを保持するには空欄のままにしてください。"
},
"capabilities": {
"title": "機能",
"contextWindowShort": "コンテキストウィンドウ",
"chips": {
"tools": "ツール",
"structuredOutput": "構造化出力",
"images": "画像"
}
},
"errors": {
"displayNameRequired": "表示名は必須です",
"modelIdRequired": "モデル ID は必須です",
"baseUrlRequired": "ベース URL は必須です",
"baseUrlScheme": "ベース URL は http:// または https:// で始まる必要があります",
"baseUrlInvalid": "有効な URL を入力してください",
"apiKeyRequired": "API キーは必須です",
"contextWindowRange": "コンテキストウィンドウは 1,000 から 10,000,000 の範囲で指定してください",
"saveFailed": "カスタムモデルの保存に失敗しました",
"testFailed": "接続テストに失敗しました"
}
},
"scrollTabsLeft": "タブを左にスクロール",
"tabsAriaLabel": "設定タブ",
"scrollTabsRight": "タブを右にスクロール"

View File

@@ -244,6 +244,71 @@
}
}
},
"customModels": {
"label": "Свои модели",
"subtitle": "Подключите свою модель, совместимую с OpenAI.",
"addModel": "Добавить модель",
"empty": "Модели не найдены",
"searchPlaceholder": "Поиск моделей...",
"modalSubtitle": "Подключите любой OpenAI-совместимый эндпоинт (Mistral, Together, vLLM и т. п.).",
"addTitle": "Добавить свою модель",
"editTitle": "Изменить свою модель",
"save": "Сохранить",
"saving": "Сохранение",
"testConnection": "Проверить соединение",
"testing": "Проверка",
"testSuccess": "Соединение установлено.",
"testHintNew": "Заполните Base URL, ID модели и API-ключ, чтобы включить проверку соединения.",
"disabledBadge": "Отключено",
"deleteWarning": "Удалить свою модель «{{modelName}}»?",
"actionsMenuAria": "Действия для {{modelName}}",
"modelsGroup": {
"builtin": "Модели DocsGPT",
"user": "Мои модели"
},
"actions": {
"edit": "Изменить",
"delete": "Удалить"
},
"fields": {
"displayName": "Отображаемое имя",
"modelId": "ID модели",
"description": "Описание",
"baseUrl": "Base URL",
"apiKey": "API-ключ"
},
"placeholders": {
"displayName": "Мой Mistral",
"modelId": "напр. mistral-large-latest",
"description": "Необязательное описание",
"baseUrl": "https://api.mistral.ai/v1",
"apiKey": "sk-...",
"apiKeyEdit": "Оставьте пустым, чтобы сохранить текущий"
},
"hints": {
"apiKeyEdit": "Оставьте пустым, чтобы сохранить текущий ключ."
},
"capabilities": {
"title": "Возможности",
"contextWindowShort": "Контекстное окно",
"chips": {
"tools": "Инструменты",
"structuredOutput": "Структурированный вывод",
"images": "Изображения"
}
},
"errors": {
"displayNameRequired": "Укажите отображаемое имя",
"modelIdRequired": "Укажите ID модели",
"baseUrlRequired": "Укажите Base URL",
"baseUrlScheme": "Base URL должен начинаться с http:// или https://",
"baseUrlInvalid": "Введите корректный URL",
"apiKeyRequired": "Укажите API-ключ",
"contextWindowRange": "Контекстное окно должно быть от 1 000 до 10 000 000",
"saveFailed": "Не удалось сохранить свою модель",
"testFailed": "Проверка соединения не удалась"
}
},
"scrollTabsLeft": "Прокрутить вкладки влево",
"tabsAriaLabel": "Вкладки настроек",
"scrollTabsRight": "Прокрутить вкладки вправо"

View File

@@ -244,6 +244,71 @@
}
}
},
"customModels": {
"label": "自訂模型",
"subtitle": "接入你自己的 OpenAI 相容模型。",
"addModel": "新增模型",
"empty": "找不到模型",
"searchPlaceholder": "搜尋模型...",
"modalSubtitle": "可連接任何 OpenAI 相容端點(Mistral、Together、vLLM 等)。",
"addTitle": "新增自訂模型",
"editTitle": "編輯自訂模型",
"save": "儲存",
"saving": "儲存中",
"testConnection": "測試連線",
"testing": "測試中",
"testSuccess": "連線成功。",
"testHintNew": "請填寫 Base URL、模型 ID 與 API 金鑰以啟用連線測試。",
"disabledBadge": "已停用",
"deleteWarning": "確定要刪除自訂模型「{{modelName}}」嗎?",
"actionsMenuAria": "{{modelName}} 的操作",
"modelsGroup": {
"builtin": "DocsGPT 模型",
"user": "我的模型"
},
"actions": {
"edit": "編輯",
"delete": "刪除"
},
"fields": {
"displayName": "顯示名稱",
"modelId": "模型 ID",
"description": "說明",
"baseUrl": "Base URL",
"apiKey": "API 金鑰"
},
"placeholders": {
"displayName": "My Mistral",
"modelId": "例:mistral-large-latest",
"description": "選填說明",
"baseUrl": "https://api.mistral.ai/v1",
"apiKey": "sk-...",
"apiKeyEdit": "留空則保留原金鑰"
},
"hints": {
"apiKeyEdit": "留空以保留現有金鑰。"
},
"capabilities": {
"title": "功能",
"contextWindowShort": "脈絡視窗",
"chips": {
"tools": "工具",
"structuredOutput": "結構化輸出",
"images": "圖片"
}
},
"errors": {
"displayNameRequired": "顯示名稱為必填",
"modelIdRequired": "模型 ID 為必填",
"baseUrlRequired": "Base URL 為必填",
"baseUrlScheme": "Base URL 必須以 http:// 或 https:// 開頭",
"baseUrlInvalid": "請輸入有效的 URL",
"apiKeyRequired": "API 金鑰為必填",
"contextWindowRange": "脈絡視窗須介於 1,000 至 10,000,000 之間",
"saveFailed": "儲存自訂模型失敗",
"testFailed": "連線測試失敗"
}
},
"scrollTabsLeft": "向左捲動標籤",
"tabsAriaLabel": "設定標籤",
"scrollTabsRight": "向右捲動標籤"

View File

@@ -244,6 +244,71 @@
}
}
},
"customModels": {
"label": "自定义模型",
"subtitle": "接入你自己的 OpenAI 兼容模型。",
"addModel": "添加模型",
"empty": "未找到模型",
"searchPlaceholder": "搜索模型...",
"modalSubtitle": "可连接任何 OpenAI 兼容的接口(Mistral、Together、vLLM 等)。",
"addTitle": "添加自定义模型",
"editTitle": "编辑自定义模型",
"save": "保存",
"saving": "保存中",
"testConnection": "测试连接",
"testing": "测试中",
"testSuccess": "连接成功。",
"testHintNew": "请填写 Base URL、模型 ID 和 API 密钥以启用连接测试。",
"disabledBadge": "已禁用",
"deleteWarning": "确定要删除自定义模型「{{modelName}}」吗?",
"actionsMenuAria": "{{modelName}} 的操作",
"modelsGroup": {
"builtin": "DocsGPT 模型",
"user": "我的模型"
},
"actions": {
"edit": "编辑",
"delete": "删除"
},
"fields": {
"displayName": "显示名称",
"modelId": "模型 ID",
"description": "描述",
"baseUrl": "Base URL",
"apiKey": "API 密钥"
},
"placeholders": {
"displayName": "My Mistral",
"modelId": "如 mistral-large-latest",
"description": "可选描述",
"baseUrl": "https://api.mistral.ai/v1",
"apiKey": "sk-...",
"apiKeyEdit": "留空则保留原密钥"
},
"hints": {
"apiKeyEdit": "留空以保留现有密钥。"
},
"capabilities": {
"title": "功能",
"contextWindowShort": "上下文窗口",
"chips": {
"tools": "工具",
"structuredOutput": "结构化输出",
"images": "图片"
}
},
"errors": {
"displayNameRequired": "显示名称必填",
"modelIdRequired": "模型 ID 必填",
"baseUrlRequired": "Base URL 必填",
"baseUrlScheme": "Base URL 必须以 http:// 或 https:// 开头",
"baseUrlInvalid": "请输入有效的 URL",
"apiKeyRequired": "API 密钥必填",
"contextWindowRange": "上下文窗口必须在 1,000 至 10,000,000 之间",
"saveFailed": "保存自定义模型失败",
"testFailed": "连接测试失败"
}
},
"scrollTabsLeft": "向左滚动标签",
"tabsAriaLabel": "设置标签",
"scrollTabsRight": "向右滚动标签"

View File

@@ -0,0 +1,597 @@
import { Check } from 'lucide-react';
import { useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useSelector } from 'react-redux';
import customModelsService from '../api/services/customModelsService';
import Spinner from '../components/Spinner';
import { Input } from '../components/ui/input';
import { Label } from '../components/ui/label';
import { ActiveState } from '../models/misc';
import { selectToken } from '../preferences/preferenceSlice';
import WrapperComponent from './WrapperModal';
import type { CreateCustomModelPayload, CustomModel } from '../models/types';
interface CustomModelModalProps {
modalState: ActiveState;
setModalState: (state: ActiveState) => void;
model?: CustomModel | null;
onSaved: (model: CustomModel) => void;
}
interface FormState {
display_name: string;
upstream_model_id: string;
description: string;
base_url: string;
api_key: string;
supports_tools: boolean;
supports_structured_output: boolean;
supports_images: boolean;
context_window: number | '';
enabled: boolean;
}
const DEFAULT_CONTEXT_WINDOW = 128000;
const MIN_CONTEXT_WINDOW = 1000;
const MAX_CONTEXT_WINDOW = 10_000_000;
const buildInitialFormState = (model?: CustomModel | null): FormState => {
if (!model) {
return {
display_name: '',
upstream_model_id: '',
description: '',
base_url: '',
api_key: '',
supports_tools: true,
supports_structured_output: true,
supports_images: false,
context_window: DEFAULT_CONTEXT_WINDOW,
enabled: true,
};
}
const attachments = Array.isArray(model.capabilities?.attachments)
? model.capabilities.attachments
: [];
return {
display_name: model.display_name || '',
upstream_model_id: model.upstream_model_id || '',
description: model.description || '',
base_url: model.base_url || '',
api_key: '',
supports_tools: model.capabilities?.supports_tools ?? true,
supports_structured_output:
model.capabilities?.supports_structured_output ?? true,
supports_images: attachments.includes('image'),
context_window:
model.capabilities?.context_window ?? DEFAULT_CONTEXT_WINDOW,
enabled: model.enabled ?? true,
};
};
export default function CustomModelModal({
modalState,
setModalState,
model,
onSaved,
}: CustomModelModalProps) {
const { t } = useTranslation();
const token = useSelector(selectToken);
const isEditMode = !!model?.id;
const [formData, setFormData] = useState<FormState>(() =>
buildInitialFormState(model),
);
const [errors, setErrors] = useState<{ [key: string]: string }>({});
const [saving, setSaving] = useState(false);
const [testing, setTesting] = useState(false);
const [testResult, setTestResult] = useState<{
ok: boolean;
message: string;
} | null>(null);
useEffect(() => {
if (modalState === 'ACTIVE') {
setFormData(buildInitialFormState(model));
setErrors({});
setTestResult(null);
setSaving(false);
setTesting(false);
}
}, [modalState, model]);
const closeModal = () => {
setModalState('INACTIVE');
};
const handleChange = <K extends keyof FormState>(
name: K,
value: FormState[K],
) => {
setFormData((prev) => ({ ...prev, [name]: value }));
if (errors[name as string] || errors.general) {
setErrors((prev) => {
const next = { ...prev };
delete next[name as string];
delete next.general;
delete next.base_url_remote;
return next;
});
}
setTestResult(null);
};
const validate = (): boolean => {
const newErrors: { [key: string]: string } = {};
if (!formData.display_name.trim()) {
newErrors.display_name = t(
'settings.customModels.errors.displayNameRequired',
);
}
if (!formData.upstream_model_id.trim()) {
newErrors.upstream_model_id = t(
'settings.customModels.errors.modelIdRequired',
);
}
const trimmedUrl = formData.base_url.trim();
if (!trimmedUrl) {
newErrors.base_url = t('settings.customModels.errors.baseUrlRequired');
} else if (!/^https?:\/\//i.test(trimmedUrl)) {
newErrors.base_url = t('settings.customModels.errors.baseUrlScheme');
} else {
try {
new URL(trimmedUrl);
} catch {
newErrors.base_url = t('settings.customModels.errors.baseUrlInvalid');
}
}
if (!isEditMode && !formData.api_key.trim()) {
newErrors.api_key = t('settings.customModels.errors.apiKeyRequired');
}
const ctxValue =
formData.context_window === '' ? NaN : Number(formData.context_window);
if (
Number.isNaN(ctxValue) ||
ctxValue < MIN_CONTEXT_WINDOW ||
ctxValue > MAX_CONTEXT_WINDOW
) {
newErrors.context_window = t(
'settings.customModels.errors.contextWindowRange',
);
}
setErrors(newErrors);
return Object.keys(newErrors).length === 0;
};
const buildPayload = (): CreateCustomModelPayload => {
const ctxValue =
formData.context_window === ''
? DEFAULT_CONTEXT_WINDOW
: Number(formData.context_window);
const payload: CreateCustomModelPayload = {
upstream_model_id: formData.upstream_model_id.trim(),
display_name: formData.display_name.trim(),
description: formData.description.trim(),
base_url: formData.base_url.trim(),
capabilities: {
supports_tools: formData.supports_tools,
supports_structured_output: formData.supports_structured_output,
attachments: formData.supports_images ? ['image'] : [],
context_window: ctxValue,
},
enabled: formData.enabled,
};
if (formData.api_key.trim()) {
payload.api_key = formData.api_key.trim();
}
return payload;
};
const mapErrorToField = (
message: string,
): { field: string; message: string } => {
const lower = message.toLowerCase();
if (
lower.includes('reachable') ||
lower.includes('public internet') ||
lower.includes('ssrf') ||
lower.includes('url') ||
lower.includes('host')
) {
return { field: 'base_url_remote', message };
}
return { field: 'general', message };
};
const handleSave = async () => {
if (!validate()) return;
setSaving(true);
setTestResult(null);
try {
const payload = buildPayload();
const saved = isEditMode
? await customModelsService.updateCustomModel(model!.id, payload, token)
: await customModelsService.createCustomModel(payload, token);
onSaved(saved);
closeModal();
} catch (err) {
const message =
err instanceof Error
? err.message
: t('settings.customModels.errors.saveFailed');
const mapped = mapErrorToField(message);
setErrors((prev) => ({ ...prev, [mapped.field]: mapped.message }));
} finally {
setSaving(false);
}
};
// Edit mode allows blank api_key (by-id endpoint falls back to stored).
const trimmedBaseUrl = formData.base_url.trim();
const trimmedApiKey = formData.api_key.trim();
const trimmedUpstreamId = formData.upstream_model_id.trim();
const canTest = isEditMode
? !!(trimmedBaseUrl && trimmedUpstreamId)
: !!(trimmedBaseUrl && trimmedApiKey && trimmedUpstreamId);
const testDisabledHint = canTest
? undefined
: t('settings.customModels.testHintNew');
const handleTest = async () => {
if (!canTest) return;
setTesting(true);
setTestResult(null);
try {
const result =
isEditMode && model?.id
? await customModelsService.testCustomModel(model.id, token, {
base_url: trimmedBaseUrl,
api_key: trimmedApiKey,
upstream_model_id: trimmedUpstreamId,
})
: await customModelsService.testCustomModelPayload(
{
base_url: trimmedBaseUrl,
api_key: trimmedApiKey,
upstream_model_id: trimmedUpstreamId,
},
token,
);
if (result.ok) {
setTestResult({
ok: true,
message: t('settings.customModels.testSuccess'),
});
} else {
const message =
result.error || t('settings.customModels.errors.testFailed');
setTestResult({ ok: false, message });
const mapped = mapErrorToField(message);
if (mapped.field === 'base_url_remote') {
setErrors((prev) => ({ ...prev, base_url_remote: message }));
}
}
} catch (err) {
const message =
err instanceof Error
? err.message
: t('settings.customModels.errors.testFailed');
setTestResult({ ok: false, message });
} finally {
setTesting(false);
}
};
if (modalState !== 'ACTIVE') return null;
return (
<WrapperComponent
close={closeModal}
isPerformingTask={saving}
className="max-w-[600px] md:w-[80vw] lg:w-[60vw]"
>
<div className="flex h-full flex-col">
<div className="px-2 py-2">
<h2 className="text-foreground dark:text-foreground text-xl font-semibold">
{isEditMode
? t('settings.customModels.editTitle')
: t('settings.customModels.addTitle')}
</h2>
<p className="text-muted-foreground mt-2 text-sm">
{t('settings.customModels.modalSubtitle')}
</p>
</div>
<div className="flex-1 px-2">
<div className="flex flex-col gap-4 px-0.5 py-4">
{/* Row 1: Display name + Model ID side-by-side */}
<div className="grid grid-cols-1 gap-4 sm:grid-cols-2">
<div className="flex flex-col gap-1.5">
<Label htmlFor="cm-display-name">
{t('settings.customModels.fields.displayName')}
<span className="text-red-500">*</span>
</Label>
<Input
id="cm-display-name"
type="text"
value={formData.display_name}
onChange={(e) => handleChange('display_name', e.target.value)}
placeholder={t(
'settings.customModels.placeholders.displayName',
)}
aria-invalid={!!errors.display_name || undefined}
className="rounded-xl"
/>
{errors.display_name && (
<p className="text-destructive text-xs">
{errors.display_name}
</p>
)}
</div>
<div className="flex flex-col gap-1.5">
<Label htmlFor="cm-model-id">
{t('settings.customModels.fields.modelId')}
<span className="text-red-500">*</span>
</Label>
<Input
id="cm-model-id"
type="text"
value={formData.upstream_model_id}
onChange={(e) =>
handleChange('upstream_model_id', e.target.value)
}
placeholder={t('settings.customModels.placeholders.modelId')}
aria-invalid={!!errors.upstream_model_id || undefined}
className="rounded-xl"
/>
{errors.upstream_model_id && (
<p className="text-destructive text-xs">
{errors.upstream_model_id}
</p>
)}
</div>
</div>
{/* Row 2: Base URL + API key side-by-side */}
<div className="grid grid-cols-1 gap-4 sm:grid-cols-2">
<div className="flex flex-col gap-1.5">
<Label htmlFor="cm-base-url">
{t('settings.customModels.fields.baseUrl')}
<span className="text-red-500">*</span>
</Label>
<Input
id="cm-base-url"
type="url"
value={formData.base_url}
onChange={(e) => handleChange('base_url', e.target.value)}
placeholder={t('settings.customModels.placeholders.baseUrl')}
aria-invalid={
!!errors.base_url || !!errors.base_url_remote || undefined
}
className="rounded-xl"
/>
{errors.base_url && (
<p className="text-destructive text-xs">{errors.base_url}</p>
)}
{errors.base_url_remote && (
<p className="text-destructive text-xs">
{errors.base_url_remote}
</p>
)}
</div>
<div className="flex flex-col gap-1.5">
<Label htmlFor="cm-api-key">
{t('settings.customModels.fields.apiKey')}
{!isEditMode && <span className="text-red-500">*</span>}
</Label>
<Input
id="cm-api-key"
type="password"
autoComplete="new-password"
value={formData.api_key}
onChange={(e) => handleChange('api_key', e.target.value)}
placeholder={
isEditMode
? t('settings.customModels.placeholders.apiKeyEdit')
: t('settings.customModels.placeholders.apiKey')
}
aria-invalid={!!errors.api_key || undefined}
className="rounded-xl"
/>
{isEditMode && (
<p className="text-muted-foreground text-xs">
{t('settings.customModels.hints.apiKeyEdit')}
</p>
)}
{errors.api_key && (
<p className="text-destructive text-xs">{errors.api_key}</p>
)}
</div>
</div>
{/* Row 3: Description (full width, optional) */}
<div className="flex flex-col gap-1.5">
<Label htmlFor="cm-description">
{t('settings.customModels.fields.description')}
</Label>
<Input
id="cm-description"
type="text"
value={formData.description}
onChange={(e) => handleChange('description', e.target.value)}
placeholder={t(
'settings.customModels.placeholders.description',
)}
className="rounded-xl"
/>
</div>
{/* Row 4: Capabilities — flat (no border), chips + inline ctx */}
<div className="flex flex-col gap-2">
<Label>{t('settings.customModels.capabilities.title')}</Label>
<div className="flex flex-wrap gap-2">
<CapabilityChip
label={t('settings.customModels.capabilities.chips.tools')}
active={formData.supports_tools}
onClick={() =>
handleChange('supports_tools', !formData.supports_tools)
}
/>
<CapabilityChip
label={t(
'settings.customModels.capabilities.chips.structuredOutput',
)}
active={formData.supports_structured_output}
onClick={() =>
handleChange(
'supports_structured_output',
!formData.supports_structured_output,
)
}
/>
<CapabilityChip
label={t('settings.customModels.capabilities.chips.images')}
active={formData.supports_images}
onClick={() =>
handleChange('supports_images', !formData.supports_images)
}
/>
</div>
<div className="mt-1 flex flex-col gap-1.5 sm:flex-row sm:items-center sm:gap-3">
<Label
htmlFor="cm-context-window"
className="text-muted-foreground text-xs sm:mb-0"
>
{t('settings.customModels.capabilities.contextWindowShort')}
</Label>
<Input
id="cm-context-window"
type="number"
value={formData.context_window}
min={MIN_CONTEXT_WINDOW}
max={MAX_CONTEXT_WINDOW}
step={1000}
onChange={(e) => {
const v = e.target.value;
if (v === '') {
handleChange('context_window', '');
} else {
const n = parseInt(v, 10);
if (!Number.isNaN(n)) {
handleChange('context_window', n);
}
}
}}
aria-invalid={!!errors.context_window || undefined}
className="w-full rounded-xl sm:w-40"
/>
{errors.context_window && (
<p className="text-destructive text-xs">
{errors.context_window}
</p>
)}
</div>
</div>
{testResult && (
<div
className={`rounded-xl p-3 text-sm ${
testResult.ok
? 'bg-green-50 text-green-700 dark:bg-green-900/40 dark:text-green-300'
: 'bg-red-50 text-red-700 dark:bg-red-900/40 dark:text-red-300'
}`}
>
{testResult.message}
</div>
)}
{errors.general && (
<div className="rounded-xl bg-red-50 p-3 text-sm text-red-700 dark:bg-red-900/40 dark:text-red-300">
{errors.general}
</div>
)}
</div>
</div>
<div className="px-2 py-4">
<div className="flex flex-col gap-3 sm:flex-row sm:justify-between">
<button
type="button"
onClick={handleTest}
disabled={!canTest || testing || saving}
title={testDisabledHint}
className="border-border dark:border-border dark:text-foreground hover:bg-accent dark:hover:bg-muted/50 w-full rounded-3xl border px-6 py-2 text-sm font-medium transition-all disabled:cursor-not-allowed disabled:opacity-50 sm:w-auto"
>
{testing ? (
<div className="flex items-center justify-center">
<Spinner size="small" />
<span className="ml-2">
{t('settings.customModels.testing')}
</span>
</div>
) : (
t('settings.customModels.testConnection')
)}
</button>
<div className="flex flex-col-reverse gap-3 sm:flex-row sm:gap-3">
<button
type="button"
onClick={closeModal}
disabled={saving}
className="dark:text-foreground hover:bg-accent dark:hover:bg-muted/50 w-full cursor-pointer rounded-3xl px-6 py-2 text-sm font-medium disabled:opacity-50 sm:w-auto"
>
{t('cancel')}
</button>
<button
type="button"
onClick={handleSave}
disabled={saving}
className="bg-primary hover:bg-primary/90 w-full rounded-3xl px-6 py-2 text-sm font-medium text-white transition-all disabled:opacity-50 sm:w-auto"
>
{saving ? (
<div className="flex items-center justify-center">
<Spinner size="small" />
<span className="ml-2">
{t('settings.customModels.saving')}
</span>
</div>
) : (
t('settings.customModels.save')
)}
</button>
</div>
</div>
</div>
</div>
</WrapperComponent>
);
}
interface CapabilityChipProps {
label: string;
active: boolean;
onClick: () => void;
}
function CapabilityChip({ label, active, onClick }: CapabilityChipProps) {
return (
<button
type="button"
role="switch"
aria-checked={active}
onClick={onClick}
className={`inline-flex items-center gap-1.5 rounded-full border px-3 py-1.5 text-sm transition-colors ${
active
? 'border-emerald-500/50 bg-emerald-500/10 text-emerald-700 dark:border-emerald-400/40 dark:bg-emerald-400/10 dark:text-emerald-300'
: 'border-border text-muted-foreground hover:bg-accent dark:hover:bg-muted/40'
}`}
>
{active && <Check size={14} strokeWidth={2.5} />}
{label}
</button>
);
}

View File

@@ -1,3 +1,5 @@
export type ModelSource = 'builtin' | 'user';
export interface AvailableModel {
id: string;
provider: string;
@@ -9,6 +11,7 @@ export interface AvailableModel {
supports_structured_output: boolean;
supports_streaming: boolean;
enabled: boolean;
source?: ModelSource;
}
export interface Model {
@@ -22,4 +25,38 @@ export interface Model {
supports_tools: boolean;
supports_structured_output: boolean;
supports_streaming: boolean;
source?: ModelSource;
}
export interface CustomModelCapabilities {
supports_tools: boolean;
supports_structured_output: boolean;
attachments: string[];
context_window: number;
}
export interface CustomModel {
id: string;
upstream_model_id: string;
display_name: string;
description?: string;
base_url: string;
capabilities: CustomModelCapabilities;
enabled: boolean;
source: 'user';
}
export interface CreateCustomModelPayload {
upstream_model_id: string;
display_name: string;
description?: string;
base_url: string;
api_key?: string;
capabilities: CustomModelCapabilities;
enabled: boolean;
}
export interface CustomModelTestResult {
ok: boolean;
error?: string;
}

View File

@@ -250,6 +250,22 @@ prefListenerMiddleware.startListening({
},
});
// Reconcile selectedModel when availableModels changes so a deleted
// BYOM doesn't leave a stale id in localStorage.
prefListenerMiddleware.startListening({
matcher: isAnyOf(setAvailableModels),
effect: (_action, listenerApi) => {
const state = listenerApi.getState() as RootState;
const { selectedModel, availableModels } = state.preference;
if (!availableModels.length) return;
if (!selectedModel) return;
const stillValid = availableModels.some((m) => m.id === selectedModel.id);
if (!stillValid) {
listenerApi.dispatch(setSelectedModel(availableModels[0]));
}
},
});
export const selectApiKey = (state: RootState) => state.preference.apiKey;
export const selectApiKeyStatus = (state: RootState) =>
!!state.preference.apiKey;

View File

@@ -0,0 +1,313 @@
import { Globe, Tag, Trash } from 'lucide-react';
import React from 'react';
import { useTranslation } from 'react-i18next';
import { useDispatch, useSelector } from 'react-redux';
import customModelsService from '../api/services/customModelsService';
import modelService from '../api/services/modelService';
import Edit from '../assets/edit.svg';
import NoFilesDarkIcon from '../assets/no-files-dark.svg';
import NoFilesIcon from '../assets/no-files.svg';
import SearchIcon from '../assets/search.svg';
import ThreeDotsIcon from '../assets/three-dots.svg';
import ContextMenu, { MenuOption } from '../components/ContextMenu';
import SkeletonLoader from '../components/SkeletonLoader';
import { useDarkTheme, useLoaderState } from '../hooks';
import ConfirmationModal from '../modals/ConfirmationModal';
import CustomModelModal from '../modals/CustomModelModal';
import { ActiveState } from '../models/misc';
import {
selectToken,
setAvailableModels,
} from '../preferences/preferenceSlice';
import type { CustomModel } from '../models/types';
const formatBaseUrlHost = (baseUrl: string): string => {
if (!baseUrl) return '';
try {
const url = new URL(baseUrl);
return url.host || url.hostname || baseUrl;
} catch {
const stripped = baseUrl.replace(/^https?:\/\//i, '');
const slashIdx = stripped.indexOf('/');
return slashIdx >= 0 ? stripped.slice(0, slashIdx) : stripped;
}
};
export default function CustomModels() {
const { t } = useTranslation();
const dispatch = useDispatch();
const token = useSelector(selectToken);
const [isDarkTheme] = useDarkTheme();
const [models, setModels] = React.useState<CustomModel[]>([]);
const [searchTerm, setSearchTerm] = React.useState('');
const [loading, setLoading] = useLoaderState(false);
const [modalState, setModalState] = React.useState<ActiveState>('INACTIVE');
const [editingModel, setEditingModel] = React.useState<CustomModel | null>(
null,
);
const [activeMenuId, setActiveMenuId] = React.useState<string | null>(null);
const menuRefs = React.useRef<{
[key: string]: React.RefObject<HTMLDivElement | null>;
}>({});
const [deleteState, setDeleteState] = React.useState<ActiveState>('INACTIVE');
const [modelToDelete, setModelToDelete] = React.useState<CustomModel | null>(
null,
);
// Ref instead of useCallback: useLoaderState returns a fresh setter
// each render, which would loop the effect (thousands of req/s).
const fetchModelsRef = React.useRef<() => Promise<void>>(async () => {});
fetchModelsRef.current = async () => {
setLoading(true);
try {
const data = await customModelsService.listCustomModels(token);
setModels(data);
} catch (err) {
console.error('Failed to load custom models:', err);
setModels([]);
} finally {
setLoading(false);
}
};
React.useEffect(() => {
fetchModelsRef.current();
}, [token]);
React.useEffect(() => {
models.forEach((model) => {
if (!menuRefs.current[model.id]) {
menuRefs.current[model.id] = React.createRef<HTMLDivElement>();
}
});
}, [models]);
const openAddModal = () => {
setEditingModel(null);
setModalState('ACTIVE');
};
const openEditModal = (model: CustomModel) => {
setEditingModel(model);
setModalState('ACTIVE');
};
// Refresh Redux availableModels so the chat dropdown reconciles a
// selectedModel UUID that was just deleted/disabled.
const refreshGlobalAvailableModels = React.useCallback(async () => {
try {
const response = await modelService.getModels(token);
if (!response.ok) return;
const data = await response.json();
const transformed = modelService.transformModels(data.models || []);
dispatch(setAvailableModels(transformed));
} catch (err) {
console.error('Failed to refresh global available models:', err);
}
}, [dispatch, token]);
const handleSaved = (saved: CustomModel) => {
setModels((prev) => {
const idx = prev.findIndex((m) => m.id === saved.id);
if (idx === -1) return [saved, ...prev];
const next = [...prev];
next[idx] = saved;
return next;
});
refreshGlobalAvailableModels();
};
const requestDelete = (model: CustomModel) => {
setModelToDelete(model);
setDeleteState('ACTIVE');
};
const confirmDelete = async () => {
if (!modelToDelete) return;
try {
await customModelsService.deleteCustomModel(modelToDelete.id, token);
setModels((prev) => prev.filter((m) => m.id !== modelToDelete.id));
refreshGlobalAvailableModels();
} catch (err) {
console.error('Failed to delete custom model:', err);
} finally {
setModelToDelete(null);
setDeleteState('INACTIVE');
}
};
const getMenuOptions = (model: CustomModel): MenuOption[] => [
{
icon: Edit,
label: t('settings.customModels.actions.edit'),
onClick: () => openEditModal(model),
variant: 'primary',
iconWidth: 14,
iconHeight: 14,
},
{
icon: Trash,
label: t('settings.customModels.actions.delete'),
onClick: () => requestDelete(model),
variant: 'danger',
iconWidth: 16,
iconHeight: 16,
},
];
const filteredModels = models.filter((model) => {
const q = searchTerm.toLowerCase();
return (
model.display_name.toLowerCase().includes(q) ||
model.upstream_model_id.toLowerCase().includes(q)
);
});
const renderEmptyState = () => (
<div className="flex w-full flex-col items-center justify-center py-12">
<img
src={isDarkTheme ? NoFilesDarkIcon : NoFilesIcon}
alt={t('settings.customModels.empty')}
className="mx-auto mb-6 h-32 w-32"
/>
<p className="text-center text-lg text-gray-500 dark:text-gray-400">
{t('settings.customModels.empty')}
</p>
</div>
);
return (
<div className="mt-8">
<div className="relative flex flex-col">
<p className="text-muted-foreground mb-5 text-[15px] leading-6">
{t('settings.customModels.subtitle')}
</p>
<div className="my-3 flex flex-col items-start gap-3 sm:flex-row sm:items-center sm:justify-between">
<div className="relative w-full max-w-md">
<img
src={SearchIcon}
alt=""
className="absolute top-1/2 left-4 h-5 w-5 -translate-y-1/2 opacity-40"
/>
<input
maxLength={256}
placeholder={t('settings.customModels.searchPlaceholder')}
name="custom-models-search-input"
type="text"
id="custom-models-search-input"
value={searchTerm}
onChange={(e) => setSearchTerm(e.target.value)}
className="border-border bg-card text-foreground placeholder:text-muted-foreground h-11 w-full rounded-full border py-2 pr-5 pl-11 text-sm shadow-[0_1px_4px_rgba(0,0,0,0.06)] transition-shadow outline-none focus:shadow-[0_2px_8px_rgba(0,0,0,0.1)] dark:shadow-none"
/>
</div>
<button
className="bg-primary hover:bg-primary/90 flex h-11 min-w-[108px] items-center justify-center rounded-full px-4 text-sm whitespace-normal text-white"
onClick={openAddModal}
>
{t('settings.customModels.addModel')}
</button>
</div>
<div className="border-border dark:border-border mt-5 mb-8 border-b" />
{loading ? (
<div className="flex flex-wrap justify-center gap-4 sm:justify-start">
<SkeletonLoader component="toolCards" count={3} />
</div>
) : (
<div className="flex flex-wrap justify-center gap-4 sm:justify-start">
{filteredModels.length === 0
? renderEmptyState()
: filteredModels.map((model) => (
<div
key={model.id}
className="bg-muted hover:bg-accent relative flex w-[300px] flex-col overflow-hidden rounded-2xl p-5"
>
<div
ref={menuRefs.current[model.id]}
onClick={(e) => {
e.stopPropagation();
setActiveMenuId(
activeMenuId === model.id ? null : model.id,
);
}}
className="absolute top-3 right-3 z-10 cursor-pointer"
>
<img
src={ThreeDotsIcon}
alt={t('settings.customModels.actionsMenuAria', {
modelName: model.display_name,
})}
className="h-[19px] w-[19px]"
/>
<ContextMenu
isOpen={activeMenuId === model.id}
setIsOpen={(isOpen) => {
setActiveMenuId(isOpen ? model.id : null);
}}
options={getMenuOptions(model)}
anchorRef={menuRefs.current[model.id]}
position="bottom-right"
offset={{ x: 0, y: 0 }}
/>
</div>
<div className="w-full pr-7">
<div className="flex items-center gap-2">
<p
title={model.display_name}
className="text-foreground dark:text-foreground truncate text-[15px] leading-snug font-semibold"
>
{model.display_name}
</p>
{!model.enabled && (
<span className="bg-muted-foreground/15 text-muted-foreground shrink-0 rounded-full px-2 py-0.5 text-[10px] leading-none font-medium">
{t('settings.customModels.disabledBadge')}
</span>
)}
</div>
<div className="mt-3 space-y-1.5">
<div
className="text-muted-foreground/80 flex items-center gap-1.5 text-xs leading-relaxed"
title={model.upstream_model_id}
>
<Tag className="h-3.5 w-3.5 shrink-0 opacity-70" />
<span className="truncate">
{model.upstream_model_id}
</span>
</div>
<div
className="text-muted-foreground/80 flex items-center gap-1.5 text-xs leading-relaxed"
title={model.base_url}
>
<Globe className="h-3.5 w-3.5 shrink-0 opacity-70" />
<span className="truncate">
{formatBaseUrlHost(model.base_url)}
</span>
</div>
</div>
</div>
</div>
))}
</div>
)}
</div>
<CustomModelModal
modalState={modalState}
setModalState={setModalState}
model={editingModel}
onSaved={handleSaved}
/>
<ConfirmationModal
message={t('settings.customModels.deleteWarning', {
modelName: modelToDelete?.display_name || '',
})}
modalState={deleteState}
setModalState={setDeleteState}
handleSubmit={confirmDelete}
submitLabel={t('settings.customModels.actions.delete')}
variant="danger"
/>
</div>
);
}

View File

@@ -21,6 +21,7 @@ import {
setSourceDocs,
} from '../preferences/preferenceSlice';
import Analytics from './Analytics';
import CustomModels from './CustomModels';
import Sources from './Sources';
import General from './General';
import Logs from './Logs';
@@ -39,6 +40,8 @@ export default function Settings() {
return t('settings.analytics.label');
if (path.includes('/settings/logs')) return t('settings.logs.label');
if (path.includes('/settings/tools')) return t('settings.tools.label');
if (path.includes('/settings/custom-models'))
return t('settings.customModels.label');
return t('settings.general.label');
};
@@ -52,6 +55,8 @@ export default function Settings() {
navigate('/settings/analytics');
else if (tab === t('settings.logs.label')) navigate('/settings/logs');
else if (tab === t('settings.tools.label')) navigate('/settings/tools');
else if (tab === t('settings.customModels.label'))
navigate('/settings/custom-models');
};
React.useEffect(() => {
@@ -113,6 +118,7 @@ export default function Settings() {
<Route path="analytics" element={<Analytics />} />
<Route path="logs" element={<Logs />} />
<Route path="tools" element={<Tools />} />
<Route path="custom-models" element={<CustomModels />} />
<Route path="*" element={<Navigate to="/settings" replace />} />
</Routes>
</div>

View File

@@ -116,7 +116,7 @@ def test_execute_agent_node_normalizes_wrapped_schema_before_agent_create(monkey
)
monkeypatch.setattr(
"application.core.model_utils.get_model_capabilities",
lambda _model_id: {"supports_structured_output": True},
lambda _model_id, **_kwargs: {"supports_structured_output": True},
)
list(engine._execute_agent_node(node))
@@ -242,7 +242,7 @@ def test_validate_workflow_structure_rejects_unsupported_structured_output_model
monkeypatch.setattr(
workflow_routes,
"get_model_capabilities",
lambda _model_id: {"supports_structured_output": False},
lambda _model_id, **_kwargs: {"supports_structured_output": False},
)
nodes = [
@@ -296,7 +296,7 @@ def test_execute_agent_node_raises_when_structured_output_violates_schema(monkey
)
monkeypatch.setattr(
"application.core.model_utils.get_model_capabilities",
lambda _model_id: {"supports_structured_output": True},
lambda _model_id, **_kwargs: {"supports_structured_output": True},
)
with pytest.raises(ValueError, match="Structured output did not match schema"):
@@ -322,7 +322,7 @@ def test_execute_agent_node_raises_when_schema_set_and_response_not_json(monkeyp
)
monkeypatch.setattr(
"application.core.model_utils.get_model_capabilities",
lambda _model_id: {"supports_structured_output": True},
lambda _model_id, **_kwargs: {"supports_structured_output": True},
)
with pytest.raises(
@@ -360,11 +360,11 @@ class TestWorkflowEngineAdditionalCoverage:
)
monkeypatch.setattr(
"application.core.model_utils.get_provider_from_model_id",
lambda _: None,
lambda _, **_kwargs: None,
)
monkeypatch.setattr(
"application.core.model_utils.get_model_capabilities",
lambda _: None,
lambda _, **_kwargs: None,
)
list(engine._execute_agent_node(node))
@@ -390,11 +390,11 @@ class TestWorkflowEngineAdditionalCoverage:
)
monkeypatch.setattr(
"application.core.model_utils.get_provider_from_model_id",
lambda _: "openai",
lambda _, **_kwargs: "openai",
)
monkeypatch.setattr(
"application.core.model_utils.get_model_capabilities",
lambda _: None,
lambda _, **_kwargs: None,
)
list(engine._execute_agent_node(node))
@@ -416,11 +416,11 @@ class TestWorkflowEngineAdditionalCoverage:
)
monkeypatch.setattr(
"application.core.model_utils.get_provider_from_model_id",
lambda _: "openai",
lambda _, **_kwargs: "openai",
)
monkeypatch.setattr(
"application.core.model_utils.get_model_capabilities",
lambda _: {"supports_structured_output": False},
lambda _, **_kwargs: {"supports_structured_output": False},
)
with pytest.raises(ValueError, match="does not support structured output"):
@@ -450,11 +450,11 @@ class TestWorkflowEngineAdditionalCoverage:
)
monkeypatch.setattr(
"application.core.model_utils.get_provider_from_model_id",
lambda _: None,
lambda _, **_kwargs: None,
)
monkeypatch.setattr(
"application.core.model_utils.get_model_capabilities",
lambda _: {"supports_structured_output": True},
lambda _, **_kwargs: {"supports_structured_output": True},
)
list(engine._execute_agent_node(node))
@@ -481,11 +481,11 @@ class TestWorkflowEngineAdditionalCoverage:
)
monkeypatch.setattr(
"application.core.model_utils.get_provider_from_model_id",
lambda _: None,
lambda _, **_kwargs: None,
)
monkeypatch.setattr(
"application.core.model_utils.get_model_capabilities",
lambda _: {"supports_structured_output": True},
lambda _, **_kwargs: {"supports_structured_output": True},
)
with pytest.raises(

View File

@@ -162,8 +162,10 @@ class TestCompressIfNeeded:
current_query_tokens=1000,
)
# user_id flows through so BYOM custom-model UUIDs resolve to
# the user's declared context window in the threshold check.
mock_threshold_checker.should_compress.assert_called_once_with(
sample_conversation, "gpt-4", 1000
sample_conversation, "gpt-4", 1000, user_id="user1"
)
@@ -270,8 +272,12 @@ class TestPerformCompression:
)
orch._perform_compression("c1", conversation, "gpt-4", decoded_token)
# Verify the override model was used
mock_get_provider.assert_called_with("gpt-3.5-turbo")
# Verify the override model was used. user_id flows from
# decoded_token['sub'] so per-user BYOM custom-model UUIDs
# resolve.
mock_get_provider.assert_called_with(
"gpt-3.5-turbo", user_id=decoded_token["sub"]
)
@patch(
"application.api.answer.services.compression.orchestrator.get_provider_from_model_id"
@@ -362,7 +368,12 @@ class TestCompressMidExecution:
)
mock_perform.assert_called_once_with(
"conv1", sample_conversation, "gpt-4", decoded_token
"conv1",
sample_conversation,
"gpt-4",
decoded_token,
user_id="user1",
model_user_id=None,
)
def test_loads_conversation_when_not_provided(

View File

@@ -0,0 +1,688 @@
"""Tests for the BYOM REST API at /api/user/models."""
from __future__ import annotations
from contextlib import contextmanager
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
@pytest.fixture
def app():
return Flask(__name__)
@contextmanager
def _patch_db(conn):
"""Patch the routes' db helpers to yield the given pg connection."""
@contextmanager
def _yield_conn():
yield conn
with patch(
"application.api.user.models.routes.db_session", _yield_conn
), patch(
"application.api.user.models.routes.db_readonly", _yield_conn
):
yield
@pytest.fixture(autouse=True)
def _reset_registry():
from application.core.model_registry import ModelRegistry
ModelRegistry.reset()
yield
ModelRegistry.reset()
# Auth
@pytest.mark.unit
class TestAuth:
def test_list_unauthenticated_returns_401(self, app):
from application.api.user.models.routes import (
UserModelsCollectionResource,
)
with app.test_request_context("/api/user/models"):
from flask import request
request.decoded_token = None
resp = UserModelsCollectionResource().get()
assert resp.status_code == 401
def test_create_unauthenticated_returns_401(self, app):
from application.api.user.models.routes import (
UserModelsCollectionResource,
)
with app.test_request_context(
"/api/user/models",
method="POST",
json={
"upstream_model_id": "x",
"display_name": "x",
"base_url": "https://api.openai.com/v1",
"api_key": "k",
},
):
from flask import request
request.decoded_token = None
resp = UserModelsCollectionResource().post()
assert resp.status_code == 401
# Create
@pytest.mark.unit
class TestCreate:
def test_creates_and_returns_201_without_api_key(self, app, pg_conn):
from application.api.user.models.routes import (
UserModelsCollectionResource,
)
# Mock DNS so the SSRF check passes for api.mistral.ai without
# hitting the network.
with patch("application.security.safe_url.socket.getaddrinfo") as gai:
gai.return_value = [
(None, None, None, None, ("104.18.0.1", 0))
]
with app.test_request_context(
"/api/user/models",
method="POST",
json={
"upstream_model_id": "mistral-large-latest",
"display_name": "My Mistral",
"base_url": "https://api.mistral.ai/v1",
"api_key": "sk-mistral-test",
"capabilities": {
"supports_tools": True,
"context_window": 128000,
},
},
):
from flask import request
request.decoded_token = {"sub": "user-1"}
with _patch_db(pg_conn):
resp = UserModelsCollectionResource().post()
assert resp.status_code == 201
body = resp.get_json()
assert body["upstream_model_id"] == "mistral-large-latest"
assert body["source"] == "user"
# Critical: api_key must NEVER appear in the response
assert "api_key" not in body
for v in body.values():
assert v != "sk-mistral-test"
def test_create_rejects_missing_required_fields(self, app, pg_conn):
from application.api.user.models.routes import (
UserModelsCollectionResource,
)
with app.test_request_context(
"/api/user/models",
method="POST",
json={"upstream_model_id": "x"}, # missing the others
):
from flask import request
request.decoded_token = {"sub": "user-1"}
with _patch_db(pg_conn):
resp = UserModelsCollectionResource().post()
assert resp.status_code == 400
def test_create_rejects_loopback_url(self, app, pg_conn):
from application.api.user.models.routes import (
UserModelsCollectionResource,
)
with app.test_request_context(
"/api/user/models",
method="POST",
json={
"upstream_model_id": "x",
"display_name": "x",
"base_url": "https://127.0.0.1/v1",
"api_key": "k",
},
):
from flask import request
request.decoded_token = {"sub": "user-1"}
with _patch_db(pg_conn):
resp = UserModelsCollectionResource().post()
assert resp.status_code == 400
body = resp.get_json()
assert "error" in body
def test_create_rejects_unknown_attachment_alias(self, app, pg_conn):
"""The UI sends ``["image"]`` as an alias; bad strings ("video",
typos) must reject at the boundary so the DB never holds
garbage that the registry would later silently drop.
"""
from application.api.user.models.routes import (
UserModelsCollectionResource,
)
with patch("application.security.safe_url.socket.getaddrinfo") as gai:
gai.return_value = [
(None, None, None, None, ("104.18.0.1", 0))
]
with app.test_request_context(
"/api/user/models",
method="POST",
json={
"upstream_model_id": "m",
"display_name": "M",
"base_url": "https://api.mistral.ai/v1",
"api_key": "k",
"capabilities": {"attachments": ["video"]},
},
):
from flask import request
request.decoded_token = {"sub": "user-1"}
with _patch_db(pg_conn):
resp = UserModelsCollectionResource().post()
assert resp.status_code == 400
body = resp.get_json()
assert "video" in body["error"]
def test_create_accepts_image_alias_and_raw_mime(self, app, pg_conn):
"""The known ``image`` alias and raw MIME types both pass."""
from application.api.user.models.routes import (
UserModelsCollectionResource,
)
with patch("application.security.safe_url.socket.getaddrinfo") as gai:
gai.return_value = [
(None, None, None, None, ("104.18.0.1", 0))
]
with app.test_request_context(
"/api/user/models",
method="POST",
json={
"upstream_model_id": "m",
"display_name": "M",
"base_url": "https://api.mistral.ai/v1",
"api_key": "k",
"capabilities": {
"attachments": ["image", "application/pdf"],
},
},
):
from flask import request
request.decoded_token = {"sub": "user-1"}
with _patch_db(pg_conn):
resp = UserModelsCollectionResource().post()
assert resp.status_code == 201
def test_create_rejects_private_ip_dns(self, app, pg_conn):
from application.api.user.models.routes import (
UserModelsCollectionResource,
)
with patch("application.security.safe_url.socket.getaddrinfo") as gai:
# Hostname resolves to a private IP only — must reject
gai.return_value = [
(None, None, None, None, ("10.0.0.5", 0))
]
with app.test_request_context(
"/api/user/models",
method="POST",
json={
"upstream_model_id": "x",
"display_name": "x",
"base_url": "https://evil.example.com/v1",
"api_key": "k",
},
):
from flask import request
request.decoded_token = {"sub": "user-1"}
with _patch_db(pg_conn):
resp = UserModelsCollectionResource().post()
assert resp.status_code == 400
# List / get / patch / delete
def _create_via_repo(pg_conn, user_id="user-1", **kwargs):
from application.storage.db.repositories.user_custom_models import (
UserCustomModelsRepository,
)
return UserCustomModelsRepository(pg_conn).create(
user_id=user_id,
upstream_model_id=kwargs.pop("upstream_model_id", "mistral-large-latest"),
display_name=kwargs.pop("display_name", "My Mistral"),
base_url=kwargs.pop("base_url", "https://api.mistral.ai/v1"),
api_key_plaintext=kwargs.pop("api_key_plaintext", "sk-mistral-test"),
**kwargs,
)
@pytest.mark.unit
class TestList:
def test_lists_only_users_own(self, app, pg_conn):
from application.api.user.models.routes import (
UserModelsCollectionResource,
)
_create_via_repo(pg_conn, user_id="alice", upstream_model_id="alice-1")
_create_via_repo(pg_conn, user_id="bob", upstream_model_id="bob-1")
with app.test_request_context("/api/user/models"):
from flask import request
request.decoded_token = {"sub": "alice"}
with _patch_db(pg_conn):
resp = UserModelsCollectionResource().get()
assert resp.status_code == 200
body = resp.get_json()
upstream_ids = {m["upstream_model_id"] for m in body["models"]}
assert upstream_ids == {"alice-1"}
# Never expose the api_key
for m in body["models"]:
assert "api_key" not in m
@pytest.mark.unit
class TestGet:
def test_returns_404_for_other_users_model(self, app, pg_conn):
from application.api.user.models.routes import UserModelResource
created = _create_via_repo(pg_conn, user_id="alice")
with app.test_request_context(
f"/api/user/models/{created['id']}"
):
from flask import request
request.decoded_token = {"sub": "bob"}
with _patch_db(pg_conn):
resp = UserModelResource().get(model_id=created["id"])
assert resp.status_code == 404
@pytest.mark.unit
class TestPatch:
def test_patch_updates_display_name(self, app, pg_conn):
from application.api.user.models.routes import UserModelResource
created = _create_via_repo(pg_conn, user_id="user-1")
with app.test_request_context(
f"/api/user/models/{created['id']}",
method="PATCH",
json={"display_name": "Renamed"},
):
from flask import request
request.decoded_token = {"sub": "user-1"}
with _patch_db(pg_conn):
resp = UserModelResource().patch(model_id=created["id"])
assert resp.status_code == 200
body = resp.get_json()
assert body["display_name"] == "Renamed"
def test_patch_blank_api_key_keeps_existing(self, app, pg_conn):
"""Critical PATCH semantic: empty/missing api_key in body must
preserve the stored ciphertext (the UI sends a blank password
field when the user wants to keep the existing key)."""
from application.api.user.models.routes import UserModelResource
from application.storage.db.repositories.user_custom_models import (
UserCustomModelsRepository,
)
created = _create_via_repo(pg_conn, user_id="user-1")
original_key_plaintext = "sk-mistral-test"
with app.test_request_context(
f"/api/user/models/{created['id']}",
method="PATCH",
json={"display_name": "Just rename me", "api_key": ""},
):
from flask import request
request.decoded_token = {"sub": "user-1"}
with _patch_db(pg_conn):
resp = UserModelResource().patch(model_id=created["id"])
assert resp.status_code == 200
# Decrypted key is unchanged
repo = UserCustomModelsRepository(pg_conn)
assert (
repo.get_decrypted_api_key(created["id"], "user-1")
== original_key_plaintext
)
@pytest.mark.unit
class TestDelete:
def test_delete_removes_row_and_invalidates_cache(self, app, pg_conn):
from application.api.user.models.routes import UserModelResource
from application.core.model_registry import ModelRegistry
created = _create_via_repo(pg_conn, user_id="user-1")
# Warm the registry's per-user cache via a lookup
with patch(
"application.storage.db.session.db_readonly"
) as ro:
@contextmanager
def _y():
yield pg_conn
ro.side_effect = _y
ModelRegistry.get_instance().get_model(created["id"], user_id="user-1")
# Now delete via the route — invalidation must happen so a
# subsequent lookup misses
with app.test_request_context(
f"/api/user/models/{created['id']}", method="DELETE"
):
from flask import request
request.decoded_token = {"sub": "user-1"}
with _patch_db(pg_conn):
resp = UserModelResource().delete(model_id=created["id"])
assert resp.status_code == 200
# Cache invalidated → next lookup re-queries DB and finds nothing
assert "user-1" not in ModelRegistry.get_instance()._user_models
# /api/models combined view
@pytest.mark.unit
class TestSecurityCreateRejectsBlankFields:
"""P1 #1 partial: blank api_key on create must be rejected so we
can never end up with an unroutable BYOM record that would cause
LLMCreator to leak settings.API_KEY to the user-supplied URL."""
def test_create_rejects_blank_api_key(self, app, pg_conn):
from application.api.user.models.routes import (
UserModelsCollectionResource,
)
with patch("application.security.safe_url.socket.getaddrinfo") as gai:
gai.return_value = [(None, None, None, None, ("104.18.0.1", 0))]
with app.test_request_context(
"/api/user/models",
method="POST",
json={
"upstream_model_id": "x",
"display_name": "x",
"base_url": "https://api.mistral.ai/v1",
"api_key": " ", # whitespace only
},
):
from flask import request
request.decoded_token = {"sub": "u"}
with _patch_db(pg_conn):
resp = UserModelsCollectionResource().post()
assert resp.status_code == 400
body = resp.get_json()
assert "api_key" in (body.get("error") or "").lower()
def test_patch_rejects_blank_required_field(self, app, pg_conn):
from application.api.user.models.routes import UserModelResource
from application.storage.db.repositories.user_custom_models import (
UserCustomModelsRepository,
)
created = UserCustomModelsRepository(pg_conn).create(
user_id="u",
upstream_model_id="x",
display_name="x",
base_url="https://api.mistral.ai/v1",
api_key_plaintext="sk-x",
)
with app.test_request_context(
f"/api/user/models/{created['id']}",
method="PATCH",
json={"base_url": " "},
):
from flask import request
request.decoded_token = {"sub": "u"}
with _patch_db(pg_conn):
resp = UserModelResource().patch(model_id=created["id"])
assert resp.status_code == 400
@pytest.mark.unit
class TestPayloadConnectionTest:
"""Verifies the payload-based test endpoint. Lets the UI's 'Test
connection' button work *before* the model is saved — operators
expect to validate their endpoint + key before committing."""
def test_payload_test_rejects_unsafe_url(self, app, pg_conn):
from application.api.user.models.routes import (
UserModelTestPayloadResource,
)
with app.test_request_context(
"/api/user/models/test",
method="POST",
json={
"base_url": "https://127.0.0.1/v1",
"api_key": "sk-anything",
"upstream_model_id": "x",
},
):
from flask import request
request.decoded_token = {"sub": "user-1"}
with _patch_db(pg_conn):
resp = UserModelTestPayloadResource().post()
assert resp.status_code == 400
body = resp.get_json()
assert body["ok"] is False
def test_payload_test_returns_ok_when_upstream_responds_2xx(
self, app, pg_conn
):
from application.api.user.models.routes import (
UserModelTestPayloadResource,
)
# pinned_post is the IP-pinned dispatch helper. Patching it
# bypasses both the SSRF guard and the network — the success
# path we're verifying here is the route's response handling.
with patch("application.api.user.models.routes.pinned_post") as rp:
rp.return_value = MagicMock(
status_code=200,
headers={"Content-Type": "application/json"},
text='{"ok": true}',
)
with app.test_request_context(
"/api/user/models/test",
method="POST",
json={
"base_url": "https://api.mistral.ai/v1",
"api_key": "sk-mistral-test",
"upstream_model_id": "mistral-large-latest",
},
):
from flask import request
request.decoded_token = {"sub": "user-1"}
with _patch_db(pg_conn):
resp = UserModelTestPayloadResource().post()
assert resp.status_code == 200
body = resp.get_json()
assert body["ok"] is True
# Verify the upstream call carried the user's submitted key (not
# whatever's in the DB) and the right model name.
call_args = rp.call_args
assert call_args.kwargs["headers"]["Authorization"] == "Bearer sk-mistral-test"
assert call_args.kwargs["json"]["model"] == "mistral-large-latest"
def test_payload_test_unauthenticated_returns_401(self, app, pg_conn):
from application.api.user.models.routes import (
UserModelTestPayloadResource,
)
with app.test_request_context(
"/api/user/models/test",
method="POST",
json={
"base_url": "https://api.mistral.ai/v1",
"api_key": "k",
"upstream_model_id": "x",
},
):
from flask import request
request.decoded_token = None
with _patch_db(pg_conn):
resp = UserModelTestPayloadResource().post()
assert resp.status_code == 401
def test_payload_test_missing_fields_returns_400(self, app, pg_conn):
from application.api.user.models.routes import (
UserModelTestPayloadResource,
)
with app.test_request_context(
"/api/user/models/test",
method="POST",
json={"base_url": "https://api.mistral.ai/v1"},
):
from flask import request
request.decoded_token = {"sub": "user-1"}
with _patch_db(pg_conn):
resp = UserModelTestPayloadResource().post()
assert resp.status_code == 400
@pytest.mark.unit
class TestByIdConnectionTestAcceptsOverrides:
"""P3: in edit mode the modal sends current form state as overrides
so the test reflects in-flight edits (not the saved record)."""
def _make_row(self, pg_conn):
from application.storage.db.repositories.user_custom_models import (
UserCustomModelsRepository,
)
return UserCustomModelsRepository(pg_conn).create(
user_id="u",
upstream_model_id="stored-model",
display_name="Stored",
base_url="https://stored.example.com/v1",
api_key_plaintext="sk-stored",
)
def _post_test(self, app, pg_conn, model_id, body):
from application.api.user.models.routes import UserModelTestResource
with patch("application.api.user.models.routes.pinned_post") as rp:
rp.return_value = MagicMock(
status_code=200,
headers={"Content-Type": "application/json"},
text='{"ok": true}',
)
with app.test_request_context(
f"/api/user/models/{model_id}/test", method="POST", json=body
):
from flask import request
request.decoded_token = {"sub": "u"}
with _patch_db(pg_conn):
UserModelTestResource().post(model_id=model_id)
return rp.call_args
def test_overrides_win_when_supplied(self, app, pg_conn):
row = self._make_row(pg_conn)
ca = self._post_test(
app,
pg_conn,
row["id"],
{
"base_url": "https://new.example.com/v1",
"api_key": "sk-new",
"upstream_model_id": "new-model",
},
)
assert ca.args[0] == "https://new.example.com/v1/chat/completions"
assert ca.kwargs["headers"]["Authorization"] == "Bearer sk-new"
assert ca.kwargs["json"]["model"] == "new-model"
def test_blank_overrides_fall_back_to_stored(self, app, pg_conn):
"""The classic edit-mode flow: user changed base_url, left
api_key blank — server uses the new URL but the stored key."""
row = self._make_row(pg_conn)
ca = self._post_test(
app,
pg_conn,
row["id"],
{
"base_url": "https://new.example.com/v1",
"api_key": "",
"upstream_model_id": "",
},
)
assert ca.args[0] == "https://new.example.com/v1/chat/completions"
# Stored key (decrypted) was used.
assert ca.kwargs["headers"]["Authorization"] == "Bearer sk-stored"
# Stored upstream_model_id was used.
assert ca.kwargs["json"]["model"] == "stored-model"
def test_empty_body_uses_all_stored_values(self, app, pg_conn):
row = self._make_row(pg_conn)
ca = self._post_test(app, pg_conn, row["id"], {})
assert ca.args[0] == "https://stored.example.com/v1/chat/completions"
assert ca.kwargs["headers"]["Authorization"] == "Bearer sk-stored"
assert ca.kwargs["json"]["model"] == "stored-model"
@pytest.mark.unit
class TestApiModelsListWithUser:
def test_includes_user_models_when_authenticated(self, app, pg_conn):
"""GET /api/models with auth should surface the user's BYOM
records alongside built-ins, each tagged with `source`."""
from application.api.user.models.routes import ModelsListResource
from application.core.model_registry import ModelRegistry
created = _create_via_repo(
pg_conn, user_id="user-1", display_name="My Mistral"
)
# Patch the *registry's* db_readonly so the per-user layer load
# uses the test connection.
@contextmanager
def _yield():
yield pg_conn
with patch(
"application.storage.db.session.db_readonly", _yield
):
ModelRegistry.reset()
with app.test_request_context("/api/models"):
from flask import request
request.decoded_token = {"sub": "user-1"}
resp = ModelsListResource().get()
assert resp.status_code == 200
body = resp.get_json()
ids = [m["id"] for m in body["models"]]
assert created["id"] in ids
# Source label tags it for the UI
user_entries = [m for m in body["models"] if m["id"] == created["id"]]
assert user_entries[0]["source"] == "user"
# Built-ins still present
assert any(m.get("source") == "builtin" for m in body["models"])

View File

@@ -122,6 +122,11 @@ def mock_llm():
llm._supports_tools = True
llm._supports_structured_output = Mock(return_value=False)
llm.__class__.__name__ = "MockLLM"
# Mirror BaseLLM.__init__: real LLMCreator stores the resolved
# upstream model name on self.model_id. Tests that build agents via
# ``mock_llm_creator`` rely on the agent's ``upstream_model_id``
# falling through to this attribute.
llm.model_id = "gpt-4"
return llm

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,505 @@
"""Tests for the BYOM per-user layer on ModelRegistry.
Covers: per-user lookups don't leak across users, lookups without
user_id stay built-in only, get_all_models / get_enabled_models /
model_exists all consult the user layer when given user_id, and the
explicit invalidate_user clears the cache.
"""
from __future__ import annotations
from contextlib import contextmanager
from unittest.mock import MagicMock, patch
import pytest
from application.core.model_registry import ModelRegistry
from application.core.model_settings import ModelProvider
from application.storage.db.repositories.user_custom_models import (
UserCustomModelsRepository,
)
@pytest.fixture(autouse=True)
def _reset_registry():
ModelRegistry.reset()
yield
ModelRegistry.reset()
def _make_settings(**overrides):
s = MagicMock()
s.OPENAI_BASE_URL = None
s.OPENAI_API_KEY = None
s.OPENAI_API_BASE = None
s.ANTHROPIC_API_KEY = None
s.GOOGLE_API_KEY = None
s.GROQ_API_KEY = None
s.OPEN_ROUTER_API_KEY = None
s.NOVITA_API_KEY = None
s.HUGGINGFACE_API_KEY = None
s.LLM_PROVIDER = ""
s.LLM_NAME = None
s.API_KEY = None
s.MODELS_CONFIG_DIR = None
for k, v in overrides.items():
setattr(s, k, v)
return s
@contextmanager
def _yield(conn):
yield conn
@pytest.mark.unit
class TestPerUserLayer:
def test_user_models_isolated_per_user(self, pg_conn):
"""Alice's BYOM model must not appear in Bob's lookups."""
repo = UserCustomModelsRepository(pg_conn)
alice_model = repo.create(
user_id="alice",
upstream_model_id="alice-mistral",
display_name="Alice Mistral",
base_url="https://api.mistral.ai/v1",
api_key_plaintext="sk-alice",
)
s = _make_settings()
with patch("application.core.settings.settings", s), patch(
"application.storage.db.session.db_readonly",
lambda: _yield(pg_conn),
):
reg = ModelRegistry()
assert reg.get_model(alice_model["id"], user_id="alice") is not None
assert reg.get_model(alice_model["id"], user_id="bob") is None
# And without a user_id at all, the per-user layer is invisible
assert reg.get_model(alice_model["id"]) is None
def test_get_all_models_includes_user_models(self, pg_conn):
repo = UserCustomModelsRepository(pg_conn)
created = repo.create(
user_id="user-1",
upstream_model_id="mistral-large-latest",
display_name="My Mistral",
base_url="https://api.mistral.ai/v1",
api_key_plaintext="sk-test",
)
s = _make_settings()
with patch("application.core.settings.settings", s), patch(
"application.storage.db.session.db_readonly",
lambda: _yield(pg_conn),
):
reg = ModelRegistry()
ids_no_user = {m.id for m in reg.get_all_models()}
ids_with_user = {
m.id for m in reg.get_all_models(user_id="user-1")
}
assert created["id"] not in ids_no_user
assert created["id"] in ids_with_user
def test_user_models_carry_decrypted_api_key_and_upstream_id(self, pg_conn):
repo = UserCustomModelsRepository(pg_conn)
created = repo.create(
user_id="user-1",
upstream_model_id="mistral-large-latest",
display_name="My Mistral",
base_url="https://api.mistral.ai/v1",
api_key_plaintext="sk-test-XYZ",
)
s = _make_settings()
with patch("application.core.settings.settings", s), patch(
"application.storage.db.session.db_readonly",
lambda: _yield(pg_conn),
):
reg = ModelRegistry()
m = reg.get_model(created["id"], user_id="user-1")
assert m is not None
assert m.provider == ModelProvider.OPENAI_COMPATIBLE
assert m.upstream_model_id == "mistral-large-latest"
assert m.api_key == "sk-test-XYZ"
assert m.base_url == "https://api.mistral.ai/v1"
assert m.source == "user"
# The wire format never leaks the api_key
d = m.to_dict()
assert "api_key" not in d
for v in d.values():
assert v != "sk-test-XYZ"
def test_invalidate_user_clears_cache(self, pg_conn):
repo = UserCustomModelsRepository(pg_conn)
created = repo.create(
user_id="user-1",
upstream_model_id="x",
display_name="X",
base_url="https://api.mistral.ai/v1",
api_key_plaintext="k",
)
s = _make_settings()
# Stub Redis so invalidate_user can publish its version bump
# without hitting a real broker. The P1 fix calls ``incr`` on
# invalidate; here we just need it not to raise.
fake_redis = MagicMock()
with patch("application.core.settings.settings", s), patch(
"application.storage.db.session.db_readonly",
lambda: _yield(pg_conn),
), patch(
"application.cache.get_redis_instance", return_value=fake_redis
):
reg = ModelRegistry()
assert reg.get_model(created["id"], user_id="user-1") is not None
# Cache populated
assert "user-1" in reg._user_models
ModelRegistry.invalidate_user("user-1")
assert "user-1" not in reg._user_models
# Re-lookup repopulates
reg.get_model(created["id"], user_id="user-1")
assert "user-1" in reg._user_models
@pytest.mark.unit
class TestLLMCreatorDispatchUsesUpstreamModelId:
def test_llmcreator_sends_upstream_id_not_uuid(self, pg_conn):
"""End-to-end: LLMCreator with a BYOM uuid must construct the
OpenAILLM with the user's upstream model name (e.g.
``mistral-large-latest``), not the registry uuid."""
repo = UserCustomModelsRepository(pg_conn)
created = repo.create(
user_id="user-1",
upstream_model_id="mistral-large-latest",
display_name="My Mistral",
base_url="https://api.mistral.ai/v1",
api_key_plaintext="sk-mistral-real",
)
captured = {}
class _FakeLLM:
def __init__(self, api_key, user_api_key, *args, **kwargs):
captured["api_key"] = api_key
captured["base_url"] = kwargs.get("base_url")
captured["model_id"] = kwargs.get("model_id")
s = _make_settings()
with patch("application.core.settings.settings", s), patch(
"application.storage.db.session.db_readonly",
lambda: _yield(pg_conn),
):
ModelRegistry()
from application.llm.providers import PROVIDERS_BY_NAME
with patch.object(
PROVIDERS_BY_NAME["openai_compatible"], "llm_class", _FakeLLM
):
from application.llm.llm_creator import LLMCreator
LLMCreator.create_llm(
type="openai_compatible",
api_key="caller-passed-WRONG",
user_api_key=None,
decoded_token={"sub": "user-1"},
model_id=created["id"],
)
assert captured["api_key"] == "sk-mistral-real"
assert captured["base_url"] == "https://api.mistral.ai/v1"
assert captured["model_id"] == "mistral-large-latest" # NOT the uuid!
def test_llmcreator_forwards_byom_capabilities(self, pg_conn):
"""LLMCreator must thread the registry-resolved ``capabilities``
into the LLM. Without it the OpenAILLM hard-codes ``True`` for
tools/structured output and advertises image attachments
unconditionally, leaking unsupported features to BYOMs that
disabled them."""
repo = UserCustomModelsRepository(pg_conn)
created = repo.create(
user_id="user-2",
upstream_model_id="my-text-only-model",
display_name="Text-only BYOM",
base_url="https://api.mistral.ai/v1",
api_key_plaintext="sk-real",
capabilities={
"supports_tools": False,
"supports_structured_output": False,
"attachments": [],
"context_window": 8192,
},
)
captured = {}
class _FakeLLM:
def __init__(self, api_key, user_api_key, *args, **kwargs):
captured["capabilities"] = kwargs.get("capabilities")
s = _make_settings()
with patch("application.core.settings.settings", s), patch(
"application.storage.db.session.db_readonly",
lambda: _yield(pg_conn),
):
ModelRegistry()
from application.llm.providers import PROVIDERS_BY_NAME
with patch.object(
PROVIDERS_BY_NAME["openai_compatible"], "llm_class", _FakeLLM
):
from application.llm.llm_creator import LLMCreator
LLMCreator.create_llm(
type="openai_compatible",
api_key="ignored",
user_api_key=None,
decoded_token={"sub": "user-2"},
model_id=created["id"],
)
caps = captured["capabilities"]
assert caps is not None
assert caps.supports_tools is False
assert caps.supports_structured_output is False
assert caps.supported_attachment_types == []
def test_byom_image_alias_expands_to_mime_types(self, pg_conn):
"""A BYOM stored with ``attachments: ["image"]`` (the alias the
UI sends) must surface as concrete MIME types on the registry
record, matching the built-in YAML expansion. Without this,
handlers/base.prepare_messages compares ``image/png`` against
the bare alias and filters every image upload as unsupported.
"""
repo = UserCustomModelsRepository(pg_conn)
created = repo.create(
user_id="user-1",
upstream_model_id="my-vision-model",
display_name="Vision BYOM",
base_url="https://api.mistral.ai/v1",
api_key_plaintext="sk-real",
capabilities={"attachments": ["image"]},
)
s = _make_settings()
with patch("application.core.settings.settings", s), patch(
"application.storage.db.session.db_readonly",
lambda: _yield(pg_conn),
):
reg = ModelRegistry()
model = reg.get_model(created["id"], user_id="user-1")
assert model is not None
types = model.capabilities.supported_attachment_types
assert "image" not in types, (
"alias must be expanded, not stored verbatim"
)
assert any(t.startswith("image/") for t in types)
# Must include at least the common web image types so any image
# an end user uploads has a chance to match.
assert "image/png" in types
assert "image/jpeg" in types
def test_byom_unknown_alias_is_skipped_at_runtime(self, pg_conn):
"""Operator alias-map edits could orphan a stored alias. The
registry must drop the unknown entry rather than the entire
layer (which would hide every BYOM the user has).
"""
repo = UserCustomModelsRepository(pg_conn)
created = repo.create(
user_id="user-1",
upstream_model_id="m",
display_name="M",
base_url="https://api.mistral.ai/v1",
api_key_plaintext="k",
# Bypass the route validation: write a bad alias straight
# to the row to simulate the post-edit orphan case.
capabilities={"attachments": ["image", "not-a-real-alias"]},
)
s = _make_settings()
with patch("application.core.settings.settings", s), patch(
"application.storage.db.session.db_readonly",
lambda: _yield(pg_conn),
):
reg = ModelRegistry()
model = reg.get_model(created["id"], user_id="user-1")
assert model is not None
types = model.capabilities.supported_attachment_types
assert "not-a-real-alias" not in types
assert any(t.startswith("image/") for t in types)
@pytest.mark.unit
class TestCrossProcessInvalidation:
"""The BYOM cache lives per-process. Without the P1 fix, a CRUD on
web-1 would leave the decrypted record (with old api_key/base_url)
cached forever in web-2 / Celery. These tests pin down that:
* ``invalidate_user`` publishes a version bump to Redis
* peers reload when the version they saw at load time is stale
* the local TTL bounds staleness even when Redis is unreachable
* unchanged version + expired TTL extends the entry without a
DB read (the common-case fast path)
"""
def test_invalidate_user_publishes_redis_version_bump(self, pg_conn):
repo = UserCustomModelsRepository(pg_conn)
repo.create(
user_id="user-1",
upstream_model_id="m1",
display_name="M1",
base_url="https://api.mistral.ai/v1",
api_key_plaintext="k",
)
fake_redis = MagicMock()
s = _make_settings()
with patch("application.core.settings.settings", s), patch(
"application.storage.db.session.db_readonly",
lambda: _yield(pg_conn),
), patch(
"application.cache.get_redis_instance", return_value=fake_redis
):
ModelRegistry().get_model("anything", user_id="user-1")
ModelRegistry.invalidate_user("user-1")
fake_redis.incr.assert_called_once_with("byom:registry_version:user-1")
def test_peer_reloads_when_redis_version_changes(self, pg_conn):
"""Two-process simulation. Peer loads at version=0; another
process's CRUD bumps the version and updates Postgres; peer's
next post-TTL access sees the version mismatch and reloads,
picking up the rotated key it never invalidated locally."""
from application.core import model_registry as registry_mod
repo = UserCustomModelsRepository(pg_conn)
created = repo.create(
user_id="user-1",
upstream_model_id="m-orig",
display_name="orig",
base_url="https://api.mistral.ai/v1",
api_key_plaintext="orig-key",
)
state = {"version": 0}
class _FakeRedis:
def get(self, key):
if key == "byom:registry_version:user-1":
return str(state["version"]).encode()
return None
def incr(self, key):
if key == "byom:registry_version:user-1":
state["version"] += 1
s = _make_settings()
# Force TTL to 0 so any subsequent access takes the post-TTL
# path without waiting 60s.
with patch("application.core.settings.settings", s), patch(
"application.storage.db.session.db_readonly",
lambda: _yield(pg_conn),
), patch(
"application.cache.get_redis_instance", return_value=_FakeRedis()
), patch.object(
registry_mod, "_USER_CACHE_TTL_SECONDS", 0.0
):
reg = ModelRegistry()
assert (
reg.get_model(created["id"], user_id="user-1").api_key
== "orig-key"
)
# Another process's CRUD: bump Redis counter + mutate the
# row. Note we deliberately do NOT call ``invalidate_user``
# in this process — that's the whole point of the test.
state["version"] += 1
repo.update(
created["id"],
"user-1",
{"api_key_plaintext": "rotated-key"},
)
assert (
reg.get_model(created["id"], user_id="user-1").api_key
== "rotated-key"
)
def test_ttl_bounds_staleness_when_redis_unavailable(self, pg_conn):
"""Redis down → fall back to TTL-only invalidation. After the
TTL elapses, peers reload regardless."""
from application.core import model_registry as registry_mod
repo = UserCustomModelsRepository(pg_conn)
created = repo.create(
user_id="user-1",
upstream_model_id="m",
display_name="m",
base_url="https://api.mistral.ai/v1",
api_key_plaintext="orig",
)
s = _make_settings()
with patch("application.core.settings.settings", s), patch(
"application.storage.db.session.db_readonly",
lambda: _yield(pg_conn),
), patch(
"application.cache.get_redis_instance", return_value=None
), patch.object(
registry_mod, "_USER_CACHE_TTL_SECONDS", 0.0
):
reg = ModelRegistry()
first = reg.get_model(created["id"], user_id="user-1")
assert first.api_key == "orig"
repo.update(
created["id"],
"user-1",
{"api_key_plaintext": "rotated"},
)
second = reg.get_model(created["id"], user_id="user-1")
assert second.api_key == "rotated"
def test_unchanged_version_extends_ttl_without_db_read(self, pg_conn):
"""Hot path: TTL expires but Redis says no invalidation
happened — extend the entry without re-reading Postgres."""
from application.core import model_registry as registry_mod
repo = UserCustomModelsRepository(pg_conn)
created = repo.create(
user_id="user-1",
upstream_model_id="m",
display_name="m",
base_url="https://api.mistral.ai/v1",
api_key_plaintext="k",
)
fake_redis = MagicMock()
fake_redis.get.return_value = b"7" # constant version
db_open_count = {"n": 0}
@contextmanager
def _counting_db_readonly():
db_open_count["n"] += 1
yield pg_conn
s = _make_settings()
with patch("application.core.settings.settings", s), patch(
"application.storage.db.session.db_readonly",
_counting_db_readonly,
), patch(
"application.cache.get_redis_instance", return_value=fake_redis
), patch.object(
registry_mod, "_USER_CACHE_TTL_SECONDS", 0.0
):
reg = ModelRegistry()
reg.get_model(created["id"], user_id="user-1")
first_open = db_open_count["n"]
# TTL has expired, but Redis returns the same version we
# captured at load time → no DB reload.
reg.get_model(created["id"], user_id="user-1")
reg.get_model(created["id"], user_id="user-1")
assert db_open_count["n"] == first_open

View File

@@ -1239,6 +1239,131 @@ class TestPerformInMemoryCompression:
assert messages is not None
assert agent.compressed_summary == "summary"
def test_uses_agent_model_user_id_for_byom_owner_scope(self):
"""Shared-agent BYOM dispatch: the agent's model_id is the owner's
BYOM UUID and ``decoded_token['sub']`` is a different (caller) user.
Provider lookup and LLMCreator must run under the OWNER scope so
the registry hits the owner's per-user layer instead of missing
and falling back to the deployment default."""
handler = ConcreteHandler()
agent = Mock()
agent.model_id = "byom-uuid-owner"
agent.model_user_id = "owner-id"
agent.user_api_key = None
agent.decoded_token = {"sub": "caller-id"}
agent.agent_id = None
agent.context_limit_reached = False
agent.current_token_count = 0
mock_metadata = Mock()
mock_metadata.compressed_token_count = 100
mock_metadata.original_token_count = 1000
mock_metadata.compression_ratio = 10.0
mock_metadata.to_dict.return_value = {"ratio": 10.0}
mock_service = Mock()
mock_service.compress_conversation.return_value = mock_metadata
mock_service.get_compressed_context.return_value = (
"summary",
[{"prompt": "q", "response": "a"}],
)
provider_lookup = MagicMock(return_value="openai_compatible")
create_llm_spy = MagicMock(return_value=Mock())
with patch.object(
handler,
"_build_conversation_from_messages",
return_value={"queries": [{"prompt": "q", "response": "a"}]},
), patch(
"application.core.model_utils.get_provider_from_model_id",
provider_lookup,
), patch(
"application.core.model_utils.get_api_key_for_provider",
return_value="key",
), patch(
"application.llm.llm_creator.LLMCreator.create_llm",
create_llm_spy,
), patch(
"application.api.answer.services.compression.service.CompressionService",
return_value=mock_service,
), patch.object(
handler,
"_rebuild_messages_after_compression",
return_value=[{"role": "system", "content": "rebuilt"}],
), patch(
"application.core.settings.settings",
MagicMock(COMPRESSION_MODEL_OVERRIDE=None),
):
success, _ = handler._perform_in_memory_compression(
agent, [{"role": "user", "content": "hi"}]
)
assert success is True
# Provider lookup must use the owner scope, not the caller's sub.
assert provider_lookup.call_args.kwargs["user_id"] == "owner-id"
# LLMCreator must receive the owner scope as model_user_id;
# without this the BYOM UUID resolves under the caller and the
# owner's registered base_url + api_key are missed.
assert create_llm_spy.call_args.kwargs["model_user_id"] == "owner-id"
def test_falls_back_to_caller_sub_when_no_model_user_id(self):
"""Built-in or caller-owned BYOM: ``agent.model_user_id`` is None.
Compression should resolve under the caller's sub, matching the
pre-fix behavior."""
handler = ConcreteHandler()
agent = Mock()
agent.model_id = "gpt-4"
agent.model_user_id = None
agent.user_api_key = None
agent.decoded_token = {"sub": "caller-id"}
agent.agent_id = None
agent.context_limit_reached = False
agent.current_token_count = 0
mock_metadata = Mock()
mock_metadata.compressed_token_count = 100
mock_metadata.original_token_count = 1000
mock_metadata.to_dict.return_value = {}
mock_service = Mock()
mock_service.compress_conversation.return_value = mock_metadata
mock_service.get_compressed_context.return_value = ("summary", [])
provider_lookup = MagicMock(return_value="openai")
create_llm_spy = MagicMock(return_value=Mock())
with patch.object(
handler,
"_build_conversation_from_messages",
return_value={"queries": [{"prompt": "q", "response": "a"}]},
), patch(
"application.core.model_utils.get_provider_from_model_id",
provider_lookup,
), patch(
"application.core.model_utils.get_api_key_for_provider",
return_value="key",
), patch(
"application.llm.llm_creator.LLMCreator.create_llm",
create_llm_spy,
), patch(
"application.api.answer.services.compression.service.CompressionService",
return_value=mock_service,
), patch.object(
handler,
"_rebuild_messages_after_compression",
return_value=[],
), patch(
"application.core.settings.settings",
MagicMock(COMPRESSION_MODEL_OVERRIDE=None),
):
handler._perform_in_memory_compression(
agent, [{"role": "user", "content": "hi"}]
)
assert provider_lookup.call_args.kwargs["user_id"] == "caller-id"
assert create_llm_spy.call_args.kwargs["model_user_id"] == "caller-id"
# ---------------------------------------------------------------------------
# _perform_mid_execution_compression — additional edge cases

View File

@@ -197,7 +197,7 @@ class TestFallbackLLMResolution:
mock_fallback = StubLLM()
monkeypatch.setattr(
"application.core.model_utils.get_provider_from_model_id",
lambda mid: "openai",
lambda mid, **_kwargs: "openai",
)
monkeypatch.setattr(
"application.core.model_utils.get_api_key_for_provider",
@@ -223,7 +223,7 @@ class TestFallbackLLMResolution:
monkeypatch.setattr(
"application.core.model_utils.get_provider_from_model_id",
lambda mid: "openai",
lambda mid, **_kwargs: "openai",
)
monkeypatch.setattr(
"application.core.model_utils.get_api_key_for_provider",
@@ -262,7 +262,7 @@ class TestFallbackLLMResolution:
def test_backup_provider_not_found_skipped(self, monkeypatch):
monkeypatch.setattr(
"application.core.model_utils.get_provider_from_model_id",
lambda mid: None,
lambda mid, **_kwargs: None,
)
monkeypatch.setattr(
"application.llm.base.settings",

View File

@@ -11,9 +11,7 @@ import pytest
from application.llm.base import BaseLLM
# ---------------------------------------------------------------------------
# Concrete LLM stubs
# ---------------------------------------------------------------------------
class FakeLLM(BaseLLM):
@@ -59,9 +57,7 @@ class FakeLLM(BaseLLM):
return super().gen_stream(*args, **kwargs)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _noop_decorator(func):
@@ -121,9 +117,7 @@ def patch_model_utils(monkeypatch):
CALL_ARGS = dict(model="test-model", messages=[{"role": "user", "content": "hi"}])
# ---------------------------------------------------------------------------
# Tests — fallback_llm property resolution
# ---------------------------------------------------------------------------
@pytest.mark.integration
@@ -135,7 +129,7 @@ class TestFallbackLLMResolution:
backup_llm = FakeLLM(responses=["backup response"])
patch_model_utils(
get_provider=lambda mid: "openai",
get_provider=lambda mid, **_kwargs: "openai",
get_api_key=lambda prov: "fake-key",
create_llm=lambda type, **kw: backup_llm,
)
@@ -174,7 +168,7 @@ class TestFallbackLLMResolution:
good_backup = FakeLLM(responses=["good backup"])
call_count = {"n": 0}
def fake_get_provider(model_id):
def fake_get_provider(model_id, **_kwargs):
call_count["n"] += 1
if model_id == "bad-model":
return None # unresolvable
@@ -202,9 +196,7 @@ class TestFallbackLLMResolution:
assert primary.fallback_llm is None
# ---------------------------------------------------------------------------
# Tests — non-streaming fallback (gen)
# ---------------------------------------------------------------------------
@pytest.mark.integration
@@ -221,7 +213,7 @@ class TestNonStreamingFallback:
backup = FakeLLM(responses=["backup ok"])
patch_model_utils(
get_provider=lambda mid: "openai",
get_provider=lambda mid, **_kwargs: "openai",
get_api_key=lambda p: "k",
create_llm=lambda type, **kw: backup,
)
@@ -242,9 +234,7 @@ class TestNonStreamingFallback:
primary.gen(**CALL_ARGS)
# ---------------------------------------------------------------------------
# Tests — streaming fallback (gen_stream)
# ---------------------------------------------------------------------------
@pytest.mark.integration
@@ -261,7 +251,7 @@ class TestStreamingFallback:
backup = FakeLLM(stream_chunks=["fallback1", "fallback2"])
patch_model_utils(
get_provider=lambda m: "openai",
get_provider=lambda m, **_kwargs: "openai",
get_api_key=lambda p: "k",
create_llm=lambda type, **kw: backup,
)
@@ -280,7 +270,7 @@ class TestStreamingFallback:
backup = FakeLLM(stream_chunks=["recovery1", "recovery2"])
patch_model_utils(
get_provider=lambda m: "openai",
get_provider=lambda m, **_kwargs: "openai",
get_api_key=lambda p: "k",
create_llm=lambda type, **kw: backup,
)
@@ -305,9 +295,7 @@ class TestStreamingFallback:
list(primary.gen_stream(**CALL_ARGS))
# ---------------------------------------------------------------------------
# Tests — backup model priority over global fallback
# ---------------------------------------------------------------------------
@pytest.mark.integration
@@ -323,7 +311,7 @@ class TestBackupModelPriority:
return backup
patch_model_utils(
get_provider=lambda m: "openai",
get_provider=lambda m, **_kwargs: "openai",
get_api_key=lambda p: "k",
create_llm=fake_create_llm,
)
@@ -346,7 +334,7 @@ class TestBackupModelPriority:
return backup
patch_model_utils(
get_provider=lambda m: "openai",
get_provider=lambda m, **_kwargs: "openai",
get_api_key=lambda p: "k",
create_llm=fake_create_llm,
)
@@ -366,7 +354,7 @@ class TestBackupModelPriority:
global_fallback = FakeLLM(responses=["global ok"])
call_order = []
def fake_get_provider(mid):
def fake_get_provider(mid, **_kwargs):
if mid == "broken-backup":
return "nonexistent_provider"
return "openai"
@@ -401,9 +389,7 @@ class TestBackupModelPriority:
assert call_order == ["broken-backup", "global-model"]
# ---------------------------------------------------------------------------
# Tests — fallback uses its own model_id, not the primary's
# ---------------------------------------------------------------------------
@pytest.mark.integration
@@ -419,7 +405,7 @@ class TestFallbackModelIdOverride:
)
patch_model_utils(
get_provider=lambda m: "groq",
get_provider=lambda m, **_kwargs: "groq",
get_api_key=lambda p: "k",
create_llm=lambda type, **kw: backup,
)
@@ -441,7 +427,7 @@ class TestFallbackModelIdOverride:
)
patch_model_utils(
get_provider=lambda m: "groq",
get_provider=lambda m, **_kwargs: "groq",
get_api_key=lambda p: "k",
create_llm=lambda type, **kw: backup,
)
@@ -465,7 +451,7 @@ class TestFallbackModelIdOverride:
)
patch_model_utils(
get_provider=lambda m: "groq",
get_provider=lambda m, **_kwargs: "groq",
get_api_key=lambda p: "k",
create_llm=lambda type, **kw: backup,
)
@@ -480,3 +466,161 @@ class TestFallbackModelIdOverride:
assert chunks == ["partial1", "partial2", "recovered"]
assert backup.last_model_received == "groq-gpt-oss-120b"
# Tests — model_user_id (BYOM owner scope) propagates into fallback resolution
@pytest.mark.integration
class TestFallbackModelUserIdScope:
"""A shared agent dispatched by user B but owned by user A stores
A's BYOM UUIDs as backup_models. Without the P2 fix the fallback
property looks up those UUIDs against ``decoded_token['sub']`` (B,
the caller), which can't see A's per-user layer — backups are
silently skipped and the global FALLBACK_* settings are used
instead. These tests pin down that ``model_user_id`` (the owner)
is used both for the registry lookup and for the recursive
``LLMCreator.create_llm`` call."""
def test_backup_lookup_uses_model_user_id_not_caller(
self, patch_model_utils
):
captured = {"user_id": None}
def fake_get_provider(model_id, **kwargs):
captured["user_id"] = kwargs.get("user_id")
return "openai"
backup = FakeLLM(responses=["ok"])
patch_model_utils(
get_provider=fake_get_provider,
get_api_key=lambda p: "k",
create_llm=lambda type, **kw: backup,
)
primary = FakeLLM(
decoded_token={"sub": "caller-bob"},
model_user_id="owner-alice",
backup_models=["alice-byom-uuid"],
)
_ = primary.fallback_llm
assert captured["user_id"] == "owner-alice"
def test_backup_create_llm_receives_model_user_id(self, patch_model_utils):
backup = FakeLLM(responses=["ok"])
captured = {}
def fake_create_llm(type, **kw):
captured["model_user_id"] = kw.get("model_user_id")
captured["model_id"] = kw.get("model_id")
return backup
patch_model_utils(
get_provider=lambda m, **_kwargs: "openai",
get_api_key=lambda p: "k",
create_llm=fake_create_llm,
)
primary = FakeLLM(
decoded_token={"sub": "caller-bob"},
model_user_id="owner-alice",
backup_models=["alice-byom-uuid"],
)
_ = primary.fallback_llm
assert captured["model_user_id"] == "owner-alice"
assert captured["model_id"] == "alice-byom-uuid"
def test_global_fallback_create_llm_receives_model_user_id(
self, monkeypatch, patch_model_utils
):
"""The global FALLBACK_LLM_NAME path must also forward
``model_user_id`` — operators can configure it to a BYOM UUID
that's owned by the same user as the primary model."""
backup = FakeLLM(responses=["ok"])
captured = {}
def fake_create_llm(type, **kw):
captured["model_user_id"] = kw.get("model_user_id")
return backup
patch_model_utils(create_llm=fake_create_llm)
monkeypatch.setattr(
"application.llm.base.settings",
MagicMock(
FALLBACK_LLM_PROVIDER="openai",
FALLBACK_LLM_NAME="some-uuid",
FALLBACK_LLM_API_KEY="k",
API_KEY="k",
),
)
primary = FakeLLM(
decoded_token={"sub": "caller-bob"},
model_user_id="owner-alice",
backup_models=[],
)
_ = primary.fallback_llm
assert captured["model_user_id"] == "owner-alice"
def test_falls_back_to_caller_when_model_user_id_unset(
self, patch_model_utils
):
"""Built-in models / pre-P2 callers don't pass model_user_id.
In that case the caller's sub is still used — preserving
existing behaviour."""
captured = {}
def fake_get_provider(model_id, **kwargs):
captured["user_id"] = kwargs.get("user_id")
return "openai"
patch_model_utils(
get_provider=fake_get_provider,
get_api_key=lambda p: "k",
create_llm=lambda type, **kw: FakeLLM(responses=["ok"]),
)
primary = FakeLLM(
decoded_token={"sub": "caller-bob"},
model_user_id=None,
backup_models=["some-builtin-id"],
)
_ = primary.fallback_llm
assert captured["user_id"] == "caller-bob"
# Tests — LLMCreator wires model_user_id through to BaseLLM
@pytest.mark.unit
class TestLLMCreatorPassesModelUserId:
"""End-to-end through ``LLMCreator.create_llm``: the constructed
LLM must store ``model_user_id`` so its fallback property can
resolve under the right scope."""
def test_model_user_id_set_on_constructed_llm(self, monkeypatch):
from application.llm.llm_creator import LLMCreator
from application.llm.providers import PROVIDERS_BY_NAME
captured = {}
class _CapturingLLM:
def __init__(self, api_key, user_api_key, *args, **kwargs):
captured["model_user_id"] = kwargs.get("model_user_id")
# Pick any registered provider — we only need the constructor
# call to land in our fake.
monkeypatch.setattr(
PROVIDERS_BY_NAME["openai"], "llm_class", _CapturingLLM
)
LLMCreator.create_llm(
type="openai",
api_key="k",
user_api_key=None,
decoded_token={"sub": "caller-bob"},
model_id=None,
model_user_id="owner-alice",
)
assert captured["model_user_id"] == "owner-alice"

View File

@@ -23,9 +23,7 @@ import pytest
from application.llm.openai import OpenAILLM, _truncate_base64_for_logging
# ---------------------------------------------------------------------------
# Fake client helpers
# ---------------------------------------------------------------------------
class _Msg:
def __init__(self, content=None, tool_calls=None):
@@ -116,9 +114,7 @@ def llm():
return instance
# ---------------------------------------------------------------------------
# _truncate_base64_for_logging
# ---------------------------------------------------------------------------
@pytest.mark.unit
@@ -170,9 +166,7 @@ class TestTruncateBase64ForLogging:
assert "BASE64_DATA_TRUNCATED" in result[0]["content"]["nested"]
# ---------------------------------------------------------------------------
# _normalize_reasoning_value
# ---------------------------------------------------------------------------
@pytest.mark.unit
@@ -212,9 +206,7 @@ class TestNormalizeReasoningValue:
assert OpenAILLM._normalize_reasoning_value(val) == "ab"
# ---------------------------------------------------------------------------
# _extract_reasoning_text
# ---------------------------------------------------------------------------
@pytest.mark.unit
@@ -240,9 +232,7 @@ class TestExtractReasoningText:
assert OpenAILLM._extract_reasoning_text(delta) == ""
# ---------------------------------------------------------------------------
# _clean_messages_openai
# ---------------------------------------------------------------------------
@pytest.mark.unit
@@ -340,9 +330,7 @@ class TestCleanMessagesOpenai:
llm._clean_messages_openai(msgs)
# ---------------------------------------------------------------------------
# _raw_gen
# ---------------------------------------------------------------------------
@pytest.mark.unit
@@ -392,9 +380,7 @@ class TestRawGen:
assert kwargs["tools"] == tools
# ---------------------------------------------------------------------------
# _raw_gen_stream
# ---------------------------------------------------------------------------
@pytest.mark.unit
@@ -444,9 +430,7 @@ class TestRawGenStream:
assert closed["called"]
# ---------------------------------------------------------------------------
# _supports_tools / _supports_structured_output
# ---------------------------------------------------------------------------
@pytest.mark.unit
@@ -459,9 +443,104 @@ class TestSupports:
assert llm._supports_structured_output() is True
# ---------------------------------------------------------------------------
# BYOM capability enforcement at dispatch
@pytest.mark.unit
class TestBYOMCapabilityEnforcement:
"""LLMCreator threads ``capabilities`` from the registry into the LLM.
These tests verify that a BYOM with restrictive caps doesn't get tools,
structured output, or unsupported attachment types at dispatch — even
when the caller forwards them."""
@staticmethod
def _llm_with_caps(
supports_tools=False,
supports_structured_output=False,
attachments=None,
):
from application.core.model_settings import ModelCapabilities
instance = OpenAILLM(
api_key="sk-test",
user_api_key=None,
capabilities=ModelCapabilities(
supports_tools=supports_tools,
supports_structured_output=supports_structured_output,
supported_attachment_types=attachments or [],
),
)
instance.client = FakeClient()
return instance
def test_supports_tools_respects_disabled_caps(self):
llm = self._llm_with_caps(supports_tools=False)
assert llm._supports_tools() is False
def test_supports_tools_respects_enabled_caps(self):
llm = self._llm_with_caps(supports_tools=True)
assert llm._supports_tools() is True
def test_supports_structured_output_respects_caps(self):
llm_off = self._llm_with_caps(supports_structured_output=False)
llm_on = self._llm_with_caps(supports_structured_output=True)
assert llm_off._supports_structured_output() is False
assert llm_on._supports_structured_output() is True
def test_get_supported_attachment_types_respects_caps(self):
llm = self._llm_with_caps(attachments=[])
assert llm.get_supported_attachment_types() == []
llm2 = self._llm_with_caps(attachments=["image/png"])
assert llm2.get_supported_attachment_types() == ["image/png"]
def test_raw_gen_drops_tools_when_caps_deny(self):
llm = self._llm_with_caps(supports_tools=False)
tools = [{"type": "function", "function": {"name": "t"}}]
msgs = [{"role": "user", "content": "hi"}]
llm._raw_gen(
llm, model="gpt", messages=msgs, stream=False, tools=tools
)
kwargs = llm.client.chat.completions.last_kwargs
assert "tools" not in kwargs
def test_raw_gen_drops_response_format_when_caps_deny(self):
llm = self._llm_with_caps(supports_structured_output=False)
msgs = [{"role": "user", "content": "hi"}]
llm._raw_gen(
llm,
model="gpt",
messages=msgs,
stream=False,
response_format={"type": "json_object"},
)
kwargs = llm.client.chat.completions.last_kwargs
assert "response_format" not in kwargs
def test_raw_gen_stream_drops_tools_when_caps_deny(self):
llm = self._llm_with_caps(supports_tools=False)
tools = [{"type": "function", "function": {"name": "t"}}]
msgs = [{"role": "user", "content": "hi"}]
list(
llm._raw_gen_stream(
llm, model="gpt", messages=msgs, stream=True, tools=tools
)
)
kwargs = llm.client.chat.completions.last_kwargs
assert "tools" not in kwargs
def test_no_caps_keeps_provider_defaults(self, llm):
# ``llm`` fixture builds an OpenAILLM with capabilities=None,
# i.e. provider-class defaults. Tools/structured output should
# pass through unchanged.
tools = [{"type": "function", "function": {"name": "t"}}]
msgs = [{"role": "user", "content": "hi"}]
llm._raw_gen(
llm, model="gpt", messages=msgs, stream=False, tools=tools
)
kwargs = llm.client.chat.completions.last_kwargs
assert kwargs["tools"] == tools
# prepare_structured_output_format
# ---------------------------------------------------------------------------
@pytest.mark.unit
@@ -546,9 +625,7 @@ class TestPrepareStructuredOutputFormat:
assert result["json_schema"]["description"] == "Structured response"
# ---------------------------------------------------------------------------
# get_supported_attachment_types
# ---------------------------------------------------------------------------
@pytest.mark.unit
@@ -560,9 +637,7 @@ class TestGetSupportedAttachmentTypes:
assert len(result) > 0
# ---------------------------------------------------------------------------
# prepare_messages_with_attachments
# ---------------------------------------------------------------------------
@pytest.mark.unit
@@ -654,9 +729,7 @@ class TestPrepareMessagesWithAttachments:
assert isinstance(user_msg["content"], list)
# ---------------------------------------------------------------------------
# _get_base64_image
# ---------------------------------------------------------------------------
@pytest.mark.unit
@@ -678,9 +751,7 @@ class TestGetBase64Image:
llm._get_base64_image({"path": "/nonexistent"})
# ---------------------------------------------------------------------------
# AzureOpenAILLM
# ---------------------------------------------------------------------------
@pytest.mark.unit
@@ -718,9 +789,7 @@ class TestAzureOpenAILLM:
assert issubclass(oai_mod.AzureOpenAILLM, oai_mod.OpenAILLM)
# ---------------------------------------------------------------------------
# _truncate_base64_for_logging — additional edges
# ---------------------------------------------------------------------------
@pytest.mark.unit
@@ -744,9 +813,7 @@ class TestTruncateBase64ForLoggingAdditional:
assert result[0]["content"] == "no base64 here"
# ---------------------------------------------------------------------------
# _clean_messages_openai — additional edges
# ---------------------------------------------------------------------------
@pytest.mark.unit
@@ -835,9 +902,7 @@ class TestCleanMessagesOpenaiAdditional:
llm._clean_messages_openai(msgs)
# ---------------------------------------------------------------------------
# _normalize_reasoning_value — additional edges
# ---------------------------------------------------------------------------
@pytest.mark.unit
@@ -862,9 +927,7 @@ class TestNormalizeReasoningValueAdditional:
assert OpenAILLM._normalize_reasoning_value(obj) == ""
# ---------------------------------------------------------------------------
# _extract_reasoning_text — additional edges
# ---------------------------------------------------------------------------
@pytest.mark.unit
@@ -881,9 +944,7 @@ class TestExtractReasoningTextAdditional:
assert OpenAILLM._extract_reasoning_text(delta) == "dict_thought"
# ---------------------------------------------------------------------------
# _raw_gen_stream — additional edges
# ---------------------------------------------------------------------------
@pytest.mark.unit
@@ -933,9 +994,7 @@ class TestRawGenStreamAdditional:
assert any(hasattr(c, "finish_reason") for c in chunks)
# ---------------------------------------------------------------------------
# prepare_structured_output_format — additional edges
# ---------------------------------------------------------------------------
@pytest.mark.unit
@@ -968,9 +1027,7 @@ class TestPrepareStructuredOutputAdditional:
assert one_of[0]["additionalProperties"] is False
# ---------------------------------------------------------------------------
# prepare_messages_with_attachments — additional edges
# ---------------------------------------------------------------------------
@pytest.mark.unit
@@ -1022,9 +1079,7 @@ class TestPrepareMessagesWithAttachmentsAdditional:
assert len(text_parts) == 0
# ---------------------------------------------------------------------------
# _upload_file_to_openai — additional edges
# ---------------------------------------------------------------------------
@pytest.mark.unit
@@ -1053,9 +1108,7 @@ class TestUploadFileToOpenai:
llm._upload_file_to_openai({"path": "/tmp/file.pdf"})
# ---------------------------------------------------------------------------
# OpenAILLM constructor — additional edges
# ---------------------------------------------------------------------------
@pytest.mark.unit
@@ -1134,9 +1187,7 @@ class TestOpenAILLMConstructor:
)
# ---------------------------------------------------------------------------
# _upload_file_to_openai — coverage lines 489-517
# ---------------------------------------------------------------------------
@pytest.mark.unit
@@ -1181,9 +1232,7 @@ class TestUploadFileToOpenai2:
llm._upload_file_to_openai({"path": "/file.pdf"})
# ---------------------------------------------------------------------------
# _normalize_reasoning_value — additional edges for line 155, 198
# ---------------------------------------------------------------------------
@pytest.mark.unit
@@ -1206,9 +1255,7 @@ class TestNormalizeReasoningAdditional:
assert result == "ab"
# ---------------------------------------------------------------------------
# _extract_reasoning_text — additional edge for line 198
# ---------------------------------------------------------------------------
@pytest.mark.unit
@@ -1227,9 +1274,7 @@ class TestExtractReasoningTextAdditional2:
assert result == ""
# ---------------------------------------------------------------------------
# prepare_structured_output_format — error path for line 348, 395
# ---------------------------------------------------------------------------
@pytest.mark.unit
@@ -1247,9 +1292,7 @@ class TestPrepareStructuredOutputAdditional2:
assert result is None
# ---------------------------------------------------------------------------
# Coverage — remaining uncovered lines
# ---------------------------------------------------------------------------
@pytest.mark.unit
@@ -1408,13 +1451,11 @@ class TestUploadFileToOpenaiLines489To517:
assert result == "file-no-cache"
# ---------------------------------------------------------------------------
# Additional coverage for openai.py
# Lines: 49 (truncate_content v passthrough), 80-82 (default base_url),
# 137 (function_response content), 198 (delta get fallback),
# 304 (_supports_structured_output), 395 (no user_message append),
# 469 (_get_base64_image missing path), 489-517 (_upload_file_to_openai)
# ---------------------------------------------------------------------------
@pytest.mark.unit

View File

@@ -0,0 +1,710 @@
"""Unit tests for ``application.security.safe_url``.
These tests must run offline, so every "valid public host" case mocks
``socket.getaddrinfo`` to return a known public IP. Cases that test
IP-literal validation do not need DNS at all and rely on the
short-circuit path inside ``validate_user_base_url``.
"""
from __future__ import annotations
import socket
from unittest import mock
import pytest
import requests
from application.security.safe_url import (
UnsafeUserUrlError,
_PinnedHTTPSTransport,
pinned_httpx_client,
pinned_post,
validate_user_base_url,
)
def _addrinfo(*ips: str) -> list[tuple]:
"""Build a fake ``socket.getaddrinfo`` return value for ``ips``."""
out: list[tuple] = []
for ip in ips:
family = socket.AF_INET6 if ":" in ip else socket.AF_INET
port_tuple = (ip, 0, 0, 0) if family == socket.AF_INET6 else (ip, 0)
out.append((family, socket.SOCK_STREAM, 0, "", port_tuple))
return out
# Valid URLs (DNS mocked to a known-public IP)
@pytest.mark.unit
def test_allows_openai_api():
with mock.patch("socket.getaddrinfo", return_value=_addrinfo("104.18.6.192")):
assert validate_user_base_url("https://api.openai.com/v1") is None
@pytest.mark.unit
def test_allows_mistral_api():
with mock.patch("socket.getaddrinfo", return_value=_addrinfo("172.67.144.116")):
assert validate_user_base_url("https://api.mistral.ai/v1") is None
@pytest.mark.unit
def test_allows_http_with_port():
with mock.patch("socket.getaddrinfo", return_value=_addrinfo("93.184.216.34")):
assert validate_user_base_url("http://example.com:8080/v1") is None
# Scheme rejection
@pytest.mark.unit
def test_rejects_file_scheme():
with pytest.raises(UnsafeUserUrlError, match="scheme"):
validate_user_base_url("file:///etc/passwd")
@pytest.mark.unit
def test_rejects_gopher_scheme():
with pytest.raises(UnsafeUserUrlError, match="scheme"):
validate_user_base_url("gopher://example.com")
@pytest.mark.unit
def test_rejects_ftp_scheme():
with pytest.raises(UnsafeUserUrlError, match="scheme"):
validate_user_base_url("ftp://example.com")
# Hostname-string blocklist
@pytest.mark.unit
def test_rejects_localhost_hostname():
with pytest.raises(UnsafeUserUrlError, match="not allowed"):
validate_user_base_url("https://localhost/v1")
@pytest.mark.unit
def test_rejects_localhost_localdomain():
with pytest.raises(UnsafeUserUrlError, match="not allowed"):
validate_user_base_url("https://localhost.localdomain/v1")
@pytest.mark.unit
def test_rejects_localhost_uppercase():
# Hostname check must be case-insensitive.
with pytest.raises(UnsafeUserUrlError, match="not allowed"):
validate_user_base_url("https://LocalHost/v1")
@pytest.mark.unit
def test_rejects_ip6_localhost_alias():
with pytest.raises(UnsafeUserUrlError, match="not allowed"):
validate_user_base_url("https://ip6-localhost/v1")
@pytest.mark.unit
def test_rejects_gcp_metadata_hostname():
with pytest.raises(UnsafeUserUrlError, match="not allowed"):
validate_user_base_url("https://metadata.google.internal/computeMetadata/v1/")
# IP-literal rejection (no DNS hit needed; covered by short-circuit)
@pytest.mark.unit
def test_rejects_loopback_ipv4_literal():
with pytest.raises(UnsafeUserUrlError, match="blocked address"):
validate_user_base_url("https://127.0.0.1/v1")
@pytest.mark.unit
def test_rejects_loopback_ipv6_literal():
with pytest.raises(UnsafeUserUrlError, match="not allowed|blocked"):
validate_user_base_url("https://[::1]/v1")
@pytest.mark.unit
def test_rejects_private_10_8():
with pytest.raises(UnsafeUserUrlError, match="blocked address"):
validate_user_base_url("https://10.0.0.5/v1")
@pytest.mark.unit
def test_rejects_private_172_16_low_boundary():
with pytest.raises(UnsafeUserUrlError, match="blocked address"):
validate_user_base_url("https://172.16.0.5/v1")
@pytest.mark.unit
def test_rejects_private_172_16_high_boundary():
with pytest.raises(UnsafeUserUrlError, match="blocked address"):
validate_user_base_url("https://172.31.0.5/v1")
@pytest.mark.unit
def test_rejects_private_192_168():
with pytest.raises(UnsafeUserUrlError, match="blocked address"):
validate_user_base_url("https://192.168.1.1/v1")
@pytest.mark.unit
def test_rejects_aws_metadata_link_local():
with pytest.raises(UnsafeUserUrlError, match="blocked address"):
validate_user_base_url("https://169.254.169.254/latest/meta-data/")
@pytest.mark.unit
def test_rejects_unique_local_ipv6():
with pytest.raises(UnsafeUserUrlError, match="blocked address"):
validate_user_base_url("https://[fc00::1]/v1")
@pytest.mark.unit
def test_rejects_link_local_ipv6():
with pytest.raises(UnsafeUserUrlError, match="blocked address"):
validate_user_base_url("https://[fe80::1]/v1")
@pytest.mark.unit
def test_rejects_multicast_ipv4():
with pytest.raises(UnsafeUserUrlError, match="blocked address"):
validate_user_base_url("https://224.0.0.1/v1")
@pytest.mark.unit
def test_rejects_unspecified_zero_address():
# 0.0.0.0 is in the literal hostname blocklist AND also caught as
# unspecified; either error message is acceptable.
with pytest.raises(UnsafeUserUrlError, match="not allowed|blocked"):
validate_user_base_url("https://0.0.0.0/v1")
@pytest.mark.unit
def test_rejects_carrier_grade_nat():
# 100.64.0.0/10 is NOT covered by ``ipaddress.is_private``.
with pytest.raises(UnsafeUserUrlError, match="blocked address"):
validate_user_base_url("https://100.64.0.1/v1")
# Parse / structural failures
@pytest.mark.unit
def test_rejects_garbage_string():
with pytest.raises(UnsafeUserUrlError):
validate_user_base_url("not a url")
@pytest.mark.unit
def test_rejects_empty_string():
with pytest.raises(UnsafeUserUrlError, match="non-empty"):
validate_user_base_url("")
@pytest.mark.unit
def test_rejects_whitespace_only_string():
with pytest.raises(UnsafeUserUrlError, match="non-empty"):
validate_user_base_url(" ")
@pytest.mark.unit
def test_rejects_url_without_hostname():
with pytest.raises(UnsafeUserUrlError):
validate_user_base_url("https:///v1")
# DNS-mocking tests for hostnames (rebinding-style scenarios)
@pytest.mark.unit
def test_rejects_hostname_resolving_only_to_private():
with mock.patch("socket.getaddrinfo", return_value=_addrinfo("10.0.0.5")):
with pytest.raises(UnsafeUserUrlError, match="blocked address"):
validate_user_base_url("https://internal.example.com/v1")
@pytest.mark.unit
def test_rejects_hostname_with_mixed_public_and_private():
# Public IP first, private second — must still reject because ANY
# blocked address in the answer set is enough.
with mock.patch(
"socket.getaddrinfo",
return_value=_addrinfo("93.184.216.34", "10.0.0.5"),
):
with pytest.raises(UnsafeUserUrlError, match="blocked address"):
validate_user_base_url("https://rebinding.example.com/v1")
@pytest.mark.unit
def test_rejects_hostname_when_dns_fails():
with mock.patch(
"socket.getaddrinfo",
side_effect=socket.gaierror("nodename nor servname provided"),
):
with pytest.raises(UnsafeUserUrlError, match="could not resolve"):
validate_user_base_url("https://nonexistent.invalid/v1")
@pytest.mark.unit
def test_rejects_hostname_resolving_to_metadata_ip_via_dns():
# Even if a hostname looks innocent, if DNS hands us 169.254.169.254
# we must refuse — defense-in-depth at dispatch time.
with mock.patch("socket.getaddrinfo", return_value=_addrinfo("169.254.169.254")):
with pytest.raises(UnsafeUserUrlError, match="blocked address"):
validate_user_base_url("https://innocent.example.com/v1")
@pytest.mark.unit
def test_rejects_hostname_resolving_to_ipv6_loopback_via_dns():
with mock.patch("socket.getaddrinfo", return_value=_addrinfo("::1")):
with pytest.raises(UnsafeUserUrlError, match="blocked address"):
validate_user_base_url("https://aaaa-only.example.com/v1")
# pinned_post — single-resolve, IP-pinned outbound HTTP
class _StubResponse:
"""Drop-in for a ``requests.Response`` that ``Session.send`` returns."""
def __init__(self, status_code: int = 200) -> None:
self.status_code = status_code
self.text = ""
self.headers = {"Content-Type": "application/json"}
def _capture_send(monkeypatch):
"""Replace ``requests.Session.send`` with a capturing stub.
Returns a dict that fills with ``prepared`` (the
``PreparedRequest``) and ``send_kwargs`` once a request is issued,
plus ``adapters`` so callers can inspect what was mounted on the
session.
"""
captured: dict = {}
def _send(self, prepared, **kwargs):
captured["prepared"] = prepared
captured["send_kwargs"] = kwargs
captured["adapters"] = dict(self.adapters)
return _StubResponse()
monkeypatch.setattr(requests.Session, "send", _send)
return captured
@pytest.mark.unit
def test_pinned_post_rewrites_host_to_resolved_ipv4(monkeypatch):
captured = _capture_send(monkeypatch)
with mock.patch("socket.getaddrinfo", return_value=_addrinfo("104.18.6.192")):
pinned_post(
"https://api.openai.com/v1/chat/completions",
json={"hi": True},
headers={"Authorization": "Bearer sk-x"},
timeout=5,
allow_redirects=False,
)
prepared = captured["prepared"]
assert prepared.url == "https://104.18.6.192/v1/chat/completions"
# Host header carries the original hostname so vhost-routing and
# SNI/cert verification target the right server.
assert prepared.headers["Host"] == "api.openai.com"
# Caller-supplied headers are preserved.
assert prepared.headers["Authorization"] == "Bearer sk-x"
# The body was sent as JSON.
assert prepared.body == b'{"hi": true}'
@pytest.mark.unit
def test_pinned_post_brackets_ipv6_in_url(monkeypatch):
captured = _capture_send(monkeypatch)
with mock.patch(
"socket.getaddrinfo", return_value=_addrinfo("2606:4700::6810:1234")
):
pinned_post(
"https://example.com/v1/x",
json={},
headers={},
timeout=5,
allow_redirects=False,
)
assert (
captured["prepared"].url == "https://[2606:4700::6810:1234]/v1/x"
)
assert captured["prepared"].headers["Host"] == "example.com"
@pytest.mark.unit
def test_pinned_post_preserves_explicit_port(monkeypatch):
captured = _capture_send(monkeypatch)
with mock.patch("socket.getaddrinfo", return_value=_addrinfo("93.184.216.34")):
pinned_post(
"https://api.example.com:8443/v1/test",
json={},
headers={},
timeout=5,
allow_redirects=False,
)
assert captured["prepared"].url == "https://93.184.216.34:8443/v1/test"
# Host header keeps the original :port — proxies and vhost routers
# rely on this, and SNI conventionally carries the bare hostname.
assert captured["prepared"].headers["Host"] == "api.example.com:8443"
@pytest.mark.unit
def test_pinned_post_handles_ip_literal_url_without_dns(monkeypatch):
"""If the URL already has an IP literal, no DNS lookup happens."""
captured = _capture_send(monkeypatch)
with mock.patch("socket.getaddrinfo") as gai:
pinned_post(
"https://93.184.216.34/v1/x",
json={},
headers={},
timeout=5,
allow_redirects=False,
)
assert gai.call_count == 0
assert captured["prepared"].url == "https://93.184.216.34/v1/x"
assert captured["prepared"].headers["Host"] == "93.184.216.34"
@pytest.mark.unit
def test_pinned_post_resolves_dns_exactly_once(monkeypatch):
"""The whole point of the helper: one resolution, one connection.
A DNS-rebinding attacker wins by getting a second ``getaddrinfo``
call after the first one was validated. If this assertion ever
fails, the SSRF guard has regressed.
"""
captured = _capture_send(monkeypatch)
with mock.patch(
"socket.getaddrinfo", return_value=_addrinfo("104.18.6.192")
) as gai:
pinned_post(
"https://api.openai.com/v1/chat/completions",
json={},
headers={},
timeout=5,
allow_redirects=False,
)
assert gai.call_count == 1
assert captured["prepared"].url.startswith("https://104.18.6.192/")
@pytest.mark.unit
def test_pinned_post_mounts_pinned_adapter_for_https(monkeypatch):
"""For HTTPS, a custom adapter must be mounted that overrides
``server_hostname`` / ``assert_hostname`` so SNI and cert
verification target the original hostname even though we connect
to an IP literal.
"""
captured = _capture_send(monkeypatch)
with mock.patch("socket.getaddrinfo", return_value=_addrinfo("104.18.6.192")):
pinned_post(
"https://api.openai.com/v1/x",
json={},
headers={},
timeout=5,
allow_redirects=False,
)
https_adapter = captured["adapters"]["https://"]
# _PinnedHostAdapter is the symbol we expect; also check it carries
# the original hostname so SNI/cert verification line up.
assert type(https_adapter).__name__ == "_PinnedHostAdapter"
assert https_adapter._server_hostname == "api.openai.com"
@pytest.mark.unit
def test_pinned_post_does_not_mount_https_adapter_for_http(monkeypatch):
"""For HTTP, no SNI/cert logic is needed — the default adapter
should remain in place; only the URL rewrite + Host header matter."""
captured = _capture_send(monkeypatch)
with mock.patch("socket.getaddrinfo", return_value=_addrinfo("93.184.216.34")):
pinned_post(
"http://example.com/v1/x",
json={},
headers={},
timeout=5,
allow_redirects=False,
)
https_adapter = captured["adapters"]["https://"]
assert type(https_adapter).__name__ == "HTTPAdapter"
@pytest.mark.unit
def test_pinned_post_raises_for_blocked_dns_result(monkeypatch):
"""A hostname that resolves to a private IP must be rejected
*before* any HTTP request is dispatched."""
captured = _capture_send(monkeypatch)
with mock.patch("socket.getaddrinfo", return_value=_addrinfo("10.0.0.5")):
with pytest.raises(UnsafeUserUrlError, match="blocked address"):
pinned_post(
"https://internal.example.com/v1/x",
json={},
headers={},
timeout=5,
allow_redirects=False,
)
assert "prepared" not in captured
@pytest.mark.unit
def test_pinned_post_raises_for_loopback_ip_literal(monkeypatch):
captured = _capture_send(monkeypatch)
with pytest.raises(UnsafeUserUrlError, match="blocked address"):
pinned_post(
"https://127.0.0.1/v1/x",
json={},
headers={},
timeout=5,
allow_redirects=False,
)
assert "prepared" not in captured
@pytest.mark.unit
def test_pinned_post_raises_for_disallowed_scheme(monkeypatch):
captured = _capture_send(monkeypatch)
with pytest.raises(UnsafeUserUrlError, match="scheme"):
pinned_post(
"ftp://example.com/v1/x",
json={},
headers={},
timeout=5,
allow_redirects=False,
)
assert "prepared" not in captured
@pytest.mark.unit
def test_pinned_post_overrides_caller_supplied_host_header(monkeypatch):
"""If the caller passes their own Host header, the helper must
still set it to the resolved hostname so the in-flight request
doesn't disagree with what was validated."""
captured = _capture_send(monkeypatch)
with mock.patch("socket.getaddrinfo", return_value=_addrinfo("104.18.6.192")):
pinned_post(
"https://api.openai.com/v1/x",
json={},
headers={"Host": "evil.example.com"},
timeout=5,
allow_redirects=False,
)
assert captured["prepared"].headers["Host"] == "api.openai.com"
# pinned_httpx_client — DNS-rebinding-safe httpx transport for SDK use
def _capture_httpx_handle_request(monkeypatch):
"""Patch ``httpx.HTTPTransport.handle_request`` to record the
request reaching the parent transport, and return a fake response.
The pinned transport's ``handle_request`` rewrites
``request.url.host`` and sets ``sni_hostname`` *before* delegating
to ``super().handle_request``. Capturing what the parent sees
gives us the actual values that would feed httpcore's connect
(and thus what TCP would dial / SNI would advertise) without
opening a real socket.
"""
import httpx
captured: dict = {}
def fake_handle(self, request):
captured["url"] = request.url
captured["sni"] = request.extensions.get("sni_hostname")
captured["host_header"] = request.headers.get("host")
return httpx.Response(200, content=b"ok")
monkeypatch.setattr(
"httpx.HTTPTransport.handle_request", fake_handle
)
return captured
@pytest.mark.unit
def test_pinned_httpx_client_returns_pinned_transport():
"""The factory must wire its transport in unchanged and bind it
to the validated host and IP."""
with mock.patch(
"socket.getaddrinfo", return_value=_addrinfo("104.18.6.192")
):
client = pinned_httpx_client("https://api.example.com/v1")
try:
assert isinstance(client._transport, _PinnedHTTPSTransport)
assert client._transport._host == "api.example.com"
assert client._transport._ip_netloc == "104.18.6.192"
finally:
client.close()
@pytest.mark.unit
def test_pinned_httpx_client_disables_redirects():
"""SSRF guard only inspects the supplied URL — following 3xx would
let a hostile upstream bounce the in-network request to an
internal address."""
with mock.patch(
"socket.getaddrinfo", return_value=_addrinfo("104.18.6.192")
):
client = pinned_httpx_client("https://api.example.com/v1")
try:
assert client.follow_redirects is False
finally:
client.close()
@pytest.mark.unit
def test_pinned_httpx_transport_rewrites_url_to_validated_ip(monkeypatch):
"""The core invariant: every request reaching httpcore has its URL
host pointed at the IP literal we validated, so TCP dials that IP
rather than triggering a fresh DNS resolution."""
captured = _capture_httpx_handle_request(monkeypatch)
with mock.patch(
"socket.getaddrinfo", return_value=_addrinfo("104.18.6.192")
):
client = pinned_httpx_client("https://api.example.com/v1")
try:
client.get("https://api.example.com/v1/test")
finally:
client.close()
assert captured["url"].host == "104.18.6.192"
@pytest.mark.unit
def test_pinned_httpx_transport_sets_sni_for_original_hostname(monkeypatch):
"""TLS SNI / cert verification must use the original hostname; the
transport sets it via the ``sni_hostname`` extension that
httpcore forwards to ``start_tls``'s ``server_hostname``."""
captured = _capture_httpx_handle_request(monkeypatch)
with mock.patch(
"socket.getaddrinfo", return_value=_addrinfo("104.18.6.192")
):
client = pinned_httpx_client("https://api.example.com/v1")
try:
client.get("https://api.example.com/v1/test")
finally:
client.close()
assert captured["sni"] == b"api.example.com"
@pytest.mark.unit
def test_pinned_httpx_transport_preserves_host_header(monkeypatch):
"""``Host`` is auto-set by ``httpx.Request._prepare`` from the URL
netloc *before* our transport rewrites the URL host. The header
must still carry the original hostname."""
captured = _capture_httpx_handle_request(monkeypatch)
with mock.patch(
"socket.getaddrinfo", return_value=_addrinfo("104.18.6.192")
):
client = pinned_httpx_client("https://api.example.com/v1")
try:
client.get("https://api.example.com/v1/test")
finally:
client.close()
assert captured["host_header"] == "api.example.com"
@pytest.mark.unit
def test_pinned_httpx_client_closes_dns_rebinding_window(monkeypatch):
"""The TOCTOU lock-in test: validate against a public IP, then
have DNS rebind to a private (loopback) IP, then send a request.
The transport must dial the *first* validated IP — not the
rebound one — guaranteeing no second DNS lookup interferes."""
captured = _capture_httpx_handle_request(monkeypatch)
# First lookup (at validation time) returns a public IP.
public = _addrinfo("104.18.6.192")
# Subsequent lookups (which the transport must NEVER trigger)
# would return loopback if a hostile resolver flipped them.
private = _addrinfo("127.0.0.1")
getaddrinfo_calls: list = []
def fake_getaddrinfo(*args, **kwargs):
getaddrinfo_calls.append((args, kwargs))
# First call = validation; everything after = post-rebind.
return public if len(getaddrinfo_calls) == 1 else private
monkeypatch.setattr("socket.getaddrinfo", fake_getaddrinfo)
client = pinned_httpx_client("https://api.attacker.example/v1")
try:
client.get("https://api.attacker.example/v1/test")
finally:
client.close()
# Whatever happens below the transport, the URL handed to
# httpcore must be the IP from the *first* getaddrinfo call.
assert captured["url"].host == "104.18.6.192", (
"pinned transport must dial the IP validated at construction, "
"not whatever DNS returns at request time"
)
@pytest.mark.unit
def test_pinned_httpx_client_rejects_blocked_dns_result():
"""If the validation lookup itself returns a private IP, the
factory must refuse to construct a client at all."""
with mock.patch(
"socket.getaddrinfo", return_value=_addrinfo("169.254.169.254")
):
with pytest.raises(UnsafeUserUrlError, match="link-local"):
pinned_httpx_client("https://api.attacker.example/v1")
@pytest.mark.unit
def test_pinned_httpx_client_rejects_loopback_literal():
"""An IP literal in the supplied URL goes through the same guard
even when DNS isn't called."""
with pytest.raises(UnsafeUserUrlError):
pinned_httpx_client("http://127.0.0.1:8080/v1")
@pytest.mark.unit
def test_pinned_httpx_transport_refuses_unexpected_host(monkeypatch):
"""Defense in depth: if the SDK ever rewrites the request URL to a
different host between Request construction and dial, the
transport refuses rather than silently dialing the validated IP
with a different host's credentials."""
import httpx
with mock.patch(
"socket.getaddrinfo", return_value=_addrinfo("104.18.6.192")
):
client = pinned_httpx_client("https://api.example.com/v1")
try:
with pytest.raises(UnsafeUserUrlError, match="refused request"):
client.get("https://other.example.com/v1/test")
except httpx.RequestError:
# Some httpx versions may wrap the transport error; accept
# either path so long as the request didn't succeed.
pass
finally:
client.close()

View File

@@ -0,0 +1,172 @@
"""Tests for UserCustomModelsRepository against a real Postgres instance."""
from __future__ import annotations
from application.storage.db.repositories.user_custom_models import (
UserCustomModelsRepository,
)
def _repo(conn) -> UserCustomModelsRepository:
return UserCustomModelsRepository(conn)
def _make(repo, user="user-1", upstream="mistral-large-latest", **kwargs):
return repo.create(
user_id=user,
upstream_model_id=upstream,
display_name=kwargs.pop("display_name", "My Mistral"),
base_url=kwargs.pop("base_url", "https://api.mistral.ai/v1"),
api_key_plaintext=kwargs.pop("api_key_plaintext", "sk-mistral-test"),
**kwargs,
)
class TestCreate:
def test_creates_minimal(self, pg_conn):
repo = _repo(pg_conn)
row = _make(repo)
assert row["user_id"] == "user-1"
assert row["upstream_model_id"] == "mistral-large-latest"
assert row["display_name"] == "My Mistral"
assert row["base_url"] == "https://api.mistral.ai/v1"
assert row["enabled"] is True
assert row["id"] is not None
# Plaintext key never lands in the row
assert row["api_key_encrypted"] != "sk-mistral-test"
assert "sk-mistral-test" not in row["api_key_encrypted"]
def test_capabilities_normalized_drops_unknown_keys(self, pg_conn):
repo = _repo(pg_conn)
row = _make(
repo,
capabilities={
"supports_tools": True,
"context_window": 200_000,
"garbage_key": "should be dropped",
},
)
assert row["capabilities"] == {
"supports_tools": True,
"context_window": 200_000,
}
class TestGet:
def test_get_by_id_returns_row(self, pg_conn):
repo = _repo(pg_conn)
created = _make(repo)
fetched = repo.get(created["id"], "user-1")
assert fetched is not None
assert fetched["id"] == created["id"]
def test_get_other_user_returns_none(self, pg_conn):
repo = _repo(pg_conn)
created = _make(repo, user="alice")
# Bob cannot fetch Alice's row even with the right id
assert repo.get(created["id"], "bob") is None
def test_get_missing_id_returns_none(self, pg_conn):
repo = _repo(pg_conn)
_make(repo)
# A different (but valid) UUID
assert repo.get("00000000-0000-0000-0000-000000000000", "user-1") is None
class TestListForUser:
def test_lists_only_users_own(self, pg_conn):
repo = _repo(pg_conn)
_make(repo, user="alice", upstream="alice-1")
_make(repo, user="alice", upstream="alice-2")
_make(repo, user="bob", upstream="bob-1")
alice = repo.list_for_user("alice")
assert {r["upstream_model_id"] for r in alice} == {"alice-1", "alice-2"}
bob = repo.list_for_user("bob")
assert {r["upstream_model_id"] for r in bob} == {"bob-1"}
class TestUpdate:
def test_update_partial(self, pg_conn):
repo = _repo(pg_conn)
created = _make(repo)
ok = repo.update(
created["id"],
"user-1",
{"display_name": "Renamed", "enabled": False},
)
assert ok is True
fetched = repo.get(created["id"], "user-1")
assert fetched["display_name"] == "Renamed"
assert fetched["enabled"] is False
def test_update_capabilities_normalizes(self, pg_conn):
repo = _repo(pg_conn)
created = _make(repo)
repo.update(
created["id"],
"user-1",
{"capabilities": {"supports_tools": False, "garbage": 1}},
)
fetched = repo.get(created["id"], "user-1")
assert fetched["capabilities"] == {"supports_tools": False}
def test_update_api_key_re_encrypts(self, pg_conn):
repo = _repo(pg_conn)
created = _make(repo)
before = created["api_key_encrypted"]
repo.update(
created["id"],
"user-1",
{"api_key_plaintext": "sk-mistral-new"},
)
fetched = repo.get(created["id"], "user-1")
# Ciphertext changed
assert fetched["api_key_encrypted"] != before
# Plaintext absent
assert "sk-mistral-new" not in fetched["api_key_encrypted"]
# And decrypts back to the new value
plaintext = repo.get_decrypted_api_key(created["id"], "user-1")
assert plaintext == "sk-mistral-new"
def test_update_other_user_returns_false(self, pg_conn):
repo = _repo(pg_conn)
created = _make(repo, user="alice")
# Bob can't update Alice's row
ok = repo.update(created["id"], "bob", {"display_name": "Hacked"})
assert ok is False
# And Alice's row is untouched
fetched = repo.get(created["id"], "alice")
assert fetched["display_name"] == "My Mistral"
class TestDelete:
def test_delete_removes_row(self, pg_conn):
repo = _repo(pg_conn)
created = _make(repo)
assert repo.delete(created["id"], "user-1") is True
assert repo.get(created["id"], "user-1") is None
def test_delete_other_user_returns_false(self, pg_conn):
repo = _repo(pg_conn)
created = _make(repo, user="alice")
assert repo.delete(created["id"], "bob") is False
# Alice's row still there
assert repo.get(created["id"], "alice") is not None
class TestEncryptionRoundtrip:
def test_decrypted_matches_original(self, pg_conn):
repo = _repo(pg_conn)
created = _make(repo, api_key_plaintext="my-very-secret-key-12345")
plaintext = repo.get_decrypted_api_key(created["id"], "user-1")
assert plaintext == "my-very-secret-key-12345"
def test_decryption_with_wrong_user_fails_silently(self, pg_conn):
repo = _repo(pg_conn)
# Per-user PBKDF2 salt: Alice's record can't be decrypted with
# Bob's user_id even if Bob somehow has the row.
created = _make(repo, user="alice", api_key_plaintext="alice-secret")
# Manually call decrypt with the wrong user_id (simulates the
# registry layer being given the wrong user context).
wrong = repo._decrypt_api_key(created["api_key_encrypted"], "bob")
assert wrong != "alice-secret" # either None or garbage; not the secret

View File

@@ -112,7 +112,7 @@ class TestFlaskCors:
response = client.get("/api/health", headers={"Origin": "http://localhost:5173"})
assert response.headers["Access-Control-Allow-Origin"] == "*"
assert response.headers["Access-Control-Allow-Headers"] == "Content-Type, Authorization"
assert response.headers["Access-Control-Allow-Methods"] == "GET, POST, PUT, DELETE, OPTIONS"
assert response.headers["Access-Control-Allow-Methods"] == "GET, POST, PUT, PATCH, DELETE, OPTIONS"
@pytest.mark.unit
def test_cors_headers_on_flask_preflight(self, client):
@@ -127,4 +127,4 @@ class TestFlaskCors:
assert response.status_code == 200
assert response.headers["Access-Control-Allow-Origin"] == "*"
assert response.headers["Access-Control-Allow-Headers"] == "Content-Type, Authorization"
assert response.headers["Access-Control-Allow-Methods"] == "GET, POST, PUT, DELETE, OPTIONS"
assert response.headers["Access-Control-Allow-Methods"] == "GET, POST, PUT, PATCH, DELETE, OPTIONS"

View File

@@ -113,6 +113,28 @@ def test_cors_preflight_on_flask_route():
assert "GET" in r.headers.get("access-control-allow-methods", "")
@pytest.mark.unit
def test_cors_preflight_allows_patch():
"""PATCH must be in Access-Control-Allow-Methods. The frontend's
apiClient.patch() (used to edit BYOM custom models via PATCH
/api/user/models/<id>) is otherwise blocked at preflight by browsers."""
from starlette.testclient import TestClient
from application.asgi import asgi_app
with TestClient(asgi_app) as client:
r = client.options(
"/api/health",
headers={
"Origin": "http://example.com",
"Access-Control-Request-Method": "PATCH",
"Access-Control-Request-Headers": "Content-Type, Authorization",
},
)
assert r.status_code in (200, 204)
assert "PATCH" in r.headers.get("access-control-allow-methods", "")
@pytest.mark.unit
def test_cors_preflight_on_mcp_route():
"""Browser clients hitting /mcp should be allowed to send session headers."""

View File

@@ -212,6 +212,44 @@ class TestClassicRAGRephraseQuery:
assert rag.question == "original"
@pytest.mark.unit
class TestClassicRAGLLMCreatorWiring:
"""ClassicRAG must forward model_id + model_user_id to LLMCreator so
the registry-resolution path runs (BYOM api_key/base_url overrides
and upstream_model_id translation). Without these the rephrase
client dispatches the registry UUID to the plugin's default endpoint
with the instance API key."""
def test_passes_model_id_and_user_id_to_llmcreator(self, mock_llm, monkeypatch):
captured = Mock(return_value=mock_llm)
monkeypatch.setattr(
"application.retriever.classic_rag.LLMCreator.create_llm", captured
)
_make_rag(
model_id="byom-uuid",
model_user_id="owner",
decoded_token={"sub": "caller"},
)
assert captured.call_count == 1
kwargs = captured.call_args.kwargs
assert kwargs["model_id"] == "byom-uuid"
assert kwargs["model_user_id"] == "owner"
# Caller identity still flows so non-BYOM paths keep working.
assert kwargs["decoded_token"] == {"sub": "caller"}
def test_default_model_user_id_is_none(self, mock_llm, monkeypatch):
captured = Mock(return_value=mock_llm)
monkeypatch.setattr(
"application.retriever.classic_rag.LLMCreator.create_llm", captured
)
_make_rag() # no model_user_id override
assert captured.call_args.kwargs["model_user_id"] is None
@pytest.mark.unit
class TestClassicRAGGetData:
def test_chunks_zero_returns_empty(self, _patch_llm_creator):

View File

@@ -99,11 +99,11 @@ class TestRunAgentLogic:
"application.core.model_utils.get_default_model_id", lambda: "gpt-4"
)
monkeypatch.setattr(
"application.core.model_utils.validate_model_id", lambda m: True
"application.core.model_utils.validate_model_id", lambda m, **_kwargs: True
)
monkeypatch.setattr(
"application.core.model_utils.get_provider_from_model_id",
lambda m: "openai",
lambda m, **_kwargs: "openai",
)
monkeypatch.setattr(
"application.core.model_utils.get_api_key_for_provider",
@@ -111,7 +111,7 @@ class TestRunAgentLogic:
)
monkeypatch.setattr(
"application.utils.calculate_doc_token_budget",
lambda model_id=None: 1000,
lambda model_id=None, **_kwargs: 1000,
)
monkeypatch.setattr(
"application.api.answer.services.stream_processor.get_prompt",