mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-13 15:45:26 +00:00
Compare commits
78 Commits
aesgi
...
feat-notif
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0bce35ad29 | ||
|
|
9de8bb4499 | ||
|
|
cdbd3f061d | ||
|
|
2ac46fd858 | ||
|
|
daa4320da2 | ||
|
|
e70a7a5115 | ||
|
|
150d9f4e37 | ||
|
|
746bcbc5f9 | ||
|
|
aa91117fbf | ||
|
|
abbd56cb66 | ||
|
|
85d8375e6c | ||
|
|
7e98d21b61 | ||
|
|
249f9f9fe0 | ||
|
|
6c4346eb84 | ||
|
|
cb3ca8a36b | ||
|
|
4c8230fb6c | ||
|
|
649557798d | ||
|
|
afe8354ca5 | ||
|
|
5483eb0e27 | ||
|
|
bd2985db47 | ||
|
|
b99147ba83 | ||
|
|
c3023f8b71 | ||
|
|
c168a530f5 | ||
|
|
2d539f3199 | ||
|
|
ed9444cf3d | ||
|
|
e692c645b9 | ||
|
|
b4c4ab68f0 | ||
|
|
d23679dd93 | ||
|
|
1b2239e54b | ||
|
|
5ceb99f946 | ||
|
|
892908cef5 | ||
|
|
99ffe439c7 | ||
|
|
ed87972ca6 | ||
|
|
6ad9022dd3 | ||
|
|
9b8fe2d5d0 | ||
|
|
d1dc8de27c | ||
|
|
a29fa44b51 | ||
|
|
026371d024 | ||
|
|
b0df2a479b | ||
|
|
5eae83af1b | ||
|
|
9c875c83c2 | ||
|
|
e6e671faf1 | ||
|
|
a31ec97bd7 | ||
|
|
ebe752d103 | ||
|
|
8c30c1c880 | ||
|
|
4a598e062c | ||
|
|
e285b47170 | ||
|
|
2d884a3df1 | ||
|
|
b9920731e0 | ||
|
|
f5f4c07e59 | ||
|
|
e87dc42ad0 | ||
|
|
40a30054bc | ||
|
|
707e782ac8 | ||
|
|
2bc0b6946b | ||
|
|
fbd686b725 | ||
|
|
29320eb9fd | ||
|
|
0d2a8e11f4 | ||
|
|
f0c39dec23 | ||
|
|
552bfe016a | ||
|
|
a6a5db631b | ||
|
|
8e9f661efc | ||
|
|
82c71be819 | ||
|
|
318de18d43 | ||
|
|
af618de13d | ||
|
|
ef976eeb06 | ||
|
|
9c8ae9d540 | ||
|
|
7ca33b2b72 | ||
|
|
fb24f9cf5e | ||
|
|
d1b9798f62 | ||
|
|
ddc3adf3ab | ||
|
|
a4991d01ac | ||
|
|
87fd1bd359 | ||
|
|
c71e986d34 | ||
|
|
a2a06c569e | ||
|
|
c5f00a1d1b | ||
|
|
2a15bb0102 | ||
|
|
c06888bc86 | ||
|
|
65460b0c03 |
@@ -35,8 +35,5 @@ MICROSOFT_TENANT_ID=your-azure-ad-tenant-id
|
||||
#Alternatively, use "https://login.microsoftonline.com/common" for multi-tenant app.
|
||||
MICROSOFT_AUTHORITY=https://{tenantId}.ciamlogin.com/{tenantId}
|
||||
|
||||
# User-data Postgres DB (Phase 0 of the MongoDB→Postgres migration).
|
||||
# Standard Postgres URI — `postgres://` and `postgresql://` both work.
|
||||
# Leave unset while the migration is still being rolled out; the app will
|
||||
# fall back to MongoDB for user data until POSTGRES_URI is configured.
|
||||
|
||||
# POSTGRES_URI=postgresql://docsgpt:docsgpt@localhost:5432/docsgpt
|
||||
|
||||
@@ -115,7 +115,7 @@ vale .
|
||||
- `frontend/`: Vite + React + TypeScript application.
|
||||
- `frontend/src/`: main UI code, including `components`, `conversation`, `hooks`, `locale`, `settings`, `upload`, and Redux store wiring in `store.ts`.
|
||||
- `docs/`: separate documentation site built with Next.js/Nextra.
|
||||
- `extensions/`: integrations and widgets such as Chatwoot, Chrome, Discord, React widget, Slack bot, and web widget.
|
||||
- `extensions/`: integrations and widgets — currently the Chatwoot webhook bridge and the React widget (published to npm as `docsgpt`). The Discord bot, Slack bot, and Chrome extension have been moved to their own repos under `arc53/`.
|
||||
- `deployment/`: Docker Compose variants and Kubernetes manifests.
|
||||
|
||||
## Coding rules
|
||||
|
||||
12
README.md
12
README.md
@@ -47,11 +47,13 @@
|
||||
</ul>
|
||||
|
||||
## Roadmap
|
||||
- [x] Add OAuth 2.0 authentication for MCP ( September 2025 )
|
||||
- [x] Deep Agents ( October 2025 )
|
||||
- [x] Prompt Templating ( October 2025 )
|
||||
- [x] Full api tooling ( Dec 2025 )
|
||||
- [ ] Agent scheduling ( Jan 2026 )
|
||||
- [x] Agent Workflow Builder with conditional nodes ( February 2026 )
|
||||
- [x] SharePoint & Confluence connectors ( March – April 2026 )
|
||||
- [x] Research mode ( March 2026 )
|
||||
- [x] Postgres migration for user data ( April 2026 )
|
||||
- [x] OpenTelemetry observability ( April 2026 )
|
||||
- [x] Bring Your Own Model (BYOM) ( April 2026 )
|
||||
- [ ] Agent scheduling (RedBeat-backed) ( Q2 2026 )
|
||||
|
||||
You can find our full roadmap [here](https://github.com/orgs/arc53/projects/2). Please don't hesitate to contribute or create issues, it helps us improve DocsGPT!
|
||||
|
||||
|
||||
@@ -42,6 +42,7 @@ class BaseAgent(ABC):
|
||||
llm_handler=None,
|
||||
tool_executor: Optional[ToolExecutor] = None,
|
||||
backup_models: Optional[List[str]] = None,
|
||||
model_user_id: Optional[str] = None,
|
||||
):
|
||||
self.endpoint = endpoint
|
||||
self.llm_name = llm_name
|
||||
@@ -52,10 +53,13 @@ class BaseAgent(ABC):
|
||||
self.prompt = prompt
|
||||
self.decoded_token = decoded_token or {}
|
||||
self.user: str = self.decoded_token.get("sub")
|
||||
# BYOM-resolution scope: owner for shared agents, caller for
|
||||
# caller-owned BYOM, None for built-ins. Falls back to self.user
|
||||
# for worker/legacy callers that don't thread model_user_id.
|
||||
self.model_user_id = model_user_id
|
||||
self.tools: List[Dict] = []
|
||||
self.chat_history: List[Dict] = chat_history if chat_history is not None else []
|
||||
|
||||
# Dependency injection for LLM — fall back to creating if not provided
|
||||
if llm is not None:
|
||||
self.llm = llm
|
||||
else:
|
||||
@@ -67,8 +71,16 @@ class BaseAgent(ABC):
|
||||
model_id=model_id,
|
||||
agent_id=agent_id,
|
||||
backup_models=backup_models,
|
||||
model_user_id=model_user_id,
|
||||
)
|
||||
|
||||
# For BYOM, registry id (UUID) differs from upstream model id
|
||||
# (e.g. ``mistral-large-latest``). LLMCreator resolved this onto
|
||||
# the LLM instance; cache it for subsequent gen calls.
|
||||
self.upstream_model_id = (
|
||||
getattr(self.llm, "model_id", None) or model_id
|
||||
)
|
||||
|
||||
self.retrieved_docs = retrieved_docs or []
|
||||
|
||||
if llm_handler is not None:
|
||||
@@ -306,7 +318,9 @@ class BaseAgent(ABC):
|
||||
try:
|
||||
current_tokens = self._calculate_current_context_tokens(messages)
|
||||
self.current_token_count = current_tokens
|
||||
context_limit = get_token_limit(self.model_id)
|
||||
context_limit = get_token_limit(
|
||||
self.model_id, user_id=self.model_user_id or self.user
|
||||
)
|
||||
threshold = int(context_limit * settings.COMPRESSION_THRESHOLD_PERCENTAGE)
|
||||
|
||||
if current_tokens >= threshold:
|
||||
@@ -325,7 +339,9 @@ class BaseAgent(ABC):
|
||||
|
||||
current_tokens = self._calculate_current_context_tokens(messages)
|
||||
self.current_token_count = current_tokens
|
||||
context_limit = get_token_limit(self.model_id)
|
||||
context_limit = get_token_limit(
|
||||
self.model_id, user_id=self.model_user_id or self.user
|
||||
)
|
||||
percentage = (current_tokens / context_limit) * 100
|
||||
|
||||
if current_tokens >= context_limit:
|
||||
@@ -387,7 +403,9 @@ class BaseAgent(ABC):
|
||||
)
|
||||
system_prompt = system_prompt + compression_context
|
||||
|
||||
context_limit = get_token_limit(self.model_id)
|
||||
context_limit = get_token_limit(
|
||||
self.model_id, user_id=self.model_user_id or self.user
|
||||
)
|
||||
system_tokens = num_tokens_from_string(system_prompt)
|
||||
|
||||
safety_buffer = int(context_limit * 0.1)
|
||||
@@ -497,7 +515,10 @@ class BaseAgent(ABC):
|
||||
def _llm_gen(self, messages: List[Dict], log_context: Optional[LogContext] = None):
|
||||
self._validate_context_size(messages)
|
||||
|
||||
gen_kwargs = {"model": self.model_id, "messages": messages}
|
||||
# Use the upstream id resolved by LLMCreator (see __init__).
|
||||
# Built-in models: same as self.model_id. BYOM: the user's
|
||||
# typed model name, not the internal UUID.
|
||||
gen_kwargs = {"model": self.upstream_model_id, "messages": messages}
|
||||
if self.attachments:
|
||||
gen_kwargs["_usage_attachments"] = self.attachments
|
||||
|
||||
|
||||
@@ -312,7 +312,7 @@ class ResearchAgent(BaseAgent):
|
||||
|
||||
try:
|
||||
response = self.llm.gen(
|
||||
model=self.model_id,
|
||||
model=self.upstream_model_id,
|
||||
messages=messages,
|
||||
tools=None,
|
||||
response_format={"type": "json_object"},
|
||||
@@ -390,7 +390,7 @@ class ResearchAgent(BaseAgent):
|
||||
|
||||
try:
|
||||
response = self.llm.gen(
|
||||
model=self.model_id,
|
||||
model=self.upstream_model_id,
|
||||
messages=messages,
|
||||
tools=None,
|
||||
response_format={"type": "json_object"},
|
||||
@@ -506,7 +506,7 @@ class ResearchAgent(BaseAgent):
|
||||
|
||||
try:
|
||||
response = self.llm.gen(
|
||||
model=self.model_id,
|
||||
model=self.upstream_model_id,
|
||||
messages=messages,
|
||||
tools=self.tools if self.tools else None,
|
||||
)
|
||||
@@ -537,7 +537,7 @@ class ResearchAgent(BaseAgent):
|
||||
)
|
||||
try:
|
||||
response = self.llm.gen(
|
||||
model=self.model_id, messages=messages, tools=None
|
||||
model=self.upstream_model_id, messages=messages, tools=None
|
||||
)
|
||||
self._track_tokens(self._snapshot_llm_tokens())
|
||||
text = self._extract_text(response)
|
||||
@@ -664,7 +664,7 @@ class ResearchAgent(BaseAgent):
|
||||
]
|
||||
|
||||
llm_response = self.llm.gen_stream(
|
||||
model=self.model_id, messages=messages, tools=None
|
||||
model=self.upstream_model_id, messages=messages, tools=None
|
||||
)
|
||||
|
||||
if log_context:
|
||||
|
||||
@@ -1,18 +1,107 @@
|
||||
import logging
|
||||
import uuid
|
||||
from collections import Counter
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from application.agents.tools.tool_action_parser import ToolActionParser
|
||||
from application.agents.tools.tool_manager import ToolManager
|
||||
from application.security.encryption import decrypt_credentials
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.tool_call_attempts import (
|
||||
ToolCallAttemptsRepository,
|
||||
)
|
||||
from application.storage.db.repositories.user_tools import UserToolsRepository
|
||||
from application.storage.db.session import db_readonly
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _record_proposed(
|
||||
call_id: str,
|
||||
tool_name: str,
|
||||
action_name: str,
|
||||
arguments: Any,
|
||||
*,
|
||||
tool_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Insert a ``proposed`` row; swallow infra failures so tool calls
|
||||
still run when the journal is unreachable. Returns True iff the row
|
||||
is now journaled (newly created or already present).
|
||||
"""
|
||||
try:
|
||||
with db_session() as conn:
|
||||
inserted = ToolCallAttemptsRepository(conn).record_proposed(
|
||||
call_id,
|
||||
tool_name,
|
||||
action_name,
|
||||
arguments,
|
||||
tool_id=tool_id if tool_id and looks_like_uuid(tool_id) else None,
|
||||
)
|
||||
if not inserted:
|
||||
logger.warning(
|
||||
"tool_call_attempts duplicate call_id=%s; existing row left in place",
|
||||
call_id,
|
||||
extra={"alert": "tool_call_id_collision", "call_id": call_id},
|
||||
)
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("tool_call_attempts proposed write failed for %s", call_id)
|
||||
return False
|
||||
|
||||
|
||||
def _mark_executed(
|
||||
call_id: str,
|
||||
result: Any,
|
||||
*,
|
||||
message_id: Optional[str] = None,
|
||||
artifact_id: Optional[str] = None,
|
||||
proposed_ok: bool = True,
|
||||
tool_name: Optional[str] = None,
|
||||
action_name: Optional[str] = None,
|
||||
arguments: Any = None,
|
||||
tool_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Flip the row to ``executed``. If ``proposed_ok`` is False (the
|
||||
proposed write failed earlier), upsert a fresh row in ``executed`` so
|
||||
the reconciler can still see the attempt — without this, the side
|
||||
effect would be invisible to the journal.
|
||||
"""
|
||||
try:
|
||||
with db_session() as conn:
|
||||
repo = ToolCallAttemptsRepository(conn)
|
||||
if proposed_ok:
|
||||
updated = repo.mark_executed(
|
||||
call_id,
|
||||
result,
|
||||
message_id=message_id,
|
||||
artifact_id=artifact_id,
|
||||
)
|
||||
if updated:
|
||||
return
|
||||
# Fallback synthesizes the row so the journal isn't lost.
|
||||
repo.upsert_executed(
|
||||
call_id,
|
||||
tool_name=tool_name or "unknown",
|
||||
action_name=action_name or "",
|
||||
arguments=arguments if arguments is not None else {},
|
||||
result=result,
|
||||
tool_id=tool_id if tool_id and looks_like_uuid(tool_id) else None,
|
||||
message_id=message_id,
|
||||
artifact_id=artifact_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("tool_call_attempts executed write failed for %s", call_id)
|
||||
|
||||
|
||||
def _mark_failed(call_id: str, error: str) -> None:
|
||||
try:
|
||||
with db_session() as conn:
|
||||
ToolCallAttemptsRepository(conn).mark_failed(call_id, error)
|
||||
except Exception:
|
||||
logger.exception("tool_call_attempts failed-write failed for %s", call_id)
|
||||
|
||||
|
||||
class ToolExecutor:
|
||||
"""Handles tool discovery, preparation, and execution.
|
||||
|
||||
@@ -31,6 +120,7 @@ class ToolExecutor:
|
||||
self.tool_calls: List[Dict] = []
|
||||
self._loaded_tools: Dict[str, object] = {}
|
||||
self.conversation_id: Optional[str] = None
|
||||
self.message_id: Optional[str] = None
|
||||
self.client_tools: Optional[List[Dict]] = None
|
||||
self._name_to_tool: Dict[str, Tuple[str, str]] = {}
|
||||
self._tool_to_name: Dict[Tuple[str, str], str] = {}
|
||||
@@ -274,7 +364,14 @@ class ToolExecutor:
|
||||
|
||||
if tool_id is None or action_name is None:
|
||||
error_message = f"Error: Failed to parse LLM tool call. Tool name: {llm_name}"
|
||||
logger.error(error_message)
|
||||
logger.error(
|
||||
"tool_call_parse_failed",
|
||||
extra={
|
||||
"llm_class_name": llm_class_name,
|
||||
"llm_tool_name": llm_name,
|
||||
"call_id": call_id,
|
||||
},
|
||||
)
|
||||
|
||||
tool_call_data = {
|
||||
"tool_name": "unknown",
|
||||
@@ -289,7 +386,15 @@ class ToolExecutor:
|
||||
|
||||
if tool_id not in tools_dict:
|
||||
error_message = f"Error: Tool ID '{tool_id}' extracted from LLM call not found in available tools_dict. Available IDs: {list(tools_dict.keys())}"
|
||||
logger.error(error_message)
|
||||
logger.error(
|
||||
"tool_id_not_found",
|
||||
extra={
|
||||
"tool_id": tool_id,
|
||||
"llm_tool_name": llm_name,
|
||||
"call_id": call_id,
|
||||
"available_tool_count": len(tools_dict),
|
||||
},
|
||||
)
|
||||
|
||||
tool_call_data = {
|
||||
"tool_name": "unknown",
|
||||
@@ -308,9 +413,36 @@ class ToolExecutor:
|
||||
"action_name": llm_name,
|
||||
"arguments": call_args,
|
||||
}
|
||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "pending"}}
|
||||
|
||||
tool_data = tools_dict[tool_id]
|
||||
# Journal first so the reconciler sees malformed calls and any
|
||||
# subsequent ``_mark_failed`` actually updates a real row.
|
||||
proposed_ok = _record_proposed(
|
||||
call_id,
|
||||
tool_data["name"],
|
||||
action_name,
|
||||
call_args if isinstance(call_args, dict) else {},
|
||||
tool_id=tool_data.get("id"),
|
||||
)
|
||||
# Defensive guard: a non-dict ``call_args`` (e.g. malformed
|
||||
# JSON on the resume path) would crash the param walk below
|
||||
# with AttributeError on ``.items()``. Surface a clean error
|
||||
# event and flip the journal row to ``failed`` instead of
|
||||
# killing the stream.
|
||||
if not isinstance(call_args, dict):
|
||||
error_message = (
|
||||
f"Tool call arguments must be a JSON object, got "
|
||||
f"{type(call_args).__name__}."
|
||||
)
|
||||
tool_call_data["result"] = error_message
|
||||
tool_call_data["arguments"] = {}
|
||||
_mark_failed(call_id, error_message)
|
||||
yield {
|
||||
"type": "tool_call",
|
||||
"data": {**tool_call_data, "status": "error"},
|
||||
}
|
||||
self.tool_calls.append(tool_call_data)
|
||||
return error_message, call_id
|
||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "pending"}}
|
||||
action_data = (
|
||||
tool_data["config"]["actions"][action_name]
|
||||
if tool_data["name"] == "api_tool"
|
||||
@@ -356,8 +488,17 @@ class ToolExecutor:
|
||||
f"Failed to load tool '{tool_data.get('name')}' (tool_id key={tool_id}): "
|
||||
"missing 'id' on tool row."
|
||||
)
|
||||
logger.error(error_message)
|
||||
logger.error(
|
||||
"tool_load_failed",
|
||||
extra={
|
||||
"tool_name": tool_data.get("name"),
|
||||
"tool_id": tool_id,
|
||||
"action_name": action_name,
|
||||
"call_id": call_id,
|
||||
},
|
||||
)
|
||||
tool_call_data["result"] = error_message
|
||||
_mark_failed(call_id, error_message)
|
||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
|
||||
self.tool_calls.append(tool_call_data)
|
||||
return error_message, call_id
|
||||
@@ -367,14 +508,18 @@ class ToolExecutor:
|
||||
if tool_data["name"] == "api_tool"
|
||||
else parameters
|
||||
)
|
||||
if tool_data["name"] == "api_tool":
|
||||
logger.debug(
|
||||
f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}"
|
||||
)
|
||||
result = tool.execute_action(action_name, **body)
|
||||
else:
|
||||
logger.debug(f"Executing tool: {action_name} with args: {call_args}")
|
||||
result = tool.execute_action(action_name, **parameters)
|
||||
try:
|
||||
if tool_data["name"] == "api_tool":
|
||||
logger.debug(
|
||||
f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}"
|
||||
)
|
||||
result = tool.execute_action(action_name, **body)
|
||||
else:
|
||||
logger.debug(f"Executing tool: {action_name} with args: {call_args}")
|
||||
result = tool.execute_action(action_name, **parameters)
|
||||
except Exception as exc:
|
||||
_mark_failed(call_id, str(exc))
|
||||
raise
|
||||
|
||||
get_artifact_id = (
|
||||
getattr(tool, "get_artifact_id", None)
|
||||
@@ -403,6 +548,22 @@ class ToolExecutor:
|
||||
f"{result_full[:50]}..." if len(result_full) > 50 else result_full
|
||||
)
|
||||
|
||||
# Tool side effect has run; flip the journal row so the
|
||||
# message-finalize path can later confirm it. If the proposed
|
||||
# write failed (DB outage), upsert a fresh row in ``executed`` so
|
||||
# the reconciler still sees the side effect.
|
||||
_mark_executed(
|
||||
call_id,
|
||||
result_full,
|
||||
message_id=self.message_id,
|
||||
artifact_id=artifact_id or None,
|
||||
proposed_ok=proposed_ok,
|
||||
tool_name=tool_data["name"],
|
||||
action_name=action_name,
|
||||
arguments=call_args,
|
||||
tool_id=tool_data.get("id"),
|
||||
)
|
||||
|
||||
stream_tool_call_data = {
|
||||
key: value
|
||||
for key, value in tool_call_data.items()
|
||||
@@ -451,10 +612,12 @@ class ToolExecutor:
|
||||
row_id = tool_data.get("id")
|
||||
if not row_id:
|
||||
logger.error(
|
||||
"Tool data missing 'id' for tool name=%s (enumerate-key tool_id=%s); "
|
||||
"skipping load to avoid binding a non-UUID downstream.",
|
||||
tool_data.get("name"),
|
||||
tool_id,
|
||||
"tool_missing_row_id",
|
||||
extra={
|
||||
"tool_name": tool_data.get("name"),
|
||||
"tool_id": tool_id,
|
||||
"action_name": action_name,
|
||||
},
|
||||
)
|
||||
return None
|
||||
tool_config["tool_id"] = str(row_id)
|
||||
|
||||
@@ -39,6 +39,7 @@ class InternalSearchTool(Tool):
|
||||
chunks=int(self.config.get("chunks", 2)),
|
||||
doc_token_limit=int(self.config.get("doc_token_limit", 50000)),
|
||||
model_id=self.config.get("model_id", "docsgpt-local"),
|
||||
model_user_id=self.config.get("model_user_id"),
|
||||
user_api_key=self.config.get("user_api_key"),
|
||||
agent_id=self.config.get("agent_id"),
|
||||
llm_name=self.config.get("llm_name", settings.LLM_PROVIDER),
|
||||
@@ -435,6 +436,7 @@ def build_internal_tool_config(
|
||||
chunks: int = 2,
|
||||
doc_token_limit: int = 50000,
|
||||
model_id: str = "docsgpt-local",
|
||||
model_user_id: Optional[str] = None,
|
||||
user_api_key: Optional[str] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
llm_name: str = None,
|
||||
@@ -449,6 +451,7 @@ def build_internal_tool_config(
|
||||
"chunks": chunks,
|
||||
"doc_token_limit": doc_token_limit,
|
||||
"model_id": model_id,
|
||||
"model_user_id": model_user_id,
|
||||
"user_api_key": user_api_key,
|
||||
"agent_id": agent_id,
|
||||
"llm_name": llm_name or settings.LLM_PROVIDER,
|
||||
|
||||
@@ -20,10 +20,11 @@ from pydantic import AnyHttpUrl, ValidationError
|
||||
from redis import Redis
|
||||
|
||||
from application.agents.tools.base import Tool
|
||||
from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task
|
||||
from application.api.user.tasks import mcp_oauth_task
|
||||
from application.cache import get_redis_instance
|
||||
from application.core.settings import settings
|
||||
from application.core.url_validation import SSRFError, validate_url
|
||||
from application.events.keys import stream_key
|
||||
from application.security.encryption import decrypt_credentials
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -76,6 +77,12 @@ class MCPTool(Tool):
|
||||
self.oauth_task_id = config.get("oauth_task_id", None)
|
||||
self.oauth_client_name = config.get("oauth_client_name", "DocsGPT-MCP")
|
||||
self.redirect_uri = self._resolve_redirect_uri(config.get("redirect_uri"))
|
||||
# Pulled out of ``config`` (rather than left in ``self.config``)
|
||||
# because it is a callable supplied by the OAuth worker — not
|
||||
# something the rest of the tool plumbing should marshal or
|
||||
# serialize. ``DocsGPTOAuth`` invokes it from ``redirect_handler``
|
||||
# so the SSE envelope can carry ``authorization_url``.
|
||||
self.oauth_redirect_publish = config.pop("oauth_redirect_publish", None)
|
||||
|
||||
self.available_tools = []
|
||||
self._cache_key = self._generate_cache_key()
|
||||
@@ -167,6 +174,7 @@ class MCPTool(Tool):
|
||||
redirect_uri=self.redirect_uri,
|
||||
task_id=self.oauth_task_id,
|
||||
user_id=self.user_id,
|
||||
redirect_publish=self.oauth_redirect_publish,
|
||||
)
|
||||
elif self.auth_type == "bearer":
|
||||
token = self.auth_credentials.get(
|
||||
@@ -679,12 +687,17 @@ class DocsGPTOAuth(OAuthClientProvider):
|
||||
user_id=None,
|
||||
additional_client_metadata: dict[str, Any] | None = None,
|
||||
skip_redirect_validation: bool = False,
|
||||
redirect_publish=None,
|
||||
):
|
||||
self.redirect_uri = redirect_uri
|
||||
self.redis_client = redis_client
|
||||
self.redis_prefix = redis_prefix
|
||||
self.task_id = task_id
|
||||
self.user_id = user_id
|
||||
# Worker-supplied callback. Invoked from ``redirect_handler``
|
||||
# once the authorization URL is known so the SSE envelope can
|
||||
# carry it. ``None`` for any non-worker entrypoint.
|
||||
self.redirect_publish = redirect_publish
|
||||
|
||||
parsed_url = urlparse(mcp_url)
|
||||
self.server_base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||||
@@ -744,17 +757,19 @@ class DocsGPTOAuth(OAuthClientProvider):
|
||||
self.redis_client.setex(key, 600, auth_url)
|
||||
logger.info("Stored auth_url in Redis: %s", key)
|
||||
|
||||
if self.task_id:
|
||||
status_key = f"mcp_oauth_status:{self.task_id}"
|
||||
status_data = {
|
||||
"status": "requires_redirect",
|
||||
"message": "Authorization required",
|
||||
"authorization_url": self.auth_url,
|
||||
"state": self.extracted_state,
|
||||
"requires_oauth": True,
|
||||
"task_id": self.task_id,
|
||||
}
|
||||
self.redis_client.setex(status_key, 600, json.dumps(status_data))
|
||||
if self.redirect_publish is not None:
|
||||
# Best-effort: a publish failure must not abort the OAuth
|
||||
# handshake — the user can still authorize via the popup
|
||||
# opened from the legacy polling fallback if the SSE
|
||||
# envelope is lost.
|
||||
try:
|
||||
self.redirect_publish(auth_url)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"redirect_publish callback raised for task_id=%s",
|
||||
self.task_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
async def callback_handler(self) -> tuple[str, str | None]:
|
||||
"""Wait for auth code from Redis using the state value."""
|
||||
@@ -764,17 +779,6 @@ class DocsGPTOAuth(OAuthClientProvider):
|
||||
max_wait_time = 300
|
||||
code_key = f"{self.redis_prefix}code:{self.extracted_state}"
|
||||
|
||||
if self.task_id:
|
||||
status_key = f"mcp_oauth_status:{self.task_id}"
|
||||
status_data = {
|
||||
"status": "awaiting_callback",
|
||||
"message": "Waiting for authorization...",
|
||||
"authorization_url": self.auth_url,
|
||||
"state": self.extracted_state,
|
||||
"requires_oauth": True,
|
||||
"task_id": self.task_id,
|
||||
}
|
||||
self.redis_client.setex(status_key, 600, json.dumps(status_data))
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < max_wait_time:
|
||||
code_data = self.redis_client.get(code_key)
|
||||
@@ -789,14 +793,6 @@ class DocsGPTOAuth(OAuthClientProvider):
|
||||
self.redis_client.delete(
|
||||
f"{self.redis_prefix}state:{self.extracted_state}"
|
||||
)
|
||||
|
||||
if self.task_id:
|
||||
status_data = {
|
||||
"status": "callback_received",
|
||||
"message": "Completing authentication...",
|
||||
"task_id": self.task_id,
|
||||
}
|
||||
self.redis_client.setex(status_key, 600, json.dumps(status_data))
|
||||
return code, returned_state
|
||||
error_key = f"{self.redis_prefix}error:{self.extracted_state}"
|
||||
error_data = self.redis_client.get(error_key)
|
||||
@@ -1038,8 +1034,73 @@ class MCPOAuthManager:
|
||||
logger.error("Error handling OAuth callback: %s", e)
|
||||
return False
|
||||
|
||||
def get_oauth_status(self, task_id: str) -> Dict[str, Any]:
|
||||
"""Get current status of OAuth flow using provided task_id."""
|
||||
def get_oauth_status(self, task_id: str, user_id: str) -> Dict[str, Any]:
|
||||
"""Return the latest OAuth status for ``task_id`` from the user's SSE journal.
|
||||
|
||||
Mirrors the legacy polling contract: ``status`` derived from the
|
||||
``mcp.oauth.*`` event-type suffix, with payload fields surfaced
|
||||
(e.g. ``tools``/``tools_count`` on ``completed``).
|
||||
"""
|
||||
if not task_id:
|
||||
return {"status": "not_started", "message": "OAuth flow not started"}
|
||||
return mcp_oauth_status_task(task_id)
|
||||
if not user_id:
|
||||
return {"status": "not_found", "message": "User not provided"}
|
||||
if self.redis_client is None:
|
||||
return {"status": "not_found", "message": "Redis unavailable"}
|
||||
|
||||
try:
|
||||
# OAuth flows are short-lived but a concurrent source
|
||||
# ingest can flood the user channel between the OAuth
|
||||
# popup completing and the user clicking Save, pushing the
|
||||
# completion envelope outside the read window. Bound the
|
||||
# scan by the configured stream cap so we cover the full
|
||||
# journal — XADD MAXLEN keeps that bounded too.
|
||||
scan_count = max(settings.EVENTS_STREAM_MAXLEN, 200)
|
||||
entries = self.redis_client.xrevrange(
|
||||
stream_key(user_id), count=scan_count
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"xrevrange failed for oauth status: user_id=%s task_id=%s",
|
||||
user_id,
|
||||
task_id,
|
||||
)
|
||||
return {"status": "not_found", "message": "Status unavailable"}
|
||||
|
||||
for _entry_id, fields in entries:
|
||||
if not isinstance(fields, dict):
|
||||
continue
|
||||
# decode_responses=False ⇒ bytes keys; the string-key fallback
|
||||
# covers a future flip of that default without a forced refactor.
|
||||
event_raw = fields.get(b"event")
|
||||
if event_raw is None:
|
||||
event_raw = fields.get("event")
|
||||
if event_raw is None:
|
||||
continue
|
||||
if isinstance(event_raw, bytes):
|
||||
try:
|
||||
event_raw = event_raw.decode("utf-8")
|
||||
except Exception:
|
||||
continue
|
||||
try:
|
||||
envelope = json.loads(event_raw)
|
||||
except Exception:
|
||||
continue
|
||||
if not isinstance(envelope, dict):
|
||||
continue
|
||||
event_type = envelope.get("type", "")
|
||||
if not isinstance(event_type, str) or not event_type.startswith(
|
||||
"mcp.oauth."
|
||||
):
|
||||
continue
|
||||
scope = envelope.get("scope") or {}
|
||||
if scope.get("kind") != "mcp_oauth" or scope.get("id") != task_id:
|
||||
continue
|
||||
payload = envelope.get("payload") or {}
|
||||
return {
|
||||
"status": event_type[len("mcp.oauth."):],
|
||||
"task_id": task_id,
|
||||
**payload,
|
||||
}
|
||||
|
||||
return {"status": "not_found", "message": "Status not found"}
|
||||
|
||||
@@ -177,3 +177,4 @@ class PostgresTool(Tool):
|
||||
"order": 1,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -57,6 +57,29 @@ class ToolActionParser:
|
||||
def _parse_google_llm(self, call):
|
||||
try:
|
||||
call_args = call.arguments
|
||||
# Gemini's SDK natively returns ``args`` as a dict, but the
|
||||
# resume path (``gen_continuation``) stringifies it for the
|
||||
# assistant message. Coerce a JSON string back into a dict;
|
||||
# fall back to an empty dict on malformed input so downstream
|
||||
# ``call_args.items()`` doesn't crash the stream.
|
||||
if isinstance(call_args, str):
|
||||
try:
|
||||
call_args = json.loads(call_args)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logger.warning(
|
||||
"Google call.arguments was not valid JSON; "
|
||||
"falling back to empty args for %s",
|
||||
getattr(call, "name", "<unknown>"),
|
||||
)
|
||||
call_args = {}
|
||||
if not isinstance(call_args, dict):
|
||||
logger.warning(
|
||||
"Google call.arguments has unexpected type %s; "
|
||||
"falling back to empty args for %s",
|
||||
type(call_args).__name__,
|
||||
getattr(call, "name", "<unknown>"),
|
||||
)
|
||||
call_args = {}
|
||||
|
||||
resolved = self._resolve_via_mapping(call.name)
|
||||
if resolved:
|
||||
|
||||
@@ -211,15 +211,26 @@ class WorkflowEngine:
|
||||
node_config.json_schema, node.title
|
||||
)
|
||||
node_model_id = node_config.model_id or self.agent.model_id
|
||||
# Inherit BYOM scope from parent agent so owner-stored BYOM
|
||||
# resolves on shared workflows.
|
||||
node_user_id = getattr(self.agent, "model_user_id", None) or (
|
||||
self.agent.decoded_token.get("sub")
|
||||
if isinstance(self.agent.decoded_token, dict)
|
||||
else None
|
||||
)
|
||||
node_llm_name = (
|
||||
node_config.llm_name
|
||||
or get_provider_from_model_id(node_model_id or "")
|
||||
or get_provider_from_model_id(
|
||||
node_model_id or "", user_id=node_user_id
|
||||
)
|
||||
or self.agent.llm_name
|
||||
)
|
||||
node_api_key = get_api_key_for_provider(node_llm_name) or self.agent.api_key
|
||||
|
||||
if node_json_schema and node_model_id:
|
||||
model_capabilities = get_model_capabilities(node_model_id)
|
||||
model_capabilities = get_model_capabilities(
|
||||
node_model_id, user_id=node_user_id
|
||||
)
|
||||
if model_capabilities and not model_capabilities.get(
|
||||
"supports_structured_output", False
|
||||
):
|
||||
@@ -232,6 +243,7 @@ class WorkflowEngine:
|
||||
"endpoint": self.agent.endpoint,
|
||||
"llm_name": node_llm_name,
|
||||
"model_id": node_model_id,
|
||||
"model_user_id": getattr(self.agent, "model_user_id", None),
|
||||
"api_key": node_api_key,
|
||||
"tool_ids": node_config.tools,
|
||||
"prompt": node_config.system_prompt,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""0001 initial schema — consolidated Phase-1..3 baseline.
|
||||
"""0001 initial schema — consolidated baseline for user-data tables.
|
||||
|
||||
Revision ID: 0001_initial
|
||||
Revises:
|
||||
|
||||
65
application/alembic/versions/0003_user_custom_models.py
Normal file
65
application/alembic/versions/0003_user_custom_models.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""0003 user_custom_models — per-user OpenAI-compatible model registrations.
|
||||
|
||||
Revision ID: 0003_user_custom_models
|
||||
Revises: 0002_app_metadata
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
revision: str = "0003_user_custom_models"
|
||||
down_revision: Union[str, None] = "0002_app_metadata"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE user_custom_models (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
upstream_model_id TEXT NOT NULL,
|
||||
display_name TEXT NOT NULL,
|
||||
description TEXT NOT NULL DEFAULT '',
|
||||
base_url TEXT NOT NULL,
|
||||
api_key_encrypted TEXT NOT NULL,
|
||||
capabilities JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
enabled BOOLEAN NOT NULL DEFAULT true,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX user_custom_models_user_id_idx "
|
||||
"ON user_custom_models (user_id);"
|
||||
)
|
||||
|
||||
# Mirror the project-wide invariants set up in 0001_initial:
|
||||
# * user_id FK with ON DELETE RESTRICT (deferrable),
|
||||
# * ensure_user_exists() trigger so the parent users row autocreates,
|
||||
# * set_updated_at() trigger.
|
||||
op.execute(
|
||||
"ALTER TABLE user_custom_models "
|
||||
"ADD CONSTRAINT user_custom_models_user_id_fk "
|
||||
"FOREIGN KEY (user_id) REFERENCES users(user_id) "
|
||||
"ON DELETE RESTRICT DEFERRABLE INITIALLY IMMEDIATE;"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE TRIGGER user_custom_models_ensure_user "
|
||||
"BEFORE INSERT OR UPDATE OF user_id ON user_custom_models "
|
||||
"FOR EACH ROW EXECUTE FUNCTION ensure_user_exists();"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE TRIGGER user_custom_models_set_updated_at "
|
||||
"BEFORE UPDATE ON user_custom_models "
|
||||
"FOR EACH ROW WHEN (OLD.* IS DISTINCT FROM NEW.*) "
|
||||
"EXECUTE FUNCTION set_updated_at();"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP TABLE IF EXISTS user_custom_models;")
|
||||
217
application/alembic/versions/0004_durability_foundation.py
Normal file
217
application/alembic/versions/0004_durability_foundation.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""0004 durability foundation — idempotency, tool-call log, ingest checkpoint.
|
||||
|
||||
Adds ``task_dedup``, ``webhook_dedup``, ``tool_call_attempts``,
|
||||
``ingest_chunk_progress``, and per-row status flags on
|
||||
``conversation_messages`` and ``pending_tool_state``. Also adds
|
||||
``token_usage.source`` and ``token_usage.request_id`` so per-channel
|
||||
cost attribution (``agent_stream`` / ``title`` / ``compression`` /
|
||||
``rag_condense`` / ``fallback``) is queryable and multi-call agent runs
|
||||
can be DISTINCT-collapsed into a single user request for rate limiting.
|
||||
|
||||
Revision ID: 0004_durability_foundation
|
||||
Revises: 0003_user_custom_models
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
revision: str = "0004_durability_foundation"
|
||||
down_revision: Union[str, None] = "0003_user_custom_models"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ------------------------------------------------------------------
|
||||
# New tables
|
||||
# ------------------------------------------------------------------
|
||||
# ``attempt_count`` bounds the per-Celery-task idempotency wrapper's
|
||||
# retry loop so a poison message can't run forever; default 0 means
|
||||
# existing rows behave as if no attempts have run yet.
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE task_dedup (
|
||||
idempotency_key TEXT PRIMARY KEY,
|
||||
task_name TEXT NOT NULL,
|
||||
task_id TEXT NOT NULL,
|
||||
result_json JSONB,
|
||||
status TEXT NOT NULL
|
||||
CHECK (status IN ('pending', 'completed', 'failed')),
|
||||
attempt_count INT NOT NULL DEFAULT 0,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE webhook_dedup (
|
||||
idempotency_key TEXT PRIMARY KEY,
|
||||
agent_id UUID NOT NULL,
|
||||
task_id TEXT NOT NULL,
|
||||
response_json JSONB,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# FK on ``message_id`` uses ``ON DELETE SET NULL`` so the journal row
|
||||
# survives parent-message deletion (compliance / cost-attribution).
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE tool_call_attempts (
|
||||
call_id TEXT PRIMARY KEY,
|
||||
message_id UUID
|
||||
REFERENCES conversation_messages (id)
|
||||
ON DELETE SET NULL,
|
||||
tool_id UUID,
|
||||
tool_name TEXT NOT NULL,
|
||||
action_name TEXT NOT NULL,
|
||||
arguments JSONB NOT NULL,
|
||||
result JSONB,
|
||||
error TEXT,
|
||||
status TEXT NOT NULL
|
||||
CHECK (status IN (
|
||||
'proposed', 'executed', 'confirmed',
|
||||
'compensated', 'failed'
|
||||
)),
|
||||
attempted_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE ingest_chunk_progress (
|
||||
source_id UUID PRIMARY KEY,
|
||||
total_chunks INT NOT NULL,
|
||||
embedded_chunks INT NOT NULL DEFAULT 0,
|
||||
last_index INT NOT NULL DEFAULT -1,
|
||||
last_updated TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Column additions on existing tables
|
||||
# ------------------------------------------------------------------
|
||||
# DEFAULT 'complete' backfills existing rows — they're already done.
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE conversation_messages
|
||||
ADD COLUMN status TEXT NOT NULL DEFAULT 'complete'
|
||||
CHECK (status IN ('pending', 'streaming', 'complete', 'failed')),
|
||||
ADD COLUMN request_id TEXT;
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE pending_tool_state
|
||||
ADD COLUMN status TEXT NOT NULL DEFAULT 'pending'
|
||||
CHECK (status IN ('pending', 'resuming')),
|
||||
ADD COLUMN resumed_at TIMESTAMPTZ;
|
||||
"""
|
||||
)
|
||||
|
||||
# Default ``agent_stream`` backfills historical rows under the
|
||||
# assumption they were written from the primary path — pre-fix the
|
||||
# only path that wrote was the error branch reading agent.llm.
|
||||
# ``request_id`` is the stream-scoped UUID stamped by the route on
|
||||
# ``agent.llm`` so multi-tool agent runs (which produce N rows)
|
||||
# collapse to one request via DISTINCT in ``count_in_range``.
|
||||
# Side-channel sources (``title`` / ``compression`` / ``rag_condense``
|
||||
# / ``fallback``) leave it NULL and are excluded from the request
|
||||
# count by source filter.
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE token_usage
|
||||
ADD COLUMN source TEXT NOT NULL DEFAULT 'agent_stream',
|
||||
ADD COLUMN request_id TEXT;
|
||||
"""
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Indexes — partial where the predicate selects only non-terminal rows
|
||||
# ------------------------------------------------------------------
|
||||
op.execute(
|
||||
"CREATE INDEX conversation_messages_pending_ts_idx "
|
||||
"ON conversation_messages (timestamp) "
|
||||
"WHERE status IN ('pending', 'streaming');"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX tool_call_attempts_pending_ts_idx "
|
||||
"ON tool_call_attempts (attempted_at) "
|
||||
"WHERE status IN ('proposed', 'executed');"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX tool_call_attempts_message_idx "
|
||||
"ON tool_call_attempts (message_id) "
|
||||
"WHERE message_id IS NOT NULL;"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX pending_tool_state_resuming_ts_idx "
|
||||
"ON pending_tool_state (resumed_at) "
|
||||
"WHERE status = 'resuming';"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX webhook_dedup_agent_idx "
|
||||
"ON webhook_dedup (agent_id);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX task_dedup_pending_attempts_idx "
|
||||
"ON task_dedup (attempt_count) WHERE status = 'pending';"
|
||||
)
|
||||
# Cost-attribution dashboards filter ``token_usage`` by
|
||||
# ``(timestamp, source)``; index the same shape so they stay cheap.
|
||||
op.execute(
|
||||
"CREATE INDEX token_usage_source_ts_idx "
|
||||
"ON token_usage (source, timestamp);"
|
||||
)
|
||||
# Partial index — only rows with a stamped request_id participate
|
||||
# in the DISTINCT count. NULL rows fall through to the COUNT(*)
|
||||
# branch in the repository query.
|
||||
op.execute(
|
||||
"CREATE INDEX token_usage_request_id_idx "
|
||||
"ON token_usage (request_id) "
|
||||
"WHERE request_id IS NOT NULL;"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"CREATE TRIGGER tool_call_attempts_set_updated_at "
|
||||
"BEFORE UPDATE ON tool_call_attempts "
|
||||
"FOR EACH ROW WHEN (OLD.* IS DISTINCT FROM NEW.*) "
|
||||
"EXECUTE FUNCTION set_updated_at();"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# CASCADE so the downgrade stays safe if later migrations FK into these.
|
||||
for table in (
|
||||
"ingest_chunk_progress",
|
||||
"tool_call_attempts",
|
||||
"webhook_dedup",
|
||||
"task_dedup",
|
||||
):
|
||||
op.execute(f"DROP TABLE IF EXISTS {table} CASCADE;")
|
||||
|
||||
op.execute(
|
||||
"ALTER TABLE conversation_messages "
|
||||
"DROP COLUMN IF EXISTS request_id, "
|
||||
"DROP COLUMN IF EXISTS status;"
|
||||
)
|
||||
op.execute(
|
||||
"ALTER TABLE pending_tool_state "
|
||||
"DROP COLUMN IF EXISTS resumed_at, "
|
||||
"DROP COLUMN IF EXISTS status;"
|
||||
)
|
||||
op.execute("DROP INDEX IF EXISTS token_usage_request_id_idx;")
|
||||
op.execute("DROP INDEX IF EXISTS token_usage_source_ts_idx;")
|
||||
op.execute(
|
||||
"ALTER TABLE token_usage "
|
||||
"DROP COLUMN IF EXISTS request_id, "
|
||||
"DROP COLUMN IF EXISTS source;"
|
||||
)
|
||||
44
application/alembic/versions/0005_ingest_attempt_id.py
Normal file
44
application/alembic/versions/0005_ingest_attempt_id.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""0005 ingest_chunk_progress.attempt_id — per-attempt resume scoping.
|
||||
|
||||
Without this column, a completed checkpoint row poisoned every later
|
||||
embed call on the same ``source_id``: a sync after an upload finished
|
||||
read the upload's terminal ``last_index`` and either embedded zero
|
||||
chunks (if new ``total_docs <= last_index + 1``) or stacked new chunks
|
||||
on top of the old vectors (if ``total_docs > last_index + 1``).
|
||||
|
||||
``attempt_id`` is stamped from ``self.request.id`` (Celery's stable
|
||||
task id, which survives ``acks_late`` retries of the same task but
|
||||
differs across separate task invocations). The repository's
|
||||
``init_progress`` upsert resets ``last_index`` / ``embedded_chunks``
|
||||
when the incoming ``attempt_id`` differs from the stored one — so a
|
||||
fresh sync starts from chunk 0 while a retry of the same task resumes
|
||||
from the last checkpointed chunk.
|
||||
|
||||
Revision ID: 0005_ingest_attempt_id
|
||||
Revises: 0004_durability_foundation
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
revision: str = "0005_ingest_attempt_id"
|
||||
down_revision: Union[str, None] = "0004_durability_foundation"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE ingest_chunk_progress
|
||||
ADD COLUMN attempt_id TEXT;
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
"ALTER TABLE ingest_chunk_progress DROP COLUMN IF EXISTS attempt_id;"
|
||||
)
|
||||
57
application/alembic/versions/0006_idempotency_lease.py
Normal file
57
application/alembic/versions/0006_idempotency_lease.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""0006 task_dedup lease columns — running-lease for in-flight tasks.
|
||||
|
||||
Without these, ``with_idempotency`` only short-circuits *completed*
|
||||
rows. A late-ack redelivery (Redis ``visibility_timeout`` exceeded by a
|
||||
long ingest, or a hung-but-alive worker) hands the same message to a
|
||||
second worker; ``_claim_or_bump`` only bumped the attempt counter and
|
||||
both workers ran the task body in parallel — duplicate vector writes,
|
||||
duplicate token spend, duplicate webhook side effects.
|
||||
|
||||
``lease_owner_id`` + ``lease_expires_at`` turn that into an atomic
|
||||
compare-and-swap. The wrapper claims a lease at entry, refreshes it via
|
||||
a 30 s heartbeat thread, and finalises (which makes the lease moot via
|
||||
``status='completed'``). A second worker hitting the same key sees a
|
||||
fresh lease and ``self.retry(countdown=LEASE_TTL)``s instead of running.
|
||||
A crashed worker's lease expires after ``LEASE_TTL`` seconds and the
|
||||
next retry can claim it.
|
||||
|
||||
Revision ID: 0006_idempotency_lease
|
||||
Revises: 0005_ingest_attempt_id
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
revision: str = "0006_idempotency_lease"
|
||||
down_revision: Union[str, None] = "0005_ingest_attempt_id"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE task_dedup
|
||||
ADD COLUMN lease_owner_id TEXT,
|
||||
ADD COLUMN lease_expires_at TIMESTAMPTZ;
|
||||
"""
|
||||
)
|
||||
# Reconciler's stuck-pending sweep filters by
|
||||
# ``(status='pending', lease_expires_at < now() - 60s, attempt_count >= 5)``.
|
||||
# Partial index keeps the scan small even under heavy task throughput.
|
||||
op.execute(
|
||||
"CREATE INDEX task_dedup_pending_lease_idx "
|
||||
"ON task_dedup (lease_expires_at) "
|
||||
"WHERE status = 'pending';"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP INDEX IF EXISTS task_dedup_pending_lease_idx;")
|
||||
op.execute(
|
||||
"ALTER TABLE task_dedup "
|
||||
"DROP COLUMN IF EXISTS lease_expires_at, "
|
||||
"DROP COLUMN IF EXISTS lease_owner_id;"
|
||||
)
|
||||
40
application/alembic/versions/0007_message_events.py
Normal file
40
application/alembic/versions/0007_message_events.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""0007 message_events — durable journal of chat-stream events.
|
||||
|
||||
Snapshot half of the chat-stream snapshot+tail pattern. Composite PK
|
||||
``(message_id, sequence_no)``, ``created_at`` indexed for retention
|
||||
sweeps, ``ON DELETE CASCADE`` from ``conversation_messages``.
|
||||
|
||||
Revision ID: 0007_message_events
|
||||
Revises: 0006_idempotency_lease
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
revision: str = "0007_message_events"
|
||||
down_revision: Union[str, None] = "0006_idempotency_lease"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE message_events (
|
||||
message_id UUID NOT NULL REFERENCES conversation_messages(id) ON DELETE CASCADE,
|
||||
sequence_no INTEGER NOT NULL,
|
||||
event_type TEXT NOT NULL,
|
||||
payload JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
PRIMARY KEY (message_id, sequence_no)
|
||||
);
|
||||
CREATE INDEX message_events_created_at_idx ON message_events(created_at);
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP INDEX IF EXISTS message_events_created_at_idx;")
|
||||
op.execute("DROP TABLE IF EXISTS message_events;")
|
||||
@@ -102,6 +102,8 @@ class AnswerResource(Resource, BaseAnswerResource):
|
||||
"tools_dict": tools_dict,
|
||||
"pending_tool_calls": pending_tool_calls,
|
||||
"tool_actions": tool_actions,
|
||||
"reserved_message_id": processor.reserved_message_id,
|
||||
"request_id": processor.request_id,
|
||||
},
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, Generator, List, Optional
|
||||
|
||||
from flask import jsonify, make_response, Response
|
||||
from flask_restx import Namespace
|
||||
|
||||
from application.api.answer.services.continuation_service import ContinuationService
|
||||
from application.api.answer.services.conversation_service import ConversationService
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
TERMINATED_RESPONSE_PLACEHOLDER,
|
||||
)
|
||||
from application.core.model_utils import (
|
||||
get_api_key_for_provider,
|
||||
get_default_model_id,
|
||||
@@ -18,9 +23,16 @@ from application.core.settings import settings
|
||||
from application.error import sanitize_api_error
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.conversations import MessageUpdateOutcome
|
||||
from application.storage.db.repositories.token_usage import TokenUsageRepository
|
||||
from application.storage.db.repositories.user_logs import UserLogsRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.events.publisher import publish_user_event
|
||||
from application.streaming.event_replay import format_sse_event
|
||||
from application.streaming.message_journal import (
|
||||
BatchedJournalWriter,
|
||||
record_event,
|
||||
)
|
||||
from application.utils import check_required_fields
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -177,6 +189,7 @@ class BaseAnswerResource:
|
||||
is_shared_usage: bool = False,
|
||||
shared_token: Optional[str] = None,
|
||||
model_id: Optional[str] = None,
|
||||
model_user_id: Optional[str] = None,
|
||||
_continuation: Optional[Dict] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
@@ -202,13 +215,188 @@ class BaseAnswerResource:
|
||||
Yields:
|
||||
Server-sent event strings
|
||||
"""
|
||||
response_full, thought, source_log_docs, tool_calls = "", "", [], []
|
||||
is_structured = False
|
||||
schema_info = None
|
||||
structured_chunks = []
|
||||
query_metadata: Dict[str, Any] = {}
|
||||
paused = False
|
||||
|
||||
# One id shared across the WAL row, primary LLM (token_usage
|
||||
# attribution), the SSE event, and resumed continuations.
|
||||
request_id = (
|
||||
_continuation.get("request_id") if _continuation else None
|
||||
) or str(uuid.uuid4())
|
||||
|
||||
# Reserve the placeholder row before the LLM call so a crash
|
||||
# mid-stream still leaves the question queryable. Continuations
|
||||
# reuse the original placeholder.
|
||||
reserved_message_id: Optional[str] = None
|
||||
wal_eligible = should_save_conversation and not _continuation
|
||||
if wal_eligible:
|
||||
try:
|
||||
reservation = self.conversation_service.save_user_question(
|
||||
conversation_id=conversation_id,
|
||||
question=question,
|
||||
decoded_token=decoded_token,
|
||||
attachment_ids=attachment_ids,
|
||||
api_key=user_api_key,
|
||||
agent_id=agent_id,
|
||||
is_shared_usage=is_shared_usage,
|
||||
shared_token=shared_token,
|
||||
model_id=model_id or self.default_model_id,
|
||||
request_id=request_id,
|
||||
index=index,
|
||||
)
|
||||
conversation_id = reservation["conversation_id"]
|
||||
reserved_message_id = reservation["message_id"]
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to reserve message row before stream: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
elif _continuation and _continuation.get("reserved_message_id"):
|
||||
reserved_message_id = _continuation["reserved_message_id"]
|
||||
|
||||
primary_llm = getattr(agent, "llm", None)
|
||||
if primary_llm is not None:
|
||||
primary_llm._request_id = request_id
|
||||
|
||||
# Flipped to ``streaming`` on first chunk; reconciler uses this
|
||||
# to tell "never started" from "in flight".
|
||||
streaming_marked = False
|
||||
# Heartbeat goes into ``metadata.last_heartbeat_at`` (not
|
||||
# ``updated_at``, which reconciler-side writes share) and uses
|
||||
# ``time.monotonic`` so a blocked event loop can't fake fresh.
|
||||
STREAM_HEARTBEAT_INTERVAL = 60
|
||||
last_heartbeat_at = time.monotonic()
|
||||
|
||||
def _mark_streaming_once() -> None:
|
||||
nonlocal streaming_marked, last_heartbeat_at
|
||||
if streaming_marked or not reserved_message_id:
|
||||
return
|
||||
try:
|
||||
self.conversation_service.update_message_status(
|
||||
reserved_message_id, "streaming",
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"update_message_status streaming failed for %s",
|
||||
reserved_message_id,
|
||||
)
|
||||
# Seed last_heartbeat_at so watchdog doesn't fall back to `timestamp`
|
||||
# (creation time) before the first STREAM_HEARTBEAT_INTERVAL tick.
|
||||
try:
|
||||
self.conversation_service.heartbeat_message(
|
||||
reserved_message_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"initial heartbeat seed failed for %s",
|
||||
reserved_message_id,
|
||||
)
|
||||
streaming_marked = True
|
||||
last_heartbeat_at = time.monotonic()
|
||||
|
||||
def _heartbeat_streaming() -> None:
|
||||
nonlocal last_heartbeat_at
|
||||
if not reserved_message_id or not streaming_marked:
|
||||
return
|
||||
now_mono = time.monotonic()
|
||||
if now_mono - last_heartbeat_at < STREAM_HEARTBEAT_INTERVAL:
|
||||
return
|
||||
try:
|
||||
self.conversation_service.heartbeat_message(
|
||||
reserved_message_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"stream heartbeat update failed for %s",
|
||||
reserved_message_id,
|
||||
)
|
||||
last_heartbeat_at = now_mono
|
||||
|
||||
# Correlates tool_call_attempts rows with this message.
|
||||
if reserved_message_id and getattr(agent, "tool_executor", None):
|
||||
try:
|
||||
agent.tool_executor.message_id = reserved_message_id
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Could not set tool_executor.message_id; tool-call correlation will be missing for message_id=%s",
|
||||
reserved_message_id,
|
||||
)
|
||||
|
||||
# Per-stream monotonic SSE event id. Allocated by ``_emit`` and
|
||||
# threaded through both the wire format (``id: <seq>\\n``) and
|
||||
# the journal write so a reconnecting client can ``Last-Event-
|
||||
# ID`` past anything they already saw. Continuations resume
|
||||
# against the original ``reserved_message_id`` — seed the
|
||||
# allocator from the journal's high-water mark so we don't
|
||||
# collide on the duplicate-PK and silently lose every emit
|
||||
# past the resume point.
|
||||
sequence_no = -1
|
||||
if _continuation and reserved_message_id:
|
||||
try:
|
||||
from application.storage.db.repositories.message_events import (
|
||||
MessageEventsRepository,
|
||||
)
|
||||
|
||||
with db_readonly() as conn:
|
||||
latest = MessageEventsRepository(conn).latest_sequence_no(
|
||||
reserved_message_id
|
||||
)
|
||||
if latest is not None:
|
||||
sequence_no = latest
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Continuation seq seed lookup failed for message_id=%s; "
|
||||
"falling back to seq=-1 (duplicate-PK collisions will "
|
||||
"be swallowed)",
|
||||
reserved_message_id,
|
||||
)
|
||||
|
||||
# One batched journal writer per stream.
|
||||
journal_writer: Optional[BatchedJournalWriter] = (
|
||||
BatchedJournalWriter(reserved_message_id)
|
||||
if reserved_message_id
|
||||
else None
|
||||
)
|
||||
|
||||
def _emit(payload: dict) -> str:
|
||||
"""Format-and-journal one SSE event.
|
||||
|
||||
With a reserved ``message_id``, buffers into the journal and
|
||||
emits ``id: <seq>``-tagged SSE frames; otherwise falls back to
|
||||
legacy ``data: ...\\n\\n`` framing.
|
||||
"""
|
||||
nonlocal sequence_no
|
||||
if not reserved_message_id or journal_writer is None:
|
||||
return f"data: {json.dumps(payload)}\n\n"
|
||||
sequence_no += 1
|
||||
seq = sequence_no
|
||||
event_type = (
|
||||
payload.get("type", "data")
|
||||
if isinstance(payload, dict)
|
||||
else "data"
|
||||
)
|
||||
normalised = payload if isinstance(payload, dict) else {"value": payload}
|
||||
journal_writer.record(seq, event_type, normalised)
|
||||
return format_sse_event(normalised, seq)
|
||||
|
||||
try:
|
||||
response_full, thought, source_log_docs, tool_calls = "", "", [], []
|
||||
is_structured = False
|
||||
schema_info = None
|
||||
structured_chunks = []
|
||||
query_metadata = {}
|
||||
paused = False
|
||||
# Surface the placeholder id before any LLM tokens so a
|
||||
# mid-handshake disconnect still has a row to tail-poll.
|
||||
if reserved_message_id:
|
||||
yield _emit(
|
||||
{
|
||||
"type": "message_id",
|
||||
"message_id": reserved_message_id,
|
||||
"conversation_id": (
|
||||
str(conversation_id) if conversation_id else None
|
||||
),
|
||||
"request_id": request_id,
|
||||
}
|
||||
)
|
||||
|
||||
if _continuation:
|
||||
gen_iter = agent.gen_continuation(
|
||||
@@ -221,18 +409,24 @@ class BaseAnswerResource:
|
||||
gen_iter = agent.gen(query=question)
|
||||
|
||||
for line in gen_iter:
|
||||
# Cheap closure check that only hits the DB when the
|
||||
# heartbeat interval has elapsed.
|
||||
_heartbeat_streaming()
|
||||
if "metadata" in line:
|
||||
query_metadata.update(line["metadata"])
|
||||
elif "answer" in line:
|
||||
_mark_streaming_once()
|
||||
response_full += str(line["answer"])
|
||||
if line.get("structured"):
|
||||
is_structured = True
|
||||
schema_info = line.get("schema")
|
||||
structured_chunks.append(line["answer"])
|
||||
else:
|
||||
data = json.dumps({"type": "answer", "answer": line["answer"]})
|
||||
yield f"data: {data}\n\n"
|
||||
yield _emit(
|
||||
{"type": "answer", "answer": line["answer"]}
|
||||
)
|
||||
elif "sources" in line:
|
||||
_mark_streaming_once()
|
||||
truncated_sources = []
|
||||
source_log_docs = line["sources"]
|
||||
for source in line["sources"]:
|
||||
@@ -243,54 +437,58 @@ class BaseAnswerResource:
|
||||
)
|
||||
truncated_sources.append(truncated_source)
|
||||
if truncated_sources:
|
||||
data = json.dumps(
|
||||
yield _emit(
|
||||
{"type": "source", "source": truncated_sources}
|
||||
)
|
||||
yield f"data: {data}\n\n"
|
||||
elif "tool_calls" in line:
|
||||
tool_calls = line["tool_calls"]
|
||||
data = json.dumps({"type": "tool_calls", "tool_calls": tool_calls})
|
||||
yield f"data: {data}\n\n"
|
||||
yield _emit({"type": "tool_calls", "tool_calls": tool_calls})
|
||||
elif "thought" in line:
|
||||
thought += line["thought"]
|
||||
data = json.dumps({"type": "thought", "thought": line["thought"]})
|
||||
yield f"data: {data}\n\n"
|
||||
yield _emit({"type": "thought", "thought": line["thought"]})
|
||||
elif "type" in line:
|
||||
if line.get("type") == "tool_calls_pending":
|
||||
# Save continuation state and end the stream
|
||||
paused = True
|
||||
data = json.dumps(line)
|
||||
yield f"data: {data}\n\n"
|
||||
yield _emit(line)
|
||||
elif line.get("type") == "error":
|
||||
sanitized_error = {
|
||||
"type": "error",
|
||||
"error": sanitize_api_error(line.get("error", "An error occurred"))
|
||||
}
|
||||
data = json.dumps(sanitized_error)
|
||||
yield f"data: {data}\n\n"
|
||||
yield _emit(
|
||||
{
|
||||
"type": "error",
|
||||
"error": sanitize_api_error(
|
||||
line.get("error", "An error occurred")
|
||||
),
|
||||
}
|
||||
)
|
||||
else:
|
||||
data = json.dumps(line)
|
||||
yield f"data: {data}\n\n"
|
||||
yield _emit(line)
|
||||
if is_structured and structured_chunks:
|
||||
structured_data = {
|
||||
"type": "structured_answer",
|
||||
"answer": response_full,
|
||||
"structured": True,
|
||||
"schema": schema_info,
|
||||
}
|
||||
data = json.dumps(structured_data)
|
||||
yield f"data: {data}\n\n"
|
||||
yield _emit(
|
||||
{
|
||||
"type": "structured_answer",
|
||||
"answer": response_full,
|
||||
"structured": True,
|
||||
"schema": schema_info,
|
||||
}
|
||||
)
|
||||
|
||||
# ---- Paused: save continuation state and end stream early ----
|
||||
if paused:
|
||||
continuation = getattr(agent, "_pending_continuation", None)
|
||||
if continuation:
|
||||
# Ensure we have a conversation_id — create a partial
|
||||
# conversation if this is the first turn.
|
||||
# First-turn pause needs a conversation row to attach to.
|
||||
if not conversation_id and should_save_conversation:
|
||||
try:
|
||||
provider = (
|
||||
get_provider_from_model_id(model_id)
|
||||
get_provider_from_model_id(
|
||||
model_id,
|
||||
user_id=model_user_id
|
||||
or (
|
||||
decoded_token.get("sub")
|
||||
if decoded_token
|
||||
else None
|
||||
),
|
||||
)
|
||||
if model_id
|
||||
else settings.LLM_PROVIDER
|
||||
)
|
||||
@@ -304,6 +502,7 @@ class BaseAnswerResource:
|
||||
decoded_token=decoded_token,
|
||||
model_id=model_id,
|
||||
agent_id=agent_id,
|
||||
model_user_id=model_user_id,
|
||||
)
|
||||
conversation_id = (
|
||||
self.conversation_service.save_conversation(
|
||||
@@ -328,6 +527,7 @@ class BaseAnswerResource:
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
state_saved = False
|
||||
if conversation_id:
|
||||
try:
|
||||
cont_service = ContinuationService()
|
||||
@@ -340,6 +540,9 @@ class BaseAnswerResource:
|
||||
tool_schemas=getattr(agent, "tools", []),
|
||||
agent_config={
|
||||
"model_id": model_id or self.default_model_id,
|
||||
# BYOM scope; without it resume falls
|
||||
# back to caller's layer.
|
||||
"model_user_id": model_user_id,
|
||||
"llm_name": getattr(agent, "llm_name", settings.LLM_PROVIDER),
|
||||
"api_key": getattr(agent, "api_key", None),
|
||||
"user_api_key": user_api_key,
|
||||
@@ -348,30 +551,87 @@ class BaseAnswerResource:
|
||||
"prompt": getattr(agent, "prompt", ""),
|
||||
"json_schema": getattr(agent, "json_schema", None),
|
||||
"retriever_config": getattr(agent, "retriever_config", None),
|
||||
# Reused on resume so the same WAL row
|
||||
# is finalised and request_id stays
|
||||
# consistent across token_usage rows.
|
||||
"reserved_message_id": reserved_message_id,
|
||||
"request_id": request_id,
|
||||
},
|
||||
client_tools=getattr(
|
||||
agent.tool_executor, "client_tools", None
|
||||
),
|
||||
)
|
||||
state_saved = True
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to save continuation state: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
id_data = {"type": "id", "id": str(conversation_id)}
|
||||
data = json.dumps(id_data)
|
||||
yield f"data: {data}\n\n"
|
||||
# Notify the user out-of-band so they can navigate
|
||||
# back to the conversation and decide on the
|
||||
# pending tool calls. Gated on ``state_saved``: a
|
||||
# missing pending_tool_state row would 404 the
|
||||
# resume endpoint, so an unfulfillable notification
|
||||
# is worse than no notification.
|
||||
user_id_for_event = (
|
||||
decoded_token.get("sub") if decoded_token else None
|
||||
)
|
||||
if state_saved and user_id_for_event and conversation_id:
|
||||
pending_calls = continuation.get(
|
||||
"pending_tool_calls", []
|
||||
) if continuation else []
|
||||
# Trim each pending tool call to its identifying
|
||||
# metadata so a tool with a multi-MB argument
|
||||
# doesn't blow out the per-event payload size
|
||||
# cap. The resume page fetches full args from
|
||||
# ``pending_tool_state`` regardless.
|
||||
pending_summaries = [
|
||||
{
|
||||
k: tc.get(k)
|
||||
for k in (
|
||||
"call_id",
|
||||
"tool_name",
|
||||
"action_name",
|
||||
"name",
|
||||
)
|
||||
if isinstance(tc, dict) and tc.get(k) is not None
|
||||
}
|
||||
for tc in (pending_calls or [])
|
||||
if isinstance(tc, dict)
|
||||
]
|
||||
publish_user_event(
|
||||
user_id_for_event,
|
||||
"tool.approval.required",
|
||||
{
|
||||
"conversation_id": str(conversation_id),
|
||||
"message_id": reserved_message_id,
|
||||
"pending_tool_calls": pending_summaries,
|
||||
},
|
||||
scope={
|
||||
"kind": "conversation",
|
||||
"id": str(conversation_id),
|
||||
},
|
||||
)
|
||||
|
||||
data = json.dumps({"type": "end"})
|
||||
yield f"data: {data}\n\n"
|
||||
yield _emit({"type": "id", "id": str(conversation_id)})
|
||||
yield _emit({"type": "end"})
|
||||
# Drain the terminal ``end`` so a reconnecting client
|
||||
# sees it on snapshot — same reason as the main exit.
|
||||
if journal_writer is not None:
|
||||
journal_writer.close()
|
||||
return
|
||||
|
||||
if isNoneDoc:
|
||||
for doc in source_log_docs:
|
||||
doc["source"] = "None"
|
||||
# Model-owner scope so title-gen uses owner's BYOM key.
|
||||
provider = (
|
||||
get_provider_from_model_id(model_id)
|
||||
get_provider_from_model_id(
|
||||
model_id,
|
||||
user_id=model_user_id
|
||||
or (decoded_token.get("sub") if decoded_token else None),
|
||||
)
|
||||
if model_id
|
||||
else settings.LLM_PROVIDER
|
||||
)
|
||||
@@ -384,27 +644,51 @@ class BaseAnswerResource:
|
||||
decoded_token=decoded_token,
|
||||
model_id=model_id,
|
||||
agent_id=agent_id,
|
||||
model_user_id=model_user_id,
|
||||
)
|
||||
# Title-gen only; agent stream tokens live on ``agent.llm``.
|
||||
llm._token_usage_source = "title"
|
||||
|
||||
if should_save_conversation:
|
||||
conversation_id = self.conversation_service.save_conversation(
|
||||
conversation_id,
|
||||
question,
|
||||
response_full,
|
||||
thought,
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
model_id or self.default_model_id,
|
||||
decoded_token,
|
||||
index=index,
|
||||
api_key=user_api_key,
|
||||
agent_id=agent_id,
|
||||
is_shared_usage=is_shared_usage,
|
||||
shared_token=shared_token,
|
||||
attachment_ids=attachment_ids,
|
||||
metadata=query_metadata if query_metadata else None,
|
||||
)
|
||||
if reserved_message_id is not None:
|
||||
self.conversation_service.finalize_message(
|
||||
reserved_message_id,
|
||||
response_full,
|
||||
thought=thought,
|
||||
sources=source_log_docs,
|
||||
tool_calls=tool_calls,
|
||||
model_id=model_id or self.default_model_id,
|
||||
metadata=query_metadata if query_metadata else None,
|
||||
status="complete",
|
||||
title_inputs={
|
||||
"llm": llm,
|
||||
"question": question,
|
||||
"response": response_full,
|
||||
"model_id": model_id or self.default_model_id,
|
||||
"fallback_name": (
|
||||
question[:50] if question else "New Conversation"
|
||||
),
|
||||
},
|
||||
)
|
||||
else:
|
||||
conversation_id = self.conversation_service.save_conversation(
|
||||
conversation_id,
|
||||
question,
|
||||
response_full,
|
||||
thought,
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
model_id or self.default_model_id,
|
||||
decoded_token,
|
||||
index=index,
|
||||
api_key=user_api_key,
|
||||
agent_id=agent_id,
|
||||
is_shared_usage=is_shared_usage,
|
||||
shared_token=shared_token,
|
||||
attachment_ids=attachment_ids,
|
||||
metadata=query_metadata if query_metadata else None,
|
||||
)
|
||||
# Persist compression metadata/summary if it exists and wasn't saved mid-execution
|
||||
compression_meta = getattr(agent, "compression_metadata", None)
|
||||
compression_saved = getattr(agent, "compression_saved", False)
|
||||
@@ -427,9 +711,22 @@ class BaseAnswerResource:
|
||||
)
|
||||
else:
|
||||
conversation_id = None
|
||||
id_data = {"type": "id", "id": str(conversation_id)}
|
||||
data = json.dumps(id_data)
|
||||
yield f"data: {data}\n\n"
|
||||
# Resume finished cleanly; drop the continuation row.
|
||||
# Crash-paths leave it ``resuming`` for the janitor to revert.
|
||||
if _continuation and conversation_id:
|
||||
try:
|
||||
cont_service = ContinuationService()
|
||||
cont_service.delete_state(
|
||||
str(conversation_id),
|
||||
decoded_token.get("sub", "local"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to delete continuation state on resume "
|
||||
f"completion: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
yield _emit({"type": "id", "id": str(conversation_id)})
|
||||
|
||||
tool_calls_for_logging = self._prepare_tool_calls_for_logging(
|
||||
getattr(agent, "tool_calls", tool_calls) or tool_calls
|
||||
@@ -470,42 +767,117 @@ class BaseAnswerResource:
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
data = json.dumps({"type": "end"})
|
||||
yield f"data: {data}\n\n"
|
||||
yield _emit({"type": "end"})
|
||||
# Drain the journal buffer so the terminal ``end`` event is
|
||||
# visible to any reconnecting client. Without this the
|
||||
# client could snapshot up to the last flush boundary and
|
||||
# then live-tail waiting for an ``end`` that's still
|
||||
# sitting in memory.
|
||||
if journal_writer is not None:
|
||||
journal_writer.close()
|
||||
except GeneratorExit:
|
||||
logger.info(f"Stream aborted by client for question: {question[:50]}... ")
|
||||
# Drain any buffered events before the terminal one-shot
|
||||
# ``record_event`` below — keeps the journal's seq order
|
||||
# contiguous (buffered events ... terminal event). ``close``
|
||||
# is idempotent; pairing it with ``flush`` matches the
|
||||
# normal-exit and error branches so any future ``record()``
|
||||
# past this point would log instead of silently buffering.
|
||||
if journal_writer is not None:
|
||||
journal_writer.flush()
|
||||
journal_writer.close()
|
||||
# Save partial response
|
||||
|
||||
# Whether the DB row was flipped to ``complete`` during this
|
||||
# abort handler. Drives the choice of terminal journal event
|
||||
# below: journal ``end`` only when the row actually matches,
|
||||
# else journal ``error`` so a reconnecting client sees a
|
||||
# failed terminal state instead of a blank "success".
|
||||
finalized_complete = False
|
||||
if should_save_conversation and response_full:
|
||||
try:
|
||||
if isNoneDoc:
|
||||
for doc in source_log_docs:
|
||||
doc["source"] = "None"
|
||||
# Resolve under model-owner scope so shared-agent
|
||||
# title-gen uses owner BYOM, not deployment default.
|
||||
provider = (
|
||||
get_provider_from_model_id(
|
||||
model_id,
|
||||
user_id=model_user_id
|
||||
or (
|
||||
decoded_token.get("sub")
|
||||
if decoded_token
|
||||
else None
|
||||
),
|
||||
)
|
||||
if model_id
|
||||
else settings.LLM_PROVIDER
|
||||
)
|
||||
sys_api_key = get_api_key_for_provider(
|
||||
provider or settings.LLM_PROVIDER
|
||||
)
|
||||
llm = LLMCreator.create_llm(
|
||||
settings.LLM_PROVIDER,
|
||||
api_key=settings.API_KEY,
|
||||
provider or settings.LLM_PROVIDER,
|
||||
api_key=sys_api_key,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
model_id=model_id,
|
||||
agent_id=agent_id,
|
||||
model_user_id=model_user_id,
|
||||
)
|
||||
self.conversation_service.save_conversation(
|
||||
conversation_id,
|
||||
question,
|
||||
response_full,
|
||||
thought,
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
model_id or self.default_model_id,
|
||||
decoded_token,
|
||||
index=index,
|
||||
api_key=user_api_key,
|
||||
agent_id=agent_id,
|
||||
is_shared_usage=is_shared_usage,
|
||||
shared_token=shared_token,
|
||||
attachment_ids=attachment_ids,
|
||||
metadata=query_metadata if query_metadata else None,
|
||||
)
|
||||
llm._token_usage_source = "title"
|
||||
if reserved_message_id is not None:
|
||||
outcome = self.conversation_service.finalize_message(
|
||||
reserved_message_id,
|
||||
response_full,
|
||||
thought=thought,
|
||||
sources=source_log_docs,
|
||||
tool_calls=tool_calls,
|
||||
model_id=model_id or self.default_model_id,
|
||||
metadata=query_metadata if query_metadata else None,
|
||||
status="complete",
|
||||
title_inputs={
|
||||
"llm": llm,
|
||||
"question": question,
|
||||
"response": response_full,
|
||||
"model_id": model_id or self.default_model_id,
|
||||
"fallback_name": (
|
||||
question[:50] if question else "New Conversation"
|
||||
),
|
||||
},
|
||||
)
|
||||
# ``ALREADY_COMPLETE`` means the normal-path
|
||||
# finalize at line 632 won the race: the DB row
|
||||
# is already at ``complete`` and the reconnect
|
||||
# journal should reflect that with ``end``,
|
||||
# not a spurious ``error``.
|
||||
finalized_complete = outcome in (
|
||||
MessageUpdateOutcome.UPDATED,
|
||||
MessageUpdateOutcome.ALREADY_COMPLETE,
|
||||
)
|
||||
else:
|
||||
self.conversation_service.save_conversation(
|
||||
conversation_id,
|
||||
question,
|
||||
response_full,
|
||||
thought,
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
model_id or self.default_model_id,
|
||||
decoded_token,
|
||||
index=index,
|
||||
api_key=user_api_key,
|
||||
agent_id=agent_id,
|
||||
is_shared_usage=is_shared_usage,
|
||||
shared_token=shared_token,
|
||||
attachment_ids=attachment_ids,
|
||||
metadata=query_metadata if query_metadata else None,
|
||||
)
|
||||
# No journal row to gate, but flag the save as
|
||||
# successful for symmetry with the WAL path.
|
||||
finalized_complete = True
|
||||
compression_meta = getattr(agent, "compression_metadata", None)
|
||||
compression_saved = getattr(agent, "compression_saved", False)
|
||||
if conversation_id and compression_meta and not compression_saved:
|
||||
@@ -529,16 +901,94 @@ class BaseAnswerResource:
|
||||
logger.error(
|
||||
f"Error saving partial response: {str(e)}", exc_info=True
|
||||
)
|
||||
# Journal a terminal event so reconnecting clients stop tailing;
|
||||
# ``end`` only when the row is ``complete``, else ``error``.
|
||||
if reserved_message_id is not None:
|
||||
try:
|
||||
sequence_no += 1
|
||||
if finalized_complete:
|
||||
# Match the wire shape ``_emit({"type": "end"})``
|
||||
# uses on the normal path — the replay terminal
|
||||
# check at ``event_replay._payload_is_terminal``
|
||||
# reads ``payload.type``, and the frontend parses
|
||||
# the same key off ``data:``.
|
||||
record_event(
|
||||
reserved_message_id,
|
||||
sequence_no,
|
||||
"end",
|
||||
{"type": "end"},
|
||||
)
|
||||
else:
|
||||
# Nothing was persisted under the complete status
|
||||
# — mark the row failed so the reconciler doesn't
|
||||
# need to sweep it, and journal an ``error`` so a
|
||||
# reconnecting client surfaces the same failure
|
||||
# the UI would show on a live error.
|
||||
try:
|
||||
self.conversation_service.finalize_message(
|
||||
reserved_message_id,
|
||||
response_full or TERMINATED_RESPONSE_PLACEHOLDER,
|
||||
thought=thought,
|
||||
sources=source_log_docs,
|
||||
tool_calls=tool_calls,
|
||||
model_id=model_id or self.default_model_id,
|
||||
metadata=query_metadata if query_metadata else None,
|
||||
status="failed",
|
||||
error=ConnectionError(
|
||||
"client disconnected before response was persisted"
|
||||
),
|
||||
)
|
||||
except Exception as fin_err:
|
||||
logger.error(
|
||||
f"Failed to mark aborted message failed: {fin_err}",
|
||||
exc_info=True,
|
||||
)
|
||||
record_event(
|
||||
reserved_message_id,
|
||||
sequence_no,
|
||||
"error",
|
||||
{
|
||||
"type": "error",
|
||||
"error": "Stream aborted before any response was produced.",
|
||||
"code": "client_disconnect",
|
||||
},
|
||||
)
|
||||
except Exception as journal_err:
|
||||
logger.error(
|
||||
f"Failed to journal terminal event on abort: {journal_err}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in stream: {str(e)}", exc_info=True)
|
||||
data = json.dumps(
|
||||
if reserved_message_id is not None:
|
||||
try:
|
||||
self.conversation_service.finalize_message(
|
||||
reserved_message_id,
|
||||
response_full or TERMINATED_RESPONSE_PLACEHOLDER,
|
||||
thought=thought,
|
||||
sources=source_log_docs,
|
||||
tool_calls=tool_calls,
|
||||
model_id=model_id or self.default_model_id,
|
||||
metadata=query_metadata if query_metadata else None,
|
||||
status="failed",
|
||||
error=e,
|
||||
)
|
||||
except Exception as fin_err:
|
||||
logger.error(
|
||||
f"Failed to finalize errored message: {fin_err}",
|
||||
exc_info=True,
|
||||
)
|
||||
yield _emit(
|
||||
{
|
||||
"type": "error",
|
||||
"error": "Please try again later. We apologize for any inconvenience.",
|
||||
}
|
||||
)
|
||||
yield f"data: {data}\n\n"
|
||||
# Drain the terminal ``error`` event we just yielded so a
|
||||
# reconnecting client sees it on snapshot.
|
||||
if journal_writer is not None:
|
||||
journal_writer.close()
|
||||
return
|
||||
|
||||
def process_response_stream(self, stream) -> Dict[str, Any]:
|
||||
@@ -560,8 +1010,22 @@ class BaseAnswerResource:
|
||||
|
||||
for line in stream:
|
||||
try:
|
||||
event_data = line.replace("data: ", "").strip()
|
||||
# Each chunk may carry an ``id: <seq>`` header before
|
||||
# the ``data:`` line. Pull just the ``data:`` body so
|
||||
# the JSON decode doesn't choke on the SSE framing.
|
||||
event_data = ""
|
||||
for raw in line.split("\n"):
|
||||
if raw.startswith("data:"):
|
||||
event_data = raw[len("data:") :].lstrip()
|
||||
break
|
||||
if not event_data:
|
||||
continue
|
||||
event = json.loads(event_data)
|
||||
# The ``message_id`` event is informational for the
|
||||
# streaming consumer and has no synchronous-API field;
|
||||
# skip it so the type-switch below doesn't KeyError.
|
||||
if event.get("type") == "message_id":
|
||||
continue
|
||||
|
||||
if event["type"] == "id":
|
||||
conversation_id = event["id"]
|
||||
|
||||
135
application/api/answer/routes/messages.py
Normal file
135
application/api/answer/routes/messages.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""GET /api/messages/<message_id>/events — chat-stream reconnect endpoint.
|
||||
|
||||
Authenticates the caller, verifies ``message_id`` belongs to the user,
|
||||
then hands off to ``build_message_event_stream`` for snapshot+tail.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Iterator, Optional
|
||||
|
||||
from flask import Blueprint, Response, jsonify, make_response, request, stream_with_context
|
||||
from sqlalchemy import text
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.session import db_readonly
|
||||
from application.streaming.event_replay import (
|
||||
DEFAULT_KEEPALIVE_SECONDS,
|
||||
DEFAULT_POLL_TIMEOUT_SECONDS,
|
||||
build_message_event_stream,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
messages_bp = Blueprint("message_stream", __name__)
|
||||
|
||||
# A message_id is the canonical UUID hex format. Reject anything else
|
||||
# before the SQL layer so a malformed cookie can't surface as a 500.
|
||||
_MESSAGE_ID_RE = re.compile(
|
||||
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-"
|
||||
r"[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$"
|
||||
)
|
||||
# ``sequence_no`` is a non-negative decimal integer. Anything else is
|
||||
# corrupt client state — fall through to a fresh-replay cursor and let
|
||||
# the snapshot reader catch the client up.
|
||||
_SEQUENCE_NO_RE = re.compile(r"^\d+$")
|
||||
|
||||
|
||||
def _normalise_last_event_id(raw: Optional[str]) -> Optional[int]:
|
||||
if raw is None:
|
||||
return None
|
||||
raw = raw.strip()
|
||||
if not raw or not _SEQUENCE_NO_RE.match(raw):
|
||||
return None
|
||||
return int(raw)
|
||||
|
||||
|
||||
def _user_owns_message(message_id: str, user_id: str) -> bool:
|
||||
"""Return True iff ``message_id`` belongs to ``user_id``."""
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
row = conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT 1 FROM conversation_messages
|
||||
WHERE id = CAST(:id AS uuid)
|
||||
AND user_id = :u
|
||||
LIMIT 1
|
||||
"""
|
||||
),
|
||||
{"id": message_id, "u": user_id},
|
||||
).first()
|
||||
return row is not None
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Ownership lookup failed for message_id=%s user_id=%s",
|
||||
message_id,
|
||||
user_id,
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
@messages_bp.route("/api/messages/<message_id>/events", methods=["GET"])
|
||||
def stream_message_events(message_id: str) -> Response:
|
||||
decoded = getattr(request, "decoded_token", None)
|
||||
user_id = decoded.get("sub") if isinstance(decoded, dict) else None
|
||||
if not user_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Authentication required"}),
|
||||
401,
|
||||
)
|
||||
|
||||
if not _MESSAGE_ID_RE.match(message_id):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid message id"}),
|
||||
400,
|
||||
)
|
||||
|
||||
if not _user_owns_message(message_id, user_id):
|
||||
# Don't disclose whether the row exists — a malicious caller
|
||||
# gets the same 404 whether the id is bogus, taken by another
|
||||
# user, or simply gone.
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Not found"}),
|
||||
404,
|
||||
)
|
||||
|
||||
raw_cursor = request.headers.get("Last-Event-ID") or request.args.get(
|
||||
"last_event_id"
|
||||
)
|
||||
last_event_id = _normalise_last_event_id(raw_cursor)
|
||||
keepalive_seconds = float(
|
||||
getattr(settings, "SSE_KEEPALIVE_SECONDS", DEFAULT_KEEPALIVE_SECONDS)
|
||||
)
|
||||
|
||||
@stream_with_context
|
||||
def generate() -> Iterator[str]:
|
||||
try:
|
||||
yield from build_message_event_stream(
|
||||
message_id,
|
||||
last_event_id=last_event_id,
|
||||
keepalive_seconds=keepalive_seconds,
|
||||
poll_timeout_seconds=DEFAULT_POLL_TIMEOUT_SECONDS,
|
||||
)
|
||||
except GeneratorExit:
|
||||
return
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Reconnect stream crashed for message_id=%s user_id=%s",
|
||||
message_id,
|
||||
user_id,
|
||||
)
|
||||
|
||||
response = Response(generate(), mimetype="text/event-stream")
|
||||
response.headers["Cache-Control"] = "no-store"
|
||||
response.headers["X-Accel-Buffering"] = "no"
|
||||
response.headers["Connection"] = "keep-alive"
|
||||
logger.info(
|
||||
"message.event.connect message_id=%s user_id=%s last_event_id=%s",
|
||||
message_id,
|
||||
user_id,
|
||||
last_event_id if last_event_id is not None else "-",
|
||||
)
|
||||
return response
|
||||
@@ -109,11 +109,14 @@ class StreamResource(Resource, BaseAnswerResource):
|
||||
decoded_token=processor.decoded_token,
|
||||
agent_id=processor.agent_id,
|
||||
model_id=processor.model_id,
|
||||
model_user_id=processor.model_user_id,
|
||||
_continuation={
|
||||
"messages": messages,
|
||||
"tools_dict": tools_dict,
|
||||
"pending_tool_calls": pending_tool_calls,
|
||||
"tool_actions": tool_actions,
|
||||
"reserved_message_id": processor.reserved_message_id,
|
||||
"request_id": processor.request_id,
|
||||
},
|
||||
),
|
||||
mimetype="text/event-stream",
|
||||
@@ -145,6 +148,7 @@ class StreamResource(Resource, BaseAnswerResource):
|
||||
is_shared_usage=processor.is_shared_usage,
|
||||
shared_token=processor.shared_token,
|
||||
model_id=processor.model_id,
|
||||
model_user_id=processor.model_user_id,
|
||||
),
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
|
||||
@@ -49,6 +49,7 @@ class CompressionOrchestrator:
|
||||
model_id: str,
|
||||
decoded_token: Dict[str, Any],
|
||||
current_query_tokens: int = 500,
|
||||
model_user_id: Optional[str] = None,
|
||||
) -> CompressionResult:
|
||||
"""
|
||||
Check if compression is needed and perform it if so.
|
||||
@@ -57,16 +58,18 @@ class CompressionOrchestrator:
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation ID
|
||||
user_id: User ID
|
||||
user_id: Caller's user id — used for conversation access checks
|
||||
model_id: Model being used for conversation
|
||||
decoded_token: User's decoded JWT token
|
||||
current_query_tokens: Estimated tokens for current query
|
||||
model_user_id: BYOM-resolution scope (model owner); defaults
|
||||
to ``user_id`` for built-in / caller-owned models.
|
||||
|
||||
Returns:
|
||||
CompressionResult with summary and recent queries
|
||||
"""
|
||||
try:
|
||||
# Load conversation
|
||||
# Conversation row is owned by the caller, not the model owner.
|
||||
conversation = self.conversation_service.get_conversation(
|
||||
conversation_id, user_id
|
||||
)
|
||||
@@ -77,9 +80,14 @@ class CompressionOrchestrator:
|
||||
)
|
||||
return CompressionResult.failure("Conversation not found")
|
||||
|
||||
# Check if compression is needed
|
||||
# Use model-owner scope so per-user BYOM context windows
|
||||
# (e.g. 8k) compute the threshold against the right limit.
|
||||
registry_user_id = model_user_id or user_id
|
||||
if not self.threshold_checker.should_compress(
|
||||
conversation, model_id, current_query_tokens
|
||||
conversation,
|
||||
model_id,
|
||||
current_query_tokens,
|
||||
user_id=registry_user_id,
|
||||
):
|
||||
# No compression needed, return full history
|
||||
queries = conversation.get("queries", [])
|
||||
@@ -87,7 +95,12 @@ class CompressionOrchestrator:
|
||||
|
||||
# Perform compression
|
||||
return self._perform_compression(
|
||||
conversation_id, conversation, model_id, decoded_token
|
||||
conversation_id,
|
||||
conversation,
|
||||
model_id,
|
||||
decoded_token,
|
||||
user_id=user_id,
|
||||
model_user_id=model_user_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -102,6 +115,8 @@ class CompressionOrchestrator:
|
||||
conversation: Dict[str, Any],
|
||||
model_id: str,
|
||||
decoded_token: Dict[str, Any],
|
||||
user_id: Optional[str] = None,
|
||||
model_user_id: Optional[str] = None,
|
||||
) -> CompressionResult:
|
||||
"""
|
||||
Perform the actual compression operation.
|
||||
@@ -111,6 +126,8 @@ class CompressionOrchestrator:
|
||||
conversation: Conversation document
|
||||
model_id: Model ID for conversation
|
||||
decoded_token: User token
|
||||
user_id: Caller's id (for conversation reload after compression)
|
||||
model_user_id: BYOM-resolution scope (model owner)
|
||||
|
||||
Returns:
|
||||
CompressionResult
|
||||
@@ -123,11 +140,17 @@ class CompressionOrchestrator:
|
||||
else model_id
|
||||
)
|
||||
|
||||
# Get provider and API key for compression model
|
||||
provider = get_provider_from_model_id(compression_model)
|
||||
# Use model-owner scope so provider/api_key resolves to the
|
||||
# owner's BYOM record (shared-agent dispatch).
|
||||
caller_user_id = user_id
|
||||
if caller_user_id is None and isinstance(decoded_token, dict):
|
||||
caller_user_id = decoded_token.get("sub")
|
||||
registry_user_id = model_user_id or caller_user_id
|
||||
provider = get_provider_from_model_id(
|
||||
compression_model, user_id=registry_user_id
|
||||
)
|
||||
api_key = get_api_key_for_provider(provider)
|
||||
|
||||
# Create compression LLM
|
||||
compression_llm = LLMCreator.create_llm(
|
||||
provider,
|
||||
api_key=api_key,
|
||||
@@ -135,7 +158,11 @@ class CompressionOrchestrator:
|
||||
decoded_token=decoded_token,
|
||||
model_id=compression_model,
|
||||
agent_id=conversation.get("agent_id"),
|
||||
model_user_id=registry_user_id,
|
||||
)
|
||||
# Side-channel LLM tag — distinguishes compression rows
|
||||
# from primary stream rows for cost-attribution dashboards.
|
||||
compression_llm._token_usage_source = "compression"
|
||||
|
||||
# Create compression service with DB update capability
|
||||
compression_service = CompressionService(
|
||||
@@ -167,9 +194,12 @@ class CompressionOrchestrator:
|
||||
f"saved {metadata.original_token_count - metadata.compressed_token_count} tokens"
|
||||
)
|
||||
|
||||
# Reload conversation with updated metadata
|
||||
# Reload under caller (conversation is owned by caller).
|
||||
reload_user_id = caller_user_id
|
||||
if reload_user_id is None and isinstance(decoded_token, dict):
|
||||
reload_user_id = decoded_token.get("sub")
|
||||
conversation = self.conversation_service.get_conversation(
|
||||
conversation_id, user_id=decoded_token.get("sub")
|
||||
conversation_id, user_id=reload_user_id
|
||||
)
|
||||
|
||||
# Get compressed context
|
||||
@@ -192,16 +222,21 @@ class CompressionOrchestrator:
|
||||
model_id: str,
|
||||
decoded_token: Dict[str, Any],
|
||||
current_conversation: Optional[Dict[str, Any]] = None,
|
||||
model_user_id: Optional[str] = None,
|
||||
) -> CompressionResult:
|
||||
"""
|
||||
Perform compression during tool execution.
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation ID
|
||||
user_id: User ID
|
||||
user_id: Caller's user id — used for conversation access checks
|
||||
model_id: Model ID
|
||||
decoded_token: User token
|
||||
current_conversation: Pre-loaded conversation (optional)
|
||||
model_user_id: BYOM-resolution scope (model owner). For
|
||||
shared-agent dispatch this is the agent owner; defaults
|
||||
to ``user_id`` so built-in / caller-owned models are
|
||||
unaffected.
|
||||
|
||||
Returns:
|
||||
CompressionResult
|
||||
@@ -223,7 +258,12 @@ class CompressionOrchestrator:
|
||||
|
||||
# Perform compression
|
||||
return self._perform_compression(
|
||||
conversation_id, conversation, model_id, decoded_token
|
||||
conversation_id,
|
||||
conversation,
|
||||
model_id,
|
||||
decoded_token,
|
||||
user_id=user_id,
|
||||
model_user_id=model_user_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -106,8 +106,13 @@ class CompressionService:
|
||||
f"using model {self.model_id}"
|
||||
)
|
||||
|
||||
# See note in conversation_service.py: ``self.model_id`` is
|
||||
# the registry id (UUID for BYOM); the LLM's own model_id is
|
||||
# what the provider's API actually expects.
|
||||
response = self.llm.gen(
|
||||
model=self.model_id, messages=messages, max_tokens=4000
|
||||
model=getattr(self.llm, "model_id", None) or self.model_id,
|
||||
messages=messages,
|
||||
max_tokens=4000,
|
||||
)
|
||||
|
||||
# Extract summary from response
|
||||
|
||||
@@ -30,6 +30,7 @@ class CompressionThresholdChecker:
|
||||
conversation: Dict[str, Any],
|
||||
model_id: str,
|
||||
current_query_tokens: int = 500,
|
||||
user_id: str | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if compression is needed.
|
||||
@@ -38,6 +39,8 @@ class CompressionThresholdChecker:
|
||||
conversation: Full conversation document
|
||||
model_id: Target model for this request
|
||||
current_query_tokens: Estimated tokens for current query
|
||||
user_id: Owner — needed so per-user BYOM custom-model UUIDs
|
||||
resolve when looking up the context window.
|
||||
|
||||
Returns:
|
||||
True if tokens >= threshold% of context window
|
||||
@@ -48,7 +51,7 @@ class CompressionThresholdChecker:
|
||||
total_tokens += current_query_tokens
|
||||
|
||||
# Get context window limit for model
|
||||
context_limit = get_token_limit(model_id)
|
||||
context_limit = get_token_limit(model_id, user_id=user_id)
|
||||
|
||||
# Calculate threshold
|
||||
threshold = int(context_limit * self.threshold_percentage)
|
||||
@@ -73,20 +76,24 @@ class CompressionThresholdChecker:
|
||||
logger.error(f"Error checking compression need: {str(e)}", exc_info=True)
|
||||
return False
|
||||
|
||||
def check_message_tokens(self, messages: list, model_id: str) -> bool:
|
||||
def check_message_tokens(
|
||||
self, messages: list, model_id: str, user_id: str | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
Check if message list exceeds threshold.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts
|
||||
model_id: Target model
|
||||
user_id: Owner — needed so per-user BYOM custom-model UUIDs
|
||||
resolve when looking up the context window.
|
||||
|
||||
Returns:
|
||||
True if at or above threshold
|
||||
"""
|
||||
try:
|
||||
current_tokens = TokenCounter.count_message_tokens(messages)
|
||||
context_limit = get_token_limit(model_id)
|
||||
context_limit = get_token_limit(model_id, user_id=user_id)
|
||||
threshold = int(context_limit * self.threshold_percentage)
|
||||
|
||||
if current_tokens >= threshold:
|
||||
|
||||
@@ -12,6 +12,12 @@ logger = logging.getLogger(__name__)
|
||||
class TokenCounter:
|
||||
"""Centralized token counting for conversations and messages."""
|
||||
|
||||
# Per-image token estimate. Provider tokenizers vary widely
|
||||
# (Gemini ~258, GPT-4o 85-1500, Claude ~1500) and the actual cost
|
||||
# depends on resolution/detail we can't see here. Errs slightly high
|
||||
# so the threshold check stays conservative.
|
||||
_IMAGE_PART_TOKEN_ESTIMATE = 1500
|
||||
|
||||
@staticmethod
|
||||
def count_message_tokens(messages: List[Dict]) -> int:
|
||||
"""
|
||||
@@ -29,12 +35,36 @@ class TokenCounter:
|
||||
if isinstance(content, str):
|
||||
total_tokens += num_tokens_from_string(content)
|
||||
elif isinstance(content, list):
|
||||
# Handle structured content (tool calls, etc.)
|
||||
# Handle structured content (tool calls, image parts, etc.)
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
total_tokens += num_tokens_from_string(str(item))
|
||||
total_tokens += TokenCounter._count_content_part(item)
|
||||
return total_tokens
|
||||
|
||||
@staticmethod
|
||||
def _count_content_part(item: Dict) -> int:
|
||||
# Image/file attachments are billed by the provider per image,
|
||||
# not proportional to the inline bytes/base64 string.
|
||||
# ``str(item)`` on a 1MB image inflates the count by ~10000x,
|
||||
# which trips spurious compression and overflows downstream
|
||||
# input limits.
|
||||
item_type = item.get("type")
|
||||
|
||||
if "files" in item:
|
||||
files = item.get("files")
|
||||
count = len(files) if isinstance(files, list) and files else 1
|
||||
return TokenCounter._IMAGE_PART_TOKEN_ESTIMATE * count
|
||||
|
||||
if "image_url" in item or item_type in {
|
||||
"image",
|
||||
"image_url",
|
||||
"input_image",
|
||||
"file",
|
||||
}:
|
||||
return TokenCounter._IMAGE_PART_TOKEN_ESTIMATE
|
||||
|
||||
return num_tokens_from_string(str(item))
|
||||
|
||||
@staticmethod
|
||||
def count_query_tokens(
|
||||
queries: List[Dict[str, Any]], include_tool_calls: bool = True
|
||||
|
||||
@@ -7,13 +7,13 @@ resume later by sending tool_actions.
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||
from application.storage.db.repositories.pending_tool_state import (
|
||||
PendingToolStateRepository,
|
||||
)
|
||||
from application.storage.db.serialization import coerce_pg_native as _make_serializable
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -21,23 +21,9 @@ logger = logging.getLogger(__name__)
|
||||
# TTL for pending states — auto-cleaned after this period
|
||||
PENDING_STATE_TTL_SECONDS = 30 * 60 # 30 minutes
|
||||
|
||||
|
||||
def _make_serializable(obj: Any) -> Any:
|
||||
"""Recursively coerce non-JSON values into JSON-safe forms.
|
||||
|
||||
Handles ``uuid.UUID`` (from PG columns), ``bytes``, and recurses into
|
||||
dicts/lists. Post-Mongo-cutover the ObjectId branch is gone — none of
|
||||
our writers produce them anymore.
|
||||
"""
|
||||
if isinstance(obj, UUID):
|
||||
return str(obj)
|
||||
if isinstance(obj, dict):
|
||||
return {str(k): _make_serializable(v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [_make_serializable(v) for v in obj]
|
||||
if isinstance(obj, bytes):
|
||||
return obj.decode("utf-8", errors="replace")
|
||||
return obj
|
||||
# Re-export so the existing tests at tests/api/answer/services/test_continuation_service_pg.py
|
||||
# can keep importing ``_make_serializable`` from here.
|
||||
__all__ = ["_make_serializable", "ContinuationService", "PENDING_STATE_TTL_SECONDS"]
|
||||
|
||||
|
||||
class ContinuationService:
|
||||
@@ -155,3 +141,23 @@ class ContinuationService:
|
||||
f"Deleted continuation state for conversation {conversation_id}"
|
||||
)
|
||||
return deleted
|
||||
|
||||
def mark_resuming(self, conversation_id: str, user: str) -> bool:
|
||||
"""Flip the pending row to ``resuming`` so a crashed resume can be retried."""
|
||||
with db_session() as conn:
|
||||
conv = ConversationsRepository(conn).get_by_legacy_id(conversation_id)
|
||||
if conv is not None:
|
||||
pg_conv_id = conv["id"]
|
||||
elif looks_like_uuid(conversation_id):
|
||||
pg_conv_id = conversation_id
|
||||
else:
|
||||
return False
|
||||
flipped = PendingToolStateRepository(conn).mark_resuming(
|
||||
pg_conv_id, user
|
||||
)
|
||||
if flipped:
|
||||
logger.info(
|
||||
f"Marked continuation state as resuming for conversation "
|
||||
f"{conversation_id}"
|
||||
)
|
||||
return flipped
|
||||
|
||||
@@ -6,6 +6,7 @@ than held for the duration of a stream.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
@@ -14,13 +15,22 @@ from sqlalchemy import text as sql_text
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||
from application.storage.db.repositories.conversations import (
|
||||
ConversationsRepository,
|
||||
MessageUpdateOutcome,
|
||||
)
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Shown to the user if the worker dies mid-stream and the response is never finalised.
|
||||
TERMINATED_RESPONSE_PLACEHOLDER = (
|
||||
"Response was terminated prior to completion, try regenerating."
|
||||
)
|
||||
|
||||
|
||||
class ConversationService:
|
||||
def get_conversation(
|
||||
self, conversation_id: str, user_id: str
|
||||
@@ -136,8 +146,14 @@ class ConversationService:
|
||||
},
|
||||
]
|
||||
|
||||
# ``model_id`` here is the registry id (a UUID for BYOM
|
||||
# records). The LLM's own ``model_id`` is the upstream name
|
||||
# LLMCreator resolved at construction time — that's what
|
||||
# the provider's API expects. Built-ins are unaffected.
|
||||
completion = llm.gen(
|
||||
model=model_id, messages=messages_summary, max_tokens=500
|
||||
model=getattr(llm, "model_id", None) or model_id,
|
||||
messages=messages_summary,
|
||||
max_tokens=500,
|
||||
)
|
||||
|
||||
if not completion or not completion.strip():
|
||||
@@ -173,6 +189,243 @@ class ConversationService:
|
||||
repo.append_message(conv_pg_id, append_payload)
|
||||
return conv_pg_id
|
||||
|
||||
def save_user_question(
|
||||
self,
|
||||
conversation_id: Optional[str],
|
||||
question: str,
|
||||
decoded_token: Dict[str, Any],
|
||||
*,
|
||||
attachment_ids: Optional[List[str]] = None,
|
||||
api_key: Optional[str] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
is_shared_usage: bool = False,
|
||||
shared_token: Optional[str] = None,
|
||||
model_id: Optional[str] = None,
|
||||
request_id: Optional[str] = None,
|
||||
status: str = "pending",
|
||||
index: Optional[int] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""Reserve the placeholder message row before the LLM call.
|
||||
|
||||
``index`` triggers regenerate semantics: messages at
|
||||
``position >= index`` are truncated so the new placeholder
|
||||
lands at ``position = index`` rather than appending.
|
||||
|
||||
Returns ``{"conversation_id", "message_id", "request_id"}``.
|
||||
"""
|
||||
if decoded_token is None:
|
||||
raise ValueError("Invalid or missing authentication token")
|
||||
user_id = decoded_token.get("sub")
|
||||
if not user_id:
|
||||
raise ValueError("User ID not found in token")
|
||||
|
||||
request_id = request_id or str(uuid.uuid4())
|
||||
|
||||
resolved_api_key: Optional[str] = None
|
||||
resolved_agent_id: Optional[str] = None
|
||||
if api_key and not conversation_id:
|
||||
with db_readonly() as conn:
|
||||
agent = AgentsRepository(conn).find_by_key(api_key)
|
||||
if agent:
|
||||
resolved_api_key = agent.get("key")
|
||||
if agent_id:
|
||||
resolved_agent_id = agent_id
|
||||
|
||||
with db_session() as conn:
|
||||
repo = ConversationsRepository(conn)
|
||||
if conversation_id:
|
||||
conv = repo.get_any(conversation_id, user_id)
|
||||
if conv is None:
|
||||
raise ValueError("Conversation not found or unauthorized")
|
||||
conv_pg_id = str(conv["id"])
|
||||
# Regenerate / edit-prior-question: drop the message at
|
||||
# ``index`` and everything after it so the new
|
||||
# ``reserve_message`` lands at ``position=index`` rather
|
||||
# than appending at the end of the conversation.
|
||||
if isinstance(index, int) and index >= 0:
|
||||
repo.truncate_after(conv_pg_id, keep_up_to=index - 1)
|
||||
else:
|
||||
fallback_name = (question[:50] if question else "New Conversation")
|
||||
conv = repo.create(
|
||||
user_id,
|
||||
fallback_name,
|
||||
agent_id=resolved_agent_id,
|
||||
api_key=resolved_api_key,
|
||||
is_shared_usage=bool(resolved_agent_id and is_shared_usage),
|
||||
shared_token=(
|
||||
shared_token
|
||||
if (resolved_agent_id and is_shared_usage)
|
||||
else None
|
||||
),
|
||||
)
|
||||
conv_pg_id = str(conv["id"])
|
||||
|
||||
row = repo.reserve_message(
|
||||
conv_pg_id,
|
||||
prompt=question,
|
||||
placeholder_response=TERMINATED_RESPONSE_PLACEHOLDER,
|
||||
request_id=request_id,
|
||||
status=status,
|
||||
attachments=attachment_ids,
|
||||
model_id=model_id,
|
||||
)
|
||||
message_id = str(row["id"])
|
||||
|
||||
return {
|
||||
"conversation_id": conv_pg_id,
|
||||
"message_id": message_id,
|
||||
"request_id": request_id,
|
||||
}
|
||||
|
||||
def update_message_status(self, message_id: str, status: str) -> bool:
|
||||
"""Cheap status-only transition (e.g. ``pending → streaming``)."""
|
||||
if not message_id:
|
||||
return False
|
||||
with db_session() as conn:
|
||||
return ConversationsRepository(conn).update_message_status(
|
||||
message_id, status,
|
||||
)
|
||||
|
||||
def heartbeat_message(self, message_id: str) -> bool:
|
||||
"""Bump ``message_metadata.last_heartbeat_at`` so the reconciler's
|
||||
staleness sweep counts the row as alive. No-ops on terminal rows.
|
||||
"""
|
||||
if not message_id:
|
||||
return False
|
||||
with db_session() as conn:
|
||||
return ConversationsRepository(conn).heartbeat_message(message_id)
|
||||
|
||||
def finalize_message(
|
||||
self,
|
||||
message_id: str,
|
||||
response: str,
|
||||
*,
|
||||
thought: str = "",
|
||||
sources: Optional[List[Dict[str, Any]]] = None,
|
||||
tool_calls: Optional[List[Dict[str, Any]]] = None,
|
||||
model_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
status: str = "complete",
|
||||
error: Optional[BaseException] = None,
|
||||
title_inputs: Optional[Dict[str, Any]] = None,
|
||||
) -> MessageUpdateOutcome:
|
||||
"""Commit the response and tool_call confirms in one transaction.
|
||||
|
||||
The outcome propagates directly from ``update_message_by_id`` so
|
||||
callers (notably the SSE abort handler) can tell a fresh
|
||||
finalize from "the row was already terminal" — the latter must
|
||||
still be treated as success when the prior state was
|
||||
``complete``.
|
||||
"""
|
||||
if not message_id:
|
||||
return MessageUpdateOutcome.INVALID
|
||||
sources = sources or []
|
||||
for source in sources:
|
||||
if "text" in source and isinstance(source["text"], str):
|
||||
source["text"] = source["text"][:1000]
|
||||
|
||||
merged_metadata: Dict[str, Any] = dict(metadata or {})
|
||||
if status == "failed" and error is not None:
|
||||
merged_metadata.setdefault(
|
||||
"error", f"{type(error).__name__}: {str(error)}"
|
||||
)
|
||||
|
||||
update_fields: Dict[str, Any] = {
|
||||
"response": response,
|
||||
"status": status,
|
||||
"thought": thought,
|
||||
"sources": sources,
|
||||
"tool_calls": tool_calls or [],
|
||||
"metadata": merged_metadata,
|
||||
}
|
||||
if model_id is not None:
|
||||
update_fields["model_id"] = model_id
|
||||
|
||||
# Atomic message update + tool_call_attempts confirm; the
|
||||
# ``only_if_non_terminal`` guard prevents a late stream from
|
||||
# retracting a row the reconciler already escalated.
|
||||
with db_session() as conn:
|
||||
repo = ConversationsRepository(conn)
|
||||
outcome = repo.update_message_by_id(
|
||||
message_id, update_fields,
|
||||
only_if_non_terminal=True,
|
||||
)
|
||||
if outcome is not MessageUpdateOutcome.UPDATED:
|
||||
logger.warning(
|
||||
f"finalize_message: no row updated for message_id={message_id} "
|
||||
f"(outcome={outcome.value} — possibly already terminal)"
|
||||
)
|
||||
return outcome
|
||||
repo.confirm_executed_tool_calls(message_id)
|
||||
|
||||
# Outside the txn — title-gen is a multi-second LLM round trip.
|
||||
if title_inputs and status == "complete":
|
||||
try:
|
||||
with db_session() as conn:
|
||||
self._maybe_generate_title(conn, message_id, title_inputs)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"finalize_message title generation failed: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
return MessageUpdateOutcome.UPDATED
|
||||
|
||||
def _maybe_generate_title(
|
||||
self,
|
||||
conn,
|
||||
message_id: str,
|
||||
title_inputs: Dict[str, Any],
|
||||
) -> None:
|
||||
"""Generate an LLM-summarised conversation name if one isn't set yet."""
|
||||
llm = title_inputs.get("llm")
|
||||
question = title_inputs.get("question") or ""
|
||||
response = title_inputs.get("response") or ""
|
||||
fallback_name = title_inputs.get("fallback_name") or question[:50]
|
||||
if llm is None:
|
||||
return
|
||||
|
||||
row = conn.execute(
|
||||
sql_text(
|
||||
"SELECT c.id, c.name FROM conversation_messages m "
|
||||
"JOIN conversations c ON c.id = m.conversation_id "
|
||||
"WHERE m.id = CAST(:mid AS uuid)"
|
||||
),
|
||||
{"mid": message_id},
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return
|
||||
conv_id, current_name = str(row[0]), row[1]
|
||||
if current_name and current_name != fallback_name:
|
||||
return
|
||||
|
||||
messages_summary = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant that creates concise conversation titles. "
|
||||
"Summarize conversations in 3 words or less using the same language as the user.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Summarise following conversation in no more than 3 words, "
|
||||
"respond ONLY with the summary, use the same language as the "
|
||||
"user query \n\nUser: " + question + "\n\n" + "AI: " + response,
|
||||
},
|
||||
]
|
||||
completion = llm.gen(
|
||||
model=getattr(llm, "model_id", None) or title_inputs.get("model_id"),
|
||||
messages=messages_summary,
|
||||
max_tokens=500,
|
||||
)
|
||||
if not completion or not completion.strip():
|
||||
completion = fallback_name or "New Conversation"
|
||||
conn.execute(
|
||||
sql_text(
|
||||
"UPDATE conversations SET name = :name, updated_at = now() "
|
||||
"WHERE id = CAST(:id AS uuid)"
|
||||
),
|
||||
{"id": conv_id, "name": completion.strip()},
|
||||
)
|
||||
|
||||
def update_compression_metadata(
|
||||
self, conversation_id: str, compression_metadata: Dict[str, Any]
|
||||
) -> None:
|
||||
|
||||
@@ -121,6 +121,12 @@ class StreamProcessor:
|
||||
self.agent_id = self.data.get("agent_id")
|
||||
self.agent_key = None
|
||||
self.model_id: Optional[str] = None
|
||||
# BYOM-resolution scope, set by _validate_and_set_model.
|
||||
self.model_user_id: Optional[str] = None
|
||||
# WAL placeholder id pulled from continuation state on resume.
|
||||
self.reserved_message_id: Optional[str] = None
|
||||
# Carried through resumes so multi-pause runs keep one request_id.
|
||||
self.request_id: Optional[str] = None
|
||||
self.conversation_service = ConversationService()
|
||||
self.compression_orchestrator = CompressionOrchestrator(
|
||||
self.conversation_service
|
||||
@@ -191,16 +197,23 @@ class StreamProcessor:
|
||||
for query in conversation.get("queries", [])
|
||||
]
|
||||
else:
|
||||
# model_user_id keeps history trim aligned with the BYOM's
|
||||
# actual context window instead of the default 128k.
|
||||
self.history = limit_chat_history(
|
||||
json.loads(self.data.get("history", "[]")), model_id=self.model_id
|
||||
json.loads(self.data.get("history", "[]")),
|
||||
model_id=self.model_id,
|
||||
user_id=self.model_user_id,
|
||||
)
|
||||
|
||||
def _handle_compression(self, conversation: Dict[str, Any]):
|
||||
"""Handle conversation compression logic using orchestrator."""
|
||||
try:
|
||||
# initial_user_id for conversation access; model_user_id
|
||||
# for BYOM context-window / provider lookups.
|
||||
result = self.compression_orchestrator.compress_if_needed(
|
||||
conversation_id=self.conversation_id,
|
||||
user_id=self.initial_user_id,
|
||||
model_user_id=self.model_user_id,
|
||||
model_id=self.model_id,
|
||||
decoded_token=self.decoded_token,
|
||||
)
|
||||
@@ -284,11 +297,18 @@ class StreamProcessor:
|
||||
from application.core.model_settings import ModelRegistry
|
||||
|
||||
requested_model = self.data.get("model_id")
|
||||
# Caller picks from their own BYOM layer; agent defaults resolve
|
||||
# under the owner's layer (shared agents have caller != owner).
|
||||
caller_user_id = self.initial_user_id
|
||||
owner_user_id = self.agent_config.get("user_id") or caller_user_id
|
||||
|
||||
if requested_model:
|
||||
if not validate_model_id(requested_model):
|
||||
if not validate_model_id(requested_model, user_id=caller_user_id):
|
||||
registry = ModelRegistry.get_instance()
|
||||
available_models = [m.id for m in registry.get_enabled_models()]
|
||||
available_models = [
|
||||
m.id
|
||||
for m in registry.get_enabled_models(user_id=caller_user_id)
|
||||
]
|
||||
raise ValueError(
|
||||
f"Invalid model_id '{requested_model}'. "
|
||||
f"Available models: {', '.join(available_models[:5])}"
|
||||
@@ -299,12 +319,17 @@ class StreamProcessor:
|
||||
)
|
||||
)
|
||||
self.model_id = requested_model
|
||||
self.model_user_id = caller_user_id
|
||||
else:
|
||||
agent_default_model = self.agent_config.get("default_model_id", "")
|
||||
if agent_default_model and validate_model_id(agent_default_model):
|
||||
if agent_default_model and validate_model_id(
|
||||
agent_default_model, user_id=owner_user_id
|
||||
):
|
||||
self.model_id = agent_default_model
|
||||
self.model_user_id = owner_user_id
|
||||
else:
|
||||
self.model_id = get_default_model_id()
|
||||
self.model_user_id = None
|
||||
|
||||
def _get_agent_key(self, agent_id: Optional[str], user_id: Optional[str]) -> tuple:
|
||||
"""Get API key for agent with access control."""
|
||||
@@ -514,6 +539,10 @@ class StreamProcessor:
|
||||
"allow_system_prompt_override": self._agent_data.get(
|
||||
"allow_system_prompt_override", False
|
||||
),
|
||||
# Owner identity — _validate_and_set_model reads this to
|
||||
# resolve owner-stored BYOM default_model_id against the
|
||||
# owner's per-user model layer rather than the caller's.
|
||||
"user_id": self._agent_data.get("user"),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -561,7 +590,13 @@ class StreamProcessor:
|
||||
|
||||
def _configure_retriever(self):
|
||||
"""Assemble retriever config with precedence: request > agent > default."""
|
||||
doc_token_limit = calculate_doc_token_budget(model_id=self.model_id)
|
||||
# BYOM scope: owner for shared-agent BYOM, caller for own BYOM,
|
||||
# None for built-ins. Without ``user_id`` here, the doc budget
|
||||
# falls back to settings.DEFAULT_LLM_TOKEN_LIMIT and overfills
|
||||
# the upstream context window for any small (e.g. 8k/32k) BYOM.
|
||||
doc_token_limit = calculate_doc_token_budget(
|
||||
model_id=self.model_id, user_id=self.model_user_id
|
||||
)
|
||||
|
||||
# Start with defaults
|
||||
retriever_name = "classic"
|
||||
@@ -612,6 +647,7 @@ class StreamProcessor:
|
||||
chunks=self.retriever_config["chunks"],
|
||||
doc_token_limit=self.retriever_config.get("doc_token_limit", 50000),
|
||||
model_id=self.model_id,
|
||||
model_user_id=self.model_user_id,
|
||||
user_api_key=self.agent_config["user_api_key"],
|
||||
agent_id=self.agent_id,
|
||||
decoded_token=self.decoded_token,
|
||||
@@ -896,6 +932,20 @@ class StreamProcessor:
|
||||
if not state:
|
||||
raise ValueError("No pending tool state found for this conversation")
|
||||
|
||||
# Claim the resume up-front. ``mark_resuming`` only flips ``pending``
|
||||
# → ``resuming``; if it returns False, another resume already
|
||||
# claimed this row (status='resuming') — bail before any further
|
||||
# LLM/tool work to avoid double-execution. The cleanup janitor
|
||||
# reverts a stale ``resuming`` claim back to ``pending`` after the
|
||||
# 10-minute grace window so the user can retry.
|
||||
if not cont_service.mark_resuming(
|
||||
conversation_id, self.initial_user_id,
|
||||
):
|
||||
raise ValueError(
|
||||
"Resume already in progress for this conversation; "
|
||||
"retry after the grace window if it stalls."
|
||||
)
|
||||
|
||||
messages = state["messages"]
|
||||
pending_tool_calls = state["pending_tool_calls"]
|
||||
tools_dict = state["tools_dict"]
|
||||
@@ -903,6 +953,11 @@ class StreamProcessor:
|
||||
agent_config = state["agent_config"]
|
||||
|
||||
model_id = agent_config.get("model_id")
|
||||
# BYOM scope captured at initial dispatch. None for built-ins or
|
||||
# caller-owned BYOM where decoded_token['sub'] is already the
|
||||
# right scope; non-None for shared-agent owner BYOM where the
|
||||
# caller's identity differs from the model owner's.
|
||||
model_user_id = agent_config.get("model_user_id")
|
||||
llm_name = agent_config.get("llm_name", settings.LLM_PROVIDER)
|
||||
api_key = agent_config.get("api_key")
|
||||
user_api_key = agent_config.get("user_api_key")
|
||||
@@ -920,6 +975,7 @@ class StreamProcessor:
|
||||
decoded_token=self.decoded_token,
|
||||
model_id=model_id,
|
||||
agent_id=agent_id,
|
||||
model_user_id=model_user_id,
|
||||
)
|
||||
llm_handler = LLMHandlerCreator.create_handler(llm_name or "default")
|
||||
tool_executor = ToolExecutor(
|
||||
@@ -949,6 +1005,7 @@ class StreamProcessor:
|
||||
"endpoint": "stream",
|
||||
"llm_name": llm_name,
|
||||
"model_id": model_id,
|
||||
"model_user_id": model_user_id,
|
||||
"api_key": system_api_key,
|
||||
"agent_id": agent_id,
|
||||
"user_api_key": user_api_key,
|
||||
@@ -971,12 +1028,22 @@ class StreamProcessor:
|
||||
|
||||
# Store config for the route layer
|
||||
self.model_id = model_id
|
||||
# Mirror ``model_user_id`` back onto the processor so the route
|
||||
# layer (StreamResource) reads the owner scope captured at
|
||||
# initial dispatch. Without this, ``processor.model_user_id``
|
||||
# stays at the __init__ default (None) and complete_stream
|
||||
# falls back to the caller's sub: the post-resume title-LLM
|
||||
# save misses the owner's BYOM layer, and any second tool
|
||||
# pause persists ``model_user_id=None`` — losing owner scope
|
||||
# for every subsequent resume of this conversation.
|
||||
self.model_user_id = model_user_id
|
||||
self.agent_id = agent_id
|
||||
self.agent_config["user_api_key"] = user_api_key
|
||||
self.conversation_id = conversation_id
|
||||
|
||||
# Delete state so it can't be replayed
|
||||
cont_service.delete_state(conversation_id, self.initial_user_id)
|
||||
# Reused on resume so the same WAL row gets finalised and
|
||||
# request_id stays consistent across token_usage rows.
|
||||
self.reserved_message_id = agent_config.get("reserved_message_id")
|
||||
self.request_id = agent_config.get("request_id")
|
||||
|
||||
return agent, messages, tools_dict, pending_tool_calls, tool_actions
|
||||
|
||||
@@ -1022,8 +1089,11 @@ class StreamProcessor:
|
||||
tools_data=tools_data,
|
||||
)
|
||||
|
||||
# Use the user_id that resolved the model so owner-scoped BYOM
|
||||
# records dispatch correctly on shared-agent requests.
|
||||
model_user_id = getattr(self, "model_user_id", self.initial_user_id)
|
||||
provider = (
|
||||
get_provider_from_model_id(self.model_id)
|
||||
get_provider_from_model_id(self.model_id, user_id=model_user_id)
|
||||
if self.model_id
|
||||
else settings.LLM_PROVIDER
|
||||
)
|
||||
@@ -1048,6 +1118,8 @@ class StreamProcessor:
|
||||
model_id=self.model_id,
|
||||
agent_id=self.agent_id,
|
||||
backup_models=backup_models,
|
||||
# Owner-scope on shared-agent BYOM dispatch.
|
||||
model_user_id=model_user_id,
|
||||
)
|
||||
llm_handler = LLMHandlerCreator.create_handler(
|
||||
provider if provider else "default"
|
||||
@@ -1070,6 +1142,7 @@ class StreamProcessor:
|
||||
"endpoint": "stream",
|
||||
"llm_name": provider or settings.LLM_PROVIDER,
|
||||
"model_id": self.model_id,
|
||||
"model_user_id": self.model_user_id,
|
||||
"api_key": system_api_key,
|
||||
"agent_id": self.agent_id,
|
||||
"user_api_key": self.agent_config["user_api_key"],
|
||||
@@ -1097,6 +1170,7 @@ class StreamProcessor:
|
||||
"doc_token_limit", 50000
|
||||
),
|
||||
"model_id": self.model_id,
|
||||
"model_user_id": self.model_user_id,
|
||||
"user_api_key": self.agent_config["user_api_key"],
|
||||
"agent_id": self.agent_id,
|
||||
"llm_name": provider or settings.LLM_PROVIDER,
|
||||
|
||||
504
application/api/events/routes.py
Normal file
504
application/api/events/routes.py
Normal file
@@ -0,0 +1,504 @@
|
||||
"""GET /api/events — user-scoped Server-Sent Events endpoint.
|
||||
|
||||
Subscribe-then-snapshot pattern: subscribe to ``user:{user_id}``
|
||||
pub/sub, snapshot the Redis Streams backlog past ``Last-Event-ID``
|
||||
inside the SUBSCRIBE-ack callback, flush snapshot, then tail live
|
||||
events (dedup'd by stream id). See ``docs/runbooks/sse-notifications.md``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from typing import Iterator, Optional
|
||||
|
||||
from flask import Blueprint, Response, jsonify, make_response, request, stream_with_context
|
||||
|
||||
from application.cache import get_redis_instance
|
||||
from application.core.settings import settings
|
||||
from application.events.keys import (
|
||||
connection_counter_key,
|
||||
replay_budget_key,
|
||||
stream_id_compare,
|
||||
stream_key,
|
||||
topic_name,
|
||||
)
|
||||
from application.streaming.broadcast_channel import Topic
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
events = Blueprint("event_stream", __name__)
|
||||
|
||||
SUBSCRIBE_POLL_INTERVAL_SECONDS = 1.0
|
||||
|
||||
# WHATWG SSE treats CRLF, CR, and LF equivalently as line terminators.
|
||||
_SSE_LINE_SPLIT = re.compile(r"\r\n|\r|\n")
|
||||
|
||||
# Redis Streams ids are ``ms`` or ``ms-seq`` where both halves are decimal.
|
||||
# Anything else is a corrupted client cookie / IndexedDB residue and must
|
||||
# not be passed to XRANGE — Redis would reject it and our truncation gate
|
||||
# would silently fail.
|
||||
_STREAM_ID_RE = re.compile(r"^\d+(-\d+)?$")
|
||||
|
||||
# Only emitted at most once per process so a misconfigured deployment
|
||||
# doesn't drown the logs.
|
||||
_local_user_warned = False
|
||||
|
||||
|
||||
def _format_sse(data: str, *, event_id: Optional[str] = None) -> str:
|
||||
"""Encode a payload as one SSE message terminated by a blank line.
|
||||
|
||||
Splits on any line-terminator variant (``\\r\\n``, ``\\r``, ``\\n``)
|
||||
so a stray CR in upstream content can't smuggle a premature line
|
||||
boundary into the wire format.
|
||||
"""
|
||||
lines: list[str] = []
|
||||
if event_id:
|
||||
lines.append(f"id: {event_id}")
|
||||
for line in _SSE_LINE_SPLIT.split(data):
|
||||
lines.append(f"data: {line}")
|
||||
return "\n".join(lines) + "\n\n"
|
||||
|
||||
|
||||
def _decode(value) -> Optional[str]:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, (bytes, bytearray)):
|
||||
try:
|
||||
return value.decode("utf-8")
|
||||
except Exception:
|
||||
return None
|
||||
return str(value)
|
||||
|
||||
|
||||
def _oldest_retained_id(redis_client, user_id: str) -> Optional[str]:
|
||||
"""Return the id of the oldest entry still in the stream, or ``None``.
|
||||
|
||||
Used to detect ``Last-Event-ID`` having slid off the back of the
|
||||
MAXLEN'd window.
|
||||
"""
|
||||
try:
|
||||
info = redis_client.xinfo_stream(stream_key(user_id))
|
||||
except Exception:
|
||||
return None
|
||||
if not isinstance(info, dict):
|
||||
return None
|
||||
# redis-py 7.4 returns str-keyed dicts here; the bytes-key probe is
|
||||
# defence in depth in case ``decode_responses`` is ever flipped.
|
||||
first_entry = info.get("first-entry") or info.get(b"first-entry")
|
||||
if not first_entry:
|
||||
return None
|
||||
# XINFO STREAM returns first-entry as [id, [field, value, ...]]
|
||||
try:
|
||||
return _decode(first_entry[0])
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _allow_replay(
|
||||
redis_client, user_id: str, last_event_id: Optional[str]
|
||||
) -> bool:
|
||||
"""Per-user sliding-window snapshot-replay budget.
|
||||
|
||||
Fails open on Redis errors or when the budget is disabled. Empty-backlog
|
||||
no-cursor connects skip INCR so dev double-mounts don't trip 429.
|
||||
"""
|
||||
budget = int(settings.EVENTS_REPLAY_BUDGET_REQUESTS_PER_WINDOW)
|
||||
if budget <= 0:
|
||||
return True
|
||||
if redis_client is None:
|
||||
return True
|
||||
|
||||
# Cheap pre-check: only INCR when we might actually replay. XLEN
|
||||
# is one Redis op; the alternative (INCR every connect) is two
|
||||
# ops AND wrongly counts no-op probes. The check is conservative:
|
||||
# if ``last_event_id`` is set we always INCR, even if the cursor
|
||||
# has already overtaken the latest entry — that case is rare and
|
||||
# short-lived, and probing further would mean a redundant XRANGE.
|
||||
if last_event_id is None:
|
||||
try:
|
||||
if int(redis_client.xlen(stream_key(user_id))) == 0:
|
||||
return True
|
||||
except Exception:
|
||||
# XLEN probe failed; fall through to the INCR path so a
|
||||
# transient Redis hiccup can't bypass the budget.
|
||||
logger.debug(
|
||||
"XLEN probe failed for replay budget check user=%s; "
|
||||
"proceeding to INCR",
|
||||
user_id,
|
||||
)
|
||||
|
||||
window = max(1, int(settings.EVENTS_REPLAY_BUDGET_WINDOW_SECONDS))
|
||||
key = replay_budget_key(user_id)
|
||||
try:
|
||||
used = int(redis_client.incr(key))
|
||||
# Always (re)seed the TTL. Gating on ``used == 1`` would wedge
|
||||
# the counter forever if INCR succeeds but EXPIRE raises on
|
||||
# the seeding call. EXPIRE on an existing key resets the TTL
|
||||
# to ``window`` — within ±1s of the per-window budget semantic.
|
||||
redis_client.expire(key, window)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"replay budget probe failed for user=%s; failing open",
|
||||
user_id,
|
||||
)
|
||||
return True
|
||||
return used <= budget
|
||||
|
||||
|
||||
def _normalize_last_event_id(raw: Optional[str]) -> Optional[str]:
|
||||
"""Validate the ``Last-Event-ID`` header / query param.
|
||||
|
||||
Returns the value unchanged when it parses as a Redis Streams id,
|
||||
otherwise ``None`` — callers treat ``None`` as "client has nothing"
|
||||
and replay from the start of the retained window. Invalid ids would
|
||||
otherwise pass straight to XRANGE and surface as a quiet replay
|
||||
failure plus broken truncation detection.
|
||||
"""
|
||||
if raw is None:
|
||||
return None
|
||||
raw = raw.strip()
|
||||
if not raw or not _STREAM_ID_RE.match(raw):
|
||||
return None
|
||||
return raw
|
||||
|
||||
|
||||
def _replay_backlog(
|
||||
redis_client, user_id: str, last_event_id: Optional[str], max_count: int
|
||||
) -> Iterator[tuple[str, str]]:
|
||||
"""Yield ``(entry_id, sse_line)`` for backlog entries past ``last_event_id``.
|
||||
|
||||
Capped at ``max_count`` rows; clients catch up across reconnects.
|
||||
Parse failures are skipped; the Streams id is injected into the
|
||||
envelope so replay matches live-tail shape.
|
||||
"""
|
||||
# Exclusive start: '(<id>' skips the already-delivered entry.
|
||||
start = f"({last_event_id}" if last_event_id else "-"
|
||||
try:
|
||||
entries = redis_client.xrange(
|
||||
stream_key(user_id), min=start, max="+", count=max_count
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"xrange replay failed for user=%s last_id=%s err=%s",
|
||||
user_id,
|
||||
last_event_id or "-",
|
||||
exc,
|
||||
)
|
||||
return
|
||||
|
||||
for entry_id, fields in entries:
|
||||
entry_id_str = _decode(entry_id)
|
||||
if not entry_id_str:
|
||||
continue
|
||||
# decode_responses=False on the cache client ⇒ field keys/values
|
||||
# are bytes. The string-key fallback covers a future flip of that
|
||||
# default without a forced refactor here.
|
||||
raw_event = None
|
||||
if isinstance(fields, dict):
|
||||
raw_event = fields.get(b"event")
|
||||
if raw_event is None:
|
||||
raw_event = fields.get("event")
|
||||
event_str = _decode(raw_event)
|
||||
if not event_str:
|
||||
continue
|
||||
try:
|
||||
envelope = json.loads(event_str)
|
||||
if isinstance(envelope, dict):
|
||||
envelope["id"] = entry_id_str
|
||||
event_str = json.dumps(envelope)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Replay envelope parse failed for entry %s; passing through raw",
|
||||
entry_id_str,
|
||||
)
|
||||
yield entry_id_str, _format_sse(event_str, event_id=entry_id_str)
|
||||
|
||||
|
||||
def _truncation_notice_line(oldest_id: str) -> str:
|
||||
"""SSE event the frontend can react to with a full-state refetch."""
|
||||
return _format_sse(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "backlog.truncated",
|
||||
"payload": {"oldest_retained_id": oldest_id},
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@events.route("/api/events", methods=["GET"])
|
||||
def stream_events() -> Response:
|
||||
decoded = getattr(request, "decoded_token", None)
|
||||
user_id = decoded.get("sub") if isinstance(decoded, dict) else None
|
||||
if not user_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Authentication required"}),
|
||||
401,
|
||||
)
|
||||
|
||||
# In dev deployments without AUTH_TYPE configured, every request
|
||||
# resolves to user_id="local" and shares one stream. Surface this so
|
||||
# an accidentally-multi-user dev box doesn't silently cross-stream.
|
||||
global _local_user_warned
|
||||
if user_id == "local" and not _local_user_warned:
|
||||
logger.warning(
|
||||
"SSE serving user_id='local' (AUTH_TYPE not set). "
|
||||
"All clients on this deployment will share one event stream."
|
||||
)
|
||||
_local_user_warned = True
|
||||
|
||||
raw_last_event_id = request.headers.get("Last-Event-ID") or request.args.get(
|
||||
"last_event_id"
|
||||
)
|
||||
last_event_id = _normalize_last_event_id(raw_last_event_id)
|
||||
last_event_id_invalid = raw_last_event_id is not None and last_event_id is None
|
||||
|
||||
keepalive_seconds = float(settings.SSE_KEEPALIVE_SECONDS)
|
||||
push_enabled = settings.ENABLE_SSE_PUSH
|
||||
cap = int(settings.SSE_MAX_CONCURRENT_PER_USER)
|
||||
|
||||
redis_client = get_redis_instance()
|
||||
counter_key = connection_counter_key(user_id)
|
||||
counted = False
|
||||
|
||||
if push_enabled and redis_client is not None and cap > 0:
|
||||
try:
|
||||
current = int(redis_client.incr(counter_key))
|
||||
counted = True
|
||||
except Exception:
|
||||
current = 0
|
||||
logger.debug(
|
||||
"SSE connection counter INCR failed for user=%s", user_id
|
||||
)
|
||||
if counted:
|
||||
# 1h safety TTL — orphaned counts from hard crashes self-heal.
|
||||
# EXPIRE failure must NOT clobber ``current`` and bypass the cap.
|
||||
try:
|
||||
redis_client.expire(counter_key, 3600)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"SSE connection counter EXPIRE failed for user=%s", user_id
|
||||
)
|
||||
if current > cap:
|
||||
try:
|
||||
redis_client.decr(counter_key)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"SSE connection counter DECR failed for user=%s",
|
||||
user_id,
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Too many concurrent SSE connections",
|
||||
}
|
||||
),
|
||||
429,
|
||||
)
|
||||
|
||||
# Replay budget is checked here, before the generator opens the
|
||||
# stream, so a denial can surface as HTTP 429 instead of a silent
|
||||
# snapshot skip. The earlier in-generator skip lost events between
|
||||
# the client's cursor and the first live-tailed entry: the live
|
||||
# tail still carried ``id:`` headers, the frontend advanced
|
||||
# ``lastEventId`` to one of those ids, and the events in between
|
||||
# were never reachable on the next reconnect. 429 keeps the
|
||||
# cursor pinned and lets the frontend back off until the window
|
||||
# slides (eventStreamClient.ts treats 429 as escalated backoff).
|
||||
if push_enabled and redis_client is not None and not _allow_replay(
|
||||
redis_client, user_id, last_event_id
|
||||
):
|
||||
if counted:
|
||||
try:
|
||||
redis_client.decr(counter_key)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"SSE connection counter DECR failed for user=%s",
|
||||
user_id,
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Replay budget exhausted",
|
||||
}
|
||||
),
|
||||
429,
|
||||
)
|
||||
|
||||
@stream_with_context
|
||||
def generate() -> Iterator[str]:
|
||||
connect_ts = time.monotonic()
|
||||
replayed_count = 0
|
||||
try:
|
||||
# First frame primes intermediaries (Cloudflare, nginx) so they
|
||||
# don't sit on a buffer waiting for body bytes.
|
||||
yield ": connected\n\n"
|
||||
|
||||
if not push_enabled:
|
||||
yield ": push_disabled\n\n"
|
||||
return
|
||||
|
||||
replay_lines: list[str] = []
|
||||
max_replayed_id: Optional[str] = None
|
||||
replay_done = False
|
||||
|
||||
# If the client sent a malformed Last-Event-ID, surface the
|
||||
# truncation notice synchronously *before* the subscribe
|
||||
# loop. Buffering it into ``replay_lines`` would lose it
|
||||
# when ``Topic.subscribe`` returns immediately (Redis down)
|
||||
# — the loop body never runs, and the flush at line ~335
|
||||
# never fires.
|
||||
if last_event_id_invalid:
|
||||
yield _truncation_notice_line("")
|
||||
replayed_count += 1
|
||||
|
||||
def _on_subscribe_callback() -> None:
|
||||
# Runs synchronously inside Topic.subscribe after the
|
||||
# SUBSCRIBE is acked. By doing XRANGE here, any publisher
|
||||
# firing between SUBSCRIBE-send and SUBSCRIBE-ack has its
|
||||
# XADD captured by XRANGE *and* its PUBLISH buffered at
|
||||
# the connection layer until we read it — closing the
|
||||
# replay/subscribe race the design doc warns about.
|
||||
#
|
||||
# Truncation contract: ``backlog.truncated`` is emitted
|
||||
# ONLY when the client's ``Last-Event-ID`` has slid off
|
||||
# the MAXLEN'd window — that's the case where the
|
||||
# journal is genuinely gone past the cursor and the
|
||||
# frontend should clear its slice cursor and refetch
|
||||
# state. Cap-hit skips the snapshot silently: the
|
||||
# cursor advances via the per-entry ``id:`` headers
|
||||
# and the frontend's slice keeps the latest id so the
|
||||
# next reconnect resumes from there. Budget-exhausted
|
||||
# never reaches this callback — the route 429s before
|
||||
# opening the stream, keeping the cursor pinned.
|
||||
# Conflating these with stale-cursor truncation would
|
||||
# tell the client to clear its cursor and re-receive
|
||||
# the same oldest-N entries on every reconnect —
|
||||
# locking the user out of entries past N.
|
||||
nonlocal max_replayed_id, replay_done
|
||||
try:
|
||||
if redis_client is None:
|
||||
return
|
||||
oldest = _oldest_retained_id(redis_client, user_id)
|
||||
if (
|
||||
last_event_id
|
||||
and oldest
|
||||
and stream_id_compare(last_event_id, oldest) < 0
|
||||
):
|
||||
# The Last-Event-ID has slid off the MAXLEN window.
|
||||
# Tell the client so it can fetch full state.
|
||||
replay_lines.append(_truncation_notice_line(oldest))
|
||||
replay_cap = int(settings.EVENTS_REPLAY_MAX_PER_REQUEST)
|
||||
for entry_id, sse_line in _replay_backlog(
|
||||
redis_client, user_id, last_event_id, replay_cap
|
||||
):
|
||||
replay_lines.append(sse_line)
|
||||
max_replayed_id = entry_id
|
||||
finally:
|
||||
# Always flip the flag — even on partial-replay failure
|
||||
# the outer loop must reach the flush step so we don't
|
||||
# silently strand whatever entries did land.
|
||||
replay_done = True
|
||||
|
||||
topic = Topic(topic_name(user_id))
|
||||
last_keepalive = time.monotonic()
|
||||
for payload in topic.subscribe(
|
||||
on_subscribe=_on_subscribe_callback,
|
||||
poll_timeout=SUBSCRIBE_POLL_INTERVAL_SECONDS,
|
||||
):
|
||||
# Flush snapshot on the first iteration after the SUBSCRIBE
|
||||
# callback ran. This runs at most once per connection.
|
||||
if replay_done and replay_lines:
|
||||
for line in replay_lines:
|
||||
yield line
|
||||
replayed_count += 1
|
||||
replay_lines.clear()
|
||||
|
||||
now = time.monotonic()
|
||||
if payload is None:
|
||||
if now - last_keepalive >= keepalive_seconds:
|
||||
yield ": keepalive\n\n"
|
||||
last_keepalive = now
|
||||
continue
|
||||
|
||||
event_str = _decode(payload) or ""
|
||||
event_id: Optional[str] = None
|
||||
try:
|
||||
envelope = json.loads(event_str)
|
||||
if isinstance(envelope, dict):
|
||||
candidate = envelope.get("id")
|
||||
# Only trust ids that look like real Redis Streams
|
||||
# ids (``ms`` or ``ms-seq``). A malformed or
|
||||
# adversarial publisher could otherwise pin
|
||||
# dedupe forever — a lex-greater bogus id would
|
||||
# make every legitimate later id compare ``<=``
|
||||
# and get dropped silently.
|
||||
if isinstance(candidate, str) and _STREAM_ID_RE.match(
|
||||
candidate
|
||||
):
|
||||
event_id = candidate
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Dedupe: if this id was already covered by replay, drop it.
|
||||
if (
|
||||
event_id is not None
|
||||
and max_replayed_id is not None
|
||||
and stream_id_compare(event_id, max_replayed_id) <= 0
|
||||
):
|
||||
continue
|
||||
|
||||
yield _format_sse(event_str, event_id=event_id)
|
||||
last_keepalive = now
|
||||
|
||||
# Topic.subscribe exited before the first yield (transient
|
||||
# Redis hiccup between SUBSCRIBE-ack and the first poll, or
|
||||
# an immediate Redis-down return). The callback may already
|
||||
# have populated the snapshot — flush it so the client gets
|
||||
# the backlog instead of a silent drop. Safe no-op when the
|
||||
# in-loop flush ran (it clear()'d the buffer) and when the
|
||||
# callback never fired (replay_done stays False).
|
||||
if replay_done and replay_lines:
|
||||
for line in replay_lines:
|
||||
yield line
|
||||
replayed_count += 1
|
||||
replay_lines.clear()
|
||||
except GeneratorExit:
|
||||
return
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"SSE event-stream generator crashed for user=%s", user_id
|
||||
)
|
||||
finally:
|
||||
duration_s = time.monotonic() - connect_ts
|
||||
logger.info(
|
||||
"event.disconnect user=%s duration_s=%.1f replayed=%d",
|
||||
user_id,
|
||||
duration_s,
|
||||
replayed_count,
|
||||
)
|
||||
if counted and redis_client is not None:
|
||||
try:
|
||||
redis_client.decr(counter_key)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"SSE connection counter DECR failed for user=%s on disconnect",
|
||||
user_id,
|
||||
)
|
||||
|
||||
response = Response(generate(), mimetype="text/event-stream")
|
||||
response.headers["Cache-Control"] = "no-store"
|
||||
response.headers["X-Accel-Buffering"] = "no"
|
||||
response.headers["Connection"] = "keep-alive"
|
||||
logger.info(
|
||||
"event.connect user=%s last_event_id=%s%s",
|
||||
user_id,
|
||||
last_event_id or "-",
|
||||
" (rejected_invalid)" if last_event_id_invalid else "",
|
||||
)
|
||||
return response
|
||||
@@ -46,7 +46,9 @@ AGENT_TYPE_SCHEMAS = {
|
||||
"prompt_id",
|
||||
],
|
||||
"required_draft": ["name"],
|
||||
"validate_published": ["name", "description", "prompt_id"],
|
||||
# ``prompt_id`` intentionally omitted — the "default" sentinel
|
||||
# is acceptable and maps to NULL downstream.
|
||||
"validate_published": ["name", "description"],
|
||||
"validate_draft": [],
|
||||
"require_source": True,
|
||||
"fields": [
|
||||
@@ -1009,12 +1011,16 @@ class UpdateAgent(Resource):
|
||||
400,
|
||||
)
|
||||
else:
|
||||
# ``prompt_id`` is intentionally omitted: the
|
||||
# frontend's "default" choice maps to NULL here
|
||||
# (see the prompt_id branch above), and NULL
|
||||
# means "use the built-in default prompt" which
|
||||
# is a valid published-agent state.
|
||||
missing_published_fields = []
|
||||
for req_field, field_label in (
|
||||
("name", "Agent name"),
|
||||
("description", "Agent description"),
|
||||
("chunks", "Chunks count"),
|
||||
("prompt_id", "Prompt"),
|
||||
("agent_type", "Agent type"),
|
||||
):
|
||||
final_value = update_fields.get(
|
||||
@@ -1028,8 +1034,23 @@ class UpdateAgent(Resource):
|
||||
extra_final = update_fields.get(
|
||||
"extra_source_ids", existing_agent.get("extra_source_ids") or [],
|
||||
)
|
||||
if not source_final and not extra_final:
|
||||
missing_published_fields.append("Source")
|
||||
# ``retriever`` carries the runtime identity for
|
||||
# agents that publish against the synthetic
|
||||
# "Default" source (frontend's auto-selected
|
||||
# ``{name: "Default", retriever: "classic"}``
|
||||
# entry has no ``id``, so ``source_id`` ends up
|
||||
# NULL even though the user picked something).
|
||||
# Without this fallback the most common new-agent
|
||||
# publish flow gets a 400.
|
||||
retriever_final = update_fields.get(
|
||||
"retriever", existing_agent.get("retriever"),
|
||||
)
|
||||
if (
|
||||
not source_final
|
||||
and not extra_final
|
||||
and not retriever_final
|
||||
):
|
||||
missing_published_fields.append("Source or retriever")
|
||||
if missing_published_fields:
|
||||
return make_response(
|
||||
jsonify(
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
"""Agent management webhook handlers."""
|
||||
|
||||
import secrets
|
||||
import uuid
|
||||
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import Namespace, Resource
|
||||
from sqlalchemy import text as sql_text
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import require_agent
|
||||
from application.api.user.tasks import process_agent_webhook
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.idempotency import IdempotencyRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
|
||||
@@ -18,6 +22,37 @@ agents_webhooks_ns = Namespace(
|
||||
)
|
||||
|
||||
|
||||
_IDEMPOTENCY_KEY_MAX_LEN = 256
|
||||
|
||||
|
||||
def _read_idempotency_key():
|
||||
"""Return (key, error_response). Empty header → (None, None); oversized → (None, 400)."""
|
||||
key = request.headers.get("Idempotency-Key")
|
||||
if not key:
|
||||
return None, None
|
||||
if len(key) > _IDEMPOTENCY_KEY_MAX_LEN:
|
||||
return None, make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": (
|
||||
f"Idempotency-Key exceeds maximum length of "
|
||||
f"{_IDEMPOTENCY_KEY_MAX_LEN} characters"
|
||||
),
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
return key, None
|
||||
|
||||
|
||||
def _scoped_idempotency_key(idempotency_key, scope):
|
||||
"""``{scope}:{key}`` so different agents can't collide on the same key."""
|
||||
if not idempotency_key or not scope:
|
||||
return None
|
||||
return f"{scope}:{idempotency_key}"
|
||||
|
||||
|
||||
@agents_webhooks_ns.route("/agent_webhook")
|
||||
class AgentWebhook(Resource):
|
||||
@api.doc(
|
||||
@@ -68,7 +103,7 @@ class AgentWebhook(Resource):
|
||||
class AgentWebhookListener(Resource):
|
||||
method_decorators = [require_agent]
|
||||
|
||||
def _enqueue_webhook_task(self, agent_id_str, payload, source_method):
|
||||
def _enqueue_webhook_task(self, agent_id_str, payload, source_method, agent=None):
|
||||
if not payload:
|
||||
current_app.logger.warning(
|
||||
f"Webhook ({source_method}) received for agent {agent_id_str} with empty payload."
|
||||
@@ -77,26 +112,94 @@ class AgentWebhookListener(Resource):
|
||||
f"Incoming {source_method} webhook for agent {agent_id_str}. Enqueuing task with payload: {payload}"
|
||||
)
|
||||
|
||||
try:
|
||||
task = process_agent_webhook.delay(
|
||||
agent_id=agent_id_str,
|
||||
payload=payload,
|
||||
idempotency_key, key_error = _read_idempotency_key()
|
||||
if key_error is not None:
|
||||
return key_error
|
||||
# Resolve to PG UUID first so dedup writes don't crash on legacy ids.
|
||||
agent_uuid = None
|
||||
if agent is not None:
|
||||
candidate = str(agent.get("id") or "")
|
||||
if looks_like_uuid(candidate):
|
||||
agent_uuid = candidate
|
||||
if idempotency_key and agent_uuid is None:
|
||||
current_app.logger.warning(
|
||||
"Skipping webhook idempotency dedup: agent %s has non-UUID id",
|
||||
agent_id_str,
|
||||
)
|
||||
idempotency_key = None
|
||||
# Agent-scoped (webhooks have no user_id).
|
||||
scoped_key = _scoped_idempotency_key(idempotency_key, agent_uuid)
|
||||
# Claim before enqueue; the loser returns the winner's task_id.
|
||||
predetermined_task_id = None
|
||||
if scoped_key:
|
||||
predetermined_task_id = str(uuid.uuid4())
|
||||
with db_session() as conn:
|
||||
claimed = IdempotencyRepository(conn).record_webhook(
|
||||
key=scoped_key,
|
||||
agent_id=agent_uuid,
|
||||
task_id=predetermined_task_id,
|
||||
response_json={
|
||||
"success": True, "task_id": predetermined_task_id,
|
||||
},
|
||||
)
|
||||
if claimed is None:
|
||||
with db_readonly() as conn:
|
||||
cached = IdempotencyRepository(conn).get_webhook(scoped_key)
|
||||
if cached is not None:
|
||||
return make_response(jsonify(cached["response_json"]), 200)
|
||||
return make_response(
|
||||
jsonify({"success": True, "task_id": "deduplicated"}), 200
|
||||
)
|
||||
|
||||
try:
|
||||
apply_kwargs = dict(
|
||||
kwargs={
|
||||
"agent_id": agent_id_str,
|
||||
"payload": payload,
|
||||
# Scoped so the worker dedup row matches the HTTP claim.
|
||||
"idempotency_key": scoped_key or idempotency_key,
|
||||
},
|
||||
)
|
||||
if predetermined_task_id is not None:
|
||||
apply_kwargs["task_id"] = predetermined_task_id
|
||||
task = process_agent_webhook.apply_async(**apply_kwargs)
|
||||
current_app.logger.info(
|
||||
f"Task {task.id} enqueued for agent {agent_id_str} ({source_method})."
|
||||
)
|
||||
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
|
||||
response_payload = {"success": True, "task_id": task.id}
|
||||
return make_response(jsonify(response_payload), 200)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error enqueuing webhook task ({source_method}) for agent {agent_id_str}: {err}",
|
||||
exc_info=True,
|
||||
)
|
||||
if scoped_key:
|
||||
# Roll back the claim so a retry can succeed.
|
||||
try:
|
||||
with db_session() as conn:
|
||||
conn.execute(
|
||||
sql_text(
|
||||
"DELETE FROM webhook_dedup "
|
||||
"WHERE idempotency_key = :k"
|
||||
),
|
||||
{"k": scoped_key},
|
||||
)
|
||||
except Exception:
|
||||
current_app.logger.exception(
|
||||
"Failed to release webhook_dedup claim for key=%s",
|
||||
scoped_key,
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Error processing webhook"}), 500
|
||||
)
|
||||
|
||||
@api.doc(
|
||||
description="Webhook listener for agent events (POST). Expects JSON payload, which is used to trigger processing.",
|
||||
description=(
|
||||
"Webhook listener for agent events (POST). Expects JSON payload, which "
|
||||
"is used to trigger processing. Honors an optional ``Idempotency-Key`` "
|
||||
"header: a repeat request with the same key within 24h returns the "
|
||||
"original cached response and does not re-enqueue the task."
|
||||
),
|
||||
)
|
||||
def post(self, webhook_token, agent, agent_id_str):
|
||||
payload = request.get_json()
|
||||
@@ -110,11 +213,20 @@ class AgentWebhookListener(Resource):
|
||||
),
|
||||
400,
|
||||
)
|
||||
return self._enqueue_webhook_task(agent_id_str, payload, source_method="POST")
|
||||
return self._enqueue_webhook_task(
|
||||
agent_id_str, payload, source_method="POST", agent=agent,
|
||||
)
|
||||
|
||||
@api.doc(
|
||||
description="Webhook listener for agent events (GET). Uses URL query parameters as payload to trigger processing.",
|
||||
description=(
|
||||
"Webhook listener for agent events (GET). Uses URL query parameters as "
|
||||
"payload to trigger processing. Honors an optional ``Idempotency-Key`` "
|
||||
"header: a repeat request with the same key within 24h returns the "
|
||||
"original cached response and does not re-enqueue the task."
|
||||
),
|
||||
)
|
||||
def get(self, webhook_token, agent, agent_id_str):
|
||||
payload = request.args.to_dict(flat=True)
|
||||
return self._enqueue_webhook_task(agent_id_str, payload, source_method="GET")
|
||||
return self._enqueue_webhook_task(
|
||||
agent_id_str, payload, source_method="GET", agent=agent,
|
||||
)
|
||||
|
||||
@@ -214,6 +214,10 @@ class StoreAttachment(Resource):
|
||||
{
|
||||
"success": True,
|
||||
"task_id": tasks[0]["task_id"],
|
||||
# Surface the attachment_id so the frontend
|
||||
# can correlate ``attachment.*`` SSE events
|
||||
# to this row and skip the polling fallback.
|
||||
"attachment_id": tasks[0]["attachment_id"],
|
||||
"message": "File uploaded successfully. Processing started.",
|
||||
}
|
||||
),
|
||||
|
||||
@@ -4,8 +4,10 @@ import datetime
|
||||
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
from sqlalchemy import text as sql_text
|
||||
|
||||
from application.api import api
|
||||
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
|
||||
from application.storage.db.repositories.attachments import AttachmentsRepository
|
||||
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
@@ -133,6 +135,7 @@ class GetSingleConversation(Resource):
|
||||
attachments_repo = AttachmentsRepository(conn)
|
||||
queries = []
|
||||
for msg in messages:
|
||||
metadata = msg.get("metadata") or {}
|
||||
query = {
|
||||
"prompt": msg.get("prompt"),
|
||||
"response": msg.get("response"),
|
||||
@@ -141,9 +144,15 @@ class GetSingleConversation(Resource):
|
||||
"tool_calls": msg.get("tool_calls") or [],
|
||||
"timestamp": msg.get("timestamp"),
|
||||
"model_id": msg.get("model_id"),
|
||||
# Lets the client distinguish placeholder rows from
|
||||
# finalised answers and tail-poll in-flight ones.
|
||||
"message_id": str(msg["id"]) if msg.get("id") else None,
|
||||
"status": msg.get("status"),
|
||||
"request_id": msg.get("request_id"),
|
||||
"last_heartbeat_at": metadata.get("last_heartbeat_at"),
|
||||
}
|
||||
if msg.get("metadata"):
|
||||
query["metadata"] = msg["metadata"]
|
||||
if metadata:
|
||||
query["metadata"] = metadata
|
||||
# Feedback on conversation_messages is a JSONB blob with
|
||||
# shape {"text": <str>, "timestamp": <iso>}. The legacy
|
||||
# frontend consumed a flat scalar feedback string, so
|
||||
@@ -301,3 +310,61 @@ class SubmitFeedback(Resource):
|
||||
current_app.logger.error(f"Error submitting feedback: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@conversations_ns.route("/messages/<string:message_id>/tail")
|
||||
class GetMessageTail(Resource):
|
||||
@api.doc(
|
||||
description=(
|
||||
"Current state of one conversation_messages row, scoped to the "
|
||||
"authenticated user. Used to reconnect to an in-flight stream "
|
||||
"after a refresh."
|
||||
),
|
||||
params={"message_id": "Message UUID"},
|
||||
)
|
||||
def get(self, message_id):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
if not looks_like_uuid(message_id):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid message id"}), 400
|
||||
)
|
||||
user_id = decoded_token.get("sub")
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
# Owner-or-shared, matching ``ConversationsRepository.get``.
|
||||
row = conn.execute(
|
||||
sql_text(
|
||||
"SELECT m.* FROM conversation_messages m "
|
||||
"JOIN conversations c ON c.id = m.conversation_id "
|
||||
"WHERE m.id = CAST(:mid AS uuid) "
|
||||
"AND (c.user_id = :uid OR :uid = ANY(c.shared_with))"
|
||||
),
|
||||
{"mid": message_id, "uid": user_id},
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return make_response(jsonify({"status": "not found"}), 404)
|
||||
msg = row_to_dict(row)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error tailing message {message_id}: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
metadata = msg.get("message_metadata") or {}
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"message_id": str(msg["id"]),
|
||||
"status": msg.get("status"),
|
||||
"response": msg.get("response"),
|
||||
"thought": msg.get("thought"),
|
||||
"sources": msg.get("sources") or [],
|
||||
"tool_calls": msg.get("tool_calls") or [],
|
||||
"request_id": msg.get("request_id"),
|
||||
"last_heartbeat_at": metadata.get("last_heartbeat_at"),
|
||||
"error": metadata.get("error"),
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
|
||||
237
application/api/user/idempotency.py
Normal file
237
application/api/user/idempotency.py
Normal file
@@ -0,0 +1,237 @@
|
||||
"""Per-Celery-task idempotency wrapper backed by ``task_dedup``."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from application.storage.db.repositories.idempotency import IdempotencyRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Poison-loop cap; transient-failure headroom without infinite retry.
|
||||
MAX_TASK_ATTEMPTS = 5
|
||||
|
||||
# 30s heartbeat / 60s TTL → ~2 missed ticks of slack before reclaim.
|
||||
LEASE_TTL_SECONDS = 60
|
||||
LEASE_HEARTBEAT_INTERVAL = 30
|
||||
|
||||
# 10 × 60s ≈ 5 min of deferral before giving up on a held lease.
|
||||
LEASE_RETRY_MAX = 10
|
||||
|
||||
|
||||
def with_idempotency(task_name: str) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
"""Short-circuit on completed key; gate concurrent runs via a lease.
|
||||
|
||||
Entry short-circuits:
|
||||
- completed row → return cached result
|
||||
- live lease held → retry(countdown=LEASE_TTL_SECONDS)
|
||||
- attempt_count > MAX_TASK_ATTEMPTS → poison-loop alert
|
||||
Success writes ``completed``; exceptions leave ``pending`` for
|
||||
autoretry until the poison-loop guard trips.
|
||||
"""
|
||||
|
||||
def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
|
||||
@functools.wraps(fn)
|
||||
def wrapper(self, *args: Any, idempotency_key: Any = None, **kwargs: Any) -> Any:
|
||||
key = idempotency_key if isinstance(idempotency_key, str) and idempotency_key else None
|
||||
if key is None:
|
||||
return fn(self, *args, idempotency_key=idempotency_key, **kwargs)
|
||||
|
||||
cached = _lookup_completed(key)
|
||||
if cached is not None:
|
||||
logger.info(
|
||||
"idempotency hit for task=%s key=%s — returning cached result",
|
||||
task_name, key,
|
||||
)
|
||||
return cached
|
||||
|
||||
owner_id = str(uuid.uuid4())
|
||||
attempt = _try_claim_lease(
|
||||
key, task_name, _safe_task_id(self), owner_id,
|
||||
)
|
||||
if attempt is None:
|
||||
# Live lease held by another worker. Re-queue and bail
|
||||
# quickly — by the time the retry fires (LEASE_TTL
|
||||
# seconds), Worker 1 has either finalised (we'll hit
|
||||
# ``_lookup_completed`` and return cached) or its lease
|
||||
# has expired and we can claim.
|
||||
logger.info(
|
||||
"idempotency: live lease held; deferring task=%s key=%s",
|
||||
task_name, key,
|
||||
)
|
||||
raise self.retry(
|
||||
countdown=LEASE_TTL_SECONDS,
|
||||
max_retries=LEASE_RETRY_MAX,
|
||||
)
|
||||
|
||||
if attempt > MAX_TASK_ATTEMPTS:
|
||||
logger.error(
|
||||
"idempotency poison-loop guard: task=%s key=%s attempts=%s",
|
||||
task_name, key, attempt,
|
||||
extra={
|
||||
"alert": "idempotency_poison_loop",
|
||||
"task_name": task_name,
|
||||
"idempotency_key": key,
|
||||
"attempts": attempt,
|
||||
},
|
||||
)
|
||||
poisoned = {
|
||||
"success": False,
|
||||
"error": "idempotency poison-loop guard tripped",
|
||||
"attempts": attempt,
|
||||
}
|
||||
_finalize(key, poisoned, status="failed")
|
||||
return poisoned
|
||||
|
||||
heartbeat_thread, heartbeat_stop = _start_lease_heartbeat(
|
||||
key, owner_id,
|
||||
)
|
||||
try:
|
||||
result = fn(self, *args, idempotency_key=idempotency_key, **kwargs)
|
||||
_finalize(key, result, status="completed")
|
||||
return result
|
||||
except Exception:
|
||||
# Drop the lease so the next retry doesn't wait LEASE_TTL.
|
||||
_release_lease(key, owner_id)
|
||||
raise
|
||||
finally:
|
||||
_stop_lease_heartbeat(heartbeat_thread, heartbeat_stop)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _lookup_completed(key: str) -> Any:
|
||||
"""Return cached ``result_json`` if a completed row exists for ``key``, else None."""
|
||||
with db_readonly() as conn:
|
||||
row = IdempotencyRepository(conn).get_task(key)
|
||||
if row is None:
|
||||
return None
|
||||
if row.get("status") != "completed":
|
||||
return None
|
||||
return row.get("result_json")
|
||||
|
||||
|
||||
def _try_claim_lease(
|
||||
key: str, task_name: str, task_id: str, owner_id: str,
|
||||
) -> Optional[int]:
|
||||
"""Atomic CAS; returns ``attempt_count`` or ``None`` when held.
|
||||
|
||||
DB outage → treated as ``attempt=1`` so transient failures don't
|
||||
block all task execution; reconciler repairs the lease columns.
|
||||
"""
|
||||
try:
|
||||
with db_session() as conn:
|
||||
return IdempotencyRepository(conn).try_claim_lease(
|
||||
key=key,
|
||||
task_name=task_name,
|
||||
task_id=task_id,
|
||||
owner_id=owner_id,
|
||||
ttl_seconds=LEASE_TTL_SECONDS,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"idempotency lease-claim failed for key=%s task=%s", key, task_name,
|
||||
)
|
||||
return 1
|
||||
|
||||
|
||||
def _finalize(key: str, result_json: Any, *, status: str) -> None:
|
||||
"""Best-effort terminal write. Never let DB outage fail the task."""
|
||||
try:
|
||||
with db_session() as conn:
|
||||
IdempotencyRepository(conn).finalize_task(
|
||||
key=key, result_json=result_json, status=status,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"idempotency finalize failed for key=%s status=%s", key, status,
|
||||
)
|
||||
|
||||
|
||||
def _release_lease(key: str, owner_id: str) -> None:
|
||||
"""Best-effort lease release on the wrapper's exception path."""
|
||||
try:
|
||||
with db_session() as conn:
|
||||
IdempotencyRepository(conn).release_lease(key, owner_id)
|
||||
except Exception:
|
||||
logger.exception("idempotency release-lease failed for key=%s", key)
|
||||
|
||||
|
||||
def _start_lease_heartbeat(
|
||||
key: str, owner_id: str,
|
||||
) -> tuple[threading.Thread, threading.Event]:
|
||||
"""Spawn a daemon thread that bumps ``lease_expires_at`` every
|
||||
:data:`LEASE_HEARTBEAT_INTERVAL` seconds until ``stop_event`` fires.
|
||||
|
||||
Mirrors ``application.worker._start_ingest_heartbeat`` so the two
|
||||
durability heartbeats share shape and cadence.
|
||||
"""
|
||||
stop_event = threading.Event()
|
||||
thread = threading.Thread(
|
||||
target=_lease_heartbeat_loop,
|
||||
args=(key, owner_id, stop_event, LEASE_HEARTBEAT_INTERVAL),
|
||||
daemon=True,
|
||||
name=f"idempotency-lease-heartbeat:{key[:32]}",
|
||||
)
|
||||
thread.start()
|
||||
return thread, stop_event
|
||||
|
||||
|
||||
def _stop_lease_heartbeat(
|
||||
thread: threading.Thread, stop_event: threading.Event,
|
||||
) -> None:
|
||||
"""Signal the heartbeat thread to exit and join with a short timeout."""
|
||||
stop_event.set()
|
||||
thread.join(timeout=10)
|
||||
|
||||
|
||||
def _lease_heartbeat_loop(
|
||||
key: str,
|
||||
owner_id: str,
|
||||
stop_event: threading.Event,
|
||||
interval: int,
|
||||
) -> None:
|
||||
"""Refresh the lease until ``stop_event`` is set or ownership is lost.
|
||||
|
||||
A failed refresh (rowcount 0) means another worker stole the lease
|
||||
after expiry — at that point the damage is already possible, so we
|
||||
log and keep ticking. Don't escalate to thread death; the main task
|
||||
body needs to keep running so its outcome is at least *recorded*.
|
||||
"""
|
||||
while not stop_event.wait(interval):
|
||||
try:
|
||||
with db_session() as conn:
|
||||
still_owned = IdempotencyRepository(conn).refresh_lease(
|
||||
key=key, owner_id=owner_id, ttl_seconds=LEASE_TTL_SECONDS,
|
||||
)
|
||||
if not still_owned:
|
||||
logger.warning(
|
||||
"idempotency lease lost mid-task for key=%s "
|
||||
"(another worker may have taken over)",
|
||||
key,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"idempotency lease-heartbeat tick failed for key=%s", key,
|
||||
)
|
||||
|
||||
|
||||
def _safe_task_id(task_self: Any) -> str:
|
||||
"""Best-effort extraction of ``self.request.id`` from a Celery task."""
|
||||
try:
|
||||
request = getattr(task_self, "request", None)
|
||||
task_id: Optional[str] = (
|
||||
getattr(request, "id", None) if request is not None else None
|
||||
)
|
||||
except Exception:
|
||||
task_id = None
|
||||
return task_id or "unknown"
|
||||
@@ -1,18 +1,135 @@
|
||||
from flask import current_app, jsonify, make_response
|
||||
"""Model routes.
|
||||
|
||||
- ``GET /api/models`` — list available models for the current user.
|
||||
Combines the built-in catalog with the user's BYOM records.
|
||||
- ``GET/POST/PATCH/DELETE /api/user/models[/<id>]`` — CRUD for the
|
||||
user's own OpenAI-compatible model registrations (BYOM).
|
||||
- ``POST /api/user/models/<id>/test`` — sanity-check the upstream
|
||||
endpoint with a tiny request.
|
||||
|
||||
Every BYOM endpoint is user-scoped at the repository layer
|
||||
(every query filters on ``user_id`` from ``request.decoded_token``).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import requests
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import Namespace, Resource
|
||||
|
||||
from application.core.model_settings import ModelRegistry
|
||||
from application.api import api
|
||||
from application.core.model_registry import ModelRegistry
|
||||
from application.security.safe_url import (
|
||||
UnsafeUserUrlError,
|
||||
pinned_post,
|
||||
validate_user_base_url,
|
||||
)
|
||||
from application.storage.db.repositories.user_custom_models import (
|
||||
UserCustomModelsRepository,
|
||||
)
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.utils import check_required_fields
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
models_ns = Namespace("models", description="Available models", path="/api")
|
||||
|
||||
|
||||
_CONTEXT_WINDOW_MIN = 1_000
|
||||
_CONTEXT_WINDOW_MAX = 10_000_000
|
||||
|
||||
|
||||
def _user_id_or_401():
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return None, make_response(jsonify({"success": False}), 401)
|
||||
user_id = decoded_token.get("sub")
|
||||
if not user_id:
|
||||
return None, make_response(jsonify({"success": False}), 401)
|
||||
return user_id, None
|
||||
|
||||
|
||||
def _normalize_capabilities(raw) -> dict:
|
||||
"""Coerce + bound the user-supplied capabilities payload."""
|
||||
raw = raw or {}
|
||||
out = {}
|
||||
if "supports_tools" in raw:
|
||||
out["supports_tools"] = bool(raw["supports_tools"])
|
||||
if "supports_structured_output" in raw:
|
||||
out["supports_structured_output"] = bool(raw["supports_structured_output"])
|
||||
if "supports_streaming" in raw:
|
||||
out["supports_streaming"] = bool(raw["supports_streaming"])
|
||||
if "attachments" in raw:
|
||||
atts = raw["attachments"] or []
|
||||
if not isinstance(atts, list):
|
||||
raise ValueError("'capabilities.attachments' must be a list")
|
||||
coerced = [str(a) for a in atts]
|
||||
# Reject unknown aliases at the API boundary so bad payloads
|
||||
# never reach the registry layer (where lenient expansion just
|
||||
# drops them). Raw MIME types (containing ``/``) pass through
|
||||
# unchanged for parity with the built-in YAML schema.
|
||||
from application.core.model_yaml import builtin_attachment_aliases
|
||||
|
||||
aliases = builtin_attachment_aliases()
|
||||
for entry in coerced:
|
||||
if "/" in entry:
|
||||
continue
|
||||
if entry not in aliases:
|
||||
valid = ", ".join(sorted(aliases.keys())) or "<none defined>"
|
||||
raise ValueError(
|
||||
f"unknown attachment alias '{entry}' in "
|
||||
f"'capabilities.attachments'. Valid aliases: {valid}, "
|
||||
f"or use a raw MIME type like 'image/png'."
|
||||
)
|
||||
out["attachments"] = coerced
|
||||
if "context_window" in raw:
|
||||
try:
|
||||
cw = int(raw["context_window"])
|
||||
except (TypeError, ValueError):
|
||||
raise ValueError("'capabilities.context_window' must be an integer")
|
||||
if not (_CONTEXT_WINDOW_MIN <= cw <= _CONTEXT_WINDOW_MAX):
|
||||
raise ValueError(
|
||||
f"'capabilities.context_window' must be between "
|
||||
f"{_CONTEXT_WINDOW_MIN} and {_CONTEXT_WINDOW_MAX}"
|
||||
)
|
||||
out["context_window"] = cw
|
||||
return out
|
||||
|
||||
|
||||
def _row_to_response(row: dict) -> dict:
|
||||
"""Wire-format projection — never includes the API key."""
|
||||
return {
|
||||
"id": str(row["id"]),
|
||||
"upstream_model_id": row["upstream_model_id"],
|
||||
"display_name": row["display_name"],
|
||||
"description": row.get("description") or "",
|
||||
"base_url": row["base_url"],
|
||||
"capabilities": row.get("capabilities") or {},
|
||||
"enabled": bool(row.get("enabled", True)),
|
||||
"source": "user",
|
||||
}
|
||||
|
||||
|
||||
@models_ns.route("/models")
|
||||
class ModelsListResource(Resource):
|
||||
def get(self):
|
||||
"""Get list of available models with their capabilities."""
|
||||
"""Get list of available models with their capabilities.
|
||||
|
||||
When the request is authenticated, the response includes the
|
||||
user's own BYOM registrations alongside the built-in catalog.
|
||||
"""
|
||||
try:
|
||||
user_id = None
|
||||
decoded_token = getattr(request, "decoded_token", None)
|
||||
if decoded_token:
|
||||
user_id = decoded_token.get("sub")
|
||||
|
||||
registry = ModelRegistry.get_instance()
|
||||
models = registry.get_enabled_models()
|
||||
models = registry.get_enabled_models(user_id=user_id)
|
||||
|
||||
response = {
|
||||
"models": [model.to_dict() for model in models],
|
||||
@@ -23,3 +140,382 @@ class ModelsListResource(Resource):
|
||||
current_app.logger.error(f"Error fetching models: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 500)
|
||||
return make_response(jsonify(response), 200)
|
||||
|
||||
|
||||
@models_ns.route("/user/models")
|
||||
class UserModelsCollectionResource(Resource):
|
||||
@api.doc(description="List the current user's BYOM custom models")
|
||||
def get(self):
|
||||
user_id, err = _user_id_or_401()
|
||||
if err:
|
||||
return err
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
rows = UserCustomModelsRepository(conn).list_for_user(user_id)
|
||||
return make_response(
|
||||
jsonify({"models": [_row_to_response(r) for r in rows]}), 200
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error listing user custom models: {e}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 500)
|
||||
|
||||
@api.doc(description="Register a new BYOM custom model")
|
||||
def post(self):
|
||||
user_id, err = _user_id_or_401()
|
||||
if err:
|
||||
return err
|
||||
|
||||
data = request.get_json() or {}
|
||||
missing = check_required_fields(
|
||||
data,
|
||||
["upstream_model_id", "display_name", "base_url", "api_key"],
|
||||
)
|
||||
if missing:
|
||||
return missing
|
||||
|
||||
# SECURITY: reject blank api_key — would leak instance API key
|
||||
# to the user-supplied base_url via LLMCreator fallback.
|
||||
for required_nonblank in (
|
||||
"upstream_model_id",
|
||||
"display_name",
|
||||
"base_url",
|
||||
"api_key",
|
||||
):
|
||||
value = data.get(required_nonblank)
|
||||
if not isinstance(value, str) or not value.strip():
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": f"'{required_nonblank}' must be a non-empty string",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
|
||||
# SSRF guard at create time. Re-runs at dispatch time (LLMCreator)
|
||||
# as defense in depth against DNS rebinding and pre-guard rows.
|
||||
try:
|
||||
validate_user_base_url(data["base_url"])
|
||||
except UnsafeUserUrlError as e:
|
||||
return make_response(
|
||||
jsonify({"success": False, "error": str(e)}), 400
|
||||
)
|
||||
|
||||
try:
|
||||
capabilities = _normalize_capabilities(data.get("capabilities"))
|
||||
except ValueError as e:
|
||||
return make_response(
|
||||
jsonify({"success": False, "error": str(e)}), 400
|
||||
)
|
||||
|
||||
try:
|
||||
with db_session() as conn:
|
||||
row = UserCustomModelsRepository(conn).create(
|
||||
user_id=user_id,
|
||||
upstream_model_id=data["upstream_model_id"],
|
||||
display_name=data["display_name"],
|
||||
description=data.get("description") or "",
|
||||
base_url=data["base_url"],
|
||||
api_key_plaintext=data["api_key"],
|
||||
capabilities=capabilities,
|
||||
enabled=bool(data.get("enabled", True)),
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error creating user custom model: {e}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 500)
|
||||
|
||||
ModelRegistry.invalidate_user(user_id)
|
||||
return make_response(jsonify(_row_to_response(row)), 201)
|
||||
|
||||
|
||||
@models_ns.route("/user/models/<string:model_id>")
|
||||
class UserModelResource(Resource):
|
||||
@api.doc(description="Get one BYOM custom model")
|
||||
def get(self, model_id):
|
||||
user_id, err = _user_id_or_401()
|
||||
if err:
|
||||
return err
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
row = UserCustomModelsRepository(conn).get(model_id, user_id)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error fetching user custom model: {e}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 500)
|
||||
if row is None:
|
||||
return make_response(jsonify({"success": False}), 404)
|
||||
return make_response(jsonify(_row_to_response(row)), 200)
|
||||
|
||||
@api.doc(description="Update a BYOM custom model (partial)")
|
||||
def patch(self, model_id):
|
||||
user_id, err = _user_id_or_401()
|
||||
if err:
|
||||
return err
|
||||
|
||||
data = request.get_json() or {}
|
||||
|
||||
# Reject present-but-blank values for fields where blank doesn't
|
||||
# mean "no change". (The api_key special case — blank means "keep
|
||||
# existing" — is handled below.)
|
||||
for required_nonblank in (
|
||||
"upstream_model_id",
|
||||
"display_name",
|
||||
"base_url",
|
||||
):
|
||||
if required_nonblank in data:
|
||||
value = data[required_nonblank]
|
||||
if not isinstance(value, str) or not value.strip():
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": f"'{required_nonblank}' cannot be blank",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
|
||||
if "base_url" in data and data["base_url"]:
|
||||
try:
|
||||
validate_user_base_url(data["base_url"])
|
||||
except UnsafeUserUrlError as e:
|
||||
return make_response(
|
||||
jsonify({"success": False, "error": str(e)}), 400
|
||||
)
|
||||
|
||||
update_fields: dict = {}
|
||||
for k in (
|
||||
"upstream_model_id",
|
||||
"display_name",
|
||||
"description",
|
||||
"base_url",
|
||||
"enabled",
|
||||
):
|
||||
if k in data:
|
||||
update_fields[k] = data[k]
|
||||
|
||||
if "capabilities" in data:
|
||||
try:
|
||||
update_fields["capabilities"] = _normalize_capabilities(
|
||||
data["capabilities"]
|
||||
)
|
||||
except ValueError as e:
|
||||
return make_response(
|
||||
jsonify({"success": False, "error": str(e)}), 400
|
||||
)
|
||||
|
||||
# PATCH semantics: blank/missing api_key → keep the existing
|
||||
# ciphertext; non-empty api_key → re-encrypt and replace.
|
||||
if data.get("api_key"):
|
||||
update_fields["api_key_plaintext"] = data["api_key"]
|
||||
|
||||
if not update_fields:
|
||||
return make_response(
|
||||
jsonify({"success": False, "error": "no updatable fields"}), 400
|
||||
)
|
||||
|
||||
try:
|
||||
with db_session() as conn:
|
||||
ok = UserCustomModelsRepository(conn).update(
|
||||
model_id, user_id, update_fields
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error updating user custom model: {e}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 500)
|
||||
|
||||
if not ok:
|
||||
return make_response(jsonify({"success": False}), 404)
|
||||
|
||||
ModelRegistry.invalidate_user(user_id)
|
||||
with db_readonly() as conn:
|
||||
row = UserCustomModelsRepository(conn).get(model_id, user_id)
|
||||
return make_response(jsonify(_row_to_response(row)), 200)
|
||||
|
||||
@api.doc(description="Delete a BYOM custom model")
|
||||
def delete(self, model_id):
|
||||
user_id, err = _user_id_or_401()
|
||||
if err:
|
||||
return err
|
||||
try:
|
||||
with db_session() as conn:
|
||||
ok = UserCustomModelsRepository(conn).delete(model_id, user_id)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error deleting user custom model: {e}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 500)
|
||||
if not ok:
|
||||
return make_response(jsonify({"success": False}), 404)
|
||||
|
||||
ModelRegistry.invalidate_user(user_id)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
def _run_connection_test(
|
||||
base_url: str, api_key: str, upstream_model_id: str
|
||||
):
|
||||
"""Send a 1-token chat-completion to verify a BYOM endpoint.
|
||||
|
||||
Returns ``(body, http_status)``. Upstream errors return 200 with
|
||||
``ok=False`` so the UI can render inline errors; only local SSRF
|
||||
rejection returns 400.
|
||||
"""
|
||||
url = base_url.rstrip("/") + "/chat/completions"
|
||||
payload = {
|
||||
"model": upstream_model_id,
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"max_tokens": 1,
|
||||
"stream": False,
|
||||
}
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
try:
|
||||
# pinned_post closes the DNS-rebinding window. Redirects off
|
||||
# because 3xx could bounce to an internal address (the SSRF
|
||||
# guard only validates the supplied URL).
|
||||
resp = pinned_post(
|
||||
url,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=5,
|
||||
allow_redirects=False,
|
||||
)
|
||||
except UnsafeUserUrlError as e:
|
||||
return {"ok": False, "error": str(e)}, 400
|
||||
except requests.RequestException as e:
|
||||
return {"ok": False, "error": f"connection error: {e}"}, 200
|
||||
|
||||
if 300 <= resp.status_code < 400:
|
||||
return (
|
||||
{
|
||||
"ok": False,
|
||||
"error": (
|
||||
f"upstream returned HTTP {resp.status_code} "
|
||||
"redirect; refusing to follow"
|
||||
),
|
||||
},
|
||||
200,
|
||||
)
|
||||
|
||||
if resp.status_code >= 400:
|
||||
# Cap and only reflect JSON to avoid body-exfil via non-API responses.
|
||||
content_type = (resp.headers.get("Content-Type") or "").lower()
|
||||
if "application/json" in content_type:
|
||||
text = (resp.text or "")[:500]
|
||||
error_msg = f"upstream returned HTTP {resp.status_code}: {text}"
|
||||
else:
|
||||
error_msg = f"upstream returned HTTP {resp.status_code}"
|
||||
return {"ok": False, "error": error_msg}, 200
|
||||
|
||||
return {"ok": True}, 200
|
||||
|
||||
|
||||
@models_ns.route("/user/models/test")
|
||||
class UserModelTestPayloadResource(Resource):
|
||||
@api.doc(
|
||||
description=(
|
||||
"Test an arbitrary BYOM payload (display_name / model id / "
|
||||
"base_url / api_key) without saving. Used by the UI's 'Test "
|
||||
"connection' button so the user can validate before they "
|
||||
"Save. Same SSRF guard, same 1-token request, same 5s "
|
||||
"timeout as the by-id variant."
|
||||
)
|
||||
)
|
||||
def post(self):
|
||||
user_id, err = _user_id_or_401()
|
||||
if err:
|
||||
return err
|
||||
|
||||
data = request.get_json() or {}
|
||||
missing = check_required_fields(
|
||||
data, ["base_url", "api_key", "upstream_model_id"]
|
||||
)
|
||||
if missing:
|
||||
return missing
|
||||
|
||||
body, status = _run_connection_test(
|
||||
data["base_url"], data["api_key"], data["upstream_model_id"]
|
||||
)
|
||||
return make_response(jsonify(body), status)
|
||||
|
||||
|
||||
@models_ns.route("/user/models/<string:model_id>/test")
|
||||
class UserModelTestResource(Resource):
|
||||
@api.doc(
|
||||
description=(
|
||||
"Test a saved BYOM record. Defaults to the stored "
|
||||
"base_url / upstream_model_id / encrypted api_key, but "
|
||||
"any of those can be overridden via the request body so "
|
||||
"the UI can test in-flight edits before saving. Used by "
|
||||
"the 'Test connection' button in edit mode."
|
||||
)
|
||||
)
|
||||
def post(self, model_id):
|
||||
user_id, err = _user_id_or_401()
|
||||
if err:
|
||||
return err
|
||||
|
||||
data = request.get_json() or {}
|
||||
# Per-field overrides; blank/missing falls back to stored value.
|
||||
override_base_url = (data.get("base_url") or "").strip() or None
|
||||
override_upstream_model_id = (
|
||||
data.get("upstream_model_id") or ""
|
||||
).strip() or None
|
||||
override_api_key = (data.get("api_key") or "").strip() or None
|
||||
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
repo = UserCustomModelsRepository(conn)
|
||||
row = repo.get(model_id, user_id)
|
||||
if row is None:
|
||||
return make_response(jsonify({"success": False}), 404)
|
||||
stored_api_key = (
|
||||
repo._decrypt_api_key(
|
||||
row.get("api_key_encrypted", ""), user_id
|
||||
)
|
||||
if not override_api_key
|
||||
else None
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error loading user custom model for test: {e}", exc_info=True
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"ok": False, "error": "internal error loading model"}),
|
||||
500,
|
||||
)
|
||||
|
||||
api_key = override_api_key or stored_api_key
|
||||
if not api_key:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"ok": False,
|
||||
"error": (
|
||||
"Stored API key could not be decrypted. The "
|
||||
"encryption secret may have rotated. Re-save "
|
||||
"the model with the API key to recover."
|
||||
),
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
|
||||
base_url = override_base_url or row["base_url"]
|
||||
upstream_model_id = (
|
||||
override_upstream_model_id or row["upstream_model_id"]
|
||||
)
|
||||
body, status = _run_connection_test(
|
||||
base_url, api_key, upstream_model_id
|
||||
)
|
||||
return make_response(jsonify(body), status)
|
||||
|
||||
196
application/api/user/reconciliation.py
Normal file
196
application/api/user/reconciliation.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""Reconciler tick: sweep stuck rows and escalate to terminal status + alert."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from sqlalchemy import Connection
|
||||
|
||||
from application.api.user.idempotency import MAX_TASK_ATTEMPTS
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.engine import get_engine
|
||||
from application.storage.db.repositories.reconciliation import (
|
||||
ReconciliationRepository,
|
||||
)
|
||||
from application.storage.db.repositories.stack_logs import StackLogsRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
MAX_MESSAGE_RECONCILE_ATTEMPTS = 3
|
||||
|
||||
|
||||
def run_reconciliation() -> Dict[str, Any]:
|
||||
"""Single tick of the reconciler. Five sweeps, FOR UPDATE SKIP LOCKED.
|
||||
|
||||
Stuck ``executed`` tool calls always flip to ``failed`` — operators
|
||||
handle cleanup manually via the structured alert. The side effect is
|
||||
assumed to have committed; no automated rollback is attempted.
|
||||
|
||||
Stuck ``task_dedup`` rows (lease expired AND attempts >= max)
|
||||
promote to ``failed`` so a same-key retry can re-claim instead of
|
||||
sitting in ``pending`` until 24 h TTL.
|
||||
"""
|
||||
if not settings.POSTGRES_URI:
|
||||
return {
|
||||
"messages_failed": 0,
|
||||
"tool_calls_failed": 0,
|
||||
"skipped": "POSTGRES_URI not set",
|
||||
}
|
||||
|
||||
engine = get_engine()
|
||||
summary = {
|
||||
"messages_failed": 0,
|
||||
"tool_calls_failed": 0,
|
||||
"ingests_stalled": 0,
|
||||
"idempotency_pending_failed": 0,
|
||||
}
|
||||
|
||||
with engine.begin() as conn:
|
||||
repo = ReconciliationRepository(conn)
|
||||
for msg in repo.find_and_lock_stuck_messages():
|
||||
new_count = repo.increment_message_reconcile_attempts(msg["id"])
|
||||
if new_count >= MAX_MESSAGE_RECONCILE_ATTEMPTS:
|
||||
repo.mark_message_failed(
|
||||
msg["id"],
|
||||
error=(
|
||||
"reconciler: stuck in pending/streaming for >5 min "
|
||||
f"after {new_count} attempts"
|
||||
),
|
||||
)
|
||||
summary["messages_failed"] += 1
|
||||
_emit_alert(
|
||||
conn,
|
||||
name="reconciler_message_failed",
|
||||
user_id=msg.get("user_id"),
|
||||
detail={
|
||||
"message_id": str(msg["id"]),
|
||||
"attempts": new_count,
|
||||
},
|
||||
)
|
||||
|
||||
with engine.begin() as conn:
|
||||
repo = ReconciliationRepository(conn)
|
||||
for row in repo.find_and_lock_proposed_tool_calls():
|
||||
repo.mark_tool_call_failed(
|
||||
row["call_id"],
|
||||
error=(
|
||||
"reconciler: stuck in 'proposed' for >5 min; "
|
||||
"side effect status unknown"
|
||||
),
|
||||
)
|
||||
summary["tool_calls_failed"] += 1
|
||||
_emit_alert(
|
||||
conn,
|
||||
name="reconciler_tool_call_failed_proposed",
|
||||
user_id=None,
|
||||
detail={
|
||||
"call_id": row["call_id"],
|
||||
"tool_name": row.get("tool_name"),
|
||||
},
|
||||
)
|
||||
|
||||
with engine.begin() as conn:
|
||||
repo = ReconciliationRepository(conn)
|
||||
for row in repo.find_and_lock_executed_tool_calls():
|
||||
repo.mark_tool_call_failed(
|
||||
row["call_id"],
|
||||
error=(
|
||||
"reconciler: executed-not-confirmed; side effect "
|
||||
"assumed committed, manual cleanup required"
|
||||
),
|
||||
)
|
||||
summary["tool_calls_failed"] += 1
|
||||
_emit_alert(
|
||||
conn,
|
||||
name="reconciler_tool_call_failed_executed",
|
||||
user_id=None,
|
||||
detail={
|
||||
"call_id": row["call_id"],
|
||||
"tool_name": row.get("tool_name"),
|
||||
"action_name": row.get("action_name"),
|
||||
},
|
||||
)
|
||||
|
||||
# Q4: ingest checkpoints whose heartbeat has gone silent. The
|
||||
# reconciler only escalates (alerts) — it doesn't kill the worker
|
||||
# or roll back the partial embed. The next dispatch resumes from
|
||||
# ``last_index`` thanks to the per-chunk checkpoint, so this is an
|
||||
# observability sweep, not a recovery action.
|
||||
with engine.begin() as conn:
|
||||
repo = ReconciliationRepository(conn)
|
||||
for row in repo.find_and_lock_stalled_ingests():
|
||||
summary["ingests_stalled"] += 1
|
||||
_emit_alert(
|
||||
conn,
|
||||
name="reconciler_ingest_stalled",
|
||||
user_id=None,
|
||||
detail={
|
||||
"source_id": str(row.get("source_id")),
|
||||
"embedded_chunks": row.get("embedded_chunks"),
|
||||
"total_chunks": row.get("total_chunks"),
|
||||
"last_updated": str(row.get("last_updated")),
|
||||
},
|
||||
)
|
||||
# Bump the heartbeat so we don't re-alert every tick.
|
||||
repo.touch_ingest_progress(str(row["source_id"]))
|
||||
|
||||
# Q5: idempotency rows whose lease expired with attempts exhausted.
|
||||
# The wrapper's poison-loop guard normally finalises these, but if
|
||||
# the wrapper itself died mid-task (worker SIGKILL, OOM during
|
||||
# heartbeat) the row sits in ``pending`` blocking same-key retries
|
||||
# via ``_lookup_completed`` returning None for the whole 24 h TTL.
|
||||
# Promote to ``failed`` so a retry can re-claim and either resume
|
||||
# or fail loudly.
|
||||
with engine.begin() as conn:
|
||||
repo = ReconciliationRepository(conn)
|
||||
for row in repo.find_stuck_idempotency_pending(
|
||||
max_attempts=MAX_TASK_ATTEMPTS,
|
||||
):
|
||||
error_msg = (
|
||||
"reconciler: idempotency lease expired with attempts "
|
||||
f"({row['attempt_count']}) >= {MAX_TASK_ATTEMPTS}; "
|
||||
"task abandoned"
|
||||
)
|
||||
repo.mark_idempotency_pending_failed(
|
||||
row["idempotency_key"], error=error_msg,
|
||||
)
|
||||
summary["idempotency_pending_failed"] += 1
|
||||
_emit_alert(
|
||||
conn,
|
||||
name="reconciler_idempotency_pending_failed",
|
||||
user_id=None,
|
||||
detail={
|
||||
"idempotency_key": row["idempotency_key"],
|
||||
"task_name": row.get("task_name"),
|
||||
"task_id": row.get("task_id"),
|
||||
"attempts": row.get("attempt_count"),
|
||||
},
|
||||
)
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
def _emit_alert(
|
||||
conn: Connection,
|
||||
*,
|
||||
name: str,
|
||||
user_id: Optional[str],
|
||||
detail: Dict[str, Any],
|
||||
) -> None:
|
||||
"""Structured ``logger.error`` plus a ``stack_logs`` row for operators."""
|
||||
extra = {"alert": name, **detail}
|
||||
logger.error("reconciler alert: %s", name, extra=extra)
|
||||
try:
|
||||
StackLogsRepository(conn).insert(
|
||||
activity_id=str(uuid.uuid4()),
|
||||
endpoint="reconciliation_worker",
|
||||
level="ERROR",
|
||||
user_id=user_id,
|
||||
query=name,
|
||||
stacks=[extra],
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("reconciler: failed to write stack_logs row for %s", name)
|
||||
@@ -3,16 +3,20 @@
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import uuid
|
||||
import zipfile
|
||||
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
from sqlalchemy import text as sql_text
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.tasks import ingest, ingest_connector_task, ingest_remote
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.source_ids import derive_source_id as _derive_source_id
|
||||
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||
from application.parser.file.constants import SUPPORTED_SOURCE_EXTENSIONS
|
||||
from application.storage.db.repositories.idempotency import IdempotencyRepository
|
||||
from application.storage.db.repositories.sources import SourcesRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
@@ -30,6 +34,91 @@ sources_upload_ns = Namespace(
|
||||
)
|
||||
|
||||
|
||||
_IDEMPOTENCY_KEY_MAX_LEN = 256
|
||||
|
||||
|
||||
def _read_idempotency_key():
|
||||
"""Return (key, error_response). Empty header → (None, None); oversized → (None, 400)."""
|
||||
key = request.headers.get("Idempotency-Key")
|
||||
if not key:
|
||||
return None, None
|
||||
if len(key) > _IDEMPOTENCY_KEY_MAX_LEN:
|
||||
return None, make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": (
|
||||
f"Idempotency-Key exceeds maximum length of "
|
||||
f"{_IDEMPOTENCY_KEY_MAX_LEN} characters"
|
||||
),
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
return key, None
|
||||
|
||||
|
||||
def _scoped_idempotency_key(idempotency_key, scope):
|
||||
"""``{scope}:{key}`` so different users can't collide on the same key."""
|
||||
if not idempotency_key or not scope:
|
||||
return None
|
||||
return f"{scope}:{idempotency_key}"
|
||||
|
||||
|
||||
def _claim_task_or_get_cached(key, task_name):
|
||||
"""Claim ``key`` for this request OR return the winner's cached payload.
|
||||
|
||||
Pre-generates the celery task_id so a losing writer sees the same
|
||||
id immediately. Returns ``(task_id, cached_response)``; non-None
|
||||
cached means the caller should return without enqueuing. The
|
||||
cached payload mirrors the fresh-request response shape (including
|
||||
``source_id``) so the frontend can correlate SSE ingest events to
|
||||
the cached upload task without an extra round-trip — but only when
|
||||
the cached row actually exists; the "deduplicated" sentinel
|
||||
deliberately omits ``source_id`` so the frontend doesn't bind to a
|
||||
phantom source.
|
||||
"""
|
||||
predetermined_id = str(uuid.uuid4())
|
||||
with db_session() as conn:
|
||||
claimed = IdempotencyRepository(conn).claim_task(
|
||||
key=key, task_name=task_name, task_id=predetermined_id,
|
||||
)
|
||||
if claimed is not None:
|
||||
return claimed["task_id"], None
|
||||
with db_readonly() as conn:
|
||||
existing = IdempotencyRepository(conn).get_task(key)
|
||||
cached_id = existing.get("task_id") if existing else None
|
||||
payload: dict = {
|
||||
"success": True,
|
||||
"task_id": cached_id or "deduplicated",
|
||||
}
|
||||
# Only surface ``source_id`` when there's a real winner whose worker
|
||||
# is publishing SSE events tagged with that id. The "deduplicated"
|
||||
# branch means the lock row vanished — we have nothing to correlate.
|
||||
if cached_id is not None:
|
||||
payload["source_id"] = str(_derive_source_id(key))
|
||||
return None, payload
|
||||
|
||||
|
||||
def _release_claim(key):
|
||||
"""Drop a pending claim so a client retry can re-claim it."""
|
||||
try:
|
||||
with db_session() as conn:
|
||||
conn.execute(
|
||||
sql_text(
|
||||
"DELETE FROM task_dedup WHERE idempotency_key = :k "
|
||||
"AND status = 'pending'"
|
||||
),
|
||||
{"k": key},
|
||||
)
|
||||
except Exception:
|
||||
current_app.logger.exception(
|
||||
"Failed to release task_dedup claim for key=%s", key,
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
def _enforce_audio_path_size_limit(file_path: str, filename: str) -> None:
|
||||
if not is_audio_filename(filename):
|
||||
return
|
||||
@@ -49,17 +138,38 @@ class UploadFile(Resource):
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Uploads a file to be vectorized and indexed",
|
||||
description=(
|
||||
"Uploads a file to be vectorized and indexed. Honors an optional "
|
||||
"``Idempotency-Key`` header: a repeat request with the same key "
|
||||
"within 24h returns the original cached response without re-enqueuing."
|
||||
),
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
idempotency_key, key_error = _read_idempotency_key()
|
||||
if key_error is not None:
|
||||
return key_error
|
||||
# User-scoped to avoid cross-user collisions; also feeds
|
||||
# ``_derive_source_id`` so uuid5 stays user-disjoint.
|
||||
scoped_key = _scoped_idempotency_key(idempotency_key, user)
|
||||
# Claim before enqueue; the loser returns the winner's task_id.
|
||||
predetermined_task_id = None
|
||||
if scoped_key:
|
||||
predetermined_task_id, cached = _claim_task_or_get_cached(
|
||||
scoped_key, "ingest",
|
||||
)
|
||||
if cached is not None:
|
||||
return make_response(jsonify(cached), 200)
|
||||
data = request.form
|
||||
files = request.files.getlist("file")
|
||||
required_fields = ["user", "name"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields or not files or all(file.filename == "" for file in files):
|
||||
if scoped_key:
|
||||
_release_claim(scoped_key)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
@@ -69,7 +179,6 @@ class UploadFile(Resource):
|
||||
),
|
||||
400,
|
||||
)
|
||||
user = decoded_token.get("sub")
|
||||
job_name = request.form["name"]
|
||||
|
||||
# Create safe versions for filesystem operations
|
||||
@@ -140,16 +249,37 @@ class UploadFile(Resource):
|
||||
file_path = f"{base_path}/{safe_file}"
|
||||
with open(temp_file_path, "rb") as f:
|
||||
storage.save_file(f, file_path)
|
||||
task = ingest.delay(
|
||||
settings.UPLOAD_FOLDER,
|
||||
list(SUPPORTED_SOURCE_EXTENSIONS),
|
||||
job_name,
|
||||
user,
|
||||
file_path=base_path,
|
||||
filename=dir_name,
|
||||
file_name_map=file_name_map,
|
||||
# Mint the source UUID up here so the HTTP response and the
|
||||
# worker's SSE envelopes share one id. With an idempotency
|
||||
# key we reuse the deterministic uuid5 (retried task lands on
|
||||
# the same source row); without a key we fall back to uuid4.
|
||||
# The worker is told to use this id verbatim — see
|
||||
# ``ingest_worker(source_id=...)``.
|
||||
source_uuid = (
|
||||
_derive_source_id(scoped_key) if scoped_key else uuid.uuid4()
|
||||
)
|
||||
ingest_kwargs = dict(
|
||||
args=(
|
||||
settings.UPLOAD_FOLDER,
|
||||
list(SUPPORTED_SOURCE_EXTENSIONS),
|
||||
job_name,
|
||||
user,
|
||||
),
|
||||
kwargs={
|
||||
"file_path": base_path,
|
||||
"filename": dir_name,
|
||||
"file_name_map": file_name_map,
|
||||
# Scoped so the worker dedup row matches the HTTP claim.
|
||||
"idempotency_key": scoped_key or idempotency_key,
|
||||
"source_id": str(source_uuid),
|
||||
},
|
||||
)
|
||||
if predetermined_task_id is not None:
|
||||
ingest_kwargs["task_id"] = predetermined_task_id
|
||||
task = ingest.apply_async(**ingest_kwargs)
|
||||
except AudioFileTooLargeError:
|
||||
if scoped_key:
|
||||
_release_claim(scoped_key)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
@@ -161,8 +291,21 @@ class UploadFile(Resource):
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error uploading file: {err}", exc_info=True)
|
||||
if scoped_key:
|
||||
_release_claim(scoped_key)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
|
||||
# Predetermined id matches the dedup-claim row; loser GET sees same.
|
||||
response_task_id = predetermined_task_id or task.id
|
||||
# ``source_uuid`` was minted above and passed to the worker as
|
||||
# ``source_id``; the worker uses it verbatim for every SSE event,
|
||||
# so the frontend can correlate inbound ``source.ingest.*`` to
|
||||
# this upload regardless of whether an idempotency key was set.
|
||||
response_payload: dict = {
|
||||
"success": True,
|
||||
"task_id": response_task_id,
|
||||
"source_id": str(source_uuid),
|
||||
}
|
||||
return make_response(jsonify(response_payload), 200)
|
||||
|
||||
|
||||
@sources_upload_ns.route("/remote")
|
||||
@@ -182,17 +325,50 @@ class UploadRemote(Resource):
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Uploads remote source for vectorization",
|
||||
description=(
|
||||
"Uploads remote source for vectorization. Honors an optional "
|
||||
"``Idempotency-Key`` header: a repeat request with the same key "
|
||||
"within 24h returns the original cached response without re-enqueuing."
|
||||
),
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
idempotency_key, key_error = _read_idempotency_key()
|
||||
if key_error is not None:
|
||||
return key_error
|
||||
scoped_key = _scoped_idempotency_key(idempotency_key, user)
|
||||
data = request.form
|
||||
required_fields = ["user", "source", "name", "data"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
task_name_for_dedup = (
|
||||
"ingest_connector_task"
|
||||
if data.get("source") in ConnectorCreator.get_supported_connectors()
|
||||
else "ingest_remote"
|
||||
)
|
||||
predetermined_task_id = None
|
||||
if scoped_key:
|
||||
predetermined_task_id, cached = _claim_task_or_get_cached(
|
||||
scoped_key, task_name_for_dedup,
|
||||
)
|
||||
if cached is not None:
|
||||
return make_response(jsonify(cached), 200)
|
||||
# Mint the source UUID up here so the HTTP response and the
|
||||
# worker's SSE envelopes share one id. Same pattern as
|
||||
# ``UploadFile.post``: with an idempotency key we reuse the
|
||||
# deterministic uuid5 (retried task lands on the same source
|
||||
# row); without a key we fall back to uuid4. The worker is told
|
||||
# to use this id verbatim — see ``remote_worker`` and
|
||||
# ``ingest_connector``. Without this the no-key path would mint
|
||||
# a random uuid4 inside the worker that the frontend has no way
|
||||
# to correlate SSE events to.
|
||||
source_uuid = (
|
||||
_derive_source_id(scoped_key) if scoped_key else uuid.uuid4()
|
||||
)
|
||||
try:
|
||||
config = json.loads(data["data"])
|
||||
source_data = None
|
||||
@@ -208,6 +384,8 @@ class UploadRemote(Resource):
|
||||
elif data["source"] in ConnectorCreator.get_supported_connectors():
|
||||
session_token = config.get("session_token")
|
||||
if not session_token:
|
||||
if scoped_key:
|
||||
_release_claim(scoped_key)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
@@ -236,31 +414,62 @@ class UploadRemote(Resource):
|
||||
config["file_ids"] = file_ids
|
||||
config["folder_ids"] = folder_ids
|
||||
|
||||
task = ingest_connector_task.delay(
|
||||
job_name=data["name"],
|
||||
user=decoded_token.get("sub"),
|
||||
source_type=data["source"],
|
||||
session_token=session_token,
|
||||
file_ids=file_ids,
|
||||
folder_ids=folder_ids,
|
||||
recursive=config.get("recursive", False),
|
||||
retriever=config.get("retriever", "classic"),
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": True, "task_id": task.id}), 200
|
||||
)
|
||||
task = ingest_remote.delay(
|
||||
source_data=source_data,
|
||||
job_name=data["name"],
|
||||
user=decoded_token.get("sub"),
|
||||
loader=data["source"],
|
||||
)
|
||||
connector_kwargs = {
|
||||
"kwargs": {
|
||||
"job_name": data["name"],
|
||||
"user": user,
|
||||
"source_type": data["source"],
|
||||
"session_token": session_token,
|
||||
"file_ids": file_ids,
|
||||
"folder_ids": folder_ids,
|
||||
"recursive": config.get("recursive", False),
|
||||
"retriever": config.get("retriever", "classic"),
|
||||
"idempotency_key": scoped_key or idempotency_key,
|
||||
"source_id": str(source_uuid),
|
||||
},
|
||||
}
|
||||
if predetermined_task_id is not None:
|
||||
connector_kwargs["task_id"] = predetermined_task_id
|
||||
task = ingest_connector_task.apply_async(**connector_kwargs)
|
||||
response_task_id = predetermined_task_id or task.id
|
||||
# ``source_uuid`` was minted above and passed to the
|
||||
# worker as ``source_id``; the worker uses it verbatim
|
||||
# for every SSE event, so the frontend can correlate
|
||||
# inbound ``source.ingest.*`` regardless of whether an
|
||||
# idempotency key was set.
|
||||
response_payload = {
|
||||
"success": True,
|
||||
"task_id": response_task_id,
|
||||
"source_id": str(source_uuid),
|
||||
}
|
||||
return make_response(jsonify(response_payload), 200)
|
||||
remote_kwargs = {
|
||||
"kwargs": {
|
||||
"source_data": source_data,
|
||||
"job_name": data["name"],
|
||||
"user": user,
|
||||
"loader": data["source"],
|
||||
"idempotency_key": scoped_key or idempotency_key,
|
||||
"source_id": str(source_uuid),
|
||||
},
|
||||
}
|
||||
if predetermined_task_id is not None:
|
||||
remote_kwargs["task_id"] = predetermined_task_id
|
||||
task = ingest_remote.apply_async(**remote_kwargs)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error uploading remote source: {err}", exc_info=True
|
||||
)
|
||||
if scoped_key:
|
||||
_release_claim(scoped_key)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
|
||||
response_task_id = predetermined_task_id or task.id
|
||||
response_payload = {
|
||||
"success": True,
|
||||
"task_id": response_task_id,
|
||||
"source_id": str(source_uuid),
|
||||
}
|
||||
return make_response(jsonify(response_payload), 200)
|
||||
|
||||
|
||||
@sources_upload_ns.route("/manage_source_files")
|
||||
@@ -305,6 +514,10 @@ class ManageSourceFiles(Resource):
|
||||
jsonify({"success": False, "message": "Unauthorized"}), 401
|
||||
)
|
||||
user = decoded_token.get("sub")
|
||||
idempotency_key, key_error = _read_idempotency_key()
|
||||
if key_error is not None:
|
||||
return key_error
|
||||
scoped_key = _scoped_idempotency_key(idempotency_key, user)
|
||||
source_id = request.form.get("source_id")
|
||||
operation = request.form.get("operation")
|
||||
|
||||
@@ -347,6 +560,12 @@ class ManageSourceFiles(Resource):
|
||||
jsonify({"success": False, "message": "Database error"}), 500
|
||||
)
|
||||
resolved_source_id = str(source["id"])
|
||||
# Flips to True after each branch's ``apply_async`` returns
|
||||
# successfully — at that point the worker owns the predetermined
|
||||
# task_id. The outer ``except`` only releases the claim while
|
||||
# this is False, so a post-``apply_async`` failure (jsonify,
|
||||
# make_response, etc.) doesn't double-enqueue on the next retry.
|
||||
claim_transferred = False
|
||||
try:
|
||||
storage = StorageCreator.get_storage()
|
||||
source_file_path = source.get("file_path", "")
|
||||
@@ -379,6 +598,34 @@ class ManageSourceFiles(Resource):
|
||||
),
|
||||
400,
|
||||
)
|
||||
|
||||
# Claim before any storage mutation so a duplicate request
|
||||
# short-circuits without touching the filesystem. Mirrors
|
||||
# the pattern in ``UploadFile.post`` / ``UploadRemote.post``
|
||||
# — without it ``.delay()`` would enqueue twice for two
|
||||
# racing same-key POSTs (the worker decorator only
|
||||
# deduplicates *after* completion).
|
||||
predetermined_task_id = None
|
||||
if scoped_key:
|
||||
predetermined_task_id, cached = _claim_task_or_get_cached(
|
||||
scoped_key, "reingest_source_task",
|
||||
)
|
||||
if cached is not None:
|
||||
# Frontend keys reingest polling on
|
||||
# ``reingest_task_id``; the shared cache helper
|
||||
# writes ``task_id``. Alias here so a dedup
|
||||
# response doesn't silently break FileTree's
|
||||
# poller. Override ``source_id`` too — the
|
||||
# helper derives it from the scoped key, which
|
||||
# is correct for upload but wrong for reingest
|
||||
# (the worker publishes events scoped to the
|
||||
# actual source row id).
|
||||
cached_task_id = cached.pop("task_id", None)
|
||||
if cached_task_id is not None:
|
||||
cached["reingest_task_id"] = cached_task_id
|
||||
cached["source_id"] = resolved_source_id
|
||||
return make_response(jsonify(cached), 200)
|
||||
|
||||
added_files = []
|
||||
map_updated = False
|
||||
|
||||
@@ -414,9 +661,15 @@ class ManageSourceFiles(Resource):
|
||||
|
||||
from application.api.user.tasks import reingest_source_task
|
||||
|
||||
task = reingest_source_task.delay(
|
||||
source_id=resolved_source_id, user=user
|
||||
task = reingest_source_task.apply_async(
|
||||
kwargs={
|
||||
"source_id": resolved_source_id,
|
||||
"user": user,
|
||||
"idempotency_key": scoped_key or idempotency_key,
|
||||
},
|
||||
task_id=predetermined_task_id,
|
||||
)
|
||||
claim_transferred = True
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
@@ -426,6 +679,12 @@ class ManageSourceFiles(Resource):
|
||||
"added_files": added_files,
|
||||
"parent_dir": parent_dir,
|
||||
"reingest_task_id": task.id,
|
||||
# ``source_id`` lets the frontend correlate
|
||||
# inbound ``source.ingest.*`` SSE events
|
||||
# (emitted by ``reingest_source_worker``)
|
||||
# back to the reingest task — matches the
|
||||
# upload route's source-id contract.
|
||||
"source_id": resolved_source_id,
|
||||
}
|
||||
),
|
||||
200,
|
||||
@@ -455,10 +714,8 @@ class ManageSourceFiles(Resource):
|
||||
),
|
||||
400,
|
||||
)
|
||||
# Remove files from storage and directory structure
|
||||
|
||||
removed_files = []
|
||||
map_updated = False
|
||||
# Path-traversal guard runs *before* the claim so a 400
|
||||
# for an invalid path doesn't leave a pending dedup row.
|
||||
for file_path in file_paths:
|
||||
if ".." in str(file_path) or str(file_path).startswith("/"):
|
||||
return make_response(
|
||||
@@ -470,6 +727,31 @@ class ManageSourceFiles(Resource):
|
||||
),
|
||||
400,
|
||||
)
|
||||
|
||||
# Claim before any storage mutation. See ``add`` branch
|
||||
# comment for rationale.
|
||||
predetermined_task_id = None
|
||||
if scoped_key:
|
||||
predetermined_task_id, cached = _claim_task_or_get_cached(
|
||||
scoped_key, "reingest_source_task",
|
||||
)
|
||||
if cached is not None:
|
||||
cached_task_id = cached.pop("task_id", None)
|
||||
if cached_task_id is not None:
|
||||
cached["reingest_task_id"] = cached_task_id
|
||||
# Override the helper's synthetic source_id (uuid5
|
||||
# of the scoped key) with the real source row id
|
||||
# — the reingest worker publishes SSE events
|
||||
# scoped to ``resolved_source_id`` and FileTree
|
||||
# correlates on it.
|
||||
cached["source_id"] = resolved_source_id
|
||||
return make_response(jsonify(cached), 200)
|
||||
|
||||
# Remove files from storage and directory structure
|
||||
|
||||
removed_files = []
|
||||
map_updated = False
|
||||
for file_path in file_paths:
|
||||
full_path = f"{source_file_path}/{file_path}"
|
||||
|
||||
# Remove from storage
|
||||
@@ -491,9 +773,15 @@ class ManageSourceFiles(Resource):
|
||||
|
||||
from application.api.user.tasks import reingest_source_task
|
||||
|
||||
task = reingest_source_task.delay(
|
||||
source_id=resolved_source_id, user=user
|
||||
task = reingest_source_task.apply_async(
|
||||
kwargs={
|
||||
"source_id": resolved_source_id,
|
||||
"user": user,
|
||||
"idempotency_key": scoped_key or idempotency_key,
|
||||
},
|
||||
task_id=predetermined_task_id,
|
||||
)
|
||||
claim_transferred = True
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
@@ -502,6 +790,7 @@ class ManageSourceFiles(Resource):
|
||||
"message": f"Removed {len(removed_files)} files",
|
||||
"removed_files": removed_files,
|
||||
"reingest_task_id": task.id,
|
||||
"source_id": resolved_source_id,
|
||||
}
|
||||
),
|
||||
200,
|
||||
@@ -552,6 +841,24 @@ class ManageSourceFiles(Resource):
|
||||
),
|
||||
404,
|
||||
)
|
||||
|
||||
# Claim before mutation. See ``add`` branch for rationale.
|
||||
predetermined_task_id = None
|
||||
if scoped_key:
|
||||
predetermined_task_id, cached = _claim_task_or_get_cached(
|
||||
scoped_key, "reingest_source_task",
|
||||
)
|
||||
if cached is not None:
|
||||
cached_task_id = cached.pop("task_id", None)
|
||||
if cached_task_id is not None:
|
||||
cached["reingest_task_id"] = cached_task_id
|
||||
# Same source_id override as the ``remove`` /
|
||||
# ``add`` cached branches — the helper's synthetic
|
||||
# id doesn't match what reingest_source_worker
|
||||
# tags its SSE events with.
|
||||
cached["source_id"] = resolved_source_id
|
||||
return make_response(jsonify(cached), 200)
|
||||
|
||||
success = storage.remove_directory(full_directory_path)
|
||||
|
||||
if not success:
|
||||
@@ -560,6 +867,11 @@ class ManageSourceFiles(Resource):
|
||||
f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}, "
|
||||
f"Full path: {full_directory_path}"
|
||||
)
|
||||
# Release so a client retry can reclaim — otherwise
|
||||
# the next request would silently 200-cache to the
|
||||
# task_id that never enqueued.
|
||||
if scoped_key:
|
||||
_release_claim(scoped_key)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": False, "message": "Failed to remove directory"}
|
||||
@@ -591,9 +903,15 @@ class ManageSourceFiles(Resource):
|
||||
|
||||
from application.api.user.tasks import reingest_source_task
|
||||
|
||||
task = reingest_source_task.delay(
|
||||
source_id=resolved_source_id, user=user
|
||||
task = reingest_source_task.apply_async(
|
||||
kwargs={
|
||||
"source_id": resolved_source_id,
|
||||
"user": user,
|
||||
"idempotency_key": scoped_key or idempotency_key,
|
||||
},
|
||||
task_id=predetermined_task_id,
|
||||
)
|
||||
claim_transferred = True
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
@@ -602,11 +920,20 @@ class ManageSourceFiles(Resource):
|
||||
"message": f"Successfully removed directory: {directory_path}",
|
||||
"removed_directory": directory_path,
|
||||
"reingest_task_id": task.id,
|
||||
"source_id": resolved_source_id,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
except Exception as err:
|
||||
# Release the dedup claim only if it wasn't transferred to
|
||||
# a worker. Without this, a same-key retry within the 24h
|
||||
# TTL would 200-cache to a predetermined task_id whose
|
||||
# ``apply_async`` never ran (or ran but the response builder
|
||||
# blew up afterward — only the first case matters in
|
||||
# practice; the flag protects both).
|
||||
if scoped_key and not claim_transferred:
|
||||
_release_claim(scoped_key)
|
||||
error_context = f"operation={operation}, user={user}, source_id={source_id}"
|
||||
if operation == "remove_directory":
|
||||
directory_path = request.form.get("directory_path", "")
|
||||
|
||||
@@ -1,21 +1,45 @@
|
||||
from datetime import timedelta
|
||||
|
||||
from application.api.user.idempotency import with_idempotency
|
||||
from application.celery_init import celery
|
||||
from application.worker import (
|
||||
agent_webhook_worker,
|
||||
attachment_worker,
|
||||
ingest_worker,
|
||||
mcp_oauth,
|
||||
mcp_oauth_status,
|
||||
remote_worker,
|
||||
sync,
|
||||
sync_worker,
|
||||
)
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
# Shared decorator config for long-running, side-effecting tasks. ``acks_late``
|
||||
# is also the celeryconfig default but stays explicit here so each task's
|
||||
# durability story is grep-able next to the body. Combined with
|
||||
# ``autoretry_for=(Exception,)`` and a bounded ``max_retries`` so a poison
|
||||
# message can't loop forever.
|
||||
DURABLE_TASK = dict(
|
||||
bind=True,
|
||||
acks_late=True,
|
||||
autoretry_for=(Exception,),
|
||||
retry_kwargs={"max_retries": 3, "countdown": 60},
|
||||
retry_backoff=True,
|
||||
)
|
||||
|
||||
|
||||
@celery.task(**DURABLE_TASK)
|
||||
@with_idempotency(task_name="ingest")
|
||||
def ingest(
|
||||
self, directory, formats, job_name, user, file_path, filename, file_name_map=None
|
||||
self,
|
||||
directory,
|
||||
formats,
|
||||
job_name,
|
||||
user,
|
||||
file_path,
|
||||
filename,
|
||||
file_name_map=None,
|
||||
idempotency_key=None,
|
||||
source_id=None,
|
||||
):
|
||||
resp = ingest_worker(
|
||||
self,
|
||||
@@ -26,25 +50,40 @@ def ingest(
|
||||
filename,
|
||||
user,
|
||||
file_name_map=file_name_map,
|
||||
idempotency_key=idempotency_key,
|
||||
source_id=source_id,
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def ingest_remote(self, source_data, job_name, user, loader):
|
||||
resp = remote_worker(self, source_data, job_name, user, loader)
|
||||
@celery.task(**DURABLE_TASK)
|
||||
@with_idempotency(task_name="ingest_remote")
|
||||
def ingest_remote(
|
||||
self, source_data, job_name, user, loader,
|
||||
idempotency_key=None, source_id=None,
|
||||
):
|
||||
resp = remote_worker(
|
||||
self, source_data, job_name, user, loader,
|
||||
idempotency_key=idempotency_key,
|
||||
source_id=source_id,
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def reingest_source_task(self, source_id, user):
|
||||
@celery.task(**DURABLE_TASK)
|
||||
@with_idempotency(task_name="reingest_source_task")
|
||||
def reingest_source_task(self, source_id, user, idempotency_key=None):
|
||||
from application.worker import reingest_source_worker
|
||||
|
||||
resp = reingest_source_worker(self, source_id, user)
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
# Beat-driven dispatch tasks default to ``acks_late=False``: a SIGKILL
|
||||
# of a beat tick is harmless to redeliver only if the dispatch itself is
|
||||
# idempotent. We keep these early-ACK so the broker doesn't replay a
|
||||
# dispatch that already enqueued downstream work.
|
||||
@celery.task(bind=True, acks_late=False)
|
||||
def schedule_syncs(self, frequency):
|
||||
resp = sync_worker(self, frequency)
|
||||
return resp
|
||||
@@ -74,19 +113,22 @@ def sync_source(
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def store_attachment(self, file_info, user):
|
||||
@celery.task(**DURABLE_TASK)
|
||||
@with_idempotency(task_name="store_attachment")
|
||||
def store_attachment(self, file_info, user, idempotency_key=None):
|
||||
resp = attachment_worker(self, file_info, user)
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def process_agent_webhook(self, agent_id, payload):
|
||||
@celery.task(**DURABLE_TASK)
|
||||
@with_idempotency(task_name="process_agent_webhook")
|
||||
def process_agent_webhook(self, agent_id, payload, idempotency_key=None):
|
||||
resp = agent_webhook_worker(self, agent_id, payload)
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
@celery.task(**DURABLE_TASK)
|
||||
@with_idempotency(task_name="ingest_connector_task")
|
||||
def ingest_connector_task(
|
||||
self,
|
||||
job_name,
|
||||
@@ -100,6 +142,8 @@ def ingest_connector_task(
|
||||
operation_mode="upload",
|
||||
doc_id=None,
|
||||
sync_frequency="never",
|
||||
idempotency_key=None,
|
||||
source_id=None,
|
||||
):
|
||||
from application.worker import ingest_connector
|
||||
|
||||
@@ -116,6 +160,8 @@ def ingest_connector_task(
|
||||
operation_mode=operation_mode,
|
||||
doc_id=doc_id,
|
||||
sync_frequency=sync_frequency,
|
||||
idempotency_key=idempotency_key,
|
||||
source_id=source_id,
|
||||
)
|
||||
return resp
|
||||
|
||||
@@ -140,6 +186,33 @@ def setup_periodic_tasks(sender, **kwargs):
|
||||
cleanup_pending_tool_state.s(),
|
||||
name="cleanup-pending-tool-state",
|
||||
)
|
||||
# Pure housekeeping for ``task_dedup`` / ``webhook_dedup`` — the
|
||||
# upsert paths already handle stale rows, so cadence only bounds
|
||||
# table size. Hourly is plenty for typical traffic.
|
||||
sender.add_periodic_task(
|
||||
timedelta(hours=1),
|
||||
cleanup_idempotency_dedup.s(),
|
||||
name="cleanup-idempotency-dedup",
|
||||
)
|
||||
sender.add_periodic_task(
|
||||
timedelta(seconds=30),
|
||||
reconciliation_task.s(),
|
||||
name="reconciliation",
|
||||
)
|
||||
sender.add_periodic_task(
|
||||
timedelta(hours=7),
|
||||
version_check_task.s(),
|
||||
name="version-check",
|
||||
)
|
||||
# Bound ``message_events`` growth — every streamed SSE chunk writes
|
||||
# one row, so retained chats accumulate hundreds of rows per
|
||||
# message. Reconnect-replay is only meaningful for streams the user
|
||||
# could plausibly still be waiting on, so 14 days is generous.
|
||||
sender.add_periodic_task(
|
||||
timedelta(hours=24),
|
||||
cleanup_message_events.s(),
|
||||
name="cleanup-message-events",
|
||||
)
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
@@ -148,24 +221,12 @@ def mcp_oauth_task(self, config, user):
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def mcp_oauth_status_task(self, task_id):
|
||||
resp = mcp_oauth_status(self, task_id)
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
@celery.task(bind=True, acks_late=False)
|
||||
def cleanup_pending_tool_state(self):
|
||||
"""Delete pending_tool_state rows past their TTL.
|
||||
|
||||
Replaces Mongo's ``expireAfterSeconds=0`` TTL index — Postgres has
|
||||
no native TTL, so this task runs every 60 seconds to keep
|
||||
``pending_tool_state`` bounded. No-ops if ``POSTGRES_URI`` isn't
|
||||
configured (keeps the task runnable in Mongo-only environments).
|
||||
"""
|
||||
"""Revert stale ``resuming`` rows, then delete TTL-expired rows."""
|
||||
from application.core.settings import settings
|
||||
if not settings.POSTGRES_URI:
|
||||
return {"deleted": 0, "skipped": "POSTGRES_URI not set"}
|
||||
return {"deleted": 0, "reverted": 0, "skipped": "POSTGRES_URI not set"}
|
||||
|
||||
from application.storage.db.engine import get_engine
|
||||
from application.storage.db.repositories.pending_tool_state import (
|
||||
@@ -174,5 +235,80 @@ def cleanup_pending_tool_state(self):
|
||||
|
||||
engine = get_engine()
|
||||
with engine.begin() as conn:
|
||||
deleted = PendingToolStateRepository(conn).cleanup_expired()
|
||||
return {"deleted": deleted}
|
||||
repo = PendingToolStateRepository(conn)
|
||||
reverted = repo.revert_stale_resuming(grace_seconds=600)
|
||||
deleted = repo.cleanup_expired()
|
||||
return {"deleted": deleted, "reverted": reverted}
|
||||
|
||||
|
||||
@celery.task(bind=True, acks_late=False)
|
||||
def cleanup_idempotency_dedup(self):
|
||||
"""Delete TTL-expired rows from ``task_dedup`` and ``webhook_dedup``.
|
||||
|
||||
Pure housekeeping — the upsert paths already ignore stale rows
|
||||
(TTL-aware ``ON CONFLICT DO UPDATE``), so this only bounds table
|
||||
growth and keeps SELECT planning tight on large deployments.
|
||||
"""
|
||||
from application.core.settings import settings
|
||||
if not settings.POSTGRES_URI:
|
||||
return {
|
||||
"task_dedup_deleted": 0,
|
||||
"webhook_dedup_deleted": 0,
|
||||
"skipped": "POSTGRES_URI not set",
|
||||
}
|
||||
|
||||
from application.storage.db.engine import get_engine
|
||||
from application.storage.db.repositories.idempotency import (
|
||||
IdempotencyRepository,
|
||||
)
|
||||
|
||||
engine = get_engine()
|
||||
with engine.begin() as conn:
|
||||
return IdempotencyRepository(conn).cleanup_expired()
|
||||
|
||||
|
||||
@celery.task(bind=True, acks_late=False)
|
||||
def reconciliation_task(self):
|
||||
"""Sweep stuck durability rows and escalate them to terminal status + alert."""
|
||||
from application.api.user.reconciliation import run_reconciliation
|
||||
|
||||
return run_reconciliation()
|
||||
|
||||
|
||||
@celery.task(bind=True, acks_late=False)
|
||||
def cleanup_message_events(self):
|
||||
"""Delete ``message_events`` rows older than the retention window.
|
||||
|
||||
Streamed answer responses write one journal row per SSE yield,
|
||||
so unbounded growth would dominate Postgres for any retained-
|
||||
conversations deployment. The reconnect-replay path only needs
|
||||
rows for in-flight streams; 14 days covers paused/tool-action
|
||||
flows comfortably.
|
||||
"""
|
||||
from application.core.settings import settings
|
||||
if not settings.POSTGRES_URI:
|
||||
return {"deleted": 0, "skipped": "POSTGRES_URI not set"}
|
||||
|
||||
from application.storage.db.engine import get_engine
|
||||
from application.storage.db.repositories.message_events import (
|
||||
MessageEventsRepository,
|
||||
)
|
||||
|
||||
ttl_days = settings.MESSAGE_EVENTS_RETENTION_DAYS
|
||||
engine = get_engine()
|
||||
with engine.begin() as conn:
|
||||
deleted = MessageEventsRepository(conn).cleanup_older_than(ttl_days)
|
||||
return {"deleted": deleted, "ttl_days": ttl_days}
|
||||
|
||||
|
||||
@celery.task(bind=True, acks_late=False)
|
||||
def version_check_task(self):
|
||||
"""Periodic anonymous version check.
|
||||
|
||||
Complements the ``worker_ready`` boot trigger so long-running
|
||||
deployments (>6h cache TTL) still refresh advisories. ``run_check``
|
||||
is fail-silent and coordinates across replicas via Redis lock +
|
||||
cache (see ``application.updates.version_check``).
|
||||
"""
|
||||
from application.updates.version_check import run_check
|
||||
run_check()
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Tool management MCP server integration."""
|
||||
|
||||
import json
|
||||
from urllib.parse import urlencode, urlparse
|
||||
|
||||
from flask import current_app, jsonify, make_response, redirect, request
|
||||
@@ -226,7 +225,9 @@ class MCPServerSave(Resource):
|
||||
)
|
||||
redis_client = get_redis_instance()
|
||||
manager = MCPOAuthManager(redis_client)
|
||||
result = manager.get_oauth_status(config["oauth_task_id"])
|
||||
result = manager.get_oauth_status(
|
||||
config["oauth_task_id"], user
|
||||
)
|
||||
if not result.get("status") == "completed":
|
||||
return make_response(
|
||||
jsonify(
|
||||
@@ -438,56 +439,6 @@ class MCPOAuthCallback(Resource):
|
||||
)
|
||||
|
||||
|
||||
@tools_mcp_ns.route("/mcp_server/oauth_status/<string:task_id>")
|
||||
class MCPOAuthStatus(Resource):
|
||||
def get(self, task_id):
|
||||
try:
|
||||
redis_client = get_redis_instance()
|
||||
status_key = f"mcp_oauth_status:{task_id}"
|
||||
status_data = redis_client.get(status_key)
|
||||
|
||||
if status_data:
|
||||
status = json.loads(status_data)
|
||||
if "tools" in status and isinstance(status["tools"], list):
|
||||
status["tools"] = [
|
||||
{
|
||||
"name": t.get("name", "unknown"),
|
||||
"description": t.get("description", ""),
|
||||
}
|
||||
for t in status["tools"]
|
||||
]
|
||||
return make_response(
|
||||
jsonify({"success": True, "task_id": task_id, **status})
|
||||
)
|
||||
else:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"task_id": task_id,
|
||||
"status": "pending",
|
||||
"message": "Waiting for OAuth to start...",
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error getting OAuth status for task {task_id}: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Failed to get OAuth status",
|
||||
"task_id": task_id,
|
||||
}
|
||||
),
|
||||
500,
|
||||
)
|
||||
|
||||
|
||||
@tools_mcp_ns.route("/mcp_server/auth_status")
|
||||
class MCPAuthStatus(Resource):
|
||||
@api.doc(
|
||||
|
||||
@@ -198,8 +198,14 @@ def normalize_agent_node_json_schemas(nodes: List[Dict]) -> List[Dict]:
|
||||
return normalized_nodes
|
||||
|
||||
|
||||
def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[str]:
|
||||
"""Validate workflow graph structure."""
|
||||
def validate_workflow_structure(
|
||||
nodes: List[Dict], edges: List[Dict], user_id: str | None = None
|
||||
) -> List[str]:
|
||||
"""Validate workflow graph structure.
|
||||
|
||||
``user_id`` is required so per-user BYOM custom-model UUIDs resolve
|
||||
when checking each agent node's structured-output capability.
|
||||
"""
|
||||
errors = []
|
||||
|
||||
if not nodes:
|
||||
@@ -343,7 +349,7 @@ def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[st
|
||||
|
||||
model_id = raw_config.get("model_id")
|
||||
if has_json_schema and isinstance(model_id, str) and model_id.strip():
|
||||
capabilities = get_model_capabilities(model_id.strip())
|
||||
capabilities = get_model_capabilities(model_id.strip(), user_id=user_id)
|
||||
if capabilities and not capabilities.get("supports_structured_output", False):
|
||||
errors.append(
|
||||
f"Agent node '{agent_title}' selected model does not support structured output"
|
||||
@@ -389,7 +395,9 @@ class WorkflowList(Resource):
|
||||
nodes_data = data.get("nodes", [])
|
||||
edges_data = data.get("edges", [])
|
||||
|
||||
validation_errors = validate_workflow_structure(nodes_data, edges_data)
|
||||
validation_errors = validate_workflow_structure(
|
||||
nodes_data, edges_data, user_id=user_id
|
||||
)
|
||||
if validation_errors:
|
||||
return error_response(
|
||||
"Workflow validation failed", errors=validation_errors
|
||||
@@ -451,7 +459,9 @@ class WorkflowDetail(Resource):
|
||||
nodes_data = data.get("nodes", [])
|
||||
edges_data = data.get("edges", [])
|
||||
|
||||
validation_errors = validate_workflow_structure(nodes_data, edges_data)
|
||||
validation_errors = validate_workflow_structure(
|
||||
nodes_data, edges_data, user_id=user_id
|
||||
)
|
||||
if validation_errors:
|
||||
return error_response(
|
||||
"Workflow validation failed", errors=validation_errors
|
||||
|
||||
@@ -9,6 +9,7 @@ import json
|
||||
import logging
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Generator, Optional
|
||||
|
||||
from flask import Blueprint, jsonify, make_response, request, Response
|
||||
@@ -213,6 +214,7 @@ def _stream_response(
|
||||
decoded_token=processor.decoded_token,
|
||||
agent_id=processor.agent_id,
|
||||
model_id=processor.model_id,
|
||||
model_user_id=processor.model_user_id,
|
||||
should_save_conversation=should_save_conversation,
|
||||
_continuation=continuation,
|
||||
)
|
||||
@@ -220,13 +222,26 @@ def _stream_response(
|
||||
for line in internal_stream:
|
||||
if not line.strip():
|
||||
continue
|
||||
# Parse the internal SSE event
|
||||
event_str = line.replace("data: ", "").strip()
|
||||
# ``complete_stream`` prefixes each frame with ``id: <seq>\n``
|
||||
# before the ``data:`` line. Extract just the data line so JSON
|
||||
# decode doesn't choke on the SSE framing.
|
||||
event_str = ""
|
||||
for raw in line.split("\n"):
|
||||
if raw.startswith("data:"):
|
||||
event_str = raw[len("data:") :].lstrip()
|
||||
break
|
||||
if not event_str:
|
||||
continue
|
||||
try:
|
||||
event_data = json.loads(event_str)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
continue
|
||||
|
||||
# Skip the informational ``message_id`` event — it has no v1 /
|
||||
# OpenAI-compatible analog.
|
||||
if event_data.get("type") == "message_id":
|
||||
continue
|
||||
|
||||
# Update completion_id when we get the conversation id
|
||||
if event_data.get("type") == "id":
|
||||
conv_id = event_data.get("id", "")
|
||||
@@ -257,6 +272,7 @@ def _non_stream_response(
|
||||
decoded_token=processor.decoded_token,
|
||||
agent_id=processor.agent_id,
|
||||
model_id=processor.model_id,
|
||||
model_user_id=processor.model_user_id,
|
||||
should_save_conversation=should_save_conversation,
|
||||
_continuation=continuation,
|
||||
)
|
||||
@@ -304,7 +320,16 @@ def list_models():
|
||||
401,
|
||||
)
|
||||
|
||||
# Repository rows now go through ``coerce_pg_native`` at SELECT
|
||||
# time, so timestamps arrive as ISO 8601 strings. Parse before
|
||||
# taking ``.timestamp()``; fall back to ``time.time()`` only when
|
||||
# the value is genuinely missing or unparseable.
|
||||
created = agent.get("created_at") or agent.get("createdAt")
|
||||
if isinstance(created, str):
|
||||
try:
|
||||
created = datetime.fromisoformat(created)
|
||||
except (ValueError, TypeError):
|
||||
created = None
|
||||
created_ts = (
|
||||
int(created.timestamp()) if hasattr(created, "timestamp")
|
||||
else int(time.time())
|
||||
|
||||
@@ -9,12 +9,15 @@ from jose import jwt
|
||||
|
||||
from application.auth import handle_auth
|
||||
|
||||
from application.core import log_context
|
||||
from application.core.logging_config import setup_logging
|
||||
|
||||
setup_logging()
|
||||
|
||||
from application.api import api # noqa: E402
|
||||
from application.api.answer import answer # noqa: E402
|
||||
from application.api.answer.routes.messages import messages_bp # noqa: E402
|
||||
from application.api.events.routes import events # noqa: E402
|
||||
from application.api.internal.routes import internal # noqa: E402
|
||||
from application.api.user.routes import user # noqa: E402
|
||||
from application.api.connector.routes import connector # noqa: E402
|
||||
@@ -48,6 +51,8 @@ ensure_database_ready(
|
||||
app = Flask(__name__)
|
||||
app.register_blueprint(user)
|
||||
app.register_blueprint(answer)
|
||||
app.register_blueprint(events)
|
||||
app.register_blueprint(messages_bp)
|
||||
app.register_blueprint(internal)
|
||||
app.register_blueprint(connector)
|
||||
app.register_blueprint(v1_bp)
|
||||
@@ -112,6 +117,38 @@ def generate_token():
|
||||
return jsonify({"error": "Token generation not allowed in current auth mode"}), 400
|
||||
|
||||
|
||||
_LOG_CTX_TOKEN_ATTR = "_log_ctx_token"
|
||||
|
||||
|
||||
@app.before_request
|
||||
def _bind_log_context():
|
||||
"""Bind activity_id + endpoint for the duration of this request.
|
||||
|
||||
Runs before ``authenticate_request``; ``user_id`` is overlaid in a
|
||||
follow-up handler once the JWT has been decoded.
|
||||
"""
|
||||
if request.method == "OPTIONS":
|
||||
return None
|
||||
activity_id = str(uuid.uuid4())
|
||||
request.activity_id = activity_id
|
||||
token = log_context.bind(
|
||||
activity_id=activity_id,
|
||||
endpoint=request.endpoint,
|
||||
)
|
||||
setattr(request, _LOG_CTX_TOKEN_ATTR, token)
|
||||
return None
|
||||
|
||||
|
||||
@app.teardown_request
|
||||
def _reset_log_context(_exc):
|
||||
# SSE streams keep yielding after teardown fires, but a2wsgi runs each
|
||||
# request inside ``copy_context().run(...)``, so this reset doesn't
|
||||
# leak into the stream's view of the context.
|
||||
token = getattr(request, _LOG_CTX_TOKEN_ATTR, None)
|
||||
if token is not None:
|
||||
log_context.reset(token)
|
||||
|
||||
|
||||
@app.before_request
|
||||
def enforce_stt_request_size_limits():
|
||||
if request.method == "OPTIONS":
|
||||
@@ -148,12 +185,29 @@ def authenticate_request():
|
||||
request.decoded_token = decoded_token
|
||||
|
||||
|
||||
@app.before_request
|
||||
def _bind_user_id_to_log_context():
|
||||
# Registered after ``authenticate_request`` (Flask runs before_request
|
||||
# handlers in registration order), so ``request.decoded_token`` is
|
||||
# populated by the time we read it. ``teardown_request`` unwinds the
|
||||
# whole request-level bind, so no separate reset token is needed here.
|
||||
if request.method == "OPTIONS":
|
||||
return None
|
||||
decoded_token = getattr(request, "decoded_token", None)
|
||||
user_id = decoded_token.get("sub") if isinstance(decoded_token, dict) else None
|
||||
if user_id:
|
||||
log_context.bind(user_id=user_id)
|
||||
return None
|
||||
|
||||
|
||||
@app.after_request
|
||||
def after_request(response: Response) -> Response:
|
||||
"""Add CORS headers for the pure Flask development entrypoint."""
|
||||
response.headers["Access-Control-Allow-Origin"] = "*"
|
||||
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization"
|
||||
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS"
|
||||
response.headers["Access-Control-Allow-Headers"] = (
|
||||
"Content-Type, Authorization, Idempotency-Key"
|
||||
)
|
||||
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
|
||||
return response
|
||||
|
||||
|
||||
|
||||
@@ -24,8 +24,13 @@ asgi_app = Starlette(
|
||||
Middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
||||
allow_headers=["Content-Type", "Authorization", "Mcp-Session-Id"],
|
||||
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
|
||||
allow_headers=[
|
||||
"Content-Type",
|
||||
"Authorization",
|
||||
"Mcp-Session-Id",
|
||||
"Idempotency-Key",
|
||||
],
|
||||
expose_headers=["Mcp-Session-Id"],
|
||||
),
|
||||
],
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
@@ -10,6 +11,14 @@ from application.utils import get_hash
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _cache_default(value):
|
||||
# Image attachments arrive inline as bytes (see GoogleLLM.prepare_messages_with_attachments);
|
||||
# hash so the cache key stays bounded in size and stable across identical content.
|
||||
if isinstance(value, (bytes, bytearray, memoryview)):
|
||||
return f"<bytes:sha256:{hashlib.sha256(bytes(value)).hexdigest()}>"
|
||||
return repr(value)
|
||||
|
||||
_redis_instance = None
|
||||
_redis_creation_failed = False
|
||||
_instance_lock = Lock()
|
||||
@@ -20,8 +29,17 @@ def get_redis_instance():
|
||||
with _instance_lock:
|
||||
if _redis_instance is None and not _redis_creation_failed:
|
||||
try:
|
||||
# ``health_check_interval`` makes redis-py ping the
|
||||
# connection every N seconds when otherwise idle.
|
||||
# Without it, a half-open TCP (NAT silently dropped
|
||||
# state, ELB idle-close) can hang the SSE generator
|
||||
# in ``pubsub.get_message`` past its keepalive
|
||||
# cadence — the kernel never surfaces the dead
|
||||
# socket because no payload is in flight.
|
||||
_redis_instance = redis.Redis.from_url(
|
||||
settings.CACHE_REDIS_URL, socket_connect_timeout=2
|
||||
settings.CACHE_REDIS_URL,
|
||||
socket_connect_timeout=2,
|
||||
health_check_interval=10,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"Invalid Redis URL: {e}")
|
||||
@@ -36,7 +54,7 @@ def get_redis_instance():
|
||||
def gen_cache_key(messages, model="docgpt", tools=None):
|
||||
if not all(isinstance(msg, dict) for msg in messages):
|
||||
raise ValueError("All messages must be dictionaries.")
|
||||
messages_str = json.dumps(messages)
|
||||
messages_str = json.dumps(messages, default=_cache_default)
|
||||
tools_str = json.dumps(str(tools)) if tools else ""
|
||||
combined = f"{model}_{messages_str}_{tools_str}"
|
||||
cache_key = get_hash(combined)
|
||||
|
||||
@@ -1,8 +1,17 @@
|
||||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
|
||||
from celery import Celery
|
||||
from application.core import log_context
|
||||
from application.core.settings import settings
|
||||
from celery.signals import setup_logging, worker_process_init, worker_ready
|
||||
from celery.signals import (
|
||||
setup_logging,
|
||||
task_postrun,
|
||||
task_prerun,
|
||||
worker_process_init,
|
||||
worker_ready,
|
||||
)
|
||||
|
||||
|
||||
def make_celery(app_name=__name__):
|
||||
@@ -41,6 +50,54 @@ def _dispose_db_engine_on_fork(*args, **kwargs):
|
||||
dispose_engine()
|
||||
|
||||
|
||||
# Most tasks in this repo accept ``user`` where the log context wants
|
||||
# ``user_id``; map task parameter names to context keys explicitly.
|
||||
_TASK_PARAM_TO_CTX_KEY: dict[str, str] = {
|
||||
"user": "user_id",
|
||||
"user_id": "user_id",
|
||||
"agent_id": "agent_id",
|
||||
"conversation_id": "conversation_id",
|
||||
}
|
||||
|
||||
_task_log_tokens: dict[str, object] = {}
|
||||
|
||||
|
||||
@task_prerun.connect
|
||||
def _bind_task_log_context(task_id, task, args, kwargs, **_):
|
||||
# Resolve task args by parameter name — nearly every task in this repo
|
||||
# is called positionally, so ``kwargs.get('user')`` would bind nothing.
|
||||
ctx = {"activity_id": task_id}
|
||||
try:
|
||||
sig = inspect.signature(task.run)
|
||||
bound = sig.bind_partial(*args, **kwargs).arguments
|
||||
except (TypeError, ValueError):
|
||||
bound = dict(kwargs)
|
||||
for param_name, value in bound.items():
|
||||
ctx_key = _TASK_PARAM_TO_CTX_KEY.get(param_name)
|
||||
if ctx_key and value:
|
||||
ctx[ctx_key] = value
|
||||
_task_log_tokens[task_id] = log_context.bind(**ctx)
|
||||
|
||||
|
||||
@task_postrun.connect
|
||||
def _unbind_task_log_context(task_id, **_):
|
||||
# ``task_postrun`` fires on both success and failure. Required for
|
||||
# Celery: unlike the Flask path, tasks aren't isolated in their own
|
||||
# ``copy_context().run(...)``, so a missing reset would leak the
|
||||
# bind onto the next task on the same worker.
|
||||
token = _task_log_tokens.pop(task_id, None)
|
||||
if token is None:
|
||||
return
|
||||
try:
|
||||
log_context.reset(token)
|
||||
except ValueError:
|
||||
# task_prerun and task_postrun ran on different threads (non-default
|
||||
# Celery pool); the token isn't valid in this context. Drop it.
|
||||
logging.getLogger(__name__).debug(
|
||||
"log_context reset skipped for task %s", task_id
|
||||
)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def _run_version_check(*args, **kwargs):
|
||||
"""Kick off the anonymous version check on worker startup.
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import os
|
||||
from application.core.settings import settings
|
||||
|
||||
broker_url = os.getenv("CELERY_BROKER_URL")
|
||||
result_backend = os.getenv("CELERY_RESULT_BACKEND")
|
||||
# Pydantic loads .env into ``settings`` but does not inject values into
|
||||
# ``os.environ`` — read directly from settings so beat startup (which
|
||||
# imports this module before any explicit env load) sees a real URL.
|
||||
broker_url = settings.CELERY_BROKER_URL
|
||||
result_backend = settings.CELERY_RESULT_BACKEND
|
||||
|
||||
task_serializer = 'json'
|
||||
result_serializer = 'json'
|
||||
@@ -9,3 +12,22 @@ accept_content = ['json']
|
||||
|
||||
# Autodiscover tasks
|
||||
imports = ('application.api.user.tasks',)
|
||||
|
||||
# Project-scoped queue so a stray sibling worker on the same broker
|
||||
# (other repo, same default ``celery`` queue) can't grab DocsGPT tasks.
|
||||
task_default_queue = "docsgpt"
|
||||
task_default_exchange = "docsgpt"
|
||||
task_default_routing_key = "docsgpt"
|
||||
|
||||
beat_scheduler = "redbeat.RedBeatScheduler"
|
||||
redbeat_redis_url = broker_url
|
||||
redbeat_key_prefix = "redbeat:docsgpt:"
|
||||
redbeat_lock_timeout = 90
|
||||
|
||||
# Survive worker SIGKILL/OOM without silently dropping in-flight tasks.
|
||||
task_acks_late = True
|
||||
task_reject_on_worker_lost = True
|
||||
worker_prefetch_multiplier = settings.CELERY_WORKER_PREFETCH_MULTIPLIER
|
||||
broker_transport_options = {"visibility_timeout": settings.CELERY_VISIBILITY_TIMEOUT}
|
||||
result_expires = 86400 * 7
|
||||
task_track_started = True
|
||||
|
||||
57
application/core/log_context.py
Normal file
57
application/core/log_context.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Per-activity logging context backed by ``contextvars``.
|
||||
|
||||
The ``_ContextFilter`` installed by ``logging_config.setup_logging`` stamps
|
||||
every ``LogRecord`` emitted inside a ``bind`` block with the bound keys, so
|
||||
they land as first-class attributes on the OTLP log export rather than being
|
||||
buried inside formatted message bodies.
|
||||
|
||||
A single ``ContextVar`` holds a dict so nested binds reset atomically (LIFO)
|
||||
via the token returned by ``bind``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextvars import ContextVar, Token
|
||||
from typing import Mapping
|
||||
|
||||
|
||||
_CTX_KEYS: frozenset[str] = frozenset(
|
||||
{
|
||||
"activity_id",
|
||||
"parent_activity_id",
|
||||
"user_id",
|
||||
"agent_id",
|
||||
"conversation_id",
|
||||
"endpoint",
|
||||
"model",
|
||||
}
|
||||
)
|
||||
|
||||
_ctx: ContextVar[Mapping[str, str]] = ContextVar("log_ctx", default={})
|
||||
|
||||
|
||||
def bind(**kwargs: object) -> Token:
|
||||
"""Overlay the given keys onto the current context.
|
||||
|
||||
Returns a ``Token`` so the caller can ``reset`` in a ``finally`` block.
|
||||
Keys outside :data:`_CTX_KEYS` are silently dropped (so a typo can't
|
||||
stamp a stray field name onto every record), as are ``None`` values
|
||||
(a missing attribute is more useful than the literal string ``"None"``).
|
||||
"""
|
||||
overlay = {
|
||||
k: str(v)
|
||||
for k, v in kwargs.items()
|
||||
if k in _CTX_KEYS and v is not None
|
||||
}
|
||||
new = {**_ctx.get(), **overlay}
|
||||
return _ctx.set(new)
|
||||
|
||||
|
||||
def reset(token: Token) -> None:
|
||||
"""Restore the context to the snapshot captured by the matching ``bind``."""
|
||||
_ctx.reset(token)
|
||||
|
||||
|
||||
def snapshot() -> Mapping[str, str]:
|
||||
"""Return the current context dict. Treat as read-only; use :func:`bind`."""
|
||||
return _ctx.get()
|
||||
@@ -1,11 +1,75 @@
|
||||
import logging
|
||||
import os
|
||||
from logging.config import dictConfig
|
||||
|
||||
def setup_logging():
|
||||
from application.core.log_context import snapshot as _ctx_snapshot
|
||||
|
||||
|
||||
# Loggers with ``propagate=False`` don't share root's handlers, so the
|
||||
# context filter has to be installed on their handlers directly.
|
||||
_NON_PROPAGATING_LOGGERS: tuple[str, ...] = (
|
||||
"uvicorn",
|
||||
"uvicorn.access",
|
||||
"uvicorn.error",
|
||||
"celery.app.trace",
|
||||
"celery.worker.strategy",
|
||||
"gunicorn.error",
|
||||
"gunicorn.access",
|
||||
)
|
||||
|
||||
|
||||
class _ContextFilter(logging.Filter):
|
||||
"""Stamp the current ``log_context`` snapshot onto every ``LogRecord``.
|
||||
|
||||
Must be installed on **handlers**, not loggers: Python skips logger-level
|
||||
filters when a child logger's record propagates up. The ``hasattr`` guard
|
||||
keeps an explicit ``logger.info(..., extra={...})`` from being overwritten.
|
||||
"""
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
for key, value in _ctx_snapshot().items():
|
||||
if not hasattr(record, key):
|
||||
setattr(record, key, value)
|
||||
return True
|
||||
|
||||
|
||||
def _otlp_logs_enabled() -> bool:
|
||||
"""Return True when the user has opted in to OTLP log export.
|
||||
|
||||
Gated by the standard OTEL env vars so no project-specific knob is needed:
|
||||
set ``OTEL_LOGS_EXPORTER=otlp`` (and leave ``OTEL_SDK_DISABLED`` unset or
|
||||
false) to flip it on. When false, ``setup_logging`` keeps its original
|
||||
console-only behavior.
|
||||
"""
|
||||
exporter = os.getenv("OTEL_LOGS_EXPORTER", "").strip().lower()
|
||||
disabled = os.getenv("OTEL_SDK_DISABLED", "false").strip().lower() == "true"
|
||||
return exporter == "otlp" and not disabled
|
||||
|
||||
|
||||
def setup_logging() -> None:
|
||||
"""Configure the root logger with a stdout console handler.
|
||||
|
||||
When OTLP log export is enabled, ``opentelemetry-instrument`` attaches a
|
||||
``LoggingHandler`` to the root logger before this function runs. The
|
||||
``dictConfig`` call below replaces ``root.handlers`` with the console
|
||||
handler, which would silently drop the OTEL handler. To make OTLP log
|
||||
export work without forcing every contributor to opt in, snapshot the
|
||||
OTEL handlers up front and re-attach them after ``dictConfig``.
|
||||
"""
|
||||
preserved_handlers: list[logging.Handler] = []
|
||||
if _otlp_logs_enabled():
|
||||
preserved_handlers = [
|
||||
h
|
||||
for h in logging.getLogger().handlers
|
||||
if h.__class__.__module__.startswith("opentelemetry")
|
||||
]
|
||||
|
||||
dictConfig({
|
||||
'version': 1,
|
||||
'formatters': {
|
||||
'default': {
|
||||
'format': '[%(asctime)s] %(levelname)s in %(module)s: %(message)s',
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"formatters": {
|
||||
"default": {
|
||||
"format": "[%(asctime)s] %(levelname)s in %(module)s: %(message)s",
|
||||
}
|
||||
},
|
||||
"handlers": {
|
||||
@@ -15,8 +79,34 @@ def setup_logging():
|
||||
"formatter": "default",
|
||||
}
|
||||
},
|
||||
'root': {
|
||||
'level': 'INFO',
|
||||
'handlers': ['console'],
|
||||
"root": {
|
||||
"level": "INFO",
|
||||
"handlers": ["console"],
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
if preserved_handlers:
|
||||
root = logging.getLogger()
|
||||
for handler in preserved_handlers:
|
||||
if handler not in root.handlers:
|
||||
root.addHandler(handler)
|
||||
|
||||
_install_context_filter()
|
||||
|
||||
|
||||
def _install_context_filter() -> None:
|
||||
"""Attach :class:`_ContextFilter` to root's handlers + every handler on
|
||||
the known non-propagating loggers. Skipping handlers that already carry
|
||||
one keeps repeat ``setup_logging`` calls from stacking filters.
|
||||
"""
|
||||
|
||||
def _has_ctx_filter(handler: logging.Handler) -> bool:
|
||||
return any(isinstance(f, _ContextFilter) for f in handler.filters)
|
||||
|
||||
for handler in logging.getLogger().handlers:
|
||||
if not _has_ctx_filter(handler):
|
||||
handler.addFilter(_ContextFilter())
|
||||
for name in _NON_PROPAGATING_LOGGERS:
|
||||
for handler in logging.getLogger(name).handlers:
|
||||
if not _has_ctx_filter(handler):
|
||||
handler.addFilter(_ContextFilter())
|
||||
|
||||
@@ -1,266 +0,0 @@
|
||||
"""
|
||||
Model configurations for all supported LLM providers.
|
||||
"""
|
||||
|
||||
from application.core.model_settings import (
|
||||
AvailableModel,
|
||||
ModelCapabilities,
|
||||
ModelProvider,
|
||||
)
|
||||
|
||||
# Base image attachment types supported by most vision-capable LLMs
|
||||
IMAGE_ATTACHMENTS = [
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/jpg",
|
||||
"image/webp",
|
||||
"image/gif",
|
||||
]
|
||||
|
||||
# PDF excluded: most OpenAI-compatible endpoints don't support native PDF uploads.
|
||||
# When excluded, PDFs are synthetically processed by converting pages to images.
|
||||
OPENAI_ATTACHMENTS = IMAGE_ATTACHMENTS
|
||||
|
||||
GOOGLE_ATTACHMENTS = ["application/pdf"] + IMAGE_ATTACHMENTS
|
||||
|
||||
ANTHROPIC_ATTACHMENTS = IMAGE_ATTACHMENTS
|
||||
|
||||
OPENROUTER_ATTACHMENTS = IMAGE_ATTACHMENTS
|
||||
|
||||
NOVITA_ATTACHMENTS = IMAGE_ATTACHMENTS
|
||||
|
||||
|
||||
OPENAI_MODELS = [
|
||||
AvailableModel(
|
||||
id="gpt-5.1",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="GPT-5.1",
|
||||
description="Flagship model with enhanced reasoning, coding, and agentic capabilities",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=OPENAI_ATTACHMENTS,
|
||||
context_window=200000,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="gpt-5-mini",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="GPT-5 Mini",
|
||||
description="Faster, cost-effective variant of GPT-5.1",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=OPENAI_ATTACHMENTS,
|
||||
context_window=200000,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
ANTHROPIC_MODELS = [
|
||||
AvailableModel(
|
||||
id="claude-3-5-sonnet-20241022",
|
||||
provider=ModelProvider.ANTHROPIC,
|
||||
display_name="Claude 3.5 Sonnet (Latest)",
|
||||
description="Latest Claude 3.5 Sonnet with enhanced capabilities",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
|
||||
context_window=200000,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="claude-3-5-sonnet",
|
||||
provider=ModelProvider.ANTHROPIC,
|
||||
display_name="Claude 3.5 Sonnet",
|
||||
description="Balanced performance and capability",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
|
||||
context_window=200000,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="claude-3-opus",
|
||||
provider=ModelProvider.ANTHROPIC,
|
||||
display_name="Claude 3 Opus",
|
||||
description="Most capable Claude model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
|
||||
context_window=200000,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="claude-3-haiku",
|
||||
provider=ModelProvider.ANTHROPIC,
|
||||
display_name="Claude 3 Haiku",
|
||||
description="Fastest Claude model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
|
||||
context_window=200000,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
GOOGLE_MODELS = [
|
||||
AvailableModel(
|
||||
id="gemini-flash-latest",
|
||||
provider=ModelProvider.GOOGLE,
|
||||
display_name="Gemini Flash (Latest)",
|
||||
description="Latest experimental Gemini model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=GOOGLE_ATTACHMENTS,
|
||||
context_window=int(1e6),
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="gemini-flash-lite-latest",
|
||||
provider=ModelProvider.GOOGLE,
|
||||
display_name="Gemini Flash Lite (Latest)",
|
||||
description="Fast with huge context window",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=GOOGLE_ATTACHMENTS,
|
||||
context_window=int(1e6),
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="gemini-3-pro-preview",
|
||||
provider=ModelProvider.GOOGLE,
|
||||
display_name="Gemini 3 Pro",
|
||||
description="Most capable Gemini model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=GOOGLE_ATTACHMENTS,
|
||||
context_window=2000000,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
GROQ_MODELS = [
|
||||
AvailableModel(
|
||||
id="llama-3.3-70b-versatile",
|
||||
provider=ModelProvider.GROQ,
|
||||
display_name="Llama 3.3 70B",
|
||||
description="Latest Llama model with high-speed inference",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
context_window=128000,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="openai/gpt-oss-120b",
|
||||
provider=ModelProvider.GROQ,
|
||||
display_name="GPT-OSS 120B",
|
||||
description="Open-source GPT model optimized for speed",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
context_window=128000,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
OPENROUTER_MODELS = [
|
||||
AvailableModel(
|
||||
id="qwen/qwen3-coder:free",
|
||||
provider=ModelProvider.OPENROUTER,
|
||||
display_name="Qwen 3 Coder",
|
||||
description="Latest Qwen model with high-speed inference",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
context_window=128000,
|
||||
supported_attachment_types=OPENROUTER_ATTACHMENTS
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="google/gemma-3-27b-it:free",
|
||||
provider=ModelProvider.OPENROUTER,
|
||||
display_name="Gemma 3 27B",
|
||||
description="Latest Gemma model with high-speed inference",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
context_window=128000,
|
||||
supported_attachment_types=OPENROUTER_ATTACHMENTS
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
NOVITA_MODELS = [
|
||||
AvailableModel(
|
||||
id="moonshotai/kimi-k2.5",
|
||||
provider=ModelProvider.NOVITA,
|
||||
display_name="Kimi K2.5",
|
||||
description="MoE model with function calling, structured output, reasoning, and vision",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=NOVITA_ATTACHMENTS,
|
||||
context_window=262144,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="zai-org/glm-5",
|
||||
provider=ModelProvider.NOVITA,
|
||||
display_name="GLM-5",
|
||||
description="MoE model with function calling, structured output, and reasoning",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=[],
|
||||
context_window=202800,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="minimax/minimax-m2.5",
|
||||
provider=ModelProvider.NOVITA,
|
||||
display_name="MiniMax M2.5",
|
||||
description="MoE model with function calling, structured output, and reasoning",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=[],
|
||||
context_window=204800,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
AZURE_OPENAI_MODELS = [
|
||||
AvailableModel(
|
||||
id="azure-gpt-4",
|
||||
provider=ModelProvider.AZURE_OPENAI,
|
||||
display_name="Azure OpenAI GPT-4",
|
||||
description="Azure-hosted GPT model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=OPENAI_ATTACHMENTS,
|
||||
context_window=8192,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def create_custom_openai_model(model_name: str, base_url: str) -> AvailableModel:
|
||||
"""Create a custom OpenAI-compatible model (e.g., LM Studio, Ollama)."""
|
||||
return AvailableModel(
|
||||
id=model_name,
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name=model_name,
|
||||
description=f"Custom OpenAI-compatible model at {base_url}",
|
||||
base_url=base_url,
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supported_attachment_types=OPENAI_ATTACHMENTS,
|
||||
),
|
||||
)
|
||||
385
application/core/model_registry.py
Normal file
385
application/core/model_registry.py
Normal file
@@ -0,0 +1,385 @@
|
||||
"""Layered model registry.
|
||||
|
||||
Loads model catalogs from YAML files (built-in + operator-supplied),
|
||||
groups them by provider name, then for each registered provider plugin
|
||||
calls ``get_models`` to produce the final per-provider model list.
|
||||
|
||||
End-user BYOM (per-user model records in Postgres) is layered on top:
|
||||
when a lookup arrives with a ``user_id``, the registry consults a
|
||||
per-user cache first (loaded from the ``user_custom_models`` table on
|
||||
miss) and falls through to the built-in catalog.
|
||||
|
||||
Cross-process invalidation: ``ModelRegistry`` is a per-process
|
||||
singleton, so a CRUD write only evicts the cache in the process that
|
||||
served it. Other gunicorn workers and Celery workers would otherwise
|
||||
keep using a deleted/disabled/key-rotated BYOM record indefinitely.
|
||||
``invalidate_user`` therefore both drops the local layer *and* bumps a
|
||||
Redis-side version counter; other processes notice the bump on their
|
||||
next access (after the local TTL window) and reload from Postgres. If
|
||||
Redis is unreachable the per-process TTL still bounds staleness — pure
|
||||
TTL semantics, no regression.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from application.core.model_settings import AvailableModel
|
||||
from application.core.model_yaml import (
|
||||
BUILTIN_MODELS_DIR,
|
||||
ProviderCatalog,
|
||||
load_model_yamls,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_USER_CACHE_TTL_SECONDS = 60.0
|
||||
_USER_VERSION_KEY_PREFIX = "byom:registry_version:"
|
||||
|
||||
|
||||
class ModelRegistry:
|
||||
"""Singleton registry of available models."""
|
||||
|
||||
_instance: Optional["ModelRegistry"] = None
|
||||
_initialized: bool = False
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not ModelRegistry._initialized:
|
||||
self.models: Dict[str, AvailableModel] = {}
|
||||
self.default_model_id: Optional[str] = None
|
||||
# Per-user BYOM cache. Each entry is
|
||||
# ``(layer, version_at_load, loaded_at_monotonic)``:
|
||||
# * ``layer`` — {model_id: AvailableModel}
|
||||
# * ``version_at_load`` — Redis-side counter snapshot at
|
||||
# reload time, or ``None`` if Redis was unreachable
|
||||
# * ``loaded_at_monotonic`` — for TTL bookkeeping
|
||||
# Populated lazily, evicted by TTL + cross-process
|
||||
# invalidation (see ``invalidate_user``).
|
||||
self._user_models: Dict[
|
||||
str,
|
||||
Tuple[Dict[str, AvailableModel], Optional[int], float],
|
||||
] = {}
|
||||
self._load_models()
|
||||
ModelRegistry._initialized = True
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "ModelRegistry":
|
||||
return cls()
|
||||
|
||||
@classmethod
|
||||
def reset(cls) -> None:
|
||||
"""Clear the singleton. Intended for test fixtures."""
|
||||
cls._instance = None
|
||||
cls._initialized = False
|
||||
|
||||
@classmethod
|
||||
def invalidate_user(cls, user_id: str) -> None:
|
||||
"""Drop the cached per-user model layer for ``user_id``.
|
||||
|
||||
Called by the BYOM REST routes after every create/update/delete.
|
||||
Two effects:
|
||||
|
||||
* Local: pop the entry from this process's cache so the next
|
||||
lookup re-reads from Postgres immediately.
|
||||
* Cross-process: ``INCR`` a Redis-side version counter for this
|
||||
user. Other gunicorn/Celery processes notice the counter
|
||||
changed on their next TTL-driven recheck (see
|
||||
``_user_models_for``) and reload. If Redis is unreachable we
|
||||
log and continue — local invalidation still happened, and
|
||||
peers fall back to TTL-only staleness bounds.
|
||||
"""
|
||||
if cls._instance is not None:
|
||||
cls._instance._user_models.pop(user_id, None)
|
||||
try:
|
||||
from application.cache import get_redis_instance
|
||||
|
||||
client = get_redis_instance()
|
||||
if client is not None:
|
||||
client.incr(_USER_VERSION_KEY_PREFIX + user_id)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"BYOM invalidate: failed to publish version bump for "
|
||||
"user %s (Redis unreachable?): %s",
|
||||
user_id,
|
||||
e,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _read_user_version(cls, user_id: str) -> Optional[int]:
|
||||
"""Return the Redis-side invalidation counter for ``user_id``.
|
||||
|
||||
``0`` if the key has never been bumped; ``None`` if Redis is
|
||||
unreachable or the read failed (callers fall back to TTL-only
|
||||
staleness in that case).
|
||||
"""
|
||||
try:
|
||||
from application.cache import get_redis_instance
|
||||
|
||||
client = get_redis_instance()
|
||||
if client is None:
|
||||
return None
|
||||
raw = client.get(_USER_VERSION_KEY_PREFIX + user_id)
|
||||
if raw is None:
|
||||
return 0
|
||||
return int(raw)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _load_models(self) -> None:
|
||||
from pathlib import Path
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.llm.providers import ALL_PROVIDERS
|
||||
|
||||
directories = [BUILTIN_MODELS_DIR]
|
||||
operator_dir = getattr(settings, "MODELS_CONFIG_DIR", None)
|
||||
if operator_dir:
|
||||
op_path = Path(operator_dir)
|
||||
if not op_path.exists():
|
||||
logger.warning(
|
||||
"MODELS_CONFIG_DIR=%s does not exist; no operator "
|
||||
"model YAMLs will be loaded.",
|
||||
operator_dir,
|
||||
)
|
||||
elif not op_path.is_dir():
|
||||
logger.warning(
|
||||
"MODELS_CONFIG_DIR=%s is not a directory; no operator "
|
||||
"model YAMLs will be loaded.",
|
||||
operator_dir,
|
||||
)
|
||||
else:
|
||||
directories.append(op_path)
|
||||
|
||||
catalogs = load_model_yamls(directories)
|
||||
|
||||
# Validate every catalog targets a known plugin before doing any
|
||||
# registry work, so an unknown provider name in YAML aborts boot
|
||||
# with a clear error.
|
||||
plugin_names = {p.name for p in ALL_PROVIDERS}
|
||||
for c in catalogs:
|
||||
if c.provider not in plugin_names:
|
||||
raise ValueError(
|
||||
f"{c.source_path}: YAML declares unknown provider "
|
||||
f"{c.provider!r}; no Provider plugin is registered "
|
||||
f"under that name. Known: {sorted(plugin_names)}"
|
||||
)
|
||||
|
||||
catalogs_by_provider: Dict[str, List[ProviderCatalog]] = defaultdict(list)
|
||||
for c in catalogs:
|
||||
catalogs_by_provider[c.provider].append(c)
|
||||
|
||||
self.models.clear()
|
||||
for provider in ALL_PROVIDERS:
|
||||
if not provider.is_enabled(settings):
|
||||
continue
|
||||
for model in provider.get_models(
|
||||
settings, catalogs_by_provider.get(provider.name, [])
|
||||
):
|
||||
self.models[model.id] = model
|
||||
|
||||
self.default_model_id = self._resolve_default(settings)
|
||||
|
||||
logger.info(
|
||||
"ModelRegistry loaded %d models, default: %s",
|
||||
len(self.models),
|
||||
self.default_model_id,
|
||||
)
|
||||
|
||||
def _resolve_default(self, settings) -> Optional[str]:
|
||||
if settings.LLM_NAME:
|
||||
for name in self._parse_model_names(settings.LLM_NAME):
|
||||
if name in self.models:
|
||||
return name
|
||||
if settings.LLM_NAME in self.models:
|
||||
return settings.LLM_NAME
|
||||
|
||||
if settings.LLM_PROVIDER and settings.API_KEY:
|
||||
for model_id, model in self.models.items():
|
||||
if model.provider.value == settings.LLM_PROVIDER:
|
||||
return model_id
|
||||
|
||||
if self.models:
|
||||
return next(iter(self.models.keys()))
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _parse_model_names(llm_name: str) -> List[str]:
|
||||
if not llm_name:
|
||||
return []
|
||||
return [name.strip() for name in llm_name.split(",") if name.strip()]
|
||||
|
||||
# Per-user (BYOM) layer
|
||||
|
||||
def _user_models_for(self, user_id: str) -> Dict[str, AvailableModel]:
|
||||
"""Return the user's BYOM models keyed by registry id (UUID).
|
||||
|
||||
Loaded lazily from Postgres on first access; cached subject to
|
||||
a per-process TTL (``_USER_CACHE_TTL_SECONDS``) and a Redis-
|
||||
backed version counter for cross-process invalidation. The TTL
|
||||
bounds staleness even when Redis is unreachable, while the
|
||||
version stamp lets peers refresh without a DB read on the
|
||||
common case (no invalidation since last load). Decryption
|
||||
failures and DB errors yield an empty layer (logged) — the
|
||||
user simply doesn't see their custom models on this request,
|
||||
never a 500.
|
||||
"""
|
||||
cached = self._user_models.get(user_id)
|
||||
now = time.monotonic()
|
||||
|
||||
if cached is not None:
|
||||
layer, cached_version, loaded_at = cached
|
||||
if (now - loaded_at) < _USER_CACHE_TTL_SECONDS:
|
||||
return layer
|
||||
# TTL elapsed: peek at the cross-process counter. If it
|
||||
# matches what we saw at load time, no invalidation has
|
||||
# happened — extend the TTL without touching Postgres. If
|
||||
# Redis is unreachable (``current_version is None``) we
|
||||
# fall through to a real reload, which keeps staleness
|
||||
# bounded to the TTL.
|
||||
current_version = self._read_user_version(user_id)
|
||||
if (
|
||||
current_version is not None
|
||||
and cached_version is not None
|
||||
and current_version == cached_version
|
||||
):
|
||||
self._user_models[user_id] = (layer, cached_version, now)
|
||||
return layer
|
||||
|
||||
# Capture the counter *before* the DB read so a CRUD that lands
|
||||
# mid-reload doesn't get masked: the next access will see a
|
||||
# newer version and reload again.
|
||||
version_before_read = self._read_user_version(user_id)
|
||||
|
||||
layer: Dict[str, AvailableModel] = {}
|
||||
try:
|
||||
from application.core.model_settings import (
|
||||
ModelCapabilities,
|
||||
ModelProvider,
|
||||
)
|
||||
from application.storage.db.repositories.user_custom_models import (
|
||||
UserCustomModelsRepository,
|
||||
)
|
||||
from application.storage.db.session import db_readonly
|
||||
|
||||
with db_readonly() as conn:
|
||||
repo = UserCustomModelsRepository(conn)
|
||||
rows = repo.list_for_user(user_id)
|
||||
for row in rows:
|
||||
api_key = repo._decrypt_api_key(
|
||||
row.get("api_key_encrypted", ""), user_id
|
||||
)
|
||||
if not api_key:
|
||||
# SECURITY: do NOT register an unroutable BYOM
|
||||
# record. If we did, LLMCreator would fall back
|
||||
# to the caller-passed api_key (settings.API_KEY
|
||||
# for openai_compatible) and POST it to the
|
||||
# user-supplied base_url — leaking the instance
|
||||
# credential to the user's chosen endpoint.
|
||||
# Most likely cause is ENCRYPTION_SECRET_KEY
|
||||
# having rotated; user must re-save the model.
|
||||
logger.warning(
|
||||
"user_custom_models: skipping model %s for "
|
||||
"user %s — api_key could not be decrypted "
|
||||
"(rotated ENCRYPTION_SECRET_KEY?). Re-save "
|
||||
"the model to recover.",
|
||||
row.get("id"),
|
||||
user_id,
|
||||
)
|
||||
continue
|
||||
caps_raw = row.get("capabilities") or {}
|
||||
# Stored attachments may be aliases (``image``) or
|
||||
# raw MIME types. Built-in YAML models expand at
|
||||
# load time; mirror that here so downstream MIME-
|
||||
# type comparisons (handlers/base.prepare_messages)
|
||||
# match concrete types like ``image/png`` rather
|
||||
# than the bare alias.
|
||||
from application.core.model_yaml import (
|
||||
expand_attachments_lenient,
|
||||
)
|
||||
|
||||
raw_attachments = caps_raw.get("attachments", []) or []
|
||||
expanded_attachments = expand_attachments_lenient(
|
||||
raw_attachments,
|
||||
f"user_custom_models[user={user_id}, model={row.get('id')}]",
|
||||
)
|
||||
caps = ModelCapabilities(
|
||||
supports_tools=bool(caps_raw.get("supports_tools", False)),
|
||||
supports_structured_output=bool(
|
||||
caps_raw.get("supports_structured_output", False)
|
||||
),
|
||||
supports_streaming=bool(
|
||||
caps_raw.get("supports_streaming", True)
|
||||
),
|
||||
supported_attachment_types=expanded_attachments,
|
||||
context_window=int(
|
||||
caps_raw.get("context_window") or 128000
|
||||
),
|
||||
)
|
||||
model_id = str(row["id"])
|
||||
layer[model_id] = AvailableModel(
|
||||
id=model_id,
|
||||
provider=ModelProvider.OPENAI_COMPATIBLE,
|
||||
display_name=row["display_name"],
|
||||
description=row.get("description") or "",
|
||||
capabilities=caps,
|
||||
enabled=bool(row.get("enabled", True)),
|
||||
base_url=row["base_url"],
|
||||
upstream_model_id=row["upstream_model_id"],
|
||||
source="user",
|
||||
api_key=api_key,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"user_custom_models: failed to load layer for user %s: %s",
|
||||
user_id,
|
||||
e,
|
||||
)
|
||||
layer = {}
|
||||
|
||||
self._user_models[user_id] = (layer, version_before_read, now)
|
||||
return layer
|
||||
|
||||
# Lookup API. ``user_id`` enables the BYOM per-user layer; without
|
||||
# it, callers see only the built-in + operator catalog.
|
||||
|
||||
def get_model(
|
||||
self, model_id: str, user_id: Optional[str] = None
|
||||
) -> Optional[AvailableModel]:
|
||||
if user_id:
|
||||
user_layer = self._user_models_for(user_id)
|
||||
if model_id in user_layer:
|
||||
return user_layer[model_id]
|
||||
return self.models.get(model_id)
|
||||
|
||||
def get_all_models(
|
||||
self, user_id: Optional[str] = None
|
||||
) -> List[AvailableModel]:
|
||||
out = list(self.models.values())
|
||||
if user_id:
|
||||
out.extend(self._user_models_for(user_id).values())
|
||||
return out
|
||||
|
||||
def get_enabled_models(
|
||||
self, user_id: Optional[str] = None
|
||||
) -> List[AvailableModel]:
|
||||
out = [m for m in self.models.values() if m.enabled]
|
||||
if user_id:
|
||||
out.extend(
|
||||
m for m in self._user_models_for(user_id).values() if m.enabled
|
||||
)
|
||||
return out
|
||||
|
||||
def model_exists(
|
||||
self, model_id: str, user_id: Optional[str] = None
|
||||
) -> bool:
|
||||
if user_id and model_id in self._user_models_for(user_id):
|
||||
return True
|
||||
return model_id in self.models
|
||||
@@ -5,9 +5,16 @@ from typing import Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Re-exported here so existing call sites (and tests) that do
|
||||
# ``from application.core.model_settings import ModelRegistry`` keep
|
||||
# working. The implementation lives in ``application/core/model_registry.py``.
|
||||
# Imported lazily inside ``__getattr__`` to avoid an import cycle with
|
||||
# ``model_yaml`` → ``model_settings`` (this file).
|
||||
|
||||
|
||||
class ModelProvider(str, Enum):
|
||||
OPENAI = "openai"
|
||||
OPENAI_COMPATIBLE = "openai_compatible"
|
||||
OPENROUTER = "openrouter"
|
||||
AZURE_OPENAI = "azure_openai"
|
||||
ANTHROPIC = "anthropic"
|
||||
@@ -41,11 +48,21 @@ class AvailableModel:
|
||||
capabilities: ModelCapabilities = field(default_factory=ModelCapabilities)
|
||||
enabled: bool = True
|
||||
base_url: Optional[str] = None
|
||||
# User-facing label distinct from dispatch provider (e.g. mistral
|
||||
# routed through openai_compatible).
|
||||
display_provider: Optional[str] = None
|
||||
# Sent in the API call's ``model`` field; falls back to ``self.id``
|
||||
# for built-ins where id IS the upstream name.
|
||||
upstream_model_id: Optional[str] = None
|
||||
# "builtin" for catalog YAMLs, "user" for BYOM records.
|
||||
source: str = "builtin"
|
||||
# Decrypted/resolved at registry-merge time. Never serialized.
|
||||
api_key: Optional[str] = field(default=None, repr=False, compare=False)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
result = {
|
||||
"id": self.id,
|
||||
"provider": self.provider.value,
|
||||
"provider": self.display_provider or self.provider.value,
|
||||
"display_name": self.display_name,
|
||||
"description": self.description,
|
||||
"supported_attachment_types": self.capabilities.supported_attachment_types,
|
||||
@@ -54,261 +71,21 @@ class AvailableModel:
|
||||
"supports_streaming": self.capabilities.supports_streaming,
|
||||
"context_window": self.capabilities.context_window,
|
||||
"enabled": self.enabled,
|
||||
"source": self.source,
|
||||
}
|
||||
if self.base_url:
|
||||
result["base_url"] = self.base_url
|
||||
return result
|
||||
|
||||
|
||||
class ModelRegistry:
|
||||
_instance = None
|
||||
_initialized = False
|
||||
def __getattr__(name):
|
||||
"""Lazy re-export of ``ModelRegistry`` from ``model_registry.py``.
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
Done lazily to avoid an import cycle: ``model_registry`` imports
|
||||
``model_yaml`` which imports the dataclasses from this file.
|
||||
"""
|
||||
if name == "ModelRegistry":
|
||||
from application.core.model_registry import ModelRegistry as _MR
|
||||
|
||||
def __init__(self):
|
||||
if not ModelRegistry._initialized:
|
||||
self.models: Dict[str, AvailableModel] = {}
|
||||
self.default_model_id: Optional[str] = None
|
||||
self._load_models()
|
||||
ModelRegistry._initialized = True
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "ModelRegistry":
|
||||
return cls()
|
||||
|
||||
def _load_models(self):
|
||||
from application.core.settings import settings
|
||||
|
||||
self.models.clear()
|
||||
|
||||
# Skip DocsGPT model if using custom OpenAI-compatible endpoint
|
||||
if not settings.OPENAI_BASE_URL:
|
||||
self._add_docsgpt_models(settings)
|
||||
if (
|
||||
settings.OPENAI_API_KEY
|
||||
or (settings.LLM_PROVIDER == "openai" and settings.API_KEY)
|
||||
or settings.OPENAI_BASE_URL
|
||||
):
|
||||
self._add_openai_models(settings)
|
||||
if settings.OPENAI_API_BASE or (
|
||||
settings.LLM_PROVIDER == "azure_openai" and settings.API_KEY
|
||||
):
|
||||
self._add_azure_openai_models(settings)
|
||||
if settings.ANTHROPIC_API_KEY or (
|
||||
settings.LLM_PROVIDER == "anthropic" and settings.API_KEY
|
||||
):
|
||||
self._add_anthropic_models(settings)
|
||||
if settings.GOOGLE_API_KEY or (
|
||||
settings.LLM_PROVIDER == "google" and settings.API_KEY
|
||||
):
|
||||
self._add_google_models(settings)
|
||||
if settings.GROQ_API_KEY or (
|
||||
settings.LLM_PROVIDER == "groq" and settings.API_KEY
|
||||
):
|
||||
self._add_groq_models(settings)
|
||||
if settings.OPEN_ROUTER_API_KEY or (
|
||||
settings.LLM_PROVIDER == "openrouter" and settings.API_KEY
|
||||
):
|
||||
self._add_openrouter_models(settings)
|
||||
if settings.NOVITA_API_KEY or (
|
||||
settings.LLM_PROVIDER == "novita" and settings.API_KEY
|
||||
):
|
||||
self._add_novita_models(settings)
|
||||
if settings.HUGGINGFACE_API_KEY or (
|
||||
settings.LLM_PROVIDER == "huggingface" and settings.API_KEY
|
||||
):
|
||||
self._add_huggingface_models(settings)
|
||||
# Default model selection
|
||||
if settings.LLM_NAME:
|
||||
# Parse LLM_NAME (may be comma-separated)
|
||||
model_names = self._parse_model_names(settings.LLM_NAME)
|
||||
# First model in the list becomes default
|
||||
for model_name in model_names:
|
||||
if model_name in self.models:
|
||||
self.default_model_id = model_name
|
||||
break
|
||||
# Backward compat: try exact match if no parsed model found
|
||||
if not self.default_model_id and settings.LLM_NAME in self.models:
|
||||
self.default_model_id = settings.LLM_NAME
|
||||
|
||||
if not self.default_model_id:
|
||||
if settings.LLM_PROVIDER and settings.API_KEY:
|
||||
for model_id, model in self.models.items():
|
||||
if model.provider.value == settings.LLM_PROVIDER:
|
||||
self.default_model_id = model_id
|
||||
break
|
||||
|
||||
if not self.default_model_id and self.models:
|
||||
self.default_model_id = next(iter(self.models.keys()))
|
||||
logger.info(
|
||||
f"ModelRegistry loaded {len(self.models)} models, default: {self.default_model_id}"
|
||||
)
|
||||
|
||||
def _add_openai_models(self, settings):
|
||||
from application.core.model_configs import (
|
||||
OPENAI_MODELS,
|
||||
create_custom_openai_model,
|
||||
)
|
||||
|
||||
# Check if using local OpenAI-compatible endpoint (Ollama, LM Studio, etc.)
|
||||
using_local_endpoint = bool(
|
||||
settings.OPENAI_BASE_URL and settings.OPENAI_BASE_URL.strip()
|
||||
)
|
||||
|
||||
if using_local_endpoint:
|
||||
# When OPENAI_BASE_URL is set, ONLY register custom models from LLM_NAME
|
||||
# Do NOT add standard OpenAI models (gpt-5.1, etc.)
|
||||
if settings.LLM_NAME:
|
||||
model_names = self._parse_model_names(settings.LLM_NAME)
|
||||
for model_name in model_names:
|
||||
custom_model = create_custom_openai_model(
|
||||
model_name, settings.OPENAI_BASE_URL
|
||||
)
|
||||
self.models[model_name] = custom_model
|
||||
logger.info(
|
||||
f"Registered custom OpenAI model: {model_name} at {settings.OPENAI_BASE_URL}"
|
||||
)
|
||||
else:
|
||||
# Standard OpenAI API usage - add standard models if API key is valid
|
||||
if settings.OPENAI_API_KEY:
|
||||
for model in OPENAI_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_azure_openai_models(self, settings):
|
||||
from application.core.model_configs import AZURE_OPENAI_MODELS
|
||||
|
||||
if settings.LLM_PROVIDER == "azure_openai" and settings.LLM_NAME:
|
||||
for model in AZURE_OPENAI_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in AZURE_OPENAI_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_anthropic_models(self, settings):
|
||||
from application.core.model_configs import ANTHROPIC_MODELS
|
||||
|
||||
if settings.ANTHROPIC_API_KEY:
|
||||
for model in ANTHROPIC_MODELS:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
if settings.LLM_PROVIDER == "anthropic" and settings.LLM_NAME:
|
||||
for model in ANTHROPIC_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in ANTHROPIC_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_google_models(self, settings):
|
||||
from application.core.model_configs import GOOGLE_MODELS
|
||||
|
||||
if settings.GOOGLE_API_KEY:
|
||||
for model in GOOGLE_MODELS:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
if settings.LLM_PROVIDER == "google" and settings.LLM_NAME:
|
||||
for model in GOOGLE_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in GOOGLE_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_groq_models(self, settings):
|
||||
from application.core.model_configs import GROQ_MODELS
|
||||
|
||||
if settings.GROQ_API_KEY:
|
||||
for model in GROQ_MODELS:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
if settings.LLM_PROVIDER == "groq" and settings.LLM_NAME:
|
||||
for model in GROQ_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in GROQ_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_openrouter_models(self, settings):
|
||||
from application.core.model_configs import OPENROUTER_MODELS
|
||||
|
||||
if settings.OPEN_ROUTER_API_KEY:
|
||||
for model in OPENROUTER_MODELS:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
if settings.LLM_PROVIDER == "openrouter" and settings.LLM_NAME:
|
||||
for model in OPENROUTER_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in OPENROUTER_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_novita_models(self, settings):
|
||||
from application.core.model_configs import NOVITA_MODELS
|
||||
|
||||
if settings.NOVITA_API_KEY:
|
||||
for model in NOVITA_MODELS:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
if settings.LLM_PROVIDER == "novita" and settings.LLM_NAME:
|
||||
for model in NOVITA_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in NOVITA_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_docsgpt_models(self, settings):
|
||||
model_id = "docsgpt-local"
|
||||
model = AvailableModel(
|
||||
id=model_id,
|
||||
provider=ModelProvider.DOCSGPT,
|
||||
display_name="DocsGPT Model",
|
||||
description="Local model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=False,
|
||||
supported_attachment_types=[],
|
||||
),
|
||||
)
|
||||
self.models[model_id] = model
|
||||
|
||||
def _add_huggingface_models(self, settings):
|
||||
model_id = "huggingface-local"
|
||||
model = AvailableModel(
|
||||
id=model_id,
|
||||
provider=ModelProvider.HUGGINGFACE,
|
||||
display_name="Hugging Face Model",
|
||||
description="Local Hugging Face model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=False,
|
||||
supported_attachment_types=[],
|
||||
),
|
||||
)
|
||||
self.models[model_id] = model
|
||||
|
||||
def _parse_model_names(self, llm_name: str) -> List[str]:
|
||||
"""
|
||||
Parse LLM_NAME which may contain comma-separated model names.
|
||||
E.g., 'deepseek-r1:1.5b,gemma:2b' -> ['deepseek-r1:1.5b', 'gemma:2b']
|
||||
"""
|
||||
if not llm_name:
|
||||
return []
|
||||
return [name.strip() for name in llm_name.split(",") if name.strip()]
|
||||
|
||||
def get_model(self, model_id: str) -> Optional[AvailableModel]:
|
||||
return self.models.get(model_id)
|
||||
|
||||
def get_all_models(self) -> List[AvailableModel]:
|
||||
return list(self.models.values())
|
||||
|
||||
def get_enabled_models(self) -> List[AvailableModel]:
|
||||
return [m for m in self.models.values() if m.enabled]
|
||||
|
||||
def model_exists(self, model_id: str) -> bool:
|
||||
return model_id in self.models
|
||||
return _MR
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
@@ -1,47 +1,59 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from application.core.model_settings import ModelRegistry
|
||||
from application.core.model_registry import ModelRegistry
|
||||
|
||||
|
||||
def get_api_key_for_provider(provider: str) -> Optional[str]:
|
||||
"""Get the appropriate API key for a provider"""
|
||||
"""Get the appropriate API key for a provider.
|
||||
|
||||
Delegates to the provider plugin's ``get_api_key``. Falls back to the
|
||||
generic ``settings.API_KEY`` for unknown providers.
|
||||
"""
|
||||
from application.core.settings import settings
|
||||
from application.llm.providers import PROVIDERS_BY_NAME
|
||||
|
||||
provider_key_map = {
|
||||
"openai": settings.OPENAI_API_KEY,
|
||||
"openrouter": settings.OPEN_ROUTER_API_KEY,
|
||||
"novita": settings.NOVITA_API_KEY,
|
||||
"anthropic": settings.ANTHROPIC_API_KEY,
|
||||
"google": settings.GOOGLE_API_KEY,
|
||||
"groq": settings.GROQ_API_KEY,
|
||||
"huggingface": settings.HUGGINGFACE_API_KEY,
|
||||
"azure_openai": settings.API_KEY,
|
||||
"docsgpt": None,
|
||||
"llama.cpp": None,
|
||||
}
|
||||
|
||||
provider_key = provider_key_map.get(provider)
|
||||
if provider_key:
|
||||
return provider_key
|
||||
plugin = PROVIDERS_BY_NAME.get(provider)
|
||||
if plugin is not None:
|
||||
key = plugin.get_api_key(settings)
|
||||
if key:
|
||||
return key
|
||||
return settings.API_KEY
|
||||
|
||||
|
||||
def get_all_available_models() -> Dict[str, Dict[str, Any]]:
|
||||
"""Get all available models with metadata for API response"""
|
||||
def get_all_available_models(
|
||||
user_id: Optional[str] = None,
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get all available models with metadata for API response.
|
||||
|
||||
When ``user_id`` is supplied, the user's BYOM custom-model records
|
||||
are merged into the result alongside the built-in catalog.
|
||||
"""
|
||||
registry = ModelRegistry.get_instance()
|
||||
return {model.id: model.to_dict() for model in registry.get_enabled_models()}
|
||||
return {
|
||||
model.id: model.to_dict()
|
||||
for model in registry.get_enabled_models(user_id=user_id)
|
||||
}
|
||||
|
||||
|
||||
def validate_model_id(model_id: str) -> bool:
|
||||
"""Check if a model ID exists in registry"""
|
||||
def validate_model_id(model_id: str, user_id: Optional[str] = None) -> bool:
|
||||
"""Check if a model ID exists in registry.
|
||||
|
||||
``user_id`` enables resolution of per-user BYOM records (UUIDs).
|
||||
Without it, only built-in catalog ids resolve.
|
||||
"""
|
||||
registry = ModelRegistry.get_instance()
|
||||
return registry.model_exists(model_id)
|
||||
return registry.model_exists(model_id, user_id=user_id)
|
||||
|
||||
|
||||
def get_model_capabilities(model_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get capabilities for a specific model"""
|
||||
def get_model_capabilities(
|
||||
model_id: str, user_id: Optional[str] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Get capabilities for a specific model.
|
||||
|
||||
``user_id`` enables resolution of per-user BYOM records.
|
||||
"""
|
||||
registry = ModelRegistry.get_instance()
|
||||
model = registry.get_model(model_id)
|
||||
model = registry.get_model(model_id, user_id=user_id)
|
||||
if model:
|
||||
return {
|
||||
"supported_attachment_types": model.capabilities.supported_attachment_types,
|
||||
@@ -58,36 +70,68 @@ def get_default_model_id() -> str:
|
||||
return registry.default_model_id
|
||||
|
||||
|
||||
def get_provider_from_model_id(model_id: str) -> Optional[str]:
|
||||
"""Get the provider name for a given model_id"""
|
||||
def get_provider_from_model_id(
|
||||
model_id: str, user_id: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
"""Get the provider name for a given model_id.
|
||||
|
||||
``user_id`` enables resolution of per-user BYOM records (UUIDs).
|
||||
Without it, BYOM model ids return ``None`` and the caller falls
|
||||
back to the deployment default.
|
||||
"""
|
||||
registry = ModelRegistry.get_instance()
|
||||
model = registry.get_model(model_id)
|
||||
model = registry.get_model(model_id, user_id=user_id)
|
||||
if model:
|
||||
return model.provider.value
|
||||
return None
|
||||
|
||||
|
||||
def get_token_limit(model_id: str) -> int:
|
||||
"""
|
||||
Get context window (token limit) for a model.
|
||||
Returns model's context_window or default 128000 if model not found.
|
||||
def get_token_limit(model_id: str, user_id: Optional[str] = None) -> int:
|
||||
"""Get context window (token limit) for a model.
|
||||
|
||||
Returns the model's ``context_window`` or ``DEFAULT_LLM_TOKEN_LIMIT``
|
||||
if not found. ``user_id`` enables resolution of per-user BYOM records.
|
||||
"""
|
||||
from application.core.settings import settings
|
||||
|
||||
registry = ModelRegistry.get_instance()
|
||||
model = registry.get_model(model_id)
|
||||
model = registry.get_model(model_id, user_id=user_id)
|
||||
if model:
|
||||
return model.capabilities.context_window
|
||||
return settings.DEFAULT_LLM_TOKEN_LIMIT
|
||||
|
||||
|
||||
def get_base_url_for_model(model_id: str) -> Optional[str]:
|
||||
"""
|
||||
Get the custom base_url for a specific model if configured.
|
||||
Returns None if no custom base_url is set.
|
||||
def get_base_url_for_model(
|
||||
model_id: str, user_id: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
"""Get the custom base_url for a specific model if configured.
|
||||
|
||||
Returns ``None`` if no custom base_url is set. ``user_id`` enables
|
||||
resolution of per-user BYOM records.
|
||||
"""
|
||||
registry = ModelRegistry.get_instance()
|
||||
model = registry.get_model(model_id)
|
||||
model = registry.get_model(model_id, user_id=user_id)
|
||||
if model:
|
||||
return model.base_url
|
||||
return None
|
||||
|
||||
|
||||
def get_api_key_for_model(
|
||||
model_id: str, user_id: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
"""Resolve the API key to use when invoking ``model_id``.
|
||||
|
||||
Priority:
|
||||
1. The model record's own ``api_key`` (BYOM records and
|
||||
``openai_compatible`` YAMLs populate this).
|
||||
2. The provider plugin's settings-based key.
|
||||
|
||||
``user_id`` enables resolution of per-user BYOM records.
|
||||
"""
|
||||
registry = ModelRegistry.get_instance()
|
||||
model = registry.get_model(model_id, user_id=user_id)
|
||||
if model is not None and model.api_key:
|
||||
return model.api_key
|
||||
if model is not None:
|
||||
return get_api_key_for_provider(model.provider.value)
|
||||
return None
|
||||
|
||||
358
application/core/model_yaml.py
Normal file
358
application/core/model_yaml.py
Normal file
@@ -0,0 +1,358 @@
|
||||
"""YAML loader for model catalog files under ``application/core/models/``.
|
||||
|
||||
Each ``*.yaml`` file declares one provider's static model catalog. Files
|
||||
are validated with Pydantic at load time; any parse, schema, or alias
|
||||
error aborts startup with the offending file path in the message.
|
||||
|
||||
For most providers, one YAML maps to one catalog. The
|
||||
``openai_compatible`` provider is special: each YAML file represents a
|
||||
distinct logical endpoint (Mistral, Together, Ollama, ...) with its own
|
||||
``api_key_env`` and ``base_url``. The loader returns a flat list so the
|
||||
registry can distinguish multiple files with the same ``provider:`` value.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Sequence
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from application.core.model_settings import (
|
||||
AvailableModel,
|
||||
ModelCapabilities,
|
||||
ModelProvider,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BUILTIN_MODELS_DIR = Path(__file__).parent / "models"
|
||||
DEFAULTS_FILENAME = "_defaults.yaml"
|
||||
|
||||
|
||||
class _DefaultsFile(BaseModel):
|
||||
"""Schema for ``_defaults.yaml``. Currently just attachment aliases."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
attachment_aliases: Dict[str, List[str]] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class _CapabilityFields(BaseModel):
|
||||
"""Capability fields shared between provider ``defaults:`` and per-model overrides.
|
||||
|
||||
All fields are optional so a per-model override can selectively replace
|
||||
a single field from the provider-level defaults.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
supports_tools: Optional[bool] = None
|
||||
supports_structured_output: Optional[bool] = None
|
||||
supports_streaming: Optional[bool] = None
|
||||
attachments: Optional[List[str]] = None
|
||||
context_window: Optional[int] = None
|
||||
input_cost_per_token: Optional[float] = None
|
||||
output_cost_per_token: Optional[float] = None
|
||||
|
||||
|
||||
class _ModelEntry(_CapabilityFields):
|
||||
"""Schema for one model row inside a YAML's ``models:`` list."""
|
||||
|
||||
id: str
|
||||
display_name: Optional[str] = None
|
||||
description: str = ""
|
||||
enabled: bool = True
|
||||
base_url: Optional[str] = None
|
||||
aliases: List[str] = Field(default_factory=list)
|
||||
|
||||
@field_validator("id")
|
||||
@classmethod
|
||||
def _id_nonempty(cls, v: str) -> str:
|
||||
if not v or not v.strip():
|
||||
raise ValueError("model id must be a non-empty string")
|
||||
return v
|
||||
|
||||
|
||||
class _ProviderFile(BaseModel):
|
||||
"""Schema for one ``<provider>.yaml`` catalog file."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
provider: str
|
||||
defaults: _CapabilityFields = Field(default_factory=_CapabilityFields)
|
||||
models: List[_ModelEntry] = Field(default_factory=list)
|
||||
# openai_compatible metadata. Optional for other providers.
|
||||
display_provider: Optional[str] = None
|
||||
api_key_env: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
|
||||
|
||||
class ProviderCatalog(BaseModel):
|
||||
"""One YAML file's parsed contents, ready for the registry.
|
||||
|
||||
For most providers, multiple catalogs with the same ``provider`` get
|
||||
merged later by the registry. The ``openai_compatible`` provider is
|
||||
the exception: each catalog is treated as a distinct endpoint, with
|
||||
its own ``api_key_env`` and ``base_url``.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
models: List[AvailableModel]
|
||||
source_path: Optional[Path] = None
|
||||
display_provider: Optional[str] = None
|
||||
api_key_env: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class ModelYAMLError(ValueError):
|
||||
"""Raised when a model YAML fails parsing, schema, or alias validation."""
|
||||
|
||||
|
||||
def _expand_attachments(
|
||||
attachments: Sequence[str], aliases: Dict[str, List[str]], source: str
|
||||
) -> List[str]:
|
||||
"""Resolve attachment shorthands (``image``, ``pdf``) to MIME types.
|
||||
|
||||
Raw MIME-typed entries (containing ``/``) pass through unchanged.
|
||||
Unknown aliases raise ``ModelYAMLError``.
|
||||
"""
|
||||
expanded: List[str] = []
|
||||
seen: set = set()
|
||||
for entry in attachments:
|
||||
if "/" in entry:
|
||||
if entry not in seen:
|
||||
expanded.append(entry)
|
||||
seen.add(entry)
|
||||
continue
|
||||
if entry not in aliases:
|
||||
valid = ", ".join(sorted(aliases.keys())) or "<none defined>"
|
||||
raise ModelYAMLError(
|
||||
f"{source}: unknown attachment alias '{entry}'. "
|
||||
f"Valid aliases: {valid}. "
|
||||
"(Or use a raw MIME type like 'image/png'.)"
|
||||
)
|
||||
for mime in aliases[entry]:
|
||||
if mime not in seen:
|
||||
expanded.append(mime)
|
||||
seen.add(mime)
|
||||
return expanded
|
||||
|
||||
|
||||
def _load_defaults(directory: Path) -> Dict[str, List[str]]:
|
||||
"""Load ``_defaults.yaml`` from ``directory`` if it exists."""
|
||||
path = directory / DEFAULTS_FILENAME
|
||||
if not path.exists():
|
||||
return {}
|
||||
try:
|
||||
raw = yaml.safe_load(path.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as e:
|
||||
raise ModelYAMLError(f"{path}: invalid YAML: {e}") from e
|
||||
try:
|
||||
parsed = _DefaultsFile.model_validate(raw)
|
||||
except Exception as e:
|
||||
raise ModelYAMLError(f"{path}: schema error: {e}") from e
|
||||
return parsed.attachment_aliases
|
||||
|
||||
|
||||
def _resolve_provider_enum(name: str, source: Path) -> ModelProvider:
|
||||
try:
|
||||
return ModelProvider(name)
|
||||
except ValueError as e:
|
||||
valid = ", ".join(p.value for p in ModelProvider)
|
||||
raise ModelYAMLError(
|
||||
f"{source}: unknown provider '{name}'. Valid: {valid}"
|
||||
) from e
|
||||
|
||||
|
||||
def _build_model(
|
||||
entry: _ModelEntry,
|
||||
defaults: _CapabilityFields,
|
||||
provider: ModelProvider,
|
||||
aliases: Dict[str, List[str]],
|
||||
source: Path,
|
||||
display_provider: Optional[str] = None,
|
||||
) -> AvailableModel:
|
||||
"""Merge defaults + per-model overrides into a final ``AvailableModel``."""
|
||||
|
||||
def pick(field_name: str, fallback):
|
||||
v = getattr(entry, field_name)
|
||||
if v is not None:
|
||||
return v
|
||||
d = getattr(defaults, field_name)
|
||||
if d is not None:
|
||||
return d
|
||||
return fallback
|
||||
|
||||
raw_attachments = entry.attachments
|
||||
if raw_attachments is None:
|
||||
raw_attachments = defaults.attachments
|
||||
if raw_attachments is None:
|
||||
raw_attachments = []
|
||||
expanded = _expand_attachments(
|
||||
raw_attachments, aliases, f"{source} [model={entry.id}]"
|
||||
)
|
||||
|
||||
caps = ModelCapabilities(
|
||||
supports_tools=pick("supports_tools", False),
|
||||
supports_structured_output=pick("supports_structured_output", False),
|
||||
supports_streaming=pick("supports_streaming", True),
|
||||
supported_attachment_types=expanded,
|
||||
context_window=pick("context_window", 128000),
|
||||
input_cost_per_token=pick("input_cost_per_token", None),
|
||||
output_cost_per_token=pick("output_cost_per_token", None),
|
||||
)
|
||||
|
||||
return AvailableModel(
|
||||
id=entry.id,
|
||||
provider=provider,
|
||||
display_name=entry.display_name or entry.id,
|
||||
description=entry.description,
|
||||
capabilities=caps,
|
||||
enabled=entry.enabled,
|
||||
base_url=entry.base_url,
|
||||
display_provider=display_provider,
|
||||
)
|
||||
|
||||
|
||||
def _load_one_yaml(
|
||||
path: Path, aliases: Dict[str, List[str]]
|
||||
) -> ProviderCatalog:
|
||||
try:
|
||||
raw = yaml.safe_load(path.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as e:
|
||||
raise ModelYAMLError(f"{path}: invalid YAML: {e}") from e
|
||||
try:
|
||||
parsed = _ProviderFile.model_validate(raw)
|
||||
except Exception as e:
|
||||
raise ModelYAMLError(f"{path}: schema error: {e}") from e
|
||||
|
||||
provider_enum = _resolve_provider_enum(parsed.provider, path)
|
||||
models = [
|
||||
_build_model(
|
||||
entry,
|
||||
parsed.defaults,
|
||||
provider_enum,
|
||||
aliases,
|
||||
path,
|
||||
display_provider=parsed.display_provider,
|
||||
)
|
||||
for entry in parsed.models
|
||||
]
|
||||
|
||||
return ProviderCatalog(
|
||||
provider=parsed.provider,
|
||||
models=models,
|
||||
source_path=path,
|
||||
display_provider=parsed.display_provider,
|
||||
api_key_env=parsed.api_key_env,
|
||||
base_url=parsed.base_url,
|
||||
)
|
||||
|
||||
|
||||
_BUILTIN_ALIASES_CACHE: Optional[Dict[str, List[str]]] = None
|
||||
|
||||
|
||||
def builtin_attachment_aliases() -> Dict[str, List[str]]:
|
||||
"""Return the built-in attachment alias map from ``_defaults.yaml``.
|
||||
|
||||
Cached after first read so repeat calls are cheap.
|
||||
"""
|
||||
global _BUILTIN_ALIASES_CACHE
|
||||
if _BUILTIN_ALIASES_CACHE is None:
|
||||
_BUILTIN_ALIASES_CACHE = _load_defaults(BUILTIN_MODELS_DIR)
|
||||
return _BUILTIN_ALIASES_CACHE
|
||||
|
||||
|
||||
def resolve_attachment_alias(alias: str) -> List[str]:
|
||||
"""Resolve a single attachment alias (e.g. ``"image"``) to its
|
||||
canonical MIME-type list. Raises ``ModelYAMLError`` if unknown.
|
||||
"""
|
||||
aliases = builtin_attachment_aliases()
|
||||
if alias not in aliases:
|
||||
valid = ", ".join(sorted(aliases.keys())) or "<none defined>"
|
||||
raise ModelYAMLError(
|
||||
f"Unknown attachment alias '{alias}'. Valid: {valid}"
|
||||
)
|
||||
return list(aliases[alias])
|
||||
|
||||
|
||||
def expand_attachments_lenient(
|
||||
attachments: Sequence[str], source: str
|
||||
) -> List[str]:
|
||||
"""Expand attachment aliases to MIME types, tolerating unknowns.
|
||||
|
||||
Mirrors ``_expand_attachments`` but logs+skips unknown aliases
|
||||
rather than raising. Used for runtime call sites (BYOM registry
|
||||
load) where an operator-side alias-map edit must not drop the
|
||||
entire user's BYOM layer; the strict raise still happens at the
|
||||
API validation boundary.
|
||||
"""
|
||||
aliases = builtin_attachment_aliases()
|
||||
expanded: List[str] = []
|
||||
seen: set = set()
|
||||
for entry in attachments:
|
||||
if "/" in entry:
|
||||
if entry not in seen:
|
||||
expanded.append(entry)
|
||||
seen.add(entry)
|
||||
continue
|
||||
mime_list = aliases.get(entry)
|
||||
if mime_list is None:
|
||||
logger.warning(
|
||||
"%s: skipping unknown attachment alias %r", source, entry,
|
||||
)
|
||||
continue
|
||||
for mime in mime_list:
|
||||
if mime not in seen:
|
||||
expanded.append(mime)
|
||||
seen.add(mime)
|
||||
return expanded
|
||||
|
||||
|
||||
def load_model_yamls(directories: Sequence[Path]) -> List[ProviderCatalog]:
|
||||
"""Load every ``*.yaml`` file (excluding ``_defaults.yaml``) under each
|
||||
directory in order and return a flat list of catalogs.
|
||||
|
||||
Caller is responsible for merging multiple catalogs that target the
|
||||
same provider plugin. The flat-list shape lets ``openai_compatible``
|
||||
keep each file separate (one logical endpoint per file).
|
||||
|
||||
When the same model ``id`` appears in more than one YAML across the
|
||||
directory list, a warning is logged. Order in the returned list
|
||||
preserves load order, so the registry's "later wins" merge gives the
|
||||
later directory's definition.
|
||||
"""
|
||||
catalogs: List[ProviderCatalog] = []
|
||||
seen_ids: Dict[str, Path] = {}
|
||||
|
||||
aliases: Dict[str, List[str]] = {}
|
||||
for d in directories:
|
||||
if not d or not d.exists():
|
||||
continue
|
||||
aliases.update(_load_defaults(d))
|
||||
|
||||
for d in directories:
|
||||
if not d or not d.exists():
|
||||
continue
|
||||
for path in sorted(d.glob("*.yaml")):
|
||||
if path.name == DEFAULTS_FILENAME:
|
||||
continue
|
||||
catalog = _load_one_yaml(path, aliases)
|
||||
catalogs.append(catalog)
|
||||
for m in catalog.models:
|
||||
prior = seen_ids.get(m.id)
|
||||
if prior is not None and prior != path:
|
||||
logger.warning(
|
||||
"Model id %r redefined: %s overrides %s (later wins)",
|
||||
m.id,
|
||||
path,
|
||||
prior,
|
||||
)
|
||||
seen_ids[m.id] = path
|
||||
|
||||
return catalogs
|
||||
213
application/core/models/README.md
Normal file
213
application/core/models/README.md
Normal file
@@ -0,0 +1,213 @@
|
||||
# Model catalogs
|
||||
|
||||
Each `*.yaml` file in this directory declares one provider's model
|
||||
catalog. The registry loads every YAML at boot and joins it to the
|
||||
matching provider plugin under `application/llm/providers/`.
|
||||
|
||||
To add or edit models, you almost always only touch a YAML here — no
|
||||
Python code required.
|
||||
|
||||
## Add a model to an existing provider
|
||||
|
||||
Open the provider's YAML (e.g. `anthropic.yaml`) and append two lines
|
||||
under `models:`:
|
||||
|
||||
```yaml
|
||||
models:
|
||||
- id: claude-3-7-sonnet
|
||||
display_name: Claude 3.7 Sonnet
|
||||
```
|
||||
|
||||
Capabilities default to the provider's `defaults:` block. Override
|
||||
per-model only when needed:
|
||||
|
||||
```yaml
|
||||
- id: claude-3-7-sonnet
|
||||
display_name: Claude 3.7 Sonnet
|
||||
context_window: 500000
|
||||
```
|
||||
|
||||
Restart the app. The new model appears in `/api/models`.
|
||||
|
||||
> The model `id` is what gets stored in agent / workflow records. Once
|
||||
> users start picking the model, **don't rename it** — agent and
|
||||
> workflow rows reference it as a free-form string and silently fall
|
||||
> back to the system default if the id disappears.
|
||||
|
||||
## Add an OpenAI-compatible provider (zero Python)
|
||||
|
||||
Drop a YAML in this directory (or in your `MODELS_CONFIG_DIR`) that uses
|
||||
the `openai_compatible` plugin. Set the env var named in `api_key_env`
|
||||
and you're done — no Python, no settings.py edit, no LLMCreator change:
|
||||
|
||||
```yaml
|
||||
# mistral.yaml
|
||||
provider: openai_compatible
|
||||
display_provider: mistral # shown in /api/models response
|
||||
api_key_env: MISTRAL_API_KEY # env var the plugin reads at boot
|
||||
base_url: https://api.mistral.ai/v1
|
||||
defaults:
|
||||
supports_tools: true
|
||||
context_window: 128000
|
||||
models:
|
||||
- id: mistral-large-latest
|
||||
display_name: Mistral Large
|
||||
- id: mistral-small-latest
|
||||
display_name: Mistral Small
|
||||
```
|
||||
|
||||
`MISTRAL_API_KEY=sk-... ; restart` — Mistral models appear in
|
||||
`/api/models` with `provider: "mistral"`. They route through the OpenAI
|
||||
wire format (it's `OpenAILLM` under the hood) but with Mistral's
|
||||
endpoint and key.
|
||||
|
||||
Multiple `openai_compatible` YAMLs coexist: each file is one logical
|
||||
endpoint with its own `api_key_env` and `base_url`. Drop in
|
||||
`together.yaml`, `fireworks.yaml`, etc. side by side. If an env var
|
||||
isn't set, that catalog is silently skipped at boot (logged at INFO) —
|
||||
no error.
|
||||
|
||||
Working example: `examples/mistral.yaml.example`. Files inside
|
||||
`examples/` aren't loaded by the registry; the glob only picks up
|
||||
`*.yaml` at the top level.
|
||||
|
||||
## Add a provider with its own SDK
|
||||
|
||||
For a provider that doesn't speak OpenAI's wire format, add one Python
|
||||
file to `application/llm/providers/<name>.py`:
|
||||
|
||||
```python
|
||||
from application.llm.providers.base import Provider
|
||||
from application.llm.my_provider import MyLLM
|
||||
|
||||
class MyProvider(Provider):
|
||||
name = "my_provider"
|
||||
llm_class = MyLLM
|
||||
|
||||
def get_api_key(self, settings):
|
||||
return settings.MY_PROVIDER_API_KEY
|
||||
```
|
||||
|
||||
Register it in `application/llm/providers/__init__.py` (one line in
|
||||
`ALL_PROVIDERS`), add `MY_PROVIDER_API_KEY` to `settings.py`, and create
|
||||
`my_provider.yaml` here with the model catalog.
|
||||
|
||||
## Schema reference
|
||||
|
||||
```yaml
|
||||
provider: <string, required> # matches the Provider plugin's `name`
|
||||
|
||||
# openai_compatible only — required for that provider, ignored for others
|
||||
display_provider: <string> # label shown in /api/models response
|
||||
api_key_env: <string> # name of the env var carrying the key
|
||||
base_url: <string> # endpoint URL
|
||||
|
||||
defaults: # optional, applied to every model below
|
||||
supports_tools: bool # default false
|
||||
supports_structured_output: bool # default false
|
||||
supports_streaming: bool # default true
|
||||
attachments: [<alias-or-mime>, ...] # default []
|
||||
context_window: int # default 128000
|
||||
input_cost_per_token: float # default null
|
||||
output_cost_per_token: float # default null
|
||||
|
||||
models: # required
|
||||
- id: <string, required> # the value persisted in agent records
|
||||
display_name: <string> # default: id
|
||||
description: <string> # default: ""
|
||||
enabled: bool # default true; false hides from /api/models
|
||||
base_url: <string> # optional custom endpoint for this model
|
||||
# All `defaults:` fields above can be overridden here per-model.
|
||||
```
|
||||
|
||||
### Attachment aliases
|
||||
|
||||
The `attachments:` list can mix human-readable aliases with raw MIME
|
||||
types. Aliases are defined in `_defaults.yaml`:
|
||||
|
||||
| Alias | Expands to |
|
||||
|---|---|
|
||||
| `image` | `image/png`, `image/jpeg`, `image/jpg`, `image/webp`, `image/gif` |
|
||||
| `pdf` | `application/pdf` |
|
||||
| `audio` | `audio/mpeg`, `audio/wav`, `audio/ogg` |
|
||||
|
||||
Use raw MIME types when you need surgical control:
|
||||
|
||||
```yaml
|
||||
attachments: [image/png, image/webp] # only these two
|
||||
```
|
||||
|
||||
## Operator-supplied YAMLs (`MODELS_CONFIG_DIR`)
|
||||
|
||||
Set the `MODELS_CONFIG_DIR` env var (or `.env` entry) to a directory
|
||||
path. Every `*.yaml` in that directory is loaded **after** the built-in
|
||||
catalog under `application/core/models/`. Operators use this to:
|
||||
|
||||
- Add new `openai_compatible` providers (Mistral, Together, Fireworks,
|
||||
Ollama, ...) without forking the repo.
|
||||
- Extend an existing provider's catalog with extra models — append
|
||||
models under `provider: anthropic` and they show up alongside the
|
||||
built-ins.
|
||||
- Override a built-in model's capabilities — declare the same `id`
|
||||
with different fields (e.g. a higher `context_window`). Later wins;
|
||||
the override is logged as a `WARNING` so you can audit it.
|
||||
|
||||
Things you cannot do via `MODELS_CONFIG_DIR`:
|
||||
|
||||
- Add a brand-new non-OpenAI provider — that needs a Python plugin
|
||||
under `application/llm/providers/` (see "Add a provider with its own
|
||||
SDK" above). Operator YAMLs may only target a `provider:` value that
|
||||
already has a registered plugin.
|
||||
|
||||
### Example: Docker
|
||||
|
||||
Mount your model YAMLs into the container and point the env var at the
|
||||
mount path:
|
||||
|
||||
```yaml
|
||||
# docker-compose.yml
|
||||
services:
|
||||
app:
|
||||
image: arc53/docsgpt
|
||||
environment:
|
||||
MODELS_CONFIG_DIR: /etc/docsgpt/models
|
||||
MISTRAL_API_KEY: ${MISTRAL_API_KEY}
|
||||
volumes:
|
||||
- ./my-models:/etc/docsgpt/models:ro
|
||||
```
|
||||
|
||||
Then `./my-models/mistral.yaml` (the file from
|
||||
`examples/mistral.yaml.example`) gets picked up at boot.
|
||||
|
||||
### Example: Kubernetes
|
||||
|
||||
Mount a `ConfigMap` containing your YAMLs at a known path and set
|
||||
`MODELS_CONFIG_DIR` on the deployment. The same `examples/mistral.yaml.example`
|
||||
becomes a key in the ConfigMap.
|
||||
|
||||
### Misconfiguration
|
||||
|
||||
If `MODELS_CONFIG_DIR` is set but the path doesn't exist (or isn't a
|
||||
directory), the app logs a `WARNING` at boot and continues with just
|
||||
the built-in catalog. The app does *not* fail to start — operators can
|
||||
ship config drift without taking down the service — but the warning is
|
||||
loud enough to surface in any reasonable log aggregator.
|
||||
|
||||
## Validation
|
||||
|
||||
YAMLs are parsed with Pydantic at boot. The app fails to start with a
|
||||
clear error message if:
|
||||
|
||||
- a top-level key is unknown
|
||||
- a model is missing `id`
|
||||
- an attachment alias isn't defined
|
||||
- the `provider:` value isn't registered as a plugin
|
||||
|
||||
This is intentional — silent fallbacks would mean users don't notice
|
||||
their model picks broke until they hit the API.
|
||||
|
||||
## Reserved fields (not yet implemented)
|
||||
|
||||
- `aliases:` on a model — old IDs that resolve to this model. Reserved
|
||||
for future renames; the schema accepts the field but it is not yet
|
||||
acted on.
|
||||
18
application/core/models/_defaults.yaml
Normal file
18
application/core/models/_defaults.yaml
Normal file
@@ -0,0 +1,18 @@
|
||||
# Global defaults applied across every model YAML in this directory.
|
||||
# Keep this file sparse — per-provider `defaults:` blocks are clearer
|
||||
# than a deep global default chain. This file is for things that
|
||||
# genuinely never vary, like the meaning of "image".
|
||||
|
||||
attachment_aliases:
|
||||
image:
|
||||
- image/png
|
||||
- image/jpeg
|
||||
- image/jpg
|
||||
- image/webp
|
||||
- image/gif
|
||||
pdf:
|
||||
- application/pdf
|
||||
audio:
|
||||
- audio/mpeg
|
||||
- audio/wav
|
||||
- audio/ogg
|
||||
23
application/core/models/anthropic.yaml
Normal file
23
application/core/models/anthropic.yaml
Normal file
@@ -0,0 +1,23 @@
|
||||
provider: anthropic
|
||||
defaults:
|
||||
supports_tools: true
|
||||
attachments: [image]
|
||||
context_window: 200000
|
||||
|
||||
models:
|
||||
- id: claude-opus-4-7
|
||||
display_name: Claude Opus 4.7
|
||||
description: Most capable Claude model for complex reasoning and agentic coding
|
||||
context_window: 1000000
|
||||
supports_structured_output: true
|
||||
|
||||
- id: claude-sonnet-4-6
|
||||
display_name: Claude Sonnet 4.6
|
||||
description: Best balance of speed and intelligence with extended thinking
|
||||
context_window: 1000000
|
||||
supports_structured_output: true
|
||||
|
||||
- id: claude-haiku-4-5
|
||||
display_name: Claude Haiku 4.5
|
||||
description: Fastest Claude model with near-frontier intelligence
|
||||
supports_structured_output: true
|
||||
31
application/core/models/azure_openai.yaml
Normal file
31
application/core/models/azure_openai.yaml
Normal file
@@ -0,0 +1,31 @@
|
||||
# Azure OpenAI catalog.
|
||||
#
|
||||
# IMPORTANT: For Azure OpenAI, the `id` field is the **deployment name**, not
|
||||
# a model name. Deployment names are arbitrary strings the operator chooses
|
||||
# in Azure portal (or via ARM/Bicep/Terraform) when they create a deployment
|
||||
# for a given underlying model + version.
|
||||
#
|
||||
# The IDs below are sensible defaults that mirror the underlying OpenAI
|
||||
# model name (prefixed with `azure-`). Operators almost always need to
|
||||
# override them via `MODELS_CONFIG_DIR` to match the deployment names that
|
||||
# actually exist in their Azure resource. The `display_name`, capability
|
||||
# flags, and `context_window` reflect the underlying OpenAI model.
|
||||
provider: azure_openai
|
||||
|
||||
defaults:
|
||||
supports_tools: true
|
||||
supports_structured_output: true
|
||||
attachments: [image]
|
||||
context_window: 400000
|
||||
|
||||
models:
|
||||
- id: azure-gpt-5.5
|
||||
display_name: Azure OpenAI GPT-5.5
|
||||
description: Azure-hosted flagship frontier model for complex reasoning, coding, and agentic work with a 1M-token context window
|
||||
context_window: 1050000
|
||||
- id: azure-gpt-5.4-mini
|
||||
display_name: Azure OpenAI GPT-5.4 Mini
|
||||
description: Azure-hosted cost-efficient GPT-5.4-class model for high-volume coding, computer use, and subagent workloads
|
||||
- id: azure-gpt-5.4-nano
|
||||
display_name: Azure OpenAI GPT-5.4 Nano
|
||||
description: Azure-hosted cheapest GPT-5.4-class model, optimized for simple high-volume tasks where speed and cost matter most
|
||||
7
application/core/models/docsgpt.yaml
Normal file
7
application/core/models/docsgpt.yaml
Normal file
@@ -0,0 +1,7 @@
|
||||
provider: docsgpt
|
||||
|
||||
models:
|
||||
- id: docsgpt-local
|
||||
display_name: DocsGPT Model
|
||||
description: Local model
|
||||
supports_tools: false
|
||||
31
application/core/models/examples/mistral.yaml.example
Normal file
31
application/core/models/examples/mistral.yaml.example
Normal file
@@ -0,0 +1,31 @@
|
||||
# EXAMPLE — copy this file to ../mistral.yaml (or to your
|
||||
# MODELS_CONFIG_DIR) and set MISTRAL_API_KEY in your environment.
|
||||
#
|
||||
# This is the entire integration. No Python required: the
|
||||
# `openai_compatible` plugin reads `api_key_env` and `base_url` from
|
||||
# the file and routes calls through the OpenAI wire format.
|
||||
#
|
||||
# Files in this `examples/` directory are NOT loaded by the registry
|
||||
# (the loader globs *.yaml at the top level only).
|
||||
|
||||
provider: openai_compatible
|
||||
display_provider: mistral # shown in /api/models response
|
||||
api_key_env: MISTRAL_API_KEY # env var the plugin reads
|
||||
base_url: https://api.mistral.ai/v1 # OpenAI-compatible endpoint
|
||||
|
||||
defaults:
|
||||
supports_tools: true
|
||||
context_window: 128000
|
||||
|
||||
models:
|
||||
- id: mistral-large-latest
|
||||
display_name: Mistral Large
|
||||
description: Top-tier reasoning model
|
||||
|
||||
- id: mistral-small-latest
|
||||
display_name: Mistral Small
|
||||
description: Fast, cost-efficient
|
||||
|
||||
- id: codestral-latest
|
||||
display_name: Codestral
|
||||
description: Code-specialized model
|
||||
17
application/core/models/google.yaml
Normal file
17
application/core/models/google.yaml
Normal file
@@ -0,0 +1,17 @@
|
||||
provider: google
|
||||
defaults:
|
||||
supports_tools: true
|
||||
supports_structured_output: true
|
||||
attachments: [pdf, image]
|
||||
context_window: 1048576
|
||||
|
||||
models:
|
||||
- id: gemini-3.1-pro-preview
|
||||
display_name: Gemini 3.1 Pro
|
||||
description: Most capable Gemini 3 model with advanced reasoning and agentic coding (preview)
|
||||
- id: gemini-3-flash-preview
|
||||
display_name: Gemini 3 Flash
|
||||
description: Frontier-class performance for low-latency, high-volume tasks (preview)
|
||||
- id: gemini-3.1-flash-lite-preview
|
||||
display_name: Gemini 3.1 Flash-Lite
|
||||
description: Cost-efficient frontier-class multimodal model for high-throughput workloads (preview)
|
||||
16
application/core/models/groq.yaml
Normal file
16
application/core/models/groq.yaml
Normal file
@@ -0,0 +1,16 @@
|
||||
provider: groq
|
||||
defaults:
|
||||
supports_tools: true
|
||||
context_window: 131072
|
||||
|
||||
models:
|
||||
- id: openai/gpt-oss-120b
|
||||
display_name: GPT-OSS 120B
|
||||
description: OpenAI's open-weight 120B flagship served on Groq's LPU hardware; strong general reasoning with strict structured output support
|
||||
supports_structured_output: true
|
||||
- id: llama-3.3-70b-versatile
|
||||
display_name: Llama 3.3 70B Versatile
|
||||
description: Meta's Llama 3.3 70B for general-purpose chat with parallel tool use
|
||||
- id: llama-3.1-8b-instant
|
||||
display_name: Llama 3.1 8B Instant
|
||||
description: Small, very low-latency Llama model (~560 tok/s) with parallel tool use
|
||||
7
application/core/models/huggingface.yaml
Normal file
7
application/core/models/huggingface.yaml
Normal file
@@ -0,0 +1,7 @@
|
||||
provider: huggingface
|
||||
|
||||
models:
|
||||
- id: huggingface-local
|
||||
display_name: Hugging Face Model
|
||||
description: Local Hugging Face model
|
||||
supports_tools: false
|
||||
21
application/core/models/novita.yaml
Normal file
21
application/core/models/novita.yaml
Normal file
@@ -0,0 +1,21 @@
|
||||
provider: novita
|
||||
defaults:
|
||||
supports_tools: true
|
||||
supports_structured_output: true
|
||||
|
||||
models:
|
||||
- id: deepseek/deepseek-v4-pro
|
||||
display_name: DeepSeek V4 Pro
|
||||
description: 1.6T MoE (49B active) with 1M context, hybrid CSA/HCA attention, top-tier reasoning and agentic coding
|
||||
context_window: 1048576
|
||||
|
||||
- id: moonshotai/kimi-k2.6
|
||||
display_name: Kimi K2.6
|
||||
description: 1T-parameter open-weight MoE with native vision/video, multi-step tool calling, and agentic long-horizon execution
|
||||
attachments: [image]
|
||||
context_window: 262144
|
||||
|
||||
- id: zai-org/glm-5
|
||||
display_name: GLM-5
|
||||
description: Z.AI 754B-parameter MoE with strong general reasoning, function calling, and structured output
|
||||
context_window: 202800
|
||||
18
application/core/models/openai.yaml
Normal file
18
application/core/models/openai.yaml
Normal file
@@ -0,0 +1,18 @@
|
||||
provider: openai
|
||||
defaults:
|
||||
supports_tools: true
|
||||
supports_structured_output: true
|
||||
attachments: [image]
|
||||
context_window: 400000
|
||||
|
||||
models:
|
||||
- id: gpt-5.5
|
||||
display_name: GPT-5.5
|
||||
description: Flagship frontier model for complex reasoning, coding, and agentic work with a 1M-token context window
|
||||
context_window: 1050000
|
||||
- id: gpt-5.4-mini
|
||||
display_name: GPT-5.4 Mini
|
||||
description: Cost-efficient GPT-5.4-class model for high-volume coding, computer use, and subagent workloads
|
||||
- id: gpt-5.4-nano
|
||||
display_name: GPT-5.4 Nano
|
||||
description: Cheapest GPT-5.4-class model, optimized for simple high-volume tasks where speed and cost matter most
|
||||
25
application/core/models/openrouter.yaml
Normal file
25
application/core/models/openrouter.yaml
Normal file
@@ -0,0 +1,25 @@
|
||||
provider: openrouter
|
||||
defaults:
|
||||
supports_tools: true
|
||||
attachments: [image]
|
||||
context_window: 128000
|
||||
|
||||
models:
|
||||
- id: qwen/qwen3-coder:free
|
||||
display_name: Qwen3 Coder (free)
|
||||
description: Free-tier 480B MoE coder model with strong agentic tool use; rate-limited
|
||||
context_window: 262000
|
||||
attachments: []
|
||||
|
||||
- id: deepseek/deepseek-v3.2
|
||||
display_name: DeepSeek V3.2
|
||||
description: Open-weights reasoning model, very low cost (~$0.25 in / $0.38 out per 1M)
|
||||
context_window: 131072
|
||||
attachments: []
|
||||
supports_structured_output: true
|
||||
|
||||
- id: anthropic/claude-sonnet-4.6
|
||||
display_name: Claude Sonnet 4.6 (via OpenRouter)
|
||||
description: Frontier Sonnet-class model with 1M context, vision, and extended thinking
|
||||
context_window: 1000000
|
||||
supports_structured_output: true
|
||||
@@ -23,9 +23,19 @@ class Settings(BaseSettings):
|
||||
EMBEDDINGS_NAME: str = "huggingface_sentence-transformers/all-mpnet-base-v2"
|
||||
EMBEDDINGS_BASE_URL: Optional[str] = None # Remote embeddings API URL (OpenAI-compatible)
|
||||
EMBEDDINGS_KEY: Optional[str] = None # api key for embeddings (if using openai, just copy API_KEY)
|
||||
# Optional directory of operator-supplied model YAMLs, loaded after the
|
||||
# built-in catalog under application/core/models/. Later wins on
|
||||
# duplicate model id. See application/core/models/README.md.
|
||||
MODELS_CONFIG_DIR: Optional[str] = None
|
||||
|
||||
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
|
||||
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
|
||||
# Prefetch=1 caps SIGKILL loss to one task. Visibility timeout must exceed
|
||||
# the longest legitimate task runtime (ingest, agent webhook) but stay
|
||||
# short enough that SIGKILLed tasks redeliver promptly. 1h matches Onyx
|
||||
# and Dify defaults; long ingests can override via env.
|
||||
CELERY_WORKER_PREFETCH_MULTIPLIER: int = 1
|
||||
CELERY_VISIBILITY_TIMEOUT: int = 3600
|
||||
# Only consulted when VECTOR_STORE=mongodb or when running scripts/db/backfill.py; user data lives in Postgres.
|
||||
MONGO_URI: Optional[str] = None
|
||||
# User-data Postgres DB.
|
||||
@@ -178,6 +188,42 @@ class Settings(BaseSettings):
|
||||
COMPRESSION_PROMPT_VERSION: str = "v1.0" # Track prompt iterations
|
||||
COMPRESSION_MAX_HISTORY_POINTS: int = 3 # Keep only last N compression points to prevent DB bloat
|
||||
|
||||
# Internal SSE push channel (notifications + durable replay journal)
|
||||
# Master switch — when False, /api/events emits a "push_disabled" comment
|
||||
# and returns; clients fall back to polling. Publisher becomes a no-op.
|
||||
ENABLE_SSE_PUSH: bool = True
|
||||
# Per-user durable backlog cap (~entries). At typical event rates this
|
||||
# gives ~24h of replay; tune up for verbose feeds, down for memory.
|
||||
EVENTS_STREAM_MAXLEN: int = 1000
|
||||
# SSE keepalive comment cadence. Must sit under Cloudflare's 100s idle
|
||||
# close and iOS Safari's ~60s — 15s gives generous headroom.
|
||||
SSE_KEEPALIVE_SECONDS: int = 15
|
||||
# Cap on simultaneous SSE connections per user. Each connection holds
|
||||
# one WSGI thread (32 per gunicorn worker) and one Redis pub/sub
|
||||
# connection. 8 covers normal multi-tab use without letting one user
|
||||
# starve the pool. Set to 0 to disable the cap.
|
||||
SSE_MAX_CONCURRENT_PER_USER: int = 8
|
||||
# Per-request cap on the number of backlog entries XRANGE returns
|
||||
# for ``/api/events`` snapshots. Bounds the bytes a single replay
|
||||
# can move from Redis to the wire — a malicious client looping
|
||||
# ``Last-Event-ID=<oldest>`` reconnects can only enumerate this
|
||||
# many entries per round-trip. Combined with the per-user
|
||||
# connection cap above and the windowed budget below, total
|
||||
# enumeration throughput is bounded.
|
||||
EVENTS_REPLAY_MAX_PER_REQUEST: int = 200
|
||||
# Sliding-window cap on snapshot replays per user. Once the budget
|
||||
# is exhausted the route returns HTTP 429 with the cursor pinned;
|
||||
# the client backs off and retries after the window rolls over.
|
||||
EVENTS_REPLAY_BUDGET_REQUESTS_PER_WINDOW: int = 30
|
||||
EVENTS_REPLAY_BUDGET_WINDOW_SECONDS: int = 60
|
||||
|
||||
# Retention for the ``message_events`` journal. The ``cleanup_message_events``
|
||||
# beat task deletes rows older than this. Reconnect-replay only
|
||||
# needs the journal for streams a client could still be tailing,
|
||||
# so 14 days is a generous default that covers paused/tool-action
|
||||
# flows without unbounded table growth.
|
||||
MESSAGE_EVENTS_RETENTION_DAYS: int = 14
|
||||
|
||||
@field_validator("POSTGRES_URI", mode="before")
|
||||
@classmethod
|
||||
def _normalize_postgres_uri_validator(cls, v):
|
||||
|
||||
0
application/events/__init__.py
Normal file
0
application/events/__init__.py
Normal file
52
application/events/keys.py
Normal file
52
application/events/keys.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Stream/topic key derivations shared by publisher and SSE consumer.
|
||||
|
||||
Single source of truth for the per-user Redis Streams key and pub/sub
|
||||
topic name. Both must agree exactly — a typo here splits the
|
||||
publisher's writes from the consumer's reads.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def stream_key(user_id: str) -> str:
|
||||
"""Redis Streams key holding the durable backlog for ``user_id``."""
|
||||
return f"user:{user_id}:stream"
|
||||
|
||||
|
||||
def topic_name(user_id: str) -> str:
|
||||
"""Redis pub/sub channel used for live fan-out to ``user_id``."""
|
||||
return f"user:{user_id}"
|
||||
|
||||
|
||||
def connection_counter_key(user_id: str) -> str:
|
||||
"""Redis counter tracking active SSE connections for ``user_id``."""
|
||||
return f"user:{user_id}:sse_count"
|
||||
|
||||
|
||||
def replay_budget_key(user_id: str) -> str:
|
||||
"""Redis counter tracking snapshot replays for ``user_id`` in the
|
||||
rolling rate-limit window."""
|
||||
return f"user:{user_id}:replay_count"
|
||||
|
||||
|
||||
def stream_id_compare(a: str, b: str) -> int:
|
||||
"""Compare two Redis Streams ids. Returns -1, 0, 1 like ``cmp``.
|
||||
|
||||
Stream ids are ``ms-seq`` strings; comparing as strings would be wrong
|
||||
once ``ms`` straddles digit-count boundaries. We parse and compare
|
||||
as ``(int, int)`` tuples.
|
||||
|
||||
Raises ``ValueError`` on malformed input. Callers must pre-validate
|
||||
against ``_STREAM_ID_RE`` (or equivalent) — a lex fallback here let
|
||||
a malformed id compare lex-greater than a real one and silently pin
|
||||
dedup forever.
|
||||
"""
|
||||
a_ms, _, a_seq = a.partition("-")
|
||||
b_ms, _, b_seq = b.partition("-")
|
||||
a_tuple = (int(a_ms), int(a_seq) if a_seq else 0)
|
||||
b_tuple = (int(b_ms), int(b_seq) if b_seq else 0)
|
||||
if a_tuple < b_tuple:
|
||||
return -1
|
||||
if a_tuple > b_tuple:
|
||||
return 1
|
||||
return 0
|
||||
144
application/events/publisher.py
Normal file
144
application/events/publisher.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""User-scoped event publisher: durable backlog + live fan-out.
|
||||
|
||||
Each ``publish_user_event`` call writes twice:
|
||||
|
||||
1. ``XADD user:{user_id}:stream MAXLEN ~ <cap> * event <json>`` — the
|
||||
durable backlog used by SSE reconnect (``Last-Event-ID``) and stream
|
||||
replay. Bounded by ``EVENTS_STREAM_MAXLEN`` (~24h at typical event
|
||||
rates) so the per-user footprint stays predictable.
|
||||
2. ``PUBLISH user:{user_id} <json-with-id>`` — live fan-out to every
|
||||
currently connected SSE generator for the user, across instances.
|
||||
|
||||
Together they give a snapshot-plus-tail story: a reconnecting client
|
||||
reads ``XRANGE`` from its last seen id and then transitions onto the
|
||||
live pub/sub. The Redis Streams entry id (e.g. ``1735682400000-0``) is
|
||||
the canonical, monotonically increasing event id and is what
|
||||
``Last-Event-ID`` carries.
|
||||
|
||||
Failures are logged and swallowed: the caller is typically a Celery
|
||||
task whose primary work has already succeeded, and a notification
|
||||
delivery miss should not surface as a task failure.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
|
||||
from application.cache import get_redis_instance
|
||||
from application.core.settings import settings
|
||||
from application.events.keys import stream_key, topic_name
|
||||
from application.streaming.broadcast_channel import Topic
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _iso_now() -> str:
|
||||
"""ISO 8601 UTC with millisecond precision and Z suffix."""
|
||||
return (
|
||||
datetime.now(timezone.utc)
|
||||
.isoformat(timespec="milliseconds")
|
||||
.replace("+00:00", "Z")
|
||||
)
|
||||
|
||||
|
||||
def publish_user_event(
|
||||
user_id: str,
|
||||
event_type: str,
|
||||
payload: dict[str, Any],
|
||||
*,
|
||||
scope: Optional[dict[str, Any]] = None,
|
||||
) -> Optional[str]:
|
||||
"""Publish a user-scoped event; return the Redis Streams id or ``None``.
|
||||
|
||||
Fire-and-forget: never raises. ``None`` means the event reached
|
||||
neither the journal nor live subscribers (see runbook for causes).
|
||||
"""
|
||||
if not user_id or not event_type:
|
||||
logger.warning(
|
||||
"publish_user_event called without user_id or event_type "
|
||||
"(user_id=%r, event_type=%r)",
|
||||
user_id,
|
||||
event_type,
|
||||
)
|
||||
return None
|
||||
if not settings.ENABLE_SSE_PUSH:
|
||||
return None
|
||||
|
||||
envelope_partial: dict[str, Any] = {
|
||||
"type": event_type,
|
||||
"ts": _iso_now(),
|
||||
"user_id": user_id,
|
||||
"topic": topic_name(user_id),
|
||||
"scope": scope or {},
|
||||
"payload": payload,
|
||||
}
|
||||
|
||||
try:
|
||||
envelope_partial_json = json.dumps(envelope_partial)
|
||||
except (TypeError, ValueError) as exc:
|
||||
logger.warning(
|
||||
"publish_user_event payload not JSON-serializable: "
|
||||
"user=%s type=%s err=%s",
|
||||
user_id,
|
||||
event_type,
|
||||
exc,
|
||||
)
|
||||
return None
|
||||
|
||||
redis = get_redis_instance()
|
||||
if redis is None:
|
||||
logger.debug("Redis unavailable; skipping publish_user_event")
|
||||
return None
|
||||
|
||||
maxlen = settings.EVENTS_STREAM_MAXLEN
|
||||
stream_id: Optional[str] = None
|
||||
try:
|
||||
# Auto-id ('*') gives a monotonic ms-seq id that doubles as the
|
||||
# SSE event id. ``approximate=True`` lets Redis trim in chunks
|
||||
# for performance; the cap is treated as ~MAXLEN, never <.
|
||||
result = redis.xadd(
|
||||
stream_key(user_id),
|
||||
{"event": envelope_partial_json},
|
||||
maxlen=maxlen,
|
||||
approximate=True,
|
||||
)
|
||||
stream_id = (
|
||||
result.decode("utf-8")
|
||||
if isinstance(result, (bytes, bytearray))
|
||||
else str(result)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"xadd failed for user=%s event_type=%s", user_id, event_type
|
||||
)
|
||||
|
||||
# If the durable journal write failed there is no canonical id to
|
||||
# ship — publishing the envelope live would put an id-less record
|
||||
# on the wire that bypasses the SSE route's dedup floor and breaks
|
||||
# ``Last-Event-ID`` semantics for any reconnect. Best-effort
|
||||
# delivery means dropping consistently, not delivering inconsistent
|
||||
# state.
|
||||
if stream_id is None:
|
||||
return None
|
||||
|
||||
envelope = dict(envelope_partial)
|
||||
envelope["id"] = stream_id
|
||||
|
||||
try:
|
||||
Topic(topic_name(user_id)).publish(json.dumps(envelope))
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"publish failed for user=%s event_type=%s", user_id, event_type
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"event.published topic=%s type=%s id=%s",
|
||||
topic_name(user_id),
|
||||
event_type,
|
||||
stream_id,
|
||||
)
|
||||
|
||||
return stream_id
|
||||
@@ -11,6 +11,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnthropicLLM(BaseLLM):
|
||||
provider_name = "anthropic"
|
||||
|
||||
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import ClassVar
|
||||
|
||||
from application.cache import gen_cache, stream_cache
|
||||
|
||||
@@ -10,6 +11,10 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseLLM(ABC):
|
||||
# Stamped onto the ``llm_stream_start`` event so dashboards can group
|
||||
# calls by vendor. Subclasses override.
|
||||
provider_name: ClassVar[str] = "unknown"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
decoded_token=None,
|
||||
@@ -17,6 +22,8 @@ class BaseLLM(ABC):
|
||||
model_id=None,
|
||||
base_url=None,
|
||||
backup_models=None,
|
||||
model_user_id=None,
|
||||
capabilities=None,
|
||||
):
|
||||
self.decoded_token = decoded_token
|
||||
self.agent_id = str(agent_id) if agent_id else None
|
||||
@@ -25,6 +32,12 @@ class BaseLLM(ABC):
|
||||
self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
|
||||
self._backup_models = backup_models or []
|
||||
self._fallback_llm = None
|
||||
# Registry-resolved per-model capability overrides (BYOM caps,
|
||||
# operator YAML). None falls back to provider-class defaults.
|
||||
self.capabilities = capabilities
|
||||
# BYOM-resolution scope captured at LLM creation time so backup
|
||||
# / fallback lookups hit the same per-user layer as the primary.
|
||||
self.model_user_id = model_user_id
|
||||
|
||||
@property
|
||||
def fallback_llm(self):
|
||||
@@ -39,10 +52,19 @@ class BaseLLM(ABC):
|
||||
get_api_key_for_provider,
|
||||
)
|
||||
|
||||
# Try per-agent backup models first
|
||||
# model_user_id (BYOM scope) takes precedence over the caller's
|
||||
# sub so shared-agent backups resolve under the owner's layer.
|
||||
caller_sub = (
|
||||
self.decoded_token.get("sub")
|
||||
if isinstance(self.decoded_token, dict)
|
||||
else None
|
||||
)
|
||||
backup_user_id = self.model_user_id or caller_sub
|
||||
for backup_model_id in self._backup_models:
|
||||
try:
|
||||
provider = get_provider_from_model_id(backup_model_id)
|
||||
provider = get_provider_from_model_id(
|
||||
backup_model_id, user_id=backup_user_id
|
||||
)
|
||||
if not provider:
|
||||
logger.warning(
|
||||
f"Could not resolve provider for backup model: {backup_model_id}"
|
||||
@@ -56,6 +78,15 @@ class BaseLLM(ABC):
|
||||
decoded_token=self.decoded_token,
|
||||
model_id=backup_model_id,
|
||||
agent_id=self.agent_id,
|
||||
model_user_id=self.model_user_id,
|
||||
)
|
||||
# Tag the fallback LLM so its rows land as
|
||||
# ``source='fallback'`` in cost-attribution dashboards.
|
||||
# Propagate the parent's ``_request_id`` so a user
|
||||
# request that ran fallback is still grouped under one id.
|
||||
self._fallback_llm._token_usage_source = "fallback"
|
||||
self._fallback_llm._request_id = getattr(
|
||||
self, "_request_id", None,
|
||||
)
|
||||
logger.info(
|
||||
f"Fallback LLM initialized from agent backup model: "
|
||||
@@ -68,7 +99,10 @@ class BaseLLM(ABC):
|
||||
)
|
||||
continue
|
||||
|
||||
# Fall back to global FALLBACK_* settings
|
||||
# Fall back to global FALLBACK_* settings. Forward
|
||||
# ``model_user_id`` here too: deployments can configure
|
||||
# ``FALLBACK_LLM_NAME`` to a BYOM UUID, and that UUID is owned
|
||||
# by the same user the primary model was resolved under.
|
||||
if settings.FALLBACK_LLM_PROVIDER:
|
||||
try:
|
||||
self._fallback_llm = LLMCreator.create_llm(
|
||||
@@ -78,6 +112,12 @@ class BaseLLM(ABC):
|
||||
decoded_token=self.decoded_token,
|
||||
model_id=settings.FALLBACK_LLM_NAME,
|
||||
agent_id=self.agent_id,
|
||||
model_user_id=self.model_user_id,
|
||||
)
|
||||
# Same rationale as the agent-backup branch.
|
||||
self._fallback_llm._token_usage_source = "fallback"
|
||||
self._fallback_llm._request_id = getattr(
|
||||
self, "_request_id", None,
|
||||
)
|
||||
logger.info(
|
||||
f"Fallback LLM initialized from global settings: "
|
||||
@@ -96,6 +136,26 @@ class BaseLLM(ABC):
|
||||
return args_dict
|
||||
return {k: v for k, v in args_dict.items() if v is not None}
|
||||
|
||||
@staticmethod
|
||||
def _is_non_retriable_client_error(exc: BaseException) -> bool:
|
||||
"""4xx errors mean the request itself is malformed — retrying with
|
||||
a different model fails identically and doubles the work. Only
|
||||
transient/5xx/connection errors should trigger fallback."""
|
||||
try:
|
||||
from google.genai.errors import ClientError as _GenaiClientError
|
||||
|
||||
if isinstance(exc, _GenaiClientError):
|
||||
return True
|
||||
except ImportError:
|
||||
pass
|
||||
for attr in ("status_code", "code", "http_status"):
|
||||
v = getattr(exc, attr, None)
|
||||
if isinstance(v, int) and 400 <= v < 500:
|
||||
return True
|
||||
resp = getattr(exc, "response", None)
|
||||
v = getattr(resp, "status_code", None)
|
||||
return isinstance(v, int) and 400 <= v < 500
|
||||
|
||||
def _execute_with_fallback(
|
||||
self, method_name: str, decorators: list, *args, **kwargs
|
||||
):
|
||||
@@ -119,12 +179,18 @@ class BaseLLM(ABC):
|
||||
|
||||
if is_stream:
|
||||
return self._stream_with_fallback(
|
||||
decorated_method, method_name, *args, **kwargs
|
||||
decorated_method, method_name, decorators, *args, **kwargs
|
||||
)
|
||||
|
||||
try:
|
||||
return decorated_method()
|
||||
except Exception as e:
|
||||
if self._is_non_retriable_client_error(e):
|
||||
logger.error(
|
||||
f"Primary LLM failed with non-retriable client error; "
|
||||
f"skipping fallback: {str(e)}"
|
||||
)
|
||||
raise
|
||||
if not self.fallback_llm:
|
||||
logger.error(f"Primary LLM failed and no fallback configured: {str(e)}")
|
||||
raise
|
||||
@@ -134,14 +200,27 @@ class BaseLLM(ABC):
|
||||
f"{fallback.model_id}. Error: {str(e)}"
|
||||
)
|
||||
|
||||
fallback_method = getattr(
|
||||
fallback, method_name.replace("_raw_", "")
|
||||
)
|
||||
# Apply decorators to fallback's raw method directly — calling
|
||||
# fallback.gen() would re-enter the orchestrator and recurse via
|
||||
# fallback.fallback_llm.
|
||||
fallback_method = getattr(fallback, method_name)
|
||||
for decorator in decorators:
|
||||
fallback_method = decorator(fallback_method)
|
||||
fallback_kwargs = {**kwargs, "model": fallback.model_id}
|
||||
return fallback_method(*args, **fallback_kwargs)
|
||||
try:
|
||||
return fallback_method(fallback, *args, **fallback_kwargs)
|
||||
except Exception as e2:
|
||||
if self._is_non_retriable_client_error(e2):
|
||||
logger.error(
|
||||
f"Fallback LLM failed with non-retriable client "
|
||||
f"error; giving up: {str(e2)}"
|
||||
)
|
||||
else:
|
||||
logger.error(f"Fallback LLM also failed; giving up: {str(e2)}")
|
||||
raise
|
||||
|
||||
def _stream_with_fallback(
|
||||
self, decorated_method, method_name, *args, **kwargs
|
||||
self, decorated_method, method_name, decorators, *args, **kwargs
|
||||
):
|
||||
"""
|
||||
Wrapper generator that catches mid-stream errors and falls back.
|
||||
@@ -154,6 +233,12 @@ class BaseLLM(ABC):
|
||||
try:
|
||||
yield from decorated_method()
|
||||
except Exception as e:
|
||||
if self._is_non_retriable_client_error(e):
|
||||
logger.error(
|
||||
f"Primary LLM failed mid-stream with non-retriable client "
|
||||
f"error; skipping fallback: {str(e)}"
|
||||
)
|
||||
raise
|
||||
if not self.fallback_llm:
|
||||
logger.error(
|
||||
f"Primary LLM failed and no fallback configured: {str(e)}"
|
||||
@@ -164,11 +249,37 @@ class BaseLLM(ABC):
|
||||
f"Primary LLM failed mid-stream. Falling back to "
|
||||
f"{fallback.model_id}. Error: {str(e)}"
|
||||
)
|
||||
fallback_method = getattr(
|
||||
fallback, method_name.replace("_raw_", "")
|
||||
# Apply decorators to fallback's raw stream method directly —
|
||||
# calling fallback.gen_stream() would re-enter the orchestrator
|
||||
# and recurse via fallback.fallback_llm. Emit the stream-start
|
||||
# event manually so dashboards still see the fallback's
|
||||
# provider/model when the response actually comes from it.
|
||||
fallback._emit_stream_start_log(
|
||||
fallback.model_id,
|
||||
kwargs.get("messages"),
|
||||
kwargs.get("tools"),
|
||||
bool(
|
||||
kwargs.get("_usage_attachments")
|
||||
or kwargs.get("attachments")
|
||||
),
|
||||
)
|
||||
fallback_method = getattr(fallback, method_name)
|
||||
for decorator in decorators:
|
||||
fallback_method = decorator(fallback_method)
|
||||
fallback_kwargs = {**kwargs, "model": fallback.model_id}
|
||||
yield from fallback_method(*args, **fallback_kwargs)
|
||||
try:
|
||||
yield from fallback_method(fallback, *args, **fallback_kwargs)
|
||||
except Exception as e2:
|
||||
if self._is_non_retriable_client_error(e2):
|
||||
logger.error(
|
||||
f"Fallback LLM failed mid-stream with non-retriable "
|
||||
f"client error; giving up: {str(e2)}"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Fallback LLM also failed mid-stream; giving up: {str(e2)}"
|
||||
)
|
||||
raise
|
||||
|
||||
def gen(self, model, messages, stream=False, tools=None, *args, **kwargs):
|
||||
decorators = [gen_token_usage, gen_cache]
|
||||
@@ -183,7 +294,58 @@ class BaseLLM(ABC):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _emit_stream_start_log(self, model, messages, tools, has_attachments):
|
||||
# Stamped with ``self.provider_name`` so dashboards can group calls
|
||||
# by vendor; the fallback path emits its own copy on the fallback
|
||||
# instance so the actual responding provider is recorded.
|
||||
logging.info(
|
||||
"llm_stream_start",
|
||||
extra={
|
||||
"model": model,
|
||||
"provider": self.provider_name,
|
||||
"message_count": len(messages) if messages is not None else 0,
|
||||
"has_attachments": bool(has_attachments),
|
||||
"has_tools": bool(tools),
|
||||
},
|
||||
)
|
||||
|
||||
def _emit_stream_finished_log(
|
||||
self,
|
||||
model,
|
||||
*,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
latency_ms,
|
||||
cached_tokens=None,
|
||||
error=None,
|
||||
):
|
||||
# Paired with ``llm_stream_start`` so cost dashboards can sum tokens
|
||||
# by user/agent/provider. Token counts are client-side estimates
|
||||
# from ``stream_token_usage``; vendor-reported counts (incl.
|
||||
# ``cached_tokens`` for prompt caching) require per-provider
|
||||
# extraction in each ``_raw_gen_stream`` and aren't wired yet.
|
||||
extra = {
|
||||
"model": model,
|
||||
"provider": self.provider_name,
|
||||
"prompt_tokens": int(prompt_tokens),
|
||||
"completion_tokens": int(completion_tokens),
|
||||
"latency_ms": int(latency_ms),
|
||||
"status": "error" if error is not None else "ok",
|
||||
}
|
||||
if cached_tokens is not None:
|
||||
extra["cached_tokens"] = int(cached_tokens)
|
||||
if error is not None:
|
||||
extra["error_class"] = type(error).__name__
|
||||
logging.info("llm_stream_finished", extra=extra)
|
||||
|
||||
def gen_stream(self, model, messages, stream=True, tools=None, *args, **kwargs):
|
||||
# Attachments arrive as ``_usage_attachments`` from ``Agent._llm_gen``;
|
||||
# the ``stream_token_usage`` decorator pops that key, but the log
|
||||
# fires before the decorator runs so it's still in ``kwargs`` here.
|
||||
has_attachments = bool(
|
||||
kwargs.get("_usage_attachments") or kwargs.get("attachments")
|
||||
)
|
||||
self._emit_stream_start_log(model, messages, tools, has_attachments)
|
||||
decorators = [stream_cache, stream_token_usage]
|
||||
return self._execute_with_fallback(
|
||||
"_raw_gen_stream",
|
||||
|
||||
@@ -6,6 +6,8 @@ DOCSGPT_BASE_URL = "https://oai.arc53.com"
|
||||
DOCSGPT_MODEL = "docsgpt"
|
||||
|
||||
class DocsGPTAPILLM(OpenAILLM):
|
||||
provider_name = "docsgpt"
|
||||
|
||||
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
|
||||
super().__init__(
|
||||
api_key=DOCSGPT_API_KEY,
|
||||
|
||||
@@ -6,10 +6,13 @@ from google.genai import types
|
||||
from application.core.settings import settings
|
||||
|
||||
from application.llm.base import BaseLLM
|
||||
from application.llm.handlers.google import _decode_thought_signature
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
|
||||
|
||||
class GoogleLLM(BaseLLM):
|
||||
provider_name = "google"
|
||||
|
||||
def __init__(
|
||||
self, api_key=None, user_api_key=None, decoded_token=None, *args, **kwargs
|
||||
):
|
||||
@@ -79,24 +82,39 @@ class GoogleLLM(BaseLLM):
|
||||
for attachment in attachments:
|
||||
mime_type = attachment.get("mime_type")
|
||||
|
||||
if mime_type in self.get_supported_attachment_types():
|
||||
try:
|
||||
if mime_type not in self.get_supported_attachment_types():
|
||||
continue
|
||||
try:
|
||||
# Images go inline as bytes per Google's guidance for
|
||||
# requests under 20MB; the Files API can return before
|
||||
# the upload reaches ACTIVE state and yield an empty URI.
|
||||
if mime_type.startswith("image/"):
|
||||
file_bytes = self._read_attachment_bytes(attachment)
|
||||
files.append(
|
||||
{"file_bytes": file_bytes, "mime_type": mime_type}
|
||||
)
|
||||
else:
|
||||
file_uri = self._upload_file_to_google(attachment)
|
||||
if not file_uri:
|
||||
raise ValueError(
|
||||
f"Google Files API returned empty URI for "
|
||||
f"{attachment.get('path', 'unknown')}"
|
||||
)
|
||||
logging.info(
|
||||
f"GoogleLLM: Successfully uploaded file, got URI: {file_uri}"
|
||||
)
|
||||
files.append({"file_uri": file_uri, "mime_type": mime_type})
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"GoogleLLM: Error uploading file: {e}", exc_info=True
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"GoogleLLM: Error processing attachment: {e}", exc_info=True
|
||||
)
|
||||
if "content" in attachment:
|
||||
prepared_messages[user_message_index]["content"].append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"[File could not be processed: {attachment.get('path', 'unknown')}]",
|
||||
}
|
||||
)
|
||||
if "content" in attachment:
|
||||
prepared_messages[user_message_index]["content"].append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"[File could not be processed: {attachment.get('path', 'unknown')}]",
|
||||
}
|
||||
)
|
||||
if files:
|
||||
logging.info(f"GoogleLLM: Adding {len(files)} files to message")
|
||||
prepared_messages[user_message_index]["content"].append({"files": files})
|
||||
@@ -112,7 +130,9 @@ class GoogleLLM(BaseLLM):
|
||||
Returns:
|
||||
str: Google AI file URI for the uploaded file.
|
||||
"""
|
||||
if "google_file_uri" in attachment:
|
||||
# Truthy check, not membership: a poisoned cache row of "" or
|
||||
# None must be treated as a miss and trigger a fresh upload.
|
||||
if attachment.get("google_file_uri"):
|
||||
return attachment["google_file_uri"]
|
||||
file_path = attachment.get("path")
|
||||
if not file_path:
|
||||
@@ -126,6 +146,10 @@ class GoogleLLM(BaseLLM):
|
||||
file=local_path
|
||||
).uri,
|
||||
)
|
||||
if not file_uri:
|
||||
raise ValueError(
|
||||
f"Google Files API upload returned empty URI for {file_path}"
|
||||
)
|
||||
|
||||
# Cache the Google file URI on the attachment row so we don't
|
||||
# re-upload on the next LLM call. Accept either a PG UUID
|
||||
@@ -159,6 +183,26 @@ class GoogleLLM(BaseLLM):
|
||||
logging.error(f"Error uploading file to Google AI: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def _read_attachment_bytes(self, attachment):
|
||||
"""
|
||||
Read attachment bytes from storage for inline transmission.
|
||||
|
||||
Args:
|
||||
attachment (dict): Attachment dictionary with path and metadata.
|
||||
|
||||
Returns:
|
||||
bytes: Raw file bytes.
|
||||
"""
|
||||
file_path = attachment.get("path")
|
||||
if not file_path:
|
||||
raise ValueError("No file path provided in attachment")
|
||||
if not self.storage.file_exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
return self.storage.process_file(
|
||||
file_path,
|
||||
lambda local_path, **kwargs: open(local_path, "rb").read(),
|
||||
)
|
||||
|
||||
def _clean_messages_google(self, messages):
|
||||
"""
|
||||
Convert OpenAI format messages to Google AI format and collect system prompts.
|
||||
@@ -215,7 +259,7 @@ class GoogleLLM(BaseLLM):
|
||||
except (_json.JSONDecodeError, TypeError):
|
||||
args = {}
|
||||
cleaned_args = self._remove_null_values(args)
|
||||
thought_sig = tc.get("thought_signature")
|
||||
thought_sig = _decode_thought_signature(tc.get("thought_signature"))
|
||||
if thought_sig:
|
||||
parts.append(
|
||||
types.Part(
|
||||
@@ -279,7 +323,9 @@ class GoogleLLM(BaseLLM):
|
||||
name=item["function_call"]["name"],
|
||||
args=cleaned_args,
|
||||
),
|
||||
thoughtSignature=item["thought_signature"],
|
||||
thoughtSignature=_decode_thought_signature(
|
||||
item["thought_signature"]
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -298,12 +344,24 @@ class GoogleLLM(BaseLLM):
|
||||
)
|
||||
elif "files" in item:
|
||||
for file_data in item["files"]:
|
||||
parts.append(
|
||||
types.Part.from_uri(
|
||||
file_uri=file_data["file_uri"],
|
||||
mime_type=file_data["mime_type"],
|
||||
if "file_bytes" in file_data:
|
||||
parts.append(
|
||||
types.Part.from_bytes(
|
||||
data=file_data["file_bytes"],
|
||||
mime_type=file_data["mime_type"],
|
||||
)
|
||||
)
|
||||
elif file_data.get("file_uri"):
|
||||
parts.append(
|
||||
types.Part.from_uri(
|
||||
file_uri=file_data["file_uri"],
|
||||
mime_type=file_data["mime_type"],
|
||||
)
|
||||
)
|
||||
else:
|
||||
logging.warning(
|
||||
"GoogleLLM: dropping file part with empty URI and no bytes"
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected content dictionary format:{item}"
|
||||
@@ -541,22 +599,6 @@ class GoogleLLM(BaseLLM):
|
||||
config.response_mime_type = "application/json"
|
||||
# Check if we have both tools and file attachments
|
||||
|
||||
has_attachments = False
|
||||
for message in messages:
|
||||
for part in message.parts:
|
||||
if hasattr(part, "file_data") and part.file_data is not None:
|
||||
has_attachments = True
|
||||
break
|
||||
if has_attachments:
|
||||
break
|
||||
messages_summary = self._summarize_messages_for_log(messages)
|
||||
logging.info(
|
||||
"GoogleLLM: Starting stream generation. Model: %s, Messages: %s, Has attachments: %s",
|
||||
model,
|
||||
messages_summary,
|
||||
has_attachments,
|
||||
)
|
||||
|
||||
response = client.models.generate_content_stream(
|
||||
model=model,
|
||||
contents=messages,
|
||||
|
||||
@@ -5,6 +5,8 @@ GROQ_BASE_URL = "https://api.groq.com/openai/v1"
|
||||
|
||||
|
||||
class GroqLLM(OpenAILLM):
|
||||
provider_name = "groq"
|
||||
|
||||
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
|
||||
super().__init__(
|
||||
api_key=api_key or settings.GROQ_API_KEY or settings.API_KEY,
|
||||
|
||||
@@ -10,6 +10,18 @@ from application.logging import build_stack_data
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Cap the agent tool-call loop. Without this an LLM that keeps
|
||||
# requesting more tool calls (preview models, sparse tool results,
|
||||
# under-specified prompts) can chain searches indefinitely and the
|
||||
# stream never finalises. 25 mirrors Dify's default.
|
||||
MAX_TOOL_ITERATIONS = 25
|
||||
_FINALIZE_INSTRUCTION = (
|
||||
f"You have made {MAX_TOOL_ITERATIONS} tool calls. Provide a final "
|
||||
"response to the user based on what you have, without making any "
|
||||
"additional tool calls."
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCall:
|
||||
"""Represents a tool/function call from the LLM."""
|
||||
@@ -280,7 +292,26 @@ class LLMHandler(ABC):
|
||||
# Keep serialized function calls/responses so the compressor sees actions
|
||||
parts_text.append(str(item))
|
||||
elif "files" in item:
|
||||
parts_text.append(str(item))
|
||||
# Image attachments arrive with raw bytes / base64
|
||||
# inline (see GoogleLLM.prepare_messages_with_attachments).
|
||||
# ``str(item)`` would dump the whole byte/base64
|
||||
# blob into the compression prompt and bust the
|
||||
# compression LLM's input limit.
|
||||
files = item.get("files") or []
|
||||
descriptors = []
|
||||
if isinstance(files, list):
|
||||
for f in files:
|
||||
if isinstance(f, dict):
|
||||
descriptors.append(
|
||||
f.get("mime_type") or "file"
|
||||
)
|
||||
elif isinstance(f, str):
|
||||
descriptors.append(f)
|
||||
if not descriptors:
|
||||
descriptors = ["file"]
|
||||
parts_text.append(
|
||||
f"[attachment: {', '.join(descriptors)}]"
|
||||
)
|
||||
return "\n".join(parts_text)
|
||||
return ""
|
||||
|
||||
@@ -470,10 +501,14 @@ class LLMHandler(ABC):
|
||||
)
|
||||
return self._perform_in_memory_compression(agent, messages)
|
||||
|
||||
# Use orchestrator to perform compression
|
||||
# Use orchestrator to perform compression. ``model_user_id``
|
||||
# keeps BYOM registry resolution scoped to the model owner
|
||||
# (shared-agent dispatch) while ``user_id`` stays the caller
|
||||
# for the conversation access check.
|
||||
result = orchestrator.compress_mid_execution(
|
||||
conversation_id=agent.conversation_id,
|
||||
user_id=agent.initial_user_id,
|
||||
model_user_id=getattr(agent, "model_user_id", None),
|
||||
model_id=agent.model_id,
|
||||
decoded_token=getattr(agent, "decoded_token", {}),
|
||||
current_conversation=conversation,
|
||||
@@ -577,7 +612,20 @@ class LLMHandler(ABC):
|
||||
if settings.COMPRESSION_MODEL_OVERRIDE
|
||||
else agent.model_id
|
||||
)
|
||||
provider = get_provider_from_model_id(compression_model)
|
||||
agent_decoded = getattr(agent, "decoded_token", None)
|
||||
caller_sub = (
|
||||
agent_decoded.get("sub")
|
||||
if isinstance(agent_decoded, dict)
|
||||
else None
|
||||
)
|
||||
# Use model-owner scope (mirrors orchestrator path) so
|
||||
# shared-agent owner-BYOM resolves under the owner's layer.
|
||||
compression_user_id = (
|
||||
getattr(agent, "model_user_id", None) or caller_sub
|
||||
)
|
||||
provider = get_provider_from_model_id(
|
||||
compression_model, user_id=compression_user_id
|
||||
)
|
||||
api_key = get_api_key_for_provider(provider)
|
||||
compression_llm = LLMCreator.create_llm(
|
||||
provider,
|
||||
@@ -586,7 +634,12 @@ class LLMHandler(ABC):
|
||||
getattr(agent, "decoded_token", None),
|
||||
model_id=compression_model,
|
||||
agent_id=getattr(agent, "agent_id", None),
|
||||
model_user_id=compression_user_id,
|
||||
)
|
||||
# Side-channel LLM tag — see ``orchestrator.py`` for rationale.
|
||||
compression_llm._token_usage_source = "compression"
|
||||
compression_llm._request_id = getattr(agent, "_request_id", None) \
|
||||
or getattr(getattr(agent, "llm", None), "_request_id", None)
|
||||
|
||||
# Create service without DB persistence capability
|
||||
compression_service = CompressionService(
|
||||
@@ -897,7 +950,9 @@ class LLMHandler(ABC):
|
||||
parsed = self.parse_response(response)
|
||||
self.llm_calls.append(build_stack_data(agent.llm))
|
||||
|
||||
iteration = 0
|
||||
while parsed.requires_tool_call:
|
||||
iteration += 1
|
||||
tool_handler_gen = self.handle_tool_calls(
|
||||
agent, parsed.tool_calls, tools_dict, messages
|
||||
)
|
||||
@@ -921,15 +976,46 @@ class LLMHandler(ABC):
|
||||
}
|
||||
return ""
|
||||
|
||||
# Cap reached: force one final tool-less call so the stream
|
||||
# always ends with content rather than cutting off.
|
||||
if iteration >= MAX_TOOL_ITERATIONS:
|
||||
logger.warning(
|
||||
"agent tool loop hit cap (%d); forcing finalize",
|
||||
MAX_TOOL_ITERATIONS,
|
||||
)
|
||||
messages.append(
|
||||
{"role": "system", "content": _FINALIZE_INSTRUCTION},
|
||||
)
|
||||
response = agent.llm.gen(
|
||||
model=getattr(agent.llm, "model_id", None) or agent.model_id,
|
||||
messages=messages,
|
||||
tools=None,
|
||||
)
|
||||
parsed = self.parse_response(response)
|
||||
self.llm_calls.append(build_stack_data(agent.llm))
|
||||
break
|
||||
|
||||
# ``agent.model_id`` is the registry id (a UUID for BYOM
|
||||
# records). Use the LLM's own model_id, which LLMCreator
|
||||
# already resolved to the upstream model name. Built-ins:
|
||||
# the two are equal; BYOM: the upstream name like
|
||||
# "mistral-large-latest" instead of the UUID.
|
||||
response = agent.llm.gen(
|
||||
model=agent.model_id, messages=messages, tools=agent.tools
|
||||
model=getattr(agent.llm, "model_id", None) or agent.model_id,
|
||||
messages=messages,
|
||||
tools=agent.tools,
|
||||
)
|
||||
parsed = self.parse_response(response)
|
||||
self.llm_calls.append(build_stack_data(agent.llm))
|
||||
return parsed.content
|
||||
|
||||
def handle_streaming(
|
||||
self, agent, response: Any, tools_dict: Dict, messages: List[Dict]
|
||||
self,
|
||||
agent,
|
||||
response: Any,
|
||||
tools_dict: Dict,
|
||||
messages: List[Dict],
|
||||
_iteration: int = 0,
|
||||
) -> Generator:
|
||||
"""
|
||||
Handle streaming response flow.
|
||||
@@ -998,6 +1084,9 @@ class LLMHandler(ABC):
|
||||
}
|
||||
return
|
||||
|
||||
next_iteration = _iteration + 1
|
||||
cap_reached = next_iteration >= MAX_TOOL_ITERATIONS
|
||||
|
||||
# Check if context limit was reached during tool execution
|
||||
if hasattr(agent, 'context_limit_reached') and agent.context_limit_reached:
|
||||
# Add system message warning about context limit
|
||||
@@ -1010,13 +1099,32 @@ class LLMHandler(ABC):
|
||||
)
|
||||
})
|
||||
logger.info("Context limit reached - instructing agent to wrap up")
|
||||
elif cap_reached:
|
||||
logger.warning(
|
||||
"agent tool loop hit cap (%d); forcing finalize",
|
||||
MAX_TOOL_ITERATIONS,
|
||||
)
|
||||
messages.append(
|
||||
{"role": "system", "content": _FINALIZE_INSTRUCTION},
|
||||
)
|
||||
|
||||
# See note above on agent.model_id vs llm.model_id.
|
||||
response = agent.llm.gen_stream(
|
||||
model=agent.model_id, messages=messages, tools=agent.tools if not agent.context_limit_reached else None
|
||||
model=getattr(agent.llm, "model_id", None) or agent.model_id,
|
||||
messages=messages,
|
||||
tools=(
|
||||
None
|
||||
if cap_reached
|
||||
or getattr(agent, "context_limit_reached", False)
|
||||
else agent.tools
|
||||
),
|
||||
)
|
||||
self.llm_calls.append(build_stack_data(agent.llm))
|
||||
|
||||
yield from self.handle_streaming(agent, response, tools_dict, messages)
|
||||
yield from self.handle_streaming(
|
||||
agent, response, tools_dict, messages,
|
||||
_iteration=next_iteration,
|
||||
)
|
||||
return
|
||||
if parsed.content:
|
||||
buffer += parsed.content
|
||||
|
||||
@@ -1,9 +1,35 @@
|
||||
import base64
|
||||
import binascii
|
||||
import uuid
|
||||
from typing import Any, Dict, Generator
|
||||
from typing import Any, Dict, Generator, Optional, Union
|
||||
|
||||
from application.llm.handlers.base import LLMHandler, LLMResponse, ToolCall
|
||||
|
||||
|
||||
def _encode_thought_signature(sig: Optional[Union[bytes, str]]) -> Optional[str]:
|
||||
# Gemini's Python SDK returns thought_signature as raw bytes, but the
|
||||
# field is typed Optional[str] downstream and gets json.dumps'd into
|
||||
# SSE events. Encode once at ingress so callers only ever see a str.
|
||||
if isinstance(sig, bytes):
|
||||
return base64.b64encode(sig).decode("ascii")
|
||||
return sig
|
||||
|
||||
|
||||
def _decode_thought_signature(
|
||||
sig: Optional[Union[bytes, str]],
|
||||
) -> Optional[Union[bytes, str]]:
|
||||
# Reverse of _encode_thought_signature — Gemini's SDK expects bytes
|
||||
# back when we replay a tool call. ``validate=True`` keeps ASCII
|
||||
# strings that happen to be loosely decodable from being silently
|
||||
# turned into bytes; non-base64 inputs pass through unchanged.
|
||||
if isinstance(sig, str):
|
||||
try:
|
||||
return base64.b64decode(sig.encode("ascii"), validate=True)
|
||||
except (binascii.Error, ValueError):
|
||||
return sig
|
||||
return sig
|
||||
|
||||
|
||||
class GoogleLLMHandler(LLMHandler):
|
||||
"""Handler for Google's GenAI API."""
|
||||
|
||||
@@ -23,7 +49,7 @@ class GoogleLLMHandler(LLMHandler):
|
||||
for idx, part in enumerate(parts):
|
||||
if hasattr(part, "function_call") and part.function_call is not None:
|
||||
has_sig = hasattr(part, "thought_signature") and part.thought_signature is not None
|
||||
thought_sig = part.thought_signature if has_sig else None
|
||||
thought_sig = _encode_thought_signature(part.thought_signature) if has_sig else None
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=str(uuid.uuid4()),
|
||||
@@ -50,7 +76,7 @@ class GoogleLLMHandler(LLMHandler):
|
||||
tool_calls = []
|
||||
if hasattr(response, "function_call") and response.function_call is not None:
|
||||
has_sig = hasattr(response, "thought_signature") and response.thought_signature is not None
|
||||
thought_sig = response.thought_signature if has_sig else None
|
||||
thought_sig = _encode_thought_signature(response.thought_signature) if has_sig else None
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=str(uuid.uuid4()),
|
||||
@@ -70,8 +96,15 @@ class GoogleLLMHandler(LLMHandler):
|
||||
"""Create a tool result message in the standard internal format."""
|
||||
import json as _json
|
||||
|
||||
from application.storage.db.serialization import PGNativeJSONEncoder
|
||||
|
||||
# PostgresTool results commonly include PG-native types
|
||||
# (datetime / UUID / Decimal / bytea) when SELECT touches
|
||||
# timestamptz / numeric / uuid / bytea columns. The shared
|
||||
# encoder handles all five — bytes get base64 (lossless) instead
|
||||
# of the ``str(b'...')`` repr that ``default=str`` would emit.
|
||||
content = (
|
||||
_json.dumps(result)
|
||||
_json.dumps(result, cls=PGNativeJSONEncoder)
|
||||
if not isinstance(result, str)
|
||||
else result
|
||||
)
|
||||
|
||||
@@ -40,8 +40,15 @@ class OpenAILLMHandler(LLMHandler):
|
||||
"""Create a tool result message in the standard internal format."""
|
||||
import json as _json
|
||||
|
||||
from application.storage.db.serialization import PGNativeJSONEncoder
|
||||
|
||||
# PostgresTool results commonly include PG-native types
|
||||
# (datetime / UUID / Decimal / bytea) when SELECT touches
|
||||
# timestamptz / numeric / uuid / bytea columns. The shared
|
||||
# encoder handles all five — bytes get base64 (lossless) instead
|
||||
# of the ``str(b'...')`` repr that ``default=str`` would emit.
|
||||
content = (
|
||||
_json.dumps(result)
|
||||
_json.dumps(result, cls=PGNativeJSONEncoder)
|
||||
if not isinstance(result, str)
|
||||
else result
|
||||
)
|
||||
|
||||
@@ -26,6 +26,8 @@ class LlamaSingleton:
|
||||
|
||||
|
||||
class LlamaCpp(BaseLLM):
|
||||
provider_name = "llama_cpp"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key=None,
|
||||
|
||||
@@ -1,34 +1,11 @@
|
||||
import logging
|
||||
|
||||
from application.llm.anthropic import AnthropicLLM
|
||||
from application.llm.docsgpt_provider import DocsGPTAPILLM
|
||||
from application.llm.google_ai import GoogleLLM
|
||||
from application.llm.groq import GroqLLM
|
||||
from application.llm.llama_cpp import LlamaCpp
|
||||
from application.llm.novita import NovitaLLM
|
||||
from application.llm.openai import AzureOpenAILLM, OpenAILLM
|
||||
from application.llm.premai import PremAILLM
|
||||
from application.llm.sagemaker import SagemakerAPILLM
|
||||
from application.llm.open_router import OpenRouterLLM
|
||||
from application.llm.providers import PROVIDERS_BY_NAME
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMCreator:
|
||||
llms = {
|
||||
"openai": OpenAILLM,
|
||||
"azure_openai": AzureOpenAILLM,
|
||||
"sagemaker": SagemakerAPILLM,
|
||||
"llama.cpp": LlamaCpp,
|
||||
"anthropic": AnthropicLLM,
|
||||
"docsgpt": DocsGPTAPILLM,
|
||||
"premai": PremAILLM,
|
||||
"groq": GroqLLM,
|
||||
"google": GoogleLLM,
|
||||
"novita": NovitaLLM,
|
||||
"openrouter": OpenRouterLLM,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create_llm(
|
||||
cls,
|
||||
@@ -39,28 +16,111 @@ class LLMCreator:
|
||||
model_id=None,
|
||||
agent_id=None,
|
||||
backup_models=None,
|
||||
model_user_id=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
from application.core.model_utils import get_base_url_for_model
|
||||
"""Construct an LLM for the given provider ``type``.
|
||||
|
||||
llm_class = cls.llms.get(type.lower())
|
||||
if not llm_class:
|
||||
``model_user_id`` is the BYOM-resolution scope. Defaults to
|
||||
``decoded_token['sub']`` (the caller). Pass it explicitly when
|
||||
the model record belongs to a *different* user — most notably
|
||||
for shared-agent dispatch, where the agent's stored
|
||||
``default_model_id`` is the owner's BYOM UUID but
|
||||
``decoded_token`` represents the caller.
|
||||
"""
|
||||
from application.core.model_registry import ModelRegistry
|
||||
from application.security.safe_url import (
|
||||
UnsafeUserUrlError,
|
||||
pinned_httpx_client,
|
||||
validate_user_base_url,
|
||||
)
|
||||
|
||||
plugin = PROVIDERS_BY_NAME.get(type.lower())
|
||||
if plugin is None or plugin.llm_class is None:
|
||||
raise ValueError(f"No LLM class found for type {type}")
|
||||
|
||||
# Extract base_url from model configuration if model_id is provided
|
||||
# Prefer per-model endpoint config from the registry. This is what
|
||||
# makes openai_compatible AND end-user BYOM work without changing
|
||||
# every call site: if the registered AvailableModel carries its
|
||||
# own api_key / base_url, they win over whatever the caller
|
||||
# resolved via the provider plugin.
|
||||
#
|
||||
# End-user BYOM lookups need the user_id from decoded_token to
|
||||
# find the user's per-user models layer (built-in models resolve
|
||||
# without it, so this stays back-compat).
|
||||
base_url = None
|
||||
upstream_model_id = model_id
|
||||
capabilities = None
|
||||
if model_id:
|
||||
base_url = get_base_url_for_model(model_id)
|
||||
user_id = model_user_id
|
||||
if user_id is None:
|
||||
user_id = (
|
||||
(decoded_token or {}).get("sub") if decoded_token else None
|
||||
)
|
||||
model = ModelRegistry.get_instance().get_model(model_id, user_id=user_id)
|
||||
if model is not None:
|
||||
# Forward registry caps so the LLM enforces them at
|
||||
# dispatch (built-in classes hard-code True otherwise).
|
||||
capabilities = getattr(model, "capabilities", None)
|
||||
# SECURITY: refuse user-source dispatch without its own
|
||||
# api_key (would leak settings.API_KEY to base_url).
|
||||
if (
|
||||
getattr(model, "source", "builtin") == "user"
|
||||
and not model.api_key
|
||||
):
|
||||
raise ValueError(
|
||||
f"Custom model {model_id!r} has no usable API key "
|
||||
"(decryption may have failed). Re-save the model "
|
||||
"in settings to dispatch it."
|
||||
)
|
||||
if model.api_key:
|
||||
api_key = model.api_key
|
||||
if model.base_url:
|
||||
base_url = model.base_url
|
||||
# For BYOM the registry id is a UUID; the upstream API
|
||||
# call needs the user's typed model name instead.
|
||||
if model.upstream_model_id:
|
||||
upstream_model_id = model.upstream_model_id
|
||||
|
||||
return llm_class(
|
||||
# SECURITY: re-validate at dispatch (defense in depth
|
||||
# for pre-guard rows / YAML-supplied entries). The
|
||||
# pinned httpx.Client below is what actually closes the
|
||||
# DNS-rebinding TOCTOU window.
|
||||
if base_url and getattr(model, "source", "builtin") == "user":
|
||||
try:
|
||||
validate_user_base_url(base_url)
|
||||
except UnsafeUserUrlError as e:
|
||||
raise ValueError(
|
||||
f"Refusing to dispatch model {model_id!r}: {e}"
|
||||
) from e
|
||||
# Pinned httpx.Client: resolves once, validates, and
|
||||
# binds the SDK's outbound socket to the validated IP
|
||||
# (preserves Host / SNI). Future BYOM providers must
|
||||
# opt in explicitly — only openai_compatible takes
|
||||
# http_client today.
|
||||
if plugin.name == "openai_compatible":
|
||||
try:
|
||||
kwargs["http_client"] = pinned_httpx_client(
|
||||
base_url
|
||||
)
|
||||
except UnsafeUserUrlError as e:
|
||||
raise ValueError(
|
||||
f"Refusing to dispatch model {model_id!r}: {e}"
|
||||
) from e
|
||||
|
||||
# Forward model_user_id so backup/fallback resolves under the
|
||||
# owner's scope on shared-agent dispatch.
|
||||
return plugin.llm_class(
|
||||
api_key,
|
||||
user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
model_id=model_id,
|
||||
model_id=upstream_model_id,
|
||||
agent_id=agent_id,
|
||||
base_url=base_url,
|
||||
backup_models=backup_models,
|
||||
model_user_id=model_user_id,
|
||||
capabilities=capabilities,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -5,6 +5,8 @@ NOVITA_BASE_URL = "https://api.novita.ai/openai"
|
||||
|
||||
|
||||
class NovitaLLM(OpenAILLM):
|
||||
provider_name = "novita"
|
||||
|
||||
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
|
||||
super().__init__(
|
||||
api_key=api_key or settings.NOVITA_API_KEY or settings.API_KEY,
|
||||
|
||||
@@ -5,6 +5,8 @@ OPEN_ROUTER_BASE_URL = "https://openrouter.ai/api/v1"
|
||||
|
||||
|
||||
class OpenRouterLLM(OpenAILLM):
|
||||
provider_name = "openrouter"
|
||||
|
||||
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
|
||||
super().__init__(
|
||||
api_key=api_key or settings.OPEN_ROUTER_API_KEY or settings.API_KEY,
|
||||
|
||||
@@ -61,8 +61,17 @@ def _truncate_base64_for_logging(messages):
|
||||
|
||||
|
||||
class OpenAILLM(BaseLLM):
|
||||
provider_name = "openai"
|
||||
|
||||
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
api_key=None,
|
||||
user_api_key=None,
|
||||
base_url=None,
|
||||
http_client=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
self.api_key = api_key or settings.OPENAI_API_KEY or settings.API_KEY
|
||||
@@ -80,7 +89,18 @@ class OpenAILLM(BaseLLM):
|
||||
else:
|
||||
effective_base_url = "https://api.openai.com/v1"
|
||||
|
||||
self.client = OpenAI(api_key=self.api_key, base_url=effective_base_url)
|
||||
# http_client (set by LLMCreator for BYOM) is a DNS-rebinding-safe
|
||||
# httpx.Client; without it the SDK re-resolves DNS per request.
|
||||
if http_client is not None:
|
||||
self.client = OpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=effective_base_url,
|
||||
http_client=http_client,
|
||||
)
|
||||
else:
|
||||
self.client = OpenAI(
|
||||
api_key=self.api_key, base_url=effective_base_url
|
||||
)
|
||||
self.storage = StorageCreator.get_storage()
|
||||
|
||||
def _clean_messages_openai(self, messages):
|
||||
@@ -243,6 +263,13 @@ class OpenAILLM(BaseLLM):
|
||||
if "max_tokens" in kwargs:
|
||||
kwargs["max_completion_tokens"] = kwargs.pop("max_tokens")
|
||||
|
||||
# Defense-in-depth: drop tools / response_format if the
|
||||
# registry's capability flags deny them.
|
||||
if tools and not self._supports_tools():
|
||||
tools = None
|
||||
if response_format and not self._supports_structured_output():
|
||||
response_format = None
|
||||
|
||||
request_params = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
@@ -279,6 +306,13 @@ class OpenAILLM(BaseLLM):
|
||||
if "max_tokens" in kwargs:
|
||||
kwargs["max_completion_tokens"] = kwargs.pop("max_tokens")
|
||||
|
||||
# See _raw_gen for rationale — drop tools/response_format when the
|
||||
# registry-provided capabilities say the model doesn't support them.
|
||||
if tools and not self._supports_tools():
|
||||
tools = None
|
||||
if response_format and not self._supports_structured_output():
|
||||
response_format = None
|
||||
|
||||
request_params = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
@@ -320,9 +354,17 @@ class OpenAILLM(BaseLLM):
|
||||
response.close()
|
||||
|
||||
def _supports_tools(self):
|
||||
# When the LLM was constructed via LLMCreator with a registered
|
||||
# AvailableModel, ``self.capabilities`` is the per-model record.
|
||||
# BYOM users can disable tool support; respect that. Otherwise
|
||||
# OpenAI's API supports tools by default.
|
||||
if self.capabilities is not None:
|
||||
return bool(self.capabilities.supports_tools)
|
||||
return True
|
||||
|
||||
def _supports_structured_output(self):
|
||||
if self.capabilities is not None:
|
||||
return bool(self.capabilities.supports_structured_output)
|
||||
return True
|
||||
|
||||
def prepare_structured_output_format(self, json_schema):
|
||||
@@ -389,8 +431,14 @@ class OpenAILLM(BaseLLM):
|
||||
Returns:
|
||||
list: List of supported MIME types
|
||||
"""
|
||||
from application.core.model_configs import OPENAI_ATTACHMENTS
|
||||
return OPENAI_ATTACHMENTS
|
||||
# Per-model caps from the registry win when present — a BYOM
|
||||
# endpoint that doesn't accept images would otherwise still be
|
||||
# sent base64 image parts because the OpenAI default below
|
||||
# advertises the image alias unconditionally.
|
||||
if self.capabilities is not None:
|
||||
return list(self.capabilities.supported_attachment_types or [])
|
||||
from application.core.model_yaml import resolve_attachment_alias
|
||||
return resolve_attachment_alias("image")
|
||||
|
||||
def prepare_messages_with_attachments(self, messages, attachments=None):
|
||||
"""
|
||||
|
||||
@@ -3,6 +3,7 @@ from application.core.settings import settings
|
||||
|
||||
|
||||
class PremAILLM(BaseLLM):
|
||||
provider_name = "premai"
|
||||
|
||||
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
||||
from premai import Prem
|
||||
|
||||
51
application/llm/providers/__init__.py
Normal file
51
application/llm/providers/__init__.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""Provider plugin registry.
|
||||
|
||||
Plugins are imported eagerly so import errors surface at app boot rather
|
||||
than at first request. ``ALL_PROVIDERS`` is the canonical ordered list;
|
||||
``PROVIDERS_BY_NAME`` is a name-keyed lookup for LLMCreator and the
|
||||
model registry.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
from application.llm.providers.anthropic import AnthropicProvider
|
||||
from application.llm.providers.azure_openai import AzureOpenAIProvider
|
||||
from application.llm.providers.base import Provider
|
||||
from application.llm.providers.docsgpt import DocsGPTProvider
|
||||
from application.llm.providers.google import GoogleProvider
|
||||
from application.llm.providers.groq import GroqProvider
|
||||
from application.llm.providers.huggingface import HuggingFaceProvider
|
||||
from application.llm.providers.llama_cpp import LlamaCppProvider
|
||||
from application.llm.providers.novita import NovitaProvider
|
||||
from application.llm.providers.openai import OpenAIProvider
|
||||
from application.llm.providers.openai_compatible import OpenAICompatibleProvider
|
||||
from application.llm.providers.openrouter import OpenRouterProvider
|
||||
from application.llm.providers.premai import PremAIProvider
|
||||
from application.llm.providers.sagemaker import SagemakerProvider
|
||||
|
||||
# Order here is the order the registry iterates providers (and therefore
|
||||
# the order ``/api/models`` reports them). Match the historical order
|
||||
# from the old ModelRegistry._load_models for byte-stable output during
|
||||
# the migration. ``openai_compatible`` slots in right after ``openai``
|
||||
# so legacy ``OPENAI_BASE_URL`` models keep landing in the same place.
|
||||
ALL_PROVIDERS: List[Provider] = [
|
||||
DocsGPTProvider(),
|
||||
OpenAIProvider(),
|
||||
OpenAICompatibleProvider(),
|
||||
AzureOpenAIProvider(),
|
||||
AnthropicProvider(),
|
||||
GoogleProvider(),
|
||||
GroqProvider(),
|
||||
OpenRouterProvider(),
|
||||
NovitaProvider(),
|
||||
HuggingFaceProvider(),
|
||||
LlamaCppProvider(),
|
||||
PremAIProvider(),
|
||||
SagemakerProvider(),
|
||||
]
|
||||
|
||||
PROVIDERS_BY_NAME: Dict[str, Provider] = {p.name: p for p in ALL_PROVIDERS}
|
||||
|
||||
__all__ = ["ALL_PROVIDERS", "PROVIDERS_BY_NAME", "Provider"]
|
||||
51
application/llm/providers/_apikey_or_llm_name.py
Normal file
51
application/llm/providers/_apikey_or_llm_name.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""Shared helper for providers that follow the
|
||||
``<X>_API_KEY or (LLM_PROVIDER==X and API_KEY)`` pattern.
|
||||
|
||||
This is the dominant pattern across Anthropic, Google, Groq, OpenRouter,
|
||||
and Novita. Extracted here so each plugin stays a few lines long.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from application.core.model_settings import AvailableModel
|
||||
|
||||
|
||||
def get_api_key(
|
||||
settings,
|
||||
provider_name: str,
|
||||
provider_specific_key: Optional[str],
|
||||
) -> Optional[str]:
|
||||
if provider_specific_key:
|
||||
return provider_specific_key
|
||||
if settings.LLM_PROVIDER == provider_name and settings.API_KEY:
|
||||
return settings.API_KEY
|
||||
return None
|
||||
|
||||
|
||||
def filter_models_by_llm_name(
|
||||
settings,
|
||||
provider_name: str,
|
||||
provider_specific_key: Optional[str],
|
||||
models: List[AvailableModel],
|
||||
) -> List[AvailableModel]:
|
||||
"""Mirrors the historical ``_add_<X>_models`` selection logic.
|
||||
|
||||
Behavior:
|
||||
- If the provider-specific API key is set → load all models.
|
||||
- Else if ``LLM_PROVIDER`` matches and ``LLM_NAME`` matches a known
|
||||
model → load just that model.
|
||||
- Otherwise → load all models (preserved "load anyway" branch from
|
||||
the original methods).
|
||||
"""
|
||||
if provider_specific_key:
|
||||
return models
|
||||
if (
|
||||
settings.LLM_PROVIDER == provider_name
|
||||
and settings.LLM_NAME
|
||||
):
|
||||
named = [m for m in models if m.id == settings.LLM_NAME]
|
||||
if named:
|
||||
return named
|
||||
return models
|
||||
23
application/llm/providers/anthropic.py
Normal file
23
application/llm/providers/anthropic.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from application.llm.anthropic import AnthropicLLM
|
||||
from application.llm.providers._apikey_or_llm_name import (
|
||||
filter_models_by_llm_name,
|
||||
get_api_key,
|
||||
)
|
||||
from application.llm.providers.base import Provider
|
||||
|
||||
|
||||
class AnthropicProvider(Provider):
|
||||
name = "anthropic"
|
||||
llm_class = AnthropicLLM
|
||||
|
||||
def get_api_key(self, settings) -> Optional[str]:
|
||||
return get_api_key(settings, self.name, settings.ANTHROPIC_API_KEY)
|
||||
|
||||
def filter_yaml_models(self, settings, models):
|
||||
return filter_models_by_llm_name(
|
||||
settings, self.name, settings.ANTHROPIC_API_KEY, models
|
||||
)
|
||||
30
application/llm/providers/azure_openai.py
Normal file
30
application/llm/providers/azure_openai.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from application.llm.openai import AzureOpenAILLM
|
||||
from application.llm.providers.base import Provider
|
||||
|
||||
|
||||
class AzureOpenAIProvider(Provider):
|
||||
name = "azure_openai"
|
||||
llm_class = AzureOpenAILLM
|
||||
|
||||
def get_api_key(self, settings) -> Optional[str]:
|
||||
# Azure historically uses the generic API_KEY field.
|
||||
return settings.API_KEY
|
||||
|
||||
def is_enabled(self, settings) -> bool:
|
||||
if settings.OPENAI_API_BASE:
|
||||
return True
|
||||
return settings.LLM_PROVIDER == self.name and bool(settings.API_KEY)
|
||||
|
||||
def filter_yaml_models(self, settings, models):
|
||||
# Mirrors _add_azure_openai_models: when LLM_PROVIDER==azure_openai
|
||||
# and LLM_NAME matches a known model, narrow to that one model.
|
||||
# Otherwise load the entire catalog.
|
||||
if settings.LLM_PROVIDER == self.name and settings.LLM_NAME:
|
||||
named = [m for m in models if m.id == settings.LLM_NAME]
|
||||
if named:
|
||||
return named
|
||||
return models
|
||||
74
application/llm/providers/base.py
Normal file
74
application/llm/providers/base.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, ClassVar, List, Optional, Type
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from application.core.model_settings import AvailableModel
|
||||
from application.core.model_yaml import ProviderCatalog
|
||||
from application.core.settings import Settings
|
||||
from application.llm.base import BaseLLM
|
||||
|
||||
|
||||
class Provider(ABC):
|
||||
"""Owns the *behavior* of an LLM provider.
|
||||
|
||||
Concrete providers declare their name, the LLM class to instantiate,
|
||||
and how to resolve credentials from settings. Static model catalogs
|
||||
live in YAML under ``application/core/models/`` and are joined to the
|
||||
provider by name at registry load time.
|
||||
|
||||
Most plugins receive zero or one catalog at registry-build time. The
|
||||
``openai_compatible`` plugin is the exception: it receives one catalog
|
||||
per matching YAML file, each with its own ``api_key_env`` and
|
||||
``base_url``. Plugins that need per-catalog metadata override
|
||||
``get_models``; the default implementation merges catalogs and routes
|
||||
through ``filter_yaml_models`` + ``extra_models``.
|
||||
"""
|
||||
|
||||
name: ClassVar[str]
|
||||
# ``None`` means the provider appears in the catalog but isn't
|
||||
# dispatchable through LLMCreator (e.g. Hugging Face today, where the
|
||||
# original LLMCreator dict had no entry).
|
||||
llm_class: ClassVar[Optional[Type["BaseLLM"]]] = None
|
||||
|
||||
@abstractmethod
|
||||
def get_api_key(self, settings: "Settings") -> Optional[str]:
|
||||
"""Return the API key for this provider, or None if unavailable."""
|
||||
|
||||
def is_enabled(self, settings: "Settings") -> bool:
|
||||
"""Whether this provider should contribute models to the registry."""
|
||||
return bool(self.get_api_key(settings))
|
||||
|
||||
def filter_yaml_models(
|
||||
self, settings: "Settings", models: List["AvailableModel"]
|
||||
) -> List["AvailableModel"]:
|
||||
"""Hook to filter YAML-loaded models. Default: return all."""
|
||||
return models
|
||||
|
||||
def extra_models(self, settings: "Settings") -> List["AvailableModel"]:
|
||||
"""Hook to add dynamic models not declared in YAML. Default: none."""
|
||||
return []
|
||||
|
||||
def get_models(
|
||||
self,
|
||||
settings: "Settings",
|
||||
catalogs: List["ProviderCatalog"],
|
||||
) -> List["AvailableModel"]:
|
||||
"""Final list of models this plugin contributes.
|
||||
|
||||
Default: merge the models across all matched catalogs (later
|
||||
catalog wins on duplicate id), filter via ``filter_yaml_models``,
|
||||
then append ``extra_models``. Override when per-catalog metadata
|
||||
matters (see ``OpenAICompatibleProvider``).
|
||||
"""
|
||||
merged: List["AvailableModel"] = []
|
||||
seen: dict = {}
|
||||
for c in catalogs:
|
||||
for m in c.models:
|
||||
if m.id in seen:
|
||||
merged[seen[m.id]] = m
|
||||
else:
|
||||
seen[m.id] = len(merged)
|
||||
merged.append(m)
|
||||
return self.filter_yaml_models(settings, merged) + self.extra_models(settings)
|
||||
22
application/llm/providers/docsgpt.py
Normal file
22
application/llm/providers/docsgpt.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from application.llm.docsgpt_provider import DocsGPTAPILLM
|
||||
from application.llm.providers.base import Provider
|
||||
|
||||
|
||||
class DocsGPTProvider(Provider):
|
||||
name = "docsgpt"
|
||||
llm_class = DocsGPTAPILLM
|
||||
|
||||
def get_api_key(self, settings) -> Optional[str]:
|
||||
# No provider-specific key; the LLM class can use the generic
|
||||
# API_KEY fallback if it needs one. Mirrors model_utils' historical
|
||||
# behavior of returning settings.API_KEY when no specific key exists.
|
||||
return settings.API_KEY
|
||||
|
||||
def is_enabled(self, settings) -> bool:
|
||||
# The hosted DocsGPT model is hidden when the deployment is
|
||||
# pointed at a custom OpenAI-compatible endpoint.
|
||||
return not settings.OPENAI_BASE_URL
|
||||
23
application/llm/providers/google.py
Normal file
23
application/llm/providers/google.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from application.llm.google_ai import GoogleLLM
|
||||
from application.llm.providers._apikey_or_llm_name import (
|
||||
filter_models_by_llm_name,
|
||||
get_api_key,
|
||||
)
|
||||
from application.llm.providers.base import Provider
|
||||
|
||||
|
||||
class GoogleProvider(Provider):
|
||||
name = "google"
|
||||
llm_class = GoogleLLM
|
||||
|
||||
def get_api_key(self, settings) -> Optional[str]:
|
||||
return get_api_key(settings, self.name, settings.GOOGLE_API_KEY)
|
||||
|
||||
def filter_yaml_models(self, settings, models):
|
||||
return filter_models_by_llm_name(
|
||||
settings, self.name, settings.GOOGLE_API_KEY, models
|
||||
)
|
||||
23
application/llm/providers/groq.py
Normal file
23
application/llm/providers/groq.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from application.llm.groq import GroqLLM
|
||||
from application.llm.providers._apikey_or_llm_name import (
|
||||
filter_models_by_llm_name,
|
||||
get_api_key,
|
||||
)
|
||||
from application.llm.providers.base import Provider
|
||||
|
||||
|
||||
class GroqProvider(Provider):
|
||||
name = "groq"
|
||||
llm_class = GroqLLM
|
||||
|
||||
def get_api_key(self, settings) -> Optional[str]:
|
||||
return get_api_key(settings, self.name, settings.GROQ_API_KEY)
|
||||
|
||||
def filter_yaml_models(self, settings, models):
|
||||
return filter_models_by_llm_name(
|
||||
settings, self.name, settings.GROQ_API_KEY, models
|
||||
)
|
||||
25
application/llm/providers/huggingface.py
Normal file
25
application/llm/providers/huggingface.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from application.llm.providers._apikey_or_llm_name import (
|
||||
get_api_key as shared_get_api_key,
|
||||
)
|
||||
from application.llm.providers.base import Provider
|
||||
|
||||
|
||||
class HuggingFaceProvider(Provider):
|
||||
"""Surfaces ``huggingface-local`` to the model catalog.
|
||||
|
||||
Not dispatchable through LLMCreator — historically there was no
|
||||
HuggingFaceLLM entry in ``LLMCreator.llms``, and calling ``create_llm``
|
||||
with ``"huggingface"`` raised ``ValueError``. We preserve that
|
||||
behavior: the model appears in ``/api/models`` but selecting it
|
||||
surfaces the same error it always did.
|
||||
"""
|
||||
|
||||
name = "huggingface"
|
||||
llm_class = None # not dispatchable
|
||||
|
||||
def get_api_key(self, settings) -> Optional[str]:
|
||||
return shared_get_api_key(settings, self.name, settings.HUGGINGFACE_API_KEY)
|
||||
19
application/llm/providers/llama_cpp.py
Normal file
19
application/llm/providers/llama_cpp.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from application.llm.llama_cpp import LlamaCpp
|
||||
from application.llm.providers.base import Provider
|
||||
|
||||
|
||||
class LlamaCppProvider(Provider):
|
||||
"""LLMCreator-only plugin: invocable via LLM_PROVIDER but not in the catalog."""
|
||||
|
||||
name = "llama.cpp"
|
||||
llm_class = LlamaCpp
|
||||
|
||||
def get_api_key(self, settings) -> Optional[str]:
|
||||
return settings.API_KEY
|
||||
|
||||
def is_enabled(self, settings) -> bool:
|
||||
return False
|
||||
23
application/llm/providers/novita.py
Normal file
23
application/llm/providers/novita.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from application.llm.novita import NovitaLLM
|
||||
from application.llm.providers._apikey_or_llm_name import (
|
||||
filter_models_by_llm_name,
|
||||
get_api_key,
|
||||
)
|
||||
from application.llm.providers.base import Provider
|
||||
|
||||
|
||||
class NovitaProvider(Provider):
|
||||
name = "novita"
|
||||
llm_class = NovitaLLM
|
||||
|
||||
def get_api_key(self, settings) -> Optional[str]:
|
||||
return get_api_key(settings, self.name, settings.NOVITA_API_KEY)
|
||||
|
||||
def filter_yaml_models(self, settings, models):
|
||||
return filter_models_by_llm_name(
|
||||
settings, self.name, settings.NOVITA_API_KEY, models
|
||||
)
|
||||
37
application/llm/providers/openai.py
Normal file
37
application/llm/providers/openai.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from application.llm.openai import OpenAILLM
|
||||
from application.llm.providers.base import Provider
|
||||
|
||||
|
||||
class OpenAIProvider(Provider):
|
||||
name = "openai"
|
||||
llm_class = OpenAILLM
|
||||
|
||||
def get_api_key(self, settings) -> Optional[str]:
|
||||
if settings.OPENAI_API_KEY:
|
||||
return settings.OPENAI_API_KEY
|
||||
if settings.LLM_PROVIDER == self.name and settings.API_KEY:
|
||||
return settings.API_KEY
|
||||
return None
|
||||
|
||||
def is_enabled(self, settings) -> bool:
|
||||
# When the deployment is pointed at a custom OpenAI-compatible
|
||||
# endpoint (Ollama, LM Studio, ...), the cloud-OpenAI catalog is
|
||||
# suppressed but ``is_enabled`` stays True — necessary so the
|
||||
# filter below still gets to drop the catalog (rather than the
|
||||
# registry skipping the provider entirely and missing the rule).
|
||||
if settings.OPENAI_BASE_URL:
|
||||
return True
|
||||
return bool(self.get_api_key(settings))
|
||||
|
||||
def filter_yaml_models(self, settings, models):
|
||||
# Legacy local-endpoint mode hides the cloud catalog. The
|
||||
# corresponding dynamic models live in OpenAICompatibleProvider.
|
||||
if settings.OPENAI_BASE_URL:
|
||||
return []
|
||||
if not settings.OPENAI_API_KEY:
|
||||
return []
|
||||
return models
|
||||
149
application/llm/providers/openai_compatible.py
Normal file
149
application/llm/providers/openai_compatible.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""Generic provider for OpenAI-wire-compatible endpoints.
|
||||
|
||||
Each ``openai_compatible`` YAML file describes one logical endpoint
|
||||
(Mistral, Together, Fireworks, Ollama, ...) with its own
|
||||
``api_key_env`` and ``base_url``. Multiple files can coexist; the
|
||||
plugin produces one set of models per file, each pre-configured with
|
||||
the right credentials and URL.
|
||||
|
||||
The plugin also handles the **legacy** ``OPENAI_BASE_URL`` + ``LLM_NAME``
|
||||
local-endpoint pattern that previously lived in ``OpenAIProvider``. That
|
||||
path generates models dynamically from ``LLM_NAME``, using
|
||||
``OPENAI_BASE_URL`` and ``OPENAI_API_KEY`` as the endpoint config.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from application.core.model_settings import (
|
||||
AvailableModel,
|
||||
ModelCapabilities,
|
||||
ModelProvider,
|
||||
)
|
||||
from application.llm.openai import OpenAILLM
|
||||
from application.llm.providers.base import Provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _parse_model_names(llm_name: Optional[str]) -> List[str]:
|
||||
if not llm_name:
|
||||
return []
|
||||
return [name.strip() for name in llm_name.split(",") if name.strip()]
|
||||
|
||||
|
||||
class OpenAICompatibleProvider(Provider):
|
||||
name = "openai_compatible"
|
||||
llm_class = OpenAILLM
|
||||
|
||||
def get_api_key(self, settings) -> Optional[str]:
|
||||
# Per-model: each catalog supplies its own ``api_key_env``. There
|
||||
# is no single plugin-wide key. LLMCreator reads the per-model
|
||||
# ``api_key`` set during catalog materialization.
|
||||
return None
|
||||
|
||||
def is_enabled(self, settings) -> bool:
|
||||
# Concrete enablement happens per catalog (in ``get_models``).
|
||||
# Returning True lets the registry call ``get_models`` so we can
|
||||
# decide per-file whether to contribute models.
|
||||
return True
|
||||
|
||||
def get_models(self, settings, catalogs) -> List[AvailableModel]:
|
||||
out: List[AvailableModel] = []
|
||||
|
||||
for catalog in catalogs:
|
||||
out.extend(self._materialize_yaml_catalog(catalog))
|
||||
|
||||
if settings.OPENAI_BASE_URL and settings.LLM_NAME:
|
||||
out.extend(self._materialize_legacy_local_endpoint(settings))
|
||||
|
||||
return out
|
||||
|
||||
def _materialize_yaml_catalog(self, catalog) -> List[AvailableModel]:
|
||||
"""Resolve one openai_compatible YAML into ready-to-dispatch models.
|
||||
|
||||
Skipped (with an INFO-level log) if ``api_key_env`` resolves to
|
||||
nothing — no point publishing models the user can't actually
|
||||
call. INFO rather than WARNING because operators may legitimately
|
||||
drop multiple provider YAMLs as templates and only set the env
|
||||
vars for the ones they actually use; a missing key is ambiguous,
|
||||
not necessarily a misconfig.
|
||||
"""
|
||||
if not catalog.base_url:
|
||||
raise ValueError(
|
||||
f"{catalog.source_path}: openai_compatible YAML must set "
|
||||
"'base_url'."
|
||||
)
|
||||
if not catalog.api_key_env:
|
||||
raise ValueError(
|
||||
f"{catalog.source_path}: openai_compatible YAML must set "
|
||||
"'api_key_env'."
|
||||
)
|
||||
|
||||
api_key = os.environ.get(catalog.api_key_env)
|
||||
if not api_key:
|
||||
logger.info(
|
||||
"openai_compatible catalog %s skipped: env var %s is not set",
|
||||
catalog.source_path,
|
||||
catalog.api_key_env,
|
||||
)
|
||||
return []
|
||||
|
||||
out: List[AvailableModel] = []
|
||||
for m in catalog.models:
|
||||
out.append(self._with_endpoint(m, catalog.base_url, api_key))
|
||||
return out
|
||||
|
||||
def _materialize_legacy_local_endpoint(self, settings) -> List[AvailableModel]:
|
||||
"""Generate AvailableModels from ``LLM_NAME`` for the legacy
|
||||
``OPENAI_BASE_URL`` deployment pattern (Ollama, LM Studio, ...).
|
||||
|
||||
Preserves the historical ``provider="openai"`` display behavior
|
||||
by setting ``display_provider="openai"``.
|
||||
"""
|
||||
from application.core.model_yaml import resolve_attachment_alias
|
||||
|
||||
attachments = resolve_attachment_alias("image")
|
||||
api_key = settings.OPENAI_API_KEY or settings.API_KEY
|
||||
out: List[AvailableModel] = []
|
||||
for model_name in _parse_model_names(settings.LLM_NAME):
|
||||
out.append(
|
||||
AvailableModel(
|
||||
id=model_name,
|
||||
provider=ModelProvider.OPENAI_COMPATIBLE,
|
||||
display_name=model_name,
|
||||
description=f"Custom OpenAI-compatible model at {settings.OPENAI_BASE_URL}",
|
||||
base_url=settings.OPENAI_BASE_URL,
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supported_attachment_types=attachments,
|
||||
),
|
||||
api_key=api_key,
|
||||
display_provider="openai",
|
||||
)
|
||||
)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def _with_endpoint(
|
||||
model: AvailableModel, base_url: str, api_key: str
|
||||
) -> AvailableModel:
|
||||
"""Return a copy of ``model`` carrying the catalog's endpoint config.
|
||||
|
||||
The catalog-level ``base_url`` is the default; an explicit
|
||||
per-model ``base_url`` in the YAML wins.
|
||||
"""
|
||||
return AvailableModel(
|
||||
id=model.id,
|
||||
provider=model.provider,
|
||||
display_name=model.display_name,
|
||||
description=model.description,
|
||||
capabilities=model.capabilities,
|
||||
enabled=model.enabled,
|
||||
base_url=model.base_url or base_url,
|
||||
display_provider=model.display_provider,
|
||||
api_key=api_key,
|
||||
)
|
||||
23
application/llm/providers/openrouter.py
Normal file
23
application/llm/providers/openrouter.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from application.llm.open_router import OpenRouterLLM
|
||||
from application.llm.providers._apikey_or_llm_name import (
|
||||
filter_models_by_llm_name,
|
||||
get_api_key,
|
||||
)
|
||||
from application.llm.providers.base import Provider
|
||||
|
||||
|
||||
class OpenRouterProvider(Provider):
|
||||
name = "openrouter"
|
||||
llm_class = OpenRouterLLM
|
||||
|
||||
def get_api_key(self, settings) -> Optional[str]:
|
||||
return get_api_key(settings, self.name, settings.OPEN_ROUTER_API_KEY)
|
||||
|
||||
def filter_yaml_models(self, settings, models):
|
||||
return filter_models_by_llm_name(
|
||||
settings, self.name, settings.OPEN_ROUTER_API_KEY, models
|
||||
)
|
||||
19
application/llm/providers/premai.py
Normal file
19
application/llm/providers/premai.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from application.llm.premai import PremAILLM
|
||||
from application.llm.providers.base import Provider
|
||||
|
||||
|
||||
class PremAIProvider(Provider):
|
||||
"""LLMCreator-only plugin: invocable via LLM_PROVIDER but not in the catalog."""
|
||||
|
||||
name = "premai"
|
||||
llm_class = PremAILLM
|
||||
|
||||
def get_api_key(self, settings) -> Optional[str]:
|
||||
return settings.API_KEY
|
||||
|
||||
def is_enabled(self, settings) -> bool:
|
||||
return False
|
||||
24
application/llm/providers/sagemaker.py
Normal file
24
application/llm/providers/sagemaker.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from application.llm.sagemaker import SagemakerAPILLM
|
||||
from application.llm.providers.base import Provider
|
||||
|
||||
|
||||
class SagemakerProvider(Provider):
|
||||
"""LLMCreator-only plugin: invocable via LLM_PROVIDER but not in the catalog.
|
||||
|
||||
SageMaker reads its credentials from ``SAGEMAKER_*`` settings inside
|
||||
the LLM class itself; this plugin's ``get_api_key`` exists only for
|
||||
LLMCreator's symmetry.
|
||||
"""
|
||||
|
||||
name = "sagemaker"
|
||||
llm_class = SagemakerAPILLM
|
||||
|
||||
def get_api_key(self, settings) -> Optional[str]:
|
||||
return settings.API_KEY
|
||||
|
||||
def is_enabled(self, settings) -> bool:
|
||||
return False
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user