Compare commits

...

21 Commits

Author SHA1 Message Date
Alex
0d2a8e11f4 feat: better token serialiser 2026-04-28 02:36:40 +01:00
Alex
f0c39dec23 feat: more logs on stream finish 2026-04-28 02:27:02 +01:00
Alex
552bfe016a fix: better token counting and fixes cache 2026-04-28 01:47:53 +01:00
Alex
a6a5db631b chore: updated roadmap 2026-04-28 01:03:52 +01:00
Alex
8e9f661efc fix: attachments 2026-04-28 00:38:27 +01:00
Alex
82c71be819 feat: better logging 2026-04-28 00:14:43 +01:00
Alex
318de18d43 feat: BYOM (#2433) 2026-04-27 22:09:33 +01:00
Alex
af618de13d Feat models (#2432)
* feat: simplified model structure

* fix: test

* fix: mini docstring stuff
2026-04-26 00:58:29 +01:00
Alex
ef976eeb06 feat: make version check periodic 2026-04-25 14:57:37 +01:00
Alex
9c8ae9d540 feat: redbeat 2026-04-25 14:38:24 +01:00
Alex
7ca33b2b72 feat: OTEL 2026-04-25 13:38:03 +01:00
Manish Madan
d1b9798f62 Merge pull request #2422 from arc53/dependabot/npm_and_yarn/extensions/react-widget/babel/plugin-transform-flow-strip-types-7.27.1
chore(deps): bump @babel/plugin-transform-flow-strip-types from 7.24.6 to 7.27.1 in /extensions/react-widget
2026-04-24 05:13:00 +05:30
dependabot[bot]
ddc3adf3ab chore(deps): bump @babel/plugin-transform-flow-strip-types
Bumps [@babel/plugin-transform-flow-strip-types](https://github.com/babel/babel/tree/HEAD/packages/babel-plugin-transform-flow-strip-types) from 7.24.6 to 7.27.1.
- [Release notes](https://github.com/babel/babel/releases)
- [Changelog](https://github.com/babel/babel/blob/main/CHANGELOG.md)
- [Commits](https://github.com/babel/babel/commits/v7.27.1/packages/babel-plugin-transform-flow-strip-types)

---
updated-dependencies:
- dependency-name: "@babel/plugin-transform-flow-strip-types"
  dependency-version: 7.27.1
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-23 23:41:46 +00:00
Manish Madan
a4991d01ac Merge pull request #2421 from arc53/dependabot/npm_and_yarn/extensions/react-widget/typescript-6.0.3
chore(deps-dev): bump typescript from 5.9.3 to 6.0.3 in /extensions/react-widget
2026-04-24 05:09:58 +05:30
ManishMadan2882
87fd1bd359 chore(deps-dev): bump @typescript-eslint/{eslint-plugin,parser} to ^8.58.0 for TS 6 support 2026-04-24 05:08:17 +05:30
dependabot[bot]
c71e986d34 chore(deps-dev): bump typescript in /extensions/react-widget
Bumps [typescript](https://github.com/microsoft/TypeScript) from 5.9.3 to 6.0.3.
- [Release notes](https://github.com/microsoft/TypeScript/releases)
- [Commits](https://github.com/microsoft/TypeScript/compare/v5.9.3...v6.0.3)

---
updated-dependencies:
- dependency-name: typescript
  dependency-version: 6.0.3
  dependency-type: direct:development
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-23 23:18:29 +00:00
Manish Madan
a2a06c569e Merge pull request #2419 from arc53/dependabot/npm_and_yarn/extensions/react-widget/prettier-3.8.3
chore(deps-dev): bump prettier from 3.8.1 to 3.8.3 in /extensions/react-widget
2026-04-24 04:34:37 +05:30
Alex
c5f00a1d1b fix: remove old extension in external repo's 2026-04-23 23:27:38 +01:00
Alex
2a15bb0102 chore: delete old extensions 2026-04-23 22:55:03 +01:00
Alex
c06888bc86 feat: asgi and search service (#2424)
* feat: asgi and search service

* feat: asgi and mcp tool server

* fix: asgi issues

* fix: mini cors hardening
2026-04-23 12:21:39 +01:00
dependabot[bot]
65460b0c03 chore(deps-dev): bump prettier in /extensions/react-widget
Bumps [prettier](https://github.com/prettier/prettier) from 3.8.1 to 3.8.3.
- [Release notes](https://github.com/prettier/prettier/releases)
- [Changelog](https://github.com/prettier/prettier/blob/main/CHANGELOG.md)
- [Commits](https://github.com/prettier/prettier/compare/3.8.1...3.8.3)

---
updated-dependencies:
- dependency-name: prettier
  dependency-version: 3.8.3
  dependency-type: direct:development
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-20 22:35:52 +00:00
200 changed files with 13663 additions and 15730 deletions

View File

@@ -35,8 +35,5 @@ MICROSOFT_TENANT_ID=your-azure-ad-tenant-id
#Alternatively, use "https://login.microsoftonline.com/common" for multi-tenant app.
MICROSOFT_AUTHORITY=https://{tenantId}.ciamlogin.com/{tenantId}
# User-data Postgres DB (Phase 0 of the MongoDB→Postgres migration).
# Standard Postgres URI — `postgres://` and `postgresql://` both work.
# Leave unset while the migration is still being rolled out; the app will
# fall back to MongoDB for user data until POSTGRES_URI is configured.
# POSTGRES_URI=postgresql://docsgpt:docsgpt@localhost:5432/docsgpt

View File

@@ -37,6 +37,22 @@ Run the Flask API (if needed):
flask --app application/app.py run --host=0.0.0.0 --port=7091
```
That's the fast inner-loop option — quick startup, the Werkzeug interactive
debugger still works, and it hot-reloads on source changes. It serves the
Flask routes only (`/api/*`, `/stream`, etc.).
If you need to exercise the full ASGI stack — the `/mcp` FastMCP endpoint,
or to match the production runtime exactly — run the ASGI composition under
uvicorn instead:
```bash
uvicorn application.asgi:asgi_app --host 0.0.0.0 --port 7091 --reload
```
Production uses `gunicorn -k uvicorn_worker.UvicornWorker` against the same
`application.asgi:asgi_app` target; see `application/Dockerfile` for the
full flag set.
Run the Celery worker in a separate terminal (if needed):
```bash
@@ -99,7 +115,7 @@ vale .
- `frontend/`: Vite + React + TypeScript application.
- `frontend/src/`: main UI code, including `components`, `conversation`, `hooks`, `locale`, `settings`, `upload`, and Redux store wiring in `store.ts`.
- `docs/`: separate documentation site built with Next.js/Nextra.
- `extensions/`: integrations and widgets such as Chatwoot, Chrome, Discord, React widget, Slack bot, and web widget.
- `extensions/`: integrations and widgets — currently the Chatwoot webhook bridge and the React widget (published to npm as `docsgpt`). The Discord bot, Slack bot, and Chrome extension have been moved to their own repos under `arc53/`.
- `deployment/`: Docker Compose variants and Kubernetes manifests.
## Coding rules

View File

@@ -47,11 +47,13 @@
</ul>
## Roadmap
- [x] Add OAuth 2.0 authentication for MCP ( September 2025 )
- [x] Deep Agents ( October 2025 )
- [x] Prompt Templating ( October 2025 )
- [x] Full api tooling ( Dec 2025 )
- [ ] Agent scheduling ( Jan 2026 )
- [x] Agent Workflow Builder with conditional nodes ( February 2026 )
- [x] SharePoint & Confluence connectors ( March April 2026 )
- [x] Research mode ( March 2026 )
- [x] Postgres migration for user data ( April 2026 )
- [x] OpenTelemetry observability ( April 2026 )
- [x] Bring Your Own Model (BYOM) ( April 2026 )
- [ ] Agent scheduling (RedBeat-backed) ( Q2 2026 )
You can find our full roadmap [here](https://github.com/orgs/arc53/projects/2). Please don't hesitate to contribute or create issues, it helps us improve DocsGPT!

View File

@@ -88,5 +88,15 @@ EXPOSE 7091
# Switch to non-root user
USER appuser
# Start Gunicorn
CMD ["gunicorn", "-w", "1", "--timeout", "120", "--bind", "0.0.0.0:7091", "--preload", "application.wsgi:app"]
CMD ["gunicorn", \
"-w", "1", \
"-k", "uvicorn_worker.UvicornWorker", \
"--bind", "0.0.0.0:7091", \
"--timeout", "180", \
"--graceful-timeout", "120", \
"--keep-alive", "5", \
"--worker-tmp-dir", "/dev/shm", \
"--max-requests", "1000", \
"--max-requests-jitter", "100", \
"--config", "application/gunicorn_conf.py", \
"application.asgi:asgi_app"]

View File

@@ -42,6 +42,7 @@ class BaseAgent(ABC):
llm_handler=None,
tool_executor: Optional[ToolExecutor] = None,
backup_models: Optional[List[str]] = None,
model_user_id: Optional[str] = None,
):
self.endpoint = endpoint
self.llm_name = llm_name
@@ -52,10 +53,13 @@ class BaseAgent(ABC):
self.prompt = prompt
self.decoded_token = decoded_token or {}
self.user: str = self.decoded_token.get("sub")
# BYOM-resolution scope: owner for shared agents, caller for
# caller-owned BYOM, None for built-ins. Falls back to self.user
# for worker/legacy callers that don't thread model_user_id.
self.model_user_id = model_user_id
self.tools: List[Dict] = []
self.chat_history: List[Dict] = chat_history if chat_history is not None else []
# Dependency injection for LLM — fall back to creating if not provided
if llm is not None:
self.llm = llm
else:
@@ -67,8 +71,16 @@ class BaseAgent(ABC):
model_id=model_id,
agent_id=agent_id,
backup_models=backup_models,
model_user_id=model_user_id,
)
# For BYOM, registry id (UUID) differs from upstream model id
# (e.g. ``mistral-large-latest``). LLMCreator resolved this onto
# the LLM instance; cache it for subsequent gen calls.
self.upstream_model_id = (
getattr(self.llm, "model_id", None) or model_id
)
self.retrieved_docs = retrieved_docs or []
if llm_handler is not None:
@@ -306,7 +318,9 @@ class BaseAgent(ABC):
try:
current_tokens = self._calculate_current_context_tokens(messages)
self.current_token_count = current_tokens
context_limit = get_token_limit(self.model_id)
context_limit = get_token_limit(
self.model_id, user_id=self.model_user_id or self.user
)
threshold = int(context_limit * settings.COMPRESSION_THRESHOLD_PERCENTAGE)
if current_tokens >= threshold:
@@ -325,7 +339,9 @@ class BaseAgent(ABC):
current_tokens = self._calculate_current_context_tokens(messages)
self.current_token_count = current_tokens
context_limit = get_token_limit(self.model_id)
context_limit = get_token_limit(
self.model_id, user_id=self.model_user_id or self.user
)
percentage = (current_tokens / context_limit) * 100
if current_tokens >= context_limit:
@@ -387,7 +403,9 @@ class BaseAgent(ABC):
)
system_prompt = system_prompt + compression_context
context_limit = get_token_limit(self.model_id)
context_limit = get_token_limit(
self.model_id, user_id=self.model_user_id or self.user
)
system_tokens = num_tokens_from_string(system_prompt)
safety_buffer = int(context_limit * 0.1)
@@ -497,7 +515,10 @@ class BaseAgent(ABC):
def _llm_gen(self, messages: List[Dict], log_context: Optional[LogContext] = None):
self._validate_context_size(messages)
gen_kwargs = {"model": self.model_id, "messages": messages}
# Use the upstream id resolved by LLMCreator (see __init__).
# Built-in models: same as self.model_id. BYOM: the user's
# typed model name, not the internal UUID.
gen_kwargs = {"model": self.upstream_model_id, "messages": messages}
if self.attachments:
gen_kwargs["_usage_attachments"] = self.attachments

View File

@@ -312,7 +312,7 @@ class ResearchAgent(BaseAgent):
try:
response = self.llm.gen(
model=self.model_id,
model=self.upstream_model_id,
messages=messages,
tools=None,
response_format={"type": "json_object"},
@@ -390,7 +390,7 @@ class ResearchAgent(BaseAgent):
try:
response = self.llm.gen(
model=self.model_id,
model=self.upstream_model_id,
messages=messages,
tools=None,
response_format={"type": "json_object"},
@@ -506,7 +506,7 @@ class ResearchAgent(BaseAgent):
try:
response = self.llm.gen(
model=self.model_id,
model=self.upstream_model_id,
messages=messages,
tools=self.tools if self.tools else None,
)
@@ -537,7 +537,7 @@ class ResearchAgent(BaseAgent):
)
try:
response = self.llm.gen(
model=self.model_id, messages=messages, tools=None
model=self.upstream_model_id, messages=messages, tools=None
)
self._track_tokens(self._snapshot_llm_tokens())
text = self._extract_text(response)
@@ -664,7 +664,7 @@ class ResearchAgent(BaseAgent):
]
llm_response = self.llm.gen_stream(
model=self.model_id, messages=messages, tools=None
model=self.upstream_model_id, messages=messages, tools=None
)
if log_context:

View File

@@ -274,7 +274,14 @@ class ToolExecutor:
if tool_id is None or action_name is None:
error_message = f"Error: Failed to parse LLM tool call. Tool name: {llm_name}"
logger.error(error_message)
logger.error(
"tool_call_parse_failed",
extra={
"llm_class_name": llm_class_name,
"llm_tool_name": llm_name,
"call_id": call_id,
},
)
tool_call_data = {
"tool_name": "unknown",
@@ -289,7 +296,15 @@ class ToolExecutor:
if tool_id not in tools_dict:
error_message = f"Error: Tool ID '{tool_id}' extracted from LLM call not found in available tools_dict. Available IDs: {list(tools_dict.keys())}"
logger.error(error_message)
logger.error(
"tool_id_not_found",
extra={
"tool_id": tool_id,
"llm_tool_name": llm_name,
"call_id": call_id,
"available_tool_count": len(tools_dict),
},
)
tool_call_data = {
"tool_name": "unknown",
@@ -356,7 +371,15 @@ class ToolExecutor:
f"Failed to load tool '{tool_data.get('name')}' (tool_id key={tool_id}): "
"missing 'id' on tool row."
)
logger.error(error_message)
logger.error(
"tool_load_failed",
extra={
"tool_name": tool_data.get("name"),
"tool_id": tool_id,
"action_name": action_name,
"call_id": call_id,
},
)
tool_call_data["result"] = error_message
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
self.tool_calls.append(tool_call_data)
@@ -451,10 +474,12 @@ class ToolExecutor:
row_id = tool_data.get("id")
if not row_id:
logger.error(
"Tool data missing 'id' for tool name=%s (enumerate-key tool_id=%s); "
"skipping load to avoid binding a non-UUID downstream.",
tool_data.get("name"),
tool_id,
"tool_missing_row_id",
extra={
"tool_name": tool_data.get("name"),
"tool_id": tool_id,
"action_name": action_name,
},
)
return None
tool_config["tool_id"] = str(row_id)

View File

@@ -39,6 +39,7 @@ class InternalSearchTool(Tool):
chunks=int(self.config.get("chunks", 2)),
doc_token_limit=int(self.config.get("doc_token_limit", 50000)),
model_id=self.config.get("model_id", "docsgpt-local"),
model_user_id=self.config.get("model_user_id"),
user_api_key=self.config.get("user_api_key"),
agent_id=self.config.get("agent_id"),
llm_name=self.config.get("llm_name", settings.LLM_PROVIDER),
@@ -435,6 +436,7 @@ def build_internal_tool_config(
chunks: int = 2,
doc_token_limit: int = 50000,
model_id: str = "docsgpt-local",
model_user_id: Optional[str] = None,
user_api_key: Optional[str] = None,
agent_id: Optional[str] = None,
llm_name: str = None,
@@ -449,6 +451,7 @@ def build_internal_tool_config(
"chunks": chunks,
"doc_token_limit": doc_token_limit,
"model_id": model_id,
"model_user_id": model_user_id,
"user_api_key": user_api_key,
"agent_id": agent_id,
"llm_name": llm_name or settings.LLM_PROVIDER,

View File

@@ -211,15 +211,26 @@ class WorkflowEngine:
node_config.json_schema, node.title
)
node_model_id = node_config.model_id or self.agent.model_id
# Inherit BYOM scope from parent agent so owner-stored BYOM
# resolves on shared workflows.
node_user_id = getattr(self.agent, "model_user_id", None) or (
self.agent.decoded_token.get("sub")
if isinstance(self.agent.decoded_token, dict)
else None
)
node_llm_name = (
node_config.llm_name
or get_provider_from_model_id(node_model_id or "")
or get_provider_from_model_id(
node_model_id or "", user_id=node_user_id
)
or self.agent.llm_name
)
node_api_key = get_api_key_for_provider(node_llm_name) or self.agent.api_key
if node_json_schema and node_model_id:
model_capabilities = get_model_capabilities(node_model_id)
model_capabilities = get_model_capabilities(
node_model_id, user_id=node_user_id
)
if model_capabilities and not model_capabilities.get(
"supports_structured_output", False
):
@@ -232,6 +243,7 @@ class WorkflowEngine:
"endpoint": self.agent.endpoint,
"llm_name": node_llm_name,
"model_id": node_model_id,
"model_user_id": getattr(self.agent, "model_user_id", None),
"api_key": node_api_key,
"tool_ids": node_config.tools,
"prompt": node_config.system_prompt,

View File

@@ -0,0 +1,65 @@
"""0003 user_custom_models — per-user OpenAI-compatible model registrations.
Revision ID: 0003_user_custom_models
Revises: 0002_app_metadata
"""
from typing import Sequence, Union
from alembic import op
revision: str = "0003_user_custom_models"
down_revision: Union[str, None] = "0002_app_metadata"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.execute(
"""
CREATE TABLE user_custom_models (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
upstream_model_id TEXT NOT NULL,
display_name TEXT NOT NULL,
description TEXT NOT NULL DEFAULT '',
base_url TEXT NOT NULL,
api_key_encrypted TEXT NOT NULL,
capabilities JSONB NOT NULL DEFAULT '{}'::jsonb,
enabled BOOLEAN NOT NULL DEFAULT true,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
"""
)
op.execute(
"CREATE INDEX user_custom_models_user_id_idx "
"ON user_custom_models (user_id);"
)
# Mirror the project-wide invariants set up in 0001_initial:
# * user_id FK with ON DELETE RESTRICT (deferrable),
# * ensure_user_exists() trigger so the parent users row autocreates,
# * set_updated_at() trigger.
op.execute(
"ALTER TABLE user_custom_models "
"ADD CONSTRAINT user_custom_models_user_id_fk "
"FOREIGN KEY (user_id) REFERENCES users(user_id) "
"ON DELETE RESTRICT DEFERRABLE INITIALLY IMMEDIATE;"
)
op.execute(
"CREATE TRIGGER user_custom_models_ensure_user "
"BEFORE INSERT OR UPDATE OF user_id ON user_custom_models "
"FOR EACH ROW EXECUTE FUNCTION ensure_user_exists();"
)
op.execute(
"CREATE TRIGGER user_custom_models_set_updated_at "
"BEFORE UPDATE ON user_custom_models "
"FOR EACH ROW WHEN (OLD.* IS DISTINCT FROM NEW.*) "
"EXECUTE FUNCTION set_updated_at();"
)
def downgrade() -> None:
op.execute("DROP TABLE IF EXISTS user_custom_models;")

View File

@@ -177,6 +177,7 @@ class BaseAnswerResource:
is_shared_usage: bool = False,
shared_token: Optional[str] = None,
model_id: Optional[str] = None,
model_user_id: Optional[str] = None,
_continuation: Optional[Dict] = None,
) -> Generator[str, None, None]:
"""
@@ -289,8 +290,18 @@ class BaseAnswerResource:
# conversation if this is the first turn.
if not conversation_id and should_save_conversation:
try:
# Use model-owner scope so shared-agent
# owner-BYOM resolves to its registered plugin.
provider = (
get_provider_from_model_id(model_id)
get_provider_from_model_id(
model_id,
user_id=model_user_id
or (
decoded_token.get("sub")
if decoded_token
else None
),
)
if model_id
else settings.LLM_PROVIDER
)
@@ -304,6 +315,7 @@ class BaseAnswerResource:
decoded_token=decoded_token,
model_id=model_id,
agent_id=agent_id,
model_user_id=model_user_id,
)
conversation_id = (
self.conversation_service.save_conversation(
@@ -340,6 +352,9 @@ class BaseAnswerResource:
tool_schemas=getattr(agent, "tools", []),
agent_config={
"model_id": model_id or self.default_model_id,
# Persist BYOM scope so resume doesn't
# fall back to caller's layer.
"model_user_id": model_user_id,
"llm_name": getattr(agent, "llm_name", settings.LLM_PROVIDER),
"api_key": getattr(agent, "api_key", None),
"user_api_key": user_api_key,
@@ -370,8 +385,14 @@ class BaseAnswerResource:
if isNoneDoc:
for doc in source_log_docs:
doc["source"] = "None"
# Run under model-owner scope so title-gen LLM inside
# save_conversation uses the owner's BYOM provider/key.
provider = (
get_provider_from_model_id(model_id)
get_provider_from_model_id(
model_id,
user_id=model_user_id
or (decoded_token.get("sub") if decoded_token else None),
)
if model_id
else settings.LLM_PROVIDER
)
@@ -384,6 +405,7 @@ class BaseAnswerResource:
decoded_token=decoded_token,
model_id=model_id,
agent_id=agent_id,
model_user_id=model_user_id,
)
if should_save_conversation:
@@ -481,12 +503,34 @@ class BaseAnswerResource:
if isNoneDoc:
for doc in source_log_docs:
doc["source"] = "None"
# Mirror the normal-path provider resolution so the
# partial-save title LLM uses the model-owner's BYOM
# registration (shared-agent dispatch) rather than
# the deployment default with the instance api key.
provider = (
get_provider_from_model_id(
model_id,
user_id=model_user_id
or (
decoded_token.get("sub")
if decoded_token
else None
),
)
if model_id
else settings.LLM_PROVIDER
)
sys_api_key = get_api_key_for_provider(
provider or settings.LLM_PROVIDER
)
llm = LLMCreator.create_llm(
settings.LLM_PROVIDER,
api_key=settings.API_KEY,
provider or settings.LLM_PROVIDER,
api_key=sys_api_key,
user_api_key=user_api_key,
decoded_token=decoded_token,
model_id=model_id,
agent_id=agent_id,
model_user_id=model_user_id,
)
self.conversation_service.save_conversation(
conversation_id,

View File

@@ -1,21 +1,21 @@
import logging
from typing import Any, Dict, List
from flask import make_response, request
from flask_restx import fields, Resource
from application.api.answer.routes.base import answer_ns
from application.core.settings import settings
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.session import db_readonly
from application.vectorstore.vector_creator import VectorCreator
from application.services.search_service import (
InvalidAPIKey,
SearchFailed,
search,
)
logger = logging.getLogger(__name__)
@answer_ns.route("/api/search")
class SearchResource(Resource):
"""Fast search endpoint for retrieving relevant documents"""
"""Fast search endpoint for retrieving relevant documents."""
search_model = answer_ns.model(
"SearchModel",
@@ -32,102 +32,10 @@ class SearchResource(Resource):
},
)
def _get_sources_from_api_key(self, api_key: str) -> List[str]:
"""Get source IDs connected to the API key/agent."""
with db_readonly() as conn:
agent_data = AgentsRepository(conn).find_by_key(api_key)
if not agent_data:
return []
source_ids: List[str] = []
# extra_source_ids is a PG ARRAY(UUID) of source UUIDs.
extra = agent_data.get("extra_source_ids") or []
for src in extra:
if src:
source_ids.append(str(src))
if not source_ids:
single = agent_data.get("source_id")
if single:
source_ids.append(str(single))
return source_ids
def _search_vectorstores(
self, query: str, source_ids: List[str], chunks: int
) -> List[Dict[str, Any]]:
"""Search across vectorstores and return results"""
if not source_ids:
return []
results = []
chunks_per_source = max(1, chunks // len(source_ids))
seen_texts = set()
for source_id in source_ids:
if not source_id or not source_id.strip():
continue
try:
docsearch = VectorCreator.create_vectorstore(
settings.VECTOR_STORE, source_id, settings.EMBEDDINGS_KEY
)
docs = docsearch.search(query, k=chunks_per_source * 2)
for doc in docs:
if len(results) >= chunks:
break
if hasattr(doc, "page_content") and hasattr(doc, "metadata"):
page_content = doc.page_content
metadata = doc.metadata
else:
page_content = doc.get("text", doc.get("page_content", ""))
metadata = doc.get("metadata", {})
# Skip duplicates
text_hash = hash(page_content[:200])
if text_hash in seen_texts:
continue
seen_texts.add(text_hash)
title = metadata.get(
"title", metadata.get("post_title", "")
)
if not isinstance(title, str):
title = str(title) if title else ""
# Clean up title
if title:
title = title.split("/")[-1]
else:
# Use filename or first part of content as title
title = metadata.get("filename", page_content[:50] + "...")
source = metadata.get("source", source_id)
results.append({
"text": page_content,
"title": title,
"source": source,
})
if len(results) >= chunks:
break
except Exception as e:
logger.error(
f"Error searching vectorstore {source_id}: {e}",
exc_info=True,
)
continue
return results[:chunks]
@answer_ns.expect(search_model)
@answer_ns.doc(description="Search for relevant documents based on query")
def post(self):
data = request.get_json()
data = request.get_json() or {}
question = data.get("question")
api_key = data.get("api_key")
@@ -135,32 +43,13 @@ class SearchResource(Resource):
if not question:
return make_response({"error": "question is required"}, 400)
if not api_key:
return make_response({"error": "api_key is required"}, 400)
# Validate API key
with db_readonly() as conn:
agent = AgentsRepository(conn).find_by_key(api_key)
if not agent:
return make_response({"error": "Invalid API key"}, 401)
try:
# Get sources connected to this API key
source_ids = self._get_sources_from_api_key(api_key)
if not source_ids:
return make_response([], 200)
# Perform search
results = self._search_vectorstores(question, source_ids, chunks)
return make_response(results, 200)
except Exception as e:
logger.error(
f"/api/search - error: {str(e)}",
extra={"error": str(e)},
exc_info=True,
)
return make_response(search(api_key, question, chunks), 200)
except InvalidAPIKey:
return make_response({"error": "Invalid API key"}, 401)
except SearchFailed:
logger.exception("/api/search failed")
return make_response({"error": "Search failed"}, 500)

View File

@@ -109,6 +109,7 @@ class StreamResource(Resource, BaseAnswerResource):
decoded_token=processor.decoded_token,
agent_id=processor.agent_id,
model_id=processor.model_id,
model_user_id=processor.model_user_id,
_continuation={
"messages": messages,
"tools_dict": tools_dict,
@@ -145,6 +146,7 @@ class StreamResource(Resource, BaseAnswerResource):
is_shared_usage=processor.is_shared_usage,
shared_token=processor.shared_token,
model_id=processor.model_id,
model_user_id=processor.model_user_id,
),
mimetype="text/event-stream",
)

View File

@@ -49,6 +49,7 @@ class CompressionOrchestrator:
model_id: str,
decoded_token: Dict[str, Any],
current_query_tokens: int = 500,
model_user_id: Optional[str] = None,
) -> CompressionResult:
"""
Check if compression is needed and perform it if so.
@@ -57,16 +58,18 @@ class CompressionOrchestrator:
Args:
conversation_id: Conversation ID
user_id: User ID
user_id: Caller's user id — used for conversation access checks
model_id: Model being used for conversation
decoded_token: User's decoded JWT token
current_query_tokens: Estimated tokens for current query
model_user_id: BYOM-resolution scope (model owner); defaults
to ``user_id`` for built-in / caller-owned models.
Returns:
CompressionResult with summary and recent queries
"""
try:
# Load conversation
# Conversation row is owned by the caller, not the model owner.
conversation = self.conversation_service.get_conversation(
conversation_id, user_id
)
@@ -77,9 +80,14 @@ class CompressionOrchestrator:
)
return CompressionResult.failure("Conversation not found")
# Check if compression is needed
# Use model-owner scope so per-user BYOM context windows
# (e.g. 8k) compute the threshold against the right limit.
registry_user_id = model_user_id or user_id
if not self.threshold_checker.should_compress(
conversation, model_id, current_query_tokens
conversation,
model_id,
current_query_tokens,
user_id=registry_user_id,
):
# No compression needed, return full history
queries = conversation.get("queries", [])
@@ -87,7 +95,12 @@ class CompressionOrchestrator:
# Perform compression
return self._perform_compression(
conversation_id, conversation, model_id, decoded_token
conversation_id,
conversation,
model_id,
decoded_token,
user_id=user_id,
model_user_id=model_user_id,
)
except Exception as e:
@@ -102,6 +115,8 @@ class CompressionOrchestrator:
conversation: Dict[str, Any],
model_id: str,
decoded_token: Dict[str, Any],
user_id: Optional[str] = None,
model_user_id: Optional[str] = None,
) -> CompressionResult:
"""
Perform the actual compression operation.
@@ -111,6 +126,8 @@ class CompressionOrchestrator:
conversation: Conversation document
model_id: Model ID for conversation
decoded_token: User token
user_id: Caller's id (for conversation reload after compression)
model_user_id: BYOM-resolution scope (model owner)
Returns:
CompressionResult
@@ -123,11 +140,17 @@ class CompressionOrchestrator:
else model_id
)
# Get provider and API key for compression model
provider = get_provider_from_model_id(compression_model)
# Use model-owner scope so provider/api_key resolves to the
# owner's BYOM record (shared-agent dispatch).
caller_user_id = user_id
if caller_user_id is None and isinstance(decoded_token, dict):
caller_user_id = decoded_token.get("sub")
registry_user_id = model_user_id or caller_user_id
provider = get_provider_from_model_id(
compression_model, user_id=registry_user_id
)
api_key = get_api_key_for_provider(provider)
# Create compression LLM
compression_llm = LLMCreator.create_llm(
provider,
api_key=api_key,
@@ -135,6 +158,7 @@ class CompressionOrchestrator:
decoded_token=decoded_token,
model_id=compression_model,
agent_id=conversation.get("agent_id"),
model_user_id=registry_user_id,
)
# Create compression service with DB update capability
@@ -167,9 +191,12 @@ class CompressionOrchestrator:
f"saved {metadata.original_token_count - metadata.compressed_token_count} tokens"
)
# Reload conversation with updated metadata
# Reload under caller (conversation is owned by caller).
reload_user_id = caller_user_id
if reload_user_id is None and isinstance(decoded_token, dict):
reload_user_id = decoded_token.get("sub")
conversation = self.conversation_service.get_conversation(
conversation_id, user_id=decoded_token.get("sub")
conversation_id, user_id=reload_user_id
)
# Get compressed context
@@ -192,16 +219,21 @@ class CompressionOrchestrator:
model_id: str,
decoded_token: Dict[str, Any],
current_conversation: Optional[Dict[str, Any]] = None,
model_user_id: Optional[str] = None,
) -> CompressionResult:
"""
Perform compression during tool execution.
Args:
conversation_id: Conversation ID
user_id: User ID
user_id: Caller's user id — used for conversation access checks
model_id: Model ID
decoded_token: User token
current_conversation: Pre-loaded conversation (optional)
model_user_id: BYOM-resolution scope (model owner). For
shared-agent dispatch this is the agent owner; defaults
to ``user_id`` so built-in / caller-owned models are
unaffected.
Returns:
CompressionResult
@@ -223,7 +255,12 @@ class CompressionOrchestrator:
# Perform compression
return self._perform_compression(
conversation_id, conversation, model_id, decoded_token
conversation_id,
conversation,
model_id,
decoded_token,
user_id=user_id,
model_user_id=model_user_id,
)
except Exception as e:

View File

@@ -106,8 +106,13 @@ class CompressionService:
f"using model {self.model_id}"
)
# See note in conversation_service.py: ``self.model_id`` is
# the registry id (UUID for BYOM); the LLM's own model_id is
# what the provider's API actually expects.
response = self.llm.gen(
model=self.model_id, messages=messages, max_tokens=4000
model=getattr(self.llm, "model_id", None) or self.model_id,
messages=messages,
max_tokens=4000,
)
# Extract summary from response

View File

@@ -30,6 +30,7 @@ class CompressionThresholdChecker:
conversation: Dict[str, Any],
model_id: str,
current_query_tokens: int = 500,
user_id: str | None = None,
) -> bool:
"""
Determine if compression is needed.
@@ -38,6 +39,8 @@ class CompressionThresholdChecker:
conversation: Full conversation document
model_id: Target model for this request
current_query_tokens: Estimated tokens for current query
user_id: Owner — needed so per-user BYOM custom-model UUIDs
resolve when looking up the context window.
Returns:
True if tokens >= threshold% of context window
@@ -48,7 +51,7 @@ class CompressionThresholdChecker:
total_tokens += current_query_tokens
# Get context window limit for model
context_limit = get_token_limit(model_id)
context_limit = get_token_limit(model_id, user_id=user_id)
# Calculate threshold
threshold = int(context_limit * self.threshold_percentage)
@@ -73,20 +76,24 @@ class CompressionThresholdChecker:
logger.error(f"Error checking compression need: {str(e)}", exc_info=True)
return False
def check_message_tokens(self, messages: list, model_id: str) -> bool:
def check_message_tokens(
self, messages: list, model_id: str, user_id: str | None = None
) -> bool:
"""
Check if message list exceeds threshold.
Args:
messages: List of message dicts
model_id: Target model
user_id: Owner — needed so per-user BYOM custom-model UUIDs
resolve when looking up the context window.
Returns:
True if at or above threshold
"""
try:
current_tokens = TokenCounter.count_message_tokens(messages)
context_limit = get_token_limit(model_id)
context_limit = get_token_limit(model_id, user_id=user_id)
threshold = int(context_limit * self.threshold_percentage)
if current_tokens >= threshold:

View File

@@ -12,6 +12,12 @@ logger = logging.getLogger(__name__)
class TokenCounter:
"""Centralized token counting for conversations and messages."""
# Per-image token estimate. Provider tokenizers vary widely
# (Gemini ~258, GPT-4o 85-1500, Claude ~1500) and the actual cost
# depends on resolution/detail we can't see here. Errs slightly high
# so the threshold check stays conservative.
_IMAGE_PART_TOKEN_ESTIMATE = 1500
@staticmethod
def count_message_tokens(messages: List[Dict]) -> int:
"""
@@ -29,12 +35,36 @@ class TokenCounter:
if isinstance(content, str):
total_tokens += num_tokens_from_string(content)
elif isinstance(content, list):
# Handle structured content (tool calls, etc.)
# Handle structured content (tool calls, image parts, etc.)
for item in content:
if isinstance(item, dict):
total_tokens += num_tokens_from_string(str(item))
total_tokens += TokenCounter._count_content_part(item)
return total_tokens
@staticmethod
def _count_content_part(item: Dict) -> int:
# Image/file attachments are billed by the provider per image,
# not proportional to the inline bytes/base64 string.
# ``str(item)`` on a 1MB image inflates the count by ~10000x,
# which trips spurious compression and overflows downstream
# input limits.
item_type = item.get("type")
if "files" in item:
files = item.get("files")
count = len(files) if isinstance(files, list) and files else 1
return TokenCounter._IMAGE_PART_TOKEN_ESTIMATE * count
if "image_url" in item or item_type in {
"image",
"image_url",
"input_image",
"file",
}:
return TokenCounter._IMAGE_PART_TOKEN_ESTIMATE
return num_tokens_from_string(str(item))
@staticmethod
def count_query_tokens(
queries: List[Dict[str, Any]], include_tool_calls: bool = True

View File

@@ -136,8 +136,14 @@ class ConversationService:
},
]
# ``model_id`` here is the registry id (a UUID for BYOM
# records). The LLM's own ``model_id`` is the upstream name
# LLMCreator resolved at construction time — that's what
# the provider's API expects. Built-ins are unaffected.
completion = llm.gen(
model=model_id, messages=messages_summary, max_tokens=500
model=getattr(llm, "model_id", None) or model_id,
messages=messages_summary,
max_tokens=500,
)
if not completion or not completion.strip():

View File

@@ -121,6 +121,8 @@ class StreamProcessor:
self.agent_id = self.data.get("agent_id")
self.agent_key = None
self.model_id: Optional[str] = None
# BYOM-resolution scope, set by _validate_and_set_model.
self.model_user_id: Optional[str] = None
self.conversation_service = ConversationService()
self.compression_orchestrator = CompressionOrchestrator(
self.conversation_service
@@ -191,16 +193,23 @@ class StreamProcessor:
for query in conversation.get("queries", [])
]
else:
# model_user_id keeps history trim aligned with the BYOM's
# actual context window instead of the default 128k.
self.history = limit_chat_history(
json.loads(self.data.get("history", "[]")), model_id=self.model_id
json.loads(self.data.get("history", "[]")),
model_id=self.model_id,
user_id=self.model_user_id,
)
def _handle_compression(self, conversation: Dict[str, Any]):
"""Handle conversation compression logic using orchestrator."""
try:
# initial_user_id for conversation access; model_user_id
# for BYOM context-window / provider lookups.
result = self.compression_orchestrator.compress_if_needed(
conversation_id=self.conversation_id,
user_id=self.initial_user_id,
model_user_id=self.model_user_id,
model_id=self.model_id,
decoded_token=self.decoded_token,
)
@@ -284,11 +293,18 @@ class StreamProcessor:
from application.core.model_settings import ModelRegistry
requested_model = self.data.get("model_id")
# Caller picks from their own BYOM layer; agent defaults resolve
# under the owner's layer (shared agents have caller != owner).
caller_user_id = self.initial_user_id
owner_user_id = self.agent_config.get("user_id") or caller_user_id
if requested_model:
if not validate_model_id(requested_model):
if not validate_model_id(requested_model, user_id=caller_user_id):
registry = ModelRegistry.get_instance()
available_models = [m.id for m in registry.get_enabled_models()]
available_models = [
m.id
for m in registry.get_enabled_models(user_id=caller_user_id)
]
raise ValueError(
f"Invalid model_id '{requested_model}'. "
f"Available models: {', '.join(available_models[:5])}"
@@ -299,12 +315,17 @@ class StreamProcessor:
)
)
self.model_id = requested_model
self.model_user_id = caller_user_id
else:
agent_default_model = self.agent_config.get("default_model_id", "")
if agent_default_model and validate_model_id(agent_default_model):
if agent_default_model and validate_model_id(
agent_default_model, user_id=owner_user_id
):
self.model_id = agent_default_model
self.model_user_id = owner_user_id
else:
self.model_id = get_default_model_id()
self.model_user_id = None
def _get_agent_key(self, agent_id: Optional[str], user_id: Optional[str]) -> tuple:
"""Get API key for agent with access control."""
@@ -514,6 +535,10 @@ class StreamProcessor:
"allow_system_prompt_override": self._agent_data.get(
"allow_system_prompt_override", False
),
# Owner identity — _validate_and_set_model reads this to
# resolve owner-stored BYOM default_model_id against the
# owner's per-user model layer rather than the caller's.
"user_id": self._agent_data.get("user"),
}
)
@@ -561,7 +586,13 @@ class StreamProcessor:
def _configure_retriever(self):
"""Assemble retriever config with precedence: request > agent > default."""
doc_token_limit = calculate_doc_token_budget(model_id=self.model_id)
# BYOM scope: owner for shared-agent BYOM, caller for own BYOM,
# None for built-ins. Without ``user_id`` here, the doc budget
# falls back to settings.DEFAULT_LLM_TOKEN_LIMIT and overfills
# the upstream context window for any small (e.g. 8k/32k) BYOM.
doc_token_limit = calculate_doc_token_budget(
model_id=self.model_id, user_id=self.model_user_id
)
# Start with defaults
retriever_name = "classic"
@@ -612,6 +643,7 @@ class StreamProcessor:
chunks=self.retriever_config["chunks"],
doc_token_limit=self.retriever_config.get("doc_token_limit", 50000),
model_id=self.model_id,
model_user_id=self.model_user_id,
user_api_key=self.agent_config["user_api_key"],
agent_id=self.agent_id,
decoded_token=self.decoded_token,
@@ -903,6 +935,11 @@ class StreamProcessor:
agent_config = state["agent_config"]
model_id = agent_config.get("model_id")
# BYOM scope captured at initial dispatch. None for built-ins or
# caller-owned BYOM where decoded_token['sub'] is already the
# right scope; non-None for shared-agent owner BYOM where the
# caller's identity differs from the model owner's.
model_user_id = agent_config.get("model_user_id")
llm_name = agent_config.get("llm_name", settings.LLM_PROVIDER)
api_key = agent_config.get("api_key")
user_api_key = agent_config.get("user_api_key")
@@ -920,6 +957,7 @@ class StreamProcessor:
decoded_token=self.decoded_token,
model_id=model_id,
agent_id=agent_id,
model_user_id=model_user_id,
)
llm_handler = LLMHandlerCreator.create_handler(llm_name or "default")
tool_executor = ToolExecutor(
@@ -949,6 +987,7 @@ class StreamProcessor:
"endpoint": "stream",
"llm_name": llm_name,
"model_id": model_id,
"model_user_id": model_user_id,
"api_key": system_api_key,
"agent_id": agent_id,
"user_api_key": user_api_key,
@@ -971,6 +1010,15 @@ class StreamProcessor:
# Store config for the route layer
self.model_id = model_id
# Mirror ``model_user_id`` back onto the processor so the route
# layer (StreamResource) reads the owner scope captured at
# initial dispatch. Without this, ``processor.model_user_id``
# stays at the __init__ default (None) and complete_stream
# falls back to the caller's sub: the post-resume title-LLM
# save misses the owner's BYOM layer, and any second tool
# pause persists ``model_user_id=None`` — losing owner scope
# for every subsequent resume of this conversation.
self.model_user_id = model_user_id
self.agent_id = agent_id
self.agent_config["user_api_key"] = user_api_key
self.conversation_id = conversation_id
@@ -1022,8 +1070,11 @@ class StreamProcessor:
tools_data=tools_data,
)
# Use the user_id that resolved the model so owner-scoped BYOM
# records dispatch correctly on shared-agent requests.
model_user_id = getattr(self, "model_user_id", self.initial_user_id)
provider = (
get_provider_from_model_id(self.model_id)
get_provider_from_model_id(self.model_id, user_id=model_user_id)
if self.model_id
else settings.LLM_PROVIDER
)
@@ -1048,6 +1099,8 @@ class StreamProcessor:
model_id=self.model_id,
agent_id=self.agent_id,
backup_models=backup_models,
# Owner-scope on shared-agent BYOM dispatch.
model_user_id=model_user_id,
)
llm_handler = LLMHandlerCreator.create_handler(
provider if provider else "default"
@@ -1070,6 +1123,7 @@ class StreamProcessor:
"endpoint": "stream",
"llm_name": provider or settings.LLM_PROVIDER,
"model_id": self.model_id,
"model_user_id": self.model_user_id,
"api_key": system_api_key,
"agent_id": self.agent_id,
"user_api_key": self.agent_config["user_api_key"],
@@ -1097,6 +1151,7 @@ class StreamProcessor:
"doc_token_limit", 50000
),
"model_id": self.model_id,
"model_user_id": self.model_user_id,
"user_api_key": self.agent_config["user_api_key"],
"agent_id": self.agent_id,
"llm_name": provider or settings.LLM_PROVIDER,

View File

@@ -1,18 +1,135 @@
from flask import current_app, jsonify, make_response
"""Model routes.
- ``GET /api/models`` — list available models for the current user.
Combines the built-in catalog with the user's BYOM records.
- ``GET/POST/PATCH/DELETE /api/user/models[/<id>]`` — CRUD for the
user's own OpenAI-compatible model registrations (BYOM).
- ``POST /api/user/models/<id>/test`` — sanity-check the upstream
endpoint with a tiny request.
Every BYOM endpoint is user-scoped at the repository layer
(every query filters on ``user_id`` from ``request.decoded_token``).
"""
from __future__ import annotations
import logging
import requests
from flask import current_app, jsonify, make_response, request
from flask_restx import Namespace, Resource
from application.core.model_settings import ModelRegistry
from application.api import api
from application.core.model_registry import ModelRegistry
from application.security.safe_url import (
UnsafeUserUrlError,
pinned_post,
validate_user_base_url,
)
from application.storage.db.repositories.user_custom_models import (
UserCustomModelsRepository,
)
from application.storage.db.session import db_readonly, db_session
from application.utils import check_required_fields
logger = logging.getLogger(__name__)
models_ns = Namespace("models", description="Available models", path="/api")
_CONTEXT_WINDOW_MIN = 1_000
_CONTEXT_WINDOW_MAX = 10_000_000
def _user_id_or_401():
decoded_token = request.decoded_token
if not decoded_token:
return None, make_response(jsonify({"success": False}), 401)
user_id = decoded_token.get("sub")
if not user_id:
return None, make_response(jsonify({"success": False}), 401)
return user_id, None
def _normalize_capabilities(raw) -> dict:
"""Coerce + bound the user-supplied capabilities payload."""
raw = raw or {}
out = {}
if "supports_tools" in raw:
out["supports_tools"] = bool(raw["supports_tools"])
if "supports_structured_output" in raw:
out["supports_structured_output"] = bool(raw["supports_structured_output"])
if "supports_streaming" in raw:
out["supports_streaming"] = bool(raw["supports_streaming"])
if "attachments" in raw:
atts = raw["attachments"] or []
if not isinstance(atts, list):
raise ValueError("'capabilities.attachments' must be a list")
coerced = [str(a) for a in atts]
# Reject unknown aliases at the API boundary so bad payloads
# never reach the registry layer (where lenient expansion just
# drops them). Raw MIME types (containing ``/``) pass through
# unchanged for parity with the built-in YAML schema.
from application.core.model_yaml import builtin_attachment_aliases
aliases = builtin_attachment_aliases()
for entry in coerced:
if "/" in entry:
continue
if entry not in aliases:
valid = ", ".join(sorted(aliases.keys())) or "<none defined>"
raise ValueError(
f"unknown attachment alias '{entry}' in "
f"'capabilities.attachments'. Valid aliases: {valid}, "
f"or use a raw MIME type like 'image/png'."
)
out["attachments"] = coerced
if "context_window" in raw:
try:
cw = int(raw["context_window"])
except (TypeError, ValueError):
raise ValueError("'capabilities.context_window' must be an integer")
if not (_CONTEXT_WINDOW_MIN <= cw <= _CONTEXT_WINDOW_MAX):
raise ValueError(
f"'capabilities.context_window' must be between "
f"{_CONTEXT_WINDOW_MIN} and {_CONTEXT_WINDOW_MAX}"
)
out["context_window"] = cw
return out
def _row_to_response(row: dict) -> dict:
"""Wire-format projection — never includes the API key."""
return {
"id": str(row["id"]),
"upstream_model_id": row["upstream_model_id"],
"display_name": row["display_name"],
"description": row.get("description") or "",
"base_url": row["base_url"],
"capabilities": row.get("capabilities") or {},
"enabled": bool(row.get("enabled", True)),
"source": "user",
}
@models_ns.route("/models")
class ModelsListResource(Resource):
def get(self):
"""Get list of available models with their capabilities."""
"""Get list of available models with their capabilities.
When the request is authenticated, the response includes the
user's own BYOM registrations alongside the built-in catalog.
"""
try:
user_id = None
decoded_token = getattr(request, "decoded_token", None)
if decoded_token:
user_id = decoded_token.get("sub")
registry = ModelRegistry.get_instance()
models = registry.get_enabled_models()
models = registry.get_enabled_models(user_id=user_id)
response = {
"models": [model.to_dict() for model in models],
@@ -23,3 +140,382 @@ class ModelsListResource(Resource):
current_app.logger.error(f"Error fetching models: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 500)
return make_response(jsonify(response), 200)
@models_ns.route("/user/models")
class UserModelsCollectionResource(Resource):
@api.doc(description="List the current user's BYOM custom models")
def get(self):
user_id, err = _user_id_or_401()
if err:
return err
try:
with db_readonly() as conn:
rows = UserCustomModelsRepository(conn).list_for_user(user_id)
return make_response(
jsonify({"models": [_row_to_response(r) for r in rows]}), 200
)
except Exception as e:
current_app.logger.error(
f"Error listing user custom models: {e}", exc_info=True
)
return make_response(jsonify({"success": False}), 500)
@api.doc(description="Register a new BYOM custom model")
def post(self):
user_id, err = _user_id_or_401()
if err:
return err
data = request.get_json() or {}
missing = check_required_fields(
data,
["upstream_model_id", "display_name", "base_url", "api_key"],
)
if missing:
return missing
# SECURITY: reject blank api_key — would leak instance API key
# to the user-supplied base_url via LLMCreator fallback.
for required_nonblank in (
"upstream_model_id",
"display_name",
"base_url",
"api_key",
):
value = data.get(required_nonblank)
if not isinstance(value, str) or not value.strip():
return make_response(
jsonify(
{
"success": False,
"error": f"'{required_nonblank}' must be a non-empty string",
}
),
400,
)
# SSRF guard at create time. Re-runs at dispatch time (LLMCreator)
# as defense in depth against DNS rebinding and pre-guard rows.
try:
validate_user_base_url(data["base_url"])
except UnsafeUserUrlError as e:
return make_response(
jsonify({"success": False, "error": str(e)}), 400
)
try:
capabilities = _normalize_capabilities(data.get("capabilities"))
except ValueError as e:
return make_response(
jsonify({"success": False, "error": str(e)}), 400
)
try:
with db_session() as conn:
row = UserCustomModelsRepository(conn).create(
user_id=user_id,
upstream_model_id=data["upstream_model_id"],
display_name=data["display_name"],
description=data.get("description") or "",
base_url=data["base_url"],
api_key_plaintext=data["api_key"],
capabilities=capabilities,
enabled=bool(data.get("enabled", True)),
)
except Exception as e:
current_app.logger.error(
f"Error creating user custom model: {e}", exc_info=True
)
return make_response(jsonify({"success": False}), 500)
ModelRegistry.invalidate_user(user_id)
return make_response(jsonify(_row_to_response(row)), 201)
@models_ns.route("/user/models/<string:model_id>")
class UserModelResource(Resource):
@api.doc(description="Get one BYOM custom model")
def get(self, model_id):
user_id, err = _user_id_or_401()
if err:
return err
try:
with db_readonly() as conn:
row = UserCustomModelsRepository(conn).get(model_id, user_id)
except Exception as e:
current_app.logger.error(
f"Error fetching user custom model: {e}", exc_info=True
)
return make_response(jsonify({"success": False}), 500)
if row is None:
return make_response(jsonify({"success": False}), 404)
return make_response(jsonify(_row_to_response(row)), 200)
@api.doc(description="Update a BYOM custom model (partial)")
def patch(self, model_id):
user_id, err = _user_id_or_401()
if err:
return err
data = request.get_json() or {}
# Reject present-but-blank values for fields where blank doesn't
# mean "no change". (The api_key special case — blank means "keep
# existing" — is handled below.)
for required_nonblank in (
"upstream_model_id",
"display_name",
"base_url",
):
if required_nonblank in data:
value = data[required_nonblank]
if not isinstance(value, str) or not value.strip():
return make_response(
jsonify(
{
"success": False,
"error": f"'{required_nonblank}' cannot be blank",
}
),
400,
)
if "base_url" in data and data["base_url"]:
try:
validate_user_base_url(data["base_url"])
except UnsafeUserUrlError as e:
return make_response(
jsonify({"success": False, "error": str(e)}), 400
)
update_fields: dict = {}
for k in (
"upstream_model_id",
"display_name",
"description",
"base_url",
"enabled",
):
if k in data:
update_fields[k] = data[k]
if "capabilities" in data:
try:
update_fields["capabilities"] = _normalize_capabilities(
data["capabilities"]
)
except ValueError as e:
return make_response(
jsonify({"success": False, "error": str(e)}), 400
)
# PATCH semantics: blank/missing api_key → keep the existing
# ciphertext; non-empty api_key → re-encrypt and replace.
if data.get("api_key"):
update_fields["api_key_plaintext"] = data["api_key"]
if not update_fields:
return make_response(
jsonify({"success": False, "error": "no updatable fields"}), 400
)
try:
with db_session() as conn:
ok = UserCustomModelsRepository(conn).update(
model_id, user_id, update_fields
)
except Exception as e:
current_app.logger.error(
f"Error updating user custom model: {e}", exc_info=True
)
return make_response(jsonify({"success": False}), 500)
if not ok:
return make_response(jsonify({"success": False}), 404)
ModelRegistry.invalidate_user(user_id)
with db_readonly() as conn:
row = UserCustomModelsRepository(conn).get(model_id, user_id)
return make_response(jsonify(_row_to_response(row)), 200)
@api.doc(description="Delete a BYOM custom model")
def delete(self, model_id):
user_id, err = _user_id_or_401()
if err:
return err
try:
with db_session() as conn:
ok = UserCustomModelsRepository(conn).delete(model_id, user_id)
except Exception as e:
current_app.logger.error(
f"Error deleting user custom model: {e}", exc_info=True
)
return make_response(jsonify({"success": False}), 500)
if not ok:
return make_response(jsonify({"success": False}), 404)
ModelRegistry.invalidate_user(user_id)
return make_response(jsonify({"success": True}), 200)
def _run_connection_test(
base_url: str, api_key: str, upstream_model_id: str
):
"""Send a 1-token chat-completion to verify a BYOM endpoint.
Returns ``(body, http_status)``. Upstream errors return 200 with
``ok=False`` so the UI can render inline errors; only local SSRF
rejection returns 400.
"""
url = base_url.rstrip("/") + "/chat/completions"
payload = {
"model": upstream_model_id,
"messages": [{"role": "user", "content": "hi"}],
"max_tokens": 1,
"stream": False,
}
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
try:
# pinned_post closes the DNS-rebinding window. Redirects off
# because 3xx could bounce to an internal address (the SSRF
# guard only validates the supplied URL).
resp = pinned_post(
url,
json=payload,
headers=headers,
timeout=5,
allow_redirects=False,
)
except UnsafeUserUrlError as e:
return {"ok": False, "error": str(e)}, 400
except requests.RequestException as e:
return {"ok": False, "error": f"connection error: {e}"}, 200
if 300 <= resp.status_code < 400:
return (
{
"ok": False,
"error": (
f"upstream returned HTTP {resp.status_code} "
"redirect; refusing to follow"
),
},
200,
)
if resp.status_code >= 400:
# Cap and only reflect JSON to avoid body-exfil via non-API responses.
content_type = (resp.headers.get("Content-Type") or "").lower()
if "application/json" in content_type:
text = (resp.text or "")[:500]
error_msg = f"upstream returned HTTP {resp.status_code}: {text}"
else:
error_msg = f"upstream returned HTTP {resp.status_code}"
return {"ok": False, "error": error_msg}, 200
return {"ok": True}, 200
@models_ns.route("/user/models/test")
class UserModelTestPayloadResource(Resource):
@api.doc(
description=(
"Test an arbitrary BYOM payload (display_name / model id / "
"base_url / api_key) without saving. Used by the UI's 'Test "
"connection' button so the user can validate before they "
"Save. Same SSRF guard, same 1-token request, same 5s "
"timeout as the by-id variant."
)
)
def post(self):
user_id, err = _user_id_or_401()
if err:
return err
data = request.get_json() or {}
missing = check_required_fields(
data, ["base_url", "api_key", "upstream_model_id"]
)
if missing:
return missing
body, status = _run_connection_test(
data["base_url"], data["api_key"], data["upstream_model_id"]
)
return make_response(jsonify(body), status)
@models_ns.route("/user/models/<string:model_id>/test")
class UserModelTestResource(Resource):
@api.doc(
description=(
"Test a saved BYOM record. Defaults to the stored "
"base_url / upstream_model_id / encrypted api_key, but "
"any of those can be overridden via the request body so "
"the UI can test in-flight edits before saving. Used by "
"the 'Test connection' button in edit mode."
)
)
def post(self, model_id):
user_id, err = _user_id_or_401()
if err:
return err
data = request.get_json() or {}
# Per-field overrides; blank/missing falls back to stored value.
override_base_url = (data.get("base_url") or "").strip() or None
override_upstream_model_id = (
data.get("upstream_model_id") or ""
).strip() or None
override_api_key = (data.get("api_key") or "").strip() or None
try:
with db_readonly() as conn:
repo = UserCustomModelsRepository(conn)
row = repo.get(model_id, user_id)
if row is None:
return make_response(jsonify({"success": False}), 404)
stored_api_key = (
repo._decrypt_api_key(
row.get("api_key_encrypted", ""), user_id
)
if not override_api_key
else None
)
except Exception as e:
current_app.logger.error(
f"Error loading user custom model for test: {e}", exc_info=True
)
return make_response(
jsonify({"ok": False, "error": "internal error loading model"}),
500,
)
api_key = override_api_key or stored_api_key
if not api_key:
return make_response(
jsonify(
{
"ok": False,
"error": (
"Stored API key could not be decrypted. The "
"encryption secret may have rotated. Re-save "
"the model with the API key to recover."
),
}
),
400,
)
base_url = override_base_url or row["base_url"]
upstream_model_id = (
override_upstream_model_id or row["upstream_model_id"]
)
body, status = _run_connection_test(
base_url, api_key, upstream_model_id
)
return make_response(jsonify(body), status)

View File

@@ -140,6 +140,11 @@ def setup_periodic_tasks(sender, **kwargs):
cleanup_pending_tool_state.s(),
name="cleanup-pending-tool-state",
)
sender.add_periodic_task(
timedelta(hours=7),
version_check_task.s(),
name="version-check",
)
@celery.task(bind=True)
@@ -176,3 +181,16 @@ def cleanup_pending_tool_state(self):
with engine.begin() as conn:
deleted = PendingToolStateRepository(conn).cleanup_expired()
return {"deleted": deleted}
@celery.task(bind=True)
def version_check_task(self):
"""Periodic anonymous version check.
Complements the ``worker_ready`` boot trigger so long-running
deployments (>6h cache TTL) still refresh advisories. ``run_check``
is fail-silent and coordinates across replicas via Redis lock +
cache (see ``application.updates.version_check``).
"""
from application.updates.version_check import run_check
run_check()

View File

@@ -198,8 +198,14 @@ def normalize_agent_node_json_schemas(nodes: List[Dict]) -> List[Dict]:
return normalized_nodes
def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[str]:
"""Validate workflow graph structure."""
def validate_workflow_structure(
nodes: List[Dict], edges: List[Dict], user_id: str | None = None
) -> List[str]:
"""Validate workflow graph structure.
``user_id`` is required so per-user BYOM custom-model UUIDs resolve
when checking each agent node's structured-output capability.
"""
errors = []
if not nodes:
@@ -343,7 +349,7 @@ def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[st
model_id = raw_config.get("model_id")
if has_json_schema and isinstance(model_id, str) and model_id.strip():
capabilities = get_model_capabilities(model_id.strip())
capabilities = get_model_capabilities(model_id.strip(), user_id=user_id)
if capabilities and not capabilities.get("supports_structured_output", False):
errors.append(
f"Agent node '{agent_title}' selected model does not support structured output"
@@ -389,7 +395,9 @@ class WorkflowList(Resource):
nodes_data = data.get("nodes", [])
edges_data = data.get("edges", [])
validation_errors = validate_workflow_structure(nodes_data, edges_data)
validation_errors = validate_workflow_structure(
nodes_data, edges_data, user_id=user_id
)
if validation_errors:
return error_response(
"Workflow validation failed", errors=validation_errors
@@ -451,7 +459,9 @@ class WorkflowDetail(Resource):
nodes_data = data.get("nodes", [])
edges_data = data.get("edges", [])
validation_errors = validate_workflow_structure(nodes_data, edges_data)
validation_errors = validate_workflow_structure(
nodes_data, edges_data, user_id=user_id
)
if validation_errors:
return error_response(
"Workflow validation failed", errors=validation_errors

View File

@@ -213,6 +213,7 @@ def _stream_response(
decoded_token=processor.decoded_token,
agent_id=processor.agent_id,
model_id=processor.model_id,
model_user_id=processor.model_user_id,
should_save_conversation=should_save_conversation,
_continuation=continuation,
)
@@ -257,6 +258,7 @@ def _non_stream_response(
decoded_token=processor.decoded_token,
agent_id=processor.agent_id,
model_id=processor.model_id,
model_user_id=processor.model_user_id,
should_save_conversation=should_save_conversation,
_continuation=continuation,
)

View File

@@ -4,11 +4,12 @@ import platform
import uuid
import dotenv
from flask import Flask, jsonify, redirect, request
from flask import Flask, Response, jsonify, redirect, request
from jose import jwt
from application.auth import handle_auth
from application.core import log_context
from application.core.logging_config import setup_logging
setup_logging()
@@ -112,6 +113,38 @@ def generate_token():
return jsonify({"error": "Token generation not allowed in current auth mode"}), 400
_LOG_CTX_TOKEN_ATTR = "_log_ctx_token"
@app.before_request
def _bind_log_context():
"""Bind activity_id + endpoint for the duration of this request.
Runs before ``authenticate_request``; ``user_id`` is overlaid in a
follow-up handler once the JWT has been decoded.
"""
if request.method == "OPTIONS":
return None
activity_id = str(uuid.uuid4())
request.activity_id = activity_id
token = log_context.bind(
activity_id=activity_id,
endpoint=request.endpoint,
)
setattr(request, _LOG_CTX_TOKEN_ATTR, token)
return None
@app.teardown_request
def _reset_log_context(_exc):
# SSE streams keep yielding after teardown fires, but a2wsgi runs each
# request inside ``copy_context().run(...)``, so this reset doesn't
# leak into the stream's view of the context.
token = getattr(request, _LOG_CTX_TOKEN_ATTR, None)
if token is not None:
log_context.reset(token)
@app.before_request
def enforce_stt_request_size_limits():
if request.method == "OPTIONS":
@@ -148,13 +181,27 @@ def authenticate_request():
request.decoded_token = decoded_token
@app.before_request
def _bind_user_id_to_log_context():
# Registered after ``authenticate_request`` (Flask runs before_request
# handlers in registration order), so ``request.decoded_token`` is
# populated by the time we read it. ``teardown_request`` unwinds the
# whole request-level bind, so no separate reset token is needed here.
if request.method == "OPTIONS":
return None
decoded_token = getattr(request, "decoded_token", None)
user_id = decoded_token.get("sub") if isinstance(decoded_token, dict) else None
if user_id:
log_context.bind(user_id=user_id)
return None
@app.after_request
def after_request(response):
response.headers.add("Access-Control-Allow-Origin", "*")
response.headers.add("Access-Control-Allow-Headers", "Content-Type, Authorization")
response.headers.add(
"Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS"
)
def after_request(response: Response) -> Response:
"""Add CORS headers for the pure Flask development entrypoint."""
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization"
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
return response

33
application/asgi.py Normal file
View File

@@ -0,0 +1,33 @@
"""ASGI entrypoint: Flask (WSGI) + FastMCP on the same process."""
from __future__ import annotations
from a2wsgi import WSGIMiddleware
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware
from starlette.routing import Mount
from application.app import app as flask_app
from application.mcp_server import mcp
_WSGI_THREADPOOL = 32
mcp_app = mcp.http_app(path="/")
asgi_app = Starlette(
routes=[
Mount("/mcp", app=mcp_app),
Mount("/", app=WSGIMiddleware(flask_app, workers=_WSGI_THREADPOOL)),
],
middleware=[
Middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
allow_headers=["Content-Type", "Authorization", "Mcp-Session-Id"],
expose_headers=["Mcp-Session-Id"],
),
],
lifespan=mcp_app.lifespan,
)

View File

@@ -1,3 +1,4 @@
import hashlib
import json
import logging
import time
@@ -10,6 +11,14 @@ from application.utils import get_hash
logger = logging.getLogger(__name__)
def _cache_default(value):
# Image attachments arrive inline as bytes (see GoogleLLM.prepare_messages_with_attachments);
# hash so the cache key stays bounded in size and stable across identical content.
if isinstance(value, (bytes, bytearray, memoryview)):
return f"<bytes:sha256:{hashlib.sha256(bytes(value)).hexdigest()}>"
return repr(value)
_redis_instance = None
_redis_creation_failed = False
_instance_lock = Lock()
@@ -36,7 +45,7 @@ def get_redis_instance():
def gen_cache_key(messages, model="docgpt", tools=None):
if not all(isinstance(msg, dict) for msg in messages):
raise ValueError("All messages must be dictionaries.")
messages_str = json.dumps(messages)
messages_str = json.dumps(messages, default=_cache_default)
tools_str = json.dumps(str(tools)) if tools else ""
combined = f"{model}_{messages_str}_{tools_str}"
cache_key = get_hash(combined)

View File

@@ -1,8 +1,17 @@
import inspect
import logging
import threading
from celery import Celery
from application.core import log_context
from application.core.settings import settings
from celery.signals import setup_logging, worker_process_init, worker_ready
from celery.signals import (
setup_logging,
task_postrun,
task_prerun,
worker_process_init,
worker_ready,
)
def make_celery(app_name=__name__):
@@ -41,6 +50,54 @@ def _dispose_db_engine_on_fork(*args, **kwargs):
dispose_engine()
# Most tasks in this repo accept ``user`` where the log context wants
# ``user_id``; map task parameter names to context keys explicitly.
_TASK_PARAM_TO_CTX_KEY: dict[str, str] = {
"user": "user_id",
"user_id": "user_id",
"agent_id": "agent_id",
"conversation_id": "conversation_id",
}
_task_log_tokens: dict[str, object] = {}
@task_prerun.connect
def _bind_task_log_context(task_id, task, args, kwargs, **_):
# Resolve task args by parameter name — nearly every task in this repo
# is called positionally, so ``kwargs.get('user')`` would bind nothing.
ctx = {"activity_id": task_id}
try:
sig = inspect.signature(task.run)
bound = sig.bind_partial(*args, **kwargs).arguments
except (TypeError, ValueError):
bound = dict(kwargs)
for param_name, value in bound.items():
ctx_key = _TASK_PARAM_TO_CTX_KEY.get(param_name)
if ctx_key and value:
ctx[ctx_key] = value
_task_log_tokens[task_id] = log_context.bind(**ctx)
@task_postrun.connect
def _unbind_task_log_context(task_id, **_):
# ``task_postrun`` fires on both success and failure. Required for
# Celery: unlike the Flask path, tasks aren't isolated in their own
# ``copy_context().run(...)``, so a missing reset would leak the
# bind onto the next task on the same worker.
token = _task_log_tokens.pop(task_id, None)
if token is None:
return
try:
log_context.reset(token)
except ValueError:
# task_prerun and task_postrun ran on different threads (non-default
# Celery pool); the token isn't valid in this context. Drop it.
logging.getLogger(__name__).debug(
"log_context reset skipped for task %s", task_id
)
@worker_ready.connect
def _run_version_check(*args, **kwargs):
"""Kick off the anonymous version check on worker startup.

View File

@@ -9,3 +9,8 @@ accept_content = ['json']
# Autodiscover tasks
imports = ('application.api.user.tasks',)
beat_scheduler = "redbeat.RedBeatScheduler"
redbeat_redis_url = broker_url
redbeat_key_prefix = "redbeat:docsgpt:"
redbeat_lock_timeout = 90

View File

@@ -0,0 +1,57 @@
"""Per-activity logging context backed by ``contextvars``.
The ``_ContextFilter`` installed by ``logging_config.setup_logging`` stamps
every ``LogRecord`` emitted inside a ``bind`` block with the bound keys, so
they land as first-class attributes on the OTLP log export rather than being
buried inside formatted message bodies.
A single ``ContextVar`` holds a dict so nested binds reset atomically (LIFO)
via the token returned by ``bind``.
"""
from __future__ import annotations
from contextvars import ContextVar, Token
from typing import Mapping
_CTX_KEYS: frozenset[str] = frozenset(
{
"activity_id",
"parent_activity_id",
"user_id",
"agent_id",
"conversation_id",
"endpoint",
"model",
}
)
_ctx: ContextVar[Mapping[str, str]] = ContextVar("log_ctx", default={})
def bind(**kwargs: object) -> Token:
"""Overlay the given keys onto the current context.
Returns a ``Token`` so the caller can ``reset`` in a ``finally`` block.
Keys outside :data:`_CTX_KEYS` are silently dropped (so a typo can't
stamp a stray field name onto every record), as are ``None`` values
(a missing attribute is more useful than the literal string ``"None"``).
"""
overlay = {
k: str(v)
for k, v in kwargs.items()
if k in _CTX_KEYS and v is not None
}
new = {**_ctx.get(), **overlay}
return _ctx.set(new)
def reset(token: Token) -> None:
"""Restore the context to the snapshot captured by the matching ``bind``."""
_ctx.reset(token)
def snapshot() -> Mapping[str, str]:
"""Return the current context dict. Treat as read-only; use :func:`bind`."""
return _ctx.get()

View File

@@ -1,11 +1,75 @@
import logging
import os
from logging.config import dictConfig
def setup_logging():
from application.core.log_context import snapshot as _ctx_snapshot
# Loggers with ``propagate=False`` don't share root's handlers, so the
# context filter has to be installed on their handlers directly.
_NON_PROPAGATING_LOGGERS: tuple[str, ...] = (
"uvicorn",
"uvicorn.access",
"uvicorn.error",
"celery.app.trace",
"celery.worker.strategy",
"gunicorn.error",
"gunicorn.access",
)
class _ContextFilter(logging.Filter):
"""Stamp the current ``log_context`` snapshot onto every ``LogRecord``.
Must be installed on **handlers**, not loggers: Python skips logger-level
filters when a child logger's record propagates up. The ``hasattr`` guard
keeps an explicit ``logger.info(..., extra={...})`` from being overwritten.
"""
def filter(self, record: logging.LogRecord) -> bool:
for key, value in _ctx_snapshot().items():
if not hasattr(record, key):
setattr(record, key, value)
return True
def _otlp_logs_enabled() -> bool:
"""Return True when the user has opted in to OTLP log export.
Gated by the standard OTEL env vars so no project-specific knob is needed:
set ``OTEL_LOGS_EXPORTER=otlp`` (and leave ``OTEL_SDK_DISABLED`` unset or
false) to flip it on. When false, ``setup_logging`` keeps its original
console-only behavior.
"""
exporter = os.getenv("OTEL_LOGS_EXPORTER", "").strip().lower()
disabled = os.getenv("OTEL_SDK_DISABLED", "false").strip().lower() == "true"
return exporter == "otlp" and not disabled
def setup_logging() -> None:
"""Configure the root logger with a stdout console handler.
When OTLP log export is enabled, ``opentelemetry-instrument`` attaches a
``LoggingHandler`` to the root logger before this function runs. The
``dictConfig`` call below replaces ``root.handlers`` with the console
handler, which would silently drop the OTEL handler. To make OTLP log
export work without forcing every contributor to opt in, snapshot the
OTEL handlers up front and re-attach them after ``dictConfig``.
"""
preserved_handlers: list[logging.Handler] = []
if _otlp_logs_enabled():
preserved_handlers = [
h
for h in logging.getLogger().handlers
if h.__class__.__module__.startswith("opentelemetry")
]
dictConfig({
'version': 1,
'formatters': {
'default': {
'format': '[%(asctime)s] %(levelname)s in %(module)s: %(message)s',
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {
"format": "[%(asctime)s] %(levelname)s in %(module)s: %(message)s",
}
},
"handlers": {
@@ -15,8 +79,34 @@ def setup_logging():
"formatter": "default",
}
},
'root': {
'level': 'INFO',
'handlers': ['console'],
"root": {
"level": "INFO",
"handlers": ["console"],
},
})
})
if preserved_handlers:
root = logging.getLogger()
for handler in preserved_handlers:
if handler not in root.handlers:
root.addHandler(handler)
_install_context_filter()
def _install_context_filter() -> None:
"""Attach :class:`_ContextFilter` to root's handlers + every handler on
the known non-propagating loggers. Skipping handlers that already carry
one keeps repeat ``setup_logging`` calls from stacking filters.
"""
def _has_ctx_filter(handler: logging.Handler) -> bool:
return any(isinstance(f, _ContextFilter) for f in handler.filters)
for handler in logging.getLogger().handlers:
if not _has_ctx_filter(handler):
handler.addFilter(_ContextFilter())
for name in _NON_PROPAGATING_LOGGERS:
for handler in logging.getLogger(name).handlers:
if not _has_ctx_filter(handler):
handler.addFilter(_ContextFilter())

View File

@@ -1,266 +0,0 @@
"""
Model configurations for all supported LLM providers.
"""
from application.core.model_settings import (
AvailableModel,
ModelCapabilities,
ModelProvider,
)
# Base image attachment types supported by most vision-capable LLMs
IMAGE_ATTACHMENTS = [
"image/png",
"image/jpeg",
"image/jpg",
"image/webp",
"image/gif",
]
# PDF excluded: most OpenAI-compatible endpoints don't support native PDF uploads.
# When excluded, PDFs are synthetically processed by converting pages to images.
OPENAI_ATTACHMENTS = IMAGE_ATTACHMENTS
GOOGLE_ATTACHMENTS = ["application/pdf"] + IMAGE_ATTACHMENTS
ANTHROPIC_ATTACHMENTS = IMAGE_ATTACHMENTS
OPENROUTER_ATTACHMENTS = IMAGE_ATTACHMENTS
NOVITA_ATTACHMENTS = IMAGE_ATTACHMENTS
OPENAI_MODELS = [
AvailableModel(
id="gpt-5.1",
provider=ModelProvider.OPENAI,
display_name="GPT-5.1",
description="Flagship model with enhanced reasoning, coding, and agentic capabilities",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
context_window=200000,
),
),
AvailableModel(
id="gpt-5-mini",
provider=ModelProvider.OPENAI,
display_name="GPT-5 Mini",
description="Faster, cost-effective variant of GPT-5.1",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
context_window=200000,
),
)
]
ANTHROPIC_MODELS = [
AvailableModel(
id="claude-3-5-sonnet-20241022",
provider=ModelProvider.ANTHROPIC,
display_name="Claude 3.5 Sonnet (Latest)",
description="Latest Claude 3.5 Sonnet with enhanced capabilities",
capabilities=ModelCapabilities(
supports_tools=True,
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
context_window=200000,
),
),
AvailableModel(
id="claude-3-5-sonnet",
provider=ModelProvider.ANTHROPIC,
display_name="Claude 3.5 Sonnet",
description="Balanced performance and capability",
capabilities=ModelCapabilities(
supports_tools=True,
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
context_window=200000,
),
),
AvailableModel(
id="claude-3-opus",
provider=ModelProvider.ANTHROPIC,
display_name="Claude 3 Opus",
description="Most capable Claude model",
capabilities=ModelCapabilities(
supports_tools=True,
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
context_window=200000,
),
),
AvailableModel(
id="claude-3-haiku",
provider=ModelProvider.ANTHROPIC,
display_name="Claude 3 Haiku",
description="Fastest Claude model",
capabilities=ModelCapabilities(
supports_tools=True,
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
context_window=200000,
),
),
]
GOOGLE_MODELS = [
AvailableModel(
id="gemini-flash-latest",
provider=ModelProvider.GOOGLE,
display_name="Gemini Flash (Latest)",
description="Latest experimental Gemini model",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=GOOGLE_ATTACHMENTS,
context_window=int(1e6),
),
),
AvailableModel(
id="gemini-flash-lite-latest",
provider=ModelProvider.GOOGLE,
display_name="Gemini Flash Lite (Latest)",
description="Fast with huge context window",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=GOOGLE_ATTACHMENTS,
context_window=int(1e6),
),
),
AvailableModel(
id="gemini-3-pro-preview",
provider=ModelProvider.GOOGLE,
display_name="Gemini 3 Pro",
description="Most capable Gemini model",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=GOOGLE_ATTACHMENTS,
context_window=2000000,
),
),
]
GROQ_MODELS = [
AvailableModel(
id="llama-3.3-70b-versatile",
provider=ModelProvider.GROQ,
display_name="Llama 3.3 70B",
description="Latest Llama model with high-speed inference",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=128000,
),
),
AvailableModel(
id="openai/gpt-oss-120b",
provider=ModelProvider.GROQ,
display_name="GPT-OSS 120B",
description="Open-source GPT model optimized for speed",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=128000,
),
),
]
OPENROUTER_MODELS = [
AvailableModel(
id="qwen/qwen3-coder:free",
provider=ModelProvider.OPENROUTER,
display_name="Qwen 3 Coder",
description="Latest Qwen model with high-speed inference",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=128000,
supported_attachment_types=OPENROUTER_ATTACHMENTS
),
),
AvailableModel(
id="google/gemma-3-27b-it:free",
provider=ModelProvider.OPENROUTER,
display_name="Gemma 3 27B",
description="Latest Gemma model with high-speed inference",
capabilities=ModelCapabilities(
supports_tools=True,
context_window=128000,
supported_attachment_types=OPENROUTER_ATTACHMENTS
),
),
]
NOVITA_MODELS = [
AvailableModel(
id="moonshotai/kimi-k2.5",
provider=ModelProvider.NOVITA,
display_name="Kimi K2.5",
description="MoE model with function calling, structured output, reasoning, and vision",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=NOVITA_ATTACHMENTS,
context_window=262144,
),
),
AvailableModel(
id="zai-org/glm-5",
provider=ModelProvider.NOVITA,
display_name="GLM-5",
description="MoE model with function calling, structured output, and reasoning",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=[],
context_window=202800,
),
),
AvailableModel(
id="minimax/minimax-m2.5",
provider=ModelProvider.NOVITA,
display_name="MiniMax M2.5",
description="MoE model with function calling, structured output, and reasoning",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=[],
context_window=204800,
),
),
]
AZURE_OPENAI_MODELS = [
AvailableModel(
id="azure-gpt-4",
provider=ModelProvider.AZURE_OPENAI,
display_name="Azure OpenAI GPT-4",
description="Azure-hosted GPT model",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
context_window=8192,
),
),
]
def create_custom_openai_model(model_name: str, base_url: str) -> AvailableModel:
"""Create a custom OpenAI-compatible model (e.g., LM Studio, Ollama)."""
return AvailableModel(
id=model_name,
provider=ModelProvider.OPENAI,
display_name=model_name,
description=f"Custom OpenAI-compatible model at {base_url}",
base_url=base_url,
capabilities=ModelCapabilities(
supports_tools=True,
supported_attachment_types=OPENAI_ATTACHMENTS,
),
)

View File

@@ -0,0 +1,385 @@
"""Layered model registry.
Loads model catalogs from YAML files (built-in + operator-supplied),
groups them by provider name, then for each registered provider plugin
calls ``get_models`` to produce the final per-provider model list.
End-user BYOM (per-user model records in Postgres) is layered on top:
when a lookup arrives with a ``user_id``, the registry consults a
per-user cache first (loaded from the ``user_custom_models`` table on
miss) and falls through to the built-in catalog.
Cross-process invalidation: ``ModelRegistry`` is a per-process
singleton, so a CRUD write only evicts the cache in the process that
served it. Other gunicorn workers and Celery workers would otherwise
keep using a deleted/disabled/key-rotated BYOM record indefinitely.
``invalidate_user`` therefore both drops the local layer *and* bumps a
Redis-side version counter; other processes notice the bump on their
next access (after the local TTL window) and reload from Postgres. If
Redis is unreachable the per-process TTL still bounds staleness — pure
TTL semantics, no regression.
"""
from __future__ import annotations
import logging
import time
from collections import defaultdict
from typing import Dict, List, Optional, Tuple
from application.core.model_settings import AvailableModel
from application.core.model_yaml import (
BUILTIN_MODELS_DIR,
ProviderCatalog,
load_model_yamls,
)
logger = logging.getLogger(__name__)
_USER_CACHE_TTL_SECONDS = 60.0
_USER_VERSION_KEY_PREFIX = "byom:registry_version:"
class ModelRegistry:
"""Singleton registry of available models."""
_instance: Optional["ModelRegistry"] = None
_initialized: bool = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not ModelRegistry._initialized:
self.models: Dict[str, AvailableModel] = {}
self.default_model_id: Optional[str] = None
# Per-user BYOM cache. Each entry is
# ``(layer, version_at_load, loaded_at_monotonic)``:
# * ``layer`` — {model_id: AvailableModel}
# * ``version_at_load`` — Redis-side counter snapshot at
# reload time, or ``None`` if Redis was unreachable
# * ``loaded_at_monotonic`` — for TTL bookkeeping
# Populated lazily, evicted by TTL + cross-process
# invalidation (see ``invalidate_user``).
self._user_models: Dict[
str,
Tuple[Dict[str, AvailableModel], Optional[int], float],
] = {}
self._load_models()
ModelRegistry._initialized = True
@classmethod
def get_instance(cls) -> "ModelRegistry":
return cls()
@classmethod
def reset(cls) -> None:
"""Clear the singleton. Intended for test fixtures."""
cls._instance = None
cls._initialized = False
@classmethod
def invalidate_user(cls, user_id: str) -> None:
"""Drop the cached per-user model layer for ``user_id``.
Called by the BYOM REST routes after every create/update/delete.
Two effects:
* Local: pop the entry from this process's cache so the next
lookup re-reads from Postgres immediately.
* Cross-process: ``INCR`` a Redis-side version counter for this
user. Other gunicorn/Celery processes notice the counter
changed on their next TTL-driven recheck (see
``_user_models_for``) and reload. If Redis is unreachable we
log and continue — local invalidation still happened, and
peers fall back to TTL-only staleness bounds.
"""
if cls._instance is not None:
cls._instance._user_models.pop(user_id, None)
try:
from application.cache import get_redis_instance
client = get_redis_instance()
if client is not None:
client.incr(_USER_VERSION_KEY_PREFIX + user_id)
except Exception as e:
logger.warning(
"BYOM invalidate: failed to publish version bump for "
"user %s (Redis unreachable?): %s",
user_id,
e,
)
@classmethod
def _read_user_version(cls, user_id: str) -> Optional[int]:
"""Return the Redis-side invalidation counter for ``user_id``.
``0`` if the key has never been bumped; ``None`` if Redis is
unreachable or the read failed (callers fall back to TTL-only
staleness in that case).
"""
try:
from application.cache import get_redis_instance
client = get_redis_instance()
if client is None:
return None
raw = client.get(_USER_VERSION_KEY_PREFIX + user_id)
if raw is None:
return 0
return int(raw)
except Exception:
return None
def _load_models(self) -> None:
from pathlib import Path
from application.core.settings import settings
from application.llm.providers import ALL_PROVIDERS
directories = [BUILTIN_MODELS_DIR]
operator_dir = getattr(settings, "MODELS_CONFIG_DIR", None)
if operator_dir:
op_path = Path(operator_dir)
if not op_path.exists():
logger.warning(
"MODELS_CONFIG_DIR=%s does not exist; no operator "
"model YAMLs will be loaded.",
operator_dir,
)
elif not op_path.is_dir():
logger.warning(
"MODELS_CONFIG_DIR=%s is not a directory; no operator "
"model YAMLs will be loaded.",
operator_dir,
)
else:
directories.append(op_path)
catalogs = load_model_yamls(directories)
# Validate every catalog targets a known plugin before doing any
# registry work, so an unknown provider name in YAML aborts boot
# with a clear error.
plugin_names = {p.name for p in ALL_PROVIDERS}
for c in catalogs:
if c.provider not in plugin_names:
raise ValueError(
f"{c.source_path}: YAML declares unknown provider "
f"{c.provider!r}; no Provider plugin is registered "
f"under that name. Known: {sorted(plugin_names)}"
)
catalogs_by_provider: Dict[str, List[ProviderCatalog]] = defaultdict(list)
for c in catalogs:
catalogs_by_provider[c.provider].append(c)
self.models.clear()
for provider in ALL_PROVIDERS:
if not provider.is_enabled(settings):
continue
for model in provider.get_models(
settings, catalogs_by_provider.get(provider.name, [])
):
self.models[model.id] = model
self.default_model_id = self._resolve_default(settings)
logger.info(
"ModelRegistry loaded %d models, default: %s",
len(self.models),
self.default_model_id,
)
def _resolve_default(self, settings) -> Optional[str]:
if settings.LLM_NAME:
for name in self._parse_model_names(settings.LLM_NAME):
if name in self.models:
return name
if settings.LLM_NAME in self.models:
return settings.LLM_NAME
if settings.LLM_PROVIDER and settings.API_KEY:
for model_id, model in self.models.items():
if model.provider.value == settings.LLM_PROVIDER:
return model_id
if self.models:
return next(iter(self.models.keys()))
return None
@staticmethod
def _parse_model_names(llm_name: str) -> List[str]:
if not llm_name:
return []
return [name.strip() for name in llm_name.split(",") if name.strip()]
# Per-user (BYOM) layer
def _user_models_for(self, user_id: str) -> Dict[str, AvailableModel]:
"""Return the user's BYOM models keyed by registry id (UUID).
Loaded lazily from Postgres on first access; cached subject to
a per-process TTL (``_USER_CACHE_TTL_SECONDS``) and a Redis-
backed version counter for cross-process invalidation. The TTL
bounds staleness even when Redis is unreachable, while the
version stamp lets peers refresh without a DB read on the
common case (no invalidation since last load). Decryption
failures and DB errors yield an empty layer (logged) — the
user simply doesn't see their custom models on this request,
never a 500.
"""
cached = self._user_models.get(user_id)
now = time.monotonic()
if cached is not None:
layer, cached_version, loaded_at = cached
if (now - loaded_at) < _USER_CACHE_TTL_SECONDS:
return layer
# TTL elapsed: peek at the cross-process counter. If it
# matches what we saw at load time, no invalidation has
# happened — extend the TTL without touching Postgres. If
# Redis is unreachable (``current_version is None``) we
# fall through to a real reload, which keeps staleness
# bounded to the TTL.
current_version = self._read_user_version(user_id)
if (
current_version is not None
and cached_version is not None
and current_version == cached_version
):
self._user_models[user_id] = (layer, cached_version, now)
return layer
# Capture the counter *before* the DB read so a CRUD that lands
# mid-reload doesn't get masked: the next access will see a
# newer version and reload again.
version_before_read = self._read_user_version(user_id)
layer: Dict[str, AvailableModel] = {}
try:
from application.core.model_settings import (
ModelCapabilities,
ModelProvider,
)
from application.storage.db.repositories.user_custom_models import (
UserCustomModelsRepository,
)
from application.storage.db.session import db_readonly
with db_readonly() as conn:
repo = UserCustomModelsRepository(conn)
rows = repo.list_for_user(user_id)
for row in rows:
api_key = repo._decrypt_api_key(
row.get("api_key_encrypted", ""), user_id
)
if not api_key:
# SECURITY: do NOT register an unroutable BYOM
# record. If we did, LLMCreator would fall back
# to the caller-passed api_key (settings.API_KEY
# for openai_compatible) and POST it to the
# user-supplied base_url — leaking the instance
# credential to the user's chosen endpoint.
# Most likely cause is ENCRYPTION_SECRET_KEY
# having rotated; user must re-save the model.
logger.warning(
"user_custom_models: skipping model %s for "
"user %s — api_key could not be decrypted "
"(rotated ENCRYPTION_SECRET_KEY?). Re-save "
"the model to recover.",
row.get("id"),
user_id,
)
continue
caps_raw = row.get("capabilities") or {}
# Stored attachments may be aliases (``image``) or
# raw MIME types. Built-in YAML models expand at
# load time; mirror that here so downstream MIME-
# type comparisons (handlers/base.prepare_messages)
# match concrete types like ``image/png`` rather
# than the bare alias.
from application.core.model_yaml import (
expand_attachments_lenient,
)
raw_attachments = caps_raw.get("attachments", []) or []
expanded_attachments = expand_attachments_lenient(
raw_attachments,
f"user_custom_models[user={user_id}, model={row.get('id')}]",
)
caps = ModelCapabilities(
supports_tools=bool(caps_raw.get("supports_tools", False)),
supports_structured_output=bool(
caps_raw.get("supports_structured_output", False)
),
supports_streaming=bool(
caps_raw.get("supports_streaming", True)
),
supported_attachment_types=expanded_attachments,
context_window=int(
caps_raw.get("context_window") or 128000
),
)
model_id = str(row["id"])
layer[model_id] = AvailableModel(
id=model_id,
provider=ModelProvider.OPENAI_COMPATIBLE,
display_name=row["display_name"],
description=row.get("description") or "",
capabilities=caps,
enabled=bool(row.get("enabled", True)),
base_url=row["base_url"],
upstream_model_id=row["upstream_model_id"],
source="user",
api_key=api_key,
)
except Exception as e:
logger.warning(
"user_custom_models: failed to load layer for user %s: %s",
user_id,
e,
)
layer = {}
self._user_models[user_id] = (layer, version_before_read, now)
return layer
# Lookup API. ``user_id`` enables the BYOM per-user layer; without
# it, callers see only the built-in + operator catalog.
def get_model(
self, model_id: str, user_id: Optional[str] = None
) -> Optional[AvailableModel]:
if user_id:
user_layer = self._user_models_for(user_id)
if model_id in user_layer:
return user_layer[model_id]
return self.models.get(model_id)
def get_all_models(
self, user_id: Optional[str] = None
) -> List[AvailableModel]:
out = list(self.models.values())
if user_id:
out.extend(self._user_models_for(user_id).values())
return out
def get_enabled_models(
self, user_id: Optional[str] = None
) -> List[AvailableModel]:
out = [m for m in self.models.values() if m.enabled]
if user_id:
out.extend(
m for m in self._user_models_for(user_id).values() if m.enabled
)
return out
def model_exists(
self, model_id: str, user_id: Optional[str] = None
) -> bool:
if user_id and model_id in self._user_models_for(user_id):
return True
return model_id in self.models

View File

@@ -5,9 +5,16 @@ from typing import Dict, List, Optional
logger = logging.getLogger(__name__)
# Re-exported here so existing call sites (and tests) that do
# ``from application.core.model_settings import ModelRegistry`` keep
# working. The implementation lives in ``application/core/model_registry.py``.
# Imported lazily inside ``__getattr__`` to avoid an import cycle with
# ``model_yaml`` → ``model_settings`` (this file).
class ModelProvider(str, Enum):
OPENAI = "openai"
OPENAI_COMPATIBLE = "openai_compatible"
OPENROUTER = "openrouter"
AZURE_OPENAI = "azure_openai"
ANTHROPIC = "anthropic"
@@ -41,11 +48,21 @@ class AvailableModel:
capabilities: ModelCapabilities = field(default_factory=ModelCapabilities)
enabled: bool = True
base_url: Optional[str] = None
# User-facing label distinct from dispatch provider (e.g. mistral
# routed through openai_compatible).
display_provider: Optional[str] = None
# Sent in the API call's ``model`` field; falls back to ``self.id``
# for built-ins where id IS the upstream name.
upstream_model_id: Optional[str] = None
# "builtin" for catalog YAMLs, "user" for BYOM records.
source: str = "builtin"
# Decrypted/resolved at registry-merge time. Never serialized.
api_key: Optional[str] = field(default=None, repr=False, compare=False)
def to_dict(self) -> Dict:
result = {
"id": self.id,
"provider": self.provider.value,
"provider": self.display_provider or self.provider.value,
"display_name": self.display_name,
"description": self.description,
"supported_attachment_types": self.capabilities.supported_attachment_types,
@@ -54,261 +71,21 @@ class AvailableModel:
"supports_streaming": self.capabilities.supports_streaming,
"context_window": self.capabilities.context_window,
"enabled": self.enabled,
"source": self.source,
}
if self.base_url:
result["base_url"] = self.base_url
return result
class ModelRegistry:
_instance = None
_initialized = False
def __getattr__(name):
"""Lazy re-export of ``ModelRegistry`` from ``model_registry.py``.
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
Done lazily to avoid an import cycle: ``model_registry`` imports
``model_yaml`` which imports the dataclasses from this file.
"""
if name == "ModelRegistry":
from application.core.model_registry import ModelRegistry as _MR
def __init__(self):
if not ModelRegistry._initialized:
self.models: Dict[str, AvailableModel] = {}
self.default_model_id: Optional[str] = None
self._load_models()
ModelRegistry._initialized = True
@classmethod
def get_instance(cls) -> "ModelRegistry":
return cls()
def _load_models(self):
from application.core.settings import settings
self.models.clear()
# Skip DocsGPT model if using custom OpenAI-compatible endpoint
if not settings.OPENAI_BASE_URL:
self._add_docsgpt_models(settings)
if (
settings.OPENAI_API_KEY
or (settings.LLM_PROVIDER == "openai" and settings.API_KEY)
or settings.OPENAI_BASE_URL
):
self._add_openai_models(settings)
if settings.OPENAI_API_BASE or (
settings.LLM_PROVIDER == "azure_openai" and settings.API_KEY
):
self._add_azure_openai_models(settings)
if settings.ANTHROPIC_API_KEY or (
settings.LLM_PROVIDER == "anthropic" and settings.API_KEY
):
self._add_anthropic_models(settings)
if settings.GOOGLE_API_KEY or (
settings.LLM_PROVIDER == "google" and settings.API_KEY
):
self._add_google_models(settings)
if settings.GROQ_API_KEY or (
settings.LLM_PROVIDER == "groq" and settings.API_KEY
):
self._add_groq_models(settings)
if settings.OPEN_ROUTER_API_KEY or (
settings.LLM_PROVIDER == "openrouter" and settings.API_KEY
):
self._add_openrouter_models(settings)
if settings.NOVITA_API_KEY or (
settings.LLM_PROVIDER == "novita" and settings.API_KEY
):
self._add_novita_models(settings)
if settings.HUGGINGFACE_API_KEY or (
settings.LLM_PROVIDER == "huggingface" and settings.API_KEY
):
self._add_huggingface_models(settings)
# Default model selection
if settings.LLM_NAME:
# Parse LLM_NAME (may be comma-separated)
model_names = self._parse_model_names(settings.LLM_NAME)
# First model in the list becomes default
for model_name in model_names:
if model_name in self.models:
self.default_model_id = model_name
break
# Backward compat: try exact match if no parsed model found
if not self.default_model_id and settings.LLM_NAME in self.models:
self.default_model_id = settings.LLM_NAME
if not self.default_model_id:
if settings.LLM_PROVIDER and settings.API_KEY:
for model_id, model in self.models.items():
if model.provider.value == settings.LLM_PROVIDER:
self.default_model_id = model_id
break
if not self.default_model_id and self.models:
self.default_model_id = next(iter(self.models.keys()))
logger.info(
f"ModelRegistry loaded {len(self.models)} models, default: {self.default_model_id}"
)
def _add_openai_models(self, settings):
from application.core.model_configs import (
OPENAI_MODELS,
create_custom_openai_model,
)
# Check if using local OpenAI-compatible endpoint (Ollama, LM Studio, etc.)
using_local_endpoint = bool(
settings.OPENAI_BASE_URL and settings.OPENAI_BASE_URL.strip()
)
if using_local_endpoint:
# When OPENAI_BASE_URL is set, ONLY register custom models from LLM_NAME
# Do NOT add standard OpenAI models (gpt-5.1, etc.)
if settings.LLM_NAME:
model_names = self._parse_model_names(settings.LLM_NAME)
for model_name in model_names:
custom_model = create_custom_openai_model(
model_name, settings.OPENAI_BASE_URL
)
self.models[model_name] = custom_model
logger.info(
f"Registered custom OpenAI model: {model_name} at {settings.OPENAI_BASE_URL}"
)
else:
# Standard OpenAI API usage - add standard models if API key is valid
if settings.OPENAI_API_KEY:
for model in OPENAI_MODELS:
self.models[model.id] = model
def _add_azure_openai_models(self, settings):
from application.core.model_configs import AZURE_OPENAI_MODELS
if settings.LLM_PROVIDER == "azure_openai" and settings.LLM_NAME:
for model in AZURE_OPENAI_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in AZURE_OPENAI_MODELS:
self.models[model.id] = model
def _add_anthropic_models(self, settings):
from application.core.model_configs import ANTHROPIC_MODELS
if settings.ANTHROPIC_API_KEY:
for model in ANTHROPIC_MODELS:
self.models[model.id] = model
return
if settings.LLM_PROVIDER == "anthropic" and settings.LLM_NAME:
for model in ANTHROPIC_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in ANTHROPIC_MODELS:
self.models[model.id] = model
def _add_google_models(self, settings):
from application.core.model_configs import GOOGLE_MODELS
if settings.GOOGLE_API_KEY:
for model in GOOGLE_MODELS:
self.models[model.id] = model
return
if settings.LLM_PROVIDER == "google" and settings.LLM_NAME:
for model in GOOGLE_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in GOOGLE_MODELS:
self.models[model.id] = model
def _add_groq_models(self, settings):
from application.core.model_configs import GROQ_MODELS
if settings.GROQ_API_KEY:
for model in GROQ_MODELS:
self.models[model.id] = model
return
if settings.LLM_PROVIDER == "groq" and settings.LLM_NAME:
for model in GROQ_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in GROQ_MODELS:
self.models[model.id] = model
def _add_openrouter_models(self, settings):
from application.core.model_configs import OPENROUTER_MODELS
if settings.OPEN_ROUTER_API_KEY:
for model in OPENROUTER_MODELS:
self.models[model.id] = model
return
if settings.LLM_PROVIDER == "openrouter" and settings.LLM_NAME:
for model in OPENROUTER_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in OPENROUTER_MODELS:
self.models[model.id] = model
def _add_novita_models(self, settings):
from application.core.model_configs import NOVITA_MODELS
if settings.NOVITA_API_KEY:
for model in NOVITA_MODELS:
self.models[model.id] = model
return
if settings.LLM_PROVIDER == "novita" and settings.LLM_NAME:
for model in NOVITA_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in NOVITA_MODELS:
self.models[model.id] = model
def _add_docsgpt_models(self, settings):
model_id = "docsgpt-local"
model = AvailableModel(
id=model_id,
provider=ModelProvider.DOCSGPT,
display_name="DocsGPT Model",
description="Local model",
capabilities=ModelCapabilities(
supports_tools=False,
supported_attachment_types=[],
),
)
self.models[model_id] = model
def _add_huggingface_models(self, settings):
model_id = "huggingface-local"
model = AvailableModel(
id=model_id,
provider=ModelProvider.HUGGINGFACE,
display_name="Hugging Face Model",
description="Local Hugging Face model",
capabilities=ModelCapabilities(
supports_tools=False,
supported_attachment_types=[],
),
)
self.models[model_id] = model
def _parse_model_names(self, llm_name: str) -> List[str]:
"""
Parse LLM_NAME which may contain comma-separated model names.
E.g., 'deepseek-r1:1.5b,gemma:2b' -> ['deepseek-r1:1.5b', 'gemma:2b']
"""
if not llm_name:
return []
return [name.strip() for name in llm_name.split(",") if name.strip()]
def get_model(self, model_id: str) -> Optional[AvailableModel]:
return self.models.get(model_id)
def get_all_models(self) -> List[AvailableModel]:
return list(self.models.values())
def get_enabled_models(self) -> List[AvailableModel]:
return [m for m in self.models.values() if m.enabled]
def model_exists(self, model_id: str) -> bool:
return model_id in self.models
return _MR
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@@ -1,47 +1,59 @@
from typing import Any, Dict, Optional
from application.core.model_settings import ModelRegistry
from application.core.model_registry import ModelRegistry
def get_api_key_for_provider(provider: str) -> Optional[str]:
"""Get the appropriate API key for a provider"""
"""Get the appropriate API key for a provider.
Delegates to the provider plugin's ``get_api_key``. Falls back to the
generic ``settings.API_KEY`` for unknown providers.
"""
from application.core.settings import settings
from application.llm.providers import PROVIDERS_BY_NAME
provider_key_map = {
"openai": settings.OPENAI_API_KEY,
"openrouter": settings.OPEN_ROUTER_API_KEY,
"novita": settings.NOVITA_API_KEY,
"anthropic": settings.ANTHROPIC_API_KEY,
"google": settings.GOOGLE_API_KEY,
"groq": settings.GROQ_API_KEY,
"huggingface": settings.HUGGINGFACE_API_KEY,
"azure_openai": settings.API_KEY,
"docsgpt": None,
"llama.cpp": None,
}
provider_key = provider_key_map.get(provider)
if provider_key:
return provider_key
plugin = PROVIDERS_BY_NAME.get(provider)
if plugin is not None:
key = plugin.get_api_key(settings)
if key:
return key
return settings.API_KEY
def get_all_available_models() -> Dict[str, Dict[str, Any]]:
"""Get all available models with metadata for API response"""
def get_all_available_models(
user_id: Optional[str] = None,
) -> Dict[str, Dict[str, Any]]:
"""Get all available models with metadata for API response.
When ``user_id`` is supplied, the user's BYOM custom-model records
are merged into the result alongside the built-in catalog.
"""
registry = ModelRegistry.get_instance()
return {model.id: model.to_dict() for model in registry.get_enabled_models()}
return {
model.id: model.to_dict()
for model in registry.get_enabled_models(user_id=user_id)
}
def validate_model_id(model_id: str) -> bool:
"""Check if a model ID exists in registry"""
def validate_model_id(model_id: str, user_id: Optional[str] = None) -> bool:
"""Check if a model ID exists in registry.
``user_id`` enables resolution of per-user BYOM records (UUIDs).
Without it, only built-in catalog ids resolve.
"""
registry = ModelRegistry.get_instance()
return registry.model_exists(model_id)
return registry.model_exists(model_id, user_id=user_id)
def get_model_capabilities(model_id: str) -> Optional[Dict[str, Any]]:
"""Get capabilities for a specific model"""
def get_model_capabilities(
model_id: str, user_id: Optional[str] = None
) -> Optional[Dict[str, Any]]:
"""Get capabilities for a specific model.
``user_id`` enables resolution of per-user BYOM records.
"""
registry = ModelRegistry.get_instance()
model = registry.get_model(model_id)
model = registry.get_model(model_id, user_id=user_id)
if model:
return {
"supported_attachment_types": model.capabilities.supported_attachment_types,
@@ -58,36 +70,68 @@ def get_default_model_id() -> str:
return registry.default_model_id
def get_provider_from_model_id(model_id: str) -> Optional[str]:
"""Get the provider name for a given model_id"""
def get_provider_from_model_id(
model_id: str, user_id: Optional[str] = None
) -> Optional[str]:
"""Get the provider name for a given model_id.
``user_id`` enables resolution of per-user BYOM records (UUIDs).
Without it, BYOM model ids return ``None`` and the caller falls
back to the deployment default.
"""
registry = ModelRegistry.get_instance()
model = registry.get_model(model_id)
model = registry.get_model(model_id, user_id=user_id)
if model:
return model.provider.value
return None
def get_token_limit(model_id: str) -> int:
"""
Get context window (token limit) for a model.
Returns model's context_window or default 128000 if model not found.
def get_token_limit(model_id: str, user_id: Optional[str] = None) -> int:
"""Get context window (token limit) for a model.
Returns the model's ``context_window`` or ``DEFAULT_LLM_TOKEN_LIMIT``
if not found. ``user_id`` enables resolution of per-user BYOM records.
"""
from application.core.settings import settings
registry = ModelRegistry.get_instance()
model = registry.get_model(model_id)
model = registry.get_model(model_id, user_id=user_id)
if model:
return model.capabilities.context_window
return settings.DEFAULT_LLM_TOKEN_LIMIT
def get_base_url_for_model(model_id: str) -> Optional[str]:
"""
Get the custom base_url for a specific model if configured.
Returns None if no custom base_url is set.
def get_base_url_for_model(
model_id: str, user_id: Optional[str] = None
) -> Optional[str]:
"""Get the custom base_url for a specific model if configured.
Returns ``None`` if no custom base_url is set. ``user_id`` enables
resolution of per-user BYOM records.
"""
registry = ModelRegistry.get_instance()
model = registry.get_model(model_id)
model = registry.get_model(model_id, user_id=user_id)
if model:
return model.base_url
return None
def get_api_key_for_model(
model_id: str, user_id: Optional[str] = None
) -> Optional[str]:
"""Resolve the API key to use when invoking ``model_id``.
Priority:
1. The model record's own ``api_key`` (BYOM records and
``openai_compatible`` YAMLs populate this).
2. The provider plugin's settings-based key.
``user_id`` enables resolution of per-user BYOM records.
"""
registry = ModelRegistry.get_instance()
model = registry.get_model(model_id, user_id=user_id)
if model is not None and model.api_key:
return model.api_key
if model is not None:
return get_api_key_for_provider(model.provider.value)
return None

View File

@@ -0,0 +1,358 @@
"""YAML loader for model catalog files under ``application/core/models/``.
Each ``*.yaml`` file declares one provider's static model catalog. Files
are validated with Pydantic at load time; any parse, schema, or alias
error aborts startup with the offending file path in the message.
For most providers, one YAML maps to one catalog. The
``openai_compatible`` provider is special: each YAML file represents a
distinct logical endpoint (Mistral, Together, Ollama, ...) with its own
``api_key_env`` and ``base_url``. The loader returns a flat list so the
registry can distinguish multiple files with the same ``provider:`` value.
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Dict, List, Optional, Sequence
import yaml
from pydantic import BaseModel, ConfigDict, Field, field_validator
from application.core.model_settings import (
AvailableModel,
ModelCapabilities,
ModelProvider,
)
logger = logging.getLogger(__name__)
BUILTIN_MODELS_DIR = Path(__file__).parent / "models"
DEFAULTS_FILENAME = "_defaults.yaml"
class _DefaultsFile(BaseModel):
"""Schema for ``_defaults.yaml``. Currently just attachment aliases."""
model_config = ConfigDict(extra="forbid")
attachment_aliases: Dict[str, List[str]] = Field(default_factory=dict)
class _CapabilityFields(BaseModel):
"""Capability fields shared between provider ``defaults:`` and per-model overrides.
All fields are optional so a per-model override can selectively replace
a single field from the provider-level defaults.
"""
model_config = ConfigDict(extra="forbid")
supports_tools: Optional[bool] = None
supports_structured_output: Optional[bool] = None
supports_streaming: Optional[bool] = None
attachments: Optional[List[str]] = None
context_window: Optional[int] = None
input_cost_per_token: Optional[float] = None
output_cost_per_token: Optional[float] = None
class _ModelEntry(_CapabilityFields):
"""Schema for one model row inside a YAML's ``models:`` list."""
id: str
display_name: Optional[str] = None
description: str = ""
enabled: bool = True
base_url: Optional[str] = None
aliases: List[str] = Field(default_factory=list)
@field_validator("id")
@classmethod
def _id_nonempty(cls, v: str) -> str:
if not v or not v.strip():
raise ValueError("model id must be a non-empty string")
return v
class _ProviderFile(BaseModel):
"""Schema for one ``<provider>.yaml`` catalog file."""
model_config = ConfigDict(extra="forbid")
provider: str
defaults: _CapabilityFields = Field(default_factory=_CapabilityFields)
models: List[_ModelEntry] = Field(default_factory=list)
# openai_compatible metadata. Optional for other providers.
display_provider: Optional[str] = None
api_key_env: Optional[str] = None
base_url: Optional[str] = None
class ProviderCatalog(BaseModel):
"""One YAML file's parsed contents, ready for the registry.
For most providers, multiple catalogs with the same ``provider`` get
merged later by the registry. The ``openai_compatible`` provider is
the exception: each catalog is treated as a distinct endpoint, with
its own ``api_key_env`` and ``base_url``.
"""
provider: str
models: List[AvailableModel]
source_path: Optional[Path] = None
display_provider: Optional[str] = None
api_key_env: Optional[str] = None
base_url: Optional[str] = None
model_config = ConfigDict(arbitrary_types_allowed=True)
class ModelYAMLError(ValueError):
"""Raised when a model YAML fails parsing, schema, or alias validation."""
def _expand_attachments(
attachments: Sequence[str], aliases: Dict[str, List[str]], source: str
) -> List[str]:
"""Resolve attachment shorthands (``image``, ``pdf``) to MIME types.
Raw MIME-typed entries (containing ``/``) pass through unchanged.
Unknown aliases raise ``ModelYAMLError``.
"""
expanded: List[str] = []
seen: set = set()
for entry in attachments:
if "/" in entry:
if entry not in seen:
expanded.append(entry)
seen.add(entry)
continue
if entry not in aliases:
valid = ", ".join(sorted(aliases.keys())) or "<none defined>"
raise ModelYAMLError(
f"{source}: unknown attachment alias '{entry}'. "
f"Valid aliases: {valid}. "
"(Or use a raw MIME type like 'image/png'.)"
)
for mime in aliases[entry]:
if mime not in seen:
expanded.append(mime)
seen.add(mime)
return expanded
def _load_defaults(directory: Path) -> Dict[str, List[str]]:
"""Load ``_defaults.yaml`` from ``directory`` if it exists."""
path = directory / DEFAULTS_FILENAME
if not path.exists():
return {}
try:
raw = yaml.safe_load(path.read_text(encoding="utf-8")) or {}
except yaml.YAMLError as e:
raise ModelYAMLError(f"{path}: invalid YAML: {e}") from e
try:
parsed = _DefaultsFile.model_validate(raw)
except Exception as e:
raise ModelYAMLError(f"{path}: schema error: {e}") from e
return parsed.attachment_aliases
def _resolve_provider_enum(name: str, source: Path) -> ModelProvider:
try:
return ModelProvider(name)
except ValueError as e:
valid = ", ".join(p.value for p in ModelProvider)
raise ModelYAMLError(
f"{source}: unknown provider '{name}'. Valid: {valid}"
) from e
def _build_model(
entry: _ModelEntry,
defaults: _CapabilityFields,
provider: ModelProvider,
aliases: Dict[str, List[str]],
source: Path,
display_provider: Optional[str] = None,
) -> AvailableModel:
"""Merge defaults + per-model overrides into a final ``AvailableModel``."""
def pick(field_name: str, fallback):
v = getattr(entry, field_name)
if v is not None:
return v
d = getattr(defaults, field_name)
if d is not None:
return d
return fallback
raw_attachments = entry.attachments
if raw_attachments is None:
raw_attachments = defaults.attachments
if raw_attachments is None:
raw_attachments = []
expanded = _expand_attachments(
raw_attachments, aliases, f"{source} [model={entry.id}]"
)
caps = ModelCapabilities(
supports_tools=pick("supports_tools", False),
supports_structured_output=pick("supports_structured_output", False),
supports_streaming=pick("supports_streaming", True),
supported_attachment_types=expanded,
context_window=pick("context_window", 128000),
input_cost_per_token=pick("input_cost_per_token", None),
output_cost_per_token=pick("output_cost_per_token", None),
)
return AvailableModel(
id=entry.id,
provider=provider,
display_name=entry.display_name or entry.id,
description=entry.description,
capabilities=caps,
enabled=entry.enabled,
base_url=entry.base_url,
display_provider=display_provider,
)
def _load_one_yaml(
path: Path, aliases: Dict[str, List[str]]
) -> ProviderCatalog:
try:
raw = yaml.safe_load(path.read_text(encoding="utf-8")) or {}
except yaml.YAMLError as e:
raise ModelYAMLError(f"{path}: invalid YAML: {e}") from e
try:
parsed = _ProviderFile.model_validate(raw)
except Exception as e:
raise ModelYAMLError(f"{path}: schema error: {e}") from e
provider_enum = _resolve_provider_enum(parsed.provider, path)
models = [
_build_model(
entry,
parsed.defaults,
provider_enum,
aliases,
path,
display_provider=parsed.display_provider,
)
for entry in parsed.models
]
return ProviderCatalog(
provider=parsed.provider,
models=models,
source_path=path,
display_provider=parsed.display_provider,
api_key_env=parsed.api_key_env,
base_url=parsed.base_url,
)
_BUILTIN_ALIASES_CACHE: Optional[Dict[str, List[str]]] = None
def builtin_attachment_aliases() -> Dict[str, List[str]]:
"""Return the built-in attachment alias map from ``_defaults.yaml``.
Cached after first read so repeat calls are cheap.
"""
global _BUILTIN_ALIASES_CACHE
if _BUILTIN_ALIASES_CACHE is None:
_BUILTIN_ALIASES_CACHE = _load_defaults(BUILTIN_MODELS_DIR)
return _BUILTIN_ALIASES_CACHE
def resolve_attachment_alias(alias: str) -> List[str]:
"""Resolve a single attachment alias (e.g. ``"image"``) to its
canonical MIME-type list. Raises ``ModelYAMLError`` if unknown.
"""
aliases = builtin_attachment_aliases()
if alias not in aliases:
valid = ", ".join(sorted(aliases.keys())) or "<none defined>"
raise ModelYAMLError(
f"Unknown attachment alias '{alias}'. Valid: {valid}"
)
return list(aliases[alias])
def expand_attachments_lenient(
attachments: Sequence[str], source: str
) -> List[str]:
"""Expand attachment aliases to MIME types, tolerating unknowns.
Mirrors ``_expand_attachments`` but logs+skips unknown aliases
rather than raising. Used for runtime call sites (BYOM registry
load) where an operator-side alias-map edit must not drop the
entire user's BYOM layer; the strict raise still happens at the
API validation boundary.
"""
aliases = builtin_attachment_aliases()
expanded: List[str] = []
seen: set = set()
for entry in attachments:
if "/" in entry:
if entry not in seen:
expanded.append(entry)
seen.add(entry)
continue
mime_list = aliases.get(entry)
if mime_list is None:
logger.warning(
"%s: skipping unknown attachment alias %r", source, entry,
)
continue
for mime in mime_list:
if mime not in seen:
expanded.append(mime)
seen.add(mime)
return expanded
def load_model_yamls(directories: Sequence[Path]) -> List[ProviderCatalog]:
"""Load every ``*.yaml`` file (excluding ``_defaults.yaml``) under each
directory in order and return a flat list of catalogs.
Caller is responsible for merging multiple catalogs that target the
same provider plugin. The flat-list shape lets ``openai_compatible``
keep each file separate (one logical endpoint per file).
When the same model ``id`` appears in more than one YAML across the
directory list, a warning is logged. Order in the returned list
preserves load order, so the registry's "later wins" merge gives the
later directory's definition.
"""
catalogs: List[ProviderCatalog] = []
seen_ids: Dict[str, Path] = {}
aliases: Dict[str, List[str]] = {}
for d in directories:
if not d or not d.exists():
continue
aliases.update(_load_defaults(d))
for d in directories:
if not d or not d.exists():
continue
for path in sorted(d.glob("*.yaml")):
if path.name == DEFAULTS_FILENAME:
continue
catalog = _load_one_yaml(path, aliases)
catalogs.append(catalog)
for m in catalog.models:
prior = seen_ids.get(m.id)
if prior is not None and prior != path:
logger.warning(
"Model id %r redefined: %s overrides %s (later wins)",
m.id,
path,
prior,
)
seen_ids[m.id] = path
return catalogs

View File

@@ -0,0 +1,213 @@
# Model catalogs
Each `*.yaml` file in this directory declares one provider's model
catalog. The registry loads every YAML at boot and joins it to the
matching provider plugin under `application/llm/providers/`.
To add or edit models, you almost always only touch a YAML here — no
Python code required.
## Add a model to an existing provider
Open the provider's YAML (e.g. `anthropic.yaml`) and append two lines
under `models:`:
```yaml
models:
- id: claude-3-7-sonnet
display_name: Claude 3.7 Sonnet
```
Capabilities default to the provider's `defaults:` block. Override
per-model only when needed:
```yaml
- id: claude-3-7-sonnet
display_name: Claude 3.7 Sonnet
context_window: 500000
```
Restart the app. The new model appears in `/api/models`.
> The model `id` is what gets stored in agent / workflow records. Once
> users start picking the model, **don't rename it** — agent and
> workflow rows reference it as a free-form string and silently fall
> back to the system default if the id disappears.
## Add an OpenAI-compatible provider (zero Python)
Drop a YAML in this directory (or in your `MODELS_CONFIG_DIR`) that uses
the `openai_compatible` plugin. Set the env var named in `api_key_env`
and you're done — no Python, no settings.py edit, no LLMCreator change:
```yaml
# mistral.yaml
provider: openai_compatible
display_provider: mistral # shown in /api/models response
api_key_env: MISTRAL_API_KEY # env var the plugin reads at boot
base_url: https://api.mistral.ai/v1
defaults:
supports_tools: true
context_window: 128000
models:
- id: mistral-large-latest
display_name: Mistral Large
- id: mistral-small-latest
display_name: Mistral Small
```
`MISTRAL_API_KEY=sk-... ; restart` — Mistral models appear in
`/api/models` with `provider: "mistral"`. They route through the OpenAI
wire format (it's `OpenAILLM` under the hood) but with Mistral's
endpoint and key.
Multiple `openai_compatible` YAMLs coexist: each file is one logical
endpoint with its own `api_key_env` and `base_url`. Drop in
`together.yaml`, `fireworks.yaml`, etc. side by side. If an env var
isn't set, that catalog is silently skipped at boot (logged at INFO) —
no error.
Working example: `examples/mistral.yaml.example`. Files inside
`examples/` aren't loaded by the registry; the glob only picks up
`*.yaml` at the top level.
## Add a provider with its own SDK
For a provider that doesn't speak OpenAI's wire format, add one Python
file to `application/llm/providers/<name>.py`:
```python
from application.llm.providers.base import Provider
from application.llm.my_provider import MyLLM
class MyProvider(Provider):
name = "my_provider"
llm_class = MyLLM
def get_api_key(self, settings):
return settings.MY_PROVIDER_API_KEY
```
Register it in `application/llm/providers/__init__.py` (one line in
`ALL_PROVIDERS`), add `MY_PROVIDER_API_KEY` to `settings.py`, and create
`my_provider.yaml` here with the model catalog.
## Schema reference
```yaml
provider: <string, required> # matches the Provider plugin's `name`
# openai_compatible only — required for that provider, ignored for others
display_provider: <string> # label shown in /api/models response
api_key_env: <string> # name of the env var carrying the key
base_url: <string> # endpoint URL
defaults: # optional, applied to every model below
supports_tools: bool # default false
supports_structured_output: bool # default false
supports_streaming: bool # default true
attachments: [<alias-or-mime>, ...] # default []
context_window: int # default 128000
input_cost_per_token: float # default null
output_cost_per_token: float # default null
models: # required
- id: <string, required> # the value persisted in agent records
display_name: <string> # default: id
description: <string> # default: ""
enabled: bool # default true; false hides from /api/models
base_url: <string> # optional custom endpoint for this model
# All `defaults:` fields above can be overridden here per-model.
```
### Attachment aliases
The `attachments:` list can mix human-readable aliases with raw MIME
types. Aliases are defined in `_defaults.yaml`:
| Alias | Expands to |
|---|---|
| `image` | `image/png`, `image/jpeg`, `image/jpg`, `image/webp`, `image/gif` |
| `pdf` | `application/pdf` |
| `audio` | `audio/mpeg`, `audio/wav`, `audio/ogg` |
Use raw MIME types when you need surgical control:
```yaml
attachments: [image/png, image/webp] # only these two
```
## Operator-supplied YAMLs (`MODELS_CONFIG_DIR`)
Set the `MODELS_CONFIG_DIR` env var (or `.env` entry) to a directory
path. Every `*.yaml` in that directory is loaded **after** the built-in
catalog under `application/core/models/`. Operators use this to:
- Add new `openai_compatible` providers (Mistral, Together, Fireworks,
Ollama, ...) without forking the repo.
- Extend an existing provider's catalog with extra models — append
models under `provider: anthropic` and they show up alongside the
built-ins.
- Override a built-in model's capabilities — declare the same `id`
with different fields (e.g. a higher `context_window`). Later wins;
the override is logged as a `WARNING` so you can audit it.
Things you cannot do via `MODELS_CONFIG_DIR`:
- Add a brand-new non-OpenAI provider — that needs a Python plugin
under `application/llm/providers/` (see "Add a provider with its own
SDK" above). Operator YAMLs may only target a `provider:` value that
already has a registered plugin.
### Example: Docker
Mount your model YAMLs into the container and point the env var at the
mount path:
```yaml
# docker-compose.yml
services:
app:
image: arc53/docsgpt
environment:
MODELS_CONFIG_DIR: /etc/docsgpt/models
MISTRAL_API_KEY: ${MISTRAL_API_KEY}
volumes:
- ./my-models:/etc/docsgpt/models:ro
```
Then `./my-models/mistral.yaml` (the file from
`examples/mistral.yaml.example`) gets picked up at boot.
### Example: Kubernetes
Mount a `ConfigMap` containing your YAMLs at a known path and set
`MODELS_CONFIG_DIR` on the deployment. The same `examples/mistral.yaml.example`
becomes a key in the ConfigMap.
### Misconfiguration
If `MODELS_CONFIG_DIR` is set but the path doesn't exist (or isn't a
directory), the app logs a `WARNING` at boot and continues with just
the built-in catalog. The app does *not* fail to start — operators can
ship config drift without taking down the service — but the warning is
loud enough to surface in any reasonable log aggregator.
## Validation
YAMLs are parsed with Pydantic at boot. The app fails to start with a
clear error message if:
- a top-level key is unknown
- a model is missing `id`
- an attachment alias isn't defined
- the `provider:` value isn't registered as a plugin
This is intentional — silent fallbacks would mean users don't notice
their model picks broke until they hit the API.
## Reserved fields (not yet implemented)
- `aliases:` on a model — old IDs that resolve to this model. Reserved
for future renames; the schema accepts the field but it is not yet
acted on.

View File

@@ -0,0 +1,18 @@
# Global defaults applied across every model YAML in this directory.
# Keep this file sparse — per-provider `defaults:` blocks are clearer
# than a deep global default chain. This file is for things that
# genuinely never vary, like the meaning of "image".
attachment_aliases:
image:
- image/png
- image/jpeg
- image/jpg
- image/webp
- image/gif
pdf:
- application/pdf
audio:
- audio/mpeg
- audio/wav
- audio/ogg

View File

@@ -0,0 +1,23 @@
provider: anthropic
defaults:
supports_tools: true
attachments: [image]
context_window: 200000
models:
- id: claude-opus-4-7
display_name: Claude Opus 4.7
description: Most capable Claude model for complex reasoning and agentic coding
context_window: 1000000
supports_structured_output: true
- id: claude-sonnet-4-6
display_name: Claude Sonnet 4.6
description: Best balance of speed and intelligence with extended thinking
context_window: 1000000
supports_structured_output: true
- id: claude-haiku-4-5
display_name: Claude Haiku 4.5
description: Fastest Claude model with near-frontier intelligence
supports_structured_output: true

View File

@@ -0,0 +1,31 @@
# Azure OpenAI catalog.
#
# IMPORTANT: For Azure OpenAI, the `id` field is the **deployment name**, not
# a model name. Deployment names are arbitrary strings the operator chooses
# in Azure portal (or via ARM/Bicep/Terraform) when they create a deployment
# for a given underlying model + version.
#
# The IDs below are sensible defaults that mirror the underlying OpenAI
# model name (prefixed with `azure-`). Operators almost always need to
# override them via `MODELS_CONFIG_DIR` to match the deployment names that
# actually exist in their Azure resource. The `display_name`, capability
# flags, and `context_window` reflect the underlying OpenAI model.
provider: azure_openai
defaults:
supports_tools: true
supports_structured_output: true
attachments: [image]
context_window: 400000
models:
- id: azure-gpt-5.5
display_name: Azure OpenAI GPT-5.5
description: Azure-hosted flagship frontier model for complex reasoning, coding, and agentic work with a 1M-token context window
context_window: 1050000
- id: azure-gpt-5.4-mini
display_name: Azure OpenAI GPT-5.4 Mini
description: Azure-hosted cost-efficient GPT-5.4-class model for high-volume coding, computer use, and subagent workloads
- id: azure-gpt-5.4-nano
display_name: Azure OpenAI GPT-5.4 Nano
description: Azure-hosted cheapest GPT-5.4-class model, optimized for simple high-volume tasks where speed and cost matter most

View File

@@ -0,0 +1,7 @@
provider: docsgpt
models:
- id: docsgpt-local
display_name: DocsGPT Model
description: Local model
supports_tools: false

View File

@@ -0,0 +1,31 @@
# EXAMPLE — copy this file to ../mistral.yaml (or to your
# MODELS_CONFIG_DIR) and set MISTRAL_API_KEY in your environment.
#
# This is the entire integration. No Python required: the
# `openai_compatible` plugin reads `api_key_env` and `base_url` from
# the file and routes calls through the OpenAI wire format.
#
# Files in this `examples/` directory are NOT loaded by the registry
# (the loader globs *.yaml at the top level only).
provider: openai_compatible
display_provider: mistral # shown in /api/models response
api_key_env: MISTRAL_API_KEY # env var the plugin reads
base_url: https://api.mistral.ai/v1 # OpenAI-compatible endpoint
defaults:
supports_tools: true
context_window: 128000
models:
- id: mistral-large-latest
display_name: Mistral Large
description: Top-tier reasoning model
- id: mistral-small-latest
display_name: Mistral Small
description: Fast, cost-efficient
- id: codestral-latest
display_name: Codestral
description: Code-specialized model

View File

@@ -0,0 +1,17 @@
provider: google
defaults:
supports_tools: true
supports_structured_output: true
attachments: [pdf, image]
context_window: 1048576
models:
- id: gemini-3.1-pro-preview
display_name: Gemini 3.1 Pro
description: Most capable Gemini 3 model with advanced reasoning and agentic coding (preview)
- id: gemini-3-flash-preview
display_name: Gemini 3 Flash
description: Frontier-class performance for low-latency, high-volume tasks (preview)
- id: gemini-3.1-flash-lite-preview
display_name: Gemini 3.1 Flash-Lite
description: Cost-efficient frontier-class multimodal model for high-throughput workloads (preview)

View File

@@ -0,0 +1,16 @@
provider: groq
defaults:
supports_tools: true
context_window: 131072
models:
- id: openai/gpt-oss-120b
display_name: GPT-OSS 120B
description: OpenAI's open-weight 120B flagship served on Groq's LPU hardware; strong general reasoning with strict structured output support
supports_structured_output: true
- id: llama-3.3-70b-versatile
display_name: Llama 3.3 70B Versatile
description: Meta's Llama 3.3 70B for general-purpose chat with parallel tool use
- id: llama-3.1-8b-instant
display_name: Llama 3.1 8B Instant
description: Small, very low-latency Llama model (~560 tok/s) with parallel tool use

View File

@@ -0,0 +1,7 @@
provider: huggingface
models:
- id: huggingface-local
display_name: Hugging Face Model
description: Local Hugging Face model
supports_tools: false

View File

@@ -0,0 +1,21 @@
provider: novita
defaults:
supports_tools: true
supports_structured_output: true
models:
- id: deepseek/deepseek-v4-pro
display_name: DeepSeek V4 Pro
description: 1.6T MoE (49B active) with 1M context, hybrid CSA/HCA attention, top-tier reasoning and agentic coding
context_window: 1048576
- id: moonshotai/kimi-k2.6
display_name: Kimi K2.6
description: 1T-parameter open-weight MoE with native vision/video, multi-step tool calling, and agentic long-horizon execution
attachments: [image]
context_window: 262144
- id: zai-org/glm-5
display_name: GLM-5
description: Z.AI 754B-parameter MoE with strong general reasoning, function calling, and structured output
context_window: 202800

View File

@@ -0,0 +1,18 @@
provider: openai
defaults:
supports_tools: true
supports_structured_output: true
attachments: [image]
context_window: 400000
models:
- id: gpt-5.5
display_name: GPT-5.5
description: Flagship frontier model for complex reasoning, coding, and agentic work with a 1M-token context window
context_window: 1050000
- id: gpt-5.4-mini
display_name: GPT-5.4 Mini
description: Cost-efficient GPT-5.4-class model for high-volume coding, computer use, and subagent workloads
- id: gpt-5.4-nano
display_name: GPT-5.4 Nano
description: Cheapest GPT-5.4-class model, optimized for simple high-volume tasks where speed and cost matter most

View File

@@ -0,0 +1,25 @@
provider: openrouter
defaults:
supports_tools: true
attachments: [image]
context_window: 128000
models:
- id: qwen/qwen3-coder:free
display_name: Qwen3 Coder (free)
description: Free-tier 480B MoE coder model with strong agentic tool use; rate-limited
context_window: 262000
attachments: []
- id: deepseek/deepseek-v3.2
display_name: DeepSeek V3.2
description: Open-weights reasoning model, very low cost (~$0.25 in / $0.38 out per 1M)
context_window: 131072
attachments: []
supports_structured_output: true
- id: anthropic/claude-sonnet-4.6
display_name: Claude Sonnet 4.6 (via OpenRouter)
description: Frontier Sonnet-class model with 1M context, vision, and extended thinking
context_window: 1000000
supports_structured_output: true

View File

@@ -23,6 +23,10 @@ class Settings(BaseSettings):
EMBEDDINGS_NAME: str = "huggingface_sentence-transformers/all-mpnet-base-v2"
EMBEDDINGS_BASE_URL: Optional[str] = None # Remote embeddings API URL (OpenAI-compatible)
EMBEDDINGS_KEY: Optional[str] = None # api key for embeddings (if using openai, just copy API_KEY)
# Optional directory of operator-supplied model YAMLs, loaded after the
# built-in catalog under application/core/models/. Later wins on
# duplicate model id. See application/core/models/README.md.
MODELS_CONFIG_DIR: Optional[str] = None
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"

View File

@@ -0,0 +1,72 @@
"""Gunicorn config — keeps uvicorn's access log in NCSA format."""
from __future__ import annotations
import logging
import logging.config
# NCSA common log format:
# %(h)s %(l)s %(u)s %(t)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s"
# Uvicorn's access formatter exposes a ``client_addr``/``request_line``/
# ``status_code`` trio but not the full NCSA field set, so we re-derive
# what we can.
_NCSA_FMT = (
'%(client_addr)s - - [%(asctime)s] "%(request_line)s" %(status_code)s'
)
logconfig_dict = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"ncsa_access": {
"()": "uvicorn.logging.AccessFormatter",
"fmt": _NCSA_FMT,
"datefmt": "%d/%b/%Y:%H:%M:%S %z",
"use_colors": False,
},
"default": {
"format": "[%(asctime)s] [%(process)d] [%(levelname)s] %(name)s: %(message)s",
},
},
"handlers": {
"access": {
"class": "logging.StreamHandler",
"formatter": "ncsa_access",
"stream": "ext://sys.stdout",
},
"default": {
"class": "logging.StreamHandler",
"formatter": "default",
"stream": "ext://sys.stderr",
},
},
"loggers": {
"uvicorn": {"handlers": ["default"], "level": "INFO", "propagate": False},
"uvicorn.error": {
"handlers": ["default"],
"level": "INFO",
"propagate": False,
},
"uvicorn.access": {
"handlers": ["access"],
"level": "INFO",
"propagate": False,
},
"gunicorn.error": {
"handlers": ["default"],
"level": "INFO",
"propagate": False,
},
"gunicorn.access": {
"handlers": ["access"],
"level": "INFO",
"propagate": False,
},
},
"root": {"handlers": ["default"], "level": "INFO"},
}
def on_starting(server): # pragma: no cover — gunicorn hook
"""Ensure gunicorn's own loggers use the configured handlers."""
logging.config.dictConfig(logconfig_dict)

View File

@@ -11,6 +11,7 @@ logger = logging.getLogger(__name__)
class AnthropicLLM(BaseLLM):
provider_name = "anthropic"
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):

View File

@@ -1,5 +1,6 @@
import logging
from abc import ABC, abstractmethod
from typing import ClassVar
from application.cache import gen_cache, stream_cache
@@ -10,6 +11,10 @@ logger = logging.getLogger(__name__)
class BaseLLM(ABC):
# Stamped onto the ``llm_stream_start`` event so dashboards can group
# calls by vendor. Subclasses override.
provider_name: ClassVar[str] = "unknown"
def __init__(
self,
decoded_token=None,
@@ -17,6 +22,8 @@ class BaseLLM(ABC):
model_id=None,
base_url=None,
backup_models=None,
model_user_id=None,
capabilities=None,
):
self.decoded_token = decoded_token
self.agent_id = str(agent_id) if agent_id else None
@@ -25,6 +32,12 @@ class BaseLLM(ABC):
self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
self._backup_models = backup_models or []
self._fallback_llm = None
# Registry-resolved per-model capability overrides (BYOM caps,
# operator YAML). None falls back to provider-class defaults.
self.capabilities = capabilities
# BYOM-resolution scope captured at LLM creation time so backup
# / fallback lookups hit the same per-user layer as the primary.
self.model_user_id = model_user_id
@property
def fallback_llm(self):
@@ -39,10 +52,19 @@ class BaseLLM(ABC):
get_api_key_for_provider,
)
# Try per-agent backup models first
# model_user_id (BYOM scope) takes precedence over the caller's
# sub so shared-agent backups resolve under the owner's layer.
caller_sub = (
self.decoded_token.get("sub")
if isinstance(self.decoded_token, dict)
else None
)
backup_user_id = self.model_user_id or caller_sub
for backup_model_id in self._backup_models:
try:
provider = get_provider_from_model_id(backup_model_id)
provider = get_provider_from_model_id(
backup_model_id, user_id=backup_user_id
)
if not provider:
logger.warning(
f"Could not resolve provider for backup model: {backup_model_id}"
@@ -56,6 +78,7 @@ class BaseLLM(ABC):
decoded_token=self.decoded_token,
model_id=backup_model_id,
agent_id=self.agent_id,
model_user_id=self.model_user_id,
)
logger.info(
f"Fallback LLM initialized from agent backup model: "
@@ -68,7 +91,10 @@ class BaseLLM(ABC):
)
continue
# Fall back to global FALLBACK_* settings
# Fall back to global FALLBACK_* settings. Forward
# ``model_user_id`` here too: deployments can configure
# ``FALLBACK_LLM_NAME`` to a BYOM UUID, and that UUID is owned
# by the same user the primary model was resolved under.
if settings.FALLBACK_LLM_PROVIDER:
try:
self._fallback_llm = LLMCreator.create_llm(
@@ -78,6 +104,7 @@ class BaseLLM(ABC):
decoded_token=self.decoded_token,
model_id=settings.FALLBACK_LLM_NAME,
agent_id=self.agent_id,
model_user_id=self.model_user_id,
)
logger.info(
f"Fallback LLM initialized from global settings: "
@@ -96,6 +123,26 @@ class BaseLLM(ABC):
return args_dict
return {k: v for k, v in args_dict.items() if v is not None}
@staticmethod
def _is_non_retriable_client_error(exc: BaseException) -> bool:
"""4xx errors mean the request itself is malformed — retrying with
a different model fails identically and doubles the work. Only
transient/5xx/connection errors should trigger fallback."""
try:
from google.genai.errors import ClientError as _GenaiClientError
if isinstance(exc, _GenaiClientError):
return True
except ImportError:
pass
for attr in ("status_code", "code", "http_status"):
v = getattr(exc, attr, None)
if isinstance(v, int) and 400 <= v < 500:
return True
resp = getattr(exc, "response", None)
v = getattr(resp, "status_code", None)
return isinstance(v, int) and 400 <= v < 500
def _execute_with_fallback(
self, method_name: str, decorators: list, *args, **kwargs
):
@@ -119,12 +166,18 @@ class BaseLLM(ABC):
if is_stream:
return self._stream_with_fallback(
decorated_method, method_name, *args, **kwargs
decorated_method, method_name, decorators, *args, **kwargs
)
try:
return decorated_method()
except Exception as e:
if self._is_non_retriable_client_error(e):
logger.error(
f"Primary LLM failed with non-retriable client error; "
f"skipping fallback: {str(e)}"
)
raise
if not self.fallback_llm:
logger.error(f"Primary LLM failed and no fallback configured: {str(e)}")
raise
@@ -134,14 +187,27 @@ class BaseLLM(ABC):
f"{fallback.model_id}. Error: {str(e)}"
)
fallback_method = getattr(
fallback, method_name.replace("_raw_", "")
)
# Apply decorators to fallback's raw method directly — calling
# fallback.gen() would re-enter the orchestrator and recurse via
# fallback.fallback_llm.
fallback_method = getattr(fallback, method_name)
for decorator in decorators:
fallback_method = decorator(fallback_method)
fallback_kwargs = {**kwargs, "model": fallback.model_id}
return fallback_method(*args, **fallback_kwargs)
try:
return fallback_method(fallback, *args, **fallback_kwargs)
except Exception as e2:
if self._is_non_retriable_client_error(e2):
logger.error(
f"Fallback LLM failed with non-retriable client "
f"error; giving up: {str(e2)}"
)
else:
logger.error(f"Fallback LLM also failed; giving up: {str(e2)}")
raise
def _stream_with_fallback(
self, decorated_method, method_name, *args, **kwargs
self, decorated_method, method_name, decorators, *args, **kwargs
):
"""
Wrapper generator that catches mid-stream errors and falls back.
@@ -154,6 +220,12 @@ class BaseLLM(ABC):
try:
yield from decorated_method()
except Exception as e:
if self._is_non_retriable_client_error(e):
logger.error(
f"Primary LLM failed mid-stream with non-retriable client "
f"error; skipping fallback: {str(e)}"
)
raise
if not self.fallback_llm:
logger.error(
f"Primary LLM failed and no fallback configured: {str(e)}"
@@ -164,11 +236,37 @@ class BaseLLM(ABC):
f"Primary LLM failed mid-stream. Falling back to "
f"{fallback.model_id}. Error: {str(e)}"
)
fallback_method = getattr(
fallback, method_name.replace("_raw_", "")
# Apply decorators to fallback's raw stream method directly —
# calling fallback.gen_stream() would re-enter the orchestrator
# and recurse via fallback.fallback_llm. Emit the stream-start
# event manually so dashboards still see the fallback's
# provider/model when the response actually comes from it.
fallback._emit_stream_start_log(
fallback.model_id,
kwargs.get("messages"),
kwargs.get("tools"),
bool(
kwargs.get("_usage_attachments")
or kwargs.get("attachments")
),
)
fallback_method = getattr(fallback, method_name)
for decorator in decorators:
fallback_method = decorator(fallback_method)
fallback_kwargs = {**kwargs, "model": fallback.model_id}
yield from fallback_method(*args, **fallback_kwargs)
try:
yield from fallback_method(fallback, *args, **fallback_kwargs)
except Exception as e2:
if self._is_non_retriable_client_error(e2):
logger.error(
f"Fallback LLM failed mid-stream with non-retriable "
f"client error; giving up: {str(e2)}"
)
else:
logger.error(
f"Fallback LLM also failed mid-stream; giving up: {str(e2)}"
)
raise
def gen(self, model, messages, stream=False, tools=None, *args, **kwargs):
decorators = [gen_token_usage, gen_cache]
@@ -183,7 +281,58 @@ class BaseLLM(ABC):
**kwargs,
)
def _emit_stream_start_log(self, model, messages, tools, has_attachments):
# Stamped with ``self.provider_name`` so dashboards can group calls
# by vendor; the fallback path emits its own copy on the fallback
# instance so the actual responding provider is recorded.
logging.info(
"llm_stream_start",
extra={
"model": model,
"provider": self.provider_name,
"message_count": len(messages) if messages is not None else 0,
"has_attachments": bool(has_attachments),
"has_tools": bool(tools),
},
)
def _emit_stream_finished_log(
self,
model,
*,
prompt_tokens,
completion_tokens,
latency_ms,
cached_tokens=None,
error=None,
):
# Paired with ``llm_stream_start`` so cost dashboards can sum tokens
# by user/agent/provider. Token counts are client-side estimates
# from ``stream_token_usage``; vendor-reported counts (incl.
# ``cached_tokens`` for prompt caching) require per-provider
# extraction in each ``_raw_gen_stream`` and aren't wired yet.
extra = {
"model": model,
"provider": self.provider_name,
"prompt_tokens": int(prompt_tokens),
"completion_tokens": int(completion_tokens),
"latency_ms": int(latency_ms),
"status": "error" if error is not None else "ok",
}
if cached_tokens is not None:
extra["cached_tokens"] = int(cached_tokens)
if error is not None:
extra["error_class"] = type(error).__name__
logging.info("llm_stream_finished", extra=extra)
def gen_stream(self, model, messages, stream=True, tools=None, *args, **kwargs):
# Attachments arrive as ``_usage_attachments`` from ``Agent._llm_gen``;
# the ``stream_token_usage`` decorator pops that key, but the log
# fires before the decorator runs so it's still in ``kwargs`` here.
has_attachments = bool(
kwargs.get("_usage_attachments") or kwargs.get("attachments")
)
self._emit_stream_start_log(model, messages, tools, has_attachments)
decorators = [stream_cache, stream_token_usage]
return self._execute_with_fallback(
"_raw_gen_stream",

View File

@@ -6,6 +6,8 @@ DOCSGPT_BASE_URL = "https://oai.arc53.com"
DOCSGPT_MODEL = "docsgpt"
class DocsGPTAPILLM(OpenAILLM):
provider_name = "docsgpt"
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
super().__init__(
api_key=DOCSGPT_API_KEY,

View File

@@ -10,6 +10,8 @@ from application.storage.storage_creator import StorageCreator
class GoogleLLM(BaseLLM):
provider_name = "google"
def __init__(
self, api_key=None, user_api_key=None, decoded_token=None, *args, **kwargs
):
@@ -79,24 +81,39 @@ class GoogleLLM(BaseLLM):
for attachment in attachments:
mime_type = attachment.get("mime_type")
if mime_type in self.get_supported_attachment_types():
try:
if mime_type not in self.get_supported_attachment_types():
continue
try:
# Images go inline as bytes per Google's guidance for
# requests under 20MB; the Files API can return before
# the upload reaches ACTIVE state and yield an empty URI.
if mime_type.startswith("image/"):
file_bytes = self._read_attachment_bytes(attachment)
files.append(
{"file_bytes": file_bytes, "mime_type": mime_type}
)
else:
file_uri = self._upload_file_to_google(attachment)
if not file_uri:
raise ValueError(
f"Google Files API returned empty URI for "
f"{attachment.get('path', 'unknown')}"
)
logging.info(
f"GoogleLLM: Successfully uploaded file, got URI: {file_uri}"
)
files.append({"file_uri": file_uri, "mime_type": mime_type})
except Exception as e:
logging.error(
f"GoogleLLM: Error uploading file: {e}", exc_info=True
except Exception as e:
logging.error(
f"GoogleLLM: Error processing attachment: {e}", exc_info=True
)
if "content" in attachment:
prepared_messages[user_message_index]["content"].append(
{
"type": "text",
"text": f"[File could not be processed: {attachment.get('path', 'unknown')}]",
}
)
if "content" in attachment:
prepared_messages[user_message_index]["content"].append(
{
"type": "text",
"text": f"[File could not be processed: {attachment.get('path', 'unknown')}]",
}
)
if files:
logging.info(f"GoogleLLM: Adding {len(files)} files to message")
prepared_messages[user_message_index]["content"].append({"files": files})
@@ -112,7 +129,9 @@ class GoogleLLM(BaseLLM):
Returns:
str: Google AI file URI for the uploaded file.
"""
if "google_file_uri" in attachment:
# Truthy check, not membership: a poisoned cache row of "" or
# None must be treated as a miss and trigger a fresh upload.
if attachment.get("google_file_uri"):
return attachment["google_file_uri"]
file_path = attachment.get("path")
if not file_path:
@@ -126,6 +145,10 @@ class GoogleLLM(BaseLLM):
file=local_path
).uri,
)
if not file_uri:
raise ValueError(
f"Google Files API upload returned empty URI for {file_path}"
)
# Cache the Google file URI on the attachment row so we don't
# re-upload on the next LLM call. Accept either a PG UUID
@@ -159,6 +182,26 @@ class GoogleLLM(BaseLLM):
logging.error(f"Error uploading file to Google AI: {e}", exc_info=True)
raise
def _read_attachment_bytes(self, attachment):
"""
Read attachment bytes from storage for inline transmission.
Args:
attachment (dict): Attachment dictionary with path and metadata.
Returns:
bytes: Raw file bytes.
"""
file_path = attachment.get("path")
if not file_path:
raise ValueError("No file path provided in attachment")
if not self.storage.file_exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
return self.storage.process_file(
file_path,
lambda local_path, **kwargs: open(local_path, "rb").read(),
)
def _clean_messages_google(self, messages):
"""
Convert OpenAI format messages to Google AI format and collect system prompts.
@@ -298,12 +341,24 @@ class GoogleLLM(BaseLLM):
)
elif "files" in item:
for file_data in item["files"]:
parts.append(
types.Part.from_uri(
file_uri=file_data["file_uri"],
mime_type=file_data["mime_type"],
if "file_bytes" in file_data:
parts.append(
types.Part.from_bytes(
data=file_data["file_bytes"],
mime_type=file_data["mime_type"],
)
)
elif file_data.get("file_uri"):
parts.append(
types.Part.from_uri(
file_uri=file_data["file_uri"],
mime_type=file_data["mime_type"],
)
)
else:
logging.warning(
"GoogleLLM: dropping file part with empty URI and no bytes"
)
)
else:
raise ValueError(
f"Unexpected content dictionary format:{item}"
@@ -541,22 +596,6 @@ class GoogleLLM(BaseLLM):
config.response_mime_type = "application/json"
# Check if we have both tools and file attachments
has_attachments = False
for message in messages:
for part in message.parts:
if hasattr(part, "file_data") and part.file_data is not None:
has_attachments = True
break
if has_attachments:
break
messages_summary = self._summarize_messages_for_log(messages)
logging.info(
"GoogleLLM: Starting stream generation. Model: %s, Messages: %s, Has attachments: %s",
model,
messages_summary,
has_attachments,
)
response = client.models.generate_content_stream(
model=model,
contents=messages,

View File

@@ -5,6 +5,8 @@ GROQ_BASE_URL = "https://api.groq.com/openai/v1"
class GroqLLM(OpenAILLM):
provider_name = "groq"
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
super().__init__(
api_key=api_key or settings.GROQ_API_KEY or settings.API_KEY,

View File

@@ -280,7 +280,26 @@ class LLMHandler(ABC):
# Keep serialized function calls/responses so the compressor sees actions
parts_text.append(str(item))
elif "files" in item:
parts_text.append(str(item))
# Image attachments arrive with raw bytes / base64
# inline (see GoogleLLM.prepare_messages_with_attachments).
# ``str(item)`` would dump the whole byte/base64
# blob into the compression prompt and bust the
# compression LLM's input limit.
files = item.get("files") or []
descriptors = []
if isinstance(files, list):
for f in files:
if isinstance(f, dict):
descriptors.append(
f.get("mime_type") or "file"
)
elif isinstance(f, str):
descriptors.append(f)
if not descriptors:
descriptors = ["file"]
parts_text.append(
f"[attachment: {', '.join(descriptors)}]"
)
return "\n".join(parts_text)
return ""
@@ -470,10 +489,14 @@ class LLMHandler(ABC):
)
return self._perform_in_memory_compression(agent, messages)
# Use orchestrator to perform compression
# Use orchestrator to perform compression. ``model_user_id``
# keeps BYOM registry resolution scoped to the model owner
# (shared-agent dispatch) while ``user_id`` stays the caller
# for the conversation access check.
result = orchestrator.compress_mid_execution(
conversation_id=agent.conversation_id,
user_id=agent.initial_user_id,
model_user_id=getattr(agent, "model_user_id", None),
model_id=agent.model_id,
decoded_token=getattr(agent, "decoded_token", {}),
current_conversation=conversation,
@@ -577,7 +600,20 @@ class LLMHandler(ABC):
if settings.COMPRESSION_MODEL_OVERRIDE
else agent.model_id
)
provider = get_provider_from_model_id(compression_model)
agent_decoded = getattr(agent, "decoded_token", None)
caller_sub = (
agent_decoded.get("sub")
if isinstance(agent_decoded, dict)
else None
)
# Use model-owner scope (mirrors orchestrator path) so
# shared-agent owner-BYOM resolves under the owner's layer.
compression_user_id = (
getattr(agent, "model_user_id", None) or caller_sub
)
provider = get_provider_from_model_id(
compression_model, user_id=compression_user_id
)
api_key = get_api_key_for_provider(provider)
compression_llm = LLMCreator.create_llm(
provider,
@@ -586,6 +622,7 @@ class LLMHandler(ABC):
getattr(agent, "decoded_token", None),
model_id=compression_model,
agent_id=getattr(agent, "agent_id", None),
model_user_id=compression_user_id,
)
# Create service without DB persistence capability
@@ -921,8 +958,15 @@ class LLMHandler(ABC):
}
return ""
# ``agent.model_id`` is the registry id (a UUID for BYOM
# records). Use the LLM's own model_id, which LLMCreator
# already resolved to the upstream model name. Built-ins:
# the two are equal; BYOM: the upstream name like
# "mistral-large-latest" instead of the UUID.
response = agent.llm.gen(
model=agent.model_id, messages=messages, tools=agent.tools
model=getattr(agent.llm, "model_id", None) or agent.model_id,
messages=messages,
tools=agent.tools,
)
parsed = self.parse_response(response)
self.llm_calls.append(build_stack_data(agent.llm))
@@ -1011,8 +1055,11 @@ class LLMHandler(ABC):
})
logger.info("Context limit reached - instructing agent to wrap up")
# See note above on agent.model_id vs llm.model_id.
response = agent.llm.gen_stream(
model=agent.model_id, messages=messages, tools=agent.tools if not agent.context_limit_reached else None
model=getattr(agent.llm, "model_id", None) or agent.model_id,
messages=messages,
tools=agent.tools if not agent.context_limit_reached else None,
)
self.llm_calls.append(build_stack_data(agent.llm))

View File

@@ -26,6 +26,8 @@ class LlamaSingleton:
class LlamaCpp(BaseLLM):
provider_name = "llama_cpp"
def __init__(
self,
api_key=None,

View File

@@ -1,34 +1,11 @@
import logging
from application.llm.anthropic import AnthropicLLM
from application.llm.docsgpt_provider import DocsGPTAPILLM
from application.llm.google_ai import GoogleLLM
from application.llm.groq import GroqLLM
from application.llm.llama_cpp import LlamaCpp
from application.llm.novita import NovitaLLM
from application.llm.openai import AzureOpenAILLM, OpenAILLM
from application.llm.premai import PremAILLM
from application.llm.sagemaker import SagemakerAPILLM
from application.llm.open_router import OpenRouterLLM
from application.llm.providers import PROVIDERS_BY_NAME
logger = logging.getLogger(__name__)
class LLMCreator:
llms = {
"openai": OpenAILLM,
"azure_openai": AzureOpenAILLM,
"sagemaker": SagemakerAPILLM,
"llama.cpp": LlamaCpp,
"anthropic": AnthropicLLM,
"docsgpt": DocsGPTAPILLM,
"premai": PremAILLM,
"groq": GroqLLM,
"google": GoogleLLM,
"novita": NovitaLLM,
"openrouter": OpenRouterLLM,
}
@classmethod
def create_llm(
cls,
@@ -39,28 +16,111 @@ class LLMCreator:
model_id=None,
agent_id=None,
backup_models=None,
model_user_id=None,
*args,
**kwargs,
):
from application.core.model_utils import get_base_url_for_model
"""Construct an LLM for the given provider ``type``.
llm_class = cls.llms.get(type.lower())
if not llm_class:
``model_user_id`` is the BYOM-resolution scope. Defaults to
``decoded_token['sub']`` (the caller). Pass it explicitly when
the model record belongs to a *different* user — most notably
for shared-agent dispatch, where the agent's stored
``default_model_id`` is the owner's BYOM UUID but
``decoded_token`` represents the caller.
"""
from application.core.model_registry import ModelRegistry
from application.security.safe_url import (
UnsafeUserUrlError,
pinned_httpx_client,
validate_user_base_url,
)
plugin = PROVIDERS_BY_NAME.get(type.lower())
if plugin is None or plugin.llm_class is None:
raise ValueError(f"No LLM class found for type {type}")
# Extract base_url from model configuration if model_id is provided
# Prefer per-model endpoint config from the registry. This is what
# makes openai_compatible AND end-user BYOM work without changing
# every call site: if the registered AvailableModel carries its
# own api_key / base_url, they win over whatever the caller
# resolved via the provider plugin.
#
# End-user BYOM lookups need the user_id from decoded_token to
# find the user's per-user models layer (built-in models resolve
# without it, so this stays back-compat).
base_url = None
upstream_model_id = model_id
capabilities = None
if model_id:
base_url = get_base_url_for_model(model_id)
user_id = model_user_id
if user_id is None:
user_id = (
(decoded_token or {}).get("sub") if decoded_token else None
)
model = ModelRegistry.get_instance().get_model(model_id, user_id=user_id)
if model is not None:
# Forward registry caps so the LLM enforces them at
# dispatch (built-in classes hard-code True otherwise).
capabilities = getattr(model, "capabilities", None)
# SECURITY: refuse user-source dispatch without its own
# api_key (would leak settings.API_KEY to base_url).
if (
getattr(model, "source", "builtin") == "user"
and not model.api_key
):
raise ValueError(
f"Custom model {model_id!r} has no usable API key "
"(decryption may have failed). Re-save the model "
"in settings to dispatch it."
)
if model.api_key:
api_key = model.api_key
if model.base_url:
base_url = model.base_url
# For BYOM the registry id is a UUID; the upstream API
# call needs the user's typed model name instead.
if model.upstream_model_id:
upstream_model_id = model.upstream_model_id
return llm_class(
# SECURITY: re-validate at dispatch (defense in depth
# for pre-guard rows / YAML-supplied entries). The
# pinned httpx.Client below is what actually closes the
# DNS-rebinding TOCTOU window.
if base_url and getattr(model, "source", "builtin") == "user":
try:
validate_user_base_url(base_url)
except UnsafeUserUrlError as e:
raise ValueError(
f"Refusing to dispatch model {model_id!r}: {e}"
) from e
# Pinned httpx.Client: resolves once, validates, and
# binds the SDK's outbound socket to the validated IP
# (preserves Host / SNI). Future BYOM providers must
# opt in explicitly — only openai_compatible takes
# http_client today.
if plugin.name == "openai_compatible":
try:
kwargs["http_client"] = pinned_httpx_client(
base_url
)
except UnsafeUserUrlError as e:
raise ValueError(
f"Refusing to dispatch model {model_id!r}: {e}"
) from e
# Forward model_user_id so backup/fallback resolves under the
# owner's scope on shared-agent dispatch.
return plugin.llm_class(
api_key,
user_api_key,
decoded_token=decoded_token,
model_id=model_id,
model_id=upstream_model_id,
agent_id=agent_id,
base_url=base_url,
backup_models=backup_models,
model_user_id=model_user_id,
capabilities=capabilities,
*args,
**kwargs,
)

View File

@@ -5,6 +5,8 @@ NOVITA_BASE_URL = "https://api.novita.ai/openai"
class NovitaLLM(OpenAILLM):
provider_name = "novita"
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
super().__init__(
api_key=api_key or settings.NOVITA_API_KEY or settings.API_KEY,

View File

@@ -5,6 +5,8 @@ OPEN_ROUTER_BASE_URL = "https://openrouter.ai/api/v1"
class OpenRouterLLM(OpenAILLM):
provider_name = "openrouter"
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
super().__init__(
api_key=api_key or settings.OPEN_ROUTER_API_KEY or settings.API_KEY,

View File

@@ -61,8 +61,17 @@ def _truncate_base64_for_logging(messages):
class OpenAILLM(BaseLLM):
provider_name = "openai"
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
def __init__(
self,
api_key=None,
user_api_key=None,
base_url=None,
http_client=None,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.api_key = api_key or settings.OPENAI_API_KEY or settings.API_KEY
@@ -80,7 +89,18 @@ class OpenAILLM(BaseLLM):
else:
effective_base_url = "https://api.openai.com/v1"
self.client = OpenAI(api_key=self.api_key, base_url=effective_base_url)
# http_client (set by LLMCreator for BYOM) is a DNS-rebinding-safe
# httpx.Client; without it the SDK re-resolves DNS per request.
if http_client is not None:
self.client = OpenAI(
api_key=self.api_key,
base_url=effective_base_url,
http_client=http_client,
)
else:
self.client = OpenAI(
api_key=self.api_key, base_url=effective_base_url
)
self.storage = StorageCreator.get_storage()
def _clean_messages_openai(self, messages):
@@ -243,6 +263,13 @@ class OpenAILLM(BaseLLM):
if "max_tokens" in kwargs:
kwargs["max_completion_tokens"] = kwargs.pop("max_tokens")
# Defense-in-depth: drop tools / response_format if the
# registry's capability flags deny them.
if tools and not self._supports_tools():
tools = None
if response_format and not self._supports_structured_output():
response_format = None
request_params = {
"model": model,
"messages": messages,
@@ -279,6 +306,13 @@ class OpenAILLM(BaseLLM):
if "max_tokens" in kwargs:
kwargs["max_completion_tokens"] = kwargs.pop("max_tokens")
# See _raw_gen for rationale — drop tools/response_format when the
# registry-provided capabilities say the model doesn't support them.
if tools and not self._supports_tools():
tools = None
if response_format and not self._supports_structured_output():
response_format = None
request_params = {
"model": model,
"messages": messages,
@@ -320,9 +354,17 @@ class OpenAILLM(BaseLLM):
response.close()
def _supports_tools(self):
# When the LLM was constructed via LLMCreator with a registered
# AvailableModel, ``self.capabilities`` is the per-model record.
# BYOM users can disable tool support; respect that. Otherwise
# OpenAI's API supports tools by default.
if self.capabilities is not None:
return bool(self.capabilities.supports_tools)
return True
def _supports_structured_output(self):
if self.capabilities is not None:
return bool(self.capabilities.supports_structured_output)
return True
def prepare_structured_output_format(self, json_schema):
@@ -389,8 +431,14 @@ class OpenAILLM(BaseLLM):
Returns:
list: List of supported MIME types
"""
from application.core.model_configs import OPENAI_ATTACHMENTS
return OPENAI_ATTACHMENTS
# Per-model caps from the registry win when present — a BYOM
# endpoint that doesn't accept images would otherwise still be
# sent base64 image parts because the OpenAI default below
# advertises the image alias unconditionally.
if self.capabilities is not None:
return list(self.capabilities.supported_attachment_types or [])
from application.core.model_yaml import resolve_attachment_alias
return resolve_attachment_alias("image")
def prepare_messages_with_attachments(self, messages, attachments=None):
"""

View File

@@ -3,6 +3,7 @@ from application.core.settings import settings
class PremAILLM(BaseLLM):
provider_name = "premai"
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
from premai import Prem

View File

@@ -0,0 +1,51 @@
"""Provider plugin registry.
Plugins are imported eagerly so import errors surface at app boot rather
than at first request. ``ALL_PROVIDERS`` is the canonical ordered list;
``PROVIDERS_BY_NAME`` is a name-keyed lookup for LLMCreator and the
model registry.
"""
from __future__ import annotations
from typing import Dict, List
from application.llm.providers.anthropic import AnthropicProvider
from application.llm.providers.azure_openai import AzureOpenAIProvider
from application.llm.providers.base import Provider
from application.llm.providers.docsgpt import DocsGPTProvider
from application.llm.providers.google import GoogleProvider
from application.llm.providers.groq import GroqProvider
from application.llm.providers.huggingface import HuggingFaceProvider
from application.llm.providers.llama_cpp import LlamaCppProvider
from application.llm.providers.novita import NovitaProvider
from application.llm.providers.openai import OpenAIProvider
from application.llm.providers.openai_compatible import OpenAICompatibleProvider
from application.llm.providers.openrouter import OpenRouterProvider
from application.llm.providers.premai import PremAIProvider
from application.llm.providers.sagemaker import SagemakerProvider
# Order here is the order the registry iterates providers (and therefore
# the order ``/api/models`` reports them). Match the historical order
# from the old ModelRegistry._load_models for byte-stable output during
# the migration. ``openai_compatible`` slots in right after ``openai``
# so legacy ``OPENAI_BASE_URL`` models keep landing in the same place.
ALL_PROVIDERS: List[Provider] = [
DocsGPTProvider(),
OpenAIProvider(),
OpenAICompatibleProvider(),
AzureOpenAIProvider(),
AnthropicProvider(),
GoogleProvider(),
GroqProvider(),
OpenRouterProvider(),
NovitaProvider(),
HuggingFaceProvider(),
LlamaCppProvider(),
PremAIProvider(),
SagemakerProvider(),
]
PROVIDERS_BY_NAME: Dict[str, Provider] = {p.name: p for p in ALL_PROVIDERS}
__all__ = ["ALL_PROVIDERS", "PROVIDERS_BY_NAME", "Provider"]

View File

@@ -0,0 +1,51 @@
"""Shared helper for providers that follow the
``<X>_API_KEY or (LLM_PROVIDER==X and API_KEY)`` pattern.
This is the dominant pattern across Anthropic, Google, Groq, OpenRouter,
and Novita. Extracted here so each plugin stays a few lines long.
"""
from __future__ import annotations
from typing import List, Optional
from application.core.model_settings import AvailableModel
def get_api_key(
settings,
provider_name: str,
provider_specific_key: Optional[str],
) -> Optional[str]:
if provider_specific_key:
return provider_specific_key
if settings.LLM_PROVIDER == provider_name and settings.API_KEY:
return settings.API_KEY
return None
def filter_models_by_llm_name(
settings,
provider_name: str,
provider_specific_key: Optional[str],
models: List[AvailableModel],
) -> List[AvailableModel]:
"""Mirrors the historical ``_add_<X>_models`` selection logic.
Behavior:
- If the provider-specific API key is set → load all models.
- Else if ``LLM_PROVIDER`` matches and ``LLM_NAME`` matches a known
model → load just that model.
- Otherwise → load all models (preserved "load anyway" branch from
the original methods).
"""
if provider_specific_key:
return models
if (
settings.LLM_PROVIDER == provider_name
and settings.LLM_NAME
):
named = [m for m in models if m.id == settings.LLM_NAME]
if named:
return named
return models

View File

@@ -0,0 +1,23 @@
from __future__ import annotations
from typing import Optional
from application.llm.anthropic import AnthropicLLM
from application.llm.providers._apikey_or_llm_name import (
filter_models_by_llm_name,
get_api_key,
)
from application.llm.providers.base import Provider
class AnthropicProvider(Provider):
name = "anthropic"
llm_class = AnthropicLLM
def get_api_key(self, settings) -> Optional[str]:
return get_api_key(settings, self.name, settings.ANTHROPIC_API_KEY)
def filter_yaml_models(self, settings, models):
return filter_models_by_llm_name(
settings, self.name, settings.ANTHROPIC_API_KEY, models
)

View File

@@ -0,0 +1,30 @@
from __future__ import annotations
from typing import Optional
from application.llm.openai import AzureOpenAILLM
from application.llm.providers.base import Provider
class AzureOpenAIProvider(Provider):
name = "azure_openai"
llm_class = AzureOpenAILLM
def get_api_key(self, settings) -> Optional[str]:
# Azure historically uses the generic API_KEY field.
return settings.API_KEY
def is_enabled(self, settings) -> bool:
if settings.OPENAI_API_BASE:
return True
return settings.LLM_PROVIDER == self.name and bool(settings.API_KEY)
def filter_yaml_models(self, settings, models):
# Mirrors _add_azure_openai_models: when LLM_PROVIDER==azure_openai
# and LLM_NAME matches a known model, narrow to that one model.
# Otherwise load the entire catalog.
if settings.LLM_PROVIDER == self.name and settings.LLM_NAME:
named = [m for m in models if m.id == settings.LLM_NAME]
if named:
return named
return models

View File

@@ -0,0 +1,74 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, ClassVar, List, Optional, Type
if TYPE_CHECKING:
from application.core.model_settings import AvailableModel
from application.core.model_yaml import ProviderCatalog
from application.core.settings import Settings
from application.llm.base import BaseLLM
class Provider(ABC):
"""Owns the *behavior* of an LLM provider.
Concrete providers declare their name, the LLM class to instantiate,
and how to resolve credentials from settings. Static model catalogs
live in YAML under ``application/core/models/`` and are joined to the
provider by name at registry load time.
Most plugins receive zero or one catalog at registry-build time. The
``openai_compatible`` plugin is the exception: it receives one catalog
per matching YAML file, each with its own ``api_key_env`` and
``base_url``. Plugins that need per-catalog metadata override
``get_models``; the default implementation merges catalogs and routes
through ``filter_yaml_models`` + ``extra_models``.
"""
name: ClassVar[str]
# ``None`` means the provider appears in the catalog but isn't
# dispatchable through LLMCreator (e.g. Hugging Face today, where the
# original LLMCreator dict had no entry).
llm_class: ClassVar[Optional[Type["BaseLLM"]]] = None
@abstractmethod
def get_api_key(self, settings: "Settings") -> Optional[str]:
"""Return the API key for this provider, or None if unavailable."""
def is_enabled(self, settings: "Settings") -> bool:
"""Whether this provider should contribute models to the registry."""
return bool(self.get_api_key(settings))
def filter_yaml_models(
self, settings: "Settings", models: List["AvailableModel"]
) -> List["AvailableModel"]:
"""Hook to filter YAML-loaded models. Default: return all."""
return models
def extra_models(self, settings: "Settings") -> List["AvailableModel"]:
"""Hook to add dynamic models not declared in YAML. Default: none."""
return []
def get_models(
self,
settings: "Settings",
catalogs: List["ProviderCatalog"],
) -> List["AvailableModel"]:
"""Final list of models this plugin contributes.
Default: merge the models across all matched catalogs (later
catalog wins on duplicate id), filter via ``filter_yaml_models``,
then append ``extra_models``. Override when per-catalog metadata
matters (see ``OpenAICompatibleProvider``).
"""
merged: List["AvailableModel"] = []
seen: dict = {}
for c in catalogs:
for m in c.models:
if m.id in seen:
merged[seen[m.id]] = m
else:
seen[m.id] = len(merged)
merged.append(m)
return self.filter_yaml_models(settings, merged) + self.extra_models(settings)

View File

@@ -0,0 +1,22 @@
from __future__ import annotations
from typing import Optional
from application.llm.docsgpt_provider import DocsGPTAPILLM
from application.llm.providers.base import Provider
class DocsGPTProvider(Provider):
name = "docsgpt"
llm_class = DocsGPTAPILLM
def get_api_key(self, settings) -> Optional[str]:
# No provider-specific key; the LLM class can use the generic
# API_KEY fallback if it needs one. Mirrors model_utils' historical
# behavior of returning settings.API_KEY when no specific key exists.
return settings.API_KEY
def is_enabled(self, settings) -> bool:
# The hosted DocsGPT model is hidden when the deployment is
# pointed at a custom OpenAI-compatible endpoint.
return not settings.OPENAI_BASE_URL

View File

@@ -0,0 +1,23 @@
from __future__ import annotations
from typing import Optional
from application.llm.google_ai import GoogleLLM
from application.llm.providers._apikey_or_llm_name import (
filter_models_by_llm_name,
get_api_key,
)
from application.llm.providers.base import Provider
class GoogleProvider(Provider):
name = "google"
llm_class = GoogleLLM
def get_api_key(self, settings) -> Optional[str]:
return get_api_key(settings, self.name, settings.GOOGLE_API_KEY)
def filter_yaml_models(self, settings, models):
return filter_models_by_llm_name(
settings, self.name, settings.GOOGLE_API_KEY, models
)

View File

@@ -0,0 +1,23 @@
from __future__ import annotations
from typing import Optional
from application.llm.groq import GroqLLM
from application.llm.providers._apikey_or_llm_name import (
filter_models_by_llm_name,
get_api_key,
)
from application.llm.providers.base import Provider
class GroqProvider(Provider):
name = "groq"
llm_class = GroqLLM
def get_api_key(self, settings) -> Optional[str]:
return get_api_key(settings, self.name, settings.GROQ_API_KEY)
def filter_yaml_models(self, settings, models):
return filter_models_by_llm_name(
settings, self.name, settings.GROQ_API_KEY, models
)

View File

@@ -0,0 +1,25 @@
from __future__ import annotations
from typing import Optional
from application.llm.providers._apikey_or_llm_name import (
get_api_key as shared_get_api_key,
)
from application.llm.providers.base import Provider
class HuggingFaceProvider(Provider):
"""Surfaces ``huggingface-local`` to the model catalog.
Not dispatchable through LLMCreator — historically there was no
HuggingFaceLLM entry in ``LLMCreator.llms``, and calling ``create_llm``
with ``"huggingface"`` raised ``ValueError``. We preserve that
behavior: the model appears in ``/api/models`` but selecting it
surfaces the same error it always did.
"""
name = "huggingface"
llm_class = None # not dispatchable
def get_api_key(self, settings) -> Optional[str]:
return shared_get_api_key(settings, self.name, settings.HUGGINGFACE_API_KEY)

View File

@@ -0,0 +1,19 @@
from __future__ import annotations
from typing import Optional
from application.llm.llama_cpp import LlamaCpp
from application.llm.providers.base import Provider
class LlamaCppProvider(Provider):
"""LLMCreator-only plugin: invocable via LLM_PROVIDER but not in the catalog."""
name = "llama.cpp"
llm_class = LlamaCpp
def get_api_key(self, settings) -> Optional[str]:
return settings.API_KEY
def is_enabled(self, settings) -> bool:
return False

View File

@@ -0,0 +1,23 @@
from __future__ import annotations
from typing import Optional
from application.llm.novita import NovitaLLM
from application.llm.providers._apikey_or_llm_name import (
filter_models_by_llm_name,
get_api_key,
)
from application.llm.providers.base import Provider
class NovitaProvider(Provider):
name = "novita"
llm_class = NovitaLLM
def get_api_key(self, settings) -> Optional[str]:
return get_api_key(settings, self.name, settings.NOVITA_API_KEY)
def filter_yaml_models(self, settings, models):
return filter_models_by_llm_name(
settings, self.name, settings.NOVITA_API_KEY, models
)

View File

@@ -0,0 +1,37 @@
from __future__ import annotations
from typing import Optional
from application.llm.openai import OpenAILLM
from application.llm.providers.base import Provider
class OpenAIProvider(Provider):
name = "openai"
llm_class = OpenAILLM
def get_api_key(self, settings) -> Optional[str]:
if settings.OPENAI_API_KEY:
return settings.OPENAI_API_KEY
if settings.LLM_PROVIDER == self.name and settings.API_KEY:
return settings.API_KEY
return None
def is_enabled(self, settings) -> bool:
# When the deployment is pointed at a custom OpenAI-compatible
# endpoint (Ollama, LM Studio, ...), the cloud-OpenAI catalog is
# suppressed but ``is_enabled`` stays True — necessary so the
# filter below still gets to drop the catalog (rather than the
# registry skipping the provider entirely and missing the rule).
if settings.OPENAI_BASE_URL:
return True
return bool(self.get_api_key(settings))
def filter_yaml_models(self, settings, models):
# Legacy local-endpoint mode hides the cloud catalog. The
# corresponding dynamic models live in OpenAICompatibleProvider.
if settings.OPENAI_BASE_URL:
return []
if not settings.OPENAI_API_KEY:
return []
return models

View File

@@ -0,0 +1,149 @@
"""Generic provider for OpenAI-wire-compatible endpoints.
Each ``openai_compatible`` YAML file describes one logical endpoint
(Mistral, Together, Fireworks, Ollama, ...) with its own
``api_key_env`` and ``base_url``. Multiple files can coexist; the
plugin produces one set of models per file, each pre-configured with
the right credentials and URL.
The plugin also handles the **legacy** ``OPENAI_BASE_URL`` + ``LLM_NAME``
local-endpoint pattern that previously lived in ``OpenAIProvider``. That
path generates models dynamically from ``LLM_NAME``, using
``OPENAI_BASE_URL`` and ``OPENAI_API_KEY`` as the endpoint config.
"""
from __future__ import annotations
import logging
import os
from typing import List, Optional
from application.core.model_settings import (
AvailableModel,
ModelCapabilities,
ModelProvider,
)
from application.llm.openai import OpenAILLM
from application.llm.providers.base import Provider
logger = logging.getLogger(__name__)
def _parse_model_names(llm_name: Optional[str]) -> List[str]:
if not llm_name:
return []
return [name.strip() for name in llm_name.split(",") if name.strip()]
class OpenAICompatibleProvider(Provider):
name = "openai_compatible"
llm_class = OpenAILLM
def get_api_key(self, settings) -> Optional[str]:
# Per-model: each catalog supplies its own ``api_key_env``. There
# is no single plugin-wide key. LLMCreator reads the per-model
# ``api_key`` set during catalog materialization.
return None
def is_enabled(self, settings) -> bool:
# Concrete enablement happens per catalog (in ``get_models``).
# Returning True lets the registry call ``get_models`` so we can
# decide per-file whether to contribute models.
return True
def get_models(self, settings, catalogs) -> List[AvailableModel]:
out: List[AvailableModel] = []
for catalog in catalogs:
out.extend(self._materialize_yaml_catalog(catalog))
if settings.OPENAI_BASE_URL and settings.LLM_NAME:
out.extend(self._materialize_legacy_local_endpoint(settings))
return out
def _materialize_yaml_catalog(self, catalog) -> List[AvailableModel]:
"""Resolve one openai_compatible YAML into ready-to-dispatch models.
Skipped (with an INFO-level log) if ``api_key_env`` resolves to
nothing — no point publishing models the user can't actually
call. INFO rather than WARNING because operators may legitimately
drop multiple provider YAMLs as templates and only set the env
vars for the ones they actually use; a missing key is ambiguous,
not necessarily a misconfig.
"""
if not catalog.base_url:
raise ValueError(
f"{catalog.source_path}: openai_compatible YAML must set "
"'base_url'."
)
if not catalog.api_key_env:
raise ValueError(
f"{catalog.source_path}: openai_compatible YAML must set "
"'api_key_env'."
)
api_key = os.environ.get(catalog.api_key_env)
if not api_key:
logger.info(
"openai_compatible catalog %s skipped: env var %s is not set",
catalog.source_path,
catalog.api_key_env,
)
return []
out: List[AvailableModel] = []
for m in catalog.models:
out.append(self._with_endpoint(m, catalog.base_url, api_key))
return out
def _materialize_legacy_local_endpoint(self, settings) -> List[AvailableModel]:
"""Generate AvailableModels from ``LLM_NAME`` for the legacy
``OPENAI_BASE_URL`` deployment pattern (Ollama, LM Studio, ...).
Preserves the historical ``provider="openai"`` display behavior
by setting ``display_provider="openai"``.
"""
from application.core.model_yaml import resolve_attachment_alias
attachments = resolve_attachment_alias("image")
api_key = settings.OPENAI_API_KEY or settings.API_KEY
out: List[AvailableModel] = []
for model_name in _parse_model_names(settings.LLM_NAME):
out.append(
AvailableModel(
id=model_name,
provider=ModelProvider.OPENAI_COMPATIBLE,
display_name=model_name,
description=f"Custom OpenAI-compatible model at {settings.OPENAI_BASE_URL}",
base_url=settings.OPENAI_BASE_URL,
capabilities=ModelCapabilities(
supports_tools=True,
supported_attachment_types=attachments,
),
api_key=api_key,
display_provider="openai",
)
)
return out
@staticmethod
def _with_endpoint(
model: AvailableModel, base_url: str, api_key: str
) -> AvailableModel:
"""Return a copy of ``model`` carrying the catalog's endpoint config.
The catalog-level ``base_url`` is the default; an explicit
per-model ``base_url`` in the YAML wins.
"""
return AvailableModel(
id=model.id,
provider=model.provider,
display_name=model.display_name,
description=model.description,
capabilities=model.capabilities,
enabled=model.enabled,
base_url=model.base_url or base_url,
display_provider=model.display_provider,
api_key=api_key,
)

View File

@@ -0,0 +1,23 @@
from __future__ import annotations
from typing import Optional
from application.llm.open_router import OpenRouterLLM
from application.llm.providers._apikey_or_llm_name import (
filter_models_by_llm_name,
get_api_key,
)
from application.llm.providers.base import Provider
class OpenRouterProvider(Provider):
name = "openrouter"
llm_class = OpenRouterLLM
def get_api_key(self, settings) -> Optional[str]:
return get_api_key(settings, self.name, settings.OPEN_ROUTER_API_KEY)
def filter_yaml_models(self, settings, models):
return filter_models_by_llm_name(
settings, self.name, settings.OPEN_ROUTER_API_KEY, models
)

View File

@@ -0,0 +1,19 @@
from __future__ import annotations
from typing import Optional
from application.llm.premai import PremAILLM
from application.llm.providers.base import Provider
class PremAIProvider(Provider):
"""LLMCreator-only plugin: invocable via LLM_PROVIDER but not in the catalog."""
name = "premai"
llm_class = PremAILLM
def get_api_key(self, settings) -> Optional[str]:
return settings.API_KEY
def is_enabled(self, settings) -> bool:
return False

View File

@@ -0,0 +1,24 @@
from __future__ import annotations
from typing import Optional
from application.llm.sagemaker import SagemakerAPILLM
from application.llm.providers.base import Provider
class SagemakerProvider(Provider):
"""LLMCreator-only plugin: invocable via LLM_PROVIDER but not in the catalog.
SageMaker reads its credentials from ``SAGEMAKER_*`` settings inside
the LLM class itself; this plugin's ``get_api_key`` exists only for
LLMCreator's symmetry.
"""
name = "sagemaker"
llm_class = SagemakerAPILLM
def get_api_key(self, settings) -> Optional[str]:
return settings.API_KEY
def is_enabled(self, settings) -> bool:
return False

View File

@@ -59,6 +59,7 @@ class LineIterator:
class SagemakerAPILLM(BaseLLM):
provider_name = "sagemaker"
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
import boto3

View File

@@ -1,11 +1,13 @@
import datetime
import functools
import inspect
import time
import logging
import uuid
from typing import Any, Callable, Dict, Generator, List
from application.core import log_context
from application.storage.db.repositories.stack_logs import StackLogsRepository
from application.storage.db.session import db_session
@@ -22,6 +24,15 @@ class LogContext:
self.api_key = api_key
self.query = query
self.stacks = []
# Per-activity response aggregates populated by ``_consume_and_log``
# while it forwards stream items, then flushed onto the
# ``activity_finished`` event so every Flask request gets the
# same summary that ``run_agent_logic`` used to log only for the
# Celery webhook path.
self.answer_length = 0
self.thought_length = 0
self.source_count = 0
self.tool_call_count = 0
def build_stack_data(
@@ -78,25 +89,125 @@ def log_activity() -> Callable:
user = data.get("user", "local")
api_key = data.get("user_api_key", "")
query = kwargs.get("query", getattr(args[0], "query", ""))
agent_id = getattr(args[0], "agent_id", None) or kwargs.get("agent_id")
conversation_id = (
kwargs.get("conversation_id")
or getattr(args[0], "conversation_id", None)
)
model = getattr(args[0], "gpt_model", None) or getattr(args[0], "model", None)
# Capture the surrounding activity_id before overlaying ours,
# so nested activities record the parent → child link.
parent_activity_id = log_context.snapshot().get("activity_id")
context = LogContext(endpoint, activity_id, user, api_key, query)
kwargs["log_context"] = context
logging.info(
f"Starting activity: {endpoint} - {activity_id} - User: {user}"
ctx_token = log_context.bind(
activity_id=activity_id,
parent_activity_id=parent_activity_id,
user_id=user,
agent_id=agent_id,
conversation_id=conversation_id,
endpoint=endpoint,
model=model,
)
generator = func(*args, **kwargs)
yield from _consume_and_log(generator, context)
started_at = time.monotonic()
logging.info(
"activity_started",
extra={
"activity_id": activity_id,
"parent_activity_id": parent_activity_id,
"user_id": user,
"agent_id": agent_id,
"conversation_id": conversation_id,
"endpoint": endpoint,
"model": model,
},
)
error: BaseException | None = None
try:
generator = func(*args, **kwargs)
yield from _consume_and_log(generator, context)
except Exception as exc:
# Only ``Exception`` counts as an activity error; ``GeneratorExit``
# (consumer disconnected mid-stream) and ``KeyboardInterrupt``
# flow through the finally as ``status="ok"``, matching
# ``_consume_and_log``.
error = exc
raise
finally:
_emit_activity_finished(
context=context,
parent_activity_id=parent_activity_id,
started_at=started_at,
error=error,
)
log_context.reset(ctx_token)
return wrapper
return decorator
def _emit_activity_finished(
*,
context: "LogContext",
parent_activity_id: str | None,
started_at: float,
error: BaseException | None,
) -> None:
"""Emit the paired ``activity_finished`` event with duration, outcome,
and per-activity response aggregates accumulated in ``_consume_and_log``.
"""
duration_ms = int((time.monotonic() - started_at) * 1000)
logging.info(
"activity_finished",
extra={
"activity_id": context.activity_id,
"parent_activity_id": parent_activity_id,
"user_id": context.user,
"endpoint": context.endpoint,
"duration_ms": duration_ms,
"status": "error" if error is not None else "ok",
"error_class": type(error).__name__ if error is not None else None,
"answer_length": context.answer_length,
"thought_length": context.thought_length,
"source_count": context.source_count,
"tool_call_count": context.tool_call_count,
},
)
def _accumulate_response_summary(item: Any, context: "LogContext") -> None:
"""Mirror the per-line aggregation that ``run_agent_logic`` did for the
Celery webhook path, but at the generator-consumption layer so every
``Agent.gen`` activity (Flask streaming, sub-agents, workflow agents)
gets the same summary.
"""
if not isinstance(item, dict):
return
if "answer" in item:
context.answer_length += len(str(item["answer"]))
return
if "thought" in item:
context.thought_length += len(str(item["thought"]))
return
sources = item.get("sources") if "sources" in item else None
if isinstance(sources, list):
context.source_count += len(sources)
return
tool_calls = item.get("tool_calls") if "tool_calls" in item else None
if isinstance(tool_calls, list):
context.tool_call_count += len(tool_calls)
def _consume_and_log(generator: Generator, context: "LogContext"):
try:
for item in generator:
_accumulate_response_summary(item, context)
yield item
except Exception as e:
logging.exception(f"Error in {context.endpoint} - {context.activity_id}: {e}")

59
application/mcp_server.py Normal file
View File

@@ -0,0 +1,59 @@
"""FastMCP server exposing DocsGPT retrieval over streamable HTTP.
Mounted at ``/mcp`` by ``application/asgi.py``. Bearer tokens are the
existing DocsGPT agent API keys — no new credential surface.
The tool reads the ``Authorization`` header directly via
``get_http_headers(include={"authorization"})``. The ``include`` kwarg
is required: by default ``get_http_headers`` strips ``authorization``
(and a handful of other hop-by-hop headers) so they aren't forwarded
to downstream services — since we deliberately want the caller's
token, we opt it back in.
"""
from __future__ import annotations
import asyncio
import logging
from fastmcp import FastMCP
from fastmcp.server.dependencies import get_http_headers
from application.services.search_service import (
InvalidAPIKey,
SearchFailed,
search,
)
logger = logging.getLogger(__name__)
mcp = FastMCP("docsgpt")
def _extract_bearer_token() -> str | None:
auth = get_http_headers(include={"authorization"}).get("authorization", "")
parts = auth.split(None, 1)
if len(parts) != 2 or parts[0].lower() != "bearer" or not parts[1]:
return None
return parts[1]
@mcp.tool
async def search_docs(query: str, chunks: int = 5) -> list[dict]:
"""Search the caller's DocsGPT knowledge base.
Authentication is via ``Authorization: Bearer <agent-api-key>`` on
the MCP request — the same opaque key that ``/api/search`` accepts
in its JSON body. Returns at most ``chunks`` hits, each a dict with
``text``, ``title``, ``source`` keys.
"""
api_key = _extract_bearer_token()
if not api_key:
raise PermissionError("Missing Bearer token")
try:
return await asyncio.to_thread(search, api_key, query, chunks)
except InvalidAPIKey as exc:
raise PermissionError("Invalid API key") from exc
except SearchFailed:
logger.exception("search_docs failed")
raise

View File

@@ -1,9 +1,12 @@
a2wsgi==1.10.10
alembic>=1.13,<2
anthropic==0.88.0
asgiref>=3.11.1
boto3==1.42.83
beautifulsoup4==4.14.3
cel-python==0.5.0
celery==5.6.3
celery-redbeat==2.3.3
cryptography==46.0.7
dataclasses-json==0.6.7
defusedxml==0.7.1
@@ -14,7 +17,7 @@ docx2txt==0.9
ddgs>=8.0.0
fast-ebook
elevenlabs==2.43.0
Flask==3.1.3
Flask==3.1.1
faiss-cpu==1.13.2
fastmcp==3.2.4
flask-restx==1.3.2
@@ -49,6 +52,16 @@ networkx==3.6.1
numpy==2.4.4
openai==2.32.0
openapi3-parser==1.1.22
opentelemetry-distro>=0.50b0,<1
opentelemetry-exporter-otlp>=1.29.0,<2
opentelemetry-instrumentation-celery>=0.50b0,<1
opentelemetry-instrumentation-flask>=0.50b0,<1
opentelemetry-instrumentation-logging>=0.50b0,<1
opentelemetry-instrumentation-psycopg>=0.50b0,<1
opentelemetry-instrumentation-redis>=0.50b0,<1
opentelemetry-instrumentation-requests>=0.50b0,<1
opentelemetry-instrumentation-sqlalchemy>=0.50b0,<1
opentelemetry-instrumentation-starlette>=0.50b0,<1
orjson==3.11.7
packaging==26.0
pandas==3.0.2
@@ -58,7 +71,7 @@ pdf2image>=1.17.0
pillow
portalocker>=2.7.0,<4.0.0
prompt-toolkit==3.0.52
protobuf==7.34.1
protobuf==6.33.6
psycopg[binary,pool]>=3.1,<4
py==1.11.0
pydantic
@@ -69,6 +82,7 @@ python-dateutil==2.9.0.post0
python-dotenv
python-jose==3.5.0
python-pptx==1.0.2
PyYAML
redis==7.4.0
referencing>=0.28.0,<0.38.0
regex==2026.4.4
@@ -76,6 +90,7 @@ requests==2.33.1
retry==0.9.2
sentence-transformers==5.3.0
sqlalchemy>=2.0,<3
starlette>=1.0,<2
tiktoken==0.12.0
tokenizers==0.22.2
torch==2.11.0
@@ -85,6 +100,8 @@ typing-extensions==4.15.0
typing-inspect==0.9.0
tzdata==2026.1
urllib3==2.6.3
uvicorn[standard]>=0.30,<1
uvicorn-worker>=0.4,<1
vine==5.1.0
wcwidth==0.6.0
werkzeug>=3.1.0

View File

@@ -22,6 +22,7 @@ class ClassicRAG(BaseRetriever):
llm_name=settings.LLM_PROVIDER,
api_key=settings.API_KEY,
decoded_token=None,
model_user_id=None,
):
self.original_question = source.get("question", "")
self.chat_history = chat_history if chat_history is not None else []
@@ -42,17 +43,22 @@ class ClassicRAG(BaseRetriever):
f"sources={'active_docs' in source and source['active_docs'] is not None}"
)
self.model_id = model_id
self.model_user_id = model_user_id
self.doc_token_limit = doc_token_limit
self.user_api_key = user_api_key
self.agent_id = agent_id
self.llm_name = llm_name
self.api_key = api_key
# Forward model_id + model_user_id so LLMCreator resolves BYOM
# base_url / api_key / upstream id for the rephrase client.
self.llm = LLMCreator.create_llm(
self.llm_name,
api_key=self.api_key,
user_api_key=self.user_api_key,
decoded_token=decoded_token,
model_id=self.model_id,
agent_id=self.agent_id,
model_user_id=self.model_user_id,
)
if "active_docs" in source and source["active_docs"] is not None:
@@ -103,7 +109,11 @@ class ClassicRAG(BaseRetriever):
]
try:
rephrased_query = self.llm.gen(model=self.model_id, messages=messages)
# Send upstream id (resolved by LLMCreator), not registry UUID.
rephrased_query = self.llm.gen(
model=getattr(self.llm, "model_id", None) or self.model_id,
messages=messages,
)
print(f"Rephrased query: {rephrased_query}")
return rephrased_query if rephrased_query else self.original_question
except Exception as e:

View File

@@ -0,0 +1,464 @@
"""SSRF protection for user-supplied OpenAI-compatible base URLs.
This module is the single chokepoint for validating any URL that a user
provides as an OpenAI-compatible ``base_url`` ("Bring Your Own Model").
The backend will later issue outbound HTTP requests to that URL on the
user's behalf, so we must reject anything that could be used to reach
internal-network resources (cloud metadata services, RFC 1918 ranges,
loopback, link-local, etc.).
Three entry points:
* :func:`validate_user_base_url` — called at create/update time on REST
routes that persist the URL, to give the user immediate feedback.
* :func:`pinned_post` — called at dispatch time when the caller drives
``requests`` directly (e.g. the ``/api/models/test`` endpoint).
Resolves once, dials the IP literal, preserves the original hostname
in the ``Host`` header and via SNI / cert verification for HTTPS.
* :func:`pinned_httpx_client` — called at dispatch time when the caller
hands an ``httpx.Client`` to a third-party SDK (e.g. the OpenAI
Python SDK via ``OpenAI(http_client=...)``). Same DNS-rebinding
closure on the httpx transport layer.
Why all three: the OpenAI / httpx ecosystem performs its own DNS lookup
inside ``socket.getaddrinfo`` when a connection opens, so a hostile DNS
server can hand a public IP to the validator and a loopback / link-local
address to the HTTP client. Validate-then-construct-SDK is unsafe; the
pinned variants close that TOCTOU window by resolving exactly once and
dialing the chosen IP literal directly.
"""
from __future__ import annotations
import ipaddress
import socket
from typing import Any, Iterable
from urllib.parse import urlsplit, urlunsplit
import httpx
import requests
from requests.adapters import HTTPAdapter
# Allowed URL schemes. Anything else (file, gopher, ftp, data, ...) is
# rejected outright because it either bypasses HTTP entirely or enables
# protocol smuggling against the proxy stack.
_ALLOWED_SCHEMES: frozenset[str] = frozenset({"http", "https"})
# Hostnames that resolve to a loopback / metadata / unspecified address
# but which we want to reject *by name* as well, so the rejection
# message is unambiguous and so we never accidentally call DNS on them.
_BLOCKED_HOSTNAMES: frozenset[str] = frozenset(
{
"localhost",
"localhost.localdomain",
"0.0.0.0",
"::",
"::1",
"ip6-localhost",
"ip6-loopback",
# GCP metadata service. AWS/Azure use 169.254.169.254 which the
# IP-range check below already covers via the link-local range,
# but Google's hostname does not always resolve to a link-local
# IP from every VPC, so we hard-deny the string too.
"metadata.google.internal",
}
)
# Carrier-grade NAT (RFC 6598). Python's ``ipaddress`` module does NOT
# classify this range as ``is_private``, so we must check it explicitly.
_CGNAT_NETWORK_V4: ipaddress.IPv4Network = ipaddress.IPv4Network("100.64.0.0/10")
class UnsafeUserUrlError(ValueError):
"""Raised when a user-supplied URL fails SSRF validation.
Subclasses :class:`ValueError` so call sites that already treat
invalid input as a 400-class error continue to work. The string
message names the specific reason (scheme, hostname, resolved IP,
DNS failure, ...) so that it can be surfaced to the user verbatim.
"""
def _strip_ipv6_brackets(host: str) -> str:
"""Return ``host`` with surrounding ``[`` / ``]`` removed if present."""
if host.startswith("[") and host.endswith("]"):
return host[1:-1]
return host
def _is_blocked_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
"""Return ``True`` if ``ip`` falls in any range we refuse to dial.
This is the single source of truth for the IP-range policy:
* loopback (``127.0.0.0/8``, ``::1``)
* private (RFC 1918, ULA ``fc00::/7``)
* link-local (``169.254.0.0/16``, ``fe80::/10``)
* multicast (``224.0.0.0/4``, ``ff00::/8``)
* unspecified (``0.0.0.0``, ``::``)
* reserved (``240.0.0.0/4``, etc.)
* carrier-grade NAT (``100.64.0.0/10``) — not covered by ``is_private``
"""
if (
ip.is_loopback
or ip.is_private
or ip.is_link_local
or ip.is_multicast
or ip.is_unspecified
or ip.is_reserved
):
return True
if isinstance(ip, ipaddress.IPv4Address) and ip in _CGNAT_NETWORK_V4:
return True
return False
def _resolve(host: str) -> Iterable[ipaddress.IPv4Address | ipaddress.IPv6Address]:
"""Resolve ``host`` to every A/AAAA record returned by the system.
Returning *all* addresses (rather than the first one) is critical:
a hostile DNS server can return a public IP first followed by a
private IP, and the underlying HTTP client may fail over to the
private one on connect. We treat the set as unsafe if any element
is unsafe.
"""
try:
results = socket.getaddrinfo(host, None)
except socket.gaierror as exc: # noqa: PERF203 — re-raise as our own type
raise UnsafeUserUrlError(f"could not resolve hostname {host!r}: {exc}") from exc
addresses: list[ipaddress.IPv4Address | ipaddress.IPv6Address] = []
for entry in results:
sockaddr = entry[4]
# IPv4 sockaddr: (host, port). IPv6 sockaddr: (host, port, flowinfo, scope_id).
ip_str = sockaddr[0]
# Strip IPv6 zone-id ("fe80::1%lo0") before parsing.
if "%" in ip_str:
ip_str = ip_str.split("%", 1)[0]
try:
addresses.append(ipaddress.ip_address(ip_str))
except ValueError:
# An entry we can't parse is itself suspicious; treat as unsafe.
raise UnsafeUserUrlError(
f"hostname {host!r} resolved to unparseable address {ip_str!r}"
) from None
return addresses
def _validate_and_pick_ip(
url: str,
) -> tuple[str, ipaddress.IPv4Address | ipaddress.IPv6Address, "urlsplit"]:
"""Run the SSRF guard and return the data needed to dial safely.
Performs every check :func:`validate_user_base_url` performs, but
additionally returns ``(hostname, ip, parts)`` where ``ip`` is one
of the validated addresses (the first record returned by the
resolver, or the literal itself if the URL already used an IP) and
``parts`` is the :func:`urllib.parse.urlsplit` result so callers do
not have to re-parse the URL.
Raises :class:`UnsafeUserUrlError` on the same conditions as
:func:`validate_user_base_url`.
"""
if not isinstance(url, str) or not url.strip():
raise UnsafeUserUrlError("url must be a non-empty string")
try:
parts = urlsplit(url)
except ValueError as exc:
raise UnsafeUserUrlError(f"could not parse url {url!r}: {exc}") from exc
scheme = parts.scheme.lower()
if scheme not in _ALLOWED_SCHEMES:
raise UnsafeUserUrlError(
f"scheme {scheme!r} is not allowed; only http and https are permitted"
)
# ``urlsplit`` returns the bracketed form for IPv6 in ``netloc`` but
# the bare form in ``hostname``. Normalize via lower() because
# hostnames are case-insensitive and we compare against a lowercase
# blocklist.
raw_host = parts.hostname
if not raw_host:
raise UnsafeUserUrlError(f"url {url!r} has no hostname")
host = raw_host.lower()
# Check the literal-string blocklist first. urlsplit().hostname strips
# IPv6 brackets, so we also test the bracketed form for completeness
# (matches the public-spec note about ``[::]``).
bracketed = f"[{host}]"
if host in _BLOCKED_HOSTNAMES or bracketed in _BLOCKED_HOSTNAMES:
raise UnsafeUserUrlError(
f"hostname {raw_host!r} is not allowed (matches internal-only name)"
)
# If the host is already an IP literal (with or without IPv6 brackets),
# check it directly without going to DNS — DNS for an IP literal is a
# no-op but it's clearer to short-circuit and gives a better message.
candidate = _strip_ipv6_brackets(host)
try:
literal = ipaddress.ip_address(candidate)
except ValueError:
literal = None
if literal is not None:
if _is_blocked_ip(literal):
raise UnsafeUserUrlError(
f"hostname {raw_host!r} resolves to blocked address {literal} "
f"(loopback/private/link-local/multicast/reserved/CGNAT)"
)
return host, literal, parts
# Hostname (not an IP literal) — resolve and validate every record.
addresses = list(_resolve(host))
for ip in addresses:
if _is_blocked_ip(ip):
raise UnsafeUserUrlError(
f"hostname {raw_host!r} resolves to blocked address {ip} "
f"(loopback/private/link-local/multicast/reserved/CGNAT)"
)
if not addresses:
# ``getaddrinfo`` would normally raise instead of returning an
# empty list, but treat the degenerate case as unsafe too — we
# have nothing to bind a connection to.
raise UnsafeUserUrlError(
f"hostname {raw_host!r} returned no addresses from DNS"
)
return host, addresses[0], parts
def validate_user_base_url(url: str) -> None:
"""Validate that ``url`` is safe to use as an outbound base URL.
Resolve the URL's hostname to one or more IPs and reject if any
resolved IP is private/loopback/link-local/multicast/reserved, or if
the URL uses a non-http(s) scheme, or if the hostname is one of the
known dangerous strings (``localhost``, ``0.0.0.0``, ``[::]``).
Raises :class:`UnsafeUserUrlError` on rejection. Returns ``None`` on
success.
This function is the create/update-time check. At dispatch time use
:func:`pinned_post` instead, which performs the same validation
*and* pins the outbound connection to the validated IP so a DNS
rebinder cannot flip the resolution between check and connect.
Args:
url: The user-supplied URL to validate. Expected to be an
absolute URL with an ``http`` or ``https`` scheme.
Raises:
UnsafeUserUrlError: If the URL fails to parse, uses a forbidden
scheme, has an empty/blocklisted hostname, fails DNS
resolution, or resolves to any IP in a blocked range.
"""
_validate_and_pick_ip(url)
class _PinnedHostAdapter(HTTPAdapter):
"""HTTPS adapter that performs SNI and cert verification against a
fixed hostname even when the URL connects to an IP literal.
Used by :func:`pinned_post` so that resolving the user-supplied
hostname once and dialing the resolved IP doesn't break TLS.
Without this, ``urllib3`` would default ``server_hostname`` /
``assert_hostname`` to the connect host (the IP) and either send the
wrong SNI or fail cert verification — the cert is for the original
hostname, not the IP literal.
"""
def __init__(self, server_hostname: str, *args: Any, **kwargs: Any) -> None:
self._server_hostname = server_hostname
super().__init__(*args, **kwargs)
def init_poolmanager(self, *args: Any, **kwargs: Any) -> None:
kwargs["server_hostname"] = self._server_hostname
kwargs["assert_hostname"] = self._server_hostname
super().init_poolmanager(*args, **kwargs)
def _ip_to_url_host(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> str:
"""Return ``ip`` formatted for use in a URL netloc (brackets for v6)."""
if isinstance(ip, ipaddress.IPv6Address):
return f"[{ip}]"
return str(ip)
def pinned_post(
url: str,
*,
json: Any = None,
headers: dict[str, str] | None = None,
timeout: float = 5.0,
allow_redirects: bool = False,
) -> requests.Response:
"""POST to ``url`` with the outbound connection pinned to a single
validated IP, closing the DNS-rebinding TOCTOU window left by the
naive validate-then-``requests.post`` pattern.
The URL's hostname is resolved exactly once. Every returned address
must pass the same SSRF guard as :func:`validate_user_base_url`. The
outbound request is issued against the chosen IP literal (so
``urllib3`` cannot ask the resolver again and receive a different
answer); the original hostname is preserved in the ``Host`` header
and, for HTTPS, via :class:`_PinnedHostAdapter` for SNI and cert
verification.
Args:
url: Absolute http(s) URL to POST to.
json: JSON-serializable payload — passed through to ``requests``.
headers: Caller-supplied headers. Any caller-supplied ``Host``
entry is overwritten so the in-flight request matches what
was validated.
timeout: Per-request timeout (seconds).
allow_redirects: Forwarded to ``requests``. Defaults to
``False`` because the SSRF guard only inspects the supplied
URL — following redirects would let a hostile upstream
bounce the request to an internal address.
Raises:
UnsafeUserUrlError: If the URL fails the SSRF guard.
requests.RequestException: For network-level failures.
"""
host, ip, parts = _validate_and_pick_ip(url)
netloc = _ip_to_url_host(ip)
if parts.port is not None:
netloc = f"{netloc}:{parts.port}"
pinned_url = urlunsplit(
(parts.scheme, netloc, parts.path, parts.query, parts.fragment)
)
request_headers = dict(headers or {})
host_header = host if parts.port is None else f"{host}:{parts.port}"
request_headers["Host"] = host_header
session = requests.Session()
if parts.scheme == "https":
session.mount("https://", _PinnedHostAdapter(host))
try:
return session.post(
pinned_url,
json=json,
headers=request_headers,
timeout=timeout,
allow_redirects=allow_redirects,
)
finally:
session.close()
class _PinnedHTTPSTransport(httpx.HTTPTransport):
"""``httpx`` transport pinned to a single validated IP literal.
Closes the DNS-rebinding TOCTOU window that
:func:`validate_user_base_url` cannot close on its own. The OpenAI
Python SDK (and any other SDK that uses ``httpx``) re-resolves the
hostname inside ``socket.getaddrinfo`` at request time, so a
hostile DNS server can return a public IP at validation time and a
private IP at request time. This transport rewrites every outgoing
request's URL host to the validated IP literal so ``httpcore``
dials that IP without a fresh lookup.
The original hostname is preserved in two places:
1. ``Host`` header — ``httpx.Request._prepare`` set it from the URL
netloc *before* this transport runs, so it carries the hostname
not the IP literal. We deliberately do not touch headers here.
2. TLS SNI / cert verification — set via the
``request.extensions["sni_hostname"]`` extension which
``httpcore`` feeds into ``start_tls``'s ``server_hostname``
parameter. Without this, ``urllib3``-equivalent code would use
the IP literal as SNI and cert verification would fail (the
cert is for the original hostname, not the IP).
"""
def __init__(
self,
validated_host: str,
validated_ip: ipaddress.IPv4Address | ipaddress.IPv6Address,
**kwargs: Any,
) -> None:
# http2=False (the httpx default) — defense in depth against
# HTTP/2 connection coalescing (RFC 7540 §9.1.1), where a
# client may reuse a TCP connection for any host whose cert
# covers it. Per-IP pinning never shares connections across
# hosts, but explicit is safer than relying on the default.
kwargs.setdefault("http2", False)
super().__init__(**kwargs)
self._host = validated_host
self._ip_netloc = _ip_to_url_host(validated_ip)
def handle_request(self, request: httpx.Request) -> httpx.Response:
# Defense in depth: refuse if the request URL's host doesn't
# match what we validated. Catches any future SDK regression
# that rewrites the URL between Request construction and dial,
# and any rare case where the SDK reuses our pinned client for
# a different host (which it shouldn't, but assert it anyway).
if request.url.host != self._host:
raise UnsafeUserUrlError(
f"pinned transport bound to {self._host!r}, refused "
f"request for {request.url.host!r}"
)
# SNI/server_hostname for TLS verification. httpcore reads this
# extension at _sync/connection.py and feeds it into
# start_tls's server_hostname argument. Set before the URL host
# is rewritten so cert validation continues to use the original
# hostname even though TCP dials the IP literal.
request.extensions = {
**request.extensions,
"sni_hostname": self._host.encode("ascii"),
}
request.url = request.url.copy_with(host=self._ip_netloc)
return super().handle_request(request)
def pinned_httpx_client(
base_url: str,
*,
timeout: float = 600.0,
) -> httpx.Client:
"""Return an :class:`httpx.Client` whose connections are pinned to
one validated IP, closing the DNS-rebinding TOCTOU window the naive
``OpenAI(base_url=...)`` flow leaves open.
The hostname in ``base_url`` is resolved exactly once. Every
returned address must pass :func:`_validate_and_pick_ip`'s SSRF
guard (loopback, RFC 1918, link-local, multicast, reserved, CGNAT,
cloud metadata names). The chosen IP becomes the URL host on every
outgoing request so ``httpcore`` cannot ask the resolver again.
Pass via ``OpenAI(http_client=pinned_httpx_client(base_url))`` (or
any other SDK that accepts an ``httpx.Client``) to make BYOM
dispatch immune to DNS-rebinding TOCTOU.
Args:
base_url: User-supplied http(s) URL. Validated through the same
SSRF guard as :func:`validate_user_base_url`.
timeout: Per-request timeout (seconds). Defaults to 600 to
match the OpenAI SDK's default; callers should override
for non-LLM workloads.
Raises:
UnsafeUserUrlError: If ``base_url`` fails the SSRF guard.
"""
host, ip, _parts = _validate_and_pick_ip(base_url)
transport = _PinnedHTTPSTransport(host, ip)
# follow_redirects=False — the SSRF guard only inspects the
# supplied URL; following 3xx would let a hostile upstream bounce
# the in-network request to an internal address (cloud metadata,
# RFC1918, loopback) carrying whatever credentials the SDK adds.
return httpx.Client(
transport=transport,
timeout=timeout,
follow_redirects=False,
)

View File

@@ -0,0 +1,153 @@
"""Shared retrieval service used by the HTTP search route and the MCP tool.
Flask-free. Raises domain exceptions (``InvalidAPIKey``, ``SearchFailed``)
that callers translate into their own wire protocol (HTTP status codes,
MCP error responses, etc.).
"""
from __future__ import annotations
import logging
from typing import Any, Dict, List
from application.core.settings import settings
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.session import db_readonly
from application.vectorstore.vector_creator import VectorCreator
logger = logging.getLogger(__name__)
class InvalidAPIKey(Exception):
"""The supplied ``api_key`` does not resolve to an agent."""
class SearchFailed(Exception):
"""Unexpected error during retrieval (e.g. DB outage). Caller maps to 5xx."""
def _collect_source_ids(agent: Dict[str, Any]) -> List[str]:
"""Extract the ordered list of source UUIDs to search.
Prefers ``extra_source_ids`` (PG ARRAY(UUID) of multi-source agents);
falls back to the legacy single ``source_id`` field.
"""
source_ids: List[str] = []
extra = agent.get("extra_source_ids") or []
for src in extra:
if src:
source_ids.append(str(src))
if not source_ids:
single = agent.get("source_id")
if single:
source_ids.append(str(single))
return source_ids
def _search_sources(
query: str, source_ids: List[str], chunks: int
) -> List[Dict[str, Any]]:
"""Search across each source's vectorstore and return up to ``chunks`` hits.
Per-source errors are logged and skipped so one broken index doesn't
take down the whole search. Results are de-duplicated by content hash.
"""
if chunks <= 0 or not source_ids:
return []
results: List[Dict[str, Any]] = []
chunks_per_source = max(1, chunks // len(source_ids))
seen_texts: set[int] = set()
for source_id in source_ids:
if not source_id or not source_id.strip():
continue
try:
docsearch = VectorCreator.create_vectorstore(
settings.VECTOR_STORE, source_id, settings.EMBEDDINGS_KEY
)
docs = docsearch.search(query, k=chunks_per_source * 2)
for doc in docs:
if len(results) >= chunks:
break
if hasattr(doc, "page_content") and hasattr(doc, "metadata"):
page_content = doc.page_content
metadata = doc.metadata
else:
page_content = doc.get("text", doc.get("page_content", ""))
metadata = doc.get("metadata", {})
text_hash = hash(page_content[:200])
if text_hash in seen_texts:
continue
seen_texts.add(text_hash)
title = metadata.get("title", metadata.get("post_title", ""))
if not isinstance(title, str):
title = str(title) if title else ""
if title:
title = title.split("/")[-1]
else:
title = metadata.get("filename", page_content[:50] + "...")
source = metadata.get("source", source_id)
results.append(
{
"text": page_content,
"title": title,
"source": source,
}
)
if len(results) >= chunks:
break
except Exception as e:
logger.error(
f"Error searching vectorstore {source_id}: {e}",
exc_info=True,
)
continue
return results[:chunks]
def search(api_key: str, query: str, chunks: int = 5) -> List[Dict[str, Any]]:
"""Resolve an agent by API key and search its sources.
Args:
api_key: Agent API key (the opaque string stored on
``agents.key`` in Postgres).
query: Free-text search query.
chunks: Max number of hits to return.
Returns:
List of hit dicts with ``text``, ``title``, ``source`` keys.
Empty list if the agent has no sources configured.
Raises:
InvalidAPIKey: if ``api_key`` does not resolve to an agent.
SearchFailed: on unexpected DB / infrastructure errors.
"""
if chunks <= 0:
return []
try:
with db_readonly() as conn:
agent = AgentsRepository(conn).find_by_key(api_key)
except Exception as e:
raise SearchFailed("agent lookup failed") from e
if not agent:
raise InvalidAPIKey()
source_ids = _collect_source_ids(agent)
if not source_ids:
return []
return _search_sources(query, source_ids, chunks)

View File

@@ -203,6 +203,24 @@ agents_table = Table(
Column("legacy_mongo_id", Text),
)
user_custom_models_table = Table(
"user_custom_models",
metadata,
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
Column("user_id", Text, nullable=False),
Column("upstream_model_id", Text, nullable=False),
Column("display_name", Text, nullable=False),
Column("description", Text, nullable=False, server_default=""),
Column("base_url", Text, nullable=False),
# AES-CBC ciphertext (base64) keyed via per-user PBKDF2 in
# application.security.encryption.encrypt_credentials.
Column("api_key_encrypted", Text, nullable=False),
Column("capabilities", JSONB, nullable=False, server_default="{}"),
Column("enabled", Boolean, nullable=False, server_default="true"),
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
)
attachments_table = Table(
"attachments",
metadata,

View File

@@ -1,7 +1,6 @@
"""Repository for the ``agents`` table.
This is the most complex Phase 2 repository. Covers every write operation
the legacy Mongo code performs on ``agents_collection``:
Covers every write operation the legacy Mongo code performs on ``agents_collection``:
- create, update, delete
- find by key (API key lookup)

View File

@@ -17,6 +17,21 @@ _UPDATABLE_SCALARS = {
_UPDATABLE_JSONB = {"metadata"}
def _attachment_to_dict(row: Any) -> dict:
"""row_to_dict + ``upload_path``→``path`` alias.
Pre-Postgres, the Mongo attachment shape used ``path``. The PG column
is ``upload_path``; LLM provider code (google_ai/openai/anthropic and
handlers/base) still reads ``attachment.get("path")``. Mirroring the
``id``/``_id`` dual-emit in row_to_dict so consumers don't need to
know which storage backend produced the dict.
"""
out = row_to_dict(row)
if "upload_path" in out and out.get("path") is None:
out["path"] = out["upload_path"]
return out
class AttachmentsRepository:
def __init__(self, conn: Connection) -> None:
self._conn = conn
@@ -66,7 +81,7 @@ class AttachmentsRepository:
"legacy_mongo_id": legacy_mongo_id,
},
)
return row_to_dict(result.fetchone())
return _attachment_to_dict(result.fetchone())
def get(self, attachment_id: str, user_id: str) -> Optional[dict]:
result = self._conn.execute(
@@ -76,7 +91,7 @@ class AttachmentsRepository:
{"id": attachment_id, "user_id": user_id},
)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
return _attachment_to_dict(row) if row is not None else None
def get_any(self, attachment_id: str, user_id: str) -> Optional[dict]:
"""Resolve an attachment by either PG UUID or legacy Mongo ObjectId string."""
@@ -155,14 +170,14 @@ class AttachmentsRepository:
params["user_id"] = user_id
result = self._conn.execute(text(sql), params)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
return _attachment_to_dict(row) if row is not None else None
def list_for_user(self, user_id: str) -> list[dict]:
result = self._conn.execute(
text("SELECT * FROM attachments WHERE user_id = :user_id ORDER BY created_at DESC"),
{"user_id": user_id},
)
return [row_to_dict(r) for r in result.fetchall()]
return [_attachment_to_dict(r) for r in result.fetchall()]
def update(self, attachment_id: str, user_id: str, fields: dict) -> bool:
"""Partial update. Used by the LLM providers to cache their

View File

@@ -0,0 +1,199 @@
"""Repository for the ``user_custom_models`` table.
Backs the end-user "Bring Your Own Model" feature. Each row is one
user-supplied OpenAI-compatible endpoint (Mistral, Together, vLLM, ...).
The ``id`` UUID is the internal DocsGPT identifier (what agents store
in ``default_model_id``); ``upstream_model_id`` is what we send verbatim
to the provider's API.
API key handling: callers pass plaintext via ``api_key_plaintext``;
this module wraps the existing ``application.security.encryption``
helper (AES-CBC + per-user PBKDF2 salt) and writes the base64 ciphertext
to the ``api_key_encrypted`` column. Decryption is the caller's
responsibility (they hold the ``user_id``).
"""
from __future__ import annotations
from typing import Any, Optional
from sqlalchemy import Connection, func, text
from application.security.encryption import (
decrypt_credentials,
encrypt_credentials,
)
from application.storage.db.base_repository import row_to_dict
from application.storage.db.models import user_custom_models_table
_ALLOWED_CAPABILITY_KEYS = frozenset(
{
"supports_tools",
"supports_structured_output",
"supports_streaming",
"attachments",
"context_window",
}
)
class UserCustomModelsRepository:
def __init__(self, conn: Connection) -> None:
self._conn = conn
# ------------------------------------------------------------------ #
# Encryption wrappers
# ------------------------------------------------------------------ #
@staticmethod
def _encrypt_api_key(api_key_plaintext: str, user_id: str) -> str:
"""Encrypt ``api_key_plaintext`` with the per-user PBKDF2 scheme."""
return encrypt_credentials({"api_key": api_key_plaintext}, user_id)
@staticmethod
def _decrypt_api_key(api_key_encrypted: str, user_id: str) -> Optional[str]:
"""Decrypt the API key. Returns None on failure (which the caller
should surface as a configuration error rather than silently
proceeding with the upstream call)."""
if not api_key_encrypted:
return None
creds = decrypt_credentials(api_key_encrypted, user_id)
return creds.get("api_key") if creds else None
@staticmethod
def _normalize_capabilities(caps: Optional[dict]) -> dict:
"""Drop unknown keys; nothing else is forced. Callers (the route
layer) are responsible for value validation (numeric ranges,
attachment alias resolution)."""
if not caps:
return {}
return {k: v for k, v in caps.items() if k in _ALLOWED_CAPABILITY_KEYS}
# ------------------------------------------------------------------ #
# CRUD
# ------------------------------------------------------------------ #
def create(
self,
user_id: str,
upstream_model_id: str,
display_name: str,
base_url: str,
api_key_plaintext: str,
description: str = "",
capabilities: Optional[dict] = None,
enabled: bool = True,
) -> dict:
values = {
"user_id": user_id,
"upstream_model_id": upstream_model_id,
"display_name": display_name,
"description": description or "",
"base_url": base_url,
"api_key_encrypted": self._encrypt_api_key(api_key_plaintext, user_id),
"capabilities": self._normalize_capabilities(capabilities),
"enabled": bool(enabled),
}
from sqlalchemy.dialects.postgresql import insert as pg_insert
stmt = (
pg_insert(user_custom_models_table)
.values(**values)
.returning(user_custom_models_table)
)
result = self._conn.execute(stmt)
return row_to_dict(result.fetchone())
def get(self, model_id: str, user_id: str) -> Optional[dict]:
result = self._conn.execute(
text(
"SELECT * FROM user_custom_models "
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
),
{"id": str(model_id), "user_id": user_id},
)
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def list_for_user(self, user_id: str) -> list[dict]:
result = self._conn.execute(
text(
"SELECT * FROM user_custom_models "
"WHERE user_id = :user_id ORDER BY created_at DESC"
),
{"user_id": user_id},
)
return [row_to_dict(r) for r in result.fetchall()]
def update(self, model_id: str, user_id: str, fields: dict) -> bool:
"""Apply a partial update.
Special-cases ``api_key_plaintext``: when present, it is encrypted
and stored in ``api_key_encrypted``. When absent (or empty), the
existing ciphertext is kept untouched. This is the wire-shape
``PATCH`` expects (the UI sends a blank password field when the
operator wants to keep the existing key).
"""
allowed = {
"upstream_model_id",
"display_name",
"description",
"base_url",
"capabilities",
"enabled",
}
values: dict[str, Any] = {}
for col, val in fields.items():
if col not in allowed or val is None:
continue
if col == "capabilities":
values[col] = self._normalize_capabilities(val)
elif col == "enabled":
values[col] = bool(val)
else:
values[col] = val
api_key_plaintext = fields.get("api_key_plaintext")
if api_key_plaintext:
values["api_key_encrypted"] = self._encrypt_api_key(
api_key_plaintext, user_id
)
if not values:
return False
values["updated_at"] = func.now()
t = user_custom_models_table
stmt = (
t.update()
.where(t.c.id == str(model_id))
.where(t.c.user_id == user_id)
.values(**values)
)
result = self._conn.execute(stmt)
return result.rowcount > 0
def delete(self, model_id: str, user_id: str) -> bool:
result = self._conn.execute(
text(
"DELETE FROM user_custom_models "
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
),
{"id": str(model_id), "user_id": user_id},
)
return result.rowcount > 0
# ------------------------------------------------------------------ #
# Decryption helpers exposed to the registry layer
# ------------------------------------------------------------------ #
def get_decrypted_api_key(
self, model_id: str, user_id: str
) -> Optional[str]:
"""Convenience: fetch the row and return the decrypted API key,
or ``None`` if the row is missing or decryption fails."""
row = self.get(model_id, user_id)
if row is None:
return None
return self._decrypt_api_key(row.get("api_key_encrypted", ""), user_id)

View File

@@ -1,9 +1,11 @@
"""Anonymous startup version-check client.
"""Anonymous version-check client.
Called once per Celery worker boot (see ``application/celery_init.py``
``worker_ready`` handler). Posts the running version + anonymous
instance UUID to ``gptcloud.arc53.com/api/check``, caches the response
in Redis, and surfaces any advisories to stdout + logs.
Fired on every Celery worker boot (see ``application/celery_init.py``
``worker_ready`` handler) and on a 7h periodic schedule (see the
``version-check`` entry in ``application/api/user/tasks.py``). Posts
the running version + anonymous instance UUID to
``gptcloud.arc53.com/api/check``, caches the response in Redis, and
surfaces any advisories to stdout + logs.
Design invariants — all enforced by a broad ``try/except`` at the top
of :func:`run_check`:

View File

@@ -1,5 +1,6 @@
import sys
import logging
import time
from datetime import datetime
from application.storage.db.repositories.token_usage import TokenUsageRepository
@@ -20,6 +21,15 @@ def _serialize_for_token_count(value):
if value is None:
return ""
# Raw binary payloads (image/file attachments arrive as ``bytes`` from
# ``GoogleLLM.prepare_messages_with_attachments``) — without this
# branch they fall through to ``str(value)`` below, which produces a
# multi-megabyte ``"b'\\x89PNG...'"`` repr-string and inflates
# ``prompt_tokens`` by orders of magnitude. Same intent as the
# data-URL skip above.
if isinstance(value, (bytes, bytearray, memoryview)):
return ""
if isinstance(value, list):
return [_serialize_for_token_count(item) for item in value]
@@ -145,19 +155,44 @@ def stream_token_usage(func):
**kwargs,
)
batch = []
result = func(self, model, messages, stream, tools, **kwargs)
for r in result:
batch.append(r)
yield r
for line in batch:
call_usage["generated_tokens"] += _count_tokens(line)
self.token_usage["prompt_tokens"] += call_usage["prompt_tokens"]
self.token_usage["generated_tokens"] += call_usage["generated_tokens"]
update_token_usage(
self.decoded_token,
self.user_api_key,
call_usage,
getattr(self, "agent_id", None),
)
started_at = time.monotonic()
error: BaseException | None = None
try:
result = func(self, model, messages, stream, tools, **kwargs)
for r in result:
batch.append(r)
yield r
except Exception as exc:
# ``GeneratorExit`` (consumer disconnected) and KeyboardInterrupt
# flow through as ``status="ok"`` — same convention as
# ``application.logging._consume_and_log``.
error = exc
raise
finally:
for line in batch:
call_usage["generated_tokens"] += _count_tokens(line)
self.token_usage["prompt_tokens"] += call_usage["prompt_tokens"]
self.token_usage["generated_tokens"] += call_usage["generated_tokens"]
# Persist usage rows only on success: a partial mid-stream
# failure shouldn't bill the user for a response they never got.
if error is None:
update_token_usage(
self.decoded_token,
self.user_api_key,
call_usage,
getattr(self, "agent_id", None),
)
emit = getattr(self, "_emit_stream_finished_log", None)
if callable(emit):
try:
emit(
model,
prompt_tokens=call_usage["prompt_tokens"],
completion_tokens=call_usage["generated_tokens"],
latency_ms=int((time.monotonic() - started_at) * 1000),
error=error,
)
except Exception:
logger.exception("Failed to emit llm_stream_finished")
return wrapper

View File

@@ -83,9 +83,9 @@ def count_tokens_docs(docs):
def calculate_doc_token_budget(
model_id: str = "gpt-4o"
model_id: str = "gpt-4o", user_id: str | None = None
) -> int:
total_context = get_token_limit(model_id)
total_context = get_token_limit(model_id, user_id=user_id)
reserved = sum(settings.RESERVED_TOKENS.values())
doc_budget = total_context - reserved
return max(doc_budget, 1000)
@@ -150,9 +150,11 @@ def get_hash(data):
return hashlib.md5(data.encode(), usedforsecurity=False).hexdigest()
def limit_chat_history(history, max_token_limit=None, model_id="docsgpt-local"):
def limit_chat_history(
history, max_token_limit=None, model_id="docsgpt-local", user_id=None
):
"""Limit chat history to fit within token limit."""
model_token_limit = get_token_limit(model_id)
model_token_limit = get_token_limit(model_id, user_id=user_id)
max_token_limit = (
max_token_limit
if max_token_limit and max_token_limit < model_token_limit
@@ -204,7 +206,9 @@ def generate_image_url(image_path):
def calculate_compression_threshold(
model_id: str, threshold_percentage: float = 0.8
model_id: str,
threshold_percentage: float = 0.8,
user_id: str | None = None,
) -> int:
"""
Calculate token threshold for triggering compression.
@@ -212,11 +216,13 @@ def calculate_compression_threshold(
Args:
model_id: Model identifier
threshold_percentage: Percentage of context window (default 80%)
user_id: When set, BYOM custom-model records (UUID-keyed) resolve
for context-window lookup.
Returns:
Token count threshold
"""
total_context = get_token_limit(model_id)
total_context = get_token_limit(model_id, user_id=user_id)
threshold = int(total_context * threshold_percentage)
return threshold

View File

@@ -344,18 +344,34 @@ def run_agent_logic(agent_config, input_data):
# Determine model_id: check agent's default_model_id, fallback to system default
agent_default_model = agent_config.get("default_model_id", "")
if agent_default_model and validate_model_id(agent_default_model):
if agent_default_model and validate_model_id(
agent_default_model, user_id=owner
):
model_id = agent_default_model
else:
model_id = get_default_model_id()
if agent_default_model:
# Stored model_id no longer resolves in the registry. Log so
# operators can detect bad YAML edits before users complain;
# behavior matches the historical silent fallback.
logging.warning(
"Agent %s references unknown model_id %r; falling back to %r",
agent_id,
agent_default_model,
model_id,
)
# Get provider and API key for the selected model
provider = get_provider_from_model_id(model_id) if model_id else settings.LLM_PROVIDER
provider = (
get_provider_from_model_id(model_id, user_id=owner)
if model_id
else settings.LLM_PROVIDER
)
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
# Calculate proper doc_token_limit based on model's context window
doc_token_limit = calculate_doc_token_budget(
model_id=model_id
model_id=model_id, user_id=owner
)
retriever = RetrieverCreator.create_retriever(
@@ -416,7 +432,10 @@ def run_agent_logic(agent_config, input_data):
"tool_calls": tool_calls,
"thought": thought,
}
logging.info(f"Agent response: {result}")
# Per-activity summary fields (answer_length, thought_length,
# source_count, tool_call_count) now ride on the inner
# ``activity_finished`` event emitted by ``log_activity`` around
# ``Agent.gen`` above; no separate ``agent_response`` log needed.
return result
except Exception as e:
logging.error(f"Error in run_agent_logic: {e}", exc_info=True)

View File

@@ -104,7 +104,15 @@ To run the DocsGPT backend locally, you'll need to set up a Python environment a
flask --app application/app.py run --host=0.0.0.0 --port=7091
```
This command will launch the backend server, making it accessible on `http://localhost:7091`.
This command will launch the backend server, making it accessible on `http://localhost:7091`. It's the fastest inner-loop option for day-to-day development — the Werkzeug interactive debugger still works and it hot-reloads on source changes. It serves the Flask routes only.
If you need to exercise the full ASGI stack — the `/mcp` endpoint (FastMCP server), or to match the production runtime — run the ASGI composition under uvicorn instead:
```bash
uvicorn application.asgi:asgi_app --host 0.0.0.0 --port 7091 --reload
```
Production uses `gunicorn -k uvicorn_worker.UvicornWorker` against the same `application.asgi:asgi_app` target.
6. **Start the Celery Worker:**

View File

@@ -99,6 +99,82 @@ EMBEDDINGS_NAME=huggingface_sentence-transformers/all-mpnet-base-v2 # You can al
In this case, even though you are using Ollama locally, `LLM_PROVIDER` is set to `openai` because Ollama (and many other local inference engines) are designed to be API-compatible with OpenAI. `OPENAI_BASE_URL` points DocsGPT to the local Ollama server.
## Adding Custom Models (`MODELS_CONFIG_DIR`)
DocsGPT ships with a built-in catalog of models for the providers it
supports out of the box (OpenAI, Anthropic, Google, Groq, OpenRouter,
Novita, Azure OpenAI, Hugging Face, DocsGPT). To add **your own
models** without forking the repo — for example, a Mistral or Together
account, a self-hosted vLLM endpoint, or any other OpenAI-compatible
API — point `MODELS_CONFIG_DIR` at a directory of YAML files.
```
MODELS_CONFIG_DIR=/etc/docsgpt/models
MISTRAL_API_KEY=sk-...
```
A minimal YAML for one provider:
```yaml
# /etc/docsgpt/models/mistral.yaml
provider: openai_compatible
display_provider: mistral
api_key_env: MISTRAL_API_KEY
base_url: https://api.mistral.ai/v1
defaults:
supports_tools: true
context_window: 128000
models:
- id: mistral-large-latest
display_name: Mistral Large
- id: mistral-small-latest
display_name: Mistral Small
```
After restart, those models appear in `/api/models` and are selectable
in the UI. A working template lives at
`application/core/models/examples/mistral.yaml.example`.
**What you can do:**
- Add new `openai_compatible` providers (Mistral, Together, Fireworks,
Ollama, vLLM, ...) — one YAML per provider, each with its own
`api_key_env` and `base_url`.
- Extend an existing provider's catalog by dropping a YAML with the
same `provider:` value as the built-in (e.g. `provider: anthropic`
with extra models).
- Override a built-in model's capabilities by re-declaring the same
`id` — later wins, override is logged at `WARNING`.
**What you cannot do via `MODELS_CONFIG_DIR`:** add a brand-new
non-OpenAI provider. That requires a Python plugin under
`application/llm/providers/`. See
`application/core/models/README.md` for the full schema reference.
### Docker
Mount the directory and set the env var:
```yaml
# docker-compose.yml
services:
app:
image: arc53/docsgpt
environment:
MODELS_CONFIG_DIR: /etc/docsgpt/models
MISTRAL_API_KEY: ${MISTRAL_API_KEY}
volumes:
- ./my-models:/etc/docsgpt/models:ro
```
### Misconfiguration
If `MODELS_CONFIG_DIR` is set but the path doesn't exist (or isn't a
directory), the app logs a `WARNING` at boot and continues with just
the built-in catalog — it does **not** fail to start. If a YAML
declares an unknown provider name or has a schema error, the app
**does** fail to start, with the offending file path in the message.
## Speech-to-Text Settings
DocsGPT can transcribe audio in two places:

View File

@@ -0,0 +1,111 @@
---
title: Observability
description: Send traces, metrics, and logs from DocsGPT to any OpenTelemetry-compatible backend (Axiom, Honeycomb, Grafana, Datadog, Jaeger, etc.).
---
import { Callout } from 'nextra/components'
# Observability
DocsGPT bundles the OpenTelemetry SDK and auto-instrumentation packages
in `application/requirements.txt` — they install with the rest of the
backend deps. Telemetry is **off by default**; opt in by prefixing the
launch command with `opentelemetry-instrument` and setting OTLP env
vars.
Auto-instrumentation covers Flask, Starlette, Celery, SQLAlchemy,
psycopg, Redis, requests, and Python logging. LLM/retriever calls are
not captured at this layer — see *Going further* below.
## Enabling
Set these env vars in your `.env` (or compose `environment:` block):
```bash
OTEL_SDK_DISABLED=false
OTEL_EXPORTER_OTLP_PROTOCOL=http/protobuf
OTEL_EXPORTER_OTLP_ENDPOINT=https://your-collector.example.com
OTEL_EXPORTER_OTLP_HEADERS=Authorization=Bearer%20<token>
OTEL_TRACES_EXPORTER=otlp
OTEL_METRICS_EXPORTER=otlp
OTEL_LOGS_EXPORTER=otlp
OTEL_PYTHON_LOG_CORRELATION=true
OTEL_RESOURCE_ATTRIBUTES=service.name=docsgpt-backend,deployment.environment=prod
```
Then prefix the process command with `opentelemetry-instrument`. The
simplest way is a compose override (no image rebuild):
```yaml
# deployment/docker-compose.override.yaml
services:
backend:
command: >
opentelemetry-instrument gunicorn -w 1 -k uvicorn_worker.UvicornWorker
--bind 0.0.0.0:7091 --config application/gunicorn_conf.py
application.asgi:asgi_app
environment:
- OTEL_SERVICE_NAME=docsgpt-backend
worker:
command: opentelemetry-instrument celery -A application.app.celery worker -l INFO -B
environment:
- OTEL_SERVICE_NAME=docsgpt-celery-worker
```
For local dev, prepend `dotenv run --` so the `OTEL_*` vars from `.env`
reach `opentelemetry-instrument` before it boots the SDK:
```bash
dotenv run -- opentelemetry-instrument flask --app application/app.py run --port=7091
dotenv run -- opentelemetry-instrument celery -A application.app.celery worker -l INFO --pool=solo
```
<Callout type="info" emoji="">
Logs are exported in-process when `OTEL_LOGS_EXPORTER=otlp` is set —
`application/core/logging_config.py` detects the flag and preserves
the OTEL log handler. Without it, `logging` writes only to stdout.
</Callout>
## Backend examples
### Axiom
```bash
OTEL_EXPORTER_OTLP_ENDPOINT=https://api.axiom.co
OTEL_EXPORTER_OTLP_HEADERS=Authorization=Bearer%20xaat-XXXX,X-Axiom-Dataset=docsgpt
OTEL_EXPORTER_OTLP_PROTOCOL=http/protobuf
```
`%20` is the URL-encoded space between `Bearer` and the token. Create
the dataset in the Axiom UI before sending.
### Self-hosted OTLP collector / Jaeger / Tempo
```bash
OTEL_EXPORTER_OTLP_ENDPOINT=http://otel-collector:4317
OTEL_EXPORTER_OTLP_PROTOCOL=grpc
```
### Honeycomb / Grafana Cloud / Datadog
Each vendor publishes a single-line `OTEL_EXPORTER_OTLP_ENDPOINT` plus
`OTEL_EXPORTER_OTLP_HEADERS` recipe — drop them in alongside the
service-name override.
## Caveats
- The Dockerfile uses `gunicorn -w 1`. If you raise worker count, move
SDK init into a `post_worker_init` hook to avoid one-thread-per-process
exporter contention.
- `asgi.py` wraps Flask in Starlette's `WSGIMiddleware`. Both
instrumentors are installed, so each request produces a Starlette
span enclosing a Flask span. Drop
`opentelemetry-instrumentation-flask` from `requirements.txt` if the
duplication is noisy.
- OTEL packages add ~50 MB to the image. They install on every build —
the runtime cost is zero unless you set `opentelemetry-instrument` on
the command and set the OTLP env vars.
- The OTEL exporter ecosystem currently caps `protobuf` at `<7`, so the
backend runs on protobuf 6.x. This will catch up in a future OTEL
release.

View File

@@ -23,6 +23,10 @@ export default {
"title": "🐘 PostgreSQL for User Data",
"href": "/Deploying/Postgres-Migration"
},
"Observability": {
"title": "🔭 Observability",
"href": "/Deploying/Observability"
},
"Amazon-Lightsail": {
"title": "Hosting DocsGPT on Amazon Lightsail",
"href": "/Deploying/Amazon-Lightsail",

View File

@@ -1,38 +0,0 @@
---
title: Add DocsGPT Chrome Extension to Your Browser
description: Install the DocsGPT Chrome extension to access AI-powered document assistance directly from your browser for enhanced productivity.
---
import {Steps} from 'nextra/components'
import { Callout } from 'nextra/components'
## Chrome Extension Setup Guide
To enhance your DocsGPT experience, you can install the DocsGPT Chrome extension. Here's how:
<Steps >
### Step 1
In the DocsGPT GitHub repository, click on the **Code** button and select **Download ZIP**.
### Step 2
Unzip the downloaded file to a location you can easily access.
### Step 3
Open the Google Chrome browser and click on the three dots menu (upper right corner).
### Step 4
Select **More Tools** and then **Extensions**.
### Step 5
Turn on the **Developer mode** switch in the top right corner of the **Extensions page**.
### Step 6
Click on the **Load unpacked** button.
### Step 7
7. Select the **Chrome** folder where the DocsGPT files have been unzipped (docsgpt-main > extensions > chrome).
### Step 8
The extension should now be added to Google Chrome and can be managed on the Extensions page.
### Step 9
To disable or remove the extension, simply turn off the toggle switch on the extension card or click the **Remove** button.
</Steps>

View File

@@ -11,10 +11,6 @@ export default {
"title": "🔎 Search Widget",
"href": "/Extensions/search-widget"
},
"Chrome-extension": {
"title": "🌐 Chrome Extension",
"href": "/Extensions/Chrome-extension"
},
"Chatwoot-extension": {
"title": "🗣️ Chatwoot Extension",
"href": "/Extensions/Chatwoot-extension"

View File

@@ -1,66 +0,0 @@
{
"l10nTabName": {
"message":"Localization"
,"description":"name of the localization tab"
}
,"l10nHeader": {
"message":"It does localization too! (this whole tab is, actually)"
,"description":"Header text for the localization section"
}
,"l10nIntro": {
"message":"'L10n' refers to 'Localization' - 'L' an 'n' are obvious, and 10 comes from the number of letters between those two. It is the process/whatever of displaying something in the language of choice. It uses 'I18n', 'Internationalization', which refers to the tools / framework supporting L10n. I.e., something is internationalized if it has I18n support, and can be localized. Something is localized for you if it is in your language / dialect."
,"description":"introduce the basic idea."
}
,"l10nProd": {
"message":"You <strong>are</strong> planning to allow localization, right? You have <em>no idea</em> who will be using your extension! You have no idea who will be translating it! At least support the basics, it's not hard, and having the framework in place will let you transition much more easily later on."
,"description":"drive the point home. It's good for you."
}
,"l10nFirstParagraph": {
"message":"When the options page loads, elements decorated with <strong>data-l10n</strong> will automatically be localized!"
,"description":"inform that <el data-l10n='' /> elements will be localized on load"
}
,"l10nSecondParagraph": {
"message":"If you need more complex localization, you can also define <strong>data-l10n-args</strong>. This should contain <span class='code'>$containerType$</span> filled with <span class='code'>$dataType$</span>, which will be passed into Chrome's i18n API as <span class='code'>$functionArgs$</span>. In fact, this paragraph does just that, and wraps the args in mono-space font. Easy!"
,"description":"introduce the data-l10n-args attribute. End on a lame note."
,"placeholders": {
"containerType": {
"content":"$1"
,"example":"'array', 'list', or something similar"
,"description":"type of the args container"
}
,"dataType": {
"content":"$2"
,"example":"string"
,"description":"type of data in each array index"
}
,"functionArgs": {
"content":"$3"
,"example":"arguments"
,"description":"whatever you call what you pass into a function/method. args, params, etc."
}
}
}
,"l10nThirdParagraph": {
"message":"Message contents are passed right into innerHTML without processing - include any tags (or even scripts) that you feel like. If you have an input field, the placeholder will be set instead, and buttons will have the value attribute set."
,"description":"inform that we handle placeholders, buttons, and direct HTML input"
}
,"l10nButtonsBefore": {
"message":"Different types of buttons are handled as well. &lt;button&gt; elements have their html set:"
}
,"l10nButton": {
"message":"in a <strong>button</strong>"
}
,"l10nButtonsBetween": {
"message":"while &lt;input type='submit'&gt; and &lt;input type='button'&gt; get their 'value' set (note: no HTML):"
}
,"l10nSubmit": {
"message":"a <strong>submit</strong> value"
}
,"l10nButtonsAfter": {
"message":"Awesome, no?"
}
,"l10nExtras": {
"message":"You can even set <span class='code'>data-l10n</span> on things like the &lt;title&gt; tag, which lets you have translatable page titles, or fieldset &lt;legend&gt; tags, or anywhere else - the default <span class='code'>Boil.localize()</span> behavior will check every tag in the document, not just the body."
,"description":"inform about places which may not be obvious, like <title>, etc"
}
}

Some files were not shown because too many files have changed in this diff Show More