Compare commits

..

6 Commits

Author SHA1 Message Date
Alex
26e2e7d353 fix: tests 2026-05-04 23:00:21 +01:00
Alex
42c33f4e0d fix: better json validation 2026-05-04 18:09:40 +01:00
Alex
073f9fc003 fix: mini issues 2026-05-04 17:51:09 +01:00
Alex
9b974af210 fix: tests 2026-05-04 17:17:15 +01:00
Alex
fe1edc6b79 feat: more durable frontend 2026-05-04 16:32:28 +01:00
Alex
e550b11f39 feat: durability and idempotency keys 2026-05-03 18:36:02 +01:00
149 changed files with 1490 additions and 14924 deletions

View File

@@ -20,11 +20,10 @@ from pydantic import AnyHttpUrl, ValidationError
from redis import Redis
from application.agents.tools.base import Tool
from application.api.user.tasks import mcp_oauth_task
from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task
from application.cache import get_redis_instance
from application.core.settings import settings
from application.core.url_validation import SSRFError, validate_url
from application.events.keys import stream_key
from application.security.encryption import decrypt_credentials
logger = logging.getLogger(__name__)
@@ -77,12 +76,6 @@ class MCPTool(Tool):
self.oauth_task_id = config.get("oauth_task_id", None)
self.oauth_client_name = config.get("oauth_client_name", "DocsGPT-MCP")
self.redirect_uri = self._resolve_redirect_uri(config.get("redirect_uri"))
# Pulled out of ``config`` (rather than left in ``self.config``)
# because it is a callable supplied by the OAuth worker — not
# something the rest of the tool plumbing should marshal or
# serialize. ``DocsGPTOAuth`` invokes it from ``redirect_handler``
# so the SSE envelope can carry ``authorization_url``.
self.oauth_redirect_publish = config.pop("oauth_redirect_publish", None)
self.available_tools = []
self._cache_key = self._generate_cache_key()
@@ -174,7 +167,6 @@ class MCPTool(Tool):
redirect_uri=self.redirect_uri,
task_id=self.oauth_task_id,
user_id=self.user_id,
redirect_publish=self.oauth_redirect_publish,
)
elif self.auth_type == "bearer":
token = self.auth_credentials.get(
@@ -687,17 +679,12 @@ class DocsGPTOAuth(OAuthClientProvider):
user_id=None,
additional_client_metadata: dict[str, Any] | None = None,
skip_redirect_validation: bool = False,
redirect_publish=None,
):
self.redirect_uri = redirect_uri
self.redis_client = redis_client
self.redis_prefix = redis_prefix
self.task_id = task_id
self.user_id = user_id
# Worker-supplied callback. Invoked from ``redirect_handler``
# once the authorization URL is known so the SSE envelope can
# carry it. ``None`` for any non-worker entrypoint.
self.redirect_publish = redirect_publish
parsed_url = urlparse(mcp_url)
self.server_base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
@@ -757,19 +744,17 @@ class DocsGPTOAuth(OAuthClientProvider):
self.redis_client.setex(key, 600, auth_url)
logger.info("Stored auth_url in Redis: %s", key)
if self.redirect_publish is not None:
# Best-effort: a publish failure must not abort the OAuth
# handshake — the user can still authorize via the popup
# opened from the legacy polling fallback if the SSE
# envelope is lost.
try:
self.redirect_publish(auth_url)
except Exception:
logger.warning(
"redirect_publish callback raised for task_id=%s",
self.task_id,
exc_info=True,
)
if self.task_id:
status_key = f"mcp_oauth_status:{self.task_id}"
status_data = {
"status": "requires_redirect",
"message": "Authorization required",
"authorization_url": self.auth_url,
"state": self.extracted_state,
"requires_oauth": True,
"task_id": self.task_id,
}
self.redis_client.setex(status_key, 600, json.dumps(status_data))
async def callback_handler(self) -> tuple[str, str | None]:
"""Wait for auth code from Redis using the state value."""
@@ -779,6 +764,17 @@ class DocsGPTOAuth(OAuthClientProvider):
max_wait_time = 300
code_key = f"{self.redis_prefix}code:{self.extracted_state}"
if self.task_id:
status_key = f"mcp_oauth_status:{self.task_id}"
status_data = {
"status": "awaiting_callback",
"message": "Waiting for authorization...",
"authorization_url": self.auth_url,
"state": self.extracted_state,
"requires_oauth": True,
"task_id": self.task_id,
}
self.redis_client.setex(status_key, 600, json.dumps(status_data))
start_time = time.time()
while time.time() - start_time < max_wait_time:
code_data = self.redis_client.get(code_key)
@@ -793,6 +789,14 @@ class DocsGPTOAuth(OAuthClientProvider):
self.redis_client.delete(
f"{self.redis_prefix}state:{self.extracted_state}"
)
if self.task_id:
status_data = {
"status": "callback_received",
"message": "Completing authentication...",
"task_id": self.task_id,
}
self.redis_client.setex(status_key, 600, json.dumps(status_data))
return code, returned_state
error_key = f"{self.redis_prefix}error:{self.extracted_state}"
error_data = self.redis_client.get(error_key)
@@ -1034,73 +1038,8 @@ class MCPOAuthManager:
logger.error("Error handling OAuth callback: %s", e)
return False
def get_oauth_status(self, task_id: str, user_id: str) -> Dict[str, Any]:
"""Return the latest OAuth status for ``task_id`` from the user's SSE journal.
Mirrors the legacy polling contract: ``status`` derived from the
``mcp.oauth.*`` event-type suffix, with payload fields surfaced
(e.g. ``tools``/``tools_count`` on ``completed``).
"""
def get_oauth_status(self, task_id: str) -> Dict[str, Any]:
"""Get current status of OAuth flow using provided task_id."""
if not task_id:
return {"status": "not_started", "message": "OAuth flow not started"}
if not user_id:
return {"status": "not_found", "message": "User not provided"}
if self.redis_client is None:
return {"status": "not_found", "message": "Redis unavailable"}
try:
# OAuth flows are short-lived but a concurrent source
# ingest can flood the user channel between the OAuth
# popup completing and the user clicking Save, pushing the
# completion envelope outside the read window. Bound the
# scan by the configured stream cap so we cover the full
# journal — XADD MAXLEN keeps that bounded too.
scan_count = max(settings.EVENTS_STREAM_MAXLEN, 200)
entries = self.redis_client.xrevrange(
stream_key(user_id), count=scan_count
)
except Exception:
logger.exception(
"xrevrange failed for oauth status: user_id=%s task_id=%s",
user_id,
task_id,
)
return {"status": "not_found", "message": "Status unavailable"}
for _entry_id, fields in entries:
if not isinstance(fields, dict):
continue
# decode_responses=False ⇒ bytes keys; the string-key fallback
# covers a future flip of that default without a forced refactor.
event_raw = fields.get(b"event")
if event_raw is None:
event_raw = fields.get("event")
if event_raw is None:
continue
if isinstance(event_raw, bytes):
try:
event_raw = event_raw.decode("utf-8")
except Exception:
continue
try:
envelope = json.loads(event_raw)
except Exception:
continue
if not isinstance(envelope, dict):
continue
event_type = envelope.get("type", "")
if not isinstance(event_type, str) or not event_type.startswith(
"mcp.oauth."
):
continue
scope = envelope.get("scope") or {}
if scope.get("kind") != "mcp_oauth" or scope.get("id") != task_id:
continue
payload = envelope.get("payload") or {}
return {
"status": event_type[len("mcp.oauth."):],
"task_id": task_id,
**payload,
}
return {"status": "not_found", "message": "Status not found"}
return mcp_oauth_status_task(task_id)

View File

@@ -1,4 +1,4 @@
"""0001 initial schema — consolidated baseline for user-data tables.
"""0001 initial schema — consolidated Phase-1..3 baseline.
Revision ID: 0001_initial
Revises:

View File

@@ -1,40 +0,0 @@
"""0007 message_events — durable journal of chat-stream events.
Snapshot half of the chat-stream snapshot+tail pattern. Composite PK
``(message_id, sequence_no)``, ``created_at`` indexed for retention
sweeps, ``ON DELETE CASCADE`` from ``conversation_messages``.
Revision ID: 0007_message_events
Revises: 0006_idempotency_lease
"""
from typing import Sequence, Union
from alembic import op
revision: str = "0007_message_events"
down_revision: Union[str, None] = "0006_idempotency_lease"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.execute(
"""
CREATE TABLE message_events (
message_id UUID NOT NULL REFERENCES conversation_messages(id) ON DELETE CASCADE,
sequence_no INTEGER NOT NULL,
event_type TEXT NOT NULL,
payload JSONB NOT NULL DEFAULT '{}'::jsonb,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
PRIMARY KEY (message_id, sequence_no)
);
CREATE INDEX message_events_created_at_idx ON message_events(created_at);
"""
)
def downgrade() -> None:
op.execute("DROP INDEX IF EXISTS message_events_created_at_idx;")
op.execute("DROP TABLE IF EXISTS message_events;")

View File

@@ -23,16 +23,9 @@ from application.core.settings import settings
from application.error import sanitize_api_error
from application.llm.llm_creator import LLMCreator
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.repositories.conversations import MessageUpdateOutcome
from application.storage.db.repositories.token_usage import TokenUsageRepository
from application.storage.db.repositories.user_logs import UserLogsRepository
from application.storage.db.session import db_readonly, db_session
from application.events.publisher import publish_user_event
from application.streaming.event_replay import format_sse_event
from application.streaming.message_journal import (
BatchedJournalWriter,
record_event,
)
from application.utils import check_required_fields
logger = logging.getLogger(__name__)
@@ -284,17 +277,6 @@ class BaseAnswerResource:
"update_message_status streaming failed for %s",
reserved_message_id,
)
# Seed last_heartbeat_at so watchdog doesn't fall back to `timestamp`
# (creation time) before the first STREAM_HEARTBEAT_INTERVAL tick.
try:
self.conversation_service.heartbeat_message(
reserved_message_id,
)
except Exception:
logger.exception(
"initial heartbeat seed failed for %s",
reserved_message_id,
)
streaming_marked = True
last_heartbeat_at = time.monotonic()
@@ -321,73 +303,13 @@ class BaseAnswerResource:
try:
agent.tool_executor.message_id = reserved_message_id
except Exception:
logger.debug(
"Could not set tool_executor.message_id; tool-call correlation will be missing for message_id=%s",
reserved_message_id,
)
# Per-stream monotonic SSE event id. Allocated by ``_emit`` and
# threaded through both the wire format (``id: <seq>\\n``) and
# the journal write so a reconnecting client can ``Last-Event-
# ID`` past anything they already saw. Continuations resume
# against the original ``reserved_message_id`` — seed the
# allocator from the journal's high-water mark so we don't
# collide on the duplicate-PK and silently lose every emit
# past the resume point.
sequence_no = -1
if _continuation and reserved_message_id:
try:
from application.storage.db.repositories.message_events import (
MessageEventsRepository,
)
with db_readonly() as conn:
latest = MessageEventsRepository(conn).latest_sequence_no(
reserved_message_id
)
if latest is not None:
sequence_no = latest
except Exception:
logger.exception(
"Continuation seq seed lookup failed for message_id=%s; "
"falling back to seq=-1 (duplicate-PK collisions will "
"be swallowed)",
reserved_message_id,
)
# One batched journal writer per stream.
journal_writer: Optional[BatchedJournalWriter] = (
BatchedJournalWriter(reserved_message_id)
if reserved_message_id
else None
)
def _emit(payload: dict) -> str:
"""Format-and-journal one SSE event.
With a reserved ``message_id``, buffers into the journal and
emits ``id: <seq>``-tagged SSE frames; otherwise falls back to
legacy ``data: ...\\n\\n`` framing.
"""
nonlocal sequence_no
if not reserved_message_id or journal_writer is None:
return f"data: {json.dumps(payload)}\n\n"
sequence_no += 1
seq = sequence_no
event_type = (
payload.get("type", "data")
if isinstance(payload, dict)
else "data"
)
normalised = payload if isinstance(payload, dict) else {"value": payload}
journal_writer.record(seq, event_type, normalised)
return format_sse_event(normalised, seq)
pass
try:
# Surface the placeholder id before any LLM tokens so a
# mid-handshake disconnect still has a row to tail-poll.
if reserved_message_id:
yield _emit(
early_event = json.dumps(
{
"type": "message_id",
"message_id": reserved_message_id,
@@ -397,6 +319,7 @@ class BaseAnswerResource:
"request_id": request_id,
}
)
yield f"data: {early_event}\n\n"
if _continuation:
gen_iter = agent.gen_continuation(
@@ -422,9 +345,8 @@ class BaseAnswerResource:
schema_info = line.get("schema")
structured_chunks.append(line["answer"])
else:
yield _emit(
{"type": "answer", "answer": line["answer"]}
)
data = json.dumps({"type": "answer", "answer": line["answer"]})
yield f"data: {data}\n\n"
elif "sources" in line:
_mark_streaming_once()
truncated_sources = []
@@ -437,40 +359,43 @@ class BaseAnswerResource:
)
truncated_sources.append(truncated_source)
if truncated_sources:
yield _emit(
data = json.dumps(
{"type": "source", "source": truncated_sources}
)
yield f"data: {data}\n\n"
elif "tool_calls" in line:
tool_calls = line["tool_calls"]
yield _emit({"type": "tool_calls", "tool_calls": tool_calls})
data = json.dumps({"type": "tool_calls", "tool_calls": tool_calls})
yield f"data: {data}\n\n"
elif "thought" in line:
thought += line["thought"]
yield _emit({"type": "thought", "thought": line["thought"]})
data = json.dumps({"type": "thought", "thought": line["thought"]})
yield f"data: {data}\n\n"
elif "type" in line:
if line.get("type") == "tool_calls_pending":
# Save continuation state and end the stream
paused = True
yield _emit(line)
data = json.dumps(line)
yield f"data: {data}\n\n"
elif line.get("type") == "error":
yield _emit(
{
"type": "error",
"error": sanitize_api_error(
line.get("error", "An error occurred")
),
}
)
sanitized_error = {
"type": "error",
"error": sanitize_api_error(line.get("error", "An error occurred"))
}
data = json.dumps(sanitized_error)
yield f"data: {data}\n\n"
else:
yield _emit(line)
data = json.dumps(line)
yield f"data: {data}\n\n"
if is_structured and structured_chunks:
yield _emit(
{
"type": "structured_answer",
"answer": response_full,
"structured": True,
"schema": schema_info,
}
)
structured_data = {
"type": "structured_answer",
"answer": response_full,
"structured": True,
"schema": schema_info,
}
data = json.dumps(structured_data)
yield f"data: {data}\n\n"
# ---- Paused: save continuation state and end stream early ----
if paused:
@@ -527,7 +452,6 @@ class BaseAnswerResource:
exc_info=True,
)
state_saved = False
if conversation_id:
try:
cont_service = ContinuationService()
@@ -561,65 +485,18 @@ class BaseAnswerResource:
agent.tool_executor, "client_tools", None
),
)
state_saved = True
except Exception as e:
logger.error(
f"Failed to save continuation state: {str(e)}",
exc_info=True,
)
# Notify the user out-of-band so they can navigate
# back to the conversation and decide on the
# pending tool calls. Gated on ``state_saved``: a
# missing pending_tool_state row would 404 the
# resume endpoint, so an unfulfillable notification
# is worse than no notification.
user_id_for_event = (
decoded_token.get("sub") if decoded_token else None
)
if state_saved and user_id_for_event and conversation_id:
pending_calls = continuation.get(
"pending_tool_calls", []
) if continuation else []
# Trim each pending tool call to its identifying
# metadata so a tool with a multi-MB argument
# doesn't blow out the per-event payload size
# cap. The resume page fetches full args from
# ``pending_tool_state`` regardless.
pending_summaries = [
{
k: tc.get(k)
for k in (
"call_id",
"tool_name",
"action_name",
"name",
)
if isinstance(tc, dict) and tc.get(k) is not None
}
for tc in (pending_calls or [])
if isinstance(tc, dict)
]
publish_user_event(
user_id_for_event,
"tool.approval.required",
{
"conversation_id": str(conversation_id),
"message_id": reserved_message_id,
"pending_tool_calls": pending_summaries,
},
scope={
"kind": "conversation",
"id": str(conversation_id),
},
)
id_data = {"type": "id", "id": str(conversation_id)}
data = json.dumps(id_data)
yield f"data: {data}\n\n"
yield _emit({"type": "id", "id": str(conversation_id)})
yield _emit({"type": "end"})
# Drain the terminal ``end`` so a reconnecting client
# sees it on snapshot — same reason as the main exit.
if journal_writer is not None:
journal_writer.close()
data = json.dumps({"type": "end"})
yield f"data: {data}\n\n"
return
if isNoneDoc:
@@ -726,7 +603,9 @@ class BaseAnswerResource:
f"completion: {e}",
exc_info=True,
)
yield _emit({"type": "id", "id": str(conversation_id)})
id_data = {"type": "id", "id": str(conversation_id)}
data = json.dumps(id_data)
yield f"data: {data}\n\n"
tool_calls_for_logging = self._prepare_tool_calls_for_logging(
getattr(agent, "tool_calls", tool_calls) or tool_calls
@@ -767,33 +646,12 @@ class BaseAnswerResource:
exc_info=True,
)
yield _emit({"type": "end"})
# Drain the journal buffer so the terminal ``end`` event is
# visible to any reconnecting client. Without this the
# client could snapshot up to the last flush boundary and
# then live-tail waiting for an ``end`` that's still
# sitting in memory.
if journal_writer is not None:
journal_writer.close()
data = json.dumps({"type": "end"})
yield f"data: {data}\n\n"
except GeneratorExit:
logger.info(f"Stream aborted by client for question: {question[:50]}... ")
# Drain any buffered events before the terminal one-shot
# ``record_event`` below — keeps the journal's seq order
# contiguous (buffered events ... terminal event). ``close``
# is idempotent; pairing it with ``flush`` matches the
# normal-exit and error branches so any future ``record()``
# past this point would log instead of silently buffering.
if journal_writer is not None:
journal_writer.flush()
journal_writer.close()
# Save partial response
# Whether the DB row was flipped to ``complete`` during this
# abort handler. Drives the choice of terminal journal event
# below: journal ``end`` only when the row actually matches,
# else journal ``error`` so a reconnecting client sees a
# failed terminal state instead of a blank "success".
finalized_complete = False
if should_save_conversation and response_full:
try:
if isNoneDoc:
@@ -828,7 +686,7 @@ class BaseAnswerResource:
)
llm._token_usage_source = "title"
if reserved_message_id is not None:
outcome = self.conversation_service.finalize_message(
self.conversation_service.finalize_message(
reserved_message_id,
response_full,
thought=thought,
@@ -847,15 +705,6 @@ class BaseAnswerResource:
),
},
)
# ``ALREADY_COMPLETE`` means the normal-path
# finalize at line 632 won the race: the DB row
# is already at ``complete`` and the reconnect
# journal should reflect that with ``end``,
# not a spurious ``error``.
finalized_complete = outcome in (
MessageUpdateOutcome.UPDATED,
MessageUpdateOutcome.ALREADY_COMPLETE,
)
else:
self.conversation_service.save_conversation(
conversation_id,
@@ -875,9 +724,6 @@ class BaseAnswerResource:
attachment_ids=attachment_ids,
metadata=query_metadata if query_metadata else None,
)
# No journal row to gate, but flag the save as
# successful for symmetry with the WAL path.
finalized_complete = True
compression_meta = getattr(agent, "compression_metadata", None)
compression_saved = getattr(agent, "compression_saved", False)
if conversation_id and compression_meta and not compression_saved:
@@ -901,63 +747,6 @@ class BaseAnswerResource:
logger.error(
f"Error saving partial response: {str(e)}", exc_info=True
)
# Journal a terminal event so reconnecting clients stop tailing;
# ``end`` only when the row is ``complete``, else ``error``.
if reserved_message_id is not None:
try:
sequence_no += 1
if finalized_complete:
# Match the wire shape ``_emit({"type": "end"})``
# uses on the normal path — the replay terminal
# check at ``event_replay._payload_is_terminal``
# reads ``payload.type``, and the frontend parses
# the same key off ``data:``.
record_event(
reserved_message_id,
sequence_no,
"end",
{"type": "end"},
)
else:
# Nothing was persisted under the complete status
# — mark the row failed so the reconciler doesn't
# need to sweep it, and journal an ``error`` so a
# reconnecting client surfaces the same failure
# the UI would show on a live error.
try:
self.conversation_service.finalize_message(
reserved_message_id,
response_full or TERMINATED_RESPONSE_PLACEHOLDER,
thought=thought,
sources=source_log_docs,
tool_calls=tool_calls,
model_id=model_id or self.default_model_id,
metadata=query_metadata if query_metadata else None,
status="failed",
error=ConnectionError(
"client disconnected before response was persisted"
),
)
except Exception as fin_err:
logger.error(
f"Failed to mark aborted message failed: {fin_err}",
exc_info=True,
)
record_event(
reserved_message_id,
sequence_no,
"error",
{
"type": "error",
"error": "Stream aborted before any response was produced.",
"code": "client_disconnect",
},
)
except Exception as journal_err:
logger.error(
f"Failed to journal terminal event on abort: {journal_err}",
exc_info=True,
)
raise
except Exception as e:
logger.error(f"Error in stream: {str(e)}", exc_info=True)
@@ -979,16 +768,13 @@ class BaseAnswerResource:
f"Failed to finalize errored message: {fin_err}",
exc_info=True,
)
yield _emit(
data = json.dumps(
{
"type": "error",
"error": "Please try again later. We apologize for any inconvenience.",
}
)
# Drain the terminal ``error`` event we just yielded so a
# reconnecting client sees it on snapshot.
if journal_writer is not None:
journal_writer.close()
yield f"data: {data}\n\n"
return
def process_response_stream(self, stream) -> Dict[str, Any]:
@@ -1010,22 +796,8 @@ class BaseAnswerResource:
for line in stream:
try:
# Each chunk may carry an ``id: <seq>`` header before
# the ``data:`` line. Pull just the ``data:`` body so
# the JSON decode doesn't choke on the SSE framing.
event_data = ""
for raw in line.split("\n"):
if raw.startswith("data:"):
event_data = raw[len("data:") :].lstrip()
break
if not event_data:
continue
event_data = line.replace("data: ", "").strip()
event = json.loads(event_data)
# The ``message_id`` event is informational for the
# streaming consumer and has no synchronous-API field;
# skip it so the type-switch below doesn't KeyError.
if event.get("type") == "message_id":
continue
if event["type"] == "id":
conversation_id = event["id"]

View File

@@ -1,135 +0,0 @@
"""GET /api/messages/<message_id>/events — chat-stream reconnect endpoint.
Authenticates the caller, verifies ``message_id`` belongs to the user,
then hands off to ``build_message_event_stream`` for snapshot+tail.
"""
from __future__ import annotations
import logging
import re
from typing import Iterator, Optional
from flask import Blueprint, Response, jsonify, make_response, request, stream_with_context
from sqlalchemy import text
from application.core.settings import settings
from application.storage.db.session import db_readonly
from application.streaming.event_replay import (
DEFAULT_KEEPALIVE_SECONDS,
DEFAULT_POLL_TIMEOUT_SECONDS,
build_message_event_stream,
)
logger = logging.getLogger(__name__)
messages_bp = Blueprint("message_stream", __name__)
# A message_id is the canonical UUID hex format. Reject anything else
# before the SQL layer so a malformed cookie can't surface as a 500.
_MESSAGE_ID_RE = re.compile(
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-"
r"[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$"
)
# ``sequence_no`` is a non-negative decimal integer. Anything else is
# corrupt client state — fall through to a fresh-replay cursor and let
# the snapshot reader catch the client up.
_SEQUENCE_NO_RE = re.compile(r"^\d+$")
def _normalise_last_event_id(raw: Optional[str]) -> Optional[int]:
if raw is None:
return None
raw = raw.strip()
if not raw or not _SEQUENCE_NO_RE.match(raw):
return None
return int(raw)
def _user_owns_message(message_id: str, user_id: str) -> bool:
"""Return True iff ``message_id`` belongs to ``user_id``."""
try:
with db_readonly() as conn:
row = conn.execute(
text(
"""
SELECT 1 FROM conversation_messages
WHERE id = CAST(:id AS uuid)
AND user_id = :u
LIMIT 1
"""
),
{"id": message_id, "u": user_id},
).first()
return row is not None
except Exception:
logger.exception(
"Ownership lookup failed for message_id=%s user_id=%s",
message_id,
user_id,
)
return False
@messages_bp.route("/api/messages/<message_id>/events", methods=["GET"])
def stream_message_events(message_id: str) -> Response:
decoded = getattr(request, "decoded_token", None)
user_id = decoded.get("sub") if isinstance(decoded, dict) else None
if not user_id:
return make_response(
jsonify({"success": False, "message": "Authentication required"}),
401,
)
if not _MESSAGE_ID_RE.match(message_id):
return make_response(
jsonify({"success": False, "message": "Invalid message id"}),
400,
)
if not _user_owns_message(message_id, user_id):
# Don't disclose whether the row exists — a malicious caller
# gets the same 404 whether the id is bogus, taken by another
# user, or simply gone.
return make_response(
jsonify({"success": False, "message": "Not found"}),
404,
)
raw_cursor = request.headers.get("Last-Event-ID") or request.args.get(
"last_event_id"
)
last_event_id = _normalise_last_event_id(raw_cursor)
keepalive_seconds = float(
getattr(settings, "SSE_KEEPALIVE_SECONDS", DEFAULT_KEEPALIVE_SECONDS)
)
@stream_with_context
def generate() -> Iterator[str]:
try:
yield from build_message_event_stream(
message_id,
last_event_id=last_event_id,
keepalive_seconds=keepalive_seconds,
poll_timeout_seconds=DEFAULT_POLL_TIMEOUT_SECONDS,
)
except GeneratorExit:
return
except Exception:
logger.exception(
"Reconnect stream crashed for message_id=%s user_id=%s",
message_id,
user_id,
)
response = Response(generate(), mimetype="text/event-stream")
response.headers["Cache-Control"] = "no-store"
response.headers["X-Accel-Buffering"] = "no"
response.headers["Connection"] = "keep-alive"
logger.info(
"message.event.connect message_id=%s user_id=%s last_event_id=%s",
message_id,
user_id,
last_event_id if last_event_id is not None else "-",
)
return response

View File

@@ -15,10 +15,7 @@ from sqlalchemy import text as sql_text
from application.core.settings import settings
from application.storage.db.base_repository import looks_like_uuid
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.repositories.conversations import (
ConversationsRepository,
MessageUpdateOutcome,
)
from application.storage.db.repositories.conversations import ConversationsRepository
from application.storage.db.session import db_readonly, db_session
@@ -308,17 +305,10 @@ class ConversationService:
status: str = "complete",
error: Optional[BaseException] = None,
title_inputs: Optional[Dict[str, Any]] = None,
) -> MessageUpdateOutcome:
"""Commit the response and tool_call confirms in one transaction.
The outcome propagates directly from ``update_message_by_id`` so
callers (notably the SSE abort handler) can tell a fresh
finalize from "the row was already terminal" — the latter must
still be treated as success when the prior state was
``complete``.
"""
) -> bool:
"""Commit the response and tool_call confirms in one transaction."""
if not message_id:
return MessageUpdateOutcome.INVALID
return False
sources = sources or []
for source in sources:
if "text" in source and isinstance(source["text"], str):
@@ -346,16 +336,16 @@ class ConversationService:
# retracting a row the reconciler already escalated.
with db_session() as conn:
repo = ConversationsRepository(conn)
outcome = repo.update_message_by_id(
ok = repo.update_message_by_id(
message_id, update_fields,
only_if_non_terminal=True,
)
if outcome is not MessageUpdateOutcome.UPDATED:
if not ok:
logger.warning(
f"finalize_message: no row updated for message_id={message_id} "
f"(outcome={outcome.value} — possibly already terminal)"
f"(possibly already terminal — reconciler may have escalated)"
)
return outcome
return False
repo.confirm_executed_tool_calls(message_id)
# Outside the txn — title-gen is a multi-second LLM round trip.
@@ -368,7 +358,7 @@ class ConversationService:
f"finalize_message title generation failed: {e}",
exc_info=True,
)
return MessageUpdateOutcome.UPDATED
return True
def _maybe_generate_title(
self,

View File

@@ -1,504 +0,0 @@
"""GET /api/events — user-scoped Server-Sent Events endpoint.
Subscribe-then-snapshot pattern: subscribe to ``user:{user_id}``
pub/sub, snapshot the Redis Streams backlog past ``Last-Event-ID``
inside the SUBSCRIBE-ack callback, flush snapshot, then tail live
events (dedup'd by stream id). See ``docs/runbooks/sse-notifications.md``.
"""
from __future__ import annotations
import json
import logging
import re
import time
from typing import Iterator, Optional
from flask import Blueprint, Response, jsonify, make_response, request, stream_with_context
from application.cache import get_redis_instance
from application.core.settings import settings
from application.events.keys import (
connection_counter_key,
replay_budget_key,
stream_id_compare,
stream_key,
topic_name,
)
from application.streaming.broadcast_channel import Topic
logger = logging.getLogger(__name__)
events = Blueprint("event_stream", __name__)
SUBSCRIBE_POLL_INTERVAL_SECONDS = 1.0
# WHATWG SSE treats CRLF, CR, and LF equivalently as line terminators.
_SSE_LINE_SPLIT = re.compile(r"\r\n|\r|\n")
# Redis Streams ids are ``ms`` or ``ms-seq`` where both halves are decimal.
# Anything else is a corrupted client cookie / IndexedDB residue and must
# not be passed to XRANGE — Redis would reject it and our truncation gate
# would silently fail.
_STREAM_ID_RE = re.compile(r"^\d+(-\d+)?$")
# Only emitted at most once per process so a misconfigured deployment
# doesn't drown the logs.
_local_user_warned = False
def _format_sse(data: str, *, event_id: Optional[str] = None) -> str:
"""Encode a payload as one SSE message terminated by a blank line.
Splits on any line-terminator variant (``\\r\\n``, ``\\r``, ``\\n``)
so a stray CR in upstream content can't smuggle a premature line
boundary into the wire format.
"""
lines: list[str] = []
if event_id:
lines.append(f"id: {event_id}")
for line in _SSE_LINE_SPLIT.split(data):
lines.append(f"data: {line}")
return "\n".join(lines) + "\n\n"
def _decode(value) -> Optional[str]:
if value is None:
return None
if isinstance(value, (bytes, bytearray)):
try:
return value.decode("utf-8")
except Exception:
return None
return str(value)
def _oldest_retained_id(redis_client, user_id: str) -> Optional[str]:
"""Return the id of the oldest entry still in the stream, or ``None``.
Used to detect ``Last-Event-ID`` having slid off the back of the
MAXLEN'd window.
"""
try:
info = redis_client.xinfo_stream(stream_key(user_id))
except Exception:
return None
if not isinstance(info, dict):
return None
# redis-py 7.4 returns str-keyed dicts here; the bytes-key probe is
# defence in depth in case ``decode_responses`` is ever flipped.
first_entry = info.get("first-entry") or info.get(b"first-entry")
if not first_entry:
return None
# XINFO STREAM returns first-entry as [id, [field, value, ...]]
try:
return _decode(first_entry[0])
except Exception:
return None
def _allow_replay(
redis_client, user_id: str, last_event_id: Optional[str]
) -> bool:
"""Per-user sliding-window snapshot-replay budget.
Fails open on Redis errors or when the budget is disabled. Empty-backlog
no-cursor connects skip INCR so dev double-mounts don't trip 429.
"""
budget = int(settings.EVENTS_REPLAY_BUDGET_REQUESTS_PER_WINDOW)
if budget <= 0:
return True
if redis_client is None:
return True
# Cheap pre-check: only INCR when we might actually replay. XLEN
# is one Redis op; the alternative (INCR every connect) is two
# ops AND wrongly counts no-op probes. The check is conservative:
# if ``last_event_id`` is set we always INCR, even if the cursor
# has already overtaken the latest entry — that case is rare and
# short-lived, and probing further would mean a redundant XRANGE.
if last_event_id is None:
try:
if int(redis_client.xlen(stream_key(user_id))) == 0:
return True
except Exception:
# XLEN probe failed; fall through to the INCR path so a
# transient Redis hiccup can't bypass the budget.
logger.debug(
"XLEN probe failed for replay budget check user=%s; "
"proceeding to INCR",
user_id,
)
window = max(1, int(settings.EVENTS_REPLAY_BUDGET_WINDOW_SECONDS))
key = replay_budget_key(user_id)
try:
used = int(redis_client.incr(key))
# Always (re)seed the TTL. Gating on ``used == 1`` would wedge
# the counter forever if INCR succeeds but EXPIRE raises on
# the seeding call. EXPIRE on an existing key resets the TTL
# to ``window`` — within ±1s of the per-window budget semantic.
redis_client.expire(key, window)
except Exception:
logger.debug(
"replay budget probe failed for user=%s; failing open",
user_id,
)
return True
return used <= budget
def _normalize_last_event_id(raw: Optional[str]) -> Optional[str]:
"""Validate the ``Last-Event-ID`` header / query param.
Returns the value unchanged when it parses as a Redis Streams id,
otherwise ``None`` — callers treat ``None`` as "client has nothing"
and replay from the start of the retained window. Invalid ids would
otherwise pass straight to XRANGE and surface as a quiet replay
failure plus broken truncation detection.
"""
if raw is None:
return None
raw = raw.strip()
if not raw or not _STREAM_ID_RE.match(raw):
return None
return raw
def _replay_backlog(
redis_client, user_id: str, last_event_id: Optional[str], max_count: int
) -> Iterator[tuple[str, str]]:
"""Yield ``(entry_id, sse_line)`` for backlog entries past ``last_event_id``.
Capped at ``max_count`` rows; clients catch up across reconnects.
Parse failures are skipped; the Streams id is injected into the
envelope so replay matches live-tail shape.
"""
# Exclusive start: '(<id>' skips the already-delivered entry.
start = f"({last_event_id}" if last_event_id else "-"
try:
entries = redis_client.xrange(
stream_key(user_id), min=start, max="+", count=max_count
)
except Exception as exc:
logger.warning(
"xrange replay failed for user=%s last_id=%s err=%s",
user_id,
last_event_id or "-",
exc,
)
return
for entry_id, fields in entries:
entry_id_str = _decode(entry_id)
if not entry_id_str:
continue
# decode_responses=False on the cache client ⇒ field keys/values
# are bytes. The string-key fallback covers a future flip of that
# default without a forced refactor here.
raw_event = None
if isinstance(fields, dict):
raw_event = fields.get(b"event")
if raw_event is None:
raw_event = fields.get("event")
event_str = _decode(raw_event)
if not event_str:
continue
try:
envelope = json.loads(event_str)
if isinstance(envelope, dict):
envelope["id"] = entry_id_str
event_str = json.dumps(envelope)
except Exception:
logger.debug(
"Replay envelope parse failed for entry %s; passing through raw",
entry_id_str,
)
yield entry_id_str, _format_sse(event_str, event_id=entry_id_str)
def _truncation_notice_line(oldest_id: str) -> str:
"""SSE event the frontend can react to with a full-state refetch."""
return _format_sse(
json.dumps(
{
"type": "backlog.truncated",
"payload": {"oldest_retained_id": oldest_id},
}
)
)
@events.route("/api/events", methods=["GET"])
def stream_events() -> Response:
decoded = getattr(request, "decoded_token", None)
user_id = decoded.get("sub") if isinstance(decoded, dict) else None
if not user_id:
return make_response(
jsonify({"success": False, "message": "Authentication required"}),
401,
)
# In dev deployments without AUTH_TYPE configured, every request
# resolves to user_id="local" and shares one stream. Surface this so
# an accidentally-multi-user dev box doesn't silently cross-stream.
global _local_user_warned
if user_id == "local" and not _local_user_warned:
logger.warning(
"SSE serving user_id='local' (AUTH_TYPE not set). "
"All clients on this deployment will share one event stream."
)
_local_user_warned = True
raw_last_event_id = request.headers.get("Last-Event-ID") or request.args.get(
"last_event_id"
)
last_event_id = _normalize_last_event_id(raw_last_event_id)
last_event_id_invalid = raw_last_event_id is not None and last_event_id is None
keepalive_seconds = float(settings.SSE_KEEPALIVE_SECONDS)
push_enabled = settings.ENABLE_SSE_PUSH
cap = int(settings.SSE_MAX_CONCURRENT_PER_USER)
redis_client = get_redis_instance()
counter_key = connection_counter_key(user_id)
counted = False
if push_enabled and redis_client is not None and cap > 0:
try:
current = int(redis_client.incr(counter_key))
counted = True
except Exception:
current = 0
logger.debug(
"SSE connection counter INCR failed for user=%s", user_id
)
if counted:
# 1h safety TTL — orphaned counts from hard crashes self-heal.
# EXPIRE failure must NOT clobber ``current`` and bypass the cap.
try:
redis_client.expire(counter_key, 3600)
except Exception:
logger.debug(
"SSE connection counter EXPIRE failed for user=%s", user_id
)
if current > cap:
try:
redis_client.decr(counter_key)
except Exception:
logger.debug(
"SSE connection counter DECR failed for user=%s",
user_id,
)
return make_response(
jsonify(
{
"success": False,
"message": "Too many concurrent SSE connections",
}
),
429,
)
# Replay budget is checked here, before the generator opens the
# stream, so a denial can surface as HTTP 429 instead of a silent
# snapshot skip. The earlier in-generator skip lost events between
# the client's cursor and the first live-tailed entry: the live
# tail still carried ``id:`` headers, the frontend advanced
# ``lastEventId`` to one of those ids, and the events in between
# were never reachable on the next reconnect. 429 keeps the
# cursor pinned and lets the frontend back off until the window
# slides (eventStreamClient.ts treats 429 as escalated backoff).
if push_enabled and redis_client is not None and not _allow_replay(
redis_client, user_id, last_event_id
):
if counted:
try:
redis_client.decr(counter_key)
except Exception:
logger.debug(
"SSE connection counter DECR failed for user=%s",
user_id,
)
return make_response(
jsonify(
{
"success": False,
"message": "Replay budget exhausted",
}
),
429,
)
@stream_with_context
def generate() -> Iterator[str]:
connect_ts = time.monotonic()
replayed_count = 0
try:
# First frame primes intermediaries (Cloudflare, nginx) so they
# don't sit on a buffer waiting for body bytes.
yield ": connected\n\n"
if not push_enabled:
yield ": push_disabled\n\n"
return
replay_lines: list[str] = []
max_replayed_id: Optional[str] = None
replay_done = False
# If the client sent a malformed Last-Event-ID, surface the
# truncation notice synchronously *before* the subscribe
# loop. Buffering it into ``replay_lines`` would lose it
# when ``Topic.subscribe`` returns immediately (Redis down)
# — the loop body never runs, and the flush at line ~335
# never fires.
if last_event_id_invalid:
yield _truncation_notice_line("")
replayed_count += 1
def _on_subscribe_callback() -> None:
# Runs synchronously inside Topic.subscribe after the
# SUBSCRIBE is acked. By doing XRANGE here, any publisher
# firing between SUBSCRIBE-send and SUBSCRIBE-ack has its
# XADD captured by XRANGE *and* its PUBLISH buffered at
# the connection layer until we read it — closing the
# replay/subscribe race the design doc warns about.
#
# Truncation contract: ``backlog.truncated`` is emitted
# ONLY when the client's ``Last-Event-ID`` has slid off
# the MAXLEN'd window — that's the case where the
# journal is genuinely gone past the cursor and the
# frontend should clear its slice cursor and refetch
# state. Cap-hit skips the snapshot silently: the
# cursor advances via the per-entry ``id:`` headers
# and the frontend's slice keeps the latest id so the
# next reconnect resumes from there. Budget-exhausted
# never reaches this callback — the route 429s before
# opening the stream, keeping the cursor pinned.
# Conflating these with stale-cursor truncation would
# tell the client to clear its cursor and re-receive
# the same oldest-N entries on every reconnect —
# locking the user out of entries past N.
nonlocal max_replayed_id, replay_done
try:
if redis_client is None:
return
oldest = _oldest_retained_id(redis_client, user_id)
if (
last_event_id
and oldest
and stream_id_compare(last_event_id, oldest) < 0
):
# The Last-Event-ID has slid off the MAXLEN window.
# Tell the client so it can fetch full state.
replay_lines.append(_truncation_notice_line(oldest))
replay_cap = int(settings.EVENTS_REPLAY_MAX_PER_REQUEST)
for entry_id, sse_line in _replay_backlog(
redis_client, user_id, last_event_id, replay_cap
):
replay_lines.append(sse_line)
max_replayed_id = entry_id
finally:
# Always flip the flag — even on partial-replay failure
# the outer loop must reach the flush step so we don't
# silently strand whatever entries did land.
replay_done = True
topic = Topic(topic_name(user_id))
last_keepalive = time.monotonic()
for payload in topic.subscribe(
on_subscribe=_on_subscribe_callback,
poll_timeout=SUBSCRIBE_POLL_INTERVAL_SECONDS,
):
# Flush snapshot on the first iteration after the SUBSCRIBE
# callback ran. This runs at most once per connection.
if replay_done and replay_lines:
for line in replay_lines:
yield line
replayed_count += 1
replay_lines.clear()
now = time.monotonic()
if payload is None:
if now - last_keepalive >= keepalive_seconds:
yield ": keepalive\n\n"
last_keepalive = now
continue
event_str = _decode(payload) or ""
event_id: Optional[str] = None
try:
envelope = json.loads(event_str)
if isinstance(envelope, dict):
candidate = envelope.get("id")
# Only trust ids that look like real Redis Streams
# ids (``ms`` or ``ms-seq``). A malformed or
# adversarial publisher could otherwise pin
# dedupe forever — a lex-greater bogus id would
# make every legitimate later id compare ``<=``
# and get dropped silently.
if isinstance(candidate, str) and _STREAM_ID_RE.match(
candidate
):
event_id = candidate
except Exception:
pass
# Dedupe: if this id was already covered by replay, drop it.
if (
event_id is not None
and max_replayed_id is not None
and stream_id_compare(event_id, max_replayed_id) <= 0
):
continue
yield _format_sse(event_str, event_id=event_id)
last_keepalive = now
# Topic.subscribe exited before the first yield (transient
# Redis hiccup between SUBSCRIBE-ack and the first poll, or
# an immediate Redis-down return). The callback may already
# have populated the snapshot — flush it so the client gets
# the backlog instead of a silent drop. Safe no-op when the
# in-loop flush ran (it clear()'d the buffer) and when the
# callback never fired (replay_done stays False).
if replay_done and replay_lines:
for line in replay_lines:
yield line
replayed_count += 1
replay_lines.clear()
except GeneratorExit:
return
except Exception:
logger.exception(
"SSE event-stream generator crashed for user=%s", user_id
)
finally:
duration_s = time.monotonic() - connect_ts
logger.info(
"event.disconnect user=%s duration_s=%.1f replayed=%d",
user_id,
duration_s,
replayed_count,
)
if counted and redis_client is not None:
try:
redis_client.decr(counter_key)
except Exception:
logger.debug(
"SSE connection counter DECR failed for user=%s on disconnect",
user_id,
)
response = Response(generate(), mimetype="text/event-stream")
response.headers["Cache-Control"] = "no-store"
response.headers["X-Accel-Buffering"] = "no"
response.headers["Connection"] = "keep-alive"
logger.info(
"event.connect user=%s last_event_id=%s%s",
user_id,
last_event_id or "-",
" (rejected_invalid)" if last_event_id_invalid else "",
)
return response

View File

@@ -214,10 +214,6 @@ class StoreAttachment(Resource):
{
"success": True,
"task_id": tasks[0]["task_id"],
# Surface the attachment_id so the frontend
# can correlate ``attachment.*`` SSE events
# to this row and skip the polling fallback.
"attachment_id": tasks[0]["attachment_id"],
"message": "File uploaded successfully. Processing started.",
}
),

View File

@@ -7,13 +7,9 @@ from flask_restx import fields, Namespace, Resource
from sqlalchemy import text as sql_text
from application.api import api
from application.api.answer.services.conversation_service import (
TERMINATED_RESPONSE_PLACEHOLDER,
)
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
from application.storage.db.repositories.attachments import AttachmentsRepository
from application.storage.db.repositories.conversations import ConversationsRepository
from application.storage.db.repositories.message_events import MessageEventsRepository
from application.storage.db.session import db_readonly, db_session
from application.utils import check_required_fields
@@ -350,25 +346,6 @@ class GetMessageTail(Resource):
if row is None:
return make_response(jsonify({"status": "not found"}), 404)
msg = row_to_dict(row)
# Mid-stream the row's response is the placeholder; rebuild
# the live partial from the journal so /tail mirrors SSE.
status = msg.get("status")
response = msg.get("response")
thought = msg.get("thought")
sources = msg.get("sources") or []
tool_calls = msg.get("tool_calls") or []
if status in ("pending", "streaming") and (
response == TERMINATED_RESPONSE_PLACEHOLDER
):
partial = MessageEventsRepository(conn).reconstruct_partial(
message_id
)
response = partial["response"]
thought = partial["thought"] or thought
if partial["sources"]:
sources = partial["sources"]
if partial["tool_calls"]:
tool_calls = partial["tool_calls"]
except Exception as err:
current_app.logger.error(
f"Error tailing message {message_id}: {err}", exc_info=True
@@ -379,11 +356,11 @@ class GetMessageTail(Resource):
jsonify(
{
"message_id": str(msg["id"]),
"status": status,
"response": response,
"thought": thought,
"sources": sources,
"tool_calls": tool_calls,
"status": msg.get("status"),
"response": msg.get("response"),
"thought": msg.get("thought"),
"sources": msg.get("sources") or [],
"tool_calls": msg.get("tool_calls") or [],
"request_id": msg.get("request_id"),
"last_heartbeat_at": metadata.get("last_heartbeat_at"),
"error": metadata.get("error"),

View File

@@ -13,7 +13,6 @@ from sqlalchemy import text as sql_text
from application.api import api
from application.api.user.tasks import ingest, ingest_connector_task, ingest_remote
from application.core.settings import settings
from application.storage.db.source_ids import derive_source_id as _derive_source_id
from application.parser.connectors.connector_creator import ConnectorCreator
from application.parser.file.constants import SUPPORTED_SOURCE_EXTENSIONS
from application.storage.db.repositories.idempotency import IdempotencyRepository
@@ -70,13 +69,7 @@ def _claim_task_or_get_cached(key, task_name):
Pre-generates the celery task_id so a losing writer sees the same
id immediately. Returns ``(task_id, cached_response)``; non-None
cached means the caller should return without enqueuing. The
cached payload mirrors the fresh-request response shape (including
``source_id``) so the frontend can correlate SSE ingest events to
the cached upload task without an extra round-trip — but only when
the cached row actually exists; the "deduplicated" sentinel
deliberately omits ``source_id`` so the frontend doesn't bind to a
phantom source.
cached means the caller should return without enqueuing.
"""
predetermined_id = str(uuid.uuid4())
with db_session() as conn:
@@ -88,16 +81,10 @@ def _claim_task_or_get_cached(key, task_name):
with db_readonly() as conn:
existing = IdempotencyRepository(conn).get_task(key)
cached_id = existing.get("task_id") if existing else None
payload: dict = {
return None, {
"success": True,
"task_id": cached_id or "deduplicated",
}
# Only surface ``source_id`` when there's a real winner whose worker
# is publishing SSE events tagged with that id. The "deduplicated"
# branch means the lock row vanished — we have nothing to correlate.
if cached_id is not None:
payload["source_id"] = str(_derive_source_id(key))
return None, payload
def _release_claim(key):
@@ -249,15 +236,6 @@ class UploadFile(Resource):
file_path = f"{base_path}/{safe_file}"
with open(temp_file_path, "rb") as f:
storage.save_file(f, file_path)
# Mint the source UUID up here so the HTTP response and the
# worker's SSE envelopes share one id. With an idempotency
# key we reuse the deterministic uuid5 (retried task lands on
# the same source row); without a key we fall back to uuid4.
# The worker is told to use this id verbatim — see
# ``ingest_worker(source_id=...)``.
source_uuid = (
_derive_source_id(scoped_key) if scoped_key else uuid.uuid4()
)
ingest_kwargs = dict(
args=(
settings.UPLOAD_FOLDER,
@@ -271,7 +249,6 @@ class UploadFile(Resource):
"file_name_map": file_name_map,
# Scoped so the worker dedup row matches the HTTP claim.
"idempotency_key": scoped_key or idempotency_key,
"source_id": str(source_uuid),
},
)
if predetermined_task_id is not None:
@@ -296,15 +273,7 @@ class UploadFile(Resource):
return make_response(jsonify({"success": False}), 400)
# Predetermined id matches the dedup-claim row; loser GET sees same.
response_task_id = predetermined_task_id or task.id
# ``source_uuid`` was minted above and passed to the worker as
# ``source_id``; the worker uses it verbatim for every SSE event,
# so the frontend can correlate inbound ``source.ingest.*`` to
# this upload regardless of whether an idempotency key was set.
response_payload: dict = {
"success": True,
"task_id": response_task_id,
"source_id": str(source_uuid),
}
response_payload = {"success": True, "task_id": response_task_id}
return make_response(jsonify(response_payload), 200)
@@ -357,18 +326,6 @@ class UploadRemote(Resource):
)
if cached is not None:
return make_response(jsonify(cached), 200)
# Mint the source UUID up here so the HTTP response and the
# worker's SSE envelopes share one id. Same pattern as
# ``UploadFile.post``: with an idempotency key we reuse the
# deterministic uuid5 (retried task lands on the same source
# row); without a key we fall back to uuid4. The worker is told
# to use this id verbatim — see ``remote_worker`` and
# ``ingest_connector``. Without this the no-key path would mint
# a random uuid4 inside the worker that the frontend has no way
# to correlate SSE events to.
source_uuid = (
_derive_source_id(scoped_key) if scoped_key else uuid.uuid4()
)
try:
config = json.loads(data["data"])
source_data = None
@@ -425,23 +382,13 @@ class UploadRemote(Resource):
"recursive": config.get("recursive", False),
"retriever": config.get("retriever", "classic"),
"idempotency_key": scoped_key or idempotency_key,
"source_id": str(source_uuid),
},
}
if predetermined_task_id is not None:
connector_kwargs["task_id"] = predetermined_task_id
task = ingest_connector_task.apply_async(**connector_kwargs)
response_task_id = predetermined_task_id or task.id
# ``source_uuid`` was minted above and passed to the
# worker as ``source_id``; the worker uses it verbatim
# for every SSE event, so the frontend can correlate
# inbound ``source.ingest.*`` regardless of whether an
# idempotency key was set.
response_payload = {
"success": True,
"task_id": response_task_id,
"source_id": str(source_uuid),
}
response_payload = {"success": True, "task_id": response_task_id}
return make_response(jsonify(response_payload), 200)
remote_kwargs = {
"kwargs": {
@@ -450,7 +397,6 @@ class UploadRemote(Resource):
"user": user,
"loader": data["source"],
"idempotency_key": scoped_key or idempotency_key,
"source_id": str(source_uuid),
},
}
if predetermined_task_id is not None:
@@ -464,11 +410,7 @@ class UploadRemote(Resource):
_release_claim(scoped_key)
return make_response(jsonify({"success": False}), 400)
response_task_id = predetermined_task_id or task.id
response_payload = {
"success": True,
"task_id": response_task_id,
"source_id": str(source_uuid),
}
response_payload = {"success": True, "task_id": response_task_id}
return make_response(jsonify(response_payload), 200)
@@ -611,19 +553,6 @@ class ManageSourceFiles(Resource):
scoped_key, "reingest_source_task",
)
if cached is not None:
# Frontend keys reingest polling on
# ``reingest_task_id``; the shared cache helper
# writes ``task_id``. Alias here so a dedup
# response doesn't silently break FileTree's
# poller. Override ``source_id`` too — the
# helper derives it from the scoped key, which
# is correct for upload but wrong for reingest
# (the worker publishes events scoped to the
# actual source row id).
cached_task_id = cached.pop("task_id", None)
if cached_task_id is not None:
cached["reingest_task_id"] = cached_task_id
cached["source_id"] = resolved_source_id
return make_response(jsonify(cached), 200)
added_files = []
@@ -679,12 +608,6 @@ class ManageSourceFiles(Resource):
"added_files": added_files,
"parent_dir": parent_dir,
"reingest_task_id": task.id,
# ``source_id`` lets the frontend correlate
# inbound ``source.ingest.*`` SSE events
# (emitted by ``reingest_source_worker``)
# back to the reingest task — matches the
# upload route's source-id contract.
"source_id": resolved_source_id,
}
),
200,
@@ -736,15 +659,6 @@ class ManageSourceFiles(Resource):
scoped_key, "reingest_source_task",
)
if cached is not None:
cached_task_id = cached.pop("task_id", None)
if cached_task_id is not None:
cached["reingest_task_id"] = cached_task_id
# Override the helper's synthetic source_id (uuid5
# of the scoped key) with the real source row id
# — the reingest worker publishes SSE events
# scoped to ``resolved_source_id`` and FileTree
# correlates on it.
cached["source_id"] = resolved_source_id
return make_response(jsonify(cached), 200)
# Remove files from storage and directory structure
@@ -790,7 +704,6 @@ class ManageSourceFiles(Resource):
"message": f"Removed {len(removed_files)} files",
"removed_files": removed_files,
"reingest_task_id": task.id,
"source_id": resolved_source_id,
}
),
200,
@@ -849,14 +762,6 @@ class ManageSourceFiles(Resource):
scoped_key, "reingest_source_task",
)
if cached is not None:
cached_task_id = cached.pop("task_id", None)
if cached_task_id is not None:
cached["reingest_task_id"] = cached_task_id
# Same source_id override as the ``remove`` /
# ``add`` cached branches — the helper's synthetic
# id doesn't match what reingest_source_worker
# tags its SSE events with.
cached["source_id"] = resolved_source_id
return make_response(jsonify(cached), 200)
success = storage.remove_directory(full_directory_path)
@@ -920,7 +825,6 @@ class ManageSourceFiles(Resource):
"message": f"Successfully removed directory: {directory_path}",
"removed_directory": directory_path,
"reingest_task_id": task.id,
"source_id": resolved_source_id,
}
),
200,

View File

@@ -7,6 +7,7 @@ from application.worker import (
attachment_worker,
ingest_worker,
mcp_oauth,
mcp_oauth_status,
remote_worker,
sync,
sync_worker,
@@ -39,7 +40,6 @@ def ingest(
filename,
file_name_map=None,
idempotency_key=None,
source_id=None,
):
resp = ingest_worker(
self,
@@ -51,21 +51,16 @@ def ingest(
user,
file_name_map=file_name_map,
idempotency_key=idempotency_key,
source_id=source_id,
)
return resp
@celery.task(**DURABLE_TASK)
@with_idempotency(task_name="ingest_remote")
def ingest_remote(
self, source_data, job_name, user, loader,
idempotency_key=None, source_id=None,
):
def ingest_remote(self, source_data, job_name, user, loader, idempotency_key=None):
resp = remote_worker(
self, source_data, job_name, user, loader,
idempotency_key=idempotency_key,
source_id=source_id,
)
return resp
@@ -143,7 +138,6 @@ def ingest_connector_task(
doc_id=None,
sync_frequency="never",
idempotency_key=None,
source_id=None,
):
from application.worker import ingest_connector
@@ -161,7 +155,6 @@ def ingest_connector_task(
doc_id=doc_id,
sync_frequency=sync_frequency,
idempotency_key=idempotency_key,
source_id=source_id,
)
return resp
@@ -204,15 +197,6 @@ def setup_periodic_tasks(sender, **kwargs):
version_check_task.s(),
name="version-check",
)
# Bound ``message_events`` growth — every streamed SSE chunk writes
# one row, so retained chats accumulate hundreds of rows per
# message. Reconnect-replay is only meaningful for streams the user
# could plausibly still be waiting on, so 14 days is generous.
sender.add_periodic_task(
timedelta(hours=24),
cleanup_message_events.s(),
name="cleanup-message-events",
)
@celery.task(bind=True)
@@ -221,6 +205,12 @@ def mcp_oauth_task(self, config, user):
return resp
@celery.task(bind=True)
def mcp_oauth_status_task(self, task_id):
resp = mcp_oauth_status(self, task_id)
return resp
@celery.task(bind=True, acks_late=False)
def cleanup_pending_tool_state(self):
"""Revert stale ``resuming`` rows, then delete TTL-expired rows."""
@@ -275,32 +265,6 @@ def reconciliation_task(self):
return run_reconciliation()
@celery.task(bind=True, acks_late=False)
def cleanup_message_events(self):
"""Delete ``message_events`` rows older than the retention window.
Streamed answer responses write one journal row per SSE yield,
so unbounded growth would dominate Postgres for any retained-
conversations deployment. The reconnect-replay path only needs
rows for in-flight streams; 14 days covers paused/tool-action
flows comfortably.
"""
from application.core.settings import settings
if not settings.POSTGRES_URI:
return {"deleted": 0, "skipped": "POSTGRES_URI not set"}
from application.storage.db.engine import get_engine
from application.storage.db.repositories.message_events import (
MessageEventsRepository,
)
ttl_days = settings.MESSAGE_EVENTS_RETENTION_DAYS
engine = get_engine()
with engine.begin() as conn:
deleted = MessageEventsRepository(conn).cleanup_older_than(ttl_days)
return {"deleted": deleted, "ttl_days": ttl_days}
@celery.task(bind=True, acks_late=False)
def version_check_task(self):
"""Periodic anonymous version check.

View File

@@ -1,5 +1,6 @@
"""Tool management MCP server integration."""
import json
from urllib.parse import urlencode, urlparse
from flask import current_app, jsonify, make_response, redirect, request
@@ -225,9 +226,7 @@ class MCPServerSave(Resource):
)
redis_client = get_redis_instance()
manager = MCPOAuthManager(redis_client)
result = manager.get_oauth_status(
config["oauth_task_id"], user
)
result = manager.get_oauth_status(config["oauth_task_id"])
if not result.get("status") == "completed":
return make_response(
jsonify(
@@ -439,6 +438,56 @@ class MCPOAuthCallback(Resource):
)
@tools_mcp_ns.route("/mcp_server/oauth_status/<string:task_id>")
class MCPOAuthStatus(Resource):
def get(self, task_id):
try:
redis_client = get_redis_instance()
status_key = f"mcp_oauth_status:{task_id}"
status_data = redis_client.get(status_key)
if status_data:
status = json.loads(status_data)
if "tools" in status and isinstance(status["tools"], list):
status["tools"] = [
{
"name": t.get("name", "unknown"),
"description": t.get("description", ""),
}
for t in status["tools"]
]
return make_response(
jsonify({"success": True, "task_id": task_id, **status})
)
else:
return make_response(
jsonify(
{
"success": True,
"task_id": task_id,
"status": "pending",
"message": "Waiting for OAuth to start...",
}
),
200,
)
except Exception as e:
current_app.logger.error(
f"Error getting OAuth status for task {task_id}: {str(e)}",
exc_info=True,
)
return make_response(
jsonify(
{
"success": False,
"error": "Failed to get OAuth status",
"task_id": task_id,
}
),
500,
)
@tools_mcp_ns.route("/mcp_server/auth_status")
class MCPAuthStatus(Resource):
@api.doc(

View File

@@ -222,26 +222,13 @@ def _stream_response(
for line in internal_stream:
if not line.strip():
continue
# ``complete_stream`` prefixes each frame with ``id: <seq>\n``
# before the ``data:`` line. Extract just the data line so JSON
# decode doesn't choke on the SSE framing.
event_str = ""
for raw in line.split("\n"):
if raw.startswith("data:"):
event_str = raw[len("data:") :].lstrip()
break
if not event_str:
continue
# Parse the internal SSE event
event_str = line.replace("data: ", "").strip()
try:
event_data = json.loads(event_str)
except (json.JSONDecodeError, TypeError):
continue
# Skip the informational ``message_id`` event — it has no v1 /
# OpenAI-compatible analog.
if event_data.get("type") == "message_id":
continue
# Update completion_id when we get the conversation id
if event_data.get("type") == "id":
conv_id = event_data.get("id", "")

View File

@@ -16,8 +16,6 @@ setup_logging()
from application.api import api # noqa: E402
from application.api.answer import answer # noqa: E402
from application.api.answer.routes.messages import messages_bp # noqa: E402
from application.api.events.routes import events # noqa: E402
from application.api.internal.routes import internal # noqa: E402
from application.api.user.routes import user # noqa: E402
from application.api.connector.routes import connector # noqa: E402
@@ -51,8 +49,6 @@ ensure_database_ready(
app = Flask(__name__)
app.register_blueprint(user)
app.register_blueprint(answer)
app.register_blueprint(events)
app.register_blueprint(messages_bp)
app.register_blueprint(internal)
app.register_blueprint(connector)
app.register_blueprint(v1_bp)

View File

@@ -29,17 +29,8 @@ def get_redis_instance():
with _instance_lock:
if _redis_instance is None and not _redis_creation_failed:
try:
# ``health_check_interval`` makes redis-py ping the
# connection every N seconds when otherwise idle.
# Without it, a half-open TCP (NAT silently dropped
# state, ELB idle-close) can hang the SSE generator
# in ``pubsub.get_message`` past its keepalive
# cadence — the kernel never surfaces the dead
# socket because no payload is in flight.
_redis_instance = redis.Redis.from_url(
settings.CACHE_REDIS_URL,
socket_connect_timeout=2,
health_check_interval=10,
settings.CACHE_REDIS_URL, socket_connect_timeout=2
)
except ValueError as e:
logger.error(f"Invalid Redis URL: {e}")

View File

@@ -188,42 +188,6 @@ class Settings(BaseSettings):
COMPRESSION_PROMPT_VERSION: str = "v1.0" # Track prompt iterations
COMPRESSION_MAX_HISTORY_POINTS: int = 3 # Keep only last N compression points to prevent DB bloat
# Internal SSE push channel (notifications + durable replay journal)
# Master switch — when False, /api/events emits a "push_disabled" comment
# and returns; clients fall back to polling. Publisher becomes a no-op.
ENABLE_SSE_PUSH: bool = True
# Per-user durable backlog cap (~entries). At typical event rates this
# gives ~24h of replay; tune up for verbose feeds, down for memory.
EVENTS_STREAM_MAXLEN: int = 1000
# SSE keepalive comment cadence. Must sit under Cloudflare's 100s idle
# close and iOS Safari's ~60s — 15s gives generous headroom.
SSE_KEEPALIVE_SECONDS: int = 15
# Cap on simultaneous SSE connections per user. Each connection holds
# one WSGI thread (32 per gunicorn worker) and one Redis pub/sub
# connection. 8 covers normal multi-tab use without letting one user
# starve the pool. Set to 0 to disable the cap.
SSE_MAX_CONCURRENT_PER_USER: int = 8
# Per-request cap on the number of backlog entries XRANGE returns
# for ``/api/events`` snapshots. Bounds the bytes a single replay
# can move from Redis to the wire — a malicious client looping
# ``Last-Event-ID=<oldest>`` reconnects can only enumerate this
# many entries per round-trip. Combined with the per-user
# connection cap above and the windowed budget below, total
# enumeration throughput is bounded.
EVENTS_REPLAY_MAX_PER_REQUEST: int = 200
# Sliding-window cap on snapshot replays per user. Once the budget
# is exhausted the route returns HTTP 429 with the cursor pinned;
# the client backs off and retries after the window rolls over.
EVENTS_REPLAY_BUDGET_REQUESTS_PER_WINDOW: int = 30
EVENTS_REPLAY_BUDGET_WINDOW_SECONDS: int = 60
# Retention for the ``message_events`` journal. The ``cleanup_message_events``
# beat task deletes rows older than this. Reconnect-replay only
# needs the journal for streams a client could still be tailing,
# so 14 days is a generous default that covers paused/tool-action
# flows without unbounded table growth.
MESSAGE_EVENTS_RETENTION_DAYS: int = 14
@field_validator("POSTGRES_URI", mode="before")
@classmethod
def _normalize_postgres_uri_validator(cls, v):

View File

@@ -1,52 +0,0 @@
"""Stream/topic key derivations shared by publisher and SSE consumer.
Single source of truth for the per-user Redis Streams key and pub/sub
topic name. Both must agree exactly — a typo here splits the
publisher's writes from the consumer's reads.
"""
from __future__ import annotations
def stream_key(user_id: str) -> str:
"""Redis Streams key holding the durable backlog for ``user_id``."""
return f"user:{user_id}:stream"
def topic_name(user_id: str) -> str:
"""Redis pub/sub channel used for live fan-out to ``user_id``."""
return f"user:{user_id}"
def connection_counter_key(user_id: str) -> str:
"""Redis counter tracking active SSE connections for ``user_id``."""
return f"user:{user_id}:sse_count"
def replay_budget_key(user_id: str) -> str:
"""Redis counter tracking snapshot replays for ``user_id`` in the
rolling rate-limit window."""
return f"user:{user_id}:replay_count"
def stream_id_compare(a: str, b: str) -> int:
"""Compare two Redis Streams ids. Returns -1, 0, 1 like ``cmp``.
Stream ids are ``ms-seq`` strings; comparing as strings would be wrong
once ``ms`` straddles digit-count boundaries. We parse and compare
as ``(int, int)`` tuples.
Raises ``ValueError`` on malformed input. Callers must pre-validate
against ``_STREAM_ID_RE`` (or equivalent) — a lex fallback here let
a malformed id compare lex-greater than a real one and silently pin
dedup forever.
"""
a_ms, _, a_seq = a.partition("-")
b_ms, _, b_seq = b.partition("-")
a_tuple = (int(a_ms), int(a_seq) if a_seq else 0)
b_tuple = (int(b_ms), int(b_seq) if b_seq else 0)
if a_tuple < b_tuple:
return -1
if a_tuple > b_tuple:
return 1
return 0

View File

@@ -1,144 +0,0 @@
"""User-scoped event publisher: durable backlog + live fan-out.
Each ``publish_user_event`` call writes twice:
1. ``XADD user:{user_id}:stream MAXLEN ~ <cap> * event <json>`` — the
durable backlog used by SSE reconnect (``Last-Event-ID``) and stream
replay. Bounded by ``EVENTS_STREAM_MAXLEN`` (~24h at typical event
rates) so the per-user footprint stays predictable.
2. ``PUBLISH user:{user_id} <json-with-id>`` — live fan-out to every
currently connected SSE generator for the user, across instances.
Together they give a snapshot-plus-tail story: a reconnecting client
reads ``XRANGE`` from its last seen id and then transitions onto the
live pub/sub. The Redis Streams entry id (e.g. ``1735682400000-0``) is
the canonical, monotonically increasing event id and is what
``Last-Event-ID`` carries.
Failures are logged and swallowed: the caller is typically a Celery
task whose primary work has already succeeded, and a notification
delivery miss should not surface as a task failure.
"""
from __future__ import annotations
import json
import logging
from datetime import datetime, timezone
from typing import Any, Optional
from application.cache import get_redis_instance
from application.core.settings import settings
from application.events.keys import stream_key, topic_name
from application.streaming.broadcast_channel import Topic
logger = logging.getLogger(__name__)
def _iso_now() -> str:
"""ISO 8601 UTC with millisecond precision and Z suffix."""
return (
datetime.now(timezone.utc)
.isoformat(timespec="milliseconds")
.replace("+00:00", "Z")
)
def publish_user_event(
user_id: str,
event_type: str,
payload: dict[str, Any],
*,
scope: Optional[dict[str, Any]] = None,
) -> Optional[str]:
"""Publish a user-scoped event; return the Redis Streams id or ``None``.
Fire-and-forget: never raises. ``None`` means the event reached
neither the journal nor live subscribers (see runbook for causes).
"""
if not user_id or not event_type:
logger.warning(
"publish_user_event called without user_id or event_type "
"(user_id=%r, event_type=%r)",
user_id,
event_type,
)
return None
if not settings.ENABLE_SSE_PUSH:
return None
envelope_partial: dict[str, Any] = {
"type": event_type,
"ts": _iso_now(),
"user_id": user_id,
"topic": topic_name(user_id),
"scope": scope or {},
"payload": payload,
}
try:
envelope_partial_json = json.dumps(envelope_partial)
except (TypeError, ValueError) as exc:
logger.warning(
"publish_user_event payload not JSON-serializable: "
"user=%s type=%s err=%s",
user_id,
event_type,
exc,
)
return None
redis = get_redis_instance()
if redis is None:
logger.debug("Redis unavailable; skipping publish_user_event")
return None
maxlen = settings.EVENTS_STREAM_MAXLEN
stream_id: Optional[str] = None
try:
# Auto-id ('*') gives a monotonic ms-seq id that doubles as the
# SSE event id. ``approximate=True`` lets Redis trim in chunks
# for performance; the cap is treated as ~MAXLEN, never <.
result = redis.xadd(
stream_key(user_id),
{"event": envelope_partial_json},
maxlen=maxlen,
approximate=True,
)
stream_id = (
result.decode("utf-8")
if isinstance(result, (bytes, bytearray))
else str(result)
)
except Exception:
logger.exception(
"xadd failed for user=%s event_type=%s", user_id, event_type
)
# If the durable journal write failed there is no canonical id to
# ship — publishing the envelope live would put an id-less record
# on the wire that bypasses the SSE route's dedup floor and breaks
# ``Last-Event-ID`` semantics for any reconnect. Best-effort
# delivery means dropping consistently, not delivering inconsistent
# state.
if stream_id is None:
return None
envelope = dict(envelope_partial)
envelope["id"] = stream_id
try:
Topic(topic_name(user_id)).publish(json.dumps(envelope))
except Exception:
logger.exception(
"publish failed for user=%s event_type=%s", user_id, event_type
)
logger.debug(
"event.published topic=%s type=%s id=%s",
topic_name(user_id),
event_type,
stream_id,
)
return stream_id

View File

@@ -4,7 +4,6 @@ from typing import Any, List, Optional
from retry import retry
from tqdm import tqdm
from application.core.settings import settings
from application.events.publisher import publish_user_event
from application.storage.db.repositories.ingest_chunk_progress import (
IngestChunkProgressRepository,
)
@@ -153,7 +152,6 @@ def embed_and_store_documents(
task_status: Any,
*,
attempt_id: Optional[str] = None,
user_id: Optional[str] = None,
) -> None:
"""Embeds documents and stores them in a vector store.
@@ -172,10 +170,6 @@ def embed_and_store_documents(
attempt_id: Stable id of the current task invocation,
typically ``self.request.id`` from the Celery task body.
``None`` is treated as a fresh attempt every time.
user_id: When provided, per-percent SSE progress events are
published to ``user:{user_id}`` for the in-app upload toast.
``None`` is the safe default — workers without a user
context (e.g. background syncs) skip the publish.
Returns:
None
@@ -255,8 +249,6 @@ def embed_and_store_documents(
# Process and embed documents
chunk_error: Exception | None = None
failed_idx: int | None = None
last_published_pct = -1
source_id_str = str(source_id)
for idx in tqdm(
range(loop_start, total_docs),
desc="Embedding 🦖",
@@ -270,24 +262,6 @@ def embed_and_store_documents(
progress = int(((idx + 1) / total_docs) * 100)
task_status.update_state(state="PROGRESS", meta={"current": progress})
# SSE push for sub-second upload-toast updates. Throttled to one
# event per percent so a 10k-chunk ingest emits ~100 events,
# not 10k. The Celery update_state above stays the source of
# truth for the polling-fallback path.
if user_id and progress > last_published_pct:
publish_user_event(
user_id,
"source.ingest.progress",
{
"current": progress,
"total": total_docs,
"embedded_chunks": idx + 1,
"stage": "embedding",
},
scope={"kind": "source", "id": source_id_str},
)
last_published_pct = progress
# Add document to vector store
add_text_to_store_with_retry(store, doc, source_id)
_record_progress(source_id, last_index=idx, embedded_chunks=idx + 1)

View File

@@ -34,7 +34,7 @@ from sqlalchemy.dialects.postgresql import ARRAY, CITEXT, JSONB, UUID
metadata = MetaData()
# --- Users, prompts, tools, logs --------------------------------------------
# --- Phase 1, Tier 1 --------------------------------------------------------
users_table = Table(
"users",
@@ -138,7 +138,7 @@ app_metadata_table = Table(
)
# --- Agents, sources, attachments, artifacts --------------------------------
# --- Phase 2, Tier 2 --------------------------------------------------------
agent_folders_table = Table(
"agent_folders",
@@ -307,7 +307,7 @@ connector_sessions_table = Table(
)
# --- Conversations, messages, workflows -------------------------------------
# --- Phase 3, Tier 3 --------------------------------------------------------
conversations_table = Table(
"conversations",
@@ -363,36 +363,6 @@ conversation_messages_table = Table(
UniqueConstraint("conversation_id", "position", name="conversation_messages_conv_pos_uidx"),
)
# Per-yield journal of chat-stream events, used by the snapshot+tail
# reconnect: the route's GET reconnect endpoint reads
# ``WHERE message_id = ? AND sequence_no > ?`` from this table before
# tailing the live ``channel:{message_id}`` pub/sub. See
# ``application/streaming/event_replay.py`` and migration 0007.
message_events_table = Table(
"message_events",
metadata,
# PK is the composite ``(message_id, sequence_no)`` — it doubles as
# the snapshot read index (covering range scan on
# ``WHERE message_id = ? AND sequence_no > ?``).
Column(
"message_id",
UUID(as_uuid=True),
ForeignKey("conversation_messages.id", ondelete="CASCADE"),
primary_key=True,
nullable=False,
),
# Strictly monotonic per ``message_id``. Allocated by the route as it
# yields, so the writer is single-threaded for the lifetime of one
# stream — no contention, no SERIAL needed.
Column("sequence_no", Integer, primary_key=True, nullable=False),
Column("event_type", Text, nullable=False),
Column("payload", JSONB, nullable=False, server_default="{}"),
Column(
"created_at", DateTime(timezone=True), nullable=False, server_default=func.now()
),
)
shared_conversations_table = Table(
"shared_conversations",
metadata,
@@ -433,7 +403,7 @@ pending_tool_state_table = Table(
)
# --- Durability foundation (idempotency / journals, migration 0004) ---------
# --- Tier 1 durability foundation (migration 0004) --------------------------
# CHECK constraints (status enums) and partial indexes are intentionally
# omitted from these declarations — the DB is the authority. Repositories
# use raw ``text(...)`` SQL against these tables, not the Core objects.

View File

@@ -15,7 +15,6 @@ Covers every operation the legacy Mongo code performs on
from __future__ import annotations
import json
from enum import Enum
from typing import Optional
from sqlalchemy import Connection, text
@@ -26,22 +25,6 @@ from application.storage.db.models import conversations_table, conversation_mess
from application.storage.db.serialization import PGNativeJSONEncoder
class MessageUpdateOutcome(str, Enum):
"""Discriminated result of ``update_message_by_id``.
Distinguishes the row-actually-updated case from the row-already-at-
the-requested-terminal-state case so an abort handler can journal
``end`` instead of ``error`` when the normal-path finalize already
flipped the row to ``complete``.
"""
UPDATED = "updated"
ALREADY_COMPLETE = "already_complete"
ALREADY_FAILED = "already_failed"
NOT_FOUND = "not_found"
INVALID = "invalid"
def _message_row_to_dict(row) -> dict:
"""Like ``row_to_dict`` but renames the DB column ``message_metadata``
back to the public API key ``metadata`` so callers keep the Mongo-era
@@ -75,8 +58,8 @@ class ConversationsRepository:
- Already-UUID-shaped → returned as-is.
- Otherwise treated as a Mongo ObjectId and looked up via
``agents.legacy_mongo_id``. Returns ``None`` if no PG row
exists yet (e.g. the agent was created before the backfill
ran).
exists yet (e.g. the agent was created before Phase 1
backfill).
"""
if not agent_id_raw:
return None
@@ -714,7 +697,7 @@ class ConversationsRepository:
def update_message_by_id(
self, message_id: str, fields: dict,
*, only_if_non_terminal: bool = False,
) -> MessageUpdateOutcome:
) -> bool:
"""Update specific fields on a message identified by its UUID.
``metadata`` is merged into the existing JSONB rather than
@@ -722,13 +705,9 @@ class ConversationsRepository:
a successful late finalize. When ``only_if_non_terminal`` is
True, the update is gated so a late finalize cannot retract a
reconciler-set ``failed`` (or a prior ``complete``).
The return value discriminates "I updated the row" from "the
row was already at a terminal state" so the abort handler can
journal ``end`` when the normal-path finalize already ran.
"""
if not looks_like_uuid(message_id):
return MessageUpdateOutcome.INVALID
return False
allowed = {
"prompt", "response", "thought", "sources", "tool_calls",
"attachments", "model_id", "metadata", "timestamp", "status",
@@ -736,7 +715,7 @@ class ConversationsRepository:
}
filtered = {k: v for k, v in fields.items() if k in allowed}
if not filtered:
return MessageUpdateOutcome.INVALID
return False
api_to_col = {"metadata": "message_metadata"}
@@ -773,44 +752,15 @@ class ConversationsRepository:
params[col] = val
set_parts.append("updated_at = now()")
update_where = ["id = CAST(:id AS uuid)"]
where_clauses = ["id = CAST(:id AS uuid)"]
if only_if_non_terminal:
update_where.append("status NOT IN ('complete', 'failed')")
# Single-statement attempt + prior-status probe. Both CTEs see
# the same MVCC snapshot, so ``prior.status`` reflects the row
# state before the UPDATE — exactly what we need to tell
# ``ALREADY_COMPLETE`` apart from ``ALREADY_FAILED`` apart from
# ``NOT_FOUND`` without a follow-up SELECT.
where_clauses.append("status NOT IN ('complete', 'failed')")
sql = (
"WITH attempted AS ("
f" UPDATE conversation_messages SET {', '.join(set_parts)} "
f" WHERE {' AND '.join(update_where)} "
" RETURNING 1 AS updated"
"), "
"prior AS ("
" SELECT status FROM conversation_messages "
" WHERE id = CAST(:id AS uuid)"
") "
"SELECT (SELECT updated FROM attempted) AS updated, "
" (SELECT status FROM prior) AS prior_status"
f"UPDATE conversation_messages SET {', '.join(set_parts)} "
f"WHERE {' AND '.join(where_clauses)}"
)
row = self._conn.execute(text(sql), params).fetchone()
if row is None:
return MessageUpdateOutcome.NOT_FOUND
updated, prior_status = row[0], row[1]
if updated:
return MessageUpdateOutcome.UPDATED
if prior_status is None:
return MessageUpdateOutcome.NOT_FOUND
if prior_status == "complete":
return MessageUpdateOutcome.ALREADY_COMPLETE
if prior_status == "failed":
return MessageUpdateOutcome.ALREADY_FAILED
# ``only_if_non_terminal=False`` always updates an existing row,
# so reaching here means the gate excluded it for some status
# the terminal set doesn't cover — treat as "not found" rather
# than inventing a new variant.
return MessageUpdateOutcome.NOT_FOUND
result = self._conn.execute(text(sql), params)
return result.rowcount > 0
def update_message_status(
self, message_id: str, status: str,

View File

@@ -1,248 +0,0 @@
"""Repository for ``message_events`` — the chat-stream snapshot journal.
``record`` / ``bulk_record`` write per-yield events; ``read_after``
replays rows past a cursor for reconnect snapshots. Composite PK
``(message_id, sequence_no)`` raises ``IntegrityError`` on duplicates.
Callers must use short-lived per-call transactions — long-lived
transactions hide writes from reconnecting clients on a separate
connection and turn one bad row into ``InFailedSqlTransaction``.
"""
from __future__ import annotations
import json
import logging
from typing import Any, Optional
from sqlalchemy import Connection, text
from application.storage.db.base_repository import row_to_dict
logger = logging.getLogger(__name__)
class MessageEventsRepository:
"""Read/write helpers for ``message_events``."""
def __init__(self, conn: Connection) -> None:
self._conn = conn
def record(
self,
message_id: str,
sequence_no: int,
event_type: str,
payload: Optional[Any] = None,
) -> None:
"""Append a single event to the journal.
At this raw repo layer ``payload`` is preserved as-is when not
``None`` (lists, scalars, and dicts all round-trip via JSONB);
``None`` substitutes an empty object so the column's NOT NULL
invariant holds. The streaming-route wrapper
``application/streaming/message_journal.py::record_event``
tightens this contract to dicts only — the live and replay
paths reconstruct non-dict payloads differently, so the wrapper
rejects them at the gate. Direct callers of this repo method
(cleanup tasks, tests, future ad-hoc consumers) keep the wider
JSONB-compatible surface.
Raises ``sqlalchemy.exc.IntegrityError`` on duplicate
``(message_id, sequence_no)`` and ``DataError`` on a malformed
``message_id`` UUID. Both abort the surrounding transaction —
callers must run inside a short-lived per-event session
(see module docstring).
"""
if not event_type:
raise ValueError("event_type must be a non-empty string")
materialised_payload = payload if payload is not None else {}
self._conn.execute(
text(
"""
INSERT INTO message_events (
message_id, sequence_no, event_type, payload
) VALUES (
CAST(:message_id AS uuid), :sequence_no, :event_type,
CAST(:payload AS jsonb)
)
"""
),
{
"message_id": str(message_id),
"sequence_no": int(sequence_no),
"event_type": event_type,
"payload": json.dumps(materialised_payload),
},
)
def bulk_record(
self,
message_id: str,
events: list[tuple[int, str, dict]],
) -> None:
"""Append multiple events for ``message_id`` in one INSERT.
``events`` is a list of ``(sequence_no, event_type, payload)``
tuples. SQLAlchemy ``executemany`` issues one bulk INSERT;
Postgres treats the whole batch as one statement, so an
IntegrityError on any row aborts the entire batch.
Caller contract: on IntegrityError, do NOT retry this method
with the same batch — fall back to per-row ``record()`` calls
(each in its own short-lived session) so a single colliding
seq doesn't drop the rest of the batch. ``BatchedJournalWriter``
in ``application/streaming/message_journal.py`` is the canonical
consumer.
"""
if not events:
return
params = [
{
"message_id": str(message_id),
"sequence_no": int(seq),
"event_type": event_type,
"payload": json.dumps(payload if payload is not None else {}),
}
for seq, event_type, payload in events
]
self._conn.execute(
text(
"""
INSERT INTO message_events (
message_id, sequence_no, event_type, payload
) VALUES (
CAST(:message_id AS uuid), :sequence_no, :event_type,
CAST(:payload AS jsonb)
)
"""
),
params,
)
def read_after(
self,
message_id: str,
last_sequence_no: Optional[int] = None,
) -> list[dict]:
"""Return events with ``sequence_no > last_sequence_no``.
``last_sequence_no=None`` returns the full backlog. Rows are
returned in ascending ``sequence_no`` order. The composite PK
is the snapshot read index for this scan — Postgres typically
picks an in-order index range scan, though for highly mixed
data the planner may pick a bitmap+sort. Either way the result
is sorted on ``sequence_no``.
Returns a ``list`` (not a generator) so the underlying
``Result`` is fully drained before the caller can issue
another query on the same connection.
"""
cursor = -1 if last_sequence_no is None else int(last_sequence_no)
rows = self._conn.execute(
text(
"""
SELECT message_id, sequence_no, event_type, payload, created_at
FROM message_events
WHERE message_id = CAST(:message_id AS uuid)
AND sequence_no > :cursor
ORDER BY sequence_no ASC
"""
),
{"message_id": str(message_id), "cursor": cursor},
).fetchall()
return [row_to_dict(row) for row in rows]
def cleanup_older_than(self, ttl_days: int) -> int:
"""Delete journal rows older than ``ttl_days``. Returns row count.
Reconnect-replay is meaningful only for streams the client
could plausibly still be waiting on, so old rows are dead
weight. The ``message_events_created_at_idx`` btree makes the
range delete a cheap index scan even on large tables.
"""
if ttl_days <= 0:
raise ValueError("ttl_days must be positive")
result = self._conn.execute(
text(
"""
DELETE FROM message_events
WHERE created_at < now() - make_interval(days => :ttl_days)
"""
),
{"ttl_days": int(ttl_days)},
)
return int(result.rowcount or 0)
def reconstruct_partial(self, message_id: str) -> dict:
"""Rebuild partial response/thought/sources/tool_calls from journal events.
``answer``/``thought`` chunks concat in seq order; ``source``/
``tool_calls`` carry the full list at emit time (last-wins).
"""
rows = self._conn.execute(
text(
"""
SELECT sequence_no, event_type, payload
FROM message_events
WHERE message_id = CAST(:message_id AS uuid)
ORDER BY sequence_no ASC
"""
),
{"message_id": str(message_id)},
).fetchall()
response_parts: list[str] = []
thought_parts: list[str] = []
sources: list = []
tool_calls: list = []
for row in rows:
payload = row.payload
if not isinstance(payload, dict):
continue
etype = row.event_type
if etype == "answer":
chunk = payload.get("answer")
if isinstance(chunk, str):
response_parts.append(chunk)
elif etype == "thought":
chunk = payload.get("thought")
if isinstance(chunk, str):
thought_parts.append(chunk)
elif etype == "source":
src = payload.get("source")
if isinstance(src, list):
sources = src
elif etype == "tool_calls":
tcs = payload.get("tool_calls")
if isinstance(tcs, list):
tool_calls = tcs
return {
"response": "".join(response_parts),
"thought": "".join(thought_parts),
"sources": sources,
"tool_calls": tool_calls,
}
def latest_sequence_no(self, message_id: str) -> Optional[int]:
"""Largest ``sequence_no`` recorded for ``message_id``, or ``None``.
Used by the route to seed the per-stream allocator on retry /
process restart so a re-run continues numbering instead of
trampling earlier entries with duplicate sequence_no.
"""
# ``MAX`` always returns one row — NULL when the journal is
# empty — so we test the value, not the row presence.
row = self._conn.execute(
text(
"""
SELECT MAX(sequence_no) AS s
FROM message_events
WHERE message_id = CAST(:message_id AS uuid)
"""
),
{"message_id": str(message_id)},
).first()
value = row[0] if row is not None else None
return int(value) if value is not None else None

View File

@@ -1,23 +0,0 @@
"""Deterministic source-id derivation for idempotent ingest.
DO NOT CHANGE the pinned UUID namespace — it backs cross-deploy
idempotency keys.
"""
from __future__ import annotations
import uuid
# DO NOT CHANGE. See module docstring.
DOCSGPT_INGEST_NAMESPACE = uuid.UUID("fa25d5d1-398b-46df-ac89-8d1c360b9bea")
def derive_source_id(idempotency_key) -> uuid.UUID:
"""``uuid5(NS, key)`` when a key is supplied; ``uuid4()`` otherwise.
A non-string / empty key falls back to ``uuid4()`` so the caller
always gets a fresh id rather than a TypeError mid-route.
"""
if isinstance(idempotency_key, str) and idempotency_key:
return uuid.uuid5(DOCSGPT_INGEST_NAMESPACE, idempotency_key)
return uuid.uuid4()

View File

@@ -1,126 +0,0 @@
"""Redis pub/sub Topic abstraction for SSE fan-out.
A Topic is a named channel for one-shot live event delivery. Canonical uses:
- ``user:{user_id}`` for per-user notifications
- ``channel:{message_id}`` for per-chat-message streams
Subscription is race-free via ``on_subscribe``: the callback fires only
after Redis acknowledges ``SUBSCRIBE``, so a publisher dispatched inside
the callback cannot lose its first event to a not-yet-registered
subscriber.
The subscribe iterator yields ``None`` on poll timeout so the caller can
emit SSE keepalive comments without spawning a separate timer thread.
"""
from __future__ import annotations
import logging
from typing import Callable, Iterator, Optional
from application.cache import get_redis_instance
logger = logging.getLogger(__name__)
class Topic:
"""A pub/sub channel identified by a string name."""
def __init__(self, name: str) -> None:
self.name = name
def publish(self, payload: str | bytes) -> int:
"""Fan out a payload to currently subscribed clients.
Returns the number Redis reports as receiving the message (limited
to subscribers connected to *this* Redis instance), or 0 if Redis
is unavailable. Never raises.
"""
redis = get_redis_instance()
if redis is None:
logger.debug("Redis unavailable; dropping publish to %s", self.name)
return 0
try:
return int(redis.publish(self.name, payload))
except Exception:
logger.exception("Topic.publish failed for %s", self.name)
return 0
def subscribe(
self,
on_subscribe: Optional[Callable[[], None]] = None,
poll_timeout: float = 1.0,
) -> Iterator[Optional[bytes]]:
"""Subscribe to the topic; yield raw payloads or ``None`` on tick.
Yields ``None`` every ``poll_timeout`` seconds while idle so the
caller can emit keepalive frames or check cancellation. Yields
``bytes`` for each delivered message.
``on_subscribe`` runs synchronously after Redis acknowledges the
SUBSCRIBE — use it to seed any state (e.g. read backlog) that
must be ordered after the subscriber is live but before the
first pub/sub message is processed.
If Redis is unavailable, returns immediately without yielding.
Cleanly unsubscribes on ``GeneratorExit`` (client disconnect).
"""
redis = get_redis_instance()
if redis is None:
logger.debug("Redis unavailable; subscribe to %s yielded nothing", self.name)
return
pubsub = None
on_subscribe_fired = False
try:
pubsub = redis.pubsub()
try:
pubsub.subscribe(self.name)
except Exception:
# Subscribe failure (transient Redis hiccup, conn reset, etc.)
# is treated like "Redis unavailable": yield nothing, let the
# caller fall back to its own resilience strategy. The finally
# block will still tear down the pubsub object cleanly.
logger.exception("pubsub.subscribe failed for %s", self.name)
return
while True:
try:
msg = pubsub.get_message(timeout=poll_timeout)
except Exception:
logger.exception("pubsub.get_message failed for %s", self.name)
return
if msg is None:
yield None
continue
msg_type = msg.get("type")
if msg_type == "subscribe":
if not on_subscribe_fired and on_subscribe is not None:
try:
on_subscribe()
except Exception:
logger.exception(
"on_subscribe callback failed for %s", self.name
)
on_subscribe_fired = True
continue
if msg_type != "message":
continue
data = msg.get("data")
if data is None:
continue
yield data if isinstance(data, bytes) else str(data).encode("utf-8")
finally:
if pubsub is not None:
if on_subscribe_fired:
try:
pubsub.unsubscribe(self.name)
except Exception:
logger.debug(
"pubsub unsubscribe error for %s",
self.name,
exc_info=True,
)
try:
pubsub.close()
except Exception:
logger.debug("pubsub close error for %s", self.name, exc_info=True)

View File

@@ -1,434 +0,0 @@
"""Snapshot+tail iterator for chat-stream reconnect-after-disconnect.
Subscribe to ``channel:{message_id}``, snapshot ``message_events``
rows past ``last_event_id`` inside the SUBSCRIBE-ack callback, flush
snapshot, then tail live pub/sub (dedup'd by ``sequence_no``). See
``docs/runbooks/sse-notifications.md``.
"""
from __future__ import annotations
import json
import logging
import re
import time
from typing import Iterator, Optional
from sqlalchemy import text as sql_text
from application.storage.db.repositories.message_events import (
MessageEventsRepository,
)
from application.storage.db.session import db_readonly
from application.streaming.broadcast_channel import Topic
from application.streaming.keys import message_topic_name
logger = logging.getLogger(__name__)
DEFAULT_KEEPALIVE_SECONDS = 15.0
DEFAULT_POLL_TIMEOUT_SECONDS = 1.0
# When the live tail has no events and no terminal in snapshot, fall
# back to checking ``conversation_messages`` directly. If the row has
# already gone terminal (worker journaled ``end``/``error`` to the DB
# but the matching pub/sub publish was lost, or the row was finalized
# without a journal write at all) we surface a terminal event so the
# client doesn't hang on keepalives. If the row is still non-terminal
# but the producer heartbeat is older than ``PRODUCER_IDLE_SECONDS``
# the producer is presumed dead (worker crash / recycle between chunks
# and finalize) and we emit a terminal ``error`` so the UI can recover.
DEFAULT_WATCHDOG_INTERVAL_SECONDS = 5.0
# 1.5× the route's 60s heartbeat interval — long enough that a normal
# heartbeat skew doesn't false-positive, short enough that a stuck
# stream surfaces before the 5-minute reconciler sweep escalates.
DEFAULT_PRODUCER_IDLE_SECONDS = 90.0
# WHATWG SSE accepts CRLF, CR, LF — split on any of them so a stray CR
# can't smuggle a record boundary into the wire format.
_SSE_LINE_SPLIT_PATTERN = re.compile(r"\r\n|\r|\n")
# Event types that mark the end of a chat answer. After delivering one
# we close the reconnect stream — keeping the connection open past a
# terminal event would leak both the client's reconnect promise and
# the server's WSGI thread waiting on keepalives that the user no
# longer cares about. The agent loop emits ``end`` for normal /
# tool-paused completion and ``error`` for the catch-all failure path
# (which doesn't get a trailing ``end``).
_TERMINAL_EVENT_TYPES = frozenset({"end", "error"})
def _payload_is_terminal(
payload: object, event_type: Optional[str] = None
) -> bool:
"""True if ``payload['type']`` or ``event_type`` is a terminal sentinel."""
if isinstance(payload, dict) and payload.get("type") in _TERMINAL_EVENT_TYPES:
return True
return event_type in _TERMINAL_EVENT_TYPES
def format_sse_event(payload: dict, sequence_no: int) -> str:
"""Encode a journal event as one ``id:``/``data:`` SSE record.
The body is the payload's JSON serialisation. ``complete_stream``
payloads are flat JSON dicts with no embedded newlines, so a
single ``data:`` line is sufficient — but we still split on any
line terminator in case a future caller passes a multi-line string.
"""
body = json.dumps(payload)
lines = [f"id: {sequence_no}"]
for line in _SSE_LINE_SPLIT_PATTERN.split(body):
lines.append(f"data: {line}")
return "\n".join(lines) + "\n\n"
def _check_producer_liveness(
message_id: str, idle_seconds: float
) -> Optional[dict]:
"""Inspect ``conversation_messages`` and return a terminal SSE
payload when the producer is no longer alive, else ``None``.
Three terminal cases collapse into a single DB round-trip:
- ``status='complete'`` — the live finalize ran but its journal
terminal write didn't reach us (or never happened). Synthesise
``end`` so the client closes cleanly on the row's user-visible
state.
- ``status='failed'`` — same, but for the failure path. Carry the
stashed ``error`` from ``message_metadata`` so the UI shows the
real reason.
- non-terminal status and ``last_heartbeat_at`` (or ``timestamp``)
older than ``idle_seconds`` — the producing worker is gone.
Synthesise ``error`` so the client doesn't hang on keepalives
until the proxy idle-timeout kicks in.
"""
try:
with db_readonly() as conn:
row = conn.execute(
sql_text(
"""
SELECT
status,
message_metadata->>'error' AS err,
GREATEST(
timestamp,
COALESCE(
(message_metadata->>'last_heartbeat_at')
::timestamptz,
timestamp
)
) < now() - make_interval(secs => :idle_secs)
AS is_stale
FROM conversation_messages
WHERE id = CAST(:id AS uuid)
"""
),
{"id": message_id, "idle_secs": float(idle_seconds)},
).first()
except Exception:
logger.exception(
"Watchdog liveness check failed for message_id=%s", message_id
)
return None
if row is None:
# Row deleted out from under us — treat as terminal so the
# client doesn't keep tailing a message that no longer exists.
return {
"type": "error",
"error": "Message no longer exists; please refresh.",
"code": "message_missing",
"message_id": message_id,
}
status, err, is_stale = row[0], row[1], bool(row[2])
if status == "complete":
return {"type": "end"}
if status == "failed":
return {
"type": "error",
"error": err or "Stream failed; please try again.",
"code": "producer_failed",
"message_id": message_id,
}
if is_stale:
return {
"type": "error",
"error": (
"Stream producer is no longer responding; please try again."
),
"code": "producer_stale",
"message_id": message_id,
}
return None
def build_message_event_stream(
message_id: str,
last_event_id: Optional[int] = None,
*,
keepalive_seconds: float = DEFAULT_KEEPALIVE_SECONDS,
poll_timeout_seconds: float = DEFAULT_POLL_TIMEOUT_SECONDS,
watchdog_interval_seconds: float = DEFAULT_WATCHDOG_INTERVAL_SECONDS,
producer_idle_seconds: float = DEFAULT_PRODUCER_IDLE_SECONDS,
) -> Iterator[str]:
"""Yield SSE-formatted lines for one ``message_id`` reconnect stream.
First frame is ``: connected``; subsequent frames are snapshot rows,
live-tail events, or ``: keepalive`` comments. Runs until the client
disconnects.
"""
yield ": connected\n\n"
# Replay buffer — populated inside ``_on_subscribe`` (or the
# Redis-unavailable fallback below), drained on the first iteration
# of the subscribe loop after the callback runs.
replay_buffer: list[str] = []
# Dedup floor: seeded with the client's cursor so an empty snapshot
# still rejects re-published live events with seq <= last_event_id.
# Advanced by snapshot rows AND by yielded live events, so any
# republish past the snapshot ceiling is also dropped.
max_replayed_seq: Optional[int] = last_event_id
replay_done = False
replay_failed = False
# Set when a snapshot row carries a terminal ``end`` / ``error``
# event. After flushing the buffer the generator returns; if we
# kept tailing we'd loop on keepalives forever for a stream that
# already finished.
terminal_in_snapshot = False
def _read_snapshot_into_buffer() -> None:
nonlocal max_replayed_seq, replay_failed, terminal_in_snapshot
try:
with db_readonly() as conn:
rows = MessageEventsRepository(conn).read_after(
message_id, last_sequence_no=last_event_id
)
for row in rows:
seq = int(row["sequence_no"])
payload = row.get("payload")
if not isinstance(payload, dict):
# ``record_event`` rejects non-dict payloads at the
# write gate, so this can only be a legacy row from
# before that contract or a direct SQL insert. The
# original synthetic fallback (``{"type": event_type}``)
# used to ship a malformed envelope here — drop the
# row instead so a corrupt journal entry doesn't
# poison a reconnect.
logger.warning(
"Skipping non-dict payload from message_events: "
"message_id=%s seq=%s type=%s",
message_id,
seq,
row.get("event_type"),
)
continue
replay_buffer.append(format_sse_event(payload, seq))
if max_replayed_seq is None or seq > max_replayed_seq:
max_replayed_seq = seq
if _payload_is_terminal(payload, row.get("event_type")):
terminal_in_snapshot = True
except Exception:
logger.exception(
"Snapshot read failed for message_id=%s last_event_id=%s",
message_id,
last_event_id,
)
replay_failed = True
def _on_subscribe() -> None:
# SUBSCRIBE has been acked — Postgres reads from this point
# capture every row that's been committed. Pub/sub messages
# published after this point are queued at the connection level
# until the outer loop calls ``get_message`` again.
nonlocal replay_done
try:
_read_snapshot_into_buffer()
finally:
# Flip even on failure so the outer loop continues to live
# tail and the client doesn't hang waiting for a snapshot
# flush that will never come.
replay_done = True
topic = Topic(message_topic_name(message_id))
last_keepalive = time.monotonic()
# Rate-limit the watchdog's DB hit. ``-inf`` makes the first idle
# tick after replay_done fire immediately so a snapshot-already-
# terminal-in-DB case is surfaced before any keepalive cadence.
# Subsequent checks are gated by ``watchdog_interval_seconds``.
last_watchdog_check = float("-inf")
# Synthetic terminal events emitted by the watchdog use the same
# ``sequence_no=-1`` convention as the snapshot-failure path so the
# frontend's strict ``\d+`` cursor regex rejects them as a
# ``Last-Event-ID`` for any future reconnect. The chosen
# discriminator ensures a manual page refresh after a watchdog-fired
# error doesn't loop on the same synthetic id.
watchdog_synthetic_seq = -1
try:
for payload in topic.subscribe(
on_subscribe=_on_subscribe,
poll_timeout=poll_timeout_seconds,
):
# Flush snapshot exactly once after the SUBSCRIBE callback
# has run and produced a buffer.
if replay_done and replay_buffer:
for line in replay_buffer:
yield line
replay_buffer.clear()
if terminal_in_snapshot:
# The original stream already finished; tailing
# would just emit keepalives forever and pin both a
# client reconnect promise and a server WSGI thread.
return
if replay_failed:
# Snapshot read failed (DB blip / transient timeout). Emit a
# terminal ``error`` event and return — the client only
# reconnects after the original stream has already moved on,
# so without a snapshot there's nothing live left to tail and
# holding the connection open would just emit keepalives
# until the proxy idle-timeout fires. ``code`` preserves the
# snapshot-vs-agent-loop distinction so a future client can
# opt into a refetch instead of a hard failure.
yield format_sse_event(
{
"type": "error",
"error": "Stream replay failed; please refresh to load the latest state.",
"code": "snapshot_failed",
"message_id": message_id,
},
sequence_no=-1,
)
return
now = time.monotonic()
if payload is None:
# Idle tick — check both keepalive and watchdog. The
# watchdog only kicks in once the snapshot half has been
# flushed (``replay_done``) so we don't race the
# snapshot read on the first iteration.
if (
replay_done
and watchdog_interval_seconds >= 0
and now - last_watchdog_check >= watchdog_interval_seconds
):
last_watchdog_check = now
terminal_payload = _check_producer_liveness(
message_id, producer_idle_seconds
)
if terminal_payload is not None:
yield format_sse_event(
terminal_payload,
sequence_no=watchdog_synthetic_seq,
)
return
if now - last_keepalive >= keepalive_seconds:
yield ": keepalive\n\n"
last_keepalive = now
continue
envelope = _decode_pubsub_message(payload)
if envelope is None:
continue
seq = envelope.get("sequence_no")
inner = envelope.get("payload")
if (
not isinstance(seq, int)
or isinstance(seq, bool)
or not isinstance(inner, dict)
):
continue
if max_replayed_seq is not None and seq <= max_replayed_seq:
# Snapshot already covered this id — drop the duplicate.
continue
yield format_sse_event(inner, seq)
# Advance the dedup floor on the live path too, so a stale
# republish of an already-yielded seq (process restart, retry
# tool, etc.) is dropped on a later iteration.
max_replayed_seq = seq
last_keepalive = now
if _payload_is_terminal(inner, envelope.get("event_type")):
# Live tail just delivered the terminal event — close
# out the reconnect stream so the client's drain
# promise resolves and the WSGI thread is freed.
return
# Subscribe exited without ever yielding (Redis unavailable,
# ``pubsub.subscribe`` raised, or the inner loop died between
# SUBSCRIBE-ack and the first poll). The snapshot half is in
# Postgres and is still serviceable — read it directly so a
# Redis-only outage doesn't cost the client their reconnect
# backlog. Gate the read on ``replay_done`` rather than
# ``subscribe_started``: if ``_on_subscribe`` already populated
# the buffer, re-reading would append the same rows twice and
# double the answer chunks on the client (the per-message
# reconnect dispatcher does not dedup by ``id``).
if not replay_done:
_read_snapshot_into_buffer()
replay_done = True
for line in replay_buffer:
yield line
replay_buffer.clear()
if replay_failed:
# Mirror the live-tail branch: emit a terminal ``error`` so
# the frontend's existing end/error handling drives the UI
# to a failed state instead of relying on the proxy timeout.
yield format_sse_event(
{
"type": "error",
"error": "Stream replay failed; please refresh to load the latest state.",
"code": "snapshot_failed",
"message_id": message_id,
},
sequence_no=-1,
)
return
# Same close-on-terminal contract as the live-tail branch.
# Without it a Redis-down + already-completed-stream client
# would also hang on a never-ending generator.
if terminal_in_snapshot:
return
except GeneratorExit:
# Client disconnect — let the underlying ``Topic.subscribe``
# ``finally`` block tear down its pubsub cleanly.
return
def _decode_pubsub_message(raw) -> Optional[dict]:
"""Parse a ``Topic.publish`` payload to ``{sequence_no, payload, ...}``.
Returns ``None`` for malformed messages (drop silently — the
journal is still authoritative on reconnect).
"""
try:
if isinstance(raw, (bytes, bytearray)):
text_value = raw.decode("utf-8")
else:
text_value = str(raw)
envelope = json.loads(text_value)
except Exception:
return None
if not isinstance(envelope, dict):
return None
return envelope
def encode_pubsub_message(
message_id: str,
sequence_no: int,
event_type: str,
payload: dict,
) -> str:
"""Build the JSON envelope used for ``channel:{message_id}`` publishes.
Kept here (not in ``message_journal.py``) so the encode/decode pair
stays in one file — replay's ``_decode_pubsub_message`` and the
journal's publish must agree on the shape exactly.
"""
return json.dumps(
{
"message_id": str(message_id),
"sequence_no": int(sequence_no),
"event_type": event_type,
"payload": payload,
}
)

View File

@@ -1,19 +0,0 @@
"""Per-chat-message stream key derivations.
Single source of truth for the Redis pub/sub topic name and any
auxiliary keys that the chat-stream snapshot+tail reconnect path
shares between the writer (``complete_stream`` + journal) and the
reader (``/api/messages/<id>/events`` reconnect endpoint).
"""
from __future__ import annotations
def message_topic_name(message_id: str) -> str:
"""Redis pub/sub channel for live fan-out of one chat message.
Subscribers tail this topic for every event that ``complete_stream``
yielded after the SUBSCRIBE-ack arrived; older events are recovered
from the ``message_events`` snapshot half of the pattern.
"""
return f"channel:{message_id}"

View File

@@ -1,400 +0,0 @@
"""Per-yield journal write for the chat-stream snapshot+tail pattern.
``record_event`` inserts into ``message_events`` and publishes to
``channel:{message_id}``. Both are best-effort; the INSERT commits
before the publish so a fast reconnect sees the row. See
``docs/runbooks/sse-notifications.md``.
"""
from __future__ import annotations
import logging
import time
from typing import Any, Optional
from sqlalchemy.exc import IntegrityError
from application.storage.db.repositories.message_events import (
MessageEventsRepository,
)
from application.storage.db.session import db_readonly, db_session
from application.streaming.broadcast_channel import Topic
from application.streaming.event_replay import encode_pubsub_message
from application.streaming.keys import message_topic_name
logger = logging.getLogger(__name__)
# Tunables for ``BatchedJournalWriter``. A streaming answer emits ~100s
# of ``answer`` chunks per response; without batching, that's one PG
# transaction per yield in the WSGI thread. With these defaults, ~10x
# fewer commits at the cost of a ≤100ms reconnect-visibility lag for
# any event still sitting in the buffer.
DEFAULT_BATCH_SIZE = 16
DEFAULT_BATCH_INTERVAL_MS = 100
def _strip_null_bytes(value: Any) -> Any:
"""Recursively strip ``\\x00`` from string keys/values in ``value``.
Postgres JSONB rejects the NUL escape; an LLM emitting a stray NUL
in a chunk would otherwise raise ``DataError`` at INSERT and the row
would be lost from the journal (live stream proceeds, reconnect
snapshot misses the chunk). Mirrors the strip already done in
``parser/embedding_pipeline.py`` and
``api/user/attachments/routes.py``.
"""
if isinstance(value, str):
return value.replace("\x00", "") if "\x00" in value else value
if isinstance(value, dict):
return {
(k.replace("\x00", "") if isinstance(k, str) and "\x00" in k else k):
_strip_null_bytes(v)
for k, v in value.items()
}
if isinstance(value, list):
return [_strip_null_bytes(item) for item in value]
if isinstance(value, tuple):
return tuple(_strip_null_bytes(item) for item in value)
return value
def record_event(
message_id: str,
sequence_no: int,
event_type: str,
payload: Optional[dict[str, Any]] = None,
) -> bool:
"""Journal one SSE event and publish it live. Best-effort.
``payload`` must be a ``dict`` or ``None`` (non-dicts are dropped so
live and replay envelopes stay byte-identical). Returns ``True`` when
the journal INSERT committed. Never raises.
"""
if not message_id or not event_type:
logger.warning(
"record_event called without message_id/event_type "
"(message_id=%r, event_type=%r)",
message_id,
event_type,
)
return False
if payload is None:
materialised_payload: dict[str, Any] = {}
elif isinstance(payload, dict):
materialised_payload = _strip_null_bytes(payload)
else:
logger.warning(
"record_event called with non-dict payload "
"(message_id=%s seq=%s type=%s payload_type=%s) — dropping",
message_id,
sequence_no,
event_type,
type(payload).__name__,
)
return False
journal_committed = False
# The seq we actually managed to write. Diverges from
# ``sequence_no`` only on the IntegrityError-retry path below.
materialised_seq = sequence_no
try:
# Short-lived per-event transaction. Critical for visibility:
# the reconnect endpoint reads the journal from a separate
# connection and only sees committed rows.
with db_session() as conn:
MessageEventsRepository(conn).record(
message_id, sequence_no, event_type, materialised_payload
)
journal_committed = True
except IntegrityError:
# Composite-PK collision on (message_id, sequence_no). Most
# likely cause is a stale ``latest_sequence_no`` seed on a
# continuation retry — the route read MAX(seq) from a separate
# connection before another writer committed past it. Look up
# the live latest and retry once with latest+1 so the event is
# not silently lost. Bounded to a single retry — if two
# writers keep racing in lockstep the route-level retry will
# converge them across attempts.
try:
with db_readonly() as conn:
latest = MessageEventsRepository(conn).latest_sequence_no(
message_id
)
materialised_seq = (latest if latest is not None else -1) + 1
with db_session() as conn:
MessageEventsRepository(conn).record(
message_id,
materialised_seq,
event_type,
materialised_payload,
)
journal_committed = True
logger.info(
"record_event: collision at seq=%s recovered → wrote at "
"seq=%s message_id=%s type=%s",
sequence_no,
materialised_seq,
message_id,
event_type,
)
except IntegrityError:
# Second collision under the same retry — give up and log.
# The route's nonlocal counter will continue at
# ``sequence_no+1`` on the next emit; the next call may
# land cleanly past the contended window.
logger.warning(
"record_event: IntegrityError persists after seq+1 retry; "
"dropping. message_id=%s original_seq=%s retry_seq=%s "
"type=%s",
message_id,
sequence_no,
materialised_seq,
event_type,
)
except Exception:
logger.exception(
"record_event: retry path failed unexpectedly "
"(message_id=%s seq=%s type=%s)",
message_id,
sequence_no,
event_type,
)
except Exception:
logger.exception(
"message_events INSERT failed: message_id=%s seq=%s type=%s",
message_id,
sequence_no,
event_type,
)
try:
# Publish using ``materialised_seq`` so the live pubsub frame
# matches the journal row that other clients will snapshot on
# reconnect. The original POST stream's SSE ``id:`` still
# carries the caller's ``sequence_no`` — a reconnect from that
# client will receive the same event at ``materialised_seq``
# on the snapshot, which is a benign duplicate (the slice's
# ``max_replayed_seq`` advances past it). No-collision case:
# ``materialised_seq == sequence_no`` and this is identical to
# the prior behaviour.
wire = encode_pubsub_message(
message_id, materialised_seq, event_type, materialised_payload
)
Topic(message_topic_name(message_id)).publish(wire)
except Exception:
logger.exception(
"channel:%s publish failed: seq=%s type=%s",
message_id,
materialised_seq,
event_type,
)
return journal_committed
class BatchedJournalWriter:
"""Per-stream journal writer that batches PG INSERTs.
One writer per ``message_id``; ``record()`` buffers events and flushes
on size/time/``close()`` triggers. Pubsub publishes fire only after the
INSERT commits. On ``IntegrityError`` falls back to per-row writes.
"""
def __init__(
self,
message_id: str,
*,
batch_size: int = DEFAULT_BATCH_SIZE,
batch_interval_ms: int = DEFAULT_BATCH_INTERVAL_MS,
) -> None:
self._message_id = message_id
self._batch_size = batch_size
self._batch_interval_ms = batch_interval_ms
self._buffer: list[tuple[int, str, dict[str, Any]]] = []
self._last_flush_mono_ms = time.monotonic() * 1000.0
self._closed = False
def record(
self,
sequence_no: int,
event_type: str,
payload: Optional[dict[str, Any]] = None,
) -> bool:
"""Buffer one event; maybe flush. Publish happens after journal commit."""
if self._closed:
logger.warning(
"BatchedJournalWriter.record after close: "
"message_id=%s seq=%s type=%s",
self._message_id,
sequence_no,
event_type,
)
return False
if not event_type:
logger.warning(
"BatchedJournalWriter.record without event_type: "
"message_id=%s seq=%s",
self._message_id,
sequence_no,
)
return False
if payload is None:
materialised: dict[str, Any] = {}
elif isinstance(payload, dict):
materialised = _strip_null_bytes(payload)
else:
# Same contract as ``record_event`` — non-dict payloads
# are rejected so the live and replay paths can't diverge
# on envelope reconstruction.
logger.warning(
"BatchedJournalWriter.record with non-dict payload: "
"message_id=%s seq=%s type=%s payload_type=%s — dropping",
self._message_id,
sequence_no,
event_type,
type(payload).__name__,
)
return False
self._buffer.append((sequence_no, event_type, materialised))
if self._should_flush():
self.flush()
return True
def _should_flush(self) -> bool:
if len(self._buffer) >= self._batch_size:
return True
elapsed_ms = (time.monotonic() * 1000.0) - self._last_flush_mono_ms
return elapsed_ms >= self._batch_interval_ms and len(self._buffer) > 0
def flush(self) -> None:
"""Commit buffered events to PG. Best-effort.
Tries one bulk INSERT first; on ``IntegrityError`` (composite
PK collision — typically a stale continuation seed) falls back
to per-row ``record_event`` so one bad seq doesn't drop the
rest of the batch. Always clears the buffer to bound memory,
even on failure — a journaled event missing from a snapshot
is degraded UX, but a runaway buffer is corruption.
"""
if not self._buffer:
self._last_flush_mono_ms = time.monotonic() * 1000.0
return
# Snapshot and clear before the I/O so a concurrent record()
# call would land in a fresh buffer rather than racing the
# flush. ``complete_stream`` is single-threaded per stream, so
# this is belt-and-suspenders for any future change.
pending = self._buffer
self._buffer = []
self._last_flush_mono_ms = time.monotonic() * 1000.0
try:
with db_session() as conn:
MessageEventsRepository(conn).bulk_record(
self._message_id, pending
)
except IntegrityError:
logger.info(
"BatchedJournalWriter: bulk INSERT collided for "
"message_id=%s n=%d; falling back to per-row writes",
self._message_id,
len(pending),
)
self._flush_per_row(pending)
return
except Exception:
logger.exception(
"BatchedJournalWriter: bulk INSERT failed for "
"message_id=%s n=%d; events dropped from journal",
self._message_id,
len(pending),
)
return
# Bulk INSERT committed — publish each frame in order. Best-effort:
# one failed publish must not poison the rest of the batch.
for seq, event_type, payload in pending:
self._publish(seq, event_type, payload)
def _flush_per_row(
self, pending: list[tuple[int, str, dict[str, Any]]]
) -> None:
"""Per-row fallback after a bulk collision. Publishes after each commit."""
for seq, event_type, payload in pending:
committed_seq: Optional[int] = None
try:
with db_session() as conn:
MessageEventsRepository(conn).record(
self._message_id, seq, event_type, payload
)
committed_seq = seq
except IntegrityError:
try:
with db_readonly() as conn:
latest = MessageEventsRepository(
conn
).latest_sequence_no(self._message_id)
retry_seq = (latest if latest is not None else -1) + 1
with db_session() as conn:
MessageEventsRepository(conn).record(
self._message_id, retry_seq, event_type, payload
)
committed_seq = retry_seq
except IntegrityError:
logger.warning(
"BatchedJournalWriter: IntegrityError persists "
"after seq+1 retry; dropping. message_id=%s "
"original_seq=%s type=%s",
self._message_id,
seq,
event_type,
)
except Exception:
logger.exception(
"BatchedJournalWriter: per-row retry failed "
"(message_id=%s seq=%s type=%s)",
self._message_id,
seq,
event_type,
)
except Exception:
logger.exception(
"BatchedJournalWriter: per-row INSERT failed "
"(message_id=%s seq=%s type=%s)",
self._message_id,
seq,
event_type,
)
if committed_seq is not None:
self._publish(committed_seq, event_type, payload)
def _publish(
self, sequence_no: int, event_type: str, payload: dict[str, Any]
) -> None:
"""Publish one frame to the per-message pubsub channel. Best-effort."""
try:
wire = encode_pubsub_message(
self._message_id, sequence_no, event_type, payload
)
Topic(message_topic_name(self._message_id)).publish(wire)
except Exception:
logger.exception(
"channel:%s publish failed: seq=%s type=%s",
self._message_id,
sequence_no,
event_type,
)
def close(self) -> None:
"""Final flush. Idempotent — safe to call from multiple
finally clauses.
"""
if self._closed:
return
self.flush()
self._closed = True

View File

@@ -1,8 +1,5 @@
import logging
from typing import List, Optional, Any, Dict
from psycopg.types.json import Jsonb
from application.core.settings import settings
from application.vectorstore.base import BaseVectorStore
from application.vectorstore.document_class import Document
@@ -178,7 +175,7 @@ class PGVectorStore(BaseVectorStore):
for text, embedding, metadata in zip(texts, embeddings, metadatas):
cursor.execute(
insert_query,
(text, embedding, Jsonb(metadata), self._source_id)
(text, embedding, metadata, self._source_id)
)
inserted_id = cursor.fetchone()[0]
inserted_ids.append(str(inserted_id))
@@ -269,7 +266,7 @@ class PGVectorStore(BaseVectorStore):
cursor.execute(
insert_query,
(text, embeddings[0], Jsonb(final_metadata), self._source_id)
(text, embeddings[0], final_metadata, self._source_id)
)
inserted_id = cursor.fetchone()[0]
conn.commit()

View File

@@ -19,8 +19,8 @@ import requests
from application.agents.agent_creator import AgentCreator
from application.api.answer.services.stream_processor import get_prompt
from application.cache import get_redis_instance
from application.core.settings import settings
from application.events.publisher import publish_user_event
from application.parser.chunking import Chunker
from application.parser.connectors.connector_creator import ConnectorCreator
from application.parser.embedding_pipeline import (
@@ -52,16 +52,18 @@ MAX_TOKENS = 1250
RECURSION_DEPTH = 2
INGEST_HEARTBEAT_INTERVAL_SECONDS = 30
# Re-exported here for backward-compatible imports
# (``from application.worker import _derive_source_id`` /
# ``DOCSGPT_INGEST_NAMESPACE``) from tests and any other in-tree callers.
# New code should import from ``application.storage.db.source_ids``
# directly to avoid pulling this Celery worker module into the API
# process at import time.
from application.storage.db.source_ids import ( # noqa: E402, F401
DOCSGPT_INGEST_NAMESPACE,
derive_source_id as _derive_source_id,
)
# Stable namespace for deterministic source IDs derived from idempotency keys.
# Pinned literal — do not change. Re-rolling this would mint different
# source_ids for the same idempotency_keys across deploys, defeating the
# retry-resume contract.
DOCSGPT_INGEST_NAMESPACE = uuid.UUID("fa25d5d1-398b-46df-ac89-8d1c360b9bea")
def _derive_source_id(idempotency_key):
"""``uuid5(NS, key)`` when a key is supplied; ``uuid4()`` otherwise."""
if isinstance(idempotency_key, str) and idempotency_key:
return uuid.uuid5(DOCSGPT_INGEST_NAMESPACE, idempotency_key)
return uuid.uuid4()
def _ingest_heartbeat_loop(source_id, stop_event, interval=INGEST_HEARTBEAT_INTERVAL_SECONDS):
@@ -508,7 +510,6 @@ def ingest_worker(
retriever="classic",
file_name_map=None,
idempotency_key=None,
source_id=None,
):
"""
Ingest and process documents.
@@ -526,11 +527,6 @@ def ingest_worker(
idempotency_key (str|None): When provided, the ``source_id`` is derived
deterministically from the key so a retried task reuses the same
source row instead of duplicating it.
source_id (str|None): UUID minted by the HTTP route and returned in
its response. When supplied, the worker uses it verbatim so SSE
envelopes carry the same id the frontend already has — required
for non-idempotent uploads where the route can't predict
``_derive_source_id(idempotency_key)``.
Returns:
dict: Information about the completed ingestion task, including input parameters and a "limited" flag.
@@ -545,41 +541,10 @@ def ingest_worker(
logging.info(f"Ingest path: {file_path}", extra={"user": user, "job": job_name})
# Source id resolution order:
# 1. Caller-supplied ``source_id`` (HTTP route minted + returned to
# the frontend) — keeps the route response and the SSE event
# payloads in lockstep on the non-idempotent path.
# 2. Deterministic uuid5 from ``idempotency_key`` — retried tasks
# reuse the original source row instead of duplicating it.
# 3. Fresh uuid4 (caller has neither) — opaque, single-shot only.
if source_id:
source_uuid = uuid.UUID(source_id)
else:
source_uuid = _derive_source_id(idempotency_key)
source_id_for_events = str(source_uuid)
# Only emit ``queued`` on the original attempt. Celery retries re-run
# the body, and re-publishing here would oscillate the toast through
# ``queued`` again between ``failed`` and ``completed``.
if self.request.retries == 0:
publish_user_event(
user,
"source.ingest.queued",
{
"job_name": job_name,
"filename": filename,
"source_id": source_id_for_events,
"operation": "upload",
},
scope={"kind": "source", "id": source_id_for_events},
)
# Create temporary working directory
# Wrap the entire body in try/except so a failure between the
# ``queued`` publish above and the inner work (e.g. tempdir
# creation, OS-level resource exhaustion) still emits a terminal
# ``failed`` event rather than leaving the toast wedged on
# 'training' until the polling fallback rescues it 30s later.
try:
with tempfile.TemporaryDirectory() as temp_dir:
with tempfile.TemporaryDirectory() as temp_dir:
try:
os.makedirs(temp_dir, exist_ok=True)
if storage.is_directory(file_path):
@@ -668,22 +633,23 @@ def ingest_worker(
docs = [Document.to_langchain_format(raw_doc) for raw_doc in raw_docs]
id = _derive_source_id(idempotency_key)
vector_store_path = os.path.join(temp_dir, "vector_store")
os.makedirs(vector_store_path, exist_ok=True)
heartbeat_thread, heartbeat_stop = _start_ingest_heartbeat(source_uuid)
heartbeat_thread, heartbeat_stop = _start_ingest_heartbeat(id)
try:
embed_and_store_documents(
docs, vector_store_path, source_uuid, self,
docs, vector_store_path, id, self,
attempt_id=getattr(self.request, "id", None),
user_id=user,
)
finally:
_stop_ingest_heartbeat(heartbeat_thread, heartbeat_stop)
# Defense-in-depth: chunk-progress is the authoritative
# record of how many chunks landed; mismatch raises so the
# task fails loud rather than caching a partial index.
assert_index_complete(source_uuid)
assert_index_complete(id)
tokens = count_tokens_docs(docs)
@@ -698,7 +664,7 @@ def ingest_worker(
"user": user,
"tokens": tokens,
"retriever": retriever,
"id": source_id_for_events,
"id": str(id),
"type": "local",
"file_path": file_path,
"directory_structure": json.dumps(directory_structure),
@@ -707,36 +673,9 @@ def ingest_worker(
file_data["file_name_map"] = json.dumps(file_name_map)
upload_index(vector_store_path, file_data)
publish_user_event(
user,
"source.ingest.completed",
{
"source_id": source_id_for_events,
"filename": filename,
"tokens": tokens,
"operation": "upload",
# Forward-looking contract: ``limited`` is always
# ``False`` today but is carried on the wire so a
# future token-cap detection path can flip it and
# the frontend slice / UploadToast already react.
"limited": False,
},
scope={"kind": "source", "id": source_id_for_events},
)
except Exception as e:
logging.error(f"Error in ingest_worker: {e}", exc_info=True)
publish_user_event(
user,
"source.ingest.failed",
{
"source_id": source_id_for_events,
"filename": filename,
"operation": "upload",
"error": str(e)[:1024],
},
scope={"kind": "source", "id": source_id_for_events},
)
raise
except Exception as e:
logging.error(f"Error in ingest_worker: {e}", exc_info=True)
raise
return {
"directory": directory,
"formats": formats,
@@ -760,23 +699,7 @@ def reingest_source_worker(self, source_id, user):
Returns:
dict: Information about the re-ingestion task
Note:
Reingest does its own ``vector_store.add_chunk`` work rather
than going through ``embed_and_store_documents`` so it does
*not* emit per-percent SSE progress events — only ``queued``,
``completed`` (carrying ``chunks_added`` / ``chunks_deleted``),
or ``failed``. v1 limitation; revisit if reingest gains a
progress-driven UI.
"""
# Declared at the function scope so the outer except can include
# ``name`` in the failed event payload when the failure happens
# after the source lookup. Empty string until the lookup succeeds.
source_name = ""
# Tracks inner-block failures so a ``completed`` event reflects
# partial-success accurately rather than masking it.
inner_warnings: list[str] = []
try:
from application.vectorstore.vector_creator import VectorCreator
@@ -790,27 +713,6 @@ def reingest_source_worker(self, source_id, user):
if not source:
raise ValueError(f"Source {source_id} not found or access denied")
source_id = str(source["id"])
source_name = source.get("name") or ""
# Publish ``queued`` *after* canonicalising ``source_id`` so the
# event references the same id as the source row. Trade-off
# documented: a Celery-backend or PG-lookup hiccup before this
# publish means the toast may see only a ``failed`` event with
# no preceding ``queued`` — acceptable for v1 since both
# conditions also imply broader system trouble. Gate on first
# attempt only so Celery retries don't re-emit ``queued`` after
# a prior attempt already published ``failed``.
if self.request.retries == 0:
publish_user_event(
user,
"source.ingest.queued",
{
"source_id": source_id,
"name": source_name,
"operation": "reingest",
},
scope={"kind": "source", "id": source_id},
)
storage = StorageCreator.get_storage()
source_file_path = source.get("file_path", "")
@@ -908,19 +810,6 @@ def reingest_source_worker(self, source_id, user):
try:
if not added_files and not removed_files:
logging.info("No changes detected.")
publish_user_event(
user,
"source.ingest.completed",
{
"source_id": source_id,
"name": source_name,
"operation": "reingest",
"no_changes": True,
"chunks_added": 0,
"chunks_deleted": 0,
},
scope={"kind": "source", "id": source_id},
)
return {
"source_id": source_id,
"user": user,
@@ -972,9 +861,6 @@ def reingest_source_worker(self, source_id, user):
f"Error during deletion of removed file chunks: {e}",
exc_info=True,
)
inner_warnings.append(
f"deletion failed: {str(e)[:200]}"
)
# 2) Add chunks from new files
added = 0
@@ -1067,9 +953,6 @@ def reingest_source_worker(self, source_id, user):
logging.error(
f"Error during ingestion of new files: {e}", exc_info=True
)
inner_warnings.append(
f"add failed: {str(e)[:200]}"
)
# 3) Update source directory structure timestamp
try:
@@ -1098,25 +981,6 @@ def reingest_source_worker(self, source_id, user):
meta={"current": 100, "status": "Re-ingestion completed"},
)
completed_payload: dict = {
"source_id": source_id,
"name": source_name,
"operation": "reingest",
"chunks_added": added,
"chunks_deleted": deleted,
"tokens": int(total_tokens) if "total_tokens" in locals() else 0,
}
if inner_warnings:
# Surface the per-block failures so the toast can warn
# rather than claim a clean success.
completed_payload["warnings"] = inner_warnings
publish_user_event(
user,
"source.ingest.completed",
completed_payload,
scope={"kind": "source", "id": source_id},
)
return {
"source_id": source_id,
"user": user,
@@ -1134,17 +998,6 @@ def reingest_source_worker(self, source_id, user):
except Exception as e:
logging.error(f"Error in reingest_source_worker: {e}", exc_info=True)
publish_user_event(
user,
"source.ingest.failed",
{
"source_id": str(source_id),
"name": source_name,
"operation": "reingest",
"error": str(e)[:1024],
},
scope={"kind": "source", "id": str(source_id)},
)
raise
@@ -1160,51 +1013,12 @@ def remote_worker(
operation_mode="upload",
doc_id=None,
idempotency_key=None,
source_id=None,
):
safe_user = safe_filename(user)
full_path = os.path.join(directory, safe_user, uuid.uuid4().hex)
os.makedirs(full_path, exist_ok=True)
# Source id resolution order matches ``ingest_worker``:
# 1. ``operation_mode == "sync"`` reuses the existing source's ``doc_id``.
# 2. Caller-supplied ``source_id`` (the HTTP route minted it and
# already returned it to the frontend) — keeps the route
# response and the SSE event payloads in lockstep on the
# no-idempotency-key path.
# 3. Deterministic uuid5 from ``idempotency_key`` — retried tasks
# reuse the original source row instead of duplicating it.
# 4. Fresh uuid4 — opaque, single-shot only.
if operation_mode == "sync" and doc_id:
source_uuid = str(doc_id)
elif source_id:
source_uuid = uuid.UUID(source_id)
else:
source_uuid = _derive_source_id(idempotency_key)
source_id_for_events = str(source_uuid)
# Emit the queued event before any work that could fail (including
# ``update_state``) so the toast UI always sees a queued envelope
# before any subsequent failed event. Gated on first attempt so
# Celery retries don't re-emit ``queued`` after a prior ``failed``.
if self.request.retries == 0:
publish_user_event(
user,
"source.ingest.queued",
{
"source_id": source_id_for_events,
"job_name": name_job,
"loader": loader,
"operation": operation_mode,
},
scope={"kind": "source", "id": source_id_for_events},
)
# Wrap ``update_state`` plus the entire body so any pre-loader
# failure (Celery backend down, OS resource issue) still emits a
# terminal ``failed`` event rather than wedging the toast.
self.update_state(state="PROGRESS", meta={"current": 1})
try:
self.update_state(state="PROGRESS", meta={"current": 1})
logging.info("Initializing remote loader with type: %s", loader)
remote_loader = RemoteCreator.create_loader(loader)
raw_docs = remote_loader.load_data(source_data)
@@ -1291,22 +1105,22 @@ def remote_worker(
)
if operation_mode == "upload":
id = _derive_source_id(idempotency_key)
embed_and_store_documents(
docs, full_path, source_uuid, self,
docs, full_path, id, self,
attempt_id=getattr(self.request, "id", None),
user_id=user,
)
assert_index_complete(source_uuid)
assert_index_complete(id)
elif operation_mode == "sync":
if not doc_id:
logging.error("Invalid doc_id provided for sync operation: %s", doc_id)
raise ValueError("doc_id must be provided for sync operation.")
id = str(doc_id)
embed_and_store_documents(
docs, full_path, source_uuid, self,
docs, full_path, id, self,
attempt_id=getattr(self.request, "id", None),
user_id=user,
)
assert_index_complete(source_uuid)
assert_index_complete(id)
self.update_state(state="PROGRESS", meta={"current": 100})
# Serialize remote_data as JSON if it's a dict (for S3, Reddit, etc.)
@@ -1318,7 +1132,7 @@ def remote_worker(
"user": user,
"tokens": tokens,
"retriever": retriever,
"id": source_id_for_events,
"id": str(id),
"type": loader,
"remote_data": remote_data_serialized,
"sync_frequency": sync_frequency,
@@ -1332,49 +1146,23 @@ def remote_worker(
try:
with db_session() as conn:
repo = SourcesRepository(conn)
src = repo.get_any(source_id_for_events, user)
src = repo.get_any(str(id), user)
if src is not None:
repo.update(str(src["id"]), user, {"date": last_sync_now})
except Exception as upd_err:
logging.warning(
f"Failed to update last_sync for source {source_id_for_events}: {upd_err}"
f"Failed to update last_sync for source {id}: {upd_err}"
)
upload_index(full_path, file_data)
publish_user_event(
user,
"source.ingest.completed",
{
"source_id": source_id_for_events,
"job_name": name_job,
"loader": loader,
"operation": operation_mode,
"tokens": tokens,
# Forward-looking contract: see ingest_worker.
"limited": False,
},
scope={"kind": "source", "id": source_id_for_events},
)
except Exception as e:
logging.error("Error in remote_worker task: %s", str(e), exc_info=True)
publish_user_event(
user,
"source.ingest.failed",
{
"source_id": source_id_for_events,
"job_name": name_job,
"loader": loader,
"operation": operation_mode,
"error": str(e)[:1024],
},
scope={"kind": "source", "id": source_id_for_events},
)
raise
finally:
if os.path.exists(full_path):
shutil.rmtree(full_path)
logging.info("remote_worker task completed successfully")
return {
"id": source_id_for_events,
"id": str(id),
"urls": source_data,
"name_job": name_job,
"user": user,
@@ -1457,13 +1245,6 @@ def attachment_worker(self, file_info, user):
relative_path = file_info["path"]
metadata = file_info.get("metadata", {})
publish_user_event(
user,
"attachment.queued",
{"attachment_id": str(attachment_id), "filename": filename},
scope={"kind": "attachment", "id": str(attachment_id)},
)
try:
self.update_state(state="PROGRESS", meta={"current": 10})
storage = StorageCreator.get_storage()
@@ -1471,17 +1252,6 @@ def attachment_worker(self, file_info, user):
self.update_state(
state="PROGRESS", meta={"current": 30, "status": "Processing content"}
)
publish_user_event(
user,
"attachment.progress",
{
"attachment_id": str(attachment_id),
"filename": filename,
"current": 30,
"stage": "processing",
},
scope={"kind": "attachment", "id": str(attachment_id)},
)
file_extractor = get_default_file_extractor(
ocr_enabled=settings.DOCLING_OCR_ATTACHMENTS_ENABLED
@@ -1514,17 +1284,6 @@ def attachment_worker(self, file_info, user):
self.update_state(
state="PROGRESS", meta={"current": 80, "status": "Storing in database"}
)
publish_user_event(
user,
"attachment.progress",
{
"attachment_id": str(attachment_id),
"filename": filename,
"current": 80,
"stage": "storing",
},
scope={"kind": "attachment", "id": str(attachment_id)},
)
mime_type = mimetypes.guess_type(filename)[0] or "application/octet-stream"
@@ -1549,18 +1308,6 @@ def attachment_worker(self, file_info, user):
self.update_state(state="PROGRESS", meta={"current": 100, "status": "Complete"})
publish_user_event(
user,
"attachment.completed",
{
"attachment_id": str(attachment_id),
"filename": filename,
"token_count": token_count,
"mime_type": mime_type,
},
scope={"kind": "attachment", "id": str(attachment_id)},
)
return {
"filename": filename,
"path": relative_path,
@@ -1575,16 +1322,6 @@ def attachment_worker(self, file_info, user):
extra={"user": user},
exc_info=True,
)
publish_user_event(
user,
"attachment.failed",
{
"attachment_id": str(attachment_id),
"filename": filename,
"error": str(e)[:1024],
},
scope={"kind": "attachment", "id": str(attachment_id)},
)
raise
@@ -1648,7 +1385,6 @@ def ingest_connector(
doc_id=None,
sync_frequency: str = "never",
idempotency_key=None,
source_id=None,
) -> Dict[str, Any]:
"""
Ingestion for internal knowledge bases (GoogleDrive, etc.).
@@ -1667,50 +1403,14 @@ def ingest_connector(
sync_frequency: How often to sync ("never", "daily", "weekly", "monthly")
idempotency_key: When provided, the ``source_id`` is derived
deterministically so a retried upload reuses the same source row.
source_id: When supplied, the worker uses it verbatim so SSE envelopes
carry the same id the HTTP route already returned to the frontend
— required for non-idempotent uploads where the route can't
predict ``_derive_source_id(idempotency_key)``.
"""
logging.info(
f"Starting remote ingestion from {source_type} for user: {user}, job: {job_name}"
)
# Source id resolution mirrors ``ingest_worker`` / ``remote_worker``:
# sync mode reuses ``doc_id``; otherwise the caller-supplied
# ``source_id`` (minted by the HTTP route and already echoed to the
# client) wins; fall back to ``_derive_source_id`` only when neither
# is supplied. Without rule (2) the no-idempotency-key path would
# mint a fresh uuid4 here that the frontend has no way to correlate
# SSE envelopes to.
if operation_mode == "sync" and doc_id:
source_uuid = str(doc_id)
elif source_id:
source_uuid = uuid.UUID(source_id)
else:
source_uuid = _derive_source_id(idempotency_key)
source_id_for_events = str(source_uuid)
# First-attempt gate: Celery retries re-run the body, and a
# repeated ``queued`` here would oscillate the toast through
# ``queued`` again between ``failed`` and ``completed``.
if self.request.retries == 0:
publish_user_event(
user,
"source.ingest.queued",
{
"source_id": source_id_for_events,
"job_name": job_name,
"loader": source_type,
"operation": operation_mode,
},
scope={"kind": "source", "id": source_id_for_events},
)
self.update_state(state="PROGRESS", meta={"current": 1})
try:
with tempfile.TemporaryDirectory() as temp_dir:
with tempfile.TemporaryDirectory() as temp_dir:
try:
# Step 1: Initialize the appropriate loader
self.update_state(
state="PROGRESS",
@@ -1748,22 +1448,6 @@ def ingest_connector(
"files_downloaded", 0
):
logging.warning(f"No files were downloaded from {source_type}")
# Connector returned no files — surface as a benign
# ``completed`` event with zero tokens so the toast
# closes out cleanly instead of waiting on polling.
publish_user_event(
user,
"source.ingest.completed",
{
"source_id": source_id_for_events,
"job_name": job_name,
"loader": source_type,
"operation": operation_mode,
"tokens": 0,
"no_changes": True,
},
scope={"kind": "source", "id": source_id_for_events},
)
# Create empty result directly instead of calling a separate method
return {
"name": job_name,
@@ -1813,16 +1497,16 @@ def ingest_connector(
docs = [Document.to_langchain_format(raw_doc) for raw_doc in raw_docs]
# Validate operation_mode here too (the source_uuid path
# at the top of the function only branches on the
# sync+doc_id combination; surfacing the wrong-mode error
# this far in matches the legacy behaviour).
if operation_mode == "sync" and not doc_id:
logging.error(
"Invalid doc_id provided for sync operation: %s", doc_id
)
raise ValueError("doc_id must be provided for sync operation.")
if operation_mode not in ("upload", "sync"):
if operation_mode == "upload":
id = _derive_source_id(idempotency_key)
elif operation_mode == "sync":
if not doc_id:
logging.error(
"Invalid doc_id provided for sync operation: %s", doc_id
)
raise ValueError("doc_id must be provided for sync operation.")
id = str(doc_id)
else:
raise ValueError(f"Invalid operation_mode: {operation_mode}")
vector_store_path = os.path.join(temp_dir, "vector_store")
@@ -1832,11 +1516,10 @@ def ingest_connector(
state="PROGRESS", meta={"current": 80, "status": "Storing documents"}
)
embed_and_store_documents(
docs, vector_store_path, source_uuid, self,
docs, vector_store_path, id, self,
attempt_id=getattr(self.request, "id", None),
user_id=user,
)
assert_index_complete(source_uuid)
assert_index_complete(id)
tokens = count_tokens_docs(docs)
@@ -1846,7 +1529,7 @@ def ingest_connector(
"name": job_name,
"tokens": tokens,
"retriever": retriever,
"id": source_id_for_events,
"id": str(id),
"type": "connector:file",
"remote_data": json.dumps(
{"provider": source_type, **api_source_config}
@@ -1855,13 +1538,16 @@ def ingest_connector(
"sync_frequency": sync_frequency,
}
file_data["last_sync"] = datetime.datetime.now()
if operation_mode == "sync":
file_data["last_sync"] = datetime.datetime.now()
else:
file_data["last_sync"] = datetime.datetime.now()
if operation_mode == "sync":
try:
with db_session() as conn:
repo = SourcesRepository(conn)
src = repo.get_any(source_id_for_events, user)
src = repo.get_any(str(id), user)
if src is not None:
repo.update(
str(src["id"]), user,
@@ -1869,9 +1555,7 @@ def ingest_connector(
)
except Exception as upd_err:
logging.warning(
"Failed to update last_sync for source %s: %s",
source_id_for_events,
upd_err,
f"Failed to update last_sync for source {id}: {upd_err}"
)
upload_index(vector_store_path, file_data)
@@ -1883,104 +1567,45 @@ def ingest_connector(
logging.info(f"Remote ingestion completed: {job_name}")
publish_user_event(
user,
"source.ingest.completed",
{
"source_id": source_id_for_events,
"job_name": job_name,
"loader": source_type,
"operation": operation_mode,
"tokens": tokens,
},
scope={"kind": "source", "id": source_id_for_events},
)
return {
"user": user,
"name": job_name,
"tokens": tokens,
"type": source_type,
"id": source_id_for_events,
"id": str(id),
"status": "complete",
}
except Exception as e:
logging.error(f"Error during remote ingestion: {e}", exc_info=True)
publish_user_event(
user,
"source.ingest.failed",
{
"source_id": source_id_for_events,
"job_name": job_name,
"loader": source_type,
"operation": operation_mode,
"error": str(e)[:1024],
},
scope={"kind": "source", "id": source_id_for_events},
)
raise
except Exception as e:
logging.error(f"Error during remote ingestion: {e}", exc_info=True)
raise
def mcp_oauth(self, config: Dict[str, Any], user_id: str = None) -> Dict[str, Any]:
"""Worker to handle MCP OAuth flow asynchronously.
Publishes SSE events at each phase boundary so the frontend can
drive the OAuth popup directly from the push channel. The
``mcp.oauth.awaiting_redirect`` envelope carries the
``authorization_url`` once the upstream OAuth client surfaces it,
eliminating the prior polling-only path for that URL.
"""
# Bind ``task_id`` and the publish helpers OUTSIDE the outer try so
# the ``except`` handler at the bottom can reach them even when an
# early statement raises. Without this, ``publish_oauth`` would
# UnboundLocalError on top of the original failure.
task_id = self.request.id if getattr(self, "request", None) else None
def publish_oauth(event_type: str, payload: Dict[str, Any]) -> None:
# MCP OAuth can be invoked without a route-bound user_id by
# legacy paths. Skip the SSE publish in that case \u2014 the caller
# has no per-user channel to subscribe to, and the status is
# surfaced via the task's return value.
if not user_id or task_id is None:
return
publish_user_event(
user_id,
event_type,
{"task_id": task_id, **payload},
scope={"kind": "mcp_oauth", "id": task_id},
)
def publish_awaiting_redirect(authorization_url: str) -> None:
"""Callback invoked by ``DocsGPTOAuth.redirect_handler`` once
the OAuth client has minted the authorization URL.
Carrying the URL on the SSE envelope lets the frontend open the
popup directly from the event \u2014 the prior polling-only path
for the URL is gone.
"""
publish_oauth(
"mcp.oauth.awaiting_redirect",
{
"message": "Awaiting OAuth redirect...",
"authorization_url": authorization_url,
},
)
"""Worker to handle MCP OAuth flow asynchronously."""
try:
import asyncio
from application.agents.tools.mcp_tool import MCPTool
publish_oauth("mcp.oauth.in_progress", {"message": "Starting OAuth..."})
task_id = self.request.id
redis_client = get_redis_instance()
def update_status(status_data: Dict[str, Any]):
status_key = f"mcp_oauth_status:{task_id}"
redis_client.setex(status_key, 600, json.dumps(status_data))
update_status(
{
"status": "in_progress",
"message": "Starting OAuth...",
"task_id": task_id,
}
)
tool_config = config.copy()
tool_config["oauth_task_id"] = task_id
# Inject the awaiting-redirect publish callback. ``MCPTool`` pops
# it out of the config and threads it into ``DocsGPTOAuth`` so
# the publish fires synchronously from inside
# ``redirect_handler`` \u2014 the only point where the URL is known.
tool_config["oauth_redirect_publish"] = publish_awaiting_redirect
mcp_tool = MCPTool(tool_config, user_id)
async def run_oauth_discovery():
@@ -1988,6 +1613,14 @@ def mcp_oauth(self, config: Dict[str, Any], user_id: str = None) -> Dict[str, An
mcp_tool._setup_client()
return await mcp_tool._execute_with_client("list_tools")
update_status(
{
"status": "awaiting_redirect",
"message": "Awaiting OAuth redirect...",
"task_id": task_id,
}
)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
@@ -1995,21 +1628,49 @@ def mcp_oauth(self, config: Dict[str, Any], user_id: str = None) -> Dict[str, An
loop.run_until_complete(run_oauth_discovery())
tools = mcp_tool.get_actions_metadata()
publish_oauth(
"mcp.oauth.completed",
{"tools": tools, "tools_count": len(tools)},
update_status(
{
"status": "completed",
"message": f"Connected \u2014 found {len(tools)} tool{'s' if len(tools) != 1 else ''}.",
"tools": tools,
"tools_count": len(tools),
"task_id": task_id,
}
)
return {"success": True, "tools": tools, "tools_count": len(tools)}
except Exception as e:
error_msg = f"OAuth failed: {str(e)}"
logging.error("MCP OAuth discovery failed: %s", error_msg, exc_info=True)
publish_oauth("mcp.oauth.failed", {"error": error_msg[:1024]})
update_status(
{
"status": "error",
"message": error_msg,
"task_id": task_id,
}
)
return {"success": False, "error": error_msg}
finally:
loop.close()
except Exception as e:
error_msg = f"OAuth init failed: {str(e)}"
logging.error("MCP OAuth init failed: %s", error_msg, exc_info=True)
publish_oauth("mcp.oauth.failed", {"error": error_msg[:1024]})
update_status(
{
"status": "error",
"message": error_msg,
"task_id": task_id,
}
)
return {"success": False, "error": error_msg}
def mcp_oauth_status(self, task_id: str) -> Dict[str, Any]:
"""Check the status of an MCP OAuth flow."""
redis_client = get_redis_instance()
status_key = f"mcp_oauth_status:{task_id}"
status_data = redis_client.get(status_key)
if status_data:
return json.loads(status_data)
return {"status": "not_found", "message": "Status not found"}

View File

@@ -1,385 +0,0 @@
# SSE Notifications Runbook
> Operations guide for "user says they didn't get a notification" — and
> the related "the bell never lights up" / "my upload toast hangs" /
> "the chat answer doesn't reconnect" symptoms.
The user-facing notifications channel is the SSE pipe at
`/api/events` plus per-message reconnects at
`/api/messages/<id>/events`. This document maps a user complaint to
the diagnostic that surfaces the cause.
---
## TL;DR — first 60 seconds
Run these three commands in parallel before anything else:
```bash
# 1) Is Redis up and serving the pipe? Should print PONG instantly.
redis-cli -n 2 PING
# 2) Anyone subscribed to the channel right now? Numbers per channel.
redis-cli -n 2 PUBSUB NUMSUB user:<user_id>
# 3) Is the user's backlog populated? Returns the count of journaled events.
redis-cli -n 2 XLEN user:<user_id>:stream
```
- `PING` failing → Redis is the problem. Skip to "Redis-down".
- `NUMSUB user:<user_id>` returns 0 → no client connected. Skip to "Client never connects".
- `XLEN user:<user_id>:stream` returns 0 or low → publisher isn't writing. Skip to "Publisher silent".
- All three look healthy → the events are flowing on the wire; the issue is downstream of the slice (UI rendering, toast suppression, etc.). Skip to "Events flowing but UI silent".
---
## Architecture cheat-sheet
```
Worker (publish_user_event) Frontend tab
│ ▲
▼ │ GET /api/events SSE
Redis Streams: XADD Flask route
user:<id>:stream ──────────────► replay_backlog (snapshot)
│ +
▼ Topic.subscribe (live tail)
Redis pub/sub: PUBLISH │
user:<id> ────────────────────────────────┘
```
**Source of truth:**
- Persistent journal: Redis Stream `user:<user_id>:stream`, capped at
`EVENTS_STREAM_MAXLEN` (default 1000) entries via `MAXLEN ~`. ~24h
at typical event rates.
- Live fan-out: Redis pub/sub channel `user:<user_id>`. No durability;
subscribers must be attached at publish time.
The chat-stream pipe is separate, parallel infrastructure:
- Journal: Postgres `message_events` table.
- Live fan-out: Redis pub/sub `channel:<message_id>`.
Same patterns, different durability layer. This doc covers both;
they share most diagnostic commands.
---
## Symptom → diagnostic map
### A. "I uploaded a source and the toast never appeared"
User flow: chat → upload → expect toast.
| Step | Command | Expect |
| ------------------------------------------------- | ------------------------------------------------------------- | ----------------------------------------------- |
| Worker received the task | `tail -f celery.log` filtered by user | `ingest_worker` start log line |
| Worker published the queued event | `redis-cli -n 2 XREVRANGE user:<id>:stream + - COUNT 5` | A `source.ingest.queued` entry within seconds |
| Frontend got it | DevTools → Network → `/api/events` → EventStream tab | `data: {"type":"source.ingest.queued",...}` |
| Slice updated | Redux DevTools → state.upload.tasks | Task with matching `sourceId`, `status:'training'` |
If the worker's queued log line is there but the XADD didn't land →
look for a `publish_user_event payload not JSON-serializable` warning
in the worker log (the publisher swallows `TypeError`).
If the XADD landed but the frontend never received it → check
`PUBSUB NUMSUB user:<id>` while the user is on the page. If 0, the
SSE connection isn't subscribed; skip to "Client never connects".
If the frontend received it but the toast didn't render → the
`uploadSlice` extraReducer requires `task.sourceId` to match the
event's `scope.id`. Check the upload route returned `source_id` in
its POST response (the upload, connector, and reingest paths all
include it). Idempotent / cached responses must also include
`source_id` (`_claim_task_or_get_cached`).
### B. "The bell badge never goes up"
There is no bell — the global notifications surface is per-event
toasts, not an aggregated counter. If the user is on an old build,
`Cmd-Shift-R` to bypass cache. The surfaces they're looking for are
`UploadToast` for source uploads and `ToolApprovalToast` for
tool-approval events.
### C. "My chat answer froze mid-stream and never recovered"
User flow: ask question → answer streaming → network blip → answer
stops; should reconnect.
```bash
# Was the original message reserved in PG?
psql -c "SELECT id, status, prompt FROM conversation_messages \
WHERE user_id = '<user>' ORDER BY timestamp DESC LIMIT 5;"
# Did the journal capture events past the user's last-seen seq?
psql -c "SELECT sequence_no, event_type FROM message_events \
WHERE message_id = '<id>' ORDER BY sequence_no;"
# Is the live tail still producing? (subscribe and watch)
redis-cli -n 2 SUBSCRIBE channel:<message_id>
```
The frontend should reconnect via `GET /api/messages/<id>/events`
when the original POST stream closes without a typed `end` or
`error` event. If it's not reconnecting, `console.warn('Stream
reconnect failed', ...)` will be in the browser console — the
reconnect HTTP errored. Common cases:
- The user's JWT rotated mid-stream → 401 on the GET. Frontend
doesn't auto-refresh; the user reloads.
- The user is on a different host than the API and CORS is rejecting
the GET → check `application/asgi.py` allow-headers.
### D. "The dev install never delivers any notifications at all"
Default `AUTH_TYPE` unset means `decoded_token = {"sub": "local"}`
for every request. The SSE client connects without the
`Authorization` header in this case, and `user:local:stream` is
the shared channel everything goes to. If the user has multiple dev
machines pointing at the same Redis, they will see each other's
events. Confirm with:
```bash
redis-cli -n 2 KEYS 'user:local:*'
```
If multiple deployments share the Redis, document that as a known
multi-user-on-local-channel limitation. Set `AUTH_TYPE=simple_jwt`
to scope per-user.
### E. "The notifications channel was working, then suddenly stopped after the user reloaded the page"
Likely path: `backlog.truncated` event fired, the slice cleared
`lastEventId` to null, the closure was carrying the same stale id and
re-tripped the same truncation on every reconnect. **Verify the user
is on a current build — `eventStreamClient.ts` must re-read
`lastEventId = opts.getLastEventId();` without a truthy guard so the
null clear propagates into the next reconnect.**
### F. "I keep getting 429 on /api/events"
The per-user concurrent-connection cap (`SSE_MAX_CONCURRENT_PER_USER`,
default 8) refused the connection. User has too many tabs open or a
runaway reconnect loop. `redis-cli -n 2 GET user:<id>:sse_count`
shows the live counter; the TTL is 1h from the last connection
attempt (rolling — every INCR re-seeds it), so the key only ages
out after the user stops reconnecting for a full hour.
If the count is wedged high without explanation, the
counter-DECR-in-finally path didn't run (worker SIGKILL, OOM). Wait
for the TTL or `redis-cli -n 2 DEL user:<id>:sse_count` to reset.
### G. "Replay snapshot stops at 200 events"
The route caps each replay at `EVENTS_REPLAY_MAX_PER_REQUEST`
(default 200). The cap is intentionally **silent** — the route does NOT
emit a `backlog.truncated` notice for cap-hit. The 200 entries each
carry their own `id:` header, so the frontend's slice cursor
advances to the most-recent delivered id. Next reconnect sends
`last_event_id=<max_replayed>` and the snapshot resumes from there.
A user that was 1000 entries behind catches up over ~5 reconnects.
If the user reports getting HTTP 429 on `/api/events` despite being
well under `SSE_MAX_CONCURRENT_PER_USER`, they hit the windowed
replay budget (`EVENTS_REPLAY_BUDGET_REQUESTS_PER_WINDOW`, default
30 / `EVENTS_REPLAY_BUDGET_WINDOW_SECONDS` 60s). The route refuses
the connection so the slice cursor stays pinned at whatever value
it had; the frontend backs off and the next reconnect (after the
window rolls) gets the proper snapshot. Serving the live tail
without a snapshot used to be the behavior here, but that let the
client advance `lastEventId` past entries it never received,
permanently stranding the un-replayed window — so the route now
429s instead. `redis-cli -n 2 GET user:<id>:replay_count` shows the
current counter; TTL is the window size.
`backlog.truncated` is emitted ONLY when the client's
`Last-Event-ID` has slid off the MAXLEN'd window — i.e. the journal
is genuinely gone past the cursor and the frontend should clear the
slice cursor and refetch state. Treating cap-hit or
budget-exhaustion the same way would lock the user into re-receiving
the oldest 200 entries on every reconnect (the cursor would clear,
the snapshot would re-serve from the start, the cap would re-trip).
### H. "User says push notifications stopped after a deploy"
- Pull `event.published topic=user:<id> type=...` from the worker
logs to confirm the publisher is still firing.
- Pull `event.connect user=<id>` from the API logs to confirm the
client is reconnecting.
- Check the gunicorn worker count and `WSGIMiddleware(workers=32)`
if the deploy reduced worker count, the per-user cap is still 8
but total concurrent SSE connections are bounded by `gunicorn
workers × 32`. A capacity miss looks like users randomly getting
429'd.
---
## Common failure modes
### Redis-down
Symptoms: `/api/events` returns 200 but emits only `: connected`
then the body closes. `XLEN` and `PUBLISH` both fail. The publisher's
`record_event` swallows the failure and returns False; the live tail
publish also drops on the floor. Frontend retries forever with
exponential backoff.
Resolution: bring Redis back. The journal is gone (was in-memory
only — Streams persist within a single Redis instance, no replication
configured). New events flow as soon as Redis comes back.
### `AUTH_TYPE` misconfigured = sub:"local" cross-stream
Symptoms: every user shares `user:local:stream`. Any user sees
everyone else's notifications.
Resolution: set `AUTH_TYPE=simple_jwt` (or `session_jwt`) in `.env`.
The events route logs a one-time WARNING per process when
`sub == "local"` is observed. A repeat WARNING after a restart
confirms the misconfiguration.
### MAXLEN trimmed past Last-Event-ID
Symptoms: client reconnects with `last_event_id=X`, snapshot returns
the entire MAXLEN'd backlog (because X is older than the oldest
retained entry). Old events appear duplicated.
Detection: the route's `_oldest_retained_id` check emits
`backlog.truncated` when this case fires. Frontend's
`dispatchSSEEvent` clears `lastEventId` so the next reconnect starts
fresh.
If the WARNING isn't firing but symptoms match: the user's client
may have a corrupt cached `lastEventId`. `localStorage` doesn't
store this state; check Redux state via DevTools.
### Stale event-stream client
Symptoms: events visible in `XRANGE` but the frontend slice doesn't
update.
```bash
# Is the client subscribed?
redis-cli -n 2 PUBSUB NUMSUB user:<id>
# When did its connection start?
grep "event.connect user=<id>" /var/log/docsgpt.log | tail -3
```
If `NUMSUB` is 0 and no recent `event.connect`, the user's tab is
closed or the connection died and never reconnected. Push them to
reload.
### Publisher silent
Symptoms: worker is processing the task (Celery says SUCCESS), but
no XADD and no PUBLISH. User sees no events.
```bash
# Was the publisher import error suppressed?
grep "publish_user_event" /var/log/celery.log | grep -i "warn\|error" | tail -20
# Is push disabled?
grep "ENABLE_SSE_PUSH" /var/log/docsgpt.log | tail -5
```
`ENABLE_SSE_PUSH=False` in `.env` would silence the publisher
globally. Useful for incident response if a runaway publisher is
DoS'ing Redis; toggle off, fix root cause, toggle on.
---
## Useful one-liners
```bash
# Watch a user's live event stream in real time (all events, all types)
redis-cli -n 2 PSUBSCRIBE 'user:*' | grep "user:<id>"
# Last 10 events the user would see on reconnect
redis-cli -n 2 XREVRANGE user:<id>:stream + - COUNT 10
# Live count of subscribed clients per user
redis-cli -n 2 PUBSUB NUMSUB $(redis-cli -n 2 PUBSUB CHANNELS 'user:*')
# Trim a runaway stream (CAREFUL — destroys backlog for all current
# subscribers; OK after explaining to the user)
redis-cli -n 2 XTRIM user:<id>:stream MAXLEN 0
# Clear a wedged concurrent-connection counter
redis-cli -n 2 DEL user:<id>:sse_count
# Force-flip every client to re-snapshot (drop the stream key entirely
# — destroys the backlog; clients reconnect with their last id and
# get a backlog.truncated)
redis-cli -n 2 DEL user:<id>:stream
```
---
## Settings reference
Everything in `application/core/settings.py`:
| Setting | Default | Purpose |
| --------------------------------------------- | ------- | --------------------------------------------- |
| `ENABLE_SSE_PUSH` | `True` | Master switch. False = publisher no-ops, route serves "push_disabled" comment. |
| `EVENTS_STREAM_MAXLEN` | `1000` | Per-user backlog cap. Approximate via `XADD MAXLEN ~`. |
| `SSE_KEEPALIVE_SECONDS` | `15` | Comment-frame cadence. Must sit under reverse-proxy idle close. |
| `SSE_MAX_CONCURRENT_PER_USER` | `8` | Cap on simultaneous SSE connections per user. 0 = disabled. |
| `EVENTS_REPLAY_MAX_PER_REQUEST` | `200` | Hard cap on snapshot rows per request. |
| `EVENTS_REPLAY_BUDGET_REQUESTS_PER_WINDOW` | `30` | Per-user replays per window. 0 = disabled. |
| `EVENTS_REPLAY_BUDGET_WINDOW_SECONDS` | `60` | Window length. |
| `MESSAGE_EVENTS_RETENTION_DAYS` | `14` | Retention for the `message_events` journal; `cleanup_message_events` beat task deletes older rows. |
---
## Known limitations
### Each tab runs its own SSE connection
There is no cross-tab dedup. Every tab open to the app holds its
own SSE connection and dispatches every received event into its
own Redux store, so a user with N tabs open will see N copies of
each toast. With `SSE_MAX_CONCURRENT_PER_USER=8` (the default) a
heavy multi-tab user can also hit the connection cap and start
seeing 429s. Cross-tab dedup via a `BroadcastChannel` ring +
`navigator.locks`-based leader election is tracked as future work.
### `/c/<unknown-id>` normalises to `/c/new`
If a user navigates to a conversation id that isn't in their
loaded list, the conversation route rewrites the URL to `/c/new`.
`ToolApprovalToast`'s gate uses `useMatch('/c/:conversationId')`,
so for the brief window after the rewrite the toast may surface
for a conversation the user *thought* they were already viewing.
Pre-existing route behaviour; not a notifications regression.
### Terminal events un-dismiss running uploads
`frontend/src/upload/uploadSlice.ts` sets `dismissed: false` when
an upload reaches `completed` or `failed`. If the user dismissed a
running task and the terminal SSE arrives later, the toast pops
back. Intentional ("notify the user it's done"); revisit if the
re-surface UX is too aggressive for v2.
### Werkzeug doesn't auto-reload route files
The dev server (`flask run`) doesn't watch
`application/api/events/routes.py` for changes by default.
After editing the route, restart Flask manually — `--reload`
isn't on. (Production gunicorn reloads via deploy.)
### MCP OAuth completion can fall outside the user stream's MAXLEN window
`get_oauth_status` scans up to `EVENTS_STREAM_MAXLEN` (~1000) entries via `XREVRANGE`. If the user has a high-rate ingest running concurrent with the OAuth handshake, the `mcp.oauth.completed` envelope can be trimmed off the back before they click Save. Symptom: backend returns "OAuth failed or not completed" even though the popup completed successfully.
Mitigation today: bump `EVENTS_STREAM_MAXLEN` per-deployment if your users routinely flood the channel during OAuth flows. A dedicated short-TTL Redis key for OAuth task results is tracked as a follow-up.
### React StrictMode double-mounts SSE
In dev, React 18 StrictMode mounts → unmounts → remounts every
component, briefly opening two SSE connections per tab before the
first is aborted. With `SSE_MAX_CONCURRENT_PER_USER=8` and 45
tabs open concurrently you can transiently hit the cap and see
HTTP 429 on cold-load. The first connection's counter increment
fires before the AbortController-induced disconnect can decrement
it. Production (single mount, no StrictMode) is unaffected; raise
the cap in dev or accept transient 429s.

File diff suppressed because it is too large Load Diff

View File

@@ -7,8 +7,6 @@
"dev": "vite",
"build": "tsc && vite build",
"preview": "vite preview",
"test": "vitest run",
"test:watch": "vitest",
"lint": "eslint ./src --ext .jsx,.js,.ts,.tsx",
"lint-fix": "eslint ./src --ext .jsx,.js,.ts,.tsx --fix",
"format": "prettier ./src --write",
@@ -71,7 +69,6 @@
"eslint-plugin-promise": "^6.6.0",
"eslint-plugin-react": "^7.37.5",
"eslint-plugin-unused-imports": "^4.1.4",
"happy-dom": "^17.6.3",
"husky": "^9.1.7",
"lint-staged": "^16.4.0",
"postcss": "^8.5.12",
@@ -81,7 +78,6 @@
"tw-animate-css": "^1.4.0",
"typescript": "^6.0.3",
"vite": "^8.0.10",
"vite-plugin-svgr": "^4.3.0",
"vitest": "^3.2.4"
"vite-plugin-svgr": "^4.3.0"
}
}

View File

@@ -10,7 +10,6 @@ import Spinner from './components/Spinner';
import UploadToast from './components/UploadToast';
import Conversation from './conversation/Conversation';
import { SharedConversation } from './conversation/SharedConversation';
import { EventStreamProvider } from './events/EventStreamProvider';
import { useDarkTheme, useMediaQuery } from './hooks';
import useDataInitializer from './hooks/useDataInitializer';
import useTokenAuth from './hooks/useTokenAuth';
@@ -18,7 +17,6 @@ import Navigation from './Navigation';
import PageNotFound from './PageNotFound';
import Setting from './settings';
import Notification from './components/Notification';
import ToolApprovalToast from './notifications/ToolApprovalToast';
function AuthWrapper({ children }: { children: React.ReactNode }) {
const { isAuthLoading } = useTokenAuth();
@@ -31,7 +29,7 @@ function AuthWrapper({ children }: { children: React.ReactNode }) {
</div>
);
}
return <EventStreamProvider>{children}</EventStreamProvider>;
return <>{children}</>;
}
function MainLayout() {
@@ -52,7 +50,6 @@ function MainLayout() {
<Outlet />
</div>
<UploadToast />
<ToolApprovalToast />
</div>
);
}

View File

@@ -813,11 +813,7 @@ function WorkflowBuilderInner() {
const response = await userService.getWorkflow(workflowId, token);
if (!response.ok) throw new Error('Failed to fetch workflow');
const responseData = await response.json();
const {
workflow,
nodes: apiNodes,
edges: apiEdges,
} = responseData.data;
const { workflow, nodes: apiNodes, edges: apiEdges } = responseData.data;
const nextWorkflowName = workflow.name;
const nextWorkflowDescription = workflow.description || '';
const mappedNodes = apiNodes.map((n: WorkflowNode) => {
@@ -1476,9 +1472,7 @@ function WorkflowBuilderInner() {
{t('agents.form.advanced.systemPromptOverride')}
</label>
<p className="mt-0.5 text-[11px] text-gray-500 dark:text-gray-400">
{t(
'agents.form.advanced.systemPromptOverrideDescription',
)}
{t('agents.form.advanced.systemPromptOverrideDescription')}
</p>
</div>
<button

View File

@@ -1,5 +1,3 @@
import { withThrottle, type FetchLike } from './throttle';
export const baseURL =
import.meta.env.VITE_API_HOST || 'https://docsapi.arc53.com';
@@ -20,121 +18,112 @@ const getHeaders = (
return headers;
};
const createClient = (transport: FetchLike) => {
const request = (url: string, init: RequestInit): Promise<Response> =>
transport(`${baseURL}${url}`, init);
const apiClient = {
get: (
url: string,
token: string | null,
headers = {},
signal?: AbortSignal,
): Promise<any> =>
fetch(`${baseURL}${url}`, {
method: 'GET',
headers: getHeaders(token, headers),
signal,
}).then((response) => {
return response;
}),
return {
get: (
url: string,
token: string | null,
headers = {},
signal?: AbortSignal,
): Promise<any> =>
request(url, {
method: 'GET',
headers: getHeaders(token, headers),
signal,
}),
post: (
url: string,
data: any,
token: string | null,
headers = {},
signal?: AbortSignal,
): Promise<any> =>
fetch(`${baseURL}${url}`, {
method: 'POST',
headers: getHeaders(token, headers),
body: JSON.stringify(data),
signal,
}).then((response) => {
return response;
}),
post: (
url: string,
data: any,
token: string | null,
headers = {},
signal?: AbortSignal,
): Promise<any> =>
request(url, {
method: 'POST',
headers: getHeaders(token, headers),
body: JSON.stringify(data),
signal,
}),
postFormData: (
url: string,
formData: FormData,
token: string | null,
headers = {},
signal?: AbortSignal,
): Promise<Response> => {
return fetch(`${baseURL}${url}`, {
method: 'POST',
headers: getHeaders(token, headers, true),
body: formData,
signal,
});
},
postFormData: (
url: string,
formData: FormData,
token: string | null,
headers = {},
signal?: AbortSignal,
): Promise<Response> =>
request(url, {
method: 'POST',
headers: getHeaders(token, headers, true),
body: formData,
signal,
}),
put: (
url: string,
data: any,
token: string | null,
headers = {},
signal?: AbortSignal,
): Promise<any> =>
fetch(`${baseURL}${url}`, {
method: 'PUT',
headers: getHeaders(token, headers),
body: JSON.stringify(data),
signal,
}).then((response) => {
return response;
}),
put: (
url: string,
data: any,
token: string | null,
headers = {},
signal?: AbortSignal,
): Promise<any> =>
request(url, {
method: 'PUT',
headers: getHeaders(token, headers),
body: JSON.stringify(data),
signal,
}),
patch: (
url: string,
data: any,
token: string | null,
headers = {},
signal?: AbortSignal,
): Promise<any> =>
fetch(`${baseURL}${url}`, {
method: 'PATCH',
headers: getHeaders(token, headers),
body: JSON.stringify(data),
signal,
}).then((response) => {
return response;
}),
patch: (
url: string,
data: any,
token: string | null,
headers = {},
signal?: AbortSignal,
): Promise<any> =>
request(url, {
method: 'PATCH',
headers: getHeaders(token, headers),
body: JSON.stringify(data),
signal,
}),
putFormData: (
url: string,
formData: FormData,
token: string | null,
headers = {},
signal?: AbortSignal,
): Promise<Response> => {
return fetch(`${baseURL}${url}`, {
method: 'PUT',
headers: getHeaders(token, headers, true),
body: formData,
signal,
});
},
putFormData: (
url: string,
formData: FormData,
token: string | null,
headers = {},
signal?: AbortSignal,
): Promise<Response> =>
request(url, {
method: 'PUT',
headers: getHeaders(token, headers, true),
body: formData,
signal,
}),
delete: (
url: string,
token: string | null,
headers = {},
signal?: AbortSignal,
): Promise<any> =>
request(url, {
method: 'DELETE',
headers: getHeaders(token, headers),
signal,
}),
};
delete: (
url: string,
token: string | null,
headers = {},
signal?: AbortSignal,
): Promise<any> =>
fetch(`${baseURL}${url}`, {
method: 'DELETE',
headers: getHeaders(token, headers),
signal,
}).then((response) => {
return response;
}),
};
const apiClient = createClient((url, init) => fetch(url, init));
// Throttled client for endpoints that fan out, are polled, or are commonly
// requested concurrently from multiple components. Shares a single concurrency
// budget and de-duplicates identical in-flight GETs.
export const throttledApiClient = createClient(
withThrottle((url, init) => fetch(url, init), { debugLabel: 'api' }),
);
if (import.meta.env.DEV && typeof window !== 'undefined') {
(window as unknown as Record<string, unknown>).__apiClient = apiClient;
(window as unknown as Record<string, unknown>).__throttledApiClient =
throttledApiClient;
(window as unknown as Record<string, unknown>).__baseURL = baseURL;
}
export default apiClient;

View File

@@ -28,6 +28,7 @@ const endpoints = {
UPDATE_PROMPT: '/api/update_prompt',
SINGLE_PROMPT: (id: string) => `/api/get_single_prompt?id=${id}`,
DELETE_PATH: (docPath: string) => `/api/delete_old?source_id=${docPath}`,
TASK_STATUS: (task_id: string) => `/api/task_status?task_id=${task_id}`,
MESSAGE_ANALYTICS: '/api/get_message_analytics',
TOKEN_ANALYTICS: '/api/get_token_analytics',
FEEDBACK_ANALYTICS: '/api/get_feedback_analytics',
@@ -42,11 +43,6 @@ const endpoints = {
DELETE_TOOL: '/api/delete_tool',
PARSE_SPEC: '/api/parse_spec',
SYNC_CONNECTOR: '/api/connectors/sync',
CONNECTOR_AUTH: (provider: string) =>
`/api/connectors/auth?provider=${provider}`,
CONNECTOR_FILES: '/api/connectors/files',
CONNECTOR_VALIDATE_SESSION: '/api/connectors/validate-session',
CONNECTOR_DISCONNECT: '/api/connectors/disconnect',
GET_CHUNKS: (
docId: string,
page: number,
@@ -63,7 +59,6 @@ const endpoints = {
UPDATE_CHUNK: '/api/update_chunk',
STORE_ATTACHMENT: '/api/store_attachment',
STT: '/api/stt',
TTS: '/api/tts',
LIVE_STT_START: '/api/stt/live/start',
LIVE_STT_CHUNK: '/api/stt/live/chunk',
LIVE_STT_FINISH: '/api/stt/live/finish',
@@ -72,6 +67,8 @@ const endpoints = {
MANAGE_SOURCE_FILES: '/api/manage_source_files',
MCP_TEST_CONNECTION: '/api/mcp_server/test',
MCP_SAVE_SERVER: '/api/mcp_server/save',
MCP_OAUTH_STATUS: (task_id: string) =>
`/api/mcp_server/oauth_status/${task_id}`,
MCP_AUTH_STATUS: '/api/mcp_server/auth_status',
AGENT_FOLDERS: '/api/agents/folders/',
AGENT_FOLDER: (id: string) => `/api/agents/folders/${id}`,

View File

@@ -1,12 +1,11 @@
import { getSessionToken } from '../../utils/providerUtils';
import apiClient, { throttledApiClient } from '../client';
import apiClient from '../client';
import endpoints from '../endpoints';
const userService = {
getConfig: (): Promise<any> =>
throttledApiClient.get(endpoints.USER.CONFIG, null),
getConfig: (): Promise<any> => apiClient.get(endpoints.USER.CONFIG, null),
getNewToken: (): Promise<any> =>
throttledApiClient.get(endpoints.USER.NEW_TOKEN, null),
apiClient.get(endpoints.USER.NEW_TOKEN, null),
getDocs: (token: string | null): Promise<any> =>
apiClient.get(`${endpoints.USER.DOCS}`, token),
getDocsWithPagination: (query: string, token: string | null): Promise<any> =>
@@ -18,9 +17,9 @@ const userService = {
deleteAPIKey: (data: any, token: string | null): Promise<any> =>
apiClient.post(endpoints.USER.DELETE_API_KEY, data, token),
getAgent: (id: string, token: string | null): Promise<any> =>
throttledApiClient.get(endpoints.USER.AGENT(id), token),
apiClient.get(endpoints.USER.AGENT(id), token),
getAgents: (token: string | null): Promise<any> =>
throttledApiClient.get(endpoints.USER.AGENTS, token),
apiClient.get(endpoints.USER.AGENTS, token),
createAgent: (data: any, token: string | null): Promise<any> =>
apiClient.postFormData(endpoints.USER.CREATE_AGENT, data, token),
updateAgent: (
@@ -32,19 +31,19 @@ const userService = {
deleteAgent: (id: string, token: string | null): Promise<any> =>
apiClient.delete(endpoints.USER.DELETE_AGENT(id), token),
getPinnedAgents: (token: string | null): Promise<any> =>
throttledApiClient.get(endpoints.USER.PINNED_AGENTS, token),
apiClient.get(endpoints.USER.PINNED_AGENTS, token),
togglePinAgent: (id: string, token: string | null): Promise<any> =>
apiClient.post(endpoints.USER.TOGGLE_PIN_AGENT(id), {}, token),
getSharedAgent: (id: string, token: string | null): Promise<any> =>
apiClient.get(endpoints.USER.SHARED_AGENT(id), token),
getSharedAgents: (token: string | null): Promise<any> =>
throttledApiClient.get(endpoints.USER.SHARED_AGENTS, token),
apiClient.get(endpoints.USER.SHARED_AGENTS, token),
shareAgent: (data: any, token: string | null): Promise<any> =>
apiClient.put(endpoints.USER.SHARE_AGENT, data, token),
removeSharedAgent: (id: string, token: string | null): Promise<any> =>
apiClient.delete(endpoints.USER.REMOVE_SHARED_AGENT(id), token),
getTemplateAgents: (token: string | null): Promise<any> =>
throttledApiClient.get(endpoints.USER.TEMPLATE_AGENTS, token),
apiClient.get(endpoints.USER.TEMPLATE_AGENTS, token),
adoptAgent: (id: string, token: string | null): Promise<any> =>
apiClient.post(endpoints.USER.ADOPT_AGENT(id), {}, token),
getAgentWebhook: (id: string, token: string | null): Promise<any> =>
@@ -61,6 +60,8 @@ const userService = {
apiClient.get(endpoints.USER.SINGLE_PROMPT(id), token),
deletePath: (docPath: string, token: string | null): Promise<any> =>
apiClient.get(endpoints.USER.DELETE_PATH(docPath), token),
getTaskStatus: (task_id: string, token: string | null): Promise<any> =>
apiClient.get(endpoints.USER.TASK_STATUS(task_id), token),
getMessageAnalytics: (data: any, token: string | null): Promise<any> =>
apiClient.post(endpoints.USER.MESSAGE_ANALYTICS, data, token),
getTokenAnalytics: (data: any, token: string | null): Promise<any> =>
@@ -148,7 +149,7 @@ const userService = {
path?: string,
search?: string,
): Promise<any> =>
throttledApiClient.get(
apiClient.get(
endpoints.USER.GET_CHUNKS(docId, page, perPage, path, search),
token,
),
@@ -163,15 +164,17 @@ const userService = {
updateChunk: (data: any, token: string | null): Promise<any> =>
apiClient.put(endpoints.USER.UPDATE_CHUNK, data, token),
getDirectoryStructure: (docId: string, token: string | null): Promise<any> =>
throttledApiClient.get(endpoints.USER.DIRECTORY_STRUCTURE(docId), token),
apiClient.get(endpoints.USER.DIRECTORY_STRUCTURE(docId), token),
manageSourceFiles: (data: FormData, token: string | null): Promise<any> =>
apiClient.postFormData(endpoints.USER.MANAGE_SOURCE_FILES, data, token),
testMCPConnection: (data: any, token: string | null): Promise<any> =>
apiClient.post(endpoints.USER.MCP_TEST_CONNECTION, data, token),
saveMCPServer: (data: any, token: string | null): Promise<any> =>
apiClient.post(endpoints.USER.MCP_SAVE_SERVER, data, token),
getMCPOAuthStatus: (task_id: string, token: string | null): Promise<any> =>
apiClient.get(endpoints.USER.MCP_OAUTH_STATUS(task_id), token),
getMCPAuthStatus: (token: string | null): Promise<any> =>
throttledApiClient.get(endpoints.USER.MCP_AUTH_STATUS, token),
apiClient.get(endpoints.USER.MCP_AUTH_STATUS, token),
syncConnector: (
docId: string,
provider: string,
@@ -188,50 +191,8 @@ const userService = {
token,
);
},
getConnectorAuthUrl: (provider: string, token: string | null): Promise<any> =>
apiClient.get(endpoints.USER.CONNECTOR_AUTH(provider), token),
getConnectorFiles: (
data: any,
token: string | null,
signal?: AbortSignal,
): Promise<any> =>
throttledApiClient.post(
endpoints.USER.CONNECTOR_FILES,
data,
token,
{},
signal,
),
validateConnectorSession: (
provider: string,
token: string | null,
): Promise<any> =>
apiClient.post(
endpoints.USER.CONNECTOR_VALIDATE_SESSION,
{
provider,
session_token: getSessionToken(provider),
},
token,
),
disconnectConnector: (
provider: string,
sessionToken: string,
token: string | null,
): Promise<any> =>
apiClient.post(
endpoints.USER.CONNECTOR_DISCONNECT,
{ provider, session_token: sessionToken },
token,
),
textToSpeech: (
text: string,
token: string | null,
signal?: AbortSignal,
): Promise<any> =>
apiClient.post(endpoints.USER.TTS, { text }, token, {}, signal),
getAgentFolders: (token: string | null): Promise<any> =>
throttledApiClient.get(endpoints.USER.AGENT_FOLDERS, token),
apiClient.get(endpoints.USER.AGENT_FOLDERS, token),
createAgentFolder: (
data: { name: string; parent_id?: string },
token: string | null,

View File

@@ -1,223 +0,0 @@
/**
* Transport-layer middleware factory for the frontend API layer.
*/
export type FetchLike = (
input: string,
init?: RequestInit,
) => Promise<Response>;
export interface ThrottleConfig {
maxConcurrentGlobal?: number;
maxConcurrentPerRoute?: number;
dedupe?: boolean;
dedupeKey?: (url: string, init?: RequestInit) => string | false;
debugLabel?: string;
}
const DEFAULT_MAX_CONCURRENT_GLOBAL = 8;
const DEFAULT_MAX_CONCURRENT_PER_ROUTE = 3;
type QueueItem = {
run: () => void;
signal?: AbortSignal;
onAbort: () => void;
};
function routeKey(method: string, url: string): string {
let pathname = url;
try {
pathname = new URL(url, 'http://_').pathname;
} catch {
pathname = url.split('?')[0];
}
return `${method.toUpperCase()} ${pathname}`;
}
function abortError(): DOMException {
return new DOMException('The operation was aborted.', 'AbortError');
}
interface ThrottleState {
perRouteQueues: Map<string, QueueItem[]>;
inflightPerRoute: Map<string, number>;
inflightGets: Map<string, Promise<Response>>;
inflightGlobal: number;
}
function createState(): ThrottleState {
return {
perRouteQueues: new Map(),
inflightPerRoute: new Map(),
inflightGets: new Map(),
inflightGlobal: 0,
};
}
export function withThrottle(
fetchLike: FetchLike,
config: ThrottleConfig = {},
): FetchLike & { __reset: () => void } {
const maxGlobal = config.maxConcurrentGlobal ?? DEFAULT_MAX_CONCURRENT_GLOBAL;
const maxPerRoute =
config.maxConcurrentPerRoute ?? DEFAULT_MAX_CONCURRENT_PER_ROUTE;
const dedupeEnabled = config.dedupe !== false;
const state = createState();
// Toggle in DevTools with: localStorage.setItem('debug:throttle', '1')
const isDebug = (): boolean => {
try {
return (
typeof localStorage !== 'undefined' &&
localStorage.getItem('debug:throttle') === '1'
);
} catch {
return false;
}
};
const log = (
event: string,
key: string,
extra?: Record<string, unknown>,
): void => {
if (!isDebug()) return;
const queued = state.perRouteQueues.get(key)?.length ?? 0;
const perRoute = state.inflightPerRoute.get(key) ?? 0;
const tag = config.debugLabel
? `[throttle:${config.debugLabel}]`
: '[throttle]';
console.debug(
`${tag} ${event} ${key} | inflight=${state.inflightGlobal}/${maxGlobal} route=${perRoute}/${maxPerRoute} queued=${queued}`,
extra ?? '',
);
};
const canDispatch = (key: string): boolean => {
const perRoute = state.inflightPerRoute.get(key) ?? 0;
return state.inflightGlobal < maxGlobal && perRoute < maxPerRoute;
};
const pumpQueues = (): void => {
for (const [key, queue] of state.perRouteQueues) {
while (queue.length > 0 && canDispatch(key)) {
const item = queue.shift()!;
item.signal?.removeEventListener('abort', item.onAbort);
item.run();
}
if (queue.length === 0) state.perRouteQueues.delete(key);
}
};
const enqueue = (key: string, item: QueueItem): void => {
let queue = state.perRouteQueues.get(key);
if (!queue) {
queue = [];
state.perRouteQueues.set(key, queue);
}
queue.push(item);
};
const acquireSlot = (key: string, signal?: AbortSignal): Promise<void> =>
new Promise((resolve, reject) => {
if (signal?.aborted) {
reject(abortError());
return;
}
const item: QueueItem = {
signal,
run: () => {
state.inflightGlobal += 1;
state.inflightPerRoute.set(
key,
(state.inflightPerRoute.get(key) ?? 0) + 1,
);
resolve();
},
onAbort: () => {
const queue = state.perRouteQueues.get(key);
if (queue) {
const idx = queue.indexOf(item);
if (idx >= 0) queue.splice(idx, 1);
}
log('abort-queued', key);
reject(abortError());
},
};
const queued = state.perRouteQueues.get(key);
if ((!queued || queued.length === 0) && canDispatch(key)) {
item.run();
log('dispatch', key);
return;
}
signal?.addEventListener('abort', item.onAbort, { once: true });
enqueue(key, item);
log('queued', key);
});
const releaseSlot = (key: string): void => {
state.inflightGlobal = Math.max(0, state.inflightGlobal - 1);
const next = (state.inflightPerRoute.get(key) ?? 1) - 1;
if (next <= 0) state.inflightPerRoute.delete(key);
else state.inflightPerRoute.set(key, next);
log('release', key);
pumpQueues();
};
const wrapped = (async (url, init = {}) => {
const method = (init.method ?? 'GET').toUpperCase();
const signal = init.signal ?? undefined;
const key = routeKey(method, url);
// Dedupe is restricted to GETs without a caller-supplied AbortSignal:
// sharing a single underlying fetch across waiters means an abort by one
// caller would reject the others, which is not the contract callers expect.
const customKey = config.dedupeKey?.(url, init);
const dedupeAllowed =
dedupeEnabled &&
customKey !== false &&
method === 'GET' &&
!init.body &&
!signal;
const dedupeKey = typeof customKey === 'string' ? customKey : `GET ${url}`;
if (dedupeAllowed) {
const existing = state.inflightGets.get(dedupeKey);
if (existing) {
log('dedupe-hit', key, { dedupeKey });
return existing.then((r) => r.clone());
}
}
const run = async (): Promise<Response> => {
await acquireSlot(key, signal);
try {
return await fetchLike(url, init);
} finally {
releaseSlot(key);
}
};
if (dedupeAllowed) {
const promise = run();
state.inflightGets.set(dedupeKey, promise);
promise.finally(() => {
if (state.inflightGets.get(dedupeKey) === promise) {
state.inflightGets.delete(dedupeKey);
}
});
return promise.then((r) => r.clone());
}
return run();
}) as FetchLike & { __reset: () => void };
wrapped.__reset = () => {
state.perRouteQueues.clear();
state.inflightPerRoute.clear();
state.inflightGets.clear();
state.inflightGlobal = 0;
};
return wrapped;
}

View File

@@ -11,7 +11,6 @@ import NoFilesIcon from '../assets/no-files.svg';
import SearchIcon from '../assets/search.svg';
import {
useDarkTheme,
useDebouncedValue,
useLoaderState,
useMediaQuery,
useOutsideAlerter,
@@ -131,7 +130,6 @@ const Chunks: React.FC<ChunksProps> = ({
const [totalChunks, setTotalChunks] = useState(0);
const [loading, setLoading] = useLoaderState(true);
const [searchTerm, setSearchTerm] = useState<string>('');
const debouncedSearchTerm = useDebouncedValue(searchTerm, 300);
const [editingChunk, setEditingChunk] = useState<ChunkType | null>(null);
const [editingTitle, setEditingTitle] = useState('');
const [editingText, setEditingText] = useState('');
@@ -153,7 +151,7 @@ const Chunks: React.FC<ChunksProps> = ({
perPage,
token,
path,
debouncedSearchTerm,
searchTerm,
);
if (!response.ok) {
@@ -278,12 +276,16 @@ const Chunks: React.FC<ChunksProps> = ({
};
useEffect(() => {
if (page !== 1) {
setPage(1);
} else {
fetchChunks();
}
}, [debouncedSearchTerm]);
const delayDebounceFn = setTimeout(() => {
if (page !== 1) {
setPage(1);
} else {
fetchChunks();
}
}, 300);
return () => clearTimeout(delayDebounceFn);
}, [searchTerm]);
useEffect(() => {
!loading && fetchChunks();

View File

@@ -1,8 +1,7 @@
import React, { useEffect, useRef } from 'react';
import React, { useRef } from 'react';
import { useTranslation } from 'react-i18next';
import { useSelector } from 'react-redux';
import userService from '../api/services/userService';
import { useDarkTheme } from '../hooks';
import { selectToken } from '../preferences/preferenceSlice';
@@ -32,24 +31,13 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
const [isDarkTheme] = useDarkTheme();
const completedRef = useRef(false);
const intervalRef = useRef<number | null>(null);
const authWindowRef = useRef<Window | null>(null);
// Hold the exact listener identity so unmount cleanup removes the same fn.
const messageHandlerRef = useRef<((event: MessageEvent) => void) | null>(
null,
);
// Tracks mount status so async ``fetch`` resolves after unmount don't
// call ``onSuccess`` / ``onError`` on a vanished parent.
const mountedRef = useRef(true);
const cleanup = () => {
if (intervalRef.current) {
clearInterval(intervalRef.current);
intervalRef.current = null;
}
if (messageHandlerRef.current) {
window.removeEventListener('message', messageHandlerRef.current as any);
messageHandlerRef.current = null;
}
window.removeEventListener('message', handleAuthMessage as any);
};
const handleAuthMessage = (event: MessageEvent) => {
@@ -60,7 +48,6 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
if (successGeneric || successProvider) {
completedRef.current = true;
cleanup();
authWindowRef.current = null;
onSuccess({
session_token: event.data.session_token,
user_email:
@@ -70,7 +57,6 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
} else if (errorProvider) {
completedRef.current = true;
cleanup();
authWindowRef.current = null;
onError(
event.data.error || t('modals.uploadDoc.connectors.auth.authFailed'),
);
@@ -80,20 +66,15 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
const handleAuth = async () => {
try {
completedRef.current = false;
// Close any popup left over from a previous click before wiping
// the ref — otherwise the old window keeps living with no
// interval watching it and no listener handling its messages.
if (authWindowRef.current && !authWindowRef.current.closed) {
authWindowRef.current.close();
}
authWindowRef.current = null;
cleanup();
const authResponse = await userService.getConnectorAuthUrl(
provider,
token,
const apiHost = import.meta.env.VITE_API_HOST;
const authResponse = await fetch(
`${apiHost}/api/connectors/auth?provider=${provider}`,
{
headers: { Authorization: `Bearer ${token}` },
},
);
if (!mountedRef.current) return;
if (!authResponse.ok) {
throw new Error(
@@ -102,7 +83,6 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
}
const authData = await authResponse.json();
if (!mountedRef.current) return;
if (!authData.success || !authData.authorization_url) {
throw new Error(
authData.error || t('modals.uploadDoc.connectors.auth.authUrlFailed'),
@@ -117,23 +97,13 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
if (!authWindow) {
throw new Error(t('modals.uploadDoc.connectors.auth.popupBlocked'));
}
authWindowRef.current = authWindow;
messageHandlerRef.current = handleAuthMessage;
window.addEventListener('message', handleAuthMessage as any);
const checkClosed = window.setInterval(() => {
if (authWindow.closed) {
clearInterval(checkClosed);
intervalRef.current = null;
if (messageHandlerRef.current) {
window.removeEventListener(
'message',
messageHandlerRef.current as any,
);
messageHandlerRef.current = null;
}
authWindowRef.current = null;
window.removeEventListener('message', handleAuthMessage as any);
if (!completedRef.current) {
onError(t('modals.uploadDoc.connectors.auth.authCancelled'));
}
@@ -141,7 +111,6 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
}, 1000);
intervalRef.current = checkClosed;
} catch (error) {
if (!mountedRef.current) return;
onError(
error instanceof Error
? error.message
@@ -150,18 +119,6 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
}
};
// Release interval, message listener, and popup on unmount only.
useEffect(() => {
return () => {
mountedRef.current = false;
cleanup();
if (authWindowRef.current && !authWindowRef.current.closed) {
authWindowRef.current.close();
}
authWindowRef.current = null;
};
}, []);
return (
<>
{errorMessage && (

View File

@@ -1,6 +1,6 @@
import React, { useEffect, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useSelector, useStore } from 'react-redux';
import { useSelector } from 'react-redux';
import userService from '../api/services/userService';
import ArrowLeft from '../assets/arrow-left.svg';
@@ -14,7 +14,6 @@ import { useLoaderState, useOutsideAlerter } from '../hooks';
import ConfirmationModal from '../modals/ConfirmationModal';
import { ActiveState } from '../models/misc';
import { selectToken } from '../preferences/preferenceSlice';
import type { RootState } from '../store';
import { formatBytes } from '../utils/stringUtils';
import Chunks from './Chunks';
import ContextMenu, { MenuOption } from './ContextMenu';
@@ -65,7 +64,6 @@ const ConnectorTree: React.FC<ConnectorTreeProps> = ({
useState<DirectoryStructure | null>(null);
const [currentPath, setCurrentPath] = useState<string[]>([]);
const token = useSelector(selectToken);
const store = useStore<RootState>();
const [activeMenuId, setActiveMenuId] = useState<string | null>(null);
const menuRefs = useRef<{
[key: string]: React.RefObject<HTMLDivElement | null>;
@@ -83,25 +81,6 @@ const ConnectorTree: React.FC<ConnectorTreeProps> = ({
const [syncDone, setSyncDone] = useState<boolean>(false);
const [syncConfirmationModal, setSyncConfirmationModal] =
useState<ActiveState>('INACTIVE');
const mountedRef = useRef(true);
const syncUnsubscribeRef = useRef<(() => void) | null>(null);
// Holds the 5-minute SSE-wait timer so the unmount cleanup can clear
// it — otherwise the timer fires up to 5 min after unmount and
// resolves an abandoned Promise.
const syncTimerRef = useRef<number | null>(null);
useEffect(
() => () => {
mountedRef.current = false;
syncUnsubscribeRef.current?.();
syncUnsubscribeRef.current = null;
if (syncTimerRef.current !== null) {
window.clearTimeout(syncTimerRef.current);
syncTimerRef.current = null;
}
},
[],
);
useOutsideAlerter(
searchDropdownRef,
@@ -137,108 +116,67 @@ const ConnectorTree: React.FC<ConnectorTreeProps> = ({
console.log('Sync started successfully:', data.task_id);
setSyncProgress(10);
// The connector worker (``ingest_connector`` in
// ``application/worker.py``) publishes
// ``source.ingest.{queued,completed,failed}`` envelopes keyed on
// ``scope.id == docId`` (sync mode reuses the source uuid). Wait
// on the bounded ``notifications.recentEvents`` ring for a
// terminal envelope rather than polling ``/task_status``.
// Mirrors FileTree's slice-walking pattern, including the
// ``opStartedAt`` guard so a stale terminal event from a prior
// sync of this same source can't short-circuit the current op.
const opStartedAt = Date.now();
const terminalFromSse = (): 'completed' | 'failed' | null => {
const events = store.getState().notifications.recentEvents;
for (const event of events) {
if (event.scope?.id !== docId) continue;
const ts = event.ts ? Date.parse(event.ts) : NaN;
if (!Number.isFinite(ts) || ts < opStartedAt) continue;
if (event.type === 'source.ingest.completed') return 'completed';
if (event.type === 'source.ingest.failed') return 'failed';
}
return null;
};
const MAX_WAIT_MS = 5 * 60_000;
const terminal = await new Promise<
'completed' | 'failed' | 'timeout' | 'unmounted'
>((resolve) => {
// Cover the race where the event landed between the POST
// returning and the subscribe call.
const initial = terminalFromSse();
if (initial) {
resolve(initial);
return;
}
if (!mountedRef.current) {
resolve('unmounted');
return;
}
let settled = false;
const finish = (
value: 'completed' | 'failed' | 'timeout' | 'unmounted',
) => {
if (settled) return;
settled = true;
if (syncTimerRef.current !== null) {
window.clearTimeout(syncTimerRef.current);
syncTimerRef.current = null;
}
if (syncUnsubscribeRef.current) {
syncUnsubscribeRef.current();
syncUnsubscribeRef.current = null;
}
resolve(value);
};
syncTimerRef.current = window.setTimeout(
() => finish('timeout'),
MAX_WAIT_MS,
);
syncUnsubscribeRef.current = store.subscribe(() => {
if (!mountedRef.current) {
finish('unmounted');
return;
}
const next = terminalFromSse();
if (next) finish(next);
});
});
if (terminal === 'timeout') {
console.error('Sync timed out waiting for SSE terminal');
} else if (terminal === 'unmounted') {
return;
}
if (terminal === 'completed') {
// The "no files downloaded" early-return path publishes
// ``completed`` with ``no_changes: true`` — treated as success
// here; refreshing the directory is cheap and idempotent.
setSyncProgress(100);
console.log('Sync completed successfully');
// Poll task status using userService
const maxAttempts = 30;
const pollInterval = 2000;
for (let attempt = 0; attempt < maxAttempts; attempt++) {
try {
const refreshResponse = await userService.getDirectoryStructure(
docId,
const statusResponse = await userService.getTaskStatus(
data.task_id,
token,
);
const refreshData = await refreshResponse.json();
if (refreshData && refreshData.directory_structure) {
setDirectoryStructure(refreshData.directory_structure);
setCurrentPath([]);
}
if (refreshData && refreshData.provider) {
setSourceProvider(refreshData.provider);
const statusData = await statusResponse.json();
console.log(
`Task status (attempt ${attempt + 1}):`,
statusData.status,
);
if (statusData.status === 'SUCCESS') {
setSyncProgress(100);
console.log('Sync completed successfully');
// Refresh directory structure
try {
const refreshResponse = await userService.getDirectoryStructure(
docId,
token,
);
const refreshData = await refreshResponse.json();
if (refreshData && refreshData.directory_structure) {
setDirectoryStructure(refreshData.directory_structure);
setCurrentPath([]);
}
if (refreshData && refreshData.provider) {
setSourceProvider(refreshData.provider);
}
setSyncDone(true);
setTimeout(() => setSyncDone(false), 5000);
} catch (err) {
console.error('Error refreshing directory structure:', err);
}
break;
} else if (statusData.status === 'FAILURE') {
console.error('Sync task failed:', statusData.result);
break;
} else if (statusData.status === 'PROGRESS') {
const progress = Number(
statusData.result && statusData.result.current != null
? statusData.result.current
: statusData.meta && statusData.meta.current != null
? statusData.meta.current
: 0,
);
setSyncProgress(Math.max(10, progress));
}
setSyncDone(true);
setTimeout(() => setSyncDone(false), 5000);
} catch (err) {
console.error('Error refreshing directory structure:', err);
await new Promise((resolve) => setTimeout(resolve, pollInterval));
} catch (error) {
console.error('Error polling task status:', error);
break;
}
} else if (terminal === 'failed') {
console.error('Sync task failed (per SSE)');
}
} else {
console.error('Sync failed:', data.error);

View File

@@ -1,6 +1,5 @@
import React, { useState, useEffect, useCallback, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import userService from '../api/services/userService';
import { formatBytes } from '../utils/stringUtils';
import { formatDate } from '../utils/dateTimeUtils';
import {
@@ -23,7 +22,6 @@ import {
TableHeader,
TableCell,
} from './Table';
import { useDebouncedCallback } from '../hooks';
interface CloudFile {
id: string;
@@ -102,6 +100,7 @@ export const FilePicker: React.FC<CloudFilePickerProps> = ({
const [activeTab, setActiveTab] = useState<'my_files' | 'shared'>('my_files');
const scrollContainerRef = useRef<HTMLDivElement>(null);
const searchTimeoutRef = useRef<ReturnType<typeof setTimeout> | null>(null);
const abortControllerRef = useRef<AbortController | null>(null);
const isFolder = (file: CloudFile) => {
@@ -127,6 +126,7 @@ export const FilePicker: React.FC<CloudFilePickerProps> = ({
setIsLoading(true);
const apiHost = import.meta.env.VITE_API_HOST;
if (!pageToken) {
setFiles([]);
}
@@ -141,11 +141,15 @@ export const FilePicker: React.FC<CloudFilePickerProps> = ({
search_query: searchQuery,
shared: shared,
};
const response = await userService.getConnectorFiles(
body,
token,
controller.signal,
);
const response = await fetch(`${apiHost}/api/connectors/files`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`,
},
body: JSON.stringify(body),
signal: controller.signal,
});
const data = await response.json();
if (data.success) {
@@ -183,9 +187,20 @@ export const FilePicker: React.FC<CloudFilePickerProps> = ({
}
try {
const validateResponse = await userService.validateConnectorSession(
provider,
token,
const apiHost = import.meta.env.VITE_API_HOST;
const validateResponse = await fetch(
`${apiHost}/api/connectors/validate-session`,
{
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`,
},
body: JSON.stringify({
provider: provider,
session_token: sessionToken,
}),
},
);
if (!validateResponse.ok) {
@@ -277,26 +292,32 @@ export const FilePicker: React.FC<CloudFilePickerProps> = ({
useEffect(() => {
return () => {
if (searchTimeoutRef.current) {
clearTimeout(searchTimeoutRef.current);
}
abortControllerRef.current?.abort();
};
}, []);
const debouncedLoadFiles = useDebouncedCallback((query: string) => {
const sessionToken = getSessionToken(provider);
if (sessionToken) {
loadCloudFiles(
sessionToken,
currentFolderId,
undefined,
query,
activeTab === 'shared' && !currentFolderId,
);
}
}, 300);
const handleSearchChange = (query: string) => {
setSearchQuery(query);
debouncedLoadFiles(query);
if (searchTimeoutRef.current) {
clearTimeout(searchTimeoutRef.current);
}
searchTimeoutRef.current = setTimeout(() => {
const sessionToken = getSessionToken(provider);
if (sessionToken) {
loadCloudFiles(
sessionToken,
currentFolderId,
undefined,
query,
activeTab === 'shared' && !currentFolderId,
);
}
}, 300);
};
const handleFolderClick = (folderId: string, folderName: string) => {
@@ -403,14 +424,23 @@ export const FilePicker: React.FC<CloudFilePickerProps> = ({
onDisconnect={() => {
const sessionToken = getSessionToken(provider);
if (sessionToken) {
userService
.disconnectConnector(provider, sessionToken, token)
.catch((err) =>
console.error(
`Error disconnecting from ${getProviderConfig(provider).displayName}:`,
err,
),
);
const apiHost = import.meta.env.VITE_API_HOST;
fetch(`${apiHost}/api/connectors/disconnect`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`,
},
body: JSON.stringify({
provider: provider,
session_token: sessionToken,
}),
}).catch((err) =>
console.error(
`Error disconnecting from ${getProviderConfig(provider).displayName}:`,
err,
),
);
}
removeSessionToken(provider);

View File

@@ -1,8 +1,7 @@
import React, { useState, useRef, useEffect } from 'react';
import { useTranslation } from 'react-i18next';
import { useSelector, useStore } from 'react-redux';
import { useSelector } from 'react-redux';
import { selectToken } from '../preferences/preferenceSlice';
import type { RootState } from '../store';
import { formatBytes } from '../utils/stringUtils';
import Chunks from './Chunks';
import ContextMenu, { MenuOption } from './ContextMenu';
@@ -57,7 +56,6 @@ const FileTree: React.FC<FileTreeProps> = ({
onBackToDocuments,
}) => {
const { t } = useTranslation();
const store = useStore<RootState>();
const [loading, setLoading] = useLoaderState(true, 500);
const [error, setError] = useState<string | null>(null);
const [directoryStructure, setDirectoryStructure] =
@@ -97,25 +95,6 @@ const FileTree: React.FC<FileTreeProps> = ({
const opQueueRef = useRef<QueuedOperation[]>([]);
const processingRef = useRef(false);
const [queueLength, setQueueLength] = useState(0);
const mountedRef = useRef(true);
const waitUnsubscribeRef = useRef<(() => void) | null>(null);
// Holds the 5-minute SSE-wait timer so the unmount cleanup can clear
// it — otherwise the timer fires up to 5 min after unmount and
// resolves an abandoned Promise.
const waitTimerRef = useRef<number | null>(null);
useEffect(
() => () => {
mountedRef.current = false;
waitUnsubscribeRef.current?.();
waitUnsubscribeRef.current = null;
if (waitTimerRef.current !== null) {
window.clearTimeout(waitTimerRef.current);
waitTimerRef.current = null;
}
},
[],
);
useOutsideAlerter(
searchDropdownRef,
@@ -334,103 +313,47 @@ const FileTree: React.FC<FileTreeProps> = ({
}
console.log('Reingest task started:', result.reingest_task_id);
// SSE is the sole driver here. The backend's
// ``reingest_source_worker`` publishes ``source.ingest.*``
// keyed on the resolved ``source_id`` (the
// ``manage_source_files`` route returns it explicitly so we
// can match without consulting any slice). Subscribe to the
// store and resolve when a terminal event tagged with our
// source lands in ``notifications.recentEvents``. Re-checking
// on every dispatch (rather than polling on a timer) avoids
// races where a terminal could roll off the bounded ring
// before the next tick observes it in chatty sessions.
const reingestSourceId: string | undefined = result.source_id;
// Cutoff so we don't pick up terminal events from a *previous*
// reingest of the same source — the backend's
// ``source.ingest.*`` payload doesn't carry a Celery task id,
// so source_id alone is ambiguous when ops repeat.
const opStartedAt = Date.now();
const MAX_WAIT_MS = 5 * 60_000;
const maxAttempts = 30;
const pollInterval = 2000;
const terminalFromSse = (): 'completed' | 'failed' | null => {
if (!reingestSourceId) return null;
const events = store.getState().notifications.recentEvents;
for (const event of events) {
if (event.scope?.id !== reingestSourceId) continue;
const ts = event.ts ? Date.parse(event.ts) : NaN;
if (!Number.isFinite(ts) || ts < opStartedAt) continue;
if (event.type === 'source.ingest.completed') return 'completed';
if (event.type === 'source.ingest.failed') return 'failed';
}
return null;
};
for (let attempt = 0; attempt < maxAttempts; attempt++) {
try {
const statusResponse = await userService.getTaskStatus(
result.reingest_task_id,
token,
);
const statusData = await statusResponse.json();
const refreshStructure = async (): Promise<boolean> => {
const structureResponse = await userService.getDirectoryStructure(
docId,
token,
);
const structureData = await structureResponse.json();
if (!mountedRef.current) return false;
if (structureData && structureData.directory_structure) {
setDirectoryStructure(structureData.directory_structure);
currentOpRef.current = null;
return true;
}
return false;
};
console.log(
`Task status (attempt ${attempt + 1}):`,
statusData.status,
);
const terminal = await new Promise<
'completed' | 'failed' | 'timeout' | 'unmounted'
>((resolve) => {
if (!mountedRef.current) {
resolve('unmounted');
return;
}
// Cover the race where the terminal event landed between
// the POST returning and the subscribe call.
const initial = terminalFromSse();
if (initial) {
resolve(initial);
return;
}
const timer = window.setTimeout(() => {
waitUnsubscribeRef.current?.();
waitUnsubscribeRef.current = null;
waitTimerRef.current = null;
resolve('timeout');
}, MAX_WAIT_MS);
waitTimerRef.current = timer;
waitUnsubscribeRef.current = store.subscribe(() => {
if (!mountedRef.current) {
window.clearTimeout(timer);
waitTimerRef.current = null;
waitUnsubscribeRef.current?.();
waitUnsubscribeRef.current = null;
resolve('unmounted');
return;
if (statusData.status === 'SUCCESS') {
console.log('Task completed successfully');
const structureResponse = await userService.getDirectoryStructure(
docId,
token,
);
const structureData = await structureResponse.json();
if (structureData && structureData.directory_structure) {
setDirectoryStructure(structureData.directory_structure);
currentOpRef.current = null;
return true;
}
break;
} else if (statusData.status === 'FAILURE') {
console.error('Task failed');
break;
}
const next = terminalFromSse();
if (next) {
window.clearTimeout(timer);
waitTimerRef.current = null;
waitUnsubscribeRef.current?.();
waitUnsubscribeRef.current = null;
resolve(next);
}
});
});
if (!mountedRef.current) return false;
if (terminal === 'completed') {
if (await refreshStructure()) return true;
} else if (terminal === 'failed') {
console.error('Reingest task failed (per SSE)');
} else if (terminal === 'unmounted') {
return false;
} else {
console.error('Reingest timed out waiting for SSE terminal');
await new Promise((resolve) => setTimeout(resolve, pollInterval));
} catch (error) {
console.error('Error polling task status:', error);
break;
}
}
} else {
throw new Error(
@@ -451,7 +374,7 @@ const FileTree: React.FC<FileTreeProps> = ({
? 'delete directory'
: 'delete file(s)';
console.error(`Error ${actionText}:`, error);
if (mountedRef.current) setError(`Failed to ${errorText}`);
setError(`Failed to ${errorText}`);
} finally {
currentOpRef.current = null;
}

View File

@@ -2,7 +2,6 @@ import React, { useState, useEffect } from 'react';
import { useTranslation } from 'react-i18next';
import useDrivePicker from 'react-google-drive-picker';
import userService from '../api/services/userService';
import ConnectorAuth from './ConnectorAuth';
import {
getSessionToken,
@@ -200,11 +199,18 @@ const GoogleDrivePicker: React.FC<GoogleDrivePickerProps> = ({
const sessionToken = getSessionToken('google_drive');
if (sessionToken) {
try {
await userService.disconnectConnector(
'google_drive',
sessionToken,
token,
);
const apiHost = import.meta.env.VITE_API_HOST;
await fetch(`${apiHost}/api/connectors/disconnect`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`,
},
body: JSON.stringify({
provider: 'google_drive',
session_token: sessionToken,
}),
});
} catch (err) {
console.error('Error disconnecting from Google Drive:', err);
}

View File

@@ -3,7 +3,7 @@ import { createPortal } from 'react-dom';
import { LoaderCircle, Mic, Square } from 'lucide-react';
import { useDropzone } from 'react-dropzone';
import { useTranslation } from 'react-i18next';
import { useDispatch, useSelector, useStore } from 'react-redux';
import { useDispatch, useSelector } from 'react-redux';
import endpoints from '../api/endpoints';
import userService from '../api/services/userService';
@@ -28,7 +28,6 @@ import {
selectSelectedDocs,
selectToken,
} from '../preferences/preferenceSlice';
import type { RootState } from '../store';
import Upload from '../upload/Upload';
import { getOS, isTouchDevice } from '../utils/browserUtils';
import SourcesPopup from './SourcesPopup';
@@ -317,7 +316,6 @@ export default function MessageInput({
const attachments = useSelector(selectAttachments);
const dispatch = useDispatch();
const store = useStore<RootState>();
const mediaStreamRef = useRef<MediaStream | null>(null);
const audioContextRef = useRef<AudioContext | null>(null);
const audioSourceNodeRef = useRef<MediaStreamAudioSourceNode | null>(null);
@@ -412,86 +410,6 @@ export default function MessageInput({
};
}, []);
// Recover the race where attachment.* SSE arrives before the upload
// XHR's onload sets ``attachmentId``: walk recentEvents and watchdog
// the row so it can't stay stuck on 'processing'. Mirrors
// Upload.tsx's ``trackTraining``.
const trackAttachment = useCallback(
(clientId: string, attachmentId: string) => {
let handled = false;
const check = () => {
const state = store.getState();
const row = state.upload.attachments.find((a) => a.id === clientId);
if (!row) return true; // removed by user; stop tracking
if (row.status === 'completed' || row.status === 'failed') {
handled = true;
return true;
}
for (const event of state.notifications.recentEvents) {
if (event.scope?.id !== attachmentId) continue;
if (event.type === 'attachment.completed') {
const payload = (event.payload || {}) as Record<string, unknown>;
const tokenCount = Number(payload.token_count);
handled = true;
dispatch(
updateAttachment({
id: clientId,
updates: {
status: 'completed',
progress: 100,
...(Number.isFinite(tokenCount)
? { token_count: tokenCount }
: {}),
},
}),
);
return true;
}
if (event.type === 'attachment.failed') {
handled = true;
dispatch(
updateAttachment({
id: clientId,
updates: { status: 'failed' },
}),
);
return true;
}
}
return false;
};
if (check()) return;
const MAX_WAIT_MS = 5 * 60_000;
let unsubscribe: (() => void) | null = null;
const timer = window.setTimeout(() => {
unsubscribe?.();
if (!handled) {
handled = true;
console.warn(
'trackAttachment: timed out waiting for terminal SSE',
clientId,
attachmentId,
);
dispatch(
updateAttachment({
id: clientId,
updates: { status: 'failed' },
}),
);
}
}, MAX_WAIT_MS);
unsubscribe = store.subscribe(() => {
if (check()) {
window.clearTimeout(timer);
unsubscribe?.();
}
});
},
[dispatch, store],
);
const uploadFiles = useCallback(
(files: File[]) => {
if (!files || files.length === 0) return;
@@ -592,19 +510,11 @@ export default function MessageInput({
id: uiId,
updates: {
taskId: task.task_id,
// Stash the server's attachment id so SSE
// ``attachment.*`` events can match this
// row by ``scope.id`` and drive the
// per-attachment push-fresh poll gate.
attachmentId: task.attachment_id,
status: 'processing',
progress: 10,
},
}),
);
if (task.attachment_id) {
trackAttachment(uiId, task.attachment_id);
}
return;
}
@@ -635,15 +545,11 @@ export default function MessageInput({
id: uiId,
updates: {
taskId: t.task_id,
attachmentId: t.attachment_id,
status: 'processing',
progress: 10,
},
}),
);
if (t.attachment_id) {
trackAttachment(uiId, t.attachment_id);
}
} else {
dispatch(
updateAttachment({
@@ -677,15 +583,11 @@ export default function MessageInput({
id: uiId,
updates: {
taskId: response.task_id,
attachmentId: response.attachment_id,
status: 'processing',
progress: 10,
},
}),
);
if (response.attachment_id) {
trackAttachment(uiId, response.attachment_id);
}
}
} else {
console.warn(
@@ -812,15 +714,11 @@ export default function MessageInput({
id: uniqueId,
updates: {
taskId: response.task_id,
attachmentId: response.attachment_id,
status: 'processing',
progress: 10,
},
}),
);
if (response.attachment_id) {
trackAttachment(uniqueId, response.attachment_id);
}
} else {
// If backend returned tasks[] for single-file, handle gracefully:
if (
@@ -832,15 +730,11 @@ export default function MessageInput({
id: uniqueId,
updates: {
taskId: response.tasks[0].task_id,
attachmentId: response.tasks[0].attachment_id,
status: 'processing',
progress: 10,
},
}),
);
if (response.tasks[0].attachment_id) {
trackAttachment(uniqueId, response.tasks[0].attachment_id);
}
} else {
dispatch(
updateAttachment({
@@ -887,7 +781,7 @@ export default function MessageInput({
xhr.send(formData);
});
},
[dispatch, token, trackAttachment],
[dispatch, token],
);
const handleFileAttachment = (e: React.ChangeEvent<HTMLInputElement>) => {
@@ -922,6 +816,65 @@ export default function MessageInput({
accept: FILE_UPLOAD_ACCEPT,
});
useEffect(() => {
const checkTaskStatus = () => {
const processingAttachments = attachments.filter(
(att) => att.status === 'processing' && att.taskId,
);
processingAttachments.forEach((attachment) => {
userService
.getTaskStatus(attachment.taskId!, null)
.then((data) => data.json())
.then((data) => {
if (data.status === 'SUCCESS') {
dispatch(
updateAttachment({
id: attachment.id,
updates: {
status: 'completed',
progress: 100,
id: data.result?.attachment_id,
token_count: data.result?.token_count,
},
}),
);
} else if (data.status === 'FAILURE') {
dispatch(
updateAttachment({
id: attachment.id,
updates: { status: 'failed' },
}),
);
} else if (data.status === 'PROGRESS' && data.result?.current) {
dispatch(
updateAttachment({
id: attachment.id,
updates: { progress: data.result.current },
}),
);
}
})
.catch(() => {
dispatch(
updateAttachment({
id: attachment.id,
updates: { status: 'failed' },
}),
);
});
});
};
const interval = setInterval(() => {
if (attachments.some((att) => att.status === 'processing')) {
checkTaskStatus();
}
}, 2000);
return () => clearInterval(interval);
}, [attachments, dispatch]);
const handleInput = useCallback(() => {
if (inputRef.current) {
if (window.innerWidth < 350) inputRef.current.style.height = 'auto';

View File

@@ -2,7 +2,8 @@ import { useState, useRef, useEffect } from 'react';
import Speaker from '../assets/speaker.svg?react';
import Stopspeech from '../assets/stopspeech.svg?react';
import LoadingIcon from '../assets/Loading.svg?react'; // Add a loading icon SVG here
import userService from '../api/services/userService';
const apiHost = import.meta.env.VITE_API_HOST || 'https://docsapi.arc53.com';
let currentlyPlayingAudio: {
audio: HTMLAudioElement;
@@ -113,11 +114,12 @@ export default function SpeakButton({ text }: { text: string }) {
},
};
const response = await userService.textToSpeech(
text,
null,
abortController.signal,
);
const response = await fetch(apiHost + '/api/tts', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ text }),
signal: abortController.signal,
});
const data = await response.json();
abortControllerRef.current = null;

View File

@@ -5,54 +5,41 @@ import { useDispatch, useSelector } from 'react-redux';
import CheckCircleFilled from '../assets/check-circle-filled.svg';
import ChevronDown from '../assets/chevron-down.svg';
import WarnIcon from '../assets/warn.svg';
import {
dismissUploadTask,
selectUploadTasks,
type UploadTask,
} from '../upload/uploadSlice';
import { dismissUploadTask, selectUploadTasks } from '../upload/uploadSlice';
const PROGRESS_RADIUS = 10;
const PROGRESS_CIRCUMFERENCE = 2 * Math.PI * PROGRESS_RADIUS;
const IN_PROGRESS_STATUSES = new Set<UploadTask['status']>([
'preparing',
'uploading',
'training',
]);
/**
* Single merged upload card — Google-Drive style. Multiple in-flight
* uploads share one toast with a list of rows; the header reflects
* the *primary* task's status (the newest still-running task, or the
* newest task overall if all are terminal). Per-task progress lives
* on each row.
*
* Dismissal: the header X dismisses every visible task at once
* (mirrors the GDrive panel close — keeps the surface tidy without
* per-row controls). The chevron collapses the row list.
*/
export default function UploadToast() {
const [collapsed, setCollapsed] = useState(false);
const [collapsedTasks, setCollapsedTasks] = useState<Record<string, boolean>>(
{},
);
const toggleTaskCollapse = (taskId: string) => {
setCollapsedTasks((prev) => ({
...prev,
[taskId]: !prev[taskId],
}));
};
const { t } = useTranslation();
const dispatch = useDispatch();
const uploadTasks = useSelector(selectUploadTasks);
const visibleTasks = uploadTasks.filter((task) => !task.dismissed);
if (visibleTasks.length === 0) return null;
// Pick the task that drives the header status: prefer a still-
// running task (most-recent first since the slice unshifts), and
// fall back to whatever's most-recent if everything is terminal.
const primaryTask =
visibleTasks.find((task) => IN_PROGRESS_STATUSES.has(task.status)) ??
visibleTasks[0];
const headerLabel = getStatusHeading(primaryTask.status, t);
const dismissAll = () => {
for (const task of visibleTasks) {
dispatch(dismissUploadTask(task.id));
const getStatusHeading = (status: string) => {
switch (status) {
case 'preparing':
return t('modals.uploadDoc.progress.wait');
case 'uploading':
return t('modals.uploadDoc.progress.upload');
case 'training':
return t('modals.uploadDoc.progress.upload');
case 'completed':
return t('modals.uploadDoc.progress.completed');
case 'failed':
return t('modals.uploadDoc.progress.failed');
default:
return t('modals.uploadDoc.progress.preparing');
}
};
@@ -60,205 +47,180 @@ export default function UploadToast() {
<div
className="fixed right-4 bottom-4 z-50 flex max-w-md flex-col gap-2"
onMouseDown={(e) => e.stopPropagation()}
role="status"
aria-live="polite"
aria-atomic="false"
>
<div
className={`border-border bg-card w-[271px] overflow-hidden rounded-2xl border shadow-[0px_24px_48px_0px_#00000029] transition-all duration-300`}
>
<div
className={`flex items-center justify-between px-4 py-3 ${
primaryTask.status !== 'failed'
? 'bg-accent/50 dark:bg-muted'
: 'bg-destructive/10 dark:bg-destructive/10'
}`}
>
<h3 className="font-inter dark:text-foreground text-[14px] leading-[16.5px] font-medium text-black">
{headerLabel}
</h3>
<div className="flex items-center gap-1">
<button
type="button"
onClick={() => setCollapsed((prev) => !prev)}
aria-label={
collapsed
? t('modals.uploadDoc.progress.expandDetails')
: t('modals.uploadDoc.progress.collapseDetails')
}
className="flex h-8 items-center justify-center p-0 text-black opacity-70 transition-opacity hover:opacity-100 dark:text-white"
>
<img
src={ChevronDown}
alt=""
className={`h-4 w-4 transform transition-transform duration-200 dark:invert ${
collapsed ? 'rotate-180' : ''
}`}
/>
</button>
<button
type="button"
onClick={dismissAll}
className="flex h-8 items-center justify-center p-0 text-black opacity-70 transition-opacity hover:opacity-100 dark:text-white"
aria-label={t('modals.uploadDoc.progress.dismiss')}
>
<svg
width="16"
height="16"
viewBox="0 0 24 24"
fill="none"
xmlns="http://www.w3.org/2000/svg"
className="h-4 w-4"
>
<path
d="M18 6L6 18"
stroke="currentColor"
strokeWidth="2"
strokeLinecap="round"
strokeLinejoin="round"
/>
<path
d="M6 6L18 18"
stroke="currentColor"
strokeWidth="2"
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
</button>
</div>
</div>
{uploadTasks
.filter((task) => !task.dismissed)
.map((task) => {
const shouldShowProgress = [
'preparing',
'uploading',
'training',
].includes(task.status);
const rawProgress = Math.min(Math.max(task.progress ?? 0, 0), 100);
const formattedProgress = Math.round(rawProgress);
const progressOffset =
PROGRESS_CIRCUMFERENCE * (1 - rawProgress / 100);
const isCollapsed = collapsedTasks[task.id] ?? false;
<div
className="grid overflow-hidden transition-[grid-template-rows] duration-300 ease-out"
style={{ gridTemplateRows: collapsed ? '0fr' : '1fr' }}
>
<div
className={`min-h-0 overflow-hidden transition-opacity duration-300 ${
collapsed ? 'opacity-0' : 'opacity-100'
}`}
>
<ul className="max-h-72 overflow-y-auto">
{visibleTasks.map((task) => (
<UploadRow key={task.id} task={task} t={t} />
))}
</ul>
</div>
</div>
</div>
return (
<div
key={task.id}
className={`border-border bg-card w-[271px] overflow-hidden rounded-2xl border shadow-[0px_24px_48px_0px_#00000029] transition-all duration-300`}
>
<div className="flex flex-col">
<div
className={`flex items-center justify-between px-4 py-3 ${
task.status !== 'failed'
? 'bg-accent/50 dark:bg-muted'
: 'bg-destructive/10 dark:bg-destructive/10'
}`}
>
<h3 className="font-inter dark:text-foreground text-[14px] leading-[16.5px] font-medium text-black">
{getStatusHeading(task.status)}
</h3>
<div className="flex items-center gap-1">
<button
type="button"
onClick={() => toggleTaskCollapse(task.id)}
aria-label={
isCollapsed
? t('modals.uploadDoc.progress.expandDetails')
: t('modals.uploadDoc.progress.collapseDetails')
}
className="flex h-8 items-center justify-center p-0 text-black opacity-70 transition-opacity hover:opacity-100 dark:text-white"
>
<img
src={ChevronDown}
alt=""
className={`h-4 w-4 transform transition-transform duration-200 dark:invert ${
isCollapsed ? 'rotate-180' : ''
}`}
/>
</button>
<button
type="button"
onClick={() => dispatch(dismissUploadTask(task.id))}
className="flex h-8 items-center justify-center p-0 text-black opacity-70 transition-opacity hover:opacity-100 dark:text-white"
aria-label={t('modals.uploadDoc.progress.dismiss')}
>
<svg
width="16"
height="16"
viewBox="0 0 24 24"
fill="none"
xmlns="http://www.w3.org/2000/svg"
className="h-4 w-4"
>
<path
d="M18 6L6 18"
stroke="currentColor"
strokeWidth="2"
strokeLinecap="round"
strokeLinejoin="round"
/>
<path
d="M6 6L18 18"
stroke="currentColor"
strokeWidth="2"
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
</button>
</div>
</div>
<div
className="grid overflow-hidden transition-[grid-template-rows] duration-300 ease-out"
style={{ gridTemplateRows: isCollapsed ? '0fr' : '1fr' }}
>
<div
className={`min-h-0 overflow-hidden transition-opacity duration-300 ${
isCollapsed ? 'opacity-0' : 'opacity-100'
}`}
>
<div className="flex items-center justify-between px-5 py-3">
<p
className="font-inter dark:text-muted-foreground max-w-[200px] truncate text-[13px] leading-[16.5px] font-normal text-black"
title={task.fileName}
>
{task.fileName}
</p>
<div className="flex items-center gap-2">
{shouldShowProgress && (
<svg
width="24"
height="24"
viewBox="0 0 24 24"
className="h-6 w-6 shrink-0 text-[#7D54D1]"
role="progressbar"
aria-valuemin={0}
aria-valuemax={100}
aria-valuenow={formattedProgress}
aria-label={t(
'modals.uploadDoc.progress.uploadProgress',
{
progress: formattedProgress,
},
)}
>
<circle
className="text-muted dark:text-muted-foreground/30"
stroke="currentColor"
strokeWidth="2"
cx="12"
cy="12"
r={PROGRESS_RADIUS}
fill="none"
/>
<circle
className="text-[#7D54D1]"
stroke="currentColor"
strokeWidth="2"
strokeLinecap="round"
strokeDasharray={PROGRESS_CIRCUMFERENCE}
strokeDashoffset={progressOffset}
cx="12"
cy="12"
r={PROGRESS_RADIUS}
fill="none"
transform="rotate(-90 12 12)"
/>
</svg>
)}
{task.status === 'completed' && (
<img
src={CheckCircleFilled}
alt=""
className="h-6 w-6 shrink-0"
aria-hidden="true"
/>
)}
{task.status === 'failed' && (
<img
src={WarnIcon}
alt=""
className="h-6 w-6 shrink-0"
aria-hidden="true"
/>
)}
</div>
</div>
{task.status === 'failed' && task.errorMessage && (
<span className="block px-5 pb-3 text-xs text-red-500">
{task.errorMessage}
</span>
)}
</div>
</div>
</div>
</div>
);
})}
</div>
);
}
function UploadRow({
task,
t,
}: {
task: UploadTask;
t: ReturnType<typeof useTranslation>['t'];
}) {
const showProgress = IN_PROGRESS_STATUSES.has(task.status);
const rawProgress = Math.min(Math.max(task.progress ?? 0, 0), 100);
const formattedProgress = Math.round(rawProgress);
const progressOffset = PROGRESS_CIRCUMFERENCE * (1 - rawProgress / 100);
return (
<li className="border-border/50 border-b last:border-b-0">
<div className="flex items-center justify-between px-5 py-3">
<p
className="font-inter dark:text-muted-foreground max-w-[200px] truncate text-[13px] leading-[16.5px] font-normal text-black"
title={task.fileName}
>
{task.fileName}
</p>
<div className="flex items-center gap-2">
{showProgress && (
<svg
width="24"
height="24"
viewBox="0 0 24 24"
className="h-6 w-6 shrink-0 text-[#7D54D1]"
role="progressbar"
aria-valuemin={0}
aria-valuemax={100}
aria-valuenow={formattedProgress}
aria-label={t('modals.uploadDoc.progress.uploadProgress', {
progress: formattedProgress,
})}
>
<circle
className="text-muted dark:text-muted-foreground/30"
stroke="currentColor"
strokeWidth="2"
cx="12"
cy="12"
r={PROGRESS_RADIUS}
fill="none"
/>
<circle
className="text-[#7D54D1]"
stroke="currentColor"
strokeWidth="2"
strokeLinecap="round"
strokeDasharray={PROGRESS_CIRCUMFERENCE}
strokeDashoffset={progressOffset}
cx="12"
cy="12"
r={PROGRESS_RADIUS}
fill="none"
transform="rotate(-90 12 12)"
/>
</svg>
)}
{task.status === 'completed' && (
<img
src={CheckCircleFilled}
alt=""
className="h-6 w-6 shrink-0"
aria-hidden="true"
/>
)}
{task.status === 'failed' && (
<img
src={WarnIcon}
alt=""
className="h-6 w-6 shrink-0"
aria-hidden="true"
/>
)}
</div>
</div>
{task.status === 'failed' &&
(task.tokenLimitReached || task.errorMessage) && (
<span className="block px-5 pb-3 text-xs text-red-500">
{task.tokenLimitReached
? t('modals.uploadDoc.progress.tokenLimit')
: task.errorMessage}
</span>
)}
</li>
);
}
function getStatusHeading(
status: UploadTask['status'],
t: ReturnType<typeof useTranslation>['t'],
): string {
switch (status) {
case 'preparing':
return t('modals.uploadDoc.progress.wait');
case 'uploading':
case 'training':
return t('modals.uploadDoc.progress.upload');
case 'completed':
return t('modals.uploadDoc.progress.completed');
case 'failed':
return t('modals.uploadDoc.progress.failed');
default:
return t('modals.uploadDoc.progress.preparing');
}
}

View File

@@ -322,7 +322,7 @@ export default function Conversation() {
isSplitArtifactOpen ? 'w-[60%] px-6' : 'w-full'
}`}
>
<div className="relative min-h-0 flex-1">
<div className="relative min-h-0 flex-1 ">
<ConversationMessages
handleQuestion={handleQuestion}
handleQuestionSubmission={handleQuestionSubmission}

View File

@@ -132,8 +132,6 @@ const ConversationBubble = forwardRef<
}, [message]);
const handleEditClick = () => {
if (!editInputBox.trim() || editInputBox.trim() === (message ?? '').trim())
return;
setIsEditClicked(false);
handleUpdatedQuestionSubmission?.(editInputBox, true, questionNumber);
};
@@ -244,12 +242,8 @@ const ConversationBubble = forwardRef<
{t('conversation.edit.cancel')}
</button>
<button
className="bg-primary not-disabled:hover:bg-primary/90 not-disabled:dark:hover:bg-primary/90 disabled:bg-primary/30 rounded-full px-4 py-2 text-sm font-medium text-white transition-colors disabled:cursor-not-allowed"
className="bg-primary hover:bg-primary/90 dark:hover:bg-primary/90 rounded-full px-4 py-2 text-sm font-medium text-white transition-colors"
onClick={handleEditClick}
disabled={
!editInputBox.trim() ||
editInputBox.trim() === (message ?? '').trim()
}
>
{t('conversation.edit.update')}
</button>

View File

@@ -64,10 +64,7 @@ export default function ConversationTile({
}
function handleSaveConversation(changedConversation: ConversationProps) {
if (
changedConversation.name.trim().length &&
changedConversation.name.trim() !== conversation.name.trim()
) {
if (changedConversation.name.trim().length) {
onSave(changedConversation);
setIsEdit(false);
} else {

View File

@@ -1,155 +1,8 @@
import { baseURL } from '../api/client';
import conversationService from '../api/services/conversationService';
import { Doc } from '../models/misc';
import { Answer, FEEDBACK, RetrievalPayload } from './conversationModels';
import { ToolCallsType } from './types';
/**
* Mirrors the backend's ``_SEQUENCE_NO_RE`` (application/api/answer/
* routes/messages.py) — only non-negative decimal integers are valid
* cursors. Rejects empty strings (Number("") === 0), hex literals,
* exponential notation, and anything else that ``Number(...)`` would
* happily coerce.
*/
const _SEQUENCE_NO_RE = /^\d+$/;
/**
* Drain an SSE response body, forwarding each ``data:`` line to
* ``onData`` and tracking the most recent ``id:`` header. Returns
* when the body ends, the signal aborts, or ``shouldStop()`` returns
* true (e.g. a terminal ``end``/``error`` event was dispatched —
* the reconnect endpoint is a live tail that doesn't close on its
* own past terminal replay).
*/
/**
* Convert a non-SSE pre-stream HTTP failure (e.g. ``check_usage``'s
* 429 JSON response) into a synthetic typed ``error`` frame so the
* caller's slice sees the actual server message instead of the
* generic "Connection lost" synthesised when the drainer finishes
* with zero events. Returns true if a frame was dispatched and the
* caller should skip ``_drainSseBody`` entirely.
*
* SSE-shaped error bodies (``mimetype="text/event-stream"``) are
* left alone — the drainer parses the typed ``error`` frame they
* carry through the normal path.
*/
async function _handlePreStreamHttpError(
response: Response,
dispatch: (data: string) => void,
): Promise<boolean> {
if (response.ok) return false;
const contentType = (
response.headers.get('content-type') ?? ''
).toLowerCase();
if (contentType.includes('text/event-stream')) return false;
let message: string | null = null;
try {
const text = await response.text();
if (text) {
try {
const parsed = JSON.parse(text);
if (parsed && typeof parsed === 'object') {
message =
(typeof parsed.message === 'string' && parsed.message) ||
(typeof parsed.error === 'string' && parsed.error) ||
(typeof parsed.detail === 'string' && parsed.detail) ||
null;
}
} catch {
message = text.slice(0, 500);
}
}
} catch {
// Body already consumed or unreadable — fall through to the
// status-line fallback below.
}
if (!message) {
message = `HTTP ${response.status} ${response.statusText}`.trim();
}
dispatch(JSON.stringify({ type: 'error', error: message }));
return true;
}
async function _drainSseBody(
body: ReadableStream<Uint8Array>,
signal: AbortSignal,
onData: (data: string) => void,
onId: (id: number) => void,
shouldStop?: () => boolean,
): Promise<void> {
const reader = body.getReader();
const decoder = new TextDecoder('utf-8');
let buffer = '';
let stoppedEarly = false;
try {
while (true) {
if (signal.aborted) break;
if (shouldStop?.()) {
stoppedEarly = true;
break;
}
const { done, value } = await reader.read();
if (done) break;
buffer += decoder.decode(value, { stream: true });
// Normalise mixed line terminators so a stray CR can't smuggle
// a record boundary inside a JSON payload.
buffer = buffer.replace(/\r\n/g, '\n').replace(/\r/g, '\n');
let boundary = buffer.indexOf('\n\n');
while (boundary !== -1) {
const record = buffer.slice(0, boundary);
buffer = buffer.slice(boundary + 2);
boundary = buffer.indexOf('\n\n');
if (record.length === 0) continue;
const dataParts: string[] = [];
let sawDataField = false;
for (const line of record.split('\n')) {
if (line.length === 0) continue;
if (line.startsWith(':')) continue; // SSE comment / keepalive
const colonIdx = line.indexOf(':');
const field = colonIdx === -1 ? line : line.slice(0, colonIdx);
let value = colonIdx === -1 ? '' : line.slice(colonIdx + 1);
if (value.startsWith(' ')) value = value.slice(1);
if (field === 'id') {
// Strict regex match — empty value, hex, ``-1`` (the
// backend's terminal snapshot-failure synthetic), and
// exponent forms are all rejected so they can't silently
// rewrite the reconnect cursor.
if (_SEQUENCE_NO_RE.test(value)) onId(parseInt(value, 10));
} else if (field === 'data') {
sawDataField = true;
dataParts.push(value);
}
}
if (!sawDataField) continue;
const data = dataParts.join('\n').trim();
if (data.length === 0) continue;
onData(data);
if (shouldStop?.()) {
stoppedEarly = true;
break;
}
}
if (stoppedEarly) break;
}
} finally {
if (stoppedEarly) {
// Ask the runtime to tear the underlying response body down so
// the server-side WSGI thread isn't pinned waiting on
// keepalives. ``releaseLock`` alone leaves the body half-open.
try {
await reader.cancel();
} catch {
// Already errored / closed.
}
}
try {
reader.releaseLock();
} catch {
// Already released.
}
}
}
export function handleFetchAnswer(
question: string,
signal: AbortSignal,
@@ -290,153 +143,54 @@ export function handleFetchAnswerSteaming(
const headers: Record<string, string> = {};
if (idempotencyKey) headers['Idempotency-Key'] = idempotencyKey;
// Per-stream state used for reconnect-after-disconnect.
let messageId: string | null = null;
let lastEventId: number | null = null;
// The single JSON.parse below feeds both the message_id capture and
// the termination flag — cheaper and stricter than substring
// matching the wire bytes.
let endReceived = false;
const dispatch = (data: string) => {
try {
const parsed = JSON.parse(data);
if (parsed && typeof parsed === 'object') {
if (parsed.type === 'message_id' && parsed.message_id) {
messageId = parsed.message_id;
} else if (parsed.type === 'end' || parsed.type === 'error') {
endReceived = true;
}
}
} catch {
// Not JSON — pass through anyway; the caller handles raw lines.
}
onEvent(new MessageEvent('message', { data }));
};
const runInitialPost = async (): Promise<void> => {
const response = await conversationService.answerStream(
payload,
token,
signal,
headers,
);
// Pre-stream HTTP failures with non-SSE bodies (e.g. ``check_usage``
// returning a JSON 429) drain as zero events and would otherwise
// be masked by the generic "Connection lost" synthetic. Convert
// them into a typed ``error`` frame so the real message surfaces.
if (await _handlePreStreamHttpError(response, dispatch)) return;
if (!response.body) throw new Error('No response body');
await _drainSseBody(response.body, signal, dispatch, (id) => {
lastEventId = id;
});
};
// Reconnect's stop predicate: as soon as ``dispatch`` flips
// ``endReceived`` (typed ``end`` or ``error`` event seen — both
// are terminal per the backend's contract). Without this the
// live-tail endpoint would emit keepalives indefinitely and the
// await would never return.
const reconnectShouldStop = () => endReceived;
const runReconnect = async (): Promise<void> => {
if (!messageId) {
throw new Error('reconnect: no message_id captured');
}
const url = new URL(`${baseURL}/api/messages/${messageId}/events`);
if (lastEventId !== null) {
url.searchParams.set('last_event_id', String(lastEventId));
}
const reconnectHeaders: Record<string, string> = {
Accept: 'text/event-stream',
};
if (token) reconnectHeaders.Authorization = `Bearer ${token}`;
// NB: there is no slice consumer for a synthetic ``reconnecting``
// event yet — surface only the underlying network reality. The
// user-visible ``Reconnecting…`` affordance is a follow-up that
// needs ``conversationSlice`` to gain a status case.
const response = await fetch(url.toString(), {
method: 'GET',
headers: reconnectHeaders,
signal,
cache: 'no-store',
});
if (!response.ok || !response.body) {
throw new Error(
`reconnect: HTTP ${response.status} ${response.statusText}`,
);
}
await _drainSseBody(
response.body,
signal,
dispatch,
(id) => {
lastEventId = id;
},
reconnectShouldStop,
);
};
return new Promise<Answer>((resolve, reject) => {
(async () => {
try {
try {
await runInitialPost();
} catch (initialErr) {
// Mid-stream network failures (WiFi blip, worker recycle,
// body reader rejecting) surface as a thrown error — not a
// graceful EOF. If the stream had already started (we have a
// ``messageId``), fall through to the reconnect path so the
// journal-backed replay can finish what the live socket
// couldn't. Pre-stream failures (auth, DNS, server 4xx/5xx
// before any yield) lack a messageId and bubble up.
if (signal.aborted || !messageId) throw initialErr;
console.warn(
'Initial stream failed mid-flight, attempting reconnect:',
initialErr,
);
}
// The backend ends the stream cleanly with a typed ``end``
// event. Anything else (network drop, gunicorn worker recycle,
// load-balancer timeout) is a "premature close" — try one
// reconnect via the GET /api/messages/<id>/events endpoint.
if (!endReceived && !signal.aborted && messageId) {
try {
await runReconnect();
} catch (reconnectErr) {
console.warn('Stream reconnect failed:', reconnectErr);
conversationService
.answerStream(payload, token, signal, headers)
.then((response) => {
if (!response.body) throw Error('No response body');
let buffer = '';
const reader = response.body.getReader();
const decoder = new TextDecoder('utf-8');
let counterrr = 0;
const processStream = ({
done,
value,
}: ReadableStreamReadResult<Uint8Array>) => {
if (done) return;
counterrr += 1;
const chunk = decoder.decode(value);
buffer += chunk;
const events = buffer.split('\n\n');
buffer = events.pop() ?? '';
for (const event of events) {
if (event.trim().startsWith('data:')) {
const dataLine: string = event
.split('\n')
.map((line: string) => line.replace(/^data:\s?/, ''))
.join('');
const messageEvent = new MessageEvent('message', {
data: dataLine.trim(),
});
onEvent(messageEvent);
}
}
}
// If we never observed a terminal frame (reconnect 4xx/5xx,
// network drop during reconnect, or live tail still silent),
// synthesize one through the same ``dispatch`` path the wire
// events use. Without this the caller's slice never transitions
// out of ``streaming`` and the UI stays in a loading spinner
// forever — the conversationSlice handles ``data.type === 'error'``
// by setting status=failed.
if (!endReceived && !signal.aborted) {
dispatch(
JSON.stringify({
type: 'error',
error:
'Connection lost. The response could not be resumed; please try again.',
}),
);
}
// The handler historically never explicitly resolved with a
// value — callers consume the streamed events via ``onEvent``
// and read final state from Redux. Preserve that contract.
resolve(undefined as unknown as Answer);
} catch (error) {
if (signal.aborted) {
resolve(undefined as unknown as Answer);
return;
}
reader.read().then(processStream).catch(reject);
};
reader.read().then(processStream).catch(reject);
})
.catch((error) => {
console.error('Connection failed:', error);
reject(error);
}
})();
});
});
}
@@ -460,149 +214,52 @@ export function handleSubmitToolActions(
const headers: Record<string, string> = {};
if (idempotencyKey) headers['Idempotency-Key'] = idempotencyKey;
// Tool-action submissions resume against the original
// ``reserved_message_id``, so the backend's continuation path emits
// ``id:`` prefixed records that the legacy parser would silently
// drop. Use the shared SSE drainer — and the same reconnect-on-
// premature-close pattern as ``handleFetchAnswerSteaming`` so a
// dropped tool-action stream can pick up after the disconnect.
let messageId: string | null = null;
let lastEventId: number | null = null;
// Track whether the typed ``end`` event was observed. The single
// JSON.parse below feeds both the message_id capture and the
// termination flag — cheaper and stricter than substring matching
// the wire bytes.
let endReceived = false;
const dispatch = (data: string) => {
try {
const parsed = JSON.parse(data);
if (parsed && typeof parsed === 'object') {
if (parsed.type === 'message_id' && parsed.message_id) {
messageId = parsed.message_id;
} else if (parsed.type === 'end' || parsed.type === 'error') {
// Match the backend's terminal set in
// ``application/streaming/event_replay.py``: the agent's
// catch-all path emits ``error`` *without* a trailing
// ``end``, so treating only ``end`` as terminal would
// trigger a reconnect against an already-finished stream
// and hang on keepalives.
endReceived = true;
}
}
} catch {
// Not JSON — pass through anyway; the caller handles raw lines.
}
onEvent(new MessageEvent('message', { data }));
};
const runInitial = async (): Promise<void> => {
const response = await conversationService.answerStream(
payload,
token,
signal,
headers,
);
// See ``handleFetchAnswerSteaming`` for the rationale: non-SSE
// HTTP failures (e.g. ``check_usage`` 429 JSON) need to be lifted
// into a typed ``error`` frame before they reach the drainer.
if (await _handlePreStreamHttpError(response, dispatch)) return;
if (!response.body) throw new Error('No response body');
await _drainSseBody(response.body, signal, dispatch, (id) => {
lastEventId = id;
});
};
// Reconnect's stop predicate: as soon as ``dispatch`` flips
// ``endReceived`` (typed ``end`` or ``error`` event seen — both
// are terminal per the backend's contract). Without this the
// live-tail endpoint would emit keepalives indefinitely and the
// await would never return.
const reconnectShouldStop = () => endReceived;
const runReconnect = async (): Promise<void> => {
if (!messageId) {
throw new Error('reconnect: no message_id captured');
}
const url = new URL(`${baseURL}/api/messages/${messageId}/events`);
if (lastEventId !== null) {
url.searchParams.set('last_event_id', String(lastEventId));
}
const reconnectHeaders: Record<string, string> = {
Accept: 'text/event-stream',
};
if (token) reconnectHeaders.Authorization = `Bearer ${token}`;
const response = await fetch(url.toString(), {
method: 'GET',
headers: reconnectHeaders,
signal,
cache: 'no-store',
});
if (!response.ok || !response.body) {
throw new Error(
`reconnect: HTTP ${response.status} ${response.statusText}`,
);
}
await _drainSseBody(
response.body,
signal,
dispatch,
(id) => {
lastEventId = id;
},
reconnectShouldStop,
);
};
return new Promise<Answer>((resolve, reject) => {
(async () => {
try {
try {
await runInitial();
} catch (initialErr) {
// Same premature-close handling as
// ``handleFetchAnswerSteaming``: a thrown reader error after
// the message_id frame still warrants one reconnect attempt
// against the journal. Pre-stream failures lack a messageId
// and bubble up.
if (signal.aborted || !messageId) throw initialErr;
console.warn(
'Tool-actions stream failed mid-flight, attempting reconnect:',
initialErr,
);
}
if (!endReceived && !signal.aborted && messageId) {
try {
await runReconnect();
} catch (reconnectErr) {
console.warn('Tool-actions reconnect failed:', reconnectErr);
conversationService
.answerStream(payload, token, signal, headers)
.then((response) => {
if (!response.body) throw Error('No response body');
let buffer = '';
const reader = response.body.getReader();
const decoder = new TextDecoder('utf-8');
const processStream = ({
done,
value,
}: ReadableStreamReadResult<Uint8Array>) => {
if (done) return;
const chunk = decoder.decode(value);
buffer += chunk;
const events = buffer.split('\n\n');
buffer = events.pop() ?? '';
for (const event of events) {
if (event.trim().startsWith('data:')) {
const dataLine: string = event
.split('\n')
.map((line: string) => line.replace(/^data:\s?/, ''))
.join('');
const messageEvent = new MessageEvent('message', {
data: dataLine.trim(),
});
onEvent(messageEvent);
}
}
}
// Synthesize a terminal error if reconnect couldn't deliver one
// (4xx/5xx, network drop, silent live tail). Same reasoning as
// ``handleFetchAnswerSteaming``: the caller's slice only exits
// the streaming state on a terminal frame.
if (!endReceived && !signal.aborted) {
dispatch(
JSON.stringify({
type: 'error',
error:
'Connection lost. The tool response could not be resumed; please try again.',
}),
);
}
resolve(undefined as unknown as Answer);
} catch (error) {
if (signal.aborted) {
resolve(undefined as unknown as Answer);
return;
}
reader.read().then(processStream).catch(reject);
};
reader.read().then(processStream).catch(reject);
})
.catch((error) => {
console.error('Tool actions submission failed:', error);
reject(error);
}
})();
});
});
}

View File

@@ -1,153 +0,0 @@
import { describe, expect, it } from 'vitest';
import reducer, {
applyMessageTail,
setConversation,
} from './conversationSlice';
const baseQuery = {
prompt: 'tell me a poem',
messageId: 'm-1',
messageStatus: 'pending' as const,
};
const seedSlice = () => reducer(undefined, setConversation([baseQuery]));
describe('applyMessageTail — streaming partial', () => {
it('writes response to the query while status is streaming', () => {
const state = seedSlice();
const next = reducer(
state,
applyMessageTail({
index: 0,
tail: {
message_id: 'm-1',
status: 'streaming',
response: 'Hello, par',
thought: null,
sources: [],
tool_calls: [],
},
}),
);
expect(next.queries[0].messageStatus).toBe('streaming');
expect(next.queries[0].response).toBe('Hello, par');
});
it('updates response on each successive tail tick', () => {
let state = seedSlice();
state = reducer(
state,
applyMessageTail({
index: 0,
tail: {
message_id: 'm-1',
status: 'streaming',
response: 'Hello',
sources: [],
tool_calls: [],
},
}),
);
state = reducer(
state,
applyMessageTail({
index: 0,
tail: {
message_id: 'm-1',
status: 'streaming',
response: 'Hello, world',
sources: [],
tool_calls: [],
},
}),
);
expect(state.queries[0].response).toBe('Hello, world');
});
it('applies sources and tool_calls when they appear mid-stream', () => {
const state = seedSlice();
const next = reducer(
state,
applyMessageTail({
index: 0,
tail: {
message_id: 'm-1',
status: 'streaming',
response: 'partial',
sources: [{ id: 's1', title: 'doc' }],
tool_calls: [{ name: 'search' }],
},
}),
);
expect(next.queries[0].sources).toEqual([{ id: 's1', title: 'doc' }]);
expect(next.queries[0].tool_calls).toEqual([{ name: 'search' }]);
});
it('ignores empty sources / tool_calls arrays so the renderer stays clean', () => {
const state = seedSlice();
const next = reducer(
state,
applyMessageTail({
index: 0,
tail: {
message_id: 'm-1',
status: 'streaming',
response: 'partial',
sources: [],
tool_calls: [],
},
}),
);
expect(next.queries[0].sources).toBeUndefined();
expect(next.queries[0].tool_calls).toBeUndefined();
});
it('promotes to complete with the final response and clears any error', () => {
let state = seedSlice();
state = reducer(
state,
applyMessageTail({
index: 0,
tail: {
message_id: 'm-1',
status: 'streaming',
response: 'partial',
},
}),
);
state = reducer(
state,
applyMessageTail({
index: 0,
tail: {
message_id: 'm-1',
status: 'complete',
response: 'Final answer.',
},
}),
);
expect(state.queries[0].messageStatus).toBe('complete');
expect(state.queries[0].response).toBe('Final answer.');
expect(state.queries[0].error).toBeUndefined();
});
it('surfaces failed status as error and clears response', () => {
const state = seedSlice();
const next = reducer(
state,
applyMessageTail({
index: 0,
tail: {
message_id: 'm-1',
status: 'failed',
response: 'whatever',
error: 'worker died',
},
}),
);
expect(next.queries[0].messageStatus).toBe('failed');
expect(next.queries[0].error).toBe('worker died');
expect(next.queries[0].response).toBeUndefined();
});
});

View File

@@ -957,34 +957,20 @@ export const conversationSlice = createSlice({
const status = tail?.status as MessageStatus | undefined;
query.messageStatus = status;
query.lastHeartbeatAt = tail?.last_heartbeat_at ?? query.lastHeartbeatAt;
if (status === 'failed') {
if (status === 'complete') {
query.response = tail?.response ?? '';
query.thought = tail?.thought ?? query.thought;
query.sources = tail?.sources ?? query.sources;
query.tool_calls = tail?.tool_calls ?? query.tool_calls;
delete query.error;
} else if (status === 'failed') {
// Surface as error so the placeholder text never renders.
query.error =
(typeof tail?.error === 'string' && tail.error) ||
'Generation failed before completing.';
delete query.response;
return;
}
// /tail returns reconstructed partials mid-stream so a second tab
// can render the in-flight bubble; spinner is driven by status.
const incomingResponse = tail?.response;
if (typeof incomingResponse === 'string') {
query.response = incomingResponse;
} else if (status === 'complete') {
query.response = '';
}
if (typeof tail?.thought === 'string') {
query.thought = tail.thought;
}
if (Array.isArray(tail?.sources) && tail.sources.length > 0) {
query.sources = tail.sources;
}
if (Array.isArray(tail?.tool_calls) && tail.tool_calls.length > 0) {
query.tool_calls = tail.tool_calls;
}
if (status === 'complete') {
delete query.error;
}
// pending / streaming: untouched; spinner keeps showing.
},
raiseError(
state,

View File

@@ -1,18 +0,0 @@
import React from 'react';
import { useEventStream } from './useEventStream';
/**
* Mount-once provider that opens the user's SSE connection. Place
* inside ``AuthWrapper`` so it sees a populated token, and wrap the
* authenticated-app subtree so the connection lives for the user's
* whole session.
*/
export function EventStreamProvider({
children,
}: {
children: React.ReactNode;
}): React.ReactElement {
useEventStream();
return <>{children}</>;
}

View File

@@ -1,49 +0,0 @@
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
import type { AppDispatch } from '../store';
import {
sseEventReceived,
sseLastEventIdReset,
} from '../notifications/notificationsSlice';
import { dispatchSSEEvent } from './dispatchEvent';
describe('dispatchSSEEvent', () => {
let debugSpy: ReturnType<typeof vi.spyOn>;
beforeEach(() => {
debugSpy = vi.spyOn(console, 'debug').mockImplementation(() => undefined);
});
afterEach(() => {
debugSpy.mockRestore();
});
it('dispatches sseLastEventIdReset AND sseEventReceived for backlog.truncated', () => {
const dispatch = vi.fn() as unknown as AppDispatch;
const envelope = { type: 'backlog.truncated' as const };
dispatchSSEEvent(envelope, dispatch);
const calls = (dispatch as unknown as { mock: { calls: unknown[][] } }).mock
.calls;
expect(calls).toHaveLength(2);
expect(calls[0][0]).toEqual(sseLastEventIdReset());
expect(calls[1][0]).toEqual(sseEventReceived(envelope));
});
it('does not log a debug line for known envelope types', () => {
const dispatch = vi.fn() as unknown as AppDispatch;
dispatchSSEEvent({ id: 'e-1', type: 'source.ingest.progress' }, dispatch);
expect(debugSpy).not.toHaveBeenCalled();
});
it('logs a debug line for unknown envelope types', () => {
const dispatch = vi.fn() as unknown as AppDispatch;
dispatchSSEEvent({ id: 'e-2', type: 'mystery.event' }, dispatch);
expect(debugSpy).toHaveBeenCalledTimes(1);
expect(debugSpy.mock.calls[0]).toEqual([
'[dispatchSSEEvent] unknown envelope type',
'mystery.event',
]);
});
});

View File

@@ -1,58 +0,0 @@
import type { AppDispatch } from '../store';
import {
sseEventReceived,
sseLastEventIdReset,
type SSEEvent,
} from '../notifications/notificationsSlice';
// Envelope types this build knows about. Hitting an unknown type means
// the backend published something the frontend hasn't been taught yet
// — worth a single debug line so it's visible in devtools without
// drowning the console in known per-progress traffic.
const KNOWN_TYPES: ReadonlySet<string> = new Set([
'backlog.truncated',
'source.ingest.queued',
'source.ingest.progress',
'source.ingest.completed',
'source.ingest.failed',
'attachment.queued',
'attachment.progress',
'attachment.completed',
'attachment.failed',
'mcp.oauth.awaiting_redirect',
'mcp.oauth.in_progress',
'mcp.oauth.completed',
'mcp.oauth.failed',
'tool.approval.required',
]);
/**
* Single fan-out point for inbound SSE envelopes. Always dispatches
* ``sseEventReceived`` so any slice can ``extraReducers``-listen
* (uploadSlice does this for source-ingest events), then handles the
* small set of envelope-types that need centralised side effects (e.g.
* ``backlog.truncated``).
*/
export function dispatchSSEEvent(
envelope: SSEEvent,
dispatch: AppDispatch,
): void {
if (!KNOWN_TYPES.has(envelope.type)) {
console.debug('[dispatchSSEEvent] unknown envelope type', envelope.type);
}
switch (envelope.type) {
case 'backlog.truncated':
// Backlog window slid past the client's Last-Event-ID. Drop the
// cursor so the next reconnect doesn't try to resume past the
// retained window. Slices that care about full-state freshness
// can subscribe to ``sseEventReceived`` and refetch.
dispatch(sseLastEventIdReset());
break;
default:
// No central side effect; rely on slice-level extraReducers.
break;
}
dispatch(sseEventReceived(envelope));
}

View File

@@ -1,386 +0,0 @@
import { baseURL } from '../api/client';
import type { SSEEvent } from '../notifications/notificationsSlice';
/**
* Connection state surfaced to the consumer. Maps directly to the
* ``PushHealth`` machine in ``notificationsSlice``.
*/
export type EventStreamHealth = 'connecting' | 'healthy' | 'unhealthy';
export interface EventStreamOptions {
/** Bearer token; ``null`` short-circuits to ``unhealthy`` (auth required). */
token: string | null;
/**
* Lazy getter for the current ``Last-Event-ID``. Called once at the
* top of each connect attempt so token rotations / remounts read
* the freshest cursor from Redux instead of a stale mount-time
* snapshot. Return ``null`` for a fresh connect.
*/
getLastEventId: () => string | null;
onEvent: (event: SSEEvent) => void;
onHealthChange: (health: EventStreamHealth) => void;
/** Called with the most recently received id so the caller can persist it. */
onLastEventId?: (id: string) => void;
/**
* Called when the server emitted an ``id:`` line with an empty value
* (WHATWG SSE cursor reset). Distinct from ``onLastEventId('')`` so
* callers can dispatch ``sseLastEventIdReset`` without overloading
* the normal advance path.
*/
onLastEventIdReset?: () => void;
/**
* Invoked once after ``MAX_CONSECUTIVE_401`` back-to-back 401s. The
* reconnect loop then bails out, so the caller is responsible for
* refreshing the token / signalling logout. Without this, an expired
* token spins forever at the 30s backoff cap.
*/
onAuthFailure?: () => void;
/**
* Invoked once when the reconnect loop bails out after
* ``MAX_CONSECUTIVE_ERRORS`` non-401 failures. Lets the caller surface
* a warning instead of the connection silently going dark.
*/
onPermanentFailure?: () => void;
}
export interface EventStreamConnection {
close(): void;
}
/**
* Backoff schedule (ms) for reconnect attempts. Capped at 30s so a long
* outage doesn't push retries past Cloudflare's typical 100s idle-close
* envelope. The schedule resets to 0 after a stream stays healthy for
* ``HEALTHY_DEBOUNCE_MS``.
*/
const BACKOFF_SCHEDULE_MS = [0, 1_000, 2_000, 4_000, 8_000, 16_000, 30_000];
const HEALTHY_DEBOUNCE_MS = 2_000;
/**
* Reconnect ceilings. Without these, the ``while (!closed)`` loop spins
* forever on a persistently-failing endpoint — expired token (401s) or
* sustained server outage (5xx). Both counters reset on a successful
* stream open. Untested (no frontend test harness); behaviour verified
* by manual trace of the loop in ``connectEventStream``.
*/
const MAX_CONSECUTIVE_401 = 3;
const MAX_CONSECUTIVE_ERRORS = 20;
/** Up-to-±20% random jitter so N tabs reconnecting in lockstep stagger. */
function withJitter(delayMs: number): number {
if (delayMs <= 0) return 0;
const span = delayMs * 0.2;
return Math.max(0, Math.round(delayMs + (Math.random() * 2 - 1) * span));
}
/**
* Open a long-lived SSE connection to ``GET /api/events`` with
* fetch-streaming semantics that mirror ``conversationHandlers.ts``.
*
* Returns immediately with an opaque handle; the connection lives in a
* background async loop until ``close()`` is called or the underlying
* ``AbortController`` fires.
*
* The ``Last-Event-ID`` cursor rides on the URL (``?last_event_id=...``)
* rather than as a header so the request stays a CORS-simple GET — a
* custom header would force a preflight OPTIONS that the production
* cross-origin deployment isn't allowlisted for.
*/
export function connectEventStream(
opts: EventStreamOptions,
): EventStreamConnection {
const controller = new AbortController();
let closed = false;
let attempt = 0;
let consecutive401 = 0;
let consecutiveErrors = 0;
// Closure cursor. Seeded from the store on each connect attempt so
// mid-session reconnects use the freshest id, but kept here too so
// an in-flight stream's reconnect doesn't lose progress between ids
// that the store hasn't seen yet (e.g. id-only frames).
let lastEventId: string | null = opts.getLastEventId();
const notifyHealth = (h: EventStreamHealth) => {
if (closed) return;
opts.onHealthChange(h);
};
void (async () => {
while (!closed) {
const baseDelay =
BACKOFF_SCHEDULE_MS[Math.min(attempt, BACKOFF_SCHEDULE_MS.length - 1)];
const delay = withJitter(baseDelay);
if (delay > 0) {
try {
await sleep(delay, controller.signal);
} catch {
return; // aborted while waiting
}
if (closed) return;
}
notifyHealth('connecting');
// Always re-read the store cursor before reconnecting and copy
// it verbatim — including null. A null cursor isn't "leave
// alone": ``backlog.truncated`` events fire ``sseLastEventIdReset``
// to clear the slice, and the client must respect that on the
// next attempt by sending no cursor (full-backlog replay) instead
// of resending the stale one and re-tripping the same truncation.
lastEventId = opts.getLastEventId();
const url = new URL(`${baseURL}/api/events`);
if (lastEventId) url.searchParams.set('last_event_id', lastEventId);
// Auth header is omitted when token is null. Self-hosted dev
// installs run with ``AUTH_TYPE`` unset; the backend maps those
// requests to ``{"sub": "local"}`` so the SSE connection works
// tokenless. When auth IS required, a missing header surfaces
// as a 401 and the response.ok check below flips the health
// back to unhealthy.
const headers: Record<string, string> = {
Accept: 'text/event-stream',
};
if (opts.token) {
headers.Authorization = `Bearer ${opts.token}`;
}
try {
const response = await fetch(url.toString(), {
method: 'GET',
headers,
signal: controller.signal,
// SSE must not be cached.
cache: 'no-store',
});
if (!response.ok || !response.body) {
notifyHealth('unhealthy');
// 401 typically means token expired. Bail out after N
// consecutive 401s so the loop doesn't spin forever at the
// 30s backoff cap with a stale token; the caller is
// responsible for refreshing auth via ``onAuthFailure``.
if (response.status === 401) {
consecutive401 += 1;
consecutiveErrors += 1;
if (consecutive401 >= MAX_CONSECUTIVE_401) {
opts.onAuthFailure?.();
return;
}
} else {
consecutive401 = 0;
consecutiveErrors += 1;
}
if (consecutiveErrors >= MAX_CONSECUTIVE_ERRORS) {
opts.onPermanentFailure?.();
return;
}
// 429: server-side per-user concurrency cap; backoff harder.
if (response.status === 429) attempt = Math.max(attempt, 4);
else attempt = Math.min(attempt + 1, BACKOFF_SCHEDULE_MS.length - 1);
continue;
}
consecutive401 = 0;
// Connection is open. Mark healthy after either:
// - 2s of open response body (covers servers that go silent), or
// - first record received past the 2s mark.
// The setTimeout path means a backend that never emits a single
// record after sending the 200 still flips us out of `connecting`.
let healthyMarked = false;
const markHealthy = () => {
if (healthyMarked) return;
healthyMarked = true;
notifyHealth('healthy');
attempt = 0;
consecutiveErrors = 0;
};
const debounceTimer = setTimeout(markHealthy, HEALTHY_DEBOUNCE_MS);
try {
await readSSEStream(response.body, controller.signal, (record) => {
if (record.id !== undefined) {
lastEventId = record.id || null;
if (record.id) opts.onLastEventId?.(record.id);
else opts.onLastEventIdReset?.();
}
if (record.data === undefined) {
// Keepalive comment, id-only frame, or comment line.
// The cursor was already advanced via ``onLastEventId``
// above so the slice tracks ids even on frames we don't
// dispatch as events.
return;
}
// Empty data line is technically valid SSE but useless; skip.
if (record.data.trim().length === 0) return;
let envelope: SSEEvent | null = null;
try {
envelope = JSON.parse(record.data) as SSEEvent;
} catch {
// Malformed payload; skip.
return;
}
// Defensive shape validation — the cast above lies if the
// server (or a man-in-the-middle) sends garbage.
if (
!envelope ||
typeof envelope !== 'object' ||
typeof envelope.type !== 'string'
) {
return;
}
if (record.id && !envelope.id) envelope.id = record.id;
// Receiving a real envelope post-debounce-window flips
// healthy if the timer hasn't already.
markHealthy();
// Every tab dispatches every envelope it receives into its
// own Redux store. With N tabs open this means N copies of
// the same toast — accepted as a v1 limitation; cross-tab
// dedup via BroadcastChannel + navigator.locks is future
// work. Toast-level suppression can be handled per surface.
opts.onEvent(envelope);
});
} finally {
clearTimeout(debounceTimer);
}
// The reader returned without abort — server closed the stream.
// Fall through to reconnect.
notifyHealth('unhealthy');
consecutiveErrors += 1;
if (consecutiveErrors >= MAX_CONSECUTIVE_ERRORS) {
opts.onPermanentFailure?.();
return;
}
attempt = Math.min(attempt + 1, BACKOFF_SCHEDULE_MS.length - 1);
} catch (err) {
if (
closed ||
(err instanceof DOMException && err.name === 'AbortError')
) {
return;
}
notifyHealth('unhealthy');
consecutiveErrors += 1;
if (consecutiveErrors >= MAX_CONSECUTIVE_ERRORS) {
opts.onPermanentFailure?.();
return;
}
attempt = Math.min(attempt + 1, BACKOFF_SCHEDULE_MS.length - 1);
}
}
})();
return {
close() {
if (closed) return;
closed = true;
controller.abort();
},
};
}
interface ParsedSSERecord {
/**
* ``undefined`` when the record had no ``id`` field at all. An empty
* string means the server explicitly reset the cursor (an ``id:``
* line with no value, per WHATWG SSE).
*/
id?: string;
/** ``undefined`` for keepalive comments / id-only frames. */
data?: string;
}
/**
* Drain a ``ReadableStream<Uint8Array>`` of ``\n\n``-delimited SSE records,
* forwarding each parsed record to ``onRecord``. Honours the WHATWG SSE
* spec's mixed line-terminator handling and SSE comment lines.
*/
async function readSSEStream(
body: ReadableStream<Uint8Array>,
signal: AbortSignal,
onRecord: (record: ParsedSSERecord) => void,
): Promise<void> {
const reader = body.getReader();
const decoder = new TextDecoder('utf-8');
let buffer = '';
try {
while (true) {
if (signal.aborted) return;
const { done, value } = await reader.read();
if (done) return;
buffer += decoder.decode(value, { stream: true });
// SSE records are separated by a blank line. WHATWG spec accepts
// CRLF, CR, or LF — normalise so a stray CR can't smuggle a
// boundary mid-record.
buffer = buffer.replace(/\r\n/g, '\n').replace(/\r/g, '\n');
let boundary = buffer.indexOf('\n\n');
while (boundary !== -1) {
const raw = buffer.slice(0, boundary);
buffer = buffer.slice(boundary + 2);
const record = parseSSERecord(raw);
if (record) onRecord(record);
boundary = buffer.indexOf('\n\n');
}
}
} finally {
try {
reader.releaseLock();
} catch {
// Already released.
}
}
}
function parseSSERecord(raw: string): ParsedSSERecord | null {
if (raw.length === 0) return null;
const lines = raw.split('\n');
let id: string | undefined;
const dataParts: string[] = [];
let sawDataField = false;
for (const line of lines) {
if (line.length === 0) continue;
if (line.startsWith(':')) continue; // SSE comment / keepalive
const colonIdx = line.indexOf(':');
const field = colonIdx === -1 ? line : line.slice(0, colonIdx);
let value = colonIdx === -1 ? '' : line.slice(colonIdx + 1);
// SSE: value may be prefixed by exactly one optional space.
if (value.startsWith(' ')) value = value.slice(1);
if (field === 'id') {
id = value;
} else if (field === 'data') {
sawDataField = true;
dataParts.push(value);
}
// Other field names ('event', 'retry') are ignored for now.
}
if (!sawDataField && id === undefined) return null;
return {
id,
data: sawDataField ? dataParts.join('\n') : undefined,
};
}
function sleep(ms: number, signal: AbortSignal): Promise<void> {
return new Promise((resolve, reject) => {
if (signal.aborted) {
reject(new DOMException('Aborted', 'AbortError'));
return;
}
const timer = setTimeout(() => {
signal.removeEventListener('abort', onAbort);
resolve();
}, ms);
const onAbort = () => {
clearTimeout(timer);
signal.removeEventListener('abort', onAbort);
reject(new DOMException('Aborted', 'AbortError'));
};
signal.addEventListener('abort', onAbort, { once: true });
});
}

View File

@@ -1,85 +0,0 @@
import { useEffect } from 'react';
import { useDispatch, useSelector, useStore } from 'react-redux';
import {
selectLastEventId,
sseHealthChanged,
sseLastEventIdAdvanced,
sseLastEventIdReset,
} from '../notifications/notificationsSlice';
import { selectToken, setToken } from '../preferences/preferenceSlice';
import type { AppDispatch, RootState } from '../store';
import { connectEventStream } from './eventStreamClient';
import { dispatchSSEEvent } from './dispatchEvent';
/**
* Open the SSE connection for the current token and keep it alive for
* the lifetime of the host component. Recreates the connection on
* token change (login / refresh).
*
* The ``lastEventId`` cursor is read lazily from the slice on each
* connect attempt via ``store.getState()`` — capturing it at mount time
* would silently re-replay the entire 24h backlog on token rotation,
* since the slice's id advances during the previous connection's
* lifetime but a snapshot ref would still hold the value seen at
* first mount.
*/
export function useEventStream(): void {
const dispatch = useDispatch<AppDispatch>();
const token = useSelector(selectToken);
const store = useStore<RootState>();
useEffect(() => {
// Connect even when token is null. Self-hosted dev installs run
// with ``AUTH_TYPE`` unset, where ``handle_auth`` maps every
// request to ``{"sub": "local"}`` regardless of headers — gating
// the connection on a populated token would silently disable push
// notifications for the most common configuration. When auth IS
// required and token is null, the backend will 401 and the
// health state will flip to ``unhealthy`` via the response check
// inside ``connectEventStream``.
const conn = connectEventStream({
token,
getLastEventId: () => selectLastEventId(store.getState()),
onEvent: (envelope) => dispatchSSEEvent(envelope, dispatch),
// Advance the slice cursor for every id-bearing frame. Each tab
// owns an independent SSE connection and Redux store, so every
// active tab tracks its own replay cursor.
onLastEventId: (id) => dispatch(sseLastEventIdAdvanced(id)),
// Server emitted ``id:`` with an empty value — WHATWG cursor reset.
// Mirror the slice so the next reconnect doesn't resend a stale id.
onLastEventIdReset: () => dispatch(sseLastEventIdReset()),
onHealthChange: (health) => dispatch(sseHealthChanged(health)),
// SSE 401 loop bail-out. Clear the stored token AND dispatch
// ``setToken(null)`` so ``useAuth`` regenerates a fresh
// ``session_jwt`` in-session; the Redux change also flips this
// hook's ``[token]`` dep, tearing down and respawning the
// connection with the new token. Without the dispatch a
// ``session_jwt`` user is stuck until a hard reload.
onAuthFailure: () => {
console.error(
'[useEventStream] giving up after repeated 401s on /api/events',
);
try {
localStorage.removeItem('authToken');
} catch {
// localStorage unavailable (private mode, etc.) — nothing to do.
}
dispatch(setToken(null));
},
// Surface a warning when the non-401 error budget is exhausted so
// the connection going dark isn't completely silent. Doesn't block
// UI — just observable in devtools.
onPermanentFailure: () => {
console.warn(
'[useEventStream] SSE connection failed permanently after repeated errors',
);
},
});
return () => {
conn.close();
};
}, [token, dispatch, store]);
}

View File

@@ -1,10 +1,4 @@
import {
useCallback,
useEffect,
useRef,
useState,
RefObject,
} from 'react';
import { useEffect, RefObject, useState } from 'react';
export function useOutsideAlerter<T extends HTMLElement>(
ref: RefObject<T | null>,
@@ -119,51 +113,6 @@ export function useDarkTheme() {
return [isDarkTheme, toggleTheme, componentMounted] as const;
}
export function useDebouncedValue<T>(value: T, delay = 300): T {
const [debounced, setDebounced] = useState<T>(value);
useEffect(() => {
const timer = setTimeout(() => setDebounced(value), delay);
return () => clearTimeout(timer);
}, [value, delay]);
return debounced;
}
export function useDebouncedCallback<A extends unknown[]>(
callback: (...args: A) => void,
delay = 300,
): ((...args: A) => void) & { cancel: () => void } {
const callbackRef = useRef(callback);
const timerRef = useRef<ReturnType<typeof setTimeout> | null>(null);
useEffect(() => {
callbackRef.current = callback;
}, [callback]);
const cancel = useCallback(() => {
if (timerRef.current) {
clearTimeout(timerRef.current);
timerRef.current = null;
}
}, []);
useEffect(() => cancel, [cancel]);
const debounced = useCallback(
(...args: A) => {
cancel();
timerRef.current = setTimeout(() => {
timerRef.current = null;
callbackRef.current(...args);
}, delay);
},
[delay, cancel],
);
return Object.assign(debounced, { cancel });
}
export function useLoaderState(
initialState = false,
delay = 250,

View File

@@ -15,39 +15,24 @@ export default function useAuth() {
const generateNewToken = async () => {
if (isGeneratingToken.current) return;
isGeneratingToken.current = true;
try {
const response = await userService.getNewToken();
const { token: newToken } = await response.json();
localStorage.setItem('authToken', newToken);
dispatch(setToken(newToken));
setIsAuthLoading(false);
return newToken;
} finally {
// Reset so a subsequent ``setToken(null)`` (SSE 401 recovery)
// can trigger another generation. Without this the in-flight
// guard would latch true forever after the first call.
isGeneratingToken.current = false;
}
const response = await userService.getNewToken();
const { token: newToken } = await response.json();
localStorage.setItem('authToken', newToken);
dispatch(setToken(newToken));
setIsAuthLoading(false);
return newToken;
};
useEffect(() => {
// Re-fires when ``token`` flips to null mid-session (e.g.
// ``useEventStream`` dispatches ``setToken(null)`` after repeated
// SSE 401s) so ``session_jwt`` users get a fresh token without a
// hard reload. ``authType`` short-circuits on subsequent runs.
const initializeAuth = async () => {
try {
let resolvedAuthType = authType;
if (resolvedAuthType === null) {
const configRes = await userService.getConfig();
const config = await configRes.json();
resolvedAuthType = config.auth_type;
setAuthType(resolvedAuthType);
}
const configRes = await userService.getConfig();
const config = await configRes.json();
setAuthType(config.auth_type);
if (resolvedAuthType === 'session_jwt' && !token) {
if (config.auth_type === 'session_jwt' && !token) {
await generateNewToken();
} else if (resolvedAuthType === 'simple_jwt' && !token) {
} else if (config.auth_type === 'simple_jwt' && !token) {
setShowTokenModal(true);
setIsAuthLoading(false);
} else {
@@ -59,7 +44,7 @@ export default function useAuth() {
}
};
initializeAuth();
}, [token, authType]);
}, []);
const handleTokenSubmit = (enteredToken: string) => {
localStorage.setItem('authToken', enteredToken);

View File

@@ -15,7 +15,6 @@ import {
SelectValue,
} from '../components/ui/select';
import { ActiveState } from '../models/misc';
import { selectRecentEvents } from '../notifications/notificationsSlice';
import { selectToken } from '../preferences/preferenceSlice';
import WrapperComponent from './WrapperModal';
@@ -34,7 +33,6 @@ export default function MCPServerModal({
}: MCPServerModalProps) {
const { t } = useTranslation();
const token = useSelector(selectToken);
const recentEvents = useSelector(selectRecentEvents);
const authTypes = [
{ label: t('settings.tools.mcp.authTypes.none'), value: 'none' },
@@ -73,29 +71,17 @@ export default function MCPServerModal({
>([]);
const [errors, setErrors] = useState<{ [key: string]: string }>({});
const oauthPopupRef = useRef<Window | null>(null);
// Set after ``test_mcp_connection`` returns ``task_id``. The SSE
// effect filters ``recentEvents`` to envelopes matching this id and
// drives the OAuth UI (popup open / completion / failure) from the
// push stream rather than polling the legacy status endpoint.
const [oauthTaskId, setOauthTaskId] = useState<string | null>(null);
// Highest event id we have already reacted to for this taskId. Each
// mcp.oauth.* envelope must fire its side-effect once; without this
// any later re-render that re-evaluates ``recentEvents`` would
// re-open the popup or re-fire onComplete.
const handledEventIdsRef = useRef<Set<string>>(new Set());
// Holds the ``testConnection`` ``onComplete`` for the current
// task id so the SSE effect can invoke it when the terminal event
// lands. Reset to ``null`` on cancel / new test / unmount.
const onCompleteRef = useRef<((result: any) => void) | null>(null);
const popupOpenedRef = useRef(false);
const pollingCancelledRef = useRef(false);
const pollTimerRef = useRef<ReturnType<typeof setTimeout> | null>(null);
const [oauthCompleted, setOAuthCompleted] = useState(false);
const [saveActive, setSaveActive] = useState(false);
const cleanupOAuthListener = useCallback(() => {
setOauthTaskId(null);
handledEventIdsRef.current = new Set();
onCompleteRef.current = null;
popupOpenedRef.current = false;
const cleanupPolling = useCallback(() => {
pollingCancelledRef.current = true;
if (pollTimerRef.current) {
clearTimeout(pollTimerRef.current);
pollTimerRef.current = null;
}
if (oauthPopupRef.current && !oauthPopupRef.current.closed) {
oauthPopupRef.current.close();
}
@@ -103,8 +89,8 @@ export default function MCPServerModal({
}, []);
useEffect(() => {
return cleanupOAuthListener;
}, [cleanupOAuthListener]);
return cleanupPolling;
}, [cleanupPolling]);
useEffect(() => {
if (modalState === 'ACTIVE' && server) {
@@ -133,7 +119,7 @@ export default function MCPServerModal({
}, [modalState, server]);
const resetForm = () => {
cleanupOAuthListener();
cleanupPolling();
setFormData({
name: t('settings.tools.mcp.defaultServerName'),
server_url: '',
@@ -242,123 +228,114 @@ export default function MCPServerModal({
return config;
};
/**
* Drive the OAuth handshake straight from the SSE stream:
*
* - ``mcp.oauth.awaiting_redirect`` → open the popup with the
* ``authorization_url`` carried on the envelope. Previously this URL
* came from polling ``/api/mcp_server/oauth_status/<task_id>``; the
* worker now publishes it inline so we never need to poll.
* - ``mcp.oauth.completed`` → enable Save, surface discovered tools,
* invoke ``onComplete`` (resolves ``testConnection``'s pending state).
* - ``mcp.oauth.failed`` → surface the error and reset Save.
*
* Each event is matched to the active task id via ``scope.id``. The
* publisher is best-effort: a lost ``awaiting_redirect`` envelope
* means the popup never opens, the user retries, and we accept that
* over the prior 1s × 60 polling loop.
*/
useEffect(() => {
if (!oauthTaskId) return;
// ``recentEvents`` is newest-first (the slice ``unshift``s on
// arrival). Walk it oldest-first so we observe the natural OAuth
// ordering (``awaiting_redirect`` → ``completed``) when both
// arrive between effect runs — otherwise we would short-circuit
// on ``completed`` and never open the popup for the
// ``awaiting_redirect`` envelope that was already buffered.
for (let i = recentEvents.length - 1; i >= 0; i--) {
const event = recentEvents[i];
if (event.scope?.id !== oauthTaskId) continue;
if (!event.id || handledEventIdsRef.current.has(event.id)) continue;
const pollOAuthStatus = async (
taskId: string,
onComplete: (result: any) => void,
) => {
let attempts = 0;
const maxAttempts = 60;
let popupOpened = false;
pollingCancelledRef.current = false;
const payload = (event.payload || {}) as Record<string, unknown>;
const poll = async () => {
if (pollingCancelledRef.current) return;
try {
const resp = await userService.getMCPOAuthStatus(taskId, token);
if (pollingCancelledRef.current) return;
const data = await resp.json();
if (pollingCancelledRef.current) return;
if (event.type === 'mcp.oauth.awaiting_redirect') {
handledEventIdsRef.current.add(event.id);
const authUrl = payload.authorization_url as string | undefined;
if (authUrl && !popupOpenedRef.current) {
popupOpenedRef.current = true;
if (data.authorization_url && !popupOpened) {
if (oauthPopupRef.current && !oauthPopupRef.current.closed) {
oauthPopupRef.current.close();
}
oauthPopupRef.current = window.open(
authUrl,
data.authorization_url,
'oauthPopup',
'width=600,height=700',
);
popupOpened = true;
if (!oauthPopupRef.current) {
// Popup blocked — surface the URL inline so the user can
// click through manually. Browsers gate ``window.open``
// outside of a user gesture, and the SSE event arrives
// asynchronously, so a blocked popup is expected on
// some browsers / configs.
setTestResult({
success: true,
message: t('settings.tools.mcp.oauthPopupBlocked', {
defaultValue:
'Popup blocked by browser. Click below to authorize:',
}),
authorization_url: authUrl,
authorization_url: data.authorization_url,
});
}
}
continue;
}
if (event.type === 'mcp.oauth.completed') {
handledEventIdsRef.current.add(event.id);
const tools = Array.isArray(payload.tools) ? payload.tools : [];
const toolsCount =
(payload.tools_count as number | undefined) ?? tools.length;
setOAuthCompleted(true);
setSaveActive(true);
if (oauthPopupRef.current && !oauthPopupRef.current.closed) {
oauthPopupRef.current.close();
}
const cb = onCompleteRef.current;
onCompleteRef.current = null;
setOauthTaskId(null);
if (cb) {
cb({
status: 'completed',
task_id: oauthTaskId,
tools,
tools_count: toolsCount,
const callbackReceived =
data.status === 'callback_received' || data.status === 'completed';
if (data.status === 'completed') {
setOAuthCompleted(true);
setSaveActive(true);
onComplete({
...data,
success: true,
message: t('settings.tools.mcp.oauthCompleted'),
});
}
continue;
}
if (event.type === 'mcp.oauth.failed') {
handledEventIdsRef.current.add(event.id);
const message =
(payload.error as string) ??
t('settings.tools.mcp.errors.oauthFailed');
setSaveActive(false);
if (oauthPopupRef.current && !oauthPopupRef.current.closed) {
oauthPopupRef.current.close();
}
const cb = onCompleteRef.current;
onCompleteRef.current = null;
setOauthTaskId(null);
if (cb) {
cb({
status: 'error',
task_id: oauthTaskId,
if (oauthPopupRef.current && !oauthPopupRef.current.closed) {
oauthPopupRef.current.close();
}
} else if (data.status === 'error' || data.success === false) {
setSaveActive(false);
onComplete({
...data,
success: false,
message,
message: data.message || t('settings.tools.mcp.errors.oauthFailed'),
});
if (oauthPopupRef.current && !oauthPopupRef.current.closed) {
oauthPopupRef.current.close();
}
} else {
if (++attempts < maxAttempts) {
if (
oauthPopupRef.current &&
oauthPopupRef.current.closed &&
popupOpened &&
!callbackReceived
) {
setSaveActive(false);
onComplete({
success: false,
message: t('settings.tools.mcp.errors.oauthFailed'),
});
return;
}
pollTimerRef.current = setTimeout(poll, 1000);
} else {
setSaveActive(false);
cleanupPolling();
onComplete({
success: false,
message: t('settings.tools.mcp.errors.oauthTimeout'),
});
}
}
} catch {
if (pollingCancelledRef.current) return;
if (++attempts < maxAttempts) {
pollTimerRef.current = setTimeout(poll, 1000);
} else {
cleanupPolling();
onComplete({
success: false,
message: t('settings.tools.mcp.errors.oauthTimeout'),
});
}
continue;
}
}
}, [recentEvents, oauthTaskId, t]);
};
poll();
};
const testConnection = async () => {
if (!validateForm()) return;
cleanupOAuthListener();
cleanupPolling();
setTesting(true);
setTestResult(null);
setDiscoveredTools([]);
@@ -378,7 +355,7 @@ export default function MCPServerModal({
message: t('settings.tools.mcp.oauthInProgress'),
});
setSaveActive(false);
onCompleteRef.current = (finalResult: any) => {
pollOAuthStatus(result.task_id, (finalResult) => {
setTestResult(finalResult);
if (finalResult.tools && Array.isArray(finalResult.tools)) {
setDiscoveredTools(finalResult.tools);
@@ -388,11 +365,7 @@ export default function MCPServerModal({
oauth_task_id: result.task_id || '',
}));
setTesting(false);
};
// Activate the SSE listener for this task id. The effect above
// will react when ``mcp.oauth.{awaiting_redirect,completed,failed}``
// arrives.
setOauthTaskId(result.task_id);
});
} else {
setTestResult(result);
if (result.success && result.tools && Array.isArray(result.tools)) {

View File

@@ -1,174 +0,0 @@
import { useDispatch, useSelector } from 'react-redux';
import { useMatch, useNavigate } from 'react-router-dom';
import WarnIcon from '../assets/warn.svg';
import type { RootState } from '../store';
import {
dismissToolApproval,
selectDismissedToolApprovals,
selectRecentEvents,
} from './notificationsSlice';
/**
* Surface ``tool.approval.required`` events as toasts that look like
* ``UploadToast`` (same fixed bottom-right rail) — but only when the
* user is NOT already on the conversation that needs the approval.
*
* - Dedup by ``conversation_id`` (the SSE ``scope.id``): keep only
* the newest pending event per conversation, so multiple paused
* tools in one conversation collapse to one toast.
* - Dismissal is per-event-id so a *new* pause of the same
* conversation will re-surface (different event id).
* - Clicking "Review" navigates to ``/c/<id>`` and dismisses.
*/
export default function ToolApprovalToast() {
const dispatch = useDispatch();
const navigate = useNavigate();
const events = useSelector(selectRecentEvents);
const dismissed = useSelector(selectDismissedToolApprovals);
// Pull the active conversation id off the route. Two route shapes
// place a conversation in view: the bare ``/c/:conversationId`` and
// the agent-scoped ``/agents/:agentId/c/:conversationId``. Hooks
// are unconditional; the toast just respects whichever matches.
//
// ``/c/new`` is the conversation route's literal-string placeholder
// for "unknown / not-yet-loaded conversation" (see the rewrite in
// the conversation route). Treat it the same as no match — the
// user isn't viewing any specific conversation yet, so an approval
// toast for any conversation should still surface.
const plainMatch = useMatch('/c/:conversationId');
const agentMatch = useMatch('/agents/:agentId/c/:conversationId');
const matchedConversationId =
plainMatch?.params.conversationId ??
agentMatch?.params.conversationId ??
null;
const currentConversationId =
matchedConversationId === 'new' ? null : matchedConversationId;
// Conversation name lookup — best-effort. The slice's
// ``preference.conversations.data`` is populated by
// ``useDataInitializer`` once auth resolves; until then we fall
// back to the conversation id.
const conversations = useSelector(
(state: RootState) => state.preference.conversations.data,
);
const dismissedSet = new Set(dismissed);
const pendingByConversation = new Map<
string,
{ eventId: string; conversationId: string }
>();
for (const event of events) {
if (event.type !== 'tool.approval.required') continue;
if (!event.id) continue; // can't dismiss without an id
if (dismissedSet.has(event.id)) continue;
const conversationId = event.scope?.id;
if (!conversationId) continue;
if (currentConversationId && conversationId === currentConversationId) {
continue;
}
if (pendingByConversation.has(conversationId)) continue;
// ``recentEvents`` is newest-first, so the first match per convId
// is the most recent unhandled approval.
pendingByConversation.set(conversationId, {
eventId: event.id,
conversationId,
});
}
if (pendingByConversation.size === 0) return null;
const conversationName = (conversationId: string): string => {
const found = conversations?.find((c) => c.id === conversationId);
return found?.name ?? 'Conversation';
};
return (
// Sit above ``UploadToast`` (which owns ``bottom-4 right-4``)
// rather than overlapping it. ``bottom-24`` ≈ 96px clears one
// standard-height upload toast; multiple in-flight uploads will
// stack into the gap, at which point the approval toast still
// floats on top via ``z-50``. Acceptable v1 layout — the two
// surfaces are rarely competing.
<div
className="fixed right-4 bottom-24 z-50 flex max-w-md flex-col gap-2"
onMouseDown={(e) => e.stopPropagation()}
role="status"
aria-live="polite"
aria-atomic="true"
>
{Array.from(pendingByConversation.values()).map(
({ eventId, conversationId }) => (
<div
key={eventId}
className="border-border bg-card w-[271px] overflow-hidden rounded-2xl border shadow-[0px_24px_48px_0px_#00000029]"
>
<div className="bg-accent/50 dark:bg-muted flex items-center justify-between px-4 py-3">
<h3 className="font-inter dark:text-foreground text-[14px] leading-[16.5px] font-medium text-black">
Tool approval needed
</h3>
<button
type="button"
onClick={() => dispatch(dismissToolApproval(eventId))}
className="flex h-8 items-center justify-center p-0 text-black opacity-70 transition-opacity hover:opacity-100 dark:text-white"
aria-label="Dismiss"
>
<svg
width="16"
height="16"
viewBox="0 0 24 24"
fill="none"
xmlns="http://www.w3.org/2000/svg"
className="h-4 w-4"
>
<path
d="M18 6L6 18"
stroke="currentColor"
strokeWidth="2"
strokeLinecap="round"
strokeLinejoin="round"
/>
<path
d="M6 6L18 18"
stroke="currentColor"
strokeWidth="2"
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
</button>
</div>
<div className="flex items-center justify-between gap-3 px-5 py-3">
<div className="flex min-w-0 items-center gap-2">
<img
src={WarnIcon}
alt=""
className="h-5 w-5 shrink-0"
aria-hidden="true"
/>
<p
className="font-inter dark:text-muted-foreground max-w-[140px] truncate text-[13px] leading-[16.5px] font-normal text-black"
title={conversationName(conversationId)}
>
{conversationName(conversationId)}
</p>
</div>
<button
type="button"
onClick={() => {
dispatch(dismissToolApproval(eventId));
navigate(`/c/${conversationId}`);
}}
className="rounded-full bg-[#7D54D1] px-3 py-1 text-[12px] font-medium text-white shadow-sm hover:bg-[#6a45b8]"
>
Review
</button>
</div>
</div>
),
)}
</div>
);
}

View File

@@ -1,71 +0,0 @@
import { beforeEach, describe, expect, it } from 'vitest';
import {
isDismissed,
loadDismissed,
saveDismissed,
} from './dismissedPersistence';
const KEY = 'test:dismissed';
const TTL = 24 * 60 * 60 * 1000;
describe('dismissedPersistence', () => {
beforeEach(() => {
localStorage.clear();
});
it('saveDismissed + loadDismissed round-trips entries', () => {
const now = Date.now();
saveDismissed(KEY, [
{ id: 'a', at: now },
{ id: 'b', at: now - 1000 },
]);
const loaded = loadDismissed(KEY, TTL);
expect(loaded).toEqual([
{ id: 'a', at: now },
{ id: 'b', at: now - 1000 },
]);
});
it('loadDismissed returns [] when key absent', () => {
expect(loadDismissed(KEY, TTL)).toEqual([]);
});
it('loadDismissed drops entries past the TTL cutoff', () => {
const now = Date.now();
saveDismissed(KEY, [
{ id: 'fresh', at: now - 1000 },
{ id: 'stale', at: now - (TTL + 1000) },
]);
const loaded = loadDismissed(KEY, TTL);
expect(loaded.map((e) => e.id)).toEqual(['fresh']);
});
it('loadDismissed returns [] on malformed JSON without throwing', () => {
localStorage.setItem(KEY, '{not json');
expect(loadDismissed(KEY, TTL)).toEqual([]);
});
it('loadDismissed filters out entries with wrong shape', () => {
const now = Date.now();
localStorage.setItem(
KEY,
JSON.stringify([
{ id: 'good', at: now },
{ id: 123, at: now },
{ id: 'bad-at', at: 'nope' },
null,
'string-entry',
]),
);
const loaded = loadDismissed(KEY, TTL);
expect(loaded.map((e) => e.id)).toEqual(['good']);
});
it('isDismissed matches by id', () => {
const list = [{ id: 'a', at: 1 }];
expect(isDismissed(list, 'a')).toBe(true);
expect(isDismissed(list, 'b')).toBe(false);
expect(isDismissed([], 'a')).toBe(false);
});
});

View File

@@ -1,42 +0,0 @@
// Persisted dismissal lists for SSE-driven toasts. Without persistence
// the next page's backlog replay re-fires the events and pops the
// toast back. TTL matches the backend's stream retention.
export interface DismissedEntry {
id: string;
at: number;
}
export function loadDismissed(key: string, ttlMs: number): DismissedEntry[] {
if (typeof localStorage === 'undefined') return [];
try {
const raw = localStorage.getItem(key);
if (!raw) return [];
const parsed = JSON.parse(raw);
if (!Array.isArray(parsed)) return [];
const cutoff = Date.now() - ttlMs;
return parsed.filter(
(e): e is DismissedEntry =>
!!e &&
typeof e === 'object' &&
typeof (e as DismissedEntry).id === 'string' &&
typeof (e as DismissedEntry).at === 'number' &&
(e as DismissedEntry).at >= cutoff,
);
} catch {
return [];
}
}
export function saveDismissed(key: string, list: DismissedEntry[]): void {
if (typeof localStorage === 'undefined') return;
try {
localStorage.setItem(key, JSON.stringify(list));
} catch {
// Best-effort: ignore quota / private-mode errors.
}
}
export function isDismissed(list: DismissedEntry[], id: string): boolean {
return list.some((e) => e.id === id);
}

View File

@@ -1,109 +0,0 @@
import { describe, expect, it, vi, afterEach } from 'vitest';
import reducer, {
dismissToolApproval,
sseEventReceived,
sseLastEventIdReset,
type SSEEvent,
} from './notificationsSlice';
const baseEvent = (overrides: Partial<SSEEvent> = {}): SSEEvent => ({
id: 'evt-1',
type: 'source.ingest.progress',
...overrides,
});
const seedState = () => reducer(undefined, { type: '@@INIT' });
afterEach(() => {
vi.useRealTimers();
});
describe('sseEventReceived', () => {
it('dedupes by id when the same envelope arrives twice', () => {
let state = seedState();
state = reducer(state, sseEventReceived(baseEvent({ id: 'a' })));
state = reducer(state, sseEventReceived(baseEvent({ id: 'a' })));
expect(state.recentEvents).toHaveLength(1);
expect(state.recentEvents[0].id).toBe('a');
});
it('does not update lastEventId for envelopes without an id (backlog.truncated)', () => {
let state = seedState();
state = reducer(state, sseEventReceived(baseEvent({ id: 'cursor-1' })));
expect(state.lastEventId).toBe('cursor-1');
state = reducer(
state,
sseEventReceived({ type: 'backlog.truncated' } as SSEEvent),
);
expect(state.lastEventId).toBe('cursor-1');
});
it('caps recentEvents at 100 entries (oldest evicted)', () => {
let state = seedState();
for (let i = 0; i < 105; i += 1) {
state = reducer(state, sseEventReceived(baseEvent({ id: `e-${i}` })));
}
expect(state.recentEvents).toHaveLength(100);
// Newest first.
expect(state.recentEvents[0].id).toBe('e-104');
expect(state.recentEvents[state.recentEvents.length - 1].id).toBe('e-5');
});
});
describe('sseLastEventIdReset', () => {
it('clears lastEventId back to null', () => {
let state = seedState();
state = reducer(state, sseEventReceived(baseEvent({ id: 'x' })));
expect(state.lastEventId).toBe('x');
state = reducer(state, sseLastEventIdReset());
expect(state.lastEventId).toBeNull();
});
});
describe('dismissToolApproval', () => {
it('dedupes by id and refreshes the timestamp', () => {
vi.useFakeTimers();
vi.setSystemTime(new Date('2026-01-01T00:00:00Z'));
let state = seedState();
state = reducer(state, dismissToolApproval('approval-1'));
const firstAt = state.dismissedToolApprovals[0].at;
vi.setSystemTime(new Date('2026-01-01T00:05:00Z'));
state = reducer(state, dismissToolApproval('approval-1'));
expect(state.dismissedToolApprovals).toHaveLength(1);
expect(state.dismissedToolApprovals[0].at).toBeGreaterThan(firstAt);
});
it('evicts entries older than the 24h TTL', () => {
vi.useFakeTimers();
vi.setSystemTime(new Date('2026-01-01T00:00:00Z'));
let state = seedState();
state = reducer(state, dismissToolApproval('old-1'));
// Move past the 24h TTL window.
vi.setSystemTime(new Date('2026-01-02T00:00:01Z'));
state = reducer(state, dismissToolApproval('fresh-1'));
const ids = state.dismissedToolApprovals.map((entry) => entry.id);
expect(ids).toEqual(['fresh-1']);
});
it('applies the 200-entry cap as a backstop after TTL filtering', () => {
vi.useFakeTimers();
vi.setSystemTime(new Date('2026-01-01T00:00:00Z'));
let state = seedState();
// Insert 205 distinct ids within the TTL window.
for (let i = 0; i < 205; i += 1) {
// Advance time slightly so the at-values are distinct but well
// inside the 24h TTL.
vi.setSystemTime(
new Date(`2026-01-01T00:00:${(i % 60).toString().padStart(2, '0')}Z`),
);
state = reducer(state, dismissToolApproval(`id-${i}`));
}
expect(state.dismissedToolApprovals).toHaveLength(200);
// The 200-cap keeps the most recently pushed ids.
expect(state.dismissedToolApprovals[0].id).toBe('id-5');
expect(state.dismissedToolApprovals[199].id).toBe('id-204');
});
});

View File

@@ -1,200 +0,0 @@
import { createSelector, createSlice, PayloadAction } from '@reduxjs/toolkit';
import { RootState } from '../store';
import { loadDismissed, saveDismissed } from './dismissedPersistence';
const DISMISSED_TOOL_APPROVALS_STORAGE_KEY = 'docsgpt:dismissedToolApprovals';
/**
* Envelope shape published by the backend SSE endpoint
* (`application/events/publisher.py`). Mirrors the wire JSON 1:1.
*/
export interface SSEEvent<P = Record<string, unknown>> {
id?: string;
type: string;
ts?: string;
user_id?: string;
topic?: string;
scope?: { kind: string; id: string };
payload?: P;
}
/**
* Connection-health state machine the rest of the app reads via
* ``selectPushChannelHealthy`` to gate polling-fallback behaviour.
*
* - ``connecting`` — initial fetch in flight, or reconnecting after drop.
* - ``healthy`` — at least one event (data or keepalive) received and
* the stream has been open >2s.
* - ``unhealthy`` — last attempt failed or has been dropped without a
* successful re-establish; fall back to polling.
*/
export type PushHealth = 'connecting' | 'healthy' | 'unhealthy';
interface NotificationsState {
health: PushHealth;
/** Most-recent server-issued id; sent back as ``Last-Event-ID`` on reconnect. */
lastEventId: string | null;
/** Bounded ring of recent events for the in-app notifications surface. */
recentEvents: SSEEvent[];
/**
* Wallclock ms of last received data-bearing event. SSE keepalives
* are comment lines (no ``id:``/``data:``) and do NOT update this —
* they're filtered out at the parser level.
*/
lastEventReceivedAt: number | null;
/**
* Event ids of ``tool.approval.required`` notifications the user
* dismissed (close button or by navigating into the conversation),
* each tagged with the wallclock ms at which it was dismissed.
* Keyed by the SSE event id so a *new* approval for the same
* conversation re-surfaces; the dismissal only suppresses the one
* specific paused-tool prompt.
*
* Entries are evicted by TTL first (anything older than
* ``DISMISSED_TOOL_APPROVALS_TTL_MS``), then by FIFO cap. The TTL
* matters because a pure FIFO with a small cap can evict a *still-
* pending* approval id before the user acts on it — re-popping the
* toast on the next dispatch. The 24h ceiling is longer than any
* plausible approval-pending window.
*/
dismissedToolApprovals: Array<{ id: string; at: number }>;
}
const RECENT_EVENTS_CAP = 100;
const DISMISSED_TOOL_APPROVALS_CAP = 200;
const DISMISSED_TOOL_APPROVALS_TTL_MS = 24 * 60 * 60 * 1000;
const initialState: NotificationsState = {
health: 'connecting',
lastEventId: null,
recentEvents: [],
lastEventReceivedAt: null,
// Hydrate from localStorage: SSE backlog replay re-delivers the
// originating ``tool.approval.required`` envelopes on reload.
dismissedToolApprovals: loadDismissed(
DISMISSED_TOOL_APPROVALS_STORAGE_KEY,
DISMISSED_TOOL_APPROVALS_TTL_MS,
),
};
export const notificationsSlice = createSlice({
name: 'notifications',
initialState,
reducers: {
sseEventReceived: (state, action: PayloadAction<SSEEvent>) => {
const e = action.payload;
// Drop immediate duplicates. Snapshot replay + live tail can
// both deliver the same id when the live pubsub frame and the
// replay XRANGE overlap, and consumers that walk
// ``recentEvents`` (FileTree, ConnectorTree, MCPServerModal,
// ToolApprovalToast) would otherwise act on the same envelope
// twice. The route's dedup floor catches the common case; this
// is a belt-and-suspenders for in-tab StrictMode double-mounts
// and any envelope that slips through with the same id.
if (e.id && state.recentEvents[0]?.id === e.id) return;
state.recentEvents.unshift(e);
if (state.recentEvents.length > RECENT_EVENTS_CAP) {
state.recentEvents.length = RECENT_EVENTS_CAP;
}
if (e.id) state.lastEventId = e.id;
state.lastEventReceivedAt = Date.now();
},
sseHealthChanged: (state, action: PayloadAction<PushHealth>) => {
state.health = action.payload;
},
/**
* Lifecycle helper used by reconnect bookkeeping — does not record
* an event, just stamps "we heard from the server" so the polling
* fallback stays disabled while keepalives arrive.
*/
sseHeartbeat: (state) => {
state.lastEventReceivedAt = Date.now();
},
sseLastEventIdReset: (state) => {
// Backlog truncated — drop the cursor so the next reconnect
// doesn't try to resume past the retained window.
state.lastEventId = null;
},
/**
* Advance the cursor without recording an event. Fired for every
* id-bearing frame including keepalives and id-only comments,
* so the slice cursor tracks the freshest id the wire has
* delivered even when no envelope was dispatched. Without this,
* ``lastEventId`` would only advance via ``sseEventReceived`` and
* a long quiet period of keepalives would leave it stale —
* eventually re-snapshotting the same backlog on each reconnect
* and exhausting the per-user replay budget.
*/
sseLastEventIdAdvanced: (state, action: PayloadAction<string>) => {
state.lastEventId = action.payload;
},
clearRecentEvents: (state) => {
state.recentEvents = [];
},
/**
* Suppress a ``tool.approval.required`` notification by the SSE
* event id. The toast surface filters dismissed ids out; a *new*
* approval event for the same conversation has a different id
* and re-surfaces, which is the desired UX (each pause is its
* own decision).
*/
dismissToolApproval: (state, action: PayloadAction<string>) => {
const id = action.payload;
const now = Date.now();
// Evict expired entries first so the TTL — not the FIFO cap —
// governs when stale ids drop, keeping still-pending approvals
// suppressed.
const cutoff = now - DISMISSED_TOOL_APPROVALS_TTL_MS;
state.dismissedToolApprovals = state.dismissedToolApprovals.filter(
(entry) => entry.at >= cutoff && entry.id !== id,
);
state.dismissedToolApprovals.push({ id, at: now });
if (state.dismissedToolApprovals.length > DISMISSED_TOOL_APPROVALS_CAP) {
state.dismissedToolApprovals = state.dismissedToolApprovals.slice(
-DISMISSED_TOOL_APPROVALS_CAP,
);
}
saveDismissed(
DISMISSED_TOOL_APPROVALS_STORAGE_KEY,
state.dismissedToolApprovals,
);
},
},
});
export const {
sseEventReceived,
sseHealthChanged,
sseHeartbeat,
sseLastEventIdReset,
sseLastEventIdAdvanced,
clearRecentEvents,
dismissToolApproval,
} = notificationsSlice.actions;
export const selectSseHealth = (state: RootState): PushHealth =>
state.notifications.health;
export const selectPushChannelHealthy = (state: RootState): boolean =>
state.notifications.health === 'healthy';
export const selectLastEventId = (state: RootState): string | null =>
state.notifications.lastEventId;
export const selectLastEventReceivedAt = (state: RootState): number | null =>
state.notifications.lastEventReceivedAt;
export const selectRecentEvents = (state: RootState): SSEEvent[] =>
state.notifications.recentEvents;
// Memoised so ``useSelector`` consumers don't re-render on every
// unrelated ``notifications`` state change — the underlying ``{id,at}``
// array only changes when ``dismissToolApproval`` runs, but the
// projected ``.map`` would otherwise return a fresh array each call.
export const selectDismissedToolApprovals = createSelector(
(state: RootState) => state.notifications.dismissedToolApprovals,
(entries) => entries.map((entry) => entry.id),
);
export default notificationsSlice.reducer;

View File

@@ -17,7 +17,7 @@ import ContextMenu, { MenuOption } from '../components/ContextMenu';
import Pagination from '../components/DocumentPagination';
import DropdownMenu from '../components/DropdownMenu';
import SkeletonLoader from '../components/SkeletonLoader';
import { useDarkTheme, useDebouncedValue, useLoaderState } from '../hooks';
import { useDarkTheme, useLoaderState } from '../hooks';
import ConfirmationModal from '../modals/ConfirmationModal';
import { ActiveState, Doc, DocumentsProps } from '../models/misc';
import { getDocs, getDocsWithPagination } from '../preferences/preferenceApi';
@@ -58,7 +58,7 @@ export default function Sources({
const token = useSelector(selectToken);
const [searchTerm, setSearchTerm] = useState<string>('');
const debouncedSearchTerm = useDebouncedValue(searchTerm, 500);
const [debouncedSearchTerm, setDebouncedSearchTerm] = useState<string>('');
const [modalState, setModalState] = useState<ActiveState>('INACTIVE');
const [isOnboarding, setIsOnboarding] = useState<boolean>(false);
const [loading, setLoading] = useLoaderState(false);
@@ -117,6 +117,14 @@ export default function Sources({
document: null,
});
useEffect(() => {
const timer = setTimeout(() => {
setDebouncedSearchTerm(searchTerm);
}, 500);
return () => clearTimeout(timer);
}, [searchTerm]);
const refreshDocs = useCallback(
(
field: 'date' | 'tokens' | undefined,

View File

@@ -4,7 +4,6 @@ import agentPreviewReducer from './agents/agentPreviewSlice';
import workflowPreviewReducer from './agents/workflow/workflowPreviewSlice';
import { conversationSlice } from './conversation/conversationSlice';
import { sharedConversationSlice } from './conversation/sharedConversationSlice';
import notificationsReducer from './notifications/notificationsSlice';
import { getStoredRecentDocs } from './preferences/preferenceApi';
import {
Preference,
@@ -68,7 +67,6 @@ const store = configureStore({
upload: uploadReducer,
agentPreview: agentPreviewReducer,
workflowPreview: workflowPreviewReducer,
notifications: notificationsReducer,
},
middleware: (getDefaultMiddleware) =>
getDefaultMiddleware().concat(prefListenerMiddleware.middleware),

View File

@@ -2,9 +2,9 @@ import { useCallback, useState } from 'react';
import { nanoid } from '@reduxjs/toolkit';
import { useDropzone } from 'react-dropzone';
import { useTranslation } from 'react-i18next';
import { useDispatch, useSelector, useStore } from 'react-redux';
import { useDispatch, useSelector } from 'react-redux';
import type { RootState } from '../store';
import userService from '../api/services/userService';
import { getSessionToken } from '../utils/providerUtils';
import Dropdown from '../components/Dropdown';
import Input from '../components/Input';
@@ -298,7 +298,6 @@ function Upload({
const { t } = useTranslation();
const dispatch = useDispatch();
const store = useStore<RootState>();
const ingestorOptions: IngestorOption[] = IngestorFormSchemas.filter(
(schema) => (schema.validate ? schema.validate() : true),
@@ -335,113 +334,110 @@ function Upload({
[dispatch],
);
/**
* Wait for the source.ingest.* SSE pipeline to flip this task into a
* terminal state, then run the post-completion side effects: refresh
* the global source list, auto-select the new doc, and fire the
* caller's ``onSuccessfulUpload`` hook. The slice's extraReducer
* (uploadSlice.ts) is the sole driver of the task's status; we only
* subscribe so the side effects can fire after the modal has closed.
*/
const trackTraining = useCallback(
(clientTaskId: string) => {
let handled = false;
(backendTaskId: string, clientTaskId: string) => {
let timeoutId: number | null = null;
const handleTerminal = (status: 'completed' | 'failed') => {
if (handled) return;
handled = true;
if (status !== 'completed') return;
getDocs(token)
.then((docs) => {
dispatch(setSourceDocs(docs));
if (Array.isArray(docs)) {
const existingDocIds = new Set(
(Array.isArray(sourceDocs) ? sourceDocs : [])
.map((doc: Doc) => doc?.id)
.filter((id): id is string => Boolean(id)),
);
const newDoc = docs.find(
(doc: Doc) => doc.id && !existingDocIds.has(doc.id),
);
if (newDoc) {
// If only one doc is selected, replace it completely
// If multiple docs are selected, append the new doc
if (selectedDocs.length === 1) {
dispatch(setSelectedDocs([newDoc]));
} else {
dispatch(setSelectedDocs([...selectedDocs, newDoc]));
const poll = () => {
userService
.getTaskStatus(backendTaskId, null)
.then((response) => response.json())
.then(async (data) => {
if (!data.success && data.message) {
if (timeoutId !== null) {
clearTimeout(timeoutId);
timeoutId = null;
}
handleTaskFailure(clientTaskId, data.message);
return;
}
if (data.status === 'SUCCESS') {
if (timeoutId !== null) {
clearTimeout(timeoutId);
timeoutId = null;
}
const docs = await getDocs(token);
dispatch(setSourceDocs(docs));
if (Array.isArray(docs)) {
const existingDocIds = new Set(
(Array.isArray(sourceDocs) ? sourceDocs : [])
.map((doc: Doc) => doc?.id)
.filter((id): id is string => Boolean(id)),
);
const newDoc = docs.find(
(doc: Doc) => doc.id && !existingDocIds.has(doc.id),
);
if (newDoc) {
// If only one doc is selected, replace it completely
// If multiple docs are selected, append the new doc
if (selectedDocs.length === 1) {
dispatch(setSelectedDocs([newDoc]));
} else {
dispatch(setSelectedDocs([...selectedDocs, newDoc]));
}
}
}
if (data.result?.limited) {
dispatch(
updateUploadTask({
id: clientTaskId,
updates: {
status: 'failed',
progress: 100,
errorMessage: t('modals.uploadDoc.progress.tokenLimit'),
},
}),
);
} else {
dispatch(
updateUploadTask({
id: clientTaskId,
updates: {
status: 'completed',
progress: 100,
errorMessage: undefined,
},
}),
);
onSuccessfulUpload?.();
}
} else if (data.status === 'FAILURE') {
if (timeoutId !== null) {
clearTimeout(timeoutId);
timeoutId = null;
}
handleTaskFailure(clientTaskId, data.result?.message);
} else if (data.status === 'PROGRESS') {
dispatch(
updateUploadTask({
id: clientTaskId,
updates: {
status: 'training',
progress: Math.min(100, data.result?.current ?? 0),
},
}),
);
timeoutId = window.setTimeout(poll, 5000);
} else {
timeoutId = window.setTimeout(poll, 5000);
}
onSuccessfulUpload?.();
})
.catch((err) => {
console.error(
'SSE-driven post-completion source-list refresh failed:',
err,
);
.catch((error) => {
if (timeoutId !== null) {
clearTimeout(timeoutId);
timeoutId = null;
}
handleTaskFailure(clientTaskId, error?.message);
});
};
const check = () => {
const state = store.getState();
const task = state.upload.tasks.find((t) => t.id === clientTaskId);
if (!task) return false;
if (task.status === 'completed' || task.status === 'failed') {
handleTerminal(task.status);
return true;
}
// Recover from the race where the terminal SSE landed before
// ``xhr.onload`` populated ``task.sourceId`` — the slice
// silently drops such events (no task to match by sourceId).
// Mirrors ConnectorTree/FileTree's ``recentEvents`` walk.
if (task.sourceId) {
for (const event of state.notifications.recentEvents) {
if (event.scope?.id !== task.sourceId) continue;
if (event.type === 'source.ingest.completed') {
handleTerminal('completed');
return true;
}
if (event.type === 'source.ingest.failed') {
handleTerminal('failed');
return true;
}
}
}
return false;
};
if (check()) return;
const MAX_WAIT_MS = 5 * 60_000;
let unsubscribe: (() => void) | null = null;
const timer = window.setTimeout(() => {
unsubscribe?.();
if (!handled) {
handled = true;
console.warn(
'trackTraining: timed out waiting for terminal SSE',
clientTaskId,
);
dispatch(
updateUploadTask({
id: clientTaskId,
updates: {
status: 'failed',
errorMessage:
'Timed out waiting for ingest completion. The ingest may still be running — please refresh to check.',
},
}),
);
}
}, MAX_WAIT_MS);
unsubscribe = store.subscribe(() => {
if (check()) {
window.clearTimeout(timer);
unsubscribe?.();
}
});
timeoutId = window.setTimeout(poll, 3000);
},
[dispatch, onSuccessfulUpload, selectedDocs, sourceDocs, store, token],
[dispatch, handleTaskFailure, onSuccessfulUpload, sourceDocs, t, token],
);
const onDrop = useCallback(
@@ -503,23 +499,19 @@ function Upload({
xhr.onload = () => {
if (xhr.status >= 200 && xhr.status < 300) {
try {
const parsed = JSON.parse(xhr.responseText) as {
task_id?: string;
source_id?: string;
};
const parsed = JSON.parse(xhr.responseText) as { task_id?: string };
if (parsed.task_id) {
dispatch(
updateUploadTask({
id: clientTaskId,
updates: {
taskId: parsed.task_id,
sourceId: parsed.source_id,
status: 'training',
progress: 0,
},
}),
);
trackTraining(clientTaskId);
trackTraining(parsed.task_id, clientTaskId);
} else {
dispatch(
updateUploadTask({
@@ -635,23 +627,19 @@ function Upload({
xhr.onload = () => {
if (xhr.status >= 200 && xhr.status < 300) {
try {
const response = JSON.parse(xhr.responseText) as {
task_id?: string;
source_id?: string;
};
const response = JSON.parse(xhr.responseText) as { task_id?: string };
if (response.task_id) {
dispatch(
updateUploadTask({
id: clientTaskId,
updates: {
taskId: response.task_id,
sourceId: response.source_id,
status: 'training',
progress: 0,
},
}),
);
trackTraining(clientTaskId);
trackTraining(response.task_id, clientTaskId);
} else {
dispatch(
updateUploadTask({

View File

@@ -1,458 +0,0 @@
import { beforeEach, describe, expect, it } from 'vitest';
import notificationsReducer, {
sseEventReceived,
type SSEEvent,
} from '../notifications/notificationsSlice';
import reducer, {
addAttachment,
addUploadTask,
dismissUploadTask,
updateAttachment,
updateUploadTask,
type Attachment,
type UploadTask,
} from './uploadSlice';
const SOURCE_ID = 'src-1';
const makeTask = (overrides: Partial<UploadTask> = {}): UploadTask => ({
id: 't-1',
fileName: 'doc.pdf',
progress: 0,
status: 'preparing',
sourceId: SOURCE_ID,
...overrides,
});
const stateWithTask = (task: UploadTask) =>
reducer(undefined, addUploadTask(task));
const ingest = (
type: string,
payload: Record<string, unknown> = {},
scopeId = SOURCE_ID,
) =>
sseEventReceived({
id: `id-${type}`,
type,
scope: { kind: 'source', id: scopeId },
payload,
});
describe('dismissal persistence across reload', () => {
const STORAGE_KEY = 'docsgpt:dismissedUploadSourceIds';
const SRC = 'src-persisted';
// Mirrors initialState as if hydrated from localStorage.
const seedState = (entries: { id: string; at: number }[]) =>
reducer(
{ attachments: [], tasks: [], dismissedSourceIds: entries },
{ type: '@@INIT' },
);
beforeEach(() => {
localStorage.clear();
});
it('dismissUploadTask writes the task sourceId to localStorage', () => {
let state = stateWithTask(
makeTask({ id: 't-dismiss', sourceId: SRC, status: 'completed' }),
);
state = reducer(state, dismissUploadTask('t-dismiss'));
expect(state.tasks[0].dismissed).toBe(true);
expect(state.dismissedSourceIds).toHaveLength(1);
expect(state.dismissedSourceIds[0].id).toBe(SRC);
const persisted = JSON.parse(localStorage.getItem(STORAGE_KEY) ?? '[]');
expect(persisted).toHaveLength(1);
expect(persisted[0].id).toBe(SRC);
});
it('skips persistence when the task has no sourceId yet', () => {
let state = stateWithTask(
makeTask({ id: 't-no-src', sourceId: undefined, status: 'preparing' }),
);
state = reducer(state, dismissUploadTask('t-no-src'));
expect(state.tasks[0].dismissed).toBe(true);
expect(state.dismissedSourceIds).toHaveLength(0);
expect(localStorage.getItem(STORAGE_KEY)).toBeNull();
});
it('auto-create on refresh marks the task dismissed when sourceId is in the persisted list', () => {
const state = seedState([{ id: SRC, at: Date.now() }]);
const next = reducer(
state,
ingest('source.ingest.progress', { current: 40 }, SRC),
);
const recovered = next.tasks.find((t) => t.sourceId === SRC);
expect(recovered).toBeDefined();
expect(recovered!.dismissed).toBe(true);
expect(recovered!.progress).toBe(40);
});
it('terminal events do NOT un-dismiss a task whose sourceId was previously dismissed', () => {
const state = seedState([{ id: SRC, at: Date.now() }]);
let next = reducer(state, ingest('source.ingest.queued', {}, SRC));
expect(next.tasks[0].dismissed).toBe(true);
next = reducer(next, ingest('source.ingest.completed', {}, SRC));
// Even though `wasTerminal` is false (just transitioned), the
// persisted dismissal keeps the toast closed.
expect(next.tasks[0].status).toBe('completed');
expect(next.tasks[0].dismissed).toBe(true);
});
it('updateUploadTask does not un-dismiss on terminal when sourceId was previously dismissed', () => {
const state = reducer(
{
attachments: [],
tasks: [
makeTask({
id: 't-1',
sourceId: SRC,
status: 'training',
dismissed: true,
}),
],
dismissedSourceIds: [{ id: SRC, at: Date.now() }],
},
{ type: '@@INIT' },
);
const next = reducer(
state,
updateUploadTask({
id: 't-1',
updates: { status: 'completed' },
}),
);
expect(next.tasks[0].status).toBe('completed');
expect(next.tasks[0].dismissed).toBe(true);
});
it('un-dismisses normally for sourceIds NOT in the persisted list', () => {
const state = reducer(undefined, { type: '@@INIT' });
const populated = reducer(
state,
addUploadTask(
makeTask({
id: 't-fresh',
sourceId: 'src-fresh',
status: 'training',
dismissed: true,
}),
),
);
const next = reducer(
populated,
updateUploadTask({ id: 't-fresh', updates: { status: 'completed' } }),
);
expect(next.tasks[0].dismissed).toBe(false);
});
});
describe('refresh recovery — auto-create from SSE when no task matches', () => {
it('creates a task on queued for an unknown sourceId', () => {
let state = reducer(undefined, addUploadTask(makeTask({ id: 'other' })));
state = reducer(
state,
ingest(
'source.ingest.queued',
{ filename: 'crawler.json', job_name: 'docs' },
'src-recovery',
),
);
const recovered = state.tasks.find((t) => t.sourceId === 'src-recovery');
expect(recovered).toBeDefined();
expect(recovered!.status).toBe('training');
expect(recovered!.fileName).toBe('crawler.json');
expect(recovered!.dismissed).toBe(false);
});
it('creates a task on progress for an unknown sourceId and applies the percent', () => {
let state: ReturnType<typeof reducer> = reducer(undefined, {
type: '@@INIT',
});
state = reducer(
state,
ingest(
'source.ingest.progress',
{ current: 55, total: 10, embedded_chunks: 5 },
'src-progress',
),
);
expect(state.tasks).toHaveLength(1);
expect(state.tasks[0].sourceId).toBe('src-progress');
expect(state.tasks[0].status).toBe('training');
expect(state.tasks[0].progress).toBe(55);
});
it('does NOT create a task on completed for an unknown sourceId (avoids backlog toast spam)', () => {
let state: ReturnType<typeof reducer> = reducer(undefined, {
type: '@@INIT',
});
state = reducer(
state,
ingest('source.ingest.completed', { tokens: 16 }, 'src-stale'),
);
expect(state.tasks).toHaveLength(0);
});
it('creates a task on failed for an unknown sourceId so error surfaces post-refresh', () => {
let state: ReturnType<typeof reducer> = reducer(undefined, {
type: '@@INIT',
});
state = reducer(
state,
ingest(
'source.ingest.failed',
{ error: 'embed worker died' },
'src-failed',
),
);
expect(state.tasks).toHaveLength(1);
expect(state.tasks[0].status).toBe('failed');
expect(state.tasks[0].errorMessage).toBe('embed worker died');
expect(state.tasks[0].dismissed).toBe(false);
});
it('falls back to job_name then sourceId when filename is absent', () => {
let state: ReturnType<typeof reducer> = reducer(undefined, {
type: '@@INIT',
});
state = reducer(
state,
ingest('source.ingest.queued', { job_name: 'docs-docs' }, 'src-jn'),
);
expect(state.tasks[0].fileName).toBe('docs-docs');
state = reducer(state, ingest('source.ingest.queued', {}, 'src-only'));
expect(state.tasks.find((t) => t.sourceId === 'src-only')!.fileName).toBe(
'src-only',
);
});
it('subsequent progress events update the recovered task in place', () => {
let state: ReturnType<typeof reducer> = reducer(undefined, {
type: '@@INIT',
});
state = reducer(
state,
ingest('source.ingest.queued', { filename: 'a.txt' }, 'src-flow'),
);
state = reducer(
state,
ingest('source.ingest.progress', { current: 30 }, 'src-flow'),
);
state = reducer(
state,
ingest('source.ingest.progress', { current: 80 }, 'src-flow'),
);
state = reducer(
state,
ingest('source.ingest.completed', { tokens: 100 }, 'src-flow'),
);
expect(state.tasks).toHaveLength(1);
expect(state.tasks[0].status).toBe('completed');
expect(state.tasks[0].progress).toBe(100);
});
});
describe('source.ingest.queued', () => {
it('does not regress a task already in training', () => {
let state = stateWithTask(makeTask({ status: 'training', progress: 42 }));
state = reducer(state, ingest('source.ingest.queued'));
expect(state.tasks[0].status).toBe('training');
expect(state.tasks[0].progress).toBe(42);
});
it('transitions preparing -> training and zeros progress', () => {
let state = stateWithTask(makeTask({ status: 'preparing', progress: 12 }));
state = reducer(state, ingest('source.ingest.queued'));
expect(state.tasks[0].status).toBe('training');
expect(state.tasks[0].progress).toBe(0);
});
});
describe('source.ingest.progress', () => {
it('clamps to 0..100 and is monotonic', () => {
let state = stateWithTask(makeTask({ status: 'training' }));
state = reducer(state, ingest('source.ingest.progress', { current: 30 }));
expect(state.tasks[0].progress).toBe(30);
// Higher value advances.
state = reducer(state, ingest('source.ingest.progress', { current: 150 }));
expect(state.tasks[0].progress).toBe(100);
// Lower value never regresses.
state = reducer(state, ingest('source.ingest.progress', { current: 50 }));
expect(state.tasks[0].progress).toBe(100);
// Negative gets clamped at 0 but still cannot regress already-higher.
state = reducer(state, ingest('source.ingest.progress', { current: -10 }));
expect(state.tasks[0].progress).toBe(100);
});
});
describe('source.ingest.completed', () => {
it('transitions training -> completed and sets dismissed=false', () => {
let state = stateWithTask(
makeTask({ status: 'training', dismissed: true }),
);
state = reducer(state, ingest('source.ingest.completed'));
expect(state.tasks[0].status).toBe('completed');
expect(state.tasks[0].progress).toBe(100);
expect(state.tasks[0].dismissed).toBe(false);
expect(state.tasks[0].tokenLimitReached).toBe(false);
});
it('with limited=true transitions training -> failed and flags tokenLimitReached', () => {
let state = stateWithTask(
makeTask({ status: 'training', dismissed: true }),
);
state = reducer(
state,
ingest('source.ingest.completed', { limited: true }),
);
expect(state.tasks[0].status).toBe('failed');
expect(state.tasks[0].progress).toBe(100);
expect(state.tasks[0].tokenLimitReached).toBe(true);
expect(state.tasks[0].dismissed).toBe(false);
});
it('does not re-un-dismiss when a duplicate terminal event arrives', () => {
// Initial terminal — wasTerminal=false, dismissed flipped to false.
let state = stateWithTask(makeTask({ status: 'training' }));
state = reducer(state, ingest('source.ingest.completed'));
expect(state.tasks[0].dismissed).toBe(false);
// User dismisses the toast manually.
state = {
...state,
tasks: state.tasks.map((t) => ({ ...t, dismissed: true })),
};
// Duplicate terminal envelope (StrictMode remount, reconnect overlap).
state = reducer(state, ingest('source.ingest.completed'));
expect(state.tasks[0].status).toBe('completed');
expect(state.tasks[0].dismissed).toBe(true);
});
});
describe('source.ingest.failed', () => {
it('transitions training -> failed with the error message', () => {
let state = stateWithTask(makeTask({ status: 'training' }));
state = reducer(
state,
ingest('source.ingest.failed', { error: 'parser blew up' }),
);
expect(state.tasks[0].status).toBe('failed');
expect(state.tasks[0].errorMessage).toBe('parser blew up');
expect(state.tasks[0].dismissed).toBe(false);
});
it('does not re-un-dismiss when a duplicate failed event arrives', () => {
let state = stateWithTask(makeTask({ status: 'training' }));
state = reducer(state, ingest('source.ingest.failed', { error: 'oops' }));
state = {
...state,
tasks: state.tasks.map((t) => ({ ...t, dismissed: true })),
};
state = reducer(state, ingest('source.ingest.failed', { error: 'oops' }));
expect(state.tasks[0].dismissed).toBe(true);
});
});
describe('attachment race recovery', () => {
const ATTACHMENT_ID = 'att-1';
const CLIENT_ID = 'ui-1';
const attEvent = (
type: string,
payload: Record<string, unknown> = {},
): SSEEvent => ({
id: `id-${type}`,
type,
scope: { kind: 'attachment', id: ATTACHMENT_ID },
payload,
});
const makeAttachment = (overrides: Partial<Attachment> = {}): Attachment => ({
id: CLIENT_ID,
fileName: 'small.pdf',
progress: 10,
status: 'processing',
taskId: 'celery-1',
attachmentId: ATTACHMENT_ID,
...overrides,
});
it('drops attachment.completed silently when no row matches attachmentId', () => {
const state = reducer(
undefined,
sseEventReceived(attEvent('attachment.completed', { token_count: 42 })),
);
expect(state.attachments).toHaveLength(0);
});
it('lands the terminal envelope in notifications.recentEvents for later recovery', () => {
const notifState = notificationsReducer(
undefined,
sseEventReceived(attEvent('attachment.completed', { token_count: 7 })),
);
expect(notifState.recentEvents).toHaveLength(1);
expect(notifState.recentEvents[0].scope?.id).toBe(ATTACHMENT_ID);
expect(notifState.recentEvents[0].type).toBe('attachment.completed');
});
it('reconciler dispatch flips the row to completed after the late row addition', () => {
// Full race: terminal SSE first, then xhr.onload adds the row,
// then trackAttachment.check() walks recentEvents and dispatches.
const terminal = attEvent('attachment.completed', { token_count: 99 });
const notifState = notificationsReducer(
undefined,
sseEventReceived(terminal),
);
let uploadState = reducer(undefined, addAttachment(makeAttachment()));
expect(uploadState.attachments[0].status).toBe('processing');
const found = notifState.recentEvents.find(
(e) => e.scope?.id === ATTACHMENT_ID && e.type === 'attachment.completed',
);
expect(found).toBeDefined();
const tokenCount = Number(
(found?.payload as { token_count?: unknown })?.token_count,
);
uploadState = reducer(
uploadState,
updateAttachment({
id: CLIENT_ID,
updates: {
status: 'completed',
progress: 100,
...(Number.isFinite(tokenCount) ? { token_count: tokenCount } : {}),
},
}),
);
expect(uploadState.attachments[0].status).toBe('completed');
expect(uploadState.attachments[0].progress).toBe(100);
expect(uploadState.attachments[0].token_count).toBe(99);
});
it('attachment.failed envelope can drive a stuck row to failed via reconciler', () => {
const failed = attEvent('attachment.failed', { error: 'docling boom' });
const notifState = notificationsReducer(
undefined,
sseEventReceived(failed),
);
let uploadState = reducer(undefined, addAttachment(makeAttachment()));
const found = notifState.recentEvents.find(
(e) => e.scope?.id === ATTACHMENT_ID && e.type === 'attachment.failed',
);
expect(found).toBeDefined();
uploadState = reducer(
uploadState,
updateAttachment({ id: CLIENT_ID, updates: { status: 'failed' } }),
);
expect(uploadState.attachments[0].status).toBe('failed');
});
});

View File

@@ -1,46 +1,12 @@
import { createSelector, createSlice, PayloadAction } from '@reduxjs/toolkit';
import {
DismissedEntry,
isDismissed,
loadDismissed,
saveDismissed,
} from '../notifications/dismissedPersistence';
import { sseEventReceived } from '../notifications/notificationsSlice';
import { RootState } from '../store';
const DISMISSED_SOURCE_IDS_STORAGE_KEY = 'docsgpt:dismissedUploadSourceIds';
const DISMISSED_SOURCE_IDS_TTL_MS = 24 * 60 * 60 * 1000;
const DISMISSED_SOURCE_IDS_CAP = 200;
function recordDismissedSourceId(
list: DismissedEntry[],
sourceId: string,
): DismissedEntry[] {
const now = Date.now();
const cutoff = now - DISMISSED_SOURCE_IDS_TTL_MS;
const next = list.filter(
(entry) => entry.at >= cutoff && entry.id !== sourceId,
);
next.push({ id: sourceId, at: now });
return next.length > DISMISSED_SOURCE_IDS_CAP
? next.slice(-DISMISSED_SOURCE_IDS_CAP)
: next;
}
export interface Attachment {
id: string; // Client-side state-management id (uuid generated in MessageInput)
id: string; // Unique identifier for the attachment (required for state management)
fileName: string;
progress: number;
status: 'uploading' | 'processing' | 'completed' | 'failed';
taskId: string; // Server-assigned celery task ID (used for API calls)
/**
* Server-assigned attachment id (stable across the lifecycle —
* ``attachment.*`` SSE events use this in ``scope.id``). Set as
* soon as the upload response returns. Distinct from ``id``
* (client) and ``taskId`` (celery).
*/
attachmentId?: string;
taskId: string; // Server-assigned task ID (used for API calls)
token_count?: number;
}
@@ -57,40 +23,18 @@ export interface UploadTask {
progress: number;
status: UploadTaskStatus;
taskId?: string;
/**
* Server-derived source id (uuid5 over the idempotency key) returned by
* the upload endpoint. Used to correlate inbound SSE ingest events
* (``source.ingest.*``) back to this task without consulting the
* polling endpoint.
*/
sourceId?: string;
errorMessage?: string;
dismissed?: boolean;
/**
* Flipped when ``source.ingest.completed`` carries
* ``payload.limited === true`` (the worker hit a token cap during
* ingest). The slice routes such events to a failed status and
* sets this flag so ``UploadToast`` can surface a translated
* token-limit message instead of a generic error. Forward-looking:
* no worker code path sets ``limited: true`` today.
*/
tokenLimitReached?: boolean;
}
interface UploadState {
attachments: Attachment[];
tasks: UploadTask[];
/** Persisted dismissed sourceIds; keeps backlog-replay auto-creates silent. */
dismissedSourceIds: DismissedEntry[];
}
const initialState: UploadState = {
attachments: [],
tasks: [],
dismissedSourceIds: loadDismissed(
DISMISSED_SOURCE_IDS_STORAGE_KEY,
DISMISSED_SOURCE_IDS_TTL_MS,
),
};
export const uploadSlice = createSlice({
@@ -159,19 +103,9 @@ export const uploadSlice = createSlice({
);
if (index !== -1) {
const updates = action.payload.updates;
const existingSourceId = state.tasks[index].sourceId;
const incomingSourceId = updates.sourceId;
const effectiveSourceId = incomingSourceId ?? existingSourceId;
const sourceWasDismissed =
!!effectiveSourceId &&
isDismissed(state.dismissedSourceIds, effectiveSourceId);
// Re-surface on terminal status, except when the sourceId was
// dismissed in a prior session (don't re-pop after reload).
if (
(updates.status === 'completed' || updates.status === 'failed') &&
!sourceWasDismissed
) {
// When task completes or fails, set dismissed to false to notify user
if (updates.status === 'completed' || updates.status === 'failed') {
state.tasks[index] = {
...state.tasks[index],
...updates,
@@ -192,184 +126,12 @@ export const uploadSlice = createSlice({
...state.tasks[index],
dismissed: true,
};
// Persist by sourceId so a reload doesn't re-pop via the SSE
// backlog replay's auto-create branch.
const sourceId = state.tasks[index].sourceId;
if (sourceId) {
state.dismissedSourceIds = recordDismissedSourceId(
state.dismissedSourceIds,
sourceId,
);
saveDismissed(
DISMISSED_SOURCE_IDS_STORAGE_KEY,
state.dismissedSourceIds,
);
}
}
},
removeUploadTask: (state, action: PayloadAction<string>) => {
state.tasks = state.tasks.filter((task) => task.id !== action.payload);
},
},
extraReducers: (builder) => {
// Consume backend SSE ingest events for sub-second progress and
// terminal-status updates. The match is by ``sourceId`` (set by
// the upload endpoint's response). Polling stays as the
// correctness-of-record fallback in Upload.tsx.
builder.addCase(sseEventReceived, (state, action) => {
const e = action.payload;
const scopeId =
typeof e.scope?.id === 'string' && e.scope.id.length > 0
? e.scope.id
: undefined;
// Attachment events flow through the same SSE pipe; route them
// to ``state.attachments`` matched by ``attachmentId``. SSE is
// the sole driver of attachment state transitions — polling
// has been removed. Events for attachments uploaded in another
// session are silently dropped.
if (e.type.startsWith('attachment.') && scopeId) {
const attachment = state.attachments.find(
(a) => a.attachmentId === scopeId,
);
if (attachment) {
const payload = (e.payload || {}) as Record<string, unknown>;
switch (e.type) {
case 'attachment.queued':
case 'attachment.progress': {
if (
attachment.status === 'completed' ||
attachment.status === 'failed'
) {
break;
}
attachment.status = 'processing';
const current = Number(payload.current);
if (Number.isFinite(current)) {
const clamped = Math.max(0, Math.min(100, current));
if (clamped > attachment.progress) {
attachment.progress = clamped;
}
}
break;
}
case 'attachment.completed': {
attachment.status = 'completed';
attachment.progress = 100;
// Replace the client-generated uuid with the server's
// attachment id so question submission
// (Conversation.tsx:174) sends an id the backend can
// resolve. Without this the backend would silently drop
// the attachment from the message context.
attachment.id = scopeId;
const tokenCount = Number(payload.token_count);
if (Number.isFinite(tokenCount)) {
attachment.token_count = tokenCount;
}
break;
}
case 'attachment.failed': {
attachment.status = 'failed';
break;
}
default:
break;
}
}
return;
}
if (!e.type.startsWith('source.ingest.')) return;
if (!scopeId) return;
// ``source.ingest.*`` payloads do not carry a celery task_id
// (see ``application/worker.py`` — only ``source_id`` /
// ``job_name`` / ``filename`` are published), so ``sourceId``
// is the only correlation key.
const payload = (e.payload || {}) as Record<string, unknown>;
let task = state.tasks.find((t) => t.sourceId === scopeId);
if (!task) {
// Auto-create on backlog/live SSE when the local task is gone
// (page refresh mid-ingest). Skip ``completed`` so the historical
// backlog doesn't pop toasts for already-finished sources.
if (e.type === 'source.ingest.completed') return;
const fileName =
(typeof payload.filename === 'string' && payload.filename) ||
(typeof payload.job_name === 'string' && payload.job_name) ||
scopeId;
const previouslyDismissed = isDismissed(
state.dismissedSourceIds,
scopeId,
);
task = {
id: scopeId,
fileName,
progress: 0,
status: 'training',
sourceId: scopeId,
dismissed: previouslyDismissed,
};
state.tasks.push(task);
task = state.tasks[state.tasks.length - 1];
}
const wasTerminal =
task.status === 'completed' || task.status === 'failed';
const sourceWasDismissed =
!!task.sourceId && isDismissed(state.dismissedSourceIds, task.sourceId);
switch (e.type) {
case 'source.ingest.queued':
// Don't regress a task already past 'training' (e.g. the
// queued event arrives after the upload XHR finished and
// status flipped to 'training'). Idempotent on retries.
if (task.status === 'preparing' || task.status === 'uploading') {
task.status = 'training';
task.progress = 0;
}
break;
case 'source.ingest.progress': {
const current = Number(payload.current);
if (!Number.isFinite(current)) break;
// Clamp + monotonic — never regress an already-higher value.
const clamped = Math.max(0, Math.min(100, current));
if (task.status === 'completed' || task.status === 'failed') break;
task.status = 'training';
if (clamped > task.progress) task.progress = clamped;
break;
}
case 'source.ingest.completed':
if (payload.limited === true) {
// Token-cap reached during ingest — surface as a failure
// so the toast shows the translated limit message rather
// than a misleading success state.
task.status = 'failed';
task.progress = 100;
task.tokenLimitReached = true;
task.errorMessage = undefined;
// Only un-dismiss on the initial terminal transition;
// duplicate envelopes (StrictMode remount, reconnect
// overlap) must not re-pop a user-dismissed toast.
if (!wasTerminal && !sourceWasDismissed) task.dismissed = false;
} else {
task.status = 'completed';
task.progress = 100;
task.errorMessage = undefined;
task.tokenLimitReached = false;
if (!wasTerminal && !sourceWasDismissed) task.dismissed = false;
}
break;
case 'source.ingest.failed':
task.status = 'failed';
task.errorMessage =
typeof payload.error === 'string'
? payload.error
: 'Ingestion failed.';
if (!wasTerminal && !sourceWasDismissed) task.dismissed = false;
break;
default:
break;
}
});
},
});
export const {

View File

@@ -3,8 +3,6 @@
* Follows the convention: {provider}_session_token
*/
import userService from '../api/services/userService';
export const getSessionToken = (provider: string): string | null => {
return localStorage.getItem(`${provider}_session_token`);
};
@@ -21,5 +19,16 @@ export const validateProviderSession = async (
token: string | null,
provider: string,
) => {
return await userService.validateConnectorSession(provider, token);
const apiHost = import.meta.env.VITE_API_HOST;
return await fetch(`${apiHost}/api/connectors/validate-session`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`,
},
body: JSON.stringify({
provider: provider,
session_token: getSessionToken(provider),
}),
});
};

View File

@@ -12,7 +12,7 @@
"module": "ESNext",
"moduleResolution": "Bundler",
"resolveJsonModule": true,
"types": ["vite-plugin-svgr/client", "vitest/globals"],
"types": ["vite-plugin-svgr/client"],
"isolatedModules": true,
"noEmit": true,
"jsx": "react-jsx",

View File

@@ -1,4 +1,3 @@
/// <reference types="vitest" />
import { defineConfig } from 'vite';
import react from '@vitejs/plugin-react';
import svgr from 'vite-plugin-svgr';
@@ -12,10 +11,4 @@ export default defineConfig({
'@': path.resolve(__dirname, './src'),
},
},
test: {
environment: 'happy-dom',
globals: true,
include: ['src/**/*.test.{ts,tsx}'],
setupFiles: ['./vitest.setup.ts'],
},
});

View File

@@ -1,24 +0,0 @@
// happy-dom's localStorage lives on `window`; slices read the bare
// global at module-load. Install a Map-backed shim on globalThis.
const store = new Map<string, string>();
const shim = {
getItem: (k: string) => (store.has(k) ? store.get(k)! : null),
setItem: (k: string, v: string) => {
store.set(k, String(v));
},
removeItem: (k: string) => {
store.delete(k);
},
clear: () => store.clear(),
key: (i: number) => Array.from(store.keys())[i] ?? null,
get length() {
return store.size;
},
};
// Force-override: some happy-dom versions expose a Storage stub
// without getItem, so simple assignment isn't enough.
Object.defineProperty(globalThis, 'localStorage', {
value: shim,
writable: true,
configurable: true,
});

View File

@@ -492,7 +492,7 @@ class TestMCPOAuthManager:
from application.agents.tools.mcp_tool import MCPOAuthManager
manager = MCPOAuthManager(MagicMock())
result = manager.get_oauth_status("", "alice")
result = manager.get_oauth_status("")
assert result["status"] == "not_started"

View File

@@ -8,7 +8,6 @@ MCPOAuthManager.
import asyncio
import concurrent.futures
import json
from unittest.mock import MagicMock, patch
import pytest
@@ -870,110 +869,17 @@ class TestMCPOAuthManager:
from application.agents.tools.mcp_tool import MCPOAuthManager
manager = MCPOAuthManager(MagicMock())
result = manager.get_oauth_status("", "alice")
result = manager.get_oauth_status("")
assert result["status"] == "not_started"
def test_get_oauth_status_no_user(self):
def test_get_oauth_status_with_task(self):
from application.agents.tools.mcp_tool import MCPOAuthManager
manager = MCPOAuthManager(MagicMock())
result = manager.get_oauth_status("task123", "")
# Without a user id we can't address the per-user SSE journal,
# so the manager refuses rather than scanning every user's
# stream.
assert result["status"] == "not_found"
def test_get_oauth_status_reads_completed_envelope_from_journal(self):
"""The manager walks the user's SSE Streams journal and
surfaces the latest ``mcp.oauth.*`` envelope for the task.
Verifies the full polling-contract surface: ``status`` is
derived from the event-type suffix, and the completed
payload's ``tools`` / ``tools_count`` fields are passed
through unchanged so ``mcp.py``'s ``connect_mcp`` can use them
without further plumbing.
"""
from application.agents.tools.mcp_tool import MCPOAuthManager
completed_envelope = json.dumps(
{
"type": "mcp.oauth.completed",
"scope": {"kind": "mcp_oauth", "id": "task123"},
"payload": {
"task_id": "task123",
"tools": [{"name": "t1", "description": "d"}],
"tools_count": 1,
},
}
).encode("utf-8")
# xrevrange yields newest-first; an unrelated older entry
# should be skipped on the way to the matching one.
unrelated_envelope = json.dumps(
{
"type": "source.ingest.completed",
"scope": {"kind": "source", "id": "abc"},
"payload": {},
}
).encode("utf-8")
mock_redis = MagicMock()
mock_redis.xrevrange.return_value = [
(b"1735682400000-0", {b"event": completed_envelope}),
(b"1735682300000-0", {b"event": unrelated_envelope}),
]
manager = MCPOAuthManager(mock_redis)
result = manager.get_oauth_status("task123", "alice")
assert result["status"] == "completed"
assert result["tools"] == [{"name": "t1", "description": "d"}]
assert result["tools_count"] == 1
# The stream key is scoped per-user — never global.
mock_redis.xrevrange.assert_called_once()
call_args = mock_redis.xrevrange.call_args
assert "user:alice:stream" in call_args.args
# Scan window must cover the full bounded stream so a flood of
# concurrent source-ingest events between popup-completed and
# Save can't push the OAuth envelope out of view.
from application.core.settings import settings
assert call_args.kwargs.get("count") >= settings.EVENTS_STREAM_MAXLEN
def test_get_oauth_status_scan_covers_events_stream_maxlen(self):
"""Regression: count must scale with ``EVENTS_STREAM_MAXLEN``
so the OAuth completion envelope is reachable even after
concurrent source-ingest events flood the user stream."""
from application.agents.tools.mcp_tool import MCPOAuthManager
from application.core.settings import settings
mock_redis = MagicMock()
mock_redis.xrevrange.return_value = []
manager = MCPOAuthManager(mock_redis)
manager.get_oauth_status("task123", "alice")
call_args = mock_redis.xrevrange.call_args
assert call_args.kwargs.get("count") >= settings.EVENTS_STREAM_MAXLEN
def test_get_oauth_status_returns_not_found_when_no_match(self):
from application.agents.tools.mcp_tool import MCPOAuthManager
mock_redis = MagicMock()
# Stream is non-empty but has nothing matching this task.
unrelated_envelope = json.dumps(
{
"type": "mcp.oauth.completed",
"scope": {"kind": "mcp_oauth", "id": "other-task"},
"payload": {"task_id": "other-task"},
}
).encode("utf-8")
mock_redis.xrevrange.return_value = [
(b"1735682400000-0", {b"event": unrelated_envelope}),
]
manager = MCPOAuthManager(mock_redis)
result = manager.get_oauth_status("task123", "alice")
assert result["status"] == "not_found"
with patch("application.agents.tools.mcp_tool.mcp_oauth_status_task",
return_value={"status": "complete"}):
manager = MCPOAuthManager(MagicMock())
result = manager.get_oauth_status("task123")
assert result["status"] == "complete"
# =====================================================================
@@ -1419,7 +1325,6 @@ class TestSetupClientExtended:
tool._client = None
tool.available_tools = []
tool.user_id = "user1"
tool.oauth_redirect_publish = None
mock_client = MagicMock()
with patch.object(MCPTool, "_create_transport", return_value=MagicMock()), \
@@ -2352,21 +2257,16 @@ class TestMCPOAuthManagerExtended:
# Should store error in redis
mock_redis.setex.assert_called()
def test_get_oauth_status_returns_not_found_on_redis_error(self):
"""A failure inside ``xrevrange`` (Redis down, network blip) is
swallowed — the manager returns ``not_found`` so the caller
can present a clean "OAuth failed, try again" message rather
than a 500.
"""
def test_get_oauth_status_task_error(self):
from application.agents.tools.mcp_tool import MCPOAuthManager
mock_redis = MagicMock()
mock_redis.xrevrange.side_effect = Exception("Redis went away")
manager = MCPOAuthManager(mock_redis)
result = manager.get_oauth_status("task123", "alice")
assert result["status"] == "not_found"
with patch(
"application.agents.tools.mcp_tool.mcp_oauth_status_task",
side_effect=Exception("task failed"),
):
manager = MCPOAuthManager(MagicMock())
with pytest.raises(Exception, match="task failed"):
manager.get_oauth_status("task123")
# =====================================================================

View File

@@ -449,9 +449,6 @@ class TestFinalizeMessage:
question="q",
decoded_token={"sub": user},
)
from application.storage.db.repositories.conversations import (
MessageUpdateOutcome,
)
assert svc.finalize_message(
res["message_id"],
"real answer",
@@ -461,7 +458,7 @@ class TestFinalizeMessage:
model_id="gpt-4",
metadata={"foo": "bar"},
status="complete",
) is MessageUpdateOutcome.UPDATED
) is True
msgs = ConversationsRepository(pg_conn).get_messages(
res["conversation_id"],
@@ -491,15 +488,12 @@ class TestFinalizeMessage:
decoded_token={"sub": user},
)
err = RuntimeError("provider down")
from application.storage.db.repositories.conversations import (
MessageUpdateOutcome,
)
assert svc.finalize_message(
res["message_id"],
"fallback text",
status="failed",
error=err,
) is MessageUpdateOutcome.UPDATED
) is True
msgs = ConversationsRepository(pg_conn).get_messages(
res["conversation_id"],
@@ -532,12 +526,9 @@ class TestFinalizeMessage:
),
{"cid": "c1", "mid": res["message_id"]},
)
from application.storage.db.repositories.conversations import (
MessageUpdateOutcome,
)
assert svc.finalize_message(
res["message_id"], "ans", status="complete",
) is MessageUpdateOutcome.UPDATED
) is True
status = pg_conn.execute(
sql_text("SELECT status FROM tool_call_attempts WHERE call_id = :cid"),
@@ -545,19 +536,16 @@ class TestFinalizeMessage:
).scalar()
assert status == "confirmed"
def test_finalize_returns_not_found_for_unknown_message(self, pg_conn):
def test_finalize_returns_false_for_unknown_message(self, pg_conn):
from application.api.answer.services.conversation_service import (
ConversationService,
)
from application.storage.db.repositories.conversations import (
MessageUpdateOutcome,
)
with _patch_db(pg_conn):
assert ConversationService().finalize_message(
"00000000-0000-0000-0000-000000000000",
"x",
status="complete",
) is MessageUpdateOutcome.NOT_FOUND
) is False
def test_finalize_rolls_back_tool_call_confirm_on_message_update_failure(
self, pg_conn
@@ -663,9 +651,6 @@ class TestFinalizeMessage:
question="long question that becomes the fallback name",
decoded_token={"sub": user},
)
from application.storage.db.repositories.conversations import (
MessageUpdateOutcome,
)
assert svc.finalize_message(
res["message_id"],
"answer",
@@ -679,7 +664,7 @@ class TestFinalizeMessage:
"long question that becomes the fallback name"[:50]
),
},
) is MessageUpdateOutcome.UPDATED
) is True
repo = ConversationsRepository(pg_conn)
conv = repo.get_any(res["conversation_id"], user)

View File

@@ -343,50 +343,6 @@ class TestProcessResponseStreamExtended:
result = resource.process_response_stream(iter(stream))
assert result["thought"] == "thinking..."
def test_handles_id_prefixed_chunks(self, mock_mongo_db, flask_app):
"""``complete_stream`` emits ``id: <seq>\\n`` before each
``data:`` line so reconnects can resume. The non-streaming
``/api/answer`` consumer must skip the ``id:`` header (and the
informational ``message_id`` event) without breaking JSON
decoding.
"""
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
stream = [
'id: 0\n'
f'data: {json.dumps({"type": "message_id", "message_id": "abc"})}\n\n',
'id: 1\n'
f'data: {json.dumps({"type": "answer", "answer": "Hello "})}\n\n',
'id: 2\n'
f'data: {json.dumps({"type": "answer", "answer": "world"})}\n\n',
'id: 3\n'
f'data: {json.dumps({"type": "id", "id": "conv-1"})}\n\n',
'id: 4\n'
f'data: {json.dumps({"type": "end"})}\n\n',
]
result = resource.process_response_stream(iter(stream))
assert result["error"] is None
assert result["answer"] == "Hello world"
assert result["conversation_id"] == "conv-1"
def test_skips_keepalive_comment_lines(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
stream = [
': keepalive\n\n',
'id: 0\n'
f'data: {json.dumps({"type": "answer", "answer": "ok"})}\n\n',
'id: 1\n'
f'data: {json.dumps({"type": "end"})}\n\n',
]
result = resource.process_response_stream(iter(stream))
assert result["answer"] == "ok"
assert result["error"] is None
@pytest.mark.unit
class TestCheckUsageStringBooleans:
@@ -637,251 +593,6 @@ class TestCompleteStreamGeneratorExit:
next(gen)
gen.close() # Should not crash even with save error
def test_generator_exit_before_any_response_journals_error_not_end(
self, mock_mongo_db, flask_app,
):
"""A client disconnect right after the early ``message_id`` frame
leaves ``response_full`` empty, so finalize never runs. The abort
handler must journal ``error`` (not ``end``) and flip the row to
``failed`` — otherwise a reconnecting client sees a terminal
``end`` for a row whose DB status is still non-terminal and the
UI parks on a blank successful answer.
"""
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
mock_agent = MagicMock()
# An agent that's primed but never yields anything before the
# client disconnects — keeps ``response_full`` empty.
def gen_never_yields():
if False:
yield # pragma: no cover
return
mock_agent.gen.return_value = gen_never_yields()
mock_agent.compression_metadata = None
mock_agent.compression_saved = False
mock_agent.tool_calls = []
resource.conversation_service = MagicMock()
resource.conversation_service.save_user_question.return_value = {
"conversation_id": "conv1",
"message_id": "msg1",
"request_id": "req1",
}
journaled: list[tuple] = []
def _capture_record(message_id, sequence_no, event_type, payload):
journaled.append((message_id, sequence_no, event_type, payload))
return True
with patch(
"application.api.answer.routes.base.record_event",
side_effect=_capture_record,
):
gen = resource.complete_stream(
question="Q",
agent=mock_agent,
conversation_id="conv1",
user_api_key=None,
decoded_token={"sub": "u"},
should_save_conversation=True,
model_id="gpt-4",
)
# Drain the early ``message_id`` event, then close before
# the agent yields anything.
next(gen)
gen.close()
# The early message_id frame got journaled (seq 0); the abort
# handler must follow with an ``error`` event (NOT ``end``).
terminal_events = [
(et, pl) for (_, _, et, pl) in journaled if et in ("end", "error")
]
assert len(terminal_events) == 1, (
f"expected exactly one terminal journal write, got {terminal_events}"
)
assert terminal_events[0][0] == "error", (
f"expected ``error`` terminal but got ``end``: {terminal_events}"
)
payload = terminal_events[0][1]
assert payload.get("type") == "error"
assert payload.get("code") == "client_disconnect"
# And the DB row should have been flipped to ``failed`` via
# finalize_message. The mocked service records the call.
finalize_calls = (
resource.conversation_service.finalize_message.call_args_list
)
assert len(finalize_calls) == 1
assert finalize_calls[0].kwargs.get("status") == "failed"
def test_generator_exit_after_response_still_journals_end(
self, mock_mongo_db, flask_app,
):
"""Regression guard: a disconnect AFTER partial response was
produced and ``finalize_message`` succeeded must still journal
``end`` (the row matches ``complete``). Only the empty-response
branch flips to ``error``.
"""
from application.api.answer.routes.base import BaseAnswerResource
from application.storage.db.repositories.conversations import (
MessageUpdateOutcome,
)
with flask_app.app_context():
resource = BaseAnswerResource()
mock_agent = MagicMock()
def gen_answers():
yield {"answer": "partial"}
mock_agent.gen.return_value = gen_answers()
mock_agent.compression_metadata = None
mock_agent.compression_saved = False
mock_agent.tool_calls = []
resource.conversation_service = MagicMock()
resource.conversation_service.finalize_message.return_value = (
MessageUpdateOutcome.UPDATED
)
resource.conversation_service.save_user_question.return_value = {
"conversation_id": "conv1",
"message_id": "msg1",
"request_id": "req1",
}
journaled: list[tuple] = []
def _capture_record(message_id, sequence_no, event_type, payload):
journaled.append((message_id, sequence_no, event_type, payload))
return True
with patch(
"application.api.answer.routes.base.record_event",
side_effect=_capture_record,
):
gen = resource.complete_stream(
question="Q",
agent=mock_agent,
conversation_id="conv1",
user_api_key=None,
decoded_token={"sub": "u"},
should_save_conversation=True,
model_id="gpt-4",
)
next(gen) # message_id frame
next(gen) # answer frame (consumes ``partial``)
gen.close()
terminal_events = [
(et, pl) for (_, _, et, pl) in journaled if et in ("end", "error")
]
assert terminal_events and terminal_events[0][0] == "end", (
f"finalize succeeded but abort journaled {terminal_events}"
)
def test_generator_exit_after_normal_finalize_already_complete_journals_end(
self, mock_mongo_db, flask_app,
):
"""Regression for the race where the normal-path finalize wins
against a client disconnect.
Trace: agent finishes, ``complete_stream`` runs the normal-path
``finalize_message`` at base.py:632 and flips the row to
``complete``. The client TCP-resets before the ``end`` frame can
be journaled. The GeneratorExit handler calls ``finalize_message``
again — and the repository, gated by ``only_if_non_terminal``,
reports ``ALREADY_COMPLETE`` because the row is already at the
target state. The abort handler must journal ``end``, not
``error``: the DB says ``complete``, the reconnecting client
must see the same.
Without the fix the abort handler treated ``ALREADY_COMPLETE``
as a failure and journaled ``error`` — a reconnect would then
replay a successful completion as a failed answer.
"""
from application.api.answer.routes.base import BaseAnswerResource
from application.storage.db.repositories.conversations import (
MessageUpdateOutcome,
)
with flask_app.app_context():
resource = BaseAnswerResource()
mock_agent = MagicMock()
def gen_answers():
yield {"answer": "partial"}
mock_agent.gen.return_value = gen_answers()
mock_agent.compression_metadata = None
mock_agent.compression_saved = False
mock_agent.tool_calls = []
resource.conversation_service = MagicMock()
# Normal-path call returns UPDATED; abort-handler call sees
# the row already terminal and returns ALREADY_COMPLETE.
resource.conversation_service.finalize_message.side_effect = [
MessageUpdateOutcome.UPDATED,
MessageUpdateOutcome.ALREADY_COMPLETE,
]
resource.conversation_service.save_user_question.return_value = {
"conversation_id": "conv1",
"message_id": "msg1",
"request_id": "req1",
}
journaled: list[tuple] = []
def _capture_record(message_id, sequence_no, event_type, payload):
journaled.append((message_id, sequence_no, event_type, payload))
return True
with patch(
"application.api.answer.routes.base.record_event",
side_effect=_capture_record,
):
gen = resource.complete_stream(
question="Q",
agent=mock_agent,
conversation_id="conv1",
user_api_key=None,
decoded_token={"sub": "u"},
should_save_conversation=True,
model_id="gpt-4",
)
next(gen) # message_id frame
next(gen) # answer frame
# Pull the ``id`` frame so the normal-path finalize at
# line 632 actually runs (it sits between the answer
# frame yield and the id frame yield).
next(gen) # id frame — runs normal-path finalize first
gen.close() # GeneratorExit at the id frame yield
terminal_events = [
(et, pl) for (_, _, et, pl) in journaled if et in ("end", "error")
]
assert len(terminal_events) == 1, (
f"expected exactly one terminal journal write, got "
f"{terminal_events}"
)
assert terminal_events[0][0] == "end", (
f"row was already ``complete`` but abort journaled "
f"{terminal_events[0]} — reconnect would surface a "
f"successful answer as failed"
)
# Both finalize_message calls were made: the normal-path
# one and the abort-handler one. Asserting on side_effect
# consumption ensures the test really exercised both
# branches.
assert (
resource.conversation_service.finalize_message.call_count == 2
)
@contextmanager
def _patch_db_session(conn):
@@ -895,34 +606,10 @@ def _patch_db_session(conn):
), patch(
"application.api.answer.services.conversation_service.db_readonly",
_yield,
), patch(
# ``record_event`` opens its own short-lived ``db_session`` for
# cross-connection visibility. In tests we route it back to the
# same ``pg_conn`` so the journal write can see the message row
# the conversation_service just wrote in this transaction.
"application.streaming.message_journal.db_session",
_yield,
), patch(
# ``complete_stream`` reads ``latest_sequence_no`` via
# ``db_readonly`` to seed continuation runs. Same patch reason
# as the journal — keep the read on the same pg_conn so it sees
# uncommitted writes from this transaction.
"application.api.answer.routes.base.db_readonly",
_yield,
):
yield
def _extract_sse_data(chunk: str) -> str:
"""Pull the ``data:`` payload from an SSE record, ignoring any
``id:`` header introduced by the journal wiring.
"""
for line in chunk.split("\n"):
if line.startswith("data:"):
return line[len("data:") :].lstrip()
return ""
@pytest.mark.unit
class TestCompleteStreamWalAcceptance:
"""Acceptance for the WAL pre-persist behaviour: when the LLM raises
@@ -972,158 +659,6 @@ class TestCompleteStreamWalAcceptance:
assert "RuntimeError" in msgs[0]["metadata"]["error"]
assert "LLM upstream failed" in msgs[0]["metadata"]["error"]
def test_tool_approval_event_only_fires_when_state_saved(
self, pg_conn, flask_app,
):
"""A `tool.approval.required` notification with no resumable
``pending_tool_state`` row would deep-link the user to a 404.
Gate the publish on save_state actually committing.
"""
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
mock_agent = MagicMock()
# Drive the agent into the paused branch with a single
# ``tool_calls_pending`` event.
mock_agent.gen.return_value = iter(
[
{
"type": "tool_calls_pending",
"data": {"pending_tool_calls": [{"call_id": "c1"}]},
}
]
)
mock_agent._pending_continuation = {
"messages": [],
"tools_dict": {},
"pending_tool_calls": [{"call_id": "c1"}],
}
mock_agent.tool_calls = []
mock_agent.compression_metadata = None
mock_agent.compression_saved = False
published: list = []
def _capture(*args, **kwargs):
published.append(args)
with _patch_db_session(pg_conn), patch(
"application.api.answer.routes.base.publish_user_event",
side_effect=_capture,
), patch(
"application.api.answer.services.continuation_service."
"ContinuationService.save_state",
side_effect=RuntimeError("PG outage"),
):
list(
resource.complete_stream(
question="run my tool",
agent=mock_agent,
conversation_id=None,
user_api_key=None,
decoded_token={"sub": "u-tap"},
should_save_conversation=True,
model_id="gpt-4",
)
)
# No tool.approval.required publish when save_state failed.
event_types = [a[1] for a in published if len(a) >= 2]
assert "tool.approval.required" not in event_types
def test_continuation_seeds_sequence_no_from_journal_high_water_mark(
self, pg_conn, flask_app,
):
"""A resumed (tool-actions) stream must continue numbering past
the original run's max ``sequence_no``. Otherwise the second
invocation collides on the duplicate-PK and silently drops
every journal write past the resume point.
We pre-seed the journal with synthetic rows simulating an
original run, then invoke ``complete_stream`` with
``_continuation`` set and assert seq numbering picks up past
the high-water mark.
"""
import uuid as _uuid
from sqlalchemy import text as sql_text
from application.api.answer.routes.base import BaseAnswerResource
from application.storage.db.repositories.message_events import (
MessageEventsRepository,
)
# Pre-seed: insert a parent conversation + message + a few
# journal rows so latest_sequence_no returns 7.
user_id = "u-resume"
conv_id = _uuid.uuid4()
message_id = _uuid.uuid4()
pg_conn.execute(
sql_text("INSERT INTO users (user_id) VALUES (:u)"),
{"u": user_id},
)
pg_conn.execute(
sql_text(
"INSERT INTO conversations (id, user_id, name) "
"VALUES (:id, :u, 'pre-seed')"
),
{"id": conv_id, "u": user_id},
)
pg_conn.execute(
sql_text(
"INSERT INTO conversation_messages "
"(id, conversation_id, user_id, position, prompt) "
"VALUES (:id, :c, :u, 0, 'q')"
),
{"id": message_id, "c": conv_id, "u": user_id},
)
repo = MessageEventsRepository(pg_conn)
for seq in range(8): # rows 0..7
repo.record(str(message_id), seq, "answer", {"chunk": str(seq)})
original_max = repo.latest_sequence_no(str(message_id))
assert original_max == 7
with flask_app.app_context():
resource = BaseAnswerResource()
cont_agent = MagicMock()
cont_agent.gen_continuation.return_value = iter(
[{"answer": "d"}, {"answer": "e"}]
)
cont_agent.tool_calls = []
cont_agent.compression_metadata = None
cont_agent.compression_saved = False
cont_agent.tool_executor = None
with _patch_db_session(pg_conn):
list(
resource.complete_stream(
question="",
agent=cont_agent,
conversation_id=str(conv_id),
user_api_key=None,
decoded_token={"sub": user_id},
should_save_conversation=True,
model_id="gpt-4",
_continuation={
"messages": [],
"tools_dict": {},
"pending_tool_calls": [],
"tool_actions": [],
"reserved_message_id": str(message_id),
"request_id": "req-resume",
},
)
)
new_max = MessageEventsRepository(pg_conn).latest_sequence_no(
str(message_id)
)
# Continuation extended the journal — the new high-water mark
# is strictly greater than the seeded ``original_max=7``,
# confirming the allocator picked up past the resume point.
assert new_max is not None and new_max > original_max
def test_request_id_consistent_across_sse_event_and_wal_row(
self, pg_conn, flask_app,
):
@@ -1159,9 +694,9 @@ class TestCompleteStreamWalAcceptance:
)
sse_events = [
json.loads(_extract_sse_data(s))
json.loads(s.replace("data: ", "").strip())
for s in stream
if "data:" in s
if s.startswith("data: ")
]
early_events = [e for e in sse_events if e.get("type") == "message_id"]
assert len(early_events) == 1
@@ -1177,102 +712,3 @@ class TestCompleteStreamWalAcceptance:
msgs = ConversationsRepository(pg_conn).get_messages(str(convs[0][0]))
assert len(msgs) == 1
assert msgs[0]["request_id"] == sse_request_id
@pytest.mark.unit
class TestStreamingHeartbeatSeed:
"""Regression guard: when the row flips to ``streaming`` we must seed
``last_heartbeat_at`` so the watchdog doesn't fall back to
``timestamp`` (creation time) on slow LLM cold-starts (>idle_secs).
"""
def test_heartbeat_seeded_on_first_chunk(self, mock_mongo_db, flask_app):
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
mock_agent = MagicMock()
mock_agent.gen.return_value = iter(
[{"answer": "a"}, {"answer": "b"}]
)
mock_agent.compression_metadata = None
mock_agent.compression_saved = False
mock_agent.tool_calls = []
mock_agent.tool_executor = None
resource.conversation_service = MagicMock()
resource.conversation_service.save_conversation.return_value = "conv1"
resource.conversation_service.save_user_question.return_value = {
"conversation_id": "conv1",
"message_id": "msg1",
"request_id": "req1",
}
list(
resource.complete_stream(
question="Q",
agent=mock_agent,
conversation_id="conv1",
user_api_key=None,
decoded_token={"sub": "u"},
should_save_conversation=True,
model_id="gpt-4",
)
)
# update_message_status flips the row to ``streaming`` exactly
# once (idempotent via streaming_marked).
status_calls = [
c for c in resource.conversation_service
.update_message_status.call_args_list
if c.args[1] == "streaming"
]
assert len(status_calls) == 1
assert status_calls[0].args[0] == "msg1"
# heartbeat seed runs exactly once at the same flip — multiple
# chunks don't re-stamp inside this window (the throttled
# _heartbeat_streaming path is gated by STREAM_HEARTBEAT_INTERVAL).
hb_calls = (
resource.conversation_service.heartbeat_message.call_args_list
)
assert len(hb_calls) == 1
assert hb_calls[0].args[0] == "msg1"
def test_heartbeat_seed_skipped_without_reserved_message_id(
self, mock_mongo_db, flask_app,
):
"""No DB-backed message row → no heartbeat call (and no error)."""
from application.api.answer.routes.base import BaseAnswerResource
with flask_app.app_context():
resource = BaseAnswerResource()
mock_agent = MagicMock()
mock_agent.gen.return_value = iter([{"answer": "a"}])
mock_agent.compression_metadata = None
mock_agent.compression_saved = False
mock_agent.tool_calls = []
mock_agent.tool_executor = None
resource.conversation_service = MagicMock()
resource.conversation_service.save_conversation.return_value = "conv1"
# save_user_question returns no message_id → reservation absent
resource.conversation_service.save_user_question.return_value = {
"conversation_id": "conv1",
"message_id": None,
"request_id": "req1",
}
list(
resource.complete_stream(
question="Q",
agent=mock_agent,
conversation_id="conv1",
user_api_key=None,
decoded_token={"sub": "u"},
should_save_conversation=True,
model_id="gpt-4",
)
)
resource.conversation_service.heartbeat_message.assert_not_called()

View File

@@ -1,233 +0,0 @@
"""Integration tests for the end-to-end snapshot+tail handoff.
Exercises the publisher → journal → reconnect endpoint round-trip
without mocking the journal layer, so a regression in any of:
- complete_stream's _emit closure
- record_event's commit-per-call contract
- build_message_event_stream's snapshot-from-DB path
- the reconnect route's auth + ownership gates
- message_events repo SQL
would surface here as a failed integration assertion.
"""
from __future__ import annotations
import uuid as _uuid
from contextlib import contextmanager
from unittest.mock import patch
import pytest
from sqlalchemy import text as sql_text
@contextmanager
def _patch_journal_session(conn):
@contextmanager
def _yield():
yield conn
with patch(
"application.streaming.message_journal.db_session", _yield
), patch(
"application.streaming.event_replay.db_readonly", _yield
):
yield
def _seed_message(conn, user_id: str | None = None):
user_id = user_id or f"u-{_uuid.uuid4().hex[:6]}"
conv_id = _uuid.uuid4()
msg_id = _uuid.uuid4()
conn.execute(sql_text("INSERT INTO users (user_id) VALUES (:u)"), {"u": user_id})
conn.execute(
sql_text(
"INSERT INTO conversations (id, user_id, name) VALUES (:id, :u, 't')"
),
{"id": conv_id, "u": user_id},
)
conn.execute(
sql_text(
"INSERT INTO conversation_messages (id, conversation_id, user_id, position) "
"VALUES (:id, :c, :u, 0)"
),
{"id": msg_id, "c": conv_id, "u": user_id},
)
return user_id, str(msg_id)
@pytest.mark.integration
class TestSnapshotPlusTailRoundTrip:
def test_record_event_then_snapshot_returns_what_was_journaled(
self, pg_conn,
):
"""End-to-end of the journal half: ``record_event`` writes through
a real ``MessageEventsRepository``, ``build_message_event_stream``
reads the snapshot back via the same repo + a stub Topic.
"""
from application.streaming import event_replay
from application.streaming.message_journal import record_event
_, message_id = _seed_message(pg_conn)
with _patch_journal_session(pg_conn):
# Three events stamp the journal.
record_event(message_id, 0, "answer", {"type": "answer", "answer": "A"})
record_event(message_id, 1, "answer", {"type": "answer", "answer": "B"})
record_event(message_id, 2, "end", {"type": "end"})
# Reconnect path: subscribe yields nothing (Redis-down
# branch); the post-loop fallback runs the snapshot read
# synchronously and yields the journal contents.
def _empty_subscribe(self, on_subscribe=None, poll_timeout=1.0):
return
yield # pragma: no cover
with patch.object(
event_replay.Topic,
"subscribe",
_empty_subscribe,
create=False,
):
gen = event_replay.build_message_event_stream(
message_id,
last_event_id=None,
keepalive_seconds=0.05,
poll_timeout_seconds=0.01,
)
out = list(gen)
# Prelude + 3 snapshot frames.
assert out[0] == ": connected\n\n"
assert "id: 0" in out[1] and '"answer": "A"' in out[1]
assert "id: 1" in out[2] and '"answer": "B"' in out[2]
assert "id: 2" in out[3] and '"type": "end"' in out[3]
def test_snapshot_resumes_past_last_event_id(self, pg_conn):
from application.streaming import event_replay
from application.streaming.message_journal import record_event
_, message_id = _seed_message(pg_conn)
with _patch_journal_session(pg_conn):
for seq in range(5):
record_event(
message_id, seq, "answer", {"type": "answer", "answer": str(seq)}
)
def _empty_subscribe(self, on_subscribe=None, poll_timeout=1.0):
return
yield # pragma: no cover
with patch.object(
event_replay.Topic,
"subscribe",
_empty_subscribe,
create=False,
):
# Client says it has seen up through seq=2; expect 3 + 4.
out = list(
event_replay.build_message_event_stream(
message_id,
last_event_id=2,
keepalive_seconds=0.05,
poll_timeout_seconds=0.01,
)
)
ids_seen = [line for line in out if line.startswith("id: ")]
# Multi-line records: extract the id integers we delivered.
emitted = sorted(
int(line.split(": ", 1)[1].split("\n")[0])
for line in out
if line.startswith("id: ")
)
# Filter for non-negative (the snapshot-failure synthetic uses -1).
emitted = [e for e in emitted if e >= 0]
assert emitted == [3, 4]
assert ids_seen # sanity
def test_reconnect_route_round_trip(self, pg_conn, flask_app):
"""``/api/messages/<id>/events`` returns the journaled events
for an authenticated owner.
"""
from flask import Flask, request
from application.api.answer.routes.messages import messages_bp
from application.streaming.message_journal import record_event
# Build a fresh Flask app routing to the reconnect blueprint
# plus a tiny auth shim that injects the test user.
user_id, message_id = _seed_message(pg_conn)
app = Flask(__name__)
app.register_blueprint(messages_bp)
app.config["TESTING"] = True
@app.before_request
def _shim_auth():
request.decoded_token = {"sub": user_id}
with _patch_journal_session(pg_conn):
record_event(message_id, 0, "answer", {"type": "answer", "answer": "x"})
record_event(message_id, 1, "end", {"type": "end"})
from application.streaming import event_replay
def _empty_subscribe(self, on_subscribe=None, poll_timeout=1.0):
return
yield # pragma: no cover
with patch.object(
event_replay.Topic,
"subscribe",
_empty_subscribe,
create=False,
), patch(
"application.api.answer.routes.messages.db_readonly"
) as ro:
ro.return_value.__enter__.return_value = pg_conn
with app.test_client() as c:
r = c.get(f"/api/messages/{message_id}/events")
assert r.status_code == 200
body = b""
for chunk in r.iter_encoded():
body += chunk
if body.count(b"\n\n") >= 4:
break
r.close()
# Both journaled events present in the response.
text = body.decode("utf-8")
assert ": connected" in text
assert '"answer": "x"' in text
assert '"type": "end"' in text
# The seq lines are correct.
assert "id: 0" in text and "id: 1" in text
def test_reconnect_rejects_non_owner(self, pg_conn, flask_app):
from flask import Flask, request
from application.api.answer.routes.messages import messages_bp
user_id, message_id = _seed_message(pg_conn)
app = Flask(__name__)
app.register_blueprint(messages_bp)
@app.before_request
def _shim_auth():
request.decoded_token = {"sub": "different-user"}
# Make the ownership check use the test connection.
with patch(
"application.api.answer.routes.messages.db_readonly"
) as ro:
from contextlib import contextmanager as _cm
@_cm
def _yield():
yield pg_conn
ro.side_effect = lambda: _yield()
with app.test_client() as c:
r = c.get(f"/api/messages/{message_id}/events")
assert r.status_code == 404

View File

@@ -1,649 +0,0 @@
"""Tests for application/api/events/routes.py — the SSE endpoint.
The SSE generator runs in a separate thread under the WSGI test client;
we drive it with mocked Redis (the ``pubsub.get_message`` and ``xrange``
sequences) and read the response body until we have enough records to
assert on, then close the response to terminate the generator.
"""
from __future__ import annotations
import json
import threading
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask, request
def _make_app():
"""Mount the events blueprint on a bare Flask app + JWT shim.
The shim mimics ``application/app.py`` populating
``request.decoded_token`` so the SSE handler's auth gate sees a
user-id without requiring the full app stack.
"""
from application.api.events.routes import events
app = Flask(__name__)
app.register_blueprint(events)
app.config["TESTING"] = True
@app.before_request
def _shim_auth(): # noqa: D401
header = request.headers.get("X-Test-Sub")
request.decoded_token = {"sub": header} if header else None
return app
class _FakePubSub:
"""Minimal Redis pub/sub stand-in for the SSE handler.
``messages`` is a list of message dicts the generator should see in
order. After exhausting it, ``get_message`` returns ``None`` (poll
timeout) so the generator stays alive emitting keepalives until the
test closes the response.
"""
def __init__(self, messages: list[dict[str, Any]]):
self._messages = list(messages)
self.subscribed: list[str] = []
self.unsubscribed: list[str] = []
self.closed = False
self._lock = threading.Lock()
def subscribe(self, name: str):
self.subscribed.append(name)
def unsubscribe(self, name: str):
self.unsubscribed.append(name)
def close(self):
self.closed = True
def get_message(self, timeout: float = 0):
with self._lock:
if self._messages:
return self._messages.pop(0)
return None
def _drain_until(response, predicate, max_chunks: int = 200) -> bytes:
"""Consume the streamed response until ``predicate(buf)`` is true.
Returns the accumulated bytes. Closes the response so the generator
exits cleanly via GeneratorExit.
"""
buf = b""
iterator = response.iter_encoded()
for _ in range(max_chunks):
try:
chunk = next(iterator)
except StopIteration:
break
if not chunk:
continue
buf += chunk
if predicate(buf):
break
response.close()
return buf
# ── auth gate ───────────────────────────────────────────────────────────
class TestAuthGate:
def test_rejects_when_no_decoded_token(self):
app = _make_app()
with app.test_client() as c:
r = c.get("/api/events")
assert r.status_code == 401
def test_rejects_when_decoded_token_missing_sub(self):
from application.api.events import routes as events_module
app = _make_app()
# Clear the shim's behavior — supply a decoded_token without sub.
@app.before_request
def _override():
request.decoded_token = {"email": "x@y.z"}
with patch.object(events_module, "get_redis_instance", return_value=None):
with app.test_client() as c:
r = c.get("/api/events")
assert r.status_code == 401
# ── streaming response shape ────────────────────────────────────────────
class TestStreamShape:
def test_returns_event_stream_mimetype_and_no_buffering_header(self):
from application.api.events import routes as events_module
app = _make_app()
with patch.object(events_module, "get_redis_instance", return_value=None):
with app.test_client() as c:
r = c.get("/api/events", headers={"X-Test-Sub": "alice"})
assert r.status_code == 200
assert r.mimetype == "text/event-stream"
assert r.headers.get("Cache-Control") == "no-store"
assert r.headers.get("X-Accel-Buffering") == "no"
# Drain enough to see the prelude comment then close.
body = _drain_until(r, lambda b: b": connected" in b)
assert b": connected" in body
def test_emits_push_disabled_when_setting_off(self):
from application.api.events import routes as events_module
app = _make_app()
with patch.object(events_module, "get_redis_instance", return_value=None), \
patch.object(events_module.settings, "ENABLE_SSE_PUSH", False):
with app.test_client() as c:
r = c.get("/api/events", headers={"X-Test-Sub": "alice"})
body = _drain_until(r, lambda b: b": push_disabled" in b)
assert b": push_disabled" in body
assert b": connected" in body # prelude still emitted
# ── concurrency cap ─────────────────────────────────────────────────────
class TestConcurrencyCap:
def test_returns_429_when_user_over_cap(self):
from application.api.events import routes as events_module
app = _make_app()
redis_client = MagicMock()
# First INCR returns 9 (over cap of 8).
redis_client.incr.return_value = 9
with patch.object(events_module, "get_redis_instance", return_value=redis_client), \
patch.object(events_module.settings, "SSE_MAX_CONCURRENT_PER_USER", 8):
with app.test_client() as c:
r = c.get("/api/events", headers={"X-Test-Sub": "alice"})
assert r.status_code == 429
# DECR fired to release the over-cap increment.
redis_client.decr.assert_called_once_with("user:alice:sse_count")
def test_skips_cap_when_zero_disabled(self):
from application.api.events import routes as events_module
app = _make_app()
redis_client = MagicMock()
with patch.object(events_module, "get_redis_instance", return_value=redis_client), \
patch.object(events_module.settings, "SSE_MAX_CONCURRENT_PER_USER", 0), \
patch.object(events_module, "Topic") as mock_topic_cls:
mock_topic = MagicMock()
mock_topic.subscribe.return_value = iter([])
mock_topic_cls.return_value = mock_topic
redis_client.xinfo_stream.side_effect = Exception("no stream")
redis_client.xrange.return_value = []
with app.test_client() as c:
r = c.get("/api/events", headers={"X-Test-Sub": "alice"})
assert r.status_code == 200
# Concurrency counter not touched when cap is 0. The
# replay-budget INCR is unrelated and may still fire.
incr_keys = [
call.args[0] for call in redis_client.incr.call_args_list
]
assert "user:alice:sse_count" not in incr_keys
_drain_until(r, lambda b: b": connected" in b)
# ── replay + live tail ──────────────────────────────────────────────────
class TestReplayAndTail:
def test_replay_yields_xrange_entries_with_injected_id(self):
from application.api.events import routes as events_module
app = _make_app()
redis_client = MagicMock()
redis_client.incr.return_value = 1
# Empty stream (no truncation).
redis_client.xinfo_stream.side_effect = Exception("nope")
# XRANGE returns one stored envelope (without ``id``); the route
# injects the entry id on the way out.
stored_event = json.dumps(
{
"type": "source.ingest.progress",
"ts": "2026-04-28T00:00:00.000Z",
"user_id": "alice",
"topic": "user:alice",
"scope": {"kind": "source", "id": "src-1"},
"payload": {"current": 25, "total": 100},
}
).encode()
redis_client.xrange.return_value = [
(b"1735682400000-0", {b"event": stored_event}),
]
# Topic.subscribe yields an immediate timeout so the generator
# keeps running long enough to flush replay; subsequent calls
# also return None.
from application.api.events.routes import _SSE_LINE_SPLIT # noqa: F401
# Fake the broadcast Topic to invoke on_subscribe immediately
# then yield None ticks until close.
def _fake_subscribe(self, on_subscribe=None, poll_timeout=1.0):
if on_subscribe is not None:
on_subscribe()
while True:
yield None
with patch.object(events_module, "get_redis_instance", return_value=redis_client), \
patch.object(
events_module.Topic, "subscribe", _fake_subscribe, create=False
):
with app.test_client() as c:
r = c.get(
"/api/events",
headers={"X-Test-Sub": "alice", "Last-Event-ID": "1735682300000-0"},
)
body = _drain_until(
r,
lambda b: b'"current": 25' in b or b'"current":25' in b,
max_chunks=80,
)
# Replay yields the entry id as the SSE id field.
assert b"id: 1735682400000-0" in body
# Envelope was rewritten to include the injected id.
assert b'"id": "1735682400000-0"' in body or b'"id":"1735682400000-0"' in body
# The connect log fires before replay.
assert b": connected" in body
def test_snapshot_flushed_when_subscribe_dies_after_callback(self):
"""Regression: if ``on_subscribe`` populated ``replay_lines`` but
``Topic.subscribe`` exits before yielding once (transient Redis
hiccup between SUBSCRIBE-ack and the first poll), the snapshot
must still reach the client. Prior to the fix the in-loop flush
was the only flush, so the backlog was silently dropped.
"""
from application.api.events import routes as events_module
app = _make_app()
redis_client = MagicMock()
redis_client.incr.return_value = 1
redis_client.xinfo_stream.side_effect = Exception("nope")
stored_event = json.dumps(
{
"type": "notification",
"payload": {"text": "from snapshot"},
}
).encode()
redis_client.xrange.return_value = [
(b"1735682400000-0", {b"event": stored_event}),
]
# Mimic the broadcast_channel race: SUBSCRIBE acks, on_subscribe
# runs, then the next get_message raises and the generator
# returns without ever yielding.
def _subscribe_dies_after_callback(
self, on_subscribe=None, poll_timeout=1.0
):
if on_subscribe is not None:
on_subscribe()
return
yield # pragma: no cover (make the function a generator)
with patch.object(events_module, "get_redis_instance", return_value=redis_client), \
patch.object(
events_module.Topic,
"subscribe",
_subscribe_dies_after_callback,
create=False,
):
with app.test_client() as c:
r = c.get(
"/api/events",
headers={
"X-Test-Sub": "alice",
"Last-Event-ID": "1735682300000-0",
},
)
body = _drain_until(
r,
lambda b: b"from snapshot" in b,
max_chunks=80,
)
# Snapshot frame must have been flushed via the post-loop
# safety net even though Topic.subscribe exited before
# the in-loop flush could fire.
assert b"id: 1735682400000-0" in body
assert b"from snapshot" in body
# XRANGE was issued exactly once (no double-flush).
redis_client.xrange.assert_called_once()
def test_invalid_last_event_id_emits_truncation_notice(self):
from application.api.events import routes as events_module
app = _make_app()
redis_client = MagicMock()
redis_client.incr.return_value = 1
redis_client.xinfo_stream.return_value = {"first-entry": [b"1-0", []]}
redis_client.xrange.return_value = []
def _fake_subscribe(self, on_subscribe=None, poll_timeout=1.0):
if on_subscribe is not None:
on_subscribe()
while True:
yield None
with patch.object(events_module, "get_redis_instance", return_value=redis_client), \
patch.object(events_module.Topic, "subscribe", _fake_subscribe, create=False):
with app.test_client() as c:
r = c.get(
"/api/events",
headers={"X-Test-Sub": "alice", "Last-Event-ID": "definitely-not-an-id"},
)
body = _drain_until(
r, lambda b: b"backlog.truncated" in b, max_chunks=80
)
assert b"backlog.truncated" in body
def test_live_tail_rejects_malformed_event_id_for_dedupe(self):
"""A pub/sub envelope carrying a non-Redis-Streams ``id`` must not
seed the dedup floor. Otherwise an adversarial or buggy publisher
could ship ``id="9999999999999-9"`` (lex-greater than any real
id) and pin every subsequent legitimate event below the floor,
silently dropping the user's notifications.
The event itself should still be delivered to the client — we
just refuse to use the bogus id for ordering, so it ships
without an SSE ``id:`` header and ``max_replayed_id`` stays put.
"""
from application.api.events import routes as events_module
app = _make_app()
redis_client = MagicMock()
redis_client.incr.return_value = 1
redis_client.xinfo_stream.side_effect = Exception("nope")
# Snapshot covers ids up to 1735682400000-0; max_replayed_id
# becomes that value after the in-loop flush.
replay_event = json.dumps({
"type": "source.ingest.progress",
"payload": {"step": "replay"},
}).encode()
redis_client.xrange.return_value = [
(b"1735682400000-0", {b"event": replay_event}),
]
live_bogus = json.dumps({
"id": "definitely-not-an-id",
"type": "source.ingest.completed",
"payload": {"step": "live-bogus"},
})
live_real = json.dumps({
"id": "1735682500000-0",
"type": "source.ingest.completed",
"payload": {"step": "live-real"},
})
def _fake_subscribe(self, on_subscribe=None, poll_timeout=1.0):
# ``Topic.subscribe`` already unpacks redis-py pubsub dicts
# and yields the raw ``data`` bytes (or ``None`` on poll
# timeout). Mirror that contract.
if on_subscribe is not None:
on_subscribe()
yield live_bogus.encode()
yield live_real.encode()
while True:
yield None
with patch.object(
events_module, "get_redis_instance", return_value=redis_client
), patch.object(
events_module.Topic, "subscribe", _fake_subscribe, create=False
):
with app.test_client() as c:
r = c.get(
"/api/events",
headers={
"X-Test-Sub": "alice",
"Last-Event-ID": "1735682300000-0",
},
)
body = _drain_until(
r, lambda b: b"live-real" in b, max_chunks=80
)
# Live-real arrived (its id is strictly greater than the
# replayed snapshot's id), with its valid id surfaced as
# the SSE ``id:`` header so the frontend can advance.
assert b"live-real" in body
assert b"id: 1735682500000-0" in body
# The bogus-id event was still delivered to the client,
# but no ``id: definitely-not-an-id`` line was emitted —
# the malformed id never reached the SSE wire and so
# could not pin the dedup floor.
assert b"live-bogus" in body
assert b"id: definitely-not-an-id" not in body
# ── format helpers (already covered in test_events_substrate but
# duplicated here as a smoke for the route's surface) ─────────────────
class TestReplayRateLimit:
"""Enumeration defenses on the per-user backlog."""
def test_allow_replay_returns_true_when_budget_disabled(self):
from application.api.events.routes import _allow_replay
with patch("application.api.events.routes.settings") as mock_settings:
mock_settings.EVENTS_REPLAY_BUDGET_REQUESTS_PER_WINDOW = 0
mock_settings.EVENTS_REPLAY_BUDGET_WINDOW_SECONDS = 60
assert _allow_replay(MagicMock(), "alice", "1735682400000-0") is True
def test_allow_replay_returns_true_when_redis_unavailable(self):
from application.api.events.routes import _allow_replay
with patch("application.api.events.routes.settings") as mock_settings:
mock_settings.EVENTS_REPLAY_BUDGET_REQUESTS_PER_WINDOW = 5
mock_settings.EVENTS_REPLAY_BUDGET_WINDOW_SECONDS = 60
assert _allow_replay(None, "alice", "1735682400000-0") is True
def test_allow_replay_skips_incr_when_no_cursor_and_empty_backlog(self):
"""Fresh client with no cursor and an empty user stream cannot
do snapshot work — INCR'ing the counter would needlessly
burn budget. Catches the React-StrictMode dev-burst case where
double-mounted components would otherwise 429 in 5 connects.
"""
from application.api.events.routes import _allow_replay
with patch("application.api.events.routes.settings") as mock_settings:
mock_settings.EVENTS_REPLAY_BUDGET_REQUESTS_PER_WINDOW = 3
mock_settings.EVENTS_REPLAY_BUDGET_WINDOW_SECONDS = 60
redis = MagicMock()
redis.xlen.return_value = 0
# 5 connects in a row, all with no cursor — none consume
# budget because the backlog is empty.
for _ in range(5):
assert _allow_replay(redis, "alice", None) is True
redis.xlen.assert_called()
redis.incr.assert_not_called()
def test_allow_replay_incrs_when_no_cursor_but_backlog_present(self):
"""A no-cursor connect against a non-empty backlog *will* do
snapshot work, so it consumes budget normally.
"""
from application.api.events.routes import _allow_replay
with patch("application.api.events.routes.settings") as mock_settings:
mock_settings.EVENTS_REPLAY_BUDGET_REQUESTS_PER_WINDOW = 5
mock_settings.EVENTS_REPLAY_BUDGET_WINDOW_SECONDS = 60
redis = MagicMock()
redis.xlen.return_value = 42
redis.incr.return_value = 1
assert _allow_replay(redis, "alice", None) is True
redis.incr.assert_called_once()
def test_allow_replay_passes_until_budget_exhausted(self):
from application.api.events.routes import _allow_replay
with patch("application.api.events.routes.settings") as mock_settings:
mock_settings.EVENTS_REPLAY_BUDGET_REQUESTS_PER_WINDOW = 3
mock_settings.EVENTS_REPLAY_BUDGET_WINDOW_SECONDS = 60
redis = MagicMock()
counter = {"v": 0}
def _incr(_key):
counter["v"] += 1
return counter["v"]
redis.incr.side_effect = _incr
# Cursor set → XLEN short-circuit doesn't fire, INCR always runs.
cursor = "1735682400000-0"
# First three pass.
assert _allow_replay(redis, "alice", cursor) is True
assert _allow_replay(redis, "alice", cursor) is True
assert _allow_replay(redis, "alice", cursor) is True
# Fourth refused.
assert _allow_replay(redis, "alice", cursor) is False
# TTL re-seeded on every successful INCR so a transient
# EXPIRE failure on the seeding call can't wedge the key.
assert redis.expire.call_count == 4
for call in redis.expire.call_args_list:
assert call.args[1] == 60
def test_allow_replay_fail_open_on_redis_error(self):
from application.api.events.routes import _allow_replay
with patch("application.api.events.routes.settings") as mock_settings:
mock_settings.EVENTS_REPLAY_BUDGET_REQUESTS_PER_WINDOW = 5
mock_settings.EVENTS_REPLAY_BUDGET_WINDOW_SECONDS = 60
redis = MagicMock()
redis.incr.side_effect = Exception("redis down")
assert _allow_replay(redis, "alice", "1735682400000-0") is True
def test_allow_replay_recovers_when_seeding_expire_raises(self):
"""Regression: INCR=1 then EXPIRE raising must not wedge the key.
Earlier code only called EXPIRE when ``used == 1``. If that EXPIRE
raised, the counter persisted with no TTL and every subsequent
call hit ``used > 1`` without re-seeding — locking the user out
until an operator DEL'd the key. The fix calls EXPIRE on every
successful INCR so the next call still re-seeds the TTL.
"""
from application.api.events.routes import _allow_replay
with patch("application.api.events.routes.settings") as mock_settings:
mock_settings.EVENTS_REPLAY_BUDGET_REQUESTS_PER_WINDOW = 5
mock_settings.EVENTS_REPLAY_BUDGET_WINDOW_SECONDS = 60
redis = MagicMock()
counter = {"v": 0}
def _incr(_key):
counter["v"] += 1
return counter["v"]
redis.incr.side_effect = _incr
# First EXPIRE raises (the seeding call that would have
# wedged the key under the old gated logic). Second EXPIRE
# succeeds — and crucially, must still run.
redis.expire.side_effect = [Exception("expire blip"), True]
cursor = "1735682400000-0"
# First call: INCR=1 succeeds, EXPIRE raises -> outer except
# returns True (fail-open) for this call.
assert _allow_replay(redis, "alice", cursor) is True
# Second call: INCR=2, EXPIRE succeeds -> still under budget,
# and the TTL is now seeded (no permanent lockout).
assert _allow_replay(redis, "alice", cursor) is True
assert redis.expire.call_count == 2
# Both EXPIRE calls used the configured window.
for call in redis.expire.call_args_list:
assert call.args[1] == 60
def test_replay_backlog_passes_count_to_xrange(self):
from application.api.events.routes import _replay_backlog
redis = MagicMock()
redis.xrange.return_value = []
# Drain the iterator so xrange is actually called.
list(_replay_backlog(redis, "alice", None, 200))
redis.xrange.assert_called_once()
kwargs = redis.xrange.call_args.kwargs
assert kwargs.get("count") == 200
def test_returns_429_when_replay_budget_exhausted(self):
"""Route refuses the connection rather than serving live tail
only. Earlier behavior silently skipped replay and let the
client advance ``lastEventId`` via id-bearing live frames,
permanently stranding the un-replayed window. The 429 keeps
the cursor pinned so the next reconnect (after the budget
window slides) can replay normally.
"""
from application.api.events import routes as events_module
app = _make_app()
redis_client = MagicMock()
def _incr(key):
if key == "user:alice:sse_count":
return 1
# Budget counter: report over-limit.
return 31
redis_client.incr.side_effect = _incr
with patch.object(
events_module, "get_redis_instance", return_value=redis_client
), patch.object(
events_module.settings,
"EVENTS_REPLAY_BUDGET_REQUESTS_PER_WINDOW",
30,
):
with app.test_client() as c:
r = c.get(
"/api/events",
headers={
"X-Test-Sub": "alice",
"Last-Event-ID": "1735682300000-0",
},
)
assert r.status_code == 429
# Concurrency slot is released so a budget-denied request
# doesn't permanently consume a connection from the cap.
redis_client.decr.assert_called_once_with("user:alice:sse_count")
class TestFormatHelpers:
def test_format_sse_two_terminating_newlines(self):
from application.api.events.routes import _format_sse
out = _format_sse("hello", event_id="1-0")
assert out.endswith("\n\n")
# Exactly one ``id:`` and one ``data:``.
lines = out.rstrip("\n").split("\n")
assert lines == ["id: 1-0", "data: hello"]
@pytest.mark.parametrize(
"candidate, expected",
[
("1234", "1234"),
("1234-5", "1234-5"),
(" 1234-0 ", "1234-0"),
(None, None),
("", None),
(" ", None),
("nope", None),
("1234-foo", None),
],
)
def test_normalize_last_event_id(self, candidate, expected):
from application.api.events.routes import _normalize_last_event_id
assert _normalize_last_event_id(candidate) == expected

View File

@@ -1,220 +0,0 @@
"""Tests for ``application/api/answer/routes/messages.py``.
Reconnect endpoint: GET /api/messages/<id>/events. Auth gate, ownership
gate, malformed-id rejection, Last-Event-ID normalisation, and a smoke
test that the SSE response shape matches the user-events endpoint.
"""
from __future__ import annotations
from unittest.mock import patch
from flask import Flask, request
from application.api.answer.routes.messages import (
_MESSAGE_ID_RE,
_normalise_last_event_id,
messages_bp,
)
def _make_app(decoded_token=None):
app = Flask(__name__)
app.register_blueprint(messages_bp)
app.config["TESTING"] = True
@app.before_request
def _shim_auth():
request.decoded_token = decoded_token
return app
VALID_UUID = "67d65e8f-e7fb-4df1-9e6e-99ea6c830206"
class TestNormaliseLastEventId:
def test_none_passthrough(self):
assert _normalise_last_event_id(None) is None
def test_empty_string(self):
assert _normalise_last_event_id("") is None
def test_whitespace_only(self):
assert _normalise_last_event_id(" ") is None
def test_valid_int(self):
assert _normalise_last_event_id("42") == 42
def test_stripped_whitespace(self):
assert _normalise_last_event_id(" 7 ") == 7
def test_zero_is_valid(self):
assert _normalise_last_event_id("0") == 0
def test_negative_rejected(self):
# We expose only non-negative cursors; -1 is reserved for the
# snapshot-failure synthetic terminal event and shouldn't
# round-trip back.
assert _normalise_last_event_id("-1") is None
def test_non_numeric_rejected(self):
for bad in ("foo", "1.5", "1e3", "abc-123", "null"):
assert _normalise_last_event_id(bad) is None, bad
class TestMessageIdRegex:
def test_canonical_uuid_accepted(self):
assert _MESSAGE_ID_RE.match(VALID_UUID)
def test_uppercase_uuid_accepted(self):
assert _MESSAGE_ID_RE.match(VALID_UUID.upper())
def test_no_dashes_rejected(self):
assert not _MESSAGE_ID_RE.match(VALID_UUID.replace("-", ""))
def test_legacy_mongo_id_rejected(self):
# 24-char hex with no dashes — a Mongo objectid-shaped string
# that happened to leak through somewhere.
assert not _MESSAGE_ID_RE.match("507f1f77bcf86cd799439011")
class TestAuthGate:
def test_401_when_no_decoded_token(self):
app = _make_app(decoded_token=None)
with app.test_client() as c:
r = c.get(f"/api/messages/{VALID_UUID}/events")
assert r.status_code == 401
def test_401_when_decoded_token_missing_sub(self):
app = _make_app(decoded_token={"email": "x@y"})
with app.test_client() as c:
r = c.get(f"/api/messages/{VALID_UUID}/events")
assert r.status_code == 401
class TestMessageIdValidation:
def test_400_on_malformed_id(self):
app = _make_app(decoded_token={"sub": "alice"})
with app.test_client() as c:
r = c.get("/api/messages/not-a-uuid/events")
assert r.status_code == 400
class TestOwnershipGate:
def test_404_when_user_does_not_own_message(self):
from application.api.answer.routes import messages as messages_module
app = _make_app(decoded_token={"sub": "alice"})
with patch.object(
messages_module, "_user_owns_message", return_value=False
):
with app.test_client() as c:
r = c.get(f"/api/messages/{VALID_UUID}/events")
assert r.status_code == 404
def test_200_when_user_owns_message(self):
from application.api.answer.routes import messages as messages_module
app = _make_app(decoded_token={"sub": "alice"})
# Have build_message_event_stream yield just the prelude then
# exit so the test can drain the response without blocking on
# a live pubsub subscription.
def _fake_builder(message_id, last_event_id=None, **kwargs):
yield ": connected\n\n"
with patch.object(
messages_module, "_user_owns_message", return_value=True
), patch.object(
messages_module, "build_message_event_stream", _fake_builder
):
with app.test_client() as c:
r = c.get(f"/api/messages/{VALID_UUID}/events")
assert r.status_code == 200
assert r.mimetype == "text/event-stream"
assert r.headers.get("Cache-Control") == "no-store"
assert r.headers.get("X-Accel-Buffering") == "no"
body = b""
for chunk in r.iter_encoded():
body += chunk
if b": connected" in body:
break
r.close()
assert b": connected" in body
class TestLastEventIdParsing:
def test_header_passes_through_to_builder(self):
from application.api.answer.routes import messages as messages_module
captured = {}
def _fake_builder(message_id, last_event_id=None, **kwargs):
captured["message_id"] = message_id
captured["last_event_id"] = last_event_id
yield ": connected\n\n"
app = _make_app(decoded_token={"sub": "alice"})
with patch.object(
messages_module, "_user_owns_message", return_value=True
), patch.object(
messages_module, "build_message_event_stream", _fake_builder
):
with app.test_client() as c:
r = c.get(
f"/api/messages/{VALID_UUID}/events",
headers={"Last-Event-ID": "12"},
)
# Drain a tick.
next(iter(r.iter_encoded()), None)
r.close()
assert captured["message_id"] == VALID_UUID
assert captured["last_event_id"] == 12
def test_query_param_fallback(self):
from application.api.answer.routes import messages as messages_module
captured = {}
def _fake_builder(message_id, last_event_id=None, **kwargs):
captured["last_event_id"] = last_event_id
yield ": connected\n\n"
app = _make_app(decoded_token={"sub": "alice"})
with patch.object(
messages_module, "_user_owns_message", return_value=True
), patch.object(
messages_module, "build_message_event_stream", _fake_builder
):
with app.test_client() as c:
r = c.get(
f"/api/messages/{VALID_UUID}/events?last_event_id=5"
)
next(iter(r.iter_encoded()), None)
r.close()
assert captured["last_event_id"] == 5
def test_invalid_cursor_normalised_to_none(self):
from application.api.answer.routes import messages as messages_module
captured = {}
def _fake_builder(message_id, last_event_id=None, **kwargs):
captured["last_event_id"] = last_event_id
yield ": connected\n\n"
app = _make_app(decoded_token={"sub": "alice"})
with patch.object(
messages_module, "_user_owns_message", return_value=True
), patch.object(
messages_module, "build_message_event_stream", _fake_builder
):
with app.test_client() as c:
r = c.get(
f"/api/messages/{VALID_UUID}/events",
headers={"Last-Event-ID": "definitely-not-a-number"},
)
next(iter(r.iter_encoded()), None)
r.close()
assert captured["last_event_id"] is None

View File

@@ -543,84 +543,6 @@ class TestRemoteIdempotency:
assert response.status_code == 400
assert mock_apply.call_count == 0
def test_no_header_returns_source_id_matching_worker_kwarg(
self, app, pg_conn,
):
"""Regression: without an ``Idempotency-Key``, the route must
still return a ``source_id`` AND pass that same id to the worker
as ``source_id`` so SSE envelopes line up with what the
frontend already has. Previously the route omitted ``source_id``
entirely on the no-key path and the worker minted its own
random uuid, breaking push correlation for the default upload
flow.
"""
from application.api.user.sources.upload import UploadRemote
apply_mock = _apply_async_mock()
with _patch_db(pg_conn), patch(
"application.api.user.sources.upload.ingest_remote.apply_async",
apply_mock,
), app.test_request_context(
"/api/remote", method="POST",
data={
"user": "u", "source": "github", "name": "g",
"data": json.dumps({"repo_url": "https://github.com/x/y"}),
},
content_type="multipart/form-data",
):
from flask import request
request.decoded_token = {"sub": "u"}
response = UploadRemote().post()
assert response.status_code == 200
assert "source_id" in response.json
assert (
apply_mock.call_args.kwargs["kwargs"]["source_id"]
== response.json["source_id"]
)
def test_no_header_connector_returns_source_id_matching_worker_kwarg(
self, app, pg_conn,
):
"""Same regression as above for the connector branch
(``ingest_connector_task``). The connector path took the
no-key gap independently of the plain remote path."""
from application.api.user.sources.upload import UploadRemote
apply_mock = _apply_async_mock()
# Pick any registered connector — the route only branches on
# ``ConnectorCreator.get_supported_connectors()``.
from application.parser.connectors.connector_creator import (
ConnectorCreator,
)
supported = ConnectorCreator.get_supported_connectors()
if not supported:
pytest.skip("no connectors registered in this build")
connector_source = next(iter(supported))
with _patch_db(pg_conn), patch(
"application.api.user.sources.upload.ingest_connector_task.apply_async",
apply_mock,
), app.test_request_context(
"/api/remote", method="POST",
data={
"user": "u", "source": connector_source, "name": "g",
"data": json.dumps({
"session_token": "tok",
"file_ids": ["f1"],
}),
},
content_type="multipart/form-data",
):
from flask import request
request.decoded_token = {"sub": "u"}
response = UploadRemote().post()
assert response.status_code == 200
assert "source_id" in response.json
assert (
apply_mock.call_args.kwargs["kwargs"]["source_id"]
== response.json["source_id"]
)
def _seed_source(pg_conn, user="u", **kw):
from application.storage.db.repositories.sources import SourcesRepository
@@ -751,140 +673,10 @@ class TestManageSourceFilesIdempotency:
assert apply_mock.call_count == 1
# Loser's response carries the winner's task_id, not the
# original 200-with-added_files payload.
# ``manage_source_files`` aliases ``task_id`` ->
# ``reingest_task_id`` in the cached payload so the dedup
# response shape matches the fresh-request response (FileTree
# keys reingest correlation on ``reingest_task_id`` /
# ``source_id``).
assert second.json["reingest_task_id"] == first.json["reingest_task_id"]
# Cached ``source_id`` must equal the real source row id (not
# the helper's uuid5-of-key) so FileTree's SSE correlation on
# ``event.scope.id === result.source_id`` keeps working.
assert second.json["source_id"] == first.json["source_id"]
assert second.json["source_id"] == str(src["id"])
assert second.json["task_id"] == first.json["reingest_task_id"]
# Confirm the loser never invoked the file-save path.
assert fake_storage.save_file.call_count == 1
def test_remove_same_key_second_post_returns_real_source_id(
self, app, pg_conn
):
"""Regression: the ``remove`` cached branch used to leave the
helper's synthetic ``source_id`` (uuid5 of the scoped key) in
place. The reingest worker publishes SSE events tagged with the
real source row id, so the cached response had to be patched to
match what the fresh response returns — otherwise FileTree's
SSE correlation silently fails on every idempotent retry and
the user never sees the directory refresh.
"""
from application.api.user.sources.upload import ManageSourceFiles
user = "alice-mgr-rmrep"
src = _seed_source(
pg_conn,
user=user,
file_path="/data",
file_name_map={"a.txt": "a.txt"},
)
fake_storage = MagicMock()
fake_storage.file_exists.return_value = True
apply_mock = _apply_async_mock()
def _do_remove():
return app.test_request_context(
"/api/manage_source_files",
method="POST",
data={
"source_id": str(src["id"]),
"operation": "remove",
"file_paths": json.dumps(["a.txt"]),
},
content_type="multipart/form-data",
headers={"Idempotency-Key": "mgr-rmrep"},
)
with _patch_db(pg_conn), patch(
"application.api.user.sources.upload.StorageCreator.get_storage",
return_value=fake_storage,
), patch(
"application.api.user.tasks.reingest_source_task.apply_async",
apply_mock,
):
with _do_remove():
from flask import request
request.decoded_token = {"sub": user}
first = ManageSourceFiles().post()
with _do_remove():
from flask import request
request.decoded_token = {"sub": user}
second = ManageSourceFiles().post()
assert first.status_code == 200
assert second.status_code == 200
assert apply_mock.call_count == 1
assert second.json["reingest_task_id"] == first.json["reingest_task_id"]
# The contract under test: cached source_id matches the fresh
# response (the real source row id), not the helper's uuid5.
assert second.json["source_id"] == first.json["source_id"]
assert second.json["source_id"] == str(src["id"])
def test_remove_directory_same_key_second_post_returns_real_source_id(
self, app, pg_conn
):
"""Same regression as the ``remove`` test, for the
``remove_directory`` branch.
"""
from application.api.user.sources.upload import ManageSourceFiles
user = "alice-mgr-rmdir-rep"
src = _seed_source(
pg_conn,
user=user,
file_path="/data",
file_name_map={"sub/a.txt": "a.txt"},
)
fake_storage = MagicMock()
fake_storage.is_directory.return_value = True
fake_storage.remove_directory.return_value = True
apply_mock = _apply_async_mock()
def _do_remove_dir():
return app.test_request_context(
"/api/manage_source_files",
method="POST",
data={
"source_id": str(src["id"]),
"operation": "remove_directory",
"directory_path": "sub",
},
content_type="multipart/form-data",
headers={"Idempotency-Key": "mgr-rmdir-rep"},
)
with _patch_db(pg_conn), patch(
"application.api.user.sources.upload.StorageCreator.get_storage",
return_value=fake_storage,
), patch(
"application.api.user.tasks.reingest_source_task.apply_async",
apply_mock,
):
with _do_remove_dir():
from flask import request
request.decoded_token = {"sub": user}
first = ManageSourceFiles().post()
with _do_remove_dir():
from flask import request
request.decoded_token = {"sub": user}
second = ManageSourceFiles().post()
assert first.status_code == 200
assert second.status_code == 200
assert apply_mock.call_count == 1
assert second.json["reingest_task_id"] == first.json["reingest_task_id"]
assert second.json["source_id"] == first.json["source_id"]
assert second.json["source_id"] == str(src["id"])
def test_concurrent_same_key_only_one_apply_async(self, app, pg_engine):
"""N parallel same-key POSTs → exactly one apply_async."""
from concurrent.futures import ThreadPoolExecutor

View File

@@ -506,60 +506,6 @@ class TestGetMessageTail:
assert response.status_code == 404
def test_streaming_row_returns_partial_from_journal(self, app, pg_conn):
"""Mid-stream rows must rebuild from message_events, not return the placeholder."""
from application.api.user.conversations.routes import GetMessageTail
from application.storage.db.repositories.message_events import (
MessageEventsRepository,
)
owner = "user-tail-partial"
_, msg_id = self._seed_in_flight_message(pg_conn, owner)
events_repo = MessageEventsRepository(pg_conn)
events_repo.record(msg_id, 0, "message_id", {"type": "message_id"})
events_repo.record(msg_id, 1, "answer", {"type": "answer", "answer": "Hello"})
events_repo.record(msg_id, 2, "answer", {"type": "answer", "answer": ", world"})
events_repo.record(
msg_id, 3, "source", {"type": "source", "source": [{"id": "s1"}]}
)
with _patch_conversations_db(pg_conn), app.test_request_context(
f"/api/messages/{msg_id}/tail"
):
from flask import request
request.decoded_token = {"sub": owner}
response = GetMessageTail().get(msg_id)
assert response.status_code == 200
assert response.json["status"] == "streaming"
assert response.json["response"] == "Hello, world"
assert response.json["sources"] == [{"id": "s1"}]
assert "terminated prior to completion" not in (
response.json["response"] or ""
)
def test_streaming_row_with_empty_journal_returns_empty_response(
self, app, pg_conn
):
"""Empty journal returns empty response, not the placeholder."""
from application.api.user.conversations.routes import GetMessageTail
owner = "user-tail-empty"
_, msg_id = self._seed_in_flight_message(pg_conn, owner)
with _patch_conversations_db(pg_conn), app.test_request_context(
f"/api/messages/{msg_id}/tail"
):
from flask import request
request.decoded_token = {"sub": owner}
response = GetMessageTail().get(msg_id)
assert response.status_code == 200
assert response.json["status"] == "streaming"
assert response.json["response"] == ""
class TestUpdateConversationNameHappy:
def test_returns_401_unauthenticated(self, app):

View File

@@ -33,7 +33,7 @@ class TestIngestTask:
mock_worker.assert_called_once_with(
ANY, "dir", ["pdf"], "job1", "/path", "file.pdf", "user1",
file_name_map=None, idempotency_key=None, source_id=None,
file_name_map=None, idempotency_key=None,
)
assert result == {"status": "ok"}
@@ -50,7 +50,7 @@ class TestIngestTask:
mock_worker.assert_called_once_with(
ANY, "dir", ["pdf"], "job1", "/path", "file.pdf", "user1",
file_name_map=name_map, idempotency_key=None, source_id=None,
file_name_map=name_map, idempotency_key=None,
)
@@ -66,7 +66,7 @@ class TestIngestRemoteTask:
mock_worker.assert_called_once_with(
ANY, {"url": "http://x"}, "job1", "user1", "web",
idempotency_key=None, source_id=None,
idempotency_key=None,
)
assert result == {"status": "ok"}
@@ -169,7 +169,6 @@ class TestIngestConnectorTask:
doc_id=None,
sync_frequency="never",
idempotency_key=None,
source_id=None,
)
assert result == {"status": "ok"}
@@ -208,7 +207,6 @@ class TestIngestConnectorTask:
doc_id="doc1",
sync_frequency="daily",
idempotency_key=None,
source_id=None,
)
assert result == {"status": "ok"}
@@ -222,7 +220,7 @@ class TestSetupPeriodicTasks:
setup_periodic_tasks(sender)
assert sender.add_periodic_task.call_count == 8
assert sender.add_periodic_task.call_count == 7
calls = sender.add_periodic_task.call_args_list
@@ -243,9 +241,6 @@ class TestSetupPeriodicTasks:
assert calls[5][1].get("name") == "reconciliation"
# version-check (every 7h)
assert calls[6][0][0] == timedelta(hours=7)
# message_events retention sweep (24h)
assert calls[7][0][0] == timedelta(hours=24)
assert calls[7][1].get("name") == "cleanup-message-events"
class TestMcpOauthTask:
@@ -262,6 +257,20 @@ class TestMcpOauthTask:
assert result == {"url": "http://auth"}
class TestMcpOauthStatusTask:
@pytest.mark.unit
@patch("application.api.user.tasks.mcp_oauth_status")
def test_calls_mcp_oauth_status(self, mock_worker):
from application.api.user.tasks import mcp_oauth_status_task
mock_worker.return_value = {"status": "authorized"}
result = mcp_oauth_status_task("task123")
mock_worker.assert_called_once_with(ANY, "task123")
assert result == {"status": "authorized"}
class TestDurableTaskRetryPolicy:
"""The long-running tasks share a uniform retry policy."""
@@ -293,6 +302,7 @@ class TestDurableTaskRetryPolicy:
"schedule_syncs",
"sync_source",
"mcp_oauth_task",
"mcp_oauth_status_task",
"cleanup_pending_tool_state",
"reconciliation_task",
"version_check_task",
@@ -428,93 +438,6 @@ class TestCleanupPendingToolState:
}
class TestCleanupMessageEventsTask:
"""Retention janitor delegates to MessageEventsRepository.cleanup_older_than."""
@pytest.mark.unit
def test_skips_when_postgres_uri_missing(self, monkeypatch):
from application.api.user.tasks import cleanup_message_events
from application.core.settings import settings
monkeypatch.setattr(settings, "POSTGRES_URI", None, raising=False)
result = cleanup_message_events.run()
assert result == {"deleted": 0, "skipped": "POSTGRES_URI not set"}
@pytest.mark.unit
def test_deletes_rows_past_retention_window(self, pg_conn, monkeypatch):
import uuid
from sqlalchemy import text as _text
from application.api.user.tasks import cleanup_message_events
from application.core.settings import settings
from application.storage.db.repositories.message_events import (
MessageEventsRepository,
)
# Seed parent rows so the FK on message_events holds.
user_id = f"user-{uuid.uuid4().hex[:8]}"
conv_id = uuid.uuid4()
msg_id = uuid.uuid4()
pg_conn.execute(
_text("INSERT INTO users (user_id) VALUES (:u)"),
{"u": user_id},
)
pg_conn.execute(
_text(
"INSERT INTO conversations (id, user_id, name) "
"VALUES (:id, :u, 'test')"
),
{"id": conv_id, "u": user_id},
)
pg_conn.execute(
_text(
"INSERT INTO conversation_messages (id, conversation_id, "
"user_id, position) VALUES (:id, :c, :u, 0)"
),
{"id": msg_id, "c": conv_id, "u": user_id},
)
repo = MessageEventsRepository(pg_conn)
repo.record(str(msg_id), 0, "answer", {"chunk": "stale"})
repo.record(str(msg_id), 1, "answer", {"chunk": "fresh"})
# Backdate seq=0 past the default 14-day retention so the
# janitor catches it; seq=1 stays at "now" and must survive.
pg_conn.execute(
_text(
"UPDATE message_events SET created_at = now() - interval '20 days' "
"WHERE message_id = CAST(:id AS uuid) AND sequence_no = 0"
),
{"id": str(msg_id)},
)
monkeypatch.setattr(
settings, "POSTGRES_URI", "postgresql://stub", raising=False
)
@contextmanager
def _fake_begin():
yield pg_conn
fake_engine = MagicMock()
fake_engine.begin = _fake_begin
with patch(
"application.storage.db.engine.get_engine",
return_value=fake_engine,
):
result = cleanup_message_events.run()
assert result == {
"deleted": 1,
"ttl_days": settings.MESSAGE_EVENTS_RETENTION_DAYS,
}
# Only the fresh row survives.
rows = repo.read_after(str(msg_id))
assert [r["sequence_no"] for r in rows] == [1]
class TestIngestIdempotency:
"""Same short-circuit applies to the ingest task path."""
@@ -526,7 +449,7 @@ class TestIngestIdempotency:
def _fake_worker(self, directory, formats, job_name, file_path,
filename, user, file_name_map=None,
idempotency_key=None, source_id=None):
idempotency_key=None):
worker_calls.append(filename)
return {"status": "ok", "directory": directory}

View File

@@ -363,6 +363,47 @@ class TestMCPServerSave:
assert response.status_code in (200, 201)
class TestMCPOAuthStatus:
def test_returns_pending_when_no_data(self, app):
from application.api.user.tools.mcp import MCPOAuthStatus
fake_redis = MagicMock()
fake_redis.get.return_value = None
with patch(
"application.api.user.tools.mcp.get_redis_instance",
return_value=fake_redis,
), app.test_request_context("/api/mcp_oauth_status/t1"):
response = MCPOAuthStatus().get("t1")
assert response.status_code == 200
assert response.json["status"] == "pending"
def test_returns_status_from_redis(self, app):
from application.api.user.tools.mcp import MCPOAuthStatus
fake_redis = MagicMock()
fake_redis.get.return_value = '{"status": "completed", "tools": [{"name": "t1", "description": "d"}]}'
with patch(
"application.api.user.tools.mcp.get_redis_instance",
return_value=fake_redis,
), app.test_request_context("/api/mcp_oauth_status/t1"):
response = MCPOAuthStatus().get("t1")
assert response.status_code == 200
assert response.json["status"] == "completed"
assert response.json["tools"][0]["name"] == "t1"
def test_redis_error_returns_500(self, app):
from application.api.user.tools.mcp import MCPOAuthStatus
with patch(
"application.api.user.tools.mcp.get_redis_instance",
side_effect=RuntimeError("boom"),
), app.test_request_context("/api/mcp_oauth_status/t1"):
response = MCPOAuthStatus().get("t1")
assert response.status_code == 500
class TestMCPOAuthCallback:
def test_error_param_redirects_error(self, app):
from application.api.user.tools.mcp import MCPOAuthCallback

View File

@@ -513,73 +513,4 @@ class TestChatCompletionsHappyPath:
assert resp.status_code == 200
assert resp.mimetype == "text/event-stream"
def test_stream_handles_id_prefixed_chunks(self, pg_conn):
"""``complete_stream`` emits ``id: <seq>\\n`` before each
``data:`` line. The v1 streaming consumer must skip the id
header and the informational ``message_id`` event, not silently
drop every chunk.
"""
app = _build_app()
def _fake_translate(data, api_key):
return {"question": "hi"}
fake_processor = MagicMock()
fake_processor.decoded_token = {"sub": "u"}
fake_processor.conversation_id = "conv-1"
fake_processor.agent_config = {}
fake_processor.agent_id = None
fake_processor.model_id = "m"
def _fake_helper_complete_stream(**kw):
# Mirror the new wire format: id-prefixed records, plus
# the informational message_id event the v1 path doesn't
# have an analog for.
yield 'id: 0\ndata: {"type": "message_id", "message_id": "m-1"}\n\n'
yield 'id: 1\ndata: {"type": "id", "id": "conv-1"}\n\n'
yield 'id: 2\ndata: {"type": "answer", "answer": "hi"}\n\n'
translated_chunks: list = []
def _fake_translate_stream_event(event_data, completion_id, model_name):
translated_chunks.append(event_data)
return ['data: x\n\n']
fake_helper = MagicMock()
fake_helper.check_usage.return_value = None
fake_helper.complete_stream.side_effect = _fake_helper_complete_stream
with _patch_v1_db(pg_conn), patch(
"application.api.v1.routes.translate_request",
side_effect=_fake_translate,
), patch(
"application.api.v1.routes.StreamProcessor",
return_value=fake_processor,
), patch(
"application.api.v1.routes._V1AnswerHelper",
return_value=fake_helper,
), patch(
"application.api.v1.routes.translate_stream_event",
side_effect=_fake_translate_stream_event,
):
with app.test_client() as c:
resp = c.post(
"/v1/chat/completions",
headers={"Authorization": "Bearer x"},
json={
"messages": [{"role": "user", "content": "Hi"}],
"stream": True,
},
)
# Drain the response so the generator runs to completion.
list(resp.iter_encoded())
assert resp.status_code == 200
# message_id event is skipped (no v1 analog); id + answer are
# decoded and forwarded to the translator.
types_translated = [c.get("type") for c in translated_chunks]
assert "message_id" not in types_translated
assert "id" in types_translated
assert "answer" in types_translated

View File

@@ -15,9 +15,9 @@ their own conftest to point at a real, long-running Postgres instance
tests and are marked with ``@pytest.mark.integration``.
No mongomock. The ``mock_mongo_db`` fixture that used to live here was
removed as part of the Mongo→Postgres cutover. Tests that still
reference it will fail with "fixture not found" until the corresponding
route handler is migrated to a repository read.
removed as part of the Phase 4/5 Mongo→Postgres cutover. Tests that
still reference it will fail with "fixture not found" until the
corresponding route handler is migrated to a repository read.
"""
from __future__ import annotations

View File

@@ -1,4 +1,4 @@
"""Regression tests for the YAML-driven ModelRegistry.
"""Phase 1 regression tests for the YAML-driven ModelRegistry.
These tests encode the contract that persisted agent / workflow /
conversation references depend on: every model id and core capability

View File

@@ -1,4 +1,4 @@
"""Tests for the operator MODELS_CONFIG_DIR.
"""Phase 3 tests: operator MODELS_CONFIG_DIR.
Covers the operator-supplied directory of model YAMLs that's loaded
after the built-in catalog. Operators use this to add new

View File

@@ -1,4 +1,4 @@
"""Tests for the openai_compatible provider.
"""Phase 2 tests for the openai_compatible provider.
Covers YAML loading from a temp directory, multiple coexisting catalogs
(Mistral + Together), env-var-based credential resolution, the legacy

View File

@@ -1,6 +1,6 @@
/**
* Shared agent provisioning for specs that need a published agent
* (with a real api_key) for subsequent /stream or /search
* Phase 2 helper — shared agent provisioning for specs that need a
* published agent (with a real api_key) for subsequent /stream or /search
* calls. A PUBLISHED classic agent requires name, description, chunks,
* retriever, prompt_id AND a source — otherwise `/api/create_agent`
* returns 400.

View File

@@ -1,4 +1,5 @@
/**
* Phase 1 helper — see e2e-plan.md §P1-B.
* Pre-authenticated Playwright APIRequestContext pointed at the e2e Flask.
*/

View File

@@ -1,4 +1,5 @@
/**
* Phase 1 helper — see e2e-plan.md §P1-B.
* JWT signing + per-test BrowserContext seeding for DocsGPT e2e.
*/

View File

@@ -1,4 +1,5 @@
/**
* Phase 1 helper — see e2e-plan.md §P1-B.
* Thin pg wrapper + typed row helpers for DB assertions in specs.
*/

View File

@@ -1,4 +1,5 @@
/**
* Phase 1 helper — see e2e-plan.md §P1-B.
* Per-test TRUNCATE — preserves `alembic_version`, wipes every other table.
*/

View File

@@ -1,5 +1,5 @@
/**
* Shared streaming primitives for SSE specs.
* Phase 2 helper — shared streaming primitives.
*
* The backend exposes two streaming endpoints:
* - POST /stream (answer_ns path="/" + route "/stream") — SSE body

Some files were not shown because too many files have changed in this diff Show More