mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-10 20:41:57 +00:00
Compare commits
31 Commits
dependabot
...
0.17.1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0d2a8e11f4 | ||
|
|
f0c39dec23 | ||
|
|
552bfe016a | ||
|
|
a6a5db631b | ||
|
|
8e9f661efc | ||
|
|
82c71be819 | ||
|
|
318de18d43 | ||
|
|
af618de13d | ||
|
|
ef976eeb06 | ||
|
|
9c8ae9d540 | ||
|
|
7ca33b2b72 | ||
|
|
d1b9798f62 | ||
|
|
ddc3adf3ab | ||
|
|
a4991d01ac | ||
|
|
87fd1bd359 | ||
|
|
c71e986d34 | ||
|
|
a2a06c569e | ||
|
|
c5f00a1d1b | ||
|
|
2a15bb0102 | ||
|
|
c06888bc86 | ||
|
|
d4b1c1fd81 | ||
|
|
2de84acf81 | ||
|
|
2702750861 | ||
|
|
2b5f20d0ec | ||
|
|
619b41dc5b | ||
|
|
76d8f49ccb | ||
|
|
65460b0c03 | ||
|
|
9fe96fb50f | ||
|
|
08822c3379 | ||
|
|
68ca8ff9ea | ||
|
|
81be3cdccc |
@@ -35,8 +35,5 @@ MICROSOFT_TENANT_ID=your-azure-ad-tenant-id
|
|||||||
#Alternatively, use "https://login.microsoftonline.com/common" for multi-tenant app.
|
#Alternatively, use "https://login.microsoftonline.com/common" for multi-tenant app.
|
||||||
MICROSOFT_AUTHORITY=https://{tenantId}.ciamlogin.com/{tenantId}
|
MICROSOFT_AUTHORITY=https://{tenantId}.ciamlogin.com/{tenantId}
|
||||||
|
|
||||||
# User-data Postgres DB (Phase 0 of the MongoDB→Postgres migration).
|
|
||||||
# Standard Postgres URI — `postgres://` and `postgresql://` both work.
|
|
||||||
# Leave unset while the migration is still being rolled out; the app will
|
|
||||||
# fall back to MongoDB for user data until POSTGRES_URI is configured.
|
|
||||||
# POSTGRES_URI=postgresql://docsgpt:docsgpt@localhost:5432/docsgpt
|
# POSTGRES_URI=postgresql://docsgpt:docsgpt@localhost:5432/docsgpt
|
||||||
|
|||||||
18
AGENTS.md
18
AGENTS.md
@@ -37,6 +37,22 @@ Run the Flask API (if needed):
|
|||||||
flask --app application/app.py run --host=0.0.0.0 --port=7091
|
flask --app application/app.py run --host=0.0.0.0 --port=7091
|
||||||
```
|
```
|
||||||
|
|
||||||
|
That's the fast inner-loop option — quick startup, the Werkzeug interactive
|
||||||
|
debugger still works, and it hot-reloads on source changes. It serves the
|
||||||
|
Flask routes only (`/api/*`, `/stream`, etc.).
|
||||||
|
|
||||||
|
If you need to exercise the full ASGI stack — the `/mcp` FastMCP endpoint,
|
||||||
|
or to match the production runtime exactly — run the ASGI composition under
|
||||||
|
uvicorn instead:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uvicorn application.asgi:asgi_app --host 0.0.0.0 --port 7091 --reload
|
||||||
|
```
|
||||||
|
|
||||||
|
Production uses `gunicorn -k uvicorn_worker.UvicornWorker` against the same
|
||||||
|
`application.asgi:asgi_app` target; see `application/Dockerfile` for the
|
||||||
|
full flag set.
|
||||||
|
|
||||||
Run the Celery worker in a separate terminal (if needed):
|
Run the Celery worker in a separate terminal (if needed):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -99,7 +115,7 @@ vale .
|
|||||||
- `frontend/`: Vite + React + TypeScript application.
|
- `frontend/`: Vite + React + TypeScript application.
|
||||||
- `frontend/src/`: main UI code, including `components`, `conversation`, `hooks`, `locale`, `settings`, `upload`, and Redux store wiring in `store.ts`.
|
- `frontend/src/`: main UI code, including `components`, `conversation`, `hooks`, `locale`, `settings`, `upload`, and Redux store wiring in `store.ts`.
|
||||||
- `docs/`: separate documentation site built with Next.js/Nextra.
|
- `docs/`: separate documentation site built with Next.js/Nextra.
|
||||||
- `extensions/`: integrations and widgets such as Chatwoot, Chrome, Discord, React widget, Slack bot, and web widget.
|
- `extensions/`: integrations and widgets — currently the Chatwoot webhook bridge and the React widget (published to npm as `docsgpt`). The Discord bot, Slack bot, and Chrome extension have been moved to their own repos under `arc53/`.
|
||||||
- `deployment/`: Docker Compose variants and Kubernetes manifests.
|
- `deployment/`: Docker Compose variants and Kubernetes manifests.
|
||||||
|
|
||||||
## Coding rules
|
## Coding rules
|
||||||
|
|||||||
12
README.md
12
README.md
@@ -47,11 +47,13 @@
|
|||||||
</ul>
|
</ul>
|
||||||
|
|
||||||
## Roadmap
|
## Roadmap
|
||||||
- [x] Add OAuth 2.0 authentication for MCP ( September 2025 )
|
- [x] Agent Workflow Builder with conditional nodes ( February 2026 )
|
||||||
- [x] Deep Agents ( October 2025 )
|
- [x] SharePoint & Confluence connectors ( March – April 2026 )
|
||||||
- [x] Prompt Templating ( October 2025 )
|
- [x] Research mode ( March 2026 )
|
||||||
- [x] Full api tooling ( Dec 2025 )
|
- [x] Postgres migration for user data ( April 2026 )
|
||||||
- [ ] Agent scheduling ( Jan 2026 )
|
- [x] OpenTelemetry observability ( April 2026 )
|
||||||
|
- [x] Bring Your Own Model (BYOM) ( April 2026 )
|
||||||
|
- [ ] Agent scheduling (RedBeat-backed) ( Q2 2026 )
|
||||||
|
|
||||||
You can find our full roadmap [here](https://github.com/orgs/arc53/projects/2). Please don't hesitate to contribute or create issues, it helps us improve DocsGPT!
|
You can find our full roadmap [here](https://github.com/orgs/arc53/projects/2). Please don't hesitate to contribute or create issues, it helps us improve DocsGPT!
|
||||||
|
|
||||||
|
|||||||
@@ -88,5 +88,15 @@ EXPOSE 7091
|
|||||||
# Switch to non-root user
|
# Switch to non-root user
|
||||||
USER appuser
|
USER appuser
|
||||||
|
|
||||||
# Start Gunicorn
|
CMD ["gunicorn", \
|
||||||
CMD ["gunicorn", "-w", "1", "--timeout", "120", "--bind", "0.0.0.0:7091", "--preload", "application.wsgi:app"]
|
"-w", "1", \
|
||||||
|
"-k", "uvicorn_worker.UvicornWorker", \
|
||||||
|
"--bind", "0.0.0.0:7091", \
|
||||||
|
"--timeout", "180", \
|
||||||
|
"--graceful-timeout", "120", \
|
||||||
|
"--keep-alive", "5", \
|
||||||
|
"--worker-tmp-dir", "/dev/shm", \
|
||||||
|
"--max-requests", "1000", \
|
||||||
|
"--max-requests-jitter", "100", \
|
||||||
|
"--config", "application/gunicorn_conf.py", \
|
||||||
|
"application.asgi:asgi_app"]
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ class BaseAgent(ABC):
|
|||||||
llm_handler=None,
|
llm_handler=None,
|
||||||
tool_executor: Optional[ToolExecutor] = None,
|
tool_executor: Optional[ToolExecutor] = None,
|
||||||
backup_models: Optional[List[str]] = None,
|
backup_models: Optional[List[str]] = None,
|
||||||
|
model_user_id: Optional[str] = None,
|
||||||
):
|
):
|
||||||
self.endpoint = endpoint
|
self.endpoint = endpoint
|
||||||
self.llm_name = llm_name
|
self.llm_name = llm_name
|
||||||
@@ -52,10 +53,13 @@ class BaseAgent(ABC):
|
|||||||
self.prompt = prompt
|
self.prompt = prompt
|
||||||
self.decoded_token = decoded_token or {}
|
self.decoded_token = decoded_token or {}
|
||||||
self.user: str = self.decoded_token.get("sub")
|
self.user: str = self.decoded_token.get("sub")
|
||||||
|
# BYOM-resolution scope: owner for shared agents, caller for
|
||||||
|
# caller-owned BYOM, None for built-ins. Falls back to self.user
|
||||||
|
# for worker/legacy callers that don't thread model_user_id.
|
||||||
|
self.model_user_id = model_user_id
|
||||||
self.tools: List[Dict] = []
|
self.tools: List[Dict] = []
|
||||||
self.chat_history: List[Dict] = chat_history if chat_history is not None else []
|
self.chat_history: List[Dict] = chat_history if chat_history is not None else []
|
||||||
|
|
||||||
# Dependency injection for LLM — fall back to creating if not provided
|
|
||||||
if llm is not None:
|
if llm is not None:
|
||||||
self.llm = llm
|
self.llm = llm
|
||||||
else:
|
else:
|
||||||
@@ -67,8 +71,16 @@ class BaseAgent(ABC):
|
|||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
backup_models=backup_models,
|
backup_models=backup_models,
|
||||||
|
model_user_id=model_user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# For BYOM, registry id (UUID) differs from upstream model id
|
||||||
|
# (e.g. ``mistral-large-latest``). LLMCreator resolved this onto
|
||||||
|
# the LLM instance; cache it for subsequent gen calls.
|
||||||
|
self.upstream_model_id = (
|
||||||
|
getattr(self.llm, "model_id", None) or model_id
|
||||||
|
)
|
||||||
|
|
||||||
self.retrieved_docs = retrieved_docs or []
|
self.retrieved_docs = retrieved_docs or []
|
||||||
|
|
||||||
if llm_handler is not None:
|
if llm_handler is not None:
|
||||||
@@ -306,7 +318,9 @@ class BaseAgent(ABC):
|
|||||||
try:
|
try:
|
||||||
current_tokens = self._calculate_current_context_tokens(messages)
|
current_tokens = self._calculate_current_context_tokens(messages)
|
||||||
self.current_token_count = current_tokens
|
self.current_token_count = current_tokens
|
||||||
context_limit = get_token_limit(self.model_id)
|
context_limit = get_token_limit(
|
||||||
|
self.model_id, user_id=self.model_user_id or self.user
|
||||||
|
)
|
||||||
threshold = int(context_limit * settings.COMPRESSION_THRESHOLD_PERCENTAGE)
|
threshold = int(context_limit * settings.COMPRESSION_THRESHOLD_PERCENTAGE)
|
||||||
|
|
||||||
if current_tokens >= threshold:
|
if current_tokens >= threshold:
|
||||||
@@ -325,7 +339,9 @@ class BaseAgent(ABC):
|
|||||||
|
|
||||||
current_tokens = self._calculate_current_context_tokens(messages)
|
current_tokens = self._calculate_current_context_tokens(messages)
|
||||||
self.current_token_count = current_tokens
|
self.current_token_count = current_tokens
|
||||||
context_limit = get_token_limit(self.model_id)
|
context_limit = get_token_limit(
|
||||||
|
self.model_id, user_id=self.model_user_id or self.user
|
||||||
|
)
|
||||||
percentage = (current_tokens / context_limit) * 100
|
percentage = (current_tokens / context_limit) * 100
|
||||||
|
|
||||||
if current_tokens >= context_limit:
|
if current_tokens >= context_limit:
|
||||||
@@ -387,7 +403,9 @@ class BaseAgent(ABC):
|
|||||||
)
|
)
|
||||||
system_prompt = system_prompt + compression_context
|
system_prompt = system_prompt + compression_context
|
||||||
|
|
||||||
context_limit = get_token_limit(self.model_id)
|
context_limit = get_token_limit(
|
||||||
|
self.model_id, user_id=self.model_user_id or self.user
|
||||||
|
)
|
||||||
system_tokens = num_tokens_from_string(system_prompt)
|
system_tokens = num_tokens_from_string(system_prompt)
|
||||||
|
|
||||||
safety_buffer = int(context_limit * 0.1)
|
safety_buffer = int(context_limit * 0.1)
|
||||||
@@ -497,7 +515,10 @@ class BaseAgent(ABC):
|
|||||||
def _llm_gen(self, messages: List[Dict], log_context: Optional[LogContext] = None):
|
def _llm_gen(self, messages: List[Dict], log_context: Optional[LogContext] = None):
|
||||||
self._validate_context_size(messages)
|
self._validate_context_size(messages)
|
||||||
|
|
||||||
gen_kwargs = {"model": self.model_id, "messages": messages}
|
# Use the upstream id resolved by LLMCreator (see __init__).
|
||||||
|
# Built-in models: same as self.model_id. BYOM: the user's
|
||||||
|
# typed model name, not the internal UUID.
|
||||||
|
gen_kwargs = {"model": self.upstream_model_id, "messages": messages}
|
||||||
if self.attachments:
|
if self.attachments:
|
||||||
gen_kwargs["_usage_attachments"] = self.attachments
|
gen_kwargs["_usage_attachments"] = self.attachments
|
||||||
|
|
||||||
|
|||||||
@@ -312,7 +312,7 @@ class ResearchAgent(BaseAgent):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
response = self.llm.gen(
|
response = self.llm.gen(
|
||||||
model=self.model_id,
|
model=self.upstream_model_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=None,
|
tools=None,
|
||||||
response_format={"type": "json_object"},
|
response_format={"type": "json_object"},
|
||||||
@@ -390,7 +390,7 @@ class ResearchAgent(BaseAgent):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
response = self.llm.gen(
|
response = self.llm.gen(
|
||||||
model=self.model_id,
|
model=self.upstream_model_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=None,
|
tools=None,
|
||||||
response_format={"type": "json_object"},
|
response_format={"type": "json_object"},
|
||||||
@@ -506,7 +506,7 @@ class ResearchAgent(BaseAgent):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
response = self.llm.gen(
|
response = self.llm.gen(
|
||||||
model=self.model_id,
|
model=self.upstream_model_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=self.tools if self.tools else None,
|
tools=self.tools if self.tools else None,
|
||||||
)
|
)
|
||||||
@@ -537,7 +537,7 @@ class ResearchAgent(BaseAgent):
|
|||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
response = self.llm.gen(
|
response = self.llm.gen(
|
||||||
model=self.model_id, messages=messages, tools=None
|
model=self.upstream_model_id, messages=messages, tools=None
|
||||||
)
|
)
|
||||||
self._track_tokens(self._snapshot_llm_tokens())
|
self._track_tokens(self._snapshot_llm_tokens())
|
||||||
text = self._extract_text(response)
|
text = self._extract_text(response)
|
||||||
@@ -664,7 +664,7 @@ class ResearchAgent(BaseAgent):
|
|||||||
]
|
]
|
||||||
|
|
||||||
llm_response = self.llm.gen_stream(
|
llm_response = self.llm.gen_stream(
|
||||||
model=self.model_id, messages=messages, tools=None
|
model=self.upstream_model_id, messages=messages, tools=None
|
||||||
)
|
)
|
||||||
|
|
||||||
if log_context:
|
if log_context:
|
||||||
|
|||||||
@@ -274,7 +274,14 @@ class ToolExecutor:
|
|||||||
|
|
||||||
if tool_id is None or action_name is None:
|
if tool_id is None or action_name is None:
|
||||||
error_message = f"Error: Failed to parse LLM tool call. Tool name: {llm_name}"
|
error_message = f"Error: Failed to parse LLM tool call. Tool name: {llm_name}"
|
||||||
logger.error(error_message)
|
logger.error(
|
||||||
|
"tool_call_parse_failed",
|
||||||
|
extra={
|
||||||
|
"llm_class_name": llm_class_name,
|
||||||
|
"llm_tool_name": llm_name,
|
||||||
|
"call_id": call_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
tool_call_data = {
|
tool_call_data = {
|
||||||
"tool_name": "unknown",
|
"tool_name": "unknown",
|
||||||
@@ -289,7 +296,15 @@ class ToolExecutor:
|
|||||||
|
|
||||||
if tool_id not in tools_dict:
|
if tool_id not in tools_dict:
|
||||||
error_message = f"Error: Tool ID '{tool_id}' extracted from LLM call not found in available tools_dict. Available IDs: {list(tools_dict.keys())}"
|
error_message = f"Error: Tool ID '{tool_id}' extracted from LLM call not found in available tools_dict. Available IDs: {list(tools_dict.keys())}"
|
||||||
logger.error(error_message)
|
logger.error(
|
||||||
|
"tool_id_not_found",
|
||||||
|
extra={
|
||||||
|
"tool_id": tool_id,
|
||||||
|
"llm_tool_name": llm_name,
|
||||||
|
"call_id": call_id,
|
||||||
|
"available_tool_count": len(tools_dict),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
tool_call_data = {
|
tool_call_data = {
|
||||||
"tool_name": "unknown",
|
"tool_name": "unknown",
|
||||||
@@ -356,7 +371,15 @@ class ToolExecutor:
|
|||||||
f"Failed to load tool '{tool_data.get('name')}' (tool_id key={tool_id}): "
|
f"Failed to load tool '{tool_data.get('name')}' (tool_id key={tool_id}): "
|
||||||
"missing 'id' on tool row."
|
"missing 'id' on tool row."
|
||||||
)
|
)
|
||||||
logger.error(error_message)
|
logger.error(
|
||||||
|
"tool_load_failed",
|
||||||
|
extra={
|
||||||
|
"tool_name": tool_data.get("name"),
|
||||||
|
"tool_id": tool_id,
|
||||||
|
"action_name": action_name,
|
||||||
|
"call_id": call_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
tool_call_data["result"] = error_message
|
tool_call_data["result"] = error_message
|
||||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
|
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
|
||||||
self.tool_calls.append(tool_call_data)
|
self.tool_calls.append(tool_call_data)
|
||||||
@@ -451,10 +474,12 @@ class ToolExecutor:
|
|||||||
row_id = tool_data.get("id")
|
row_id = tool_data.get("id")
|
||||||
if not row_id:
|
if not row_id:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Tool data missing 'id' for tool name=%s (enumerate-key tool_id=%s); "
|
"tool_missing_row_id",
|
||||||
"skipping load to avoid binding a non-UUID downstream.",
|
extra={
|
||||||
tool_data.get("name"),
|
"tool_name": tool_data.get("name"),
|
||||||
tool_id,
|
"tool_id": tool_id,
|
||||||
|
"action_name": action_name,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
tool_config["tool_id"] = str(row_id)
|
tool_config["tool_id"] = str(row_id)
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ class InternalSearchTool(Tool):
|
|||||||
chunks=int(self.config.get("chunks", 2)),
|
chunks=int(self.config.get("chunks", 2)),
|
||||||
doc_token_limit=int(self.config.get("doc_token_limit", 50000)),
|
doc_token_limit=int(self.config.get("doc_token_limit", 50000)),
|
||||||
model_id=self.config.get("model_id", "docsgpt-local"),
|
model_id=self.config.get("model_id", "docsgpt-local"),
|
||||||
|
model_user_id=self.config.get("model_user_id"),
|
||||||
user_api_key=self.config.get("user_api_key"),
|
user_api_key=self.config.get("user_api_key"),
|
||||||
agent_id=self.config.get("agent_id"),
|
agent_id=self.config.get("agent_id"),
|
||||||
llm_name=self.config.get("llm_name", settings.LLM_PROVIDER),
|
llm_name=self.config.get("llm_name", settings.LLM_PROVIDER),
|
||||||
@@ -435,6 +436,7 @@ def build_internal_tool_config(
|
|||||||
chunks: int = 2,
|
chunks: int = 2,
|
||||||
doc_token_limit: int = 50000,
|
doc_token_limit: int = 50000,
|
||||||
model_id: str = "docsgpt-local",
|
model_id: str = "docsgpt-local",
|
||||||
|
model_user_id: Optional[str] = None,
|
||||||
user_api_key: Optional[str] = None,
|
user_api_key: Optional[str] = None,
|
||||||
agent_id: Optional[str] = None,
|
agent_id: Optional[str] = None,
|
||||||
llm_name: str = None,
|
llm_name: str = None,
|
||||||
@@ -449,6 +451,7 @@ def build_internal_tool_config(
|
|||||||
"chunks": chunks,
|
"chunks": chunks,
|
||||||
"doc_token_limit": doc_token_limit,
|
"doc_token_limit": doc_token_limit,
|
||||||
"model_id": model_id,
|
"model_id": model_id,
|
||||||
|
"model_user_id": model_user_id,
|
||||||
"user_api_key": user_api_key,
|
"user_api_key": user_api_key,
|
||||||
"agent_id": agent_id,
|
"agent_id": agent_id,
|
||||||
"llm_name": llm_name or settings.LLM_PROVIDER,
|
"llm_name": llm_name or settings.LLM_PROVIDER,
|
||||||
|
|||||||
@@ -211,15 +211,26 @@ class WorkflowEngine:
|
|||||||
node_config.json_schema, node.title
|
node_config.json_schema, node.title
|
||||||
)
|
)
|
||||||
node_model_id = node_config.model_id or self.agent.model_id
|
node_model_id = node_config.model_id or self.agent.model_id
|
||||||
|
# Inherit BYOM scope from parent agent so owner-stored BYOM
|
||||||
|
# resolves on shared workflows.
|
||||||
|
node_user_id = getattr(self.agent, "model_user_id", None) or (
|
||||||
|
self.agent.decoded_token.get("sub")
|
||||||
|
if isinstance(self.agent.decoded_token, dict)
|
||||||
|
else None
|
||||||
|
)
|
||||||
node_llm_name = (
|
node_llm_name = (
|
||||||
node_config.llm_name
|
node_config.llm_name
|
||||||
or get_provider_from_model_id(node_model_id or "")
|
or get_provider_from_model_id(
|
||||||
|
node_model_id or "", user_id=node_user_id
|
||||||
|
)
|
||||||
or self.agent.llm_name
|
or self.agent.llm_name
|
||||||
)
|
)
|
||||||
node_api_key = get_api_key_for_provider(node_llm_name) or self.agent.api_key
|
node_api_key = get_api_key_for_provider(node_llm_name) or self.agent.api_key
|
||||||
|
|
||||||
if node_json_schema and node_model_id:
|
if node_json_schema and node_model_id:
|
||||||
model_capabilities = get_model_capabilities(node_model_id)
|
model_capabilities = get_model_capabilities(
|
||||||
|
node_model_id, user_id=node_user_id
|
||||||
|
)
|
||||||
if model_capabilities and not model_capabilities.get(
|
if model_capabilities and not model_capabilities.get(
|
||||||
"supports_structured_output", False
|
"supports_structured_output", False
|
||||||
):
|
):
|
||||||
@@ -232,6 +243,7 @@ class WorkflowEngine:
|
|||||||
"endpoint": self.agent.endpoint,
|
"endpoint": self.agent.endpoint,
|
||||||
"llm_name": node_llm_name,
|
"llm_name": node_llm_name,
|
||||||
"model_id": node_model_id,
|
"model_id": node_model_id,
|
||||||
|
"model_user_id": getattr(self.agent, "model_user_id", None),
|
||||||
"api_key": node_api_key,
|
"api_key": node_api_key,
|
||||||
"tool_ids": node_config.tools,
|
"tool_ids": node_config.tools,
|
||||||
"prompt": node_config.system_prompt,
|
"prompt": node_config.system_prompt,
|
||||||
|
|||||||
37
application/alembic/versions/0002_app_metadata.py
Normal file
37
application/alembic/versions/0002_app_metadata.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
"""0002 app_metadata — singleton key/value table for instance-wide state.
|
||||||
|
|
||||||
|
Used by the startup version-check client to persist the anonymous
|
||||||
|
instance UUID and a one-shot "notice shown" flag. Both values are tiny
|
||||||
|
plain-text strings; this is a deliberate generic-config table rather
|
||||||
|
than dedicated columns so future one-off settings (telemetry opt-in
|
||||||
|
timestamps, feature-flag overrides, etc.) don't each need their own
|
||||||
|
migration.
|
||||||
|
|
||||||
|
Revision ID: 0002_app_metadata
|
||||||
|
Revises: 0001_initial
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
|
||||||
|
revision: str = "0002_app_metadata"
|
||||||
|
down_revision: Union[str, None] = "0001_initial"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE app_metadata (
|
||||||
|
key TEXT PRIMARY KEY,
|
||||||
|
value TEXT NOT NULL
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.execute("DROP TABLE IF EXISTS app_metadata;")
|
||||||
65
application/alembic/versions/0003_user_custom_models.py
Normal file
65
application/alembic/versions/0003_user_custom_models.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
"""0003 user_custom_models — per-user OpenAI-compatible model registrations.
|
||||||
|
|
||||||
|
Revision ID: 0003_user_custom_models
|
||||||
|
Revises: 0002_app_metadata
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
|
||||||
|
revision: str = "0003_user_custom_models"
|
||||||
|
down_revision: Union[str, None] = "0002_app_metadata"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE user_custom_models (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
upstream_model_id TEXT NOT NULL,
|
||||||
|
display_name TEXT NOT NULL,
|
||||||
|
description TEXT NOT NULL DEFAULT '',
|
||||||
|
base_url TEXT NOT NULL,
|
||||||
|
api_key_encrypted TEXT NOT NULL,
|
||||||
|
capabilities JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||||
|
enabled BOOLEAN NOT NULL DEFAULT true,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX user_custom_models_user_id_idx "
|
||||||
|
"ON user_custom_models (user_id);"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mirror the project-wide invariants set up in 0001_initial:
|
||||||
|
# * user_id FK with ON DELETE RESTRICT (deferrable),
|
||||||
|
# * ensure_user_exists() trigger so the parent users row autocreates,
|
||||||
|
# * set_updated_at() trigger.
|
||||||
|
op.execute(
|
||||||
|
"ALTER TABLE user_custom_models "
|
||||||
|
"ADD CONSTRAINT user_custom_models_user_id_fk "
|
||||||
|
"FOREIGN KEY (user_id) REFERENCES users(user_id) "
|
||||||
|
"ON DELETE RESTRICT DEFERRABLE INITIALLY IMMEDIATE;"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE TRIGGER user_custom_models_ensure_user "
|
||||||
|
"BEFORE INSERT OR UPDATE OF user_id ON user_custom_models "
|
||||||
|
"FOR EACH ROW EXECUTE FUNCTION ensure_user_exists();"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE TRIGGER user_custom_models_set_updated_at "
|
||||||
|
"BEFORE UPDATE ON user_custom_models "
|
||||||
|
"FOR EACH ROW WHEN (OLD.* IS DISTINCT FROM NEW.*) "
|
||||||
|
"EXECUTE FUNCTION set_updated_at();"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.execute("DROP TABLE IF EXISTS user_custom_models;")
|
||||||
@@ -177,6 +177,7 @@ class BaseAnswerResource:
|
|||||||
is_shared_usage: bool = False,
|
is_shared_usage: bool = False,
|
||||||
shared_token: Optional[str] = None,
|
shared_token: Optional[str] = None,
|
||||||
model_id: Optional[str] = None,
|
model_id: Optional[str] = None,
|
||||||
|
model_user_id: Optional[str] = None,
|
||||||
_continuation: Optional[Dict] = None,
|
_continuation: Optional[Dict] = None,
|
||||||
) -> Generator[str, None, None]:
|
) -> Generator[str, None, None]:
|
||||||
"""
|
"""
|
||||||
@@ -289,8 +290,18 @@ class BaseAnswerResource:
|
|||||||
# conversation if this is the first turn.
|
# conversation if this is the first turn.
|
||||||
if not conversation_id and should_save_conversation:
|
if not conversation_id and should_save_conversation:
|
||||||
try:
|
try:
|
||||||
|
# Use model-owner scope so shared-agent
|
||||||
|
# owner-BYOM resolves to its registered plugin.
|
||||||
provider = (
|
provider = (
|
||||||
get_provider_from_model_id(model_id)
|
get_provider_from_model_id(
|
||||||
|
model_id,
|
||||||
|
user_id=model_user_id
|
||||||
|
or (
|
||||||
|
decoded_token.get("sub")
|
||||||
|
if decoded_token
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
)
|
||||||
if model_id
|
if model_id
|
||||||
else settings.LLM_PROVIDER
|
else settings.LLM_PROVIDER
|
||||||
)
|
)
|
||||||
@@ -304,6 +315,7 @@ class BaseAnswerResource:
|
|||||||
decoded_token=decoded_token,
|
decoded_token=decoded_token,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
|
model_user_id=model_user_id,
|
||||||
)
|
)
|
||||||
conversation_id = (
|
conversation_id = (
|
||||||
self.conversation_service.save_conversation(
|
self.conversation_service.save_conversation(
|
||||||
@@ -340,6 +352,9 @@ class BaseAnswerResource:
|
|||||||
tool_schemas=getattr(agent, "tools", []),
|
tool_schemas=getattr(agent, "tools", []),
|
||||||
agent_config={
|
agent_config={
|
||||||
"model_id": model_id or self.default_model_id,
|
"model_id": model_id or self.default_model_id,
|
||||||
|
# Persist BYOM scope so resume doesn't
|
||||||
|
# fall back to caller's layer.
|
||||||
|
"model_user_id": model_user_id,
|
||||||
"llm_name": getattr(agent, "llm_name", settings.LLM_PROVIDER),
|
"llm_name": getattr(agent, "llm_name", settings.LLM_PROVIDER),
|
||||||
"api_key": getattr(agent, "api_key", None),
|
"api_key": getattr(agent, "api_key", None),
|
||||||
"user_api_key": user_api_key,
|
"user_api_key": user_api_key,
|
||||||
@@ -370,8 +385,14 @@ class BaseAnswerResource:
|
|||||||
if isNoneDoc:
|
if isNoneDoc:
|
||||||
for doc in source_log_docs:
|
for doc in source_log_docs:
|
||||||
doc["source"] = "None"
|
doc["source"] = "None"
|
||||||
|
# Run under model-owner scope so title-gen LLM inside
|
||||||
|
# save_conversation uses the owner's BYOM provider/key.
|
||||||
provider = (
|
provider = (
|
||||||
get_provider_from_model_id(model_id)
|
get_provider_from_model_id(
|
||||||
|
model_id,
|
||||||
|
user_id=model_user_id
|
||||||
|
or (decoded_token.get("sub") if decoded_token else None),
|
||||||
|
)
|
||||||
if model_id
|
if model_id
|
||||||
else settings.LLM_PROVIDER
|
else settings.LLM_PROVIDER
|
||||||
)
|
)
|
||||||
@@ -384,6 +405,7 @@ class BaseAnswerResource:
|
|||||||
decoded_token=decoded_token,
|
decoded_token=decoded_token,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
|
model_user_id=model_user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if should_save_conversation:
|
if should_save_conversation:
|
||||||
@@ -481,12 +503,34 @@ class BaseAnswerResource:
|
|||||||
if isNoneDoc:
|
if isNoneDoc:
|
||||||
for doc in source_log_docs:
|
for doc in source_log_docs:
|
||||||
doc["source"] = "None"
|
doc["source"] = "None"
|
||||||
|
# Mirror the normal-path provider resolution so the
|
||||||
|
# partial-save title LLM uses the model-owner's BYOM
|
||||||
|
# registration (shared-agent dispatch) rather than
|
||||||
|
# the deployment default with the instance api key.
|
||||||
|
provider = (
|
||||||
|
get_provider_from_model_id(
|
||||||
|
model_id,
|
||||||
|
user_id=model_user_id
|
||||||
|
or (
|
||||||
|
decoded_token.get("sub")
|
||||||
|
if decoded_token
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if model_id
|
||||||
|
else settings.LLM_PROVIDER
|
||||||
|
)
|
||||||
|
sys_api_key = get_api_key_for_provider(
|
||||||
|
provider or settings.LLM_PROVIDER
|
||||||
|
)
|
||||||
llm = LLMCreator.create_llm(
|
llm = LLMCreator.create_llm(
|
||||||
settings.LLM_PROVIDER,
|
provider or settings.LLM_PROVIDER,
|
||||||
api_key=settings.API_KEY,
|
api_key=sys_api_key,
|
||||||
user_api_key=user_api_key,
|
user_api_key=user_api_key,
|
||||||
decoded_token=decoded_token,
|
decoded_token=decoded_token,
|
||||||
|
model_id=model_id,
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
|
model_user_id=model_user_id,
|
||||||
)
|
)
|
||||||
self.conversation_service.save_conversation(
|
self.conversation_service.save_conversation(
|
||||||
conversation_id,
|
conversation_id,
|
||||||
|
|||||||
@@ -1,21 +1,21 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List
|
|
||||||
|
|
||||||
from flask import make_response, request
|
from flask import make_response, request
|
||||||
from flask_restx import fields, Resource
|
from flask_restx import fields, Resource
|
||||||
|
|
||||||
from application.api.answer.routes.base import answer_ns
|
from application.api.answer.routes.base import answer_ns
|
||||||
from application.core.settings import settings
|
from application.services.search_service import (
|
||||||
from application.storage.db.repositories.agents import AgentsRepository
|
InvalidAPIKey,
|
||||||
from application.storage.db.session import db_readonly
|
SearchFailed,
|
||||||
from application.vectorstore.vector_creator import VectorCreator
|
search,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@answer_ns.route("/api/search")
|
@answer_ns.route("/api/search")
|
||||||
class SearchResource(Resource):
|
class SearchResource(Resource):
|
||||||
"""Fast search endpoint for retrieving relevant documents"""
|
"""Fast search endpoint for retrieving relevant documents."""
|
||||||
|
|
||||||
search_model = answer_ns.model(
|
search_model = answer_ns.model(
|
||||||
"SearchModel",
|
"SearchModel",
|
||||||
@@ -32,102 +32,10 @@ class SearchResource(Resource):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_sources_from_api_key(self, api_key: str) -> List[str]:
|
|
||||||
"""Get source IDs connected to the API key/agent."""
|
|
||||||
with db_readonly() as conn:
|
|
||||||
agent_data = AgentsRepository(conn).find_by_key(api_key)
|
|
||||||
if not agent_data:
|
|
||||||
return []
|
|
||||||
|
|
||||||
source_ids: List[str] = []
|
|
||||||
# extra_source_ids is a PG ARRAY(UUID) of source UUIDs.
|
|
||||||
extra = agent_data.get("extra_source_ids") or []
|
|
||||||
for src in extra:
|
|
||||||
if src:
|
|
||||||
source_ids.append(str(src))
|
|
||||||
|
|
||||||
if not source_ids:
|
|
||||||
single = agent_data.get("source_id")
|
|
||||||
if single:
|
|
||||||
source_ids.append(str(single))
|
|
||||||
|
|
||||||
return source_ids
|
|
||||||
|
|
||||||
def _search_vectorstores(
|
|
||||||
self, query: str, source_ids: List[str], chunks: int
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
"""Search across vectorstores and return results"""
|
|
||||||
if not source_ids:
|
|
||||||
return []
|
|
||||||
|
|
||||||
results = []
|
|
||||||
chunks_per_source = max(1, chunks // len(source_ids))
|
|
||||||
seen_texts = set()
|
|
||||||
|
|
||||||
for source_id in source_ids:
|
|
||||||
if not source_id or not source_id.strip():
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
docsearch = VectorCreator.create_vectorstore(
|
|
||||||
settings.VECTOR_STORE, source_id, settings.EMBEDDINGS_KEY
|
|
||||||
)
|
|
||||||
docs = docsearch.search(query, k=chunks_per_source * 2)
|
|
||||||
|
|
||||||
for doc in docs:
|
|
||||||
if len(results) >= chunks:
|
|
||||||
break
|
|
||||||
|
|
||||||
if hasattr(doc, "page_content") and hasattr(doc, "metadata"):
|
|
||||||
page_content = doc.page_content
|
|
||||||
metadata = doc.metadata
|
|
||||||
else:
|
|
||||||
page_content = doc.get("text", doc.get("page_content", ""))
|
|
||||||
metadata = doc.get("metadata", {})
|
|
||||||
|
|
||||||
# Skip duplicates
|
|
||||||
text_hash = hash(page_content[:200])
|
|
||||||
if text_hash in seen_texts:
|
|
||||||
continue
|
|
||||||
seen_texts.add(text_hash)
|
|
||||||
|
|
||||||
title = metadata.get(
|
|
||||||
"title", metadata.get("post_title", "")
|
|
||||||
)
|
|
||||||
if not isinstance(title, str):
|
|
||||||
title = str(title) if title else ""
|
|
||||||
|
|
||||||
# Clean up title
|
|
||||||
if title:
|
|
||||||
title = title.split("/")[-1]
|
|
||||||
else:
|
|
||||||
# Use filename or first part of content as title
|
|
||||||
title = metadata.get("filename", page_content[:50] + "...")
|
|
||||||
|
|
||||||
source = metadata.get("source", source_id)
|
|
||||||
|
|
||||||
results.append({
|
|
||||||
"text": page_content,
|
|
||||||
"title": title,
|
|
||||||
"source": source,
|
|
||||||
})
|
|
||||||
|
|
||||||
if len(results) >= chunks:
|
|
||||||
break
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Error searching vectorstore {source_id}: {e}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
return results[:chunks]
|
|
||||||
|
|
||||||
@answer_ns.expect(search_model)
|
@answer_ns.expect(search_model)
|
||||||
@answer_ns.doc(description="Search for relevant documents based on query")
|
@answer_ns.doc(description="Search for relevant documents based on query")
|
||||||
def post(self):
|
def post(self):
|
||||||
data = request.get_json()
|
data = request.get_json() or {}
|
||||||
|
|
||||||
question = data.get("question")
|
question = data.get("question")
|
||||||
api_key = data.get("api_key")
|
api_key = data.get("api_key")
|
||||||
@@ -135,32 +43,13 @@ class SearchResource(Resource):
|
|||||||
|
|
||||||
if not question:
|
if not question:
|
||||||
return make_response({"error": "question is required"}, 400)
|
return make_response({"error": "question is required"}, 400)
|
||||||
|
|
||||||
if not api_key:
|
if not api_key:
|
||||||
return make_response({"error": "api_key is required"}, 400)
|
return make_response({"error": "api_key is required"}, 400)
|
||||||
|
|
||||||
# Validate API key
|
|
||||||
with db_readonly() as conn:
|
|
||||||
agent = AgentsRepository(conn).find_by_key(api_key)
|
|
||||||
if not agent:
|
|
||||||
return make_response({"error": "Invalid API key"}, 401)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get sources connected to this API key
|
return make_response(search(api_key, question, chunks), 200)
|
||||||
source_ids = self._get_sources_from_api_key(api_key)
|
except InvalidAPIKey:
|
||||||
|
return make_response({"error": "Invalid API key"}, 401)
|
||||||
if not source_ids:
|
except SearchFailed:
|
||||||
return make_response([], 200)
|
logger.exception("/api/search failed")
|
||||||
|
|
||||||
# Perform search
|
|
||||||
results = self._search_vectorstores(question, source_ids, chunks)
|
|
||||||
|
|
||||||
return make_response(results, 200)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"/api/search - error: {str(e)}",
|
|
||||||
extra={"error": str(e)},
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
return make_response({"error": "Search failed"}, 500)
|
return make_response({"error": "Search failed"}, 500)
|
||||||
|
|||||||
@@ -109,6 +109,7 @@ class StreamResource(Resource, BaseAnswerResource):
|
|||||||
decoded_token=processor.decoded_token,
|
decoded_token=processor.decoded_token,
|
||||||
agent_id=processor.agent_id,
|
agent_id=processor.agent_id,
|
||||||
model_id=processor.model_id,
|
model_id=processor.model_id,
|
||||||
|
model_user_id=processor.model_user_id,
|
||||||
_continuation={
|
_continuation={
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"tools_dict": tools_dict,
|
"tools_dict": tools_dict,
|
||||||
@@ -145,6 +146,7 @@ class StreamResource(Resource, BaseAnswerResource):
|
|||||||
is_shared_usage=processor.is_shared_usage,
|
is_shared_usage=processor.is_shared_usage,
|
||||||
shared_token=processor.shared_token,
|
shared_token=processor.shared_token,
|
||||||
model_id=processor.model_id,
|
model_id=processor.model_id,
|
||||||
|
model_user_id=processor.model_user_id,
|
||||||
),
|
),
|
||||||
mimetype="text/event-stream",
|
mimetype="text/event-stream",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ class CompressionOrchestrator:
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
decoded_token: Dict[str, Any],
|
decoded_token: Dict[str, Any],
|
||||||
current_query_tokens: int = 500,
|
current_query_tokens: int = 500,
|
||||||
|
model_user_id: Optional[str] = None,
|
||||||
) -> CompressionResult:
|
) -> CompressionResult:
|
||||||
"""
|
"""
|
||||||
Check if compression is needed and perform it if so.
|
Check if compression is needed and perform it if so.
|
||||||
@@ -57,16 +58,18 @@ class CompressionOrchestrator:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
conversation_id: Conversation ID
|
conversation_id: Conversation ID
|
||||||
user_id: User ID
|
user_id: Caller's user id — used for conversation access checks
|
||||||
model_id: Model being used for conversation
|
model_id: Model being used for conversation
|
||||||
decoded_token: User's decoded JWT token
|
decoded_token: User's decoded JWT token
|
||||||
current_query_tokens: Estimated tokens for current query
|
current_query_tokens: Estimated tokens for current query
|
||||||
|
model_user_id: BYOM-resolution scope (model owner); defaults
|
||||||
|
to ``user_id`` for built-in / caller-owned models.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
CompressionResult with summary and recent queries
|
CompressionResult with summary and recent queries
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Load conversation
|
# Conversation row is owned by the caller, not the model owner.
|
||||||
conversation = self.conversation_service.get_conversation(
|
conversation = self.conversation_service.get_conversation(
|
||||||
conversation_id, user_id
|
conversation_id, user_id
|
||||||
)
|
)
|
||||||
@@ -77,9 +80,14 @@ class CompressionOrchestrator:
|
|||||||
)
|
)
|
||||||
return CompressionResult.failure("Conversation not found")
|
return CompressionResult.failure("Conversation not found")
|
||||||
|
|
||||||
# Check if compression is needed
|
# Use model-owner scope so per-user BYOM context windows
|
||||||
|
# (e.g. 8k) compute the threshold against the right limit.
|
||||||
|
registry_user_id = model_user_id or user_id
|
||||||
if not self.threshold_checker.should_compress(
|
if not self.threshold_checker.should_compress(
|
||||||
conversation, model_id, current_query_tokens
|
conversation,
|
||||||
|
model_id,
|
||||||
|
current_query_tokens,
|
||||||
|
user_id=registry_user_id,
|
||||||
):
|
):
|
||||||
# No compression needed, return full history
|
# No compression needed, return full history
|
||||||
queries = conversation.get("queries", [])
|
queries = conversation.get("queries", [])
|
||||||
@@ -87,7 +95,12 @@ class CompressionOrchestrator:
|
|||||||
|
|
||||||
# Perform compression
|
# Perform compression
|
||||||
return self._perform_compression(
|
return self._perform_compression(
|
||||||
conversation_id, conversation, model_id, decoded_token
|
conversation_id,
|
||||||
|
conversation,
|
||||||
|
model_id,
|
||||||
|
decoded_token,
|
||||||
|
user_id=user_id,
|
||||||
|
model_user_id=model_user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -102,6 +115,8 @@ class CompressionOrchestrator:
|
|||||||
conversation: Dict[str, Any],
|
conversation: Dict[str, Any],
|
||||||
model_id: str,
|
model_id: str,
|
||||||
decoded_token: Dict[str, Any],
|
decoded_token: Dict[str, Any],
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
model_user_id: Optional[str] = None,
|
||||||
) -> CompressionResult:
|
) -> CompressionResult:
|
||||||
"""
|
"""
|
||||||
Perform the actual compression operation.
|
Perform the actual compression operation.
|
||||||
@@ -111,6 +126,8 @@ class CompressionOrchestrator:
|
|||||||
conversation: Conversation document
|
conversation: Conversation document
|
||||||
model_id: Model ID for conversation
|
model_id: Model ID for conversation
|
||||||
decoded_token: User token
|
decoded_token: User token
|
||||||
|
user_id: Caller's id (for conversation reload after compression)
|
||||||
|
model_user_id: BYOM-resolution scope (model owner)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
CompressionResult
|
CompressionResult
|
||||||
@@ -123,11 +140,17 @@ class CompressionOrchestrator:
|
|||||||
else model_id
|
else model_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get provider and API key for compression model
|
# Use model-owner scope so provider/api_key resolves to the
|
||||||
provider = get_provider_from_model_id(compression_model)
|
# owner's BYOM record (shared-agent dispatch).
|
||||||
|
caller_user_id = user_id
|
||||||
|
if caller_user_id is None and isinstance(decoded_token, dict):
|
||||||
|
caller_user_id = decoded_token.get("sub")
|
||||||
|
registry_user_id = model_user_id or caller_user_id
|
||||||
|
provider = get_provider_from_model_id(
|
||||||
|
compression_model, user_id=registry_user_id
|
||||||
|
)
|
||||||
api_key = get_api_key_for_provider(provider)
|
api_key = get_api_key_for_provider(provider)
|
||||||
|
|
||||||
# Create compression LLM
|
|
||||||
compression_llm = LLMCreator.create_llm(
|
compression_llm = LLMCreator.create_llm(
|
||||||
provider,
|
provider,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
@@ -135,6 +158,7 @@ class CompressionOrchestrator:
|
|||||||
decoded_token=decoded_token,
|
decoded_token=decoded_token,
|
||||||
model_id=compression_model,
|
model_id=compression_model,
|
||||||
agent_id=conversation.get("agent_id"),
|
agent_id=conversation.get("agent_id"),
|
||||||
|
model_user_id=registry_user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create compression service with DB update capability
|
# Create compression service with DB update capability
|
||||||
@@ -167,9 +191,12 @@ class CompressionOrchestrator:
|
|||||||
f"saved {metadata.original_token_count - metadata.compressed_token_count} tokens"
|
f"saved {metadata.original_token_count - metadata.compressed_token_count} tokens"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reload conversation with updated metadata
|
# Reload under caller (conversation is owned by caller).
|
||||||
|
reload_user_id = caller_user_id
|
||||||
|
if reload_user_id is None and isinstance(decoded_token, dict):
|
||||||
|
reload_user_id = decoded_token.get("sub")
|
||||||
conversation = self.conversation_service.get_conversation(
|
conversation = self.conversation_service.get_conversation(
|
||||||
conversation_id, user_id=decoded_token.get("sub")
|
conversation_id, user_id=reload_user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get compressed context
|
# Get compressed context
|
||||||
@@ -192,16 +219,21 @@ class CompressionOrchestrator:
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
decoded_token: Dict[str, Any],
|
decoded_token: Dict[str, Any],
|
||||||
current_conversation: Optional[Dict[str, Any]] = None,
|
current_conversation: Optional[Dict[str, Any]] = None,
|
||||||
|
model_user_id: Optional[str] = None,
|
||||||
) -> CompressionResult:
|
) -> CompressionResult:
|
||||||
"""
|
"""
|
||||||
Perform compression during tool execution.
|
Perform compression during tool execution.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
conversation_id: Conversation ID
|
conversation_id: Conversation ID
|
||||||
user_id: User ID
|
user_id: Caller's user id — used for conversation access checks
|
||||||
model_id: Model ID
|
model_id: Model ID
|
||||||
decoded_token: User token
|
decoded_token: User token
|
||||||
current_conversation: Pre-loaded conversation (optional)
|
current_conversation: Pre-loaded conversation (optional)
|
||||||
|
model_user_id: BYOM-resolution scope (model owner). For
|
||||||
|
shared-agent dispatch this is the agent owner; defaults
|
||||||
|
to ``user_id`` so built-in / caller-owned models are
|
||||||
|
unaffected.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
CompressionResult
|
CompressionResult
|
||||||
@@ -223,7 +255,12 @@ class CompressionOrchestrator:
|
|||||||
|
|
||||||
# Perform compression
|
# Perform compression
|
||||||
return self._perform_compression(
|
return self._perform_compression(
|
||||||
conversation_id, conversation, model_id, decoded_token
|
conversation_id,
|
||||||
|
conversation,
|
||||||
|
model_id,
|
||||||
|
decoded_token,
|
||||||
|
user_id=user_id,
|
||||||
|
model_user_id=model_user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -106,8 +106,13 @@ class CompressionService:
|
|||||||
f"using model {self.model_id}"
|
f"using model {self.model_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# See note in conversation_service.py: ``self.model_id`` is
|
||||||
|
# the registry id (UUID for BYOM); the LLM's own model_id is
|
||||||
|
# what the provider's API actually expects.
|
||||||
response = self.llm.gen(
|
response = self.llm.gen(
|
||||||
model=self.model_id, messages=messages, max_tokens=4000
|
model=getattr(self.llm, "model_id", None) or self.model_id,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=4000,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract summary from response
|
# Extract summary from response
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ class CompressionThresholdChecker:
|
|||||||
conversation: Dict[str, Any],
|
conversation: Dict[str, Any],
|
||||||
model_id: str,
|
model_id: str,
|
||||||
current_query_tokens: int = 500,
|
current_query_tokens: int = 500,
|
||||||
|
user_id: str | None = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Determine if compression is needed.
|
Determine if compression is needed.
|
||||||
@@ -38,6 +39,8 @@ class CompressionThresholdChecker:
|
|||||||
conversation: Full conversation document
|
conversation: Full conversation document
|
||||||
model_id: Target model for this request
|
model_id: Target model for this request
|
||||||
current_query_tokens: Estimated tokens for current query
|
current_query_tokens: Estimated tokens for current query
|
||||||
|
user_id: Owner — needed so per-user BYOM custom-model UUIDs
|
||||||
|
resolve when looking up the context window.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if tokens >= threshold% of context window
|
True if tokens >= threshold% of context window
|
||||||
@@ -48,7 +51,7 @@ class CompressionThresholdChecker:
|
|||||||
total_tokens += current_query_tokens
|
total_tokens += current_query_tokens
|
||||||
|
|
||||||
# Get context window limit for model
|
# Get context window limit for model
|
||||||
context_limit = get_token_limit(model_id)
|
context_limit = get_token_limit(model_id, user_id=user_id)
|
||||||
|
|
||||||
# Calculate threshold
|
# Calculate threshold
|
||||||
threshold = int(context_limit * self.threshold_percentage)
|
threshold = int(context_limit * self.threshold_percentage)
|
||||||
@@ -73,20 +76,24 @@ class CompressionThresholdChecker:
|
|||||||
logger.error(f"Error checking compression need: {str(e)}", exc_info=True)
|
logger.error(f"Error checking compression need: {str(e)}", exc_info=True)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def check_message_tokens(self, messages: list, model_id: str) -> bool:
|
def check_message_tokens(
|
||||||
|
self, messages: list, model_id: str, user_id: str | None = None
|
||||||
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if message list exceeds threshold.
|
Check if message list exceeds threshold.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: List of message dicts
|
messages: List of message dicts
|
||||||
model_id: Target model
|
model_id: Target model
|
||||||
|
user_id: Owner — needed so per-user BYOM custom-model UUIDs
|
||||||
|
resolve when looking up the context window.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if at or above threshold
|
True if at or above threshold
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
current_tokens = TokenCounter.count_message_tokens(messages)
|
current_tokens = TokenCounter.count_message_tokens(messages)
|
||||||
context_limit = get_token_limit(model_id)
|
context_limit = get_token_limit(model_id, user_id=user_id)
|
||||||
threshold = int(context_limit * self.threshold_percentage)
|
threshold = int(context_limit * self.threshold_percentage)
|
||||||
|
|
||||||
if current_tokens >= threshold:
|
if current_tokens >= threshold:
|
||||||
|
|||||||
@@ -12,6 +12,12 @@ logger = logging.getLogger(__name__)
|
|||||||
class TokenCounter:
|
class TokenCounter:
|
||||||
"""Centralized token counting for conversations and messages."""
|
"""Centralized token counting for conversations and messages."""
|
||||||
|
|
||||||
|
# Per-image token estimate. Provider tokenizers vary widely
|
||||||
|
# (Gemini ~258, GPT-4o 85-1500, Claude ~1500) and the actual cost
|
||||||
|
# depends on resolution/detail we can't see here. Errs slightly high
|
||||||
|
# so the threshold check stays conservative.
|
||||||
|
_IMAGE_PART_TOKEN_ESTIMATE = 1500
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def count_message_tokens(messages: List[Dict]) -> int:
|
def count_message_tokens(messages: List[Dict]) -> int:
|
||||||
"""
|
"""
|
||||||
@@ -29,12 +35,36 @@ class TokenCounter:
|
|||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
total_tokens += num_tokens_from_string(content)
|
total_tokens += num_tokens_from_string(content)
|
||||||
elif isinstance(content, list):
|
elif isinstance(content, list):
|
||||||
# Handle structured content (tool calls, etc.)
|
# Handle structured content (tool calls, image parts, etc.)
|
||||||
for item in content:
|
for item in content:
|
||||||
if isinstance(item, dict):
|
if isinstance(item, dict):
|
||||||
total_tokens += num_tokens_from_string(str(item))
|
total_tokens += TokenCounter._count_content_part(item)
|
||||||
return total_tokens
|
return total_tokens
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _count_content_part(item: Dict) -> int:
|
||||||
|
# Image/file attachments are billed by the provider per image,
|
||||||
|
# not proportional to the inline bytes/base64 string.
|
||||||
|
# ``str(item)`` on a 1MB image inflates the count by ~10000x,
|
||||||
|
# which trips spurious compression and overflows downstream
|
||||||
|
# input limits.
|
||||||
|
item_type = item.get("type")
|
||||||
|
|
||||||
|
if "files" in item:
|
||||||
|
files = item.get("files")
|
||||||
|
count = len(files) if isinstance(files, list) and files else 1
|
||||||
|
return TokenCounter._IMAGE_PART_TOKEN_ESTIMATE * count
|
||||||
|
|
||||||
|
if "image_url" in item or item_type in {
|
||||||
|
"image",
|
||||||
|
"image_url",
|
||||||
|
"input_image",
|
||||||
|
"file",
|
||||||
|
}:
|
||||||
|
return TokenCounter._IMAGE_PART_TOKEN_ESTIMATE
|
||||||
|
|
||||||
|
return num_tokens_from_string(str(item))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def count_query_tokens(
|
def count_query_tokens(
|
||||||
queries: List[Dict[str, Any]], include_tool_calls: bool = True
|
queries: List[Dict[str, Any]], include_tool_calls: bool = True
|
||||||
|
|||||||
@@ -136,8 +136,14 @@ class ConversationService:
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# ``model_id`` here is the registry id (a UUID for BYOM
|
||||||
|
# records). The LLM's own ``model_id`` is the upstream name
|
||||||
|
# LLMCreator resolved at construction time — that's what
|
||||||
|
# the provider's API expects. Built-ins are unaffected.
|
||||||
completion = llm.gen(
|
completion = llm.gen(
|
||||||
model=model_id, messages=messages_summary, max_tokens=500
|
model=getattr(llm, "model_id", None) or model_id,
|
||||||
|
messages=messages_summary,
|
||||||
|
max_tokens=500,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not completion or not completion.strip():
|
if not completion or not completion.strip():
|
||||||
|
|||||||
@@ -121,6 +121,8 @@ class StreamProcessor:
|
|||||||
self.agent_id = self.data.get("agent_id")
|
self.agent_id = self.data.get("agent_id")
|
||||||
self.agent_key = None
|
self.agent_key = None
|
||||||
self.model_id: Optional[str] = None
|
self.model_id: Optional[str] = None
|
||||||
|
# BYOM-resolution scope, set by _validate_and_set_model.
|
||||||
|
self.model_user_id: Optional[str] = None
|
||||||
self.conversation_service = ConversationService()
|
self.conversation_service = ConversationService()
|
||||||
self.compression_orchestrator = CompressionOrchestrator(
|
self.compression_orchestrator = CompressionOrchestrator(
|
||||||
self.conversation_service
|
self.conversation_service
|
||||||
@@ -191,16 +193,23 @@ class StreamProcessor:
|
|||||||
for query in conversation.get("queries", [])
|
for query in conversation.get("queries", [])
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
|
# model_user_id keeps history trim aligned with the BYOM's
|
||||||
|
# actual context window instead of the default 128k.
|
||||||
self.history = limit_chat_history(
|
self.history = limit_chat_history(
|
||||||
json.loads(self.data.get("history", "[]")), model_id=self.model_id
|
json.loads(self.data.get("history", "[]")),
|
||||||
|
model_id=self.model_id,
|
||||||
|
user_id=self.model_user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _handle_compression(self, conversation: Dict[str, Any]):
|
def _handle_compression(self, conversation: Dict[str, Any]):
|
||||||
"""Handle conversation compression logic using orchestrator."""
|
"""Handle conversation compression logic using orchestrator."""
|
||||||
try:
|
try:
|
||||||
|
# initial_user_id for conversation access; model_user_id
|
||||||
|
# for BYOM context-window / provider lookups.
|
||||||
result = self.compression_orchestrator.compress_if_needed(
|
result = self.compression_orchestrator.compress_if_needed(
|
||||||
conversation_id=self.conversation_id,
|
conversation_id=self.conversation_id,
|
||||||
user_id=self.initial_user_id,
|
user_id=self.initial_user_id,
|
||||||
|
model_user_id=self.model_user_id,
|
||||||
model_id=self.model_id,
|
model_id=self.model_id,
|
||||||
decoded_token=self.decoded_token,
|
decoded_token=self.decoded_token,
|
||||||
)
|
)
|
||||||
@@ -284,11 +293,18 @@ class StreamProcessor:
|
|||||||
from application.core.model_settings import ModelRegistry
|
from application.core.model_settings import ModelRegistry
|
||||||
|
|
||||||
requested_model = self.data.get("model_id")
|
requested_model = self.data.get("model_id")
|
||||||
|
# Caller picks from their own BYOM layer; agent defaults resolve
|
||||||
|
# under the owner's layer (shared agents have caller != owner).
|
||||||
|
caller_user_id = self.initial_user_id
|
||||||
|
owner_user_id = self.agent_config.get("user_id") or caller_user_id
|
||||||
|
|
||||||
if requested_model:
|
if requested_model:
|
||||||
if not validate_model_id(requested_model):
|
if not validate_model_id(requested_model, user_id=caller_user_id):
|
||||||
registry = ModelRegistry.get_instance()
|
registry = ModelRegistry.get_instance()
|
||||||
available_models = [m.id for m in registry.get_enabled_models()]
|
available_models = [
|
||||||
|
m.id
|
||||||
|
for m in registry.get_enabled_models(user_id=caller_user_id)
|
||||||
|
]
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid model_id '{requested_model}'. "
|
f"Invalid model_id '{requested_model}'. "
|
||||||
f"Available models: {', '.join(available_models[:5])}"
|
f"Available models: {', '.join(available_models[:5])}"
|
||||||
@@ -299,12 +315,17 @@ class StreamProcessor:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.model_id = requested_model
|
self.model_id = requested_model
|
||||||
|
self.model_user_id = caller_user_id
|
||||||
else:
|
else:
|
||||||
agent_default_model = self.agent_config.get("default_model_id", "")
|
agent_default_model = self.agent_config.get("default_model_id", "")
|
||||||
if agent_default_model and validate_model_id(agent_default_model):
|
if agent_default_model and validate_model_id(
|
||||||
|
agent_default_model, user_id=owner_user_id
|
||||||
|
):
|
||||||
self.model_id = agent_default_model
|
self.model_id = agent_default_model
|
||||||
|
self.model_user_id = owner_user_id
|
||||||
else:
|
else:
|
||||||
self.model_id = get_default_model_id()
|
self.model_id = get_default_model_id()
|
||||||
|
self.model_user_id = None
|
||||||
|
|
||||||
def _get_agent_key(self, agent_id: Optional[str], user_id: Optional[str]) -> tuple:
|
def _get_agent_key(self, agent_id: Optional[str], user_id: Optional[str]) -> tuple:
|
||||||
"""Get API key for agent with access control."""
|
"""Get API key for agent with access control."""
|
||||||
@@ -514,6 +535,10 @@ class StreamProcessor:
|
|||||||
"allow_system_prompt_override": self._agent_data.get(
|
"allow_system_prompt_override": self._agent_data.get(
|
||||||
"allow_system_prompt_override", False
|
"allow_system_prompt_override", False
|
||||||
),
|
),
|
||||||
|
# Owner identity — _validate_and_set_model reads this to
|
||||||
|
# resolve owner-stored BYOM default_model_id against the
|
||||||
|
# owner's per-user model layer rather than the caller's.
|
||||||
|
"user_id": self._agent_data.get("user"),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -561,7 +586,13 @@ class StreamProcessor:
|
|||||||
|
|
||||||
def _configure_retriever(self):
|
def _configure_retriever(self):
|
||||||
"""Assemble retriever config with precedence: request > agent > default."""
|
"""Assemble retriever config with precedence: request > agent > default."""
|
||||||
doc_token_limit = calculate_doc_token_budget(model_id=self.model_id)
|
# BYOM scope: owner for shared-agent BYOM, caller for own BYOM,
|
||||||
|
# None for built-ins. Without ``user_id`` here, the doc budget
|
||||||
|
# falls back to settings.DEFAULT_LLM_TOKEN_LIMIT and overfills
|
||||||
|
# the upstream context window for any small (e.g. 8k/32k) BYOM.
|
||||||
|
doc_token_limit = calculate_doc_token_budget(
|
||||||
|
model_id=self.model_id, user_id=self.model_user_id
|
||||||
|
)
|
||||||
|
|
||||||
# Start with defaults
|
# Start with defaults
|
||||||
retriever_name = "classic"
|
retriever_name = "classic"
|
||||||
@@ -612,6 +643,7 @@ class StreamProcessor:
|
|||||||
chunks=self.retriever_config["chunks"],
|
chunks=self.retriever_config["chunks"],
|
||||||
doc_token_limit=self.retriever_config.get("doc_token_limit", 50000),
|
doc_token_limit=self.retriever_config.get("doc_token_limit", 50000),
|
||||||
model_id=self.model_id,
|
model_id=self.model_id,
|
||||||
|
model_user_id=self.model_user_id,
|
||||||
user_api_key=self.agent_config["user_api_key"],
|
user_api_key=self.agent_config["user_api_key"],
|
||||||
agent_id=self.agent_id,
|
agent_id=self.agent_id,
|
||||||
decoded_token=self.decoded_token,
|
decoded_token=self.decoded_token,
|
||||||
@@ -903,6 +935,11 @@ class StreamProcessor:
|
|||||||
agent_config = state["agent_config"]
|
agent_config = state["agent_config"]
|
||||||
|
|
||||||
model_id = agent_config.get("model_id")
|
model_id = agent_config.get("model_id")
|
||||||
|
# BYOM scope captured at initial dispatch. None for built-ins or
|
||||||
|
# caller-owned BYOM where decoded_token['sub'] is already the
|
||||||
|
# right scope; non-None for shared-agent owner BYOM where the
|
||||||
|
# caller's identity differs from the model owner's.
|
||||||
|
model_user_id = agent_config.get("model_user_id")
|
||||||
llm_name = agent_config.get("llm_name", settings.LLM_PROVIDER)
|
llm_name = agent_config.get("llm_name", settings.LLM_PROVIDER)
|
||||||
api_key = agent_config.get("api_key")
|
api_key = agent_config.get("api_key")
|
||||||
user_api_key = agent_config.get("user_api_key")
|
user_api_key = agent_config.get("user_api_key")
|
||||||
@@ -920,6 +957,7 @@ class StreamProcessor:
|
|||||||
decoded_token=self.decoded_token,
|
decoded_token=self.decoded_token,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
|
model_user_id=model_user_id,
|
||||||
)
|
)
|
||||||
llm_handler = LLMHandlerCreator.create_handler(llm_name or "default")
|
llm_handler = LLMHandlerCreator.create_handler(llm_name or "default")
|
||||||
tool_executor = ToolExecutor(
|
tool_executor = ToolExecutor(
|
||||||
@@ -949,6 +987,7 @@ class StreamProcessor:
|
|||||||
"endpoint": "stream",
|
"endpoint": "stream",
|
||||||
"llm_name": llm_name,
|
"llm_name": llm_name,
|
||||||
"model_id": model_id,
|
"model_id": model_id,
|
||||||
|
"model_user_id": model_user_id,
|
||||||
"api_key": system_api_key,
|
"api_key": system_api_key,
|
||||||
"agent_id": agent_id,
|
"agent_id": agent_id,
|
||||||
"user_api_key": user_api_key,
|
"user_api_key": user_api_key,
|
||||||
@@ -971,6 +1010,15 @@ class StreamProcessor:
|
|||||||
|
|
||||||
# Store config for the route layer
|
# Store config for the route layer
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
|
# Mirror ``model_user_id`` back onto the processor so the route
|
||||||
|
# layer (StreamResource) reads the owner scope captured at
|
||||||
|
# initial dispatch. Without this, ``processor.model_user_id``
|
||||||
|
# stays at the __init__ default (None) and complete_stream
|
||||||
|
# falls back to the caller's sub: the post-resume title-LLM
|
||||||
|
# save misses the owner's BYOM layer, and any second tool
|
||||||
|
# pause persists ``model_user_id=None`` — losing owner scope
|
||||||
|
# for every subsequent resume of this conversation.
|
||||||
|
self.model_user_id = model_user_id
|
||||||
self.agent_id = agent_id
|
self.agent_id = agent_id
|
||||||
self.agent_config["user_api_key"] = user_api_key
|
self.agent_config["user_api_key"] = user_api_key
|
||||||
self.conversation_id = conversation_id
|
self.conversation_id = conversation_id
|
||||||
@@ -1022,8 +1070,11 @@ class StreamProcessor:
|
|||||||
tools_data=tools_data,
|
tools_data=tools_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Use the user_id that resolved the model so owner-scoped BYOM
|
||||||
|
# records dispatch correctly on shared-agent requests.
|
||||||
|
model_user_id = getattr(self, "model_user_id", self.initial_user_id)
|
||||||
provider = (
|
provider = (
|
||||||
get_provider_from_model_id(self.model_id)
|
get_provider_from_model_id(self.model_id, user_id=model_user_id)
|
||||||
if self.model_id
|
if self.model_id
|
||||||
else settings.LLM_PROVIDER
|
else settings.LLM_PROVIDER
|
||||||
)
|
)
|
||||||
@@ -1048,6 +1099,8 @@ class StreamProcessor:
|
|||||||
model_id=self.model_id,
|
model_id=self.model_id,
|
||||||
agent_id=self.agent_id,
|
agent_id=self.agent_id,
|
||||||
backup_models=backup_models,
|
backup_models=backup_models,
|
||||||
|
# Owner-scope on shared-agent BYOM dispatch.
|
||||||
|
model_user_id=model_user_id,
|
||||||
)
|
)
|
||||||
llm_handler = LLMHandlerCreator.create_handler(
|
llm_handler = LLMHandlerCreator.create_handler(
|
||||||
provider if provider else "default"
|
provider if provider else "default"
|
||||||
@@ -1070,6 +1123,7 @@ class StreamProcessor:
|
|||||||
"endpoint": "stream",
|
"endpoint": "stream",
|
||||||
"llm_name": provider or settings.LLM_PROVIDER,
|
"llm_name": provider or settings.LLM_PROVIDER,
|
||||||
"model_id": self.model_id,
|
"model_id": self.model_id,
|
||||||
|
"model_user_id": self.model_user_id,
|
||||||
"api_key": system_api_key,
|
"api_key": system_api_key,
|
||||||
"agent_id": self.agent_id,
|
"agent_id": self.agent_id,
|
||||||
"user_api_key": self.agent_config["user_api_key"],
|
"user_api_key": self.agent_config["user_api_key"],
|
||||||
@@ -1097,6 +1151,7 @@ class StreamProcessor:
|
|||||||
"doc_token_limit", 50000
|
"doc_token_limit", 50000
|
||||||
),
|
),
|
||||||
"model_id": self.model_id,
|
"model_id": self.model_id,
|
||||||
|
"model_user_id": self.model_user_id,
|
||||||
"user_api_key": self.agent_config["user_api_key"],
|
"user_api_key": self.agent_config["user_api_key"],
|
||||||
"agent_id": self.agent_id,
|
"agent_id": self.agent_id,
|
||||||
"llm_name": provider or settings.LLM_PROVIDER,
|
"llm_name": provider or settings.LLM_PROVIDER,
|
||||||
|
|||||||
@@ -1,18 +1,135 @@
|
|||||||
from flask import current_app, jsonify, make_response
|
"""Model routes.
|
||||||
|
|
||||||
|
- ``GET /api/models`` — list available models for the current user.
|
||||||
|
Combines the built-in catalog with the user's BYOM records.
|
||||||
|
- ``GET/POST/PATCH/DELETE /api/user/models[/<id>]`` — CRUD for the
|
||||||
|
user's own OpenAI-compatible model registrations (BYOM).
|
||||||
|
- ``POST /api/user/models/<id>/test`` — sanity-check the upstream
|
||||||
|
endpoint with a tiny request.
|
||||||
|
|
||||||
|
Every BYOM endpoint is user-scoped at the repository layer
|
||||||
|
(every query filters on ``user_id`` from ``request.decoded_token``).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from flask import current_app, jsonify, make_response, request
|
||||||
from flask_restx import Namespace, Resource
|
from flask_restx import Namespace, Resource
|
||||||
|
|
||||||
from application.core.model_settings import ModelRegistry
|
from application.api import api
|
||||||
|
from application.core.model_registry import ModelRegistry
|
||||||
|
from application.security.safe_url import (
|
||||||
|
UnsafeUserUrlError,
|
||||||
|
pinned_post,
|
||||||
|
validate_user_base_url,
|
||||||
|
)
|
||||||
|
from application.storage.db.repositories.user_custom_models import (
|
||||||
|
UserCustomModelsRepository,
|
||||||
|
)
|
||||||
|
from application.storage.db.session import db_readonly, db_session
|
||||||
|
from application.utils import check_required_fields
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
models_ns = Namespace("models", description="Available models", path="/api")
|
models_ns = Namespace("models", description="Available models", path="/api")
|
||||||
|
|
||||||
|
|
||||||
|
_CONTEXT_WINDOW_MIN = 1_000
|
||||||
|
_CONTEXT_WINDOW_MAX = 10_000_000
|
||||||
|
|
||||||
|
|
||||||
|
def _user_id_or_401():
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return None, make_response(jsonify({"success": False}), 401)
|
||||||
|
user_id = decoded_token.get("sub")
|
||||||
|
if not user_id:
|
||||||
|
return None, make_response(jsonify({"success": False}), 401)
|
||||||
|
return user_id, None
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_capabilities(raw) -> dict:
|
||||||
|
"""Coerce + bound the user-supplied capabilities payload."""
|
||||||
|
raw = raw or {}
|
||||||
|
out = {}
|
||||||
|
if "supports_tools" in raw:
|
||||||
|
out["supports_tools"] = bool(raw["supports_tools"])
|
||||||
|
if "supports_structured_output" in raw:
|
||||||
|
out["supports_structured_output"] = bool(raw["supports_structured_output"])
|
||||||
|
if "supports_streaming" in raw:
|
||||||
|
out["supports_streaming"] = bool(raw["supports_streaming"])
|
||||||
|
if "attachments" in raw:
|
||||||
|
atts = raw["attachments"] or []
|
||||||
|
if not isinstance(atts, list):
|
||||||
|
raise ValueError("'capabilities.attachments' must be a list")
|
||||||
|
coerced = [str(a) for a in atts]
|
||||||
|
# Reject unknown aliases at the API boundary so bad payloads
|
||||||
|
# never reach the registry layer (where lenient expansion just
|
||||||
|
# drops them). Raw MIME types (containing ``/``) pass through
|
||||||
|
# unchanged for parity with the built-in YAML schema.
|
||||||
|
from application.core.model_yaml import builtin_attachment_aliases
|
||||||
|
|
||||||
|
aliases = builtin_attachment_aliases()
|
||||||
|
for entry in coerced:
|
||||||
|
if "/" in entry:
|
||||||
|
continue
|
||||||
|
if entry not in aliases:
|
||||||
|
valid = ", ".join(sorted(aliases.keys())) or "<none defined>"
|
||||||
|
raise ValueError(
|
||||||
|
f"unknown attachment alias '{entry}' in "
|
||||||
|
f"'capabilities.attachments'. Valid aliases: {valid}, "
|
||||||
|
f"or use a raw MIME type like 'image/png'."
|
||||||
|
)
|
||||||
|
out["attachments"] = coerced
|
||||||
|
if "context_window" in raw:
|
||||||
|
try:
|
||||||
|
cw = int(raw["context_window"])
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
raise ValueError("'capabilities.context_window' must be an integer")
|
||||||
|
if not (_CONTEXT_WINDOW_MIN <= cw <= _CONTEXT_WINDOW_MAX):
|
||||||
|
raise ValueError(
|
||||||
|
f"'capabilities.context_window' must be between "
|
||||||
|
f"{_CONTEXT_WINDOW_MIN} and {_CONTEXT_WINDOW_MAX}"
|
||||||
|
)
|
||||||
|
out["context_window"] = cw
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _row_to_response(row: dict) -> dict:
|
||||||
|
"""Wire-format projection — never includes the API key."""
|
||||||
|
return {
|
||||||
|
"id": str(row["id"]),
|
||||||
|
"upstream_model_id": row["upstream_model_id"],
|
||||||
|
"display_name": row["display_name"],
|
||||||
|
"description": row.get("description") or "",
|
||||||
|
"base_url": row["base_url"],
|
||||||
|
"capabilities": row.get("capabilities") or {},
|
||||||
|
"enabled": bool(row.get("enabled", True)),
|
||||||
|
"source": "user",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@models_ns.route("/models")
|
@models_ns.route("/models")
|
||||||
class ModelsListResource(Resource):
|
class ModelsListResource(Resource):
|
||||||
def get(self):
|
def get(self):
|
||||||
"""Get list of available models with their capabilities."""
|
"""Get list of available models with their capabilities.
|
||||||
|
|
||||||
|
When the request is authenticated, the response includes the
|
||||||
|
user's own BYOM registrations alongside the built-in catalog.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
|
user_id = None
|
||||||
|
decoded_token = getattr(request, "decoded_token", None)
|
||||||
|
if decoded_token:
|
||||||
|
user_id = decoded_token.get("sub")
|
||||||
|
|
||||||
registry = ModelRegistry.get_instance()
|
registry = ModelRegistry.get_instance()
|
||||||
models = registry.get_enabled_models()
|
models = registry.get_enabled_models(user_id=user_id)
|
||||||
|
|
||||||
response = {
|
response = {
|
||||||
"models": [model.to_dict() for model in models],
|
"models": [model.to_dict() for model in models],
|
||||||
@@ -23,3 +140,382 @@ class ModelsListResource(Resource):
|
|||||||
current_app.logger.error(f"Error fetching models: {err}", exc_info=True)
|
current_app.logger.error(f"Error fetching models: {err}", exc_info=True)
|
||||||
return make_response(jsonify({"success": False}), 500)
|
return make_response(jsonify({"success": False}), 500)
|
||||||
return make_response(jsonify(response), 200)
|
return make_response(jsonify(response), 200)
|
||||||
|
|
||||||
|
|
||||||
|
@models_ns.route("/user/models")
|
||||||
|
class UserModelsCollectionResource(Resource):
|
||||||
|
@api.doc(description="List the current user's BYOM custom models")
|
||||||
|
def get(self):
|
||||||
|
user_id, err = _user_id_or_401()
|
||||||
|
if err:
|
||||||
|
return err
|
||||||
|
try:
|
||||||
|
with db_readonly() as conn:
|
||||||
|
rows = UserCustomModelsRepository(conn).list_for_user(user_id)
|
||||||
|
return make_response(
|
||||||
|
jsonify({"models": [_row_to_response(r) for r in rows]}), 200
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
current_app.logger.error(
|
||||||
|
f"Error listing user custom models: {e}", exc_info=True
|
||||||
|
)
|
||||||
|
return make_response(jsonify({"success": False}), 500)
|
||||||
|
|
||||||
|
@api.doc(description="Register a new BYOM custom model")
|
||||||
|
def post(self):
|
||||||
|
user_id, err = _user_id_or_401()
|
||||||
|
if err:
|
||||||
|
return err
|
||||||
|
|
||||||
|
data = request.get_json() or {}
|
||||||
|
missing = check_required_fields(
|
||||||
|
data,
|
||||||
|
["upstream_model_id", "display_name", "base_url", "api_key"],
|
||||||
|
)
|
||||||
|
if missing:
|
||||||
|
return missing
|
||||||
|
|
||||||
|
# SECURITY: reject blank api_key — would leak instance API key
|
||||||
|
# to the user-supplied base_url via LLMCreator fallback.
|
||||||
|
for required_nonblank in (
|
||||||
|
"upstream_model_id",
|
||||||
|
"display_name",
|
||||||
|
"base_url",
|
||||||
|
"api_key",
|
||||||
|
):
|
||||||
|
value = data.get(required_nonblank)
|
||||||
|
if not isinstance(value, str) or not value.strip():
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"error": f"'{required_nonblank}' must be a non-empty string",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
|
||||||
|
# SSRF guard at create time. Re-runs at dispatch time (LLMCreator)
|
||||||
|
# as defense in depth against DNS rebinding and pre-guard rows.
|
||||||
|
try:
|
||||||
|
validate_user_base_url(data["base_url"])
|
||||||
|
except UnsafeUserUrlError as e:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "error": str(e)}), 400
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
capabilities = _normalize_capabilities(data.get("capabilities"))
|
||||||
|
except ValueError as e:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "error": str(e)}), 400
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with db_session() as conn:
|
||||||
|
row = UserCustomModelsRepository(conn).create(
|
||||||
|
user_id=user_id,
|
||||||
|
upstream_model_id=data["upstream_model_id"],
|
||||||
|
display_name=data["display_name"],
|
||||||
|
description=data.get("description") or "",
|
||||||
|
base_url=data["base_url"],
|
||||||
|
api_key_plaintext=data["api_key"],
|
||||||
|
capabilities=capabilities,
|
||||||
|
enabled=bool(data.get("enabled", True)),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
current_app.logger.error(
|
||||||
|
f"Error creating user custom model: {e}", exc_info=True
|
||||||
|
)
|
||||||
|
return make_response(jsonify({"success": False}), 500)
|
||||||
|
|
||||||
|
ModelRegistry.invalidate_user(user_id)
|
||||||
|
return make_response(jsonify(_row_to_response(row)), 201)
|
||||||
|
|
||||||
|
|
||||||
|
@models_ns.route("/user/models/<string:model_id>")
|
||||||
|
class UserModelResource(Resource):
|
||||||
|
@api.doc(description="Get one BYOM custom model")
|
||||||
|
def get(self, model_id):
|
||||||
|
user_id, err = _user_id_or_401()
|
||||||
|
if err:
|
||||||
|
return err
|
||||||
|
try:
|
||||||
|
with db_readonly() as conn:
|
||||||
|
row = UserCustomModelsRepository(conn).get(model_id, user_id)
|
||||||
|
except Exception as e:
|
||||||
|
current_app.logger.error(
|
||||||
|
f"Error fetching user custom model: {e}", exc_info=True
|
||||||
|
)
|
||||||
|
return make_response(jsonify({"success": False}), 500)
|
||||||
|
if row is None:
|
||||||
|
return make_response(jsonify({"success": False}), 404)
|
||||||
|
return make_response(jsonify(_row_to_response(row)), 200)
|
||||||
|
|
||||||
|
@api.doc(description="Update a BYOM custom model (partial)")
|
||||||
|
def patch(self, model_id):
|
||||||
|
user_id, err = _user_id_or_401()
|
||||||
|
if err:
|
||||||
|
return err
|
||||||
|
|
||||||
|
data = request.get_json() or {}
|
||||||
|
|
||||||
|
# Reject present-but-blank values for fields where blank doesn't
|
||||||
|
# mean "no change". (The api_key special case — blank means "keep
|
||||||
|
# existing" — is handled below.)
|
||||||
|
for required_nonblank in (
|
||||||
|
"upstream_model_id",
|
||||||
|
"display_name",
|
||||||
|
"base_url",
|
||||||
|
):
|
||||||
|
if required_nonblank in data:
|
||||||
|
value = data[required_nonblank]
|
||||||
|
if not isinstance(value, str) or not value.strip():
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"error": f"'{required_nonblank}' cannot be blank",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
|
||||||
|
if "base_url" in data and data["base_url"]:
|
||||||
|
try:
|
||||||
|
validate_user_base_url(data["base_url"])
|
||||||
|
except UnsafeUserUrlError as e:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "error": str(e)}), 400
|
||||||
|
)
|
||||||
|
|
||||||
|
update_fields: dict = {}
|
||||||
|
for k in (
|
||||||
|
"upstream_model_id",
|
||||||
|
"display_name",
|
||||||
|
"description",
|
||||||
|
"base_url",
|
||||||
|
"enabled",
|
||||||
|
):
|
||||||
|
if k in data:
|
||||||
|
update_fields[k] = data[k]
|
||||||
|
|
||||||
|
if "capabilities" in data:
|
||||||
|
try:
|
||||||
|
update_fields["capabilities"] = _normalize_capabilities(
|
||||||
|
data["capabilities"]
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "error": str(e)}), 400
|
||||||
|
)
|
||||||
|
|
||||||
|
# PATCH semantics: blank/missing api_key → keep the existing
|
||||||
|
# ciphertext; non-empty api_key → re-encrypt and replace.
|
||||||
|
if data.get("api_key"):
|
||||||
|
update_fields["api_key_plaintext"] = data["api_key"]
|
||||||
|
|
||||||
|
if not update_fields:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "error": "no updatable fields"}), 400
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with db_session() as conn:
|
||||||
|
ok = UserCustomModelsRepository(conn).update(
|
||||||
|
model_id, user_id, update_fields
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
current_app.logger.error(
|
||||||
|
f"Error updating user custom model: {e}", exc_info=True
|
||||||
|
)
|
||||||
|
return make_response(jsonify({"success": False}), 500)
|
||||||
|
|
||||||
|
if not ok:
|
||||||
|
return make_response(jsonify({"success": False}), 404)
|
||||||
|
|
||||||
|
ModelRegistry.invalidate_user(user_id)
|
||||||
|
with db_readonly() as conn:
|
||||||
|
row = UserCustomModelsRepository(conn).get(model_id, user_id)
|
||||||
|
return make_response(jsonify(_row_to_response(row)), 200)
|
||||||
|
|
||||||
|
@api.doc(description="Delete a BYOM custom model")
|
||||||
|
def delete(self, model_id):
|
||||||
|
user_id, err = _user_id_or_401()
|
||||||
|
if err:
|
||||||
|
return err
|
||||||
|
try:
|
||||||
|
with db_session() as conn:
|
||||||
|
ok = UserCustomModelsRepository(conn).delete(model_id, user_id)
|
||||||
|
except Exception as e:
|
||||||
|
current_app.logger.error(
|
||||||
|
f"Error deleting user custom model: {e}", exc_info=True
|
||||||
|
)
|
||||||
|
return make_response(jsonify({"success": False}), 500)
|
||||||
|
if not ok:
|
||||||
|
return make_response(jsonify({"success": False}), 404)
|
||||||
|
|
||||||
|
ModelRegistry.invalidate_user(user_id)
|
||||||
|
return make_response(jsonify({"success": True}), 200)
|
||||||
|
|
||||||
|
|
||||||
|
def _run_connection_test(
|
||||||
|
base_url: str, api_key: str, upstream_model_id: str
|
||||||
|
):
|
||||||
|
"""Send a 1-token chat-completion to verify a BYOM endpoint.
|
||||||
|
|
||||||
|
Returns ``(body, http_status)``. Upstream errors return 200 with
|
||||||
|
``ok=False`` so the UI can render inline errors; only local SSRF
|
||||||
|
rejection returns 400.
|
||||||
|
"""
|
||||||
|
url = base_url.rstrip("/") + "/chat/completions"
|
||||||
|
payload = {
|
||||||
|
"model": upstream_model_id,
|
||||||
|
"messages": [{"role": "user", "content": "hi"}],
|
||||||
|
"max_tokens": 1,
|
||||||
|
"stream": False,
|
||||||
|
}
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
# pinned_post closes the DNS-rebinding window. Redirects off
|
||||||
|
# because 3xx could bounce to an internal address (the SSRF
|
||||||
|
# guard only validates the supplied URL).
|
||||||
|
resp = pinned_post(
|
||||||
|
url,
|
||||||
|
json=payload,
|
||||||
|
headers=headers,
|
||||||
|
timeout=5,
|
||||||
|
allow_redirects=False,
|
||||||
|
)
|
||||||
|
except UnsafeUserUrlError as e:
|
||||||
|
return {"ok": False, "error": str(e)}, 400
|
||||||
|
except requests.RequestException as e:
|
||||||
|
return {"ok": False, "error": f"connection error: {e}"}, 200
|
||||||
|
|
||||||
|
if 300 <= resp.status_code < 400:
|
||||||
|
return (
|
||||||
|
{
|
||||||
|
"ok": False,
|
||||||
|
"error": (
|
||||||
|
f"upstream returned HTTP {resp.status_code} "
|
||||||
|
"redirect; refusing to follow"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
200,
|
||||||
|
)
|
||||||
|
|
||||||
|
if resp.status_code >= 400:
|
||||||
|
# Cap and only reflect JSON to avoid body-exfil via non-API responses.
|
||||||
|
content_type = (resp.headers.get("Content-Type") or "").lower()
|
||||||
|
if "application/json" in content_type:
|
||||||
|
text = (resp.text or "")[:500]
|
||||||
|
error_msg = f"upstream returned HTTP {resp.status_code}: {text}"
|
||||||
|
else:
|
||||||
|
error_msg = f"upstream returned HTTP {resp.status_code}"
|
||||||
|
return {"ok": False, "error": error_msg}, 200
|
||||||
|
|
||||||
|
return {"ok": True}, 200
|
||||||
|
|
||||||
|
|
||||||
|
@models_ns.route("/user/models/test")
|
||||||
|
class UserModelTestPayloadResource(Resource):
|
||||||
|
@api.doc(
|
||||||
|
description=(
|
||||||
|
"Test an arbitrary BYOM payload (display_name / model id / "
|
||||||
|
"base_url / api_key) without saving. Used by the UI's 'Test "
|
||||||
|
"connection' button so the user can validate before they "
|
||||||
|
"Save. Same SSRF guard, same 1-token request, same 5s "
|
||||||
|
"timeout as the by-id variant."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
def post(self):
|
||||||
|
user_id, err = _user_id_or_401()
|
||||||
|
if err:
|
||||||
|
return err
|
||||||
|
|
||||||
|
data = request.get_json() or {}
|
||||||
|
missing = check_required_fields(
|
||||||
|
data, ["base_url", "api_key", "upstream_model_id"]
|
||||||
|
)
|
||||||
|
if missing:
|
||||||
|
return missing
|
||||||
|
|
||||||
|
body, status = _run_connection_test(
|
||||||
|
data["base_url"], data["api_key"], data["upstream_model_id"]
|
||||||
|
)
|
||||||
|
return make_response(jsonify(body), status)
|
||||||
|
|
||||||
|
|
||||||
|
@models_ns.route("/user/models/<string:model_id>/test")
|
||||||
|
class UserModelTestResource(Resource):
|
||||||
|
@api.doc(
|
||||||
|
description=(
|
||||||
|
"Test a saved BYOM record. Defaults to the stored "
|
||||||
|
"base_url / upstream_model_id / encrypted api_key, but "
|
||||||
|
"any of those can be overridden via the request body so "
|
||||||
|
"the UI can test in-flight edits before saving. Used by "
|
||||||
|
"the 'Test connection' button in edit mode."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
def post(self, model_id):
|
||||||
|
user_id, err = _user_id_or_401()
|
||||||
|
if err:
|
||||||
|
return err
|
||||||
|
|
||||||
|
data = request.get_json() or {}
|
||||||
|
# Per-field overrides; blank/missing falls back to stored value.
|
||||||
|
override_base_url = (data.get("base_url") or "").strip() or None
|
||||||
|
override_upstream_model_id = (
|
||||||
|
data.get("upstream_model_id") or ""
|
||||||
|
).strip() or None
|
||||||
|
override_api_key = (data.get("api_key") or "").strip() or None
|
||||||
|
|
||||||
|
try:
|
||||||
|
with db_readonly() as conn:
|
||||||
|
repo = UserCustomModelsRepository(conn)
|
||||||
|
row = repo.get(model_id, user_id)
|
||||||
|
if row is None:
|
||||||
|
return make_response(jsonify({"success": False}), 404)
|
||||||
|
stored_api_key = (
|
||||||
|
repo._decrypt_api_key(
|
||||||
|
row.get("api_key_encrypted", ""), user_id
|
||||||
|
)
|
||||||
|
if not override_api_key
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
current_app.logger.error(
|
||||||
|
f"Error loading user custom model for test: {e}", exc_info=True
|
||||||
|
)
|
||||||
|
return make_response(
|
||||||
|
jsonify({"ok": False, "error": "internal error loading model"}),
|
||||||
|
500,
|
||||||
|
)
|
||||||
|
|
||||||
|
api_key = override_api_key or stored_api_key
|
||||||
|
if not api_key:
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"ok": False,
|
||||||
|
"error": (
|
||||||
|
"Stored API key could not be decrypted. The "
|
||||||
|
"encryption secret may have rotated. Re-save "
|
||||||
|
"the model with the API key to recover."
|
||||||
|
),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
|
||||||
|
base_url = override_base_url or row["base_url"]
|
||||||
|
upstream_model_id = (
|
||||||
|
override_upstream_model_id or row["upstream_model_id"]
|
||||||
|
)
|
||||||
|
body, status = _run_connection_test(
|
||||||
|
base_url, api_key, upstream_model_id
|
||||||
|
)
|
||||||
|
return make_response(jsonify(body), status)
|
||||||
|
|||||||
@@ -140,6 +140,11 @@ def setup_periodic_tasks(sender, **kwargs):
|
|||||||
cleanup_pending_tool_state.s(),
|
cleanup_pending_tool_state.s(),
|
||||||
name="cleanup-pending-tool-state",
|
name="cleanup-pending-tool-state",
|
||||||
)
|
)
|
||||||
|
sender.add_periodic_task(
|
||||||
|
timedelta(hours=7),
|
||||||
|
version_check_task.s(),
|
||||||
|
name="version-check",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@celery.task(bind=True)
|
@celery.task(bind=True)
|
||||||
@@ -176,3 +181,16 @@ def cleanup_pending_tool_state(self):
|
|||||||
with engine.begin() as conn:
|
with engine.begin() as conn:
|
||||||
deleted = PendingToolStateRepository(conn).cleanup_expired()
|
deleted = PendingToolStateRepository(conn).cleanup_expired()
|
||||||
return {"deleted": deleted}
|
return {"deleted": deleted}
|
||||||
|
|
||||||
|
|
||||||
|
@celery.task(bind=True)
|
||||||
|
def version_check_task(self):
|
||||||
|
"""Periodic anonymous version check.
|
||||||
|
|
||||||
|
Complements the ``worker_ready`` boot trigger so long-running
|
||||||
|
deployments (>6h cache TTL) still refresh advisories. ``run_check``
|
||||||
|
is fail-silent and coordinates across replicas via Redis lock +
|
||||||
|
cache (see ``application.updates.version_check``).
|
||||||
|
"""
|
||||||
|
from application.updates.version_check import run_check
|
||||||
|
run_check()
|
||||||
|
|||||||
@@ -198,8 +198,14 @@ def normalize_agent_node_json_schemas(nodes: List[Dict]) -> List[Dict]:
|
|||||||
return normalized_nodes
|
return normalized_nodes
|
||||||
|
|
||||||
|
|
||||||
def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[str]:
|
def validate_workflow_structure(
|
||||||
"""Validate workflow graph structure."""
|
nodes: List[Dict], edges: List[Dict], user_id: str | None = None
|
||||||
|
) -> List[str]:
|
||||||
|
"""Validate workflow graph structure.
|
||||||
|
|
||||||
|
``user_id`` is required so per-user BYOM custom-model UUIDs resolve
|
||||||
|
when checking each agent node's structured-output capability.
|
||||||
|
"""
|
||||||
errors = []
|
errors = []
|
||||||
|
|
||||||
if not nodes:
|
if not nodes:
|
||||||
@@ -343,7 +349,7 @@ def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[st
|
|||||||
|
|
||||||
model_id = raw_config.get("model_id")
|
model_id = raw_config.get("model_id")
|
||||||
if has_json_schema and isinstance(model_id, str) and model_id.strip():
|
if has_json_schema and isinstance(model_id, str) and model_id.strip():
|
||||||
capabilities = get_model_capabilities(model_id.strip())
|
capabilities = get_model_capabilities(model_id.strip(), user_id=user_id)
|
||||||
if capabilities and not capabilities.get("supports_structured_output", False):
|
if capabilities and not capabilities.get("supports_structured_output", False):
|
||||||
errors.append(
|
errors.append(
|
||||||
f"Agent node '{agent_title}' selected model does not support structured output"
|
f"Agent node '{agent_title}' selected model does not support structured output"
|
||||||
@@ -389,7 +395,9 @@ class WorkflowList(Resource):
|
|||||||
nodes_data = data.get("nodes", [])
|
nodes_data = data.get("nodes", [])
|
||||||
edges_data = data.get("edges", [])
|
edges_data = data.get("edges", [])
|
||||||
|
|
||||||
validation_errors = validate_workflow_structure(nodes_data, edges_data)
|
validation_errors = validate_workflow_structure(
|
||||||
|
nodes_data, edges_data, user_id=user_id
|
||||||
|
)
|
||||||
if validation_errors:
|
if validation_errors:
|
||||||
return error_response(
|
return error_response(
|
||||||
"Workflow validation failed", errors=validation_errors
|
"Workflow validation failed", errors=validation_errors
|
||||||
@@ -451,7 +459,9 @@ class WorkflowDetail(Resource):
|
|||||||
nodes_data = data.get("nodes", [])
|
nodes_data = data.get("nodes", [])
|
||||||
edges_data = data.get("edges", [])
|
edges_data = data.get("edges", [])
|
||||||
|
|
||||||
validation_errors = validate_workflow_structure(nodes_data, edges_data)
|
validation_errors = validate_workflow_structure(
|
||||||
|
nodes_data, edges_data, user_id=user_id
|
||||||
|
)
|
||||||
if validation_errors:
|
if validation_errors:
|
||||||
return error_response(
|
return error_response(
|
||||||
"Workflow validation failed", errors=validation_errors
|
"Workflow validation failed", errors=validation_errors
|
||||||
|
|||||||
@@ -213,6 +213,7 @@ def _stream_response(
|
|||||||
decoded_token=processor.decoded_token,
|
decoded_token=processor.decoded_token,
|
||||||
agent_id=processor.agent_id,
|
agent_id=processor.agent_id,
|
||||||
model_id=processor.model_id,
|
model_id=processor.model_id,
|
||||||
|
model_user_id=processor.model_user_id,
|
||||||
should_save_conversation=should_save_conversation,
|
should_save_conversation=should_save_conversation,
|
||||||
_continuation=continuation,
|
_continuation=continuation,
|
||||||
)
|
)
|
||||||
@@ -257,6 +258,7 @@ def _non_stream_response(
|
|||||||
decoded_token=processor.decoded_token,
|
decoded_token=processor.decoded_token,
|
||||||
agent_id=processor.agent_id,
|
agent_id=processor.agent_id,
|
||||||
model_id=processor.model_id,
|
model_id=processor.model_id,
|
||||||
|
model_user_id=processor.model_user_id,
|
||||||
should_save_conversation=should_save_conversation,
|
should_save_conversation=should_save_conversation,
|
||||||
_continuation=continuation,
|
_continuation=continuation,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,11 +4,12 @@ import platform
|
|||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
import dotenv
|
import dotenv
|
||||||
from flask import Flask, jsonify, redirect, request
|
from flask import Flask, Response, jsonify, redirect, request
|
||||||
from jose import jwt
|
from jose import jwt
|
||||||
|
|
||||||
from application.auth import handle_auth
|
from application.auth import handle_auth
|
||||||
|
|
||||||
|
from application.core import log_context
|
||||||
from application.core.logging_config import setup_logging
|
from application.core.logging_config import setup_logging
|
||||||
|
|
||||||
setup_logging()
|
setup_logging()
|
||||||
@@ -112,6 +113,38 @@ def generate_token():
|
|||||||
return jsonify({"error": "Token generation not allowed in current auth mode"}), 400
|
return jsonify({"error": "Token generation not allowed in current auth mode"}), 400
|
||||||
|
|
||||||
|
|
||||||
|
_LOG_CTX_TOKEN_ATTR = "_log_ctx_token"
|
||||||
|
|
||||||
|
|
||||||
|
@app.before_request
|
||||||
|
def _bind_log_context():
|
||||||
|
"""Bind activity_id + endpoint for the duration of this request.
|
||||||
|
|
||||||
|
Runs before ``authenticate_request``; ``user_id`` is overlaid in a
|
||||||
|
follow-up handler once the JWT has been decoded.
|
||||||
|
"""
|
||||||
|
if request.method == "OPTIONS":
|
||||||
|
return None
|
||||||
|
activity_id = str(uuid.uuid4())
|
||||||
|
request.activity_id = activity_id
|
||||||
|
token = log_context.bind(
|
||||||
|
activity_id=activity_id,
|
||||||
|
endpoint=request.endpoint,
|
||||||
|
)
|
||||||
|
setattr(request, _LOG_CTX_TOKEN_ATTR, token)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@app.teardown_request
|
||||||
|
def _reset_log_context(_exc):
|
||||||
|
# SSE streams keep yielding after teardown fires, but a2wsgi runs each
|
||||||
|
# request inside ``copy_context().run(...)``, so this reset doesn't
|
||||||
|
# leak into the stream's view of the context.
|
||||||
|
token = getattr(request, _LOG_CTX_TOKEN_ATTR, None)
|
||||||
|
if token is not None:
|
||||||
|
log_context.reset(token)
|
||||||
|
|
||||||
|
|
||||||
@app.before_request
|
@app.before_request
|
||||||
def enforce_stt_request_size_limits():
|
def enforce_stt_request_size_limits():
|
||||||
if request.method == "OPTIONS":
|
if request.method == "OPTIONS":
|
||||||
@@ -148,13 +181,27 @@ def authenticate_request():
|
|||||||
request.decoded_token = decoded_token
|
request.decoded_token = decoded_token
|
||||||
|
|
||||||
|
|
||||||
|
@app.before_request
|
||||||
|
def _bind_user_id_to_log_context():
|
||||||
|
# Registered after ``authenticate_request`` (Flask runs before_request
|
||||||
|
# handlers in registration order), so ``request.decoded_token`` is
|
||||||
|
# populated by the time we read it. ``teardown_request`` unwinds the
|
||||||
|
# whole request-level bind, so no separate reset token is needed here.
|
||||||
|
if request.method == "OPTIONS":
|
||||||
|
return None
|
||||||
|
decoded_token = getattr(request, "decoded_token", None)
|
||||||
|
user_id = decoded_token.get("sub") if isinstance(decoded_token, dict) else None
|
||||||
|
if user_id:
|
||||||
|
log_context.bind(user_id=user_id)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@app.after_request
|
@app.after_request
|
||||||
def after_request(response):
|
def after_request(response: Response) -> Response:
|
||||||
response.headers.add("Access-Control-Allow-Origin", "*")
|
"""Add CORS headers for the pure Flask development entrypoint."""
|
||||||
response.headers.add("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
response.headers["Access-Control-Allow-Origin"] = "*"
|
||||||
response.headers.add(
|
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization"
|
||||||
"Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS"
|
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
|
||||||
)
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
33
application/asgi.py
Normal file
33
application/asgi.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
"""ASGI entrypoint: Flask (WSGI) + FastMCP on the same process."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from a2wsgi import WSGIMiddleware
|
||||||
|
from starlette.applications import Starlette
|
||||||
|
from starlette.middleware import Middleware
|
||||||
|
from starlette.middleware.cors import CORSMiddleware
|
||||||
|
from starlette.routing import Mount
|
||||||
|
|
||||||
|
from application.app import app as flask_app
|
||||||
|
from application.mcp_server import mcp
|
||||||
|
|
||||||
|
_WSGI_THREADPOOL = 32
|
||||||
|
|
||||||
|
mcp_app = mcp.http_app(path="/")
|
||||||
|
|
||||||
|
asgi_app = Starlette(
|
||||||
|
routes=[
|
||||||
|
Mount("/mcp", app=mcp_app),
|
||||||
|
Mount("/", app=WSGIMiddleware(flask_app, workers=_WSGI_THREADPOOL)),
|
||||||
|
],
|
||||||
|
middleware=[
|
||||||
|
Middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
|
||||||
|
allow_headers=["Content-Type", "Authorization", "Mcp-Session-Id"],
|
||||||
|
expose_headers=["Mcp-Session-Id"],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
lifespan=mcp_app.lifespan,
|
||||||
|
)
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
@@ -10,6 +11,14 @@ from application.utils import get_hash
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _cache_default(value):
|
||||||
|
# Image attachments arrive inline as bytes (see GoogleLLM.prepare_messages_with_attachments);
|
||||||
|
# hash so the cache key stays bounded in size and stable across identical content.
|
||||||
|
if isinstance(value, (bytes, bytearray, memoryview)):
|
||||||
|
return f"<bytes:sha256:{hashlib.sha256(bytes(value)).hexdigest()}>"
|
||||||
|
return repr(value)
|
||||||
|
|
||||||
_redis_instance = None
|
_redis_instance = None
|
||||||
_redis_creation_failed = False
|
_redis_creation_failed = False
|
||||||
_instance_lock = Lock()
|
_instance_lock = Lock()
|
||||||
@@ -36,7 +45,7 @@ def get_redis_instance():
|
|||||||
def gen_cache_key(messages, model="docgpt", tools=None):
|
def gen_cache_key(messages, model="docgpt", tools=None):
|
||||||
if not all(isinstance(msg, dict) for msg in messages):
|
if not all(isinstance(msg, dict) for msg in messages):
|
||||||
raise ValueError("All messages must be dictionaries.")
|
raise ValueError("All messages must be dictionaries.")
|
||||||
messages_str = json.dumps(messages)
|
messages_str = json.dumps(messages, default=_cache_default)
|
||||||
tools_str = json.dumps(str(tools)) if tools else ""
|
tools_str = json.dumps(str(tools)) if tools else ""
|
||||||
combined = f"{model}_{messages_str}_{tools_str}"
|
combined = f"{model}_{messages_str}_{tools_str}"
|
||||||
cache_key = get_hash(combined)
|
cache_key = get_hash(combined)
|
||||||
|
|||||||
@@ -1,6 +1,17 @@
|
|||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
|
||||||
from celery import Celery
|
from celery import Celery
|
||||||
|
from application.core import log_context
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
from celery.signals import setup_logging, worker_process_init
|
from celery.signals import (
|
||||||
|
setup_logging,
|
||||||
|
task_postrun,
|
||||||
|
task_prerun,
|
||||||
|
worker_process_init,
|
||||||
|
worker_ready,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def make_celery(app_name=__name__):
|
def make_celery(app_name=__name__):
|
||||||
@@ -39,5 +50,73 @@ def _dispose_db_engine_on_fork(*args, **kwargs):
|
|||||||
dispose_engine()
|
dispose_engine()
|
||||||
|
|
||||||
|
|
||||||
|
# Most tasks in this repo accept ``user`` where the log context wants
|
||||||
|
# ``user_id``; map task parameter names to context keys explicitly.
|
||||||
|
_TASK_PARAM_TO_CTX_KEY: dict[str, str] = {
|
||||||
|
"user": "user_id",
|
||||||
|
"user_id": "user_id",
|
||||||
|
"agent_id": "agent_id",
|
||||||
|
"conversation_id": "conversation_id",
|
||||||
|
}
|
||||||
|
|
||||||
|
_task_log_tokens: dict[str, object] = {}
|
||||||
|
|
||||||
|
|
||||||
|
@task_prerun.connect
|
||||||
|
def _bind_task_log_context(task_id, task, args, kwargs, **_):
|
||||||
|
# Resolve task args by parameter name — nearly every task in this repo
|
||||||
|
# is called positionally, so ``kwargs.get('user')`` would bind nothing.
|
||||||
|
ctx = {"activity_id": task_id}
|
||||||
|
try:
|
||||||
|
sig = inspect.signature(task.run)
|
||||||
|
bound = sig.bind_partial(*args, **kwargs).arguments
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
bound = dict(kwargs)
|
||||||
|
for param_name, value in bound.items():
|
||||||
|
ctx_key = _TASK_PARAM_TO_CTX_KEY.get(param_name)
|
||||||
|
if ctx_key and value:
|
||||||
|
ctx[ctx_key] = value
|
||||||
|
_task_log_tokens[task_id] = log_context.bind(**ctx)
|
||||||
|
|
||||||
|
|
||||||
|
@task_postrun.connect
|
||||||
|
def _unbind_task_log_context(task_id, **_):
|
||||||
|
# ``task_postrun`` fires on both success and failure. Required for
|
||||||
|
# Celery: unlike the Flask path, tasks aren't isolated in their own
|
||||||
|
# ``copy_context().run(...)``, so a missing reset would leak the
|
||||||
|
# bind onto the next task on the same worker.
|
||||||
|
token = _task_log_tokens.pop(task_id, None)
|
||||||
|
if token is None:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
log_context.reset(token)
|
||||||
|
except ValueError:
|
||||||
|
# task_prerun and task_postrun ran on different threads (non-default
|
||||||
|
# Celery pool); the token isn't valid in this context. Drop it.
|
||||||
|
logging.getLogger(__name__).debug(
|
||||||
|
"log_context reset skipped for task %s", task_id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@worker_ready.connect
|
||||||
|
def _run_version_check(*args, **kwargs):
|
||||||
|
"""Kick off the anonymous version check on worker startup.
|
||||||
|
|
||||||
|
Runs in a daemon thread so a slow endpoint or bad DNS never holds
|
||||||
|
up the worker becoming ready for tasks. The check itself is
|
||||||
|
fail-silent (see ``application.updates.version_check.run_check``);
|
||||||
|
this handler's only job is to launch it and get out of the way.
|
||||||
|
|
||||||
|
Import is lazy so the symbol resolution never fires at module
|
||||||
|
import time — consistent with the ``_dispose_db_engine_on_fork``
|
||||||
|
pattern above.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from application.updates.version_check import run_check
|
||||||
|
except Exception:
|
||||||
|
return
|
||||||
|
threading.Thread(target=run_check, name="version-check", daemon=True).start()
|
||||||
|
|
||||||
|
|
||||||
celery = make_celery()
|
celery = make_celery()
|
||||||
celery.config_from_object("application.celeryconfig")
|
celery.config_from_object("application.celeryconfig")
|
||||||
|
|||||||
@@ -9,3 +9,8 @@ accept_content = ['json']
|
|||||||
|
|
||||||
# Autodiscover tasks
|
# Autodiscover tasks
|
||||||
imports = ('application.api.user.tasks',)
|
imports = ('application.api.user.tasks',)
|
||||||
|
|
||||||
|
beat_scheduler = "redbeat.RedBeatScheduler"
|
||||||
|
redbeat_redis_url = broker_url
|
||||||
|
redbeat_key_prefix = "redbeat:docsgpt:"
|
||||||
|
redbeat_lock_timeout = 90
|
||||||
|
|||||||
57
application/core/log_context.py
Normal file
57
application/core/log_context.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
"""Per-activity logging context backed by ``contextvars``.
|
||||||
|
|
||||||
|
The ``_ContextFilter`` installed by ``logging_config.setup_logging`` stamps
|
||||||
|
every ``LogRecord`` emitted inside a ``bind`` block with the bound keys, so
|
||||||
|
they land as first-class attributes on the OTLP log export rather than being
|
||||||
|
buried inside formatted message bodies.
|
||||||
|
|
||||||
|
A single ``ContextVar`` holds a dict so nested binds reset atomically (LIFO)
|
||||||
|
via the token returned by ``bind``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from contextvars import ContextVar, Token
|
||||||
|
from typing import Mapping
|
||||||
|
|
||||||
|
|
||||||
|
_CTX_KEYS: frozenset[str] = frozenset(
|
||||||
|
{
|
||||||
|
"activity_id",
|
||||||
|
"parent_activity_id",
|
||||||
|
"user_id",
|
||||||
|
"agent_id",
|
||||||
|
"conversation_id",
|
||||||
|
"endpoint",
|
||||||
|
"model",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
_ctx: ContextVar[Mapping[str, str]] = ContextVar("log_ctx", default={})
|
||||||
|
|
||||||
|
|
||||||
|
def bind(**kwargs: object) -> Token:
|
||||||
|
"""Overlay the given keys onto the current context.
|
||||||
|
|
||||||
|
Returns a ``Token`` so the caller can ``reset`` in a ``finally`` block.
|
||||||
|
Keys outside :data:`_CTX_KEYS` are silently dropped (so a typo can't
|
||||||
|
stamp a stray field name onto every record), as are ``None`` values
|
||||||
|
(a missing attribute is more useful than the literal string ``"None"``).
|
||||||
|
"""
|
||||||
|
overlay = {
|
||||||
|
k: str(v)
|
||||||
|
for k, v in kwargs.items()
|
||||||
|
if k in _CTX_KEYS and v is not None
|
||||||
|
}
|
||||||
|
new = {**_ctx.get(), **overlay}
|
||||||
|
return _ctx.set(new)
|
||||||
|
|
||||||
|
|
||||||
|
def reset(token: Token) -> None:
|
||||||
|
"""Restore the context to the snapshot captured by the matching ``bind``."""
|
||||||
|
_ctx.reset(token)
|
||||||
|
|
||||||
|
|
||||||
|
def snapshot() -> Mapping[str, str]:
|
||||||
|
"""Return the current context dict. Treat as read-only; use :func:`bind`."""
|
||||||
|
return _ctx.get()
|
||||||
@@ -1,11 +1,75 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
from logging.config import dictConfig
|
from logging.config import dictConfig
|
||||||
|
|
||||||
def setup_logging():
|
from application.core.log_context import snapshot as _ctx_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
# Loggers with ``propagate=False`` don't share root's handlers, so the
|
||||||
|
# context filter has to be installed on their handlers directly.
|
||||||
|
_NON_PROPAGATING_LOGGERS: tuple[str, ...] = (
|
||||||
|
"uvicorn",
|
||||||
|
"uvicorn.access",
|
||||||
|
"uvicorn.error",
|
||||||
|
"celery.app.trace",
|
||||||
|
"celery.worker.strategy",
|
||||||
|
"gunicorn.error",
|
||||||
|
"gunicorn.access",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _ContextFilter(logging.Filter):
|
||||||
|
"""Stamp the current ``log_context`` snapshot onto every ``LogRecord``.
|
||||||
|
|
||||||
|
Must be installed on **handlers**, not loggers: Python skips logger-level
|
||||||
|
filters when a child logger's record propagates up. The ``hasattr`` guard
|
||||||
|
keeps an explicit ``logger.info(..., extra={...})`` from being overwritten.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def filter(self, record: logging.LogRecord) -> bool:
|
||||||
|
for key, value in _ctx_snapshot().items():
|
||||||
|
if not hasattr(record, key):
|
||||||
|
setattr(record, key, value)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _otlp_logs_enabled() -> bool:
|
||||||
|
"""Return True when the user has opted in to OTLP log export.
|
||||||
|
|
||||||
|
Gated by the standard OTEL env vars so no project-specific knob is needed:
|
||||||
|
set ``OTEL_LOGS_EXPORTER=otlp`` (and leave ``OTEL_SDK_DISABLED`` unset or
|
||||||
|
false) to flip it on. When false, ``setup_logging`` keeps its original
|
||||||
|
console-only behavior.
|
||||||
|
"""
|
||||||
|
exporter = os.getenv("OTEL_LOGS_EXPORTER", "").strip().lower()
|
||||||
|
disabled = os.getenv("OTEL_SDK_DISABLED", "false").strip().lower() == "true"
|
||||||
|
return exporter == "otlp" and not disabled
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logging() -> None:
|
||||||
|
"""Configure the root logger with a stdout console handler.
|
||||||
|
|
||||||
|
When OTLP log export is enabled, ``opentelemetry-instrument`` attaches a
|
||||||
|
``LoggingHandler`` to the root logger before this function runs. The
|
||||||
|
``dictConfig`` call below replaces ``root.handlers`` with the console
|
||||||
|
handler, which would silently drop the OTEL handler. To make OTLP log
|
||||||
|
export work without forcing every contributor to opt in, snapshot the
|
||||||
|
OTEL handlers up front and re-attach them after ``dictConfig``.
|
||||||
|
"""
|
||||||
|
preserved_handlers: list[logging.Handler] = []
|
||||||
|
if _otlp_logs_enabled():
|
||||||
|
preserved_handlers = [
|
||||||
|
h
|
||||||
|
for h in logging.getLogger().handlers
|
||||||
|
if h.__class__.__module__.startswith("opentelemetry")
|
||||||
|
]
|
||||||
|
|
||||||
dictConfig({
|
dictConfig({
|
||||||
'version': 1,
|
"version": 1,
|
||||||
'formatters': {
|
"disable_existing_loggers": False,
|
||||||
'default': {
|
"formatters": {
|
||||||
'format': '[%(asctime)s] %(levelname)s in %(module)s: %(message)s',
|
"default": {
|
||||||
|
"format": "[%(asctime)s] %(levelname)s in %(module)s: %(message)s",
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"handlers": {
|
"handlers": {
|
||||||
@@ -15,8 +79,34 @@ def setup_logging():
|
|||||||
"formatter": "default",
|
"formatter": "default",
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
'root': {
|
"root": {
|
||||||
'level': 'INFO',
|
"level": "INFO",
|
||||||
'handlers': ['console'],
|
"handlers": ["console"],
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if preserved_handlers:
|
||||||
|
root = logging.getLogger()
|
||||||
|
for handler in preserved_handlers:
|
||||||
|
if handler not in root.handlers:
|
||||||
|
root.addHandler(handler)
|
||||||
|
|
||||||
|
_install_context_filter()
|
||||||
|
|
||||||
|
|
||||||
|
def _install_context_filter() -> None:
|
||||||
|
"""Attach :class:`_ContextFilter` to root's handlers + every handler on
|
||||||
|
the known non-propagating loggers. Skipping handlers that already carry
|
||||||
|
one keeps repeat ``setup_logging`` calls from stacking filters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _has_ctx_filter(handler: logging.Handler) -> bool:
|
||||||
|
return any(isinstance(f, _ContextFilter) for f in handler.filters)
|
||||||
|
|
||||||
|
for handler in logging.getLogger().handlers:
|
||||||
|
if not _has_ctx_filter(handler):
|
||||||
|
handler.addFilter(_ContextFilter())
|
||||||
|
for name in _NON_PROPAGATING_LOGGERS:
|
||||||
|
for handler in logging.getLogger(name).handlers:
|
||||||
|
if not _has_ctx_filter(handler):
|
||||||
|
handler.addFilter(_ContextFilter())
|
||||||
|
|||||||
@@ -1,266 +0,0 @@
|
|||||||
"""
|
|
||||||
Model configurations for all supported LLM providers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from application.core.model_settings import (
|
|
||||||
AvailableModel,
|
|
||||||
ModelCapabilities,
|
|
||||||
ModelProvider,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Base image attachment types supported by most vision-capable LLMs
|
|
||||||
IMAGE_ATTACHMENTS = [
|
|
||||||
"image/png",
|
|
||||||
"image/jpeg",
|
|
||||||
"image/jpg",
|
|
||||||
"image/webp",
|
|
||||||
"image/gif",
|
|
||||||
]
|
|
||||||
|
|
||||||
# PDF excluded: most OpenAI-compatible endpoints don't support native PDF uploads.
|
|
||||||
# When excluded, PDFs are synthetically processed by converting pages to images.
|
|
||||||
OPENAI_ATTACHMENTS = IMAGE_ATTACHMENTS
|
|
||||||
|
|
||||||
GOOGLE_ATTACHMENTS = ["application/pdf"] + IMAGE_ATTACHMENTS
|
|
||||||
|
|
||||||
ANTHROPIC_ATTACHMENTS = IMAGE_ATTACHMENTS
|
|
||||||
|
|
||||||
OPENROUTER_ATTACHMENTS = IMAGE_ATTACHMENTS
|
|
||||||
|
|
||||||
NOVITA_ATTACHMENTS = IMAGE_ATTACHMENTS
|
|
||||||
|
|
||||||
|
|
||||||
OPENAI_MODELS = [
|
|
||||||
AvailableModel(
|
|
||||||
id="gpt-5.1",
|
|
||||||
provider=ModelProvider.OPENAI,
|
|
||||||
display_name="GPT-5.1",
|
|
||||||
description="Flagship model with enhanced reasoning, coding, and agentic capabilities",
|
|
||||||
capabilities=ModelCapabilities(
|
|
||||||
supports_tools=True,
|
|
||||||
supports_structured_output=True,
|
|
||||||
supported_attachment_types=OPENAI_ATTACHMENTS,
|
|
||||||
context_window=200000,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
AvailableModel(
|
|
||||||
id="gpt-5-mini",
|
|
||||||
provider=ModelProvider.OPENAI,
|
|
||||||
display_name="GPT-5 Mini",
|
|
||||||
description="Faster, cost-effective variant of GPT-5.1",
|
|
||||||
capabilities=ModelCapabilities(
|
|
||||||
supports_tools=True,
|
|
||||||
supports_structured_output=True,
|
|
||||||
supported_attachment_types=OPENAI_ATTACHMENTS,
|
|
||||||
context_window=200000,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
ANTHROPIC_MODELS = [
|
|
||||||
AvailableModel(
|
|
||||||
id="claude-3-5-sonnet-20241022",
|
|
||||||
provider=ModelProvider.ANTHROPIC,
|
|
||||||
display_name="Claude 3.5 Sonnet (Latest)",
|
|
||||||
description="Latest Claude 3.5 Sonnet with enhanced capabilities",
|
|
||||||
capabilities=ModelCapabilities(
|
|
||||||
supports_tools=True,
|
|
||||||
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
|
|
||||||
context_window=200000,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
AvailableModel(
|
|
||||||
id="claude-3-5-sonnet",
|
|
||||||
provider=ModelProvider.ANTHROPIC,
|
|
||||||
display_name="Claude 3.5 Sonnet",
|
|
||||||
description="Balanced performance and capability",
|
|
||||||
capabilities=ModelCapabilities(
|
|
||||||
supports_tools=True,
|
|
||||||
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
|
|
||||||
context_window=200000,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
AvailableModel(
|
|
||||||
id="claude-3-opus",
|
|
||||||
provider=ModelProvider.ANTHROPIC,
|
|
||||||
display_name="Claude 3 Opus",
|
|
||||||
description="Most capable Claude model",
|
|
||||||
capabilities=ModelCapabilities(
|
|
||||||
supports_tools=True,
|
|
||||||
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
|
|
||||||
context_window=200000,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
AvailableModel(
|
|
||||||
id="claude-3-haiku",
|
|
||||||
provider=ModelProvider.ANTHROPIC,
|
|
||||||
display_name="Claude 3 Haiku",
|
|
||||||
description="Fastest Claude model",
|
|
||||||
capabilities=ModelCapabilities(
|
|
||||||
supports_tools=True,
|
|
||||||
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
|
|
||||||
context_window=200000,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
GOOGLE_MODELS = [
|
|
||||||
AvailableModel(
|
|
||||||
id="gemini-flash-latest",
|
|
||||||
provider=ModelProvider.GOOGLE,
|
|
||||||
display_name="Gemini Flash (Latest)",
|
|
||||||
description="Latest experimental Gemini model",
|
|
||||||
capabilities=ModelCapabilities(
|
|
||||||
supports_tools=True,
|
|
||||||
supports_structured_output=True,
|
|
||||||
supported_attachment_types=GOOGLE_ATTACHMENTS,
|
|
||||||
context_window=int(1e6),
|
|
||||||
),
|
|
||||||
),
|
|
||||||
AvailableModel(
|
|
||||||
id="gemini-flash-lite-latest",
|
|
||||||
provider=ModelProvider.GOOGLE,
|
|
||||||
display_name="Gemini Flash Lite (Latest)",
|
|
||||||
description="Fast with huge context window",
|
|
||||||
capabilities=ModelCapabilities(
|
|
||||||
supports_tools=True,
|
|
||||||
supports_structured_output=True,
|
|
||||||
supported_attachment_types=GOOGLE_ATTACHMENTS,
|
|
||||||
context_window=int(1e6),
|
|
||||||
),
|
|
||||||
),
|
|
||||||
AvailableModel(
|
|
||||||
id="gemini-3-pro-preview",
|
|
||||||
provider=ModelProvider.GOOGLE,
|
|
||||||
display_name="Gemini 3 Pro",
|
|
||||||
description="Most capable Gemini model",
|
|
||||||
capabilities=ModelCapabilities(
|
|
||||||
supports_tools=True,
|
|
||||||
supports_structured_output=True,
|
|
||||||
supported_attachment_types=GOOGLE_ATTACHMENTS,
|
|
||||||
context_window=2000000,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
GROQ_MODELS = [
|
|
||||||
AvailableModel(
|
|
||||||
id="llama-3.3-70b-versatile",
|
|
||||||
provider=ModelProvider.GROQ,
|
|
||||||
display_name="Llama 3.3 70B",
|
|
||||||
description="Latest Llama model with high-speed inference",
|
|
||||||
capabilities=ModelCapabilities(
|
|
||||||
supports_tools=True,
|
|
||||||
context_window=128000,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
AvailableModel(
|
|
||||||
id="openai/gpt-oss-120b",
|
|
||||||
provider=ModelProvider.GROQ,
|
|
||||||
display_name="GPT-OSS 120B",
|
|
||||||
description="Open-source GPT model optimized for speed",
|
|
||||||
capabilities=ModelCapabilities(
|
|
||||||
supports_tools=True,
|
|
||||||
context_window=128000,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
OPENROUTER_MODELS = [
|
|
||||||
AvailableModel(
|
|
||||||
id="qwen/qwen3-coder:free",
|
|
||||||
provider=ModelProvider.OPENROUTER,
|
|
||||||
display_name="Qwen 3 Coder",
|
|
||||||
description="Latest Qwen model with high-speed inference",
|
|
||||||
capabilities=ModelCapabilities(
|
|
||||||
supports_tools=True,
|
|
||||||
context_window=128000,
|
|
||||||
supported_attachment_types=OPENROUTER_ATTACHMENTS
|
|
||||||
),
|
|
||||||
),
|
|
||||||
AvailableModel(
|
|
||||||
id="google/gemma-3-27b-it:free",
|
|
||||||
provider=ModelProvider.OPENROUTER,
|
|
||||||
display_name="Gemma 3 27B",
|
|
||||||
description="Latest Gemma model with high-speed inference",
|
|
||||||
capabilities=ModelCapabilities(
|
|
||||||
supports_tools=True,
|
|
||||||
context_window=128000,
|
|
||||||
supported_attachment_types=OPENROUTER_ATTACHMENTS
|
|
||||||
),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
NOVITA_MODELS = [
|
|
||||||
AvailableModel(
|
|
||||||
id="moonshotai/kimi-k2.5",
|
|
||||||
provider=ModelProvider.NOVITA,
|
|
||||||
display_name="Kimi K2.5",
|
|
||||||
description="MoE model with function calling, structured output, reasoning, and vision",
|
|
||||||
capabilities=ModelCapabilities(
|
|
||||||
supports_tools=True,
|
|
||||||
supports_structured_output=True,
|
|
||||||
supported_attachment_types=NOVITA_ATTACHMENTS,
|
|
||||||
context_window=262144,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
AvailableModel(
|
|
||||||
id="zai-org/glm-5",
|
|
||||||
provider=ModelProvider.NOVITA,
|
|
||||||
display_name="GLM-5",
|
|
||||||
description="MoE model with function calling, structured output, and reasoning",
|
|
||||||
capabilities=ModelCapabilities(
|
|
||||||
supports_tools=True,
|
|
||||||
supports_structured_output=True,
|
|
||||||
supported_attachment_types=[],
|
|
||||||
context_window=202800,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
AvailableModel(
|
|
||||||
id="minimax/minimax-m2.5",
|
|
||||||
provider=ModelProvider.NOVITA,
|
|
||||||
display_name="MiniMax M2.5",
|
|
||||||
description="MoE model with function calling, structured output, and reasoning",
|
|
||||||
capabilities=ModelCapabilities(
|
|
||||||
supports_tools=True,
|
|
||||||
supports_structured_output=True,
|
|
||||||
supported_attachment_types=[],
|
|
||||||
context_window=204800,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
AZURE_OPENAI_MODELS = [
|
|
||||||
AvailableModel(
|
|
||||||
id="azure-gpt-4",
|
|
||||||
provider=ModelProvider.AZURE_OPENAI,
|
|
||||||
display_name="Azure OpenAI GPT-4",
|
|
||||||
description="Azure-hosted GPT model",
|
|
||||||
capabilities=ModelCapabilities(
|
|
||||||
supports_tools=True,
|
|
||||||
supports_structured_output=True,
|
|
||||||
supported_attachment_types=OPENAI_ATTACHMENTS,
|
|
||||||
context_window=8192,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def create_custom_openai_model(model_name: str, base_url: str) -> AvailableModel:
|
|
||||||
"""Create a custom OpenAI-compatible model (e.g., LM Studio, Ollama)."""
|
|
||||||
return AvailableModel(
|
|
||||||
id=model_name,
|
|
||||||
provider=ModelProvider.OPENAI,
|
|
||||||
display_name=model_name,
|
|
||||||
description=f"Custom OpenAI-compatible model at {base_url}",
|
|
||||||
base_url=base_url,
|
|
||||||
capabilities=ModelCapabilities(
|
|
||||||
supports_tools=True,
|
|
||||||
supported_attachment_types=OPENAI_ATTACHMENTS,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
385
application/core/model_registry.py
Normal file
385
application/core/model_registry.py
Normal file
@@ -0,0 +1,385 @@
|
|||||||
|
"""Layered model registry.
|
||||||
|
|
||||||
|
Loads model catalogs from YAML files (built-in + operator-supplied),
|
||||||
|
groups them by provider name, then for each registered provider plugin
|
||||||
|
calls ``get_models`` to produce the final per-provider model list.
|
||||||
|
|
||||||
|
End-user BYOM (per-user model records in Postgres) is layered on top:
|
||||||
|
when a lookup arrives with a ``user_id``, the registry consults a
|
||||||
|
per-user cache first (loaded from the ``user_custom_models`` table on
|
||||||
|
miss) and falls through to the built-in catalog.
|
||||||
|
|
||||||
|
Cross-process invalidation: ``ModelRegistry`` is a per-process
|
||||||
|
singleton, so a CRUD write only evicts the cache in the process that
|
||||||
|
served it. Other gunicorn workers and Celery workers would otherwise
|
||||||
|
keep using a deleted/disabled/key-rotated BYOM record indefinitely.
|
||||||
|
``invalidate_user`` therefore both drops the local layer *and* bumps a
|
||||||
|
Redis-side version counter; other processes notice the bump on their
|
||||||
|
next access (after the local TTL window) and reload from Postgres. If
|
||||||
|
Redis is unreachable the per-process TTL still bounds staleness — pure
|
||||||
|
TTL semantics, no regression.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
from application.core.model_settings import AvailableModel
|
||||||
|
from application.core.model_yaml import (
|
||||||
|
BUILTIN_MODELS_DIR,
|
||||||
|
ProviderCatalog,
|
||||||
|
load_model_yamls,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_USER_CACHE_TTL_SECONDS = 60.0
|
||||||
|
_USER_VERSION_KEY_PREFIX = "byom:registry_version:"
|
||||||
|
|
||||||
|
|
||||||
|
class ModelRegistry:
|
||||||
|
"""Singleton registry of available models."""
|
||||||
|
|
||||||
|
_instance: Optional["ModelRegistry"] = None
|
||||||
|
_initialized: bool = False
|
||||||
|
|
||||||
|
def __new__(cls):
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
if not ModelRegistry._initialized:
|
||||||
|
self.models: Dict[str, AvailableModel] = {}
|
||||||
|
self.default_model_id: Optional[str] = None
|
||||||
|
# Per-user BYOM cache. Each entry is
|
||||||
|
# ``(layer, version_at_load, loaded_at_monotonic)``:
|
||||||
|
# * ``layer`` — {model_id: AvailableModel}
|
||||||
|
# * ``version_at_load`` — Redis-side counter snapshot at
|
||||||
|
# reload time, or ``None`` if Redis was unreachable
|
||||||
|
# * ``loaded_at_monotonic`` — for TTL bookkeeping
|
||||||
|
# Populated lazily, evicted by TTL + cross-process
|
||||||
|
# invalidation (see ``invalidate_user``).
|
||||||
|
self._user_models: Dict[
|
||||||
|
str,
|
||||||
|
Tuple[Dict[str, AvailableModel], Optional[int], float],
|
||||||
|
] = {}
|
||||||
|
self._load_models()
|
||||||
|
ModelRegistry._initialized = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_instance(cls) -> "ModelRegistry":
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def reset(cls) -> None:
|
||||||
|
"""Clear the singleton. Intended for test fixtures."""
|
||||||
|
cls._instance = None
|
||||||
|
cls._initialized = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def invalidate_user(cls, user_id: str) -> None:
|
||||||
|
"""Drop the cached per-user model layer for ``user_id``.
|
||||||
|
|
||||||
|
Called by the BYOM REST routes after every create/update/delete.
|
||||||
|
Two effects:
|
||||||
|
|
||||||
|
* Local: pop the entry from this process's cache so the next
|
||||||
|
lookup re-reads from Postgres immediately.
|
||||||
|
* Cross-process: ``INCR`` a Redis-side version counter for this
|
||||||
|
user. Other gunicorn/Celery processes notice the counter
|
||||||
|
changed on their next TTL-driven recheck (see
|
||||||
|
``_user_models_for``) and reload. If Redis is unreachable we
|
||||||
|
log and continue — local invalidation still happened, and
|
||||||
|
peers fall back to TTL-only staleness bounds.
|
||||||
|
"""
|
||||||
|
if cls._instance is not None:
|
||||||
|
cls._instance._user_models.pop(user_id, None)
|
||||||
|
try:
|
||||||
|
from application.cache import get_redis_instance
|
||||||
|
|
||||||
|
client = get_redis_instance()
|
||||||
|
if client is not None:
|
||||||
|
client.incr(_USER_VERSION_KEY_PREFIX + user_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"BYOM invalidate: failed to publish version bump for "
|
||||||
|
"user %s (Redis unreachable?): %s",
|
||||||
|
user_id,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _read_user_version(cls, user_id: str) -> Optional[int]:
|
||||||
|
"""Return the Redis-side invalidation counter for ``user_id``.
|
||||||
|
|
||||||
|
``0`` if the key has never been bumped; ``None`` if Redis is
|
||||||
|
unreachable or the read failed (callers fall back to TTL-only
|
||||||
|
staleness in that case).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from application.cache import get_redis_instance
|
||||||
|
|
||||||
|
client = get_redis_instance()
|
||||||
|
if client is None:
|
||||||
|
return None
|
||||||
|
raw = client.get(_USER_VERSION_KEY_PREFIX + user_id)
|
||||||
|
if raw is None:
|
||||||
|
return 0
|
||||||
|
return int(raw)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _load_models(self) -> None:
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from application.core.settings import settings
|
||||||
|
from application.llm.providers import ALL_PROVIDERS
|
||||||
|
|
||||||
|
directories = [BUILTIN_MODELS_DIR]
|
||||||
|
operator_dir = getattr(settings, "MODELS_CONFIG_DIR", None)
|
||||||
|
if operator_dir:
|
||||||
|
op_path = Path(operator_dir)
|
||||||
|
if not op_path.exists():
|
||||||
|
logger.warning(
|
||||||
|
"MODELS_CONFIG_DIR=%s does not exist; no operator "
|
||||||
|
"model YAMLs will be loaded.",
|
||||||
|
operator_dir,
|
||||||
|
)
|
||||||
|
elif not op_path.is_dir():
|
||||||
|
logger.warning(
|
||||||
|
"MODELS_CONFIG_DIR=%s is not a directory; no operator "
|
||||||
|
"model YAMLs will be loaded.",
|
||||||
|
operator_dir,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
directories.append(op_path)
|
||||||
|
|
||||||
|
catalogs = load_model_yamls(directories)
|
||||||
|
|
||||||
|
# Validate every catalog targets a known plugin before doing any
|
||||||
|
# registry work, so an unknown provider name in YAML aborts boot
|
||||||
|
# with a clear error.
|
||||||
|
plugin_names = {p.name for p in ALL_PROVIDERS}
|
||||||
|
for c in catalogs:
|
||||||
|
if c.provider not in plugin_names:
|
||||||
|
raise ValueError(
|
||||||
|
f"{c.source_path}: YAML declares unknown provider "
|
||||||
|
f"{c.provider!r}; no Provider plugin is registered "
|
||||||
|
f"under that name. Known: {sorted(plugin_names)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
catalogs_by_provider: Dict[str, List[ProviderCatalog]] = defaultdict(list)
|
||||||
|
for c in catalogs:
|
||||||
|
catalogs_by_provider[c.provider].append(c)
|
||||||
|
|
||||||
|
self.models.clear()
|
||||||
|
for provider in ALL_PROVIDERS:
|
||||||
|
if not provider.is_enabled(settings):
|
||||||
|
continue
|
||||||
|
for model in provider.get_models(
|
||||||
|
settings, catalogs_by_provider.get(provider.name, [])
|
||||||
|
):
|
||||||
|
self.models[model.id] = model
|
||||||
|
|
||||||
|
self.default_model_id = self._resolve_default(settings)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"ModelRegistry loaded %d models, default: %s",
|
||||||
|
len(self.models),
|
||||||
|
self.default_model_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _resolve_default(self, settings) -> Optional[str]:
|
||||||
|
if settings.LLM_NAME:
|
||||||
|
for name in self._parse_model_names(settings.LLM_NAME):
|
||||||
|
if name in self.models:
|
||||||
|
return name
|
||||||
|
if settings.LLM_NAME in self.models:
|
||||||
|
return settings.LLM_NAME
|
||||||
|
|
||||||
|
if settings.LLM_PROVIDER and settings.API_KEY:
|
||||||
|
for model_id, model in self.models.items():
|
||||||
|
if model.provider.value == settings.LLM_PROVIDER:
|
||||||
|
return model_id
|
||||||
|
|
||||||
|
if self.models:
|
||||||
|
return next(iter(self.models.keys()))
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_model_names(llm_name: str) -> List[str]:
|
||||||
|
if not llm_name:
|
||||||
|
return []
|
||||||
|
return [name.strip() for name in llm_name.split(",") if name.strip()]
|
||||||
|
|
||||||
|
# Per-user (BYOM) layer
|
||||||
|
|
||||||
|
def _user_models_for(self, user_id: str) -> Dict[str, AvailableModel]:
|
||||||
|
"""Return the user's BYOM models keyed by registry id (UUID).
|
||||||
|
|
||||||
|
Loaded lazily from Postgres on first access; cached subject to
|
||||||
|
a per-process TTL (``_USER_CACHE_TTL_SECONDS``) and a Redis-
|
||||||
|
backed version counter for cross-process invalidation. The TTL
|
||||||
|
bounds staleness even when Redis is unreachable, while the
|
||||||
|
version stamp lets peers refresh without a DB read on the
|
||||||
|
common case (no invalidation since last load). Decryption
|
||||||
|
failures and DB errors yield an empty layer (logged) — the
|
||||||
|
user simply doesn't see their custom models on this request,
|
||||||
|
never a 500.
|
||||||
|
"""
|
||||||
|
cached = self._user_models.get(user_id)
|
||||||
|
now = time.monotonic()
|
||||||
|
|
||||||
|
if cached is not None:
|
||||||
|
layer, cached_version, loaded_at = cached
|
||||||
|
if (now - loaded_at) < _USER_CACHE_TTL_SECONDS:
|
||||||
|
return layer
|
||||||
|
# TTL elapsed: peek at the cross-process counter. If it
|
||||||
|
# matches what we saw at load time, no invalidation has
|
||||||
|
# happened — extend the TTL without touching Postgres. If
|
||||||
|
# Redis is unreachable (``current_version is None``) we
|
||||||
|
# fall through to a real reload, which keeps staleness
|
||||||
|
# bounded to the TTL.
|
||||||
|
current_version = self._read_user_version(user_id)
|
||||||
|
if (
|
||||||
|
current_version is not None
|
||||||
|
and cached_version is not None
|
||||||
|
and current_version == cached_version
|
||||||
|
):
|
||||||
|
self._user_models[user_id] = (layer, cached_version, now)
|
||||||
|
return layer
|
||||||
|
|
||||||
|
# Capture the counter *before* the DB read so a CRUD that lands
|
||||||
|
# mid-reload doesn't get masked: the next access will see a
|
||||||
|
# newer version and reload again.
|
||||||
|
version_before_read = self._read_user_version(user_id)
|
||||||
|
|
||||||
|
layer: Dict[str, AvailableModel] = {}
|
||||||
|
try:
|
||||||
|
from application.core.model_settings import (
|
||||||
|
ModelCapabilities,
|
||||||
|
ModelProvider,
|
||||||
|
)
|
||||||
|
from application.storage.db.repositories.user_custom_models import (
|
||||||
|
UserCustomModelsRepository,
|
||||||
|
)
|
||||||
|
from application.storage.db.session import db_readonly
|
||||||
|
|
||||||
|
with db_readonly() as conn:
|
||||||
|
repo = UserCustomModelsRepository(conn)
|
||||||
|
rows = repo.list_for_user(user_id)
|
||||||
|
for row in rows:
|
||||||
|
api_key = repo._decrypt_api_key(
|
||||||
|
row.get("api_key_encrypted", ""), user_id
|
||||||
|
)
|
||||||
|
if not api_key:
|
||||||
|
# SECURITY: do NOT register an unroutable BYOM
|
||||||
|
# record. If we did, LLMCreator would fall back
|
||||||
|
# to the caller-passed api_key (settings.API_KEY
|
||||||
|
# for openai_compatible) and POST it to the
|
||||||
|
# user-supplied base_url — leaking the instance
|
||||||
|
# credential to the user's chosen endpoint.
|
||||||
|
# Most likely cause is ENCRYPTION_SECRET_KEY
|
||||||
|
# having rotated; user must re-save the model.
|
||||||
|
logger.warning(
|
||||||
|
"user_custom_models: skipping model %s for "
|
||||||
|
"user %s — api_key could not be decrypted "
|
||||||
|
"(rotated ENCRYPTION_SECRET_KEY?). Re-save "
|
||||||
|
"the model to recover.",
|
||||||
|
row.get("id"),
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
caps_raw = row.get("capabilities") or {}
|
||||||
|
# Stored attachments may be aliases (``image``) or
|
||||||
|
# raw MIME types. Built-in YAML models expand at
|
||||||
|
# load time; mirror that here so downstream MIME-
|
||||||
|
# type comparisons (handlers/base.prepare_messages)
|
||||||
|
# match concrete types like ``image/png`` rather
|
||||||
|
# than the bare alias.
|
||||||
|
from application.core.model_yaml import (
|
||||||
|
expand_attachments_lenient,
|
||||||
|
)
|
||||||
|
|
||||||
|
raw_attachments = caps_raw.get("attachments", []) or []
|
||||||
|
expanded_attachments = expand_attachments_lenient(
|
||||||
|
raw_attachments,
|
||||||
|
f"user_custom_models[user={user_id}, model={row.get('id')}]",
|
||||||
|
)
|
||||||
|
caps = ModelCapabilities(
|
||||||
|
supports_tools=bool(caps_raw.get("supports_tools", False)),
|
||||||
|
supports_structured_output=bool(
|
||||||
|
caps_raw.get("supports_structured_output", False)
|
||||||
|
),
|
||||||
|
supports_streaming=bool(
|
||||||
|
caps_raw.get("supports_streaming", True)
|
||||||
|
),
|
||||||
|
supported_attachment_types=expanded_attachments,
|
||||||
|
context_window=int(
|
||||||
|
caps_raw.get("context_window") or 128000
|
||||||
|
),
|
||||||
|
)
|
||||||
|
model_id = str(row["id"])
|
||||||
|
layer[model_id] = AvailableModel(
|
||||||
|
id=model_id,
|
||||||
|
provider=ModelProvider.OPENAI_COMPATIBLE,
|
||||||
|
display_name=row["display_name"],
|
||||||
|
description=row.get("description") or "",
|
||||||
|
capabilities=caps,
|
||||||
|
enabled=bool(row.get("enabled", True)),
|
||||||
|
base_url=row["base_url"],
|
||||||
|
upstream_model_id=row["upstream_model_id"],
|
||||||
|
source="user",
|
||||||
|
api_key=api_key,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"user_custom_models: failed to load layer for user %s: %s",
|
||||||
|
user_id,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
layer = {}
|
||||||
|
|
||||||
|
self._user_models[user_id] = (layer, version_before_read, now)
|
||||||
|
return layer
|
||||||
|
|
||||||
|
# Lookup API. ``user_id`` enables the BYOM per-user layer; without
|
||||||
|
# it, callers see only the built-in + operator catalog.
|
||||||
|
|
||||||
|
def get_model(
|
||||||
|
self, model_id: str, user_id: Optional[str] = None
|
||||||
|
) -> Optional[AvailableModel]:
|
||||||
|
if user_id:
|
||||||
|
user_layer = self._user_models_for(user_id)
|
||||||
|
if model_id in user_layer:
|
||||||
|
return user_layer[model_id]
|
||||||
|
return self.models.get(model_id)
|
||||||
|
|
||||||
|
def get_all_models(
|
||||||
|
self, user_id: Optional[str] = None
|
||||||
|
) -> List[AvailableModel]:
|
||||||
|
out = list(self.models.values())
|
||||||
|
if user_id:
|
||||||
|
out.extend(self._user_models_for(user_id).values())
|
||||||
|
return out
|
||||||
|
|
||||||
|
def get_enabled_models(
|
||||||
|
self, user_id: Optional[str] = None
|
||||||
|
) -> List[AvailableModel]:
|
||||||
|
out = [m for m in self.models.values() if m.enabled]
|
||||||
|
if user_id:
|
||||||
|
out.extend(
|
||||||
|
m for m in self._user_models_for(user_id).values() if m.enabled
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def model_exists(
|
||||||
|
self, model_id: str, user_id: Optional[str] = None
|
||||||
|
) -> bool:
|
||||||
|
if user_id and model_id in self._user_models_for(user_id):
|
||||||
|
return True
|
||||||
|
return model_id in self.models
|
||||||
@@ -5,9 +5,16 @@ from typing import Dict, List, Optional
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Re-exported here so existing call sites (and tests) that do
|
||||||
|
# ``from application.core.model_settings import ModelRegistry`` keep
|
||||||
|
# working. The implementation lives in ``application/core/model_registry.py``.
|
||||||
|
# Imported lazily inside ``__getattr__`` to avoid an import cycle with
|
||||||
|
# ``model_yaml`` → ``model_settings`` (this file).
|
||||||
|
|
||||||
|
|
||||||
class ModelProvider(str, Enum):
|
class ModelProvider(str, Enum):
|
||||||
OPENAI = "openai"
|
OPENAI = "openai"
|
||||||
|
OPENAI_COMPATIBLE = "openai_compatible"
|
||||||
OPENROUTER = "openrouter"
|
OPENROUTER = "openrouter"
|
||||||
AZURE_OPENAI = "azure_openai"
|
AZURE_OPENAI = "azure_openai"
|
||||||
ANTHROPIC = "anthropic"
|
ANTHROPIC = "anthropic"
|
||||||
@@ -41,11 +48,21 @@ class AvailableModel:
|
|||||||
capabilities: ModelCapabilities = field(default_factory=ModelCapabilities)
|
capabilities: ModelCapabilities = field(default_factory=ModelCapabilities)
|
||||||
enabled: bool = True
|
enabled: bool = True
|
||||||
base_url: Optional[str] = None
|
base_url: Optional[str] = None
|
||||||
|
# User-facing label distinct from dispatch provider (e.g. mistral
|
||||||
|
# routed through openai_compatible).
|
||||||
|
display_provider: Optional[str] = None
|
||||||
|
# Sent in the API call's ``model`` field; falls back to ``self.id``
|
||||||
|
# for built-ins where id IS the upstream name.
|
||||||
|
upstream_model_id: Optional[str] = None
|
||||||
|
# "builtin" for catalog YAMLs, "user" for BYOM records.
|
||||||
|
source: str = "builtin"
|
||||||
|
# Decrypted/resolved at registry-merge time. Never serialized.
|
||||||
|
api_key: Optional[str] = field(default=None, repr=False, compare=False)
|
||||||
|
|
||||||
def to_dict(self) -> Dict:
|
def to_dict(self) -> Dict:
|
||||||
result = {
|
result = {
|
||||||
"id": self.id,
|
"id": self.id,
|
||||||
"provider": self.provider.value,
|
"provider": self.display_provider or self.provider.value,
|
||||||
"display_name": self.display_name,
|
"display_name": self.display_name,
|
||||||
"description": self.description,
|
"description": self.description,
|
||||||
"supported_attachment_types": self.capabilities.supported_attachment_types,
|
"supported_attachment_types": self.capabilities.supported_attachment_types,
|
||||||
@@ -54,261 +71,21 @@ class AvailableModel:
|
|||||||
"supports_streaming": self.capabilities.supports_streaming,
|
"supports_streaming": self.capabilities.supports_streaming,
|
||||||
"context_window": self.capabilities.context_window,
|
"context_window": self.capabilities.context_window,
|
||||||
"enabled": self.enabled,
|
"enabled": self.enabled,
|
||||||
|
"source": self.source,
|
||||||
}
|
}
|
||||||
if self.base_url:
|
if self.base_url:
|
||||||
result["base_url"] = self.base_url
|
result["base_url"] = self.base_url
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
class ModelRegistry:
|
def __getattr__(name):
|
||||||
_instance = None
|
"""Lazy re-export of ``ModelRegistry`` from ``model_registry.py``.
|
||||||
_initialized = False
|
|
||||||
|
|
||||||
def __new__(cls):
|
Done lazily to avoid an import cycle: ``model_registry`` imports
|
||||||
if cls._instance is None:
|
``model_yaml`` which imports the dataclasses from this file.
|
||||||
cls._instance = super().__new__(cls)
|
"""
|
||||||
return cls._instance
|
if name == "ModelRegistry":
|
||||||
|
from application.core.model_registry import ModelRegistry as _MR
|
||||||
|
|
||||||
def __init__(self):
|
return _MR
|
||||||
if not ModelRegistry._initialized:
|
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||||
self.models: Dict[str, AvailableModel] = {}
|
|
||||||
self.default_model_id: Optional[str] = None
|
|
||||||
self._load_models()
|
|
||||||
ModelRegistry._initialized = True
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_instance(cls) -> "ModelRegistry":
|
|
||||||
return cls()
|
|
||||||
|
|
||||||
def _load_models(self):
|
|
||||||
from application.core.settings import settings
|
|
||||||
|
|
||||||
self.models.clear()
|
|
||||||
|
|
||||||
# Skip DocsGPT model if using custom OpenAI-compatible endpoint
|
|
||||||
if not settings.OPENAI_BASE_URL:
|
|
||||||
self._add_docsgpt_models(settings)
|
|
||||||
if (
|
|
||||||
settings.OPENAI_API_KEY
|
|
||||||
or (settings.LLM_PROVIDER == "openai" and settings.API_KEY)
|
|
||||||
or settings.OPENAI_BASE_URL
|
|
||||||
):
|
|
||||||
self._add_openai_models(settings)
|
|
||||||
if settings.OPENAI_API_BASE or (
|
|
||||||
settings.LLM_PROVIDER == "azure_openai" and settings.API_KEY
|
|
||||||
):
|
|
||||||
self._add_azure_openai_models(settings)
|
|
||||||
if settings.ANTHROPIC_API_KEY or (
|
|
||||||
settings.LLM_PROVIDER == "anthropic" and settings.API_KEY
|
|
||||||
):
|
|
||||||
self._add_anthropic_models(settings)
|
|
||||||
if settings.GOOGLE_API_KEY or (
|
|
||||||
settings.LLM_PROVIDER == "google" and settings.API_KEY
|
|
||||||
):
|
|
||||||
self._add_google_models(settings)
|
|
||||||
if settings.GROQ_API_KEY or (
|
|
||||||
settings.LLM_PROVIDER == "groq" and settings.API_KEY
|
|
||||||
):
|
|
||||||
self._add_groq_models(settings)
|
|
||||||
if settings.OPEN_ROUTER_API_KEY or (
|
|
||||||
settings.LLM_PROVIDER == "openrouter" and settings.API_KEY
|
|
||||||
):
|
|
||||||
self._add_openrouter_models(settings)
|
|
||||||
if settings.NOVITA_API_KEY or (
|
|
||||||
settings.LLM_PROVIDER == "novita" and settings.API_KEY
|
|
||||||
):
|
|
||||||
self._add_novita_models(settings)
|
|
||||||
if settings.HUGGINGFACE_API_KEY or (
|
|
||||||
settings.LLM_PROVIDER == "huggingface" and settings.API_KEY
|
|
||||||
):
|
|
||||||
self._add_huggingface_models(settings)
|
|
||||||
# Default model selection
|
|
||||||
if settings.LLM_NAME:
|
|
||||||
# Parse LLM_NAME (may be comma-separated)
|
|
||||||
model_names = self._parse_model_names(settings.LLM_NAME)
|
|
||||||
# First model in the list becomes default
|
|
||||||
for model_name in model_names:
|
|
||||||
if model_name in self.models:
|
|
||||||
self.default_model_id = model_name
|
|
||||||
break
|
|
||||||
# Backward compat: try exact match if no parsed model found
|
|
||||||
if not self.default_model_id and settings.LLM_NAME in self.models:
|
|
||||||
self.default_model_id = settings.LLM_NAME
|
|
||||||
|
|
||||||
if not self.default_model_id:
|
|
||||||
if settings.LLM_PROVIDER and settings.API_KEY:
|
|
||||||
for model_id, model in self.models.items():
|
|
||||||
if model.provider.value == settings.LLM_PROVIDER:
|
|
||||||
self.default_model_id = model_id
|
|
||||||
break
|
|
||||||
|
|
||||||
if not self.default_model_id and self.models:
|
|
||||||
self.default_model_id = next(iter(self.models.keys()))
|
|
||||||
logger.info(
|
|
||||||
f"ModelRegistry loaded {len(self.models)} models, default: {self.default_model_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _add_openai_models(self, settings):
|
|
||||||
from application.core.model_configs import (
|
|
||||||
OPENAI_MODELS,
|
|
||||||
create_custom_openai_model,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if using local OpenAI-compatible endpoint (Ollama, LM Studio, etc.)
|
|
||||||
using_local_endpoint = bool(
|
|
||||||
settings.OPENAI_BASE_URL and settings.OPENAI_BASE_URL.strip()
|
|
||||||
)
|
|
||||||
|
|
||||||
if using_local_endpoint:
|
|
||||||
# When OPENAI_BASE_URL is set, ONLY register custom models from LLM_NAME
|
|
||||||
# Do NOT add standard OpenAI models (gpt-5.1, etc.)
|
|
||||||
if settings.LLM_NAME:
|
|
||||||
model_names = self._parse_model_names(settings.LLM_NAME)
|
|
||||||
for model_name in model_names:
|
|
||||||
custom_model = create_custom_openai_model(
|
|
||||||
model_name, settings.OPENAI_BASE_URL
|
|
||||||
)
|
|
||||||
self.models[model_name] = custom_model
|
|
||||||
logger.info(
|
|
||||||
f"Registered custom OpenAI model: {model_name} at {settings.OPENAI_BASE_URL}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Standard OpenAI API usage - add standard models if API key is valid
|
|
||||||
if settings.OPENAI_API_KEY:
|
|
||||||
for model in OPENAI_MODELS:
|
|
||||||
self.models[model.id] = model
|
|
||||||
|
|
||||||
def _add_azure_openai_models(self, settings):
|
|
||||||
from application.core.model_configs import AZURE_OPENAI_MODELS
|
|
||||||
|
|
||||||
if settings.LLM_PROVIDER == "azure_openai" and settings.LLM_NAME:
|
|
||||||
for model in AZURE_OPENAI_MODELS:
|
|
||||||
if model.id == settings.LLM_NAME:
|
|
||||||
self.models[model.id] = model
|
|
||||||
return
|
|
||||||
for model in AZURE_OPENAI_MODELS:
|
|
||||||
self.models[model.id] = model
|
|
||||||
|
|
||||||
def _add_anthropic_models(self, settings):
|
|
||||||
from application.core.model_configs import ANTHROPIC_MODELS
|
|
||||||
|
|
||||||
if settings.ANTHROPIC_API_KEY:
|
|
||||||
for model in ANTHROPIC_MODELS:
|
|
||||||
self.models[model.id] = model
|
|
||||||
return
|
|
||||||
if settings.LLM_PROVIDER == "anthropic" and settings.LLM_NAME:
|
|
||||||
for model in ANTHROPIC_MODELS:
|
|
||||||
if model.id == settings.LLM_NAME:
|
|
||||||
self.models[model.id] = model
|
|
||||||
return
|
|
||||||
for model in ANTHROPIC_MODELS:
|
|
||||||
self.models[model.id] = model
|
|
||||||
|
|
||||||
def _add_google_models(self, settings):
|
|
||||||
from application.core.model_configs import GOOGLE_MODELS
|
|
||||||
|
|
||||||
if settings.GOOGLE_API_KEY:
|
|
||||||
for model in GOOGLE_MODELS:
|
|
||||||
self.models[model.id] = model
|
|
||||||
return
|
|
||||||
if settings.LLM_PROVIDER == "google" and settings.LLM_NAME:
|
|
||||||
for model in GOOGLE_MODELS:
|
|
||||||
if model.id == settings.LLM_NAME:
|
|
||||||
self.models[model.id] = model
|
|
||||||
return
|
|
||||||
for model in GOOGLE_MODELS:
|
|
||||||
self.models[model.id] = model
|
|
||||||
|
|
||||||
def _add_groq_models(self, settings):
|
|
||||||
from application.core.model_configs import GROQ_MODELS
|
|
||||||
|
|
||||||
if settings.GROQ_API_KEY:
|
|
||||||
for model in GROQ_MODELS:
|
|
||||||
self.models[model.id] = model
|
|
||||||
return
|
|
||||||
if settings.LLM_PROVIDER == "groq" and settings.LLM_NAME:
|
|
||||||
for model in GROQ_MODELS:
|
|
||||||
if model.id == settings.LLM_NAME:
|
|
||||||
self.models[model.id] = model
|
|
||||||
return
|
|
||||||
for model in GROQ_MODELS:
|
|
||||||
self.models[model.id] = model
|
|
||||||
|
|
||||||
def _add_openrouter_models(self, settings):
|
|
||||||
from application.core.model_configs import OPENROUTER_MODELS
|
|
||||||
|
|
||||||
if settings.OPEN_ROUTER_API_KEY:
|
|
||||||
for model in OPENROUTER_MODELS:
|
|
||||||
self.models[model.id] = model
|
|
||||||
return
|
|
||||||
if settings.LLM_PROVIDER == "openrouter" and settings.LLM_NAME:
|
|
||||||
for model in OPENROUTER_MODELS:
|
|
||||||
if model.id == settings.LLM_NAME:
|
|
||||||
self.models[model.id] = model
|
|
||||||
return
|
|
||||||
for model in OPENROUTER_MODELS:
|
|
||||||
self.models[model.id] = model
|
|
||||||
|
|
||||||
def _add_novita_models(self, settings):
|
|
||||||
from application.core.model_configs import NOVITA_MODELS
|
|
||||||
|
|
||||||
if settings.NOVITA_API_KEY:
|
|
||||||
for model in NOVITA_MODELS:
|
|
||||||
self.models[model.id] = model
|
|
||||||
return
|
|
||||||
if settings.LLM_PROVIDER == "novita" and settings.LLM_NAME:
|
|
||||||
for model in NOVITA_MODELS:
|
|
||||||
if model.id == settings.LLM_NAME:
|
|
||||||
self.models[model.id] = model
|
|
||||||
return
|
|
||||||
for model in NOVITA_MODELS:
|
|
||||||
self.models[model.id] = model
|
|
||||||
|
|
||||||
def _add_docsgpt_models(self, settings):
|
|
||||||
model_id = "docsgpt-local"
|
|
||||||
model = AvailableModel(
|
|
||||||
id=model_id,
|
|
||||||
provider=ModelProvider.DOCSGPT,
|
|
||||||
display_name="DocsGPT Model",
|
|
||||||
description="Local model",
|
|
||||||
capabilities=ModelCapabilities(
|
|
||||||
supports_tools=False,
|
|
||||||
supported_attachment_types=[],
|
|
||||||
),
|
|
||||||
)
|
|
||||||
self.models[model_id] = model
|
|
||||||
|
|
||||||
def _add_huggingface_models(self, settings):
|
|
||||||
model_id = "huggingface-local"
|
|
||||||
model = AvailableModel(
|
|
||||||
id=model_id,
|
|
||||||
provider=ModelProvider.HUGGINGFACE,
|
|
||||||
display_name="Hugging Face Model",
|
|
||||||
description="Local Hugging Face model",
|
|
||||||
capabilities=ModelCapabilities(
|
|
||||||
supports_tools=False,
|
|
||||||
supported_attachment_types=[],
|
|
||||||
),
|
|
||||||
)
|
|
||||||
self.models[model_id] = model
|
|
||||||
|
|
||||||
def _parse_model_names(self, llm_name: str) -> List[str]:
|
|
||||||
"""
|
|
||||||
Parse LLM_NAME which may contain comma-separated model names.
|
|
||||||
E.g., 'deepseek-r1:1.5b,gemma:2b' -> ['deepseek-r1:1.5b', 'gemma:2b']
|
|
||||||
"""
|
|
||||||
if not llm_name:
|
|
||||||
return []
|
|
||||||
return [name.strip() for name in llm_name.split(",") if name.strip()]
|
|
||||||
|
|
||||||
def get_model(self, model_id: str) -> Optional[AvailableModel]:
|
|
||||||
return self.models.get(model_id)
|
|
||||||
|
|
||||||
def get_all_models(self) -> List[AvailableModel]:
|
|
||||||
return list(self.models.values())
|
|
||||||
|
|
||||||
def get_enabled_models(self) -> List[AvailableModel]:
|
|
||||||
return [m for m in self.models.values() if m.enabled]
|
|
||||||
|
|
||||||
def model_exists(self, model_id: str) -> bool:
|
|
||||||
return model_id in self.models
|
|
||||||
|
|||||||
@@ -1,47 +1,59 @@
|
|||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from application.core.model_settings import ModelRegistry
|
from application.core.model_registry import ModelRegistry
|
||||||
|
|
||||||
|
|
||||||
def get_api_key_for_provider(provider: str) -> Optional[str]:
|
def get_api_key_for_provider(provider: str) -> Optional[str]:
|
||||||
"""Get the appropriate API key for a provider"""
|
"""Get the appropriate API key for a provider.
|
||||||
|
|
||||||
|
Delegates to the provider plugin's ``get_api_key``. Falls back to the
|
||||||
|
generic ``settings.API_KEY`` for unknown providers.
|
||||||
|
"""
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
|
from application.llm.providers import PROVIDERS_BY_NAME
|
||||||
|
|
||||||
provider_key_map = {
|
plugin = PROVIDERS_BY_NAME.get(provider)
|
||||||
"openai": settings.OPENAI_API_KEY,
|
if plugin is not None:
|
||||||
"openrouter": settings.OPEN_ROUTER_API_KEY,
|
key = plugin.get_api_key(settings)
|
||||||
"novita": settings.NOVITA_API_KEY,
|
if key:
|
||||||
"anthropic": settings.ANTHROPIC_API_KEY,
|
return key
|
||||||
"google": settings.GOOGLE_API_KEY,
|
|
||||||
"groq": settings.GROQ_API_KEY,
|
|
||||||
"huggingface": settings.HUGGINGFACE_API_KEY,
|
|
||||||
"azure_openai": settings.API_KEY,
|
|
||||||
"docsgpt": None,
|
|
||||||
"llama.cpp": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
provider_key = provider_key_map.get(provider)
|
|
||||||
if provider_key:
|
|
||||||
return provider_key
|
|
||||||
return settings.API_KEY
|
return settings.API_KEY
|
||||||
|
|
||||||
|
|
||||||
def get_all_available_models() -> Dict[str, Dict[str, Any]]:
|
def get_all_available_models(
|
||||||
"""Get all available models with metadata for API response"""
|
user_id: Optional[str] = None,
|
||||||
|
) -> Dict[str, Dict[str, Any]]:
|
||||||
|
"""Get all available models with metadata for API response.
|
||||||
|
|
||||||
|
When ``user_id`` is supplied, the user's BYOM custom-model records
|
||||||
|
are merged into the result alongside the built-in catalog.
|
||||||
|
"""
|
||||||
registry = ModelRegistry.get_instance()
|
registry = ModelRegistry.get_instance()
|
||||||
return {model.id: model.to_dict() for model in registry.get_enabled_models()}
|
return {
|
||||||
|
model.id: model.to_dict()
|
||||||
|
for model in registry.get_enabled_models(user_id=user_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def validate_model_id(model_id: str) -> bool:
|
def validate_model_id(model_id: str, user_id: Optional[str] = None) -> bool:
|
||||||
"""Check if a model ID exists in registry"""
|
"""Check if a model ID exists in registry.
|
||||||
|
|
||||||
|
``user_id`` enables resolution of per-user BYOM records (UUIDs).
|
||||||
|
Without it, only built-in catalog ids resolve.
|
||||||
|
"""
|
||||||
registry = ModelRegistry.get_instance()
|
registry = ModelRegistry.get_instance()
|
||||||
return registry.model_exists(model_id)
|
return registry.model_exists(model_id, user_id=user_id)
|
||||||
|
|
||||||
|
|
||||||
def get_model_capabilities(model_id: str) -> Optional[Dict[str, Any]]:
|
def get_model_capabilities(
|
||||||
"""Get capabilities for a specific model"""
|
model_id: str, user_id: Optional[str] = None
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Get capabilities for a specific model.
|
||||||
|
|
||||||
|
``user_id`` enables resolution of per-user BYOM records.
|
||||||
|
"""
|
||||||
registry = ModelRegistry.get_instance()
|
registry = ModelRegistry.get_instance()
|
||||||
model = registry.get_model(model_id)
|
model = registry.get_model(model_id, user_id=user_id)
|
||||||
if model:
|
if model:
|
||||||
return {
|
return {
|
||||||
"supported_attachment_types": model.capabilities.supported_attachment_types,
|
"supported_attachment_types": model.capabilities.supported_attachment_types,
|
||||||
@@ -58,36 +70,68 @@ def get_default_model_id() -> str:
|
|||||||
return registry.default_model_id
|
return registry.default_model_id
|
||||||
|
|
||||||
|
|
||||||
def get_provider_from_model_id(model_id: str) -> Optional[str]:
|
def get_provider_from_model_id(
|
||||||
"""Get the provider name for a given model_id"""
|
model_id: str, user_id: Optional[str] = None
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""Get the provider name for a given model_id.
|
||||||
|
|
||||||
|
``user_id`` enables resolution of per-user BYOM records (UUIDs).
|
||||||
|
Without it, BYOM model ids return ``None`` and the caller falls
|
||||||
|
back to the deployment default.
|
||||||
|
"""
|
||||||
registry = ModelRegistry.get_instance()
|
registry = ModelRegistry.get_instance()
|
||||||
model = registry.get_model(model_id)
|
model = registry.get_model(model_id, user_id=user_id)
|
||||||
if model:
|
if model:
|
||||||
return model.provider.value
|
return model.provider.value
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_token_limit(model_id: str) -> int:
|
def get_token_limit(model_id: str, user_id: Optional[str] = None) -> int:
|
||||||
"""
|
"""Get context window (token limit) for a model.
|
||||||
Get context window (token limit) for a model.
|
|
||||||
Returns model's context_window or default 128000 if model not found.
|
Returns the model's ``context_window`` or ``DEFAULT_LLM_TOKEN_LIMIT``
|
||||||
|
if not found. ``user_id`` enables resolution of per-user BYOM records.
|
||||||
"""
|
"""
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
|
|
||||||
registry = ModelRegistry.get_instance()
|
registry = ModelRegistry.get_instance()
|
||||||
model = registry.get_model(model_id)
|
model = registry.get_model(model_id, user_id=user_id)
|
||||||
if model:
|
if model:
|
||||||
return model.capabilities.context_window
|
return model.capabilities.context_window
|
||||||
return settings.DEFAULT_LLM_TOKEN_LIMIT
|
return settings.DEFAULT_LLM_TOKEN_LIMIT
|
||||||
|
|
||||||
|
|
||||||
def get_base_url_for_model(model_id: str) -> Optional[str]:
|
def get_base_url_for_model(
|
||||||
"""
|
model_id: str, user_id: Optional[str] = None
|
||||||
Get the custom base_url for a specific model if configured.
|
) -> Optional[str]:
|
||||||
Returns None if no custom base_url is set.
|
"""Get the custom base_url for a specific model if configured.
|
||||||
|
|
||||||
|
Returns ``None`` if no custom base_url is set. ``user_id`` enables
|
||||||
|
resolution of per-user BYOM records.
|
||||||
"""
|
"""
|
||||||
registry = ModelRegistry.get_instance()
|
registry = ModelRegistry.get_instance()
|
||||||
model = registry.get_model(model_id)
|
model = registry.get_model(model_id, user_id=user_id)
|
||||||
if model:
|
if model:
|
||||||
return model.base_url
|
return model.base_url
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_api_key_for_model(
|
||||||
|
model_id: str, user_id: Optional[str] = None
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""Resolve the API key to use when invoking ``model_id``.
|
||||||
|
|
||||||
|
Priority:
|
||||||
|
1. The model record's own ``api_key`` (BYOM records and
|
||||||
|
``openai_compatible`` YAMLs populate this).
|
||||||
|
2. The provider plugin's settings-based key.
|
||||||
|
|
||||||
|
``user_id`` enables resolution of per-user BYOM records.
|
||||||
|
"""
|
||||||
|
registry = ModelRegistry.get_instance()
|
||||||
|
model = registry.get_model(model_id, user_id=user_id)
|
||||||
|
if model is not None and model.api_key:
|
||||||
|
return model.api_key
|
||||||
|
if model is not None:
|
||||||
|
return get_api_key_for_provider(model.provider.value)
|
||||||
|
return None
|
||||||
|
|||||||
358
application/core/model_yaml.py
Normal file
358
application/core/model_yaml.py
Normal file
@@ -0,0 +1,358 @@
|
|||||||
|
"""YAML loader for model catalog files under ``application/core/models/``.
|
||||||
|
|
||||||
|
Each ``*.yaml`` file declares one provider's static model catalog. Files
|
||||||
|
are validated with Pydantic at load time; any parse, schema, or alias
|
||||||
|
error aborts startup with the offending file path in the message.
|
||||||
|
|
||||||
|
For most providers, one YAML maps to one catalog. The
|
||||||
|
``openai_compatible`` provider is special: each YAML file represents a
|
||||||
|
distinct logical endpoint (Mistral, Together, Ollama, ...) with its own
|
||||||
|
``api_key_env`` and ``base_url``. The loader returns a flat list so the
|
||||||
|
registry can distinguish multiple files with the same ``provider:`` value.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Sequence
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||||
|
|
||||||
|
from application.core.model_settings import (
|
||||||
|
AvailableModel,
|
||||||
|
ModelCapabilities,
|
||||||
|
ModelProvider,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
BUILTIN_MODELS_DIR = Path(__file__).parent / "models"
|
||||||
|
DEFAULTS_FILENAME = "_defaults.yaml"
|
||||||
|
|
||||||
|
|
||||||
|
class _DefaultsFile(BaseModel):
|
||||||
|
"""Schema for ``_defaults.yaml``. Currently just attachment aliases."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
attachment_aliases: Dict[str, List[str]] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class _CapabilityFields(BaseModel):
|
||||||
|
"""Capability fields shared between provider ``defaults:`` and per-model overrides.
|
||||||
|
|
||||||
|
All fields are optional so a per-model override can selectively replace
|
||||||
|
a single field from the provider-level defaults.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
supports_tools: Optional[bool] = None
|
||||||
|
supports_structured_output: Optional[bool] = None
|
||||||
|
supports_streaming: Optional[bool] = None
|
||||||
|
attachments: Optional[List[str]] = None
|
||||||
|
context_window: Optional[int] = None
|
||||||
|
input_cost_per_token: Optional[float] = None
|
||||||
|
output_cost_per_token: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
|
class _ModelEntry(_CapabilityFields):
|
||||||
|
"""Schema for one model row inside a YAML's ``models:`` list."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
display_name: Optional[str] = None
|
||||||
|
description: str = ""
|
||||||
|
enabled: bool = True
|
||||||
|
base_url: Optional[str] = None
|
||||||
|
aliases: List[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
@field_validator("id")
|
||||||
|
@classmethod
|
||||||
|
def _id_nonempty(cls, v: str) -> str:
|
||||||
|
if not v or not v.strip():
|
||||||
|
raise ValueError("model id must be a non-empty string")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class _ProviderFile(BaseModel):
|
||||||
|
"""Schema for one ``<provider>.yaml`` catalog file."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
provider: str
|
||||||
|
defaults: _CapabilityFields = Field(default_factory=_CapabilityFields)
|
||||||
|
models: List[_ModelEntry] = Field(default_factory=list)
|
||||||
|
# openai_compatible metadata. Optional for other providers.
|
||||||
|
display_provider: Optional[str] = None
|
||||||
|
api_key_env: Optional[str] = None
|
||||||
|
base_url: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderCatalog(BaseModel):
|
||||||
|
"""One YAML file's parsed contents, ready for the registry.
|
||||||
|
|
||||||
|
For most providers, multiple catalogs with the same ``provider`` get
|
||||||
|
merged later by the registry. The ``openai_compatible`` provider is
|
||||||
|
the exception: each catalog is treated as a distinct endpoint, with
|
||||||
|
its own ``api_key_env`` and ``base_url``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
provider: str
|
||||||
|
models: List[AvailableModel]
|
||||||
|
source_path: Optional[Path] = None
|
||||||
|
display_provider: Optional[str] = None
|
||||||
|
api_key_env: Optional[str] = None
|
||||||
|
base_url: Optional[str] = None
|
||||||
|
|
||||||
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelYAMLError(ValueError):
|
||||||
|
"""Raised when a model YAML fails parsing, schema, or alias validation."""
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_attachments(
|
||||||
|
attachments: Sequence[str], aliases: Dict[str, List[str]], source: str
|
||||||
|
) -> List[str]:
|
||||||
|
"""Resolve attachment shorthands (``image``, ``pdf``) to MIME types.
|
||||||
|
|
||||||
|
Raw MIME-typed entries (containing ``/``) pass through unchanged.
|
||||||
|
Unknown aliases raise ``ModelYAMLError``.
|
||||||
|
"""
|
||||||
|
expanded: List[str] = []
|
||||||
|
seen: set = set()
|
||||||
|
for entry in attachments:
|
||||||
|
if "/" in entry:
|
||||||
|
if entry not in seen:
|
||||||
|
expanded.append(entry)
|
||||||
|
seen.add(entry)
|
||||||
|
continue
|
||||||
|
if entry not in aliases:
|
||||||
|
valid = ", ".join(sorted(aliases.keys())) or "<none defined>"
|
||||||
|
raise ModelYAMLError(
|
||||||
|
f"{source}: unknown attachment alias '{entry}'. "
|
||||||
|
f"Valid aliases: {valid}. "
|
||||||
|
"(Or use a raw MIME type like 'image/png'.)"
|
||||||
|
)
|
||||||
|
for mime in aliases[entry]:
|
||||||
|
if mime not in seen:
|
||||||
|
expanded.append(mime)
|
||||||
|
seen.add(mime)
|
||||||
|
return expanded
|
||||||
|
|
||||||
|
|
||||||
|
def _load_defaults(directory: Path) -> Dict[str, List[str]]:
|
||||||
|
"""Load ``_defaults.yaml`` from ``directory`` if it exists."""
|
||||||
|
path = directory / DEFAULTS_FILENAME
|
||||||
|
if not path.exists():
|
||||||
|
return {}
|
||||||
|
try:
|
||||||
|
raw = yaml.safe_load(path.read_text(encoding="utf-8")) or {}
|
||||||
|
except yaml.YAMLError as e:
|
||||||
|
raise ModelYAMLError(f"{path}: invalid YAML: {e}") from e
|
||||||
|
try:
|
||||||
|
parsed = _DefaultsFile.model_validate(raw)
|
||||||
|
except Exception as e:
|
||||||
|
raise ModelYAMLError(f"{path}: schema error: {e}") from e
|
||||||
|
return parsed.attachment_aliases
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_provider_enum(name: str, source: Path) -> ModelProvider:
|
||||||
|
try:
|
||||||
|
return ModelProvider(name)
|
||||||
|
except ValueError as e:
|
||||||
|
valid = ", ".join(p.value for p in ModelProvider)
|
||||||
|
raise ModelYAMLError(
|
||||||
|
f"{source}: unknown provider '{name}'. Valid: {valid}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
def _build_model(
|
||||||
|
entry: _ModelEntry,
|
||||||
|
defaults: _CapabilityFields,
|
||||||
|
provider: ModelProvider,
|
||||||
|
aliases: Dict[str, List[str]],
|
||||||
|
source: Path,
|
||||||
|
display_provider: Optional[str] = None,
|
||||||
|
) -> AvailableModel:
|
||||||
|
"""Merge defaults + per-model overrides into a final ``AvailableModel``."""
|
||||||
|
|
||||||
|
def pick(field_name: str, fallback):
|
||||||
|
v = getattr(entry, field_name)
|
||||||
|
if v is not None:
|
||||||
|
return v
|
||||||
|
d = getattr(defaults, field_name)
|
||||||
|
if d is not None:
|
||||||
|
return d
|
||||||
|
return fallback
|
||||||
|
|
||||||
|
raw_attachments = entry.attachments
|
||||||
|
if raw_attachments is None:
|
||||||
|
raw_attachments = defaults.attachments
|
||||||
|
if raw_attachments is None:
|
||||||
|
raw_attachments = []
|
||||||
|
expanded = _expand_attachments(
|
||||||
|
raw_attachments, aliases, f"{source} [model={entry.id}]"
|
||||||
|
)
|
||||||
|
|
||||||
|
caps = ModelCapabilities(
|
||||||
|
supports_tools=pick("supports_tools", False),
|
||||||
|
supports_structured_output=pick("supports_structured_output", False),
|
||||||
|
supports_streaming=pick("supports_streaming", True),
|
||||||
|
supported_attachment_types=expanded,
|
||||||
|
context_window=pick("context_window", 128000),
|
||||||
|
input_cost_per_token=pick("input_cost_per_token", None),
|
||||||
|
output_cost_per_token=pick("output_cost_per_token", None),
|
||||||
|
)
|
||||||
|
|
||||||
|
return AvailableModel(
|
||||||
|
id=entry.id,
|
||||||
|
provider=provider,
|
||||||
|
display_name=entry.display_name or entry.id,
|
||||||
|
description=entry.description,
|
||||||
|
capabilities=caps,
|
||||||
|
enabled=entry.enabled,
|
||||||
|
base_url=entry.base_url,
|
||||||
|
display_provider=display_provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_one_yaml(
|
||||||
|
path: Path, aliases: Dict[str, List[str]]
|
||||||
|
) -> ProviderCatalog:
|
||||||
|
try:
|
||||||
|
raw = yaml.safe_load(path.read_text(encoding="utf-8")) or {}
|
||||||
|
except yaml.YAMLError as e:
|
||||||
|
raise ModelYAMLError(f"{path}: invalid YAML: {e}") from e
|
||||||
|
try:
|
||||||
|
parsed = _ProviderFile.model_validate(raw)
|
||||||
|
except Exception as e:
|
||||||
|
raise ModelYAMLError(f"{path}: schema error: {e}") from e
|
||||||
|
|
||||||
|
provider_enum = _resolve_provider_enum(parsed.provider, path)
|
||||||
|
models = [
|
||||||
|
_build_model(
|
||||||
|
entry,
|
||||||
|
parsed.defaults,
|
||||||
|
provider_enum,
|
||||||
|
aliases,
|
||||||
|
path,
|
||||||
|
display_provider=parsed.display_provider,
|
||||||
|
)
|
||||||
|
for entry in parsed.models
|
||||||
|
]
|
||||||
|
|
||||||
|
return ProviderCatalog(
|
||||||
|
provider=parsed.provider,
|
||||||
|
models=models,
|
||||||
|
source_path=path,
|
||||||
|
display_provider=parsed.display_provider,
|
||||||
|
api_key_env=parsed.api_key_env,
|
||||||
|
base_url=parsed.base_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_BUILTIN_ALIASES_CACHE: Optional[Dict[str, List[str]]] = None
|
||||||
|
|
||||||
|
|
||||||
|
def builtin_attachment_aliases() -> Dict[str, List[str]]:
|
||||||
|
"""Return the built-in attachment alias map from ``_defaults.yaml``.
|
||||||
|
|
||||||
|
Cached after first read so repeat calls are cheap.
|
||||||
|
"""
|
||||||
|
global _BUILTIN_ALIASES_CACHE
|
||||||
|
if _BUILTIN_ALIASES_CACHE is None:
|
||||||
|
_BUILTIN_ALIASES_CACHE = _load_defaults(BUILTIN_MODELS_DIR)
|
||||||
|
return _BUILTIN_ALIASES_CACHE
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_attachment_alias(alias: str) -> List[str]:
|
||||||
|
"""Resolve a single attachment alias (e.g. ``"image"``) to its
|
||||||
|
canonical MIME-type list. Raises ``ModelYAMLError`` if unknown.
|
||||||
|
"""
|
||||||
|
aliases = builtin_attachment_aliases()
|
||||||
|
if alias not in aliases:
|
||||||
|
valid = ", ".join(sorted(aliases.keys())) or "<none defined>"
|
||||||
|
raise ModelYAMLError(
|
||||||
|
f"Unknown attachment alias '{alias}'. Valid: {valid}"
|
||||||
|
)
|
||||||
|
return list(aliases[alias])
|
||||||
|
|
||||||
|
|
||||||
|
def expand_attachments_lenient(
|
||||||
|
attachments: Sequence[str], source: str
|
||||||
|
) -> List[str]:
|
||||||
|
"""Expand attachment aliases to MIME types, tolerating unknowns.
|
||||||
|
|
||||||
|
Mirrors ``_expand_attachments`` but logs+skips unknown aliases
|
||||||
|
rather than raising. Used for runtime call sites (BYOM registry
|
||||||
|
load) where an operator-side alias-map edit must not drop the
|
||||||
|
entire user's BYOM layer; the strict raise still happens at the
|
||||||
|
API validation boundary.
|
||||||
|
"""
|
||||||
|
aliases = builtin_attachment_aliases()
|
||||||
|
expanded: List[str] = []
|
||||||
|
seen: set = set()
|
||||||
|
for entry in attachments:
|
||||||
|
if "/" in entry:
|
||||||
|
if entry not in seen:
|
||||||
|
expanded.append(entry)
|
||||||
|
seen.add(entry)
|
||||||
|
continue
|
||||||
|
mime_list = aliases.get(entry)
|
||||||
|
if mime_list is None:
|
||||||
|
logger.warning(
|
||||||
|
"%s: skipping unknown attachment alias %r", source, entry,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
for mime in mime_list:
|
||||||
|
if mime not in seen:
|
||||||
|
expanded.append(mime)
|
||||||
|
seen.add(mime)
|
||||||
|
return expanded
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_yamls(directories: Sequence[Path]) -> List[ProviderCatalog]:
|
||||||
|
"""Load every ``*.yaml`` file (excluding ``_defaults.yaml``) under each
|
||||||
|
directory in order and return a flat list of catalogs.
|
||||||
|
|
||||||
|
Caller is responsible for merging multiple catalogs that target the
|
||||||
|
same provider plugin. The flat-list shape lets ``openai_compatible``
|
||||||
|
keep each file separate (one logical endpoint per file).
|
||||||
|
|
||||||
|
When the same model ``id`` appears in more than one YAML across the
|
||||||
|
directory list, a warning is logged. Order in the returned list
|
||||||
|
preserves load order, so the registry's "later wins" merge gives the
|
||||||
|
later directory's definition.
|
||||||
|
"""
|
||||||
|
catalogs: List[ProviderCatalog] = []
|
||||||
|
seen_ids: Dict[str, Path] = {}
|
||||||
|
|
||||||
|
aliases: Dict[str, List[str]] = {}
|
||||||
|
for d in directories:
|
||||||
|
if not d or not d.exists():
|
||||||
|
continue
|
||||||
|
aliases.update(_load_defaults(d))
|
||||||
|
|
||||||
|
for d in directories:
|
||||||
|
if not d or not d.exists():
|
||||||
|
continue
|
||||||
|
for path in sorted(d.glob("*.yaml")):
|
||||||
|
if path.name == DEFAULTS_FILENAME:
|
||||||
|
continue
|
||||||
|
catalog = _load_one_yaml(path, aliases)
|
||||||
|
catalogs.append(catalog)
|
||||||
|
for m in catalog.models:
|
||||||
|
prior = seen_ids.get(m.id)
|
||||||
|
if prior is not None and prior != path:
|
||||||
|
logger.warning(
|
||||||
|
"Model id %r redefined: %s overrides %s (later wins)",
|
||||||
|
m.id,
|
||||||
|
path,
|
||||||
|
prior,
|
||||||
|
)
|
||||||
|
seen_ids[m.id] = path
|
||||||
|
|
||||||
|
return catalogs
|
||||||
213
application/core/models/README.md
Normal file
213
application/core/models/README.md
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
# Model catalogs
|
||||||
|
|
||||||
|
Each `*.yaml` file in this directory declares one provider's model
|
||||||
|
catalog. The registry loads every YAML at boot and joins it to the
|
||||||
|
matching provider plugin under `application/llm/providers/`.
|
||||||
|
|
||||||
|
To add or edit models, you almost always only touch a YAML here — no
|
||||||
|
Python code required.
|
||||||
|
|
||||||
|
## Add a model to an existing provider
|
||||||
|
|
||||||
|
Open the provider's YAML (e.g. `anthropic.yaml`) and append two lines
|
||||||
|
under `models:`:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
models:
|
||||||
|
- id: claude-3-7-sonnet
|
||||||
|
display_name: Claude 3.7 Sonnet
|
||||||
|
```
|
||||||
|
|
||||||
|
Capabilities default to the provider's `defaults:` block. Override
|
||||||
|
per-model only when needed:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
- id: claude-3-7-sonnet
|
||||||
|
display_name: Claude 3.7 Sonnet
|
||||||
|
context_window: 500000
|
||||||
|
```
|
||||||
|
|
||||||
|
Restart the app. The new model appears in `/api/models`.
|
||||||
|
|
||||||
|
> The model `id` is what gets stored in agent / workflow records. Once
|
||||||
|
> users start picking the model, **don't rename it** — agent and
|
||||||
|
> workflow rows reference it as a free-form string and silently fall
|
||||||
|
> back to the system default if the id disappears.
|
||||||
|
|
||||||
|
## Add an OpenAI-compatible provider (zero Python)
|
||||||
|
|
||||||
|
Drop a YAML in this directory (or in your `MODELS_CONFIG_DIR`) that uses
|
||||||
|
the `openai_compatible` plugin. Set the env var named in `api_key_env`
|
||||||
|
and you're done — no Python, no settings.py edit, no LLMCreator change:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# mistral.yaml
|
||||||
|
provider: openai_compatible
|
||||||
|
display_provider: mistral # shown in /api/models response
|
||||||
|
api_key_env: MISTRAL_API_KEY # env var the plugin reads at boot
|
||||||
|
base_url: https://api.mistral.ai/v1
|
||||||
|
defaults:
|
||||||
|
supports_tools: true
|
||||||
|
context_window: 128000
|
||||||
|
models:
|
||||||
|
- id: mistral-large-latest
|
||||||
|
display_name: Mistral Large
|
||||||
|
- id: mistral-small-latest
|
||||||
|
display_name: Mistral Small
|
||||||
|
```
|
||||||
|
|
||||||
|
`MISTRAL_API_KEY=sk-... ; restart` — Mistral models appear in
|
||||||
|
`/api/models` with `provider: "mistral"`. They route through the OpenAI
|
||||||
|
wire format (it's `OpenAILLM` under the hood) but with Mistral's
|
||||||
|
endpoint and key.
|
||||||
|
|
||||||
|
Multiple `openai_compatible` YAMLs coexist: each file is one logical
|
||||||
|
endpoint with its own `api_key_env` and `base_url`. Drop in
|
||||||
|
`together.yaml`, `fireworks.yaml`, etc. side by side. If an env var
|
||||||
|
isn't set, that catalog is silently skipped at boot (logged at INFO) —
|
||||||
|
no error.
|
||||||
|
|
||||||
|
Working example: `examples/mistral.yaml.example`. Files inside
|
||||||
|
`examples/` aren't loaded by the registry; the glob only picks up
|
||||||
|
`*.yaml` at the top level.
|
||||||
|
|
||||||
|
## Add a provider with its own SDK
|
||||||
|
|
||||||
|
For a provider that doesn't speak OpenAI's wire format, add one Python
|
||||||
|
file to `application/llm/providers/<name>.py`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from application.llm.providers.base import Provider
|
||||||
|
from application.llm.my_provider import MyLLM
|
||||||
|
|
||||||
|
class MyProvider(Provider):
|
||||||
|
name = "my_provider"
|
||||||
|
llm_class = MyLLM
|
||||||
|
|
||||||
|
def get_api_key(self, settings):
|
||||||
|
return settings.MY_PROVIDER_API_KEY
|
||||||
|
```
|
||||||
|
|
||||||
|
Register it in `application/llm/providers/__init__.py` (one line in
|
||||||
|
`ALL_PROVIDERS`), add `MY_PROVIDER_API_KEY` to `settings.py`, and create
|
||||||
|
`my_provider.yaml` here with the model catalog.
|
||||||
|
|
||||||
|
## Schema reference
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
provider: <string, required> # matches the Provider plugin's `name`
|
||||||
|
|
||||||
|
# openai_compatible only — required for that provider, ignored for others
|
||||||
|
display_provider: <string> # label shown in /api/models response
|
||||||
|
api_key_env: <string> # name of the env var carrying the key
|
||||||
|
base_url: <string> # endpoint URL
|
||||||
|
|
||||||
|
defaults: # optional, applied to every model below
|
||||||
|
supports_tools: bool # default false
|
||||||
|
supports_structured_output: bool # default false
|
||||||
|
supports_streaming: bool # default true
|
||||||
|
attachments: [<alias-or-mime>, ...] # default []
|
||||||
|
context_window: int # default 128000
|
||||||
|
input_cost_per_token: float # default null
|
||||||
|
output_cost_per_token: float # default null
|
||||||
|
|
||||||
|
models: # required
|
||||||
|
- id: <string, required> # the value persisted in agent records
|
||||||
|
display_name: <string> # default: id
|
||||||
|
description: <string> # default: ""
|
||||||
|
enabled: bool # default true; false hides from /api/models
|
||||||
|
base_url: <string> # optional custom endpoint for this model
|
||||||
|
# All `defaults:` fields above can be overridden here per-model.
|
||||||
|
```
|
||||||
|
|
||||||
|
### Attachment aliases
|
||||||
|
|
||||||
|
The `attachments:` list can mix human-readable aliases with raw MIME
|
||||||
|
types. Aliases are defined in `_defaults.yaml`:
|
||||||
|
|
||||||
|
| Alias | Expands to |
|
||||||
|
|---|---|
|
||||||
|
| `image` | `image/png`, `image/jpeg`, `image/jpg`, `image/webp`, `image/gif` |
|
||||||
|
| `pdf` | `application/pdf` |
|
||||||
|
| `audio` | `audio/mpeg`, `audio/wav`, `audio/ogg` |
|
||||||
|
|
||||||
|
Use raw MIME types when you need surgical control:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
attachments: [image/png, image/webp] # only these two
|
||||||
|
```
|
||||||
|
|
||||||
|
## Operator-supplied YAMLs (`MODELS_CONFIG_DIR`)
|
||||||
|
|
||||||
|
Set the `MODELS_CONFIG_DIR` env var (or `.env` entry) to a directory
|
||||||
|
path. Every `*.yaml` in that directory is loaded **after** the built-in
|
||||||
|
catalog under `application/core/models/`. Operators use this to:
|
||||||
|
|
||||||
|
- Add new `openai_compatible` providers (Mistral, Together, Fireworks,
|
||||||
|
Ollama, ...) without forking the repo.
|
||||||
|
- Extend an existing provider's catalog with extra models — append
|
||||||
|
models under `provider: anthropic` and they show up alongside the
|
||||||
|
built-ins.
|
||||||
|
- Override a built-in model's capabilities — declare the same `id`
|
||||||
|
with different fields (e.g. a higher `context_window`). Later wins;
|
||||||
|
the override is logged as a `WARNING` so you can audit it.
|
||||||
|
|
||||||
|
Things you cannot do via `MODELS_CONFIG_DIR`:
|
||||||
|
|
||||||
|
- Add a brand-new non-OpenAI provider — that needs a Python plugin
|
||||||
|
under `application/llm/providers/` (see "Add a provider with its own
|
||||||
|
SDK" above). Operator YAMLs may only target a `provider:` value that
|
||||||
|
already has a registered plugin.
|
||||||
|
|
||||||
|
### Example: Docker
|
||||||
|
|
||||||
|
Mount your model YAMLs into the container and point the env var at the
|
||||||
|
mount path:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# docker-compose.yml
|
||||||
|
services:
|
||||||
|
app:
|
||||||
|
image: arc53/docsgpt
|
||||||
|
environment:
|
||||||
|
MODELS_CONFIG_DIR: /etc/docsgpt/models
|
||||||
|
MISTRAL_API_KEY: ${MISTRAL_API_KEY}
|
||||||
|
volumes:
|
||||||
|
- ./my-models:/etc/docsgpt/models:ro
|
||||||
|
```
|
||||||
|
|
||||||
|
Then `./my-models/mistral.yaml` (the file from
|
||||||
|
`examples/mistral.yaml.example`) gets picked up at boot.
|
||||||
|
|
||||||
|
### Example: Kubernetes
|
||||||
|
|
||||||
|
Mount a `ConfigMap` containing your YAMLs at a known path and set
|
||||||
|
`MODELS_CONFIG_DIR` on the deployment. The same `examples/mistral.yaml.example`
|
||||||
|
becomes a key in the ConfigMap.
|
||||||
|
|
||||||
|
### Misconfiguration
|
||||||
|
|
||||||
|
If `MODELS_CONFIG_DIR` is set but the path doesn't exist (or isn't a
|
||||||
|
directory), the app logs a `WARNING` at boot and continues with just
|
||||||
|
the built-in catalog. The app does *not* fail to start — operators can
|
||||||
|
ship config drift without taking down the service — but the warning is
|
||||||
|
loud enough to surface in any reasonable log aggregator.
|
||||||
|
|
||||||
|
## Validation
|
||||||
|
|
||||||
|
YAMLs are parsed with Pydantic at boot. The app fails to start with a
|
||||||
|
clear error message if:
|
||||||
|
|
||||||
|
- a top-level key is unknown
|
||||||
|
- a model is missing `id`
|
||||||
|
- an attachment alias isn't defined
|
||||||
|
- the `provider:` value isn't registered as a plugin
|
||||||
|
|
||||||
|
This is intentional — silent fallbacks would mean users don't notice
|
||||||
|
their model picks broke until they hit the API.
|
||||||
|
|
||||||
|
## Reserved fields (not yet implemented)
|
||||||
|
|
||||||
|
- `aliases:` on a model — old IDs that resolve to this model. Reserved
|
||||||
|
for future renames; the schema accepts the field but it is not yet
|
||||||
|
acted on.
|
||||||
18
application/core/models/_defaults.yaml
Normal file
18
application/core/models/_defaults.yaml
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
# Global defaults applied across every model YAML in this directory.
|
||||||
|
# Keep this file sparse — per-provider `defaults:` blocks are clearer
|
||||||
|
# than a deep global default chain. This file is for things that
|
||||||
|
# genuinely never vary, like the meaning of "image".
|
||||||
|
|
||||||
|
attachment_aliases:
|
||||||
|
image:
|
||||||
|
- image/png
|
||||||
|
- image/jpeg
|
||||||
|
- image/jpg
|
||||||
|
- image/webp
|
||||||
|
- image/gif
|
||||||
|
pdf:
|
||||||
|
- application/pdf
|
||||||
|
audio:
|
||||||
|
- audio/mpeg
|
||||||
|
- audio/wav
|
||||||
|
- audio/ogg
|
||||||
23
application/core/models/anthropic.yaml
Normal file
23
application/core/models/anthropic.yaml
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
provider: anthropic
|
||||||
|
defaults:
|
||||||
|
supports_tools: true
|
||||||
|
attachments: [image]
|
||||||
|
context_window: 200000
|
||||||
|
|
||||||
|
models:
|
||||||
|
- id: claude-opus-4-7
|
||||||
|
display_name: Claude Opus 4.7
|
||||||
|
description: Most capable Claude model for complex reasoning and agentic coding
|
||||||
|
context_window: 1000000
|
||||||
|
supports_structured_output: true
|
||||||
|
|
||||||
|
- id: claude-sonnet-4-6
|
||||||
|
display_name: Claude Sonnet 4.6
|
||||||
|
description: Best balance of speed and intelligence with extended thinking
|
||||||
|
context_window: 1000000
|
||||||
|
supports_structured_output: true
|
||||||
|
|
||||||
|
- id: claude-haiku-4-5
|
||||||
|
display_name: Claude Haiku 4.5
|
||||||
|
description: Fastest Claude model with near-frontier intelligence
|
||||||
|
supports_structured_output: true
|
||||||
31
application/core/models/azure_openai.yaml
Normal file
31
application/core/models/azure_openai.yaml
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
# Azure OpenAI catalog.
|
||||||
|
#
|
||||||
|
# IMPORTANT: For Azure OpenAI, the `id` field is the **deployment name**, not
|
||||||
|
# a model name. Deployment names are arbitrary strings the operator chooses
|
||||||
|
# in Azure portal (or via ARM/Bicep/Terraform) when they create a deployment
|
||||||
|
# for a given underlying model + version.
|
||||||
|
#
|
||||||
|
# The IDs below are sensible defaults that mirror the underlying OpenAI
|
||||||
|
# model name (prefixed with `azure-`). Operators almost always need to
|
||||||
|
# override them via `MODELS_CONFIG_DIR` to match the deployment names that
|
||||||
|
# actually exist in their Azure resource. The `display_name`, capability
|
||||||
|
# flags, and `context_window` reflect the underlying OpenAI model.
|
||||||
|
provider: azure_openai
|
||||||
|
|
||||||
|
defaults:
|
||||||
|
supports_tools: true
|
||||||
|
supports_structured_output: true
|
||||||
|
attachments: [image]
|
||||||
|
context_window: 400000
|
||||||
|
|
||||||
|
models:
|
||||||
|
- id: azure-gpt-5.5
|
||||||
|
display_name: Azure OpenAI GPT-5.5
|
||||||
|
description: Azure-hosted flagship frontier model for complex reasoning, coding, and agentic work with a 1M-token context window
|
||||||
|
context_window: 1050000
|
||||||
|
- id: azure-gpt-5.4-mini
|
||||||
|
display_name: Azure OpenAI GPT-5.4 Mini
|
||||||
|
description: Azure-hosted cost-efficient GPT-5.4-class model for high-volume coding, computer use, and subagent workloads
|
||||||
|
- id: azure-gpt-5.4-nano
|
||||||
|
display_name: Azure OpenAI GPT-5.4 Nano
|
||||||
|
description: Azure-hosted cheapest GPT-5.4-class model, optimized for simple high-volume tasks where speed and cost matter most
|
||||||
7
application/core/models/docsgpt.yaml
Normal file
7
application/core/models/docsgpt.yaml
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
provider: docsgpt
|
||||||
|
|
||||||
|
models:
|
||||||
|
- id: docsgpt-local
|
||||||
|
display_name: DocsGPT Model
|
||||||
|
description: Local model
|
||||||
|
supports_tools: false
|
||||||
31
application/core/models/examples/mistral.yaml.example
Normal file
31
application/core/models/examples/mistral.yaml.example
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
# EXAMPLE — copy this file to ../mistral.yaml (or to your
|
||||||
|
# MODELS_CONFIG_DIR) and set MISTRAL_API_KEY in your environment.
|
||||||
|
#
|
||||||
|
# This is the entire integration. No Python required: the
|
||||||
|
# `openai_compatible` plugin reads `api_key_env` and `base_url` from
|
||||||
|
# the file and routes calls through the OpenAI wire format.
|
||||||
|
#
|
||||||
|
# Files in this `examples/` directory are NOT loaded by the registry
|
||||||
|
# (the loader globs *.yaml at the top level only).
|
||||||
|
|
||||||
|
provider: openai_compatible
|
||||||
|
display_provider: mistral # shown in /api/models response
|
||||||
|
api_key_env: MISTRAL_API_KEY # env var the plugin reads
|
||||||
|
base_url: https://api.mistral.ai/v1 # OpenAI-compatible endpoint
|
||||||
|
|
||||||
|
defaults:
|
||||||
|
supports_tools: true
|
||||||
|
context_window: 128000
|
||||||
|
|
||||||
|
models:
|
||||||
|
- id: mistral-large-latest
|
||||||
|
display_name: Mistral Large
|
||||||
|
description: Top-tier reasoning model
|
||||||
|
|
||||||
|
- id: mistral-small-latest
|
||||||
|
display_name: Mistral Small
|
||||||
|
description: Fast, cost-efficient
|
||||||
|
|
||||||
|
- id: codestral-latest
|
||||||
|
display_name: Codestral
|
||||||
|
description: Code-specialized model
|
||||||
17
application/core/models/google.yaml
Normal file
17
application/core/models/google.yaml
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
provider: google
|
||||||
|
defaults:
|
||||||
|
supports_tools: true
|
||||||
|
supports_structured_output: true
|
||||||
|
attachments: [pdf, image]
|
||||||
|
context_window: 1048576
|
||||||
|
|
||||||
|
models:
|
||||||
|
- id: gemini-3.1-pro-preview
|
||||||
|
display_name: Gemini 3.1 Pro
|
||||||
|
description: Most capable Gemini 3 model with advanced reasoning and agentic coding (preview)
|
||||||
|
- id: gemini-3-flash-preview
|
||||||
|
display_name: Gemini 3 Flash
|
||||||
|
description: Frontier-class performance for low-latency, high-volume tasks (preview)
|
||||||
|
- id: gemini-3.1-flash-lite-preview
|
||||||
|
display_name: Gemini 3.1 Flash-Lite
|
||||||
|
description: Cost-efficient frontier-class multimodal model for high-throughput workloads (preview)
|
||||||
16
application/core/models/groq.yaml
Normal file
16
application/core/models/groq.yaml
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
provider: groq
|
||||||
|
defaults:
|
||||||
|
supports_tools: true
|
||||||
|
context_window: 131072
|
||||||
|
|
||||||
|
models:
|
||||||
|
- id: openai/gpt-oss-120b
|
||||||
|
display_name: GPT-OSS 120B
|
||||||
|
description: OpenAI's open-weight 120B flagship served on Groq's LPU hardware; strong general reasoning with strict structured output support
|
||||||
|
supports_structured_output: true
|
||||||
|
- id: llama-3.3-70b-versatile
|
||||||
|
display_name: Llama 3.3 70B Versatile
|
||||||
|
description: Meta's Llama 3.3 70B for general-purpose chat with parallel tool use
|
||||||
|
- id: llama-3.1-8b-instant
|
||||||
|
display_name: Llama 3.1 8B Instant
|
||||||
|
description: Small, very low-latency Llama model (~560 tok/s) with parallel tool use
|
||||||
7
application/core/models/huggingface.yaml
Normal file
7
application/core/models/huggingface.yaml
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
provider: huggingface
|
||||||
|
|
||||||
|
models:
|
||||||
|
- id: huggingface-local
|
||||||
|
display_name: Hugging Face Model
|
||||||
|
description: Local Hugging Face model
|
||||||
|
supports_tools: false
|
||||||
21
application/core/models/novita.yaml
Normal file
21
application/core/models/novita.yaml
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
provider: novita
|
||||||
|
defaults:
|
||||||
|
supports_tools: true
|
||||||
|
supports_structured_output: true
|
||||||
|
|
||||||
|
models:
|
||||||
|
- id: deepseek/deepseek-v4-pro
|
||||||
|
display_name: DeepSeek V4 Pro
|
||||||
|
description: 1.6T MoE (49B active) with 1M context, hybrid CSA/HCA attention, top-tier reasoning and agentic coding
|
||||||
|
context_window: 1048576
|
||||||
|
|
||||||
|
- id: moonshotai/kimi-k2.6
|
||||||
|
display_name: Kimi K2.6
|
||||||
|
description: 1T-parameter open-weight MoE with native vision/video, multi-step tool calling, and agentic long-horizon execution
|
||||||
|
attachments: [image]
|
||||||
|
context_window: 262144
|
||||||
|
|
||||||
|
- id: zai-org/glm-5
|
||||||
|
display_name: GLM-5
|
||||||
|
description: Z.AI 754B-parameter MoE with strong general reasoning, function calling, and structured output
|
||||||
|
context_window: 202800
|
||||||
18
application/core/models/openai.yaml
Normal file
18
application/core/models/openai.yaml
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
provider: openai
|
||||||
|
defaults:
|
||||||
|
supports_tools: true
|
||||||
|
supports_structured_output: true
|
||||||
|
attachments: [image]
|
||||||
|
context_window: 400000
|
||||||
|
|
||||||
|
models:
|
||||||
|
- id: gpt-5.5
|
||||||
|
display_name: GPT-5.5
|
||||||
|
description: Flagship frontier model for complex reasoning, coding, and agentic work with a 1M-token context window
|
||||||
|
context_window: 1050000
|
||||||
|
- id: gpt-5.4-mini
|
||||||
|
display_name: GPT-5.4 Mini
|
||||||
|
description: Cost-efficient GPT-5.4-class model for high-volume coding, computer use, and subagent workloads
|
||||||
|
- id: gpt-5.4-nano
|
||||||
|
display_name: GPT-5.4 Nano
|
||||||
|
description: Cheapest GPT-5.4-class model, optimized for simple high-volume tasks where speed and cost matter most
|
||||||
25
application/core/models/openrouter.yaml
Normal file
25
application/core/models/openrouter.yaml
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
provider: openrouter
|
||||||
|
defaults:
|
||||||
|
supports_tools: true
|
||||||
|
attachments: [image]
|
||||||
|
context_window: 128000
|
||||||
|
|
||||||
|
models:
|
||||||
|
- id: qwen/qwen3-coder:free
|
||||||
|
display_name: Qwen3 Coder (free)
|
||||||
|
description: Free-tier 480B MoE coder model with strong agentic tool use; rate-limited
|
||||||
|
context_window: 262000
|
||||||
|
attachments: []
|
||||||
|
|
||||||
|
- id: deepseek/deepseek-v3.2
|
||||||
|
display_name: DeepSeek V3.2
|
||||||
|
description: Open-weights reasoning model, very low cost (~$0.25 in / $0.38 out per 1M)
|
||||||
|
context_window: 131072
|
||||||
|
attachments: []
|
||||||
|
supports_structured_output: true
|
||||||
|
|
||||||
|
- id: anthropic/claude-sonnet-4.6
|
||||||
|
display_name: Claude Sonnet 4.6 (via OpenRouter)
|
||||||
|
description: Frontier Sonnet-class model with 1M context, vision, and extended thinking
|
||||||
|
context_window: 1000000
|
||||||
|
supports_structured_output: true
|
||||||
@@ -23,6 +23,10 @@ class Settings(BaseSettings):
|
|||||||
EMBEDDINGS_NAME: str = "huggingface_sentence-transformers/all-mpnet-base-v2"
|
EMBEDDINGS_NAME: str = "huggingface_sentence-transformers/all-mpnet-base-v2"
|
||||||
EMBEDDINGS_BASE_URL: Optional[str] = None # Remote embeddings API URL (OpenAI-compatible)
|
EMBEDDINGS_BASE_URL: Optional[str] = None # Remote embeddings API URL (OpenAI-compatible)
|
||||||
EMBEDDINGS_KEY: Optional[str] = None # api key for embeddings (if using openai, just copy API_KEY)
|
EMBEDDINGS_KEY: Optional[str] = None # api key for embeddings (if using openai, just copy API_KEY)
|
||||||
|
# Optional directory of operator-supplied model YAMLs, loaded after the
|
||||||
|
# built-in catalog under application/core/models/. Later wins on
|
||||||
|
# duplicate model id. See application/core/models/README.md.
|
||||||
|
MODELS_CONFIG_DIR: Optional[str] = None
|
||||||
|
|
||||||
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
|
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
|
||||||
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
|
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
|
||||||
@@ -149,6 +153,9 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
FLASK_DEBUG_MODE: bool = False
|
FLASK_DEBUG_MODE: bool = False
|
||||||
STORAGE_TYPE: str = "local" # local or s3
|
STORAGE_TYPE: str = "local" # local or s3
|
||||||
|
|
||||||
|
# Anonymous startup version check for security issues.
|
||||||
|
VERSION_CHECK: bool = True
|
||||||
URL_STRATEGY: str = "backend" # backend or s3
|
URL_STRATEGY: str = "backend" # backend or s3
|
||||||
|
|
||||||
JWT_SECRET_KEY: str = ""
|
JWT_SECRET_KEY: str = ""
|
||||||
|
|||||||
72
application/gunicorn_conf.py
Normal file
72
application/gunicorn_conf.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
"""Gunicorn config — keeps uvicorn's access log in NCSA format."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import logging.config
|
||||||
|
|
||||||
|
# NCSA common log format:
|
||||||
|
# %(h)s %(l)s %(u)s %(t)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s"
|
||||||
|
# Uvicorn's access formatter exposes a ``client_addr``/``request_line``/
|
||||||
|
# ``status_code`` trio but not the full NCSA field set, so we re-derive
|
||||||
|
# what we can.
|
||||||
|
_NCSA_FMT = (
|
||||||
|
'%(client_addr)s - - [%(asctime)s] "%(request_line)s" %(status_code)s'
|
||||||
|
)
|
||||||
|
|
||||||
|
logconfig_dict = {
|
||||||
|
"version": 1,
|
||||||
|
"disable_existing_loggers": False,
|
||||||
|
"formatters": {
|
||||||
|
"ncsa_access": {
|
||||||
|
"()": "uvicorn.logging.AccessFormatter",
|
||||||
|
"fmt": _NCSA_FMT,
|
||||||
|
"datefmt": "%d/%b/%Y:%H:%M:%S %z",
|
||||||
|
"use_colors": False,
|
||||||
|
},
|
||||||
|
"default": {
|
||||||
|
"format": "[%(asctime)s] [%(process)d] [%(levelname)s] %(name)s: %(message)s",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"handlers": {
|
||||||
|
"access": {
|
||||||
|
"class": "logging.StreamHandler",
|
||||||
|
"formatter": "ncsa_access",
|
||||||
|
"stream": "ext://sys.stdout",
|
||||||
|
},
|
||||||
|
"default": {
|
||||||
|
"class": "logging.StreamHandler",
|
||||||
|
"formatter": "default",
|
||||||
|
"stream": "ext://sys.stderr",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"loggers": {
|
||||||
|
"uvicorn": {"handlers": ["default"], "level": "INFO", "propagate": False},
|
||||||
|
"uvicorn.error": {
|
||||||
|
"handlers": ["default"],
|
||||||
|
"level": "INFO",
|
||||||
|
"propagate": False,
|
||||||
|
},
|
||||||
|
"uvicorn.access": {
|
||||||
|
"handlers": ["access"],
|
||||||
|
"level": "INFO",
|
||||||
|
"propagate": False,
|
||||||
|
},
|
||||||
|
"gunicorn.error": {
|
||||||
|
"handlers": ["default"],
|
||||||
|
"level": "INFO",
|
||||||
|
"propagate": False,
|
||||||
|
},
|
||||||
|
"gunicorn.access": {
|
||||||
|
"handlers": ["access"],
|
||||||
|
"level": "INFO",
|
||||||
|
"propagate": False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"root": {"handlers": ["default"], "level": "INFO"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def on_starting(server): # pragma: no cover — gunicorn hook
|
||||||
|
"""Ensure gunicorn's own loggers use the configured handlers."""
|
||||||
|
logging.config.dictConfig(logconfig_dict)
|
||||||
@@ -11,6 +11,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class AnthropicLLM(BaseLLM):
|
class AnthropicLLM(BaseLLM):
|
||||||
|
provider_name = "anthropic"
|
||||||
|
|
||||||
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
|
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
from application.cache import gen_cache, stream_cache
|
from application.cache import gen_cache, stream_cache
|
||||||
|
|
||||||
@@ -10,6 +11,10 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class BaseLLM(ABC):
|
class BaseLLM(ABC):
|
||||||
|
# Stamped onto the ``llm_stream_start`` event so dashboards can group
|
||||||
|
# calls by vendor. Subclasses override.
|
||||||
|
provider_name: ClassVar[str] = "unknown"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
decoded_token=None,
|
decoded_token=None,
|
||||||
@@ -17,6 +22,8 @@ class BaseLLM(ABC):
|
|||||||
model_id=None,
|
model_id=None,
|
||||||
base_url=None,
|
base_url=None,
|
||||||
backup_models=None,
|
backup_models=None,
|
||||||
|
model_user_id=None,
|
||||||
|
capabilities=None,
|
||||||
):
|
):
|
||||||
self.decoded_token = decoded_token
|
self.decoded_token = decoded_token
|
||||||
self.agent_id = str(agent_id) if agent_id else None
|
self.agent_id = str(agent_id) if agent_id else None
|
||||||
@@ -25,6 +32,12 @@ class BaseLLM(ABC):
|
|||||||
self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
|
self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
|
||||||
self._backup_models = backup_models or []
|
self._backup_models = backup_models or []
|
||||||
self._fallback_llm = None
|
self._fallback_llm = None
|
||||||
|
# Registry-resolved per-model capability overrides (BYOM caps,
|
||||||
|
# operator YAML). None falls back to provider-class defaults.
|
||||||
|
self.capabilities = capabilities
|
||||||
|
# BYOM-resolution scope captured at LLM creation time so backup
|
||||||
|
# / fallback lookups hit the same per-user layer as the primary.
|
||||||
|
self.model_user_id = model_user_id
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def fallback_llm(self):
|
def fallback_llm(self):
|
||||||
@@ -39,10 +52,19 @@ class BaseLLM(ABC):
|
|||||||
get_api_key_for_provider,
|
get_api_key_for_provider,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Try per-agent backup models first
|
# model_user_id (BYOM scope) takes precedence over the caller's
|
||||||
|
# sub so shared-agent backups resolve under the owner's layer.
|
||||||
|
caller_sub = (
|
||||||
|
self.decoded_token.get("sub")
|
||||||
|
if isinstance(self.decoded_token, dict)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
backup_user_id = self.model_user_id or caller_sub
|
||||||
for backup_model_id in self._backup_models:
|
for backup_model_id in self._backup_models:
|
||||||
try:
|
try:
|
||||||
provider = get_provider_from_model_id(backup_model_id)
|
provider = get_provider_from_model_id(
|
||||||
|
backup_model_id, user_id=backup_user_id
|
||||||
|
)
|
||||||
if not provider:
|
if not provider:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Could not resolve provider for backup model: {backup_model_id}"
|
f"Could not resolve provider for backup model: {backup_model_id}"
|
||||||
@@ -56,6 +78,7 @@ class BaseLLM(ABC):
|
|||||||
decoded_token=self.decoded_token,
|
decoded_token=self.decoded_token,
|
||||||
model_id=backup_model_id,
|
model_id=backup_model_id,
|
||||||
agent_id=self.agent_id,
|
agent_id=self.agent_id,
|
||||||
|
model_user_id=self.model_user_id,
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Fallback LLM initialized from agent backup model: "
|
f"Fallback LLM initialized from agent backup model: "
|
||||||
@@ -68,7 +91,10 @@ class BaseLLM(ABC):
|
|||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Fall back to global FALLBACK_* settings
|
# Fall back to global FALLBACK_* settings. Forward
|
||||||
|
# ``model_user_id`` here too: deployments can configure
|
||||||
|
# ``FALLBACK_LLM_NAME`` to a BYOM UUID, and that UUID is owned
|
||||||
|
# by the same user the primary model was resolved under.
|
||||||
if settings.FALLBACK_LLM_PROVIDER:
|
if settings.FALLBACK_LLM_PROVIDER:
|
||||||
try:
|
try:
|
||||||
self._fallback_llm = LLMCreator.create_llm(
|
self._fallback_llm = LLMCreator.create_llm(
|
||||||
@@ -78,6 +104,7 @@ class BaseLLM(ABC):
|
|||||||
decoded_token=self.decoded_token,
|
decoded_token=self.decoded_token,
|
||||||
model_id=settings.FALLBACK_LLM_NAME,
|
model_id=settings.FALLBACK_LLM_NAME,
|
||||||
agent_id=self.agent_id,
|
agent_id=self.agent_id,
|
||||||
|
model_user_id=self.model_user_id,
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Fallback LLM initialized from global settings: "
|
f"Fallback LLM initialized from global settings: "
|
||||||
@@ -96,6 +123,26 @@ class BaseLLM(ABC):
|
|||||||
return args_dict
|
return args_dict
|
||||||
return {k: v for k, v in args_dict.items() if v is not None}
|
return {k: v for k, v in args_dict.items() if v is not None}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_non_retriable_client_error(exc: BaseException) -> bool:
|
||||||
|
"""4xx errors mean the request itself is malformed — retrying with
|
||||||
|
a different model fails identically and doubles the work. Only
|
||||||
|
transient/5xx/connection errors should trigger fallback."""
|
||||||
|
try:
|
||||||
|
from google.genai.errors import ClientError as _GenaiClientError
|
||||||
|
|
||||||
|
if isinstance(exc, _GenaiClientError):
|
||||||
|
return True
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
for attr in ("status_code", "code", "http_status"):
|
||||||
|
v = getattr(exc, attr, None)
|
||||||
|
if isinstance(v, int) and 400 <= v < 500:
|
||||||
|
return True
|
||||||
|
resp = getattr(exc, "response", None)
|
||||||
|
v = getattr(resp, "status_code", None)
|
||||||
|
return isinstance(v, int) and 400 <= v < 500
|
||||||
|
|
||||||
def _execute_with_fallback(
|
def _execute_with_fallback(
|
||||||
self, method_name: str, decorators: list, *args, **kwargs
|
self, method_name: str, decorators: list, *args, **kwargs
|
||||||
):
|
):
|
||||||
@@ -119,12 +166,18 @@ class BaseLLM(ABC):
|
|||||||
|
|
||||||
if is_stream:
|
if is_stream:
|
||||||
return self._stream_with_fallback(
|
return self._stream_with_fallback(
|
||||||
decorated_method, method_name, *args, **kwargs
|
decorated_method, method_name, decorators, *args, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return decorated_method()
|
return decorated_method()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
if self._is_non_retriable_client_error(e):
|
||||||
|
logger.error(
|
||||||
|
f"Primary LLM failed with non-retriable client error; "
|
||||||
|
f"skipping fallback: {str(e)}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
if not self.fallback_llm:
|
if not self.fallback_llm:
|
||||||
logger.error(f"Primary LLM failed and no fallback configured: {str(e)}")
|
logger.error(f"Primary LLM failed and no fallback configured: {str(e)}")
|
||||||
raise
|
raise
|
||||||
@@ -134,14 +187,27 @@ class BaseLLM(ABC):
|
|||||||
f"{fallback.model_id}. Error: {str(e)}"
|
f"{fallback.model_id}. Error: {str(e)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
fallback_method = getattr(
|
# Apply decorators to fallback's raw method directly — calling
|
||||||
fallback, method_name.replace("_raw_", "")
|
# fallback.gen() would re-enter the orchestrator and recurse via
|
||||||
)
|
# fallback.fallback_llm.
|
||||||
|
fallback_method = getattr(fallback, method_name)
|
||||||
|
for decorator in decorators:
|
||||||
|
fallback_method = decorator(fallback_method)
|
||||||
fallback_kwargs = {**kwargs, "model": fallback.model_id}
|
fallback_kwargs = {**kwargs, "model": fallback.model_id}
|
||||||
return fallback_method(*args, **fallback_kwargs)
|
try:
|
||||||
|
return fallback_method(fallback, *args, **fallback_kwargs)
|
||||||
|
except Exception as e2:
|
||||||
|
if self._is_non_retriable_client_error(e2):
|
||||||
|
logger.error(
|
||||||
|
f"Fallback LLM failed with non-retriable client "
|
||||||
|
f"error; giving up: {str(e2)}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.error(f"Fallback LLM also failed; giving up: {str(e2)}")
|
||||||
|
raise
|
||||||
|
|
||||||
def _stream_with_fallback(
|
def _stream_with_fallback(
|
||||||
self, decorated_method, method_name, *args, **kwargs
|
self, decorated_method, method_name, decorators, *args, **kwargs
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Wrapper generator that catches mid-stream errors and falls back.
|
Wrapper generator that catches mid-stream errors and falls back.
|
||||||
@@ -154,6 +220,12 @@ class BaseLLM(ABC):
|
|||||||
try:
|
try:
|
||||||
yield from decorated_method()
|
yield from decorated_method()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
if self._is_non_retriable_client_error(e):
|
||||||
|
logger.error(
|
||||||
|
f"Primary LLM failed mid-stream with non-retriable client "
|
||||||
|
f"error; skipping fallback: {str(e)}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
if not self.fallback_llm:
|
if not self.fallback_llm:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Primary LLM failed and no fallback configured: {str(e)}"
|
f"Primary LLM failed and no fallback configured: {str(e)}"
|
||||||
@@ -164,11 +236,37 @@ class BaseLLM(ABC):
|
|||||||
f"Primary LLM failed mid-stream. Falling back to "
|
f"Primary LLM failed mid-stream. Falling back to "
|
||||||
f"{fallback.model_id}. Error: {str(e)}"
|
f"{fallback.model_id}. Error: {str(e)}"
|
||||||
)
|
)
|
||||||
fallback_method = getattr(
|
# Apply decorators to fallback's raw stream method directly —
|
||||||
fallback, method_name.replace("_raw_", "")
|
# calling fallback.gen_stream() would re-enter the orchestrator
|
||||||
|
# and recurse via fallback.fallback_llm. Emit the stream-start
|
||||||
|
# event manually so dashboards still see the fallback's
|
||||||
|
# provider/model when the response actually comes from it.
|
||||||
|
fallback._emit_stream_start_log(
|
||||||
|
fallback.model_id,
|
||||||
|
kwargs.get("messages"),
|
||||||
|
kwargs.get("tools"),
|
||||||
|
bool(
|
||||||
|
kwargs.get("_usage_attachments")
|
||||||
|
or kwargs.get("attachments")
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
fallback_method = getattr(fallback, method_name)
|
||||||
|
for decorator in decorators:
|
||||||
|
fallback_method = decorator(fallback_method)
|
||||||
fallback_kwargs = {**kwargs, "model": fallback.model_id}
|
fallback_kwargs = {**kwargs, "model": fallback.model_id}
|
||||||
yield from fallback_method(*args, **fallback_kwargs)
|
try:
|
||||||
|
yield from fallback_method(fallback, *args, **fallback_kwargs)
|
||||||
|
except Exception as e2:
|
||||||
|
if self._is_non_retriable_client_error(e2):
|
||||||
|
logger.error(
|
||||||
|
f"Fallback LLM failed mid-stream with non-retriable "
|
||||||
|
f"client error; giving up: {str(e2)}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"Fallback LLM also failed mid-stream; giving up: {str(e2)}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
def gen(self, model, messages, stream=False, tools=None, *args, **kwargs):
|
def gen(self, model, messages, stream=False, tools=None, *args, **kwargs):
|
||||||
decorators = [gen_token_usage, gen_cache]
|
decorators = [gen_token_usage, gen_cache]
|
||||||
@@ -183,7 +281,58 @@ class BaseLLM(ABC):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _emit_stream_start_log(self, model, messages, tools, has_attachments):
|
||||||
|
# Stamped with ``self.provider_name`` so dashboards can group calls
|
||||||
|
# by vendor; the fallback path emits its own copy on the fallback
|
||||||
|
# instance so the actual responding provider is recorded.
|
||||||
|
logging.info(
|
||||||
|
"llm_stream_start",
|
||||||
|
extra={
|
||||||
|
"model": model,
|
||||||
|
"provider": self.provider_name,
|
||||||
|
"message_count": len(messages) if messages is not None else 0,
|
||||||
|
"has_attachments": bool(has_attachments),
|
||||||
|
"has_tools": bool(tools),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _emit_stream_finished_log(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
*,
|
||||||
|
prompt_tokens,
|
||||||
|
completion_tokens,
|
||||||
|
latency_ms,
|
||||||
|
cached_tokens=None,
|
||||||
|
error=None,
|
||||||
|
):
|
||||||
|
# Paired with ``llm_stream_start`` so cost dashboards can sum tokens
|
||||||
|
# by user/agent/provider. Token counts are client-side estimates
|
||||||
|
# from ``stream_token_usage``; vendor-reported counts (incl.
|
||||||
|
# ``cached_tokens`` for prompt caching) require per-provider
|
||||||
|
# extraction in each ``_raw_gen_stream`` and aren't wired yet.
|
||||||
|
extra = {
|
||||||
|
"model": model,
|
||||||
|
"provider": self.provider_name,
|
||||||
|
"prompt_tokens": int(prompt_tokens),
|
||||||
|
"completion_tokens": int(completion_tokens),
|
||||||
|
"latency_ms": int(latency_ms),
|
||||||
|
"status": "error" if error is not None else "ok",
|
||||||
|
}
|
||||||
|
if cached_tokens is not None:
|
||||||
|
extra["cached_tokens"] = int(cached_tokens)
|
||||||
|
if error is not None:
|
||||||
|
extra["error_class"] = type(error).__name__
|
||||||
|
logging.info("llm_stream_finished", extra=extra)
|
||||||
|
|
||||||
def gen_stream(self, model, messages, stream=True, tools=None, *args, **kwargs):
|
def gen_stream(self, model, messages, stream=True, tools=None, *args, **kwargs):
|
||||||
|
# Attachments arrive as ``_usage_attachments`` from ``Agent._llm_gen``;
|
||||||
|
# the ``stream_token_usage`` decorator pops that key, but the log
|
||||||
|
# fires before the decorator runs so it's still in ``kwargs`` here.
|
||||||
|
has_attachments = bool(
|
||||||
|
kwargs.get("_usage_attachments") or kwargs.get("attachments")
|
||||||
|
)
|
||||||
|
self._emit_stream_start_log(model, messages, tools, has_attachments)
|
||||||
decorators = [stream_cache, stream_token_usage]
|
decorators = [stream_cache, stream_token_usage]
|
||||||
return self._execute_with_fallback(
|
return self._execute_with_fallback(
|
||||||
"_raw_gen_stream",
|
"_raw_gen_stream",
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ DOCSGPT_BASE_URL = "https://oai.arc53.com"
|
|||||||
DOCSGPT_MODEL = "docsgpt"
|
DOCSGPT_MODEL = "docsgpt"
|
||||||
|
|
||||||
class DocsGPTAPILLM(OpenAILLM):
|
class DocsGPTAPILLM(OpenAILLM):
|
||||||
|
provider_name = "docsgpt"
|
||||||
|
|
||||||
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
|
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
api_key=DOCSGPT_API_KEY,
|
api_key=DOCSGPT_API_KEY,
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ from application.storage.storage_creator import StorageCreator
|
|||||||
|
|
||||||
|
|
||||||
class GoogleLLM(BaseLLM):
|
class GoogleLLM(BaseLLM):
|
||||||
|
provider_name = "google"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, api_key=None, user_api_key=None, decoded_token=None, *args, **kwargs
|
self, api_key=None, user_api_key=None, decoded_token=None, *args, **kwargs
|
||||||
):
|
):
|
||||||
@@ -79,24 +81,39 @@ class GoogleLLM(BaseLLM):
|
|||||||
for attachment in attachments:
|
for attachment in attachments:
|
||||||
mime_type = attachment.get("mime_type")
|
mime_type = attachment.get("mime_type")
|
||||||
|
|
||||||
if mime_type in self.get_supported_attachment_types():
|
if mime_type not in self.get_supported_attachment_types():
|
||||||
try:
|
continue
|
||||||
|
try:
|
||||||
|
# Images go inline as bytes per Google's guidance for
|
||||||
|
# requests under 20MB; the Files API can return before
|
||||||
|
# the upload reaches ACTIVE state and yield an empty URI.
|
||||||
|
if mime_type.startswith("image/"):
|
||||||
|
file_bytes = self._read_attachment_bytes(attachment)
|
||||||
|
files.append(
|
||||||
|
{"file_bytes": file_bytes, "mime_type": mime_type}
|
||||||
|
)
|
||||||
|
else:
|
||||||
file_uri = self._upload_file_to_google(attachment)
|
file_uri = self._upload_file_to_google(attachment)
|
||||||
|
if not file_uri:
|
||||||
|
raise ValueError(
|
||||||
|
f"Google Files API returned empty URI for "
|
||||||
|
f"{attachment.get('path', 'unknown')}"
|
||||||
|
)
|
||||||
logging.info(
|
logging.info(
|
||||||
f"GoogleLLM: Successfully uploaded file, got URI: {file_uri}"
|
f"GoogleLLM: Successfully uploaded file, got URI: {file_uri}"
|
||||||
)
|
)
|
||||||
files.append({"file_uri": file_uri, "mime_type": mime_type})
|
files.append({"file_uri": file_uri, "mime_type": mime_type})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(
|
logging.error(
|
||||||
f"GoogleLLM: Error uploading file: {e}", exc_info=True
|
f"GoogleLLM: Error processing attachment: {e}", exc_info=True
|
||||||
|
)
|
||||||
|
if "content" in attachment:
|
||||||
|
prepared_messages[user_message_index]["content"].append(
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": f"[File could not be processed: {attachment.get('path', 'unknown')}]",
|
||||||
|
}
|
||||||
)
|
)
|
||||||
if "content" in attachment:
|
|
||||||
prepared_messages[user_message_index]["content"].append(
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": f"[File could not be processed: {attachment.get('path', 'unknown')}]",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if files:
|
if files:
|
||||||
logging.info(f"GoogleLLM: Adding {len(files)} files to message")
|
logging.info(f"GoogleLLM: Adding {len(files)} files to message")
|
||||||
prepared_messages[user_message_index]["content"].append({"files": files})
|
prepared_messages[user_message_index]["content"].append({"files": files})
|
||||||
@@ -112,7 +129,9 @@ class GoogleLLM(BaseLLM):
|
|||||||
Returns:
|
Returns:
|
||||||
str: Google AI file URI for the uploaded file.
|
str: Google AI file URI for the uploaded file.
|
||||||
"""
|
"""
|
||||||
if "google_file_uri" in attachment:
|
# Truthy check, not membership: a poisoned cache row of "" or
|
||||||
|
# None must be treated as a miss and trigger a fresh upload.
|
||||||
|
if attachment.get("google_file_uri"):
|
||||||
return attachment["google_file_uri"]
|
return attachment["google_file_uri"]
|
||||||
file_path = attachment.get("path")
|
file_path = attachment.get("path")
|
||||||
if not file_path:
|
if not file_path:
|
||||||
@@ -126,6 +145,10 @@ class GoogleLLM(BaseLLM):
|
|||||||
file=local_path
|
file=local_path
|
||||||
).uri,
|
).uri,
|
||||||
)
|
)
|
||||||
|
if not file_uri:
|
||||||
|
raise ValueError(
|
||||||
|
f"Google Files API upload returned empty URI for {file_path}"
|
||||||
|
)
|
||||||
|
|
||||||
# Cache the Google file URI on the attachment row so we don't
|
# Cache the Google file URI on the attachment row so we don't
|
||||||
# re-upload on the next LLM call. Accept either a PG UUID
|
# re-upload on the next LLM call. Accept either a PG UUID
|
||||||
@@ -159,6 +182,26 @@ class GoogleLLM(BaseLLM):
|
|||||||
logging.error(f"Error uploading file to Google AI: {e}", exc_info=True)
|
logging.error(f"Error uploading file to Google AI: {e}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
def _read_attachment_bytes(self, attachment):
|
||||||
|
"""
|
||||||
|
Read attachment bytes from storage for inline transmission.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attachment (dict): Attachment dictionary with path and metadata.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bytes: Raw file bytes.
|
||||||
|
"""
|
||||||
|
file_path = attachment.get("path")
|
||||||
|
if not file_path:
|
||||||
|
raise ValueError("No file path provided in attachment")
|
||||||
|
if not self.storage.file_exists(file_path):
|
||||||
|
raise FileNotFoundError(f"File not found: {file_path}")
|
||||||
|
return self.storage.process_file(
|
||||||
|
file_path,
|
||||||
|
lambda local_path, **kwargs: open(local_path, "rb").read(),
|
||||||
|
)
|
||||||
|
|
||||||
def _clean_messages_google(self, messages):
|
def _clean_messages_google(self, messages):
|
||||||
"""
|
"""
|
||||||
Convert OpenAI format messages to Google AI format and collect system prompts.
|
Convert OpenAI format messages to Google AI format and collect system prompts.
|
||||||
@@ -298,12 +341,24 @@ class GoogleLLM(BaseLLM):
|
|||||||
)
|
)
|
||||||
elif "files" in item:
|
elif "files" in item:
|
||||||
for file_data in item["files"]:
|
for file_data in item["files"]:
|
||||||
parts.append(
|
if "file_bytes" in file_data:
|
||||||
types.Part.from_uri(
|
parts.append(
|
||||||
file_uri=file_data["file_uri"],
|
types.Part.from_bytes(
|
||||||
mime_type=file_data["mime_type"],
|
data=file_data["file_bytes"],
|
||||||
|
mime_type=file_data["mime_type"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif file_data.get("file_uri"):
|
||||||
|
parts.append(
|
||||||
|
types.Part.from_uri(
|
||||||
|
file_uri=file_data["file_uri"],
|
||||||
|
mime_type=file_data["mime_type"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.warning(
|
||||||
|
"GoogleLLM: dropping file part with empty URI and no bytes"
|
||||||
)
|
)
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unexpected content dictionary format:{item}"
|
f"Unexpected content dictionary format:{item}"
|
||||||
@@ -541,22 +596,6 @@ class GoogleLLM(BaseLLM):
|
|||||||
config.response_mime_type = "application/json"
|
config.response_mime_type = "application/json"
|
||||||
# Check if we have both tools and file attachments
|
# Check if we have both tools and file attachments
|
||||||
|
|
||||||
has_attachments = False
|
|
||||||
for message in messages:
|
|
||||||
for part in message.parts:
|
|
||||||
if hasattr(part, "file_data") and part.file_data is not None:
|
|
||||||
has_attachments = True
|
|
||||||
break
|
|
||||||
if has_attachments:
|
|
||||||
break
|
|
||||||
messages_summary = self._summarize_messages_for_log(messages)
|
|
||||||
logging.info(
|
|
||||||
"GoogleLLM: Starting stream generation. Model: %s, Messages: %s, Has attachments: %s",
|
|
||||||
model,
|
|
||||||
messages_summary,
|
|
||||||
has_attachments,
|
|
||||||
)
|
|
||||||
|
|
||||||
response = client.models.generate_content_stream(
|
response = client.models.generate_content_stream(
|
||||||
model=model,
|
model=model,
|
||||||
contents=messages,
|
contents=messages,
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ GROQ_BASE_URL = "https://api.groq.com/openai/v1"
|
|||||||
|
|
||||||
|
|
||||||
class GroqLLM(OpenAILLM):
|
class GroqLLM(OpenAILLM):
|
||||||
|
provider_name = "groq"
|
||||||
|
|
||||||
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
|
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
api_key=api_key or settings.GROQ_API_KEY or settings.API_KEY,
|
api_key=api_key or settings.GROQ_API_KEY or settings.API_KEY,
|
||||||
|
|||||||
@@ -280,7 +280,26 @@ class LLMHandler(ABC):
|
|||||||
# Keep serialized function calls/responses so the compressor sees actions
|
# Keep serialized function calls/responses so the compressor sees actions
|
||||||
parts_text.append(str(item))
|
parts_text.append(str(item))
|
||||||
elif "files" in item:
|
elif "files" in item:
|
||||||
parts_text.append(str(item))
|
# Image attachments arrive with raw bytes / base64
|
||||||
|
# inline (see GoogleLLM.prepare_messages_with_attachments).
|
||||||
|
# ``str(item)`` would dump the whole byte/base64
|
||||||
|
# blob into the compression prompt and bust the
|
||||||
|
# compression LLM's input limit.
|
||||||
|
files = item.get("files") or []
|
||||||
|
descriptors = []
|
||||||
|
if isinstance(files, list):
|
||||||
|
for f in files:
|
||||||
|
if isinstance(f, dict):
|
||||||
|
descriptors.append(
|
||||||
|
f.get("mime_type") or "file"
|
||||||
|
)
|
||||||
|
elif isinstance(f, str):
|
||||||
|
descriptors.append(f)
|
||||||
|
if not descriptors:
|
||||||
|
descriptors = ["file"]
|
||||||
|
parts_text.append(
|
||||||
|
f"[attachment: {', '.join(descriptors)}]"
|
||||||
|
)
|
||||||
return "\n".join(parts_text)
|
return "\n".join(parts_text)
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
@@ -470,10 +489,14 @@ class LLMHandler(ABC):
|
|||||||
)
|
)
|
||||||
return self._perform_in_memory_compression(agent, messages)
|
return self._perform_in_memory_compression(agent, messages)
|
||||||
|
|
||||||
# Use orchestrator to perform compression
|
# Use orchestrator to perform compression. ``model_user_id``
|
||||||
|
# keeps BYOM registry resolution scoped to the model owner
|
||||||
|
# (shared-agent dispatch) while ``user_id`` stays the caller
|
||||||
|
# for the conversation access check.
|
||||||
result = orchestrator.compress_mid_execution(
|
result = orchestrator.compress_mid_execution(
|
||||||
conversation_id=agent.conversation_id,
|
conversation_id=agent.conversation_id,
|
||||||
user_id=agent.initial_user_id,
|
user_id=agent.initial_user_id,
|
||||||
|
model_user_id=getattr(agent, "model_user_id", None),
|
||||||
model_id=agent.model_id,
|
model_id=agent.model_id,
|
||||||
decoded_token=getattr(agent, "decoded_token", {}),
|
decoded_token=getattr(agent, "decoded_token", {}),
|
||||||
current_conversation=conversation,
|
current_conversation=conversation,
|
||||||
@@ -577,7 +600,20 @@ class LLMHandler(ABC):
|
|||||||
if settings.COMPRESSION_MODEL_OVERRIDE
|
if settings.COMPRESSION_MODEL_OVERRIDE
|
||||||
else agent.model_id
|
else agent.model_id
|
||||||
)
|
)
|
||||||
provider = get_provider_from_model_id(compression_model)
|
agent_decoded = getattr(agent, "decoded_token", None)
|
||||||
|
caller_sub = (
|
||||||
|
agent_decoded.get("sub")
|
||||||
|
if isinstance(agent_decoded, dict)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
# Use model-owner scope (mirrors orchestrator path) so
|
||||||
|
# shared-agent owner-BYOM resolves under the owner's layer.
|
||||||
|
compression_user_id = (
|
||||||
|
getattr(agent, "model_user_id", None) or caller_sub
|
||||||
|
)
|
||||||
|
provider = get_provider_from_model_id(
|
||||||
|
compression_model, user_id=compression_user_id
|
||||||
|
)
|
||||||
api_key = get_api_key_for_provider(provider)
|
api_key = get_api_key_for_provider(provider)
|
||||||
compression_llm = LLMCreator.create_llm(
|
compression_llm = LLMCreator.create_llm(
|
||||||
provider,
|
provider,
|
||||||
@@ -586,6 +622,7 @@ class LLMHandler(ABC):
|
|||||||
getattr(agent, "decoded_token", None),
|
getattr(agent, "decoded_token", None),
|
||||||
model_id=compression_model,
|
model_id=compression_model,
|
||||||
agent_id=getattr(agent, "agent_id", None),
|
agent_id=getattr(agent, "agent_id", None),
|
||||||
|
model_user_id=compression_user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create service without DB persistence capability
|
# Create service without DB persistence capability
|
||||||
@@ -921,8 +958,15 @@ class LLMHandler(ABC):
|
|||||||
}
|
}
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
# ``agent.model_id`` is the registry id (a UUID for BYOM
|
||||||
|
# records). Use the LLM's own model_id, which LLMCreator
|
||||||
|
# already resolved to the upstream model name. Built-ins:
|
||||||
|
# the two are equal; BYOM: the upstream name like
|
||||||
|
# "mistral-large-latest" instead of the UUID.
|
||||||
response = agent.llm.gen(
|
response = agent.llm.gen(
|
||||||
model=agent.model_id, messages=messages, tools=agent.tools
|
model=getattr(agent.llm, "model_id", None) or agent.model_id,
|
||||||
|
messages=messages,
|
||||||
|
tools=agent.tools,
|
||||||
)
|
)
|
||||||
parsed = self.parse_response(response)
|
parsed = self.parse_response(response)
|
||||||
self.llm_calls.append(build_stack_data(agent.llm))
|
self.llm_calls.append(build_stack_data(agent.llm))
|
||||||
@@ -1011,8 +1055,11 @@ class LLMHandler(ABC):
|
|||||||
})
|
})
|
||||||
logger.info("Context limit reached - instructing agent to wrap up")
|
logger.info("Context limit reached - instructing agent to wrap up")
|
||||||
|
|
||||||
|
# See note above on agent.model_id vs llm.model_id.
|
||||||
response = agent.llm.gen_stream(
|
response = agent.llm.gen_stream(
|
||||||
model=agent.model_id, messages=messages, tools=agent.tools if not agent.context_limit_reached else None
|
model=getattr(agent.llm, "model_id", None) or agent.model_id,
|
||||||
|
messages=messages,
|
||||||
|
tools=agent.tools if not agent.context_limit_reached else None,
|
||||||
)
|
)
|
||||||
self.llm_calls.append(build_stack_data(agent.llm))
|
self.llm_calls.append(build_stack_data(agent.llm))
|
||||||
|
|
||||||
|
|||||||
@@ -26,6 +26,8 @@ class LlamaSingleton:
|
|||||||
|
|
||||||
|
|
||||||
class LlamaCpp(BaseLLM):
|
class LlamaCpp(BaseLLM):
|
||||||
|
provider_name = "llama_cpp"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
api_key=None,
|
api_key=None,
|
||||||
|
|||||||
@@ -1,34 +1,11 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from application.llm.anthropic import AnthropicLLM
|
from application.llm.providers import PROVIDERS_BY_NAME
|
||||||
from application.llm.docsgpt_provider import DocsGPTAPILLM
|
|
||||||
from application.llm.google_ai import GoogleLLM
|
|
||||||
from application.llm.groq import GroqLLM
|
|
||||||
from application.llm.llama_cpp import LlamaCpp
|
|
||||||
from application.llm.novita import NovitaLLM
|
|
||||||
from application.llm.openai import AzureOpenAILLM, OpenAILLM
|
|
||||||
from application.llm.premai import PremAILLM
|
|
||||||
from application.llm.sagemaker import SagemakerAPILLM
|
|
||||||
from application.llm.open_router import OpenRouterLLM
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LLMCreator:
|
class LLMCreator:
|
||||||
llms = {
|
|
||||||
"openai": OpenAILLM,
|
|
||||||
"azure_openai": AzureOpenAILLM,
|
|
||||||
"sagemaker": SagemakerAPILLM,
|
|
||||||
"llama.cpp": LlamaCpp,
|
|
||||||
"anthropic": AnthropicLLM,
|
|
||||||
"docsgpt": DocsGPTAPILLM,
|
|
||||||
"premai": PremAILLM,
|
|
||||||
"groq": GroqLLM,
|
|
||||||
"google": GoogleLLM,
|
|
||||||
"novita": NovitaLLM,
|
|
||||||
"openrouter": OpenRouterLLM,
|
|
||||||
}
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_llm(
|
def create_llm(
|
||||||
cls,
|
cls,
|
||||||
@@ -39,28 +16,111 @@ class LLMCreator:
|
|||||||
model_id=None,
|
model_id=None,
|
||||||
agent_id=None,
|
agent_id=None,
|
||||||
backup_models=None,
|
backup_models=None,
|
||||||
|
model_user_id=None,
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
from application.core.model_utils import get_base_url_for_model
|
"""Construct an LLM for the given provider ``type``.
|
||||||
|
|
||||||
llm_class = cls.llms.get(type.lower())
|
``model_user_id`` is the BYOM-resolution scope. Defaults to
|
||||||
if not llm_class:
|
``decoded_token['sub']`` (the caller). Pass it explicitly when
|
||||||
|
the model record belongs to a *different* user — most notably
|
||||||
|
for shared-agent dispatch, where the agent's stored
|
||||||
|
``default_model_id`` is the owner's BYOM UUID but
|
||||||
|
``decoded_token`` represents the caller.
|
||||||
|
"""
|
||||||
|
from application.core.model_registry import ModelRegistry
|
||||||
|
from application.security.safe_url import (
|
||||||
|
UnsafeUserUrlError,
|
||||||
|
pinned_httpx_client,
|
||||||
|
validate_user_base_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
plugin = PROVIDERS_BY_NAME.get(type.lower())
|
||||||
|
if plugin is None or plugin.llm_class is None:
|
||||||
raise ValueError(f"No LLM class found for type {type}")
|
raise ValueError(f"No LLM class found for type {type}")
|
||||||
|
|
||||||
# Extract base_url from model configuration if model_id is provided
|
# Prefer per-model endpoint config from the registry. This is what
|
||||||
|
# makes openai_compatible AND end-user BYOM work without changing
|
||||||
|
# every call site: if the registered AvailableModel carries its
|
||||||
|
# own api_key / base_url, they win over whatever the caller
|
||||||
|
# resolved via the provider plugin.
|
||||||
|
#
|
||||||
|
# End-user BYOM lookups need the user_id from decoded_token to
|
||||||
|
# find the user's per-user models layer (built-in models resolve
|
||||||
|
# without it, so this stays back-compat).
|
||||||
base_url = None
|
base_url = None
|
||||||
|
upstream_model_id = model_id
|
||||||
|
capabilities = None
|
||||||
if model_id:
|
if model_id:
|
||||||
base_url = get_base_url_for_model(model_id)
|
user_id = model_user_id
|
||||||
|
if user_id is None:
|
||||||
|
user_id = (
|
||||||
|
(decoded_token or {}).get("sub") if decoded_token else None
|
||||||
|
)
|
||||||
|
model = ModelRegistry.get_instance().get_model(model_id, user_id=user_id)
|
||||||
|
if model is not None:
|
||||||
|
# Forward registry caps so the LLM enforces them at
|
||||||
|
# dispatch (built-in classes hard-code True otherwise).
|
||||||
|
capabilities = getattr(model, "capabilities", None)
|
||||||
|
# SECURITY: refuse user-source dispatch without its own
|
||||||
|
# api_key (would leak settings.API_KEY to base_url).
|
||||||
|
if (
|
||||||
|
getattr(model, "source", "builtin") == "user"
|
||||||
|
and not model.api_key
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Custom model {model_id!r} has no usable API key "
|
||||||
|
"(decryption may have failed). Re-save the model "
|
||||||
|
"in settings to dispatch it."
|
||||||
|
)
|
||||||
|
if model.api_key:
|
||||||
|
api_key = model.api_key
|
||||||
|
if model.base_url:
|
||||||
|
base_url = model.base_url
|
||||||
|
# For BYOM the registry id is a UUID; the upstream API
|
||||||
|
# call needs the user's typed model name instead.
|
||||||
|
if model.upstream_model_id:
|
||||||
|
upstream_model_id = model.upstream_model_id
|
||||||
|
|
||||||
return llm_class(
|
# SECURITY: re-validate at dispatch (defense in depth
|
||||||
|
# for pre-guard rows / YAML-supplied entries). The
|
||||||
|
# pinned httpx.Client below is what actually closes the
|
||||||
|
# DNS-rebinding TOCTOU window.
|
||||||
|
if base_url and getattr(model, "source", "builtin") == "user":
|
||||||
|
try:
|
||||||
|
validate_user_base_url(base_url)
|
||||||
|
except UnsafeUserUrlError as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Refusing to dispatch model {model_id!r}: {e}"
|
||||||
|
) from e
|
||||||
|
# Pinned httpx.Client: resolves once, validates, and
|
||||||
|
# binds the SDK's outbound socket to the validated IP
|
||||||
|
# (preserves Host / SNI). Future BYOM providers must
|
||||||
|
# opt in explicitly — only openai_compatible takes
|
||||||
|
# http_client today.
|
||||||
|
if plugin.name == "openai_compatible":
|
||||||
|
try:
|
||||||
|
kwargs["http_client"] = pinned_httpx_client(
|
||||||
|
base_url
|
||||||
|
)
|
||||||
|
except UnsafeUserUrlError as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Refusing to dispatch model {model_id!r}: {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
# Forward model_user_id so backup/fallback resolves under the
|
||||||
|
# owner's scope on shared-agent dispatch.
|
||||||
|
return plugin.llm_class(
|
||||||
api_key,
|
api_key,
|
||||||
user_api_key,
|
user_api_key,
|
||||||
decoded_token=decoded_token,
|
decoded_token=decoded_token,
|
||||||
model_id=model_id,
|
model_id=upstream_model_id,
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
base_url=base_url,
|
base_url=base_url,
|
||||||
backup_models=backup_models,
|
backup_models=backup_models,
|
||||||
|
model_user_id=model_user_id,
|
||||||
|
capabilities=capabilities,
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ NOVITA_BASE_URL = "https://api.novita.ai/openai"
|
|||||||
|
|
||||||
|
|
||||||
class NovitaLLM(OpenAILLM):
|
class NovitaLLM(OpenAILLM):
|
||||||
|
provider_name = "novita"
|
||||||
|
|
||||||
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
|
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
api_key=api_key or settings.NOVITA_API_KEY or settings.API_KEY,
|
api_key=api_key or settings.NOVITA_API_KEY or settings.API_KEY,
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ OPEN_ROUTER_BASE_URL = "https://openrouter.ai/api/v1"
|
|||||||
|
|
||||||
|
|
||||||
class OpenRouterLLM(OpenAILLM):
|
class OpenRouterLLM(OpenAILLM):
|
||||||
|
provider_name = "openrouter"
|
||||||
|
|
||||||
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
|
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
api_key=api_key or settings.OPEN_ROUTER_API_KEY or settings.API_KEY,
|
api_key=api_key or settings.OPEN_ROUTER_API_KEY or settings.API_KEY,
|
||||||
|
|||||||
@@ -61,8 +61,17 @@ def _truncate_base64_for_logging(messages):
|
|||||||
|
|
||||||
|
|
||||||
class OpenAILLM(BaseLLM):
|
class OpenAILLM(BaseLLM):
|
||||||
|
provider_name = "openai"
|
||||||
|
|
||||||
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key=None,
|
||||||
|
user_api_key=None,
|
||||||
|
base_url=None,
|
||||||
|
http_client=None,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.api_key = api_key or settings.OPENAI_API_KEY or settings.API_KEY
|
self.api_key = api_key or settings.OPENAI_API_KEY or settings.API_KEY
|
||||||
@@ -80,7 +89,18 @@ class OpenAILLM(BaseLLM):
|
|||||||
else:
|
else:
|
||||||
effective_base_url = "https://api.openai.com/v1"
|
effective_base_url = "https://api.openai.com/v1"
|
||||||
|
|
||||||
self.client = OpenAI(api_key=self.api_key, base_url=effective_base_url)
|
# http_client (set by LLMCreator for BYOM) is a DNS-rebinding-safe
|
||||||
|
# httpx.Client; without it the SDK re-resolves DNS per request.
|
||||||
|
if http_client is not None:
|
||||||
|
self.client = OpenAI(
|
||||||
|
api_key=self.api_key,
|
||||||
|
base_url=effective_base_url,
|
||||||
|
http_client=http_client,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.client = OpenAI(
|
||||||
|
api_key=self.api_key, base_url=effective_base_url
|
||||||
|
)
|
||||||
self.storage = StorageCreator.get_storage()
|
self.storage = StorageCreator.get_storage()
|
||||||
|
|
||||||
def _clean_messages_openai(self, messages):
|
def _clean_messages_openai(self, messages):
|
||||||
@@ -243,6 +263,13 @@ class OpenAILLM(BaseLLM):
|
|||||||
if "max_tokens" in kwargs:
|
if "max_tokens" in kwargs:
|
||||||
kwargs["max_completion_tokens"] = kwargs.pop("max_tokens")
|
kwargs["max_completion_tokens"] = kwargs.pop("max_tokens")
|
||||||
|
|
||||||
|
# Defense-in-depth: drop tools / response_format if the
|
||||||
|
# registry's capability flags deny them.
|
||||||
|
if tools and not self._supports_tools():
|
||||||
|
tools = None
|
||||||
|
if response_format and not self._supports_structured_output():
|
||||||
|
response_format = None
|
||||||
|
|
||||||
request_params = {
|
request_params = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
@@ -279,6 +306,13 @@ class OpenAILLM(BaseLLM):
|
|||||||
if "max_tokens" in kwargs:
|
if "max_tokens" in kwargs:
|
||||||
kwargs["max_completion_tokens"] = kwargs.pop("max_tokens")
|
kwargs["max_completion_tokens"] = kwargs.pop("max_tokens")
|
||||||
|
|
||||||
|
# See _raw_gen for rationale — drop tools/response_format when the
|
||||||
|
# registry-provided capabilities say the model doesn't support them.
|
||||||
|
if tools and not self._supports_tools():
|
||||||
|
tools = None
|
||||||
|
if response_format and not self._supports_structured_output():
|
||||||
|
response_format = None
|
||||||
|
|
||||||
request_params = {
|
request_params = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
@@ -320,9 +354,17 @@ class OpenAILLM(BaseLLM):
|
|||||||
response.close()
|
response.close()
|
||||||
|
|
||||||
def _supports_tools(self):
|
def _supports_tools(self):
|
||||||
|
# When the LLM was constructed via LLMCreator with a registered
|
||||||
|
# AvailableModel, ``self.capabilities`` is the per-model record.
|
||||||
|
# BYOM users can disable tool support; respect that. Otherwise
|
||||||
|
# OpenAI's API supports tools by default.
|
||||||
|
if self.capabilities is not None:
|
||||||
|
return bool(self.capabilities.supports_tools)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _supports_structured_output(self):
|
def _supports_structured_output(self):
|
||||||
|
if self.capabilities is not None:
|
||||||
|
return bool(self.capabilities.supports_structured_output)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def prepare_structured_output_format(self, json_schema):
|
def prepare_structured_output_format(self, json_schema):
|
||||||
@@ -389,8 +431,14 @@ class OpenAILLM(BaseLLM):
|
|||||||
Returns:
|
Returns:
|
||||||
list: List of supported MIME types
|
list: List of supported MIME types
|
||||||
"""
|
"""
|
||||||
from application.core.model_configs import OPENAI_ATTACHMENTS
|
# Per-model caps from the registry win when present — a BYOM
|
||||||
return OPENAI_ATTACHMENTS
|
# endpoint that doesn't accept images would otherwise still be
|
||||||
|
# sent base64 image parts because the OpenAI default below
|
||||||
|
# advertises the image alias unconditionally.
|
||||||
|
if self.capabilities is not None:
|
||||||
|
return list(self.capabilities.supported_attachment_types or [])
|
||||||
|
from application.core.model_yaml import resolve_attachment_alias
|
||||||
|
return resolve_attachment_alias("image")
|
||||||
|
|
||||||
def prepare_messages_with_attachments(self, messages, attachments=None):
|
def prepare_messages_with_attachments(self, messages, attachments=None):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from application.core.settings import settings
|
|||||||
|
|
||||||
|
|
||||||
class PremAILLM(BaseLLM):
|
class PremAILLM(BaseLLM):
|
||||||
|
provider_name = "premai"
|
||||||
|
|
||||||
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
||||||
from premai import Prem
|
from premai import Prem
|
||||||
|
|||||||
51
application/llm/providers/__init__.py
Normal file
51
application/llm/providers/__init__.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
"""Provider plugin registry.
|
||||||
|
|
||||||
|
Plugins are imported eagerly so import errors surface at app boot rather
|
||||||
|
than at first request. ``ALL_PROVIDERS`` is the canonical ordered list;
|
||||||
|
``PROVIDERS_BY_NAME`` is a name-keyed lookup for LLMCreator and the
|
||||||
|
model registry.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from application.llm.providers.anthropic import AnthropicProvider
|
||||||
|
from application.llm.providers.azure_openai import AzureOpenAIProvider
|
||||||
|
from application.llm.providers.base import Provider
|
||||||
|
from application.llm.providers.docsgpt import DocsGPTProvider
|
||||||
|
from application.llm.providers.google import GoogleProvider
|
||||||
|
from application.llm.providers.groq import GroqProvider
|
||||||
|
from application.llm.providers.huggingface import HuggingFaceProvider
|
||||||
|
from application.llm.providers.llama_cpp import LlamaCppProvider
|
||||||
|
from application.llm.providers.novita import NovitaProvider
|
||||||
|
from application.llm.providers.openai import OpenAIProvider
|
||||||
|
from application.llm.providers.openai_compatible import OpenAICompatibleProvider
|
||||||
|
from application.llm.providers.openrouter import OpenRouterProvider
|
||||||
|
from application.llm.providers.premai import PremAIProvider
|
||||||
|
from application.llm.providers.sagemaker import SagemakerProvider
|
||||||
|
|
||||||
|
# Order here is the order the registry iterates providers (and therefore
|
||||||
|
# the order ``/api/models`` reports them). Match the historical order
|
||||||
|
# from the old ModelRegistry._load_models for byte-stable output during
|
||||||
|
# the migration. ``openai_compatible`` slots in right after ``openai``
|
||||||
|
# so legacy ``OPENAI_BASE_URL`` models keep landing in the same place.
|
||||||
|
ALL_PROVIDERS: List[Provider] = [
|
||||||
|
DocsGPTProvider(),
|
||||||
|
OpenAIProvider(),
|
||||||
|
OpenAICompatibleProvider(),
|
||||||
|
AzureOpenAIProvider(),
|
||||||
|
AnthropicProvider(),
|
||||||
|
GoogleProvider(),
|
||||||
|
GroqProvider(),
|
||||||
|
OpenRouterProvider(),
|
||||||
|
NovitaProvider(),
|
||||||
|
HuggingFaceProvider(),
|
||||||
|
LlamaCppProvider(),
|
||||||
|
PremAIProvider(),
|
||||||
|
SagemakerProvider(),
|
||||||
|
]
|
||||||
|
|
||||||
|
PROVIDERS_BY_NAME: Dict[str, Provider] = {p.name: p for p in ALL_PROVIDERS}
|
||||||
|
|
||||||
|
__all__ = ["ALL_PROVIDERS", "PROVIDERS_BY_NAME", "Provider"]
|
||||||
51
application/llm/providers/_apikey_or_llm_name.py
Normal file
51
application/llm/providers/_apikey_or_llm_name.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
"""Shared helper for providers that follow the
|
||||||
|
``<X>_API_KEY or (LLM_PROVIDER==X and API_KEY)`` pattern.
|
||||||
|
|
||||||
|
This is the dominant pattern across Anthropic, Google, Groq, OpenRouter,
|
||||||
|
and Novita. Extracted here so each plugin stays a few lines long.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from application.core.model_settings import AvailableModel
|
||||||
|
|
||||||
|
|
||||||
|
def get_api_key(
|
||||||
|
settings,
|
||||||
|
provider_name: str,
|
||||||
|
provider_specific_key: Optional[str],
|
||||||
|
) -> Optional[str]:
|
||||||
|
if provider_specific_key:
|
||||||
|
return provider_specific_key
|
||||||
|
if settings.LLM_PROVIDER == provider_name and settings.API_KEY:
|
||||||
|
return settings.API_KEY
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def filter_models_by_llm_name(
|
||||||
|
settings,
|
||||||
|
provider_name: str,
|
||||||
|
provider_specific_key: Optional[str],
|
||||||
|
models: List[AvailableModel],
|
||||||
|
) -> List[AvailableModel]:
|
||||||
|
"""Mirrors the historical ``_add_<X>_models`` selection logic.
|
||||||
|
|
||||||
|
Behavior:
|
||||||
|
- If the provider-specific API key is set → load all models.
|
||||||
|
- Else if ``LLM_PROVIDER`` matches and ``LLM_NAME`` matches a known
|
||||||
|
model → load just that model.
|
||||||
|
- Otherwise → load all models (preserved "load anyway" branch from
|
||||||
|
the original methods).
|
||||||
|
"""
|
||||||
|
if provider_specific_key:
|
||||||
|
return models
|
||||||
|
if (
|
||||||
|
settings.LLM_PROVIDER == provider_name
|
||||||
|
and settings.LLM_NAME
|
||||||
|
):
|
||||||
|
named = [m for m in models if m.id == settings.LLM_NAME]
|
||||||
|
if named:
|
||||||
|
return named
|
||||||
|
return models
|
||||||
23
application/llm/providers/anthropic.py
Normal file
23
application/llm/providers/anthropic.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from application.llm.anthropic import AnthropicLLM
|
||||||
|
from application.llm.providers._apikey_or_llm_name import (
|
||||||
|
filter_models_by_llm_name,
|
||||||
|
get_api_key,
|
||||||
|
)
|
||||||
|
from application.llm.providers.base import Provider
|
||||||
|
|
||||||
|
|
||||||
|
class AnthropicProvider(Provider):
|
||||||
|
name = "anthropic"
|
||||||
|
llm_class = AnthropicLLM
|
||||||
|
|
||||||
|
def get_api_key(self, settings) -> Optional[str]:
|
||||||
|
return get_api_key(settings, self.name, settings.ANTHROPIC_API_KEY)
|
||||||
|
|
||||||
|
def filter_yaml_models(self, settings, models):
|
||||||
|
return filter_models_by_llm_name(
|
||||||
|
settings, self.name, settings.ANTHROPIC_API_KEY, models
|
||||||
|
)
|
||||||
30
application/llm/providers/azure_openai.py
Normal file
30
application/llm/providers/azure_openai.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from application.llm.openai import AzureOpenAILLM
|
||||||
|
from application.llm.providers.base import Provider
|
||||||
|
|
||||||
|
|
||||||
|
class AzureOpenAIProvider(Provider):
|
||||||
|
name = "azure_openai"
|
||||||
|
llm_class = AzureOpenAILLM
|
||||||
|
|
||||||
|
def get_api_key(self, settings) -> Optional[str]:
|
||||||
|
# Azure historically uses the generic API_KEY field.
|
||||||
|
return settings.API_KEY
|
||||||
|
|
||||||
|
def is_enabled(self, settings) -> bool:
|
||||||
|
if settings.OPENAI_API_BASE:
|
||||||
|
return True
|
||||||
|
return settings.LLM_PROVIDER == self.name and bool(settings.API_KEY)
|
||||||
|
|
||||||
|
def filter_yaml_models(self, settings, models):
|
||||||
|
# Mirrors _add_azure_openai_models: when LLM_PROVIDER==azure_openai
|
||||||
|
# and LLM_NAME matches a known model, narrow to that one model.
|
||||||
|
# Otherwise load the entire catalog.
|
||||||
|
if settings.LLM_PROVIDER == self.name and settings.LLM_NAME:
|
||||||
|
named = [m for m in models if m.id == settings.LLM_NAME]
|
||||||
|
if named:
|
||||||
|
return named
|
||||||
|
return models
|
||||||
74
application/llm/providers/base.py
Normal file
74
application/llm/providers/base.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import TYPE_CHECKING, ClassVar, List, Optional, Type
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from application.core.model_settings import AvailableModel
|
||||||
|
from application.core.model_yaml import ProviderCatalog
|
||||||
|
from application.core.settings import Settings
|
||||||
|
from application.llm.base import BaseLLM
|
||||||
|
|
||||||
|
|
||||||
|
class Provider(ABC):
|
||||||
|
"""Owns the *behavior* of an LLM provider.
|
||||||
|
|
||||||
|
Concrete providers declare their name, the LLM class to instantiate,
|
||||||
|
and how to resolve credentials from settings. Static model catalogs
|
||||||
|
live in YAML under ``application/core/models/`` and are joined to the
|
||||||
|
provider by name at registry load time.
|
||||||
|
|
||||||
|
Most plugins receive zero or one catalog at registry-build time. The
|
||||||
|
``openai_compatible`` plugin is the exception: it receives one catalog
|
||||||
|
per matching YAML file, each with its own ``api_key_env`` and
|
||||||
|
``base_url``. Plugins that need per-catalog metadata override
|
||||||
|
``get_models``; the default implementation merges catalogs and routes
|
||||||
|
through ``filter_yaml_models`` + ``extra_models``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: ClassVar[str]
|
||||||
|
# ``None`` means the provider appears in the catalog but isn't
|
||||||
|
# dispatchable through LLMCreator (e.g. Hugging Face today, where the
|
||||||
|
# original LLMCreator dict had no entry).
|
||||||
|
llm_class: ClassVar[Optional[Type["BaseLLM"]]] = None
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_api_key(self, settings: "Settings") -> Optional[str]:
|
||||||
|
"""Return the API key for this provider, or None if unavailable."""
|
||||||
|
|
||||||
|
def is_enabled(self, settings: "Settings") -> bool:
|
||||||
|
"""Whether this provider should contribute models to the registry."""
|
||||||
|
return bool(self.get_api_key(settings))
|
||||||
|
|
||||||
|
def filter_yaml_models(
|
||||||
|
self, settings: "Settings", models: List["AvailableModel"]
|
||||||
|
) -> List["AvailableModel"]:
|
||||||
|
"""Hook to filter YAML-loaded models. Default: return all."""
|
||||||
|
return models
|
||||||
|
|
||||||
|
def extra_models(self, settings: "Settings") -> List["AvailableModel"]:
|
||||||
|
"""Hook to add dynamic models not declared in YAML. Default: none."""
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_models(
|
||||||
|
self,
|
||||||
|
settings: "Settings",
|
||||||
|
catalogs: List["ProviderCatalog"],
|
||||||
|
) -> List["AvailableModel"]:
|
||||||
|
"""Final list of models this plugin contributes.
|
||||||
|
|
||||||
|
Default: merge the models across all matched catalogs (later
|
||||||
|
catalog wins on duplicate id), filter via ``filter_yaml_models``,
|
||||||
|
then append ``extra_models``. Override when per-catalog metadata
|
||||||
|
matters (see ``OpenAICompatibleProvider``).
|
||||||
|
"""
|
||||||
|
merged: List["AvailableModel"] = []
|
||||||
|
seen: dict = {}
|
||||||
|
for c in catalogs:
|
||||||
|
for m in c.models:
|
||||||
|
if m.id in seen:
|
||||||
|
merged[seen[m.id]] = m
|
||||||
|
else:
|
||||||
|
seen[m.id] = len(merged)
|
||||||
|
merged.append(m)
|
||||||
|
return self.filter_yaml_models(settings, merged) + self.extra_models(settings)
|
||||||
22
application/llm/providers/docsgpt.py
Normal file
22
application/llm/providers/docsgpt.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from application.llm.docsgpt_provider import DocsGPTAPILLM
|
||||||
|
from application.llm.providers.base import Provider
|
||||||
|
|
||||||
|
|
||||||
|
class DocsGPTProvider(Provider):
|
||||||
|
name = "docsgpt"
|
||||||
|
llm_class = DocsGPTAPILLM
|
||||||
|
|
||||||
|
def get_api_key(self, settings) -> Optional[str]:
|
||||||
|
# No provider-specific key; the LLM class can use the generic
|
||||||
|
# API_KEY fallback if it needs one. Mirrors model_utils' historical
|
||||||
|
# behavior of returning settings.API_KEY when no specific key exists.
|
||||||
|
return settings.API_KEY
|
||||||
|
|
||||||
|
def is_enabled(self, settings) -> bool:
|
||||||
|
# The hosted DocsGPT model is hidden when the deployment is
|
||||||
|
# pointed at a custom OpenAI-compatible endpoint.
|
||||||
|
return not settings.OPENAI_BASE_URL
|
||||||
23
application/llm/providers/google.py
Normal file
23
application/llm/providers/google.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from application.llm.google_ai import GoogleLLM
|
||||||
|
from application.llm.providers._apikey_or_llm_name import (
|
||||||
|
filter_models_by_llm_name,
|
||||||
|
get_api_key,
|
||||||
|
)
|
||||||
|
from application.llm.providers.base import Provider
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleProvider(Provider):
|
||||||
|
name = "google"
|
||||||
|
llm_class = GoogleLLM
|
||||||
|
|
||||||
|
def get_api_key(self, settings) -> Optional[str]:
|
||||||
|
return get_api_key(settings, self.name, settings.GOOGLE_API_KEY)
|
||||||
|
|
||||||
|
def filter_yaml_models(self, settings, models):
|
||||||
|
return filter_models_by_llm_name(
|
||||||
|
settings, self.name, settings.GOOGLE_API_KEY, models
|
||||||
|
)
|
||||||
23
application/llm/providers/groq.py
Normal file
23
application/llm/providers/groq.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from application.llm.groq import GroqLLM
|
||||||
|
from application.llm.providers._apikey_or_llm_name import (
|
||||||
|
filter_models_by_llm_name,
|
||||||
|
get_api_key,
|
||||||
|
)
|
||||||
|
from application.llm.providers.base import Provider
|
||||||
|
|
||||||
|
|
||||||
|
class GroqProvider(Provider):
|
||||||
|
name = "groq"
|
||||||
|
llm_class = GroqLLM
|
||||||
|
|
||||||
|
def get_api_key(self, settings) -> Optional[str]:
|
||||||
|
return get_api_key(settings, self.name, settings.GROQ_API_KEY)
|
||||||
|
|
||||||
|
def filter_yaml_models(self, settings, models):
|
||||||
|
return filter_models_by_llm_name(
|
||||||
|
settings, self.name, settings.GROQ_API_KEY, models
|
||||||
|
)
|
||||||
25
application/llm/providers/huggingface.py
Normal file
25
application/llm/providers/huggingface.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from application.llm.providers._apikey_or_llm_name import (
|
||||||
|
get_api_key as shared_get_api_key,
|
||||||
|
)
|
||||||
|
from application.llm.providers.base import Provider
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingFaceProvider(Provider):
|
||||||
|
"""Surfaces ``huggingface-local`` to the model catalog.
|
||||||
|
|
||||||
|
Not dispatchable through LLMCreator — historically there was no
|
||||||
|
HuggingFaceLLM entry in ``LLMCreator.llms``, and calling ``create_llm``
|
||||||
|
with ``"huggingface"`` raised ``ValueError``. We preserve that
|
||||||
|
behavior: the model appears in ``/api/models`` but selecting it
|
||||||
|
surfaces the same error it always did.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "huggingface"
|
||||||
|
llm_class = None # not dispatchable
|
||||||
|
|
||||||
|
def get_api_key(self, settings) -> Optional[str]:
|
||||||
|
return shared_get_api_key(settings, self.name, settings.HUGGINGFACE_API_KEY)
|
||||||
19
application/llm/providers/llama_cpp.py
Normal file
19
application/llm/providers/llama_cpp.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from application.llm.llama_cpp import LlamaCpp
|
||||||
|
from application.llm.providers.base import Provider
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaCppProvider(Provider):
|
||||||
|
"""LLMCreator-only plugin: invocable via LLM_PROVIDER but not in the catalog."""
|
||||||
|
|
||||||
|
name = "llama.cpp"
|
||||||
|
llm_class = LlamaCpp
|
||||||
|
|
||||||
|
def get_api_key(self, settings) -> Optional[str]:
|
||||||
|
return settings.API_KEY
|
||||||
|
|
||||||
|
def is_enabled(self, settings) -> bool:
|
||||||
|
return False
|
||||||
23
application/llm/providers/novita.py
Normal file
23
application/llm/providers/novita.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from application.llm.novita import NovitaLLM
|
||||||
|
from application.llm.providers._apikey_or_llm_name import (
|
||||||
|
filter_models_by_llm_name,
|
||||||
|
get_api_key,
|
||||||
|
)
|
||||||
|
from application.llm.providers.base import Provider
|
||||||
|
|
||||||
|
|
||||||
|
class NovitaProvider(Provider):
|
||||||
|
name = "novita"
|
||||||
|
llm_class = NovitaLLM
|
||||||
|
|
||||||
|
def get_api_key(self, settings) -> Optional[str]:
|
||||||
|
return get_api_key(settings, self.name, settings.NOVITA_API_KEY)
|
||||||
|
|
||||||
|
def filter_yaml_models(self, settings, models):
|
||||||
|
return filter_models_by_llm_name(
|
||||||
|
settings, self.name, settings.NOVITA_API_KEY, models
|
||||||
|
)
|
||||||
37
application/llm/providers/openai.py
Normal file
37
application/llm/providers/openai.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from application.llm.openai import OpenAILLM
|
||||||
|
from application.llm.providers.base import Provider
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIProvider(Provider):
|
||||||
|
name = "openai"
|
||||||
|
llm_class = OpenAILLM
|
||||||
|
|
||||||
|
def get_api_key(self, settings) -> Optional[str]:
|
||||||
|
if settings.OPENAI_API_KEY:
|
||||||
|
return settings.OPENAI_API_KEY
|
||||||
|
if settings.LLM_PROVIDER == self.name and settings.API_KEY:
|
||||||
|
return settings.API_KEY
|
||||||
|
return None
|
||||||
|
|
||||||
|
def is_enabled(self, settings) -> bool:
|
||||||
|
# When the deployment is pointed at a custom OpenAI-compatible
|
||||||
|
# endpoint (Ollama, LM Studio, ...), the cloud-OpenAI catalog is
|
||||||
|
# suppressed but ``is_enabled`` stays True — necessary so the
|
||||||
|
# filter below still gets to drop the catalog (rather than the
|
||||||
|
# registry skipping the provider entirely and missing the rule).
|
||||||
|
if settings.OPENAI_BASE_URL:
|
||||||
|
return True
|
||||||
|
return bool(self.get_api_key(settings))
|
||||||
|
|
||||||
|
def filter_yaml_models(self, settings, models):
|
||||||
|
# Legacy local-endpoint mode hides the cloud catalog. The
|
||||||
|
# corresponding dynamic models live in OpenAICompatibleProvider.
|
||||||
|
if settings.OPENAI_BASE_URL:
|
||||||
|
return []
|
||||||
|
if not settings.OPENAI_API_KEY:
|
||||||
|
return []
|
||||||
|
return models
|
||||||
149
application/llm/providers/openai_compatible.py
Normal file
149
application/llm/providers/openai_compatible.py
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
"""Generic provider for OpenAI-wire-compatible endpoints.
|
||||||
|
|
||||||
|
Each ``openai_compatible`` YAML file describes one logical endpoint
|
||||||
|
(Mistral, Together, Fireworks, Ollama, ...) with its own
|
||||||
|
``api_key_env`` and ``base_url``. Multiple files can coexist; the
|
||||||
|
plugin produces one set of models per file, each pre-configured with
|
||||||
|
the right credentials and URL.
|
||||||
|
|
||||||
|
The plugin also handles the **legacy** ``OPENAI_BASE_URL`` + ``LLM_NAME``
|
||||||
|
local-endpoint pattern that previously lived in ``OpenAIProvider``. That
|
||||||
|
path generates models dynamically from ``LLM_NAME``, using
|
||||||
|
``OPENAI_BASE_URL`` and ``OPENAI_API_KEY`` as the endpoint config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from application.core.model_settings import (
|
||||||
|
AvailableModel,
|
||||||
|
ModelCapabilities,
|
||||||
|
ModelProvider,
|
||||||
|
)
|
||||||
|
from application.llm.openai import OpenAILLM
|
||||||
|
from application.llm.providers.base import Provider
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_model_names(llm_name: Optional[str]) -> List[str]:
|
||||||
|
if not llm_name:
|
||||||
|
return []
|
||||||
|
return [name.strip() for name in llm_name.split(",") if name.strip()]
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAICompatibleProvider(Provider):
|
||||||
|
name = "openai_compatible"
|
||||||
|
llm_class = OpenAILLM
|
||||||
|
|
||||||
|
def get_api_key(self, settings) -> Optional[str]:
|
||||||
|
# Per-model: each catalog supplies its own ``api_key_env``. There
|
||||||
|
# is no single plugin-wide key. LLMCreator reads the per-model
|
||||||
|
# ``api_key`` set during catalog materialization.
|
||||||
|
return None
|
||||||
|
|
||||||
|
def is_enabled(self, settings) -> bool:
|
||||||
|
# Concrete enablement happens per catalog (in ``get_models``).
|
||||||
|
# Returning True lets the registry call ``get_models`` so we can
|
||||||
|
# decide per-file whether to contribute models.
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_models(self, settings, catalogs) -> List[AvailableModel]:
|
||||||
|
out: List[AvailableModel] = []
|
||||||
|
|
||||||
|
for catalog in catalogs:
|
||||||
|
out.extend(self._materialize_yaml_catalog(catalog))
|
||||||
|
|
||||||
|
if settings.OPENAI_BASE_URL and settings.LLM_NAME:
|
||||||
|
out.extend(self._materialize_legacy_local_endpoint(settings))
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def _materialize_yaml_catalog(self, catalog) -> List[AvailableModel]:
|
||||||
|
"""Resolve one openai_compatible YAML into ready-to-dispatch models.
|
||||||
|
|
||||||
|
Skipped (with an INFO-level log) if ``api_key_env`` resolves to
|
||||||
|
nothing — no point publishing models the user can't actually
|
||||||
|
call. INFO rather than WARNING because operators may legitimately
|
||||||
|
drop multiple provider YAMLs as templates and only set the env
|
||||||
|
vars for the ones they actually use; a missing key is ambiguous,
|
||||||
|
not necessarily a misconfig.
|
||||||
|
"""
|
||||||
|
if not catalog.base_url:
|
||||||
|
raise ValueError(
|
||||||
|
f"{catalog.source_path}: openai_compatible YAML must set "
|
||||||
|
"'base_url'."
|
||||||
|
)
|
||||||
|
if not catalog.api_key_env:
|
||||||
|
raise ValueError(
|
||||||
|
f"{catalog.source_path}: openai_compatible YAML must set "
|
||||||
|
"'api_key_env'."
|
||||||
|
)
|
||||||
|
|
||||||
|
api_key = os.environ.get(catalog.api_key_env)
|
||||||
|
if not api_key:
|
||||||
|
logger.info(
|
||||||
|
"openai_compatible catalog %s skipped: env var %s is not set",
|
||||||
|
catalog.source_path,
|
||||||
|
catalog.api_key_env,
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
out: List[AvailableModel] = []
|
||||||
|
for m in catalog.models:
|
||||||
|
out.append(self._with_endpoint(m, catalog.base_url, api_key))
|
||||||
|
return out
|
||||||
|
|
||||||
|
def _materialize_legacy_local_endpoint(self, settings) -> List[AvailableModel]:
|
||||||
|
"""Generate AvailableModels from ``LLM_NAME`` for the legacy
|
||||||
|
``OPENAI_BASE_URL`` deployment pattern (Ollama, LM Studio, ...).
|
||||||
|
|
||||||
|
Preserves the historical ``provider="openai"`` display behavior
|
||||||
|
by setting ``display_provider="openai"``.
|
||||||
|
"""
|
||||||
|
from application.core.model_yaml import resolve_attachment_alias
|
||||||
|
|
||||||
|
attachments = resolve_attachment_alias("image")
|
||||||
|
api_key = settings.OPENAI_API_KEY or settings.API_KEY
|
||||||
|
out: List[AvailableModel] = []
|
||||||
|
for model_name in _parse_model_names(settings.LLM_NAME):
|
||||||
|
out.append(
|
||||||
|
AvailableModel(
|
||||||
|
id=model_name,
|
||||||
|
provider=ModelProvider.OPENAI_COMPATIBLE,
|
||||||
|
display_name=model_name,
|
||||||
|
description=f"Custom OpenAI-compatible model at {settings.OPENAI_BASE_URL}",
|
||||||
|
base_url=settings.OPENAI_BASE_URL,
|
||||||
|
capabilities=ModelCapabilities(
|
||||||
|
supports_tools=True,
|
||||||
|
supported_attachment_types=attachments,
|
||||||
|
),
|
||||||
|
api_key=api_key,
|
||||||
|
display_provider="openai",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _with_endpoint(
|
||||||
|
model: AvailableModel, base_url: str, api_key: str
|
||||||
|
) -> AvailableModel:
|
||||||
|
"""Return a copy of ``model`` carrying the catalog's endpoint config.
|
||||||
|
|
||||||
|
The catalog-level ``base_url`` is the default; an explicit
|
||||||
|
per-model ``base_url`` in the YAML wins.
|
||||||
|
"""
|
||||||
|
return AvailableModel(
|
||||||
|
id=model.id,
|
||||||
|
provider=model.provider,
|
||||||
|
display_name=model.display_name,
|
||||||
|
description=model.description,
|
||||||
|
capabilities=model.capabilities,
|
||||||
|
enabled=model.enabled,
|
||||||
|
base_url=model.base_url or base_url,
|
||||||
|
display_provider=model.display_provider,
|
||||||
|
api_key=api_key,
|
||||||
|
)
|
||||||
23
application/llm/providers/openrouter.py
Normal file
23
application/llm/providers/openrouter.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from application.llm.open_router import OpenRouterLLM
|
||||||
|
from application.llm.providers._apikey_or_llm_name import (
|
||||||
|
filter_models_by_llm_name,
|
||||||
|
get_api_key,
|
||||||
|
)
|
||||||
|
from application.llm.providers.base import Provider
|
||||||
|
|
||||||
|
|
||||||
|
class OpenRouterProvider(Provider):
|
||||||
|
name = "openrouter"
|
||||||
|
llm_class = OpenRouterLLM
|
||||||
|
|
||||||
|
def get_api_key(self, settings) -> Optional[str]:
|
||||||
|
return get_api_key(settings, self.name, settings.OPEN_ROUTER_API_KEY)
|
||||||
|
|
||||||
|
def filter_yaml_models(self, settings, models):
|
||||||
|
return filter_models_by_llm_name(
|
||||||
|
settings, self.name, settings.OPEN_ROUTER_API_KEY, models
|
||||||
|
)
|
||||||
19
application/llm/providers/premai.py
Normal file
19
application/llm/providers/premai.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from application.llm.premai import PremAILLM
|
||||||
|
from application.llm.providers.base import Provider
|
||||||
|
|
||||||
|
|
||||||
|
class PremAIProvider(Provider):
|
||||||
|
"""LLMCreator-only plugin: invocable via LLM_PROVIDER but not in the catalog."""
|
||||||
|
|
||||||
|
name = "premai"
|
||||||
|
llm_class = PremAILLM
|
||||||
|
|
||||||
|
def get_api_key(self, settings) -> Optional[str]:
|
||||||
|
return settings.API_KEY
|
||||||
|
|
||||||
|
def is_enabled(self, settings) -> bool:
|
||||||
|
return False
|
||||||
24
application/llm/providers/sagemaker.py
Normal file
24
application/llm/providers/sagemaker.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from application.llm.sagemaker import SagemakerAPILLM
|
||||||
|
from application.llm.providers.base import Provider
|
||||||
|
|
||||||
|
|
||||||
|
class SagemakerProvider(Provider):
|
||||||
|
"""LLMCreator-only plugin: invocable via LLM_PROVIDER but not in the catalog.
|
||||||
|
|
||||||
|
SageMaker reads its credentials from ``SAGEMAKER_*`` settings inside
|
||||||
|
the LLM class itself; this plugin's ``get_api_key`` exists only for
|
||||||
|
LLMCreator's symmetry.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "sagemaker"
|
||||||
|
llm_class = SagemakerAPILLM
|
||||||
|
|
||||||
|
def get_api_key(self, settings) -> Optional[str]:
|
||||||
|
return settings.API_KEY
|
||||||
|
|
||||||
|
def is_enabled(self, settings) -> bool:
|
||||||
|
return False
|
||||||
@@ -59,6 +59,7 @@ class LineIterator:
|
|||||||
|
|
||||||
|
|
||||||
class SagemakerAPILLM(BaseLLM):
|
class SagemakerAPILLM(BaseLLM):
|
||||||
|
provider_name = "sagemaker"
|
||||||
|
|
||||||
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
||||||
import boto3
|
import boto3
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
|
import time
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Callable, Dict, Generator, List
|
from typing import Any, Callable, Dict, Generator, List
|
||||||
|
|
||||||
|
from application.core import log_context
|
||||||
from application.storage.db.repositories.stack_logs import StackLogsRepository
|
from application.storage.db.repositories.stack_logs import StackLogsRepository
|
||||||
from application.storage.db.session import db_session
|
from application.storage.db.session import db_session
|
||||||
|
|
||||||
@@ -22,6 +24,15 @@ class LogContext:
|
|||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.query = query
|
self.query = query
|
||||||
self.stacks = []
|
self.stacks = []
|
||||||
|
# Per-activity response aggregates populated by ``_consume_and_log``
|
||||||
|
# while it forwards stream items, then flushed onto the
|
||||||
|
# ``activity_finished`` event so every Flask request gets the
|
||||||
|
# same summary that ``run_agent_logic`` used to log only for the
|
||||||
|
# Celery webhook path.
|
||||||
|
self.answer_length = 0
|
||||||
|
self.thought_length = 0
|
||||||
|
self.source_count = 0
|
||||||
|
self.tool_call_count = 0
|
||||||
|
|
||||||
|
|
||||||
def build_stack_data(
|
def build_stack_data(
|
||||||
@@ -78,25 +89,125 @@ def log_activity() -> Callable:
|
|||||||
user = data.get("user", "local")
|
user = data.get("user", "local")
|
||||||
api_key = data.get("user_api_key", "")
|
api_key = data.get("user_api_key", "")
|
||||||
query = kwargs.get("query", getattr(args[0], "query", ""))
|
query = kwargs.get("query", getattr(args[0], "query", ""))
|
||||||
|
agent_id = getattr(args[0], "agent_id", None) or kwargs.get("agent_id")
|
||||||
|
conversation_id = (
|
||||||
|
kwargs.get("conversation_id")
|
||||||
|
or getattr(args[0], "conversation_id", None)
|
||||||
|
)
|
||||||
|
model = getattr(args[0], "gpt_model", None) or getattr(args[0], "model", None)
|
||||||
|
|
||||||
|
# Capture the surrounding activity_id before overlaying ours,
|
||||||
|
# so nested activities record the parent → child link.
|
||||||
|
parent_activity_id = log_context.snapshot().get("activity_id")
|
||||||
|
|
||||||
context = LogContext(endpoint, activity_id, user, api_key, query)
|
context = LogContext(endpoint, activity_id, user, api_key, query)
|
||||||
kwargs["log_context"] = context
|
kwargs["log_context"] = context
|
||||||
|
|
||||||
logging.info(
|
ctx_token = log_context.bind(
|
||||||
f"Starting activity: {endpoint} - {activity_id} - User: {user}"
|
activity_id=activity_id,
|
||||||
|
parent_activity_id=parent_activity_id,
|
||||||
|
user_id=user,
|
||||||
|
agent_id=agent_id,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
endpoint=endpoint,
|
||||||
|
model=model,
|
||||||
)
|
)
|
||||||
|
|
||||||
generator = func(*args, **kwargs)
|
started_at = time.monotonic()
|
||||||
yield from _consume_and_log(generator, context)
|
logging.info(
|
||||||
|
"activity_started",
|
||||||
|
extra={
|
||||||
|
"activity_id": activity_id,
|
||||||
|
"parent_activity_id": parent_activity_id,
|
||||||
|
"user_id": user,
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"conversation_id": conversation_id,
|
||||||
|
"endpoint": endpoint,
|
||||||
|
"model": model,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
error: BaseException | None = None
|
||||||
|
try:
|
||||||
|
generator = func(*args, **kwargs)
|
||||||
|
yield from _consume_and_log(generator, context)
|
||||||
|
except Exception as exc:
|
||||||
|
# Only ``Exception`` counts as an activity error; ``GeneratorExit``
|
||||||
|
# (consumer disconnected mid-stream) and ``KeyboardInterrupt``
|
||||||
|
# flow through the finally as ``status="ok"``, matching
|
||||||
|
# ``_consume_and_log``.
|
||||||
|
error = exc
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
_emit_activity_finished(
|
||||||
|
context=context,
|
||||||
|
parent_activity_id=parent_activity_id,
|
||||||
|
started_at=started_at,
|
||||||
|
error=error,
|
||||||
|
)
|
||||||
|
log_context.reset(ctx_token)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def _emit_activity_finished(
|
||||||
|
*,
|
||||||
|
context: "LogContext",
|
||||||
|
parent_activity_id: str | None,
|
||||||
|
started_at: float,
|
||||||
|
error: BaseException | None,
|
||||||
|
) -> None:
|
||||||
|
"""Emit the paired ``activity_finished`` event with duration, outcome,
|
||||||
|
and per-activity response aggregates accumulated in ``_consume_and_log``.
|
||||||
|
"""
|
||||||
|
duration_ms = int((time.monotonic() - started_at) * 1000)
|
||||||
|
logging.info(
|
||||||
|
"activity_finished",
|
||||||
|
extra={
|
||||||
|
"activity_id": context.activity_id,
|
||||||
|
"parent_activity_id": parent_activity_id,
|
||||||
|
"user_id": context.user,
|
||||||
|
"endpoint": context.endpoint,
|
||||||
|
"duration_ms": duration_ms,
|
||||||
|
"status": "error" if error is not None else "ok",
|
||||||
|
"error_class": type(error).__name__ if error is not None else None,
|
||||||
|
"answer_length": context.answer_length,
|
||||||
|
"thought_length": context.thought_length,
|
||||||
|
"source_count": context.source_count,
|
||||||
|
"tool_call_count": context.tool_call_count,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _accumulate_response_summary(item: Any, context: "LogContext") -> None:
|
||||||
|
"""Mirror the per-line aggregation that ``run_agent_logic`` did for the
|
||||||
|
Celery webhook path, but at the generator-consumption layer so every
|
||||||
|
``Agent.gen`` activity (Flask streaming, sub-agents, workflow agents)
|
||||||
|
gets the same summary.
|
||||||
|
"""
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
return
|
||||||
|
if "answer" in item:
|
||||||
|
context.answer_length += len(str(item["answer"]))
|
||||||
|
return
|
||||||
|
if "thought" in item:
|
||||||
|
context.thought_length += len(str(item["thought"]))
|
||||||
|
return
|
||||||
|
sources = item.get("sources") if "sources" in item else None
|
||||||
|
if isinstance(sources, list):
|
||||||
|
context.source_count += len(sources)
|
||||||
|
return
|
||||||
|
tool_calls = item.get("tool_calls") if "tool_calls" in item else None
|
||||||
|
if isinstance(tool_calls, list):
|
||||||
|
context.tool_call_count += len(tool_calls)
|
||||||
|
|
||||||
|
|
||||||
def _consume_and_log(generator: Generator, context: "LogContext"):
|
def _consume_and_log(generator: Generator, context: "LogContext"):
|
||||||
try:
|
try:
|
||||||
for item in generator:
|
for item in generator:
|
||||||
|
_accumulate_response_summary(item, context)
|
||||||
yield item
|
yield item
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(f"Error in {context.endpoint} - {context.activity_id}: {e}")
|
logging.exception(f"Error in {context.endpoint} - {context.activity_id}: {e}")
|
||||||
|
|||||||
59
application/mcp_server.py
Normal file
59
application/mcp_server.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
"""FastMCP server exposing DocsGPT retrieval over streamable HTTP.
|
||||||
|
|
||||||
|
Mounted at ``/mcp`` by ``application/asgi.py``. Bearer tokens are the
|
||||||
|
existing DocsGPT agent API keys — no new credential surface.
|
||||||
|
|
||||||
|
The tool reads the ``Authorization`` header directly via
|
||||||
|
``get_http_headers(include={"authorization"})``. The ``include`` kwarg
|
||||||
|
is required: by default ``get_http_headers`` strips ``authorization``
|
||||||
|
(and a handful of other hop-by-hop headers) so they aren't forwarded
|
||||||
|
to downstream services — since we deliberately want the caller's
|
||||||
|
token, we opt it back in.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from fastmcp import FastMCP
|
||||||
|
from fastmcp.server.dependencies import get_http_headers
|
||||||
|
|
||||||
|
from application.services.search_service import (
|
||||||
|
InvalidAPIKey,
|
||||||
|
SearchFailed,
|
||||||
|
search,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
mcp = FastMCP("docsgpt")
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_bearer_token() -> str | None:
|
||||||
|
auth = get_http_headers(include={"authorization"}).get("authorization", "")
|
||||||
|
parts = auth.split(None, 1)
|
||||||
|
if len(parts) != 2 or parts[0].lower() != "bearer" or not parts[1]:
|
||||||
|
return None
|
||||||
|
return parts[1]
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool
|
||||||
|
async def search_docs(query: str, chunks: int = 5) -> list[dict]:
|
||||||
|
"""Search the caller's DocsGPT knowledge base.
|
||||||
|
|
||||||
|
Authentication is via ``Authorization: Bearer <agent-api-key>`` on
|
||||||
|
the MCP request — the same opaque key that ``/api/search`` accepts
|
||||||
|
in its JSON body. Returns at most ``chunks`` hits, each a dict with
|
||||||
|
``text``, ``title``, ``source`` keys.
|
||||||
|
"""
|
||||||
|
api_key = _extract_bearer_token()
|
||||||
|
if not api_key:
|
||||||
|
raise PermissionError("Missing Bearer token")
|
||||||
|
try:
|
||||||
|
return await asyncio.to_thread(search, api_key, query, chunks)
|
||||||
|
except InvalidAPIKey as exc:
|
||||||
|
raise PermissionError("Invalid API key") from exc
|
||||||
|
except SearchFailed:
|
||||||
|
logger.exception("search_docs failed")
|
||||||
|
raise
|
||||||
@@ -1,9 +1,12 @@
|
|||||||
|
a2wsgi==1.10.10
|
||||||
alembic>=1.13,<2
|
alembic>=1.13,<2
|
||||||
anthropic==0.88.0
|
anthropic==0.88.0
|
||||||
|
asgiref>=3.11.1
|
||||||
boto3==1.42.83
|
boto3==1.42.83
|
||||||
beautifulsoup4==4.14.3
|
beautifulsoup4==4.14.3
|
||||||
cel-python==0.5.0
|
cel-python==0.5.0
|
||||||
celery==5.6.3
|
celery==5.6.3
|
||||||
|
celery-redbeat==2.3.3
|
||||||
cryptography==46.0.7
|
cryptography==46.0.7
|
||||||
dataclasses-json==0.6.7
|
dataclasses-json==0.6.7
|
||||||
defusedxml==0.7.1
|
defusedxml==0.7.1
|
||||||
@@ -14,7 +17,7 @@ docx2txt==0.9
|
|||||||
ddgs>=8.0.0
|
ddgs>=8.0.0
|
||||||
fast-ebook
|
fast-ebook
|
||||||
elevenlabs==2.43.0
|
elevenlabs==2.43.0
|
||||||
Flask==3.1.3
|
Flask==3.1.1
|
||||||
faiss-cpu==1.13.2
|
faiss-cpu==1.13.2
|
||||||
fastmcp==3.2.4
|
fastmcp==3.2.4
|
||||||
flask-restx==1.3.2
|
flask-restx==1.3.2
|
||||||
@@ -49,6 +52,16 @@ networkx==3.6.1
|
|||||||
numpy==2.4.4
|
numpy==2.4.4
|
||||||
openai==2.32.0
|
openai==2.32.0
|
||||||
openapi3-parser==1.1.22
|
openapi3-parser==1.1.22
|
||||||
|
opentelemetry-distro>=0.50b0,<1
|
||||||
|
opentelemetry-exporter-otlp>=1.29.0,<2
|
||||||
|
opentelemetry-instrumentation-celery>=0.50b0,<1
|
||||||
|
opentelemetry-instrumentation-flask>=0.50b0,<1
|
||||||
|
opentelemetry-instrumentation-logging>=0.50b0,<1
|
||||||
|
opentelemetry-instrumentation-psycopg>=0.50b0,<1
|
||||||
|
opentelemetry-instrumentation-redis>=0.50b0,<1
|
||||||
|
opentelemetry-instrumentation-requests>=0.50b0,<1
|
||||||
|
opentelemetry-instrumentation-sqlalchemy>=0.50b0,<1
|
||||||
|
opentelemetry-instrumentation-starlette>=0.50b0,<1
|
||||||
orjson==3.11.7
|
orjson==3.11.7
|
||||||
packaging==26.0
|
packaging==26.0
|
||||||
pandas==3.0.2
|
pandas==3.0.2
|
||||||
@@ -58,7 +71,7 @@ pdf2image>=1.17.0
|
|||||||
pillow
|
pillow
|
||||||
portalocker>=2.7.0,<4.0.0
|
portalocker>=2.7.0,<4.0.0
|
||||||
prompt-toolkit==3.0.52
|
prompt-toolkit==3.0.52
|
||||||
protobuf==7.34.1
|
protobuf==6.33.6
|
||||||
psycopg[binary,pool]>=3.1,<4
|
psycopg[binary,pool]>=3.1,<4
|
||||||
py==1.11.0
|
py==1.11.0
|
||||||
pydantic
|
pydantic
|
||||||
@@ -69,6 +82,7 @@ python-dateutil==2.9.0.post0
|
|||||||
python-dotenv
|
python-dotenv
|
||||||
python-jose==3.5.0
|
python-jose==3.5.0
|
||||||
python-pptx==1.0.2
|
python-pptx==1.0.2
|
||||||
|
PyYAML
|
||||||
redis==7.4.0
|
redis==7.4.0
|
||||||
referencing>=0.28.0,<0.38.0
|
referencing>=0.28.0,<0.38.0
|
||||||
regex==2026.4.4
|
regex==2026.4.4
|
||||||
@@ -76,6 +90,7 @@ requests==2.33.1
|
|||||||
retry==0.9.2
|
retry==0.9.2
|
||||||
sentence-transformers==5.3.0
|
sentence-transformers==5.3.0
|
||||||
sqlalchemy>=2.0,<3
|
sqlalchemy>=2.0,<3
|
||||||
|
starlette>=1.0,<2
|
||||||
tiktoken==0.12.0
|
tiktoken==0.12.0
|
||||||
tokenizers==0.22.2
|
tokenizers==0.22.2
|
||||||
torch==2.11.0
|
torch==2.11.0
|
||||||
@@ -85,6 +100,8 @@ typing-extensions==4.15.0
|
|||||||
typing-inspect==0.9.0
|
typing-inspect==0.9.0
|
||||||
tzdata==2026.1
|
tzdata==2026.1
|
||||||
urllib3==2.6.3
|
urllib3==2.6.3
|
||||||
|
uvicorn[standard]>=0.30,<1
|
||||||
|
uvicorn-worker>=0.4,<1
|
||||||
vine==5.1.0
|
vine==5.1.0
|
||||||
wcwidth==0.6.0
|
wcwidth==0.6.0
|
||||||
werkzeug>=3.1.0
|
werkzeug>=3.1.0
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ class ClassicRAG(BaseRetriever):
|
|||||||
llm_name=settings.LLM_PROVIDER,
|
llm_name=settings.LLM_PROVIDER,
|
||||||
api_key=settings.API_KEY,
|
api_key=settings.API_KEY,
|
||||||
decoded_token=None,
|
decoded_token=None,
|
||||||
|
model_user_id=None,
|
||||||
):
|
):
|
||||||
self.original_question = source.get("question", "")
|
self.original_question = source.get("question", "")
|
||||||
self.chat_history = chat_history if chat_history is not None else []
|
self.chat_history = chat_history if chat_history is not None else []
|
||||||
@@ -42,17 +43,22 @@ class ClassicRAG(BaseRetriever):
|
|||||||
f"sources={'active_docs' in source and source['active_docs'] is not None}"
|
f"sources={'active_docs' in source and source['active_docs'] is not None}"
|
||||||
)
|
)
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
|
self.model_user_id = model_user_id
|
||||||
self.doc_token_limit = doc_token_limit
|
self.doc_token_limit = doc_token_limit
|
||||||
self.user_api_key = user_api_key
|
self.user_api_key = user_api_key
|
||||||
self.agent_id = agent_id
|
self.agent_id = agent_id
|
||||||
self.llm_name = llm_name
|
self.llm_name = llm_name
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
|
# Forward model_id + model_user_id so LLMCreator resolves BYOM
|
||||||
|
# base_url / api_key / upstream id for the rephrase client.
|
||||||
self.llm = LLMCreator.create_llm(
|
self.llm = LLMCreator.create_llm(
|
||||||
self.llm_name,
|
self.llm_name,
|
||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
user_api_key=self.user_api_key,
|
user_api_key=self.user_api_key,
|
||||||
decoded_token=decoded_token,
|
decoded_token=decoded_token,
|
||||||
|
model_id=self.model_id,
|
||||||
agent_id=self.agent_id,
|
agent_id=self.agent_id,
|
||||||
|
model_user_id=self.model_user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if "active_docs" in source and source["active_docs"] is not None:
|
if "active_docs" in source and source["active_docs"] is not None:
|
||||||
@@ -103,7 +109,11 @@ class ClassicRAG(BaseRetriever):
|
|||||||
]
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
rephrased_query = self.llm.gen(model=self.model_id, messages=messages)
|
# Send upstream id (resolved by LLMCreator), not registry UUID.
|
||||||
|
rephrased_query = self.llm.gen(
|
||||||
|
model=getattr(self.llm, "model_id", None) or self.model_id,
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
print(f"Rephrased query: {rephrased_query}")
|
print(f"Rephrased query: {rephrased_query}")
|
||||||
return rephrased_query if rephrased_query else self.original_question
|
return rephrased_query if rephrased_query else self.original_question
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
464
application/security/safe_url.py
Normal file
464
application/security/safe_url.py
Normal file
@@ -0,0 +1,464 @@
|
|||||||
|
"""SSRF protection for user-supplied OpenAI-compatible base URLs.
|
||||||
|
|
||||||
|
This module is the single chokepoint for validating any URL that a user
|
||||||
|
provides as an OpenAI-compatible ``base_url`` ("Bring Your Own Model").
|
||||||
|
The backend will later issue outbound HTTP requests to that URL on the
|
||||||
|
user's behalf, so we must reject anything that could be used to reach
|
||||||
|
internal-network resources (cloud metadata services, RFC 1918 ranges,
|
||||||
|
loopback, link-local, etc.).
|
||||||
|
|
||||||
|
Three entry points:
|
||||||
|
|
||||||
|
* :func:`validate_user_base_url` — called at create/update time on REST
|
||||||
|
routes that persist the URL, to give the user immediate feedback.
|
||||||
|
* :func:`pinned_post` — called at dispatch time when the caller drives
|
||||||
|
``requests`` directly (e.g. the ``/api/models/test`` endpoint).
|
||||||
|
Resolves once, dials the IP literal, preserves the original hostname
|
||||||
|
in the ``Host`` header and via SNI / cert verification for HTTPS.
|
||||||
|
* :func:`pinned_httpx_client` — called at dispatch time when the caller
|
||||||
|
hands an ``httpx.Client`` to a third-party SDK (e.g. the OpenAI
|
||||||
|
Python SDK via ``OpenAI(http_client=...)``). Same DNS-rebinding
|
||||||
|
closure on the httpx transport layer.
|
||||||
|
|
||||||
|
Why all three: the OpenAI / httpx ecosystem performs its own DNS lookup
|
||||||
|
inside ``socket.getaddrinfo`` when a connection opens, so a hostile DNS
|
||||||
|
server can hand a public IP to the validator and a loopback / link-local
|
||||||
|
address to the HTTP client. Validate-then-construct-SDK is unsafe; the
|
||||||
|
pinned variants close that TOCTOU window by resolving exactly once and
|
||||||
|
dialing the chosen IP literal directly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import ipaddress
|
||||||
|
import socket
|
||||||
|
from typing import Any, Iterable
|
||||||
|
from urllib.parse import urlsplit, urlunsplit
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import requests
|
||||||
|
from requests.adapters import HTTPAdapter
|
||||||
|
|
||||||
|
# Allowed URL schemes. Anything else (file, gopher, ftp, data, ...) is
|
||||||
|
# rejected outright because it either bypasses HTTP entirely or enables
|
||||||
|
# protocol smuggling against the proxy stack.
|
||||||
|
_ALLOWED_SCHEMES: frozenset[str] = frozenset({"http", "https"})
|
||||||
|
|
||||||
|
# Hostnames that resolve to a loopback / metadata / unspecified address
|
||||||
|
# but which we want to reject *by name* as well, so the rejection
|
||||||
|
# message is unambiguous and so we never accidentally call DNS on them.
|
||||||
|
_BLOCKED_HOSTNAMES: frozenset[str] = frozenset(
|
||||||
|
{
|
||||||
|
"localhost",
|
||||||
|
"localhost.localdomain",
|
||||||
|
"0.0.0.0",
|
||||||
|
"::",
|
||||||
|
"::1",
|
||||||
|
"ip6-localhost",
|
||||||
|
"ip6-loopback",
|
||||||
|
# GCP metadata service. AWS/Azure use 169.254.169.254 which the
|
||||||
|
# IP-range check below already covers via the link-local range,
|
||||||
|
# but Google's hostname does not always resolve to a link-local
|
||||||
|
# IP from every VPC, so we hard-deny the string too.
|
||||||
|
"metadata.google.internal",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Carrier-grade NAT (RFC 6598). Python's ``ipaddress`` module does NOT
|
||||||
|
# classify this range as ``is_private``, so we must check it explicitly.
|
||||||
|
_CGNAT_NETWORK_V4: ipaddress.IPv4Network = ipaddress.IPv4Network("100.64.0.0/10")
|
||||||
|
|
||||||
|
|
||||||
|
class UnsafeUserUrlError(ValueError):
|
||||||
|
"""Raised when a user-supplied URL fails SSRF validation.
|
||||||
|
|
||||||
|
Subclasses :class:`ValueError` so call sites that already treat
|
||||||
|
invalid input as a 400-class error continue to work. The string
|
||||||
|
message names the specific reason (scheme, hostname, resolved IP,
|
||||||
|
DNS failure, ...) so that it can be surfaced to the user verbatim.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_ipv6_brackets(host: str) -> str:
|
||||||
|
"""Return ``host`` with surrounding ``[`` / ``]`` removed if present."""
|
||||||
|
|
||||||
|
if host.startswith("[") and host.endswith("]"):
|
||||||
|
return host[1:-1]
|
||||||
|
return host
|
||||||
|
|
||||||
|
|
||||||
|
def _is_blocked_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
|
||||||
|
"""Return ``True`` if ``ip`` falls in any range we refuse to dial.
|
||||||
|
|
||||||
|
This is the single source of truth for the IP-range policy:
|
||||||
|
|
||||||
|
* loopback (``127.0.0.0/8``, ``::1``)
|
||||||
|
* private (RFC 1918, ULA ``fc00::/7``)
|
||||||
|
* link-local (``169.254.0.0/16``, ``fe80::/10``)
|
||||||
|
* multicast (``224.0.0.0/4``, ``ff00::/8``)
|
||||||
|
* unspecified (``0.0.0.0``, ``::``)
|
||||||
|
* reserved (``240.0.0.0/4``, etc.)
|
||||||
|
* carrier-grade NAT (``100.64.0.0/10``) — not covered by ``is_private``
|
||||||
|
"""
|
||||||
|
|
||||||
|
if (
|
||||||
|
ip.is_loopback
|
||||||
|
or ip.is_private
|
||||||
|
or ip.is_link_local
|
||||||
|
or ip.is_multicast
|
||||||
|
or ip.is_unspecified
|
||||||
|
or ip.is_reserved
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
if isinstance(ip, ipaddress.IPv4Address) and ip in _CGNAT_NETWORK_V4:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve(host: str) -> Iterable[ipaddress.IPv4Address | ipaddress.IPv6Address]:
|
||||||
|
"""Resolve ``host`` to every A/AAAA record returned by the system.
|
||||||
|
|
||||||
|
Returning *all* addresses (rather than the first one) is critical:
|
||||||
|
a hostile DNS server can return a public IP first followed by a
|
||||||
|
private IP, and the underlying HTTP client may fail over to the
|
||||||
|
private one on connect. We treat the set as unsafe if any element
|
||||||
|
is unsafe.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
results = socket.getaddrinfo(host, None)
|
||||||
|
except socket.gaierror as exc: # noqa: PERF203 — re-raise as our own type
|
||||||
|
raise UnsafeUserUrlError(f"could not resolve hostname {host!r}: {exc}") from exc
|
||||||
|
|
||||||
|
addresses: list[ipaddress.IPv4Address | ipaddress.IPv6Address] = []
|
||||||
|
for entry in results:
|
||||||
|
sockaddr = entry[4]
|
||||||
|
# IPv4 sockaddr: (host, port). IPv6 sockaddr: (host, port, flowinfo, scope_id).
|
||||||
|
ip_str = sockaddr[0]
|
||||||
|
# Strip IPv6 zone-id ("fe80::1%lo0") before parsing.
|
||||||
|
if "%" in ip_str:
|
||||||
|
ip_str = ip_str.split("%", 1)[0]
|
||||||
|
try:
|
||||||
|
addresses.append(ipaddress.ip_address(ip_str))
|
||||||
|
except ValueError:
|
||||||
|
# An entry we can't parse is itself suspicious; treat as unsafe.
|
||||||
|
raise UnsafeUserUrlError(
|
||||||
|
f"hostname {host!r} resolved to unparseable address {ip_str!r}"
|
||||||
|
) from None
|
||||||
|
return addresses
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_and_pick_ip(
|
||||||
|
url: str,
|
||||||
|
) -> tuple[str, ipaddress.IPv4Address | ipaddress.IPv6Address, "urlsplit"]:
|
||||||
|
"""Run the SSRF guard and return the data needed to dial safely.
|
||||||
|
|
||||||
|
Performs every check :func:`validate_user_base_url` performs, but
|
||||||
|
additionally returns ``(hostname, ip, parts)`` where ``ip`` is one
|
||||||
|
of the validated addresses (the first record returned by the
|
||||||
|
resolver, or the literal itself if the URL already used an IP) and
|
||||||
|
``parts`` is the :func:`urllib.parse.urlsplit` result so callers do
|
||||||
|
not have to re-parse the URL.
|
||||||
|
|
||||||
|
Raises :class:`UnsafeUserUrlError` on the same conditions as
|
||||||
|
:func:`validate_user_base_url`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not isinstance(url, str) or not url.strip():
|
||||||
|
raise UnsafeUserUrlError("url must be a non-empty string")
|
||||||
|
|
||||||
|
try:
|
||||||
|
parts = urlsplit(url)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise UnsafeUserUrlError(f"could not parse url {url!r}: {exc}") from exc
|
||||||
|
|
||||||
|
scheme = parts.scheme.lower()
|
||||||
|
if scheme not in _ALLOWED_SCHEMES:
|
||||||
|
raise UnsafeUserUrlError(
|
||||||
|
f"scheme {scheme!r} is not allowed; only http and https are permitted"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ``urlsplit`` returns the bracketed form for IPv6 in ``netloc`` but
|
||||||
|
# the bare form in ``hostname``. Normalize via lower() because
|
||||||
|
# hostnames are case-insensitive and we compare against a lowercase
|
||||||
|
# blocklist.
|
||||||
|
raw_host = parts.hostname
|
||||||
|
if not raw_host:
|
||||||
|
raise UnsafeUserUrlError(f"url {url!r} has no hostname")
|
||||||
|
|
||||||
|
host = raw_host.lower()
|
||||||
|
|
||||||
|
# Check the literal-string blocklist first. urlsplit().hostname strips
|
||||||
|
# IPv6 brackets, so we also test the bracketed form for completeness
|
||||||
|
# (matches the public-spec note about ``[::]``).
|
||||||
|
bracketed = f"[{host}]"
|
||||||
|
if host in _BLOCKED_HOSTNAMES or bracketed in _BLOCKED_HOSTNAMES:
|
||||||
|
raise UnsafeUserUrlError(
|
||||||
|
f"hostname {raw_host!r} is not allowed (matches internal-only name)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# If the host is already an IP literal (with or without IPv6 brackets),
|
||||||
|
# check it directly without going to DNS — DNS for an IP literal is a
|
||||||
|
# no-op but it's clearer to short-circuit and gives a better message.
|
||||||
|
candidate = _strip_ipv6_brackets(host)
|
||||||
|
try:
|
||||||
|
literal = ipaddress.ip_address(candidate)
|
||||||
|
except ValueError:
|
||||||
|
literal = None
|
||||||
|
|
||||||
|
if literal is not None:
|
||||||
|
if _is_blocked_ip(literal):
|
||||||
|
raise UnsafeUserUrlError(
|
||||||
|
f"hostname {raw_host!r} resolves to blocked address {literal} "
|
||||||
|
f"(loopback/private/link-local/multicast/reserved/CGNAT)"
|
||||||
|
)
|
||||||
|
return host, literal, parts
|
||||||
|
|
||||||
|
# Hostname (not an IP literal) — resolve and validate every record.
|
||||||
|
addresses = list(_resolve(host))
|
||||||
|
for ip in addresses:
|
||||||
|
if _is_blocked_ip(ip):
|
||||||
|
raise UnsafeUserUrlError(
|
||||||
|
f"hostname {raw_host!r} resolves to blocked address {ip} "
|
||||||
|
f"(loopback/private/link-local/multicast/reserved/CGNAT)"
|
||||||
|
)
|
||||||
|
if not addresses:
|
||||||
|
# ``getaddrinfo`` would normally raise instead of returning an
|
||||||
|
# empty list, but treat the degenerate case as unsafe too — we
|
||||||
|
# have nothing to bind a connection to.
|
||||||
|
raise UnsafeUserUrlError(
|
||||||
|
f"hostname {raw_host!r} returned no addresses from DNS"
|
||||||
|
)
|
||||||
|
return host, addresses[0], parts
|
||||||
|
|
||||||
|
|
||||||
|
def validate_user_base_url(url: str) -> None:
|
||||||
|
"""Validate that ``url`` is safe to use as an outbound base URL.
|
||||||
|
|
||||||
|
Resolve the URL's hostname to one or more IPs and reject if any
|
||||||
|
resolved IP is private/loopback/link-local/multicast/reserved, or if
|
||||||
|
the URL uses a non-http(s) scheme, or if the hostname is one of the
|
||||||
|
known dangerous strings (``localhost``, ``0.0.0.0``, ``[::]``).
|
||||||
|
|
||||||
|
Raises :class:`UnsafeUserUrlError` on rejection. Returns ``None`` on
|
||||||
|
success.
|
||||||
|
|
||||||
|
This function is the create/update-time check. At dispatch time use
|
||||||
|
:func:`pinned_post` instead, which performs the same validation
|
||||||
|
*and* pins the outbound connection to the validated IP so a DNS
|
||||||
|
rebinder cannot flip the resolution between check and connect.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: The user-supplied URL to validate. Expected to be an
|
||||||
|
absolute URL with an ``http`` or ``https`` scheme.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
UnsafeUserUrlError: If the URL fails to parse, uses a forbidden
|
||||||
|
scheme, has an empty/blocklisted hostname, fails DNS
|
||||||
|
resolution, or resolves to any IP in a blocked range.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_validate_and_pick_ip(url)
|
||||||
|
|
||||||
|
|
||||||
|
class _PinnedHostAdapter(HTTPAdapter):
|
||||||
|
"""HTTPS adapter that performs SNI and cert verification against a
|
||||||
|
fixed hostname even when the URL connects to an IP literal.
|
||||||
|
|
||||||
|
Used by :func:`pinned_post` so that resolving the user-supplied
|
||||||
|
hostname once and dialing the resolved IP doesn't break TLS.
|
||||||
|
Without this, ``urllib3`` would default ``server_hostname`` /
|
||||||
|
``assert_hostname`` to the connect host (the IP) and either send the
|
||||||
|
wrong SNI or fail cert verification — the cert is for the original
|
||||||
|
hostname, not the IP literal.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, server_hostname: str, *args: Any, **kwargs: Any) -> None:
|
||||||
|
self._server_hostname = server_hostname
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def init_poolmanager(self, *args: Any, **kwargs: Any) -> None:
|
||||||
|
kwargs["server_hostname"] = self._server_hostname
|
||||||
|
kwargs["assert_hostname"] = self._server_hostname
|
||||||
|
super().init_poolmanager(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _ip_to_url_host(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> str:
|
||||||
|
"""Return ``ip`` formatted for use in a URL netloc (brackets for v6)."""
|
||||||
|
|
||||||
|
if isinstance(ip, ipaddress.IPv6Address):
|
||||||
|
return f"[{ip}]"
|
||||||
|
return str(ip)
|
||||||
|
|
||||||
|
|
||||||
|
def pinned_post(
|
||||||
|
url: str,
|
||||||
|
*,
|
||||||
|
json: Any = None,
|
||||||
|
headers: dict[str, str] | None = None,
|
||||||
|
timeout: float = 5.0,
|
||||||
|
allow_redirects: bool = False,
|
||||||
|
) -> requests.Response:
|
||||||
|
"""POST to ``url`` with the outbound connection pinned to a single
|
||||||
|
validated IP, closing the DNS-rebinding TOCTOU window left by the
|
||||||
|
naive validate-then-``requests.post`` pattern.
|
||||||
|
|
||||||
|
The URL's hostname is resolved exactly once. Every returned address
|
||||||
|
must pass the same SSRF guard as :func:`validate_user_base_url`. The
|
||||||
|
outbound request is issued against the chosen IP literal (so
|
||||||
|
``urllib3`` cannot ask the resolver again and receive a different
|
||||||
|
answer); the original hostname is preserved in the ``Host`` header
|
||||||
|
and, for HTTPS, via :class:`_PinnedHostAdapter` for SNI and cert
|
||||||
|
verification.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: Absolute http(s) URL to POST to.
|
||||||
|
json: JSON-serializable payload — passed through to ``requests``.
|
||||||
|
headers: Caller-supplied headers. Any caller-supplied ``Host``
|
||||||
|
entry is overwritten so the in-flight request matches what
|
||||||
|
was validated.
|
||||||
|
timeout: Per-request timeout (seconds).
|
||||||
|
allow_redirects: Forwarded to ``requests``. Defaults to
|
||||||
|
``False`` because the SSRF guard only inspects the supplied
|
||||||
|
URL — following redirects would let a hostile upstream
|
||||||
|
bounce the request to an internal address.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
UnsafeUserUrlError: If the URL fails the SSRF guard.
|
||||||
|
requests.RequestException: For network-level failures.
|
||||||
|
"""
|
||||||
|
|
||||||
|
host, ip, parts = _validate_and_pick_ip(url)
|
||||||
|
|
||||||
|
netloc = _ip_to_url_host(ip)
|
||||||
|
if parts.port is not None:
|
||||||
|
netloc = f"{netloc}:{parts.port}"
|
||||||
|
pinned_url = urlunsplit(
|
||||||
|
(parts.scheme, netloc, parts.path, parts.query, parts.fragment)
|
||||||
|
)
|
||||||
|
|
||||||
|
request_headers = dict(headers or {})
|
||||||
|
host_header = host if parts.port is None else f"{host}:{parts.port}"
|
||||||
|
request_headers["Host"] = host_header
|
||||||
|
|
||||||
|
session = requests.Session()
|
||||||
|
if parts.scheme == "https":
|
||||||
|
session.mount("https://", _PinnedHostAdapter(host))
|
||||||
|
try:
|
||||||
|
return session.post(
|
||||||
|
pinned_url,
|
||||||
|
json=json,
|
||||||
|
headers=request_headers,
|
||||||
|
timeout=timeout,
|
||||||
|
allow_redirects=allow_redirects,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
|
||||||
|
class _PinnedHTTPSTransport(httpx.HTTPTransport):
|
||||||
|
"""``httpx`` transport pinned to a single validated IP literal.
|
||||||
|
|
||||||
|
Closes the DNS-rebinding TOCTOU window that
|
||||||
|
:func:`validate_user_base_url` cannot close on its own. The OpenAI
|
||||||
|
Python SDK (and any other SDK that uses ``httpx``) re-resolves the
|
||||||
|
hostname inside ``socket.getaddrinfo`` at request time, so a
|
||||||
|
hostile DNS server can return a public IP at validation time and a
|
||||||
|
private IP at request time. This transport rewrites every outgoing
|
||||||
|
request's URL host to the validated IP literal so ``httpcore``
|
||||||
|
dials that IP without a fresh lookup.
|
||||||
|
|
||||||
|
The original hostname is preserved in two places:
|
||||||
|
|
||||||
|
1. ``Host`` header — ``httpx.Request._prepare`` set it from the URL
|
||||||
|
netloc *before* this transport runs, so it carries the hostname
|
||||||
|
not the IP literal. We deliberately do not touch headers here.
|
||||||
|
2. TLS SNI / cert verification — set via the
|
||||||
|
``request.extensions["sni_hostname"]`` extension which
|
||||||
|
``httpcore`` feeds into ``start_tls``'s ``server_hostname``
|
||||||
|
parameter. Without this, ``urllib3``-equivalent code would use
|
||||||
|
the IP literal as SNI and cert verification would fail (the
|
||||||
|
cert is for the original hostname, not the IP).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
validated_host: str,
|
||||||
|
validated_ip: ipaddress.IPv4Address | ipaddress.IPv6Address,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
# http2=False (the httpx default) — defense in depth against
|
||||||
|
# HTTP/2 connection coalescing (RFC 7540 §9.1.1), where a
|
||||||
|
# client may reuse a TCP connection for any host whose cert
|
||||||
|
# covers it. Per-IP pinning never shares connections across
|
||||||
|
# hosts, but explicit is safer than relying on the default.
|
||||||
|
kwargs.setdefault("http2", False)
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self._host = validated_host
|
||||||
|
self._ip_netloc = _ip_to_url_host(validated_ip)
|
||||||
|
|
||||||
|
def handle_request(self, request: httpx.Request) -> httpx.Response:
|
||||||
|
# Defense in depth: refuse if the request URL's host doesn't
|
||||||
|
# match what we validated. Catches any future SDK regression
|
||||||
|
# that rewrites the URL between Request construction and dial,
|
||||||
|
# and any rare case where the SDK reuses our pinned client for
|
||||||
|
# a different host (which it shouldn't, but assert it anyway).
|
||||||
|
if request.url.host != self._host:
|
||||||
|
raise UnsafeUserUrlError(
|
||||||
|
f"pinned transport bound to {self._host!r}, refused "
|
||||||
|
f"request for {request.url.host!r}"
|
||||||
|
)
|
||||||
|
# SNI/server_hostname for TLS verification. httpcore reads this
|
||||||
|
# extension at _sync/connection.py and feeds it into
|
||||||
|
# start_tls's server_hostname argument. Set before the URL host
|
||||||
|
# is rewritten so cert validation continues to use the original
|
||||||
|
# hostname even though TCP dials the IP literal.
|
||||||
|
request.extensions = {
|
||||||
|
**request.extensions,
|
||||||
|
"sni_hostname": self._host.encode("ascii"),
|
||||||
|
}
|
||||||
|
request.url = request.url.copy_with(host=self._ip_netloc)
|
||||||
|
return super().handle_request(request)
|
||||||
|
|
||||||
|
|
||||||
|
def pinned_httpx_client(
|
||||||
|
base_url: str,
|
||||||
|
*,
|
||||||
|
timeout: float = 600.0,
|
||||||
|
) -> httpx.Client:
|
||||||
|
"""Return an :class:`httpx.Client` whose connections are pinned to
|
||||||
|
one validated IP, closing the DNS-rebinding TOCTOU window the naive
|
||||||
|
``OpenAI(base_url=...)`` flow leaves open.
|
||||||
|
|
||||||
|
The hostname in ``base_url`` is resolved exactly once. Every
|
||||||
|
returned address must pass :func:`_validate_and_pick_ip`'s SSRF
|
||||||
|
guard (loopback, RFC 1918, link-local, multicast, reserved, CGNAT,
|
||||||
|
cloud metadata names). The chosen IP becomes the URL host on every
|
||||||
|
outgoing request so ``httpcore`` cannot ask the resolver again.
|
||||||
|
|
||||||
|
Pass via ``OpenAI(http_client=pinned_httpx_client(base_url))`` (or
|
||||||
|
any other SDK that accepts an ``httpx.Client``) to make BYOM
|
||||||
|
dispatch immune to DNS-rebinding TOCTOU.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_url: User-supplied http(s) URL. Validated through the same
|
||||||
|
SSRF guard as :func:`validate_user_base_url`.
|
||||||
|
timeout: Per-request timeout (seconds). Defaults to 600 to
|
||||||
|
match the OpenAI SDK's default; callers should override
|
||||||
|
for non-LLM workloads.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
UnsafeUserUrlError: If ``base_url`` fails the SSRF guard.
|
||||||
|
"""
|
||||||
|
|
||||||
|
host, ip, _parts = _validate_and_pick_ip(base_url)
|
||||||
|
transport = _PinnedHTTPSTransport(host, ip)
|
||||||
|
# follow_redirects=False — the SSRF guard only inspects the
|
||||||
|
# supplied URL; following 3xx would let a hostile upstream bounce
|
||||||
|
# the in-network request to an internal address (cloud metadata,
|
||||||
|
# RFC1918, loopback) carrying whatever credentials the SDK adds.
|
||||||
|
return httpx.Client(
|
||||||
|
transport=transport,
|
||||||
|
timeout=timeout,
|
||||||
|
follow_redirects=False,
|
||||||
|
)
|
||||||
153
application/services/search_service.py
Normal file
153
application/services/search_service.py
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
"""Shared retrieval service used by the HTTP search route and the MCP tool.
|
||||||
|
|
||||||
|
Flask-free. Raises domain exceptions (``InvalidAPIKey``, ``SearchFailed``)
|
||||||
|
that callers translate into their own wire protocol (HTTP status codes,
|
||||||
|
MCP error responses, etc.).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from application.core.settings import settings
|
||||||
|
from application.storage.db.repositories.agents import AgentsRepository
|
||||||
|
from application.storage.db.session import db_readonly
|
||||||
|
from application.vectorstore.vector_creator import VectorCreator
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidAPIKey(Exception):
|
||||||
|
"""The supplied ``api_key`` does not resolve to an agent."""
|
||||||
|
|
||||||
|
|
||||||
|
class SearchFailed(Exception):
|
||||||
|
"""Unexpected error during retrieval (e.g. DB outage). Caller maps to 5xx."""
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_source_ids(agent: Dict[str, Any]) -> List[str]:
|
||||||
|
"""Extract the ordered list of source UUIDs to search.
|
||||||
|
|
||||||
|
Prefers ``extra_source_ids`` (PG ARRAY(UUID) of multi-source agents);
|
||||||
|
falls back to the legacy single ``source_id`` field.
|
||||||
|
"""
|
||||||
|
source_ids: List[str] = []
|
||||||
|
extra = agent.get("extra_source_ids") or []
|
||||||
|
for src in extra:
|
||||||
|
if src:
|
||||||
|
source_ids.append(str(src))
|
||||||
|
if not source_ids:
|
||||||
|
single = agent.get("source_id")
|
||||||
|
if single:
|
||||||
|
source_ids.append(str(single))
|
||||||
|
return source_ids
|
||||||
|
|
||||||
|
|
||||||
|
def _search_sources(
|
||||||
|
query: str, source_ids: List[str], chunks: int
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Search across each source's vectorstore and return up to ``chunks`` hits.
|
||||||
|
|
||||||
|
Per-source errors are logged and skipped so one broken index doesn't
|
||||||
|
take down the whole search. Results are de-duplicated by content hash.
|
||||||
|
"""
|
||||||
|
if chunks <= 0 or not source_ids:
|
||||||
|
return []
|
||||||
|
|
||||||
|
results: List[Dict[str, Any]] = []
|
||||||
|
chunks_per_source = max(1, chunks // len(source_ids))
|
||||||
|
seen_texts: set[int] = set()
|
||||||
|
|
||||||
|
for source_id in source_ids:
|
||||||
|
if not source_id or not source_id.strip():
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
docsearch = VectorCreator.create_vectorstore(
|
||||||
|
settings.VECTOR_STORE, source_id, settings.EMBEDDINGS_KEY
|
||||||
|
)
|
||||||
|
docs = docsearch.search(query, k=chunks_per_source * 2)
|
||||||
|
|
||||||
|
for doc in docs:
|
||||||
|
if len(results) >= chunks:
|
||||||
|
break
|
||||||
|
|
||||||
|
if hasattr(doc, "page_content") and hasattr(doc, "metadata"):
|
||||||
|
page_content = doc.page_content
|
||||||
|
metadata = doc.metadata
|
||||||
|
else:
|
||||||
|
page_content = doc.get("text", doc.get("page_content", ""))
|
||||||
|
metadata = doc.get("metadata", {})
|
||||||
|
|
||||||
|
text_hash = hash(page_content[:200])
|
||||||
|
if text_hash in seen_texts:
|
||||||
|
continue
|
||||||
|
seen_texts.add(text_hash)
|
||||||
|
|
||||||
|
title = metadata.get("title", metadata.get("post_title", ""))
|
||||||
|
if not isinstance(title, str):
|
||||||
|
title = str(title) if title else ""
|
||||||
|
|
||||||
|
if title:
|
||||||
|
title = title.split("/")[-1]
|
||||||
|
else:
|
||||||
|
title = metadata.get("filename", page_content[:50] + "...")
|
||||||
|
|
||||||
|
source = metadata.get("source", source_id)
|
||||||
|
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"text": page_content,
|
||||||
|
"title": title,
|
||||||
|
"source": source,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(results) >= chunks:
|
||||||
|
break
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error searching vectorstore {source_id}: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
return results[:chunks]
|
||||||
|
|
||||||
|
|
||||||
|
def search(api_key: str, query: str, chunks: int = 5) -> List[Dict[str, Any]]:
|
||||||
|
"""Resolve an agent by API key and search its sources.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: Agent API key (the opaque string stored on
|
||||||
|
``agents.key`` in Postgres).
|
||||||
|
query: Free-text search query.
|
||||||
|
chunks: Max number of hits to return.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of hit dicts with ``text``, ``title``, ``source`` keys.
|
||||||
|
Empty list if the agent has no sources configured.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidAPIKey: if ``api_key`` does not resolve to an agent.
|
||||||
|
SearchFailed: on unexpected DB / infrastructure errors.
|
||||||
|
"""
|
||||||
|
if chunks <= 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
with db_readonly() as conn:
|
||||||
|
agent = AgentsRepository(conn).find_by_key(api_key)
|
||||||
|
except Exception as e:
|
||||||
|
raise SearchFailed("agent lookup failed") from e
|
||||||
|
|
||||||
|
if not agent:
|
||||||
|
raise InvalidAPIKey()
|
||||||
|
|
||||||
|
source_ids = _collect_source_ids(agent)
|
||||||
|
if not source_ids:
|
||||||
|
return []
|
||||||
|
|
||||||
|
return _search_sources(query, source_ids, chunks)
|
||||||
@@ -117,6 +117,16 @@ stack_logs_table = Table(
|
|||||||
Column("timestamp", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
Column("timestamp", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Singleton key/value table for instance-wide state (e.g. anonymous
|
||||||
|
# instance UUID, one-shot notice flags). Added in migration
|
||||||
|
# ``0002_app_metadata``.
|
||||||
|
app_metadata_table = Table(
|
||||||
|
"app_metadata",
|
||||||
|
metadata,
|
||||||
|
Column("key", Text, primary_key=True),
|
||||||
|
Column("value", Text, nullable=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# --- Phase 2, Tier 2 --------------------------------------------------------
|
# --- Phase 2, Tier 2 --------------------------------------------------------
|
||||||
|
|
||||||
@@ -193,6 +203,24 @@ agents_table = Table(
|
|||||||
Column("legacy_mongo_id", Text),
|
Column("legacy_mongo_id", Text),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
user_custom_models_table = Table(
|
||||||
|
"user_custom_models",
|
||||||
|
metadata,
|
||||||
|
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||||
|
Column("user_id", Text, nullable=False),
|
||||||
|
Column("upstream_model_id", Text, nullable=False),
|
||||||
|
Column("display_name", Text, nullable=False),
|
||||||
|
Column("description", Text, nullable=False, server_default=""),
|
||||||
|
Column("base_url", Text, nullable=False),
|
||||||
|
# AES-CBC ciphertext (base64) keyed via per-user PBKDF2 in
|
||||||
|
# application.security.encryption.encrypt_credentials.
|
||||||
|
Column("api_key_encrypted", Text, nullable=False),
|
||||||
|
Column("capabilities", JSONB, nullable=False, server_default="{}"),
|
||||||
|
Column("enabled", Boolean, nullable=False, server_default="true"),
|
||||||
|
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
)
|
||||||
|
|
||||||
attachments_table = Table(
|
attachments_table = Table(
|
||||||
"attachments",
|
"attachments",
|
||||||
metadata,
|
metadata,
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""Repository for the ``agents`` table.
|
"""Repository for the ``agents`` table.
|
||||||
|
|
||||||
This is the most complex Phase 2 repository. Covers every write operation
|
Covers every write operation the legacy Mongo code performs on ``agents_collection``:
|
||||||
the legacy Mongo code performs on ``agents_collection``:
|
|
||||||
|
|
||||||
- create, update, delete
|
- create, update, delete
|
||||||
- find by key (API key lookup)
|
- find by key (API key lookup)
|
||||||
|
|||||||
60
application/storage/db/repositories/app_metadata.py
Normal file
60
application/storage/db/repositories/app_metadata.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
"""Repository for the ``app_metadata`` singleton key/value table.
|
||||||
|
|
||||||
|
Owns the instance-wide state the version-check client needs:
|
||||||
|
``instance_id`` (anonymous UUID sent with each check) and
|
||||||
|
``version_check_notice_shown`` (one-shot flag for the first-run
|
||||||
|
telemetry notice). Kept deliberately generic so future one-off config
|
||||||
|
values can piggyback without a new migration each time.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import Connection, text
|
||||||
|
|
||||||
|
|
||||||
|
class AppMetadataRepository:
|
||||||
|
"""Postgres-backed ``app_metadata`` store. Tiny by design."""
|
||||||
|
|
||||||
|
def __init__(self, conn: Connection) -> None:
|
||||||
|
self._conn = conn
|
||||||
|
|
||||||
|
def get(self, key: str) -> Optional[str]:
|
||||||
|
row = self._conn.execute(
|
||||||
|
text("SELECT value FROM app_metadata WHERE key = :key"),
|
||||||
|
{"key": key},
|
||||||
|
).fetchone()
|
||||||
|
return row[0] if row is not None else None
|
||||||
|
|
||||||
|
def set(self, key: str, value: str) -> None:
|
||||||
|
self._conn.execute(
|
||||||
|
text(
|
||||||
|
"INSERT INTO app_metadata (key, value) VALUES (:key, :value) "
|
||||||
|
"ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value"
|
||||||
|
),
|
||||||
|
{"key": key, "value": value},
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_or_create_instance_id(self) -> str:
|
||||||
|
"""Return the anonymous instance UUID, generating one if absent.
|
||||||
|
|
||||||
|
Uses ``INSERT ... ON CONFLICT DO NOTHING`` + re-read so two
|
||||||
|
workers racing on the very first startup converge on a single
|
||||||
|
UUID instead of each persisting their own.
|
||||||
|
"""
|
||||||
|
existing = self.get("instance_id")
|
||||||
|
if existing:
|
||||||
|
return existing
|
||||||
|
candidate = str(uuid.uuid4())
|
||||||
|
self._conn.execute(
|
||||||
|
text(
|
||||||
|
"INSERT INTO app_metadata (key, value) VALUES ('instance_id', :value) "
|
||||||
|
"ON CONFLICT (key) DO NOTHING"
|
||||||
|
),
|
||||||
|
{"value": candidate},
|
||||||
|
)
|
||||||
|
# Re-read: if another worker won the race, their UUID is now authoritative.
|
||||||
|
winner = self.get("instance_id")
|
||||||
|
return winner or candidate
|
||||||
@@ -17,6 +17,21 @@ _UPDATABLE_SCALARS = {
|
|||||||
_UPDATABLE_JSONB = {"metadata"}
|
_UPDATABLE_JSONB = {"metadata"}
|
||||||
|
|
||||||
|
|
||||||
|
def _attachment_to_dict(row: Any) -> dict:
|
||||||
|
"""row_to_dict + ``upload_path``→``path`` alias.
|
||||||
|
|
||||||
|
Pre-Postgres, the Mongo attachment shape used ``path``. The PG column
|
||||||
|
is ``upload_path``; LLM provider code (google_ai/openai/anthropic and
|
||||||
|
handlers/base) still reads ``attachment.get("path")``. Mirroring the
|
||||||
|
``id``/``_id`` dual-emit in row_to_dict so consumers don't need to
|
||||||
|
know which storage backend produced the dict.
|
||||||
|
"""
|
||||||
|
out = row_to_dict(row)
|
||||||
|
if "upload_path" in out and out.get("path") is None:
|
||||||
|
out["path"] = out["upload_path"]
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class AttachmentsRepository:
|
class AttachmentsRepository:
|
||||||
def __init__(self, conn: Connection) -> None:
|
def __init__(self, conn: Connection) -> None:
|
||||||
self._conn = conn
|
self._conn = conn
|
||||||
@@ -66,7 +81,7 @@ class AttachmentsRepository:
|
|||||||
"legacy_mongo_id": legacy_mongo_id,
|
"legacy_mongo_id": legacy_mongo_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return row_to_dict(result.fetchone())
|
return _attachment_to_dict(result.fetchone())
|
||||||
|
|
||||||
def get(self, attachment_id: str, user_id: str) -> Optional[dict]:
|
def get(self, attachment_id: str, user_id: str) -> Optional[dict]:
|
||||||
result = self._conn.execute(
|
result = self._conn.execute(
|
||||||
@@ -76,7 +91,7 @@ class AttachmentsRepository:
|
|||||||
{"id": attachment_id, "user_id": user_id},
|
{"id": attachment_id, "user_id": user_id},
|
||||||
)
|
)
|
||||||
row = result.fetchone()
|
row = result.fetchone()
|
||||||
return row_to_dict(row) if row is not None else None
|
return _attachment_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
def get_any(self, attachment_id: str, user_id: str) -> Optional[dict]:
|
def get_any(self, attachment_id: str, user_id: str) -> Optional[dict]:
|
||||||
"""Resolve an attachment by either PG UUID or legacy Mongo ObjectId string."""
|
"""Resolve an attachment by either PG UUID or legacy Mongo ObjectId string."""
|
||||||
@@ -155,14 +170,14 @@ class AttachmentsRepository:
|
|||||||
params["user_id"] = user_id
|
params["user_id"] = user_id
|
||||||
result = self._conn.execute(text(sql), params)
|
result = self._conn.execute(text(sql), params)
|
||||||
row = result.fetchone()
|
row = result.fetchone()
|
||||||
return row_to_dict(row) if row is not None else None
|
return _attachment_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
def list_for_user(self, user_id: str) -> list[dict]:
|
def list_for_user(self, user_id: str) -> list[dict]:
|
||||||
result = self._conn.execute(
|
result = self._conn.execute(
|
||||||
text("SELECT * FROM attachments WHERE user_id = :user_id ORDER BY created_at DESC"),
|
text("SELECT * FROM attachments WHERE user_id = :user_id ORDER BY created_at DESC"),
|
||||||
{"user_id": user_id},
|
{"user_id": user_id},
|
||||||
)
|
)
|
||||||
return [row_to_dict(r) for r in result.fetchall()]
|
return [_attachment_to_dict(r) for r in result.fetchall()]
|
||||||
|
|
||||||
def update(self, attachment_id: str, user_id: str, fields: dict) -> bool:
|
def update(self, attachment_id: str, user_id: str, fields: dict) -> bool:
|
||||||
"""Partial update. Used by the LLM providers to cache their
|
"""Partial update. Used by the LLM providers to cache their
|
||||||
|
|||||||
199
application/storage/db/repositories/user_custom_models.py
Normal file
199
application/storage/db/repositories/user_custom_models.py
Normal file
@@ -0,0 +1,199 @@
|
|||||||
|
"""Repository for the ``user_custom_models`` table.
|
||||||
|
|
||||||
|
Backs the end-user "Bring Your Own Model" feature. Each row is one
|
||||||
|
user-supplied OpenAI-compatible endpoint (Mistral, Together, vLLM, ...).
|
||||||
|
The ``id`` UUID is the internal DocsGPT identifier (what agents store
|
||||||
|
in ``default_model_id``); ``upstream_model_id`` is what we send verbatim
|
||||||
|
to the provider's API.
|
||||||
|
|
||||||
|
API key handling: callers pass plaintext via ``api_key_plaintext``;
|
||||||
|
this module wraps the existing ``application.security.encryption``
|
||||||
|
helper (AES-CBC + per-user PBKDF2 salt) and writes the base64 ciphertext
|
||||||
|
to the ``api_key_encrypted`` column. Decryption is the caller's
|
||||||
|
responsibility (they hold the ``user_id``).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from sqlalchemy import Connection, func, text
|
||||||
|
|
||||||
|
from application.security.encryption import (
|
||||||
|
decrypt_credentials,
|
||||||
|
encrypt_credentials,
|
||||||
|
)
|
||||||
|
from application.storage.db.base_repository import row_to_dict
|
||||||
|
from application.storage.db.models import user_custom_models_table
|
||||||
|
|
||||||
|
|
||||||
|
_ALLOWED_CAPABILITY_KEYS = frozenset(
|
||||||
|
{
|
||||||
|
"supports_tools",
|
||||||
|
"supports_structured_output",
|
||||||
|
"supports_streaming",
|
||||||
|
"attachments",
|
||||||
|
"context_window",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UserCustomModelsRepository:
|
||||||
|
def __init__(self, conn: Connection) -> None:
|
||||||
|
self._conn = conn
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------ #
|
||||||
|
# Encryption wrappers
|
||||||
|
# ------------------------------------------------------------------ #
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _encrypt_api_key(api_key_plaintext: str, user_id: str) -> str:
|
||||||
|
"""Encrypt ``api_key_plaintext`` with the per-user PBKDF2 scheme."""
|
||||||
|
return encrypt_credentials({"api_key": api_key_plaintext}, user_id)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _decrypt_api_key(api_key_encrypted: str, user_id: str) -> Optional[str]:
|
||||||
|
"""Decrypt the API key. Returns None on failure (which the caller
|
||||||
|
should surface as a configuration error rather than silently
|
||||||
|
proceeding with the upstream call)."""
|
||||||
|
if not api_key_encrypted:
|
||||||
|
return None
|
||||||
|
creds = decrypt_credentials(api_key_encrypted, user_id)
|
||||||
|
return creds.get("api_key") if creds else None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_capabilities(caps: Optional[dict]) -> dict:
|
||||||
|
"""Drop unknown keys; nothing else is forced. Callers (the route
|
||||||
|
layer) are responsible for value validation (numeric ranges,
|
||||||
|
attachment alias resolution)."""
|
||||||
|
if not caps:
|
||||||
|
return {}
|
||||||
|
return {k: v for k, v in caps.items() if k in _ALLOWED_CAPABILITY_KEYS}
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------ #
|
||||||
|
# CRUD
|
||||||
|
# ------------------------------------------------------------------ #
|
||||||
|
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
upstream_model_id: str,
|
||||||
|
display_name: str,
|
||||||
|
base_url: str,
|
||||||
|
api_key_plaintext: str,
|
||||||
|
description: str = "",
|
||||||
|
capabilities: Optional[dict] = None,
|
||||||
|
enabled: bool = True,
|
||||||
|
) -> dict:
|
||||||
|
values = {
|
||||||
|
"user_id": user_id,
|
||||||
|
"upstream_model_id": upstream_model_id,
|
||||||
|
"display_name": display_name,
|
||||||
|
"description": description or "",
|
||||||
|
"base_url": base_url,
|
||||||
|
"api_key_encrypted": self._encrypt_api_key(api_key_plaintext, user_id),
|
||||||
|
"capabilities": self._normalize_capabilities(capabilities),
|
||||||
|
"enabled": bool(enabled),
|
||||||
|
}
|
||||||
|
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||||
|
|
||||||
|
stmt = (
|
||||||
|
pg_insert(user_custom_models_table)
|
||||||
|
.values(**values)
|
||||||
|
.returning(user_custom_models_table)
|
||||||
|
)
|
||||||
|
result = self._conn.execute(stmt)
|
||||||
|
return row_to_dict(result.fetchone())
|
||||||
|
|
||||||
|
def get(self, model_id: str, user_id: str) -> Optional[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT * FROM user_custom_models "
|
||||||
|
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||||
|
),
|
||||||
|
{"id": str(model_id), "user_id": user_id},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def list_for_user(self, user_id: str) -> list[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT * FROM user_custom_models "
|
||||||
|
"WHERE user_id = :user_id ORDER BY created_at DESC"
|
||||||
|
),
|
||||||
|
{"user_id": user_id},
|
||||||
|
)
|
||||||
|
return [row_to_dict(r) for r in result.fetchall()]
|
||||||
|
|
||||||
|
def update(self, model_id: str, user_id: str, fields: dict) -> bool:
|
||||||
|
"""Apply a partial update.
|
||||||
|
|
||||||
|
Special-cases ``api_key_plaintext``: when present, it is encrypted
|
||||||
|
and stored in ``api_key_encrypted``. When absent (or empty), the
|
||||||
|
existing ciphertext is kept untouched. This is the wire-shape
|
||||||
|
``PATCH`` expects (the UI sends a blank password field when the
|
||||||
|
operator wants to keep the existing key).
|
||||||
|
"""
|
||||||
|
allowed = {
|
||||||
|
"upstream_model_id",
|
||||||
|
"display_name",
|
||||||
|
"description",
|
||||||
|
"base_url",
|
||||||
|
"capabilities",
|
||||||
|
"enabled",
|
||||||
|
}
|
||||||
|
values: dict[str, Any] = {}
|
||||||
|
for col, val in fields.items():
|
||||||
|
if col not in allowed or val is None:
|
||||||
|
continue
|
||||||
|
if col == "capabilities":
|
||||||
|
values[col] = self._normalize_capabilities(val)
|
||||||
|
elif col == "enabled":
|
||||||
|
values[col] = bool(val)
|
||||||
|
else:
|
||||||
|
values[col] = val
|
||||||
|
|
||||||
|
api_key_plaintext = fields.get("api_key_plaintext")
|
||||||
|
if api_key_plaintext:
|
||||||
|
values["api_key_encrypted"] = self._encrypt_api_key(
|
||||||
|
api_key_plaintext, user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if not values:
|
||||||
|
return False
|
||||||
|
values["updated_at"] = func.now()
|
||||||
|
|
||||||
|
t = user_custom_models_table
|
||||||
|
stmt = (
|
||||||
|
t.update()
|
||||||
|
.where(t.c.id == str(model_id))
|
||||||
|
.where(t.c.user_id == user_id)
|
||||||
|
.values(**values)
|
||||||
|
)
|
||||||
|
result = self._conn.execute(stmt)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
def delete(self, model_id: str, user_id: str) -> bool:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"DELETE FROM user_custom_models "
|
||||||
|
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||||
|
),
|
||||||
|
{"id": str(model_id), "user_id": user_id},
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------ #
|
||||||
|
# Decryption helpers exposed to the registry layer
|
||||||
|
# ------------------------------------------------------------------ #
|
||||||
|
|
||||||
|
def get_decrypted_api_key(
|
||||||
|
self, model_id: str, user_id: str
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""Convenience: fetch the row and return the decrypted API key,
|
||||||
|
or ``None`` if the row is missing or decryption fails."""
|
||||||
|
row = self.get(model_id, user_id)
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
return self._decrypt_api_key(row.get("api_key_encrypted", ""), user_id)
|
||||||
0
application/updates/__init__.py
Normal file
0
application/updates/__init__.py
Normal file
304
application/updates/version_check.py
Normal file
304
application/updates/version_check.py
Normal file
@@ -0,0 +1,304 @@
|
|||||||
|
"""Anonymous version-check client.
|
||||||
|
|
||||||
|
Fired on every Celery worker boot (see ``application/celery_init.py``
|
||||||
|
``worker_ready`` handler) and on a 7h periodic schedule (see the
|
||||||
|
``version-check`` entry in ``application/api/user/tasks.py``). Posts
|
||||||
|
the running version + anonymous instance UUID to
|
||||||
|
``gptcloud.arc53.com/api/check``, caches the response in Redis, and
|
||||||
|
surfaces any advisories to stdout + logs.
|
||||||
|
|
||||||
|
Design invariants — all enforced by a broad ``try/except`` at the top
|
||||||
|
of :func:`run_check`:
|
||||||
|
|
||||||
|
* Never blocks worker startup (fired from a daemon thread).
|
||||||
|
* Never raises to the caller (every failure is swallowed + logged at
|
||||||
|
``DEBUG``).
|
||||||
|
* Opt-out via ``VERSION_CHECK=0`` short-circuits before any Postgres
|
||||||
|
write, Redis access, or outbound request.
|
||||||
|
* Redis coordinates multi-worker and multi-replica deployments — the
|
||||||
|
first worker to acquire ``docsgpt:version_check:lock`` fetches, the
|
||||||
|
rest read from the cached response on the next cycle.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
import socket
|
||||||
|
import sys
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from application.cache import get_redis_instance
|
||||||
|
from application.core.settings import settings
|
||||||
|
from application.storage.db.repositories.app_metadata import AppMetadataRepository
|
||||||
|
from application.storage.db.session import db_session
|
||||||
|
from application.version import get_version
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
ENDPOINT_URL = "https://gptcloud.arc53.com/api/check"
|
||||||
|
CLIENT_NAME = "docsgpt-backend"
|
||||||
|
REQUEST_TIMEOUT_SECONDS = 5
|
||||||
|
|
||||||
|
CACHE_KEY = "docsgpt:version_check:response"
|
||||||
|
LOCK_KEY = "docsgpt:version_check:lock"
|
||||||
|
CACHE_TTL_SECONDS = 6 * 3600 # 6h default; shortened by response `next_check_after`.
|
||||||
|
LOCK_TTL_SECONDS = 60
|
||||||
|
|
||||||
|
NOTICE_KEY = "version_check_notice_shown"
|
||||||
|
INSTANCE_ID_KEY = "instance_id"
|
||||||
|
|
||||||
|
_HIGH_SEVERITIES = {"high", "critical"}
|
||||||
|
|
||||||
|
_ANSI_RESET = "\033[0m"
|
||||||
|
_ANSI_RED = "\033[31m"
|
||||||
|
_ANSI_YELLOW = "\033[33m"
|
||||||
|
|
||||||
|
|
||||||
|
def run_check() -> None:
|
||||||
|
"""Entry point for the worker-startup daemon thread.
|
||||||
|
|
||||||
|
Safe to call unconditionally: the opt-out, Redis-outage, and
|
||||||
|
Postgres-outage paths all return silently. No exception propagates.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
_run_check_inner()
|
||||||
|
except Exception as exc: # noqa: BLE001 — belt-and-braces; nothing escapes.
|
||||||
|
logger.debug("version check crashed: %s", exc, exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _run_check_inner() -> None:
|
||||||
|
if not settings.VERSION_CHECK:
|
||||||
|
return
|
||||||
|
|
||||||
|
instance_id = _resolve_instance_id_and_notice()
|
||||||
|
if instance_id is None:
|
||||||
|
# Postgres unavailable — per spec we skip the check entirely
|
||||||
|
# rather than phone home with a synthetic/ephemeral UUID.
|
||||||
|
return
|
||||||
|
|
||||||
|
redis_client = get_redis_instance()
|
||||||
|
|
||||||
|
cached = _read_cache(redis_client)
|
||||||
|
if cached is not None:
|
||||||
|
_render_advisories(cached)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Cache miss. Try to win the lock; if another worker has it, skip.
|
||||||
|
# ``redis_client is None`` here means Redis is unreachable — per the
|
||||||
|
# spec we still proceed uncached (acceptable duplicate calls in
|
||||||
|
# multi-worker Redis-less deploys).
|
||||||
|
if redis_client is not None and not _acquire_lock(redis_client):
|
||||||
|
return
|
||||||
|
|
||||||
|
response = _fetch(instance_id)
|
||||||
|
if response is None:
|
||||||
|
if redis_client is not None:
|
||||||
|
_release_lock(redis_client)
|
||||||
|
return
|
||||||
|
|
||||||
|
_write_cache(redis_client, response)
|
||||||
|
_render_advisories(response)
|
||||||
|
if redis_client is not None:
|
||||||
|
_release_lock(redis_client)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_instance_id_and_notice() -> Optional[str]:
|
||||||
|
"""Load (or create) the instance UUID and emit the first-run notice.
|
||||||
|
|
||||||
|
The notice is printed at most once across the lifetime of the
|
||||||
|
installation — tracked via the ``version_check_notice_shown`` row
|
||||||
|
in ``app_metadata``. Both reads and the write happen inside one
|
||||||
|
short transaction so two racing workers can't each emit the notice.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with db_session() as conn:
|
||||||
|
repo = AppMetadataRepository(conn)
|
||||||
|
instance_id = repo.get_or_create_instance_id()
|
||||||
|
if repo.get(NOTICE_KEY) is None:
|
||||||
|
_print_first_run_notice()
|
||||||
|
repo.set(NOTICE_KEY, "1")
|
||||||
|
return instance_id
|
||||||
|
except Exception as exc: # noqa: BLE001 — Postgres down, bad URI, etc.
|
||||||
|
logger.debug("version check: Postgres unavailable (%s)", exc, exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _print_first_run_notice() -> None:
|
||||||
|
message = (
|
||||||
|
"Anonymous version check enabled — sends version to "
|
||||||
|
"gptcloud.arc53.com.\nDisable with VERSION_CHECK=0."
|
||||||
|
)
|
||||||
|
print(message, flush=True)
|
||||||
|
logger.info("version check: first-run notice shown")
|
||||||
|
|
||||||
|
|
||||||
|
def _read_cache(redis_client) -> Optional[Dict[str, Any]]:
|
||||||
|
if redis_client is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
raw = redis_client.get(CACHE_KEY)
|
||||||
|
except Exception as exc: # noqa: BLE001 — Redis transient errors.
|
||||||
|
logger.debug("version check: cache GET failed (%s)", exc, exc_info=True)
|
||||||
|
return None
|
||||||
|
if raw is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return json.loads(raw.decode("utf-8") if isinstance(raw, bytes) else raw)
|
||||||
|
except (ValueError, AttributeError) as exc:
|
||||||
|
logger.debug("version check: cache decode failed (%s)", exc, exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _write_cache(redis_client, response: Dict[str, Any]) -> None:
|
||||||
|
if redis_client is None:
|
||||||
|
return
|
||||||
|
ttl = _compute_ttl(response)
|
||||||
|
try:
|
||||||
|
redis_client.setex(CACHE_KEY, ttl, json.dumps(response))
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
logger.debug("version check: cache SETEX failed (%s)", exc, exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_ttl(response: Dict[str, Any]) -> int:
|
||||||
|
"""Cap the cache at 6h but honor a shorter server-specified window."""
|
||||||
|
next_after = response.get("next_check_after")
|
||||||
|
if isinstance(next_after, (int, float)) and next_after > 0:
|
||||||
|
return max(1, min(CACHE_TTL_SECONDS, int(next_after)))
|
||||||
|
return CACHE_TTL_SECONDS
|
||||||
|
|
||||||
|
|
||||||
|
def _acquire_lock(redis_client) -> bool:
|
||||||
|
try:
|
||||||
|
owner = f"{socket.gethostname()}:{os.getpid()}"
|
||||||
|
return bool(
|
||||||
|
redis_client.set(LOCK_KEY, owner, nx=True, ex=LOCK_TTL_SECONDS)
|
||||||
|
)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
# Treat a failing Redis the same as "no lock infra" — skip rather
|
||||||
|
# than fire without coordination, because Redis outage is
|
||||||
|
# usually transient and one missed cycle is harmless.
|
||||||
|
logger.debug("version check: lock acquire failed (%s)", exc, exc_info=True)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _release_lock(redis_client) -> None:
|
||||||
|
try:
|
||||||
|
redis_client.delete(LOCK_KEY)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
logger.debug("version check: lock release failed (%s)", exc, exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _fetch(instance_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
version = get_version()
|
||||||
|
if version in ("", "unknown"):
|
||||||
|
# The endpoint rejects payloads without a valid semver, and the
|
||||||
|
# rejection is otherwise logged at DEBUG — invisible under the
|
||||||
|
# usual ``-l INFO`` Celery worker start. Surface it loudly so a
|
||||||
|
# misconfigured release (missing or unset ``__version__``) is
|
||||||
|
# obvious instead of silently disabling the check.
|
||||||
|
logger.warning(
|
||||||
|
"version check: skipping — get_version() returned %r. "
|
||||||
|
"Set __version__ in application/version.py to a valid "
|
||||||
|
"version string.",
|
||||||
|
version,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
payload = {
|
||||||
|
"version": version,
|
||||||
|
"instance_id": instance_id,
|
||||||
|
"python_version": platform.python_version(),
|
||||||
|
"platform": sys.platform,
|
||||||
|
"client": CLIENT_NAME,
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
resp = requests.post(
|
||||||
|
ENDPOINT_URL,
|
||||||
|
json=payload,
|
||||||
|
timeout=REQUEST_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
|
except requests.RequestException as exc:
|
||||||
|
logger.debug("version check: request failed (%s)", exc, exc_info=True)
|
||||||
|
return None
|
||||||
|
if resp.status_code >= 400:
|
||||||
|
logger.debug("version check: non-2xx response %s", resp.status_code)
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return resp.json()
|
||||||
|
except ValueError as exc:
|
||||||
|
logger.debug("version check: response decode failed (%s)", exc, exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _render_advisories(response: Dict[str, Any]) -> None:
|
||||||
|
advisories = response.get("advisories") or []
|
||||||
|
if not isinstance(advisories, list):
|
||||||
|
return
|
||||||
|
current_version = get_version()
|
||||||
|
for advisory in advisories:
|
||||||
|
if not isinstance(advisory, dict):
|
||||||
|
continue
|
||||||
|
severity = str(advisory.get("severity", "")).lower()
|
||||||
|
advisory_id = advisory.get("id", "UNKNOWN")
|
||||||
|
title = advisory.get("title", "")
|
||||||
|
url = advisory.get("url", "")
|
||||||
|
fixed_in = advisory.get("fixed_in")
|
||||||
|
summary = advisory.get(
|
||||||
|
"summary",
|
||||||
|
f"Your DocsGPT version {current_version} is vulnerable.",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
"security advisory %s (severity=%s) affects version %s: %s%s%s",
|
||||||
|
advisory_id,
|
||||||
|
severity or "unknown",
|
||||||
|
current_version,
|
||||||
|
title or summary,
|
||||||
|
f" — fixed in {fixed_in}" if fixed_in else "",
|
||||||
|
f" — {url}" if url else "",
|
||||||
|
)
|
||||||
|
|
||||||
|
if severity in _HIGH_SEVERITIES:
|
||||||
|
_print_console_advisory(
|
||||||
|
advisory_id=advisory_id,
|
||||||
|
title=title,
|
||||||
|
severity=severity,
|
||||||
|
summary=summary,
|
||||||
|
fixed_in=fixed_in,
|
||||||
|
url=url,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _print_console_advisory(
|
||||||
|
*,
|
||||||
|
advisory_id: str,
|
||||||
|
title: str,
|
||||||
|
severity: str,
|
||||||
|
summary: str,
|
||||||
|
fixed_in: Optional[str],
|
||||||
|
url: str,
|
||||||
|
) -> None:
|
||||||
|
color = _ANSI_RED if severity == "critical" else _ANSI_YELLOW
|
||||||
|
bar = "=" * 60
|
||||||
|
upgrade_line = ""
|
||||||
|
if fixed_in and url:
|
||||||
|
upgrade_line = f" Upgrade to {fixed_in}+ — {url}"
|
||||||
|
elif fixed_in:
|
||||||
|
upgrade_line = f" Upgrade to {fixed_in}+"
|
||||||
|
elif url:
|
||||||
|
upgrade_line = f" {url}"
|
||||||
|
|
||||||
|
lines = [
|
||||||
|
bar,
|
||||||
|
f"\u26a0 SECURITY ADVISORY: {advisory_id}",
|
||||||
|
f" {summary}",
|
||||||
|
f" {title} (severity: {severity})" if title else f" severity: {severity}",
|
||||||
|
]
|
||||||
|
if upgrade_line:
|
||||||
|
lines.append(upgrade_line)
|
||||||
|
lines.append(bar)
|
||||||
|
print(f"{color}{chr(10).join(lines)}{_ANSI_RESET}", flush=True)
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
import sys
|
import sys
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from application.storage.db.repositories.token_usage import TokenUsageRepository
|
from application.storage.db.repositories.token_usage import TokenUsageRepository
|
||||||
@@ -20,6 +21,15 @@ def _serialize_for_token_count(value):
|
|||||||
if value is None:
|
if value is None:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
# Raw binary payloads (image/file attachments arrive as ``bytes`` from
|
||||||
|
# ``GoogleLLM.prepare_messages_with_attachments``) — without this
|
||||||
|
# branch they fall through to ``str(value)`` below, which produces a
|
||||||
|
# multi-megabyte ``"b'\\x89PNG...'"`` repr-string and inflates
|
||||||
|
# ``prompt_tokens`` by orders of magnitude. Same intent as the
|
||||||
|
# data-URL skip above.
|
||||||
|
if isinstance(value, (bytes, bytearray, memoryview)):
|
||||||
|
return ""
|
||||||
|
|
||||||
if isinstance(value, list):
|
if isinstance(value, list):
|
||||||
return [_serialize_for_token_count(item) for item in value]
|
return [_serialize_for_token_count(item) for item in value]
|
||||||
|
|
||||||
@@ -145,19 +155,44 @@ def stream_token_usage(func):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
batch = []
|
batch = []
|
||||||
result = func(self, model, messages, stream, tools, **kwargs)
|
started_at = time.monotonic()
|
||||||
for r in result:
|
error: BaseException | None = None
|
||||||
batch.append(r)
|
try:
|
||||||
yield r
|
result = func(self, model, messages, stream, tools, **kwargs)
|
||||||
for line in batch:
|
for r in result:
|
||||||
call_usage["generated_tokens"] += _count_tokens(line)
|
batch.append(r)
|
||||||
self.token_usage["prompt_tokens"] += call_usage["prompt_tokens"]
|
yield r
|
||||||
self.token_usage["generated_tokens"] += call_usage["generated_tokens"]
|
except Exception as exc:
|
||||||
update_token_usage(
|
# ``GeneratorExit`` (consumer disconnected) and KeyboardInterrupt
|
||||||
self.decoded_token,
|
# flow through as ``status="ok"`` — same convention as
|
||||||
self.user_api_key,
|
# ``application.logging._consume_and_log``.
|
||||||
call_usage,
|
error = exc
|
||||||
getattr(self, "agent_id", None),
|
raise
|
||||||
)
|
finally:
|
||||||
|
for line in batch:
|
||||||
|
call_usage["generated_tokens"] += _count_tokens(line)
|
||||||
|
self.token_usage["prompt_tokens"] += call_usage["prompt_tokens"]
|
||||||
|
self.token_usage["generated_tokens"] += call_usage["generated_tokens"]
|
||||||
|
# Persist usage rows only on success: a partial mid-stream
|
||||||
|
# failure shouldn't bill the user for a response they never got.
|
||||||
|
if error is None:
|
||||||
|
update_token_usage(
|
||||||
|
self.decoded_token,
|
||||||
|
self.user_api_key,
|
||||||
|
call_usage,
|
||||||
|
getattr(self, "agent_id", None),
|
||||||
|
)
|
||||||
|
emit = getattr(self, "_emit_stream_finished_log", None)
|
||||||
|
if callable(emit):
|
||||||
|
try:
|
||||||
|
emit(
|
||||||
|
model,
|
||||||
|
prompt_tokens=call_usage["prompt_tokens"],
|
||||||
|
completion_tokens=call_usage["generated_tokens"],
|
||||||
|
latency_ms=int((time.monotonic() - started_at) * 1000),
|
||||||
|
error=error,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to emit llm_stream_finished")
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|||||||
@@ -83,9 +83,9 @@ def count_tokens_docs(docs):
|
|||||||
|
|
||||||
|
|
||||||
def calculate_doc_token_budget(
|
def calculate_doc_token_budget(
|
||||||
model_id: str = "gpt-4o"
|
model_id: str = "gpt-4o", user_id: str | None = None
|
||||||
) -> int:
|
) -> int:
|
||||||
total_context = get_token_limit(model_id)
|
total_context = get_token_limit(model_id, user_id=user_id)
|
||||||
reserved = sum(settings.RESERVED_TOKENS.values())
|
reserved = sum(settings.RESERVED_TOKENS.values())
|
||||||
doc_budget = total_context - reserved
|
doc_budget = total_context - reserved
|
||||||
return max(doc_budget, 1000)
|
return max(doc_budget, 1000)
|
||||||
@@ -150,9 +150,11 @@ def get_hash(data):
|
|||||||
return hashlib.md5(data.encode(), usedforsecurity=False).hexdigest()
|
return hashlib.md5(data.encode(), usedforsecurity=False).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def limit_chat_history(history, max_token_limit=None, model_id="docsgpt-local"):
|
def limit_chat_history(
|
||||||
|
history, max_token_limit=None, model_id="docsgpt-local", user_id=None
|
||||||
|
):
|
||||||
"""Limit chat history to fit within token limit."""
|
"""Limit chat history to fit within token limit."""
|
||||||
model_token_limit = get_token_limit(model_id)
|
model_token_limit = get_token_limit(model_id, user_id=user_id)
|
||||||
max_token_limit = (
|
max_token_limit = (
|
||||||
max_token_limit
|
max_token_limit
|
||||||
if max_token_limit and max_token_limit < model_token_limit
|
if max_token_limit and max_token_limit < model_token_limit
|
||||||
@@ -204,7 +206,9 @@ def generate_image_url(image_path):
|
|||||||
|
|
||||||
|
|
||||||
def calculate_compression_threshold(
|
def calculate_compression_threshold(
|
||||||
model_id: str, threshold_percentage: float = 0.8
|
model_id: str,
|
||||||
|
threshold_percentage: float = 0.8,
|
||||||
|
user_id: str | None = None,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Calculate token threshold for triggering compression.
|
Calculate token threshold for triggering compression.
|
||||||
@@ -212,11 +216,13 @@ def calculate_compression_threshold(
|
|||||||
Args:
|
Args:
|
||||||
model_id: Model identifier
|
model_id: Model identifier
|
||||||
threshold_percentage: Percentage of context window (default 80%)
|
threshold_percentage: Percentage of context window (default 80%)
|
||||||
|
user_id: When set, BYOM custom-model records (UUID-keyed) resolve
|
||||||
|
for context-window lookup.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Token count threshold
|
Token count threshold
|
||||||
"""
|
"""
|
||||||
total_context = get_token_limit(model_id)
|
total_context = get_token_limit(model_id, user_id=user_id)
|
||||||
threshold = int(total_context * threshold_percentage)
|
threshold = int(total_context * threshold_percentage)
|
||||||
return threshold
|
return threshold
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,23 @@
|
|||||||
import logging
|
import logging
|
||||||
|
from functools import cached_property
|
||||||
|
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
from application.vectorstore.base import BaseVectorStore
|
from application.vectorstore.base import BaseVectorStore
|
||||||
from application.vectorstore.document_class import Document
|
from application.vectorstore.document_class import Document
|
||||||
|
|
||||||
|
|
||||||
|
def _lazy_import_pymongo():
|
||||||
|
"""Lazy import of pymongo so installations that don't use the MongoDB vectorstore don't need it."""
|
||||||
|
try:
|
||||||
|
import pymongo
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import pymongo python package. "
|
||||||
|
"Please install it with `pip install pymongo`."
|
||||||
|
) from exc
|
||||||
|
return pymongo
|
||||||
|
|
||||||
|
|
||||||
class MongoDBVectorStore(BaseVectorStore):
|
class MongoDBVectorStore(BaseVectorStore):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -20,20 +34,23 @@ class MongoDBVectorStore(BaseVectorStore):
|
|||||||
self._embedding_key = embedding_key
|
self._embedding_key = embedding_key
|
||||||
self._embeddings_key = embeddings_key
|
self._embeddings_key = embeddings_key
|
||||||
self._mongo_uri = settings.MONGO_URI
|
self._mongo_uri = settings.MONGO_URI
|
||||||
|
self._database_name = database
|
||||||
|
self._collection_name = collection
|
||||||
self._source_id = source_id.replace("application/indexes/", "").rstrip("/")
|
self._source_id = source_id.replace("application/indexes/", "").rstrip("/")
|
||||||
self._embedding = self._get_embeddings(settings.EMBEDDINGS_NAME, embeddings_key)
|
self._embedding = self._get_embeddings(settings.EMBEDDINGS_NAME, embeddings_key)
|
||||||
|
|
||||||
try:
|
@cached_property
|
||||||
import pymongo
|
def _client(self):
|
||||||
except ImportError:
|
pymongo = _lazy_import_pymongo()
|
||||||
raise ImportError(
|
return pymongo.MongoClient(self._mongo_uri)
|
||||||
"Could not import pymongo python package. "
|
|
||||||
"Please install it with `pip install pymongo`."
|
|
||||||
)
|
|
||||||
|
|
||||||
self._client = pymongo.MongoClient(self._mongo_uri)
|
@cached_property
|
||||||
self._database = self._client[database]
|
def _database(self):
|
||||||
self._collection = self._database[collection]
|
return self._client[self._database_name]
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def _collection(self):
|
||||||
|
return self._database[self._collection_name]
|
||||||
|
|
||||||
def search(self, question, k=2, *args, **kwargs):
|
def search(self, question, k=2, *args, **kwargs):
|
||||||
query_vector = self._embedding.embed_query(question)
|
query_vector = self._embedding.embed_query(question)
|
||||||
|
|||||||
10
application/version.py
Normal file
10
application/version.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
"""DocsGPT backend version string."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
__version__ = "0.17.0"
|
||||||
|
|
||||||
|
|
||||||
|
def get_version() -> str:
|
||||||
|
"""Return the DocsGPT backend version."""
|
||||||
|
return __version__
|
||||||
@@ -344,18 +344,34 @@ def run_agent_logic(agent_config, input_data):
|
|||||||
|
|
||||||
# Determine model_id: check agent's default_model_id, fallback to system default
|
# Determine model_id: check agent's default_model_id, fallback to system default
|
||||||
agent_default_model = agent_config.get("default_model_id", "")
|
agent_default_model = agent_config.get("default_model_id", "")
|
||||||
if agent_default_model and validate_model_id(agent_default_model):
|
if agent_default_model and validate_model_id(
|
||||||
|
agent_default_model, user_id=owner
|
||||||
|
):
|
||||||
model_id = agent_default_model
|
model_id = agent_default_model
|
||||||
else:
|
else:
|
||||||
model_id = get_default_model_id()
|
model_id = get_default_model_id()
|
||||||
|
if agent_default_model:
|
||||||
|
# Stored model_id no longer resolves in the registry. Log so
|
||||||
|
# operators can detect bad YAML edits before users complain;
|
||||||
|
# behavior matches the historical silent fallback.
|
||||||
|
logging.warning(
|
||||||
|
"Agent %s references unknown model_id %r; falling back to %r",
|
||||||
|
agent_id,
|
||||||
|
agent_default_model,
|
||||||
|
model_id,
|
||||||
|
)
|
||||||
|
|
||||||
# Get provider and API key for the selected model
|
# Get provider and API key for the selected model
|
||||||
provider = get_provider_from_model_id(model_id) if model_id else settings.LLM_PROVIDER
|
provider = (
|
||||||
|
get_provider_from_model_id(model_id, user_id=owner)
|
||||||
|
if model_id
|
||||||
|
else settings.LLM_PROVIDER
|
||||||
|
)
|
||||||
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
|
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
|
||||||
|
|
||||||
# Calculate proper doc_token_limit based on model's context window
|
# Calculate proper doc_token_limit based on model's context window
|
||||||
doc_token_limit = calculate_doc_token_budget(
|
doc_token_limit = calculate_doc_token_budget(
|
||||||
model_id=model_id
|
model_id=model_id, user_id=owner
|
||||||
)
|
)
|
||||||
|
|
||||||
retriever = RetrieverCreator.create_retriever(
|
retriever = RetrieverCreator.create_retriever(
|
||||||
@@ -416,7 +432,10 @@ def run_agent_logic(agent_config, input_data):
|
|||||||
"tool_calls": tool_calls,
|
"tool_calls": tool_calls,
|
||||||
"thought": thought,
|
"thought": thought,
|
||||||
}
|
}
|
||||||
logging.info(f"Agent response: {result}")
|
# Per-activity summary fields (answer_length, thought_length,
|
||||||
|
# source_count, tool_call_count) now ride on the inner
|
||||||
|
# ``activity_finished`` event emitted by ``log_activity`` around
|
||||||
|
# ``Agent.gen`` above; no separate ``agent_response`` log needed.
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error in run_agent_logic: {e}", exc_info=True)
|
logging.error(f"Error in run_agent_logic: {e}", exc_info=True)
|
||||||
|
|||||||
@@ -104,7 +104,15 @@ To run the DocsGPT backend locally, you'll need to set up a Python environment a
|
|||||||
flask --app application/app.py run --host=0.0.0.0 --port=7091
|
flask --app application/app.py run --host=0.0.0.0 --port=7091
|
||||||
```
|
```
|
||||||
|
|
||||||
This command will launch the backend server, making it accessible on `http://localhost:7091`.
|
This command will launch the backend server, making it accessible on `http://localhost:7091`. It's the fastest inner-loop option for day-to-day development — the Werkzeug interactive debugger still works and it hot-reloads on source changes. It serves the Flask routes only.
|
||||||
|
|
||||||
|
If you need to exercise the full ASGI stack — the `/mcp` endpoint (FastMCP server), or to match the production runtime — run the ASGI composition under uvicorn instead:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uvicorn application.asgi:asgi_app --host 0.0.0.0 --port 7091 --reload
|
||||||
|
```
|
||||||
|
|
||||||
|
Production uses `gunicorn -k uvicorn_worker.UvicornWorker` against the same `application.asgi:asgi_app` target.
|
||||||
|
|
||||||
6. **Start the Celery Worker:**
|
6. **Start the Celery Worker:**
|
||||||
|
|
||||||
|
|||||||
@@ -99,6 +99,82 @@ EMBEDDINGS_NAME=huggingface_sentence-transformers/all-mpnet-base-v2 # You can al
|
|||||||
|
|
||||||
In this case, even though you are using Ollama locally, `LLM_PROVIDER` is set to `openai` because Ollama (and many other local inference engines) are designed to be API-compatible with OpenAI. `OPENAI_BASE_URL` points DocsGPT to the local Ollama server.
|
In this case, even though you are using Ollama locally, `LLM_PROVIDER` is set to `openai` because Ollama (and many other local inference engines) are designed to be API-compatible with OpenAI. `OPENAI_BASE_URL` points DocsGPT to the local Ollama server.
|
||||||
|
|
||||||
|
## Adding Custom Models (`MODELS_CONFIG_DIR`)
|
||||||
|
|
||||||
|
DocsGPT ships with a built-in catalog of models for the providers it
|
||||||
|
supports out of the box (OpenAI, Anthropic, Google, Groq, OpenRouter,
|
||||||
|
Novita, Azure OpenAI, Hugging Face, DocsGPT). To add **your own
|
||||||
|
models** without forking the repo — for example, a Mistral or Together
|
||||||
|
account, a self-hosted vLLM endpoint, or any other OpenAI-compatible
|
||||||
|
API — point `MODELS_CONFIG_DIR` at a directory of YAML files.
|
||||||
|
|
||||||
|
```
|
||||||
|
MODELS_CONFIG_DIR=/etc/docsgpt/models
|
||||||
|
MISTRAL_API_KEY=sk-...
|
||||||
|
```
|
||||||
|
|
||||||
|
A minimal YAML for one provider:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# /etc/docsgpt/models/mistral.yaml
|
||||||
|
provider: openai_compatible
|
||||||
|
display_provider: mistral
|
||||||
|
api_key_env: MISTRAL_API_KEY
|
||||||
|
base_url: https://api.mistral.ai/v1
|
||||||
|
defaults:
|
||||||
|
supports_tools: true
|
||||||
|
context_window: 128000
|
||||||
|
models:
|
||||||
|
- id: mistral-large-latest
|
||||||
|
display_name: Mistral Large
|
||||||
|
- id: mistral-small-latest
|
||||||
|
display_name: Mistral Small
|
||||||
|
```
|
||||||
|
|
||||||
|
After restart, those models appear in `/api/models` and are selectable
|
||||||
|
in the UI. A working template lives at
|
||||||
|
`application/core/models/examples/mistral.yaml.example`.
|
||||||
|
|
||||||
|
**What you can do:**
|
||||||
|
|
||||||
|
- Add new `openai_compatible` providers (Mistral, Together, Fireworks,
|
||||||
|
Ollama, vLLM, ...) — one YAML per provider, each with its own
|
||||||
|
`api_key_env` and `base_url`.
|
||||||
|
- Extend an existing provider's catalog by dropping a YAML with the
|
||||||
|
same `provider:` value as the built-in (e.g. `provider: anthropic`
|
||||||
|
with extra models).
|
||||||
|
- Override a built-in model's capabilities by re-declaring the same
|
||||||
|
`id` — later wins, override is logged at `WARNING`.
|
||||||
|
|
||||||
|
**What you cannot do via `MODELS_CONFIG_DIR`:** add a brand-new
|
||||||
|
non-OpenAI provider. That requires a Python plugin under
|
||||||
|
`application/llm/providers/`. See
|
||||||
|
`application/core/models/README.md` for the full schema reference.
|
||||||
|
|
||||||
|
### Docker
|
||||||
|
|
||||||
|
Mount the directory and set the env var:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# docker-compose.yml
|
||||||
|
services:
|
||||||
|
app:
|
||||||
|
image: arc53/docsgpt
|
||||||
|
environment:
|
||||||
|
MODELS_CONFIG_DIR: /etc/docsgpt/models
|
||||||
|
MISTRAL_API_KEY: ${MISTRAL_API_KEY}
|
||||||
|
volumes:
|
||||||
|
- ./my-models:/etc/docsgpt/models:ro
|
||||||
|
```
|
||||||
|
|
||||||
|
### Misconfiguration
|
||||||
|
|
||||||
|
If `MODELS_CONFIG_DIR` is set but the path doesn't exist (or isn't a
|
||||||
|
directory), the app logs a `WARNING` at boot and continues with just
|
||||||
|
the built-in catalog — it does **not** fail to start. If a YAML
|
||||||
|
declares an unknown provider name or has a schema error, the app
|
||||||
|
**does** fail to start, with the offending file path in the message.
|
||||||
|
|
||||||
## Speech-to-Text Settings
|
## Speech-to-Text Settings
|
||||||
|
|
||||||
DocsGPT can transcribe audio in two places:
|
DocsGPT can transcribe audio in two places:
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user