mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-22 13:25:08 +00:00
Compare commits
32 Commits
dependabot
...
feat-notif
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
827a0bb382 | ||
|
|
b04cb44ab5 | ||
|
|
42384a0e92 | ||
|
|
0bce35ad29 | ||
|
|
9de8bb4499 | ||
|
|
cdbd3f061d | ||
|
|
2ac46fd858 | ||
|
|
daa4320da2 | ||
|
|
e70a7a5115 | ||
|
|
150d9f4e37 | ||
|
|
746bcbc5f9 | ||
|
|
aa91117fbf | ||
|
|
abbd56cb66 | ||
|
|
85d8375e6c | ||
|
|
7e98d21b61 | ||
|
|
249f9f9fe0 | ||
|
|
6c4346eb84 | ||
|
|
cb3ca8a36b | ||
|
|
4c8230fb6c | ||
|
|
649557798d | ||
|
|
afe8354ca5 | ||
|
|
5483eb0e27 | ||
|
|
bd2985db47 | ||
|
|
b99147ba83 | ||
|
|
c3023f8b71 | ||
|
|
c168a530f5 | ||
|
|
2d539f3199 | ||
|
|
ed9444cf3d | ||
|
|
4d6f360e3a | ||
|
|
e245057822 | ||
|
|
e692c645b9 | ||
|
|
b4c4ab68f0 |
@@ -1,18 +1,107 @@
|
||||
import logging
|
||||
import uuid
|
||||
from collections import Counter
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from application.agents.tools.tool_action_parser import ToolActionParser
|
||||
from application.agents.tools.tool_manager import ToolManager
|
||||
from application.security.encryption import decrypt_credentials
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.tool_call_attempts import (
|
||||
ToolCallAttemptsRepository,
|
||||
)
|
||||
from application.storage.db.repositories.user_tools import UserToolsRepository
|
||||
from application.storage.db.session import db_readonly
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _record_proposed(
|
||||
call_id: str,
|
||||
tool_name: str,
|
||||
action_name: str,
|
||||
arguments: Any,
|
||||
*,
|
||||
tool_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Insert a ``proposed`` row; swallow infra failures so tool calls
|
||||
still run when the journal is unreachable. Returns True iff the row
|
||||
is now journaled (newly created or already present).
|
||||
"""
|
||||
try:
|
||||
with db_session() as conn:
|
||||
inserted = ToolCallAttemptsRepository(conn).record_proposed(
|
||||
call_id,
|
||||
tool_name,
|
||||
action_name,
|
||||
arguments,
|
||||
tool_id=tool_id if tool_id and looks_like_uuid(tool_id) else None,
|
||||
)
|
||||
if not inserted:
|
||||
logger.warning(
|
||||
"tool_call_attempts duplicate call_id=%s; existing row left in place",
|
||||
call_id,
|
||||
extra={"alert": "tool_call_id_collision", "call_id": call_id},
|
||||
)
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("tool_call_attempts proposed write failed for %s", call_id)
|
||||
return False
|
||||
|
||||
|
||||
def _mark_executed(
|
||||
call_id: str,
|
||||
result: Any,
|
||||
*,
|
||||
message_id: Optional[str] = None,
|
||||
artifact_id: Optional[str] = None,
|
||||
proposed_ok: bool = True,
|
||||
tool_name: Optional[str] = None,
|
||||
action_name: Optional[str] = None,
|
||||
arguments: Any = None,
|
||||
tool_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Flip the row to ``executed``. If ``proposed_ok`` is False (the
|
||||
proposed write failed earlier), upsert a fresh row in ``executed`` so
|
||||
the reconciler can still see the attempt — without this, the side
|
||||
effect would be invisible to the journal.
|
||||
"""
|
||||
try:
|
||||
with db_session() as conn:
|
||||
repo = ToolCallAttemptsRepository(conn)
|
||||
if proposed_ok:
|
||||
updated = repo.mark_executed(
|
||||
call_id,
|
||||
result,
|
||||
message_id=message_id,
|
||||
artifact_id=artifact_id,
|
||||
)
|
||||
if updated:
|
||||
return
|
||||
# Fallback synthesizes the row so the journal isn't lost.
|
||||
repo.upsert_executed(
|
||||
call_id,
|
||||
tool_name=tool_name or "unknown",
|
||||
action_name=action_name or "",
|
||||
arguments=arguments if arguments is not None else {},
|
||||
result=result,
|
||||
tool_id=tool_id if tool_id and looks_like_uuid(tool_id) else None,
|
||||
message_id=message_id,
|
||||
artifact_id=artifact_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("tool_call_attempts executed write failed for %s", call_id)
|
||||
|
||||
|
||||
def _mark_failed(call_id: str, error: str) -> None:
|
||||
try:
|
||||
with db_session() as conn:
|
||||
ToolCallAttemptsRepository(conn).mark_failed(call_id, error)
|
||||
except Exception:
|
||||
logger.exception("tool_call_attempts failed-write failed for %s", call_id)
|
||||
|
||||
|
||||
class ToolExecutor:
|
||||
"""Handles tool discovery, preparation, and execution.
|
||||
|
||||
@@ -31,6 +120,7 @@ class ToolExecutor:
|
||||
self.tool_calls: List[Dict] = []
|
||||
self._loaded_tools: Dict[str, object] = {}
|
||||
self.conversation_id: Optional[str] = None
|
||||
self.message_id: Optional[str] = None
|
||||
self.client_tools: Optional[List[Dict]] = None
|
||||
self._name_to_tool: Dict[str, Tuple[str, str]] = {}
|
||||
self._tool_to_name: Dict[Tuple[str, str], str] = {}
|
||||
@@ -323,9 +413,36 @@ class ToolExecutor:
|
||||
"action_name": llm_name,
|
||||
"arguments": call_args,
|
||||
}
|
||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "pending"}}
|
||||
|
||||
tool_data = tools_dict[tool_id]
|
||||
# Journal first so the reconciler sees malformed calls and any
|
||||
# subsequent ``_mark_failed`` actually updates a real row.
|
||||
proposed_ok = _record_proposed(
|
||||
call_id,
|
||||
tool_data["name"],
|
||||
action_name,
|
||||
call_args if isinstance(call_args, dict) else {},
|
||||
tool_id=tool_data.get("id"),
|
||||
)
|
||||
# Defensive guard: a non-dict ``call_args`` (e.g. malformed
|
||||
# JSON on the resume path) would crash the param walk below
|
||||
# with AttributeError on ``.items()``. Surface a clean error
|
||||
# event and flip the journal row to ``failed`` instead of
|
||||
# killing the stream.
|
||||
if not isinstance(call_args, dict):
|
||||
error_message = (
|
||||
f"Tool call arguments must be a JSON object, got "
|
||||
f"{type(call_args).__name__}."
|
||||
)
|
||||
tool_call_data["result"] = error_message
|
||||
tool_call_data["arguments"] = {}
|
||||
_mark_failed(call_id, error_message)
|
||||
yield {
|
||||
"type": "tool_call",
|
||||
"data": {**tool_call_data, "status": "error"},
|
||||
}
|
||||
self.tool_calls.append(tool_call_data)
|
||||
return error_message, call_id
|
||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "pending"}}
|
||||
action_data = (
|
||||
tool_data["config"]["actions"][action_name]
|
||||
if tool_data["name"] == "api_tool"
|
||||
@@ -381,6 +498,7 @@ class ToolExecutor:
|
||||
},
|
||||
)
|
||||
tool_call_data["result"] = error_message
|
||||
_mark_failed(call_id, error_message)
|
||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
|
||||
self.tool_calls.append(tool_call_data)
|
||||
return error_message, call_id
|
||||
@@ -390,14 +508,18 @@ class ToolExecutor:
|
||||
if tool_data["name"] == "api_tool"
|
||||
else parameters
|
||||
)
|
||||
if tool_data["name"] == "api_tool":
|
||||
logger.debug(
|
||||
f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}"
|
||||
)
|
||||
result = tool.execute_action(action_name, **body)
|
||||
else:
|
||||
logger.debug(f"Executing tool: {action_name} with args: {call_args}")
|
||||
result = tool.execute_action(action_name, **parameters)
|
||||
try:
|
||||
if tool_data["name"] == "api_tool":
|
||||
logger.debug(
|
||||
f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}"
|
||||
)
|
||||
result = tool.execute_action(action_name, **body)
|
||||
else:
|
||||
logger.debug(f"Executing tool: {action_name} with args: {call_args}")
|
||||
result = tool.execute_action(action_name, **parameters)
|
||||
except Exception as exc:
|
||||
_mark_failed(call_id, str(exc))
|
||||
raise
|
||||
|
||||
get_artifact_id = (
|
||||
getattr(tool, "get_artifact_id", None)
|
||||
@@ -426,6 +548,22 @@ class ToolExecutor:
|
||||
f"{result_full[:50]}..." if len(result_full) > 50 else result_full
|
||||
)
|
||||
|
||||
# Tool side effect has run; flip the journal row so the
|
||||
# message-finalize path can later confirm it. If the proposed
|
||||
# write failed (DB outage), upsert a fresh row in ``executed`` so
|
||||
# the reconciler still sees the side effect.
|
||||
_mark_executed(
|
||||
call_id,
|
||||
result_full,
|
||||
message_id=self.message_id,
|
||||
artifact_id=artifact_id or None,
|
||||
proposed_ok=proposed_ok,
|
||||
tool_name=tool_data["name"],
|
||||
action_name=action_name,
|
||||
arguments=call_args,
|
||||
tool_id=tool_data.get("id"),
|
||||
)
|
||||
|
||||
stream_tool_call_data = {
|
||||
key: value
|
||||
for key, value in tool_call_data.items()
|
||||
|
||||
@@ -20,10 +20,11 @@ from pydantic import AnyHttpUrl, ValidationError
|
||||
from redis import Redis
|
||||
|
||||
from application.agents.tools.base import Tool
|
||||
from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task
|
||||
from application.api.user.tasks import mcp_oauth_task
|
||||
from application.cache import get_redis_instance
|
||||
from application.core.settings import settings
|
||||
from application.core.url_validation import SSRFError, validate_url
|
||||
from application.events.keys import stream_key
|
||||
from application.security.encryption import decrypt_credentials
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -76,6 +77,12 @@ class MCPTool(Tool):
|
||||
self.oauth_task_id = config.get("oauth_task_id", None)
|
||||
self.oauth_client_name = config.get("oauth_client_name", "DocsGPT-MCP")
|
||||
self.redirect_uri = self._resolve_redirect_uri(config.get("redirect_uri"))
|
||||
# Pulled out of ``config`` (rather than left in ``self.config``)
|
||||
# because it is a callable supplied by the OAuth worker — not
|
||||
# something the rest of the tool plumbing should marshal or
|
||||
# serialize. ``DocsGPTOAuth`` invokes it from ``redirect_handler``
|
||||
# so the SSE envelope can carry ``authorization_url``.
|
||||
self.oauth_redirect_publish = config.pop("oauth_redirect_publish", None)
|
||||
|
||||
self.available_tools = []
|
||||
self._cache_key = self._generate_cache_key()
|
||||
@@ -167,6 +174,7 @@ class MCPTool(Tool):
|
||||
redirect_uri=self.redirect_uri,
|
||||
task_id=self.oauth_task_id,
|
||||
user_id=self.user_id,
|
||||
redirect_publish=self.oauth_redirect_publish,
|
||||
)
|
||||
elif self.auth_type == "bearer":
|
||||
token = self.auth_credentials.get(
|
||||
@@ -679,12 +687,17 @@ class DocsGPTOAuth(OAuthClientProvider):
|
||||
user_id=None,
|
||||
additional_client_metadata: dict[str, Any] | None = None,
|
||||
skip_redirect_validation: bool = False,
|
||||
redirect_publish=None,
|
||||
):
|
||||
self.redirect_uri = redirect_uri
|
||||
self.redis_client = redis_client
|
||||
self.redis_prefix = redis_prefix
|
||||
self.task_id = task_id
|
||||
self.user_id = user_id
|
||||
# Worker-supplied callback. Invoked from ``redirect_handler``
|
||||
# once the authorization URL is known so the SSE envelope can
|
||||
# carry it. ``None`` for any non-worker entrypoint.
|
||||
self.redirect_publish = redirect_publish
|
||||
|
||||
parsed_url = urlparse(mcp_url)
|
||||
self.server_base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||||
@@ -744,17 +757,19 @@ class DocsGPTOAuth(OAuthClientProvider):
|
||||
self.redis_client.setex(key, 600, auth_url)
|
||||
logger.info("Stored auth_url in Redis: %s", key)
|
||||
|
||||
if self.task_id:
|
||||
status_key = f"mcp_oauth_status:{self.task_id}"
|
||||
status_data = {
|
||||
"status": "requires_redirect",
|
||||
"message": "Authorization required",
|
||||
"authorization_url": self.auth_url,
|
||||
"state": self.extracted_state,
|
||||
"requires_oauth": True,
|
||||
"task_id": self.task_id,
|
||||
}
|
||||
self.redis_client.setex(status_key, 600, json.dumps(status_data))
|
||||
if self.redirect_publish is not None:
|
||||
# Best-effort: a publish failure must not abort the OAuth
|
||||
# handshake — the user can still authorize via the popup
|
||||
# opened from the legacy polling fallback if the SSE
|
||||
# envelope is lost.
|
||||
try:
|
||||
self.redirect_publish(auth_url)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"redirect_publish callback raised for task_id=%s",
|
||||
self.task_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
async def callback_handler(self) -> tuple[str, str | None]:
|
||||
"""Wait for auth code from Redis using the state value."""
|
||||
@@ -764,17 +779,6 @@ class DocsGPTOAuth(OAuthClientProvider):
|
||||
max_wait_time = 300
|
||||
code_key = f"{self.redis_prefix}code:{self.extracted_state}"
|
||||
|
||||
if self.task_id:
|
||||
status_key = f"mcp_oauth_status:{self.task_id}"
|
||||
status_data = {
|
||||
"status": "awaiting_callback",
|
||||
"message": "Waiting for authorization...",
|
||||
"authorization_url": self.auth_url,
|
||||
"state": self.extracted_state,
|
||||
"requires_oauth": True,
|
||||
"task_id": self.task_id,
|
||||
}
|
||||
self.redis_client.setex(status_key, 600, json.dumps(status_data))
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < max_wait_time:
|
||||
code_data = self.redis_client.get(code_key)
|
||||
@@ -789,14 +793,6 @@ class DocsGPTOAuth(OAuthClientProvider):
|
||||
self.redis_client.delete(
|
||||
f"{self.redis_prefix}state:{self.extracted_state}"
|
||||
)
|
||||
|
||||
if self.task_id:
|
||||
status_data = {
|
||||
"status": "callback_received",
|
||||
"message": "Completing authentication...",
|
||||
"task_id": self.task_id,
|
||||
}
|
||||
self.redis_client.setex(status_key, 600, json.dumps(status_data))
|
||||
return code, returned_state
|
||||
error_key = f"{self.redis_prefix}error:{self.extracted_state}"
|
||||
error_data = self.redis_client.get(error_key)
|
||||
@@ -1038,8 +1034,73 @@ class MCPOAuthManager:
|
||||
logger.error("Error handling OAuth callback: %s", e)
|
||||
return False
|
||||
|
||||
def get_oauth_status(self, task_id: str) -> Dict[str, Any]:
|
||||
"""Get current status of OAuth flow using provided task_id."""
|
||||
def get_oauth_status(self, task_id: str, user_id: str) -> Dict[str, Any]:
|
||||
"""Return the latest OAuth status for ``task_id`` from the user's SSE journal.
|
||||
|
||||
Mirrors the legacy polling contract: ``status`` derived from the
|
||||
``mcp.oauth.*`` event-type suffix, with payload fields surfaced
|
||||
(e.g. ``tools``/``tools_count`` on ``completed``).
|
||||
"""
|
||||
if not task_id:
|
||||
return {"status": "not_started", "message": "OAuth flow not started"}
|
||||
return mcp_oauth_status_task(task_id)
|
||||
if not user_id:
|
||||
return {"status": "not_found", "message": "User not provided"}
|
||||
if self.redis_client is None:
|
||||
return {"status": "not_found", "message": "Redis unavailable"}
|
||||
|
||||
try:
|
||||
# OAuth flows are short-lived but a concurrent source
|
||||
# ingest can flood the user channel between the OAuth
|
||||
# popup completing and the user clicking Save, pushing the
|
||||
# completion envelope outside the read window. Bound the
|
||||
# scan by the configured stream cap so we cover the full
|
||||
# journal — XADD MAXLEN keeps that bounded too.
|
||||
scan_count = max(settings.EVENTS_STREAM_MAXLEN, 200)
|
||||
entries = self.redis_client.xrevrange(
|
||||
stream_key(user_id), count=scan_count
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"xrevrange failed for oauth status: user_id=%s task_id=%s",
|
||||
user_id,
|
||||
task_id,
|
||||
)
|
||||
return {"status": "not_found", "message": "Status unavailable"}
|
||||
|
||||
for _entry_id, fields in entries:
|
||||
if not isinstance(fields, dict):
|
||||
continue
|
||||
# decode_responses=False ⇒ bytes keys; the string-key fallback
|
||||
# covers a future flip of that default without a forced refactor.
|
||||
event_raw = fields.get(b"event")
|
||||
if event_raw is None:
|
||||
event_raw = fields.get("event")
|
||||
if event_raw is None:
|
||||
continue
|
||||
if isinstance(event_raw, bytes):
|
||||
try:
|
||||
event_raw = event_raw.decode("utf-8")
|
||||
except Exception:
|
||||
continue
|
||||
try:
|
||||
envelope = json.loads(event_raw)
|
||||
except Exception:
|
||||
continue
|
||||
if not isinstance(envelope, dict):
|
||||
continue
|
||||
event_type = envelope.get("type", "")
|
||||
if not isinstance(event_type, str) or not event_type.startswith(
|
||||
"mcp.oauth."
|
||||
):
|
||||
continue
|
||||
scope = envelope.get("scope") or {}
|
||||
if scope.get("kind") != "mcp_oauth" or scope.get("id") != task_id:
|
||||
continue
|
||||
payload = envelope.get("payload") or {}
|
||||
return {
|
||||
"status": event_type[len("mcp.oauth."):],
|
||||
"task_id": task_id,
|
||||
**payload,
|
||||
}
|
||||
|
||||
return {"status": "not_found", "message": "Status not found"}
|
||||
|
||||
@@ -177,3 +177,4 @@ class PostgresTool(Tool):
|
||||
"order": 1,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -57,6 +57,29 @@ class ToolActionParser:
|
||||
def _parse_google_llm(self, call):
|
||||
try:
|
||||
call_args = call.arguments
|
||||
# Gemini's SDK natively returns ``args`` as a dict, but the
|
||||
# resume path (``gen_continuation``) stringifies it for the
|
||||
# assistant message. Coerce a JSON string back into a dict;
|
||||
# fall back to an empty dict on malformed input so downstream
|
||||
# ``call_args.items()`` doesn't crash the stream.
|
||||
if isinstance(call_args, str):
|
||||
try:
|
||||
call_args = json.loads(call_args)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logger.warning(
|
||||
"Google call.arguments was not valid JSON; "
|
||||
"falling back to empty args for %s",
|
||||
getattr(call, "name", "<unknown>"),
|
||||
)
|
||||
call_args = {}
|
||||
if not isinstance(call_args, dict):
|
||||
logger.warning(
|
||||
"Google call.arguments has unexpected type %s; "
|
||||
"falling back to empty args for %s",
|
||||
type(call_args).__name__,
|
||||
getattr(call, "name", "<unknown>"),
|
||||
)
|
||||
call_args = {}
|
||||
|
||||
resolved = self._resolve_via_mapping(call.name)
|
||||
if resolved:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""0001 initial schema — consolidated Phase-1..3 baseline.
|
||||
"""0001 initial schema — consolidated baseline for user-data tables.
|
||||
|
||||
Revision ID: 0001_initial
|
||||
Revises:
|
||||
|
||||
217
application/alembic/versions/0004_durability_foundation.py
Normal file
217
application/alembic/versions/0004_durability_foundation.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""0004 durability foundation — idempotency, tool-call log, ingest checkpoint.
|
||||
|
||||
Adds ``task_dedup``, ``webhook_dedup``, ``tool_call_attempts``,
|
||||
``ingest_chunk_progress``, and per-row status flags on
|
||||
``conversation_messages`` and ``pending_tool_state``. Also adds
|
||||
``token_usage.source`` and ``token_usage.request_id`` so per-channel
|
||||
cost attribution (``agent_stream`` / ``title`` / ``compression`` /
|
||||
``rag_condense`` / ``fallback``) is queryable and multi-call agent runs
|
||||
can be DISTINCT-collapsed into a single user request for rate limiting.
|
||||
|
||||
Revision ID: 0004_durability_foundation
|
||||
Revises: 0003_user_custom_models
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
revision: str = "0004_durability_foundation"
|
||||
down_revision: Union[str, None] = "0003_user_custom_models"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ------------------------------------------------------------------
|
||||
# New tables
|
||||
# ------------------------------------------------------------------
|
||||
# ``attempt_count`` bounds the per-Celery-task idempotency wrapper's
|
||||
# retry loop so a poison message can't run forever; default 0 means
|
||||
# existing rows behave as if no attempts have run yet.
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE task_dedup (
|
||||
idempotency_key TEXT PRIMARY KEY,
|
||||
task_name TEXT NOT NULL,
|
||||
task_id TEXT NOT NULL,
|
||||
result_json JSONB,
|
||||
status TEXT NOT NULL
|
||||
CHECK (status IN ('pending', 'completed', 'failed')),
|
||||
attempt_count INT NOT NULL DEFAULT 0,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE webhook_dedup (
|
||||
idempotency_key TEXT PRIMARY KEY,
|
||||
agent_id UUID NOT NULL,
|
||||
task_id TEXT NOT NULL,
|
||||
response_json JSONB,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# FK on ``message_id`` uses ``ON DELETE SET NULL`` so the journal row
|
||||
# survives parent-message deletion (compliance / cost-attribution).
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE tool_call_attempts (
|
||||
call_id TEXT PRIMARY KEY,
|
||||
message_id UUID
|
||||
REFERENCES conversation_messages (id)
|
||||
ON DELETE SET NULL,
|
||||
tool_id UUID,
|
||||
tool_name TEXT NOT NULL,
|
||||
action_name TEXT NOT NULL,
|
||||
arguments JSONB NOT NULL,
|
||||
result JSONB,
|
||||
error TEXT,
|
||||
status TEXT NOT NULL
|
||||
CHECK (status IN (
|
||||
'proposed', 'executed', 'confirmed',
|
||||
'compensated', 'failed'
|
||||
)),
|
||||
attempted_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE ingest_chunk_progress (
|
||||
source_id UUID PRIMARY KEY,
|
||||
total_chunks INT NOT NULL,
|
||||
embedded_chunks INT NOT NULL DEFAULT 0,
|
||||
last_index INT NOT NULL DEFAULT -1,
|
||||
last_updated TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Column additions on existing tables
|
||||
# ------------------------------------------------------------------
|
||||
# DEFAULT 'complete' backfills existing rows — they're already done.
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE conversation_messages
|
||||
ADD COLUMN status TEXT NOT NULL DEFAULT 'complete'
|
||||
CHECK (status IN ('pending', 'streaming', 'complete', 'failed')),
|
||||
ADD COLUMN request_id TEXT;
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE pending_tool_state
|
||||
ADD COLUMN status TEXT NOT NULL DEFAULT 'pending'
|
||||
CHECK (status IN ('pending', 'resuming')),
|
||||
ADD COLUMN resumed_at TIMESTAMPTZ;
|
||||
"""
|
||||
)
|
||||
|
||||
# Default ``agent_stream`` backfills historical rows under the
|
||||
# assumption they were written from the primary path — pre-fix the
|
||||
# only path that wrote was the error branch reading agent.llm.
|
||||
# ``request_id`` is the stream-scoped UUID stamped by the route on
|
||||
# ``agent.llm`` so multi-tool agent runs (which produce N rows)
|
||||
# collapse to one request via DISTINCT in ``count_in_range``.
|
||||
# Side-channel sources (``title`` / ``compression`` / ``rag_condense``
|
||||
# / ``fallback``) leave it NULL and are excluded from the request
|
||||
# count by source filter.
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE token_usage
|
||||
ADD COLUMN source TEXT NOT NULL DEFAULT 'agent_stream',
|
||||
ADD COLUMN request_id TEXT;
|
||||
"""
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Indexes — partial where the predicate selects only non-terminal rows
|
||||
# ------------------------------------------------------------------
|
||||
op.execute(
|
||||
"CREATE INDEX conversation_messages_pending_ts_idx "
|
||||
"ON conversation_messages (timestamp) "
|
||||
"WHERE status IN ('pending', 'streaming');"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX tool_call_attempts_pending_ts_idx "
|
||||
"ON tool_call_attempts (attempted_at) "
|
||||
"WHERE status IN ('proposed', 'executed');"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX tool_call_attempts_message_idx "
|
||||
"ON tool_call_attempts (message_id) "
|
||||
"WHERE message_id IS NOT NULL;"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX pending_tool_state_resuming_ts_idx "
|
||||
"ON pending_tool_state (resumed_at) "
|
||||
"WHERE status = 'resuming';"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX webhook_dedup_agent_idx "
|
||||
"ON webhook_dedup (agent_id);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX task_dedup_pending_attempts_idx "
|
||||
"ON task_dedup (attempt_count) WHERE status = 'pending';"
|
||||
)
|
||||
# Cost-attribution dashboards filter ``token_usage`` by
|
||||
# ``(timestamp, source)``; index the same shape so they stay cheap.
|
||||
op.execute(
|
||||
"CREATE INDEX token_usage_source_ts_idx "
|
||||
"ON token_usage (source, timestamp);"
|
||||
)
|
||||
# Partial index — only rows with a stamped request_id participate
|
||||
# in the DISTINCT count. NULL rows fall through to the COUNT(*)
|
||||
# branch in the repository query.
|
||||
op.execute(
|
||||
"CREATE INDEX token_usage_request_id_idx "
|
||||
"ON token_usage (request_id) "
|
||||
"WHERE request_id IS NOT NULL;"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"CREATE TRIGGER tool_call_attempts_set_updated_at "
|
||||
"BEFORE UPDATE ON tool_call_attempts "
|
||||
"FOR EACH ROW WHEN (OLD.* IS DISTINCT FROM NEW.*) "
|
||||
"EXECUTE FUNCTION set_updated_at();"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# CASCADE so the downgrade stays safe if later migrations FK into these.
|
||||
for table in (
|
||||
"ingest_chunk_progress",
|
||||
"tool_call_attempts",
|
||||
"webhook_dedup",
|
||||
"task_dedup",
|
||||
):
|
||||
op.execute(f"DROP TABLE IF EXISTS {table} CASCADE;")
|
||||
|
||||
op.execute(
|
||||
"ALTER TABLE conversation_messages "
|
||||
"DROP COLUMN IF EXISTS request_id, "
|
||||
"DROP COLUMN IF EXISTS status;"
|
||||
)
|
||||
op.execute(
|
||||
"ALTER TABLE pending_tool_state "
|
||||
"DROP COLUMN IF EXISTS resumed_at, "
|
||||
"DROP COLUMN IF EXISTS status;"
|
||||
)
|
||||
op.execute("DROP INDEX IF EXISTS token_usage_request_id_idx;")
|
||||
op.execute("DROP INDEX IF EXISTS token_usage_source_ts_idx;")
|
||||
op.execute(
|
||||
"ALTER TABLE token_usage "
|
||||
"DROP COLUMN IF EXISTS request_id, "
|
||||
"DROP COLUMN IF EXISTS source;"
|
||||
)
|
||||
44
application/alembic/versions/0005_ingest_attempt_id.py
Normal file
44
application/alembic/versions/0005_ingest_attempt_id.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""0005 ingest_chunk_progress.attempt_id — per-attempt resume scoping.
|
||||
|
||||
Without this column, a completed checkpoint row poisoned every later
|
||||
embed call on the same ``source_id``: a sync after an upload finished
|
||||
read the upload's terminal ``last_index`` and either embedded zero
|
||||
chunks (if new ``total_docs <= last_index + 1``) or stacked new chunks
|
||||
on top of the old vectors (if ``total_docs > last_index + 1``).
|
||||
|
||||
``attempt_id`` is stamped from ``self.request.id`` (Celery's stable
|
||||
task id, which survives ``acks_late`` retries of the same task but
|
||||
differs across separate task invocations). The repository's
|
||||
``init_progress`` upsert resets ``last_index`` / ``embedded_chunks``
|
||||
when the incoming ``attempt_id`` differs from the stored one — so a
|
||||
fresh sync starts from chunk 0 while a retry of the same task resumes
|
||||
from the last checkpointed chunk.
|
||||
|
||||
Revision ID: 0005_ingest_attempt_id
|
||||
Revises: 0004_durability_foundation
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
revision: str = "0005_ingest_attempt_id"
|
||||
down_revision: Union[str, None] = "0004_durability_foundation"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE ingest_chunk_progress
|
||||
ADD COLUMN attempt_id TEXT;
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
"ALTER TABLE ingest_chunk_progress DROP COLUMN IF EXISTS attempt_id;"
|
||||
)
|
||||
57
application/alembic/versions/0006_idempotency_lease.py
Normal file
57
application/alembic/versions/0006_idempotency_lease.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""0006 task_dedup lease columns — running-lease for in-flight tasks.
|
||||
|
||||
Without these, ``with_idempotency`` only short-circuits *completed*
|
||||
rows. A late-ack redelivery (Redis ``visibility_timeout`` exceeded by a
|
||||
long ingest, or a hung-but-alive worker) hands the same message to a
|
||||
second worker; ``_claim_or_bump`` only bumped the attempt counter and
|
||||
both workers ran the task body in parallel — duplicate vector writes,
|
||||
duplicate token spend, duplicate webhook side effects.
|
||||
|
||||
``lease_owner_id`` + ``lease_expires_at`` turn that into an atomic
|
||||
compare-and-swap. The wrapper claims a lease at entry, refreshes it via
|
||||
a 30 s heartbeat thread, and finalises (which makes the lease moot via
|
||||
``status='completed'``). A second worker hitting the same key sees a
|
||||
fresh lease and ``self.retry(countdown=LEASE_TTL)``s instead of running.
|
||||
A crashed worker's lease expires after ``LEASE_TTL`` seconds and the
|
||||
next retry can claim it.
|
||||
|
||||
Revision ID: 0006_idempotency_lease
|
||||
Revises: 0005_ingest_attempt_id
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
revision: str = "0006_idempotency_lease"
|
||||
down_revision: Union[str, None] = "0005_ingest_attempt_id"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE task_dedup
|
||||
ADD COLUMN lease_owner_id TEXT,
|
||||
ADD COLUMN lease_expires_at TIMESTAMPTZ;
|
||||
"""
|
||||
)
|
||||
# Reconciler's stuck-pending sweep filters by
|
||||
# ``(status='pending', lease_expires_at < now() - 60s, attempt_count >= 5)``.
|
||||
# Partial index keeps the scan small even under heavy task throughput.
|
||||
op.execute(
|
||||
"CREATE INDEX task_dedup_pending_lease_idx "
|
||||
"ON task_dedup (lease_expires_at) "
|
||||
"WHERE status = 'pending';"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP INDEX IF EXISTS task_dedup_pending_lease_idx;")
|
||||
op.execute(
|
||||
"ALTER TABLE task_dedup "
|
||||
"DROP COLUMN IF EXISTS lease_expires_at, "
|
||||
"DROP COLUMN IF EXISTS lease_owner_id;"
|
||||
)
|
||||
40
application/alembic/versions/0007_message_events.py
Normal file
40
application/alembic/versions/0007_message_events.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""0007 message_events — durable journal of chat-stream events.
|
||||
|
||||
Snapshot half of the chat-stream snapshot+tail pattern. Composite PK
|
||||
``(message_id, sequence_no)``, ``created_at`` indexed for retention
|
||||
sweeps, ``ON DELETE CASCADE`` from ``conversation_messages``.
|
||||
|
||||
Revision ID: 0007_message_events
|
||||
Revises: 0006_idempotency_lease
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
revision: str = "0007_message_events"
|
||||
down_revision: Union[str, None] = "0006_idempotency_lease"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE message_events (
|
||||
message_id UUID NOT NULL REFERENCES conversation_messages(id) ON DELETE CASCADE,
|
||||
sequence_no INTEGER NOT NULL,
|
||||
event_type TEXT NOT NULL,
|
||||
payload JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
PRIMARY KEY (message_id, sequence_no)
|
||||
);
|
||||
CREATE INDEX message_events_created_at_idx ON message_events(created_at);
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP INDEX IF EXISTS message_events_created_at_idx;")
|
||||
op.execute("DROP TABLE IF EXISTS message_events;")
|
||||
@@ -102,6 +102,8 @@ class AnswerResource(Resource, BaseAnswerResource):
|
||||
"tools_dict": tools_dict,
|
||||
"pending_tool_calls": pending_tool_calls,
|
||||
"tool_actions": tool_actions,
|
||||
"reserved_message_id": processor.reserved_message_id,
|
||||
"request_id": processor.request_id,
|
||||
},
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, Generator, List, Optional
|
||||
|
||||
from flask import jsonify, make_response, Response
|
||||
from flask_restx import Namespace
|
||||
|
||||
from application.api.answer.services.continuation_service import ContinuationService
|
||||
from application.api.answer.services.conversation_service import ConversationService
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
TERMINATED_RESPONSE_PLACEHOLDER,
|
||||
)
|
||||
from application.core.model_utils import (
|
||||
get_api_key_for_provider,
|
||||
get_default_model_id,
|
||||
@@ -18,9 +23,16 @@ from application.core.settings import settings
|
||||
from application.error import sanitize_api_error
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.conversations import MessageUpdateOutcome
|
||||
from application.storage.db.repositories.token_usage import TokenUsageRepository
|
||||
from application.storage.db.repositories.user_logs import UserLogsRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.events.publisher import publish_user_event
|
||||
from application.streaming.event_replay import format_sse_event
|
||||
from application.streaming.message_journal import (
|
||||
BatchedJournalWriter,
|
||||
record_event,
|
||||
)
|
||||
from application.utils import check_required_fields
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -203,13 +215,188 @@ class BaseAnswerResource:
|
||||
Yields:
|
||||
Server-sent event strings
|
||||
"""
|
||||
response_full, thought, source_log_docs, tool_calls = "", "", [], []
|
||||
is_structured = False
|
||||
schema_info = None
|
||||
structured_chunks = []
|
||||
query_metadata: Dict[str, Any] = {}
|
||||
paused = False
|
||||
|
||||
# One id shared across the WAL row, primary LLM (token_usage
|
||||
# attribution), the SSE event, and resumed continuations.
|
||||
request_id = (
|
||||
_continuation.get("request_id") if _continuation else None
|
||||
) or str(uuid.uuid4())
|
||||
|
||||
# Reserve the placeholder row before the LLM call so a crash
|
||||
# mid-stream still leaves the question queryable. Continuations
|
||||
# reuse the original placeholder.
|
||||
reserved_message_id: Optional[str] = None
|
||||
wal_eligible = should_save_conversation and not _continuation
|
||||
if wal_eligible:
|
||||
try:
|
||||
reservation = self.conversation_service.save_user_question(
|
||||
conversation_id=conversation_id,
|
||||
question=question,
|
||||
decoded_token=decoded_token,
|
||||
attachment_ids=attachment_ids,
|
||||
api_key=user_api_key,
|
||||
agent_id=agent_id,
|
||||
is_shared_usage=is_shared_usage,
|
||||
shared_token=shared_token,
|
||||
model_id=model_id or self.default_model_id,
|
||||
request_id=request_id,
|
||||
index=index,
|
||||
)
|
||||
conversation_id = reservation["conversation_id"]
|
||||
reserved_message_id = reservation["message_id"]
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to reserve message row before stream: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
elif _continuation and _continuation.get("reserved_message_id"):
|
||||
reserved_message_id = _continuation["reserved_message_id"]
|
||||
|
||||
primary_llm = getattr(agent, "llm", None)
|
||||
if primary_llm is not None:
|
||||
primary_llm._request_id = request_id
|
||||
|
||||
# Flipped to ``streaming`` on first chunk; reconciler uses this
|
||||
# to tell "never started" from "in flight".
|
||||
streaming_marked = False
|
||||
# Heartbeat goes into ``metadata.last_heartbeat_at`` (not
|
||||
# ``updated_at``, which reconciler-side writes share) and uses
|
||||
# ``time.monotonic`` so a blocked event loop can't fake fresh.
|
||||
STREAM_HEARTBEAT_INTERVAL = 60
|
||||
last_heartbeat_at = time.monotonic()
|
||||
|
||||
def _mark_streaming_once() -> None:
|
||||
nonlocal streaming_marked, last_heartbeat_at
|
||||
if streaming_marked or not reserved_message_id:
|
||||
return
|
||||
try:
|
||||
self.conversation_service.update_message_status(
|
||||
reserved_message_id, "streaming",
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"update_message_status streaming failed for %s",
|
||||
reserved_message_id,
|
||||
)
|
||||
# Seed last_heartbeat_at so watchdog doesn't fall back to `timestamp`
|
||||
# (creation time) before the first STREAM_HEARTBEAT_INTERVAL tick.
|
||||
try:
|
||||
self.conversation_service.heartbeat_message(
|
||||
reserved_message_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"initial heartbeat seed failed for %s",
|
||||
reserved_message_id,
|
||||
)
|
||||
streaming_marked = True
|
||||
last_heartbeat_at = time.monotonic()
|
||||
|
||||
def _heartbeat_streaming() -> None:
|
||||
nonlocal last_heartbeat_at
|
||||
if not reserved_message_id or not streaming_marked:
|
||||
return
|
||||
now_mono = time.monotonic()
|
||||
if now_mono - last_heartbeat_at < STREAM_HEARTBEAT_INTERVAL:
|
||||
return
|
||||
try:
|
||||
self.conversation_service.heartbeat_message(
|
||||
reserved_message_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"stream heartbeat update failed for %s",
|
||||
reserved_message_id,
|
||||
)
|
||||
last_heartbeat_at = now_mono
|
||||
|
||||
# Correlates tool_call_attempts rows with this message.
|
||||
if reserved_message_id and getattr(agent, "tool_executor", None):
|
||||
try:
|
||||
agent.tool_executor.message_id = reserved_message_id
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Could not set tool_executor.message_id; tool-call correlation will be missing for message_id=%s",
|
||||
reserved_message_id,
|
||||
)
|
||||
|
||||
# Per-stream monotonic SSE event id. Allocated by ``_emit`` and
|
||||
# threaded through both the wire format (``id: <seq>\\n``) and
|
||||
# the journal write so a reconnecting client can ``Last-Event-
|
||||
# ID`` past anything they already saw. Continuations resume
|
||||
# against the original ``reserved_message_id`` — seed the
|
||||
# allocator from the journal's high-water mark so we don't
|
||||
# collide on the duplicate-PK and silently lose every emit
|
||||
# past the resume point.
|
||||
sequence_no = -1
|
||||
if _continuation and reserved_message_id:
|
||||
try:
|
||||
from application.storage.db.repositories.message_events import (
|
||||
MessageEventsRepository,
|
||||
)
|
||||
|
||||
with db_readonly() as conn:
|
||||
latest = MessageEventsRepository(conn).latest_sequence_no(
|
||||
reserved_message_id
|
||||
)
|
||||
if latest is not None:
|
||||
sequence_no = latest
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Continuation seq seed lookup failed for message_id=%s; "
|
||||
"falling back to seq=-1 (duplicate-PK collisions will "
|
||||
"be swallowed)",
|
||||
reserved_message_id,
|
||||
)
|
||||
|
||||
# One batched journal writer per stream.
|
||||
journal_writer: Optional[BatchedJournalWriter] = (
|
||||
BatchedJournalWriter(reserved_message_id)
|
||||
if reserved_message_id
|
||||
else None
|
||||
)
|
||||
|
||||
def _emit(payload: dict) -> str:
|
||||
"""Format-and-journal one SSE event.
|
||||
|
||||
With a reserved ``message_id``, buffers into the journal and
|
||||
emits ``id: <seq>``-tagged SSE frames; otherwise falls back to
|
||||
legacy ``data: ...\\n\\n`` framing.
|
||||
"""
|
||||
nonlocal sequence_no
|
||||
if not reserved_message_id or journal_writer is None:
|
||||
return f"data: {json.dumps(payload)}\n\n"
|
||||
sequence_no += 1
|
||||
seq = sequence_no
|
||||
event_type = (
|
||||
payload.get("type", "data")
|
||||
if isinstance(payload, dict)
|
||||
else "data"
|
||||
)
|
||||
normalised = payload if isinstance(payload, dict) else {"value": payload}
|
||||
journal_writer.record(seq, event_type, normalised)
|
||||
return format_sse_event(normalised, seq)
|
||||
|
||||
try:
|
||||
response_full, thought, source_log_docs, tool_calls = "", "", [], []
|
||||
is_structured = False
|
||||
schema_info = None
|
||||
structured_chunks = []
|
||||
query_metadata = {}
|
||||
paused = False
|
||||
# Surface the placeholder id before any LLM tokens so a
|
||||
# mid-handshake disconnect still has a row to tail-poll.
|
||||
if reserved_message_id:
|
||||
yield _emit(
|
||||
{
|
||||
"type": "message_id",
|
||||
"message_id": reserved_message_id,
|
||||
"conversation_id": (
|
||||
str(conversation_id) if conversation_id else None
|
||||
),
|
||||
"request_id": request_id,
|
||||
}
|
||||
)
|
||||
|
||||
if _continuation:
|
||||
gen_iter = agent.gen_continuation(
|
||||
@@ -222,18 +409,24 @@ class BaseAnswerResource:
|
||||
gen_iter = agent.gen(query=question)
|
||||
|
||||
for line in gen_iter:
|
||||
# Cheap closure check that only hits the DB when the
|
||||
# heartbeat interval has elapsed.
|
||||
_heartbeat_streaming()
|
||||
if "metadata" in line:
|
||||
query_metadata.update(line["metadata"])
|
||||
elif "answer" in line:
|
||||
_mark_streaming_once()
|
||||
response_full += str(line["answer"])
|
||||
if line.get("structured"):
|
||||
is_structured = True
|
||||
schema_info = line.get("schema")
|
||||
structured_chunks.append(line["answer"])
|
||||
else:
|
||||
data = json.dumps({"type": "answer", "answer": line["answer"]})
|
||||
yield f"data: {data}\n\n"
|
||||
yield _emit(
|
||||
{"type": "answer", "answer": line["answer"]}
|
||||
)
|
||||
elif "sources" in line:
|
||||
_mark_streaming_once()
|
||||
truncated_sources = []
|
||||
source_log_docs = line["sources"]
|
||||
for source in line["sources"]:
|
||||
@@ -244,54 +437,48 @@ class BaseAnswerResource:
|
||||
)
|
||||
truncated_sources.append(truncated_source)
|
||||
if truncated_sources:
|
||||
data = json.dumps(
|
||||
yield _emit(
|
||||
{"type": "source", "source": truncated_sources}
|
||||
)
|
||||
yield f"data: {data}\n\n"
|
||||
elif "tool_calls" in line:
|
||||
tool_calls = line["tool_calls"]
|
||||
data = json.dumps({"type": "tool_calls", "tool_calls": tool_calls})
|
||||
yield f"data: {data}\n\n"
|
||||
yield _emit({"type": "tool_calls", "tool_calls": tool_calls})
|
||||
elif "thought" in line:
|
||||
thought += line["thought"]
|
||||
data = json.dumps({"type": "thought", "thought": line["thought"]})
|
||||
yield f"data: {data}\n\n"
|
||||
yield _emit({"type": "thought", "thought": line["thought"]})
|
||||
elif "type" in line:
|
||||
if line.get("type") == "tool_calls_pending":
|
||||
# Save continuation state and end the stream
|
||||
paused = True
|
||||
data = json.dumps(line)
|
||||
yield f"data: {data}\n\n"
|
||||
yield _emit(line)
|
||||
elif line.get("type") == "error":
|
||||
sanitized_error = {
|
||||
"type": "error",
|
||||
"error": sanitize_api_error(line.get("error", "An error occurred"))
|
||||
}
|
||||
data = json.dumps(sanitized_error)
|
||||
yield f"data: {data}\n\n"
|
||||
yield _emit(
|
||||
{
|
||||
"type": "error",
|
||||
"error": sanitize_api_error(
|
||||
line.get("error", "An error occurred")
|
||||
),
|
||||
}
|
||||
)
|
||||
else:
|
||||
data = json.dumps(line)
|
||||
yield f"data: {data}\n\n"
|
||||
yield _emit(line)
|
||||
if is_structured and structured_chunks:
|
||||
structured_data = {
|
||||
"type": "structured_answer",
|
||||
"answer": response_full,
|
||||
"structured": True,
|
||||
"schema": schema_info,
|
||||
}
|
||||
data = json.dumps(structured_data)
|
||||
yield f"data: {data}\n\n"
|
||||
yield _emit(
|
||||
{
|
||||
"type": "structured_answer",
|
||||
"answer": response_full,
|
||||
"structured": True,
|
||||
"schema": schema_info,
|
||||
}
|
||||
)
|
||||
|
||||
# ---- Paused: save continuation state and end stream early ----
|
||||
if paused:
|
||||
continuation = getattr(agent, "_pending_continuation", None)
|
||||
if continuation:
|
||||
# Ensure we have a conversation_id — create a partial
|
||||
# conversation if this is the first turn.
|
||||
# First-turn pause needs a conversation row to attach to.
|
||||
if not conversation_id and should_save_conversation:
|
||||
try:
|
||||
# Use model-owner scope so shared-agent
|
||||
# owner-BYOM resolves to its registered plugin.
|
||||
provider = (
|
||||
get_provider_from_model_id(
|
||||
model_id,
|
||||
@@ -340,6 +527,7 @@ class BaseAnswerResource:
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
state_saved = False
|
||||
if conversation_id:
|
||||
try:
|
||||
cont_service = ContinuationService()
|
||||
@@ -352,8 +540,8 @@ class BaseAnswerResource:
|
||||
tool_schemas=getattr(agent, "tools", []),
|
||||
agent_config={
|
||||
"model_id": model_id or self.default_model_id,
|
||||
# Persist BYOM scope so resume doesn't
|
||||
# fall back to caller's layer.
|
||||
# BYOM scope; without it resume falls
|
||||
# back to caller's layer.
|
||||
"model_user_id": model_user_id,
|
||||
"llm_name": getattr(agent, "llm_name", settings.LLM_PROVIDER),
|
||||
"api_key": getattr(agent, "api_key", None),
|
||||
@@ -363,30 +551,81 @@ class BaseAnswerResource:
|
||||
"prompt": getattr(agent, "prompt", ""),
|
||||
"json_schema": getattr(agent, "json_schema", None),
|
||||
"retriever_config": getattr(agent, "retriever_config", None),
|
||||
# Reused on resume so the same WAL row
|
||||
# is finalised and request_id stays
|
||||
# consistent across token_usage rows.
|
||||
"reserved_message_id": reserved_message_id,
|
||||
"request_id": request_id,
|
||||
},
|
||||
client_tools=getattr(
|
||||
agent.tool_executor, "client_tools", None
|
||||
),
|
||||
)
|
||||
state_saved = True
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to save continuation state: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
id_data = {"type": "id", "id": str(conversation_id)}
|
||||
data = json.dumps(id_data)
|
||||
yield f"data: {data}\n\n"
|
||||
# Notify the user out-of-band so they can navigate
|
||||
# back to the conversation and decide on the
|
||||
# pending tool calls. Gated on ``state_saved``: a
|
||||
# missing pending_tool_state row would 404 the
|
||||
# resume endpoint, so an unfulfillable notification
|
||||
# is worse than no notification.
|
||||
user_id_for_event = (
|
||||
decoded_token.get("sub") if decoded_token else None
|
||||
)
|
||||
if state_saved and user_id_for_event and conversation_id:
|
||||
pending_calls = continuation.get(
|
||||
"pending_tool_calls", []
|
||||
) if continuation else []
|
||||
# Trim each pending tool call to its identifying
|
||||
# metadata so a tool with a multi-MB argument
|
||||
# doesn't blow out the per-event payload size
|
||||
# cap. The resume page fetches full args from
|
||||
# ``pending_tool_state`` regardless.
|
||||
pending_summaries = [
|
||||
{
|
||||
k: tc.get(k)
|
||||
for k in (
|
||||
"call_id",
|
||||
"tool_name",
|
||||
"action_name",
|
||||
"name",
|
||||
)
|
||||
if isinstance(tc, dict) and tc.get(k) is not None
|
||||
}
|
||||
for tc in (pending_calls or [])
|
||||
if isinstance(tc, dict)
|
||||
]
|
||||
publish_user_event(
|
||||
user_id_for_event,
|
||||
"tool.approval.required",
|
||||
{
|
||||
"conversation_id": str(conversation_id),
|
||||
"message_id": reserved_message_id,
|
||||
"pending_tool_calls": pending_summaries,
|
||||
},
|
||||
scope={
|
||||
"kind": "conversation",
|
||||
"id": str(conversation_id),
|
||||
},
|
||||
)
|
||||
|
||||
data = json.dumps({"type": "end"})
|
||||
yield f"data: {data}\n\n"
|
||||
yield _emit({"type": "id", "id": str(conversation_id)})
|
||||
yield _emit({"type": "end"})
|
||||
# Drain the terminal ``end`` so a reconnecting client
|
||||
# sees it on snapshot — same reason as the main exit.
|
||||
if journal_writer is not None:
|
||||
journal_writer.close()
|
||||
return
|
||||
|
||||
if isNoneDoc:
|
||||
for doc in source_log_docs:
|
||||
doc["source"] = "None"
|
||||
# Run under model-owner scope so title-gen LLM inside
|
||||
# save_conversation uses the owner's BYOM provider/key.
|
||||
# Model-owner scope so title-gen uses owner's BYOM key.
|
||||
provider = (
|
||||
get_provider_from_model_id(
|
||||
model_id,
|
||||
@@ -407,26 +646,49 @@ class BaseAnswerResource:
|
||||
agent_id=agent_id,
|
||||
model_user_id=model_user_id,
|
||||
)
|
||||
# Title-gen only; agent stream tokens live on ``agent.llm``.
|
||||
llm._token_usage_source = "title"
|
||||
|
||||
if should_save_conversation:
|
||||
conversation_id = self.conversation_service.save_conversation(
|
||||
conversation_id,
|
||||
question,
|
||||
response_full,
|
||||
thought,
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
model_id or self.default_model_id,
|
||||
decoded_token,
|
||||
index=index,
|
||||
api_key=user_api_key,
|
||||
agent_id=agent_id,
|
||||
is_shared_usage=is_shared_usage,
|
||||
shared_token=shared_token,
|
||||
attachment_ids=attachment_ids,
|
||||
metadata=query_metadata if query_metadata else None,
|
||||
)
|
||||
if reserved_message_id is not None:
|
||||
self.conversation_service.finalize_message(
|
||||
reserved_message_id,
|
||||
response_full,
|
||||
thought=thought,
|
||||
sources=source_log_docs,
|
||||
tool_calls=tool_calls,
|
||||
model_id=model_id or self.default_model_id,
|
||||
metadata=query_metadata if query_metadata else None,
|
||||
status="complete",
|
||||
title_inputs={
|
||||
"llm": llm,
|
||||
"question": question,
|
||||
"response": response_full,
|
||||
"model_id": model_id or self.default_model_id,
|
||||
"fallback_name": (
|
||||
question[:50] if question else "New Conversation"
|
||||
),
|
||||
},
|
||||
)
|
||||
else:
|
||||
conversation_id = self.conversation_service.save_conversation(
|
||||
conversation_id,
|
||||
question,
|
||||
response_full,
|
||||
thought,
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
model_id or self.default_model_id,
|
||||
decoded_token,
|
||||
index=index,
|
||||
api_key=user_api_key,
|
||||
agent_id=agent_id,
|
||||
is_shared_usage=is_shared_usage,
|
||||
shared_token=shared_token,
|
||||
attachment_ids=attachment_ids,
|
||||
metadata=query_metadata if query_metadata else None,
|
||||
)
|
||||
# Persist compression metadata/summary if it exists and wasn't saved mid-execution
|
||||
compression_meta = getattr(agent, "compression_metadata", None)
|
||||
compression_saved = getattr(agent, "compression_saved", False)
|
||||
@@ -449,9 +711,22 @@ class BaseAnswerResource:
|
||||
)
|
||||
else:
|
||||
conversation_id = None
|
||||
id_data = {"type": "id", "id": str(conversation_id)}
|
||||
data = json.dumps(id_data)
|
||||
yield f"data: {data}\n\n"
|
||||
# Resume finished cleanly; drop the continuation row.
|
||||
# Crash-paths leave it ``resuming`` for the janitor to revert.
|
||||
if _continuation and conversation_id:
|
||||
try:
|
||||
cont_service = ContinuationService()
|
||||
cont_service.delete_state(
|
||||
str(conversation_id),
|
||||
decoded_token.get("sub", "local"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to delete continuation state on resume "
|
||||
f"completion: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
yield _emit({"type": "id", "id": str(conversation_id)})
|
||||
|
||||
tool_calls_for_logging = self._prepare_tool_calls_for_logging(
|
||||
getattr(agent, "tool_calls", tool_calls) or tool_calls
|
||||
@@ -492,21 +767,40 @@ class BaseAnswerResource:
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
data = json.dumps({"type": "end"})
|
||||
yield f"data: {data}\n\n"
|
||||
yield _emit({"type": "end"})
|
||||
# Drain the journal buffer so the terminal ``end`` event is
|
||||
# visible to any reconnecting client. Without this the
|
||||
# client could snapshot up to the last flush boundary and
|
||||
# then live-tail waiting for an ``end`` that's still
|
||||
# sitting in memory.
|
||||
if journal_writer is not None:
|
||||
journal_writer.close()
|
||||
except GeneratorExit:
|
||||
logger.info(f"Stream aborted by client for question: {question[:50]}... ")
|
||||
# Drain any buffered events before the terminal one-shot
|
||||
# ``record_event`` below — keeps the journal's seq order
|
||||
# contiguous (buffered events ... terminal event). ``close``
|
||||
# is idempotent; pairing it with ``flush`` matches the
|
||||
# normal-exit and error branches so any future ``record()``
|
||||
# past this point would log instead of silently buffering.
|
||||
if journal_writer is not None:
|
||||
journal_writer.flush()
|
||||
journal_writer.close()
|
||||
# Save partial response
|
||||
|
||||
# Whether the DB row was flipped to ``complete`` during this
|
||||
# abort handler. Drives the choice of terminal journal event
|
||||
# below: journal ``end`` only when the row actually matches,
|
||||
# else journal ``error`` so a reconnecting client sees a
|
||||
# failed terminal state instead of a blank "success".
|
||||
finalized_complete = False
|
||||
if should_save_conversation and response_full:
|
||||
try:
|
||||
if isNoneDoc:
|
||||
for doc in source_log_docs:
|
||||
doc["source"] = "None"
|
||||
# Mirror the normal-path provider resolution so the
|
||||
# partial-save title LLM uses the model-owner's BYOM
|
||||
# registration (shared-agent dispatch) rather than
|
||||
# the deployment default with the instance api key.
|
||||
# Resolve under model-owner scope so shared-agent
|
||||
# title-gen uses owner BYOM, not deployment default.
|
||||
provider = (
|
||||
get_provider_from_model_id(
|
||||
model_id,
|
||||
@@ -532,24 +826,58 @@ class BaseAnswerResource:
|
||||
agent_id=agent_id,
|
||||
model_user_id=model_user_id,
|
||||
)
|
||||
self.conversation_service.save_conversation(
|
||||
conversation_id,
|
||||
question,
|
||||
response_full,
|
||||
thought,
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
model_id or self.default_model_id,
|
||||
decoded_token,
|
||||
index=index,
|
||||
api_key=user_api_key,
|
||||
agent_id=agent_id,
|
||||
is_shared_usage=is_shared_usage,
|
||||
shared_token=shared_token,
|
||||
attachment_ids=attachment_ids,
|
||||
metadata=query_metadata if query_metadata else None,
|
||||
)
|
||||
llm._token_usage_source = "title"
|
||||
if reserved_message_id is not None:
|
||||
outcome = self.conversation_service.finalize_message(
|
||||
reserved_message_id,
|
||||
response_full,
|
||||
thought=thought,
|
||||
sources=source_log_docs,
|
||||
tool_calls=tool_calls,
|
||||
model_id=model_id or self.default_model_id,
|
||||
metadata=query_metadata if query_metadata else None,
|
||||
status="complete",
|
||||
title_inputs={
|
||||
"llm": llm,
|
||||
"question": question,
|
||||
"response": response_full,
|
||||
"model_id": model_id or self.default_model_id,
|
||||
"fallback_name": (
|
||||
question[:50] if question else "New Conversation"
|
||||
),
|
||||
},
|
||||
)
|
||||
# ``ALREADY_COMPLETE`` means the normal-path
|
||||
# finalize at line 632 won the race: the DB row
|
||||
# is already at ``complete`` and the reconnect
|
||||
# journal should reflect that with ``end``,
|
||||
# not a spurious ``error``.
|
||||
finalized_complete = outcome in (
|
||||
MessageUpdateOutcome.UPDATED,
|
||||
MessageUpdateOutcome.ALREADY_COMPLETE,
|
||||
)
|
||||
else:
|
||||
self.conversation_service.save_conversation(
|
||||
conversation_id,
|
||||
question,
|
||||
response_full,
|
||||
thought,
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
model_id or self.default_model_id,
|
||||
decoded_token,
|
||||
index=index,
|
||||
api_key=user_api_key,
|
||||
agent_id=agent_id,
|
||||
is_shared_usage=is_shared_usage,
|
||||
shared_token=shared_token,
|
||||
attachment_ids=attachment_ids,
|
||||
metadata=query_metadata if query_metadata else None,
|
||||
)
|
||||
# No journal row to gate, but flag the save as
|
||||
# successful for symmetry with the WAL path.
|
||||
finalized_complete = True
|
||||
compression_meta = getattr(agent, "compression_metadata", None)
|
||||
compression_saved = getattr(agent, "compression_saved", False)
|
||||
if conversation_id and compression_meta and not compression_saved:
|
||||
@@ -573,16 +901,94 @@ class BaseAnswerResource:
|
||||
logger.error(
|
||||
f"Error saving partial response: {str(e)}", exc_info=True
|
||||
)
|
||||
# Journal a terminal event so reconnecting clients stop tailing;
|
||||
# ``end`` only when the row is ``complete``, else ``error``.
|
||||
if reserved_message_id is not None:
|
||||
try:
|
||||
sequence_no += 1
|
||||
if finalized_complete:
|
||||
# Match the wire shape ``_emit({"type": "end"})``
|
||||
# uses on the normal path — the replay terminal
|
||||
# check at ``event_replay._payload_is_terminal``
|
||||
# reads ``payload.type``, and the frontend parses
|
||||
# the same key off ``data:``.
|
||||
record_event(
|
||||
reserved_message_id,
|
||||
sequence_no,
|
||||
"end",
|
||||
{"type": "end"},
|
||||
)
|
||||
else:
|
||||
# Nothing was persisted under the complete status
|
||||
# — mark the row failed so the reconciler doesn't
|
||||
# need to sweep it, and journal an ``error`` so a
|
||||
# reconnecting client surfaces the same failure
|
||||
# the UI would show on a live error.
|
||||
try:
|
||||
self.conversation_service.finalize_message(
|
||||
reserved_message_id,
|
||||
response_full or TERMINATED_RESPONSE_PLACEHOLDER,
|
||||
thought=thought,
|
||||
sources=source_log_docs,
|
||||
tool_calls=tool_calls,
|
||||
model_id=model_id or self.default_model_id,
|
||||
metadata=query_metadata if query_metadata else None,
|
||||
status="failed",
|
||||
error=ConnectionError(
|
||||
"client disconnected before response was persisted"
|
||||
),
|
||||
)
|
||||
except Exception as fin_err:
|
||||
logger.error(
|
||||
f"Failed to mark aborted message failed: {fin_err}",
|
||||
exc_info=True,
|
||||
)
|
||||
record_event(
|
||||
reserved_message_id,
|
||||
sequence_no,
|
||||
"error",
|
||||
{
|
||||
"type": "error",
|
||||
"error": "Stream aborted before any response was produced.",
|
||||
"code": "client_disconnect",
|
||||
},
|
||||
)
|
||||
except Exception as journal_err:
|
||||
logger.error(
|
||||
f"Failed to journal terminal event on abort: {journal_err}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in stream: {str(e)}", exc_info=True)
|
||||
data = json.dumps(
|
||||
if reserved_message_id is not None:
|
||||
try:
|
||||
self.conversation_service.finalize_message(
|
||||
reserved_message_id,
|
||||
response_full or TERMINATED_RESPONSE_PLACEHOLDER,
|
||||
thought=thought,
|
||||
sources=source_log_docs,
|
||||
tool_calls=tool_calls,
|
||||
model_id=model_id or self.default_model_id,
|
||||
metadata=query_metadata if query_metadata else None,
|
||||
status="failed",
|
||||
error=e,
|
||||
)
|
||||
except Exception as fin_err:
|
||||
logger.error(
|
||||
f"Failed to finalize errored message: {fin_err}",
|
||||
exc_info=True,
|
||||
)
|
||||
yield _emit(
|
||||
{
|
||||
"type": "error",
|
||||
"error": "Please try again later. We apologize for any inconvenience.",
|
||||
}
|
||||
)
|
||||
yield f"data: {data}\n\n"
|
||||
# Drain the terminal ``error`` event we just yielded so a
|
||||
# reconnecting client sees it on snapshot.
|
||||
if journal_writer is not None:
|
||||
journal_writer.close()
|
||||
return
|
||||
|
||||
def process_response_stream(self, stream) -> Dict[str, Any]:
|
||||
@@ -604,8 +1010,22 @@ class BaseAnswerResource:
|
||||
|
||||
for line in stream:
|
||||
try:
|
||||
event_data = line.replace("data: ", "").strip()
|
||||
# Each chunk may carry an ``id: <seq>`` header before
|
||||
# the ``data:`` line. Pull just the ``data:`` body so
|
||||
# the JSON decode doesn't choke on the SSE framing.
|
||||
event_data = ""
|
||||
for raw in line.split("\n"):
|
||||
if raw.startswith("data:"):
|
||||
event_data = raw[len("data:") :].lstrip()
|
||||
break
|
||||
if not event_data:
|
||||
continue
|
||||
event = json.loads(event_data)
|
||||
# The ``message_id`` event is informational for the
|
||||
# streaming consumer and has no synchronous-API field;
|
||||
# skip it so the type-switch below doesn't KeyError.
|
||||
if event.get("type") == "message_id":
|
||||
continue
|
||||
|
||||
if event["type"] == "id":
|
||||
conversation_id = event["id"]
|
||||
|
||||
135
application/api/answer/routes/messages.py
Normal file
135
application/api/answer/routes/messages.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""GET /api/messages/<message_id>/events — chat-stream reconnect endpoint.
|
||||
|
||||
Authenticates the caller, verifies ``message_id`` belongs to the user,
|
||||
then hands off to ``build_message_event_stream`` for snapshot+tail.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Iterator, Optional
|
||||
|
||||
from flask import Blueprint, Response, jsonify, make_response, request, stream_with_context
|
||||
from sqlalchemy import text
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.session import db_readonly
|
||||
from application.streaming.event_replay import (
|
||||
DEFAULT_KEEPALIVE_SECONDS,
|
||||
DEFAULT_POLL_TIMEOUT_SECONDS,
|
||||
build_message_event_stream,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
messages_bp = Blueprint("message_stream", __name__)
|
||||
|
||||
# A message_id is the canonical UUID hex format. Reject anything else
|
||||
# before the SQL layer so a malformed cookie can't surface as a 500.
|
||||
_MESSAGE_ID_RE = re.compile(
|
||||
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-"
|
||||
r"[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$"
|
||||
)
|
||||
# ``sequence_no`` is a non-negative decimal integer. Anything else is
|
||||
# corrupt client state — fall through to a fresh-replay cursor and let
|
||||
# the snapshot reader catch the client up.
|
||||
_SEQUENCE_NO_RE = re.compile(r"^\d+$")
|
||||
|
||||
|
||||
def _normalise_last_event_id(raw: Optional[str]) -> Optional[int]:
|
||||
if raw is None:
|
||||
return None
|
||||
raw = raw.strip()
|
||||
if not raw or not _SEQUENCE_NO_RE.match(raw):
|
||||
return None
|
||||
return int(raw)
|
||||
|
||||
|
||||
def _user_owns_message(message_id: str, user_id: str) -> bool:
|
||||
"""Return True iff ``message_id`` belongs to ``user_id``."""
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
row = conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT 1 FROM conversation_messages
|
||||
WHERE id = CAST(:id AS uuid)
|
||||
AND user_id = :u
|
||||
LIMIT 1
|
||||
"""
|
||||
),
|
||||
{"id": message_id, "u": user_id},
|
||||
).first()
|
||||
return row is not None
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Ownership lookup failed for message_id=%s user_id=%s",
|
||||
message_id,
|
||||
user_id,
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
@messages_bp.route("/api/messages/<message_id>/events", methods=["GET"])
|
||||
def stream_message_events(message_id: str) -> Response:
|
||||
decoded = getattr(request, "decoded_token", None)
|
||||
user_id = decoded.get("sub") if isinstance(decoded, dict) else None
|
||||
if not user_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Authentication required"}),
|
||||
401,
|
||||
)
|
||||
|
||||
if not _MESSAGE_ID_RE.match(message_id):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid message id"}),
|
||||
400,
|
||||
)
|
||||
|
||||
if not _user_owns_message(message_id, user_id):
|
||||
# Don't disclose whether the row exists — a malicious caller
|
||||
# gets the same 404 whether the id is bogus, taken by another
|
||||
# user, or simply gone.
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Not found"}),
|
||||
404,
|
||||
)
|
||||
|
||||
raw_cursor = request.headers.get("Last-Event-ID") or request.args.get(
|
||||
"last_event_id"
|
||||
)
|
||||
last_event_id = _normalise_last_event_id(raw_cursor)
|
||||
keepalive_seconds = float(
|
||||
getattr(settings, "SSE_KEEPALIVE_SECONDS", DEFAULT_KEEPALIVE_SECONDS)
|
||||
)
|
||||
|
||||
@stream_with_context
|
||||
def generate() -> Iterator[str]:
|
||||
try:
|
||||
yield from build_message_event_stream(
|
||||
message_id,
|
||||
last_event_id=last_event_id,
|
||||
keepalive_seconds=keepalive_seconds,
|
||||
poll_timeout_seconds=DEFAULT_POLL_TIMEOUT_SECONDS,
|
||||
)
|
||||
except GeneratorExit:
|
||||
return
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Reconnect stream crashed for message_id=%s user_id=%s",
|
||||
message_id,
|
||||
user_id,
|
||||
)
|
||||
|
||||
response = Response(generate(), mimetype="text/event-stream")
|
||||
response.headers["Cache-Control"] = "no-store"
|
||||
response.headers["X-Accel-Buffering"] = "no"
|
||||
response.headers["Connection"] = "keep-alive"
|
||||
logger.info(
|
||||
"message.event.connect message_id=%s user_id=%s last_event_id=%s",
|
||||
message_id,
|
||||
user_id,
|
||||
last_event_id if last_event_id is not None else "-",
|
||||
)
|
||||
return response
|
||||
@@ -115,6 +115,8 @@ class StreamResource(Resource, BaseAnswerResource):
|
||||
"tools_dict": tools_dict,
|
||||
"pending_tool_calls": pending_tool_calls,
|
||||
"tool_actions": tool_actions,
|
||||
"reserved_message_id": processor.reserved_message_id,
|
||||
"request_id": processor.request_id,
|
||||
},
|
||||
),
|
||||
mimetype="text/event-stream",
|
||||
|
||||
@@ -160,6 +160,9 @@ class CompressionOrchestrator:
|
||||
agent_id=conversation.get("agent_id"),
|
||||
model_user_id=registry_user_id,
|
||||
)
|
||||
# Side-channel LLM tag — distinguishes compression rows
|
||||
# from primary stream rows for cost-attribution dashboards.
|
||||
compression_llm._token_usage_source = "compression"
|
||||
|
||||
# Create compression service with DB update capability
|
||||
compression_service = CompressionService(
|
||||
|
||||
@@ -7,13 +7,13 @@ resume later by sending tool_actions.
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||
from application.storage.db.repositories.pending_tool_state import (
|
||||
PendingToolStateRepository,
|
||||
)
|
||||
from application.storage.db.serialization import coerce_pg_native as _make_serializable
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -21,23 +21,9 @@ logger = logging.getLogger(__name__)
|
||||
# TTL for pending states — auto-cleaned after this period
|
||||
PENDING_STATE_TTL_SECONDS = 30 * 60 # 30 minutes
|
||||
|
||||
|
||||
def _make_serializable(obj: Any) -> Any:
|
||||
"""Recursively coerce non-JSON values into JSON-safe forms.
|
||||
|
||||
Handles ``uuid.UUID`` (from PG columns), ``bytes``, and recurses into
|
||||
dicts/lists. Post-Mongo-cutover the ObjectId branch is gone — none of
|
||||
our writers produce them anymore.
|
||||
"""
|
||||
if isinstance(obj, UUID):
|
||||
return str(obj)
|
||||
if isinstance(obj, dict):
|
||||
return {str(k): _make_serializable(v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [_make_serializable(v) for v in obj]
|
||||
if isinstance(obj, bytes):
|
||||
return obj.decode("utf-8", errors="replace")
|
||||
return obj
|
||||
# Re-export so the existing tests at tests/api/answer/services/test_continuation_service_pg.py
|
||||
# can keep importing ``_make_serializable`` from here.
|
||||
__all__ = ["_make_serializable", "ContinuationService", "PENDING_STATE_TTL_SECONDS"]
|
||||
|
||||
|
||||
class ContinuationService:
|
||||
@@ -155,3 +141,23 @@ class ContinuationService:
|
||||
f"Deleted continuation state for conversation {conversation_id}"
|
||||
)
|
||||
return deleted
|
||||
|
||||
def mark_resuming(self, conversation_id: str, user: str) -> bool:
|
||||
"""Flip the pending row to ``resuming`` so a crashed resume can be retried."""
|
||||
with db_session() as conn:
|
||||
conv = ConversationsRepository(conn).get_by_legacy_id(conversation_id)
|
||||
if conv is not None:
|
||||
pg_conv_id = conv["id"]
|
||||
elif looks_like_uuid(conversation_id):
|
||||
pg_conv_id = conversation_id
|
||||
else:
|
||||
return False
|
||||
flipped = PendingToolStateRepository(conn).mark_resuming(
|
||||
pg_conv_id, user
|
||||
)
|
||||
if flipped:
|
||||
logger.info(
|
||||
f"Marked continuation state as resuming for conversation "
|
||||
f"{conversation_id}"
|
||||
)
|
||||
return flipped
|
||||
|
||||
@@ -6,6 +6,7 @@ than held for the duration of a stream.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
@@ -14,13 +15,22 @@ from sqlalchemy import text as sql_text
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||
from application.storage.db.repositories.conversations import (
|
||||
ConversationsRepository,
|
||||
MessageUpdateOutcome,
|
||||
)
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Shown to the user if the worker dies mid-stream and the response is never finalised.
|
||||
TERMINATED_RESPONSE_PLACEHOLDER = (
|
||||
"Response was terminated prior to completion, try regenerating."
|
||||
)
|
||||
|
||||
|
||||
class ConversationService:
|
||||
def get_conversation(
|
||||
self, conversation_id: str, user_id: str
|
||||
@@ -179,6 +189,243 @@ class ConversationService:
|
||||
repo.append_message(conv_pg_id, append_payload)
|
||||
return conv_pg_id
|
||||
|
||||
def save_user_question(
|
||||
self,
|
||||
conversation_id: Optional[str],
|
||||
question: str,
|
||||
decoded_token: Dict[str, Any],
|
||||
*,
|
||||
attachment_ids: Optional[List[str]] = None,
|
||||
api_key: Optional[str] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
is_shared_usage: bool = False,
|
||||
shared_token: Optional[str] = None,
|
||||
model_id: Optional[str] = None,
|
||||
request_id: Optional[str] = None,
|
||||
status: str = "pending",
|
||||
index: Optional[int] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""Reserve the placeholder message row before the LLM call.
|
||||
|
||||
``index`` triggers regenerate semantics: messages at
|
||||
``position >= index`` are truncated so the new placeholder
|
||||
lands at ``position = index`` rather than appending.
|
||||
|
||||
Returns ``{"conversation_id", "message_id", "request_id"}``.
|
||||
"""
|
||||
if decoded_token is None:
|
||||
raise ValueError("Invalid or missing authentication token")
|
||||
user_id = decoded_token.get("sub")
|
||||
if not user_id:
|
||||
raise ValueError("User ID not found in token")
|
||||
|
||||
request_id = request_id or str(uuid.uuid4())
|
||||
|
||||
resolved_api_key: Optional[str] = None
|
||||
resolved_agent_id: Optional[str] = None
|
||||
if api_key and not conversation_id:
|
||||
with db_readonly() as conn:
|
||||
agent = AgentsRepository(conn).find_by_key(api_key)
|
||||
if agent:
|
||||
resolved_api_key = agent.get("key")
|
||||
if agent_id:
|
||||
resolved_agent_id = agent_id
|
||||
|
||||
with db_session() as conn:
|
||||
repo = ConversationsRepository(conn)
|
||||
if conversation_id:
|
||||
conv = repo.get_any(conversation_id, user_id)
|
||||
if conv is None:
|
||||
raise ValueError("Conversation not found or unauthorized")
|
||||
conv_pg_id = str(conv["id"])
|
||||
# Regenerate / edit-prior-question: drop the message at
|
||||
# ``index`` and everything after it so the new
|
||||
# ``reserve_message`` lands at ``position=index`` rather
|
||||
# than appending at the end of the conversation.
|
||||
if isinstance(index, int) and index >= 0:
|
||||
repo.truncate_after(conv_pg_id, keep_up_to=index - 1)
|
||||
else:
|
||||
fallback_name = (question[:50] if question else "New Conversation")
|
||||
conv = repo.create(
|
||||
user_id,
|
||||
fallback_name,
|
||||
agent_id=resolved_agent_id,
|
||||
api_key=resolved_api_key,
|
||||
is_shared_usage=bool(resolved_agent_id and is_shared_usage),
|
||||
shared_token=(
|
||||
shared_token
|
||||
if (resolved_agent_id and is_shared_usage)
|
||||
else None
|
||||
),
|
||||
)
|
||||
conv_pg_id = str(conv["id"])
|
||||
|
||||
row = repo.reserve_message(
|
||||
conv_pg_id,
|
||||
prompt=question,
|
||||
placeholder_response=TERMINATED_RESPONSE_PLACEHOLDER,
|
||||
request_id=request_id,
|
||||
status=status,
|
||||
attachments=attachment_ids,
|
||||
model_id=model_id,
|
||||
)
|
||||
message_id = str(row["id"])
|
||||
|
||||
return {
|
||||
"conversation_id": conv_pg_id,
|
||||
"message_id": message_id,
|
||||
"request_id": request_id,
|
||||
}
|
||||
|
||||
def update_message_status(self, message_id: str, status: str) -> bool:
|
||||
"""Cheap status-only transition (e.g. ``pending → streaming``)."""
|
||||
if not message_id:
|
||||
return False
|
||||
with db_session() as conn:
|
||||
return ConversationsRepository(conn).update_message_status(
|
||||
message_id, status,
|
||||
)
|
||||
|
||||
def heartbeat_message(self, message_id: str) -> bool:
|
||||
"""Bump ``message_metadata.last_heartbeat_at`` so the reconciler's
|
||||
staleness sweep counts the row as alive. No-ops on terminal rows.
|
||||
"""
|
||||
if not message_id:
|
||||
return False
|
||||
with db_session() as conn:
|
||||
return ConversationsRepository(conn).heartbeat_message(message_id)
|
||||
|
||||
def finalize_message(
|
||||
self,
|
||||
message_id: str,
|
||||
response: str,
|
||||
*,
|
||||
thought: str = "",
|
||||
sources: Optional[List[Dict[str, Any]]] = None,
|
||||
tool_calls: Optional[List[Dict[str, Any]]] = None,
|
||||
model_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
status: str = "complete",
|
||||
error: Optional[BaseException] = None,
|
||||
title_inputs: Optional[Dict[str, Any]] = None,
|
||||
) -> MessageUpdateOutcome:
|
||||
"""Commit the response and tool_call confirms in one transaction.
|
||||
|
||||
The outcome propagates directly from ``update_message_by_id`` so
|
||||
callers (notably the SSE abort handler) can tell a fresh
|
||||
finalize from "the row was already terminal" — the latter must
|
||||
still be treated as success when the prior state was
|
||||
``complete``.
|
||||
"""
|
||||
if not message_id:
|
||||
return MessageUpdateOutcome.INVALID
|
||||
sources = sources or []
|
||||
for source in sources:
|
||||
if "text" in source and isinstance(source["text"], str):
|
||||
source["text"] = source["text"][:1000]
|
||||
|
||||
merged_metadata: Dict[str, Any] = dict(metadata or {})
|
||||
if status == "failed" and error is not None:
|
||||
merged_metadata.setdefault(
|
||||
"error", f"{type(error).__name__}: {str(error)}"
|
||||
)
|
||||
|
||||
update_fields: Dict[str, Any] = {
|
||||
"response": response,
|
||||
"status": status,
|
||||
"thought": thought,
|
||||
"sources": sources,
|
||||
"tool_calls": tool_calls or [],
|
||||
"metadata": merged_metadata,
|
||||
}
|
||||
if model_id is not None:
|
||||
update_fields["model_id"] = model_id
|
||||
|
||||
# Atomic message update + tool_call_attempts confirm; the
|
||||
# ``only_if_non_terminal`` guard prevents a late stream from
|
||||
# retracting a row the reconciler already escalated.
|
||||
with db_session() as conn:
|
||||
repo = ConversationsRepository(conn)
|
||||
outcome = repo.update_message_by_id(
|
||||
message_id, update_fields,
|
||||
only_if_non_terminal=True,
|
||||
)
|
||||
if outcome is not MessageUpdateOutcome.UPDATED:
|
||||
logger.warning(
|
||||
f"finalize_message: no row updated for message_id={message_id} "
|
||||
f"(outcome={outcome.value} — possibly already terminal)"
|
||||
)
|
||||
return outcome
|
||||
repo.confirm_executed_tool_calls(message_id)
|
||||
|
||||
# Outside the txn — title-gen is a multi-second LLM round trip.
|
||||
if title_inputs and status == "complete":
|
||||
try:
|
||||
with db_session() as conn:
|
||||
self._maybe_generate_title(conn, message_id, title_inputs)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"finalize_message title generation failed: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
return MessageUpdateOutcome.UPDATED
|
||||
|
||||
def _maybe_generate_title(
|
||||
self,
|
||||
conn,
|
||||
message_id: str,
|
||||
title_inputs: Dict[str, Any],
|
||||
) -> None:
|
||||
"""Generate an LLM-summarised conversation name if one isn't set yet."""
|
||||
llm = title_inputs.get("llm")
|
||||
question = title_inputs.get("question") or ""
|
||||
response = title_inputs.get("response") or ""
|
||||
fallback_name = title_inputs.get("fallback_name") or question[:50]
|
||||
if llm is None:
|
||||
return
|
||||
|
||||
row = conn.execute(
|
||||
sql_text(
|
||||
"SELECT c.id, c.name FROM conversation_messages m "
|
||||
"JOIN conversations c ON c.id = m.conversation_id "
|
||||
"WHERE m.id = CAST(:mid AS uuid)"
|
||||
),
|
||||
{"mid": message_id},
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return
|
||||
conv_id, current_name = str(row[0]), row[1]
|
||||
if current_name and current_name != fallback_name:
|
||||
return
|
||||
|
||||
messages_summary = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant that creates concise conversation titles. "
|
||||
"Summarize conversations in 3 words or less using the same language as the user.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Summarise following conversation in no more than 3 words, "
|
||||
"respond ONLY with the summary, use the same language as the "
|
||||
"user query \n\nUser: " + question + "\n\n" + "AI: " + response,
|
||||
},
|
||||
]
|
||||
completion = llm.gen(
|
||||
model=getattr(llm, "model_id", None) or title_inputs.get("model_id"),
|
||||
messages=messages_summary,
|
||||
max_tokens=500,
|
||||
)
|
||||
if not completion or not completion.strip():
|
||||
completion = fallback_name or "New Conversation"
|
||||
conn.execute(
|
||||
sql_text(
|
||||
"UPDATE conversations SET name = :name, updated_at = now() "
|
||||
"WHERE id = CAST(:id AS uuid)"
|
||||
),
|
||||
{"id": conv_id, "name": completion.strip()},
|
||||
)
|
||||
|
||||
def update_compression_metadata(
|
||||
self, conversation_id: str, compression_metadata: Dict[str, Any]
|
||||
) -> None:
|
||||
|
||||
@@ -123,6 +123,10 @@ class StreamProcessor:
|
||||
self.model_id: Optional[str] = None
|
||||
# BYOM-resolution scope, set by _validate_and_set_model.
|
||||
self.model_user_id: Optional[str] = None
|
||||
# WAL placeholder id pulled from continuation state on resume.
|
||||
self.reserved_message_id: Optional[str] = None
|
||||
# Carried through resumes so multi-pause runs keep one request_id.
|
||||
self.request_id: Optional[str] = None
|
||||
self.conversation_service = ConversationService()
|
||||
self.compression_orchestrator = CompressionOrchestrator(
|
||||
self.conversation_service
|
||||
@@ -928,6 +932,20 @@ class StreamProcessor:
|
||||
if not state:
|
||||
raise ValueError("No pending tool state found for this conversation")
|
||||
|
||||
# Claim the resume up-front. ``mark_resuming`` only flips ``pending``
|
||||
# → ``resuming``; if it returns False, another resume already
|
||||
# claimed this row (status='resuming') — bail before any further
|
||||
# LLM/tool work to avoid double-execution. The cleanup janitor
|
||||
# reverts a stale ``resuming`` claim back to ``pending`` after the
|
||||
# 10-minute grace window so the user can retry.
|
||||
if not cont_service.mark_resuming(
|
||||
conversation_id, self.initial_user_id,
|
||||
):
|
||||
raise ValueError(
|
||||
"Resume already in progress for this conversation; "
|
||||
"retry after the grace window if it stalls."
|
||||
)
|
||||
|
||||
messages = state["messages"]
|
||||
pending_tool_calls = state["pending_tool_calls"]
|
||||
tools_dict = state["tools_dict"]
|
||||
@@ -1022,9 +1040,10 @@ class StreamProcessor:
|
||||
self.agent_id = agent_id
|
||||
self.agent_config["user_api_key"] = user_api_key
|
||||
self.conversation_id = conversation_id
|
||||
|
||||
# Delete state so it can't be replayed
|
||||
cont_service.delete_state(conversation_id, self.initial_user_id)
|
||||
# Reused on resume so the same WAL row gets finalised and
|
||||
# request_id stays consistent across token_usage rows.
|
||||
self.reserved_message_id = agent_config.get("reserved_message_id")
|
||||
self.request_id = agent_config.get("request_id")
|
||||
|
||||
return agent, messages, tools_dict, pending_tool_calls, tool_actions
|
||||
|
||||
|
||||
0
application/api/events/__init__.py
Normal file
0
application/api/events/__init__.py
Normal file
504
application/api/events/routes.py
Normal file
504
application/api/events/routes.py
Normal file
@@ -0,0 +1,504 @@
|
||||
"""GET /api/events — user-scoped Server-Sent Events endpoint.
|
||||
|
||||
Subscribe-then-snapshot pattern: subscribe to ``user:{user_id}``
|
||||
pub/sub, snapshot the Redis Streams backlog past ``Last-Event-ID``
|
||||
inside the SUBSCRIBE-ack callback, flush snapshot, then tail live
|
||||
events (dedup'd by stream id). See ``docs/runbooks/sse-notifications.md``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from typing import Iterator, Optional
|
||||
|
||||
from flask import Blueprint, Response, jsonify, make_response, request, stream_with_context
|
||||
|
||||
from application.cache import get_redis_instance
|
||||
from application.core.settings import settings
|
||||
from application.events.keys import (
|
||||
connection_counter_key,
|
||||
replay_budget_key,
|
||||
stream_id_compare,
|
||||
stream_key,
|
||||
topic_name,
|
||||
)
|
||||
from application.streaming.broadcast_channel import Topic
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
events = Blueprint("event_stream", __name__)
|
||||
|
||||
SUBSCRIBE_POLL_INTERVAL_SECONDS = 1.0
|
||||
|
||||
# WHATWG SSE treats CRLF, CR, and LF equivalently as line terminators.
|
||||
_SSE_LINE_SPLIT = re.compile(r"\r\n|\r|\n")
|
||||
|
||||
# Redis Streams ids are ``ms`` or ``ms-seq`` where both halves are decimal.
|
||||
# Anything else is a corrupted client cookie / IndexedDB residue and must
|
||||
# not be passed to XRANGE — Redis would reject it and our truncation gate
|
||||
# would silently fail.
|
||||
_STREAM_ID_RE = re.compile(r"^\d+(-\d+)?$")
|
||||
|
||||
# Only emitted at most once per process so a misconfigured deployment
|
||||
# doesn't drown the logs.
|
||||
_local_user_warned = False
|
||||
|
||||
|
||||
def _format_sse(data: str, *, event_id: Optional[str] = None) -> str:
|
||||
"""Encode a payload as one SSE message terminated by a blank line.
|
||||
|
||||
Splits on any line-terminator variant (``\\r\\n``, ``\\r``, ``\\n``)
|
||||
so a stray CR in upstream content can't smuggle a premature line
|
||||
boundary into the wire format.
|
||||
"""
|
||||
lines: list[str] = []
|
||||
if event_id:
|
||||
lines.append(f"id: {event_id}")
|
||||
for line in _SSE_LINE_SPLIT.split(data):
|
||||
lines.append(f"data: {line}")
|
||||
return "\n".join(lines) + "\n\n"
|
||||
|
||||
|
||||
def _decode(value) -> Optional[str]:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, (bytes, bytearray)):
|
||||
try:
|
||||
return value.decode("utf-8")
|
||||
except Exception:
|
||||
return None
|
||||
return str(value)
|
||||
|
||||
|
||||
def _oldest_retained_id(redis_client, user_id: str) -> Optional[str]:
|
||||
"""Return the id of the oldest entry still in the stream, or ``None``.
|
||||
|
||||
Used to detect ``Last-Event-ID`` having slid off the back of the
|
||||
MAXLEN'd window.
|
||||
"""
|
||||
try:
|
||||
info = redis_client.xinfo_stream(stream_key(user_id))
|
||||
except Exception:
|
||||
return None
|
||||
if not isinstance(info, dict):
|
||||
return None
|
||||
# redis-py 7.4 returns str-keyed dicts here; the bytes-key probe is
|
||||
# defence in depth in case ``decode_responses`` is ever flipped.
|
||||
first_entry = info.get("first-entry") or info.get(b"first-entry")
|
||||
if not first_entry:
|
||||
return None
|
||||
# XINFO STREAM returns first-entry as [id, [field, value, ...]]
|
||||
try:
|
||||
return _decode(first_entry[0])
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _allow_replay(
|
||||
redis_client, user_id: str, last_event_id: Optional[str]
|
||||
) -> bool:
|
||||
"""Per-user sliding-window snapshot-replay budget.
|
||||
|
||||
Fails open on Redis errors or when the budget is disabled. Empty-backlog
|
||||
no-cursor connects skip INCR so dev double-mounts don't trip 429.
|
||||
"""
|
||||
budget = int(settings.EVENTS_REPLAY_BUDGET_REQUESTS_PER_WINDOW)
|
||||
if budget <= 0:
|
||||
return True
|
||||
if redis_client is None:
|
||||
return True
|
||||
|
||||
# Cheap pre-check: only INCR when we might actually replay. XLEN
|
||||
# is one Redis op; the alternative (INCR every connect) is two
|
||||
# ops AND wrongly counts no-op probes. The check is conservative:
|
||||
# if ``last_event_id`` is set we always INCR, even if the cursor
|
||||
# has already overtaken the latest entry — that case is rare and
|
||||
# short-lived, and probing further would mean a redundant XRANGE.
|
||||
if last_event_id is None:
|
||||
try:
|
||||
if int(redis_client.xlen(stream_key(user_id))) == 0:
|
||||
return True
|
||||
except Exception:
|
||||
# XLEN probe failed; fall through to the INCR path so a
|
||||
# transient Redis hiccup can't bypass the budget.
|
||||
logger.debug(
|
||||
"XLEN probe failed for replay budget check user=%s; "
|
||||
"proceeding to INCR",
|
||||
user_id,
|
||||
)
|
||||
|
||||
window = max(1, int(settings.EVENTS_REPLAY_BUDGET_WINDOW_SECONDS))
|
||||
key = replay_budget_key(user_id)
|
||||
try:
|
||||
used = int(redis_client.incr(key))
|
||||
# Always (re)seed the TTL. Gating on ``used == 1`` would wedge
|
||||
# the counter forever if INCR succeeds but EXPIRE raises on
|
||||
# the seeding call. EXPIRE on an existing key resets the TTL
|
||||
# to ``window`` — within ±1s of the per-window budget semantic.
|
||||
redis_client.expire(key, window)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"replay budget probe failed for user=%s; failing open",
|
||||
user_id,
|
||||
)
|
||||
return True
|
||||
return used <= budget
|
||||
|
||||
|
||||
def _normalize_last_event_id(raw: Optional[str]) -> Optional[str]:
|
||||
"""Validate the ``Last-Event-ID`` header / query param.
|
||||
|
||||
Returns the value unchanged when it parses as a Redis Streams id,
|
||||
otherwise ``None`` — callers treat ``None`` as "client has nothing"
|
||||
and replay from the start of the retained window. Invalid ids would
|
||||
otherwise pass straight to XRANGE and surface as a quiet replay
|
||||
failure plus broken truncation detection.
|
||||
"""
|
||||
if raw is None:
|
||||
return None
|
||||
raw = raw.strip()
|
||||
if not raw or not _STREAM_ID_RE.match(raw):
|
||||
return None
|
||||
return raw
|
||||
|
||||
|
||||
def _replay_backlog(
|
||||
redis_client, user_id: str, last_event_id: Optional[str], max_count: int
|
||||
) -> Iterator[tuple[str, str]]:
|
||||
"""Yield ``(entry_id, sse_line)`` for backlog entries past ``last_event_id``.
|
||||
|
||||
Capped at ``max_count`` rows; clients catch up across reconnects.
|
||||
Parse failures are skipped; the Streams id is injected into the
|
||||
envelope so replay matches live-tail shape.
|
||||
"""
|
||||
# Exclusive start: '(<id>' skips the already-delivered entry.
|
||||
start = f"({last_event_id}" if last_event_id else "-"
|
||||
try:
|
||||
entries = redis_client.xrange(
|
||||
stream_key(user_id), min=start, max="+", count=max_count
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"xrange replay failed for user=%s last_id=%s err=%s",
|
||||
user_id,
|
||||
last_event_id or "-",
|
||||
exc,
|
||||
)
|
||||
return
|
||||
|
||||
for entry_id, fields in entries:
|
||||
entry_id_str = _decode(entry_id)
|
||||
if not entry_id_str:
|
||||
continue
|
||||
# decode_responses=False on the cache client ⇒ field keys/values
|
||||
# are bytes. The string-key fallback covers a future flip of that
|
||||
# default without a forced refactor here.
|
||||
raw_event = None
|
||||
if isinstance(fields, dict):
|
||||
raw_event = fields.get(b"event")
|
||||
if raw_event is None:
|
||||
raw_event = fields.get("event")
|
||||
event_str = _decode(raw_event)
|
||||
if not event_str:
|
||||
continue
|
||||
try:
|
||||
envelope = json.loads(event_str)
|
||||
if isinstance(envelope, dict):
|
||||
envelope["id"] = entry_id_str
|
||||
event_str = json.dumps(envelope)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Replay envelope parse failed for entry %s; passing through raw",
|
||||
entry_id_str,
|
||||
)
|
||||
yield entry_id_str, _format_sse(event_str, event_id=entry_id_str)
|
||||
|
||||
|
||||
def _truncation_notice_line(oldest_id: str) -> str:
|
||||
"""SSE event the frontend can react to with a full-state refetch."""
|
||||
return _format_sse(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "backlog.truncated",
|
||||
"payload": {"oldest_retained_id": oldest_id},
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@events.route("/api/events", methods=["GET"])
|
||||
def stream_events() -> Response:
|
||||
decoded = getattr(request, "decoded_token", None)
|
||||
user_id = decoded.get("sub") if isinstance(decoded, dict) else None
|
||||
if not user_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Authentication required"}),
|
||||
401,
|
||||
)
|
||||
|
||||
# In dev deployments without AUTH_TYPE configured, every request
|
||||
# resolves to user_id="local" and shares one stream. Surface this so
|
||||
# an accidentally-multi-user dev box doesn't silently cross-stream.
|
||||
global _local_user_warned
|
||||
if user_id == "local" and not _local_user_warned:
|
||||
logger.warning(
|
||||
"SSE serving user_id='local' (AUTH_TYPE not set). "
|
||||
"All clients on this deployment will share one event stream."
|
||||
)
|
||||
_local_user_warned = True
|
||||
|
||||
raw_last_event_id = request.headers.get("Last-Event-ID") or request.args.get(
|
||||
"last_event_id"
|
||||
)
|
||||
last_event_id = _normalize_last_event_id(raw_last_event_id)
|
||||
last_event_id_invalid = raw_last_event_id is not None and last_event_id is None
|
||||
|
||||
keepalive_seconds = float(settings.SSE_KEEPALIVE_SECONDS)
|
||||
push_enabled = settings.ENABLE_SSE_PUSH
|
||||
cap = int(settings.SSE_MAX_CONCURRENT_PER_USER)
|
||||
|
||||
redis_client = get_redis_instance()
|
||||
counter_key = connection_counter_key(user_id)
|
||||
counted = False
|
||||
|
||||
if push_enabled and redis_client is not None and cap > 0:
|
||||
try:
|
||||
current = int(redis_client.incr(counter_key))
|
||||
counted = True
|
||||
except Exception:
|
||||
current = 0
|
||||
logger.debug(
|
||||
"SSE connection counter INCR failed for user=%s", user_id
|
||||
)
|
||||
if counted:
|
||||
# 1h safety TTL — orphaned counts from hard crashes self-heal.
|
||||
# EXPIRE failure must NOT clobber ``current`` and bypass the cap.
|
||||
try:
|
||||
redis_client.expire(counter_key, 3600)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"SSE connection counter EXPIRE failed for user=%s", user_id
|
||||
)
|
||||
if current > cap:
|
||||
try:
|
||||
redis_client.decr(counter_key)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"SSE connection counter DECR failed for user=%s",
|
||||
user_id,
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Too many concurrent SSE connections",
|
||||
}
|
||||
),
|
||||
429,
|
||||
)
|
||||
|
||||
# Replay budget is checked here, before the generator opens the
|
||||
# stream, so a denial can surface as HTTP 429 instead of a silent
|
||||
# snapshot skip. The earlier in-generator skip lost events between
|
||||
# the client's cursor and the first live-tailed entry: the live
|
||||
# tail still carried ``id:`` headers, the frontend advanced
|
||||
# ``lastEventId`` to one of those ids, and the events in between
|
||||
# were never reachable on the next reconnect. 429 keeps the
|
||||
# cursor pinned and lets the frontend back off until the window
|
||||
# slides (eventStreamClient.ts treats 429 as escalated backoff).
|
||||
if push_enabled and redis_client is not None and not _allow_replay(
|
||||
redis_client, user_id, last_event_id
|
||||
):
|
||||
if counted:
|
||||
try:
|
||||
redis_client.decr(counter_key)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"SSE connection counter DECR failed for user=%s",
|
||||
user_id,
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Replay budget exhausted",
|
||||
}
|
||||
),
|
||||
429,
|
||||
)
|
||||
|
||||
@stream_with_context
|
||||
def generate() -> Iterator[str]:
|
||||
connect_ts = time.monotonic()
|
||||
replayed_count = 0
|
||||
try:
|
||||
# First frame primes intermediaries (Cloudflare, nginx) so they
|
||||
# don't sit on a buffer waiting for body bytes.
|
||||
yield ": connected\n\n"
|
||||
|
||||
if not push_enabled:
|
||||
yield ": push_disabled\n\n"
|
||||
return
|
||||
|
||||
replay_lines: list[str] = []
|
||||
max_replayed_id: Optional[str] = None
|
||||
replay_done = False
|
||||
|
||||
# If the client sent a malformed Last-Event-ID, surface the
|
||||
# truncation notice synchronously *before* the subscribe
|
||||
# loop. Buffering it into ``replay_lines`` would lose it
|
||||
# when ``Topic.subscribe`` returns immediately (Redis down)
|
||||
# — the loop body never runs, and the flush at line ~335
|
||||
# never fires.
|
||||
if last_event_id_invalid:
|
||||
yield _truncation_notice_line("")
|
||||
replayed_count += 1
|
||||
|
||||
def _on_subscribe_callback() -> None:
|
||||
# Runs synchronously inside Topic.subscribe after the
|
||||
# SUBSCRIBE is acked. By doing XRANGE here, any publisher
|
||||
# firing between SUBSCRIBE-send and SUBSCRIBE-ack has its
|
||||
# XADD captured by XRANGE *and* its PUBLISH buffered at
|
||||
# the connection layer until we read it — closing the
|
||||
# replay/subscribe race the design doc warns about.
|
||||
#
|
||||
# Truncation contract: ``backlog.truncated`` is emitted
|
||||
# ONLY when the client's ``Last-Event-ID`` has slid off
|
||||
# the MAXLEN'd window — that's the case where the
|
||||
# journal is genuinely gone past the cursor and the
|
||||
# frontend should clear its slice cursor and refetch
|
||||
# state. Cap-hit skips the snapshot silently: the
|
||||
# cursor advances via the per-entry ``id:`` headers
|
||||
# and the frontend's slice keeps the latest id so the
|
||||
# next reconnect resumes from there. Budget-exhausted
|
||||
# never reaches this callback — the route 429s before
|
||||
# opening the stream, keeping the cursor pinned.
|
||||
# Conflating these with stale-cursor truncation would
|
||||
# tell the client to clear its cursor and re-receive
|
||||
# the same oldest-N entries on every reconnect —
|
||||
# locking the user out of entries past N.
|
||||
nonlocal max_replayed_id, replay_done
|
||||
try:
|
||||
if redis_client is None:
|
||||
return
|
||||
oldest = _oldest_retained_id(redis_client, user_id)
|
||||
if (
|
||||
last_event_id
|
||||
and oldest
|
||||
and stream_id_compare(last_event_id, oldest) < 0
|
||||
):
|
||||
# The Last-Event-ID has slid off the MAXLEN window.
|
||||
# Tell the client so it can fetch full state.
|
||||
replay_lines.append(_truncation_notice_line(oldest))
|
||||
replay_cap = int(settings.EVENTS_REPLAY_MAX_PER_REQUEST)
|
||||
for entry_id, sse_line in _replay_backlog(
|
||||
redis_client, user_id, last_event_id, replay_cap
|
||||
):
|
||||
replay_lines.append(sse_line)
|
||||
max_replayed_id = entry_id
|
||||
finally:
|
||||
# Always flip the flag — even on partial-replay failure
|
||||
# the outer loop must reach the flush step so we don't
|
||||
# silently strand whatever entries did land.
|
||||
replay_done = True
|
||||
|
||||
topic = Topic(topic_name(user_id))
|
||||
last_keepalive = time.monotonic()
|
||||
for payload in topic.subscribe(
|
||||
on_subscribe=_on_subscribe_callback,
|
||||
poll_timeout=SUBSCRIBE_POLL_INTERVAL_SECONDS,
|
||||
):
|
||||
# Flush snapshot on the first iteration after the SUBSCRIBE
|
||||
# callback ran. This runs at most once per connection.
|
||||
if replay_done and replay_lines:
|
||||
for line in replay_lines:
|
||||
yield line
|
||||
replayed_count += 1
|
||||
replay_lines.clear()
|
||||
|
||||
now = time.monotonic()
|
||||
if payload is None:
|
||||
if now - last_keepalive >= keepalive_seconds:
|
||||
yield ": keepalive\n\n"
|
||||
last_keepalive = now
|
||||
continue
|
||||
|
||||
event_str = _decode(payload) or ""
|
||||
event_id: Optional[str] = None
|
||||
try:
|
||||
envelope = json.loads(event_str)
|
||||
if isinstance(envelope, dict):
|
||||
candidate = envelope.get("id")
|
||||
# Only trust ids that look like real Redis Streams
|
||||
# ids (``ms`` or ``ms-seq``). A malformed or
|
||||
# adversarial publisher could otherwise pin
|
||||
# dedupe forever — a lex-greater bogus id would
|
||||
# make every legitimate later id compare ``<=``
|
||||
# and get dropped silently.
|
||||
if isinstance(candidate, str) and _STREAM_ID_RE.match(
|
||||
candidate
|
||||
):
|
||||
event_id = candidate
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Dedupe: if this id was already covered by replay, drop it.
|
||||
if (
|
||||
event_id is not None
|
||||
and max_replayed_id is not None
|
||||
and stream_id_compare(event_id, max_replayed_id) <= 0
|
||||
):
|
||||
continue
|
||||
|
||||
yield _format_sse(event_str, event_id=event_id)
|
||||
last_keepalive = now
|
||||
|
||||
# Topic.subscribe exited before the first yield (transient
|
||||
# Redis hiccup between SUBSCRIBE-ack and the first poll, or
|
||||
# an immediate Redis-down return). The callback may already
|
||||
# have populated the snapshot — flush it so the client gets
|
||||
# the backlog instead of a silent drop. Safe no-op when the
|
||||
# in-loop flush ran (it clear()'d the buffer) and when the
|
||||
# callback never fired (replay_done stays False).
|
||||
if replay_done and replay_lines:
|
||||
for line in replay_lines:
|
||||
yield line
|
||||
replayed_count += 1
|
||||
replay_lines.clear()
|
||||
except GeneratorExit:
|
||||
return
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"SSE event-stream generator crashed for user=%s", user_id
|
||||
)
|
||||
finally:
|
||||
duration_s = time.monotonic() - connect_ts
|
||||
logger.info(
|
||||
"event.disconnect user=%s duration_s=%.1f replayed=%d",
|
||||
user_id,
|
||||
duration_s,
|
||||
replayed_count,
|
||||
)
|
||||
if counted and redis_client is not None:
|
||||
try:
|
||||
redis_client.decr(counter_key)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"SSE connection counter DECR failed for user=%s on disconnect",
|
||||
user_id,
|
||||
)
|
||||
|
||||
response = Response(generate(), mimetype="text/event-stream")
|
||||
response.headers["Cache-Control"] = "no-store"
|
||||
response.headers["X-Accel-Buffering"] = "no"
|
||||
response.headers["Connection"] = "keep-alive"
|
||||
logger.info(
|
||||
"event.connect user=%s last_event_id=%s%s",
|
||||
user_id,
|
||||
last_event_id or "-",
|
||||
" (rejected_invalid)" if last_event_id_invalid else "",
|
||||
)
|
||||
return response
|
||||
@@ -46,7 +46,9 @@ AGENT_TYPE_SCHEMAS = {
|
||||
"prompt_id",
|
||||
],
|
||||
"required_draft": ["name"],
|
||||
"validate_published": ["name", "description", "prompt_id"],
|
||||
# ``prompt_id`` intentionally omitted — the "default" sentinel
|
||||
# is acceptable and maps to NULL downstream.
|
||||
"validate_published": ["name", "description"],
|
||||
"validate_draft": [],
|
||||
"require_source": True,
|
||||
"fields": [
|
||||
@@ -1009,12 +1011,16 @@ class UpdateAgent(Resource):
|
||||
400,
|
||||
)
|
||||
else:
|
||||
# ``prompt_id`` is intentionally omitted: the
|
||||
# frontend's "default" choice maps to NULL here
|
||||
# (see the prompt_id branch above), and NULL
|
||||
# means "use the built-in default prompt" which
|
||||
# is a valid published-agent state.
|
||||
missing_published_fields = []
|
||||
for req_field, field_label in (
|
||||
("name", "Agent name"),
|
||||
("description", "Agent description"),
|
||||
("chunks", "Chunks count"),
|
||||
("prompt_id", "Prompt"),
|
||||
("agent_type", "Agent type"),
|
||||
):
|
||||
final_value = update_fields.get(
|
||||
@@ -1028,8 +1034,23 @@ class UpdateAgent(Resource):
|
||||
extra_final = update_fields.get(
|
||||
"extra_source_ids", existing_agent.get("extra_source_ids") or [],
|
||||
)
|
||||
if not source_final and not extra_final:
|
||||
missing_published_fields.append("Source")
|
||||
# ``retriever`` carries the runtime identity for
|
||||
# agents that publish against the synthetic
|
||||
# "Default" source (frontend's auto-selected
|
||||
# ``{name: "Default", retriever: "classic"}``
|
||||
# entry has no ``id``, so ``source_id`` ends up
|
||||
# NULL even though the user picked something).
|
||||
# Without this fallback the most common new-agent
|
||||
# publish flow gets a 400.
|
||||
retriever_final = update_fields.get(
|
||||
"retriever", existing_agent.get("retriever"),
|
||||
)
|
||||
if (
|
||||
not source_final
|
||||
and not extra_final
|
||||
and not retriever_final
|
||||
):
|
||||
missing_published_fields.append("Source or retriever")
|
||||
if missing_published_fields:
|
||||
return make_response(
|
||||
jsonify(
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
"""Agent management webhook handlers."""
|
||||
|
||||
import secrets
|
||||
import uuid
|
||||
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import Namespace, Resource
|
||||
from sqlalchemy import text as sql_text
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import require_agent
|
||||
from application.api.user.tasks import process_agent_webhook
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.idempotency import IdempotencyRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
|
||||
@@ -18,6 +22,37 @@ agents_webhooks_ns = Namespace(
|
||||
)
|
||||
|
||||
|
||||
_IDEMPOTENCY_KEY_MAX_LEN = 256
|
||||
|
||||
|
||||
def _read_idempotency_key():
|
||||
"""Return (key, error_response). Empty header → (None, None); oversized → (None, 400)."""
|
||||
key = request.headers.get("Idempotency-Key")
|
||||
if not key:
|
||||
return None, None
|
||||
if len(key) > _IDEMPOTENCY_KEY_MAX_LEN:
|
||||
return None, make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": (
|
||||
f"Idempotency-Key exceeds maximum length of "
|
||||
f"{_IDEMPOTENCY_KEY_MAX_LEN} characters"
|
||||
),
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
return key, None
|
||||
|
||||
|
||||
def _scoped_idempotency_key(idempotency_key, scope):
|
||||
"""``{scope}:{key}`` so different agents can't collide on the same key."""
|
||||
if not idempotency_key or not scope:
|
||||
return None
|
||||
return f"{scope}:{idempotency_key}"
|
||||
|
||||
|
||||
@agents_webhooks_ns.route("/agent_webhook")
|
||||
class AgentWebhook(Resource):
|
||||
@api.doc(
|
||||
@@ -68,7 +103,7 @@ class AgentWebhook(Resource):
|
||||
class AgentWebhookListener(Resource):
|
||||
method_decorators = [require_agent]
|
||||
|
||||
def _enqueue_webhook_task(self, agent_id_str, payload, source_method):
|
||||
def _enqueue_webhook_task(self, agent_id_str, payload, source_method, agent=None):
|
||||
if not payload:
|
||||
current_app.logger.warning(
|
||||
f"Webhook ({source_method}) received for agent {agent_id_str} with empty payload."
|
||||
@@ -77,26 +112,94 @@ class AgentWebhookListener(Resource):
|
||||
f"Incoming {source_method} webhook for agent {agent_id_str}. Enqueuing task with payload: {payload}"
|
||||
)
|
||||
|
||||
try:
|
||||
task = process_agent_webhook.delay(
|
||||
agent_id=agent_id_str,
|
||||
payload=payload,
|
||||
idempotency_key, key_error = _read_idempotency_key()
|
||||
if key_error is not None:
|
||||
return key_error
|
||||
# Resolve to PG UUID first so dedup writes don't crash on legacy ids.
|
||||
agent_uuid = None
|
||||
if agent is not None:
|
||||
candidate = str(agent.get("id") or "")
|
||||
if looks_like_uuid(candidate):
|
||||
agent_uuid = candidate
|
||||
if idempotency_key and agent_uuid is None:
|
||||
current_app.logger.warning(
|
||||
"Skipping webhook idempotency dedup: agent %s has non-UUID id",
|
||||
agent_id_str,
|
||||
)
|
||||
idempotency_key = None
|
||||
# Agent-scoped (webhooks have no user_id).
|
||||
scoped_key = _scoped_idempotency_key(idempotency_key, agent_uuid)
|
||||
# Claim before enqueue; the loser returns the winner's task_id.
|
||||
predetermined_task_id = None
|
||||
if scoped_key:
|
||||
predetermined_task_id = str(uuid.uuid4())
|
||||
with db_session() as conn:
|
||||
claimed = IdempotencyRepository(conn).record_webhook(
|
||||
key=scoped_key,
|
||||
agent_id=agent_uuid,
|
||||
task_id=predetermined_task_id,
|
||||
response_json={
|
||||
"success": True, "task_id": predetermined_task_id,
|
||||
},
|
||||
)
|
||||
if claimed is None:
|
||||
with db_readonly() as conn:
|
||||
cached = IdempotencyRepository(conn).get_webhook(scoped_key)
|
||||
if cached is not None:
|
||||
return make_response(jsonify(cached["response_json"]), 200)
|
||||
return make_response(
|
||||
jsonify({"success": True, "task_id": "deduplicated"}), 200
|
||||
)
|
||||
|
||||
try:
|
||||
apply_kwargs = dict(
|
||||
kwargs={
|
||||
"agent_id": agent_id_str,
|
||||
"payload": payload,
|
||||
# Scoped so the worker dedup row matches the HTTP claim.
|
||||
"idempotency_key": scoped_key or idempotency_key,
|
||||
},
|
||||
)
|
||||
if predetermined_task_id is not None:
|
||||
apply_kwargs["task_id"] = predetermined_task_id
|
||||
task = process_agent_webhook.apply_async(**apply_kwargs)
|
||||
current_app.logger.info(
|
||||
f"Task {task.id} enqueued for agent {agent_id_str} ({source_method})."
|
||||
)
|
||||
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
|
||||
response_payload = {"success": True, "task_id": task.id}
|
||||
return make_response(jsonify(response_payload), 200)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error enqueuing webhook task ({source_method}) for agent {agent_id_str}: {err}",
|
||||
exc_info=True,
|
||||
)
|
||||
if scoped_key:
|
||||
# Roll back the claim so a retry can succeed.
|
||||
try:
|
||||
with db_session() as conn:
|
||||
conn.execute(
|
||||
sql_text(
|
||||
"DELETE FROM webhook_dedup "
|
||||
"WHERE idempotency_key = :k"
|
||||
),
|
||||
{"k": scoped_key},
|
||||
)
|
||||
except Exception:
|
||||
current_app.logger.exception(
|
||||
"Failed to release webhook_dedup claim for key=%s",
|
||||
scoped_key,
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Error processing webhook"}), 500
|
||||
)
|
||||
|
||||
@api.doc(
|
||||
description="Webhook listener for agent events (POST). Expects JSON payload, which is used to trigger processing.",
|
||||
description=(
|
||||
"Webhook listener for agent events (POST). Expects JSON payload, which "
|
||||
"is used to trigger processing. Honors an optional ``Idempotency-Key`` "
|
||||
"header: a repeat request with the same key within 24h returns the "
|
||||
"original cached response and does not re-enqueue the task."
|
||||
),
|
||||
)
|
||||
def post(self, webhook_token, agent, agent_id_str):
|
||||
payload = request.get_json()
|
||||
@@ -110,11 +213,20 @@ class AgentWebhookListener(Resource):
|
||||
),
|
||||
400,
|
||||
)
|
||||
return self._enqueue_webhook_task(agent_id_str, payload, source_method="POST")
|
||||
return self._enqueue_webhook_task(
|
||||
agent_id_str, payload, source_method="POST", agent=agent,
|
||||
)
|
||||
|
||||
@api.doc(
|
||||
description="Webhook listener for agent events (GET). Uses URL query parameters as payload to trigger processing.",
|
||||
description=(
|
||||
"Webhook listener for agent events (GET). Uses URL query parameters as "
|
||||
"payload to trigger processing. Honors an optional ``Idempotency-Key`` "
|
||||
"header: a repeat request with the same key within 24h returns the "
|
||||
"original cached response and does not re-enqueue the task."
|
||||
),
|
||||
)
|
||||
def get(self, webhook_token, agent, agent_id_str):
|
||||
payload = request.args.to_dict(flat=True)
|
||||
return self._enqueue_webhook_task(agent_id_str, payload, source_method="GET")
|
||||
return self._enqueue_webhook_task(
|
||||
agent_id_str, payload, source_method="GET", agent=agent,
|
||||
)
|
||||
|
||||
@@ -214,6 +214,10 @@ class StoreAttachment(Resource):
|
||||
{
|
||||
"success": True,
|
||||
"task_id": tasks[0]["task_id"],
|
||||
# Surface the attachment_id so the frontend
|
||||
# can correlate ``attachment.*`` SSE events
|
||||
# to this row and skip the polling fallback.
|
||||
"attachment_id": tasks[0]["attachment_id"],
|
||||
"message": "File uploaded successfully. Processing started.",
|
||||
}
|
||||
),
|
||||
|
||||
@@ -4,10 +4,16 @@ import datetime
|
||||
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
from sqlalchemy import text as sql_text
|
||||
|
||||
from application.api import api
|
||||
from application.api.answer.services.conversation_service import (
|
||||
TERMINATED_RESPONSE_PLACEHOLDER,
|
||||
)
|
||||
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
|
||||
from application.storage.db.repositories.attachments import AttachmentsRepository
|
||||
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||
from application.storage.db.repositories.message_events import MessageEventsRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.utils import check_required_fields
|
||||
|
||||
@@ -133,6 +139,7 @@ class GetSingleConversation(Resource):
|
||||
attachments_repo = AttachmentsRepository(conn)
|
||||
queries = []
|
||||
for msg in messages:
|
||||
metadata = msg.get("metadata") or {}
|
||||
query = {
|
||||
"prompt": msg.get("prompt"),
|
||||
"response": msg.get("response"),
|
||||
@@ -141,9 +148,15 @@ class GetSingleConversation(Resource):
|
||||
"tool_calls": msg.get("tool_calls") or [],
|
||||
"timestamp": msg.get("timestamp"),
|
||||
"model_id": msg.get("model_id"),
|
||||
# Lets the client distinguish placeholder rows from
|
||||
# finalised answers and tail-poll in-flight ones.
|
||||
"message_id": str(msg["id"]) if msg.get("id") else None,
|
||||
"status": msg.get("status"),
|
||||
"request_id": msg.get("request_id"),
|
||||
"last_heartbeat_at": metadata.get("last_heartbeat_at"),
|
||||
}
|
||||
if msg.get("metadata"):
|
||||
query["metadata"] = msg["metadata"]
|
||||
if metadata:
|
||||
query["metadata"] = metadata
|
||||
# Feedback on conversation_messages is a JSONB blob with
|
||||
# shape {"text": <str>, "timestamp": <iso>}. The legacy
|
||||
# frontend consumed a flat scalar feedback string, so
|
||||
@@ -301,3 +314,80 @@ class SubmitFeedback(Resource):
|
||||
current_app.logger.error(f"Error submitting feedback: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@conversations_ns.route("/messages/<string:message_id>/tail")
|
||||
class GetMessageTail(Resource):
|
||||
@api.doc(
|
||||
description=(
|
||||
"Current state of one conversation_messages row, scoped to the "
|
||||
"authenticated user. Used to reconnect to an in-flight stream "
|
||||
"after a refresh."
|
||||
),
|
||||
params={"message_id": "Message UUID"},
|
||||
)
|
||||
def get(self, message_id):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
if not looks_like_uuid(message_id):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid message id"}), 400
|
||||
)
|
||||
user_id = decoded_token.get("sub")
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
# Owner-or-shared, matching ``ConversationsRepository.get``.
|
||||
row = conn.execute(
|
||||
sql_text(
|
||||
"SELECT m.* FROM conversation_messages m "
|
||||
"JOIN conversations c ON c.id = m.conversation_id "
|
||||
"WHERE m.id = CAST(:mid AS uuid) "
|
||||
"AND (c.user_id = :uid OR :uid = ANY(c.shared_with))"
|
||||
),
|
||||
{"mid": message_id, "uid": user_id},
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return make_response(jsonify({"status": "not found"}), 404)
|
||||
msg = row_to_dict(row)
|
||||
# Mid-stream the row's response is the placeholder; rebuild
|
||||
# the live partial from the journal so /tail mirrors SSE.
|
||||
status = msg.get("status")
|
||||
response = msg.get("response")
|
||||
thought = msg.get("thought")
|
||||
sources = msg.get("sources") or []
|
||||
tool_calls = msg.get("tool_calls") or []
|
||||
if status in ("pending", "streaming") and (
|
||||
response == TERMINATED_RESPONSE_PLACEHOLDER
|
||||
):
|
||||
partial = MessageEventsRepository(conn).reconstruct_partial(
|
||||
message_id
|
||||
)
|
||||
response = partial["response"]
|
||||
thought = partial["thought"] or thought
|
||||
if partial["sources"]:
|
||||
sources = partial["sources"]
|
||||
if partial["tool_calls"]:
|
||||
tool_calls = partial["tool_calls"]
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error tailing message {message_id}: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
metadata = msg.get("message_metadata") or {}
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"message_id": str(msg["id"]),
|
||||
"status": status,
|
||||
"response": response,
|
||||
"thought": thought,
|
||||
"sources": sources,
|
||||
"tool_calls": tool_calls,
|
||||
"request_id": msg.get("request_id"),
|
||||
"last_heartbeat_at": metadata.get("last_heartbeat_at"),
|
||||
"error": metadata.get("error"),
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
|
||||
237
application/api/user/idempotency.py
Normal file
237
application/api/user/idempotency.py
Normal file
@@ -0,0 +1,237 @@
|
||||
"""Per-Celery-task idempotency wrapper backed by ``task_dedup``."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from application.storage.db.repositories.idempotency import IdempotencyRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Poison-loop cap; transient-failure headroom without infinite retry.
|
||||
MAX_TASK_ATTEMPTS = 5
|
||||
|
||||
# 30s heartbeat / 60s TTL → ~2 missed ticks of slack before reclaim.
|
||||
LEASE_TTL_SECONDS = 60
|
||||
LEASE_HEARTBEAT_INTERVAL = 30
|
||||
|
||||
# 10 × 60s ≈ 5 min of deferral before giving up on a held lease.
|
||||
LEASE_RETRY_MAX = 10
|
||||
|
||||
|
||||
def with_idempotency(task_name: str) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
"""Short-circuit on completed key; gate concurrent runs via a lease.
|
||||
|
||||
Entry short-circuits:
|
||||
- completed row → return cached result
|
||||
- live lease held → retry(countdown=LEASE_TTL_SECONDS)
|
||||
- attempt_count > MAX_TASK_ATTEMPTS → poison-loop alert
|
||||
Success writes ``completed``; exceptions leave ``pending`` for
|
||||
autoretry until the poison-loop guard trips.
|
||||
"""
|
||||
|
||||
def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
|
||||
@functools.wraps(fn)
|
||||
def wrapper(self, *args: Any, idempotency_key: Any = None, **kwargs: Any) -> Any:
|
||||
key = idempotency_key if isinstance(idempotency_key, str) and idempotency_key else None
|
||||
if key is None:
|
||||
return fn(self, *args, idempotency_key=idempotency_key, **kwargs)
|
||||
|
||||
cached = _lookup_completed(key)
|
||||
if cached is not None:
|
||||
logger.info(
|
||||
"idempotency hit for task=%s key=%s — returning cached result",
|
||||
task_name, key,
|
||||
)
|
||||
return cached
|
||||
|
||||
owner_id = str(uuid.uuid4())
|
||||
attempt = _try_claim_lease(
|
||||
key, task_name, _safe_task_id(self), owner_id,
|
||||
)
|
||||
if attempt is None:
|
||||
# Live lease held by another worker. Re-queue and bail
|
||||
# quickly — by the time the retry fires (LEASE_TTL
|
||||
# seconds), Worker 1 has either finalised (we'll hit
|
||||
# ``_lookup_completed`` and return cached) or its lease
|
||||
# has expired and we can claim.
|
||||
logger.info(
|
||||
"idempotency: live lease held; deferring task=%s key=%s",
|
||||
task_name, key,
|
||||
)
|
||||
raise self.retry(
|
||||
countdown=LEASE_TTL_SECONDS,
|
||||
max_retries=LEASE_RETRY_MAX,
|
||||
)
|
||||
|
||||
if attempt > MAX_TASK_ATTEMPTS:
|
||||
logger.error(
|
||||
"idempotency poison-loop guard: task=%s key=%s attempts=%s",
|
||||
task_name, key, attempt,
|
||||
extra={
|
||||
"alert": "idempotency_poison_loop",
|
||||
"task_name": task_name,
|
||||
"idempotency_key": key,
|
||||
"attempts": attempt,
|
||||
},
|
||||
)
|
||||
poisoned = {
|
||||
"success": False,
|
||||
"error": "idempotency poison-loop guard tripped",
|
||||
"attempts": attempt,
|
||||
}
|
||||
_finalize(key, poisoned, status="failed")
|
||||
return poisoned
|
||||
|
||||
heartbeat_thread, heartbeat_stop = _start_lease_heartbeat(
|
||||
key, owner_id,
|
||||
)
|
||||
try:
|
||||
result = fn(self, *args, idempotency_key=idempotency_key, **kwargs)
|
||||
_finalize(key, result, status="completed")
|
||||
return result
|
||||
except Exception:
|
||||
# Drop the lease so the next retry doesn't wait LEASE_TTL.
|
||||
_release_lease(key, owner_id)
|
||||
raise
|
||||
finally:
|
||||
_stop_lease_heartbeat(heartbeat_thread, heartbeat_stop)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _lookup_completed(key: str) -> Any:
|
||||
"""Return cached ``result_json`` if a completed row exists for ``key``, else None."""
|
||||
with db_readonly() as conn:
|
||||
row = IdempotencyRepository(conn).get_task(key)
|
||||
if row is None:
|
||||
return None
|
||||
if row.get("status") != "completed":
|
||||
return None
|
||||
return row.get("result_json")
|
||||
|
||||
|
||||
def _try_claim_lease(
|
||||
key: str, task_name: str, task_id: str, owner_id: str,
|
||||
) -> Optional[int]:
|
||||
"""Atomic CAS; returns ``attempt_count`` or ``None`` when held.
|
||||
|
||||
DB outage → treated as ``attempt=1`` so transient failures don't
|
||||
block all task execution; reconciler repairs the lease columns.
|
||||
"""
|
||||
try:
|
||||
with db_session() as conn:
|
||||
return IdempotencyRepository(conn).try_claim_lease(
|
||||
key=key,
|
||||
task_name=task_name,
|
||||
task_id=task_id,
|
||||
owner_id=owner_id,
|
||||
ttl_seconds=LEASE_TTL_SECONDS,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"idempotency lease-claim failed for key=%s task=%s", key, task_name,
|
||||
)
|
||||
return 1
|
||||
|
||||
|
||||
def _finalize(key: str, result_json: Any, *, status: str) -> None:
|
||||
"""Best-effort terminal write. Never let DB outage fail the task."""
|
||||
try:
|
||||
with db_session() as conn:
|
||||
IdempotencyRepository(conn).finalize_task(
|
||||
key=key, result_json=result_json, status=status,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"idempotency finalize failed for key=%s status=%s", key, status,
|
||||
)
|
||||
|
||||
|
||||
def _release_lease(key: str, owner_id: str) -> None:
|
||||
"""Best-effort lease release on the wrapper's exception path."""
|
||||
try:
|
||||
with db_session() as conn:
|
||||
IdempotencyRepository(conn).release_lease(key, owner_id)
|
||||
except Exception:
|
||||
logger.exception("idempotency release-lease failed for key=%s", key)
|
||||
|
||||
|
||||
def _start_lease_heartbeat(
|
||||
key: str, owner_id: str,
|
||||
) -> tuple[threading.Thread, threading.Event]:
|
||||
"""Spawn a daemon thread that bumps ``lease_expires_at`` every
|
||||
:data:`LEASE_HEARTBEAT_INTERVAL` seconds until ``stop_event`` fires.
|
||||
|
||||
Mirrors ``application.worker._start_ingest_heartbeat`` so the two
|
||||
durability heartbeats share shape and cadence.
|
||||
"""
|
||||
stop_event = threading.Event()
|
||||
thread = threading.Thread(
|
||||
target=_lease_heartbeat_loop,
|
||||
args=(key, owner_id, stop_event, LEASE_HEARTBEAT_INTERVAL),
|
||||
daemon=True,
|
||||
name=f"idempotency-lease-heartbeat:{key[:32]}",
|
||||
)
|
||||
thread.start()
|
||||
return thread, stop_event
|
||||
|
||||
|
||||
def _stop_lease_heartbeat(
|
||||
thread: threading.Thread, stop_event: threading.Event,
|
||||
) -> None:
|
||||
"""Signal the heartbeat thread to exit and join with a short timeout."""
|
||||
stop_event.set()
|
||||
thread.join(timeout=10)
|
||||
|
||||
|
||||
def _lease_heartbeat_loop(
|
||||
key: str,
|
||||
owner_id: str,
|
||||
stop_event: threading.Event,
|
||||
interval: int,
|
||||
) -> None:
|
||||
"""Refresh the lease until ``stop_event`` is set or ownership is lost.
|
||||
|
||||
A failed refresh (rowcount 0) means another worker stole the lease
|
||||
after expiry — at that point the damage is already possible, so we
|
||||
log and keep ticking. Don't escalate to thread death; the main task
|
||||
body needs to keep running so its outcome is at least *recorded*.
|
||||
"""
|
||||
while not stop_event.wait(interval):
|
||||
try:
|
||||
with db_session() as conn:
|
||||
still_owned = IdempotencyRepository(conn).refresh_lease(
|
||||
key=key, owner_id=owner_id, ttl_seconds=LEASE_TTL_SECONDS,
|
||||
)
|
||||
if not still_owned:
|
||||
logger.warning(
|
||||
"idempotency lease lost mid-task for key=%s "
|
||||
"(another worker may have taken over)",
|
||||
key,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"idempotency lease-heartbeat tick failed for key=%s", key,
|
||||
)
|
||||
|
||||
|
||||
def _safe_task_id(task_self: Any) -> str:
|
||||
"""Best-effort extraction of ``self.request.id`` from a Celery task."""
|
||||
try:
|
||||
request = getattr(task_self, "request", None)
|
||||
task_id: Optional[str] = (
|
||||
getattr(request, "id", None) if request is not None else None
|
||||
)
|
||||
except Exception:
|
||||
task_id = None
|
||||
return task_id or "unknown"
|
||||
196
application/api/user/reconciliation.py
Normal file
196
application/api/user/reconciliation.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""Reconciler tick: sweep stuck rows and escalate to terminal status + alert."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from sqlalchemy import Connection
|
||||
|
||||
from application.api.user.idempotency import MAX_TASK_ATTEMPTS
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.engine import get_engine
|
||||
from application.storage.db.repositories.reconciliation import (
|
||||
ReconciliationRepository,
|
||||
)
|
||||
from application.storage.db.repositories.stack_logs import StackLogsRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
MAX_MESSAGE_RECONCILE_ATTEMPTS = 3
|
||||
|
||||
|
||||
def run_reconciliation() -> Dict[str, Any]:
|
||||
"""Single tick of the reconciler. Five sweeps, FOR UPDATE SKIP LOCKED.
|
||||
|
||||
Stuck ``executed`` tool calls always flip to ``failed`` — operators
|
||||
handle cleanup manually via the structured alert. The side effect is
|
||||
assumed to have committed; no automated rollback is attempted.
|
||||
|
||||
Stuck ``task_dedup`` rows (lease expired AND attempts >= max)
|
||||
promote to ``failed`` so a same-key retry can re-claim instead of
|
||||
sitting in ``pending`` until 24 h TTL.
|
||||
"""
|
||||
if not settings.POSTGRES_URI:
|
||||
return {
|
||||
"messages_failed": 0,
|
||||
"tool_calls_failed": 0,
|
||||
"skipped": "POSTGRES_URI not set",
|
||||
}
|
||||
|
||||
engine = get_engine()
|
||||
summary = {
|
||||
"messages_failed": 0,
|
||||
"tool_calls_failed": 0,
|
||||
"ingests_stalled": 0,
|
||||
"idempotency_pending_failed": 0,
|
||||
}
|
||||
|
||||
with engine.begin() as conn:
|
||||
repo = ReconciliationRepository(conn)
|
||||
for msg in repo.find_and_lock_stuck_messages():
|
||||
new_count = repo.increment_message_reconcile_attempts(msg["id"])
|
||||
if new_count >= MAX_MESSAGE_RECONCILE_ATTEMPTS:
|
||||
repo.mark_message_failed(
|
||||
msg["id"],
|
||||
error=(
|
||||
"reconciler: stuck in pending/streaming for >5 min "
|
||||
f"after {new_count} attempts"
|
||||
),
|
||||
)
|
||||
summary["messages_failed"] += 1
|
||||
_emit_alert(
|
||||
conn,
|
||||
name="reconciler_message_failed",
|
||||
user_id=msg.get("user_id"),
|
||||
detail={
|
||||
"message_id": str(msg["id"]),
|
||||
"attempts": new_count,
|
||||
},
|
||||
)
|
||||
|
||||
with engine.begin() as conn:
|
||||
repo = ReconciliationRepository(conn)
|
||||
for row in repo.find_and_lock_proposed_tool_calls():
|
||||
repo.mark_tool_call_failed(
|
||||
row["call_id"],
|
||||
error=(
|
||||
"reconciler: stuck in 'proposed' for >5 min; "
|
||||
"side effect status unknown"
|
||||
),
|
||||
)
|
||||
summary["tool_calls_failed"] += 1
|
||||
_emit_alert(
|
||||
conn,
|
||||
name="reconciler_tool_call_failed_proposed",
|
||||
user_id=None,
|
||||
detail={
|
||||
"call_id": row["call_id"],
|
||||
"tool_name": row.get("tool_name"),
|
||||
},
|
||||
)
|
||||
|
||||
with engine.begin() as conn:
|
||||
repo = ReconciliationRepository(conn)
|
||||
for row in repo.find_and_lock_executed_tool_calls():
|
||||
repo.mark_tool_call_failed(
|
||||
row["call_id"],
|
||||
error=(
|
||||
"reconciler: executed-not-confirmed; side effect "
|
||||
"assumed committed, manual cleanup required"
|
||||
),
|
||||
)
|
||||
summary["tool_calls_failed"] += 1
|
||||
_emit_alert(
|
||||
conn,
|
||||
name="reconciler_tool_call_failed_executed",
|
||||
user_id=None,
|
||||
detail={
|
||||
"call_id": row["call_id"],
|
||||
"tool_name": row.get("tool_name"),
|
||||
"action_name": row.get("action_name"),
|
||||
},
|
||||
)
|
||||
|
||||
# Q4: ingest checkpoints whose heartbeat has gone silent. The
|
||||
# reconciler only escalates (alerts) — it doesn't kill the worker
|
||||
# or roll back the partial embed. The next dispatch resumes from
|
||||
# ``last_index`` thanks to the per-chunk checkpoint, so this is an
|
||||
# observability sweep, not a recovery action.
|
||||
with engine.begin() as conn:
|
||||
repo = ReconciliationRepository(conn)
|
||||
for row in repo.find_and_lock_stalled_ingests():
|
||||
summary["ingests_stalled"] += 1
|
||||
_emit_alert(
|
||||
conn,
|
||||
name="reconciler_ingest_stalled",
|
||||
user_id=None,
|
||||
detail={
|
||||
"source_id": str(row.get("source_id")),
|
||||
"embedded_chunks": row.get("embedded_chunks"),
|
||||
"total_chunks": row.get("total_chunks"),
|
||||
"last_updated": str(row.get("last_updated")),
|
||||
},
|
||||
)
|
||||
# Bump the heartbeat so we don't re-alert every tick.
|
||||
repo.touch_ingest_progress(str(row["source_id"]))
|
||||
|
||||
# Q5: idempotency rows whose lease expired with attempts exhausted.
|
||||
# The wrapper's poison-loop guard normally finalises these, but if
|
||||
# the wrapper itself died mid-task (worker SIGKILL, OOM during
|
||||
# heartbeat) the row sits in ``pending`` blocking same-key retries
|
||||
# via ``_lookup_completed`` returning None for the whole 24 h TTL.
|
||||
# Promote to ``failed`` so a retry can re-claim and either resume
|
||||
# or fail loudly.
|
||||
with engine.begin() as conn:
|
||||
repo = ReconciliationRepository(conn)
|
||||
for row in repo.find_stuck_idempotency_pending(
|
||||
max_attempts=MAX_TASK_ATTEMPTS,
|
||||
):
|
||||
error_msg = (
|
||||
"reconciler: idempotency lease expired with attempts "
|
||||
f"({row['attempt_count']}) >= {MAX_TASK_ATTEMPTS}; "
|
||||
"task abandoned"
|
||||
)
|
||||
repo.mark_idempotency_pending_failed(
|
||||
row["idempotency_key"], error=error_msg,
|
||||
)
|
||||
summary["idempotency_pending_failed"] += 1
|
||||
_emit_alert(
|
||||
conn,
|
||||
name="reconciler_idempotency_pending_failed",
|
||||
user_id=None,
|
||||
detail={
|
||||
"idempotency_key": row["idempotency_key"],
|
||||
"task_name": row.get("task_name"),
|
||||
"task_id": row.get("task_id"),
|
||||
"attempts": row.get("attempt_count"),
|
||||
},
|
||||
)
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
def _emit_alert(
|
||||
conn: Connection,
|
||||
*,
|
||||
name: str,
|
||||
user_id: Optional[str],
|
||||
detail: Dict[str, Any],
|
||||
) -> None:
|
||||
"""Structured ``logger.error`` plus a ``stack_logs`` row for operators."""
|
||||
extra = {"alert": name, **detail}
|
||||
logger.error("reconciler alert: %s", name, extra=extra)
|
||||
try:
|
||||
StackLogsRepository(conn).insert(
|
||||
activity_id=str(uuid.uuid4()),
|
||||
endpoint="reconciliation_worker",
|
||||
level="ERROR",
|
||||
user_id=user_id,
|
||||
query=name,
|
||||
stacks=[extra],
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("reconciler: failed to write stack_logs row for %s", name)
|
||||
@@ -3,16 +3,20 @@
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import uuid
|
||||
import zipfile
|
||||
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
from sqlalchemy import text as sql_text
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.tasks import ingest, ingest_connector_task, ingest_remote
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.source_ids import derive_source_id as _derive_source_id
|
||||
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||
from application.parser.file.constants import SUPPORTED_SOURCE_EXTENSIONS
|
||||
from application.storage.db.repositories.idempotency import IdempotencyRepository
|
||||
from application.storage.db.repositories.sources import SourcesRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
@@ -30,6 +34,91 @@ sources_upload_ns = Namespace(
|
||||
)
|
||||
|
||||
|
||||
_IDEMPOTENCY_KEY_MAX_LEN = 256
|
||||
|
||||
|
||||
def _read_idempotency_key():
|
||||
"""Return (key, error_response). Empty header → (None, None); oversized → (None, 400)."""
|
||||
key = request.headers.get("Idempotency-Key")
|
||||
if not key:
|
||||
return None, None
|
||||
if len(key) > _IDEMPOTENCY_KEY_MAX_LEN:
|
||||
return None, make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": (
|
||||
f"Idempotency-Key exceeds maximum length of "
|
||||
f"{_IDEMPOTENCY_KEY_MAX_LEN} characters"
|
||||
),
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
return key, None
|
||||
|
||||
|
||||
def _scoped_idempotency_key(idempotency_key, scope):
|
||||
"""``{scope}:{key}`` so different users can't collide on the same key."""
|
||||
if not idempotency_key or not scope:
|
||||
return None
|
||||
return f"{scope}:{idempotency_key}"
|
||||
|
||||
|
||||
def _claim_task_or_get_cached(key, task_name):
|
||||
"""Claim ``key`` for this request OR return the winner's cached payload.
|
||||
|
||||
Pre-generates the celery task_id so a losing writer sees the same
|
||||
id immediately. Returns ``(task_id, cached_response)``; non-None
|
||||
cached means the caller should return without enqueuing. The
|
||||
cached payload mirrors the fresh-request response shape (including
|
||||
``source_id``) so the frontend can correlate SSE ingest events to
|
||||
the cached upload task without an extra round-trip — but only when
|
||||
the cached row actually exists; the "deduplicated" sentinel
|
||||
deliberately omits ``source_id`` so the frontend doesn't bind to a
|
||||
phantom source.
|
||||
"""
|
||||
predetermined_id = str(uuid.uuid4())
|
||||
with db_session() as conn:
|
||||
claimed = IdempotencyRepository(conn).claim_task(
|
||||
key=key, task_name=task_name, task_id=predetermined_id,
|
||||
)
|
||||
if claimed is not None:
|
||||
return claimed["task_id"], None
|
||||
with db_readonly() as conn:
|
||||
existing = IdempotencyRepository(conn).get_task(key)
|
||||
cached_id = existing.get("task_id") if existing else None
|
||||
payload: dict = {
|
||||
"success": True,
|
||||
"task_id": cached_id or "deduplicated",
|
||||
}
|
||||
# Only surface ``source_id`` when there's a real winner whose worker
|
||||
# is publishing SSE events tagged with that id. The "deduplicated"
|
||||
# branch means the lock row vanished — we have nothing to correlate.
|
||||
if cached_id is not None:
|
||||
payload["source_id"] = str(_derive_source_id(key))
|
||||
return None, payload
|
||||
|
||||
|
||||
def _release_claim(key):
|
||||
"""Drop a pending claim so a client retry can re-claim it."""
|
||||
try:
|
||||
with db_session() as conn:
|
||||
conn.execute(
|
||||
sql_text(
|
||||
"DELETE FROM task_dedup WHERE idempotency_key = :k "
|
||||
"AND status = 'pending'"
|
||||
),
|
||||
{"k": key},
|
||||
)
|
||||
except Exception:
|
||||
current_app.logger.exception(
|
||||
"Failed to release task_dedup claim for key=%s", key,
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
def _enforce_audio_path_size_limit(file_path: str, filename: str) -> None:
|
||||
if not is_audio_filename(filename):
|
||||
return
|
||||
@@ -49,17 +138,38 @@ class UploadFile(Resource):
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Uploads a file to be vectorized and indexed",
|
||||
description=(
|
||||
"Uploads a file to be vectorized and indexed. Honors an optional "
|
||||
"``Idempotency-Key`` header: a repeat request with the same key "
|
||||
"within 24h returns the original cached response without re-enqueuing."
|
||||
),
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
idempotency_key, key_error = _read_idempotency_key()
|
||||
if key_error is not None:
|
||||
return key_error
|
||||
# User-scoped to avoid cross-user collisions; also feeds
|
||||
# ``_derive_source_id`` so uuid5 stays user-disjoint.
|
||||
scoped_key = _scoped_idempotency_key(idempotency_key, user)
|
||||
# Claim before enqueue; the loser returns the winner's task_id.
|
||||
predetermined_task_id = None
|
||||
if scoped_key:
|
||||
predetermined_task_id, cached = _claim_task_or_get_cached(
|
||||
scoped_key, "ingest",
|
||||
)
|
||||
if cached is not None:
|
||||
return make_response(jsonify(cached), 200)
|
||||
data = request.form
|
||||
files = request.files.getlist("file")
|
||||
required_fields = ["user", "name"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields or not files or all(file.filename == "" for file in files):
|
||||
if scoped_key:
|
||||
_release_claim(scoped_key)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
@@ -69,7 +179,6 @@ class UploadFile(Resource):
|
||||
),
|
||||
400,
|
||||
)
|
||||
user = decoded_token.get("sub")
|
||||
job_name = request.form["name"]
|
||||
|
||||
# Create safe versions for filesystem operations
|
||||
@@ -140,16 +249,37 @@ class UploadFile(Resource):
|
||||
file_path = f"{base_path}/{safe_file}"
|
||||
with open(temp_file_path, "rb") as f:
|
||||
storage.save_file(f, file_path)
|
||||
task = ingest.delay(
|
||||
settings.UPLOAD_FOLDER,
|
||||
list(SUPPORTED_SOURCE_EXTENSIONS),
|
||||
job_name,
|
||||
user,
|
||||
file_path=base_path,
|
||||
filename=dir_name,
|
||||
file_name_map=file_name_map,
|
||||
# Mint the source UUID up here so the HTTP response and the
|
||||
# worker's SSE envelopes share one id. With an idempotency
|
||||
# key we reuse the deterministic uuid5 (retried task lands on
|
||||
# the same source row); without a key we fall back to uuid4.
|
||||
# The worker is told to use this id verbatim — see
|
||||
# ``ingest_worker(source_id=...)``.
|
||||
source_uuid = (
|
||||
_derive_source_id(scoped_key) if scoped_key else uuid.uuid4()
|
||||
)
|
||||
ingest_kwargs = dict(
|
||||
args=(
|
||||
settings.UPLOAD_FOLDER,
|
||||
list(SUPPORTED_SOURCE_EXTENSIONS),
|
||||
job_name,
|
||||
user,
|
||||
),
|
||||
kwargs={
|
||||
"file_path": base_path,
|
||||
"filename": dir_name,
|
||||
"file_name_map": file_name_map,
|
||||
# Scoped so the worker dedup row matches the HTTP claim.
|
||||
"idempotency_key": scoped_key or idempotency_key,
|
||||
"source_id": str(source_uuid),
|
||||
},
|
||||
)
|
||||
if predetermined_task_id is not None:
|
||||
ingest_kwargs["task_id"] = predetermined_task_id
|
||||
task = ingest.apply_async(**ingest_kwargs)
|
||||
except AudioFileTooLargeError:
|
||||
if scoped_key:
|
||||
_release_claim(scoped_key)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
@@ -161,8 +291,21 @@ class UploadFile(Resource):
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error uploading file: {err}", exc_info=True)
|
||||
if scoped_key:
|
||||
_release_claim(scoped_key)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
|
||||
# Predetermined id matches the dedup-claim row; loser GET sees same.
|
||||
response_task_id = predetermined_task_id or task.id
|
||||
# ``source_uuid`` was minted above and passed to the worker as
|
||||
# ``source_id``; the worker uses it verbatim for every SSE event,
|
||||
# so the frontend can correlate inbound ``source.ingest.*`` to
|
||||
# this upload regardless of whether an idempotency key was set.
|
||||
response_payload: dict = {
|
||||
"success": True,
|
||||
"task_id": response_task_id,
|
||||
"source_id": str(source_uuid),
|
||||
}
|
||||
return make_response(jsonify(response_payload), 200)
|
||||
|
||||
|
||||
@sources_upload_ns.route("/remote")
|
||||
@@ -182,17 +325,50 @@ class UploadRemote(Resource):
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Uploads remote source for vectorization",
|
||||
description=(
|
||||
"Uploads remote source for vectorization. Honors an optional "
|
||||
"``Idempotency-Key`` header: a repeat request with the same key "
|
||||
"within 24h returns the original cached response without re-enqueuing."
|
||||
),
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
idempotency_key, key_error = _read_idempotency_key()
|
||||
if key_error is not None:
|
||||
return key_error
|
||||
scoped_key = _scoped_idempotency_key(idempotency_key, user)
|
||||
data = request.form
|
||||
required_fields = ["user", "source", "name", "data"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
task_name_for_dedup = (
|
||||
"ingest_connector_task"
|
||||
if data.get("source") in ConnectorCreator.get_supported_connectors()
|
||||
else "ingest_remote"
|
||||
)
|
||||
predetermined_task_id = None
|
||||
if scoped_key:
|
||||
predetermined_task_id, cached = _claim_task_or_get_cached(
|
||||
scoped_key, task_name_for_dedup,
|
||||
)
|
||||
if cached is not None:
|
||||
return make_response(jsonify(cached), 200)
|
||||
# Mint the source UUID up here so the HTTP response and the
|
||||
# worker's SSE envelopes share one id. Same pattern as
|
||||
# ``UploadFile.post``: with an idempotency key we reuse the
|
||||
# deterministic uuid5 (retried task lands on the same source
|
||||
# row); without a key we fall back to uuid4. The worker is told
|
||||
# to use this id verbatim — see ``remote_worker`` and
|
||||
# ``ingest_connector``. Without this the no-key path would mint
|
||||
# a random uuid4 inside the worker that the frontend has no way
|
||||
# to correlate SSE events to.
|
||||
source_uuid = (
|
||||
_derive_source_id(scoped_key) if scoped_key else uuid.uuid4()
|
||||
)
|
||||
try:
|
||||
config = json.loads(data["data"])
|
||||
source_data = None
|
||||
@@ -208,6 +384,8 @@ class UploadRemote(Resource):
|
||||
elif data["source"] in ConnectorCreator.get_supported_connectors():
|
||||
session_token = config.get("session_token")
|
||||
if not session_token:
|
||||
if scoped_key:
|
||||
_release_claim(scoped_key)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
@@ -236,31 +414,62 @@ class UploadRemote(Resource):
|
||||
config["file_ids"] = file_ids
|
||||
config["folder_ids"] = folder_ids
|
||||
|
||||
task = ingest_connector_task.delay(
|
||||
job_name=data["name"],
|
||||
user=decoded_token.get("sub"),
|
||||
source_type=data["source"],
|
||||
session_token=session_token,
|
||||
file_ids=file_ids,
|
||||
folder_ids=folder_ids,
|
||||
recursive=config.get("recursive", False),
|
||||
retriever=config.get("retriever", "classic"),
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": True, "task_id": task.id}), 200
|
||||
)
|
||||
task = ingest_remote.delay(
|
||||
source_data=source_data,
|
||||
job_name=data["name"],
|
||||
user=decoded_token.get("sub"),
|
||||
loader=data["source"],
|
||||
)
|
||||
connector_kwargs = {
|
||||
"kwargs": {
|
||||
"job_name": data["name"],
|
||||
"user": user,
|
||||
"source_type": data["source"],
|
||||
"session_token": session_token,
|
||||
"file_ids": file_ids,
|
||||
"folder_ids": folder_ids,
|
||||
"recursive": config.get("recursive", False),
|
||||
"retriever": config.get("retriever", "classic"),
|
||||
"idempotency_key": scoped_key or idempotency_key,
|
||||
"source_id": str(source_uuid),
|
||||
},
|
||||
}
|
||||
if predetermined_task_id is not None:
|
||||
connector_kwargs["task_id"] = predetermined_task_id
|
||||
task = ingest_connector_task.apply_async(**connector_kwargs)
|
||||
response_task_id = predetermined_task_id or task.id
|
||||
# ``source_uuid`` was minted above and passed to the
|
||||
# worker as ``source_id``; the worker uses it verbatim
|
||||
# for every SSE event, so the frontend can correlate
|
||||
# inbound ``source.ingest.*`` regardless of whether an
|
||||
# idempotency key was set.
|
||||
response_payload = {
|
||||
"success": True,
|
||||
"task_id": response_task_id,
|
||||
"source_id": str(source_uuid),
|
||||
}
|
||||
return make_response(jsonify(response_payload), 200)
|
||||
remote_kwargs = {
|
||||
"kwargs": {
|
||||
"source_data": source_data,
|
||||
"job_name": data["name"],
|
||||
"user": user,
|
||||
"loader": data["source"],
|
||||
"idempotency_key": scoped_key or idempotency_key,
|
||||
"source_id": str(source_uuid),
|
||||
},
|
||||
}
|
||||
if predetermined_task_id is not None:
|
||||
remote_kwargs["task_id"] = predetermined_task_id
|
||||
task = ingest_remote.apply_async(**remote_kwargs)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error uploading remote source: {err}", exc_info=True
|
||||
)
|
||||
if scoped_key:
|
||||
_release_claim(scoped_key)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
|
||||
response_task_id = predetermined_task_id or task.id
|
||||
response_payload = {
|
||||
"success": True,
|
||||
"task_id": response_task_id,
|
||||
"source_id": str(source_uuid),
|
||||
}
|
||||
return make_response(jsonify(response_payload), 200)
|
||||
|
||||
|
||||
@sources_upload_ns.route("/manage_source_files")
|
||||
@@ -305,6 +514,10 @@ class ManageSourceFiles(Resource):
|
||||
jsonify({"success": False, "message": "Unauthorized"}), 401
|
||||
)
|
||||
user = decoded_token.get("sub")
|
||||
idempotency_key, key_error = _read_idempotency_key()
|
||||
if key_error is not None:
|
||||
return key_error
|
||||
scoped_key = _scoped_idempotency_key(idempotency_key, user)
|
||||
source_id = request.form.get("source_id")
|
||||
operation = request.form.get("operation")
|
||||
|
||||
@@ -347,6 +560,12 @@ class ManageSourceFiles(Resource):
|
||||
jsonify({"success": False, "message": "Database error"}), 500
|
||||
)
|
||||
resolved_source_id = str(source["id"])
|
||||
# Flips to True after each branch's ``apply_async`` returns
|
||||
# successfully — at that point the worker owns the predetermined
|
||||
# task_id. The outer ``except`` only releases the claim while
|
||||
# this is False, so a post-``apply_async`` failure (jsonify,
|
||||
# make_response, etc.) doesn't double-enqueue on the next retry.
|
||||
claim_transferred = False
|
||||
try:
|
||||
storage = StorageCreator.get_storage()
|
||||
source_file_path = source.get("file_path", "")
|
||||
@@ -379,6 +598,34 @@ class ManageSourceFiles(Resource):
|
||||
),
|
||||
400,
|
||||
)
|
||||
|
||||
# Claim before any storage mutation so a duplicate request
|
||||
# short-circuits without touching the filesystem. Mirrors
|
||||
# the pattern in ``UploadFile.post`` / ``UploadRemote.post``
|
||||
# — without it ``.delay()`` would enqueue twice for two
|
||||
# racing same-key POSTs (the worker decorator only
|
||||
# deduplicates *after* completion).
|
||||
predetermined_task_id = None
|
||||
if scoped_key:
|
||||
predetermined_task_id, cached = _claim_task_or_get_cached(
|
||||
scoped_key, "reingest_source_task",
|
||||
)
|
||||
if cached is not None:
|
||||
# Frontend keys reingest polling on
|
||||
# ``reingest_task_id``; the shared cache helper
|
||||
# writes ``task_id``. Alias here so a dedup
|
||||
# response doesn't silently break FileTree's
|
||||
# poller. Override ``source_id`` too — the
|
||||
# helper derives it from the scoped key, which
|
||||
# is correct for upload but wrong for reingest
|
||||
# (the worker publishes events scoped to the
|
||||
# actual source row id).
|
||||
cached_task_id = cached.pop("task_id", None)
|
||||
if cached_task_id is not None:
|
||||
cached["reingest_task_id"] = cached_task_id
|
||||
cached["source_id"] = resolved_source_id
|
||||
return make_response(jsonify(cached), 200)
|
||||
|
||||
added_files = []
|
||||
map_updated = False
|
||||
|
||||
@@ -414,9 +661,15 @@ class ManageSourceFiles(Resource):
|
||||
|
||||
from application.api.user.tasks import reingest_source_task
|
||||
|
||||
task = reingest_source_task.delay(
|
||||
source_id=resolved_source_id, user=user
|
||||
task = reingest_source_task.apply_async(
|
||||
kwargs={
|
||||
"source_id": resolved_source_id,
|
||||
"user": user,
|
||||
"idempotency_key": scoped_key or idempotency_key,
|
||||
},
|
||||
task_id=predetermined_task_id,
|
||||
)
|
||||
claim_transferred = True
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
@@ -426,6 +679,12 @@ class ManageSourceFiles(Resource):
|
||||
"added_files": added_files,
|
||||
"parent_dir": parent_dir,
|
||||
"reingest_task_id": task.id,
|
||||
# ``source_id`` lets the frontend correlate
|
||||
# inbound ``source.ingest.*`` SSE events
|
||||
# (emitted by ``reingest_source_worker``)
|
||||
# back to the reingest task — matches the
|
||||
# upload route's source-id contract.
|
||||
"source_id": resolved_source_id,
|
||||
}
|
||||
),
|
||||
200,
|
||||
@@ -455,10 +714,8 @@ class ManageSourceFiles(Resource):
|
||||
),
|
||||
400,
|
||||
)
|
||||
# Remove files from storage and directory structure
|
||||
|
||||
removed_files = []
|
||||
map_updated = False
|
||||
# Path-traversal guard runs *before* the claim so a 400
|
||||
# for an invalid path doesn't leave a pending dedup row.
|
||||
for file_path in file_paths:
|
||||
if ".." in str(file_path) or str(file_path).startswith("/"):
|
||||
return make_response(
|
||||
@@ -470,6 +727,31 @@ class ManageSourceFiles(Resource):
|
||||
),
|
||||
400,
|
||||
)
|
||||
|
||||
# Claim before any storage mutation. See ``add`` branch
|
||||
# comment for rationale.
|
||||
predetermined_task_id = None
|
||||
if scoped_key:
|
||||
predetermined_task_id, cached = _claim_task_or_get_cached(
|
||||
scoped_key, "reingest_source_task",
|
||||
)
|
||||
if cached is not None:
|
||||
cached_task_id = cached.pop("task_id", None)
|
||||
if cached_task_id is not None:
|
||||
cached["reingest_task_id"] = cached_task_id
|
||||
# Override the helper's synthetic source_id (uuid5
|
||||
# of the scoped key) with the real source row id
|
||||
# — the reingest worker publishes SSE events
|
||||
# scoped to ``resolved_source_id`` and FileTree
|
||||
# correlates on it.
|
||||
cached["source_id"] = resolved_source_id
|
||||
return make_response(jsonify(cached), 200)
|
||||
|
||||
# Remove files from storage and directory structure
|
||||
|
||||
removed_files = []
|
||||
map_updated = False
|
||||
for file_path in file_paths:
|
||||
full_path = f"{source_file_path}/{file_path}"
|
||||
|
||||
# Remove from storage
|
||||
@@ -491,9 +773,15 @@ class ManageSourceFiles(Resource):
|
||||
|
||||
from application.api.user.tasks import reingest_source_task
|
||||
|
||||
task = reingest_source_task.delay(
|
||||
source_id=resolved_source_id, user=user
|
||||
task = reingest_source_task.apply_async(
|
||||
kwargs={
|
||||
"source_id": resolved_source_id,
|
||||
"user": user,
|
||||
"idempotency_key": scoped_key or idempotency_key,
|
||||
},
|
||||
task_id=predetermined_task_id,
|
||||
)
|
||||
claim_transferred = True
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
@@ -502,6 +790,7 @@ class ManageSourceFiles(Resource):
|
||||
"message": f"Removed {len(removed_files)} files",
|
||||
"removed_files": removed_files,
|
||||
"reingest_task_id": task.id,
|
||||
"source_id": resolved_source_id,
|
||||
}
|
||||
),
|
||||
200,
|
||||
@@ -552,6 +841,24 @@ class ManageSourceFiles(Resource):
|
||||
),
|
||||
404,
|
||||
)
|
||||
|
||||
# Claim before mutation. See ``add`` branch for rationale.
|
||||
predetermined_task_id = None
|
||||
if scoped_key:
|
||||
predetermined_task_id, cached = _claim_task_or_get_cached(
|
||||
scoped_key, "reingest_source_task",
|
||||
)
|
||||
if cached is not None:
|
||||
cached_task_id = cached.pop("task_id", None)
|
||||
if cached_task_id is not None:
|
||||
cached["reingest_task_id"] = cached_task_id
|
||||
# Same source_id override as the ``remove`` /
|
||||
# ``add`` cached branches — the helper's synthetic
|
||||
# id doesn't match what reingest_source_worker
|
||||
# tags its SSE events with.
|
||||
cached["source_id"] = resolved_source_id
|
||||
return make_response(jsonify(cached), 200)
|
||||
|
||||
success = storage.remove_directory(full_directory_path)
|
||||
|
||||
if not success:
|
||||
@@ -560,6 +867,11 @@ class ManageSourceFiles(Resource):
|
||||
f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}, "
|
||||
f"Full path: {full_directory_path}"
|
||||
)
|
||||
# Release so a client retry can reclaim — otherwise
|
||||
# the next request would silently 200-cache to the
|
||||
# task_id that never enqueued.
|
||||
if scoped_key:
|
||||
_release_claim(scoped_key)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": False, "message": "Failed to remove directory"}
|
||||
@@ -591,9 +903,15 @@ class ManageSourceFiles(Resource):
|
||||
|
||||
from application.api.user.tasks import reingest_source_task
|
||||
|
||||
task = reingest_source_task.delay(
|
||||
source_id=resolved_source_id, user=user
|
||||
task = reingest_source_task.apply_async(
|
||||
kwargs={
|
||||
"source_id": resolved_source_id,
|
||||
"user": user,
|
||||
"idempotency_key": scoped_key or idempotency_key,
|
||||
},
|
||||
task_id=predetermined_task_id,
|
||||
)
|
||||
claim_transferred = True
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
@@ -602,11 +920,20 @@ class ManageSourceFiles(Resource):
|
||||
"message": f"Successfully removed directory: {directory_path}",
|
||||
"removed_directory": directory_path,
|
||||
"reingest_task_id": task.id,
|
||||
"source_id": resolved_source_id,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
except Exception as err:
|
||||
# Release the dedup claim only if it wasn't transferred to
|
||||
# a worker. Without this, a same-key retry within the 24h
|
||||
# TTL would 200-cache to a predetermined task_id whose
|
||||
# ``apply_async`` never ran (or ran but the response builder
|
||||
# blew up afterward — only the first case matters in
|
||||
# practice; the flag protects both).
|
||||
if scoped_key and not claim_transferred:
|
||||
_release_claim(scoped_key)
|
||||
error_context = f"operation={operation}, user={user}, source_id={source_id}"
|
||||
if operation == "remove_directory":
|
||||
directory_path = request.form.get("directory_path", "")
|
||||
|
||||
@@ -1,21 +1,45 @@
|
||||
from datetime import timedelta
|
||||
|
||||
from application.api.user.idempotency import with_idempotency
|
||||
from application.celery_init import celery
|
||||
from application.worker import (
|
||||
agent_webhook_worker,
|
||||
attachment_worker,
|
||||
ingest_worker,
|
||||
mcp_oauth,
|
||||
mcp_oauth_status,
|
||||
remote_worker,
|
||||
sync,
|
||||
sync_worker,
|
||||
)
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
# Shared decorator config for long-running, side-effecting tasks. ``acks_late``
|
||||
# is also the celeryconfig default but stays explicit here so each task's
|
||||
# durability story is grep-able next to the body. Combined with
|
||||
# ``autoretry_for=(Exception,)`` and a bounded ``max_retries`` so a poison
|
||||
# message can't loop forever.
|
||||
DURABLE_TASK = dict(
|
||||
bind=True,
|
||||
acks_late=True,
|
||||
autoretry_for=(Exception,),
|
||||
retry_kwargs={"max_retries": 3, "countdown": 60},
|
||||
retry_backoff=True,
|
||||
)
|
||||
|
||||
|
||||
@celery.task(**DURABLE_TASK)
|
||||
@with_idempotency(task_name="ingest")
|
||||
def ingest(
|
||||
self, directory, formats, job_name, user, file_path, filename, file_name_map=None
|
||||
self,
|
||||
directory,
|
||||
formats,
|
||||
job_name,
|
||||
user,
|
||||
file_path,
|
||||
filename,
|
||||
file_name_map=None,
|
||||
idempotency_key=None,
|
||||
source_id=None,
|
||||
):
|
||||
resp = ingest_worker(
|
||||
self,
|
||||
@@ -26,25 +50,40 @@ def ingest(
|
||||
filename,
|
||||
user,
|
||||
file_name_map=file_name_map,
|
||||
idempotency_key=idempotency_key,
|
||||
source_id=source_id,
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def ingest_remote(self, source_data, job_name, user, loader):
|
||||
resp = remote_worker(self, source_data, job_name, user, loader)
|
||||
@celery.task(**DURABLE_TASK)
|
||||
@with_idempotency(task_name="ingest_remote")
|
||||
def ingest_remote(
|
||||
self, source_data, job_name, user, loader,
|
||||
idempotency_key=None, source_id=None,
|
||||
):
|
||||
resp = remote_worker(
|
||||
self, source_data, job_name, user, loader,
|
||||
idempotency_key=idempotency_key,
|
||||
source_id=source_id,
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def reingest_source_task(self, source_id, user):
|
||||
@celery.task(**DURABLE_TASK)
|
||||
@with_idempotency(task_name="reingest_source_task")
|
||||
def reingest_source_task(self, source_id, user, idempotency_key=None):
|
||||
from application.worker import reingest_source_worker
|
||||
|
||||
resp = reingest_source_worker(self, source_id, user)
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
# Beat-driven dispatch tasks default to ``acks_late=False``: a SIGKILL
|
||||
# of a beat tick is harmless to redeliver only if the dispatch itself is
|
||||
# idempotent. We keep these early-ACK so the broker doesn't replay a
|
||||
# dispatch that already enqueued downstream work.
|
||||
@celery.task(bind=True, acks_late=False)
|
||||
def schedule_syncs(self, frequency):
|
||||
resp = sync_worker(self, frequency)
|
||||
return resp
|
||||
@@ -74,19 +113,22 @@ def sync_source(
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def store_attachment(self, file_info, user):
|
||||
@celery.task(**DURABLE_TASK)
|
||||
@with_idempotency(task_name="store_attachment")
|
||||
def store_attachment(self, file_info, user, idempotency_key=None):
|
||||
resp = attachment_worker(self, file_info, user)
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def process_agent_webhook(self, agent_id, payload):
|
||||
@celery.task(**DURABLE_TASK)
|
||||
@with_idempotency(task_name="process_agent_webhook")
|
||||
def process_agent_webhook(self, agent_id, payload, idempotency_key=None):
|
||||
resp = agent_webhook_worker(self, agent_id, payload)
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
@celery.task(**DURABLE_TASK)
|
||||
@with_idempotency(task_name="ingest_connector_task")
|
||||
def ingest_connector_task(
|
||||
self,
|
||||
job_name,
|
||||
@@ -100,6 +142,8 @@ def ingest_connector_task(
|
||||
operation_mode="upload",
|
||||
doc_id=None,
|
||||
sync_frequency="never",
|
||||
idempotency_key=None,
|
||||
source_id=None,
|
||||
):
|
||||
from application.worker import ingest_connector
|
||||
|
||||
@@ -116,6 +160,8 @@ def ingest_connector_task(
|
||||
operation_mode=operation_mode,
|
||||
doc_id=doc_id,
|
||||
sync_frequency=sync_frequency,
|
||||
idempotency_key=idempotency_key,
|
||||
source_id=source_id,
|
||||
)
|
||||
return resp
|
||||
|
||||
@@ -140,11 +186,33 @@ def setup_periodic_tasks(sender, **kwargs):
|
||||
cleanup_pending_tool_state.s(),
|
||||
name="cleanup-pending-tool-state",
|
||||
)
|
||||
# Pure housekeeping for ``task_dedup`` / ``webhook_dedup`` — the
|
||||
# upsert paths already handle stale rows, so cadence only bounds
|
||||
# table size. Hourly is plenty for typical traffic.
|
||||
sender.add_periodic_task(
|
||||
timedelta(hours=1),
|
||||
cleanup_idempotency_dedup.s(),
|
||||
name="cleanup-idempotency-dedup",
|
||||
)
|
||||
sender.add_periodic_task(
|
||||
timedelta(seconds=30),
|
||||
reconciliation_task.s(),
|
||||
name="reconciliation",
|
||||
)
|
||||
sender.add_periodic_task(
|
||||
timedelta(hours=7),
|
||||
version_check_task.s(),
|
||||
name="version-check",
|
||||
)
|
||||
# Bound ``message_events`` growth — every streamed SSE chunk writes
|
||||
# one row, so retained chats accumulate hundreds of rows per
|
||||
# message. Reconnect-replay is only meaningful for streams the user
|
||||
# could plausibly still be waiting on, so 14 days is generous.
|
||||
sender.add_periodic_task(
|
||||
timedelta(hours=24),
|
||||
cleanup_message_events.s(),
|
||||
name="cleanup-message-events",
|
||||
)
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
@@ -153,24 +221,12 @@ def mcp_oauth_task(self, config, user):
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def mcp_oauth_status_task(self, task_id):
|
||||
resp = mcp_oauth_status(self, task_id)
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
@celery.task(bind=True, acks_late=False)
|
||||
def cleanup_pending_tool_state(self):
|
||||
"""Delete pending_tool_state rows past their TTL.
|
||||
|
||||
Replaces Mongo's ``expireAfterSeconds=0`` TTL index — Postgres has
|
||||
no native TTL, so this task runs every 60 seconds to keep
|
||||
``pending_tool_state`` bounded. No-ops if ``POSTGRES_URI`` isn't
|
||||
configured (keeps the task runnable in Mongo-only environments).
|
||||
"""
|
||||
"""Revert stale ``resuming`` rows, then delete TTL-expired rows."""
|
||||
from application.core.settings import settings
|
||||
if not settings.POSTGRES_URI:
|
||||
return {"deleted": 0, "skipped": "POSTGRES_URI not set"}
|
||||
return {"deleted": 0, "reverted": 0, "skipped": "POSTGRES_URI not set"}
|
||||
|
||||
from application.storage.db.engine import get_engine
|
||||
from application.storage.db.repositories.pending_tool_state import (
|
||||
@@ -179,11 +235,73 @@ def cleanup_pending_tool_state(self):
|
||||
|
||||
engine = get_engine()
|
||||
with engine.begin() as conn:
|
||||
deleted = PendingToolStateRepository(conn).cleanup_expired()
|
||||
return {"deleted": deleted}
|
||||
repo = PendingToolStateRepository(conn)
|
||||
reverted = repo.revert_stale_resuming(grace_seconds=600)
|
||||
deleted = repo.cleanup_expired()
|
||||
return {"deleted": deleted, "reverted": reverted}
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
@celery.task(bind=True, acks_late=False)
|
||||
def cleanup_idempotency_dedup(self):
|
||||
"""Delete TTL-expired rows from ``task_dedup`` and ``webhook_dedup``.
|
||||
|
||||
Pure housekeeping — the upsert paths already ignore stale rows
|
||||
(TTL-aware ``ON CONFLICT DO UPDATE``), so this only bounds table
|
||||
growth and keeps SELECT planning tight on large deployments.
|
||||
"""
|
||||
from application.core.settings import settings
|
||||
if not settings.POSTGRES_URI:
|
||||
return {
|
||||
"task_dedup_deleted": 0,
|
||||
"webhook_dedup_deleted": 0,
|
||||
"skipped": "POSTGRES_URI not set",
|
||||
}
|
||||
|
||||
from application.storage.db.engine import get_engine
|
||||
from application.storage.db.repositories.idempotency import (
|
||||
IdempotencyRepository,
|
||||
)
|
||||
|
||||
engine = get_engine()
|
||||
with engine.begin() as conn:
|
||||
return IdempotencyRepository(conn).cleanup_expired()
|
||||
|
||||
|
||||
@celery.task(bind=True, acks_late=False)
|
||||
def reconciliation_task(self):
|
||||
"""Sweep stuck durability rows and escalate them to terminal status + alert."""
|
||||
from application.api.user.reconciliation import run_reconciliation
|
||||
|
||||
return run_reconciliation()
|
||||
|
||||
|
||||
@celery.task(bind=True, acks_late=False)
|
||||
def cleanup_message_events(self):
|
||||
"""Delete ``message_events`` rows older than the retention window.
|
||||
|
||||
Streamed answer responses write one journal row per SSE yield,
|
||||
so unbounded growth would dominate Postgres for any retained-
|
||||
conversations deployment. The reconnect-replay path only needs
|
||||
rows for in-flight streams; 14 days covers paused/tool-action
|
||||
flows comfortably.
|
||||
"""
|
||||
from application.core.settings import settings
|
||||
if not settings.POSTGRES_URI:
|
||||
return {"deleted": 0, "skipped": "POSTGRES_URI not set"}
|
||||
|
||||
from application.storage.db.engine import get_engine
|
||||
from application.storage.db.repositories.message_events import (
|
||||
MessageEventsRepository,
|
||||
)
|
||||
|
||||
ttl_days = settings.MESSAGE_EVENTS_RETENTION_DAYS
|
||||
engine = get_engine()
|
||||
with engine.begin() as conn:
|
||||
deleted = MessageEventsRepository(conn).cleanup_older_than(ttl_days)
|
||||
return {"deleted": deleted, "ttl_days": ttl_days}
|
||||
|
||||
|
||||
@celery.task(bind=True, acks_late=False)
|
||||
def version_check_task(self):
|
||||
"""Periodic anonymous version check.
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Tool management MCP server integration."""
|
||||
|
||||
import json
|
||||
from urllib.parse import urlencode, urlparse
|
||||
|
||||
from flask import current_app, jsonify, make_response, redirect, request
|
||||
@@ -226,7 +225,9 @@ class MCPServerSave(Resource):
|
||||
)
|
||||
redis_client = get_redis_instance()
|
||||
manager = MCPOAuthManager(redis_client)
|
||||
result = manager.get_oauth_status(config["oauth_task_id"])
|
||||
result = manager.get_oauth_status(
|
||||
config["oauth_task_id"], user
|
||||
)
|
||||
if not result.get("status") == "completed":
|
||||
return make_response(
|
||||
jsonify(
|
||||
@@ -438,56 +439,6 @@ class MCPOAuthCallback(Resource):
|
||||
)
|
||||
|
||||
|
||||
@tools_mcp_ns.route("/mcp_server/oauth_status/<string:task_id>")
|
||||
class MCPOAuthStatus(Resource):
|
||||
def get(self, task_id):
|
||||
try:
|
||||
redis_client = get_redis_instance()
|
||||
status_key = f"mcp_oauth_status:{task_id}"
|
||||
status_data = redis_client.get(status_key)
|
||||
|
||||
if status_data:
|
||||
status = json.loads(status_data)
|
||||
if "tools" in status and isinstance(status["tools"], list):
|
||||
status["tools"] = [
|
||||
{
|
||||
"name": t.get("name", "unknown"),
|
||||
"description": t.get("description", ""),
|
||||
}
|
||||
for t in status["tools"]
|
||||
]
|
||||
return make_response(
|
||||
jsonify({"success": True, "task_id": task_id, **status})
|
||||
)
|
||||
else:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"task_id": task_id,
|
||||
"status": "pending",
|
||||
"message": "Waiting for OAuth to start...",
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error getting OAuth status for task {task_id}: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Failed to get OAuth status",
|
||||
"task_id": task_id,
|
||||
}
|
||||
),
|
||||
500,
|
||||
)
|
||||
|
||||
|
||||
@tools_mcp_ns.route("/mcp_server/auth_status")
|
||||
class MCPAuthStatus(Resource):
|
||||
@api.doc(
|
||||
|
||||
@@ -9,6 +9,7 @@ import json
|
||||
import logging
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Generator, Optional
|
||||
|
||||
from flask import Blueprint, jsonify, make_response, request, Response
|
||||
@@ -221,13 +222,26 @@ def _stream_response(
|
||||
for line in internal_stream:
|
||||
if not line.strip():
|
||||
continue
|
||||
# Parse the internal SSE event
|
||||
event_str = line.replace("data: ", "").strip()
|
||||
# ``complete_stream`` prefixes each frame with ``id: <seq>\n``
|
||||
# before the ``data:`` line. Extract just the data line so JSON
|
||||
# decode doesn't choke on the SSE framing.
|
||||
event_str = ""
|
||||
for raw in line.split("\n"):
|
||||
if raw.startswith("data:"):
|
||||
event_str = raw[len("data:") :].lstrip()
|
||||
break
|
||||
if not event_str:
|
||||
continue
|
||||
try:
|
||||
event_data = json.loads(event_str)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
continue
|
||||
|
||||
# Skip the informational ``message_id`` event — it has no v1 /
|
||||
# OpenAI-compatible analog.
|
||||
if event_data.get("type") == "message_id":
|
||||
continue
|
||||
|
||||
# Update completion_id when we get the conversation id
|
||||
if event_data.get("type") == "id":
|
||||
conv_id = event_data.get("id", "")
|
||||
@@ -306,7 +320,16 @@ def list_models():
|
||||
401,
|
||||
)
|
||||
|
||||
# Repository rows now go through ``coerce_pg_native`` at SELECT
|
||||
# time, so timestamps arrive as ISO 8601 strings. Parse before
|
||||
# taking ``.timestamp()``; fall back to ``time.time()`` only when
|
||||
# the value is genuinely missing or unparseable.
|
||||
created = agent.get("created_at") or agent.get("createdAt")
|
||||
if isinstance(created, str):
|
||||
try:
|
||||
created = datetime.fromisoformat(created)
|
||||
except (ValueError, TypeError):
|
||||
created = None
|
||||
created_ts = (
|
||||
int(created.timestamp()) if hasattr(created, "timestamp")
|
||||
else int(time.time())
|
||||
|
||||
@@ -16,6 +16,8 @@ setup_logging()
|
||||
|
||||
from application.api import api # noqa: E402
|
||||
from application.api.answer import answer # noqa: E402
|
||||
from application.api.answer.routes.messages import messages_bp # noqa: E402
|
||||
from application.api.events.routes import events # noqa: E402
|
||||
from application.api.internal.routes import internal # noqa: E402
|
||||
from application.api.user.routes import user # noqa: E402
|
||||
from application.api.connector.routes import connector # noqa: E402
|
||||
@@ -49,6 +51,8 @@ ensure_database_ready(
|
||||
app = Flask(__name__)
|
||||
app.register_blueprint(user)
|
||||
app.register_blueprint(answer)
|
||||
app.register_blueprint(events)
|
||||
app.register_blueprint(messages_bp)
|
||||
app.register_blueprint(internal)
|
||||
app.register_blueprint(connector)
|
||||
app.register_blueprint(v1_bp)
|
||||
@@ -200,7 +204,9 @@ def _bind_user_id_to_log_context():
|
||||
def after_request(response: Response) -> Response:
|
||||
"""Add CORS headers for the pure Flask development entrypoint."""
|
||||
response.headers["Access-Control-Allow-Origin"] = "*"
|
||||
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization"
|
||||
response.headers["Access-Control-Allow-Headers"] = (
|
||||
"Content-Type, Authorization, Idempotency-Key"
|
||||
)
|
||||
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
|
||||
return response
|
||||
|
||||
|
||||
@@ -25,7 +25,12 @@ asgi_app = Starlette(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
|
||||
allow_headers=["Content-Type", "Authorization", "Mcp-Session-Id"],
|
||||
allow_headers=[
|
||||
"Content-Type",
|
||||
"Authorization",
|
||||
"Mcp-Session-Id",
|
||||
"Idempotency-Key",
|
||||
],
|
||||
expose_headers=["Mcp-Session-Id"],
|
||||
),
|
||||
],
|
||||
|
||||
@@ -29,8 +29,17 @@ def get_redis_instance():
|
||||
with _instance_lock:
|
||||
if _redis_instance is None and not _redis_creation_failed:
|
||||
try:
|
||||
# ``health_check_interval`` makes redis-py ping the
|
||||
# connection every N seconds when otherwise idle.
|
||||
# Without it, a half-open TCP (NAT silently dropped
|
||||
# state, ELB idle-close) can hang the SSE generator
|
||||
# in ``pubsub.get_message`` past its keepalive
|
||||
# cadence — the kernel never surfaces the dead
|
||||
# socket because no payload is in flight.
|
||||
_redis_instance = redis.Redis.from_url(
|
||||
settings.CACHE_REDIS_URL, socket_connect_timeout=2
|
||||
settings.CACHE_REDIS_URL,
|
||||
socket_connect_timeout=2,
|
||||
health_check_interval=10,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"Invalid Redis URL: {e}")
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import os
|
||||
from application.core.settings import settings
|
||||
|
||||
broker_url = os.getenv("CELERY_BROKER_URL")
|
||||
result_backend = os.getenv("CELERY_RESULT_BACKEND")
|
||||
# Pydantic loads .env into ``settings`` but does not inject values into
|
||||
# ``os.environ`` — read directly from settings so beat startup (which
|
||||
# imports this module before any explicit env load) sees a real URL.
|
||||
broker_url = settings.CELERY_BROKER_URL
|
||||
result_backend = settings.CELERY_RESULT_BACKEND
|
||||
|
||||
task_serializer = 'json'
|
||||
result_serializer = 'json'
|
||||
@@ -10,7 +13,21 @@ accept_content = ['json']
|
||||
# Autodiscover tasks
|
||||
imports = ('application.api.user.tasks',)
|
||||
|
||||
# Project-scoped queue so a stray sibling worker on the same broker
|
||||
# (other repo, same default ``celery`` queue) can't grab DocsGPT tasks.
|
||||
task_default_queue = "docsgpt"
|
||||
task_default_exchange = "docsgpt"
|
||||
task_default_routing_key = "docsgpt"
|
||||
|
||||
beat_scheduler = "redbeat.RedBeatScheduler"
|
||||
redbeat_redis_url = broker_url
|
||||
redbeat_key_prefix = "redbeat:docsgpt:"
|
||||
redbeat_lock_timeout = 90
|
||||
|
||||
# Survive worker SIGKILL/OOM without silently dropping in-flight tasks.
|
||||
task_acks_late = True
|
||||
task_reject_on_worker_lost = True
|
||||
worker_prefetch_multiplier = settings.CELERY_WORKER_PREFETCH_MULTIPLIER
|
||||
broker_transport_options = {"visibility_timeout": settings.CELERY_VISIBILITY_TIMEOUT}
|
||||
result_expires = 86400 * 7
|
||||
task_track_started = True
|
||||
|
||||
@@ -30,6 +30,12 @@ class Settings(BaseSettings):
|
||||
|
||||
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
|
||||
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
|
||||
# Prefetch=1 caps SIGKILL loss to one task. Visibility timeout must exceed
|
||||
# the longest legitimate task runtime (ingest, agent webhook) but stay
|
||||
# short enough that SIGKILLed tasks redeliver promptly. 1h matches Onyx
|
||||
# and Dify defaults; long ingests can override via env.
|
||||
CELERY_WORKER_PREFETCH_MULTIPLIER: int = 1
|
||||
CELERY_VISIBILITY_TIMEOUT: int = 3600
|
||||
# Only consulted when VECTOR_STORE=mongodb or when running scripts/db/backfill.py; user data lives in Postgres.
|
||||
MONGO_URI: Optional[str] = None
|
||||
# User-data Postgres DB.
|
||||
@@ -182,6 +188,42 @@ class Settings(BaseSettings):
|
||||
COMPRESSION_PROMPT_VERSION: str = "v1.0" # Track prompt iterations
|
||||
COMPRESSION_MAX_HISTORY_POINTS: int = 3 # Keep only last N compression points to prevent DB bloat
|
||||
|
||||
# Internal SSE push channel (notifications + durable replay journal)
|
||||
# Master switch — when False, /api/events emits a "push_disabled" comment
|
||||
# and returns; clients fall back to polling. Publisher becomes a no-op.
|
||||
ENABLE_SSE_PUSH: bool = True
|
||||
# Per-user durable backlog cap (~entries). At typical event rates this
|
||||
# gives ~24h of replay; tune up for verbose feeds, down for memory.
|
||||
EVENTS_STREAM_MAXLEN: int = 1000
|
||||
# SSE keepalive comment cadence. Must sit under Cloudflare's 100s idle
|
||||
# close and iOS Safari's ~60s — 15s gives generous headroom.
|
||||
SSE_KEEPALIVE_SECONDS: int = 15
|
||||
# Cap on simultaneous SSE connections per user. Each connection holds
|
||||
# one WSGI thread (32 per gunicorn worker) and one Redis pub/sub
|
||||
# connection. 8 covers normal multi-tab use without letting one user
|
||||
# starve the pool. Set to 0 to disable the cap.
|
||||
SSE_MAX_CONCURRENT_PER_USER: int = 8
|
||||
# Per-request cap on the number of backlog entries XRANGE returns
|
||||
# for ``/api/events`` snapshots. Bounds the bytes a single replay
|
||||
# can move from Redis to the wire — a malicious client looping
|
||||
# ``Last-Event-ID=<oldest>`` reconnects can only enumerate this
|
||||
# many entries per round-trip. Combined with the per-user
|
||||
# connection cap above and the windowed budget below, total
|
||||
# enumeration throughput is bounded.
|
||||
EVENTS_REPLAY_MAX_PER_REQUEST: int = 200
|
||||
# Sliding-window cap on snapshot replays per user. Once the budget
|
||||
# is exhausted the route returns HTTP 429 with the cursor pinned;
|
||||
# the client backs off and retries after the window rolls over.
|
||||
EVENTS_REPLAY_BUDGET_REQUESTS_PER_WINDOW: int = 30
|
||||
EVENTS_REPLAY_BUDGET_WINDOW_SECONDS: int = 60
|
||||
|
||||
# Retention for the ``message_events`` journal. The ``cleanup_message_events``
|
||||
# beat task deletes rows older than this. Reconnect-replay only
|
||||
# needs the journal for streams a client could still be tailing,
|
||||
# so 14 days is a generous default that covers paused/tool-action
|
||||
# flows without unbounded table growth.
|
||||
MESSAGE_EVENTS_RETENTION_DAYS: int = 14
|
||||
|
||||
@field_validator("POSTGRES_URI", mode="before")
|
||||
@classmethod
|
||||
def _normalize_postgres_uri_validator(cls, v):
|
||||
|
||||
0
application/events/__init__.py
Normal file
0
application/events/__init__.py
Normal file
52
application/events/keys.py
Normal file
52
application/events/keys.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Stream/topic key derivations shared by publisher and SSE consumer.
|
||||
|
||||
Single source of truth for the per-user Redis Streams key and pub/sub
|
||||
topic name. Both must agree exactly — a typo here splits the
|
||||
publisher's writes from the consumer's reads.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def stream_key(user_id: str) -> str:
|
||||
"""Redis Streams key holding the durable backlog for ``user_id``."""
|
||||
return f"user:{user_id}:stream"
|
||||
|
||||
|
||||
def topic_name(user_id: str) -> str:
|
||||
"""Redis pub/sub channel used for live fan-out to ``user_id``."""
|
||||
return f"user:{user_id}"
|
||||
|
||||
|
||||
def connection_counter_key(user_id: str) -> str:
|
||||
"""Redis counter tracking active SSE connections for ``user_id``."""
|
||||
return f"user:{user_id}:sse_count"
|
||||
|
||||
|
||||
def replay_budget_key(user_id: str) -> str:
|
||||
"""Redis counter tracking snapshot replays for ``user_id`` in the
|
||||
rolling rate-limit window."""
|
||||
return f"user:{user_id}:replay_count"
|
||||
|
||||
|
||||
def stream_id_compare(a: str, b: str) -> int:
|
||||
"""Compare two Redis Streams ids. Returns -1, 0, 1 like ``cmp``.
|
||||
|
||||
Stream ids are ``ms-seq`` strings; comparing as strings would be wrong
|
||||
once ``ms`` straddles digit-count boundaries. We parse and compare
|
||||
as ``(int, int)`` tuples.
|
||||
|
||||
Raises ``ValueError`` on malformed input. Callers must pre-validate
|
||||
against ``_STREAM_ID_RE`` (or equivalent) — a lex fallback here let
|
||||
a malformed id compare lex-greater than a real one and silently pin
|
||||
dedup forever.
|
||||
"""
|
||||
a_ms, _, a_seq = a.partition("-")
|
||||
b_ms, _, b_seq = b.partition("-")
|
||||
a_tuple = (int(a_ms), int(a_seq) if a_seq else 0)
|
||||
b_tuple = (int(b_ms), int(b_seq) if b_seq else 0)
|
||||
if a_tuple < b_tuple:
|
||||
return -1
|
||||
if a_tuple > b_tuple:
|
||||
return 1
|
||||
return 0
|
||||
144
application/events/publisher.py
Normal file
144
application/events/publisher.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""User-scoped event publisher: durable backlog + live fan-out.
|
||||
|
||||
Each ``publish_user_event`` call writes twice:
|
||||
|
||||
1. ``XADD user:{user_id}:stream MAXLEN ~ <cap> * event <json>`` — the
|
||||
durable backlog used by SSE reconnect (``Last-Event-ID``) and stream
|
||||
replay. Bounded by ``EVENTS_STREAM_MAXLEN`` (~24h at typical event
|
||||
rates) so the per-user footprint stays predictable.
|
||||
2. ``PUBLISH user:{user_id} <json-with-id>`` — live fan-out to every
|
||||
currently connected SSE generator for the user, across instances.
|
||||
|
||||
Together they give a snapshot-plus-tail story: a reconnecting client
|
||||
reads ``XRANGE`` from its last seen id and then transitions onto the
|
||||
live pub/sub. The Redis Streams entry id (e.g. ``1735682400000-0``) is
|
||||
the canonical, monotonically increasing event id and is what
|
||||
``Last-Event-ID`` carries.
|
||||
|
||||
Failures are logged and swallowed: the caller is typically a Celery
|
||||
task whose primary work has already succeeded, and a notification
|
||||
delivery miss should not surface as a task failure.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
|
||||
from application.cache import get_redis_instance
|
||||
from application.core.settings import settings
|
||||
from application.events.keys import stream_key, topic_name
|
||||
from application.streaming.broadcast_channel import Topic
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _iso_now() -> str:
|
||||
"""ISO 8601 UTC with millisecond precision and Z suffix."""
|
||||
return (
|
||||
datetime.now(timezone.utc)
|
||||
.isoformat(timespec="milliseconds")
|
||||
.replace("+00:00", "Z")
|
||||
)
|
||||
|
||||
|
||||
def publish_user_event(
|
||||
user_id: str,
|
||||
event_type: str,
|
||||
payload: dict[str, Any],
|
||||
*,
|
||||
scope: Optional[dict[str, Any]] = None,
|
||||
) -> Optional[str]:
|
||||
"""Publish a user-scoped event; return the Redis Streams id or ``None``.
|
||||
|
||||
Fire-and-forget: never raises. ``None`` means the event reached
|
||||
neither the journal nor live subscribers (see runbook for causes).
|
||||
"""
|
||||
if not user_id or not event_type:
|
||||
logger.warning(
|
||||
"publish_user_event called without user_id or event_type "
|
||||
"(user_id=%r, event_type=%r)",
|
||||
user_id,
|
||||
event_type,
|
||||
)
|
||||
return None
|
||||
if not settings.ENABLE_SSE_PUSH:
|
||||
return None
|
||||
|
||||
envelope_partial: dict[str, Any] = {
|
||||
"type": event_type,
|
||||
"ts": _iso_now(),
|
||||
"user_id": user_id,
|
||||
"topic": topic_name(user_id),
|
||||
"scope": scope or {},
|
||||
"payload": payload,
|
||||
}
|
||||
|
||||
try:
|
||||
envelope_partial_json = json.dumps(envelope_partial)
|
||||
except (TypeError, ValueError) as exc:
|
||||
logger.warning(
|
||||
"publish_user_event payload not JSON-serializable: "
|
||||
"user=%s type=%s err=%s",
|
||||
user_id,
|
||||
event_type,
|
||||
exc,
|
||||
)
|
||||
return None
|
||||
|
||||
redis = get_redis_instance()
|
||||
if redis is None:
|
||||
logger.debug("Redis unavailable; skipping publish_user_event")
|
||||
return None
|
||||
|
||||
maxlen = settings.EVENTS_STREAM_MAXLEN
|
||||
stream_id: Optional[str] = None
|
||||
try:
|
||||
# Auto-id ('*') gives a monotonic ms-seq id that doubles as the
|
||||
# SSE event id. ``approximate=True`` lets Redis trim in chunks
|
||||
# for performance; the cap is treated as ~MAXLEN, never <.
|
||||
result = redis.xadd(
|
||||
stream_key(user_id),
|
||||
{"event": envelope_partial_json},
|
||||
maxlen=maxlen,
|
||||
approximate=True,
|
||||
)
|
||||
stream_id = (
|
||||
result.decode("utf-8")
|
||||
if isinstance(result, (bytes, bytearray))
|
||||
else str(result)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"xadd failed for user=%s event_type=%s", user_id, event_type
|
||||
)
|
||||
|
||||
# If the durable journal write failed there is no canonical id to
|
||||
# ship — publishing the envelope live would put an id-less record
|
||||
# on the wire that bypasses the SSE route's dedup floor and breaks
|
||||
# ``Last-Event-ID`` semantics for any reconnect. Best-effort
|
||||
# delivery means dropping consistently, not delivering inconsistent
|
||||
# state.
|
||||
if stream_id is None:
|
||||
return None
|
||||
|
||||
envelope = dict(envelope_partial)
|
||||
envelope["id"] = stream_id
|
||||
|
||||
try:
|
||||
Topic(topic_name(user_id)).publish(json.dumps(envelope))
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"publish failed for user=%s event_type=%s", user_id, event_type
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"event.published topic=%s type=%s id=%s",
|
||||
topic_name(user_id),
|
||||
event_type,
|
||||
stream_id,
|
||||
)
|
||||
|
||||
return stream_id
|
||||
@@ -80,6 +80,14 @@ class BaseLLM(ABC):
|
||||
agent_id=self.agent_id,
|
||||
model_user_id=self.model_user_id,
|
||||
)
|
||||
# Tag the fallback LLM so its rows land as
|
||||
# ``source='fallback'`` in cost-attribution dashboards.
|
||||
# Propagate the parent's ``_request_id`` so a user
|
||||
# request that ran fallback is still grouped under one id.
|
||||
self._fallback_llm._token_usage_source = "fallback"
|
||||
self._fallback_llm._request_id = getattr(
|
||||
self, "_request_id", None,
|
||||
)
|
||||
logger.info(
|
||||
f"Fallback LLM initialized from agent backup model: "
|
||||
f"{provider}/{backup_model_id}"
|
||||
@@ -106,6 +114,11 @@ class BaseLLM(ABC):
|
||||
agent_id=self.agent_id,
|
||||
model_user_id=self.model_user_id,
|
||||
)
|
||||
# Same rationale as the agent-backup branch.
|
||||
self._fallback_llm._token_usage_source = "fallback"
|
||||
self._fallback_llm._request_id = getattr(
|
||||
self, "_request_id", None,
|
||||
)
|
||||
logger.info(
|
||||
f"Fallback LLM initialized from global settings: "
|
||||
f"{settings.FALLBACK_LLM_PROVIDER}/{settings.FALLBACK_LLM_NAME}"
|
||||
|
||||
@@ -6,6 +6,7 @@ from google.genai import types
|
||||
from application.core.settings import settings
|
||||
|
||||
from application.llm.base import BaseLLM
|
||||
from application.llm.handlers.google import _decode_thought_signature
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
|
||||
|
||||
@@ -258,7 +259,7 @@ class GoogleLLM(BaseLLM):
|
||||
except (_json.JSONDecodeError, TypeError):
|
||||
args = {}
|
||||
cleaned_args = self._remove_null_values(args)
|
||||
thought_sig = tc.get("thought_signature")
|
||||
thought_sig = _decode_thought_signature(tc.get("thought_signature"))
|
||||
if thought_sig:
|
||||
parts.append(
|
||||
types.Part(
|
||||
@@ -322,7 +323,9 @@ class GoogleLLM(BaseLLM):
|
||||
name=item["function_call"]["name"],
|
||||
args=cleaned_args,
|
||||
),
|
||||
thoughtSignature=item["thought_signature"],
|
||||
thoughtSignature=_decode_thought_signature(
|
||||
item["thought_signature"]
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -10,6 +10,18 @@ from application.logging import build_stack_data
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Cap the agent tool-call loop. Without this an LLM that keeps
|
||||
# requesting more tool calls (preview models, sparse tool results,
|
||||
# under-specified prompts) can chain searches indefinitely and the
|
||||
# stream never finalises. 25 mirrors Dify's default.
|
||||
MAX_TOOL_ITERATIONS = 25
|
||||
_FINALIZE_INSTRUCTION = (
|
||||
f"You have made {MAX_TOOL_ITERATIONS} tool calls. Provide a final "
|
||||
"response to the user based on what you have, without making any "
|
||||
"additional tool calls."
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCall:
|
||||
"""Represents a tool/function call from the LLM."""
|
||||
@@ -624,6 +636,10 @@ class LLMHandler(ABC):
|
||||
agent_id=getattr(agent, "agent_id", None),
|
||||
model_user_id=compression_user_id,
|
||||
)
|
||||
# Side-channel LLM tag — see ``orchestrator.py`` for rationale.
|
||||
compression_llm._token_usage_source = "compression"
|
||||
compression_llm._request_id = getattr(agent, "_request_id", None) \
|
||||
or getattr(getattr(agent, "llm", None), "_request_id", None)
|
||||
|
||||
# Create service without DB persistence capability
|
||||
compression_service = CompressionService(
|
||||
@@ -934,7 +950,9 @@ class LLMHandler(ABC):
|
||||
parsed = self.parse_response(response)
|
||||
self.llm_calls.append(build_stack_data(agent.llm))
|
||||
|
||||
iteration = 0
|
||||
while parsed.requires_tool_call:
|
||||
iteration += 1
|
||||
tool_handler_gen = self.handle_tool_calls(
|
||||
agent, parsed.tool_calls, tools_dict, messages
|
||||
)
|
||||
@@ -958,6 +976,25 @@ class LLMHandler(ABC):
|
||||
}
|
||||
return ""
|
||||
|
||||
# Cap reached: force one final tool-less call so the stream
|
||||
# always ends with content rather than cutting off.
|
||||
if iteration >= MAX_TOOL_ITERATIONS:
|
||||
logger.warning(
|
||||
"agent tool loop hit cap (%d); forcing finalize",
|
||||
MAX_TOOL_ITERATIONS,
|
||||
)
|
||||
messages.append(
|
||||
{"role": "system", "content": _FINALIZE_INSTRUCTION},
|
||||
)
|
||||
response = agent.llm.gen(
|
||||
model=getattr(agent.llm, "model_id", None) or agent.model_id,
|
||||
messages=messages,
|
||||
tools=None,
|
||||
)
|
||||
parsed = self.parse_response(response)
|
||||
self.llm_calls.append(build_stack_data(agent.llm))
|
||||
break
|
||||
|
||||
# ``agent.model_id`` is the registry id (a UUID for BYOM
|
||||
# records). Use the LLM's own model_id, which LLMCreator
|
||||
# already resolved to the upstream model name. Built-ins:
|
||||
@@ -973,7 +1010,12 @@ class LLMHandler(ABC):
|
||||
return parsed.content
|
||||
|
||||
def handle_streaming(
|
||||
self, agent, response: Any, tools_dict: Dict, messages: List[Dict]
|
||||
self,
|
||||
agent,
|
||||
response: Any,
|
||||
tools_dict: Dict,
|
||||
messages: List[Dict],
|
||||
_iteration: int = 0,
|
||||
) -> Generator:
|
||||
"""
|
||||
Handle streaming response flow.
|
||||
@@ -1042,6 +1084,9 @@ class LLMHandler(ABC):
|
||||
}
|
||||
return
|
||||
|
||||
next_iteration = _iteration + 1
|
||||
cap_reached = next_iteration >= MAX_TOOL_ITERATIONS
|
||||
|
||||
# Check if context limit was reached during tool execution
|
||||
if hasattr(agent, 'context_limit_reached') and agent.context_limit_reached:
|
||||
# Add system message warning about context limit
|
||||
@@ -1054,16 +1099,32 @@ class LLMHandler(ABC):
|
||||
)
|
||||
})
|
||||
logger.info("Context limit reached - instructing agent to wrap up")
|
||||
elif cap_reached:
|
||||
logger.warning(
|
||||
"agent tool loop hit cap (%d); forcing finalize",
|
||||
MAX_TOOL_ITERATIONS,
|
||||
)
|
||||
messages.append(
|
||||
{"role": "system", "content": _FINALIZE_INSTRUCTION},
|
||||
)
|
||||
|
||||
# See note above on agent.model_id vs llm.model_id.
|
||||
response = agent.llm.gen_stream(
|
||||
model=getattr(agent.llm, "model_id", None) or agent.model_id,
|
||||
messages=messages,
|
||||
tools=agent.tools if not agent.context_limit_reached else None,
|
||||
tools=(
|
||||
None
|
||||
if cap_reached
|
||||
or getattr(agent, "context_limit_reached", False)
|
||||
else agent.tools
|
||||
),
|
||||
)
|
||||
self.llm_calls.append(build_stack_data(agent.llm))
|
||||
|
||||
yield from self.handle_streaming(agent, response, tools_dict, messages)
|
||||
yield from self.handle_streaming(
|
||||
agent, response, tools_dict, messages,
|
||||
_iteration=next_iteration,
|
||||
)
|
||||
return
|
||||
if parsed.content:
|
||||
buffer += parsed.content
|
||||
|
||||
@@ -1,9 +1,35 @@
|
||||
import base64
|
||||
import binascii
|
||||
import uuid
|
||||
from typing import Any, Dict, Generator
|
||||
from typing import Any, Dict, Generator, Optional, Union
|
||||
|
||||
from application.llm.handlers.base import LLMHandler, LLMResponse, ToolCall
|
||||
|
||||
|
||||
def _encode_thought_signature(sig: Optional[Union[bytes, str]]) -> Optional[str]:
|
||||
# Gemini's Python SDK returns thought_signature as raw bytes, but the
|
||||
# field is typed Optional[str] downstream and gets json.dumps'd into
|
||||
# SSE events. Encode once at ingress so callers only ever see a str.
|
||||
if isinstance(sig, bytes):
|
||||
return base64.b64encode(sig).decode("ascii")
|
||||
return sig
|
||||
|
||||
|
||||
def _decode_thought_signature(
|
||||
sig: Optional[Union[bytes, str]],
|
||||
) -> Optional[Union[bytes, str]]:
|
||||
# Reverse of _encode_thought_signature — Gemini's SDK expects bytes
|
||||
# back when we replay a tool call. ``validate=True`` keeps ASCII
|
||||
# strings that happen to be loosely decodable from being silently
|
||||
# turned into bytes; non-base64 inputs pass through unchanged.
|
||||
if isinstance(sig, str):
|
||||
try:
|
||||
return base64.b64decode(sig.encode("ascii"), validate=True)
|
||||
except (binascii.Error, ValueError):
|
||||
return sig
|
||||
return sig
|
||||
|
||||
|
||||
class GoogleLLMHandler(LLMHandler):
|
||||
"""Handler for Google's GenAI API."""
|
||||
|
||||
@@ -23,7 +49,7 @@ class GoogleLLMHandler(LLMHandler):
|
||||
for idx, part in enumerate(parts):
|
||||
if hasattr(part, "function_call") and part.function_call is not None:
|
||||
has_sig = hasattr(part, "thought_signature") and part.thought_signature is not None
|
||||
thought_sig = part.thought_signature if has_sig else None
|
||||
thought_sig = _encode_thought_signature(part.thought_signature) if has_sig else None
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=str(uuid.uuid4()),
|
||||
@@ -50,7 +76,7 @@ class GoogleLLMHandler(LLMHandler):
|
||||
tool_calls = []
|
||||
if hasattr(response, "function_call") and response.function_call is not None:
|
||||
has_sig = hasattr(response, "thought_signature") and response.thought_signature is not None
|
||||
thought_sig = response.thought_signature if has_sig else None
|
||||
thought_sig = _encode_thought_signature(response.thought_signature) if has_sig else None
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=str(uuid.uuid4()),
|
||||
@@ -70,8 +96,15 @@ class GoogleLLMHandler(LLMHandler):
|
||||
"""Create a tool result message in the standard internal format."""
|
||||
import json as _json
|
||||
|
||||
from application.storage.db.serialization import PGNativeJSONEncoder
|
||||
|
||||
# PostgresTool results commonly include PG-native types
|
||||
# (datetime / UUID / Decimal / bytea) when SELECT touches
|
||||
# timestamptz / numeric / uuid / bytea columns. The shared
|
||||
# encoder handles all five — bytes get base64 (lossless) instead
|
||||
# of the ``str(b'...')`` repr that ``default=str`` would emit.
|
||||
content = (
|
||||
_json.dumps(result)
|
||||
_json.dumps(result, cls=PGNativeJSONEncoder)
|
||||
if not isinstance(result, str)
|
||||
else result
|
||||
)
|
||||
|
||||
@@ -40,8 +40,15 @@ class OpenAILLMHandler(LLMHandler):
|
||||
"""Create a tool result message in the standard internal format."""
|
||||
import json as _json
|
||||
|
||||
from application.storage.db.serialization import PGNativeJSONEncoder
|
||||
|
||||
# PostgresTool results commonly include PG-native types
|
||||
# (datetime / UUID / Decimal / bytea) when SELECT touches
|
||||
# timestamptz / numeric / uuid / bytea columns. The shared
|
||||
# encoder handles all five — bytes get base64 (lossless) instead
|
||||
# of the ``str(b'...')`` repr that ``default=str`` would emit.
|
||||
content = (
|
||||
_json.dumps(result)
|
||||
_json.dumps(result, cls=PGNativeJSONEncoder)
|
||||
if not isinstance(result, str)
|
||||
else result
|
||||
)
|
||||
|
||||
@@ -1,12 +1,28 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import List, Any
|
||||
from typing import Any, List, Optional
|
||||
from retry import retry
|
||||
from tqdm import tqdm
|
||||
from application.core.settings import settings
|
||||
from application.events.publisher import publish_user_event
|
||||
from application.storage.db.repositories.ingest_chunk_progress import (
|
||||
IngestChunkProgressRepository,
|
||||
)
|
||||
from application.storage.db.session import db_session
|
||||
from application.vectorstore.vector_creator import VectorCreator
|
||||
|
||||
|
||||
class EmbeddingPipelineError(Exception):
|
||||
"""Raised when the per-chunk embed loop produces a partial index.
|
||||
|
||||
Escapes into Celery's ``autoretry_for`` so a transient cause (rate
|
||||
limit, network blip) gets another shot. The chunk-progress
|
||||
checkpoint makes retries cheap — only the failed-and-after chunks
|
||||
re-run. After ``MAX_TASK_ATTEMPTS`` the poison-loop guard in
|
||||
``with_idempotency`` finalises the row as ``failed``.
|
||||
"""
|
||||
|
||||
|
||||
def sanitize_content(content: str) -> str:
|
||||
"""
|
||||
Remove NUL characters that can cause vector store ingestion to fail.
|
||||
@@ -22,7 +38,11 @@ def sanitize_content(content: str) -> str:
|
||||
return content.replace('\x00', '')
|
||||
|
||||
|
||||
@retry(tries=10, delay=60)
|
||||
# Per-chunk inline retry. Aggressive defaults (tries=10, delay=60) blocked
|
||||
# the loop for up to 9 min per chunk and wedged the heartbeat: lower the
|
||||
# tail so a transient failure fails-fast and the chunk-progress checkpoint
|
||||
# resumes cleanly on next dispatch.
|
||||
@retry(tries=3, delay=5, backoff=2)
|
||||
def add_text_to_store_with_retry(store: Any, doc: Any, source_id: str) -> None:
|
||||
"""Add a document's text and metadata to the vector store with retry logic.
|
||||
|
||||
@@ -45,21 +65,124 @@ def add_text_to_store_with_retry(store: Any, doc: Any, source_id: str) -> None:
|
||||
raise
|
||||
|
||||
|
||||
def embed_and_store_documents(docs: List[Any], folder_name: str, source_id: str, task_status: Any) -> None:
|
||||
def _init_progress_and_resume_index(
|
||||
source_id: str, total_chunks: int, attempt_id: Optional[str],
|
||||
) -> int:
|
||||
"""Upsert the progress row and return the next chunk index to embed.
|
||||
|
||||
The repository's upsert preserves ``last_index`` only when the
|
||||
incoming ``attempt_id`` matches the stored one (a Celery autoretry
|
||||
of the same task). On a fresh attempt — including any caller that
|
||||
doesn't pass an ``attempt_id``, e.g. legacy code or tests — the
|
||||
row's checkpoint is reset so the loop starts from chunk 0. This
|
||||
is what prevents a completed checkpoint from any prior run
|
||||
silently no-op'ing the next sync/reingest.
|
||||
|
||||
Best-effort: a DB outage falls back to ``0`` (fresh run from
|
||||
chunk 0). The embed loop's own re-raise still ensures partial
|
||||
runs don't get cached as complete.
|
||||
"""
|
||||
try:
|
||||
with db_session() as conn:
|
||||
progress = IngestChunkProgressRepository(conn).init_progress(
|
||||
source_id, total_chunks, attempt_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning(
|
||||
f"Could not init ingest progress for {source_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
return 0
|
||||
if not progress:
|
||||
return 0
|
||||
last_index = progress.get("last_index", -1)
|
||||
if last_index is None or last_index < 0:
|
||||
return 0
|
||||
return int(last_index) + 1
|
||||
|
||||
|
||||
def _record_progress(source_id: str, last_index: int, embedded_chunks: int) -> None:
|
||||
"""Best-effort checkpoint after each chunk; logged but never raised."""
|
||||
try:
|
||||
with db_session() as conn:
|
||||
IngestChunkProgressRepository(conn).record_chunk(
|
||||
source_id, last_index=last_index, embedded_chunks=embedded_chunks
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning(
|
||||
f"Could not record ingest progress for {source_id}: {e}", exc_info=True
|
||||
)
|
||||
|
||||
|
||||
def assert_index_complete(source_id: str) -> None:
|
||||
"""Raise ``EmbeddingPipelineError`` if ``ingest_chunk_progress``
|
||||
shows a partial embed for ``source_id``.
|
||||
|
||||
Defense-in-depth tripwire that workers run after
|
||||
``embed_and_store_documents`` to catch any future swallow path
|
||||
that bypasses the function's own re-raise — the chunk-progress
|
||||
row is the authoritative record of how many chunks landed.
|
||||
No-op when no row exists (zero-doc validation raised before init,
|
||||
or progress repo was unreachable).
|
||||
"""
|
||||
try:
|
||||
with db_session() as conn:
|
||||
progress = IngestChunkProgressRepository(conn).get_progress(source_id)
|
||||
except Exception as e:
|
||||
logging.warning(
|
||||
f"assert_index_complete: progress lookup failed for "
|
||||
f"{source_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
return
|
||||
if not progress:
|
||||
return
|
||||
embedded = int(progress.get("embedded_chunks") or 0)
|
||||
total = int(progress.get("total_chunks") or 0)
|
||||
if embedded < total:
|
||||
raise EmbeddingPipelineError(
|
||||
f"partial index for source {source_id}: "
|
||||
f"{embedded}/{total} chunks embedded"
|
||||
)
|
||||
|
||||
|
||||
def embed_and_store_documents(
|
||||
docs: List[Any],
|
||||
folder_name: str,
|
||||
source_id: str,
|
||||
task_status: Any,
|
||||
*,
|
||||
attempt_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Embeds documents and stores them in a vector store.
|
||||
|
||||
Resumable across Celery autoretries of the *same* task: when
|
||||
``attempt_id`` matches the stored checkpoint's ``attempt_id``,
|
||||
the loop resumes from ``last_index + 1``. A different
|
||||
``attempt_id`` (a fresh sync / reingest invocation) resets the
|
||||
checkpoint so the index is rebuilt from chunk 0 — this is what
|
||||
keeps a completed checkpoint from poisoning the next sync.
|
||||
|
||||
Args:
|
||||
docs: List of documents to be embedded and stored.
|
||||
folder_name: Directory to save the vector store.
|
||||
source_id: Unique identifier for the source.
|
||||
task_status: Task state manager for progress updates.
|
||||
attempt_id: Stable id of the current task invocation,
|
||||
typically ``self.request.id`` from the Celery task body.
|
||||
``None`` is treated as a fresh attempt every time.
|
||||
user_id: When provided, per-percent SSE progress events are
|
||||
published to ``user:{user_id}`` for the in-app upload toast.
|
||||
``None`` is the safe default — workers without a user
|
||||
context (e.g. background syncs) skip the publish.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
|
||||
Raises:
|
||||
OSError: If unable to create folder or save vector store.
|
||||
Exception: If vector store creation or document embedding fails.
|
||||
EmbeddingPipelineError: If a chunk fails after retries.
|
||||
"""
|
||||
# Ensure the folder exists
|
||||
if not os.path.exists(folder_name):
|
||||
@@ -69,41 +192,108 @@ def embed_and_store_documents(docs: List[Any], folder_name: str, source_id: str,
|
||||
if not docs:
|
||||
raise ValueError("No documents to embed - check file format and extension")
|
||||
|
||||
total_docs = len(docs)
|
||||
# Atomic upsert that preserves checkpoint state on attempt-id match
|
||||
# (autoretry of same task) and resets it on mismatch (fresh sync /
|
||||
# reingest). Returns the new resume index — 0 means "start fresh".
|
||||
resume_index = _init_progress_and_resume_index(
|
||||
source_id, total_docs, attempt_id,
|
||||
)
|
||||
is_resume = resume_index > 0
|
||||
|
||||
# Initialize vector store
|
||||
if settings.VECTOR_STORE == "faiss":
|
||||
docs_init = [docs.pop(0)]
|
||||
store = VectorCreator.create_vectorstore(
|
||||
settings.VECTOR_STORE,
|
||||
docs_init=docs_init,
|
||||
source_id=source_id,
|
||||
embeddings_key=os.getenv("EMBEDDINGS_KEY"),
|
||||
)
|
||||
if is_resume:
|
||||
# Load the existing FAISS index from storage so chunks
|
||||
# already embedded by the prior attempt survive the
|
||||
# save_local rewrite at the end of this run.
|
||||
store = VectorCreator.create_vectorstore(
|
||||
settings.VECTOR_STORE,
|
||||
source_id=source_id,
|
||||
embeddings_key=os.getenv("EMBEDDINGS_KEY"),
|
||||
)
|
||||
loop_start = resume_index
|
||||
else:
|
||||
# FAISS requires at least one doc to construct the store;
|
||||
# seed with ``docs[0]`` and let the loop pick up at index 1.
|
||||
store = VectorCreator.create_vectorstore(
|
||||
settings.VECTOR_STORE,
|
||||
docs_init=[docs[0]],
|
||||
source_id=source_id,
|
||||
embeddings_key=os.getenv("EMBEDDINGS_KEY"),
|
||||
)
|
||||
# Record the seeded chunk so single-doc ingests don't fail
|
||||
# ``assert_index_complete`` — the loop never runs for
|
||||
# ``total_docs == 1`` and would otherwise leave
|
||||
# ``embedded_chunks`` at 0 / ``last_index`` at -1. The loop
|
||||
# body's per-iteration ``_record_progress`` overshoots
|
||||
# correctly for multi-chunk runs (counts seed + iterations),
|
||||
# so writing this checkpoint up-front is a no-op for those.
|
||||
_record_progress(source_id, last_index=0, embedded_chunks=1)
|
||||
loop_start = 1
|
||||
else:
|
||||
store = VectorCreator.create_vectorstore(
|
||||
settings.VECTOR_STORE,
|
||||
source_id=source_id,
|
||||
embeddings_key=os.getenv("EMBEDDINGS_KEY"),
|
||||
)
|
||||
store.delete_index()
|
||||
# Only wipe the index on a fresh run — a resume must keep the
|
||||
# chunks that earlier attempts already embedded.
|
||||
if not is_resume:
|
||||
store.delete_index()
|
||||
loop_start = resume_index
|
||||
|
||||
total_docs = len(docs)
|
||||
if is_resume and loop_start >= total_docs:
|
||||
# Nothing left to do; the loop runs zero iterations and
|
||||
# downstream finalize logic still executes. This is only
|
||||
# reachable on a same-attempt retry of a task whose previous
|
||||
# attempt finished — typically a Celery acks_late redelivery
|
||||
# after the task already returned. The ``assert_index_complete``
|
||||
# tripwire still validates ``embedded == total`` afterwards.
|
||||
loop_start = total_docs
|
||||
|
||||
# Process and embed documents
|
||||
for idx, doc in tqdm(
|
||||
enumerate(docs),
|
||||
chunk_error: Exception | None = None
|
||||
failed_idx: int | None = None
|
||||
last_published_pct = -1
|
||||
source_id_str = str(source_id)
|
||||
for idx in tqdm(
|
||||
range(loop_start, total_docs),
|
||||
desc="Embedding 🦖",
|
||||
unit="docs",
|
||||
total=total_docs,
|
||||
total=total_docs - loop_start,
|
||||
bar_format="{l_bar}{bar}| Time Left: {remaining}",
|
||||
):
|
||||
doc = docs[idx]
|
||||
try:
|
||||
# Update task status for progress tracking
|
||||
progress = int(((idx + 1) / total_docs) * 100)
|
||||
task_status.update_state(state="PROGRESS", meta={"current": progress})
|
||||
|
||||
# SSE push for sub-second upload-toast updates. Throttled to one
|
||||
# event per percent so a 10k-chunk ingest emits ~100 events,
|
||||
# not 10k. The Celery update_state above stays the source of
|
||||
# truth for the polling-fallback path.
|
||||
if user_id and progress > last_published_pct:
|
||||
publish_user_event(
|
||||
user_id,
|
||||
"source.ingest.progress",
|
||||
{
|
||||
"current": progress,
|
||||
"total": total_docs,
|
||||
"embedded_chunks": idx + 1,
|
||||
"stage": "embedding",
|
||||
},
|
||||
scope={"kind": "source", "id": source_id_str},
|
||||
)
|
||||
last_published_pct = progress
|
||||
|
||||
# Add document to vector store
|
||||
add_text_to_store_with_retry(store, doc, source_id)
|
||||
_record_progress(source_id, last_index=idx, embedded_chunks=idx + 1)
|
||||
except Exception as e:
|
||||
chunk_error = e
|
||||
failed_idx = idx
|
||||
logging.error(f"Error embedding document {idx}: {e}", exc_info=True)
|
||||
logging.info(f"Saving progress at document {idx} out of {total_docs}")
|
||||
try:
|
||||
@@ -124,3 +314,16 @@ def embed_and_store_documents(docs: List[Any], folder_name: str, source_id: str,
|
||||
raise OSError(f"Unable to save vector store to {folder_name}: {e}") from e
|
||||
else:
|
||||
logging.info("Vector store saved successfully.")
|
||||
|
||||
# Re-raise after the partial save: the chunks that *did* embed are
|
||||
# flushed to disk and recorded in ``ingest_chunk_progress``, so a
|
||||
# Celery autoretry resumes via ``_read_resume_index`` and only
|
||||
# re-runs the failed-and-after chunks. Without the raise, the
|
||||
# task body returns success and ``with_idempotency`` finalises
|
||||
# ``task_dedup`` as ``completed`` for a partial index — poisoning
|
||||
# the cache for 24h.
|
||||
if chunk_error is not None:
|
||||
raise EmbeddingPipelineError(
|
||||
f"embed failure at chunk {failed_idx}/{total_docs} "
|
||||
f"for source {source_id}"
|
||||
) from chunk_error
|
||||
|
||||
@@ -60,6 +60,9 @@ class ClassicRAG(BaseRetriever):
|
||||
agent_id=self.agent_id,
|
||||
model_user_id=self.model_user_id,
|
||||
)
|
||||
# Query-rephrase LLM is a side channel — tag it so its rows
|
||||
# land as ``source='rag_condense'`` in cost-attribution.
|
||||
self.llm._token_usage_source = "rag_condense"
|
||||
|
||||
if "active_docs" in source and source["active_docs"] is not None:
|
||||
if isinstance(source["active_docs"], list):
|
||||
|
||||
@@ -11,6 +11,8 @@ import re
|
||||
from typing import Any, Mapping
|
||||
from uuid import UUID
|
||||
|
||||
from application.storage.db.serialization import coerce_pg_native
|
||||
|
||||
|
||||
_UUID_RE = re.compile(
|
||||
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$",
|
||||
@@ -34,12 +36,17 @@ def looks_like_uuid(value: Any) -> bool:
|
||||
|
||||
|
||||
def row_to_dict(row: Any) -> dict:
|
||||
"""Convert a SQLAlchemy ``Row`` to a plain dict with Mongo-compatible ids.
|
||||
"""Convert a SQLAlchemy ``Row`` to a plain JSON-safe dict.
|
||||
|
||||
During the migration window, API responses and downstream code still
|
||||
expect a string ``_id`` field (matching the Mongo shape). This helper
|
||||
normalizes UUID columns to strings and emits both ``id`` and ``_id`` so
|
||||
existing serializers keep working unchanged.
|
||||
Normalises PG-native types at the SELECT boundary: UUID, datetime,
|
||||
date, Decimal, and bytes are coerced to JSON-safe forms via
|
||||
:func:`coerce_pg_native`. Downstream serialisation (SSE events,
|
||||
JSONB writes, API responses) becomes safe by default — repository
|
||||
consumers no longer need to know that PG returns a different type
|
||||
set than Mongo did.
|
||||
|
||||
Also emits ``_id`` alongside ``id`` for the duration of the Mongo→PG
|
||||
cutover so legacy serializers expecting Mongo's shape keep working.
|
||||
|
||||
Args:
|
||||
row: A SQLAlchemy ``Row`` object, or ``None``.
|
||||
@@ -52,10 +59,9 @@ def row_to_dict(row: Any) -> dict:
|
||||
|
||||
# Row has a ``._mapping`` attribute exposing a MappingProxy view.
|
||||
mapping: Mapping[str, Any] = row._mapping # type: ignore[attr-defined]
|
||||
out = dict(mapping)
|
||||
out = coerce_pg_native(dict(mapping))
|
||||
|
||||
if "id" in out and out["id"] is not None:
|
||||
out["id"] = str(out["id"]) if isinstance(out["id"], UUID) else out["id"]
|
||||
out["_id"] = out["id"]
|
||||
|
||||
return out
|
||||
|
||||
@@ -34,7 +34,7 @@ from sqlalchemy.dialects.postgresql import ARRAY, CITEXT, JSONB, UUID
|
||||
metadata = MetaData()
|
||||
|
||||
|
||||
# --- Phase 1, Tier 1 --------------------------------------------------------
|
||||
# --- Users, prompts, tools, logs --------------------------------------------
|
||||
|
||||
users_table = Table(
|
||||
"users",
|
||||
@@ -91,6 +91,16 @@ token_usage_table = Table(
|
||||
Column("prompt_tokens", Integer, nullable=False, server_default="0"),
|
||||
Column("generated_tokens", Integer, nullable=False, server_default="0"),
|
||||
Column("timestamp", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
# Added in ``0004_durability_foundation``. Distinguishes
|
||||
# ``agent_stream`` (primary completion) from side-channel inserts
|
||||
# (``title`` / ``compression`` / ``rag_condense`` / ``fallback``)
|
||||
# so cost attribution dashboards can group by call source.
|
||||
Column("source", Text, nullable=False, server_default="agent_stream"),
|
||||
# Added in ``0005_token_usage_request_id``. Stream-scoped UUID stamped
|
||||
# on the agent's primary LLM so multi-call agent runs (which produce
|
||||
# N rows) count as a single request via DISTINCT in the repository
|
||||
# query. NULL on side-channel sources by design.
|
||||
Column("request_id", Text),
|
||||
)
|
||||
|
||||
user_logs_table = Table(
|
||||
@@ -128,7 +138,7 @@ app_metadata_table = Table(
|
||||
)
|
||||
|
||||
|
||||
# --- Phase 2, Tier 2 --------------------------------------------------------
|
||||
# --- Agents, sources, attachments, artifacts --------------------------------
|
||||
|
||||
agent_folders_table = Table(
|
||||
"agent_folders",
|
||||
@@ -297,7 +307,7 @@ connector_sessions_table = Table(
|
||||
)
|
||||
|
||||
|
||||
# --- Phase 3, Tier 3 --------------------------------------------------------
|
||||
# --- Conversations, messages, workflows -------------------------------------
|
||||
|
||||
conversations_table = Table(
|
||||
"conversations",
|
||||
@@ -345,9 +355,44 @@ conversation_messages_table = Table(
|
||||
Column("feedback", JSONB),
|
||||
Column("timestamp", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
# Added in 0004_durability_foundation. ``status`` is the WAL state
|
||||
# machine (pending|streaming|complete|failed); ``request_id`` ties a
|
||||
# row to a specific HTTP request for log correlation.
|
||||
Column("status", Text, nullable=False, server_default="complete"),
|
||||
Column("request_id", Text),
|
||||
UniqueConstraint("conversation_id", "position", name="conversation_messages_conv_pos_uidx"),
|
||||
)
|
||||
|
||||
# Per-yield journal of chat-stream events, used by the snapshot+tail
|
||||
# reconnect: the route's GET reconnect endpoint reads
|
||||
# ``WHERE message_id = ? AND sequence_no > ?`` from this table before
|
||||
# tailing the live ``channel:{message_id}`` pub/sub. See
|
||||
# ``application/streaming/event_replay.py`` and migration 0007.
|
||||
message_events_table = Table(
|
||||
"message_events",
|
||||
metadata,
|
||||
# PK is the composite ``(message_id, sequence_no)`` — it doubles as
|
||||
# the snapshot read index (covering range scan on
|
||||
# ``WHERE message_id = ? AND sequence_no > ?``).
|
||||
Column(
|
||||
"message_id",
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("conversation_messages.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
nullable=False,
|
||||
),
|
||||
# Strictly monotonic per ``message_id``. Allocated by the route as it
|
||||
# yields, so the writer is single-threaded for the lifetime of one
|
||||
# stream — no contention, no SERIAL needed.
|
||||
Column("sequence_no", Integer, primary_key=True, nullable=False),
|
||||
Column("event_type", Text, nullable=False),
|
||||
Column("payload", JSONB, nullable=False, server_default="{}"),
|
||||
Column(
|
||||
"created_at", DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
shared_conversations_table = Table(
|
||||
"shared_conversations",
|
||||
metadata,
|
||||
@@ -377,9 +422,101 @@ pending_tool_state_table = Table(
|
||||
Column("client_tools", JSONB),
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("expires_at", DateTime(timezone=True), nullable=False),
|
||||
# Added in ``0004_durability_foundation``. ``status`` is the
|
||||
# ``pending|resuming`` claim flag for the resumed-run path;
|
||||
# ``resumed_at`` stamps when ``mark_resuming`` flipped the row so
|
||||
# the cleanup janitor can revert stale claims after the grace
|
||||
# window.
|
||||
Column("status", Text, nullable=False, server_default="pending"),
|
||||
Column("resumed_at", DateTime(timezone=True)),
|
||||
UniqueConstraint("conversation_id", "user_id", name="pending_tool_state_conv_user_uidx"),
|
||||
)
|
||||
|
||||
|
||||
# --- Durability foundation (idempotency / journals, migration 0004) ---------
|
||||
# CHECK constraints (status enums) and partial indexes are intentionally
|
||||
# omitted from these declarations — the DB is the authority. Repositories
|
||||
# use raw ``text(...)`` SQL against these tables, not the Core objects.
|
||||
|
||||
task_dedup_table = Table(
|
||||
"task_dedup",
|
||||
metadata,
|
||||
Column("idempotency_key", Text, primary_key=True),
|
||||
Column("task_name", Text, nullable=False),
|
||||
Column("task_id", Text, nullable=False),
|
||||
Column("result_json", JSONB),
|
||||
# CHECK (status IN ('pending', 'completed', 'failed')) lives in 0004.
|
||||
Column("status", Text, nullable=False),
|
||||
# Bumped each time the per-Celery-task wrapper re-enters; the
|
||||
# poison-loop guard (``MAX_TASK_ATTEMPTS=5``) refuses to run fn once
|
||||
# this exceeds the threshold.
|
||||
Column("attempt_count", Integer, nullable=False, server_default="0"),
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
# Added in ``0006_idempotency_lease``. Per-invocation random id
|
||||
# written by the wrapper at lease claim; refreshed every 30 s by a
|
||||
# heartbeat thread. Other workers seeing a fresh lease (NOT NULL
|
||||
# AND ``lease_expires_at > now()``) refuse to run the task body.
|
||||
Column("lease_owner_id", Text),
|
||||
Column("lease_expires_at", DateTime(timezone=True)),
|
||||
)
|
||||
|
||||
webhook_dedup_table = Table(
|
||||
"webhook_dedup",
|
||||
metadata,
|
||||
Column("idempotency_key", Text, primary_key=True),
|
||||
Column("agent_id", UUID(as_uuid=True), nullable=False),
|
||||
Column("task_id", Text, nullable=False),
|
||||
Column("response_json", JSONB),
|
||||
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
)
|
||||
|
||||
# Three-phase tool-call journal: ``proposed → executed → confirmed``
|
||||
# (terminal: ``failed``; ``compensated`` is grandfathered in the CHECK
|
||||
# from migration 0004 but no code writes it). The reconciler sweeps
|
||||
# stuck rows via the partial ``tool_call_attempts_pending_ts_idx``.
|
||||
tool_call_attempts_table = Table(
|
||||
"tool_call_attempts",
|
||||
metadata,
|
||||
Column("call_id", Text, primary_key=True),
|
||||
# ON DELETE SET NULL preserves the journal even after the parent
|
||||
# message is deleted — useful for cost-attribution / compliance.
|
||||
Column(
|
||||
"message_id",
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("conversation_messages.id", ondelete="SET NULL"),
|
||||
),
|
||||
Column("tool_id", UUID(as_uuid=True)),
|
||||
Column("tool_name", Text, nullable=False),
|
||||
Column("action_name", Text, nullable=False),
|
||||
Column("arguments", JSONB, nullable=False),
|
||||
Column("result", JSONB),
|
||||
Column("error", Text),
|
||||
# CHECK (status IN ('proposed', 'executed', 'confirmed',
|
||||
# 'compensated', 'failed')) lives in 0004.
|
||||
Column("status", Text, nullable=False),
|
||||
Column("attempted_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
)
|
||||
|
||||
# Per-source ingest checkpoint. Heartbeat thread bumps ``last_updated``
|
||||
# every 30s while a worker embeds; the reconciler escalates when it
|
||||
# stops ticking.
|
||||
ingest_chunk_progress_table = Table(
|
||||
"ingest_chunk_progress",
|
||||
metadata,
|
||||
Column("source_id", UUID(as_uuid=True), primary_key=True),
|
||||
Column("total_chunks", Integer, nullable=False),
|
||||
Column("embedded_chunks", Integer, nullable=False, server_default="0"),
|
||||
Column("last_index", Integer, nullable=False, server_default="-1"),
|
||||
Column("last_updated", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||
# Added in ``0005_ingest_attempt_id``. Stamped from
|
||||
# ``self.request.id`` (Celery's stable task id) so a retry of the
|
||||
# same task resumes from the checkpoint, but a separate invocation
|
||||
# (manual reingest, scheduled sync) resets to a clean re-index.
|
||||
Column("attempt_id", Text),
|
||||
)
|
||||
|
||||
|
||||
workflows_table = Table(
|
||||
"workflows",
|
||||
metadata,
|
||||
|
||||
@@ -25,6 +25,7 @@ from typing import Any, Optional
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
from application.storage.db.serialization import PGNativeJSONEncoder
|
||||
|
||||
|
||||
_UPDATABLE_SCALARS = {
|
||||
@@ -36,7 +37,7 @@ _UPDATABLE_JSONB = {"session_data", "token_info"}
|
||||
def _jsonb(value: Any) -> Any:
|
||||
if value is None:
|
||||
return None
|
||||
return json.dumps(value, default=str)
|
||||
return json.dumps(value, cls=PGNativeJSONEncoder)
|
||||
|
||||
|
||||
class ConnectorSessionsRepository:
|
||||
|
||||
@@ -15,6 +15,7 @@ Covers every operation the legacy Mongo code performs on
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
@@ -22,6 +23,23 @@ from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
|
||||
from application.storage.db.models import conversations_table, conversation_messages_table
|
||||
from application.storage.db.serialization import PGNativeJSONEncoder
|
||||
|
||||
|
||||
class MessageUpdateOutcome(str, Enum):
|
||||
"""Discriminated result of ``update_message_by_id``.
|
||||
|
||||
Distinguishes the row-actually-updated case from the row-already-at-
|
||||
the-requested-terminal-state case so an abort handler can journal
|
||||
``end`` instead of ``error`` when the normal-path finalize already
|
||||
flipped the row to ``complete``.
|
||||
"""
|
||||
|
||||
UPDATED = "updated"
|
||||
ALREADY_COMPLETE = "already_complete"
|
||||
ALREADY_FAILED = "already_failed"
|
||||
NOT_FOUND = "not_found"
|
||||
INVALID = "invalid"
|
||||
|
||||
|
||||
def _message_row_to_dict(row) -> dict:
|
||||
@@ -57,8 +75,8 @@ class ConversationsRepository:
|
||||
- Already-UUID-shaped → returned as-is.
|
||||
- Otherwise treated as a Mongo ObjectId and looked up via
|
||||
``agents.legacy_mongo_id``. Returns ``None`` if no PG row
|
||||
exists yet (e.g. the agent was created before Phase 1
|
||||
backfill).
|
||||
exists yet (e.g. the agent was created before the backfill
|
||||
ran).
|
||||
"""
|
||||
if not agent_id_raw:
|
||||
return None
|
||||
@@ -452,7 +470,7 @@ class ConversationsRepository:
|
||||
),
|
||||
{
|
||||
"id": conversation_id,
|
||||
"point": json.dumps(point, default=str),
|
||||
"point": json.dumps(point, cls=PGNativeJSONEncoder),
|
||||
"max_points": int(max_points),
|
||||
},
|
||||
)
|
||||
@@ -632,6 +650,233 @@ class ConversationsRepository:
|
||||
result = self._conn.execute(text(sql), params)
|
||||
return result.rowcount > 0
|
||||
|
||||
def reserve_message(
|
||||
self,
|
||||
conversation_id: str,
|
||||
*,
|
||||
prompt: str,
|
||||
placeholder_response: str,
|
||||
request_id: str | None = None,
|
||||
status: str = "pending",
|
||||
attachments: list[str] | None = None,
|
||||
model_id: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> dict:
|
||||
"""Pre-persist a placeholder assistant message before the LLM call."""
|
||||
self._conn.execute(
|
||||
text(
|
||||
"SELECT id FROM conversations "
|
||||
"WHERE id = CAST(:conv_id AS uuid) FOR UPDATE"
|
||||
),
|
||||
{"conv_id": conversation_id},
|
||||
)
|
||||
next_pos = self._conn.execute(
|
||||
text(
|
||||
"SELECT COALESCE(MAX(position), -1) + 1 AS next_pos "
|
||||
"FROM conversation_messages "
|
||||
"WHERE conversation_id = CAST(:conv_id AS uuid)"
|
||||
),
|
||||
{"conv_id": conversation_id},
|
||||
).scalar()
|
||||
|
||||
values = {
|
||||
"conversation_id": conversation_id,
|
||||
"position": next_pos,
|
||||
"prompt": prompt,
|
||||
"response": placeholder_response,
|
||||
"status": status,
|
||||
"request_id": request_id,
|
||||
"model_id": model_id,
|
||||
"message_metadata": metadata or {},
|
||||
}
|
||||
if attachments:
|
||||
resolved = self._resolve_attachment_refs(
|
||||
[str(a) for a in attachments],
|
||||
)
|
||||
if resolved:
|
||||
values["attachments"] = resolved
|
||||
|
||||
stmt = (
|
||||
pg_insert(conversation_messages_table)
|
||||
.values(**values)
|
||||
.returning(conversation_messages_table)
|
||||
)
|
||||
result = self._conn.execute(stmt)
|
||||
self._conn.execute(
|
||||
text(
|
||||
"UPDATE conversations SET updated_at = now() "
|
||||
"WHERE id = CAST(:id AS uuid)"
|
||||
),
|
||||
{"id": conversation_id},
|
||||
)
|
||||
return _message_row_to_dict(result.fetchone())
|
||||
|
||||
def update_message_by_id(
|
||||
self, message_id: str, fields: dict,
|
||||
*, only_if_non_terminal: bool = False,
|
||||
) -> MessageUpdateOutcome:
|
||||
"""Update specific fields on a message identified by its UUID.
|
||||
|
||||
``metadata`` is merged into the existing JSONB rather than
|
||||
overwritten, so a reconciler-set ``reconcile_attempts`` survives
|
||||
a successful late finalize. When ``only_if_non_terminal`` is
|
||||
True, the update is gated so a late finalize cannot retract a
|
||||
reconciler-set ``failed`` (or a prior ``complete``).
|
||||
|
||||
The return value discriminates "I updated the row" from "the
|
||||
row was already at a terminal state" so the abort handler can
|
||||
journal ``end`` when the normal-path finalize already ran.
|
||||
"""
|
||||
if not looks_like_uuid(message_id):
|
||||
return MessageUpdateOutcome.INVALID
|
||||
allowed = {
|
||||
"prompt", "response", "thought", "sources", "tool_calls",
|
||||
"attachments", "model_id", "metadata", "timestamp", "status",
|
||||
"request_id", "feedback", "feedback_timestamp",
|
||||
}
|
||||
filtered = {k: v for k, v in fields.items() if k in allowed}
|
||||
if not filtered:
|
||||
return MessageUpdateOutcome.INVALID
|
||||
|
||||
api_to_col = {"metadata": "message_metadata"}
|
||||
|
||||
set_parts = []
|
||||
params: dict = {"id": message_id}
|
||||
for key, val in filtered.items():
|
||||
col = api_to_col.get(key, key)
|
||||
if key == "metadata":
|
||||
if val is None:
|
||||
set_parts.append(f"{col} = NULL")
|
||||
else:
|
||||
set_parts.append(
|
||||
f"{col} = COALESCE({col}, '{{}}'::jsonb) "
|
||||
f"|| CAST(:{col} AS jsonb)"
|
||||
)
|
||||
params[col] = (
|
||||
json.dumps(val) if not isinstance(val, str) else val
|
||||
)
|
||||
elif key in ("sources", "tool_calls", "feedback"):
|
||||
set_parts.append(f"{col} = CAST(:{col} AS jsonb)")
|
||||
if val is None:
|
||||
params[col] = None
|
||||
else:
|
||||
params[col] = (
|
||||
json.dumps(val) if not isinstance(val, str) else val
|
||||
)
|
||||
elif key == "attachments":
|
||||
set_parts.append(f"{col} = CAST(:{col} AS uuid[])")
|
||||
params[col] = self._resolve_attachment_refs(
|
||||
[str(a) for a in val] if val else [],
|
||||
)
|
||||
else:
|
||||
set_parts.append(f"{col} = :{col}")
|
||||
params[col] = val
|
||||
|
||||
set_parts.append("updated_at = now()")
|
||||
update_where = ["id = CAST(:id AS uuid)"]
|
||||
if only_if_non_terminal:
|
||||
update_where.append("status NOT IN ('complete', 'failed')")
|
||||
# Single-statement attempt + prior-status probe. Both CTEs see
|
||||
# the same MVCC snapshot, so ``prior.status`` reflects the row
|
||||
# state before the UPDATE — exactly what we need to tell
|
||||
# ``ALREADY_COMPLETE`` apart from ``ALREADY_FAILED`` apart from
|
||||
# ``NOT_FOUND`` without a follow-up SELECT.
|
||||
sql = (
|
||||
"WITH attempted AS ("
|
||||
f" UPDATE conversation_messages SET {', '.join(set_parts)} "
|
||||
f" WHERE {' AND '.join(update_where)} "
|
||||
" RETURNING 1 AS updated"
|
||||
"), "
|
||||
"prior AS ("
|
||||
" SELECT status FROM conversation_messages "
|
||||
" WHERE id = CAST(:id AS uuid)"
|
||||
") "
|
||||
"SELECT (SELECT updated FROM attempted) AS updated, "
|
||||
" (SELECT status FROM prior) AS prior_status"
|
||||
)
|
||||
row = self._conn.execute(text(sql), params).fetchone()
|
||||
if row is None:
|
||||
return MessageUpdateOutcome.NOT_FOUND
|
||||
updated, prior_status = row[0], row[1]
|
||||
if updated:
|
||||
return MessageUpdateOutcome.UPDATED
|
||||
if prior_status is None:
|
||||
return MessageUpdateOutcome.NOT_FOUND
|
||||
if prior_status == "complete":
|
||||
return MessageUpdateOutcome.ALREADY_COMPLETE
|
||||
if prior_status == "failed":
|
||||
return MessageUpdateOutcome.ALREADY_FAILED
|
||||
# ``only_if_non_terminal=False`` always updates an existing row,
|
||||
# so reaching here means the gate excluded it for some status
|
||||
# the terminal set doesn't cover — treat as "not found" rather
|
||||
# than inventing a new variant.
|
||||
return MessageUpdateOutcome.NOT_FOUND
|
||||
|
||||
def update_message_status(
|
||||
self, message_id: str, status: str,
|
||||
) -> bool:
|
||||
"""Cheap status-only transition (e.g. pending → streaming).
|
||||
|
||||
Only flips non-terminal rows: a reconciler-set ``failed`` row
|
||||
stays put so the late streaming chunk doesn't silently retract
|
||||
the alert.
|
||||
"""
|
||||
if not looks_like_uuid(message_id):
|
||||
return False
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"UPDATE conversation_messages SET status = :status, "
|
||||
"updated_at = now() "
|
||||
"WHERE id = CAST(:id AS uuid) "
|
||||
"AND status NOT IN ('complete', 'failed')"
|
||||
),
|
||||
{"id": message_id, "status": status},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def heartbeat_message(self, message_id: str) -> bool:
|
||||
"""Stamp ``message_metadata.last_heartbeat_at`` with ``clock_timestamp()``.
|
||||
|
||||
The reconciler's staleness check uses ``GREATEST(timestamp,
|
||||
last_heartbeat_at)``, so this call extends a long-running
|
||||
stream's effective freshness without touching ``timestamp`` (the
|
||||
creation time, used for history sort) or ``status`` (the WAL
|
||||
marker). Skips terminal rows so a late heartbeat can't silently
|
||||
retract a reconciler-set ``failed``.
|
||||
"""
|
||||
if not looks_like_uuid(message_id):
|
||||
return False
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE conversation_messages
|
||||
SET message_metadata = jsonb_set(
|
||||
COALESCE(message_metadata, '{}'::jsonb),
|
||||
'{last_heartbeat_at}',
|
||||
to_jsonb(clock_timestamp())
|
||||
)
|
||||
WHERE id = CAST(:id AS uuid)
|
||||
AND status NOT IN ('complete', 'failed')
|
||||
"""
|
||||
),
|
||||
{"id": message_id},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def confirm_executed_tool_calls(self, message_id: str) -> int:
|
||||
"""Flip ``tool_call_attempts.status='executed' → 'confirmed'`` for the message."""
|
||||
if not looks_like_uuid(message_id):
|
||||
return 0
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"UPDATE tool_call_attempts SET status = 'confirmed', "
|
||||
"updated_at = now() "
|
||||
"WHERE message_id = CAST(:mid AS uuid) AND status = 'executed'"
|
||||
),
|
||||
{"mid": message_id},
|
||||
)
|
||||
return result.rowcount or 0
|
||||
|
||||
def truncate_after(self, conversation_id: str, keep_up_to: int) -> int:
|
||||
"""Delete messages with position > keep_up_to.
|
||||
|
||||
|
||||
346
application/storage/db/repositories/idempotency.py
Normal file
346
application/storage/db/repositories/idempotency.py
Normal file
@@ -0,0 +1,346 @@
|
||||
"""Repository for ``webhook_dedup`` and ``task_dedup``; 24h TTL enforced at read."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
from application.storage.db.serialization import PGNativeJSONEncoder
|
||||
|
||||
# 24h TTL is the contract surfaced in the upload/webhook docstrings; the
|
||||
# read filters and the stale-row replacement predicate must agree, or the
|
||||
# upsert can fall into a window where the row is "fresh" to the writer
|
||||
# but "expired" to the reader (or vice versa). Keep one constant so any
|
||||
# future change moves both directions in lockstep.
|
||||
DEDUP_TTL_INTERVAL = "24 hours"
|
||||
|
||||
|
||||
def _jsonb(value: Any) -> Any:
|
||||
if value is None:
|
||||
return None
|
||||
return json.dumps(value, cls=PGNativeJSONEncoder)
|
||||
|
||||
|
||||
class IdempotencyRepository:
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
# --- webhook_dedup -----------------------------------------------------
|
||||
|
||||
def get_webhook(self, key: str) -> Optional[dict]:
|
||||
"""Return the cached webhook row for ``key`` if still within the 24h window."""
|
||||
row = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT * FROM webhook_dedup
|
||||
WHERE idempotency_key = :key
|
||||
AND created_at > now() - CAST(:ttl AS interval)
|
||||
"""
|
||||
),
|
||||
{"key": key, "ttl": DEDUP_TTL_INTERVAL},
|
||||
).fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def record_webhook(
|
||||
self,
|
||||
key: str,
|
||||
agent_id: str,
|
||||
task_id: str,
|
||||
response_json: dict,
|
||||
) -> Optional[dict]:
|
||||
"""Insert a webhook dedup row; return None if another writer raced and won.
|
||||
|
||||
``ON CONFLICT`` replaces an existing row only when its ``created_at``
|
||||
is past TTL — atomic stale-row recycling under the row lock. A
|
||||
within-TTL conflict yields no row; the caller resolves it via
|
||||
:meth:`get_webhook`.
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO webhook_dedup (
|
||||
idempotency_key, agent_id, task_id, response_json
|
||||
)
|
||||
VALUES (
|
||||
:key, CAST(:agent_id AS uuid), :task_id,
|
||||
CAST(:response_json AS jsonb)
|
||||
)
|
||||
ON CONFLICT (idempotency_key) DO UPDATE
|
||||
SET agent_id = EXCLUDED.agent_id,
|
||||
task_id = EXCLUDED.task_id,
|
||||
response_json = EXCLUDED.response_json,
|
||||
created_at = now()
|
||||
WHERE webhook_dedup.created_at
|
||||
<= now() - CAST(:ttl AS interval)
|
||||
RETURNING *
|
||||
"""
|
||||
),
|
||||
{
|
||||
"key": key,
|
||||
"agent_id": agent_id,
|
||||
"task_id": task_id,
|
||||
"response_json": _jsonb(response_json),
|
||||
"ttl": DEDUP_TTL_INTERVAL,
|
||||
},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
# --- task_dedup --------------------------------------------------------
|
||||
|
||||
def get_task(self, key: str) -> Optional[dict]:
|
||||
"""Return the cached task row for ``key`` if still within the 24h window."""
|
||||
row = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT * FROM task_dedup
|
||||
WHERE idempotency_key = :key
|
||||
AND created_at > now() - CAST(:ttl AS interval)
|
||||
"""
|
||||
),
|
||||
{"key": key, "ttl": DEDUP_TTL_INTERVAL},
|
||||
).fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def claim_task(
|
||||
self,
|
||||
key: str,
|
||||
task_name: str,
|
||||
task_id: str,
|
||||
) -> Optional[dict]:
|
||||
"""Claim ``key`` for this task. Returns the inserted row, or None if
|
||||
another writer raced and won. The HTTP entry must call this *before*
|
||||
``.delay()`` so only the winner enqueues the Celery task.
|
||||
|
||||
``ON CONFLICT`` replaces an existing row in two cases:
|
||||
|
||||
- **status='failed'**: the worker's poison-loop guard or the
|
||||
reconciler's stuck-pending sweep finalised the prior attempt
|
||||
as failed. Both explicitly intend a same-key retry to re-run
|
||||
(see ``run_reconciliation`` Q5 docstring) — letting the row
|
||||
block for 24 h would silently undo that intent.
|
||||
- **created_at past TTL**: a stale claim from any status no
|
||||
longer represents a meaningful dedup signal.
|
||||
|
||||
``status='completed'`` rows still block within TTL — that's the
|
||||
cached-success contract callers rely on. ``status='pending'``
|
||||
rows still block within TTL so concurrent same-key requests
|
||||
collapse onto the in-flight task. Result/attempt fields are
|
||||
reset to their fresh-claim defaults during replacement.
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO task_dedup (
|
||||
idempotency_key, task_name, task_id, result_json, status
|
||||
)
|
||||
VALUES (
|
||||
:key, :task_name, :task_id, NULL, 'pending'
|
||||
)
|
||||
ON CONFLICT (idempotency_key) DO UPDATE
|
||||
SET task_name = EXCLUDED.task_name,
|
||||
task_id = EXCLUDED.task_id,
|
||||
result_json = NULL,
|
||||
status = 'pending',
|
||||
attempt_count = 0,
|
||||
created_at = now()
|
||||
WHERE task_dedup.status = 'failed'
|
||||
OR task_dedup.created_at
|
||||
<= now() - CAST(:ttl AS interval)
|
||||
RETURNING *
|
||||
"""
|
||||
),
|
||||
{
|
||||
"key": key,
|
||||
"task_name": task_name,
|
||||
"task_id": task_id,
|
||||
"ttl": DEDUP_TTL_INTERVAL,
|
||||
},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def try_claim_lease(
|
||||
self,
|
||||
key: str,
|
||||
task_name: str,
|
||||
task_id: str,
|
||||
owner_id: str,
|
||||
ttl_seconds: int = 60,
|
||||
) -> Optional[int]:
|
||||
"""Atomically claim the running lease for ``key``.
|
||||
|
||||
Returns the new ``attempt_count`` if this caller now owns the
|
||||
lease (fresh insert OR existing row whose lease was empty/expired),
|
||||
or ``None`` if a different worker holds a live lease.
|
||||
|
||||
The conflict path also bumps ``attempt_count`` so the
|
||||
poison-loop guard in :func:`with_idempotency` can fire after
|
||||
:data:`MAX_TASK_ATTEMPTS` reclaims. ``status='completed'`` rows
|
||||
are deliberately untouched — :func:`_lookup_completed` is the
|
||||
cache short-circuit and runs before this. Uses
|
||||
``clock_timestamp()`` so a same-transaction refresh actually
|
||||
moves the expiry forward (``now()`` is frozen at txn start).
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO task_dedup (
|
||||
idempotency_key, task_name, task_id, status, attempt_count,
|
||||
lease_owner_id, lease_expires_at
|
||||
) VALUES (
|
||||
:key, :task_name, :task_id, 'pending', 1,
|
||||
:owner,
|
||||
clock_timestamp() + make_interval(secs => :ttl)
|
||||
)
|
||||
ON CONFLICT (idempotency_key) DO UPDATE
|
||||
SET attempt_count = task_dedup.attempt_count + 1,
|
||||
task_name = EXCLUDED.task_name,
|
||||
lease_owner_id = EXCLUDED.lease_owner_id,
|
||||
lease_expires_at = EXCLUDED.lease_expires_at
|
||||
WHERE task_dedup.status <> 'completed'
|
||||
AND (task_dedup.lease_expires_at IS NULL
|
||||
OR task_dedup.lease_expires_at <= clock_timestamp())
|
||||
RETURNING attempt_count
|
||||
"""
|
||||
),
|
||||
{
|
||||
"key": key,
|
||||
"task_name": task_name,
|
||||
"task_id": task_id,
|
||||
"owner": owner_id,
|
||||
"ttl": int(ttl_seconds),
|
||||
},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return int(row[0]) if row is not None else None
|
||||
|
||||
def refresh_lease(
|
||||
self,
|
||||
key: str,
|
||||
owner_id: str,
|
||||
ttl_seconds: int = 60,
|
||||
) -> bool:
|
||||
"""Bump ``lease_expires_at`` if this caller still owns the lease.
|
||||
|
||||
Returns False when ownership was lost (lease stolen by another
|
||||
worker after expiry, or row finalised). The heartbeat thread
|
||||
logs that as a warning but doesn't try to abort the running
|
||||
task — at-most-one-worker is bounded by ``ttl_seconds``, the
|
||||
damage from a brief overlap window is unavoidable in this case.
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE task_dedup
|
||||
SET lease_expires_at =
|
||||
clock_timestamp() + make_interval(secs => :ttl)
|
||||
WHERE idempotency_key = :key
|
||||
AND lease_owner_id = :owner
|
||||
AND status = 'pending'
|
||||
"""
|
||||
),
|
||||
{
|
||||
"key": key,
|
||||
"owner": owner_id,
|
||||
"ttl": int(ttl_seconds),
|
||||
},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def release_lease(self, key: str, owner_id: str) -> bool:
|
||||
"""Clear ``lease_owner_id`` / ``lease_expires_at`` on the
|
||||
wrapper's exception path so Celery's autoretry_for doesn't have
|
||||
to wait the full ``ttl_seconds`` before the next worker can
|
||||
re-claim. No-op if a different worker has since taken over the
|
||||
lease — that case is benign (we'd just be acknowledging we
|
||||
weren't the owner anymore).
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE task_dedup
|
||||
SET lease_owner_id = NULL,
|
||||
lease_expires_at = NULL
|
||||
WHERE idempotency_key = :key
|
||||
AND lease_owner_id = :owner
|
||||
AND status = 'pending'
|
||||
"""
|
||||
),
|
||||
{"key": key, "owner": owner_id},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def finalize_task(
|
||||
self,
|
||||
key: str,
|
||||
*,
|
||||
result_json: Optional[dict],
|
||||
status: str,
|
||||
) -> bool:
|
||||
"""Promote ``status='pending'`` → ``completed|failed`` with the
|
||||
recorded result. Also clears the lease columns so a stale
|
||||
``lease_expires_at`` doesn't show up in operator dashboards.
|
||||
No-op if the row is already terminal — preserves the first
|
||||
writer's outcome on a crash + retry.
|
||||
"""
|
||||
if status not in ("completed", "failed"):
|
||||
raise ValueError(f"finalize_task: invalid status {status!r}")
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE task_dedup
|
||||
SET status = :status,
|
||||
result_json = CAST(:result_json AS jsonb),
|
||||
lease_owner_id = NULL,
|
||||
lease_expires_at = NULL
|
||||
WHERE idempotency_key = :key
|
||||
AND status = 'pending'
|
||||
"""
|
||||
),
|
||||
{
|
||||
"key": key,
|
||||
"status": status,
|
||||
"result_json": _jsonb(result_json),
|
||||
},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
# --- housekeeping ------------------------------------------------------
|
||||
|
||||
def cleanup_expired(self) -> dict:
|
||||
"""Delete rows past TTL from both dedup tables; return per-table counts.
|
||||
|
||||
The TTL-aware upserts already prevent stale rows from blocking new
|
||||
work, so this is purely housekeeping — bounds table growth and
|
||||
keeps test isolation cheap. Safe to run concurrently with other
|
||||
writers: a same-key INSERT racing the DELETE will either find no
|
||||
row (acts as a fresh insert) or find a fresh row (re-created
|
||||
between DELETE and conflict-check), neither of which is wrong.
|
||||
"""
|
||||
task_deleted = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
DELETE FROM task_dedup
|
||||
WHERE created_at <= now() - CAST(:ttl AS interval)
|
||||
"""
|
||||
),
|
||||
{"ttl": DEDUP_TTL_INTERVAL},
|
||||
).rowcount
|
||||
webhook_deleted = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
DELETE FROM webhook_dedup
|
||||
WHERE created_at <= now() - CAST(:ttl AS interval)
|
||||
"""
|
||||
),
|
||||
{"ttl": DEDUP_TTL_INTERVAL},
|
||||
).rowcount
|
||||
return {
|
||||
"task_dedup_deleted": int(task_deleted or 0),
|
||||
"webhook_dedup_deleted": int(webhook_deleted or 0),
|
||||
}
|
||||
|
||||
127
application/storage/db/repositories/ingest_chunk_progress.py
Normal file
127
application/storage/db/repositories/ingest_chunk_progress.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""Repository for ``ingest_chunk_progress``; per-source resume + heartbeat."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
|
||||
|
||||
class IngestChunkProgressRepository:
|
||||
"""Read/write helpers for ``ingest_chunk_progress``."""
|
||||
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def init_progress(
|
||||
self,
|
||||
source_id: str,
|
||||
total_chunks: int,
|
||||
attempt_id: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Upsert the progress row, scoped by ``attempt_id``.
|
||||
|
||||
On conflict the upsert distinguishes two cases:
|
||||
|
||||
- **Same attempt** (``attempt_id`` matches the stored value):
|
||||
this is a Celery autoretry of the same task — preserve
|
||||
``last_index`` / ``embedded_chunks`` so the embed loop resumes
|
||||
from the checkpoint. Only ``total_chunks`` and
|
||||
``last_updated`` get refreshed.
|
||||
- **Different attempt** (a fresh invocation: manual reingest,
|
||||
scheduled sync, or any caller that didn't pass an
|
||||
``attempt_id``): reset ``last_index`` to ``-1`` and
|
||||
``embedded_chunks`` to ``0`` so the loop starts from chunk 0.
|
||||
This prevents a completed checkpoint from any prior run
|
||||
poisoning the index.
|
||||
|
||||
``IS NOT DISTINCT FROM`` treats two NULLs as equal — so legacy
|
||||
rows with NULL ``attempt_id`` resume against another NULL
|
||||
caller (e.g. test fixtures), but get reset the moment a real
|
||||
``attempt_id`` arrives.
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO ingest_chunk_progress (
|
||||
source_id, total_chunks, embedded_chunks, last_index,
|
||||
attempt_id, last_updated
|
||||
)
|
||||
VALUES (
|
||||
CAST(:source_id AS uuid), :total_chunks, 0, -1,
|
||||
:attempt_id, now()
|
||||
)
|
||||
ON CONFLICT (source_id) DO UPDATE SET
|
||||
total_chunks = EXCLUDED.total_chunks,
|
||||
last_updated = now(),
|
||||
last_index = CASE
|
||||
WHEN ingest_chunk_progress.attempt_id
|
||||
IS NOT DISTINCT FROM EXCLUDED.attempt_id
|
||||
THEN ingest_chunk_progress.last_index
|
||||
ELSE -1
|
||||
END,
|
||||
embedded_chunks = CASE
|
||||
WHEN ingest_chunk_progress.attempt_id
|
||||
IS NOT DISTINCT FROM EXCLUDED.attempt_id
|
||||
THEN ingest_chunk_progress.embedded_chunks
|
||||
ELSE 0
|
||||
END,
|
||||
attempt_id = EXCLUDED.attempt_id
|
||||
RETURNING *
|
||||
"""
|
||||
),
|
||||
{
|
||||
"source_id": str(source_id),
|
||||
"total_chunks": int(total_chunks),
|
||||
"attempt_id": attempt_id,
|
||||
},
|
||||
)
|
||||
return row_to_dict(result.fetchone())
|
||||
|
||||
def record_chunk(
|
||||
self, source_id: str, last_index: int, embedded_chunks: int
|
||||
) -> None:
|
||||
"""Persist progress after a chunk is embedded."""
|
||||
self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE ingest_chunk_progress
|
||||
SET last_index = :last_index,
|
||||
embedded_chunks = :embedded_chunks,
|
||||
last_updated = now()
|
||||
WHERE source_id = CAST(:source_id AS uuid)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"source_id": str(source_id),
|
||||
"last_index": int(last_index),
|
||||
"embedded_chunks": int(embedded_chunks),
|
||||
},
|
||||
)
|
||||
|
||||
def get_progress(self, source_id: str) -> Optional[dict]:
|
||||
"""Return the progress row for ``source_id`` if it exists."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT * FROM ingest_chunk_progress "
|
||||
"WHERE source_id = CAST(:source_id AS uuid)"
|
||||
),
|
||||
{"source_id": str(source_id)},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def bump_heartbeat(self, source_id: str) -> None:
|
||||
"""Refresh ``last_updated`` so the row looks alive to the reconciler."""
|
||||
self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE ingest_chunk_progress
|
||||
SET last_updated = now()
|
||||
WHERE source_id = CAST(:source_id AS uuid)
|
||||
"""
|
||||
),
|
||||
{"source_id": str(source_id)},
|
||||
)
|
||||
248
application/storage/db/repositories/message_events.py
Normal file
248
application/storage/db/repositories/message_events.py
Normal file
@@ -0,0 +1,248 @@
|
||||
"""Repository for ``message_events`` — the chat-stream snapshot journal.
|
||||
|
||||
``record`` / ``bulk_record`` write per-yield events; ``read_after``
|
||||
replays rows past a cursor for reconnect snapshots. Composite PK
|
||||
``(message_id, sequence_no)`` raises ``IntegrityError`` on duplicates.
|
||||
Callers must use short-lived per-call transactions — long-lived
|
||||
transactions hide writes from reconnecting clients on a separate
|
||||
connection and turn one bad row into ``InFailedSqlTransaction``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageEventsRepository:
|
||||
"""Read/write helpers for ``message_events``."""
|
||||
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def record(
|
||||
self,
|
||||
message_id: str,
|
||||
sequence_no: int,
|
||||
event_type: str,
|
||||
payload: Optional[Any] = None,
|
||||
) -> None:
|
||||
"""Append a single event to the journal.
|
||||
|
||||
At this raw repo layer ``payload`` is preserved as-is when not
|
||||
``None`` (lists, scalars, and dicts all round-trip via JSONB);
|
||||
``None`` substitutes an empty object so the column's NOT NULL
|
||||
invariant holds. The streaming-route wrapper
|
||||
``application/streaming/message_journal.py::record_event``
|
||||
tightens this contract to dicts only — the live and replay
|
||||
paths reconstruct non-dict payloads differently, so the wrapper
|
||||
rejects them at the gate. Direct callers of this repo method
|
||||
(cleanup tasks, tests, future ad-hoc consumers) keep the wider
|
||||
JSONB-compatible surface.
|
||||
|
||||
Raises ``sqlalchemy.exc.IntegrityError`` on duplicate
|
||||
``(message_id, sequence_no)`` and ``DataError`` on a malformed
|
||||
``message_id`` UUID. Both abort the surrounding transaction —
|
||||
callers must run inside a short-lived per-event session
|
||||
(see module docstring).
|
||||
"""
|
||||
if not event_type:
|
||||
raise ValueError("event_type must be a non-empty string")
|
||||
materialised_payload = payload if payload is not None else {}
|
||||
self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO message_events (
|
||||
message_id, sequence_no, event_type, payload
|
||||
) VALUES (
|
||||
CAST(:message_id AS uuid), :sequence_no, :event_type,
|
||||
CAST(:payload AS jsonb)
|
||||
)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"message_id": str(message_id),
|
||||
"sequence_no": int(sequence_no),
|
||||
"event_type": event_type,
|
||||
"payload": json.dumps(materialised_payload),
|
||||
},
|
||||
)
|
||||
|
||||
def bulk_record(
|
||||
self,
|
||||
message_id: str,
|
||||
events: list[tuple[int, str, dict]],
|
||||
) -> None:
|
||||
"""Append multiple events for ``message_id`` in one INSERT.
|
||||
|
||||
``events`` is a list of ``(sequence_no, event_type, payload)``
|
||||
tuples. SQLAlchemy ``executemany`` issues one bulk INSERT;
|
||||
Postgres treats the whole batch as one statement, so an
|
||||
IntegrityError on any row aborts the entire batch.
|
||||
|
||||
Caller contract: on IntegrityError, do NOT retry this method
|
||||
with the same batch — fall back to per-row ``record()`` calls
|
||||
(each in its own short-lived session) so a single colliding
|
||||
seq doesn't drop the rest of the batch. ``BatchedJournalWriter``
|
||||
in ``application/streaming/message_journal.py`` is the canonical
|
||||
consumer.
|
||||
"""
|
||||
if not events:
|
||||
return
|
||||
params = [
|
||||
{
|
||||
"message_id": str(message_id),
|
||||
"sequence_no": int(seq),
|
||||
"event_type": event_type,
|
||||
"payload": json.dumps(payload if payload is not None else {}),
|
||||
}
|
||||
for seq, event_type, payload in events
|
||||
]
|
||||
self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO message_events (
|
||||
message_id, sequence_no, event_type, payload
|
||||
) VALUES (
|
||||
CAST(:message_id AS uuid), :sequence_no, :event_type,
|
||||
CAST(:payload AS jsonb)
|
||||
)
|
||||
"""
|
||||
),
|
||||
params,
|
||||
)
|
||||
|
||||
def read_after(
|
||||
self,
|
||||
message_id: str,
|
||||
last_sequence_no: Optional[int] = None,
|
||||
) -> list[dict]:
|
||||
"""Return events with ``sequence_no > last_sequence_no``.
|
||||
|
||||
``last_sequence_no=None`` returns the full backlog. Rows are
|
||||
returned in ascending ``sequence_no`` order. The composite PK
|
||||
is the snapshot read index for this scan — Postgres typically
|
||||
picks an in-order index range scan, though for highly mixed
|
||||
data the planner may pick a bitmap+sort. Either way the result
|
||||
is sorted on ``sequence_no``.
|
||||
|
||||
Returns a ``list`` (not a generator) so the underlying
|
||||
``Result`` is fully drained before the caller can issue
|
||||
another query on the same connection.
|
||||
"""
|
||||
cursor = -1 if last_sequence_no is None else int(last_sequence_no)
|
||||
rows = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT message_id, sequence_no, event_type, payload, created_at
|
||||
FROM message_events
|
||||
WHERE message_id = CAST(:message_id AS uuid)
|
||||
AND sequence_no > :cursor
|
||||
ORDER BY sequence_no ASC
|
||||
"""
|
||||
),
|
||||
{"message_id": str(message_id), "cursor": cursor},
|
||||
).fetchall()
|
||||
return [row_to_dict(row) for row in rows]
|
||||
|
||||
def cleanup_older_than(self, ttl_days: int) -> int:
|
||||
"""Delete journal rows older than ``ttl_days``. Returns row count.
|
||||
|
||||
Reconnect-replay is meaningful only for streams the client
|
||||
could plausibly still be waiting on, so old rows are dead
|
||||
weight. The ``message_events_created_at_idx`` btree makes the
|
||||
range delete a cheap index scan even on large tables.
|
||||
"""
|
||||
if ttl_days <= 0:
|
||||
raise ValueError("ttl_days must be positive")
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
DELETE FROM message_events
|
||||
WHERE created_at < now() - make_interval(days => :ttl_days)
|
||||
"""
|
||||
),
|
||||
{"ttl_days": int(ttl_days)},
|
||||
)
|
||||
return int(result.rowcount or 0)
|
||||
|
||||
def reconstruct_partial(self, message_id: str) -> dict:
|
||||
"""Rebuild partial response/thought/sources/tool_calls from journal events.
|
||||
|
||||
``answer``/``thought`` chunks concat in seq order; ``source``/
|
||||
``tool_calls`` carry the full list at emit time (last-wins).
|
||||
"""
|
||||
rows = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT sequence_no, event_type, payload
|
||||
FROM message_events
|
||||
WHERE message_id = CAST(:message_id AS uuid)
|
||||
ORDER BY sequence_no ASC
|
||||
"""
|
||||
),
|
||||
{"message_id": str(message_id)},
|
||||
).fetchall()
|
||||
|
||||
response_parts: list[str] = []
|
||||
thought_parts: list[str] = []
|
||||
sources: list = []
|
||||
tool_calls: list = []
|
||||
|
||||
for row in rows:
|
||||
payload = row.payload
|
||||
if not isinstance(payload, dict):
|
||||
continue
|
||||
etype = row.event_type
|
||||
if etype == "answer":
|
||||
chunk = payload.get("answer")
|
||||
if isinstance(chunk, str):
|
||||
response_parts.append(chunk)
|
||||
elif etype == "thought":
|
||||
chunk = payload.get("thought")
|
||||
if isinstance(chunk, str):
|
||||
thought_parts.append(chunk)
|
||||
elif etype == "source":
|
||||
src = payload.get("source")
|
||||
if isinstance(src, list):
|
||||
sources = src
|
||||
elif etype == "tool_calls":
|
||||
tcs = payload.get("tool_calls")
|
||||
if isinstance(tcs, list):
|
||||
tool_calls = tcs
|
||||
|
||||
return {
|
||||
"response": "".join(response_parts),
|
||||
"thought": "".join(thought_parts),
|
||||
"sources": sources,
|
||||
"tool_calls": tool_calls,
|
||||
}
|
||||
|
||||
def latest_sequence_no(self, message_id: str) -> Optional[int]:
|
||||
"""Largest ``sequence_no`` recorded for ``message_id``, or ``None``.
|
||||
|
||||
Used by the route to seed the per-stream allocator on retry /
|
||||
process restart so a re-run continues numbering instead of
|
||||
trampling earlier entries with duplicate sequence_no.
|
||||
"""
|
||||
# ``MAX`` always returns one row — NULL when the journal is
|
||||
# empty — so we test the value, not the row presence.
|
||||
row = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT MAX(sequence_no) AS s
|
||||
FROM message_events
|
||||
WHERE message_id = CAST(:message_id AS uuid)
|
||||
"""
|
||||
),
|
||||
{"message_id": str(message_id)},
|
||||
).first()
|
||||
value = row[0] if row is not None else None
|
||||
return int(value) if value is not None else None
|
||||
@@ -7,6 +7,11 @@ Mirrors the continuation service's three operations on
|
||||
- load_state → find_one by (conversation_id, user_id)
|
||||
- delete_state → delete_one by (conversation_id, user_id)
|
||||
|
||||
Adds ``mark_resuming`` so a resumed run can claim a row without
|
||||
deleting it; a separate ``revert_stale_resuming`` flips abandoned
|
||||
``resuming`` rows back to ``pending`` so a crashed worker doesn't
|
||||
strand the user.
|
||||
|
||||
Plus a cleanup method for the Celery beat task that replaces Mongo's
|
||||
TTL index.
|
||||
"""
|
||||
@@ -20,6 +25,7 @@ from typing import Optional
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
from application.storage.db.serialization import PGNativeJSONEncoder
|
||||
|
||||
PENDING_STATE_TTL_SECONDS = 30 * 60 # 1800 seconds
|
||||
|
||||
@@ -71,19 +77,24 @@ class PendingToolStateRepository:
|
||||
agent_config = EXCLUDED.agent_config,
|
||||
client_tools = EXCLUDED.client_tools,
|
||||
created_at = EXCLUDED.created_at,
|
||||
expires_at = EXCLUDED.expires_at
|
||||
expires_at = EXCLUDED.expires_at,
|
||||
status = 'pending',
|
||||
resumed_at = NULL
|
||||
RETURNING *
|
||||
"""
|
||||
),
|
||||
{
|
||||
"conv_id": conversation_id,
|
||||
"user_id": user_id,
|
||||
"messages": json.dumps(messages),
|
||||
"pending": json.dumps(pending_tool_calls),
|
||||
"tools_dict": json.dumps(tools_dict),
|
||||
"schemas": json.dumps(tool_schemas),
|
||||
"agent_config": json.dumps(agent_config),
|
||||
"client_tools": json.dumps(client_tools) if client_tools is not None else None,
|
||||
"messages": json.dumps(messages, cls=PGNativeJSONEncoder),
|
||||
"pending": json.dumps(pending_tool_calls, cls=PGNativeJSONEncoder),
|
||||
"tools_dict": json.dumps(tools_dict, cls=PGNativeJSONEncoder),
|
||||
"schemas": json.dumps(tool_schemas, cls=PGNativeJSONEncoder),
|
||||
"agent_config": json.dumps(agent_config, cls=PGNativeJSONEncoder),
|
||||
"client_tools": (
|
||||
json.dumps(client_tools, cls=PGNativeJSONEncoder)
|
||||
if client_tools is not None else None
|
||||
),
|
||||
"created_at": now,
|
||||
"expires_at": expires,
|
||||
},
|
||||
@@ -113,6 +124,45 @@ class PendingToolStateRepository:
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def mark_resuming(self, conversation_id: str, user_id: str) -> bool:
|
||||
"""Flip a pending row to ``resuming`` and stamp ``resumed_at``."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE pending_tool_state
|
||||
SET status = 'resuming', resumed_at = clock_timestamp()
|
||||
WHERE conversation_id = CAST(:conv_id AS uuid)
|
||||
AND user_id = :user_id
|
||||
AND status = 'pending'
|
||||
"""
|
||||
),
|
||||
{"conv_id": conversation_id, "user_id": user_id},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def revert_stale_resuming(
|
||||
self,
|
||||
grace_seconds: int = 600,
|
||||
ttl_extension_seconds: int = PENDING_STATE_TTL_SECONDS,
|
||||
) -> int:
|
||||
"""Revert ``resuming`` rows older than ``grace_seconds`` to ``pending``; bump TTL."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE pending_tool_state
|
||||
SET status = 'pending',
|
||||
resumed_at = NULL,
|
||||
expires_at = clock_timestamp()
|
||||
+ make_interval(secs => :ttl)
|
||||
WHERE status = 'resuming'
|
||||
AND resumed_at
|
||||
< clock_timestamp() - make_interval(secs => :grace)
|
||||
"""
|
||||
),
|
||||
{"grace": grace_seconds, "ttl": ttl_extension_seconds},
|
||||
)
|
||||
return result.rowcount
|
||||
|
||||
def cleanup_expired(self) -> int:
|
||||
"""Delete rows where ``expires_at < now()``.
|
||||
|
||||
|
||||
273
application/storage/db/repositories/reconciliation.py
Normal file
273
application/storage/db/repositories/reconciliation.py
Normal file
@@ -0,0 +1,273 @@
|
||||
"""Repository for reconciliation sweeps over stuck durability rows."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
|
||||
|
||||
class ReconciliationRepository:
|
||||
"""Sweeps and terminal writes for the reconciler beat task."""
|
||||
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def find_and_lock_stuck_messages(
|
||||
self, *, age_minutes: int = 5, limit: int = 100,
|
||||
) -> list[dict]:
|
||||
"""Lock stuck pending/streaming messages skipping live resumes.
|
||||
|
||||
Staleness rides on the **later of** ``cm.timestamp`` (creation)
|
||||
and ``message_metadata.last_heartbeat_at`` (route heartbeat). An
|
||||
in-flight stream that re-stamps the heartbeat each minute stays
|
||||
out of the sweep; reconciler-side writes deliberately don't
|
||||
touch either column so the per-row attempts counter advances
|
||||
across ticks. Liveness exemption covers both ``pending`` (paused
|
||||
waiting for resume) and ``resuming`` (actively executing)
|
||||
``pending_tool_state`` rows so a paused message survives until
|
||||
the PT row's own TTL retires it.
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT cm.id, cm.conversation_id, cm.user_id, cm.timestamp,
|
||||
cm.message_metadata
|
||||
FROM conversation_messages cm
|
||||
WHERE cm.status IN ('pending', 'streaming')
|
||||
AND cm.timestamp < now() - make_interval(mins => :age)
|
||||
AND COALESCE(
|
||||
(cm.message_metadata->>'last_heartbeat_at')::timestamptz,
|
||||
cm.timestamp
|
||||
) < now() - make_interval(mins => :age)
|
||||
AND NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM pending_tool_state pts
|
||||
WHERE pts.conversation_id = cm.conversation_id
|
||||
AND (
|
||||
(pts.status = 'pending'
|
||||
AND pts.expires_at > now())
|
||||
OR
|
||||
(pts.status = 'resuming'
|
||||
AND pts.resumed_at
|
||||
> now() - interval '10 minutes')
|
||||
)
|
||||
)
|
||||
ORDER BY cm.timestamp ASC
|
||||
LIMIT :limit
|
||||
FOR UPDATE OF cm SKIP LOCKED
|
||||
"""
|
||||
),
|
||||
{"age": age_minutes, "limit": limit},
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def find_and_lock_proposed_tool_calls(
|
||||
self, *, age_minutes: int = 5, limit: int = 100,
|
||||
) -> list[dict]:
|
||||
"""Lock tool_call_attempts that never advanced past ``proposed``."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT call_id, message_id, tool_id, tool_name, action_name,
|
||||
arguments, attempted_at, updated_at
|
||||
FROM tool_call_attempts
|
||||
WHERE status = 'proposed'
|
||||
AND attempted_at < now() - make_interval(mins => :age)
|
||||
ORDER BY attempted_at ASC
|
||||
LIMIT :limit
|
||||
FOR UPDATE SKIP LOCKED
|
||||
"""
|
||||
),
|
||||
{"age": age_minutes, "limit": limit},
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def find_and_lock_executed_tool_calls(
|
||||
self, *, age_minutes: int = 15, limit: int = 100,
|
||||
) -> list[dict]:
|
||||
"""Lock tool_call_attempts stuck in ``executed`` past confirm window."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT call_id, message_id, tool_id, tool_name, action_name,
|
||||
arguments, result, attempted_at, updated_at
|
||||
FROM tool_call_attempts
|
||||
WHERE status = 'executed'
|
||||
AND updated_at < now() - make_interval(mins => :age)
|
||||
ORDER BY updated_at ASC
|
||||
LIMIT :limit
|
||||
FOR UPDATE SKIP LOCKED
|
||||
"""
|
||||
),
|
||||
{"age": age_minutes, "limit": limit},
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def find_and_lock_stalled_ingests(
|
||||
self, *, age_minutes: int = 30, limit: int = 100,
|
||||
) -> list[dict]:
|
||||
"""Lock ingest checkpoints whose heartbeat hasn't ticked recently."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT source_id, total_chunks, embedded_chunks,
|
||||
last_index, last_updated
|
||||
FROM ingest_chunk_progress
|
||||
WHERE last_updated < now() - make_interval(mins => :age)
|
||||
AND embedded_chunks < total_chunks
|
||||
ORDER BY last_updated ASC
|
||||
LIMIT :limit
|
||||
FOR UPDATE SKIP LOCKED
|
||||
"""
|
||||
),
|
||||
{"age": age_minutes, "limit": limit},
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def touch_ingest_progress(self, source_id: str) -> bool:
|
||||
"""Bump ``last_updated`` so a once-stalled ingest re-enters the watch window."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"UPDATE ingest_chunk_progress SET last_updated = now() "
|
||||
"WHERE source_id = CAST(:sid AS uuid)"
|
||||
),
|
||||
{"sid": str(source_id)},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def increment_message_reconcile_attempts(self, message_id: str) -> int:
|
||||
"""Bump ``message_metadata.reconcile_attempts`` and return the new count."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE conversation_messages
|
||||
SET message_metadata = jsonb_set(
|
||||
COALESCE(message_metadata, '{}'::jsonb),
|
||||
'{reconcile_attempts}',
|
||||
to_jsonb(
|
||||
COALESCE(
|
||||
(message_metadata->>'reconcile_attempts')::int,
|
||||
0
|
||||
) + 1
|
||||
)
|
||||
)
|
||||
WHERE id = CAST(:message_id AS uuid)
|
||||
RETURNING (message_metadata->>'reconcile_attempts')::int
|
||||
AS new_count
|
||||
"""
|
||||
),
|
||||
{"message_id": message_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return int(row[0]) if row is not None else 0
|
||||
|
||||
def mark_message_failed(self, message_id: str, *, error: str) -> bool:
|
||||
"""Flip a message to ``status='failed'`` and stash ``error`` in metadata."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE conversation_messages
|
||||
SET status = 'failed',
|
||||
message_metadata = jsonb_set(
|
||||
COALESCE(message_metadata, '{}'::jsonb),
|
||||
'{error}',
|
||||
to_jsonb(CAST(:error AS text))
|
||||
)
|
||||
WHERE id = CAST(:message_id AS uuid)
|
||||
"""
|
||||
),
|
||||
{"message_id": message_id, "error": error},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def mark_tool_call_failed(self, call_id: str, *, error: str) -> bool:
|
||||
"""Flip a tool_call_attempts row to ``failed`` with ``error``."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"UPDATE tool_call_attempts SET status = 'failed', "
|
||||
"error = :error WHERE call_id = :call_id"
|
||||
),
|
||||
{"call_id": call_id, "error": error},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def find_stuck_idempotency_pending(
|
||||
self,
|
||||
*,
|
||||
max_attempts: int,
|
||||
lease_grace_seconds: int = 60,
|
||||
limit: int = 100,
|
||||
) -> list[dict]:
|
||||
"""Lock ``task_dedup`` rows abandoned past the lease + retry budget.
|
||||
|
||||
A row is "stuck" when:
|
||||
|
||||
- ``status='pending'`` (lease was claimed but never finalised)
|
||||
- ``lease_expires_at`` is past by at least ``lease_grace_seconds``
|
||||
(the heartbeat thread is gone — the lease isn't going to come
|
||||
back)
|
||||
- ``attempt_count >= max_attempts`` (the poison-loop guard
|
||||
should already have escalated this; if it hasn't, the wrapper
|
||||
died before getting there)
|
||||
|
||||
These rows would otherwise sit in ``pending`` until the 24 h
|
||||
TTL aged them out, blocking same-key retries via
|
||||
``_lookup_completed`` returning None for the whole window.
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT idempotency_key, task_name, task_id, attempt_count,
|
||||
lease_owner_id, lease_expires_at, created_at
|
||||
FROM task_dedup
|
||||
WHERE status = 'pending'
|
||||
AND lease_expires_at IS NOT NULL
|
||||
AND lease_expires_at
|
||||
< now() - make_interval(secs => :grace)
|
||||
AND attempt_count >= :max_attempts
|
||||
ORDER BY created_at ASC
|
||||
LIMIT :limit
|
||||
FOR UPDATE SKIP LOCKED
|
||||
"""
|
||||
),
|
||||
{
|
||||
"max_attempts": int(max_attempts),
|
||||
"grace": int(lease_grace_seconds),
|
||||
"limit": int(limit),
|
||||
},
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def mark_idempotency_pending_failed(
|
||||
self, key: str, *, error: str,
|
||||
) -> bool:
|
||||
"""Promote a stuck pending ``task_dedup`` row to ``failed``."""
|
||||
from application.storage.db.serialization import PGNativeJSONEncoder
|
||||
import json
|
||||
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE task_dedup
|
||||
SET status = 'failed',
|
||||
result_json = CAST(:result AS jsonb),
|
||||
lease_owner_id = NULL,
|
||||
lease_expires_at = NULL
|
||||
WHERE idempotency_key = :key
|
||||
AND status = 'pending'
|
||||
"""
|
||||
),
|
||||
{
|
||||
"key": key,
|
||||
"result": json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"error": error,
|
||||
"reconciled": True,
|
||||
},
|
||||
cls=PGNativeJSONEncoder,
|
||||
),
|
||||
},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
@@ -13,6 +13,8 @@ import json
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from application.storage.db.serialization import PGNativeJSONEncoder
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
|
||||
@@ -52,7 +54,7 @@ class StackLogsRepository:
|
||||
"user_id": user_id,
|
||||
"api_key": api_key,
|
||||
"query": query,
|
||||
"stacks": json.dumps(stacks or []),
|
||||
"stacks": json.dumps(stacks or [], cls=PGNativeJSONEncoder),
|
||||
"timestamp": timestamp,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -31,6 +31,8 @@ class TokenUsageRepository:
|
||||
agent_id: Optional[str] = None,
|
||||
prompt_tokens: int = 0,
|
||||
generated_tokens: int = 0,
|
||||
source: str = "agent_stream",
|
||||
request_id: Optional[str] = None,
|
||||
timestamp: Optional[datetime] = None,
|
||||
) -> None:
|
||||
# Attribution guard: the ``token_usage_attribution_chk`` CHECK
|
||||
@@ -54,12 +56,16 @@ class TokenUsageRepository:
|
||||
self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO token_usage (user_id, api_key, agent_id, prompt_tokens, generated_tokens, timestamp)
|
||||
INSERT INTO token_usage (
|
||||
user_id, api_key, agent_id,
|
||||
prompt_tokens, generated_tokens,
|
||||
source, request_id, timestamp
|
||||
)
|
||||
VALUES (
|
||||
:user_id, :api_key,
|
||||
CAST(:agent_id AS uuid),
|
||||
:prompt_tokens, :generated_tokens,
|
||||
COALESCE(:timestamp, now())
|
||||
:source, :request_id, COALESCE(:timestamp, now())
|
||||
)
|
||||
"""
|
||||
),
|
||||
@@ -69,6 +75,8 @@ class TokenUsageRepository:
|
||||
"agent_id": agent_id_uuid,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"generated_tokens": generated_tokens,
|
||||
"source": source,
|
||||
"request_id": request_id,
|
||||
"timestamp": timestamp,
|
||||
},
|
||||
)
|
||||
@@ -173,8 +181,22 @@ class TokenUsageRepository:
|
||||
user_id: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
) -> int:
|
||||
"""Count of token_usage rows in the given time range (for request limiting)."""
|
||||
clauses = ["timestamp >= :start", "timestamp <= :end"]
|
||||
"""Count user-initiated requests in the given time range.
|
||||
|
||||
A request = one ``agent_stream`` invocation. Multi-tool agent
|
||||
runs produce multiple rows (one per LLM call) tagged with the
|
||||
same ``request_id``; we DISTINCT on that to count the request
|
||||
once. Pre-migration rows have ``request_id=NULL`` and are
|
||||
counted one-per-row via the second branch (back-compat).
|
||||
Side-channel sources (``title`` / ``compression`` /
|
||||
``rag_condense`` / ``fallback``) are excluded — they aren't
|
||||
user-initiated and shouldn't tick the request limit.
|
||||
"""
|
||||
clauses = [
|
||||
"timestamp >= :start",
|
||||
"timestamp <= :end",
|
||||
"source = 'agent_stream'",
|
||||
]
|
||||
params: dict = {"start": start, "end": end}
|
||||
if user_id is not None:
|
||||
clauses.append("user_id = :user_id")
|
||||
@@ -184,7 +206,15 @@ class TokenUsageRepository:
|
||||
params["api_key"] = api_key
|
||||
where = " AND ".join(clauses)
|
||||
result = self._conn.execute(
|
||||
text(f"SELECT COUNT(*) FROM token_usage WHERE {where}"),
|
||||
text(
|
||||
f"""
|
||||
SELECT
|
||||
COUNT(DISTINCT request_id) FILTER (WHERE request_id IS NOT NULL)
|
||||
+ COUNT(*) FILTER (WHERE request_id IS NULL)
|
||||
FROM token_usage
|
||||
WHERE {where}
|
||||
"""
|
||||
),
|
||||
params,
|
||||
)
|
||||
return result.scalar()
|
||||
|
||||
144
application/storage/db/repositories/tool_call_attempts.py
Normal file
144
application/storage/db/repositories/tool_call_attempts.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""Repository for ``tool_call_attempts``; executor's proposed/executed/failed writes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
from application.storage.db.serialization import PGNativeJSONEncoder
|
||||
|
||||
|
||||
class ToolCallAttemptsRepository:
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def record_proposed(
|
||||
self,
|
||||
call_id: str,
|
||||
tool_name: str,
|
||||
action_name: str,
|
||||
arguments: Any,
|
||||
*,
|
||||
tool_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Insert a ``proposed`` row before the tool executes.
|
||||
|
||||
Returns True if a new row was created. ``ON CONFLICT DO NOTHING``
|
||||
guards against the LLM emitting a duplicate ``call_id``: the
|
||||
existing row stays put rather than a re-insert raising
|
||||
``IntegrityError``.
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO tool_call_attempts
|
||||
(call_id, tool_id, tool_name, action_name, arguments, status)
|
||||
VALUES
|
||||
(:call_id, CAST(:tool_id AS uuid), :tool_name,
|
||||
:action_name, CAST(:arguments AS jsonb), 'proposed')
|
||||
ON CONFLICT (call_id) DO NOTHING
|
||||
"""
|
||||
),
|
||||
{
|
||||
"call_id": call_id,
|
||||
"tool_id": tool_id,
|
||||
"tool_name": tool_name,
|
||||
"action_name": action_name,
|
||||
"arguments": json.dumps(arguments if arguments is not None else {}, cls=PGNativeJSONEncoder),
|
||||
},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def upsert_executed(
|
||||
self,
|
||||
call_id: str,
|
||||
tool_name: str,
|
||||
action_name: str,
|
||||
arguments: Any,
|
||||
result: Any,
|
||||
*,
|
||||
tool_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
artifact_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Insert OR upgrade a row to ``executed``.
|
||||
|
||||
Used as a fallback when ``record_proposed`` failed (DB outage)
|
||||
and the tool ran anyway — preserves the journal so the
|
||||
reconciler can still see the attempt.
|
||||
"""
|
||||
result_payload: dict = {"result": result}
|
||||
if artifact_id:
|
||||
result_payload["artifact_id"] = artifact_id
|
||||
self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO tool_call_attempts
|
||||
(call_id, tool_id, tool_name, action_name, arguments,
|
||||
result, message_id, status)
|
||||
VALUES
|
||||
(:call_id, CAST(:tool_id AS uuid), :tool_name,
|
||||
:action_name, CAST(:arguments AS jsonb),
|
||||
CAST(:result AS jsonb), CAST(:message_id AS uuid),
|
||||
'executed')
|
||||
ON CONFLICT (call_id) DO UPDATE
|
||||
SET status = 'executed',
|
||||
result = EXCLUDED.result,
|
||||
message_id = COALESCE(EXCLUDED.message_id, tool_call_attempts.message_id)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"call_id": call_id,
|
||||
"tool_id": tool_id,
|
||||
"tool_name": tool_name,
|
||||
"action_name": action_name,
|
||||
"arguments": json.dumps(arguments if arguments is not None else {}, cls=PGNativeJSONEncoder),
|
||||
"result": json.dumps(result_payload, cls=PGNativeJSONEncoder),
|
||||
"message_id": message_id,
|
||||
},
|
||||
)
|
||||
|
||||
def mark_executed(
|
||||
self,
|
||||
call_id: str,
|
||||
result: Any,
|
||||
*,
|
||||
message_id: Optional[str] = None,
|
||||
artifact_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Flip ``proposed`` → ``executed`` with the tool result.
|
||||
|
||||
``artifact_id`` (when present) is stored alongside ``result`` in
|
||||
the JSONB as audit data — the reconciler reads it for diagnostic
|
||||
alerts when escalating stuck rows to ``failed``.
|
||||
"""
|
||||
result_payload: dict = {"result": result}
|
||||
if artifact_id:
|
||||
result_payload["artifact_id"] = artifact_id
|
||||
sql = (
|
||||
"UPDATE tool_call_attempts SET "
|
||||
"status = 'executed', result = CAST(:result AS jsonb)"
|
||||
)
|
||||
params: dict[str, Any] = {
|
||||
"call_id": call_id,
|
||||
"result": json.dumps(result_payload, cls=PGNativeJSONEncoder),
|
||||
}
|
||||
if message_id is not None:
|
||||
sql += ", message_id = CAST(:message_id AS uuid)"
|
||||
params["message_id"] = message_id
|
||||
sql += " WHERE call_id = :call_id"
|
||||
result_proxy = self._conn.execute(text(sql), params)
|
||||
return result_proxy.rowcount > 0
|
||||
|
||||
def mark_failed(self, call_id: str, error: str) -> bool:
|
||||
"""Flip ``proposed`` → ``failed`` with the exception text."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"UPDATE tool_call_attempts SET status = 'failed', error = :error "
|
||||
"WHERE call_id = :call_id"
|
||||
),
|
||||
{"call_id": call_id, "error": error},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
@@ -20,6 +20,7 @@ from typing import Optional
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
from application.storage.db.serialization import PGNativeJSONEncoder
|
||||
|
||||
|
||||
class UserLogsRepository:
|
||||
@@ -46,7 +47,7 @@ class UserLogsRepository:
|
||||
{
|
||||
"user_id": user_id,
|
||||
"endpoint": endpoint,
|
||||
"data": json.dumps(data, default=str) if data is not None else None,
|
||||
"data": json.dumps(data, cls=PGNativeJSONEncoder) if data is not None else None,
|
||||
"timestamp": timestamp,
|
||||
},
|
||||
)
|
||||
|
||||
93
application/storage/db/serialization.py
Normal file
93
application/storage/db/serialization.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""JSON-safe coercion for PG-native Python types.
|
||||
|
||||
Postgres (via psycopg) returns native Python types — ``uuid.UUID``,
|
||||
``datetime.datetime``/``datetime.date``, ``decimal.Decimal``, ``bytes``
|
||||
— that ``json.dumps`` rejects. This module is the single place those
|
||||
coercion rules live; everywhere else should call into it.
|
||||
|
||||
Two interfaces with identical coverage:
|
||||
|
||||
* :func:`coerce_pg_native` — recursive walk returning a JSON-safe copy.
|
||||
Use when you need to inspect the dict yourself or pass it to a
|
||||
serializer that doesn't accept a custom encoder (e.g. SQLAlchemy
|
||||
parameter binding for a JSONB column).
|
||||
* :class:`PGNativeJSONEncoder` — ``JSONEncoder`` subclass. Use as
|
||||
``json.dumps(obj, cls=PGNativeJSONEncoder)`` for serialise-once flows
|
||||
where the extra recursive walk is wasted work.
|
||||
|
||||
Coercion rules:
|
||||
|
||||
* ``UUID`` → canonical hex string.
|
||||
* ``datetime`` / ``date`` → ISO 8601 string.
|
||||
* ``Decimal`` → numeric string (preserves precision; ``float()`` would not).
|
||||
* ``bytes`` → base64 string. Lossless and universally JSON-safe;
|
||||
prior code used UTF-8 with ``errors="replace"`` which silently
|
||||
corrupted binary payloads (e.g. Gemini's ``thought_signature``).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
import json
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
def _coerce_scalar(obj: Any) -> Any:
|
||||
if isinstance(obj, UUID):
|
||||
return str(obj)
|
||||
if isinstance(obj, (datetime, date)):
|
||||
return obj.isoformat()
|
||||
if isinstance(obj, Decimal):
|
||||
return str(obj)
|
||||
if isinstance(obj, bytes):
|
||||
return base64.b64encode(obj).decode("ascii")
|
||||
return obj
|
||||
|
||||
|
||||
def coerce_pg_native(obj: Any) -> Any:
|
||||
"""Recursively coerce PG-native types to JSON-safe equivalents.
|
||||
|
||||
Recurses into ``dict`` (stringifying keys, matching prior helper
|
||||
behavior) and ``list``/``tuple`` (tuples flatten to lists since JSON
|
||||
has no tuple type). Any other type passes through unchanged.
|
||||
"""
|
||||
if isinstance(obj, dict):
|
||||
return {str(k): coerce_pg_native(v) for k, v in obj.items()}
|
||||
if isinstance(obj, (list, tuple)):
|
||||
return [coerce_pg_native(v) for v in obj]
|
||||
return _coerce_scalar(obj)
|
||||
|
||||
|
||||
def decode_base64_bytes(value: Any) -> Any:
|
||||
"""Reverse ``coerce_pg_native``'s bytes-to-base64 step.
|
||||
|
||||
Useful at egress points that need the original bytes back (e.g.
|
||||
sending Gemini's ``thought_signature`` to the SDK on resume). Uses
|
||||
``validate=True`` so plain ASCII strings that happen to be
|
||||
permissively decodable (e.g. ``"abcd"``) are not silently turned
|
||||
into bytes — the original value passes through.
|
||||
"""
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return base64.b64decode(value.encode("ascii"), validate=True)
|
||||
except (binascii.Error, ValueError):
|
||||
return value
|
||||
return value
|
||||
|
||||
|
||||
class PGNativeJSONEncoder(json.JSONEncoder):
|
||||
"""``JSONEncoder`` covering UUID / datetime / date / Decimal / bytes.
|
||||
|
||||
Use as ``json.dumps(obj, cls=PGNativeJSONEncoder)``. Equivalent in
|
||||
coverage to :func:`coerce_pg_native` but skips the eager walk.
|
||||
"""
|
||||
|
||||
def default(self, obj: Any) -> Any:
|
||||
coerced = _coerce_scalar(obj)
|
||||
if coerced is obj:
|
||||
return super().default(obj)
|
||||
return coerced
|
||||
23
application/storage/db/source_ids.py
Normal file
23
application/storage/db/source_ids.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Deterministic source-id derivation for idempotent ingest.
|
||||
|
||||
DO NOT CHANGE the pinned UUID namespace — it backs cross-deploy
|
||||
idempotency keys.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
# DO NOT CHANGE. See module docstring.
|
||||
DOCSGPT_INGEST_NAMESPACE = uuid.UUID("fa25d5d1-398b-46df-ac89-8d1c360b9bea")
|
||||
|
||||
|
||||
def derive_source_id(idempotency_key) -> uuid.UUID:
|
||||
"""``uuid5(NS, key)`` when a key is supplied; ``uuid4()`` otherwise.
|
||||
|
||||
A non-string / empty key falls back to ``uuid4()`` so the caller
|
||||
always gets a fresh id rather than a TypeError mid-route.
|
||||
"""
|
||||
if isinstance(idempotency_key, str) and idempotency_key:
|
||||
return uuid.uuid5(DOCSGPT_INGEST_NAMESPACE, idempotency_key)
|
||||
return uuid.uuid4()
|
||||
0
application/streaming/__init__.py
Normal file
0
application/streaming/__init__.py
Normal file
126
application/streaming/broadcast_channel.py
Normal file
126
application/streaming/broadcast_channel.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""Redis pub/sub Topic abstraction for SSE fan-out.
|
||||
|
||||
A Topic is a named channel for one-shot live event delivery. Canonical uses:
|
||||
|
||||
- ``user:{user_id}`` for per-user notifications
|
||||
- ``channel:{message_id}`` for per-chat-message streams
|
||||
|
||||
Subscription is race-free via ``on_subscribe``: the callback fires only
|
||||
after Redis acknowledges ``SUBSCRIBE``, so a publisher dispatched inside
|
||||
the callback cannot lose its first event to a not-yet-registered
|
||||
subscriber.
|
||||
|
||||
The subscribe iterator yields ``None`` on poll timeout so the caller can
|
||||
emit SSE keepalive comments without spawning a separate timer thread.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Callable, Iterator, Optional
|
||||
|
||||
from application.cache import get_redis_instance
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Topic:
|
||||
"""A pub/sub channel identified by a string name."""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
self.name = name
|
||||
|
||||
def publish(self, payload: str | bytes) -> int:
|
||||
"""Fan out a payload to currently subscribed clients.
|
||||
|
||||
Returns the number Redis reports as receiving the message (limited
|
||||
to subscribers connected to *this* Redis instance), or 0 if Redis
|
||||
is unavailable. Never raises.
|
||||
"""
|
||||
redis = get_redis_instance()
|
||||
if redis is None:
|
||||
logger.debug("Redis unavailable; dropping publish to %s", self.name)
|
||||
return 0
|
||||
try:
|
||||
return int(redis.publish(self.name, payload))
|
||||
except Exception:
|
||||
logger.exception("Topic.publish failed for %s", self.name)
|
||||
return 0
|
||||
|
||||
def subscribe(
|
||||
self,
|
||||
on_subscribe: Optional[Callable[[], None]] = None,
|
||||
poll_timeout: float = 1.0,
|
||||
) -> Iterator[Optional[bytes]]:
|
||||
"""Subscribe to the topic; yield raw payloads or ``None`` on tick.
|
||||
|
||||
Yields ``None`` every ``poll_timeout`` seconds while idle so the
|
||||
caller can emit keepalive frames or check cancellation. Yields
|
||||
``bytes`` for each delivered message.
|
||||
|
||||
``on_subscribe`` runs synchronously after Redis acknowledges the
|
||||
SUBSCRIBE — use it to seed any state (e.g. read backlog) that
|
||||
must be ordered after the subscriber is live but before the
|
||||
first pub/sub message is processed.
|
||||
|
||||
If Redis is unavailable, returns immediately without yielding.
|
||||
Cleanly unsubscribes on ``GeneratorExit`` (client disconnect).
|
||||
"""
|
||||
redis = get_redis_instance()
|
||||
if redis is None:
|
||||
logger.debug("Redis unavailable; subscribe to %s yielded nothing", self.name)
|
||||
return
|
||||
pubsub = None
|
||||
on_subscribe_fired = False
|
||||
try:
|
||||
pubsub = redis.pubsub()
|
||||
try:
|
||||
pubsub.subscribe(self.name)
|
||||
except Exception:
|
||||
# Subscribe failure (transient Redis hiccup, conn reset, etc.)
|
||||
# is treated like "Redis unavailable": yield nothing, let the
|
||||
# caller fall back to its own resilience strategy. The finally
|
||||
# block will still tear down the pubsub object cleanly.
|
||||
logger.exception("pubsub.subscribe failed for %s", self.name)
|
||||
return
|
||||
while True:
|
||||
try:
|
||||
msg = pubsub.get_message(timeout=poll_timeout)
|
||||
except Exception:
|
||||
logger.exception("pubsub.get_message failed for %s", self.name)
|
||||
return
|
||||
if msg is None:
|
||||
yield None
|
||||
continue
|
||||
msg_type = msg.get("type")
|
||||
if msg_type == "subscribe":
|
||||
if not on_subscribe_fired and on_subscribe is not None:
|
||||
try:
|
||||
on_subscribe()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"on_subscribe callback failed for %s", self.name
|
||||
)
|
||||
on_subscribe_fired = True
|
||||
continue
|
||||
if msg_type != "message":
|
||||
continue
|
||||
data = msg.get("data")
|
||||
if data is None:
|
||||
continue
|
||||
yield data if isinstance(data, bytes) else str(data).encode("utf-8")
|
||||
finally:
|
||||
if pubsub is not None:
|
||||
if on_subscribe_fired:
|
||||
try:
|
||||
pubsub.unsubscribe(self.name)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"pubsub unsubscribe error for %s",
|
||||
self.name,
|
||||
exc_info=True,
|
||||
)
|
||||
try:
|
||||
pubsub.close()
|
||||
except Exception:
|
||||
logger.debug("pubsub close error for %s", self.name, exc_info=True)
|
||||
434
application/streaming/event_replay.py
Normal file
434
application/streaming/event_replay.py
Normal file
@@ -0,0 +1,434 @@
|
||||
"""Snapshot+tail iterator for chat-stream reconnect-after-disconnect.
|
||||
|
||||
Subscribe to ``channel:{message_id}``, snapshot ``message_events``
|
||||
rows past ``last_event_id`` inside the SUBSCRIBE-ack callback, flush
|
||||
snapshot, then tail live pub/sub (dedup'd by ``sequence_no``). See
|
||||
``docs/runbooks/sse-notifications.md``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from typing import Iterator, Optional
|
||||
|
||||
from sqlalchemy import text as sql_text
|
||||
|
||||
from application.storage.db.repositories.message_events import (
|
||||
MessageEventsRepository,
|
||||
)
|
||||
from application.storage.db.session import db_readonly
|
||||
from application.streaming.broadcast_channel import Topic
|
||||
from application.streaming.keys import message_topic_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_KEEPALIVE_SECONDS = 15.0
|
||||
DEFAULT_POLL_TIMEOUT_SECONDS = 1.0
|
||||
# When the live tail has no events and no terminal in snapshot, fall
|
||||
# back to checking ``conversation_messages`` directly. If the row has
|
||||
# already gone terminal (worker journaled ``end``/``error`` to the DB
|
||||
# but the matching pub/sub publish was lost, or the row was finalized
|
||||
# without a journal write at all) we surface a terminal event so the
|
||||
# client doesn't hang on keepalives. If the row is still non-terminal
|
||||
# but the producer heartbeat is older than ``PRODUCER_IDLE_SECONDS``
|
||||
# the producer is presumed dead (worker crash / recycle between chunks
|
||||
# and finalize) and we emit a terminal ``error`` so the UI can recover.
|
||||
DEFAULT_WATCHDOG_INTERVAL_SECONDS = 5.0
|
||||
# 1.5× the route's 60s heartbeat interval — long enough that a normal
|
||||
# heartbeat skew doesn't false-positive, short enough that a stuck
|
||||
# stream surfaces before the 5-minute reconciler sweep escalates.
|
||||
DEFAULT_PRODUCER_IDLE_SECONDS = 90.0
|
||||
|
||||
# WHATWG SSE accepts CRLF, CR, LF — split on any of them so a stray CR
|
||||
# can't smuggle a record boundary into the wire format.
|
||||
_SSE_LINE_SPLIT_PATTERN = re.compile(r"\r\n|\r|\n")
|
||||
|
||||
# Event types that mark the end of a chat answer. After delivering one
|
||||
# we close the reconnect stream — keeping the connection open past a
|
||||
# terminal event would leak both the client's reconnect promise and
|
||||
# the server's WSGI thread waiting on keepalives that the user no
|
||||
# longer cares about. The agent loop emits ``end`` for normal /
|
||||
# tool-paused completion and ``error`` for the catch-all failure path
|
||||
# (which doesn't get a trailing ``end``).
|
||||
_TERMINAL_EVENT_TYPES = frozenset({"end", "error"})
|
||||
|
||||
|
||||
def _payload_is_terminal(
|
||||
payload: object, event_type: Optional[str] = None
|
||||
) -> bool:
|
||||
"""True if ``payload['type']`` or ``event_type`` is a terminal sentinel."""
|
||||
if isinstance(payload, dict) and payload.get("type") in _TERMINAL_EVENT_TYPES:
|
||||
return True
|
||||
return event_type in _TERMINAL_EVENT_TYPES
|
||||
|
||||
|
||||
def format_sse_event(payload: dict, sequence_no: int) -> str:
|
||||
"""Encode a journal event as one ``id:``/``data:`` SSE record.
|
||||
|
||||
The body is the payload's JSON serialisation. ``complete_stream``
|
||||
payloads are flat JSON dicts with no embedded newlines, so a
|
||||
single ``data:`` line is sufficient — but we still split on any
|
||||
line terminator in case a future caller passes a multi-line string.
|
||||
"""
|
||||
body = json.dumps(payload)
|
||||
lines = [f"id: {sequence_no}"]
|
||||
for line in _SSE_LINE_SPLIT_PATTERN.split(body):
|
||||
lines.append(f"data: {line}")
|
||||
return "\n".join(lines) + "\n\n"
|
||||
|
||||
|
||||
def _check_producer_liveness(
|
||||
message_id: str, idle_seconds: float
|
||||
) -> Optional[dict]:
|
||||
"""Inspect ``conversation_messages`` and return a terminal SSE
|
||||
payload when the producer is no longer alive, else ``None``.
|
||||
|
||||
Three terminal cases collapse into a single DB round-trip:
|
||||
|
||||
- ``status='complete'`` — the live finalize ran but its journal
|
||||
terminal write didn't reach us (or never happened). Synthesise
|
||||
``end`` so the client closes cleanly on the row's user-visible
|
||||
state.
|
||||
- ``status='failed'`` — same, but for the failure path. Carry the
|
||||
stashed ``error`` from ``message_metadata`` so the UI shows the
|
||||
real reason.
|
||||
- non-terminal status and ``last_heartbeat_at`` (or ``timestamp``)
|
||||
older than ``idle_seconds`` — the producing worker is gone.
|
||||
Synthesise ``error`` so the client doesn't hang on keepalives
|
||||
until the proxy idle-timeout kicks in.
|
||||
"""
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
row = conn.execute(
|
||||
sql_text(
|
||||
"""
|
||||
SELECT
|
||||
status,
|
||||
message_metadata->>'error' AS err,
|
||||
GREATEST(
|
||||
timestamp,
|
||||
COALESCE(
|
||||
(message_metadata->>'last_heartbeat_at')
|
||||
::timestamptz,
|
||||
timestamp
|
||||
)
|
||||
) < now() - make_interval(secs => :idle_secs)
|
||||
AS is_stale
|
||||
FROM conversation_messages
|
||||
WHERE id = CAST(:id AS uuid)
|
||||
"""
|
||||
),
|
||||
{"id": message_id, "idle_secs": float(idle_seconds)},
|
||||
).first()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Watchdog liveness check failed for message_id=%s", message_id
|
||||
)
|
||||
return None
|
||||
|
||||
if row is None:
|
||||
# Row deleted out from under us — treat as terminal so the
|
||||
# client doesn't keep tailing a message that no longer exists.
|
||||
return {
|
||||
"type": "error",
|
||||
"error": "Message no longer exists; please refresh.",
|
||||
"code": "message_missing",
|
||||
"message_id": message_id,
|
||||
}
|
||||
|
||||
status, err, is_stale = row[0], row[1], bool(row[2])
|
||||
if status == "complete":
|
||||
return {"type": "end"}
|
||||
if status == "failed":
|
||||
return {
|
||||
"type": "error",
|
||||
"error": err or "Stream failed; please try again.",
|
||||
"code": "producer_failed",
|
||||
"message_id": message_id,
|
||||
}
|
||||
if is_stale:
|
||||
return {
|
||||
"type": "error",
|
||||
"error": (
|
||||
"Stream producer is no longer responding; please try again."
|
||||
),
|
||||
"code": "producer_stale",
|
||||
"message_id": message_id,
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
def build_message_event_stream(
|
||||
message_id: str,
|
||||
last_event_id: Optional[int] = None,
|
||||
*,
|
||||
keepalive_seconds: float = DEFAULT_KEEPALIVE_SECONDS,
|
||||
poll_timeout_seconds: float = DEFAULT_POLL_TIMEOUT_SECONDS,
|
||||
watchdog_interval_seconds: float = DEFAULT_WATCHDOG_INTERVAL_SECONDS,
|
||||
producer_idle_seconds: float = DEFAULT_PRODUCER_IDLE_SECONDS,
|
||||
) -> Iterator[str]:
|
||||
"""Yield SSE-formatted lines for one ``message_id`` reconnect stream.
|
||||
|
||||
First frame is ``: connected``; subsequent frames are snapshot rows,
|
||||
live-tail events, or ``: keepalive`` comments. Runs until the client
|
||||
disconnects.
|
||||
"""
|
||||
yield ": connected\n\n"
|
||||
|
||||
# Replay buffer — populated inside ``_on_subscribe`` (or the
|
||||
# Redis-unavailable fallback below), drained on the first iteration
|
||||
# of the subscribe loop after the callback runs.
|
||||
replay_buffer: list[str] = []
|
||||
# Dedup floor: seeded with the client's cursor so an empty snapshot
|
||||
# still rejects re-published live events with seq <= last_event_id.
|
||||
# Advanced by snapshot rows AND by yielded live events, so any
|
||||
# republish past the snapshot ceiling is also dropped.
|
||||
max_replayed_seq: Optional[int] = last_event_id
|
||||
replay_done = False
|
||||
replay_failed = False
|
||||
# Set when a snapshot row carries a terminal ``end`` / ``error``
|
||||
# event. After flushing the buffer the generator returns; if we
|
||||
# kept tailing we'd loop on keepalives forever for a stream that
|
||||
# already finished.
|
||||
terminal_in_snapshot = False
|
||||
|
||||
def _read_snapshot_into_buffer() -> None:
|
||||
nonlocal max_replayed_seq, replay_failed, terminal_in_snapshot
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
rows = MessageEventsRepository(conn).read_after(
|
||||
message_id, last_sequence_no=last_event_id
|
||||
)
|
||||
for row in rows:
|
||||
seq = int(row["sequence_no"])
|
||||
payload = row.get("payload")
|
||||
if not isinstance(payload, dict):
|
||||
# ``record_event`` rejects non-dict payloads at the
|
||||
# write gate, so this can only be a legacy row from
|
||||
# before that contract or a direct SQL insert. The
|
||||
# original synthetic fallback (``{"type": event_type}``)
|
||||
# used to ship a malformed envelope here — drop the
|
||||
# row instead so a corrupt journal entry doesn't
|
||||
# poison a reconnect.
|
||||
logger.warning(
|
||||
"Skipping non-dict payload from message_events: "
|
||||
"message_id=%s seq=%s type=%s",
|
||||
message_id,
|
||||
seq,
|
||||
row.get("event_type"),
|
||||
)
|
||||
continue
|
||||
replay_buffer.append(format_sse_event(payload, seq))
|
||||
if max_replayed_seq is None or seq > max_replayed_seq:
|
||||
max_replayed_seq = seq
|
||||
if _payload_is_terminal(payload, row.get("event_type")):
|
||||
terminal_in_snapshot = True
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Snapshot read failed for message_id=%s last_event_id=%s",
|
||||
message_id,
|
||||
last_event_id,
|
||||
)
|
||||
replay_failed = True
|
||||
|
||||
def _on_subscribe() -> None:
|
||||
# SUBSCRIBE has been acked — Postgres reads from this point
|
||||
# capture every row that's been committed. Pub/sub messages
|
||||
# published after this point are queued at the connection level
|
||||
# until the outer loop calls ``get_message`` again.
|
||||
nonlocal replay_done
|
||||
try:
|
||||
_read_snapshot_into_buffer()
|
||||
finally:
|
||||
# Flip even on failure so the outer loop continues to live
|
||||
# tail and the client doesn't hang waiting for a snapshot
|
||||
# flush that will never come.
|
||||
replay_done = True
|
||||
|
||||
topic = Topic(message_topic_name(message_id))
|
||||
last_keepalive = time.monotonic()
|
||||
# Rate-limit the watchdog's DB hit. ``-inf`` makes the first idle
|
||||
# tick after replay_done fire immediately so a snapshot-already-
|
||||
# terminal-in-DB case is surfaced before any keepalive cadence.
|
||||
# Subsequent checks are gated by ``watchdog_interval_seconds``.
|
||||
last_watchdog_check = float("-inf")
|
||||
# Synthetic terminal events emitted by the watchdog use the same
|
||||
# ``sequence_no=-1`` convention as the snapshot-failure path so the
|
||||
# frontend's strict ``\d+`` cursor regex rejects them as a
|
||||
# ``Last-Event-ID`` for any future reconnect. The chosen
|
||||
# discriminator ensures a manual page refresh after a watchdog-fired
|
||||
# error doesn't loop on the same synthetic id.
|
||||
watchdog_synthetic_seq = -1
|
||||
|
||||
try:
|
||||
for payload in topic.subscribe(
|
||||
on_subscribe=_on_subscribe,
|
||||
poll_timeout=poll_timeout_seconds,
|
||||
):
|
||||
# Flush snapshot exactly once after the SUBSCRIBE callback
|
||||
# has run and produced a buffer.
|
||||
if replay_done and replay_buffer:
|
||||
for line in replay_buffer:
|
||||
yield line
|
||||
replay_buffer.clear()
|
||||
if terminal_in_snapshot:
|
||||
# The original stream already finished; tailing
|
||||
# would just emit keepalives forever and pin both a
|
||||
# client reconnect promise and a server WSGI thread.
|
||||
return
|
||||
|
||||
if replay_failed:
|
||||
# Snapshot read failed (DB blip / transient timeout). Emit a
|
||||
# terminal ``error`` event and return — the client only
|
||||
# reconnects after the original stream has already moved on,
|
||||
# so without a snapshot there's nothing live left to tail and
|
||||
# holding the connection open would just emit keepalives
|
||||
# until the proxy idle-timeout fires. ``code`` preserves the
|
||||
# snapshot-vs-agent-loop distinction so a future client can
|
||||
# opt into a refetch instead of a hard failure.
|
||||
yield format_sse_event(
|
||||
{
|
||||
"type": "error",
|
||||
"error": "Stream replay failed; please refresh to load the latest state.",
|
||||
"code": "snapshot_failed",
|
||||
"message_id": message_id,
|
||||
},
|
||||
sequence_no=-1,
|
||||
)
|
||||
return
|
||||
|
||||
now = time.monotonic()
|
||||
if payload is None:
|
||||
# Idle tick — check both keepalive and watchdog. The
|
||||
# watchdog only kicks in once the snapshot half has been
|
||||
# flushed (``replay_done``) so we don't race the
|
||||
# snapshot read on the first iteration.
|
||||
if (
|
||||
replay_done
|
||||
and watchdog_interval_seconds >= 0
|
||||
and now - last_watchdog_check >= watchdog_interval_seconds
|
||||
):
|
||||
last_watchdog_check = now
|
||||
terminal_payload = _check_producer_liveness(
|
||||
message_id, producer_idle_seconds
|
||||
)
|
||||
if terminal_payload is not None:
|
||||
yield format_sse_event(
|
||||
terminal_payload,
|
||||
sequence_no=watchdog_synthetic_seq,
|
||||
)
|
||||
return
|
||||
if now - last_keepalive >= keepalive_seconds:
|
||||
yield ": keepalive\n\n"
|
||||
last_keepalive = now
|
||||
continue
|
||||
|
||||
envelope = _decode_pubsub_message(payload)
|
||||
if envelope is None:
|
||||
continue
|
||||
seq = envelope.get("sequence_no")
|
||||
inner = envelope.get("payload")
|
||||
if (
|
||||
not isinstance(seq, int)
|
||||
or isinstance(seq, bool)
|
||||
or not isinstance(inner, dict)
|
||||
):
|
||||
continue
|
||||
if max_replayed_seq is not None and seq <= max_replayed_seq:
|
||||
# Snapshot already covered this id — drop the duplicate.
|
||||
continue
|
||||
yield format_sse_event(inner, seq)
|
||||
# Advance the dedup floor on the live path too, so a stale
|
||||
# republish of an already-yielded seq (process restart, retry
|
||||
# tool, etc.) is dropped on a later iteration.
|
||||
max_replayed_seq = seq
|
||||
last_keepalive = now
|
||||
if _payload_is_terminal(inner, envelope.get("event_type")):
|
||||
# Live tail just delivered the terminal event — close
|
||||
# out the reconnect stream so the client's drain
|
||||
# promise resolves and the WSGI thread is freed.
|
||||
return
|
||||
|
||||
# Subscribe exited without ever yielding (Redis unavailable,
|
||||
# ``pubsub.subscribe`` raised, or the inner loop died between
|
||||
# SUBSCRIBE-ack and the first poll). The snapshot half is in
|
||||
# Postgres and is still serviceable — read it directly so a
|
||||
# Redis-only outage doesn't cost the client their reconnect
|
||||
# backlog. Gate the read on ``replay_done`` rather than
|
||||
# ``subscribe_started``: if ``_on_subscribe`` already populated
|
||||
# the buffer, re-reading would append the same rows twice and
|
||||
# double the answer chunks on the client (the per-message
|
||||
# reconnect dispatcher does not dedup by ``id``).
|
||||
if not replay_done:
|
||||
_read_snapshot_into_buffer()
|
||||
replay_done = True
|
||||
for line in replay_buffer:
|
||||
yield line
|
||||
replay_buffer.clear()
|
||||
if replay_failed:
|
||||
# Mirror the live-tail branch: emit a terminal ``error`` so
|
||||
# the frontend's existing end/error handling drives the UI
|
||||
# to a failed state instead of relying on the proxy timeout.
|
||||
yield format_sse_event(
|
||||
{
|
||||
"type": "error",
|
||||
"error": "Stream replay failed; please refresh to load the latest state.",
|
||||
"code": "snapshot_failed",
|
||||
"message_id": message_id,
|
||||
},
|
||||
sequence_no=-1,
|
||||
)
|
||||
return
|
||||
# Same close-on-terminal contract as the live-tail branch.
|
||||
# Without it a Redis-down + already-completed-stream client
|
||||
# would also hang on a never-ending generator.
|
||||
if terminal_in_snapshot:
|
||||
return
|
||||
except GeneratorExit:
|
||||
# Client disconnect — let the underlying ``Topic.subscribe``
|
||||
# ``finally`` block tear down its pubsub cleanly.
|
||||
return
|
||||
|
||||
|
||||
def _decode_pubsub_message(raw) -> Optional[dict]:
|
||||
"""Parse a ``Topic.publish`` payload to ``{sequence_no, payload, ...}``.
|
||||
|
||||
Returns ``None`` for malformed messages (drop silently — the
|
||||
journal is still authoritative on reconnect).
|
||||
"""
|
||||
try:
|
||||
if isinstance(raw, (bytes, bytearray)):
|
||||
text_value = raw.decode("utf-8")
|
||||
else:
|
||||
text_value = str(raw)
|
||||
envelope = json.loads(text_value)
|
||||
except Exception:
|
||||
return None
|
||||
if not isinstance(envelope, dict):
|
||||
return None
|
||||
return envelope
|
||||
|
||||
|
||||
def encode_pubsub_message(
|
||||
message_id: str,
|
||||
sequence_no: int,
|
||||
event_type: str,
|
||||
payload: dict,
|
||||
) -> str:
|
||||
"""Build the JSON envelope used for ``channel:{message_id}`` publishes.
|
||||
|
||||
Kept here (not in ``message_journal.py``) so the encode/decode pair
|
||||
stays in one file — replay's ``_decode_pubsub_message`` and the
|
||||
journal's publish must agree on the shape exactly.
|
||||
"""
|
||||
return json.dumps(
|
||||
{
|
||||
"message_id": str(message_id),
|
||||
"sequence_no": int(sequence_no),
|
||||
"event_type": event_type,
|
||||
"payload": payload,
|
||||
}
|
||||
)
|
||||
19
application/streaming/keys.py
Normal file
19
application/streaming/keys.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Per-chat-message stream key derivations.
|
||||
|
||||
Single source of truth for the Redis pub/sub topic name and any
|
||||
auxiliary keys that the chat-stream snapshot+tail reconnect path
|
||||
shares between the writer (``complete_stream`` + journal) and the
|
||||
reader (``/api/messages/<id>/events`` reconnect endpoint).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def message_topic_name(message_id: str) -> str:
|
||||
"""Redis pub/sub channel for live fan-out of one chat message.
|
||||
|
||||
Subscribers tail this topic for every event that ``complete_stream``
|
||||
yielded after the SUBSCRIBE-ack arrived; older events are recovered
|
||||
from the ``message_events`` snapshot half of the pattern.
|
||||
"""
|
||||
return f"channel:{message_id}"
|
||||
400
application/streaming/message_journal.py
Normal file
400
application/streaming/message_journal.py
Normal file
@@ -0,0 +1,400 @@
|
||||
"""Per-yield journal write for the chat-stream snapshot+tail pattern.
|
||||
|
||||
``record_event`` inserts into ``message_events`` and publishes to
|
||||
``channel:{message_id}``. Both are best-effort; the INSERT commits
|
||||
before the publish so a fast reconnect sees the row. See
|
||||
``docs/runbooks/sse-notifications.md``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from application.storage.db.repositories.message_events import (
|
||||
MessageEventsRepository,
|
||||
)
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.streaming.broadcast_channel import Topic
|
||||
from application.streaming.event_replay import encode_pubsub_message
|
||||
from application.streaming.keys import message_topic_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Tunables for ``BatchedJournalWriter``. A streaming answer emits ~100s
|
||||
# of ``answer`` chunks per response; without batching, that's one PG
|
||||
# transaction per yield in the WSGI thread. With these defaults, ~10x
|
||||
# fewer commits at the cost of a ≤100ms reconnect-visibility lag for
|
||||
# any event still sitting in the buffer.
|
||||
DEFAULT_BATCH_SIZE = 16
|
||||
DEFAULT_BATCH_INTERVAL_MS = 100
|
||||
|
||||
|
||||
def _strip_null_bytes(value: Any) -> Any:
|
||||
"""Recursively strip ``\\x00`` from string keys/values in ``value``.
|
||||
|
||||
Postgres JSONB rejects the NUL escape; an LLM emitting a stray NUL
|
||||
in a chunk would otherwise raise ``DataError`` at INSERT and the row
|
||||
would be lost from the journal (live stream proceeds, reconnect
|
||||
snapshot misses the chunk). Mirrors the strip already done in
|
||||
``parser/embedding_pipeline.py`` and
|
||||
``api/user/attachments/routes.py``.
|
||||
"""
|
||||
if isinstance(value, str):
|
||||
return value.replace("\x00", "") if "\x00" in value else value
|
||||
if isinstance(value, dict):
|
||||
return {
|
||||
(k.replace("\x00", "") if isinstance(k, str) and "\x00" in k else k):
|
||||
_strip_null_bytes(v)
|
||||
for k, v in value.items()
|
||||
}
|
||||
if isinstance(value, list):
|
||||
return [_strip_null_bytes(item) for item in value]
|
||||
if isinstance(value, tuple):
|
||||
return tuple(_strip_null_bytes(item) for item in value)
|
||||
return value
|
||||
|
||||
|
||||
def record_event(
|
||||
message_id: str,
|
||||
sequence_no: int,
|
||||
event_type: str,
|
||||
payload: Optional[dict[str, Any]] = None,
|
||||
) -> bool:
|
||||
"""Journal one SSE event and publish it live. Best-effort.
|
||||
|
||||
``payload`` must be a ``dict`` or ``None`` (non-dicts are dropped so
|
||||
live and replay envelopes stay byte-identical). Returns ``True`` when
|
||||
the journal INSERT committed. Never raises.
|
||||
"""
|
||||
if not message_id or not event_type:
|
||||
logger.warning(
|
||||
"record_event called without message_id/event_type "
|
||||
"(message_id=%r, event_type=%r)",
|
||||
message_id,
|
||||
event_type,
|
||||
)
|
||||
return False
|
||||
|
||||
if payload is None:
|
||||
materialised_payload: dict[str, Any] = {}
|
||||
elif isinstance(payload, dict):
|
||||
materialised_payload = _strip_null_bytes(payload)
|
||||
else:
|
||||
logger.warning(
|
||||
"record_event called with non-dict payload "
|
||||
"(message_id=%s seq=%s type=%s payload_type=%s) — dropping",
|
||||
message_id,
|
||||
sequence_no,
|
||||
event_type,
|
||||
type(payload).__name__,
|
||||
)
|
||||
return False
|
||||
|
||||
journal_committed = False
|
||||
# The seq we actually managed to write. Diverges from
|
||||
# ``sequence_no`` only on the IntegrityError-retry path below.
|
||||
materialised_seq = sequence_no
|
||||
try:
|
||||
# Short-lived per-event transaction. Critical for visibility:
|
||||
# the reconnect endpoint reads the journal from a separate
|
||||
# connection and only sees committed rows.
|
||||
with db_session() as conn:
|
||||
MessageEventsRepository(conn).record(
|
||||
message_id, sequence_no, event_type, materialised_payload
|
||||
)
|
||||
journal_committed = True
|
||||
except IntegrityError:
|
||||
# Composite-PK collision on (message_id, sequence_no). Most
|
||||
# likely cause is a stale ``latest_sequence_no`` seed on a
|
||||
# continuation retry — the route read MAX(seq) from a separate
|
||||
# connection before another writer committed past it. Look up
|
||||
# the live latest and retry once with latest+1 so the event is
|
||||
# not silently lost. Bounded to a single retry — if two
|
||||
# writers keep racing in lockstep the route-level retry will
|
||||
# converge them across attempts.
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
latest = MessageEventsRepository(conn).latest_sequence_no(
|
||||
message_id
|
||||
)
|
||||
materialised_seq = (latest if latest is not None else -1) + 1
|
||||
with db_session() as conn:
|
||||
MessageEventsRepository(conn).record(
|
||||
message_id,
|
||||
materialised_seq,
|
||||
event_type,
|
||||
materialised_payload,
|
||||
)
|
||||
journal_committed = True
|
||||
logger.info(
|
||||
"record_event: collision at seq=%s recovered → wrote at "
|
||||
"seq=%s message_id=%s type=%s",
|
||||
sequence_no,
|
||||
materialised_seq,
|
||||
message_id,
|
||||
event_type,
|
||||
)
|
||||
except IntegrityError:
|
||||
# Second collision under the same retry — give up and log.
|
||||
# The route's nonlocal counter will continue at
|
||||
# ``sequence_no+1`` on the next emit; the next call may
|
||||
# land cleanly past the contended window.
|
||||
logger.warning(
|
||||
"record_event: IntegrityError persists after seq+1 retry; "
|
||||
"dropping. message_id=%s original_seq=%s retry_seq=%s "
|
||||
"type=%s",
|
||||
message_id,
|
||||
sequence_no,
|
||||
materialised_seq,
|
||||
event_type,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"record_event: retry path failed unexpectedly "
|
||||
"(message_id=%s seq=%s type=%s)",
|
||||
message_id,
|
||||
sequence_no,
|
||||
event_type,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"message_events INSERT failed: message_id=%s seq=%s type=%s",
|
||||
message_id,
|
||||
sequence_no,
|
||||
event_type,
|
||||
)
|
||||
|
||||
try:
|
||||
# Publish using ``materialised_seq`` so the live pubsub frame
|
||||
# matches the journal row that other clients will snapshot on
|
||||
# reconnect. The original POST stream's SSE ``id:`` still
|
||||
# carries the caller's ``sequence_no`` — a reconnect from that
|
||||
# client will receive the same event at ``materialised_seq``
|
||||
# on the snapshot, which is a benign duplicate (the slice's
|
||||
# ``max_replayed_seq`` advances past it). No-collision case:
|
||||
# ``materialised_seq == sequence_no`` and this is identical to
|
||||
# the prior behaviour.
|
||||
wire = encode_pubsub_message(
|
||||
message_id, materialised_seq, event_type, materialised_payload
|
||||
)
|
||||
Topic(message_topic_name(message_id)).publish(wire)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"channel:%s publish failed: seq=%s type=%s",
|
||||
message_id,
|
||||
materialised_seq,
|
||||
event_type,
|
||||
)
|
||||
|
||||
return journal_committed
|
||||
|
||||
|
||||
class BatchedJournalWriter:
|
||||
"""Per-stream journal writer that batches PG INSERTs.
|
||||
|
||||
One writer per ``message_id``; ``record()`` buffers events and flushes
|
||||
on size/time/``close()`` triggers. Pubsub publishes fire only after the
|
||||
INSERT commits. On ``IntegrityError`` falls back to per-row writes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message_id: str,
|
||||
*,
|
||||
batch_size: int = DEFAULT_BATCH_SIZE,
|
||||
batch_interval_ms: int = DEFAULT_BATCH_INTERVAL_MS,
|
||||
) -> None:
|
||||
self._message_id = message_id
|
||||
self._batch_size = batch_size
|
||||
self._batch_interval_ms = batch_interval_ms
|
||||
self._buffer: list[tuple[int, str, dict[str, Any]]] = []
|
||||
self._last_flush_mono_ms = time.monotonic() * 1000.0
|
||||
self._closed = False
|
||||
|
||||
def record(
|
||||
self,
|
||||
sequence_no: int,
|
||||
event_type: str,
|
||||
payload: Optional[dict[str, Any]] = None,
|
||||
) -> bool:
|
||||
"""Buffer one event; maybe flush. Publish happens after journal commit."""
|
||||
if self._closed:
|
||||
logger.warning(
|
||||
"BatchedJournalWriter.record after close: "
|
||||
"message_id=%s seq=%s type=%s",
|
||||
self._message_id,
|
||||
sequence_no,
|
||||
event_type,
|
||||
)
|
||||
return False
|
||||
if not event_type:
|
||||
logger.warning(
|
||||
"BatchedJournalWriter.record without event_type: "
|
||||
"message_id=%s seq=%s",
|
||||
self._message_id,
|
||||
sequence_no,
|
||||
)
|
||||
return False
|
||||
if payload is None:
|
||||
materialised: dict[str, Any] = {}
|
||||
elif isinstance(payload, dict):
|
||||
materialised = _strip_null_bytes(payload)
|
||||
else:
|
||||
# Same contract as ``record_event`` — non-dict payloads
|
||||
# are rejected so the live and replay paths can't diverge
|
||||
# on envelope reconstruction.
|
||||
logger.warning(
|
||||
"BatchedJournalWriter.record with non-dict payload: "
|
||||
"message_id=%s seq=%s type=%s payload_type=%s — dropping",
|
||||
self._message_id,
|
||||
sequence_no,
|
||||
event_type,
|
||||
type(payload).__name__,
|
||||
)
|
||||
return False
|
||||
|
||||
self._buffer.append((sequence_no, event_type, materialised))
|
||||
|
||||
if self._should_flush():
|
||||
self.flush()
|
||||
return True
|
||||
|
||||
def _should_flush(self) -> bool:
|
||||
if len(self._buffer) >= self._batch_size:
|
||||
return True
|
||||
elapsed_ms = (time.monotonic() * 1000.0) - self._last_flush_mono_ms
|
||||
return elapsed_ms >= self._batch_interval_ms and len(self._buffer) > 0
|
||||
|
||||
def flush(self) -> None:
|
||||
"""Commit buffered events to PG. Best-effort.
|
||||
|
||||
Tries one bulk INSERT first; on ``IntegrityError`` (composite
|
||||
PK collision — typically a stale continuation seed) falls back
|
||||
to per-row ``record_event`` so one bad seq doesn't drop the
|
||||
rest of the batch. Always clears the buffer to bound memory,
|
||||
even on failure — a journaled event missing from a snapshot
|
||||
is degraded UX, but a runaway buffer is corruption.
|
||||
"""
|
||||
if not self._buffer:
|
||||
self._last_flush_mono_ms = time.monotonic() * 1000.0
|
||||
return
|
||||
|
||||
# Snapshot and clear before the I/O so a concurrent record()
|
||||
# call would land in a fresh buffer rather than racing the
|
||||
# flush. ``complete_stream`` is single-threaded per stream, so
|
||||
# this is belt-and-suspenders for any future change.
|
||||
pending = self._buffer
|
||||
self._buffer = []
|
||||
self._last_flush_mono_ms = time.monotonic() * 1000.0
|
||||
|
||||
try:
|
||||
with db_session() as conn:
|
||||
MessageEventsRepository(conn).bulk_record(
|
||||
self._message_id, pending
|
||||
)
|
||||
except IntegrityError:
|
||||
logger.info(
|
||||
"BatchedJournalWriter: bulk INSERT collided for "
|
||||
"message_id=%s n=%d; falling back to per-row writes",
|
||||
self._message_id,
|
||||
len(pending),
|
||||
)
|
||||
self._flush_per_row(pending)
|
||||
return
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"BatchedJournalWriter: bulk INSERT failed for "
|
||||
"message_id=%s n=%d; events dropped from journal",
|
||||
self._message_id,
|
||||
len(pending),
|
||||
)
|
||||
return
|
||||
|
||||
# Bulk INSERT committed — publish each frame in order. Best-effort:
|
||||
# one failed publish must not poison the rest of the batch.
|
||||
for seq, event_type, payload in pending:
|
||||
self._publish(seq, event_type, payload)
|
||||
|
||||
def _flush_per_row(
|
||||
self, pending: list[tuple[int, str, dict[str, Any]]]
|
||||
) -> None:
|
||||
"""Per-row fallback after a bulk collision. Publishes after each commit."""
|
||||
for seq, event_type, payload in pending:
|
||||
committed_seq: Optional[int] = None
|
||||
try:
|
||||
with db_session() as conn:
|
||||
MessageEventsRepository(conn).record(
|
||||
self._message_id, seq, event_type, payload
|
||||
)
|
||||
committed_seq = seq
|
||||
except IntegrityError:
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
latest = MessageEventsRepository(
|
||||
conn
|
||||
).latest_sequence_no(self._message_id)
|
||||
retry_seq = (latest if latest is not None else -1) + 1
|
||||
with db_session() as conn:
|
||||
MessageEventsRepository(conn).record(
|
||||
self._message_id, retry_seq, event_type, payload
|
||||
)
|
||||
committed_seq = retry_seq
|
||||
except IntegrityError:
|
||||
logger.warning(
|
||||
"BatchedJournalWriter: IntegrityError persists "
|
||||
"after seq+1 retry; dropping. message_id=%s "
|
||||
"original_seq=%s type=%s",
|
||||
self._message_id,
|
||||
seq,
|
||||
event_type,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"BatchedJournalWriter: per-row retry failed "
|
||||
"(message_id=%s seq=%s type=%s)",
|
||||
self._message_id,
|
||||
seq,
|
||||
event_type,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"BatchedJournalWriter: per-row INSERT failed "
|
||||
"(message_id=%s seq=%s type=%s)",
|
||||
self._message_id,
|
||||
seq,
|
||||
event_type,
|
||||
)
|
||||
|
||||
if committed_seq is not None:
|
||||
self._publish(committed_seq, event_type, payload)
|
||||
|
||||
def _publish(
|
||||
self, sequence_no: int, event_type: str, payload: dict[str, Any]
|
||||
) -> None:
|
||||
"""Publish one frame to the per-message pubsub channel. Best-effort."""
|
||||
try:
|
||||
wire = encode_pubsub_message(
|
||||
self._message_id, sequence_no, event_type, payload
|
||||
)
|
||||
Topic(message_topic_name(self._message_id)).publish(wire)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"channel:%s publish failed: seq=%s type=%s",
|
||||
self._message_id,
|
||||
sequence_no,
|
||||
event_type,
|
||||
)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Final flush. Idempotent — safe to call from multiple
|
||||
finally clauses.
|
||||
"""
|
||||
if self._closed:
|
||||
return
|
||||
self.flush()
|
||||
self._closed = True
|
||||
@@ -1,7 +1,5 @@
|
||||
import sys
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
from application.storage.db.repositories.token_usage import TokenUsageRepository
|
||||
from application.storage.db.session import db_session
|
||||
@@ -93,33 +91,62 @@ def _count_prompt_tokens(messages, tools=None, usage_attachments=None, **kwargs)
|
||||
return prompt_tokens
|
||||
|
||||
|
||||
def update_token_usage(decoded_token, user_api_key, token_usage, agent_id=None):
|
||||
if "pytest" in sys.modules:
|
||||
return
|
||||
user_id = decoded_token.get("sub") if isinstance(decoded_token, dict) else None
|
||||
normalized_agent_id = str(agent_id) if agent_id else None
|
||||
def _persist_call_usage(llm, call_usage):
|
||||
"""Write one ``token_usage`` row per LLM call. Always-on; no flag.
|
||||
|
||||
if not user_id and not user_api_key and not normalized_agent_id:
|
||||
Source defaults to ``agent_stream`` and can be overridden per
|
||||
instance via ``_token_usage_source`` (set on side-channel LLMs:
|
||||
title / compression / rag_condense / fallback). A ``_request_id``
|
||||
stamped on the LLM lets ``count_in_range`` deduplicate the multiple
|
||||
rows produced by a single multi-tool agent run.
|
||||
"""
|
||||
if call_usage["prompt_tokens"] == 0 and call_usage["generated_tokens"] == 0:
|
||||
return
|
||||
decoded_token = getattr(llm, "decoded_token", None)
|
||||
user_id = (
|
||||
decoded_token.get("sub") if isinstance(decoded_token, dict) else None
|
||||
)
|
||||
user_api_key = getattr(llm, "user_api_key", None)
|
||||
agent_id = getattr(llm, "agent_id", None)
|
||||
if not user_id and not user_api_key:
|
||||
# Repository would raise on the attribution check — log instead
|
||||
# so operators see the gap rather than crashing the stream.
|
||||
logger.warning(
|
||||
"Skipping token usage insert: missing user_id, api_key, and agent_id"
|
||||
"token_usage skip: no user_id/api_key on LLM instance",
|
||||
extra={
|
||||
"source": getattr(llm, "_token_usage_source", "agent_stream"),
|
||||
},
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
with db_session() as conn:
|
||||
# ``timestamp`` is omitted so Postgres ``server_default
|
||||
# = func.now()`` populates a tz-aware UTC value; passing
|
||||
# naive ``datetime.now()`` would silently shift on
|
||||
# non-UTC servers.
|
||||
TokenUsageRepository(conn).insert(
|
||||
user_id=user_id,
|
||||
api_key=user_api_key,
|
||||
agent_id=normalized_agent_id,
|
||||
prompt_tokens=token_usage["prompt_tokens"],
|
||||
generated_tokens=token_usage["generated_tokens"],
|
||||
timestamp=datetime.now(),
|
||||
agent_id=str(agent_id) if agent_id else None,
|
||||
prompt_tokens=call_usage["prompt_tokens"],
|
||||
generated_tokens=call_usage["generated_tokens"],
|
||||
source=(
|
||||
getattr(llm, "_token_usage_source", None) or "agent_stream"
|
||||
),
|
||||
request_id=getattr(llm, "_request_id", None),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to record token usage: {e}", exc_info=True)
|
||||
except Exception:
|
||||
logger.exception("token_usage persist failed")
|
||||
|
||||
|
||||
def gen_token_usage(func):
|
||||
"""Accumulate per-call token counts and write a ``token_usage`` row.
|
||||
|
||||
The accumulator on ``self.token_usage`` stays in place for code
|
||||
paths that introspect it (e.g., logging, response payloads). DB
|
||||
persistence happens here for every call so primary streams,
|
||||
side-channel LLMs, and no-save flows all produce rows uniformly.
|
||||
"""
|
||||
def wrapper(self, model, messages, stream, tools, **kwargs):
|
||||
usage_attachments = kwargs.pop("_usage_attachments", None)
|
||||
call_usage = {"prompt_tokens": 0, "generated_tokens": 0}
|
||||
@@ -133,18 +160,14 @@ def gen_token_usage(func):
|
||||
call_usage["generated_tokens"] += _count_tokens(result)
|
||||
self.token_usage["prompt_tokens"] += call_usage["prompt_tokens"]
|
||||
self.token_usage["generated_tokens"] += call_usage["generated_tokens"]
|
||||
update_token_usage(
|
||||
self.decoded_token,
|
||||
self.user_api_key,
|
||||
call_usage,
|
||||
getattr(self, "agent_id", None),
|
||||
)
|
||||
_persist_call_usage(self, call_usage)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def stream_token_usage(func):
|
||||
"""Stream variant of ``gen_token_usage``. Same persistence contract."""
|
||||
def wrapper(self, model, messages, stream, tools, **kwargs):
|
||||
usage_attachments = kwargs.pop("_usage_attachments", None)
|
||||
call_usage = {"prompt_tokens": 0, "generated_tokens": 0}
|
||||
@@ -173,15 +196,7 @@ def stream_token_usage(func):
|
||||
call_usage["generated_tokens"] += _count_tokens(line)
|
||||
self.token_usage["prompt_tokens"] += call_usage["prompt_tokens"]
|
||||
self.token_usage["generated_tokens"] += call_usage["generated_tokens"]
|
||||
# Persist usage rows only on success: a partial mid-stream
|
||||
# failure shouldn't bill the user for a response they never got.
|
||||
if error is None:
|
||||
update_token_usage(
|
||||
self.decoded_token,
|
||||
self.user_api_key,
|
||||
call_usage,
|
||||
getattr(self, "agent_id", None),
|
||||
)
|
||||
_persist_call_usage(self, call_usage)
|
||||
emit = getattr(self, "_emit_stream_finished_log", None)
|
||||
if callable(emit):
|
||||
try:
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import logging
|
||||
from typing import List, Optional, Any, Dict
|
||||
|
||||
from psycopg.types.json import Jsonb
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.vectorstore.base import BaseVectorStore
|
||||
from application.vectorstore.document_class import Document
|
||||
@@ -175,7 +178,7 @@ class PGVectorStore(BaseVectorStore):
|
||||
for text, embedding, metadata in zip(texts, embeddings, metadatas):
|
||||
cursor.execute(
|
||||
insert_query,
|
||||
(text, embedding, metadata, self._source_id)
|
||||
(text, embedding, Jsonb(metadata), self._source_id)
|
||||
)
|
||||
inserted_id = cursor.fetchone()[0]
|
||||
inserted_ids.append(str(inserted_id))
|
||||
@@ -266,7 +269,7 @@ class PGVectorStore(BaseVectorStore):
|
||||
|
||||
cursor.execute(
|
||||
insert_query,
|
||||
(text, embeddings[0], final_metadata, self._source_id)
|
||||
(text, embeddings[0], Jsonb(final_metadata), self._source_id)
|
||||
)
|
||||
inserted_id = cursor.fetchone()[0]
|
||||
conn.commit()
|
||||
|
||||
@@ -6,6 +6,7 @@ import os
|
||||
import shutil
|
||||
import string
|
||||
import tempfile
|
||||
import threading
|
||||
from typing import Any, Dict
|
||||
import zipfile
|
||||
|
||||
@@ -18,11 +19,14 @@ import requests
|
||||
from application.agents.agent_creator import AgentCreator
|
||||
from application.api.answer.services.stream_processor import get_prompt
|
||||
|
||||
from application.cache import get_redis_instance
|
||||
from application.core.settings import settings
|
||||
from application.events.publisher import publish_user_event
|
||||
from application.parser.chunking import Chunker
|
||||
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||
from application.parser.embedding_pipeline import embed_and_store_documents
|
||||
from application.parser.embedding_pipeline import (
|
||||
assert_index_complete,
|
||||
embed_and_store_documents,
|
||||
)
|
||||
from application.parser.file.bulk import SimpleDirectoryReader, get_default_file_extractor
|
||||
from application.parser.file.constants import SUPPORTED_SOURCE_EXTENSIONS
|
||||
from application.parser.remote.remote_creator import RemoteCreator
|
||||
@@ -32,6 +36,9 @@ from application.retriever.retriever_creator import RetrieverCreator
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.attachments import AttachmentsRepository
|
||||
from application.storage.db.repositories.ingest_chunk_progress import (
|
||||
IngestChunkProgressRepository,
|
||||
)
|
||||
from application.storage.db.repositories.sources import SourcesRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
@@ -43,6 +50,51 @@ from application.utils import count_tokens_docs, num_tokens_from_string, safe_fi
|
||||
MIN_TOKENS = 150
|
||||
MAX_TOKENS = 1250
|
||||
RECURSION_DEPTH = 2
|
||||
INGEST_HEARTBEAT_INTERVAL_SECONDS = 30
|
||||
|
||||
# Re-exported here for backward-compatible imports
|
||||
# (``from application.worker import _derive_source_id`` /
|
||||
# ``DOCSGPT_INGEST_NAMESPACE``) from tests and any other in-tree callers.
|
||||
# New code should import from ``application.storage.db.source_ids``
|
||||
# directly to avoid pulling this Celery worker module into the API
|
||||
# process at import time.
|
||||
from application.storage.db.source_ids import ( # noqa: E402, F401
|
||||
DOCSGPT_INGEST_NAMESPACE,
|
||||
derive_source_id as _derive_source_id,
|
||||
)
|
||||
|
||||
|
||||
def _ingest_heartbeat_loop(source_id, stop_event, interval=INGEST_HEARTBEAT_INTERVAL_SECONDS):
|
||||
"""Bump ``ingest_chunk_progress.last_updated`` until ``stop_event`` is set."""
|
||||
while not stop_event.wait(interval):
|
||||
try:
|
||||
with db_session() as conn:
|
||||
IngestChunkProgressRepository(conn).bump_heartbeat(source_id)
|
||||
except Exception as e:
|
||||
logging.warning(
|
||||
f"Heartbeat failed for {source_id}: {e}", exc_info=True
|
||||
)
|
||||
|
||||
|
||||
def _start_ingest_heartbeat(source_id):
|
||||
"""Spawn the heartbeat daemon and return ``(thread, stop_event)``."""
|
||||
stop_event = threading.Event()
|
||||
thread = threading.Thread(
|
||||
target=_ingest_heartbeat_loop,
|
||||
args=(str(source_id), stop_event),
|
||||
daemon=True,
|
||||
name=f"ingest-heartbeat-{source_id}",
|
||||
)
|
||||
thread.start()
|
||||
return thread, stop_event
|
||||
|
||||
|
||||
def _stop_ingest_heartbeat(thread, stop_event):
|
||||
"""Signal the heartbeat daemon to exit and wait briefly for it."""
|
||||
if stop_event is not None:
|
||||
stop_event.set()
|
||||
if thread is not None:
|
||||
thread.join(timeout=5)
|
||||
|
||||
|
||||
# Define a function to extract metadata from a given filename.
|
||||
@@ -455,6 +507,8 @@ def ingest_worker(
|
||||
user,
|
||||
retriever="classic",
|
||||
file_name_map=None,
|
||||
idempotency_key=None,
|
||||
source_id=None,
|
||||
):
|
||||
"""
|
||||
Ingest and process documents.
|
||||
@@ -469,6 +523,14 @@ def ingest_worker(
|
||||
user (str): Identifier for the user initiating the ingestion (original, unsanitized).
|
||||
retriever (str): Type of retriever to use for processing the documents.
|
||||
file_name_map (dict|str|None): Optional mapping of safe relative paths to original filenames.
|
||||
idempotency_key (str|None): When provided, the ``source_id`` is derived
|
||||
deterministically from the key so a retried task reuses the same
|
||||
source row instead of duplicating it.
|
||||
source_id (str|None): UUID minted by the HTTP route and returned in
|
||||
its response. When supplied, the worker uses it verbatim so SSE
|
||||
envelopes carry the same id the frontend already has — required
|
||||
for non-idempotent uploads where the route can't predict
|
||||
``_derive_source_id(idempotency_key)``.
|
||||
|
||||
Returns:
|
||||
dict: Information about the completed ingestion task, including input parameters and a "limited" flag.
|
||||
@@ -483,10 +545,41 @@ def ingest_worker(
|
||||
|
||||
logging.info(f"Ingest path: {file_path}", extra={"user": user, "job": job_name})
|
||||
|
||||
# Create temporary working directory
|
||||
# Source id resolution order:
|
||||
# 1. Caller-supplied ``source_id`` (HTTP route minted + returned to
|
||||
# the frontend) — keeps the route response and the SSE event
|
||||
# payloads in lockstep on the non-idempotent path.
|
||||
# 2. Deterministic uuid5 from ``idempotency_key`` — retried tasks
|
||||
# reuse the original source row instead of duplicating it.
|
||||
# 3. Fresh uuid4 (caller has neither) — opaque, single-shot only.
|
||||
if source_id:
|
||||
source_uuid = uuid.UUID(source_id)
|
||||
else:
|
||||
source_uuid = _derive_source_id(idempotency_key)
|
||||
source_id_for_events = str(source_uuid)
|
||||
# Only emit ``queued`` on the original attempt. Celery retries re-run
|
||||
# the body, and re-publishing here would oscillate the toast through
|
||||
# ``queued`` again between ``failed`` and ``completed``.
|
||||
if self.request.retries == 0:
|
||||
publish_user_event(
|
||||
user,
|
||||
"source.ingest.queued",
|
||||
{
|
||||
"job_name": job_name,
|
||||
"filename": filename,
|
||||
"source_id": source_id_for_events,
|
||||
"operation": "upload",
|
||||
},
|
||||
scope={"kind": "source", "id": source_id_for_events},
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
try:
|
||||
# Wrap the entire body in try/except so a failure between the
|
||||
# ``queued`` publish above and the inner work (e.g. tempdir
|
||||
# creation, OS-level resource exhaustion) still emits a terminal
|
||||
# ``failed`` event rather than leaving the toast wedged on
|
||||
# 'training' until the polling fallback rescues it 30s later.
|
||||
try:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
|
||||
if storage.is_directory(file_path):
|
||||
@@ -575,12 +668,22 @@ def ingest_worker(
|
||||
|
||||
docs = [Document.to_langchain_format(raw_doc) for raw_doc in raw_docs]
|
||||
|
||||
id = uuid.uuid4()
|
||||
|
||||
vector_store_path = os.path.join(temp_dir, "vector_store")
|
||||
os.makedirs(vector_store_path, exist_ok=True)
|
||||
|
||||
embed_and_store_documents(docs, vector_store_path, id, self)
|
||||
heartbeat_thread, heartbeat_stop = _start_ingest_heartbeat(source_uuid)
|
||||
try:
|
||||
embed_and_store_documents(
|
||||
docs, vector_store_path, source_uuid, self,
|
||||
attempt_id=getattr(self.request, "id", None),
|
||||
user_id=user,
|
||||
)
|
||||
finally:
|
||||
_stop_ingest_heartbeat(heartbeat_thread, heartbeat_stop)
|
||||
# Defense-in-depth: chunk-progress is the authoritative
|
||||
# record of how many chunks landed; mismatch raises so the
|
||||
# task fails loud rather than caching a partial index.
|
||||
assert_index_complete(source_uuid)
|
||||
|
||||
tokens = count_tokens_docs(docs)
|
||||
|
||||
@@ -595,7 +698,7 @@ def ingest_worker(
|
||||
"user": user,
|
||||
"tokens": tokens,
|
||||
"retriever": retriever,
|
||||
"id": str(id),
|
||||
"id": source_id_for_events,
|
||||
"type": "local",
|
||||
"file_path": file_path,
|
||||
"directory_structure": json.dumps(directory_structure),
|
||||
@@ -604,9 +707,36 @@ def ingest_worker(
|
||||
file_data["file_name_map"] = json.dumps(file_name_map)
|
||||
|
||||
upload_index(vector_store_path, file_data)
|
||||
except Exception as e:
|
||||
logging.error(f"Error in ingest_worker: {e}", exc_info=True)
|
||||
raise
|
||||
publish_user_event(
|
||||
user,
|
||||
"source.ingest.completed",
|
||||
{
|
||||
"source_id": source_id_for_events,
|
||||
"filename": filename,
|
||||
"tokens": tokens,
|
||||
"operation": "upload",
|
||||
# Forward-looking contract: ``limited`` is always
|
||||
# ``False`` today but is carried on the wire so a
|
||||
# future token-cap detection path can flip it and
|
||||
# the frontend slice / UploadToast already react.
|
||||
"limited": False,
|
||||
},
|
||||
scope={"kind": "source", "id": source_id_for_events},
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error in ingest_worker: {e}", exc_info=True)
|
||||
publish_user_event(
|
||||
user,
|
||||
"source.ingest.failed",
|
||||
{
|
||||
"source_id": source_id_for_events,
|
||||
"filename": filename,
|
||||
"operation": "upload",
|
||||
"error": str(e)[:1024],
|
||||
},
|
||||
scope={"kind": "source", "id": source_id_for_events},
|
||||
)
|
||||
raise
|
||||
return {
|
||||
"directory": directory,
|
||||
"formats": formats,
|
||||
@@ -630,7 +760,23 @@ def reingest_source_worker(self, source_id, user):
|
||||
|
||||
Returns:
|
||||
dict: Information about the re-ingestion task
|
||||
|
||||
Note:
|
||||
Reingest does its own ``vector_store.add_chunk`` work rather
|
||||
than going through ``embed_and_store_documents`` so it does
|
||||
*not* emit per-percent SSE progress events — only ``queued``,
|
||||
``completed`` (carrying ``chunks_added`` / ``chunks_deleted``),
|
||||
or ``failed``. v1 limitation; revisit if reingest gains a
|
||||
progress-driven UI.
|
||||
"""
|
||||
# Declared at the function scope so the outer except can include
|
||||
# ``name`` in the failed event payload when the failure happens
|
||||
# after the source lookup. Empty string until the lookup succeeds.
|
||||
source_name = ""
|
||||
# Tracks inner-block failures so a ``completed`` event reflects
|
||||
# partial-success accurately rather than masking it.
|
||||
inner_warnings: list[str] = []
|
||||
|
||||
try:
|
||||
from application.vectorstore.vector_creator import VectorCreator
|
||||
|
||||
@@ -644,6 +790,27 @@ def reingest_source_worker(self, source_id, user):
|
||||
if not source:
|
||||
raise ValueError(f"Source {source_id} not found or access denied")
|
||||
source_id = str(source["id"])
|
||||
source_name = source.get("name") or ""
|
||||
|
||||
# Publish ``queued`` *after* canonicalising ``source_id`` so the
|
||||
# event references the same id as the source row. Trade-off
|
||||
# documented: a Celery-backend or PG-lookup hiccup before this
|
||||
# publish means the toast may see only a ``failed`` event with
|
||||
# no preceding ``queued`` — acceptable for v1 since both
|
||||
# conditions also imply broader system trouble. Gate on first
|
||||
# attempt only so Celery retries don't re-emit ``queued`` after
|
||||
# a prior attempt already published ``failed``.
|
||||
if self.request.retries == 0:
|
||||
publish_user_event(
|
||||
user,
|
||||
"source.ingest.queued",
|
||||
{
|
||||
"source_id": source_id,
|
||||
"name": source_name,
|
||||
"operation": "reingest",
|
||||
},
|
||||
scope={"kind": "source", "id": source_id},
|
||||
)
|
||||
|
||||
storage = StorageCreator.get_storage()
|
||||
source_file_path = source.get("file_path", "")
|
||||
@@ -741,6 +908,19 @@ def reingest_source_worker(self, source_id, user):
|
||||
try:
|
||||
if not added_files and not removed_files:
|
||||
logging.info("No changes detected.")
|
||||
publish_user_event(
|
||||
user,
|
||||
"source.ingest.completed",
|
||||
{
|
||||
"source_id": source_id,
|
||||
"name": source_name,
|
||||
"operation": "reingest",
|
||||
"no_changes": True,
|
||||
"chunks_added": 0,
|
||||
"chunks_deleted": 0,
|
||||
},
|
||||
scope={"kind": "source", "id": source_id},
|
||||
)
|
||||
return {
|
||||
"source_id": source_id,
|
||||
"user": user,
|
||||
@@ -792,6 +972,9 @@ def reingest_source_worker(self, source_id, user):
|
||||
f"Error during deletion of removed file chunks: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
inner_warnings.append(
|
||||
f"deletion failed: {str(e)[:200]}"
|
||||
)
|
||||
|
||||
# 2) Add chunks from new files
|
||||
added = 0
|
||||
@@ -884,6 +1067,9 @@ def reingest_source_worker(self, source_id, user):
|
||||
logging.error(
|
||||
f"Error during ingestion of new files: {e}", exc_info=True
|
||||
)
|
||||
inner_warnings.append(
|
||||
f"add failed: {str(e)[:200]}"
|
||||
)
|
||||
|
||||
# 3) Update source directory structure timestamp
|
||||
try:
|
||||
@@ -912,6 +1098,25 @@ def reingest_source_worker(self, source_id, user):
|
||||
meta={"current": 100, "status": "Re-ingestion completed"},
|
||||
)
|
||||
|
||||
completed_payload: dict = {
|
||||
"source_id": source_id,
|
||||
"name": source_name,
|
||||
"operation": "reingest",
|
||||
"chunks_added": added,
|
||||
"chunks_deleted": deleted,
|
||||
"tokens": int(total_tokens) if "total_tokens" in locals() else 0,
|
||||
}
|
||||
if inner_warnings:
|
||||
# Surface the per-block failures so the toast can warn
|
||||
# rather than claim a clean success.
|
||||
completed_payload["warnings"] = inner_warnings
|
||||
publish_user_event(
|
||||
user,
|
||||
"source.ingest.completed",
|
||||
completed_payload,
|
||||
scope={"kind": "source", "id": source_id},
|
||||
)
|
||||
|
||||
return {
|
||||
"source_id": source_id,
|
||||
"user": user,
|
||||
@@ -929,6 +1134,17 @@ def reingest_source_worker(self, source_id, user):
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error in reingest_source_worker: {e}", exc_info=True)
|
||||
publish_user_event(
|
||||
user,
|
||||
"source.ingest.failed",
|
||||
{
|
||||
"source_id": str(source_id),
|
||||
"name": source_name,
|
||||
"operation": "reingest",
|
||||
"error": str(e)[:1024],
|
||||
},
|
||||
scope={"kind": "source", "id": str(source_id)},
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@@ -943,12 +1159,52 @@ def remote_worker(
|
||||
sync_frequency="never",
|
||||
operation_mode="upload",
|
||||
doc_id=None,
|
||||
idempotency_key=None,
|
||||
source_id=None,
|
||||
):
|
||||
safe_user = safe_filename(user)
|
||||
full_path = os.path.join(directory, safe_user, uuid.uuid4().hex)
|
||||
os.makedirs(full_path, exist_ok=True)
|
||||
self.update_state(state="PROGRESS", meta={"current": 1})
|
||||
|
||||
# Source id resolution order matches ``ingest_worker``:
|
||||
# 1. ``operation_mode == "sync"`` reuses the existing source's ``doc_id``.
|
||||
# 2. Caller-supplied ``source_id`` (the HTTP route minted it and
|
||||
# already returned it to the frontend) — keeps the route
|
||||
# response and the SSE event payloads in lockstep on the
|
||||
# no-idempotency-key path.
|
||||
# 3. Deterministic uuid5 from ``idempotency_key`` — retried tasks
|
||||
# reuse the original source row instead of duplicating it.
|
||||
# 4. Fresh uuid4 — opaque, single-shot only.
|
||||
if operation_mode == "sync" and doc_id:
|
||||
source_uuid = str(doc_id)
|
||||
elif source_id:
|
||||
source_uuid = uuid.UUID(source_id)
|
||||
else:
|
||||
source_uuid = _derive_source_id(idempotency_key)
|
||||
source_id_for_events = str(source_uuid)
|
||||
|
||||
# Emit the queued event before any work that could fail (including
|
||||
# ``update_state``) so the toast UI always sees a queued envelope
|
||||
# before any subsequent failed event. Gated on first attempt so
|
||||
# Celery retries don't re-emit ``queued`` after a prior ``failed``.
|
||||
if self.request.retries == 0:
|
||||
publish_user_event(
|
||||
user,
|
||||
"source.ingest.queued",
|
||||
{
|
||||
"source_id": source_id_for_events,
|
||||
"job_name": name_job,
|
||||
"loader": loader,
|
||||
"operation": operation_mode,
|
||||
},
|
||||
scope={"kind": "source", "id": source_id_for_events},
|
||||
)
|
||||
|
||||
# Wrap ``update_state`` plus the entire body so any pre-loader
|
||||
# failure (Celery backend down, OS resource issue) still emits a
|
||||
# terminal ``failed`` event rather than wedging the toast.
|
||||
try:
|
||||
self.update_state(state="PROGRESS", meta={"current": 1})
|
||||
logging.info("Initializing remote loader with type: %s", loader)
|
||||
remote_loader = RemoteCreator.create_loader(loader)
|
||||
raw_docs = remote_loader.load_data(source_data)
|
||||
@@ -1035,14 +1291,22 @@ def remote_worker(
|
||||
)
|
||||
|
||||
if operation_mode == "upload":
|
||||
id = uuid.uuid4()
|
||||
embed_and_store_documents(docs, full_path, id, self)
|
||||
embed_and_store_documents(
|
||||
docs, full_path, source_uuid, self,
|
||||
attempt_id=getattr(self.request, "id", None),
|
||||
user_id=user,
|
||||
)
|
||||
assert_index_complete(source_uuid)
|
||||
elif operation_mode == "sync":
|
||||
if not doc_id:
|
||||
logging.error("Invalid doc_id provided for sync operation: %s", doc_id)
|
||||
raise ValueError("doc_id must be provided for sync operation.")
|
||||
id = str(doc_id)
|
||||
embed_and_store_documents(docs, full_path, id, self)
|
||||
embed_and_store_documents(
|
||||
docs, full_path, source_uuid, self,
|
||||
attempt_id=getattr(self.request, "id", None),
|
||||
user_id=user,
|
||||
)
|
||||
assert_index_complete(source_uuid)
|
||||
self.update_state(state="PROGRESS", meta={"current": 100})
|
||||
|
||||
# Serialize remote_data as JSON if it's a dict (for S3, Reddit, etc.)
|
||||
@@ -1054,7 +1318,7 @@ def remote_worker(
|
||||
"user": user,
|
||||
"tokens": tokens,
|
||||
"retriever": retriever,
|
||||
"id": str(id),
|
||||
"id": source_id_for_events,
|
||||
"type": loader,
|
||||
"remote_data": remote_data_serialized,
|
||||
"sync_frequency": sync_frequency,
|
||||
@@ -1068,23 +1332,49 @@ def remote_worker(
|
||||
try:
|
||||
with db_session() as conn:
|
||||
repo = SourcesRepository(conn)
|
||||
src = repo.get_any(str(id), user)
|
||||
src = repo.get_any(source_id_for_events, user)
|
||||
if src is not None:
|
||||
repo.update(str(src["id"]), user, {"date": last_sync_now})
|
||||
except Exception as upd_err:
|
||||
logging.warning(
|
||||
f"Failed to update last_sync for source {id}: {upd_err}"
|
||||
f"Failed to update last_sync for source {source_id_for_events}: {upd_err}"
|
||||
)
|
||||
upload_index(full_path, file_data)
|
||||
publish_user_event(
|
||||
user,
|
||||
"source.ingest.completed",
|
||||
{
|
||||
"source_id": source_id_for_events,
|
||||
"job_name": name_job,
|
||||
"loader": loader,
|
||||
"operation": operation_mode,
|
||||
"tokens": tokens,
|
||||
# Forward-looking contract: see ingest_worker.
|
||||
"limited": False,
|
||||
},
|
||||
scope={"kind": "source", "id": source_id_for_events},
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error("Error in remote_worker task: %s", str(e), exc_info=True)
|
||||
publish_user_event(
|
||||
user,
|
||||
"source.ingest.failed",
|
||||
{
|
||||
"source_id": source_id_for_events,
|
||||
"job_name": name_job,
|
||||
"loader": loader,
|
||||
"operation": operation_mode,
|
||||
"error": str(e)[:1024],
|
||||
},
|
||||
scope={"kind": "source", "id": source_id_for_events},
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
if os.path.exists(full_path):
|
||||
shutil.rmtree(full_path)
|
||||
logging.info("remote_worker task completed successfully")
|
||||
return {
|
||||
"id": str(id),
|
||||
"id": source_id_for_events,
|
||||
"urls": source_data,
|
||||
"name_job": name_job,
|
||||
"user": user,
|
||||
@@ -1167,6 +1457,13 @@ def attachment_worker(self, file_info, user):
|
||||
relative_path = file_info["path"]
|
||||
metadata = file_info.get("metadata", {})
|
||||
|
||||
publish_user_event(
|
||||
user,
|
||||
"attachment.queued",
|
||||
{"attachment_id": str(attachment_id), "filename": filename},
|
||||
scope={"kind": "attachment", "id": str(attachment_id)},
|
||||
)
|
||||
|
||||
try:
|
||||
self.update_state(state="PROGRESS", meta={"current": 10})
|
||||
storage = StorageCreator.get_storage()
|
||||
@@ -1174,6 +1471,17 @@ def attachment_worker(self, file_info, user):
|
||||
self.update_state(
|
||||
state="PROGRESS", meta={"current": 30, "status": "Processing content"}
|
||||
)
|
||||
publish_user_event(
|
||||
user,
|
||||
"attachment.progress",
|
||||
{
|
||||
"attachment_id": str(attachment_id),
|
||||
"filename": filename,
|
||||
"current": 30,
|
||||
"stage": "processing",
|
||||
},
|
||||
scope={"kind": "attachment", "id": str(attachment_id)},
|
||||
)
|
||||
|
||||
file_extractor = get_default_file_extractor(
|
||||
ocr_enabled=settings.DOCLING_OCR_ATTACHMENTS_ENABLED
|
||||
@@ -1206,6 +1514,17 @@ def attachment_worker(self, file_info, user):
|
||||
self.update_state(
|
||||
state="PROGRESS", meta={"current": 80, "status": "Storing in database"}
|
||||
)
|
||||
publish_user_event(
|
||||
user,
|
||||
"attachment.progress",
|
||||
{
|
||||
"attachment_id": str(attachment_id),
|
||||
"filename": filename,
|
||||
"current": 80,
|
||||
"stage": "storing",
|
||||
},
|
||||
scope={"kind": "attachment", "id": str(attachment_id)},
|
||||
)
|
||||
|
||||
mime_type = mimetypes.guess_type(filename)[0] or "application/octet-stream"
|
||||
|
||||
@@ -1230,6 +1549,18 @@ def attachment_worker(self, file_info, user):
|
||||
|
||||
self.update_state(state="PROGRESS", meta={"current": 100, "status": "Complete"})
|
||||
|
||||
publish_user_event(
|
||||
user,
|
||||
"attachment.completed",
|
||||
{
|
||||
"attachment_id": str(attachment_id),
|
||||
"filename": filename,
|
||||
"token_count": token_count,
|
||||
"mime_type": mime_type,
|
||||
},
|
||||
scope={"kind": "attachment", "id": str(attachment_id)},
|
||||
)
|
||||
|
||||
return {
|
||||
"filename": filename,
|
||||
"path": relative_path,
|
||||
@@ -1244,20 +1575,24 @@ def attachment_worker(self, file_info, user):
|
||||
extra={"user": user},
|
||||
exc_info=True,
|
||||
)
|
||||
publish_user_event(
|
||||
user,
|
||||
"attachment.failed",
|
||||
{
|
||||
"attachment_id": str(attachment_id),
|
||||
"filename": filename,
|
||||
"error": str(e)[:1024],
|
||||
},
|
||||
scope={"kind": "attachment", "id": str(attachment_id)},
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def agent_webhook_worker(self, agent_id, payload):
|
||||
"""
|
||||
Process the webhook payload for an agent.
|
||||
"""Process the webhook payload for an agent.
|
||||
|
||||
Args:
|
||||
self: Reference to the instance of the task.
|
||||
agent_id (str): Unique identifier for the agent.
|
||||
payload (dict): The payload data from the webhook.
|
||||
|
||||
Returns:
|
||||
dict: Information about the processed webhook.
|
||||
Raises on failure: Celery treats a returned dict as success and
|
||||
would skip retries, leaving the caller with a stale 200.
|
||||
"""
|
||||
self.update_state(state="PROGRESS", meta={"current": 1})
|
||||
try:
|
||||
@@ -1283,13 +1618,13 @@ def agent_webhook_worker(self, agent_id, payload):
|
||||
input_data = json.dumps(payload)
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing agent webhook: {e}", exc_info=True)
|
||||
return {"status": "error", "error": str(e)}
|
||||
raise
|
||||
self.update_state(state="PROGRESS", meta={"current": 50})
|
||||
try:
|
||||
result = run_agent_logic(agent_config, input_data)
|
||||
except Exception as e:
|
||||
logging.error(f"Error running agent logic: {e}", exc_info=True)
|
||||
return {"status": "error"}
|
||||
raise
|
||||
else:
|
||||
logging.info(
|
||||
f"Webhook processed for agent {agent_id}", extra={"agent_id": agent_id}
|
||||
@@ -1312,6 +1647,8 @@ def ingest_connector(
|
||||
operation_mode: str = "upload",
|
||||
doc_id=None,
|
||||
sync_frequency: str = "never",
|
||||
idempotency_key=None,
|
||||
source_id=None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Ingestion for internal knowledge bases (GoogleDrive, etc.).
|
||||
@@ -1328,14 +1665,52 @@ def ingest_connector(
|
||||
operation_mode: "upload" for initial ingestion, "sync" for incremental sync
|
||||
doc_id: Document ID for sync operations (required when operation_mode="sync")
|
||||
sync_frequency: How often to sync ("never", "daily", "weekly", "monthly")
|
||||
idempotency_key: When provided, the ``source_id`` is derived
|
||||
deterministically so a retried upload reuses the same source row.
|
||||
source_id: When supplied, the worker uses it verbatim so SSE envelopes
|
||||
carry the same id the HTTP route already returned to the frontend
|
||||
— required for non-idempotent uploads where the route can't
|
||||
predict ``_derive_source_id(idempotency_key)``.
|
||||
"""
|
||||
logging.info(
|
||||
f"Starting remote ingestion from {source_type} for user: {user}, job: {job_name}"
|
||||
)
|
||||
|
||||
# Source id resolution mirrors ``ingest_worker`` / ``remote_worker``:
|
||||
# sync mode reuses ``doc_id``; otherwise the caller-supplied
|
||||
# ``source_id`` (minted by the HTTP route and already echoed to the
|
||||
# client) wins; fall back to ``_derive_source_id`` only when neither
|
||||
# is supplied. Without rule (2) the no-idempotency-key path would
|
||||
# mint a fresh uuid4 here that the frontend has no way to correlate
|
||||
# SSE envelopes to.
|
||||
if operation_mode == "sync" and doc_id:
|
||||
source_uuid = str(doc_id)
|
||||
elif source_id:
|
||||
source_uuid = uuid.UUID(source_id)
|
||||
else:
|
||||
source_uuid = _derive_source_id(idempotency_key)
|
||||
source_id_for_events = str(source_uuid)
|
||||
|
||||
# First-attempt gate: Celery retries re-run the body, and a
|
||||
# repeated ``queued`` here would oscillate the toast through
|
||||
# ``queued`` again between ``failed`` and ``completed``.
|
||||
if self.request.retries == 0:
|
||||
publish_user_event(
|
||||
user,
|
||||
"source.ingest.queued",
|
||||
{
|
||||
"source_id": source_id_for_events,
|
||||
"job_name": job_name,
|
||||
"loader": source_type,
|
||||
"operation": operation_mode,
|
||||
},
|
||||
scope={"kind": "source", "id": source_id_for_events},
|
||||
)
|
||||
|
||||
self.update_state(state="PROGRESS", meta={"current": 1})
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
try:
|
||||
try:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Step 1: Initialize the appropriate loader
|
||||
self.update_state(
|
||||
state="PROGRESS",
|
||||
@@ -1373,6 +1748,22 @@ def ingest_connector(
|
||||
"files_downloaded", 0
|
||||
):
|
||||
logging.warning(f"No files were downloaded from {source_type}")
|
||||
# Connector returned no files — surface as a benign
|
||||
# ``completed`` event with zero tokens so the toast
|
||||
# closes out cleanly instead of waiting on polling.
|
||||
publish_user_event(
|
||||
user,
|
||||
"source.ingest.completed",
|
||||
{
|
||||
"source_id": source_id_for_events,
|
||||
"job_name": job_name,
|
||||
"loader": source_type,
|
||||
"operation": operation_mode,
|
||||
"tokens": 0,
|
||||
"no_changes": True,
|
||||
},
|
||||
scope={"kind": "source", "id": source_id_for_events},
|
||||
)
|
||||
# Create empty result directly instead of calling a separate method
|
||||
return {
|
||||
"name": job_name,
|
||||
@@ -1422,16 +1813,16 @@ def ingest_connector(
|
||||
|
||||
docs = [Document.to_langchain_format(raw_doc) for raw_doc in raw_docs]
|
||||
|
||||
if operation_mode == "upload":
|
||||
id = uuid.uuid4()
|
||||
elif operation_mode == "sync":
|
||||
if not doc_id:
|
||||
logging.error(
|
||||
"Invalid doc_id provided for sync operation: %s", doc_id
|
||||
)
|
||||
raise ValueError("doc_id must be provided for sync operation.")
|
||||
id = str(doc_id)
|
||||
else:
|
||||
# Validate operation_mode here too (the source_uuid path
|
||||
# at the top of the function only branches on the
|
||||
# sync+doc_id combination; surfacing the wrong-mode error
|
||||
# this far in matches the legacy behaviour).
|
||||
if operation_mode == "sync" and not doc_id:
|
||||
logging.error(
|
||||
"Invalid doc_id provided for sync operation: %s", doc_id
|
||||
)
|
||||
raise ValueError("doc_id must be provided for sync operation.")
|
||||
if operation_mode not in ("upload", "sync"):
|
||||
raise ValueError(f"Invalid operation_mode: {operation_mode}")
|
||||
|
||||
vector_store_path = os.path.join(temp_dir, "vector_store")
|
||||
@@ -1440,7 +1831,12 @@ def ingest_connector(
|
||||
self.update_state(
|
||||
state="PROGRESS", meta={"current": 80, "status": "Storing documents"}
|
||||
)
|
||||
embed_and_store_documents(docs, vector_store_path, id, self)
|
||||
embed_and_store_documents(
|
||||
docs, vector_store_path, source_uuid, self,
|
||||
attempt_id=getattr(self.request, "id", None),
|
||||
user_id=user,
|
||||
)
|
||||
assert_index_complete(source_uuid)
|
||||
|
||||
tokens = count_tokens_docs(docs)
|
||||
|
||||
@@ -1450,7 +1846,7 @@ def ingest_connector(
|
||||
"name": job_name,
|
||||
"tokens": tokens,
|
||||
"retriever": retriever,
|
||||
"id": str(id),
|
||||
"id": source_id_for_events,
|
||||
"type": "connector:file",
|
||||
"remote_data": json.dumps(
|
||||
{"provider": source_type, **api_source_config}
|
||||
@@ -1459,16 +1855,13 @@ def ingest_connector(
|
||||
"sync_frequency": sync_frequency,
|
||||
}
|
||||
|
||||
if operation_mode == "sync":
|
||||
file_data["last_sync"] = datetime.datetime.now()
|
||||
else:
|
||||
file_data["last_sync"] = datetime.datetime.now()
|
||||
file_data["last_sync"] = datetime.datetime.now()
|
||||
|
||||
if operation_mode == "sync":
|
||||
try:
|
||||
with db_session() as conn:
|
||||
repo = SourcesRepository(conn)
|
||||
src = repo.get_any(str(id), user)
|
||||
src = repo.get_any(source_id_for_events, user)
|
||||
if src is not None:
|
||||
repo.update(
|
||||
str(src["id"]), user,
|
||||
@@ -1476,7 +1869,9 @@ def ingest_connector(
|
||||
)
|
||||
except Exception as upd_err:
|
||||
logging.warning(
|
||||
f"Failed to update last_sync for source {id}: {upd_err}"
|
||||
"Failed to update last_sync for source %s: %s",
|
||||
source_id_for_events,
|
||||
upd_err,
|
||||
)
|
||||
|
||||
upload_index(vector_store_path, file_data)
|
||||
@@ -1488,45 +1883,104 @@ def ingest_connector(
|
||||
|
||||
logging.info(f"Remote ingestion completed: {job_name}")
|
||||
|
||||
publish_user_event(
|
||||
user,
|
||||
"source.ingest.completed",
|
||||
{
|
||||
"source_id": source_id_for_events,
|
||||
"job_name": job_name,
|
||||
"loader": source_type,
|
||||
"operation": operation_mode,
|
||||
"tokens": tokens,
|
||||
},
|
||||
scope={"kind": "source", "id": source_id_for_events},
|
||||
)
|
||||
|
||||
return {
|
||||
"user": user,
|
||||
"name": job_name,
|
||||
"tokens": tokens,
|
||||
"type": source_type,
|
||||
"id": str(id),
|
||||
"id": source_id_for_events,
|
||||
"status": "complete",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error during remote ingestion: {e}", exc_info=True)
|
||||
raise
|
||||
except Exception as e:
|
||||
logging.error(f"Error during remote ingestion: {e}", exc_info=True)
|
||||
publish_user_event(
|
||||
user,
|
||||
"source.ingest.failed",
|
||||
{
|
||||
"source_id": source_id_for_events,
|
||||
"job_name": job_name,
|
||||
"loader": source_type,
|
||||
"operation": operation_mode,
|
||||
"error": str(e)[:1024],
|
||||
},
|
||||
scope={"kind": "source", "id": source_id_for_events},
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def mcp_oauth(self, config: Dict[str, Any], user_id: str = None) -> Dict[str, Any]:
|
||||
"""Worker to handle MCP OAuth flow asynchronously."""
|
||||
"""Worker to handle MCP OAuth flow asynchronously.
|
||||
|
||||
Publishes SSE events at each phase boundary so the frontend can
|
||||
drive the OAuth popup directly from the push channel. The
|
||||
``mcp.oauth.awaiting_redirect`` envelope carries the
|
||||
``authorization_url`` once the upstream OAuth client surfaces it,
|
||||
eliminating the prior polling-only path for that URL.
|
||||
"""
|
||||
|
||||
# Bind ``task_id`` and the publish helpers OUTSIDE the outer try so
|
||||
# the ``except`` handler at the bottom can reach them even when an
|
||||
# early statement raises. Without this, ``publish_oauth`` would
|
||||
# UnboundLocalError on top of the original failure.
|
||||
task_id = self.request.id if getattr(self, "request", None) else None
|
||||
|
||||
def publish_oauth(event_type: str, payload: Dict[str, Any]) -> None:
|
||||
# MCP OAuth can be invoked without a route-bound user_id by
|
||||
# legacy paths. Skip the SSE publish in that case \u2014 the caller
|
||||
# has no per-user channel to subscribe to, and the status is
|
||||
# surfaced via the task's return value.
|
||||
if not user_id or task_id is None:
|
||||
return
|
||||
publish_user_event(
|
||||
user_id,
|
||||
event_type,
|
||||
{"task_id": task_id, **payload},
|
||||
scope={"kind": "mcp_oauth", "id": task_id},
|
||||
)
|
||||
|
||||
def publish_awaiting_redirect(authorization_url: str) -> None:
|
||||
"""Callback invoked by ``DocsGPTOAuth.redirect_handler`` once
|
||||
the OAuth client has minted the authorization URL.
|
||||
|
||||
Carrying the URL on the SSE envelope lets the frontend open the
|
||||
popup directly from the event \u2014 the prior polling-only path
|
||||
for the URL is gone.
|
||||
"""
|
||||
publish_oauth(
|
||||
"mcp.oauth.awaiting_redirect",
|
||||
{
|
||||
"message": "Awaiting OAuth redirect...",
|
||||
"authorization_url": authorization_url,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
task_id = self.request.id
|
||||
redis_client = get_redis_instance()
|
||||
|
||||
def update_status(status_data: Dict[str, Any]):
|
||||
status_key = f"mcp_oauth_status:{task_id}"
|
||||
redis_client.setex(status_key, 600, json.dumps(status_data))
|
||||
|
||||
update_status(
|
||||
{
|
||||
"status": "in_progress",
|
||||
"message": "Starting OAuth...",
|
||||
"task_id": task_id,
|
||||
}
|
||||
)
|
||||
publish_oauth("mcp.oauth.in_progress", {"message": "Starting OAuth..."})
|
||||
|
||||
tool_config = config.copy()
|
||||
tool_config["oauth_task_id"] = task_id
|
||||
# Inject the awaiting-redirect publish callback. ``MCPTool`` pops
|
||||
# it out of the config and threads it into ``DocsGPTOAuth`` so
|
||||
# the publish fires synchronously from inside
|
||||
# ``redirect_handler`` \u2014 the only point where the URL is known.
|
||||
tool_config["oauth_redirect_publish"] = publish_awaiting_redirect
|
||||
mcp_tool = MCPTool(tool_config, user_id)
|
||||
|
||||
async def run_oauth_discovery():
|
||||
@@ -1534,14 +1988,6 @@ def mcp_oauth(self, config: Dict[str, Any], user_id: str = None) -> Dict[str, An
|
||||
mcp_tool._setup_client()
|
||||
return await mcp_tool._execute_with_client("list_tools")
|
||||
|
||||
update_status(
|
||||
{
|
||||
"status": "awaiting_redirect",
|
||||
"message": "Awaiting OAuth redirect...",
|
||||
"task_id": task_id,
|
||||
}
|
||||
)
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
@@ -1549,49 +1995,21 @@ def mcp_oauth(self, config: Dict[str, Any], user_id: str = None) -> Dict[str, An
|
||||
loop.run_until_complete(run_oauth_discovery())
|
||||
tools = mcp_tool.get_actions_metadata()
|
||||
|
||||
update_status(
|
||||
{
|
||||
"status": "completed",
|
||||
"message": f"Connected \u2014 found {len(tools)} tool{'s' if len(tools) != 1 else ''}.",
|
||||
"tools": tools,
|
||||
"tools_count": len(tools),
|
||||
"task_id": task_id,
|
||||
}
|
||||
publish_oauth(
|
||||
"mcp.oauth.completed",
|
||||
{"tools": tools, "tools_count": len(tools)},
|
||||
)
|
||||
|
||||
return {"success": True, "tools": tools, "tools_count": len(tools)}
|
||||
except Exception as e:
|
||||
error_msg = f"OAuth failed: {str(e)}"
|
||||
logging.error("MCP OAuth discovery failed: %s", error_msg, exc_info=True)
|
||||
update_status(
|
||||
{
|
||||
"status": "error",
|
||||
"message": error_msg,
|
||||
"task_id": task_id,
|
||||
}
|
||||
)
|
||||
publish_oauth("mcp.oauth.failed", {"error": error_msg[:1024]})
|
||||
return {"success": False, "error": error_msg}
|
||||
finally:
|
||||
loop.close()
|
||||
except Exception as e:
|
||||
error_msg = f"OAuth init failed: {str(e)}"
|
||||
logging.error("MCP OAuth init failed: %s", error_msg, exc_info=True)
|
||||
update_status(
|
||||
{
|
||||
"status": "error",
|
||||
"message": error_msg,
|
||||
"task_id": task_id,
|
||||
}
|
||||
)
|
||||
publish_oauth("mcp.oauth.failed", {"error": error_msg[:1024]})
|
||||
return {"success": False, "error": error_msg}
|
||||
|
||||
|
||||
def mcp_oauth_status(self, task_id: str) -> Dict[str, Any]:
|
||||
"""Check the status of an MCP OAuth flow."""
|
||||
redis_client = get_redis_instance()
|
||||
status_key = f"mcp_oauth_status:{task_id}"
|
||||
|
||||
status_data = redis_client.get(status_key)
|
||||
if status_data:
|
||||
return json.loads(status_data)
|
||||
return {"status": "not_found", "message": "Status not found"}
|
||||
|
||||
385
docs/runbooks/sse-notifications.md
Normal file
385
docs/runbooks/sse-notifications.md
Normal file
@@ -0,0 +1,385 @@
|
||||
# SSE Notifications Runbook
|
||||
|
||||
> Operations guide for "user says they didn't get a notification" — and
|
||||
> the related "the bell never lights up" / "my upload toast hangs" /
|
||||
> "the chat answer doesn't reconnect" symptoms.
|
||||
|
||||
The user-facing notifications channel is the SSE pipe at
|
||||
`/api/events` plus per-message reconnects at
|
||||
`/api/messages/<id>/events`. This document maps a user complaint to
|
||||
the diagnostic that surfaces the cause.
|
||||
|
||||
---
|
||||
|
||||
## TL;DR — first 60 seconds
|
||||
|
||||
Run these three commands in parallel before anything else:
|
||||
|
||||
```bash
|
||||
# 1) Is Redis up and serving the pipe? Should print PONG instantly.
|
||||
redis-cli -n 2 PING
|
||||
|
||||
# 2) Anyone subscribed to the channel right now? Numbers per channel.
|
||||
redis-cli -n 2 PUBSUB NUMSUB user:<user_id>
|
||||
|
||||
# 3) Is the user's backlog populated? Returns the count of journaled events.
|
||||
redis-cli -n 2 XLEN user:<user_id>:stream
|
||||
```
|
||||
|
||||
- `PING` failing → Redis is the problem. Skip to "Redis-down".
|
||||
- `NUMSUB user:<user_id>` returns 0 → no client connected. Skip to "Client never connects".
|
||||
- `XLEN user:<user_id>:stream` returns 0 or low → publisher isn't writing. Skip to "Publisher silent".
|
||||
- All three look healthy → the events are flowing on the wire; the issue is downstream of the slice (UI rendering, toast suppression, etc.). Skip to "Events flowing but UI silent".
|
||||
|
||||
---
|
||||
|
||||
## Architecture cheat-sheet
|
||||
|
||||
```
|
||||
Worker (publish_user_event) Frontend tab
|
||||
│ ▲
|
||||
▼ │ GET /api/events SSE
|
||||
Redis Streams: XADD Flask route
|
||||
user:<id>:stream ──────────────► replay_backlog (snapshot)
|
||||
│ +
|
||||
▼ Topic.subscribe (live tail)
|
||||
Redis pub/sub: PUBLISH │
|
||||
user:<id> ────────────────────────────────┘
|
||||
```
|
||||
|
||||
**Source of truth:**
|
||||
- Persistent journal: Redis Stream `user:<user_id>:stream`, capped at
|
||||
`EVENTS_STREAM_MAXLEN` (default 1000) entries via `MAXLEN ~`. ~24h
|
||||
at typical event rates.
|
||||
- Live fan-out: Redis pub/sub channel `user:<user_id>`. No durability;
|
||||
subscribers must be attached at publish time.
|
||||
|
||||
The chat-stream pipe is separate, parallel infrastructure:
|
||||
- Journal: Postgres `message_events` table.
|
||||
- Live fan-out: Redis pub/sub `channel:<message_id>`.
|
||||
|
||||
Same patterns, different durability layer. This doc covers both;
|
||||
they share most diagnostic commands.
|
||||
|
||||
---
|
||||
|
||||
## Symptom → diagnostic map
|
||||
|
||||
### A. "I uploaded a source and the toast never appeared"
|
||||
|
||||
User flow: chat → upload → expect toast.
|
||||
|
||||
| Step | Command | Expect |
|
||||
| ------------------------------------------------- | ------------------------------------------------------------- | ----------------------------------------------- |
|
||||
| Worker received the task | `tail -f celery.log` filtered by user | `ingest_worker` start log line |
|
||||
| Worker published the queued event | `redis-cli -n 2 XREVRANGE user:<id>:stream + - COUNT 5` | A `source.ingest.queued` entry within seconds |
|
||||
| Frontend got it | DevTools → Network → `/api/events` → EventStream tab | `data: {"type":"source.ingest.queued",...}` |
|
||||
| Slice updated | Redux DevTools → state.upload.tasks | Task with matching `sourceId`, `status:'training'` |
|
||||
|
||||
If the worker's queued log line is there but the XADD didn't land →
|
||||
look for a `publish_user_event payload not JSON-serializable` warning
|
||||
in the worker log (the publisher swallows `TypeError`).
|
||||
|
||||
If the XADD landed but the frontend never received it → check
|
||||
`PUBSUB NUMSUB user:<id>` while the user is on the page. If 0, the
|
||||
SSE connection isn't subscribed; skip to "Client never connects".
|
||||
|
||||
If the frontend received it but the toast didn't render → the
|
||||
`uploadSlice` extraReducer requires `task.sourceId` to match the
|
||||
event's `scope.id`. Check the upload route returned `source_id` in
|
||||
its POST response (the upload, connector, and reingest paths all
|
||||
include it). Idempotent / cached responses must also include
|
||||
`source_id` (`_claim_task_or_get_cached`).
|
||||
|
||||
### B. "The bell badge never goes up"
|
||||
|
||||
There is no bell — the global notifications surface is per-event
|
||||
toasts, not an aggregated counter. If the user is on an old build,
|
||||
`Cmd-Shift-R` to bypass cache. The surfaces they're looking for are
|
||||
`UploadToast` for source uploads and `ToolApprovalToast` for
|
||||
tool-approval events.
|
||||
|
||||
### C. "My chat answer froze mid-stream and never recovered"
|
||||
|
||||
User flow: ask question → answer streaming → network blip → answer
|
||||
stops; should reconnect.
|
||||
|
||||
```bash
|
||||
# Was the original message reserved in PG?
|
||||
psql -c "SELECT id, status, prompt FROM conversation_messages \
|
||||
WHERE user_id = '<user>' ORDER BY timestamp DESC LIMIT 5;"
|
||||
|
||||
# Did the journal capture events past the user's last-seen seq?
|
||||
psql -c "SELECT sequence_no, event_type FROM message_events \
|
||||
WHERE message_id = '<id>' ORDER BY sequence_no;"
|
||||
|
||||
# Is the live tail still producing? (subscribe and watch)
|
||||
redis-cli -n 2 SUBSCRIBE channel:<message_id>
|
||||
```
|
||||
|
||||
The frontend should reconnect via `GET /api/messages/<id>/events`
|
||||
when the original POST stream closes without a typed `end` or
|
||||
`error` event. If it's not reconnecting, `console.warn('Stream
|
||||
reconnect failed', ...)` will be in the browser console — the
|
||||
reconnect HTTP errored. Common cases:
|
||||
|
||||
- The user's JWT rotated mid-stream → 401 on the GET. Frontend
|
||||
doesn't auto-refresh; the user reloads.
|
||||
- The user is on a different host than the API and CORS is rejecting
|
||||
the GET → check `application/asgi.py` allow-headers.
|
||||
|
||||
### D. "The dev install never delivers any notifications at all"
|
||||
|
||||
Default `AUTH_TYPE` unset means `decoded_token = {"sub": "local"}`
|
||||
for every request. The SSE client connects without the
|
||||
`Authorization` header in this case, and `user:local:stream` is
|
||||
the shared channel everything goes to. If the user has multiple dev
|
||||
machines pointing at the same Redis, they will see each other's
|
||||
events. Confirm with:
|
||||
|
||||
```bash
|
||||
redis-cli -n 2 KEYS 'user:local:*'
|
||||
```
|
||||
|
||||
If multiple deployments share the Redis, document that as a known
|
||||
multi-user-on-local-channel limitation. Set `AUTH_TYPE=simple_jwt`
|
||||
to scope per-user.
|
||||
|
||||
### E. "The notifications channel was working, then suddenly stopped after the user reloaded the page"
|
||||
|
||||
Likely path: `backlog.truncated` event fired, the slice cleared
|
||||
`lastEventId` to null, the closure was carrying the same stale id and
|
||||
re-tripped the same truncation on every reconnect. **Verify the user
|
||||
is on a current build — `eventStreamClient.ts` must re-read
|
||||
`lastEventId = opts.getLastEventId();` without a truthy guard so the
|
||||
null clear propagates into the next reconnect.**
|
||||
|
||||
### F. "I keep getting 429 on /api/events"
|
||||
|
||||
The per-user concurrent-connection cap (`SSE_MAX_CONCURRENT_PER_USER`,
|
||||
default 8) refused the connection. User has too many tabs open or a
|
||||
runaway reconnect loop. `redis-cli -n 2 GET user:<id>:sse_count`
|
||||
shows the live counter; the TTL is 1h from the last connection
|
||||
attempt (rolling — every INCR re-seeds it), so the key only ages
|
||||
out after the user stops reconnecting for a full hour.
|
||||
|
||||
If the count is wedged high without explanation, the
|
||||
counter-DECR-in-finally path didn't run (worker SIGKILL, OOM). Wait
|
||||
for the TTL or `redis-cli -n 2 DEL user:<id>:sse_count` to reset.
|
||||
|
||||
### G. "Replay snapshot stops at 200 events"
|
||||
|
||||
The route caps each replay at `EVENTS_REPLAY_MAX_PER_REQUEST`
|
||||
(default 200). The cap is intentionally **silent** — the route does NOT
|
||||
emit a `backlog.truncated` notice for cap-hit. The 200 entries each
|
||||
carry their own `id:` header, so the frontend's slice cursor
|
||||
advances to the most-recent delivered id. Next reconnect sends
|
||||
`last_event_id=<max_replayed>` and the snapshot resumes from there.
|
||||
A user that was 1000 entries behind catches up over ~5 reconnects.
|
||||
|
||||
If the user reports getting HTTP 429 on `/api/events` despite being
|
||||
well under `SSE_MAX_CONCURRENT_PER_USER`, they hit the windowed
|
||||
replay budget (`EVENTS_REPLAY_BUDGET_REQUESTS_PER_WINDOW`, default
|
||||
30 / `EVENTS_REPLAY_BUDGET_WINDOW_SECONDS` 60s). The route refuses
|
||||
the connection so the slice cursor stays pinned at whatever value
|
||||
it had; the frontend backs off and the next reconnect (after the
|
||||
window rolls) gets the proper snapshot. Serving the live tail
|
||||
without a snapshot used to be the behavior here, but that let the
|
||||
client advance `lastEventId` past entries it never received,
|
||||
permanently stranding the un-replayed window — so the route now
|
||||
429s instead. `redis-cli -n 2 GET user:<id>:replay_count` shows the
|
||||
current counter; TTL is the window size.
|
||||
|
||||
`backlog.truncated` is emitted ONLY when the client's
|
||||
`Last-Event-ID` has slid off the MAXLEN'd window — i.e. the journal
|
||||
is genuinely gone past the cursor and the frontend should clear the
|
||||
slice cursor and refetch state. Treating cap-hit or
|
||||
budget-exhaustion the same way would lock the user into re-receiving
|
||||
the oldest 200 entries on every reconnect (the cursor would clear,
|
||||
the snapshot would re-serve from the start, the cap would re-trip).
|
||||
|
||||
### H. "User says push notifications stopped after a deploy"
|
||||
|
||||
- Pull `event.published topic=user:<id> type=...` from the worker
|
||||
logs to confirm the publisher is still firing.
|
||||
- Pull `event.connect user=<id>` from the API logs to confirm the
|
||||
client is reconnecting.
|
||||
- Check the gunicorn worker count and `WSGIMiddleware(workers=32)` —
|
||||
if the deploy reduced worker count, the per-user cap is still 8
|
||||
but total concurrent SSE connections are bounded by `gunicorn
|
||||
workers × 32`. A capacity miss looks like users randomly getting
|
||||
429'd.
|
||||
|
||||
---
|
||||
|
||||
## Common failure modes
|
||||
|
||||
### Redis-down
|
||||
|
||||
Symptoms: `/api/events` returns 200 but emits only `: connected`
|
||||
then the body closes. `XLEN` and `PUBLISH` both fail. The publisher's
|
||||
`record_event` swallows the failure and returns False; the live tail
|
||||
publish also drops on the floor. Frontend retries forever with
|
||||
exponential backoff.
|
||||
|
||||
Resolution: bring Redis back. The journal is gone (was in-memory
|
||||
only — Streams persist within a single Redis instance, no replication
|
||||
configured). New events flow as soon as Redis comes back.
|
||||
|
||||
### `AUTH_TYPE` misconfigured = sub:"local" cross-stream
|
||||
|
||||
Symptoms: every user shares `user:local:stream`. Any user sees
|
||||
everyone else's notifications.
|
||||
|
||||
Resolution: set `AUTH_TYPE=simple_jwt` (or `session_jwt`) in `.env`.
|
||||
The events route logs a one-time WARNING per process when
|
||||
`sub == "local"` is observed. A repeat WARNING after a restart
|
||||
confirms the misconfiguration.
|
||||
|
||||
### MAXLEN trimmed past Last-Event-ID
|
||||
|
||||
Symptoms: client reconnects with `last_event_id=X`, snapshot returns
|
||||
the entire MAXLEN'd backlog (because X is older than the oldest
|
||||
retained entry). Old events appear duplicated.
|
||||
|
||||
Detection: the route's `_oldest_retained_id` check emits
|
||||
`backlog.truncated` when this case fires. Frontend's
|
||||
`dispatchSSEEvent` clears `lastEventId` so the next reconnect starts
|
||||
fresh.
|
||||
|
||||
If the WARNING isn't firing but symptoms match: the user's client
|
||||
may have a corrupt cached `lastEventId`. `localStorage` doesn't
|
||||
store this state; check Redux state via DevTools.
|
||||
|
||||
### Stale event-stream client
|
||||
|
||||
Symptoms: events visible in `XRANGE` but the frontend slice doesn't
|
||||
update.
|
||||
|
||||
```bash
|
||||
# Is the client subscribed?
|
||||
redis-cli -n 2 PUBSUB NUMSUB user:<id>
|
||||
|
||||
# When did its connection start?
|
||||
grep "event.connect user=<id>" /var/log/docsgpt.log | tail -3
|
||||
```
|
||||
|
||||
If `NUMSUB` is 0 and no recent `event.connect`, the user's tab is
|
||||
closed or the connection died and never reconnected. Push them to
|
||||
reload.
|
||||
|
||||
### Publisher silent
|
||||
|
||||
Symptoms: worker is processing the task (Celery says SUCCESS), but
|
||||
no XADD and no PUBLISH. User sees no events.
|
||||
|
||||
```bash
|
||||
# Was the publisher import error suppressed?
|
||||
grep "publish_user_event" /var/log/celery.log | grep -i "warn\|error" | tail -20
|
||||
|
||||
# Is push disabled?
|
||||
grep "ENABLE_SSE_PUSH" /var/log/docsgpt.log | tail -5
|
||||
```
|
||||
|
||||
`ENABLE_SSE_PUSH=False` in `.env` would silence the publisher
|
||||
globally. Useful for incident response if a runaway publisher is
|
||||
DoS'ing Redis; toggle off, fix root cause, toggle on.
|
||||
|
||||
---
|
||||
|
||||
## Useful one-liners
|
||||
|
||||
```bash
|
||||
# Watch a user's live event stream in real time (all events, all types)
|
||||
redis-cli -n 2 PSUBSCRIBE 'user:*' | grep "user:<id>"
|
||||
|
||||
# Last 10 events the user would see on reconnect
|
||||
redis-cli -n 2 XREVRANGE user:<id>:stream + - COUNT 10
|
||||
|
||||
# Live count of subscribed clients per user
|
||||
redis-cli -n 2 PUBSUB NUMSUB $(redis-cli -n 2 PUBSUB CHANNELS 'user:*')
|
||||
|
||||
# Trim a runaway stream (CAREFUL — destroys backlog for all current
|
||||
# subscribers; OK after explaining to the user)
|
||||
redis-cli -n 2 XTRIM user:<id>:stream MAXLEN 0
|
||||
|
||||
# Clear a wedged concurrent-connection counter
|
||||
redis-cli -n 2 DEL user:<id>:sse_count
|
||||
|
||||
# Force-flip every client to re-snapshot (drop the stream key entirely
|
||||
# — destroys the backlog; clients reconnect with their last id and
|
||||
# get a backlog.truncated)
|
||||
redis-cli -n 2 DEL user:<id>:stream
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Settings reference
|
||||
|
||||
Everything in `application/core/settings.py`:
|
||||
|
||||
| Setting | Default | Purpose |
|
||||
| --------------------------------------------- | ------- | --------------------------------------------- |
|
||||
| `ENABLE_SSE_PUSH` | `True` | Master switch. False = publisher no-ops, route serves "push_disabled" comment. |
|
||||
| `EVENTS_STREAM_MAXLEN` | `1000` | Per-user backlog cap. Approximate via `XADD MAXLEN ~`. |
|
||||
| `SSE_KEEPALIVE_SECONDS` | `15` | Comment-frame cadence. Must sit under reverse-proxy idle close. |
|
||||
| `SSE_MAX_CONCURRENT_PER_USER` | `8` | Cap on simultaneous SSE connections per user. 0 = disabled. |
|
||||
| `EVENTS_REPLAY_MAX_PER_REQUEST` | `200` | Hard cap on snapshot rows per request. |
|
||||
| `EVENTS_REPLAY_BUDGET_REQUESTS_PER_WINDOW` | `30` | Per-user replays per window. 0 = disabled. |
|
||||
| `EVENTS_REPLAY_BUDGET_WINDOW_SECONDS` | `60` | Window length. |
|
||||
| `MESSAGE_EVENTS_RETENTION_DAYS` | `14` | Retention for the `message_events` journal; `cleanup_message_events` beat task deletes older rows. |
|
||||
|
||||
---
|
||||
|
||||
## Known limitations
|
||||
|
||||
### Each tab runs its own SSE connection
|
||||
|
||||
There is no cross-tab dedup. Every tab open to the app holds its
|
||||
own SSE connection and dispatches every received event into its
|
||||
own Redux store, so a user with N tabs open will see N copies of
|
||||
each toast. With `SSE_MAX_CONCURRENT_PER_USER=8` (the default) a
|
||||
heavy multi-tab user can also hit the connection cap and start
|
||||
seeing 429s. Cross-tab dedup via a `BroadcastChannel` ring +
|
||||
`navigator.locks`-based leader election is tracked as future work.
|
||||
|
||||
### `/c/<unknown-id>` normalises to `/c/new`
|
||||
|
||||
If a user navigates to a conversation id that isn't in their
|
||||
loaded list, the conversation route rewrites the URL to `/c/new`.
|
||||
`ToolApprovalToast`'s gate uses `useMatch('/c/:conversationId')`,
|
||||
so for the brief window after the rewrite the toast may surface
|
||||
for a conversation the user *thought* they were already viewing.
|
||||
Pre-existing route behaviour; not a notifications regression.
|
||||
|
||||
### Terminal events un-dismiss running uploads
|
||||
|
||||
`frontend/src/upload/uploadSlice.ts` sets `dismissed: false` when
|
||||
an upload reaches `completed` or `failed`. If the user dismissed a
|
||||
running task and the terminal SSE arrives later, the toast pops
|
||||
back. Intentional ("notify the user it's done"); revisit if the
|
||||
re-surface UX is too aggressive for v2.
|
||||
|
||||
### Werkzeug doesn't auto-reload route files
|
||||
|
||||
The dev server (`flask run`) doesn't watch
|
||||
`application/api/events/routes.py` for changes by default.
|
||||
After editing the route, restart Flask manually — `--reload`
|
||||
isn't on. (Production gunicorn reloads via deploy.)
|
||||
|
||||
### MCP OAuth completion can fall outside the user stream's MAXLEN window
|
||||
|
||||
`get_oauth_status` scans up to `EVENTS_STREAM_MAXLEN` (~1000) entries via `XREVRANGE`. If the user has a high-rate ingest running concurrent with the OAuth handshake, the `mcp.oauth.completed` envelope can be trimmed off the back before they click Save. Symptom: backend returns "OAuth failed or not completed" even though the popup completed successfully.
|
||||
|
||||
Mitigation today: bump `EVENTS_STREAM_MAXLEN` per-deployment if your users routinely flood the channel during OAuth flows. A dedicated short-TTL Redis key for OAuth task results is tracked as a follow-up.
|
||||
|
||||
### React StrictMode double-mounts SSE
|
||||
|
||||
In dev, React 18 StrictMode mounts → unmounts → remounts every
|
||||
component, briefly opening two SSE connections per tab before the
|
||||
first is aborted. With `SSE_MAX_CONCURRENT_PER_USER=8` and 4–5
|
||||
tabs open concurrently you can transiently hit the cap and see
|
||||
HTTP 429 on cold-load. The first connection's counter increment
|
||||
fires before the AbortController-induced disconnect can decrement
|
||||
it. Production (single mount, no StrictMode) is unaffected; raise
|
||||
the cap in dev or accept transient 429s.
|
||||
1531
frontend/package-lock.json
generated
1531
frontend/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -7,6 +7,8 @@
|
||||
"dev": "vite",
|
||||
"build": "tsc && vite build",
|
||||
"preview": "vite preview",
|
||||
"test": "vitest run",
|
||||
"test:watch": "vitest",
|
||||
"lint": "eslint ./src --ext .jsx,.js,.ts,.tsx",
|
||||
"lint-fix": "eslint ./src --ext .jsx,.js,.ts,.tsx --fix",
|
||||
"format": "prettier ./src --write",
|
||||
@@ -69,6 +71,7 @@
|
||||
"eslint-plugin-promise": "^6.6.0",
|
||||
"eslint-plugin-react": "^7.37.5",
|
||||
"eslint-plugin-unused-imports": "^4.1.4",
|
||||
"happy-dom": "^17.6.3",
|
||||
"husky": "^9.1.7",
|
||||
"lint-staged": "^16.4.0",
|
||||
"postcss": "^8.5.12",
|
||||
@@ -78,6 +81,7 @@
|
||||
"tw-animate-css": "^1.4.0",
|
||||
"typescript": "^6.0.3",
|
||||
"vite": "^8.0.10",
|
||||
"vite-plugin-svgr": "^4.3.0"
|
||||
"vite-plugin-svgr": "^4.3.0",
|
||||
"vitest": "^3.2.4"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import Spinner from './components/Spinner';
|
||||
import UploadToast from './components/UploadToast';
|
||||
import Conversation from './conversation/Conversation';
|
||||
import { SharedConversation } from './conversation/SharedConversation';
|
||||
import { EventStreamProvider } from './events/EventStreamProvider';
|
||||
import { useDarkTheme, useMediaQuery } from './hooks';
|
||||
import useDataInitializer from './hooks/useDataInitializer';
|
||||
import useTokenAuth from './hooks/useTokenAuth';
|
||||
@@ -17,6 +18,7 @@ import Navigation from './Navigation';
|
||||
import PageNotFound from './PageNotFound';
|
||||
import Setting from './settings';
|
||||
import Notification from './components/Notification';
|
||||
import ToolApprovalToast from './notifications/ToolApprovalToast';
|
||||
|
||||
function AuthWrapper({ children }: { children: React.ReactNode }) {
|
||||
const { isAuthLoading } = useTokenAuth();
|
||||
@@ -29,7 +31,7 @@ function AuthWrapper({ children }: { children: React.ReactNode }) {
|
||||
</div>
|
||||
);
|
||||
}
|
||||
return <>{children}</>;
|
||||
return <EventStreamProvider>{children}</EventStreamProvider>;
|
||||
}
|
||||
|
||||
function MainLayout() {
|
||||
@@ -50,6 +52,7 @@ function MainLayout() {
|
||||
<Outlet />
|
||||
</div>
|
||||
<UploadToast />
|
||||
<ToolApprovalToast />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -85,6 +88,13 @@ export default function App() {
|
||||
}
|
||||
>
|
||||
<Route index element={<Conversation />} />
|
||||
{/* One dynamic route (accepting "new" or a UUID) so the
|
||||
/c/new → /c/<id> replace doesn't remount Conversation. */}
|
||||
<Route path="/c/:conversationId" element={<Conversation />} />
|
||||
<Route
|
||||
path="/agents/:agentId/c/:conversationId"
|
||||
element={<Conversation />}
|
||||
/>
|
||||
<Route path="/settings/*" element={<Setting />} />
|
||||
<Route path="/agents/*" element={<Agents />} />
|
||||
</Route>
|
||||
|
||||
@@ -25,6 +25,7 @@ import UnPin from './assets/unpin.svg';
|
||||
import Help from './components/Help';
|
||||
import {
|
||||
handleAbort,
|
||||
loadConversation,
|
||||
selectQueries,
|
||||
setConversation,
|
||||
updateConversationId,
|
||||
@@ -50,6 +51,7 @@ import {
|
||||
setSelectedAgent,
|
||||
setSharedAgents,
|
||||
} from './preferences/preferenceSlice';
|
||||
import { AppDispatch } from './store';
|
||||
import Upload from './upload/Upload';
|
||||
|
||||
interface NavigationProps {
|
||||
@@ -58,7 +60,7 @@ interface NavigationProps {
|
||||
}
|
||||
|
||||
export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
const dispatch = useDispatch();
|
||||
const dispatch = useDispatch<AppDispatch>();
|
||||
const navigate = useNavigate();
|
||||
|
||||
const { t } = useTranslation();
|
||||
@@ -182,7 +184,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
resetConversation();
|
||||
dispatch(setSelectedAgent(agent));
|
||||
if (isMobile || isTablet) setNavOpen(!navOpen);
|
||||
navigate('/');
|
||||
navigate(agent.id ? `/agents/${agent.id}/c/new` : '/c/new');
|
||||
};
|
||||
|
||||
const handleTogglePin = (agent: Agent) => {
|
||||
@@ -200,20 +202,21 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
try {
|
||||
dispatch(setSelectedAgent(null));
|
||||
|
||||
const response = await conversationService.getConversation(index, token);
|
||||
if (!response.ok) {
|
||||
navigate('/');
|
||||
// Pre-fetch to choose the route shape (owned-agent / shared / none).
|
||||
const result = await dispatch(
|
||||
loadConversation({ id: index, force: true }),
|
||||
).unwrap();
|
||||
// Stale: a newer load has already updated Redux; the URL is
|
||||
// wherever that newer flow lands, leave it alone.
|
||||
if (result.stale) return;
|
||||
const data = result.data;
|
||||
if (!data) {
|
||||
navigate('/c/new');
|
||||
return;
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
if (!data) return;
|
||||
|
||||
dispatch(setConversation(data.queries));
|
||||
dispatch(updateConversationId({ query: { conversationId: index } }));
|
||||
|
||||
if (!data.agent_id) {
|
||||
navigate('/');
|
||||
navigate(`/c/${index}`);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -224,7 +227,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
token,
|
||||
);
|
||||
if (!sharedResponse.ok) {
|
||||
navigate('/');
|
||||
navigate(`/c/${index}`);
|
||||
return;
|
||||
}
|
||||
agent = await sharedResponse.json();
|
||||
@@ -232,7 +235,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
} else {
|
||||
const agentResponse = await userService.getAgent(data.agent_id, token);
|
||||
if (!agentResponse.ok) {
|
||||
navigate('/');
|
||||
navigate(`/c/${index}`);
|
||||
return;
|
||||
}
|
||||
agent = await agentResponse.json();
|
||||
@@ -240,12 +243,12 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
navigate(`/agents/shared/${agent.shared_token}`);
|
||||
} else {
|
||||
await Promise.resolve(dispatch(setSelectedAgent(agent)));
|
||||
navigate('/');
|
||||
navigate(`/agents/${data.agent_id}/c/${index}`);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error handling conversation click:', error);
|
||||
navigate('/');
|
||||
navigate('/c/new');
|
||||
}
|
||||
};
|
||||
|
||||
@@ -264,6 +267,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
if (queries && queries?.length > 0) {
|
||||
resetConversation();
|
||||
}
|
||||
navigate('/c/new');
|
||||
};
|
||||
|
||||
async function updateConversationName(updatedConversation: {
|
||||
@@ -275,7 +279,6 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
.then((response) => response.json())
|
||||
.then((data) => {
|
||||
if (data) {
|
||||
navigate('/');
|
||||
fetchConversations();
|
||||
}
|
||||
})
|
||||
@@ -370,7 +373,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
</button>
|
||||
</div>
|
||||
<NavLink
|
||||
to={'/'}
|
||||
to={'/c/new'}
|
||||
onClick={() => {
|
||||
if (isMobile || isTablet) {
|
||||
setNavOpen(!navOpen);
|
||||
|
||||
@@ -174,7 +174,7 @@ export default function AgentCard({
|
||||
if (section === 'user') {
|
||||
if (agent.status === 'published') {
|
||||
dispatch(setSelectedAgent(agent));
|
||||
navigate(`/`);
|
||||
navigate(agent.id ? `/agents/${agent.id}/c/new` : '/c/new');
|
||||
}
|
||||
}
|
||||
if (section === 'shared') {
|
||||
|
||||
@@ -565,8 +565,22 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
setJsonSchemaText(jsonText);
|
||||
setJsonSchemaValid(true);
|
||||
}
|
||||
setAgent(data);
|
||||
initialAgentRef.current = data;
|
||||
// Backfill required fields so older agents (created before
|
||||
// agent_type / prompt_id / models existed) don't fail
|
||||
// ``isPublishable()`` and leave Save permanently disabled.
|
||||
const normalized = {
|
||||
...data,
|
||||
agent_type: data.agent_type || 'classic',
|
||||
prompt_id: data.prompt_id || 'default',
|
||||
retriever: data.retriever || 'classic',
|
||||
chunks: data.chunks || '2',
|
||||
tools: data.tools || [],
|
||||
sources: data.sources || [],
|
||||
models: data.models || [],
|
||||
default_model_id: data.default_model_id || '',
|
||||
};
|
||||
setAgent(normalized);
|
||||
initialAgentRef.current = normalized;
|
||||
};
|
||||
getAgent();
|
||||
}
|
||||
|
||||
@@ -1,8 +1,18 @@
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import EditIcon from '../assets/edit.svg';
|
||||
import AgentImage from '../components/AgentImage';
|
||||
import { getToolDisplayName } from '../utils/toolUtils';
|
||||
import { Agent } from './types';
|
||||
|
||||
export default function SharedAgentCard({ agent }: { agent: Agent }) {
|
||||
export default function SharedAgentCard({
|
||||
agent,
|
||||
onEdit,
|
||||
}: {
|
||||
agent: Agent;
|
||||
onEdit?: () => void;
|
||||
}) {
|
||||
const { t } = useTranslation();
|
||||
// Check if shared metadata exists and has properties (type is 'any' so we validate it's a non-empty object)
|
||||
const hasSharedMetadata =
|
||||
agent.shared_metadata &&
|
||||
@@ -11,14 +21,14 @@ export default function SharedAgentCard({ agent }: { agent: Agent }) {
|
||||
Object.keys(agent.shared_metadata).length > 0;
|
||||
return (
|
||||
<div className="border-border dark:border-border flex w-full max-w-[720px] flex-col rounded-3xl border p-6 shadow-xs sm:w-fit sm:min-w-[480px]">
|
||||
<div className="flex items-center gap-3">
|
||||
<div className="flex items-start gap-3">
|
||||
<div className="flex h-12 w-12 items-center justify-center overflow-hidden rounded-full p-1">
|
||||
<AgentImage
|
||||
src={agent.image}
|
||||
className="h-full w-full rounded-full object-contain"
|
||||
/>
|
||||
</div>
|
||||
<div className="flex max-h-[92px] w-[80%] flex-col gap-px">
|
||||
<div className="flex max-h-[92px] flex-1 flex-col gap-px">
|
||||
<h2 className="text-foreground text-base font-semibold sm:text-lg">
|
||||
{agent.name}
|
||||
</h2>
|
||||
@@ -26,6 +36,17 @@ export default function SharedAgentCard({ agent }: { agent: Agent }) {
|
||||
{agent.description}
|
||||
</p>
|
||||
</div>
|
||||
{onEdit && (
|
||||
<button
|
||||
type="button"
|
||||
onClick={onEdit}
|
||||
className="border-border hover:bg-accent text-foreground flex shrink-0 items-center gap-1.5 rounded-full border px-3 py-1.5 text-sm font-medium transition-colors"
|
||||
aria-label={t('agents.edit')}
|
||||
>
|
||||
<img src={EditIcon} alt="" className="h-3.5 w-3.5" />
|
||||
{t('agents.edit')}
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
{hasSharedMetadata && (
|
||||
<div className="mt-4 flex items-center gap-8">
|
||||
|
||||
@@ -813,7 +813,11 @@ function WorkflowBuilderInner() {
|
||||
const response = await userService.getWorkflow(workflowId, token);
|
||||
if (!response.ok) throw new Error('Failed to fetch workflow');
|
||||
const responseData = await response.json();
|
||||
const { workflow, nodes: apiNodes, edges: apiEdges } = responseData.data;
|
||||
const {
|
||||
workflow,
|
||||
nodes: apiNodes,
|
||||
edges: apiEdges,
|
||||
} = responseData.data;
|
||||
const nextWorkflowName = workflow.name;
|
||||
const nextWorkflowDescription = workflow.description || '';
|
||||
const mappedNodes = apiNodes.map((n: WorkflowNode) => {
|
||||
@@ -1472,7 +1476,9 @@ function WorkflowBuilderInner() {
|
||||
{t('agents.form.advanced.systemPromptOverride')}
|
||||
</label>
|
||||
<p className="mt-0.5 text-[11px] text-gray-500 dark:text-gray-400">
|
||||
{t('agents.form.advanced.systemPromptOverrideDescription')}
|
||||
{t(
|
||||
'agents.form.advanced.systemPromptOverrideDescription',
|
||||
)}
|
||||
</p>
|
||||
</div>
|
||||
<button
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import { withThrottle, type FetchLike } from './throttle';
|
||||
|
||||
export const baseURL =
|
||||
import.meta.env.VITE_API_HOST || 'https://docsapi.arc53.com';
|
||||
|
||||
@@ -18,112 +20,121 @@ const getHeaders = (
|
||||
return headers;
|
||||
};
|
||||
|
||||
const apiClient = {
|
||||
get: (
|
||||
url: string,
|
||||
token: string | null,
|
||||
headers = {},
|
||||
signal?: AbortSignal,
|
||||
): Promise<any> =>
|
||||
fetch(`${baseURL}${url}`, {
|
||||
method: 'GET',
|
||||
headers: getHeaders(token, headers),
|
||||
signal,
|
||||
}).then((response) => {
|
||||
return response;
|
||||
}),
|
||||
const createClient = (transport: FetchLike) => {
|
||||
const request = (url: string, init: RequestInit): Promise<Response> =>
|
||||
transport(`${baseURL}${url}`, init);
|
||||
|
||||
post: (
|
||||
url: string,
|
||||
data: any,
|
||||
token: string | null,
|
||||
headers = {},
|
||||
signal?: AbortSignal,
|
||||
): Promise<any> =>
|
||||
fetch(`${baseURL}${url}`, {
|
||||
method: 'POST',
|
||||
headers: getHeaders(token, headers),
|
||||
body: JSON.stringify(data),
|
||||
signal,
|
||||
}).then((response) => {
|
||||
return response;
|
||||
}),
|
||||
return {
|
||||
get: (
|
||||
url: string,
|
||||
token: string | null,
|
||||
headers = {},
|
||||
signal?: AbortSignal,
|
||||
): Promise<any> =>
|
||||
request(url, {
|
||||
method: 'GET',
|
||||
headers: getHeaders(token, headers),
|
||||
signal,
|
||||
}),
|
||||
|
||||
postFormData: (
|
||||
url: string,
|
||||
formData: FormData,
|
||||
token: string | null,
|
||||
headers = {},
|
||||
signal?: AbortSignal,
|
||||
): Promise<Response> => {
|
||||
return fetch(`${baseURL}${url}`, {
|
||||
method: 'POST',
|
||||
headers: getHeaders(token, headers, true),
|
||||
body: formData,
|
||||
signal,
|
||||
});
|
||||
},
|
||||
post: (
|
||||
url: string,
|
||||
data: any,
|
||||
token: string | null,
|
||||
headers = {},
|
||||
signal?: AbortSignal,
|
||||
): Promise<any> =>
|
||||
request(url, {
|
||||
method: 'POST',
|
||||
headers: getHeaders(token, headers),
|
||||
body: JSON.stringify(data),
|
||||
signal,
|
||||
}),
|
||||
|
||||
put: (
|
||||
url: string,
|
||||
data: any,
|
||||
token: string | null,
|
||||
headers = {},
|
||||
signal?: AbortSignal,
|
||||
): Promise<any> =>
|
||||
fetch(`${baseURL}${url}`, {
|
||||
method: 'PUT',
|
||||
headers: getHeaders(token, headers),
|
||||
body: JSON.stringify(data),
|
||||
signal,
|
||||
}).then((response) => {
|
||||
return response;
|
||||
}),
|
||||
postFormData: (
|
||||
url: string,
|
||||
formData: FormData,
|
||||
token: string | null,
|
||||
headers = {},
|
||||
signal?: AbortSignal,
|
||||
): Promise<Response> =>
|
||||
request(url, {
|
||||
method: 'POST',
|
||||
headers: getHeaders(token, headers, true),
|
||||
body: formData,
|
||||
signal,
|
||||
}),
|
||||
|
||||
patch: (
|
||||
url: string,
|
||||
data: any,
|
||||
token: string | null,
|
||||
headers = {},
|
||||
signal?: AbortSignal,
|
||||
): Promise<any> =>
|
||||
fetch(`${baseURL}${url}`, {
|
||||
method: 'PATCH',
|
||||
headers: getHeaders(token, headers),
|
||||
body: JSON.stringify(data),
|
||||
signal,
|
||||
}).then((response) => {
|
||||
return response;
|
||||
}),
|
||||
put: (
|
||||
url: string,
|
||||
data: any,
|
||||
token: string | null,
|
||||
headers = {},
|
||||
signal?: AbortSignal,
|
||||
): Promise<any> =>
|
||||
request(url, {
|
||||
method: 'PUT',
|
||||
headers: getHeaders(token, headers),
|
||||
body: JSON.stringify(data),
|
||||
signal,
|
||||
}),
|
||||
|
||||
putFormData: (
|
||||
url: string,
|
||||
formData: FormData,
|
||||
token: string | null,
|
||||
headers = {},
|
||||
signal?: AbortSignal,
|
||||
): Promise<Response> => {
|
||||
return fetch(`${baseURL}${url}`, {
|
||||
method: 'PUT',
|
||||
headers: getHeaders(token, headers, true),
|
||||
body: formData,
|
||||
signal,
|
||||
});
|
||||
},
|
||||
patch: (
|
||||
url: string,
|
||||
data: any,
|
||||
token: string | null,
|
||||
headers = {},
|
||||
signal?: AbortSignal,
|
||||
): Promise<any> =>
|
||||
request(url, {
|
||||
method: 'PATCH',
|
||||
headers: getHeaders(token, headers),
|
||||
body: JSON.stringify(data),
|
||||
signal,
|
||||
}),
|
||||
|
||||
delete: (
|
||||
url: string,
|
||||
token: string | null,
|
||||
headers = {},
|
||||
signal?: AbortSignal,
|
||||
): Promise<any> =>
|
||||
fetch(`${baseURL}${url}`, {
|
||||
method: 'DELETE',
|
||||
headers: getHeaders(token, headers),
|
||||
signal,
|
||||
}).then((response) => {
|
||||
return response;
|
||||
}),
|
||||
putFormData: (
|
||||
url: string,
|
||||
formData: FormData,
|
||||
token: string | null,
|
||||
headers = {},
|
||||
signal?: AbortSignal,
|
||||
): Promise<Response> =>
|
||||
request(url, {
|
||||
method: 'PUT',
|
||||
headers: getHeaders(token, headers, true),
|
||||
body: formData,
|
||||
signal,
|
||||
}),
|
||||
|
||||
delete: (
|
||||
url: string,
|
||||
token: string | null,
|
||||
headers = {},
|
||||
signal?: AbortSignal,
|
||||
): Promise<any> =>
|
||||
request(url, {
|
||||
method: 'DELETE',
|
||||
headers: getHeaders(token, headers),
|
||||
signal,
|
||||
}),
|
||||
};
|
||||
};
|
||||
|
||||
const apiClient = createClient((url, init) => fetch(url, init));
|
||||
|
||||
// Throttled client for endpoints that fan out, are polled, or are commonly
|
||||
// requested concurrently from multiple components. Shares a single concurrency
|
||||
// budget and de-duplicates identical in-flight GETs.
|
||||
export const throttledApiClient = createClient(
|
||||
withThrottle((url, init) => fetch(url, init), { debugLabel: 'api' }),
|
||||
);
|
||||
|
||||
if (import.meta.env.DEV && typeof window !== 'undefined') {
|
||||
(window as unknown as Record<string, unknown>).__apiClient = apiClient;
|
||||
(window as unknown as Record<string, unknown>).__throttledApiClient =
|
||||
throttledApiClient;
|
||||
(window as unknown as Record<string, unknown>).__baseURL = baseURL;
|
||||
}
|
||||
|
||||
export default apiClient;
|
||||
|
||||
@@ -28,7 +28,6 @@ const endpoints = {
|
||||
UPDATE_PROMPT: '/api/update_prompt',
|
||||
SINGLE_PROMPT: (id: string) => `/api/get_single_prompt?id=${id}`,
|
||||
DELETE_PATH: (docPath: string) => `/api/delete_old?source_id=${docPath}`,
|
||||
TASK_STATUS: (task_id: string) => `/api/task_status?task_id=${task_id}`,
|
||||
MESSAGE_ANALYTICS: '/api/get_message_analytics',
|
||||
TOKEN_ANALYTICS: '/api/get_token_analytics',
|
||||
FEEDBACK_ANALYTICS: '/api/get_feedback_analytics',
|
||||
@@ -43,6 +42,11 @@ const endpoints = {
|
||||
DELETE_TOOL: '/api/delete_tool',
|
||||
PARSE_SPEC: '/api/parse_spec',
|
||||
SYNC_CONNECTOR: '/api/connectors/sync',
|
||||
CONNECTOR_AUTH: (provider: string) =>
|
||||
`/api/connectors/auth?provider=${provider}`,
|
||||
CONNECTOR_FILES: '/api/connectors/files',
|
||||
CONNECTOR_VALIDATE_SESSION: '/api/connectors/validate-session',
|
||||
CONNECTOR_DISCONNECT: '/api/connectors/disconnect',
|
||||
GET_CHUNKS: (
|
||||
docId: string,
|
||||
page: number,
|
||||
@@ -59,6 +63,7 @@ const endpoints = {
|
||||
UPDATE_CHUNK: '/api/update_chunk',
|
||||
STORE_ATTACHMENT: '/api/store_attachment',
|
||||
STT: '/api/stt',
|
||||
TTS: '/api/tts',
|
||||
LIVE_STT_START: '/api/stt/live/start',
|
||||
LIVE_STT_CHUNK: '/api/stt/live/chunk',
|
||||
LIVE_STT_FINISH: '/api/stt/live/finish',
|
||||
@@ -67,8 +72,6 @@ const endpoints = {
|
||||
MANAGE_SOURCE_FILES: '/api/manage_source_files',
|
||||
MCP_TEST_CONNECTION: '/api/mcp_server/test',
|
||||
MCP_SAVE_SERVER: '/api/mcp_server/save',
|
||||
MCP_OAUTH_STATUS: (task_id: string) =>
|
||||
`/api/mcp_server/oauth_status/${task_id}`,
|
||||
MCP_AUTH_STATUS: '/api/mcp_server/auth_status',
|
||||
AGENT_FOLDERS: '/api/agents/folders/',
|
||||
AGENT_FOLDER: (id: string) => `/api/agents/folders/${id}`,
|
||||
@@ -92,6 +95,7 @@ const endpoints = {
|
||||
FEEDBACK: '/api/feedback',
|
||||
CONVERSATION: (id: string) => `/api/get_single_conversation?id=${id}`,
|
||||
CONVERSATIONS: '/api/get_conversations',
|
||||
MESSAGE_TAIL: (messageId: string) => `/api/messages/${messageId}/tail`,
|
||||
SHARE_CONVERSATION: (isPromptable: boolean) =>
|
||||
`/api/share?isPromptable=${isPromptable}`,
|
||||
SHARED_CONVERSATION: (identifier: string) =>
|
||||
|
||||
@@ -6,18 +6,20 @@ const conversationService = {
|
||||
data: any,
|
||||
token: string | null,
|
||||
signal: AbortSignal,
|
||||
headers: Record<string, string> = {},
|
||||
): Promise<any> =>
|
||||
apiClient.post(endpoints.CONVERSATION.ANSWER, data, token, {}, signal),
|
||||
apiClient.post(endpoints.CONVERSATION.ANSWER, data, token, headers, signal),
|
||||
answerStream: (
|
||||
data: any,
|
||||
token: string | null,
|
||||
signal: AbortSignal,
|
||||
headers: Record<string, string> = {},
|
||||
): Promise<any> =>
|
||||
apiClient.post(
|
||||
endpoints.CONVERSATION.ANSWER_STREAMING,
|
||||
data,
|
||||
token,
|
||||
{},
|
||||
headers,
|
||||
signal,
|
||||
),
|
||||
search: (data: any, token: string | null): Promise<any> =>
|
||||
@@ -26,6 +28,8 @@ const conversationService = {
|
||||
apiClient.post(endpoints.CONVERSATION.FEEDBACK, data, token, {}),
|
||||
getConversation: (id: string, token: string | null): Promise<any> =>
|
||||
apiClient.get(endpoints.CONVERSATION.CONVERSATION(id), token, {}),
|
||||
tailMessage: (messageId: string, token: string | null): Promise<any> =>
|
||||
apiClient.get(endpoints.CONVERSATION.MESSAGE_TAIL(messageId), token, {}),
|
||||
getConversations: (token: string | null): Promise<any> =>
|
||||
apiClient.get(endpoints.CONVERSATION.CONVERSATIONS, token, {}),
|
||||
shareConversation: (
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import { getSessionToken } from '../../utils/providerUtils';
|
||||
import apiClient from '../client';
|
||||
import apiClient, { throttledApiClient } from '../client';
|
||||
import endpoints from '../endpoints';
|
||||
|
||||
const userService = {
|
||||
getConfig: (): Promise<any> => apiClient.get(endpoints.USER.CONFIG, null),
|
||||
getConfig: (): Promise<any> =>
|
||||
throttledApiClient.get(endpoints.USER.CONFIG, null),
|
||||
getNewToken: (): Promise<any> =>
|
||||
apiClient.get(endpoints.USER.NEW_TOKEN, null),
|
||||
throttledApiClient.get(endpoints.USER.NEW_TOKEN, null),
|
||||
getDocs: (token: string | null): Promise<any> =>
|
||||
apiClient.get(`${endpoints.USER.DOCS}`, token),
|
||||
getDocsWithPagination: (query: string, token: string | null): Promise<any> =>
|
||||
@@ -17,9 +18,9 @@ const userService = {
|
||||
deleteAPIKey: (data: any, token: string | null): Promise<any> =>
|
||||
apiClient.post(endpoints.USER.DELETE_API_KEY, data, token),
|
||||
getAgent: (id: string, token: string | null): Promise<any> =>
|
||||
apiClient.get(endpoints.USER.AGENT(id), token),
|
||||
throttledApiClient.get(endpoints.USER.AGENT(id), token),
|
||||
getAgents: (token: string | null): Promise<any> =>
|
||||
apiClient.get(endpoints.USER.AGENTS, token),
|
||||
throttledApiClient.get(endpoints.USER.AGENTS, token),
|
||||
createAgent: (data: any, token: string | null): Promise<any> =>
|
||||
apiClient.postFormData(endpoints.USER.CREATE_AGENT, data, token),
|
||||
updateAgent: (
|
||||
@@ -31,19 +32,19 @@ const userService = {
|
||||
deleteAgent: (id: string, token: string | null): Promise<any> =>
|
||||
apiClient.delete(endpoints.USER.DELETE_AGENT(id), token),
|
||||
getPinnedAgents: (token: string | null): Promise<any> =>
|
||||
apiClient.get(endpoints.USER.PINNED_AGENTS, token),
|
||||
throttledApiClient.get(endpoints.USER.PINNED_AGENTS, token),
|
||||
togglePinAgent: (id: string, token: string | null): Promise<any> =>
|
||||
apiClient.post(endpoints.USER.TOGGLE_PIN_AGENT(id), {}, token),
|
||||
getSharedAgent: (id: string, token: string | null): Promise<any> =>
|
||||
apiClient.get(endpoints.USER.SHARED_AGENT(id), token),
|
||||
getSharedAgents: (token: string | null): Promise<any> =>
|
||||
apiClient.get(endpoints.USER.SHARED_AGENTS, token),
|
||||
throttledApiClient.get(endpoints.USER.SHARED_AGENTS, token),
|
||||
shareAgent: (data: any, token: string | null): Promise<any> =>
|
||||
apiClient.put(endpoints.USER.SHARE_AGENT, data, token),
|
||||
removeSharedAgent: (id: string, token: string | null): Promise<any> =>
|
||||
apiClient.delete(endpoints.USER.REMOVE_SHARED_AGENT(id), token),
|
||||
getTemplateAgents: (token: string | null): Promise<any> =>
|
||||
apiClient.get(endpoints.USER.TEMPLATE_AGENTS, token),
|
||||
throttledApiClient.get(endpoints.USER.TEMPLATE_AGENTS, token),
|
||||
adoptAgent: (id: string, token: string | null): Promise<any> =>
|
||||
apiClient.post(endpoints.USER.ADOPT_AGENT(id), {}, token),
|
||||
getAgentWebhook: (id: string, token: string | null): Promise<any> =>
|
||||
@@ -60,8 +61,6 @@ const userService = {
|
||||
apiClient.get(endpoints.USER.SINGLE_PROMPT(id), token),
|
||||
deletePath: (docPath: string, token: string | null): Promise<any> =>
|
||||
apiClient.get(endpoints.USER.DELETE_PATH(docPath), token),
|
||||
getTaskStatus: (task_id: string, token: string | null): Promise<any> =>
|
||||
apiClient.get(endpoints.USER.TASK_STATUS(task_id), token),
|
||||
getMessageAnalytics: (data: any, token: string | null): Promise<any> =>
|
||||
apiClient.post(endpoints.USER.MESSAGE_ANALYTICS, data, token),
|
||||
getTokenAnalytics: (data: any, token: string | null): Promise<any> =>
|
||||
@@ -149,7 +148,7 @@ const userService = {
|
||||
path?: string,
|
||||
search?: string,
|
||||
): Promise<any> =>
|
||||
apiClient.get(
|
||||
throttledApiClient.get(
|
||||
endpoints.USER.GET_CHUNKS(docId, page, perPage, path, search),
|
||||
token,
|
||||
),
|
||||
@@ -164,17 +163,15 @@ const userService = {
|
||||
updateChunk: (data: any, token: string | null): Promise<any> =>
|
||||
apiClient.put(endpoints.USER.UPDATE_CHUNK, data, token),
|
||||
getDirectoryStructure: (docId: string, token: string | null): Promise<any> =>
|
||||
apiClient.get(endpoints.USER.DIRECTORY_STRUCTURE(docId), token),
|
||||
throttledApiClient.get(endpoints.USER.DIRECTORY_STRUCTURE(docId), token),
|
||||
manageSourceFiles: (data: FormData, token: string | null): Promise<any> =>
|
||||
apiClient.postFormData(endpoints.USER.MANAGE_SOURCE_FILES, data, token),
|
||||
testMCPConnection: (data: any, token: string | null): Promise<any> =>
|
||||
apiClient.post(endpoints.USER.MCP_TEST_CONNECTION, data, token),
|
||||
saveMCPServer: (data: any, token: string | null): Promise<any> =>
|
||||
apiClient.post(endpoints.USER.MCP_SAVE_SERVER, data, token),
|
||||
getMCPOAuthStatus: (task_id: string, token: string | null): Promise<any> =>
|
||||
apiClient.get(endpoints.USER.MCP_OAUTH_STATUS(task_id), token),
|
||||
getMCPAuthStatus: (token: string | null): Promise<any> =>
|
||||
apiClient.get(endpoints.USER.MCP_AUTH_STATUS, token),
|
||||
throttledApiClient.get(endpoints.USER.MCP_AUTH_STATUS, token),
|
||||
syncConnector: (
|
||||
docId: string,
|
||||
provider: string,
|
||||
@@ -191,8 +188,50 @@ const userService = {
|
||||
token,
|
||||
);
|
||||
},
|
||||
getConnectorAuthUrl: (provider: string, token: string | null): Promise<any> =>
|
||||
apiClient.get(endpoints.USER.CONNECTOR_AUTH(provider), token),
|
||||
getConnectorFiles: (
|
||||
data: any,
|
||||
token: string | null,
|
||||
signal?: AbortSignal,
|
||||
): Promise<any> =>
|
||||
throttledApiClient.post(
|
||||
endpoints.USER.CONNECTOR_FILES,
|
||||
data,
|
||||
token,
|
||||
{},
|
||||
signal,
|
||||
),
|
||||
validateConnectorSession: (
|
||||
provider: string,
|
||||
token: string | null,
|
||||
): Promise<any> =>
|
||||
apiClient.post(
|
||||
endpoints.USER.CONNECTOR_VALIDATE_SESSION,
|
||||
{
|
||||
provider,
|
||||
session_token: getSessionToken(provider),
|
||||
},
|
||||
token,
|
||||
),
|
||||
disconnectConnector: (
|
||||
provider: string,
|
||||
sessionToken: string,
|
||||
token: string | null,
|
||||
): Promise<any> =>
|
||||
apiClient.post(
|
||||
endpoints.USER.CONNECTOR_DISCONNECT,
|
||||
{ provider, session_token: sessionToken },
|
||||
token,
|
||||
),
|
||||
textToSpeech: (
|
||||
text: string,
|
||||
token: string | null,
|
||||
signal?: AbortSignal,
|
||||
): Promise<any> =>
|
||||
apiClient.post(endpoints.USER.TTS, { text }, token, {}, signal),
|
||||
getAgentFolders: (token: string | null): Promise<any> =>
|
||||
apiClient.get(endpoints.USER.AGENT_FOLDERS, token),
|
||||
throttledApiClient.get(endpoints.USER.AGENT_FOLDERS, token),
|
||||
createAgentFolder: (
|
||||
data: { name: string; parent_id?: string },
|
||||
token: string | null,
|
||||
|
||||
223
frontend/src/api/throttle.ts
Normal file
223
frontend/src/api/throttle.ts
Normal file
@@ -0,0 +1,223 @@
|
||||
/**
|
||||
* Transport-layer middleware factory for the frontend API layer.
|
||||
*/
|
||||
|
||||
export type FetchLike = (
|
||||
input: string,
|
||||
init?: RequestInit,
|
||||
) => Promise<Response>;
|
||||
|
||||
export interface ThrottleConfig {
|
||||
maxConcurrentGlobal?: number;
|
||||
maxConcurrentPerRoute?: number;
|
||||
dedupe?: boolean;
|
||||
dedupeKey?: (url: string, init?: RequestInit) => string | false;
|
||||
debugLabel?: string;
|
||||
}
|
||||
|
||||
const DEFAULT_MAX_CONCURRENT_GLOBAL = 8;
|
||||
const DEFAULT_MAX_CONCURRENT_PER_ROUTE = 3;
|
||||
|
||||
type QueueItem = {
|
||||
run: () => void;
|
||||
signal?: AbortSignal;
|
||||
onAbort: () => void;
|
||||
};
|
||||
|
||||
function routeKey(method: string, url: string): string {
|
||||
let pathname = url;
|
||||
try {
|
||||
pathname = new URL(url, 'http://_').pathname;
|
||||
} catch {
|
||||
pathname = url.split('?')[0];
|
||||
}
|
||||
return `${method.toUpperCase()} ${pathname}`;
|
||||
}
|
||||
|
||||
function abortError(): DOMException {
|
||||
return new DOMException('The operation was aborted.', 'AbortError');
|
||||
}
|
||||
|
||||
interface ThrottleState {
|
||||
perRouteQueues: Map<string, QueueItem[]>;
|
||||
inflightPerRoute: Map<string, number>;
|
||||
inflightGets: Map<string, Promise<Response>>;
|
||||
inflightGlobal: number;
|
||||
}
|
||||
|
||||
function createState(): ThrottleState {
|
||||
return {
|
||||
perRouteQueues: new Map(),
|
||||
inflightPerRoute: new Map(),
|
||||
inflightGets: new Map(),
|
||||
inflightGlobal: 0,
|
||||
};
|
||||
}
|
||||
|
||||
export function withThrottle(
|
||||
fetchLike: FetchLike,
|
||||
config: ThrottleConfig = {},
|
||||
): FetchLike & { __reset: () => void } {
|
||||
const maxGlobal = config.maxConcurrentGlobal ?? DEFAULT_MAX_CONCURRENT_GLOBAL;
|
||||
const maxPerRoute =
|
||||
config.maxConcurrentPerRoute ?? DEFAULT_MAX_CONCURRENT_PER_ROUTE;
|
||||
const dedupeEnabled = config.dedupe !== false;
|
||||
const state = createState();
|
||||
|
||||
// Toggle in DevTools with: localStorage.setItem('debug:throttle', '1')
|
||||
const isDebug = (): boolean => {
|
||||
try {
|
||||
return (
|
||||
typeof localStorage !== 'undefined' &&
|
||||
localStorage.getItem('debug:throttle') === '1'
|
||||
);
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
const log = (
|
||||
event: string,
|
||||
key: string,
|
||||
extra?: Record<string, unknown>,
|
||||
): void => {
|
||||
if (!isDebug()) return;
|
||||
const queued = state.perRouteQueues.get(key)?.length ?? 0;
|
||||
const perRoute = state.inflightPerRoute.get(key) ?? 0;
|
||||
const tag = config.debugLabel
|
||||
? `[throttle:${config.debugLabel}]`
|
||||
: '[throttle]';
|
||||
console.debug(
|
||||
`${tag} ${event} ${key} | inflight=${state.inflightGlobal}/${maxGlobal} route=${perRoute}/${maxPerRoute} queued=${queued}`,
|
||||
extra ?? '',
|
||||
);
|
||||
};
|
||||
|
||||
const canDispatch = (key: string): boolean => {
|
||||
const perRoute = state.inflightPerRoute.get(key) ?? 0;
|
||||
return state.inflightGlobal < maxGlobal && perRoute < maxPerRoute;
|
||||
};
|
||||
|
||||
const pumpQueues = (): void => {
|
||||
for (const [key, queue] of state.perRouteQueues) {
|
||||
while (queue.length > 0 && canDispatch(key)) {
|
||||
const item = queue.shift()!;
|
||||
item.signal?.removeEventListener('abort', item.onAbort);
|
||||
item.run();
|
||||
}
|
||||
if (queue.length === 0) state.perRouteQueues.delete(key);
|
||||
}
|
||||
};
|
||||
|
||||
const enqueue = (key: string, item: QueueItem): void => {
|
||||
let queue = state.perRouteQueues.get(key);
|
||||
if (!queue) {
|
||||
queue = [];
|
||||
state.perRouteQueues.set(key, queue);
|
||||
}
|
||||
queue.push(item);
|
||||
};
|
||||
|
||||
const acquireSlot = (key: string, signal?: AbortSignal): Promise<void> =>
|
||||
new Promise((resolve, reject) => {
|
||||
if (signal?.aborted) {
|
||||
reject(abortError());
|
||||
return;
|
||||
}
|
||||
const item: QueueItem = {
|
||||
signal,
|
||||
run: () => {
|
||||
state.inflightGlobal += 1;
|
||||
state.inflightPerRoute.set(
|
||||
key,
|
||||
(state.inflightPerRoute.get(key) ?? 0) + 1,
|
||||
);
|
||||
resolve();
|
||||
},
|
||||
onAbort: () => {
|
||||
const queue = state.perRouteQueues.get(key);
|
||||
if (queue) {
|
||||
const idx = queue.indexOf(item);
|
||||
if (idx >= 0) queue.splice(idx, 1);
|
||||
}
|
||||
log('abort-queued', key);
|
||||
reject(abortError());
|
||||
},
|
||||
};
|
||||
const queued = state.perRouteQueues.get(key);
|
||||
if ((!queued || queued.length === 0) && canDispatch(key)) {
|
||||
item.run();
|
||||
log('dispatch', key);
|
||||
return;
|
||||
}
|
||||
signal?.addEventListener('abort', item.onAbort, { once: true });
|
||||
enqueue(key, item);
|
||||
log('queued', key);
|
||||
});
|
||||
|
||||
const releaseSlot = (key: string): void => {
|
||||
state.inflightGlobal = Math.max(0, state.inflightGlobal - 1);
|
||||
const next = (state.inflightPerRoute.get(key) ?? 1) - 1;
|
||||
if (next <= 0) state.inflightPerRoute.delete(key);
|
||||
else state.inflightPerRoute.set(key, next);
|
||||
log('release', key);
|
||||
pumpQueues();
|
||||
};
|
||||
|
||||
const wrapped = (async (url, init = {}) => {
|
||||
const method = (init.method ?? 'GET').toUpperCase();
|
||||
const signal = init.signal ?? undefined;
|
||||
const key = routeKey(method, url);
|
||||
|
||||
// Dedupe is restricted to GETs without a caller-supplied AbortSignal:
|
||||
// sharing a single underlying fetch across waiters means an abort by one
|
||||
// caller would reject the others, which is not the contract callers expect.
|
||||
const customKey = config.dedupeKey?.(url, init);
|
||||
const dedupeAllowed =
|
||||
dedupeEnabled &&
|
||||
customKey !== false &&
|
||||
method === 'GET' &&
|
||||
!init.body &&
|
||||
!signal;
|
||||
const dedupeKey = typeof customKey === 'string' ? customKey : `GET ${url}`;
|
||||
|
||||
if (dedupeAllowed) {
|
||||
const existing = state.inflightGets.get(dedupeKey);
|
||||
if (existing) {
|
||||
log('dedupe-hit', key, { dedupeKey });
|
||||
return existing.then((r) => r.clone());
|
||||
}
|
||||
}
|
||||
|
||||
const run = async (): Promise<Response> => {
|
||||
await acquireSlot(key, signal);
|
||||
try {
|
||||
return await fetchLike(url, init);
|
||||
} finally {
|
||||
releaseSlot(key);
|
||||
}
|
||||
};
|
||||
|
||||
if (dedupeAllowed) {
|
||||
const promise = run();
|
||||
state.inflightGets.set(dedupeKey, promise);
|
||||
promise.finally(() => {
|
||||
if (state.inflightGets.get(dedupeKey) === promise) {
|
||||
state.inflightGets.delete(dedupeKey);
|
||||
}
|
||||
});
|
||||
return promise.then((r) => r.clone());
|
||||
}
|
||||
|
||||
return run();
|
||||
}) as FetchLike & { __reset: () => void };
|
||||
|
||||
wrapped.__reset = () => {
|
||||
state.perRouteQueues.clear();
|
||||
state.inflightPerRoute.clear();
|
||||
state.inflightGets.clear();
|
||||
state.inflightGlobal = 0;
|
||||
};
|
||||
|
||||
return wrapped;
|
||||
}
|
||||
@@ -40,7 +40,7 @@ export default function ActionButtons({
|
||||
query: { conversationId: null },
|
||||
}),
|
||||
);
|
||||
navigate('/');
|
||||
navigate('/c/new');
|
||||
};
|
||||
return (
|
||||
<div
|
||||
|
||||
@@ -11,6 +11,7 @@ import NoFilesIcon from '../assets/no-files.svg';
|
||||
import SearchIcon from '../assets/search.svg';
|
||||
import {
|
||||
useDarkTheme,
|
||||
useDebouncedValue,
|
||||
useLoaderState,
|
||||
useMediaQuery,
|
||||
useOutsideAlerter,
|
||||
@@ -130,6 +131,7 @@ const Chunks: React.FC<ChunksProps> = ({
|
||||
const [totalChunks, setTotalChunks] = useState(0);
|
||||
const [loading, setLoading] = useLoaderState(true);
|
||||
const [searchTerm, setSearchTerm] = useState<string>('');
|
||||
const debouncedSearchTerm = useDebouncedValue(searchTerm, 300);
|
||||
const [editingChunk, setEditingChunk] = useState<ChunkType | null>(null);
|
||||
const [editingTitle, setEditingTitle] = useState('');
|
||||
const [editingText, setEditingText] = useState('');
|
||||
@@ -151,7 +153,7 @@ const Chunks: React.FC<ChunksProps> = ({
|
||||
perPage,
|
||||
token,
|
||||
path,
|
||||
searchTerm,
|
||||
debouncedSearchTerm,
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
@@ -276,16 +278,12 @@ const Chunks: React.FC<ChunksProps> = ({
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
const delayDebounceFn = setTimeout(() => {
|
||||
if (page !== 1) {
|
||||
setPage(1);
|
||||
} else {
|
||||
fetchChunks();
|
||||
}
|
||||
}, 300);
|
||||
|
||||
return () => clearTimeout(delayDebounceFn);
|
||||
}, [searchTerm]);
|
||||
if (page !== 1) {
|
||||
setPage(1);
|
||||
} else {
|
||||
fetchChunks();
|
||||
}
|
||||
}, [debouncedSearchTerm]);
|
||||
|
||||
useEffect(() => {
|
||||
!loading && fetchChunks();
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import React, { useRef } from 'react';
|
||||
import React, { useEffect, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useSelector } from 'react-redux';
|
||||
|
||||
import userService from '../api/services/userService';
|
||||
import { useDarkTheme } from '../hooks';
|
||||
import { selectToken } from '../preferences/preferenceSlice';
|
||||
|
||||
@@ -31,13 +32,24 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
||||
const [isDarkTheme] = useDarkTheme();
|
||||
const completedRef = useRef(false);
|
||||
const intervalRef = useRef<number | null>(null);
|
||||
const authWindowRef = useRef<Window | null>(null);
|
||||
// Hold the exact listener identity so unmount cleanup removes the same fn.
|
||||
const messageHandlerRef = useRef<((event: MessageEvent) => void) | null>(
|
||||
null,
|
||||
);
|
||||
// Tracks mount status so async ``fetch`` resolves after unmount don't
|
||||
// call ``onSuccess`` / ``onError`` on a vanished parent.
|
||||
const mountedRef = useRef(true);
|
||||
|
||||
const cleanup = () => {
|
||||
if (intervalRef.current) {
|
||||
clearInterval(intervalRef.current);
|
||||
intervalRef.current = null;
|
||||
}
|
||||
window.removeEventListener('message', handleAuthMessage as any);
|
||||
if (messageHandlerRef.current) {
|
||||
window.removeEventListener('message', messageHandlerRef.current as any);
|
||||
messageHandlerRef.current = null;
|
||||
}
|
||||
};
|
||||
|
||||
const handleAuthMessage = (event: MessageEvent) => {
|
||||
@@ -48,6 +60,7 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
||||
if (successGeneric || successProvider) {
|
||||
completedRef.current = true;
|
||||
cleanup();
|
||||
authWindowRef.current = null;
|
||||
onSuccess({
|
||||
session_token: event.data.session_token,
|
||||
user_email:
|
||||
@@ -57,6 +70,7 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
||||
} else if (errorProvider) {
|
||||
completedRef.current = true;
|
||||
cleanup();
|
||||
authWindowRef.current = null;
|
||||
onError(
|
||||
event.data.error || t('modals.uploadDoc.connectors.auth.authFailed'),
|
||||
);
|
||||
@@ -66,15 +80,20 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
||||
const handleAuth = async () => {
|
||||
try {
|
||||
completedRef.current = false;
|
||||
// Close any popup left over from a previous click before wiping
|
||||
// the ref — otherwise the old window keeps living with no
|
||||
// interval watching it and no listener handling its messages.
|
||||
if (authWindowRef.current && !authWindowRef.current.closed) {
|
||||
authWindowRef.current.close();
|
||||
}
|
||||
authWindowRef.current = null;
|
||||
cleanup();
|
||||
|
||||
const apiHost = import.meta.env.VITE_API_HOST;
|
||||
const authResponse = await fetch(
|
||||
`${apiHost}/api/connectors/auth?provider=${provider}`,
|
||||
{
|
||||
headers: { Authorization: `Bearer ${token}` },
|
||||
},
|
||||
const authResponse = await userService.getConnectorAuthUrl(
|
||||
provider,
|
||||
token,
|
||||
);
|
||||
if (!mountedRef.current) return;
|
||||
|
||||
if (!authResponse.ok) {
|
||||
throw new Error(
|
||||
@@ -83,6 +102,7 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
||||
}
|
||||
|
||||
const authData = await authResponse.json();
|
||||
if (!mountedRef.current) return;
|
||||
if (!authData.success || !authData.authorization_url) {
|
||||
throw new Error(
|
||||
authData.error || t('modals.uploadDoc.connectors.auth.authUrlFailed'),
|
||||
@@ -97,13 +117,23 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
||||
if (!authWindow) {
|
||||
throw new Error(t('modals.uploadDoc.connectors.auth.popupBlocked'));
|
||||
}
|
||||
authWindowRef.current = authWindow;
|
||||
|
||||
messageHandlerRef.current = handleAuthMessage;
|
||||
window.addEventListener('message', handleAuthMessage as any);
|
||||
|
||||
const checkClosed = window.setInterval(() => {
|
||||
if (authWindow.closed) {
|
||||
clearInterval(checkClosed);
|
||||
window.removeEventListener('message', handleAuthMessage as any);
|
||||
intervalRef.current = null;
|
||||
if (messageHandlerRef.current) {
|
||||
window.removeEventListener(
|
||||
'message',
|
||||
messageHandlerRef.current as any,
|
||||
);
|
||||
messageHandlerRef.current = null;
|
||||
}
|
||||
authWindowRef.current = null;
|
||||
if (!completedRef.current) {
|
||||
onError(t('modals.uploadDoc.connectors.auth.authCancelled'));
|
||||
}
|
||||
@@ -111,6 +141,7 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
||||
}, 1000);
|
||||
intervalRef.current = checkClosed;
|
||||
} catch (error) {
|
||||
if (!mountedRef.current) return;
|
||||
onError(
|
||||
error instanceof Error
|
||||
? error.message
|
||||
@@ -119,6 +150,18 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
||||
}
|
||||
};
|
||||
|
||||
// Release interval, message listener, and popup on unmount only.
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
mountedRef.current = false;
|
||||
cleanup();
|
||||
if (authWindowRef.current && !authWindowRef.current.closed) {
|
||||
authWindowRef.current.close();
|
||||
}
|
||||
authWindowRef.current = null;
|
||||
};
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<>
|
||||
{errorMessage && (
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import React, { useEffect, useRef, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useSelector } from 'react-redux';
|
||||
import { useSelector, useStore } from 'react-redux';
|
||||
|
||||
import userService from '../api/services/userService';
|
||||
import ArrowLeft from '../assets/arrow-left.svg';
|
||||
@@ -14,6 +14,7 @@ import { useLoaderState, useOutsideAlerter } from '../hooks';
|
||||
import ConfirmationModal from '../modals/ConfirmationModal';
|
||||
import { ActiveState } from '../models/misc';
|
||||
import { selectToken } from '../preferences/preferenceSlice';
|
||||
import type { RootState } from '../store';
|
||||
import { formatBytes } from '../utils/stringUtils';
|
||||
import Chunks from './Chunks';
|
||||
import ContextMenu, { MenuOption } from './ContextMenu';
|
||||
@@ -64,6 +65,7 @@ const ConnectorTree: React.FC<ConnectorTreeProps> = ({
|
||||
useState<DirectoryStructure | null>(null);
|
||||
const [currentPath, setCurrentPath] = useState<string[]>([]);
|
||||
const token = useSelector(selectToken);
|
||||
const store = useStore<RootState>();
|
||||
const [activeMenuId, setActiveMenuId] = useState<string | null>(null);
|
||||
const menuRefs = useRef<{
|
||||
[key: string]: React.RefObject<HTMLDivElement | null>;
|
||||
@@ -81,6 +83,25 @@ const ConnectorTree: React.FC<ConnectorTreeProps> = ({
|
||||
const [syncDone, setSyncDone] = useState<boolean>(false);
|
||||
const [syncConfirmationModal, setSyncConfirmationModal] =
|
||||
useState<ActiveState>('INACTIVE');
|
||||
const mountedRef = useRef(true);
|
||||
const syncUnsubscribeRef = useRef<(() => void) | null>(null);
|
||||
// Holds the 5-minute SSE-wait timer so the unmount cleanup can clear
|
||||
// it — otherwise the timer fires up to 5 min after unmount and
|
||||
// resolves an abandoned Promise.
|
||||
const syncTimerRef = useRef<number | null>(null);
|
||||
|
||||
useEffect(
|
||||
() => () => {
|
||||
mountedRef.current = false;
|
||||
syncUnsubscribeRef.current?.();
|
||||
syncUnsubscribeRef.current = null;
|
||||
if (syncTimerRef.current !== null) {
|
||||
window.clearTimeout(syncTimerRef.current);
|
||||
syncTimerRef.current = null;
|
||||
}
|
||||
},
|
||||
[],
|
||||
);
|
||||
|
||||
useOutsideAlerter(
|
||||
searchDropdownRef,
|
||||
@@ -116,67 +137,108 @@ const ConnectorTree: React.FC<ConnectorTreeProps> = ({
|
||||
console.log('Sync started successfully:', data.task_id);
|
||||
setSyncProgress(10);
|
||||
|
||||
// Poll task status using userService
|
||||
const maxAttempts = 30;
|
||||
const pollInterval = 2000;
|
||||
// The connector worker (``ingest_connector`` in
|
||||
// ``application/worker.py``) publishes
|
||||
// ``source.ingest.{queued,completed,failed}`` envelopes keyed on
|
||||
// ``scope.id == docId`` (sync mode reuses the source uuid). Wait
|
||||
// on the bounded ``notifications.recentEvents`` ring for a
|
||||
// terminal envelope rather than polling ``/task_status``.
|
||||
// Mirrors FileTree's slice-walking pattern, including the
|
||||
// ``opStartedAt`` guard so a stale terminal event from a prior
|
||||
// sync of this same source can't short-circuit the current op.
|
||||
const opStartedAt = Date.now();
|
||||
|
||||
const terminalFromSse = (): 'completed' | 'failed' | null => {
|
||||
const events = store.getState().notifications.recentEvents;
|
||||
for (const event of events) {
|
||||
if (event.scope?.id !== docId) continue;
|
||||
const ts = event.ts ? Date.parse(event.ts) : NaN;
|
||||
if (!Number.isFinite(ts) || ts < opStartedAt) continue;
|
||||
if (event.type === 'source.ingest.completed') return 'completed';
|
||||
if (event.type === 'source.ingest.failed') return 'failed';
|
||||
}
|
||||
return null;
|
||||
};
|
||||
|
||||
const MAX_WAIT_MS = 5 * 60_000;
|
||||
const terminal = await new Promise<
|
||||
'completed' | 'failed' | 'timeout' | 'unmounted'
|
||||
>((resolve) => {
|
||||
// Cover the race where the event landed between the POST
|
||||
// returning and the subscribe call.
|
||||
const initial = terminalFromSse();
|
||||
if (initial) {
|
||||
resolve(initial);
|
||||
return;
|
||||
}
|
||||
if (!mountedRef.current) {
|
||||
resolve('unmounted');
|
||||
return;
|
||||
}
|
||||
let settled = false;
|
||||
const finish = (
|
||||
value: 'completed' | 'failed' | 'timeout' | 'unmounted',
|
||||
) => {
|
||||
if (settled) return;
|
||||
settled = true;
|
||||
if (syncTimerRef.current !== null) {
|
||||
window.clearTimeout(syncTimerRef.current);
|
||||
syncTimerRef.current = null;
|
||||
}
|
||||
if (syncUnsubscribeRef.current) {
|
||||
syncUnsubscribeRef.current();
|
||||
syncUnsubscribeRef.current = null;
|
||||
}
|
||||
resolve(value);
|
||||
};
|
||||
syncTimerRef.current = window.setTimeout(
|
||||
() => finish('timeout'),
|
||||
MAX_WAIT_MS,
|
||||
);
|
||||
syncUnsubscribeRef.current = store.subscribe(() => {
|
||||
if (!mountedRef.current) {
|
||||
finish('unmounted');
|
||||
return;
|
||||
}
|
||||
const next = terminalFromSse();
|
||||
if (next) finish(next);
|
||||
});
|
||||
});
|
||||
|
||||
if (terminal === 'timeout') {
|
||||
console.error('Sync timed out waiting for SSE terminal');
|
||||
} else if (terminal === 'unmounted') {
|
||||
return;
|
||||
}
|
||||
|
||||
if (terminal === 'completed') {
|
||||
// The "no files downloaded" early-return path publishes
|
||||
// ``completed`` with ``no_changes: true`` — treated as success
|
||||
// here; refreshing the directory is cheap and idempotent.
|
||||
setSyncProgress(100);
|
||||
console.log('Sync completed successfully');
|
||||
|
||||
for (let attempt = 0; attempt < maxAttempts; attempt++) {
|
||||
try {
|
||||
const statusResponse = await userService.getTaskStatus(
|
||||
data.task_id,
|
||||
const refreshResponse = await userService.getDirectoryStructure(
|
||||
docId,
|
||||
token,
|
||||
);
|
||||
const statusData = await statusResponse.json();
|
||||
|
||||
console.log(
|
||||
`Task status (attempt ${attempt + 1}):`,
|
||||
statusData.status,
|
||||
);
|
||||
|
||||
if (statusData.status === 'SUCCESS') {
|
||||
setSyncProgress(100);
|
||||
console.log('Sync completed successfully');
|
||||
|
||||
// Refresh directory structure
|
||||
try {
|
||||
const refreshResponse = await userService.getDirectoryStructure(
|
||||
docId,
|
||||
token,
|
||||
);
|
||||
const refreshData = await refreshResponse.json();
|
||||
if (refreshData && refreshData.directory_structure) {
|
||||
setDirectoryStructure(refreshData.directory_structure);
|
||||
setCurrentPath([]);
|
||||
}
|
||||
if (refreshData && refreshData.provider) {
|
||||
setSourceProvider(refreshData.provider);
|
||||
}
|
||||
|
||||
setSyncDone(true);
|
||||
setTimeout(() => setSyncDone(false), 5000);
|
||||
} catch (err) {
|
||||
console.error('Error refreshing directory structure:', err);
|
||||
}
|
||||
break;
|
||||
} else if (statusData.status === 'FAILURE') {
|
||||
console.error('Sync task failed:', statusData.result);
|
||||
break;
|
||||
} else if (statusData.status === 'PROGRESS') {
|
||||
const progress = Number(
|
||||
statusData.result && statusData.result.current != null
|
||||
? statusData.result.current
|
||||
: statusData.meta && statusData.meta.current != null
|
||||
? statusData.meta.current
|
||||
: 0,
|
||||
);
|
||||
setSyncProgress(Math.max(10, progress));
|
||||
const refreshData = await refreshResponse.json();
|
||||
if (refreshData && refreshData.directory_structure) {
|
||||
setDirectoryStructure(refreshData.directory_structure);
|
||||
setCurrentPath([]);
|
||||
}
|
||||
if (refreshData && refreshData.provider) {
|
||||
setSourceProvider(refreshData.provider);
|
||||
}
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, pollInterval));
|
||||
} catch (error) {
|
||||
console.error('Error polling task status:', error);
|
||||
break;
|
||||
setSyncDone(true);
|
||||
setTimeout(() => setSyncDone(false), 5000);
|
||||
} catch (err) {
|
||||
console.error('Error refreshing directory structure:', err);
|
||||
}
|
||||
} else if (terminal === 'failed') {
|
||||
console.error('Sync task failed (per SSE)');
|
||||
}
|
||||
} else {
|
||||
console.error('Sync failed:', data.error);
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import React, { useState, useEffect, useCallback, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import userService from '../api/services/userService';
|
||||
import { formatBytes } from '../utils/stringUtils';
|
||||
import { formatDate } from '../utils/dateTimeUtils';
|
||||
import {
|
||||
@@ -22,6 +23,7 @@ import {
|
||||
TableHeader,
|
||||
TableCell,
|
||||
} from './Table';
|
||||
import { useDebouncedCallback } from '../hooks';
|
||||
|
||||
interface CloudFile {
|
||||
id: string;
|
||||
@@ -100,7 +102,6 @@ export const FilePicker: React.FC<CloudFilePickerProps> = ({
|
||||
const [activeTab, setActiveTab] = useState<'my_files' | 'shared'>('my_files');
|
||||
|
||||
const scrollContainerRef = useRef<HTMLDivElement>(null);
|
||||
const searchTimeoutRef = useRef<ReturnType<typeof setTimeout> | null>(null);
|
||||
const abortControllerRef = useRef<AbortController | null>(null);
|
||||
|
||||
const isFolder = (file: CloudFile) => {
|
||||
@@ -126,7 +127,6 @@ export const FilePicker: React.FC<CloudFilePickerProps> = ({
|
||||
|
||||
setIsLoading(true);
|
||||
|
||||
const apiHost = import.meta.env.VITE_API_HOST;
|
||||
if (!pageToken) {
|
||||
setFiles([]);
|
||||
}
|
||||
@@ -141,15 +141,11 @@ export const FilePicker: React.FC<CloudFilePickerProps> = ({
|
||||
search_query: searchQuery,
|
||||
shared: shared,
|
||||
};
|
||||
const response = await fetch(`${apiHost}/api/connectors/files`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${token}`,
|
||||
},
|
||||
body: JSON.stringify(body),
|
||||
signal: controller.signal,
|
||||
});
|
||||
const response = await userService.getConnectorFiles(
|
||||
body,
|
||||
token,
|
||||
controller.signal,
|
||||
);
|
||||
|
||||
const data = await response.json();
|
||||
if (data.success) {
|
||||
@@ -187,20 +183,9 @@ export const FilePicker: React.FC<CloudFilePickerProps> = ({
|
||||
}
|
||||
|
||||
try {
|
||||
const apiHost = import.meta.env.VITE_API_HOST;
|
||||
const validateResponse = await fetch(
|
||||
`${apiHost}/api/connectors/validate-session`,
|
||||
{
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${token}`,
|
||||
},
|
||||
body: JSON.stringify({
|
||||
provider: provider,
|
||||
session_token: sessionToken,
|
||||
}),
|
||||
},
|
||||
const validateResponse = await userService.validateConnectorSession(
|
||||
provider,
|
||||
token,
|
||||
);
|
||||
|
||||
if (!validateResponse.ok) {
|
||||
@@ -292,32 +277,26 @@ export const FilePicker: React.FC<CloudFilePickerProps> = ({
|
||||
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
if (searchTimeoutRef.current) {
|
||||
clearTimeout(searchTimeoutRef.current);
|
||||
}
|
||||
abortControllerRef.current?.abort();
|
||||
};
|
||||
}, []);
|
||||
|
||||
const debouncedLoadFiles = useDebouncedCallback((query: string) => {
|
||||
const sessionToken = getSessionToken(provider);
|
||||
if (sessionToken) {
|
||||
loadCloudFiles(
|
||||
sessionToken,
|
||||
currentFolderId,
|
||||
undefined,
|
||||
query,
|
||||
activeTab === 'shared' && !currentFolderId,
|
||||
);
|
||||
}
|
||||
}, 300);
|
||||
|
||||
const handleSearchChange = (query: string) => {
|
||||
setSearchQuery(query);
|
||||
|
||||
if (searchTimeoutRef.current) {
|
||||
clearTimeout(searchTimeoutRef.current);
|
||||
}
|
||||
|
||||
searchTimeoutRef.current = setTimeout(() => {
|
||||
const sessionToken = getSessionToken(provider);
|
||||
if (sessionToken) {
|
||||
loadCloudFiles(
|
||||
sessionToken,
|
||||
currentFolderId,
|
||||
undefined,
|
||||
query,
|
||||
activeTab === 'shared' && !currentFolderId,
|
||||
);
|
||||
}
|
||||
}, 300);
|
||||
debouncedLoadFiles(query);
|
||||
};
|
||||
|
||||
const handleFolderClick = (folderId: string, folderName: string) => {
|
||||
@@ -424,23 +403,14 @@ export const FilePicker: React.FC<CloudFilePickerProps> = ({
|
||||
onDisconnect={() => {
|
||||
const sessionToken = getSessionToken(provider);
|
||||
if (sessionToken) {
|
||||
const apiHost = import.meta.env.VITE_API_HOST;
|
||||
fetch(`${apiHost}/api/connectors/disconnect`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${token}`,
|
||||
},
|
||||
body: JSON.stringify({
|
||||
provider: provider,
|
||||
session_token: sessionToken,
|
||||
}),
|
||||
}).catch((err) =>
|
||||
console.error(
|
||||
`Error disconnecting from ${getProviderConfig(provider).displayName}:`,
|
||||
err,
|
||||
),
|
||||
);
|
||||
userService
|
||||
.disconnectConnector(provider, sessionToken, token)
|
||||
.catch((err) =>
|
||||
console.error(
|
||||
`Error disconnecting from ${getProviderConfig(provider).displayName}:`,
|
||||
err,
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
removeSessionToken(provider);
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import React, { useState, useRef, useEffect } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useSelector } from 'react-redux';
|
||||
import { useSelector, useStore } from 'react-redux';
|
||||
import { selectToken } from '../preferences/preferenceSlice';
|
||||
import type { RootState } from '../store';
|
||||
import { formatBytes } from '../utils/stringUtils';
|
||||
import Chunks from './Chunks';
|
||||
import ContextMenu, { MenuOption } from './ContextMenu';
|
||||
@@ -56,6 +57,7 @@ const FileTree: React.FC<FileTreeProps> = ({
|
||||
onBackToDocuments,
|
||||
}) => {
|
||||
const { t } = useTranslation();
|
||||
const store = useStore<RootState>();
|
||||
const [loading, setLoading] = useLoaderState(true, 500);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [directoryStructure, setDirectoryStructure] =
|
||||
@@ -95,6 +97,25 @@ const FileTree: React.FC<FileTreeProps> = ({
|
||||
const opQueueRef = useRef<QueuedOperation[]>([]);
|
||||
const processingRef = useRef(false);
|
||||
const [queueLength, setQueueLength] = useState(0);
|
||||
const mountedRef = useRef(true);
|
||||
const waitUnsubscribeRef = useRef<(() => void) | null>(null);
|
||||
// Holds the 5-minute SSE-wait timer so the unmount cleanup can clear
|
||||
// it — otherwise the timer fires up to 5 min after unmount and
|
||||
// resolves an abandoned Promise.
|
||||
const waitTimerRef = useRef<number | null>(null);
|
||||
|
||||
useEffect(
|
||||
() => () => {
|
||||
mountedRef.current = false;
|
||||
waitUnsubscribeRef.current?.();
|
||||
waitUnsubscribeRef.current = null;
|
||||
if (waitTimerRef.current !== null) {
|
||||
window.clearTimeout(waitTimerRef.current);
|
||||
waitTimerRef.current = null;
|
||||
}
|
||||
},
|
||||
[],
|
||||
);
|
||||
|
||||
useOutsideAlerter(
|
||||
searchDropdownRef,
|
||||
@@ -313,47 +334,103 @@ const FileTree: React.FC<FileTreeProps> = ({
|
||||
}
|
||||
console.log('Reingest task started:', result.reingest_task_id);
|
||||
|
||||
const maxAttempts = 30;
|
||||
const pollInterval = 2000;
|
||||
// SSE is the sole driver here. The backend's
|
||||
// ``reingest_source_worker`` publishes ``source.ingest.*``
|
||||
// keyed on the resolved ``source_id`` (the
|
||||
// ``manage_source_files`` route returns it explicitly so we
|
||||
// can match without consulting any slice). Subscribe to the
|
||||
// store and resolve when a terminal event tagged with our
|
||||
// source lands in ``notifications.recentEvents``. Re-checking
|
||||
// on every dispatch (rather than polling on a timer) avoids
|
||||
// races where a terminal could roll off the bounded ring
|
||||
// before the next tick observes it in chatty sessions.
|
||||
const reingestSourceId: string | undefined = result.source_id;
|
||||
// Cutoff so we don't pick up terminal events from a *previous*
|
||||
// reingest of the same source — the backend's
|
||||
// ``source.ingest.*`` payload doesn't carry a Celery task id,
|
||||
// so source_id alone is ambiguous when ops repeat.
|
||||
const opStartedAt = Date.now();
|
||||
const MAX_WAIT_MS = 5 * 60_000;
|
||||
|
||||
for (let attempt = 0; attempt < maxAttempts; attempt++) {
|
||||
try {
|
||||
const statusResponse = await userService.getTaskStatus(
|
||||
result.reingest_task_id,
|
||||
token,
|
||||
);
|
||||
const statusData = await statusResponse.json();
|
||||
|
||||
console.log(
|
||||
`Task status (attempt ${attempt + 1}):`,
|
||||
statusData.status,
|
||||
);
|
||||
|
||||
if (statusData.status === 'SUCCESS') {
|
||||
console.log('Task completed successfully');
|
||||
|
||||
const structureResponse = await userService.getDirectoryStructure(
|
||||
docId,
|
||||
token,
|
||||
);
|
||||
const structureData = await structureResponse.json();
|
||||
|
||||
if (structureData && structureData.directory_structure) {
|
||||
setDirectoryStructure(structureData.directory_structure);
|
||||
currentOpRef.current = null;
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
} else if (statusData.status === 'FAILURE') {
|
||||
console.error('Task failed');
|
||||
break;
|
||||
}
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, pollInterval));
|
||||
} catch (error) {
|
||||
console.error('Error polling task status:', error);
|
||||
break;
|
||||
const terminalFromSse = (): 'completed' | 'failed' | null => {
|
||||
if (!reingestSourceId) return null;
|
||||
const events = store.getState().notifications.recentEvents;
|
||||
for (const event of events) {
|
||||
if (event.scope?.id !== reingestSourceId) continue;
|
||||
const ts = event.ts ? Date.parse(event.ts) : NaN;
|
||||
if (!Number.isFinite(ts) || ts < opStartedAt) continue;
|
||||
if (event.type === 'source.ingest.completed') return 'completed';
|
||||
if (event.type === 'source.ingest.failed') return 'failed';
|
||||
}
|
||||
return null;
|
||||
};
|
||||
|
||||
const refreshStructure = async (): Promise<boolean> => {
|
||||
const structureResponse = await userService.getDirectoryStructure(
|
||||
docId,
|
||||
token,
|
||||
);
|
||||
const structureData = await structureResponse.json();
|
||||
if (!mountedRef.current) return false;
|
||||
if (structureData && structureData.directory_structure) {
|
||||
setDirectoryStructure(structureData.directory_structure);
|
||||
currentOpRef.current = null;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
const terminal = await new Promise<
|
||||
'completed' | 'failed' | 'timeout' | 'unmounted'
|
||||
>((resolve) => {
|
||||
if (!mountedRef.current) {
|
||||
resolve('unmounted');
|
||||
return;
|
||||
}
|
||||
// Cover the race where the terminal event landed between
|
||||
// the POST returning and the subscribe call.
|
||||
const initial = terminalFromSse();
|
||||
if (initial) {
|
||||
resolve(initial);
|
||||
return;
|
||||
}
|
||||
const timer = window.setTimeout(() => {
|
||||
waitUnsubscribeRef.current?.();
|
||||
waitUnsubscribeRef.current = null;
|
||||
waitTimerRef.current = null;
|
||||
resolve('timeout');
|
||||
}, MAX_WAIT_MS);
|
||||
waitTimerRef.current = timer;
|
||||
waitUnsubscribeRef.current = store.subscribe(() => {
|
||||
if (!mountedRef.current) {
|
||||
window.clearTimeout(timer);
|
||||
waitTimerRef.current = null;
|
||||
waitUnsubscribeRef.current?.();
|
||||
waitUnsubscribeRef.current = null;
|
||||
resolve('unmounted');
|
||||
return;
|
||||
}
|
||||
const next = terminalFromSse();
|
||||
if (next) {
|
||||
window.clearTimeout(timer);
|
||||
waitTimerRef.current = null;
|
||||
waitUnsubscribeRef.current?.();
|
||||
waitUnsubscribeRef.current = null;
|
||||
resolve(next);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
if (!mountedRef.current) return false;
|
||||
|
||||
if (terminal === 'completed') {
|
||||
if (await refreshStructure()) return true;
|
||||
} else if (terminal === 'failed') {
|
||||
console.error('Reingest task failed (per SSE)');
|
||||
} else if (terminal === 'unmounted') {
|
||||
return false;
|
||||
} else {
|
||||
console.error('Reingest timed out waiting for SSE terminal');
|
||||
}
|
||||
} else {
|
||||
throw new Error(
|
||||
@@ -374,7 +451,7 @@ const FileTree: React.FC<FileTreeProps> = ({
|
||||
? 'delete directory'
|
||||
: 'delete file(s)';
|
||||
console.error(`Error ${actionText}:`, error);
|
||||
setError(`Failed to ${errorText}`);
|
||||
if (mountedRef.current) setError(`Failed to ${errorText}`);
|
||||
} finally {
|
||||
currentOpRef.current = null;
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ import React, { useState, useEffect } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import useDrivePicker from 'react-google-drive-picker';
|
||||
|
||||
import userService from '../api/services/userService';
|
||||
import ConnectorAuth from './ConnectorAuth';
|
||||
import {
|
||||
getSessionToken,
|
||||
@@ -199,18 +200,11 @@ const GoogleDrivePicker: React.FC<GoogleDrivePickerProps> = ({
|
||||
const sessionToken = getSessionToken('google_drive');
|
||||
if (sessionToken) {
|
||||
try {
|
||||
const apiHost = import.meta.env.VITE_API_HOST;
|
||||
await fetch(`${apiHost}/api/connectors/disconnect`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${token}`,
|
||||
},
|
||||
body: JSON.stringify({
|
||||
provider: 'google_drive',
|
||||
session_token: sessionToken,
|
||||
}),
|
||||
});
|
||||
await userService.disconnectConnector(
|
||||
'google_drive',
|
||||
sessionToken,
|
||||
token,
|
||||
);
|
||||
} catch (err) {
|
||||
console.error('Error disconnecting from Google Drive:', err);
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ import { createPortal } from 'react-dom';
|
||||
import { LoaderCircle, Mic, Square } from 'lucide-react';
|
||||
import { useDropzone } from 'react-dropzone';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useDispatch, useSelector } from 'react-redux';
|
||||
import { useDispatch, useSelector, useStore } from 'react-redux';
|
||||
|
||||
import endpoints from '../api/endpoints';
|
||||
import userService from '../api/services/userService';
|
||||
@@ -28,6 +28,7 @@ import {
|
||||
selectSelectedDocs,
|
||||
selectToken,
|
||||
} from '../preferences/preferenceSlice';
|
||||
import type { RootState } from '../store';
|
||||
import Upload from '../upload/Upload';
|
||||
import { getOS, isTouchDevice } from '../utils/browserUtils';
|
||||
import SourcesPopup from './SourcesPopup';
|
||||
@@ -316,6 +317,7 @@ export default function MessageInput({
|
||||
const attachments = useSelector(selectAttachments);
|
||||
|
||||
const dispatch = useDispatch();
|
||||
const store = useStore<RootState>();
|
||||
const mediaStreamRef = useRef<MediaStream | null>(null);
|
||||
const audioContextRef = useRef<AudioContext | null>(null);
|
||||
const audioSourceNodeRef = useRef<MediaStreamAudioSourceNode | null>(null);
|
||||
@@ -410,6 +412,86 @@ export default function MessageInput({
|
||||
};
|
||||
}, []);
|
||||
|
||||
// Recover the race where attachment.* SSE arrives before the upload
|
||||
// XHR's onload sets ``attachmentId``: walk recentEvents and watchdog
|
||||
// the row so it can't stay stuck on 'processing'. Mirrors
|
||||
// Upload.tsx's ``trackTraining``.
|
||||
const trackAttachment = useCallback(
|
||||
(clientId: string, attachmentId: string) => {
|
||||
let handled = false;
|
||||
|
||||
const check = () => {
|
||||
const state = store.getState();
|
||||
const row = state.upload.attachments.find((a) => a.id === clientId);
|
||||
if (!row) return true; // removed by user; stop tracking
|
||||
if (row.status === 'completed' || row.status === 'failed') {
|
||||
handled = true;
|
||||
return true;
|
||||
}
|
||||
for (const event of state.notifications.recentEvents) {
|
||||
if (event.scope?.id !== attachmentId) continue;
|
||||
if (event.type === 'attachment.completed') {
|
||||
const payload = (event.payload || {}) as Record<string, unknown>;
|
||||
const tokenCount = Number(payload.token_count);
|
||||
handled = true;
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: clientId,
|
||||
updates: {
|
||||
status: 'completed',
|
||||
progress: 100,
|
||||
...(Number.isFinite(tokenCount)
|
||||
? { token_count: tokenCount }
|
||||
: {}),
|
||||
},
|
||||
}),
|
||||
);
|
||||
return true;
|
||||
}
|
||||
if (event.type === 'attachment.failed') {
|
||||
handled = true;
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: clientId,
|
||||
updates: { status: 'failed' },
|
||||
}),
|
||||
);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
if (check()) return;
|
||||
const MAX_WAIT_MS = 5 * 60_000;
|
||||
let unsubscribe: (() => void) | null = null;
|
||||
const timer = window.setTimeout(() => {
|
||||
unsubscribe?.();
|
||||
if (!handled) {
|
||||
handled = true;
|
||||
console.warn(
|
||||
'trackAttachment: timed out waiting for terminal SSE',
|
||||
clientId,
|
||||
attachmentId,
|
||||
);
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: clientId,
|
||||
updates: { status: 'failed' },
|
||||
}),
|
||||
);
|
||||
}
|
||||
}, MAX_WAIT_MS);
|
||||
unsubscribe = store.subscribe(() => {
|
||||
if (check()) {
|
||||
window.clearTimeout(timer);
|
||||
unsubscribe?.();
|
||||
}
|
||||
});
|
||||
},
|
||||
[dispatch, store],
|
||||
);
|
||||
|
||||
const uploadFiles = useCallback(
|
||||
(files: File[]) => {
|
||||
if (!files || files.length === 0) return;
|
||||
@@ -510,11 +592,19 @@ export default function MessageInput({
|
||||
id: uiId,
|
||||
updates: {
|
||||
taskId: task.task_id,
|
||||
// Stash the server's attachment id so SSE
|
||||
// ``attachment.*`` events can match this
|
||||
// row by ``scope.id`` and drive the
|
||||
// per-attachment push-fresh poll gate.
|
||||
attachmentId: task.attachment_id,
|
||||
status: 'processing',
|
||||
progress: 10,
|
||||
},
|
||||
}),
|
||||
);
|
||||
if (task.attachment_id) {
|
||||
trackAttachment(uiId, task.attachment_id);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -545,11 +635,15 @@ export default function MessageInput({
|
||||
id: uiId,
|
||||
updates: {
|
||||
taskId: t.task_id,
|
||||
attachmentId: t.attachment_id,
|
||||
status: 'processing',
|
||||
progress: 10,
|
||||
},
|
||||
}),
|
||||
);
|
||||
if (t.attachment_id) {
|
||||
trackAttachment(uiId, t.attachment_id);
|
||||
}
|
||||
} else {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
@@ -583,11 +677,15 @@ export default function MessageInput({
|
||||
id: uiId,
|
||||
updates: {
|
||||
taskId: response.task_id,
|
||||
attachmentId: response.attachment_id,
|
||||
status: 'processing',
|
||||
progress: 10,
|
||||
},
|
||||
}),
|
||||
);
|
||||
if (response.attachment_id) {
|
||||
trackAttachment(uiId, response.attachment_id);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
console.warn(
|
||||
@@ -714,11 +812,15 @@ export default function MessageInput({
|
||||
id: uniqueId,
|
||||
updates: {
|
||||
taskId: response.task_id,
|
||||
attachmentId: response.attachment_id,
|
||||
status: 'processing',
|
||||
progress: 10,
|
||||
},
|
||||
}),
|
||||
);
|
||||
if (response.attachment_id) {
|
||||
trackAttachment(uniqueId, response.attachment_id);
|
||||
}
|
||||
} else {
|
||||
// If backend returned tasks[] for single-file, handle gracefully:
|
||||
if (
|
||||
@@ -730,11 +832,15 @@ export default function MessageInput({
|
||||
id: uniqueId,
|
||||
updates: {
|
||||
taskId: response.tasks[0].task_id,
|
||||
attachmentId: response.tasks[0].attachment_id,
|
||||
status: 'processing',
|
||||
progress: 10,
|
||||
},
|
||||
}),
|
||||
);
|
||||
if (response.tasks[0].attachment_id) {
|
||||
trackAttachment(uniqueId, response.tasks[0].attachment_id);
|
||||
}
|
||||
} else {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
@@ -781,7 +887,7 @@ export default function MessageInput({
|
||||
xhr.send(formData);
|
||||
});
|
||||
},
|
||||
[dispatch, token],
|
||||
[dispatch, token, trackAttachment],
|
||||
);
|
||||
|
||||
const handleFileAttachment = (e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
@@ -816,65 +922,6 @@ export default function MessageInput({
|
||||
accept: FILE_UPLOAD_ACCEPT,
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
const checkTaskStatus = () => {
|
||||
const processingAttachments = attachments.filter(
|
||||
(att) => att.status === 'processing' && att.taskId,
|
||||
);
|
||||
|
||||
processingAttachments.forEach((attachment) => {
|
||||
userService
|
||||
.getTaskStatus(attachment.taskId!, null)
|
||||
.then((data) => data.json())
|
||||
.then((data) => {
|
||||
if (data.status === 'SUCCESS') {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: attachment.id,
|
||||
updates: {
|
||||
status: 'completed',
|
||||
progress: 100,
|
||||
id: data.result?.attachment_id,
|
||||
token_count: data.result?.token_count,
|
||||
},
|
||||
}),
|
||||
);
|
||||
} else if (data.status === 'FAILURE') {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: attachment.id,
|
||||
updates: { status: 'failed' },
|
||||
}),
|
||||
);
|
||||
} else if (data.status === 'PROGRESS' && data.result?.current) {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: attachment.id,
|
||||
updates: { progress: data.result.current },
|
||||
}),
|
||||
);
|
||||
}
|
||||
})
|
||||
.catch(() => {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: attachment.id,
|
||||
updates: { status: 'failed' },
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
const interval = setInterval(() => {
|
||||
if (attachments.some((att) => att.status === 'processing')) {
|
||||
checkTaskStatus();
|
||||
}
|
||||
}, 2000);
|
||||
|
||||
return () => clearInterval(interval);
|
||||
}, [attachments, dispatch]);
|
||||
|
||||
const handleInput = useCallback(() => {
|
||||
if (inputRef.current) {
|
||||
if (window.innerWidth < 350) inputRef.current.style.height = 'auto';
|
||||
|
||||
@@ -2,8 +2,7 @@ import { useState, useRef, useEffect } from 'react';
|
||||
import Speaker from '../assets/speaker.svg?react';
|
||||
import Stopspeech from '../assets/stopspeech.svg?react';
|
||||
import LoadingIcon from '../assets/Loading.svg?react'; // Add a loading icon SVG here
|
||||
|
||||
const apiHost = import.meta.env.VITE_API_HOST || 'https://docsapi.arc53.com';
|
||||
import userService from '../api/services/userService';
|
||||
|
||||
let currentlyPlayingAudio: {
|
||||
audio: HTMLAudioElement;
|
||||
@@ -114,12 +113,11 @@ export default function SpeakButton({ text }: { text: string }) {
|
||||
},
|
||||
};
|
||||
|
||||
const response = await fetch(apiHost + '/api/tts', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ text }),
|
||||
signal: abortController.signal,
|
||||
});
|
||||
const response = await userService.textToSpeech(
|
||||
text,
|
||||
null,
|
||||
abortController.signal,
|
||||
);
|
||||
|
||||
const data = await response.json();
|
||||
abortControllerRef.current = null;
|
||||
|
||||
@@ -5,41 +5,54 @@ import { useDispatch, useSelector } from 'react-redux';
|
||||
import CheckCircleFilled from '../assets/check-circle-filled.svg';
|
||||
import ChevronDown from '../assets/chevron-down.svg';
|
||||
import WarnIcon from '../assets/warn.svg';
|
||||
import { dismissUploadTask, selectUploadTasks } from '../upload/uploadSlice';
|
||||
import {
|
||||
dismissUploadTask,
|
||||
selectUploadTasks,
|
||||
type UploadTask,
|
||||
} from '../upload/uploadSlice';
|
||||
|
||||
const PROGRESS_RADIUS = 10;
|
||||
const PROGRESS_CIRCUMFERENCE = 2 * Math.PI * PROGRESS_RADIUS;
|
||||
|
||||
export default function UploadToast() {
|
||||
const [collapsedTasks, setCollapsedTasks] = useState<Record<string, boolean>>(
|
||||
{},
|
||||
);
|
||||
const IN_PROGRESS_STATUSES = new Set<UploadTask['status']>([
|
||||
'preparing',
|
||||
'uploading',
|
||||
'training',
|
||||
]);
|
||||
|
||||
const toggleTaskCollapse = (taskId: string) => {
|
||||
setCollapsedTasks((prev) => ({
|
||||
...prev,
|
||||
[taskId]: !prev[taskId],
|
||||
}));
|
||||
};
|
||||
/**
|
||||
* Single merged upload card — Google-Drive style. Multiple in-flight
|
||||
* uploads share one toast with a list of rows; the header reflects
|
||||
* the *primary* task's status (the newest still-running task, or the
|
||||
* newest task overall if all are terminal). Per-task progress lives
|
||||
* on each row.
|
||||
*
|
||||
* Dismissal: the header X dismisses every visible task at once
|
||||
* (mirrors the GDrive panel close — keeps the surface tidy without
|
||||
* per-row controls). The chevron collapses the row list.
|
||||
*/
|
||||
export default function UploadToast() {
|
||||
const [collapsed, setCollapsed] = useState(false);
|
||||
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useDispatch();
|
||||
const uploadTasks = useSelector(selectUploadTasks);
|
||||
|
||||
const getStatusHeading = (status: string) => {
|
||||
switch (status) {
|
||||
case 'preparing':
|
||||
return t('modals.uploadDoc.progress.wait');
|
||||
case 'uploading':
|
||||
return t('modals.uploadDoc.progress.upload');
|
||||
case 'training':
|
||||
return t('modals.uploadDoc.progress.upload');
|
||||
case 'completed':
|
||||
return t('modals.uploadDoc.progress.completed');
|
||||
case 'failed':
|
||||
return t('modals.uploadDoc.progress.failed');
|
||||
default:
|
||||
return t('modals.uploadDoc.progress.preparing');
|
||||
const visibleTasks = uploadTasks.filter((task) => !task.dismissed);
|
||||
if (visibleTasks.length === 0) return null;
|
||||
|
||||
// Pick the task that drives the header status: prefer a still-
|
||||
// running task (most-recent first since the slice unshifts), and
|
||||
// fall back to whatever's most-recent if everything is terminal.
|
||||
const primaryTask =
|
||||
visibleTasks.find((task) => IN_PROGRESS_STATUSES.has(task.status)) ??
|
||||
visibleTasks[0];
|
||||
|
||||
const headerLabel = getStatusHeading(primaryTask.status, t);
|
||||
|
||||
const dismissAll = () => {
|
||||
for (const task of visibleTasks) {
|
||||
dispatch(dismissUploadTask(task.id));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -47,180 +60,205 @@ export default function UploadToast() {
|
||||
<div
|
||||
className="fixed right-4 bottom-4 z-50 flex max-w-md flex-col gap-2"
|
||||
onMouseDown={(e) => e.stopPropagation()}
|
||||
role="status"
|
||||
aria-live="polite"
|
||||
aria-atomic="false"
|
||||
>
|
||||
{uploadTasks
|
||||
.filter((task) => !task.dismissed)
|
||||
.map((task) => {
|
||||
const shouldShowProgress = [
|
||||
'preparing',
|
||||
'uploading',
|
||||
'training',
|
||||
].includes(task.status);
|
||||
const rawProgress = Math.min(Math.max(task.progress ?? 0, 0), 100);
|
||||
const formattedProgress = Math.round(rawProgress);
|
||||
const progressOffset =
|
||||
PROGRESS_CIRCUMFERENCE * (1 - rawProgress / 100);
|
||||
const isCollapsed = collapsedTasks[task.id] ?? false;
|
||||
|
||||
return (
|
||||
<div
|
||||
key={task.id}
|
||||
className={`border-border bg-card w-[271px] overflow-hidden rounded-2xl border shadow-[0px_24px_48px_0px_#00000029] transition-all duration-300`}
|
||||
<div
|
||||
className={`border-border bg-card w-[271px] overflow-hidden rounded-2xl border shadow-[0px_24px_48px_0px_#00000029] transition-all duration-300`}
|
||||
>
|
||||
<div
|
||||
className={`flex items-center justify-between px-4 py-3 ${
|
||||
primaryTask.status !== 'failed'
|
||||
? 'bg-accent/50 dark:bg-muted'
|
||||
: 'bg-destructive/10 dark:bg-destructive/10'
|
||||
}`}
|
||||
>
|
||||
<h3 className="font-inter dark:text-foreground text-[14px] leading-[16.5px] font-medium text-black">
|
||||
{headerLabel}
|
||||
</h3>
|
||||
<div className="flex items-center gap-1">
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => setCollapsed((prev) => !prev)}
|
||||
aria-label={
|
||||
collapsed
|
||||
? t('modals.uploadDoc.progress.expandDetails')
|
||||
: t('modals.uploadDoc.progress.collapseDetails')
|
||||
}
|
||||
className="flex h-8 items-center justify-center p-0 text-black opacity-70 transition-opacity hover:opacity-100 dark:text-white"
|
||||
>
|
||||
<div className="flex flex-col">
|
||||
<div
|
||||
className={`flex items-center justify-between px-4 py-3 ${
|
||||
task.status !== 'failed'
|
||||
? 'bg-accent/50 dark:bg-muted'
|
||||
: 'bg-destructive/10 dark:bg-destructive/10'
|
||||
}`}
|
||||
>
|
||||
<h3 className="font-inter dark:text-foreground text-[14px] leading-[16.5px] font-medium text-black">
|
||||
{getStatusHeading(task.status)}
|
||||
</h3>
|
||||
<div className="flex items-center gap-1">
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => toggleTaskCollapse(task.id)}
|
||||
aria-label={
|
||||
isCollapsed
|
||||
? t('modals.uploadDoc.progress.expandDetails')
|
||||
: t('modals.uploadDoc.progress.collapseDetails')
|
||||
}
|
||||
className="flex h-8 items-center justify-center p-0 text-black opacity-70 transition-opacity hover:opacity-100 dark:text-white"
|
||||
>
|
||||
<img
|
||||
src={ChevronDown}
|
||||
alt=""
|
||||
className={`h-4 w-4 transform transition-transform duration-200 dark:invert ${
|
||||
isCollapsed ? 'rotate-180' : ''
|
||||
}`}
|
||||
/>
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => dispatch(dismissUploadTask(task.id))}
|
||||
className="flex h-8 items-center justify-center p-0 text-black opacity-70 transition-opacity hover:opacity-100 dark:text-white"
|
||||
aria-label={t('modals.uploadDoc.progress.dismiss')}
|
||||
>
|
||||
<svg
|
||||
width="16"
|
||||
height="16"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
className="h-4 w-4"
|
||||
>
|
||||
<path
|
||||
d="M18 6L6 18"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
<path
|
||||
d="M6 6L18 18"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<img
|
||||
src={ChevronDown}
|
||||
alt=""
|
||||
className={`h-4 w-4 transform transition-transform duration-200 dark:invert ${
|
||||
collapsed ? 'rotate-180' : ''
|
||||
}`}
|
||||
/>
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
onClick={dismissAll}
|
||||
className="flex h-8 items-center justify-center p-0 text-black opacity-70 transition-opacity hover:opacity-100 dark:text-white"
|
||||
aria-label={t('modals.uploadDoc.progress.dismiss')}
|
||||
>
|
||||
<svg
|
||||
width="16"
|
||||
height="16"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
className="h-4 w-4"
|
||||
>
|
||||
<path
|
||||
d="M18 6L6 18"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
<path
|
||||
d="M6 6L18 18"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div
|
||||
className="grid overflow-hidden transition-[grid-template-rows] duration-300 ease-out"
|
||||
style={{ gridTemplateRows: isCollapsed ? '0fr' : '1fr' }}
|
||||
>
|
||||
<div
|
||||
className={`min-h-0 overflow-hidden transition-opacity duration-300 ${
|
||||
isCollapsed ? 'opacity-0' : 'opacity-100'
|
||||
}`}
|
||||
>
|
||||
<div className="flex items-center justify-between px-5 py-3">
|
||||
<p
|
||||
className="font-inter dark:text-muted-foreground max-w-[200px] truncate text-[13px] leading-[16.5px] font-normal text-black"
|
||||
title={task.fileName}
|
||||
>
|
||||
{task.fileName}
|
||||
</p>
|
||||
|
||||
<div className="flex items-center gap-2">
|
||||
{shouldShowProgress && (
|
||||
<svg
|
||||
width="24"
|
||||
height="24"
|
||||
viewBox="0 0 24 24"
|
||||
className="h-6 w-6 shrink-0 text-[#7D54D1]"
|
||||
role="progressbar"
|
||||
aria-valuemin={0}
|
||||
aria-valuemax={100}
|
||||
aria-valuenow={formattedProgress}
|
||||
aria-label={t(
|
||||
'modals.uploadDoc.progress.uploadProgress',
|
||||
{
|
||||
progress: formattedProgress,
|
||||
},
|
||||
)}
|
||||
>
|
||||
<circle
|
||||
className="text-muted dark:text-muted-foreground/30"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
cx="12"
|
||||
cy="12"
|
||||
r={PROGRESS_RADIUS}
|
||||
fill="none"
|
||||
/>
|
||||
<circle
|
||||
className="text-[#7D54D1]"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeDasharray={PROGRESS_CIRCUMFERENCE}
|
||||
strokeDashoffset={progressOffset}
|
||||
cx="12"
|
||||
cy="12"
|
||||
r={PROGRESS_RADIUS}
|
||||
fill="none"
|
||||
transform="rotate(-90 12 12)"
|
||||
/>
|
||||
</svg>
|
||||
)}
|
||||
|
||||
{task.status === 'completed' && (
|
||||
<img
|
||||
src={CheckCircleFilled}
|
||||
alt=""
|
||||
className="h-6 w-6 shrink-0"
|
||||
aria-hidden="true"
|
||||
/>
|
||||
)}
|
||||
|
||||
{task.status === 'failed' && (
|
||||
<img
|
||||
src={WarnIcon}
|
||||
alt=""
|
||||
className="h-6 w-6 shrink-0"
|
||||
aria-hidden="true"
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{task.status === 'failed' && task.errorMessage && (
|
||||
<span className="block px-5 pb-3 text-xs text-red-500">
|
||||
{task.errorMessage}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
<div
|
||||
className="grid overflow-hidden transition-[grid-template-rows] duration-300 ease-out"
|
||||
style={{ gridTemplateRows: collapsed ? '0fr' : '1fr' }}
|
||||
>
|
||||
<div
|
||||
className={`min-h-0 overflow-hidden transition-opacity duration-300 ${
|
||||
collapsed ? 'opacity-0' : 'opacity-100'
|
||||
}`}
|
||||
>
|
||||
<ul className="max-h-72 overflow-y-auto">
|
||||
{visibleTasks.map((task) => (
|
||||
<UploadRow key={task.id} task={task} t={t} />
|
||||
))}
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function UploadRow({
|
||||
task,
|
||||
t,
|
||||
}: {
|
||||
task: UploadTask;
|
||||
t: ReturnType<typeof useTranslation>['t'];
|
||||
}) {
|
||||
const showProgress = IN_PROGRESS_STATUSES.has(task.status);
|
||||
const rawProgress = Math.min(Math.max(task.progress ?? 0, 0), 100);
|
||||
const formattedProgress = Math.round(rawProgress);
|
||||
const progressOffset = PROGRESS_CIRCUMFERENCE * (1 - rawProgress / 100);
|
||||
|
||||
return (
|
||||
<li className="border-border/50 border-b last:border-b-0">
|
||||
<div className="flex items-center justify-between px-5 py-3">
|
||||
<p
|
||||
className="font-inter dark:text-muted-foreground max-w-[200px] truncate text-[13px] leading-[16.5px] font-normal text-black"
|
||||
title={task.fileName}
|
||||
>
|
||||
{task.fileName}
|
||||
</p>
|
||||
|
||||
<div className="flex items-center gap-2">
|
||||
{showProgress && (
|
||||
<svg
|
||||
width="24"
|
||||
height="24"
|
||||
viewBox="0 0 24 24"
|
||||
className="h-6 w-6 shrink-0 text-[#7D54D1]"
|
||||
role="progressbar"
|
||||
aria-valuemin={0}
|
||||
aria-valuemax={100}
|
||||
aria-valuenow={formattedProgress}
|
||||
aria-label={t('modals.uploadDoc.progress.uploadProgress', {
|
||||
progress: formattedProgress,
|
||||
})}
|
||||
>
|
||||
<circle
|
||||
className="text-muted dark:text-muted-foreground/30"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
cx="12"
|
||||
cy="12"
|
||||
r={PROGRESS_RADIUS}
|
||||
fill="none"
|
||||
/>
|
||||
<circle
|
||||
className="text-[#7D54D1]"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeDasharray={PROGRESS_CIRCUMFERENCE}
|
||||
strokeDashoffset={progressOffset}
|
||||
cx="12"
|
||||
cy="12"
|
||||
r={PROGRESS_RADIUS}
|
||||
fill="none"
|
||||
transform="rotate(-90 12 12)"
|
||||
/>
|
||||
</svg>
|
||||
)}
|
||||
|
||||
{task.status === 'completed' && (
|
||||
<img
|
||||
src={CheckCircleFilled}
|
||||
alt=""
|
||||
className="h-6 w-6 shrink-0"
|
||||
aria-hidden="true"
|
||||
/>
|
||||
)}
|
||||
|
||||
{task.status === 'failed' && (
|
||||
<img
|
||||
src={WarnIcon}
|
||||
alt=""
|
||||
className="h-6 w-6 shrink-0"
|
||||
aria-hidden="true"
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{task.status === 'failed' &&
|
||||
(task.tokenLimitReached || task.errorMessage) && (
|
||||
<span className="block px-5 pb-3 text-xs text-red-500">
|
||||
{task.tokenLimitReached
|
||||
? t('modals.uploadDoc.progress.tokenLimit')
|
||||
: task.errorMessage}
|
||||
</span>
|
||||
)}
|
||||
</li>
|
||||
);
|
||||
}
|
||||
|
||||
function getStatusHeading(
|
||||
status: UploadTask['status'],
|
||||
t: ReturnType<typeof useTranslation>['t'],
|
||||
): string {
|
||||
switch (status) {
|
||||
case 'preparing':
|
||||
return t('modals.uploadDoc.progress.wait');
|
||||
case 'uploading':
|
||||
case 'training':
|
||||
return t('modals.uploadDoc.progress.upload');
|
||||
case 'completed':
|
||||
return t('modals.uploadDoc.progress.completed');
|
||||
case 'failed':
|
||||
return t('modals.uploadDoc.progress.failed');
|
||||
default:
|
||||
return t('modals.uploadDoc.progress.preparing');
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
import { useCallback, useEffect, useRef, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useDispatch, useSelector } from 'react-redux';
|
||||
import { useNavigate, useParams } from 'react-router-dom';
|
||||
|
||||
import userService from '../api/services/userService';
|
||||
import SharedAgentCard from '../agents/SharedAgentCard';
|
||||
import { Agent } from '../agents/types';
|
||||
import ArtifactSidebar from '../components/ArtifactSidebar';
|
||||
import MessageInput from '../components/MessageInput';
|
||||
import { useMediaQuery } from '../hooks';
|
||||
@@ -10,6 +13,7 @@ import {
|
||||
selectConversationId,
|
||||
selectSelectedAgent,
|
||||
selectToken,
|
||||
setSelectedAgent,
|
||||
} from '../preferences/preferenceSlice';
|
||||
import { AppDispatch } from '../store';
|
||||
import { handleSendFeedback } from './conversationHandlers';
|
||||
@@ -19,7 +23,9 @@ import { ToolCallsType } from './types';
|
||||
import {
|
||||
addQuery,
|
||||
fetchAnswer,
|
||||
loadConversation,
|
||||
resendQuery,
|
||||
resetConversation,
|
||||
selectQueries,
|
||||
selectStatus,
|
||||
submitToolActions,
|
||||
@@ -31,6 +37,16 @@ export default function Conversation() {
|
||||
const { t } = useTranslation();
|
||||
const { isMobile } = useMediaQuery();
|
||||
const dispatch = useDispatch<AppDispatch>();
|
||||
const navigate = useNavigate();
|
||||
const params = useParams<{
|
||||
conversationId?: string;
|
||||
agentId?: string;
|
||||
}>();
|
||||
const urlConversationId = params.conversationId;
|
||||
const urlAgentId = params.agentId;
|
||||
// ``new`` is treated as empty-chat intent, not a real id to fetch.
|
||||
const isNewChatRoute =
|
||||
urlConversationId === undefined || urlConversationId === 'new';
|
||||
|
||||
const token = useSelector(selectToken);
|
||||
const queries = useSelector(selectQueries);
|
||||
@@ -42,6 +58,65 @@ export default function Conversation() {
|
||||
const [lastQueryReturnedErr, setLastQueryReturnedErr] =
|
||||
useState<boolean>(false);
|
||||
|
||||
// URL → state. Thunk short-circuits when Redux already matches.
|
||||
useEffect(() => {
|
||||
if (isNewChatRoute) {
|
||||
// Skip when nothing to reset; avoids wiping the in-flight stream
|
||||
// during the null → assigned-id replace below.
|
||||
if (conversationId !== null) {
|
||||
dispatch(resetConversation());
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (urlConversationId && urlConversationId !== conversationId) {
|
||||
dispatch(loadConversation({ id: urlConversationId }))
|
||||
.unwrap()
|
||||
.then((result) => {
|
||||
if (result.stale) return;
|
||||
if (result.data === null) {
|
||||
navigate('/c/new', { replace: true });
|
||||
}
|
||||
})
|
||||
.catch(() => navigate('/c/new', { replace: true }));
|
||||
}
|
||||
}, [urlConversationId, isNewChatRoute]);
|
||||
|
||||
// Agent context follows the URL. ``cancelled`` covers two races:
|
||||
// the user switches agents before the fetch resolves, or leaves the
|
||||
// agent route entirely; either way the late dispatch must be dropped.
|
||||
useEffect(() => {
|
||||
let cancelled = false;
|
||||
if (urlAgentId) {
|
||||
if (selectedAgent?.id !== urlAgentId) {
|
||||
userService
|
||||
.getAgent(urlAgentId, token)
|
||||
.then((response) => (response.ok ? response.json() : null))
|
||||
.then((agent: Agent | null) => {
|
||||
if (cancelled) return;
|
||||
if (agent) dispatch(setSelectedAgent(agent));
|
||||
})
|
||||
.catch((err) => {
|
||||
if (!cancelled) console.error('Failed to load agent:', err);
|
||||
});
|
||||
}
|
||||
} else if (selectedAgent !== null) {
|
||||
dispatch(setSelectedAgent(null));
|
||||
}
|
||||
return () => {
|
||||
cancelled = true;
|
||||
};
|
||||
}, [urlAgentId, token]);
|
||||
|
||||
// State → URL. ``replace`` so Back doesn't return to /c/new and
|
||||
// reset the just-streamed chat.
|
||||
useEffect(() => {
|
||||
if (!isNewChatRoute || !conversationId) return;
|
||||
const target = urlAgentId
|
||||
? `/agents/${urlAgentId}/c/${conversationId}`
|
||||
: `/c/${conversationId}`;
|
||||
navigate(target, { replace: true });
|
||||
}, [conversationId, isNewChatRoute, urlAgentId]);
|
||||
|
||||
const handleToolAction = useCallback(
|
||||
(callId: string, decision: 'approved' | 'denied', comment?: string) => {
|
||||
dispatch(
|
||||
@@ -101,7 +176,13 @@ export default function Conversation() {
|
||||
.map((a) => ({ id: a.id as string, fileName: a.fileName }));
|
||||
|
||||
if (index !== undefined) {
|
||||
dispatch(resendQuery({ index, prompt: trimmedQuestion }));
|
||||
dispatch(
|
||||
resendQuery({
|
||||
index,
|
||||
prompt: trimmedQuestion,
|
||||
keepIdempotencyKey: isRetry,
|
||||
}),
|
||||
);
|
||||
handleFetchAnswer({ question: trimmedQuestion, index });
|
||||
} else {
|
||||
if (!isRetry)
|
||||
@@ -151,17 +232,22 @@ export default function Conversation() {
|
||||
} else if (question && status !== 'loading') {
|
||||
if (lastQueryReturnedErr && queries.length > 0) {
|
||||
const retryIndex = queries.length - 1;
|
||||
dispatch(
|
||||
updateQuery({
|
||||
index: retryIndex,
|
||||
query: {
|
||||
prompt: question,
|
||||
},
|
||||
}),
|
||||
);
|
||||
// Different prompt = new logical action, fresh idempotency key.
|
||||
const prevPrompt = queries[retryIndex].prompt;
|
||||
const isSamePrompt = prevPrompt === question;
|
||||
if (!isSamePrompt) {
|
||||
dispatch(
|
||||
updateQuery({
|
||||
index: retryIndex,
|
||||
query: {
|
||||
prompt: question,
|
||||
},
|
||||
}),
|
||||
);
|
||||
}
|
||||
handleQuestion({
|
||||
question,
|
||||
isRetry: true,
|
||||
isRetry: isSamePrompt,
|
||||
index: retryIndex,
|
||||
});
|
||||
} else {
|
||||
@@ -236,7 +322,7 @@ export default function Conversation() {
|
||||
isSplitArtifactOpen ? 'w-[60%] px-6' : 'w-full'
|
||||
}`}
|
||||
>
|
||||
<div className="relative min-h-0 flex-1 ">
|
||||
<div className="relative min-h-0 flex-1">
|
||||
<ConversationMessages
|
||||
handleQuestion={handleQuestion}
|
||||
handleQuestionSubmission={handleQuestionSubmission}
|
||||
@@ -250,7 +336,19 @@ export default function Conversation() {
|
||||
headerContent={
|
||||
selectedAgent ? (
|
||||
<div className="flex w-full items-center justify-center py-4">
|
||||
<SharedAgentCard agent={selectedAgent} />
|
||||
<SharedAgentCard
|
||||
agent={selectedAgent}
|
||||
onEdit={
|
||||
selectedAgent.id
|
||||
? () =>
|
||||
navigate(
|
||||
selectedAgent.agent_type === 'workflow'
|
||||
? `/agents/workflow/edit/${selectedAgent.id}`
|
||||
: `/agents/edit/${selectedAgent.id}`,
|
||||
)
|
||||
: undefined
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
) : undefined
|
||||
}
|
||||
|
||||
@@ -132,6 +132,8 @@ const ConversationBubble = forwardRef<
|
||||
}, [message]);
|
||||
|
||||
const handleEditClick = () => {
|
||||
if (!editInputBox.trim() || editInputBox.trim() === (message ?? '').trim())
|
||||
return;
|
||||
setIsEditClicked(false);
|
||||
handleUpdatedQuestionSubmission?.(editInputBox, true, questionNumber);
|
||||
};
|
||||
@@ -242,8 +244,12 @@ const ConversationBubble = forwardRef<
|
||||
{t('conversation.edit.cancel')}
|
||||
</button>
|
||||
<button
|
||||
className="bg-primary hover:bg-primary/90 dark:hover:bg-primary/90 rounded-full px-4 py-2 text-sm font-medium text-white transition-colors"
|
||||
className="bg-primary not-disabled:hover:bg-primary/90 not-disabled:dark:hover:bg-primary/90 disabled:bg-primary/30 rounded-full px-4 py-2 text-sm font-medium text-white transition-colors disabled:cursor-not-allowed"
|
||||
onClick={handleEditClick}
|
||||
disabled={
|
||||
!editInputBox.trim() ||
|
||||
editInputBox.trim() === (message ?? '').trim()
|
||||
}
|
||||
>
|
||||
{t('conversation.edit.update')}
|
||||
</button>
|
||||
|
||||
@@ -248,32 +248,8 @@ export default function ConversationMessages({
|
||||
? LAST_BUBBLE_MARGIN
|
||||
: DEFAULT_BUBBLE_MARGIN;
|
||||
|
||||
if (query.thought || query.response || query.tool_calls || query.research) {
|
||||
const isCurrentlyStreaming =
|
||||
status === 'loading' && index === queries.length - 1;
|
||||
return (
|
||||
<ConversationBubble
|
||||
className={bubbleMargin}
|
||||
key={`${index}-ANSWER`}
|
||||
message={query.response}
|
||||
type={'ANSWER'}
|
||||
thought={query.thought}
|
||||
sources={query.sources}
|
||||
toolCalls={query.tool_calls}
|
||||
research={query.research}
|
||||
onOpenArtifact={onOpenArtifact}
|
||||
onToolAction={onToolAction}
|
||||
feedback={query.feedback}
|
||||
isStreaming={isCurrentlyStreaming}
|
||||
handleFeedback={
|
||||
handleFeedback
|
||||
? (feedback) => handleFeedback(query, feedback, index)
|
||||
: undefined
|
||||
}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
// Error first; reconciler-failed rows may carry partial thought/
|
||||
// tool_calls and would otherwise fall into the answer branch.
|
||||
if (query.error) {
|
||||
const retryButton = (
|
||||
<button
|
||||
@@ -303,6 +279,38 @@ export default function ConversationMessages({
|
||||
);
|
||||
}
|
||||
|
||||
// tool_calls.length, not tool_calls — empty arrays are JS-truthy.
|
||||
const hasContent =
|
||||
query.thought ||
|
||||
query.response ||
|
||||
(query.tool_calls && query.tool_calls.length > 0) ||
|
||||
query.research;
|
||||
if (hasContent) {
|
||||
const isCurrentlyStreaming =
|
||||
status === 'loading' && index === queries.length - 1;
|
||||
return (
|
||||
<ConversationBubble
|
||||
className={bubbleMargin}
|
||||
key={`${index}-ANSWER`}
|
||||
message={query.response}
|
||||
type={'ANSWER'}
|
||||
thought={query.thought}
|
||||
sources={query.sources}
|
||||
toolCalls={query.tool_calls}
|
||||
research={query.research}
|
||||
onOpenArtifact={onOpenArtifact}
|
||||
onToolAction={onToolAction}
|
||||
feedback={query.feedback}
|
||||
isStreaming={isCurrentlyStreaming}
|
||||
handleFeedback={
|
||||
handleFeedback
|
||||
? (feedback) => handleFeedback(query, feedback, index)
|
||||
: undefined
|
||||
}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (status === 'loading' && isLastMessage) {
|
||||
return (
|
||||
<div
|
||||
|
||||
@@ -64,7 +64,10 @@ export default function ConversationTile({
|
||||
}
|
||||
|
||||
function handleSaveConversation(changedConversation: ConversationProps) {
|
||||
if (changedConversation.name.trim().length) {
|
||||
if (
|
||||
changedConversation.name.trim().length &&
|
||||
changedConversation.name.trim() !== conversation.name.trim()
|
||||
) {
|
||||
onSave(changedConversation);
|
||||
setIsEdit(false);
|
||||
} else {
|
||||
|
||||
@@ -1,8 +1,155 @@
|
||||
import { baseURL } from '../api/client';
|
||||
import conversationService from '../api/services/conversationService';
|
||||
import { Doc } from '../models/misc';
|
||||
import { Answer, FEEDBACK, RetrievalPayload } from './conversationModels';
|
||||
import { ToolCallsType } from './types';
|
||||
|
||||
/**
|
||||
* Mirrors the backend's ``_SEQUENCE_NO_RE`` (application/api/answer/
|
||||
* routes/messages.py) — only non-negative decimal integers are valid
|
||||
* cursors. Rejects empty strings (Number("") === 0), hex literals,
|
||||
* exponential notation, and anything else that ``Number(...)`` would
|
||||
* happily coerce.
|
||||
*/
|
||||
const _SEQUENCE_NO_RE = /^\d+$/;
|
||||
|
||||
/**
|
||||
* Drain an SSE response body, forwarding each ``data:`` line to
|
||||
* ``onData`` and tracking the most recent ``id:`` header. Returns
|
||||
* when the body ends, the signal aborts, or ``shouldStop()`` returns
|
||||
* true (e.g. a terminal ``end``/``error`` event was dispatched —
|
||||
* the reconnect endpoint is a live tail that doesn't close on its
|
||||
* own past terminal replay).
|
||||
*/
|
||||
/**
|
||||
* Convert a non-SSE pre-stream HTTP failure (e.g. ``check_usage``'s
|
||||
* 429 JSON response) into a synthetic typed ``error`` frame so the
|
||||
* caller's slice sees the actual server message instead of the
|
||||
* generic "Connection lost" synthesised when the drainer finishes
|
||||
* with zero events. Returns true if a frame was dispatched and the
|
||||
* caller should skip ``_drainSseBody`` entirely.
|
||||
*
|
||||
* SSE-shaped error bodies (``mimetype="text/event-stream"``) are
|
||||
* left alone — the drainer parses the typed ``error`` frame they
|
||||
* carry through the normal path.
|
||||
*/
|
||||
async function _handlePreStreamHttpError(
|
||||
response: Response,
|
||||
dispatch: (data: string) => void,
|
||||
): Promise<boolean> {
|
||||
if (response.ok) return false;
|
||||
const contentType = (
|
||||
response.headers.get('content-type') ?? ''
|
||||
).toLowerCase();
|
||||
if (contentType.includes('text/event-stream')) return false;
|
||||
let message: string | null = null;
|
||||
try {
|
||||
const text = await response.text();
|
||||
if (text) {
|
||||
try {
|
||||
const parsed = JSON.parse(text);
|
||||
if (parsed && typeof parsed === 'object') {
|
||||
message =
|
||||
(typeof parsed.message === 'string' && parsed.message) ||
|
||||
(typeof parsed.error === 'string' && parsed.error) ||
|
||||
(typeof parsed.detail === 'string' && parsed.detail) ||
|
||||
null;
|
||||
}
|
||||
} catch {
|
||||
message = text.slice(0, 500);
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// Body already consumed or unreadable — fall through to the
|
||||
// status-line fallback below.
|
||||
}
|
||||
if (!message) {
|
||||
message = `HTTP ${response.status} ${response.statusText}`.trim();
|
||||
}
|
||||
dispatch(JSON.stringify({ type: 'error', error: message }));
|
||||
return true;
|
||||
}
|
||||
|
||||
async function _drainSseBody(
|
||||
body: ReadableStream<Uint8Array>,
|
||||
signal: AbortSignal,
|
||||
onData: (data: string) => void,
|
||||
onId: (id: number) => void,
|
||||
shouldStop?: () => boolean,
|
||||
): Promise<void> {
|
||||
const reader = body.getReader();
|
||||
const decoder = new TextDecoder('utf-8');
|
||||
let buffer = '';
|
||||
let stoppedEarly = false;
|
||||
try {
|
||||
while (true) {
|
||||
if (signal.aborted) break;
|
||||
if (shouldStop?.()) {
|
||||
stoppedEarly = true;
|
||||
break;
|
||||
}
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
// Normalise mixed line terminators so a stray CR can't smuggle
|
||||
// a record boundary inside a JSON payload.
|
||||
buffer = buffer.replace(/\r\n/g, '\n').replace(/\r/g, '\n');
|
||||
let boundary = buffer.indexOf('\n\n');
|
||||
while (boundary !== -1) {
|
||||
const record = buffer.slice(0, boundary);
|
||||
buffer = buffer.slice(boundary + 2);
|
||||
boundary = buffer.indexOf('\n\n');
|
||||
if (record.length === 0) continue;
|
||||
const dataParts: string[] = [];
|
||||
let sawDataField = false;
|
||||
for (const line of record.split('\n')) {
|
||||
if (line.length === 0) continue;
|
||||
if (line.startsWith(':')) continue; // SSE comment / keepalive
|
||||
const colonIdx = line.indexOf(':');
|
||||
const field = colonIdx === -1 ? line : line.slice(0, colonIdx);
|
||||
let value = colonIdx === -1 ? '' : line.slice(colonIdx + 1);
|
||||
if (value.startsWith(' ')) value = value.slice(1);
|
||||
if (field === 'id') {
|
||||
// Strict regex match — empty value, hex, ``-1`` (the
|
||||
// backend's terminal snapshot-failure synthetic), and
|
||||
// exponent forms are all rejected so they can't silently
|
||||
// rewrite the reconnect cursor.
|
||||
if (_SEQUENCE_NO_RE.test(value)) onId(parseInt(value, 10));
|
||||
} else if (field === 'data') {
|
||||
sawDataField = true;
|
||||
dataParts.push(value);
|
||||
}
|
||||
}
|
||||
if (!sawDataField) continue;
|
||||
const data = dataParts.join('\n').trim();
|
||||
if (data.length === 0) continue;
|
||||
onData(data);
|
||||
if (shouldStop?.()) {
|
||||
stoppedEarly = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (stoppedEarly) break;
|
||||
}
|
||||
} finally {
|
||||
if (stoppedEarly) {
|
||||
// Ask the runtime to tear the underlying response body down so
|
||||
// the server-side WSGI thread isn't pinned waiting on
|
||||
// keepalives. ``releaseLock`` alone leaves the body half-open.
|
||||
try {
|
||||
await reader.cancel();
|
||||
} catch {
|
||||
// Already errored / closed.
|
||||
}
|
||||
}
|
||||
try {
|
||||
reader.releaseLock();
|
||||
} catch {
|
||||
// Already released.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export function handleFetchAnswer(
|
||||
question: string,
|
||||
signal: AbortSignal,
|
||||
@@ -15,6 +162,7 @@ export function handleFetchAnswer(
|
||||
attachments?: string[],
|
||||
save_conversation = true,
|
||||
modelId?: string,
|
||||
idempotencyKey?: string,
|
||||
): Promise<
|
||||
| {
|
||||
result: any;
|
||||
@@ -66,8 +214,10 @@ export function handleFetchAnswer(
|
||||
payload.retriever = selectedDocs[0].retriever as string;
|
||||
}
|
||||
}
|
||||
const headers: Record<string, string> = {};
|
||||
if (idempotencyKey) headers['Idempotency-Key'] = idempotencyKey;
|
||||
return conversationService
|
||||
.answer(payload, token, signal)
|
||||
.answer(payload, token, signal, headers)
|
||||
.then((response) => {
|
||||
if (response.ok) {
|
||||
return response.json();
|
||||
@@ -104,6 +254,7 @@ export function handleFetchAnswerSteaming(
|
||||
attachments?: string[],
|
||||
save_conversation = true,
|
||||
modelId?: string,
|
||||
idempotencyKey?: string,
|
||||
): Promise<Answer> {
|
||||
const payload: RetrievalPayload = {
|
||||
question: question,
|
||||
@@ -137,54 +288,155 @@ export function handleFetchAnswerSteaming(
|
||||
}
|
||||
}
|
||||
|
||||
const headers: Record<string, string> = {};
|
||||
if (idempotencyKey) headers['Idempotency-Key'] = idempotencyKey;
|
||||
|
||||
// Per-stream state used for reconnect-after-disconnect.
|
||||
let messageId: string | null = null;
|
||||
let lastEventId: number | null = null;
|
||||
// The single JSON.parse below feeds both the message_id capture and
|
||||
// the termination flag — cheaper and stricter than substring
|
||||
// matching the wire bytes.
|
||||
let endReceived = false;
|
||||
|
||||
const dispatch = (data: string) => {
|
||||
try {
|
||||
const parsed = JSON.parse(data);
|
||||
if (parsed && typeof parsed === 'object') {
|
||||
if (parsed.type === 'message_id' && parsed.message_id) {
|
||||
messageId = parsed.message_id;
|
||||
} else if (parsed.type === 'end' || parsed.type === 'error') {
|
||||
endReceived = true;
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// Not JSON — pass through anyway; the caller handles raw lines.
|
||||
}
|
||||
onEvent(new MessageEvent('message', { data }));
|
||||
};
|
||||
|
||||
const runInitialPost = async (): Promise<void> => {
|
||||
const response = await conversationService.answerStream(
|
||||
payload,
|
||||
token,
|
||||
signal,
|
||||
headers,
|
||||
);
|
||||
// Pre-stream HTTP failures with non-SSE bodies (e.g. ``check_usage``
|
||||
// returning a JSON 429) drain as zero events and would otherwise
|
||||
// be masked by the generic "Connection lost" synthetic. Convert
|
||||
// them into a typed ``error`` frame so the real message surfaces.
|
||||
if (await _handlePreStreamHttpError(response, dispatch)) return;
|
||||
if (!response.body) throw new Error('No response body');
|
||||
await _drainSseBody(response.body, signal, dispatch, (id) => {
|
||||
lastEventId = id;
|
||||
});
|
||||
};
|
||||
|
||||
// Reconnect's stop predicate: as soon as ``dispatch`` flips
|
||||
// ``endReceived`` (typed ``end`` or ``error`` event seen — both
|
||||
// are terminal per the backend's contract). Without this the
|
||||
// live-tail endpoint would emit keepalives indefinitely and the
|
||||
// await would never return.
|
||||
const reconnectShouldStop = () => endReceived;
|
||||
|
||||
const runReconnect = async (): Promise<void> => {
|
||||
if (!messageId) {
|
||||
throw new Error('reconnect: no message_id captured');
|
||||
}
|
||||
const url = new URL(`${baseURL}/api/messages/${messageId}/events`);
|
||||
if (lastEventId !== null) {
|
||||
url.searchParams.set('last_event_id', String(lastEventId));
|
||||
}
|
||||
const reconnectHeaders: Record<string, string> = {
|
||||
Accept: 'text/event-stream',
|
||||
};
|
||||
if (token) reconnectHeaders.Authorization = `Bearer ${token}`;
|
||||
// NB: there is no slice consumer for a synthetic ``reconnecting``
|
||||
// event yet — surface only the underlying network reality. The
|
||||
// user-visible ``Reconnecting…`` affordance is a follow-up that
|
||||
// needs ``conversationSlice`` to gain a status case.
|
||||
const response = await fetch(url.toString(), {
|
||||
method: 'GET',
|
||||
headers: reconnectHeaders,
|
||||
signal,
|
||||
cache: 'no-store',
|
||||
});
|
||||
if (!response.ok || !response.body) {
|
||||
throw new Error(
|
||||
`reconnect: HTTP ${response.status} ${response.statusText}`,
|
||||
);
|
||||
}
|
||||
await _drainSseBody(
|
||||
response.body,
|
||||
signal,
|
||||
dispatch,
|
||||
(id) => {
|
||||
lastEventId = id;
|
||||
},
|
||||
reconnectShouldStop,
|
||||
);
|
||||
};
|
||||
|
||||
return new Promise<Answer>((resolve, reject) => {
|
||||
conversationService
|
||||
.answerStream(payload, token, signal)
|
||||
.then((response) => {
|
||||
if (!response.body) throw Error('No response body');
|
||||
|
||||
let buffer = '';
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder('utf-8');
|
||||
let counterrr = 0;
|
||||
const processStream = ({
|
||||
done,
|
||||
value,
|
||||
}: ReadableStreamReadResult<Uint8Array>) => {
|
||||
if (done) return;
|
||||
|
||||
counterrr += 1;
|
||||
|
||||
const chunk = decoder.decode(value);
|
||||
buffer += chunk;
|
||||
|
||||
const events = buffer.split('\n\n');
|
||||
buffer = events.pop() ?? '';
|
||||
|
||||
for (const event of events) {
|
||||
if (event.trim().startsWith('data:')) {
|
||||
const dataLine: string = event
|
||||
.split('\n')
|
||||
.map((line: string) => line.replace(/^data:\s?/, ''))
|
||||
.join('');
|
||||
|
||||
const messageEvent = new MessageEvent('message', {
|
||||
data: dataLine.trim(),
|
||||
});
|
||||
|
||||
onEvent(messageEvent);
|
||||
}
|
||||
(async () => {
|
||||
try {
|
||||
try {
|
||||
await runInitialPost();
|
||||
} catch (initialErr) {
|
||||
// Mid-stream network failures (WiFi blip, worker recycle,
|
||||
// body reader rejecting) surface as a thrown error — not a
|
||||
// graceful EOF. If the stream had already started (we have a
|
||||
// ``messageId``), fall through to the reconnect path so the
|
||||
// journal-backed replay can finish what the live socket
|
||||
// couldn't. Pre-stream failures (auth, DNS, server 4xx/5xx
|
||||
// before any yield) lack a messageId and bubble up.
|
||||
if (signal.aborted || !messageId) throw initialErr;
|
||||
console.warn(
|
||||
'Initial stream failed mid-flight, attempting reconnect:',
|
||||
initialErr,
|
||||
);
|
||||
}
|
||||
// The backend ends the stream cleanly with a typed ``end``
|
||||
// event. Anything else (network drop, gunicorn worker recycle,
|
||||
// load-balancer timeout) is a "premature close" — try one
|
||||
// reconnect via the GET /api/messages/<id>/events endpoint.
|
||||
if (!endReceived && !signal.aborted && messageId) {
|
||||
try {
|
||||
await runReconnect();
|
||||
} catch (reconnectErr) {
|
||||
console.warn('Stream reconnect failed:', reconnectErr);
|
||||
}
|
||||
|
||||
reader.read().then(processStream).catch(reject);
|
||||
};
|
||||
|
||||
reader.read().then(processStream).catch(reject);
|
||||
})
|
||||
.catch((error) => {
|
||||
}
|
||||
// If we never observed a terminal frame (reconnect 4xx/5xx,
|
||||
// network drop during reconnect, or live tail still silent),
|
||||
// synthesize one through the same ``dispatch`` path the wire
|
||||
// events use. Without this the caller's slice never transitions
|
||||
// out of ``streaming`` and the UI stays in a loading spinner
|
||||
// forever — the conversationSlice handles ``data.type === 'error'``
|
||||
// by setting status=failed.
|
||||
if (!endReceived && !signal.aborted) {
|
||||
dispatch(
|
||||
JSON.stringify({
|
||||
type: 'error',
|
||||
error:
|
||||
'Connection lost. The response could not be resumed; please try again.',
|
||||
}),
|
||||
);
|
||||
}
|
||||
// The handler historically never explicitly resolved with a
|
||||
// value — callers consume the streamed events via ``onEvent``
|
||||
// and read final state from Redux. Preserve that contract.
|
||||
resolve(undefined as unknown as Answer);
|
||||
} catch (error) {
|
||||
if (signal.aborted) {
|
||||
resolve(undefined as unknown as Answer);
|
||||
return;
|
||||
}
|
||||
console.error('Connection failed:', error);
|
||||
reject(error);
|
||||
});
|
||||
}
|
||||
})();
|
||||
});
|
||||
}
|
||||
|
||||
@@ -199,58 +451,158 @@ export function handleSubmitToolActions(
|
||||
token: string | null,
|
||||
signal: AbortSignal,
|
||||
onEvent: (event: MessageEvent) => void,
|
||||
idempotencyKey?: string,
|
||||
): Promise<Answer> {
|
||||
const payload = {
|
||||
conversation_id: conversationId,
|
||||
tool_actions: toolActions,
|
||||
};
|
||||
|
||||
const headers: Record<string, string> = {};
|
||||
if (idempotencyKey) headers['Idempotency-Key'] = idempotencyKey;
|
||||
|
||||
// Tool-action submissions resume against the original
|
||||
// ``reserved_message_id``, so the backend's continuation path emits
|
||||
// ``id:`` prefixed records that the legacy parser would silently
|
||||
// drop. Use the shared SSE drainer — and the same reconnect-on-
|
||||
// premature-close pattern as ``handleFetchAnswerSteaming`` so a
|
||||
// dropped tool-action stream can pick up after the disconnect.
|
||||
let messageId: string | null = null;
|
||||
let lastEventId: number | null = null;
|
||||
|
||||
// Track whether the typed ``end`` event was observed. The single
|
||||
// JSON.parse below feeds both the message_id capture and the
|
||||
// termination flag — cheaper and stricter than substring matching
|
||||
// the wire bytes.
|
||||
let endReceived = false;
|
||||
|
||||
const dispatch = (data: string) => {
|
||||
try {
|
||||
const parsed = JSON.parse(data);
|
||||
if (parsed && typeof parsed === 'object') {
|
||||
if (parsed.type === 'message_id' && parsed.message_id) {
|
||||
messageId = parsed.message_id;
|
||||
} else if (parsed.type === 'end' || parsed.type === 'error') {
|
||||
// Match the backend's terminal set in
|
||||
// ``application/streaming/event_replay.py``: the agent's
|
||||
// catch-all path emits ``error`` *without* a trailing
|
||||
// ``end``, so treating only ``end`` as terminal would
|
||||
// trigger a reconnect against an already-finished stream
|
||||
// and hang on keepalives.
|
||||
endReceived = true;
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// Not JSON — pass through anyway; the caller handles raw lines.
|
||||
}
|
||||
onEvent(new MessageEvent('message', { data }));
|
||||
};
|
||||
|
||||
const runInitial = async (): Promise<void> => {
|
||||
const response = await conversationService.answerStream(
|
||||
payload,
|
||||
token,
|
||||
signal,
|
||||
headers,
|
||||
);
|
||||
// See ``handleFetchAnswerSteaming`` for the rationale: non-SSE
|
||||
// HTTP failures (e.g. ``check_usage`` 429 JSON) need to be lifted
|
||||
// into a typed ``error`` frame before they reach the drainer.
|
||||
if (await _handlePreStreamHttpError(response, dispatch)) return;
|
||||
if (!response.body) throw new Error('No response body');
|
||||
await _drainSseBody(response.body, signal, dispatch, (id) => {
|
||||
lastEventId = id;
|
||||
});
|
||||
};
|
||||
|
||||
// Reconnect's stop predicate: as soon as ``dispatch`` flips
|
||||
// ``endReceived`` (typed ``end`` or ``error`` event seen — both
|
||||
// are terminal per the backend's contract). Without this the
|
||||
// live-tail endpoint would emit keepalives indefinitely and the
|
||||
// await would never return.
|
||||
const reconnectShouldStop = () => endReceived;
|
||||
|
||||
const runReconnect = async (): Promise<void> => {
|
||||
if (!messageId) {
|
||||
throw new Error('reconnect: no message_id captured');
|
||||
}
|
||||
const url = new URL(`${baseURL}/api/messages/${messageId}/events`);
|
||||
if (lastEventId !== null) {
|
||||
url.searchParams.set('last_event_id', String(lastEventId));
|
||||
}
|
||||
const reconnectHeaders: Record<string, string> = {
|
||||
Accept: 'text/event-stream',
|
||||
};
|
||||
if (token) reconnectHeaders.Authorization = `Bearer ${token}`;
|
||||
const response = await fetch(url.toString(), {
|
||||
method: 'GET',
|
||||
headers: reconnectHeaders,
|
||||
signal,
|
||||
cache: 'no-store',
|
||||
});
|
||||
if (!response.ok || !response.body) {
|
||||
throw new Error(
|
||||
`reconnect: HTTP ${response.status} ${response.statusText}`,
|
||||
);
|
||||
}
|
||||
await _drainSseBody(
|
||||
response.body,
|
||||
signal,
|
||||
dispatch,
|
||||
(id) => {
|
||||
lastEventId = id;
|
||||
},
|
||||
reconnectShouldStop,
|
||||
);
|
||||
};
|
||||
|
||||
return new Promise<Answer>((resolve, reject) => {
|
||||
conversationService
|
||||
.answerStream(payload, token, signal)
|
||||
.then((response) => {
|
||||
if (!response.body) throw Error('No response body');
|
||||
|
||||
let buffer = '';
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder('utf-8');
|
||||
|
||||
const processStream = ({
|
||||
done,
|
||||
value,
|
||||
}: ReadableStreamReadResult<Uint8Array>) => {
|
||||
if (done) return;
|
||||
|
||||
const chunk = decoder.decode(value);
|
||||
buffer += chunk;
|
||||
|
||||
const events = buffer.split('\n\n');
|
||||
buffer = events.pop() ?? '';
|
||||
|
||||
for (const event of events) {
|
||||
if (event.trim().startsWith('data:')) {
|
||||
const dataLine: string = event
|
||||
.split('\n')
|
||||
.map((line: string) => line.replace(/^data:\s?/, ''))
|
||||
.join('');
|
||||
|
||||
const messageEvent = new MessageEvent('message', {
|
||||
data: dataLine.trim(),
|
||||
});
|
||||
|
||||
onEvent(messageEvent);
|
||||
}
|
||||
(async () => {
|
||||
try {
|
||||
try {
|
||||
await runInitial();
|
||||
} catch (initialErr) {
|
||||
// Same premature-close handling as
|
||||
// ``handleFetchAnswerSteaming``: a thrown reader error after
|
||||
// the message_id frame still warrants one reconnect attempt
|
||||
// against the journal. Pre-stream failures lack a messageId
|
||||
// and bubble up.
|
||||
if (signal.aborted || !messageId) throw initialErr;
|
||||
console.warn(
|
||||
'Tool-actions stream failed mid-flight, attempting reconnect:',
|
||||
initialErr,
|
||||
);
|
||||
}
|
||||
if (!endReceived && !signal.aborted && messageId) {
|
||||
try {
|
||||
await runReconnect();
|
||||
} catch (reconnectErr) {
|
||||
console.warn('Tool-actions reconnect failed:', reconnectErr);
|
||||
}
|
||||
|
||||
reader.read().then(processStream).catch(reject);
|
||||
};
|
||||
|
||||
reader.read().then(processStream).catch(reject);
|
||||
})
|
||||
.catch((error) => {
|
||||
}
|
||||
// Synthesize a terminal error if reconnect couldn't deliver one
|
||||
// (4xx/5xx, network drop, silent live tail). Same reasoning as
|
||||
// ``handleFetchAnswerSteaming``: the caller's slice only exits
|
||||
// the streaming state on a terminal frame.
|
||||
if (!endReceived && !signal.aborted) {
|
||||
dispatch(
|
||||
JSON.stringify({
|
||||
type: 'error',
|
||||
error:
|
||||
'Connection lost. The tool response could not be resumed; please try again.',
|
||||
}),
|
||||
);
|
||||
}
|
||||
resolve(undefined as unknown as Answer);
|
||||
} catch (error) {
|
||||
if (signal.aborted) {
|
||||
resolve(undefined as unknown as Answer);
|
||||
return;
|
||||
}
|
||||
console.error('Tool actions submission failed:', error);
|
||||
reject(error);
|
||||
});
|
||||
}
|
||||
})();
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,8 @@ import { ToolCallsType } from './types';
|
||||
export type MESSAGE_TYPE = 'QUESTION' | 'ANSWER' | 'ERROR';
|
||||
export type Status = 'idle' | 'loading' | 'failed' | 'awaiting_tool_actions';
|
||||
export type FEEDBACK = 'LIKE' | 'DISLIKE' | null;
|
||||
// Mirrors ``conversation_messages.status``.
|
||||
export type MessageStatus = 'pending' | 'streaming' | 'complete' | 'failed';
|
||||
|
||||
export interface Message {
|
||||
text: string;
|
||||
@@ -65,6 +67,13 @@ export interface Query {
|
||||
structured?: boolean;
|
||||
schema?: object;
|
||||
research?: ResearchState;
|
||||
// WAL placeholder id; lets the client tail an in-flight stream.
|
||||
messageId?: string;
|
||||
messageStatus?: MessageStatus;
|
||||
requestId?: string;
|
||||
lastHeartbeatAt?: string;
|
||||
// Persisted so Retry can re-send the same key for server-side dedup.
|
||||
idempotencyKey?: string;
|
||||
}
|
||||
|
||||
export interface RetrievalPayload {
|
||||
|
||||
153
frontend/src/conversation/conversationSlice.test.ts
Normal file
153
frontend/src/conversation/conversationSlice.test.ts
Normal file
@@ -0,0 +1,153 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import reducer, {
|
||||
applyMessageTail,
|
||||
setConversation,
|
||||
} from './conversationSlice';
|
||||
|
||||
const baseQuery = {
|
||||
prompt: 'tell me a poem',
|
||||
messageId: 'm-1',
|
||||
messageStatus: 'pending' as const,
|
||||
};
|
||||
|
||||
const seedSlice = () => reducer(undefined, setConversation([baseQuery]));
|
||||
|
||||
describe('applyMessageTail — streaming partial', () => {
|
||||
it('writes response to the query while status is streaming', () => {
|
||||
const state = seedSlice();
|
||||
const next = reducer(
|
||||
state,
|
||||
applyMessageTail({
|
||||
index: 0,
|
||||
tail: {
|
||||
message_id: 'm-1',
|
||||
status: 'streaming',
|
||||
response: 'Hello, par',
|
||||
thought: null,
|
||||
sources: [],
|
||||
tool_calls: [],
|
||||
},
|
||||
}),
|
||||
);
|
||||
expect(next.queries[0].messageStatus).toBe('streaming');
|
||||
expect(next.queries[0].response).toBe('Hello, par');
|
||||
});
|
||||
|
||||
it('updates response on each successive tail tick', () => {
|
||||
let state = seedSlice();
|
||||
state = reducer(
|
||||
state,
|
||||
applyMessageTail({
|
||||
index: 0,
|
||||
tail: {
|
||||
message_id: 'm-1',
|
||||
status: 'streaming',
|
||||
response: 'Hello',
|
||||
sources: [],
|
||||
tool_calls: [],
|
||||
},
|
||||
}),
|
||||
);
|
||||
state = reducer(
|
||||
state,
|
||||
applyMessageTail({
|
||||
index: 0,
|
||||
tail: {
|
||||
message_id: 'm-1',
|
||||
status: 'streaming',
|
||||
response: 'Hello, world',
|
||||
sources: [],
|
||||
tool_calls: [],
|
||||
},
|
||||
}),
|
||||
);
|
||||
expect(state.queries[0].response).toBe('Hello, world');
|
||||
});
|
||||
|
||||
it('applies sources and tool_calls when they appear mid-stream', () => {
|
||||
const state = seedSlice();
|
||||
const next = reducer(
|
||||
state,
|
||||
applyMessageTail({
|
||||
index: 0,
|
||||
tail: {
|
||||
message_id: 'm-1',
|
||||
status: 'streaming',
|
||||
response: 'partial',
|
||||
sources: [{ id: 's1', title: 'doc' }],
|
||||
tool_calls: [{ name: 'search' }],
|
||||
},
|
||||
}),
|
||||
);
|
||||
expect(next.queries[0].sources).toEqual([{ id: 's1', title: 'doc' }]);
|
||||
expect(next.queries[0].tool_calls).toEqual([{ name: 'search' }]);
|
||||
});
|
||||
|
||||
it('ignores empty sources / tool_calls arrays so the renderer stays clean', () => {
|
||||
const state = seedSlice();
|
||||
const next = reducer(
|
||||
state,
|
||||
applyMessageTail({
|
||||
index: 0,
|
||||
tail: {
|
||||
message_id: 'm-1',
|
||||
status: 'streaming',
|
||||
response: 'partial',
|
||||
sources: [],
|
||||
tool_calls: [],
|
||||
},
|
||||
}),
|
||||
);
|
||||
expect(next.queries[0].sources).toBeUndefined();
|
||||
expect(next.queries[0].tool_calls).toBeUndefined();
|
||||
});
|
||||
|
||||
it('promotes to complete with the final response and clears any error', () => {
|
||||
let state = seedSlice();
|
||||
state = reducer(
|
||||
state,
|
||||
applyMessageTail({
|
||||
index: 0,
|
||||
tail: {
|
||||
message_id: 'm-1',
|
||||
status: 'streaming',
|
||||
response: 'partial',
|
||||
},
|
||||
}),
|
||||
);
|
||||
state = reducer(
|
||||
state,
|
||||
applyMessageTail({
|
||||
index: 0,
|
||||
tail: {
|
||||
message_id: 'm-1',
|
||||
status: 'complete',
|
||||
response: 'Final answer.',
|
||||
},
|
||||
}),
|
||||
);
|
||||
expect(state.queries[0].messageStatus).toBe('complete');
|
||||
expect(state.queries[0].response).toBe('Final answer.');
|
||||
expect(state.queries[0].error).toBeUndefined();
|
||||
});
|
||||
|
||||
it('surfaces failed status as error and clears response', () => {
|
||||
const state = seedSlice();
|
||||
const next = reducer(
|
||||
state,
|
||||
applyMessageTail({
|
||||
index: 0,
|
||||
tail: {
|
||||
message_id: 'm-1',
|
||||
status: 'failed',
|
||||
response: 'whatever',
|
||||
error: 'worker died',
|
||||
},
|
||||
}),
|
||||
);
|
||||
expect(next.queries[0].messageStatus).toBe('failed');
|
||||
expect(next.queries[0].error).toBe('worker died');
|
||||
expect(next.queries[0].response).toBeUndefined();
|
||||
});
|
||||
});
|
||||
@@ -1,5 +1,6 @@
|
||||
import { createAsyncThunk, createSlice, PayloadAction } from '@reduxjs/toolkit';
|
||||
|
||||
import conversationService from '../api/services/conversationService';
|
||||
import { getConversations } from '../preferences/preferenceApi';
|
||||
import { setConversations } from '../preferences/preferenceSlice';
|
||||
import store from '../store';
|
||||
@@ -7,6 +8,7 @@ import {
|
||||
clearAttachments,
|
||||
selectCompletedAttachments,
|
||||
} from '../upload/uploadSlice';
|
||||
import { newIdempotencyKey } from '../utils/idempotency';
|
||||
import {
|
||||
handleFetchAnswer,
|
||||
handleFetchAnswerSteaming,
|
||||
@@ -16,12 +18,61 @@ import {
|
||||
import {
|
||||
Answer,
|
||||
ConversationState,
|
||||
MessageStatus,
|
||||
Query,
|
||||
ResearchStep,
|
||||
Status,
|
||||
} from './conversationModels';
|
||||
import { ToolCallsType } from './types';
|
||||
|
||||
// Maps a server message dict into the client ``Query`` shape. Only
|
||||
// terminal ``complete`` rows expose ``response``; non-terminal rows
|
||||
// would carry the WAL placeholder text, which must never render.
|
||||
// ``failed`` rows surface as ``error`` so they pick up Retry.
|
||||
export function mapServerQueryToClient(raw: any): Query {
|
||||
const status = raw?.status as MessageStatus | undefined;
|
||||
const isTerminalComplete = status === 'complete';
|
||||
const isFailed = status === 'failed';
|
||||
const metadata = raw?.metadata || {};
|
||||
|
||||
// Empty arrays are JS-truthy; coercing to undefined keeps the
|
||||
// renderer from rendering a blank bubble for in-flight rows and
|
||||
// matches the shape live-stream queries start with.
|
||||
const toolCalls = Array.isArray(raw?.tool_calls) ? raw.tool_calls : undefined;
|
||||
const sources = Array.isArray(raw?.sources) ? raw.sources : undefined;
|
||||
const query: Query = {
|
||||
prompt: raw?.prompt ?? '',
|
||||
feedback: raw?.feedback ?? undefined,
|
||||
thought: raw?.thought ?? undefined,
|
||||
sources: sources && sources.length > 0 ? sources : undefined,
|
||||
tool_calls: toolCalls && toolCalls.length > 0 ? toolCalls : undefined,
|
||||
attachments: raw?.attachments ?? undefined,
|
||||
messageId: raw?.message_id ?? undefined,
|
||||
messageStatus: status,
|
||||
requestId: raw?.request_id ?? undefined,
|
||||
lastHeartbeatAt: raw?.last_heartbeat_at ?? undefined,
|
||||
};
|
||||
|
||||
if (isTerminalComplete) {
|
||||
query.response = raw?.response ?? '';
|
||||
}
|
||||
if (isFailed) {
|
||||
query.error =
|
||||
(typeof metadata.error === 'string' && metadata.error) ||
|
||||
'Generation failed before completing.';
|
||||
}
|
||||
return query;
|
||||
}
|
||||
|
||||
// Placeholder still being produced server-side; client should tail
|
||||
// rather than treat as idle.
|
||||
export function isInFlightMessage(query: Query | undefined): boolean {
|
||||
if (!query) return false;
|
||||
return (
|
||||
query.messageStatus === 'pending' || query.messageStatus === 'streaming'
|
||||
);
|
||||
}
|
||||
|
||||
const initialState: ConversationState = {
|
||||
queries: [],
|
||||
status: 'idle',
|
||||
@@ -39,6 +90,63 @@ export function handleAbort() {
|
||||
}
|
||||
}
|
||||
|
||||
// Loads a conversation and applies it to the slice. Returns
|
||||
// ``{data, stale}``: ``stale`` true means a newer load has superseded
|
||||
// this one (or Redux already matches), so callers should not react to
|
||||
// the returned data; ``data`` null with ``stale`` false means 404.
|
||||
export type LoadConversationResult = {
|
||||
data: any | null;
|
||||
stale: boolean;
|
||||
};
|
||||
|
||||
let loadSeq = 0;
|
||||
|
||||
export const loadConversation = createAsyncThunk<
|
||||
LoadConversationResult,
|
||||
{ id: string; force?: boolean }
|
||||
>('loadConversation', async ({ id, force }, { dispatch, getState }) => {
|
||||
const seq = ++loadSeq;
|
||||
const state = getState() as RootState;
|
||||
const token = state.preference.token;
|
||||
if (!force && state.conversation.conversationId === id) {
|
||||
return { data: null, stale: true };
|
||||
}
|
||||
const response = await conversationService.getConversation(id, token);
|
||||
if (!response.ok) {
|
||||
return { data: null, stale: false };
|
||||
}
|
||||
const data = await response.json();
|
||||
if (!data) return { data: null, stale: false };
|
||||
|
||||
// A later loadConversation has been issued; drop our writes so its
|
||||
// result wins, and tell the caller not to navigate off our return.
|
||||
if (seq !== loadSeq) {
|
||||
return { data: null, stale: true };
|
||||
}
|
||||
|
||||
const mappedQueries = (data.queries || []).map(mapServerQueryToClient);
|
||||
dispatch(conversationSlice.actions.setConversation(mappedQueries));
|
||||
dispatch(
|
||||
conversationSlice.actions.updateConversationId({
|
||||
query: { conversationId: id },
|
||||
}),
|
||||
);
|
||||
|
||||
// Only tail the trailing message; earlier in-flight rows are rare.
|
||||
const lastIdx = mappedQueries.length - 1;
|
||||
const lastQuery = mappedQueries[lastIdx];
|
||||
if (lastQuery && lastQuery.messageId && isInFlightMessage(lastQuery)) {
|
||||
dispatch(
|
||||
tailInFlightMessage({
|
||||
messageId: lastQuery.messageId,
|
||||
index: lastIdx,
|
||||
conversationId: id,
|
||||
}),
|
||||
);
|
||||
}
|
||||
return { data, stale: false };
|
||||
});
|
||||
|
||||
export const fetchAnswer = createAsyncThunk<
|
||||
Answer,
|
||||
{ question: string; indx?: number }
|
||||
@@ -57,11 +165,30 @@ export const fetchAnswer = createAsyncThunk<
|
||||
dispatch(clearAttachments());
|
||||
}
|
||||
|
||||
const currentConversationId = state.conversation.conversationId;
|
||||
// Mutable so the SSE handler can adopt a server-assigned id and
|
||||
// keep passing it to reducer guards once the early ``message_id``
|
||||
// event lands.
|
||||
let currentConversationId = state.conversation.conversationId;
|
||||
const modelId =
|
||||
state.preference.selectedAgent?.default_model_id ||
|
||||
state.preference.selectedModel?.id;
|
||||
|
||||
// Reuse the key on the target Query when present (retry path),
|
||||
// else mint and persist so a later retry can re-send it.
|
||||
const targetIndexForKey =
|
||||
indx ?? Math.max(state.conversation.queries.length - 1, 0);
|
||||
let idempotencyKey =
|
||||
state.conversation.queries[targetIndexForKey]?.idempotencyKey;
|
||||
if (!idempotencyKey) {
|
||||
idempotencyKey = newIdempotencyKey();
|
||||
dispatch(
|
||||
conversationSlice.actions.updateQuery({
|
||||
index: targetIndexForKey,
|
||||
query: { idempotencyKey },
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
if (state.preference) {
|
||||
const agentKey = state.preference.selectedAgent?.key;
|
||||
if (USE_V1_API && agentKey) {
|
||||
@@ -79,7 +206,11 @@ export const fetchAnswer = createAsyncThunk<
|
||||
const data = JSON.parse(event.data);
|
||||
const targetIndex = indx ?? state.conversation.queries.length - 1;
|
||||
|
||||
if (currentConversationId === state.conversation.conversationId) {
|
||||
// Live Redux check; the closure ``state`` is a stale snapshot.
|
||||
if (
|
||||
currentConversationId ===
|
||||
(getState() as RootState).conversation.conversationId
|
||||
) {
|
||||
if (data.type === 'end') {
|
||||
dispatch(conversationSlice.actions.setStatus('idle'));
|
||||
getConversations(state.preference.token)
|
||||
@@ -107,6 +238,28 @@ export const fetchAnswer = createAsyncThunk<
|
||||
}),
|
||||
);
|
||||
}
|
||||
} else if (data.type === 'message_id') {
|
||||
if (data.conversation_id) {
|
||||
const currentState = getState() as RootState;
|
||||
if (currentState.conversation.conversationId === null) {
|
||||
// setConversationId leaves status='loading'; the
|
||||
// status-touching updateConversationId would flip it
|
||||
// to 'idle' and drop subsequent chunks.
|
||||
dispatch(
|
||||
conversationSlice.actions.setConversationId(
|
||||
data.conversation_id,
|
||||
),
|
||||
);
|
||||
currentConversationId = data.conversation_id;
|
||||
}
|
||||
}
|
||||
dispatch(
|
||||
conversationSlice.actions.updateMessageMeta({
|
||||
index: targetIndex,
|
||||
messageId: data.message_id,
|
||||
requestId: data.request_id,
|
||||
}),
|
||||
);
|
||||
} else if (data.type === 'thought') {
|
||||
dispatch(
|
||||
updateThought({
|
||||
@@ -171,8 +324,11 @@ export const fetchAnswer = createAsyncThunk<
|
||||
const data = JSON.parse(event.data);
|
||||
const targetIndex = indx ?? state.conversation.queries.length - 1;
|
||||
|
||||
// Only process events if they match the current conversation
|
||||
if (currentConversationId === state.conversation.conversationId) {
|
||||
// Live Redux check; the closure ``state`` is a stale snapshot.
|
||||
if (
|
||||
currentConversationId ===
|
||||
(getState() as RootState).conversation.conversationId
|
||||
) {
|
||||
if (data.type === 'end') {
|
||||
dispatch(conversationSlice.actions.setStatus('idle'));
|
||||
// Only update research status if this query has research data
|
||||
@@ -211,6 +367,28 @@ export const fetchAnswer = createAsyncThunk<
|
||||
}),
|
||||
);
|
||||
}
|
||||
} else if (data.type === 'message_id') {
|
||||
if (data.conversation_id) {
|
||||
const currentState = getState() as RootState;
|
||||
if (currentState.conversation.conversationId === null) {
|
||||
// setConversationId leaves status='loading'; the
|
||||
// status-touching updateConversationId would flip it
|
||||
// to 'idle' and drop subsequent chunks.
|
||||
dispatch(
|
||||
conversationSlice.actions.setConversationId(
|
||||
data.conversation_id,
|
||||
),
|
||||
);
|
||||
currentConversationId = data.conversation_id;
|
||||
}
|
||||
}
|
||||
dispatch(
|
||||
conversationSlice.actions.updateMessageMeta({
|
||||
index: targetIndex,
|
||||
messageId: data.message_id,
|
||||
requestId: data.request_id,
|
||||
}),
|
||||
);
|
||||
} else if (data.type === 'thought') {
|
||||
const result = data.thought;
|
||||
dispatch(
|
||||
@@ -293,6 +471,7 @@ export const fetchAnswer = createAsyncThunk<
|
||||
attachmentIds,
|
||||
true,
|
||||
modelId,
|
||||
idempotencyKey,
|
||||
);
|
||||
} else {
|
||||
const answer = await handleFetchAnswer(
|
||||
@@ -307,6 +486,7 @@ export const fetchAnswer = createAsyncThunk<
|
||||
attachmentIds,
|
||||
true,
|
||||
modelId,
|
||||
idempotencyKey,
|
||||
);
|
||||
if (answer) {
|
||||
let sourcesPrepped = [];
|
||||
@@ -362,6 +542,67 @@ export const fetchAnswer = createAsyncThunk<
|
||||
};
|
||||
});
|
||||
|
||||
// Tail-polls the placeholder until terminal status, navigation away,
|
||||
// or hard timeout. First poll fires immediately so rows that are
|
||||
// already terminal resolve without delay.
|
||||
const TAIL_POLL_INTERVAL_MS = 2000;
|
||||
const TAIL_MAX_POLL_DURATION_MS = 10 * 60 * 1000;
|
||||
|
||||
export const tailInFlightMessage = createAsyncThunk<
|
||||
void,
|
||||
{ messageId: string; index: number; conversationId: string }
|
||||
>(
|
||||
'tailInFlightMessage',
|
||||
async ({ messageId, index, conversationId }, { dispatch, getState }) => {
|
||||
const initialState = getState() as RootState;
|
||||
const token = initialState.preference.token;
|
||||
const start = Date.now();
|
||||
dispatch(conversationSlice.actions.setStatus('loading'));
|
||||
|
||||
while (Date.now() - start < TAIL_MAX_POLL_DURATION_MS) {
|
||||
const cur = (getState() as RootState).conversation.conversationId;
|
||||
if (cur !== conversationId) return;
|
||||
|
||||
let resp: Response;
|
||||
try {
|
||||
resp = await conversationService.tailMessage(messageId, token);
|
||||
} catch {
|
||||
await new Promise((r) => setTimeout(r, TAIL_POLL_INTERVAL_MS));
|
||||
continue;
|
||||
}
|
||||
|
||||
// 404 → row deleted (e.g. conversation wiped); bail quietly.
|
||||
if (resp.status === 404) {
|
||||
dispatch(conversationSlice.actions.setStatus('idle'));
|
||||
return;
|
||||
}
|
||||
|
||||
if (!resp.ok) {
|
||||
await new Promise((r) => setTimeout(r, TAIL_POLL_INTERVAL_MS));
|
||||
continue;
|
||||
}
|
||||
|
||||
const data = await resp.json();
|
||||
dispatch(
|
||||
conversationSlice.actions.applyMessageTail({ index, tail: data }),
|
||||
);
|
||||
|
||||
const status = data?.status as MessageStatus | undefined;
|
||||
if (status === 'complete' || status === 'failed') {
|
||||
dispatch(
|
||||
conversationSlice.actions.setStatus(
|
||||
status === 'failed' ? 'failed' : 'idle',
|
||||
),
|
||||
);
|
||||
return;
|
||||
}
|
||||
await new Promise((r) => setTimeout(r, TAIL_POLL_INTERVAL_MS));
|
||||
}
|
||||
// Hard timeout: drop status to idle so the user can interact again.
|
||||
dispatch(conversationSlice.actions.setStatus('idle'));
|
||||
},
|
||||
);
|
||||
|
||||
export const submitToolActions = createAsyncThunk<
|
||||
void,
|
||||
{
|
||||
@@ -379,10 +620,26 @@ export const submitToolActions = createAsyncThunk<
|
||||
|
||||
const state = getState() as RootState;
|
||||
const conversationId = state.conversation.conversationId;
|
||||
if (!conversationId) return;
|
||||
if (!conversationId) {
|
||||
const targetIndex = state.conversation.queries.length - 1;
|
||||
if (targetIndex >= 0) {
|
||||
dispatch(
|
||||
conversationSlice.actions.raiseError({
|
||||
conversationId: null,
|
||||
index: targetIndex,
|
||||
message:
|
||||
'Cannot submit decision — the conversation was not initialized. Please retry the question.',
|
||||
}),
|
||||
);
|
||||
}
|
||||
dispatch(conversationSlice.actions.setStatus('failed'));
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(conversationSlice.actions.setStatus('loading'));
|
||||
|
||||
// Fresh per submission: a tool decision is its own logical action.
|
||||
const idempotencyKey = newIdempotencyKey();
|
||||
await handleSubmitToolActions(
|
||||
conversationId,
|
||||
toolActions,
|
||||
@@ -403,6 +660,15 @@ export const submitToolActions = createAsyncThunk<
|
||||
});
|
||||
} else if (data.type === 'id') {
|
||||
// conversation ID already set
|
||||
} else if (data.type === 'message_id') {
|
||||
// Re-stamp; continuation reuses the original placeholder.
|
||||
dispatch(
|
||||
conversationSlice.actions.updateMessageMeta({
|
||||
index: targetIndex,
|
||||
messageId: data.message_id,
|
||||
requestId: data.request_id,
|
||||
}),
|
||||
);
|
||||
} else if (data.type === 'thought') {
|
||||
dispatch(
|
||||
updateThought({
|
||||
@@ -447,6 +713,7 @@ export const submitToolActions = createAsyncThunk<
|
||||
);
|
||||
}
|
||||
},
|
||||
idempotencyKey,
|
||||
);
|
||||
});
|
||||
|
||||
@@ -462,9 +729,13 @@ export const conversationSlice = createSlice({
|
||||
},
|
||||
resendQuery(
|
||||
state,
|
||||
action: PayloadAction<{ index: number; prompt: string }>,
|
||||
action: PayloadAction<{
|
||||
index: number;
|
||||
prompt: string;
|
||||
keepIdempotencyKey?: boolean;
|
||||
}>,
|
||||
) {
|
||||
const { index, prompt } = action.payload;
|
||||
const { index, prompt, keepIdempotencyKey } = action.payload;
|
||||
if (index < 0 || index >= state.queries.length) return;
|
||||
|
||||
state.queries.splice(index + 1);
|
||||
@@ -478,6 +749,15 @@ export const conversationSlice = createSlice({
|
||||
delete state.queries[index].schema;
|
||||
delete state.queries[index].feedback;
|
||||
delete state.queries[index].research;
|
||||
// Drop stale WAL refs; the next stream's message_id event repopulates.
|
||||
delete state.queries[index].messageId;
|
||||
delete state.queries[index].messageStatus;
|
||||
delete state.queries[index].requestId;
|
||||
delete state.queries[index].lastHeartbeatAt;
|
||||
// Retry keeps the key so the server can dedupe; Edit drops it.
|
||||
if (!keepIdempotencyKey) {
|
||||
delete state.queries[index].idempotencyKey;
|
||||
}
|
||||
},
|
||||
updateStreamingQuery(
|
||||
state,
|
||||
@@ -512,6 +792,11 @@ export const conversationSlice = createSlice({
|
||||
state.conversationId = action.payload.query.conversationId ?? null;
|
||||
state.status = 'idle';
|
||||
},
|
||||
// Sets id without touching status; used mid-stream where the
|
||||
// status-flipping updateConversationId would drop later chunks.
|
||||
setConversationId(state, action: PayloadAction<string | null>) {
|
||||
state.conversationId = action.payload;
|
||||
},
|
||||
updateThought(
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
@@ -646,6 +931,61 @@ export const conversationSlice = createSlice({
|
||||
setStatus(state, action: PayloadAction<Status>) {
|
||||
state.status = action.payload;
|
||||
},
|
||||
updateMessageMeta(
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
index: number;
|
||||
messageId?: string;
|
||||
requestId?: string;
|
||||
}>,
|
||||
) {
|
||||
const { index, messageId, requestId } = action.payload;
|
||||
const query = state.queries[index];
|
||||
if (!query) return;
|
||||
if (messageId) query.messageId = messageId;
|
||||
if (requestId) query.requestId = requestId;
|
||||
// Mirror the server-side default so a refresh sees 'pending'.
|
||||
if (!query.messageStatus) query.messageStatus = 'pending';
|
||||
},
|
||||
applyMessageTail(
|
||||
state,
|
||||
action: PayloadAction<{ index: number; tail: any }>,
|
||||
) {
|
||||
const { index, tail } = action.payload;
|
||||
const query = state.queries[index];
|
||||
if (!query) return;
|
||||
const status = tail?.status as MessageStatus | undefined;
|
||||
query.messageStatus = status;
|
||||
query.lastHeartbeatAt = tail?.last_heartbeat_at ?? query.lastHeartbeatAt;
|
||||
if (status === 'failed') {
|
||||
// Surface as error so the placeholder text never renders.
|
||||
query.error =
|
||||
(typeof tail?.error === 'string' && tail.error) ||
|
||||
'Generation failed before completing.';
|
||||
delete query.response;
|
||||
return;
|
||||
}
|
||||
// /tail returns reconstructed partials mid-stream so a second tab
|
||||
// can render the in-flight bubble; spinner is driven by status.
|
||||
const incomingResponse = tail?.response;
|
||||
if (typeof incomingResponse === 'string') {
|
||||
query.response = incomingResponse;
|
||||
} else if (status === 'complete') {
|
||||
query.response = '';
|
||||
}
|
||||
if (typeof tail?.thought === 'string') {
|
||||
query.thought = tail.thought;
|
||||
}
|
||||
if (Array.isArray(tail?.sources) && tail.sources.length > 0) {
|
||||
query.sources = tail.sources;
|
||||
}
|
||||
if (Array.isArray(tail?.tool_calls) && tail.tool_calls.length > 0) {
|
||||
query.tool_calls = tail.tool_calls;
|
||||
}
|
||||
if (status === 'complete') {
|
||||
delete query.error;
|
||||
}
|
||||
},
|
||||
raiseError(
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
@@ -704,8 +1044,11 @@ export const {
|
||||
updateResearchPlan,
|
||||
updateResearchProgress,
|
||||
setConversation,
|
||||
setConversationId,
|
||||
setStatus,
|
||||
raiseError,
|
||||
resetConversation,
|
||||
applyMessageTail,
|
||||
updateMessageMeta,
|
||||
} = conversationSlice.actions;
|
||||
export default conversationSlice.reducer;
|
||||
|
||||
18
frontend/src/events/EventStreamProvider.tsx
Normal file
18
frontend/src/events/EventStreamProvider.tsx
Normal file
@@ -0,0 +1,18 @@
|
||||
import React from 'react';
|
||||
|
||||
import { useEventStream } from './useEventStream';
|
||||
|
||||
/**
|
||||
* Mount-once provider that opens the user's SSE connection. Place
|
||||
* inside ``AuthWrapper`` so it sees a populated token, and wrap the
|
||||
* authenticated-app subtree so the connection lives for the user's
|
||||
* whole session.
|
||||
*/
|
||||
export function EventStreamProvider({
|
||||
children,
|
||||
}: {
|
||||
children: React.ReactNode;
|
||||
}): React.ReactElement {
|
||||
useEventStream();
|
||||
return <>{children}</>;
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user