mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-22 13:25:08 +00:00
Compare commits
54 Commits
feat-bring
...
feat/defau
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
85e885520b | ||
|
|
16695205d5 | ||
|
|
764a23d641 | ||
|
|
0bbcbf4539 | ||
|
|
d041db77e1 | ||
|
|
1de82ca040 | ||
|
|
8f7742c937 | ||
|
|
e3bf6a5471 | ||
|
|
e167cf8247 | ||
|
|
c06646519e | ||
|
|
97a362b703 | ||
|
|
29477b40b3 | ||
|
|
e351f45d88 | ||
|
|
4d6f360e3a | ||
|
|
e245057822 | ||
|
|
e692c645b9 | ||
|
|
b4c4ab68f0 | ||
|
|
d23679dd93 | ||
|
|
1b2239e54b | ||
|
|
5ceb99f946 | ||
|
|
892908cef5 | ||
|
|
99ffe439c7 | ||
|
|
ed87972ca6 | ||
|
|
6ad9022dd3 | ||
|
|
9b8fe2d5d0 | ||
|
|
d1dc8de27c | ||
|
|
a29fa44b51 | ||
|
|
026371d024 | ||
|
|
b0df2a479b | ||
|
|
5eae83af1b | ||
|
|
9c875c83c2 | ||
|
|
e6e671faf1 | ||
|
|
a31ec97bd7 | ||
|
|
ebe752d103 | ||
|
|
8c30c1c880 | ||
|
|
4a598e062c | ||
|
|
e285b47170 | ||
|
|
2d884a3df1 | ||
|
|
b9920731e0 | ||
|
|
f5f4c07e59 | ||
|
|
e87dc42ad0 | ||
|
|
40a30054bc | ||
|
|
707e782ac8 | ||
|
|
2bc0b6946b | ||
|
|
fbd686b725 | ||
|
|
29320eb9fd | ||
|
|
0d2a8e11f4 | ||
|
|
f0c39dec23 | ||
|
|
552bfe016a | ||
|
|
a6a5db631b | ||
|
|
8e9f661efc | ||
|
|
82c71be819 | ||
|
|
318de18d43 | ||
|
|
fb24f9cf5e |
12
README.md
12
README.md
@@ -47,11 +47,13 @@
|
||||
</ul>
|
||||
|
||||
## Roadmap
|
||||
- [x] Add OAuth 2.0 authentication for MCP ( September 2025 )
|
||||
- [x] Deep Agents ( October 2025 )
|
||||
- [x] Prompt Templating ( October 2025 )
|
||||
- [x] Full api tooling ( Dec 2025 )
|
||||
- [ ] Agent scheduling ( Jan 2026 )
|
||||
- [x] Agent Workflow Builder with conditional nodes ( February 2026 )
|
||||
- [x] SharePoint & Confluence connectors ( March – April 2026 )
|
||||
- [x] Research mode ( March 2026 )
|
||||
- [x] Postgres migration for user data ( April 2026 )
|
||||
- [x] OpenTelemetry observability ( April 2026 )
|
||||
- [x] Bring Your Own Model (BYOM) ( April 2026 )
|
||||
- [ ] Agent scheduling (RedBeat-backed) ( Q2 2026 )
|
||||
|
||||
You can find our full roadmap [here](https://github.com/orgs/arc53/projects/2). Please don't hesitate to contribute or create issues, it helps us improve DocsGPT!
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ RUN apt-get update && \
|
||||
add-apt-repository ppa:deadsnakes/ppa && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends gcc g++ wget unzip libc6-dev python3.12 python3.12-venv python3.12-dev && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Verify Python installation and setup symlink
|
||||
RUN if [ -f /usr/bin/python3.12 ]; then \
|
||||
@@ -73,7 +73,7 @@ COPY --from=builder /models /app/models
|
||||
COPY . /app/application
|
||||
|
||||
# Change the ownership of the /app directory to the appuser
|
||||
|
||||
|
||||
RUN mkdir -p /app/application/inputs/local
|
||||
RUN chown -R appuser:appuser /app
|
||||
|
||||
@@ -82,6 +82,11 @@ ENV FLASK_APP=app.py \
|
||||
FLASK_DEBUG=true \
|
||||
PATH="/venv/bin:$PATH"
|
||||
|
||||
ENV MALLOC_ARENA_MAX=2 \
|
||||
OMP_NUM_THREADS=4 \
|
||||
MKL_NUM_THREADS=4 \
|
||||
OPENBLAS_NUM_THREADS=4
|
||||
|
||||
# Expose the port the app runs on
|
||||
EXPOSE 7091
|
||||
|
||||
|
||||
@@ -98,6 +98,7 @@ class BaseAgent(ABC):
|
||||
user_api_key=user_api_key,
|
||||
user=self.user,
|
||||
decoded_token=decoded_token,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
self.attachments = attachments or []
|
||||
@@ -114,6 +115,8 @@ class BaseAgent(ABC):
|
||||
self.compressed_summary = compressed_summary
|
||||
self.current_token_count = 0
|
||||
self.context_limit_reached = False
|
||||
self.conversation_id: Optional[str] = None
|
||||
self.initial_user_id: Optional[str] = None
|
||||
|
||||
@log_activity()
|
||||
def gen(
|
||||
|
||||
356
application/agents/default_tools.py
Normal file
356
application/agents/default_tools.py
Normal file
@@ -0,0 +1,356 @@
|
||||
"""Default chat tools — config-free tools on by default in chats."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from application.core.settings import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Fixed namespace — never regenerate; produced ids are persisted.
|
||||
_DEFAULT_TOOL_NAMESPACE = uuid.UUID("6b1d3f2a-9c84-4d17-bf6e-2a0c5e8d4471")
|
||||
|
||||
# Tool names whose storage tables FK ``tool_id`` to ``user_tools.id``;
|
||||
# a synthetic id has no row, so a write would FK-violate. Schema-rot
|
||||
# guard: ``tests.agents.test_default_tools.TestFkBoundToolsIsInSync``.
|
||||
_FK_BOUND_TOOLS = frozenset({"notes", "todo_list"})
|
||||
|
||||
# Tools that should NEVER appear in a headless run (scheduled or webhook).
|
||||
# ``scheduler`` only makes sense from an interactive chat — letting an LLM
|
||||
# call ``schedule_task`` from a scheduled run chains new schedules each fire,
|
||||
# bounded only by ``SCHEDULE_MAX_PER_USER`` (cost foot-gun, confusing UX).
|
||||
_HEADLESS_EXCLUDED_TOOLS = frozenset({"scheduler"})
|
||||
|
||||
# Agent-selectable builtins: hidden from the Add-Tool catalog (internal=True)
|
||||
# and exposed to the agent picker via the same synthetic-id machinery as
|
||||
# default tools. Names may overlap with DEFAULT_CHAT_TOOLS (e.g. ``scheduler``)
|
||||
# — both registries share ``_DEFAULT_TOOL_NAMESPACE`` so the same uuid5
|
||||
# resolves either way (the dual-flag row carries ``default`` AND ``builtin``).
|
||||
BUILTIN_AGENT_TOOLS: tuple = ("scheduler",)
|
||||
|
||||
_tool_cache: Dict[str, Optional[Any]] = {}
|
||||
_ids_cache: Dict[tuple, Dict[str, str]] = {}
|
||||
_loaded_cache: Dict[tuple, List[str]] = {}
|
||||
_builtin_ids_cache: Dict[tuple, Dict[str, str]] = {}
|
||||
_builtin_loaded_cache: Dict[tuple, List[str]] = {}
|
||||
|
||||
|
||||
def _load_tool(tool_name: str) -> Optional[Any]:
|
||||
"""Return a metadata-only instance of a tool, or None if it has no class."""
|
||||
# Imports just the named module (not the whole package) — avoids the
|
||||
# circular import via ``mcp_tool`` → ``application.api.user``.
|
||||
if tool_name in _tool_cache:
|
||||
return _tool_cache[tool_name]
|
||||
|
||||
from application.agents.tools.base import Tool
|
||||
|
||||
instance: Optional[Any] = None
|
||||
try:
|
||||
module = importlib.import_module(f"application.agents.tools.{tool_name}")
|
||||
except ModuleNotFoundError:
|
||||
_tool_cache[tool_name] = None
|
||||
return None
|
||||
for _, obj in inspect.getmembers(module, inspect.isclass):
|
||||
if issubclass(obj, Tool) and obj is not Tool:
|
||||
try:
|
||||
instance = obj({})
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"DEFAULT_CHAT_TOOLS entry %r failed to instantiate; skipping.",
|
||||
tool_name,
|
||||
)
|
||||
instance = None
|
||||
break
|
||||
_tool_cache[tool_name] = instance
|
||||
return instance
|
||||
|
||||
|
||||
def default_tool_id(tool_name: str) -> str:
|
||||
"""Return the deterministic synthetic id for a default tool name."""
|
||||
return str(uuid.uuid5(_DEFAULT_TOOL_NAMESPACE, tool_name))
|
||||
|
||||
|
||||
def default_tool_ids() -> Dict[str, str]:
|
||||
"""Map each configured default-tool name to its synthetic id (memoized)."""
|
||||
key = tuple(settings.DEFAULT_CHAT_TOOLS)
|
||||
cached = _ids_cache.get(key)
|
||||
if cached is None:
|
||||
cached = {name: default_tool_id(name) for name in key}
|
||||
_ids_cache[key] = cached
|
||||
return cached
|
||||
|
||||
|
||||
def is_default_tool_id(tool_id: Any) -> bool:
|
||||
"""Return True if ``tool_id`` is a synthetic default-tool id."""
|
||||
if not tool_id:
|
||||
return False
|
||||
return str(tool_id) in set(default_tool_ids().values())
|
||||
|
||||
|
||||
def default_tool_name_for_id(tool_id: Any) -> Optional[str]:
|
||||
"""Return the default-tool name for a synthetic id, or None."""
|
||||
target = str(tool_id) if tool_id else ""
|
||||
for name, synthetic_id in default_tool_ids().items():
|
||||
if synthetic_id == target:
|
||||
return name
|
||||
return None
|
||||
|
||||
|
||||
def builtin_agent_tool_ids() -> Dict[str, str]:
|
||||
"""Map each agent-selectable builtin to its synthetic id (memoized)."""
|
||||
key = tuple(BUILTIN_AGENT_TOOLS)
|
||||
cached = _builtin_ids_cache.get(key)
|
||||
if cached is None:
|
||||
cached = {name: default_tool_id(name) for name in key}
|
||||
_builtin_ids_cache[key] = cached
|
||||
return cached
|
||||
|
||||
|
||||
def is_builtin_agent_tool_id(tool_id: Any) -> bool:
|
||||
"""Return True if ``tool_id`` is an agent-selectable builtin synthetic id."""
|
||||
if not tool_id:
|
||||
return False
|
||||
return str(tool_id) in set(builtin_agent_tool_ids().values())
|
||||
|
||||
|
||||
def builtin_agent_tool_name_for_id(tool_id: Any) -> Optional[str]:
|
||||
"""Return the builtin tool name for a synthetic id, or None."""
|
||||
target = str(tool_id) if tool_id else ""
|
||||
for name, synthetic_id in builtin_agent_tool_ids().items():
|
||||
if synthetic_id == target:
|
||||
return name
|
||||
return None
|
||||
|
||||
|
||||
def synthesized_tool_name_for_id(tool_id: Any) -> Optional[str]:
|
||||
"""Return the tool name for any synthetic id (default or builtin), or None."""
|
||||
return default_tool_name_for_id(tool_id) or builtin_agent_tool_name_for_id(tool_id)
|
||||
|
||||
|
||||
def is_synthesized_tool_id(tool_id: Any) -> bool:
|
||||
"""Return True for any synthetic id (default chat or agent-builtin)."""
|
||||
return is_default_tool_id(tool_id) or is_builtin_agent_tool_id(tool_id)
|
||||
|
||||
|
||||
def loaded_default_tools() -> List[str]:
|
||||
"""Return configured default-tool names that resolve to a loaded tool."""
|
||||
# Silent + memoized — runs per request; the one-time skip notice
|
||||
# for unimplemented names lives in ``validate_default_chat_tools``.
|
||||
key = tuple(settings.DEFAULT_CHAT_TOOLS)
|
||||
cached = _loaded_cache.get(key)
|
||||
if cached is None:
|
||||
cached = [name for name in key if _load_tool(name) is not None]
|
||||
_loaded_cache[key] = cached
|
||||
return cached
|
||||
|
||||
|
||||
def loaded_builtin_agent_tools() -> List[str]:
|
||||
"""Return builtin agent-tool names that resolve to a loaded tool."""
|
||||
key = tuple(BUILTIN_AGENT_TOOLS)
|
||||
cached = _builtin_loaded_cache.get(key)
|
||||
if cached is None:
|
||||
cached = [name for name in key if _load_tool(name) is not None]
|
||||
_builtin_loaded_cache[key] = cached
|
||||
return cached
|
||||
|
||||
|
||||
def validate_default_chat_tools() -> List[str]:
|
||||
"""Validate ``DEFAULT_CHAT_TOOLS`` at startup; return the usable names."""
|
||||
skipped = [
|
||||
name for name in settings.DEFAULT_CHAT_TOOLS if _load_tool(name) is None
|
||||
]
|
||||
if skipped:
|
||||
logger.debug(
|
||||
"DEFAULT_CHAT_TOOLS entries with no loaded tool, skipped: %s. "
|
||||
"Each activates automatically once its tool exists.",
|
||||
", ".join(skipped),
|
||||
)
|
||||
usable = loaded_default_tools()
|
||||
for name in usable:
|
||||
if name in _FK_BOUND_TOOLS:
|
||||
raise ValueError(
|
||||
f"DEFAULT_CHAT_TOOLS entry {name!r} has a storage table "
|
||||
f"that foreign-keys tool_id to user_tools; a default tool "
|
||||
f"has a synthetic id with no user_tools row, so it would "
|
||||
f"fail at write time. It cannot be defaulted on."
|
||||
)
|
||||
requirements = _load_tool(name).get_config_requirements() or {}
|
||||
required = [
|
||||
key for key, spec in requirements.items()
|
||||
if isinstance(spec, dict) and spec.get("required")
|
||||
]
|
||||
if required:
|
||||
raise ValueError(
|
||||
f"DEFAULT_CHAT_TOOLS entry {name!r} requires config "
|
||||
f"fields {required}; only config-free tools may be "
|
||||
"defaulted on."
|
||||
)
|
||||
if usable:
|
||||
logger.info("Default chat tools active: %s", ", ".join(usable))
|
||||
return usable
|
||||
|
||||
|
||||
def _tool_display(tool_name: str) -> str:
|
||||
"""Return the human-readable display name from the tool docstring."""
|
||||
tool = _load_tool(tool_name)
|
||||
doc = (tool.__doc__ or "").strip() if tool else ""
|
||||
first_line = doc.split("\n", 1)[0].strip() if doc else ""
|
||||
return first_line or tool_name
|
||||
|
||||
|
||||
def _tool_description(tool_name: str) -> str:
|
||||
"""Return the tool description (docstring lines after the first)."""
|
||||
tool = _load_tool(tool_name)
|
||||
doc = (tool.__doc__ or "").strip() if tool else ""
|
||||
parts = doc.split("\n", 1)
|
||||
return parts[1].strip() if len(parts) > 1 else ""
|
||||
|
||||
|
||||
def synthesize_default_tool(tool_name: str) -> Optional[Dict[str, Any]]:
|
||||
"""Build an in-memory ``user_tools``-shaped row for a default tool."""
|
||||
tool = _load_tool(tool_name)
|
||||
if tool is None:
|
||||
return None
|
||||
synthetic_id = default_tool_id(tool_name)
|
||||
return {
|
||||
"id": synthetic_id,
|
||||
"_id": synthetic_id,
|
||||
"name": tool_name,
|
||||
"display_name": _tool_display(tool_name),
|
||||
"custom_name": "",
|
||||
"description": _tool_description(tool_name),
|
||||
"config": {},
|
||||
"config_requirements": {},
|
||||
"actions": tool.get_actions_metadata() or [],
|
||||
"status": True,
|
||||
"default": True,
|
||||
}
|
||||
|
||||
|
||||
def synthesize_builtin_agent_tool(tool_name: str) -> Optional[Dict[str, Any]]:
|
||||
"""Build an in-memory ``user_tools``-shaped row for a builtin agent tool."""
|
||||
tool = _load_tool(tool_name)
|
||||
if tool is None:
|
||||
return None
|
||||
synthetic_id = default_tool_id(tool_name)
|
||||
return {
|
||||
"id": synthetic_id,
|
||||
"_id": synthetic_id,
|
||||
"name": tool_name,
|
||||
"display_name": _tool_display(tool_name),
|
||||
"custom_name": "",
|
||||
"description": _tool_description(tool_name),
|
||||
"config": {},
|
||||
"config_requirements": {},
|
||||
"actions": tool.get_actions_metadata() or [],
|
||||
"status": True,
|
||||
"default": False,
|
||||
"builtin": True,
|
||||
}
|
||||
|
||||
|
||||
def synthesize_tool_by_name(tool_name: str) -> Optional[Dict[str, Any]]:
|
||||
"""Synthesize the row for any default or builtin tool name."""
|
||||
if tool_name in BUILTIN_AGENT_TOOLS:
|
||||
return synthesize_builtin_agent_tool(tool_name)
|
||||
return synthesize_default_tool(tool_name)
|
||||
|
||||
|
||||
def disabled_default_tools(user_doc: Optional[Dict[str, Any]]) -> List[str]:
|
||||
"""Return the user's opt-out list from ``tool_preferences``."""
|
||||
if not isinstance(user_doc, dict):
|
||||
return []
|
||||
prefs = user_doc.get("tool_preferences") or {}
|
||||
if not isinstance(prefs, dict):
|
||||
return []
|
||||
disabled = prefs.get("disabled_default_tools") or []
|
||||
if not isinstance(disabled, list):
|
||||
return []
|
||||
return [str(name) for name in disabled]
|
||||
|
||||
|
||||
def synthesized_default_tools(
|
||||
user_doc: Optional[Dict[str, Any]] = None,
|
||||
*,
|
||||
headless: bool = False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Return synthesized default-tool rows for an agentless chat."""
|
||||
# Agent-bound chats must NOT call this — they resolve exactly
|
||||
# ``agents.tools``. Disabled defaults are dropped. ``headless=True``
|
||||
# additionally drops chat-only tools (e.g. ``scheduler``) so a scheduled
|
||||
# / webhook LLM can't re-schedule itself.
|
||||
disabled = set(disabled_default_tools(user_doc))
|
||||
rows: List[Dict[str, Any]] = []
|
||||
for name in loaded_default_tools():
|
||||
if name in disabled:
|
||||
continue
|
||||
if headless and name in _HEADLESS_EXCLUDED_TOOLS:
|
||||
continue
|
||||
row = synthesize_default_tool(name)
|
||||
if row is not None:
|
||||
rows.append(row)
|
||||
return rows
|
||||
|
||||
|
||||
def is_headless_excluded_tool(tool_name: Optional[str]) -> bool:
|
||||
"""Return True if ``tool_name`` must be hidden from headless runs."""
|
||||
return bool(tool_name) and tool_name in _HEADLESS_EXCLUDED_TOOLS
|
||||
|
||||
|
||||
def default_tools_for_management(
|
||||
user_doc: Optional[Dict[str, Any]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Return every loaded default tool with its on/off ``status``."""
|
||||
# Unlike ``synthesized_default_tools`` (chat toolset), this keeps
|
||||
# disabled tools so the management UI can render their toggle.
|
||||
disabled = set(disabled_default_tools(user_doc))
|
||||
rows: List[Dict[str, Any]] = []
|
||||
for name in loaded_default_tools():
|
||||
row = synthesize_default_tool(name)
|
||||
if row is None:
|
||||
continue
|
||||
row["status"] = name not in disabled
|
||||
rows.append(row)
|
||||
return rows
|
||||
|
||||
|
||||
def builtin_agent_tools_for_management() -> List[Dict[str, Any]]:
|
||||
"""Return every loaded agent-builtin tool for the agent picker (no per-user state)."""
|
||||
rows: List[Dict[str, Any]] = []
|
||||
for name in loaded_builtin_agent_tools():
|
||||
row = synthesize_builtin_agent_tool(name)
|
||||
if row is None:
|
||||
continue
|
||||
rows.append(row)
|
||||
return rows
|
||||
|
||||
|
||||
def resolve_tool_by_id(
|
||||
tool_id: Any,
|
||||
user: Optional[str],
|
||||
*,
|
||||
user_tools_repo: Any = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Resolve a tool by id: default/builtin synthetic id, else user_tools row.
|
||||
|
||||
Dual-registered tools (e.g. ``scheduler``) get both flags on the resolved
|
||||
row so callers can branch on either path without losing the discriminator.
|
||||
"""
|
||||
default_name = default_tool_name_for_id(tool_id)
|
||||
builtin_name = builtin_agent_tool_name_for_id(tool_id)
|
||||
if default_name is not None and builtin_name is not None:
|
||||
row = synthesize_default_tool(default_name) or {}
|
||||
row["builtin"] = True
|
||||
return row or None
|
||||
if default_name is not None:
|
||||
return synthesize_default_tool(default_name)
|
||||
if builtin_name is not None:
|
||||
return synthesize_builtin_agent_tool(builtin_name)
|
||||
if user_tools_repo is None or not user:
|
||||
return None
|
||||
return user_tools_repo.get_any(str(tool_id), user)
|
||||
173
application/agents/headless_runner.py
Normal file
173
application/agents/headless_runner.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""Shared headless agent runner used by webhooks and scheduled runs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
|
||||
from application.agents.agent_creator import AgentCreator
|
||||
from application.agents.tool_executor import ToolExecutor
|
||||
from application.api.answer.services.stream_processor import get_prompt
|
||||
from application.core.settings import settings
|
||||
from application.retriever.retriever_creator import RetrieverCreator
|
||||
from application.storage.db.repositories.sources import SourcesRepository
|
||||
from application.storage.db.session import db_readonly
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _resolve_owner(agent_config: Dict[str, Any]) -> Optional[str]:
|
||||
return agent_config.get("user_id") or agent_config.get("user")
|
||||
|
||||
|
||||
def _resolve_agent_id(agent_config: Dict[str, Any]) -> Optional[str]:
|
||||
raw = agent_config.get("id") or agent_config.get("_id")
|
||||
return str(raw) if raw else None
|
||||
|
||||
|
||||
def run_agent_headless(
|
||||
agent_config: Dict[str, Any],
|
||||
query: str,
|
||||
*,
|
||||
tool_allowlist: Optional[Iterable[str]] = None,
|
||||
model_id_override: Optional[str] = None,
|
||||
endpoint: str = "headless",
|
||||
chat_history: Optional[List[Dict[str, Any]]] = None,
|
||||
conversation_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run an agent with no live client; returns a structured outcome dict."""
|
||||
from application.core.model_utils import (
|
||||
get_api_key_for_provider,
|
||||
get_default_model_id,
|
||||
get_provider_from_model_id,
|
||||
validate_model_id,
|
||||
)
|
||||
from application.utils import calculate_doc_token_budget
|
||||
|
||||
owner = _resolve_owner(agent_config)
|
||||
if not owner:
|
||||
raise ValueError("Agent config is missing user_id; cannot run headless.")
|
||||
decoded_token = {"sub": owner}
|
||||
|
||||
retriever_kind = agent_config.get("retriever", "classic")
|
||||
source_id = agent_config.get("source_id") or agent_config.get("source")
|
||||
source_active: Any = {}
|
||||
if source_id:
|
||||
with db_readonly() as conn:
|
||||
src_row = SourcesRepository(conn).get(str(source_id), owner)
|
||||
if src_row:
|
||||
source_active = str(src_row["id"])
|
||||
retriever_kind = src_row.get("retriever", retriever_kind)
|
||||
source = {"active_docs": source_active}
|
||||
chunks = int(agent_config.get("chunks", 2) or 2)
|
||||
prompt_id = agent_config.get("prompt_id", "default")
|
||||
user_api_key = agent_config.get("key")
|
||||
agent_id = _resolve_agent_id(agent_config)
|
||||
agent_type = agent_config.get("agent_type", "classic")
|
||||
json_schema = agent_config.get("json_schema")
|
||||
prompt = get_prompt(prompt_id)
|
||||
|
||||
candidate_model = model_id_override or agent_config.get("default_model_id") or ""
|
||||
if candidate_model and validate_model_id(candidate_model, user_id=owner):
|
||||
model_id = candidate_model
|
||||
else:
|
||||
model_id = get_default_model_id()
|
||||
if candidate_model:
|
||||
logger.warning(
|
||||
"Agent %s references unknown model_id %r; falling back to %r",
|
||||
agent_id, candidate_model, model_id,
|
||||
)
|
||||
provider = (
|
||||
get_provider_from_model_id(model_id, user_id=owner)
|
||||
if model_id
|
||||
else settings.LLM_PROVIDER
|
||||
)
|
||||
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
|
||||
doc_token_limit = calculate_doc_token_budget(model_id=model_id, user_id=owner)
|
||||
|
||||
retriever = RetrieverCreator.create_retriever(
|
||||
retriever_kind,
|
||||
source=source,
|
||||
chat_history=chat_history or [],
|
||||
prompt=prompt,
|
||||
chunks=chunks,
|
||||
doc_token_limit=doc_token_limit,
|
||||
model_id=model_id,
|
||||
user_api_key=user_api_key,
|
||||
agent_id=agent_id,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
retrieved_docs: List[Dict[str, Any]] = []
|
||||
try:
|
||||
docs = retriever.search(query)
|
||||
if docs:
|
||||
retrieved_docs = docs
|
||||
except Exception as exc:
|
||||
logger.warning("Headless retrieve failed: %s", exc)
|
||||
|
||||
tool_executor = ToolExecutor(
|
||||
user_api_key=user_api_key,
|
||||
user=owner,
|
||||
decoded_token=decoded_token,
|
||||
agent_id=agent_id,
|
||||
headless=True,
|
||||
tool_allowlist=list(tool_allowlist or []),
|
||||
)
|
||||
if conversation_id:
|
||||
tool_executor.conversation_id = str(conversation_id)
|
||||
|
||||
agent = AgentCreator.create_agent(
|
||||
agent_type,
|
||||
endpoint=endpoint,
|
||||
llm_name=provider or settings.LLM_PROVIDER,
|
||||
model_id=model_id,
|
||||
api_key=system_api_key,
|
||||
agent_id=agent_id,
|
||||
user_api_key=user_api_key,
|
||||
prompt=prompt,
|
||||
chat_history=chat_history or [],
|
||||
retrieved_docs=retrieved_docs,
|
||||
decoded_token=decoded_token,
|
||||
attachments=[],
|
||||
json_schema=json_schema,
|
||||
tool_executor=tool_executor,
|
||||
)
|
||||
if conversation_id:
|
||||
agent.conversation_id = str(conversation_id)
|
||||
|
||||
answer_full = ""
|
||||
thought = ""
|
||||
sources_log: List[Dict[str, Any]] = []
|
||||
tool_calls: List[Dict[str, Any]] = []
|
||||
for event in agent.gen(query=query):
|
||||
if not isinstance(event, dict):
|
||||
continue
|
||||
if "answer" in event:
|
||||
answer_full += str(event["answer"])
|
||||
elif "sources" in event:
|
||||
sources_log.extend(event["sources"])
|
||||
elif "tool_calls" in event:
|
||||
tool_calls.extend(event["tool_calls"])
|
||||
elif "thought" in event:
|
||||
thought += str(event["thought"])
|
||||
|
||||
denied = list(getattr(tool_executor, "headless_denials", []))
|
||||
error_type = "tool_not_allowed" if denied and not answer_full.strip() else None
|
||||
|
||||
# Use the LLM accumulator (gen_token_usage / stream_token_usage decorators);
|
||||
# current_token_count is a context-size sentinel, not a usage tally.
|
||||
llm_usage = getattr(getattr(agent, "llm", None), "token_usage", None) or {}
|
||||
prompt_tokens = int(llm_usage.get("prompt_tokens", 0) or 0)
|
||||
generated_tokens = int(llm_usage.get("generated_tokens", 0) or 0)
|
||||
|
||||
return {
|
||||
"answer": answer_full,
|
||||
"thought": thought,
|
||||
"sources": sources_log,
|
||||
"tool_calls": tool_calls,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"generated_tokens": generated_tokens,
|
||||
"denied": denied,
|
||||
"error_type": error_type,
|
||||
"model_id": model_id,
|
||||
}
|
||||
131
application/agents/scheduler_utils.py
Normal file
131
application/agents/scheduler_utils.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""Cron/tz computations for the scheduler (shared by dispatcher, routes, and tool)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||
|
||||
from croniter import croniter
|
||||
|
||||
|
||||
_DELAY_RE = re.compile(r"^\s*(\d+)\s*(s|m|h|d)\s*$", re.IGNORECASE)
|
||||
_DELAY_MULTIPLIERS = {"s": 1, "m": 60, "h": 3600, "d": 86_400}
|
||||
|
||||
|
||||
class ScheduleValidationError(ValueError):
|
||||
"""Raised when a schedule's cron, run_at, or delay is invalid."""
|
||||
|
||||
|
||||
def resolve_timezone(tz_name: Optional[str]) -> ZoneInfo:
|
||||
"""Return a ``ZoneInfo`` for ``tz_name`` (default UTC)."""
|
||||
name = (tz_name or "UTC").strip() or "UTC"
|
||||
try:
|
||||
return ZoneInfo(name)
|
||||
except ZoneInfoNotFoundError as exc:
|
||||
raise ScheduleValidationError(f"Unknown timezone: {name}") from exc
|
||||
|
||||
|
||||
def parse_cron(expression: str) -> None:
|
||||
"""Validate a 5-field cron expression; raise on bad input."""
|
||||
# croniter defers some malformed inputs until get_next, so force one here.
|
||||
if not expression or not isinstance(expression, str):
|
||||
raise ScheduleValidationError("Cron expression is required.")
|
||||
fields = expression.strip().split()
|
||||
if len(fields) != 5:
|
||||
raise ScheduleValidationError("Cron expression must have 5 fields.")
|
||||
try:
|
||||
itr = croniter(expression, datetime.now(timezone.utc))
|
||||
itr.get_next(datetime)
|
||||
except (ValueError, KeyError) as exc:
|
||||
raise ScheduleValidationError(f"Invalid cron expression: {exc}") from exc
|
||||
|
||||
|
||||
_CRON_INTERVAL_WINDOW = 64
|
||||
|
||||
|
||||
def cron_interval_seconds(expression: str, tz_name: Optional[str]) -> int:
|
||||
"""Return the smallest gap between ticks in a rolling window (enforces SCHEDULE_MIN_INTERVAL).
|
||||
|
||||
Walks _CRON_INTERVAL_WINDOW ticks because bursty expressions like
|
||||
``* 9 * * *`` have tiny within-burst gaps and huge between-burst gaps;
|
||||
sampling only two adjacent ticks would miss the small gap.
|
||||
"""
|
||||
parse_cron(expression)
|
||||
tz = resolve_timezone(tz_name)
|
||||
anchor_local = datetime.now(timezone.utc).astimezone(tz)
|
||||
itr = croniter(expression, anchor_local)
|
||||
prev = itr.get_next(datetime)
|
||||
smallest: Optional[int] = None
|
||||
for _ in range(_CRON_INTERVAL_WINDOW - 1):
|
||||
nxt = itr.get_next(datetime)
|
||||
gap = int((nxt - prev).total_seconds())
|
||||
if gap > 0 and (smallest is None or gap < smallest):
|
||||
smallest = gap
|
||||
prev = nxt
|
||||
return smallest if smallest is not None else 0
|
||||
|
||||
|
||||
def next_cron_run(
|
||||
expression: str,
|
||||
tz_name: Optional[str],
|
||||
after: Optional[datetime] = None,
|
||||
) -> datetime:
|
||||
"""Return the next fire time strictly after ``after`` (UTC, tz-aware).
|
||||
|
||||
Evaluates the cadence in the schedule's IANA tz so DST boundaries land on
|
||||
the intended local clock-time (e.g. 9 AM Warsaw stays 9 AM across the jump).
|
||||
"""
|
||||
parse_cron(expression)
|
||||
tz = resolve_timezone(tz_name)
|
||||
anchor_utc = after if after is not None else datetime.now(timezone.utc)
|
||||
if anchor_utc.tzinfo is None:
|
||||
anchor_utc = anchor_utc.replace(tzinfo=timezone.utc)
|
||||
anchor_local = anchor_utc.astimezone(tz)
|
||||
itr = croniter(expression, anchor_local)
|
||||
nxt_local = itr.get_next(datetime)
|
||||
return nxt_local.astimezone(timezone.utc)
|
||||
|
||||
|
||||
def parse_delay(delay: str) -> timedelta:
|
||||
"""Parse a duration like ``30m`` / ``2h`` / ``1d`` into a timedelta."""
|
||||
if not isinstance(delay, str):
|
||||
raise ScheduleValidationError("delay must be a string like '30m' or '2h'.")
|
||||
match = _DELAY_RE.match(delay)
|
||||
if not match:
|
||||
raise ScheduleValidationError(
|
||||
"delay must look like '30s', '15m', '2h', or '1d'."
|
||||
)
|
||||
amount, unit = int(match.group(1)), match.group(2).lower()
|
||||
if amount <= 0:
|
||||
raise ScheduleValidationError("delay must be positive.")
|
||||
return timedelta(seconds=amount * _DELAY_MULTIPLIERS[unit])
|
||||
|
||||
|
||||
def parse_run_at(run_at: str, tz_name: Optional[str] = None) -> datetime:
|
||||
"""Parse an ISO 8601 timestamp; naive values resolve in ``tz_name``.
|
||||
|
||||
Naive values inside the DST "fall back" hour resolve to the earlier instance
|
||||
(zoneinfo default fold=0); pass an explicit offset to select the later one.
|
||||
"""
|
||||
if not isinstance(run_at, str) or not run_at.strip():
|
||||
raise ScheduleValidationError("run_at must be an ISO 8601 string.")
|
||||
try:
|
||||
parsed = datetime.fromisoformat(run_at.strip().replace("Z", "+00:00"))
|
||||
except ValueError as exc:
|
||||
raise ScheduleValidationError(f"Invalid run_at: {exc}") from exc
|
||||
if parsed.tzinfo is None:
|
||||
parsed = parsed.replace(tzinfo=resolve_timezone(tz_name))
|
||||
return parsed.astimezone(timezone.utc)
|
||||
|
||||
|
||||
def clamp_once_horizon(run_at: datetime, max_horizon_seconds: int) -> None:
|
||||
"""Raise when ``run_at`` is in the past or beyond the once-task horizon."""
|
||||
now = datetime.now(timezone.utc)
|
||||
if run_at <= now:
|
||||
raise ScheduleValidationError("run_at is in the past.")
|
||||
if max_horizon_seconds > 0 and run_at - now > timedelta(seconds=max_horizon_seconds):
|
||||
raise ScheduleValidationError(
|
||||
"run_at is beyond the maximum allowed scheduling horizon."
|
||||
)
|
||||
@@ -1,18 +1,113 @@
|
||||
import logging
|
||||
import uuid
|
||||
from collections import Counter
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from application.agents.default_tools import (
|
||||
is_headless_excluded_tool,
|
||||
resolve_tool_by_id,
|
||||
synthesized_default_tools,
|
||||
)
|
||||
from application.agents.tools.tool_action_parser import ToolActionParser
|
||||
from application.agents.tools.tool_manager import ToolManager
|
||||
from application.security.encryption import decrypt_credentials
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.tool_call_attempts import (
|
||||
ToolCallAttemptsRepository,
|
||||
)
|
||||
from application.storage.db.repositories.user_tools import UserToolsRepository
|
||||
from application.storage.db.session import db_readonly
|
||||
from application.storage.db.repositories.users import UsersRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _record_proposed(
|
||||
call_id: str,
|
||||
tool_name: str,
|
||||
action_name: str,
|
||||
arguments: Any,
|
||||
*,
|
||||
tool_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Insert a ``proposed`` row; swallow infra failures so tool calls
|
||||
still run when the journal is unreachable. Returns True iff the row
|
||||
is now journaled (newly created or already present).
|
||||
"""
|
||||
try:
|
||||
with db_session() as conn:
|
||||
inserted = ToolCallAttemptsRepository(conn).record_proposed(
|
||||
call_id,
|
||||
tool_name,
|
||||
action_name,
|
||||
arguments,
|
||||
tool_id=tool_id if tool_id and looks_like_uuid(tool_id) else None,
|
||||
)
|
||||
if not inserted:
|
||||
logger.warning(
|
||||
"tool_call_attempts duplicate call_id=%s; existing row left in place",
|
||||
call_id,
|
||||
extra={"alert": "tool_call_id_collision", "call_id": call_id},
|
||||
)
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("tool_call_attempts proposed write failed for %s", call_id)
|
||||
return False
|
||||
|
||||
|
||||
def _mark_executed(
|
||||
call_id: str,
|
||||
result: Any,
|
||||
*,
|
||||
message_id: Optional[str] = None,
|
||||
artifact_id: Optional[str] = None,
|
||||
proposed_ok: bool = True,
|
||||
tool_name: Optional[str] = None,
|
||||
action_name: Optional[str] = None,
|
||||
arguments: Any = None,
|
||||
tool_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Flip the row to ``executed``. If ``proposed_ok`` is False (the
|
||||
proposed write failed earlier), upsert a fresh row in ``executed`` so
|
||||
the reconciler can still see the attempt — without this, the side
|
||||
effect would be invisible to the journal.
|
||||
"""
|
||||
try:
|
||||
with db_session() as conn:
|
||||
repo = ToolCallAttemptsRepository(conn)
|
||||
if proposed_ok:
|
||||
updated = repo.mark_executed(
|
||||
call_id,
|
||||
result,
|
||||
message_id=message_id,
|
||||
artifact_id=artifact_id,
|
||||
)
|
||||
if updated:
|
||||
return
|
||||
# Fallback synthesizes the row so the journal isn't lost.
|
||||
repo.upsert_executed(
|
||||
call_id,
|
||||
tool_name=tool_name or "unknown",
|
||||
action_name=action_name or "",
|
||||
arguments=arguments if arguments is not None else {},
|
||||
result=result,
|
||||
tool_id=tool_id if tool_id and looks_like_uuid(tool_id) else None,
|
||||
message_id=message_id,
|
||||
artifact_id=artifact_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("tool_call_attempts executed write failed for %s", call_id)
|
||||
|
||||
|
||||
def _mark_failed(call_id: str, error: str) -> None:
|
||||
try:
|
||||
with db_session() as conn:
|
||||
ToolCallAttemptsRepository(conn).mark_failed(call_id, error)
|
||||
except Exception:
|
||||
logger.exception("tool_call_attempts failed-write failed for %s", call_id)
|
||||
|
||||
|
||||
class ToolExecutor:
|
||||
"""Handles tool discovery, preparation, and execution.
|
||||
|
||||
@@ -24,16 +119,31 @@ class ToolExecutor:
|
||||
user_api_key: Optional[str] = None,
|
||||
user: Optional[str] = None,
|
||||
decoded_token: Optional[Dict] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
*,
|
||||
headless: bool = False,
|
||||
tool_allowlist: Optional[List[str]] = None,
|
||||
):
|
||||
self.user_api_key = user_api_key
|
||||
self.user = user
|
||||
self.decoded_token = decoded_token
|
||||
self.agent_id = agent_id
|
||||
# Headless mode (scheduled / webhook): no human to resolve a pause,
|
||||
# so check_pause returns headless_denied sentinels instead.
|
||||
self.headless = bool(headless)
|
||||
# Tool-instance ids pre-authorized for headless approval-gated execution.
|
||||
self.tool_allowlist: set = (
|
||||
{str(x) for x in tool_allowlist} if tool_allowlist else set()
|
||||
)
|
||||
self.tool_calls: List[Dict] = []
|
||||
self._loaded_tools: Dict[str, object] = {}
|
||||
self.conversation_id: Optional[str] = None
|
||||
self.message_id: Optional[str] = None
|
||||
self.client_tools: Optional[List[Dict]] = None
|
||||
self._name_to_tool: Dict[str, Tuple[str, str]] = {}
|
||||
self._tool_to_name: Dict[Tuple[str, str], str] = {}
|
||||
# Filled by the LLMHandler.handle_tool_calls headless loop.
|
||||
self.headless_denials: List[Dict] = []
|
||||
|
||||
def get_tools(self) -> Dict[str, Dict]:
|
||||
"""Load tool configs from DB based on user context.
|
||||
@@ -50,29 +160,54 @@ class ToolExecutor:
|
||||
return tools
|
||||
|
||||
def _get_tools_by_api_key(self, api_key: str) -> Dict[str, Dict]:
|
||||
"""Resolve an agent's toolset — exactly ``agents.tools``, no defaults."""
|
||||
# Per-operation session: the answer pipeline spans a long-lived
|
||||
# generator; wrapping it in a single connection would pin a PG
|
||||
# conn for the whole stream. Open, fetch, close.
|
||||
with db_readonly() as conn:
|
||||
agent_data = AgentsRepository(conn).find_by_key(api_key)
|
||||
tool_ids = agent_data.get("tools", []) if agent_data else []
|
||||
if not tool_ids:
|
||||
return {}
|
||||
tools_repo = UserToolsRepository(conn)
|
||||
owner = (
|
||||
(agent_data.get("user_id") or agent_data.get("user"))
|
||||
if agent_data
|
||||
else None
|
||||
)
|
||||
tools: List[Dict] = []
|
||||
owner = (agent_data.get("user_id") or agent_data.get("user")) if agent_data else None
|
||||
for tid in tool_ids:
|
||||
row = None
|
||||
if owner:
|
||||
row = tools_repo.get_any(str(tid), owner)
|
||||
if row is not None:
|
||||
tools.append(row)
|
||||
return {str(tool["id"]): tool for tool in tools} if tools else {}
|
||||
row = resolve_tool_by_id(tid, owner, user_tools_repo=tools_repo)
|
||||
if row is None:
|
||||
continue
|
||||
# Headless runs (scheduled / webhook) drop chat-only tools
|
||||
# like ``scheduler`` so a fire-time LLM can't chain schedules.
|
||||
if self.headless and is_headless_excluded_tool(row.get("name")):
|
||||
continue
|
||||
tools.append(row)
|
||||
return {str(tool["id"]): tool for tool in tools}
|
||||
|
||||
def _get_user_tools(self, user: str = "local") -> Dict[str, Dict]:
|
||||
"""Resolve an agentless chat's toolset: explicit user tools plus defaults."""
|
||||
with db_readonly() as conn:
|
||||
user_tools = UserToolsRepository(conn).list_active_for_user(user)
|
||||
return {str(i): tool for i, tool in enumerate(user_tools)}
|
||||
user_doc = (
|
||||
UsersRepository(conn).get(user) if self.agent_id is None else None
|
||||
)
|
||||
# Headless agentless runs (e.g. scheduled fire) drop chat-only
|
||||
# tools (``scheduler``) from explicit user_tools too.
|
||||
filtered_user_tools = [
|
||||
t for t in user_tools
|
||||
if not (self.headless and is_headless_excluded_tool(t.get("name")))
|
||||
]
|
||||
# Index keys (ints) and synthetic uuid5 keys can't collide.
|
||||
tools: Dict[str, Dict] = {
|
||||
str(i): tool for i, tool in enumerate(filtered_user_tools)
|
||||
}
|
||||
if self.agent_id is None:
|
||||
for default_row in synthesized_default_tools(
|
||||
user_doc, headless=self.headless,
|
||||
):
|
||||
tools[str(default_row["id"])] = default_row
|
||||
return tools
|
||||
|
||||
def merge_client_tools(
|
||||
self, tools_dict: Dict, client_tools: List[Dict]
|
||||
@@ -210,9 +345,11 @@ class ToolExecutor:
|
||||
def check_pause(
|
||||
self, tools_dict: Dict, call, llm_class_name: str
|
||||
) -> Optional[Dict]:
|
||||
"""Check if a tool call requires pausing for approval or client execution.
|
||||
"""Return a pending-action dict (approval / client / headless_denied) or None.
|
||||
|
||||
Returns a dict describing the pending action if pause is needed, None otherwise.
|
||||
In headless mode the dict's pause_type is ``headless_denied`` so the
|
||||
upstream loop synthesizes a tool result instead of pausing (nothing can
|
||||
resume a scheduled / webhook run).
|
||||
"""
|
||||
parser = ToolActionParser(llm_class_name, name_mapping=self._name_to_tool)
|
||||
tool_id, action_name, call_args = parser.parse_args(call)
|
||||
@@ -223,9 +360,26 @@ class ToolExecutor:
|
||||
return None # Will be handled as error by execute()
|
||||
|
||||
tool_data = tools_dict[tool_id]
|
||||
arguments = call_args if isinstance(call_args, dict) else {}
|
||||
|
||||
# Client-side tools
|
||||
if tool_data.get("client_side"):
|
||||
if self.headless:
|
||||
return {
|
||||
"call_id": call_id,
|
||||
"name": llm_name,
|
||||
"tool_name": tool_data.get("name", "unknown"),
|
||||
"tool_id": tool_id,
|
||||
"action_name": action_name,
|
||||
"llm_name": llm_name,
|
||||
"arguments": arguments,
|
||||
"pause_type": "headless_denied",
|
||||
"deny_reason": (
|
||||
"Client-side tools cannot run in headless / scheduled runs."
|
||||
),
|
||||
"error_type": "tool_not_allowed",
|
||||
"thought_signature": getattr(call, "thought_signature", None),
|
||||
}
|
||||
return {
|
||||
"call_id": call_id,
|
||||
"name": llm_name,
|
||||
@@ -233,7 +387,7 @@ class ToolExecutor:
|
||||
"tool_id": tool_id,
|
||||
"action_name": action_name,
|
||||
"llm_name": llm_name,
|
||||
"arguments": call_args if isinstance(call_args, dict) else {},
|
||||
"arguments": arguments,
|
||||
"pause_type": "requires_client_execution",
|
||||
"thought_signature": getattr(call, "thought_signature", None),
|
||||
}
|
||||
@@ -250,6 +404,27 @@ class ToolExecutor:
|
||||
)
|
||||
|
||||
if action_data.get("require_approval"):
|
||||
if self.headless:
|
||||
tool_row_id = str(tool_data.get("id") or tool_id)
|
||||
if tool_row_id in self.tool_allowlist:
|
||||
# Pre-authorized for headless execution — fall through.
|
||||
return None
|
||||
return {
|
||||
"call_id": call_id,
|
||||
"name": llm_name,
|
||||
"tool_name": tool_data.get("name", "unknown"),
|
||||
"tool_id": tool_id,
|
||||
"action_name": action_name,
|
||||
"llm_name": llm_name,
|
||||
"arguments": arguments,
|
||||
"pause_type": "headless_denied",
|
||||
"deny_reason": (
|
||||
"This tool requires approval and is not in the run's "
|
||||
"tool_allowlist."
|
||||
),
|
||||
"error_type": "tool_not_allowed",
|
||||
"thought_signature": getattr(call, "thought_signature", None),
|
||||
}
|
||||
return {
|
||||
"call_id": call_id,
|
||||
"name": llm_name,
|
||||
@@ -257,7 +432,7 @@ class ToolExecutor:
|
||||
"tool_id": tool_id,
|
||||
"action_name": action_name,
|
||||
"llm_name": llm_name,
|
||||
"arguments": call_args if isinstance(call_args, dict) else {},
|
||||
"arguments": arguments,
|
||||
"pause_type": "awaiting_approval",
|
||||
"thought_signature": getattr(call, "thought_signature", None),
|
||||
}
|
||||
@@ -274,7 +449,14 @@ class ToolExecutor:
|
||||
|
||||
if tool_id is None or action_name is None:
|
||||
error_message = f"Error: Failed to parse LLM tool call. Tool name: {llm_name}"
|
||||
logger.error(error_message)
|
||||
logger.error(
|
||||
"tool_call_parse_failed",
|
||||
extra={
|
||||
"llm_class_name": llm_class_name,
|
||||
"llm_tool_name": llm_name,
|
||||
"call_id": call_id,
|
||||
},
|
||||
)
|
||||
|
||||
tool_call_data = {
|
||||
"tool_name": "unknown",
|
||||
@@ -289,7 +471,15 @@ class ToolExecutor:
|
||||
|
||||
if tool_id not in tools_dict:
|
||||
error_message = f"Error: Tool ID '{tool_id}' extracted from LLM call not found in available tools_dict. Available IDs: {list(tools_dict.keys())}"
|
||||
logger.error(error_message)
|
||||
logger.error(
|
||||
"tool_id_not_found",
|
||||
extra={
|
||||
"tool_id": tool_id,
|
||||
"llm_tool_name": llm_name,
|
||||
"call_id": call_id,
|
||||
"available_tool_count": len(tools_dict),
|
||||
},
|
||||
)
|
||||
|
||||
tool_call_data = {
|
||||
"tool_name": "unknown",
|
||||
@@ -308,9 +498,36 @@ class ToolExecutor:
|
||||
"action_name": llm_name,
|
||||
"arguments": call_args,
|
||||
}
|
||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "pending"}}
|
||||
|
||||
tool_data = tools_dict[tool_id]
|
||||
# Journal first so the reconciler sees malformed calls and any
|
||||
# subsequent ``_mark_failed`` actually updates a real row.
|
||||
proposed_ok = _record_proposed(
|
||||
call_id,
|
||||
tool_data["name"],
|
||||
action_name,
|
||||
call_args if isinstance(call_args, dict) else {},
|
||||
tool_id=tool_data.get("id"),
|
||||
)
|
||||
# Defensive guard: a non-dict ``call_args`` (e.g. malformed
|
||||
# JSON on the resume path) would crash the param walk below
|
||||
# with AttributeError on ``.items()``. Surface a clean error
|
||||
# event and flip the journal row to ``failed`` instead of
|
||||
# killing the stream.
|
||||
if not isinstance(call_args, dict):
|
||||
error_message = (
|
||||
f"Tool call arguments must be a JSON object, got "
|
||||
f"{type(call_args).__name__}."
|
||||
)
|
||||
tool_call_data["result"] = error_message
|
||||
tool_call_data["arguments"] = {}
|
||||
_mark_failed(call_id, error_message)
|
||||
yield {
|
||||
"type": "tool_call",
|
||||
"data": {**tool_call_data, "status": "error"},
|
||||
}
|
||||
self.tool_calls.append(tool_call_data)
|
||||
return error_message, call_id
|
||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "pending"}}
|
||||
action_data = (
|
||||
tool_data["config"]["actions"][action_name]
|
||||
if tool_data["name"] == "api_tool"
|
||||
@@ -356,8 +573,17 @@ class ToolExecutor:
|
||||
f"Failed to load tool '{tool_data.get('name')}' (tool_id key={tool_id}): "
|
||||
"missing 'id' on tool row."
|
||||
)
|
||||
logger.error(error_message)
|
||||
logger.error(
|
||||
"tool_load_failed",
|
||||
extra={
|
||||
"tool_name": tool_data.get("name"),
|
||||
"tool_id": tool_id,
|
||||
"action_name": action_name,
|
||||
"call_id": call_id,
|
||||
},
|
||||
)
|
||||
tool_call_data["result"] = error_message
|
||||
_mark_failed(call_id, error_message)
|
||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
|
||||
self.tool_calls.append(tool_call_data)
|
||||
return error_message, call_id
|
||||
@@ -367,14 +593,18 @@ class ToolExecutor:
|
||||
if tool_data["name"] == "api_tool"
|
||||
else parameters
|
||||
)
|
||||
if tool_data["name"] == "api_tool":
|
||||
logger.debug(
|
||||
f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}"
|
||||
)
|
||||
result = tool.execute_action(action_name, **body)
|
||||
else:
|
||||
logger.debug(f"Executing tool: {action_name} with args: {call_args}")
|
||||
result = tool.execute_action(action_name, **parameters)
|
||||
try:
|
||||
if tool_data["name"] == "api_tool":
|
||||
logger.debug(
|
||||
f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}"
|
||||
)
|
||||
result = tool.execute_action(action_name, **body)
|
||||
else:
|
||||
logger.debug(f"Executing tool: {action_name} with args: {call_args}")
|
||||
result = tool.execute_action(action_name, **parameters)
|
||||
except Exception as exc:
|
||||
_mark_failed(call_id, str(exc))
|
||||
raise
|
||||
|
||||
get_artifact_id = (
|
||||
getattr(tool, "get_artifact_id", None)
|
||||
@@ -403,6 +633,22 @@ class ToolExecutor:
|
||||
f"{result_full[:50]}..." if len(result_full) > 50 else result_full
|
||||
)
|
||||
|
||||
# Tool side effect has run; flip the journal row so the
|
||||
# message-finalize path can later confirm it. If the proposed
|
||||
# write failed (DB outage), upsert a fresh row in ``executed`` so
|
||||
# the reconciler still sees the side effect.
|
||||
_mark_executed(
|
||||
call_id,
|
||||
result_full,
|
||||
message_id=self.message_id,
|
||||
artifact_id=artifact_id or None,
|
||||
proposed_ok=proposed_ok,
|
||||
tool_name=tool_data["name"],
|
||||
action_name=action_name,
|
||||
arguments=call_args,
|
||||
tool_id=tool_data.get("id"),
|
||||
)
|
||||
|
||||
stream_tool_call_data = {
|
||||
key: value
|
||||
for key, value in tool_call_data.items()
|
||||
@@ -451,15 +697,24 @@ class ToolExecutor:
|
||||
row_id = tool_data.get("id")
|
||||
if not row_id:
|
||||
logger.error(
|
||||
"Tool data missing 'id' for tool name=%s (enumerate-key tool_id=%s); "
|
||||
"skipping load to avoid binding a non-UUID downstream.",
|
||||
tool_data.get("name"),
|
||||
tool_id,
|
||||
"tool_missing_row_id",
|
||||
extra={
|
||||
"tool_name": tool_data.get("name"),
|
||||
"tool_id": tool_id,
|
||||
"action_name": action_name,
|
||||
},
|
||||
)
|
||||
return None
|
||||
tool_config["tool_id"] = str(row_id)
|
||||
if self.conversation_id:
|
||||
tool_config["conversation_id"] = self.conversation_id
|
||||
if tool_data["name"] == "scheduler":
|
||||
# Agent-bound: stamp schedules.agent_id. Agentless: the tool
|
||||
# falls back to ``origin_conversation_id`` as the schedule's
|
||||
# conversation home.
|
||||
tool_config["agent_id"] = (
|
||||
str(self.agent_id) if self.agent_id else None
|
||||
)
|
||||
if tool_data["name"] == "mcp_tool":
|
||||
tool_config["query_mode"] = True
|
||||
|
||||
|
||||
@@ -20,10 +20,11 @@ from pydantic import AnyHttpUrl, ValidationError
|
||||
from redis import Redis
|
||||
|
||||
from application.agents.tools.base import Tool
|
||||
from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task
|
||||
from application.api.user.tasks import mcp_oauth_task
|
||||
from application.cache import get_redis_instance
|
||||
from application.core.settings import settings
|
||||
from application.core.url_validation import SSRFError, validate_url
|
||||
from application.events.keys import stream_key
|
||||
from application.security.encryption import decrypt_credentials
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -76,6 +77,12 @@ class MCPTool(Tool):
|
||||
self.oauth_task_id = config.get("oauth_task_id", None)
|
||||
self.oauth_client_name = config.get("oauth_client_name", "DocsGPT-MCP")
|
||||
self.redirect_uri = self._resolve_redirect_uri(config.get("redirect_uri"))
|
||||
# Pulled out of ``config`` (rather than left in ``self.config``)
|
||||
# because it is a callable supplied by the OAuth worker — not
|
||||
# something the rest of the tool plumbing should marshal or
|
||||
# serialize. ``DocsGPTOAuth`` invokes it from ``redirect_handler``
|
||||
# so the SSE envelope can carry ``authorization_url``.
|
||||
self.oauth_redirect_publish = config.pop("oauth_redirect_publish", None)
|
||||
|
||||
self.available_tools = []
|
||||
self._cache_key = self._generate_cache_key()
|
||||
@@ -167,6 +174,7 @@ class MCPTool(Tool):
|
||||
redirect_uri=self.redirect_uri,
|
||||
task_id=self.oauth_task_id,
|
||||
user_id=self.user_id,
|
||||
redirect_publish=self.oauth_redirect_publish,
|
||||
)
|
||||
elif self.auth_type == "bearer":
|
||||
token = self.auth_credentials.get(
|
||||
@@ -679,12 +687,17 @@ class DocsGPTOAuth(OAuthClientProvider):
|
||||
user_id=None,
|
||||
additional_client_metadata: dict[str, Any] | None = None,
|
||||
skip_redirect_validation: bool = False,
|
||||
redirect_publish=None,
|
||||
):
|
||||
self.redirect_uri = redirect_uri
|
||||
self.redis_client = redis_client
|
||||
self.redis_prefix = redis_prefix
|
||||
self.task_id = task_id
|
||||
self.user_id = user_id
|
||||
# Worker-supplied callback. Invoked from ``redirect_handler``
|
||||
# once the authorization URL is known so the SSE envelope can
|
||||
# carry it. ``None`` for any non-worker entrypoint.
|
||||
self.redirect_publish = redirect_publish
|
||||
|
||||
parsed_url = urlparse(mcp_url)
|
||||
self.server_base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||||
@@ -744,17 +757,19 @@ class DocsGPTOAuth(OAuthClientProvider):
|
||||
self.redis_client.setex(key, 600, auth_url)
|
||||
logger.info("Stored auth_url in Redis: %s", key)
|
||||
|
||||
if self.task_id:
|
||||
status_key = f"mcp_oauth_status:{self.task_id}"
|
||||
status_data = {
|
||||
"status": "requires_redirect",
|
||||
"message": "Authorization required",
|
||||
"authorization_url": self.auth_url,
|
||||
"state": self.extracted_state,
|
||||
"requires_oauth": True,
|
||||
"task_id": self.task_id,
|
||||
}
|
||||
self.redis_client.setex(status_key, 600, json.dumps(status_data))
|
||||
if self.redirect_publish is not None:
|
||||
# Best-effort: a publish failure must not abort the OAuth
|
||||
# handshake — the user can still authorize via the popup
|
||||
# opened from the legacy polling fallback if the SSE
|
||||
# envelope is lost.
|
||||
try:
|
||||
self.redirect_publish(auth_url)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"redirect_publish callback raised for task_id=%s",
|
||||
self.task_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
async def callback_handler(self) -> tuple[str, str | None]:
|
||||
"""Wait for auth code from Redis using the state value."""
|
||||
@@ -764,17 +779,6 @@ class DocsGPTOAuth(OAuthClientProvider):
|
||||
max_wait_time = 300
|
||||
code_key = f"{self.redis_prefix}code:{self.extracted_state}"
|
||||
|
||||
if self.task_id:
|
||||
status_key = f"mcp_oauth_status:{self.task_id}"
|
||||
status_data = {
|
||||
"status": "awaiting_callback",
|
||||
"message": "Waiting for authorization...",
|
||||
"authorization_url": self.auth_url,
|
||||
"state": self.extracted_state,
|
||||
"requires_oauth": True,
|
||||
"task_id": self.task_id,
|
||||
}
|
||||
self.redis_client.setex(status_key, 600, json.dumps(status_data))
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < max_wait_time:
|
||||
code_data = self.redis_client.get(code_key)
|
||||
@@ -789,14 +793,6 @@ class DocsGPTOAuth(OAuthClientProvider):
|
||||
self.redis_client.delete(
|
||||
f"{self.redis_prefix}state:{self.extracted_state}"
|
||||
)
|
||||
|
||||
if self.task_id:
|
||||
status_data = {
|
||||
"status": "callback_received",
|
||||
"message": "Completing authentication...",
|
||||
"task_id": self.task_id,
|
||||
}
|
||||
self.redis_client.setex(status_key, 600, json.dumps(status_data))
|
||||
return code, returned_state
|
||||
error_key = f"{self.redis_prefix}error:{self.extracted_state}"
|
||||
error_data = self.redis_client.get(error_key)
|
||||
@@ -1038,8 +1034,73 @@ class MCPOAuthManager:
|
||||
logger.error("Error handling OAuth callback: %s", e)
|
||||
return False
|
||||
|
||||
def get_oauth_status(self, task_id: str) -> Dict[str, Any]:
|
||||
"""Get current status of OAuth flow using provided task_id."""
|
||||
def get_oauth_status(self, task_id: str, user_id: str) -> Dict[str, Any]:
|
||||
"""Return the latest OAuth status for ``task_id`` from the user's SSE journal.
|
||||
|
||||
Mirrors the legacy polling contract: ``status`` derived from the
|
||||
``mcp.oauth.*`` event-type suffix, with payload fields surfaced
|
||||
(e.g. ``tools``/``tools_count`` on ``completed``).
|
||||
"""
|
||||
if not task_id:
|
||||
return {"status": "not_started", "message": "OAuth flow not started"}
|
||||
return mcp_oauth_status_task(task_id)
|
||||
if not user_id:
|
||||
return {"status": "not_found", "message": "User not provided"}
|
||||
if self.redis_client is None:
|
||||
return {"status": "not_found", "message": "Redis unavailable"}
|
||||
|
||||
try:
|
||||
# OAuth flows are short-lived but a concurrent source
|
||||
# ingest can flood the user channel between the OAuth
|
||||
# popup completing and the user clicking Save, pushing the
|
||||
# completion envelope outside the read window. Bound the
|
||||
# scan by the configured stream cap so we cover the full
|
||||
# journal — XADD MAXLEN keeps that bounded too.
|
||||
scan_count = max(settings.EVENTS_STREAM_MAXLEN, 200)
|
||||
entries = self.redis_client.xrevrange(
|
||||
stream_key(user_id), count=scan_count
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"xrevrange failed for oauth status: user_id=%s task_id=%s",
|
||||
user_id,
|
||||
task_id,
|
||||
)
|
||||
return {"status": "not_found", "message": "Status unavailable"}
|
||||
|
||||
for _entry_id, fields in entries:
|
||||
if not isinstance(fields, dict):
|
||||
continue
|
||||
# decode_responses=False ⇒ bytes keys; the string-key fallback
|
||||
# covers a future flip of that default without a forced refactor.
|
||||
event_raw = fields.get(b"event")
|
||||
if event_raw is None:
|
||||
event_raw = fields.get("event")
|
||||
if event_raw is None:
|
||||
continue
|
||||
if isinstance(event_raw, bytes):
|
||||
try:
|
||||
event_raw = event_raw.decode("utf-8")
|
||||
except Exception:
|
||||
continue
|
||||
try:
|
||||
envelope = json.loads(event_raw)
|
||||
except Exception:
|
||||
continue
|
||||
if not isinstance(envelope, dict):
|
||||
continue
|
||||
event_type = envelope.get("type", "")
|
||||
if not isinstance(event_type, str) or not event_type.startswith(
|
||||
"mcp.oauth."
|
||||
):
|
||||
continue
|
||||
scope = envelope.get("scope") or {}
|
||||
if scope.get("kind") != "mcp_oauth" or scope.get("id") != task_id:
|
||||
continue
|
||||
payload = envelope.get("payload") or {}
|
||||
return {
|
||||
"status": event_type[len("mcp.oauth."):],
|
||||
"task_id": task_id,
|
||||
**payload,
|
||||
}
|
||||
|
||||
return {"status": "not_found", "message": "Status not found"}
|
||||
|
||||
@@ -177,3 +177,4 @@ class PostgresTool(Tool):
|
||||
"order": 1,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
342
application/agents/tools/scheduler.py
Normal file
342
application/agents/tools/scheduler.py
Normal file
@@ -0,0 +1,342 @@
|
||||
"""Scheduler tool: one-time agent tasks in agent-bound or agentless chats."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from application.agents.scheduler_utils import (
|
||||
ScheduleValidationError,
|
||||
clamp_once_horizon,
|
||||
parse_delay,
|
||||
parse_run_at,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
from application.storage.db.repositories.schedules import SchedulesRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
from .base import Tool
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SchedulerTool(Tool):
|
||||
"""
|
||||
Scheduling
|
||||
Schedules a one-time task for the agent to run at a chosen time or delay.
|
||||
"""
|
||||
|
||||
# internal=True keeps scheduler out of /api/available_tools and the
|
||||
# agentless Add-Tool modal; tool_manager.load_tool still lazy-loads it
|
||||
# per-user at execute time (same as memory/notes/todo_list).
|
||||
internal: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tool_config: Optional[Dict[str, Any]] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> None:
|
||||
cfg = tool_config or {}
|
||||
self.user_id: Optional[str] = user_id
|
||||
self.agent_id: Optional[str] = cfg.get("agent_id")
|
||||
self.conversation_id: Optional[str] = cfg.get("conversation_id")
|
||||
|
||||
def execute_action(self, action_name: str, **kwargs: Any) -> str:
|
||||
"""Dispatch on the LLM-supplied action name."""
|
||||
if not self.user_id:
|
||||
return "Error: SchedulerTool requires a valid user_id."
|
||||
# Agent-bound: agent_id must look like a UUID. Agentless: agent_id is
|
||||
# absent; an originating conversation is then mandatory (the schedule's
|
||||
# conversation home, used for history + output append).
|
||||
if self.agent_id and not looks_like_uuid(str(self.agent_id)):
|
||||
return "Error: SchedulerTool received an invalid agent_id."
|
||||
if not self.agent_id and not self.conversation_id:
|
||||
return (
|
||||
"Error: SchedulerTool requires an agent_id or a "
|
||||
"conversation_id (no conversation home)."
|
||||
)
|
||||
if action_name == "schedule_task":
|
||||
return self._schedule_task(
|
||||
instruction=kwargs.get("instruction", ""),
|
||||
delay=kwargs.get("delay"),
|
||||
run_at=kwargs.get("run_at"),
|
||||
tz=kwargs.get("timezone"),
|
||||
)
|
||||
if action_name == "list_scheduled_tasks":
|
||||
return self._list_scheduled_tasks()
|
||||
if action_name == "cancel_scheduled_task":
|
||||
return self._cancel_scheduled_task(kwargs.get("task_id", ""))
|
||||
return f"Unknown action: {action_name}"
|
||||
|
||||
def get_actions_metadata(self) -> List[Dict[str, Any]]:
|
||||
"""Action schemas for the LLM tool catalogue."""
|
||||
return [
|
||||
{
|
||||
"name": "schedule_task",
|
||||
"description": (
|
||||
"Schedule a one-time task. Provide either a `delay` "
|
||||
"(e.g. '30m', '2h', '1d') from now, or a `run_at` ISO-8601 "
|
||||
"absolute time. Optionally pass an IANA `timezone` to resolve "
|
||||
"naive run_at values. The instruction is the task that will "
|
||||
"execute at fire time (including delivery, e.g. 'send to my "
|
||||
"Telegram'). For recurring schedules in an agent chat, point "
|
||||
"the user to the agent's Schedules tab."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"instruction": {
|
||||
"type": "string",
|
||||
"description": "What the agent should do at fire time.",
|
||||
},
|
||||
"delay": {
|
||||
"type": "string",
|
||||
"description": "Duration like '30m', '2h', '1d'.",
|
||||
},
|
||||
"run_at": {
|
||||
"type": "string",
|
||||
"description": "Absolute ISO 8601 timestamp.",
|
||||
},
|
||||
"timezone": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"IANA timezone (e.g. Europe/Warsaw) for naive run_at."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["instruction"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "list_scheduled_tasks",
|
||||
"description": (
|
||||
"List pending one-time tasks for the current chat. "
|
||||
"Agent-bound chats scope to user+agent; agentless chats "
|
||||
"scope to user+originating conversation."
|
||||
),
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
{
|
||||
"name": "cancel_scheduled_task",
|
||||
"description": "Cancel a pending one-time task by its task_id.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"task_id": {
|
||||
"type": "string",
|
||||
"description": "The schedule id returned by schedule_task.",
|
||||
},
|
||||
},
|
||||
"required": ["task_id"],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
def get_config_requirements(self) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
def _schedule_task(
|
||||
self,
|
||||
instruction: str,
|
||||
delay: Optional[str],
|
||||
run_at: Optional[str],
|
||||
tz: Optional[str],
|
||||
) -> str:
|
||||
if not instruction or not isinstance(instruction, str):
|
||||
return "Error: instruction is required."
|
||||
if not delay and not run_at:
|
||||
return "Error: provide either `delay` or `run_at`."
|
||||
if delay and run_at:
|
||||
return "Error: provide only one of `delay` or `run_at`."
|
||||
|
||||
try:
|
||||
if delay:
|
||||
fire = datetime.now(timezone.utc) + parse_delay(delay)
|
||||
else:
|
||||
fire = parse_run_at(run_at, tz)
|
||||
clamp_once_horizon(fire, settings.SCHEDULE_ONCE_MAX_HORIZON)
|
||||
except ScheduleValidationError as exc:
|
||||
return f"Error: {exc}"
|
||||
|
||||
with db_readonly() as conn:
|
||||
count = SchedulesRepository(conn).count_active_for_user(self.user_id)
|
||||
if (
|
||||
settings.SCHEDULE_MAX_PER_USER > 0
|
||||
and count >= settings.SCHEDULE_MAX_PER_USER
|
||||
):
|
||||
return (
|
||||
"Error: you have reached the maximum number of active schedules."
|
||||
)
|
||||
|
||||
# Chat-created tasks default to the user's non-approval tools (for the
|
||||
# agent's toolset when agent-bound, or the user's defaults+user_tools
|
||||
# when agentless).
|
||||
allowlist = _safe_default_allowlist(self.agent_id, self.user_id)
|
||||
|
||||
auto_name = _name_from_instruction(instruction)
|
||||
try:
|
||||
with db_session() as conn:
|
||||
created = SchedulesRepository(conn).create(
|
||||
user_id=self.user_id,
|
||||
agent_id=self.agent_id,
|
||||
trigger_type="once",
|
||||
instruction=instruction.strip(),
|
||||
name=auto_name,
|
||||
run_at=fire,
|
||||
next_run_at=fire,
|
||||
timezone=tz or "UTC",
|
||||
tool_allowlist=allowlist,
|
||||
origin_conversation_id=self.conversation_id,
|
||||
created_via="chat",
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception("schedule_task create failed: %s", exc)
|
||||
return "Error: failed to create scheduled task."
|
||||
return json.dumps(
|
||||
{
|
||||
"task_id": str(created["id"]),
|
||||
"resolved_run_at": _iso_utc(fire),
|
||||
"timezone": tz or "UTC",
|
||||
"instruction": instruction.strip(),
|
||||
"name": auto_name,
|
||||
}
|
||||
)
|
||||
|
||||
def _list_scheduled_tasks(self) -> str:
|
||||
"""Pending one-time tasks for this user, oldest fire first.
|
||||
|
||||
Agent-bound chats scope to user+agent. Agentless chats scope to user+
|
||||
origin_conversation_id so a user only sees tasks created from this chat.
|
||||
"""
|
||||
with db_readonly() as conn:
|
||||
repo = SchedulesRepository(conn)
|
||||
if self.agent_id:
|
||||
rows = repo.list_for_agent(
|
||||
self.agent_id,
|
||||
self.user_id,
|
||||
statuses=["active"],
|
||||
trigger_type="once",
|
||||
)
|
||||
else:
|
||||
rows = repo.list_for_conversation(
|
||||
self.user_id,
|
||||
self.conversation_id,
|
||||
statuses=["active"],
|
||||
trigger_type="once",
|
||||
)
|
||||
# Values arrive as ISO strings (coerce_pg_native); string sentinel keeps types uniform.
|
||||
rows.sort(key=lambda r: r.get("next_run_at") or "9999-12-31T23:59:59Z")
|
||||
items = [
|
||||
{
|
||||
"task_id": str(r["id"]),
|
||||
"instruction": r.get("instruction"),
|
||||
"name": r.get("name"),
|
||||
"resolved_run_at": _iso_utc(r.get("next_run_at")),
|
||||
"timezone": r.get("timezone"),
|
||||
"status": r.get("status"),
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
return json.dumps({"tasks": items})
|
||||
|
||||
def _cancel_scheduled_task(self, task_id: str) -> str:
|
||||
if not task_id or not looks_like_uuid(str(task_id)):
|
||||
return "Error: task_id must be a valid id."
|
||||
with db_session() as conn:
|
||||
repo = SchedulesRepository(conn)
|
||||
# Agentless: scope cancel to user + originating conversation so a
|
||||
# user can only cancel tasks they created in the current chat.
|
||||
if not self.agent_id:
|
||||
row = repo.get(task_id, self.user_id)
|
||||
if row is None or row.get("agent_id") is not None or (
|
||||
str(row.get("origin_conversation_id") or "")
|
||||
!= str(self.conversation_id or "")
|
||||
):
|
||||
return (
|
||||
"Error: scheduled task not found or already terminal."
|
||||
)
|
||||
ok = repo.cancel(task_id, self.user_id)
|
||||
if not ok:
|
||||
return "Error: scheduled task not found or already terminal."
|
||||
return json.dumps({"task_id": str(task_id), "status": "cancelled"})
|
||||
|
||||
|
||||
def _name_from_instruction(instruction: str, *, max_len: int = 80) -> str:
|
||||
"""Compact display name derived from the instruction's first line."""
|
||||
first_line = instruction.strip().split("\n", 1)[0]
|
||||
if len(first_line) <= max_len:
|
||||
return first_line
|
||||
return first_line[: max_len - 1] + "…"
|
||||
|
||||
|
||||
def _iso_utc(value: Any) -> Optional[str]:
|
||||
"""Render a datetime (or ISO string) as RFC3339 UTC; ``None`` passes through."""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
value = datetime.fromisoformat(value.replace("Z", "+00:00"))
|
||||
except ValueError:
|
||||
return value
|
||||
if value.tzinfo is None:
|
||||
value = value.replace(tzinfo=timezone.utc)
|
||||
return value.astimezone(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
|
||||
|
||||
def _safe_default_allowlist(
|
||||
agent_id: Optional[str], user_id: str,
|
||||
) -> List[str]:
|
||||
"""Return ids of available tools whose actions are all non-approval.
|
||||
|
||||
Agent-bound: the agent's ``agents.tools`` entries.
|
||||
Agentless: the user's active ``user_tools`` rows plus synthesized default
|
||||
chat tools (resolved against ``settings.DEFAULT_CHAT_TOOLS`` and the
|
||||
user's ``tool_preferences.disabled_default_tools`` opt-outs).
|
||||
"""
|
||||
from application.agents.default_tools import (
|
||||
resolve_tool_by_id,
|
||||
synthesized_default_tools,
|
||||
)
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.user_tools import UserToolsRepository
|
||||
from application.storage.db.repositories.users import UsersRepository
|
||||
|
||||
def _is_safe(row: Dict[str, Any]) -> bool:
|
||||
actions = row.get("actions") or []
|
||||
return not any(a.get("require_approval") for a in actions)
|
||||
|
||||
safe_ids: List[str] = []
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
tools_repo = UserToolsRepository(conn)
|
||||
if agent_id:
|
||||
agent = AgentsRepository(conn).get(agent_id, user_id)
|
||||
tool_ids = (agent or {}).get("tools") or []
|
||||
for raw_id in tool_ids:
|
||||
tool_id = str(raw_id)
|
||||
row = resolve_tool_by_id(
|
||||
tool_id, user_id, user_tools_repo=tools_repo,
|
||||
)
|
||||
if not row or not _is_safe(row):
|
||||
continue
|
||||
safe_ids.append(tool_id)
|
||||
else:
|
||||
# Agentless: explicit user_tools (active=true) + synthesized
|
||||
# defaults respecting the user's opt-out preferences.
|
||||
user_doc = UsersRepository(conn).get(user_id)
|
||||
for row in tools_repo.list_active_for_user(user_id):
|
||||
if not _is_safe(row):
|
||||
continue
|
||||
safe_ids.append(str(row["id"]))
|
||||
for default_row in synthesized_default_tools(user_doc):
|
||||
if not _is_safe(default_row):
|
||||
continue
|
||||
safe_ids.append(str(default_row["id"]))
|
||||
except Exception: # pragma: no cover — best-effort fallback
|
||||
logger.exception("scheduler: default allowlist build failed")
|
||||
return []
|
||||
return safe_ids
|
||||
@@ -57,6 +57,29 @@ class ToolActionParser:
|
||||
def _parse_google_llm(self, call):
|
||||
try:
|
||||
call_args = call.arguments
|
||||
# Gemini's SDK natively returns ``args`` as a dict, but the
|
||||
# resume path (``gen_continuation``) stringifies it for the
|
||||
# assistant message. Coerce a JSON string back into a dict;
|
||||
# fall back to an empty dict on malformed input so downstream
|
||||
# ``call_args.items()`` doesn't crash the stream.
|
||||
if isinstance(call_args, str):
|
||||
try:
|
||||
call_args = json.loads(call_args)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logger.warning(
|
||||
"Google call.arguments was not valid JSON; "
|
||||
"falling back to empty args for %s",
|
||||
getattr(call, "name", "<unknown>"),
|
||||
)
|
||||
call_args = {}
|
||||
if not isinstance(call_args, dict):
|
||||
logger.warning(
|
||||
"Google call.arguments has unexpected type %s; "
|
||||
"falling back to empty args for %s",
|
||||
type(call_args).__name__,
|
||||
getattr(call, "name", "<unknown>"),
|
||||
)
|
||||
call_args = {}
|
||||
|
||||
resolved = self._resolve_via_mapping(call.name)
|
||||
if resolved:
|
||||
|
||||
@@ -28,7 +28,10 @@ class ToolManager:
|
||||
module = importlib.import_module(f"application.agents.tools.{tool_name}")
|
||||
for member_name, obj in inspect.getmembers(module, inspect.isclass):
|
||||
if issubclass(obj, Tool) and obj is not Tool:
|
||||
if tool_name in {"mcp_tool", "notes", "memory", "todo_list"} and user_id:
|
||||
if (
|
||||
tool_name in {"mcp_tool", "notes", "memory", "todo_list", "scheduler"}
|
||||
and user_id
|
||||
):
|
||||
return obj(tool_config, user_id)
|
||||
else:
|
||||
return obj(tool_config)
|
||||
@@ -36,7 +39,10 @@ class ToolManager:
|
||||
def execute_action(self, tool_name, action_name, user_id=None, **kwargs):
|
||||
if tool_name not in self.tools:
|
||||
raise ValueError(f"Tool '{tool_name}' not loaded")
|
||||
if tool_name in {"mcp_tool", "memory", "todo_list", "notes"} and user_id:
|
||||
if (
|
||||
tool_name in {"mcp_tool", "memory", "todo_list", "notes", "scheduler"}
|
||||
and user_id
|
||||
):
|
||||
tool_config = self.config.get(tool_name, {})
|
||||
tool = self.load_tool(tool_name, tool_config, user_id)
|
||||
return tool.execute_action(action_name, **kwargs)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""0001 initial schema — consolidated Phase-1..3 baseline.
|
||||
"""0001 initial schema — consolidated baseline for user-data tables.
|
||||
|
||||
Revision ID: 0001_initial
|
||||
Revises:
|
||||
|
||||
217
application/alembic/versions/0004_durability_foundation.py
Normal file
217
application/alembic/versions/0004_durability_foundation.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""0004 durability foundation — idempotency, tool-call log, ingest checkpoint.
|
||||
|
||||
Adds ``task_dedup``, ``webhook_dedup``, ``tool_call_attempts``,
|
||||
``ingest_chunk_progress``, and per-row status flags on
|
||||
``conversation_messages`` and ``pending_tool_state``. Also adds
|
||||
``token_usage.source`` and ``token_usage.request_id`` so per-channel
|
||||
cost attribution (``agent_stream`` / ``title`` / ``compression`` /
|
||||
``rag_condense`` / ``fallback``) is queryable and multi-call agent runs
|
||||
can be DISTINCT-collapsed into a single user request for rate limiting.
|
||||
|
||||
Revision ID: 0004_durability_foundation
|
||||
Revises: 0003_user_custom_models
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
revision: str = "0004_durability_foundation"
|
||||
down_revision: Union[str, None] = "0003_user_custom_models"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ------------------------------------------------------------------
|
||||
# New tables
|
||||
# ------------------------------------------------------------------
|
||||
# ``attempt_count`` bounds the per-Celery-task idempotency wrapper's
|
||||
# retry loop so a poison message can't run forever; default 0 means
|
||||
# existing rows behave as if no attempts have run yet.
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE task_dedup (
|
||||
idempotency_key TEXT PRIMARY KEY,
|
||||
task_name TEXT NOT NULL,
|
||||
task_id TEXT NOT NULL,
|
||||
result_json JSONB,
|
||||
status TEXT NOT NULL
|
||||
CHECK (status IN ('pending', 'completed', 'failed')),
|
||||
attempt_count INT NOT NULL DEFAULT 0,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE webhook_dedup (
|
||||
idempotency_key TEXT PRIMARY KEY,
|
||||
agent_id UUID NOT NULL,
|
||||
task_id TEXT NOT NULL,
|
||||
response_json JSONB,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# FK on ``message_id`` uses ``ON DELETE SET NULL`` so the journal row
|
||||
# survives parent-message deletion (compliance / cost-attribution).
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE tool_call_attempts (
|
||||
call_id TEXT PRIMARY KEY,
|
||||
message_id UUID
|
||||
REFERENCES conversation_messages (id)
|
||||
ON DELETE SET NULL,
|
||||
tool_id UUID,
|
||||
tool_name TEXT NOT NULL,
|
||||
action_name TEXT NOT NULL,
|
||||
arguments JSONB NOT NULL,
|
||||
result JSONB,
|
||||
error TEXT,
|
||||
status TEXT NOT NULL
|
||||
CHECK (status IN (
|
||||
'proposed', 'executed', 'confirmed',
|
||||
'compensated', 'failed'
|
||||
)),
|
||||
attempted_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE ingest_chunk_progress (
|
||||
source_id UUID PRIMARY KEY,
|
||||
total_chunks INT NOT NULL,
|
||||
embedded_chunks INT NOT NULL DEFAULT 0,
|
||||
last_index INT NOT NULL DEFAULT -1,
|
||||
last_updated TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Column additions on existing tables
|
||||
# ------------------------------------------------------------------
|
||||
# DEFAULT 'complete' backfills existing rows — they're already done.
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE conversation_messages
|
||||
ADD COLUMN status TEXT NOT NULL DEFAULT 'complete'
|
||||
CHECK (status IN ('pending', 'streaming', 'complete', 'failed')),
|
||||
ADD COLUMN request_id TEXT;
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE pending_tool_state
|
||||
ADD COLUMN status TEXT NOT NULL DEFAULT 'pending'
|
||||
CHECK (status IN ('pending', 'resuming')),
|
||||
ADD COLUMN resumed_at TIMESTAMPTZ;
|
||||
"""
|
||||
)
|
||||
|
||||
# Default ``agent_stream`` backfills historical rows under the
|
||||
# assumption they were written from the primary path — pre-fix the
|
||||
# only path that wrote was the error branch reading agent.llm.
|
||||
# ``request_id`` is the stream-scoped UUID stamped by the route on
|
||||
# ``agent.llm`` so multi-tool agent runs (which produce N rows)
|
||||
# collapse to one request via DISTINCT in ``count_in_range``.
|
||||
# Side-channel sources (``title`` / ``compression`` / ``rag_condense``
|
||||
# / ``fallback``) leave it NULL and are excluded from the request
|
||||
# count by source filter.
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE token_usage
|
||||
ADD COLUMN source TEXT NOT NULL DEFAULT 'agent_stream',
|
||||
ADD COLUMN request_id TEXT;
|
||||
"""
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Indexes — partial where the predicate selects only non-terminal rows
|
||||
# ------------------------------------------------------------------
|
||||
op.execute(
|
||||
"CREATE INDEX conversation_messages_pending_ts_idx "
|
||||
"ON conversation_messages (timestamp) "
|
||||
"WHERE status IN ('pending', 'streaming');"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX tool_call_attempts_pending_ts_idx "
|
||||
"ON tool_call_attempts (attempted_at) "
|
||||
"WHERE status IN ('proposed', 'executed');"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX tool_call_attempts_message_idx "
|
||||
"ON tool_call_attempts (message_id) "
|
||||
"WHERE message_id IS NOT NULL;"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX pending_tool_state_resuming_ts_idx "
|
||||
"ON pending_tool_state (resumed_at) "
|
||||
"WHERE status = 'resuming';"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX webhook_dedup_agent_idx "
|
||||
"ON webhook_dedup (agent_id);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX task_dedup_pending_attempts_idx "
|
||||
"ON task_dedup (attempt_count) WHERE status = 'pending';"
|
||||
)
|
||||
# Cost-attribution dashboards filter ``token_usage`` by
|
||||
# ``(timestamp, source)``; index the same shape so they stay cheap.
|
||||
op.execute(
|
||||
"CREATE INDEX token_usage_source_ts_idx "
|
||||
"ON token_usage (source, timestamp);"
|
||||
)
|
||||
# Partial index — only rows with a stamped request_id participate
|
||||
# in the DISTINCT count. NULL rows fall through to the COUNT(*)
|
||||
# branch in the repository query.
|
||||
op.execute(
|
||||
"CREATE INDEX token_usage_request_id_idx "
|
||||
"ON token_usage (request_id) "
|
||||
"WHERE request_id IS NOT NULL;"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"CREATE TRIGGER tool_call_attempts_set_updated_at "
|
||||
"BEFORE UPDATE ON tool_call_attempts "
|
||||
"FOR EACH ROW WHEN (OLD.* IS DISTINCT FROM NEW.*) "
|
||||
"EXECUTE FUNCTION set_updated_at();"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# CASCADE so the downgrade stays safe if later migrations FK into these.
|
||||
for table in (
|
||||
"ingest_chunk_progress",
|
||||
"tool_call_attempts",
|
||||
"webhook_dedup",
|
||||
"task_dedup",
|
||||
):
|
||||
op.execute(f"DROP TABLE IF EXISTS {table} CASCADE;")
|
||||
|
||||
op.execute(
|
||||
"ALTER TABLE conversation_messages "
|
||||
"DROP COLUMN IF EXISTS request_id, "
|
||||
"DROP COLUMN IF EXISTS status;"
|
||||
)
|
||||
op.execute(
|
||||
"ALTER TABLE pending_tool_state "
|
||||
"DROP COLUMN IF EXISTS resumed_at, "
|
||||
"DROP COLUMN IF EXISTS status;"
|
||||
)
|
||||
op.execute("DROP INDEX IF EXISTS token_usage_request_id_idx;")
|
||||
op.execute("DROP INDEX IF EXISTS token_usage_source_ts_idx;")
|
||||
op.execute(
|
||||
"ALTER TABLE token_usage "
|
||||
"DROP COLUMN IF EXISTS request_id, "
|
||||
"DROP COLUMN IF EXISTS source;"
|
||||
)
|
||||
44
application/alembic/versions/0005_ingest_attempt_id.py
Normal file
44
application/alembic/versions/0005_ingest_attempt_id.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""0005 ingest_chunk_progress.attempt_id — per-attempt resume scoping.
|
||||
|
||||
Without this column, a completed checkpoint row poisoned every later
|
||||
embed call on the same ``source_id``: a sync after an upload finished
|
||||
read the upload's terminal ``last_index`` and either embedded zero
|
||||
chunks (if new ``total_docs <= last_index + 1``) or stacked new chunks
|
||||
on top of the old vectors (if ``total_docs > last_index + 1``).
|
||||
|
||||
``attempt_id`` is stamped from ``self.request.id`` (Celery's stable
|
||||
task id, which survives ``acks_late`` retries of the same task but
|
||||
differs across separate task invocations). The repository's
|
||||
``init_progress`` upsert resets ``last_index`` / ``embedded_chunks``
|
||||
when the incoming ``attempt_id`` differs from the stored one — so a
|
||||
fresh sync starts from chunk 0 while a retry of the same task resumes
|
||||
from the last checkpointed chunk.
|
||||
|
||||
Revision ID: 0005_ingest_attempt_id
|
||||
Revises: 0004_durability_foundation
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
revision: str = "0005_ingest_attempt_id"
|
||||
down_revision: Union[str, None] = "0004_durability_foundation"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE ingest_chunk_progress
|
||||
ADD COLUMN attempt_id TEXT;
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
"ALTER TABLE ingest_chunk_progress DROP COLUMN IF EXISTS attempt_id;"
|
||||
)
|
||||
57
application/alembic/versions/0006_idempotency_lease.py
Normal file
57
application/alembic/versions/0006_idempotency_lease.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""0006 task_dedup lease columns — running-lease for in-flight tasks.
|
||||
|
||||
Without these, ``with_idempotency`` only short-circuits *completed*
|
||||
rows. A late-ack redelivery (Redis ``visibility_timeout`` exceeded by a
|
||||
long ingest, or a hung-but-alive worker) hands the same message to a
|
||||
second worker; ``_claim_or_bump`` only bumped the attempt counter and
|
||||
both workers ran the task body in parallel — duplicate vector writes,
|
||||
duplicate token spend, duplicate webhook side effects.
|
||||
|
||||
``lease_owner_id`` + ``lease_expires_at`` turn that into an atomic
|
||||
compare-and-swap. The wrapper claims a lease at entry, refreshes it via
|
||||
a 30 s heartbeat thread, and finalises (which makes the lease moot via
|
||||
``status='completed'``). A second worker hitting the same key sees a
|
||||
fresh lease and ``self.retry(countdown=LEASE_TTL)``s instead of running.
|
||||
A crashed worker's lease expires after ``LEASE_TTL`` seconds and the
|
||||
next retry can claim it.
|
||||
|
||||
Revision ID: 0006_idempotency_lease
|
||||
Revises: 0005_ingest_attempt_id
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
revision: str = "0006_idempotency_lease"
|
||||
down_revision: Union[str, None] = "0005_ingest_attempt_id"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE task_dedup
|
||||
ADD COLUMN lease_owner_id TEXT,
|
||||
ADD COLUMN lease_expires_at TIMESTAMPTZ;
|
||||
"""
|
||||
)
|
||||
# Reconciler's stuck-pending sweep filters by
|
||||
# ``(status='pending', lease_expires_at < now() - 60s, attempt_count >= 5)``.
|
||||
# Partial index keeps the scan small even under heavy task throughput.
|
||||
op.execute(
|
||||
"CREATE INDEX task_dedup_pending_lease_idx "
|
||||
"ON task_dedup (lease_expires_at) "
|
||||
"WHERE status = 'pending';"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP INDEX IF EXISTS task_dedup_pending_lease_idx;")
|
||||
op.execute(
|
||||
"ALTER TABLE task_dedup "
|
||||
"DROP COLUMN IF EXISTS lease_expires_at, "
|
||||
"DROP COLUMN IF EXISTS lease_owner_id;"
|
||||
)
|
||||
40
application/alembic/versions/0007_message_events.py
Normal file
40
application/alembic/versions/0007_message_events.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""0007 message_events — durable journal of chat-stream events.
|
||||
|
||||
Snapshot half of the chat-stream snapshot+tail pattern. Composite PK
|
||||
``(message_id, sequence_no)``, ``created_at`` indexed for retention
|
||||
sweeps, ``ON DELETE CASCADE`` from ``conversation_messages``.
|
||||
|
||||
Revision ID: 0007_message_events
|
||||
Revises: 0006_idempotency_lease
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
revision: str = "0007_message_events"
|
||||
down_revision: Union[str, None] = "0006_idempotency_lease"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE message_events (
|
||||
message_id UUID NOT NULL REFERENCES conversation_messages(id) ON DELETE CASCADE,
|
||||
sequence_no INTEGER NOT NULL,
|
||||
event_type TEXT NOT NULL,
|
||||
payload JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
PRIMARY KEY (message_id, sequence_no)
|
||||
);
|
||||
CREATE INDEX message_events_created_at_idx ON message_events(created_at);
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP INDEX IF EXISTS message_events_created_at_idx;")
|
||||
op.execute("DROP TABLE IF EXISTS message_events;")
|
||||
44
application/alembic/versions/0008_ingest_progress_status.py
Normal file
44
application/alembic/versions/0008_ingest_progress_status.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""0008 ingest_chunk_progress.status — terminal flag for stalled ingests.
|
||||
|
||||
The reconciler's stalled-ingest sweep had no terminal write, so a dead
|
||||
ingest re-alerted every ~30 min forever. ``status`` lets it escalate a
|
||||
stalled checkpoint to ``'stalled'`` once and stop re-selecting it;
|
||||
``init_progress`` resets it to ``'active'`` on reingest.
|
||||
|
||||
Revision ID: 0008_ingest_progress_status
|
||||
Revises: 0007_message_events
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
revision: str = "0008_ingest_progress_status"
|
||||
down_revision: Union[str, None] = "0007_message_events"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Constant DEFAULT — metadata-only ADD COLUMN, no table rewrite.
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE ingest_chunk_progress
|
||||
ADD COLUMN status TEXT NOT NULL DEFAULT 'active'
|
||||
CHECK (status IN ('active', 'stalled'));
|
||||
"""
|
||||
)
|
||||
# Partial index for the reconciler's stalled-ingest sweep.
|
||||
op.execute(
|
||||
"CREATE INDEX ingest_chunk_progress_active_idx "
|
||||
"ON ingest_chunk_progress (last_updated) "
|
||||
"WHERE status = 'active';"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP INDEX IF EXISTS ingest_chunk_progress_active_idx;")
|
||||
op.execute(
|
||||
"ALTER TABLE ingest_chunk_progress DROP COLUMN IF EXISTS status;"
|
||||
)
|
||||
83
application/alembic/versions/0009_tool_preferences.py
Normal file
83
application/alembic/versions/0009_tool_preferences.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""0009 default chat tools — users.tool_preferences + memories.tool_id.
|
||||
|
||||
Adds ``users.tool_preferences`` JSONB and drops the
|
||||
``memories.tool_id`` FK to ``user_tools`` (synthetic default-tool ids
|
||||
have no ``user_tools`` row). Delete-cascade for real tools is kept via
|
||||
an AFTER DELETE trigger on ``user_tools``. Idempotent both ways.
|
||||
|
||||
Revision ID: 0009_tool_preferences
|
||||
Revises: 0008_ingest_progress_status
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
revision: str = "0009_tool_preferences"
|
||||
down_revision: Union[str, None] = "0008_ingest_progress_status"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE users
|
||||
ADD COLUMN IF NOT EXISTS tool_preferences JSONB
|
||||
NOT NULL DEFAULT '{}'::jsonb;
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"ALTER TABLE memories DROP CONSTRAINT IF EXISTS memories_tool_id_fkey;"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE OR REPLACE FUNCTION cleanup_tool_memories() RETURNS trigger
|
||||
LANGUAGE plpgsql AS $$
|
||||
BEGIN
|
||||
DELETE FROM memories WHERE tool_id = OLD.id;
|
||||
RETURN OLD;
|
||||
END;
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
# DROP-then-CREATE — no CREATE OR REPLACE TRIGGER for this signature.
|
||||
op.execute(
|
||||
"DROP TRIGGER IF EXISTS user_tools_cleanup_memories ON user_tools;"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE TRIGGER user_tools_cleanup_memories "
|
||||
"AFTER DELETE ON user_tools "
|
||||
"FOR EACH ROW EXECUTE FUNCTION cleanup_tool_memories();"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
"DROP TRIGGER IF EXISTS user_tools_cleanup_memories ON user_tools;"
|
||||
)
|
||||
op.execute("DROP FUNCTION IF EXISTS cleanup_tool_memories();")
|
||||
# DESTRUCTIVE: restoring the FK requires every memories.tool_id to
|
||||
# reference a real user_tools row. Any memory written by a built-in
|
||||
# default tool (synthetic uuid5 id, no user_tools row) is permanently
|
||||
# DELETED here so the constraint can be re-created. Downgrading 0009
|
||||
# therefore loses all built-in-memory-tool data — by necessity, since
|
||||
# the restored schema cannot represent it.
|
||||
op.execute(
|
||||
"""
|
||||
DELETE FROM memories
|
||||
WHERE tool_id IS NOT NULL
|
||||
AND tool_id NOT IN (SELECT id FROM user_tools);
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE memories
|
||||
ADD CONSTRAINT memories_tool_id_fkey
|
||||
FOREIGN KEY (tool_id) REFERENCES user_tools(id) ON DELETE CASCADE;
|
||||
"""
|
||||
)
|
||||
op.execute("ALTER TABLE users DROP COLUMN IF EXISTS tool_preferences;")
|
||||
147
application/alembic/versions/0010_schedules.py
Normal file
147
application/alembic/versions/0010_schedules.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""0010 scheduler — schedules + schedule_runs tables.
|
||||
|
||||
Revision ID: 0010_schedules
|
||||
Revises: 0009_tool_preferences
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
revision: str = "0010_schedules"
|
||||
down_revision: Union[str, None] = "0009_tool_preferences"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE schedules (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
agent_id UUID NOT NULL REFERENCES agents(id) ON DELETE CASCADE,
|
||||
trigger_type TEXT NOT NULL,
|
||||
name TEXT,
|
||||
instruction TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'active',
|
||||
cron TEXT,
|
||||
run_at TIMESTAMPTZ,
|
||||
timezone TEXT NOT NULL DEFAULT 'UTC',
|
||||
next_run_at TIMESTAMPTZ,
|
||||
last_run_at TIMESTAMPTZ,
|
||||
end_at TIMESTAMPTZ,
|
||||
tool_allowlist JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||
model_id TEXT,
|
||||
token_budget INTEGER,
|
||||
origin_conversation_id UUID REFERENCES conversations(id) ON DELETE SET NULL,
|
||||
created_via TEXT NOT NULL DEFAULT 'ui',
|
||||
consecutive_failure_count INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
CONSTRAINT schedules_trigger_type_chk
|
||||
CHECK (trigger_type IN ('once', 'recurring')),
|
||||
CONSTRAINT schedules_status_chk
|
||||
CHECK (status IN ('active', 'paused', 'completed', 'cancelled')),
|
||||
CONSTRAINT schedules_created_via_chk
|
||||
CHECK (created_via IN ('chat', 'ui')),
|
||||
CONSTRAINT schedules_recurring_cron_chk
|
||||
CHECK (trigger_type <> 'recurring' OR cron IS NOT NULL),
|
||||
CONSTRAINT schedules_once_run_at_chk
|
||||
CHECK (trigger_type <> 'once' OR run_at IS NOT NULL)
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"CREATE INDEX schedules_user_idx ON schedules (user_id);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX schedules_agent_idx ON schedules (agent_id);"
|
||||
)
|
||||
# Dispatcher hot path: status='active' AND next_run_at <= now().
|
||||
op.execute(
|
||||
"CREATE INDEX schedules_due_idx "
|
||||
"ON schedules (status, next_run_at) "
|
||||
"WHERE status = 'active';"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"CREATE TRIGGER schedules_set_updated_at "
|
||||
"BEFORE UPDATE ON schedules "
|
||||
"FOR EACH ROW EXECUTE FUNCTION set_updated_at();"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE schedule_runs (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
schedule_id UUID NOT NULL REFERENCES schedules(id) ON DELETE CASCADE,
|
||||
user_id TEXT NOT NULL,
|
||||
agent_id UUID NOT NULL REFERENCES agents(id) ON DELETE CASCADE,
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
scheduled_for TIMESTAMPTZ NOT NULL,
|
||||
trigger_source TEXT NOT NULL DEFAULT 'cron',
|
||||
started_at TIMESTAMPTZ,
|
||||
finished_at TIMESTAMPTZ,
|
||||
output TEXT,
|
||||
output_truncated BOOLEAN NOT NULL DEFAULT false,
|
||||
error TEXT,
|
||||
error_type TEXT,
|
||||
prompt_tokens INTEGER NOT NULL DEFAULT 0,
|
||||
generated_tokens INTEGER NOT NULL DEFAULT 0,
|
||||
conversation_id UUID,
|
||||
message_id UUID,
|
||||
celery_task_id TEXT,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
CONSTRAINT schedule_runs_status_chk
|
||||
CHECK (status IN (
|
||||
'pending', 'running', 'success', 'failed', 'skipped', 'timeout'
|
||||
)),
|
||||
CONSTRAINT schedule_runs_trigger_source_chk
|
||||
CHECK (trigger_source IN ('cron', 'manual')),
|
||||
CONSTRAINT schedule_runs_error_type_chk
|
||||
CHECK (error_type IS NULL OR error_type IN (
|
||||
'auth_expired', 'tool_not_allowed', 'budget_exceeded',
|
||||
'timeout', 'agent_error', 'internal', 'missed', 'overlap'
|
||||
))
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# Dedup primitive: racing dispatchers hit ON CONFLICT on this index.
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX schedule_runs_dedup_uidx "
|
||||
"ON schedule_runs (schedule_id, scheduled_for);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX schedule_runs_schedule_recent_idx "
|
||||
"ON schedule_runs (schedule_id, scheduled_for DESC);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX schedule_runs_user_idx ON schedule_runs (user_id);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX schedule_runs_running_idx "
|
||||
"ON schedule_runs (status, started_at) "
|
||||
"WHERE status = 'running';"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE TRIGGER schedule_runs_set_updated_at "
|
||||
"BEFORE UPDATE ON schedule_runs "
|
||||
"FOR EACH ROW EXECUTE FUNCTION set_updated_at();"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop triggers explicitly (grep-able) before CASCADE-dropping the tables.
|
||||
op.execute(
|
||||
"DROP TRIGGER IF EXISTS schedule_runs_set_updated_at ON schedule_runs;"
|
||||
)
|
||||
op.execute("DROP TABLE IF EXISTS schedule_runs CASCADE;")
|
||||
op.execute(
|
||||
"DROP TRIGGER IF EXISTS schedules_set_updated_at ON schedules;"
|
||||
)
|
||||
op.execute("DROP TABLE IF EXISTS schedules CASCADE;")
|
||||
@@ -0,0 +1,53 @@
|
||||
"""0011 scheduler — make schedules.agent_id / schedule_runs.agent_id nullable.
|
||||
|
||||
Agentless schedules (created from agentless chats via the dual-registered
|
||||
``scheduler`` default chat tool) carry ``agent_id IS NULL``. Existing FK +
|
||||
``ON DELETE CASCADE`` semantics on ``agents(id)`` are unaffected — Postgres
|
||||
only cascades when the parent row is deleted, NULL rows aren't matched.
|
||||
|
||||
Revision ID: 0011_schedules_nullable_agent
|
||||
Revises: 0010_schedules
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
revision: str = "0011_schedules_nullable_agent"
|
||||
down_revision: Union[str, None] = "0010_schedules"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute("ALTER TABLE schedules ALTER COLUMN agent_id DROP NOT NULL;")
|
||||
op.execute("ALTER TABLE schedule_runs ALTER COLUMN agent_id DROP NOT NULL;")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Destructive otherwise: agentless rows have agent_id IS NULL by design,
|
||||
# so restoring NOT NULL must fail loudly if any exist.
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
DECLARE
|
||||
sched_nulls INTEGER;
|
||||
run_nulls INTEGER;
|
||||
BEGIN
|
||||
SELECT count(*) INTO sched_nulls
|
||||
FROM schedules WHERE agent_id IS NULL;
|
||||
SELECT count(*) INTO run_nulls
|
||||
FROM schedule_runs WHERE agent_id IS NULL;
|
||||
IF sched_nulls > 0 OR run_nulls > 0 THEN
|
||||
RAISE EXCEPTION
|
||||
'Cannot downgrade 0011: agentless rows present '
|
||||
'(schedules=%, schedule_runs=%). '
|
||||
'Delete or reassign them before retrying.',
|
||||
sched_nulls, run_nulls;
|
||||
END IF;
|
||||
END$$;
|
||||
"""
|
||||
)
|
||||
op.execute("ALTER TABLE schedule_runs ALTER COLUMN agent_id SET NOT NULL;")
|
||||
op.execute("ALTER TABLE schedules ALTER COLUMN agent_id SET NOT NULL;")
|
||||
@@ -102,6 +102,8 @@ class AnswerResource(Resource, BaseAnswerResource):
|
||||
"tools_dict": tools_dict,
|
||||
"pending_tool_calls": pending_tool_calls,
|
||||
"tool_actions": tool_actions,
|
||||
"reserved_message_id": processor.reserved_message_id,
|
||||
"request_id": processor.request_id,
|
||||
},
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, Generator, List, Optional
|
||||
|
||||
from flask import jsonify, make_response, Response
|
||||
from flask_restx import Namespace
|
||||
|
||||
from application.api.answer.services.continuation_service import ContinuationService
|
||||
from application.api.answer.services.conversation_service import ConversationService
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
TERMINATED_RESPONSE_PLACEHOLDER,
|
||||
)
|
||||
from application.core.model_utils import (
|
||||
get_api_key_for_provider,
|
||||
get_default_model_id,
|
||||
@@ -18,9 +23,16 @@ from application.core.settings import settings
|
||||
from application.error import sanitize_api_error
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.conversations import MessageUpdateOutcome
|
||||
from application.storage.db.repositories.token_usage import TokenUsageRepository
|
||||
from application.storage.db.repositories.user_logs import UserLogsRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.events.publisher import publish_user_event
|
||||
from application.streaming.event_replay import format_sse_event
|
||||
from application.streaming.message_journal import (
|
||||
BatchedJournalWriter,
|
||||
record_event,
|
||||
)
|
||||
from application.utils import check_required_fields
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -203,13 +215,199 @@ class BaseAnswerResource:
|
||||
Yields:
|
||||
Server-sent event strings
|
||||
"""
|
||||
response_full, thought, source_log_docs, tool_calls = "", "", [], []
|
||||
is_structured = False
|
||||
schema_info = None
|
||||
structured_chunks = []
|
||||
query_metadata: Dict[str, Any] = {}
|
||||
paused = False
|
||||
|
||||
# One id shared across the WAL row, primary LLM (token_usage
|
||||
# attribution), the SSE event, and resumed continuations.
|
||||
request_id = (
|
||||
_continuation.get("request_id") if _continuation else None
|
||||
) or str(uuid.uuid4())
|
||||
|
||||
# Reserve the placeholder row before the LLM call so a crash
|
||||
# mid-stream still leaves the question queryable. Continuations
|
||||
# reuse the original placeholder.
|
||||
reserved_message_id: Optional[str] = None
|
||||
wal_eligible = should_save_conversation and not _continuation
|
||||
if wal_eligible:
|
||||
try:
|
||||
reservation = self.conversation_service.save_user_question(
|
||||
conversation_id=conversation_id,
|
||||
question=question,
|
||||
decoded_token=decoded_token,
|
||||
attachment_ids=attachment_ids,
|
||||
api_key=user_api_key,
|
||||
agent_id=agent_id,
|
||||
is_shared_usage=is_shared_usage,
|
||||
shared_token=shared_token,
|
||||
model_id=model_id or self.default_model_id,
|
||||
request_id=request_id,
|
||||
index=index,
|
||||
)
|
||||
conversation_id = reservation["conversation_id"]
|
||||
reserved_message_id = reservation["message_id"]
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to reserve message row before stream: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
elif _continuation and _continuation.get("reserved_message_id"):
|
||||
reserved_message_id = _continuation["reserved_message_id"]
|
||||
|
||||
primary_llm = getattr(agent, "llm", None)
|
||||
if primary_llm is not None:
|
||||
primary_llm._request_id = request_id
|
||||
|
||||
# Flipped to ``streaming`` on first chunk; reconciler uses this
|
||||
# to tell "never started" from "in flight".
|
||||
streaming_marked = False
|
||||
# Heartbeat goes into ``metadata.last_heartbeat_at`` (not
|
||||
# ``updated_at``, which reconciler-side writes share) and uses
|
||||
# ``time.monotonic`` so a blocked event loop can't fake fresh.
|
||||
STREAM_HEARTBEAT_INTERVAL = 60
|
||||
last_heartbeat_at = time.monotonic()
|
||||
|
||||
def _mark_streaming_once() -> None:
|
||||
nonlocal streaming_marked, last_heartbeat_at
|
||||
if streaming_marked or not reserved_message_id:
|
||||
return
|
||||
try:
|
||||
self.conversation_service.update_message_status(
|
||||
reserved_message_id, "streaming",
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"update_message_status streaming failed for %s",
|
||||
reserved_message_id,
|
||||
)
|
||||
# Seed last_heartbeat_at so watchdog doesn't fall back to `timestamp`
|
||||
# (creation time) before the first STREAM_HEARTBEAT_INTERVAL tick.
|
||||
try:
|
||||
self.conversation_service.heartbeat_message(
|
||||
reserved_message_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"initial heartbeat seed failed for %s",
|
||||
reserved_message_id,
|
||||
)
|
||||
streaming_marked = True
|
||||
last_heartbeat_at = time.monotonic()
|
||||
|
||||
def _heartbeat_streaming() -> None:
|
||||
nonlocal last_heartbeat_at
|
||||
if not reserved_message_id or not streaming_marked:
|
||||
return
|
||||
now_mono = time.monotonic()
|
||||
if now_mono - last_heartbeat_at < STREAM_HEARTBEAT_INTERVAL:
|
||||
return
|
||||
try:
|
||||
self.conversation_service.heartbeat_message(
|
||||
reserved_message_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"stream heartbeat update failed for %s",
|
||||
reserved_message_id,
|
||||
)
|
||||
last_heartbeat_at = now_mono
|
||||
|
||||
# Correlates tool_call_attempts rows with this message.
|
||||
if reserved_message_id and getattr(agent, "tool_executor", None):
|
||||
try:
|
||||
agent.tool_executor.message_id = reserved_message_id
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Could not set tool_executor.message_id; tool-call correlation will be missing for message_id=%s",
|
||||
reserved_message_id,
|
||||
)
|
||||
# The reservation above may create the conversation row (first turn in
|
||||
# a new chat). Propagate that fresh id to the tool_executor so tools
|
||||
# that need a conversation home (e.g. ``scheduler`` in agentless chats)
|
||||
# see it on the very first call instead of waiting for the next turn.
|
||||
if conversation_id and getattr(agent, "tool_executor", None):
|
||||
try:
|
||||
agent.tool_executor.conversation_id = str(conversation_id)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Could not set tool_executor.conversation_id post-reserve",
|
||||
)
|
||||
|
||||
# Per-stream monotonic SSE event id. Allocated by ``_emit`` and
|
||||
# threaded through both the wire format (``id: <seq>\\n``) and
|
||||
# the journal write so a reconnecting client can ``Last-Event-
|
||||
# ID`` past anything they already saw. Continuations resume
|
||||
# against the original ``reserved_message_id`` — seed the
|
||||
# allocator from the journal's high-water mark so we don't
|
||||
# collide on the duplicate-PK and silently lose every emit
|
||||
# past the resume point.
|
||||
sequence_no = -1
|
||||
if _continuation and reserved_message_id:
|
||||
try:
|
||||
from application.storage.db.repositories.message_events import (
|
||||
MessageEventsRepository,
|
||||
)
|
||||
|
||||
with db_readonly() as conn:
|
||||
latest = MessageEventsRepository(conn).latest_sequence_no(
|
||||
reserved_message_id
|
||||
)
|
||||
if latest is not None:
|
||||
sequence_no = latest
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Continuation seq seed lookup failed for message_id=%s; "
|
||||
"falling back to seq=-1 (duplicate-PK collisions will "
|
||||
"be swallowed)",
|
||||
reserved_message_id,
|
||||
)
|
||||
|
||||
# One batched journal writer per stream.
|
||||
journal_writer: Optional[BatchedJournalWriter] = (
|
||||
BatchedJournalWriter(reserved_message_id)
|
||||
if reserved_message_id
|
||||
else None
|
||||
)
|
||||
|
||||
def _emit(payload: dict) -> str:
|
||||
"""Format-and-journal one SSE event.
|
||||
|
||||
With a reserved ``message_id``, buffers into the journal and
|
||||
emits ``id: <seq>``-tagged SSE frames; otherwise falls back to
|
||||
legacy ``data: ...\\n\\n`` framing.
|
||||
"""
|
||||
nonlocal sequence_no
|
||||
if not reserved_message_id or journal_writer is None:
|
||||
return f"data: {json.dumps(payload)}\n\n"
|
||||
sequence_no += 1
|
||||
seq = sequence_no
|
||||
event_type = (
|
||||
payload.get("type", "data")
|
||||
if isinstance(payload, dict)
|
||||
else "data"
|
||||
)
|
||||
normalised = payload if isinstance(payload, dict) else {"value": payload}
|
||||
journal_writer.record(seq, event_type, normalised)
|
||||
return format_sse_event(normalised, seq)
|
||||
|
||||
try:
|
||||
response_full, thought, source_log_docs, tool_calls = "", "", [], []
|
||||
is_structured = False
|
||||
schema_info = None
|
||||
structured_chunks = []
|
||||
query_metadata = {}
|
||||
paused = False
|
||||
# Surface the placeholder id before any LLM tokens so a
|
||||
# mid-handshake disconnect still has a row to tail-poll.
|
||||
if reserved_message_id:
|
||||
yield _emit(
|
||||
{
|
||||
"type": "message_id",
|
||||
"message_id": reserved_message_id,
|
||||
"conversation_id": (
|
||||
str(conversation_id) if conversation_id else None
|
||||
),
|
||||
"request_id": request_id,
|
||||
}
|
||||
)
|
||||
|
||||
if _continuation:
|
||||
gen_iter = agent.gen_continuation(
|
||||
@@ -222,18 +420,24 @@ class BaseAnswerResource:
|
||||
gen_iter = agent.gen(query=question)
|
||||
|
||||
for line in gen_iter:
|
||||
# Cheap closure check that only hits the DB when the
|
||||
# heartbeat interval has elapsed.
|
||||
_heartbeat_streaming()
|
||||
if "metadata" in line:
|
||||
query_metadata.update(line["metadata"])
|
||||
elif "answer" in line:
|
||||
_mark_streaming_once()
|
||||
response_full += str(line["answer"])
|
||||
if line.get("structured"):
|
||||
is_structured = True
|
||||
schema_info = line.get("schema")
|
||||
structured_chunks.append(line["answer"])
|
||||
else:
|
||||
data = json.dumps({"type": "answer", "answer": line["answer"]})
|
||||
yield f"data: {data}\n\n"
|
||||
yield _emit(
|
||||
{"type": "answer", "answer": line["answer"]}
|
||||
)
|
||||
elif "sources" in line:
|
||||
_mark_streaming_once()
|
||||
truncated_sources = []
|
||||
source_log_docs = line["sources"]
|
||||
for source in line["sources"]:
|
||||
@@ -244,54 +448,48 @@ class BaseAnswerResource:
|
||||
)
|
||||
truncated_sources.append(truncated_source)
|
||||
if truncated_sources:
|
||||
data = json.dumps(
|
||||
yield _emit(
|
||||
{"type": "source", "source": truncated_sources}
|
||||
)
|
||||
yield f"data: {data}\n\n"
|
||||
elif "tool_calls" in line:
|
||||
tool_calls = line["tool_calls"]
|
||||
data = json.dumps({"type": "tool_calls", "tool_calls": tool_calls})
|
||||
yield f"data: {data}\n\n"
|
||||
yield _emit({"type": "tool_calls", "tool_calls": tool_calls})
|
||||
elif "thought" in line:
|
||||
thought += line["thought"]
|
||||
data = json.dumps({"type": "thought", "thought": line["thought"]})
|
||||
yield f"data: {data}\n\n"
|
||||
yield _emit({"type": "thought", "thought": line["thought"]})
|
||||
elif "type" in line:
|
||||
if line.get("type") == "tool_calls_pending":
|
||||
# Save continuation state and end the stream
|
||||
paused = True
|
||||
data = json.dumps(line)
|
||||
yield f"data: {data}\n\n"
|
||||
yield _emit(line)
|
||||
elif line.get("type") == "error":
|
||||
sanitized_error = {
|
||||
"type": "error",
|
||||
"error": sanitize_api_error(line.get("error", "An error occurred"))
|
||||
}
|
||||
data = json.dumps(sanitized_error)
|
||||
yield f"data: {data}\n\n"
|
||||
yield _emit(
|
||||
{
|
||||
"type": "error",
|
||||
"error": sanitize_api_error(
|
||||
line.get("error", "An error occurred")
|
||||
),
|
||||
}
|
||||
)
|
||||
else:
|
||||
data = json.dumps(line)
|
||||
yield f"data: {data}\n\n"
|
||||
yield _emit(line)
|
||||
if is_structured and structured_chunks:
|
||||
structured_data = {
|
||||
"type": "structured_answer",
|
||||
"answer": response_full,
|
||||
"structured": True,
|
||||
"schema": schema_info,
|
||||
}
|
||||
data = json.dumps(structured_data)
|
||||
yield f"data: {data}\n\n"
|
||||
yield _emit(
|
||||
{
|
||||
"type": "structured_answer",
|
||||
"answer": response_full,
|
||||
"structured": True,
|
||||
"schema": schema_info,
|
||||
}
|
||||
)
|
||||
|
||||
# ---- Paused: save continuation state and end stream early ----
|
||||
if paused:
|
||||
continuation = getattr(agent, "_pending_continuation", None)
|
||||
if continuation:
|
||||
# Ensure we have a conversation_id — create a partial
|
||||
# conversation if this is the first turn.
|
||||
# First-turn pause needs a conversation row to attach to.
|
||||
if not conversation_id and should_save_conversation:
|
||||
try:
|
||||
# Use model-owner scope so shared-agent
|
||||
# owner-BYOM resolves to its registered plugin.
|
||||
provider = (
|
||||
get_provider_from_model_id(
|
||||
model_id,
|
||||
@@ -340,6 +538,7 @@ class BaseAnswerResource:
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
state_saved = False
|
||||
if conversation_id:
|
||||
try:
|
||||
cont_service = ContinuationService()
|
||||
@@ -352,8 +551,8 @@ class BaseAnswerResource:
|
||||
tool_schemas=getattr(agent, "tools", []),
|
||||
agent_config={
|
||||
"model_id": model_id or self.default_model_id,
|
||||
# Persist BYOM scope so resume doesn't
|
||||
# fall back to caller's layer.
|
||||
# BYOM scope; without it resume falls
|
||||
# back to caller's layer.
|
||||
"model_user_id": model_user_id,
|
||||
"llm_name": getattr(agent, "llm_name", settings.LLM_PROVIDER),
|
||||
"api_key": getattr(agent, "api_key", None),
|
||||
@@ -363,30 +562,81 @@ class BaseAnswerResource:
|
||||
"prompt": getattr(agent, "prompt", ""),
|
||||
"json_schema": getattr(agent, "json_schema", None),
|
||||
"retriever_config": getattr(agent, "retriever_config", None),
|
||||
# Reused on resume so the same WAL row
|
||||
# is finalised and request_id stays
|
||||
# consistent across token_usage rows.
|
||||
"reserved_message_id": reserved_message_id,
|
||||
"request_id": request_id,
|
||||
},
|
||||
client_tools=getattr(
|
||||
agent.tool_executor, "client_tools", None
|
||||
),
|
||||
)
|
||||
state_saved = True
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to save continuation state: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
id_data = {"type": "id", "id": str(conversation_id)}
|
||||
data = json.dumps(id_data)
|
||||
yield f"data: {data}\n\n"
|
||||
# Notify the user out-of-band so they can navigate
|
||||
# back to the conversation and decide on the
|
||||
# pending tool calls. Gated on ``state_saved``: a
|
||||
# missing pending_tool_state row would 404 the
|
||||
# resume endpoint, so an unfulfillable notification
|
||||
# is worse than no notification.
|
||||
user_id_for_event = (
|
||||
decoded_token.get("sub") if decoded_token else None
|
||||
)
|
||||
if state_saved and user_id_for_event and conversation_id:
|
||||
pending_calls = continuation.get(
|
||||
"pending_tool_calls", []
|
||||
) if continuation else []
|
||||
# Trim each pending tool call to its identifying
|
||||
# metadata so a tool with a multi-MB argument
|
||||
# doesn't blow out the per-event payload size
|
||||
# cap. The resume page fetches full args from
|
||||
# ``pending_tool_state`` regardless.
|
||||
pending_summaries = [
|
||||
{
|
||||
k: tc.get(k)
|
||||
for k in (
|
||||
"call_id",
|
||||
"tool_name",
|
||||
"action_name",
|
||||
"name",
|
||||
)
|
||||
if isinstance(tc, dict) and tc.get(k) is not None
|
||||
}
|
||||
for tc in (pending_calls or [])
|
||||
if isinstance(tc, dict)
|
||||
]
|
||||
publish_user_event(
|
||||
user_id_for_event,
|
||||
"tool.approval.required",
|
||||
{
|
||||
"conversation_id": str(conversation_id),
|
||||
"message_id": reserved_message_id,
|
||||
"pending_tool_calls": pending_summaries,
|
||||
},
|
||||
scope={
|
||||
"kind": "conversation",
|
||||
"id": str(conversation_id),
|
||||
},
|
||||
)
|
||||
|
||||
data = json.dumps({"type": "end"})
|
||||
yield f"data: {data}\n\n"
|
||||
yield _emit({"type": "id", "id": str(conversation_id)})
|
||||
yield _emit({"type": "end"})
|
||||
# Drain the terminal ``end`` so a reconnecting client
|
||||
# sees it on snapshot — same reason as the main exit.
|
||||
if journal_writer is not None:
|
||||
journal_writer.close()
|
||||
return
|
||||
|
||||
if isNoneDoc:
|
||||
for doc in source_log_docs:
|
||||
doc["source"] = "None"
|
||||
# Run under model-owner scope so title-gen LLM inside
|
||||
# save_conversation uses the owner's BYOM provider/key.
|
||||
# Model-owner scope so title-gen uses owner's BYOM key.
|
||||
provider = (
|
||||
get_provider_from_model_id(
|
||||
model_id,
|
||||
@@ -407,26 +657,49 @@ class BaseAnswerResource:
|
||||
agent_id=agent_id,
|
||||
model_user_id=model_user_id,
|
||||
)
|
||||
# Title-gen only; agent stream tokens live on ``agent.llm``.
|
||||
llm._token_usage_source = "title"
|
||||
|
||||
if should_save_conversation:
|
||||
conversation_id = self.conversation_service.save_conversation(
|
||||
conversation_id,
|
||||
question,
|
||||
response_full,
|
||||
thought,
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
model_id or self.default_model_id,
|
||||
decoded_token,
|
||||
index=index,
|
||||
api_key=user_api_key,
|
||||
agent_id=agent_id,
|
||||
is_shared_usage=is_shared_usage,
|
||||
shared_token=shared_token,
|
||||
attachment_ids=attachment_ids,
|
||||
metadata=query_metadata if query_metadata else None,
|
||||
)
|
||||
if reserved_message_id is not None:
|
||||
self.conversation_service.finalize_message(
|
||||
reserved_message_id,
|
||||
response_full,
|
||||
thought=thought,
|
||||
sources=source_log_docs,
|
||||
tool_calls=tool_calls,
|
||||
model_id=model_id or self.default_model_id,
|
||||
metadata=query_metadata if query_metadata else None,
|
||||
status="complete",
|
||||
title_inputs={
|
||||
"llm": llm,
|
||||
"question": question,
|
||||
"response": response_full,
|
||||
"model_id": model_id or self.default_model_id,
|
||||
"fallback_name": (
|
||||
question[:50] if question else "New Conversation"
|
||||
),
|
||||
},
|
||||
)
|
||||
else:
|
||||
conversation_id = self.conversation_service.save_conversation(
|
||||
conversation_id,
|
||||
question,
|
||||
response_full,
|
||||
thought,
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
model_id or self.default_model_id,
|
||||
decoded_token,
|
||||
index=index,
|
||||
api_key=user_api_key,
|
||||
agent_id=agent_id,
|
||||
is_shared_usage=is_shared_usage,
|
||||
shared_token=shared_token,
|
||||
attachment_ids=attachment_ids,
|
||||
metadata=query_metadata if query_metadata else None,
|
||||
)
|
||||
# Persist compression metadata/summary if it exists and wasn't saved mid-execution
|
||||
compression_meta = getattr(agent, "compression_metadata", None)
|
||||
compression_saved = getattr(agent, "compression_saved", False)
|
||||
@@ -449,9 +722,22 @@ class BaseAnswerResource:
|
||||
)
|
||||
else:
|
||||
conversation_id = None
|
||||
id_data = {"type": "id", "id": str(conversation_id)}
|
||||
data = json.dumps(id_data)
|
||||
yield f"data: {data}\n\n"
|
||||
# Resume finished cleanly; drop the continuation row.
|
||||
# Crash-paths leave it ``resuming`` for the janitor to revert.
|
||||
if _continuation and conversation_id:
|
||||
try:
|
||||
cont_service = ContinuationService()
|
||||
cont_service.delete_state(
|
||||
str(conversation_id),
|
||||
decoded_token.get("sub", "local"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to delete continuation state on resume "
|
||||
f"completion: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
yield _emit({"type": "id", "id": str(conversation_id)})
|
||||
|
||||
tool_calls_for_logging = self._prepare_tool_calls_for_logging(
|
||||
getattr(agent, "tool_calls", tool_calls) or tool_calls
|
||||
@@ -492,21 +778,40 @@ class BaseAnswerResource:
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
data = json.dumps({"type": "end"})
|
||||
yield f"data: {data}\n\n"
|
||||
yield _emit({"type": "end"})
|
||||
# Drain the journal buffer so the terminal ``end`` event is
|
||||
# visible to any reconnecting client. Without this the
|
||||
# client could snapshot up to the last flush boundary and
|
||||
# then live-tail waiting for an ``end`` that's still
|
||||
# sitting in memory.
|
||||
if journal_writer is not None:
|
||||
journal_writer.close()
|
||||
except GeneratorExit:
|
||||
logger.info(f"Stream aborted by client for question: {question[:50]}... ")
|
||||
# Drain any buffered events before the terminal one-shot
|
||||
# ``record_event`` below — keeps the journal's seq order
|
||||
# contiguous (buffered events ... terminal event). ``close``
|
||||
# is idempotent; pairing it with ``flush`` matches the
|
||||
# normal-exit and error branches so any future ``record()``
|
||||
# past this point would log instead of silently buffering.
|
||||
if journal_writer is not None:
|
||||
journal_writer.flush()
|
||||
journal_writer.close()
|
||||
# Save partial response
|
||||
|
||||
# Whether the DB row was flipped to ``complete`` during this
|
||||
# abort handler. Drives the choice of terminal journal event
|
||||
# below: journal ``end`` only when the row actually matches,
|
||||
# else journal ``error`` so a reconnecting client sees a
|
||||
# failed terminal state instead of a blank "success".
|
||||
finalized_complete = False
|
||||
if should_save_conversation and response_full:
|
||||
try:
|
||||
if isNoneDoc:
|
||||
for doc in source_log_docs:
|
||||
doc["source"] = "None"
|
||||
# Mirror the normal-path provider resolution so the
|
||||
# partial-save title LLM uses the model-owner's BYOM
|
||||
# registration (shared-agent dispatch) rather than
|
||||
# the deployment default with the instance api key.
|
||||
# Resolve under model-owner scope so shared-agent
|
||||
# title-gen uses owner BYOM, not deployment default.
|
||||
provider = (
|
||||
get_provider_from_model_id(
|
||||
model_id,
|
||||
@@ -532,24 +837,58 @@ class BaseAnswerResource:
|
||||
agent_id=agent_id,
|
||||
model_user_id=model_user_id,
|
||||
)
|
||||
self.conversation_service.save_conversation(
|
||||
conversation_id,
|
||||
question,
|
||||
response_full,
|
||||
thought,
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
model_id or self.default_model_id,
|
||||
decoded_token,
|
||||
index=index,
|
||||
api_key=user_api_key,
|
||||
agent_id=agent_id,
|
||||
is_shared_usage=is_shared_usage,
|
||||
shared_token=shared_token,
|
||||
attachment_ids=attachment_ids,
|
||||
metadata=query_metadata if query_metadata else None,
|
||||
)
|
||||
llm._token_usage_source = "title"
|
||||
if reserved_message_id is not None:
|
||||
outcome = self.conversation_service.finalize_message(
|
||||
reserved_message_id,
|
||||
response_full,
|
||||
thought=thought,
|
||||
sources=source_log_docs,
|
||||
tool_calls=tool_calls,
|
||||
model_id=model_id or self.default_model_id,
|
||||
metadata=query_metadata if query_metadata else None,
|
||||
status="complete",
|
||||
title_inputs={
|
||||
"llm": llm,
|
||||
"question": question,
|
||||
"response": response_full,
|
||||
"model_id": model_id or self.default_model_id,
|
||||
"fallback_name": (
|
||||
question[:50] if question else "New Conversation"
|
||||
),
|
||||
},
|
||||
)
|
||||
# ``ALREADY_COMPLETE`` means the normal-path
|
||||
# finalize at line 632 won the race: the DB row
|
||||
# is already at ``complete`` and the reconnect
|
||||
# journal should reflect that with ``end``,
|
||||
# not a spurious ``error``.
|
||||
finalized_complete = outcome in (
|
||||
MessageUpdateOutcome.UPDATED,
|
||||
MessageUpdateOutcome.ALREADY_COMPLETE,
|
||||
)
|
||||
else:
|
||||
self.conversation_service.save_conversation(
|
||||
conversation_id,
|
||||
question,
|
||||
response_full,
|
||||
thought,
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
model_id or self.default_model_id,
|
||||
decoded_token,
|
||||
index=index,
|
||||
api_key=user_api_key,
|
||||
agent_id=agent_id,
|
||||
is_shared_usage=is_shared_usage,
|
||||
shared_token=shared_token,
|
||||
attachment_ids=attachment_ids,
|
||||
metadata=query_metadata if query_metadata else None,
|
||||
)
|
||||
# No journal row to gate, but flag the save as
|
||||
# successful for symmetry with the WAL path.
|
||||
finalized_complete = True
|
||||
compression_meta = getattr(agent, "compression_metadata", None)
|
||||
compression_saved = getattr(agent, "compression_saved", False)
|
||||
if conversation_id and compression_meta and not compression_saved:
|
||||
@@ -573,16 +912,94 @@ class BaseAnswerResource:
|
||||
logger.error(
|
||||
f"Error saving partial response: {str(e)}", exc_info=True
|
||||
)
|
||||
# Journal a terminal event so reconnecting clients stop tailing;
|
||||
# ``end`` only when the row is ``complete``, else ``error``.
|
||||
if reserved_message_id is not None:
|
||||
try:
|
||||
sequence_no += 1
|
||||
if finalized_complete:
|
||||
# Match the wire shape ``_emit({"type": "end"})``
|
||||
# uses on the normal path — the replay terminal
|
||||
# check at ``event_replay._payload_is_terminal``
|
||||
# reads ``payload.type``, and the frontend parses
|
||||
# the same key off ``data:``.
|
||||
record_event(
|
||||
reserved_message_id,
|
||||
sequence_no,
|
||||
"end",
|
||||
{"type": "end"},
|
||||
)
|
||||
else:
|
||||
# Nothing was persisted under the complete status
|
||||
# — mark the row failed so the reconciler doesn't
|
||||
# need to sweep it, and journal an ``error`` so a
|
||||
# reconnecting client surfaces the same failure
|
||||
# the UI would show on a live error.
|
||||
try:
|
||||
self.conversation_service.finalize_message(
|
||||
reserved_message_id,
|
||||
response_full or TERMINATED_RESPONSE_PLACEHOLDER,
|
||||
thought=thought,
|
||||
sources=source_log_docs,
|
||||
tool_calls=tool_calls,
|
||||
model_id=model_id or self.default_model_id,
|
||||
metadata=query_metadata if query_metadata else None,
|
||||
status="failed",
|
||||
error=ConnectionError(
|
||||
"client disconnected before response was persisted"
|
||||
),
|
||||
)
|
||||
except Exception as fin_err:
|
||||
logger.error(
|
||||
f"Failed to mark aborted message failed: {fin_err}",
|
||||
exc_info=True,
|
||||
)
|
||||
record_event(
|
||||
reserved_message_id,
|
||||
sequence_no,
|
||||
"error",
|
||||
{
|
||||
"type": "error",
|
||||
"error": "Stream aborted before any response was produced.",
|
||||
"code": "client_disconnect",
|
||||
},
|
||||
)
|
||||
except Exception as journal_err:
|
||||
logger.error(
|
||||
f"Failed to journal terminal event on abort: {journal_err}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in stream: {str(e)}", exc_info=True)
|
||||
data = json.dumps(
|
||||
if reserved_message_id is not None:
|
||||
try:
|
||||
self.conversation_service.finalize_message(
|
||||
reserved_message_id,
|
||||
response_full or TERMINATED_RESPONSE_PLACEHOLDER,
|
||||
thought=thought,
|
||||
sources=source_log_docs,
|
||||
tool_calls=tool_calls,
|
||||
model_id=model_id or self.default_model_id,
|
||||
metadata=query_metadata if query_metadata else None,
|
||||
status="failed",
|
||||
error=e,
|
||||
)
|
||||
except Exception as fin_err:
|
||||
logger.error(
|
||||
f"Failed to finalize errored message: {fin_err}",
|
||||
exc_info=True,
|
||||
)
|
||||
yield _emit(
|
||||
{
|
||||
"type": "error",
|
||||
"error": "Please try again later. We apologize for any inconvenience.",
|
||||
}
|
||||
)
|
||||
yield f"data: {data}\n\n"
|
||||
# Drain the terminal ``error`` event we just yielded so a
|
||||
# reconnecting client sees it on snapshot.
|
||||
if journal_writer is not None:
|
||||
journal_writer.close()
|
||||
return
|
||||
|
||||
def process_response_stream(self, stream) -> Dict[str, Any]:
|
||||
@@ -604,8 +1021,22 @@ class BaseAnswerResource:
|
||||
|
||||
for line in stream:
|
||||
try:
|
||||
event_data = line.replace("data: ", "").strip()
|
||||
# Each chunk may carry an ``id: <seq>`` header before
|
||||
# the ``data:`` line. Pull just the ``data:`` body so
|
||||
# the JSON decode doesn't choke on the SSE framing.
|
||||
event_data = ""
|
||||
for raw in line.split("\n"):
|
||||
if raw.startswith("data:"):
|
||||
event_data = raw[len("data:") :].lstrip()
|
||||
break
|
||||
if not event_data:
|
||||
continue
|
||||
event = json.loads(event_data)
|
||||
# The ``message_id`` event is informational for the
|
||||
# streaming consumer and has no synchronous-API field;
|
||||
# skip it so the type-switch below doesn't KeyError.
|
||||
if event.get("type") == "message_id":
|
||||
continue
|
||||
|
||||
if event["type"] == "id":
|
||||
conversation_id = event["id"]
|
||||
|
||||
135
application/api/answer/routes/messages.py
Normal file
135
application/api/answer/routes/messages.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""GET /api/messages/<message_id>/events — chat-stream reconnect endpoint.
|
||||
|
||||
Authenticates the caller, verifies ``message_id`` belongs to the user,
|
||||
then hands off to ``build_message_event_stream`` for snapshot+tail.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Iterator, Optional
|
||||
|
||||
from flask import Blueprint, Response, jsonify, make_response, request, stream_with_context
|
||||
from sqlalchemy import text
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.session import db_readonly
|
||||
from application.streaming.event_replay import (
|
||||
DEFAULT_KEEPALIVE_SECONDS,
|
||||
DEFAULT_POLL_TIMEOUT_SECONDS,
|
||||
build_message_event_stream,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
messages_bp = Blueprint("message_stream", __name__)
|
||||
|
||||
# A message_id is the canonical UUID hex format. Reject anything else
|
||||
# before the SQL layer so a malformed cookie can't surface as a 500.
|
||||
_MESSAGE_ID_RE = re.compile(
|
||||
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-"
|
||||
r"[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$"
|
||||
)
|
||||
# ``sequence_no`` is a non-negative decimal integer. Anything else is
|
||||
# corrupt client state — fall through to a fresh-replay cursor and let
|
||||
# the snapshot reader catch the client up.
|
||||
_SEQUENCE_NO_RE = re.compile(r"^\d+$")
|
||||
|
||||
|
||||
def _normalise_last_event_id(raw: Optional[str]) -> Optional[int]:
|
||||
if raw is None:
|
||||
return None
|
||||
raw = raw.strip()
|
||||
if not raw or not _SEQUENCE_NO_RE.match(raw):
|
||||
return None
|
||||
return int(raw)
|
||||
|
||||
|
||||
def _user_owns_message(message_id: str, user_id: str) -> bool:
|
||||
"""Return True iff ``message_id`` belongs to ``user_id``."""
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
row = conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT 1 FROM conversation_messages
|
||||
WHERE id = CAST(:id AS uuid)
|
||||
AND user_id = :u
|
||||
LIMIT 1
|
||||
"""
|
||||
),
|
||||
{"id": message_id, "u": user_id},
|
||||
).first()
|
||||
return row is not None
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Ownership lookup failed for message_id=%s user_id=%s",
|
||||
message_id,
|
||||
user_id,
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
@messages_bp.route("/api/messages/<message_id>/events", methods=["GET"])
|
||||
def stream_message_events(message_id: str) -> Response:
|
||||
decoded = getattr(request, "decoded_token", None)
|
||||
user_id = decoded.get("sub") if isinstance(decoded, dict) else None
|
||||
if not user_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Authentication required"}),
|
||||
401,
|
||||
)
|
||||
|
||||
if not _MESSAGE_ID_RE.match(message_id):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid message id"}),
|
||||
400,
|
||||
)
|
||||
|
||||
if not _user_owns_message(message_id, user_id):
|
||||
# Don't disclose whether the row exists — a malicious caller
|
||||
# gets the same 404 whether the id is bogus, taken by another
|
||||
# user, or simply gone.
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Not found"}),
|
||||
404,
|
||||
)
|
||||
|
||||
raw_cursor = request.headers.get("Last-Event-ID") or request.args.get(
|
||||
"last_event_id"
|
||||
)
|
||||
last_event_id = _normalise_last_event_id(raw_cursor)
|
||||
keepalive_seconds = float(
|
||||
getattr(settings, "SSE_KEEPALIVE_SECONDS", DEFAULT_KEEPALIVE_SECONDS)
|
||||
)
|
||||
|
||||
@stream_with_context
|
||||
def generate() -> Iterator[str]:
|
||||
try:
|
||||
yield from build_message_event_stream(
|
||||
message_id,
|
||||
last_event_id=last_event_id,
|
||||
keepalive_seconds=keepalive_seconds,
|
||||
poll_timeout_seconds=DEFAULT_POLL_TIMEOUT_SECONDS,
|
||||
)
|
||||
except GeneratorExit:
|
||||
return
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Reconnect stream crashed for message_id=%s user_id=%s",
|
||||
message_id,
|
||||
user_id,
|
||||
)
|
||||
|
||||
response = Response(generate(), mimetype="text/event-stream")
|
||||
response.headers["Cache-Control"] = "no-store"
|
||||
response.headers["X-Accel-Buffering"] = "no"
|
||||
response.headers["Connection"] = "keep-alive"
|
||||
logger.info(
|
||||
"message.event.connect message_id=%s user_id=%s last_event_id=%s",
|
||||
message_id,
|
||||
user_id,
|
||||
last_event_id if last_event_id is not None else "-",
|
||||
)
|
||||
return response
|
||||
@@ -115,6 +115,8 @@ class StreamResource(Resource, BaseAnswerResource):
|
||||
"tools_dict": tools_dict,
|
||||
"pending_tool_calls": pending_tool_calls,
|
||||
"tool_actions": tool_actions,
|
||||
"reserved_message_id": processor.reserved_message_id,
|
||||
"request_id": processor.request_id,
|
||||
},
|
||||
),
|
||||
mimetype="text/event-stream",
|
||||
|
||||
@@ -160,6 +160,9 @@ class CompressionOrchestrator:
|
||||
agent_id=conversation.get("agent_id"),
|
||||
model_user_id=registry_user_id,
|
||||
)
|
||||
# Side-channel LLM tag — distinguishes compression rows
|
||||
# from primary stream rows for cost-attribution dashboards.
|
||||
compression_llm._token_usage_source = "compression"
|
||||
|
||||
# Create compression service with DB update capability
|
||||
compression_service = CompressionService(
|
||||
|
||||
@@ -12,6 +12,12 @@ logger = logging.getLogger(__name__)
|
||||
class TokenCounter:
|
||||
"""Centralized token counting for conversations and messages."""
|
||||
|
||||
# Per-image token estimate. Provider tokenizers vary widely
|
||||
# (Gemini ~258, GPT-4o 85-1500, Claude ~1500) and the actual cost
|
||||
# depends on resolution/detail we can't see here. Errs slightly high
|
||||
# so the threshold check stays conservative.
|
||||
_IMAGE_PART_TOKEN_ESTIMATE = 1500
|
||||
|
||||
@staticmethod
|
||||
def count_message_tokens(messages: List[Dict]) -> int:
|
||||
"""
|
||||
@@ -29,12 +35,36 @@ class TokenCounter:
|
||||
if isinstance(content, str):
|
||||
total_tokens += num_tokens_from_string(content)
|
||||
elif isinstance(content, list):
|
||||
# Handle structured content (tool calls, etc.)
|
||||
# Handle structured content (tool calls, image parts, etc.)
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
total_tokens += num_tokens_from_string(str(item))
|
||||
total_tokens += TokenCounter._count_content_part(item)
|
||||
return total_tokens
|
||||
|
||||
@staticmethod
|
||||
def _count_content_part(item: Dict) -> int:
|
||||
# Image/file attachments are billed by the provider per image,
|
||||
# not proportional to the inline bytes/base64 string.
|
||||
# ``str(item)`` on a 1MB image inflates the count by ~10000x,
|
||||
# which trips spurious compression and overflows downstream
|
||||
# input limits.
|
||||
item_type = item.get("type")
|
||||
|
||||
if "files" in item:
|
||||
files = item.get("files")
|
||||
count = len(files) if isinstance(files, list) and files else 1
|
||||
return TokenCounter._IMAGE_PART_TOKEN_ESTIMATE * count
|
||||
|
||||
if "image_url" in item or item_type in {
|
||||
"image",
|
||||
"image_url",
|
||||
"input_image",
|
||||
"file",
|
||||
}:
|
||||
return TokenCounter._IMAGE_PART_TOKEN_ESTIMATE
|
||||
|
||||
return num_tokens_from_string(str(item))
|
||||
|
||||
@staticmethod
|
||||
def count_query_tokens(
|
||||
queries: List[Dict[str, Any]], include_tool_calls: bool = True
|
||||
|
||||
@@ -7,13 +7,13 @@ resume later by sending tool_actions.
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||
from application.storage.db.repositories.pending_tool_state import (
|
||||
PendingToolStateRepository,
|
||||
)
|
||||
from application.storage.db.serialization import coerce_pg_native as _make_serializable
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -21,23 +21,9 @@ logger = logging.getLogger(__name__)
|
||||
# TTL for pending states — auto-cleaned after this period
|
||||
PENDING_STATE_TTL_SECONDS = 30 * 60 # 30 minutes
|
||||
|
||||
|
||||
def _make_serializable(obj: Any) -> Any:
|
||||
"""Recursively coerce non-JSON values into JSON-safe forms.
|
||||
|
||||
Handles ``uuid.UUID`` (from PG columns), ``bytes``, and recurses into
|
||||
dicts/lists. Post-Mongo-cutover the ObjectId branch is gone — none of
|
||||
our writers produce them anymore.
|
||||
"""
|
||||
if isinstance(obj, UUID):
|
||||
return str(obj)
|
||||
if isinstance(obj, dict):
|
||||
return {str(k): _make_serializable(v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [_make_serializable(v) for v in obj]
|
||||
if isinstance(obj, bytes):
|
||||
return obj.decode("utf-8", errors="replace")
|
||||
return obj
|
||||
# Re-export so the existing tests at tests/api/answer/services/test_continuation_service_pg.py
|
||||
# can keep importing ``_make_serializable`` from here.
|
||||
__all__ = ["_make_serializable", "ContinuationService", "PENDING_STATE_TTL_SECONDS"]
|
||||
|
||||
|
||||
class ContinuationService:
|
||||
@@ -155,3 +141,23 @@ class ContinuationService:
|
||||
f"Deleted continuation state for conversation {conversation_id}"
|
||||
)
|
||||
return deleted
|
||||
|
||||
def mark_resuming(self, conversation_id: str, user: str) -> bool:
|
||||
"""Flip the pending row to ``resuming`` so a crashed resume can be retried."""
|
||||
with db_session() as conn:
|
||||
conv = ConversationsRepository(conn).get_by_legacy_id(conversation_id)
|
||||
if conv is not None:
|
||||
pg_conv_id = conv["id"]
|
||||
elif looks_like_uuid(conversation_id):
|
||||
pg_conv_id = conversation_id
|
||||
else:
|
||||
return False
|
||||
flipped = PendingToolStateRepository(conn).mark_resuming(
|
||||
pg_conv_id, user
|
||||
)
|
||||
if flipped:
|
||||
logger.info(
|
||||
f"Marked continuation state as resuming for conversation "
|
||||
f"{conversation_id}"
|
||||
)
|
||||
return flipped
|
||||
|
||||
@@ -6,6 +6,7 @@ than held for the duration of a stream.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
@@ -14,13 +15,22 @@ from sqlalchemy import text as sql_text
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||
from application.storage.db.repositories.conversations import (
|
||||
ConversationsRepository,
|
||||
MessageUpdateOutcome,
|
||||
)
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Shown to the user if the worker dies mid-stream and the response is never finalised.
|
||||
TERMINATED_RESPONSE_PLACEHOLDER = (
|
||||
"Response was terminated prior to completion, try regenerating."
|
||||
)
|
||||
|
||||
|
||||
class ConversationService:
|
||||
def get_conversation(
|
||||
self, conversation_id: str, user_id: str
|
||||
@@ -179,6 +189,243 @@ class ConversationService:
|
||||
repo.append_message(conv_pg_id, append_payload)
|
||||
return conv_pg_id
|
||||
|
||||
def save_user_question(
|
||||
self,
|
||||
conversation_id: Optional[str],
|
||||
question: str,
|
||||
decoded_token: Dict[str, Any],
|
||||
*,
|
||||
attachment_ids: Optional[List[str]] = None,
|
||||
api_key: Optional[str] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
is_shared_usage: bool = False,
|
||||
shared_token: Optional[str] = None,
|
||||
model_id: Optional[str] = None,
|
||||
request_id: Optional[str] = None,
|
||||
status: str = "pending",
|
||||
index: Optional[int] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""Reserve the placeholder message row before the LLM call.
|
||||
|
||||
``index`` triggers regenerate semantics: messages at
|
||||
``position >= index`` are truncated so the new placeholder
|
||||
lands at ``position = index`` rather than appending.
|
||||
|
||||
Returns ``{"conversation_id", "message_id", "request_id"}``.
|
||||
"""
|
||||
if decoded_token is None:
|
||||
raise ValueError("Invalid or missing authentication token")
|
||||
user_id = decoded_token.get("sub")
|
||||
if not user_id:
|
||||
raise ValueError("User ID not found in token")
|
||||
|
||||
request_id = request_id or str(uuid.uuid4())
|
||||
|
||||
resolved_api_key: Optional[str] = None
|
||||
resolved_agent_id: Optional[str] = None
|
||||
if api_key and not conversation_id:
|
||||
with db_readonly() as conn:
|
||||
agent = AgentsRepository(conn).find_by_key(api_key)
|
||||
if agent:
|
||||
resolved_api_key = agent.get("key")
|
||||
if agent_id:
|
||||
resolved_agent_id = agent_id
|
||||
|
||||
with db_session() as conn:
|
||||
repo = ConversationsRepository(conn)
|
||||
if conversation_id:
|
||||
conv = repo.get_any(conversation_id, user_id)
|
||||
if conv is None:
|
||||
raise ValueError("Conversation not found or unauthorized")
|
||||
conv_pg_id = str(conv["id"])
|
||||
# Regenerate / edit-prior-question: drop the message at
|
||||
# ``index`` and everything after it so the new
|
||||
# ``reserve_message`` lands at ``position=index`` rather
|
||||
# than appending at the end of the conversation.
|
||||
if isinstance(index, int) and index >= 0:
|
||||
repo.truncate_after(conv_pg_id, keep_up_to=index - 1)
|
||||
else:
|
||||
fallback_name = (question[:50] if question else "New Conversation")
|
||||
conv = repo.create(
|
||||
user_id,
|
||||
fallback_name,
|
||||
agent_id=resolved_agent_id,
|
||||
api_key=resolved_api_key,
|
||||
is_shared_usage=bool(resolved_agent_id and is_shared_usage),
|
||||
shared_token=(
|
||||
shared_token
|
||||
if (resolved_agent_id and is_shared_usage)
|
||||
else None
|
||||
),
|
||||
)
|
||||
conv_pg_id = str(conv["id"])
|
||||
|
||||
row = repo.reserve_message(
|
||||
conv_pg_id,
|
||||
prompt=question,
|
||||
placeholder_response=TERMINATED_RESPONSE_PLACEHOLDER,
|
||||
request_id=request_id,
|
||||
status=status,
|
||||
attachments=attachment_ids,
|
||||
model_id=model_id,
|
||||
)
|
||||
message_id = str(row["id"])
|
||||
|
||||
return {
|
||||
"conversation_id": conv_pg_id,
|
||||
"message_id": message_id,
|
||||
"request_id": request_id,
|
||||
}
|
||||
|
||||
def update_message_status(self, message_id: str, status: str) -> bool:
|
||||
"""Cheap status-only transition (e.g. ``pending → streaming``)."""
|
||||
if not message_id:
|
||||
return False
|
||||
with db_session() as conn:
|
||||
return ConversationsRepository(conn).update_message_status(
|
||||
message_id, status,
|
||||
)
|
||||
|
||||
def heartbeat_message(self, message_id: str) -> bool:
|
||||
"""Bump ``message_metadata.last_heartbeat_at`` so the reconciler's
|
||||
staleness sweep counts the row as alive. No-ops on terminal rows.
|
||||
"""
|
||||
if not message_id:
|
||||
return False
|
||||
with db_session() as conn:
|
||||
return ConversationsRepository(conn).heartbeat_message(message_id)
|
||||
|
||||
def finalize_message(
|
||||
self,
|
||||
message_id: str,
|
||||
response: str,
|
||||
*,
|
||||
thought: str = "",
|
||||
sources: Optional[List[Dict[str, Any]]] = None,
|
||||
tool_calls: Optional[List[Dict[str, Any]]] = None,
|
||||
model_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
status: str = "complete",
|
||||
error: Optional[BaseException] = None,
|
||||
title_inputs: Optional[Dict[str, Any]] = None,
|
||||
) -> MessageUpdateOutcome:
|
||||
"""Commit the response and tool_call confirms in one transaction.
|
||||
|
||||
The outcome propagates directly from ``update_message_by_id`` so
|
||||
callers (notably the SSE abort handler) can tell a fresh
|
||||
finalize from "the row was already terminal" — the latter must
|
||||
still be treated as success when the prior state was
|
||||
``complete``.
|
||||
"""
|
||||
if not message_id:
|
||||
return MessageUpdateOutcome.INVALID
|
||||
sources = sources or []
|
||||
for source in sources:
|
||||
if "text" in source and isinstance(source["text"], str):
|
||||
source["text"] = source["text"][:1000]
|
||||
|
||||
merged_metadata: Dict[str, Any] = dict(metadata or {})
|
||||
if status == "failed" and error is not None:
|
||||
merged_metadata.setdefault(
|
||||
"error", f"{type(error).__name__}: {str(error)}"
|
||||
)
|
||||
|
||||
update_fields: Dict[str, Any] = {
|
||||
"response": response,
|
||||
"status": status,
|
||||
"thought": thought,
|
||||
"sources": sources,
|
||||
"tool_calls": tool_calls or [],
|
||||
"metadata": merged_metadata,
|
||||
}
|
||||
if model_id is not None:
|
||||
update_fields["model_id"] = model_id
|
||||
|
||||
# Atomic message update + tool_call_attempts confirm; the
|
||||
# ``only_if_non_terminal`` guard prevents a late stream from
|
||||
# retracting a row the reconciler already escalated.
|
||||
with db_session() as conn:
|
||||
repo = ConversationsRepository(conn)
|
||||
outcome = repo.update_message_by_id(
|
||||
message_id, update_fields,
|
||||
only_if_non_terminal=True,
|
||||
)
|
||||
if outcome is not MessageUpdateOutcome.UPDATED:
|
||||
logger.warning(
|
||||
f"finalize_message: no row updated for message_id={message_id} "
|
||||
f"(outcome={outcome.value} — possibly already terminal)"
|
||||
)
|
||||
return outcome
|
||||
repo.confirm_executed_tool_calls(message_id)
|
||||
|
||||
# Outside the txn — title-gen is a multi-second LLM round trip.
|
||||
if title_inputs and status == "complete":
|
||||
try:
|
||||
with db_session() as conn:
|
||||
self._maybe_generate_title(conn, message_id, title_inputs)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"finalize_message title generation failed: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
return MessageUpdateOutcome.UPDATED
|
||||
|
||||
def _maybe_generate_title(
|
||||
self,
|
||||
conn,
|
||||
message_id: str,
|
||||
title_inputs: Dict[str, Any],
|
||||
) -> None:
|
||||
"""Generate an LLM-summarised conversation name if one isn't set yet."""
|
||||
llm = title_inputs.get("llm")
|
||||
question = title_inputs.get("question") or ""
|
||||
response = title_inputs.get("response") or ""
|
||||
fallback_name = title_inputs.get("fallback_name") or question[:50]
|
||||
if llm is None:
|
||||
return
|
||||
|
||||
row = conn.execute(
|
||||
sql_text(
|
||||
"SELECT c.id, c.name FROM conversation_messages m "
|
||||
"JOIN conversations c ON c.id = m.conversation_id "
|
||||
"WHERE m.id = CAST(:mid AS uuid)"
|
||||
),
|
||||
{"mid": message_id},
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return
|
||||
conv_id, current_name = str(row[0]), row[1]
|
||||
if current_name and current_name != fallback_name:
|
||||
return
|
||||
|
||||
messages_summary = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant that creates concise conversation titles. "
|
||||
"Summarize conversations in 3 words or less using the same language as the user.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Summarise following conversation in no more than 3 words, "
|
||||
"respond ONLY with the summary, use the same language as the "
|
||||
"user query \n\nUser: " + question + "\n\n" + "AI: " + response,
|
||||
},
|
||||
]
|
||||
completion = llm.gen(
|
||||
model=getattr(llm, "model_id", None) or title_inputs.get("model_id"),
|
||||
messages=messages_summary,
|
||||
max_tokens=500,
|
||||
)
|
||||
if not completion or not completion.strip():
|
||||
completion = fallback_name or "New Conversation"
|
||||
conn.execute(
|
||||
sql_text(
|
||||
"UPDATE conversations SET name = :name, updated_at = now() "
|
||||
"WHERE id = CAST(:id AS uuid)"
|
||||
),
|
||||
{"id": conv_id, "name": completion.strip()},
|
||||
)
|
||||
|
||||
def update_compression_metadata(
|
||||
self, conversation_id: str, compression_metadata: Dict[str, Any]
|
||||
) -> None:
|
||||
|
||||
@@ -6,6 +6,7 @@ from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Set
|
||||
|
||||
from application.agents.agent_creator import AgentCreator
|
||||
from application.agents.default_tools import synthesized_default_tools
|
||||
from application.api.answer.services.compression import CompressionOrchestrator
|
||||
from application.api.answer.services.compression.token_counter import TokenCounter
|
||||
from application.api.answer.services.conversation_service import ConversationService
|
||||
@@ -25,6 +26,7 @@ from application.storage.db.repositories.attachments import AttachmentsRepositor
|
||||
from application.storage.db.repositories.prompts import PromptsRepository
|
||||
from application.storage.db.repositories.sources import SourcesRepository
|
||||
from application.storage.db.repositories.user_tools import UserToolsRepository
|
||||
from application.storage.db.repositories.users import UsersRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.retriever.retriever_creator import RetrieverCreator
|
||||
from application.utils import (
|
||||
@@ -123,6 +125,10 @@ class StreamProcessor:
|
||||
self.model_id: Optional[str] = None
|
||||
# BYOM-resolution scope, set by _validate_and_set_model.
|
||||
self.model_user_id: Optional[str] = None
|
||||
# WAL placeholder id pulled from continuation state on resume.
|
||||
self.reserved_message_id: Optional[str] = None
|
||||
# Carried through resumes so multi-pause runs keep one request_id.
|
||||
self.request_id: Optional[str] = None
|
||||
self.conversation_service = ConversationService()
|
||||
self.compression_orchestrator = CompressionOrchestrator(
|
||||
self.conversation_service
|
||||
@@ -289,7 +295,7 @@ class StreamProcessor:
|
||||
return attachments
|
||||
|
||||
def _validate_and_set_model(self):
|
||||
"""Validate and set model_id from request"""
|
||||
"""Pick model_id with agent authority on agent-bound chats."""
|
||||
from application.core.model_settings import ModelRegistry
|
||||
|
||||
requested_model = self.data.get("model_id")
|
||||
@@ -298,6 +304,20 @@ class StreamProcessor:
|
||||
caller_user_id = self.initial_user_id
|
||||
owner_user_id = self.agent_config.get("user_id") or caller_user_id
|
||||
|
||||
# Agent-bound: agent's default_model_id wins, body's model_id is dropped.
|
||||
agent_bound = self._agent_data is not None
|
||||
if agent_bound:
|
||||
agent_default_model = self.agent_config.get("default_model_id", "")
|
||||
if agent_default_model and validate_model_id(
|
||||
agent_default_model, user_id=owner_user_id
|
||||
):
|
||||
self.model_id = agent_default_model
|
||||
self.model_user_id = owner_user_id
|
||||
else:
|
||||
self.model_id = get_default_model_id()
|
||||
self.model_user_id = None
|
||||
return
|
||||
|
||||
if requested_model:
|
||||
if not validate_model_id(requested_model, user_id=caller_user_id):
|
||||
registry = ModelRegistry.get_instance()
|
||||
@@ -317,15 +337,8 @@ class StreamProcessor:
|
||||
self.model_id = requested_model
|
||||
self.model_user_id = caller_user_id
|
||||
else:
|
||||
agent_default_model = self.agent_config.get("default_model_id", "")
|
||||
if agent_default_model and validate_model_id(
|
||||
agent_default_model, user_id=owner_user_id
|
||||
):
|
||||
self.model_id = agent_default_model
|
||||
self.model_user_id = owner_user_id
|
||||
else:
|
||||
self.model_id = get_default_model_id()
|
||||
self.model_user_id = None
|
||||
self.model_id = get_default_model_id()
|
||||
self.model_user_id = None
|
||||
|
||||
def _get_agent_key(self, agent_id: Optional[str], user_id: Optional[str]) -> tuple:
|
||||
"""Get API key for agent with access control."""
|
||||
@@ -381,6 +394,7 @@ class StreamProcessor:
|
||||
raise
|
||||
|
||||
def _get_data_from_api_key(self, api_key: str) -> Dict[str, Any]:
|
||||
"""Resolve agent metadata + the unioned source set for the given key."""
|
||||
with db_readonly() as conn:
|
||||
agent = AgentsRepository(conn).find_by_key(api_key)
|
||||
if not agent:
|
||||
@@ -391,36 +405,66 @@ class StreamProcessor:
|
||||
data: Dict[str, Any] = dict(agent)
|
||||
data["user"] = agent.get("user_id")
|
||||
|
||||
# Resolve the primary source row (if any) for retriever/chunks.
|
||||
source_id = agent.get("source_id")
|
||||
if source_id:
|
||||
source_doc = sources_repo.get(str(source_id), agent.get("user_id"))
|
||||
# Active sources = primary ∪ extras, primary first, deduplicated.
|
||||
# ``_configure_source`` ignores an empty ``data["sources"]``,
|
||||
# so the primary must appear in the union too — not only in
|
||||
# the legacy ``data["source"]`` slot.
|
||||
sources_list: list = []
|
||||
seen: set = set()
|
||||
owner = agent.get("user_id")
|
||||
primary_id = agent.get("source_id")
|
||||
# ``sources`` row may have NULL ``retriever``/``chunks`` —
|
||||
# fall back to the agent's value (``dict.get`` returns None
|
||||
# even when the key exists with value None).
|
||||
if primary_id:
|
||||
source_doc = sources_repo.get(str(primary_id), owner)
|
||||
if source_doc:
|
||||
data["source"] = str(source_doc["id"])
|
||||
data["retriever"] = source_doc.get(
|
||||
"retriever", data.get("retriever")
|
||||
sid = str(source_doc["id"])
|
||||
data["source"] = sid
|
||||
src_retriever = source_doc.get("retriever")
|
||||
if src_retriever:
|
||||
data["retriever"] = src_retriever
|
||||
src_chunks = source_doc.get("chunks")
|
||||
if src_chunks is not None:
|
||||
data["chunks"] = src_chunks
|
||||
sources_list.append(
|
||||
{
|
||||
"id": sid,
|
||||
"retriever": src_retriever or "classic",
|
||||
"chunks": (
|
||||
src_chunks if src_chunks is not None
|
||||
else data.get("chunks", "2")
|
||||
),
|
||||
}
|
||||
)
|
||||
data["chunks"] = source_doc.get("chunks", data.get("chunks"))
|
||||
seen.add(sid)
|
||||
else:
|
||||
data["source"] = None
|
||||
else:
|
||||
data["source"] = None
|
||||
|
||||
sources_list = []
|
||||
extra = agent.get("extra_source_ids") or []
|
||||
if extra:
|
||||
for sid in extra:
|
||||
source_doc = sources_repo.get(str(sid), agent.get("user_id"))
|
||||
if source_doc:
|
||||
sources_list.append(
|
||||
{
|
||||
"id": str(source_doc["id"]),
|
||||
"retriever": source_doc.get("retriever", "classic"),
|
||||
"chunks": source_doc.get(
|
||||
"chunks", data.get("chunks", "2")
|
||||
),
|
||||
}
|
||||
)
|
||||
for sid_raw in agent.get("extra_source_ids") or []:
|
||||
if not sid_raw:
|
||||
continue
|
||||
source_doc = sources_repo.get(str(sid_raw), owner)
|
||||
if not source_doc:
|
||||
continue
|
||||
sid = str(source_doc["id"])
|
||||
if sid in seen:
|
||||
continue
|
||||
src_retriever = source_doc.get("retriever")
|
||||
src_chunks = source_doc.get("chunks")
|
||||
sources_list.append(
|
||||
{
|
||||
"id": sid,
|
||||
"retriever": src_retriever or "classic",
|
||||
"chunks": (
|
||||
src_chunks if src_chunks is not None
|
||||
else data.get("chunks", "2")
|
||||
),
|
||||
}
|
||||
)
|
||||
seen.add(sid)
|
||||
data["sources"] = sources_list
|
||||
data["default_model_id"] = data.get("default_model_id", "")
|
||||
return data
|
||||
@@ -585,7 +629,7 @@ class StreamProcessor:
|
||||
)
|
||||
|
||||
def _configure_retriever(self):
|
||||
"""Assemble retriever config with precedence: request > agent > default."""
|
||||
"""Assemble retriever config; agent's values are authoritative when bound."""
|
||||
# BYOM scope: owner for shared-agent BYOM, caller for own BYOM,
|
||||
# None for built-ins. Without ``user_id`` here, the doc budget
|
||||
# falls back to settings.DEFAULT_LLM_TOKEN_LIMIT and overfills
|
||||
@@ -594,12 +638,11 @@ class StreamProcessor:
|
||||
model_id=self.model_id, user_id=self.model_user_id
|
||||
)
|
||||
|
||||
# Start with defaults
|
||||
retriever_name = "classic"
|
||||
chunks = 2
|
||||
|
||||
# Layer agent-level config (if present)
|
||||
if self._agent_data:
|
||||
if self._agent_data is not None:
|
||||
# Agent-bound: agent wins, body's retriever/chunks are dropped.
|
||||
if self._agent_data.get("retriever"):
|
||||
retriever_name = self._agent_data["retriever"]
|
||||
if self._agent_data.get("chunks") is not None:
|
||||
@@ -610,18 +653,17 @@ class StreamProcessor:
|
||||
f"Invalid agent chunks value: {self._agent_data['chunks']}, "
|
||||
"using default value 2"
|
||||
)
|
||||
|
||||
# Explicit request values win over agent config
|
||||
if "retriever" in self.data:
|
||||
retriever_name = self.data["retriever"]
|
||||
if "chunks" in self.data:
|
||||
try:
|
||||
chunks = int(self.data["chunks"])
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
f"Invalid request chunks value: {self.data['chunks']}, "
|
||||
"using default value 2"
|
||||
)
|
||||
else:
|
||||
if "retriever" in self.data:
|
||||
retriever_name = self.data["retriever"]
|
||||
if "chunks" in self.data:
|
||||
try:
|
||||
chunks = int(self.data["chunks"])
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
f"Invalid request chunks value: {self.data['chunks']}, "
|
||||
"using default value 2"
|
||||
)
|
||||
|
||||
self.retriever_config = {
|
||||
"retriever_name": retriever_name,
|
||||
@@ -629,7 +671,7 @@ class StreamProcessor:
|
||||
"doc_token_limit": doc_token_limit,
|
||||
}
|
||||
|
||||
# isNoneDoc without an API key forces no retrieval
|
||||
# isNoneDoc without an API key forces no retrieval (agentless only)
|
||||
api_key = self.data.get("api_key") or self.agent_key
|
||||
if not api_key and "isNoneDoc" in self.data and self.data["isNoneDoc"]:
|
||||
self.retriever_config["chunks"] = 0
|
||||
@@ -704,17 +746,26 @@ class StreamProcessor:
|
||||
|
||||
try:
|
||||
user_id = self.initial_user_id or "local"
|
||||
agentless = self.agent_id is None
|
||||
with db_readonly() as conn:
|
||||
user_tools = UserToolsRepository(conn).list_active_for_user(user_id)
|
||||
user_doc = (
|
||||
UsersRepository(conn).get(user_id) if agentless else None
|
||||
)
|
||||
|
||||
if not user_tools:
|
||||
default_docs = (
|
||||
synthesized_default_tools(user_doc) if agentless else []
|
||||
)
|
||||
tool_docs = list(user_tools) + default_docs
|
||||
if not tool_docs:
|
||||
return None
|
||||
|
||||
tools_data = {}
|
||||
|
||||
for tool_doc in user_tools:
|
||||
for tool_doc in tool_docs:
|
||||
tool_name = tool_doc.get("name")
|
||||
tool_id = str(tool_doc.get("_id"))
|
||||
tool_id = str(tool_doc.get("_id") or tool_doc.get("id"))
|
||||
is_default = bool(tool_doc.get("default"))
|
||||
|
||||
if filtering_enabled:
|
||||
required_actions_by_name = required_tool_actions.get(
|
||||
@@ -727,11 +778,18 @@ class StreamProcessor:
|
||||
if not required_actions:
|
||||
continue
|
||||
else:
|
||||
# No template names a default tool, so running its
|
||||
# actions blind would only inject noise.
|
||||
if is_default:
|
||||
continue
|
||||
required_actions = None
|
||||
|
||||
tool_data = self._fetch_tool_data(tool_doc, required_actions)
|
||||
if tool_data:
|
||||
tools_data[tool_name] = tool_data
|
||||
# Defaults reachable by synthetic id only — the name
|
||||
# key stays bound to an explicit row of the same name.
|
||||
if not is_default:
|
||||
tools_data[tool_name] = tool_data
|
||||
tools_data[tool_id] = tool_data
|
||||
|
||||
return tools_data if tools_data else None
|
||||
@@ -928,6 +986,20 @@ class StreamProcessor:
|
||||
if not state:
|
||||
raise ValueError("No pending tool state found for this conversation")
|
||||
|
||||
# Claim the resume up-front. ``mark_resuming`` only flips ``pending``
|
||||
# → ``resuming``; if it returns False, another resume already
|
||||
# claimed this row (status='resuming') — bail before any further
|
||||
# LLM/tool work to avoid double-execution. The cleanup janitor
|
||||
# reverts a stale ``resuming`` claim back to ``pending`` after the
|
||||
# 10-minute grace window so the user can retry.
|
||||
if not cont_service.mark_resuming(
|
||||
conversation_id, self.initial_user_id,
|
||||
):
|
||||
raise ValueError(
|
||||
"Resume already in progress for this conversation; "
|
||||
"retry after the grace window if it stalls."
|
||||
)
|
||||
|
||||
messages = state["messages"]
|
||||
pending_tool_calls = state["pending_tool_calls"]
|
||||
tools_dict = state["tools_dict"]
|
||||
@@ -964,6 +1036,7 @@ class StreamProcessor:
|
||||
user_api_key=user_api_key,
|
||||
user=self.initial_user_id,
|
||||
decoded_token=self.decoded_token,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
tool_executor.conversation_id = conversation_id
|
||||
# Restore client tools so they stay available for subsequent LLM calls
|
||||
@@ -1022,9 +1095,10 @@ class StreamProcessor:
|
||||
self.agent_id = agent_id
|
||||
self.agent_config["user_api_key"] = user_api_key
|
||||
self.conversation_id = conversation_id
|
||||
|
||||
# Delete state so it can't be replayed
|
||||
cont_service.delete_state(conversation_id, self.initial_user_id)
|
||||
# Reused on resume so the same WAL row gets finalised and
|
||||
# request_id stays consistent across token_usage rows.
|
||||
self.reserved_message_id = agent_config.get("reserved_message_id")
|
||||
self.request_id = agent_config.get("request_id")
|
||||
|
||||
return agent, messages, tools_dict, pending_tool_calls, tool_actions
|
||||
|
||||
@@ -1111,6 +1185,7 @@ class StreamProcessor:
|
||||
user_api_key=self.agent_config["user_api_key"],
|
||||
user=user,
|
||||
decoded_token=self.decoded_token,
|
||||
agent_id=self.agent_id,
|
||||
)
|
||||
tool_executor.conversation_id = self.conversation_id
|
||||
# Pass client-side tools so they get merged in get_tools()
|
||||
@@ -1118,7 +1193,6 @@ class StreamProcessor:
|
||||
if client_tools:
|
||||
tool_executor.client_tools = client_tools
|
||||
|
||||
# Base agent kwargs
|
||||
agent_kwargs = {
|
||||
"endpoint": "stream",
|
||||
"llm_name": provider or settings.LLM_PROVIDER,
|
||||
|
||||
0
application/api/events/__init__.py
Normal file
0
application/api/events/__init__.py
Normal file
504
application/api/events/routes.py
Normal file
504
application/api/events/routes.py
Normal file
@@ -0,0 +1,504 @@
|
||||
"""GET /api/events — user-scoped Server-Sent Events endpoint.
|
||||
|
||||
Subscribe-then-snapshot pattern: subscribe to ``user:{user_id}``
|
||||
pub/sub, snapshot the Redis Streams backlog past ``Last-Event-ID``
|
||||
inside the SUBSCRIBE-ack callback, flush snapshot, then tail live
|
||||
events (dedup'd by stream id). See ``docs/runbooks/sse-notifications.md``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from typing import Iterator, Optional
|
||||
|
||||
from flask import Blueprint, Response, jsonify, make_response, request, stream_with_context
|
||||
|
||||
from application.cache import get_redis_instance
|
||||
from application.core.settings import settings
|
||||
from application.events.keys import (
|
||||
connection_counter_key,
|
||||
replay_budget_key,
|
||||
stream_id_compare,
|
||||
stream_key,
|
||||
topic_name,
|
||||
)
|
||||
from application.streaming.broadcast_channel import Topic
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
events = Blueprint("event_stream", __name__)
|
||||
|
||||
SUBSCRIBE_POLL_INTERVAL_SECONDS = 1.0
|
||||
|
||||
# WHATWG SSE treats CRLF, CR, and LF equivalently as line terminators.
|
||||
_SSE_LINE_SPLIT = re.compile(r"\r\n|\r|\n")
|
||||
|
||||
# Redis Streams ids are ``ms`` or ``ms-seq`` where both halves are decimal.
|
||||
# Anything else is a corrupted client cookie / IndexedDB residue and must
|
||||
# not be passed to XRANGE — Redis would reject it and our truncation gate
|
||||
# would silently fail.
|
||||
_STREAM_ID_RE = re.compile(r"^\d+(-\d+)?$")
|
||||
|
||||
# Only emitted at most once per process so a misconfigured deployment
|
||||
# doesn't drown the logs.
|
||||
_local_user_warned = False
|
||||
|
||||
|
||||
def _format_sse(data: str, *, event_id: Optional[str] = None) -> str:
|
||||
"""Encode a payload as one SSE message terminated by a blank line.
|
||||
|
||||
Splits on any line-terminator variant (``\\r\\n``, ``\\r``, ``\\n``)
|
||||
so a stray CR in upstream content can't smuggle a premature line
|
||||
boundary into the wire format.
|
||||
"""
|
||||
lines: list[str] = []
|
||||
if event_id:
|
||||
lines.append(f"id: {event_id}")
|
||||
for line in _SSE_LINE_SPLIT.split(data):
|
||||
lines.append(f"data: {line}")
|
||||
return "\n".join(lines) + "\n\n"
|
||||
|
||||
|
||||
def _decode(value) -> Optional[str]:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, (bytes, bytearray)):
|
||||
try:
|
||||
return value.decode("utf-8")
|
||||
except Exception:
|
||||
return None
|
||||
return str(value)
|
||||
|
||||
|
||||
def _oldest_retained_id(redis_client, user_id: str) -> Optional[str]:
|
||||
"""Return the id of the oldest entry still in the stream, or ``None``.
|
||||
|
||||
Used to detect ``Last-Event-ID`` having slid off the back of the
|
||||
MAXLEN'd window.
|
||||
"""
|
||||
try:
|
||||
info = redis_client.xinfo_stream(stream_key(user_id))
|
||||
except Exception:
|
||||
return None
|
||||
if not isinstance(info, dict):
|
||||
return None
|
||||
# redis-py 7.4 returns str-keyed dicts here; the bytes-key probe is
|
||||
# defence in depth in case ``decode_responses`` is ever flipped.
|
||||
first_entry = info.get("first-entry") or info.get(b"first-entry")
|
||||
if not first_entry:
|
||||
return None
|
||||
# XINFO STREAM returns first-entry as [id, [field, value, ...]]
|
||||
try:
|
||||
return _decode(first_entry[0])
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _allow_replay(
|
||||
redis_client, user_id: str, last_event_id: Optional[str]
|
||||
) -> bool:
|
||||
"""Per-user sliding-window snapshot-replay budget.
|
||||
|
||||
Fails open on Redis errors or when the budget is disabled. Empty-backlog
|
||||
no-cursor connects skip INCR so dev double-mounts don't trip 429.
|
||||
"""
|
||||
budget = int(settings.EVENTS_REPLAY_BUDGET_REQUESTS_PER_WINDOW)
|
||||
if budget <= 0:
|
||||
return True
|
||||
if redis_client is None:
|
||||
return True
|
||||
|
||||
# Cheap pre-check: only INCR when we might actually replay. XLEN
|
||||
# is one Redis op; the alternative (INCR every connect) is two
|
||||
# ops AND wrongly counts no-op probes. The check is conservative:
|
||||
# if ``last_event_id`` is set we always INCR, even if the cursor
|
||||
# has already overtaken the latest entry — that case is rare and
|
||||
# short-lived, and probing further would mean a redundant XRANGE.
|
||||
if last_event_id is None:
|
||||
try:
|
||||
if int(redis_client.xlen(stream_key(user_id))) == 0:
|
||||
return True
|
||||
except Exception:
|
||||
# XLEN probe failed; fall through to the INCR path so a
|
||||
# transient Redis hiccup can't bypass the budget.
|
||||
logger.debug(
|
||||
"XLEN probe failed for replay budget check user=%s; "
|
||||
"proceeding to INCR",
|
||||
user_id,
|
||||
)
|
||||
|
||||
window = max(1, int(settings.EVENTS_REPLAY_BUDGET_WINDOW_SECONDS))
|
||||
key = replay_budget_key(user_id)
|
||||
try:
|
||||
used = int(redis_client.incr(key))
|
||||
# Always (re)seed the TTL. Gating on ``used == 1`` would wedge
|
||||
# the counter forever if INCR succeeds but EXPIRE raises on
|
||||
# the seeding call. EXPIRE on an existing key resets the TTL
|
||||
# to ``window`` — within ±1s of the per-window budget semantic.
|
||||
redis_client.expire(key, window)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"replay budget probe failed for user=%s; failing open",
|
||||
user_id,
|
||||
)
|
||||
return True
|
||||
return used <= budget
|
||||
|
||||
|
||||
def _normalize_last_event_id(raw: Optional[str]) -> Optional[str]:
|
||||
"""Validate the ``Last-Event-ID`` header / query param.
|
||||
|
||||
Returns the value unchanged when it parses as a Redis Streams id,
|
||||
otherwise ``None`` — callers treat ``None`` as "client has nothing"
|
||||
and replay from the start of the retained window. Invalid ids would
|
||||
otherwise pass straight to XRANGE and surface as a quiet replay
|
||||
failure plus broken truncation detection.
|
||||
"""
|
||||
if raw is None:
|
||||
return None
|
||||
raw = raw.strip()
|
||||
if not raw or not _STREAM_ID_RE.match(raw):
|
||||
return None
|
||||
return raw
|
||||
|
||||
|
||||
def _replay_backlog(
|
||||
redis_client, user_id: str, last_event_id: Optional[str], max_count: int
|
||||
) -> Iterator[tuple[str, str]]:
|
||||
"""Yield ``(entry_id, sse_line)`` for backlog entries past ``last_event_id``.
|
||||
|
||||
Capped at ``max_count`` rows; clients catch up across reconnects.
|
||||
Parse failures are skipped; the Streams id is injected into the
|
||||
envelope so replay matches live-tail shape.
|
||||
"""
|
||||
# Exclusive start: '(<id>' skips the already-delivered entry.
|
||||
start = f"({last_event_id}" if last_event_id else "-"
|
||||
try:
|
||||
entries = redis_client.xrange(
|
||||
stream_key(user_id), min=start, max="+", count=max_count
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"xrange replay failed for user=%s last_id=%s err=%s",
|
||||
user_id,
|
||||
last_event_id or "-",
|
||||
exc,
|
||||
)
|
||||
return
|
||||
|
||||
for entry_id, fields in entries:
|
||||
entry_id_str = _decode(entry_id)
|
||||
if not entry_id_str:
|
||||
continue
|
||||
# decode_responses=False on the cache client ⇒ field keys/values
|
||||
# are bytes. The string-key fallback covers a future flip of that
|
||||
# default without a forced refactor here.
|
||||
raw_event = None
|
||||
if isinstance(fields, dict):
|
||||
raw_event = fields.get(b"event")
|
||||
if raw_event is None:
|
||||
raw_event = fields.get("event")
|
||||
event_str = _decode(raw_event)
|
||||
if not event_str:
|
||||
continue
|
||||
try:
|
||||
envelope = json.loads(event_str)
|
||||
if isinstance(envelope, dict):
|
||||
envelope["id"] = entry_id_str
|
||||
event_str = json.dumps(envelope)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Replay envelope parse failed for entry %s; passing through raw",
|
||||
entry_id_str,
|
||||
)
|
||||
yield entry_id_str, _format_sse(event_str, event_id=entry_id_str)
|
||||
|
||||
|
||||
def _truncation_notice_line(oldest_id: str) -> str:
|
||||
"""SSE event the frontend can react to with a full-state refetch."""
|
||||
return _format_sse(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "backlog.truncated",
|
||||
"payload": {"oldest_retained_id": oldest_id},
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@events.route("/api/events", methods=["GET"])
|
||||
def stream_events() -> Response:
|
||||
decoded = getattr(request, "decoded_token", None)
|
||||
user_id = decoded.get("sub") if isinstance(decoded, dict) else None
|
||||
if not user_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Authentication required"}),
|
||||
401,
|
||||
)
|
||||
|
||||
# In dev deployments without AUTH_TYPE configured, every request
|
||||
# resolves to user_id="local" and shares one stream. Surface this so
|
||||
# an accidentally-multi-user dev box doesn't silently cross-stream.
|
||||
global _local_user_warned
|
||||
if user_id == "local" and not _local_user_warned:
|
||||
logger.warning(
|
||||
"SSE serving user_id='local' (AUTH_TYPE not set). "
|
||||
"All clients on this deployment will share one event stream."
|
||||
)
|
||||
_local_user_warned = True
|
||||
|
||||
raw_last_event_id = request.headers.get("Last-Event-ID") or request.args.get(
|
||||
"last_event_id"
|
||||
)
|
||||
last_event_id = _normalize_last_event_id(raw_last_event_id)
|
||||
last_event_id_invalid = raw_last_event_id is not None and last_event_id is None
|
||||
|
||||
keepalive_seconds = float(settings.SSE_KEEPALIVE_SECONDS)
|
||||
push_enabled = settings.ENABLE_SSE_PUSH
|
||||
cap = int(settings.SSE_MAX_CONCURRENT_PER_USER)
|
||||
|
||||
redis_client = get_redis_instance()
|
||||
counter_key = connection_counter_key(user_id)
|
||||
counted = False
|
||||
|
||||
if push_enabled and redis_client is not None and cap > 0:
|
||||
try:
|
||||
current = int(redis_client.incr(counter_key))
|
||||
counted = True
|
||||
except Exception:
|
||||
current = 0
|
||||
logger.debug(
|
||||
"SSE connection counter INCR failed for user=%s", user_id
|
||||
)
|
||||
if counted:
|
||||
# 1h safety TTL — orphaned counts from hard crashes self-heal.
|
||||
# EXPIRE failure must NOT clobber ``current`` and bypass the cap.
|
||||
try:
|
||||
redis_client.expire(counter_key, 3600)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"SSE connection counter EXPIRE failed for user=%s", user_id
|
||||
)
|
||||
if current > cap:
|
||||
try:
|
||||
redis_client.decr(counter_key)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"SSE connection counter DECR failed for user=%s",
|
||||
user_id,
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Too many concurrent SSE connections",
|
||||
}
|
||||
),
|
||||
429,
|
||||
)
|
||||
|
||||
# Replay budget is checked here, before the generator opens the
|
||||
# stream, so a denial can surface as HTTP 429 instead of a silent
|
||||
# snapshot skip. The earlier in-generator skip lost events between
|
||||
# the client's cursor and the first live-tailed entry: the live
|
||||
# tail still carried ``id:`` headers, the frontend advanced
|
||||
# ``lastEventId`` to one of those ids, and the events in between
|
||||
# were never reachable on the next reconnect. 429 keeps the
|
||||
# cursor pinned and lets the frontend back off until the window
|
||||
# slides (eventStreamClient.ts treats 429 as escalated backoff).
|
||||
if push_enabled and redis_client is not None and not _allow_replay(
|
||||
redis_client, user_id, last_event_id
|
||||
):
|
||||
if counted:
|
||||
try:
|
||||
redis_client.decr(counter_key)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"SSE connection counter DECR failed for user=%s",
|
||||
user_id,
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Replay budget exhausted",
|
||||
}
|
||||
),
|
||||
429,
|
||||
)
|
||||
|
||||
@stream_with_context
|
||||
def generate() -> Iterator[str]:
|
||||
connect_ts = time.monotonic()
|
||||
replayed_count = 0
|
||||
try:
|
||||
# First frame primes intermediaries (Cloudflare, nginx) so they
|
||||
# don't sit on a buffer waiting for body bytes.
|
||||
yield ": connected\n\n"
|
||||
|
||||
if not push_enabled:
|
||||
yield ": push_disabled\n\n"
|
||||
return
|
||||
|
||||
replay_lines: list[str] = []
|
||||
max_replayed_id: Optional[str] = None
|
||||
replay_done = False
|
||||
|
||||
# If the client sent a malformed Last-Event-ID, surface the
|
||||
# truncation notice synchronously *before* the subscribe
|
||||
# loop. Buffering it into ``replay_lines`` would lose it
|
||||
# when ``Topic.subscribe`` returns immediately (Redis down)
|
||||
# — the loop body never runs, and the flush at line ~335
|
||||
# never fires.
|
||||
if last_event_id_invalid:
|
||||
yield _truncation_notice_line("")
|
||||
replayed_count += 1
|
||||
|
||||
def _on_subscribe_callback() -> None:
|
||||
# Runs synchronously inside Topic.subscribe after the
|
||||
# SUBSCRIBE is acked. By doing XRANGE here, any publisher
|
||||
# firing between SUBSCRIBE-send and SUBSCRIBE-ack has its
|
||||
# XADD captured by XRANGE *and* its PUBLISH buffered at
|
||||
# the connection layer until we read it — closing the
|
||||
# replay/subscribe race the design doc warns about.
|
||||
#
|
||||
# Truncation contract: ``backlog.truncated`` is emitted
|
||||
# ONLY when the client's ``Last-Event-ID`` has slid off
|
||||
# the MAXLEN'd window — that's the case where the
|
||||
# journal is genuinely gone past the cursor and the
|
||||
# frontend should clear its slice cursor and refetch
|
||||
# state. Cap-hit skips the snapshot silently: the
|
||||
# cursor advances via the per-entry ``id:`` headers
|
||||
# and the frontend's slice keeps the latest id so the
|
||||
# next reconnect resumes from there. Budget-exhausted
|
||||
# never reaches this callback — the route 429s before
|
||||
# opening the stream, keeping the cursor pinned.
|
||||
# Conflating these with stale-cursor truncation would
|
||||
# tell the client to clear its cursor and re-receive
|
||||
# the same oldest-N entries on every reconnect —
|
||||
# locking the user out of entries past N.
|
||||
nonlocal max_replayed_id, replay_done
|
||||
try:
|
||||
if redis_client is None:
|
||||
return
|
||||
oldest = _oldest_retained_id(redis_client, user_id)
|
||||
if (
|
||||
last_event_id
|
||||
and oldest
|
||||
and stream_id_compare(last_event_id, oldest) < 0
|
||||
):
|
||||
# The Last-Event-ID has slid off the MAXLEN window.
|
||||
# Tell the client so it can fetch full state.
|
||||
replay_lines.append(_truncation_notice_line(oldest))
|
||||
replay_cap = int(settings.EVENTS_REPLAY_MAX_PER_REQUEST)
|
||||
for entry_id, sse_line in _replay_backlog(
|
||||
redis_client, user_id, last_event_id, replay_cap
|
||||
):
|
||||
replay_lines.append(sse_line)
|
||||
max_replayed_id = entry_id
|
||||
finally:
|
||||
# Always flip the flag — even on partial-replay failure
|
||||
# the outer loop must reach the flush step so we don't
|
||||
# silently strand whatever entries did land.
|
||||
replay_done = True
|
||||
|
||||
topic = Topic(topic_name(user_id))
|
||||
last_keepalive = time.monotonic()
|
||||
for payload in topic.subscribe(
|
||||
on_subscribe=_on_subscribe_callback,
|
||||
poll_timeout=SUBSCRIBE_POLL_INTERVAL_SECONDS,
|
||||
):
|
||||
# Flush snapshot on the first iteration after the SUBSCRIBE
|
||||
# callback ran. This runs at most once per connection.
|
||||
if replay_done and replay_lines:
|
||||
for line in replay_lines:
|
||||
yield line
|
||||
replayed_count += 1
|
||||
replay_lines.clear()
|
||||
|
||||
now = time.monotonic()
|
||||
if payload is None:
|
||||
if now - last_keepalive >= keepalive_seconds:
|
||||
yield ": keepalive\n\n"
|
||||
last_keepalive = now
|
||||
continue
|
||||
|
||||
event_str = _decode(payload) or ""
|
||||
event_id: Optional[str] = None
|
||||
try:
|
||||
envelope = json.loads(event_str)
|
||||
if isinstance(envelope, dict):
|
||||
candidate = envelope.get("id")
|
||||
# Only trust ids that look like real Redis Streams
|
||||
# ids (``ms`` or ``ms-seq``). A malformed or
|
||||
# adversarial publisher could otherwise pin
|
||||
# dedupe forever — a lex-greater bogus id would
|
||||
# make every legitimate later id compare ``<=``
|
||||
# and get dropped silently.
|
||||
if isinstance(candidate, str) and _STREAM_ID_RE.match(
|
||||
candidate
|
||||
):
|
||||
event_id = candidate
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Dedupe: if this id was already covered by replay, drop it.
|
||||
if (
|
||||
event_id is not None
|
||||
and max_replayed_id is not None
|
||||
and stream_id_compare(event_id, max_replayed_id) <= 0
|
||||
):
|
||||
continue
|
||||
|
||||
yield _format_sse(event_str, event_id=event_id)
|
||||
last_keepalive = now
|
||||
|
||||
# Topic.subscribe exited before the first yield (transient
|
||||
# Redis hiccup between SUBSCRIBE-ack and the first poll, or
|
||||
# an immediate Redis-down return). The callback may already
|
||||
# have populated the snapshot — flush it so the client gets
|
||||
# the backlog instead of a silent drop. Safe no-op when the
|
||||
# in-loop flush ran (it clear()'d the buffer) and when the
|
||||
# callback never fired (replay_done stays False).
|
||||
if replay_done and replay_lines:
|
||||
for line in replay_lines:
|
||||
yield line
|
||||
replayed_count += 1
|
||||
replay_lines.clear()
|
||||
except GeneratorExit:
|
||||
return
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"SSE event-stream generator crashed for user=%s", user_id
|
||||
)
|
||||
finally:
|
||||
duration_s = time.monotonic() - connect_ts
|
||||
logger.info(
|
||||
"event.disconnect user=%s duration_s=%.1f replayed=%d",
|
||||
user_id,
|
||||
duration_s,
|
||||
replayed_count,
|
||||
)
|
||||
if counted and redis_client is not None:
|
||||
try:
|
||||
redis_client.decr(counter_key)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"SSE connection counter DECR failed for user=%s on disconnect",
|
||||
user_id,
|
||||
)
|
||||
|
||||
response = Response(generate(), mimetype="text/event-stream")
|
||||
response.headers["Cache-Control"] = "no-store"
|
||||
response.headers["X-Accel-Buffering"] = "no"
|
||||
response.headers["Connection"] = "keep-alive"
|
||||
logger.info(
|
||||
"event.connect user=%s last_event_id=%s%s",
|
||||
user_id,
|
||||
last_event_id or "-",
|
||||
" (rejected_invalid)" if last_event_id_invalid else "",
|
||||
)
|
||||
return response
|
||||
@@ -46,7 +46,9 @@ AGENT_TYPE_SCHEMAS = {
|
||||
"prompt_id",
|
||||
],
|
||||
"required_draft": ["name"],
|
||||
"validate_published": ["name", "description", "prompt_id"],
|
||||
# ``prompt_id`` intentionally omitted — the "default" sentinel
|
||||
# is acceptable and maps to NULL downstream.
|
||||
"validate_published": ["name", "description"],
|
||||
"validate_draft": [],
|
||||
"require_source": True,
|
||||
"fields": [
|
||||
@@ -1009,12 +1011,16 @@ class UpdateAgent(Resource):
|
||||
400,
|
||||
)
|
||||
else:
|
||||
# ``prompt_id`` is intentionally omitted: the
|
||||
# frontend's "default" choice maps to NULL here
|
||||
# (see the prompt_id branch above), and NULL
|
||||
# means "use the built-in default prompt" which
|
||||
# is a valid published-agent state.
|
||||
missing_published_fields = []
|
||||
for req_field, field_label in (
|
||||
("name", "Agent name"),
|
||||
("description", "Agent description"),
|
||||
("chunks", "Chunks count"),
|
||||
("prompt_id", "Prompt"),
|
||||
("agent_type", "Agent type"),
|
||||
):
|
||||
final_value = update_fields.get(
|
||||
@@ -1028,8 +1034,23 @@ class UpdateAgent(Resource):
|
||||
extra_final = update_fields.get(
|
||||
"extra_source_ids", existing_agent.get("extra_source_ids") or [],
|
||||
)
|
||||
if not source_final and not extra_final:
|
||||
missing_published_fields.append("Source")
|
||||
# ``retriever`` carries the runtime identity for
|
||||
# agents that publish against the synthetic
|
||||
# "Default" source (frontend's auto-selected
|
||||
# ``{name: "Default", retriever: "classic"}``
|
||||
# entry has no ``id``, so ``source_id`` ends up
|
||||
# NULL even though the user picked something).
|
||||
# Without this fallback the most common new-agent
|
||||
# publish flow gets a 400.
|
||||
retriever_final = update_fields.get(
|
||||
"retriever", existing_agent.get("retriever"),
|
||||
)
|
||||
if (
|
||||
not source_final
|
||||
and not extra_final
|
||||
and not retriever_final
|
||||
):
|
||||
missing_published_fields.append("Source or retriever")
|
||||
if missing_published_fields:
|
||||
return make_response(
|
||||
jsonify(
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
"""Agent management webhook handlers."""
|
||||
|
||||
import secrets
|
||||
import uuid
|
||||
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import Namespace, Resource
|
||||
from sqlalchemy import text as sql_text
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import require_agent
|
||||
from application.api.user.tasks import process_agent_webhook
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.idempotency import IdempotencyRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
|
||||
@@ -18,6 +22,37 @@ agents_webhooks_ns = Namespace(
|
||||
)
|
||||
|
||||
|
||||
_IDEMPOTENCY_KEY_MAX_LEN = 256
|
||||
|
||||
|
||||
def _read_idempotency_key():
|
||||
"""Return (key, error_response). Empty header → (None, None); oversized → (None, 400)."""
|
||||
key = request.headers.get("Idempotency-Key")
|
||||
if not key:
|
||||
return None, None
|
||||
if len(key) > _IDEMPOTENCY_KEY_MAX_LEN:
|
||||
return None, make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": (
|
||||
f"Idempotency-Key exceeds maximum length of "
|
||||
f"{_IDEMPOTENCY_KEY_MAX_LEN} characters"
|
||||
),
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
return key, None
|
||||
|
||||
|
||||
def _scoped_idempotency_key(idempotency_key, scope):
|
||||
"""``{scope}:{key}`` so different agents can't collide on the same key."""
|
||||
if not idempotency_key or not scope:
|
||||
return None
|
||||
return f"{scope}:{idempotency_key}"
|
||||
|
||||
|
||||
@agents_webhooks_ns.route("/agent_webhook")
|
||||
class AgentWebhook(Resource):
|
||||
@api.doc(
|
||||
@@ -68,7 +103,7 @@ class AgentWebhook(Resource):
|
||||
class AgentWebhookListener(Resource):
|
||||
method_decorators = [require_agent]
|
||||
|
||||
def _enqueue_webhook_task(self, agent_id_str, payload, source_method):
|
||||
def _enqueue_webhook_task(self, agent_id_str, payload, source_method, agent=None):
|
||||
if not payload:
|
||||
current_app.logger.warning(
|
||||
f"Webhook ({source_method}) received for agent {agent_id_str} with empty payload."
|
||||
@@ -77,26 +112,94 @@ class AgentWebhookListener(Resource):
|
||||
f"Incoming {source_method} webhook for agent {agent_id_str}. Enqueuing task with payload: {payload}"
|
||||
)
|
||||
|
||||
try:
|
||||
task = process_agent_webhook.delay(
|
||||
agent_id=agent_id_str,
|
||||
payload=payload,
|
||||
idempotency_key, key_error = _read_idempotency_key()
|
||||
if key_error is not None:
|
||||
return key_error
|
||||
# Resolve to PG UUID first so dedup writes don't crash on legacy ids.
|
||||
agent_uuid = None
|
||||
if agent is not None:
|
||||
candidate = str(agent.get("id") or "")
|
||||
if looks_like_uuid(candidate):
|
||||
agent_uuid = candidate
|
||||
if idempotency_key and agent_uuid is None:
|
||||
current_app.logger.warning(
|
||||
"Skipping webhook idempotency dedup: agent %s has non-UUID id",
|
||||
agent_id_str,
|
||||
)
|
||||
idempotency_key = None
|
||||
# Agent-scoped (webhooks have no user_id).
|
||||
scoped_key = _scoped_idempotency_key(idempotency_key, agent_uuid)
|
||||
# Claim before enqueue; the loser returns the winner's task_id.
|
||||
predetermined_task_id = None
|
||||
if scoped_key:
|
||||
predetermined_task_id = str(uuid.uuid4())
|
||||
with db_session() as conn:
|
||||
claimed = IdempotencyRepository(conn).record_webhook(
|
||||
key=scoped_key,
|
||||
agent_id=agent_uuid,
|
||||
task_id=predetermined_task_id,
|
||||
response_json={
|
||||
"success": True, "task_id": predetermined_task_id,
|
||||
},
|
||||
)
|
||||
if claimed is None:
|
||||
with db_readonly() as conn:
|
||||
cached = IdempotencyRepository(conn).get_webhook(scoped_key)
|
||||
if cached is not None:
|
||||
return make_response(jsonify(cached["response_json"]), 200)
|
||||
return make_response(
|
||||
jsonify({"success": True, "task_id": "deduplicated"}), 200
|
||||
)
|
||||
|
||||
try:
|
||||
apply_kwargs = dict(
|
||||
kwargs={
|
||||
"agent_id": agent_id_str,
|
||||
"payload": payload,
|
||||
# Scoped so the worker dedup row matches the HTTP claim.
|
||||
"idempotency_key": scoped_key or idempotency_key,
|
||||
},
|
||||
)
|
||||
if predetermined_task_id is not None:
|
||||
apply_kwargs["task_id"] = predetermined_task_id
|
||||
task = process_agent_webhook.apply_async(**apply_kwargs)
|
||||
current_app.logger.info(
|
||||
f"Task {task.id} enqueued for agent {agent_id_str} ({source_method})."
|
||||
)
|
||||
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
|
||||
response_payload = {"success": True, "task_id": task.id}
|
||||
return make_response(jsonify(response_payload), 200)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error enqueuing webhook task ({source_method}) for agent {agent_id_str}: {err}",
|
||||
exc_info=True,
|
||||
)
|
||||
if scoped_key:
|
||||
# Roll back the claim so a retry can succeed.
|
||||
try:
|
||||
with db_session() as conn:
|
||||
conn.execute(
|
||||
sql_text(
|
||||
"DELETE FROM webhook_dedup "
|
||||
"WHERE idempotency_key = :k"
|
||||
),
|
||||
{"k": scoped_key},
|
||||
)
|
||||
except Exception:
|
||||
current_app.logger.exception(
|
||||
"Failed to release webhook_dedup claim for key=%s",
|
||||
scoped_key,
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Error processing webhook"}), 500
|
||||
)
|
||||
|
||||
@api.doc(
|
||||
description="Webhook listener for agent events (POST). Expects JSON payload, which is used to trigger processing.",
|
||||
description=(
|
||||
"Webhook listener for agent events (POST). Expects JSON payload, which "
|
||||
"is used to trigger processing. Honors an optional ``Idempotency-Key`` "
|
||||
"header: a repeat request with the same key within 24h returns the "
|
||||
"original cached response and does not re-enqueue the task."
|
||||
),
|
||||
)
|
||||
def post(self, webhook_token, agent, agent_id_str):
|
||||
payload = request.get_json()
|
||||
@@ -110,11 +213,20 @@ class AgentWebhookListener(Resource):
|
||||
),
|
||||
400,
|
||||
)
|
||||
return self._enqueue_webhook_task(agent_id_str, payload, source_method="POST")
|
||||
return self._enqueue_webhook_task(
|
||||
agent_id_str, payload, source_method="POST", agent=agent,
|
||||
)
|
||||
|
||||
@api.doc(
|
||||
description="Webhook listener for agent events (GET). Uses URL query parameters as payload to trigger processing.",
|
||||
description=(
|
||||
"Webhook listener for agent events (GET). Uses URL query parameters as "
|
||||
"payload to trigger processing. Honors an optional ``Idempotency-Key`` "
|
||||
"header: a repeat request with the same key within 24h returns the "
|
||||
"original cached response and does not re-enqueue the task."
|
||||
),
|
||||
)
|
||||
def get(self, webhook_token, agent, agent_id_str):
|
||||
payload = request.args.to_dict(flat=True)
|
||||
return self._enqueue_webhook_task(agent_id_str, payload, source_method="GET")
|
||||
return self._enqueue_webhook_task(
|
||||
agent_id_str, payload, source_method="GET", agent=agent,
|
||||
)
|
||||
|
||||
@@ -214,6 +214,10 @@ class StoreAttachment(Resource):
|
||||
{
|
||||
"success": True,
|
||||
"task_id": tasks[0]["task_id"],
|
||||
# Surface the attachment_id so the frontend
|
||||
# can correlate ``attachment.*`` SSE events
|
||||
# to this row and skip the polling fallback.
|
||||
"attachment_id": tasks[0]["attachment_id"],
|
||||
"message": "File uploaded successfully. Processing started.",
|
||||
}
|
||||
),
|
||||
|
||||
@@ -83,13 +83,15 @@ def resolve_tool_details(tool_ids):
|
||||
"""
|
||||
Resolve tool IDs to their display details.
|
||||
|
||||
Accepts either Postgres UUIDs or legacy Mongo ObjectId strings (mixed
|
||||
lists are supported — each id is looked up via ``get_any``, which
|
||||
resolves to whichever column matches). Unknown ids are silently
|
||||
Accepts Postgres UUIDs, legacy Mongo ObjectId strings, or the
|
||||
synthetic ids of default chat tools / agent-selectable builtins
|
||||
(mixed lists are supported). Synthetic ids are resolved in memory;
|
||||
real ids are looked up via ``get_any``. Unknown ids are silently
|
||||
skipped.
|
||||
|
||||
Args:
|
||||
tool_ids: List of tool IDs (UUIDs or legacy Mongo ObjectId strings).
|
||||
tool_ids: List of tool IDs (UUIDs, legacy ObjectId strings, or
|
||||
synthetic default-tool / builtin ids).
|
||||
|
||||
Returns:
|
||||
List of tool details with ``id``, ``name``, and ``display_name``.
|
||||
@@ -97,19 +99,37 @@ def resolve_tool_details(tool_ids):
|
||||
if not tool_ids:
|
||||
return []
|
||||
|
||||
from application.agents.default_tools import (
|
||||
is_synthesized_tool_id,
|
||||
synthesize_tool_by_name,
|
||||
synthesized_tool_name_for_id,
|
||||
)
|
||||
|
||||
uuid_ids: list[str] = []
|
||||
legacy_ids: list[str] = []
|
||||
default_details: list[dict] = []
|
||||
for tid in tool_ids:
|
||||
if not tid:
|
||||
continue
|
||||
tid_str = str(tid)
|
||||
if is_synthesized_tool_id(tid_str):
|
||||
synth = synthesize_tool_by_name(synthesized_tool_name_for_id(tid_str))
|
||||
if synth is not None:
|
||||
default_details.append(
|
||||
{
|
||||
"id": tid_str,
|
||||
"name": synth.get("name", ""),
|
||||
"display_name": synth.get("display_name", ""),
|
||||
}
|
||||
)
|
||||
continue
|
||||
if looks_like_uuid(tid_str):
|
||||
uuid_ids.append(tid_str)
|
||||
else:
|
||||
legacy_ids.append(tid_str)
|
||||
|
||||
if not uuid_ids and not legacy_ids:
|
||||
return []
|
||||
return default_details
|
||||
|
||||
rows: list[dict] = []
|
||||
with db_readonly() as conn:
|
||||
@@ -132,7 +152,7 @@ def resolve_tool_details(tool_ids):
|
||||
)
|
||||
rows.extend(row_to_dict(r) for r in result.fetchall())
|
||||
|
||||
return [
|
||||
return default_details + [
|
||||
{
|
||||
"id": str(tool.get("id") or tool.get("legacy_mongo_id") or ""),
|
||||
"name": tool.get("name", "") or "",
|
||||
|
||||
@@ -4,10 +4,16 @@ import datetime
|
||||
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
from sqlalchemy import text as sql_text
|
||||
|
||||
from application.api import api
|
||||
from application.api.answer.services.conversation_service import (
|
||||
TERMINATED_RESPONSE_PLACEHOLDER,
|
||||
)
|
||||
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
|
||||
from application.storage.db.repositories.attachments import AttachmentsRepository
|
||||
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||
from application.storage.db.repositories.message_events import MessageEventsRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.utils import check_required_fields
|
||||
|
||||
@@ -133,6 +139,7 @@ class GetSingleConversation(Resource):
|
||||
attachments_repo = AttachmentsRepository(conn)
|
||||
queries = []
|
||||
for msg in messages:
|
||||
metadata = msg.get("metadata") or {}
|
||||
query = {
|
||||
"prompt": msg.get("prompt"),
|
||||
"response": msg.get("response"),
|
||||
@@ -141,9 +148,15 @@ class GetSingleConversation(Resource):
|
||||
"tool_calls": msg.get("tool_calls") or [],
|
||||
"timestamp": msg.get("timestamp"),
|
||||
"model_id": msg.get("model_id"),
|
||||
# Lets the client distinguish placeholder rows from
|
||||
# finalised answers and tail-poll in-flight ones.
|
||||
"message_id": str(msg["id"]) if msg.get("id") else None,
|
||||
"status": msg.get("status"),
|
||||
"request_id": msg.get("request_id"),
|
||||
"last_heartbeat_at": metadata.get("last_heartbeat_at"),
|
||||
}
|
||||
if msg.get("metadata"):
|
||||
query["metadata"] = msg["metadata"]
|
||||
if metadata:
|
||||
query["metadata"] = metadata
|
||||
# Feedback on conversation_messages is a JSONB blob with
|
||||
# shape {"text": <str>, "timestamp": <iso>}. The legacy
|
||||
# frontend consumed a flat scalar feedback string, so
|
||||
@@ -301,3 +314,80 @@ class SubmitFeedback(Resource):
|
||||
current_app.logger.error(f"Error submitting feedback: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@conversations_ns.route("/messages/<string:message_id>/tail")
|
||||
class GetMessageTail(Resource):
|
||||
@api.doc(
|
||||
description=(
|
||||
"Current state of one conversation_messages row, scoped to the "
|
||||
"authenticated user. Used to reconnect to an in-flight stream "
|
||||
"after a refresh."
|
||||
),
|
||||
params={"message_id": "Message UUID"},
|
||||
)
|
||||
def get(self, message_id):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
if not looks_like_uuid(message_id):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid message id"}), 400
|
||||
)
|
||||
user_id = decoded_token.get("sub")
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
# Owner-or-shared, matching ``ConversationsRepository.get``.
|
||||
row = conn.execute(
|
||||
sql_text(
|
||||
"SELECT m.* FROM conversation_messages m "
|
||||
"JOIN conversations c ON c.id = m.conversation_id "
|
||||
"WHERE m.id = CAST(:mid AS uuid) "
|
||||
"AND (c.user_id = :uid OR :uid = ANY(c.shared_with))"
|
||||
),
|
||||
{"mid": message_id, "uid": user_id},
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return make_response(jsonify({"status": "not found"}), 404)
|
||||
msg = row_to_dict(row)
|
||||
# Mid-stream the row's response is the placeholder; rebuild
|
||||
# the live partial from the journal so /tail mirrors SSE.
|
||||
status = msg.get("status")
|
||||
response = msg.get("response")
|
||||
thought = msg.get("thought")
|
||||
sources = msg.get("sources") or []
|
||||
tool_calls = msg.get("tool_calls") or []
|
||||
if status in ("pending", "streaming") and (
|
||||
response == TERMINATED_RESPONSE_PLACEHOLDER
|
||||
):
|
||||
partial = MessageEventsRepository(conn).reconstruct_partial(
|
||||
message_id
|
||||
)
|
||||
response = partial["response"]
|
||||
thought = partial["thought"] or thought
|
||||
if partial["sources"]:
|
||||
sources = partial["sources"]
|
||||
if partial["tool_calls"]:
|
||||
tool_calls = partial["tool_calls"]
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error tailing message {message_id}: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
metadata = msg.get("message_metadata") or {}
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"message_id": str(msg["id"]),
|
||||
"status": status,
|
||||
"response": response,
|
||||
"thought": thought,
|
||||
"sources": sources,
|
||||
"tool_calls": tool_calls,
|
||||
"request_id": msg.get("request_id"),
|
||||
"last_heartbeat_at": metadata.get("last_heartbeat_at"),
|
||||
"error": metadata.get("error"),
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
|
||||
294
application/api/user/idempotency.py
Normal file
294
application/api/user/idempotency.py
Normal file
@@ -0,0 +1,294 @@
|
||||
"""Per-Celery-task idempotency wrapper backed by ``task_dedup``."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from application.storage.db.repositories.idempotency import IdempotencyRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Poison-loop cap; transient-failure headroom without infinite retry.
|
||||
MAX_TASK_ATTEMPTS = 5
|
||||
|
||||
# 30s heartbeat / 60s TTL → ~2 missed ticks of slack before reclaim.
|
||||
LEASE_TTL_SECONDS = 60
|
||||
LEASE_HEARTBEAT_INTERVAL = 30
|
||||
|
||||
# 10 × 60s ≈ 5 min of deferral before giving up on a held lease.
|
||||
LEASE_RETRY_MAX = 10
|
||||
|
||||
|
||||
def with_idempotency(
|
||||
task_name: str,
|
||||
*,
|
||||
on_poison: Optional[Callable[[str, dict], None]] = None,
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
"""Short-circuit on completed key; gate concurrent runs via a lease.
|
||||
|
||||
The guard key is the caller's ``idempotency_key``, or one synthesized
|
||||
from ``source_id`` so a keyless dispatch is still poison-guarded.
|
||||
|
||||
Entry short-circuits:
|
||||
- completed row → return cached result
|
||||
- live lease held → retry(countdown=LEASE_TTL_SECONDS)
|
||||
- attempt_count > MAX_TASK_ATTEMPTS → poison alert; ``on_poison`` fires
|
||||
Success writes ``completed``; exceptions leave ``pending`` for
|
||||
autoretry until the poison-loop guard trips.
|
||||
"""
|
||||
|
||||
def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
|
||||
@functools.wraps(fn)
|
||||
def wrapper(self, *args: Any, idempotency_key: Any = None, **kwargs: Any) -> Any:
|
||||
explicit_key = (
|
||||
idempotency_key
|
||||
if isinstance(idempotency_key, str) and idempotency_key
|
||||
else None
|
||||
)
|
||||
# A keyless dispatch still gets the guard via a synthesized key;
|
||||
# None means no anchor exists — run unguarded, as before.
|
||||
key = explicit_key or _synthesize_guard_key(task_name, kwargs)
|
||||
if key is None:
|
||||
return fn(self, *args, idempotency_key=idempotency_key, **kwargs)
|
||||
|
||||
cached = _lookup_completed(key)
|
||||
if cached is not None:
|
||||
logger.info(
|
||||
"idempotency hit for task=%s key=%s — returning cached result",
|
||||
task_name, key,
|
||||
)
|
||||
return cached
|
||||
|
||||
owner_id = str(uuid.uuid4())
|
||||
attempt = _try_claim_lease(
|
||||
key, task_name, _safe_task_id(self), owner_id,
|
||||
)
|
||||
if attempt is None:
|
||||
# Live lease held by another worker. Re-queue and bail
|
||||
# quickly — by the time the retry fires (LEASE_TTL
|
||||
# seconds), Worker 1 has either finalised (we'll hit
|
||||
# ``_lookup_completed`` and return cached) or its lease
|
||||
# has expired and we can claim.
|
||||
logger.info(
|
||||
"idempotency: live lease held; deferring task=%s key=%s",
|
||||
task_name, key,
|
||||
)
|
||||
raise self.retry(
|
||||
countdown=LEASE_TTL_SECONDS,
|
||||
max_retries=LEASE_RETRY_MAX,
|
||||
)
|
||||
|
||||
if attempt > MAX_TASK_ATTEMPTS:
|
||||
logger.error(
|
||||
"idempotency poison-loop guard: task=%s key=%s attempts=%s",
|
||||
task_name, key, attempt,
|
||||
extra={
|
||||
"alert": "idempotency_poison_loop",
|
||||
"task_name": task_name,
|
||||
"idempotency_key": key,
|
||||
"attempts": attempt,
|
||||
},
|
||||
)
|
||||
poisoned = {
|
||||
"success": False,
|
||||
"error": "idempotency poison-loop guard tripped",
|
||||
"attempts": attempt,
|
||||
}
|
||||
_finalize(key, poisoned, status="failed")
|
||||
_run_poison_hook(
|
||||
on_poison, task_name, fn, self, args, kwargs, idempotency_key,
|
||||
)
|
||||
return poisoned
|
||||
|
||||
heartbeat_thread, heartbeat_stop = _start_lease_heartbeat(
|
||||
key, owner_id,
|
||||
)
|
||||
try:
|
||||
result = fn(self, *args, idempotency_key=idempotency_key, **kwargs)
|
||||
_finalize(key, result, status="completed")
|
||||
return result
|
||||
except Exception:
|
||||
# Drop the lease so the next retry doesn't wait LEASE_TTL.
|
||||
_release_lease(key, owner_id)
|
||||
raise
|
||||
finally:
|
||||
_stop_lease_heartbeat(heartbeat_thread, heartbeat_stop)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _synthesize_guard_key(task_name: str, kwargs: dict) -> Optional[str]:
|
||||
"""Derive a deterministic guard key from ``source_id`` for a keyless dispatch.
|
||||
|
||||
``source_id`` is stable across broker redeliveries and unique per
|
||||
upload, so the poison-loop counter survives an OOM SIGKILL. Returns
|
||||
``None`` when absent — the dispatch then runs unguarded as before.
|
||||
"""
|
||||
source_id = kwargs.get("source_id")
|
||||
if source_id:
|
||||
return f"auto:{task_name}:{source_id}"
|
||||
return None
|
||||
|
||||
|
||||
def _run_poison_hook(
|
||||
on_poison: Optional[Callable[[str, dict], None]],
|
||||
task_name: str,
|
||||
fn: Callable[..., Any],
|
||||
task_self: Any,
|
||||
args: tuple,
|
||||
kwargs: dict,
|
||||
idempotency_key: Any,
|
||||
) -> None:
|
||||
"""Invoke a task's poison-path hook with named call args; swallow failures.
|
||||
|
||||
A hook failure must never change the poison-guard outcome.
|
||||
"""
|
||||
if on_poison is None:
|
||||
return
|
||||
try:
|
||||
bound = inspect.signature(fn).bind_partial(
|
||||
task_self, *args, idempotency_key=idempotency_key, **kwargs,
|
||||
)
|
||||
on_poison(task_name, dict(bound.arguments))
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"idempotency: poison hook failed for task=%s", task_name,
|
||||
)
|
||||
|
||||
|
||||
def _lookup_completed(key: str) -> Any:
|
||||
"""Return cached ``result_json`` if a completed row exists for ``key``, else None."""
|
||||
with db_readonly() as conn:
|
||||
row = IdempotencyRepository(conn).get_task(key)
|
||||
if row is None:
|
||||
return None
|
||||
if row.get("status") != "completed":
|
||||
return None
|
||||
return row.get("result_json")
|
||||
|
||||
|
||||
def _try_claim_lease(
|
||||
key: str, task_name: str, task_id: str, owner_id: str,
|
||||
) -> Optional[int]:
|
||||
"""Atomic CAS; returns ``attempt_count`` or ``None`` when held.
|
||||
|
||||
DB outage → treated as ``attempt=1`` so transient failures don't
|
||||
block all task execution; reconciler repairs the lease columns.
|
||||
"""
|
||||
try:
|
||||
with db_session() as conn:
|
||||
return IdempotencyRepository(conn).try_claim_lease(
|
||||
key=key,
|
||||
task_name=task_name,
|
||||
task_id=task_id,
|
||||
owner_id=owner_id,
|
||||
ttl_seconds=LEASE_TTL_SECONDS,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"idempotency lease-claim failed for key=%s task=%s", key, task_name,
|
||||
)
|
||||
return 1
|
||||
|
||||
|
||||
def _finalize(key: str, result_json: Any, *, status: str) -> None:
|
||||
"""Best-effort terminal write. Never let DB outage fail the task."""
|
||||
try:
|
||||
with db_session() as conn:
|
||||
IdempotencyRepository(conn).finalize_task(
|
||||
key=key, result_json=result_json, status=status,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"idempotency finalize failed for key=%s status=%s", key, status,
|
||||
)
|
||||
|
||||
|
||||
def _release_lease(key: str, owner_id: str) -> None:
|
||||
"""Best-effort lease release on the wrapper's exception path."""
|
||||
try:
|
||||
with db_session() as conn:
|
||||
IdempotencyRepository(conn).release_lease(key, owner_id)
|
||||
except Exception:
|
||||
logger.exception("idempotency release-lease failed for key=%s", key)
|
||||
|
||||
|
||||
def _start_lease_heartbeat(
|
||||
key: str, owner_id: str,
|
||||
) -> tuple[threading.Thread, threading.Event]:
|
||||
"""Spawn a daemon thread that bumps ``lease_expires_at`` every
|
||||
:data:`LEASE_HEARTBEAT_INTERVAL` seconds until ``stop_event`` fires.
|
||||
|
||||
Mirrors ``application.worker._start_ingest_heartbeat`` so the two
|
||||
durability heartbeats share shape and cadence.
|
||||
"""
|
||||
stop_event = threading.Event()
|
||||
thread = threading.Thread(
|
||||
target=_lease_heartbeat_loop,
|
||||
args=(key, owner_id, stop_event, LEASE_HEARTBEAT_INTERVAL),
|
||||
daemon=True,
|
||||
name=f"idempotency-lease-heartbeat:{key[:32]}",
|
||||
)
|
||||
thread.start()
|
||||
return thread, stop_event
|
||||
|
||||
|
||||
def _stop_lease_heartbeat(
|
||||
thread: threading.Thread, stop_event: threading.Event,
|
||||
) -> None:
|
||||
"""Signal the heartbeat thread to exit and join with a short timeout."""
|
||||
stop_event.set()
|
||||
thread.join(timeout=10)
|
||||
|
||||
|
||||
def _lease_heartbeat_loop(
|
||||
key: str,
|
||||
owner_id: str,
|
||||
stop_event: threading.Event,
|
||||
interval: int,
|
||||
) -> None:
|
||||
"""Refresh the lease until ``stop_event`` is set or ownership is lost.
|
||||
|
||||
A failed refresh (rowcount 0) means another worker stole the lease
|
||||
after expiry — at that point the damage is already possible, so we
|
||||
log and keep ticking. Don't escalate to thread death; the main task
|
||||
body needs to keep running so its outcome is at least *recorded*.
|
||||
"""
|
||||
while not stop_event.wait(interval):
|
||||
try:
|
||||
with db_session() as conn:
|
||||
still_owned = IdempotencyRepository(conn).refresh_lease(
|
||||
key=key, owner_id=owner_id, ttl_seconds=LEASE_TTL_SECONDS,
|
||||
)
|
||||
if not still_owned:
|
||||
logger.warning(
|
||||
"idempotency lease lost mid-task for key=%s "
|
||||
"(another worker may have taken over)",
|
||||
key,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"idempotency lease-heartbeat tick failed for key=%s", key,
|
||||
)
|
||||
|
||||
|
||||
def _safe_task_id(task_self: Any) -> str:
|
||||
"""Best-effort extraction of ``self.request.id`` from a Celery task."""
|
||||
try:
|
||||
request = getattr(task_self, "request", None)
|
||||
task_id: Optional[str] = (
|
||||
getattr(request, "id", None) if request is not None else None
|
||||
)
|
||||
except Exception:
|
||||
task_id = None
|
||||
return task_id or "unknown"
|
||||
292
application/api/user/reconciliation.py
Normal file
292
application/api/user/reconciliation.py
Normal file
@@ -0,0 +1,292 @@
|
||||
"""Reconciler tick: sweep stuck rows and escalate to terminal status + alert."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Optional, TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import Connection
|
||||
|
||||
from application.api.user.idempotency import MAX_TASK_ATTEMPTS
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.engine import get_engine
|
||||
from application.storage.db.repositories.reconciliation import (
|
||||
ReconciliationRepository,
|
||||
)
|
||||
from application.storage.db.repositories.stack_logs import StackLogsRepository
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from application.storage.db.repositories.schedules import SchedulesRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
MAX_MESSAGE_RECONCILE_ATTEMPTS = 3
|
||||
|
||||
|
||||
def run_reconciliation() -> Dict[str, Any]:
|
||||
"""Single tick of the reconciler. Five sweeps, FOR UPDATE SKIP LOCKED.
|
||||
|
||||
Stuck ``executed`` tool calls always flip to ``failed`` — operators
|
||||
handle cleanup manually via the structured alert. The side effect is
|
||||
assumed to have committed; no automated rollback is attempted.
|
||||
|
||||
Stuck ``task_dedup`` rows (lease expired AND attempts >= max)
|
||||
promote to ``failed`` so a same-key retry can re-claim instead of
|
||||
sitting in ``pending`` until 24 h TTL.
|
||||
"""
|
||||
if not settings.POSTGRES_URI:
|
||||
return {
|
||||
"messages_failed": 0,
|
||||
"tool_calls_failed": 0,
|
||||
"skipped": "POSTGRES_URI not set",
|
||||
}
|
||||
|
||||
engine = get_engine()
|
||||
summary = {
|
||||
"messages_failed": 0,
|
||||
"tool_calls_failed": 0,
|
||||
"ingests_stalled": 0,
|
||||
"idempotency_pending_failed": 0,
|
||||
"schedule_runs_failed": 0,
|
||||
}
|
||||
|
||||
with engine.begin() as conn:
|
||||
repo = ReconciliationRepository(conn)
|
||||
for msg in repo.find_and_lock_stuck_messages():
|
||||
new_count = repo.increment_message_reconcile_attempts(msg["id"])
|
||||
if new_count >= MAX_MESSAGE_RECONCILE_ATTEMPTS:
|
||||
repo.mark_message_failed(
|
||||
msg["id"],
|
||||
error=(
|
||||
"reconciler: stuck in pending/streaming for >5 min "
|
||||
f"after {new_count} attempts"
|
||||
),
|
||||
)
|
||||
summary["messages_failed"] += 1
|
||||
_emit_alert(
|
||||
conn,
|
||||
name="reconciler_message_failed",
|
||||
user_id=msg.get("user_id"),
|
||||
detail={
|
||||
"message_id": str(msg["id"]),
|
||||
"attempts": new_count,
|
||||
},
|
||||
)
|
||||
|
||||
with engine.begin() as conn:
|
||||
repo = ReconciliationRepository(conn)
|
||||
for row in repo.find_and_lock_proposed_tool_calls():
|
||||
repo.mark_tool_call_failed(
|
||||
row["call_id"],
|
||||
error=(
|
||||
"reconciler: stuck in 'proposed' for >5 min; "
|
||||
"side effect status unknown"
|
||||
),
|
||||
)
|
||||
summary["tool_calls_failed"] += 1
|
||||
_emit_alert(
|
||||
conn,
|
||||
name="reconciler_tool_call_failed_proposed",
|
||||
user_id=None,
|
||||
detail={
|
||||
"call_id": row["call_id"],
|
||||
"tool_name": row.get("tool_name"),
|
||||
},
|
||||
)
|
||||
|
||||
with engine.begin() as conn:
|
||||
repo = ReconciliationRepository(conn)
|
||||
for row in repo.find_and_lock_executed_tool_calls():
|
||||
repo.mark_tool_call_failed(
|
||||
row["call_id"],
|
||||
error=(
|
||||
"reconciler: executed-not-confirmed; side effect "
|
||||
"assumed committed, manual cleanup required"
|
||||
),
|
||||
)
|
||||
summary["tool_calls_failed"] += 1
|
||||
_emit_alert(
|
||||
conn,
|
||||
name="reconciler_tool_call_failed_executed",
|
||||
user_id=None,
|
||||
detail={
|
||||
"call_id": row["call_id"],
|
||||
"tool_name": row.get("tool_name"),
|
||||
"action_name": row.get("action_name"),
|
||||
},
|
||||
)
|
||||
|
||||
# Q4: ingest checkpoints whose heartbeat has gone silent. Each is
|
||||
# escalated to terminal ``status='stalled'`` and alerted once — no
|
||||
# worker kill, no rollback of the partial embed. The 'stalled' flag
|
||||
# ends the re-alert loop and drives the "indexing failed" badge the
|
||||
# sources list derives from this row.
|
||||
with engine.begin() as conn:
|
||||
repo = ReconciliationRepository(conn)
|
||||
for row in repo.find_and_lock_stalled_ingests():
|
||||
summary["ingests_stalled"] += 1
|
||||
_emit_alert(
|
||||
conn,
|
||||
name="reconciler_ingest_stalled",
|
||||
user_id=None,
|
||||
detail={
|
||||
"source_id": str(row.get("source_id")),
|
||||
"embedded_chunks": row.get("embedded_chunks"),
|
||||
"total_chunks": row.get("total_chunks"),
|
||||
"last_updated": str(row.get("last_updated")),
|
||||
},
|
||||
)
|
||||
repo.mark_ingest_stalled(str(row["source_id"]))
|
||||
|
||||
# Q5: idempotency rows whose lease expired with attempts exhausted.
|
||||
# The wrapper's poison-loop guard normally finalises these, but if
|
||||
# the wrapper itself died mid-task (worker SIGKILL, OOM during
|
||||
# heartbeat) the row sits in ``pending`` blocking same-key retries
|
||||
# via ``_lookup_completed`` returning None for the whole 24 h TTL.
|
||||
# Promote to ``failed`` so a retry can re-claim and either resume
|
||||
# or fail loudly.
|
||||
with engine.begin() as conn:
|
||||
repo = ReconciliationRepository(conn)
|
||||
for row in repo.find_stuck_idempotency_pending(
|
||||
max_attempts=MAX_TASK_ATTEMPTS,
|
||||
):
|
||||
error_msg = (
|
||||
"reconciler: idempotency lease expired with attempts "
|
||||
f"({row['attempt_count']}) >= {MAX_TASK_ATTEMPTS}; "
|
||||
"task abandoned"
|
||||
)
|
||||
repo.mark_idempotency_pending_failed(
|
||||
row["idempotency_key"], error=error_msg,
|
||||
)
|
||||
summary["idempotency_pending_failed"] += 1
|
||||
_emit_alert(
|
||||
conn,
|
||||
name="reconciler_idempotency_pending_failed",
|
||||
user_id=None,
|
||||
detail={
|
||||
"idempotency_key": row["idempotency_key"],
|
||||
"task_name": row.get("task_name"),
|
||||
"task_id": row.get("task_id"),
|
||||
"attempts": row.get("attempt_count"),
|
||||
},
|
||||
)
|
||||
|
||||
# Q6: scheduler runs stuck in 'running' past the soft-time-limit window.
|
||||
from application.storage.db.repositories.schedule_runs import (
|
||||
ScheduleRunsRepository,
|
||||
)
|
||||
from application.storage.db.repositories.schedules import SchedulesRepository
|
||||
from application.core.settings import settings as _settings
|
||||
|
||||
stuck_age = max(
|
||||
15, int(_settings.SCHEDULE_RUN_TIMEOUT // 60) + 5,
|
||||
)
|
||||
with engine.begin() as conn:
|
||||
runs_repo = ScheduleRunsRepository(conn)
|
||||
schedules_repo = SchedulesRepository(conn)
|
||||
for run in runs_repo.list_stuck_running(age_minutes=stuck_age):
|
||||
runs_repo.update(
|
||||
run["id"],
|
||||
{
|
||||
"status": "timeout",
|
||||
"finished_at": datetime.now(timezone.utc),
|
||||
"error_type": "timeout",
|
||||
"error": (
|
||||
"reconciler: schedule_run stuck in 'running' past "
|
||||
f"{stuck_age} min"
|
||||
),
|
||||
},
|
||||
)
|
||||
schedules_repo.bump_failure_count(str(run["schedule_id"]))
|
||||
_terminal_flip_once_schedule(
|
||||
schedules_repo, str(run["schedule_id"]),
|
||||
)
|
||||
summary["schedule_runs_failed"] += 1
|
||||
_emit_alert(
|
||||
conn,
|
||||
name="reconciler_schedule_run_timeout",
|
||||
user_id=run.get("user_id"),
|
||||
detail={
|
||||
"run_id": str(run["id"]),
|
||||
"schedule_id": str(run["schedule_id"]),
|
||||
},
|
||||
)
|
||||
|
||||
# Q7: scheduler runs orphaned in 'pending' — dispatcher committed but
|
||||
# apply_async failed (broker outage / crash mid-dispatch).
|
||||
with engine.begin() as conn:
|
||||
runs_repo = ScheduleRunsRepository(conn)
|
||||
schedules_repo = SchedulesRepository(conn)
|
||||
for run in runs_repo.list_stuck_pending(age_minutes=stuck_age):
|
||||
runs_repo.update(
|
||||
run["id"],
|
||||
{
|
||||
"status": "failed",
|
||||
"finished_at": datetime.now(timezone.utc),
|
||||
"error_type": "internal",
|
||||
"error": (
|
||||
"reconciler: schedule_run stuck in 'pending' past "
|
||||
f"{stuck_age} min (worker_never_started)"
|
||||
),
|
||||
},
|
||||
)
|
||||
schedules_repo.bump_failure_count(str(run["schedule_id"]))
|
||||
_terminal_flip_once_schedule(
|
||||
schedules_repo, str(run["schedule_id"]),
|
||||
)
|
||||
summary["schedule_runs_failed"] += 1
|
||||
_emit_alert(
|
||||
conn,
|
||||
name="reconciler_schedule_run_pending",
|
||||
user_id=run.get("user_id"),
|
||||
detail={
|
||||
"run_id": str(run["id"]),
|
||||
"schedule_id": str(run["schedule_id"]),
|
||||
},
|
||||
)
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
def _terminal_flip_once_schedule(
|
||||
schedules_repo: "SchedulesRepository", schedule_id: str,
|
||||
) -> None:
|
||||
"""Flip a once-schedule to 'completed' after its run terminates.
|
||||
|
||||
Recurring schedules keep firing; once-schedules would otherwise read
|
||||
'active forever' since next_run_at is already NULL.
|
||||
"""
|
||||
schedule = schedules_repo.get_internal(schedule_id)
|
||||
if schedule is None or schedule.get("trigger_type") != "once":
|
||||
return
|
||||
if schedule.get("status") in {"completed", "cancelled"}:
|
||||
return
|
||||
schedules_repo.update_internal(
|
||||
schedule_id, {"status": "completed", "next_run_at": None},
|
||||
)
|
||||
|
||||
|
||||
def _emit_alert(
|
||||
conn: Connection,
|
||||
*,
|
||||
name: str,
|
||||
user_id: Optional[str],
|
||||
detail: Dict[str, Any],
|
||||
) -> None:
|
||||
"""Structured ``logger.error`` plus a ``stack_logs`` row for operators."""
|
||||
extra = {"alert": name, **detail}
|
||||
logger.error("reconciler alert: %s", name, extra=extra)
|
||||
try:
|
||||
StackLogsRepository(conn).insert(
|
||||
activity_id=str(uuid.uuid4()),
|
||||
endpoint="reconciliation_worker",
|
||||
level="ERROR",
|
||||
user_id=user_id,
|
||||
query=name,
|
||||
stacks=[extra],
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("reconciler: failed to write stack_logs row for %s", name)
|
||||
@@ -11,6 +11,7 @@ from .attachments import attachments_ns
|
||||
from .conversations import conversations_ns
|
||||
from .models import models_ns
|
||||
from .prompts import prompts_ns
|
||||
from .schedules import schedules_ns
|
||||
from .sharing import sharing_ns
|
||||
from .sources import sources_chunks_ns, sources_ns, sources_upload_ns
|
||||
from .tools import tools_mcp_ns, tools_ns
|
||||
@@ -40,6 +41,9 @@ api.add_namespace(agents_folders_ns)
|
||||
# Prompts
|
||||
api.add_namespace(prompts_ns)
|
||||
|
||||
# Schedules
|
||||
api.add_namespace(schedules_ns)
|
||||
|
||||
# Sharing
|
||||
api.add_namespace(sharing_ns)
|
||||
|
||||
|
||||
186
application/api/user/scheduler_dispatcher.py
Normal file
186
application/api/user/scheduler_dispatcher.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""Schedule dispatcher: poll Postgres, claim due rows under FOR UPDATE SKIP LOCKED,
|
||||
advance next_run_at atomically with the run claim, then enqueue.
|
||||
|
||||
Per-schedule IANA tz semantics (croniter+zoneinfo) outside Celery's app-wide tz,
|
||||
plus Postgres-native dedup avoid Redis visibility_timeout double-fires.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from application.agents.scheduler_utils import next_cron_run
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.engine import get_engine
|
||||
from application.storage.db.repositories.schedule_runs import (
|
||||
ScheduleRunsRepository,
|
||||
)
|
||||
from application.storage.db.repositories.schedules import SchedulesRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _normalize_dt(value: Any) -> Optional[datetime]:
|
||||
"""Accept a datetime / ISO string / None and return a tz-aware UTC dt."""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, datetime):
|
||||
return value.astimezone(timezone.utc) if value.tzinfo else (
|
||||
value.replace(tzinfo=timezone.utc)
|
||||
)
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
parsed = datetime.fromisoformat(value.replace("Z", "+00:00"))
|
||||
except ValueError:
|
||||
return None
|
||||
return parsed.astimezone(timezone.utc) if parsed.tzinfo else (
|
||||
parsed.replace(tzinfo=timezone.utc)
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _compute_next(
|
||||
schedule: Dict[str, Any],
|
||||
*,
|
||||
after: datetime,
|
||||
) -> Optional[datetime]:
|
||||
"""Next next_run_at for a recurring schedule, or None when past end_at."""
|
||||
cron = schedule.get("cron")
|
||||
if not cron:
|
||||
return None
|
||||
end_at = _normalize_dt(schedule.get("end_at"))
|
||||
candidate = next_cron_run(cron, schedule.get("timezone"), after=after)
|
||||
if end_at is not None and candidate > end_at:
|
||||
return None
|
||||
return candidate
|
||||
|
||||
|
||||
def dispatch_due_runs() -> Dict[str, int]:
|
||||
"""One dispatcher tick; returns counts for schedule_syncs-style logging."""
|
||||
if not settings.POSTGRES_URI:
|
||||
return {"enqueued": 0, "skipped": 0, "advanced": 0}
|
||||
|
||||
from application.api.user.tasks import execute_scheduled_run
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
grace = timedelta(seconds=max(0, settings.SCHEDULE_MISFIRE_GRACE))
|
||||
engine = get_engine()
|
||||
counts = {"enqueued": 0, "skipped": 0, "advanced": 0}
|
||||
enqueue_args: List[str] = []
|
||||
|
||||
with engine.begin() as conn:
|
||||
schedules_repo = SchedulesRepository(conn)
|
||||
runs_repo = ScheduleRunsRepository(conn)
|
||||
for schedule in schedules_repo.list_due():
|
||||
scheduled_for = _normalize_dt(schedule.get("next_run_at"))
|
||||
if scheduled_for is None:
|
||||
continue
|
||||
|
||||
trigger_type = schedule.get("trigger_type")
|
||||
agent_id_raw = schedule.get("agent_id")
|
||||
agent_id = str(agent_id_raw) if agent_id_raw else None
|
||||
|
||||
# Misfire grace applies to recurring only — once-tasks fire late, not vanish.
|
||||
if (
|
||||
trigger_type == "recurring"
|
||||
and grace > timedelta(0)
|
||||
and (now - scheduled_for) > grace
|
||||
):
|
||||
runs_repo.record_skipped(
|
||||
str(schedule["id"]),
|
||||
schedule["user_id"],
|
||||
agent_id,
|
||||
scheduled_for,
|
||||
error_type="missed",
|
||||
error="misfire grace exceeded",
|
||||
)
|
||||
counts["skipped"] += 1
|
||||
nxt = _compute_next(schedule, after=now)
|
||||
if nxt is None:
|
||||
schedules_repo.update_internal(
|
||||
str(schedule["id"]),
|
||||
{"status": "completed", "next_run_at": None,
|
||||
"last_run_at": now},
|
||||
)
|
||||
else:
|
||||
schedules_repo.update_internal(
|
||||
str(schedule["id"]),
|
||||
{"next_run_at": nxt, "last_run_at": now},
|
||||
)
|
||||
counts["advanced"] += 1
|
||||
continue
|
||||
|
||||
# Overlap guard: never enqueue while a previous run is active.
|
||||
if runs_repo.has_active_run(str(schedule["id"])):
|
||||
runs_repo.record_skipped(
|
||||
str(schedule["id"]),
|
||||
schedule["user_id"],
|
||||
agent_id,
|
||||
scheduled_for,
|
||||
error_type="overlap",
|
||||
error="previous run still active",
|
||||
)
|
||||
counts["skipped"] += 1
|
||||
if trigger_type == "recurring":
|
||||
nxt = _compute_next(schedule, after=scheduled_for)
|
||||
schedules_repo.update_internal(
|
||||
str(schedule["id"]),
|
||||
{"next_run_at": nxt, "last_run_at": now},
|
||||
)
|
||||
else:
|
||||
# Once: null next_run_at so we don't re-pick; the in-flight
|
||||
# run will terminal-flip the schedule when it finishes.
|
||||
schedules_repo.update_internal(
|
||||
str(schedule["id"]),
|
||||
{"next_run_at": None, "last_run_at": now},
|
||||
)
|
||||
continue
|
||||
|
||||
# Dedup primitive: two racing dispatchers see exactly one row.
|
||||
run = runs_repo.record_pending(
|
||||
str(schedule["id"]),
|
||||
schedule["user_id"],
|
||||
agent_id,
|
||||
scheduled_for,
|
||||
trigger_source="cron",
|
||||
)
|
||||
if run is None:
|
||||
counts["skipped"] += 1
|
||||
else:
|
||||
enqueue_args.append(str(run["id"]))
|
||||
counts["enqueued"] += 1
|
||||
|
||||
# Advance: recurring picks next tick, once nulls next_run_at
|
||||
# (worker terminal-flips status on completion).
|
||||
if trigger_type == "recurring":
|
||||
nxt = _compute_next(schedule, after=scheduled_for)
|
||||
if nxt is None:
|
||||
schedules_repo.update_internal(
|
||||
str(schedule["id"]),
|
||||
{"status": "completed", "next_run_at": None,
|
||||
"last_run_at": now},
|
||||
)
|
||||
else:
|
||||
schedules_repo.update_internal(
|
||||
str(schedule["id"]),
|
||||
{"next_run_at": nxt, "last_run_at": now},
|
||||
)
|
||||
else:
|
||||
schedules_repo.update_internal(
|
||||
str(schedule["id"]),
|
||||
{"next_run_at": None, "last_run_at": now},
|
||||
)
|
||||
counts["advanced"] += 1
|
||||
|
||||
# Enqueue after commit so the worker sees the schedule_runs row on pick-up.
|
||||
for run_id in enqueue_args:
|
||||
try:
|
||||
execute_scheduled_run.apply_async(args=[run_id], queue="docsgpt")
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"dispatcher: failed to enqueue execute_scheduled_run for %s",
|
||||
run_id,
|
||||
)
|
||||
return counts
|
||||
433
application/api/user/scheduler_worker.py
Normal file
433
application/api/user/scheduler_worker.py
Normal file
@@ -0,0 +1,433 @@
|
||||
"""Body of ``execute_scheduled_run`` — runs a single agent execution.
|
||||
|
||||
Not a DURABLE_TASK: agent runs have side effects (messages, CRM writes)
|
||||
and blind auto-retry would double them. Failures after agent.gen starts
|
||||
are terminal and recorded; only the pre-start load is retry-safe.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from sqlalchemy import text as sql_text
|
||||
|
||||
from application.agents.headless_runner import run_agent_headless
|
||||
from application.core.settings import settings
|
||||
from application.events.publisher import publish_user_event
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
from application.storage.db.engine import get_engine
|
||||
from application.storage.db.repositories.conversations import (
|
||||
ConversationsRepository,
|
||||
)
|
||||
from application.storage.db.repositories.schedule_runs import (
|
||||
ScheduleRunsRepository,
|
||||
)
|
||||
from application.storage.db.repositories.schedules import SchedulesRepository
|
||||
from application.storage.db.repositories.token_usage import TokenUsageRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Cap output verbatim in the run log; beyond the cap we keep the head and stamp output_truncated.
|
||||
_OUTPUT_CAP_CHARS = 24_000
|
||||
|
||||
|
||||
def _agent_config_for_schedule(schedule: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""Resolve the agent row (agent-bound) or build an ephemeral classic config.
|
||||
|
||||
For agentless schedules (``agent_id IS NULL``), the worker constructs an
|
||||
in-memory agent shape carrying just enough fields for ``run_agent_headless``:
|
||||
classic agent type, system-default retriever/chunks/prompt, no source, and
|
||||
the optional ``model_id`` override. The runtime toolset is rebuilt by
|
||||
``ToolExecutor`` at fire time (current ``user_tools`` + non-disabled,
|
||||
non-headless-excluded defaults), so a snapshot here would be dead code.
|
||||
"""
|
||||
if schedule.get("agent_id"):
|
||||
engine = get_engine()
|
||||
with engine.connect() as conn:
|
||||
row = conn.execute(
|
||||
sql_text("SELECT * FROM agents WHERE id = CAST(:id AS uuid)"),
|
||||
{"id": str(schedule["agent_id"])},
|
||||
).fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
return _ephemeral_agent_for_agentless(schedule)
|
||||
|
||||
|
||||
def _ephemeral_agent_for_agentless(
|
||||
schedule: Dict[str, Any],
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Build an agent-shaped config for a schedule with no parent agent."""
|
||||
# ``agent_config["tools"]`` is intentionally omitted: ``run_agent_headless``
|
||||
# never reads it. The runtime toolset is rebuilt by
|
||||
# ``ToolExecutor._get_user_tools(owner)`` at fire time — same dereference
|
||||
# the agent-bound path uses, so a tool added/disabled after creation is
|
||||
# reflected. Headless mode there filters chat-only tools (``scheduler``).
|
||||
user_id = schedule.get("user_id")
|
||||
if not user_id:
|
||||
return None
|
||||
return {
|
||||
"id": None,
|
||||
"user_id": user_id,
|
||||
"agent_type": "classic",
|
||||
"retriever": "classic",
|
||||
"chunks": 2,
|
||||
"prompt_id": "default",
|
||||
"source_id": None,
|
||||
"default_model_id": schedule.get("model_id") or "",
|
||||
}
|
||||
|
||||
|
||||
def _load_chat_history(schedule: Dict[str, Any]) -> list:
|
||||
"""Originating conversation history (one-time only; recurring has none)."""
|
||||
origin = schedule.get("origin_conversation_id")
|
||||
if not origin or schedule.get("trigger_type") != "once":
|
||||
return []
|
||||
user_id = schedule.get("user_id")
|
||||
if not user_id:
|
||||
return []
|
||||
try:
|
||||
engine = get_engine()
|
||||
with engine.connect() as conn:
|
||||
conv = ConversationsRepository(conn).get_any(str(origin), user_id)
|
||||
if conv is None:
|
||||
return []
|
||||
messages = ConversationsRepository(conn).get_messages(str(conv["id"]))
|
||||
except Exception:
|
||||
logger.exception("scheduler: failed loading chat history")
|
||||
return []
|
||||
history: list = []
|
||||
for msg in messages:
|
||||
if msg.get("prompt") and msg.get("response"):
|
||||
history.append({
|
||||
"prompt": msg["prompt"],
|
||||
"response": msg["response"],
|
||||
})
|
||||
return history
|
||||
|
||||
|
||||
def _publish_run_event(
|
||||
event_type: str, run: Dict[str, Any], schedule: Dict[str, Any], **extra: Any
|
||||
) -> None:
|
||||
"""Best-effort SSE publish for a scheduler run state transition."""
|
||||
user_id = run.get("user_id") or schedule.get("user_id")
|
||||
if not user_id:
|
||||
return
|
||||
agent_id_raw = schedule.get("agent_id")
|
||||
payload = {
|
||||
"run_id": str(run["id"]),
|
||||
"schedule_id": str(schedule["id"]),
|
||||
"agent_id": str(agent_id_raw) if agent_id_raw else None,
|
||||
"trigger_type": schedule.get("trigger_type"),
|
||||
"status": run.get("status"),
|
||||
**extra,
|
||||
}
|
||||
try:
|
||||
publish_user_event(
|
||||
user_id,
|
||||
event_type,
|
||||
payload,
|
||||
scope={"kind": "schedule", "id": str(schedule["id"])},
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"scheduler: SSE publish failed event=%s run=%s",
|
||||
event_type, run.get("id"),
|
||||
)
|
||||
|
||||
|
||||
def _publish_message_appended(
|
||||
user_id: str,
|
||||
conversation_id: str,
|
||||
message: Dict[str, Any],
|
||||
schedule_id: str,
|
||||
run_id: str,
|
||||
) -> None:
|
||||
"""SSE message-appended event for a one-time run's chat turn."""
|
||||
try:
|
||||
publish_user_event(
|
||||
user_id,
|
||||
"schedule.message.appended",
|
||||
{
|
||||
"conversation_id": str(conversation_id),
|
||||
"message_id": str(message["id"]),
|
||||
"schedule_id": str(schedule_id),
|
||||
"run_id": str(run_id),
|
||||
"position": int(message.get("position", 0)),
|
||||
},
|
||||
scope={"kind": "conversation", "id": str(conversation_id)},
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"scheduler: message.appended publish failed run=%s", run_id,
|
||||
)
|
||||
|
||||
|
||||
def _append_one_time_turn(
|
||||
schedule: Dict[str, Any],
|
||||
run: Dict[str, Any],
|
||||
outcome: Dict[str, Any],
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Insert an assistant turn in the originating conversation (once only)."""
|
||||
origin = schedule.get("origin_conversation_id")
|
||||
if not origin:
|
||||
return None
|
||||
engine = get_engine()
|
||||
user_id = schedule.get("user_id")
|
||||
metadata = {
|
||||
"scheduled": True,
|
||||
"schedule_id": str(schedule["id"]),
|
||||
"run_id": str(run["id"]),
|
||||
"scheduled_run_at": (
|
||||
run.get("scheduled_for")
|
||||
if isinstance(run.get("scheduled_for"), str)
|
||||
else None
|
||||
),
|
||||
}
|
||||
with engine.begin() as conn:
|
||||
conv = ConversationsRepository(conn).get_any(str(origin), user_id)
|
||||
if conv is None:
|
||||
return None
|
||||
message = ConversationsRepository(conn).append_message(
|
||||
str(conv["id"]),
|
||||
{
|
||||
"prompt": schedule.get("instruction") or "",
|
||||
"response": outcome.get("answer") or "",
|
||||
"thought": outcome.get("thought") or "",
|
||||
"sources": outcome.get("sources") or [],
|
||||
"tool_calls": outcome.get("tool_calls") or [],
|
||||
"model_id": outcome.get("model_id"),
|
||||
"metadata": metadata,
|
||||
},
|
||||
)
|
||||
return message
|
||||
|
||||
|
||||
def execute_scheduled_run_body(run_id: str, celery_task_id: Optional[str]) -> Dict[str, Any]:
|
||||
"""Execute one scheduled run by id; returns a result dict for tracing."""
|
||||
if not settings.POSTGRES_URI:
|
||||
return {"status": "skipped", "reason": "POSTGRES_URI not set"}
|
||||
|
||||
engine = get_engine()
|
||||
|
||||
with engine.connect() as conn:
|
||||
run = ScheduleRunsRepository(conn).get_internal(run_id)
|
||||
if run is None:
|
||||
return {"status": "skipped", "reason": "run not found"}
|
||||
schedule = SchedulesRepository(conn).get_internal(str(run["schedule_id"]))
|
||||
if schedule is None:
|
||||
return {"status": "skipped", "reason": "schedule not found"}
|
||||
|
||||
# Refuse non-runnable terminal states; manual run-now bypasses.
|
||||
if run.get("status") != "pending":
|
||||
return {"status": "skipped", "reason": f"run status={run.get('status')}"}
|
||||
if schedule.get("status") in {"cancelled", "completed"} and run.get(
|
||||
"trigger_source"
|
||||
) != "manual":
|
||||
with engine.begin() as conn:
|
||||
ScheduleRunsRepository(conn).update(
|
||||
run_id,
|
||||
{
|
||||
"status": "skipped",
|
||||
"finished_at": datetime.now(timezone.utc),
|
||||
"error_type": "internal",
|
||||
"error": "schedule no longer active",
|
||||
},
|
||||
)
|
||||
return {"status": "skipped", "reason": "schedule terminal"}
|
||||
|
||||
agent_config = _agent_config_for_schedule(schedule)
|
||||
if agent_config is None:
|
||||
with engine.begin() as conn:
|
||||
updated = ScheduleRunsRepository(conn).update(
|
||||
run_id,
|
||||
{
|
||||
"status": "failed",
|
||||
"finished_at": datetime.now(timezone.utc),
|
||||
"error_type": "internal",
|
||||
"error": "agent missing",
|
||||
},
|
||||
)
|
||||
SchedulesRepository(conn).bump_failure_count(str(schedule["id"]))
|
||||
_publish_run_event("schedule.run.failed", updated or run, schedule,
|
||||
error="agent missing")
|
||||
return {"status": "failed", "reason": "agent missing"}
|
||||
|
||||
with engine.begin() as conn:
|
||||
if not ScheduleRunsRepository(conn).mark_running(run_id, celery_task_id):
|
||||
return {"status": "skipped", "reason": "lost race to mark_running"}
|
||||
|
||||
started = datetime.now(timezone.utc)
|
||||
instruction = schedule.get("instruction") or ""
|
||||
allowlist = schedule.get("tool_allowlist") or []
|
||||
chat_history = _load_chat_history(schedule)
|
||||
outcome: Dict[str, Any]
|
||||
error_type: Optional[str] = None
|
||||
error_text: Optional[str] = None
|
||||
timed_out = False
|
||||
try:
|
||||
outcome = run_agent_headless(
|
||||
agent_config,
|
||||
instruction,
|
||||
tool_allowlist=allowlist,
|
||||
model_id_override=schedule.get("model_id"),
|
||||
endpoint="schedule",
|
||||
conversation_id=schedule.get("origin_conversation_id"),
|
||||
chat_history=chat_history,
|
||||
)
|
||||
except SoftTimeLimitExceeded:
|
||||
timed_out = True
|
||||
outcome = {"answer": "", "tool_calls": [], "sources": [], "thought": ""}
|
||||
error_type = "timeout"
|
||||
error_text = "run exceeded soft time limit"
|
||||
except Exception as exc:
|
||||
outcome = {"answer": "", "tool_calls": [], "sources": [], "thought": ""}
|
||||
error_type = "agent_error"
|
||||
error_text = str(exc)
|
||||
logger.exception("scheduler: agent run failed run=%s", run_id)
|
||||
|
||||
finished = datetime.now(timezone.utc)
|
||||
|
||||
# Headless denial with no usable output → tool_not_allowed.
|
||||
if (
|
||||
error_type is None
|
||||
and (outcome.get("denied") or [])
|
||||
and not (outcome.get("answer") or "").strip()
|
||||
):
|
||||
error_type = "tool_not_allowed"
|
||||
error_text = "headless allowlist blocked required tool"
|
||||
|
||||
prompt_tokens = int(outcome.get("prompt_tokens", 0) or 0)
|
||||
generated_tokens = int(outcome.get("generated_tokens", 0) or 0)
|
||||
used_tokens = prompt_tokens + generated_tokens
|
||||
if (
|
||||
schedule.get("token_budget") is not None
|
||||
and int(schedule["token_budget"]) > 0
|
||||
and used_tokens > int(schedule["token_budget"])
|
||||
):
|
||||
error_type = "budget_exceeded"
|
||||
error_text = (
|
||||
f"used {used_tokens} tokens exceeds budget "
|
||||
f"{schedule['token_budget']}"
|
||||
)
|
||||
|
||||
answer = outcome.get("answer") or ""
|
||||
truncated = False
|
||||
if len(answer) > _OUTPUT_CAP_CHARS:
|
||||
answer = answer[:_OUTPUT_CAP_CHARS]
|
||||
truncated = True
|
||||
|
||||
new_status = (
|
||||
"timeout" if timed_out else ("failed" if error_type else "success")
|
||||
)
|
||||
|
||||
with engine.begin() as conn:
|
||||
update_fields: Dict[str, Any] = {
|
||||
"status": new_status,
|
||||
"started_at": started,
|
||||
"finished_at": finished,
|
||||
"output": answer or None,
|
||||
"output_truncated": truncated,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"generated_tokens": generated_tokens,
|
||||
}
|
||||
if error_type:
|
||||
update_fields["error_type"] = error_type
|
||||
update_fields["error"] = error_text
|
||||
updated_run = ScheduleRunsRepository(conn).update(run_id, update_fields)
|
||||
if used_tokens > 0:
|
||||
agent_id_raw = schedule.get("agent_id")
|
||||
try:
|
||||
TokenUsageRepository(conn).insert(
|
||||
user_id=schedule.get("user_id"),
|
||||
api_key=None,
|
||||
prompt_tokens=prompt_tokens,
|
||||
generated_tokens=generated_tokens,
|
||||
timestamp=finished,
|
||||
agent_id=str(agent_id_raw) if agent_id_raw else None,
|
||||
source="schedule",
|
||||
request_id=str(run_id),
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"scheduler: token_usage insert failed run=%s", run_id,
|
||||
)
|
||||
schedules_repo = SchedulesRepository(conn)
|
||||
autopaused = False
|
||||
if new_status == "success":
|
||||
schedules_repo.reset_failure_count(str(schedule["id"]))
|
||||
elif new_status in ("failed", "timeout"):
|
||||
count = schedules_repo.bump_failure_count(str(schedule["id"]))
|
||||
if (
|
||||
settings.SCHEDULE_AUTOPAUSE_FAILURES > 0
|
||||
and count >= settings.SCHEDULE_AUTOPAUSE_FAILURES
|
||||
and schedule.get("trigger_type") == "recurring"
|
||||
):
|
||||
autopaused = schedules_repo.autopause(str(schedule["id"]))
|
||||
# Once: terminal-flip on cron-fired runs only; manual runs on a
|
||||
# still-active once-schedule leave the future cadence intact.
|
||||
if (
|
||||
schedule.get("trigger_type") == "once"
|
||||
and run.get("trigger_source") != "manual"
|
||||
and schedule.get("status") == "active"
|
||||
):
|
||||
schedules_repo.update_internal(
|
||||
str(schedule["id"]),
|
||||
{"status": "completed", "next_run_at": None},
|
||||
)
|
||||
|
||||
appended: Optional[Dict[str, Any]] = None
|
||||
if (
|
||||
schedule.get("trigger_type") == "once"
|
||||
and new_status == "success"
|
||||
and schedule.get("origin_conversation_id")
|
||||
):
|
||||
try:
|
||||
appended = _append_one_time_turn(schedule, updated_run or run, outcome)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"scheduler: append turn failed run=%s", run_id,
|
||||
)
|
||||
if appended is not None:
|
||||
with engine.begin() as conn:
|
||||
ScheduleRunsRepository(conn).update(
|
||||
run_id,
|
||||
{
|
||||
"conversation_id": str(appended["conversation_id"]),
|
||||
"message_id": str(appended["id"]),
|
||||
},
|
||||
)
|
||||
_publish_message_appended(
|
||||
schedule.get("user_id"),
|
||||
str(appended["conversation_id"]),
|
||||
appended,
|
||||
str(schedule["id"]),
|
||||
run_id,
|
||||
)
|
||||
|
||||
if new_status == "success":
|
||||
_publish_run_event("schedule.run.completed", updated_run or run, schedule)
|
||||
else:
|
||||
_publish_run_event(
|
||||
"schedule.run.failed",
|
||||
updated_run or run,
|
||||
schedule,
|
||||
error_type=error_type,
|
||||
error=error_text,
|
||||
)
|
||||
|
||||
if autopaused:
|
||||
_publish_run_event(
|
||||
"schedule.autopaused",
|
||||
updated_run or run,
|
||||
schedule,
|
||||
consecutive_failure_count=settings.SCHEDULE_AUTOPAUSE_FAILURES,
|
||||
)
|
||||
|
||||
return {
|
||||
"status": new_status,
|
||||
"run_id": run_id,
|
||||
"error_type": error_type,
|
||||
}
|
||||
5
application/api/user/schedules/__init__.py
Normal file
5
application/api/user/schedules/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Schedules module."""
|
||||
|
||||
from .routes import schedules_ns
|
||||
|
||||
__all__ = ["schedules_ns"]
|
||||
550
application/api/user/schedules/routes.py
Normal file
550
application/api/user/schedules/routes.py
Normal file
@@ -0,0 +1,550 @@
|
||||
"""Schedules REST API (owner-scoped via request.decoded_token)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import Namespace, Resource, fields
|
||||
|
||||
from application.agents.scheduler_utils import (
|
||||
ScheduleValidationError,
|
||||
clamp_once_horizon,
|
||||
cron_interval_seconds,
|
||||
next_cron_run,
|
||||
parse_cron,
|
||||
parse_run_at,
|
||||
resolve_timezone,
|
||||
)
|
||||
from application.api import api
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.schedule_runs import (
|
||||
ScheduleRunsRepository,
|
||||
)
|
||||
from application.storage.db.repositories.schedules import SchedulesRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
schedules_ns = Namespace(
|
||||
"schedules", description="Agent schedule management", path="/api",
|
||||
)
|
||||
|
||||
|
||||
def _ok(data: Any, status: int = 200):
|
||||
return make_response(jsonify(data), status)
|
||||
|
||||
|
||||
def _err(message: str, status: int = 400):
|
||||
return make_response(jsonify({"success": False, "message": message}), status)
|
||||
|
||||
|
||||
def _format_schedule(row: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Render a schedule row for the API (id-as-string + ISO timestamps)."""
|
||||
if not row:
|
||||
return {}
|
||||
out = dict(row)
|
||||
for key in (
|
||||
"id", "agent_id", "origin_conversation_id",
|
||||
):
|
||||
if out.get(key) is not None:
|
||||
out[key] = str(out[key])
|
||||
out.pop("_id", None) # drop dual-id legacy mirror
|
||||
return out
|
||||
|
||||
|
||||
def _format_run(row: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Render a schedule_run row for the API."""
|
||||
if not row:
|
||||
return {}
|
||||
out = dict(row)
|
||||
for key in (
|
||||
"id", "schedule_id", "agent_id", "conversation_id", "message_id",
|
||||
):
|
||||
if out.get(key) is not None:
|
||||
out[key] = str(out[key])
|
||||
out.pop("_id", None)
|
||||
return out
|
||||
|
||||
|
||||
def _agent_owned(agent_id: str, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
if not looks_like_uuid(str(agent_id)):
|
||||
return None
|
||||
with db_readonly() as conn:
|
||||
return AgentsRepository(conn).get_any(agent_id, user_id)
|
||||
|
||||
|
||||
def _user_id() -> Optional[str]:
|
||||
decoded = getattr(request, "decoded_token", None)
|
||||
if not decoded:
|
||||
return None
|
||||
return decoded.get("sub")
|
||||
|
||||
|
||||
@schedules_ns.route("/agents/<string:agent_id>/schedules")
|
||||
class AgentSchedules(Resource):
|
||||
@api.doc(description="List schedules for an agent (recurring + one-time).")
|
||||
def get(self, agent_id):
|
||||
user_id = _user_id()
|
||||
if not user_id:
|
||||
return _err("unauthorized", 401)
|
||||
agent = _agent_owned(agent_id, user_id)
|
||||
if agent is None:
|
||||
return _err("agent not found", 404)
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
rows = SchedulesRepository(conn).list_for_agent(
|
||||
str(agent["id"]), user_id,
|
||||
)
|
||||
except Exception as exc:
|
||||
current_app.logger.error("list schedules failed: %s", exc, exc_info=True)
|
||||
return _err("internal error", 500)
|
||||
return _ok({"schedules": [_format_schedule(r) for r in rows]})
|
||||
|
||||
create_model = api.model(
|
||||
"ScheduleCreate",
|
||||
{
|
||||
"instruction": fields.String(required=True),
|
||||
"trigger_type": fields.String(
|
||||
required=False,
|
||||
description="'recurring' (default) or 'once'",
|
||||
),
|
||||
"cron": fields.String(
|
||||
required=False,
|
||||
description="Required when trigger_type == 'recurring'",
|
||||
),
|
||||
"run_at": fields.String(
|
||||
required=False,
|
||||
description="ISO 8601 — required when trigger_type == 'once'",
|
||||
),
|
||||
"timezone": fields.String(required=False),
|
||||
"name": fields.String(required=False),
|
||||
"end_at": fields.String(required=False, description="ISO 8601"),
|
||||
"tool_allowlist": fields.List(fields.String, required=False),
|
||||
"model_id": fields.String(required=False),
|
||||
"token_budget": fields.Integer(required=False),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(create_model)
|
||||
@api.doc(description="Create a schedule (recurring or one-time) for an agent.")
|
||||
def post(self, agent_id):
|
||||
user_id = _user_id()
|
||||
if not user_id:
|
||||
return _err("unauthorized", 401)
|
||||
agent = _agent_owned(agent_id, user_id)
|
||||
if agent is None:
|
||||
return _err("agent not found", 404)
|
||||
data = request.get_json(silent=True) or {}
|
||||
instruction = (data.get("instruction") or "").strip()
|
||||
tz_name = (data.get("timezone") or "UTC").strip() or "UTC"
|
||||
trigger_type = (data.get("trigger_type") or "recurring").strip().lower()
|
||||
if trigger_type not in ("recurring", "once"):
|
||||
return _err("trigger_type must be 'recurring' or 'once'")
|
||||
if not instruction:
|
||||
return _err("instruction is required")
|
||||
try:
|
||||
resolve_timezone(tz_name)
|
||||
except ScheduleValidationError as exc:
|
||||
return _err(str(exc))
|
||||
token_budget = data.get("token_budget")
|
||||
if token_budget is not None:
|
||||
try:
|
||||
token_budget = int(token_budget)
|
||||
if token_budget < 0:
|
||||
raise ValueError
|
||||
except (TypeError, ValueError):
|
||||
return _err("token_budget must be a non-negative integer")
|
||||
with db_readonly() as conn:
|
||||
count = SchedulesRepository(conn).count_active_for_user(user_id)
|
||||
if (
|
||||
settings.SCHEDULE_MAX_PER_USER > 0
|
||||
and count >= settings.SCHEDULE_MAX_PER_USER
|
||||
):
|
||||
return _err("max schedules per user reached", 429)
|
||||
|
||||
if trigger_type == "once":
|
||||
run_at_raw = (data.get("run_at") or "").strip()
|
||||
if not run_at_raw:
|
||||
return _err("run_at is required for trigger_type 'once'")
|
||||
try:
|
||||
fire = parse_run_at(run_at_raw, tz_name)
|
||||
clamp_once_horizon(
|
||||
fire, settings.SCHEDULE_ONCE_MAX_HORIZON,
|
||||
)
|
||||
except ScheduleValidationError as exc:
|
||||
return _err(str(exc))
|
||||
try:
|
||||
with db_session() as conn:
|
||||
created = SchedulesRepository(conn).create(
|
||||
user_id=user_id,
|
||||
agent_id=str(agent["id"]),
|
||||
trigger_type="once",
|
||||
instruction=instruction,
|
||||
run_at=fire,
|
||||
next_run_at=fire,
|
||||
timezone=tz_name,
|
||||
name=(data.get("name") or "").strip() or None,
|
||||
tool_allowlist=data.get("tool_allowlist") or [],
|
||||
model_id=(data.get("model_id") or None),
|
||||
token_budget=token_budget,
|
||||
created_via="ui",
|
||||
)
|
||||
except Exception as exc:
|
||||
current_app.logger.error(
|
||||
"create one-time schedule failed: %s", exc, exc_info=True,
|
||||
)
|
||||
return _err("internal error", 500)
|
||||
return _ok({"schedule": _format_schedule(created)}, status=201)
|
||||
|
||||
cron = (data.get("cron") or "").strip()
|
||||
if not cron:
|
||||
return _err("cron is required")
|
||||
try:
|
||||
parse_cron(cron)
|
||||
except ScheduleValidationError as exc:
|
||||
return _err(str(exc))
|
||||
min_interval = max(0, int(settings.SCHEDULE_MIN_INTERVAL))
|
||||
if min_interval > 0:
|
||||
try:
|
||||
cadence = cron_interval_seconds(cron, tz_name)
|
||||
except ScheduleValidationError as exc:
|
||||
return _err(str(exc))
|
||||
if cadence < min_interval:
|
||||
return _err(
|
||||
"cadence below minimum interval "
|
||||
f"({cadence}s < {min_interval}s)",
|
||||
)
|
||||
end_at = None
|
||||
if data.get("end_at"):
|
||||
try:
|
||||
end_at = datetime.fromisoformat(
|
||||
str(data["end_at"]).replace("Z", "+00:00"),
|
||||
)
|
||||
except ValueError:
|
||||
return _err("invalid end_at")
|
||||
try:
|
||||
next_run = next_cron_run(cron, tz_name, after=datetime.now(timezone.utc))
|
||||
except ScheduleValidationError as exc:
|
||||
return _err(str(exc))
|
||||
if end_at is not None and next_run > end_at:
|
||||
return _err("end_at is before the first cron tick")
|
||||
try:
|
||||
with db_session() as conn:
|
||||
created = SchedulesRepository(conn).create(
|
||||
user_id=user_id,
|
||||
agent_id=str(agent["id"]),
|
||||
trigger_type="recurring",
|
||||
instruction=instruction,
|
||||
cron=cron,
|
||||
timezone=tz_name,
|
||||
next_run_at=next_run,
|
||||
end_at=end_at,
|
||||
name=(data.get("name") or "").strip() or None,
|
||||
tool_allowlist=data.get("tool_allowlist") or [],
|
||||
model_id=(data.get("model_id") or None),
|
||||
token_budget=token_budget,
|
||||
created_via="ui",
|
||||
)
|
||||
except Exception as exc:
|
||||
current_app.logger.error(
|
||||
"create schedule failed: %s", exc, exc_info=True,
|
||||
)
|
||||
return _err("internal error", 500)
|
||||
return _ok({"schedule": _format_schedule(created)}, status=201)
|
||||
|
||||
|
||||
@schedules_ns.route("/schedules/<string:schedule_id>")
|
||||
class ScheduleResource(Resource):
|
||||
@api.doc(description="Get schedule by id.")
|
||||
def get(self, schedule_id):
|
||||
user_id = _user_id()
|
||||
if not user_id:
|
||||
return _err("unauthorized", 401)
|
||||
if not looks_like_uuid(schedule_id):
|
||||
return _err("invalid schedule id", 400)
|
||||
with db_readonly() as conn:
|
||||
row = SchedulesRepository(conn).get(schedule_id, user_id)
|
||||
if row is None:
|
||||
return _err("schedule not found", 404)
|
||||
return _ok({"schedule": _format_schedule(row)})
|
||||
|
||||
@api.doc(description="Edit a schedule's editable fields.")
|
||||
def put(self, schedule_id):
|
||||
user_id = _user_id()
|
||||
if not user_id:
|
||||
return _err("unauthorized", 401)
|
||||
if not looks_like_uuid(schedule_id):
|
||||
return _err("invalid schedule id", 400)
|
||||
data = request.get_json(silent=True) or {}
|
||||
fields_in: Dict[str, Any] = {}
|
||||
if "instruction" in data:
|
||||
inst = (data["instruction"] or "").strip()
|
||||
if not inst:
|
||||
return _err("instruction must not be empty")
|
||||
fields_in["instruction"] = inst
|
||||
if "cron" in data:
|
||||
cron = (data["cron"] or "").strip()
|
||||
try:
|
||||
parse_cron(cron)
|
||||
except ScheduleValidationError as exc:
|
||||
return _err(str(exc))
|
||||
fields_in["cron"] = cron
|
||||
if "timezone" in data:
|
||||
tz_name = (data["timezone"] or "UTC").strip() or "UTC"
|
||||
try:
|
||||
resolve_timezone(tz_name)
|
||||
except ScheduleValidationError as exc:
|
||||
return _err(str(exc))
|
||||
fields_in["timezone"] = tz_name
|
||||
if "tool_allowlist" in data:
|
||||
fields_in["tool_allowlist"] = data["tool_allowlist"] or []
|
||||
if "name" in data:
|
||||
fields_in["name"] = (data["name"] or "").strip() or None
|
||||
if "model_id" in data:
|
||||
fields_in["model_id"] = (data["model_id"] or None)
|
||||
if "token_budget" in data:
|
||||
tb = data["token_budget"]
|
||||
if tb is not None:
|
||||
try:
|
||||
tb = int(tb)
|
||||
if tb < 0:
|
||||
raise ValueError
|
||||
except (TypeError, ValueError):
|
||||
return _err("token_budget must be a non-negative integer")
|
||||
fields_in["token_budget"] = tb
|
||||
if "end_at" in data:
|
||||
if data["end_at"]:
|
||||
try:
|
||||
fields_in["end_at"] = datetime.fromisoformat(
|
||||
str(data["end_at"]).replace("Z", "+00:00"),
|
||||
)
|
||||
except ValueError:
|
||||
return _err("invalid end_at")
|
||||
else:
|
||||
fields_in["end_at"] = None
|
||||
# Recompute next_run_at when cron/tz changes.
|
||||
with db_session() as conn:
|
||||
existing = SchedulesRepository(conn).get(schedule_id, user_id)
|
||||
if existing is None:
|
||||
return _err("schedule not found", 404)
|
||||
if (
|
||||
("cron" in fields_in or "timezone" in fields_in)
|
||||
and existing.get("trigger_type") == "recurring"
|
||||
):
|
||||
cron_eff = fields_in.get("cron") or existing.get("cron")
|
||||
tz_eff = fields_in.get("timezone") or existing.get("timezone")
|
||||
if cron_eff:
|
||||
min_interval = max(0, int(settings.SCHEDULE_MIN_INTERVAL))
|
||||
if min_interval > 0:
|
||||
try:
|
||||
cadence = cron_interval_seconds(cron_eff, tz_eff)
|
||||
except ScheduleValidationError as exc:
|
||||
return _err(str(exc))
|
||||
if cadence < min_interval:
|
||||
return _err(
|
||||
"cadence below minimum interval "
|
||||
f"({cadence}s < {min_interval}s)",
|
||||
)
|
||||
try:
|
||||
fields_in["next_run_at"] = next_cron_run(
|
||||
cron_eff, tz_eff, after=datetime.now(timezone.utc),
|
||||
)
|
||||
except ScheduleValidationError as exc:
|
||||
return _err(str(exc))
|
||||
updated = SchedulesRepository(conn).update(
|
||||
schedule_id, user_id, fields_in,
|
||||
)
|
||||
return _ok({"schedule": _format_schedule(updated or {})})
|
||||
|
||||
@api.doc(description="Pause / resume a schedule.")
|
||||
def patch(self, schedule_id):
|
||||
user_id = _user_id()
|
||||
if not user_id:
|
||||
return _err("unauthorized", 401)
|
||||
if not looks_like_uuid(schedule_id):
|
||||
return _err("invalid schedule id", 400)
|
||||
data = request.get_json(silent=True) or {}
|
||||
action = (data.get("action") or "").lower().strip()
|
||||
if action not in {"pause", "resume"}:
|
||||
return _err("action must be 'pause' or 'resume'")
|
||||
with db_session() as conn:
|
||||
existing = SchedulesRepository(conn).get(schedule_id, user_id)
|
||||
if existing is None:
|
||||
return _err("schedule not found", 404)
|
||||
if existing.get("status") in ("cancelled", "completed"):
|
||||
return _err("schedule is terminal", 409)
|
||||
if action == "pause":
|
||||
fields_in: Dict[str, Any] = {"status": "paused", "next_run_at": None}
|
||||
else:
|
||||
# Resume: recurring recomputes from now; once honours run_at if still future.
|
||||
fields_in = {"status": "active"}
|
||||
if existing.get("trigger_type") == "recurring":
|
||||
try:
|
||||
fields_in["next_run_at"] = next_cron_run(
|
||||
existing["cron"],
|
||||
existing["timezone"],
|
||||
after=datetime.now(timezone.utc),
|
||||
)
|
||||
except ScheduleValidationError as exc:
|
||||
return _err(str(exc))
|
||||
else:
|
||||
new_run_at = data.get("run_at")
|
||||
if new_run_at:
|
||||
try:
|
||||
run_at_dt = datetime.fromisoformat(
|
||||
str(new_run_at).replace("Z", "+00:00"),
|
||||
)
|
||||
except ValueError:
|
||||
return _err("invalid run_at")
|
||||
if run_at_dt <= datetime.now(timezone.utc):
|
||||
return _err(
|
||||
"run_at must be in the future to resume", 409,
|
||||
)
|
||||
fields_in["next_run_at"] = run_at_dt
|
||||
fields_in["run_at"] = run_at_dt
|
||||
else:
|
||||
run_at = existing.get("run_at")
|
||||
if run_at:
|
||||
if isinstance(run_at, str):
|
||||
try:
|
||||
run_at_dt = datetime.fromisoformat(
|
||||
run_at.replace("Z", "+00:00"),
|
||||
)
|
||||
except ValueError:
|
||||
return _err("schedule run_at is invalid")
|
||||
else:
|
||||
run_at_dt = run_at
|
||||
if run_at_dt <= datetime.now(timezone.utc):
|
||||
return _err(
|
||||
"the once schedule has elapsed; recreate "
|
||||
"it or supply a new run_at",
|
||||
409,
|
||||
)
|
||||
fields_in["next_run_at"] = run_at_dt
|
||||
updated = SchedulesRepository(conn).update(
|
||||
schedule_id, user_id, fields_in,
|
||||
)
|
||||
if action == "resume":
|
||||
SchedulesRepository(conn).reset_failure_count(schedule_id)
|
||||
return _ok({"schedule": _format_schedule(updated or {})})
|
||||
|
||||
@api.doc(description="Cancel / delete a schedule.")
|
||||
def delete(self, schedule_id):
|
||||
user_id = _user_id()
|
||||
if not user_id:
|
||||
return _err("unauthorized", 401)
|
||||
if not looks_like_uuid(schedule_id):
|
||||
return _err("invalid schedule id", 400)
|
||||
with db_session() as conn:
|
||||
ok = SchedulesRepository(conn).delete(schedule_id, user_id)
|
||||
if not ok:
|
||||
return _err("schedule not found", 404)
|
||||
return _ok({"success": True})
|
||||
|
||||
|
||||
@schedules_ns.route("/schedules/<string:schedule_id>/run")
|
||||
class ScheduleRunNow(Resource):
|
||||
@api.doc(description="Run a schedule immediately (trigger_source='manual').")
|
||||
def post(self, schedule_id):
|
||||
user_id = _user_id()
|
||||
if not user_id:
|
||||
return _err("unauthorized", 401)
|
||||
if not looks_like_uuid(schedule_id):
|
||||
return _err("invalid schedule id", 400)
|
||||
# FOR UPDATE serializes concurrent Run-Now POSTs (timestamp-unique
|
||||
# scheduled_for values would otherwise sneak past the unique index).
|
||||
with db_session() as conn:
|
||||
schedule = SchedulesRepository(conn).get_for_update(
|
||||
schedule_id, user_id,
|
||||
)
|
||||
if schedule is None:
|
||||
return _err("schedule not found", 404)
|
||||
if schedule.get("status") == "cancelled":
|
||||
return _err("schedule is cancelled", 409)
|
||||
if ScheduleRunsRepository(conn).has_active_run(schedule_id):
|
||||
return _err("a run is already in flight", 409)
|
||||
scheduled_for = datetime.now(timezone.utc)
|
||||
agent_id_raw = schedule.get("agent_id")
|
||||
run = ScheduleRunsRepository(conn).record_pending(
|
||||
schedule_id,
|
||||
user_id,
|
||||
str(agent_id_raw) if agent_id_raw else None,
|
||||
scheduled_for,
|
||||
trigger_source="manual",
|
||||
)
|
||||
if run is None:
|
||||
return _err("could not claim run (concurrent dispatch)", 409)
|
||||
# Import inside the handler to avoid a circular tasks <-> routes import.
|
||||
try:
|
||||
from application.api.user.tasks import execute_scheduled_run
|
||||
execute_scheduled_run.apply_async(args=[str(run["id"])], queue="docsgpt")
|
||||
except Exception as exc:
|
||||
current_app.logger.error(
|
||||
"run-now enqueue failed: %s", exc, exc_info=True,
|
||||
)
|
||||
return _err("enqueue failed", 500)
|
||||
return _ok({"run": _format_run(run)}, status=202)
|
||||
|
||||
|
||||
@schedules_ns.route("/schedules/<string:schedule_id>/runs")
|
||||
class ScheduleRunList(Resource):
|
||||
@api.doc(
|
||||
description="Paginated run log for a schedule.",
|
||||
params={"limit": "Page size (default 50)", "offset": "Page offset"},
|
||||
)
|
||||
def get(self, schedule_id):
|
||||
user_id = _user_id()
|
||||
if not user_id:
|
||||
return _err("unauthorized", 401)
|
||||
if not looks_like_uuid(schedule_id):
|
||||
return _err("invalid schedule id", 400)
|
||||
try:
|
||||
limit = max(1, min(int(request.args.get("limit", 50)), 200))
|
||||
except (TypeError, ValueError):
|
||||
limit = 50
|
||||
try:
|
||||
offset = max(0, int(request.args.get("offset", 0)))
|
||||
except (TypeError, ValueError):
|
||||
offset = 0
|
||||
with db_readonly() as conn:
|
||||
schedule = SchedulesRepository(conn).get(schedule_id, user_id)
|
||||
if schedule is None:
|
||||
return _err("schedule not found", 404)
|
||||
rows = ScheduleRunsRepository(conn).list_runs(
|
||||
schedule_id, user_id, limit=limit, offset=offset,
|
||||
)
|
||||
return _ok(
|
||||
{
|
||||
"runs": [_format_run(r) for r in rows],
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@schedules_ns.route("/schedules/<string:schedule_id>/runs/<string:run_id>")
|
||||
class ScheduleRunDetail(Resource):
|
||||
@api.doc(description="Full output / error for a single run.")
|
||||
def get(self, schedule_id, run_id):
|
||||
user_id = _user_id()
|
||||
if not user_id:
|
||||
return _err("unauthorized", 401)
|
||||
if not looks_like_uuid(schedule_id) or not looks_like_uuid(run_id):
|
||||
return _err("invalid id", 400)
|
||||
with db_readonly() as conn:
|
||||
schedule = SchedulesRepository(conn).get(schedule_id, user_id)
|
||||
if schedule is None:
|
||||
return _err("schedule not found", 404)
|
||||
run = ScheduleRunsRepository(conn).get(run_id, user_id)
|
||||
if run is None or str(run.get("schedule_id")) != str(
|
||||
schedule["id"]
|
||||
):
|
||||
return _err("run not found", 404)
|
||||
return _ok({"run": _format_run(run)})
|
||||
@@ -7,8 +7,12 @@ from flask import current_app, jsonify, make_response, redirect, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.tasks import sync_source
|
||||
from application.api.user.tasks import reingest_source_task, sync_source
|
||||
from application.core.settings import settings
|
||||
from application.parser.remote.remote_creator import normalize_remote_data
|
||||
from application.storage.db.repositories.ingest_chunk_progress import (
|
||||
IngestChunkProgressRepository,
|
||||
)
|
||||
from application.storage.db.repositories.sources import SourcesRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
@@ -139,6 +143,8 @@ class PaginatedSources(Resource):
|
||||
"provider": provider,
|
||||
"isNested": bool(doc.get("directory_structure")),
|
||||
"type": doc.get("type", "file"),
|
||||
# Derived in SourcesRepository.list_for_user.
|
||||
"ingestStatus": doc.get("ingest_status"),
|
||||
}
|
||||
)
|
||||
response = {
|
||||
@@ -322,7 +328,7 @@ class SyncSource(Resource):
|
||||
),
|
||||
400,
|
||||
)
|
||||
source_data = doc.get("remote_data")
|
||||
source_data = normalize_remote_data(source_type, doc.get("remote_data"))
|
||||
if not source_data:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Source is not syncable"}), 400
|
||||
@@ -346,6 +352,70 @@ class SyncSource(Resource):
|
||||
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
|
||||
|
||||
|
||||
@sources_ns.route("/sources/reingest")
|
||||
class ReingestSource(Resource):
|
||||
reingest_source_model = api.model(
|
||||
"ReingestSourceModel",
|
||||
{"source_id": fields.String(required=True, description="Source ID")},
|
||||
)
|
||||
|
||||
@api.expect(reingest_source_model)
|
||||
@api.doc(
|
||||
description="Re-run ingestion for a source — e.g. to recover a "
|
||||
"stalled embed flagged by the reconciler."
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json() or {}
|
||||
missing_fields = check_required_fields(data, ["source_id"])
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
source_id = data["source_id"]
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
doc = SourcesRepository(conn).get_any(source_id, user)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error looking up source: {err}", exc_info=True
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid source ID"}), 400
|
||||
)
|
||||
if not doc:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Source not found"}), 404
|
||||
)
|
||||
resolved_source_id = str(doc["id"])
|
||||
# Drop the stale chunk-progress row so the sources list stops
|
||||
# deriving a 'failed' status; reingest never rewrites it itself.
|
||||
try:
|
||||
with db_session() as conn:
|
||||
IngestChunkProgressRepository(conn).delete(resolved_source_id)
|
||||
except Exception as err:
|
||||
current_app.logger.warning(
|
||||
f"Could not clear ingest progress for {resolved_source_id}: "
|
||||
f"{err}",
|
||||
exc_info=True,
|
||||
)
|
||||
try:
|
||||
# Scoped key so repeated clicks collapse onto one reingest.
|
||||
task = reingest_source_task.delay(
|
||||
source_id=resolved_source_id,
|
||||
user=user,
|
||||
idempotency_key=f"reingest-source:{user}:{resolved_source_id}",
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error starting reingest for source {source_id}: {err}",
|
||||
exc_info=True,
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
|
||||
|
||||
|
||||
@sources_ns.route("/directory_structure")
|
||||
class DirectoryStructure(Resource):
|
||||
@api.doc(
|
||||
|
||||
@@ -3,16 +3,20 @@
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import uuid
|
||||
import zipfile
|
||||
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
from sqlalchemy import text as sql_text
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.tasks import ingest, ingest_connector_task, ingest_remote
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.source_ids import derive_source_id as _derive_source_id
|
||||
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||
from application.parser.file.constants import SUPPORTED_SOURCE_EXTENSIONS
|
||||
from application.storage.db.repositories.idempotency import IdempotencyRepository
|
||||
from application.storage.db.repositories.sources import SourcesRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
@@ -30,6 +34,91 @@ sources_upload_ns = Namespace(
|
||||
)
|
||||
|
||||
|
||||
_IDEMPOTENCY_KEY_MAX_LEN = 256
|
||||
|
||||
|
||||
def _read_idempotency_key():
|
||||
"""Return (key, error_response). Empty header → (None, None); oversized → (None, 400)."""
|
||||
key = request.headers.get("Idempotency-Key")
|
||||
if not key:
|
||||
return None, None
|
||||
if len(key) > _IDEMPOTENCY_KEY_MAX_LEN:
|
||||
return None, make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": (
|
||||
f"Idempotency-Key exceeds maximum length of "
|
||||
f"{_IDEMPOTENCY_KEY_MAX_LEN} characters"
|
||||
),
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
return key, None
|
||||
|
||||
|
||||
def _scoped_idempotency_key(idempotency_key, scope):
|
||||
"""``{scope}:{key}`` so different users can't collide on the same key."""
|
||||
if not idempotency_key or not scope:
|
||||
return None
|
||||
return f"{scope}:{idempotency_key}"
|
||||
|
||||
|
||||
def _claim_task_or_get_cached(key, task_name):
|
||||
"""Claim ``key`` for this request OR return the winner's cached payload.
|
||||
|
||||
Pre-generates the celery task_id so a losing writer sees the same
|
||||
id immediately. Returns ``(task_id, cached_response)``; non-None
|
||||
cached means the caller should return without enqueuing. The
|
||||
cached payload mirrors the fresh-request response shape (including
|
||||
``source_id``) so the frontend can correlate SSE ingest events to
|
||||
the cached upload task without an extra round-trip — but only when
|
||||
the cached row actually exists; the "deduplicated" sentinel
|
||||
deliberately omits ``source_id`` so the frontend doesn't bind to a
|
||||
phantom source.
|
||||
"""
|
||||
predetermined_id = str(uuid.uuid4())
|
||||
with db_session() as conn:
|
||||
claimed = IdempotencyRepository(conn).claim_task(
|
||||
key=key, task_name=task_name, task_id=predetermined_id,
|
||||
)
|
||||
if claimed is not None:
|
||||
return claimed["task_id"], None
|
||||
with db_readonly() as conn:
|
||||
existing = IdempotencyRepository(conn).get_task(key)
|
||||
cached_id = existing.get("task_id") if existing else None
|
||||
payload: dict = {
|
||||
"success": True,
|
||||
"task_id": cached_id or "deduplicated",
|
||||
}
|
||||
# Only surface ``source_id`` when there's a real winner whose worker
|
||||
# is publishing SSE events tagged with that id. The "deduplicated"
|
||||
# branch means the lock row vanished — we have nothing to correlate.
|
||||
if cached_id is not None:
|
||||
payload["source_id"] = str(_derive_source_id(key))
|
||||
return None, payload
|
||||
|
||||
|
||||
def _release_claim(key):
|
||||
"""Drop a pending claim so a client retry can re-claim it."""
|
||||
try:
|
||||
with db_session() as conn:
|
||||
conn.execute(
|
||||
sql_text(
|
||||
"DELETE FROM task_dedup WHERE idempotency_key = :k "
|
||||
"AND status = 'pending'"
|
||||
),
|
||||
{"k": key},
|
||||
)
|
||||
except Exception:
|
||||
current_app.logger.exception(
|
||||
"Failed to release task_dedup claim for key=%s", key,
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
def _enforce_audio_path_size_limit(file_path: str, filename: str) -> None:
|
||||
if not is_audio_filename(filename):
|
||||
return
|
||||
@@ -49,17 +138,38 @@ class UploadFile(Resource):
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Uploads a file to be vectorized and indexed",
|
||||
description=(
|
||||
"Uploads a file to be vectorized and indexed. Honors an optional "
|
||||
"``Idempotency-Key`` header: a repeat request with the same key "
|
||||
"within 24h returns the original cached response without re-enqueuing."
|
||||
),
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
idempotency_key, key_error = _read_idempotency_key()
|
||||
if key_error is not None:
|
||||
return key_error
|
||||
# User-scoped to avoid cross-user collisions; also feeds
|
||||
# ``_derive_source_id`` so uuid5 stays user-disjoint.
|
||||
scoped_key = _scoped_idempotency_key(idempotency_key, user)
|
||||
# Claim before enqueue; the loser returns the winner's task_id.
|
||||
predetermined_task_id = None
|
||||
if scoped_key:
|
||||
predetermined_task_id, cached = _claim_task_or_get_cached(
|
||||
scoped_key, "ingest",
|
||||
)
|
||||
if cached is not None:
|
||||
return make_response(jsonify(cached), 200)
|
||||
data = request.form
|
||||
files = request.files.getlist("file")
|
||||
required_fields = ["user", "name"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields or not files or all(file.filename == "" for file in files):
|
||||
if scoped_key:
|
||||
_release_claim(scoped_key)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
@@ -69,7 +179,6 @@ class UploadFile(Resource):
|
||||
),
|
||||
400,
|
||||
)
|
||||
user = decoded_token.get("sub")
|
||||
job_name = request.form["name"]
|
||||
|
||||
# Create safe versions for filesystem operations
|
||||
@@ -140,16 +249,37 @@ class UploadFile(Resource):
|
||||
file_path = f"{base_path}/{safe_file}"
|
||||
with open(temp_file_path, "rb") as f:
|
||||
storage.save_file(f, file_path)
|
||||
task = ingest.delay(
|
||||
settings.UPLOAD_FOLDER,
|
||||
list(SUPPORTED_SOURCE_EXTENSIONS),
|
||||
job_name,
|
||||
user,
|
||||
file_path=base_path,
|
||||
filename=dir_name,
|
||||
file_name_map=file_name_map,
|
||||
# Mint the source UUID up here so the HTTP response and the
|
||||
# worker's SSE envelopes share one id. With an idempotency
|
||||
# key we reuse the deterministic uuid5 (retried task lands on
|
||||
# the same source row); without a key we fall back to uuid4.
|
||||
# The worker is told to use this id verbatim — see
|
||||
# ``ingest_worker(source_id=...)``.
|
||||
source_uuid = (
|
||||
_derive_source_id(scoped_key) if scoped_key else uuid.uuid4()
|
||||
)
|
||||
ingest_kwargs = dict(
|
||||
args=(
|
||||
settings.UPLOAD_FOLDER,
|
||||
list(SUPPORTED_SOURCE_EXTENSIONS),
|
||||
job_name,
|
||||
user,
|
||||
),
|
||||
kwargs={
|
||||
"file_path": base_path,
|
||||
"filename": dir_name,
|
||||
"file_name_map": file_name_map,
|
||||
# Scoped so the worker dedup row matches the HTTP claim.
|
||||
"idempotency_key": scoped_key or idempotency_key,
|
||||
"source_id": str(source_uuid),
|
||||
},
|
||||
)
|
||||
if predetermined_task_id is not None:
|
||||
ingest_kwargs["task_id"] = predetermined_task_id
|
||||
task = ingest.apply_async(**ingest_kwargs)
|
||||
except AudioFileTooLargeError:
|
||||
if scoped_key:
|
||||
_release_claim(scoped_key)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
@@ -161,8 +291,21 @@ class UploadFile(Resource):
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error uploading file: {err}", exc_info=True)
|
||||
if scoped_key:
|
||||
_release_claim(scoped_key)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
|
||||
# Predetermined id matches the dedup-claim row; loser GET sees same.
|
||||
response_task_id = predetermined_task_id or task.id
|
||||
# ``source_uuid`` was minted above and passed to the worker as
|
||||
# ``source_id``; the worker uses it verbatim for every SSE event,
|
||||
# so the frontend can correlate inbound ``source.ingest.*`` to
|
||||
# this upload regardless of whether an idempotency key was set.
|
||||
response_payload: dict = {
|
||||
"success": True,
|
||||
"task_id": response_task_id,
|
||||
"source_id": str(source_uuid),
|
||||
}
|
||||
return make_response(jsonify(response_payload), 200)
|
||||
|
||||
|
||||
@sources_upload_ns.route("/remote")
|
||||
@@ -182,17 +325,50 @@ class UploadRemote(Resource):
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Uploads remote source for vectorization",
|
||||
description=(
|
||||
"Uploads remote source for vectorization. Honors an optional "
|
||||
"``Idempotency-Key`` header: a repeat request with the same key "
|
||||
"within 24h returns the original cached response without re-enqueuing."
|
||||
),
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
idempotency_key, key_error = _read_idempotency_key()
|
||||
if key_error is not None:
|
||||
return key_error
|
||||
scoped_key = _scoped_idempotency_key(idempotency_key, user)
|
||||
data = request.form
|
||||
required_fields = ["user", "source", "name", "data"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
task_name_for_dedup = (
|
||||
"ingest_connector_task"
|
||||
if data.get("source") in ConnectorCreator.get_supported_connectors()
|
||||
else "ingest_remote"
|
||||
)
|
||||
predetermined_task_id = None
|
||||
if scoped_key:
|
||||
predetermined_task_id, cached = _claim_task_or_get_cached(
|
||||
scoped_key, task_name_for_dedup,
|
||||
)
|
||||
if cached is not None:
|
||||
return make_response(jsonify(cached), 200)
|
||||
# Mint the source UUID up here so the HTTP response and the
|
||||
# worker's SSE envelopes share one id. Same pattern as
|
||||
# ``UploadFile.post``: with an idempotency key we reuse the
|
||||
# deterministic uuid5 (retried task lands on the same source
|
||||
# row); without a key we fall back to uuid4. The worker is told
|
||||
# to use this id verbatim — see ``remote_worker`` and
|
||||
# ``ingest_connector``. Without this the no-key path would mint
|
||||
# a random uuid4 inside the worker that the frontend has no way
|
||||
# to correlate SSE events to.
|
||||
source_uuid = (
|
||||
_derive_source_id(scoped_key) if scoped_key else uuid.uuid4()
|
||||
)
|
||||
try:
|
||||
config = json.loads(data["data"])
|
||||
source_data = None
|
||||
@@ -208,6 +384,8 @@ class UploadRemote(Resource):
|
||||
elif data["source"] in ConnectorCreator.get_supported_connectors():
|
||||
session_token = config.get("session_token")
|
||||
if not session_token:
|
||||
if scoped_key:
|
||||
_release_claim(scoped_key)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
@@ -236,31 +414,62 @@ class UploadRemote(Resource):
|
||||
config["file_ids"] = file_ids
|
||||
config["folder_ids"] = folder_ids
|
||||
|
||||
task = ingest_connector_task.delay(
|
||||
job_name=data["name"],
|
||||
user=decoded_token.get("sub"),
|
||||
source_type=data["source"],
|
||||
session_token=session_token,
|
||||
file_ids=file_ids,
|
||||
folder_ids=folder_ids,
|
||||
recursive=config.get("recursive", False),
|
||||
retriever=config.get("retriever", "classic"),
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": True, "task_id": task.id}), 200
|
||||
)
|
||||
task = ingest_remote.delay(
|
||||
source_data=source_data,
|
||||
job_name=data["name"],
|
||||
user=decoded_token.get("sub"),
|
||||
loader=data["source"],
|
||||
)
|
||||
connector_kwargs = {
|
||||
"kwargs": {
|
||||
"job_name": data["name"],
|
||||
"user": user,
|
||||
"source_type": data["source"],
|
||||
"session_token": session_token,
|
||||
"file_ids": file_ids,
|
||||
"folder_ids": folder_ids,
|
||||
"recursive": config.get("recursive", False),
|
||||
"retriever": config.get("retriever", "classic"),
|
||||
"idempotency_key": scoped_key or idempotency_key,
|
||||
"source_id": str(source_uuid),
|
||||
},
|
||||
}
|
||||
if predetermined_task_id is not None:
|
||||
connector_kwargs["task_id"] = predetermined_task_id
|
||||
task = ingest_connector_task.apply_async(**connector_kwargs)
|
||||
response_task_id = predetermined_task_id or task.id
|
||||
# ``source_uuid`` was minted above and passed to the
|
||||
# worker as ``source_id``; the worker uses it verbatim
|
||||
# for every SSE event, so the frontend can correlate
|
||||
# inbound ``source.ingest.*`` regardless of whether an
|
||||
# idempotency key was set.
|
||||
response_payload = {
|
||||
"success": True,
|
||||
"task_id": response_task_id,
|
||||
"source_id": str(source_uuid),
|
||||
}
|
||||
return make_response(jsonify(response_payload), 200)
|
||||
remote_kwargs = {
|
||||
"kwargs": {
|
||||
"source_data": source_data,
|
||||
"job_name": data["name"],
|
||||
"user": user,
|
||||
"loader": data["source"],
|
||||
"idempotency_key": scoped_key or idempotency_key,
|
||||
"source_id": str(source_uuid),
|
||||
},
|
||||
}
|
||||
if predetermined_task_id is not None:
|
||||
remote_kwargs["task_id"] = predetermined_task_id
|
||||
task = ingest_remote.apply_async(**remote_kwargs)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error uploading remote source: {err}", exc_info=True
|
||||
)
|
||||
if scoped_key:
|
||||
_release_claim(scoped_key)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
|
||||
response_task_id = predetermined_task_id or task.id
|
||||
response_payload = {
|
||||
"success": True,
|
||||
"task_id": response_task_id,
|
||||
"source_id": str(source_uuid),
|
||||
}
|
||||
return make_response(jsonify(response_payload), 200)
|
||||
|
||||
|
||||
@sources_upload_ns.route("/manage_source_files")
|
||||
@@ -305,6 +514,10 @@ class ManageSourceFiles(Resource):
|
||||
jsonify({"success": False, "message": "Unauthorized"}), 401
|
||||
)
|
||||
user = decoded_token.get("sub")
|
||||
idempotency_key, key_error = _read_idempotency_key()
|
||||
if key_error is not None:
|
||||
return key_error
|
||||
scoped_key = _scoped_idempotency_key(idempotency_key, user)
|
||||
source_id = request.form.get("source_id")
|
||||
operation = request.form.get("operation")
|
||||
|
||||
@@ -347,6 +560,12 @@ class ManageSourceFiles(Resource):
|
||||
jsonify({"success": False, "message": "Database error"}), 500
|
||||
)
|
||||
resolved_source_id = str(source["id"])
|
||||
# Flips to True after each branch's ``apply_async`` returns
|
||||
# successfully — at that point the worker owns the predetermined
|
||||
# task_id. The outer ``except`` only releases the claim while
|
||||
# this is False, so a post-``apply_async`` failure (jsonify,
|
||||
# make_response, etc.) doesn't double-enqueue on the next retry.
|
||||
claim_transferred = False
|
||||
try:
|
||||
storage = StorageCreator.get_storage()
|
||||
source_file_path = source.get("file_path", "")
|
||||
@@ -379,6 +598,34 @@ class ManageSourceFiles(Resource):
|
||||
),
|
||||
400,
|
||||
)
|
||||
|
||||
# Claim before any storage mutation so a duplicate request
|
||||
# short-circuits without touching the filesystem. Mirrors
|
||||
# the pattern in ``UploadFile.post`` / ``UploadRemote.post``
|
||||
# — without it ``.delay()`` would enqueue twice for two
|
||||
# racing same-key POSTs (the worker decorator only
|
||||
# deduplicates *after* completion).
|
||||
predetermined_task_id = None
|
||||
if scoped_key:
|
||||
predetermined_task_id, cached = _claim_task_or_get_cached(
|
||||
scoped_key, "reingest_source_task",
|
||||
)
|
||||
if cached is not None:
|
||||
# Frontend keys reingest polling on
|
||||
# ``reingest_task_id``; the shared cache helper
|
||||
# writes ``task_id``. Alias here so a dedup
|
||||
# response doesn't silently break FileTree's
|
||||
# poller. Override ``source_id`` too — the
|
||||
# helper derives it from the scoped key, which
|
||||
# is correct for upload but wrong for reingest
|
||||
# (the worker publishes events scoped to the
|
||||
# actual source row id).
|
||||
cached_task_id = cached.pop("task_id", None)
|
||||
if cached_task_id is not None:
|
||||
cached["reingest_task_id"] = cached_task_id
|
||||
cached["source_id"] = resolved_source_id
|
||||
return make_response(jsonify(cached), 200)
|
||||
|
||||
added_files = []
|
||||
map_updated = False
|
||||
|
||||
@@ -414,9 +661,15 @@ class ManageSourceFiles(Resource):
|
||||
|
||||
from application.api.user.tasks import reingest_source_task
|
||||
|
||||
task = reingest_source_task.delay(
|
||||
source_id=resolved_source_id, user=user
|
||||
task = reingest_source_task.apply_async(
|
||||
kwargs={
|
||||
"source_id": resolved_source_id,
|
||||
"user": user,
|
||||
"idempotency_key": scoped_key or idempotency_key,
|
||||
},
|
||||
task_id=predetermined_task_id,
|
||||
)
|
||||
claim_transferred = True
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
@@ -426,6 +679,12 @@ class ManageSourceFiles(Resource):
|
||||
"added_files": added_files,
|
||||
"parent_dir": parent_dir,
|
||||
"reingest_task_id": task.id,
|
||||
# ``source_id`` lets the frontend correlate
|
||||
# inbound ``source.ingest.*`` SSE events
|
||||
# (emitted by ``reingest_source_worker``)
|
||||
# back to the reingest task — matches the
|
||||
# upload route's source-id contract.
|
||||
"source_id": resolved_source_id,
|
||||
}
|
||||
),
|
||||
200,
|
||||
@@ -455,10 +714,8 @@ class ManageSourceFiles(Resource):
|
||||
),
|
||||
400,
|
||||
)
|
||||
# Remove files from storage and directory structure
|
||||
|
||||
removed_files = []
|
||||
map_updated = False
|
||||
# Path-traversal guard runs *before* the claim so a 400
|
||||
# for an invalid path doesn't leave a pending dedup row.
|
||||
for file_path in file_paths:
|
||||
if ".." in str(file_path) or str(file_path).startswith("/"):
|
||||
return make_response(
|
||||
@@ -470,6 +727,31 @@ class ManageSourceFiles(Resource):
|
||||
),
|
||||
400,
|
||||
)
|
||||
|
||||
# Claim before any storage mutation. See ``add`` branch
|
||||
# comment for rationale.
|
||||
predetermined_task_id = None
|
||||
if scoped_key:
|
||||
predetermined_task_id, cached = _claim_task_or_get_cached(
|
||||
scoped_key, "reingest_source_task",
|
||||
)
|
||||
if cached is not None:
|
||||
cached_task_id = cached.pop("task_id", None)
|
||||
if cached_task_id is not None:
|
||||
cached["reingest_task_id"] = cached_task_id
|
||||
# Override the helper's synthetic source_id (uuid5
|
||||
# of the scoped key) with the real source row id
|
||||
# — the reingest worker publishes SSE events
|
||||
# scoped to ``resolved_source_id`` and FileTree
|
||||
# correlates on it.
|
||||
cached["source_id"] = resolved_source_id
|
||||
return make_response(jsonify(cached), 200)
|
||||
|
||||
# Remove files from storage and directory structure
|
||||
|
||||
removed_files = []
|
||||
map_updated = False
|
||||
for file_path in file_paths:
|
||||
full_path = f"{source_file_path}/{file_path}"
|
||||
|
||||
# Remove from storage
|
||||
@@ -491,9 +773,15 @@ class ManageSourceFiles(Resource):
|
||||
|
||||
from application.api.user.tasks import reingest_source_task
|
||||
|
||||
task = reingest_source_task.delay(
|
||||
source_id=resolved_source_id, user=user
|
||||
task = reingest_source_task.apply_async(
|
||||
kwargs={
|
||||
"source_id": resolved_source_id,
|
||||
"user": user,
|
||||
"idempotency_key": scoped_key or idempotency_key,
|
||||
},
|
||||
task_id=predetermined_task_id,
|
||||
)
|
||||
claim_transferred = True
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
@@ -502,6 +790,7 @@ class ManageSourceFiles(Resource):
|
||||
"message": f"Removed {len(removed_files)} files",
|
||||
"removed_files": removed_files,
|
||||
"reingest_task_id": task.id,
|
||||
"source_id": resolved_source_id,
|
||||
}
|
||||
),
|
||||
200,
|
||||
@@ -552,6 +841,24 @@ class ManageSourceFiles(Resource):
|
||||
),
|
||||
404,
|
||||
)
|
||||
|
||||
# Claim before mutation. See ``add`` branch for rationale.
|
||||
predetermined_task_id = None
|
||||
if scoped_key:
|
||||
predetermined_task_id, cached = _claim_task_or_get_cached(
|
||||
scoped_key, "reingest_source_task",
|
||||
)
|
||||
if cached is not None:
|
||||
cached_task_id = cached.pop("task_id", None)
|
||||
if cached_task_id is not None:
|
||||
cached["reingest_task_id"] = cached_task_id
|
||||
# Same source_id override as the ``remove`` /
|
||||
# ``add`` cached branches — the helper's synthetic
|
||||
# id doesn't match what reingest_source_worker
|
||||
# tags its SSE events with.
|
||||
cached["source_id"] = resolved_source_id
|
||||
return make_response(jsonify(cached), 200)
|
||||
|
||||
success = storage.remove_directory(full_directory_path)
|
||||
|
||||
if not success:
|
||||
@@ -560,6 +867,11 @@ class ManageSourceFiles(Resource):
|
||||
f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}, "
|
||||
f"Full path: {full_directory_path}"
|
||||
)
|
||||
# Release so a client retry can reclaim — otherwise
|
||||
# the next request would silently 200-cache to the
|
||||
# task_id that never enqueued.
|
||||
if scoped_key:
|
||||
_release_claim(scoped_key)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": False, "message": "Failed to remove directory"}
|
||||
@@ -591,9 +903,15 @@ class ManageSourceFiles(Resource):
|
||||
|
||||
from application.api.user.tasks import reingest_source_task
|
||||
|
||||
task = reingest_source_task.delay(
|
||||
source_id=resolved_source_id, user=user
|
||||
task = reingest_source_task.apply_async(
|
||||
kwargs={
|
||||
"source_id": resolved_source_id,
|
||||
"user": user,
|
||||
"idempotency_key": scoped_key or idempotency_key,
|
||||
},
|
||||
task_id=predetermined_task_id,
|
||||
)
|
||||
claim_transferred = True
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
@@ -602,11 +920,20 @@ class ManageSourceFiles(Resource):
|
||||
"message": f"Successfully removed directory: {directory_path}",
|
||||
"removed_directory": directory_path,
|
||||
"reingest_task_id": task.id,
|
||||
"source_id": resolved_source_id,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
except Exception as err:
|
||||
# Release the dedup claim only if it wasn't transferred to
|
||||
# a worker. Without this, a same-key retry within the 24h
|
||||
# TTL would 200-cache to a predetermined task_id whose
|
||||
# ``apply_async`` never ran (or ran but the response builder
|
||||
# blew up afterward — only the first case matters in
|
||||
# practice; the flag protects both).
|
||||
if scoped_key and not claim_transferred:
|
||||
_release_claim(scoped_key)
|
||||
error_context = f"operation={operation}, user={user}, source_id={source_id}"
|
||||
if operation == "remove_directory":
|
||||
directory_path = request.form.get("directory_path", "")
|
||||
|
||||
@@ -1,21 +1,79 @@
|
||||
from datetime import timedelta
|
||||
|
||||
from application.api.user.idempotency import with_idempotency
|
||||
from application.celery_init import celery
|
||||
from application.worker import (
|
||||
agent_webhook_worker,
|
||||
attachment_worker,
|
||||
ingest_worker,
|
||||
mcp_oauth,
|
||||
mcp_oauth_status,
|
||||
remote_worker,
|
||||
sync,
|
||||
sync_worker,
|
||||
)
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
# Shared decorator config for long-running, side-effecting tasks. ``acks_late``
|
||||
# is also the celeryconfig default but stays explicit here so each task's
|
||||
# durability story is grep-able next to the body. Combined with
|
||||
# ``autoretry_for=(Exception,)`` and a bounded ``max_retries`` so a poison
|
||||
# message can't loop forever.
|
||||
DURABLE_TASK = dict(
|
||||
bind=True,
|
||||
acks_late=True,
|
||||
autoretry_for=(Exception,),
|
||||
retry_kwargs={"max_retries": 3, "countdown": 60},
|
||||
retry_backoff=True,
|
||||
)
|
||||
|
||||
|
||||
# operation tag for the poison-path source.ingest.failed event, per task.
|
||||
_INGEST_POISON_OPERATION = {
|
||||
"ingest": "upload",
|
||||
"ingest_remote": "upload",
|
||||
"ingest_connector_task": "upload",
|
||||
"reingest_source_task": "reingest",
|
||||
}
|
||||
|
||||
|
||||
def _emit_ingest_poison_event(task_name, bound):
|
||||
"""Publish a terminal ``source.ingest.failed`` when the poison-guard trips.
|
||||
|
||||
The guard returns before the worker runs, so the worker's own failed
|
||||
event never fires — without this the upload toast spins on "training".
|
||||
"""
|
||||
user = bound.get("user")
|
||||
source_id = bound.get("source_id")
|
||||
if not user or not source_id:
|
||||
return
|
||||
from application.events.publisher import publish_user_event
|
||||
|
||||
publish_user_event(
|
||||
user,
|
||||
"source.ingest.failed",
|
||||
{
|
||||
"source_id": str(source_id),
|
||||
"filename": bound.get("filename") or "",
|
||||
"operation": _INGEST_POISON_OPERATION.get(task_name, "upload"),
|
||||
"error": "Ingestion stopped after repeated failures.",
|
||||
},
|
||||
scope={"kind": "source", "id": str(source_id)},
|
||||
)
|
||||
|
||||
|
||||
@celery.task(**DURABLE_TASK)
|
||||
@with_idempotency(task_name="ingest", on_poison=_emit_ingest_poison_event)
|
||||
def ingest(
|
||||
self, directory, formats, job_name, user, file_path, filename, file_name_map=None
|
||||
self,
|
||||
directory,
|
||||
formats,
|
||||
job_name,
|
||||
user,
|
||||
file_path,
|
||||
filename,
|
||||
file_name_map=None,
|
||||
idempotency_key=None,
|
||||
source_id=None,
|
||||
):
|
||||
resp = ingest_worker(
|
||||
self,
|
||||
@@ -26,25 +84,42 @@ def ingest(
|
||||
filename,
|
||||
user,
|
||||
file_name_map=file_name_map,
|
||||
idempotency_key=idempotency_key,
|
||||
source_id=source_id,
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def ingest_remote(self, source_data, job_name, user, loader):
|
||||
resp = remote_worker(self, source_data, job_name, user, loader)
|
||||
@celery.task(**DURABLE_TASK)
|
||||
@with_idempotency(task_name="ingest_remote", on_poison=_emit_ingest_poison_event)
|
||||
def ingest_remote(
|
||||
self, source_data, job_name, user, loader,
|
||||
idempotency_key=None, source_id=None,
|
||||
):
|
||||
resp = remote_worker(
|
||||
self, source_data, job_name, user, loader,
|
||||
idempotency_key=idempotency_key,
|
||||
source_id=source_id,
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def reingest_source_task(self, source_id, user):
|
||||
@celery.task(**DURABLE_TASK)
|
||||
@with_idempotency(
|
||||
task_name="reingest_source_task", on_poison=_emit_ingest_poison_event,
|
||||
)
|
||||
def reingest_source_task(self, source_id, user, idempotency_key=None):
|
||||
from application.worker import reingest_source_worker
|
||||
|
||||
resp = reingest_source_worker(self, source_id, user)
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
# Beat-driven dispatch tasks default to ``acks_late=False``: a SIGKILL
|
||||
# of a beat tick is harmless to redeliver only if the dispatch itself is
|
||||
# idempotent. We keep these early-ACK so the broker doesn't replay a
|
||||
# dispatch that already enqueued downstream work.
|
||||
@celery.task(bind=True, acks_late=False)
|
||||
def schedule_syncs(self, frequency):
|
||||
resp = sync_worker(self, frequency)
|
||||
return resp
|
||||
@@ -74,19 +149,24 @@ def sync_source(
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def store_attachment(self, file_info, user):
|
||||
@celery.task(**DURABLE_TASK)
|
||||
@with_idempotency(task_name="store_attachment")
|
||||
def store_attachment(self, file_info, user, idempotency_key=None):
|
||||
resp = attachment_worker(self, file_info, user)
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def process_agent_webhook(self, agent_id, payload):
|
||||
@celery.task(**DURABLE_TASK)
|
||||
@with_idempotency(task_name="process_agent_webhook")
|
||||
def process_agent_webhook(self, agent_id, payload, idempotency_key=None):
|
||||
resp = agent_webhook_worker(self, agent_id, payload)
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
@celery.task(**DURABLE_TASK)
|
||||
@with_idempotency(
|
||||
task_name="ingest_connector_task", on_poison=_emit_ingest_poison_event,
|
||||
)
|
||||
def ingest_connector_task(
|
||||
self,
|
||||
job_name,
|
||||
@@ -100,6 +180,8 @@ def ingest_connector_task(
|
||||
operation_mode="upload",
|
||||
doc_id=None,
|
||||
sync_frequency="never",
|
||||
idempotency_key=None,
|
||||
source_id=None,
|
||||
):
|
||||
from application.worker import ingest_connector
|
||||
|
||||
@@ -116,12 +198,70 @@ def ingest_connector_task(
|
||||
operation_mode=operation_mode,
|
||||
doc_id=doc_id,
|
||||
sync_frequency=sync_frequency,
|
||||
idempotency_key=idempotency_key,
|
||||
source_id=source_id,
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(bind=True, acks_late=False)
|
||||
def dispatch_scheduled_runs(self):
|
||||
"""Beat-driven scheduler poller (body in scheduler_dispatcher)."""
|
||||
from application.api.user.scheduler_dispatcher import dispatch_due_runs
|
||||
|
||||
return dispatch_due_runs()
|
||||
|
||||
|
||||
@celery.task(
|
||||
bind=True,
|
||||
acks_late=True,
|
||||
# Not DURABLE_TASK: agent runs have side effects; blind retry would double them.
|
||||
autoretry_for=(),
|
||||
max_retries=0,
|
||||
)
|
||||
def execute_scheduled_run(self, run_id):
|
||||
"""Execute one scheduled run; soft-time-limit honors SCHEDULE_RUN_TIMEOUT."""
|
||||
from application.api.user.scheduler_worker import execute_scheduled_run_body
|
||||
|
||||
return execute_scheduled_run_body(run_id, getattr(self.request, "id", None))
|
||||
|
||||
|
||||
# Bind runtime soft-time-limit so the prefork worker can raise mid-agent.
|
||||
try:
|
||||
from application.core.settings import settings as _scheduler_settings
|
||||
execute_scheduled_run.soft_time_limit = max(
|
||||
30, int(_scheduler_settings.SCHEDULE_RUN_TIMEOUT),
|
||||
)
|
||||
execute_scheduled_run.time_limit = (
|
||||
execute_scheduled_run.soft_time_limit + 60
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@celery.task(bind=True, acks_late=False)
|
||||
def cleanup_schedule_runs(self):
|
||||
"""Trim ``schedule_runs`` per ``SCHEDULE_RUN_OUTPUT_RETENTION_DAYS``."""
|
||||
from application.core.settings import settings
|
||||
if not settings.POSTGRES_URI:
|
||||
return {"deleted": 0, "skipped": "POSTGRES_URI not set"}
|
||||
|
||||
from application.storage.db.engine import get_engine
|
||||
from application.storage.db.repositories.schedule_runs import (
|
||||
ScheduleRunsRepository,
|
||||
)
|
||||
|
||||
ttl_days = settings.SCHEDULE_RUN_OUTPUT_RETENTION_DAYS
|
||||
engine = get_engine()
|
||||
with engine.begin() as conn:
|
||||
deleted = ScheduleRunsRepository(conn).cleanup_older_than(ttl_days)
|
||||
return {"deleted": deleted, "ttl_days": ttl_days}
|
||||
|
||||
|
||||
@celery.on_after_configure.connect
|
||||
def setup_periodic_tasks(sender, **kwargs):
|
||||
from application.core.settings import settings
|
||||
|
||||
sender.add_periodic_task(
|
||||
timedelta(days=1),
|
||||
schedule_syncs.s("daily"),
|
||||
@@ -140,11 +280,49 @@ def setup_periodic_tasks(sender, **kwargs):
|
||||
cleanup_pending_tool_state.s(),
|
||||
name="cleanup-pending-tool-state",
|
||||
)
|
||||
# Pure housekeeping for ``task_dedup`` / ``webhook_dedup`` — the
|
||||
# upsert paths already handle stale rows, so cadence only bounds
|
||||
# table size. Hourly is plenty for typical traffic.
|
||||
sender.add_periodic_task(
|
||||
timedelta(hours=1),
|
||||
cleanup_idempotency_dedup.s(),
|
||||
name="cleanup-idempotency-dedup",
|
||||
)
|
||||
sender.add_periodic_task(
|
||||
timedelta(seconds=30),
|
||||
reconciliation_task.s(),
|
||||
name="reconciliation",
|
||||
)
|
||||
sender.add_periodic_task(
|
||||
timedelta(hours=7),
|
||||
version_check_task.s(),
|
||||
name="version-check",
|
||||
)
|
||||
# Bound ``message_events`` growth — every streamed SSE chunk writes
|
||||
# one row, so retained chats accumulate hundreds of rows per
|
||||
# message. Reconnect-replay is only meaningful for streams the user
|
||||
# could plausibly still be waiting on, so 14 days is generous.
|
||||
sender.add_periodic_task(
|
||||
timedelta(hours=24),
|
||||
cleanup_message_events.s(),
|
||||
name="cleanup-message-events",
|
||||
)
|
||||
sender.add_periodic_task(
|
||||
timedelta(hours=24),
|
||||
cleanup_orphan_memories.s(),
|
||||
name="cleanup-orphan-memories",
|
||||
)
|
||||
# Scheduler dispatcher and run-log trim.
|
||||
sender.add_periodic_task(
|
||||
timedelta(seconds=max(15, settings.SCHEDULE_DISPATCHER_INTERVAL)),
|
||||
dispatch_scheduled_runs.s(),
|
||||
name="dispatch-scheduled-runs",
|
||||
)
|
||||
sender.add_periodic_task(
|
||||
timedelta(hours=24),
|
||||
cleanup_schedule_runs.s(),
|
||||
name="cleanup-schedule-runs",
|
||||
)
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
@@ -153,24 +331,12 @@ def mcp_oauth_task(self, config, user):
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def mcp_oauth_status_task(self, task_id):
|
||||
resp = mcp_oauth_status(self, task_id)
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
@celery.task(bind=True, acks_late=False)
|
||||
def cleanup_pending_tool_state(self):
|
||||
"""Delete pending_tool_state rows past their TTL.
|
||||
|
||||
Replaces Mongo's ``expireAfterSeconds=0`` TTL index — Postgres has
|
||||
no native TTL, so this task runs every 60 seconds to keep
|
||||
``pending_tool_state`` bounded. No-ops if ``POSTGRES_URI`` isn't
|
||||
configured (keeps the task runnable in Mongo-only environments).
|
||||
"""
|
||||
"""Revert stale ``resuming`` rows, then delete TTL-expired rows."""
|
||||
from application.core.settings import settings
|
||||
if not settings.POSTGRES_URI:
|
||||
return {"deleted": 0, "skipped": "POSTGRES_URI not set"}
|
||||
return {"deleted": 0, "reverted": 0, "skipped": "POSTGRES_URI not set"}
|
||||
|
||||
from application.storage.db.engine import get_engine
|
||||
from application.storage.db.repositories.pending_tool_state import (
|
||||
@@ -179,11 +345,96 @@ def cleanup_pending_tool_state(self):
|
||||
|
||||
engine = get_engine()
|
||||
with engine.begin() as conn:
|
||||
deleted = PendingToolStateRepository(conn).cleanup_expired()
|
||||
repo = PendingToolStateRepository(conn)
|
||||
reverted = repo.revert_stale_resuming(grace_seconds=600)
|
||||
deleted = repo.cleanup_expired()
|
||||
return {"deleted": deleted, "reverted": reverted}
|
||||
|
||||
|
||||
@celery.task(bind=True, acks_late=False)
|
||||
def cleanup_idempotency_dedup(self):
|
||||
"""Delete TTL-expired rows from ``task_dedup`` and ``webhook_dedup``.
|
||||
|
||||
Pure housekeeping — the upsert paths already ignore stale rows
|
||||
(TTL-aware ``ON CONFLICT DO UPDATE``), so this only bounds table
|
||||
growth and keeps SELECT planning tight on large deployments.
|
||||
"""
|
||||
from application.core.settings import settings
|
||||
if not settings.POSTGRES_URI:
|
||||
return {
|
||||
"task_dedup_deleted": 0,
|
||||
"webhook_dedup_deleted": 0,
|
||||
"skipped": "POSTGRES_URI not set",
|
||||
}
|
||||
|
||||
from application.storage.db.engine import get_engine
|
||||
from application.storage.db.repositories.idempotency import (
|
||||
IdempotencyRepository,
|
||||
)
|
||||
|
||||
engine = get_engine()
|
||||
with engine.begin() as conn:
|
||||
return IdempotencyRepository(conn).cleanup_expired()
|
||||
|
||||
|
||||
@celery.task(bind=True, acks_late=False)
|
||||
def reconciliation_task(self):
|
||||
"""Sweep stuck durability rows and escalate them to terminal status + alert."""
|
||||
from application.api.user.reconciliation import run_reconciliation
|
||||
|
||||
return run_reconciliation()
|
||||
|
||||
|
||||
@celery.task(bind=True, acks_late=False)
|
||||
def cleanup_message_events(self):
|
||||
"""Delete ``message_events`` rows older than the retention window.
|
||||
|
||||
Streamed answer responses write one journal row per SSE yield,
|
||||
so unbounded growth would dominate Postgres for any retained-
|
||||
conversations deployment. The reconnect-replay path only needs
|
||||
rows for in-flight streams; 14 days covers paused/tool-action
|
||||
flows comfortably.
|
||||
"""
|
||||
from application.core.settings import settings
|
||||
if not settings.POSTGRES_URI:
|
||||
return {"deleted": 0, "skipped": "POSTGRES_URI not set"}
|
||||
|
||||
from application.storage.db.engine import get_engine
|
||||
from application.storage.db.repositories.message_events import (
|
||||
MessageEventsRepository,
|
||||
)
|
||||
|
||||
ttl_days = settings.MESSAGE_EVENTS_RETENTION_DAYS
|
||||
engine = get_engine()
|
||||
with engine.begin() as conn:
|
||||
deleted = MessageEventsRepository(conn).cleanup_older_than(ttl_days)
|
||||
return {"deleted": deleted, "ttl_days": ttl_days}
|
||||
|
||||
|
||||
@celery.task(bind=True, acks_late=False)
|
||||
def cleanup_orphan_memories(self):
|
||||
"""Sweep orphan memories left by the 0009 FK-to-trigger orphan window.
|
||||
|
||||
A ``memories`` INSERT for a real ``tool_id`` racing a ``user_tools``
|
||||
DELETE leaves a permanent orphan the dropped FK would have rejected.
|
||||
Default-tool synthetic ids are preserved (legitimate built-in data).
|
||||
"""
|
||||
from application.core.settings import settings
|
||||
if not settings.POSTGRES_URI:
|
||||
return {"deleted": 0, "skipped": "POSTGRES_URI not set"}
|
||||
|
||||
from application.agents.default_tools import default_tool_ids
|
||||
from application.storage.db.engine import get_engine
|
||||
from application.storage.db.repositories.memories import MemoriesRepository
|
||||
|
||||
keep_tool_ids = list(default_tool_ids().values())
|
||||
engine = get_engine()
|
||||
with engine.begin() as conn:
|
||||
deleted = MemoriesRepository(conn).delete_orphans(keep_tool_ids)
|
||||
return {"deleted": deleted}
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
@celery.task(bind=True, acks_late=False)
|
||||
def version_check_task(self):
|
||||
"""Periodic anonymous version check.
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Tool management MCP server integration."""
|
||||
|
||||
import json
|
||||
from urllib.parse import urlencode, urlparse
|
||||
|
||||
from flask import current_app, jsonify, make_response, redirect, request
|
||||
@@ -226,7 +225,9 @@ class MCPServerSave(Resource):
|
||||
)
|
||||
redis_client = get_redis_instance()
|
||||
manager = MCPOAuthManager(redis_client)
|
||||
result = manager.get_oauth_status(config["oauth_task_id"])
|
||||
result = manager.get_oauth_status(
|
||||
config["oauth_task_id"], user
|
||||
)
|
||||
if not result.get("status") == "completed":
|
||||
return make_response(
|
||||
jsonify(
|
||||
@@ -438,56 +439,6 @@ class MCPOAuthCallback(Resource):
|
||||
)
|
||||
|
||||
|
||||
@tools_mcp_ns.route("/mcp_server/oauth_status/<string:task_id>")
|
||||
class MCPOAuthStatus(Resource):
|
||||
def get(self, task_id):
|
||||
try:
|
||||
redis_client = get_redis_instance()
|
||||
status_key = f"mcp_oauth_status:{task_id}"
|
||||
status_data = redis_client.get(status_key)
|
||||
|
||||
if status_data:
|
||||
status = json.loads(status_data)
|
||||
if "tools" in status and isinstance(status["tools"], list):
|
||||
status["tools"] = [
|
||||
{
|
||||
"name": t.get("name", "unknown"),
|
||||
"description": t.get("description", ""),
|
||||
}
|
||||
for t in status["tools"]
|
||||
]
|
||||
return make_response(
|
||||
jsonify({"success": True, "task_id": task_id, **status})
|
||||
)
|
||||
else:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"task_id": task_id,
|
||||
"status": "pending",
|
||||
"message": "Waiting for OAuth to start...",
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error getting OAuth status for task {task_id}: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Failed to get OAuth status",
|
||||
"task_id": task_id,
|
||||
}
|
||||
),
|
||||
500,
|
||||
)
|
||||
|
||||
|
||||
@tools_mcp_ns.route("/mcp_server/auth_status")
|
||||
class MCPAuthStatus(Resource):
|
||||
@api.doc(
|
||||
|
||||
@@ -3,6 +3,15 @@
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.agents.default_tools import (
|
||||
builtin_agent_tools_for_management,
|
||||
BUILTIN_AGENT_TOOLS,
|
||||
default_tool_name_for_id,
|
||||
default_tools_for_management,
|
||||
is_builtin_agent_tool_id,
|
||||
is_default_tool_id,
|
||||
is_synthesized_tool_id,
|
||||
)
|
||||
from application.agents.tools.spec_parser import parse_spec
|
||||
from application.agents.tools.tool_manager import ToolManager
|
||||
from application.api import api
|
||||
@@ -11,6 +20,7 @@ from application.security.encryption import decrypt_credentials, encrypt_credent
|
||||
from application.storage.db.repositories.notes import NotesRepository
|
||||
from application.storage.db.repositories.todos import TodosRepository
|
||||
from application.storage.db.repositories.user_tools import UserToolsRepository
|
||||
from application.storage.db.repositories.users import UsersRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.utils import check_required_fields, validate_function_name
|
||||
|
||||
@@ -208,6 +218,7 @@ class GetTools(Resource):
|
||||
user = decoded_token.get("sub")
|
||||
with db_readonly() as conn:
|
||||
rows = UserToolsRepository(conn).list_for_user(user)
|
||||
user_doc = UsersRepository(conn).get(user)
|
||||
user_tools = []
|
||||
for row in rows:
|
||||
tool_copy = _row_to_api(row)
|
||||
@@ -227,6 +238,29 @@ class GetTools(Resource):
|
||||
tool_copy["config"].pop("encrypted_credentials", None)
|
||||
|
||||
user_tools.append(tool_copy)
|
||||
|
||||
# ``scheduler`` is dual-registered (default chat tool + agent-
|
||||
# selectable builtin) and resolves to the same synthetic uuid5 id.
|
||||
# Surface a single row with both flags so the frontend can show it
|
||||
# in the management page (toggle) and the agent picker.
|
||||
seen_ids: set = set()
|
||||
for default_row in default_tools_for_management(user_doc):
|
||||
default_copy = _row_to_api(default_row)
|
||||
default_copy["default"] = True
|
||||
if default_copy.get("name") in BUILTIN_AGENT_TOOLS:
|
||||
default_copy["builtin"] = True
|
||||
seen_ids.add(str(default_copy["id"]))
|
||||
user_tools.append(default_copy)
|
||||
# Builtins (e.g. scheduler) hidden from Add-Tool catalog, visible
|
||||
# to the agent picker. Skip ones already added via the default
|
||||
# path — both registries share ``_DEFAULT_TOOL_NAMESPACE``.
|
||||
for builtin_row in builtin_agent_tools_for_management():
|
||||
builtin_copy = _row_to_api(builtin_row)
|
||||
if str(builtin_copy["id"]) in seen_ids:
|
||||
continue
|
||||
builtin_copy["builtin"] = True
|
||||
builtin_copy["default"] = False
|
||||
user_tools.append(builtin_copy)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error getting user tools: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
@@ -367,6 +401,46 @@ class UpdateTool(Resource):
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
# Default-tool branch first: a dual-registered tool (e.g. ``scheduler``)
|
||||
# matches BOTH ``is_default_tool_id`` and ``is_builtin_agent_tool_id``.
|
||||
# The toggle in Tools settings is the per-user opt-out for the
|
||||
# agentless default — it must reach the ``set_default_tool_enabled``
|
||||
# path, not the builtin "not editable" reject.
|
||||
if is_default_tool_id(data["id"]):
|
||||
if "status" not in data:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Default tools are not editable; "
|
||||
"only their on/off status can be changed.",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
tool_name = default_tool_name_for_id(data["id"])
|
||||
try:
|
||||
with db_session() as conn:
|
||||
UsersRepository(conn).set_default_tool_enabled(
|
||||
user, tool_name, bool(data["status"])
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error updating default tool: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
if is_builtin_agent_tool_id(data["id"]):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Built-in agent tools are not editable; "
|
||||
"add them to an agent via the agent picker.",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
try:
|
||||
update_data: dict = {}
|
||||
for key in ("name", "displayName", "customName", "description", "actions"):
|
||||
@@ -471,6 +545,17 @@ class UpdateToolConfig(Resource):
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
if is_synthesized_tool_id(data["id"]):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Default and built-in tools are config-free "
|
||||
"and cannot be configured.",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
try:
|
||||
with db_session() as conn:
|
||||
repo = UserToolsRepository(conn)
|
||||
@@ -550,6 +635,16 @@ class UpdateToolActions(Resource):
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
if is_synthesized_tool_id(data["id"]):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Default and built-in tools' actions are not editable.",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
try:
|
||||
with db_session() as conn:
|
||||
repo = UserToolsRepository(conn)
|
||||
@@ -595,6 +690,27 @@ class UpdateToolStatus(Resource):
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
# Default branch first so a dual-registered id (e.g. ``scheduler``)
|
||||
# writes the per-user opt-out instead of being rejected as a
|
||||
# not-editable builtin (both predicates match the same uuid5).
|
||||
if is_default_tool_id(data["id"]):
|
||||
tool_name = default_tool_name_for_id(data["id"])
|
||||
with db_session() as conn:
|
||||
UsersRepository(conn).set_default_tool_enabled(
|
||||
user, tool_name, bool(data["status"])
|
||||
)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
if is_builtin_agent_tool_id(data["id"]):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Built-in agent tools have no per-user "
|
||||
"toggle; add them to an agent via the agent picker.",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
with db_session() as conn:
|
||||
repo = UserToolsRepository(conn)
|
||||
tool_doc = repo.get_any(data["id"], user)
|
||||
@@ -633,6 +749,16 @@ class DeleteTool(Resource):
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
if is_synthesized_tool_id(data["id"]):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Built-in tools cannot be deleted; disable them instead.",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
try:
|
||||
with db_session() as conn:
|
||||
repo = UserToolsRepository(conn)
|
||||
|
||||
@@ -9,6 +9,7 @@ import json
|
||||
import logging
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Generator, Optional
|
||||
|
||||
from flask import Blueprint, jsonify, make_response, request, Response
|
||||
@@ -221,13 +222,26 @@ def _stream_response(
|
||||
for line in internal_stream:
|
||||
if not line.strip():
|
||||
continue
|
||||
# Parse the internal SSE event
|
||||
event_str = line.replace("data: ", "").strip()
|
||||
# ``complete_stream`` prefixes each frame with ``id: <seq>\n``
|
||||
# before the ``data:`` line. Extract just the data line so JSON
|
||||
# decode doesn't choke on the SSE framing.
|
||||
event_str = ""
|
||||
for raw in line.split("\n"):
|
||||
if raw.startswith("data:"):
|
||||
event_str = raw[len("data:") :].lstrip()
|
||||
break
|
||||
if not event_str:
|
||||
continue
|
||||
try:
|
||||
event_data = json.loads(event_str)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
continue
|
||||
|
||||
# Skip the informational ``message_id`` event — it has no v1 /
|
||||
# OpenAI-compatible analog.
|
||||
if event_data.get("type") == "message_id":
|
||||
continue
|
||||
|
||||
# Update completion_id when we get the conversation id
|
||||
if event_data.get("type") == "id":
|
||||
conv_id = event_data.get("id", "")
|
||||
@@ -306,7 +320,16 @@ def list_models():
|
||||
401,
|
||||
)
|
||||
|
||||
# Repository rows now go through ``coerce_pg_native`` at SELECT
|
||||
# time, so timestamps arrive as ISO 8601 strings. Parse before
|
||||
# taking ``.timestamp()``; fall back to ``time.time()`` only when
|
||||
# the value is genuinely missing or unparseable.
|
||||
created = agent.get("created_at") or agent.get("createdAt")
|
||||
if isinstance(created, str):
|
||||
try:
|
||||
created = datetime.fromisoformat(created)
|
||||
except (ValueError, TypeError):
|
||||
created = None
|
||||
created_ts = (
|
||||
int(created.timestamp()) if hasattr(created, "timestamp")
|
||||
else int(time.time())
|
||||
|
||||
@@ -9,12 +9,15 @@ from jose import jwt
|
||||
|
||||
from application.auth import handle_auth
|
||||
|
||||
from application.core import log_context
|
||||
from application.core.logging_config import setup_logging
|
||||
|
||||
setup_logging()
|
||||
|
||||
from application.api import api # noqa: E402
|
||||
from application.api.answer import answer # noqa: E402
|
||||
from application.api.answer.routes.messages import messages_bp # noqa: E402
|
||||
from application.api.events.routes import events # noqa: E402
|
||||
from application.api.internal.routes import internal # noqa: E402
|
||||
from application.api.user.routes import user # noqa: E402
|
||||
from application.api.connector.routes import connector # noqa: E402
|
||||
@@ -45,9 +48,17 @@ ensure_database_ready(
|
||||
logger=logging.getLogger("application.app"),
|
||||
)
|
||||
|
||||
from application.agents.default_tools import ( # noqa: E402
|
||||
validate_default_chat_tools,
|
||||
)
|
||||
|
||||
validate_default_chat_tools()
|
||||
|
||||
app = Flask(__name__)
|
||||
app.register_blueprint(user)
|
||||
app.register_blueprint(answer)
|
||||
app.register_blueprint(events)
|
||||
app.register_blueprint(messages_bp)
|
||||
app.register_blueprint(internal)
|
||||
app.register_blueprint(connector)
|
||||
app.register_blueprint(v1_bp)
|
||||
@@ -112,6 +123,38 @@ def generate_token():
|
||||
return jsonify({"error": "Token generation not allowed in current auth mode"}), 400
|
||||
|
||||
|
||||
_LOG_CTX_TOKEN_ATTR = "_log_ctx_token"
|
||||
|
||||
|
||||
@app.before_request
|
||||
def _bind_log_context():
|
||||
"""Bind activity_id + endpoint for the duration of this request.
|
||||
|
||||
Runs before ``authenticate_request``; ``user_id`` is overlaid in a
|
||||
follow-up handler once the JWT has been decoded.
|
||||
"""
|
||||
if request.method == "OPTIONS":
|
||||
return None
|
||||
activity_id = str(uuid.uuid4())
|
||||
request.activity_id = activity_id
|
||||
token = log_context.bind(
|
||||
activity_id=activity_id,
|
||||
endpoint=request.endpoint,
|
||||
)
|
||||
setattr(request, _LOG_CTX_TOKEN_ATTR, token)
|
||||
return None
|
||||
|
||||
|
||||
@app.teardown_request
|
||||
def _reset_log_context(_exc):
|
||||
# SSE streams keep yielding after teardown fires, but a2wsgi runs each
|
||||
# request inside ``copy_context().run(...)``, so this reset doesn't
|
||||
# leak into the stream's view of the context.
|
||||
token = getattr(request, _LOG_CTX_TOKEN_ATTR, None)
|
||||
if token is not None:
|
||||
log_context.reset(token)
|
||||
|
||||
|
||||
@app.before_request
|
||||
def enforce_stt_request_size_limits():
|
||||
if request.method == "OPTIONS":
|
||||
@@ -148,11 +191,28 @@ def authenticate_request():
|
||||
request.decoded_token = decoded_token
|
||||
|
||||
|
||||
@app.before_request
|
||||
def _bind_user_id_to_log_context():
|
||||
# Registered after ``authenticate_request`` (Flask runs before_request
|
||||
# handlers in registration order), so ``request.decoded_token`` is
|
||||
# populated by the time we read it. ``teardown_request`` unwinds the
|
||||
# whole request-level bind, so no separate reset token is needed here.
|
||||
if request.method == "OPTIONS":
|
||||
return None
|
||||
decoded_token = getattr(request, "decoded_token", None)
|
||||
user_id = decoded_token.get("sub") if isinstance(decoded_token, dict) else None
|
||||
if user_id:
|
||||
log_context.bind(user_id=user_id)
|
||||
return None
|
||||
|
||||
|
||||
@app.after_request
|
||||
def after_request(response: Response) -> Response:
|
||||
"""Add CORS headers for the pure Flask development entrypoint."""
|
||||
response.headers["Access-Control-Allow-Origin"] = "*"
|
||||
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization"
|
||||
response.headers["Access-Control-Allow-Headers"] = (
|
||||
"Content-Type, Authorization, Idempotency-Key"
|
||||
)
|
||||
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
|
||||
return response
|
||||
|
||||
|
||||
@@ -25,7 +25,12 @@ asgi_app = Starlette(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
|
||||
allow_headers=["Content-Type", "Authorization", "Mcp-Session-Id"],
|
||||
allow_headers=[
|
||||
"Content-Type",
|
||||
"Authorization",
|
||||
"Mcp-Session-Id",
|
||||
"Idempotency-Key",
|
||||
],
|
||||
expose_headers=["Mcp-Session-Id"],
|
||||
),
|
||||
],
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
@@ -10,6 +11,14 @@ from application.utils import get_hash
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _cache_default(value):
|
||||
# Image attachments arrive inline as bytes (see GoogleLLM.prepare_messages_with_attachments);
|
||||
# hash so the cache key stays bounded in size and stable across identical content.
|
||||
if isinstance(value, (bytes, bytearray, memoryview)):
|
||||
return f"<bytes:sha256:{hashlib.sha256(bytes(value)).hexdigest()}>"
|
||||
return repr(value)
|
||||
|
||||
_redis_instance = None
|
||||
_redis_creation_failed = False
|
||||
_instance_lock = Lock()
|
||||
@@ -20,8 +29,17 @@ def get_redis_instance():
|
||||
with _instance_lock:
|
||||
if _redis_instance is None and not _redis_creation_failed:
|
||||
try:
|
||||
# ``health_check_interval`` makes redis-py ping the
|
||||
# connection every N seconds when otherwise idle.
|
||||
# Without it, a half-open TCP (NAT silently dropped
|
||||
# state, ELB idle-close) can hang the SSE generator
|
||||
# in ``pubsub.get_message`` past its keepalive
|
||||
# cadence — the kernel never surfaces the dead
|
||||
# socket because no payload is in flight.
|
||||
_redis_instance = redis.Redis.from_url(
|
||||
settings.CACHE_REDIS_URL, socket_connect_timeout=2
|
||||
settings.CACHE_REDIS_URL,
|
||||
socket_connect_timeout=2,
|
||||
health_check_interval=10,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"Invalid Redis URL: {e}")
|
||||
@@ -36,7 +54,7 @@ def get_redis_instance():
|
||||
def gen_cache_key(messages, model="docgpt", tools=None):
|
||||
if not all(isinstance(msg, dict) for msg in messages):
|
||||
raise ValueError("All messages must be dictionaries.")
|
||||
messages_str = json.dumps(messages)
|
||||
messages_str = json.dumps(messages, default=_cache_default)
|
||||
tools_str = json.dumps(str(tools)) if tools else ""
|
||||
combined = f"{model}_{messages_str}_{tools_str}"
|
||||
cache_key = get_hash(combined)
|
||||
|
||||
@@ -1,8 +1,20 @@
|
||||
import ctypes
|
||||
import gc
|
||||
import inspect
|
||||
import logging
|
||||
import sys
|
||||
import threading
|
||||
|
||||
from celery import Celery
|
||||
from application.core import log_context
|
||||
from application.core.settings import settings
|
||||
from celery.signals import setup_logging, worker_process_init, worker_ready
|
||||
from celery.signals import (
|
||||
setup_logging,
|
||||
task_postrun,
|
||||
task_prerun,
|
||||
worker_process_init,
|
||||
worker_ready,
|
||||
)
|
||||
|
||||
|
||||
def make_celery(app_name=__name__):
|
||||
@@ -41,6 +53,82 @@ def _dispose_db_engine_on_fork(*args, **kwargs):
|
||||
dispose_engine()
|
||||
|
||||
|
||||
# Most tasks in this repo accept ``user`` where the log context wants
|
||||
# ``user_id``; map task parameter names to context keys explicitly.
|
||||
_TASK_PARAM_TO_CTX_KEY: dict[str, str] = {
|
||||
"user": "user_id",
|
||||
"user_id": "user_id",
|
||||
"agent_id": "agent_id",
|
||||
"conversation_id": "conversation_id",
|
||||
}
|
||||
|
||||
_task_log_tokens: dict[str, object] = {}
|
||||
|
||||
|
||||
@task_prerun.connect
|
||||
def _bind_task_log_context(task_id, task, args, kwargs, **_):
|
||||
# Resolve task args by parameter name — nearly every task in this repo
|
||||
# is called positionally, so ``kwargs.get('user')`` would bind nothing.
|
||||
ctx = {"activity_id": task_id}
|
||||
try:
|
||||
sig = inspect.signature(task.run)
|
||||
bound = sig.bind_partial(*args, **kwargs).arguments
|
||||
except (TypeError, ValueError):
|
||||
bound = dict(kwargs)
|
||||
for param_name, value in bound.items():
|
||||
ctx_key = _TASK_PARAM_TO_CTX_KEY.get(param_name)
|
||||
if ctx_key and value:
|
||||
ctx[ctx_key] = value
|
||||
_task_log_tokens[task_id] = log_context.bind(**ctx)
|
||||
|
||||
|
||||
@task_postrun.connect
|
||||
def _unbind_task_log_context(task_id, **_):
|
||||
# ``task_postrun`` fires on both success and failure. Required for
|
||||
# Celery: unlike the Flask path, tasks aren't isolated in their own
|
||||
# ``copy_context().run(...)``, so a missing reset would leak the
|
||||
# bind onto the next task on the same worker.
|
||||
token = _task_log_tokens.pop(task_id, None)
|
||||
if token is None:
|
||||
return
|
||||
try:
|
||||
log_context.reset(token)
|
||||
except ValueError:
|
||||
# task_prerun and task_postrun ran on different threads (non-default
|
||||
# Celery pool); the token isn't valid in this context. Drop it.
|
||||
logging.getLogger(__name__).debug(
|
||||
"log_context reset skipped for task %s", task_id
|
||||
)
|
||||
|
||||
|
||||
def _trim_native_heap() -> None:
|
||||
"""Return freed glibc heap pages to the OS (Linux only; no-op elsewhere)."""
|
||||
# docling/torch parsing makes large transient allocations; glibc keeps the
|
||||
# freed pages in per-thread malloc arenas rather than returning them, so a
|
||||
# long-lived worker child's RSS only ever climbs. malloc_trim hands them
|
||||
# back. The symbol is glibc-only — absent in macOS libc.
|
||||
if not sys.platform.startswith("linux"):
|
||||
return
|
||||
try:
|
||||
ctypes.CDLL("libc.so.6").malloc_trim(0)
|
||||
except (OSError, AttributeError):
|
||||
pass
|
||||
|
||||
|
||||
@task_postrun.connect
|
||||
def _reclaim_memory_after_task(*args, **kwargs):
|
||||
"""Drop per-task allocations so the prefork child's RSS doesn't ratchet."""
|
||||
gc.collect()
|
||||
torch = sys.modules.get("torch")
|
||||
if torch is not None:
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
except Exception:
|
||||
pass
|
||||
_trim_native_heap()
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def _run_version_check(*args, **kwargs):
|
||||
"""Kick off the anonymous version check on worker startup.
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import os
|
||||
from application.core.settings import settings
|
||||
|
||||
broker_url = os.getenv("CELERY_BROKER_URL")
|
||||
result_backend = os.getenv("CELERY_RESULT_BACKEND")
|
||||
# Pydantic loads .env into ``settings`` but does not inject values into
|
||||
# ``os.environ`` — read directly from settings so beat startup (which
|
||||
# imports this module before any explicit env load) sees a real URL.
|
||||
broker_url = settings.CELERY_BROKER_URL
|
||||
result_backend = settings.CELERY_RESULT_BACKEND
|
||||
|
||||
task_serializer = 'json'
|
||||
result_serializer = 'json'
|
||||
@@ -10,7 +13,28 @@ accept_content = ['json']
|
||||
# Autodiscover tasks
|
||||
imports = ('application.api.user.tasks',)
|
||||
|
||||
# Project-scoped queue so a stray sibling worker on the same broker
|
||||
# (other repo, same default ``celery`` queue) can't grab DocsGPT tasks.
|
||||
task_default_queue = "docsgpt"
|
||||
task_default_exchange = "docsgpt"
|
||||
task_default_routing_key = "docsgpt"
|
||||
|
||||
beat_scheduler = "redbeat.RedBeatScheduler"
|
||||
redbeat_redis_url = broker_url
|
||||
redbeat_key_prefix = "redbeat:docsgpt:"
|
||||
redbeat_lock_timeout = 90
|
||||
|
||||
# Survive worker SIGKILL/OOM without silently dropping in-flight tasks.
|
||||
task_acks_late = True
|
||||
task_reject_on_worker_lost = True
|
||||
worker_prefetch_multiplier = settings.CELERY_WORKER_PREFETCH_MULTIPLIER
|
||||
broker_transport_options = {"visibility_timeout": settings.CELERY_VISIBILITY_TIMEOUT}
|
||||
result_expires = 86400 * 7
|
||||
task_track_started = True
|
||||
|
||||
# Recycle the prefork worker child to bound native-heap growth from
|
||||
# docling/torch parsing. Left unset (Celery's unlimited default) when 0.
|
||||
if settings.CELERY_WORKER_MAX_MEMORY_PER_CHILD > 0:
|
||||
worker_max_memory_per_child = settings.CELERY_WORKER_MAX_MEMORY_PER_CHILD
|
||||
if settings.CELERY_WORKER_MAX_TASKS_PER_CHILD > 0:
|
||||
worker_max_tasks_per_child = settings.CELERY_WORKER_MAX_TASKS_PER_CHILD
|
||||
|
||||
57
application/core/log_context.py
Normal file
57
application/core/log_context.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Per-activity logging context backed by ``contextvars``.
|
||||
|
||||
The ``_ContextFilter`` installed by ``logging_config.setup_logging`` stamps
|
||||
every ``LogRecord`` emitted inside a ``bind`` block with the bound keys, so
|
||||
they land as first-class attributes on the OTLP log export rather than being
|
||||
buried inside formatted message bodies.
|
||||
|
||||
A single ``ContextVar`` holds a dict so nested binds reset atomically (LIFO)
|
||||
via the token returned by ``bind``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextvars import ContextVar, Token
|
||||
from typing import Mapping
|
||||
|
||||
|
||||
_CTX_KEYS: frozenset[str] = frozenset(
|
||||
{
|
||||
"activity_id",
|
||||
"parent_activity_id",
|
||||
"user_id",
|
||||
"agent_id",
|
||||
"conversation_id",
|
||||
"endpoint",
|
||||
"model",
|
||||
}
|
||||
)
|
||||
|
||||
_ctx: ContextVar[Mapping[str, str]] = ContextVar("log_ctx", default={})
|
||||
|
||||
|
||||
def bind(**kwargs: object) -> Token:
|
||||
"""Overlay the given keys onto the current context.
|
||||
|
||||
Returns a ``Token`` so the caller can ``reset`` in a ``finally`` block.
|
||||
Keys outside :data:`_CTX_KEYS` are silently dropped (so a typo can't
|
||||
stamp a stray field name onto every record), as are ``None`` values
|
||||
(a missing attribute is more useful than the literal string ``"None"``).
|
||||
"""
|
||||
overlay = {
|
||||
k: str(v)
|
||||
for k, v in kwargs.items()
|
||||
if k in _CTX_KEYS and v is not None
|
||||
}
|
||||
new = {**_ctx.get(), **overlay}
|
||||
return _ctx.set(new)
|
||||
|
||||
|
||||
def reset(token: Token) -> None:
|
||||
"""Restore the context to the snapshot captured by the matching ``bind``."""
|
||||
_ctx.reset(token)
|
||||
|
||||
|
||||
def snapshot() -> Mapping[str, str]:
|
||||
"""Return the current context dict. Treat as read-only; use :func:`bind`."""
|
||||
return _ctx.get()
|
||||
@@ -2,6 +2,36 @@ import logging
|
||||
import os
|
||||
from logging.config import dictConfig
|
||||
|
||||
from application.core.log_context import snapshot as _ctx_snapshot
|
||||
|
||||
|
||||
# Loggers with ``propagate=False`` don't share root's handlers, so the
|
||||
# context filter has to be installed on their handlers directly.
|
||||
_NON_PROPAGATING_LOGGERS: tuple[str, ...] = (
|
||||
"uvicorn",
|
||||
"uvicorn.access",
|
||||
"uvicorn.error",
|
||||
"celery.app.trace",
|
||||
"celery.worker.strategy",
|
||||
"gunicorn.error",
|
||||
"gunicorn.access",
|
||||
)
|
||||
|
||||
|
||||
class _ContextFilter(logging.Filter):
|
||||
"""Stamp the current ``log_context`` snapshot onto every ``LogRecord``.
|
||||
|
||||
Must be installed on **handlers**, not loggers: Python skips logger-level
|
||||
filters when a child logger's record propagates up. The ``hasattr`` guard
|
||||
keeps an explicit ``logger.info(..., extra={...})`` from being overwritten.
|
||||
"""
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
for key, value in _ctx_snapshot().items():
|
||||
if not hasattr(record, key):
|
||||
setattr(record, key, value)
|
||||
return True
|
||||
|
||||
|
||||
def _otlp_logs_enabled() -> bool:
|
||||
"""Return True when the user has opted in to OTLP log export.
|
||||
@@ -60,3 +90,23 @@ def setup_logging() -> None:
|
||||
for handler in preserved_handlers:
|
||||
if handler not in root.handlers:
|
||||
root.addHandler(handler)
|
||||
|
||||
_install_context_filter()
|
||||
|
||||
|
||||
def _install_context_filter() -> None:
|
||||
"""Attach :class:`_ContextFilter` to root's handlers + every handler on
|
||||
the known non-propagating loggers. Skipping handlers that already carry
|
||||
one keeps repeat ``setup_logging`` calls from stacking filters.
|
||||
"""
|
||||
|
||||
def _has_ctx_filter(handler: logging.Handler) -> bool:
|
||||
return any(isinstance(f, _ContextFilter) for f in handler.filters)
|
||||
|
||||
for handler in logging.getLogger().handlers:
|
||||
if not _has_ctx_filter(handler):
|
||||
handler.addFilter(_ContextFilter())
|
||||
for name in _NON_PROPAGATING_LOGGERS:
|
||||
for handler in logging.getLogger(name).handlers:
|
||||
if not _has_ctx_filter(handler):
|
||||
handler.addFilter(_ContextFilter())
|
||||
|
||||
@@ -30,6 +30,17 @@ class Settings(BaseSettings):
|
||||
|
||||
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
|
||||
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
|
||||
# Prefetch=1 caps SIGKILL loss to one task. Visibility timeout must exceed
|
||||
# the longest legitimate task runtime (ingest, agent webhook) but stay
|
||||
# short enough that SIGKILLed tasks redeliver promptly. 1h matches Onyx
|
||||
# and Dify defaults; long ingests can override via env.
|
||||
CELERY_WORKER_PREFETCH_MULTIPLIER: int = 1
|
||||
CELERY_VISIBILITY_TIMEOUT: int = 3600
|
||||
# Recycle the prefork worker child once its resident size crosses this many
|
||||
# kilobytes — backstops native-heap growth from docling/torch parsing. 0 disables.
|
||||
CELERY_WORKER_MAX_MEMORY_PER_CHILD: int = 4194304
|
||||
# Recycle the child after this many tasks; 0 disables (memory cap is the primary knob).
|
||||
CELERY_WORKER_MAX_TASKS_PER_CHILD: int = 0
|
||||
# Only consulted when VECTOR_STORE=mongodb or when running scripts/db/backfill.py; user data lives in Postgres.
|
||||
MONGO_URI: Optional[str] = None
|
||||
# User-data Postgres DB.
|
||||
@@ -55,6 +66,9 @@ class Settings(BaseSettings):
|
||||
PARSE_IMAGE_REMOTE: bool = False
|
||||
DOCLING_OCR_ENABLED: bool = False # Enable OCR for docling parsers (PDF, images)
|
||||
DOCLING_OCR_ATTACHMENTS_ENABLED: bool = False # Enable OCR for docling when parsing attachments
|
||||
# Pages docling's threaded pipeline buffers in flight; the library
|
||||
# default (100) drives worker RSS to ~3 GB on a mid-size PDF.
|
||||
DOCLING_PIPELINE_QUEUE_MAX_SIZE: int = 2
|
||||
VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb" or "pgvector"
|
||||
RETRIEVERS_ENABLED: list = ["classic_rag"]
|
||||
AGENT_NAME: str = "classic"
|
||||
@@ -175,6 +189,11 @@ class Settings(BaseSettings):
|
||||
# Tool pre-fetch settings
|
||||
ENABLE_TOOL_PREFETCH: bool = True
|
||||
|
||||
# Config-free tools on by default in agentless chats. ``scheduler`` is
|
||||
# dual-registered (also in ``BUILTIN_AGENT_TOOLS``) so the same synthetic id
|
||||
# resolves whether reached via defaults or the agent picker.
|
||||
DEFAULT_CHAT_TOOLS: list = ["memory", "read_webpage", "scheduler"]
|
||||
|
||||
# Conversation Compression Settings
|
||||
ENABLE_CONVERSATION_COMPRESSION: bool = True
|
||||
COMPRESSION_THRESHOLD_PERCENTAGE: float = 0.8 # Trigger at 80% of context
|
||||
@@ -182,6 +201,52 @@ class Settings(BaseSettings):
|
||||
COMPRESSION_PROMPT_VERSION: str = "v1.0" # Track prompt iterations
|
||||
COMPRESSION_MAX_HISTORY_POINTS: int = 3 # Keep only last N compression points to prevent DB bloat
|
||||
|
||||
# Internal SSE push channel (notifications + durable replay journal)
|
||||
# Master switch — when False, /api/events emits a "push_disabled" comment
|
||||
# and returns; clients fall back to polling. Publisher becomes a no-op.
|
||||
ENABLE_SSE_PUSH: bool = True
|
||||
# Per-user durable backlog cap (~entries). At typical event rates this
|
||||
# gives ~24h of replay; tune up for verbose feeds, down for memory.
|
||||
EVENTS_STREAM_MAXLEN: int = 1000
|
||||
# SSE keepalive comment cadence. Must sit under Cloudflare's 100s idle
|
||||
# close and iOS Safari's ~60s — 15s gives generous headroom.
|
||||
SSE_KEEPALIVE_SECONDS: int = 15
|
||||
# Cap on simultaneous SSE connections per user. Each connection holds
|
||||
# one WSGI thread (32 per gunicorn worker) and one Redis pub/sub
|
||||
# connection. 8 covers normal multi-tab use without letting one user
|
||||
# starve the pool. Set to 0 to disable the cap.
|
||||
SSE_MAX_CONCURRENT_PER_USER: int = 8
|
||||
# Per-request cap on the number of backlog entries XRANGE returns
|
||||
# for ``/api/events`` snapshots. Bounds the bytes a single replay
|
||||
# can move from Redis to the wire — a malicious client looping
|
||||
# ``Last-Event-ID=<oldest>`` reconnects can only enumerate this
|
||||
# many entries per round-trip. Combined with the per-user
|
||||
# connection cap above and the windowed budget below, total
|
||||
# enumeration throughput is bounded.
|
||||
EVENTS_REPLAY_MAX_PER_REQUEST: int = 200
|
||||
# Sliding-window cap on snapshot replays per user. Once the budget
|
||||
# is exhausted the route returns HTTP 429 with the cursor pinned;
|
||||
# the client backs off and retries after the window rolls over.
|
||||
EVENTS_REPLAY_BUDGET_REQUESTS_PER_WINDOW: int = 30
|
||||
EVENTS_REPLAY_BUDGET_WINDOW_SECONDS: int = 60
|
||||
|
||||
# Retention for the ``message_events`` journal. The ``cleanup_message_events``
|
||||
# beat task deletes rows older than this. Reconnect-replay only
|
||||
# needs the journal for streams a client could still be tailing,
|
||||
# so 14 days is a generous default that covers paused/tool-action
|
||||
# flows without unbounded table growth.
|
||||
MESSAGE_EVENTS_RETENTION_DAYS: int = 14
|
||||
|
||||
# Scheduler (see scheduler.md).
|
||||
SCHEDULE_DISPATCHER_INTERVAL: int = 30
|
||||
SCHEDULE_MIN_INTERVAL: int = 900
|
||||
SCHEDULE_MAX_PER_USER: int = 50
|
||||
SCHEDULE_RUN_TIMEOUT: int = 600
|
||||
SCHEDULE_MISFIRE_GRACE: int = 60
|
||||
SCHEDULE_AUTOPAUSE_FAILURES: int = 3
|
||||
SCHEDULE_ONCE_MAX_HORIZON: int = 31_536_000
|
||||
SCHEDULE_RUN_OUTPUT_RETENTION_DAYS: int = 90
|
||||
|
||||
@field_validator("POSTGRES_URI", mode="before")
|
||||
@classmethod
|
||||
def _normalize_postgres_uri_validator(cls, v):
|
||||
|
||||
0
application/events/__init__.py
Normal file
0
application/events/__init__.py
Normal file
52
application/events/keys.py
Normal file
52
application/events/keys.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Stream/topic key derivations shared by publisher and SSE consumer.
|
||||
|
||||
Single source of truth for the per-user Redis Streams key and pub/sub
|
||||
topic name. Both must agree exactly — a typo here splits the
|
||||
publisher's writes from the consumer's reads.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def stream_key(user_id: str) -> str:
|
||||
"""Redis Streams key holding the durable backlog for ``user_id``."""
|
||||
return f"user:{user_id}:stream"
|
||||
|
||||
|
||||
def topic_name(user_id: str) -> str:
|
||||
"""Redis pub/sub channel used for live fan-out to ``user_id``."""
|
||||
return f"user:{user_id}"
|
||||
|
||||
|
||||
def connection_counter_key(user_id: str) -> str:
|
||||
"""Redis counter tracking active SSE connections for ``user_id``."""
|
||||
return f"user:{user_id}:sse_count"
|
||||
|
||||
|
||||
def replay_budget_key(user_id: str) -> str:
|
||||
"""Redis counter tracking snapshot replays for ``user_id`` in the
|
||||
rolling rate-limit window."""
|
||||
return f"user:{user_id}:replay_count"
|
||||
|
||||
|
||||
def stream_id_compare(a: str, b: str) -> int:
|
||||
"""Compare two Redis Streams ids. Returns -1, 0, 1 like ``cmp``.
|
||||
|
||||
Stream ids are ``ms-seq`` strings; comparing as strings would be wrong
|
||||
once ``ms`` straddles digit-count boundaries. We parse and compare
|
||||
as ``(int, int)`` tuples.
|
||||
|
||||
Raises ``ValueError`` on malformed input. Callers must pre-validate
|
||||
against ``_STREAM_ID_RE`` (or equivalent) — a lex fallback here let
|
||||
a malformed id compare lex-greater than a real one and silently pin
|
||||
dedup forever.
|
||||
"""
|
||||
a_ms, _, a_seq = a.partition("-")
|
||||
b_ms, _, b_seq = b.partition("-")
|
||||
a_tuple = (int(a_ms), int(a_seq) if a_seq else 0)
|
||||
b_tuple = (int(b_ms), int(b_seq) if b_seq else 0)
|
||||
if a_tuple < b_tuple:
|
||||
return -1
|
||||
if a_tuple > b_tuple:
|
||||
return 1
|
||||
return 0
|
||||
144
application/events/publisher.py
Normal file
144
application/events/publisher.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""User-scoped event publisher: durable backlog + live fan-out.
|
||||
|
||||
Each ``publish_user_event`` call writes twice:
|
||||
|
||||
1. ``XADD user:{user_id}:stream MAXLEN ~ <cap> * event <json>`` — the
|
||||
durable backlog used by SSE reconnect (``Last-Event-ID``) and stream
|
||||
replay. Bounded by ``EVENTS_STREAM_MAXLEN`` (~24h at typical event
|
||||
rates) so the per-user footprint stays predictable.
|
||||
2. ``PUBLISH user:{user_id} <json-with-id>`` — live fan-out to every
|
||||
currently connected SSE generator for the user, across instances.
|
||||
|
||||
Together they give a snapshot-plus-tail story: a reconnecting client
|
||||
reads ``XRANGE`` from its last seen id and then transitions onto the
|
||||
live pub/sub. The Redis Streams entry id (e.g. ``1735682400000-0``) is
|
||||
the canonical, monotonically increasing event id and is what
|
||||
``Last-Event-ID`` carries.
|
||||
|
||||
Failures are logged and swallowed: the caller is typically a Celery
|
||||
task whose primary work has already succeeded, and a notification
|
||||
delivery miss should not surface as a task failure.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
|
||||
from application.cache import get_redis_instance
|
||||
from application.core.settings import settings
|
||||
from application.events.keys import stream_key, topic_name
|
||||
from application.streaming.broadcast_channel import Topic
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _iso_now() -> str:
|
||||
"""ISO 8601 UTC with millisecond precision and Z suffix."""
|
||||
return (
|
||||
datetime.now(timezone.utc)
|
||||
.isoformat(timespec="milliseconds")
|
||||
.replace("+00:00", "Z")
|
||||
)
|
||||
|
||||
|
||||
def publish_user_event(
|
||||
user_id: str,
|
||||
event_type: str,
|
||||
payload: dict[str, Any],
|
||||
*,
|
||||
scope: Optional[dict[str, Any]] = None,
|
||||
) -> Optional[str]:
|
||||
"""Publish a user-scoped event; return the Redis Streams id or ``None``.
|
||||
|
||||
Fire-and-forget: never raises. ``None`` means the event reached
|
||||
neither the journal nor live subscribers (see runbook for causes).
|
||||
"""
|
||||
if not user_id or not event_type:
|
||||
logger.warning(
|
||||
"publish_user_event called without user_id or event_type "
|
||||
"(user_id=%r, event_type=%r)",
|
||||
user_id,
|
||||
event_type,
|
||||
)
|
||||
return None
|
||||
if not settings.ENABLE_SSE_PUSH:
|
||||
return None
|
||||
|
||||
envelope_partial: dict[str, Any] = {
|
||||
"type": event_type,
|
||||
"ts": _iso_now(),
|
||||
"user_id": user_id,
|
||||
"topic": topic_name(user_id),
|
||||
"scope": scope or {},
|
||||
"payload": payload,
|
||||
}
|
||||
|
||||
try:
|
||||
envelope_partial_json = json.dumps(envelope_partial)
|
||||
except (TypeError, ValueError) as exc:
|
||||
logger.warning(
|
||||
"publish_user_event payload not JSON-serializable: "
|
||||
"user=%s type=%s err=%s",
|
||||
user_id,
|
||||
event_type,
|
||||
exc,
|
||||
)
|
||||
return None
|
||||
|
||||
redis = get_redis_instance()
|
||||
if redis is None:
|
||||
logger.debug("Redis unavailable; skipping publish_user_event")
|
||||
return None
|
||||
|
||||
maxlen = settings.EVENTS_STREAM_MAXLEN
|
||||
stream_id: Optional[str] = None
|
||||
try:
|
||||
# Auto-id ('*') gives a monotonic ms-seq id that doubles as the
|
||||
# SSE event id. ``approximate=True`` lets Redis trim in chunks
|
||||
# for performance; the cap is treated as ~MAXLEN, never <.
|
||||
result = redis.xadd(
|
||||
stream_key(user_id),
|
||||
{"event": envelope_partial_json},
|
||||
maxlen=maxlen,
|
||||
approximate=True,
|
||||
)
|
||||
stream_id = (
|
||||
result.decode("utf-8")
|
||||
if isinstance(result, (bytes, bytearray))
|
||||
else str(result)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"xadd failed for user=%s event_type=%s", user_id, event_type
|
||||
)
|
||||
|
||||
# If the durable journal write failed there is no canonical id to
|
||||
# ship — publishing the envelope live would put an id-less record
|
||||
# on the wire that bypasses the SSE route's dedup floor and breaks
|
||||
# ``Last-Event-ID`` semantics for any reconnect. Best-effort
|
||||
# delivery means dropping consistently, not delivering inconsistent
|
||||
# state.
|
||||
if stream_id is None:
|
||||
return None
|
||||
|
||||
envelope = dict(envelope_partial)
|
||||
envelope["id"] = stream_id
|
||||
|
||||
try:
|
||||
Topic(topic_name(user_id)).publish(json.dumps(envelope))
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"publish failed for user=%s event_type=%s", user_id, event_type
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"event.published topic=%s type=%s id=%s",
|
||||
topic_name(user_id),
|
||||
event_type,
|
||||
stream_id,
|
||||
)
|
||||
|
||||
return stream_id
|
||||
@@ -11,6 +11,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnthropicLLM(BaseLLM):
|
||||
provider_name = "anthropic"
|
||||
|
||||
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import ClassVar
|
||||
|
||||
from application.cache import gen_cache, stream_cache
|
||||
|
||||
@@ -10,6 +11,10 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseLLM(ABC):
|
||||
# Stamped onto the ``llm_stream_start`` event so dashboards can group
|
||||
# calls by vendor. Subclasses override.
|
||||
provider_name: ClassVar[str] = "unknown"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
decoded_token=None,
|
||||
@@ -75,6 +80,14 @@ class BaseLLM(ABC):
|
||||
agent_id=self.agent_id,
|
||||
model_user_id=self.model_user_id,
|
||||
)
|
||||
# Tag the fallback LLM so its rows land as
|
||||
# ``source='fallback'`` in cost-attribution dashboards.
|
||||
# Propagate the parent's ``_request_id`` so a user
|
||||
# request that ran fallback is still grouped under one id.
|
||||
self._fallback_llm._token_usage_source = "fallback"
|
||||
self._fallback_llm._request_id = getattr(
|
||||
self, "_request_id", None,
|
||||
)
|
||||
logger.info(
|
||||
f"Fallback LLM initialized from agent backup model: "
|
||||
f"{provider}/{backup_model_id}"
|
||||
@@ -101,6 +114,11 @@ class BaseLLM(ABC):
|
||||
agent_id=self.agent_id,
|
||||
model_user_id=self.model_user_id,
|
||||
)
|
||||
# Same rationale as the agent-backup branch.
|
||||
self._fallback_llm._token_usage_source = "fallback"
|
||||
self._fallback_llm._request_id = getattr(
|
||||
self, "_request_id", None,
|
||||
)
|
||||
logger.info(
|
||||
f"Fallback LLM initialized from global settings: "
|
||||
f"{settings.FALLBACK_LLM_PROVIDER}/{settings.FALLBACK_LLM_NAME}"
|
||||
@@ -118,6 +136,26 @@ class BaseLLM(ABC):
|
||||
return args_dict
|
||||
return {k: v for k, v in args_dict.items() if v is not None}
|
||||
|
||||
@staticmethod
|
||||
def _is_non_retriable_client_error(exc: BaseException) -> bool:
|
||||
"""4xx errors mean the request itself is malformed — retrying with
|
||||
a different model fails identically and doubles the work. Only
|
||||
transient/5xx/connection errors should trigger fallback."""
|
||||
try:
|
||||
from google.genai.errors import ClientError as _GenaiClientError
|
||||
|
||||
if isinstance(exc, _GenaiClientError):
|
||||
return True
|
||||
except ImportError:
|
||||
pass
|
||||
for attr in ("status_code", "code", "http_status"):
|
||||
v = getattr(exc, attr, None)
|
||||
if isinstance(v, int) and 400 <= v < 500:
|
||||
return True
|
||||
resp = getattr(exc, "response", None)
|
||||
v = getattr(resp, "status_code", None)
|
||||
return isinstance(v, int) and 400 <= v < 500
|
||||
|
||||
def _execute_with_fallback(
|
||||
self, method_name: str, decorators: list, *args, **kwargs
|
||||
):
|
||||
@@ -141,12 +179,18 @@ class BaseLLM(ABC):
|
||||
|
||||
if is_stream:
|
||||
return self._stream_with_fallback(
|
||||
decorated_method, method_name, *args, **kwargs
|
||||
decorated_method, method_name, decorators, *args, **kwargs
|
||||
)
|
||||
|
||||
try:
|
||||
return decorated_method()
|
||||
except Exception as e:
|
||||
if self._is_non_retriable_client_error(e):
|
||||
logger.error(
|
||||
f"Primary LLM failed with non-retriable client error; "
|
||||
f"skipping fallback: {str(e)}"
|
||||
)
|
||||
raise
|
||||
if not self.fallback_llm:
|
||||
logger.error(f"Primary LLM failed and no fallback configured: {str(e)}")
|
||||
raise
|
||||
@@ -156,14 +200,27 @@ class BaseLLM(ABC):
|
||||
f"{fallback.model_id}. Error: {str(e)}"
|
||||
)
|
||||
|
||||
fallback_method = getattr(
|
||||
fallback, method_name.replace("_raw_", "")
|
||||
)
|
||||
# Apply decorators to fallback's raw method directly — calling
|
||||
# fallback.gen() would re-enter the orchestrator and recurse via
|
||||
# fallback.fallback_llm.
|
||||
fallback_method = getattr(fallback, method_name)
|
||||
for decorator in decorators:
|
||||
fallback_method = decorator(fallback_method)
|
||||
fallback_kwargs = {**kwargs, "model": fallback.model_id}
|
||||
return fallback_method(*args, **fallback_kwargs)
|
||||
try:
|
||||
return fallback_method(fallback, *args, **fallback_kwargs)
|
||||
except Exception as e2:
|
||||
if self._is_non_retriable_client_error(e2):
|
||||
logger.error(
|
||||
f"Fallback LLM failed with non-retriable client "
|
||||
f"error; giving up: {str(e2)}"
|
||||
)
|
||||
else:
|
||||
logger.error(f"Fallback LLM also failed; giving up: {str(e2)}")
|
||||
raise
|
||||
|
||||
def _stream_with_fallback(
|
||||
self, decorated_method, method_name, *args, **kwargs
|
||||
self, decorated_method, method_name, decorators, *args, **kwargs
|
||||
):
|
||||
"""
|
||||
Wrapper generator that catches mid-stream errors and falls back.
|
||||
@@ -176,6 +233,12 @@ class BaseLLM(ABC):
|
||||
try:
|
||||
yield from decorated_method()
|
||||
except Exception as e:
|
||||
if self._is_non_retriable_client_error(e):
|
||||
logger.error(
|
||||
f"Primary LLM failed mid-stream with non-retriable client "
|
||||
f"error; skipping fallback: {str(e)}"
|
||||
)
|
||||
raise
|
||||
if not self.fallback_llm:
|
||||
logger.error(
|
||||
f"Primary LLM failed and no fallback configured: {str(e)}"
|
||||
@@ -186,11 +249,37 @@ class BaseLLM(ABC):
|
||||
f"Primary LLM failed mid-stream. Falling back to "
|
||||
f"{fallback.model_id}. Error: {str(e)}"
|
||||
)
|
||||
fallback_method = getattr(
|
||||
fallback, method_name.replace("_raw_", "")
|
||||
# Apply decorators to fallback's raw stream method directly —
|
||||
# calling fallback.gen_stream() would re-enter the orchestrator
|
||||
# and recurse via fallback.fallback_llm. Emit the stream-start
|
||||
# event manually so dashboards still see the fallback's
|
||||
# provider/model when the response actually comes from it.
|
||||
fallback._emit_stream_start_log(
|
||||
fallback.model_id,
|
||||
kwargs.get("messages"),
|
||||
kwargs.get("tools"),
|
||||
bool(
|
||||
kwargs.get("_usage_attachments")
|
||||
or kwargs.get("attachments")
|
||||
),
|
||||
)
|
||||
fallback_method = getattr(fallback, method_name)
|
||||
for decorator in decorators:
|
||||
fallback_method = decorator(fallback_method)
|
||||
fallback_kwargs = {**kwargs, "model": fallback.model_id}
|
||||
yield from fallback_method(*args, **fallback_kwargs)
|
||||
try:
|
||||
yield from fallback_method(fallback, *args, **fallback_kwargs)
|
||||
except Exception as e2:
|
||||
if self._is_non_retriable_client_error(e2):
|
||||
logger.error(
|
||||
f"Fallback LLM failed mid-stream with non-retriable "
|
||||
f"client error; giving up: {str(e2)}"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Fallback LLM also failed mid-stream; giving up: {str(e2)}"
|
||||
)
|
||||
raise
|
||||
|
||||
def gen(self, model, messages, stream=False, tools=None, *args, **kwargs):
|
||||
decorators = [gen_token_usage, gen_cache]
|
||||
@@ -205,7 +294,58 @@ class BaseLLM(ABC):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _emit_stream_start_log(self, model, messages, tools, has_attachments):
|
||||
# Stamped with ``self.provider_name`` so dashboards can group calls
|
||||
# by vendor; the fallback path emits its own copy on the fallback
|
||||
# instance so the actual responding provider is recorded.
|
||||
logging.info(
|
||||
"llm_stream_start",
|
||||
extra={
|
||||
"model": model,
|
||||
"provider": self.provider_name,
|
||||
"message_count": len(messages) if messages is not None else 0,
|
||||
"has_attachments": bool(has_attachments),
|
||||
"has_tools": bool(tools),
|
||||
},
|
||||
)
|
||||
|
||||
def _emit_stream_finished_log(
|
||||
self,
|
||||
model,
|
||||
*,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
latency_ms,
|
||||
cached_tokens=None,
|
||||
error=None,
|
||||
):
|
||||
# Paired with ``llm_stream_start`` so cost dashboards can sum tokens
|
||||
# by user/agent/provider. Token counts are client-side estimates
|
||||
# from ``stream_token_usage``; vendor-reported counts (incl.
|
||||
# ``cached_tokens`` for prompt caching) require per-provider
|
||||
# extraction in each ``_raw_gen_stream`` and aren't wired yet.
|
||||
extra = {
|
||||
"model": model,
|
||||
"provider": self.provider_name,
|
||||
"prompt_tokens": int(prompt_tokens),
|
||||
"completion_tokens": int(completion_tokens),
|
||||
"latency_ms": int(latency_ms),
|
||||
"status": "error" if error is not None else "ok",
|
||||
}
|
||||
if cached_tokens is not None:
|
||||
extra["cached_tokens"] = int(cached_tokens)
|
||||
if error is not None:
|
||||
extra["error_class"] = type(error).__name__
|
||||
logging.info("llm_stream_finished", extra=extra)
|
||||
|
||||
def gen_stream(self, model, messages, stream=True, tools=None, *args, **kwargs):
|
||||
# Attachments arrive as ``_usage_attachments`` from ``Agent._llm_gen``;
|
||||
# the ``stream_token_usage`` decorator pops that key, but the log
|
||||
# fires before the decorator runs so it's still in ``kwargs`` here.
|
||||
has_attachments = bool(
|
||||
kwargs.get("_usage_attachments") or kwargs.get("attachments")
|
||||
)
|
||||
self._emit_stream_start_log(model, messages, tools, has_attachments)
|
||||
decorators = [stream_cache, stream_token_usage]
|
||||
return self._execute_with_fallback(
|
||||
"_raw_gen_stream",
|
||||
|
||||
@@ -6,6 +6,8 @@ DOCSGPT_BASE_URL = "https://oai.arc53.com"
|
||||
DOCSGPT_MODEL = "docsgpt"
|
||||
|
||||
class DocsGPTAPILLM(OpenAILLM):
|
||||
provider_name = "docsgpt"
|
||||
|
||||
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
|
||||
super().__init__(
|
||||
api_key=DOCSGPT_API_KEY,
|
||||
|
||||
@@ -6,10 +6,13 @@ from google.genai import types
|
||||
from application.core.settings import settings
|
||||
|
||||
from application.llm.base import BaseLLM
|
||||
from application.llm.handlers.google import _decode_thought_signature
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
|
||||
|
||||
class GoogleLLM(BaseLLM):
|
||||
provider_name = "google"
|
||||
|
||||
def __init__(
|
||||
self, api_key=None, user_api_key=None, decoded_token=None, *args, **kwargs
|
||||
):
|
||||
@@ -79,24 +82,39 @@ class GoogleLLM(BaseLLM):
|
||||
for attachment in attachments:
|
||||
mime_type = attachment.get("mime_type")
|
||||
|
||||
if mime_type in self.get_supported_attachment_types():
|
||||
try:
|
||||
if mime_type not in self.get_supported_attachment_types():
|
||||
continue
|
||||
try:
|
||||
# Images go inline as bytes per Google's guidance for
|
||||
# requests under 20MB; the Files API can return before
|
||||
# the upload reaches ACTIVE state and yield an empty URI.
|
||||
if mime_type.startswith("image/"):
|
||||
file_bytes = self._read_attachment_bytes(attachment)
|
||||
files.append(
|
||||
{"file_bytes": file_bytes, "mime_type": mime_type}
|
||||
)
|
||||
else:
|
||||
file_uri = self._upload_file_to_google(attachment)
|
||||
if not file_uri:
|
||||
raise ValueError(
|
||||
f"Google Files API returned empty URI for "
|
||||
f"{attachment.get('path', 'unknown')}"
|
||||
)
|
||||
logging.info(
|
||||
f"GoogleLLM: Successfully uploaded file, got URI: {file_uri}"
|
||||
)
|
||||
files.append({"file_uri": file_uri, "mime_type": mime_type})
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"GoogleLLM: Error uploading file: {e}", exc_info=True
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"GoogleLLM: Error processing attachment: {e}", exc_info=True
|
||||
)
|
||||
if "content" in attachment:
|
||||
prepared_messages[user_message_index]["content"].append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"[File could not be processed: {attachment.get('path', 'unknown')}]",
|
||||
}
|
||||
)
|
||||
if "content" in attachment:
|
||||
prepared_messages[user_message_index]["content"].append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"[File could not be processed: {attachment.get('path', 'unknown')}]",
|
||||
}
|
||||
)
|
||||
if files:
|
||||
logging.info(f"GoogleLLM: Adding {len(files)} files to message")
|
||||
prepared_messages[user_message_index]["content"].append({"files": files})
|
||||
@@ -112,7 +130,9 @@ class GoogleLLM(BaseLLM):
|
||||
Returns:
|
||||
str: Google AI file URI for the uploaded file.
|
||||
"""
|
||||
if "google_file_uri" in attachment:
|
||||
# Truthy check, not membership: a poisoned cache row of "" or
|
||||
# None must be treated as a miss and trigger a fresh upload.
|
||||
if attachment.get("google_file_uri"):
|
||||
return attachment["google_file_uri"]
|
||||
file_path = attachment.get("path")
|
||||
if not file_path:
|
||||
@@ -126,6 +146,10 @@ class GoogleLLM(BaseLLM):
|
||||
file=local_path
|
||||
).uri,
|
||||
)
|
||||
if not file_uri:
|
||||
raise ValueError(
|
||||
f"Google Files API upload returned empty URI for {file_path}"
|
||||
)
|
||||
|
||||
# Cache the Google file URI on the attachment row so we don't
|
||||
# re-upload on the next LLM call. Accept either a PG UUID
|
||||
@@ -159,6 +183,26 @@ class GoogleLLM(BaseLLM):
|
||||
logging.error(f"Error uploading file to Google AI: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def _read_attachment_bytes(self, attachment):
|
||||
"""
|
||||
Read attachment bytes from storage for inline transmission.
|
||||
|
||||
Args:
|
||||
attachment (dict): Attachment dictionary with path and metadata.
|
||||
|
||||
Returns:
|
||||
bytes: Raw file bytes.
|
||||
"""
|
||||
file_path = attachment.get("path")
|
||||
if not file_path:
|
||||
raise ValueError("No file path provided in attachment")
|
||||
if not self.storage.file_exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
return self.storage.process_file(
|
||||
file_path,
|
||||
lambda local_path, **kwargs: open(local_path, "rb").read(),
|
||||
)
|
||||
|
||||
def _clean_messages_google(self, messages):
|
||||
"""
|
||||
Convert OpenAI format messages to Google AI format and collect system prompts.
|
||||
@@ -215,7 +259,7 @@ class GoogleLLM(BaseLLM):
|
||||
except (_json.JSONDecodeError, TypeError):
|
||||
args = {}
|
||||
cleaned_args = self._remove_null_values(args)
|
||||
thought_sig = tc.get("thought_signature")
|
||||
thought_sig = _decode_thought_signature(tc.get("thought_signature"))
|
||||
if thought_sig:
|
||||
parts.append(
|
||||
types.Part(
|
||||
@@ -279,7 +323,9 @@ class GoogleLLM(BaseLLM):
|
||||
name=item["function_call"]["name"],
|
||||
args=cleaned_args,
|
||||
),
|
||||
thoughtSignature=item["thought_signature"],
|
||||
thoughtSignature=_decode_thought_signature(
|
||||
item["thought_signature"]
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -298,12 +344,24 @@ class GoogleLLM(BaseLLM):
|
||||
)
|
||||
elif "files" in item:
|
||||
for file_data in item["files"]:
|
||||
parts.append(
|
||||
types.Part.from_uri(
|
||||
file_uri=file_data["file_uri"],
|
||||
mime_type=file_data["mime_type"],
|
||||
if "file_bytes" in file_data:
|
||||
parts.append(
|
||||
types.Part.from_bytes(
|
||||
data=file_data["file_bytes"],
|
||||
mime_type=file_data["mime_type"],
|
||||
)
|
||||
)
|
||||
elif file_data.get("file_uri"):
|
||||
parts.append(
|
||||
types.Part.from_uri(
|
||||
file_uri=file_data["file_uri"],
|
||||
mime_type=file_data["mime_type"],
|
||||
)
|
||||
)
|
||||
else:
|
||||
logging.warning(
|
||||
"GoogleLLM: dropping file part with empty URI and no bytes"
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected content dictionary format:{item}"
|
||||
@@ -541,22 +599,6 @@ class GoogleLLM(BaseLLM):
|
||||
config.response_mime_type = "application/json"
|
||||
# Check if we have both tools and file attachments
|
||||
|
||||
has_attachments = False
|
||||
for message in messages:
|
||||
for part in message.parts:
|
||||
if hasattr(part, "file_data") and part.file_data is not None:
|
||||
has_attachments = True
|
||||
break
|
||||
if has_attachments:
|
||||
break
|
||||
messages_summary = self._summarize_messages_for_log(messages)
|
||||
logging.info(
|
||||
"GoogleLLM: Starting stream generation. Model: %s, Messages: %s, Has attachments: %s",
|
||||
model,
|
||||
messages_summary,
|
||||
has_attachments,
|
||||
)
|
||||
|
||||
response = client.models.generate_content_stream(
|
||||
model=model,
|
||||
contents=messages,
|
||||
|
||||
@@ -5,6 +5,8 @@ GROQ_BASE_URL = "https://api.groq.com/openai/v1"
|
||||
|
||||
|
||||
class GroqLLM(OpenAILLM):
|
||||
provider_name = "groq"
|
||||
|
||||
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
|
||||
super().__init__(
|
||||
api_key=api_key or settings.GROQ_API_KEY or settings.API_KEY,
|
||||
|
||||
@@ -10,6 +10,18 @@ from application.logging import build_stack_data
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Cap the agent tool-call loop. Without this an LLM that keeps
|
||||
# requesting more tool calls (preview models, sparse tool results,
|
||||
# under-specified prompts) can chain searches indefinitely and the
|
||||
# stream never finalises. 25 mirrors Dify's default.
|
||||
MAX_TOOL_ITERATIONS = 25
|
||||
_FINALIZE_INSTRUCTION = (
|
||||
f"You have made {MAX_TOOL_ITERATIONS} tool calls. Provide a final "
|
||||
"response to the user based on what you have, without making any "
|
||||
"additional tool calls."
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCall:
|
||||
"""Represents a tool/function call from the LLM."""
|
||||
@@ -280,7 +292,26 @@ class LLMHandler(ABC):
|
||||
# Keep serialized function calls/responses so the compressor sees actions
|
||||
parts_text.append(str(item))
|
||||
elif "files" in item:
|
||||
parts_text.append(str(item))
|
||||
# Image attachments arrive with raw bytes / base64
|
||||
# inline (see GoogleLLM.prepare_messages_with_attachments).
|
||||
# ``str(item)`` would dump the whole byte/base64
|
||||
# blob into the compression prompt and bust the
|
||||
# compression LLM's input limit.
|
||||
files = item.get("files") or []
|
||||
descriptors = []
|
||||
if isinstance(files, list):
|
||||
for f in files:
|
||||
if isinstance(f, dict):
|
||||
descriptors.append(
|
||||
f.get("mime_type") or "file"
|
||||
)
|
||||
elif isinstance(f, str):
|
||||
descriptors.append(f)
|
||||
if not descriptors:
|
||||
descriptors = ["file"]
|
||||
parts_text.append(
|
||||
f"[attachment: {', '.join(descriptors)}]"
|
||||
)
|
||||
return "\n".join(parts_text)
|
||||
return ""
|
||||
|
||||
@@ -605,6 +636,10 @@ class LLMHandler(ABC):
|
||||
agent_id=getattr(agent, "agent_id", None),
|
||||
model_user_id=compression_user_id,
|
||||
)
|
||||
# Side-channel LLM tag — see ``orchestrator.py`` for rationale.
|
||||
compression_llm._token_usage_source = "compression"
|
||||
compression_llm._request_id = getattr(agent, "_request_id", None) \
|
||||
or getattr(getattr(agent, "llm", None), "_request_id", None)
|
||||
|
||||
# Create service without DB persistence capability
|
||||
compression_service = CompressionService(
|
||||
@@ -815,6 +850,79 @@ class LLMHandler(ABC):
|
||||
tools_dict, call, llm_class
|
||||
)
|
||||
if pause_info:
|
||||
# Headless (scheduled / webhook): synthesize a denial tool message
|
||||
# so the LLM finishes gracefully instead of stalling on a pause
|
||||
# nobody will resolve, then journal so the reconciler sees it.
|
||||
if pause_info.get("pause_type") == "headless_denied":
|
||||
deny_reason = pause_info.get(
|
||||
"deny_reason", "Tool blocked in headless mode."
|
||||
)
|
||||
args_str = (
|
||||
json.dumps(call.arguments)
|
||||
if isinstance(call.arguments, dict)
|
||||
else (call.arguments or "{}")
|
||||
)
|
||||
tool_call_obj = {
|
||||
"id": pause_info["call_id"],
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": call.name,
|
||||
"arguments": args_str,
|
||||
},
|
||||
}
|
||||
if getattr(call, "thought_signature", None):
|
||||
tool_call_obj["thought_signature"] = call.thought_signature
|
||||
updated_messages.append({
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [tool_call_obj],
|
||||
})
|
||||
denial_call = ToolCall(
|
||||
id=pause_info["call_id"],
|
||||
name=call.name,
|
||||
arguments=call.arguments,
|
||||
)
|
||||
updated_messages.append(
|
||||
self.create_tool_message(
|
||||
denial_call,
|
||||
f"Tool denied (headless): {deny_reason}",
|
||||
)
|
||||
)
|
||||
if hasattr(agent.tool_executor, "headless_denials"):
|
||||
agent.tool_executor.headless_denials.append(pause_info)
|
||||
from application.agents.tool_executor import (
|
||||
_mark_failed,
|
||||
_record_proposed,
|
||||
)
|
||||
|
||||
_record_proposed(
|
||||
pause_info["call_id"],
|
||||
pause_info["tool_name"],
|
||||
pause_info["action_name"],
|
||||
pause_info.get("arguments") or {},
|
||||
tool_id=pause_info.get("tool_id"),
|
||||
)
|
||||
_mark_failed(
|
||||
pause_info["call_id"],
|
||||
f"headless: {deny_reason}",
|
||||
)
|
||||
yield {
|
||||
"type": "tool_call",
|
||||
"data": {
|
||||
"tool_name": pause_info["tool_name"],
|
||||
"call_id": pause_info["call_id"],
|
||||
"action_name": pause_info.get(
|
||||
"llm_name", pause_info["name"]
|
||||
),
|
||||
"arguments": pause_info["arguments"],
|
||||
"status": "denied",
|
||||
"error": deny_reason,
|
||||
"error_type": pause_info.get(
|
||||
"error_type", "tool_not_allowed"
|
||||
),
|
||||
},
|
||||
}
|
||||
continue
|
||||
# Yield pause event so the client knows this tool is waiting
|
||||
yield {
|
||||
"type": "tool_call",
|
||||
@@ -915,7 +1023,9 @@ class LLMHandler(ABC):
|
||||
parsed = self.parse_response(response)
|
||||
self.llm_calls.append(build_stack_data(agent.llm))
|
||||
|
||||
iteration = 0
|
||||
while parsed.requires_tool_call:
|
||||
iteration += 1
|
||||
tool_handler_gen = self.handle_tool_calls(
|
||||
agent, parsed.tool_calls, tools_dict, messages
|
||||
)
|
||||
@@ -939,6 +1049,25 @@ class LLMHandler(ABC):
|
||||
}
|
||||
return ""
|
||||
|
||||
# Cap reached: force one final tool-less call so the stream
|
||||
# always ends with content rather than cutting off.
|
||||
if iteration >= MAX_TOOL_ITERATIONS:
|
||||
logger.warning(
|
||||
"agent tool loop hit cap (%d); forcing finalize",
|
||||
MAX_TOOL_ITERATIONS,
|
||||
)
|
||||
messages.append(
|
||||
{"role": "system", "content": _FINALIZE_INSTRUCTION},
|
||||
)
|
||||
response = agent.llm.gen(
|
||||
model=getattr(agent.llm, "model_id", None) or agent.model_id,
|
||||
messages=messages,
|
||||
tools=None,
|
||||
)
|
||||
parsed = self.parse_response(response)
|
||||
self.llm_calls.append(build_stack_data(agent.llm))
|
||||
break
|
||||
|
||||
# ``agent.model_id`` is the registry id (a UUID for BYOM
|
||||
# records). Use the LLM's own model_id, which LLMCreator
|
||||
# already resolved to the upstream model name. Built-ins:
|
||||
@@ -954,7 +1083,12 @@ class LLMHandler(ABC):
|
||||
return parsed.content
|
||||
|
||||
def handle_streaming(
|
||||
self, agent, response: Any, tools_dict: Dict, messages: List[Dict]
|
||||
self,
|
||||
agent,
|
||||
response: Any,
|
||||
tools_dict: Dict,
|
||||
messages: List[Dict],
|
||||
_iteration: int = 0,
|
||||
) -> Generator:
|
||||
"""
|
||||
Handle streaming response flow.
|
||||
@@ -1023,6 +1157,9 @@ class LLMHandler(ABC):
|
||||
}
|
||||
return
|
||||
|
||||
next_iteration = _iteration + 1
|
||||
cap_reached = next_iteration >= MAX_TOOL_ITERATIONS
|
||||
|
||||
# Check if context limit was reached during tool execution
|
||||
if hasattr(agent, 'context_limit_reached') and agent.context_limit_reached:
|
||||
# Add system message warning about context limit
|
||||
@@ -1035,16 +1172,32 @@ class LLMHandler(ABC):
|
||||
)
|
||||
})
|
||||
logger.info("Context limit reached - instructing agent to wrap up")
|
||||
elif cap_reached:
|
||||
logger.warning(
|
||||
"agent tool loop hit cap (%d); forcing finalize",
|
||||
MAX_TOOL_ITERATIONS,
|
||||
)
|
||||
messages.append(
|
||||
{"role": "system", "content": _FINALIZE_INSTRUCTION},
|
||||
)
|
||||
|
||||
# See note above on agent.model_id vs llm.model_id.
|
||||
response = agent.llm.gen_stream(
|
||||
model=getattr(agent.llm, "model_id", None) or agent.model_id,
|
||||
messages=messages,
|
||||
tools=agent.tools if not agent.context_limit_reached else None,
|
||||
tools=(
|
||||
None
|
||||
if cap_reached
|
||||
or getattr(agent, "context_limit_reached", False)
|
||||
else agent.tools
|
||||
),
|
||||
)
|
||||
self.llm_calls.append(build_stack_data(agent.llm))
|
||||
|
||||
yield from self.handle_streaming(agent, response, tools_dict, messages)
|
||||
yield from self.handle_streaming(
|
||||
agent, response, tools_dict, messages,
|
||||
_iteration=next_iteration,
|
||||
)
|
||||
return
|
||||
if parsed.content:
|
||||
buffer += parsed.content
|
||||
|
||||
@@ -1,9 +1,35 @@
|
||||
import base64
|
||||
import binascii
|
||||
import uuid
|
||||
from typing import Any, Dict, Generator
|
||||
from typing import Any, Dict, Generator, Optional, Union
|
||||
|
||||
from application.llm.handlers.base import LLMHandler, LLMResponse, ToolCall
|
||||
|
||||
|
||||
def _encode_thought_signature(sig: Optional[Union[bytes, str]]) -> Optional[str]:
|
||||
# Gemini's Python SDK returns thought_signature as raw bytes, but the
|
||||
# field is typed Optional[str] downstream and gets json.dumps'd into
|
||||
# SSE events. Encode once at ingress so callers only ever see a str.
|
||||
if isinstance(sig, bytes):
|
||||
return base64.b64encode(sig).decode("ascii")
|
||||
return sig
|
||||
|
||||
|
||||
def _decode_thought_signature(
|
||||
sig: Optional[Union[bytes, str]],
|
||||
) -> Optional[Union[bytes, str]]:
|
||||
# Reverse of _encode_thought_signature — Gemini's SDK expects bytes
|
||||
# back when we replay a tool call. ``validate=True`` keeps ASCII
|
||||
# strings that happen to be loosely decodable from being silently
|
||||
# turned into bytes; non-base64 inputs pass through unchanged.
|
||||
if isinstance(sig, str):
|
||||
try:
|
||||
return base64.b64decode(sig.encode("ascii"), validate=True)
|
||||
except (binascii.Error, ValueError):
|
||||
return sig
|
||||
return sig
|
||||
|
||||
|
||||
class GoogleLLMHandler(LLMHandler):
|
||||
"""Handler for Google's GenAI API."""
|
||||
|
||||
@@ -23,7 +49,7 @@ class GoogleLLMHandler(LLMHandler):
|
||||
for idx, part in enumerate(parts):
|
||||
if hasattr(part, "function_call") and part.function_call is not None:
|
||||
has_sig = hasattr(part, "thought_signature") and part.thought_signature is not None
|
||||
thought_sig = part.thought_signature if has_sig else None
|
||||
thought_sig = _encode_thought_signature(part.thought_signature) if has_sig else None
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=str(uuid.uuid4()),
|
||||
@@ -50,7 +76,7 @@ class GoogleLLMHandler(LLMHandler):
|
||||
tool_calls = []
|
||||
if hasattr(response, "function_call") and response.function_call is not None:
|
||||
has_sig = hasattr(response, "thought_signature") and response.thought_signature is not None
|
||||
thought_sig = response.thought_signature if has_sig else None
|
||||
thought_sig = _encode_thought_signature(response.thought_signature) if has_sig else None
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=str(uuid.uuid4()),
|
||||
@@ -70,8 +96,15 @@ class GoogleLLMHandler(LLMHandler):
|
||||
"""Create a tool result message in the standard internal format."""
|
||||
import json as _json
|
||||
|
||||
from application.storage.db.serialization import PGNativeJSONEncoder
|
||||
|
||||
# PostgresTool results commonly include PG-native types
|
||||
# (datetime / UUID / Decimal / bytea) when SELECT touches
|
||||
# timestamptz / numeric / uuid / bytea columns. The shared
|
||||
# encoder handles all five — bytes get base64 (lossless) instead
|
||||
# of the ``str(b'...')`` repr that ``default=str`` would emit.
|
||||
content = (
|
||||
_json.dumps(result)
|
||||
_json.dumps(result, cls=PGNativeJSONEncoder)
|
||||
if not isinstance(result, str)
|
||||
else result
|
||||
)
|
||||
|
||||
@@ -40,8 +40,15 @@ class OpenAILLMHandler(LLMHandler):
|
||||
"""Create a tool result message in the standard internal format."""
|
||||
import json as _json
|
||||
|
||||
from application.storage.db.serialization import PGNativeJSONEncoder
|
||||
|
||||
# PostgresTool results commonly include PG-native types
|
||||
# (datetime / UUID / Decimal / bytea) when SELECT touches
|
||||
# timestamptz / numeric / uuid / bytea columns. The shared
|
||||
# encoder handles all five — bytes get base64 (lossless) instead
|
||||
# of the ``str(b'...')`` repr that ``default=str`` would emit.
|
||||
content = (
|
||||
_json.dumps(result)
|
||||
_json.dumps(result, cls=PGNativeJSONEncoder)
|
||||
if not isinstance(result, str)
|
||||
else result
|
||||
)
|
||||
|
||||
@@ -26,6 +26,8 @@ class LlamaSingleton:
|
||||
|
||||
|
||||
class LlamaCpp(BaseLLM):
|
||||
provider_name = "llama_cpp"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key=None,
|
||||
|
||||
@@ -5,6 +5,8 @@ NOVITA_BASE_URL = "https://api.novita.ai/openai"
|
||||
|
||||
|
||||
class NovitaLLM(OpenAILLM):
|
||||
provider_name = "novita"
|
||||
|
||||
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
|
||||
super().__init__(
|
||||
api_key=api_key or settings.NOVITA_API_KEY or settings.API_KEY,
|
||||
|
||||
@@ -5,6 +5,8 @@ OPEN_ROUTER_BASE_URL = "https://openrouter.ai/api/v1"
|
||||
|
||||
|
||||
class OpenRouterLLM(OpenAILLM):
|
||||
provider_name = "openrouter"
|
||||
|
||||
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
|
||||
super().__init__(
|
||||
api_key=api_key or settings.OPEN_ROUTER_API_KEY or settings.API_KEY,
|
||||
|
||||
@@ -61,6 +61,7 @@ def _truncate_base64_for_logging(messages):
|
||||
|
||||
|
||||
class OpenAILLM(BaseLLM):
|
||||
provider_name = "openai"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -3,6 +3,7 @@ from application.core.settings import settings
|
||||
|
||||
|
||||
class PremAILLM(BaseLLM):
|
||||
provider_name = "premai"
|
||||
|
||||
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
||||
from premai import Prem
|
||||
|
||||
@@ -59,6 +59,7 @@ class LineIterator:
|
||||
|
||||
|
||||
class SagemakerAPILLM(BaseLLM):
|
||||
provider_name = "sagemaker"
|
||||
|
||||
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
||||
import boto3
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import datetime
|
||||
import functools
|
||||
import inspect
|
||||
import time
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Callable, Dict, Generator, List
|
||||
|
||||
from application.core import log_context
|
||||
from application.storage.db.repositories.stack_logs import StackLogsRepository
|
||||
from application.storage.db.session import db_session
|
||||
|
||||
@@ -22,6 +24,15 @@ class LogContext:
|
||||
self.api_key = api_key
|
||||
self.query = query
|
||||
self.stacks = []
|
||||
# Per-activity response aggregates populated by ``_consume_and_log``
|
||||
# while it forwards stream items, then flushed onto the
|
||||
# ``activity_finished`` event so every Flask request gets the
|
||||
# same summary that ``run_agent_logic`` used to log only for the
|
||||
# Celery webhook path.
|
||||
self.answer_length = 0
|
||||
self.thought_length = 0
|
||||
self.source_count = 0
|
||||
self.tool_call_count = 0
|
||||
|
||||
|
||||
def build_stack_data(
|
||||
@@ -78,25 +89,125 @@ def log_activity() -> Callable:
|
||||
user = data.get("user", "local")
|
||||
api_key = data.get("user_api_key", "")
|
||||
query = kwargs.get("query", getattr(args[0], "query", ""))
|
||||
agent_id = getattr(args[0], "agent_id", None) or kwargs.get("agent_id")
|
||||
conversation_id = (
|
||||
kwargs.get("conversation_id")
|
||||
or getattr(args[0], "conversation_id", None)
|
||||
)
|
||||
model = getattr(args[0], "gpt_model", None) or getattr(args[0], "model", None)
|
||||
|
||||
# Capture the surrounding activity_id before overlaying ours,
|
||||
# so nested activities record the parent → child link.
|
||||
parent_activity_id = log_context.snapshot().get("activity_id")
|
||||
|
||||
context = LogContext(endpoint, activity_id, user, api_key, query)
|
||||
kwargs["log_context"] = context
|
||||
|
||||
logging.info(
|
||||
f"Starting activity: {endpoint} - {activity_id} - User: {user}"
|
||||
ctx_token = log_context.bind(
|
||||
activity_id=activity_id,
|
||||
parent_activity_id=parent_activity_id,
|
||||
user_id=user,
|
||||
agent_id=agent_id,
|
||||
conversation_id=conversation_id,
|
||||
endpoint=endpoint,
|
||||
model=model,
|
||||
)
|
||||
|
||||
generator = func(*args, **kwargs)
|
||||
yield from _consume_and_log(generator, context)
|
||||
started_at = time.monotonic()
|
||||
logging.info(
|
||||
"activity_started",
|
||||
extra={
|
||||
"activity_id": activity_id,
|
||||
"parent_activity_id": parent_activity_id,
|
||||
"user_id": user,
|
||||
"agent_id": agent_id,
|
||||
"conversation_id": conversation_id,
|
||||
"endpoint": endpoint,
|
||||
"model": model,
|
||||
},
|
||||
)
|
||||
|
||||
error: BaseException | None = None
|
||||
try:
|
||||
generator = func(*args, **kwargs)
|
||||
yield from _consume_and_log(generator, context)
|
||||
except Exception as exc:
|
||||
# Only ``Exception`` counts as an activity error; ``GeneratorExit``
|
||||
# (consumer disconnected mid-stream) and ``KeyboardInterrupt``
|
||||
# flow through the finally as ``status="ok"``, matching
|
||||
# ``_consume_and_log``.
|
||||
error = exc
|
||||
raise
|
||||
finally:
|
||||
_emit_activity_finished(
|
||||
context=context,
|
||||
parent_activity_id=parent_activity_id,
|
||||
started_at=started_at,
|
||||
error=error,
|
||||
)
|
||||
log_context.reset(ctx_token)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _emit_activity_finished(
|
||||
*,
|
||||
context: "LogContext",
|
||||
parent_activity_id: str | None,
|
||||
started_at: float,
|
||||
error: BaseException | None,
|
||||
) -> None:
|
||||
"""Emit the paired ``activity_finished`` event with duration, outcome,
|
||||
and per-activity response aggregates accumulated in ``_consume_and_log``.
|
||||
"""
|
||||
duration_ms = int((time.monotonic() - started_at) * 1000)
|
||||
logging.info(
|
||||
"activity_finished",
|
||||
extra={
|
||||
"activity_id": context.activity_id,
|
||||
"parent_activity_id": parent_activity_id,
|
||||
"user_id": context.user,
|
||||
"endpoint": context.endpoint,
|
||||
"duration_ms": duration_ms,
|
||||
"status": "error" if error is not None else "ok",
|
||||
"error_class": type(error).__name__ if error is not None else None,
|
||||
"answer_length": context.answer_length,
|
||||
"thought_length": context.thought_length,
|
||||
"source_count": context.source_count,
|
||||
"tool_call_count": context.tool_call_count,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _accumulate_response_summary(item: Any, context: "LogContext") -> None:
|
||||
"""Mirror the per-line aggregation that ``run_agent_logic`` did for the
|
||||
Celery webhook path, but at the generator-consumption layer so every
|
||||
``Agent.gen`` activity (Flask streaming, sub-agents, workflow agents)
|
||||
gets the same summary.
|
||||
"""
|
||||
if not isinstance(item, dict):
|
||||
return
|
||||
if "answer" in item:
|
||||
context.answer_length += len(str(item["answer"]))
|
||||
return
|
||||
if "thought" in item:
|
||||
context.thought_length += len(str(item["thought"]))
|
||||
return
|
||||
sources = item.get("sources") if "sources" in item else None
|
||||
if isinstance(sources, list):
|
||||
context.source_count += len(sources)
|
||||
return
|
||||
tool_calls = item.get("tool_calls") if "tool_calls" in item else None
|
||||
if isinstance(tool_calls, list):
|
||||
context.tool_call_count += len(tool_calls)
|
||||
|
||||
|
||||
def _consume_and_log(generator: Generator, context: "LogContext"):
|
||||
try:
|
||||
for item in generator:
|
||||
_accumulate_response_summary(item, context)
|
||||
yield item
|
||||
except Exception as e:
|
||||
logging.exception(f"Error in {context.endpoint} - {context.activity_id}: {e}")
|
||||
|
||||
@@ -1,12 +1,28 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import List, Any
|
||||
from typing import Any, List, Optional
|
||||
from retry import retry
|
||||
from tqdm import tqdm
|
||||
from application.core.settings import settings
|
||||
from application.events.publisher import publish_user_event
|
||||
from application.storage.db.repositories.ingest_chunk_progress import (
|
||||
IngestChunkProgressRepository,
|
||||
)
|
||||
from application.storage.db.session import db_session
|
||||
from application.vectorstore.vector_creator import VectorCreator
|
||||
|
||||
|
||||
class EmbeddingPipelineError(Exception):
|
||||
"""Raised when the per-chunk embed loop produces a partial index.
|
||||
|
||||
Escapes into Celery's ``autoretry_for`` so a transient cause (rate
|
||||
limit, network blip) gets another shot. The chunk-progress
|
||||
checkpoint makes retries cheap — only the failed-and-after chunks
|
||||
re-run. After ``MAX_TASK_ATTEMPTS`` the poison-loop guard in
|
||||
``with_idempotency`` finalises the row as ``failed``.
|
||||
"""
|
||||
|
||||
|
||||
def sanitize_content(content: str) -> str:
|
||||
"""
|
||||
Remove NUL characters that can cause vector store ingestion to fail.
|
||||
@@ -22,7 +38,11 @@ def sanitize_content(content: str) -> str:
|
||||
return content.replace('\x00', '')
|
||||
|
||||
|
||||
@retry(tries=10, delay=60)
|
||||
# Per-chunk inline retry. Aggressive defaults (tries=10, delay=60) blocked
|
||||
# the loop for up to 9 min per chunk and wedged the heartbeat: lower the
|
||||
# tail so a transient failure fails-fast and the chunk-progress checkpoint
|
||||
# resumes cleanly on next dispatch.
|
||||
@retry(tries=3, delay=5, backoff=2)
|
||||
def add_text_to_store_with_retry(store: Any, doc: Any, source_id: str) -> None:
|
||||
"""Add a document's text and metadata to the vector store with retry logic.
|
||||
|
||||
@@ -45,21 +65,131 @@ def add_text_to_store_with_retry(store: Any, doc: Any, source_id: str) -> None:
|
||||
raise
|
||||
|
||||
|
||||
def embed_and_store_documents(docs: List[Any], folder_name: str, source_id: str, task_status: Any) -> None:
|
||||
def _init_progress_and_resume_index(
|
||||
source_id: str, total_chunks: int, attempt_id: Optional[str],
|
||||
) -> int:
|
||||
"""Upsert the progress row and return the next chunk index to embed.
|
||||
|
||||
The repository's upsert preserves ``last_index`` only when the
|
||||
incoming ``attempt_id`` matches the stored one (a Celery autoretry
|
||||
of the same task). On a fresh attempt — including any caller that
|
||||
doesn't pass an ``attempt_id``, e.g. legacy code or tests — the
|
||||
row's checkpoint is reset so the loop starts from chunk 0. This
|
||||
is what prevents a completed checkpoint from any prior run
|
||||
silently no-op'ing the next sync/reingest.
|
||||
|
||||
Best-effort: a DB outage falls back to ``0`` (fresh run from
|
||||
chunk 0). The embed loop's own re-raise still ensures partial
|
||||
runs don't get cached as complete.
|
||||
"""
|
||||
try:
|
||||
with db_session() as conn:
|
||||
progress = IngestChunkProgressRepository(conn).init_progress(
|
||||
source_id, total_chunks, attempt_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning(
|
||||
f"Could not init ingest progress for {source_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
return 0
|
||||
if not progress:
|
||||
return 0
|
||||
last_index = progress.get("last_index", -1)
|
||||
if last_index is None or last_index < 0:
|
||||
return 0
|
||||
return int(last_index) + 1
|
||||
|
||||
|
||||
def _record_progress(source_id: str, last_index: int, embedded_chunks: int) -> None:
|
||||
"""Best-effort checkpoint after each chunk; logged but never raised."""
|
||||
try:
|
||||
with db_session() as conn:
|
||||
IngestChunkProgressRepository(conn).record_chunk(
|
||||
source_id, last_index=last_index, embedded_chunks=embedded_chunks
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning(
|
||||
f"Could not record ingest progress for {source_id}: {e}", exc_info=True
|
||||
)
|
||||
|
||||
|
||||
def assert_index_complete(source_id: str) -> None:
|
||||
"""Raise ``EmbeddingPipelineError`` if ``ingest_chunk_progress``
|
||||
shows a partial embed for ``source_id``.
|
||||
|
||||
Defense-in-depth tripwire that workers run after
|
||||
``embed_and_store_documents`` to catch any future swallow path
|
||||
that bypasses the function's own re-raise — the chunk-progress
|
||||
row is the authoritative record of how many chunks landed.
|
||||
No-op when no row exists (zero-doc validation raised before init,
|
||||
or progress repo was unreachable).
|
||||
"""
|
||||
try:
|
||||
with db_session() as conn:
|
||||
progress = IngestChunkProgressRepository(conn).get_progress(source_id)
|
||||
except Exception as e:
|
||||
logging.warning(
|
||||
f"assert_index_complete: progress lookup failed for "
|
||||
f"{source_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
return
|
||||
if not progress:
|
||||
return
|
||||
embedded = int(progress.get("embedded_chunks") or 0)
|
||||
total = int(progress.get("total_chunks") or 0)
|
||||
if embedded < total:
|
||||
raise EmbeddingPipelineError(
|
||||
f"partial index for source {source_id}: "
|
||||
f"{embedded}/{total} chunks embedded"
|
||||
)
|
||||
|
||||
|
||||
def embed_and_store_documents(
|
||||
docs: List[Any],
|
||||
folder_name: str,
|
||||
source_id: str,
|
||||
task_status: Any,
|
||||
*,
|
||||
attempt_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
progress_start: int = 0,
|
||||
progress_end: int = 100,
|
||||
) -> None:
|
||||
"""Embeds documents and stores them in a vector store.
|
||||
|
||||
Resumable across Celery autoretries of the *same* task: when
|
||||
``attempt_id`` matches the stored checkpoint's ``attempt_id``,
|
||||
the loop resumes from ``last_index + 1``. A different
|
||||
``attempt_id`` (a fresh sync / reingest invocation) resets the
|
||||
checkpoint so the index is rebuilt from chunk 0 — this is what
|
||||
keeps a completed checkpoint from poisoning the next sync.
|
||||
|
||||
Args:
|
||||
docs: List of documents to be embedded and stored.
|
||||
folder_name: Directory to save the vector store.
|
||||
source_id: Unique identifier for the source.
|
||||
task_status: Task state manager for progress updates.
|
||||
attempt_id: Stable id of the current task invocation,
|
||||
typically ``self.request.id`` from the Celery task body.
|
||||
``None`` is treated as a fresh attempt every time.
|
||||
user_id: When provided, per-percent SSE progress events are
|
||||
published to ``user:{user_id}`` for the in-app upload toast.
|
||||
``None`` is the safe default — workers without a user
|
||||
context (e.g. background syncs) skip the publish.
|
||||
progress_start: Percent the reported progress maps to at chunk 0.
|
||||
Lets a caller reserve the lower band for an earlier stage
|
||||
(e.g. parsing). Defaults to ``0`` (embed owns the whole bar).
|
||||
progress_end: Percent the reported progress maps to at the final
|
||||
chunk. Defaults to ``100``.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
|
||||
Raises:
|
||||
OSError: If unable to create folder or save vector store.
|
||||
Exception: If vector store creation or document embedding fails.
|
||||
EmbeddingPipelineError: If a chunk fails after retries.
|
||||
"""
|
||||
# Ensure the folder exists
|
||||
if not os.path.exists(folder_name):
|
||||
@@ -69,41 +199,111 @@ def embed_and_store_documents(docs: List[Any], folder_name: str, source_id: str,
|
||||
if not docs:
|
||||
raise ValueError("No documents to embed - check file format and extension")
|
||||
|
||||
total_docs = len(docs)
|
||||
# Atomic upsert that preserves checkpoint state on attempt-id match
|
||||
# (autoretry of same task) and resets it on mismatch (fresh sync /
|
||||
# reingest). Returns the new resume index — 0 means "start fresh".
|
||||
resume_index = _init_progress_and_resume_index(
|
||||
source_id, total_docs, attempt_id,
|
||||
)
|
||||
is_resume = resume_index > 0
|
||||
|
||||
# Initialize vector store
|
||||
if settings.VECTOR_STORE == "faiss":
|
||||
docs_init = [docs.pop(0)]
|
||||
store = VectorCreator.create_vectorstore(
|
||||
settings.VECTOR_STORE,
|
||||
docs_init=docs_init,
|
||||
source_id=source_id,
|
||||
embeddings_key=os.getenv("EMBEDDINGS_KEY"),
|
||||
)
|
||||
if is_resume:
|
||||
# Load the existing FAISS index from storage so chunks
|
||||
# already embedded by the prior attempt survive the
|
||||
# save_local rewrite at the end of this run.
|
||||
store = VectorCreator.create_vectorstore(
|
||||
settings.VECTOR_STORE,
|
||||
source_id=source_id,
|
||||
embeddings_key=os.getenv("EMBEDDINGS_KEY"),
|
||||
)
|
||||
loop_start = resume_index
|
||||
else:
|
||||
# FAISS requires at least one doc to construct the store;
|
||||
# seed with ``docs[0]`` and let the loop pick up at index 1.
|
||||
store = VectorCreator.create_vectorstore(
|
||||
settings.VECTOR_STORE,
|
||||
docs_init=[docs[0]],
|
||||
source_id=source_id,
|
||||
embeddings_key=os.getenv("EMBEDDINGS_KEY"),
|
||||
)
|
||||
# Record the seeded chunk so single-doc ingests don't fail
|
||||
# ``assert_index_complete`` — the loop never runs for
|
||||
# ``total_docs == 1`` and would otherwise leave
|
||||
# ``embedded_chunks`` at 0 / ``last_index`` at -1. The loop
|
||||
# body's per-iteration ``_record_progress`` overshoots
|
||||
# correctly for multi-chunk runs (counts seed + iterations),
|
||||
# so writing this checkpoint up-front is a no-op for those.
|
||||
_record_progress(source_id, last_index=0, embedded_chunks=1)
|
||||
loop_start = 1
|
||||
else:
|
||||
store = VectorCreator.create_vectorstore(
|
||||
settings.VECTOR_STORE,
|
||||
source_id=source_id,
|
||||
embeddings_key=os.getenv("EMBEDDINGS_KEY"),
|
||||
)
|
||||
store.delete_index()
|
||||
# Only wipe the index on a fresh run — a resume must keep the
|
||||
# chunks that earlier attempts already embedded.
|
||||
if not is_resume:
|
||||
store.delete_index()
|
||||
loop_start = resume_index
|
||||
|
||||
total_docs = len(docs)
|
||||
if is_resume and loop_start >= total_docs:
|
||||
# Nothing left to do; the loop runs zero iterations and
|
||||
# downstream finalize logic still executes. This is only
|
||||
# reachable on a same-attempt retry of a task whose previous
|
||||
# attempt finished — typically a Celery acks_late redelivery
|
||||
# after the task already returned. The ``assert_index_complete``
|
||||
# tripwire still validates ``embedded == total`` afterwards.
|
||||
loop_start = total_docs
|
||||
|
||||
# Process and embed documents
|
||||
for idx, doc in tqdm(
|
||||
enumerate(docs),
|
||||
chunk_error: Exception | None = None
|
||||
failed_idx: int | None = None
|
||||
last_published_pct = -1
|
||||
source_id_str = str(source_id)
|
||||
progress_span = progress_end - progress_start
|
||||
for idx in tqdm(
|
||||
range(loop_start, total_docs),
|
||||
desc="Embedding 🦖",
|
||||
unit="docs",
|
||||
total=total_docs,
|
||||
total=total_docs - loop_start,
|
||||
bar_format="{l_bar}{bar}| Time Left: {remaining}",
|
||||
):
|
||||
doc = docs[idx]
|
||||
try:
|
||||
# Update task status for progress tracking
|
||||
progress = int(((idx + 1) / total_docs) * 100)
|
||||
# Map the embed loop into [progress_start, progress_end].
|
||||
progress = progress_start + int(
|
||||
((idx + 1) / total_docs) * progress_span
|
||||
)
|
||||
task_status.update_state(state="PROGRESS", meta={"current": progress})
|
||||
|
||||
# SSE push for sub-second upload-toast updates. Throttled to one
|
||||
# event per percent so a 10k-chunk ingest emits ~100 events,
|
||||
# not 10k. The Celery update_state above stays the source of
|
||||
# truth for the polling-fallback path.
|
||||
if user_id and progress > last_published_pct:
|
||||
publish_user_event(
|
||||
user_id,
|
||||
"source.ingest.progress",
|
||||
{
|
||||
"current": progress,
|
||||
"total": total_docs,
|
||||
"embedded_chunks": idx + 1,
|
||||
"stage": "embedding",
|
||||
},
|
||||
scope={"kind": "source", "id": source_id_str},
|
||||
)
|
||||
last_published_pct = progress
|
||||
|
||||
# Add document to vector store
|
||||
add_text_to_store_with_retry(store, doc, source_id)
|
||||
_record_progress(source_id, last_index=idx, embedded_chunks=idx + 1)
|
||||
except Exception as e:
|
||||
chunk_error = e
|
||||
failed_idx = idx
|
||||
logging.error(f"Error embedding document {idx}: {e}", exc_info=True)
|
||||
logging.info(f"Saving progress at document {idx} out of {total_docs}")
|
||||
try:
|
||||
@@ -124,3 +324,16 @@ def embed_and_store_documents(docs: List[Any], folder_name: str, source_id: str,
|
||||
raise OSError(f"Unable to save vector store to {folder_name}: {e}") from e
|
||||
else:
|
||||
logging.info("Vector store saved successfully.")
|
||||
|
||||
# Re-raise after the partial save: the chunks that *did* embed are
|
||||
# flushed to disk and recorded in ``ingest_chunk_progress``, so a
|
||||
# Celery autoretry resumes via ``_read_resume_index`` and only
|
||||
# re-runs the failed-and-after chunks. Without the raise, the
|
||||
# task body returns success and ``with_idempotency`` finalises
|
||||
# ``task_dedup`` as ``completed`` for a partial index — poisoning
|
||||
# the cache for 24h.
|
||||
if chunk_error is not None:
|
||||
raise EmbeddingPipelineError(
|
||||
f"embed failure at chunk {failed_idx}/{total_docs} "
|
||||
f"for source {source_id}"
|
||||
) from chunk_error
|
||||
|
||||
@@ -211,13 +211,22 @@ class SimpleDirectoryReader(BaseReader):
|
||||
|
||||
return new_input_files
|
||||
|
||||
def load_data(self, concatenate: bool = False) -> List[Document]:
|
||||
def load_data(
|
||||
self,
|
||||
concatenate: bool = False,
|
||||
progress_callback: Optional[Callable[[int, int], None]] = None,
|
||||
) -> List[Document]:
|
||||
"""Load data from the input directory.
|
||||
|
||||
Args:
|
||||
concatenate (bool): whether to concatenate all files into one document.
|
||||
If set to True, file metadata is ignored.
|
||||
False by default.
|
||||
progress_callback (Optional[Callable[[int, int], None]]): Called
|
||||
after each file is parsed with ``(files_done, total_files)``.
|
||||
Lets callers surface parse/OCR progress before embedding
|
||||
begins. Exceptions raised by the callback are swallowed so
|
||||
progress reporting can never fail ingestion.
|
||||
|
||||
Returns:
|
||||
List[Document]: A list of documents.
|
||||
@@ -226,8 +235,9 @@ class SimpleDirectoryReader(BaseReader):
|
||||
data_list: List[str] = []
|
||||
metadata_list = []
|
||||
self.file_token_counts = {}
|
||||
|
||||
for input_file in self.input_files:
|
||||
|
||||
total_files = len(self.input_files)
|
||||
for file_index, input_file in enumerate(self.input_files):
|
||||
suffix_lower = input_file.suffix.lower()
|
||||
parser_metadata = {}
|
||||
if suffix_lower in self.file_extractor:
|
||||
@@ -277,7 +287,15 @@ class SimpleDirectoryReader(BaseReader):
|
||||
else:
|
||||
data_list.append(str(data))
|
||||
metadata_list.append(base_metadata)
|
||||
|
||||
|
||||
if progress_callback is not None:
|
||||
try:
|
||||
progress_callback(file_index + 1, total_files)
|
||||
except Exception:
|
||||
logging.warning(
|
||||
"load_data progress callback failed", exc_info=True
|
||||
)
|
||||
|
||||
# Build directory structure if input_dir is provided
|
||||
if hasattr(self, 'input_dir'):
|
||||
self.directory_structure = self.build_directory_structure(self.input_dir)
|
||||
|
||||
@@ -16,6 +16,29 @@ from application.parser.file.base_parser import BaseParser
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Per-stage batch size for docling's threaded pipeline; 1 holds the
|
||||
# concurrent working set to a single page (see _apply_pipeline_caps).
|
||||
_PIPELINE_BATCH_SIZE = 1
|
||||
|
||||
|
||||
def _apply_pipeline_caps(pipeline_options) -> None:
|
||||
"""Cap docling's threaded-pipeline queue depth and batch sizes in place.
|
||||
|
||||
hasattr-guarded so docling builds without these knobs are unaffected.
|
||||
"""
|
||||
from application.core.settings import settings
|
||||
|
||||
caps = {
|
||||
"queue_max_size": max(1, settings.DOCLING_PIPELINE_QUEUE_MAX_SIZE),
|
||||
"layout_batch_size": _PIPELINE_BATCH_SIZE,
|
||||
"table_batch_size": _PIPELINE_BATCH_SIZE,
|
||||
"ocr_batch_size": _PIPELINE_BATCH_SIZE,
|
||||
}
|
||||
for name, value in caps.items():
|
||||
if hasattr(pipeline_options, name):
|
||||
setattr(pipeline_options, name, value)
|
||||
|
||||
|
||||
class DoclingParser(BaseParser):
|
||||
"""Parser using docling for advanced document processing.
|
||||
|
||||
@@ -86,6 +109,7 @@ class DoclingParser(BaseParser):
|
||||
do_ocr=self.ocr_enabled,
|
||||
do_table_structure=self.table_structure,
|
||||
)
|
||||
_apply_pipeline_caps(pipeline_options)
|
||||
|
||||
if self.ocr_enabled:
|
||||
ocr_options = self._get_ocr_options()
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import json
|
||||
|
||||
from application.parser.remote.sitemap_loader import SitemapLoader
|
||||
from application.parser.remote.crawler_loader import CrawlerLoader
|
||||
from application.parser.remote.web_loader import WebLoader
|
||||
@@ -32,3 +34,59 @@ class RemoteCreator:
|
||||
if not loader_class:
|
||||
raise ValueError(f"No loader class found for type {type}")
|
||||
return loader_class(*args, **kwargs)
|
||||
|
||||
|
||||
# Loader types whose load_data expects a URL string, not a config dict.
|
||||
_URL_LOADER_TYPES = {"url", "crawler", "sitemap", "github"}
|
||||
|
||||
# Keys a remote_data dict may hold the URL under (``raw`` is the legacy shape).
|
||||
_URL_DATA_KEYS = ("url", "urls", "repo_url", "raw")
|
||||
|
||||
|
||||
def normalize_remote_data(source_type, remote_data):
|
||||
"""Convert a stored ``sources.remote_data`` JSONB value into the
|
||||
``source_data`` shape the matching loader expects.
|
||||
|
||||
Args:
|
||||
source_type: The ``sources.type`` value (the loader name).
|
||||
remote_data: The stored ``remote_data`` (dict, list, str, or None).
|
||||
|
||||
Returns:
|
||||
Loader input: a URL string or list for url/crawler/sitemap/github,
|
||||
a JSON string for reddit, a dict for s3; ``None`` when the row has
|
||||
nothing syncable.
|
||||
"""
|
||||
if remote_data is None:
|
||||
return None
|
||||
|
||||
# Some legacy rows stored the JSON itself as a string.
|
||||
if isinstance(remote_data, str):
|
||||
stripped = remote_data.strip()
|
||||
if stripped[:1] in ("{", "["):
|
||||
try:
|
||||
remote_data = json.loads(stripped)
|
||||
except json.JSONDecodeError:
|
||||
# Not actually JSON — leave remote_data as the original
|
||||
# string; the per-loader branches below handle a string.
|
||||
pass
|
||||
|
||||
loader = (source_type or "").lower()
|
||||
|
||||
if loader in _URL_LOADER_TYPES:
|
||||
if isinstance(remote_data, dict):
|
||||
for key in _URL_DATA_KEYS:
|
||||
value = remote_data.get(key)
|
||||
if value:
|
||||
return value
|
||||
# No URL key — None keeps the loader off the dict-crash path.
|
||||
return None
|
||||
return remote_data
|
||||
|
||||
if loader == "reddit":
|
||||
# reddit's loader runs json.loads() on its input — needs a string.
|
||||
if isinstance(remote_data, (dict, list)):
|
||||
return json.dumps(remote_data)
|
||||
return remote_data
|
||||
|
||||
# s3's loader accepts a dict or JSON string; pass it through unchanged.
|
||||
return remote_data
|
||||
|
||||
@@ -7,6 +7,7 @@ beautifulsoup4==4.14.3
|
||||
cel-python==0.5.0
|
||||
celery==5.6.3
|
||||
celery-redbeat==2.3.3
|
||||
croniter==6.2.2
|
||||
cryptography==46.0.7
|
||||
dataclasses-json==0.6.7
|
||||
defusedxml==0.7.1
|
||||
|
||||
@@ -60,6 +60,9 @@ class ClassicRAG(BaseRetriever):
|
||||
agent_id=self.agent_id,
|
||||
model_user_id=self.model_user_id,
|
||||
)
|
||||
# Query-rephrase LLM is a side channel — tag it so its rows
|
||||
# land as ``source='rag_condense'`` in cost-attribution.
|
||||
self.llm._token_usage_source = "rag_condense"
|
||||
|
||||
if "active_docs" in source and source["active_docs"] is not None:
|
||||
if isinstance(source["active_docs"], list):
|
||||
|
||||
@@ -11,6 +11,8 @@ import re
|
||||
from typing import Any, Mapping
|
||||
from uuid import UUID
|
||||
|
||||
from application.storage.db.serialization import coerce_pg_native
|
||||
|
||||
|
||||
_UUID_RE = re.compile(
|
||||
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$",
|
||||
@@ -34,12 +36,17 @@ def looks_like_uuid(value: Any) -> bool:
|
||||
|
||||
|
||||
def row_to_dict(row: Any) -> dict:
|
||||
"""Convert a SQLAlchemy ``Row`` to a plain dict with Mongo-compatible ids.
|
||||
"""Convert a SQLAlchemy ``Row`` to a plain JSON-safe dict.
|
||||
|
||||
During the migration window, API responses and downstream code still
|
||||
expect a string ``_id`` field (matching the Mongo shape). This helper
|
||||
normalizes UUID columns to strings and emits both ``id`` and ``_id`` so
|
||||
existing serializers keep working unchanged.
|
||||
Normalises PG-native types at the SELECT boundary: UUID, datetime,
|
||||
date, Decimal, and bytes are coerced to JSON-safe forms via
|
||||
:func:`coerce_pg_native`. Downstream serialisation (SSE events,
|
||||
JSONB writes, API responses) becomes safe by default — repository
|
||||
consumers no longer need to know that PG returns a different type
|
||||
set than Mongo did.
|
||||
|
||||
Also emits ``_id`` alongside ``id`` for the duration of the Mongo→PG
|
||||
cutover so legacy serializers expecting Mongo's shape keep working.
|
||||
|
||||
Args:
|
||||
row: A SQLAlchemy ``Row`` object, or ``None``.
|
||||
@@ -52,10 +59,9 @@ def row_to_dict(row: Any) -> dict:
|
||||
|
||||
# Row has a ``._mapping`` attribute exposing a MappingProxy view.
|
||||
mapping: Mapping[str, Any] = row._mapping # type: ignore[attr-defined]
|
||||
out = dict(mapping)
|
||||
out = coerce_pg_native(dict(mapping))
|
||||
|
||||
if "id" in out and out["id"] is not None:
|
||||
out["id"] = str(out["id"]) if isinstance(out["id"], UUID) else out["id"]
|
||||
out["_id"] = out["id"]
|
||||
|
||||
return out
|
||||
|
||||
@@ -34,7 +34,7 @@ from sqlalchemy.dialects.postgresql import ARRAY, CITEXT, JSONB, UUID
|
||||
metadata = MetaData()
|
||||
|
||||
|
||||
# --- Phase 1, Tier 1 --------------------------------------------------------
|
||||
# --- Users, prompts, tools, logs --------------------------------------------
|
||||
|
||||
users_table = Table(
|
||||
"users",
|
||||
@@ -47,6 +47,7 @@ users_table = Table(
|
||||
nullable=False,
|
||||
server_default='{"pinned": [], "shared_with_me": []}',
|
||||
),
|
||||
Column("tool_preferences", JSONB, nullable=False, server_default="{}"),
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
)
|
||||
@@ -91,6 +92,16 @@ token_usage_table = Table(
|
||||
Column("prompt_tokens", Integer, nullable=False, server_default="0"),
|
||||
Column("generated_tokens", Integer, nullable=False, server_default="0"),
|
||||
Column("timestamp", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
# Added in ``0004_durability_foundation``. Distinguishes
|
||||
# ``agent_stream`` (primary completion) from side-channel inserts
|
||||
# (``title`` / ``compression`` / ``rag_condense`` / ``fallback``)
|
||||
# so cost attribution dashboards can group by call source.
|
||||
Column("source", Text, nullable=False, server_default="agent_stream"),
|
||||
# Added in ``0005_token_usage_request_id``. Stream-scoped UUID stamped
|
||||
# on the agent's primary LLM so multi-call agent runs (which produce
|
||||
# N rows) count as a single request via DISTINCT in the repository
|
||||
# query. NULL on side-channel sources by design.
|
||||
Column("request_id", Text),
|
||||
)
|
||||
|
||||
user_logs_table = Table(
|
||||
@@ -128,7 +139,7 @@ app_metadata_table = Table(
|
||||
)
|
||||
|
||||
|
||||
# --- Phase 2, Tier 2 --------------------------------------------------------
|
||||
# --- Agents, sources, attachments, artifacts --------------------------------
|
||||
|
||||
agent_folders_table = Table(
|
||||
"agent_folders",
|
||||
@@ -244,7 +255,8 @@ memories_table = Table(
|
||||
metadata,
|
||||
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||
Column("user_id", Text, nullable=False),
|
||||
Column("tool_id", UUID(as_uuid=True), ForeignKey("user_tools.id", ondelete="CASCADE")),
|
||||
# No FK since 0009 — delete-cascade preserved by trigger.
|
||||
Column("tool_id", UUID(as_uuid=True)),
|
||||
Column("path", Text, nullable=False),
|
||||
Column("content", Text, nullable=False),
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
@@ -297,7 +309,7 @@ connector_sessions_table = Table(
|
||||
)
|
||||
|
||||
|
||||
# --- Phase 3, Tier 3 --------------------------------------------------------
|
||||
# --- Conversations, messages, workflows -------------------------------------
|
||||
|
||||
conversations_table = Table(
|
||||
"conversations",
|
||||
@@ -345,9 +357,44 @@ conversation_messages_table = Table(
|
||||
Column("feedback", JSONB),
|
||||
Column("timestamp", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
# Added in 0004_durability_foundation. ``status`` is the WAL state
|
||||
# machine (pending|streaming|complete|failed); ``request_id`` ties a
|
||||
# row to a specific HTTP request for log correlation.
|
||||
Column("status", Text, nullable=False, server_default="complete"),
|
||||
Column("request_id", Text),
|
||||
UniqueConstraint("conversation_id", "position", name="conversation_messages_conv_pos_uidx"),
|
||||
)
|
||||
|
||||
# Per-yield journal of chat-stream events, used by the snapshot+tail
|
||||
# reconnect: the route's GET reconnect endpoint reads
|
||||
# ``WHERE message_id = ? AND sequence_no > ?`` from this table before
|
||||
# tailing the live ``channel:{message_id}`` pub/sub. See
|
||||
# ``application/streaming/event_replay.py`` and migration 0007.
|
||||
message_events_table = Table(
|
||||
"message_events",
|
||||
metadata,
|
||||
# PK is the composite ``(message_id, sequence_no)`` — it doubles as
|
||||
# the snapshot read index (covering range scan on
|
||||
# ``WHERE message_id = ? AND sequence_no > ?``).
|
||||
Column(
|
||||
"message_id",
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("conversation_messages.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
nullable=False,
|
||||
),
|
||||
# Strictly monotonic per ``message_id``. Allocated by the route as it
|
||||
# yields, so the writer is single-threaded for the lifetime of one
|
||||
# stream — no contention, no SERIAL needed.
|
||||
Column("sequence_no", Integer, primary_key=True, nullable=False),
|
||||
Column("event_type", Text, nullable=False),
|
||||
Column("payload", JSONB, nullable=False, server_default="{}"),
|
||||
Column(
|
||||
"created_at", DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
shared_conversations_table = Table(
|
||||
"shared_conversations",
|
||||
metadata,
|
||||
@@ -377,9 +424,104 @@ pending_tool_state_table = Table(
|
||||
Column("client_tools", JSONB),
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("expires_at", DateTime(timezone=True), nullable=False),
|
||||
# Added in ``0004_durability_foundation``. ``status`` is the
|
||||
# ``pending|resuming`` claim flag for the resumed-run path;
|
||||
# ``resumed_at`` stamps when ``mark_resuming`` flipped the row so
|
||||
# the cleanup janitor can revert stale claims after the grace
|
||||
# window.
|
||||
Column("status", Text, nullable=False, server_default="pending"),
|
||||
Column("resumed_at", DateTime(timezone=True)),
|
||||
UniqueConstraint("conversation_id", "user_id", name="pending_tool_state_conv_user_uidx"),
|
||||
)
|
||||
|
||||
|
||||
# --- Durability foundation (idempotency / journals, migration 0004) ---------
|
||||
# CHECK constraints (status enums) and partial indexes are intentionally
|
||||
# omitted from these declarations — the DB is the authority. Repositories
|
||||
# use raw ``text(...)`` SQL against these tables, not the Core objects.
|
||||
|
||||
task_dedup_table = Table(
|
||||
"task_dedup",
|
||||
metadata,
|
||||
Column("idempotency_key", Text, primary_key=True),
|
||||
Column("task_name", Text, nullable=False),
|
||||
Column("task_id", Text, nullable=False),
|
||||
Column("result_json", JSONB),
|
||||
# CHECK (status IN ('pending', 'completed', 'failed')) lives in 0004.
|
||||
Column("status", Text, nullable=False),
|
||||
# Bumped each time the per-Celery-task wrapper re-enters; the
|
||||
# poison-loop guard (``MAX_TASK_ATTEMPTS=5``) refuses to run fn once
|
||||
# this exceeds the threshold.
|
||||
Column("attempt_count", Integer, nullable=False, server_default="0"),
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
# Added in ``0006_idempotency_lease``. Per-invocation random id
|
||||
# written by the wrapper at lease claim; refreshed every 30 s by a
|
||||
# heartbeat thread. Other workers seeing a fresh lease (NOT NULL
|
||||
# AND ``lease_expires_at > now()``) refuse to run the task body.
|
||||
Column("lease_owner_id", Text),
|
||||
Column("lease_expires_at", DateTime(timezone=True)),
|
||||
)
|
||||
|
||||
webhook_dedup_table = Table(
|
||||
"webhook_dedup",
|
||||
metadata,
|
||||
Column("idempotency_key", Text, primary_key=True),
|
||||
Column("agent_id", UUID(as_uuid=True), nullable=False),
|
||||
Column("task_id", Text, nullable=False),
|
||||
Column("response_json", JSONB),
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
)
|
||||
|
||||
# Three-phase tool-call journal: ``proposed → executed → confirmed``
|
||||
# (terminal: ``failed``; ``compensated`` is grandfathered in the CHECK
|
||||
# from migration 0004 but no code writes it). The reconciler sweeps
|
||||
# stuck rows via the partial ``tool_call_attempts_pending_ts_idx``.
|
||||
tool_call_attempts_table = Table(
|
||||
"tool_call_attempts",
|
||||
metadata,
|
||||
Column("call_id", Text, primary_key=True),
|
||||
# ON DELETE SET NULL preserves the journal even after the parent
|
||||
# message is deleted — useful for cost-attribution / compliance.
|
||||
Column(
|
||||
"message_id",
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("conversation_messages.id", ondelete="SET NULL"),
|
||||
),
|
||||
Column("tool_id", UUID(as_uuid=True)),
|
||||
Column("tool_name", Text, nullable=False),
|
||||
Column("action_name", Text, nullable=False),
|
||||
Column("arguments", JSONB, nullable=False),
|
||||
Column("result", JSONB),
|
||||
Column("error", Text),
|
||||
# CHECK (status IN ('proposed', 'executed', 'confirmed',
|
||||
# 'compensated', 'failed')) lives in 0004.
|
||||
Column("status", Text, nullable=False),
|
||||
Column("attempted_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
)
|
||||
|
||||
# Per-source ingest checkpoint. Heartbeat thread bumps ``last_updated``
|
||||
# every 30s while a worker embeds; the reconciler escalates when it
|
||||
# stops ticking.
|
||||
ingest_chunk_progress_table = Table(
|
||||
"ingest_chunk_progress",
|
||||
metadata,
|
||||
Column("source_id", UUID(as_uuid=True), primary_key=True),
|
||||
Column("total_chunks", Integer, nullable=False),
|
||||
Column("embedded_chunks", Integer, nullable=False, server_default="0"),
|
||||
Column("last_index", Integer, nullable=False, server_default="-1"),
|
||||
Column("last_updated", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
# Added in ``0005_ingest_attempt_id``. Stamped from
|
||||
# ``self.request.id`` (Celery's stable task id) so a retry of the
|
||||
# same task resumes from the checkpoint, but a separate invocation
|
||||
# (manual reingest, scheduled sync) resets to a clean re-index.
|
||||
Column("attempt_id", Text),
|
||||
# Added in ``0008_ingest_progress_status``. The reconciler flips
|
||||
# this to 'stalled'; ``init_progress`` resets it to 'active'.
|
||||
Column("status", Text, nullable=False, server_default="active"),
|
||||
)
|
||||
|
||||
|
||||
workflows_table = Table(
|
||||
"workflows",
|
||||
metadata,
|
||||
@@ -458,3 +600,74 @@ workflow_runs_table = Table(
|
||||
Column("ended_at", DateTime(timezone=True)),
|
||||
Column("legacy_mongo_id", Text),
|
||||
)
|
||||
|
||||
|
||||
# --- Scheduler (migration 0010) ---------------------------------------------
|
||||
|
||||
schedules_table = Table(
|
||||
"schedules",
|
||||
metadata,
|
||||
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||
Column("user_id", Text, nullable=False),
|
||||
# Nullable as of 0011: agentless chats create one-time schedules whose
|
||||
# run is built ephemerally at fire time from system defaults.
|
||||
Column(
|
||||
"agent_id",
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("agents.id", ondelete="CASCADE"),
|
||||
),
|
||||
Column("trigger_type", Text, nullable=False),
|
||||
Column("name", Text),
|
||||
Column("instruction", Text, nullable=False),
|
||||
Column("status", Text, nullable=False, server_default="active"),
|
||||
Column("cron", Text),
|
||||
Column("run_at", DateTime(timezone=True)),
|
||||
Column("timezone", Text, nullable=False, server_default="UTC"),
|
||||
Column("next_run_at", DateTime(timezone=True)),
|
||||
Column("last_run_at", DateTime(timezone=True)),
|
||||
Column("end_at", DateTime(timezone=True)),
|
||||
Column("tool_allowlist", JSONB, nullable=False, server_default="[]"),
|
||||
Column("model_id", Text),
|
||||
Column("token_budget", Integer),
|
||||
Column(
|
||||
"origin_conversation_id",
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("conversations.id", ondelete="SET NULL"),
|
||||
),
|
||||
Column("created_via", Text, nullable=False, server_default="ui"),
|
||||
Column("consecutive_failure_count", Integer, nullable=False, server_default="0"),
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
)
|
||||
|
||||
schedule_runs_table = Table(
|
||||
"schedule_runs",
|
||||
metadata,
|
||||
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||
Column(
|
||||
"schedule_id",
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("schedules.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
Column("user_id", Text, nullable=False),
|
||||
# Nullable as of 0011 (mirrors ``schedules.agent_id``).
|
||||
Column("agent_id", UUID(as_uuid=True)),
|
||||
Column("status", Text, nullable=False, server_default="pending"),
|
||||
Column("scheduled_for", DateTime(timezone=True), nullable=False),
|
||||
Column("trigger_source", Text, nullable=False, server_default="cron"),
|
||||
Column("started_at", DateTime(timezone=True)),
|
||||
Column("finished_at", DateTime(timezone=True)),
|
||||
Column("output", Text),
|
||||
Column("output_truncated", Boolean, nullable=False, server_default="false"),
|
||||
Column("error", Text),
|
||||
Column("error_type", Text),
|
||||
Column("prompt_tokens", Integer, nullable=False, server_default="0"),
|
||||
Column("generated_tokens", Integer, nullable=False, server_default="0"),
|
||||
Column("conversation_id", UUID(as_uuid=True)),
|
||||
Column("message_id", UUID(as_uuid=True)),
|
||||
Column("celery_task_id", Text),
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
UniqueConstraint("schedule_id", "scheduled_for", name="schedule_runs_dedup_uidx"),
|
||||
)
|
||||
|
||||
@@ -17,6 +17,21 @@ _UPDATABLE_SCALARS = {
|
||||
_UPDATABLE_JSONB = {"metadata"}
|
||||
|
||||
|
||||
def _attachment_to_dict(row: Any) -> dict:
|
||||
"""row_to_dict + ``upload_path``→``path`` alias.
|
||||
|
||||
Pre-Postgres, the Mongo attachment shape used ``path``. The PG column
|
||||
is ``upload_path``; LLM provider code (google_ai/openai/anthropic and
|
||||
handlers/base) still reads ``attachment.get("path")``. Mirroring the
|
||||
``id``/``_id`` dual-emit in row_to_dict so consumers don't need to
|
||||
know which storage backend produced the dict.
|
||||
"""
|
||||
out = row_to_dict(row)
|
||||
if "upload_path" in out and out.get("path") is None:
|
||||
out["path"] = out["upload_path"]
|
||||
return out
|
||||
|
||||
|
||||
class AttachmentsRepository:
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
@@ -66,7 +81,7 @@ class AttachmentsRepository:
|
||||
"legacy_mongo_id": legacy_mongo_id,
|
||||
},
|
||||
)
|
||||
return row_to_dict(result.fetchone())
|
||||
return _attachment_to_dict(result.fetchone())
|
||||
|
||||
def get(self, attachment_id: str, user_id: str) -> Optional[dict]:
|
||||
result = self._conn.execute(
|
||||
@@ -76,7 +91,7 @@ class AttachmentsRepository:
|
||||
{"id": attachment_id, "user_id": user_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
return _attachment_to_dict(row) if row is not None else None
|
||||
|
||||
def get_any(self, attachment_id: str, user_id: str) -> Optional[dict]:
|
||||
"""Resolve an attachment by either PG UUID or legacy Mongo ObjectId string."""
|
||||
@@ -155,14 +170,14 @@ class AttachmentsRepository:
|
||||
params["user_id"] = user_id
|
||||
result = self._conn.execute(text(sql), params)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
return _attachment_to_dict(row) if row is not None else None
|
||||
|
||||
def list_for_user(self, user_id: str) -> list[dict]:
|
||||
result = self._conn.execute(
|
||||
text("SELECT * FROM attachments WHERE user_id = :user_id ORDER BY created_at DESC"),
|
||||
{"user_id": user_id},
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
return [_attachment_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def update(self, attachment_id: str, user_id: str, fields: dict) -> bool:
|
||||
"""Partial update. Used by the LLM providers to cache their
|
||||
|
||||
@@ -25,6 +25,7 @@ from typing import Any, Optional
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
from application.storage.db.serialization import PGNativeJSONEncoder
|
||||
|
||||
|
||||
_UPDATABLE_SCALARS = {
|
||||
@@ -36,7 +37,7 @@ _UPDATABLE_JSONB = {"session_data", "token_info"}
|
||||
def _jsonb(value: Any) -> Any:
|
||||
if value is None:
|
||||
return None
|
||||
return json.dumps(value, default=str)
|
||||
return json.dumps(value, cls=PGNativeJSONEncoder)
|
||||
|
||||
|
||||
class ConnectorSessionsRepository:
|
||||
|
||||
@@ -15,6 +15,7 @@ Covers every operation the legacy Mongo code performs on
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
@@ -22,6 +23,23 @@ from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
|
||||
from application.storage.db.models import conversations_table, conversation_messages_table
|
||||
from application.storage.db.serialization import PGNativeJSONEncoder
|
||||
|
||||
|
||||
class MessageUpdateOutcome(str, Enum):
|
||||
"""Discriminated result of ``update_message_by_id``.
|
||||
|
||||
Distinguishes the row-actually-updated case from the row-already-at-
|
||||
the-requested-terminal-state case so an abort handler can journal
|
||||
``end`` instead of ``error`` when the normal-path finalize already
|
||||
flipped the row to ``complete``.
|
||||
"""
|
||||
|
||||
UPDATED = "updated"
|
||||
ALREADY_COMPLETE = "already_complete"
|
||||
ALREADY_FAILED = "already_failed"
|
||||
NOT_FOUND = "not_found"
|
||||
INVALID = "invalid"
|
||||
|
||||
|
||||
def _message_row_to_dict(row) -> dict:
|
||||
@@ -57,8 +75,8 @@ class ConversationsRepository:
|
||||
- Already-UUID-shaped → returned as-is.
|
||||
- Otherwise treated as a Mongo ObjectId and looked up via
|
||||
``agents.legacy_mongo_id``. Returns ``None`` if no PG row
|
||||
exists yet (e.g. the agent was created before Phase 1
|
||||
backfill).
|
||||
exists yet (e.g. the agent was created before the backfill
|
||||
ran).
|
||||
"""
|
||||
if not agent_id_raw:
|
||||
return None
|
||||
@@ -452,7 +470,7 @@ class ConversationsRepository:
|
||||
),
|
||||
{
|
||||
"id": conversation_id,
|
||||
"point": json.dumps(point, default=str),
|
||||
"point": json.dumps(point, cls=PGNativeJSONEncoder),
|
||||
"max_points": int(max_points),
|
||||
},
|
||||
)
|
||||
@@ -632,6 +650,233 @@ class ConversationsRepository:
|
||||
result = self._conn.execute(text(sql), params)
|
||||
return result.rowcount > 0
|
||||
|
||||
def reserve_message(
|
||||
self,
|
||||
conversation_id: str,
|
||||
*,
|
||||
prompt: str,
|
||||
placeholder_response: str,
|
||||
request_id: str | None = None,
|
||||
status: str = "pending",
|
||||
attachments: list[str] | None = None,
|
||||
model_id: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> dict:
|
||||
"""Pre-persist a placeholder assistant message before the LLM call."""
|
||||
self._conn.execute(
|
||||
text(
|
||||
"SELECT id FROM conversations "
|
||||
"WHERE id = CAST(:conv_id AS uuid) FOR UPDATE"
|
||||
),
|
||||
{"conv_id": conversation_id},
|
||||
)
|
||||
next_pos = self._conn.execute(
|
||||
text(
|
||||
"SELECT COALESCE(MAX(position), -1) + 1 AS next_pos "
|
||||
"FROM conversation_messages "
|
||||
"WHERE conversation_id = CAST(:conv_id AS uuid)"
|
||||
),
|
||||
{"conv_id": conversation_id},
|
||||
).scalar()
|
||||
|
||||
values = {
|
||||
"conversation_id": conversation_id,
|
||||
"position": next_pos,
|
||||
"prompt": prompt,
|
||||
"response": placeholder_response,
|
||||
"status": status,
|
||||
"request_id": request_id,
|
||||
"model_id": model_id,
|
||||
"message_metadata": metadata or {},
|
||||
}
|
||||
if attachments:
|
||||
resolved = self._resolve_attachment_refs(
|
||||
[str(a) for a in attachments],
|
||||
)
|
||||
if resolved:
|
||||
values["attachments"] = resolved
|
||||
|
||||
stmt = (
|
||||
pg_insert(conversation_messages_table)
|
||||
.values(**values)
|
||||
.returning(conversation_messages_table)
|
||||
)
|
||||
result = self._conn.execute(stmt)
|
||||
self._conn.execute(
|
||||
text(
|
||||
"UPDATE conversations SET updated_at = now() "
|
||||
"WHERE id = CAST(:id AS uuid)"
|
||||
),
|
||||
{"id": conversation_id},
|
||||
)
|
||||
return _message_row_to_dict(result.fetchone())
|
||||
|
||||
def update_message_by_id(
|
||||
self, message_id: str, fields: dict,
|
||||
*, only_if_non_terminal: bool = False,
|
||||
) -> MessageUpdateOutcome:
|
||||
"""Update specific fields on a message identified by its UUID.
|
||||
|
||||
``metadata`` is merged into the existing JSONB rather than
|
||||
overwritten, so a reconciler-set ``reconcile_attempts`` survives
|
||||
a successful late finalize. When ``only_if_non_terminal`` is
|
||||
True, the update is gated so a late finalize cannot retract a
|
||||
reconciler-set ``failed`` (or a prior ``complete``).
|
||||
|
||||
The return value discriminates "I updated the row" from "the
|
||||
row was already at a terminal state" so the abort handler can
|
||||
journal ``end`` when the normal-path finalize already ran.
|
||||
"""
|
||||
if not looks_like_uuid(message_id):
|
||||
return MessageUpdateOutcome.INVALID
|
||||
allowed = {
|
||||
"prompt", "response", "thought", "sources", "tool_calls",
|
||||
"attachments", "model_id", "metadata", "timestamp", "status",
|
||||
"request_id", "feedback", "feedback_timestamp",
|
||||
}
|
||||
filtered = {k: v for k, v in fields.items() if k in allowed}
|
||||
if not filtered:
|
||||
return MessageUpdateOutcome.INVALID
|
||||
|
||||
api_to_col = {"metadata": "message_metadata"}
|
||||
|
||||
set_parts = []
|
||||
params: dict = {"id": message_id}
|
||||
for key, val in filtered.items():
|
||||
col = api_to_col.get(key, key)
|
||||
if key == "metadata":
|
||||
if val is None:
|
||||
set_parts.append(f"{col} = NULL")
|
||||
else:
|
||||
set_parts.append(
|
||||
f"{col} = COALESCE({col}, '{{}}'::jsonb) "
|
||||
f"|| CAST(:{col} AS jsonb)"
|
||||
)
|
||||
params[col] = (
|
||||
json.dumps(val) if not isinstance(val, str) else val
|
||||
)
|
||||
elif key in ("sources", "tool_calls", "feedback"):
|
||||
set_parts.append(f"{col} = CAST(:{col} AS jsonb)")
|
||||
if val is None:
|
||||
params[col] = None
|
||||
else:
|
||||
params[col] = (
|
||||
json.dumps(val) if not isinstance(val, str) else val
|
||||
)
|
||||
elif key == "attachments":
|
||||
set_parts.append(f"{col} = CAST(:{col} AS uuid[])")
|
||||
params[col] = self._resolve_attachment_refs(
|
||||
[str(a) for a in val] if val else [],
|
||||
)
|
||||
else:
|
||||
set_parts.append(f"{col} = :{col}")
|
||||
params[col] = val
|
||||
|
||||
set_parts.append("updated_at = now()")
|
||||
update_where = ["id = CAST(:id AS uuid)"]
|
||||
if only_if_non_terminal:
|
||||
update_where.append("status NOT IN ('complete', 'failed')")
|
||||
# Single-statement attempt + prior-status probe. Both CTEs see
|
||||
# the same MVCC snapshot, so ``prior.status`` reflects the row
|
||||
# state before the UPDATE — exactly what we need to tell
|
||||
# ``ALREADY_COMPLETE`` apart from ``ALREADY_FAILED`` apart from
|
||||
# ``NOT_FOUND`` without a follow-up SELECT.
|
||||
sql = (
|
||||
"WITH attempted AS ("
|
||||
f" UPDATE conversation_messages SET {', '.join(set_parts)} "
|
||||
f" WHERE {' AND '.join(update_where)} "
|
||||
" RETURNING 1 AS updated"
|
||||
"), "
|
||||
"prior AS ("
|
||||
" SELECT status FROM conversation_messages "
|
||||
" WHERE id = CAST(:id AS uuid)"
|
||||
") "
|
||||
"SELECT (SELECT updated FROM attempted) AS updated, "
|
||||
" (SELECT status FROM prior) AS prior_status"
|
||||
)
|
||||
row = self._conn.execute(text(sql), params).fetchone()
|
||||
if row is None:
|
||||
return MessageUpdateOutcome.NOT_FOUND
|
||||
updated, prior_status = row[0], row[1]
|
||||
if updated:
|
||||
return MessageUpdateOutcome.UPDATED
|
||||
if prior_status is None:
|
||||
return MessageUpdateOutcome.NOT_FOUND
|
||||
if prior_status == "complete":
|
||||
return MessageUpdateOutcome.ALREADY_COMPLETE
|
||||
if prior_status == "failed":
|
||||
return MessageUpdateOutcome.ALREADY_FAILED
|
||||
# ``only_if_non_terminal=False`` always updates an existing row,
|
||||
# so reaching here means the gate excluded it for some status
|
||||
# the terminal set doesn't cover — treat as "not found" rather
|
||||
# than inventing a new variant.
|
||||
return MessageUpdateOutcome.NOT_FOUND
|
||||
|
||||
def update_message_status(
|
||||
self, message_id: str, status: str,
|
||||
) -> bool:
|
||||
"""Cheap status-only transition (e.g. pending → streaming).
|
||||
|
||||
Only flips non-terminal rows: a reconciler-set ``failed`` row
|
||||
stays put so the late streaming chunk doesn't silently retract
|
||||
the alert.
|
||||
"""
|
||||
if not looks_like_uuid(message_id):
|
||||
return False
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"UPDATE conversation_messages SET status = :status, "
|
||||
"updated_at = now() "
|
||||
"WHERE id = CAST(:id AS uuid) "
|
||||
"AND status NOT IN ('complete', 'failed')"
|
||||
),
|
||||
{"id": message_id, "status": status},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def heartbeat_message(self, message_id: str) -> bool:
|
||||
"""Stamp ``message_metadata.last_heartbeat_at`` with ``clock_timestamp()``.
|
||||
|
||||
The reconciler's staleness check uses ``GREATEST(timestamp,
|
||||
last_heartbeat_at)``, so this call extends a long-running
|
||||
stream's effective freshness without touching ``timestamp`` (the
|
||||
creation time, used for history sort) or ``status`` (the WAL
|
||||
marker). Skips terminal rows so a late heartbeat can't silently
|
||||
retract a reconciler-set ``failed``.
|
||||
"""
|
||||
if not looks_like_uuid(message_id):
|
||||
return False
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE conversation_messages
|
||||
SET message_metadata = jsonb_set(
|
||||
COALESCE(message_metadata, '{}'::jsonb),
|
||||
'{last_heartbeat_at}',
|
||||
to_jsonb(clock_timestamp())
|
||||
)
|
||||
WHERE id = CAST(:id AS uuid)
|
||||
AND status NOT IN ('complete', 'failed')
|
||||
"""
|
||||
),
|
||||
{"id": message_id},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def confirm_executed_tool_calls(self, message_id: str) -> int:
|
||||
"""Flip ``tool_call_attempts.status='executed' → 'confirmed'`` for the message."""
|
||||
if not looks_like_uuid(message_id):
|
||||
return 0
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"UPDATE tool_call_attempts SET status = 'confirmed', "
|
||||
"updated_at = now() "
|
||||
"WHERE message_id = CAST(:mid AS uuid) AND status = 'executed'"
|
||||
),
|
||||
{"mid": message_id},
|
||||
)
|
||||
return result.rowcount or 0
|
||||
|
||||
def truncate_after(self, conversation_id: str, keep_up_to: int) -> int:
|
||||
"""Delete messages with position > keep_up_to.
|
||||
|
||||
|
||||
346
application/storage/db/repositories/idempotency.py
Normal file
346
application/storage/db/repositories/idempotency.py
Normal file
@@ -0,0 +1,346 @@
|
||||
"""Repository for ``webhook_dedup`` and ``task_dedup``; 24h TTL enforced at read."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
from application.storage.db.serialization import PGNativeJSONEncoder
|
||||
|
||||
# 24h TTL is the contract surfaced in the upload/webhook docstrings; the
|
||||
# read filters and the stale-row replacement predicate must agree, or the
|
||||
# upsert can fall into a window where the row is "fresh" to the writer
|
||||
# but "expired" to the reader (or vice versa). Keep one constant so any
|
||||
# future change moves both directions in lockstep.
|
||||
DEDUP_TTL_INTERVAL = "24 hours"
|
||||
|
||||
|
||||
def _jsonb(value: Any) -> Any:
|
||||
if value is None:
|
||||
return None
|
||||
return json.dumps(value, cls=PGNativeJSONEncoder)
|
||||
|
||||
|
||||
class IdempotencyRepository:
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
# --- webhook_dedup -----------------------------------------------------
|
||||
|
||||
def get_webhook(self, key: str) -> Optional[dict]:
|
||||
"""Return the cached webhook row for ``key`` if still within the 24h window."""
|
||||
row = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT * FROM webhook_dedup
|
||||
WHERE idempotency_key = :key
|
||||
AND created_at > now() - CAST(:ttl AS interval)
|
||||
"""
|
||||
),
|
||||
{"key": key, "ttl": DEDUP_TTL_INTERVAL},
|
||||
).fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def record_webhook(
|
||||
self,
|
||||
key: str,
|
||||
agent_id: str,
|
||||
task_id: str,
|
||||
response_json: dict,
|
||||
) -> Optional[dict]:
|
||||
"""Insert a webhook dedup row; return None if another writer raced and won.
|
||||
|
||||
``ON CONFLICT`` replaces an existing row only when its ``created_at``
|
||||
is past TTL — atomic stale-row recycling under the row lock. A
|
||||
within-TTL conflict yields no row; the caller resolves it via
|
||||
:meth:`get_webhook`.
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO webhook_dedup (
|
||||
idempotency_key, agent_id, task_id, response_json
|
||||
)
|
||||
VALUES (
|
||||
:key, CAST(:agent_id AS uuid), :task_id,
|
||||
CAST(:response_json AS jsonb)
|
||||
)
|
||||
ON CONFLICT (idempotency_key) DO UPDATE
|
||||
SET agent_id = EXCLUDED.agent_id,
|
||||
task_id = EXCLUDED.task_id,
|
||||
response_json = EXCLUDED.response_json,
|
||||
created_at = now()
|
||||
WHERE webhook_dedup.created_at
|
||||
<= now() - CAST(:ttl AS interval)
|
||||
RETURNING *
|
||||
"""
|
||||
),
|
||||
{
|
||||
"key": key,
|
||||
"agent_id": agent_id,
|
||||
"task_id": task_id,
|
||||
"response_json": _jsonb(response_json),
|
||||
"ttl": DEDUP_TTL_INTERVAL,
|
||||
},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
# --- task_dedup --------------------------------------------------------
|
||||
|
||||
def get_task(self, key: str) -> Optional[dict]:
|
||||
"""Return the cached task row for ``key`` if still within the 24h window."""
|
||||
row = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT * FROM task_dedup
|
||||
WHERE idempotency_key = :key
|
||||
AND created_at > now() - CAST(:ttl AS interval)
|
||||
"""
|
||||
),
|
||||
{"key": key, "ttl": DEDUP_TTL_INTERVAL},
|
||||
).fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def claim_task(
|
||||
self,
|
||||
key: str,
|
||||
task_name: str,
|
||||
task_id: str,
|
||||
) -> Optional[dict]:
|
||||
"""Claim ``key`` for this task. Returns the inserted row, or None if
|
||||
another writer raced and won. The HTTP entry must call this *before*
|
||||
``.delay()`` so only the winner enqueues the Celery task.
|
||||
|
||||
``ON CONFLICT`` replaces an existing row in two cases:
|
||||
|
||||
- **status='failed'**: the worker's poison-loop guard or the
|
||||
reconciler's stuck-pending sweep finalised the prior attempt
|
||||
as failed. Both explicitly intend a same-key retry to re-run
|
||||
(see ``run_reconciliation`` Q5 docstring) — letting the row
|
||||
block for 24 h would silently undo that intent.
|
||||
- **created_at past TTL**: a stale claim from any status no
|
||||
longer represents a meaningful dedup signal.
|
||||
|
||||
``status='completed'`` rows still block within TTL — that's the
|
||||
cached-success contract callers rely on. ``status='pending'``
|
||||
rows still block within TTL so concurrent same-key requests
|
||||
collapse onto the in-flight task. Result/attempt fields are
|
||||
reset to their fresh-claim defaults during replacement.
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO task_dedup (
|
||||
idempotency_key, task_name, task_id, result_json, status
|
||||
)
|
||||
VALUES (
|
||||
:key, :task_name, :task_id, NULL, 'pending'
|
||||
)
|
||||
ON CONFLICT (idempotency_key) DO UPDATE
|
||||
SET task_name = EXCLUDED.task_name,
|
||||
task_id = EXCLUDED.task_id,
|
||||
result_json = NULL,
|
||||
status = 'pending',
|
||||
attempt_count = 0,
|
||||
created_at = now()
|
||||
WHERE task_dedup.status = 'failed'
|
||||
OR task_dedup.created_at
|
||||
<= now() - CAST(:ttl AS interval)
|
||||
RETURNING *
|
||||
"""
|
||||
),
|
||||
{
|
||||
"key": key,
|
||||
"task_name": task_name,
|
||||
"task_id": task_id,
|
||||
"ttl": DEDUP_TTL_INTERVAL,
|
||||
},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def try_claim_lease(
|
||||
self,
|
||||
key: str,
|
||||
task_name: str,
|
||||
task_id: str,
|
||||
owner_id: str,
|
||||
ttl_seconds: int = 60,
|
||||
) -> Optional[int]:
|
||||
"""Atomically claim the running lease for ``key``.
|
||||
|
||||
Returns the new ``attempt_count`` if this caller now owns the
|
||||
lease (fresh insert OR existing row whose lease was empty/expired),
|
||||
or ``None`` if a different worker holds a live lease.
|
||||
|
||||
The conflict path also bumps ``attempt_count`` so the
|
||||
poison-loop guard in :func:`with_idempotency` can fire after
|
||||
:data:`MAX_TASK_ATTEMPTS` reclaims. ``status='completed'`` rows
|
||||
are deliberately untouched — :func:`_lookup_completed` is the
|
||||
cache short-circuit and runs before this. Uses
|
||||
``clock_timestamp()`` so a same-transaction refresh actually
|
||||
moves the expiry forward (``now()`` is frozen at txn start).
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO task_dedup (
|
||||
idempotency_key, task_name, task_id, status, attempt_count,
|
||||
lease_owner_id, lease_expires_at
|
||||
) VALUES (
|
||||
:key, :task_name, :task_id, 'pending', 1,
|
||||
:owner,
|
||||
clock_timestamp() + make_interval(secs => :ttl)
|
||||
)
|
||||
ON CONFLICT (idempotency_key) DO UPDATE
|
||||
SET attempt_count = task_dedup.attempt_count + 1,
|
||||
task_name = EXCLUDED.task_name,
|
||||
lease_owner_id = EXCLUDED.lease_owner_id,
|
||||
lease_expires_at = EXCLUDED.lease_expires_at
|
||||
WHERE task_dedup.status <> 'completed'
|
||||
AND (task_dedup.lease_expires_at IS NULL
|
||||
OR task_dedup.lease_expires_at <= clock_timestamp())
|
||||
RETURNING attempt_count
|
||||
"""
|
||||
),
|
||||
{
|
||||
"key": key,
|
||||
"task_name": task_name,
|
||||
"task_id": task_id,
|
||||
"owner": owner_id,
|
||||
"ttl": int(ttl_seconds),
|
||||
},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return int(row[0]) if row is not None else None
|
||||
|
||||
def refresh_lease(
|
||||
self,
|
||||
key: str,
|
||||
owner_id: str,
|
||||
ttl_seconds: int = 60,
|
||||
) -> bool:
|
||||
"""Bump ``lease_expires_at`` if this caller still owns the lease.
|
||||
|
||||
Returns False when ownership was lost (lease stolen by another
|
||||
worker after expiry, or row finalised). The heartbeat thread
|
||||
logs that as a warning but doesn't try to abort the running
|
||||
task — at-most-one-worker is bounded by ``ttl_seconds``, the
|
||||
damage from a brief overlap window is unavoidable in this case.
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE task_dedup
|
||||
SET lease_expires_at =
|
||||
clock_timestamp() + make_interval(secs => :ttl)
|
||||
WHERE idempotency_key = :key
|
||||
AND lease_owner_id = :owner
|
||||
AND status = 'pending'
|
||||
"""
|
||||
),
|
||||
{
|
||||
"key": key,
|
||||
"owner": owner_id,
|
||||
"ttl": int(ttl_seconds),
|
||||
},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def release_lease(self, key: str, owner_id: str) -> bool:
|
||||
"""Clear ``lease_owner_id`` / ``lease_expires_at`` on the
|
||||
wrapper's exception path so Celery's autoretry_for doesn't have
|
||||
to wait the full ``ttl_seconds`` before the next worker can
|
||||
re-claim. No-op if a different worker has since taken over the
|
||||
lease — that case is benign (we'd just be acknowledging we
|
||||
weren't the owner anymore).
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE task_dedup
|
||||
SET lease_owner_id = NULL,
|
||||
lease_expires_at = NULL
|
||||
WHERE idempotency_key = :key
|
||||
AND lease_owner_id = :owner
|
||||
AND status = 'pending'
|
||||
"""
|
||||
),
|
||||
{"key": key, "owner": owner_id},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def finalize_task(
|
||||
self,
|
||||
key: str,
|
||||
*,
|
||||
result_json: Optional[dict],
|
||||
status: str,
|
||||
) -> bool:
|
||||
"""Promote ``status='pending'`` → ``completed|failed`` with the
|
||||
recorded result. Also clears the lease columns so a stale
|
||||
``lease_expires_at`` doesn't show up in operator dashboards.
|
||||
No-op if the row is already terminal — preserves the first
|
||||
writer's outcome on a crash + retry.
|
||||
"""
|
||||
if status not in ("completed", "failed"):
|
||||
raise ValueError(f"finalize_task: invalid status {status!r}")
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE task_dedup
|
||||
SET status = :status,
|
||||
result_json = CAST(:result_json AS jsonb),
|
||||
lease_owner_id = NULL,
|
||||
lease_expires_at = NULL
|
||||
WHERE idempotency_key = :key
|
||||
AND status = 'pending'
|
||||
"""
|
||||
),
|
||||
{
|
||||
"key": key,
|
||||
"status": status,
|
||||
"result_json": _jsonb(result_json),
|
||||
},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
# --- housekeeping ------------------------------------------------------
|
||||
|
||||
def cleanup_expired(self) -> dict:
|
||||
"""Delete rows past TTL from both dedup tables; return per-table counts.
|
||||
|
||||
The TTL-aware upserts already prevent stale rows from blocking new
|
||||
work, so this is purely housekeeping — bounds table growth and
|
||||
keeps test isolation cheap. Safe to run concurrently with other
|
||||
writers: a same-key INSERT racing the DELETE will either find no
|
||||
row (acts as a fresh insert) or find a fresh row (re-created
|
||||
between DELETE and conflict-check), neither of which is wrong.
|
||||
"""
|
||||
task_deleted = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
DELETE FROM task_dedup
|
||||
WHERE created_at <= now() - CAST(:ttl AS interval)
|
||||
"""
|
||||
),
|
||||
{"ttl": DEDUP_TTL_INTERVAL},
|
||||
).rowcount
|
||||
webhook_deleted = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
DELETE FROM webhook_dedup
|
||||
WHERE created_at <= now() - CAST(:ttl AS interval)
|
||||
"""
|
||||
),
|
||||
{"ttl": DEDUP_TTL_INTERVAL},
|
||||
).rowcount
|
||||
return {
|
||||
"task_dedup_deleted": int(task_deleted or 0),
|
||||
"webhook_dedup_deleted": int(webhook_deleted or 0),
|
||||
}
|
||||
|
||||
148
application/storage/db/repositories/ingest_chunk_progress.py
Normal file
148
application/storage/db/repositories/ingest_chunk_progress.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""Repository for ``ingest_chunk_progress``; per-source resume + heartbeat."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
|
||||
|
||||
class IngestChunkProgressRepository:
|
||||
"""Read/write helpers for ``ingest_chunk_progress``."""
|
||||
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def init_progress(
|
||||
self,
|
||||
source_id: str,
|
||||
total_chunks: int,
|
||||
attempt_id: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Upsert the progress row, scoped by ``attempt_id``.
|
||||
|
||||
On conflict the upsert distinguishes two cases:
|
||||
|
||||
- **Same attempt** (``attempt_id`` matches the stored value):
|
||||
this is a Celery autoretry of the same task — preserve
|
||||
``last_index`` / ``embedded_chunks`` so the embed loop resumes
|
||||
from the checkpoint. Only ``total_chunks`` and
|
||||
``last_updated`` get refreshed.
|
||||
- **Different attempt** (a fresh invocation: manual reingest,
|
||||
scheduled sync, or any caller that didn't pass an
|
||||
``attempt_id``): reset ``last_index`` to ``-1`` and
|
||||
``embedded_chunks`` to ``0`` so the loop starts from chunk 0.
|
||||
This prevents a completed checkpoint from any prior run
|
||||
poisoning the index.
|
||||
|
||||
``IS NOT DISTINCT FROM`` treats two NULLs as equal — so legacy
|
||||
rows with NULL ``attempt_id`` resume against another NULL
|
||||
caller (e.g. test fixtures), but get reset the moment a real
|
||||
``attempt_id`` arrives.
|
||||
|
||||
Both branches also reset ``status`` to ``'active'``, clearing a
|
||||
prior reconciler ``'stalled'`` escalation.
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO ingest_chunk_progress (
|
||||
source_id, total_chunks, embedded_chunks, last_index,
|
||||
attempt_id, last_updated
|
||||
)
|
||||
VALUES (
|
||||
CAST(:source_id AS uuid), :total_chunks, 0, -1,
|
||||
:attempt_id, now()
|
||||
)
|
||||
ON CONFLICT (source_id) DO UPDATE SET
|
||||
total_chunks = EXCLUDED.total_chunks,
|
||||
last_updated = now(),
|
||||
last_index = CASE
|
||||
WHEN ingest_chunk_progress.attempt_id
|
||||
IS NOT DISTINCT FROM EXCLUDED.attempt_id
|
||||
THEN ingest_chunk_progress.last_index
|
||||
ELSE -1
|
||||
END,
|
||||
embedded_chunks = CASE
|
||||
WHEN ingest_chunk_progress.attempt_id
|
||||
IS NOT DISTINCT FROM EXCLUDED.attempt_id
|
||||
THEN ingest_chunk_progress.embedded_chunks
|
||||
ELSE 0
|
||||
END,
|
||||
attempt_id = EXCLUDED.attempt_id,
|
||||
status = 'active'
|
||||
RETURNING *
|
||||
"""
|
||||
),
|
||||
{
|
||||
"source_id": str(source_id),
|
||||
"total_chunks": int(total_chunks),
|
||||
"attempt_id": attempt_id,
|
||||
},
|
||||
)
|
||||
return row_to_dict(result.fetchone())
|
||||
|
||||
def record_chunk(
|
||||
self, source_id: str, last_index: int, embedded_chunks: int
|
||||
) -> None:
|
||||
"""Persist progress after a chunk is embedded."""
|
||||
self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE ingest_chunk_progress
|
||||
SET last_index = :last_index,
|
||||
embedded_chunks = :embedded_chunks,
|
||||
last_updated = now()
|
||||
WHERE source_id = CAST(:source_id AS uuid)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"source_id": str(source_id),
|
||||
"last_index": int(last_index),
|
||||
"embedded_chunks": int(embedded_chunks),
|
||||
},
|
||||
)
|
||||
|
||||
def get_progress(self, source_id: str) -> Optional[dict]:
|
||||
"""Return the progress row for ``source_id`` if it exists."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM ingest_chunk_progress "
|
||||
"WHERE source_id = CAST(:source_id AS uuid)"
|
||||
),
|
||||
{"source_id": str(source_id)},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def delete(self, source_id: str) -> bool:
|
||||
"""Delete the progress row for ``source_id``.
|
||||
|
||||
A manual reingest supersedes any prior ingest state — including a
|
||||
reconciler ``'stalled'`` escalation — so dropping the row clears
|
||||
the derived ``failed`` ingest status the sources list shows.
|
||||
Returns ``True`` when a row was removed.
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"DELETE FROM ingest_chunk_progress "
|
||||
"WHERE source_id = CAST(:source_id AS uuid)"
|
||||
),
|
||||
{"source_id": str(source_id)},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def bump_heartbeat(self, source_id: str) -> None:
|
||||
"""Refresh ``last_updated`` so the row looks alive to the reconciler."""
|
||||
self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE ingest_chunk_progress
|
||||
SET last_updated = now()
|
||||
WHERE source_id = CAST(:source_id AS uuid)
|
||||
"""
|
||||
),
|
||||
{"source_id": str(source_id)},
|
||||
)
|
||||
@@ -86,6 +86,22 @@ class MemoriesRepository:
|
||||
)
|
||||
return result.rowcount
|
||||
|
||||
def delete_orphans(self, keep_tool_ids: Optional[list[str]] = None) -> int:
|
||||
"""Delete memories whose tool_id has no user_tools row, except keep_tool_ids."""
|
||||
keep = [str(tid) for tid in (keep_tool_ids or [])]
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
DELETE FROM memories
|
||||
WHERE tool_id IS NOT NULL
|
||||
AND tool_id NOT IN (SELECT id FROM user_tools)
|
||||
AND NOT (tool_id = ANY(CAST(:keep AS uuid[])))
|
||||
"""
|
||||
),
|
||||
{"keep": keep},
|
||||
)
|
||||
return result.rowcount
|
||||
|
||||
def update_path(self, user_id: str, tool_id: str, old_path: str, new_path: str) -> bool:
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
|
||||
248
application/storage/db/repositories/message_events.py
Normal file
248
application/storage/db/repositories/message_events.py
Normal file
@@ -0,0 +1,248 @@
|
||||
"""Repository for ``message_events`` — the chat-stream snapshot journal.
|
||||
|
||||
``record`` / ``bulk_record`` write per-yield events; ``read_after``
|
||||
replays rows past a cursor for reconnect snapshots. Composite PK
|
||||
``(message_id, sequence_no)`` raises ``IntegrityError`` on duplicates.
|
||||
Callers must use short-lived per-call transactions — long-lived
|
||||
transactions hide writes from reconnecting clients on a separate
|
||||
connection and turn one bad row into ``InFailedSqlTransaction``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageEventsRepository:
|
||||
"""Read/write helpers for ``message_events``."""
|
||||
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def record(
|
||||
self,
|
||||
message_id: str,
|
||||
sequence_no: int,
|
||||
event_type: str,
|
||||
payload: Optional[Any] = None,
|
||||
) -> None:
|
||||
"""Append a single event to the journal.
|
||||
|
||||
At this raw repo layer ``payload`` is preserved as-is when not
|
||||
``None`` (lists, scalars, and dicts all round-trip via JSONB);
|
||||
``None`` substitutes an empty object so the column's NOT NULL
|
||||
invariant holds. The streaming-route wrapper
|
||||
``application/streaming/message_journal.py::record_event``
|
||||
tightens this contract to dicts only — the live and replay
|
||||
paths reconstruct non-dict payloads differently, so the wrapper
|
||||
rejects them at the gate. Direct callers of this repo method
|
||||
(cleanup tasks, tests, future ad-hoc consumers) keep the wider
|
||||
JSONB-compatible surface.
|
||||
|
||||
Raises ``sqlalchemy.exc.IntegrityError`` on duplicate
|
||||
``(message_id, sequence_no)`` and ``DataError`` on a malformed
|
||||
``message_id`` UUID. Both abort the surrounding transaction —
|
||||
callers must run inside a short-lived per-event session
|
||||
(see module docstring).
|
||||
"""
|
||||
if not event_type:
|
||||
raise ValueError("event_type must be a non-empty string")
|
||||
materialised_payload = payload if payload is not None else {}
|
||||
self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO message_events (
|
||||
message_id, sequence_no, event_type, payload
|
||||
) VALUES (
|
||||
CAST(:message_id AS uuid), :sequence_no, :event_type,
|
||||
CAST(:payload AS jsonb)
|
||||
)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"message_id": str(message_id),
|
||||
"sequence_no": int(sequence_no),
|
||||
"event_type": event_type,
|
||||
"payload": json.dumps(materialised_payload),
|
||||
},
|
||||
)
|
||||
|
||||
def bulk_record(
|
||||
self,
|
||||
message_id: str,
|
||||
events: list[tuple[int, str, dict]],
|
||||
) -> None:
|
||||
"""Append multiple events for ``message_id`` in one INSERT.
|
||||
|
||||
``events`` is a list of ``(sequence_no, event_type, payload)``
|
||||
tuples. SQLAlchemy ``executemany`` issues one bulk INSERT;
|
||||
Postgres treats the whole batch as one statement, so an
|
||||
IntegrityError on any row aborts the entire batch.
|
||||
|
||||
Caller contract: on IntegrityError, do NOT retry this method
|
||||
with the same batch — fall back to per-row ``record()`` calls
|
||||
(each in its own short-lived session) so a single colliding
|
||||
seq doesn't drop the rest of the batch. ``BatchedJournalWriter``
|
||||
in ``application/streaming/message_journal.py`` is the canonical
|
||||
consumer.
|
||||
"""
|
||||
if not events:
|
||||
return
|
||||
params = [
|
||||
{
|
||||
"message_id": str(message_id),
|
||||
"sequence_no": int(seq),
|
||||
"event_type": event_type,
|
||||
"payload": json.dumps(payload if payload is not None else {}),
|
||||
}
|
||||
for seq, event_type, payload in events
|
||||
]
|
||||
self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO message_events (
|
||||
message_id, sequence_no, event_type, payload
|
||||
) VALUES (
|
||||
CAST(:message_id AS uuid), :sequence_no, :event_type,
|
||||
CAST(:payload AS jsonb)
|
||||
)
|
||||
"""
|
||||
),
|
||||
params,
|
||||
)
|
||||
|
||||
def read_after(
|
||||
self,
|
||||
message_id: str,
|
||||
last_sequence_no: Optional[int] = None,
|
||||
) -> list[dict]:
|
||||
"""Return events with ``sequence_no > last_sequence_no``.
|
||||
|
||||
``last_sequence_no=None`` returns the full backlog. Rows are
|
||||
returned in ascending ``sequence_no`` order. The composite PK
|
||||
is the snapshot read index for this scan — Postgres typically
|
||||
picks an in-order index range scan, though for highly mixed
|
||||
data the planner may pick a bitmap+sort. Either way the result
|
||||
is sorted on ``sequence_no``.
|
||||
|
||||
Returns a ``list`` (not a generator) so the underlying
|
||||
``Result`` is fully drained before the caller can issue
|
||||
another query on the same connection.
|
||||
"""
|
||||
cursor = -1 if last_sequence_no is None else int(last_sequence_no)
|
||||
rows = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT message_id, sequence_no, event_type, payload, created_at
|
||||
FROM message_events
|
||||
WHERE message_id = CAST(:message_id AS uuid)
|
||||
AND sequence_no > :cursor
|
||||
ORDER BY sequence_no ASC
|
||||
"""
|
||||
),
|
||||
{"message_id": str(message_id), "cursor": cursor},
|
||||
).fetchall()
|
||||
return [row_to_dict(row) for row in rows]
|
||||
|
||||
def cleanup_older_than(self, ttl_days: int) -> int:
|
||||
"""Delete journal rows older than ``ttl_days``. Returns row count.
|
||||
|
||||
Reconnect-replay is meaningful only for streams the client
|
||||
could plausibly still be waiting on, so old rows are dead
|
||||
weight. The ``message_events_created_at_idx`` btree makes the
|
||||
range delete a cheap index scan even on large tables.
|
||||
"""
|
||||
if ttl_days <= 0:
|
||||
raise ValueError("ttl_days must be positive")
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
DELETE FROM message_events
|
||||
WHERE created_at < now() - make_interval(days => :ttl_days)
|
||||
"""
|
||||
),
|
||||
{"ttl_days": int(ttl_days)},
|
||||
)
|
||||
return int(result.rowcount or 0)
|
||||
|
||||
def reconstruct_partial(self, message_id: str) -> dict:
|
||||
"""Rebuild partial response/thought/sources/tool_calls from journal events.
|
||||
|
||||
``answer``/``thought`` chunks concat in seq order; ``source``/
|
||||
``tool_calls`` carry the full list at emit time (last-wins).
|
||||
"""
|
||||
rows = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT sequence_no, event_type, payload
|
||||
FROM message_events
|
||||
WHERE message_id = CAST(:message_id AS uuid)
|
||||
ORDER BY sequence_no ASC
|
||||
"""
|
||||
),
|
||||
{"message_id": str(message_id)},
|
||||
).fetchall()
|
||||
|
||||
response_parts: list[str] = []
|
||||
thought_parts: list[str] = []
|
||||
sources: list = []
|
||||
tool_calls: list = []
|
||||
|
||||
for row in rows:
|
||||
payload = row.payload
|
||||
if not isinstance(payload, dict):
|
||||
continue
|
||||
etype = row.event_type
|
||||
if etype == "answer":
|
||||
chunk = payload.get("answer")
|
||||
if isinstance(chunk, str):
|
||||
response_parts.append(chunk)
|
||||
elif etype == "thought":
|
||||
chunk = payload.get("thought")
|
||||
if isinstance(chunk, str):
|
||||
thought_parts.append(chunk)
|
||||
elif etype == "source":
|
||||
src = payload.get("source")
|
||||
if isinstance(src, list):
|
||||
sources = src
|
||||
elif etype == "tool_calls":
|
||||
tcs = payload.get("tool_calls")
|
||||
if isinstance(tcs, list):
|
||||
tool_calls = tcs
|
||||
|
||||
return {
|
||||
"response": "".join(response_parts),
|
||||
"thought": "".join(thought_parts),
|
||||
"sources": sources,
|
||||
"tool_calls": tool_calls,
|
||||
}
|
||||
|
||||
def latest_sequence_no(self, message_id: str) -> Optional[int]:
|
||||
"""Largest ``sequence_no`` recorded for ``message_id``, or ``None``.
|
||||
|
||||
Used by the route to seed the per-stream allocator on retry /
|
||||
process restart so a re-run continues numbering instead of
|
||||
trampling earlier entries with duplicate sequence_no.
|
||||
"""
|
||||
# ``MAX`` always returns one row — NULL when the journal is
|
||||
# empty — so we test the value, not the row presence.
|
||||
row = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT MAX(sequence_no) AS s
|
||||
FROM message_events
|
||||
WHERE message_id = CAST(:message_id AS uuid)
|
||||
"""
|
||||
),
|
||||
{"message_id": str(message_id)},
|
||||
).first()
|
||||
value = row[0] if row is not None else None
|
||||
return int(value) if value is not None else None
|
||||
@@ -7,6 +7,11 @@ Mirrors the continuation service's three operations on
|
||||
- load_state → find_one by (conversation_id, user_id)
|
||||
- delete_state → delete_one by (conversation_id, user_id)
|
||||
|
||||
Adds ``mark_resuming`` so a resumed run can claim a row without
|
||||
deleting it; a separate ``revert_stale_resuming`` flips abandoned
|
||||
``resuming`` rows back to ``pending`` so a crashed worker doesn't
|
||||
strand the user.
|
||||
|
||||
Plus a cleanup method for the Celery beat task that replaces Mongo's
|
||||
TTL index.
|
||||
"""
|
||||
@@ -20,6 +25,7 @@ from typing import Optional
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
from application.storage.db.serialization import PGNativeJSONEncoder
|
||||
|
||||
PENDING_STATE_TTL_SECONDS = 30 * 60 # 1800 seconds
|
||||
|
||||
@@ -71,19 +77,24 @@ class PendingToolStateRepository:
|
||||
agent_config = EXCLUDED.agent_config,
|
||||
client_tools = EXCLUDED.client_tools,
|
||||
created_at = EXCLUDED.created_at,
|
||||
expires_at = EXCLUDED.expires_at
|
||||
expires_at = EXCLUDED.expires_at,
|
||||
status = 'pending',
|
||||
resumed_at = NULL
|
||||
RETURNING *
|
||||
"""
|
||||
),
|
||||
{
|
||||
"conv_id": conversation_id,
|
||||
"user_id": user_id,
|
||||
"messages": json.dumps(messages),
|
||||
"pending": json.dumps(pending_tool_calls),
|
||||
"tools_dict": json.dumps(tools_dict),
|
||||
"schemas": json.dumps(tool_schemas),
|
||||
"agent_config": json.dumps(agent_config),
|
||||
"client_tools": json.dumps(client_tools) if client_tools is not None else None,
|
||||
"messages": json.dumps(messages, cls=PGNativeJSONEncoder),
|
||||
"pending": json.dumps(pending_tool_calls, cls=PGNativeJSONEncoder),
|
||||
"tools_dict": json.dumps(tools_dict, cls=PGNativeJSONEncoder),
|
||||
"schemas": json.dumps(tool_schemas, cls=PGNativeJSONEncoder),
|
||||
"agent_config": json.dumps(agent_config, cls=PGNativeJSONEncoder),
|
||||
"client_tools": (
|
||||
json.dumps(client_tools, cls=PGNativeJSONEncoder)
|
||||
if client_tools is not None else None
|
||||
),
|
||||
"created_at": now,
|
||||
"expires_at": expires,
|
||||
},
|
||||
@@ -113,6 +124,45 @@ class PendingToolStateRepository:
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def mark_resuming(self, conversation_id: str, user_id: str) -> bool:
|
||||
"""Flip a pending row to ``resuming`` and stamp ``resumed_at``."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE pending_tool_state
|
||||
SET status = 'resuming', resumed_at = clock_timestamp()
|
||||
WHERE conversation_id = CAST(:conv_id AS uuid)
|
||||
AND user_id = :user_id
|
||||
AND status = 'pending'
|
||||
"""
|
||||
),
|
||||
{"conv_id": conversation_id, "user_id": user_id},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def revert_stale_resuming(
|
||||
self,
|
||||
grace_seconds: int = 600,
|
||||
ttl_extension_seconds: int = PENDING_STATE_TTL_SECONDS,
|
||||
) -> int:
|
||||
"""Revert ``resuming`` rows older than ``grace_seconds`` to ``pending``; bump TTL."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE pending_tool_state
|
||||
SET status = 'pending',
|
||||
resumed_at = NULL,
|
||||
expires_at = clock_timestamp()
|
||||
+ make_interval(secs => :ttl)
|
||||
WHERE status = 'resuming'
|
||||
AND resumed_at
|
||||
< clock_timestamp() - make_interval(secs => :grace)
|
||||
"""
|
||||
),
|
||||
{"grace": grace_seconds, "ttl": ttl_extension_seconds},
|
||||
)
|
||||
return result.rowcount
|
||||
|
||||
def cleanup_expired(self) -> int:
|
||||
"""Delete rows where ``expires_at < now()``.
|
||||
|
||||
|
||||
282
application/storage/db/repositories/reconciliation.py
Normal file
282
application/storage/db/repositories/reconciliation.py
Normal file
@@ -0,0 +1,282 @@
|
||||
"""Repository for reconciliation sweeps over stuck durability rows."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
|
||||
|
||||
class ReconciliationRepository:
|
||||
"""Sweeps and terminal writes for the reconciler beat task."""
|
||||
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def find_and_lock_stuck_messages(
|
||||
self, *, age_minutes: int = 5, limit: int = 100,
|
||||
) -> list[dict]:
|
||||
"""Lock stuck pending/streaming messages skipping live resumes.
|
||||
|
||||
Staleness rides on the **later of** ``cm.timestamp`` (creation)
|
||||
and ``message_metadata.last_heartbeat_at`` (route heartbeat). An
|
||||
in-flight stream that re-stamps the heartbeat each minute stays
|
||||
out of the sweep; reconciler-side writes deliberately don't
|
||||
touch either column so the per-row attempts counter advances
|
||||
across ticks. Liveness exemption covers both ``pending`` (paused
|
||||
waiting for resume) and ``resuming`` (actively executing)
|
||||
``pending_tool_state`` rows so a paused message survives until
|
||||
the PT row's own TTL retires it.
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT cm.id, cm.conversation_id, cm.user_id, cm.timestamp,
|
||||
cm.message_metadata
|
||||
FROM conversation_messages cm
|
||||
WHERE cm.status IN ('pending', 'streaming')
|
||||
AND cm.timestamp < now() - make_interval(mins => :age)
|
||||
AND COALESCE(
|
||||
(cm.message_metadata->>'last_heartbeat_at')::timestamptz,
|
||||
cm.timestamp
|
||||
) < now() - make_interval(mins => :age)
|
||||
AND NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM pending_tool_state pts
|
||||
WHERE pts.conversation_id = cm.conversation_id
|
||||
AND (
|
||||
(pts.status = 'pending'
|
||||
AND pts.expires_at > now())
|
||||
OR
|
||||
(pts.status = 'resuming'
|
||||
AND pts.resumed_at
|
||||
> now() - interval '10 minutes')
|
||||
)
|
||||
)
|
||||
ORDER BY cm.timestamp ASC
|
||||
LIMIT :limit
|
||||
FOR UPDATE OF cm SKIP LOCKED
|
||||
"""
|
||||
),
|
||||
{"age": age_minutes, "limit": limit},
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def find_and_lock_proposed_tool_calls(
|
||||
self, *, age_minutes: int = 5, limit: int = 100,
|
||||
) -> list[dict]:
|
||||
"""Lock tool_call_attempts that never advanced past ``proposed``."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT call_id, message_id, tool_id, tool_name, action_name,
|
||||
arguments, attempted_at, updated_at
|
||||
FROM tool_call_attempts
|
||||
WHERE status = 'proposed'
|
||||
AND attempted_at < now() - make_interval(mins => :age)
|
||||
ORDER BY attempted_at ASC
|
||||
LIMIT :limit
|
||||
FOR UPDATE SKIP LOCKED
|
||||
"""
|
||||
),
|
||||
{"age": age_minutes, "limit": limit},
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def find_and_lock_executed_tool_calls(
|
||||
self, *, age_minutes: int = 15, limit: int = 100,
|
||||
) -> list[dict]:
|
||||
"""Lock tool_call_attempts stuck in ``executed`` past confirm window."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT call_id, message_id, tool_id, tool_name, action_name,
|
||||
arguments, result, attempted_at, updated_at
|
||||
FROM tool_call_attempts
|
||||
WHERE status = 'executed'
|
||||
AND updated_at < now() - make_interval(mins => :age)
|
||||
ORDER BY updated_at ASC
|
||||
LIMIT :limit
|
||||
FOR UPDATE SKIP LOCKED
|
||||
"""
|
||||
),
|
||||
{"age": age_minutes, "limit": limit},
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def find_and_lock_stalled_ingests(
|
||||
self, *, age_minutes: int = 30, limit: int = 100,
|
||||
) -> list[dict]:
|
||||
"""Lock still-active ingest checkpoints with a silent heartbeat.
|
||||
|
||||
The ``status = 'active'`` filter skips rows already escalated to
|
||||
``'stalled'``, so a dead ingest is alerted once, not every tick.
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT source_id, total_chunks, embedded_chunks,
|
||||
last_index, last_updated
|
||||
FROM ingest_chunk_progress
|
||||
WHERE last_updated < now() - make_interval(mins => :age)
|
||||
AND embedded_chunks < total_chunks
|
||||
AND status = 'active'
|
||||
ORDER BY last_updated ASC
|
||||
LIMIT :limit
|
||||
FOR UPDATE SKIP LOCKED
|
||||
"""
|
||||
),
|
||||
{"age": age_minutes, "limit": limit},
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def mark_ingest_stalled(self, source_id: str) -> bool:
|
||||
"""Escalate a stalled checkpoint to terminal ``status='stalled'``.
|
||||
|
||||
Drops the row out of the sweep so the reconciler alerts once;
|
||||
``init_progress`` flips it back to ``'active'`` on reingest.
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"UPDATE ingest_chunk_progress SET status = 'stalled' "
|
||||
"WHERE source_id = CAST(:sid AS uuid)"
|
||||
),
|
||||
{"sid": str(source_id)},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def increment_message_reconcile_attempts(self, message_id: str) -> int:
|
||||
"""Bump ``message_metadata.reconcile_attempts`` and return the new count."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE conversation_messages
|
||||
SET message_metadata = jsonb_set(
|
||||
COALESCE(message_metadata, '{}'::jsonb),
|
||||
'{reconcile_attempts}',
|
||||
to_jsonb(
|
||||
COALESCE(
|
||||
(message_metadata->>'reconcile_attempts')::int,
|
||||
0
|
||||
) + 1
|
||||
)
|
||||
)
|
||||
WHERE id = CAST(:message_id AS uuid)
|
||||
RETURNING (message_metadata->>'reconcile_attempts')::int
|
||||
AS new_count
|
||||
"""
|
||||
),
|
||||
{"message_id": message_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return int(row[0]) if row is not None else 0
|
||||
|
||||
def mark_message_failed(self, message_id: str, *, error: str) -> bool:
|
||||
"""Flip a message to ``status='failed'`` and stash ``error`` in metadata."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE conversation_messages
|
||||
SET status = 'failed',
|
||||
message_metadata = jsonb_set(
|
||||
COALESCE(message_metadata, '{}'::jsonb),
|
||||
'{error}',
|
||||
to_jsonb(CAST(:error AS text))
|
||||
)
|
||||
WHERE id = CAST(:message_id AS uuid)
|
||||
"""
|
||||
),
|
||||
{"message_id": message_id, "error": error},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def mark_tool_call_failed(self, call_id: str, *, error: str) -> bool:
|
||||
"""Flip a tool_call_attempts row to ``failed`` with ``error``."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"UPDATE tool_call_attempts SET status = 'failed', "
|
||||
"error = :error WHERE call_id = :call_id"
|
||||
),
|
||||
{"call_id": call_id, "error": error},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def find_stuck_idempotency_pending(
|
||||
self,
|
||||
*,
|
||||
max_attempts: int,
|
||||
lease_grace_seconds: int = 60,
|
||||
limit: int = 100,
|
||||
) -> list[dict]:
|
||||
"""Lock ``task_dedup`` rows abandoned past the lease + retry budget.
|
||||
|
||||
A row is "stuck" when:
|
||||
|
||||
- ``status='pending'`` (lease was claimed but never finalised)
|
||||
- ``lease_expires_at`` is past by at least ``lease_grace_seconds``
|
||||
(the heartbeat thread is gone — the lease isn't going to come
|
||||
back)
|
||||
- ``attempt_count >= max_attempts`` (the poison-loop guard
|
||||
should already have escalated this; if it hasn't, the wrapper
|
||||
died before getting there)
|
||||
|
||||
These rows would otherwise sit in ``pending`` until the 24 h
|
||||
TTL aged them out, blocking same-key retries via
|
||||
``_lookup_completed`` returning None for the whole window.
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT idempotency_key, task_name, task_id, attempt_count,
|
||||
lease_owner_id, lease_expires_at, created_at
|
||||
FROM task_dedup
|
||||
WHERE status = 'pending'
|
||||
AND lease_expires_at IS NOT NULL
|
||||
AND lease_expires_at
|
||||
< now() - make_interval(secs => :grace)
|
||||
AND attempt_count >= :max_attempts
|
||||
ORDER BY created_at ASC
|
||||
LIMIT :limit
|
||||
FOR UPDATE SKIP LOCKED
|
||||
"""
|
||||
),
|
||||
{
|
||||
"max_attempts": int(max_attempts),
|
||||
"grace": int(lease_grace_seconds),
|
||||
"limit": int(limit),
|
||||
},
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def mark_idempotency_pending_failed(
|
||||
self, key: str, *, error: str,
|
||||
) -> bool:
|
||||
"""Promote a stuck pending ``task_dedup`` row to ``failed``."""
|
||||
from application.storage.db.serialization import PGNativeJSONEncoder
|
||||
import json
|
||||
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE task_dedup
|
||||
SET status = 'failed',
|
||||
result_json = CAST(:result AS jsonb),
|
||||
lease_owner_id = NULL,
|
||||
lease_expires_at = NULL
|
||||
WHERE idempotency_key = :key
|
||||
AND status = 'pending'
|
||||
"""
|
||||
),
|
||||
{
|
||||
"key": key,
|
||||
"result": json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"error": error,
|
||||
"reconciled": True,
|
||||
},
|
||||
cls=PGNativeJSONEncoder,
|
||||
),
|
||||
},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
278
application/storage/db/repositories/schedule_runs.py
Normal file
278
application/storage/db/repositories/schedule_runs.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""Repository for ``schedule_runs`` (record_pending is the dedup primitive)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
|
||||
|
||||
_ALLOWED_UPDATES = frozenset(
|
||||
{
|
||||
"status", "started_at", "finished_at", "output", "output_truncated",
|
||||
"error", "error_type", "prompt_tokens", "generated_tokens",
|
||||
"conversation_id", "message_id", "celery_task_id",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class ScheduleRunsRepository:
|
||||
"""CRUD + dedup insert + reconciliation sweep for ``schedule_runs``."""
|
||||
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def record_pending(
|
||||
self,
|
||||
schedule_id: str,
|
||||
user_id: str,
|
||||
agent_id: Optional[str],
|
||||
scheduled_for: datetime,
|
||||
*,
|
||||
trigger_source: str = "cron",
|
||||
) -> Optional[dict]:
|
||||
"""Insert a ``pending`` row; ``None`` on conflict (already claimed)."""
|
||||
row = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO schedule_runs (
|
||||
schedule_id, user_id, agent_id, scheduled_for,
|
||||
trigger_source, status
|
||||
) VALUES (
|
||||
CAST(:schedule_id AS uuid),
|
||||
:user_id,
|
||||
CAST(:agent_id AS uuid),
|
||||
:scheduled_for,
|
||||
:trigger_source,
|
||||
'pending'
|
||||
)
|
||||
ON CONFLICT (schedule_id, scheduled_for) DO NOTHING
|
||||
RETURNING *
|
||||
"""
|
||||
),
|
||||
{
|
||||
"schedule_id": str(schedule_id),
|
||||
"user_id": user_id,
|
||||
"agent_id": str(agent_id) if agent_id else None,
|
||||
"scheduled_for": scheduled_for,
|
||||
"trigger_source": trigger_source,
|
||||
},
|
||||
).fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def record_skipped(
|
||||
self,
|
||||
schedule_id: str,
|
||||
user_id: str,
|
||||
agent_id: Optional[str],
|
||||
scheduled_for: datetime,
|
||||
*,
|
||||
error_type: str,
|
||||
error: Optional[str] = None,
|
||||
) -> Optional[dict]:
|
||||
"""Write a terminal ``skipped`` row; returns ``None`` on conflict."""
|
||||
row = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO schedule_runs (
|
||||
schedule_id, user_id, agent_id, scheduled_for,
|
||||
trigger_source, status, started_at, finished_at,
|
||||
error, error_type
|
||||
) VALUES (
|
||||
CAST(:schedule_id AS uuid),
|
||||
:user_id,
|
||||
CAST(:agent_id AS uuid),
|
||||
:scheduled_for,
|
||||
'cron',
|
||||
'skipped',
|
||||
now(),
|
||||
now(),
|
||||
:error,
|
||||
:error_type
|
||||
)
|
||||
ON CONFLICT (schedule_id, scheduled_for) DO NOTHING
|
||||
RETURNING *
|
||||
"""
|
||||
),
|
||||
{
|
||||
"schedule_id": str(schedule_id),
|
||||
"user_id": user_id,
|
||||
"agent_id": str(agent_id) if agent_id else None,
|
||||
"scheduled_for": scheduled_for,
|
||||
"error": error,
|
||||
"error_type": error_type,
|
||||
},
|
||||
).fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def get(self, run_id: str, user_id: str) -> Optional[dict]:
|
||||
"""Fetch an owned run row."""
|
||||
row = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM schedule_runs "
|
||||
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||
),
|
||||
{"id": str(run_id), "user_id": user_id},
|
||||
).fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def get_internal(self, run_id: str) -> Optional[dict]:
|
||||
"""Fetch a run row with no ownership scoping (worker-only)."""
|
||||
row = self._conn.execute(
|
||||
text("SELECT * FROM schedule_runs WHERE id = CAST(:id AS uuid)"),
|
||||
{"id": str(run_id)},
|
||||
).fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def has_active_run(self, schedule_id: str) -> bool:
|
||||
"""True iff a ``pending``/``running`` run exists for the schedule."""
|
||||
scalar = self._conn.execute(
|
||||
text(
|
||||
"SELECT 1 FROM schedule_runs "
|
||||
"WHERE schedule_id = CAST(:id AS uuid) "
|
||||
"AND status IN ('pending', 'running') "
|
||||
"LIMIT 1"
|
||||
),
|
||||
{"id": str(schedule_id)},
|
||||
).first()
|
||||
return scalar is not None
|
||||
|
||||
def list_runs(
|
||||
self,
|
||||
schedule_id: str,
|
||||
user_id: str,
|
||||
*,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> list[dict]:
|
||||
"""Paginated newest-first run log for an owned schedule."""
|
||||
rows = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT * FROM schedule_runs
|
||||
WHERE schedule_id = CAST(:id AS uuid) AND user_id = :user_id
|
||||
ORDER BY scheduled_for DESC
|
||||
LIMIT :limit OFFSET :offset
|
||||
"""
|
||||
),
|
||||
{
|
||||
"id": str(schedule_id),
|
||||
"user_id": user_id,
|
||||
"limit": int(limit),
|
||||
"offset": int(offset),
|
||||
},
|
||||
).fetchall()
|
||||
return [row_to_dict(r) for r in rows]
|
||||
|
||||
def update(self, run_id: str, fields: dict) -> Optional[dict]:
|
||||
"""Apply a whitelisted partial update to a run row."""
|
||||
filtered = {k: v for k, v in fields.items() if k in _ALLOWED_UPDATES}
|
||||
if not filtered:
|
||||
return self.get_internal(run_id)
|
||||
set_parts: list[str] = []
|
||||
params: dict[str, Any] = {"id": str(run_id)}
|
||||
for key, val in filtered.items():
|
||||
if key in ("conversation_id", "message_id"):
|
||||
set_parts.append(f"{key} = CAST(:{key} AS uuid)")
|
||||
params[key] = str(val) if val else None
|
||||
else:
|
||||
set_parts.append(f"{key} = :{key}")
|
||||
params[key] = val
|
||||
sql = (
|
||||
"UPDATE schedule_runs SET " + ", ".join(set_parts) +
|
||||
" WHERE id = CAST(:id AS uuid) RETURNING *"
|
||||
)
|
||||
row = self._conn.execute(text(sql), params).fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def mark_running(self, run_id: str, celery_task_id: Optional[str]) -> bool:
|
||||
"""Flip ``pending`` → ``running`` and stamp ``started_at``."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE schedule_runs
|
||||
SET status = 'running',
|
||||
started_at = now(),
|
||||
celery_task_id = :celery_task_id
|
||||
WHERE id = CAST(:id AS uuid)
|
||||
AND status = 'pending'
|
||||
"""
|
||||
),
|
||||
{"id": str(run_id), "celery_task_id": celery_task_id},
|
||||
)
|
||||
return (result.rowcount or 0) > 0
|
||||
|
||||
def list_stuck_running(
|
||||
self, *, age_minutes: int = 15, limit: int = 50,
|
||||
) -> list[dict]:
|
||||
"""Lock ``running`` rows past the soft-time-limit envelope."""
|
||||
rows = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT * FROM schedule_runs
|
||||
WHERE status = 'running'
|
||||
AND started_at IS NOT NULL
|
||||
AND started_at < now() - make_interval(mins => :age)
|
||||
ORDER BY started_at ASC
|
||||
LIMIT :limit
|
||||
FOR UPDATE SKIP LOCKED
|
||||
"""
|
||||
),
|
||||
{"age": int(age_minutes), "limit": int(limit)},
|
||||
).fetchall()
|
||||
return [row_to_dict(r) for r in rows]
|
||||
|
||||
def list_stuck_pending(
|
||||
self, *, age_minutes: int = 15, limit: int = 50,
|
||||
) -> list[dict]:
|
||||
"""Lock 'pending' rows whose worker never picked them up (created_at-based)."""
|
||||
rows = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT * FROM schedule_runs
|
||||
WHERE status = 'pending'
|
||||
AND started_at IS NULL
|
||||
AND created_at < now() - make_interval(mins => :age)
|
||||
ORDER BY created_at ASC
|
||||
LIMIT :limit
|
||||
FOR UPDATE SKIP LOCKED
|
||||
"""
|
||||
),
|
||||
{"age": int(age_minutes), "limit": int(limit)},
|
||||
).fetchall()
|
||||
return [row_to_dict(r) for r in rows]
|
||||
|
||||
def cleanup_older_than(
|
||||
self,
|
||||
ttl_days: int,
|
||||
*,
|
||||
keep_recent_per_schedule: int = 50,
|
||||
) -> int:
|
||||
"""Trim run rows older than ``ttl_days``, keeping the recent log slice."""
|
||||
if ttl_days <= 0:
|
||||
raise ValueError("ttl_days must be positive")
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
DELETE FROM schedule_runs
|
||||
WHERE id IN (
|
||||
SELECT id FROM (
|
||||
SELECT id,
|
||||
ROW_NUMBER() OVER (
|
||||
PARTITION BY schedule_id
|
||||
ORDER BY scheduled_for DESC
|
||||
) AS rn,
|
||||
created_at
|
||||
FROM schedule_runs
|
||||
) ranked
|
||||
WHERE ranked.rn > :keep
|
||||
AND ranked.created_at < now() - make_interval(days => :ttl)
|
||||
)
|
||||
"""
|
||||
),
|
||||
{"keep": int(keep_recent_per_schedule), "ttl": int(ttl_days)},
|
||||
)
|
||||
return int(result.rowcount or 0)
|
||||
352
application/storage/db/repositories/schedules.py
Normal file
352
application/storage/db/repositories/schedules.py
Normal file
@@ -0,0 +1,352 @@
|
||||
"""Repository for the ``schedules`` table (CRUD + dispatcher claim query)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any, Iterable, Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
|
||||
|
||||
_ALLOWED_UPDATES = frozenset(
|
||||
{
|
||||
"name", "instruction", "status", "cron", "run_at", "timezone",
|
||||
"next_run_at", "last_run_at", "end_at", "tool_allowlist",
|
||||
"model_id", "token_budget", "consecutive_failure_count",
|
||||
"origin_conversation_id",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class SchedulesRepository:
|
||||
"""CRUD + dispatcher hot path for ``schedules``."""
|
||||
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def create(
|
||||
self,
|
||||
user_id: str,
|
||||
agent_id: Optional[str],
|
||||
trigger_type: str,
|
||||
instruction: str,
|
||||
*,
|
||||
cron: Optional[str] = None,
|
||||
run_at: Optional[datetime] = None,
|
||||
timezone: str = "UTC",
|
||||
next_run_at: Optional[datetime] = None,
|
||||
end_at: Optional[datetime] = None,
|
||||
name: Optional[str] = None,
|
||||
tool_allowlist: Optional[Iterable[str]] = None,
|
||||
model_id: Optional[str] = None,
|
||||
token_budget: Optional[int] = None,
|
||||
origin_conversation_id: Optional[str] = None,
|
||||
created_via: str = "ui",
|
||||
status: str = "active",
|
||||
) -> dict:
|
||||
"""Insert a new schedule and return the populated row."""
|
||||
params = {
|
||||
"user_id": user_id,
|
||||
"agent_id": str(agent_id) if agent_id else None,
|
||||
"trigger_type": trigger_type,
|
||||
"instruction": instruction,
|
||||
"cron": cron,
|
||||
"run_at": run_at,
|
||||
"tz": timezone,
|
||||
"next_run_at": next_run_at,
|
||||
"end_at": end_at,
|
||||
"name": name,
|
||||
"allowlist": json.dumps(list(tool_allowlist or [])),
|
||||
"model_id": model_id,
|
||||
"token_budget": int(token_budget) if token_budget is not None else None,
|
||||
"origin_conversation_id": (
|
||||
str(origin_conversation_id) if origin_conversation_id else None
|
||||
),
|
||||
"created_via": created_via,
|
||||
"status": status,
|
||||
}
|
||||
row = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO schedules (
|
||||
user_id, agent_id, trigger_type, instruction, status,
|
||||
cron, run_at, timezone, next_run_at, end_at, name,
|
||||
tool_allowlist, model_id, token_budget,
|
||||
origin_conversation_id, created_via
|
||||
) VALUES (
|
||||
:user_id,
|
||||
CAST(:agent_id AS uuid),
|
||||
:trigger_type,
|
||||
:instruction,
|
||||
:status,
|
||||
:cron,
|
||||
:run_at,
|
||||
:tz,
|
||||
:next_run_at,
|
||||
:end_at,
|
||||
:name,
|
||||
CAST(:allowlist AS jsonb),
|
||||
:model_id,
|
||||
:token_budget,
|
||||
CAST(:origin_conversation_id AS uuid),
|
||||
:created_via
|
||||
) RETURNING *
|
||||
"""
|
||||
),
|
||||
params,
|
||||
).fetchone()
|
||||
return row_to_dict(row)
|
||||
|
||||
def get(self, schedule_id: str, user_id: str) -> Optional[dict]:
|
||||
"""Fetch an owned schedule (None when missing or owned by another)."""
|
||||
row = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM schedules "
|
||||
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||
),
|
||||
{"id": str(schedule_id), "user_id": user_id},
|
||||
).fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def get_internal(self, schedule_id: str) -> Optional[dict]:
|
||||
"""Fetch a schedule with no ownership scoping (worker-only)."""
|
||||
row = self._conn.execute(
|
||||
text("SELECT * FROM schedules WHERE id = CAST(:id AS uuid)"),
|
||||
{"id": str(schedule_id)},
|
||||
).fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def get_for_update(
|
||||
self, schedule_id: str, user_id: str,
|
||||
) -> Optional[dict]:
|
||||
"""Owned fetch with FOR UPDATE; closes the Run-Now TOCTOU."""
|
||||
row = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM schedules "
|
||||
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id "
|
||||
"FOR UPDATE"
|
||||
),
|
||||
{"id": str(schedule_id), "user_id": user_id},
|
||||
).fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def list_for_agent(
|
||||
self,
|
||||
agent_id: str,
|
||||
user_id: str,
|
||||
*,
|
||||
statuses: Optional[Iterable[str]] = None,
|
||||
trigger_type: Optional[str] = None,
|
||||
) -> list[dict]:
|
||||
"""Owned schedules for an agent, newest-created first."""
|
||||
sql = (
|
||||
"SELECT * FROM schedules "
|
||||
"WHERE agent_id = CAST(:agent_id AS uuid) AND user_id = :user_id"
|
||||
)
|
||||
params: dict[str, Any] = {"agent_id": str(agent_id), "user_id": user_id}
|
||||
if statuses is not None:
|
||||
status_list = [str(s) for s in statuses]
|
||||
if not status_list:
|
||||
return []
|
||||
placeholders = ", ".join(f":s{i}" for i, _ in enumerate(status_list))
|
||||
sql += f" AND status IN ({placeholders})"
|
||||
for i, s in enumerate(status_list):
|
||||
params[f"s{i}"] = s
|
||||
if trigger_type:
|
||||
sql += " AND trigger_type = :trigger_type"
|
||||
params["trigger_type"] = trigger_type
|
||||
sql += " ORDER BY created_at DESC"
|
||||
rows = self._conn.execute(text(sql), params).fetchall()
|
||||
return [row_to_dict(r) for r in rows]
|
||||
|
||||
def list_for_conversation(
|
||||
self,
|
||||
user_id: str,
|
||||
origin_conversation_id: str,
|
||||
*,
|
||||
statuses: Optional[Iterable[str]] = None,
|
||||
trigger_type: Optional[str] = None,
|
||||
) -> list[dict]:
|
||||
"""Owned agentless schedules anchored to an originating conversation."""
|
||||
sql = (
|
||||
"SELECT * FROM schedules "
|
||||
"WHERE user_id = :user_id "
|
||||
"AND agent_id IS NULL "
|
||||
"AND origin_conversation_id = CAST(:conv AS uuid)"
|
||||
)
|
||||
params: dict[str, Any] = {
|
||||
"user_id": user_id,
|
||||
"conv": str(origin_conversation_id),
|
||||
}
|
||||
if statuses is not None:
|
||||
status_list = [str(s) for s in statuses]
|
||||
if not status_list:
|
||||
return []
|
||||
placeholders = ", ".join(f":s{i}" for i, _ in enumerate(status_list))
|
||||
sql += f" AND status IN ({placeholders})"
|
||||
for i, s in enumerate(status_list):
|
||||
params[f"s{i}"] = s
|
||||
if trigger_type:
|
||||
sql += " AND trigger_type = :trigger_type"
|
||||
params["trigger_type"] = trigger_type
|
||||
sql += " ORDER BY created_at DESC"
|
||||
rows = self._conn.execute(text(sql), params).fetchall()
|
||||
return [row_to_dict(r) for r in rows]
|
||||
|
||||
def list_for_user(self, user_id: str, *, limit: int = 200) -> list[dict]:
|
||||
"""Owned schedules across all agents — admin / debugging path."""
|
||||
rows = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM schedules WHERE user_id = :user_id "
|
||||
"ORDER BY created_at DESC LIMIT :limit"
|
||||
),
|
||||
{"user_id": user_id, "limit": int(limit)},
|
||||
).fetchall()
|
||||
return [row_to_dict(r) for r in rows]
|
||||
|
||||
def count_active_for_user(self, user_id: str) -> int:
|
||||
"""Active+paused schedules for quota enforcement."""
|
||||
scalar = self._conn.execute(
|
||||
text(
|
||||
"SELECT COUNT(*) FROM schedules "
|
||||
"WHERE user_id = :user_id AND status IN ('active', 'paused')"
|
||||
),
|
||||
{"user_id": user_id},
|
||||
).scalar()
|
||||
return int(scalar or 0)
|
||||
|
||||
def list_due(self, *, limit: int = 100) -> list[dict]:
|
||||
"""Lock and return schedules with ``next_run_at <= now()``."""
|
||||
rows = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT * FROM schedules
|
||||
WHERE status = 'active'
|
||||
AND next_run_at IS NOT NULL
|
||||
AND next_run_at <= now()
|
||||
AND (end_at IS NULL OR next_run_at <= end_at)
|
||||
ORDER BY next_run_at ASC
|
||||
LIMIT :limit
|
||||
FOR UPDATE SKIP LOCKED
|
||||
"""
|
||||
),
|
||||
{"limit": int(limit)},
|
||||
).fetchall()
|
||||
return [row_to_dict(r) for r in rows]
|
||||
|
||||
def update(
|
||||
self,
|
||||
schedule_id: str,
|
||||
user_id: str,
|
||||
fields: dict,
|
||||
) -> Optional[dict]:
|
||||
"""Apply a whitelisted partial update; return the new row or None."""
|
||||
filtered = {k: v for k, v in fields.items() if k in _ALLOWED_UPDATES}
|
||||
if not filtered:
|
||||
return self.get(schedule_id, user_id)
|
||||
set_parts: list[str] = []
|
||||
params: dict[str, Any] = {"id": str(schedule_id), "user_id": user_id}
|
||||
for key, val in filtered.items():
|
||||
if key == "tool_allowlist":
|
||||
set_parts.append("tool_allowlist = CAST(:tool_allowlist AS jsonb)")
|
||||
params["tool_allowlist"] = json.dumps(list(val or []))
|
||||
elif key == "origin_conversation_id":
|
||||
set_parts.append(
|
||||
"origin_conversation_id = CAST(:origin_conversation_id AS uuid)"
|
||||
)
|
||||
params["origin_conversation_id"] = str(val) if val else None
|
||||
else:
|
||||
set_parts.append(f"{key} = :{key}")
|
||||
params[key] = val
|
||||
sql = (
|
||||
"UPDATE schedules SET " + ", ".join(set_parts) +
|
||||
" WHERE id = CAST(:id AS uuid) AND user_id = :user_id "
|
||||
"RETURNING *"
|
||||
)
|
||||
row = self._conn.execute(text(sql), params).fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def update_internal(self, schedule_id: str, fields: dict) -> None:
|
||||
"""Apply a whitelisted partial update from a worker context."""
|
||||
filtered = {k: v for k, v in fields.items() if k in _ALLOWED_UPDATES}
|
||||
if not filtered:
|
||||
return
|
||||
set_parts: list[str] = []
|
||||
params: dict[str, Any] = {"id": str(schedule_id)}
|
||||
for key, val in filtered.items():
|
||||
if key == "tool_allowlist":
|
||||
set_parts.append("tool_allowlist = CAST(:tool_allowlist AS jsonb)")
|
||||
params["tool_allowlist"] = json.dumps(list(val or []))
|
||||
elif key == "origin_conversation_id":
|
||||
set_parts.append(
|
||||
"origin_conversation_id = CAST(:origin_conversation_id AS uuid)"
|
||||
)
|
||||
params["origin_conversation_id"] = str(val) if val else None
|
||||
else:
|
||||
set_parts.append(f"{key} = :{key}")
|
||||
params[key] = val
|
||||
sql = (
|
||||
"UPDATE schedules SET " + ", ".join(set_parts) +
|
||||
" WHERE id = CAST(:id AS uuid)"
|
||||
)
|
||||
self._conn.execute(text(sql), params)
|
||||
|
||||
def cancel(self, schedule_id: str, user_id: str) -> bool:
|
||||
"""Soft-cancel — flips ``status`` to ``cancelled`` and clears ``next_run_at``."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"UPDATE schedules SET status = 'cancelled', next_run_at = NULL "
|
||||
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id "
|
||||
"AND status NOT IN ('cancelled', 'completed')"
|
||||
),
|
||||
{"id": str(schedule_id), "user_id": user_id},
|
||||
)
|
||||
return (result.rowcount or 0) > 0
|
||||
|
||||
def delete(self, schedule_id: str, user_id: str) -> bool:
|
||||
"""Hard-delete an owned schedule and its runs (FK cascade)."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"DELETE FROM schedules "
|
||||
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||
),
|
||||
{"id": str(schedule_id), "user_id": user_id},
|
||||
)
|
||||
return (result.rowcount or 0) > 0
|
||||
|
||||
def bump_failure_count(self, schedule_id: str) -> int:
|
||||
"""Increment ``consecutive_failure_count`` and return the new value."""
|
||||
row = self._conn.execute(
|
||||
text(
|
||||
"UPDATE schedules "
|
||||
"SET consecutive_failure_count = consecutive_failure_count + 1 "
|
||||
"WHERE id = CAST(:id AS uuid) "
|
||||
"RETURNING consecutive_failure_count"
|
||||
),
|
||||
{"id": str(schedule_id)},
|
||||
).fetchone()
|
||||
return int(row[0]) if row is not None else 0
|
||||
|
||||
def reset_failure_count(self, schedule_id: str) -> None:
|
||||
"""Reset the failure counter to 0 after a successful run."""
|
||||
self._conn.execute(
|
||||
text(
|
||||
"UPDATE schedules SET consecutive_failure_count = 0 "
|
||||
"WHERE id = CAST(:id AS uuid)"
|
||||
),
|
||||
{"id": str(schedule_id)},
|
||||
)
|
||||
|
||||
def autopause(self, schedule_id: str) -> bool:
|
||||
"""Flip an active schedule to ``paused`` after repeated failures."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"UPDATE schedules SET status = 'paused', next_run_at = NULL "
|
||||
"WHERE id = CAST(:id AS uuid) AND status = 'active'"
|
||||
),
|
||||
{"id": str(schedule_id)},
|
||||
)
|
||||
return (result.rowcount or 0) > 0
|
||||
@@ -5,10 +5,10 @@ from __future__ import annotations
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import Connection, func, select, text
|
||||
from sqlalchemy import case, Connection, func, select, text
|
||||
|
||||
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
|
||||
from application.storage.db.models import sources_table
|
||||
from application.storage.db.models import ingest_chunk_progress_table, sources_table
|
||||
|
||||
|
||||
_SCALAR_COLUMNS = {
|
||||
@@ -61,6 +61,21 @@ def _coerce_jsonb(value: Any) -> Any:
|
||||
return value
|
||||
|
||||
|
||||
def _ingest_status_case():
|
||||
"""Derive a user-facing ingest status from the joined progress row.
|
||||
|
||||
``failed`` — reconciler-escalated stall. ``processing`` — embed in
|
||||
flight. ``None`` — no progress row, or the embed completed.
|
||||
"""
|
||||
icp = ingest_chunk_progress_table
|
||||
return case(
|
||||
(icp.c.source_id.is_(None), None),
|
||||
(icp.c.status == "stalled", "failed"),
|
||||
(icp.c.embedded_chunks < icp.c.total_chunks, "processing"),
|
||||
else_=None,
|
||||
).label("ingest_status")
|
||||
|
||||
|
||||
class SourcesRepository:
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
@@ -192,13 +207,25 @@ class SourcesRepository:
|
||||
as ``"desc"``.
|
||||
|
||||
Returns:
|
||||
A list of source rows as plain dicts (via ``row_to_dict``).
|
||||
A list of source rows as plain dicts (via ``row_to_dict``),
|
||||
each carrying a derived ``ingest_status`` (``failed`` /
|
||||
``processing`` / ``None``) from the joined progress row.
|
||||
"""
|
||||
column_name = sort_field if sort_field in _SORTABLE_COLUMNS else "date"
|
||||
sort_column = sources_table.c[column_name]
|
||||
ascending = sort_order.lower() == "asc"
|
||||
|
||||
stmt = select(sources_table).where(sources_table.c.user_id == user_id)
|
||||
stmt = (
|
||||
select(sources_table, _ingest_status_case())
|
||||
.select_from(
|
||||
sources_table.outerjoin(
|
||||
ingest_chunk_progress_table,
|
||||
ingest_chunk_progress_table.c.source_id
|
||||
== sources_table.c.id,
|
||||
)
|
||||
)
|
||||
.where(sources_table.c.user_id == user_id)
|
||||
)
|
||||
if search_term:
|
||||
stmt = stmt.where(
|
||||
sources_table.c.name.ilike(
|
||||
|
||||
@@ -13,6 +13,8 @@ import json
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from application.storage.db.serialization import PGNativeJSONEncoder
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
|
||||
@@ -52,7 +54,7 @@ class StackLogsRepository:
|
||||
"user_id": user_id,
|
||||
"api_key": api_key,
|
||||
"query": query,
|
||||
"stacks": json.dumps(stacks or []),
|
||||
"stacks": json.dumps(stacks or [], cls=PGNativeJSONEncoder),
|
||||
"timestamp": timestamp,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -31,6 +31,8 @@ class TokenUsageRepository:
|
||||
agent_id: Optional[str] = None,
|
||||
prompt_tokens: int = 0,
|
||||
generated_tokens: int = 0,
|
||||
source: str = "agent_stream",
|
||||
request_id: Optional[str] = None,
|
||||
timestamp: Optional[datetime] = None,
|
||||
) -> None:
|
||||
# Attribution guard: the ``token_usage_attribution_chk`` CHECK
|
||||
@@ -54,12 +56,16 @@ class TokenUsageRepository:
|
||||
self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO token_usage (user_id, api_key, agent_id, prompt_tokens, generated_tokens, timestamp)
|
||||
INSERT INTO token_usage (
|
||||
user_id, api_key, agent_id,
|
||||
prompt_tokens, generated_tokens,
|
||||
source, request_id, timestamp
|
||||
)
|
||||
VALUES (
|
||||
:user_id, :api_key,
|
||||
CAST(:agent_id AS uuid),
|
||||
:prompt_tokens, :generated_tokens,
|
||||
COALESCE(:timestamp, now())
|
||||
:source, :request_id, COALESCE(:timestamp, now())
|
||||
)
|
||||
"""
|
||||
),
|
||||
@@ -69,6 +75,8 @@ class TokenUsageRepository:
|
||||
"agent_id": agent_id_uuid,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"generated_tokens": generated_tokens,
|
||||
"source": source,
|
||||
"request_id": request_id,
|
||||
"timestamp": timestamp,
|
||||
},
|
||||
)
|
||||
@@ -173,8 +181,22 @@ class TokenUsageRepository:
|
||||
user_id: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
) -> int:
|
||||
"""Count of token_usage rows in the given time range (for request limiting)."""
|
||||
clauses = ["timestamp >= :start", "timestamp <= :end"]
|
||||
"""Count user-initiated requests in the given time range.
|
||||
|
||||
A request = one ``agent_stream`` invocation. Multi-tool agent
|
||||
runs produce multiple rows (one per LLM call) tagged with the
|
||||
same ``request_id``; we DISTINCT on that to count the request
|
||||
once. Pre-migration rows have ``request_id=NULL`` and are
|
||||
counted one-per-row via the second branch (back-compat).
|
||||
Side-channel sources (``title`` / ``compression`` /
|
||||
``rag_condense`` / ``fallback``) are excluded — they aren't
|
||||
user-initiated and shouldn't tick the request limit.
|
||||
"""
|
||||
clauses = [
|
||||
"timestamp >= :start",
|
||||
"timestamp <= :end",
|
||||
"source = 'agent_stream'",
|
||||
]
|
||||
params: dict = {"start": start, "end": end}
|
||||
if user_id is not None:
|
||||
clauses.append("user_id = :user_id")
|
||||
@@ -184,7 +206,15 @@ class TokenUsageRepository:
|
||||
params["api_key"] = api_key
|
||||
where = " AND ".join(clauses)
|
||||
result = self._conn.execute(
|
||||
text(f"SELECT COUNT(*) FROM token_usage WHERE {where}"),
|
||||
text(
|
||||
f"""
|
||||
SELECT
|
||||
COUNT(DISTINCT request_id) FILTER (WHERE request_id IS NOT NULL)
|
||||
+ COUNT(*) FILTER (WHERE request_id IS NULL)
|
||||
FROM token_usage
|
||||
WHERE {where}
|
||||
"""
|
||||
),
|
||||
params,
|
||||
)
|
||||
return result.scalar()
|
||||
|
||||
151
application/storage/db/repositories/tool_call_attempts.py
Normal file
151
application/storage/db/repositories/tool_call_attempts.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""Repository for ``tool_call_attempts``; executor's proposed/executed/failed writes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
from application.storage.db.serialization import PGNativeJSONEncoder
|
||||
|
||||
|
||||
class ToolCallAttemptsRepository:
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def record_proposed(
|
||||
self,
|
||||
call_id: str,
|
||||
tool_name: str,
|
||||
action_name: str,
|
||||
arguments: Any,
|
||||
*,
|
||||
tool_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Insert a ``proposed`` row before the tool executes.
|
||||
|
||||
Returns True if a new row was created. ``ON CONFLICT DO NOTHING``
|
||||
guards against the LLM emitting a duplicate ``call_id``: the
|
||||
existing row stays put rather than a re-insert raising
|
||||
``IntegrityError``.
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO tool_call_attempts
|
||||
(call_id, tool_id, tool_name, action_name, arguments, status)
|
||||
VALUES
|
||||
(:call_id, CAST(:tool_id AS uuid), :tool_name,
|
||||
:action_name, CAST(:arguments AS jsonb), 'proposed')
|
||||
ON CONFLICT (call_id) DO NOTHING
|
||||
"""
|
||||
),
|
||||
{
|
||||
"call_id": call_id,
|
||||
"tool_id": tool_id,
|
||||
"tool_name": tool_name,
|
||||
"action_name": action_name,
|
||||
"arguments": json.dumps(arguments if arguments is not None else {}, cls=PGNativeJSONEncoder),
|
||||
},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def upsert_executed(
|
||||
self,
|
||||
call_id: str,
|
||||
tool_name: str,
|
||||
action_name: str,
|
||||
arguments: Any,
|
||||
result: Any,
|
||||
*,
|
||||
tool_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
artifact_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Insert OR upgrade a row to ``executed`` — or ``confirmed`` when
|
||||
there is no ``message_id``, as in ``mark_executed``.
|
||||
|
||||
Used as a fallback when ``record_proposed`` failed (DB outage)
|
||||
and the tool ran anyway — preserves the journal so the
|
||||
reconciler can still see the attempt.
|
||||
"""
|
||||
result_payload: dict = {"result": result}
|
||||
if artifact_id:
|
||||
result_payload["artifact_id"] = artifact_id
|
||||
status = "executed" if message_id is not None else "confirmed"
|
||||
self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO tool_call_attempts
|
||||
(call_id, tool_id, tool_name, action_name, arguments,
|
||||
result, message_id, status)
|
||||
VALUES
|
||||
(:call_id, CAST(:tool_id AS uuid), :tool_name,
|
||||
:action_name, CAST(:arguments AS jsonb),
|
||||
CAST(:result AS jsonb), CAST(:message_id AS uuid),
|
||||
:status)
|
||||
ON CONFLICT (call_id) DO UPDATE
|
||||
SET status = :status,
|
||||
result = EXCLUDED.result,
|
||||
message_id = COALESCE(EXCLUDED.message_id, tool_call_attempts.message_id)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"call_id": call_id,
|
||||
"tool_id": tool_id,
|
||||
"tool_name": tool_name,
|
||||
"action_name": action_name,
|
||||
"arguments": json.dumps(arguments if arguments is not None else {}, cls=PGNativeJSONEncoder),
|
||||
"result": json.dumps(result_payload, cls=PGNativeJSONEncoder),
|
||||
"message_id": message_id,
|
||||
"status": status,
|
||||
},
|
||||
)
|
||||
|
||||
def mark_executed(
|
||||
self,
|
||||
call_id: str,
|
||||
result: Any,
|
||||
*,
|
||||
message_id: Optional[str] = None,
|
||||
artifact_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Flip ``proposed`` → ``executed``, or straight to ``confirmed``
|
||||
when there is no ``message_id`` (a ``save_conversation=False``
|
||||
request reserves no message, so no finalize will confirm it).
|
||||
|
||||
``artifact_id`` (when present) is stored alongside ``result`` in
|
||||
the JSONB as audit data — the reconciler reads it for diagnostic
|
||||
alerts when escalating stuck rows to ``failed``.
|
||||
"""
|
||||
result_payload: dict = {"result": result}
|
||||
if artifact_id:
|
||||
result_payload["artifact_id"] = artifact_id
|
||||
status = "executed" if message_id is not None else "confirmed"
|
||||
sql = (
|
||||
"UPDATE tool_call_attempts SET "
|
||||
"status = :status, result = CAST(:result AS jsonb)"
|
||||
)
|
||||
params: dict[str, Any] = {
|
||||
"call_id": call_id,
|
||||
"status": status,
|
||||
"result": json.dumps(result_payload, cls=PGNativeJSONEncoder),
|
||||
}
|
||||
if message_id is not None:
|
||||
sql += ", message_id = CAST(:message_id AS uuid)"
|
||||
params["message_id"] = message_id
|
||||
sql += " WHERE call_id = :call_id"
|
||||
result_proxy = self._conn.execute(text(sql), params)
|
||||
return result_proxy.rowcount > 0
|
||||
|
||||
def mark_failed(self, call_id: str, error: str) -> bool:
|
||||
"""Flip ``proposed`` → ``failed`` with the exception text."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"UPDATE tool_call_attempts SET status = 'failed', error = :error "
|
||||
"WHERE call_id = :call_id"
|
||||
),
|
||||
{"call_id": call_id, "error": error},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
@@ -20,6 +20,7 @@ from typing import Optional
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
from application.storage.db.serialization import PGNativeJSONEncoder
|
||||
|
||||
|
||||
class UserLogsRepository:
|
||||
@@ -46,7 +47,7 @@ class UserLogsRepository:
|
||||
{
|
||||
"user_id": user_id,
|
||||
"endpoint": endpoint,
|
||||
"data": json.dumps(data, default=str) if data is not None else None,
|
||||
"data": json.dumps(data, cls=PGNativeJSONEncoder) if data is not None else None,
|
||||
"timestamp": timestamp,
|
||||
},
|
||||
)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user