mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-14 08:03:19 +00:00
Compare commits
28 Commits
convsearch
...
feat-notif
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
827a0bb382 | ||
|
|
b04cb44ab5 | ||
|
|
42384a0e92 | ||
|
|
0bce35ad29 | ||
|
|
9de8bb4499 | ||
|
|
cdbd3f061d | ||
|
|
2ac46fd858 | ||
|
|
daa4320da2 | ||
|
|
e70a7a5115 | ||
|
|
150d9f4e37 | ||
|
|
746bcbc5f9 | ||
|
|
aa91117fbf | ||
|
|
abbd56cb66 | ||
|
|
85d8375e6c | ||
|
|
7e98d21b61 | ||
|
|
249f9f9fe0 | ||
|
|
6c4346eb84 | ||
|
|
cb3ca8a36b | ||
|
|
4c8230fb6c | ||
|
|
649557798d | ||
|
|
afe8354ca5 | ||
|
|
5483eb0e27 | ||
|
|
bd2985db47 | ||
|
|
b99147ba83 | ||
|
|
c3023f8b71 | ||
|
|
c168a530f5 | ||
|
|
2d539f3199 | ||
|
|
ed9444cf3d |
@@ -20,10 +20,11 @@ from pydantic import AnyHttpUrl, ValidationError
|
||||
from redis import Redis
|
||||
|
||||
from application.agents.tools.base import Tool
|
||||
from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task
|
||||
from application.api.user.tasks import mcp_oauth_task
|
||||
from application.cache import get_redis_instance
|
||||
from application.core.settings import settings
|
||||
from application.core.url_validation import SSRFError, validate_url
|
||||
from application.events.keys import stream_key
|
||||
from application.security.encryption import decrypt_credentials
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -76,6 +77,12 @@ class MCPTool(Tool):
|
||||
self.oauth_task_id = config.get("oauth_task_id", None)
|
||||
self.oauth_client_name = config.get("oauth_client_name", "DocsGPT-MCP")
|
||||
self.redirect_uri = self._resolve_redirect_uri(config.get("redirect_uri"))
|
||||
# Pulled out of ``config`` (rather than left in ``self.config``)
|
||||
# because it is a callable supplied by the OAuth worker — not
|
||||
# something the rest of the tool plumbing should marshal or
|
||||
# serialize. ``DocsGPTOAuth`` invokes it from ``redirect_handler``
|
||||
# so the SSE envelope can carry ``authorization_url``.
|
||||
self.oauth_redirect_publish = config.pop("oauth_redirect_publish", None)
|
||||
|
||||
self.available_tools = []
|
||||
self._cache_key = self._generate_cache_key()
|
||||
@@ -167,6 +174,7 @@ class MCPTool(Tool):
|
||||
redirect_uri=self.redirect_uri,
|
||||
task_id=self.oauth_task_id,
|
||||
user_id=self.user_id,
|
||||
redirect_publish=self.oauth_redirect_publish,
|
||||
)
|
||||
elif self.auth_type == "bearer":
|
||||
token = self.auth_credentials.get(
|
||||
@@ -679,12 +687,17 @@ class DocsGPTOAuth(OAuthClientProvider):
|
||||
user_id=None,
|
||||
additional_client_metadata: dict[str, Any] | None = None,
|
||||
skip_redirect_validation: bool = False,
|
||||
redirect_publish=None,
|
||||
):
|
||||
self.redirect_uri = redirect_uri
|
||||
self.redis_client = redis_client
|
||||
self.redis_prefix = redis_prefix
|
||||
self.task_id = task_id
|
||||
self.user_id = user_id
|
||||
# Worker-supplied callback. Invoked from ``redirect_handler``
|
||||
# once the authorization URL is known so the SSE envelope can
|
||||
# carry it. ``None`` for any non-worker entrypoint.
|
||||
self.redirect_publish = redirect_publish
|
||||
|
||||
parsed_url = urlparse(mcp_url)
|
||||
self.server_base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||||
@@ -744,17 +757,19 @@ class DocsGPTOAuth(OAuthClientProvider):
|
||||
self.redis_client.setex(key, 600, auth_url)
|
||||
logger.info("Stored auth_url in Redis: %s", key)
|
||||
|
||||
if self.task_id:
|
||||
status_key = f"mcp_oauth_status:{self.task_id}"
|
||||
status_data = {
|
||||
"status": "requires_redirect",
|
||||
"message": "Authorization required",
|
||||
"authorization_url": self.auth_url,
|
||||
"state": self.extracted_state,
|
||||
"requires_oauth": True,
|
||||
"task_id": self.task_id,
|
||||
}
|
||||
self.redis_client.setex(status_key, 600, json.dumps(status_data))
|
||||
if self.redirect_publish is not None:
|
||||
# Best-effort: a publish failure must not abort the OAuth
|
||||
# handshake — the user can still authorize via the popup
|
||||
# opened from the legacy polling fallback if the SSE
|
||||
# envelope is lost.
|
||||
try:
|
||||
self.redirect_publish(auth_url)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"redirect_publish callback raised for task_id=%s",
|
||||
self.task_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
async def callback_handler(self) -> tuple[str, str | None]:
|
||||
"""Wait for auth code from Redis using the state value."""
|
||||
@@ -764,17 +779,6 @@ class DocsGPTOAuth(OAuthClientProvider):
|
||||
max_wait_time = 300
|
||||
code_key = f"{self.redis_prefix}code:{self.extracted_state}"
|
||||
|
||||
if self.task_id:
|
||||
status_key = f"mcp_oauth_status:{self.task_id}"
|
||||
status_data = {
|
||||
"status": "awaiting_callback",
|
||||
"message": "Waiting for authorization...",
|
||||
"authorization_url": self.auth_url,
|
||||
"state": self.extracted_state,
|
||||
"requires_oauth": True,
|
||||
"task_id": self.task_id,
|
||||
}
|
||||
self.redis_client.setex(status_key, 600, json.dumps(status_data))
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < max_wait_time:
|
||||
code_data = self.redis_client.get(code_key)
|
||||
@@ -789,14 +793,6 @@ class DocsGPTOAuth(OAuthClientProvider):
|
||||
self.redis_client.delete(
|
||||
f"{self.redis_prefix}state:{self.extracted_state}"
|
||||
)
|
||||
|
||||
if self.task_id:
|
||||
status_data = {
|
||||
"status": "callback_received",
|
||||
"message": "Completing authentication...",
|
||||
"task_id": self.task_id,
|
||||
}
|
||||
self.redis_client.setex(status_key, 600, json.dumps(status_data))
|
||||
return code, returned_state
|
||||
error_key = f"{self.redis_prefix}error:{self.extracted_state}"
|
||||
error_data = self.redis_client.get(error_key)
|
||||
@@ -1038,8 +1034,73 @@ class MCPOAuthManager:
|
||||
logger.error("Error handling OAuth callback: %s", e)
|
||||
return False
|
||||
|
||||
def get_oauth_status(self, task_id: str) -> Dict[str, Any]:
|
||||
"""Get current status of OAuth flow using provided task_id."""
|
||||
def get_oauth_status(self, task_id: str, user_id: str) -> Dict[str, Any]:
|
||||
"""Return the latest OAuth status for ``task_id`` from the user's SSE journal.
|
||||
|
||||
Mirrors the legacy polling contract: ``status`` derived from the
|
||||
``mcp.oauth.*`` event-type suffix, with payload fields surfaced
|
||||
(e.g. ``tools``/``tools_count`` on ``completed``).
|
||||
"""
|
||||
if not task_id:
|
||||
return {"status": "not_started", "message": "OAuth flow not started"}
|
||||
return mcp_oauth_status_task(task_id)
|
||||
if not user_id:
|
||||
return {"status": "not_found", "message": "User not provided"}
|
||||
if self.redis_client is None:
|
||||
return {"status": "not_found", "message": "Redis unavailable"}
|
||||
|
||||
try:
|
||||
# OAuth flows are short-lived but a concurrent source
|
||||
# ingest can flood the user channel between the OAuth
|
||||
# popup completing and the user clicking Save, pushing the
|
||||
# completion envelope outside the read window. Bound the
|
||||
# scan by the configured stream cap so we cover the full
|
||||
# journal — XADD MAXLEN keeps that bounded too.
|
||||
scan_count = max(settings.EVENTS_STREAM_MAXLEN, 200)
|
||||
entries = self.redis_client.xrevrange(
|
||||
stream_key(user_id), count=scan_count
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"xrevrange failed for oauth status: user_id=%s task_id=%s",
|
||||
user_id,
|
||||
task_id,
|
||||
)
|
||||
return {"status": "not_found", "message": "Status unavailable"}
|
||||
|
||||
for _entry_id, fields in entries:
|
||||
if not isinstance(fields, dict):
|
||||
continue
|
||||
# decode_responses=False ⇒ bytes keys; the string-key fallback
|
||||
# covers a future flip of that default without a forced refactor.
|
||||
event_raw = fields.get(b"event")
|
||||
if event_raw is None:
|
||||
event_raw = fields.get("event")
|
||||
if event_raw is None:
|
||||
continue
|
||||
if isinstance(event_raw, bytes):
|
||||
try:
|
||||
event_raw = event_raw.decode("utf-8")
|
||||
except Exception:
|
||||
continue
|
||||
try:
|
||||
envelope = json.loads(event_raw)
|
||||
except Exception:
|
||||
continue
|
||||
if not isinstance(envelope, dict):
|
||||
continue
|
||||
event_type = envelope.get("type", "")
|
||||
if not isinstance(event_type, str) or not event_type.startswith(
|
||||
"mcp.oauth."
|
||||
):
|
||||
continue
|
||||
scope = envelope.get("scope") or {}
|
||||
if scope.get("kind") != "mcp_oauth" or scope.get("id") != task_id:
|
||||
continue
|
||||
payload = envelope.get("payload") or {}
|
||||
return {
|
||||
"status": event_type[len("mcp.oauth."):],
|
||||
"task_id": task_id,
|
||||
**payload,
|
||||
}
|
||||
|
||||
return {"status": "not_found", "message": "Status not found"}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""0001 initial schema — consolidated Phase-1..3 baseline.
|
||||
"""0001 initial schema — consolidated baseline for user-data tables.
|
||||
|
||||
Revision ID: 0001_initial
|
||||
Revises:
|
||||
|
||||
40
application/alembic/versions/0007_message_events.py
Normal file
40
application/alembic/versions/0007_message_events.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""0007 message_events — durable journal of chat-stream events.
|
||||
|
||||
Snapshot half of the chat-stream snapshot+tail pattern. Composite PK
|
||||
``(message_id, sequence_no)``, ``created_at`` indexed for retention
|
||||
sweeps, ``ON DELETE CASCADE`` from ``conversation_messages``.
|
||||
|
||||
Revision ID: 0007_message_events
|
||||
Revises: 0006_idempotency_lease
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
revision: str = "0007_message_events"
|
||||
down_revision: Union[str, None] = "0006_idempotency_lease"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE message_events (
|
||||
message_id UUID NOT NULL REFERENCES conversation_messages(id) ON DELETE CASCADE,
|
||||
sequence_no INTEGER NOT NULL,
|
||||
event_type TEXT NOT NULL,
|
||||
payload JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
PRIMARY KEY (message_id, sequence_no)
|
||||
);
|
||||
CREATE INDEX message_events_created_at_idx ON message_events(created_at);
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP INDEX IF EXISTS message_events_created_at_idx;")
|
||||
op.execute("DROP TABLE IF EXISTS message_events;")
|
||||
@@ -23,9 +23,16 @@ from application.core.settings import settings
|
||||
from application.error import sanitize_api_error
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.conversations import MessageUpdateOutcome
|
||||
from application.storage.db.repositories.token_usage import TokenUsageRepository
|
||||
from application.storage.db.repositories.user_logs import UserLogsRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.events.publisher import publish_user_event
|
||||
from application.streaming.event_replay import format_sse_event
|
||||
from application.streaming.message_journal import (
|
||||
BatchedJournalWriter,
|
||||
record_event,
|
||||
)
|
||||
from application.utils import check_required_fields
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -277,6 +284,17 @@ class BaseAnswerResource:
|
||||
"update_message_status streaming failed for %s",
|
||||
reserved_message_id,
|
||||
)
|
||||
# Seed last_heartbeat_at so watchdog doesn't fall back to `timestamp`
|
||||
# (creation time) before the first STREAM_HEARTBEAT_INTERVAL tick.
|
||||
try:
|
||||
self.conversation_service.heartbeat_message(
|
||||
reserved_message_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"initial heartbeat seed failed for %s",
|
||||
reserved_message_id,
|
||||
)
|
||||
streaming_marked = True
|
||||
last_heartbeat_at = time.monotonic()
|
||||
|
||||
@@ -303,13 +321,73 @@ class BaseAnswerResource:
|
||||
try:
|
||||
agent.tool_executor.message_id = reserved_message_id
|
||||
except Exception:
|
||||
pass
|
||||
logger.debug(
|
||||
"Could not set tool_executor.message_id; tool-call correlation will be missing for message_id=%s",
|
||||
reserved_message_id,
|
||||
)
|
||||
|
||||
# Per-stream monotonic SSE event id. Allocated by ``_emit`` and
|
||||
# threaded through both the wire format (``id: <seq>\\n``) and
|
||||
# the journal write so a reconnecting client can ``Last-Event-
|
||||
# ID`` past anything they already saw. Continuations resume
|
||||
# against the original ``reserved_message_id`` — seed the
|
||||
# allocator from the journal's high-water mark so we don't
|
||||
# collide on the duplicate-PK and silently lose every emit
|
||||
# past the resume point.
|
||||
sequence_no = -1
|
||||
if _continuation and reserved_message_id:
|
||||
try:
|
||||
from application.storage.db.repositories.message_events import (
|
||||
MessageEventsRepository,
|
||||
)
|
||||
|
||||
with db_readonly() as conn:
|
||||
latest = MessageEventsRepository(conn).latest_sequence_no(
|
||||
reserved_message_id
|
||||
)
|
||||
if latest is not None:
|
||||
sequence_no = latest
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Continuation seq seed lookup failed for message_id=%s; "
|
||||
"falling back to seq=-1 (duplicate-PK collisions will "
|
||||
"be swallowed)",
|
||||
reserved_message_id,
|
||||
)
|
||||
|
||||
# One batched journal writer per stream.
|
||||
journal_writer: Optional[BatchedJournalWriter] = (
|
||||
BatchedJournalWriter(reserved_message_id)
|
||||
if reserved_message_id
|
||||
else None
|
||||
)
|
||||
|
||||
def _emit(payload: dict) -> str:
|
||||
"""Format-and-journal one SSE event.
|
||||
|
||||
With a reserved ``message_id``, buffers into the journal and
|
||||
emits ``id: <seq>``-tagged SSE frames; otherwise falls back to
|
||||
legacy ``data: ...\\n\\n`` framing.
|
||||
"""
|
||||
nonlocal sequence_no
|
||||
if not reserved_message_id or journal_writer is None:
|
||||
return f"data: {json.dumps(payload)}\n\n"
|
||||
sequence_no += 1
|
||||
seq = sequence_no
|
||||
event_type = (
|
||||
payload.get("type", "data")
|
||||
if isinstance(payload, dict)
|
||||
else "data"
|
||||
)
|
||||
normalised = payload if isinstance(payload, dict) else {"value": payload}
|
||||
journal_writer.record(seq, event_type, normalised)
|
||||
return format_sse_event(normalised, seq)
|
||||
|
||||
try:
|
||||
# Surface the placeholder id before any LLM tokens so a
|
||||
# mid-handshake disconnect still has a row to tail-poll.
|
||||
if reserved_message_id:
|
||||
early_event = json.dumps(
|
||||
yield _emit(
|
||||
{
|
||||
"type": "message_id",
|
||||
"message_id": reserved_message_id,
|
||||
@@ -319,7 +397,6 @@ class BaseAnswerResource:
|
||||
"request_id": request_id,
|
||||
}
|
||||
)
|
||||
yield f"data: {early_event}\n\n"
|
||||
|
||||
if _continuation:
|
||||
gen_iter = agent.gen_continuation(
|
||||
@@ -345,8 +422,9 @@ class BaseAnswerResource:
|
||||
schema_info = line.get("schema")
|
||||
structured_chunks.append(line["answer"])
|
||||
else:
|
||||
data = json.dumps({"type": "answer", "answer": line["answer"]})
|
||||
yield f"data: {data}\n\n"
|
||||
yield _emit(
|
||||
{"type": "answer", "answer": line["answer"]}
|
||||
)
|
||||
elif "sources" in line:
|
||||
_mark_streaming_once()
|
||||
truncated_sources = []
|
||||
@@ -359,43 +437,40 @@ class BaseAnswerResource:
|
||||
)
|
||||
truncated_sources.append(truncated_source)
|
||||
if truncated_sources:
|
||||
data = json.dumps(
|
||||
yield _emit(
|
||||
{"type": "source", "source": truncated_sources}
|
||||
)
|
||||
yield f"data: {data}\n\n"
|
||||
elif "tool_calls" in line:
|
||||
tool_calls = line["tool_calls"]
|
||||
data = json.dumps({"type": "tool_calls", "tool_calls": tool_calls})
|
||||
yield f"data: {data}\n\n"
|
||||
yield _emit({"type": "tool_calls", "tool_calls": tool_calls})
|
||||
elif "thought" in line:
|
||||
thought += line["thought"]
|
||||
data = json.dumps({"type": "thought", "thought": line["thought"]})
|
||||
yield f"data: {data}\n\n"
|
||||
yield _emit({"type": "thought", "thought": line["thought"]})
|
||||
elif "type" in line:
|
||||
if line.get("type") == "tool_calls_pending":
|
||||
# Save continuation state and end the stream
|
||||
paused = True
|
||||
data = json.dumps(line)
|
||||
yield f"data: {data}\n\n"
|
||||
yield _emit(line)
|
||||
elif line.get("type") == "error":
|
||||
sanitized_error = {
|
||||
"type": "error",
|
||||
"error": sanitize_api_error(line.get("error", "An error occurred"))
|
||||
}
|
||||
data = json.dumps(sanitized_error)
|
||||
yield f"data: {data}\n\n"
|
||||
yield _emit(
|
||||
{
|
||||
"type": "error",
|
||||
"error": sanitize_api_error(
|
||||
line.get("error", "An error occurred")
|
||||
),
|
||||
}
|
||||
)
|
||||
else:
|
||||
data = json.dumps(line)
|
||||
yield f"data: {data}\n\n"
|
||||
yield _emit(line)
|
||||
if is_structured and structured_chunks:
|
||||
structured_data = {
|
||||
"type": "structured_answer",
|
||||
"answer": response_full,
|
||||
"structured": True,
|
||||
"schema": schema_info,
|
||||
}
|
||||
data = json.dumps(structured_data)
|
||||
yield f"data: {data}\n\n"
|
||||
yield _emit(
|
||||
{
|
||||
"type": "structured_answer",
|
||||
"answer": response_full,
|
||||
"structured": True,
|
||||
"schema": schema_info,
|
||||
}
|
||||
)
|
||||
|
||||
# ---- Paused: save continuation state and end stream early ----
|
||||
if paused:
|
||||
@@ -452,6 +527,7 @@ class BaseAnswerResource:
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
state_saved = False
|
||||
if conversation_id:
|
||||
try:
|
||||
cont_service = ContinuationService()
|
||||
@@ -485,18 +561,65 @@ class BaseAnswerResource:
|
||||
agent.tool_executor, "client_tools", None
|
||||
),
|
||||
)
|
||||
state_saved = True
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to save continuation state: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
id_data = {"type": "id", "id": str(conversation_id)}
|
||||
data = json.dumps(id_data)
|
||||
yield f"data: {data}\n\n"
|
||||
# Notify the user out-of-band so they can navigate
|
||||
# back to the conversation and decide on the
|
||||
# pending tool calls. Gated on ``state_saved``: a
|
||||
# missing pending_tool_state row would 404 the
|
||||
# resume endpoint, so an unfulfillable notification
|
||||
# is worse than no notification.
|
||||
user_id_for_event = (
|
||||
decoded_token.get("sub") if decoded_token else None
|
||||
)
|
||||
if state_saved and user_id_for_event and conversation_id:
|
||||
pending_calls = continuation.get(
|
||||
"pending_tool_calls", []
|
||||
) if continuation else []
|
||||
# Trim each pending tool call to its identifying
|
||||
# metadata so a tool with a multi-MB argument
|
||||
# doesn't blow out the per-event payload size
|
||||
# cap. The resume page fetches full args from
|
||||
# ``pending_tool_state`` regardless.
|
||||
pending_summaries = [
|
||||
{
|
||||
k: tc.get(k)
|
||||
for k in (
|
||||
"call_id",
|
||||
"tool_name",
|
||||
"action_name",
|
||||
"name",
|
||||
)
|
||||
if isinstance(tc, dict) and tc.get(k) is not None
|
||||
}
|
||||
for tc in (pending_calls or [])
|
||||
if isinstance(tc, dict)
|
||||
]
|
||||
publish_user_event(
|
||||
user_id_for_event,
|
||||
"tool.approval.required",
|
||||
{
|
||||
"conversation_id": str(conversation_id),
|
||||
"message_id": reserved_message_id,
|
||||
"pending_tool_calls": pending_summaries,
|
||||
},
|
||||
scope={
|
||||
"kind": "conversation",
|
||||
"id": str(conversation_id),
|
||||
},
|
||||
)
|
||||
|
||||
data = json.dumps({"type": "end"})
|
||||
yield f"data: {data}\n\n"
|
||||
yield _emit({"type": "id", "id": str(conversation_id)})
|
||||
yield _emit({"type": "end"})
|
||||
# Drain the terminal ``end`` so a reconnecting client
|
||||
# sees it on snapshot — same reason as the main exit.
|
||||
if journal_writer is not None:
|
||||
journal_writer.close()
|
||||
return
|
||||
|
||||
if isNoneDoc:
|
||||
@@ -603,9 +726,7 @@ class BaseAnswerResource:
|
||||
f"completion: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
id_data = {"type": "id", "id": str(conversation_id)}
|
||||
data = json.dumps(id_data)
|
||||
yield f"data: {data}\n\n"
|
||||
yield _emit({"type": "id", "id": str(conversation_id)})
|
||||
|
||||
tool_calls_for_logging = self._prepare_tool_calls_for_logging(
|
||||
getattr(agent, "tool_calls", tool_calls) or tool_calls
|
||||
@@ -646,12 +767,33 @@ class BaseAnswerResource:
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
data = json.dumps({"type": "end"})
|
||||
yield f"data: {data}\n\n"
|
||||
yield _emit({"type": "end"})
|
||||
# Drain the journal buffer so the terminal ``end`` event is
|
||||
# visible to any reconnecting client. Without this the
|
||||
# client could snapshot up to the last flush boundary and
|
||||
# then live-tail waiting for an ``end`` that's still
|
||||
# sitting in memory.
|
||||
if journal_writer is not None:
|
||||
journal_writer.close()
|
||||
except GeneratorExit:
|
||||
logger.info(f"Stream aborted by client for question: {question[:50]}... ")
|
||||
# Drain any buffered events before the terminal one-shot
|
||||
# ``record_event`` below — keeps the journal's seq order
|
||||
# contiguous (buffered events ... terminal event). ``close``
|
||||
# is idempotent; pairing it with ``flush`` matches the
|
||||
# normal-exit and error branches so any future ``record()``
|
||||
# past this point would log instead of silently buffering.
|
||||
if journal_writer is not None:
|
||||
journal_writer.flush()
|
||||
journal_writer.close()
|
||||
# Save partial response
|
||||
|
||||
# Whether the DB row was flipped to ``complete`` during this
|
||||
# abort handler. Drives the choice of terminal journal event
|
||||
# below: journal ``end`` only when the row actually matches,
|
||||
# else journal ``error`` so a reconnecting client sees a
|
||||
# failed terminal state instead of a blank "success".
|
||||
finalized_complete = False
|
||||
if should_save_conversation and response_full:
|
||||
try:
|
||||
if isNoneDoc:
|
||||
@@ -686,7 +828,7 @@ class BaseAnswerResource:
|
||||
)
|
||||
llm._token_usage_source = "title"
|
||||
if reserved_message_id is not None:
|
||||
self.conversation_service.finalize_message(
|
||||
outcome = self.conversation_service.finalize_message(
|
||||
reserved_message_id,
|
||||
response_full,
|
||||
thought=thought,
|
||||
@@ -705,6 +847,15 @@ class BaseAnswerResource:
|
||||
),
|
||||
},
|
||||
)
|
||||
# ``ALREADY_COMPLETE`` means the normal-path
|
||||
# finalize at line 632 won the race: the DB row
|
||||
# is already at ``complete`` and the reconnect
|
||||
# journal should reflect that with ``end``,
|
||||
# not a spurious ``error``.
|
||||
finalized_complete = outcome in (
|
||||
MessageUpdateOutcome.UPDATED,
|
||||
MessageUpdateOutcome.ALREADY_COMPLETE,
|
||||
)
|
||||
else:
|
||||
self.conversation_service.save_conversation(
|
||||
conversation_id,
|
||||
@@ -724,6 +875,9 @@ class BaseAnswerResource:
|
||||
attachment_ids=attachment_ids,
|
||||
metadata=query_metadata if query_metadata else None,
|
||||
)
|
||||
# No journal row to gate, but flag the save as
|
||||
# successful for symmetry with the WAL path.
|
||||
finalized_complete = True
|
||||
compression_meta = getattr(agent, "compression_metadata", None)
|
||||
compression_saved = getattr(agent, "compression_saved", False)
|
||||
if conversation_id and compression_meta and not compression_saved:
|
||||
@@ -747,6 +901,63 @@ class BaseAnswerResource:
|
||||
logger.error(
|
||||
f"Error saving partial response: {str(e)}", exc_info=True
|
||||
)
|
||||
# Journal a terminal event so reconnecting clients stop tailing;
|
||||
# ``end`` only when the row is ``complete``, else ``error``.
|
||||
if reserved_message_id is not None:
|
||||
try:
|
||||
sequence_no += 1
|
||||
if finalized_complete:
|
||||
# Match the wire shape ``_emit({"type": "end"})``
|
||||
# uses on the normal path — the replay terminal
|
||||
# check at ``event_replay._payload_is_terminal``
|
||||
# reads ``payload.type``, and the frontend parses
|
||||
# the same key off ``data:``.
|
||||
record_event(
|
||||
reserved_message_id,
|
||||
sequence_no,
|
||||
"end",
|
||||
{"type": "end"},
|
||||
)
|
||||
else:
|
||||
# Nothing was persisted under the complete status
|
||||
# — mark the row failed so the reconciler doesn't
|
||||
# need to sweep it, and journal an ``error`` so a
|
||||
# reconnecting client surfaces the same failure
|
||||
# the UI would show on a live error.
|
||||
try:
|
||||
self.conversation_service.finalize_message(
|
||||
reserved_message_id,
|
||||
response_full or TERMINATED_RESPONSE_PLACEHOLDER,
|
||||
thought=thought,
|
||||
sources=source_log_docs,
|
||||
tool_calls=tool_calls,
|
||||
model_id=model_id or self.default_model_id,
|
||||
metadata=query_metadata if query_metadata else None,
|
||||
status="failed",
|
||||
error=ConnectionError(
|
||||
"client disconnected before response was persisted"
|
||||
),
|
||||
)
|
||||
except Exception as fin_err:
|
||||
logger.error(
|
||||
f"Failed to mark aborted message failed: {fin_err}",
|
||||
exc_info=True,
|
||||
)
|
||||
record_event(
|
||||
reserved_message_id,
|
||||
sequence_no,
|
||||
"error",
|
||||
{
|
||||
"type": "error",
|
||||
"error": "Stream aborted before any response was produced.",
|
||||
"code": "client_disconnect",
|
||||
},
|
||||
)
|
||||
except Exception as journal_err:
|
||||
logger.error(
|
||||
f"Failed to journal terminal event on abort: {journal_err}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in stream: {str(e)}", exc_info=True)
|
||||
@@ -768,13 +979,16 @@ class BaseAnswerResource:
|
||||
f"Failed to finalize errored message: {fin_err}",
|
||||
exc_info=True,
|
||||
)
|
||||
data = json.dumps(
|
||||
yield _emit(
|
||||
{
|
||||
"type": "error",
|
||||
"error": "Please try again later. We apologize for any inconvenience.",
|
||||
}
|
||||
)
|
||||
yield f"data: {data}\n\n"
|
||||
# Drain the terminal ``error`` event we just yielded so a
|
||||
# reconnecting client sees it on snapshot.
|
||||
if journal_writer is not None:
|
||||
journal_writer.close()
|
||||
return
|
||||
|
||||
def process_response_stream(self, stream) -> Dict[str, Any]:
|
||||
@@ -796,8 +1010,22 @@ class BaseAnswerResource:
|
||||
|
||||
for line in stream:
|
||||
try:
|
||||
event_data = line.replace("data: ", "").strip()
|
||||
# Each chunk may carry an ``id: <seq>`` header before
|
||||
# the ``data:`` line. Pull just the ``data:`` body so
|
||||
# the JSON decode doesn't choke on the SSE framing.
|
||||
event_data = ""
|
||||
for raw in line.split("\n"):
|
||||
if raw.startswith("data:"):
|
||||
event_data = raw[len("data:") :].lstrip()
|
||||
break
|
||||
if not event_data:
|
||||
continue
|
||||
event = json.loads(event_data)
|
||||
# The ``message_id`` event is informational for the
|
||||
# streaming consumer and has no synchronous-API field;
|
||||
# skip it so the type-switch below doesn't KeyError.
|
||||
if event.get("type") == "message_id":
|
||||
continue
|
||||
|
||||
if event["type"] == "id":
|
||||
conversation_id = event["id"]
|
||||
|
||||
135
application/api/answer/routes/messages.py
Normal file
135
application/api/answer/routes/messages.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""GET /api/messages/<message_id>/events — chat-stream reconnect endpoint.
|
||||
|
||||
Authenticates the caller, verifies ``message_id`` belongs to the user,
|
||||
then hands off to ``build_message_event_stream`` for snapshot+tail.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Iterator, Optional
|
||||
|
||||
from flask import Blueprint, Response, jsonify, make_response, request, stream_with_context
|
||||
from sqlalchemy import text
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.session import db_readonly
|
||||
from application.streaming.event_replay import (
|
||||
DEFAULT_KEEPALIVE_SECONDS,
|
||||
DEFAULT_POLL_TIMEOUT_SECONDS,
|
||||
build_message_event_stream,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
messages_bp = Blueprint("message_stream", __name__)
|
||||
|
||||
# A message_id is the canonical UUID hex format. Reject anything else
|
||||
# before the SQL layer so a malformed cookie can't surface as a 500.
|
||||
_MESSAGE_ID_RE = re.compile(
|
||||
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-"
|
||||
r"[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$"
|
||||
)
|
||||
# ``sequence_no`` is a non-negative decimal integer. Anything else is
|
||||
# corrupt client state — fall through to a fresh-replay cursor and let
|
||||
# the snapshot reader catch the client up.
|
||||
_SEQUENCE_NO_RE = re.compile(r"^\d+$")
|
||||
|
||||
|
||||
def _normalise_last_event_id(raw: Optional[str]) -> Optional[int]:
|
||||
if raw is None:
|
||||
return None
|
||||
raw = raw.strip()
|
||||
if not raw or not _SEQUENCE_NO_RE.match(raw):
|
||||
return None
|
||||
return int(raw)
|
||||
|
||||
|
||||
def _user_owns_message(message_id: str, user_id: str) -> bool:
|
||||
"""Return True iff ``message_id`` belongs to ``user_id``."""
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
row = conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT 1 FROM conversation_messages
|
||||
WHERE id = CAST(:id AS uuid)
|
||||
AND user_id = :u
|
||||
LIMIT 1
|
||||
"""
|
||||
),
|
||||
{"id": message_id, "u": user_id},
|
||||
).first()
|
||||
return row is not None
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Ownership lookup failed for message_id=%s user_id=%s",
|
||||
message_id,
|
||||
user_id,
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
@messages_bp.route("/api/messages/<message_id>/events", methods=["GET"])
|
||||
def stream_message_events(message_id: str) -> Response:
|
||||
decoded = getattr(request, "decoded_token", None)
|
||||
user_id = decoded.get("sub") if isinstance(decoded, dict) else None
|
||||
if not user_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Authentication required"}),
|
||||
401,
|
||||
)
|
||||
|
||||
if not _MESSAGE_ID_RE.match(message_id):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid message id"}),
|
||||
400,
|
||||
)
|
||||
|
||||
if not _user_owns_message(message_id, user_id):
|
||||
# Don't disclose whether the row exists — a malicious caller
|
||||
# gets the same 404 whether the id is bogus, taken by another
|
||||
# user, or simply gone.
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Not found"}),
|
||||
404,
|
||||
)
|
||||
|
||||
raw_cursor = request.headers.get("Last-Event-ID") or request.args.get(
|
||||
"last_event_id"
|
||||
)
|
||||
last_event_id = _normalise_last_event_id(raw_cursor)
|
||||
keepalive_seconds = float(
|
||||
getattr(settings, "SSE_KEEPALIVE_SECONDS", DEFAULT_KEEPALIVE_SECONDS)
|
||||
)
|
||||
|
||||
@stream_with_context
|
||||
def generate() -> Iterator[str]:
|
||||
try:
|
||||
yield from build_message_event_stream(
|
||||
message_id,
|
||||
last_event_id=last_event_id,
|
||||
keepalive_seconds=keepalive_seconds,
|
||||
poll_timeout_seconds=DEFAULT_POLL_TIMEOUT_SECONDS,
|
||||
)
|
||||
except GeneratorExit:
|
||||
return
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Reconnect stream crashed for message_id=%s user_id=%s",
|
||||
message_id,
|
||||
user_id,
|
||||
)
|
||||
|
||||
response = Response(generate(), mimetype="text/event-stream")
|
||||
response.headers["Cache-Control"] = "no-store"
|
||||
response.headers["X-Accel-Buffering"] = "no"
|
||||
response.headers["Connection"] = "keep-alive"
|
||||
logger.info(
|
||||
"message.event.connect message_id=%s user_id=%s last_event_id=%s",
|
||||
message_id,
|
||||
user_id,
|
||||
last_event_id if last_event_id is not None else "-",
|
||||
)
|
||||
return response
|
||||
@@ -15,7 +15,10 @@ from sqlalchemy import text as sql_text
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||
from application.storage.db.repositories.conversations import (
|
||||
ConversationsRepository,
|
||||
MessageUpdateOutcome,
|
||||
)
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
|
||||
@@ -305,10 +308,17 @@ class ConversationService:
|
||||
status: str = "complete",
|
||||
error: Optional[BaseException] = None,
|
||||
title_inputs: Optional[Dict[str, Any]] = None,
|
||||
) -> bool:
|
||||
"""Commit the response and tool_call confirms in one transaction."""
|
||||
) -> MessageUpdateOutcome:
|
||||
"""Commit the response and tool_call confirms in one transaction.
|
||||
|
||||
The outcome propagates directly from ``update_message_by_id`` so
|
||||
callers (notably the SSE abort handler) can tell a fresh
|
||||
finalize from "the row was already terminal" — the latter must
|
||||
still be treated as success when the prior state was
|
||||
``complete``.
|
||||
"""
|
||||
if not message_id:
|
||||
return False
|
||||
return MessageUpdateOutcome.INVALID
|
||||
sources = sources or []
|
||||
for source in sources:
|
||||
if "text" in source and isinstance(source["text"], str):
|
||||
@@ -336,16 +346,16 @@ class ConversationService:
|
||||
# retracting a row the reconciler already escalated.
|
||||
with db_session() as conn:
|
||||
repo = ConversationsRepository(conn)
|
||||
ok = repo.update_message_by_id(
|
||||
outcome = repo.update_message_by_id(
|
||||
message_id, update_fields,
|
||||
only_if_non_terminal=True,
|
||||
)
|
||||
if not ok:
|
||||
if outcome is not MessageUpdateOutcome.UPDATED:
|
||||
logger.warning(
|
||||
f"finalize_message: no row updated for message_id={message_id} "
|
||||
f"(possibly already terminal — reconciler may have escalated)"
|
||||
f"(outcome={outcome.value} — possibly already terminal)"
|
||||
)
|
||||
return False
|
||||
return outcome
|
||||
repo.confirm_executed_tool_calls(message_id)
|
||||
|
||||
# Outside the txn — title-gen is a multi-second LLM round trip.
|
||||
@@ -358,7 +368,7 @@ class ConversationService:
|
||||
f"finalize_message title generation failed: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
return True
|
||||
return MessageUpdateOutcome.UPDATED
|
||||
|
||||
def _maybe_generate_title(
|
||||
self,
|
||||
|
||||
0
application/api/events/__init__.py
Normal file
0
application/api/events/__init__.py
Normal file
504
application/api/events/routes.py
Normal file
504
application/api/events/routes.py
Normal file
@@ -0,0 +1,504 @@
|
||||
"""GET /api/events — user-scoped Server-Sent Events endpoint.
|
||||
|
||||
Subscribe-then-snapshot pattern: subscribe to ``user:{user_id}``
|
||||
pub/sub, snapshot the Redis Streams backlog past ``Last-Event-ID``
|
||||
inside the SUBSCRIBE-ack callback, flush snapshot, then tail live
|
||||
events (dedup'd by stream id). See ``docs/runbooks/sse-notifications.md``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from typing import Iterator, Optional
|
||||
|
||||
from flask import Blueprint, Response, jsonify, make_response, request, stream_with_context
|
||||
|
||||
from application.cache import get_redis_instance
|
||||
from application.core.settings import settings
|
||||
from application.events.keys import (
|
||||
connection_counter_key,
|
||||
replay_budget_key,
|
||||
stream_id_compare,
|
||||
stream_key,
|
||||
topic_name,
|
||||
)
|
||||
from application.streaming.broadcast_channel import Topic
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
events = Blueprint("event_stream", __name__)
|
||||
|
||||
SUBSCRIBE_POLL_INTERVAL_SECONDS = 1.0
|
||||
|
||||
# WHATWG SSE treats CRLF, CR, and LF equivalently as line terminators.
|
||||
_SSE_LINE_SPLIT = re.compile(r"\r\n|\r|\n")
|
||||
|
||||
# Redis Streams ids are ``ms`` or ``ms-seq`` where both halves are decimal.
|
||||
# Anything else is a corrupted client cookie / IndexedDB residue and must
|
||||
# not be passed to XRANGE — Redis would reject it and our truncation gate
|
||||
# would silently fail.
|
||||
_STREAM_ID_RE = re.compile(r"^\d+(-\d+)?$")
|
||||
|
||||
# Only emitted at most once per process so a misconfigured deployment
|
||||
# doesn't drown the logs.
|
||||
_local_user_warned = False
|
||||
|
||||
|
||||
def _format_sse(data: str, *, event_id: Optional[str] = None) -> str:
|
||||
"""Encode a payload as one SSE message terminated by a blank line.
|
||||
|
||||
Splits on any line-terminator variant (``\\r\\n``, ``\\r``, ``\\n``)
|
||||
so a stray CR in upstream content can't smuggle a premature line
|
||||
boundary into the wire format.
|
||||
"""
|
||||
lines: list[str] = []
|
||||
if event_id:
|
||||
lines.append(f"id: {event_id}")
|
||||
for line in _SSE_LINE_SPLIT.split(data):
|
||||
lines.append(f"data: {line}")
|
||||
return "\n".join(lines) + "\n\n"
|
||||
|
||||
|
||||
def _decode(value) -> Optional[str]:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, (bytes, bytearray)):
|
||||
try:
|
||||
return value.decode("utf-8")
|
||||
except Exception:
|
||||
return None
|
||||
return str(value)
|
||||
|
||||
|
||||
def _oldest_retained_id(redis_client, user_id: str) -> Optional[str]:
|
||||
"""Return the id of the oldest entry still in the stream, or ``None``.
|
||||
|
||||
Used to detect ``Last-Event-ID`` having slid off the back of the
|
||||
MAXLEN'd window.
|
||||
"""
|
||||
try:
|
||||
info = redis_client.xinfo_stream(stream_key(user_id))
|
||||
except Exception:
|
||||
return None
|
||||
if not isinstance(info, dict):
|
||||
return None
|
||||
# redis-py 7.4 returns str-keyed dicts here; the bytes-key probe is
|
||||
# defence in depth in case ``decode_responses`` is ever flipped.
|
||||
first_entry = info.get("first-entry") or info.get(b"first-entry")
|
||||
if not first_entry:
|
||||
return None
|
||||
# XINFO STREAM returns first-entry as [id, [field, value, ...]]
|
||||
try:
|
||||
return _decode(first_entry[0])
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _allow_replay(
|
||||
redis_client, user_id: str, last_event_id: Optional[str]
|
||||
) -> bool:
|
||||
"""Per-user sliding-window snapshot-replay budget.
|
||||
|
||||
Fails open on Redis errors or when the budget is disabled. Empty-backlog
|
||||
no-cursor connects skip INCR so dev double-mounts don't trip 429.
|
||||
"""
|
||||
budget = int(settings.EVENTS_REPLAY_BUDGET_REQUESTS_PER_WINDOW)
|
||||
if budget <= 0:
|
||||
return True
|
||||
if redis_client is None:
|
||||
return True
|
||||
|
||||
# Cheap pre-check: only INCR when we might actually replay. XLEN
|
||||
# is one Redis op; the alternative (INCR every connect) is two
|
||||
# ops AND wrongly counts no-op probes. The check is conservative:
|
||||
# if ``last_event_id`` is set we always INCR, even if the cursor
|
||||
# has already overtaken the latest entry — that case is rare and
|
||||
# short-lived, and probing further would mean a redundant XRANGE.
|
||||
if last_event_id is None:
|
||||
try:
|
||||
if int(redis_client.xlen(stream_key(user_id))) == 0:
|
||||
return True
|
||||
except Exception:
|
||||
# XLEN probe failed; fall through to the INCR path so a
|
||||
# transient Redis hiccup can't bypass the budget.
|
||||
logger.debug(
|
||||
"XLEN probe failed for replay budget check user=%s; "
|
||||
"proceeding to INCR",
|
||||
user_id,
|
||||
)
|
||||
|
||||
window = max(1, int(settings.EVENTS_REPLAY_BUDGET_WINDOW_SECONDS))
|
||||
key = replay_budget_key(user_id)
|
||||
try:
|
||||
used = int(redis_client.incr(key))
|
||||
# Always (re)seed the TTL. Gating on ``used == 1`` would wedge
|
||||
# the counter forever if INCR succeeds but EXPIRE raises on
|
||||
# the seeding call. EXPIRE on an existing key resets the TTL
|
||||
# to ``window`` — within ±1s of the per-window budget semantic.
|
||||
redis_client.expire(key, window)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"replay budget probe failed for user=%s; failing open",
|
||||
user_id,
|
||||
)
|
||||
return True
|
||||
return used <= budget
|
||||
|
||||
|
||||
def _normalize_last_event_id(raw: Optional[str]) -> Optional[str]:
|
||||
"""Validate the ``Last-Event-ID`` header / query param.
|
||||
|
||||
Returns the value unchanged when it parses as a Redis Streams id,
|
||||
otherwise ``None`` — callers treat ``None`` as "client has nothing"
|
||||
and replay from the start of the retained window. Invalid ids would
|
||||
otherwise pass straight to XRANGE and surface as a quiet replay
|
||||
failure plus broken truncation detection.
|
||||
"""
|
||||
if raw is None:
|
||||
return None
|
||||
raw = raw.strip()
|
||||
if not raw or not _STREAM_ID_RE.match(raw):
|
||||
return None
|
||||
return raw
|
||||
|
||||
|
||||
def _replay_backlog(
|
||||
redis_client, user_id: str, last_event_id: Optional[str], max_count: int
|
||||
) -> Iterator[tuple[str, str]]:
|
||||
"""Yield ``(entry_id, sse_line)`` for backlog entries past ``last_event_id``.
|
||||
|
||||
Capped at ``max_count`` rows; clients catch up across reconnects.
|
||||
Parse failures are skipped; the Streams id is injected into the
|
||||
envelope so replay matches live-tail shape.
|
||||
"""
|
||||
# Exclusive start: '(<id>' skips the already-delivered entry.
|
||||
start = f"({last_event_id}" if last_event_id else "-"
|
||||
try:
|
||||
entries = redis_client.xrange(
|
||||
stream_key(user_id), min=start, max="+", count=max_count
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"xrange replay failed for user=%s last_id=%s err=%s",
|
||||
user_id,
|
||||
last_event_id or "-",
|
||||
exc,
|
||||
)
|
||||
return
|
||||
|
||||
for entry_id, fields in entries:
|
||||
entry_id_str = _decode(entry_id)
|
||||
if not entry_id_str:
|
||||
continue
|
||||
# decode_responses=False on the cache client ⇒ field keys/values
|
||||
# are bytes. The string-key fallback covers a future flip of that
|
||||
# default without a forced refactor here.
|
||||
raw_event = None
|
||||
if isinstance(fields, dict):
|
||||
raw_event = fields.get(b"event")
|
||||
if raw_event is None:
|
||||
raw_event = fields.get("event")
|
||||
event_str = _decode(raw_event)
|
||||
if not event_str:
|
||||
continue
|
||||
try:
|
||||
envelope = json.loads(event_str)
|
||||
if isinstance(envelope, dict):
|
||||
envelope["id"] = entry_id_str
|
||||
event_str = json.dumps(envelope)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Replay envelope parse failed for entry %s; passing through raw",
|
||||
entry_id_str,
|
||||
)
|
||||
yield entry_id_str, _format_sse(event_str, event_id=entry_id_str)
|
||||
|
||||
|
||||
def _truncation_notice_line(oldest_id: str) -> str:
|
||||
"""SSE event the frontend can react to with a full-state refetch."""
|
||||
return _format_sse(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "backlog.truncated",
|
||||
"payload": {"oldest_retained_id": oldest_id},
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@events.route("/api/events", methods=["GET"])
|
||||
def stream_events() -> Response:
|
||||
decoded = getattr(request, "decoded_token", None)
|
||||
user_id = decoded.get("sub") if isinstance(decoded, dict) else None
|
||||
if not user_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Authentication required"}),
|
||||
401,
|
||||
)
|
||||
|
||||
# In dev deployments without AUTH_TYPE configured, every request
|
||||
# resolves to user_id="local" and shares one stream. Surface this so
|
||||
# an accidentally-multi-user dev box doesn't silently cross-stream.
|
||||
global _local_user_warned
|
||||
if user_id == "local" and not _local_user_warned:
|
||||
logger.warning(
|
||||
"SSE serving user_id='local' (AUTH_TYPE not set). "
|
||||
"All clients on this deployment will share one event stream."
|
||||
)
|
||||
_local_user_warned = True
|
||||
|
||||
raw_last_event_id = request.headers.get("Last-Event-ID") or request.args.get(
|
||||
"last_event_id"
|
||||
)
|
||||
last_event_id = _normalize_last_event_id(raw_last_event_id)
|
||||
last_event_id_invalid = raw_last_event_id is not None and last_event_id is None
|
||||
|
||||
keepalive_seconds = float(settings.SSE_KEEPALIVE_SECONDS)
|
||||
push_enabled = settings.ENABLE_SSE_PUSH
|
||||
cap = int(settings.SSE_MAX_CONCURRENT_PER_USER)
|
||||
|
||||
redis_client = get_redis_instance()
|
||||
counter_key = connection_counter_key(user_id)
|
||||
counted = False
|
||||
|
||||
if push_enabled and redis_client is not None and cap > 0:
|
||||
try:
|
||||
current = int(redis_client.incr(counter_key))
|
||||
counted = True
|
||||
except Exception:
|
||||
current = 0
|
||||
logger.debug(
|
||||
"SSE connection counter INCR failed for user=%s", user_id
|
||||
)
|
||||
if counted:
|
||||
# 1h safety TTL — orphaned counts from hard crashes self-heal.
|
||||
# EXPIRE failure must NOT clobber ``current`` and bypass the cap.
|
||||
try:
|
||||
redis_client.expire(counter_key, 3600)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"SSE connection counter EXPIRE failed for user=%s", user_id
|
||||
)
|
||||
if current > cap:
|
||||
try:
|
||||
redis_client.decr(counter_key)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"SSE connection counter DECR failed for user=%s",
|
||||
user_id,
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Too many concurrent SSE connections",
|
||||
}
|
||||
),
|
||||
429,
|
||||
)
|
||||
|
||||
# Replay budget is checked here, before the generator opens the
|
||||
# stream, so a denial can surface as HTTP 429 instead of a silent
|
||||
# snapshot skip. The earlier in-generator skip lost events between
|
||||
# the client's cursor and the first live-tailed entry: the live
|
||||
# tail still carried ``id:`` headers, the frontend advanced
|
||||
# ``lastEventId`` to one of those ids, and the events in between
|
||||
# were never reachable on the next reconnect. 429 keeps the
|
||||
# cursor pinned and lets the frontend back off until the window
|
||||
# slides (eventStreamClient.ts treats 429 as escalated backoff).
|
||||
if push_enabled and redis_client is not None and not _allow_replay(
|
||||
redis_client, user_id, last_event_id
|
||||
):
|
||||
if counted:
|
||||
try:
|
||||
redis_client.decr(counter_key)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"SSE connection counter DECR failed for user=%s",
|
||||
user_id,
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Replay budget exhausted",
|
||||
}
|
||||
),
|
||||
429,
|
||||
)
|
||||
|
||||
@stream_with_context
|
||||
def generate() -> Iterator[str]:
|
||||
connect_ts = time.monotonic()
|
||||
replayed_count = 0
|
||||
try:
|
||||
# First frame primes intermediaries (Cloudflare, nginx) so they
|
||||
# don't sit on a buffer waiting for body bytes.
|
||||
yield ": connected\n\n"
|
||||
|
||||
if not push_enabled:
|
||||
yield ": push_disabled\n\n"
|
||||
return
|
||||
|
||||
replay_lines: list[str] = []
|
||||
max_replayed_id: Optional[str] = None
|
||||
replay_done = False
|
||||
|
||||
# If the client sent a malformed Last-Event-ID, surface the
|
||||
# truncation notice synchronously *before* the subscribe
|
||||
# loop. Buffering it into ``replay_lines`` would lose it
|
||||
# when ``Topic.subscribe`` returns immediately (Redis down)
|
||||
# — the loop body never runs, and the flush at line ~335
|
||||
# never fires.
|
||||
if last_event_id_invalid:
|
||||
yield _truncation_notice_line("")
|
||||
replayed_count += 1
|
||||
|
||||
def _on_subscribe_callback() -> None:
|
||||
# Runs synchronously inside Topic.subscribe after the
|
||||
# SUBSCRIBE is acked. By doing XRANGE here, any publisher
|
||||
# firing between SUBSCRIBE-send and SUBSCRIBE-ack has its
|
||||
# XADD captured by XRANGE *and* its PUBLISH buffered at
|
||||
# the connection layer until we read it — closing the
|
||||
# replay/subscribe race the design doc warns about.
|
||||
#
|
||||
# Truncation contract: ``backlog.truncated`` is emitted
|
||||
# ONLY when the client's ``Last-Event-ID`` has slid off
|
||||
# the MAXLEN'd window — that's the case where the
|
||||
# journal is genuinely gone past the cursor and the
|
||||
# frontend should clear its slice cursor and refetch
|
||||
# state. Cap-hit skips the snapshot silently: the
|
||||
# cursor advances via the per-entry ``id:`` headers
|
||||
# and the frontend's slice keeps the latest id so the
|
||||
# next reconnect resumes from there. Budget-exhausted
|
||||
# never reaches this callback — the route 429s before
|
||||
# opening the stream, keeping the cursor pinned.
|
||||
# Conflating these with stale-cursor truncation would
|
||||
# tell the client to clear its cursor and re-receive
|
||||
# the same oldest-N entries on every reconnect —
|
||||
# locking the user out of entries past N.
|
||||
nonlocal max_replayed_id, replay_done
|
||||
try:
|
||||
if redis_client is None:
|
||||
return
|
||||
oldest = _oldest_retained_id(redis_client, user_id)
|
||||
if (
|
||||
last_event_id
|
||||
and oldest
|
||||
and stream_id_compare(last_event_id, oldest) < 0
|
||||
):
|
||||
# The Last-Event-ID has slid off the MAXLEN window.
|
||||
# Tell the client so it can fetch full state.
|
||||
replay_lines.append(_truncation_notice_line(oldest))
|
||||
replay_cap = int(settings.EVENTS_REPLAY_MAX_PER_REQUEST)
|
||||
for entry_id, sse_line in _replay_backlog(
|
||||
redis_client, user_id, last_event_id, replay_cap
|
||||
):
|
||||
replay_lines.append(sse_line)
|
||||
max_replayed_id = entry_id
|
||||
finally:
|
||||
# Always flip the flag — even on partial-replay failure
|
||||
# the outer loop must reach the flush step so we don't
|
||||
# silently strand whatever entries did land.
|
||||
replay_done = True
|
||||
|
||||
topic = Topic(topic_name(user_id))
|
||||
last_keepalive = time.monotonic()
|
||||
for payload in topic.subscribe(
|
||||
on_subscribe=_on_subscribe_callback,
|
||||
poll_timeout=SUBSCRIBE_POLL_INTERVAL_SECONDS,
|
||||
):
|
||||
# Flush snapshot on the first iteration after the SUBSCRIBE
|
||||
# callback ran. This runs at most once per connection.
|
||||
if replay_done and replay_lines:
|
||||
for line in replay_lines:
|
||||
yield line
|
||||
replayed_count += 1
|
||||
replay_lines.clear()
|
||||
|
||||
now = time.monotonic()
|
||||
if payload is None:
|
||||
if now - last_keepalive >= keepalive_seconds:
|
||||
yield ": keepalive\n\n"
|
||||
last_keepalive = now
|
||||
continue
|
||||
|
||||
event_str = _decode(payload) or ""
|
||||
event_id: Optional[str] = None
|
||||
try:
|
||||
envelope = json.loads(event_str)
|
||||
if isinstance(envelope, dict):
|
||||
candidate = envelope.get("id")
|
||||
# Only trust ids that look like real Redis Streams
|
||||
# ids (``ms`` or ``ms-seq``). A malformed or
|
||||
# adversarial publisher could otherwise pin
|
||||
# dedupe forever — a lex-greater bogus id would
|
||||
# make every legitimate later id compare ``<=``
|
||||
# and get dropped silently.
|
||||
if isinstance(candidate, str) and _STREAM_ID_RE.match(
|
||||
candidate
|
||||
):
|
||||
event_id = candidate
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Dedupe: if this id was already covered by replay, drop it.
|
||||
if (
|
||||
event_id is not None
|
||||
and max_replayed_id is not None
|
||||
and stream_id_compare(event_id, max_replayed_id) <= 0
|
||||
):
|
||||
continue
|
||||
|
||||
yield _format_sse(event_str, event_id=event_id)
|
||||
last_keepalive = now
|
||||
|
||||
# Topic.subscribe exited before the first yield (transient
|
||||
# Redis hiccup between SUBSCRIBE-ack and the first poll, or
|
||||
# an immediate Redis-down return). The callback may already
|
||||
# have populated the snapshot — flush it so the client gets
|
||||
# the backlog instead of a silent drop. Safe no-op when the
|
||||
# in-loop flush ran (it clear()'d the buffer) and when the
|
||||
# callback never fired (replay_done stays False).
|
||||
if replay_done and replay_lines:
|
||||
for line in replay_lines:
|
||||
yield line
|
||||
replayed_count += 1
|
||||
replay_lines.clear()
|
||||
except GeneratorExit:
|
||||
return
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"SSE event-stream generator crashed for user=%s", user_id
|
||||
)
|
||||
finally:
|
||||
duration_s = time.monotonic() - connect_ts
|
||||
logger.info(
|
||||
"event.disconnect user=%s duration_s=%.1f replayed=%d",
|
||||
user_id,
|
||||
duration_s,
|
||||
replayed_count,
|
||||
)
|
||||
if counted and redis_client is not None:
|
||||
try:
|
||||
redis_client.decr(counter_key)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"SSE connection counter DECR failed for user=%s on disconnect",
|
||||
user_id,
|
||||
)
|
||||
|
||||
response = Response(generate(), mimetype="text/event-stream")
|
||||
response.headers["Cache-Control"] = "no-store"
|
||||
response.headers["X-Accel-Buffering"] = "no"
|
||||
response.headers["Connection"] = "keep-alive"
|
||||
logger.info(
|
||||
"event.connect user=%s last_event_id=%s%s",
|
||||
user_id,
|
||||
last_event_id or "-",
|
||||
" (rejected_invalid)" if last_event_id_invalid else "",
|
||||
)
|
||||
return response
|
||||
@@ -214,6 +214,10 @@ class StoreAttachment(Resource):
|
||||
{
|
||||
"success": True,
|
||||
"task_id": tasks[0]["task_id"],
|
||||
# Surface the attachment_id so the frontend
|
||||
# can correlate ``attachment.*`` SSE events
|
||||
# to this row and skip the polling fallback.
|
||||
"attachment_id": tasks[0]["attachment_id"],
|
||||
"message": "File uploaded successfully. Processing started.",
|
||||
}
|
||||
),
|
||||
|
||||
@@ -7,9 +7,13 @@ from flask_restx import fields, Namespace, Resource
|
||||
from sqlalchemy import text as sql_text
|
||||
|
||||
from application.api import api
|
||||
from application.api.answer.services.conversation_service import (
|
||||
TERMINATED_RESPONSE_PLACEHOLDER,
|
||||
)
|
||||
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
|
||||
from application.storage.db.repositories.attachments import AttachmentsRepository
|
||||
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||
from application.storage.db.repositories.message_events import MessageEventsRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.utils import check_required_fields
|
||||
|
||||
@@ -106,85 +110,6 @@ class GetConversations(Resource):
|
||||
return make_response(jsonify(list_conversations), 200)
|
||||
|
||||
|
||||
@conversations_ns.route("/search_conversations")
|
||||
class SearchConversations(Resource):
|
||||
@staticmethod
|
||||
def _build_match_snippet(text_value: str, query: str, radius: int = 60) -> str:
|
||||
if not text_value:
|
||||
return ""
|
||||
idx = text_value.lower().find(query.lower())
|
||||
if idx == -1:
|
||||
snippet = text_value[: radius * 2]
|
||||
return snippet + ("…" if len(text_value) > len(snippet) else "")
|
||||
start = max(0, idx - radius)
|
||||
end = min(len(text_value), idx + len(query) + radius)
|
||||
snippet = text_value[start:end]
|
||||
if start > 0:
|
||||
snippet = "…" + snippet
|
||||
if end < len(text_value):
|
||||
snippet = snippet + "…"
|
||||
return snippet
|
||||
|
||||
@api.doc(
|
||||
description=(
|
||||
"Search the authenticated user's conversations by name or "
|
||||
"message content (case-insensitive substring match). Mirrors "
|
||||
"the visibility filter and response shape of /get_conversations, "
|
||||
"and additionally returns ``match_field`` (``name``, ``prompt`` "
|
||||
"or ``response``) and ``match_snippet`` (a short excerpt of the "
|
||||
"matched text centered on the query) for each result."
|
||||
),
|
||||
params={
|
||||
"q": "Search term (required)",
|
||||
"limit": "Maximum number of results to return (default 30, max 100)",
|
||||
},
|
||||
)
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
query = (request.args.get("q") or "").strip()
|
||||
if not query:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "q is required"}), 400
|
||||
)
|
||||
try:
|
||||
limit = int(request.args.get("limit", 30))
|
||||
except (TypeError, ValueError):
|
||||
limit = 30
|
||||
limit = max(1, min(limit, 100))
|
||||
user_id = decoded_token.get("sub")
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
conversations = ConversationsRepository(conn).search_for_user(
|
||||
user_id, query, limit=limit
|
||||
)
|
||||
list_conversations = [
|
||||
{
|
||||
"id": str(conversation["id"]),
|
||||
"name": conversation["name"],
|
||||
"agent_id": (
|
||||
str(conversation["agent_id"])
|
||||
if conversation.get("agent_id")
|
||||
else None
|
||||
),
|
||||
"is_shared_usage": conversation.get("is_shared_usage", False),
|
||||
"shared_token": conversation.get("shared_token", None),
|
||||
"match_field": conversation.get("match_field"),
|
||||
"match_snippet": self._build_match_snippet(
|
||||
conversation.get("match_text") or "", query
|
||||
),
|
||||
}
|
||||
for conversation in conversations
|
||||
]
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error searching conversations: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify(list_conversations), 200)
|
||||
|
||||
|
||||
@conversations_ns.route("/get_single_conversation")
|
||||
class GetSingleConversation(Resource):
|
||||
@api.doc(
|
||||
@@ -425,6 +350,25 @@ class GetMessageTail(Resource):
|
||||
if row is None:
|
||||
return make_response(jsonify({"status": "not found"}), 404)
|
||||
msg = row_to_dict(row)
|
||||
# Mid-stream the row's response is the placeholder; rebuild
|
||||
# the live partial from the journal so /tail mirrors SSE.
|
||||
status = msg.get("status")
|
||||
response = msg.get("response")
|
||||
thought = msg.get("thought")
|
||||
sources = msg.get("sources") or []
|
||||
tool_calls = msg.get("tool_calls") or []
|
||||
if status in ("pending", "streaming") and (
|
||||
response == TERMINATED_RESPONSE_PLACEHOLDER
|
||||
):
|
||||
partial = MessageEventsRepository(conn).reconstruct_partial(
|
||||
message_id
|
||||
)
|
||||
response = partial["response"]
|
||||
thought = partial["thought"] or thought
|
||||
if partial["sources"]:
|
||||
sources = partial["sources"]
|
||||
if partial["tool_calls"]:
|
||||
tool_calls = partial["tool_calls"]
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error tailing message {message_id}: {err}", exc_info=True
|
||||
@@ -435,11 +379,11 @@ class GetMessageTail(Resource):
|
||||
jsonify(
|
||||
{
|
||||
"message_id": str(msg["id"]),
|
||||
"status": msg.get("status"),
|
||||
"response": msg.get("response"),
|
||||
"thought": msg.get("thought"),
|
||||
"sources": msg.get("sources") or [],
|
||||
"tool_calls": msg.get("tool_calls") or [],
|
||||
"status": status,
|
||||
"response": response,
|
||||
"thought": thought,
|
||||
"sources": sources,
|
||||
"tool_calls": tool_calls,
|
||||
"request_id": msg.get("request_id"),
|
||||
"last_heartbeat_at": metadata.get("last_heartbeat_at"),
|
||||
"error": metadata.get("error"),
|
||||
|
||||
@@ -13,6 +13,7 @@ from sqlalchemy import text as sql_text
|
||||
from application.api import api
|
||||
from application.api.user.tasks import ingest, ingest_connector_task, ingest_remote
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.source_ids import derive_source_id as _derive_source_id
|
||||
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||
from application.parser.file.constants import SUPPORTED_SOURCE_EXTENSIONS
|
||||
from application.storage.db.repositories.idempotency import IdempotencyRepository
|
||||
@@ -69,7 +70,13 @@ def _claim_task_or_get_cached(key, task_name):
|
||||
|
||||
Pre-generates the celery task_id so a losing writer sees the same
|
||||
id immediately. Returns ``(task_id, cached_response)``; non-None
|
||||
cached means the caller should return without enqueuing.
|
||||
cached means the caller should return without enqueuing. The
|
||||
cached payload mirrors the fresh-request response shape (including
|
||||
``source_id``) so the frontend can correlate SSE ingest events to
|
||||
the cached upload task without an extra round-trip — but only when
|
||||
the cached row actually exists; the "deduplicated" sentinel
|
||||
deliberately omits ``source_id`` so the frontend doesn't bind to a
|
||||
phantom source.
|
||||
"""
|
||||
predetermined_id = str(uuid.uuid4())
|
||||
with db_session() as conn:
|
||||
@@ -81,10 +88,16 @@ def _claim_task_or_get_cached(key, task_name):
|
||||
with db_readonly() as conn:
|
||||
existing = IdempotencyRepository(conn).get_task(key)
|
||||
cached_id = existing.get("task_id") if existing else None
|
||||
return None, {
|
||||
payload: dict = {
|
||||
"success": True,
|
||||
"task_id": cached_id or "deduplicated",
|
||||
}
|
||||
# Only surface ``source_id`` when there's a real winner whose worker
|
||||
# is publishing SSE events tagged with that id. The "deduplicated"
|
||||
# branch means the lock row vanished — we have nothing to correlate.
|
||||
if cached_id is not None:
|
||||
payload["source_id"] = str(_derive_source_id(key))
|
||||
return None, payload
|
||||
|
||||
|
||||
def _release_claim(key):
|
||||
@@ -236,6 +249,15 @@ class UploadFile(Resource):
|
||||
file_path = f"{base_path}/{safe_file}"
|
||||
with open(temp_file_path, "rb") as f:
|
||||
storage.save_file(f, file_path)
|
||||
# Mint the source UUID up here so the HTTP response and the
|
||||
# worker's SSE envelopes share one id. With an idempotency
|
||||
# key we reuse the deterministic uuid5 (retried task lands on
|
||||
# the same source row); without a key we fall back to uuid4.
|
||||
# The worker is told to use this id verbatim — see
|
||||
# ``ingest_worker(source_id=...)``.
|
||||
source_uuid = (
|
||||
_derive_source_id(scoped_key) if scoped_key else uuid.uuid4()
|
||||
)
|
||||
ingest_kwargs = dict(
|
||||
args=(
|
||||
settings.UPLOAD_FOLDER,
|
||||
@@ -249,6 +271,7 @@ class UploadFile(Resource):
|
||||
"file_name_map": file_name_map,
|
||||
# Scoped so the worker dedup row matches the HTTP claim.
|
||||
"idempotency_key": scoped_key or idempotency_key,
|
||||
"source_id": str(source_uuid),
|
||||
},
|
||||
)
|
||||
if predetermined_task_id is not None:
|
||||
@@ -273,7 +296,15 @@ class UploadFile(Resource):
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
# Predetermined id matches the dedup-claim row; loser GET sees same.
|
||||
response_task_id = predetermined_task_id or task.id
|
||||
response_payload = {"success": True, "task_id": response_task_id}
|
||||
# ``source_uuid`` was minted above and passed to the worker as
|
||||
# ``source_id``; the worker uses it verbatim for every SSE event,
|
||||
# so the frontend can correlate inbound ``source.ingest.*`` to
|
||||
# this upload regardless of whether an idempotency key was set.
|
||||
response_payload: dict = {
|
||||
"success": True,
|
||||
"task_id": response_task_id,
|
||||
"source_id": str(source_uuid),
|
||||
}
|
||||
return make_response(jsonify(response_payload), 200)
|
||||
|
||||
|
||||
@@ -326,6 +357,18 @@ class UploadRemote(Resource):
|
||||
)
|
||||
if cached is not None:
|
||||
return make_response(jsonify(cached), 200)
|
||||
# Mint the source UUID up here so the HTTP response and the
|
||||
# worker's SSE envelopes share one id. Same pattern as
|
||||
# ``UploadFile.post``: with an idempotency key we reuse the
|
||||
# deterministic uuid5 (retried task lands on the same source
|
||||
# row); without a key we fall back to uuid4. The worker is told
|
||||
# to use this id verbatim — see ``remote_worker`` and
|
||||
# ``ingest_connector``. Without this the no-key path would mint
|
||||
# a random uuid4 inside the worker that the frontend has no way
|
||||
# to correlate SSE events to.
|
||||
source_uuid = (
|
||||
_derive_source_id(scoped_key) if scoped_key else uuid.uuid4()
|
||||
)
|
||||
try:
|
||||
config = json.loads(data["data"])
|
||||
source_data = None
|
||||
@@ -382,13 +425,23 @@ class UploadRemote(Resource):
|
||||
"recursive": config.get("recursive", False),
|
||||
"retriever": config.get("retriever", "classic"),
|
||||
"idempotency_key": scoped_key or idempotency_key,
|
||||
"source_id": str(source_uuid),
|
||||
},
|
||||
}
|
||||
if predetermined_task_id is not None:
|
||||
connector_kwargs["task_id"] = predetermined_task_id
|
||||
task = ingest_connector_task.apply_async(**connector_kwargs)
|
||||
response_task_id = predetermined_task_id or task.id
|
||||
response_payload = {"success": True, "task_id": response_task_id}
|
||||
# ``source_uuid`` was minted above and passed to the
|
||||
# worker as ``source_id``; the worker uses it verbatim
|
||||
# for every SSE event, so the frontend can correlate
|
||||
# inbound ``source.ingest.*`` regardless of whether an
|
||||
# idempotency key was set.
|
||||
response_payload = {
|
||||
"success": True,
|
||||
"task_id": response_task_id,
|
||||
"source_id": str(source_uuid),
|
||||
}
|
||||
return make_response(jsonify(response_payload), 200)
|
||||
remote_kwargs = {
|
||||
"kwargs": {
|
||||
@@ -397,6 +450,7 @@ class UploadRemote(Resource):
|
||||
"user": user,
|
||||
"loader": data["source"],
|
||||
"idempotency_key": scoped_key or idempotency_key,
|
||||
"source_id": str(source_uuid),
|
||||
},
|
||||
}
|
||||
if predetermined_task_id is not None:
|
||||
@@ -410,7 +464,11 @@ class UploadRemote(Resource):
|
||||
_release_claim(scoped_key)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
response_task_id = predetermined_task_id or task.id
|
||||
response_payload = {"success": True, "task_id": response_task_id}
|
||||
response_payload = {
|
||||
"success": True,
|
||||
"task_id": response_task_id,
|
||||
"source_id": str(source_uuid),
|
||||
}
|
||||
return make_response(jsonify(response_payload), 200)
|
||||
|
||||
|
||||
@@ -553,6 +611,19 @@ class ManageSourceFiles(Resource):
|
||||
scoped_key, "reingest_source_task",
|
||||
)
|
||||
if cached is not None:
|
||||
# Frontend keys reingest polling on
|
||||
# ``reingest_task_id``; the shared cache helper
|
||||
# writes ``task_id``. Alias here so a dedup
|
||||
# response doesn't silently break FileTree's
|
||||
# poller. Override ``source_id`` too — the
|
||||
# helper derives it from the scoped key, which
|
||||
# is correct for upload but wrong for reingest
|
||||
# (the worker publishes events scoped to the
|
||||
# actual source row id).
|
||||
cached_task_id = cached.pop("task_id", None)
|
||||
if cached_task_id is not None:
|
||||
cached["reingest_task_id"] = cached_task_id
|
||||
cached["source_id"] = resolved_source_id
|
||||
return make_response(jsonify(cached), 200)
|
||||
|
||||
added_files = []
|
||||
@@ -608,6 +679,12 @@ class ManageSourceFiles(Resource):
|
||||
"added_files": added_files,
|
||||
"parent_dir": parent_dir,
|
||||
"reingest_task_id": task.id,
|
||||
# ``source_id`` lets the frontend correlate
|
||||
# inbound ``source.ingest.*`` SSE events
|
||||
# (emitted by ``reingest_source_worker``)
|
||||
# back to the reingest task — matches the
|
||||
# upload route's source-id contract.
|
||||
"source_id": resolved_source_id,
|
||||
}
|
||||
),
|
||||
200,
|
||||
@@ -659,6 +736,15 @@ class ManageSourceFiles(Resource):
|
||||
scoped_key, "reingest_source_task",
|
||||
)
|
||||
if cached is not None:
|
||||
cached_task_id = cached.pop("task_id", None)
|
||||
if cached_task_id is not None:
|
||||
cached["reingest_task_id"] = cached_task_id
|
||||
# Override the helper's synthetic source_id (uuid5
|
||||
# of the scoped key) with the real source row id
|
||||
# — the reingest worker publishes SSE events
|
||||
# scoped to ``resolved_source_id`` and FileTree
|
||||
# correlates on it.
|
||||
cached["source_id"] = resolved_source_id
|
||||
return make_response(jsonify(cached), 200)
|
||||
|
||||
# Remove files from storage and directory structure
|
||||
@@ -704,6 +790,7 @@ class ManageSourceFiles(Resource):
|
||||
"message": f"Removed {len(removed_files)} files",
|
||||
"removed_files": removed_files,
|
||||
"reingest_task_id": task.id,
|
||||
"source_id": resolved_source_id,
|
||||
}
|
||||
),
|
||||
200,
|
||||
@@ -762,6 +849,14 @@ class ManageSourceFiles(Resource):
|
||||
scoped_key, "reingest_source_task",
|
||||
)
|
||||
if cached is not None:
|
||||
cached_task_id = cached.pop("task_id", None)
|
||||
if cached_task_id is not None:
|
||||
cached["reingest_task_id"] = cached_task_id
|
||||
# Same source_id override as the ``remove`` /
|
||||
# ``add`` cached branches — the helper's synthetic
|
||||
# id doesn't match what reingest_source_worker
|
||||
# tags its SSE events with.
|
||||
cached["source_id"] = resolved_source_id
|
||||
return make_response(jsonify(cached), 200)
|
||||
|
||||
success = storage.remove_directory(full_directory_path)
|
||||
@@ -825,6 +920,7 @@ class ManageSourceFiles(Resource):
|
||||
"message": f"Successfully removed directory: {directory_path}",
|
||||
"removed_directory": directory_path,
|
||||
"reingest_task_id": task.id,
|
||||
"source_id": resolved_source_id,
|
||||
}
|
||||
),
|
||||
200,
|
||||
|
||||
@@ -7,7 +7,6 @@ from application.worker import (
|
||||
attachment_worker,
|
||||
ingest_worker,
|
||||
mcp_oauth,
|
||||
mcp_oauth_status,
|
||||
remote_worker,
|
||||
sync,
|
||||
sync_worker,
|
||||
@@ -40,6 +39,7 @@ def ingest(
|
||||
filename,
|
||||
file_name_map=None,
|
||||
idempotency_key=None,
|
||||
source_id=None,
|
||||
):
|
||||
resp = ingest_worker(
|
||||
self,
|
||||
@@ -51,16 +51,21 @@ def ingest(
|
||||
user,
|
||||
file_name_map=file_name_map,
|
||||
idempotency_key=idempotency_key,
|
||||
source_id=source_id,
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(**DURABLE_TASK)
|
||||
@with_idempotency(task_name="ingest_remote")
|
||||
def ingest_remote(self, source_data, job_name, user, loader, idempotency_key=None):
|
||||
def ingest_remote(
|
||||
self, source_data, job_name, user, loader,
|
||||
idempotency_key=None, source_id=None,
|
||||
):
|
||||
resp = remote_worker(
|
||||
self, source_data, job_name, user, loader,
|
||||
idempotency_key=idempotency_key,
|
||||
source_id=source_id,
|
||||
)
|
||||
return resp
|
||||
|
||||
@@ -138,6 +143,7 @@ def ingest_connector_task(
|
||||
doc_id=None,
|
||||
sync_frequency="never",
|
||||
idempotency_key=None,
|
||||
source_id=None,
|
||||
):
|
||||
from application.worker import ingest_connector
|
||||
|
||||
@@ -155,6 +161,7 @@ def ingest_connector_task(
|
||||
doc_id=doc_id,
|
||||
sync_frequency=sync_frequency,
|
||||
idempotency_key=idempotency_key,
|
||||
source_id=source_id,
|
||||
)
|
||||
return resp
|
||||
|
||||
@@ -197,6 +204,15 @@ def setup_periodic_tasks(sender, **kwargs):
|
||||
version_check_task.s(),
|
||||
name="version-check",
|
||||
)
|
||||
# Bound ``message_events`` growth — every streamed SSE chunk writes
|
||||
# one row, so retained chats accumulate hundreds of rows per
|
||||
# message. Reconnect-replay is only meaningful for streams the user
|
||||
# could plausibly still be waiting on, so 14 days is generous.
|
||||
sender.add_periodic_task(
|
||||
timedelta(hours=24),
|
||||
cleanup_message_events.s(),
|
||||
name="cleanup-message-events",
|
||||
)
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
@@ -205,12 +221,6 @@ def mcp_oauth_task(self, config, user):
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def mcp_oauth_status_task(self, task_id):
|
||||
resp = mcp_oauth_status(self, task_id)
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(bind=True, acks_late=False)
|
||||
def cleanup_pending_tool_state(self):
|
||||
"""Revert stale ``resuming`` rows, then delete TTL-expired rows."""
|
||||
@@ -265,6 +275,32 @@ def reconciliation_task(self):
|
||||
return run_reconciliation()
|
||||
|
||||
|
||||
@celery.task(bind=True, acks_late=False)
|
||||
def cleanup_message_events(self):
|
||||
"""Delete ``message_events`` rows older than the retention window.
|
||||
|
||||
Streamed answer responses write one journal row per SSE yield,
|
||||
so unbounded growth would dominate Postgres for any retained-
|
||||
conversations deployment. The reconnect-replay path only needs
|
||||
rows for in-flight streams; 14 days covers paused/tool-action
|
||||
flows comfortably.
|
||||
"""
|
||||
from application.core.settings import settings
|
||||
if not settings.POSTGRES_URI:
|
||||
return {"deleted": 0, "skipped": "POSTGRES_URI not set"}
|
||||
|
||||
from application.storage.db.engine import get_engine
|
||||
from application.storage.db.repositories.message_events import (
|
||||
MessageEventsRepository,
|
||||
)
|
||||
|
||||
ttl_days = settings.MESSAGE_EVENTS_RETENTION_DAYS
|
||||
engine = get_engine()
|
||||
with engine.begin() as conn:
|
||||
deleted = MessageEventsRepository(conn).cleanup_older_than(ttl_days)
|
||||
return {"deleted": deleted, "ttl_days": ttl_days}
|
||||
|
||||
|
||||
@celery.task(bind=True, acks_late=False)
|
||||
def version_check_task(self):
|
||||
"""Periodic anonymous version check.
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Tool management MCP server integration."""
|
||||
|
||||
import json
|
||||
from urllib.parse import urlencode, urlparse
|
||||
|
||||
from flask import current_app, jsonify, make_response, redirect, request
|
||||
@@ -226,7 +225,9 @@ class MCPServerSave(Resource):
|
||||
)
|
||||
redis_client = get_redis_instance()
|
||||
manager = MCPOAuthManager(redis_client)
|
||||
result = manager.get_oauth_status(config["oauth_task_id"])
|
||||
result = manager.get_oauth_status(
|
||||
config["oauth_task_id"], user
|
||||
)
|
||||
if not result.get("status") == "completed":
|
||||
return make_response(
|
||||
jsonify(
|
||||
@@ -438,56 +439,6 @@ class MCPOAuthCallback(Resource):
|
||||
)
|
||||
|
||||
|
||||
@tools_mcp_ns.route("/mcp_server/oauth_status/<string:task_id>")
|
||||
class MCPOAuthStatus(Resource):
|
||||
def get(self, task_id):
|
||||
try:
|
||||
redis_client = get_redis_instance()
|
||||
status_key = f"mcp_oauth_status:{task_id}"
|
||||
status_data = redis_client.get(status_key)
|
||||
|
||||
if status_data:
|
||||
status = json.loads(status_data)
|
||||
if "tools" in status and isinstance(status["tools"], list):
|
||||
status["tools"] = [
|
||||
{
|
||||
"name": t.get("name", "unknown"),
|
||||
"description": t.get("description", ""),
|
||||
}
|
||||
for t in status["tools"]
|
||||
]
|
||||
return make_response(
|
||||
jsonify({"success": True, "task_id": task_id, **status})
|
||||
)
|
||||
else:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"task_id": task_id,
|
||||
"status": "pending",
|
||||
"message": "Waiting for OAuth to start...",
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error getting OAuth status for task {task_id}: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Failed to get OAuth status",
|
||||
"task_id": task_id,
|
||||
}
|
||||
),
|
||||
500,
|
||||
)
|
||||
|
||||
|
||||
@tools_mcp_ns.route("/mcp_server/auth_status")
|
||||
class MCPAuthStatus(Resource):
|
||||
@api.doc(
|
||||
|
||||
@@ -222,13 +222,26 @@ def _stream_response(
|
||||
for line in internal_stream:
|
||||
if not line.strip():
|
||||
continue
|
||||
# Parse the internal SSE event
|
||||
event_str = line.replace("data: ", "").strip()
|
||||
# ``complete_stream`` prefixes each frame with ``id: <seq>\n``
|
||||
# before the ``data:`` line. Extract just the data line so JSON
|
||||
# decode doesn't choke on the SSE framing.
|
||||
event_str = ""
|
||||
for raw in line.split("\n"):
|
||||
if raw.startswith("data:"):
|
||||
event_str = raw[len("data:") :].lstrip()
|
||||
break
|
||||
if not event_str:
|
||||
continue
|
||||
try:
|
||||
event_data = json.loads(event_str)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
continue
|
||||
|
||||
# Skip the informational ``message_id`` event — it has no v1 /
|
||||
# OpenAI-compatible analog.
|
||||
if event_data.get("type") == "message_id":
|
||||
continue
|
||||
|
||||
# Update completion_id when we get the conversation id
|
||||
if event_data.get("type") == "id":
|
||||
conv_id = event_data.get("id", "")
|
||||
|
||||
@@ -16,6 +16,8 @@ setup_logging()
|
||||
|
||||
from application.api import api # noqa: E402
|
||||
from application.api.answer import answer # noqa: E402
|
||||
from application.api.answer.routes.messages import messages_bp # noqa: E402
|
||||
from application.api.events.routes import events # noqa: E402
|
||||
from application.api.internal.routes import internal # noqa: E402
|
||||
from application.api.user.routes import user # noqa: E402
|
||||
from application.api.connector.routes import connector # noqa: E402
|
||||
@@ -49,6 +51,8 @@ ensure_database_ready(
|
||||
app = Flask(__name__)
|
||||
app.register_blueprint(user)
|
||||
app.register_blueprint(answer)
|
||||
app.register_blueprint(events)
|
||||
app.register_blueprint(messages_bp)
|
||||
app.register_blueprint(internal)
|
||||
app.register_blueprint(connector)
|
||||
app.register_blueprint(v1_bp)
|
||||
|
||||
@@ -29,8 +29,17 @@ def get_redis_instance():
|
||||
with _instance_lock:
|
||||
if _redis_instance is None and not _redis_creation_failed:
|
||||
try:
|
||||
# ``health_check_interval`` makes redis-py ping the
|
||||
# connection every N seconds when otherwise idle.
|
||||
# Without it, a half-open TCP (NAT silently dropped
|
||||
# state, ELB idle-close) can hang the SSE generator
|
||||
# in ``pubsub.get_message`` past its keepalive
|
||||
# cadence — the kernel never surfaces the dead
|
||||
# socket because no payload is in flight.
|
||||
_redis_instance = redis.Redis.from_url(
|
||||
settings.CACHE_REDIS_URL, socket_connect_timeout=2
|
||||
settings.CACHE_REDIS_URL,
|
||||
socket_connect_timeout=2,
|
||||
health_check_interval=10,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"Invalid Redis URL: {e}")
|
||||
|
||||
@@ -188,6 +188,42 @@ class Settings(BaseSettings):
|
||||
COMPRESSION_PROMPT_VERSION: str = "v1.0" # Track prompt iterations
|
||||
COMPRESSION_MAX_HISTORY_POINTS: int = 3 # Keep only last N compression points to prevent DB bloat
|
||||
|
||||
# Internal SSE push channel (notifications + durable replay journal)
|
||||
# Master switch — when False, /api/events emits a "push_disabled" comment
|
||||
# and returns; clients fall back to polling. Publisher becomes a no-op.
|
||||
ENABLE_SSE_PUSH: bool = True
|
||||
# Per-user durable backlog cap (~entries). At typical event rates this
|
||||
# gives ~24h of replay; tune up for verbose feeds, down for memory.
|
||||
EVENTS_STREAM_MAXLEN: int = 1000
|
||||
# SSE keepalive comment cadence. Must sit under Cloudflare's 100s idle
|
||||
# close and iOS Safari's ~60s — 15s gives generous headroom.
|
||||
SSE_KEEPALIVE_SECONDS: int = 15
|
||||
# Cap on simultaneous SSE connections per user. Each connection holds
|
||||
# one WSGI thread (32 per gunicorn worker) and one Redis pub/sub
|
||||
# connection. 8 covers normal multi-tab use without letting one user
|
||||
# starve the pool. Set to 0 to disable the cap.
|
||||
SSE_MAX_CONCURRENT_PER_USER: int = 8
|
||||
# Per-request cap on the number of backlog entries XRANGE returns
|
||||
# for ``/api/events`` snapshots. Bounds the bytes a single replay
|
||||
# can move from Redis to the wire — a malicious client looping
|
||||
# ``Last-Event-ID=<oldest>`` reconnects can only enumerate this
|
||||
# many entries per round-trip. Combined with the per-user
|
||||
# connection cap above and the windowed budget below, total
|
||||
# enumeration throughput is bounded.
|
||||
EVENTS_REPLAY_MAX_PER_REQUEST: int = 200
|
||||
# Sliding-window cap on snapshot replays per user. Once the budget
|
||||
# is exhausted the route returns HTTP 429 with the cursor pinned;
|
||||
# the client backs off and retries after the window rolls over.
|
||||
EVENTS_REPLAY_BUDGET_REQUESTS_PER_WINDOW: int = 30
|
||||
EVENTS_REPLAY_BUDGET_WINDOW_SECONDS: int = 60
|
||||
|
||||
# Retention for the ``message_events`` journal. The ``cleanup_message_events``
|
||||
# beat task deletes rows older than this. Reconnect-replay only
|
||||
# needs the journal for streams a client could still be tailing,
|
||||
# so 14 days is a generous default that covers paused/tool-action
|
||||
# flows without unbounded table growth.
|
||||
MESSAGE_EVENTS_RETENTION_DAYS: int = 14
|
||||
|
||||
@field_validator("POSTGRES_URI", mode="before")
|
||||
@classmethod
|
||||
def _normalize_postgres_uri_validator(cls, v):
|
||||
|
||||
0
application/events/__init__.py
Normal file
0
application/events/__init__.py
Normal file
52
application/events/keys.py
Normal file
52
application/events/keys.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Stream/topic key derivations shared by publisher and SSE consumer.
|
||||
|
||||
Single source of truth for the per-user Redis Streams key and pub/sub
|
||||
topic name. Both must agree exactly — a typo here splits the
|
||||
publisher's writes from the consumer's reads.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def stream_key(user_id: str) -> str:
|
||||
"""Redis Streams key holding the durable backlog for ``user_id``."""
|
||||
return f"user:{user_id}:stream"
|
||||
|
||||
|
||||
def topic_name(user_id: str) -> str:
|
||||
"""Redis pub/sub channel used for live fan-out to ``user_id``."""
|
||||
return f"user:{user_id}"
|
||||
|
||||
|
||||
def connection_counter_key(user_id: str) -> str:
|
||||
"""Redis counter tracking active SSE connections for ``user_id``."""
|
||||
return f"user:{user_id}:sse_count"
|
||||
|
||||
|
||||
def replay_budget_key(user_id: str) -> str:
|
||||
"""Redis counter tracking snapshot replays for ``user_id`` in the
|
||||
rolling rate-limit window."""
|
||||
return f"user:{user_id}:replay_count"
|
||||
|
||||
|
||||
def stream_id_compare(a: str, b: str) -> int:
|
||||
"""Compare two Redis Streams ids. Returns -1, 0, 1 like ``cmp``.
|
||||
|
||||
Stream ids are ``ms-seq`` strings; comparing as strings would be wrong
|
||||
once ``ms`` straddles digit-count boundaries. We parse and compare
|
||||
as ``(int, int)`` tuples.
|
||||
|
||||
Raises ``ValueError`` on malformed input. Callers must pre-validate
|
||||
against ``_STREAM_ID_RE`` (or equivalent) — a lex fallback here let
|
||||
a malformed id compare lex-greater than a real one and silently pin
|
||||
dedup forever.
|
||||
"""
|
||||
a_ms, _, a_seq = a.partition("-")
|
||||
b_ms, _, b_seq = b.partition("-")
|
||||
a_tuple = (int(a_ms), int(a_seq) if a_seq else 0)
|
||||
b_tuple = (int(b_ms), int(b_seq) if b_seq else 0)
|
||||
if a_tuple < b_tuple:
|
||||
return -1
|
||||
if a_tuple > b_tuple:
|
||||
return 1
|
||||
return 0
|
||||
144
application/events/publisher.py
Normal file
144
application/events/publisher.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""User-scoped event publisher: durable backlog + live fan-out.
|
||||
|
||||
Each ``publish_user_event`` call writes twice:
|
||||
|
||||
1. ``XADD user:{user_id}:stream MAXLEN ~ <cap> * event <json>`` — the
|
||||
durable backlog used by SSE reconnect (``Last-Event-ID``) and stream
|
||||
replay. Bounded by ``EVENTS_STREAM_MAXLEN`` (~24h at typical event
|
||||
rates) so the per-user footprint stays predictable.
|
||||
2. ``PUBLISH user:{user_id} <json-with-id>`` — live fan-out to every
|
||||
currently connected SSE generator for the user, across instances.
|
||||
|
||||
Together they give a snapshot-plus-tail story: a reconnecting client
|
||||
reads ``XRANGE`` from its last seen id and then transitions onto the
|
||||
live pub/sub. The Redis Streams entry id (e.g. ``1735682400000-0``) is
|
||||
the canonical, monotonically increasing event id and is what
|
||||
``Last-Event-ID`` carries.
|
||||
|
||||
Failures are logged and swallowed: the caller is typically a Celery
|
||||
task whose primary work has already succeeded, and a notification
|
||||
delivery miss should not surface as a task failure.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
|
||||
from application.cache import get_redis_instance
|
||||
from application.core.settings import settings
|
||||
from application.events.keys import stream_key, topic_name
|
||||
from application.streaming.broadcast_channel import Topic
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _iso_now() -> str:
|
||||
"""ISO 8601 UTC with millisecond precision and Z suffix."""
|
||||
return (
|
||||
datetime.now(timezone.utc)
|
||||
.isoformat(timespec="milliseconds")
|
||||
.replace("+00:00", "Z")
|
||||
)
|
||||
|
||||
|
||||
def publish_user_event(
|
||||
user_id: str,
|
||||
event_type: str,
|
||||
payload: dict[str, Any],
|
||||
*,
|
||||
scope: Optional[dict[str, Any]] = None,
|
||||
) -> Optional[str]:
|
||||
"""Publish a user-scoped event; return the Redis Streams id or ``None``.
|
||||
|
||||
Fire-and-forget: never raises. ``None`` means the event reached
|
||||
neither the journal nor live subscribers (see runbook for causes).
|
||||
"""
|
||||
if not user_id or not event_type:
|
||||
logger.warning(
|
||||
"publish_user_event called without user_id or event_type "
|
||||
"(user_id=%r, event_type=%r)",
|
||||
user_id,
|
||||
event_type,
|
||||
)
|
||||
return None
|
||||
if not settings.ENABLE_SSE_PUSH:
|
||||
return None
|
||||
|
||||
envelope_partial: dict[str, Any] = {
|
||||
"type": event_type,
|
||||
"ts": _iso_now(),
|
||||
"user_id": user_id,
|
||||
"topic": topic_name(user_id),
|
||||
"scope": scope or {},
|
||||
"payload": payload,
|
||||
}
|
||||
|
||||
try:
|
||||
envelope_partial_json = json.dumps(envelope_partial)
|
||||
except (TypeError, ValueError) as exc:
|
||||
logger.warning(
|
||||
"publish_user_event payload not JSON-serializable: "
|
||||
"user=%s type=%s err=%s",
|
||||
user_id,
|
||||
event_type,
|
||||
exc,
|
||||
)
|
||||
return None
|
||||
|
||||
redis = get_redis_instance()
|
||||
if redis is None:
|
||||
logger.debug("Redis unavailable; skipping publish_user_event")
|
||||
return None
|
||||
|
||||
maxlen = settings.EVENTS_STREAM_MAXLEN
|
||||
stream_id: Optional[str] = None
|
||||
try:
|
||||
# Auto-id ('*') gives a monotonic ms-seq id that doubles as the
|
||||
# SSE event id. ``approximate=True`` lets Redis trim in chunks
|
||||
# for performance; the cap is treated as ~MAXLEN, never <.
|
||||
result = redis.xadd(
|
||||
stream_key(user_id),
|
||||
{"event": envelope_partial_json},
|
||||
maxlen=maxlen,
|
||||
approximate=True,
|
||||
)
|
||||
stream_id = (
|
||||
result.decode("utf-8")
|
||||
if isinstance(result, (bytes, bytearray))
|
||||
else str(result)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"xadd failed for user=%s event_type=%s", user_id, event_type
|
||||
)
|
||||
|
||||
# If the durable journal write failed there is no canonical id to
|
||||
# ship — publishing the envelope live would put an id-less record
|
||||
# on the wire that bypasses the SSE route's dedup floor and breaks
|
||||
# ``Last-Event-ID`` semantics for any reconnect. Best-effort
|
||||
# delivery means dropping consistently, not delivering inconsistent
|
||||
# state.
|
||||
if stream_id is None:
|
||||
return None
|
||||
|
||||
envelope = dict(envelope_partial)
|
||||
envelope["id"] = stream_id
|
||||
|
||||
try:
|
||||
Topic(topic_name(user_id)).publish(json.dumps(envelope))
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"publish failed for user=%s event_type=%s", user_id, event_type
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"event.published topic=%s type=%s id=%s",
|
||||
topic_name(user_id),
|
||||
event_type,
|
||||
stream_id,
|
||||
)
|
||||
|
||||
return stream_id
|
||||
@@ -4,6 +4,7 @@ from typing import Any, List, Optional
|
||||
from retry import retry
|
||||
from tqdm import tqdm
|
||||
from application.core.settings import settings
|
||||
from application.events.publisher import publish_user_event
|
||||
from application.storage.db.repositories.ingest_chunk_progress import (
|
||||
IngestChunkProgressRepository,
|
||||
)
|
||||
@@ -152,6 +153,7 @@ def embed_and_store_documents(
|
||||
task_status: Any,
|
||||
*,
|
||||
attempt_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Embeds documents and stores them in a vector store.
|
||||
|
||||
@@ -170,6 +172,10 @@ def embed_and_store_documents(
|
||||
attempt_id: Stable id of the current task invocation,
|
||||
typically ``self.request.id`` from the Celery task body.
|
||||
``None`` is treated as a fresh attempt every time.
|
||||
user_id: When provided, per-percent SSE progress events are
|
||||
published to ``user:{user_id}`` for the in-app upload toast.
|
||||
``None`` is the safe default — workers without a user
|
||||
context (e.g. background syncs) skip the publish.
|
||||
|
||||
Returns:
|
||||
None
|
||||
@@ -249,6 +255,8 @@ def embed_and_store_documents(
|
||||
# Process and embed documents
|
||||
chunk_error: Exception | None = None
|
||||
failed_idx: int | None = None
|
||||
last_published_pct = -1
|
||||
source_id_str = str(source_id)
|
||||
for idx in tqdm(
|
||||
range(loop_start, total_docs),
|
||||
desc="Embedding 🦖",
|
||||
@@ -262,6 +270,24 @@ def embed_and_store_documents(
|
||||
progress = int(((idx + 1) / total_docs) * 100)
|
||||
task_status.update_state(state="PROGRESS", meta={"current": progress})
|
||||
|
||||
# SSE push for sub-second upload-toast updates. Throttled to one
|
||||
# event per percent so a 10k-chunk ingest emits ~100 events,
|
||||
# not 10k. The Celery update_state above stays the source of
|
||||
# truth for the polling-fallback path.
|
||||
if user_id and progress > last_published_pct:
|
||||
publish_user_event(
|
||||
user_id,
|
||||
"source.ingest.progress",
|
||||
{
|
||||
"current": progress,
|
||||
"total": total_docs,
|
||||
"embedded_chunks": idx + 1,
|
||||
"stage": "embedding",
|
||||
},
|
||||
scope={"kind": "source", "id": source_id_str},
|
||||
)
|
||||
last_published_pct = progress
|
||||
|
||||
# Add document to vector store
|
||||
add_text_to_store_with_retry(store, doc, source_id)
|
||||
_record_progress(source_id, last_index=idx, embedded_chunks=idx + 1)
|
||||
|
||||
@@ -34,7 +34,7 @@ from sqlalchemy.dialects.postgresql import ARRAY, CITEXT, JSONB, UUID
|
||||
metadata = MetaData()
|
||||
|
||||
|
||||
# --- Phase 1, Tier 1 --------------------------------------------------------
|
||||
# --- Users, prompts, tools, logs --------------------------------------------
|
||||
|
||||
users_table = Table(
|
||||
"users",
|
||||
@@ -138,7 +138,7 @@ app_metadata_table = Table(
|
||||
)
|
||||
|
||||
|
||||
# --- Phase 2, Tier 2 --------------------------------------------------------
|
||||
# --- Agents, sources, attachments, artifacts --------------------------------
|
||||
|
||||
agent_folders_table = Table(
|
||||
"agent_folders",
|
||||
@@ -307,7 +307,7 @@ connector_sessions_table = Table(
|
||||
)
|
||||
|
||||
|
||||
# --- Phase 3, Tier 3 --------------------------------------------------------
|
||||
# --- Conversations, messages, workflows -------------------------------------
|
||||
|
||||
conversations_table = Table(
|
||||
"conversations",
|
||||
@@ -363,6 +363,36 @@ conversation_messages_table = Table(
|
||||
UniqueConstraint("conversation_id", "position", name="conversation_messages_conv_pos_uidx"),
|
||||
)
|
||||
|
||||
# Per-yield journal of chat-stream events, used by the snapshot+tail
|
||||
# reconnect: the route's GET reconnect endpoint reads
|
||||
# ``WHERE message_id = ? AND sequence_no > ?`` from this table before
|
||||
# tailing the live ``channel:{message_id}`` pub/sub. See
|
||||
# ``application/streaming/event_replay.py`` and migration 0007.
|
||||
message_events_table = Table(
|
||||
"message_events",
|
||||
metadata,
|
||||
# PK is the composite ``(message_id, sequence_no)`` — it doubles as
|
||||
# the snapshot read index (covering range scan on
|
||||
# ``WHERE message_id = ? AND sequence_no > ?``).
|
||||
Column(
|
||||
"message_id",
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("conversation_messages.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
nullable=False,
|
||||
),
|
||||
# Strictly monotonic per ``message_id``. Allocated by the route as it
|
||||
# yields, so the writer is single-threaded for the lifetime of one
|
||||
# stream — no contention, no SERIAL needed.
|
||||
Column("sequence_no", Integer, primary_key=True, nullable=False),
|
||||
Column("event_type", Text, nullable=False),
|
||||
Column("payload", JSONB, nullable=False, server_default="{}"),
|
||||
Column(
|
||||
"created_at", DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
shared_conversations_table = Table(
|
||||
"shared_conversations",
|
||||
metadata,
|
||||
@@ -403,7 +433,7 @@ pending_tool_state_table = Table(
|
||||
)
|
||||
|
||||
|
||||
# --- Tier 1 durability foundation (migration 0004) --------------------------
|
||||
# --- Durability foundation (idempotency / journals, migration 0004) ---------
|
||||
# CHECK constraints (status enums) and partial indexes are intentionally
|
||||
# omitted from these declarations — the DB is the authority. Repositories
|
||||
# use raw ``text(...)`` SQL against these tables, not the Core objects.
|
||||
|
||||
@@ -15,6 +15,7 @@ Covers every operation the legacy Mongo code performs on
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
@@ -25,6 +26,22 @@ from application.storage.db.models import conversations_table, conversation_mess
|
||||
from application.storage.db.serialization import PGNativeJSONEncoder
|
||||
|
||||
|
||||
class MessageUpdateOutcome(str, Enum):
|
||||
"""Discriminated result of ``update_message_by_id``.
|
||||
|
||||
Distinguishes the row-actually-updated case from the row-already-at-
|
||||
the-requested-terminal-state case so an abort handler can journal
|
||||
``end`` instead of ``error`` when the normal-path finalize already
|
||||
flipped the row to ``complete``.
|
||||
"""
|
||||
|
||||
UPDATED = "updated"
|
||||
ALREADY_COMPLETE = "already_complete"
|
||||
ALREADY_FAILED = "already_failed"
|
||||
NOT_FOUND = "not_found"
|
||||
INVALID = "invalid"
|
||||
|
||||
|
||||
def _message_row_to_dict(row) -> dict:
|
||||
"""Like ``row_to_dict`` but renames the DB column ``message_metadata``
|
||||
back to the public API key ``metadata`` so callers keep the Mongo-era
|
||||
@@ -58,8 +75,8 @@ class ConversationsRepository:
|
||||
- Already-UUID-shaped → returned as-is.
|
||||
- Otherwise treated as a Mongo ObjectId and looked up via
|
||||
``agents.legacy_mongo_id``. Returns ``None`` if no PG row
|
||||
exists yet (e.g. the agent was created before Phase 1
|
||||
backfill).
|
||||
exists yet (e.g. the agent was created before the backfill
|
||||
ran).
|
||||
"""
|
||||
if not agent_id_raw:
|
||||
return None
|
||||
@@ -245,57 +262,6 @@ class ConversationsRepository:
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def search_for_user(
|
||||
self, user_id: str, query: str, limit: int = 30,
|
||||
) -> list[dict]:
|
||||
"""Search a user's conversations by name or message content.
|
||||
|
||||
Same visibility filter as :meth:`list_for_user`. Matches against
|
||||
``conversations.name`` or any of the conversation's messages'
|
||||
``prompt`` / ``response`` columns (case-insensitive substring).
|
||||
|
||||
Each returned row includes ``match_field`` (one of ``name``,
|
||||
``prompt``, ``response``) and ``match_text`` (the full text of the
|
||||
first matching field, ``name`` taking precedence over messages,
|
||||
``prompt`` over ``response``) so callers can render a snippet.
|
||||
"""
|
||||
if not query:
|
||||
return []
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT c.*, mt.match_field, mt.match_text "
|
||||
"FROM conversations c "
|
||||
"JOIN LATERAL ( "
|
||||
" SELECT field AS match_field, txt AS match_text "
|
||||
" FROM ( "
|
||||
" SELECT 'name'::text AS field, c.name AS txt, 0 AS prio "
|
||||
" WHERE c.name ILIKE :pattern "
|
||||
" UNION ALL "
|
||||
" SELECT 'prompt'::text, m.prompt, 1 "
|
||||
" FROM conversation_messages m "
|
||||
" WHERE m.conversation_id = c.id "
|
||||
" AND m.prompt ILIKE :pattern "
|
||||
" UNION ALL "
|
||||
" SELECT 'response'::text, m.response, 2 "
|
||||
" FROM conversation_messages m "
|
||||
" WHERE m.conversation_id = c.id "
|
||||
" AND m.response ILIKE :pattern "
|
||||
" ) s "
|
||||
" ORDER BY prio "
|
||||
" LIMIT 1 "
|
||||
") mt ON TRUE "
|
||||
"WHERE c.user_id = :user_id "
|
||||
"AND (c.api_key IS NULL OR c.agent_id IS NOT NULL) "
|
||||
"ORDER BY c.date DESC LIMIT :limit"
|
||||
),
|
||||
{
|
||||
"user_id": user_id,
|
||||
"pattern": f"%{query}%",
|
||||
"limit": limit,
|
||||
},
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def rename(self, conversation_id: str, user_id: str, name: str) -> bool:
|
||||
# Shape-gate so a non-UUID id (legacy Mongo ObjectId still floating
|
||||
# around in client-side state during the cutover) never reaches the
|
||||
@@ -748,7 +714,7 @@ class ConversationsRepository:
|
||||
def update_message_by_id(
|
||||
self, message_id: str, fields: dict,
|
||||
*, only_if_non_terminal: bool = False,
|
||||
) -> bool:
|
||||
) -> MessageUpdateOutcome:
|
||||
"""Update specific fields on a message identified by its UUID.
|
||||
|
||||
``metadata`` is merged into the existing JSONB rather than
|
||||
@@ -756,9 +722,13 @@ class ConversationsRepository:
|
||||
a successful late finalize. When ``only_if_non_terminal`` is
|
||||
True, the update is gated so a late finalize cannot retract a
|
||||
reconciler-set ``failed`` (or a prior ``complete``).
|
||||
|
||||
The return value discriminates "I updated the row" from "the
|
||||
row was already at a terminal state" so the abort handler can
|
||||
journal ``end`` when the normal-path finalize already ran.
|
||||
"""
|
||||
if not looks_like_uuid(message_id):
|
||||
return False
|
||||
return MessageUpdateOutcome.INVALID
|
||||
allowed = {
|
||||
"prompt", "response", "thought", "sources", "tool_calls",
|
||||
"attachments", "model_id", "metadata", "timestamp", "status",
|
||||
@@ -766,7 +736,7 @@ class ConversationsRepository:
|
||||
}
|
||||
filtered = {k: v for k, v in fields.items() if k in allowed}
|
||||
if not filtered:
|
||||
return False
|
||||
return MessageUpdateOutcome.INVALID
|
||||
|
||||
api_to_col = {"metadata": "message_metadata"}
|
||||
|
||||
@@ -803,15 +773,44 @@ class ConversationsRepository:
|
||||
params[col] = val
|
||||
|
||||
set_parts.append("updated_at = now()")
|
||||
where_clauses = ["id = CAST(:id AS uuid)"]
|
||||
update_where = ["id = CAST(:id AS uuid)"]
|
||||
if only_if_non_terminal:
|
||||
where_clauses.append("status NOT IN ('complete', 'failed')")
|
||||
update_where.append("status NOT IN ('complete', 'failed')")
|
||||
# Single-statement attempt + prior-status probe. Both CTEs see
|
||||
# the same MVCC snapshot, so ``prior.status`` reflects the row
|
||||
# state before the UPDATE — exactly what we need to tell
|
||||
# ``ALREADY_COMPLETE`` apart from ``ALREADY_FAILED`` apart from
|
||||
# ``NOT_FOUND`` without a follow-up SELECT.
|
||||
sql = (
|
||||
f"UPDATE conversation_messages SET {', '.join(set_parts)} "
|
||||
f"WHERE {' AND '.join(where_clauses)}"
|
||||
"WITH attempted AS ("
|
||||
f" UPDATE conversation_messages SET {', '.join(set_parts)} "
|
||||
f" WHERE {' AND '.join(update_where)} "
|
||||
" RETURNING 1 AS updated"
|
||||
"), "
|
||||
"prior AS ("
|
||||
" SELECT status FROM conversation_messages "
|
||||
" WHERE id = CAST(:id AS uuid)"
|
||||
") "
|
||||
"SELECT (SELECT updated FROM attempted) AS updated, "
|
||||
" (SELECT status FROM prior) AS prior_status"
|
||||
)
|
||||
result = self._conn.execute(text(sql), params)
|
||||
return result.rowcount > 0
|
||||
row = self._conn.execute(text(sql), params).fetchone()
|
||||
if row is None:
|
||||
return MessageUpdateOutcome.NOT_FOUND
|
||||
updated, prior_status = row[0], row[1]
|
||||
if updated:
|
||||
return MessageUpdateOutcome.UPDATED
|
||||
if prior_status is None:
|
||||
return MessageUpdateOutcome.NOT_FOUND
|
||||
if prior_status == "complete":
|
||||
return MessageUpdateOutcome.ALREADY_COMPLETE
|
||||
if prior_status == "failed":
|
||||
return MessageUpdateOutcome.ALREADY_FAILED
|
||||
# ``only_if_non_terminal=False`` always updates an existing row,
|
||||
# so reaching here means the gate excluded it for some status
|
||||
# the terminal set doesn't cover — treat as "not found" rather
|
||||
# than inventing a new variant.
|
||||
return MessageUpdateOutcome.NOT_FOUND
|
||||
|
||||
def update_message_status(
|
||||
self, message_id: str, status: str,
|
||||
|
||||
248
application/storage/db/repositories/message_events.py
Normal file
248
application/storage/db/repositories/message_events.py
Normal file
@@ -0,0 +1,248 @@
|
||||
"""Repository for ``message_events`` — the chat-stream snapshot journal.
|
||||
|
||||
``record`` / ``bulk_record`` write per-yield events; ``read_after``
|
||||
replays rows past a cursor for reconnect snapshots. Composite PK
|
||||
``(message_id, sequence_no)`` raises ``IntegrityError`` on duplicates.
|
||||
Callers must use short-lived per-call transactions — long-lived
|
||||
transactions hide writes from reconnecting clients on a separate
|
||||
connection and turn one bad row into ``InFailedSqlTransaction``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageEventsRepository:
|
||||
"""Read/write helpers for ``message_events``."""
|
||||
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def record(
|
||||
self,
|
||||
message_id: str,
|
||||
sequence_no: int,
|
||||
event_type: str,
|
||||
payload: Optional[Any] = None,
|
||||
) -> None:
|
||||
"""Append a single event to the journal.
|
||||
|
||||
At this raw repo layer ``payload`` is preserved as-is when not
|
||||
``None`` (lists, scalars, and dicts all round-trip via JSONB);
|
||||
``None`` substitutes an empty object so the column's NOT NULL
|
||||
invariant holds. The streaming-route wrapper
|
||||
``application/streaming/message_journal.py::record_event``
|
||||
tightens this contract to dicts only — the live and replay
|
||||
paths reconstruct non-dict payloads differently, so the wrapper
|
||||
rejects them at the gate. Direct callers of this repo method
|
||||
(cleanup tasks, tests, future ad-hoc consumers) keep the wider
|
||||
JSONB-compatible surface.
|
||||
|
||||
Raises ``sqlalchemy.exc.IntegrityError`` on duplicate
|
||||
``(message_id, sequence_no)`` and ``DataError`` on a malformed
|
||||
``message_id`` UUID. Both abort the surrounding transaction —
|
||||
callers must run inside a short-lived per-event session
|
||||
(see module docstring).
|
||||
"""
|
||||
if not event_type:
|
||||
raise ValueError("event_type must be a non-empty string")
|
||||
materialised_payload = payload if payload is not None else {}
|
||||
self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO message_events (
|
||||
message_id, sequence_no, event_type, payload
|
||||
) VALUES (
|
||||
CAST(:message_id AS uuid), :sequence_no, :event_type,
|
||||
CAST(:payload AS jsonb)
|
||||
)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"message_id": str(message_id),
|
||||
"sequence_no": int(sequence_no),
|
||||
"event_type": event_type,
|
||||
"payload": json.dumps(materialised_payload),
|
||||
},
|
||||
)
|
||||
|
||||
def bulk_record(
|
||||
self,
|
||||
message_id: str,
|
||||
events: list[tuple[int, str, dict]],
|
||||
) -> None:
|
||||
"""Append multiple events for ``message_id`` in one INSERT.
|
||||
|
||||
``events`` is a list of ``(sequence_no, event_type, payload)``
|
||||
tuples. SQLAlchemy ``executemany`` issues one bulk INSERT;
|
||||
Postgres treats the whole batch as one statement, so an
|
||||
IntegrityError on any row aborts the entire batch.
|
||||
|
||||
Caller contract: on IntegrityError, do NOT retry this method
|
||||
with the same batch — fall back to per-row ``record()`` calls
|
||||
(each in its own short-lived session) so a single colliding
|
||||
seq doesn't drop the rest of the batch. ``BatchedJournalWriter``
|
||||
in ``application/streaming/message_journal.py`` is the canonical
|
||||
consumer.
|
||||
"""
|
||||
if not events:
|
||||
return
|
||||
params = [
|
||||
{
|
||||
"message_id": str(message_id),
|
||||
"sequence_no": int(seq),
|
||||
"event_type": event_type,
|
||||
"payload": json.dumps(payload if payload is not None else {}),
|
||||
}
|
||||
for seq, event_type, payload in events
|
||||
]
|
||||
self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO message_events (
|
||||
message_id, sequence_no, event_type, payload
|
||||
) VALUES (
|
||||
CAST(:message_id AS uuid), :sequence_no, :event_type,
|
||||
CAST(:payload AS jsonb)
|
||||
)
|
||||
"""
|
||||
),
|
||||
params,
|
||||
)
|
||||
|
||||
def read_after(
|
||||
self,
|
||||
message_id: str,
|
||||
last_sequence_no: Optional[int] = None,
|
||||
) -> list[dict]:
|
||||
"""Return events with ``sequence_no > last_sequence_no``.
|
||||
|
||||
``last_sequence_no=None`` returns the full backlog. Rows are
|
||||
returned in ascending ``sequence_no`` order. The composite PK
|
||||
is the snapshot read index for this scan — Postgres typically
|
||||
picks an in-order index range scan, though for highly mixed
|
||||
data the planner may pick a bitmap+sort. Either way the result
|
||||
is sorted on ``sequence_no``.
|
||||
|
||||
Returns a ``list`` (not a generator) so the underlying
|
||||
``Result`` is fully drained before the caller can issue
|
||||
another query on the same connection.
|
||||
"""
|
||||
cursor = -1 if last_sequence_no is None else int(last_sequence_no)
|
||||
rows = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT message_id, sequence_no, event_type, payload, created_at
|
||||
FROM message_events
|
||||
WHERE message_id = CAST(:message_id AS uuid)
|
||||
AND sequence_no > :cursor
|
||||
ORDER BY sequence_no ASC
|
||||
"""
|
||||
),
|
||||
{"message_id": str(message_id), "cursor": cursor},
|
||||
).fetchall()
|
||||
return [row_to_dict(row) for row in rows]
|
||||
|
||||
def cleanup_older_than(self, ttl_days: int) -> int:
|
||||
"""Delete journal rows older than ``ttl_days``. Returns row count.
|
||||
|
||||
Reconnect-replay is meaningful only for streams the client
|
||||
could plausibly still be waiting on, so old rows are dead
|
||||
weight. The ``message_events_created_at_idx`` btree makes the
|
||||
range delete a cheap index scan even on large tables.
|
||||
"""
|
||||
if ttl_days <= 0:
|
||||
raise ValueError("ttl_days must be positive")
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
DELETE FROM message_events
|
||||
WHERE created_at < now() - make_interval(days => :ttl_days)
|
||||
"""
|
||||
),
|
||||
{"ttl_days": int(ttl_days)},
|
||||
)
|
||||
return int(result.rowcount or 0)
|
||||
|
||||
def reconstruct_partial(self, message_id: str) -> dict:
|
||||
"""Rebuild partial response/thought/sources/tool_calls from journal events.
|
||||
|
||||
``answer``/``thought`` chunks concat in seq order; ``source``/
|
||||
``tool_calls`` carry the full list at emit time (last-wins).
|
||||
"""
|
||||
rows = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT sequence_no, event_type, payload
|
||||
FROM message_events
|
||||
WHERE message_id = CAST(:message_id AS uuid)
|
||||
ORDER BY sequence_no ASC
|
||||
"""
|
||||
),
|
||||
{"message_id": str(message_id)},
|
||||
).fetchall()
|
||||
|
||||
response_parts: list[str] = []
|
||||
thought_parts: list[str] = []
|
||||
sources: list = []
|
||||
tool_calls: list = []
|
||||
|
||||
for row in rows:
|
||||
payload = row.payload
|
||||
if not isinstance(payload, dict):
|
||||
continue
|
||||
etype = row.event_type
|
||||
if etype == "answer":
|
||||
chunk = payload.get("answer")
|
||||
if isinstance(chunk, str):
|
||||
response_parts.append(chunk)
|
||||
elif etype == "thought":
|
||||
chunk = payload.get("thought")
|
||||
if isinstance(chunk, str):
|
||||
thought_parts.append(chunk)
|
||||
elif etype == "source":
|
||||
src = payload.get("source")
|
||||
if isinstance(src, list):
|
||||
sources = src
|
||||
elif etype == "tool_calls":
|
||||
tcs = payload.get("tool_calls")
|
||||
if isinstance(tcs, list):
|
||||
tool_calls = tcs
|
||||
|
||||
return {
|
||||
"response": "".join(response_parts),
|
||||
"thought": "".join(thought_parts),
|
||||
"sources": sources,
|
||||
"tool_calls": tool_calls,
|
||||
}
|
||||
|
||||
def latest_sequence_no(self, message_id: str) -> Optional[int]:
|
||||
"""Largest ``sequence_no`` recorded for ``message_id``, or ``None``.
|
||||
|
||||
Used by the route to seed the per-stream allocator on retry /
|
||||
process restart so a re-run continues numbering instead of
|
||||
trampling earlier entries with duplicate sequence_no.
|
||||
"""
|
||||
# ``MAX`` always returns one row — NULL when the journal is
|
||||
# empty — so we test the value, not the row presence.
|
||||
row = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT MAX(sequence_no) AS s
|
||||
FROM message_events
|
||||
WHERE message_id = CAST(:message_id AS uuid)
|
||||
"""
|
||||
),
|
||||
{"message_id": str(message_id)},
|
||||
).first()
|
||||
value = row[0] if row is not None else None
|
||||
return int(value) if value is not None else None
|
||||
23
application/storage/db/source_ids.py
Normal file
23
application/storage/db/source_ids.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Deterministic source-id derivation for idempotent ingest.
|
||||
|
||||
DO NOT CHANGE the pinned UUID namespace — it backs cross-deploy
|
||||
idempotency keys.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
# DO NOT CHANGE. See module docstring.
|
||||
DOCSGPT_INGEST_NAMESPACE = uuid.UUID("fa25d5d1-398b-46df-ac89-8d1c360b9bea")
|
||||
|
||||
|
||||
def derive_source_id(idempotency_key) -> uuid.UUID:
|
||||
"""``uuid5(NS, key)`` when a key is supplied; ``uuid4()`` otherwise.
|
||||
|
||||
A non-string / empty key falls back to ``uuid4()`` so the caller
|
||||
always gets a fresh id rather than a TypeError mid-route.
|
||||
"""
|
||||
if isinstance(idempotency_key, str) and idempotency_key:
|
||||
return uuid.uuid5(DOCSGPT_INGEST_NAMESPACE, idempotency_key)
|
||||
return uuid.uuid4()
|
||||
0
application/streaming/__init__.py
Normal file
0
application/streaming/__init__.py
Normal file
126
application/streaming/broadcast_channel.py
Normal file
126
application/streaming/broadcast_channel.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""Redis pub/sub Topic abstraction for SSE fan-out.
|
||||
|
||||
A Topic is a named channel for one-shot live event delivery. Canonical uses:
|
||||
|
||||
- ``user:{user_id}`` for per-user notifications
|
||||
- ``channel:{message_id}`` for per-chat-message streams
|
||||
|
||||
Subscription is race-free via ``on_subscribe``: the callback fires only
|
||||
after Redis acknowledges ``SUBSCRIBE``, so a publisher dispatched inside
|
||||
the callback cannot lose its first event to a not-yet-registered
|
||||
subscriber.
|
||||
|
||||
The subscribe iterator yields ``None`` on poll timeout so the caller can
|
||||
emit SSE keepalive comments without spawning a separate timer thread.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Callable, Iterator, Optional
|
||||
|
||||
from application.cache import get_redis_instance
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Topic:
|
||||
"""A pub/sub channel identified by a string name."""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
self.name = name
|
||||
|
||||
def publish(self, payload: str | bytes) -> int:
|
||||
"""Fan out a payload to currently subscribed clients.
|
||||
|
||||
Returns the number Redis reports as receiving the message (limited
|
||||
to subscribers connected to *this* Redis instance), or 0 if Redis
|
||||
is unavailable. Never raises.
|
||||
"""
|
||||
redis = get_redis_instance()
|
||||
if redis is None:
|
||||
logger.debug("Redis unavailable; dropping publish to %s", self.name)
|
||||
return 0
|
||||
try:
|
||||
return int(redis.publish(self.name, payload))
|
||||
except Exception:
|
||||
logger.exception("Topic.publish failed for %s", self.name)
|
||||
return 0
|
||||
|
||||
def subscribe(
|
||||
self,
|
||||
on_subscribe: Optional[Callable[[], None]] = None,
|
||||
poll_timeout: float = 1.0,
|
||||
) -> Iterator[Optional[bytes]]:
|
||||
"""Subscribe to the topic; yield raw payloads or ``None`` on tick.
|
||||
|
||||
Yields ``None`` every ``poll_timeout`` seconds while idle so the
|
||||
caller can emit keepalive frames or check cancellation. Yields
|
||||
``bytes`` for each delivered message.
|
||||
|
||||
``on_subscribe`` runs synchronously after Redis acknowledges the
|
||||
SUBSCRIBE — use it to seed any state (e.g. read backlog) that
|
||||
must be ordered after the subscriber is live but before the
|
||||
first pub/sub message is processed.
|
||||
|
||||
If Redis is unavailable, returns immediately without yielding.
|
||||
Cleanly unsubscribes on ``GeneratorExit`` (client disconnect).
|
||||
"""
|
||||
redis = get_redis_instance()
|
||||
if redis is None:
|
||||
logger.debug("Redis unavailable; subscribe to %s yielded nothing", self.name)
|
||||
return
|
||||
pubsub = None
|
||||
on_subscribe_fired = False
|
||||
try:
|
||||
pubsub = redis.pubsub()
|
||||
try:
|
||||
pubsub.subscribe(self.name)
|
||||
except Exception:
|
||||
# Subscribe failure (transient Redis hiccup, conn reset, etc.)
|
||||
# is treated like "Redis unavailable": yield nothing, let the
|
||||
# caller fall back to its own resilience strategy. The finally
|
||||
# block will still tear down the pubsub object cleanly.
|
||||
logger.exception("pubsub.subscribe failed for %s", self.name)
|
||||
return
|
||||
while True:
|
||||
try:
|
||||
msg = pubsub.get_message(timeout=poll_timeout)
|
||||
except Exception:
|
||||
logger.exception("pubsub.get_message failed for %s", self.name)
|
||||
return
|
||||
if msg is None:
|
||||
yield None
|
||||
continue
|
||||
msg_type = msg.get("type")
|
||||
if msg_type == "subscribe":
|
||||
if not on_subscribe_fired and on_subscribe is not None:
|
||||
try:
|
||||
on_subscribe()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"on_subscribe callback failed for %s", self.name
|
||||
)
|
||||
on_subscribe_fired = True
|
||||
continue
|
||||
if msg_type != "message":
|
||||
continue
|
||||
data = msg.get("data")
|
||||
if data is None:
|
||||
continue
|
||||
yield data if isinstance(data, bytes) else str(data).encode("utf-8")
|
||||
finally:
|
||||
if pubsub is not None:
|
||||
if on_subscribe_fired:
|
||||
try:
|
||||
pubsub.unsubscribe(self.name)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"pubsub unsubscribe error for %s",
|
||||
self.name,
|
||||
exc_info=True,
|
||||
)
|
||||
try:
|
||||
pubsub.close()
|
||||
except Exception:
|
||||
logger.debug("pubsub close error for %s", self.name, exc_info=True)
|
||||
434
application/streaming/event_replay.py
Normal file
434
application/streaming/event_replay.py
Normal file
@@ -0,0 +1,434 @@
|
||||
"""Snapshot+tail iterator for chat-stream reconnect-after-disconnect.
|
||||
|
||||
Subscribe to ``channel:{message_id}``, snapshot ``message_events``
|
||||
rows past ``last_event_id`` inside the SUBSCRIBE-ack callback, flush
|
||||
snapshot, then tail live pub/sub (dedup'd by ``sequence_no``). See
|
||||
``docs/runbooks/sse-notifications.md``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from typing import Iterator, Optional
|
||||
|
||||
from sqlalchemy import text as sql_text
|
||||
|
||||
from application.storage.db.repositories.message_events import (
|
||||
MessageEventsRepository,
|
||||
)
|
||||
from application.storage.db.session import db_readonly
|
||||
from application.streaming.broadcast_channel import Topic
|
||||
from application.streaming.keys import message_topic_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_KEEPALIVE_SECONDS = 15.0
|
||||
DEFAULT_POLL_TIMEOUT_SECONDS = 1.0
|
||||
# When the live tail has no events and no terminal in snapshot, fall
|
||||
# back to checking ``conversation_messages`` directly. If the row has
|
||||
# already gone terminal (worker journaled ``end``/``error`` to the DB
|
||||
# but the matching pub/sub publish was lost, or the row was finalized
|
||||
# without a journal write at all) we surface a terminal event so the
|
||||
# client doesn't hang on keepalives. If the row is still non-terminal
|
||||
# but the producer heartbeat is older than ``PRODUCER_IDLE_SECONDS``
|
||||
# the producer is presumed dead (worker crash / recycle between chunks
|
||||
# and finalize) and we emit a terminal ``error`` so the UI can recover.
|
||||
DEFAULT_WATCHDOG_INTERVAL_SECONDS = 5.0
|
||||
# 1.5× the route's 60s heartbeat interval — long enough that a normal
|
||||
# heartbeat skew doesn't false-positive, short enough that a stuck
|
||||
# stream surfaces before the 5-minute reconciler sweep escalates.
|
||||
DEFAULT_PRODUCER_IDLE_SECONDS = 90.0
|
||||
|
||||
# WHATWG SSE accepts CRLF, CR, LF — split on any of them so a stray CR
|
||||
# can't smuggle a record boundary into the wire format.
|
||||
_SSE_LINE_SPLIT_PATTERN = re.compile(r"\r\n|\r|\n")
|
||||
|
||||
# Event types that mark the end of a chat answer. After delivering one
|
||||
# we close the reconnect stream — keeping the connection open past a
|
||||
# terminal event would leak both the client's reconnect promise and
|
||||
# the server's WSGI thread waiting on keepalives that the user no
|
||||
# longer cares about. The agent loop emits ``end`` for normal /
|
||||
# tool-paused completion and ``error`` for the catch-all failure path
|
||||
# (which doesn't get a trailing ``end``).
|
||||
_TERMINAL_EVENT_TYPES = frozenset({"end", "error"})
|
||||
|
||||
|
||||
def _payload_is_terminal(
|
||||
payload: object, event_type: Optional[str] = None
|
||||
) -> bool:
|
||||
"""True if ``payload['type']`` or ``event_type`` is a terminal sentinel."""
|
||||
if isinstance(payload, dict) and payload.get("type") in _TERMINAL_EVENT_TYPES:
|
||||
return True
|
||||
return event_type in _TERMINAL_EVENT_TYPES
|
||||
|
||||
|
||||
def format_sse_event(payload: dict, sequence_no: int) -> str:
|
||||
"""Encode a journal event as one ``id:``/``data:`` SSE record.
|
||||
|
||||
The body is the payload's JSON serialisation. ``complete_stream``
|
||||
payloads are flat JSON dicts with no embedded newlines, so a
|
||||
single ``data:`` line is sufficient — but we still split on any
|
||||
line terminator in case a future caller passes a multi-line string.
|
||||
"""
|
||||
body = json.dumps(payload)
|
||||
lines = [f"id: {sequence_no}"]
|
||||
for line in _SSE_LINE_SPLIT_PATTERN.split(body):
|
||||
lines.append(f"data: {line}")
|
||||
return "\n".join(lines) + "\n\n"
|
||||
|
||||
|
||||
def _check_producer_liveness(
|
||||
message_id: str, idle_seconds: float
|
||||
) -> Optional[dict]:
|
||||
"""Inspect ``conversation_messages`` and return a terminal SSE
|
||||
payload when the producer is no longer alive, else ``None``.
|
||||
|
||||
Three terminal cases collapse into a single DB round-trip:
|
||||
|
||||
- ``status='complete'`` — the live finalize ran but its journal
|
||||
terminal write didn't reach us (or never happened). Synthesise
|
||||
``end`` so the client closes cleanly on the row's user-visible
|
||||
state.
|
||||
- ``status='failed'`` — same, but for the failure path. Carry the
|
||||
stashed ``error`` from ``message_metadata`` so the UI shows the
|
||||
real reason.
|
||||
- non-terminal status and ``last_heartbeat_at`` (or ``timestamp``)
|
||||
older than ``idle_seconds`` — the producing worker is gone.
|
||||
Synthesise ``error`` so the client doesn't hang on keepalives
|
||||
until the proxy idle-timeout kicks in.
|
||||
"""
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
row = conn.execute(
|
||||
sql_text(
|
||||
"""
|
||||
SELECT
|
||||
status,
|
||||
message_metadata->>'error' AS err,
|
||||
GREATEST(
|
||||
timestamp,
|
||||
COALESCE(
|
||||
(message_metadata->>'last_heartbeat_at')
|
||||
::timestamptz,
|
||||
timestamp
|
||||
)
|
||||
) < now() - make_interval(secs => :idle_secs)
|
||||
AS is_stale
|
||||
FROM conversation_messages
|
||||
WHERE id = CAST(:id AS uuid)
|
||||
"""
|
||||
),
|
||||
{"id": message_id, "idle_secs": float(idle_seconds)},
|
||||
).first()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Watchdog liveness check failed for message_id=%s", message_id
|
||||
)
|
||||
return None
|
||||
|
||||
if row is None:
|
||||
# Row deleted out from under us — treat as terminal so the
|
||||
# client doesn't keep tailing a message that no longer exists.
|
||||
return {
|
||||
"type": "error",
|
||||
"error": "Message no longer exists; please refresh.",
|
||||
"code": "message_missing",
|
||||
"message_id": message_id,
|
||||
}
|
||||
|
||||
status, err, is_stale = row[0], row[1], bool(row[2])
|
||||
if status == "complete":
|
||||
return {"type": "end"}
|
||||
if status == "failed":
|
||||
return {
|
||||
"type": "error",
|
||||
"error": err or "Stream failed; please try again.",
|
||||
"code": "producer_failed",
|
||||
"message_id": message_id,
|
||||
}
|
||||
if is_stale:
|
||||
return {
|
||||
"type": "error",
|
||||
"error": (
|
||||
"Stream producer is no longer responding; please try again."
|
||||
),
|
||||
"code": "producer_stale",
|
||||
"message_id": message_id,
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
def build_message_event_stream(
|
||||
message_id: str,
|
||||
last_event_id: Optional[int] = None,
|
||||
*,
|
||||
keepalive_seconds: float = DEFAULT_KEEPALIVE_SECONDS,
|
||||
poll_timeout_seconds: float = DEFAULT_POLL_TIMEOUT_SECONDS,
|
||||
watchdog_interval_seconds: float = DEFAULT_WATCHDOG_INTERVAL_SECONDS,
|
||||
producer_idle_seconds: float = DEFAULT_PRODUCER_IDLE_SECONDS,
|
||||
) -> Iterator[str]:
|
||||
"""Yield SSE-formatted lines for one ``message_id`` reconnect stream.
|
||||
|
||||
First frame is ``: connected``; subsequent frames are snapshot rows,
|
||||
live-tail events, or ``: keepalive`` comments. Runs until the client
|
||||
disconnects.
|
||||
"""
|
||||
yield ": connected\n\n"
|
||||
|
||||
# Replay buffer — populated inside ``_on_subscribe`` (or the
|
||||
# Redis-unavailable fallback below), drained on the first iteration
|
||||
# of the subscribe loop after the callback runs.
|
||||
replay_buffer: list[str] = []
|
||||
# Dedup floor: seeded with the client's cursor so an empty snapshot
|
||||
# still rejects re-published live events with seq <= last_event_id.
|
||||
# Advanced by snapshot rows AND by yielded live events, so any
|
||||
# republish past the snapshot ceiling is also dropped.
|
||||
max_replayed_seq: Optional[int] = last_event_id
|
||||
replay_done = False
|
||||
replay_failed = False
|
||||
# Set when a snapshot row carries a terminal ``end`` / ``error``
|
||||
# event. After flushing the buffer the generator returns; if we
|
||||
# kept tailing we'd loop on keepalives forever for a stream that
|
||||
# already finished.
|
||||
terminal_in_snapshot = False
|
||||
|
||||
def _read_snapshot_into_buffer() -> None:
|
||||
nonlocal max_replayed_seq, replay_failed, terminal_in_snapshot
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
rows = MessageEventsRepository(conn).read_after(
|
||||
message_id, last_sequence_no=last_event_id
|
||||
)
|
||||
for row in rows:
|
||||
seq = int(row["sequence_no"])
|
||||
payload = row.get("payload")
|
||||
if not isinstance(payload, dict):
|
||||
# ``record_event`` rejects non-dict payloads at the
|
||||
# write gate, so this can only be a legacy row from
|
||||
# before that contract or a direct SQL insert. The
|
||||
# original synthetic fallback (``{"type": event_type}``)
|
||||
# used to ship a malformed envelope here — drop the
|
||||
# row instead so a corrupt journal entry doesn't
|
||||
# poison a reconnect.
|
||||
logger.warning(
|
||||
"Skipping non-dict payload from message_events: "
|
||||
"message_id=%s seq=%s type=%s",
|
||||
message_id,
|
||||
seq,
|
||||
row.get("event_type"),
|
||||
)
|
||||
continue
|
||||
replay_buffer.append(format_sse_event(payload, seq))
|
||||
if max_replayed_seq is None or seq > max_replayed_seq:
|
||||
max_replayed_seq = seq
|
||||
if _payload_is_terminal(payload, row.get("event_type")):
|
||||
terminal_in_snapshot = True
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Snapshot read failed for message_id=%s last_event_id=%s",
|
||||
message_id,
|
||||
last_event_id,
|
||||
)
|
||||
replay_failed = True
|
||||
|
||||
def _on_subscribe() -> None:
|
||||
# SUBSCRIBE has been acked — Postgres reads from this point
|
||||
# capture every row that's been committed. Pub/sub messages
|
||||
# published after this point are queued at the connection level
|
||||
# until the outer loop calls ``get_message`` again.
|
||||
nonlocal replay_done
|
||||
try:
|
||||
_read_snapshot_into_buffer()
|
||||
finally:
|
||||
# Flip even on failure so the outer loop continues to live
|
||||
# tail and the client doesn't hang waiting for a snapshot
|
||||
# flush that will never come.
|
||||
replay_done = True
|
||||
|
||||
topic = Topic(message_topic_name(message_id))
|
||||
last_keepalive = time.monotonic()
|
||||
# Rate-limit the watchdog's DB hit. ``-inf`` makes the first idle
|
||||
# tick after replay_done fire immediately so a snapshot-already-
|
||||
# terminal-in-DB case is surfaced before any keepalive cadence.
|
||||
# Subsequent checks are gated by ``watchdog_interval_seconds``.
|
||||
last_watchdog_check = float("-inf")
|
||||
# Synthetic terminal events emitted by the watchdog use the same
|
||||
# ``sequence_no=-1`` convention as the snapshot-failure path so the
|
||||
# frontend's strict ``\d+`` cursor regex rejects them as a
|
||||
# ``Last-Event-ID`` for any future reconnect. The chosen
|
||||
# discriminator ensures a manual page refresh after a watchdog-fired
|
||||
# error doesn't loop on the same synthetic id.
|
||||
watchdog_synthetic_seq = -1
|
||||
|
||||
try:
|
||||
for payload in topic.subscribe(
|
||||
on_subscribe=_on_subscribe,
|
||||
poll_timeout=poll_timeout_seconds,
|
||||
):
|
||||
# Flush snapshot exactly once after the SUBSCRIBE callback
|
||||
# has run and produced a buffer.
|
||||
if replay_done and replay_buffer:
|
||||
for line in replay_buffer:
|
||||
yield line
|
||||
replay_buffer.clear()
|
||||
if terminal_in_snapshot:
|
||||
# The original stream already finished; tailing
|
||||
# would just emit keepalives forever and pin both a
|
||||
# client reconnect promise and a server WSGI thread.
|
||||
return
|
||||
|
||||
if replay_failed:
|
||||
# Snapshot read failed (DB blip / transient timeout). Emit a
|
||||
# terminal ``error`` event and return — the client only
|
||||
# reconnects after the original stream has already moved on,
|
||||
# so without a snapshot there's nothing live left to tail and
|
||||
# holding the connection open would just emit keepalives
|
||||
# until the proxy idle-timeout fires. ``code`` preserves the
|
||||
# snapshot-vs-agent-loop distinction so a future client can
|
||||
# opt into a refetch instead of a hard failure.
|
||||
yield format_sse_event(
|
||||
{
|
||||
"type": "error",
|
||||
"error": "Stream replay failed; please refresh to load the latest state.",
|
||||
"code": "snapshot_failed",
|
||||
"message_id": message_id,
|
||||
},
|
||||
sequence_no=-1,
|
||||
)
|
||||
return
|
||||
|
||||
now = time.monotonic()
|
||||
if payload is None:
|
||||
# Idle tick — check both keepalive and watchdog. The
|
||||
# watchdog only kicks in once the snapshot half has been
|
||||
# flushed (``replay_done``) so we don't race the
|
||||
# snapshot read on the first iteration.
|
||||
if (
|
||||
replay_done
|
||||
and watchdog_interval_seconds >= 0
|
||||
and now - last_watchdog_check >= watchdog_interval_seconds
|
||||
):
|
||||
last_watchdog_check = now
|
||||
terminal_payload = _check_producer_liveness(
|
||||
message_id, producer_idle_seconds
|
||||
)
|
||||
if terminal_payload is not None:
|
||||
yield format_sse_event(
|
||||
terminal_payload,
|
||||
sequence_no=watchdog_synthetic_seq,
|
||||
)
|
||||
return
|
||||
if now - last_keepalive >= keepalive_seconds:
|
||||
yield ": keepalive\n\n"
|
||||
last_keepalive = now
|
||||
continue
|
||||
|
||||
envelope = _decode_pubsub_message(payload)
|
||||
if envelope is None:
|
||||
continue
|
||||
seq = envelope.get("sequence_no")
|
||||
inner = envelope.get("payload")
|
||||
if (
|
||||
not isinstance(seq, int)
|
||||
or isinstance(seq, bool)
|
||||
or not isinstance(inner, dict)
|
||||
):
|
||||
continue
|
||||
if max_replayed_seq is not None and seq <= max_replayed_seq:
|
||||
# Snapshot already covered this id — drop the duplicate.
|
||||
continue
|
||||
yield format_sse_event(inner, seq)
|
||||
# Advance the dedup floor on the live path too, so a stale
|
||||
# republish of an already-yielded seq (process restart, retry
|
||||
# tool, etc.) is dropped on a later iteration.
|
||||
max_replayed_seq = seq
|
||||
last_keepalive = now
|
||||
if _payload_is_terminal(inner, envelope.get("event_type")):
|
||||
# Live tail just delivered the terminal event — close
|
||||
# out the reconnect stream so the client's drain
|
||||
# promise resolves and the WSGI thread is freed.
|
||||
return
|
||||
|
||||
# Subscribe exited without ever yielding (Redis unavailable,
|
||||
# ``pubsub.subscribe`` raised, or the inner loop died between
|
||||
# SUBSCRIBE-ack and the first poll). The snapshot half is in
|
||||
# Postgres and is still serviceable — read it directly so a
|
||||
# Redis-only outage doesn't cost the client their reconnect
|
||||
# backlog. Gate the read on ``replay_done`` rather than
|
||||
# ``subscribe_started``: if ``_on_subscribe`` already populated
|
||||
# the buffer, re-reading would append the same rows twice and
|
||||
# double the answer chunks on the client (the per-message
|
||||
# reconnect dispatcher does not dedup by ``id``).
|
||||
if not replay_done:
|
||||
_read_snapshot_into_buffer()
|
||||
replay_done = True
|
||||
for line in replay_buffer:
|
||||
yield line
|
||||
replay_buffer.clear()
|
||||
if replay_failed:
|
||||
# Mirror the live-tail branch: emit a terminal ``error`` so
|
||||
# the frontend's existing end/error handling drives the UI
|
||||
# to a failed state instead of relying on the proxy timeout.
|
||||
yield format_sse_event(
|
||||
{
|
||||
"type": "error",
|
||||
"error": "Stream replay failed; please refresh to load the latest state.",
|
||||
"code": "snapshot_failed",
|
||||
"message_id": message_id,
|
||||
},
|
||||
sequence_no=-1,
|
||||
)
|
||||
return
|
||||
# Same close-on-terminal contract as the live-tail branch.
|
||||
# Without it a Redis-down + already-completed-stream client
|
||||
# would also hang on a never-ending generator.
|
||||
if terminal_in_snapshot:
|
||||
return
|
||||
except GeneratorExit:
|
||||
# Client disconnect — let the underlying ``Topic.subscribe``
|
||||
# ``finally`` block tear down its pubsub cleanly.
|
||||
return
|
||||
|
||||
|
||||
def _decode_pubsub_message(raw) -> Optional[dict]:
|
||||
"""Parse a ``Topic.publish`` payload to ``{sequence_no, payload, ...}``.
|
||||
|
||||
Returns ``None`` for malformed messages (drop silently — the
|
||||
journal is still authoritative on reconnect).
|
||||
"""
|
||||
try:
|
||||
if isinstance(raw, (bytes, bytearray)):
|
||||
text_value = raw.decode("utf-8")
|
||||
else:
|
||||
text_value = str(raw)
|
||||
envelope = json.loads(text_value)
|
||||
except Exception:
|
||||
return None
|
||||
if not isinstance(envelope, dict):
|
||||
return None
|
||||
return envelope
|
||||
|
||||
|
||||
def encode_pubsub_message(
|
||||
message_id: str,
|
||||
sequence_no: int,
|
||||
event_type: str,
|
||||
payload: dict,
|
||||
) -> str:
|
||||
"""Build the JSON envelope used for ``channel:{message_id}`` publishes.
|
||||
|
||||
Kept here (not in ``message_journal.py``) so the encode/decode pair
|
||||
stays in one file — replay's ``_decode_pubsub_message`` and the
|
||||
journal's publish must agree on the shape exactly.
|
||||
"""
|
||||
return json.dumps(
|
||||
{
|
||||
"message_id": str(message_id),
|
||||
"sequence_no": int(sequence_no),
|
||||
"event_type": event_type,
|
||||
"payload": payload,
|
||||
}
|
||||
)
|
||||
19
application/streaming/keys.py
Normal file
19
application/streaming/keys.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Per-chat-message stream key derivations.
|
||||
|
||||
Single source of truth for the Redis pub/sub topic name and any
|
||||
auxiliary keys that the chat-stream snapshot+tail reconnect path
|
||||
shares between the writer (``complete_stream`` + journal) and the
|
||||
reader (``/api/messages/<id>/events`` reconnect endpoint).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def message_topic_name(message_id: str) -> str:
|
||||
"""Redis pub/sub channel for live fan-out of one chat message.
|
||||
|
||||
Subscribers tail this topic for every event that ``complete_stream``
|
||||
yielded after the SUBSCRIBE-ack arrived; older events are recovered
|
||||
from the ``message_events`` snapshot half of the pattern.
|
||||
"""
|
||||
return f"channel:{message_id}"
|
||||
400
application/streaming/message_journal.py
Normal file
400
application/streaming/message_journal.py
Normal file
@@ -0,0 +1,400 @@
|
||||
"""Per-yield journal write for the chat-stream snapshot+tail pattern.
|
||||
|
||||
``record_event`` inserts into ``message_events`` and publishes to
|
||||
``channel:{message_id}``. Both are best-effort; the INSERT commits
|
||||
before the publish so a fast reconnect sees the row. See
|
||||
``docs/runbooks/sse-notifications.md``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from application.storage.db.repositories.message_events import (
|
||||
MessageEventsRepository,
|
||||
)
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.streaming.broadcast_channel import Topic
|
||||
from application.streaming.event_replay import encode_pubsub_message
|
||||
from application.streaming.keys import message_topic_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Tunables for ``BatchedJournalWriter``. A streaming answer emits ~100s
|
||||
# of ``answer`` chunks per response; without batching, that's one PG
|
||||
# transaction per yield in the WSGI thread. With these defaults, ~10x
|
||||
# fewer commits at the cost of a ≤100ms reconnect-visibility lag for
|
||||
# any event still sitting in the buffer.
|
||||
DEFAULT_BATCH_SIZE = 16
|
||||
DEFAULT_BATCH_INTERVAL_MS = 100
|
||||
|
||||
|
||||
def _strip_null_bytes(value: Any) -> Any:
|
||||
"""Recursively strip ``\\x00`` from string keys/values in ``value``.
|
||||
|
||||
Postgres JSONB rejects the NUL escape; an LLM emitting a stray NUL
|
||||
in a chunk would otherwise raise ``DataError`` at INSERT and the row
|
||||
would be lost from the journal (live stream proceeds, reconnect
|
||||
snapshot misses the chunk). Mirrors the strip already done in
|
||||
``parser/embedding_pipeline.py`` and
|
||||
``api/user/attachments/routes.py``.
|
||||
"""
|
||||
if isinstance(value, str):
|
||||
return value.replace("\x00", "") if "\x00" in value else value
|
||||
if isinstance(value, dict):
|
||||
return {
|
||||
(k.replace("\x00", "") if isinstance(k, str) and "\x00" in k else k):
|
||||
_strip_null_bytes(v)
|
||||
for k, v in value.items()
|
||||
}
|
||||
if isinstance(value, list):
|
||||
return [_strip_null_bytes(item) for item in value]
|
||||
if isinstance(value, tuple):
|
||||
return tuple(_strip_null_bytes(item) for item in value)
|
||||
return value
|
||||
|
||||
|
||||
def record_event(
|
||||
message_id: str,
|
||||
sequence_no: int,
|
||||
event_type: str,
|
||||
payload: Optional[dict[str, Any]] = None,
|
||||
) -> bool:
|
||||
"""Journal one SSE event and publish it live. Best-effort.
|
||||
|
||||
``payload`` must be a ``dict`` or ``None`` (non-dicts are dropped so
|
||||
live and replay envelopes stay byte-identical). Returns ``True`` when
|
||||
the journal INSERT committed. Never raises.
|
||||
"""
|
||||
if not message_id or not event_type:
|
||||
logger.warning(
|
||||
"record_event called without message_id/event_type "
|
||||
"(message_id=%r, event_type=%r)",
|
||||
message_id,
|
||||
event_type,
|
||||
)
|
||||
return False
|
||||
|
||||
if payload is None:
|
||||
materialised_payload: dict[str, Any] = {}
|
||||
elif isinstance(payload, dict):
|
||||
materialised_payload = _strip_null_bytes(payload)
|
||||
else:
|
||||
logger.warning(
|
||||
"record_event called with non-dict payload "
|
||||
"(message_id=%s seq=%s type=%s payload_type=%s) — dropping",
|
||||
message_id,
|
||||
sequence_no,
|
||||
event_type,
|
||||
type(payload).__name__,
|
||||
)
|
||||
return False
|
||||
|
||||
journal_committed = False
|
||||
# The seq we actually managed to write. Diverges from
|
||||
# ``sequence_no`` only on the IntegrityError-retry path below.
|
||||
materialised_seq = sequence_no
|
||||
try:
|
||||
# Short-lived per-event transaction. Critical for visibility:
|
||||
# the reconnect endpoint reads the journal from a separate
|
||||
# connection and only sees committed rows.
|
||||
with db_session() as conn:
|
||||
MessageEventsRepository(conn).record(
|
||||
message_id, sequence_no, event_type, materialised_payload
|
||||
)
|
||||
journal_committed = True
|
||||
except IntegrityError:
|
||||
# Composite-PK collision on (message_id, sequence_no). Most
|
||||
# likely cause is a stale ``latest_sequence_no`` seed on a
|
||||
# continuation retry — the route read MAX(seq) from a separate
|
||||
# connection before another writer committed past it. Look up
|
||||
# the live latest and retry once with latest+1 so the event is
|
||||
# not silently lost. Bounded to a single retry — if two
|
||||
# writers keep racing in lockstep the route-level retry will
|
||||
# converge them across attempts.
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
latest = MessageEventsRepository(conn).latest_sequence_no(
|
||||
message_id
|
||||
)
|
||||
materialised_seq = (latest if latest is not None else -1) + 1
|
||||
with db_session() as conn:
|
||||
MessageEventsRepository(conn).record(
|
||||
message_id,
|
||||
materialised_seq,
|
||||
event_type,
|
||||
materialised_payload,
|
||||
)
|
||||
journal_committed = True
|
||||
logger.info(
|
||||
"record_event: collision at seq=%s recovered → wrote at "
|
||||
"seq=%s message_id=%s type=%s",
|
||||
sequence_no,
|
||||
materialised_seq,
|
||||
message_id,
|
||||
event_type,
|
||||
)
|
||||
except IntegrityError:
|
||||
# Second collision under the same retry — give up and log.
|
||||
# The route's nonlocal counter will continue at
|
||||
# ``sequence_no+1`` on the next emit; the next call may
|
||||
# land cleanly past the contended window.
|
||||
logger.warning(
|
||||
"record_event: IntegrityError persists after seq+1 retry; "
|
||||
"dropping. message_id=%s original_seq=%s retry_seq=%s "
|
||||
"type=%s",
|
||||
message_id,
|
||||
sequence_no,
|
||||
materialised_seq,
|
||||
event_type,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"record_event: retry path failed unexpectedly "
|
||||
"(message_id=%s seq=%s type=%s)",
|
||||
message_id,
|
||||
sequence_no,
|
||||
event_type,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"message_events INSERT failed: message_id=%s seq=%s type=%s",
|
||||
message_id,
|
||||
sequence_no,
|
||||
event_type,
|
||||
)
|
||||
|
||||
try:
|
||||
# Publish using ``materialised_seq`` so the live pubsub frame
|
||||
# matches the journal row that other clients will snapshot on
|
||||
# reconnect. The original POST stream's SSE ``id:`` still
|
||||
# carries the caller's ``sequence_no`` — a reconnect from that
|
||||
# client will receive the same event at ``materialised_seq``
|
||||
# on the snapshot, which is a benign duplicate (the slice's
|
||||
# ``max_replayed_seq`` advances past it). No-collision case:
|
||||
# ``materialised_seq == sequence_no`` and this is identical to
|
||||
# the prior behaviour.
|
||||
wire = encode_pubsub_message(
|
||||
message_id, materialised_seq, event_type, materialised_payload
|
||||
)
|
||||
Topic(message_topic_name(message_id)).publish(wire)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"channel:%s publish failed: seq=%s type=%s",
|
||||
message_id,
|
||||
materialised_seq,
|
||||
event_type,
|
||||
)
|
||||
|
||||
return journal_committed
|
||||
|
||||
|
||||
class BatchedJournalWriter:
|
||||
"""Per-stream journal writer that batches PG INSERTs.
|
||||
|
||||
One writer per ``message_id``; ``record()`` buffers events and flushes
|
||||
on size/time/``close()`` triggers. Pubsub publishes fire only after the
|
||||
INSERT commits. On ``IntegrityError`` falls back to per-row writes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message_id: str,
|
||||
*,
|
||||
batch_size: int = DEFAULT_BATCH_SIZE,
|
||||
batch_interval_ms: int = DEFAULT_BATCH_INTERVAL_MS,
|
||||
) -> None:
|
||||
self._message_id = message_id
|
||||
self._batch_size = batch_size
|
||||
self._batch_interval_ms = batch_interval_ms
|
||||
self._buffer: list[tuple[int, str, dict[str, Any]]] = []
|
||||
self._last_flush_mono_ms = time.monotonic() * 1000.0
|
||||
self._closed = False
|
||||
|
||||
def record(
|
||||
self,
|
||||
sequence_no: int,
|
||||
event_type: str,
|
||||
payload: Optional[dict[str, Any]] = None,
|
||||
) -> bool:
|
||||
"""Buffer one event; maybe flush. Publish happens after journal commit."""
|
||||
if self._closed:
|
||||
logger.warning(
|
||||
"BatchedJournalWriter.record after close: "
|
||||
"message_id=%s seq=%s type=%s",
|
||||
self._message_id,
|
||||
sequence_no,
|
||||
event_type,
|
||||
)
|
||||
return False
|
||||
if not event_type:
|
||||
logger.warning(
|
||||
"BatchedJournalWriter.record without event_type: "
|
||||
"message_id=%s seq=%s",
|
||||
self._message_id,
|
||||
sequence_no,
|
||||
)
|
||||
return False
|
||||
if payload is None:
|
||||
materialised: dict[str, Any] = {}
|
||||
elif isinstance(payload, dict):
|
||||
materialised = _strip_null_bytes(payload)
|
||||
else:
|
||||
# Same contract as ``record_event`` — non-dict payloads
|
||||
# are rejected so the live and replay paths can't diverge
|
||||
# on envelope reconstruction.
|
||||
logger.warning(
|
||||
"BatchedJournalWriter.record with non-dict payload: "
|
||||
"message_id=%s seq=%s type=%s payload_type=%s — dropping",
|
||||
self._message_id,
|
||||
sequence_no,
|
||||
event_type,
|
||||
type(payload).__name__,
|
||||
)
|
||||
return False
|
||||
|
||||
self._buffer.append((sequence_no, event_type, materialised))
|
||||
|
||||
if self._should_flush():
|
||||
self.flush()
|
||||
return True
|
||||
|
||||
def _should_flush(self) -> bool:
|
||||
if len(self._buffer) >= self._batch_size:
|
||||
return True
|
||||
elapsed_ms = (time.monotonic() * 1000.0) - self._last_flush_mono_ms
|
||||
return elapsed_ms >= self._batch_interval_ms and len(self._buffer) > 0
|
||||
|
||||
def flush(self) -> None:
|
||||
"""Commit buffered events to PG. Best-effort.
|
||||
|
||||
Tries one bulk INSERT first; on ``IntegrityError`` (composite
|
||||
PK collision — typically a stale continuation seed) falls back
|
||||
to per-row ``record_event`` so one bad seq doesn't drop the
|
||||
rest of the batch. Always clears the buffer to bound memory,
|
||||
even on failure — a journaled event missing from a snapshot
|
||||
is degraded UX, but a runaway buffer is corruption.
|
||||
"""
|
||||
if not self._buffer:
|
||||
self._last_flush_mono_ms = time.monotonic() * 1000.0
|
||||
return
|
||||
|
||||
# Snapshot and clear before the I/O so a concurrent record()
|
||||
# call would land in a fresh buffer rather than racing the
|
||||
# flush. ``complete_stream`` is single-threaded per stream, so
|
||||
# this is belt-and-suspenders for any future change.
|
||||
pending = self._buffer
|
||||
self._buffer = []
|
||||
self._last_flush_mono_ms = time.monotonic() * 1000.0
|
||||
|
||||
try:
|
||||
with db_session() as conn:
|
||||
MessageEventsRepository(conn).bulk_record(
|
||||
self._message_id, pending
|
||||
)
|
||||
except IntegrityError:
|
||||
logger.info(
|
||||
"BatchedJournalWriter: bulk INSERT collided for "
|
||||
"message_id=%s n=%d; falling back to per-row writes",
|
||||
self._message_id,
|
||||
len(pending),
|
||||
)
|
||||
self._flush_per_row(pending)
|
||||
return
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"BatchedJournalWriter: bulk INSERT failed for "
|
||||
"message_id=%s n=%d; events dropped from journal",
|
||||
self._message_id,
|
||||
len(pending),
|
||||
)
|
||||
return
|
||||
|
||||
# Bulk INSERT committed — publish each frame in order. Best-effort:
|
||||
# one failed publish must not poison the rest of the batch.
|
||||
for seq, event_type, payload in pending:
|
||||
self._publish(seq, event_type, payload)
|
||||
|
||||
def _flush_per_row(
|
||||
self, pending: list[tuple[int, str, dict[str, Any]]]
|
||||
) -> None:
|
||||
"""Per-row fallback after a bulk collision. Publishes after each commit."""
|
||||
for seq, event_type, payload in pending:
|
||||
committed_seq: Optional[int] = None
|
||||
try:
|
||||
with db_session() as conn:
|
||||
MessageEventsRepository(conn).record(
|
||||
self._message_id, seq, event_type, payload
|
||||
)
|
||||
committed_seq = seq
|
||||
except IntegrityError:
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
latest = MessageEventsRepository(
|
||||
conn
|
||||
).latest_sequence_no(self._message_id)
|
||||
retry_seq = (latest if latest is not None else -1) + 1
|
||||
with db_session() as conn:
|
||||
MessageEventsRepository(conn).record(
|
||||
self._message_id, retry_seq, event_type, payload
|
||||
)
|
||||
committed_seq = retry_seq
|
||||
except IntegrityError:
|
||||
logger.warning(
|
||||
"BatchedJournalWriter: IntegrityError persists "
|
||||
"after seq+1 retry; dropping. message_id=%s "
|
||||
"original_seq=%s type=%s",
|
||||
self._message_id,
|
||||
seq,
|
||||
event_type,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"BatchedJournalWriter: per-row retry failed "
|
||||
"(message_id=%s seq=%s type=%s)",
|
||||
self._message_id,
|
||||
seq,
|
||||
event_type,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"BatchedJournalWriter: per-row INSERT failed "
|
||||
"(message_id=%s seq=%s type=%s)",
|
||||
self._message_id,
|
||||
seq,
|
||||
event_type,
|
||||
)
|
||||
|
||||
if committed_seq is not None:
|
||||
self._publish(committed_seq, event_type, payload)
|
||||
|
||||
def _publish(
|
||||
self, sequence_no: int, event_type: str, payload: dict[str, Any]
|
||||
) -> None:
|
||||
"""Publish one frame to the per-message pubsub channel. Best-effort."""
|
||||
try:
|
||||
wire = encode_pubsub_message(
|
||||
self._message_id, sequence_no, event_type, payload
|
||||
)
|
||||
Topic(message_topic_name(self._message_id)).publish(wire)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"channel:%s publish failed: seq=%s type=%s",
|
||||
self._message_id,
|
||||
sequence_no,
|
||||
event_type,
|
||||
)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Final flush. Idempotent — safe to call from multiple
|
||||
finally clauses.
|
||||
"""
|
||||
if self._closed:
|
||||
return
|
||||
self.flush()
|
||||
self._closed = True
|
||||
@@ -19,8 +19,8 @@ import requests
|
||||
from application.agents.agent_creator import AgentCreator
|
||||
from application.api.answer.services.stream_processor import get_prompt
|
||||
|
||||
from application.cache import get_redis_instance
|
||||
from application.core.settings import settings
|
||||
from application.events.publisher import publish_user_event
|
||||
from application.parser.chunking import Chunker
|
||||
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||
from application.parser.embedding_pipeline import (
|
||||
@@ -52,18 +52,16 @@ MAX_TOKENS = 1250
|
||||
RECURSION_DEPTH = 2
|
||||
INGEST_HEARTBEAT_INTERVAL_SECONDS = 30
|
||||
|
||||
# Stable namespace for deterministic source IDs derived from idempotency keys.
|
||||
# Pinned literal — do not change. Re-rolling this would mint different
|
||||
# source_ids for the same idempotency_keys across deploys, defeating the
|
||||
# retry-resume contract.
|
||||
DOCSGPT_INGEST_NAMESPACE = uuid.UUID("fa25d5d1-398b-46df-ac89-8d1c360b9bea")
|
||||
|
||||
|
||||
def _derive_source_id(idempotency_key):
|
||||
"""``uuid5(NS, key)`` when a key is supplied; ``uuid4()`` otherwise."""
|
||||
if isinstance(idempotency_key, str) and idempotency_key:
|
||||
return uuid.uuid5(DOCSGPT_INGEST_NAMESPACE, idempotency_key)
|
||||
return uuid.uuid4()
|
||||
# Re-exported here for backward-compatible imports
|
||||
# (``from application.worker import _derive_source_id`` /
|
||||
# ``DOCSGPT_INGEST_NAMESPACE``) from tests and any other in-tree callers.
|
||||
# New code should import from ``application.storage.db.source_ids``
|
||||
# directly to avoid pulling this Celery worker module into the API
|
||||
# process at import time.
|
||||
from application.storage.db.source_ids import ( # noqa: E402, F401
|
||||
DOCSGPT_INGEST_NAMESPACE,
|
||||
derive_source_id as _derive_source_id,
|
||||
)
|
||||
|
||||
|
||||
def _ingest_heartbeat_loop(source_id, stop_event, interval=INGEST_HEARTBEAT_INTERVAL_SECONDS):
|
||||
@@ -510,6 +508,7 @@ def ingest_worker(
|
||||
retriever="classic",
|
||||
file_name_map=None,
|
||||
idempotency_key=None,
|
||||
source_id=None,
|
||||
):
|
||||
"""
|
||||
Ingest and process documents.
|
||||
@@ -527,6 +526,11 @@ def ingest_worker(
|
||||
idempotency_key (str|None): When provided, the ``source_id`` is derived
|
||||
deterministically from the key so a retried task reuses the same
|
||||
source row instead of duplicating it.
|
||||
source_id (str|None): UUID minted by the HTTP route and returned in
|
||||
its response. When supplied, the worker uses it verbatim so SSE
|
||||
envelopes carry the same id the frontend already has — required
|
||||
for non-idempotent uploads where the route can't predict
|
||||
``_derive_source_id(idempotency_key)``.
|
||||
|
||||
Returns:
|
||||
dict: Information about the completed ingestion task, including input parameters and a "limited" flag.
|
||||
@@ -541,10 +545,41 @@ def ingest_worker(
|
||||
|
||||
logging.info(f"Ingest path: {file_path}", extra={"user": user, "job": job_name})
|
||||
|
||||
# Create temporary working directory
|
||||
# Source id resolution order:
|
||||
# 1. Caller-supplied ``source_id`` (HTTP route minted + returned to
|
||||
# the frontend) — keeps the route response and the SSE event
|
||||
# payloads in lockstep on the non-idempotent path.
|
||||
# 2. Deterministic uuid5 from ``idempotency_key`` — retried tasks
|
||||
# reuse the original source row instead of duplicating it.
|
||||
# 3. Fresh uuid4 (caller has neither) — opaque, single-shot only.
|
||||
if source_id:
|
||||
source_uuid = uuid.UUID(source_id)
|
||||
else:
|
||||
source_uuid = _derive_source_id(idempotency_key)
|
||||
source_id_for_events = str(source_uuid)
|
||||
# Only emit ``queued`` on the original attempt. Celery retries re-run
|
||||
# the body, and re-publishing here would oscillate the toast through
|
||||
# ``queued`` again between ``failed`` and ``completed``.
|
||||
if self.request.retries == 0:
|
||||
publish_user_event(
|
||||
user,
|
||||
"source.ingest.queued",
|
||||
{
|
||||
"job_name": job_name,
|
||||
"filename": filename,
|
||||
"source_id": source_id_for_events,
|
||||
"operation": "upload",
|
||||
},
|
||||
scope={"kind": "source", "id": source_id_for_events},
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
try:
|
||||
# Wrap the entire body in try/except so a failure between the
|
||||
# ``queued`` publish above and the inner work (e.g. tempdir
|
||||
# creation, OS-level resource exhaustion) still emits a terminal
|
||||
# ``failed`` event rather than leaving the toast wedged on
|
||||
# 'training' until the polling fallback rescues it 30s later.
|
||||
try:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
|
||||
if storage.is_directory(file_path):
|
||||
@@ -633,23 +668,22 @@ def ingest_worker(
|
||||
|
||||
docs = [Document.to_langchain_format(raw_doc) for raw_doc in raw_docs]
|
||||
|
||||
id = _derive_source_id(idempotency_key)
|
||||
|
||||
vector_store_path = os.path.join(temp_dir, "vector_store")
|
||||
os.makedirs(vector_store_path, exist_ok=True)
|
||||
|
||||
heartbeat_thread, heartbeat_stop = _start_ingest_heartbeat(id)
|
||||
heartbeat_thread, heartbeat_stop = _start_ingest_heartbeat(source_uuid)
|
||||
try:
|
||||
embed_and_store_documents(
|
||||
docs, vector_store_path, id, self,
|
||||
docs, vector_store_path, source_uuid, self,
|
||||
attempt_id=getattr(self.request, "id", None),
|
||||
user_id=user,
|
||||
)
|
||||
finally:
|
||||
_stop_ingest_heartbeat(heartbeat_thread, heartbeat_stop)
|
||||
# Defense-in-depth: chunk-progress is the authoritative
|
||||
# record of how many chunks landed; mismatch raises so the
|
||||
# task fails loud rather than caching a partial index.
|
||||
assert_index_complete(id)
|
||||
assert_index_complete(source_uuid)
|
||||
|
||||
tokens = count_tokens_docs(docs)
|
||||
|
||||
@@ -664,7 +698,7 @@ def ingest_worker(
|
||||
"user": user,
|
||||
"tokens": tokens,
|
||||
"retriever": retriever,
|
||||
"id": str(id),
|
||||
"id": source_id_for_events,
|
||||
"type": "local",
|
||||
"file_path": file_path,
|
||||
"directory_structure": json.dumps(directory_structure),
|
||||
@@ -673,9 +707,36 @@ def ingest_worker(
|
||||
file_data["file_name_map"] = json.dumps(file_name_map)
|
||||
|
||||
upload_index(vector_store_path, file_data)
|
||||
except Exception as e:
|
||||
logging.error(f"Error in ingest_worker: {e}", exc_info=True)
|
||||
raise
|
||||
publish_user_event(
|
||||
user,
|
||||
"source.ingest.completed",
|
||||
{
|
||||
"source_id": source_id_for_events,
|
||||
"filename": filename,
|
||||
"tokens": tokens,
|
||||
"operation": "upload",
|
||||
# Forward-looking contract: ``limited`` is always
|
||||
# ``False`` today but is carried on the wire so a
|
||||
# future token-cap detection path can flip it and
|
||||
# the frontend slice / UploadToast already react.
|
||||
"limited": False,
|
||||
},
|
||||
scope={"kind": "source", "id": source_id_for_events},
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error in ingest_worker: {e}", exc_info=True)
|
||||
publish_user_event(
|
||||
user,
|
||||
"source.ingest.failed",
|
||||
{
|
||||
"source_id": source_id_for_events,
|
||||
"filename": filename,
|
||||
"operation": "upload",
|
||||
"error": str(e)[:1024],
|
||||
},
|
||||
scope={"kind": "source", "id": source_id_for_events},
|
||||
)
|
||||
raise
|
||||
return {
|
||||
"directory": directory,
|
||||
"formats": formats,
|
||||
@@ -699,7 +760,23 @@ def reingest_source_worker(self, source_id, user):
|
||||
|
||||
Returns:
|
||||
dict: Information about the re-ingestion task
|
||||
|
||||
Note:
|
||||
Reingest does its own ``vector_store.add_chunk`` work rather
|
||||
than going through ``embed_and_store_documents`` so it does
|
||||
*not* emit per-percent SSE progress events — only ``queued``,
|
||||
``completed`` (carrying ``chunks_added`` / ``chunks_deleted``),
|
||||
or ``failed``. v1 limitation; revisit if reingest gains a
|
||||
progress-driven UI.
|
||||
"""
|
||||
# Declared at the function scope so the outer except can include
|
||||
# ``name`` in the failed event payload when the failure happens
|
||||
# after the source lookup. Empty string until the lookup succeeds.
|
||||
source_name = ""
|
||||
# Tracks inner-block failures so a ``completed`` event reflects
|
||||
# partial-success accurately rather than masking it.
|
||||
inner_warnings: list[str] = []
|
||||
|
||||
try:
|
||||
from application.vectorstore.vector_creator import VectorCreator
|
||||
|
||||
@@ -713,6 +790,27 @@ def reingest_source_worker(self, source_id, user):
|
||||
if not source:
|
||||
raise ValueError(f"Source {source_id} not found or access denied")
|
||||
source_id = str(source["id"])
|
||||
source_name = source.get("name") or ""
|
||||
|
||||
# Publish ``queued`` *after* canonicalising ``source_id`` so the
|
||||
# event references the same id as the source row. Trade-off
|
||||
# documented: a Celery-backend or PG-lookup hiccup before this
|
||||
# publish means the toast may see only a ``failed`` event with
|
||||
# no preceding ``queued`` — acceptable for v1 since both
|
||||
# conditions also imply broader system trouble. Gate on first
|
||||
# attempt only so Celery retries don't re-emit ``queued`` after
|
||||
# a prior attempt already published ``failed``.
|
||||
if self.request.retries == 0:
|
||||
publish_user_event(
|
||||
user,
|
||||
"source.ingest.queued",
|
||||
{
|
||||
"source_id": source_id,
|
||||
"name": source_name,
|
||||
"operation": "reingest",
|
||||
},
|
||||
scope={"kind": "source", "id": source_id},
|
||||
)
|
||||
|
||||
storage = StorageCreator.get_storage()
|
||||
source_file_path = source.get("file_path", "")
|
||||
@@ -810,6 +908,19 @@ def reingest_source_worker(self, source_id, user):
|
||||
try:
|
||||
if not added_files and not removed_files:
|
||||
logging.info("No changes detected.")
|
||||
publish_user_event(
|
||||
user,
|
||||
"source.ingest.completed",
|
||||
{
|
||||
"source_id": source_id,
|
||||
"name": source_name,
|
||||
"operation": "reingest",
|
||||
"no_changes": True,
|
||||
"chunks_added": 0,
|
||||
"chunks_deleted": 0,
|
||||
},
|
||||
scope={"kind": "source", "id": source_id},
|
||||
)
|
||||
return {
|
||||
"source_id": source_id,
|
||||
"user": user,
|
||||
@@ -861,6 +972,9 @@ def reingest_source_worker(self, source_id, user):
|
||||
f"Error during deletion of removed file chunks: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
inner_warnings.append(
|
||||
f"deletion failed: {str(e)[:200]}"
|
||||
)
|
||||
|
||||
# 2) Add chunks from new files
|
||||
added = 0
|
||||
@@ -953,6 +1067,9 @@ def reingest_source_worker(self, source_id, user):
|
||||
logging.error(
|
||||
f"Error during ingestion of new files: {e}", exc_info=True
|
||||
)
|
||||
inner_warnings.append(
|
||||
f"add failed: {str(e)[:200]}"
|
||||
)
|
||||
|
||||
# 3) Update source directory structure timestamp
|
||||
try:
|
||||
@@ -981,6 +1098,25 @@ def reingest_source_worker(self, source_id, user):
|
||||
meta={"current": 100, "status": "Re-ingestion completed"},
|
||||
)
|
||||
|
||||
completed_payload: dict = {
|
||||
"source_id": source_id,
|
||||
"name": source_name,
|
||||
"operation": "reingest",
|
||||
"chunks_added": added,
|
||||
"chunks_deleted": deleted,
|
||||
"tokens": int(total_tokens) if "total_tokens" in locals() else 0,
|
||||
}
|
||||
if inner_warnings:
|
||||
# Surface the per-block failures so the toast can warn
|
||||
# rather than claim a clean success.
|
||||
completed_payload["warnings"] = inner_warnings
|
||||
publish_user_event(
|
||||
user,
|
||||
"source.ingest.completed",
|
||||
completed_payload,
|
||||
scope={"kind": "source", "id": source_id},
|
||||
)
|
||||
|
||||
return {
|
||||
"source_id": source_id,
|
||||
"user": user,
|
||||
@@ -998,6 +1134,17 @@ def reingest_source_worker(self, source_id, user):
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error in reingest_source_worker: {e}", exc_info=True)
|
||||
publish_user_event(
|
||||
user,
|
||||
"source.ingest.failed",
|
||||
{
|
||||
"source_id": str(source_id),
|
||||
"name": source_name,
|
||||
"operation": "reingest",
|
||||
"error": str(e)[:1024],
|
||||
},
|
||||
scope={"kind": "source", "id": str(source_id)},
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@@ -1013,12 +1160,51 @@ def remote_worker(
|
||||
operation_mode="upload",
|
||||
doc_id=None,
|
||||
idempotency_key=None,
|
||||
source_id=None,
|
||||
):
|
||||
safe_user = safe_filename(user)
|
||||
full_path = os.path.join(directory, safe_user, uuid.uuid4().hex)
|
||||
os.makedirs(full_path, exist_ok=True)
|
||||
self.update_state(state="PROGRESS", meta={"current": 1})
|
||||
|
||||
# Source id resolution order matches ``ingest_worker``:
|
||||
# 1. ``operation_mode == "sync"`` reuses the existing source's ``doc_id``.
|
||||
# 2. Caller-supplied ``source_id`` (the HTTP route minted it and
|
||||
# already returned it to the frontend) — keeps the route
|
||||
# response and the SSE event payloads in lockstep on the
|
||||
# no-idempotency-key path.
|
||||
# 3. Deterministic uuid5 from ``idempotency_key`` — retried tasks
|
||||
# reuse the original source row instead of duplicating it.
|
||||
# 4. Fresh uuid4 — opaque, single-shot only.
|
||||
if operation_mode == "sync" and doc_id:
|
||||
source_uuid = str(doc_id)
|
||||
elif source_id:
|
||||
source_uuid = uuid.UUID(source_id)
|
||||
else:
|
||||
source_uuid = _derive_source_id(idempotency_key)
|
||||
source_id_for_events = str(source_uuid)
|
||||
|
||||
# Emit the queued event before any work that could fail (including
|
||||
# ``update_state``) so the toast UI always sees a queued envelope
|
||||
# before any subsequent failed event. Gated on first attempt so
|
||||
# Celery retries don't re-emit ``queued`` after a prior ``failed``.
|
||||
if self.request.retries == 0:
|
||||
publish_user_event(
|
||||
user,
|
||||
"source.ingest.queued",
|
||||
{
|
||||
"source_id": source_id_for_events,
|
||||
"job_name": name_job,
|
||||
"loader": loader,
|
||||
"operation": operation_mode,
|
||||
},
|
||||
scope={"kind": "source", "id": source_id_for_events},
|
||||
)
|
||||
|
||||
# Wrap ``update_state`` plus the entire body so any pre-loader
|
||||
# failure (Celery backend down, OS resource issue) still emits a
|
||||
# terminal ``failed`` event rather than wedging the toast.
|
||||
try:
|
||||
self.update_state(state="PROGRESS", meta={"current": 1})
|
||||
logging.info("Initializing remote loader with type: %s", loader)
|
||||
remote_loader = RemoteCreator.create_loader(loader)
|
||||
raw_docs = remote_loader.load_data(source_data)
|
||||
@@ -1105,22 +1291,22 @@ def remote_worker(
|
||||
)
|
||||
|
||||
if operation_mode == "upload":
|
||||
id = _derive_source_id(idempotency_key)
|
||||
embed_and_store_documents(
|
||||
docs, full_path, id, self,
|
||||
docs, full_path, source_uuid, self,
|
||||
attempt_id=getattr(self.request, "id", None),
|
||||
user_id=user,
|
||||
)
|
||||
assert_index_complete(id)
|
||||
assert_index_complete(source_uuid)
|
||||
elif operation_mode == "sync":
|
||||
if not doc_id:
|
||||
logging.error("Invalid doc_id provided for sync operation: %s", doc_id)
|
||||
raise ValueError("doc_id must be provided for sync operation.")
|
||||
id = str(doc_id)
|
||||
embed_and_store_documents(
|
||||
docs, full_path, id, self,
|
||||
docs, full_path, source_uuid, self,
|
||||
attempt_id=getattr(self.request, "id", None),
|
||||
user_id=user,
|
||||
)
|
||||
assert_index_complete(id)
|
||||
assert_index_complete(source_uuid)
|
||||
self.update_state(state="PROGRESS", meta={"current": 100})
|
||||
|
||||
# Serialize remote_data as JSON if it's a dict (for S3, Reddit, etc.)
|
||||
@@ -1132,7 +1318,7 @@ def remote_worker(
|
||||
"user": user,
|
||||
"tokens": tokens,
|
||||
"retriever": retriever,
|
||||
"id": str(id),
|
||||
"id": source_id_for_events,
|
||||
"type": loader,
|
||||
"remote_data": remote_data_serialized,
|
||||
"sync_frequency": sync_frequency,
|
||||
@@ -1146,23 +1332,49 @@ def remote_worker(
|
||||
try:
|
||||
with db_session() as conn:
|
||||
repo = SourcesRepository(conn)
|
||||
src = repo.get_any(str(id), user)
|
||||
src = repo.get_any(source_id_for_events, user)
|
||||
if src is not None:
|
||||
repo.update(str(src["id"]), user, {"date": last_sync_now})
|
||||
except Exception as upd_err:
|
||||
logging.warning(
|
||||
f"Failed to update last_sync for source {id}: {upd_err}"
|
||||
f"Failed to update last_sync for source {source_id_for_events}: {upd_err}"
|
||||
)
|
||||
upload_index(full_path, file_data)
|
||||
publish_user_event(
|
||||
user,
|
||||
"source.ingest.completed",
|
||||
{
|
||||
"source_id": source_id_for_events,
|
||||
"job_name": name_job,
|
||||
"loader": loader,
|
||||
"operation": operation_mode,
|
||||
"tokens": tokens,
|
||||
# Forward-looking contract: see ingest_worker.
|
||||
"limited": False,
|
||||
},
|
||||
scope={"kind": "source", "id": source_id_for_events},
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error("Error in remote_worker task: %s", str(e), exc_info=True)
|
||||
publish_user_event(
|
||||
user,
|
||||
"source.ingest.failed",
|
||||
{
|
||||
"source_id": source_id_for_events,
|
||||
"job_name": name_job,
|
||||
"loader": loader,
|
||||
"operation": operation_mode,
|
||||
"error": str(e)[:1024],
|
||||
},
|
||||
scope={"kind": "source", "id": source_id_for_events},
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
if os.path.exists(full_path):
|
||||
shutil.rmtree(full_path)
|
||||
logging.info("remote_worker task completed successfully")
|
||||
return {
|
||||
"id": str(id),
|
||||
"id": source_id_for_events,
|
||||
"urls": source_data,
|
||||
"name_job": name_job,
|
||||
"user": user,
|
||||
@@ -1245,6 +1457,13 @@ def attachment_worker(self, file_info, user):
|
||||
relative_path = file_info["path"]
|
||||
metadata = file_info.get("metadata", {})
|
||||
|
||||
publish_user_event(
|
||||
user,
|
||||
"attachment.queued",
|
||||
{"attachment_id": str(attachment_id), "filename": filename},
|
||||
scope={"kind": "attachment", "id": str(attachment_id)},
|
||||
)
|
||||
|
||||
try:
|
||||
self.update_state(state="PROGRESS", meta={"current": 10})
|
||||
storage = StorageCreator.get_storage()
|
||||
@@ -1252,6 +1471,17 @@ def attachment_worker(self, file_info, user):
|
||||
self.update_state(
|
||||
state="PROGRESS", meta={"current": 30, "status": "Processing content"}
|
||||
)
|
||||
publish_user_event(
|
||||
user,
|
||||
"attachment.progress",
|
||||
{
|
||||
"attachment_id": str(attachment_id),
|
||||
"filename": filename,
|
||||
"current": 30,
|
||||
"stage": "processing",
|
||||
},
|
||||
scope={"kind": "attachment", "id": str(attachment_id)},
|
||||
)
|
||||
|
||||
file_extractor = get_default_file_extractor(
|
||||
ocr_enabled=settings.DOCLING_OCR_ATTACHMENTS_ENABLED
|
||||
@@ -1284,6 +1514,17 @@ def attachment_worker(self, file_info, user):
|
||||
self.update_state(
|
||||
state="PROGRESS", meta={"current": 80, "status": "Storing in database"}
|
||||
)
|
||||
publish_user_event(
|
||||
user,
|
||||
"attachment.progress",
|
||||
{
|
||||
"attachment_id": str(attachment_id),
|
||||
"filename": filename,
|
||||
"current": 80,
|
||||
"stage": "storing",
|
||||
},
|
||||
scope={"kind": "attachment", "id": str(attachment_id)},
|
||||
)
|
||||
|
||||
mime_type = mimetypes.guess_type(filename)[0] or "application/octet-stream"
|
||||
|
||||
@@ -1308,6 +1549,18 @@ def attachment_worker(self, file_info, user):
|
||||
|
||||
self.update_state(state="PROGRESS", meta={"current": 100, "status": "Complete"})
|
||||
|
||||
publish_user_event(
|
||||
user,
|
||||
"attachment.completed",
|
||||
{
|
||||
"attachment_id": str(attachment_id),
|
||||
"filename": filename,
|
||||
"token_count": token_count,
|
||||
"mime_type": mime_type,
|
||||
},
|
||||
scope={"kind": "attachment", "id": str(attachment_id)},
|
||||
)
|
||||
|
||||
return {
|
||||
"filename": filename,
|
||||
"path": relative_path,
|
||||
@@ -1322,6 +1575,16 @@ def attachment_worker(self, file_info, user):
|
||||
extra={"user": user},
|
||||
exc_info=True,
|
||||
)
|
||||
publish_user_event(
|
||||
user,
|
||||
"attachment.failed",
|
||||
{
|
||||
"attachment_id": str(attachment_id),
|
||||
"filename": filename,
|
||||
"error": str(e)[:1024],
|
||||
},
|
||||
scope={"kind": "attachment", "id": str(attachment_id)},
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@@ -1385,6 +1648,7 @@ def ingest_connector(
|
||||
doc_id=None,
|
||||
sync_frequency: str = "never",
|
||||
idempotency_key=None,
|
||||
source_id=None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Ingestion for internal knowledge bases (GoogleDrive, etc.).
|
||||
@@ -1403,14 +1667,50 @@ def ingest_connector(
|
||||
sync_frequency: How often to sync ("never", "daily", "weekly", "monthly")
|
||||
idempotency_key: When provided, the ``source_id`` is derived
|
||||
deterministically so a retried upload reuses the same source row.
|
||||
source_id: When supplied, the worker uses it verbatim so SSE envelopes
|
||||
carry the same id the HTTP route already returned to the frontend
|
||||
— required for non-idempotent uploads where the route can't
|
||||
predict ``_derive_source_id(idempotency_key)``.
|
||||
"""
|
||||
logging.info(
|
||||
f"Starting remote ingestion from {source_type} for user: {user}, job: {job_name}"
|
||||
)
|
||||
|
||||
# Source id resolution mirrors ``ingest_worker`` / ``remote_worker``:
|
||||
# sync mode reuses ``doc_id``; otherwise the caller-supplied
|
||||
# ``source_id`` (minted by the HTTP route and already echoed to the
|
||||
# client) wins; fall back to ``_derive_source_id`` only when neither
|
||||
# is supplied. Without rule (2) the no-idempotency-key path would
|
||||
# mint a fresh uuid4 here that the frontend has no way to correlate
|
||||
# SSE envelopes to.
|
||||
if operation_mode == "sync" and doc_id:
|
||||
source_uuid = str(doc_id)
|
||||
elif source_id:
|
||||
source_uuid = uuid.UUID(source_id)
|
||||
else:
|
||||
source_uuid = _derive_source_id(idempotency_key)
|
||||
source_id_for_events = str(source_uuid)
|
||||
|
||||
# First-attempt gate: Celery retries re-run the body, and a
|
||||
# repeated ``queued`` here would oscillate the toast through
|
||||
# ``queued`` again between ``failed`` and ``completed``.
|
||||
if self.request.retries == 0:
|
||||
publish_user_event(
|
||||
user,
|
||||
"source.ingest.queued",
|
||||
{
|
||||
"source_id": source_id_for_events,
|
||||
"job_name": job_name,
|
||||
"loader": source_type,
|
||||
"operation": operation_mode,
|
||||
},
|
||||
scope={"kind": "source", "id": source_id_for_events},
|
||||
)
|
||||
|
||||
self.update_state(state="PROGRESS", meta={"current": 1})
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
try:
|
||||
try:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Step 1: Initialize the appropriate loader
|
||||
self.update_state(
|
||||
state="PROGRESS",
|
||||
@@ -1448,6 +1748,22 @@ def ingest_connector(
|
||||
"files_downloaded", 0
|
||||
):
|
||||
logging.warning(f"No files were downloaded from {source_type}")
|
||||
# Connector returned no files — surface as a benign
|
||||
# ``completed`` event with zero tokens so the toast
|
||||
# closes out cleanly instead of waiting on polling.
|
||||
publish_user_event(
|
||||
user,
|
||||
"source.ingest.completed",
|
||||
{
|
||||
"source_id": source_id_for_events,
|
||||
"job_name": job_name,
|
||||
"loader": source_type,
|
||||
"operation": operation_mode,
|
||||
"tokens": 0,
|
||||
"no_changes": True,
|
||||
},
|
||||
scope={"kind": "source", "id": source_id_for_events},
|
||||
)
|
||||
# Create empty result directly instead of calling a separate method
|
||||
return {
|
||||
"name": job_name,
|
||||
@@ -1497,16 +1813,16 @@ def ingest_connector(
|
||||
|
||||
docs = [Document.to_langchain_format(raw_doc) for raw_doc in raw_docs]
|
||||
|
||||
if operation_mode == "upload":
|
||||
id = _derive_source_id(idempotency_key)
|
||||
elif operation_mode == "sync":
|
||||
if not doc_id:
|
||||
logging.error(
|
||||
"Invalid doc_id provided for sync operation: %s", doc_id
|
||||
)
|
||||
raise ValueError("doc_id must be provided for sync operation.")
|
||||
id = str(doc_id)
|
||||
else:
|
||||
# Validate operation_mode here too (the source_uuid path
|
||||
# at the top of the function only branches on the
|
||||
# sync+doc_id combination; surfacing the wrong-mode error
|
||||
# this far in matches the legacy behaviour).
|
||||
if operation_mode == "sync" and not doc_id:
|
||||
logging.error(
|
||||
"Invalid doc_id provided for sync operation: %s", doc_id
|
||||
)
|
||||
raise ValueError("doc_id must be provided for sync operation.")
|
||||
if operation_mode not in ("upload", "sync"):
|
||||
raise ValueError(f"Invalid operation_mode: {operation_mode}")
|
||||
|
||||
vector_store_path = os.path.join(temp_dir, "vector_store")
|
||||
@@ -1516,10 +1832,11 @@ def ingest_connector(
|
||||
state="PROGRESS", meta={"current": 80, "status": "Storing documents"}
|
||||
)
|
||||
embed_and_store_documents(
|
||||
docs, vector_store_path, id, self,
|
||||
docs, vector_store_path, source_uuid, self,
|
||||
attempt_id=getattr(self.request, "id", None),
|
||||
user_id=user,
|
||||
)
|
||||
assert_index_complete(id)
|
||||
assert_index_complete(source_uuid)
|
||||
|
||||
tokens = count_tokens_docs(docs)
|
||||
|
||||
@@ -1529,7 +1846,7 @@ def ingest_connector(
|
||||
"name": job_name,
|
||||
"tokens": tokens,
|
||||
"retriever": retriever,
|
||||
"id": str(id),
|
||||
"id": source_id_for_events,
|
||||
"type": "connector:file",
|
||||
"remote_data": json.dumps(
|
||||
{"provider": source_type, **api_source_config}
|
||||
@@ -1538,16 +1855,13 @@ def ingest_connector(
|
||||
"sync_frequency": sync_frequency,
|
||||
}
|
||||
|
||||
if operation_mode == "sync":
|
||||
file_data["last_sync"] = datetime.datetime.now()
|
||||
else:
|
||||
file_data["last_sync"] = datetime.datetime.now()
|
||||
file_data["last_sync"] = datetime.datetime.now()
|
||||
|
||||
if operation_mode == "sync":
|
||||
try:
|
||||
with db_session() as conn:
|
||||
repo = SourcesRepository(conn)
|
||||
src = repo.get_any(str(id), user)
|
||||
src = repo.get_any(source_id_for_events, user)
|
||||
if src is not None:
|
||||
repo.update(
|
||||
str(src["id"]), user,
|
||||
@@ -1555,7 +1869,9 @@ def ingest_connector(
|
||||
)
|
||||
except Exception as upd_err:
|
||||
logging.warning(
|
||||
f"Failed to update last_sync for source {id}: {upd_err}"
|
||||
"Failed to update last_sync for source %s: %s",
|
||||
source_id_for_events,
|
||||
upd_err,
|
||||
)
|
||||
|
||||
upload_index(vector_store_path, file_data)
|
||||
@@ -1567,45 +1883,104 @@ def ingest_connector(
|
||||
|
||||
logging.info(f"Remote ingestion completed: {job_name}")
|
||||
|
||||
publish_user_event(
|
||||
user,
|
||||
"source.ingest.completed",
|
||||
{
|
||||
"source_id": source_id_for_events,
|
||||
"job_name": job_name,
|
||||
"loader": source_type,
|
||||
"operation": operation_mode,
|
||||
"tokens": tokens,
|
||||
},
|
||||
scope={"kind": "source", "id": source_id_for_events},
|
||||
)
|
||||
|
||||
return {
|
||||
"user": user,
|
||||
"name": job_name,
|
||||
"tokens": tokens,
|
||||
"type": source_type,
|
||||
"id": str(id),
|
||||
"id": source_id_for_events,
|
||||
"status": "complete",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error during remote ingestion: {e}", exc_info=True)
|
||||
raise
|
||||
except Exception as e:
|
||||
logging.error(f"Error during remote ingestion: {e}", exc_info=True)
|
||||
publish_user_event(
|
||||
user,
|
||||
"source.ingest.failed",
|
||||
{
|
||||
"source_id": source_id_for_events,
|
||||
"job_name": job_name,
|
||||
"loader": source_type,
|
||||
"operation": operation_mode,
|
||||
"error": str(e)[:1024],
|
||||
},
|
||||
scope={"kind": "source", "id": source_id_for_events},
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def mcp_oauth(self, config: Dict[str, Any], user_id: str = None) -> Dict[str, Any]:
|
||||
"""Worker to handle MCP OAuth flow asynchronously."""
|
||||
"""Worker to handle MCP OAuth flow asynchronously.
|
||||
|
||||
Publishes SSE events at each phase boundary so the frontend can
|
||||
drive the OAuth popup directly from the push channel. The
|
||||
``mcp.oauth.awaiting_redirect`` envelope carries the
|
||||
``authorization_url`` once the upstream OAuth client surfaces it,
|
||||
eliminating the prior polling-only path for that URL.
|
||||
"""
|
||||
|
||||
# Bind ``task_id`` and the publish helpers OUTSIDE the outer try so
|
||||
# the ``except`` handler at the bottom can reach them even when an
|
||||
# early statement raises. Without this, ``publish_oauth`` would
|
||||
# UnboundLocalError on top of the original failure.
|
||||
task_id = self.request.id if getattr(self, "request", None) else None
|
||||
|
||||
def publish_oauth(event_type: str, payload: Dict[str, Any]) -> None:
|
||||
# MCP OAuth can be invoked without a route-bound user_id by
|
||||
# legacy paths. Skip the SSE publish in that case \u2014 the caller
|
||||
# has no per-user channel to subscribe to, and the status is
|
||||
# surfaced via the task's return value.
|
||||
if not user_id or task_id is None:
|
||||
return
|
||||
publish_user_event(
|
||||
user_id,
|
||||
event_type,
|
||||
{"task_id": task_id, **payload},
|
||||
scope={"kind": "mcp_oauth", "id": task_id},
|
||||
)
|
||||
|
||||
def publish_awaiting_redirect(authorization_url: str) -> None:
|
||||
"""Callback invoked by ``DocsGPTOAuth.redirect_handler`` once
|
||||
the OAuth client has minted the authorization URL.
|
||||
|
||||
Carrying the URL on the SSE envelope lets the frontend open the
|
||||
popup directly from the event \u2014 the prior polling-only path
|
||||
for the URL is gone.
|
||||
"""
|
||||
publish_oauth(
|
||||
"mcp.oauth.awaiting_redirect",
|
||||
{
|
||||
"message": "Awaiting OAuth redirect...",
|
||||
"authorization_url": authorization_url,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
from application.agents.tools.mcp_tool import MCPTool
|
||||
|
||||
task_id = self.request.id
|
||||
redis_client = get_redis_instance()
|
||||
|
||||
def update_status(status_data: Dict[str, Any]):
|
||||
status_key = f"mcp_oauth_status:{task_id}"
|
||||
redis_client.setex(status_key, 600, json.dumps(status_data))
|
||||
|
||||
update_status(
|
||||
{
|
||||
"status": "in_progress",
|
||||
"message": "Starting OAuth...",
|
||||
"task_id": task_id,
|
||||
}
|
||||
)
|
||||
publish_oauth("mcp.oauth.in_progress", {"message": "Starting OAuth..."})
|
||||
|
||||
tool_config = config.copy()
|
||||
tool_config["oauth_task_id"] = task_id
|
||||
# Inject the awaiting-redirect publish callback. ``MCPTool`` pops
|
||||
# it out of the config and threads it into ``DocsGPTOAuth`` so
|
||||
# the publish fires synchronously from inside
|
||||
# ``redirect_handler`` \u2014 the only point where the URL is known.
|
||||
tool_config["oauth_redirect_publish"] = publish_awaiting_redirect
|
||||
mcp_tool = MCPTool(tool_config, user_id)
|
||||
|
||||
async def run_oauth_discovery():
|
||||
@@ -1613,14 +1988,6 @@ def mcp_oauth(self, config: Dict[str, Any], user_id: str = None) -> Dict[str, An
|
||||
mcp_tool._setup_client()
|
||||
return await mcp_tool._execute_with_client("list_tools")
|
||||
|
||||
update_status(
|
||||
{
|
||||
"status": "awaiting_redirect",
|
||||
"message": "Awaiting OAuth redirect...",
|
||||
"task_id": task_id,
|
||||
}
|
||||
)
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
@@ -1628,49 +1995,21 @@ def mcp_oauth(self, config: Dict[str, Any], user_id: str = None) -> Dict[str, An
|
||||
loop.run_until_complete(run_oauth_discovery())
|
||||
tools = mcp_tool.get_actions_metadata()
|
||||
|
||||
update_status(
|
||||
{
|
||||
"status": "completed",
|
||||
"message": f"Connected \u2014 found {len(tools)} tool{'s' if len(tools) != 1 else ''}.",
|
||||
"tools": tools,
|
||||
"tools_count": len(tools),
|
||||
"task_id": task_id,
|
||||
}
|
||||
publish_oauth(
|
||||
"mcp.oauth.completed",
|
||||
{"tools": tools, "tools_count": len(tools)},
|
||||
)
|
||||
|
||||
return {"success": True, "tools": tools, "tools_count": len(tools)}
|
||||
except Exception as e:
|
||||
error_msg = f"OAuth failed: {str(e)}"
|
||||
logging.error("MCP OAuth discovery failed: %s", error_msg, exc_info=True)
|
||||
update_status(
|
||||
{
|
||||
"status": "error",
|
||||
"message": error_msg,
|
||||
"task_id": task_id,
|
||||
}
|
||||
)
|
||||
publish_oauth("mcp.oauth.failed", {"error": error_msg[:1024]})
|
||||
return {"success": False, "error": error_msg}
|
||||
finally:
|
||||
loop.close()
|
||||
except Exception as e:
|
||||
error_msg = f"OAuth init failed: {str(e)}"
|
||||
logging.error("MCP OAuth init failed: %s", error_msg, exc_info=True)
|
||||
update_status(
|
||||
{
|
||||
"status": "error",
|
||||
"message": error_msg,
|
||||
"task_id": task_id,
|
||||
}
|
||||
)
|
||||
publish_oauth("mcp.oauth.failed", {"error": error_msg[:1024]})
|
||||
return {"success": False, "error": error_msg}
|
||||
|
||||
|
||||
def mcp_oauth_status(self, task_id: str) -> Dict[str, Any]:
|
||||
"""Check the status of an MCP OAuth flow."""
|
||||
redis_client = get_redis_instance()
|
||||
status_key = f"mcp_oauth_status:{task_id}"
|
||||
|
||||
status_data = redis_client.get(status_key)
|
||||
if status_data:
|
||||
return json.loads(status_data)
|
||||
return {"status": "not_found", "message": "Status not found"}
|
||||
|
||||
385
docs/runbooks/sse-notifications.md
Normal file
385
docs/runbooks/sse-notifications.md
Normal file
@@ -0,0 +1,385 @@
|
||||
# SSE Notifications Runbook
|
||||
|
||||
> Operations guide for "user says they didn't get a notification" — and
|
||||
> the related "the bell never lights up" / "my upload toast hangs" /
|
||||
> "the chat answer doesn't reconnect" symptoms.
|
||||
|
||||
The user-facing notifications channel is the SSE pipe at
|
||||
`/api/events` plus per-message reconnects at
|
||||
`/api/messages/<id>/events`. This document maps a user complaint to
|
||||
the diagnostic that surfaces the cause.
|
||||
|
||||
---
|
||||
|
||||
## TL;DR — first 60 seconds
|
||||
|
||||
Run these three commands in parallel before anything else:
|
||||
|
||||
```bash
|
||||
# 1) Is Redis up and serving the pipe? Should print PONG instantly.
|
||||
redis-cli -n 2 PING
|
||||
|
||||
# 2) Anyone subscribed to the channel right now? Numbers per channel.
|
||||
redis-cli -n 2 PUBSUB NUMSUB user:<user_id>
|
||||
|
||||
# 3) Is the user's backlog populated? Returns the count of journaled events.
|
||||
redis-cli -n 2 XLEN user:<user_id>:stream
|
||||
```
|
||||
|
||||
- `PING` failing → Redis is the problem. Skip to "Redis-down".
|
||||
- `NUMSUB user:<user_id>` returns 0 → no client connected. Skip to "Client never connects".
|
||||
- `XLEN user:<user_id>:stream` returns 0 or low → publisher isn't writing. Skip to "Publisher silent".
|
||||
- All three look healthy → the events are flowing on the wire; the issue is downstream of the slice (UI rendering, toast suppression, etc.). Skip to "Events flowing but UI silent".
|
||||
|
||||
---
|
||||
|
||||
## Architecture cheat-sheet
|
||||
|
||||
```
|
||||
Worker (publish_user_event) Frontend tab
|
||||
│ ▲
|
||||
▼ │ GET /api/events SSE
|
||||
Redis Streams: XADD Flask route
|
||||
user:<id>:stream ──────────────► replay_backlog (snapshot)
|
||||
│ +
|
||||
▼ Topic.subscribe (live tail)
|
||||
Redis pub/sub: PUBLISH │
|
||||
user:<id> ────────────────────────────────┘
|
||||
```
|
||||
|
||||
**Source of truth:**
|
||||
- Persistent journal: Redis Stream `user:<user_id>:stream`, capped at
|
||||
`EVENTS_STREAM_MAXLEN` (default 1000) entries via `MAXLEN ~`. ~24h
|
||||
at typical event rates.
|
||||
- Live fan-out: Redis pub/sub channel `user:<user_id>`. No durability;
|
||||
subscribers must be attached at publish time.
|
||||
|
||||
The chat-stream pipe is separate, parallel infrastructure:
|
||||
- Journal: Postgres `message_events` table.
|
||||
- Live fan-out: Redis pub/sub `channel:<message_id>`.
|
||||
|
||||
Same patterns, different durability layer. This doc covers both;
|
||||
they share most diagnostic commands.
|
||||
|
||||
---
|
||||
|
||||
## Symptom → diagnostic map
|
||||
|
||||
### A. "I uploaded a source and the toast never appeared"
|
||||
|
||||
User flow: chat → upload → expect toast.
|
||||
|
||||
| Step | Command | Expect |
|
||||
| ------------------------------------------------- | ------------------------------------------------------------- | ----------------------------------------------- |
|
||||
| Worker received the task | `tail -f celery.log` filtered by user | `ingest_worker` start log line |
|
||||
| Worker published the queued event | `redis-cli -n 2 XREVRANGE user:<id>:stream + - COUNT 5` | A `source.ingest.queued` entry within seconds |
|
||||
| Frontend got it | DevTools → Network → `/api/events` → EventStream tab | `data: {"type":"source.ingest.queued",...}` |
|
||||
| Slice updated | Redux DevTools → state.upload.tasks | Task with matching `sourceId`, `status:'training'` |
|
||||
|
||||
If the worker's queued log line is there but the XADD didn't land →
|
||||
look for a `publish_user_event payload not JSON-serializable` warning
|
||||
in the worker log (the publisher swallows `TypeError`).
|
||||
|
||||
If the XADD landed but the frontend never received it → check
|
||||
`PUBSUB NUMSUB user:<id>` while the user is on the page. If 0, the
|
||||
SSE connection isn't subscribed; skip to "Client never connects".
|
||||
|
||||
If the frontend received it but the toast didn't render → the
|
||||
`uploadSlice` extraReducer requires `task.sourceId` to match the
|
||||
event's `scope.id`. Check the upload route returned `source_id` in
|
||||
its POST response (the upload, connector, and reingest paths all
|
||||
include it). Idempotent / cached responses must also include
|
||||
`source_id` (`_claim_task_or_get_cached`).
|
||||
|
||||
### B. "The bell badge never goes up"
|
||||
|
||||
There is no bell — the global notifications surface is per-event
|
||||
toasts, not an aggregated counter. If the user is on an old build,
|
||||
`Cmd-Shift-R` to bypass cache. The surfaces they're looking for are
|
||||
`UploadToast` for source uploads and `ToolApprovalToast` for
|
||||
tool-approval events.
|
||||
|
||||
### C. "My chat answer froze mid-stream and never recovered"
|
||||
|
||||
User flow: ask question → answer streaming → network blip → answer
|
||||
stops; should reconnect.
|
||||
|
||||
```bash
|
||||
# Was the original message reserved in PG?
|
||||
psql -c "SELECT id, status, prompt FROM conversation_messages \
|
||||
WHERE user_id = '<user>' ORDER BY timestamp DESC LIMIT 5;"
|
||||
|
||||
# Did the journal capture events past the user's last-seen seq?
|
||||
psql -c "SELECT sequence_no, event_type FROM message_events \
|
||||
WHERE message_id = '<id>' ORDER BY sequence_no;"
|
||||
|
||||
# Is the live tail still producing? (subscribe and watch)
|
||||
redis-cli -n 2 SUBSCRIBE channel:<message_id>
|
||||
```
|
||||
|
||||
The frontend should reconnect via `GET /api/messages/<id>/events`
|
||||
when the original POST stream closes without a typed `end` or
|
||||
`error` event. If it's not reconnecting, `console.warn('Stream
|
||||
reconnect failed', ...)` will be in the browser console — the
|
||||
reconnect HTTP errored. Common cases:
|
||||
|
||||
- The user's JWT rotated mid-stream → 401 on the GET. Frontend
|
||||
doesn't auto-refresh; the user reloads.
|
||||
- The user is on a different host than the API and CORS is rejecting
|
||||
the GET → check `application/asgi.py` allow-headers.
|
||||
|
||||
### D. "The dev install never delivers any notifications at all"
|
||||
|
||||
Default `AUTH_TYPE` unset means `decoded_token = {"sub": "local"}`
|
||||
for every request. The SSE client connects without the
|
||||
`Authorization` header in this case, and `user:local:stream` is
|
||||
the shared channel everything goes to. If the user has multiple dev
|
||||
machines pointing at the same Redis, they will see each other's
|
||||
events. Confirm with:
|
||||
|
||||
```bash
|
||||
redis-cli -n 2 KEYS 'user:local:*'
|
||||
```
|
||||
|
||||
If multiple deployments share the Redis, document that as a known
|
||||
multi-user-on-local-channel limitation. Set `AUTH_TYPE=simple_jwt`
|
||||
to scope per-user.
|
||||
|
||||
### E. "The notifications channel was working, then suddenly stopped after the user reloaded the page"
|
||||
|
||||
Likely path: `backlog.truncated` event fired, the slice cleared
|
||||
`lastEventId` to null, the closure was carrying the same stale id and
|
||||
re-tripped the same truncation on every reconnect. **Verify the user
|
||||
is on a current build — `eventStreamClient.ts` must re-read
|
||||
`lastEventId = opts.getLastEventId();` without a truthy guard so the
|
||||
null clear propagates into the next reconnect.**
|
||||
|
||||
### F. "I keep getting 429 on /api/events"
|
||||
|
||||
The per-user concurrent-connection cap (`SSE_MAX_CONCURRENT_PER_USER`,
|
||||
default 8) refused the connection. User has too many tabs open or a
|
||||
runaway reconnect loop. `redis-cli -n 2 GET user:<id>:sse_count`
|
||||
shows the live counter; the TTL is 1h from the last connection
|
||||
attempt (rolling — every INCR re-seeds it), so the key only ages
|
||||
out after the user stops reconnecting for a full hour.
|
||||
|
||||
If the count is wedged high without explanation, the
|
||||
counter-DECR-in-finally path didn't run (worker SIGKILL, OOM). Wait
|
||||
for the TTL or `redis-cli -n 2 DEL user:<id>:sse_count` to reset.
|
||||
|
||||
### G. "Replay snapshot stops at 200 events"
|
||||
|
||||
The route caps each replay at `EVENTS_REPLAY_MAX_PER_REQUEST`
|
||||
(default 200). The cap is intentionally **silent** — the route does NOT
|
||||
emit a `backlog.truncated` notice for cap-hit. The 200 entries each
|
||||
carry their own `id:` header, so the frontend's slice cursor
|
||||
advances to the most-recent delivered id. Next reconnect sends
|
||||
`last_event_id=<max_replayed>` and the snapshot resumes from there.
|
||||
A user that was 1000 entries behind catches up over ~5 reconnects.
|
||||
|
||||
If the user reports getting HTTP 429 on `/api/events` despite being
|
||||
well under `SSE_MAX_CONCURRENT_PER_USER`, they hit the windowed
|
||||
replay budget (`EVENTS_REPLAY_BUDGET_REQUESTS_PER_WINDOW`, default
|
||||
30 / `EVENTS_REPLAY_BUDGET_WINDOW_SECONDS` 60s). The route refuses
|
||||
the connection so the slice cursor stays pinned at whatever value
|
||||
it had; the frontend backs off and the next reconnect (after the
|
||||
window rolls) gets the proper snapshot. Serving the live tail
|
||||
without a snapshot used to be the behavior here, but that let the
|
||||
client advance `lastEventId` past entries it never received,
|
||||
permanently stranding the un-replayed window — so the route now
|
||||
429s instead. `redis-cli -n 2 GET user:<id>:replay_count` shows the
|
||||
current counter; TTL is the window size.
|
||||
|
||||
`backlog.truncated` is emitted ONLY when the client's
|
||||
`Last-Event-ID` has slid off the MAXLEN'd window — i.e. the journal
|
||||
is genuinely gone past the cursor and the frontend should clear the
|
||||
slice cursor and refetch state. Treating cap-hit or
|
||||
budget-exhaustion the same way would lock the user into re-receiving
|
||||
the oldest 200 entries on every reconnect (the cursor would clear,
|
||||
the snapshot would re-serve from the start, the cap would re-trip).
|
||||
|
||||
### H. "User says push notifications stopped after a deploy"
|
||||
|
||||
- Pull `event.published topic=user:<id> type=...` from the worker
|
||||
logs to confirm the publisher is still firing.
|
||||
- Pull `event.connect user=<id>` from the API logs to confirm the
|
||||
client is reconnecting.
|
||||
- Check the gunicorn worker count and `WSGIMiddleware(workers=32)` —
|
||||
if the deploy reduced worker count, the per-user cap is still 8
|
||||
but total concurrent SSE connections are bounded by `gunicorn
|
||||
workers × 32`. A capacity miss looks like users randomly getting
|
||||
429'd.
|
||||
|
||||
---
|
||||
|
||||
## Common failure modes
|
||||
|
||||
### Redis-down
|
||||
|
||||
Symptoms: `/api/events` returns 200 but emits only `: connected`
|
||||
then the body closes. `XLEN` and `PUBLISH` both fail. The publisher's
|
||||
`record_event` swallows the failure and returns False; the live tail
|
||||
publish also drops on the floor. Frontend retries forever with
|
||||
exponential backoff.
|
||||
|
||||
Resolution: bring Redis back. The journal is gone (was in-memory
|
||||
only — Streams persist within a single Redis instance, no replication
|
||||
configured). New events flow as soon as Redis comes back.
|
||||
|
||||
### `AUTH_TYPE` misconfigured = sub:"local" cross-stream
|
||||
|
||||
Symptoms: every user shares `user:local:stream`. Any user sees
|
||||
everyone else's notifications.
|
||||
|
||||
Resolution: set `AUTH_TYPE=simple_jwt` (or `session_jwt`) in `.env`.
|
||||
The events route logs a one-time WARNING per process when
|
||||
`sub == "local"` is observed. A repeat WARNING after a restart
|
||||
confirms the misconfiguration.
|
||||
|
||||
### MAXLEN trimmed past Last-Event-ID
|
||||
|
||||
Symptoms: client reconnects with `last_event_id=X`, snapshot returns
|
||||
the entire MAXLEN'd backlog (because X is older than the oldest
|
||||
retained entry). Old events appear duplicated.
|
||||
|
||||
Detection: the route's `_oldest_retained_id` check emits
|
||||
`backlog.truncated` when this case fires. Frontend's
|
||||
`dispatchSSEEvent` clears `lastEventId` so the next reconnect starts
|
||||
fresh.
|
||||
|
||||
If the WARNING isn't firing but symptoms match: the user's client
|
||||
may have a corrupt cached `lastEventId`. `localStorage` doesn't
|
||||
store this state; check Redux state via DevTools.
|
||||
|
||||
### Stale event-stream client
|
||||
|
||||
Symptoms: events visible in `XRANGE` but the frontend slice doesn't
|
||||
update.
|
||||
|
||||
```bash
|
||||
# Is the client subscribed?
|
||||
redis-cli -n 2 PUBSUB NUMSUB user:<id>
|
||||
|
||||
# When did its connection start?
|
||||
grep "event.connect user=<id>" /var/log/docsgpt.log | tail -3
|
||||
```
|
||||
|
||||
If `NUMSUB` is 0 and no recent `event.connect`, the user's tab is
|
||||
closed or the connection died and never reconnected. Push them to
|
||||
reload.
|
||||
|
||||
### Publisher silent
|
||||
|
||||
Symptoms: worker is processing the task (Celery says SUCCESS), but
|
||||
no XADD and no PUBLISH. User sees no events.
|
||||
|
||||
```bash
|
||||
# Was the publisher import error suppressed?
|
||||
grep "publish_user_event" /var/log/celery.log | grep -i "warn\|error" | tail -20
|
||||
|
||||
# Is push disabled?
|
||||
grep "ENABLE_SSE_PUSH" /var/log/docsgpt.log | tail -5
|
||||
```
|
||||
|
||||
`ENABLE_SSE_PUSH=False` in `.env` would silence the publisher
|
||||
globally. Useful for incident response if a runaway publisher is
|
||||
DoS'ing Redis; toggle off, fix root cause, toggle on.
|
||||
|
||||
---
|
||||
|
||||
## Useful one-liners
|
||||
|
||||
```bash
|
||||
# Watch a user's live event stream in real time (all events, all types)
|
||||
redis-cli -n 2 PSUBSCRIBE 'user:*' | grep "user:<id>"
|
||||
|
||||
# Last 10 events the user would see on reconnect
|
||||
redis-cli -n 2 XREVRANGE user:<id>:stream + - COUNT 10
|
||||
|
||||
# Live count of subscribed clients per user
|
||||
redis-cli -n 2 PUBSUB NUMSUB $(redis-cli -n 2 PUBSUB CHANNELS 'user:*')
|
||||
|
||||
# Trim a runaway stream (CAREFUL — destroys backlog for all current
|
||||
# subscribers; OK after explaining to the user)
|
||||
redis-cli -n 2 XTRIM user:<id>:stream MAXLEN 0
|
||||
|
||||
# Clear a wedged concurrent-connection counter
|
||||
redis-cli -n 2 DEL user:<id>:sse_count
|
||||
|
||||
# Force-flip every client to re-snapshot (drop the stream key entirely
|
||||
# — destroys the backlog; clients reconnect with their last id and
|
||||
# get a backlog.truncated)
|
||||
redis-cli -n 2 DEL user:<id>:stream
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Settings reference
|
||||
|
||||
Everything in `application/core/settings.py`:
|
||||
|
||||
| Setting | Default | Purpose |
|
||||
| --------------------------------------------- | ------- | --------------------------------------------- |
|
||||
| `ENABLE_SSE_PUSH` | `True` | Master switch. False = publisher no-ops, route serves "push_disabled" comment. |
|
||||
| `EVENTS_STREAM_MAXLEN` | `1000` | Per-user backlog cap. Approximate via `XADD MAXLEN ~`. |
|
||||
| `SSE_KEEPALIVE_SECONDS` | `15` | Comment-frame cadence. Must sit under reverse-proxy idle close. |
|
||||
| `SSE_MAX_CONCURRENT_PER_USER` | `8` | Cap on simultaneous SSE connections per user. 0 = disabled. |
|
||||
| `EVENTS_REPLAY_MAX_PER_REQUEST` | `200` | Hard cap on snapshot rows per request. |
|
||||
| `EVENTS_REPLAY_BUDGET_REQUESTS_PER_WINDOW` | `30` | Per-user replays per window. 0 = disabled. |
|
||||
| `EVENTS_REPLAY_BUDGET_WINDOW_SECONDS` | `60` | Window length. |
|
||||
| `MESSAGE_EVENTS_RETENTION_DAYS` | `14` | Retention for the `message_events` journal; `cleanup_message_events` beat task deletes older rows. |
|
||||
|
||||
---
|
||||
|
||||
## Known limitations
|
||||
|
||||
### Each tab runs its own SSE connection
|
||||
|
||||
There is no cross-tab dedup. Every tab open to the app holds its
|
||||
own SSE connection and dispatches every received event into its
|
||||
own Redux store, so a user with N tabs open will see N copies of
|
||||
each toast. With `SSE_MAX_CONCURRENT_PER_USER=8` (the default) a
|
||||
heavy multi-tab user can also hit the connection cap and start
|
||||
seeing 429s. Cross-tab dedup via a `BroadcastChannel` ring +
|
||||
`navigator.locks`-based leader election is tracked as future work.
|
||||
|
||||
### `/c/<unknown-id>` normalises to `/c/new`
|
||||
|
||||
If a user navigates to a conversation id that isn't in their
|
||||
loaded list, the conversation route rewrites the URL to `/c/new`.
|
||||
`ToolApprovalToast`'s gate uses `useMatch('/c/:conversationId')`,
|
||||
so for the brief window after the rewrite the toast may surface
|
||||
for a conversation the user *thought* they were already viewing.
|
||||
Pre-existing route behaviour; not a notifications regression.
|
||||
|
||||
### Terminal events un-dismiss running uploads
|
||||
|
||||
`frontend/src/upload/uploadSlice.ts` sets `dismissed: false` when
|
||||
an upload reaches `completed` or `failed`. If the user dismissed a
|
||||
running task and the terminal SSE arrives later, the toast pops
|
||||
back. Intentional ("notify the user it's done"); revisit if the
|
||||
re-surface UX is too aggressive for v2.
|
||||
|
||||
### Werkzeug doesn't auto-reload route files
|
||||
|
||||
The dev server (`flask run`) doesn't watch
|
||||
`application/api/events/routes.py` for changes by default.
|
||||
After editing the route, restart Flask manually — `--reload`
|
||||
isn't on. (Production gunicorn reloads via deploy.)
|
||||
|
||||
### MCP OAuth completion can fall outside the user stream's MAXLEN window
|
||||
|
||||
`get_oauth_status` scans up to `EVENTS_STREAM_MAXLEN` (~1000) entries via `XREVRANGE`. If the user has a high-rate ingest running concurrent with the OAuth handshake, the `mcp.oauth.completed` envelope can be trimmed off the back before they click Save. Symptom: backend returns "OAuth failed or not completed" even though the popup completed successfully.
|
||||
|
||||
Mitigation today: bump `EVENTS_STREAM_MAXLEN` per-deployment if your users routinely flood the channel during OAuth flows. A dedicated short-TTL Redis key for OAuth task results is tracked as a follow-up.
|
||||
|
||||
### React StrictMode double-mounts SSE
|
||||
|
||||
In dev, React 18 StrictMode mounts → unmounts → remounts every
|
||||
component, briefly opening two SSE connections per tab before the
|
||||
first is aborted. With `SSE_MAX_CONCURRENT_PER_USER=8` and 4–5
|
||||
tabs open concurrently you can transiently hit the cap and see
|
||||
HTTP 429 on cold-load. The first connection's counter increment
|
||||
fires before the AbortController-induced disconnect can decrement
|
||||
it. Production (single mount, no StrictMode) is unaffected; raise
|
||||
the cap in dev or accept transient 429s.
|
||||
1531
frontend/package-lock.json
generated
1531
frontend/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -7,6 +7,8 @@
|
||||
"dev": "vite",
|
||||
"build": "tsc && vite build",
|
||||
"preview": "vite preview",
|
||||
"test": "vitest run",
|
||||
"test:watch": "vitest",
|
||||
"lint": "eslint ./src --ext .jsx,.js,.ts,.tsx",
|
||||
"lint-fix": "eslint ./src --ext .jsx,.js,.ts,.tsx --fix",
|
||||
"format": "prettier ./src --write",
|
||||
@@ -69,6 +71,7 @@
|
||||
"eslint-plugin-promise": "^6.6.0",
|
||||
"eslint-plugin-react": "^7.37.5",
|
||||
"eslint-plugin-unused-imports": "^4.1.4",
|
||||
"happy-dom": "^17.6.3",
|
||||
"husky": "^9.1.7",
|
||||
"lint-staged": "^16.4.0",
|
||||
"postcss": "^8.5.12",
|
||||
@@ -78,6 +81,7 @@
|
||||
"tw-animate-css": "^1.4.0",
|
||||
"typescript": "^6.0.3",
|
||||
"vite": "^8.0.10",
|
||||
"vite-plugin-svgr": "^4.3.0"
|
||||
"vite-plugin-svgr": "^4.3.0",
|
||||
"vitest": "^3.2.4"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import Spinner from './components/Spinner';
|
||||
import UploadToast from './components/UploadToast';
|
||||
import Conversation from './conversation/Conversation';
|
||||
import { SharedConversation } from './conversation/SharedConversation';
|
||||
import { EventStreamProvider } from './events/EventStreamProvider';
|
||||
import { useDarkTheme, useMediaQuery } from './hooks';
|
||||
import useDataInitializer from './hooks/useDataInitializer';
|
||||
import useTokenAuth from './hooks/useTokenAuth';
|
||||
@@ -17,6 +18,7 @@ import Navigation from './Navigation';
|
||||
import PageNotFound from './PageNotFound';
|
||||
import Setting from './settings';
|
||||
import Notification from './components/Notification';
|
||||
import ToolApprovalToast from './notifications/ToolApprovalToast';
|
||||
|
||||
function AuthWrapper({ children }: { children: React.ReactNode }) {
|
||||
const { isAuthLoading } = useTokenAuth();
|
||||
@@ -29,7 +31,7 @@ function AuthWrapper({ children }: { children: React.ReactNode }) {
|
||||
</div>
|
||||
);
|
||||
}
|
||||
return <>{children}</>;
|
||||
return <EventStreamProvider>{children}</EventStreamProvider>;
|
||||
}
|
||||
|
||||
function MainLayout() {
|
||||
@@ -50,6 +52,7 @@ function MainLayout() {
|
||||
<Outlet />
|
||||
</div>
|
||||
<UploadToast />
|
||||
<ToolApprovalToast />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -15,7 +15,6 @@ import Github from './assets/git_nav.svg';
|
||||
import Hamburger from './assets/hamburger.svg';
|
||||
import openNewChat from './assets/openNewChat.svg';
|
||||
import Pin from './assets/pin.svg';
|
||||
import SearchIcon from './assets/search.svg';
|
||||
import AgentImage from './components/AgentImage';
|
||||
import SettingGear from './assets/settingGear.svg';
|
||||
import Spark from './assets/spark.svg';
|
||||
@@ -36,7 +35,6 @@ import { useDarkTheme, useMediaQuery } from './hooks';
|
||||
import useTokenAuth from './hooks/useTokenAuth';
|
||||
import DeleteConvModal from './modals/DeleteConvModal';
|
||||
import JWTModal from './modals/JWTModal';
|
||||
import SearchConversationsModal from './modals/SearchConversationsModal';
|
||||
import { ActiveState } from './models/misc';
|
||||
import { getConversations } from './preferences/preferenceApi';
|
||||
import {
|
||||
@@ -84,7 +82,6 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
const [uploadModalState, setUploadModalState] =
|
||||
useState<ActiveState>('INACTIVE');
|
||||
const [recentAgents, setRecentAgents] = useState<Agent[]>([]);
|
||||
const [searchOpen, setSearchOpen] = useState(false);
|
||||
|
||||
const navRef = useRef<HTMLDivElement>(null);
|
||||
useEffect(() => {
|
||||
@@ -509,23 +506,11 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
)}
|
||||
{conversations?.data && conversations.data.length > 0 ? (
|
||||
<div className="mt-7">
|
||||
<div className="mx-4 my-auto mt-2 flex h-8 items-center justify-between gap-4 rounded-3xl">
|
||||
<div className="mx-4 my-auto mt-2 flex h-6 items-center justify-between gap-4 rounded-3xl">
|
||||
<p className="mt-1 ml-4 text-sm font-semibold">{t('chats')}</p>
|
||||
<button
|
||||
onClick={() => setSearchOpen(true)}
|
||||
className="hover:bg-sidebar-accent mr-2 flex h-7 w-7 items-center justify-center rounded-full"
|
||||
aria-label={t('modals.searchConversations.searchPlaceholder')}
|
||||
title={t('modals.searchConversations.searchPlaceholder')}
|
||||
>
|
||||
<img
|
||||
src={SearchIcon}
|
||||
alt="search"
|
||||
className="h-4 w-4 opacity-70"
|
||||
/>
|
||||
</button>
|
||||
</div>
|
||||
<div className="conversations-container">
|
||||
{(conversations.data ?? []).map((conversation) => (
|
||||
{conversations.data?.map((conversation) => (
|
||||
<ConversationTile
|
||||
key={conversation.id}
|
||||
conversation={conversation}
|
||||
@@ -659,17 +644,6 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
modalState={showTokenModal ? 'ACTIVE' : 'INACTIVE'}
|
||||
handleTokenSubmit={handleTokenSubmit}
|
||||
/>
|
||||
{searchOpen && (
|
||||
<SearchConversationsModal
|
||||
close={() => setSearchOpen(false)}
|
||||
conversations={conversations?.data ?? []}
|
||||
token={token}
|
||||
onSelectConversation={(id) => {
|
||||
handleConversationClick(id);
|
||||
if (isMobile || isTablet) setNavOpen(false);
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -28,7 +28,6 @@ const endpoints = {
|
||||
UPDATE_PROMPT: '/api/update_prompt',
|
||||
SINGLE_PROMPT: (id: string) => `/api/get_single_prompt?id=${id}`,
|
||||
DELETE_PATH: (docPath: string) => `/api/delete_old?source_id=${docPath}`,
|
||||
TASK_STATUS: (task_id: string) => `/api/task_status?task_id=${task_id}`,
|
||||
MESSAGE_ANALYTICS: '/api/get_message_analytics',
|
||||
TOKEN_ANALYTICS: '/api/get_token_analytics',
|
||||
FEEDBACK_ANALYTICS: '/api/get_feedback_analytics',
|
||||
@@ -73,8 +72,6 @@ const endpoints = {
|
||||
MANAGE_SOURCE_FILES: '/api/manage_source_files',
|
||||
MCP_TEST_CONNECTION: '/api/mcp_server/test',
|
||||
MCP_SAVE_SERVER: '/api/mcp_server/save',
|
||||
MCP_OAUTH_STATUS: (task_id: string) =>
|
||||
`/api/mcp_server/oauth_status/${task_id}`,
|
||||
MCP_AUTH_STATUS: '/api/mcp_server/auth_status',
|
||||
AGENT_FOLDERS: '/api/agents/folders/',
|
||||
AGENT_FOLDER: (id: string) => `/api/agents/folders/${id}`,
|
||||
@@ -98,8 +95,6 @@ const endpoints = {
|
||||
FEEDBACK: '/api/feedback',
|
||||
CONVERSATION: (id: string) => `/api/get_single_conversation?id=${id}`,
|
||||
CONVERSATIONS: '/api/get_conversations',
|
||||
SEARCH_CONVERSATIONS: (q: string, limit = 30) =>
|
||||
`/api/search_conversations?q=${encodeURIComponent(q)}&limit=${limit}`,
|
||||
MESSAGE_TAIL: (messageId: string) => `/api/messages/${messageId}/tail`,
|
||||
SHARE_CONVERSATION: (isPromptable: boolean) =>
|
||||
`/api/share?isPromptable=${isPromptable}`,
|
||||
|
||||
@@ -32,16 +32,6 @@ const conversationService = {
|
||||
apiClient.get(endpoints.CONVERSATION.MESSAGE_TAIL(messageId), token, {}),
|
||||
getConversations: (token: string | null): Promise<any> =>
|
||||
apiClient.get(endpoints.CONVERSATION.CONVERSATIONS, token, {}),
|
||||
searchConversations: (
|
||||
query: string,
|
||||
token: string | null,
|
||||
limit = 30,
|
||||
): Promise<any> =>
|
||||
apiClient.get(
|
||||
endpoints.CONVERSATION.SEARCH_CONVERSATIONS(query, limit),
|
||||
token,
|
||||
{},
|
||||
),
|
||||
shareConversation: (
|
||||
isPromptable: boolean,
|
||||
data: any,
|
||||
|
||||
@@ -61,8 +61,6 @@ const userService = {
|
||||
apiClient.get(endpoints.USER.SINGLE_PROMPT(id), token),
|
||||
deletePath: (docPath: string, token: string | null): Promise<any> =>
|
||||
apiClient.get(endpoints.USER.DELETE_PATH(docPath), token),
|
||||
getTaskStatus: (task_id: string, token: string | null): Promise<any> =>
|
||||
throttledApiClient.get(endpoints.USER.TASK_STATUS(task_id), token),
|
||||
getMessageAnalytics: (data: any, token: string | null): Promise<any> =>
|
||||
apiClient.post(endpoints.USER.MESSAGE_ANALYTICS, data, token),
|
||||
getTokenAnalytics: (data: any, token: string | null): Promise<any> =>
|
||||
@@ -172,8 +170,6 @@ const userService = {
|
||||
apiClient.post(endpoints.USER.MCP_TEST_CONNECTION, data, token),
|
||||
saveMCPServer: (data: any, token: string | null): Promise<any> =>
|
||||
apiClient.post(endpoints.USER.MCP_SAVE_SERVER, data, token),
|
||||
getMCPOAuthStatus: (task_id: string, token: string | null): Promise<any> =>
|
||||
throttledApiClient.get(endpoints.USER.MCP_OAUTH_STATUS(task_id), token),
|
||||
getMCPAuthStatus: (token: string | null): Promise<any> =>
|
||||
throttledApiClient.get(endpoints.USER.MCP_AUTH_STATUS, token),
|
||||
syncConnector: (
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import React, { useRef } from 'react';
|
||||
import React, { useEffect, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useSelector } from 'react-redux';
|
||||
|
||||
@@ -32,13 +32,24 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
||||
const [isDarkTheme] = useDarkTheme();
|
||||
const completedRef = useRef(false);
|
||||
const intervalRef = useRef<number | null>(null);
|
||||
const authWindowRef = useRef<Window | null>(null);
|
||||
// Hold the exact listener identity so unmount cleanup removes the same fn.
|
||||
const messageHandlerRef = useRef<((event: MessageEvent) => void) | null>(
|
||||
null,
|
||||
);
|
||||
// Tracks mount status so async ``fetch`` resolves after unmount don't
|
||||
// call ``onSuccess`` / ``onError`` on a vanished parent.
|
||||
const mountedRef = useRef(true);
|
||||
|
||||
const cleanup = () => {
|
||||
if (intervalRef.current) {
|
||||
clearInterval(intervalRef.current);
|
||||
intervalRef.current = null;
|
||||
}
|
||||
window.removeEventListener('message', handleAuthMessage as any);
|
||||
if (messageHandlerRef.current) {
|
||||
window.removeEventListener('message', messageHandlerRef.current as any);
|
||||
messageHandlerRef.current = null;
|
||||
}
|
||||
};
|
||||
|
||||
const handleAuthMessage = (event: MessageEvent) => {
|
||||
@@ -49,6 +60,7 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
||||
if (successGeneric || successProvider) {
|
||||
completedRef.current = true;
|
||||
cleanup();
|
||||
authWindowRef.current = null;
|
||||
onSuccess({
|
||||
session_token: event.data.session_token,
|
||||
user_email:
|
||||
@@ -58,6 +70,7 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
||||
} else if (errorProvider) {
|
||||
completedRef.current = true;
|
||||
cleanup();
|
||||
authWindowRef.current = null;
|
||||
onError(
|
||||
event.data.error || t('modals.uploadDoc.connectors.auth.authFailed'),
|
||||
);
|
||||
@@ -67,12 +80,20 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
||||
const handleAuth = async () => {
|
||||
try {
|
||||
completedRef.current = false;
|
||||
// Close any popup left over from a previous click before wiping
|
||||
// the ref — otherwise the old window keeps living with no
|
||||
// interval watching it and no listener handling its messages.
|
||||
if (authWindowRef.current && !authWindowRef.current.closed) {
|
||||
authWindowRef.current.close();
|
||||
}
|
||||
authWindowRef.current = null;
|
||||
cleanup();
|
||||
|
||||
const authResponse = await userService.getConnectorAuthUrl(
|
||||
provider,
|
||||
token,
|
||||
);
|
||||
if (!mountedRef.current) return;
|
||||
|
||||
if (!authResponse.ok) {
|
||||
throw new Error(
|
||||
@@ -81,6 +102,7 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
||||
}
|
||||
|
||||
const authData = await authResponse.json();
|
||||
if (!mountedRef.current) return;
|
||||
if (!authData.success || !authData.authorization_url) {
|
||||
throw new Error(
|
||||
authData.error || t('modals.uploadDoc.connectors.auth.authUrlFailed'),
|
||||
@@ -95,13 +117,23 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
||||
if (!authWindow) {
|
||||
throw new Error(t('modals.uploadDoc.connectors.auth.popupBlocked'));
|
||||
}
|
||||
authWindowRef.current = authWindow;
|
||||
|
||||
messageHandlerRef.current = handleAuthMessage;
|
||||
window.addEventListener('message', handleAuthMessage as any);
|
||||
|
||||
const checkClosed = window.setInterval(() => {
|
||||
if (authWindow.closed) {
|
||||
clearInterval(checkClosed);
|
||||
window.removeEventListener('message', handleAuthMessage as any);
|
||||
intervalRef.current = null;
|
||||
if (messageHandlerRef.current) {
|
||||
window.removeEventListener(
|
||||
'message',
|
||||
messageHandlerRef.current as any,
|
||||
);
|
||||
messageHandlerRef.current = null;
|
||||
}
|
||||
authWindowRef.current = null;
|
||||
if (!completedRef.current) {
|
||||
onError(t('modals.uploadDoc.connectors.auth.authCancelled'));
|
||||
}
|
||||
@@ -109,6 +141,7 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
||||
}, 1000);
|
||||
intervalRef.current = checkClosed;
|
||||
} catch (error) {
|
||||
if (!mountedRef.current) return;
|
||||
onError(
|
||||
error instanceof Error
|
||||
? error.message
|
||||
@@ -117,6 +150,18 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
||||
}
|
||||
};
|
||||
|
||||
// Release interval, message listener, and popup on unmount only.
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
mountedRef.current = false;
|
||||
cleanup();
|
||||
if (authWindowRef.current && !authWindowRef.current.closed) {
|
||||
authWindowRef.current.close();
|
||||
}
|
||||
authWindowRef.current = null;
|
||||
};
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<>
|
||||
{errorMessage && (
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import React, { useEffect, useRef, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useSelector } from 'react-redux';
|
||||
import { useSelector, useStore } from 'react-redux';
|
||||
|
||||
import userService from '../api/services/userService';
|
||||
import ArrowLeft from '../assets/arrow-left.svg';
|
||||
@@ -14,6 +14,7 @@ import { useLoaderState, useOutsideAlerter } from '../hooks';
|
||||
import ConfirmationModal from '../modals/ConfirmationModal';
|
||||
import { ActiveState } from '../models/misc';
|
||||
import { selectToken } from '../preferences/preferenceSlice';
|
||||
import type { RootState } from '../store';
|
||||
import { formatBytes } from '../utils/stringUtils';
|
||||
import Chunks from './Chunks';
|
||||
import ContextMenu, { MenuOption } from './ContextMenu';
|
||||
@@ -64,6 +65,7 @@ const ConnectorTree: React.FC<ConnectorTreeProps> = ({
|
||||
useState<DirectoryStructure | null>(null);
|
||||
const [currentPath, setCurrentPath] = useState<string[]>([]);
|
||||
const token = useSelector(selectToken);
|
||||
const store = useStore<RootState>();
|
||||
const [activeMenuId, setActiveMenuId] = useState<string | null>(null);
|
||||
const menuRefs = useRef<{
|
||||
[key: string]: React.RefObject<HTMLDivElement | null>;
|
||||
@@ -81,6 +83,25 @@ const ConnectorTree: React.FC<ConnectorTreeProps> = ({
|
||||
const [syncDone, setSyncDone] = useState<boolean>(false);
|
||||
const [syncConfirmationModal, setSyncConfirmationModal] =
|
||||
useState<ActiveState>('INACTIVE');
|
||||
const mountedRef = useRef(true);
|
||||
const syncUnsubscribeRef = useRef<(() => void) | null>(null);
|
||||
// Holds the 5-minute SSE-wait timer so the unmount cleanup can clear
|
||||
// it — otherwise the timer fires up to 5 min after unmount and
|
||||
// resolves an abandoned Promise.
|
||||
const syncTimerRef = useRef<number | null>(null);
|
||||
|
||||
useEffect(
|
||||
() => () => {
|
||||
mountedRef.current = false;
|
||||
syncUnsubscribeRef.current?.();
|
||||
syncUnsubscribeRef.current = null;
|
||||
if (syncTimerRef.current !== null) {
|
||||
window.clearTimeout(syncTimerRef.current);
|
||||
syncTimerRef.current = null;
|
||||
}
|
||||
},
|
||||
[],
|
||||
);
|
||||
|
||||
useOutsideAlerter(
|
||||
searchDropdownRef,
|
||||
@@ -116,67 +137,108 @@ const ConnectorTree: React.FC<ConnectorTreeProps> = ({
|
||||
console.log('Sync started successfully:', data.task_id);
|
||||
setSyncProgress(10);
|
||||
|
||||
// Poll task status using userService
|
||||
const maxAttempts = 30;
|
||||
const pollInterval = 2000;
|
||||
// The connector worker (``ingest_connector`` in
|
||||
// ``application/worker.py``) publishes
|
||||
// ``source.ingest.{queued,completed,failed}`` envelopes keyed on
|
||||
// ``scope.id == docId`` (sync mode reuses the source uuid). Wait
|
||||
// on the bounded ``notifications.recentEvents`` ring for a
|
||||
// terminal envelope rather than polling ``/task_status``.
|
||||
// Mirrors FileTree's slice-walking pattern, including the
|
||||
// ``opStartedAt`` guard so a stale terminal event from a prior
|
||||
// sync of this same source can't short-circuit the current op.
|
||||
const opStartedAt = Date.now();
|
||||
|
||||
const terminalFromSse = (): 'completed' | 'failed' | null => {
|
||||
const events = store.getState().notifications.recentEvents;
|
||||
for (const event of events) {
|
||||
if (event.scope?.id !== docId) continue;
|
||||
const ts = event.ts ? Date.parse(event.ts) : NaN;
|
||||
if (!Number.isFinite(ts) || ts < opStartedAt) continue;
|
||||
if (event.type === 'source.ingest.completed') return 'completed';
|
||||
if (event.type === 'source.ingest.failed') return 'failed';
|
||||
}
|
||||
return null;
|
||||
};
|
||||
|
||||
const MAX_WAIT_MS = 5 * 60_000;
|
||||
const terminal = await new Promise<
|
||||
'completed' | 'failed' | 'timeout' | 'unmounted'
|
||||
>((resolve) => {
|
||||
// Cover the race where the event landed between the POST
|
||||
// returning and the subscribe call.
|
||||
const initial = terminalFromSse();
|
||||
if (initial) {
|
||||
resolve(initial);
|
||||
return;
|
||||
}
|
||||
if (!mountedRef.current) {
|
||||
resolve('unmounted');
|
||||
return;
|
||||
}
|
||||
let settled = false;
|
||||
const finish = (
|
||||
value: 'completed' | 'failed' | 'timeout' | 'unmounted',
|
||||
) => {
|
||||
if (settled) return;
|
||||
settled = true;
|
||||
if (syncTimerRef.current !== null) {
|
||||
window.clearTimeout(syncTimerRef.current);
|
||||
syncTimerRef.current = null;
|
||||
}
|
||||
if (syncUnsubscribeRef.current) {
|
||||
syncUnsubscribeRef.current();
|
||||
syncUnsubscribeRef.current = null;
|
||||
}
|
||||
resolve(value);
|
||||
};
|
||||
syncTimerRef.current = window.setTimeout(
|
||||
() => finish('timeout'),
|
||||
MAX_WAIT_MS,
|
||||
);
|
||||
syncUnsubscribeRef.current = store.subscribe(() => {
|
||||
if (!mountedRef.current) {
|
||||
finish('unmounted');
|
||||
return;
|
||||
}
|
||||
const next = terminalFromSse();
|
||||
if (next) finish(next);
|
||||
});
|
||||
});
|
||||
|
||||
if (terminal === 'timeout') {
|
||||
console.error('Sync timed out waiting for SSE terminal');
|
||||
} else if (terminal === 'unmounted') {
|
||||
return;
|
||||
}
|
||||
|
||||
if (terminal === 'completed') {
|
||||
// The "no files downloaded" early-return path publishes
|
||||
// ``completed`` with ``no_changes: true`` — treated as success
|
||||
// here; refreshing the directory is cheap and idempotent.
|
||||
setSyncProgress(100);
|
||||
console.log('Sync completed successfully');
|
||||
|
||||
for (let attempt = 0; attempt < maxAttempts; attempt++) {
|
||||
try {
|
||||
const statusResponse = await userService.getTaskStatus(
|
||||
data.task_id,
|
||||
const refreshResponse = await userService.getDirectoryStructure(
|
||||
docId,
|
||||
token,
|
||||
);
|
||||
const statusData = await statusResponse.json();
|
||||
|
||||
console.log(
|
||||
`Task status (attempt ${attempt + 1}):`,
|
||||
statusData.status,
|
||||
);
|
||||
|
||||
if (statusData.status === 'SUCCESS') {
|
||||
setSyncProgress(100);
|
||||
console.log('Sync completed successfully');
|
||||
|
||||
// Refresh directory structure
|
||||
try {
|
||||
const refreshResponse = await userService.getDirectoryStructure(
|
||||
docId,
|
||||
token,
|
||||
);
|
||||
const refreshData = await refreshResponse.json();
|
||||
if (refreshData && refreshData.directory_structure) {
|
||||
setDirectoryStructure(refreshData.directory_structure);
|
||||
setCurrentPath([]);
|
||||
}
|
||||
if (refreshData && refreshData.provider) {
|
||||
setSourceProvider(refreshData.provider);
|
||||
}
|
||||
|
||||
setSyncDone(true);
|
||||
setTimeout(() => setSyncDone(false), 5000);
|
||||
} catch (err) {
|
||||
console.error('Error refreshing directory structure:', err);
|
||||
}
|
||||
break;
|
||||
} else if (statusData.status === 'FAILURE') {
|
||||
console.error('Sync task failed:', statusData.result);
|
||||
break;
|
||||
} else if (statusData.status === 'PROGRESS') {
|
||||
const progress = Number(
|
||||
statusData.result && statusData.result.current != null
|
||||
? statusData.result.current
|
||||
: statusData.meta && statusData.meta.current != null
|
||||
? statusData.meta.current
|
||||
: 0,
|
||||
);
|
||||
setSyncProgress(Math.max(10, progress));
|
||||
const refreshData = await refreshResponse.json();
|
||||
if (refreshData && refreshData.directory_structure) {
|
||||
setDirectoryStructure(refreshData.directory_structure);
|
||||
setCurrentPath([]);
|
||||
}
|
||||
if (refreshData && refreshData.provider) {
|
||||
setSourceProvider(refreshData.provider);
|
||||
}
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, pollInterval));
|
||||
} catch (error) {
|
||||
console.error('Error polling task status:', error);
|
||||
break;
|
||||
setSyncDone(true);
|
||||
setTimeout(() => setSyncDone(false), 5000);
|
||||
} catch (err) {
|
||||
console.error('Error refreshing directory structure:', err);
|
||||
}
|
||||
} else if (terminal === 'failed') {
|
||||
console.error('Sync task failed (per SSE)');
|
||||
}
|
||||
} else {
|
||||
console.error('Sync failed:', data.error);
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import React, { useState, useRef, useEffect } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useSelector } from 'react-redux';
|
||||
import { useSelector, useStore } from 'react-redux';
|
||||
import { selectToken } from '../preferences/preferenceSlice';
|
||||
import type { RootState } from '../store';
|
||||
import { formatBytes } from '../utils/stringUtils';
|
||||
import Chunks from './Chunks';
|
||||
import ContextMenu, { MenuOption } from './ContextMenu';
|
||||
@@ -56,6 +57,7 @@ const FileTree: React.FC<FileTreeProps> = ({
|
||||
onBackToDocuments,
|
||||
}) => {
|
||||
const { t } = useTranslation();
|
||||
const store = useStore<RootState>();
|
||||
const [loading, setLoading] = useLoaderState(true, 500);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [directoryStructure, setDirectoryStructure] =
|
||||
@@ -95,6 +97,25 @@ const FileTree: React.FC<FileTreeProps> = ({
|
||||
const opQueueRef = useRef<QueuedOperation[]>([]);
|
||||
const processingRef = useRef(false);
|
||||
const [queueLength, setQueueLength] = useState(0);
|
||||
const mountedRef = useRef(true);
|
||||
const waitUnsubscribeRef = useRef<(() => void) | null>(null);
|
||||
// Holds the 5-minute SSE-wait timer so the unmount cleanup can clear
|
||||
// it — otherwise the timer fires up to 5 min after unmount and
|
||||
// resolves an abandoned Promise.
|
||||
const waitTimerRef = useRef<number | null>(null);
|
||||
|
||||
useEffect(
|
||||
() => () => {
|
||||
mountedRef.current = false;
|
||||
waitUnsubscribeRef.current?.();
|
||||
waitUnsubscribeRef.current = null;
|
||||
if (waitTimerRef.current !== null) {
|
||||
window.clearTimeout(waitTimerRef.current);
|
||||
waitTimerRef.current = null;
|
||||
}
|
||||
},
|
||||
[],
|
||||
);
|
||||
|
||||
useOutsideAlerter(
|
||||
searchDropdownRef,
|
||||
@@ -313,47 +334,103 @@ const FileTree: React.FC<FileTreeProps> = ({
|
||||
}
|
||||
console.log('Reingest task started:', result.reingest_task_id);
|
||||
|
||||
const maxAttempts = 30;
|
||||
const pollInterval = 2000;
|
||||
// SSE is the sole driver here. The backend's
|
||||
// ``reingest_source_worker`` publishes ``source.ingest.*``
|
||||
// keyed on the resolved ``source_id`` (the
|
||||
// ``manage_source_files`` route returns it explicitly so we
|
||||
// can match without consulting any slice). Subscribe to the
|
||||
// store and resolve when a terminal event tagged with our
|
||||
// source lands in ``notifications.recentEvents``. Re-checking
|
||||
// on every dispatch (rather than polling on a timer) avoids
|
||||
// races where a terminal could roll off the bounded ring
|
||||
// before the next tick observes it in chatty sessions.
|
||||
const reingestSourceId: string | undefined = result.source_id;
|
||||
// Cutoff so we don't pick up terminal events from a *previous*
|
||||
// reingest of the same source — the backend's
|
||||
// ``source.ingest.*`` payload doesn't carry a Celery task id,
|
||||
// so source_id alone is ambiguous when ops repeat.
|
||||
const opStartedAt = Date.now();
|
||||
const MAX_WAIT_MS = 5 * 60_000;
|
||||
|
||||
for (let attempt = 0; attempt < maxAttempts; attempt++) {
|
||||
try {
|
||||
const statusResponse = await userService.getTaskStatus(
|
||||
result.reingest_task_id,
|
||||
token,
|
||||
);
|
||||
const statusData = await statusResponse.json();
|
||||
|
||||
console.log(
|
||||
`Task status (attempt ${attempt + 1}):`,
|
||||
statusData.status,
|
||||
);
|
||||
|
||||
if (statusData.status === 'SUCCESS') {
|
||||
console.log('Task completed successfully');
|
||||
|
||||
const structureResponse = await userService.getDirectoryStructure(
|
||||
docId,
|
||||
token,
|
||||
);
|
||||
const structureData = await structureResponse.json();
|
||||
|
||||
if (structureData && structureData.directory_structure) {
|
||||
setDirectoryStructure(structureData.directory_structure);
|
||||
currentOpRef.current = null;
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
} else if (statusData.status === 'FAILURE') {
|
||||
console.error('Task failed');
|
||||
break;
|
||||
}
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, pollInterval));
|
||||
} catch (error) {
|
||||
console.error('Error polling task status:', error);
|
||||
break;
|
||||
const terminalFromSse = (): 'completed' | 'failed' | null => {
|
||||
if (!reingestSourceId) return null;
|
||||
const events = store.getState().notifications.recentEvents;
|
||||
for (const event of events) {
|
||||
if (event.scope?.id !== reingestSourceId) continue;
|
||||
const ts = event.ts ? Date.parse(event.ts) : NaN;
|
||||
if (!Number.isFinite(ts) || ts < opStartedAt) continue;
|
||||
if (event.type === 'source.ingest.completed') return 'completed';
|
||||
if (event.type === 'source.ingest.failed') return 'failed';
|
||||
}
|
||||
return null;
|
||||
};
|
||||
|
||||
const refreshStructure = async (): Promise<boolean> => {
|
||||
const structureResponse = await userService.getDirectoryStructure(
|
||||
docId,
|
||||
token,
|
||||
);
|
||||
const structureData = await structureResponse.json();
|
||||
if (!mountedRef.current) return false;
|
||||
if (structureData && structureData.directory_structure) {
|
||||
setDirectoryStructure(structureData.directory_structure);
|
||||
currentOpRef.current = null;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
const terminal = await new Promise<
|
||||
'completed' | 'failed' | 'timeout' | 'unmounted'
|
||||
>((resolve) => {
|
||||
if (!mountedRef.current) {
|
||||
resolve('unmounted');
|
||||
return;
|
||||
}
|
||||
// Cover the race where the terminal event landed between
|
||||
// the POST returning and the subscribe call.
|
||||
const initial = terminalFromSse();
|
||||
if (initial) {
|
||||
resolve(initial);
|
||||
return;
|
||||
}
|
||||
const timer = window.setTimeout(() => {
|
||||
waitUnsubscribeRef.current?.();
|
||||
waitUnsubscribeRef.current = null;
|
||||
waitTimerRef.current = null;
|
||||
resolve('timeout');
|
||||
}, MAX_WAIT_MS);
|
||||
waitTimerRef.current = timer;
|
||||
waitUnsubscribeRef.current = store.subscribe(() => {
|
||||
if (!mountedRef.current) {
|
||||
window.clearTimeout(timer);
|
||||
waitTimerRef.current = null;
|
||||
waitUnsubscribeRef.current?.();
|
||||
waitUnsubscribeRef.current = null;
|
||||
resolve('unmounted');
|
||||
return;
|
||||
}
|
||||
const next = terminalFromSse();
|
||||
if (next) {
|
||||
window.clearTimeout(timer);
|
||||
waitTimerRef.current = null;
|
||||
waitUnsubscribeRef.current?.();
|
||||
waitUnsubscribeRef.current = null;
|
||||
resolve(next);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
if (!mountedRef.current) return false;
|
||||
|
||||
if (terminal === 'completed') {
|
||||
if (await refreshStructure()) return true;
|
||||
} else if (terminal === 'failed') {
|
||||
console.error('Reingest task failed (per SSE)');
|
||||
} else if (terminal === 'unmounted') {
|
||||
return false;
|
||||
} else {
|
||||
console.error('Reingest timed out waiting for SSE terminal');
|
||||
}
|
||||
} else {
|
||||
throw new Error(
|
||||
@@ -374,7 +451,7 @@ const FileTree: React.FC<FileTreeProps> = ({
|
||||
? 'delete directory'
|
||||
: 'delete file(s)';
|
||||
console.error(`Error ${actionText}:`, error);
|
||||
setError(`Failed to ${errorText}`);
|
||||
if (mountedRef.current) setError(`Failed to ${errorText}`);
|
||||
} finally {
|
||||
currentOpRef.current = null;
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ import { createPortal } from 'react-dom';
|
||||
import { LoaderCircle, Mic, Square } from 'lucide-react';
|
||||
import { useDropzone } from 'react-dropzone';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useDispatch, useSelector } from 'react-redux';
|
||||
import { useDispatch, useSelector, useStore } from 'react-redux';
|
||||
|
||||
import endpoints from '../api/endpoints';
|
||||
import userService from '../api/services/userService';
|
||||
@@ -28,6 +28,7 @@ import {
|
||||
selectSelectedDocs,
|
||||
selectToken,
|
||||
} from '../preferences/preferenceSlice';
|
||||
import type { RootState } from '../store';
|
||||
import Upload from '../upload/Upload';
|
||||
import { getOS, isTouchDevice } from '../utils/browserUtils';
|
||||
import SourcesPopup from './SourcesPopup';
|
||||
@@ -316,6 +317,7 @@ export default function MessageInput({
|
||||
const attachments = useSelector(selectAttachments);
|
||||
|
||||
const dispatch = useDispatch();
|
||||
const store = useStore<RootState>();
|
||||
const mediaStreamRef = useRef<MediaStream | null>(null);
|
||||
const audioContextRef = useRef<AudioContext | null>(null);
|
||||
const audioSourceNodeRef = useRef<MediaStreamAudioSourceNode | null>(null);
|
||||
@@ -410,6 +412,86 @@ export default function MessageInput({
|
||||
};
|
||||
}, []);
|
||||
|
||||
// Recover the race where attachment.* SSE arrives before the upload
|
||||
// XHR's onload sets ``attachmentId``: walk recentEvents and watchdog
|
||||
// the row so it can't stay stuck on 'processing'. Mirrors
|
||||
// Upload.tsx's ``trackTraining``.
|
||||
const trackAttachment = useCallback(
|
||||
(clientId: string, attachmentId: string) => {
|
||||
let handled = false;
|
||||
|
||||
const check = () => {
|
||||
const state = store.getState();
|
||||
const row = state.upload.attachments.find((a) => a.id === clientId);
|
||||
if (!row) return true; // removed by user; stop tracking
|
||||
if (row.status === 'completed' || row.status === 'failed') {
|
||||
handled = true;
|
||||
return true;
|
||||
}
|
||||
for (const event of state.notifications.recentEvents) {
|
||||
if (event.scope?.id !== attachmentId) continue;
|
||||
if (event.type === 'attachment.completed') {
|
||||
const payload = (event.payload || {}) as Record<string, unknown>;
|
||||
const tokenCount = Number(payload.token_count);
|
||||
handled = true;
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: clientId,
|
||||
updates: {
|
||||
status: 'completed',
|
||||
progress: 100,
|
||||
...(Number.isFinite(tokenCount)
|
||||
? { token_count: tokenCount }
|
||||
: {}),
|
||||
},
|
||||
}),
|
||||
);
|
||||
return true;
|
||||
}
|
||||
if (event.type === 'attachment.failed') {
|
||||
handled = true;
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: clientId,
|
||||
updates: { status: 'failed' },
|
||||
}),
|
||||
);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
if (check()) return;
|
||||
const MAX_WAIT_MS = 5 * 60_000;
|
||||
let unsubscribe: (() => void) | null = null;
|
||||
const timer = window.setTimeout(() => {
|
||||
unsubscribe?.();
|
||||
if (!handled) {
|
||||
handled = true;
|
||||
console.warn(
|
||||
'trackAttachment: timed out waiting for terminal SSE',
|
||||
clientId,
|
||||
attachmentId,
|
||||
);
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: clientId,
|
||||
updates: { status: 'failed' },
|
||||
}),
|
||||
);
|
||||
}
|
||||
}, MAX_WAIT_MS);
|
||||
unsubscribe = store.subscribe(() => {
|
||||
if (check()) {
|
||||
window.clearTimeout(timer);
|
||||
unsubscribe?.();
|
||||
}
|
||||
});
|
||||
},
|
||||
[dispatch, store],
|
||||
);
|
||||
|
||||
const uploadFiles = useCallback(
|
||||
(files: File[]) => {
|
||||
if (!files || files.length === 0) return;
|
||||
@@ -510,11 +592,19 @@ export default function MessageInput({
|
||||
id: uiId,
|
||||
updates: {
|
||||
taskId: task.task_id,
|
||||
// Stash the server's attachment id so SSE
|
||||
// ``attachment.*`` events can match this
|
||||
// row by ``scope.id`` and drive the
|
||||
// per-attachment push-fresh poll gate.
|
||||
attachmentId: task.attachment_id,
|
||||
status: 'processing',
|
||||
progress: 10,
|
||||
},
|
||||
}),
|
||||
);
|
||||
if (task.attachment_id) {
|
||||
trackAttachment(uiId, task.attachment_id);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -545,11 +635,15 @@ export default function MessageInput({
|
||||
id: uiId,
|
||||
updates: {
|
||||
taskId: t.task_id,
|
||||
attachmentId: t.attachment_id,
|
||||
status: 'processing',
|
||||
progress: 10,
|
||||
},
|
||||
}),
|
||||
);
|
||||
if (t.attachment_id) {
|
||||
trackAttachment(uiId, t.attachment_id);
|
||||
}
|
||||
} else {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
@@ -583,11 +677,15 @@ export default function MessageInput({
|
||||
id: uiId,
|
||||
updates: {
|
||||
taskId: response.task_id,
|
||||
attachmentId: response.attachment_id,
|
||||
status: 'processing',
|
||||
progress: 10,
|
||||
},
|
||||
}),
|
||||
);
|
||||
if (response.attachment_id) {
|
||||
trackAttachment(uiId, response.attachment_id);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
console.warn(
|
||||
@@ -714,11 +812,15 @@ export default function MessageInput({
|
||||
id: uniqueId,
|
||||
updates: {
|
||||
taskId: response.task_id,
|
||||
attachmentId: response.attachment_id,
|
||||
status: 'processing',
|
||||
progress: 10,
|
||||
},
|
||||
}),
|
||||
);
|
||||
if (response.attachment_id) {
|
||||
trackAttachment(uniqueId, response.attachment_id);
|
||||
}
|
||||
} else {
|
||||
// If backend returned tasks[] for single-file, handle gracefully:
|
||||
if (
|
||||
@@ -730,11 +832,15 @@ export default function MessageInput({
|
||||
id: uniqueId,
|
||||
updates: {
|
||||
taskId: response.tasks[0].task_id,
|
||||
attachmentId: response.tasks[0].attachment_id,
|
||||
status: 'processing',
|
||||
progress: 10,
|
||||
},
|
||||
}),
|
||||
);
|
||||
if (response.tasks[0].attachment_id) {
|
||||
trackAttachment(uniqueId, response.tasks[0].attachment_id);
|
||||
}
|
||||
} else {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
@@ -781,7 +887,7 @@ export default function MessageInput({
|
||||
xhr.send(formData);
|
||||
});
|
||||
},
|
||||
[dispatch, token],
|
||||
[dispatch, token, trackAttachment],
|
||||
);
|
||||
|
||||
const handleFileAttachment = (e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
@@ -816,65 +922,6 @@ export default function MessageInput({
|
||||
accept: FILE_UPLOAD_ACCEPT,
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
const checkTaskStatus = () => {
|
||||
const processingAttachments = attachments.filter(
|
||||
(att) => att.status === 'processing' && att.taskId,
|
||||
);
|
||||
|
||||
processingAttachments.forEach((attachment) => {
|
||||
userService
|
||||
.getTaskStatus(attachment.taskId!, null)
|
||||
.then((data) => data.json())
|
||||
.then((data) => {
|
||||
if (data.status === 'SUCCESS') {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: attachment.id,
|
||||
updates: {
|
||||
status: 'completed',
|
||||
progress: 100,
|
||||
id: data.result?.attachment_id,
|
||||
token_count: data.result?.token_count,
|
||||
},
|
||||
}),
|
||||
);
|
||||
} else if (data.status === 'FAILURE') {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: attachment.id,
|
||||
updates: { status: 'failed' },
|
||||
}),
|
||||
);
|
||||
} else if (data.status === 'PROGRESS' && data.result?.current) {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: attachment.id,
|
||||
updates: { progress: data.result.current },
|
||||
}),
|
||||
);
|
||||
}
|
||||
})
|
||||
.catch(() => {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: attachment.id,
|
||||
updates: { status: 'failed' },
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
const interval = setInterval(() => {
|
||||
if (attachments.some((att) => att.status === 'processing')) {
|
||||
checkTaskStatus();
|
||||
}
|
||||
}, 2000);
|
||||
|
||||
return () => clearInterval(interval);
|
||||
}, [attachments, dispatch]);
|
||||
|
||||
const handleInput = useCallback(() => {
|
||||
if (inputRef.current) {
|
||||
if (window.innerWidth < 350) inputRef.current.style.height = 'auto';
|
||||
|
||||
@@ -5,41 +5,54 @@ import { useDispatch, useSelector } from 'react-redux';
|
||||
import CheckCircleFilled from '../assets/check-circle-filled.svg';
|
||||
import ChevronDown from '../assets/chevron-down.svg';
|
||||
import WarnIcon from '../assets/warn.svg';
|
||||
import { dismissUploadTask, selectUploadTasks } from '../upload/uploadSlice';
|
||||
import {
|
||||
dismissUploadTask,
|
||||
selectUploadTasks,
|
||||
type UploadTask,
|
||||
} from '../upload/uploadSlice';
|
||||
|
||||
const PROGRESS_RADIUS = 10;
|
||||
const PROGRESS_CIRCUMFERENCE = 2 * Math.PI * PROGRESS_RADIUS;
|
||||
|
||||
export default function UploadToast() {
|
||||
const [collapsedTasks, setCollapsedTasks] = useState<Record<string, boolean>>(
|
||||
{},
|
||||
);
|
||||
const IN_PROGRESS_STATUSES = new Set<UploadTask['status']>([
|
||||
'preparing',
|
||||
'uploading',
|
||||
'training',
|
||||
]);
|
||||
|
||||
const toggleTaskCollapse = (taskId: string) => {
|
||||
setCollapsedTasks((prev) => ({
|
||||
...prev,
|
||||
[taskId]: !prev[taskId],
|
||||
}));
|
||||
};
|
||||
/**
|
||||
* Single merged upload card — Google-Drive style. Multiple in-flight
|
||||
* uploads share one toast with a list of rows; the header reflects
|
||||
* the *primary* task's status (the newest still-running task, or the
|
||||
* newest task overall if all are terminal). Per-task progress lives
|
||||
* on each row.
|
||||
*
|
||||
* Dismissal: the header X dismisses every visible task at once
|
||||
* (mirrors the GDrive panel close — keeps the surface tidy without
|
||||
* per-row controls). The chevron collapses the row list.
|
||||
*/
|
||||
export default function UploadToast() {
|
||||
const [collapsed, setCollapsed] = useState(false);
|
||||
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useDispatch();
|
||||
const uploadTasks = useSelector(selectUploadTasks);
|
||||
|
||||
const getStatusHeading = (status: string) => {
|
||||
switch (status) {
|
||||
case 'preparing':
|
||||
return t('modals.uploadDoc.progress.wait');
|
||||
case 'uploading':
|
||||
return t('modals.uploadDoc.progress.upload');
|
||||
case 'training':
|
||||
return t('modals.uploadDoc.progress.upload');
|
||||
case 'completed':
|
||||
return t('modals.uploadDoc.progress.completed');
|
||||
case 'failed':
|
||||
return t('modals.uploadDoc.progress.failed');
|
||||
default:
|
||||
return t('modals.uploadDoc.progress.preparing');
|
||||
const visibleTasks = uploadTasks.filter((task) => !task.dismissed);
|
||||
if (visibleTasks.length === 0) return null;
|
||||
|
||||
// Pick the task that drives the header status: prefer a still-
|
||||
// running task (most-recent first since the slice unshifts), and
|
||||
// fall back to whatever's most-recent if everything is terminal.
|
||||
const primaryTask =
|
||||
visibleTasks.find((task) => IN_PROGRESS_STATUSES.has(task.status)) ??
|
||||
visibleTasks[0];
|
||||
|
||||
const headerLabel = getStatusHeading(primaryTask.status, t);
|
||||
|
||||
const dismissAll = () => {
|
||||
for (const task of visibleTasks) {
|
||||
dispatch(dismissUploadTask(task.id));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -47,180 +60,205 @@ export default function UploadToast() {
|
||||
<div
|
||||
className="fixed right-4 bottom-4 z-50 flex max-w-md flex-col gap-2"
|
||||
onMouseDown={(e) => e.stopPropagation()}
|
||||
role="status"
|
||||
aria-live="polite"
|
||||
aria-atomic="false"
|
||||
>
|
||||
{uploadTasks
|
||||
.filter((task) => !task.dismissed)
|
||||
.map((task) => {
|
||||
const shouldShowProgress = [
|
||||
'preparing',
|
||||
'uploading',
|
||||
'training',
|
||||
].includes(task.status);
|
||||
const rawProgress = Math.min(Math.max(task.progress ?? 0, 0), 100);
|
||||
const formattedProgress = Math.round(rawProgress);
|
||||
const progressOffset =
|
||||
PROGRESS_CIRCUMFERENCE * (1 - rawProgress / 100);
|
||||
const isCollapsed = collapsedTasks[task.id] ?? false;
|
||||
|
||||
return (
|
||||
<div
|
||||
key={task.id}
|
||||
className={`border-border bg-card w-[271px] overflow-hidden rounded-2xl border shadow-[0px_24px_48px_0px_#00000029] transition-all duration-300`}
|
||||
<div
|
||||
className={`border-border bg-card w-[271px] overflow-hidden rounded-2xl border shadow-[0px_24px_48px_0px_#00000029] transition-all duration-300`}
|
||||
>
|
||||
<div
|
||||
className={`flex items-center justify-between px-4 py-3 ${
|
||||
primaryTask.status !== 'failed'
|
||||
? 'bg-accent/50 dark:bg-muted'
|
||||
: 'bg-destructive/10 dark:bg-destructive/10'
|
||||
}`}
|
||||
>
|
||||
<h3 className="font-inter dark:text-foreground text-[14px] leading-[16.5px] font-medium text-black">
|
||||
{headerLabel}
|
||||
</h3>
|
||||
<div className="flex items-center gap-1">
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => setCollapsed((prev) => !prev)}
|
||||
aria-label={
|
||||
collapsed
|
||||
? t('modals.uploadDoc.progress.expandDetails')
|
||||
: t('modals.uploadDoc.progress.collapseDetails')
|
||||
}
|
||||
className="flex h-8 items-center justify-center p-0 text-black opacity-70 transition-opacity hover:opacity-100 dark:text-white"
|
||||
>
|
||||
<div className="flex flex-col">
|
||||
<div
|
||||
className={`flex items-center justify-between px-4 py-3 ${
|
||||
task.status !== 'failed'
|
||||
? 'bg-accent/50 dark:bg-muted'
|
||||
: 'bg-destructive/10 dark:bg-destructive/10'
|
||||
}`}
|
||||
>
|
||||
<h3 className="font-inter dark:text-foreground text-[14px] leading-[16.5px] font-medium text-black">
|
||||
{getStatusHeading(task.status)}
|
||||
</h3>
|
||||
<div className="flex items-center gap-1">
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => toggleTaskCollapse(task.id)}
|
||||
aria-label={
|
||||
isCollapsed
|
||||
? t('modals.uploadDoc.progress.expandDetails')
|
||||
: t('modals.uploadDoc.progress.collapseDetails')
|
||||
}
|
||||
className="flex h-8 items-center justify-center p-0 text-black opacity-70 transition-opacity hover:opacity-100 dark:text-white"
|
||||
>
|
||||
<img
|
||||
src={ChevronDown}
|
||||
alt=""
|
||||
className={`h-4 w-4 transform transition-transform duration-200 dark:invert ${
|
||||
isCollapsed ? 'rotate-180' : ''
|
||||
}`}
|
||||
/>
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => dispatch(dismissUploadTask(task.id))}
|
||||
className="flex h-8 items-center justify-center p-0 text-black opacity-70 transition-opacity hover:opacity-100 dark:text-white"
|
||||
aria-label={t('modals.uploadDoc.progress.dismiss')}
|
||||
>
|
||||
<svg
|
||||
width="16"
|
||||
height="16"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
className="h-4 w-4"
|
||||
>
|
||||
<path
|
||||
d="M18 6L6 18"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
<path
|
||||
d="M6 6L18 18"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<img
|
||||
src={ChevronDown}
|
||||
alt=""
|
||||
className={`h-4 w-4 transform transition-transform duration-200 dark:invert ${
|
||||
collapsed ? 'rotate-180' : ''
|
||||
}`}
|
||||
/>
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
onClick={dismissAll}
|
||||
className="flex h-8 items-center justify-center p-0 text-black opacity-70 transition-opacity hover:opacity-100 dark:text-white"
|
||||
aria-label={t('modals.uploadDoc.progress.dismiss')}
|
||||
>
|
||||
<svg
|
||||
width="16"
|
||||
height="16"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
className="h-4 w-4"
|
||||
>
|
||||
<path
|
||||
d="M18 6L6 18"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
<path
|
||||
d="M6 6L18 18"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div
|
||||
className="grid overflow-hidden transition-[grid-template-rows] duration-300 ease-out"
|
||||
style={{ gridTemplateRows: isCollapsed ? '0fr' : '1fr' }}
|
||||
>
|
||||
<div
|
||||
className={`min-h-0 overflow-hidden transition-opacity duration-300 ${
|
||||
isCollapsed ? 'opacity-0' : 'opacity-100'
|
||||
}`}
|
||||
>
|
||||
<div className="flex items-center justify-between px-5 py-3">
|
||||
<p
|
||||
className="font-inter dark:text-muted-foreground max-w-[200px] truncate text-[13px] leading-[16.5px] font-normal text-black"
|
||||
title={task.fileName}
|
||||
>
|
||||
{task.fileName}
|
||||
</p>
|
||||
|
||||
<div className="flex items-center gap-2">
|
||||
{shouldShowProgress && (
|
||||
<svg
|
||||
width="24"
|
||||
height="24"
|
||||
viewBox="0 0 24 24"
|
||||
className="h-6 w-6 shrink-0 text-[#7D54D1]"
|
||||
role="progressbar"
|
||||
aria-valuemin={0}
|
||||
aria-valuemax={100}
|
||||
aria-valuenow={formattedProgress}
|
||||
aria-label={t(
|
||||
'modals.uploadDoc.progress.uploadProgress',
|
||||
{
|
||||
progress: formattedProgress,
|
||||
},
|
||||
)}
|
||||
>
|
||||
<circle
|
||||
className="text-muted dark:text-muted-foreground/30"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
cx="12"
|
||||
cy="12"
|
||||
r={PROGRESS_RADIUS}
|
||||
fill="none"
|
||||
/>
|
||||
<circle
|
||||
className="text-[#7D54D1]"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeDasharray={PROGRESS_CIRCUMFERENCE}
|
||||
strokeDashoffset={progressOffset}
|
||||
cx="12"
|
||||
cy="12"
|
||||
r={PROGRESS_RADIUS}
|
||||
fill="none"
|
||||
transform="rotate(-90 12 12)"
|
||||
/>
|
||||
</svg>
|
||||
)}
|
||||
|
||||
{task.status === 'completed' && (
|
||||
<img
|
||||
src={CheckCircleFilled}
|
||||
alt=""
|
||||
className="h-6 w-6 shrink-0"
|
||||
aria-hidden="true"
|
||||
/>
|
||||
)}
|
||||
|
||||
{task.status === 'failed' && (
|
||||
<img
|
||||
src={WarnIcon}
|
||||
alt=""
|
||||
className="h-6 w-6 shrink-0"
|
||||
aria-hidden="true"
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{task.status === 'failed' && task.errorMessage && (
|
||||
<span className="block px-5 pb-3 text-xs text-red-500">
|
||||
{task.errorMessage}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
<div
|
||||
className="grid overflow-hidden transition-[grid-template-rows] duration-300 ease-out"
|
||||
style={{ gridTemplateRows: collapsed ? '0fr' : '1fr' }}
|
||||
>
|
||||
<div
|
||||
className={`min-h-0 overflow-hidden transition-opacity duration-300 ${
|
||||
collapsed ? 'opacity-0' : 'opacity-100'
|
||||
}`}
|
||||
>
|
||||
<ul className="max-h-72 overflow-y-auto">
|
||||
{visibleTasks.map((task) => (
|
||||
<UploadRow key={task.id} task={task} t={t} />
|
||||
))}
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function UploadRow({
|
||||
task,
|
||||
t,
|
||||
}: {
|
||||
task: UploadTask;
|
||||
t: ReturnType<typeof useTranslation>['t'];
|
||||
}) {
|
||||
const showProgress = IN_PROGRESS_STATUSES.has(task.status);
|
||||
const rawProgress = Math.min(Math.max(task.progress ?? 0, 0), 100);
|
||||
const formattedProgress = Math.round(rawProgress);
|
||||
const progressOffset = PROGRESS_CIRCUMFERENCE * (1 - rawProgress / 100);
|
||||
|
||||
return (
|
||||
<li className="border-border/50 border-b last:border-b-0">
|
||||
<div className="flex items-center justify-between px-5 py-3">
|
||||
<p
|
||||
className="font-inter dark:text-muted-foreground max-w-[200px] truncate text-[13px] leading-[16.5px] font-normal text-black"
|
||||
title={task.fileName}
|
||||
>
|
||||
{task.fileName}
|
||||
</p>
|
||||
|
||||
<div className="flex items-center gap-2">
|
||||
{showProgress && (
|
||||
<svg
|
||||
width="24"
|
||||
height="24"
|
||||
viewBox="0 0 24 24"
|
||||
className="h-6 w-6 shrink-0 text-[#7D54D1]"
|
||||
role="progressbar"
|
||||
aria-valuemin={0}
|
||||
aria-valuemax={100}
|
||||
aria-valuenow={formattedProgress}
|
||||
aria-label={t('modals.uploadDoc.progress.uploadProgress', {
|
||||
progress: formattedProgress,
|
||||
})}
|
||||
>
|
||||
<circle
|
||||
className="text-muted dark:text-muted-foreground/30"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
cx="12"
|
||||
cy="12"
|
||||
r={PROGRESS_RADIUS}
|
||||
fill="none"
|
||||
/>
|
||||
<circle
|
||||
className="text-[#7D54D1]"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeDasharray={PROGRESS_CIRCUMFERENCE}
|
||||
strokeDashoffset={progressOffset}
|
||||
cx="12"
|
||||
cy="12"
|
||||
r={PROGRESS_RADIUS}
|
||||
fill="none"
|
||||
transform="rotate(-90 12 12)"
|
||||
/>
|
||||
</svg>
|
||||
)}
|
||||
|
||||
{task.status === 'completed' && (
|
||||
<img
|
||||
src={CheckCircleFilled}
|
||||
alt=""
|
||||
className="h-6 w-6 shrink-0"
|
||||
aria-hidden="true"
|
||||
/>
|
||||
)}
|
||||
|
||||
{task.status === 'failed' && (
|
||||
<img
|
||||
src={WarnIcon}
|
||||
alt=""
|
||||
className="h-6 w-6 shrink-0"
|
||||
aria-hidden="true"
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{task.status === 'failed' &&
|
||||
(task.tokenLimitReached || task.errorMessage) && (
|
||||
<span className="block px-5 pb-3 text-xs text-red-500">
|
||||
{task.tokenLimitReached
|
||||
? t('modals.uploadDoc.progress.tokenLimit')
|
||||
: task.errorMessage}
|
||||
</span>
|
||||
)}
|
||||
</li>
|
||||
);
|
||||
}
|
||||
|
||||
function getStatusHeading(
|
||||
status: UploadTask['status'],
|
||||
t: ReturnType<typeof useTranslation>['t'],
|
||||
): string {
|
||||
switch (status) {
|
||||
case 'preparing':
|
||||
return t('modals.uploadDoc.progress.wait');
|
||||
case 'uploading':
|
||||
case 'training':
|
||||
return t('modals.uploadDoc.progress.upload');
|
||||
case 'completed':
|
||||
return t('modals.uploadDoc.progress.completed');
|
||||
case 'failed':
|
||||
return t('modals.uploadDoc.progress.failed');
|
||||
default:
|
||||
return t('modals.uploadDoc.progress.preparing');
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,155 @@
|
||||
import { baseURL } from '../api/client';
|
||||
import conversationService from '../api/services/conversationService';
|
||||
import { Doc } from '../models/misc';
|
||||
import { Answer, FEEDBACK, RetrievalPayload } from './conversationModels';
|
||||
import { ToolCallsType } from './types';
|
||||
|
||||
/**
|
||||
* Mirrors the backend's ``_SEQUENCE_NO_RE`` (application/api/answer/
|
||||
* routes/messages.py) — only non-negative decimal integers are valid
|
||||
* cursors. Rejects empty strings (Number("") === 0), hex literals,
|
||||
* exponential notation, and anything else that ``Number(...)`` would
|
||||
* happily coerce.
|
||||
*/
|
||||
const _SEQUENCE_NO_RE = /^\d+$/;
|
||||
|
||||
/**
|
||||
* Drain an SSE response body, forwarding each ``data:`` line to
|
||||
* ``onData`` and tracking the most recent ``id:`` header. Returns
|
||||
* when the body ends, the signal aborts, or ``shouldStop()`` returns
|
||||
* true (e.g. a terminal ``end``/``error`` event was dispatched —
|
||||
* the reconnect endpoint is a live tail that doesn't close on its
|
||||
* own past terminal replay).
|
||||
*/
|
||||
/**
|
||||
* Convert a non-SSE pre-stream HTTP failure (e.g. ``check_usage``'s
|
||||
* 429 JSON response) into a synthetic typed ``error`` frame so the
|
||||
* caller's slice sees the actual server message instead of the
|
||||
* generic "Connection lost" synthesised when the drainer finishes
|
||||
* with zero events. Returns true if a frame was dispatched and the
|
||||
* caller should skip ``_drainSseBody`` entirely.
|
||||
*
|
||||
* SSE-shaped error bodies (``mimetype="text/event-stream"``) are
|
||||
* left alone — the drainer parses the typed ``error`` frame they
|
||||
* carry through the normal path.
|
||||
*/
|
||||
async function _handlePreStreamHttpError(
|
||||
response: Response,
|
||||
dispatch: (data: string) => void,
|
||||
): Promise<boolean> {
|
||||
if (response.ok) return false;
|
||||
const contentType = (
|
||||
response.headers.get('content-type') ?? ''
|
||||
).toLowerCase();
|
||||
if (contentType.includes('text/event-stream')) return false;
|
||||
let message: string | null = null;
|
||||
try {
|
||||
const text = await response.text();
|
||||
if (text) {
|
||||
try {
|
||||
const parsed = JSON.parse(text);
|
||||
if (parsed && typeof parsed === 'object') {
|
||||
message =
|
||||
(typeof parsed.message === 'string' && parsed.message) ||
|
||||
(typeof parsed.error === 'string' && parsed.error) ||
|
||||
(typeof parsed.detail === 'string' && parsed.detail) ||
|
||||
null;
|
||||
}
|
||||
} catch {
|
||||
message = text.slice(0, 500);
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// Body already consumed or unreadable — fall through to the
|
||||
// status-line fallback below.
|
||||
}
|
||||
if (!message) {
|
||||
message = `HTTP ${response.status} ${response.statusText}`.trim();
|
||||
}
|
||||
dispatch(JSON.stringify({ type: 'error', error: message }));
|
||||
return true;
|
||||
}
|
||||
|
||||
async function _drainSseBody(
|
||||
body: ReadableStream<Uint8Array>,
|
||||
signal: AbortSignal,
|
||||
onData: (data: string) => void,
|
||||
onId: (id: number) => void,
|
||||
shouldStop?: () => boolean,
|
||||
): Promise<void> {
|
||||
const reader = body.getReader();
|
||||
const decoder = new TextDecoder('utf-8');
|
||||
let buffer = '';
|
||||
let stoppedEarly = false;
|
||||
try {
|
||||
while (true) {
|
||||
if (signal.aborted) break;
|
||||
if (shouldStop?.()) {
|
||||
stoppedEarly = true;
|
||||
break;
|
||||
}
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
// Normalise mixed line terminators so a stray CR can't smuggle
|
||||
// a record boundary inside a JSON payload.
|
||||
buffer = buffer.replace(/\r\n/g, '\n').replace(/\r/g, '\n');
|
||||
let boundary = buffer.indexOf('\n\n');
|
||||
while (boundary !== -1) {
|
||||
const record = buffer.slice(0, boundary);
|
||||
buffer = buffer.slice(boundary + 2);
|
||||
boundary = buffer.indexOf('\n\n');
|
||||
if (record.length === 0) continue;
|
||||
const dataParts: string[] = [];
|
||||
let sawDataField = false;
|
||||
for (const line of record.split('\n')) {
|
||||
if (line.length === 0) continue;
|
||||
if (line.startsWith(':')) continue; // SSE comment / keepalive
|
||||
const colonIdx = line.indexOf(':');
|
||||
const field = colonIdx === -1 ? line : line.slice(0, colonIdx);
|
||||
let value = colonIdx === -1 ? '' : line.slice(colonIdx + 1);
|
||||
if (value.startsWith(' ')) value = value.slice(1);
|
||||
if (field === 'id') {
|
||||
// Strict regex match — empty value, hex, ``-1`` (the
|
||||
// backend's terminal snapshot-failure synthetic), and
|
||||
// exponent forms are all rejected so they can't silently
|
||||
// rewrite the reconnect cursor.
|
||||
if (_SEQUENCE_NO_RE.test(value)) onId(parseInt(value, 10));
|
||||
} else if (field === 'data') {
|
||||
sawDataField = true;
|
||||
dataParts.push(value);
|
||||
}
|
||||
}
|
||||
if (!sawDataField) continue;
|
||||
const data = dataParts.join('\n').trim();
|
||||
if (data.length === 0) continue;
|
||||
onData(data);
|
||||
if (shouldStop?.()) {
|
||||
stoppedEarly = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (stoppedEarly) break;
|
||||
}
|
||||
} finally {
|
||||
if (stoppedEarly) {
|
||||
// Ask the runtime to tear the underlying response body down so
|
||||
// the server-side WSGI thread isn't pinned waiting on
|
||||
// keepalives. ``releaseLock`` alone leaves the body half-open.
|
||||
try {
|
||||
await reader.cancel();
|
||||
} catch {
|
||||
// Already errored / closed.
|
||||
}
|
||||
}
|
||||
try {
|
||||
reader.releaseLock();
|
||||
} catch {
|
||||
// Already released.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export function handleFetchAnswer(
|
||||
question: string,
|
||||
signal: AbortSignal,
|
||||
@@ -143,54 +290,153 @@ export function handleFetchAnswerSteaming(
|
||||
|
||||
const headers: Record<string, string> = {};
|
||||
if (idempotencyKey) headers['Idempotency-Key'] = idempotencyKey;
|
||||
|
||||
// Per-stream state used for reconnect-after-disconnect.
|
||||
let messageId: string | null = null;
|
||||
let lastEventId: number | null = null;
|
||||
// The single JSON.parse below feeds both the message_id capture and
|
||||
// the termination flag — cheaper and stricter than substring
|
||||
// matching the wire bytes.
|
||||
let endReceived = false;
|
||||
|
||||
const dispatch = (data: string) => {
|
||||
try {
|
||||
const parsed = JSON.parse(data);
|
||||
if (parsed && typeof parsed === 'object') {
|
||||
if (parsed.type === 'message_id' && parsed.message_id) {
|
||||
messageId = parsed.message_id;
|
||||
} else if (parsed.type === 'end' || parsed.type === 'error') {
|
||||
endReceived = true;
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// Not JSON — pass through anyway; the caller handles raw lines.
|
||||
}
|
||||
onEvent(new MessageEvent('message', { data }));
|
||||
};
|
||||
|
||||
const runInitialPost = async (): Promise<void> => {
|
||||
const response = await conversationService.answerStream(
|
||||
payload,
|
||||
token,
|
||||
signal,
|
||||
headers,
|
||||
);
|
||||
// Pre-stream HTTP failures with non-SSE bodies (e.g. ``check_usage``
|
||||
// returning a JSON 429) drain as zero events and would otherwise
|
||||
// be masked by the generic "Connection lost" synthetic. Convert
|
||||
// them into a typed ``error`` frame so the real message surfaces.
|
||||
if (await _handlePreStreamHttpError(response, dispatch)) return;
|
||||
if (!response.body) throw new Error('No response body');
|
||||
await _drainSseBody(response.body, signal, dispatch, (id) => {
|
||||
lastEventId = id;
|
||||
});
|
||||
};
|
||||
|
||||
// Reconnect's stop predicate: as soon as ``dispatch`` flips
|
||||
// ``endReceived`` (typed ``end`` or ``error`` event seen — both
|
||||
// are terminal per the backend's contract). Without this the
|
||||
// live-tail endpoint would emit keepalives indefinitely and the
|
||||
// await would never return.
|
||||
const reconnectShouldStop = () => endReceived;
|
||||
|
||||
const runReconnect = async (): Promise<void> => {
|
||||
if (!messageId) {
|
||||
throw new Error('reconnect: no message_id captured');
|
||||
}
|
||||
const url = new URL(`${baseURL}/api/messages/${messageId}/events`);
|
||||
if (lastEventId !== null) {
|
||||
url.searchParams.set('last_event_id', String(lastEventId));
|
||||
}
|
||||
const reconnectHeaders: Record<string, string> = {
|
||||
Accept: 'text/event-stream',
|
||||
};
|
||||
if (token) reconnectHeaders.Authorization = `Bearer ${token}`;
|
||||
// NB: there is no slice consumer for a synthetic ``reconnecting``
|
||||
// event yet — surface only the underlying network reality. The
|
||||
// user-visible ``Reconnecting…`` affordance is a follow-up that
|
||||
// needs ``conversationSlice`` to gain a status case.
|
||||
const response = await fetch(url.toString(), {
|
||||
method: 'GET',
|
||||
headers: reconnectHeaders,
|
||||
signal,
|
||||
cache: 'no-store',
|
||||
});
|
||||
if (!response.ok || !response.body) {
|
||||
throw new Error(
|
||||
`reconnect: HTTP ${response.status} ${response.statusText}`,
|
||||
);
|
||||
}
|
||||
await _drainSseBody(
|
||||
response.body,
|
||||
signal,
|
||||
dispatch,
|
||||
(id) => {
|
||||
lastEventId = id;
|
||||
},
|
||||
reconnectShouldStop,
|
||||
);
|
||||
};
|
||||
|
||||
return new Promise<Answer>((resolve, reject) => {
|
||||
conversationService
|
||||
.answerStream(payload, token, signal, headers)
|
||||
.then((response) => {
|
||||
if (!response.body) throw Error('No response body');
|
||||
|
||||
let buffer = '';
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder('utf-8');
|
||||
let counterrr = 0;
|
||||
const processStream = ({
|
||||
done,
|
||||
value,
|
||||
}: ReadableStreamReadResult<Uint8Array>) => {
|
||||
if (done) return;
|
||||
|
||||
counterrr += 1;
|
||||
|
||||
const chunk = decoder.decode(value);
|
||||
buffer += chunk;
|
||||
|
||||
const events = buffer.split('\n\n');
|
||||
buffer = events.pop() ?? '';
|
||||
|
||||
for (const event of events) {
|
||||
if (event.trim().startsWith('data:')) {
|
||||
const dataLine: string = event
|
||||
.split('\n')
|
||||
.map((line: string) => line.replace(/^data:\s?/, ''))
|
||||
.join('');
|
||||
|
||||
const messageEvent = new MessageEvent('message', {
|
||||
data: dataLine.trim(),
|
||||
});
|
||||
|
||||
onEvent(messageEvent);
|
||||
}
|
||||
(async () => {
|
||||
try {
|
||||
try {
|
||||
await runInitialPost();
|
||||
} catch (initialErr) {
|
||||
// Mid-stream network failures (WiFi blip, worker recycle,
|
||||
// body reader rejecting) surface as a thrown error — not a
|
||||
// graceful EOF. If the stream had already started (we have a
|
||||
// ``messageId``), fall through to the reconnect path so the
|
||||
// journal-backed replay can finish what the live socket
|
||||
// couldn't. Pre-stream failures (auth, DNS, server 4xx/5xx
|
||||
// before any yield) lack a messageId and bubble up.
|
||||
if (signal.aborted || !messageId) throw initialErr;
|
||||
console.warn(
|
||||
'Initial stream failed mid-flight, attempting reconnect:',
|
||||
initialErr,
|
||||
);
|
||||
}
|
||||
// The backend ends the stream cleanly with a typed ``end``
|
||||
// event. Anything else (network drop, gunicorn worker recycle,
|
||||
// load-balancer timeout) is a "premature close" — try one
|
||||
// reconnect via the GET /api/messages/<id>/events endpoint.
|
||||
if (!endReceived && !signal.aborted && messageId) {
|
||||
try {
|
||||
await runReconnect();
|
||||
} catch (reconnectErr) {
|
||||
console.warn('Stream reconnect failed:', reconnectErr);
|
||||
}
|
||||
|
||||
reader.read().then(processStream).catch(reject);
|
||||
};
|
||||
|
||||
reader.read().then(processStream).catch(reject);
|
||||
})
|
||||
.catch((error) => {
|
||||
}
|
||||
// If we never observed a terminal frame (reconnect 4xx/5xx,
|
||||
// network drop during reconnect, or live tail still silent),
|
||||
// synthesize one through the same ``dispatch`` path the wire
|
||||
// events use. Without this the caller's slice never transitions
|
||||
// out of ``streaming`` and the UI stays in a loading spinner
|
||||
// forever — the conversationSlice handles ``data.type === 'error'``
|
||||
// by setting status=failed.
|
||||
if (!endReceived && !signal.aborted) {
|
||||
dispatch(
|
||||
JSON.stringify({
|
||||
type: 'error',
|
||||
error:
|
||||
'Connection lost. The response could not be resumed; please try again.',
|
||||
}),
|
||||
);
|
||||
}
|
||||
// The handler historically never explicitly resolved with a
|
||||
// value — callers consume the streamed events via ``onEvent``
|
||||
// and read final state from Redux. Preserve that contract.
|
||||
resolve(undefined as unknown as Answer);
|
||||
} catch (error) {
|
||||
if (signal.aborted) {
|
||||
resolve(undefined as unknown as Answer);
|
||||
return;
|
||||
}
|
||||
console.error('Connection failed:', error);
|
||||
reject(error);
|
||||
});
|
||||
}
|
||||
})();
|
||||
});
|
||||
}
|
||||
|
||||
@@ -214,52 +460,149 @@ export function handleSubmitToolActions(
|
||||
|
||||
const headers: Record<string, string> = {};
|
||||
if (idempotencyKey) headers['Idempotency-Key'] = idempotencyKey;
|
||||
|
||||
// Tool-action submissions resume against the original
|
||||
// ``reserved_message_id``, so the backend's continuation path emits
|
||||
// ``id:`` prefixed records that the legacy parser would silently
|
||||
// drop. Use the shared SSE drainer — and the same reconnect-on-
|
||||
// premature-close pattern as ``handleFetchAnswerSteaming`` so a
|
||||
// dropped tool-action stream can pick up after the disconnect.
|
||||
let messageId: string | null = null;
|
||||
let lastEventId: number | null = null;
|
||||
|
||||
// Track whether the typed ``end`` event was observed. The single
|
||||
// JSON.parse below feeds both the message_id capture and the
|
||||
// termination flag — cheaper and stricter than substring matching
|
||||
// the wire bytes.
|
||||
let endReceived = false;
|
||||
|
||||
const dispatch = (data: string) => {
|
||||
try {
|
||||
const parsed = JSON.parse(data);
|
||||
if (parsed && typeof parsed === 'object') {
|
||||
if (parsed.type === 'message_id' && parsed.message_id) {
|
||||
messageId = parsed.message_id;
|
||||
} else if (parsed.type === 'end' || parsed.type === 'error') {
|
||||
// Match the backend's terminal set in
|
||||
// ``application/streaming/event_replay.py``: the agent's
|
||||
// catch-all path emits ``error`` *without* a trailing
|
||||
// ``end``, so treating only ``end`` as terminal would
|
||||
// trigger a reconnect against an already-finished stream
|
||||
// and hang on keepalives.
|
||||
endReceived = true;
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// Not JSON — pass through anyway; the caller handles raw lines.
|
||||
}
|
||||
onEvent(new MessageEvent('message', { data }));
|
||||
};
|
||||
|
||||
const runInitial = async (): Promise<void> => {
|
||||
const response = await conversationService.answerStream(
|
||||
payload,
|
||||
token,
|
||||
signal,
|
||||
headers,
|
||||
);
|
||||
// See ``handleFetchAnswerSteaming`` for the rationale: non-SSE
|
||||
// HTTP failures (e.g. ``check_usage`` 429 JSON) need to be lifted
|
||||
// into a typed ``error`` frame before they reach the drainer.
|
||||
if (await _handlePreStreamHttpError(response, dispatch)) return;
|
||||
if (!response.body) throw new Error('No response body');
|
||||
await _drainSseBody(response.body, signal, dispatch, (id) => {
|
||||
lastEventId = id;
|
||||
});
|
||||
};
|
||||
|
||||
// Reconnect's stop predicate: as soon as ``dispatch`` flips
|
||||
// ``endReceived`` (typed ``end`` or ``error`` event seen — both
|
||||
// are terminal per the backend's contract). Without this the
|
||||
// live-tail endpoint would emit keepalives indefinitely and the
|
||||
// await would never return.
|
||||
const reconnectShouldStop = () => endReceived;
|
||||
|
||||
const runReconnect = async (): Promise<void> => {
|
||||
if (!messageId) {
|
||||
throw new Error('reconnect: no message_id captured');
|
||||
}
|
||||
const url = new URL(`${baseURL}/api/messages/${messageId}/events`);
|
||||
if (lastEventId !== null) {
|
||||
url.searchParams.set('last_event_id', String(lastEventId));
|
||||
}
|
||||
const reconnectHeaders: Record<string, string> = {
|
||||
Accept: 'text/event-stream',
|
||||
};
|
||||
if (token) reconnectHeaders.Authorization = `Bearer ${token}`;
|
||||
const response = await fetch(url.toString(), {
|
||||
method: 'GET',
|
||||
headers: reconnectHeaders,
|
||||
signal,
|
||||
cache: 'no-store',
|
||||
});
|
||||
if (!response.ok || !response.body) {
|
||||
throw new Error(
|
||||
`reconnect: HTTP ${response.status} ${response.statusText}`,
|
||||
);
|
||||
}
|
||||
await _drainSseBody(
|
||||
response.body,
|
||||
signal,
|
||||
dispatch,
|
||||
(id) => {
|
||||
lastEventId = id;
|
||||
},
|
||||
reconnectShouldStop,
|
||||
);
|
||||
};
|
||||
|
||||
return new Promise<Answer>((resolve, reject) => {
|
||||
conversationService
|
||||
.answerStream(payload, token, signal, headers)
|
||||
.then((response) => {
|
||||
if (!response.body) throw Error('No response body');
|
||||
|
||||
let buffer = '';
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder('utf-8');
|
||||
|
||||
const processStream = ({
|
||||
done,
|
||||
value,
|
||||
}: ReadableStreamReadResult<Uint8Array>) => {
|
||||
if (done) return;
|
||||
|
||||
const chunk = decoder.decode(value);
|
||||
buffer += chunk;
|
||||
|
||||
const events = buffer.split('\n\n');
|
||||
buffer = events.pop() ?? '';
|
||||
|
||||
for (const event of events) {
|
||||
if (event.trim().startsWith('data:')) {
|
||||
const dataLine: string = event
|
||||
.split('\n')
|
||||
.map((line: string) => line.replace(/^data:\s?/, ''))
|
||||
.join('');
|
||||
|
||||
const messageEvent = new MessageEvent('message', {
|
||||
data: dataLine.trim(),
|
||||
});
|
||||
|
||||
onEvent(messageEvent);
|
||||
}
|
||||
(async () => {
|
||||
try {
|
||||
try {
|
||||
await runInitial();
|
||||
} catch (initialErr) {
|
||||
// Same premature-close handling as
|
||||
// ``handleFetchAnswerSteaming``: a thrown reader error after
|
||||
// the message_id frame still warrants one reconnect attempt
|
||||
// against the journal. Pre-stream failures lack a messageId
|
||||
// and bubble up.
|
||||
if (signal.aborted || !messageId) throw initialErr;
|
||||
console.warn(
|
||||
'Tool-actions stream failed mid-flight, attempting reconnect:',
|
||||
initialErr,
|
||||
);
|
||||
}
|
||||
if (!endReceived && !signal.aborted && messageId) {
|
||||
try {
|
||||
await runReconnect();
|
||||
} catch (reconnectErr) {
|
||||
console.warn('Tool-actions reconnect failed:', reconnectErr);
|
||||
}
|
||||
|
||||
reader.read().then(processStream).catch(reject);
|
||||
};
|
||||
|
||||
reader.read().then(processStream).catch(reject);
|
||||
})
|
||||
.catch((error) => {
|
||||
}
|
||||
// Synthesize a terminal error if reconnect couldn't deliver one
|
||||
// (4xx/5xx, network drop, silent live tail). Same reasoning as
|
||||
// ``handleFetchAnswerSteaming``: the caller's slice only exits
|
||||
// the streaming state on a terminal frame.
|
||||
if (!endReceived && !signal.aborted) {
|
||||
dispatch(
|
||||
JSON.stringify({
|
||||
type: 'error',
|
||||
error:
|
||||
'Connection lost. The tool response could not be resumed; please try again.',
|
||||
}),
|
||||
);
|
||||
}
|
||||
resolve(undefined as unknown as Answer);
|
||||
} catch (error) {
|
||||
if (signal.aborted) {
|
||||
resolve(undefined as unknown as Answer);
|
||||
return;
|
||||
}
|
||||
console.error('Tool actions submission failed:', error);
|
||||
reject(error);
|
||||
});
|
||||
}
|
||||
})();
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
153
frontend/src/conversation/conversationSlice.test.ts
Normal file
153
frontend/src/conversation/conversationSlice.test.ts
Normal file
@@ -0,0 +1,153 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import reducer, {
|
||||
applyMessageTail,
|
||||
setConversation,
|
||||
} from './conversationSlice';
|
||||
|
||||
const baseQuery = {
|
||||
prompt: 'tell me a poem',
|
||||
messageId: 'm-1',
|
||||
messageStatus: 'pending' as const,
|
||||
};
|
||||
|
||||
const seedSlice = () => reducer(undefined, setConversation([baseQuery]));
|
||||
|
||||
describe('applyMessageTail — streaming partial', () => {
|
||||
it('writes response to the query while status is streaming', () => {
|
||||
const state = seedSlice();
|
||||
const next = reducer(
|
||||
state,
|
||||
applyMessageTail({
|
||||
index: 0,
|
||||
tail: {
|
||||
message_id: 'm-1',
|
||||
status: 'streaming',
|
||||
response: 'Hello, par',
|
||||
thought: null,
|
||||
sources: [],
|
||||
tool_calls: [],
|
||||
},
|
||||
}),
|
||||
);
|
||||
expect(next.queries[0].messageStatus).toBe('streaming');
|
||||
expect(next.queries[0].response).toBe('Hello, par');
|
||||
});
|
||||
|
||||
it('updates response on each successive tail tick', () => {
|
||||
let state = seedSlice();
|
||||
state = reducer(
|
||||
state,
|
||||
applyMessageTail({
|
||||
index: 0,
|
||||
tail: {
|
||||
message_id: 'm-1',
|
||||
status: 'streaming',
|
||||
response: 'Hello',
|
||||
sources: [],
|
||||
tool_calls: [],
|
||||
},
|
||||
}),
|
||||
);
|
||||
state = reducer(
|
||||
state,
|
||||
applyMessageTail({
|
||||
index: 0,
|
||||
tail: {
|
||||
message_id: 'm-1',
|
||||
status: 'streaming',
|
||||
response: 'Hello, world',
|
||||
sources: [],
|
||||
tool_calls: [],
|
||||
},
|
||||
}),
|
||||
);
|
||||
expect(state.queries[0].response).toBe('Hello, world');
|
||||
});
|
||||
|
||||
it('applies sources and tool_calls when they appear mid-stream', () => {
|
||||
const state = seedSlice();
|
||||
const next = reducer(
|
||||
state,
|
||||
applyMessageTail({
|
||||
index: 0,
|
||||
tail: {
|
||||
message_id: 'm-1',
|
||||
status: 'streaming',
|
||||
response: 'partial',
|
||||
sources: [{ id: 's1', title: 'doc' }],
|
||||
tool_calls: [{ name: 'search' }],
|
||||
},
|
||||
}),
|
||||
);
|
||||
expect(next.queries[0].sources).toEqual([{ id: 's1', title: 'doc' }]);
|
||||
expect(next.queries[0].tool_calls).toEqual([{ name: 'search' }]);
|
||||
});
|
||||
|
||||
it('ignores empty sources / tool_calls arrays so the renderer stays clean', () => {
|
||||
const state = seedSlice();
|
||||
const next = reducer(
|
||||
state,
|
||||
applyMessageTail({
|
||||
index: 0,
|
||||
tail: {
|
||||
message_id: 'm-1',
|
||||
status: 'streaming',
|
||||
response: 'partial',
|
||||
sources: [],
|
||||
tool_calls: [],
|
||||
},
|
||||
}),
|
||||
);
|
||||
expect(next.queries[0].sources).toBeUndefined();
|
||||
expect(next.queries[0].tool_calls).toBeUndefined();
|
||||
});
|
||||
|
||||
it('promotes to complete with the final response and clears any error', () => {
|
||||
let state = seedSlice();
|
||||
state = reducer(
|
||||
state,
|
||||
applyMessageTail({
|
||||
index: 0,
|
||||
tail: {
|
||||
message_id: 'm-1',
|
||||
status: 'streaming',
|
||||
response: 'partial',
|
||||
},
|
||||
}),
|
||||
);
|
||||
state = reducer(
|
||||
state,
|
||||
applyMessageTail({
|
||||
index: 0,
|
||||
tail: {
|
||||
message_id: 'm-1',
|
||||
status: 'complete',
|
||||
response: 'Final answer.',
|
||||
},
|
||||
}),
|
||||
);
|
||||
expect(state.queries[0].messageStatus).toBe('complete');
|
||||
expect(state.queries[0].response).toBe('Final answer.');
|
||||
expect(state.queries[0].error).toBeUndefined();
|
||||
});
|
||||
|
||||
it('surfaces failed status as error and clears response', () => {
|
||||
const state = seedSlice();
|
||||
const next = reducer(
|
||||
state,
|
||||
applyMessageTail({
|
||||
index: 0,
|
||||
tail: {
|
||||
message_id: 'm-1',
|
||||
status: 'failed',
|
||||
response: 'whatever',
|
||||
error: 'worker died',
|
||||
},
|
||||
}),
|
||||
);
|
||||
expect(next.queries[0].messageStatus).toBe('failed');
|
||||
expect(next.queries[0].error).toBe('worker died');
|
||||
expect(next.queries[0].response).toBeUndefined();
|
||||
});
|
||||
});
|
||||
@@ -957,20 +957,34 @@ export const conversationSlice = createSlice({
|
||||
const status = tail?.status as MessageStatus | undefined;
|
||||
query.messageStatus = status;
|
||||
query.lastHeartbeatAt = tail?.last_heartbeat_at ?? query.lastHeartbeatAt;
|
||||
if (status === 'complete') {
|
||||
query.response = tail?.response ?? '';
|
||||
query.thought = tail?.thought ?? query.thought;
|
||||
query.sources = tail?.sources ?? query.sources;
|
||||
query.tool_calls = tail?.tool_calls ?? query.tool_calls;
|
||||
delete query.error;
|
||||
} else if (status === 'failed') {
|
||||
if (status === 'failed') {
|
||||
// Surface as error so the placeholder text never renders.
|
||||
query.error =
|
||||
(typeof tail?.error === 'string' && tail.error) ||
|
||||
'Generation failed before completing.';
|
||||
delete query.response;
|
||||
return;
|
||||
}
|
||||
// /tail returns reconstructed partials mid-stream so a second tab
|
||||
// can render the in-flight bubble; spinner is driven by status.
|
||||
const incomingResponse = tail?.response;
|
||||
if (typeof incomingResponse === 'string') {
|
||||
query.response = incomingResponse;
|
||||
} else if (status === 'complete') {
|
||||
query.response = '';
|
||||
}
|
||||
if (typeof tail?.thought === 'string') {
|
||||
query.thought = tail.thought;
|
||||
}
|
||||
if (Array.isArray(tail?.sources) && tail.sources.length > 0) {
|
||||
query.sources = tail.sources;
|
||||
}
|
||||
if (Array.isArray(tail?.tool_calls) && tail.tool_calls.length > 0) {
|
||||
query.tool_calls = tail.tool_calls;
|
||||
}
|
||||
if (status === 'complete') {
|
||||
delete query.error;
|
||||
}
|
||||
// pending / streaming: untouched; spinner keeps showing.
|
||||
},
|
||||
raiseError(
|
||||
state,
|
||||
|
||||
18
frontend/src/events/EventStreamProvider.tsx
Normal file
18
frontend/src/events/EventStreamProvider.tsx
Normal file
@@ -0,0 +1,18 @@
|
||||
import React from 'react';
|
||||
|
||||
import { useEventStream } from './useEventStream';
|
||||
|
||||
/**
|
||||
* Mount-once provider that opens the user's SSE connection. Place
|
||||
* inside ``AuthWrapper`` so it sees a populated token, and wrap the
|
||||
* authenticated-app subtree so the connection lives for the user's
|
||||
* whole session.
|
||||
*/
|
||||
export function EventStreamProvider({
|
||||
children,
|
||||
}: {
|
||||
children: React.ReactNode;
|
||||
}): React.ReactElement {
|
||||
useEventStream();
|
||||
return <>{children}</>;
|
||||
}
|
||||
49
frontend/src/events/dispatchEvent.test.ts
Normal file
49
frontend/src/events/dispatchEvent.test.ts
Normal file
@@ -0,0 +1,49 @@
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import type { AppDispatch } from '../store';
|
||||
import {
|
||||
sseEventReceived,
|
||||
sseLastEventIdReset,
|
||||
} from '../notifications/notificationsSlice';
|
||||
import { dispatchSSEEvent } from './dispatchEvent';
|
||||
|
||||
describe('dispatchSSEEvent', () => {
|
||||
let debugSpy: ReturnType<typeof vi.spyOn>;
|
||||
|
||||
beforeEach(() => {
|
||||
debugSpy = vi.spyOn(console, 'debug').mockImplementation(() => undefined);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
debugSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('dispatches sseLastEventIdReset AND sseEventReceived for backlog.truncated', () => {
|
||||
const dispatch = vi.fn() as unknown as AppDispatch;
|
||||
const envelope = { type: 'backlog.truncated' as const };
|
||||
|
||||
dispatchSSEEvent(envelope, dispatch);
|
||||
|
||||
const calls = (dispatch as unknown as { mock: { calls: unknown[][] } }).mock
|
||||
.calls;
|
||||
expect(calls).toHaveLength(2);
|
||||
expect(calls[0][0]).toEqual(sseLastEventIdReset());
|
||||
expect(calls[1][0]).toEqual(sseEventReceived(envelope));
|
||||
});
|
||||
|
||||
it('does not log a debug line for known envelope types', () => {
|
||||
const dispatch = vi.fn() as unknown as AppDispatch;
|
||||
dispatchSSEEvent({ id: 'e-1', type: 'source.ingest.progress' }, dispatch);
|
||||
expect(debugSpy).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('logs a debug line for unknown envelope types', () => {
|
||||
const dispatch = vi.fn() as unknown as AppDispatch;
|
||||
dispatchSSEEvent({ id: 'e-2', type: 'mystery.event' }, dispatch);
|
||||
expect(debugSpy).toHaveBeenCalledTimes(1);
|
||||
expect(debugSpy.mock.calls[0]).toEqual([
|
||||
'[dispatchSSEEvent] unknown envelope type',
|
||||
'mystery.event',
|
||||
]);
|
||||
});
|
||||
});
|
||||
58
frontend/src/events/dispatchEvent.ts
Normal file
58
frontend/src/events/dispatchEvent.ts
Normal file
@@ -0,0 +1,58 @@
|
||||
import type { AppDispatch } from '../store';
|
||||
import {
|
||||
sseEventReceived,
|
||||
sseLastEventIdReset,
|
||||
type SSEEvent,
|
||||
} from '../notifications/notificationsSlice';
|
||||
|
||||
// Envelope types this build knows about. Hitting an unknown type means
|
||||
// the backend published something the frontend hasn't been taught yet
|
||||
// — worth a single debug line so it's visible in devtools without
|
||||
// drowning the console in known per-progress traffic.
|
||||
const KNOWN_TYPES: ReadonlySet<string> = new Set([
|
||||
'backlog.truncated',
|
||||
'source.ingest.queued',
|
||||
'source.ingest.progress',
|
||||
'source.ingest.completed',
|
||||
'source.ingest.failed',
|
||||
'attachment.queued',
|
||||
'attachment.progress',
|
||||
'attachment.completed',
|
||||
'attachment.failed',
|
||||
'mcp.oauth.awaiting_redirect',
|
||||
'mcp.oauth.in_progress',
|
||||
'mcp.oauth.completed',
|
||||
'mcp.oauth.failed',
|
||||
'tool.approval.required',
|
||||
]);
|
||||
|
||||
/**
|
||||
* Single fan-out point for inbound SSE envelopes. Always dispatches
|
||||
* ``sseEventReceived`` so any slice can ``extraReducers``-listen
|
||||
* (uploadSlice does this for source-ingest events), then handles the
|
||||
* small set of envelope-types that need centralised side effects (e.g.
|
||||
* ``backlog.truncated``).
|
||||
*/
|
||||
export function dispatchSSEEvent(
|
||||
envelope: SSEEvent,
|
||||
dispatch: AppDispatch,
|
||||
): void {
|
||||
if (!KNOWN_TYPES.has(envelope.type)) {
|
||||
console.debug('[dispatchSSEEvent] unknown envelope type', envelope.type);
|
||||
}
|
||||
|
||||
switch (envelope.type) {
|
||||
case 'backlog.truncated':
|
||||
// Backlog window slid past the client's Last-Event-ID. Drop the
|
||||
// cursor so the next reconnect doesn't try to resume past the
|
||||
// retained window. Slices that care about full-state freshness
|
||||
// can subscribe to ``sseEventReceived`` and refetch.
|
||||
dispatch(sseLastEventIdReset());
|
||||
break;
|
||||
default:
|
||||
// No central side effect; rely on slice-level extraReducers.
|
||||
break;
|
||||
}
|
||||
|
||||
dispatch(sseEventReceived(envelope));
|
||||
}
|
||||
386
frontend/src/events/eventStreamClient.ts
Normal file
386
frontend/src/events/eventStreamClient.ts
Normal file
@@ -0,0 +1,386 @@
|
||||
import { baseURL } from '../api/client';
|
||||
import type { SSEEvent } from '../notifications/notificationsSlice';
|
||||
|
||||
/**
|
||||
* Connection state surfaced to the consumer. Maps directly to the
|
||||
* ``PushHealth`` machine in ``notificationsSlice``.
|
||||
*/
|
||||
export type EventStreamHealth = 'connecting' | 'healthy' | 'unhealthy';
|
||||
|
||||
export interface EventStreamOptions {
|
||||
/** Bearer token; ``null`` short-circuits to ``unhealthy`` (auth required). */
|
||||
token: string | null;
|
||||
/**
|
||||
* Lazy getter for the current ``Last-Event-ID``. Called once at the
|
||||
* top of each connect attempt so token rotations / remounts read
|
||||
* the freshest cursor from Redux instead of a stale mount-time
|
||||
* snapshot. Return ``null`` for a fresh connect.
|
||||
*/
|
||||
getLastEventId: () => string | null;
|
||||
onEvent: (event: SSEEvent) => void;
|
||||
onHealthChange: (health: EventStreamHealth) => void;
|
||||
/** Called with the most recently received id so the caller can persist it. */
|
||||
onLastEventId?: (id: string) => void;
|
||||
/**
|
||||
* Called when the server emitted an ``id:`` line with an empty value
|
||||
* (WHATWG SSE cursor reset). Distinct from ``onLastEventId('')`` so
|
||||
* callers can dispatch ``sseLastEventIdReset`` without overloading
|
||||
* the normal advance path.
|
||||
*/
|
||||
onLastEventIdReset?: () => void;
|
||||
/**
|
||||
* Invoked once after ``MAX_CONSECUTIVE_401`` back-to-back 401s. The
|
||||
* reconnect loop then bails out, so the caller is responsible for
|
||||
* refreshing the token / signalling logout. Without this, an expired
|
||||
* token spins forever at the 30s backoff cap.
|
||||
*/
|
||||
onAuthFailure?: () => void;
|
||||
/**
|
||||
* Invoked once when the reconnect loop bails out after
|
||||
* ``MAX_CONSECUTIVE_ERRORS`` non-401 failures. Lets the caller surface
|
||||
* a warning instead of the connection silently going dark.
|
||||
*/
|
||||
onPermanentFailure?: () => void;
|
||||
}
|
||||
|
||||
export interface EventStreamConnection {
|
||||
close(): void;
|
||||
}
|
||||
|
||||
/**
|
||||
* Backoff schedule (ms) for reconnect attempts. Capped at 30s so a long
|
||||
* outage doesn't push retries past Cloudflare's typical 100s idle-close
|
||||
* envelope. The schedule resets to 0 after a stream stays healthy for
|
||||
* ``HEALTHY_DEBOUNCE_MS``.
|
||||
*/
|
||||
const BACKOFF_SCHEDULE_MS = [0, 1_000, 2_000, 4_000, 8_000, 16_000, 30_000];
|
||||
const HEALTHY_DEBOUNCE_MS = 2_000;
|
||||
/**
|
||||
* Reconnect ceilings. Without these, the ``while (!closed)`` loop spins
|
||||
* forever on a persistently-failing endpoint — expired token (401s) or
|
||||
* sustained server outage (5xx). Both counters reset on a successful
|
||||
* stream open. Untested (no frontend test harness); behaviour verified
|
||||
* by manual trace of the loop in ``connectEventStream``.
|
||||
*/
|
||||
const MAX_CONSECUTIVE_401 = 3;
|
||||
const MAX_CONSECUTIVE_ERRORS = 20;
|
||||
|
||||
/** Up-to-±20% random jitter so N tabs reconnecting in lockstep stagger. */
|
||||
function withJitter(delayMs: number): number {
|
||||
if (delayMs <= 0) return 0;
|
||||
const span = delayMs * 0.2;
|
||||
return Math.max(0, Math.round(delayMs + (Math.random() * 2 - 1) * span));
|
||||
}
|
||||
|
||||
/**
|
||||
* Open a long-lived SSE connection to ``GET /api/events`` with
|
||||
* fetch-streaming semantics that mirror ``conversationHandlers.ts``.
|
||||
*
|
||||
* Returns immediately with an opaque handle; the connection lives in a
|
||||
* background async loop until ``close()`` is called or the underlying
|
||||
* ``AbortController`` fires.
|
||||
*
|
||||
* The ``Last-Event-ID`` cursor rides on the URL (``?last_event_id=...``)
|
||||
* rather than as a header so the request stays a CORS-simple GET — a
|
||||
* custom header would force a preflight OPTIONS that the production
|
||||
* cross-origin deployment isn't allowlisted for.
|
||||
*/
|
||||
export function connectEventStream(
|
||||
opts: EventStreamOptions,
|
||||
): EventStreamConnection {
|
||||
const controller = new AbortController();
|
||||
let closed = false;
|
||||
let attempt = 0;
|
||||
let consecutive401 = 0;
|
||||
let consecutiveErrors = 0;
|
||||
// Closure cursor. Seeded from the store on each connect attempt so
|
||||
// mid-session reconnects use the freshest id, but kept here too so
|
||||
// an in-flight stream's reconnect doesn't lose progress between ids
|
||||
// that the store hasn't seen yet (e.g. id-only frames).
|
||||
let lastEventId: string | null = opts.getLastEventId();
|
||||
|
||||
const notifyHealth = (h: EventStreamHealth) => {
|
||||
if (closed) return;
|
||||
opts.onHealthChange(h);
|
||||
};
|
||||
|
||||
void (async () => {
|
||||
while (!closed) {
|
||||
const baseDelay =
|
||||
BACKOFF_SCHEDULE_MS[Math.min(attempt, BACKOFF_SCHEDULE_MS.length - 1)];
|
||||
const delay = withJitter(baseDelay);
|
||||
if (delay > 0) {
|
||||
try {
|
||||
await sleep(delay, controller.signal);
|
||||
} catch {
|
||||
return; // aborted while waiting
|
||||
}
|
||||
if (closed) return;
|
||||
}
|
||||
|
||||
notifyHealth('connecting');
|
||||
|
||||
// Always re-read the store cursor before reconnecting and copy
|
||||
// it verbatim — including null. A null cursor isn't "leave
|
||||
// alone": ``backlog.truncated`` events fire ``sseLastEventIdReset``
|
||||
// to clear the slice, and the client must respect that on the
|
||||
// next attempt by sending no cursor (full-backlog replay) instead
|
||||
// of resending the stale one and re-tripping the same truncation.
|
||||
lastEventId = opts.getLastEventId();
|
||||
|
||||
const url = new URL(`${baseURL}/api/events`);
|
||||
if (lastEventId) url.searchParams.set('last_event_id', lastEventId);
|
||||
|
||||
// Auth header is omitted when token is null. Self-hosted dev
|
||||
// installs run with ``AUTH_TYPE`` unset; the backend maps those
|
||||
// requests to ``{"sub": "local"}`` so the SSE connection works
|
||||
// tokenless. When auth IS required, a missing header surfaces
|
||||
// as a 401 and the response.ok check below flips the health
|
||||
// back to unhealthy.
|
||||
const headers: Record<string, string> = {
|
||||
Accept: 'text/event-stream',
|
||||
};
|
||||
if (opts.token) {
|
||||
headers.Authorization = `Bearer ${opts.token}`;
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetch(url.toString(), {
|
||||
method: 'GET',
|
||||
headers,
|
||||
signal: controller.signal,
|
||||
// SSE must not be cached.
|
||||
cache: 'no-store',
|
||||
});
|
||||
|
||||
if (!response.ok || !response.body) {
|
||||
notifyHealth('unhealthy');
|
||||
// 401 typically means token expired. Bail out after N
|
||||
// consecutive 401s so the loop doesn't spin forever at the
|
||||
// 30s backoff cap with a stale token; the caller is
|
||||
// responsible for refreshing auth via ``onAuthFailure``.
|
||||
if (response.status === 401) {
|
||||
consecutive401 += 1;
|
||||
consecutiveErrors += 1;
|
||||
if (consecutive401 >= MAX_CONSECUTIVE_401) {
|
||||
opts.onAuthFailure?.();
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
consecutive401 = 0;
|
||||
consecutiveErrors += 1;
|
||||
}
|
||||
if (consecutiveErrors >= MAX_CONSECUTIVE_ERRORS) {
|
||||
opts.onPermanentFailure?.();
|
||||
return;
|
||||
}
|
||||
// 429: server-side per-user concurrency cap; backoff harder.
|
||||
if (response.status === 429) attempt = Math.max(attempt, 4);
|
||||
else attempt = Math.min(attempt + 1, BACKOFF_SCHEDULE_MS.length - 1);
|
||||
continue;
|
||||
}
|
||||
consecutive401 = 0;
|
||||
|
||||
// Connection is open. Mark healthy after either:
|
||||
// - 2s of open response body (covers servers that go silent), or
|
||||
// - first record received past the 2s mark.
|
||||
// The setTimeout path means a backend that never emits a single
|
||||
// record after sending the 200 still flips us out of `connecting`.
|
||||
let healthyMarked = false;
|
||||
const markHealthy = () => {
|
||||
if (healthyMarked) return;
|
||||
healthyMarked = true;
|
||||
notifyHealth('healthy');
|
||||
attempt = 0;
|
||||
consecutiveErrors = 0;
|
||||
};
|
||||
const debounceTimer = setTimeout(markHealthy, HEALTHY_DEBOUNCE_MS);
|
||||
|
||||
try {
|
||||
await readSSEStream(response.body, controller.signal, (record) => {
|
||||
if (record.id !== undefined) {
|
||||
lastEventId = record.id || null;
|
||||
if (record.id) opts.onLastEventId?.(record.id);
|
||||
else opts.onLastEventIdReset?.();
|
||||
}
|
||||
if (record.data === undefined) {
|
||||
// Keepalive comment, id-only frame, or comment line.
|
||||
// The cursor was already advanced via ``onLastEventId``
|
||||
// above so the slice tracks ids even on frames we don't
|
||||
// dispatch as events.
|
||||
return;
|
||||
}
|
||||
// Empty data line is technically valid SSE but useless; skip.
|
||||
if (record.data.trim().length === 0) return;
|
||||
let envelope: SSEEvent | null = null;
|
||||
try {
|
||||
envelope = JSON.parse(record.data) as SSEEvent;
|
||||
} catch {
|
||||
// Malformed payload; skip.
|
||||
return;
|
||||
}
|
||||
// Defensive shape validation — the cast above lies if the
|
||||
// server (or a man-in-the-middle) sends garbage.
|
||||
if (
|
||||
!envelope ||
|
||||
typeof envelope !== 'object' ||
|
||||
typeof envelope.type !== 'string'
|
||||
) {
|
||||
return;
|
||||
}
|
||||
if (record.id && !envelope.id) envelope.id = record.id;
|
||||
// Receiving a real envelope post-debounce-window flips
|
||||
// healthy if the timer hasn't already.
|
||||
markHealthy();
|
||||
// Every tab dispatches every envelope it receives into its
|
||||
// own Redux store. With N tabs open this means N copies of
|
||||
// the same toast — accepted as a v1 limitation; cross-tab
|
||||
// dedup via BroadcastChannel + navigator.locks is future
|
||||
// work. Toast-level suppression can be handled per surface.
|
||||
opts.onEvent(envelope);
|
||||
});
|
||||
} finally {
|
||||
clearTimeout(debounceTimer);
|
||||
}
|
||||
|
||||
// The reader returned without abort — server closed the stream.
|
||||
// Fall through to reconnect.
|
||||
notifyHealth('unhealthy');
|
||||
consecutiveErrors += 1;
|
||||
if (consecutiveErrors >= MAX_CONSECUTIVE_ERRORS) {
|
||||
opts.onPermanentFailure?.();
|
||||
return;
|
||||
}
|
||||
attempt = Math.min(attempt + 1, BACKOFF_SCHEDULE_MS.length - 1);
|
||||
} catch (err) {
|
||||
if (
|
||||
closed ||
|
||||
(err instanceof DOMException && err.name === 'AbortError')
|
||||
) {
|
||||
return;
|
||||
}
|
||||
notifyHealth('unhealthy');
|
||||
consecutiveErrors += 1;
|
||||
if (consecutiveErrors >= MAX_CONSECUTIVE_ERRORS) {
|
||||
opts.onPermanentFailure?.();
|
||||
return;
|
||||
}
|
||||
attempt = Math.min(attempt + 1, BACKOFF_SCHEDULE_MS.length - 1);
|
||||
}
|
||||
}
|
||||
})();
|
||||
|
||||
return {
|
||||
close() {
|
||||
if (closed) return;
|
||||
closed = true;
|
||||
controller.abort();
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
interface ParsedSSERecord {
|
||||
/**
|
||||
* ``undefined`` when the record had no ``id`` field at all. An empty
|
||||
* string means the server explicitly reset the cursor (an ``id:``
|
||||
* line with no value, per WHATWG SSE).
|
||||
*/
|
||||
id?: string;
|
||||
/** ``undefined`` for keepalive comments / id-only frames. */
|
||||
data?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Drain a ``ReadableStream<Uint8Array>`` of ``\n\n``-delimited SSE records,
|
||||
* forwarding each parsed record to ``onRecord``. Honours the WHATWG SSE
|
||||
* spec's mixed line-terminator handling and SSE comment lines.
|
||||
*/
|
||||
async function readSSEStream(
|
||||
body: ReadableStream<Uint8Array>,
|
||||
signal: AbortSignal,
|
||||
onRecord: (record: ParsedSSERecord) => void,
|
||||
): Promise<void> {
|
||||
const reader = body.getReader();
|
||||
const decoder = new TextDecoder('utf-8');
|
||||
let buffer = '';
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
if (signal.aborted) return;
|
||||
const { done, value } = await reader.read();
|
||||
if (done) return;
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
|
||||
// SSE records are separated by a blank line. WHATWG spec accepts
|
||||
// CRLF, CR, or LF — normalise so a stray CR can't smuggle a
|
||||
// boundary mid-record.
|
||||
buffer = buffer.replace(/\r\n/g, '\n').replace(/\r/g, '\n');
|
||||
|
||||
let boundary = buffer.indexOf('\n\n');
|
||||
while (boundary !== -1) {
|
||||
const raw = buffer.slice(0, boundary);
|
||||
buffer = buffer.slice(boundary + 2);
|
||||
const record = parseSSERecord(raw);
|
||||
if (record) onRecord(record);
|
||||
boundary = buffer.indexOf('\n\n');
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
try {
|
||||
reader.releaseLock();
|
||||
} catch {
|
||||
// Already released.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function parseSSERecord(raw: string): ParsedSSERecord | null {
|
||||
if (raw.length === 0) return null;
|
||||
const lines = raw.split('\n');
|
||||
let id: string | undefined;
|
||||
const dataParts: string[] = [];
|
||||
let sawDataField = false;
|
||||
|
||||
for (const line of lines) {
|
||||
if (line.length === 0) continue;
|
||||
if (line.startsWith(':')) continue; // SSE comment / keepalive
|
||||
const colonIdx = line.indexOf(':');
|
||||
const field = colonIdx === -1 ? line : line.slice(0, colonIdx);
|
||||
let value = colonIdx === -1 ? '' : line.slice(colonIdx + 1);
|
||||
// SSE: value may be prefixed by exactly one optional space.
|
||||
if (value.startsWith(' ')) value = value.slice(1);
|
||||
|
||||
if (field === 'id') {
|
||||
id = value;
|
||||
} else if (field === 'data') {
|
||||
sawDataField = true;
|
||||
dataParts.push(value);
|
||||
}
|
||||
// Other field names ('event', 'retry') are ignored for now.
|
||||
}
|
||||
|
||||
if (!sawDataField && id === undefined) return null;
|
||||
return {
|
||||
id,
|
||||
data: sawDataField ? dataParts.join('\n') : undefined,
|
||||
};
|
||||
}
|
||||
|
||||
function sleep(ms: number, signal: AbortSignal): Promise<void> {
|
||||
return new Promise((resolve, reject) => {
|
||||
if (signal.aborted) {
|
||||
reject(new DOMException('Aborted', 'AbortError'));
|
||||
return;
|
||||
}
|
||||
const timer = setTimeout(() => {
|
||||
signal.removeEventListener('abort', onAbort);
|
||||
resolve();
|
||||
}, ms);
|
||||
const onAbort = () => {
|
||||
clearTimeout(timer);
|
||||
signal.removeEventListener('abort', onAbort);
|
||||
reject(new DOMException('Aborted', 'AbortError'));
|
||||
};
|
||||
signal.addEventListener('abort', onAbort, { once: true });
|
||||
});
|
||||
}
|
||||
85
frontend/src/events/useEventStream.ts
Normal file
85
frontend/src/events/useEventStream.ts
Normal file
@@ -0,0 +1,85 @@
|
||||
import { useEffect } from 'react';
|
||||
import { useDispatch, useSelector, useStore } from 'react-redux';
|
||||
|
||||
import {
|
||||
selectLastEventId,
|
||||
sseHealthChanged,
|
||||
sseLastEventIdAdvanced,
|
||||
sseLastEventIdReset,
|
||||
} from '../notifications/notificationsSlice';
|
||||
import { selectToken, setToken } from '../preferences/preferenceSlice';
|
||||
import type { AppDispatch, RootState } from '../store';
|
||||
|
||||
import { connectEventStream } from './eventStreamClient';
|
||||
import { dispatchSSEEvent } from './dispatchEvent';
|
||||
|
||||
/**
|
||||
* Open the SSE connection for the current token and keep it alive for
|
||||
* the lifetime of the host component. Recreates the connection on
|
||||
* token change (login / refresh).
|
||||
*
|
||||
* The ``lastEventId`` cursor is read lazily from the slice on each
|
||||
* connect attempt via ``store.getState()`` — capturing it at mount time
|
||||
* would silently re-replay the entire 24h backlog on token rotation,
|
||||
* since the slice's id advances during the previous connection's
|
||||
* lifetime but a snapshot ref would still hold the value seen at
|
||||
* first mount.
|
||||
*/
|
||||
export function useEventStream(): void {
|
||||
const dispatch = useDispatch<AppDispatch>();
|
||||
const token = useSelector(selectToken);
|
||||
const store = useStore<RootState>();
|
||||
|
||||
useEffect(() => {
|
||||
// Connect even when token is null. Self-hosted dev installs run
|
||||
// with ``AUTH_TYPE`` unset, where ``handle_auth`` maps every
|
||||
// request to ``{"sub": "local"}`` regardless of headers — gating
|
||||
// the connection on a populated token would silently disable push
|
||||
// notifications for the most common configuration. When auth IS
|
||||
// required and token is null, the backend will 401 and the
|
||||
// health state will flip to ``unhealthy`` via the response check
|
||||
// inside ``connectEventStream``.
|
||||
const conn = connectEventStream({
|
||||
token,
|
||||
getLastEventId: () => selectLastEventId(store.getState()),
|
||||
onEvent: (envelope) => dispatchSSEEvent(envelope, dispatch),
|
||||
// Advance the slice cursor for every id-bearing frame. Each tab
|
||||
// owns an independent SSE connection and Redux store, so every
|
||||
// active tab tracks its own replay cursor.
|
||||
onLastEventId: (id) => dispatch(sseLastEventIdAdvanced(id)),
|
||||
// Server emitted ``id:`` with an empty value — WHATWG cursor reset.
|
||||
// Mirror the slice so the next reconnect doesn't resend a stale id.
|
||||
onLastEventIdReset: () => dispatch(sseLastEventIdReset()),
|
||||
onHealthChange: (health) => dispatch(sseHealthChanged(health)),
|
||||
// SSE 401 loop bail-out. Clear the stored token AND dispatch
|
||||
// ``setToken(null)`` so ``useAuth`` regenerates a fresh
|
||||
// ``session_jwt`` in-session; the Redux change also flips this
|
||||
// hook's ``[token]`` dep, tearing down and respawning the
|
||||
// connection with the new token. Without the dispatch a
|
||||
// ``session_jwt`` user is stuck until a hard reload.
|
||||
onAuthFailure: () => {
|
||||
console.error(
|
||||
'[useEventStream] giving up after repeated 401s on /api/events',
|
||||
);
|
||||
try {
|
||||
localStorage.removeItem('authToken');
|
||||
} catch {
|
||||
// localStorage unavailable (private mode, etc.) — nothing to do.
|
||||
}
|
||||
dispatch(setToken(null));
|
||||
},
|
||||
// Surface a warning when the non-401 error budget is exhausted so
|
||||
// the connection going dark isn't completely silent. Doesn't block
|
||||
// UI — just observable in devtools.
|
||||
onPermanentFailure: () => {
|
||||
console.warn(
|
||||
'[useEventStream] SSE connection failed permanently after repeated errors',
|
||||
);
|
||||
},
|
||||
});
|
||||
|
||||
return () => {
|
||||
conn.close();
|
||||
};
|
||||
}, [token, dispatch, store]);
|
||||
}
|
||||
@@ -15,24 +15,39 @@ export default function useAuth() {
|
||||
const generateNewToken = async () => {
|
||||
if (isGeneratingToken.current) return;
|
||||
isGeneratingToken.current = true;
|
||||
const response = await userService.getNewToken();
|
||||
const { token: newToken } = await response.json();
|
||||
localStorage.setItem('authToken', newToken);
|
||||
dispatch(setToken(newToken));
|
||||
setIsAuthLoading(false);
|
||||
return newToken;
|
||||
try {
|
||||
const response = await userService.getNewToken();
|
||||
const { token: newToken } = await response.json();
|
||||
localStorage.setItem('authToken', newToken);
|
||||
dispatch(setToken(newToken));
|
||||
setIsAuthLoading(false);
|
||||
return newToken;
|
||||
} finally {
|
||||
// Reset so a subsequent ``setToken(null)`` (SSE 401 recovery)
|
||||
// can trigger another generation. Without this the in-flight
|
||||
// guard would latch true forever after the first call.
|
||||
isGeneratingToken.current = false;
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
// Re-fires when ``token`` flips to null mid-session (e.g.
|
||||
// ``useEventStream`` dispatches ``setToken(null)`` after repeated
|
||||
// SSE 401s) so ``session_jwt`` users get a fresh token without a
|
||||
// hard reload. ``authType`` short-circuits on subsequent runs.
|
||||
const initializeAuth = async () => {
|
||||
try {
|
||||
const configRes = await userService.getConfig();
|
||||
const config = await configRes.json();
|
||||
setAuthType(config.auth_type);
|
||||
let resolvedAuthType = authType;
|
||||
if (resolvedAuthType === null) {
|
||||
const configRes = await userService.getConfig();
|
||||
const config = await configRes.json();
|
||||
resolvedAuthType = config.auth_type;
|
||||
setAuthType(resolvedAuthType);
|
||||
}
|
||||
|
||||
if (config.auth_type === 'session_jwt' && !token) {
|
||||
if (resolvedAuthType === 'session_jwt' && !token) {
|
||||
await generateNewToken();
|
||||
} else if (config.auth_type === 'simple_jwt' && !token) {
|
||||
} else if (resolvedAuthType === 'simple_jwt' && !token) {
|
||||
setShowTokenModal(true);
|
||||
setIsAuthLoading(false);
|
||||
} else {
|
||||
@@ -44,7 +59,7 @@ export default function useAuth() {
|
||||
}
|
||||
};
|
||||
initializeAuth();
|
||||
}, []);
|
||||
}, [token, authType]);
|
||||
|
||||
const handleTokenSubmit = (enteredToken: string) => {
|
||||
localStorage.setItem('authToken', enteredToken);
|
||||
|
||||
@@ -456,11 +456,6 @@
|
||||
"create": "Erstellen",
|
||||
"option": "Benutzern weitere Eingaben erlauben"
|
||||
},
|
||||
"searchConversations": {
|
||||
"searchPlaceholder": "Konversationen durchsuchen",
|
||||
"noResults": "Keine Ergebnisse gefunden",
|
||||
"loading": "Laden..."
|
||||
},
|
||||
"configTool": {
|
||||
"title": "Werkzeug-Konfiguration",
|
||||
"type": "Typ",
|
||||
|
||||
@@ -486,11 +486,6 @@
|
||||
"create": "Create",
|
||||
"option": "Allow users to prompt further"
|
||||
},
|
||||
"searchConversations": {
|
||||
"searchPlaceholder": "Search conversations",
|
||||
"noResults": "No results found",
|
||||
"loading": "Loading..."
|
||||
},
|
||||
"configTool": {
|
||||
"title": "Tool Config",
|
||||
"type": "Type",
|
||||
|
||||
@@ -474,11 +474,6 @@
|
||||
"create": "Crear",
|
||||
"option": "Permitir a los usuarios realizar más consultas"
|
||||
},
|
||||
"searchConversations": {
|
||||
"searchPlaceholder": "Buscar conversaciones",
|
||||
"noResults": "No se encontraron resultados",
|
||||
"loading": "Cargando..."
|
||||
},
|
||||
"configTool": {
|
||||
"title": "Configuración de la Herramienta",
|
||||
"type": "Tipo",
|
||||
|
||||
@@ -474,11 +474,6 @@
|
||||
"create": "作成",
|
||||
"option": "ユーザーがより多くのクエリを実行できるようにします。"
|
||||
},
|
||||
"searchConversations": {
|
||||
"searchPlaceholder": "会話を検索",
|
||||
"noResults": "結果が見つかりません",
|
||||
"loading": "読み込み中..."
|
||||
},
|
||||
"configTool": {
|
||||
"title": "ツール設定",
|
||||
"type": "タイプ",
|
||||
|
||||
@@ -474,11 +474,6 @@
|
||||
"create": "Создать",
|
||||
"option": "Позволить пользователям делать дополнительные запросы."
|
||||
},
|
||||
"searchConversations": {
|
||||
"searchPlaceholder": "Поиск разговоров",
|
||||
"noResults": "Результаты не найдены",
|
||||
"loading": "Загрузка..."
|
||||
},
|
||||
"configTool": {
|
||||
"title": "Настройка инструмента",
|
||||
"type": "Тип",
|
||||
|
||||
@@ -474,11 +474,6 @@
|
||||
"create": "建立",
|
||||
"option": "允許使用者進行更多查詢"
|
||||
},
|
||||
"searchConversations": {
|
||||
"searchPlaceholder": "搜尋對話",
|
||||
"noResults": "未找到結果",
|
||||
"loading": "載入中..."
|
||||
},
|
||||
"configTool": {
|
||||
"title": "工具設定",
|
||||
"type": "類型",
|
||||
|
||||
@@ -474,11 +474,6 @@
|
||||
"create": "创建",
|
||||
"option": "允许用户进行更多查询。"
|
||||
},
|
||||
"searchConversations": {
|
||||
"searchPlaceholder": "搜索对话",
|
||||
"noResults": "未找到结果",
|
||||
"loading": "加载中..."
|
||||
},
|
||||
"configTool": {
|
||||
"title": "工具配置",
|
||||
"type": "类型",
|
||||
|
||||
@@ -15,6 +15,7 @@ import {
|
||||
SelectValue,
|
||||
} from '../components/ui/select';
|
||||
import { ActiveState } from '../models/misc';
|
||||
import { selectRecentEvents } from '../notifications/notificationsSlice';
|
||||
import { selectToken } from '../preferences/preferenceSlice';
|
||||
import WrapperComponent from './WrapperModal';
|
||||
|
||||
@@ -33,6 +34,7 @@ export default function MCPServerModal({
|
||||
}: MCPServerModalProps) {
|
||||
const { t } = useTranslation();
|
||||
const token = useSelector(selectToken);
|
||||
const recentEvents = useSelector(selectRecentEvents);
|
||||
|
||||
const authTypes = [
|
||||
{ label: t('settings.tools.mcp.authTypes.none'), value: 'none' },
|
||||
@@ -71,17 +73,29 @@ export default function MCPServerModal({
|
||||
>([]);
|
||||
const [errors, setErrors] = useState<{ [key: string]: string }>({});
|
||||
const oauthPopupRef = useRef<Window | null>(null);
|
||||
const pollingCancelledRef = useRef(false);
|
||||
const pollTimerRef = useRef<ReturnType<typeof setTimeout> | null>(null);
|
||||
// Set after ``test_mcp_connection`` returns ``task_id``. The SSE
|
||||
// effect filters ``recentEvents`` to envelopes matching this id and
|
||||
// drives the OAuth UI (popup open / completion / failure) from the
|
||||
// push stream rather than polling the legacy status endpoint.
|
||||
const [oauthTaskId, setOauthTaskId] = useState<string | null>(null);
|
||||
// Highest event id we have already reacted to for this taskId. Each
|
||||
// mcp.oauth.* envelope must fire its side-effect once; without this
|
||||
// any later re-render that re-evaluates ``recentEvents`` would
|
||||
// re-open the popup or re-fire onComplete.
|
||||
const handledEventIdsRef = useRef<Set<string>>(new Set());
|
||||
// Holds the ``testConnection`` ``onComplete`` for the current
|
||||
// task id so the SSE effect can invoke it when the terminal event
|
||||
// lands. Reset to ``null`` on cancel / new test / unmount.
|
||||
const onCompleteRef = useRef<((result: any) => void) | null>(null);
|
||||
const popupOpenedRef = useRef(false);
|
||||
const [oauthCompleted, setOAuthCompleted] = useState(false);
|
||||
const [saveActive, setSaveActive] = useState(false);
|
||||
|
||||
const cleanupPolling = useCallback(() => {
|
||||
pollingCancelledRef.current = true;
|
||||
if (pollTimerRef.current) {
|
||||
clearTimeout(pollTimerRef.current);
|
||||
pollTimerRef.current = null;
|
||||
}
|
||||
const cleanupOAuthListener = useCallback(() => {
|
||||
setOauthTaskId(null);
|
||||
handledEventIdsRef.current = new Set();
|
||||
onCompleteRef.current = null;
|
||||
popupOpenedRef.current = false;
|
||||
if (oauthPopupRef.current && !oauthPopupRef.current.closed) {
|
||||
oauthPopupRef.current.close();
|
||||
}
|
||||
@@ -89,8 +103,8 @@ export default function MCPServerModal({
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
return cleanupPolling;
|
||||
}, [cleanupPolling]);
|
||||
return cleanupOAuthListener;
|
||||
}, [cleanupOAuthListener]);
|
||||
|
||||
useEffect(() => {
|
||||
if (modalState === 'ACTIVE' && server) {
|
||||
@@ -119,7 +133,7 @@ export default function MCPServerModal({
|
||||
}, [modalState, server]);
|
||||
|
||||
const resetForm = () => {
|
||||
cleanupPolling();
|
||||
cleanupOAuthListener();
|
||||
setFormData({
|
||||
name: t('settings.tools.mcp.defaultServerName'),
|
||||
server_url: '',
|
||||
@@ -228,114 +242,123 @@ export default function MCPServerModal({
|
||||
return config;
|
||||
};
|
||||
|
||||
const pollOAuthStatus = async (
|
||||
taskId: string,
|
||||
onComplete: (result: any) => void,
|
||||
) => {
|
||||
let attempts = 0;
|
||||
const maxAttempts = 60;
|
||||
let popupOpened = false;
|
||||
pollingCancelledRef.current = false;
|
||||
/**
|
||||
* Drive the OAuth handshake straight from the SSE stream:
|
||||
*
|
||||
* - ``mcp.oauth.awaiting_redirect`` → open the popup with the
|
||||
* ``authorization_url`` carried on the envelope. Previously this URL
|
||||
* came from polling ``/api/mcp_server/oauth_status/<task_id>``; the
|
||||
* worker now publishes it inline so we never need to poll.
|
||||
* - ``mcp.oauth.completed`` → enable Save, surface discovered tools,
|
||||
* invoke ``onComplete`` (resolves ``testConnection``'s pending state).
|
||||
* - ``mcp.oauth.failed`` → surface the error and reset Save.
|
||||
*
|
||||
* Each event is matched to the active task id via ``scope.id``. The
|
||||
* publisher is best-effort: a lost ``awaiting_redirect`` envelope
|
||||
* means the popup never opens, the user retries, and we accept that
|
||||
* over the prior 1s × 60 polling loop.
|
||||
*/
|
||||
useEffect(() => {
|
||||
if (!oauthTaskId) return;
|
||||
// ``recentEvents`` is newest-first (the slice ``unshift``s on
|
||||
// arrival). Walk it oldest-first so we observe the natural OAuth
|
||||
// ordering (``awaiting_redirect`` → ``completed``) when both
|
||||
// arrive between effect runs — otherwise we would short-circuit
|
||||
// on ``completed`` and never open the popup for the
|
||||
// ``awaiting_redirect`` envelope that was already buffered.
|
||||
for (let i = recentEvents.length - 1; i >= 0; i--) {
|
||||
const event = recentEvents[i];
|
||||
if (event.scope?.id !== oauthTaskId) continue;
|
||||
if (!event.id || handledEventIdsRef.current.has(event.id)) continue;
|
||||
|
||||
const poll = async () => {
|
||||
if (pollingCancelledRef.current) return;
|
||||
try {
|
||||
const resp = await userService.getMCPOAuthStatus(taskId, token);
|
||||
if (pollingCancelledRef.current) return;
|
||||
const data = await resp.json();
|
||||
if (pollingCancelledRef.current) return;
|
||||
const payload = (event.payload || {}) as Record<string, unknown>;
|
||||
|
||||
if (data.authorization_url && !popupOpened) {
|
||||
if (event.type === 'mcp.oauth.awaiting_redirect') {
|
||||
handledEventIdsRef.current.add(event.id);
|
||||
const authUrl = payload.authorization_url as string | undefined;
|
||||
if (authUrl && !popupOpenedRef.current) {
|
||||
popupOpenedRef.current = true;
|
||||
if (oauthPopupRef.current && !oauthPopupRef.current.closed) {
|
||||
oauthPopupRef.current.close();
|
||||
}
|
||||
oauthPopupRef.current = window.open(
|
||||
data.authorization_url,
|
||||
authUrl,
|
||||
'oauthPopup',
|
||||
'width=600,height=700',
|
||||
);
|
||||
popupOpened = true;
|
||||
|
||||
if (!oauthPopupRef.current) {
|
||||
// Popup blocked — surface the URL inline so the user can
|
||||
// click through manually. Browsers gate ``window.open``
|
||||
// outside of a user gesture, and the SSE event arrives
|
||||
// asynchronously, so a blocked popup is expected on
|
||||
// some browsers / configs.
|
||||
setTestResult({
|
||||
success: true,
|
||||
message: t('settings.tools.mcp.oauthPopupBlocked', {
|
||||
defaultValue:
|
||||
'Popup blocked by browser. Click below to authorize:',
|
||||
}),
|
||||
authorization_url: data.authorization_url,
|
||||
authorization_url: authUrl,
|
||||
});
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
const callbackReceived =
|
||||
data.status === 'callback_received' || data.status === 'completed';
|
||||
|
||||
if (data.status === 'completed') {
|
||||
setOAuthCompleted(true);
|
||||
setSaveActive(true);
|
||||
onComplete({
|
||||
...data,
|
||||
if (event.type === 'mcp.oauth.completed') {
|
||||
handledEventIdsRef.current.add(event.id);
|
||||
const tools = Array.isArray(payload.tools) ? payload.tools : [];
|
||||
const toolsCount =
|
||||
(payload.tools_count as number | undefined) ?? tools.length;
|
||||
setOAuthCompleted(true);
|
||||
setSaveActive(true);
|
||||
if (oauthPopupRef.current && !oauthPopupRef.current.closed) {
|
||||
oauthPopupRef.current.close();
|
||||
}
|
||||
const cb = onCompleteRef.current;
|
||||
onCompleteRef.current = null;
|
||||
setOauthTaskId(null);
|
||||
if (cb) {
|
||||
cb({
|
||||
status: 'completed',
|
||||
task_id: oauthTaskId,
|
||||
tools,
|
||||
tools_count: toolsCount,
|
||||
success: true,
|
||||
message: t('settings.tools.mcp.oauthCompleted'),
|
||||
});
|
||||
if (oauthPopupRef.current && !oauthPopupRef.current.closed) {
|
||||
oauthPopupRef.current.close();
|
||||
}
|
||||
} else if (data.status === 'error' || data.success === false) {
|
||||
setSaveActive(false);
|
||||
onComplete({
|
||||
...data,
|
||||
success: false,
|
||||
message: data.message || t('settings.tools.mcp.errors.oauthFailed'),
|
||||
});
|
||||
if (oauthPopupRef.current && !oauthPopupRef.current.closed) {
|
||||
oauthPopupRef.current.close();
|
||||
}
|
||||
} else {
|
||||
if (++attempts < maxAttempts) {
|
||||
if (
|
||||
oauthPopupRef.current &&
|
||||
oauthPopupRef.current.closed &&
|
||||
popupOpened &&
|
||||
!callbackReceived
|
||||
) {
|
||||
setSaveActive(false);
|
||||
onComplete({
|
||||
success: false,
|
||||
message: t('settings.tools.mcp.errors.oauthFailed'),
|
||||
});
|
||||
return;
|
||||
}
|
||||
pollTimerRef.current = setTimeout(poll, 1000);
|
||||
} else {
|
||||
setSaveActive(false);
|
||||
cleanupPolling();
|
||||
onComplete({
|
||||
success: false,
|
||||
message: t('settings.tools.mcp.errors.oauthTimeout'),
|
||||
});
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
if (pollingCancelledRef.current) return;
|
||||
if (++attempts < maxAttempts) {
|
||||
pollTimerRef.current = setTimeout(poll, 1000);
|
||||
} else {
|
||||
cleanupPolling();
|
||||
onComplete({
|
||||
success: false,
|
||||
message: t('settings.tools.mcp.errors.oauthTimeout'),
|
||||
});
|
||||
}
|
||||
continue;
|
||||
}
|
||||
};
|
||||
poll();
|
||||
};
|
||||
|
||||
if (event.type === 'mcp.oauth.failed') {
|
||||
handledEventIdsRef.current.add(event.id);
|
||||
const message =
|
||||
(payload.error as string) ??
|
||||
t('settings.tools.mcp.errors.oauthFailed');
|
||||
setSaveActive(false);
|
||||
if (oauthPopupRef.current && !oauthPopupRef.current.closed) {
|
||||
oauthPopupRef.current.close();
|
||||
}
|
||||
const cb = onCompleteRef.current;
|
||||
onCompleteRef.current = null;
|
||||
setOauthTaskId(null);
|
||||
if (cb) {
|
||||
cb({
|
||||
status: 'error',
|
||||
task_id: oauthTaskId,
|
||||
success: false,
|
||||
message,
|
||||
});
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}, [recentEvents, oauthTaskId, t]);
|
||||
|
||||
const testConnection = async () => {
|
||||
if (!validateForm()) return;
|
||||
cleanupPolling();
|
||||
cleanupOAuthListener();
|
||||
setTesting(true);
|
||||
setTestResult(null);
|
||||
setDiscoveredTools([]);
|
||||
@@ -355,7 +378,7 @@ export default function MCPServerModal({
|
||||
message: t('settings.tools.mcp.oauthInProgress'),
|
||||
});
|
||||
setSaveActive(false);
|
||||
pollOAuthStatus(result.task_id, (finalResult) => {
|
||||
onCompleteRef.current = (finalResult: any) => {
|
||||
setTestResult(finalResult);
|
||||
if (finalResult.tools && Array.isArray(finalResult.tools)) {
|
||||
setDiscoveredTools(finalResult.tools);
|
||||
@@ -365,7 +388,11 @@ export default function MCPServerModal({
|
||||
oauth_task_id: result.task_id || '',
|
||||
}));
|
||||
setTesting(false);
|
||||
});
|
||||
};
|
||||
// Activate the SSE listener for this task id. The effect above
|
||||
// will react when ``mcp.oauth.{awaiting_redirect,completed,failed}``
|
||||
// arrives.
|
||||
setOauthTaskId(result.task_id);
|
||||
} else {
|
||||
setTestResult(result);
|
||||
if (result.success && result.tools && Array.isArray(result.tools)) {
|
||||
|
||||
@@ -1,165 +0,0 @@
|
||||
import { useEffect, useMemo, useRef, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import SearchIcon from '../assets/search.svg';
|
||||
import { searchConversations } from '../preferences/preferenceApi';
|
||||
import WrapperModal from './WrapperModal';
|
||||
|
||||
type ConversationListItem = {
|
||||
id: string;
|
||||
name: string;
|
||||
match_field?: 'name' | 'prompt' | 'response' | null;
|
||||
match_snippet?: string | null;
|
||||
};
|
||||
|
||||
type SearchConversationsModalProps = {
|
||||
close: () => void;
|
||||
conversations: ConversationListItem[];
|
||||
token: string | null;
|
||||
onSelectConversation: (id: string) => void;
|
||||
};
|
||||
|
||||
// Escape regex metacharacters so the user query can be used in a RegExp.
|
||||
function escapeRegExp(value: string): string {
|
||||
return value.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
|
||||
}
|
||||
|
||||
function HighlightedText({ text, query }: { text: string; query: string }) {
|
||||
const trimmed = query.trim();
|
||||
if (!trimmed) return <>{text}</>;
|
||||
const parts = text.split(new RegExp(`(${escapeRegExp(trimmed)})`, 'gi'));
|
||||
return (
|
||||
<>
|
||||
{parts.map((part, idx) =>
|
||||
part.toLowerCase() === trimmed.toLowerCase() ? (
|
||||
<mark
|
||||
key={idx}
|
||||
className="bg-transparent font-semibold text-purple-30"
|
||||
>
|
||||
{part}
|
||||
</mark>
|
||||
) : (
|
||||
<span key={idx}>{part}</span>
|
||||
),
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
export default function SearchConversationsModal({
|
||||
close,
|
||||
conversations,
|
||||
token,
|
||||
onSelectConversation,
|
||||
}: SearchConversationsModalProps) {
|
||||
const { t } = useTranslation();
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
const [query, setQuery] = useState('');
|
||||
const [results, setResults] = useState<ConversationListItem[] | null>(null);
|
||||
const [isSearching, setIsSearching] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
inputRef.current?.focus();
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
const trimmed = query.trim();
|
||||
if (!trimmed) {
|
||||
setResults(null);
|
||||
setIsSearching(false);
|
||||
return;
|
||||
}
|
||||
setIsSearching(true);
|
||||
const handle = setTimeout(() => {
|
||||
searchConversations(trimmed, token).then((result) => {
|
||||
setResults(result.data ?? []);
|
||||
setIsSearching(false);
|
||||
});
|
||||
}, 300);
|
||||
return () => clearTimeout(handle);
|
||||
}, [query, token]);
|
||||
|
||||
const visibleConversations = useMemo(() => {
|
||||
if (!query.trim()) return conversations;
|
||||
return results ?? [];
|
||||
}, [query, results, conversations]);
|
||||
|
||||
const handleSelect = (id: string) => {
|
||||
onSelectConversation(id);
|
||||
close();
|
||||
};
|
||||
|
||||
const showEmptyState =
|
||||
!!query.trim() && !isSearching && visibleConversations.length === 0;
|
||||
|
||||
return (
|
||||
<WrapperModal
|
||||
close={close}
|
||||
className="w-[92vw] max-w-xl p-0"
|
||||
contentClassName="max-h-[70vh]"
|
||||
>
|
||||
<div className="flex flex-col">
|
||||
<div className="border-sidebar-border flex items-center gap-2 border-b px-5 py-4">
|
||||
<img src={SearchIcon} alt="search" className="h-4 w-4 opacity-60" />
|
||||
<input
|
||||
ref={inputRef}
|
||||
type="text"
|
||||
value={query}
|
||||
onChange={(e) => setQuery(e.target.value)}
|
||||
placeholder={t('modals.searchConversations.searchPlaceholder')}
|
||||
className="text-foreground placeholder:text-muted-foreground w-full bg-transparent text-sm outline-none"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="max-h-[55vh] overflow-y-auto py-2">
|
||||
{isSearching && (
|
||||
<div className="text-muted-foreground px-5 py-3 text-xs">
|
||||
{t('modals.searchConversations.loading')}
|
||||
</div>
|
||||
)}
|
||||
{showEmptyState && (
|
||||
<div className="text-muted-foreground px-5 py-3 text-xs">
|
||||
{t('modals.searchConversations.noResults')}
|
||||
</div>
|
||||
)}
|
||||
{!isSearching &&
|
||||
visibleConversations.map((conversation) => {
|
||||
const trimmedQuery = query.trim();
|
||||
const showSnippet =
|
||||
!!trimmedQuery &&
|
||||
!!conversation.match_snippet &&
|
||||
conversation.match_field !== 'name';
|
||||
return (
|
||||
<button
|
||||
key={conversation.id}
|
||||
type="button"
|
||||
onClick={() => handleSelect(conversation.id)}
|
||||
className="hover:bg-sidebar-accent text-foreground flex w-full flex-col items-start gap-0.5 px-5 py-2.5 text-left text-sm"
|
||||
>
|
||||
<span className="w-full truncate">
|
||||
{trimmedQuery ? (
|
||||
<HighlightedText
|
||||
text={conversation.name}
|
||||
query={trimmedQuery}
|
||||
/>
|
||||
) : (
|
||||
conversation.name
|
||||
)}
|
||||
</span>
|
||||
{showSnippet && (
|
||||
<span className="text-muted-foreground line-clamp-2 w-full text-xs">
|
||||
<HighlightedText
|
||||
text={conversation.match_snippet as string}
|
||||
query={trimmedQuery}
|
||||
/>
|
||||
</span>
|
||||
)}
|
||||
</button>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
</WrapperModal>
|
||||
);
|
||||
}
|
||||
174
frontend/src/notifications/ToolApprovalToast.tsx
Normal file
174
frontend/src/notifications/ToolApprovalToast.tsx
Normal file
@@ -0,0 +1,174 @@
|
||||
import { useDispatch, useSelector } from 'react-redux';
|
||||
import { useMatch, useNavigate } from 'react-router-dom';
|
||||
|
||||
import WarnIcon from '../assets/warn.svg';
|
||||
import type { RootState } from '../store';
|
||||
|
||||
import {
|
||||
dismissToolApproval,
|
||||
selectDismissedToolApprovals,
|
||||
selectRecentEvents,
|
||||
} from './notificationsSlice';
|
||||
|
||||
/**
|
||||
* Surface ``tool.approval.required`` events as toasts that look like
|
||||
* ``UploadToast`` (same fixed bottom-right rail) — but only when the
|
||||
* user is NOT already on the conversation that needs the approval.
|
||||
*
|
||||
* - Dedup by ``conversation_id`` (the SSE ``scope.id``): keep only
|
||||
* the newest pending event per conversation, so multiple paused
|
||||
* tools in one conversation collapse to one toast.
|
||||
* - Dismissal is per-event-id so a *new* pause of the same
|
||||
* conversation will re-surface (different event id).
|
||||
* - Clicking "Review" navigates to ``/c/<id>`` and dismisses.
|
||||
*/
|
||||
export default function ToolApprovalToast() {
|
||||
const dispatch = useDispatch();
|
||||
const navigate = useNavigate();
|
||||
const events = useSelector(selectRecentEvents);
|
||||
const dismissed = useSelector(selectDismissedToolApprovals);
|
||||
|
||||
// Pull the active conversation id off the route. Two route shapes
|
||||
// place a conversation in view: the bare ``/c/:conversationId`` and
|
||||
// the agent-scoped ``/agents/:agentId/c/:conversationId``. Hooks
|
||||
// are unconditional; the toast just respects whichever matches.
|
||||
//
|
||||
// ``/c/new`` is the conversation route's literal-string placeholder
|
||||
// for "unknown / not-yet-loaded conversation" (see the rewrite in
|
||||
// the conversation route). Treat it the same as no match — the
|
||||
// user isn't viewing any specific conversation yet, so an approval
|
||||
// toast for any conversation should still surface.
|
||||
const plainMatch = useMatch('/c/:conversationId');
|
||||
const agentMatch = useMatch('/agents/:agentId/c/:conversationId');
|
||||
const matchedConversationId =
|
||||
plainMatch?.params.conversationId ??
|
||||
agentMatch?.params.conversationId ??
|
||||
null;
|
||||
const currentConversationId =
|
||||
matchedConversationId === 'new' ? null : matchedConversationId;
|
||||
|
||||
// Conversation name lookup — best-effort. The slice's
|
||||
// ``preference.conversations.data`` is populated by
|
||||
// ``useDataInitializer`` once auth resolves; until then we fall
|
||||
// back to the conversation id.
|
||||
const conversations = useSelector(
|
||||
(state: RootState) => state.preference.conversations.data,
|
||||
);
|
||||
|
||||
const dismissedSet = new Set(dismissed);
|
||||
const pendingByConversation = new Map<
|
||||
string,
|
||||
{ eventId: string; conversationId: string }
|
||||
>();
|
||||
for (const event of events) {
|
||||
if (event.type !== 'tool.approval.required') continue;
|
||||
if (!event.id) continue; // can't dismiss without an id
|
||||
if (dismissedSet.has(event.id)) continue;
|
||||
const conversationId = event.scope?.id;
|
||||
if (!conversationId) continue;
|
||||
if (currentConversationId && conversationId === currentConversationId) {
|
||||
continue;
|
||||
}
|
||||
if (pendingByConversation.has(conversationId)) continue;
|
||||
// ``recentEvents`` is newest-first, so the first match per convId
|
||||
// is the most recent unhandled approval.
|
||||
pendingByConversation.set(conversationId, {
|
||||
eventId: event.id,
|
||||
conversationId,
|
||||
});
|
||||
}
|
||||
|
||||
if (pendingByConversation.size === 0) return null;
|
||||
|
||||
const conversationName = (conversationId: string): string => {
|
||||
const found = conversations?.find((c) => c.id === conversationId);
|
||||
return found?.name ?? 'Conversation';
|
||||
};
|
||||
|
||||
return (
|
||||
// Sit above ``UploadToast`` (which owns ``bottom-4 right-4``)
|
||||
// rather than overlapping it. ``bottom-24`` ≈ 96px clears one
|
||||
// standard-height upload toast; multiple in-flight uploads will
|
||||
// stack into the gap, at which point the approval toast still
|
||||
// floats on top via ``z-50``. Acceptable v1 layout — the two
|
||||
// surfaces are rarely competing.
|
||||
<div
|
||||
className="fixed right-4 bottom-24 z-50 flex max-w-md flex-col gap-2"
|
||||
onMouseDown={(e) => e.stopPropagation()}
|
||||
role="status"
|
||||
aria-live="polite"
|
||||
aria-atomic="true"
|
||||
>
|
||||
{Array.from(pendingByConversation.values()).map(
|
||||
({ eventId, conversationId }) => (
|
||||
<div
|
||||
key={eventId}
|
||||
className="border-border bg-card w-[271px] overflow-hidden rounded-2xl border shadow-[0px_24px_48px_0px_#00000029]"
|
||||
>
|
||||
<div className="bg-accent/50 dark:bg-muted flex items-center justify-between px-4 py-3">
|
||||
<h3 className="font-inter dark:text-foreground text-[14px] leading-[16.5px] font-medium text-black">
|
||||
Tool approval needed
|
||||
</h3>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => dispatch(dismissToolApproval(eventId))}
|
||||
className="flex h-8 items-center justify-center p-0 text-black opacity-70 transition-opacity hover:opacity-100 dark:text-white"
|
||||
aria-label="Dismiss"
|
||||
>
|
||||
<svg
|
||||
width="16"
|
||||
height="16"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
className="h-4 w-4"
|
||||
>
|
||||
<path
|
||||
d="M18 6L6 18"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
<path
|
||||
d="M6 6L18 18"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
<div className="flex items-center justify-between gap-3 px-5 py-3">
|
||||
<div className="flex min-w-0 items-center gap-2">
|
||||
<img
|
||||
src={WarnIcon}
|
||||
alt=""
|
||||
className="h-5 w-5 shrink-0"
|
||||
aria-hidden="true"
|
||||
/>
|
||||
<p
|
||||
className="font-inter dark:text-muted-foreground max-w-[140px] truncate text-[13px] leading-[16.5px] font-normal text-black"
|
||||
title={conversationName(conversationId)}
|
||||
>
|
||||
{conversationName(conversationId)}
|
||||
</p>
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => {
|
||||
dispatch(dismissToolApproval(eventId));
|
||||
navigate(`/c/${conversationId}`);
|
||||
}}
|
||||
className="rounded-full bg-[#7D54D1] px-3 py-1 text-[12px] font-medium text-white shadow-sm hover:bg-[#6a45b8]"
|
||||
>
|
||||
Review
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
),
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
71
frontend/src/notifications/dismissedPersistence.test.ts
Normal file
71
frontend/src/notifications/dismissedPersistence.test.ts
Normal file
@@ -0,0 +1,71 @@
|
||||
import { beforeEach, describe, expect, it } from 'vitest';
|
||||
|
||||
import {
|
||||
isDismissed,
|
||||
loadDismissed,
|
||||
saveDismissed,
|
||||
} from './dismissedPersistence';
|
||||
|
||||
const KEY = 'test:dismissed';
|
||||
const TTL = 24 * 60 * 60 * 1000;
|
||||
|
||||
describe('dismissedPersistence', () => {
|
||||
beforeEach(() => {
|
||||
localStorage.clear();
|
||||
});
|
||||
|
||||
it('saveDismissed + loadDismissed round-trips entries', () => {
|
||||
const now = Date.now();
|
||||
saveDismissed(KEY, [
|
||||
{ id: 'a', at: now },
|
||||
{ id: 'b', at: now - 1000 },
|
||||
]);
|
||||
const loaded = loadDismissed(KEY, TTL);
|
||||
expect(loaded).toEqual([
|
||||
{ id: 'a', at: now },
|
||||
{ id: 'b', at: now - 1000 },
|
||||
]);
|
||||
});
|
||||
|
||||
it('loadDismissed returns [] when key absent', () => {
|
||||
expect(loadDismissed(KEY, TTL)).toEqual([]);
|
||||
});
|
||||
|
||||
it('loadDismissed drops entries past the TTL cutoff', () => {
|
||||
const now = Date.now();
|
||||
saveDismissed(KEY, [
|
||||
{ id: 'fresh', at: now - 1000 },
|
||||
{ id: 'stale', at: now - (TTL + 1000) },
|
||||
]);
|
||||
const loaded = loadDismissed(KEY, TTL);
|
||||
expect(loaded.map((e) => e.id)).toEqual(['fresh']);
|
||||
});
|
||||
|
||||
it('loadDismissed returns [] on malformed JSON without throwing', () => {
|
||||
localStorage.setItem(KEY, '{not json');
|
||||
expect(loadDismissed(KEY, TTL)).toEqual([]);
|
||||
});
|
||||
|
||||
it('loadDismissed filters out entries with wrong shape', () => {
|
||||
const now = Date.now();
|
||||
localStorage.setItem(
|
||||
KEY,
|
||||
JSON.stringify([
|
||||
{ id: 'good', at: now },
|
||||
{ id: 123, at: now },
|
||||
{ id: 'bad-at', at: 'nope' },
|
||||
null,
|
||||
'string-entry',
|
||||
]),
|
||||
);
|
||||
const loaded = loadDismissed(KEY, TTL);
|
||||
expect(loaded.map((e) => e.id)).toEqual(['good']);
|
||||
});
|
||||
|
||||
it('isDismissed matches by id', () => {
|
||||
const list = [{ id: 'a', at: 1 }];
|
||||
expect(isDismissed(list, 'a')).toBe(true);
|
||||
expect(isDismissed(list, 'b')).toBe(false);
|
||||
expect(isDismissed([], 'a')).toBe(false);
|
||||
});
|
||||
});
|
||||
42
frontend/src/notifications/dismissedPersistence.ts
Normal file
42
frontend/src/notifications/dismissedPersistence.ts
Normal file
@@ -0,0 +1,42 @@
|
||||
// Persisted dismissal lists for SSE-driven toasts. Without persistence
|
||||
// the next page's backlog replay re-fires the events and pops the
|
||||
// toast back. TTL matches the backend's stream retention.
|
||||
|
||||
export interface DismissedEntry {
|
||||
id: string;
|
||||
at: number;
|
||||
}
|
||||
|
||||
export function loadDismissed(key: string, ttlMs: number): DismissedEntry[] {
|
||||
if (typeof localStorage === 'undefined') return [];
|
||||
try {
|
||||
const raw = localStorage.getItem(key);
|
||||
if (!raw) return [];
|
||||
const parsed = JSON.parse(raw);
|
||||
if (!Array.isArray(parsed)) return [];
|
||||
const cutoff = Date.now() - ttlMs;
|
||||
return parsed.filter(
|
||||
(e): e is DismissedEntry =>
|
||||
!!e &&
|
||||
typeof e === 'object' &&
|
||||
typeof (e as DismissedEntry).id === 'string' &&
|
||||
typeof (e as DismissedEntry).at === 'number' &&
|
||||
(e as DismissedEntry).at >= cutoff,
|
||||
);
|
||||
} catch {
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
export function saveDismissed(key: string, list: DismissedEntry[]): void {
|
||||
if (typeof localStorage === 'undefined') return;
|
||||
try {
|
||||
localStorage.setItem(key, JSON.stringify(list));
|
||||
} catch {
|
||||
// Best-effort: ignore quota / private-mode errors.
|
||||
}
|
||||
}
|
||||
|
||||
export function isDismissed(list: DismissedEntry[], id: string): boolean {
|
||||
return list.some((e) => e.id === id);
|
||||
}
|
||||
109
frontend/src/notifications/notificationsSlice.test.ts
Normal file
109
frontend/src/notifications/notificationsSlice.test.ts
Normal file
@@ -0,0 +1,109 @@
|
||||
import { describe, expect, it, vi, afterEach } from 'vitest';
|
||||
|
||||
import reducer, {
|
||||
dismissToolApproval,
|
||||
sseEventReceived,
|
||||
sseLastEventIdReset,
|
||||
type SSEEvent,
|
||||
} from './notificationsSlice';
|
||||
|
||||
const baseEvent = (overrides: Partial<SSEEvent> = {}): SSEEvent => ({
|
||||
id: 'evt-1',
|
||||
type: 'source.ingest.progress',
|
||||
...overrides,
|
||||
});
|
||||
|
||||
const seedState = () => reducer(undefined, { type: '@@INIT' });
|
||||
|
||||
afterEach(() => {
|
||||
vi.useRealTimers();
|
||||
});
|
||||
|
||||
describe('sseEventReceived', () => {
|
||||
it('dedupes by id when the same envelope arrives twice', () => {
|
||||
let state = seedState();
|
||||
state = reducer(state, sseEventReceived(baseEvent({ id: 'a' })));
|
||||
state = reducer(state, sseEventReceived(baseEvent({ id: 'a' })));
|
||||
expect(state.recentEvents).toHaveLength(1);
|
||||
expect(state.recentEvents[0].id).toBe('a');
|
||||
});
|
||||
|
||||
it('does not update lastEventId for envelopes without an id (backlog.truncated)', () => {
|
||||
let state = seedState();
|
||||
state = reducer(state, sseEventReceived(baseEvent({ id: 'cursor-1' })));
|
||||
expect(state.lastEventId).toBe('cursor-1');
|
||||
state = reducer(
|
||||
state,
|
||||
sseEventReceived({ type: 'backlog.truncated' } as SSEEvent),
|
||||
);
|
||||
expect(state.lastEventId).toBe('cursor-1');
|
||||
});
|
||||
|
||||
it('caps recentEvents at 100 entries (oldest evicted)', () => {
|
||||
let state = seedState();
|
||||
for (let i = 0; i < 105; i += 1) {
|
||||
state = reducer(state, sseEventReceived(baseEvent({ id: `e-${i}` })));
|
||||
}
|
||||
expect(state.recentEvents).toHaveLength(100);
|
||||
// Newest first.
|
||||
expect(state.recentEvents[0].id).toBe('e-104');
|
||||
expect(state.recentEvents[state.recentEvents.length - 1].id).toBe('e-5');
|
||||
});
|
||||
});
|
||||
|
||||
describe('sseLastEventIdReset', () => {
|
||||
it('clears lastEventId back to null', () => {
|
||||
let state = seedState();
|
||||
state = reducer(state, sseEventReceived(baseEvent({ id: 'x' })));
|
||||
expect(state.lastEventId).toBe('x');
|
||||
state = reducer(state, sseLastEventIdReset());
|
||||
expect(state.lastEventId).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('dismissToolApproval', () => {
|
||||
it('dedupes by id and refreshes the timestamp', () => {
|
||||
vi.useFakeTimers();
|
||||
vi.setSystemTime(new Date('2026-01-01T00:00:00Z'));
|
||||
let state = seedState();
|
||||
state = reducer(state, dismissToolApproval('approval-1'));
|
||||
const firstAt = state.dismissedToolApprovals[0].at;
|
||||
|
||||
vi.setSystemTime(new Date('2026-01-01T00:05:00Z'));
|
||||
state = reducer(state, dismissToolApproval('approval-1'));
|
||||
expect(state.dismissedToolApprovals).toHaveLength(1);
|
||||
expect(state.dismissedToolApprovals[0].at).toBeGreaterThan(firstAt);
|
||||
});
|
||||
|
||||
it('evicts entries older than the 24h TTL', () => {
|
||||
vi.useFakeTimers();
|
||||
vi.setSystemTime(new Date('2026-01-01T00:00:00Z'));
|
||||
let state = seedState();
|
||||
state = reducer(state, dismissToolApproval('old-1'));
|
||||
|
||||
// Move past the 24h TTL window.
|
||||
vi.setSystemTime(new Date('2026-01-02T00:00:01Z'));
|
||||
state = reducer(state, dismissToolApproval('fresh-1'));
|
||||
const ids = state.dismissedToolApprovals.map((entry) => entry.id);
|
||||
expect(ids).toEqual(['fresh-1']);
|
||||
});
|
||||
|
||||
it('applies the 200-entry cap as a backstop after TTL filtering', () => {
|
||||
vi.useFakeTimers();
|
||||
vi.setSystemTime(new Date('2026-01-01T00:00:00Z'));
|
||||
let state = seedState();
|
||||
// Insert 205 distinct ids within the TTL window.
|
||||
for (let i = 0; i < 205; i += 1) {
|
||||
// Advance time slightly so the at-values are distinct but well
|
||||
// inside the 24h TTL.
|
||||
vi.setSystemTime(
|
||||
new Date(`2026-01-01T00:00:${(i % 60).toString().padStart(2, '0')}Z`),
|
||||
);
|
||||
state = reducer(state, dismissToolApproval(`id-${i}`));
|
||||
}
|
||||
expect(state.dismissedToolApprovals).toHaveLength(200);
|
||||
// The 200-cap keeps the most recently pushed ids.
|
||||
expect(state.dismissedToolApprovals[0].id).toBe('id-5');
|
||||
expect(state.dismissedToolApprovals[199].id).toBe('id-204');
|
||||
});
|
||||
});
|
||||
200
frontend/src/notifications/notificationsSlice.ts
Normal file
200
frontend/src/notifications/notificationsSlice.ts
Normal file
@@ -0,0 +1,200 @@
|
||||
import { createSelector, createSlice, PayloadAction } from '@reduxjs/toolkit';
|
||||
|
||||
import { RootState } from '../store';
|
||||
import { loadDismissed, saveDismissed } from './dismissedPersistence';
|
||||
|
||||
const DISMISSED_TOOL_APPROVALS_STORAGE_KEY = 'docsgpt:dismissedToolApprovals';
|
||||
|
||||
/**
|
||||
* Envelope shape published by the backend SSE endpoint
|
||||
* (`application/events/publisher.py`). Mirrors the wire JSON 1:1.
|
||||
*/
|
||||
export interface SSEEvent<P = Record<string, unknown>> {
|
||||
id?: string;
|
||||
type: string;
|
||||
ts?: string;
|
||||
user_id?: string;
|
||||
topic?: string;
|
||||
scope?: { kind: string; id: string };
|
||||
payload?: P;
|
||||
}
|
||||
|
||||
/**
|
||||
* Connection-health state machine the rest of the app reads via
|
||||
* ``selectPushChannelHealthy`` to gate polling-fallback behaviour.
|
||||
*
|
||||
* - ``connecting`` — initial fetch in flight, or reconnecting after drop.
|
||||
* - ``healthy`` — at least one event (data or keepalive) received and
|
||||
* the stream has been open >2s.
|
||||
* - ``unhealthy`` — last attempt failed or has been dropped without a
|
||||
* successful re-establish; fall back to polling.
|
||||
*/
|
||||
export type PushHealth = 'connecting' | 'healthy' | 'unhealthy';
|
||||
|
||||
interface NotificationsState {
|
||||
health: PushHealth;
|
||||
/** Most-recent server-issued id; sent back as ``Last-Event-ID`` on reconnect. */
|
||||
lastEventId: string | null;
|
||||
/** Bounded ring of recent events for the in-app notifications surface. */
|
||||
recentEvents: SSEEvent[];
|
||||
/**
|
||||
* Wallclock ms of last received data-bearing event. SSE keepalives
|
||||
* are comment lines (no ``id:``/``data:``) and do NOT update this —
|
||||
* they're filtered out at the parser level.
|
||||
*/
|
||||
lastEventReceivedAt: number | null;
|
||||
/**
|
||||
* Event ids of ``tool.approval.required`` notifications the user
|
||||
* dismissed (close button or by navigating into the conversation),
|
||||
* each tagged with the wallclock ms at which it was dismissed.
|
||||
* Keyed by the SSE event id so a *new* approval for the same
|
||||
* conversation re-surfaces; the dismissal only suppresses the one
|
||||
* specific paused-tool prompt.
|
||||
*
|
||||
* Entries are evicted by TTL first (anything older than
|
||||
* ``DISMISSED_TOOL_APPROVALS_TTL_MS``), then by FIFO cap. The TTL
|
||||
* matters because a pure FIFO with a small cap can evict a *still-
|
||||
* pending* approval id before the user acts on it — re-popping the
|
||||
* toast on the next dispatch. The 24h ceiling is longer than any
|
||||
* plausible approval-pending window.
|
||||
*/
|
||||
dismissedToolApprovals: Array<{ id: string; at: number }>;
|
||||
}
|
||||
|
||||
const RECENT_EVENTS_CAP = 100;
|
||||
const DISMISSED_TOOL_APPROVALS_CAP = 200;
|
||||
const DISMISSED_TOOL_APPROVALS_TTL_MS = 24 * 60 * 60 * 1000;
|
||||
|
||||
const initialState: NotificationsState = {
|
||||
health: 'connecting',
|
||||
lastEventId: null,
|
||||
recentEvents: [],
|
||||
lastEventReceivedAt: null,
|
||||
// Hydrate from localStorage: SSE backlog replay re-delivers the
|
||||
// originating ``tool.approval.required`` envelopes on reload.
|
||||
dismissedToolApprovals: loadDismissed(
|
||||
DISMISSED_TOOL_APPROVALS_STORAGE_KEY,
|
||||
DISMISSED_TOOL_APPROVALS_TTL_MS,
|
||||
),
|
||||
};
|
||||
|
||||
export const notificationsSlice = createSlice({
|
||||
name: 'notifications',
|
||||
initialState,
|
||||
reducers: {
|
||||
sseEventReceived: (state, action: PayloadAction<SSEEvent>) => {
|
||||
const e = action.payload;
|
||||
// Drop immediate duplicates. Snapshot replay + live tail can
|
||||
// both deliver the same id when the live pubsub frame and the
|
||||
// replay XRANGE overlap, and consumers that walk
|
||||
// ``recentEvents`` (FileTree, ConnectorTree, MCPServerModal,
|
||||
// ToolApprovalToast) would otherwise act on the same envelope
|
||||
// twice. The route's dedup floor catches the common case; this
|
||||
// is a belt-and-suspenders for in-tab StrictMode double-mounts
|
||||
// and any envelope that slips through with the same id.
|
||||
if (e.id && state.recentEvents[0]?.id === e.id) return;
|
||||
state.recentEvents.unshift(e);
|
||||
if (state.recentEvents.length > RECENT_EVENTS_CAP) {
|
||||
state.recentEvents.length = RECENT_EVENTS_CAP;
|
||||
}
|
||||
if (e.id) state.lastEventId = e.id;
|
||||
state.lastEventReceivedAt = Date.now();
|
||||
},
|
||||
sseHealthChanged: (state, action: PayloadAction<PushHealth>) => {
|
||||
state.health = action.payload;
|
||||
},
|
||||
/**
|
||||
* Lifecycle helper used by reconnect bookkeeping — does not record
|
||||
* an event, just stamps "we heard from the server" so the polling
|
||||
* fallback stays disabled while keepalives arrive.
|
||||
*/
|
||||
sseHeartbeat: (state) => {
|
||||
state.lastEventReceivedAt = Date.now();
|
||||
},
|
||||
sseLastEventIdReset: (state) => {
|
||||
// Backlog truncated — drop the cursor so the next reconnect
|
||||
// doesn't try to resume past the retained window.
|
||||
state.lastEventId = null;
|
||||
},
|
||||
/**
|
||||
* Advance the cursor without recording an event. Fired for every
|
||||
* id-bearing frame including keepalives and id-only comments,
|
||||
* so the slice cursor tracks the freshest id the wire has
|
||||
* delivered even when no envelope was dispatched. Without this,
|
||||
* ``lastEventId`` would only advance via ``sseEventReceived`` and
|
||||
* a long quiet period of keepalives would leave it stale —
|
||||
* eventually re-snapshotting the same backlog on each reconnect
|
||||
* and exhausting the per-user replay budget.
|
||||
*/
|
||||
sseLastEventIdAdvanced: (state, action: PayloadAction<string>) => {
|
||||
state.lastEventId = action.payload;
|
||||
},
|
||||
clearRecentEvents: (state) => {
|
||||
state.recentEvents = [];
|
||||
},
|
||||
/**
|
||||
* Suppress a ``tool.approval.required`` notification by the SSE
|
||||
* event id. The toast surface filters dismissed ids out; a *new*
|
||||
* approval event for the same conversation has a different id
|
||||
* and re-surfaces, which is the desired UX (each pause is its
|
||||
* own decision).
|
||||
*/
|
||||
dismissToolApproval: (state, action: PayloadAction<string>) => {
|
||||
const id = action.payload;
|
||||
const now = Date.now();
|
||||
// Evict expired entries first so the TTL — not the FIFO cap —
|
||||
// governs when stale ids drop, keeping still-pending approvals
|
||||
// suppressed.
|
||||
const cutoff = now - DISMISSED_TOOL_APPROVALS_TTL_MS;
|
||||
state.dismissedToolApprovals = state.dismissedToolApprovals.filter(
|
||||
(entry) => entry.at >= cutoff && entry.id !== id,
|
||||
);
|
||||
state.dismissedToolApprovals.push({ id, at: now });
|
||||
if (state.dismissedToolApprovals.length > DISMISSED_TOOL_APPROVALS_CAP) {
|
||||
state.dismissedToolApprovals = state.dismissedToolApprovals.slice(
|
||||
-DISMISSED_TOOL_APPROVALS_CAP,
|
||||
);
|
||||
}
|
||||
saveDismissed(
|
||||
DISMISSED_TOOL_APPROVALS_STORAGE_KEY,
|
||||
state.dismissedToolApprovals,
|
||||
);
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
export const {
|
||||
sseEventReceived,
|
||||
sseHealthChanged,
|
||||
sseHeartbeat,
|
||||
sseLastEventIdReset,
|
||||
sseLastEventIdAdvanced,
|
||||
clearRecentEvents,
|
||||
dismissToolApproval,
|
||||
} = notificationsSlice.actions;
|
||||
|
||||
export const selectSseHealth = (state: RootState): PushHealth =>
|
||||
state.notifications.health;
|
||||
|
||||
export const selectPushChannelHealthy = (state: RootState): boolean =>
|
||||
state.notifications.health === 'healthy';
|
||||
|
||||
export const selectLastEventId = (state: RootState): string | null =>
|
||||
state.notifications.lastEventId;
|
||||
|
||||
export const selectLastEventReceivedAt = (state: RootState): number | null =>
|
||||
state.notifications.lastEventReceivedAt;
|
||||
|
||||
export const selectRecentEvents = (state: RootState): SSEEvent[] =>
|
||||
state.notifications.recentEvents;
|
||||
|
||||
// Memoised so ``useSelector`` consumers don't re-render on every
|
||||
// unrelated ``notifications`` state change — the underlying ``{id,at}``
|
||||
// array only changes when ``dismissToolApproval`` runs, but the
|
||||
// projected ``.map`` would otherwise return a fresh array each call.
|
||||
export const selectDismissedToolApprovals = createSelector(
|
||||
(state: RootState) => state.notifications.dismissedToolApprovals,
|
||||
(entries) => entries.map((entry) => entry.id),
|
||||
);
|
||||
|
||||
export default notificationsSlice.reducer;
|
||||
@@ -85,49 +85,6 @@ export async function getConversations(
|
||||
}
|
||||
}
|
||||
|
||||
export async function searchConversations(
|
||||
query: string,
|
||||
token: string | null,
|
||||
limit = 30,
|
||||
): Promise<GetConversationsResult> {
|
||||
try {
|
||||
const response = await conversationService.searchConversations(
|
||||
query,
|
||||
token,
|
||||
limit,
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
console.error('Error searching conversations:', response.statusText);
|
||||
return { data: null, loading: false };
|
||||
}
|
||||
|
||||
const rawData: unknown = await response.json();
|
||||
if (!Array.isArray(rawData)) {
|
||||
console.error(
|
||||
'Invalid data format received from API: Expected an array.',
|
||||
rawData,
|
||||
);
|
||||
return { data: null, loading: false };
|
||||
}
|
||||
|
||||
const conversations: ConversationSummary[] = rawData.map((item: any) => ({
|
||||
id: item.id,
|
||||
name: item.name,
|
||||
agent_id: item.agent_id ?? null,
|
||||
match_field: item.match_field ?? null,
|
||||
match_snippet: item.match_snippet ?? null,
|
||||
}));
|
||||
return { data: conversations, loading: false };
|
||||
} catch (error) {
|
||||
console.error(
|
||||
'An unexpected error occurred while searching conversations:',
|
||||
error,
|
||||
);
|
||||
return { data: null, loading: false };
|
||||
}
|
||||
}
|
||||
|
||||
export function getLocalApiKey(): string | null {
|
||||
const key = localStorage.getItem('DocsGPTApiKey');
|
||||
return key;
|
||||
|
||||
@@ -2,8 +2,6 @@ export type ConversationSummary = {
|
||||
id: string;
|
||||
name: string;
|
||||
agent_id: string | null;
|
||||
match_field?: 'name' | 'prompt' | 'response' | null;
|
||||
match_snippet?: string | null;
|
||||
};
|
||||
|
||||
export type GetConversationsResult = {
|
||||
|
||||
@@ -4,6 +4,7 @@ import agentPreviewReducer from './agents/agentPreviewSlice';
|
||||
import workflowPreviewReducer from './agents/workflow/workflowPreviewSlice';
|
||||
import { conversationSlice } from './conversation/conversationSlice';
|
||||
import { sharedConversationSlice } from './conversation/sharedConversationSlice';
|
||||
import notificationsReducer from './notifications/notificationsSlice';
|
||||
import { getStoredRecentDocs } from './preferences/preferenceApi';
|
||||
import {
|
||||
Preference,
|
||||
@@ -67,6 +68,7 @@ const store = configureStore({
|
||||
upload: uploadReducer,
|
||||
agentPreview: agentPreviewReducer,
|
||||
workflowPreview: workflowPreviewReducer,
|
||||
notifications: notificationsReducer,
|
||||
},
|
||||
middleware: (getDefaultMiddleware) =>
|
||||
getDefaultMiddleware().concat(prefListenerMiddleware.middleware),
|
||||
|
||||
@@ -2,9 +2,9 @@ import { useCallback, useState } from 'react';
|
||||
import { nanoid } from '@reduxjs/toolkit';
|
||||
import { useDropzone } from 'react-dropzone';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useDispatch, useSelector } from 'react-redux';
|
||||
import { useDispatch, useSelector, useStore } from 'react-redux';
|
||||
|
||||
import userService from '../api/services/userService';
|
||||
import type { RootState } from '../store';
|
||||
import { getSessionToken } from '../utils/providerUtils';
|
||||
import Dropdown from '../components/Dropdown';
|
||||
import Input from '../components/Input';
|
||||
@@ -298,6 +298,7 @@ function Upload({
|
||||
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useDispatch();
|
||||
const store = useStore<RootState>();
|
||||
|
||||
const ingestorOptions: IngestorOption[] = IngestorFormSchemas.filter(
|
||||
(schema) => (schema.validate ? schema.validate() : true),
|
||||
@@ -334,110 +335,113 @@ function Upload({
|
||||
[dispatch],
|
||||
);
|
||||
|
||||
/**
|
||||
* Wait for the source.ingest.* SSE pipeline to flip this task into a
|
||||
* terminal state, then run the post-completion side effects: refresh
|
||||
* the global source list, auto-select the new doc, and fire the
|
||||
* caller's ``onSuccessfulUpload`` hook. The slice's extraReducer
|
||||
* (uploadSlice.ts) is the sole driver of the task's status; we only
|
||||
* subscribe so the side effects can fire after the modal has closed.
|
||||
*/
|
||||
const trackTraining = useCallback(
|
||||
(backendTaskId: string, clientTaskId: string) => {
|
||||
let timeoutId: number | null = null;
|
||||
(clientTaskId: string) => {
|
||||
let handled = false;
|
||||
|
||||
const poll = () => {
|
||||
userService
|
||||
.getTaskStatus(backendTaskId, null)
|
||||
.then((response) => response.json())
|
||||
.then(async (data) => {
|
||||
if (!data.success && data.message) {
|
||||
if (timeoutId !== null) {
|
||||
clearTimeout(timeoutId);
|
||||
timeoutId = null;
|
||||
}
|
||||
handleTaskFailure(clientTaskId, data.message);
|
||||
return;
|
||||
}
|
||||
|
||||
if (data.status === 'SUCCESS') {
|
||||
if (timeoutId !== null) {
|
||||
clearTimeout(timeoutId);
|
||||
timeoutId = null;
|
||||
}
|
||||
|
||||
const docs = await getDocs(token);
|
||||
dispatch(setSourceDocs(docs));
|
||||
|
||||
if (Array.isArray(docs)) {
|
||||
const existingDocIds = new Set(
|
||||
(Array.isArray(sourceDocs) ? sourceDocs : [])
|
||||
.map((doc: Doc) => doc?.id)
|
||||
.filter((id): id is string => Boolean(id)),
|
||||
);
|
||||
const newDoc = docs.find(
|
||||
(doc: Doc) => doc.id && !existingDocIds.has(doc.id),
|
||||
);
|
||||
if (newDoc) {
|
||||
// If only one doc is selected, replace it completely
|
||||
// If multiple docs are selected, append the new doc
|
||||
if (selectedDocs.length === 1) {
|
||||
dispatch(setSelectedDocs([newDoc]));
|
||||
} else {
|
||||
dispatch(setSelectedDocs([...selectedDocs, newDoc]));
|
||||
}
|
||||
const handleTerminal = (status: 'completed' | 'failed') => {
|
||||
if (handled) return;
|
||||
handled = true;
|
||||
if (status !== 'completed') return;
|
||||
getDocs(token)
|
||||
.then((docs) => {
|
||||
dispatch(setSourceDocs(docs));
|
||||
if (Array.isArray(docs)) {
|
||||
const existingDocIds = new Set(
|
||||
(Array.isArray(sourceDocs) ? sourceDocs : [])
|
||||
.map((doc: Doc) => doc?.id)
|
||||
.filter((id): id is string => Boolean(id)),
|
||||
);
|
||||
const newDoc = docs.find(
|
||||
(doc: Doc) => doc.id && !existingDocIds.has(doc.id),
|
||||
);
|
||||
if (newDoc) {
|
||||
// If only one doc is selected, replace it completely
|
||||
// If multiple docs are selected, append the new doc
|
||||
if (selectedDocs.length === 1) {
|
||||
dispatch(setSelectedDocs([newDoc]));
|
||||
} else {
|
||||
dispatch(setSelectedDocs([...selectedDocs, newDoc]));
|
||||
}
|
||||
}
|
||||
|
||||
if (data.result?.limited) {
|
||||
dispatch(
|
||||
updateUploadTask({
|
||||
id: clientTaskId,
|
||||
updates: {
|
||||
status: 'failed',
|
||||
progress: 100,
|
||||
errorMessage: t('modals.uploadDoc.progress.tokenLimit'),
|
||||
},
|
||||
}),
|
||||
);
|
||||
} else {
|
||||
dispatch(
|
||||
updateUploadTask({
|
||||
id: clientTaskId,
|
||||
updates: {
|
||||
status: 'completed',
|
||||
progress: 100,
|
||||
errorMessage: undefined,
|
||||
},
|
||||
}),
|
||||
);
|
||||
onSuccessfulUpload?.();
|
||||
}
|
||||
} else if (data.status === 'FAILURE') {
|
||||
if (timeoutId !== null) {
|
||||
clearTimeout(timeoutId);
|
||||
timeoutId = null;
|
||||
}
|
||||
handleTaskFailure(clientTaskId, data.result?.message);
|
||||
} else if (data.status === 'PROGRESS') {
|
||||
dispatch(
|
||||
updateUploadTask({
|
||||
id: clientTaskId,
|
||||
updates: {
|
||||
status: 'training',
|
||||
progress: Math.min(100, data.result?.current ?? 0),
|
||||
},
|
||||
}),
|
||||
);
|
||||
timeoutId = window.setTimeout(poll, 5000);
|
||||
} else {
|
||||
timeoutId = window.setTimeout(poll, 5000);
|
||||
}
|
||||
onSuccessfulUpload?.();
|
||||
})
|
||||
.catch((error) => {
|
||||
if (timeoutId !== null) {
|
||||
clearTimeout(timeoutId);
|
||||
timeoutId = null;
|
||||
}
|
||||
handleTaskFailure(clientTaskId, error?.message);
|
||||
.catch((err) => {
|
||||
console.error(
|
||||
'SSE-driven post-completion source-list refresh failed:',
|
||||
err,
|
||||
);
|
||||
});
|
||||
};
|
||||
|
||||
timeoutId = window.setTimeout(poll, 3000);
|
||||
const check = () => {
|
||||
const state = store.getState();
|
||||
const task = state.upload.tasks.find((t) => t.id === clientTaskId);
|
||||
if (!task) return false;
|
||||
if (task.status === 'completed' || task.status === 'failed') {
|
||||
handleTerminal(task.status);
|
||||
return true;
|
||||
}
|
||||
// Recover from the race where the terminal SSE landed before
|
||||
// ``xhr.onload`` populated ``task.sourceId`` — the slice
|
||||
// silently drops such events (no task to match by sourceId).
|
||||
// Mirrors ConnectorTree/FileTree's ``recentEvents`` walk.
|
||||
if (task.sourceId) {
|
||||
for (const event of state.notifications.recentEvents) {
|
||||
if (event.scope?.id !== task.sourceId) continue;
|
||||
if (event.type === 'source.ingest.completed') {
|
||||
handleTerminal('completed');
|
||||
return true;
|
||||
}
|
||||
if (event.type === 'source.ingest.failed') {
|
||||
handleTerminal('failed');
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
if (check()) return;
|
||||
const MAX_WAIT_MS = 5 * 60_000;
|
||||
let unsubscribe: (() => void) | null = null;
|
||||
const timer = window.setTimeout(() => {
|
||||
unsubscribe?.();
|
||||
if (!handled) {
|
||||
handled = true;
|
||||
console.warn(
|
||||
'trackTraining: timed out waiting for terminal SSE',
|
||||
clientTaskId,
|
||||
);
|
||||
dispatch(
|
||||
updateUploadTask({
|
||||
id: clientTaskId,
|
||||
updates: {
|
||||
status: 'failed',
|
||||
errorMessage:
|
||||
'Timed out waiting for ingest completion. The ingest may still be running — please refresh to check.',
|
||||
},
|
||||
}),
|
||||
);
|
||||
}
|
||||
}, MAX_WAIT_MS);
|
||||
unsubscribe = store.subscribe(() => {
|
||||
if (check()) {
|
||||
window.clearTimeout(timer);
|
||||
unsubscribe?.();
|
||||
}
|
||||
});
|
||||
},
|
||||
[dispatch, handleTaskFailure, onSuccessfulUpload, sourceDocs, t, token],
|
||||
[dispatch, onSuccessfulUpload, selectedDocs, sourceDocs, store, token],
|
||||
);
|
||||
|
||||
const onDrop = useCallback(
|
||||
@@ -499,19 +503,23 @@ function Upload({
|
||||
xhr.onload = () => {
|
||||
if (xhr.status >= 200 && xhr.status < 300) {
|
||||
try {
|
||||
const parsed = JSON.parse(xhr.responseText) as { task_id?: string };
|
||||
const parsed = JSON.parse(xhr.responseText) as {
|
||||
task_id?: string;
|
||||
source_id?: string;
|
||||
};
|
||||
if (parsed.task_id) {
|
||||
dispatch(
|
||||
updateUploadTask({
|
||||
id: clientTaskId,
|
||||
updates: {
|
||||
taskId: parsed.task_id,
|
||||
sourceId: parsed.source_id,
|
||||
status: 'training',
|
||||
progress: 0,
|
||||
},
|
||||
}),
|
||||
);
|
||||
trackTraining(parsed.task_id, clientTaskId);
|
||||
trackTraining(clientTaskId);
|
||||
} else {
|
||||
dispatch(
|
||||
updateUploadTask({
|
||||
@@ -627,19 +635,23 @@ function Upload({
|
||||
xhr.onload = () => {
|
||||
if (xhr.status >= 200 && xhr.status < 300) {
|
||||
try {
|
||||
const response = JSON.parse(xhr.responseText) as { task_id?: string };
|
||||
const response = JSON.parse(xhr.responseText) as {
|
||||
task_id?: string;
|
||||
source_id?: string;
|
||||
};
|
||||
if (response.task_id) {
|
||||
dispatch(
|
||||
updateUploadTask({
|
||||
id: clientTaskId,
|
||||
updates: {
|
||||
taskId: response.task_id,
|
||||
sourceId: response.source_id,
|
||||
status: 'training',
|
||||
progress: 0,
|
||||
},
|
||||
}),
|
||||
);
|
||||
trackTraining(response.task_id, clientTaskId);
|
||||
trackTraining(clientTaskId);
|
||||
} else {
|
||||
dispatch(
|
||||
updateUploadTask({
|
||||
|
||||
458
frontend/src/upload/uploadSlice.test.ts
Normal file
458
frontend/src/upload/uploadSlice.test.ts
Normal file
@@ -0,0 +1,458 @@
|
||||
import { beforeEach, describe, expect, it } from 'vitest';
|
||||
|
||||
import notificationsReducer, {
|
||||
sseEventReceived,
|
||||
type SSEEvent,
|
||||
} from '../notifications/notificationsSlice';
|
||||
import reducer, {
|
||||
addAttachment,
|
||||
addUploadTask,
|
||||
dismissUploadTask,
|
||||
updateAttachment,
|
||||
updateUploadTask,
|
||||
type Attachment,
|
||||
type UploadTask,
|
||||
} from './uploadSlice';
|
||||
|
||||
const SOURCE_ID = 'src-1';
|
||||
|
||||
const makeTask = (overrides: Partial<UploadTask> = {}): UploadTask => ({
|
||||
id: 't-1',
|
||||
fileName: 'doc.pdf',
|
||||
progress: 0,
|
||||
status: 'preparing',
|
||||
sourceId: SOURCE_ID,
|
||||
...overrides,
|
||||
});
|
||||
|
||||
const stateWithTask = (task: UploadTask) =>
|
||||
reducer(undefined, addUploadTask(task));
|
||||
|
||||
const ingest = (
|
||||
type: string,
|
||||
payload: Record<string, unknown> = {},
|
||||
scopeId = SOURCE_ID,
|
||||
) =>
|
||||
sseEventReceived({
|
||||
id: `id-${type}`,
|
||||
type,
|
||||
scope: { kind: 'source', id: scopeId },
|
||||
payload,
|
||||
});
|
||||
|
||||
describe('dismissal persistence across reload', () => {
|
||||
const STORAGE_KEY = 'docsgpt:dismissedUploadSourceIds';
|
||||
const SRC = 'src-persisted';
|
||||
|
||||
// Mirrors initialState as if hydrated from localStorage.
|
||||
const seedState = (entries: { id: string; at: number }[]) =>
|
||||
reducer(
|
||||
{ attachments: [], tasks: [], dismissedSourceIds: entries },
|
||||
{ type: '@@INIT' },
|
||||
);
|
||||
|
||||
beforeEach(() => {
|
||||
localStorage.clear();
|
||||
});
|
||||
|
||||
it('dismissUploadTask writes the task sourceId to localStorage', () => {
|
||||
let state = stateWithTask(
|
||||
makeTask({ id: 't-dismiss', sourceId: SRC, status: 'completed' }),
|
||||
);
|
||||
state = reducer(state, dismissUploadTask('t-dismiss'));
|
||||
expect(state.tasks[0].dismissed).toBe(true);
|
||||
expect(state.dismissedSourceIds).toHaveLength(1);
|
||||
expect(state.dismissedSourceIds[0].id).toBe(SRC);
|
||||
const persisted = JSON.parse(localStorage.getItem(STORAGE_KEY) ?? '[]');
|
||||
expect(persisted).toHaveLength(1);
|
||||
expect(persisted[0].id).toBe(SRC);
|
||||
});
|
||||
|
||||
it('skips persistence when the task has no sourceId yet', () => {
|
||||
let state = stateWithTask(
|
||||
makeTask({ id: 't-no-src', sourceId: undefined, status: 'preparing' }),
|
||||
);
|
||||
state = reducer(state, dismissUploadTask('t-no-src'));
|
||||
expect(state.tasks[0].dismissed).toBe(true);
|
||||
expect(state.dismissedSourceIds).toHaveLength(0);
|
||||
expect(localStorage.getItem(STORAGE_KEY)).toBeNull();
|
||||
});
|
||||
|
||||
it('auto-create on refresh marks the task dismissed when sourceId is in the persisted list', () => {
|
||||
const state = seedState([{ id: SRC, at: Date.now() }]);
|
||||
const next = reducer(
|
||||
state,
|
||||
ingest('source.ingest.progress', { current: 40 }, SRC),
|
||||
);
|
||||
const recovered = next.tasks.find((t) => t.sourceId === SRC);
|
||||
expect(recovered).toBeDefined();
|
||||
expect(recovered!.dismissed).toBe(true);
|
||||
expect(recovered!.progress).toBe(40);
|
||||
});
|
||||
|
||||
it('terminal events do NOT un-dismiss a task whose sourceId was previously dismissed', () => {
|
||||
const state = seedState([{ id: SRC, at: Date.now() }]);
|
||||
let next = reducer(state, ingest('source.ingest.queued', {}, SRC));
|
||||
expect(next.tasks[0].dismissed).toBe(true);
|
||||
next = reducer(next, ingest('source.ingest.completed', {}, SRC));
|
||||
// Even though `wasTerminal` is false (just transitioned), the
|
||||
// persisted dismissal keeps the toast closed.
|
||||
expect(next.tasks[0].status).toBe('completed');
|
||||
expect(next.tasks[0].dismissed).toBe(true);
|
||||
});
|
||||
|
||||
it('updateUploadTask does not un-dismiss on terminal when sourceId was previously dismissed', () => {
|
||||
const state = reducer(
|
||||
{
|
||||
attachments: [],
|
||||
tasks: [
|
||||
makeTask({
|
||||
id: 't-1',
|
||||
sourceId: SRC,
|
||||
status: 'training',
|
||||
dismissed: true,
|
||||
}),
|
||||
],
|
||||
dismissedSourceIds: [{ id: SRC, at: Date.now() }],
|
||||
},
|
||||
{ type: '@@INIT' },
|
||||
);
|
||||
const next = reducer(
|
||||
state,
|
||||
updateUploadTask({
|
||||
id: 't-1',
|
||||
updates: { status: 'completed' },
|
||||
}),
|
||||
);
|
||||
expect(next.tasks[0].status).toBe('completed');
|
||||
expect(next.tasks[0].dismissed).toBe(true);
|
||||
});
|
||||
|
||||
it('un-dismisses normally for sourceIds NOT in the persisted list', () => {
|
||||
const state = reducer(undefined, { type: '@@INIT' });
|
||||
const populated = reducer(
|
||||
state,
|
||||
addUploadTask(
|
||||
makeTask({
|
||||
id: 't-fresh',
|
||||
sourceId: 'src-fresh',
|
||||
status: 'training',
|
||||
dismissed: true,
|
||||
}),
|
||||
),
|
||||
);
|
||||
const next = reducer(
|
||||
populated,
|
||||
updateUploadTask({ id: 't-fresh', updates: { status: 'completed' } }),
|
||||
);
|
||||
expect(next.tasks[0].dismissed).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('refresh recovery — auto-create from SSE when no task matches', () => {
|
||||
it('creates a task on queued for an unknown sourceId', () => {
|
||||
let state = reducer(undefined, addUploadTask(makeTask({ id: 'other' })));
|
||||
state = reducer(
|
||||
state,
|
||||
ingest(
|
||||
'source.ingest.queued',
|
||||
{ filename: 'crawler.json', job_name: 'docs' },
|
||||
'src-recovery',
|
||||
),
|
||||
);
|
||||
const recovered = state.tasks.find((t) => t.sourceId === 'src-recovery');
|
||||
expect(recovered).toBeDefined();
|
||||
expect(recovered!.status).toBe('training');
|
||||
expect(recovered!.fileName).toBe('crawler.json');
|
||||
expect(recovered!.dismissed).toBe(false);
|
||||
});
|
||||
|
||||
it('creates a task on progress for an unknown sourceId and applies the percent', () => {
|
||||
let state: ReturnType<typeof reducer> = reducer(undefined, {
|
||||
type: '@@INIT',
|
||||
});
|
||||
state = reducer(
|
||||
state,
|
||||
ingest(
|
||||
'source.ingest.progress',
|
||||
{ current: 55, total: 10, embedded_chunks: 5 },
|
||||
'src-progress',
|
||||
),
|
||||
);
|
||||
expect(state.tasks).toHaveLength(1);
|
||||
expect(state.tasks[0].sourceId).toBe('src-progress');
|
||||
expect(state.tasks[0].status).toBe('training');
|
||||
expect(state.tasks[0].progress).toBe(55);
|
||||
});
|
||||
|
||||
it('does NOT create a task on completed for an unknown sourceId (avoids backlog toast spam)', () => {
|
||||
let state: ReturnType<typeof reducer> = reducer(undefined, {
|
||||
type: '@@INIT',
|
||||
});
|
||||
state = reducer(
|
||||
state,
|
||||
ingest('source.ingest.completed', { tokens: 16 }, 'src-stale'),
|
||||
);
|
||||
expect(state.tasks).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('creates a task on failed for an unknown sourceId so error surfaces post-refresh', () => {
|
||||
let state: ReturnType<typeof reducer> = reducer(undefined, {
|
||||
type: '@@INIT',
|
||||
});
|
||||
state = reducer(
|
||||
state,
|
||||
ingest(
|
||||
'source.ingest.failed',
|
||||
{ error: 'embed worker died' },
|
||||
'src-failed',
|
||||
),
|
||||
);
|
||||
expect(state.tasks).toHaveLength(1);
|
||||
expect(state.tasks[0].status).toBe('failed');
|
||||
expect(state.tasks[0].errorMessage).toBe('embed worker died');
|
||||
expect(state.tasks[0].dismissed).toBe(false);
|
||||
});
|
||||
|
||||
it('falls back to job_name then sourceId when filename is absent', () => {
|
||||
let state: ReturnType<typeof reducer> = reducer(undefined, {
|
||||
type: '@@INIT',
|
||||
});
|
||||
state = reducer(
|
||||
state,
|
||||
ingest('source.ingest.queued', { job_name: 'docs-docs' }, 'src-jn'),
|
||||
);
|
||||
expect(state.tasks[0].fileName).toBe('docs-docs');
|
||||
state = reducer(state, ingest('source.ingest.queued', {}, 'src-only'));
|
||||
expect(state.tasks.find((t) => t.sourceId === 'src-only')!.fileName).toBe(
|
||||
'src-only',
|
||||
);
|
||||
});
|
||||
|
||||
it('subsequent progress events update the recovered task in place', () => {
|
||||
let state: ReturnType<typeof reducer> = reducer(undefined, {
|
||||
type: '@@INIT',
|
||||
});
|
||||
state = reducer(
|
||||
state,
|
||||
ingest('source.ingest.queued', { filename: 'a.txt' }, 'src-flow'),
|
||||
);
|
||||
state = reducer(
|
||||
state,
|
||||
ingest('source.ingest.progress', { current: 30 }, 'src-flow'),
|
||||
);
|
||||
state = reducer(
|
||||
state,
|
||||
ingest('source.ingest.progress', { current: 80 }, 'src-flow'),
|
||||
);
|
||||
state = reducer(
|
||||
state,
|
||||
ingest('source.ingest.completed', { tokens: 100 }, 'src-flow'),
|
||||
);
|
||||
expect(state.tasks).toHaveLength(1);
|
||||
expect(state.tasks[0].status).toBe('completed');
|
||||
expect(state.tasks[0].progress).toBe(100);
|
||||
});
|
||||
});
|
||||
|
||||
describe('source.ingest.queued', () => {
|
||||
it('does not regress a task already in training', () => {
|
||||
let state = stateWithTask(makeTask({ status: 'training', progress: 42 }));
|
||||
state = reducer(state, ingest('source.ingest.queued'));
|
||||
expect(state.tasks[0].status).toBe('training');
|
||||
expect(state.tasks[0].progress).toBe(42);
|
||||
});
|
||||
|
||||
it('transitions preparing -> training and zeros progress', () => {
|
||||
let state = stateWithTask(makeTask({ status: 'preparing', progress: 12 }));
|
||||
state = reducer(state, ingest('source.ingest.queued'));
|
||||
expect(state.tasks[0].status).toBe('training');
|
||||
expect(state.tasks[0].progress).toBe(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('source.ingest.progress', () => {
|
||||
it('clamps to 0..100 and is monotonic', () => {
|
||||
let state = stateWithTask(makeTask({ status: 'training' }));
|
||||
state = reducer(state, ingest('source.ingest.progress', { current: 30 }));
|
||||
expect(state.tasks[0].progress).toBe(30);
|
||||
// Higher value advances.
|
||||
state = reducer(state, ingest('source.ingest.progress', { current: 150 }));
|
||||
expect(state.tasks[0].progress).toBe(100);
|
||||
// Lower value never regresses.
|
||||
state = reducer(state, ingest('source.ingest.progress', { current: 50 }));
|
||||
expect(state.tasks[0].progress).toBe(100);
|
||||
// Negative gets clamped at 0 but still cannot regress already-higher.
|
||||
state = reducer(state, ingest('source.ingest.progress', { current: -10 }));
|
||||
expect(state.tasks[0].progress).toBe(100);
|
||||
});
|
||||
});
|
||||
|
||||
describe('source.ingest.completed', () => {
|
||||
it('transitions training -> completed and sets dismissed=false', () => {
|
||||
let state = stateWithTask(
|
||||
makeTask({ status: 'training', dismissed: true }),
|
||||
);
|
||||
state = reducer(state, ingest('source.ingest.completed'));
|
||||
expect(state.tasks[0].status).toBe('completed');
|
||||
expect(state.tasks[0].progress).toBe(100);
|
||||
expect(state.tasks[0].dismissed).toBe(false);
|
||||
expect(state.tasks[0].tokenLimitReached).toBe(false);
|
||||
});
|
||||
|
||||
it('with limited=true transitions training -> failed and flags tokenLimitReached', () => {
|
||||
let state = stateWithTask(
|
||||
makeTask({ status: 'training', dismissed: true }),
|
||||
);
|
||||
state = reducer(
|
||||
state,
|
||||
ingest('source.ingest.completed', { limited: true }),
|
||||
);
|
||||
expect(state.tasks[0].status).toBe('failed');
|
||||
expect(state.tasks[0].progress).toBe(100);
|
||||
expect(state.tasks[0].tokenLimitReached).toBe(true);
|
||||
expect(state.tasks[0].dismissed).toBe(false);
|
||||
});
|
||||
|
||||
it('does not re-un-dismiss when a duplicate terminal event arrives', () => {
|
||||
// Initial terminal — wasTerminal=false, dismissed flipped to false.
|
||||
let state = stateWithTask(makeTask({ status: 'training' }));
|
||||
state = reducer(state, ingest('source.ingest.completed'));
|
||||
expect(state.tasks[0].dismissed).toBe(false);
|
||||
|
||||
// User dismisses the toast manually.
|
||||
state = {
|
||||
...state,
|
||||
tasks: state.tasks.map((t) => ({ ...t, dismissed: true })),
|
||||
};
|
||||
|
||||
// Duplicate terminal envelope (StrictMode remount, reconnect overlap).
|
||||
state = reducer(state, ingest('source.ingest.completed'));
|
||||
expect(state.tasks[0].status).toBe('completed');
|
||||
expect(state.tasks[0].dismissed).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('source.ingest.failed', () => {
|
||||
it('transitions training -> failed with the error message', () => {
|
||||
let state = stateWithTask(makeTask({ status: 'training' }));
|
||||
state = reducer(
|
||||
state,
|
||||
ingest('source.ingest.failed', { error: 'parser blew up' }),
|
||||
);
|
||||
expect(state.tasks[0].status).toBe('failed');
|
||||
expect(state.tasks[0].errorMessage).toBe('parser blew up');
|
||||
expect(state.tasks[0].dismissed).toBe(false);
|
||||
});
|
||||
|
||||
it('does not re-un-dismiss when a duplicate failed event arrives', () => {
|
||||
let state = stateWithTask(makeTask({ status: 'training' }));
|
||||
state = reducer(state, ingest('source.ingest.failed', { error: 'oops' }));
|
||||
state = {
|
||||
...state,
|
||||
tasks: state.tasks.map((t) => ({ ...t, dismissed: true })),
|
||||
};
|
||||
state = reducer(state, ingest('source.ingest.failed', { error: 'oops' }));
|
||||
expect(state.tasks[0].dismissed).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('attachment race recovery', () => {
|
||||
const ATTACHMENT_ID = 'att-1';
|
||||
const CLIENT_ID = 'ui-1';
|
||||
|
||||
const attEvent = (
|
||||
type: string,
|
||||
payload: Record<string, unknown> = {},
|
||||
): SSEEvent => ({
|
||||
id: `id-${type}`,
|
||||
type,
|
||||
scope: { kind: 'attachment', id: ATTACHMENT_ID },
|
||||
payload,
|
||||
});
|
||||
|
||||
const makeAttachment = (overrides: Partial<Attachment> = {}): Attachment => ({
|
||||
id: CLIENT_ID,
|
||||
fileName: 'small.pdf',
|
||||
progress: 10,
|
||||
status: 'processing',
|
||||
taskId: 'celery-1',
|
||||
attachmentId: ATTACHMENT_ID,
|
||||
...overrides,
|
||||
});
|
||||
|
||||
it('drops attachment.completed silently when no row matches attachmentId', () => {
|
||||
const state = reducer(
|
||||
undefined,
|
||||
sseEventReceived(attEvent('attachment.completed', { token_count: 42 })),
|
||||
);
|
||||
expect(state.attachments).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('lands the terminal envelope in notifications.recentEvents for later recovery', () => {
|
||||
const notifState = notificationsReducer(
|
||||
undefined,
|
||||
sseEventReceived(attEvent('attachment.completed', { token_count: 7 })),
|
||||
);
|
||||
expect(notifState.recentEvents).toHaveLength(1);
|
||||
expect(notifState.recentEvents[0].scope?.id).toBe(ATTACHMENT_ID);
|
||||
expect(notifState.recentEvents[0].type).toBe('attachment.completed');
|
||||
});
|
||||
|
||||
it('reconciler dispatch flips the row to completed after the late row addition', () => {
|
||||
// Full race: terminal SSE first, then xhr.onload adds the row,
|
||||
// then trackAttachment.check() walks recentEvents and dispatches.
|
||||
const terminal = attEvent('attachment.completed', { token_count: 99 });
|
||||
const notifState = notificationsReducer(
|
||||
undefined,
|
||||
sseEventReceived(terminal),
|
||||
);
|
||||
|
||||
let uploadState = reducer(undefined, addAttachment(makeAttachment()));
|
||||
expect(uploadState.attachments[0].status).toBe('processing');
|
||||
|
||||
const found = notifState.recentEvents.find(
|
||||
(e) => e.scope?.id === ATTACHMENT_ID && e.type === 'attachment.completed',
|
||||
);
|
||||
expect(found).toBeDefined();
|
||||
const tokenCount = Number(
|
||||
(found?.payload as { token_count?: unknown })?.token_count,
|
||||
);
|
||||
uploadState = reducer(
|
||||
uploadState,
|
||||
updateAttachment({
|
||||
id: CLIENT_ID,
|
||||
updates: {
|
||||
status: 'completed',
|
||||
progress: 100,
|
||||
...(Number.isFinite(tokenCount) ? { token_count: tokenCount } : {}),
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
expect(uploadState.attachments[0].status).toBe('completed');
|
||||
expect(uploadState.attachments[0].progress).toBe(100);
|
||||
expect(uploadState.attachments[0].token_count).toBe(99);
|
||||
});
|
||||
|
||||
it('attachment.failed envelope can drive a stuck row to failed via reconciler', () => {
|
||||
const failed = attEvent('attachment.failed', { error: 'docling boom' });
|
||||
const notifState = notificationsReducer(
|
||||
undefined,
|
||||
sseEventReceived(failed),
|
||||
);
|
||||
|
||||
let uploadState = reducer(undefined, addAttachment(makeAttachment()));
|
||||
const found = notifState.recentEvents.find(
|
||||
(e) => e.scope?.id === ATTACHMENT_ID && e.type === 'attachment.failed',
|
||||
);
|
||||
expect(found).toBeDefined();
|
||||
|
||||
uploadState = reducer(
|
||||
uploadState,
|
||||
updateAttachment({ id: CLIENT_ID, updates: { status: 'failed' } }),
|
||||
);
|
||||
|
||||
expect(uploadState.attachments[0].status).toBe('failed');
|
||||
});
|
||||
});
|
||||
@@ -1,12 +1,46 @@
|
||||
import { createSelector, createSlice, PayloadAction } from '@reduxjs/toolkit';
|
||||
|
||||
import {
|
||||
DismissedEntry,
|
||||
isDismissed,
|
||||
loadDismissed,
|
||||
saveDismissed,
|
||||
} from '../notifications/dismissedPersistence';
|
||||
import { sseEventReceived } from '../notifications/notificationsSlice';
|
||||
import { RootState } from '../store';
|
||||
|
||||
const DISMISSED_SOURCE_IDS_STORAGE_KEY = 'docsgpt:dismissedUploadSourceIds';
|
||||
const DISMISSED_SOURCE_IDS_TTL_MS = 24 * 60 * 60 * 1000;
|
||||
const DISMISSED_SOURCE_IDS_CAP = 200;
|
||||
|
||||
function recordDismissedSourceId(
|
||||
list: DismissedEntry[],
|
||||
sourceId: string,
|
||||
): DismissedEntry[] {
|
||||
const now = Date.now();
|
||||
const cutoff = now - DISMISSED_SOURCE_IDS_TTL_MS;
|
||||
const next = list.filter(
|
||||
(entry) => entry.at >= cutoff && entry.id !== sourceId,
|
||||
);
|
||||
next.push({ id: sourceId, at: now });
|
||||
return next.length > DISMISSED_SOURCE_IDS_CAP
|
||||
? next.slice(-DISMISSED_SOURCE_IDS_CAP)
|
||||
: next;
|
||||
}
|
||||
|
||||
export interface Attachment {
|
||||
id: string; // Unique identifier for the attachment (required for state management)
|
||||
id: string; // Client-side state-management id (uuid generated in MessageInput)
|
||||
fileName: string;
|
||||
progress: number;
|
||||
status: 'uploading' | 'processing' | 'completed' | 'failed';
|
||||
taskId: string; // Server-assigned task ID (used for API calls)
|
||||
taskId: string; // Server-assigned celery task ID (used for API calls)
|
||||
/**
|
||||
* Server-assigned attachment id (stable across the lifecycle —
|
||||
* ``attachment.*`` SSE events use this in ``scope.id``). Set as
|
||||
* soon as the upload response returns. Distinct from ``id``
|
||||
* (client) and ``taskId`` (celery).
|
||||
*/
|
||||
attachmentId?: string;
|
||||
token_count?: number;
|
||||
}
|
||||
|
||||
@@ -23,18 +57,40 @@ export interface UploadTask {
|
||||
progress: number;
|
||||
status: UploadTaskStatus;
|
||||
taskId?: string;
|
||||
/**
|
||||
* Server-derived source id (uuid5 over the idempotency key) returned by
|
||||
* the upload endpoint. Used to correlate inbound SSE ingest events
|
||||
* (``source.ingest.*``) back to this task without consulting the
|
||||
* polling endpoint.
|
||||
*/
|
||||
sourceId?: string;
|
||||
errorMessage?: string;
|
||||
dismissed?: boolean;
|
||||
/**
|
||||
* Flipped when ``source.ingest.completed`` carries
|
||||
* ``payload.limited === true`` (the worker hit a token cap during
|
||||
* ingest). The slice routes such events to a failed status and
|
||||
* sets this flag so ``UploadToast`` can surface a translated
|
||||
* token-limit message instead of a generic error. Forward-looking:
|
||||
* no worker code path sets ``limited: true`` today.
|
||||
*/
|
||||
tokenLimitReached?: boolean;
|
||||
}
|
||||
|
||||
interface UploadState {
|
||||
attachments: Attachment[];
|
||||
tasks: UploadTask[];
|
||||
/** Persisted dismissed sourceIds; keeps backlog-replay auto-creates silent. */
|
||||
dismissedSourceIds: DismissedEntry[];
|
||||
}
|
||||
|
||||
const initialState: UploadState = {
|
||||
attachments: [],
|
||||
tasks: [],
|
||||
dismissedSourceIds: loadDismissed(
|
||||
DISMISSED_SOURCE_IDS_STORAGE_KEY,
|
||||
DISMISSED_SOURCE_IDS_TTL_MS,
|
||||
),
|
||||
};
|
||||
|
||||
export const uploadSlice = createSlice({
|
||||
@@ -103,9 +159,19 @@ export const uploadSlice = createSlice({
|
||||
);
|
||||
if (index !== -1) {
|
||||
const updates = action.payload.updates;
|
||||
const existingSourceId = state.tasks[index].sourceId;
|
||||
const incomingSourceId = updates.sourceId;
|
||||
const effectiveSourceId = incomingSourceId ?? existingSourceId;
|
||||
const sourceWasDismissed =
|
||||
!!effectiveSourceId &&
|
||||
isDismissed(state.dismissedSourceIds, effectiveSourceId);
|
||||
|
||||
// When task completes or fails, set dismissed to false to notify user
|
||||
if (updates.status === 'completed' || updates.status === 'failed') {
|
||||
// Re-surface on terminal status, except when the sourceId was
|
||||
// dismissed in a prior session (don't re-pop after reload).
|
||||
if (
|
||||
(updates.status === 'completed' || updates.status === 'failed') &&
|
||||
!sourceWasDismissed
|
||||
) {
|
||||
state.tasks[index] = {
|
||||
...state.tasks[index],
|
||||
...updates,
|
||||
@@ -126,12 +192,184 @@ export const uploadSlice = createSlice({
|
||||
...state.tasks[index],
|
||||
dismissed: true,
|
||||
};
|
||||
// Persist by sourceId so a reload doesn't re-pop via the SSE
|
||||
// backlog replay's auto-create branch.
|
||||
const sourceId = state.tasks[index].sourceId;
|
||||
if (sourceId) {
|
||||
state.dismissedSourceIds = recordDismissedSourceId(
|
||||
state.dismissedSourceIds,
|
||||
sourceId,
|
||||
);
|
||||
saveDismissed(
|
||||
DISMISSED_SOURCE_IDS_STORAGE_KEY,
|
||||
state.dismissedSourceIds,
|
||||
);
|
||||
}
|
||||
}
|
||||
},
|
||||
removeUploadTask: (state, action: PayloadAction<string>) => {
|
||||
state.tasks = state.tasks.filter((task) => task.id !== action.payload);
|
||||
},
|
||||
},
|
||||
extraReducers: (builder) => {
|
||||
// Consume backend SSE ingest events for sub-second progress and
|
||||
// terminal-status updates. The match is by ``sourceId`` (set by
|
||||
// the upload endpoint's response). Polling stays as the
|
||||
// correctness-of-record fallback in Upload.tsx.
|
||||
builder.addCase(sseEventReceived, (state, action) => {
|
||||
const e = action.payload;
|
||||
const scopeId =
|
||||
typeof e.scope?.id === 'string' && e.scope.id.length > 0
|
||||
? e.scope.id
|
||||
: undefined;
|
||||
|
||||
// Attachment events flow through the same SSE pipe; route them
|
||||
// to ``state.attachments`` matched by ``attachmentId``. SSE is
|
||||
// the sole driver of attachment state transitions — polling
|
||||
// has been removed. Events for attachments uploaded in another
|
||||
// session are silently dropped.
|
||||
if (e.type.startsWith('attachment.') && scopeId) {
|
||||
const attachment = state.attachments.find(
|
||||
(a) => a.attachmentId === scopeId,
|
||||
);
|
||||
if (attachment) {
|
||||
const payload = (e.payload || {}) as Record<string, unknown>;
|
||||
switch (e.type) {
|
||||
case 'attachment.queued':
|
||||
case 'attachment.progress': {
|
||||
if (
|
||||
attachment.status === 'completed' ||
|
||||
attachment.status === 'failed'
|
||||
) {
|
||||
break;
|
||||
}
|
||||
attachment.status = 'processing';
|
||||
const current = Number(payload.current);
|
||||
if (Number.isFinite(current)) {
|
||||
const clamped = Math.max(0, Math.min(100, current));
|
||||
if (clamped > attachment.progress) {
|
||||
attachment.progress = clamped;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'attachment.completed': {
|
||||
attachment.status = 'completed';
|
||||
attachment.progress = 100;
|
||||
// Replace the client-generated uuid with the server's
|
||||
// attachment id so question submission
|
||||
// (Conversation.tsx:174) sends an id the backend can
|
||||
// resolve. Without this the backend would silently drop
|
||||
// the attachment from the message context.
|
||||
attachment.id = scopeId;
|
||||
const tokenCount = Number(payload.token_count);
|
||||
if (Number.isFinite(tokenCount)) {
|
||||
attachment.token_count = tokenCount;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'attachment.failed': {
|
||||
attachment.status = 'failed';
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (!e.type.startsWith('source.ingest.')) return;
|
||||
if (!scopeId) return;
|
||||
// ``source.ingest.*`` payloads do not carry a celery task_id
|
||||
// (see ``application/worker.py`` — only ``source_id`` /
|
||||
// ``job_name`` / ``filename`` are published), so ``sourceId``
|
||||
// is the only correlation key.
|
||||
const payload = (e.payload || {}) as Record<string, unknown>;
|
||||
let task = state.tasks.find((t) => t.sourceId === scopeId);
|
||||
if (!task) {
|
||||
// Auto-create on backlog/live SSE when the local task is gone
|
||||
// (page refresh mid-ingest). Skip ``completed`` so the historical
|
||||
// backlog doesn't pop toasts for already-finished sources.
|
||||
if (e.type === 'source.ingest.completed') return;
|
||||
const fileName =
|
||||
(typeof payload.filename === 'string' && payload.filename) ||
|
||||
(typeof payload.job_name === 'string' && payload.job_name) ||
|
||||
scopeId;
|
||||
const previouslyDismissed = isDismissed(
|
||||
state.dismissedSourceIds,
|
||||
scopeId,
|
||||
);
|
||||
task = {
|
||||
id: scopeId,
|
||||
fileName,
|
||||
progress: 0,
|
||||
status: 'training',
|
||||
sourceId: scopeId,
|
||||
dismissed: previouslyDismissed,
|
||||
};
|
||||
state.tasks.push(task);
|
||||
task = state.tasks[state.tasks.length - 1];
|
||||
}
|
||||
const wasTerminal =
|
||||
task.status === 'completed' || task.status === 'failed';
|
||||
const sourceWasDismissed =
|
||||
!!task.sourceId && isDismissed(state.dismissedSourceIds, task.sourceId);
|
||||
|
||||
switch (e.type) {
|
||||
case 'source.ingest.queued':
|
||||
// Don't regress a task already past 'training' (e.g. the
|
||||
// queued event arrives after the upload XHR finished and
|
||||
// status flipped to 'training'). Idempotent on retries.
|
||||
if (task.status === 'preparing' || task.status === 'uploading') {
|
||||
task.status = 'training';
|
||||
task.progress = 0;
|
||||
}
|
||||
break;
|
||||
case 'source.ingest.progress': {
|
||||
const current = Number(payload.current);
|
||||
if (!Number.isFinite(current)) break;
|
||||
// Clamp + monotonic — never regress an already-higher value.
|
||||
const clamped = Math.max(0, Math.min(100, current));
|
||||
if (task.status === 'completed' || task.status === 'failed') break;
|
||||
task.status = 'training';
|
||||
if (clamped > task.progress) task.progress = clamped;
|
||||
break;
|
||||
}
|
||||
case 'source.ingest.completed':
|
||||
if (payload.limited === true) {
|
||||
// Token-cap reached during ingest — surface as a failure
|
||||
// so the toast shows the translated limit message rather
|
||||
// than a misleading success state.
|
||||
task.status = 'failed';
|
||||
task.progress = 100;
|
||||
task.tokenLimitReached = true;
|
||||
task.errorMessage = undefined;
|
||||
// Only un-dismiss on the initial terminal transition;
|
||||
// duplicate envelopes (StrictMode remount, reconnect
|
||||
// overlap) must not re-pop a user-dismissed toast.
|
||||
if (!wasTerminal && !sourceWasDismissed) task.dismissed = false;
|
||||
} else {
|
||||
task.status = 'completed';
|
||||
task.progress = 100;
|
||||
task.errorMessage = undefined;
|
||||
task.tokenLimitReached = false;
|
||||
if (!wasTerminal && !sourceWasDismissed) task.dismissed = false;
|
||||
}
|
||||
break;
|
||||
case 'source.ingest.failed':
|
||||
task.status = 'failed';
|
||||
task.errorMessage =
|
||||
typeof payload.error === 'string'
|
||||
? payload.error
|
||||
: 'Ingestion failed.';
|
||||
if (!wasTerminal && !sourceWasDismissed) task.dismissed = false;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
export const {
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
"module": "ESNext",
|
||||
"moduleResolution": "Bundler",
|
||||
"resolveJsonModule": true,
|
||||
"types": ["vite-plugin-svgr/client"],
|
||||
"types": ["vite-plugin-svgr/client", "vitest/globals"],
|
||||
"isolatedModules": true,
|
||||
"noEmit": true,
|
||||
"jsx": "react-jsx",
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
/// <reference types="vitest" />
|
||||
import { defineConfig } from 'vite';
|
||||
import react from '@vitejs/plugin-react';
|
||||
import svgr from 'vite-plugin-svgr';
|
||||
@@ -11,4 +12,10 @@ export default defineConfig({
|
||||
'@': path.resolve(__dirname, './src'),
|
||||
},
|
||||
},
|
||||
test: {
|
||||
environment: 'happy-dom',
|
||||
globals: true,
|
||||
include: ['src/**/*.test.{ts,tsx}'],
|
||||
setupFiles: ['./vitest.setup.ts'],
|
||||
},
|
||||
});
|
||||
|
||||
24
frontend/vitest.setup.ts
Normal file
24
frontend/vitest.setup.ts
Normal file
@@ -0,0 +1,24 @@
|
||||
// happy-dom's localStorage lives on `window`; slices read the bare
|
||||
// global at module-load. Install a Map-backed shim on globalThis.
|
||||
const store = new Map<string, string>();
|
||||
const shim = {
|
||||
getItem: (k: string) => (store.has(k) ? store.get(k)! : null),
|
||||
setItem: (k: string, v: string) => {
|
||||
store.set(k, String(v));
|
||||
},
|
||||
removeItem: (k: string) => {
|
||||
store.delete(k);
|
||||
},
|
||||
clear: () => store.clear(),
|
||||
key: (i: number) => Array.from(store.keys())[i] ?? null,
|
||||
get length() {
|
||||
return store.size;
|
||||
},
|
||||
};
|
||||
// Force-override: some happy-dom versions expose a Storage stub
|
||||
// without getItem, so simple assignment isn't enough.
|
||||
Object.defineProperty(globalThis, 'localStorage', {
|
||||
value: shim,
|
||||
writable: true,
|
||||
configurable: true,
|
||||
});
|
||||
@@ -492,7 +492,7 @@ class TestMCPOAuthManager:
|
||||
from application.agents.tools.mcp_tool import MCPOAuthManager
|
||||
|
||||
manager = MCPOAuthManager(MagicMock())
|
||||
result = manager.get_oauth_status("")
|
||||
result = manager.get_oauth_status("", "alice")
|
||||
assert result["status"] == "not_started"
|
||||
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ MCPOAuthManager.
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -869,17 +870,110 @@ class TestMCPOAuthManager:
|
||||
from application.agents.tools.mcp_tool import MCPOAuthManager
|
||||
|
||||
manager = MCPOAuthManager(MagicMock())
|
||||
result = manager.get_oauth_status("")
|
||||
result = manager.get_oauth_status("", "alice")
|
||||
assert result["status"] == "not_started"
|
||||
|
||||
def test_get_oauth_status_with_task(self):
|
||||
def test_get_oauth_status_no_user(self):
|
||||
from application.agents.tools.mcp_tool import MCPOAuthManager
|
||||
|
||||
with patch("application.agents.tools.mcp_tool.mcp_oauth_status_task",
|
||||
return_value={"status": "complete"}):
|
||||
manager = MCPOAuthManager(MagicMock())
|
||||
result = manager.get_oauth_status("task123")
|
||||
assert result["status"] == "complete"
|
||||
manager = MCPOAuthManager(MagicMock())
|
||||
result = manager.get_oauth_status("task123", "")
|
||||
# Without a user id we can't address the per-user SSE journal,
|
||||
# so the manager refuses rather than scanning every user's
|
||||
# stream.
|
||||
assert result["status"] == "not_found"
|
||||
|
||||
def test_get_oauth_status_reads_completed_envelope_from_journal(self):
|
||||
"""The manager walks the user's SSE Streams journal and
|
||||
surfaces the latest ``mcp.oauth.*`` envelope for the task.
|
||||
|
||||
Verifies the full polling-contract surface: ``status`` is
|
||||
derived from the event-type suffix, and the completed
|
||||
payload's ``tools`` / ``tools_count`` fields are passed
|
||||
through unchanged so ``mcp.py``'s ``connect_mcp`` can use them
|
||||
without further plumbing.
|
||||
"""
|
||||
from application.agents.tools.mcp_tool import MCPOAuthManager
|
||||
|
||||
completed_envelope = json.dumps(
|
||||
{
|
||||
"type": "mcp.oauth.completed",
|
||||
"scope": {"kind": "mcp_oauth", "id": "task123"},
|
||||
"payload": {
|
||||
"task_id": "task123",
|
||||
"tools": [{"name": "t1", "description": "d"}],
|
||||
"tools_count": 1,
|
||||
},
|
||||
}
|
||||
).encode("utf-8")
|
||||
# xrevrange yields newest-first; an unrelated older entry
|
||||
# should be skipped on the way to the matching one.
|
||||
unrelated_envelope = json.dumps(
|
||||
{
|
||||
"type": "source.ingest.completed",
|
||||
"scope": {"kind": "source", "id": "abc"},
|
||||
"payload": {},
|
||||
}
|
||||
).encode("utf-8")
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.xrevrange.return_value = [
|
||||
(b"1735682400000-0", {b"event": completed_envelope}),
|
||||
(b"1735682300000-0", {b"event": unrelated_envelope}),
|
||||
]
|
||||
|
||||
manager = MCPOAuthManager(mock_redis)
|
||||
result = manager.get_oauth_status("task123", "alice")
|
||||
|
||||
assert result["status"] == "completed"
|
||||
assert result["tools"] == [{"name": "t1", "description": "d"}]
|
||||
assert result["tools_count"] == 1
|
||||
# The stream key is scoped per-user — never global.
|
||||
mock_redis.xrevrange.assert_called_once()
|
||||
call_args = mock_redis.xrevrange.call_args
|
||||
assert "user:alice:stream" in call_args.args
|
||||
# Scan window must cover the full bounded stream so a flood of
|
||||
# concurrent source-ingest events between popup-completed and
|
||||
# Save can't push the OAuth envelope out of view.
|
||||
from application.core.settings import settings
|
||||
|
||||
assert call_args.kwargs.get("count") >= settings.EVENTS_STREAM_MAXLEN
|
||||
|
||||
def test_get_oauth_status_scan_covers_events_stream_maxlen(self):
|
||||
"""Regression: count must scale with ``EVENTS_STREAM_MAXLEN``
|
||||
so the OAuth completion envelope is reachable even after
|
||||
concurrent source-ingest events flood the user stream."""
|
||||
from application.agents.tools.mcp_tool import MCPOAuthManager
|
||||
from application.core.settings import settings
|
||||
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.xrevrange.return_value = []
|
||||
|
||||
manager = MCPOAuthManager(mock_redis)
|
||||
manager.get_oauth_status("task123", "alice")
|
||||
|
||||
call_args = mock_redis.xrevrange.call_args
|
||||
assert call_args.kwargs.get("count") >= settings.EVENTS_STREAM_MAXLEN
|
||||
|
||||
def test_get_oauth_status_returns_not_found_when_no_match(self):
|
||||
from application.agents.tools.mcp_tool import MCPOAuthManager
|
||||
|
||||
mock_redis = MagicMock()
|
||||
# Stream is non-empty but has nothing matching this task.
|
||||
unrelated_envelope = json.dumps(
|
||||
{
|
||||
"type": "mcp.oauth.completed",
|
||||
"scope": {"kind": "mcp_oauth", "id": "other-task"},
|
||||
"payload": {"task_id": "other-task"},
|
||||
}
|
||||
).encode("utf-8")
|
||||
mock_redis.xrevrange.return_value = [
|
||||
(b"1735682400000-0", {b"event": unrelated_envelope}),
|
||||
]
|
||||
|
||||
manager = MCPOAuthManager(mock_redis)
|
||||
result = manager.get_oauth_status("task123", "alice")
|
||||
|
||||
assert result["status"] == "not_found"
|
||||
|
||||
|
||||
# =====================================================================
|
||||
@@ -1325,6 +1419,7 @@ class TestSetupClientExtended:
|
||||
tool._client = None
|
||||
tool.available_tools = []
|
||||
tool.user_id = "user1"
|
||||
tool.oauth_redirect_publish = None
|
||||
|
||||
mock_client = MagicMock()
|
||||
with patch.object(MCPTool, "_create_transport", return_value=MagicMock()), \
|
||||
@@ -2257,16 +2352,21 @@ class TestMCPOAuthManagerExtended:
|
||||
# Should store error in redis
|
||||
mock_redis.setex.assert_called()
|
||||
|
||||
def test_get_oauth_status_task_error(self):
|
||||
def test_get_oauth_status_returns_not_found_on_redis_error(self):
|
||||
"""A failure inside ``xrevrange`` (Redis down, network blip) is
|
||||
swallowed — the manager returns ``not_found`` so the caller
|
||||
can present a clean "OAuth failed, try again" message rather
|
||||
than a 500.
|
||||
"""
|
||||
from application.agents.tools.mcp_tool import MCPOAuthManager
|
||||
|
||||
with patch(
|
||||
"application.agents.tools.mcp_tool.mcp_oauth_status_task",
|
||||
side_effect=Exception("task failed"),
|
||||
):
|
||||
manager = MCPOAuthManager(MagicMock())
|
||||
with pytest.raises(Exception, match="task failed"):
|
||||
manager.get_oauth_status("task123")
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.xrevrange.side_effect = Exception("Redis went away")
|
||||
|
||||
manager = MCPOAuthManager(mock_redis)
|
||||
result = manager.get_oauth_status("task123", "alice")
|
||||
|
||||
assert result["status"] == "not_found"
|
||||
|
||||
|
||||
# =====================================================================
|
||||
|
||||
@@ -449,6 +449,9 @@ class TestFinalizeMessage:
|
||||
question="q",
|
||||
decoded_token={"sub": user},
|
||||
)
|
||||
from application.storage.db.repositories.conversations import (
|
||||
MessageUpdateOutcome,
|
||||
)
|
||||
assert svc.finalize_message(
|
||||
res["message_id"],
|
||||
"real answer",
|
||||
@@ -458,7 +461,7 @@ class TestFinalizeMessage:
|
||||
model_id="gpt-4",
|
||||
metadata={"foo": "bar"},
|
||||
status="complete",
|
||||
) is True
|
||||
) is MessageUpdateOutcome.UPDATED
|
||||
|
||||
msgs = ConversationsRepository(pg_conn).get_messages(
|
||||
res["conversation_id"],
|
||||
@@ -488,12 +491,15 @@ class TestFinalizeMessage:
|
||||
decoded_token={"sub": user},
|
||||
)
|
||||
err = RuntimeError("provider down")
|
||||
from application.storage.db.repositories.conversations import (
|
||||
MessageUpdateOutcome,
|
||||
)
|
||||
assert svc.finalize_message(
|
||||
res["message_id"],
|
||||
"fallback text",
|
||||
status="failed",
|
||||
error=err,
|
||||
) is True
|
||||
) is MessageUpdateOutcome.UPDATED
|
||||
|
||||
msgs = ConversationsRepository(pg_conn).get_messages(
|
||||
res["conversation_id"],
|
||||
@@ -526,9 +532,12 @@ class TestFinalizeMessage:
|
||||
),
|
||||
{"cid": "c1", "mid": res["message_id"]},
|
||||
)
|
||||
from application.storage.db.repositories.conversations import (
|
||||
MessageUpdateOutcome,
|
||||
)
|
||||
assert svc.finalize_message(
|
||||
res["message_id"], "ans", status="complete",
|
||||
) is True
|
||||
) is MessageUpdateOutcome.UPDATED
|
||||
|
||||
status = pg_conn.execute(
|
||||
sql_text("SELECT status FROM tool_call_attempts WHERE call_id = :cid"),
|
||||
@@ -536,16 +545,19 @@ class TestFinalizeMessage:
|
||||
).scalar()
|
||||
assert status == "confirmed"
|
||||
|
||||
def test_finalize_returns_false_for_unknown_message(self, pg_conn):
|
||||
def test_finalize_returns_not_found_for_unknown_message(self, pg_conn):
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
from application.storage.db.repositories.conversations import (
|
||||
MessageUpdateOutcome,
|
||||
)
|
||||
with _patch_db(pg_conn):
|
||||
assert ConversationService().finalize_message(
|
||||
"00000000-0000-0000-0000-000000000000",
|
||||
"x",
|
||||
status="complete",
|
||||
) is False
|
||||
) is MessageUpdateOutcome.NOT_FOUND
|
||||
|
||||
def test_finalize_rolls_back_tool_call_confirm_on_message_update_failure(
|
||||
self, pg_conn
|
||||
@@ -651,6 +663,9 @@ class TestFinalizeMessage:
|
||||
question="long question that becomes the fallback name",
|
||||
decoded_token={"sub": user},
|
||||
)
|
||||
from application.storage.db.repositories.conversations import (
|
||||
MessageUpdateOutcome,
|
||||
)
|
||||
assert svc.finalize_message(
|
||||
res["message_id"],
|
||||
"answer",
|
||||
@@ -664,7 +679,7 @@ class TestFinalizeMessage:
|
||||
"long question that becomes the fallback name"[:50]
|
||||
),
|
||||
},
|
||||
) is True
|
||||
) is MessageUpdateOutcome.UPDATED
|
||||
|
||||
repo = ConversationsRepository(pg_conn)
|
||||
conv = repo.get_any(res["conversation_id"], user)
|
||||
|
||||
@@ -343,6 +343,50 @@ class TestProcessResponseStreamExtended:
|
||||
result = resource.process_response_stream(iter(stream))
|
||||
assert result["thought"] == "thinking..."
|
||||
|
||||
def test_handles_id_prefixed_chunks(self, mock_mongo_db, flask_app):
|
||||
"""``complete_stream`` emits ``id: <seq>\\n`` before each
|
||||
``data:`` line so reconnects can resume. The non-streaming
|
||||
``/api/answer`` consumer must skip the ``id:`` header (and the
|
||||
informational ``message_id`` event) without breaking JSON
|
||||
decoding.
|
||||
"""
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
stream = [
|
||||
'id: 0\n'
|
||||
f'data: {json.dumps({"type": "message_id", "message_id": "abc"})}\n\n',
|
||||
'id: 1\n'
|
||||
f'data: {json.dumps({"type": "answer", "answer": "Hello "})}\n\n',
|
||||
'id: 2\n'
|
||||
f'data: {json.dumps({"type": "answer", "answer": "world"})}\n\n',
|
||||
'id: 3\n'
|
||||
f'data: {json.dumps({"type": "id", "id": "conv-1"})}\n\n',
|
||||
'id: 4\n'
|
||||
f'data: {json.dumps({"type": "end"})}\n\n',
|
||||
]
|
||||
result = resource.process_response_stream(iter(stream))
|
||||
assert result["error"] is None
|
||||
assert result["answer"] == "Hello world"
|
||||
assert result["conversation_id"] == "conv-1"
|
||||
|
||||
def test_skips_keepalive_comment_lines(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
stream = [
|
||||
': keepalive\n\n',
|
||||
'id: 0\n'
|
||||
f'data: {json.dumps({"type": "answer", "answer": "ok"})}\n\n',
|
||||
'id: 1\n'
|
||||
f'data: {json.dumps({"type": "end"})}\n\n',
|
||||
]
|
||||
result = resource.process_response_stream(iter(stream))
|
||||
assert result["answer"] == "ok"
|
||||
assert result["error"] is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCheckUsageStringBooleans:
|
||||
@@ -593,6 +637,251 @@ class TestCompleteStreamGeneratorExit:
|
||||
next(gen)
|
||||
gen.close() # Should not crash even with save error
|
||||
|
||||
def test_generator_exit_before_any_response_journals_error_not_end(
|
||||
self, mock_mongo_db, flask_app,
|
||||
):
|
||||
"""A client disconnect right after the early ``message_id`` frame
|
||||
leaves ``response_full`` empty, so finalize never runs. The abort
|
||||
handler must journal ``error`` (not ``end``) and flip the row to
|
||||
``failed`` — otherwise a reconnecting client sees a terminal
|
||||
``end`` for a row whose DB status is still non-terminal and the
|
||||
UI parks on a blank successful answer.
|
||||
"""
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
mock_agent = MagicMock()
|
||||
|
||||
# An agent that's primed but never yields anything before the
|
||||
# client disconnects — keeps ``response_full`` empty.
|
||||
def gen_never_yields():
|
||||
if False:
|
||||
yield # pragma: no cover
|
||||
return
|
||||
|
||||
mock_agent.gen.return_value = gen_never_yields()
|
||||
mock_agent.compression_metadata = None
|
||||
mock_agent.compression_saved = False
|
||||
mock_agent.tool_calls = []
|
||||
|
||||
resource.conversation_service = MagicMock()
|
||||
resource.conversation_service.save_user_question.return_value = {
|
||||
"conversation_id": "conv1",
|
||||
"message_id": "msg1",
|
||||
"request_id": "req1",
|
||||
}
|
||||
|
||||
journaled: list[tuple] = []
|
||||
|
||||
def _capture_record(message_id, sequence_no, event_type, payload):
|
||||
journaled.append((message_id, sequence_no, event_type, payload))
|
||||
return True
|
||||
|
||||
with patch(
|
||||
"application.api.answer.routes.base.record_event",
|
||||
side_effect=_capture_record,
|
||||
):
|
||||
gen = resource.complete_stream(
|
||||
question="Q",
|
||||
agent=mock_agent,
|
||||
conversation_id="conv1",
|
||||
user_api_key=None,
|
||||
decoded_token={"sub": "u"},
|
||||
should_save_conversation=True,
|
||||
model_id="gpt-4",
|
||||
)
|
||||
# Drain the early ``message_id`` event, then close before
|
||||
# the agent yields anything.
|
||||
next(gen)
|
||||
gen.close()
|
||||
|
||||
# The early message_id frame got journaled (seq 0); the abort
|
||||
# handler must follow with an ``error`` event (NOT ``end``).
|
||||
terminal_events = [
|
||||
(et, pl) for (_, _, et, pl) in journaled if et in ("end", "error")
|
||||
]
|
||||
assert len(terminal_events) == 1, (
|
||||
f"expected exactly one terminal journal write, got {terminal_events}"
|
||||
)
|
||||
assert terminal_events[0][0] == "error", (
|
||||
f"expected ``error`` terminal but got ``end``: {terminal_events}"
|
||||
)
|
||||
payload = terminal_events[0][1]
|
||||
assert payload.get("type") == "error"
|
||||
assert payload.get("code") == "client_disconnect"
|
||||
|
||||
# And the DB row should have been flipped to ``failed`` via
|
||||
# finalize_message. The mocked service records the call.
|
||||
finalize_calls = (
|
||||
resource.conversation_service.finalize_message.call_args_list
|
||||
)
|
||||
assert len(finalize_calls) == 1
|
||||
assert finalize_calls[0].kwargs.get("status") == "failed"
|
||||
|
||||
def test_generator_exit_after_response_still_journals_end(
|
||||
self, mock_mongo_db, flask_app,
|
||||
):
|
||||
"""Regression guard: a disconnect AFTER partial response was
|
||||
produced and ``finalize_message`` succeeded must still journal
|
||||
``end`` (the row matches ``complete``). Only the empty-response
|
||||
branch flips to ``error``.
|
||||
"""
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
from application.storage.db.repositories.conversations import (
|
||||
MessageUpdateOutcome,
|
||||
)
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
mock_agent = MagicMock()
|
||||
|
||||
def gen_answers():
|
||||
yield {"answer": "partial"}
|
||||
|
||||
mock_agent.gen.return_value = gen_answers()
|
||||
mock_agent.compression_metadata = None
|
||||
mock_agent.compression_saved = False
|
||||
mock_agent.tool_calls = []
|
||||
|
||||
resource.conversation_service = MagicMock()
|
||||
resource.conversation_service.finalize_message.return_value = (
|
||||
MessageUpdateOutcome.UPDATED
|
||||
)
|
||||
resource.conversation_service.save_user_question.return_value = {
|
||||
"conversation_id": "conv1",
|
||||
"message_id": "msg1",
|
||||
"request_id": "req1",
|
||||
}
|
||||
|
||||
journaled: list[tuple] = []
|
||||
|
||||
def _capture_record(message_id, sequence_no, event_type, payload):
|
||||
journaled.append((message_id, sequence_no, event_type, payload))
|
||||
return True
|
||||
|
||||
with patch(
|
||||
"application.api.answer.routes.base.record_event",
|
||||
side_effect=_capture_record,
|
||||
):
|
||||
gen = resource.complete_stream(
|
||||
question="Q",
|
||||
agent=mock_agent,
|
||||
conversation_id="conv1",
|
||||
user_api_key=None,
|
||||
decoded_token={"sub": "u"},
|
||||
should_save_conversation=True,
|
||||
model_id="gpt-4",
|
||||
)
|
||||
next(gen) # message_id frame
|
||||
next(gen) # answer frame (consumes ``partial``)
|
||||
gen.close()
|
||||
|
||||
terminal_events = [
|
||||
(et, pl) for (_, _, et, pl) in journaled if et in ("end", "error")
|
||||
]
|
||||
assert terminal_events and terminal_events[0][0] == "end", (
|
||||
f"finalize succeeded but abort journaled {terminal_events}"
|
||||
)
|
||||
|
||||
def test_generator_exit_after_normal_finalize_already_complete_journals_end(
|
||||
self, mock_mongo_db, flask_app,
|
||||
):
|
||||
"""Regression for the race where the normal-path finalize wins
|
||||
against a client disconnect.
|
||||
|
||||
Trace: agent finishes, ``complete_stream`` runs the normal-path
|
||||
``finalize_message`` at base.py:632 and flips the row to
|
||||
``complete``. The client TCP-resets before the ``end`` frame can
|
||||
be journaled. The GeneratorExit handler calls ``finalize_message``
|
||||
again — and the repository, gated by ``only_if_non_terminal``,
|
||||
reports ``ALREADY_COMPLETE`` because the row is already at the
|
||||
target state. The abort handler must journal ``end``, not
|
||||
``error``: the DB says ``complete``, the reconnecting client
|
||||
must see the same.
|
||||
|
||||
Without the fix the abort handler treated ``ALREADY_COMPLETE``
|
||||
as a failure and journaled ``error`` — a reconnect would then
|
||||
replay a successful completion as a failed answer.
|
||||
"""
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
from application.storage.db.repositories.conversations import (
|
||||
MessageUpdateOutcome,
|
||||
)
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
mock_agent = MagicMock()
|
||||
|
||||
def gen_answers():
|
||||
yield {"answer": "partial"}
|
||||
|
||||
mock_agent.gen.return_value = gen_answers()
|
||||
mock_agent.compression_metadata = None
|
||||
mock_agent.compression_saved = False
|
||||
mock_agent.tool_calls = []
|
||||
|
||||
resource.conversation_service = MagicMock()
|
||||
# Normal-path call returns UPDATED; abort-handler call sees
|
||||
# the row already terminal and returns ALREADY_COMPLETE.
|
||||
resource.conversation_service.finalize_message.side_effect = [
|
||||
MessageUpdateOutcome.UPDATED,
|
||||
MessageUpdateOutcome.ALREADY_COMPLETE,
|
||||
]
|
||||
resource.conversation_service.save_user_question.return_value = {
|
||||
"conversation_id": "conv1",
|
||||
"message_id": "msg1",
|
||||
"request_id": "req1",
|
||||
}
|
||||
|
||||
journaled: list[tuple] = []
|
||||
|
||||
def _capture_record(message_id, sequence_no, event_type, payload):
|
||||
journaled.append((message_id, sequence_no, event_type, payload))
|
||||
return True
|
||||
|
||||
with patch(
|
||||
"application.api.answer.routes.base.record_event",
|
||||
side_effect=_capture_record,
|
||||
):
|
||||
gen = resource.complete_stream(
|
||||
question="Q",
|
||||
agent=mock_agent,
|
||||
conversation_id="conv1",
|
||||
user_api_key=None,
|
||||
decoded_token={"sub": "u"},
|
||||
should_save_conversation=True,
|
||||
model_id="gpt-4",
|
||||
)
|
||||
next(gen) # message_id frame
|
||||
next(gen) # answer frame
|
||||
# Pull the ``id`` frame so the normal-path finalize at
|
||||
# line 632 actually runs (it sits between the answer
|
||||
# frame yield and the id frame yield).
|
||||
next(gen) # id frame — runs normal-path finalize first
|
||||
gen.close() # GeneratorExit at the id frame yield
|
||||
|
||||
terminal_events = [
|
||||
(et, pl) for (_, _, et, pl) in journaled if et in ("end", "error")
|
||||
]
|
||||
assert len(terminal_events) == 1, (
|
||||
f"expected exactly one terminal journal write, got "
|
||||
f"{terminal_events}"
|
||||
)
|
||||
assert terminal_events[0][0] == "end", (
|
||||
f"row was already ``complete`` but abort journaled "
|
||||
f"{terminal_events[0]} — reconnect would surface a "
|
||||
f"successful answer as failed"
|
||||
)
|
||||
|
||||
# Both finalize_message calls were made: the normal-path
|
||||
# one and the abort-handler one. Asserting on side_effect
|
||||
# consumption ensures the test really exercised both
|
||||
# branches.
|
||||
assert (
|
||||
resource.conversation_service.finalize_message.call_count == 2
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _patch_db_session(conn):
|
||||
@@ -606,10 +895,34 @@ def _patch_db_session(conn):
|
||||
), patch(
|
||||
"application.api.answer.services.conversation_service.db_readonly",
|
||||
_yield,
|
||||
), patch(
|
||||
# ``record_event`` opens its own short-lived ``db_session`` for
|
||||
# cross-connection visibility. In tests we route it back to the
|
||||
# same ``pg_conn`` so the journal write can see the message row
|
||||
# the conversation_service just wrote in this transaction.
|
||||
"application.streaming.message_journal.db_session",
|
||||
_yield,
|
||||
), patch(
|
||||
# ``complete_stream`` reads ``latest_sequence_no`` via
|
||||
# ``db_readonly`` to seed continuation runs. Same patch reason
|
||||
# as the journal — keep the read on the same pg_conn so it sees
|
||||
# uncommitted writes from this transaction.
|
||||
"application.api.answer.routes.base.db_readonly",
|
||||
_yield,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
def _extract_sse_data(chunk: str) -> str:
|
||||
"""Pull the ``data:`` payload from an SSE record, ignoring any
|
||||
``id:`` header introduced by the journal wiring.
|
||||
"""
|
||||
for line in chunk.split("\n"):
|
||||
if line.startswith("data:"):
|
||||
return line[len("data:") :].lstrip()
|
||||
return ""
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCompleteStreamWalAcceptance:
|
||||
"""Acceptance for the WAL pre-persist behaviour: when the LLM raises
|
||||
@@ -659,6 +972,158 @@ class TestCompleteStreamWalAcceptance:
|
||||
assert "RuntimeError" in msgs[0]["metadata"]["error"]
|
||||
assert "LLM upstream failed" in msgs[0]["metadata"]["error"]
|
||||
|
||||
def test_tool_approval_event_only_fires_when_state_saved(
|
||||
self, pg_conn, flask_app,
|
||||
):
|
||||
"""A `tool.approval.required` notification with no resumable
|
||||
``pending_tool_state`` row would deep-link the user to a 404.
|
||||
Gate the publish on save_state actually committing.
|
||||
"""
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
mock_agent = MagicMock()
|
||||
# Drive the agent into the paused branch with a single
|
||||
# ``tool_calls_pending`` event.
|
||||
mock_agent.gen.return_value = iter(
|
||||
[
|
||||
{
|
||||
"type": "tool_calls_pending",
|
||||
"data": {"pending_tool_calls": [{"call_id": "c1"}]},
|
||||
}
|
||||
]
|
||||
)
|
||||
mock_agent._pending_continuation = {
|
||||
"messages": [],
|
||||
"tools_dict": {},
|
||||
"pending_tool_calls": [{"call_id": "c1"}],
|
||||
}
|
||||
mock_agent.tool_calls = []
|
||||
mock_agent.compression_metadata = None
|
||||
mock_agent.compression_saved = False
|
||||
|
||||
published: list = []
|
||||
|
||||
def _capture(*args, **kwargs):
|
||||
published.append(args)
|
||||
|
||||
with _patch_db_session(pg_conn), patch(
|
||||
"application.api.answer.routes.base.publish_user_event",
|
||||
side_effect=_capture,
|
||||
), patch(
|
||||
"application.api.answer.services.continuation_service."
|
||||
"ContinuationService.save_state",
|
||||
side_effect=RuntimeError("PG outage"),
|
||||
):
|
||||
list(
|
||||
resource.complete_stream(
|
||||
question="run my tool",
|
||||
agent=mock_agent,
|
||||
conversation_id=None,
|
||||
user_api_key=None,
|
||||
decoded_token={"sub": "u-tap"},
|
||||
should_save_conversation=True,
|
||||
model_id="gpt-4",
|
||||
)
|
||||
)
|
||||
|
||||
# No tool.approval.required publish when save_state failed.
|
||||
event_types = [a[1] for a in published if len(a) >= 2]
|
||||
assert "tool.approval.required" not in event_types
|
||||
|
||||
def test_continuation_seeds_sequence_no_from_journal_high_water_mark(
|
||||
self, pg_conn, flask_app,
|
||||
):
|
||||
"""A resumed (tool-actions) stream must continue numbering past
|
||||
the original run's max ``sequence_no``. Otherwise the second
|
||||
invocation collides on the duplicate-PK and silently drops
|
||||
every journal write past the resume point.
|
||||
|
||||
We pre-seed the journal with synthetic rows simulating an
|
||||
original run, then invoke ``complete_stream`` with
|
||||
``_continuation`` set and assert seq numbering picks up past
|
||||
the high-water mark.
|
||||
"""
|
||||
import uuid as _uuid
|
||||
|
||||
from sqlalchemy import text as sql_text
|
||||
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
from application.storage.db.repositories.message_events import (
|
||||
MessageEventsRepository,
|
||||
)
|
||||
|
||||
# Pre-seed: insert a parent conversation + message + a few
|
||||
# journal rows so latest_sequence_no returns 7.
|
||||
user_id = "u-resume"
|
||||
conv_id = _uuid.uuid4()
|
||||
message_id = _uuid.uuid4()
|
||||
pg_conn.execute(
|
||||
sql_text("INSERT INTO users (user_id) VALUES (:u)"),
|
||||
{"u": user_id},
|
||||
)
|
||||
pg_conn.execute(
|
||||
sql_text(
|
||||
"INSERT INTO conversations (id, user_id, name) "
|
||||
"VALUES (:id, :u, 'pre-seed')"
|
||||
),
|
||||
{"id": conv_id, "u": user_id},
|
||||
)
|
||||
pg_conn.execute(
|
||||
sql_text(
|
||||
"INSERT INTO conversation_messages "
|
||||
"(id, conversation_id, user_id, position, prompt) "
|
||||
"VALUES (:id, :c, :u, 0, 'q')"
|
||||
),
|
||||
{"id": message_id, "c": conv_id, "u": user_id},
|
||||
)
|
||||
repo = MessageEventsRepository(pg_conn)
|
||||
for seq in range(8): # rows 0..7
|
||||
repo.record(str(message_id), seq, "answer", {"chunk": str(seq)})
|
||||
original_max = repo.latest_sequence_no(str(message_id))
|
||||
assert original_max == 7
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
cont_agent = MagicMock()
|
||||
cont_agent.gen_continuation.return_value = iter(
|
||||
[{"answer": "d"}, {"answer": "e"}]
|
||||
)
|
||||
cont_agent.tool_calls = []
|
||||
cont_agent.compression_metadata = None
|
||||
cont_agent.compression_saved = False
|
||||
cont_agent.tool_executor = None
|
||||
|
||||
with _patch_db_session(pg_conn):
|
||||
list(
|
||||
resource.complete_stream(
|
||||
question="",
|
||||
agent=cont_agent,
|
||||
conversation_id=str(conv_id),
|
||||
user_api_key=None,
|
||||
decoded_token={"sub": user_id},
|
||||
should_save_conversation=True,
|
||||
model_id="gpt-4",
|
||||
_continuation={
|
||||
"messages": [],
|
||||
"tools_dict": {},
|
||||
"pending_tool_calls": [],
|
||||
"tool_actions": [],
|
||||
"reserved_message_id": str(message_id),
|
||||
"request_id": "req-resume",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
new_max = MessageEventsRepository(pg_conn).latest_sequence_no(
|
||||
str(message_id)
|
||||
)
|
||||
# Continuation extended the journal — the new high-water mark
|
||||
# is strictly greater than the seeded ``original_max=7``,
|
||||
# confirming the allocator picked up past the resume point.
|
||||
assert new_max is not None and new_max > original_max
|
||||
|
||||
def test_request_id_consistent_across_sse_event_and_wal_row(
|
||||
self, pg_conn, flask_app,
|
||||
):
|
||||
@@ -694,9 +1159,9 @@ class TestCompleteStreamWalAcceptance:
|
||||
)
|
||||
|
||||
sse_events = [
|
||||
json.loads(s.replace("data: ", "").strip())
|
||||
json.loads(_extract_sse_data(s))
|
||||
for s in stream
|
||||
if s.startswith("data: ")
|
||||
if "data:" in s
|
||||
]
|
||||
early_events = [e for e in sse_events if e.get("type") == "message_id"]
|
||||
assert len(early_events) == 1
|
||||
@@ -712,3 +1177,102 @@ class TestCompleteStreamWalAcceptance:
|
||||
msgs = ConversationsRepository(pg_conn).get_messages(str(convs[0][0]))
|
||||
assert len(msgs) == 1
|
||||
assert msgs[0]["request_id"] == sse_request_id
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestStreamingHeartbeatSeed:
|
||||
"""Regression guard: when the row flips to ``streaming`` we must seed
|
||||
``last_heartbeat_at`` so the watchdog doesn't fall back to
|
||||
``timestamp`` (creation time) on slow LLM cold-starts (>idle_secs).
|
||||
"""
|
||||
|
||||
def test_heartbeat_seeded_on_first_chunk(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.gen.return_value = iter(
|
||||
[{"answer": "a"}, {"answer": "b"}]
|
||||
)
|
||||
mock_agent.compression_metadata = None
|
||||
mock_agent.compression_saved = False
|
||||
mock_agent.tool_calls = []
|
||||
mock_agent.tool_executor = None
|
||||
|
||||
resource.conversation_service = MagicMock()
|
||||
resource.conversation_service.save_conversation.return_value = "conv1"
|
||||
resource.conversation_service.save_user_question.return_value = {
|
||||
"conversation_id": "conv1",
|
||||
"message_id": "msg1",
|
||||
"request_id": "req1",
|
||||
}
|
||||
|
||||
list(
|
||||
resource.complete_stream(
|
||||
question="Q",
|
||||
agent=mock_agent,
|
||||
conversation_id="conv1",
|
||||
user_api_key=None,
|
||||
decoded_token={"sub": "u"},
|
||||
should_save_conversation=True,
|
||||
model_id="gpt-4",
|
||||
)
|
||||
)
|
||||
|
||||
# update_message_status flips the row to ``streaming`` exactly
|
||||
# once (idempotent via streaming_marked).
|
||||
status_calls = [
|
||||
c for c in resource.conversation_service
|
||||
.update_message_status.call_args_list
|
||||
if c.args[1] == "streaming"
|
||||
]
|
||||
assert len(status_calls) == 1
|
||||
assert status_calls[0].args[0] == "msg1"
|
||||
|
||||
# heartbeat seed runs exactly once at the same flip — multiple
|
||||
# chunks don't re-stamp inside this window (the throttled
|
||||
# _heartbeat_streaming path is gated by STREAM_HEARTBEAT_INTERVAL).
|
||||
hb_calls = (
|
||||
resource.conversation_service.heartbeat_message.call_args_list
|
||||
)
|
||||
assert len(hb_calls) == 1
|
||||
assert hb_calls[0].args[0] == "msg1"
|
||||
|
||||
def test_heartbeat_seed_skipped_without_reserved_message_id(
|
||||
self, mock_mongo_db, flask_app,
|
||||
):
|
||||
"""No DB-backed message row → no heartbeat call (and no error)."""
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.gen.return_value = iter([{"answer": "a"}])
|
||||
mock_agent.compression_metadata = None
|
||||
mock_agent.compression_saved = False
|
||||
mock_agent.tool_calls = []
|
||||
mock_agent.tool_executor = None
|
||||
|
||||
resource.conversation_service = MagicMock()
|
||||
resource.conversation_service.save_conversation.return_value = "conv1"
|
||||
# save_user_question returns no message_id → reservation absent
|
||||
resource.conversation_service.save_user_question.return_value = {
|
||||
"conversation_id": "conv1",
|
||||
"message_id": None,
|
||||
"request_id": "req1",
|
||||
}
|
||||
|
||||
list(
|
||||
resource.complete_stream(
|
||||
question="Q",
|
||||
agent=mock_agent,
|
||||
conversation_id="conv1",
|
||||
user_api_key=None,
|
||||
decoded_token={"sub": "u"},
|
||||
should_save_conversation=True,
|
||||
model_id="gpt-4",
|
||||
)
|
||||
)
|
||||
|
||||
resource.conversation_service.heartbeat_message.assert_not_called()
|
||||
|
||||
233
tests/api/answer/test_snapshot_tail_integration.py
Normal file
233
tests/api/answer/test_snapshot_tail_integration.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""Integration tests for the end-to-end snapshot+tail handoff.
|
||||
|
||||
Exercises the publisher → journal → reconnect endpoint round-trip
|
||||
without mocking the journal layer, so a regression in any of:
|
||||
- complete_stream's _emit closure
|
||||
- record_event's commit-per-call contract
|
||||
- build_message_event_stream's snapshot-from-DB path
|
||||
- the reconnect route's auth + ownership gates
|
||||
- message_events repo SQL
|
||||
|
||||
would surface here as a failed integration assertion.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid as _uuid
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import text as sql_text
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _patch_journal_session(conn):
|
||||
@contextmanager
|
||||
def _yield():
|
||||
yield conn
|
||||
|
||||
with patch(
|
||||
"application.streaming.message_journal.db_session", _yield
|
||||
), patch(
|
||||
"application.streaming.event_replay.db_readonly", _yield
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
def _seed_message(conn, user_id: str | None = None):
|
||||
user_id = user_id or f"u-{_uuid.uuid4().hex[:6]}"
|
||||
conv_id = _uuid.uuid4()
|
||||
msg_id = _uuid.uuid4()
|
||||
conn.execute(sql_text("INSERT INTO users (user_id) VALUES (:u)"), {"u": user_id})
|
||||
conn.execute(
|
||||
sql_text(
|
||||
"INSERT INTO conversations (id, user_id, name) VALUES (:id, :u, 't')"
|
||||
),
|
||||
{"id": conv_id, "u": user_id},
|
||||
)
|
||||
conn.execute(
|
||||
sql_text(
|
||||
"INSERT INTO conversation_messages (id, conversation_id, user_id, position) "
|
||||
"VALUES (:id, :c, :u, 0)"
|
||||
),
|
||||
{"id": msg_id, "c": conv_id, "u": user_id},
|
||||
)
|
||||
return user_id, str(msg_id)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestSnapshotPlusTailRoundTrip:
|
||||
def test_record_event_then_snapshot_returns_what_was_journaled(
|
||||
self, pg_conn,
|
||||
):
|
||||
"""End-to-end of the journal half: ``record_event`` writes through
|
||||
a real ``MessageEventsRepository``, ``build_message_event_stream``
|
||||
reads the snapshot back via the same repo + a stub Topic.
|
||||
"""
|
||||
from application.streaming import event_replay
|
||||
from application.streaming.message_journal import record_event
|
||||
|
||||
_, message_id = _seed_message(pg_conn)
|
||||
|
||||
with _patch_journal_session(pg_conn):
|
||||
# Three events stamp the journal.
|
||||
record_event(message_id, 0, "answer", {"type": "answer", "answer": "A"})
|
||||
record_event(message_id, 1, "answer", {"type": "answer", "answer": "B"})
|
||||
record_event(message_id, 2, "end", {"type": "end"})
|
||||
|
||||
# Reconnect path: subscribe yields nothing (Redis-down
|
||||
# branch); the post-loop fallback runs the snapshot read
|
||||
# synchronously and yields the journal contents.
|
||||
def _empty_subscribe(self, on_subscribe=None, poll_timeout=1.0):
|
||||
return
|
||||
yield # pragma: no cover
|
||||
|
||||
with patch.object(
|
||||
event_replay.Topic,
|
||||
"subscribe",
|
||||
_empty_subscribe,
|
||||
create=False,
|
||||
):
|
||||
gen = event_replay.build_message_event_stream(
|
||||
message_id,
|
||||
last_event_id=None,
|
||||
keepalive_seconds=0.05,
|
||||
poll_timeout_seconds=0.01,
|
||||
)
|
||||
out = list(gen)
|
||||
|
||||
# Prelude + 3 snapshot frames.
|
||||
assert out[0] == ": connected\n\n"
|
||||
assert "id: 0" in out[1] and '"answer": "A"' in out[1]
|
||||
assert "id: 1" in out[2] and '"answer": "B"' in out[2]
|
||||
assert "id: 2" in out[3] and '"type": "end"' in out[3]
|
||||
|
||||
def test_snapshot_resumes_past_last_event_id(self, pg_conn):
|
||||
from application.streaming import event_replay
|
||||
from application.streaming.message_journal import record_event
|
||||
|
||||
_, message_id = _seed_message(pg_conn)
|
||||
|
||||
with _patch_journal_session(pg_conn):
|
||||
for seq in range(5):
|
||||
record_event(
|
||||
message_id, seq, "answer", {"type": "answer", "answer": str(seq)}
|
||||
)
|
||||
|
||||
def _empty_subscribe(self, on_subscribe=None, poll_timeout=1.0):
|
||||
return
|
||||
yield # pragma: no cover
|
||||
|
||||
with patch.object(
|
||||
event_replay.Topic,
|
||||
"subscribe",
|
||||
_empty_subscribe,
|
||||
create=False,
|
||||
):
|
||||
# Client says it has seen up through seq=2; expect 3 + 4.
|
||||
out = list(
|
||||
event_replay.build_message_event_stream(
|
||||
message_id,
|
||||
last_event_id=2,
|
||||
keepalive_seconds=0.05,
|
||||
poll_timeout_seconds=0.01,
|
||||
)
|
||||
)
|
||||
|
||||
ids_seen = [line for line in out if line.startswith("id: ")]
|
||||
# Multi-line records: extract the id integers we delivered.
|
||||
emitted = sorted(
|
||||
int(line.split(": ", 1)[1].split("\n")[0])
|
||||
for line in out
|
||||
if line.startswith("id: ")
|
||||
)
|
||||
# Filter for non-negative (the snapshot-failure synthetic uses -1).
|
||||
emitted = [e for e in emitted if e >= 0]
|
||||
assert emitted == [3, 4]
|
||||
assert ids_seen # sanity
|
||||
|
||||
def test_reconnect_route_round_trip(self, pg_conn, flask_app):
|
||||
"""``/api/messages/<id>/events`` returns the journaled events
|
||||
for an authenticated owner.
|
||||
"""
|
||||
from flask import Flask, request
|
||||
|
||||
from application.api.answer.routes.messages import messages_bp
|
||||
from application.streaming.message_journal import record_event
|
||||
|
||||
# Build a fresh Flask app routing to the reconnect blueprint
|
||||
# plus a tiny auth shim that injects the test user.
|
||||
user_id, message_id = _seed_message(pg_conn)
|
||||
app = Flask(__name__)
|
||||
app.register_blueprint(messages_bp)
|
||||
app.config["TESTING"] = True
|
||||
|
||||
@app.before_request
|
||||
def _shim_auth():
|
||||
request.decoded_token = {"sub": user_id}
|
||||
|
||||
with _patch_journal_session(pg_conn):
|
||||
record_event(message_id, 0, "answer", {"type": "answer", "answer": "x"})
|
||||
record_event(message_id, 1, "end", {"type": "end"})
|
||||
|
||||
from application.streaming import event_replay
|
||||
|
||||
def _empty_subscribe(self, on_subscribe=None, poll_timeout=1.0):
|
||||
return
|
||||
yield # pragma: no cover
|
||||
|
||||
with patch.object(
|
||||
event_replay.Topic,
|
||||
"subscribe",
|
||||
_empty_subscribe,
|
||||
create=False,
|
||||
), patch(
|
||||
"application.api.answer.routes.messages.db_readonly"
|
||||
) as ro:
|
||||
ro.return_value.__enter__.return_value = pg_conn
|
||||
|
||||
with app.test_client() as c:
|
||||
r = c.get(f"/api/messages/{message_id}/events")
|
||||
assert r.status_code == 200
|
||||
body = b""
|
||||
for chunk in r.iter_encoded():
|
||||
body += chunk
|
||||
if body.count(b"\n\n") >= 4:
|
||||
break
|
||||
r.close()
|
||||
# Both journaled events present in the response.
|
||||
text = body.decode("utf-8")
|
||||
assert ": connected" in text
|
||||
assert '"answer": "x"' in text
|
||||
assert '"type": "end"' in text
|
||||
# The seq lines are correct.
|
||||
assert "id: 0" in text and "id: 1" in text
|
||||
|
||||
def test_reconnect_rejects_non_owner(self, pg_conn, flask_app):
|
||||
from flask import Flask, request
|
||||
|
||||
from application.api.answer.routes.messages import messages_bp
|
||||
|
||||
user_id, message_id = _seed_message(pg_conn)
|
||||
app = Flask(__name__)
|
||||
app.register_blueprint(messages_bp)
|
||||
|
||||
@app.before_request
|
||||
def _shim_auth():
|
||||
request.decoded_token = {"sub": "different-user"}
|
||||
|
||||
# Make the ownership check use the test connection.
|
||||
with patch(
|
||||
"application.api.answer.routes.messages.db_readonly"
|
||||
) as ro:
|
||||
from contextlib import contextmanager as _cm
|
||||
|
||||
@_cm
|
||||
def _yield():
|
||||
yield pg_conn
|
||||
|
||||
ro.side_effect = lambda: _yield()
|
||||
with app.test_client() as c:
|
||||
r = c.get(f"/api/messages/{message_id}/events")
|
||||
assert r.status_code == 404
|
||||
649
tests/api/test_events_routes.py
Normal file
649
tests/api/test_events_routes.py
Normal file
@@ -0,0 +1,649 @@
|
||||
"""Tests for application/api/events/routes.py — the SSE endpoint.
|
||||
|
||||
The SSE generator runs in a separate thread under the WSGI test client;
|
||||
we drive it with mocked Redis (the ``pubsub.get_message`` and ``xrange``
|
||||
sequences) and read the response body until we have enough records to
|
||||
assert on, then close the response to terminate the generator.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import threading
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask, request
|
||||
|
||||
|
||||
def _make_app():
|
||||
"""Mount the events blueprint on a bare Flask app + JWT shim.
|
||||
|
||||
The shim mimics ``application/app.py`` populating
|
||||
``request.decoded_token`` so the SSE handler's auth gate sees a
|
||||
user-id without requiring the full app stack.
|
||||
"""
|
||||
from application.api.events.routes import events
|
||||
|
||||
app = Flask(__name__)
|
||||
app.register_blueprint(events)
|
||||
app.config["TESTING"] = True
|
||||
|
||||
@app.before_request
|
||||
def _shim_auth(): # noqa: D401
|
||||
header = request.headers.get("X-Test-Sub")
|
||||
request.decoded_token = {"sub": header} if header else None
|
||||
|
||||
return app
|
||||
|
||||
|
||||
class _FakePubSub:
|
||||
"""Minimal Redis pub/sub stand-in for the SSE handler.
|
||||
|
||||
``messages`` is a list of message dicts the generator should see in
|
||||
order. After exhausting it, ``get_message`` returns ``None`` (poll
|
||||
timeout) so the generator stays alive emitting keepalives until the
|
||||
test closes the response.
|
||||
"""
|
||||
|
||||
def __init__(self, messages: list[dict[str, Any]]):
|
||||
self._messages = list(messages)
|
||||
self.subscribed: list[str] = []
|
||||
self.unsubscribed: list[str] = []
|
||||
self.closed = False
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def subscribe(self, name: str):
|
||||
self.subscribed.append(name)
|
||||
|
||||
def unsubscribe(self, name: str):
|
||||
self.unsubscribed.append(name)
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
def get_message(self, timeout: float = 0):
|
||||
with self._lock:
|
||||
if self._messages:
|
||||
return self._messages.pop(0)
|
||||
return None
|
||||
|
||||
|
||||
def _drain_until(response, predicate, max_chunks: int = 200) -> bytes:
|
||||
"""Consume the streamed response until ``predicate(buf)`` is true.
|
||||
|
||||
Returns the accumulated bytes. Closes the response so the generator
|
||||
exits cleanly via GeneratorExit.
|
||||
"""
|
||||
buf = b""
|
||||
iterator = response.iter_encoded()
|
||||
for _ in range(max_chunks):
|
||||
try:
|
||||
chunk = next(iterator)
|
||||
except StopIteration:
|
||||
break
|
||||
if not chunk:
|
||||
continue
|
||||
buf += chunk
|
||||
if predicate(buf):
|
||||
break
|
||||
response.close()
|
||||
return buf
|
||||
|
||||
|
||||
# ── auth gate ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestAuthGate:
|
||||
def test_rejects_when_no_decoded_token(self):
|
||||
app = _make_app()
|
||||
with app.test_client() as c:
|
||||
r = c.get("/api/events")
|
||||
assert r.status_code == 401
|
||||
|
||||
def test_rejects_when_decoded_token_missing_sub(self):
|
||||
from application.api.events import routes as events_module
|
||||
|
||||
app = _make_app()
|
||||
|
||||
# Clear the shim's behavior — supply a decoded_token without sub.
|
||||
@app.before_request
|
||||
def _override():
|
||||
request.decoded_token = {"email": "x@y.z"}
|
||||
|
||||
with patch.object(events_module, "get_redis_instance", return_value=None):
|
||||
with app.test_client() as c:
|
||||
r = c.get("/api/events")
|
||||
assert r.status_code == 401
|
||||
|
||||
|
||||
# ── streaming response shape ────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestStreamShape:
|
||||
def test_returns_event_stream_mimetype_and_no_buffering_header(self):
|
||||
from application.api.events import routes as events_module
|
||||
|
||||
app = _make_app()
|
||||
with patch.object(events_module, "get_redis_instance", return_value=None):
|
||||
with app.test_client() as c:
|
||||
r = c.get("/api/events", headers={"X-Test-Sub": "alice"})
|
||||
assert r.status_code == 200
|
||||
assert r.mimetype == "text/event-stream"
|
||||
assert r.headers.get("Cache-Control") == "no-store"
|
||||
assert r.headers.get("X-Accel-Buffering") == "no"
|
||||
# Drain enough to see the prelude comment then close.
|
||||
body = _drain_until(r, lambda b: b": connected" in b)
|
||||
assert b": connected" in body
|
||||
|
||||
def test_emits_push_disabled_when_setting_off(self):
|
||||
from application.api.events import routes as events_module
|
||||
|
||||
app = _make_app()
|
||||
with patch.object(events_module, "get_redis_instance", return_value=None), \
|
||||
patch.object(events_module.settings, "ENABLE_SSE_PUSH", False):
|
||||
with app.test_client() as c:
|
||||
r = c.get("/api/events", headers={"X-Test-Sub": "alice"})
|
||||
body = _drain_until(r, lambda b: b": push_disabled" in b)
|
||||
assert b": push_disabled" in body
|
||||
assert b": connected" in body # prelude still emitted
|
||||
|
||||
|
||||
# ── concurrency cap ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestConcurrencyCap:
|
||||
def test_returns_429_when_user_over_cap(self):
|
||||
from application.api.events import routes as events_module
|
||||
|
||||
app = _make_app()
|
||||
redis_client = MagicMock()
|
||||
# First INCR returns 9 (over cap of 8).
|
||||
redis_client.incr.return_value = 9
|
||||
|
||||
with patch.object(events_module, "get_redis_instance", return_value=redis_client), \
|
||||
patch.object(events_module.settings, "SSE_MAX_CONCURRENT_PER_USER", 8):
|
||||
with app.test_client() as c:
|
||||
r = c.get("/api/events", headers={"X-Test-Sub": "alice"})
|
||||
assert r.status_code == 429
|
||||
# DECR fired to release the over-cap increment.
|
||||
redis_client.decr.assert_called_once_with("user:alice:sse_count")
|
||||
|
||||
def test_skips_cap_when_zero_disabled(self):
|
||||
from application.api.events import routes as events_module
|
||||
|
||||
app = _make_app()
|
||||
redis_client = MagicMock()
|
||||
|
||||
with patch.object(events_module, "get_redis_instance", return_value=redis_client), \
|
||||
patch.object(events_module.settings, "SSE_MAX_CONCURRENT_PER_USER", 0), \
|
||||
patch.object(events_module, "Topic") as mock_topic_cls:
|
||||
mock_topic = MagicMock()
|
||||
mock_topic.subscribe.return_value = iter([])
|
||||
mock_topic_cls.return_value = mock_topic
|
||||
redis_client.xinfo_stream.side_effect = Exception("no stream")
|
||||
redis_client.xrange.return_value = []
|
||||
with app.test_client() as c:
|
||||
r = c.get("/api/events", headers={"X-Test-Sub": "alice"})
|
||||
assert r.status_code == 200
|
||||
# Concurrency counter not touched when cap is 0. The
|
||||
# replay-budget INCR is unrelated and may still fire.
|
||||
incr_keys = [
|
||||
call.args[0] for call in redis_client.incr.call_args_list
|
||||
]
|
||||
assert "user:alice:sse_count" not in incr_keys
|
||||
_drain_until(r, lambda b: b": connected" in b)
|
||||
|
||||
|
||||
# ── replay + live tail ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestReplayAndTail:
|
||||
def test_replay_yields_xrange_entries_with_injected_id(self):
|
||||
from application.api.events import routes as events_module
|
||||
|
||||
app = _make_app()
|
||||
redis_client = MagicMock()
|
||||
redis_client.incr.return_value = 1
|
||||
# Empty stream (no truncation).
|
||||
redis_client.xinfo_stream.side_effect = Exception("nope")
|
||||
# XRANGE returns one stored envelope (without ``id``); the route
|
||||
# injects the entry id on the way out.
|
||||
stored_event = json.dumps(
|
||||
{
|
||||
"type": "source.ingest.progress",
|
||||
"ts": "2026-04-28T00:00:00.000Z",
|
||||
"user_id": "alice",
|
||||
"topic": "user:alice",
|
||||
"scope": {"kind": "source", "id": "src-1"},
|
||||
"payload": {"current": 25, "total": 100},
|
||||
}
|
||||
).encode()
|
||||
redis_client.xrange.return_value = [
|
||||
(b"1735682400000-0", {b"event": stored_event}),
|
||||
]
|
||||
|
||||
# Topic.subscribe yields an immediate timeout so the generator
|
||||
# keeps running long enough to flush replay; subsequent calls
|
||||
# also return None.
|
||||
from application.api.events.routes import _SSE_LINE_SPLIT # noqa: F401
|
||||
|
||||
# Fake the broadcast Topic to invoke on_subscribe immediately
|
||||
# then yield None ticks until close.
|
||||
def _fake_subscribe(self, on_subscribe=None, poll_timeout=1.0):
|
||||
if on_subscribe is not None:
|
||||
on_subscribe()
|
||||
while True:
|
||||
yield None
|
||||
|
||||
with patch.object(events_module, "get_redis_instance", return_value=redis_client), \
|
||||
patch.object(
|
||||
events_module.Topic, "subscribe", _fake_subscribe, create=False
|
||||
):
|
||||
with app.test_client() as c:
|
||||
r = c.get(
|
||||
"/api/events",
|
||||
headers={"X-Test-Sub": "alice", "Last-Event-ID": "1735682300000-0"},
|
||||
)
|
||||
body = _drain_until(
|
||||
r,
|
||||
lambda b: b'"current": 25' in b or b'"current":25' in b,
|
||||
max_chunks=80,
|
||||
)
|
||||
# Replay yields the entry id as the SSE id field.
|
||||
assert b"id: 1735682400000-0" in body
|
||||
# Envelope was rewritten to include the injected id.
|
||||
assert b'"id": "1735682400000-0"' in body or b'"id":"1735682400000-0"' in body
|
||||
# The connect log fires before replay.
|
||||
assert b": connected" in body
|
||||
|
||||
def test_snapshot_flushed_when_subscribe_dies_after_callback(self):
|
||||
"""Regression: if ``on_subscribe`` populated ``replay_lines`` but
|
||||
``Topic.subscribe`` exits before yielding once (transient Redis
|
||||
hiccup between SUBSCRIBE-ack and the first poll), the snapshot
|
||||
must still reach the client. Prior to the fix the in-loop flush
|
||||
was the only flush, so the backlog was silently dropped.
|
||||
"""
|
||||
from application.api.events import routes as events_module
|
||||
|
||||
app = _make_app()
|
||||
redis_client = MagicMock()
|
||||
redis_client.incr.return_value = 1
|
||||
redis_client.xinfo_stream.side_effect = Exception("nope")
|
||||
stored_event = json.dumps(
|
||||
{
|
||||
"type": "notification",
|
||||
"payload": {"text": "from snapshot"},
|
||||
}
|
||||
).encode()
|
||||
redis_client.xrange.return_value = [
|
||||
(b"1735682400000-0", {b"event": stored_event}),
|
||||
]
|
||||
|
||||
# Mimic the broadcast_channel race: SUBSCRIBE acks, on_subscribe
|
||||
# runs, then the next get_message raises and the generator
|
||||
# returns without ever yielding.
|
||||
def _subscribe_dies_after_callback(
|
||||
self, on_subscribe=None, poll_timeout=1.0
|
||||
):
|
||||
if on_subscribe is not None:
|
||||
on_subscribe()
|
||||
return
|
||||
yield # pragma: no cover (make the function a generator)
|
||||
|
||||
with patch.object(events_module, "get_redis_instance", return_value=redis_client), \
|
||||
patch.object(
|
||||
events_module.Topic,
|
||||
"subscribe",
|
||||
_subscribe_dies_after_callback,
|
||||
create=False,
|
||||
):
|
||||
with app.test_client() as c:
|
||||
r = c.get(
|
||||
"/api/events",
|
||||
headers={
|
||||
"X-Test-Sub": "alice",
|
||||
"Last-Event-ID": "1735682300000-0",
|
||||
},
|
||||
)
|
||||
body = _drain_until(
|
||||
r,
|
||||
lambda b: b"from snapshot" in b,
|
||||
max_chunks=80,
|
||||
)
|
||||
# Snapshot frame must have been flushed via the post-loop
|
||||
# safety net even though Topic.subscribe exited before
|
||||
# the in-loop flush could fire.
|
||||
assert b"id: 1735682400000-0" in body
|
||||
assert b"from snapshot" in body
|
||||
# XRANGE was issued exactly once (no double-flush).
|
||||
redis_client.xrange.assert_called_once()
|
||||
|
||||
def test_invalid_last_event_id_emits_truncation_notice(self):
|
||||
from application.api.events import routes as events_module
|
||||
|
||||
app = _make_app()
|
||||
redis_client = MagicMock()
|
||||
redis_client.incr.return_value = 1
|
||||
redis_client.xinfo_stream.return_value = {"first-entry": [b"1-0", []]}
|
||||
redis_client.xrange.return_value = []
|
||||
|
||||
def _fake_subscribe(self, on_subscribe=None, poll_timeout=1.0):
|
||||
if on_subscribe is not None:
|
||||
on_subscribe()
|
||||
while True:
|
||||
yield None
|
||||
|
||||
with patch.object(events_module, "get_redis_instance", return_value=redis_client), \
|
||||
patch.object(events_module.Topic, "subscribe", _fake_subscribe, create=False):
|
||||
with app.test_client() as c:
|
||||
r = c.get(
|
||||
"/api/events",
|
||||
headers={"X-Test-Sub": "alice", "Last-Event-ID": "definitely-not-an-id"},
|
||||
)
|
||||
body = _drain_until(
|
||||
r, lambda b: b"backlog.truncated" in b, max_chunks=80
|
||||
)
|
||||
assert b"backlog.truncated" in body
|
||||
|
||||
def test_live_tail_rejects_malformed_event_id_for_dedupe(self):
|
||||
"""A pub/sub envelope carrying a non-Redis-Streams ``id`` must not
|
||||
seed the dedup floor. Otherwise an adversarial or buggy publisher
|
||||
could ship ``id="9999999999999-9"`` (lex-greater than any real
|
||||
id) and pin every subsequent legitimate event below the floor,
|
||||
silently dropping the user's notifications.
|
||||
|
||||
The event itself should still be delivered to the client — we
|
||||
just refuse to use the bogus id for ordering, so it ships
|
||||
without an SSE ``id:`` header and ``max_replayed_id`` stays put.
|
||||
"""
|
||||
from application.api.events import routes as events_module
|
||||
|
||||
app = _make_app()
|
||||
redis_client = MagicMock()
|
||||
redis_client.incr.return_value = 1
|
||||
redis_client.xinfo_stream.side_effect = Exception("nope")
|
||||
# Snapshot covers ids up to 1735682400000-0; max_replayed_id
|
||||
# becomes that value after the in-loop flush.
|
||||
replay_event = json.dumps({
|
||||
"type": "source.ingest.progress",
|
||||
"payload": {"step": "replay"},
|
||||
}).encode()
|
||||
redis_client.xrange.return_value = [
|
||||
(b"1735682400000-0", {b"event": replay_event}),
|
||||
]
|
||||
|
||||
live_bogus = json.dumps({
|
||||
"id": "definitely-not-an-id",
|
||||
"type": "source.ingest.completed",
|
||||
"payload": {"step": "live-bogus"},
|
||||
})
|
||||
live_real = json.dumps({
|
||||
"id": "1735682500000-0",
|
||||
"type": "source.ingest.completed",
|
||||
"payload": {"step": "live-real"},
|
||||
})
|
||||
|
||||
def _fake_subscribe(self, on_subscribe=None, poll_timeout=1.0):
|
||||
# ``Topic.subscribe`` already unpacks redis-py pubsub dicts
|
||||
# and yields the raw ``data`` bytes (or ``None`` on poll
|
||||
# timeout). Mirror that contract.
|
||||
if on_subscribe is not None:
|
||||
on_subscribe()
|
||||
yield live_bogus.encode()
|
||||
yield live_real.encode()
|
||||
while True:
|
||||
yield None
|
||||
|
||||
with patch.object(
|
||||
events_module, "get_redis_instance", return_value=redis_client
|
||||
), patch.object(
|
||||
events_module.Topic, "subscribe", _fake_subscribe, create=False
|
||||
):
|
||||
with app.test_client() as c:
|
||||
r = c.get(
|
||||
"/api/events",
|
||||
headers={
|
||||
"X-Test-Sub": "alice",
|
||||
"Last-Event-ID": "1735682300000-0",
|
||||
},
|
||||
)
|
||||
body = _drain_until(
|
||||
r, lambda b: b"live-real" in b, max_chunks=80
|
||||
)
|
||||
|
||||
# Live-real arrived (its id is strictly greater than the
|
||||
# replayed snapshot's id), with its valid id surfaced as
|
||||
# the SSE ``id:`` header so the frontend can advance.
|
||||
assert b"live-real" in body
|
||||
assert b"id: 1735682500000-0" in body
|
||||
|
||||
# The bogus-id event was still delivered to the client,
|
||||
# but no ``id: definitely-not-an-id`` line was emitted —
|
||||
# the malformed id never reached the SSE wire and so
|
||||
# could not pin the dedup floor.
|
||||
assert b"live-bogus" in body
|
||||
assert b"id: definitely-not-an-id" not in body
|
||||
|
||||
|
||||
# ── format helpers (already covered in test_events_substrate but
|
||||
# duplicated here as a smoke for the route's surface) ─────────────────
|
||||
|
||||
|
||||
class TestReplayRateLimit:
|
||||
"""Enumeration defenses on the per-user backlog."""
|
||||
|
||||
def test_allow_replay_returns_true_when_budget_disabled(self):
|
||||
from application.api.events.routes import _allow_replay
|
||||
|
||||
with patch("application.api.events.routes.settings") as mock_settings:
|
||||
mock_settings.EVENTS_REPLAY_BUDGET_REQUESTS_PER_WINDOW = 0
|
||||
mock_settings.EVENTS_REPLAY_BUDGET_WINDOW_SECONDS = 60
|
||||
assert _allow_replay(MagicMock(), "alice", "1735682400000-0") is True
|
||||
|
||||
def test_allow_replay_returns_true_when_redis_unavailable(self):
|
||||
from application.api.events.routes import _allow_replay
|
||||
|
||||
with patch("application.api.events.routes.settings") as mock_settings:
|
||||
mock_settings.EVENTS_REPLAY_BUDGET_REQUESTS_PER_WINDOW = 5
|
||||
mock_settings.EVENTS_REPLAY_BUDGET_WINDOW_SECONDS = 60
|
||||
assert _allow_replay(None, "alice", "1735682400000-0") is True
|
||||
|
||||
def test_allow_replay_skips_incr_when_no_cursor_and_empty_backlog(self):
|
||||
"""Fresh client with no cursor and an empty user stream cannot
|
||||
do snapshot work — INCR'ing the counter would needlessly
|
||||
burn budget. Catches the React-StrictMode dev-burst case where
|
||||
double-mounted components would otherwise 429 in 5 connects.
|
||||
"""
|
||||
from application.api.events.routes import _allow_replay
|
||||
|
||||
with patch("application.api.events.routes.settings") as mock_settings:
|
||||
mock_settings.EVENTS_REPLAY_BUDGET_REQUESTS_PER_WINDOW = 3
|
||||
mock_settings.EVENTS_REPLAY_BUDGET_WINDOW_SECONDS = 60
|
||||
redis = MagicMock()
|
||||
redis.xlen.return_value = 0
|
||||
|
||||
# 5 connects in a row, all with no cursor — none consume
|
||||
# budget because the backlog is empty.
|
||||
for _ in range(5):
|
||||
assert _allow_replay(redis, "alice", None) is True
|
||||
|
||||
redis.xlen.assert_called()
|
||||
redis.incr.assert_not_called()
|
||||
|
||||
def test_allow_replay_incrs_when_no_cursor_but_backlog_present(self):
|
||||
"""A no-cursor connect against a non-empty backlog *will* do
|
||||
snapshot work, so it consumes budget normally.
|
||||
"""
|
||||
from application.api.events.routes import _allow_replay
|
||||
|
||||
with patch("application.api.events.routes.settings") as mock_settings:
|
||||
mock_settings.EVENTS_REPLAY_BUDGET_REQUESTS_PER_WINDOW = 5
|
||||
mock_settings.EVENTS_REPLAY_BUDGET_WINDOW_SECONDS = 60
|
||||
redis = MagicMock()
|
||||
redis.xlen.return_value = 42
|
||||
redis.incr.return_value = 1
|
||||
|
||||
assert _allow_replay(redis, "alice", None) is True
|
||||
redis.incr.assert_called_once()
|
||||
|
||||
def test_allow_replay_passes_until_budget_exhausted(self):
|
||||
from application.api.events.routes import _allow_replay
|
||||
|
||||
with patch("application.api.events.routes.settings") as mock_settings:
|
||||
mock_settings.EVENTS_REPLAY_BUDGET_REQUESTS_PER_WINDOW = 3
|
||||
mock_settings.EVENTS_REPLAY_BUDGET_WINDOW_SECONDS = 60
|
||||
redis = MagicMock()
|
||||
counter = {"v": 0}
|
||||
|
||||
def _incr(_key):
|
||||
counter["v"] += 1
|
||||
return counter["v"]
|
||||
|
||||
redis.incr.side_effect = _incr
|
||||
|
||||
# Cursor set → XLEN short-circuit doesn't fire, INCR always runs.
|
||||
cursor = "1735682400000-0"
|
||||
# First three pass.
|
||||
assert _allow_replay(redis, "alice", cursor) is True
|
||||
assert _allow_replay(redis, "alice", cursor) is True
|
||||
assert _allow_replay(redis, "alice", cursor) is True
|
||||
# Fourth refused.
|
||||
assert _allow_replay(redis, "alice", cursor) is False
|
||||
# TTL re-seeded on every successful INCR so a transient
|
||||
# EXPIRE failure on the seeding call can't wedge the key.
|
||||
assert redis.expire.call_count == 4
|
||||
for call in redis.expire.call_args_list:
|
||||
assert call.args[1] == 60
|
||||
|
||||
def test_allow_replay_fail_open_on_redis_error(self):
|
||||
from application.api.events.routes import _allow_replay
|
||||
|
||||
with patch("application.api.events.routes.settings") as mock_settings:
|
||||
mock_settings.EVENTS_REPLAY_BUDGET_REQUESTS_PER_WINDOW = 5
|
||||
mock_settings.EVENTS_REPLAY_BUDGET_WINDOW_SECONDS = 60
|
||||
redis = MagicMock()
|
||||
redis.incr.side_effect = Exception("redis down")
|
||||
assert _allow_replay(redis, "alice", "1735682400000-0") is True
|
||||
|
||||
def test_allow_replay_recovers_when_seeding_expire_raises(self):
|
||||
"""Regression: INCR=1 then EXPIRE raising must not wedge the key.
|
||||
|
||||
Earlier code only called EXPIRE when ``used == 1``. If that EXPIRE
|
||||
raised, the counter persisted with no TTL and every subsequent
|
||||
call hit ``used > 1`` without re-seeding — locking the user out
|
||||
until an operator DEL'd the key. The fix calls EXPIRE on every
|
||||
successful INCR so the next call still re-seeds the TTL.
|
||||
"""
|
||||
from application.api.events.routes import _allow_replay
|
||||
|
||||
with patch("application.api.events.routes.settings") as mock_settings:
|
||||
mock_settings.EVENTS_REPLAY_BUDGET_REQUESTS_PER_WINDOW = 5
|
||||
mock_settings.EVENTS_REPLAY_BUDGET_WINDOW_SECONDS = 60
|
||||
redis = MagicMock()
|
||||
counter = {"v": 0}
|
||||
|
||||
def _incr(_key):
|
||||
counter["v"] += 1
|
||||
return counter["v"]
|
||||
|
||||
redis.incr.side_effect = _incr
|
||||
# First EXPIRE raises (the seeding call that would have
|
||||
# wedged the key under the old gated logic). Second EXPIRE
|
||||
# succeeds — and crucially, must still run.
|
||||
redis.expire.side_effect = [Exception("expire blip"), True]
|
||||
|
||||
cursor = "1735682400000-0"
|
||||
# First call: INCR=1 succeeds, EXPIRE raises -> outer except
|
||||
# returns True (fail-open) for this call.
|
||||
assert _allow_replay(redis, "alice", cursor) is True
|
||||
# Second call: INCR=2, EXPIRE succeeds -> still under budget,
|
||||
# and the TTL is now seeded (no permanent lockout).
|
||||
assert _allow_replay(redis, "alice", cursor) is True
|
||||
|
||||
assert redis.expire.call_count == 2
|
||||
# Both EXPIRE calls used the configured window.
|
||||
for call in redis.expire.call_args_list:
|
||||
assert call.args[1] == 60
|
||||
|
||||
def test_replay_backlog_passes_count_to_xrange(self):
|
||||
from application.api.events.routes import _replay_backlog
|
||||
|
||||
redis = MagicMock()
|
||||
redis.xrange.return_value = []
|
||||
# Drain the iterator so xrange is actually called.
|
||||
list(_replay_backlog(redis, "alice", None, 200))
|
||||
redis.xrange.assert_called_once()
|
||||
kwargs = redis.xrange.call_args.kwargs
|
||||
assert kwargs.get("count") == 200
|
||||
|
||||
def test_returns_429_when_replay_budget_exhausted(self):
|
||||
"""Route refuses the connection rather than serving live tail
|
||||
only. Earlier behavior silently skipped replay and let the
|
||||
client advance ``lastEventId`` via id-bearing live frames,
|
||||
permanently stranding the un-replayed window. The 429 keeps
|
||||
the cursor pinned so the next reconnect (after the budget
|
||||
window slides) can replay normally.
|
||||
"""
|
||||
from application.api.events import routes as events_module
|
||||
|
||||
app = _make_app()
|
||||
redis_client = MagicMock()
|
||||
|
||||
def _incr(key):
|
||||
if key == "user:alice:sse_count":
|
||||
return 1
|
||||
# Budget counter: report over-limit.
|
||||
return 31
|
||||
|
||||
redis_client.incr.side_effect = _incr
|
||||
|
||||
with patch.object(
|
||||
events_module, "get_redis_instance", return_value=redis_client
|
||||
), patch.object(
|
||||
events_module.settings,
|
||||
"EVENTS_REPLAY_BUDGET_REQUESTS_PER_WINDOW",
|
||||
30,
|
||||
):
|
||||
with app.test_client() as c:
|
||||
r = c.get(
|
||||
"/api/events",
|
||||
headers={
|
||||
"X-Test-Sub": "alice",
|
||||
"Last-Event-ID": "1735682300000-0",
|
||||
},
|
||||
)
|
||||
assert r.status_code == 429
|
||||
# Concurrency slot is released so a budget-denied request
|
||||
# doesn't permanently consume a connection from the cap.
|
||||
redis_client.decr.assert_called_once_with("user:alice:sse_count")
|
||||
|
||||
|
||||
class TestFormatHelpers:
|
||||
def test_format_sse_two_terminating_newlines(self):
|
||||
from application.api.events.routes import _format_sse
|
||||
|
||||
out = _format_sse("hello", event_id="1-0")
|
||||
assert out.endswith("\n\n")
|
||||
# Exactly one ``id:`` and one ``data:``.
|
||||
lines = out.rstrip("\n").split("\n")
|
||||
assert lines == ["id: 1-0", "data: hello"]
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"candidate, expected",
|
||||
[
|
||||
("1234", "1234"),
|
||||
("1234-5", "1234-5"),
|
||||
(" 1234-0 ", "1234-0"),
|
||||
(None, None),
|
||||
("", None),
|
||||
(" ", None),
|
||||
("nope", None),
|
||||
("1234-foo", None),
|
||||
],
|
||||
)
|
||||
def test_normalize_last_event_id(self, candidate, expected):
|
||||
from application.api.events.routes import _normalize_last_event_id
|
||||
|
||||
assert _normalize_last_event_id(candidate) == expected
|
||||
220
tests/api/test_message_stream_routes.py
Normal file
220
tests/api/test_message_stream_routes.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""Tests for ``application/api/answer/routes/messages.py``.
|
||||
|
||||
Reconnect endpoint: GET /api/messages/<id>/events. Auth gate, ownership
|
||||
gate, malformed-id rejection, Last-Event-ID normalisation, and a smoke
|
||||
test that the SSE response shape matches the user-events endpoint.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from flask import Flask, request
|
||||
|
||||
from application.api.answer.routes.messages import (
|
||||
_MESSAGE_ID_RE,
|
||||
_normalise_last_event_id,
|
||||
messages_bp,
|
||||
)
|
||||
|
||||
|
||||
def _make_app(decoded_token=None):
|
||||
app = Flask(__name__)
|
||||
app.register_blueprint(messages_bp)
|
||||
app.config["TESTING"] = True
|
||||
|
||||
@app.before_request
|
||||
def _shim_auth():
|
||||
request.decoded_token = decoded_token
|
||||
|
||||
return app
|
||||
|
||||
|
||||
VALID_UUID = "67d65e8f-e7fb-4df1-9e6e-99ea6c830206"
|
||||
|
||||
|
||||
class TestNormaliseLastEventId:
|
||||
def test_none_passthrough(self):
|
||||
assert _normalise_last_event_id(None) is None
|
||||
|
||||
def test_empty_string(self):
|
||||
assert _normalise_last_event_id("") is None
|
||||
|
||||
def test_whitespace_only(self):
|
||||
assert _normalise_last_event_id(" ") is None
|
||||
|
||||
def test_valid_int(self):
|
||||
assert _normalise_last_event_id("42") == 42
|
||||
|
||||
def test_stripped_whitespace(self):
|
||||
assert _normalise_last_event_id(" 7 ") == 7
|
||||
|
||||
def test_zero_is_valid(self):
|
||||
assert _normalise_last_event_id("0") == 0
|
||||
|
||||
def test_negative_rejected(self):
|
||||
# We expose only non-negative cursors; -1 is reserved for the
|
||||
# snapshot-failure synthetic terminal event and shouldn't
|
||||
# round-trip back.
|
||||
assert _normalise_last_event_id("-1") is None
|
||||
|
||||
def test_non_numeric_rejected(self):
|
||||
for bad in ("foo", "1.5", "1e3", "abc-123", "null"):
|
||||
assert _normalise_last_event_id(bad) is None, bad
|
||||
|
||||
|
||||
class TestMessageIdRegex:
|
||||
def test_canonical_uuid_accepted(self):
|
||||
assert _MESSAGE_ID_RE.match(VALID_UUID)
|
||||
|
||||
def test_uppercase_uuid_accepted(self):
|
||||
assert _MESSAGE_ID_RE.match(VALID_UUID.upper())
|
||||
|
||||
def test_no_dashes_rejected(self):
|
||||
assert not _MESSAGE_ID_RE.match(VALID_UUID.replace("-", ""))
|
||||
|
||||
def test_legacy_mongo_id_rejected(self):
|
||||
# 24-char hex with no dashes — a Mongo objectid-shaped string
|
||||
# that happened to leak through somewhere.
|
||||
assert not _MESSAGE_ID_RE.match("507f1f77bcf86cd799439011")
|
||||
|
||||
|
||||
class TestAuthGate:
|
||||
def test_401_when_no_decoded_token(self):
|
||||
app = _make_app(decoded_token=None)
|
||||
with app.test_client() as c:
|
||||
r = c.get(f"/api/messages/{VALID_UUID}/events")
|
||||
assert r.status_code == 401
|
||||
|
||||
def test_401_when_decoded_token_missing_sub(self):
|
||||
app = _make_app(decoded_token={"email": "x@y"})
|
||||
with app.test_client() as c:
|
||||
r = c.get(f"/api/messages/{VALID_UUID}/events")
|
||||
assert r.status_code == 401
|
||||
|
||||
|
||||
class TestMessageIdValidation:
|
||||
def test_400_on_malformed_id(self):
|
||||
app = _make_app(decoded_token={"sub": "alice"})
|
||||
with app.test_client() as c:
|
||||
r = c.get("/api/messages/not-a-uuid/events")
|
||||
assert r.status_code == 400
|
||||
|
||||
|
||||
class TestOwnershipGate:
|
||||
def test_404_when_user_does_not_own_message(self):
|
||||
from application.api.answer.routes import messages as messages_module
|
||||
|
||||
app = _make_app(decoded_token={"sub": "alice"})
|
||||
with patch.object(
|
||||
messages_module, "_user_owns_message", return_value=False
|
||||
):
|
||||
with app.test_client() as c:
|
||||
r = c.get(f"/api/messages/{VALID_UUID}/events")
|
||||
assert r.status_code == 404
|
||||
|
||||
def test_200_when_user_owns_message(self):
|
||||
from application.api.answer.routes import messages as messages_module
|
||||
|
||||
app = _make_app(decoded_token={"sub": "alice"})
|
||||
|
||||
# Have build_message_event_stream yield just the prelude then
|
||||
# exit so the test can drain the response without blocking on
|
||||
# a live pubsub subscription.
|
||||
def _fake_builder(message_id, last_event_id=None, **kwargs):
|
||||
yield ": connected\n\n"
|
||||
|
||||
with patch.object(
|
||||
messages_module, "_user_owns_message", return_value=True
|
||||
), patch.object(
|
||||
messages_module, "build_message_event_stream", _fake_builder
|
||||
):
|
||||
with app.test_client() as c:
|
||||
r = c.get(f"/api/messages/{VALID_UUID}/events")
|
||||
assert r.status_code == 200
|
||||
assert r.mimetype == "text/event-stream"
|
||||
assert r.headers.get("Cache-Control") == "no-store"
|
||||
assert r.headers.get("X-Accel-Buffering") == "no"
|
||||
body = b""
|
||||
for chunk in r.iter_encoded():
|
||||
body += chunk
|
||||
if b": connected" in body:
|
||||
break
|
||||
r.close()
|
||||
assert b": connected" in body
|
||||
|
||||
|
||||
class TestLastEventIdParsing:
|
||||
def test_header_passes_through_to_builder(self):
|
||||
from application.api.answer.routes import messages as messages_module
|
||||
|
||||
captured = {}
|
||||
|
||||
def _fake_builder(message_id, last_event_id=None, **kwargs):
|
||||
captured["message_id"] = message_id
|
||||
captured["last_event_id"] = last_event_id
|
||||
yield ": connected\n\n"
|
||||
|
||||
app = _make_app(decoded_token={"sub": "alice"})
|
||||
with patch.object(
|
||||
messages_module, "_user_owns_message", return_value=True
|
||||
), patch.object(
|
||||
messages_module, "build_message_event_stream", _fake_builder
|
||||
):
|
||||
with app.test_client() as c:
|
||||
r = c.get(
|
||||
f"/api/messages/{VALID_UUID}/events",
|
||||
headers={"Last-Event-ID": "12"},
|
||||
)
|
||||
# Drain a tick.
|
||||
next(iter(r.iter_encoded()), None)
|
||||
r.close()
|
||||
assert captured["message_id"] == VALID_UUID
|
||||
assert captured["last_event_id"] == 12
|
||||
|
||||
def test_query_param_fallback(self):
|
||||
from application.api.answer.routes import messages as messages_module
|
||||
|
||||
captured = {}
|
||||
|
||||
def _fake_builder(message_id, last_event_id=None, **kwargs):
|
||||
captured["last_event_id"] = last_event_id
|
||||
yield ": connected\n\n"
|
||||
|
||||
app = _make_app(decoded_token={"sub": "alice"})
|
||||
with patch.object(
|
||||
messages_module, "_user_owns_message", return_value=True
|
||||
), patch.object(
|
||||
messages_module, "build_message_event_stream", _fake_builder
|
||||
):
|
||||
with app.test_client() as c:
|
||||
r = c.get(
|
||||
f"/api/messages/{VALID_UUID}/events?last_event_id=5"
|
||||
)
|
||||
next(iter(r.iter_encoded()), None)
|
||||
r.close()
|
||||
assert captured["last_event_id"] == 5
|
||||
|
||||
def test_invalid_cursor_normalised_to_none(self):
|
||||
from application.api.answer.routes import messages as messages_module
|
||||
|
||||
captured = {}
|
||||
|
||||
def _fake_builder(message_id, last_event_id=None, **kwargs):
|
||||
captured["last_event_id"] = last_event_id
|
||||
yield ": connected\n\n"
|
||||
|
||||
app = _make_app(decoded_token={"sub": "alice"})
|
||||
with patch.object(
|
||||
messages_module, "_user_owns_message", return_value=True
|
||||
), patch.object(
|
||||
messages_module, "build_message_event_stream", _fake_builder
|
||||
):
|
||||
with app.test_client() as c:
|
||||
r = c.get(
|
||||
f"/api/messages/{VALID_UUID}/events",
|
||||
headers={"Last-Event-ID": "definitely-not-a-number"},
|
||||
)
|
||||
next(iter(r.iter_encoded()), None)
|
||||
r.close()
|
||||
assert captured["last_event_id"] is None
|
||||
@@ -543,6 +543,84 @@ class TestRemoteIdempotency:
|
||||
assert response.status_code == 400
|
||||
assert mock_apply.call_count == 0
|
||||
|
||||
def test_no_header_returns_source_id_matching_worker_kwarg(
|
||||
self, app, pg_conn,
|
||||
):
|
||||
"""Regression: without an ``Idempotency-Key``, the route must
|
||||
still return a ``source_id`` AND pass that same id to the worker
|
||||
as ``source_id`` so SSE envelopes line up with what the
|
||||
frontend already has. Previously the route omitted ``source_id``
|
||||
entirely on the no-key path and the worker minted its own
|
||||
random uuid, breaking push correlation for the default upload
|
||||
flow.
|
||||
"""
|
||||
from application.api.user.sources.upload import UploadRemote
|
||||
|
||||
apply_mock = _apply_async_mock()
|
||||
with _patch_db(pg_conn), patch(
|
||||
"application.api.user.sources.upload.ingest_remote.apply_async",
|
||||
apply_mock,
|
||||
), app.test_request_context(
|
||||
"/api/remote", method="POST",
|
||||
data={
|
||||
"user": "u", "source": "github", "name": "g",
|
||||
"data": json.dumps({"repo_url": "https://github.com/x/y"}),
|
||||
},
|
||||
content_type="multipart/form-data",
|
||||
):
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": "u"}
|
||||
response = UploadRemote().post()
|
||||
assert response.status_code == 200
|
||||
assert "source_id" in response.json
|
||||
assert (
|
||||
apply_mock.call_args.kwargs["kwargs"]["source_id"]
|
||||
== response.json["source_id"]
|
||||
)
|
||||
|
||||
def test_no_header_connector_returns_source_id_matching_worker_kwarg(
|
||||
self, app, pg_conn,
|
||||
):
|
||||
"""Same regression as above for the connector branch
|
||||
(``ingest_connector_task``). The connector path took the
|
||||
no-key gap independently of the plain remote path."""
|
||||
from application.api.user.sources.upload import UploadRemote
|
||||
|
||||
apply_mock = _apply_async_mock()
|
||||
# Pick any registered connector — the route only branches on
|
||||
# ``ConnectorCreator.get_supported_connectors()``.
|
||||
from application.parser.connectors.connector_creator import (
|
||||
ConnectorCreator,
|
||||
)
|
||||
supported = ConnectorCreator.get_supported_connectors()
|
||||
if not supported:
|
||||
pytest.skip("no connectors registered in this build")
|
||||
connector_source = next(iter(supported))
|
||||
|
||||
with _patch_db(pg_conn), patch(
|
||||
"application.api.user.sources.upload.ingest_connector_task.apply_async",
|
||||
apply_mock,
|
||||
), app.test_request_context(
|
||||
"/api/remote", method="POST",
|
||||
data={
|
||||
"user": "u", "source": connector_source, "name": "g",
|
||||
"data": json.dumps({
|
||||
"session_token": "tok",
|
||||
"file_ids": ["f1"],
|
||||
}),
|
||||
},
|
||||
content_type="multipart/form-data",
|
||||
):
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": "u"}
|
||||
response = UploadRemote().post()
|
||||
assert response.status_code == 200
|
||||
assert "source_id" in response.json
|
||||
assert (
|
||||
apply_mock.call_args.kwargs["kwargs"]["source_id"]
|
||||
== response.json["source_id"]
|
||||
)
|
||||
|
||||
|
||||
def _seed_source(pg_conn, user="u", **kw):
|
||||
from application.storage.db.repositories.sources import SourcesRepository
|
||||
@@ -673,10 +751,140 @@ class TestManageSourceFilesIdempotency:
|
||||
assert apply_mock.call_count == 1
|
||||
# Loser's response carries the winner's task_id, not the
|
||||
# original 200-with-added_files payload.
|
||||
assert second.json["task_id"] == first.json["reingest_task_id"]
|
||||
# ``manage_source_files`` aliases ``task_id`` ->
|
||||
# ``reingest_task_id`` in the cached payload so the dedup
|
||||
# response shape matches the fresh-request response (FileTree
|
||||
# keys reingest correlation on ``reingest_task_id`` /
|
||||
# ``source_id``).
|
||||
assert second.json["reingest_task_id"] == first.json["reingest_task_id"]
|
||||
# Cached ``source_id`` must equal the real source row id (not
|
||||
# the helper's uuid5-of-key) so FileTree's SSE correlation on
|
||||
# ``event.scope.id === result.source_id`` keeps working.
|
||||
assert second.json["source_id"] == first.json["source_id"]
|
||||
assert second.json["source_id"] == str(src["id"])
|
||||
# Confirm the loser never invoked the file-save path.
|
||||
assert fake_storage.save_file.call_count == 1
|
||||
|
||||
def test_remove_same_key_second_post_returns_real_source_id(
|
||||
self, app, pg_conn
|
||||
):
|
||||
"""Regression: the ``remove`` cached branch used to leave the
|
||||
helper's synthetic ``source_id`` (uuid5 of the scoped key) in
|
||||
place. The reingest worker publishes SSE events tagged with the
|
||||
real source row id, so the cached response had to be patched to
|
||||
match what the fresh response returns — otherwise FileTree's
|
||||
SSE correlation silently fails on every idempotent retry and
|
||||
the user never sees the directory refresh.
|
||||
"""
|
||||
from application.api.user.sources.upload import ManageSourceFiles
|
||||
|
||||
user = "alice-mgr-rmrep"
|
||||
src = _seed_source(
|
||||
pg_conn,
|
||||
user=user,
|
||||
file_path="/data",
|
||||
file_name_map={"a.txt": "a.txt"},
|
||||
)
|
||||
|
||||
fake_storage = MagicMock()
|
||||
fake_storage.file_exists.return_value = True
|
||||
apply_mock = _apply_async_mock()
|
||||
|
||||
def _do_remove():
|
||||
return app.test_request_context(
|
||||
"/api/manage_source_files",
|
||||
method="POST",
|
||||
data={
|
||||
"source_id": str(src["id"]),
|
||||
"operation": "remove",
|
||||
"file_paths": json.dumps(["a.txt"]),
|
||||
},
|
||||
content_type="multipart/form-data",
|
||||
headers={"Idempotency-Key": "mgr-rmrep"},
|
||||
)
|
||||
|
||||
with _patch_db(pg_conn), patch(
|
||||
"application.api.user.sources.upload.StorageCreator.get_storage",
|
||||
return_value=fake_storage,
|
||||
), patch(
|
||||
"application.api.user.tasks.reingest_source_task.apply_async",
|
||||
apply_mock,
|
||||
):
|
||||
with _do_remove():
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": user}
|
||||
first = ManageSourceFiles().post()
|
||||
with _do_remove():
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": user}
|
||||
second = ManageSourceFiles().post()
|
||||
|
||||
assert first.status_code == 200
|
||||
assert second.status_code == 200
|
||||
assert apply_mock.call_count == 1
|
||||
assert second.json["reingest_task_id"] == first.json["reingest_task_id"]
|
||||
# The contract under test: cached source_id matches the fresh
|
||||
# response (the real source row id), not the helper's uuid5.
|
||||
assert second.json["source_id"] == first.json["source_id"]
|
||||
assert second.json["source_id"] == str(src["id"])
|
||||
|
||||
def test_remove_directory_same_key_second_post_returns_real_source_id(
|
||||
self, app, pg_conn
|
||||
):
|
||||
"""Same regression as the ``remove`` test, for the
|
||||
``remove_directory`` branch.
|
||||
"""
|
||||
from application.api.user.sources.upload import ManageSourceFiles
|
||||
|
||||
user = "alice-mgr-rmdir-rep"
|
||||
src = _seed_source(
|
||||
pg_conn,
|
||||
user=user,
|
||||
file_path="/data",
|
||||
file_name_map={"sub/a.txt": "a.txt"},
|
||||
)
|
||||
|
||||
fake_storage = MagicMock()
|
||||
fake_storage.is_directory.return_value = True
|
||||
fake_storage.remove_directory.return_value = True
|
||||
apply_mock = _apply_async_mock()
|
||||
|
||||
def _do_remove_dir():
|
||||
return app.test_request_context(
|
||||
"/api/manage_source_files",
|
||||
method="POST",
|
||||
data={
|
||||
"source_id": str(src["id"]),
|
||||
"operation": "remove_directory",
|
||||
"directory_path": "sub",
|
||||
},
|
||||
content_type="multipart/form-data",
|
||||
headers={"Idempotency-Key": "mgr-rmdir-rep"},
|
||||
)
|
||||
|
||||
with _patch_db(pg_conn), patch(
|
||||
"application.api.user.sources.upload.StorageCreator.get_storage",
|
||||
return_value=fake_storage,
|
||||
), patch(
|
||||
"application.api.user.tasks.reingest_source_task.apply_async",
|
||||
apply_mock,
|
||||
):
|
||||
with _do_remove_dir():
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": user}
|
||||
first = ManageSourceFiles().post()
|
||||
with _do_remove_dir():
|
||||
from flask import request
|
||||
request.decoded_token = {"sub": user}
|
||||
second = ManageSourceFiles().post()
|
||||
|
||||
assert first.status_code == 200
|
||||
assert second.status_code == 200
|
||||
assert apply_mock.call_count == 1
|
||||
assert second.json["reingest_task_id"] == first.json["reingest_task_id"]
|
||||
assert second.json["source_id"] == first.json["source_id"]
|
||||
assert second.json["source_id"] == str(src["id"])
|
||||
|
||||
def test_concurrent_same_key_only_one_apply_async(self, app, pg_engine):
|
||||
"""N parallel same-key POSTs → exactly one apply_async."""
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
@@ -506,6 +506,60 @@ class TestGetMessageTail:
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_streaming_row_returns_partial_from_journal(self, app, pg_conn):
|
||||
"""Mid-stream rows must rebuild from message_events, not return the placeholder."""
|
||||
from application.api.user.conversations.routes import GetMessageTail
|
||||
from application.storage.db.repositories.message_events import (
|
||||
MessageEventsRepository,
|
||||
)
|
||||
|
||||
owner = "user-tail-partial"
|
||||
_, msg_id = self._seed_in_flight_message(pg_conn, owner)
|
||||
events_repo = MessageEventsRepository(pg_conn)
|
||||
events_repo.record(msg_id, 0, "message_id", {"type": "message_id"})
|
||||
events_repo.record(msg_id, 1, "answer", {"type": "answer", "answer": "Hello"})
|
||||
events_repo.record(msg_id, 2, "answer", {"type": "answer", "answer": ", world"})
|
||||
events_repo.record(
|
||||
msg_id, 3, "source", {"type": "source", "source": [{"id": "s1"}]}
|
||||
)
|
||||
|
||||
with _patch_conversations_db(pg_conn), app.test_request_context(
|
||||
f"/api/messages/{msg_id}/tail"
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": owner}
|
||||
response = GetMessageTail().get(msg_id)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json["status"] == "streaming"
|
||||
assert response.json["response"] == "Hello, world"
|
||||
assert response.json["sources"] == [{"id": "s1"}]
|
||||
assert "terminated prior to completion" not in (
|
||||
response.json["response"] or ""
|
||||
)
|
||||
|
||||
def test_streaming_row_with_empty_journal_returns_empty_response(
|
||||
self, app, pg_conn
|
||||
):
|
||||
"""Empty journal returns empty response, not the placeholder."""
|
||||
from application.api.user.conversations.routes import GetMessageTail
|
||||
|
||||
owner = "user-tail-empty"
|
||||
_, msg_id = self._seed_in_flight_message(pg_conn, owner)
|
||||
|
||||
with _patch_conversations_db(pg_conn), app.test_request_context(
|
||||
f"/api/messages/{msg_id}/tail"
|
||||
):
|
||||
from flask import request
|
||||
|
||||
request.decoded_token = {"sub": owner}
|
||||
response = GetMessageTail().get(msg_id)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json["status"] == "streaming"
|
||||
assert response.json["response"] == ""
|
||||
|
||||
|
||||
class TestUpdateConversationNameHappy:
|
||||
def test_returns_401_unauthenticated(self, app):
|
||||
|
||||
@@ -33,7 +33,7 @@ class TestIngestTask:
|
||||
|
||||
mock_worker.assert_called_once_with(
|
||||
ANY, "dir", ["pdf"], "job1", "/path", "file.pdf", "user1",
|
||||
file_name_map=None, idempotency_key=None,
|
||||
file_name_map=None, idempotency_key=None, source_id=None,
|
||||
)
|
||||
assert result == {"status": "ok"}
|
||||
|
||||
@@ -50,7 +50,7 @@ class TestIngestTask:
|
||||
|
||||
mock_worker.assert_called_once_with(
|
||||
ANY, "dir", ["pdf"], "job1", "/path", "file.pdf", "user1",
|
||||
file_name_map=name_map, idempotency_key=None,
|
||||
file_name_map=name_map, idempotency_key=None, source_id=None,
|
||||
)
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ class TestIngestRemoteTask:
|
||||
|
||||
mock_worker.assert_called_once_with(
|
||||
ANY, {"url": "http://x"}, "job1", "user1", "web",
|
||||
idempotency_key=None,
|
||||
idempotency_key=None, source_id=None,
|
||||
)
|
||||
assert result == {"status": "ok"}
|
||||
|
||||
@@ -169,6 +169,7 @@ class TestIngestConnectorTask:
|
||||
doc_id=None,
|
||||
sync_frequency="never",
|
||||
idempotency_key=None,
|
||||
source_id=None,
|
||||
)
|
||||
assert result == {"status": "ok"}
|
||||
|
||||
@@ -207,6 +208,7 @@ class TestIngestConnectorTask:
|
||||
doc_id="doc1",
|
||||
sync_frequency="daily",
|
||||
idempotency_key=None,
|
||||
source_id=None,
|
||||
)
|
||||
assert result == {"status": "ok"}
|
||||
|
||||
@@ -220,7 +222,7 @@ class TestSetupPeriodicTasks:
|
||||
|
||||
setup_periodic_tasks(sender)
|
||||
|
||||
assert sender.add_periodic_task.call_count == 7
|
||||
assert sender.add_periodic_task.call_count == 8
|
||||
|
||||
calls = sender.add_periodic_task.call_args_list
|
||||
|
||||
@@ -241,6 +243,9 @@ class TestSetupPeriodicTasks:
|
||||
assert calls[5][1].get("name") == "reconciliation"
|
||||
# version-check (every 7h)
|
||||
assert calls[6][0][0] == timedelta(hours=7)
|
||||
# message_events retention sweep (24h)
|
||||
assert calls[7][0][0] == timedelta(hours=24)
|
||||
assert calls[7][1].get("name") == "cleanup-message-events"
|
||||
|
||||
|
||||
class TestMcpOauthTask:
|
||||
@@ -257,20 +262,6 @@ class TestMcpOauthTask:
|
||||
assert result == {"url": "http://auth"}
|
||||
|
||||
|
||||
class TestMcpOauthStatusTask:
|
||||
@pytest.mark.unit
|
||||
@patch("application.api.user.tasks.mcp_oauth_status")
|
||||
def test_calls_mcp_oauth_status(self, mock_worker):
|
||||
from application.api.user.tasks import mcp_oauth_status_task
|
||||
|
||||
mock_worker.return_value = {"status": "authorized"}
|
||||
|
||||
result = mcp_oauth_status_task("task123")
|
||||
|
||||
mock_worker.assert_called_once_with(ANY, "task123")
|
||||
assert result == {"status": "authorized"}
|
||||
|
||||
|
||||
class TestDurableTaskRetryPolicy:
|
||||
"""The long-running tasks share a uniform retry policy."""
|
||||
|
||||
@@ -302,7 +293,6 @@ class TestDurableTaskRetryPolicy:
|
||||
"schedule_syncs",
|
||||
"sync_source",
|
||||
"mcp_oauth_task",
|
||||
"mcp_oauth_status_task",
|
||||
"cleanup_pending_tool_state",
|
||||
"reconciliation_task",
|
||||
"version_check_task",
|
||||
@@ -438,6 +428,93 @@ class TestCleanupPendingToolState:
|
||||
}
|
||||
|
||||
|
||||
class TestCleanupMessageEventsTask:
|
||||
"""Retention janitor delegates to MessageEventsRepository.cleanup_older_than."""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_skips_when_postgres_uri_missing(self, monkeypatch):
|
||||
from application.api.user.tasks import cleanup_message_events
|
||||
from application.core.settings import settings
|
||||
|
||||
monkeypatch.setattr(settings, "POSTGRES_URI", None, raising=False)
|
||||
|
||||
result = cleanup_message_events.run()
|
||||
assert result == {"deleted": 0, "skipped": "POSTGRES_URI not set"}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_deletes_rows_past_retention_window(self, pg_conn, monkeypatch):
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import text as _text
|
||||
|
||||
from application.api.user.tasks import cleanup_message_events
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.repositories.message_events import (
|
||||
MessageEventsRepository,
|
||||
)
|
||||
|
||||
# Seed parent rows so the FK on message_events holds.
|
||||
user_id = f"user-{uuid.uuid4().hex[:8]}"
|
||||
conv_id = uuid.uuid4()
|
||||
msg_id = uuid.uuid4()
|
||||
pg_conn.execute(
|
||||
_text("INSERT INTO users (user_id) VALUES (:u)"),
|
||||
{"u": user_id},
|
||||
)
|
||||
pg_conn.execute(
|
||||
_text(
|
||||
"INSERT INTO conversations (id, user_id, name) "
|
||||
"VALUES (:id, :u, 'test')"
|
||||
),
|
||||
{"id": conv_id, "u": user_id},
|
||||
)
|
||||
pg_conn.execute(
|
||||
_text(
|
||||
"INSERT INTO conversation_messages (id, conversation_id, "
|
||||
"user_id, position) VALUES (:id, :c, :u, 0)"
|
||||
),
|
||||
{"id": msg_id, "c": conv_id, "u": user_id},
|
||||
)
|
||||
|
||||
repo = MessageEventsRepository(pg_conn)
|
||||
repo.record(str(msg_id), 0, "answer", {"chunk": "stale"})
|
||||
repo.record(str(msg_id), 1, "answer", {"chunk": "fresh"})
|
||||
# Backdate seq=0 past the default 14-day retention so the
|
||||
# janitor catches it; seq=1 stays at "now" and must survive.
|
||||
pg_conn.execute(
|
||||
_text(
|
||||
"UPDATE message_events SET created_at = now() - interval '20 days' "
|
||||
"WHERE message_id = CAST(:id AS uuid) AND sequence_no = 0"
|
||||
),
|
||||
{"id": str(msg_id)},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
settings, "POSTGRES_URI", "postgresql://stub", raising=False
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def _fake_begin():
|
||||
yield pg_conn
|
||||
|
||||
fake_engine = MagicMock()
|
||||
fake_engine.begin = _fake_begin
|
||||
|
||||
with patch(
|
||||
"application.storage.db.engine.get_engine",
|
||||
return_value=fake_engine,
|
||||
):
|
||||
result = cleanup_message_events.run()
|
||||
|
||||
assert result == {
|
||||
"deleted": 1,
|
||||
"ttl_days": settings.MESSAGE_EVENTS_RETENTION_DAYS,
|
||||
}
|
||||
# Only the fresh row survives.
|
||||
rows = repo.read_after(str(msg_id))
|
||||
assert [r["sequence_no"] for r in rows] == [1]
|
||||
|
||||
|
||||
class TestIngestIdempotency:
|
||||
"""Same short-circuit applies to the ingest task path."""
|
||||
|
||||
@@ -449,7 +526,7 @@ class TestIngestIdempotency:
|
||||
|
||||
def _fake_worker(self, directory, formats, job_name, file_path,
|
||||
filename, user, file_name_map=None,
|
||||
idempotency_key=None):
|
||||
idempotency_key=None, source_id=None):
|
||||
worker_calls.append(filename)
|
||||
return {"status": "ok", "directory": directory}
|
||||
|
||||
|
||||
@@ -363,47 +363,6 @@ class TestMCPServerSave:
|
||||
assert response.status_code in (200, 201)
|
||||
|
||||
|
||||
class TestMCPOAuthStatus:
|
||||
def test_returns_pending_when_no_data(self, app):
|
||||
from application.api.user.tools.mcp import MCPOAuthStatus
|
||||
|
||||
fake_redis = MagicMock()
|
||||
fake_redis.get.return_value = None
|
||||
|
||||
with patch(
|
||||
"application.api.user.tools.mcp.get_redis_instance",
|
||||
return_value=fake_redis,
|
||||
), app.test_request_context("/api/mcp_oauth_status/t1"):
|
||||
response = MCPOAuthStatus().get("t1")
|
||||
assert response.status_code == 200
|
||||
assert response.json["status"] == "pending"
|
||||
|
||||
def test_returns_status_from_redis(self, app):
|
||||
from application.api.user.tools.mcp import MCPOAuthStatus
|
||||
|
||||
fake_redis = MagicMock()
|
||||
fake_redis.get.return_value = '{"status": "completed", "tools": [{"name": "t1", "description": "d"}]}'
|
||||
|
||||
with patch(
|
||||
"application.api.user.tools.mcp.get_redis_instance",
|
||||
return_value=fake_redis,
|
||||
), app.test_request_context("/api/mcp_oauth_status/t1"):
|
||||
response = MCPOAuthStatus().get("t1")
|
||||
assert response.status_code == 200
|
||||
assert response.json["status"] == "completed"
|
||||
assert response.json["tools"][0]["name"] == "t1"
|
||||
|
||||
def test_redis_error_returns_500(self, app):
|
||||
from application.api.user.tools.mcp import MCPOAuthStatus
|
||||
|
||||
with patch(
|
||||
"application.api.user.tools.mcp.get_redis_instance",
|
||||
side_effect=RuntimeError("boom"),
|
||||
), app.test_request_context("/api/mcp_oauth_status/t1"):
|
||||
response = MCPOAuthStatus().get("t1")
|
||||
assert response.status_code == 500
|
||||
|
||||
|
||||
class TestMCPOAuthCallback:
|
||||
def test_error_param_redirects_error(self, app):
|
||||
from application.api.user.tools.mcp import MCPOAuthCallback
|
||||
|
||||
@@ -513,4 +513,73 @@ class TestChatCompletionsHappyPath:
|
||||
assert resp.status_code == 200
|
||||
assert resp.mimetype == "text/event-stream"
|
||||
|
||||
def test_stream_handles_id_prefixed_chunks(self, pg_conn):
|
||||
"""``complete_stream`` emits ``id: <seq>\\n`` before each
|
||||
``data:`` line. The v1 streaming consumer must skip the id
|
||||
header and the informational ``message_id`` event, not silently
|
||||
drop every chunk.
|
||||
"""
|
||||
app = _build_app()
|
||||
|
||||
def _fake_translate(data, api_key):
|
||||
return {"question": "hi"}
|
||||
|
||||
fake_processor = MagicMock()
|
||||
fake_processor.decoded_token = {"sub": "u"}
|
||||
fake_processor.conversation_id = "conv-1"
|
||||
fake_processor.agent_config = {}
|
||||
fake_processor.agent_id = None
|
||||
fake_processor.model_id = "m"
|
||||
|
||||
def _fake_helper_complete_stream(**kw):
|
||||
# Mirror the new wire format: id-prefixed records, plus
|
||||
# the informational message_id event the v1 path doesn't
|
||||
# have an analog for.
|
||||
yield 'id: 0\ndata: {"type": "message_id", "message_id": "m-1"}\n\n'
|
||||
yield 'id: 1\ndata: {"type": "id", "id": "conv-1"}\n\n'
|
||||
yield 'id: 2\ndata: {"type": "answer", "answer": "hi"}\n\n'
|
||||
|
||||
translated_chunks: list = []
|
||||
|
||||
def _fake_translate_stream_event(event_data, completion_id, model_name):
|
||||
translated_chunks.append(event_data)
|
||||
return ['data: x\n\n']
|
||||
|
||||
fake_helper = MagicMock()
|
||||
fake_helper.check_usage.return_value = None
|
||||
fake_helper.complete_stream.side_effect = _fake_helper_complete_stream
|
||||
|
||||
with _patch_v1_db(pg_conn), patch(
|
||||
"application.api.v1.routes.translate_request",
|
||||
side_effect=_fake_translate,
|
||||
), patch(
|
||||
"application.api.v1.routes.StreamProcessor",
|
||||
return_value=fake_processor,
|
||||
), patch(
|
||||
"application.api.v1.routes._V1AnswerHelper",
|
||||
return_value=fake_helper,
|
||||
), patch(
|
||||
"application.api.v1.routes.translate_stream_event",
|
||||
side_effect=_fake_translate_stream_event,
|
||||
):
|
||||
with app.test_client() as c:
|
||||
resp = c.post(
|
||||
"/v1/chat/completions",
|
||||
headers={"Authorization": "Bearer x"},
|
||||
json={
|
||||
"messages": [{"role": "user", "content": "Hi"}],
|
||||
"stream": True,
|
||||
},
|
||||
)
|
||||
# Drain the response so the generator runs to completion.
|
||||
list(resp.iter_encoded())
|
||||
|
||||
assert resp.status_code == 200
|
||||
# message_id event is skipped (no v1 analog); id + answer are
|
||||
# decoded and forwarded to the translator.
|
||||
types_translated = [c.get("type") for c in translated_chunks]
|
||||
assert "message_id" not in types_translated
|
||||
assert "id" in types_translated
|
||||
assert "answer" in types_translated
|
||||
|
||||
|
||||
|
||||
@@ -15,9 +15,9 @@ their own conftest to point at a real, long-running Postgres instance
|
||||
tests and are marked with ``@pytest.mark.integration``.
|
||||
|
||||
No mongomock. The ``mock_mongo_db`` fixture that used to live here was
|
||||
removed as part of the Phase 4/5 Mongo→Postgres cutover. Tests that
|
||||
still reference it will fail with "fixture not found" until the
|
||||
corresponding route handler is migrated to a repository read.
|
||||
removed as part of the Mongo→Postgres cutover. Tests that still
|
||||
reference it will fail with "fixture not found" until the corresponding
|
||||
route handler is migrated to a repository read.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Phase 1 regression tests for the YAML-driven ModelRegistry.
|
||||
"""Regression tests for the YAML-driven ModelRegistry.
|
||||
|
||||
These tests encode the contract that persisted agent / workflow /
|
||||
conversation references depend on: every model id and core capability
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Phase 3 tests: operator MODELS_CONFIG_DIR.
|
||||
"""Tests for the operator MODELS_CONFIG_DIR.
|
||||
|
||||
Covers the operator-supplied directory of model YAMLs that's loaded
|
||||
after the built-in catalog. Operators use this to add new
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Phase 2 tests for the openai_compatible provider.
|
||||
"""Tests for the openai_compatible provider.
|
||||
|
||||
Covers YAML loading from a temp directory, multiple coexisting catalogs
|
||||
(Mistral + Together), env-var-based credential resolution, the legacy
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
/**
|
||||
* Phase 2 helper — shared agent provisioning for specs that need a
|
||||
* published agent (with a real api_key) for subsequent /stream or /search
|
||||
* Shared agent provisioning for specs that need a published agent
|
||||
* (with a real api_key) for subsequent /stream or /search
|
||||
* calls. A PUBLISHED classic agent requires name, description, chunks,
|
||||
* retriever, prompt_id AND a source — otherwise `/api/create_agent`
|
||||
* returns 400.
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
/**
|
||||
* Phase 1 helper — see e2e-plan.md §P1-B.
|
||||
* Pre-authenticated Playwright APIRequestContext pointed at the e2e Flask.
|
||||
*/
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
/**
|
||||
* Phase 1 helper — see e2e-plan.md §P1-B.
|
||||
* JWT signing + per-test BrowserContext seeding for DocsGPT e2e.
|
||||
*/
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
/**
|
||||
* Phase 1 helper — see e2e-plan.md §P1-B.
|
||||
* Thin pg wrapper + typed row helpers for DB assertions in specs.
|
||||
*/
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
/**
|
||||
* Phase 1 helper — see e2e-plan.md §P1-B.
|
||||
* Per-test TRUNCATE — preserves `alembic_version`, wipes every other table.
|
||||
*/
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Phase 2 helper — shared streaming primitives.
|
||||
* Shared streaming primitives for SSE specs.
|
||||
*
|
||||
* The backend exposes two streaming endpoints:
|
||||
* - POST /stream (answer_ns path="/" + route "/stream") — SSE body
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Phase 3 helper — upload / task_status polling primitives for Tier-B specs.
|
||||
* Upload / task_status polling primitives shared across upload specs.
|
||||
*
|
||||
* These helpers wrap the two patterns that recur across B1/B3/B4/B5:
|
||||
*
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
// mock LLM, Postgres reset) is handled by `scripts/e2e/up.sh` and
|
||||
// `scripts/e2e/down.sh`. Playwright's built-in `webServer` can only manage one
|
||||
// process and would fight with the four-service native setup this suite needs.
|
||||
// See `e2e-plan.md` → "Phase 0 — Foundation" → P0-A for the orchestration
|
||||
// See `e2e-plan.md` → "Foundation" → P0-A for the orchestration
|
||||
// contract.
|
||||
|
||||
import { defineConfig, devices } from '@playwright/test';
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user