Compare commits

..

6 Commits

Author SHA1 Message Date
Alex
26e2e7d353 fix: tests 2026-05-04 23:00:21 +01:00
Alex
42c33f4e0d fix: better json validation 2026-05-04 18:09:40 +01:00
Alex
073f9fc003 fix: mini issues 2026-05-04 17:51:09 +01:00
Alex
9b974af210 fix: tests 2026-05-04 17:17:15 +01:00
Alex
fe1edc6b79 feat: more durable frontend 2026-05-04 16:32:28 +01:00
Alex
e550b11f39 feat: durability and idempotency keys 2026-05-03 18:36:02 +01:00
193 changed files with 1969 additions and 17163 deletions

View File

@@ -8,7 +8,7 @@ RUN apt-get update && \
add-apt-repository ppa:deadsnakes/ppa && \
apt-get update && \
apt-get install -y --no-install-recommends gcc g++ wget unzip libc6-dev python3.12 python3.12-venv python3.12-dev && \
rm -rf /var/lib/apt/lists/*
rm -rf /var/lib/apt/lists/*
# Verify Python installation and setup symlink
RUN if [ -f /usr/bin/python3.12 ]; then \
@@ -73,7 +73,7 @@ COPY --from=builder /models /app/models
COPY . /app/application
# Change the ownership of the /app directory to the appuser
RUN mkdir -p /app/application/inputs/local
RUN chown -R appuser:appuser /app
@@ -82,11 +82,6 @@ ENV FLASK_APP=app.py \
FLASK_DEBUG=true \
PATH="/venv/bin:$PATH"
ENV MALLOC_ARENA_MAX=2 \
OMP_NUM_THREADS=4 \
MKL_NUM_THREADS=4 \
OPENBLAS_NUM_THREADS=4
# Expose the port the app runs on
EXPOSE 7091

View File

@@ -114,8 +114,6 @@ class BaseAgent(ABC):
self.compressed_summary = compressed_summary
self.current_token_count = 0
self.context_limit_reached = False
self.conversation_id: Optional[str] = None
self.initial_user_id: Optional[str] = None
@log_activity()
def gen(

View File

@@ -2,7 +2,7 @@ import json
import logging
import re
from typing import Any, Dict, Optional
from urllib.parse import quote, urlencode
from urllib.parse import urlencode
import requests
@@ -11,7 +11,7 @@ from application.agents.tools.api_body_serializer import (
RequestBodySerializer,
)
from application.agents.tools.base import Tool
from application.security.safe_url import UnsafeUserUrlError, pinned_request
from application.core.url_validation import validate_url, SSRFError
logger = logging.getLogger(__name__)
@@ -70,16 +70,18 @@ class APITool(Tool):
Returns:
Dict with status_code, data, and message
"""
_VALID_METHODS = {"GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"}
request_url = url
request_headers = headers.copy() if headers else {}
response = None
if method.upper() not in _VALID_METHODS:
# Validate URL to prevent SSRF attacks
try:
validate_url(request_url)
except SSRFError as e:
logger.error(f"URL validation failed: {e}")
return {
"status_code": None,
"message": f"Unsupported HTTP method: {method}",
"message": f"URL validation error: {e}",
"data": None,
}
@@ -89,9 +91,8 @@ class APITool(Tool):
for match in re.finditer(r"\{([^}]+)\}", request_url):
param_name = match.group(1)
if param_name in query_params:
safe_value = quote(str(query_params[param_name]), safe="")
request_url = request_url.replace(
f"{{{param_name}}}", safe_value
f"{{{param_name}}}", str(query_params[param_name])
)
path_params_used.add(param_name)
remaining_params = {
@@ -102,6 +103,19 @@ class APITool(Tool):
separator = "&" if "?" in request_url else "?"
request_url = f"{request_url}{separator}{query_string}"
# Re-validate URL after parameter substitution to prevent SSRF via path params
try:
validate_url(request_url)
except SSRFError as e:
logger.error(f"URL validation failed after parameter substitution: {e}")
return {
"status_code": None,
"message": f"URL validation error: {e}",
"data": None,
}
# Serialize body based on content type
if body and body != {}:
try:
serialized_body, body_headers = RequestBodySerializer.serialize(
@@ -127,13 +141,49 @@ class APITool(Tool):
f"API Call: {method} {request_url} | Content-Type: {request_headers.get('Content-Type', 'N/A')}"
)
response = pinned_request(
method,
request_url,
data=serialized_body,
headers=request_headers,
timeout=DEFAULT_TIMEOUT,
)
if method.upper() == "GET":
response = requests.get(
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
)
elif method.upper() == "POST":
response = requests.post(
request_url,
data=serialized_body,
headers=request_headers,
timeout=DEFAULT_TIMEOUT,
)
elif method.upper() == "PUT":
response = requests.put(
request_url,
data=serialized_body,
headers=request_headers,
timeout=DEFAULT_TIMEOUT,
)
elif method.upper() == "DELETE":
response = requests.delete(
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
)
elif method.upper() == "PATCH":
response = requests.patch(
request_url,
data=serialized_body,
headers=request_headers,
timeout=DEFAULT_TIMEOUT,
)
elif method.upper() == "HEAD":
response = requests.head(
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
)
elif method.upper() == "OPTIONS":
response = requests.options(
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
)
else:
return {
"status_code": None,
"message": f"Unsupported HTTP method: {method}",
"data": None,
}
response.raise_for_status()
data = self._parse_response(response)
@@ -143,13 +193,6 @@ class APITool(Tool):
"data": data,
"message": "API call successful.",
}
except UnsafeUserUrlError as e:
logger.error(f"URL validation failed: {e}")
return {
"status_code": None,
"message": f"URL validation error: {e}",
"data": None,
}
except requests.exceptions.Timeout:
logger.error(f"Request timeout for {request_url}")
return {

View File

@@ -20,11 +20,10 @@ from pydantic import AnyHttpUrl, ValidationError
from redis import Redis
from application.agents.tools.base import Tool
from application.api.user.tasks import mcp_oauth_task
from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task
from application.cache import get_redis_instance
from application.core.settings import settings
from application.core.url_validation import SSRFError, validate_url
from application.events.keys import stream_key
from application.security.encryption import decrypt_credentials
logger = logging.getLogger(__name__)
@@ -77,12 +76,6 @@ class MCPTool(Tool):
self.oauth_task_id = config.get("oauth_task_id", None)
self.oauth_client_name = config.get("oauth_client_name", "DocsGPT-MCP")
self.redirect_uri = self._resolve_redirect_uri(config.get("redirect_uri"))
# Pulled out of ``config`` (rather than left in ``self.config``)
# because it is a callable supplied by the OAuth worker — not
# something the rest of the tool plumbing should marshal or
# serialize. ``DocsGPTOAuth`` invokes it from ``redirect_handler``
# so the SSE envelope can carry ``authorization_url``.
self.oauth_redirect_publish = config.pop("oauth_redirect_publish", None)
self.available_tools = []
self._cache_key = self._generate_cache_key()
@@ -174,7 +167,6 @@ class MCPTool(Tool):
redirect_uri=self.redirect_uri,
task_id=self.oauth_task_id,
user_id=self.user_id,
redirect_publish=self.oauth_redirect_publish,
)
elif self.auth_type == "bearer":
token = self.auth_credentials.get(
@@ -687,17 +679,12 @@ class DocsGPTOAuth(OAuthClientProvider):
user_id=None,
additional_client_metadata: dict[str, Any] | None = None,
skip_redirect_validation: bool = False,
redirect_publish=None,
):
self.redirect_uri = redirect_uri
self.redis_client = redis_client
self.redis_prefix = redis_prefix
self.task_id = task_id
self.user_id = user_id
# Worker-supplied callback. Invoked from ``redirect_handler``
# once the authorization URL is known so the SSE envelope can
# carry it. ``None`` for any non-worker entrypoint.
self.redirect_publish = redirect_publish
parsed_url = urlparse(mcp_url)
self.server_base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
@@ -757,19 +744,17 @@ class DocsGPTOAuth(OAuthClientProvider):
self.redis_client.setex(key, 600, auth_url)
logger.info("Stored auth_url in Redis: %s", key)
if self.redirect_publish is not None:
# Best-effort: a publish failure must not abort the OAuth
# handshake — the user can still authorize via the popup
# opened from the legacy polling fallback if the SSE
# envelope is lost.
try:
self.redirect_publish(auth_url)
except Exception:
logger.warning(
"redirect_publish callback raised for task_id=%s",
self.task_id,
exc_info=True,
)
if self.task_id:
status_key = f"mcp_oauth_status:{self.task_id}"
status_data = {
"status": "requires_redirect",
"message": "Authorization required",
"authorization_url": self.auth_url,
"state": self.extracted_state,
"requires_oauth": True,
"task_id": self.task_id,
}
self.redis_client.setex(status_key, 600, json.dumps(status_data))
async def callback_handler(self) -> tuple[str, str | None]:
"""Wait for auth code from Redis using the state value."""
@@ -779,6 +764,17 @@ class DocsGPTOAuth(OAuthClientProvider):
max_wait_time = 300
code_key = f"{self.redis_prefix}code:{self.extracted_state}"
if self.task_id:
status_key = f"mcp_oauth_status:{self.task_id}"
status_data = {
"status": "awaiting_callback",
"message": "Waiting for authorization...",
"authorization_url": self.auth_url,
"state": self.extracted_state,
"requires_oauth": True,
"task_id": self.task_id,
}
self.redis_client.setex(status_key, 600, json.dumps(status_data))
start_time = time.time()
while time.time() - start_time < max_wait_time:
code_data = self.redis_client.get(code_key)
@@ -793,6 +789,14 @@ class DocsGPTOAuth(OAuthClientProvider):
self.redis_client.delete(
f"{self.redis_prefix}state:{self.extracted_state}"
)
if self.task_id:
status_data = {
"status": "callback_received",
"message": "Completing authentication...",
"task_id": self.task_id,
}
self.redis_client.setex(status_key, 600, json.dumps(status_data))
return code, returned_state
error_key = f"{self.redis_prefix}error:{self.extracted_state}"
error_data = self.redis_client.get(error_key)
@@ -1034,73 +1038,8 @@ class MCPOAuthManager:
logger.error("Error handling OAuth callback: %s", e)
return False
def get_oauth_status(self, task_id: str, user_id: str) -> Dict[str, Any]:
"""Return the latest OAuth status for ``task_id`` from the user's SSE journal.
Mirrors the legacy polling contract: ``status`` derived from the
``mcp.oauth.*`` event-type suffix, with payload fields surfaced
(e.g. ``tools``/``tools_count`` on ``completed``).
"""
def get_oauth_status(self, task_id: str) -> Dict[str, Any]:
"""Get current status of OAuth flow using provided task_id."""
if not task_id:
return {"status": "not_started", "message": "OAuth flow not started"}
if not user_id:
return {"status": "not_found", "message": "User not provided"}
if self.redis_client is None:
return {"status": "not_found", "message": "Redis unavailable"}
try:
# OAuth flows are short-lived but a concurrent source
# ingest can flood the user channel between the OAuth
# popup completing and the user clicking Save, pushing the
# completion envelope outside the read window. Bound the
# scan by the configured stream cap so we cover the full
# journal — XADD MAXLEN keeps that bounded too.
scan_count = max(settings.EVENTS_STREAM_MAXLEN, 200)
entries = self.redis_client.xrevrange(
stream_key(user_id), count=scan_count
)
except Exception:
logger.exception(
"xrevrange failed for oauth status: user_id=%s task_id=%s",
user_id,
task_id,
)
return {"status": "not_found", "message": "Status unavailable"}
for _entry_id, fields in entries:
if not isinstance(fields, dict):
continue
# decode_responses=False ⇒ bytes keys; the string-key fallback
# covers a future flip of that default without a forced refactor.
event_raw = fields.get(b"event")
if event_raw is None:
event_raw = fields.get("event")
if event_raw is None:
continue
if isinstance(event_raw, bytes):
try:
event_raw = event_raw.decode("utf-8")
except Exception:
continue
try:
envelope = json.loads(event_raw)
except Exception:
continue
if not isinstance(envelope, dict):
continue
event_type = envelope.get("type", "")
if not isinstance(event_type, str) or not event_type.startswith(
"mcp.oauth."
):
continue
scope = envelope.get("scope") or {}
if scope.get("kind") != "mcp_oauth" or scope.get("id") != task_id:
continue
payload = envelope.get("payload") or {}
return {
"status": event_type[len("mcp.oauth."):],
"task_id": task_id,
**payload,
}
return {"status": "not_found", "message": "Status not found"}
return mcp_oauth_status_task(task_id)

View File

@@ -1,5 +1,5 @@
import requests
from application.agents.tools.base import Tool
from application.security.safe_url import UnsafeUserUrlError, pinned_request
class NtfyTool(Tool):
"""
@@ -71,12 +71,7 @@ class NtfyTool(Tool):
if self.token:
headers["Authorization"] = f"Basic {self.token}"
data = message.encode("utf-8")
try:
response = pinned_request(
"POST", url, data=data, headers=headers, timeout=100,
)
except UnsafeUserUrlError as e:
return {"status_code": None, "message": f"URL validation error: {e}"}
response = requests.post(url, headers=headers, data=data, timeout=100)
return {"status_code": response.status_code, "message": "Message sent"}
def get_actions_metadata(self):

View File

@@ -1,6 +1,7 @@
import requests
from markdownify import markdownify
from application.agents.tools.base import Tool
from application.security.safe_url import UnsafeUserUrlError, pinned_request
from application.core.url_validation import validate_url, SSRFError
class ReadWebpageTool(Tool):
"""
@@ -30,24 +31,28 @@ class ReadWebpageTool(Tool):
if not url:
return "Error: URL parameter is missing."
# Validate URL to prevent SSRF attacks
try:
response = pinned_request(
"GET",
url,
headers={'User-Agent': 'DocsGPT-Agent/1.0'},
timeout=10,
)
response.raise_for_status()
url = validate_url(url)
except SSRFError as e:
return f"Error: URL validation failed - {e}"
try:
response = requests.get(url, timeout=10, headers={'User-Agent': 'DocsGPT-Agent/1.0'})
response.raise_for_status() # Raise an exception for HTTP errors (4xx or 5xx)
html_content = response.text
#soup = BeautifulSoup(html_content, 'html.parser')
markdown_content = markdownify(html_content, heading_style="ATX", newline_style="BACKSLASH")
return markdown_content
except UnsafeUserUrlError as e:
return f"Error: URL validation failed - {e}"
except Exception as e:
except requests.exceptions.RequestException as e:
return f"Error fetching URL {url}: {e}"
except Exception as e:
return f"Error processing URL {url}: {e}"
def get_actions_metadata(self):
"""

View File

@@ -1,4 +1,4 @@
"""0001 initial schema — consolidated baseline for user-data tables.
"""0001 initial schema — consolidated Phase-1..3 baseline.
Revision ID: 0001_initial
Revises:

View File

@@ -1,40 +0,0 @@
"""0007 message_events — durable journal of chat-stream events.
Snapshot half of the chat-stream snapshot+tail pattern. Composite PK
``(message_id, sequence_no)``, ``created_at`` indexed for retention
sweeps, ``ON DELETE CASCADE`` from ``conversation_messages``.
Revision ID: 0007_message_events
Revises: 0006_idempotency_lease
"""
from typing import Sequence, Union
from alembic import op
revision: str = "0007_message_events"
down_revision: Union[str, None] = "0006_idempotency_lease"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.execute(
"""
CREATE TABLE message_events (
message_id UUID NOT NULL REFERENCES conversation_messages(id) ON DELETE CASCADE,
sequence_no INTEGER NOT NULL,
event_type TEXT NOT NULL,
payload JSONB NOT NULL DEFAULT '{}'::jsonb,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
PRIMARY KEY (message_id, sequence_no)
);
CREATE INDEX message_events_created_at_idx ON message_events(created_at);
"""
)
def downgrade() -> None:
op.execute("DROP INDEX IF EXISTS message_events_created_at_idx;")
op.execute("DROP TABLE IF EXISTS message_events;")

View File

@@ -1,44 +0,0 @@
"""0008 ingest_chunk_progress.status — terminal flag for stalled ingests.
The reconciler's stalled-ingest sweep had no terminal write, so a dead
ingest re-alerted every ~30 min forever. ``status`` lets it escalate a
stalled checkpoint to ``'stalled'`` once and stop re-selecting it;
``init_progress`` resets it to ``'active'`` on reingest.
Revision ID: 0008_ingest_progress_status
Revises: 0007_message_events
"""
from typing import Sequence, Union
from alembic import op
revision: str = "0008_ingest_progress_status"
down_revision: Union[str, None] = "0007_message_events"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Constant DEFAULT — metadata-only ADD COLUMN, no table rewrite.
op.execute(
"""
ALTER TABLE ingest_chunk_progress
ADD COLUMN status TEXT NOT NULL DEFAULT 'active'
CHECK (status IN ('active', 'stalled'));
"""
)
# Partial index for the reconciler's stalled-ingest sweep.
op.execute(
"CREATE INDEX ingest_chunk_progress_active_idx "
"ON ingest_chunk_progress (last_updated) "
"WHERE status = 'active';"
)
def downgrade() -> None:
op.execute("DROP INDEX IF EXISTS ingest_chunk_progress_active_idx;")
op.execute(
"ALTER TABLE ingest_chunk_progress DROP COLUMN IF EXISTS status;"
)

View File

@@ -23,16 +23,9 @@ from application.core.settings import settings
from application.error import sanitize_api_error
from application.llm.llm_creator import LLMCreator
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.repositories.conversations import MessageUpdateOutcome
from application.storage.db.repositories.token_usage import TokenUsageRepository
from application.storage.db.repositories.user_logs import UserLogsRepository
from application.storage.db.session import db_readonly, db_session
from application.events.publisher import publish_user_event
from application.streaming.event_replay import format_sse_event
from application.streaming.message_journal import (
BatchedJournalWriter,
record_event,
)
from application.utils import check_required_fields
logger = logging.getLogger(__name__)
@@ -284,17 +277,6 @@ class BaseAnswerResource:
"update_message_status streaming failed for %s",
reserved_message_id,
)
# Seed last_heartbeat_at so watchdog doesn't fall back to `timestamp`
# (creation time) before the first STREAM_HEARTBEAT_INTERVAL tick.
try:
self.conversation_service.heartbeat_message(
reserved_message_id,
)
except Exception:
logger.exception(
"initial heartbeat seed failed for %s",
reserved_message_id,
)
streaming_marked = True
last_heartbeat_at = time.monotonic()
@@ -321,73 +303,13 @@ class BaseAnswerResource:
try:
agent.tool_executor.message_id = reserved_message_id
except Exception:
logger.debug(
"Could not set tool_executor.message_id; tool-call correlation will be missing for message_id=%s",
reserved_message_id,
)
# Per-stream monotonic SSE event id. Allocated by ``_emit`` and
# threaded through both the wire format (``id: <seq>\\n``) and
# the journal write so a reconnecting client can ``Last-Event-
# ID`` past anything they already saw. Continuations resume
# against the original ``reserved_message_id`` — seed the
# allocator from the journal's high-water mark so we don't
# collide on the duplicate-PK and silently lose every emit
# past the resume point.
sequence_no = -1
if _continuation and reserved_message_id:
try:
from application.storage.db.repositories.message_events import (
MessageEventsRepository,
)
with db_readonly() as conn:
latest = MessageEventsRepository(conn).latest_sequence_no(
reserved_message_id
)
if latest is not None:
sequence_no = latest
except Exception:
logger.exception(
"Continuation seq seed lookup failed for message_id=%s; "
"falling back to seq=-1 (duplicate-PK collisions will "
"be swallowed)",
reserved_message_id,
)
# One batched journal writer per stream.
journal_writer: Optional[BatchedJournalWriter] = (
BatchedJournalWriter(reserved_message_id)
if reserved_message_id
else None
)
def _emit(payload: dict) -> str:
"""Format-and-journal one SSE event.
With a reserved ``message_id``, buffers into the journal and
emits ``id: <seq>``-tagged SSE frames; otherwise falls back to
legacy ``data: ...\\n\\n`` framing.
"""
nonlocal sequence_no
if not reserved_message_id or journal_writer is None:
return f"data: {json.dumps(payload)}\n\n"
sequence_no += 1
seq = sequence_no
event_type = (
payload.get("type", "data")
if isinstance(payload, dict)
else "data"
)
normalised = payload if isinstance(payload, dict) else {"value": payload}
journal_writer.record(seq, event_type, normalised)
return format_sse_event(normalised, seq)
pass
try:
# Surface the placeholder id before any LLM tokens so a
# mid-handshake disconnect still has a row to tail-poll.
if reserved_message_id:
yield _emit(
early_event = json.dumps(
{
"type": "message_id",
"message_id": reserved_message_id,
@@ -397,6 +319,7 @@ class BaseAnswerResource:
"request_id": request_id,
}
)
yield f"data: {early_event}\n\n"
if _continuation:
gen_iter = agent.gen_continuation(
@@ -422,9 +345,8 @@ class BaseAnswerResource:
schema_info = line.get("schema")
structured_chunks.append(line["answer"])
else:
yield _emit(
{"type": "answer", "answer": line["answer"]}
)
data = json.dumps({"type": "answer", "answer": line["answer"]})
yield f"data: {data}\n\n"
elif "sources" in line:
_mark_streaming_once()
truncated_sources = []
@@ -437,40 +359,43 @@ class BaseAnswerResource:
)
truncated_sources.append(truncated_source)
if truncated_sources:
yield _emit(
data = json.dumps(
{"type": "source", "source": truncated_sources}
)
yield f"data: {data}\n\n"
elif "tool_calls" in line:
tool_calls = line["tool_calls"]
yield _emit({"type": "tool_calls", "tool_calls": tool_calls})
data = json.dumps({"type": "tool_calls", "tool_calls": tool_calls})
yield f"data: {data}\n\n"
elif "thought" in line:
thought += line["thought"]
yield _emit({"type": "thought", "thought": line["thought"]})
data = json.dumps({"type": "thought", "thought": line["thought"]})
yield f"data: {data}\n\n"
elif "type" in line:
if line.get("type") == "tool_calls_pending":
# Save continuation state and end the stream
paused = True
yield _emit(line)
data = json.dumps(line)
yield f"data: {data}\n\n"
elif line.get("type") == "error":
yield _emit(
{
"type": "error",
"error": sanitize_api_error(
line.get("error", "An error occurred")
),
}
)
sanitized_error = {
"type": "error",
"error": sanitize_api_error(line.get("error", "An error occurred"))
}
data = json.dumps(sanitized_error)
yield f"data: {data}\n\n"
else:
yield _emit(line)
data = json.dumps(line)
yield f"data: {data}\n\n"
if is_structured and structured_chunks:
yield _emit(
{
"type": "structured_answer",
"answer": response_full,
"structured": True,
"schema": schema_info,
}
)
structured_data = {
"type": "structured_answer",
"answer": response_full,
"structured": True,
"schema": schema_info,
}
data = json.dumps(structured_data)
yield f"data: {data}\n\n"
# ---- Paused: save continuation state and end stream early ----
if paused:
@@ -527,7 +452,6 @@ class BaseAnswerResource:
exc_info=True,
)
state_saved = False
if conversation_id:
try:
cont_service = ContinuationService()
@@ -561,65 +485,18 @@ class BaseAnswerResource:
agent.tool_executor, "client_tools", None
),
)
state_saved = True
except Exception as e:
logger.error(
f"Failed to save continuation state: {str(e)}",
exc_info=True,
)
# Notify the user out-of-band so they can navigate
# back to the conversation and decide on the
# pending tool calls. Gated on ``state_saved``: a
# missing pending_tool_state row would 404 the
# resume endpoint, so an unfulfillable notification
# is worse than no notification.
user_id_for_event = (
decoded_token.get("sub") if decoded_token else None
)
if state_saved and user_id_for_event and conversation_id:
pending_calls = continuation.get(
"pending_tool_calls", []
) if continuation else []
# Trim each pending tool call to its identifying
# metadata so a tool with a multi-MB argument
# doesn't blow out the per-event payload size
# cap. The resume page fetches full args from
# ``pending_tool_state`` regardless.
pending_summaries = [
{
k: tc.get(k)
for k in (
"call_id",
"tool_name",
"action_name",
"name",
)
if isinstance(tc, dict) and tc.get(k) is not None
}
for tc in (pending_calls or [])
if isinstance(tc, dict)
]
publish_user_event(
user_id_for_event,
"tool.approval.required",
{
"conversation_id": str(conversation_id),
"message_id": reserved_message_id,
"pending_tool_calls": pending_summaries,
},
scope={
"kind": "conversation",
"id": str(conversation_id),
},
)
id_data = {"type": "id", "id": str(conversation_id)}
data = json.dumps(id_data)
yield f"data: {data}\n\n"
yield _emit({"type": "id", "id": str(conversation_id)})
yield _emit({"type": "end"})
# Drain the terminal ``end`` so a reconnecting client
# sees it on snapshot — same reason as the main exit.
if journal_writer is not None:
journal_writer.close()
data = json.dumps({"type": "end"})
yield f"data: {data}\n\n"
return
if isNoneDoc:
@@ -726,7 +603,9 @@ class BaseAnswerResource:
f"completion: {e}",
exc_info=True,
)
yield _emit({"type": "id", "id": str(conversation_id)})
id_data = {"type": "id", "id": str(conversation_id)}
data = json.dumps(id_data)
yield f"data: {data}\n\n"
tool_calls_for_logging = self._prepare_tool_calls_for_logging(
getattr(agent, "tool_calls", tool_calls) or tool_calls
@@ -767,33 +646,12 @@ class BaseAnswerResource:
exc_info=True,
)
yield _emit({"type": "end"})
# Drain the journal buffer so the terminal ``end`` event is
# visible to any reconnecting client. Without this the
# client could snapshot up to the last flush boundary and
# then live-tail waiting for an ``end`` that's still
# sitting in memory.
if journal_writer is not None:
journal_writer.close()
data = json.dumps({"type": "end"})
yield f"data: {data}\n\n"
except GeneratorExit:
logger.info(f"Stream aborted by client for question: {question[:50]}... ")
# Drain any buffered events before the terminal one-shot
# ``record_event`` below — keeps the journal's seq order
# contiguous (buffered events ... terminal event). ``close``
# is idempotent; pairing it with ``flush`` matches the
# normal-exit and error branches so any future ``record()``
# past this point would log instead of silently buffering.
if journal_writer is not None:
journal_writer.flush()
journal_writer.close()
# Save partial response
# Whether the DB row was flipped to ``complete`` during this
# abort handler. Drives the choice of terminal journal event
# below: journal ``end`` only when the row actually matches,
# else journal ``error`` so a reconnecting client sees a
# failed terminal state instead of a blank "success".
finalized_complete = False
if should_save_conversation and response_full:
try:
if isNoneDoc:
@@ -828,7 +686,7 @@ class BaseAnswerResource:
)
llm._token_usage_source = "title"
if reserved_message_id is not None:
outcome = self.conversation_service.finalize_message(
self.conversation_service.finalize_message(
reserved_message_id,
response_full,
thought=thought,
@@ -847,15 +705,6 @@ class BaseAnswerResource:
),
},
)
# ``ALREADY_COMPLETE`` means the normal-path
# finalize at line 632 won the race: the DB row
# is already at ``complete`` and the reconnect
# journal should reflect that with ``end``,
# not a spurious ``error``.
finalized_complete = outcome in (
MessageUpdateOutcome.UPDATED,
MessageUpdateOutcome.ALREADY_COMPLETE,
)
else:
self.conversation_service.save_conversation(
conversation_id,
@@ -875,9 +724,6 @@ class BaseAnswerResource:
attachment_ids=attachment_ids,
metadata=query_metadata if query_metadata else None,
)
# No journal row to gate, but flag the save as
# successful for symmetry with the WAL path.
finalized_complete = True
compression_meta = getattr(agent, "compression_metadata", None)
compression_saved = getattr(agent, "compression_saved", False)
if conversation_id and compression_meta and not compression_saved:
@@ -901,63 +747,6 @@ class BaseAnswerResource:
logger.error(
f"Error saving partial response: {str(e)}", exc_info=True
)
# Journal a terminal event so reconnecting clients stop tailing;
# ``end`` only when the row is ``complete``, else ``error``.
if reserved_message_id is not None:
try:
sequence_no += 1
if finalized_complete:
# Match the wire shape ``_emit({"type": "end"})``
# uses on the normal path — the replay terminal
# check at ``event_replay._payload_is_terminal``
# reads ``payload.type``, and the frontend parses
# the same key off ``data:``.
record_event(
reserved_message_id,
sequence_no,
"end",
{"type": "end"},
)
else:
# Nothing was persisted under the complete status
# — mark the row failed so the reconciler doesn't
# need to sweep it, and journal an ``error`` so a
# reconnecting client surfaces the same failure
# the UI would show on a live error.
try:
self.conversation_service.finalize_message(
reserved_message_id,
response_full or TERMINATED_RESPONSE_PLACEHOLDER,
thought=thought,
sources=source_log_docs,
tool_calls=tool_calls,
model_id=model_id or self.default_model_id,
metadata=query_metadata if query_metadata else None,
status="failed",
error=ConnectionError(
"client disconnected before response was persisted"
),
)
except Exception as fin_err:
logger.error(
f"Failed to mark aborted message failed: {fin_err}",
exc_info=True,
)
record_event(
reserved_message_id,
sequence_no,
"error",
{
"type": "error",
"error": "Stream aborted before any response was produced.",
"code": "client_disconnect",
},
)
except Exception as journal_err:
logger.error(
f"Failed to journal terminal event on abort: {journal_err}",
exc_info=True,
)
raise
except Exception as e:
logger.error(f"Error in stream: {str(e)}", exc_info=True)
@@ -979,16 +768,13 @@ class BaseAnswerResource:
f"Failed to finalize errored message: {fin_err}",
exc_info=True,
)
yield _emit(
data = json.dumps(
{
"type": "error",
"error": "Please try again later. We apologize for any inconvenience.",
}
)
# Drain the terminal ``error`` event we just yielded so a
# reconnecting client sees it on snapshot.
if journal_writer is not None:
journal_writer.close()
yield f"data: {data}\n\n"
return
def process_response_stream(self, stream) -> Dict[str, Any]:
@@ -1010,22 +796,8 @@ class BaseAnswerResource:
for line in stream:
try:
# Each chunk may carry an ``id: <seq>`` header before
# the ``data:`` line. Pull just the ``data:`` body so
# the JSON decode doesn't choke on the SSE framing.
event_data = ""
for raw in line.split("\n"):
if raw.startswith("data:"):
event_data = raw[len("data:") :].lstrip()
break
if not event_data:
continue
event_data = line.replace("data: ", "").strip()
event = json.loads(event_data)
# The ``message_id`` event is informational for the
# streaming consumer and has no synchronous-API field;
# skip it so the type-switch below doesn't KeyError.
if event.get("type") == "message_id":
continue
if event["type"] == "id":
conversation_id = event["id"]

View File

@@ -1,135 +0,0 @@
"""GET /api/messages/<message_id>/events — chat-stream reconnect endpoint.
Authenticates the caller, verifies ``message_id`` belongs to the user,
then hands off to ``build_message_event_stream`` for snapshot+tail.
"""
from __future__ import annotations
import logging
import re
from typing import Iterator, Optional
from flask import Blueprint, Response, jsonify, make_response, request, stream_with_context
from sqlalchemy import text
from application.core.settings import settings
from application.storage.db.session import db_readonly
from application.streaming.event_replay import (
DEFAULT_KEEPALIVE_SECONDS,
DEFAULT_POLL_TIMEOUT_SECONDS,
build_message_event_stream,
)
logger = logging.getLogger(__name__)
messages_bp = Blueprint("message_stream", __name__)
# A message_id is the canonical UUID hex format. Reject anything else
# before the SQL layer so a malformed cookie can't surface as a 500.
_MESSAGE_ID_RE = re.compile(
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-"
r"[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$"
)
# ``sequence_no`` is a non-negative decimal integer. Anything else is
# corrupt client state — fall through to a fresh-replay cursor and let
# the snapshot reader catch the client up.
_SEQUENCE_NO_RE = re.compile(r"^\d+$")
def _normalise_last_event_id(raw: Optional[str]) -> Optional[int]:
if raw is None:
return None
raw = raw.strip()
if not raw or not _SEQUENCE_NO_RE.match(raw):
return None
return int(raw)
def _user_owns_message(message_id: str, user_id: str) -> bool:
"""Return True iff ``message_id`` belongs to ``user_id``."""
try:
with db_readonly() as conn:
row = conn.execute(
text(
"""
SELECT 1 FROM conversation_messages
WHERE id = CAST(:id AS uuid)
AND user_id = :u
LIMIT 1
"""
),
{"id": message_id, "u": user_id},
).first()
return row is not None
except Exception:
logger.exception(
"Ownership lookup failed for message_id=%s user_id=%s",
message_id,
user_id,
)
return False
@messages_bp.route("/api/messages/<message_id>/events", methods=["GET"])
def stream_message_events(message_id: str) -> Response:
decoded = getattr(request, "decoded_token", None)
user_id = decoded.get("sub") if isinstance(decoded, dict) else None
if not user_id:
return make_response(
jsonify({"success": False, "message": "Authentication required"}),
401,
)
if not _MESSAGE_ID_RE.match(message_id):
return make_response(
jsonify({"success": False, "message": "Invalid message id"}),
400,
)
if not _user_owns_message(message_id, user_id):
# Don't disclose whether the row exists — a malicious caller
# gets the same 404 whether the id is bogus, taken by another
# user, or simply gone.
return make_response(
jsonify({"success": False, "message": "Not found"}),
404,
)
raw_cursor = request.headers.get("Last-Event-ID") or request.args.get(
"last_event_id"
)
last_event_id = _normalise_last_event_id(raw_cursor)
keepalive_seconds = float(
getattr(settings, "SSE_KEEPALIVE_SECONDS", DEFAULT_KEEPALIVE_SECONDS)
)
@stream_with_context
def generate() -> Iterator[str]:
try:
yield from build_message_event_stream(
message_id,
last_event_id=last_event_id,
keepalive_seconds=keepalive_seconds,
poll_timeout_seconds=DEFAULT_POLL_TIMEOUT_SECONDS,
)
except GeneratorExit:
return
except Exception:
logger.exception(
"Reconnect stream crashed for message_id=%s user_id=%s",
message_id,
user_id,
)
response = Response(generate(), mimetype="text/event-stream")
response.headers["Cache-Control"] = "no-store"
response.headers["X-Accel-Buffering"] = "no"
response.headers["Connection"] = "keep-alive"
logger.info(
"message.event.connect message_id=%s user_id=%s last_event_id=%s",
message_id,
user_id,
last_event_id if last_event_id is not None else "-",
)
return response

View File

@@ -15,10 +15,7 @@ from sqlalchemy import text as sql_text
from application.core.settings import settings
from application.storage.db.base_repository import looks_like_uuid
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.repositories.conversations import (
ConversationsRepository,
MessageUpdateOutcome,
)
from application.storage.db.repositories.conversations import ConversationsRepository
from application.storage.db.session import db_readonly, db_session
@@ -308,17 +305,10 @@ class ConversationService:
status: str = "complete",
error: Optional[BaseException] = None,
title_inputs: Optional[Dict[str, Any]] = None,
) -> MessageUpdateOutcome:
"""Commit the response and tool_call confirms in one transaction.
The outcome propagates directly from ``update_message_by_id`` so
callers (notably the SSE abort handler) can tell a fresh
finalize from "the row was already terminal" — the latter must
still be treated as success when the prior state was
``complete``.
"""
) -> bool:
"""Commit the response and tool_call confirms in one transaction."""
if not message_id:
return MessageUpdateOutcome.INVALID
return False
sources = sources or []
for source in sources:
if "text" in source and isinstance(source["text"], str):
@@ -346,16 +336,16 @@ class ConversationService:
# retracting a row the reconciler already escalated.
with db_session() as conn:
repo = ConversationsRepository(conn)
outcome = repo.update_message_by_id(
ok = repo.update_message_by_id(
message_id, update_fields,
only_if_non_terminal=True,
)
if outcome is not MessageUpdateOutcome.UPDATED:
if not ok:
logger.warning(
f"finalize_message: no row updated for message_id={message_id} "
f"(outcome={outcome.value} — possibly already terminal)"
f"(possibly already terminal — reconciler may have escalated)"
)
return outcome
return False
repo.confirm_executed_tool_calls(message_id)
# Outside the txn — title-gen is a multi-second LLM round trip.
@@ -368,7 +358,7 @@ class ConversationService:
f"finalize_message title generation failed: {e}",
exc_info=True,
)
return MessageUpdateOutcome.UPDATED
return True
def _maybe_generate_title(
self,

View File

@@ -1,504 +0,0 @@
"""GET /api/events — user-scoped Server-Sent Events endpoint.
Subscribe-then-snapshot pattern: subscribe to ``user:{user_id}``
pub/sub, snapshot the Redis Streams backlog past ``Last-Event-ID``
inside the SUBSCRIBE-ack callback, flush snapshot, then tail live
events (dedup'd by stream id). See ``docs/runbooks/sse-notifications.md``.
"""
from __future__ import annotations
import json
import logging
import re
import time
from typing import Iterator, Optional
from flask import Blueprint, Response, jsonify, make_response, request, stream_with_context
from application.cache import get_redis_instance
from application.core.settings import settings
from application.events.keys import (
connection_counter_key,
replay_budget_key,
stream_id_compare,
stream_key,
topic_name,
)
from application.streaming.broadcast_channel import Topic
logger = logging.getLogger(__name__)
events = Blueprint("event_stream", __name__)
SUBSCRIBE_POLL_INTERVAL_SECONDS = 1.0
# WHATWG SSE treats CRLF, CR, and LF equivalently as line terminators.
_SSE_LINE_SPLIT = re.compile(r"\r\n|\r|\n")
# Redis Streams ids are ``ms`` or ``ms-seq`` where both halves are decimal.
# Anything else is a corrupted client cookie / IndexedDB residue and must
# not be passed to XRANGE — Redis would reject it and our truncation gate
# would silently fail.
_STREAM_ID_RE = re.compile(r"^\d+(-\d+)?$")
# Only emitted at most once per process so a misconfigured deployment
# doesn't drown the logs.
_local_user_warned = False
def _format_sse(data: str, *, event_id: Optional[str] = None) -> str:
"""Encode a payload as one SSE message terminated by a blank line.
Splits on any line-terminator variant (``\\r\\n``, ``\\r``, ``\\n``)
so a stray CR in upstream content can't smuggle a premature line
boundary into the wire format.
"""
lines: list[str] = []
if event_id:
lines.append(f"id: {event_id}")
for line in _SSE_LINE_SPLIT.split(data):
lines.append(f"data: {line}")
return "\n".join(lines) + "\n\n"
def _decode(value) -> Optional[str]:
if value is None:
return None
if isinstance(value, (bytes, bytearray)):
try:
return value.decode("utf-8")
except Exception:
return None
return str(value)
def _oldest_retained_id(redis_client, user_id: str) -> Optional[str]:
"""Return the id of the oldest entry still in the stream, or ``None``.
Used to detect ``Last-Event-ID`` having slid off the back of the
MAXLEN'd window.
"""
try:
info = redis_client.xinfo_stream(stream_key(user_id))
except Exception:
return None
if not isinstance(info, dict):
return None
# redis-py 7.4 returns str-keyed dicts here; the bytes-key probe is
# defence in depth in case ``decode_responses`` is ever flipped.
first_entry = info.get("first-entry") or info.get(b"first-entry")
if not first_entry:
return None
# XINFO STREAM returns first-entry as [id, [field, value, ...]]
try:
return _decode(first_entry[0])
except Exception:
return None
def _allow_replay(
redis_client, user_id: str, last_event_id: Optional[str]
) -> bool:
"""Per-user sliding-window snapshot-replay budget.
Fails open on Redis errors or when the budget is disabled. Empty-backlog
no-cursor connects skip INCR so dev double-mounts don't trip 429.
"""
budget = int(settings.EVENTS_REPLAY_BUDGET_REQUESTS_PER_WINDOW)
if budget <= 0:
return True
if redis_client is None:
return True
# Cheap pre-check: only INCR when we might actually replay. XLEN
# is one Redis op; the alternative (INCR every connect) is two
# ops AND wrongly counts no-op probes. The check is conservative:
# if ``last_event_id`` is set we always INCR, even if the cursor
# has already overtaken the latest entry — that case is rare and
# short-lived, and probing further would mean a redundant XRANGE.
if last_event_id is None:
try:
if int(redis_client.xlen(stream_key(user_id))) == 0:
return True
except Exception:
# XLEN probe failed; fall through to the INCR path so a
# transient Redis hiccup can't bypass the budget.
logger.debug(
"XLEN probe failed for replay budget check user=%s; "
"proceeding to INCR",
user_id,
)
window = max(1, int(settings.EVENTS_REPLAY_BUDGET_WINDOW_SECONDS))
key = replay_budget_key(user_id)
try:
used = int(redis_client.incr(key))
# Always (re)seed the TTL. Gating on ``used == 1`` would wedge
# the counter forever if INCR succeeds but EXPIRE raises on
# the seeding call. EXPIRE on an existing key resets the TTL
# to ``window`` — within ±1s of the per-window budget semantic.
redis_client.expire(key, window)
except Exception:
logger.debug(
"replay budget probe failed for user=%s; failing open",
user_id,
)
return True
return used <= budget
def _normalize_last_event_id(raw: Optional[str]) -> Optional[str]:
"""Validate the ``Last-Event-ID`` header / query param.
Returns the value unchanged when it parses as a Redis Streams id,
otherwise ``None`` — callers treat ``None`` as "client has nothing"
and replay from the start of the retained window. Invalid ids would
otherwise pass straight to XRANGE and surface as a quiet replay
failure plus broken truncation detection.
"""
if raw is None:
return None
raw = raw.strip()
if not raw or not _STREAM_ID_RE.match(raw):
return None
return raw
def _replay_backlog(
redis_client, user_id: str, last_event_id: Optional[str], max_count: int
) -> Iterator[tuple[str, str]]:
"""Yield ``(entry_id, sse_line)`` for backlog entries past ``last_event_id``.
Capped at ``max_count`` rows; clients catch up across reconnects.
Parse failures are skipped; the Streams id is injected into the
envelope so replay matches live-tail shape.
"""
# Exclusive start: '(<id>' skips the already-delivered entry.
start = f"({last_event_id}" if last_event_id else "-"
try:
entries = redis_client.xrange(
stream_key(user_id), min=start, max="+", count=max_count
)
except Exception as exc:
logger.warning(
"xrange replay failed for user=%s last_id=%s err=%s",
user_id,
last_event_id or "-",
exc,
)
return
for entry_id, fields in entries:
entry_id_str = _decode(entry_id)
if not entry_id_str:
continue
# decode_responses=False on the cache client ⇒ field keys/values
# are bytes. The string-key fallback covers a future flip of that
# default without a forced refactor here.
raw_event = None
if isinstance(fields, dict):
raw_event = fields.get(b"event")
if raw_event is None:
raw_event = fields.get("event")
event_str = _decode(raw_event)
if not event_str:
continue
try:
envelope = json.loads(event_str)
if isinstance(envelope, dict):
envelope["id"] = entry_id_str
event_str = json.dumps(envelope)
except Exception:
logger.debug(
"Replay envelope parse failed for entry %s; passing through raw",
entry_id_str,
)
yield entry_id_str, _format_sse(event_str, event_id=entry_id_str)
def _truncation_notice_line(oldest_id: str) -> str:
"""SSE event the frontend can react to with a full-state refetch."""
return _format_sse(
json.dumps(
{
"type": "backlog.truncated",
"payload": {"oldest_retained_id": oldest_id},
}
)
)
@events.route("/api/events", methods=["GET"])
def stream_events() -> Response:
decoded = getattr(request, "decoded_token", None)
user_id = decoded.get("sub") if isinstance(decoded, dict) else None
if not user_id:
return make_response(
jsonify({"success": False, "message": "Authentication required"}),
401,
)
# In dev deployments without AUTH_TYPE configured, every request
# resolves to user_id="local" and shares one stream. Surface this so
# an accidentally-multi-user dev box doesn't silently cross-stream.
global _local_user_warned
if user_id == "local" and not _local_user_warned:
logger.warning(
"SSE serving user_id='local' (AUTH_TYPE not set). "
"All clients on this deployment will share one event stream."
)
_local_user_warned = True
raw_last_event_id = request.headers.get("Last-Event-ID") or request.args.get(
"last_event_id"
)
last_event_id = _normalize_last_event_id(raw_last_event_id)
last_event_id_invalid = raw_last_event_id is not None and last_event_id is None
keepalive_seconds = float(settings.SSE_KEEPALIVE_SECONDS)
push_enabled = settings.ENABLE_SSE_PUSH
cap = int(settings.SSE_MAX_CONCURRENT_PER_USER)
redis_client = get_redis_instance()
counter_key = connection_counter_key(user_id)
counted = False
if push_enabled and redis_client is not None and cap > 0:
try:
current = int(redis_client.incr(counter_key))
counted = True
except Exception:
current = 0
logger.debug(
"SSE connection counter INCR failed for user=%s", user_id
)
if counted:
# 1h safety TTL — orphaned counts from hard crashes self-heal.
# EXPIRE failure must NOT clobber ``current`` and bypass the cap.
try:
redis_client.expire(counter_key, 3600)
except Exception:
logger.debug(
"SSE connection counter EXPIRE failed for user=%s", user_id
)
if current > cap:
try:
redis_client.decr(counter_key)
except Exception:
logger.debug(
"SSE connection counter DECR failed for user=%s",
user_id,
)
return make_response(
jsonify(
{
"success": False,
"message": "Too many concurrent SSE connections",
}
),
429,
)
# Replay budget is checked here, before the generator opens the
# stream, so a denial can surface as HTTP 429 instead of a silent
# snapshot skip. The earlier in-generator skip lost events between
# the client's cursor and the first live-tailed entry: the live
# tail still carried ``id:`` headers, the frontend advanced
# ``lastEventId`` to one of those ids, and the events in between
# were never reachable on the next reconnect. 429 keeps the
# cursor pinned and lets the frontend back off until the window
# slides (eventStreamClient.ts treats 429 as escalated backoff).
if push_enabled and redis_client is not None and not _allow_replay(
redis_client, user_id, last_event_id
):
if counted:
try:
redis_client.decr(counter_key)
except Exception:
logger.debug(
"SSE connection counter DECR failed for user=%s",
user_id,
)
return make_response(
jsonify(
{
"success": False,
"message": "Replay budget exhausted",
}
),
429,
)
@stream_with_context
def generate() -> Iterator[str]:
connect_ts = time.monotonic()
replayed_count = 0
try:
# First frame primes intermediaries (Cloudflare, nginx) so they
# don't sit on a buffer waiting for body bytes.
yield ": connected\n\n"
if not push_enabled:
yield ": push_disabled\n\n"
return
replay_lines: list[str] = []
max_replayed_id: Optional[str] = None
replay_done = False
# If the client sent a malformed Last-Event-ID, surface the
# truncation notice synchronously *before* the subscribe
# loop. Buffering it into ``replay_lines`` would lose it
# when ``Topic.subscribe`` returns immediately (Redis down)
# — the loop body never runs, and the flush at line ~335
# never fires.
if last_event_id_invalid:
yield _truncation_notice_line("")
replayed_count += 1
def _on_subscribe_callback() -> None:
# Runs synchronously inside Topic.subscribe after the
# SUBSCRIBE is acked. By doing XRANGE here, any publisher
# firing between SUBSCRIBE-send and SUBSCRIBE-ack has its
# XADD captured by XRANGE *and* its PUBLISH buffered at
# the connection layer until we read it — closing the
# replay/subscribe race the design doc warns about.
#
# Truncation contract: ``backlog.truncated`` is emitted
# ONLY when the client's ``Last-Event-ID`` has slid off
# the MAXLEN'd window — that's the case where the
# journal is genuinely gone past the cursor and the
# frontend should clear its slice cursor and refetch
# state. Cap-hit skips the snapshot silently: the
# cursor advances via the per-entry ``id:`` headers
# and the frontend's slice keeps the latest id so the
# next reconnect resumes from there. Budget-exhausted
# never reaches this callback — the route 429s before
# opening the stream, keeping the cursor pinned.
# Conflating these with stale-cursor truncation would
# tell the client to clear its cursor and re-receive
# the same oldest-N entries on every reconnect —
# locking the user out of entries past N.
nonlocal max_replayed_id, replay_done
try:
if redis_client is None:
return
oldest = _oldest_retained_id(redis_client, user_id)
if (
last_event_id
and oldest
and stream_id_compare(last_event_id, oldest) < 0
):
# The Last-Event-ID has slid off the MAXLEN window.
# Tell the client so it can fetch full state.
replay_lines.append(_truncation_notice_line(oldest))
replay_cap = int(settings.EVENTS_REPLAY_MAX_PER_REQUEST)
for entry_id, sse_line in _replay_backlog(
redis_client, user_id, last_event_id, replay_cap
):
replay_lines.append(sse_line)
max_replayed_id = entry_id
finally:
# Always flip the flag — even on partial-replay failure
# the outer loop must reach the flush step so we don't
# silently strand whatever entries did land.
replay_done = True
topic = Topic(topic_name(user_id))
last_keepalive = time.monotonic()
for payload in topic.subscribe(
on_subscribe=_on_subscribe_callback,
poll_timeout=SUBSCRIBE_POLL_INTERVAL_SECONDS,
):
# Flush snapshot on the first iteration after the SUBSCRIBE
# callback ran. This runs at most once per connection.
if replay_done and replay_lines:
for line in replay_lines:
yield line
replayed_count += 1
replay_lines.clear()
now = time.monotonic()
if payload is None:
if now - last_keepalive >= keepalive_seconds:
yield ": keepalive\n\n"
last_keepalive = now
continue
event_str = _decode(payload) or ""
event_id: Optional[str] = None
try:
envelope = json.loads(event_str)
if isinstance(envelope, dict):
candidate = envelope.get("id")
# Only trust ids that look like real Redis Streams
# ids (``ms`` or ``ms-seq``). A malformed or
# adversarial publisher could otherwise pin
# dedupe forever — a lex-greater bogus id would
# make every legitimate later id compare ``<=``
# and get dropped silently.
if isinstance(candidate, str) and _STREAM_ID_RE.match(
candidate
):
event_id = candidate
except Exception:
pass
# Dedupe: if this id was already covered by replay, drop it.
if (
event_id is not None
and max_replayed_id is not None
and stream_id_compare(event_id, max_replayed_id) <= 0
):
continue
yield _format_sse(event_str, event_id=event_id)
last_keepalive = now
# Topic.subscribe exited before the first yield (transient
# Redis hiccup between SUBSCRIBE-ack and the first poll, or
# an immediate Redis-down return). The callback may already
# have populated the snapshot — flush it so the client gets
# the backlog instead of a silent drop. Safe no-op when the
# in-loop flush ran (it clear()'d the buffer) and when the
# callback never fired (replay_done stays False).
if replay_done and replay_lines:
for line in replay_lines:
yield line
replayed_count += 1
replay_lines.clear()
except GeneratorExit:
return
except Exception:
logger.exception(
"SSE event-stream generator crashed for user=%s", user_id
)
finally:
duration_s = time.monotonic() - connect_ts
logger.info(
"event.disconnect user=%s duration_s=%.1f replayed=%d",
user_id,
duration_s,
replayed_count,
)
if counted and redis_client is not None:
try:
redis_client.decr(counter_key)
except Exception:
logger.debug(
"SSE connection counter DECR failed for user=%s on disconnect",
user_id,
)
response = Response(generate(), mimetype="text/event-stream")
response.headers["Cache-Control"] = "no-store"
response.headers["X-Accel-Buffering"] = "no"
response.headers["Connection"] = "keep-alive"
logger.info(
"event.connect user=%s last_event_id=%s%s",
user_id,
last_event_id or "-",
" (rejected_invalid)" if last_event_id_invalid else "",
)
return response

View File

@@ -214,10 +214,6 @@ class StoreAttachment(Resource):
{
"success": True,
"task_id": tasks[0]["task_id"],
# Surface the attachment_id so the frontend
# can correlate ``attachment.*`` SSE events
# to this row and skip the polling fallback.
"attachment_id": tasks[0]["attachment_id"],
"message": "File uploaded successfully. Processing started.",
}
),

View File

@@ -7,13 +7,9 @@ from flask_restx import fields, Namespace, Resource
from sqlalchemy import text as sql_text
from application.api import api
from application.api.answer.services.conversation_service import (
TERMINATED_RESPONSE_PLACEHOLDER,
)
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
from application.storage.db.repositories.attachments import AttachmentsRepository
from application.storage.db.repositories.conversations import ConversationsRepository
from application.storage.db.repositories.message_events import MessageEventsRepository
from application.storage.db.session import db_readonly, db_session
from application.utils import check_required_fields
@@ -350,25 +346,6 @@ class GetMessageTail(Resource):
if row is None:
return make_response(jsonify({"status": "not found"}), 404)
msg = row_to_dict(row)
# Mid-stream the row's response is the placeholder; rebuild
# the live partial from the journal so /tail mirrors SSE.
status = msg.get("status")
response = msg.get("response")
thought = msg.get("thought")
sources = msg.get("sources") or []
tool_calls = msg.get("tool_calls") or []
if status in ("pending", "streaming") and (
response == TERMINATED_RESPONSE_PLACEHOLDER
):
partial = MessageEventsRepository(conn).reconstruct_partial(
message_id
)
response = partial["response"]
thought = partial["thought"] or thought
if partial["sources"]:
sources = partial["sources"]
if partial["tool_calls"]:
tool_calls = partial["tool_calls"]
except Exception as err:
current_app.logger.error(
f"Error tailing message {message_id}: {err}", exc_info=True
@@ -379,11 +356,11 @@ class GetMessageTail(Resource):
jsonify(
{
"message_id": str(msg["id"]),
"status": status,
"response": response,
"thought": thought,
"sources": sources,
"tool_calls": tool_calls,
"status": msg.get("status"),
"response": msg.get("response"),
"thought": msg.get("thought"),
"sources": msg.get("sources") or [],
"tool_calls": msg.get("tool_calls") or [],
"request_id": msg.get("request_id"),
"last_heartbeat_at": metadata.get("last_heartbeat_at"),
"error": metadata.get("error"),

View File

@@ -3,7 +3,6 @@
from __future__ import annotations
import functools
import inspect
import logging
import threading
import uuid
@@ -27,20 +26,13 @@ LEASE_HEARTBEAT_INTERVAL = 30
LEASE_RETRY_MAX = 10
def with_idempotency(
task_name: str,
*,
on_poison: Optional[Callable[[str, dict], None]] = None,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
def with_idempotency(task_name: str) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""Short-circuit on completed key; gate concurrent runs via a lease.
The guard key is the caller's ``idempotency_key``, or one synthesized
from ``source_id`` so a keyless dispatch is still poison-guarded.
Entry short-circuits:
- completed row → return cached result
- live lease held → retry(countdown=LEASE_TTL_SECONDS)
- attempt_count > MAX_TASK_ATTEMPTS → poison alert; ``on_poison`` fires
- attempt_count > MAX_TASK_ATTEMPTS → poison-loop alert
Success writes ``completed``; exceptions leave ``pending`` for
autoretry until the poison-loop guard trips.
"""
@@ -48,14 +40,7 @@ def with_idempotency(
def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
@functools.wraps(fn)
def wrapper(self, *args: Any, idempotency_key: Any = None, **kwargs: Any) -> Any:
explicit_key = (
idempotency_key
if isinstance(idempotency_key, str) and idempotency_key
else None
)
# A keyless dispatch still gets the guard via a synthesized key;
# None means no anchor exists — run unguarded, as before.
key = explicit_key or _synthesize_guard_key(task_name, kwargs)
key = idempotency_key if isinstance(idempotency_key, str) and idempotency_key else None
if key is None:
return fn(self, *args, idempotency_key=idempotency_key, **kwargs)
@@ -103,9 +88,6 @@ def with_idempotency(
"attempts": attempt,
}
_finalize(key, poisoned, status="failed")
_run_poison_hook(
on_poison, task_name, fn, self, args, kwargs, idempotency_key,
)
return poisoned
heartbeat_thread, heartbeat_stop = _start_lease_heartbeat(
@@ -127,45 +109,6 @@ def with_idempotency(
return decorator
def _synthesize_guard_key(task_name: str, kwargs: dict) -> Optional[str]:
"""Derive a deterministic guard key from ``source_id`` for a keyless dispatch.
``source_id`` is stable across broker redeliveries and unique per
upload, so the poison-loop counter survives an OOM SIGKILL. Returns
``None`` when absent — the dispatch then runs unguarded as before.
"""
source_id = kwargs.get("source_id")
if source_id:
return f"auto:{task_name}:{source_id}"
return None
def _run_poison_hook(
on_poison: Optional[Callable[[str, dict], None]],
task_name: str,
fn: Callable[..., Any],
task_self: Any,
args: tuple,
kwargs: dict,
idempotency_key: Any,
) -> None:
"""Invoke a task's poison-path hook with named call args; swallow failures.
A hook failure must never change the poison-guard outcome.
"""
if on_poison is None:
return
try:
bound = inspect.signature(fn).bind_partial(
task_self, *args, idempotency_key=idempotency_key, **kwargs,
)
on_poison(task_name, dict(bound.arguments))
except Exception:
logger.exception(
"idempotency: poison hook failed for task=%s", task_name,
)
def _lookup_completed(key: str) -> Any:
"""Return cached ``result_json`` if a completed row exists for ``key``, else None."""
with db_readonly() as conn:

View File

@@ -114,11 +114,11 @@ def run_reconciliation() -> Dict[str, Any]:
},
)
# Q4: ingest checkpoints whose heartbeat has gone silent. Each is
# escalated to terminal ``status='stalled'`` and alerted once — no
# worker kill, no rollback of the partial embed. The 'stalled' flag
# ends the re-alert loop and drives the "indexing failed" badge the
# sources list derives from this row.
# Q4: ingest checkpoints whose heartbeat has gone silent. The
# reconciler only escalates (alerts) — it doesn't kill the worker
# or roll back the partial embed. The next dispatch resumes from
# ``last_index`` thanks to the per-chunk checkpoint, so this is an
# observability sweep, not a recovery action.
with engine.begin() as conn:
repo = ReconciliationRepository(conn)
for row in repo.find_and_lock_stalled_ingests():
@@ -134,7 +134,8 @@ def run_reconciliation() -> Dict[str, Any]:
"last_updated": str(row.get("last_updated")),
},
)
repo.mark_ingest_stalled(str(row["source_id"]))
# Bump the heartbeat so we don't re-alert every tick.
repo.touch_ingest_progress(str(row["source_id"]))
# Q5: idempotency rows whose lease expired with attempts exhausted.
# The wrapper's poison-loop guard normally finalises these, but if

View File

@@ -7,12 +7,8 @@ from flask import current_app, jsonify, make_response, redirect, request
from flask_restx import fields, Namespace, Resource
from application.api import api
from application.api.user.tasks import reingest_source_task, sync_source
from application.api.user.tasks import sync_source
from application.core.settings import settings
from application.parser.remote.remote_creator import normalize_remote_data
from application.storage.db.repositories.ingest_chunk_progress import (
IngestChunkProgressRepository,
)
from application.storage.db.repositories.sources import SourcesRepository
from application.storage.db.session import db_readonly, db_session
from application.storage.storage_creator import StorageCreator
@@ -143,8 +139,6 @@ class PaginatedSources(Resource):
"provider": provider,
"isNested": bool(doc.get("directory_structure")),
"type": doc.get("type", "file"),
# Derived in SourcesRepository.list_for_user.
"ingestStatus": doc.get("ingest_status"),
}
)
response = {
@@ -328,7 +322,7 @@ class SyncSource(Resource):
),
400,
)
source_data = normalize_remote_data(source_type, doc.get("remote_data"))
source_data = doc.get("remote_data")
if not source_data:
return make_response(
jsonify({"success": False, "message": "Source is not syncable"}), 400
@@ -352,70 +346,6 @@ class SyncSource(Resource):
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
@sources_ns.route("/sources/reingest")
class ReingestSource(Resource):
reingest_source_model = api.model(
"ReingestSourceModel",
{"source_id": fields.String(required=True, description="Source ID")},
)
@api.expect(reingest_source_model)
@api.doc(
description="Re-run ingestion for a source — e.g. to recover a "
"stalled embed flagged by the reconciler."
)
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json() or {}
missing_fields = check_required_fields(data, ["source_id"])
if missing_fields:
return missing_fields
source_id = data["source_id"]
try:
with db_readonly() as conn:
doc = SourcesRepository(conn).get_any(source_id, user)
except Exception as err:
current_app.logger.error(
f"Error looking up source: {err}", exc_info=True
)
return make_response(
jsonify({"success": False, "message": "Invalid source ID"}), 400
)
if not doc:
return make_response(
jsonify({"success": False, "message": "Source not found"}), 404
)
resolved_source_id = str(doc["id"])
# Drop the stale chunk-progress row so the sources list stops
# deriving a 'failed' status; reingest never rewrites it itself.
try:
with db_session() as conn:
IngestChunkProgressRepository(conn).delete(resolved_source_id)
except Exception as err:
current_app.logger.warning(
f"Could not clear ingest progress for {resolved_source_id}: "
f"{err}",
exc_info=True,
)
try:
# Scoped key so repeated clicks collapse onto one reingest.
task = reingest_source_task.delay(
source_id=resolved_source_id,
user=user,
idempotency_key=f"reingest-source:{user}:{resolved_source_id}",
)
except Exception as err:
current_app.logger.error(
f"Error starting reingest for source {source_id}: {err}",
exc_info=True,
)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
@sources_ns.route("/directory_structure")
class DirectoryStructure(Resource):
@api.doc(

View File

@@ -13,7 +13,6 @@ from sqlalchemy import text as sql_text
from application.api import api
from application.api.user.tasks import ingest, ingest_connector_task, ingest_remote
from application.core.settings import settings
from application.storage.db.source_ids import derive_source_id as _derive_source_id
from application.parser.connectors.connector_creator import ConnectorCreator
from application.parser.file.constants import SUPPORTED_SOURCE_EXTENSIONS
from application.storage.db.repositories.idempotency import IdempotencyRepository
@@ -70,13 +69,7 @@ def _claim_task_or_get_cached(key, task_name):
Pre-generates the celery task_id so a losing writer sees the same
id immediately. Returns ``(task_id, cached_response)``; non-None
cached means the caller should return without enqueuing. The
cached payload mirrors the fresh-request response shape (including
``source_id``) so the frontend can correlate SSE ingest events to
the cached upload task without an extra round-trip — but only when
the cached row actually exists; the "deduplicated" sentinel
deliberately omits ``source_id`` so the frontend doesn't bind to a
phantom source.
cached means the caller should return without enqueuing.
"""
predetermined_id = str(uuid.uuid4())
with db_session() as conn:
@@ -88,16 +81,10 @@ def _claim_task_or_get_cached(key, task_name):
with db_readonly() as conn:
existing = IdempotencyRepository(conn).get_task(key)
cached_id = existing.get("task_id") if existing else None
payload: dict = {
return None, {
"success": True,
"task_id": cached_id or "deduplicated",
}
# Only surface ``source_id`` when there's a real winner whose worker
# is publishing SSE events tagged with that id. The "deduplicated"
# branch means the lock row vanished — we have nothing to correlate.
if cached_id is not None:
payload["source_id"] = str(_derive_source_id(key))
return None, payload
def _release_claim(key):
@@ -249,15 +236,6 @@ class UploadFile(Resource):
file_path = f"{base_path}/{safe_file}"
with open(temp_file_path, "rb") as f:
storage.save_file(f, file_path)
# Mint the source UUID up here so the HTTP response and the
# worker's SSE envelopes share one id. With an idempotency
# key we reuse the deterministic uuid5 (retried task lands on
# the same source row); without a key we fall back to uuid4.
# The worker is told to use this id verbatim — see
# ``ingest_worker(source_id=...)``.
source_uuid = (
_derive_source_id(scoped_key) if scoped_key else uuid.uuid4()
)
ingest_kwargs = dict(
args=(
settings.UPLOAD_FOLDER,
@@ -271,7 +249,6 @@ class UploadFile(Resource):
"file_name_map": file_name_map,
# Scoped so the worker dedup row matches the HTTP claim.
"idempotency_key": scoped_key or idempotency_key,
"source_id": str(source_uuid),
},
)
if predetermined_task_id is not None:
@@ -296,15 +273,7 @@ class UploadFile(Resource):
return make_response(jsonify({"success": False}), 400)
# Predetermined id matches the dedup-claim row; loser GET sees same.
response_task_id = predetermined_task_id or task.id
# ``source_uuid`` was minted above and passed to the worker as
# ``source_id``; the worker uses it verbatim for every SSE event,
# so the frontend can correlate inbound ``source.ingest.*`` to
# this upload regardless of whether an idempotency key was set.
response_payload: dict = {
"success": True,
"task_id": response_task_id,
"source_id": str(source_uuid),
}
response_payload = {"success": True, "task_id": response_task_id}
return make_response(jsonify(response_payload), 200)
@@ -357,18 +326,6 @@ class UploadRemote(Resource):
)
if cached is not None:
return make_response(jsonify(cached), 200)
# Mint the source UUID up here so the HTTP response and the
# worker's SSE envelopes share one id. Same pattern as
# ``UploadFile.post``: with an idempotency key we reuse the
# deterministic uuid5 (retried task lands on the same source
# row); without a key we fall back to uuid4. The worker is told
# to use this id verbatim — see ``remote_worker`` and
# ``ingest_connector``. Without this the no-key path would mint
# a random uuid4 inside the worker that the frontend has no way
# to correlate SSE events to.
source_uuid = (
_derive_source_id(scoped_key) if scoped_key else uuid.uuid4()
)
try:
config = json.loads(data["data"])
source_data = None
@@ -425,23 +382,13 @@ class UploadRemote(Resource):
"recursive": config.get("recursive", False),
"retriever": config.get("retriever", "classic"),
"idempotency_key": scoped_key or idempotency_key,
"source_id": str(source_uuid),
},
}
if predetermined_task_id is not None:
connector_kwargs["task_id"] = predetermined_task_id
task = ingest_connector_task.apply_async(**connector_kwargs)
response_task_id = predetermined_task_id or task.id
# ``source_uuid`` was minted above and passed to the
# worker as ``source_id``; the worker uses it verbatim
# for every SSE event, so the frontend can correlate
# inbound ``source.ingest.*`` regardless of whether an
# idempotency key was set.
response_payload = {
"success": True,
"task_id": response_task_id,
"source_id": str(source_uuid),
}
response_payload = {"success": True, "task_id": response_task_id}
return make_response(jsonify(response_payload), 200)
remote_kwargs = {
"kwargs": {
@@ -450,7 +397,6 @@ class UploadRemote(Resource):
"user": user,
"loader": data["source"],
"idempotency_key": scoped_key or idempotency_key,
"source_id": str(source_uuid),
},
}
if predetermined_task_id is not None:
@@ -464,11 +410,7 @@ class UploadRemote(Resource):
_release_claim(scoped_key)
return make_response(jsonify({"success": False}), 400)
response_task_id = predetermined_task_id or task.id
response_payload = {
"success": True,
"task_id": response_task_id,
"source_id": str(source_uuid),
}
response_payload = {"success": True, "task_id": response_task_id}
return make_response(jsonify(response_payload), 200)
@@ -611,19 +553,6 @@ class ManageSourceFiles(Resource):
scoped_key, "reingest_source_task",
)
if cached is not None:
# Frontend keys reingest polling on
# ``reingest_task_id``; the shared cache helper
# writes ``task_id``. Alias here so a dedup
# response doesn't silently break FileTree's
# poller. Override ``source_id`` too — the
# helper derives it from the scoped key, which
# is correct for upload but wrong for reingest
# (the worker publishes events scoped to the
# actual source row id).
cached_task_id = cached.pop("task_id", None)
if cached_task_id is not None:
cached["reingest_task_id"] = cached_task_id
cached["source_id"] = resolved_source_id
return make_response(jsonify(cached), 200)
added_files = []
@@ -679,12 +608,6 @@ class ManageSourceFiles(Resource):
"added_files": added_files,
"parent_dir": parent_dir,
"reingest_task_id": task.id,
# ``source_id`` lets the frontend correlate
# inbound ``source.ingest.*`` SSE events
# (emitted by ``reingest_source_worker``)
# back to the reingest task — matches the
# upload route's source-id contract.
"source_id": resolved_source_id,
}
),
200,
@@ -736,15 +659,6 @@ class ManageSourceFiles(Resource):
scoped_key, "reingest_source_task",
)
if cached is not None:
cached_task_id = cached.pop("task_id", None)
if cached_task_id is not None:
cached["reingest_task_id"] = cached_task_id
# Override the helper's synthetic source_id (uuid5
# of the scoped key) with the real source row id
# — the reingest worker publishes SSE events
# scoped to ``resolved_source_id`` and FileTree
# correlates on it.
cached["source_id"] = resolved_source_id
return make_response(jsonify(cached), 200)
# Remove files from storage and directory structure
@@ -790,7 +704,6 @@ class ManageSourceFiles(Resource):
"message": f"Removed {len(removed_files)} files",
"removed_files": removed_files,
"reingest_task_id": task.id,
"source_id": resolved_source_id,
}
),
200,
@@ -849,14 +762,6 @@ class ManageSourceFiles(Resource):
scoped_key, "reingest_source_task",
)
if cached is not None:
cached_task_id = cached.pop("task_id", None)
if cached_task_id is not None:
cached["reingest_task_id"] = cached_task_id
# Same source_id override as the ``remove`` /
# ``add`` cached branches — the helper's synthetic
# id doesn't match what reingest_source_worker
# tags its SSE events with.
cached["source_id"] = resolved_source_id
return make_response(jsonify(cached), 200)
success = storage.remove_directory(full_directory_path)
@@ -920,7 +825,6 @@ class ManageSourceFiles(Resource):
"message": f"Successfully removed directory: {directory_path}",
"removed_directory": directory_path,
"reingest_task_id": task.id,
"source_id": resolved_source_id,
}
),
200,

View File

@@ -7,6 +7,7 @@ from application.worker import (
attachment_worker,
ingest_worker,
mcp_oauth,
mcp_oauth_status,
remote_worker,
sync,
sync_worker,
@@ -27,42 +28,8 @@ DURABLE_TASK = dict(
)
# operation tag for the poison-path source.ingest.failed event, per task.
_INGEST_POISON_OPERATION = {
"ingest": "upload",
"ingest_remote": "upload",
"ingest_connector_task": "upload",
"reingest_source_task": "reingest",
}
def _emit_ingest_poison_event(task_name, bound):
"""Publish a terminal ``source.ingest.failed`` when the poison-guard trips.
The guard returns before the worker runs, so the worker's own failed
event never fires — without this the upload toast spins on "training".
"""
user = bound.get("user")
source_id = bound.get("source_id")
if not user or not source_id:
return
from application.events.publisher import publish_user_event
publish_user_event(
user,
"source.ingest.failed",
{
"source_id": str(source_id),
"filename": bound.get("filename") or "",
"operation": _INGEST_POISON_OPERATION.get(task_name, "upload"),
"error": "Ingestion stopped after repeated failures.",
},
scope={"kind": "source", "id": str(source_id)},
)
@celery.task(**DURABLE_TASK)
@with_idempotency(task_name="ingest", on_poison=_emit_ingest_poison_event)
@with_idempotency(task_name="ingest")
def ingest(
self,
directory,
@@ -73,7 +40,6 @@ def ingest(
filename,
file_name_map=None,
idempotency_key=None,
source_id=None,
):
resp = ingest_worker(
self,
@@ -85,29 +51,22 @@ def ingest(
user,
file_name_map=file_name_map,
idempotency_key=idempotency_key,
source_id=source_id,
)
return resp
@celery.task(**DURABLE_TASK)
@with_idempotency(task_name="ingest_remote", on_poison=_emit_ingest_poison_event)
def ingest_remote(
self, source_data, job_name, user, loader,
idempotency_key=None, source_id=None,
):
@with_idempotency(task_name="ingest_remote")
def ingest_remote(self, source_data, job_name, user, loader, idempotency_key=None):
resp = remote_worker(
self, source_data, job_name, user, loader,
idempotency_key=idempotency_key,
source_id=source_id,
)
return resp
@celery.task(**DURABLE_TASK)
@with_idempotency(
task_name="reingest_source_task", on_poison=_emit_ingest_poison_event,
)
@with_idempotency(task_name="reingest_source_task")
def reingest_source_task(self, source_id, user, idempotency_key=None):
from application.worker import reingest_source_worker
@@ -164,9 +123,7 @@ def process_agent_webhook(self, agent_id, payload, idempotency_key=None):
@celery.task(**DURABLE_TASK)
@with_idempotency(
task_name="ingest_connector_task", on_poison=_emit_ingest_poison_event,
)
@with_idempotency(task_name="ingest_connector_task")
def ingest_connector_task(
self,
job_name,
@@ -181,7 +138,6 @@ def ingest_connector_task(
doc_id=None,
sync_frequency="never",
idempotency_key=None,
source_id=None,
):
from application.worker import ingest_connector
@@ -199,7 +155,6 @@ def ingest_connector_task(
doc_id=doc_id,
sync_frequency=sync_frequency,
idempotency_key=idempotency_key,
source_id=source_id,
)
return resp
@@ -242,15 +197,6 @@ def setup_periodic_tasks(sender, **kwargs):
version_check_task.s(),
name="version-check",
)
# Bound ``message_events`` growth — every streamed SSE chunk writes
# one row, so retained chats accumulate hundreds of rows per
# message. Reconnect-replay is only meaningful for streams the user
# could plausibly still be waiting on, so 14 days is generous.
sender.add_periodic_task(
timedelta(hours=24),
cleanup_message_events.s(),
name="cleanup-message-events",
)
@celery.task(bind=True)
@@ -259,6 +205,12 @@ def mcp_oauth_task(self, config, user):
return resp
@celery.task(bind=True)
def mcp_oauth_status_task(self, task_id):
resp = mcp_oauth_status(self, task_id)
return resp
@celery.task(bind=True, acks_late=False)
def cleanup_pending_tool_state(self):
"""Revert stale ``resuming`` rows, then delete TTL-expired rows."""
@@ -313,32 +265,6 @@ def reconciliation_task(self):
return run_reconciliation()
@celery.task(bind=True, acks_late=False)
def cleanup_message_events(self):
"""Delete ``message_events`` rows older than the retention window.
Streamed answer responses write one journal row per SSE yield,
so unbounded growth would dominate Postgres for any retained-
conversations deployment. The reconnect-replay path only needs
rows for in-flight streams; 14 days covers paused/tool-action
flows comfortably.
"""
from application.core.settings import settings
if not settings.POSTGRES_URI:
return {"deleted": 0, "skipped": "POSTGRES_URI not set"}
from application.storage.db.engine import get_engine
from application.storage.db.repositories.message_events import (
MessageEventsRepository,
)
ttl_days = settings.MESSAGE_EVENTS_RETENTION_DAYS
engine = get_engine()
with engine.begin() as conn:
deleted = MessageEventsRepository(conn).cleanup_older_than(ttl_days)
return {"deleted": deleted, "ttl_days": ttl_days}
@celery.task(bind=True, acks_late=False)
def version_check_task(self):
"""Periodic anonymous version check.

View File

@@ -1,5 +1,6 @@
"""Tool management MCP server integration."""
import json
from urllib.parse import urlencode, urlparse
from flask import current_app, jsonify, make_response, redirect, request
@@ -225,9 +226,7 @@ class MCPServerSave(Resource):
)
redis_client = get_redis_instance()
manager = MCPOAuthManager(redis_client)
result = manager.get_oauth_status(
config["oauth_task_id"], user
)
result = manager.get_oauth_status(config["oauth_task_id"])
if not result.get("status") == "completed":
return make_response(
jsonify(
@@ -439,6 +438,56 @@ class MCPOAuthCallback(Resource):
)
@tools_mcp_ns.route("/mcp_server/oauth_status/<string:task_id>")
class MCPOAuthStatus(Resource):
def get(self, task_id):
try:
redis_client = get_redis_instance()
status_key = f"mcp_oauth_status:{task_id}"
status_data = redis_client.get(status_key)
if status_data:
status = json.loads(status_data)
if "tools" in status and isinstance(status["tools"], list):
status["tools"] = [
{
"name": t.get("name", "unknown"),
"description": t.get("description", ""),
}
for t in status["tools"]
]
return make_response(
jsonify({"success": True, "task_id": task_id, **status})
)
else:
return make_response(
jsonify(
{
"success": True,
"task_id": task_id,
"status": "pending",
"message": "Waiting for OAuth to start...",
}
),
200,
)
except Exception as e:
current_app.logger.error(
f"Error getting OAuth status for task {task_id}: {str(e)}",
exc_info=True,
)
return make_response(
jsonify(
{
"success": False,
"error": "Failed to get OAuth status",
"task_id": task_id,
}
),
500,
)
@tools_mcp_ns.route("/mcp_server/auth_status")
class MCPAuthStatus(Resource):
@api.doc(

View File

@@ -222,26 +222,13 @@ def _stream_response(
for line in internal_stream:
if not line.strip():
continue
# ``complete_stream`` prefixes each frame with ``id: <seq>\n``
# before the ``data:`` line. Extract just the data line so JSON
# decode doesn't choke on the SSE framing.
event_str = ""
for raw in line.split("\n"):
if raw.startswith("data:"):
event_str = raw[len("data:") :].lstrip()
break
if not event_str:
continue
# Parse the internal SSE event
event_str = line.replace("data: ", "").strip()
try:
event_data = json.loads(event_str)
except (json.JSONDecodeError, TypeError):
continue
# Skip the informational ``message_id`` event — it has no v1 /
# OpenAI-compatible analog.
if event_data.get("type") == "message_id":
continue
# Update completion_id when we get the conversation id
if event_data.get("type") == "id":
conv_id = event_data.get("id", "")

View File

@@ -16,8 +16,6 @@ setup_logging()
from application.api import api # noqa: E402
from application.api.answer import answer # noqa: E402
from application.api.answer.routes.messages import messages_bp # noqa: E402
from application.api.events.routes import events # noqa: E402
from application.api.internal.routes import internal # noqa: E402
from application.api.user.routes import user # noqa: E402
from application.api.connector.routes import connector # noqa: E402
@@ -51,8 +49,6 @@ ensure_database_ready(
app = Flask(__name__)
app.register_blueprint(user)
app.register_blueprint(answer)
app.register_blueprint(events)
app.register_blueprint(messages_bp)
app.register_blueprint(internal)
app.register_blueprint(connector)
app.register_blueprint(v1_bp)

View File

@@ -29,17 +29,8 @@ def get_redis_instance():
with _instance_lock:
if _redis_instance is None and not _redis_creation_failed:
try:
# ``health_check_interval`` makes redis-py ping the
# connection every N seconds when otherwise idle.
# Without it, a half-open TCP (NAT silently dropped
# state, ELB idle-close) can hang the SSE generator
# in ``pubsub.get_message`` past its keepalive
# cadence — the kernel never surfaces the dead
# socket because no payload is in flight.
_redis_instance = redis.Redis.from_url(
settings.CACHE_REDIS_URL,
socket_connect_timeout=2,
health_check_interval=10,
settings.CACHE_REDIS_URL, socket_connect_timeout=2
)
except ValueError as e:
logger.error(f"Invalid Redis URL: {e}")

View File

@@ -1,8 +1,5 @@
import ctypes
import gc
import inspect
import logging
import sys
import threading
from celery import Celery
@@ -101,34 +98,6 @@ def _unbind_task_log_context(task_id, **_):
)
def _trim_native_heap() -> None:
"""Return freed glibc heap pages to the OS (Linux only; no-op elsewhere)."""
# docling/torch parsing makes large transient allocations; glibc keeps the
# freed pages in per-thread malloc arenas rather than returning them, so a
# long-lived worker child's RSS only ever climbs. malloc_trim hands them
# back. The symbol is glibc-only — absent in macOS libc.
if not sys.platform.startswith("linux"):
return
try:
ctypes.CDLL("libc.so.6").malloc_trim(0)
except (OSError, AttributeError):
pass
@task_postrun.connect
def _reclaim_memory_after_task(*args, **kwargs):
"""Drop per-task allocations so the prefork child's RSS doesn't ratchet."""
gc.collect()
torch = sys.modules.get("torch")
if torch is not None:
try:
if torch.cuda.is_available():
torch.cuda.empty_cache()
except Exception:
pass
_trim_native_heap()
@worker_ready.connect
def _run_version_check(*args, **kwargs):
"""Kick off the anonymous version check on worker startup.

View File

@@ -31,10 +31,3 @@ worker_prefetch_multiplier = settings.CELERY_WORKER_PREFETCH_MULTIPLIER
broker_transport_options = {"visibility_timeout": settings.CELERY_VISIBILITY_TIMEOUT}
result_expires = 86400 * 7
task_track_started = True
# Recycle the prefork worker child to bound native-heap growth from
# docling/torch parsing. Left unset (Celery's unlimited default) when 0.
if settings.CELERY_WORKER_MAX_MEMORY_PER_CHILD > 0:
worker_max_memory_per_child = settings.CELERY_WORKER_MAX_MEMORY_PER_CHILD
if settings.CELERY_WORKER_MAX_TASKS_PER_CHILD > 0:
worker_max_tasks_per_child = settings.CELERY_WORKER_MAX_TASKS_PER_CHILD

View File

@@ -36,11 +36,6 @@ class Settings(BaseSettings):
# and Dify defaults; long ingests can override via env.
CELERY_WORKER_PREFETCH_MULTIPLIER: int = 1
CELERY_VISIBILITY_TIMEOUT: int = 3600
# Recycle the prefork worker child once its resident size crosses this many
# kilobytes — backstops native-heap growth from docling/torch parsing. 0 disables.
CELERY_WORKER_MAX_MEMORY_PER_CHILD: int = 4194304
# Recycle the child after this many tasks; 0 disables (memory cap is the primary knob).
CELERY_WORKER_MAX_TASKS_PER_CHILD: int = 0
# Only consulted when VECTOR_STORE=mongodb or when running scripts/db/backfill.py; user data lives in Postgres.
MONGO_URI: Optional[str] = None
# User-data Postgres DB.
@@ -66,9 +61,6 @@ class Settings(BaseSettings):
PARSE_IMAGE_REMOTE: bool = False
DOCLING_OCR_ENABLED: bool = False # Enable OCR for docling parsers (PDF, images)
DOCLING_OCR_ATTACHMENTS_ENABLED: bool = False # Enable OCR for docling when parsing attachments
# Pages docling's threaded pipeline buffers in flight; the library
# default (100) drives worker RSS to ~3 GB on a mid-size PDF.
DOCLING_PIPELINE_QUEUE_MAX_SIZE: int = 2
VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb" or "pgvector"
RETRIEVERS_ENABLED: list = ["classic_rag"]
AGENT_NAME: str = "classic"
@@ -196,42 +188,6 @@ class Settings(BaseSettings):
COMPRESSION_PROMPT_VERSION: str = "v1.0" # Track prompt iterations
COMPRESSION_MAX_HISTORY_POINTS: int = 3 # Keep only last N compression points to prevent DB bloat
# Internal SSE push channel (notifications + durable replay journal)
# Master switch — when False, /api/events emits a "push_disabled" comment
# and returns; clients fall back to polling. Publisher becomes a no-op.
ENABLE_SSE_PUSH: bool = True
# Per-user durable backlog cap (~entries). At typical event rates this
# gives ~24h of replay; tune up for verbose feeds, down for memory.
EVENTS_STREAM_MAXLEN: int = 1000
# SSE keepalive comment cadence. Must sit under Cloudflare's 100s idle
# close and iOS Safari's ~60s — 15s gives generous headroom.
SSE_KEEPALIVE_SECONDS: int = 15
# Cap on simultaneous SSE connections per user. Each connection holds
# one WSGI thread (32 per gunicorn worker) and one Redis pub/sub
# connection. 8 covers normal multi-tab use without letting one user
# starve the pool. Set to 0 to disable the cap.
SSE_MAX_CONCURRENT_PER_USER: int = 8
# Per-request cap on the number of backlog entries XRANGE returns
# for ``/api/events`` snapshots. Bounds the bytes a single replay
# can move from Redis to the wire — a malicious client looping
# ``Last-Event-ID=<oldest>`` reconnects can only enumerate this
# many entries per round-trip. Combined with the per-user
# connection cap above and the windowed budget below, total
# enumeration throughput is bounded.
EVENTS_REPLAY_MAX_PER_REQUEST: int = 200
# Sliding-window cap on snapshot replays per user. Once the budget
# is exhausted the route returns HTTP 429 with the cursor pinned;
# the client backs off and retries after the window rolls over.
EVENTS_REPLAY_BUDGET_REQUESTS_PER_WINDOW: int = 30
EVENTS_REPLAY_BUDGET_WINDOW_SECONDS: int = 60
# Retention for the ``message_events`` journal. The ``cleanup_message_events``
# beat task deletes rows older than this. Reconnect-replay only
# needs the journal for streams a client could still be tailing,
# so 14 days is a generous default that covers paused/tool-action
# flows without unbounded table growth.
MESSAGE_EVENTS_RETENTION_DAYS: int = 14
@field_validator("POSTGRES_URI", mode="before")
@classmethod
def _normalize_postgres_uri_validator(cls, v):

View File

@@ -1,52 +0,0 @@
"""Stream/topic key derivations shared by publisher and SSE consumer.
Single source of truth for the per-user Redis Streams key and pub/sub
topic name. Both must agree exactly — a typo here splits the
publisher's writes from the consumer's reads.
"""
from __future__ import annotations
def stream_key(user_id: str) -> str:
"""Redis Streams key holding the durable backlog for ``user_id``."""
return f"user:{user_id}:stream"
def topic_name(user_id: str) -> str:
"""Redis pub/sub channel used for live fan-out to ``user_id``."""
return f"user:{user_id}"
def connection_counter_key(user_id: str) -> str:
"""Redis counter tracking active SSE connections for ``user_id``."""
return f"user:{user_id}:sse_count"
def replay_budget_key(user_id: str) -> str:
"""Redis counter tracking snapshot replays for ``user_id`` in the
rolling rate-limit window."""
return f"user:{user_id}:replay_count"
def stream_id_compare(a: str, b: str) -> int:
"""Compare two Redis Streams ids. Returns -1, 0, 1 like ``cmp``.
Stream ids are ``ms-seq`` strings; comparing as strings would be wrong
once ``ms`` straddles digit-count boundaries. We parse and compare
as ``(int, int)`` tuples.
Raises ``ValueError`` on malformed input. Callers must pre-validate
against ``_STREAM_ID_RE`` (or equivalent) — a lex fallback here let
a malformed id compare lex-greater than a real one and silently pin
dedup forever.
"""
a_ms, _, a_seq = a.partition("-")
b_ms, _, b_seq = b.partition("-")
a_tuple = (int(a_ms), int(a_seq) if a_seq else 0)
b_tuple = (int(b_ms), int(b_seq) if b_seq else 0)
if a_tuple < b_tuple:
return -1
if a_tuple > b_tuple:
return 1
return 0

View File

@@ -1,144 +0,0 @@
"""User-scoped event publisher: durable backlog + live fan-out.
Each ``publish_user_event`` call writes twice:
1. ``XADD user:{user_id}:stream MAXLEN ~ <cap> * event <json>`` — the
durable backlog used by SSE reconnect (``Last-Event-ID``) and stream
replay. Bounded by ``EVENTS_STREAM_MAXLEN`` (~24h at typical event
rates) so the per-user footprint stays predictable.
2. ``PUBLISH user:{user_id} <json-with-id>`` — live fan-out to every
currently connected SSE generator for the user, across instances.
Together they give a snapshot-plus-tail story: a reconnecting client
reads ``XRANGE`` from its last seen id and then transitions onto the
live pub/sub. The Redis Streams entry id (e.g. ``1735682400000-0``) is
the canonical, monotonically increasing event id and is what
``Last-Event-ID`` carries.
Failures are logged and swallowed: the caller is typically a Celery
task whose primary work has already succeeded, and a notification
delivery miss should not surface as a task failure.
"""
from __future__ import annotations
import json
import logging
from datetime import datetime, timezone
from typing import Any, Optional
from application.cache import get_redis_instance
from application.core.settings import settings
from application.events.keys import stream_key, topic_name
from application.streaming.broadcast_channel import Topic
logger = logging.getLogger(__name__)
def _iso_now() -> str:
"""ISO 8601 UTC with millisecond precision and Z suffix."""
return (
datetime.now(timezone.utc)
.isoformat(timespec="milliseconds")
.replace("+00:00", "Z")
)
def publish_user_event(
user_id: str,
event_type: str,
payload: dict[str, Any],
*,
scope: Optional[dict[str, Any]] = None,
) -> Optional[str]:
"""Publish a user-scoped event; return the Redis Streams id or ``None``.
Fire-and-forget: never raises. ``None`` means the event reached
neither the journal nor live subscribers (see runbook for causes).
"""
if not user_id or not event_type:
logger.warning(
"publish_user_event called without user_id or event_type "
"(user_id=%r, event_type=%r)",
user_id,
event_type,
)
return None
if not settings.ENABLE_SSE_PUSH:
return None
envelope_partial: dict[str, Any] = {
"type": event_type,
"ts": _iso_now(),
"user_id": user_id,
"topic": topic_name(user_id),
"scope": scope or {},
"payload": payload,
}
try:
envelope_partial_json = json.dumps(envelope_partial)
except (TypeError, ValueError) as exc:
logger.warning(
"publish_user_event payload not JSON-serializable: "
"user=%s type=%s err=%s",
user_id,
event_type,
exc,
)
return None
redis = get_redis_instance()
if redis is None:
logger.debug("Redis unavailable; skipping publish_user_event")
return None
maxlen = settings.EVENTS_STREAM_MAXLEN
stream_id: Optional[str] = None
try:
# Auto-id ('*') gives a monotonic ms-seq id that doubles as the
# SSE event id. ``approximate=True`` lets Redis trim in chunks
# for performance; the cap is treated as ~MAXLEN, never <.
result = redis.xadd(
stream_key(user_id),
{"event": envelope_partial_json},
maxlen=maxlen,
approximate=True,
)
stream_id = (
result.decode("utf-8")
if isinstance(result, (bytes, bytearray))
else str(result)
)
except Exception:
logger.exception(
"xadd failed for user=%s event_type=%s", user_id, event_type
)
# If the durable journal write failed there is no canonical id to
# ship — publishing the envelope live would put an id-less record
# on the wire that bypasses the SSE route's dedup floor and breaks
# ``Last-Event-ID`` semantics for any reconnect. Best-effort
# delivery means dropping consistently, not delivering inconsistent
# state.
if stream_id is None:
return None
envelope = dict(envelope_partial)
envelope["id"] = stream_id
try:
Topic(topic_name(user_id)).publish(json.dumps(envelope))
except Exception:
logger.exception(
"publish failed for user=%s event_type=%s", user_id, event_type
)
logger.debug(
"event.published topic=%s type=%s id=%s",
topic_name(user_id),
event_type,
stream_id,
)
return stream_id

View File

@@ -4,7 +4,6 @@ from typing import Any, List, Optional
from retry import retry
from tqdm import tqdm
from application.core.settings import settings
from application.events.publisher import publish_user_event
from application.storage.db.repositories.ingest_chunk_progress import (
IngestChunkProgressRepository,
)
@@ -153,9 +152,6 @@ def embed_and_store_documents(
task_status: Any,
*,
attempt_id: Optional[str] = None,
user_id: Optional[str] = None,
progress_start: int = 0,
progress_end: int = 100,
) -> None:
"""Embeds documents and stores them in a vector store.
@@ -174,15 +170,6 @@ def embed_and_store_documents(
attempt_id: Stable id of the current task invocation,
typically ``self.request.id`` from the Celery task body.
``None`` is treated as a fresh attempt every time.
user_id: When provided, per-percent SSE progress events are
published to ``user:{user_id}`` for the in-app upload toast.
``None`` is the safe default — workers without a user
context (e.g. background syncs) skip the publish.
progress_start: Percent the reported progress maps to at chunk 0.
Lets a caller reserve the lower band for an earlier stage
(e.g. parsing). Defaults to ``0`` (embed owns the whole bar).
progress_end: Percent the reported progress maps to at the final
chunk. Defaults to ``100``.
Returns:
None
@@ -262,9 +249,6 @@ def embed_and_store_documents(
# Process and embed documents
chunk_error: Exception | None = None
failed_idx: int | None = None
last_published_pct = -1
source_id_str = str(source_id)
progress_span = progress_end - progress_start
for idx in tqdm(
range(loop_start, total_docs),
desc="Embedding 🦖",
@@ -274,30 +258,10 @@ def embed_and_store_documents(
):
doc = docs[idx]
try:
# Map the embed loop into [progress_start, progress_end].
progress = progress_start + int(
((idx + 1) / total_docs) * progress_span
)
# Update task status for progress tracking
progress = int(((idx + 1) / total_docs) * 100)
task_status.update_state(state="PROGRESS", meta={"current": progress})
# SSE push for sub-second upload-toast updates. Throttled to one
# event per percent so a 10k-chunk ingest emits ~100 events,
# not 10k. The Celery update_state above stays the source of
# truth for the polling-fallback path.
if user_id and progress > last_published_pct:
publish_user_event(
user_id,
"source.ingest.progress",
{
"current": progress,
"total": total_docs,
"embedded_chunks": idx + 1,
"stage": "embedding",
},
scope={"kind": "source", "id": source_id_str},
)
last_published_pct = progress
# Add document to vector store
add_text_to_store_with_retry(store, doc, source_id)
_record_progress(source_id, last_index=idx, embedded_chunks=idx + 1)

View File

@@ -211,22 +211,13 @@ class SimpleDirectoryReader(BaseReader):
return new_input_files
def load_data(
self,
concatenate: bool = False,
progress_callback: Optional[Callable[[int, int], None]] = None,
) -> List[Document]:
def load_data(self, concatenate: bool = False) -> List[Document]:
"""Load data from the input directory.
Args:
concatenate (bool): whether to concatenate all files into one document.
If set to True, file metadata is ignored.
False by default.
progress_callback (Optional[Callable[[int, int], None]]): Called
after each file is parsed with ``(files_done, total_files)``.
Lets callers surface parse/OCR progress before embedding
begins. Exceptions raised by the callback are swallowed so
progress reporting can never fail ingestion.
Returns:
List[Document]: A list of documents.
@@ -235,9 +226,8 @@ class SimpleDirectoryReader(BaseReader):
data_list: List[str] = []
metadata_list = []
self.file_token_counts = {}
total_files = len(self.input_files)
for file_index, input_file in enumerate(self.input_files):
for input_file in self.input_files:
suffix_lower = input_file.suffix.lower()
parser_metadata = {}
if suffix_lower in self.file_extractor:
@@ -287,15 +277,7 @@ class SimpleDirectoryReader(BaseReader):
else:
data_list.append(str(data))
metadata_list.append(base_metadata)
if progress_callback is not None:
try:
progress_callback(file_index + 1, total_files)
except Exception:
logging.warning(
"load_data progress callback failed", exc_info=True
)
# Build directory structure if input_dir is provided
if hasattr(self, 'input_dir'):
self.directory_structure = self.build_directory_structure(self.input_dir)

View File

@@ -16,29 +16,6 @@ from application.parser.file.base_parser import BaseParser
logger = logging.getLogger(__name__)
# Per-stage batch size for docling's threaded pipeline; 1 holds the
# concurrent working set to a single page (see _apply_pipeline_caps).
_PIPELINE_BATCH_SIZE = 1
def _apply_pipeline_caps(pipeline_options) -> None:
"""Cap docling's threaded-pipeline queue depth and batch sizes in place.
hasattr-guarded so docling builds without these knobs are unaffected.
"""
from application.core.settings import settings
caps = {
"queue_max_size": max(1, settings.DOCLING_PIPELINE_QUEUE_MAX_SIZE),
"layout_batch_size": _PIPELINE_BATCH_SIZE,
"table_batch_size": _PIPELINE_BATCH_SIZE,
"ocr_batch_size": _PIPELINE_BATCH_SIZE,
}
for name, value in caps.items():
if hasattr(pipeline_options, name):
setattr(pipeline_options, name, value)
class DoclingParser(BaseParser):
"""Parser using docling for advanced document processing.
@@ -109,7 +86,6 @@ class DoclingParser(BaseParser):
do_ocr=self.ocr_enabled,
do_table_structure=self.table_structure,
)
_apply_pipeline_caps(pipeline_options)
if self.ocr_enabled:
ocr_options = self._get_ocr_options()

View File

@@ -1,11 +1,11 @@
import logging
import os
import requests
from urllib.parse import urlparse, urljoin
from bs4 import BeautifulSoup
from application.parser.remote.base import BaseRemote
from application.parser.schema.base import Document
from application.core.url_validation import validate_url, SSRFError
from application.security.safe_url import UnsafeUserUrlError, pinned_request
from langchain_community.document_loaders import WebBaseLoader
class CrawlerLoader(BaseRemote):
@@ -35,7 +35,14 @@ class CrawlerLoader(BaseRemote):
visited_urls.add(current_url)
try:
response = pinned_request("GET", current_url, timeout=30)
# Validate each URL before making requests
try:
validate_url(current_url)
except SSRFError as e:
logging.warning(f"Skipping URL due to validation failure: {current_url} - {e}")
continue
response = requests.get(current_url, timeout=30)
response.raise_for_status()
loader = self.loader([current_url])
docs = loader.load()

View File

@@ -1,8 +1,8 @@
import requests
from urllib.parse import urlparse, urljoin
from bs4 import BeautifulSoup
from application.parser.remote.base import BaseRemote
from application.core.url_validation import validate_url, SSRFError
from application.security.safe_url import UnsafeUserUrlError, pinned_request
import re
from markdownify import markdownify
from application.parser.schema.base import Document
@@ -20,6 +20,7 @@ class CrawlerLoader(BaseRemote):
"""
self.limit = limit
self.allow_subdomains = allow_subdomains
self.session = requests.Session()
def load_data(self, inputs):
url = inputs
@@ -90,13 +91,15 @@ class CrawlerLoader(BaseRemote):
def _fetch_page(self, url):
try:
response = pinned_request("GET", url, timeout=10)
# Validate URL before fetching to prevent SSRF
validate_url(url)
response = self.session.get(url, timeout=10)
response.raise_for_status()
return response.text
except UnsafeUserUrlError as e:
except SSRFError as e:
print(f"URL validation failed for {url}: {e}")
return None
except Exception as e:
except requests.exceptions.RequestException as e:
print(f"Error fetching URL {url}: {e}")
return None

View File

@@ -1,5 +1,3 @@
import json
from application.parser.remote.sitemap_loader import SitemapLoader
from application.parser.remote.crawler_loader import CrawlerLoader
from application.parser.remote.web_loader import WebLoader
@@ -34,59 +32,3 @@ class RemoteCreator:
if not loader_class:
raise ValueError(f"No loader class found for type {type}")
return loader_class(*args, **kwargs)
# Loader types whose load_data expects a URL string, not a config dict.
_URL_LOADER_TYPES = {"url", "crawler", "sitemap", "github"}
# Keys a remote_data dict may hold the URL under (``raw`` is the legacy shape).
_URL_DATA_KEYS = ("url", "urls", "repo_url", "raw")
def normalize_remote_data(source_type, remote_data):
"""Convert a stored ``sources.remote_data`` JSONB value into the
``source_data`` shape the matching loader expects.
Args:
source_type: The ``sources.type`` value (the loader name).
remote_data: The stored ``remote_data`` (dict, list, str, or None).
Returns:
Loader input: a URL string or list for url/crawler/sitemap/github,
a JSON string for reddit, a dict for s3; ``None`` when the row has
nothing syncable.
"""
if remote_data is None:
return None
# Some legacy rows stored the JSON itself as a string.
if isinstance(remote_data, str):
stripped = remote_data.strip()
if stripped[:1] in ("{", "["):
try:
remote_data = json.loads(stripped)
except json.JSONDecodeError:
# Not actually JSON — leave remote_data as the original
# string; the per-loader branches below handle a string.
pass
loader = (source_type or "").lower()
if loader in _URL_LOADER_TYPES:
if isinstance(remote_data, dict):
for key in _URL_DATA_KEYS:
value = remote_data.get(key)
if value:
return value
# No URL key — None keeps the loader off the dict-crash path.
return None
return remote_data
if loader == "reddit":
# reddit's loader runs json.loads() on its input — needs a string.
if isinstance(remote_data, (dict, list)):
return json.dumps(remote_data)
return remote_data
# s3's loader accepts a dict or JSON string; pass it through unchanged.
return remote_data

View File

@@ -1,9 +1,9 @@
import logging
import re
import requests
import re # Import regular expression library
import defusedxml.ElementTree as ET
from application.parser.remote.base import BaseRemote
from application.core.url_validation import validate_url, SSRFError
from application.security.safe_url import UnsafeUserUrlError, pinned_request
class SitemapLoader(BaseRemote):
def __init__(self, limit=20):
@@ -53,12 +53,14 @@ class SitemapLoader(BaseRemote):
def _extract_urls(self, sitemap_url):
try:
response = pinned_request("GET", sitemap_url, timeout=30)
response.raise_for_status()
except UnsafeUserUrlError as e:
# Validate URL before fetching to prevent SSRF
validate_url(sitemap_url)
response = requests.get(sitemap_url, timeout=30)
response.raise_for_status() # Raise an exception for HTTP errors
except SSRFError as e:
print(f"URL validation failed for sitemap: {sitemap_url}. Error: {e}")
return []
except Exception as e:
except (requests.exceptions.HTTPError, requests.exceptions.ConnectionError) as e:
print(f"Failed to fetch sitemap: {sitemap_url}. Error: {e}")
return []
@@ -95,6 +97,13 @@ class SitemapLoader(BaseRemote):
nested_sitemap_url = sitemap.text
if not nested_sitemap_url:
continue
try:
nested_sitemap_url = validate_url(nested_sitemap_url)
except SSRFError as e:
logging.error(
f"URL validation failed for nested sitemap {nested_sitemap_url}: {e}"
)
continue
urls.extend(self._extract_urls(nested_sitemap_url))
return urls

View File

@@ -291,55 +291,6 @@ def _ip_to_url_host(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> str:
return str(ip)
def pinned_request(
method: str,
url: str,
*,
data: Any = None,
json: Any = None,
headers: dict[str, str] | None = None,
timeout: float = 90.0,
allow_redirects: bool = False,
) -> requests.Response:
"""Send an HTTP request with the connection pinned to a validated IP,
closing the DNS-rebinding TOCTOU window left by the naive
validate-then-``requests`` pattern.
Raises:
UnsafeUserUrlError: If the URL fails the SSRF guard.
requests.RequestException: For network-level failures.
"""
host, ip, parts = _validate_and_pick_ip(url)
netloc = _ip_to_url_host(ip)
if parts.port is not None:
netloc = f"{netloc}:{parts.port}"
pinned_url = urlunsplit(
(parts.scheme, netloc, parts.path, parts.query, parts.fragment)
)
request_headers = dict(headers or {})
host_header = host if parts.port is None else f"{host}:{parts.port}"
request_headers["Host"] = host_header
session = requests.Session()
if parts.scheme == "https":
session.mount("https://", _PinnedHostAdapter(host))
try:
return session.request(
method=method.upper(),
url=pinned_url,
data=data,
json=json,
headers=request_headers,
timeout=timeout,
allow_redirects=allow_redirects,
)
finally:
session.close()
def pinned_post(
url: str,
*,
@@ -377,15 +328,33 @@ def pinned_post(
requests.RequestException: For network-level failures.
"""
return pinned_request(
"POST",
url,
json=json,
headers=headers,
timeout=timeout,
allow_redirects=allow_redirects,
host, ip, parts = _validate_and_pick_ip(url)
netloc = _ip_to_url_host(ip)
if parts.port is not None:
netloc = f"{netloc}:{parts.port}"
pinned_url = urlunsplit(
(parts.scheme, netloc, parts.path, parts.query, parts.fragment)
)
request_headers = dict(headers or {})
host_header = host if parts.port is None else f"{host}:{parts.port}"
request_headers["Host"] = host_header
session = requests.Session()
if parts.scheme == "https":
session.mount("https://", _PinnedHostAdapter(host))
try:
return session.post(
pinned_url,
json=json,
headers=request_headers,
timeout=timeout,
allow_redirects=allow_redirects,
)
finally:
session.close()
class _PinnedHTTPSTransport(httpx.HTTPTransport):
"""``httpx`` transport pinned to a single validated IP literal.

View File

@@ -34,7 +34,7 @@ from sqlalchemy.dialects.postgresql import ARRAY, CITEXT, JSONB, UUID
metadata = MetaData()
# --- Users, prompts, tools, logs --------------------------------------------
# --- Phase 1, Tier 1 --------------------------------------------------------
users_table = Table(
"users",
@@ -138,7 +138,7 @@ app_metadata_table = Table(
)
# --- Agents, sources, attachments, artifacts --------------------------------
# --- Phase 2, Tier 2 --------------------------------------------------------
agent_folders_table = Table(
"agent_folders",
@@ -307,7 +307,7 @@ connector_sessions_table = Table(
)
# --- Conversations, messages, workflows -------------------------------------
# --- Phase 3, Tier 3 --------------------------------------------------------
conversations_table = Table(
"conversations",
@@ -363,36 +363,6 @@ conversation_messages_table = Table(
UniqueConstraint("conversation_id", "position", name="conversation_messages_conv_pos_uidx"),
)
# Per-yield journal of chat-stream events, used by the snapshot+tail
# reconnect: the route's GET reconnect endpoint reads
# ``WHERE message_id = ? AND sequence_no > ?`` from this table before
# tailing the live ``channel:{message_id}`` pub/sub. See
# ``application/streaming/event_replay.py`` and migration 0007.
message_events_table = Table(
"message_events",
metadata,
# PK is the composite ``(message_id, sequence_no)`` — it doubles as
# the snapshot read index (covering range scan on
# ``WHERE message_id = ? AND sequence_no > ?``).
Column(
"message_id",
UUID(as_uuid=True),
ForeignKey("conversation_messages.id", ondelete="CASCADE"),
primary_key=True,
nullable=False,
),
# Strictly monotonic per ``message_id``. Allocated by the route as it
# yields, so the writer is single-threaded for the lifetime of one
# stream — no contention, no SERIAL needed.
Column("sequence_no", Integer, primary_key=True, nullable=False),
Column("event_type", Text, nullable=False),
Column("payload", JSONB, nullable=False, server_default="{}"),
Column(
"created_at", DateTime(timezone=True), nullable=False, server_default=func.now()
),
)
shared_conversations_table = Table(
"shared_conversations",
metadata,
@@ -433,7 +403,7 @@ pending_tool_state_table = Table(
)
# --- Durability foundation (idempotency / journals, migration 0004) ---------
# --- Tier 1 durability foundation (migration 0004) --------------------------
# CHECK constraints (status enums) and partial indexes are intentionally
# omitted from these declarations — the DB is the authority. Repositories
# use raw ``text(...)`` SQL against these tables, not the Core objects.
@@ -514,9 +484,6 @@ ingest_chunk_progress_table = Table(
# same task resumes from the checkpoint, but a separate invocation
# (manual reingest, scheduled sync) resets to a clean re-index.
Column("attempt_id", Text),
# Added in ``0008_ingest_progress_status``. The reconciler flips
# this to 'stalled'; ``init_progress`` resets it to 'active'.
Column("status", Text, nullable=False, server_default="active"),
)

View File

@@ -15,7 +15,6 @@ Covers every operation the legacy Mongo code performs on
from __future__ import annotations
import json
from enum import Enum
from typing import Optional
from sqlalchemy import Connection, text
@@ -26,22 +25,6 @@ from application.storage.db.models import conversations_table, conversation_mess
from application.storage.db.serialization import PGNativeJSONEncoder
class MessageUpdateOutcome(str, Enum):
"""Discriminated result of ``update_message_by_id``.
Distinguishes the row-actually-updated case from the row-already-at-
the-requested-terminal-state case so an abort handler can journal
``end`` instead of ``error`` when the normal-path finalize already
flipped the row to ``complete``.
"""
UPDATED = "updated"
ALREADY_COMPLETE = "already_complete"
ALREADY_FAILED = "already_failed"
NOT_FOUND = "not_found"
INVALID = "invalid"
def _message_row_to_dict(row) -> dict:
"""Like ``row_to_dict`` but renames the DB column ``message_metadata``
back to the public API key ``metadata`` so callers keep the Mongo-era
@@ -75,8 +58,8 @@ class ConversationsRepository:
- Already-UUID-shaped → returned as-is.
- Otherwise treated as a Mongo ObjectId and looked up via
``agents.legacy_mongo_id``. Returns ``None`` if no PG row
exists yet (e.g. the agent was created before the backfill
ran).
exists yet (e.g. the agent was created before Phase 1
backfill).
"""
if not agent_id_raw:
return None
@@ -714,7 +697,7 @@ class ConversationsRepository:
def update_message_by_id(
self, message_id: str, fields: dict,
*, only_if_non_terminal: bool = False,
) -> MessageUpdateOutcome:
) -> bool:
"""Update specific fields on a message identified by its UUID.
``metadata`` is merged into the existing JSONB rather than
@@ -722,13 +705,9 @@ class ConversationsRepository:
a successful late finalize. When ``only_if_non_terminal`` is
True, the update is gated so a late finalize cannot retract a
reconciler-set ``failed`` (or a prior ``complete``).
The return value discriminates "I updated the row" from "the
row was already at a terminal state" so the abort handler can
journal ``end`` when the normal-path finalize already ran.
"""
if not looks_like_uuid(message_id):
return MessageUpdateOutcome.INVALID
return False
allowed = {
"prompt", "response", "thought", "sources", "tool_calls",
"attachments", "model_id", "metadata", "timestamp", "status",
@@ -736,7 +715,7 @@ class ConversationsRepository:
}
filtered = {k: v for k, v in fields.items() if k in allowed}
if not filtered:
return MessageUpdateOutcome.INVALID
return False
api_to_col = {"metadata": "message_metadata"}
@@ -773,44 +752,15 @@ class ConversationsRepository:
params[col] = val
set_parts.append("updated_at = now()")
update_where = ["id = CAST(:id AS uuid)"]
where_clauses = ["id = CAST(:id AS uuid)"]
if only_if_non_terminal:
update_where.append("status NOT IN ('complete', 'failed')")
# Single-statement attempt + prior-status probe. Both CTEs see
# the same MVCC snapshot, so ``prior.status`` reflects the row
# state before the UPDATE — exactly what we need to tell
# ``ALREADY_COMPLETE`` apart from ``ALREADY_FAILED`` apart from
# ``NOT_FOUND`` without a follow-up SELECT.
where_clauses.append("status NOT IN ('complete', 'failed')")
sql = (
"WITH attempted AS ("
f" UPDATE conversation_messages SET {', '.join(set_parts)} "
f" WHERE {' AND '.join(update_where)} "
" RETURNING 1 AS updated"
"), "
"prior AS ("
" SELECT status FROM conversation_messages "
" WHERE id = CAST(:id AS uuid)"
") "
"SELECT (SELECT updated FROM attempted) AS updated, "
" (SELECT status FROM prior) AS prior_status"
f"UPDATE conversation_messages SET {', '.join(set_parts)} "
f"WHERE {' AND '.join(where_clauses)}"
)
row = self._conn.execute(text(sql), params).fetchone()
if row is None:
return MessageUpdateOutcome.NOT_FOUND
updated, prior_status = row[0], row[1]
if updated:
return MessageUpdateOutcome.UPDATED
if prior_status is None:
return MessageUpdateOutcome.NOT_FOUND
if prior_status == "complete":
return MessageUpdateOutcome.ALREADY_COMPLETE
if prior_status == "failed":
return MessageUpdateOutcome.ALREADY_FAILED
# ``only_if_non_terminal=False`` always updates an existing row,
# so reaching here means the gate excluded it for some status
# the terminal set doesn't cover — treat as "not found" rather
# than inventing a new variant.
return MessageUpdateOutcome.NOT_FOUND
result = self._conn.execute(text(sql), params)
return result.rowcount > 0
def update_message_status(
self, message_id: str, status: str,

View File

@@ -41,9 +41,6 @@ class IngestChunkProgressRepository:
rows with NULL ``attempt_id`` resume against another NULL
caller (e.g. test fixtures), but get reset the moment a real
``attempt_id`` arrives.
Both branches also reset ``status`` to ``'active'``, clearing a
prior reconciler ``'stalled'`` escalation.
"""
result = self._conn.execute(
text(
@@ -71,8 +68,7 @@ class IngestChunkProgressRepository:
THEN ingest_chunk_progress.embedded_chunks
ELSE 0
END,
attempt_id = EXCLUDED.attempt_id,
status = 'active'
attempt_id = EXCLUDED.attempt_id
RETURNING *
"""
),
@@ -117,23 +113,6 @@ class IngestChunkProgressRepository:
row = result.fetchone()
return row_to_dict(row) if row is not None else None
def delete(self, source_id: str) -> bool:
"""Delete the progress row for ``source_id``.
A manual reingest supersedes any prior ingest state — including a
reconciler ``'stalled'`` escalation — so dropping the row clears
the derived ``failed`` ingest status the sources list shows.
Returns ``True`` when a row was removed.
"""
result = self._conn.execute(
text(
"DELETE FROM ingest_chunk_progress "
"WHERE source_id = CAST(:source_id AS uuid)"
),
{"source_id": str(source_id)},
)
return result.rowcount > 0
def bump_heartbeat(self, source_id: str) -> None:
"""Refresh ``last_updated`` so the row looks alive to the reconciler."""
self._conn.execute(

View File

@@ -1,248 +0,0 @@
"""Repository for ``message_events`` — the chat-stream snapshot journal.
``record`` / ``bulk_record`` write per-yield events; ``read_after``
replays rows past a cursor for reconnect snapshots. Composite PK
``(message_id, sequence_no)`` raises ``IntegrityError`` on duplicates.
Callers must use short-lived per-call transactions — long-lived
transactions hide writes from reconnecting clients on a separate
connection and turn one bad row into ``InFailedSqlTransaction``.
"""
from __future__ import annotations
import json
import logging
from typing import Any, Optional
from sqlalchemy import Connection, text
from application.storage.db.base_repository import row_to_dict
logger = logging.getLogger(__name__)
class MessageEventsRepository:
"""Read/write helpers for ``message_events``."""
def __init__(self, conn: Connection) -> None:
self._conn = conn
def record(
self,
message_id: str,
sequence_no: int,
event_type: str,
payload: Optional[Any] = None,
) -> None:
"""Append a single event to the journal.
At this raw repo layer ``payload`` is preserved as-is when not
``None`` (lists, scalars, and dicts all round-trip via JSONB);
``None`` substitutes an empty object so the column's NOT NULL
invariant holds. The streaming-route wrapper
``application/streaming/message_journal.py::record_event``
tightens this contract to dicts only — the live and replay
paths reconstruct non-dict payloads differently, so the wrapper
rejects them at the gate. Direct callers of this repo method
(cleanup tasks, tests, future ad-hoc consumers) keep the wider
JSONB-compatible surface.
Raises ``sqlalchemy.exc.IntegrityError`` on duplicate
``(message_id, sequence_no)`` and ``DataError`` on a malformed
``message_id`` UUID. Both abort the surrounding transaction —
callers must run inside a short-lived per-event session
(see module docstring).
"""
if not event_type:
raise ValueError("event_type must be a non-empty string")
materialised_payload = payload if payload is not None else {}
self._conn.execute(
text(
"""
INSERT INTO message_events (
message_id, sequence_no, event_type, payload
) VALUES (
CAST(:message_id AS uuid), :sequence_no, :event_type,
CAST(:payload AS jsonb)
)
"""
),
{
"message_id": str(message_id),
"sequence_no": int(sequence_no),
"event_type": event_type,
"payload": json.dumps(materialised_payload),
},
)
def bulk_record(
self,
message_id: str,
events: list[tuple[int, str, dict]],
) -> None:
"""Append multiple events for ``message_id`` in one INSERT.
``events`` is a list of ``(sequence_no, event_type, payload)``
tuples. SQLAlchemy ``executemany`` issues one bulk INSERT;
Postgres treats the whole batch as one statement, so an
IntegrityError on any row aborts the entire batch.
Caller contract: on IntegrityError, do NOT retry this method
with the same batch — fall back to per-row ``record()`` calls
(each in its own short-lived session) so a single colliding
seq doesn't drop the rest of the batch. ``BatchedJournalWriter``
in ``application/streaming/message_journal.py`` is the canonical
consumer.
"""
if not events:
return
params = [
{
"message_id": str(message_id),
"sequence_no": int(seq),
"event_type": event_type,
"payload": json.dumps(payload if payload is not None else {}),
}
for seq, event_type, payload in events
]
self._conn.execute(
text(
"""
INSERT INTO message_events (
message_id, sequence_no, event_type, payload
) VALUES (
CAST(:message_id AS uuid), :sequence_no, :event_type,
CAST(:payload AS jsonb)
)
"""
),
params,
)
def read_after(
self,
message_id: str,
last_sequence_no: Optional[int] = None,
) -> list[dict]:
"""Return events with ``sequence_no > last_sequence_no``.
``last_sequence_no=None`` returns the full backlog. Rows are
returned in ascending ``sequence_no`` order. The composite PK
is the snapshot read index for this scan — Postgres typically
picks an in-order index range scan, though for highly mixed
data the planner may pick a bitmap+sort. Either way the result
is sorted on ``sequence_no``.
Returns a ``list`` (not a generator) so the underlying
``Result`` is fully drained before the caller can issue
another query on the same connection.
"""
cursor = -1 if last_sequence_no is None else int(last_sequence_no)
rows = self._conn.execute(
text(
"""
SELECT message_id, sequence_no, event_type, payload, created_at
FROM message_events
WHERE message_id = CAST(:message_id AS uuid)
AND sequence_no > :cursor
ORDER BY sequence_no ASC
"""
),
{"message_id": str(message_id), "cursor": cursor},
).fetchall()
return [row_to_dict(row) for row in rows]
def cleanup_older_than(self, ttl_days: int) -> int:
"""Delete journal rows older than ``ttl_days``. Returns row count.
Reconnect-replay is meaningful only for streams the client
could plausibly still be waiting on, so old rows are dead
weight. The ``message_events_created_at_idx`` btree makes the
range delete a cheap index scan even on large tables.
"""
if ttl_days <= 0:
raise ValueError("ttl_days must be positive")
result = self._conn.execute(
text(
"""
DELETE FROM message_events
WHERE created_at < now() - make_interval(days => :ttl_days)
"""
),
{"ttl_days": int(ttl_days)},
)
return int(result.rowcount or 0)
def reconstruct_partial(self, message_id: str) -> dict:
"""Rebuild partial response/thought/sources/tool_calls from journal events.
``answer``/``thought`` chunks concat in seq order; ``source``/
``tool_calls`` carry the full list at emit time (last-wins).
"""
rows = self._conn.execute(
text(
"""
SELECT sequence_no, event_type, payload
FROM message_events
WHERE message_id = CAST(:message_id AS uuid)
ORDER BY sequence_no ASC
"""
),
{"message_id": str(message_id)},
).fetchall()
response_parts: list[str] = []
thought_parts: list[str] = []
sources: list = []
tool_calls: list = []
for row in rows:
payload = row.payload
if not isinstance(payload, dict):
continue
etype = row.event_type
if etype == "answer":
chunk = payload.get("answer")
if isinstance(chunk, str):
response_parts.append(chunk)
elif etype == "thought":
chunk = payload.get("thought")
if isinstance(chunk, str):
thought_parts.append(chunk)
elif etype == "source":
src = payload.get("source")
if isinstance(src, list):
sources = src
elif etype == "tool_calls":
tcs = payload.get("tool_calls")
if isinstance(tcs, list):
tool_calls = tcs
return {
"response": "".join(response_parts),
"thought": "".join(thought_parts),
"sources": sources,
"tool_calls": tool_calls,
}
def latest_sequence_no(self, message_id: str) -> Optional[int]:
"""Largest ``sequence_no`` recorded for ``message_id``, or ``None``.
Used by the route to seed the per-stream allocator on retry /
process restart so a re-run continues numbering instead of
trampling earlier entries with duplicate sequence_no.
"""
# ``MAX`` always returns one row — NULL when the journal is
# empty — so we test the value, not the row presence.
row = self._conn.execute(
text(
"""
SELECT MAX(sequence_no) AS s
FROM message_events
WHERE message_id = CAST(:message_id AS uuid)
"""
),
{"message_id": str(message_id)},
).first()
value = row[0] if row is not None else None
return int(value) if value is not None else None

View File

@@ -107,11 +107,7 @@ class ReconciliationRepository:
def find_and_lock_stalled_ingests(
self, *, age_minutes: int = 30, limit: int = 100,
) -> list[dict]:
"""Lock still-active ingest checkpoints with a silent heartbeat.
The ``status = 'active'`` filter skips rows already escalated to
``'stalled'``, so a dead ingest is alerted once, not every tick.
"""
"""Lock ingest checkpoints whose heartbeat hasn't ticked recently."""
result = self._conn.execute(
text(
"""
@@ -120,7 +116,6 @@ class ReconciliationRepository:
FROM ingest_chunk_progress
WHERE last_updated < now() - make_interval(mins => :age)
AND embedded_chunks < total_chunks
AND status = 'active'
ORDER BY last_updated ASC
LIMIT :limit
FOR UPDATE SKIP LOCKED
@@ -130,15 +125,11 @@ class ReconciliationRepository:
)
return [row_to_dict(r) for r in result.fetchall()]
def mark_ingest_stalled(self, source_id: str) -> bool:
"""Escalate a stalled checkpoint to terminal ``status='stalled'``.
Drops the row out of the sweep so the reconciler alerts once;
``init_progress`` flips it back to ``'active'`` on reingest.
"""
def touch_ingest_progress(self, source_id: str) -> bool:
"""Bump ``last_updated`` so a once-stalled ingest re-enters the watch window."""
result = self._conn.execute(
text(
"UPDATE ingest_chunk_progress SET status = 'stalled' "
"UPDATE ingest_chunk_progress SET last_updated = now() "
"WHERE source_id = CAST(:sid AS uuid)"
),
{"sid": str(source_id)},

View File

@@ -5,10 +5,10 @@ from __future__ import annotations
import json
from typing import Any, Optional
from sqlalchemy import case, Connection, func, select, text
from sqlalchemy import Connection, func, select, text
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
from application.storage.db.models import ingest_chunk_progress_table, sources_table
from application.storage.db.models import sources_table
_SCALAR_COLUMNS = {
@@ -61,21 +61,6 @@ def _coerce_jsonb(value: Any) -> Any:
return value
def _ingest_status_case():
"""Derive a user-facing ingest status from the joined progress row.
``failed`` — reconciler-escalated stall. ``processing`` — embed in
flight. ``None`` — no progress row, or the embed completed.
"""
icp = ingest_chunk_progress_table
return case(
(icp.c.source_id.is_(None), None),
(icp.c.status == "stalled", "failed"),
(icp.c.embedded_chunks < icp.c.total_chunks, "processing"),
else_=None,
).label("ingest_status")
class SourcesRepository:
def __init__(self, conn: Connection) -> None:
self._conn = conn
@@ -207,25 +192,13 @@ class SourcesRepository:
as ``"desc"``.
Returns:
A list of source rows as plain dicts (via ``row_to_dict``),
each carrying a derived ``ingest_status`` (``failed`` /
``processing`` / ``None``) from the joined progress row.
A list of source rows as plain dicts (via ``row_to_dict``).
"""
column_name = sort_field if sort_field in _SORTABLE_COLUMNS else "date"
sort_column = sources_table.c[column_name]
ascending = sort_order.lower() == "asc"
stmt = (
select(sources_table, _ingest_status_case())
.select_from(
sources_table.outerjoin(
ingest_chunk_progress_table,
ingest_chunk_progress_table.c.source_id
== sources_table.c.id,
)
)
.where(sources_table.c.user_id == user_id)
)
stmt = select(sources_table).where(sources_table.c.user_id == user_id)
if search_term:
stmt = stmt.where(
sources_table.c.name.ilike(

View File

@@ -63,8 +63,7 @@ class ToolCallAttemptsRepository:
message_id: Optional[str] = None,
artifact_id: Optional[str] = None,
) -> None:
"""Insert OR upgrade a row to ``executed`` — or ``confirmed`` when
there is no ``message_id``, as in ``mark_executed``.
"""Insert OR upgrade a row to ``executed``.
Used as a fallback when ``record_proposed`` failed (DB outage)
and the tool ran anyway — preserves the journal so the
@@ -73,7 +72,6 @@ class ToolCallAttemptsRepository:
result_payload: dict = {"result": result}
if artifact_id:
result_payload["artifact_id"] = artifact_id
status = "executed" if message_id is not None else "confirmed"
self._conn.execute(
text(
"""
@@ -84,9 +82,9 @@ class ToolCallAttemptsRepository:
(:call_id, CAST(:tool_id AS uuid), :tool_name,
:action_name, CAST(:arguments AS jsonb),
CAST(:result AS jsonb), CAST(:message_id AS uuid),
:status)
'executed')
ON CONFLICT (call_id) DO UPDATE
SET status = :status,
SET status = 'executed',
result = EXCLUDED.result,
message_id = COALESCE(EXCLUDED.message_id, tool_call_attempts.message_id)
"""
@@ -99,7 +97,6 @@ class ToolCallAttemptsRepository:
"arguments": json.dumps(arguments if arguments is not None else {}, cls=PGNativeJSONEncoder),
"result": json.dumps(result_payload, cls=PGNativeJSONEncoder),
"message_id": message_id,
"status": status,
},
)
@@ -111,9 +108,7 @@ class ToolCallAttemptsRepository:
message_id: Optional[str] = None,
artifact_id: Optional[str] = None,
) -> bool:
"""Flip ``proposed`` → ``executed``, or straight to ``confirmed``
when there is no ``message_id`` (a ``save_conversation=False``
request reserves no message, so no finalize will confirm it).
"""Flip ``proposed`` → ``executed`` with the tool result.
``artifact_id`` (when present) is stored alongside ``result`` in
the JSONB as audit data — the reconciler reads it for diagnostic
@@ -122,14 +117,12 @@ class ToolCallAttemptsRepository:
result_payload: dict = {"result": result}
if artifact_id:
result_payload["artifact_id"] = artifact_id
status = "executed" if message_id is not None else "confirmed"
sql = (
"UPDATE tool_call_attempts SET "
"status = :status, result = CAST(:result AS jsonb)"
"status = 'executed', result = CAST(:result AS jsonb)"
)
params: dict[str, Any] = {
"call_id": call_id,
"status": status,
"result": json.dumps(result_payload, cls=PGNativeJSONEncoder),
}
if message_id is not None:

View File

@@ -1,23 +0,0 @@
"""Deterministic source-id derivation for idempotent ingest.
DO NOT CHANGE the pinned UUID namespace — it backs cross-deploy
idempotency keys.
"""
from __future__ import annotations
import uuid
# DO NOT CHANGE. See module docstring.
DOCSGPT_INGEST_NAMESPACE = uuid.UUID("fa25d5d1-398b-46df-ac89-8d1c360b9bea")
def derive_source_id(idempotency_key) -> uuid.UUID:
"""``uuid5(NS, key)`` when a key is supplied; ``uuid4()`` otherwise.
A non-string / empty key falls back to ``uuid4()`` so the caller
always gets a fresh id rather than a TypeError mid-route.
"""
if isinstance(idempotency_key, str) and idempotency_key:
return uuid.uuid5(DOCSGPT_INGEST_NAMESPACE, idempotency_key)
return uuid.uuid4()

View File

@@ -1,126 +0,0 @@
"""Redis pub/sub Topic abstraction for SSE fan-out.
A Topic is a named channel for one-shot live event delivery. Canonical uses:
- ``user:{user_id}`` for per-user notifications
- ``channel:{message_id}`` for per-chat-message streams
Subscription is race-free via ``on_subscribe``: the callback fires only
after Redis acknowledges ``SUBSCRIBE``, so a publisher dispatched inside
the callback cannot lose its first event to a not-yet-registered
subscriber.
The subscribe iterator yields ``None`` on poll timeout so the caller can
emit SSE keepalive comments without spawning a separate timer thread.
"""
from __future__ import annotations
import logging
from typing import Callable, Iterator, Optional
from application.cache import get_redis_instance
logger = logging.getLogger(__name__)
class Topic:
"""A pub/sub channel identified by a string name."""
def __init__(self, name: str) -> None:
self.name = name
def publish(self, payload: str | bytes) -> int:
"""Fan out a payload to currently subscribed clients.
Returns the number Redis reports as receiving the message (limited
to subscribers connected to *this* Redis instance), or 0 if Redis
is unavailable. Never raises.
"""
redis = get_redis_instance()
if redis is None:
logger.debug("Redis unavailable; dropping publish to %s", self.name)
return 0
try:
return int(redis.publish(self.name, payload))
except Exception:
logger.exception("Topic.publish failed for %s", self.name)
return 0
def subscribe(
self,
on_subscribe: Optional[Callable[[], None]] = None,
poll_timeout: float = 1.0,
) -> Iterator[Optional[bytes]]:
"""Subscribe to the topic; yield raw payloads or ``None`` on tick.
Yields ``None`` every ``poll_timeout`` seconds while idle so the
caller can emit keepalive frames or check cancellation. Yields
``bytes`` for each delivered message.
``on_subscribe`` runs synchronously after Redis acknowledges the
SUBSCRIBE — use it to seed any state (e.g. read backlog) that
must be ordered after the subscriber is live but before the
first pub/sub message is processed.
If Redis is unavailable, returns immediately without yielding.
Cleanly unsubscribes on ``GeneratorExit`` (client disconnect).
"""
redis = get_redis_instance()
if redis is None:
logger.debug("Redis unavailable; subscribe to %s yielded nothing", self.name)
return
pubsub = None
on_subscribe_fired = False
try:
pubsub = redis.pubsub()
try:
pubsub.subscribe(self.name)
except Exception:
# Subscribe failure (transient Redis hiccup, conn reset, etc.)
# is treated like "Redis unavailable": yield nothing, let the
# caller fall back to its own resilience strategy. The finally
# block will still tear down the pubsub object cleanly.
logger.exception("pubsub.subscribe failed for %s", self.name)
return
while True:
try:
msg = pubsub.get_message(timeout=poll_timeout)
except Exception:
logger.exception("pubsub.get_message failed for %s", self.name)
return
if msg is None:
yield None
continue
msg_type = msg.get("type")
if msg_type == "subscribe":
if not on_subscribe_fired and on_subscribe is not None:
try:
on_subscribe()
except Exception:
logger.exception(
"on_subscribe callback failed for %s", self.name
)
on_subscribe_fired = True
continue
if msg_type != "message":
continue
data = msg.get("data")
if data is None:
continue
yield data if isinstance(data, bytes) else str(data).encode("utf-8")
finally:
if pubsub is not None:
if on_subscribe_fired:
try:
pubsub.unsubscribe(self.name)
except Exception:
logger.debug(
"pubsub unsubscribe error for %s",
self.name,
exc_info=True,
)
try:
pubsub.close()
except Exception:
logger.debug("pubsub close error for %s", self.name, exc_info=True)

View File

@@ -1,434 +0,0 @@
"""Snapshot+tail iterator for chat-stream reconnect-after-disconnect.
Subscribe to ``channel:{message_id}``, snapshot ``message_events``
rows past ``last_event_id`` inside the SUBSCRIBE-ack callback, flush
snapshot, then tail live pub/sub (dedup'd by ``sequence_no``). See
``docs/runbooks/sse-notifications.md``.
"""
from __future__ import annotations
import json
import logging
import re
import time
from typing import Iterator, Optional
from sqlalchemy import text as sql_text
from application.storage.db.repositories.message_events import (
MessageEventsRepository,
)
from application.storage.db.session import db_readonly
from application.streaming.broadcast_channel import Topic
from application.streaming.keys import message_topic_name
logger = logging.getLogger(__name__)
DEFAULT_KEEPALIVE_SECONDS = 15.0
DEFAULT_POLL_TIMEOUT_SECONDS = 1.0
# When the live tail has no events and no terminal in snapshot, fall
# back to checking ``conversation_messages`` directly. If the row has
# already gone terminal (worker journaled ``end``/``error`` to the DB
# but the matching pub/sub publish was lost, or the row was finalized
# without a journal write at all) we surface a terminal event so the
# client doesn't hang on keepalives. If the row is still non-terminal
# but the producer heartbeat is older than ``PRODUCER_IDLE_SECONDS``
# the producer is presumed dead (worker crash / recycle between chunks
# and finalize) and we emit a terminal ``error`` so the UI can recover.
DEFAULT_WATCHDOG_INTERVAL_SECONDS = 5.0
# 1.5× the route's 60s heartbeat interval — long enough that a normal
# heartbeat skew doesn't false-positive, short enough that a stuck
# stream surfaces before the 5-minute reconciler sweep escalates.
DEFAULT_PRODUCER_IDLE_SECONDS = 90.0
# WHATWG SSE accepts CRLF, CR, LF — split on any of them so a stray CR
# can't smuggle a record boundary into the wire format.
_SSE_LINE_SPLIT_PATTERN = re.compile(r"\r\n|\r|\n")
# Event types that mark the end of a chat answer. After delivering one
# we close the reconnect stream — keeping the connection open past a
# terminal event would leak both the client's reconnect promise and
# the server's WSGI thread waiting on keepalives that the user no
# longer cares about. The agent loop emits ``end`` for normal /
# tool-paused completion and ``error`` for the catch-all failure path
# (which doesn't get a trailing ``end``).
_TERMINAL_EVENT_TYPES = frozenset({"end", "error"})
def _payload_is_terminal(
payload: object, event_type: Optional[str] = None
) -> bool:
"""True if ``payload['type']`` or ``event_type`` is a terminal sentinel."""
if isinstance(payload, dict) and payload.get("type") in _TERMINAL_EVENT_TYPES:
return True
return event_type in _TERMINAL_EVENT_TYPES
def format_sse_event(payload: dict, sequence_no: int) -> str:
"""Encode a journal event as one ``id:``/``data:`` SSE record.
The body is the payload's JSON serialisation. ``complete_stream``
payloads are flat JSON dicts with no embedded newlines, so a
single ``data:`` line is sufficient — but we still split on any
line terminator in case a future caller passes a multi-line string.
"""
body = json.dumps(payload)
lines = [f"id: {sequence_no}"]
for line in _SSE_LINE_SPLIT_PATTERN.split(body):
lines.append(f"data: {line}")
return "\n".join(lines) + "\n\n"
def _check_producer_liveness(
message_id: str, idle_seconds: float
) -> Optional[dict]:
"""Inspect ``conversation_messages`` and return a terminal SSE
payload when the producer is no longer alive, else ``None``.
Three terminal cases collapse into a single DB round-trip:
- ``status='complete'`` — the live finalize ran but its journal
terminal write didn't reach us (or never happened). Synthesise
``end`` so the client closes cleanly on the row's user-visible
state.
- ``status='failed'`` — same, but for the failure path. Carry the
stashed ``error`` from ``message_metadata`` so the UI shows the
real reason.
- non-terminal status and ``last_heartbeat_at`` (or ``timestamp``)
older than ``idle_seconds`` — the producing worker is gone.
Synthesise ``error`` so the client doesn't hang on keepalives
until the proxy idle-timeout kicks in.
"""
try:
with db_readonly() as conn:
row = conn.execute(
sql_text(
"""
SELECT
status,
message_metadata->>'error' AS err,
GREATEST(
timestamp,
COALESCE(
(message_metadata->>'last_heartbeat_at')
::timestamptz,
timestamp
)
) < now() - make_interval(secs => :idle_secs)
AS is_stale
FROM conversation_messages
WHERE id = CAST(:id AS uuid)
"""
),
{"id": message_id, "idle_secs": float(idle_seconds)},
).first()
except Exception:
logger.exception(
"Watchdog liveness check failed for message_id=%s", message_id
)
return None
if row is None:
# Row deleted out from under us — treat as terminal so the
# client doesn't keep tailing a message that no longer exists.
return {
"type": "error",
"error": "Message no longer exists; please refresh.",
"code": "message_missing",
"message_id": message_id,
}
status, err, is_stale = row[0], row[1], bool(row[2])
if status == "complete":
return {"type": "end"}
if status == "failed":
return {
"type": "error",
"error": err or "Stream failed; please try again.",
"code": "producer_failed",
"message_id": message_id,
}
if is_stale:
return {
"type": "error",
"error": (
"Stream producer is no longer responding; please try again."
),
"code": "producer_stale",
"message_id": message_id,
}
return None
def build_message_event_stream(
message_id: str,
last_event_id: Optional[int] = None,
*,
keepalive_seconds: float = DEFAULT_KEEPALIVE_SECONDS,
poll_timeout_seconds: float = DEFAULT_POLL_TIMEOUT_SECONDS,
watchdog_interval_seconds: float = DEFAULT_WATCHDOG_INTERVAL_SECONDS,
producer_idle_seconds: float = DEFAULT_PRODUCER_IDLE_SECONDS,
) -> Iterator[str]:
"""Yield SSE-formatted lines for one ``message_id`` reconnect stream.
First frame is ``: connected``; subsequent frames are snapshot rows,
live-tail events, or ``: keepalive`` comments. Runs until the client
disconnects.
"""
yield ": connected\n\n"
# Replay buffer — populated inside ``_on_subscribe`` (or the
# Redis-unavailable fallback below), drained on the first iteration
# of the subscribe loop after the callback runs.
replay_buffer: list[str] = []
# Dedup floor: seeded with the client's cursor so an empty snapshot
# still rejects re-published live events with seq <= last_event_id.
# Advanced by snapshot rows AND by yielded live events, so any
# republish past the snapshot ceiling is also dropped.
max_replayed_seq: Optional[int] = last_event_id
replay_done = False
replay_failed = False
# Set when a snapshot row carries a terminal ``end`` / ``error``
# event. After flushing the buffer the generator returns; if we
# kept tailing we'd loop on keepalives forever for a stream that
# already finished.
terminal_in_snapshot = False
def _read_snapshot_into_buffer() -> None:
nonlocal max_replayed_seq, replay_failed, terminal_in_snapshot
try:
with db_readonly() as conn:
rows = MessageEventsRepository(conn).read_after(
message_id, last_sequence_no=last_event_id
)
for row in rows:
seq = int(row["sequence_no"])
payload = row.get("payload")
if not isinstance(payload, dict):
# ``record_event`` rejects non-dict payloads at the
# write gate, so this can only be a legacy row from
# before that contract or a direct SQL insert. The
# original synthetic fallback (``{"type": event_type}``)
# used to ship a malformed envelope here — drop the
# row instead so a corrupt journal entry doesn't
# poison a reconnect.
logger.warning(
"Skipping non-dict payload from message_events: "
"message_id=%s seq=%s type=%s",
message_id,
seq,
row.get("event_type"),
)
continue
replay_buffer.append(format_sse_event(payload, seq))
if max_replayed_seq is None or seq > max_replayed_seq:
max_replayed_seq = seq
if _payload_is_terminal(payload, row.get("event_type")):
terminal_in_snapshot = True
except Exception:
logger.exception(
"Snapshot read failed for message_id=%s last_event_id=%s",
message_id,
last_event_id,
)
replay_failed = True
def _on_subscribe() -> None:
# SUBSCRIBE has been acked — Postgres reads from this point
# capture every row that's been committed. Pub/sub messages
# published after this point are queued at the connection level
# until the outer loop calls ``get_message`` again.
nonlocal replay_done
try:
_read_snapshot_into_buffer()
finally:
# Flip even on failure so the outer loop continues to live
# tail and the client doesn't hang waiting for a snapshot
# flush that will never come.
replay_done = True
topic = Topic(message_topic_name(message_id))
last_keepalive = time.monotonic()
# Rate-limit the watchdog's DB hit. ``-inf`` makes the first idle
# tick after replay_done fire immediately so a snapshot-already-
# terminal-in-DB case is surfaced before any keepalive cadence.
# Subsequent checks are gated by ``watchdog_interval_seconds``.
last_watchdog_check = float("-inf")
# Synthetic terminal events emitted by the watchdog use the same
# ``sequence_no=-1`` convention as the snapshot-failure path so the
# frontend's strict ``\d+`` cursor regex rejects them as a
# ``Last-Event-ID`` for any future reconnect. The chosen
# discriminator ensures a manual page refresh after a watchdog-fired
# error doesn't loop on the same synthetic id.
watchdog_synthetic_seq = -1
try:
for payload in topic.subscribe(
on_subscribe=_on_subscribe,
poll_timeout=poll_timeout_seconds,
):
# Flush snapshot exactly once after the SUBSCRIBE callback
# has run and produced a buffer.
if replay_done and replay_buffer:
for line in replay_buffer:
yield line
replay_buffer.clear()
if terminal_in_snapshot:
# The original stream already finished; tailing
# would just emit keepalives forever and pin both a
# client reconnect promise and a server WSGI thread.
return
if replay_failed:
# Snapshot read failed (DB blip / transient timeout). Emit a
# terminal ``error`` event and return — the client only
# reconnects after the original stream has already moved on,
# so without a snapshot there's nothing live left to tail and
# holding the connection open would just emit keepalives
# until the proxy idle-timeout fires. ``code`` preserves the
# snapshot-vs-agent-loop distinction so a future client can
# opt into a refetch instead of a hard failure.
yield format_sse_event(
{
"type": "error",
"error": "Stream replay failed; please refresh to load the latest state.",
"code": "snapshot_failed",
"message_id": message_id,
},
sequence_no=-1,
)
return
now = time.monotonic()
if payload is None:
# Idle tick — check both keepalive and watchdog. The
# watchdog only kicks in once the snapshot half has been
# flushed (``replay_done``) so we don't race the
# snapshot read on the first iteration.
if (
replay_done
and watchdog_interval_seconds >= 0
and now - last_watchdog_check >= watchdog_interval_seconds
):
last_watchdog_check = now
terminal_payload = _check_producer_liveness(
message_id, producer_idle_seconds
)
if terminal_payload is not None:
yield format_sse_event(
terminal_payload,
sequence_no=watchdog_synthetic_seq,
)
return
if now - last_keepalive >= keepalive_seconds:
yield ": keepalive\n\n"
last_keepalive = now
continue
envelope = _decode_pubsub_message(payload)
if envelope is None:
continue
seq = envelope.get("sequence_no")
inner = envelope.get("payload")
if (
not isinstance(seq, int)
or isinstance(seq, bool)
or not isinstance(inner, dict)
):
continue
if max_replayed_seq is not None and seq <= max_replayed_seq:
# Snapshot already covered this id — drop the duplicate.
continue
yield format_sse_event(inner, seq)
# Advance the dedup floor on the live path too, so a stale
# republish of an already-yielded seq (process restart, retry
# tool, etc.) is dropped on a later iteration.
max_replayed_seq = seq
last_keepalive = now
if _payload_is_terminal(inner, envelope.get("event_type")):
# Live tail just delivered the terminal event — close
# out the reconnect stream so the client's drain
# promise resolves and the WSGI thread is freed.
return
# Subscribe exited without ever yielding (Redis unavailable,
# ``pubsub.subscribe`` raised, or the inner loop died between
# SUBSCRIBE-ack and the first poll). The snapshot half is in
# Postgres and is still serviceable — read it directly so a
# Redis-only outage doesn't cost the client their reconnect
# backlog. Gate the read on ``replay_done`` rather than
# ``subscribe_started``: if ``_on_subscribe`` already populated
# the buffer, re-reading would append the same rows twice and
# double the answer chunks on the client (the per-message
# reconnect dispatcher does not dedup by ``id``).
if not replay_done:
_read_snapshot_into_buffer()
replay_done = True
for line in replay_buffer:
yield line
replay_buffer.clear()
if replay_failed:
# Mirror the live-tail branch: emit a terminal ``error`` so
# the frontend's existing end/error handling drives the UI
# to a failed state instead of relying on the proxy timeout.
yield format_sse_event(
{
"type": "error",
"error": "Stream replay failed; please refresh to load the latest state.",
"code": "snapshot_failed",
"message_id": message_id,
},
sequence_no=-1,
)
return
# Same close-on-terminal contract as the live-tail branch.
# Without it a Redis-down + already-completed-stream client
# would also hang on a never-ending generator.
if terminal_in_snapshot:
return
except GeneratorExit:
# Client disconnect — let the underlying ``Topic.subscribe``
# ``finally`` block tear down its pubsub cleanly.
return
def _decode_pubsub_message(raw) -> Optional[dict]:
"""Parse a ``Topic.publish`` payload to ``{sequence_no, payload, ...}``.
Returns ``None`` for malformed messages (drop silently — the
journal is still authoritative on reconnect).
"""
try:
if isinstance(raw, (bytes, bytearray)):
text_value = raw.decode("utf-8")
else:
text_value = str(raw)
envelope = json.loads(text_value)
except Exception:
return None
if not isinstance(envelope, dict):
return None
return envelope
def encode_pubsub_message(
message_id: str,
sequence_no: int,
event_type: str,
payload: dict,
) -> str:
"""Build the JSON envelope used for ``channel:{message_id}`` publishes.
Kept here (not in ``message_journal.py``) so the encode/decode pair
stays in one file — replay's ``_decode_pubsub_message`` and the
journal's publish must agree on the shape exactly.
"""
return json.dumps(
{
"message_id": str(message_id),
"sequence_no": int(sequence_no),
"event_type": event_type,
"payload": payload,
}
)

View File

@@ -1,19 +0,0 @@
"""Per-chat-message stream key derivations.
Single source of truth for the Redis pub/sub topic name and any
auxiliary keys that the chat-stream snapshot+tail reconnect path
shares between the writer (``complete_stream`` + journal) and the
reader (``/api/messages/<id>/events`` reconnect endpoint).
"""
from __future__ import annotations
def message_topic_name(message_id: str) -> str:
"""Redis pub/sub channel for live fan-out of one chat message.
Subscribers tail this topic for every event that ``complete_stream``
yielded after the SUBSCRIBE-ack arrived; older events are recovered
from the ``message_events`` snapshot half of the pattern.
"""
return f"channel:{message_id}"

View File

@@ -1,400 +0,0 @@
"""Per-yield journal write for the chat-stream snapshot+tail pattern.
``record_event`` inserts into ``message_events`` and publishes to
``channel:{message_id}``. Both are best-effort; the INSERT commits
before the publish so a fast reconnect sees the row. See
``docs/runbooks/sse-notifications.md``.
"""
from __future__ import annotations
import logging
import time
from typing import Any, Optional
from sqlalchemy.exc import IntegrityError
from application.storage.db.repositories.message_events import (
MessageEventsRepository,
)
from application.storage.db.session import db_readonly, db_session
from application.streaming.broadcast_channel import Topic
from application.streaming.event_replay import encode_pubsub_message
from application.streaming.keys import message_topic_name
logger = logging.getLogger(__name__)
# Tunables for ``BatchedJournalWriter``. A streaming answer emits ~100s
# of ``answer`` chunks per response; without batching, that's one PG
# transaction per yield in the WSGI thread. With these defaults, ~10x
# fewer commits at the cost of a ≤100ms reconnect-visibility lag for
# any event still sitting in the buffer.
DEFAULT_BATCH_SIZE = 16
DEFAULT_BATCH_INTERVAL_MS = 100
def _strip_null_bytes(value: Any) -> Any:
"""Recursively strip ``\\x00`` from string keys/values in ``value``.
Postgres JSONB rejects the NUL escape; an LLM emitting a stray NUL
in a chunk would otherwise raise ``DataError`` at INSERT and the row
would be lost from the journal (live stream proceeds, reconnect
snapshot misses the chunk). Mirrors the strip already done in
``parser/embedding_pipeline.py`` and
``api/user/attachments/routes.py``.
"""
if isinstance(value, str):
return value.replace("\x00", "") if "\x00" in value else value
if isinstance(value, dict):
return {
(k.replace("\x00", "") if isinstance(k, str) and "\x00" in k else k):
_strip_null_bytes(v)
for k, v in value.items()
}
if isinstance(value, list):
return [_strip_null_bytes(item) for item in value]
if isinstance(value, tuple):
return tuple(_strip_null_bytes(item) for item in value)
return value
def record_event(
message_id: str,
sequence_no: int,
event_type: str,
payload: Optional[dict[str, Any]] = None,
) -> bool:
"""Journal one SSE event and publish it live. Best-effort.
``payload`` must be a ``dict`` or ``None`` (non-dicts are dropped so
live and replay envelopes stay byte-identical). Returns ``True`` when
the journal INSERT committed. Never raises.
"""
if not message_id or not event_type:
logger.warning(
"record_event called without message_id/event_type "
"(message_id=%r, event_type=%r)",
message_id,
event_type,
)
return False
if payload is None:
materialised_payload: dict[str, Any] = {}
elif isinstance(payload, dict):
materialised_payload = _strip_null_bytes(payload)
else:
logger.warning(
"record_event called with non-dict payload "
"(message_id=%s seq=%s type=%s payload_type=%s) — dropping",
message_id,
sequence_no,
event_type,
type(payload).__name__,
)
return False
journal_committed = False
# The seq we actually managed to write. Diverges from
# ``sequence_no`` only on the IntegrityError-retry path below.
materialised_seq = sequence_no
try:
# Short-lived per-event transaction. Critical for visibility:
# the reconnect endpoint reads the journal from a separate
# connection and only sees committed rows.
with db_session() as conn:
MessageEventsRepository(conn).record(
message_id, sequence_no, event_type, materialised_payload
)
journal_committed = True
except IntegrityError:
# Composite-PK collision on (message_id, sequence_no). Most
# likely cause is a stale ``latest_sequence_no`` seed on a
# continuation retry — the route read MAX(seq) from a separate
# connection before another writer committed past it. Look up
# the live latest and retry once with latest+1 so the event is
# not silently lost. Bounded to a single retry — if two
# writers keep racing in lockstep the route-level retry will
# converge them across attempts.
try:
with db_readonly() as conn:
latest = MessageEventsRepository(conn).latest_sequence_no(
message_id
)
materialised_seq = (latest if latest is not None else -1) + 1
with db_session() as conn:
MessageEventsRepository(conn).record(
message_id,
materialised_seq,
event_type,
materialised_payload,
)
journal_committed = True
logger.info(
"record_event: collision at seq=%s recovered → wrote at "
"seq=%s message_id=%s type=%s",
sequence_no,
materialised_seq,
message_id,
event_type,
)
except IntegrityError:
# Second collision under the same retry — give up and log.
# The route's nonlocal counter will continue at
# ``sequence_no+1`` on the next emit; the next call may
# land cleanly past the contended window.
logger.warning(
"record_event: IntegrityError persists after seq+1 retry; "
"dropping. message_id=%s original_seq=%s retry_seq=%s "
"type=%s",
message_id,
sequence_no,
materialised_seq,
event_type,
)
except Exception:
logger.exception(
"record_event: retry path failed unexpectedly "
"(message_id=%s seq=%s type=%s)",
message_id,
sequence_no,
event_type,
)
except Exception:
logger.exception(
"message_events INSERT failed: message_id=%s seq=%s type=%s",
message_id,
sequence_no,
event_type,
)
try:
# Publish using ``materialised_seq`` so the live pubsub frame
# matches the journal row that other clients will snapshot on
# reconnect. The original POST stream's SSE ``id:`` still
# carries the caller's ``sequence_no`` — a reconnect from that
# client will receive the same event at ``materialised_seq``
# on the snapshot, which is a benign duplicate (the slice's
# ``max_replayed_seq`` advances past it). No-collision case:
# ``materialised_seq == sequence_no`` and this is identical to
# the prior behaviour.
wire = encode_pubsub_message(
message_id, materialised_seq, event_type, materialised_payload
)
Topic(message_topic_name(message_id)).publish(wire)
except Exception:
logger.exception(
"channel:%s publish failed: seq=%s type=%s",
message_id,
materialised_seq,
event_type,
)
return journal_committed
class BatchedJournalWriter:
"""Per-stream journal writer that batches PG INSERTs.
One writer per ``message_id``; ``record()`` buffers events and flushes
on size/time/``close()`` triggers. Pubsub publishes fire only after the
INSERT commits. On ``IntegrityError`` falls back to per-row writes.
"""
def __init__(
self,
message_id: str,
*,
batch_size: int = DEFAULT_BATCH_SIZE,
batch_interval_ms: int = DEFAULT_BATCH_INTERVAL_MS,
) -> None:
self._message_id = message_id
self._batch_size = batch_size
self._batch_interval_ms = batch_interval_ms
self._buffer: list[tuple[int, str, dict[str, Any]]] = []
self._last_flush_mono_ms = time.monotonic() * 1000.0
self._closed = False
def record(
self,
sequence_no: int,
event_type: str,
payload: Optional[dict[str, Any]] = None,
) -> bool:
"""Buffer one event; maybe flush. Publish happens after journal commit."""
if self._closed:
logger.warning(
"BatchedJournalWriter.record after close: "
"message_id=%s seq=%s type=%s",
self._message_id,
sequence_no,
event_type,
)
return False
if not event_type:
logger.warning(
"BatchedJournalWriter.record without event_type: "
"message_id=%s seq=%s",
self._message_id,
sequence_no,
)
return False
if payload is None:
materialised: dict[str, Any] = {}
elif isinstance(payload, dict):
materialised = _strip_null_bytes(payload)
else:
# Same contract as ``record_event`` — non-dict payloads
# are rejected so the live and replay paths can't diverge
# on envelope reconstruction.
logger.warning(
"BatchedJournalWriter.record with non-dict payload: "
"message_id=%s seq=%s type=%s payload_type=%s — dropping",
self._message_id,
sequence_no,
event_type,
type(payload).__name__,
)
return False
self._buffer.append((sequence_no, event_type, materialised))
if self._should_flush():
self.flush()
return True
def _should_flush(self) -> bool:
if len(self._buffer) >= self._batch_size:
return True
elapsed_ms = (time.monotonic() * 1000.0) - self._last_flush_mono_ms
return elapsed_ms >= self._batch_interval_ms and len(self._buffer) > 0
def flush(self) -> None:
"""Commit buffered events to PG. Best-effort.
Tries one bulk INSERT first; on ``IntegrityError`` (composite
PK collision — typically a stale continuation seed) falls back
to per-row ``record_event`` so one bad seq doesn't drop the
rest of the batch. Always clears the buffer to bound memory,
even on failure — a journaled event missing from a snapshot
is degraded UX, but a runaway buffer is corruption.
"""
if not self._buffer:
self._last_flush_mono_ms = time.monotonic() * 1000.0
return
# Snapshot and clear before the I/O so a concurrent record()
# call would land in a fresh buffer rather than racing the
# flush. ``complete_stream`` is single-threaded per stream, so
# this is belt-and-suspenders for any future change.
pending = self._buffer
self._buffer = []
self._last_flush_mono_ms = time.monotonic() * 1000.0
try:
with db_session() as conn:
MessageEventsRepository(conn).bulk_record(
self._message_id, pending
)
except IntegrityError:
logger.info(
"BatchedJournalWriter: bulk INSERT collided for "
"message_id=%s n=%d; falling back to per-row writes",
self._message_id,
len(pending),
)
self._flush_per_row(pending)
return
except Exception:
logger.exception(
"BatchedJournalWriter: bulk INSERT failed for "
"message_id=%s n=%d; events dropped from journal",
self._message_id,
len(pending),
)
return
# Bulk INSERT committed — publish each frame in order. Best-effort:
# one failed publish must not poison the rest of the batch.
for seq, event_type, payload in pending:
self._publish(seq, event_type, payload)
def _flush_per_row(
self, pending: list[tuple[int, str, dict[str, Any]]]
) -> None:
"""Per-row fallback after a bulk collision. Publishes after each commit."""
for seq, event_type, payload in pending:
committed_seq: Optional[int] = None
try:
with db_session() as conn:
MessageEventsRepository(conn).record(
self._message_id, seq, event_type, payload
)
committed_seq = seq
except IntegrityError:
try:
with db_readonly() as conn:
latest = MessageEventsRepository(
conn
).latest_sequence_no(self._message_id)
retry_seq = (latest if latest is not None else -1) + 1
with db_session() as conn:
MessageEventsRepository(conn).record(
self._message_id, retry_seq, event_type, payload
)
committed_seq = retry_seq
except IntegrityError:
logger.warning(
"BatchedJournalWriter: IntegrityError persists "
"after seq+1 retry; dropping. message_id=%s "
"original_seq=%s type=%s",
self._message_id,
seq,
event_type,
)
except Exception:
logger.exception(
"BatchedJournalWriter: per-row retry failed "
"(message_id=%s seq=%s type=%s)",
self._message_id,
seq,
event_type,
)
except Exception:
logger.exception(
"BatchedJournalWriter: per-row INSERT failed "
"(message_id=%s seq=%s type=%s)",
self._message_id,
seq,
event_type,
)
if committed_seq is not None:
self._publish(committed_seq, event_type, payload)
def _publish(
self, sequence_no: int, event_type: str, payload: dict[str, Any]
) -> None:
"""Publish one frame to the per-message pubsub channel. Best-effort."""
try:
wire = encode_pubsub_message(
self._message_id, sequence_no, event_type, payload
)
Topic(message_topic_name(self._message_id)).publish(wire)
except Exception:
logger.exception(
"channel:%s publish failed: seq=%s type=%s",
self._message_id,
sequence_no,
event_type,
)
def close(self) -> None:
"""Final flush. Idempotent — safe to call from multiple
finally clauses.
"""
if self._closed:
return
self.flush()
self._closed = True

View File

@@ -1,8 +1,5 @@
import logging
from typing import List, Optional, Any, Dict
from psycopg.types.json import Jsonb
from application.core.settings import settings
from application.vectorstore.base import BaseVectorStore
from application.vectorstore.document_class import Document
@@ -178,7 +175,7 @@ class PGVectorStore(BaseVectorStore):
for text, embedding, metadata in zip(texts, embeddings, metadatas):
cursor.execute(
insert_query,
(text, embedding, Jsonb(metadata), self._source_id)
(text, embedding, metadata, self._source_id)
)
inserted_id = cursor.fetchone()[0]
inserted_ids.append(str(inserted_id))
@@ -269,7 +266,7 @@ class PGVectorStore(BaseVectorStore):
cursor.execute(
insert_query,
(text, embeddings[0], Jsonb(final_metadata), self._source_id)
(text, embeddings[0], final_metadata, self._source_id)
)
inserted_id = cursor.fetchone()[0]
conn.commit()

File diff suppressed because it is too large Load Diff

View File

@@ -1,385 +0,0 @@
# SSE Notifications Runbook
> Operations guide for "user says they didn't get a notification" — and
> the related "the bell never lights up" / "my upload toast hangs" /
> "the chat answer doesn't reconnect" symptoms.
The user-facing notifications channel is the SSE pipe at
`/api/events` plus per-message reconnects at
`/api/messages/<id>/events`. This document maps a user complaint to
the diagnostic that surfaces the cause.
---
## TL;DR — first 60 seconds
Run these three commands in parallel before anything else:
```bash
# 1) Is Redis up and serving the pipe? Should print PONG instantly.
redis-cli -n 2 PING
# 2) Anyone subscribed to the channel right now? Numbers per channel.
redis-cli -n 2 PUBSUB NUMSUB user:<user_id>
# 3) Is the user's backlog populated? Returns the count of journaled events.
redis-cli -n 2 XLEN user:<user_id>:stream
```
- `PING` failing → Redis is the problem. Skip to "Redis-down".
- `NUMSUB user:<user_id>` returns 0 → no client connected. Skip to "Client never connects".
- `XLEN user:<user_id>:stream` returns 0 or low → publisher isn't writing. Skip to "Publisher silent".
- All three look healthy → the events are flowing on the wire; the issue is downstream of the slice (UI rendering, toast suppression, etc.). Skip to "Events flowing but UI silent".
---
## Architecture cheat-sheet
```
Worker (publish_user_event) Frontend tab
│ ▲
▼ │ GET /api/events SSE
Redis Streams: XADD Flask route
user:<id>:stream ──────────────► replay_backlog (snapshot)
│ +
▼ Topic.subscribe (live tail)
Redis pub/sub: PUBLISH │
user:<id> ────────────────────────────────┘
```
**Source of truth:**
- Persistent journal: Redis Stream `user:<user_id>:stream`, capped at
`EVENTS_STREAM_MAXLEN` (default 1000) entries via `MAXLEN ~`. ~24h
at typical event rates.
- Live fan-out: Redis pub/sub channel `user:<user_id>`. No durability;
subscribers must be attached at publish time.
The chat-stream pipe is separate, parallel infrastructure:
- Journal: Postgres `message_events` table.
- Live fan-out: Redis pub/sub `channel:<message_id>`.
Same patterns, different durability layer. This doc covers both;
they share most diagnostic commands.
---
## Symptom → diagnostic map
### A. "I uploaded a source and the toast never appeared"
User flow: chat → upload → expect toast.
| Step | Command | Expect |
| ------------------------------------------------- | ------------------------------------------------------------- | ----------------------------------------------- |
| Worker received the task | `tail -f celery.log` filtered by user | `ingest_worker` start log line |
| Worker published the queued event | `redis-cli -n 2 XREVRANGE user:<id>:stream + - COUNT 5` | A `source.ingest.queued` entry within seconds |
| Frontend got it | DevTools → Network → `/api/events` → EventStream tab | `data: {"type":"source.ingest.queued",...}` |
| Slice updated | Redux DevTools → state.upload.tasks | Task with matching `sourceId`, `status:'training'` |
If the worker's queued log line is there but the XADD didn't land →
look for a `publish_user_event payload not JSON-serializable` warning
in the worker log (the publisher swallows `TypeError`).
If the XADD landed but the frontend never received it → check
`PUBSUB NUMSUB user:<id>` while the user is on the page. If 0, the
SSE connection isn't subscribed; skip to "Client never connects".
If the frontend received it but the toast didn't render → the
`uploadSlice` extraReducer requires `task.sourceId` to match the
event's `scope.id`. Check the upload route returned `source_id` in
its POST response (the upload, connector, and reingest paths all
include it). Idempotent / cached responses must also include
`source_id` (`_claim_task_or_get_cached`).
### B. "The bell badge never goes up"
There is no bell — the global notifications surface is per-event
toasts, not an aggregated counter. If the user is on an old build,
`Cmd-Shift-R` to bypass cache. The surfaces they're looking for are
`UploadToast` for source uploads and `ToolApprovalToast` for
tool-approval events.
### C. "My chat answer froze mid-stream and never recovered"
User flow: ask question → answer streaming → network blip → answer
stops; should reconnect.
```bash
# Was the original message reserved in PG?
psql -c "SELECT id, status, prompt FROM conversation_messages \
WHERE user_id = '<user>' ORDER BY timestamp DESC LIMIT 5;"
# Did the journal capture events past the user's last-seen seq?
psql -c "SELECT sequence_no, event_type FROM message_events \
WHERE message_id = '<id>' ORDER BY sequence_no;"
# Is the live tail still producing? (subscribe and watch)
redis-cli -n 2 SUBSCRIBE channel:<message_id>
```
The frontend should reconnect via `GET /api/messages/<id>/events`
when the original POST stream closes without a typed `end` or
`error` event. If it's not reconnecting, `console.warn('Stream
reconnect failed', ...)` will be in the browser console — the
reconnect HTTP errored. Common cases:
- The user's JWT rotated mid-stream → 401 on the GET. Frontend
doesn't auto-refresh; the user reloads.
- The user is on a different host than the API and CORS is rejecting
the GET → check `application/asgi.py` allow-headers.
### D. "The dev install never delivers any notifications at all"
Default `AUTH_TYPE` unset means `decoded_token = {"sub": "local"}`
for every request. The SSE client connects without the
`Authorization` header in this case, and `user:local:stream` is
the shared channel everything goes to. If the user has multiple dev
machines pointing at the same Redis, they will see each other's
events. Confirm with:
```bash
redis-cli -n 2 KEYS 'user:local:*'
```
If multiple deployments share the Redis, document that as a known
multi-user-on-local-channel limitation. Set `AUTH_TYPE=simple_jwt`
to scope per-user.
### E. "The notifications channel was working, then suddenly stopped after the user reloaded the page"
Likely path: `backlog.truncated` event fired, the slice cleared
`lastEventId` to null, the closure was carrying the same stale id and
re-tripped the same truncation on every reconnect. **Verify the user
is on a current build — `eventStreamClient.ts` must re-read
`lastEventId = opts.getLastEventId();` without a truthy guard so the
null clear propagates into the next reconnect.**
### F. "I keep getting 429 on /api/events"
The per-user concurrent-connection cap (`SSE_MAX_CONCURRENT_PER_USER`,
default 8) refused the connection. User has too many tabs open or a
runaway reconnect loop. `redis-cli -n 2 GET user:<id>:sse_count`
shows the live counter; the TTL is 1h from the last connection
attempt (rolling — every INCR re-seeds it), so the key only ages
out after the user stops reconnecting for a full hour.
If the count is wedged high without explanation, the
counter-DECR-in-finally path didn't run (worker SIGKILL, OOM). Wait
for the TTL or `redis-cli -n 2 DEL user:<id>:sse_count` to reset.
### G. "Replay snapshot stops at 200 events"
The route caps each replay at `EVENTS_REPLAY_MAX_PER_REQUEST`
(default 200). The cap is intentionally **silent** — the route does NOT
emit a `backlog.truncated` notice for cap-hit. The 200 entries each
carry their own `id:` header, so the frontend's slice cursor
advances to the most-recent delivered id. Next reconnect sends
`last_event_id=<max_replayed>` and the snapshot resumes from there.
A user that was 1000 entries behind catches up over ~5 reconnects.
If the user reports getting HTTP 429 on `/api/events` despite being
well under `SSE_MAX_CONCURRENT_PER_USER`, they hit the windowed
replay budget (`EVENTS_REPLAY_BUDGET_REQUESTS_PER_WINDOW`, default
30 / `EVENTS_REPLAY_BUDGET_WINDOW_SECONDS` 60s). The route refuses
the connection so the slice cursor stays pinned at whatever value
it had; the frontend backs off and the next reconnect (after the
window rolls) gets the proper snapshot. Serving the live tail
without a snapshot used to be the behavior here, but that let the
client advance `lastEventId` past entries it never received,
permanently stranding the un-replayed window — so the route now
429s instead. `redis-cli -n 2 GET user:<id>:replay_count` shows the
current counter; TTL is the window size.
`backlog.truncated` is emitted ONLY when the client's
`Last-Event-ID` has slid off the MAXLEN'd window — i.e. the journal
is genuinely gone past the cursor and the frontend should clear the
slice cursor and refetch state. Treating cap-hit or
budget-exhaustion the same way would lock the user into re-receiving
the oldest 200 entries on every reconnect (the cursor would clear,
the snapshot would re-serve from the start, the cap would re-trip).
### H. "User says push notifications stopped after a deploy"
- Pull `event.published topic=user:<id> type=...` from the worker
logs to confirm the publisher is still firing.
- Pull `event.connect user=<id>` from the API logs to confirm the
client is reconnecting.
- Check the gunicorn worker count and `WSGIMiddleware(workers=32)`
if the deploy reduced worker count, the per-user cap is still 8
but total concurrent SSE connections are bounded by `gunicorn
workers × 32`. A capacity miss looks like users randomly getting
429'd.
---
## Common failure modes
### Redis-down
Symptoms: `/api/events` returns 200 but emits only `: connected`
then the body closes. `XLEN` and `PUBLISH` both fail. The publisher's
`record_event` swallows the failure and returns False; the live tail
publish also drops on the floor. Frontend retries forever with
exponential backoff.
Resolution: bring Redis back. The journal is gone (was in-memory
only — Streams persist within a single Redis instance, no replication
configured). New events flow as soon as Redis comes back.
### `AUTH_TYPE` misconfigured = sub:"local" cross-stream
Symptoms: every user shares `user:local:stream`. Any user sees
everyone else's notifications.
Resolution: set `AUTH_TYPE=simple_jwt` (or `session_jwt`) in `.env`.
The events route logs a one-time WARNING per process when
`sub == "local"` is observed. A repeat WARNING after a restart
confirms the misconfiguration.
### MAXLEN trimmed past Last-Event-ID
Symptoms: client reconnects with `last_event_id=X`, snapshot returns
the entire MAXLEN'd backlog (because X is older than the oldest
retained entry). Old events appear duplicated.
Detection: the route's `_oldest_retained_id` check emits
`backlog.truncated` when this case fires. Frontend's
`dispatchSSEEvent` clears `lastEventId` so the next reconnect starts
fresh.
If the WARNING isn't firing but symptoms match: the user's client
may have a corrupt cached `lastEventId`. `localStorage` doesn't
store this state; check Redux state via DevTools.
### Stale event-stream client
Symptoms: events visible in `XRANGE` but the frontend slice doesn't
update.
```bash
# Is the client subscribed?
redis-cli -n 2 PUBSUB NUMSUB user:<id>
# When did its connection start?
grep "event.connect user=<id>" /var/log/docsgpt.log | tail -3
```
If `NUMSUB` is 0 and no recent `event.connect`, the user's tab is
closed or the connection died and never reconnected. Push them to
reload.
### Publisher silent
Symptoms: worker is processing the task (Celery says SUCCESS), but
no XADD and no PUBLISH. User sees no events.
```bash
# Was the publisher import error suppressed?
grep "publish_user_event" /var/log/celery.log | grep -i "warn\|error" | tail -20
# Is push disabled?
grep "ENABLE_SSE_PUSH" /var/log/docsgpt.log | tail -5
```
`ENABLE_SSE_PUSH=False` in `.env` would silence the publisher
globally. Useful for incident response if a runaway publisher is
DoS'ing Redis; toggle off, fix root cause, toggle on.
---
## Useful one-liners
```bash
# Watch a user's live event stream in real time (all events, all types)
redis-cli -n 2 PSUBSCRIBE 'user:*' | grep "user:<id>"
# Last 10 events the user would see on reconnect
redis-cli -n 2 XREVRANGE user:<id>:stream + - COUNT 10
# Live count of subscribed clients per user
redis-cli -n 2 PUBSUB NUMSUB $(redis-cli -n 2 PUBSUB CHANNELS 'user:*')
# Trim a runaway stream (CAREFUL — destroys backlog for all current
# subscribers; OK after explaining to the user)
redis-cli -n 2 XTRIM user:<id>:stream MAXLEN 0
# Clear a wedged concurrent-connection counter
redis-cli -n 2 DEL user:<id>:sse_count
# Force-flip every client to re-snapshot (drop the stream key entirely
# — destroys the backlog; clients reconnect with their last id and
# get a backlog.truncated)
redis-cli -n 2 DEL user:<id>:stream
```
---
## Settings reference
Everything in `application/core/settings.py`:
| Setting | Default | Purpose |
| --------------------------------------------- | ------- | --------------------------------------------- |
| `ENABLE_SSE_PUSH` | `True` | Master switch. False = publisher no-ops, route serves "push_disabled" comment. |
| `EVENTS_STREAM_MAXLEN` | `1000` | Per-user backlog cap. Approximate via `XADD MAXLEN ~`. |
| `SSE_KEEPALIVE_SECONDS` | `15` | Comment-frame cadence. Must sit under reverse-proxy idle close. |
| `SSE_MAX_CONCURRENT_PER_USER` | `8` | Cap on simultaneous SSE connections per user. 0 = disabled. |
| `EVENTS_REPLAY_MAX_PER_REQUEST` | `200` | Hard cap on snapshot rows per request. |
| `EVENTS_REPLAY_BUDGET_REQUESTS_PER_WINDOW` | `30` | Per-user replays per window. 0 = disabled. |
| `EVENTS_REPLAY_BUDGET_WINDOW_SECONDS` | `60` | Window length. |
| `MESSAGE_EVENTS_RETENTION_DAYS` | `14` | Retention for the `message_events` journal; `cleanup_message_events` beat task deletes older rows. |
---
## Known limitations
### Each tab runs its own SSE connection
There is no cross-tab dedup. Every tab open to the app holds its
own SSE connection and dispatches every received event into its
own Redux store, so a user with N tabs open will see N copies of
each toast. With `SSE_MAX_CONCURRENT_PER_USER=8` (the default) a
heavy multi-tab user can also hit the connection cap and start
seeing 429s. Cross-tab dedup via a `BroadcastChannel` ring +
`navigator.locks`-based leader election is tracked as future work.
### `/c/<unknown-id>` normalises to `/c/new`
If a user navigates to a conversation id that isn't in their
loaded list, the conversation route rewrites the URL to `/c/new`.
`ToolApprovalToast`'s gate uses `useMatch('/c/:conversationId')`,
so for the brief window after the rewrite the toast may surface
for a conversation the user *thought* they were already viewing.
Pre-existing route behaviour; not a notifications regression.
### Terminal events un-dismiss running uploads
`frontend/src/upload/uploadSlice.ts` sets `dismissed: false` when
an upload reaches `completed` or `failed`. If the user dismissed a
running task and the terminal SSE arrives later, the toast pops
back. Intentional ("notify the user it's done"); revisit if the
re-surface UX is too aggressive for v2.
### Werkzeug doesn't auto-reload route files
The dev server (`flask run`) doesn't watch
`application/api/events/routes.py` for changes by default.
After editing the route, restart Flask manually — `--reload`
isn't on. (Production gunicorn reloads via deploy.)
### MCP OAuth completion can fall outside the user stream's MAXLEN window
`get_oauth_status` scans up to `EVENTS_STREAM_MAXLEN` (~1000) entries via `XREVRANGE`. If the user has a high-rate ingest running concurrent with the OAuth handshake, the `mcp.oauth.completed` envelope can be trimmed off the back before they click Save. Symptom: backend returns "OAuth failed or not completed" even though the popup completed successfully.
Mitigation today: bump `EVENTS_STREAM_MAXLEN` per-deployment if your users routinely flood the channel during OAuth flows. A dedicated short-TTL Redis key for OAuth task results is tracked as a follow-up.
### React StrictMode double-mounts SSE
In dev, React 18 StrictMode mounts → unmounts → remounts every
component, briefly opening two SSE connections per tab before the
first is aborted. With `SSE_MAX_CONCURRENT_PER_USER=8` and 45
tabs open concurrently you can transiently hit the cap and see
HTTP 429 on cold-load. The first connection's counter increment
fires before the AbortController-induced disconnect can decrement
it. Production (single mount, no StrictMode) is unaffected; raise
the cap in dev or accept transient 429s.

File diff suppressed because it is too large Load Diff

View File

@@ -7,8 +7,6 @@
"dev": "vite",
"build": "tsc && vite build",
"preview": "vite preview",
"test": "vitest run",
"test:watch": "vitest",
"lint": "eslint ./src --ext .jsx,.js,.ts,.tsx",
"lint-fix": "eslint ./src --ext .jsx,.js,.ts,.tsx --fix",
"format": "prettier ./src --write",
@@ -71,7 +69,6 @@
"eslint-plugin-promise": "^6.6.0",
"eslint-plugin-react": "^7.37.5",
"eslint-plugin-unused-imports": "^4.1.4",
"happy-dom": "^17.6.3",
"husky": "^9.1.7",
"lint-staged": "^16.4.0",
"postcss": "^8.5.12",
@@ -81,7 +78,6 @@
"tw-animate-css": "^1.4.0",
"typescript": "^6.0.3",
"vite": "^8.0.10",
"vite-plugin-svgr": "^4.3.0",
"vitest": "^3.2.4"
"vite-plugin-svgr": "^4.3.0"
}
}

View File

@@ -10,7 +10,6 @@ import Spinner from './components/Spinner';
import UploadToast from './components/UploadToast';
import Conversation from './conversation/Conversation';
import { SharedConversation } from './conversation/SharedConversation';
import { EventStreamProvider } from './events/EventStreamProvider';
import { useDarkTheme, useMediaQuery } from './hooks';
import useDataInitializer from './hooks/useDataInitializer';
import useTokenAuth from './hooks/useTokenAuth';
@@ -18,7 +17,6 @@ import Navigation from './Navigation';
import PageNotFound from './PageNotFound';
import Setting from './settings';
import Notification from './components/Notification';
import ToolApprovalToast from './notifications/ToolApprovalToast';
function AuthWrapper({ children }: { children: React.ReactNode }) {
const { isAuthLoading } = useTokenAuth();
@@ -31,7 +29,7 @@ function AuthWrapper({ children }: { children: React.ReactNode }) {
</div>
);
}
return <EventStreamProvider>{children}</EventStreamProvider>;
return <>{children}</>;
}
function MainLayout() {
@@ -52,7 +50,6 @@ function MainLayout() {
<Outlet />
</div>
<UploadToast />
<ToolApprovalToast />
</div>
);
}

View File

@@ -813,11 +813,7 @@ function WorkflowBuilderInner() {
const response = await userService.getWorkflow(workflowId, token);
if (!response.ok) throw new Error('Failed to fetch workflow');
const responseData = await response.json();
const {
workflow,
nodes: apiNodes,
edges: apiEdges,
} = responseData.data;
const { workflow, nodes: apiNodes, edges: apiEdges } = responseData.data;
const nextWorkflowName = workflow.name;
const nextWorkflowDescription = workflow.description || '';
const mappedNodes = apiNodes.map((n: WorkflowNode) => {
@@ -1476,9 +1472,7 @@ function WorkflowBuilderInner() {
{t('agents.form.advanced.systemPromptOverride')}
</label>
<p className="mt-0.5 text-[11px] text-gray-500 dark:text-gray-400">
{t(
'agents.form.advanced.systemPromptOverrideDescription',
)}
{t('agents.form.advanced.systemPromptOverrideDescription')}
</p>
</div>
<button

View File

@@ -1,5 +1,3 @@
import { withThrottle, type FetchLike } from './throttle';
export const baseURL =
import.meta.env.VITE_API_HOST || 'https://docsapi.arc53.com';
@@ -20,121 +18,112 @@ const getHeaders = (
return headers;
};
const createClient = (transport: FetchLike) => {
const request = (url: string, init: RequestInit): Promise<Response> =>
transport(`${baseURL}${url}`, init);
const apiClient = {
get: (
url: string,
token: string | null,
headers = {},
signal?: AbortSignal,
): Promise<any> =>
fetch(`${baseURL}${url}`, {
method: 'GET',
headers: getHeaders(token, headers),
signal,
}).then((response) => {
return response;
}),
return {
get: (
url: string,
token: string | null,
headers = {},
signal?: AbortSignal,
): Promise<any> =>
request(url, {
method: 'GET',
headers: getHeaders(token, headers),
signal,
}),
post: (
url: string,
data: any,
token: string | null,
headers = {},
signal?: AbortSignal,
): Promise<any> =>
fetch(`${baseURL}${url}`, {
method: 'POST',
headers: getHeaders(token, headers),
body: JSON.stringify(data),
signal,
}).then((response) => {
return response;
}),
post: (
url: string,
data: any,
token: string | null,
headers = {},
signal?: AbortSignal,
): Promise<any> =>
request(url, {
method: 'POST',
headers: getHeaders(token, headers),
body: JSON.stringify(data),
signal,
}),
postFormData: (
url: string,
formData: FormData,
token: string | null,
headers = {},
signal?: AbortSignal,
): Promise<Response> => {
return fetch(`${baseURL}${url}`, {
method: 'POST',
headers: getHeaders(token, headers, true),
body: formData,
signal,
});
},
postFormData: (
url: string,
formData: FormData,
token: string | null,
headers = {},
signal?: AbortSignal,
): Promise<Response> =>
request(url, {
method: 'POST',
headers: getHeaders(token, headers, true),
body: formData,
signal,
}),
put: (
url: string,
data: any,
token: string | null,
headers = {},
signal?: AbortSignal,
): Promise<any> =>
fetch(`${baseURL}${url}`, {
method: 'PUT',
headers: getHeaders(token, headers),
body: JSON.stringify(data),
signal,
}).then((response) => {
return response;
}),
put: (
url: string,
data: any,
token: string | null,
headers = {},
signal?: AbortSignal,
): Promise<any> =>
request(url, {
method: 'PUT',
headers: getHeaders(token, headers),
body: JSON.stringify(data),
signal,
}),
patch: (
url: string,
data: any,
token: string | null,
headers = {},
signal?: AbortSignal,
): Promise<any> =>
fetch(`${baseURL}${url}`, {
method: 'PATCH',
headers: getHeaders(token, headers),
body: JSON.stringify(data),
signal,
}).then((response) => {
return response;
}),
patch: (
url: string,
data: any,
token: string | null,
headers = {},
signal?: AbortSignal,
): Promise<any> =>
request(url, {
method: 'PATCH',
headers: getHeaders(token, headers),
body: JSON.stringify(data),
signal,
}),
putFormData: (
url: string,
formData: FormData,
token: string | null,
headers = {},
signal?: AbortSignal,
): Promise<Response> => {
return fetch(`${baseURL}${url}`, {
method: 'PUT',
headers: getHeaders(token, headers, true),
body: formData,
signal,
});
},
putFormData: (
url: string,
formData: FormData,
token: string | null,
headers = {},
signal?: AbortSignal,
): Promise<Response> =>
request(url, {
method: 'PUT',
headers: getHeaders(token, headers, true),
body: formData,
signal,
}),
delete: (
url: string,
token: string | null,
headers = {},
signal?: AbortSignal,
): Promise<any> =>
request(url, {
method: 'DELETE',
headers: getHeaders(token, headers),
signal,
}),
};
delete: (
url: string,
token: string | null,
headers = {},
signal?: AbortSignal,
): Promise<any> =>
fetch(`${baseURL}${url}`, {
method: 'DELETE',
headers: getHeaders(token, headers),
signal,
}).then((response) => {
return response;
}),
};
const apiClient = createClient((url, init) => fetch(url, init));
// Throttled client for endpoints that fan out, are polled, or are commonly
// requested concurrently from multiple components. Shares a single concurrency
// budget and de-duplicates identical in-flight GETs.
export const throttledApiClient = createClient(
withThrottle((url, init) => fetch(url, init), { debugLabel: 'api' }),
);
if (import.meta.env.DEV && typeof window !== 'undefined') {
(window as unknown as Record<string, unknown>).__apiClient = apiClient;
(window as unknown as Record<string, unknown>).__throttledApiClient =
throttledApiClient;
(window as unknown as Record<string, unknown>).__baseURL = baseURL;
}
export default apiClient;

View File

@@ -28,13 +28,13 @@ const endpoints = {
UPDATE_PROMPT: '/api/update_prompt',
SINGLE_PROMPT: (id: string) => `/api/get_single_prompt?id=${id}`,
DELETE_PATH: (docPath: string) => `/api/delete_old?source_id=${docPath}`,
TASK_STATUS: (task_id: string) => `/api/task_status?task_id=${task_id}`,
MESSAGE_ANALYTICS: '/api/get_message_analytics',
TOKEN_ANALYTICS: '/api/get_token_analytics',
FEEDBACK_ANALYTICS: '/api/get_feedback_analytics',
LOGS: `/api/get_user_logs`,
MANAGE_SYNC: '/api/manage_sync',
SYNC_SOURCE: '/api/sync_source',
REINGEST_SOURCE: '/api/sources/reingest',
GET_AVAILABLE_TOOLS: '/api/available_tools',
GET_USER_TOOLS: '/api/get_tools',
CREATE_TOOL: '/api/create_tool',
@@ -43,11 +43,6 @@ const endpoints = {
DELETE_TOOL: '/api/delete_tool',
PARSE_SPEC: '/api/parse_spec',
SYNC_CONNECTOR: '/api/connectors/sync',
CONNECTOR_AUTH: (provider: string) =>
`/api/connectors/auth?provider=${provider}`,
CONNECTOR_FILES: '/api/connectors/files',
CONNECTOR_VALIDATE_SESSION: '/api/connectors/validate-session',
CONNECTOR_DISCONNECT: '/api/connectors/disconnect',
GET_CHUNKS: (
docId: string,
page: number,
@@ -64,7 +59,6 @@ const endpoints = {
UPDATE_CHUNK: '/api/update_chunk',
STORE_ATTACHMENT: '/api/store_attachment',
STT: '/api/stt',
TTS: '/api/tts',
LIVE_STT_START: '/api/stt/live/start',
LIVE_STT_CHUNK: '/api/stt/live/chunk',
LIVE_STT_FINISH: '/api/stt/live/finish',
@@ -73,6 +67,8 @@ const endpoints = {
MANAGE_SOURCE_FILES: '/api/manage_source_files',
MCP_TEST_CONNECTION: '/api/mcp_server/test',
MCP_SAVE_SERVER: '/api/mcp_server/save',
MCP_OAUTH_STATUS: (task_id: string) =>
`/api/mcp_server/oauth_status/${task_id}`,
MCP_AUTH_STATUS: '/api/mcp_server/auth_status',
AGENT_FOLDERS: '/api/agents/folders/',
AGENT_FOLDER: (id: string) => `/api/agents/folders/${id}`,

View File

@@ -1,12 +1,11 @@
import { getSessionToken } from '../../utils/providerUtils';
import apiClient, { throttledApiClient } from '../client';
import apiClient from '../client';
import endpoints from '../endpoints';
const userService = {
getConfig: (): Promise<any> =>
throttledApiClient.get(endpoints.USER.CONFIG, null),
getConfig: (): Promise<any> => apiClient.get(endpoints.USER.CONFIG, null),
getNewToken: (): Promise<any> =>
throttledApiClient.get(endpoints.USER.NEW_TOKEN, null),
apiClient.get(endpoints.USER.NEW_TOKEN, null),
getDocs: (token: string | null): Promise<any> =>
apiClient.get(`${endpoints.USER.DOCS}`, token),
getDocsWithPagination: (query: string, token: string | null): Promise<any> =>
@@ -18,9 +17,9 @@ const userService = {
deleteAPIKey: (data: any, token: string | null): Promise<any> =>
apiClient.post(endpoints.USER.DELETE_API_KEY, data, token),
getAgent: (id: string, token: string | null): Promise<any> =>
throttledApiClient.get(endpoints.USER.AGENT(id), token),
apiClient.get(endpoints.USER.AGENT(id), token),
getAgents: (token: string | null): Promise<any> =>
throttledApiClient.get(endpoints.USER.AGENTS, token),
apiClient.get(endpoints.USER.AGENTS, token),
createAgent: (data: any, token: string | null): Promise<any> =>
apiClient.postFormData(endpoints.USER.CREATE_AGENT, data, token),
updateAgent: (
@@ -32,19 +31,19 @@ const userService = {
deleteAgent: (id: string, token: string | null): Promise<any> =>
apiClient.delete(endpoints.USER.DELETE_AGENT(id), token),
getPinnedAgents: (token: string | null): Promise<any> =>
throttledApiClient.get(endpoints.USER.PINNED_AGENTS, token),
apiClient.get(endpoints.USER.PINNED_AGENTS, token),
togglePinAgent: (id: string, token: string | null): Promise<any> =>
apiClient.post(endpoints.USER.TOGGLE_PIN_AGENT(id), {}, token),
getSharedAgent: (id: string, token: string | null): Promise<any> =>
apiClient.get(endpoints.USER.SHARED_AGENT(id), token),
getSharedAgents: (token: string | null): Promise<any> =>
throttledApiClient.get(endpoints.USER.SHARED_AGENTS, token),
apiClient.get(endpoints.USER.SHARED_AGENTS, token),
shareAgent: (data: any, token: string | null): Promise<any> =>
apiClient.put(endpoints.USER.SHARE_AGENT, data, token),
removeSharedAgent: (id: string, token: string | null): Promise<any> =>
apiClient.delete(endpoints.USER.REMOVE_SHARED_AGENT(id), token),
getTemplateAgents: (token: string | null): Promise<any> =>
throttledApiClient.get(endpoints.USER.TEMPLATE_AGENTS, token),
apiClient.get(endpoints.USER.TEMPLATE_AGENTS, token),
adoptAgent: (id: string, token: string | null): Promise<any> =>
apiClient.post(endpoints.USER.ADOPT_AGENT(id), {}, token),
getAgentWebhook: (id: string, token: string | null): Promise<any> =>
@@ -61,6 +60,8 @@ const userService = {
apiClient.get(endpoints.USER.SINGLE_PROMPT(id), token),
deletePath: (docPath: string, token: string | null): Promise<any> =>
apiClient.get(endpoints.USER.DELETE_PATH(docPath), token),
getTaskStatus: (task_id: string, token: string | null): Promise<any> =>
apiClient.get(endpoints.USER.TASK_STATUS(task_id), token),
getMessageAnalytics: (data: any, token: string | null): Promise<any> =>
apiClient.post(endpoints.USER.MESSAGE_ANALYTICS, data, token),
getTokenAnalytics: (data: any, token: string | null): Promise<any> =>
@@ -73,8 +74,6 @@ const userService = {
apiClient.post(endpoints.USER.MANAGE_SYNC, data, token),
syncSource: (data: any, token: string | null): Promise<any> =>
apiClient.post(endpoints.USER.SYNC_SOURCE, data, token),
reingestSource: (data: any, token: string | null): Promise<any> =>
apiClient.post(endpoints.USER.REINGEST_SOURCE, data, token),
getAvailableTools: (token: string | null): Promise<any> =>
apiClient.get(endpoints.USER.GET_AVAILABLE_TOOLS, token),
getUserTools: (token: string | null): Promise<any> =>
@@ -150,7 +149,7 @@ const userService = {
path?: string,
search?: string,
): Promise<any> =>
throttledApiClient.get(
apiClient.get(
endpoints.USER.GET_CHUNKS(docId, page, perPage, path, search),
token,
),
@@ -165,15 +164,17 @@ const userService = {
updateChunk: (data: any, token: string | null): Promise<any> =>
apiClient.put(endpoints.USER.UPDATE_CHUNK, data, token),
getDirectoryStructure: (docId: string, token: string | null): Promise<any> =>
throttledApiClient.get(endpoints.USER.DIRECTORY_STRUCTURE(docId), token),
apiClient.get(endpoints.USER.DIRECTORY_STRUCTURE(docId), token),
manageSourceFiles: (data: FormData, token: string | null): Promise<any> =>
apiClient.postFormData(endpoints.USER.MANAGE_SOURCE_FILES, data, token),
testMCPConnection: (data: any, token: string | null): Promise<any> =>
apiClient.post(endpoints.USER.MCP_TEST_CONNECTION, data, token),
saveMCPServer: (data: any, token: string | null): Promise<any> =>
apiClient.post(endpoints.USER.MCP_SAVE_SERVER, data, token),
getMCPOAuthStatus: (task_id: string, token: string | null): Promise<any> =>
apiClient.get(endpoints.USER.MCP_OAUTH_STATUS(task_id), token),
getMCPAuthStatus: (token: string | null): Promise<any> =>
throttledApiClient.get(endpoints.USER.MCP_AUTH_STATUS, token),
apiClient.get(endpoints.USER.MCP_AUTH_STATUS, token),
syncConnector: (
docId: string,
provider: string,
@@ -190,50 +191,8 @@ const userService = {
token,
);
},
getConnectorAuthUrl: (provider: string, token: string | null): Promise<any> =>
apiClient.get(endpoints.USER.CONNECTOR_AUTH(provider), token),
getConnectorFiles: (
data: any,
token: string | null,
signal?: AbortSignal,
): Promise<any> =>
throttledApiClient.post(
endpoints.USER.CONNECTOR_FILES,
data,
token,
{},
signal,
),
validateConnectorSession: (
provider: string,
token: string | null,
): Promise<any> =>
apiClient.post(
endpoints.USER.CONNECTOR_VALIDATE_SESSION,
{
provider,
session_token: getSessionToken(provider),
},
token,
),
disconnectConnector: (
provider: string,
sessionToken: string,
token: string | null,
): Promise<any> =>
apiClient.post(
endpoints.USER.CONNECTOR_DISCONNECT,
{ provider, session_token: sessionToken },
token,
),
textToSpeech: (
text: string,
token: string | null,
signal?: AbortSignal,
): Promise<any> =>
apiClient.post(endpoints.USER.TTS, { text }, token, {}, signal),
getAgentFolders: (token: string | null): Promise<any> =>
throttledApiClient.get(endpoints.USER.AGENT_FOLDERS, token),
apiClient.get(endpoints.USER.AGENT_FOLDERS, token),
createAgentFolder: (
data: { name: string; parent_id?: string },
token: string | null,

View File

@@ -1,223 +0,0 @@
/**
* Transport-layer middleware factory for the frontend API layer.
*/
export type FetchLike = (
input: string,
init?: RequestInit,
) => Promise<Response>;
export interface ThrottleConfig {
maxConcurrentGlobal?: number;
maxConcurrentPerRoute?: number;
dedupe?: boolean;
dedupeKey?: (url: string, init?: RequestInit) => string | false;
debugLabel?: string;
}
const DEFAULT_MAX_CONCURRENT_GLOBAL = 8;
const DEFAULT_MAX_CONCURRENT_PER_ROUTE = 3;
type QueueItem = {
run: () => void;
signal?: AbortSignal;
onAbort: () => void;
};
function routeKey(method: string, url: string): string {
let pathname = url;
try {
pathname = new URL(url, 'http://_').pathname;
} catch {
pathname = url.split('?')[0];
}
return `${method.toUpperCase()} ${pathname}`;
}
function abortError(): DOMException {
return new DOMException('The operation was aborted.', 'AbortError');
}
interface ThrottleState {
perRouteQueues: Map<string, QueueItem[]>;
inflightPerRoute: Map<string, number>;
inflightGets: Map<string, Promise<Response>>;
inflightGlobal: number;
}
function createState(): ThrottleState {
return {
perRouteQueues: new Map(),
inflightPerRoute: new Map(),
inflightGets: new Map(),
inflightGlobal: 0,
};
}
export function withThrottle(
fetchLike: FetchLike,
config: ThrottleConfig = {},
): FetchLike & { __reset: () => void } {
const maxGlobal = config.maxConcurrentGlobal ?? DEFAULT_MAX_CONCURRENT_GLOBAL;
const maxPerRoute =
config.maxConcurrentPerRoute ?? DEFAULT_MAX_CONCURRENT_PER_ROUTE;
const dedupeEnabled = config.dedupe !== false;
const state = createState();
// Toggle in DevTools with: localStorage.setItem('debug:throttle', '1')
const isDebug = (): boolean => {
try {
return (
typeof localStorage !== 'undefined' &&
localStorage.getItem('debug:throttle') === '1'
);
} catch {
return false;
}
};
const log = (
event: string,
key: string,
extra?: Record<string, unknown>,
): void => {
if (!isDebug()) return;
const queued = state.perRouteQueues.get(key)?.length ?? 0;
const perRoute = state.inflightPerRoute.get(key) ?? 0;
const tag = config.debugLabel
? `[throttle:${config.debugLabel}]`
: '[throttle]';
console.debug(
`${tag} ${event} ${key} | inflight=${state.inflightGlobal}/${maxGlobal} route=${perRoute}/${maxPerRoute} queued=${queued}`,
extra ?? '',
);
};
const canDispatch = (key: string): boolean => {
const perRoute = state.inflightPerRoute.get(key) ?? 0;
return state.inflightGlobal < maxGlobal && perRoute < maxPerRoute;
};
const pumpQueues = (): void => {
for (const [key, queue] of state.perRouteQueues) {
while (queue.length > 0 && canDispatch(key)) {
const item = queue.shift()!;
item.signal?.removeEventListener('abort', item.onAbort);
item.run();
}
if (queue.length === 0) state.perRouteQueues.delete(key);
}
};
const enqueue = (key: string, item: QueueItem): void => {
let queue = state.perRouteQueues.get(key);
if (!queue) {
queue = [];
state.perRouteQueues.set(key, queue);
}
queue.push(item);
};
const acquireSlot = (key: string, signal?: AbortSignal): Promise<void> =>
new Promise((resolve, reject) => {
if (signal?.aborted) {
reject(abortError());
return;
}
const item: QueueItem = {
signal,
run: () => {
state.inflightGlobal += 1;
state.inflightPerRoute.set(
key,
(state.inflightPerRoute.get(key) ?? 0) + 1,
);
resolve();
},
onAbort: () => {
const queue = state.perRouteQueues.get(key);
if (queue) {
const idx = queue.indexOf(item);
if (idx >= 0) queue.splice(idx, 1);
}
log('abort-queued', key);
reject(abortError());
},
};
const queued = state.perRouteQueues.get(key);
if ((!queued || queued.length === 0) && canDispatch(key)) {
item.run();
log('dispatch', key);
return;
}
signal?.addEventListener('abort', item.onAbort, { once: true });
enqueue(key, item);
log('queued', key);
});
const releaseSlot = (key: string): void => {
state.inflightGlobal = Math.max(0, state.inflightGlobal - 1);
const next = (state.inflightPerRoute.get(key) ?? 1) - 1;
if (next <= 0) state.inflightPerRoute.delete(key);
else state.inflightPerRoute.set(key, next);
log('release', key);
pumpQueues();
};
const wrapped = (async (url, init = {}) => {
const method = (init.method ?? 'GET').toUpperCase();
const signal = init.signal ?? undefined;
const key = routeKey(method, url);
// Dedupe is restricted to GETs without a caller-supplied AbortSignal:
// sharing a single underlying fetch across waiters means an abort by one
// caller would reject the others, which is not the contract callers expect.
const customKey = config.dedupeKey?.(url, init);
const dedupeAllowed =
dedupeEnabled &&
customKey !== false &&
method === 'GET' &&
!init.body &&
!signal;
const dedupeKey = typeof customKey === 'string' ? customKey : `GET ${url}`;
if (dedupeAllowed) {
const existing = state.inflightGets.get(dedupeKey);
if (existing) {
log('dedupe-hit', key, { dedupeKey });
return existing.then((r) => r.clone());
}
}
const run = async (): Promise<Response> => {
await acquireSlot(key, signal);
try {
return await fetchLike(url, init);
} finally {
releaseSlot(key);
}
};
if (dedupeAllowed) {
const promise = run();
state.inflightGets.set(dedupeKey, promise);
promise.finally(() => {
if (state.inflightGets.get(dedupeKey) === promise) {
state.inflightGets.delete(dedupeKey);
}
});
return promise.then((r) => r.clone());
}
return run();
}) as FetchLike & { __reset: () => void };
wrapped.__reset = () => {
state.perRouteQueues.clear();
state.inflightPerRoute.clear();
state.inflightGets.clear();
state.inflightGlobal = 0;
};
return wrapped;
}

View File

@@ -11,7 +11,6 @@ import NoFilesIcon from '../assets/no-files.svg';
import SearchIcon from '../assets/search.svg';
import {
useDarkTheme,
useDebouncedValue,
useLoaderState,
useMediaQuery,
useOutsideAlerter,
@@ -131,7 +130,6 @@ const Chunks: React.FC<ChunksProps> = ({
const [totalChunks, setTotalChunks] = useState(0);
const [loading, setLoading] = useLoaderState(true);
const [searchTerm, setSearchTerm] = useState<string>('');
const debouncedSearchTerm = useDebouncedValue(searchTerm, 300);
const [editingChunk, setEditingChunk] = useState<ChunkType | null>(null);
const [editingTitle, setEditingTitle] = useState('');
const [editingText, setEditingText] = useState('');
@@ -153,7 +151,7 @@ const Chunks: React.FC<ChunksProps> = ({
perPage,
token,
path,
debouncedSearchTerm,
searchTerm,
);
if (!response.ok) {
@@ -278,12 +276,16 @@ const Chunks: React.FC<ChunksProps> = ({
};
useEffect(() => {
if (page !== 1) {
setPage(1);
} else {
fetchChunks();
}
}, [debouncedSearchTerm]);
const delayDebounceFn = setTimeout(() => {
if (page !== 1) {
setPage(1);
} else {
fetchChunks();
}
}, 300);
return () => clearTimeout(delayDebounceFn);
}, [searchTerm]);
useEffect(() => {
!loading && fetchChunks();

View File

@@ -1,8 +1,7 @@
import React, { useEffect, useRef } from 'react';
import React, { useRef } from 'react';
import { useTranslation } from 'react-i18next';
import { useSelector } from 'react-redux';
import userService from '../api/services/userService';
import { useDarkTheme } from '../hooks';
import { selectToken } from '../preferences/preferenceSlice';
@@ -32,24 +31,13 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
const [isDarkTheme] = useDarkTheme();
const completedRef = useRef(false);
const intervalRef = useRef<number | null>(null);
const authWindowRef = useRef<Window | null>(null);
// Hold the exact listener identity so unmount cleanup removes the same fn.
const messageHandlerRef = useRef<((event: MessageEvent) => void) | null>(
null,
);
// Tracks mount status so async ``fetch`` resolves after unmount don't
// call ``onSuccess`` / ``onError`` on a vanished parent.
const mountedRef = useRef(true);
const cleanup = () => {
if (intervalRef.current) {
clearInterval(intervalRef.current);
intervalRef.current = null;
}
if (messageHandlerRef.current) {
window.removeEventListener('message', messageHandlerRef.current as any);
messageHandlerRef.current = null;
}
window.removeEventListener('message', handleAuthMessage as any);
};
const handleAuthMessage = (event: MessageEvent) => {
@@ -60,7 +48,6 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
if (successGeneric || successProvider) {
completedRef.current = true;
cleanup();
authWindowRef.current = null;
onSuccess({
session_token: event.data.session_token,
user_email:
@@ -70,7 +57,6 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
} else if (errorProvider) {
completedRef.current = true;
cleanup();
authWindowRef.current = null;
onError(
event.data.error || t('modals.uploadDoc.connectors.auth.authFailed'),
);
@@ -80,20 +66,15 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
const handleAuth = async () => {
try {
completedRef.current = false;
// Close any popup left over from a previous click before wiping
// the ref — otherwise the old window keeps living with no
// interval watching it and no listener handling its messages.
if (authWindowRef.current && !authWindowRef.current.closed) {
authWindowRef.current.close();
}
authWindowRef.current = null;
cleanup();
const authResponse = await userService.getConnectorAuthUrl(
provider,
token,
const apiHost = import.meta.env.VITE_API_HOST;
const authResponse = await fetch(
`${apiHost}/api/connectors/auth?provider=${provider}`,
{
headers: { Authorization: `Bearer ${token}` },
},
);
if (!mountedRef.current) return;
if (!authResponse.ok) {
throw new Error(
@@ -102,7 +83,6 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
}
const authData = await authResponse.json();
if (!mountedRef.current) return;
if (!authData.success || !authData.authorization_url) {
throw new Error(
authData.error || t('modals.uploadDoc.connectors.auth.authUrlFailed'),
@@ -117,23 +97,13 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
if (!authWindow) {
throw new Error(t('modals.uploadDoc.connectors.auth.popupBlocked'));
}
authWindowRef.current = authWindow;
messageHandlerRef.current = handleAuthMessage;
window.addEventListener('message', handleAuthMessage as any);
const checkClosed = window.setInterval(() => {
if (authWindow.closed) {
clearInterval(checkClosed);
intervalRef.current = null;
if (messageHandlerRef.current) {
window.removeEventListener(
'message',
messageHandlerRef.current as any,
);
messageHandlerRef.current = null;
}
authWindowRef.current = null;
window.removeEventListener('message', handleAuthMessage as any);
if (!completedRef.current) {
onError(t('modals.uploadDoc.connectors.auth.authCancelled'));
}
@@ -141,7 +111,6 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
}, 1000);
intervalRef.current = checkClosed;
} catch (error) {
if (!mountedRef.current) return;
onError(
error instanceof Error
? error.message
@@ -150,18 +119,6 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
}
};
// Release interval, message listener, and popup on unmount only.
useEffect(() => {
return () => {
mountedRef.current = false;
cleanup();
if (authWindowRef.current && !authWindowRef.current.closed) {
authWindowRef.current.close();
}
authWindowRef.current = null;
};
}, []);
return (
<>
{errorMessage && (

View File

@@ -1,6 +1,6 @@
import React, { useEffect, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useSelector, useStore } from 'react-redux';
import { useSelector } from 'react-redux';
import userService from '../api/services/userService';
import ArrowLeft from '../assets/arrow-left.svg';
@@ -14,7 +14,6 @@ import { useLoaderState, useOutsideAlerter } from '../hooks';
import ConfirmationModal from '../modals/ConfirmationModal';
import { ActiveState } from '../models/misc';
import { selectToken } from '../preferences/preferenceSlice';
import type { RootState } from '../store';
import { formatBytes } from '../utils/stringUtils';
import Chunks from './Chunks';
import ContextMenu, { MenuOption } from './ContextMenu';
@@ -65,7 +64,6 @@ const ConnectorTree: React.FC<ConnectorTreeProps> = ({
useState<DirectoryStructure | null>(null);
const [currentPath, setCurrentPath] = useState<string[]>([]);
const token = useSelector(selectToken);
const store = useStore<RootState>();
const [activeMenuId, setActiveMenuId] = useState<string | null>(null);
const menuRefs = useRef<{
[key: string]: React.RefObject<HTMLDivElement | null>;
@@ -83,25 +81,6 @@ const ConnectorTree: React.FC<ConnectorTreeProps> = ({
const [syncDone, setSyncDone] = useState<boolean>(false);
const [syncConfirmationModal, setSyncConfirmationModal] =
useState<ActiveState>('INACTIVE');
const mountedRef = useRef(true);
const syncUnsubscribeRef = useRef<(() => void) | null>(null);
// Holds the 5-minute SSE-wait timer so the unmount cleanup can clear
// it — otherwise the timer fires up to 5 min after unmount and
// resolves an abandoned Promise.
const syncTimerRef = useRef<number | null>(null);
useEffect(
() => () => {
mountedRef.current = false;
syncUnsubscribeRef.current?.();
syncUnsubscribeRef.current = null;
if (syncTimerRef.current !== null) {
window.clearTimeout(syncTimerRef.current);
syncTimerRef.current = null;
}
},
[],
);
useOutsideAlerter(
searchDropdownRef,
@@ -137,108 +116,67 @@ const ConnectorTree: React.FC<ConnectorTreeProps> = ({
console.log('Sync started successfully:', data.task_id);
setSyncProgress(10);
// The connector worker (``ingest_connector`` in
// ``application/worker.py``) publishes
// ``source.ingest.{queued,completed,failed}`` envelopes keyed on
// ``scope.id == docId`` (sync mode reuses the source uuid). Wait
// on the bounded ``notifications.recentEvents`` ring for a
// terminal envelope rather than polling ``/task_status``.
// Mirrors FileTree's slice-walking pattern, including the
// ``opStartedAt`` guard so a stale terminal event from a prior
// sync of this same source can't short-circuit the current op.
const opStartedAt = Date.now();
const terminalFromSse = (): 'completed' | 'failed' | null => {
const events = store.getState().notifications.recentEvents;
for (const event of events) {
if (event.scope?.id !== docId) continue;
const ts = event.ts ? Date.parse(event.ts) : NaN;
if (!Number.isFinite(ts) || ts < opStartedAt) continue;
if (event.type === 'source.ingest.completed') return 'completed';
if (event.type === 'source.ingest.failed') return 'failed';
}
return null;
};
const MAX_WAIT_MS = 5 * 60_000;
const terminal = await new Promise<
'completed' | 'failed' | 'timeout' | 'unmounted'
>((resolve) => {
// Cover the race where the event landed between the POST
// returning and the subscribe call.
const initial = terminalFromSse();
if (initial) {
resolve(initial);
return;
}
if (!mountedRef.current) {
resolve('unmounted');
return;
}
let settled = false;
const finish = (
value: 'completed' | 'failed' | 'timeout' | 'unmounted',
) => {
if (settled) return;
settled = true;
if (syncTimerRef.current !== null) {
window.clearTimeout(syncTimerRef.current);
syncTimerRef.current = null;
}
if (syncUnsubscribeRef.current) {
syncUnsubscribeRef.current();
syncUnsubscribeRef.current = null;
}
resolve(value);
};
syncTimerRef.current = window.setTimeout(
() => finish('timeout'),
MAX_WAIT_MS,
);
syncUnsubscribeRef.current = store.subscribe(() => {
if (!mountedRef.current) {
finish('unmounted');
return;
}
const next = terminalFromSse();
if (next) finish(next);
});
});
if (terminal === 'timeout') {
console.error('Sync timed out waiting for SSE terminal');
} else if (terminal === 'unmounted') {
return;
}
if (terminal === 'completed') {
// The "no files downloaded" early-return path publishes
// ``completed`` with ``no_changes: true`` — treated as success
// here; refreshing the directory is cheap and idempotent.
setSyncProgress(100);
console.log('Sync completed successfully');
// Poll task status using userService
const maxAttempts = 30;
const pollInterval = 2000;
for (let attempt = 0; attempt < maxAttempts; attempt++) {
try {
const refreshResponse = await userService.getDirectoryStructure(
docId,
const statusResponse = await userService.getTaskStatus(
data.task_id,
token,
);
const refreshData = await refreshResponse.json();
if (refreshData && refreshData.directory_structure) {
setDirectoryStructure(refreshData.directory_structure);
setCurrentPath([]);
}
if (refreshData && refreshData.provider) {
setSourceProvider(refreshData.provider);
const statusData = await statusResponse.json();
console.log(
`Task status (attempt ${attempt + 1}):`,
statusData.status,
);
if (statusData.status === 'SUCCESS') {
setSyncProgress(100);
console.log('Sync completed successfully');
// Refresh directory structure
try {
const refreshResponse = await userService.getDirectoryStructure(
docId,
token,
);
const refreshData = await refreshResponse.json();
if (refreshData && refreshData.directory_structure) {
setDirectoryStructure(refreshData.directory_structure);
setCurrentPath([]);
}
if (refreshData && refreshData.provider) {
setSourceProvider(refreshData.provider);
}
setSyncDone(true);
setTimeout(() => setSyncDone(false), 5000);
} catch (err) {
console.error('Error refreshing directory structure:', err);
}
break;
} else if (statusData.status === 'FAILURE') {
console.error('Sync task failed:', statusData.result);
break;
} else if (statusData.status === 'PROGRESS') {
const progress = Number(
statusData.result && statusData.result.current != null
? statusData.result.current
: statusData.meta && statusData.meta.current != null
? statusData.meta.current
: 0,
);
setSyncProgress(Math.max(10, progress));
}
setSyncDone(true);
setTimeout(() => setSyncDone(false), 5000);
} catch (err) {
console.error('Error refreshing directory structure:', err);
await new Promise((resolve) => setTimeout(resolve, pollInterval));
} catch (error) {
console.error('Error polling task status:', error);
break;
}
} else if (terminal === 'failed') {
console.error('Sync task failed (per SSE)');
}
} else {
console.error('Sync failed:', data.error);

View File

@@ -1,6 +1,5 @@
import React, { useState, useEffect, useCallback, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import userService from '../api/services/userService';
import { formatBytes } from '../utils/stringUtils';
import { formatDate } from '../utils/dateTimeUtils';
import {
@@ -23,7 +22,6 @@ import {
TableHeader,
TableCell,
} from './Table';
import { useDebouncedCallback } from '../hooks';
interface CloudFile {
id: string;
@@ -102,6 +100,7 @@ export const FilePicker: React.FC<CloudFilePickerProps> = ({
const [activeTab, setActiveTab] = useState<'my_files' | 'shared'>('my_files');
const scrollContainerRef = useRef<HTMLDivElement>(null);
const searchTimeoutRef = useRef<ReturnType<typeof setTimeout> | null>(null);
const abortControllerRef = useRef<AbortController | null>(null);
const isFolder = (file: CloudFile) => {
@@ -127,6 +126,7 @@ export const FilePicker: React.FC<CloudFilePickerProps> = ({
setIsLoading(true);
const apiHost = import.meta.env.VITE_API_HOST;
if (!pageToken) {
setFiles([]);
}
@@ -141,11 +141,15 @@ export const FilePicker: React.FC<CloudFilePickerProps> = ({
search_query: searchQuery,
shared: shared,
};
const response = await userService.getConnectorFiles(
body,
token,
controller.signal,
);
const response = await fetch(`${apiHost}/api/connectors/files`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`,
},
body: JSON.stringify(body),
signal: controller.signal,
});
const data = await response.json();
if (data.success) {
@@ -183,9 +187,20 @@ export const FilePicker: React.FC<CloudFilePickerProps> = ({
}
try {
const validateResponse = await userService.validateConnectorSession(
provider,
token,
const apiHost = import.meta.env.VITE_API_HOST;
const validateResponse = await fetch(
`${apiHost}/api/connectors/validate-session`,
{
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`,
},
body: JSON.stringify({
provider: provider,
session_token: sessionToken,
}),
},
);
if (!validateResponse.ok) {
@@ -277,26 +292,32 @@ export const FilePicker: React.FC<CloudFilePickerProps> = ({
useEffect(() => {
return () => {
if (searchTimeoutRef.current) {
clearTimeout(searchTimeoutRef.current);
}
abortControllerRef.current?.abort();
};
}, []);
const debouncedLoadFiles = useDebouncedCallback((query: string) => {
const sessionToken = getSessionToken(provider);
if (sessionToken) {
loadCloudFiles(
sessionToken,
currentFolderId,
undefined,
query,
activeTab === 'shared' && !currentFolderId,
);
}
}, 300);
const handleSearchChange = (query: string) => {
setSearchQuery(query);
debouncedLoadFiles(query);
if (searchTimeoutRef.current) {
clearTimeout(searchTimeoutRef.current);
}
searchTimeoutRef.current = setTimeout(() => {
const sessionToken = getSessionToken(provider);
if (sessionToken) {
loadCloudFiles(
sessionToken,
currentFolderId,
undefined,
query,
activeTab === 'shared' && !currentFolderId,
);
}
}, 300);
};
const handleFolderClick = (folderId: string, folderName: string) => {
@@ -403,14 +424,23 @@ export const FilePicker: React.FC<CloudFilePickerProps> = ({
onDisconnect={() => {
const sessionToken = getSessionToken(provider);
if (sessionToken) {
userService
.disconnectConnector(provider, sessionToken, token)
.catch((err) =>
console.error(
`Error disconnecting from ${getProviderConfig(provider).displayName}:`,
err,
),
);
const apiHost = import.meta.env.VITE_API_HOST;
fetch(`${apiHost}/api/connectors/disconnect`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`,
},
body: JSON.stringify({
provider: provider,
session_token: sessionToken,
}),
}).catch((err) =>
console.error(
`Error disconnecting from ${getProviderConfig(provider).displayName}:`,
err,
),
);
}
removeSessionToken(provider);

View File

@@ -1,8 +1,7 @@
import React, { useState, useRef, useEffect } from 'react';
import { useTranslation } from 'react-i18next';
import { useSelector, useStore } from 'react-redux';
import { useSelector } from 'react-redux';
import { selectToken } from '../preferences/preferenceSlice';
import type { RootState } from '../store';
import { formatBytes } from '../utils/stringUtils';
import Chunks from './Chunks';
import ContextMenu, { MenuOption } from './ContextMenu';
@@ -57,7 +56,6 @@ const FileTree: React.FC<FileTreeProps> = ({
onBackToDocuments,
}) => {
const { t } = useTranslation();
const store = useStore<RootState>();
const [loading, setLoading] = useLoaderState(true, 500);
const [error, setError] = useState<string | null>(null);
const [directoryStructure, setDirectoryStructure] =
@@ -97,25 +95,6 @@ const FileTree: React.FC<FileTreeProps> = ({
const opQueueRef = useRef<QueuedOperation[]>([]);
const processingRef = useRef(false);
const [queueLength, setQueueLength] = useState(0);
const mountedRef = useRef(true);
const waitUnsubscribeRef = useRef<(() => void) | null>(null);
// Holds the 5-minute SSE-wait timer so the unmount cleanup can clear
// it — otherwise the timer fires up to 5 min after unmount and
// resolves an abandoned Promise.
const waitTimerRef = useRef<number | null>(null);
useEffect(
() => () => {
mountedRef.current = false;
waitUnsubscribeRef.current?.();
waitUnsubscribeRef.current = null;
if (waitTimerRef.current !== null) {
window.clearTimeout(waitTimerRef.current);
waitTimerRef.current = null;
}
},
[],
);
useOutsideAlerter(
searchDropdownRef,
@@ -334,103 +313,47 @@ const FileTree: React.FC<FileTreeProps> = ({
}
console.log('Reingest task started:', result.reingest_task_id);
// SSE is the sole driver here. The backend's
// ``reingest_source_worker`` publishes ``source.ingest.*``
// keyed on the resolved ``source_id`` (the
// ``manage_source_files`` route returns it explicitly so we
// can match without consulting any slice). Subscribe to the
// store and resolve when a terminal event tagged with our
// source lands in ``notifications.recentEvents``. Re-checking
// on every dispatch (rather than polling on a timer) avoids
// races where a terminal could roll off the bounded ring
// before the next tick observes it in chatty sessions.
const reingestSourceId: string | undefined = result.source_id;
// Cutoff so we don't pick up terminal events from a *previous*
// reingest of the same source — the backend's
// ``source.ingest.*`` payload doesn't carry a Celery task id,
// so source_id alone is ambiguous when ops repeat.
const opStartedAt = Date.now();
const MAX_WAIT_MS = 5 * 60_000;
const maxAttempts = 30;
const pollInterval = 2000;
const terminalFromSse = (): 'completed' | 'failed' | null => {
if (!reingestSourceId) return null;
const events = store.getState().notifications.recentEvents;
for (const event of events) {
if (event.scope?.id !== reingestSourceId) continue;
const ts = event.ts ? Date.parse(event.ts) : NaN;
if (!Number.isFinite(ts) || ts < opStartedAt) continue;
if (event.type === 'source.ingest.completed') return 'completed';
if (event.type === 'source.ingest.failed') return 'failed';
}
return null;
};
for (let attempt = 0; attempt < maxAttempts; attempt++) {
try {
const statusResponse = await userService.getTaskStatus(
result.reingest_task_id,
token,
);
const statusData = await statusResponse.json();
const refreshStructure = async (): Promise<boolean> => {
const structureResponse = await userService.getDirectoryStructure(
docId,
token,
);
const structureData = await structureResponse.json();
if (!mountedRef.current) return false;
if (structureData && structureData.directory_structure) {
setDirectoryStructure(structureData.directory_structure);
currentOpRef.current = null;
return true;
}
return false;
};
console.log(
`Task status (attempt ${attempt + 1}):`,
statusData.status,
);
const terminal = await new Promise<
'completed' | 'failed' | 'timeout' | 'unmounted'
>((resolve) => {
if (!mountedRef.current) {
resolve('unmounted');
return;
}
// Cover the race where the terminal event landed between
// the POST returning and the subscribe call.
const initial = terminalFromSse();
if (initial) {
resolve(initial);
return;
}
const timer = window.setTimeout(() => {
waitUnsubscribeRef.current?.();
waitUnsubscribeRef.current = null;
waitTimerRef.current = null;
resolve('timeout');
}, MAX_WAIT_MS);
waitTimerRef.current = timer;
waitUnsubscribeRef.current = store.subscribe(() => {
if (!mountedRef.current) {
window.clearTimeout(timer);
waitTimerRef.current = null;
waitUnsubscribeRef.current?.();
waitUnsubscribeRef.current = null;
resolve('unmounted');
return;
if (statusData.status === 'SUCCESS') {
console.log('Task completed successfully');
const structureResponse = await userService.getDirectoryStructure(
docId,
token,
);
const structureData = await structureResponse.json();
if (structureData && structureData.directory_structure) {
setDirectoryStructure(structureData.directory_structure);
currentOpRef.current = null;
return true;
}
break;
} else if (statusData.status === 'FAILURE') {
console.error('Task failed');
break;
}
const next = terminalFromSse();
if (next) {
window.clearTimeout(timer);
waitTimerRef.current = null;
waitUnsubscribeRef.current?.();
waitUnsubscribeRef.current = null;
resolve(next);
}
});
});
if (!mountedRef.current) return false;
if (terminal === 'completed') {
if (await refreshStructure()) return true;
} else if (terminal === 'failed') {
console.error('Reingest task failed (per SSE)');
} else if (terminal === 'unmounted') {
return false;
} else {
console.error('Reingest timed out waiting for SSE terminal');
await new Promise((resolve) => setTimeout(resolve, pollInterval));
} catch (error) {
console.error('Error polling task status:', error);
break;
}
}
} else {
throw new Error(
@@ -451,7 +374,7 @@ const FileTree: React.FC<FileTreeProps> = ({
? 'delete directory'
: 'delete file(s)';
console.error(`Error ${actionText}:`, error);
if (mountedRef.current) setError(`Failed to ${errorText}`);
setError(`Failed to ${errorText}`);
} finally {
currentOpRef.current = null;
}

View File

@@ -2,7 +2,6 @@ import React, { useState, useEffect } from 'react';
import { useTranslation } from 'react-i18next';
import useDrivePicker from 'react-google-drive-picker';
import userService from '../api/services/userService';
import ConnectorAuth from './ConnectorAuth';
import {
getSessionToken,
@@ -200,11 +199,18 @@ const GoogleDrivePicker: React.FC<GoogleDrivePickerProps> = ({
const sessionToken = getSessionToken('google_drive');
if (sessionToken) {
try {
await userService.disconnectConnector(
'google_drive',
sessionToken,
token,
);
const apiHost = import.meta.env.VITE_API_HOST;
await fetch(`${apiHost}/api/connectors/disconnect`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`,
},
body: JSON.stringify({
provider: 'google_drive',
session_token: sessionToken,
}),
});
} catch (err) {
console.error('Error disconnecting from Google Drive:', err);
}

View File

@@ -3,7 +3,7 @@ import { createPortal } from 'react-dom';
import { LoaderCircle, Mic, Square } from 'lucide-react';
import { useDropzone } from 'react-dropzone';
import { useTranslation } from 'react-i18next';
import { useDispatch, useSelector, useStore } from 'react-redux';
import { useDispatch, useSelector } from 'react-redux';
import endpoints from '../api/endpoints';
import userService from '../api/services/userService';
@@ -28,7 +28,6 @@ import {
selectSelectedDocs,
selectToken,
} from '../preferences/preferenceSlice';
import type { RootState } from '../store';
import Upload from '../upload/Upload';
import { getOS, isTouchDevice } from '../utils/browserUtils';
import SourcesPopup from './SourcesPopup';
@@ -317,7 +316,6 @@ export default function MessageInput({
const attachments = useSelector(selectAttachments);
const dispatch = useDispatch();
const store = useStore<RootState>();
const mediaStreamRef = useRef<MediaStream | null>(null);
const audioContextRef = useRef<AudioContext | null>(null);
const audioSourceNodeRef = useRef<MediaStreamAudioSourceNode | null>(null);
@@ -412,86 +410,6 @@ export default function MessageInput({
};
}, []);
// Recover the race where attachment.* SSE arrives before the upload
// XHR's onload sets ``attachmentId``: walk recentEvents and watchdog
// the row so it can't stay stuck on 'processing'. Mirrors
// Upload.tsx's ``trackTraining``.
const trackAttachment = useCallback(
(clientId: string, attachmentId: string) => {
let handled = false;
const check = () => {
const state = store.getState();
const row = state.upload.attachments.find((a) => a.id === clientId);
if (!row) return true; // removed by user; stop tracking
if (row.status === 'completed' || row.status === 'failed') {
handled = true;
return true;
}
for (const event of state.notifications.recentEvents) {
if (event.scope?.id !== attachmentId) continue;
if (event.type === 'attachment.completed') {
const payload = (event.payload || {}) as Record<string, unknown>;
const tokenCount = Number(payload.token_count);
handled = true;
dispatch(
updateAttachment({
id: clientId,
updates: {
status: 'completed',
progress: 100,
...(Number.isFinite(tokenCount)
? { token_count: tokenCount }
: {}),
},
}),
);
return true;
}
if (event.type === 'attachment.failed') {
handled = true;
dispatch(
updateAttachment({
id: clientId,
updates: { status: 'failed' },
}),
);
return true;
}
}
return false;
};
if (check()) return;
const MAX_WAIT_MS = 5 * 60_000;
let unsubscribe: (() => void) | null = null;
const timer = window.setTimeout(() => {
unsubscribe?.();
if (!handled) {
handled = true;
console.warn(
'trackAttachment: timed out waiting for terminal SSE',
clientId,
attachmentId,
);
dispatch(
updateAttachment({
id: clientId,
updates: { status: 'failed' },
}),
);
}
}, MAX_WAIT_MS);
unsubscribe = store.subscribe(() => {
if (check()) {
window.clearTimeout(timer);
unsubscribe?.();
}
});
},
[dispatch, store],
);
const uploadFiles = useCallback(
(files: File[]) => {
if (!files || files.length === 0) return;
@@ -592,19 +510,11 @@ export default function MessageInput({
id: uiId,
updates: {
taskId: task.task_id,
// Stash the server's attachment id so SSE
// ``attachment.*`` events can match this
// row by ``scope.id`` and drive the
// per-attachment push-fresh poll gate.
attachmentId: task.attachment_id,
status: 'processing',
progress: 10,
},
}),
);
if (task.attachment_id) {
trackAttachment(uiId, task.attachment_id);
}
return;
}
@@ -635,15 +545,11 @@ export default function MessageInput({
id: uiId,
updates: {
taskId: t.task_id,
attachmentId: t.attachment_id,
status: 'processing',
progress: 10,
},
}),
);
if (t.attachment_id) {
trackAttachment(uiId, t.attachment_id);
}
} else {
dispatch(
updateAttachment({
@@ -677,15 +583,11 @@ export default function MessageInput({
id: uiId,
updates: {
taskId: response.task_id,
attachmentId: response.attachment_id,
status: 'processing',
progress: 10,
},
}),
);
if (response.attachment_id) {
trackAttachment(uiId, response.attachment_id);
}
}
} else {
console.warn(
@@ -812,15 +714,11 @@ export default function MessageInput({
id: uniqueId,
updates: {
taskId: response.task_id,
attachmentId: response.attachment_id,
status: 'processing',
progress: 10,
},
}),
);
if (response.attachment_id) {
trackAttachment(uniqueId, response.attachment_id);
}
} else {
// If backend returned tasks[] for single-file, handle gracefully:
if (
@@ -832,15 +730,11 @@ export default function MessageInput({
id: uniqueId,
updates: {
taskId: response.tasks[0].task_id,
attachmentId: response.tasks[0].attachment_id,
status: 'processing',
progress: 10,
},
}),
);
if (response.tasks[0].attachment_id) {
trackAttachment(uniqueId, response.tasks[0].attachment_id);
}
} else {
dispatch(
updateAttachment({
@@ -887,7 +781,7 @@ export default function MessageInput({
xhr.send(formData);
});
},
[dispatch, token, trackAttachment],
[dispatch, token],
);
const handleFileAttachment = (e: React.ChangeEvent<HTMLInputElement>) => {
@@ -922,6 +816,65 @@ export default function MessageInput({
accept: FILE_UPLOAD_ACCEPT,
});
useEffect(() => {
const checkTaskStatus = () => {
const processingAttachments = attachments.filter(
(att) => att.status === 'processing' && att.taskId,
);
processingAttachments.forEach((attachment) => {
userService
.getTaskStatus(attachment.taskId!, null)
.then((data) => data.json())
.then((data) => {
if (data.status === 'SUCCESS') {
dispatch(
updateAttachment({
id: attachment.id,
updates: {
status: 'completed',
progress: 100,
id: data.result?.attachment_id,
token_count: data.result?.token_count,
},
}),
);
} else if (data.status === 'FAILURE') {
dispatch(
updateAttachment({
id: attachment.id,
updates: { status: 'failed' },
}),
);
} else if (data.status === 'PROGRESS' && data.result?.current) {
dispatch(
updateAttachment({
id: attachment.id,
updates: { progress: data.result.current },
}),
);
}
})
.catch(() => {
dispatch(
updateAttachment({
id: attachment.id,
updates: { status: 'failed' },
}),
);
});
});
};
const interval = setInterval(() => {
if (attachments.some((att) => att.status === 'processing')) {
checkTaskStatus();
}
}, 2000);
return () => clearInterval(interval);
}, [attachments, dispatch]);
const handleInput = useCallback(() => {
if (inputRef.current) {
if (window.innerWidth < 350) inputRef.current.style.height = 'auto';

View File

@@ -2,7 +2,8 @@ import { useState, useRef, useEffect } from 'react';
import Speaker from '../assets/speaker.svg?react';
import Stopspeech from '../assets/stopspeech.svg?react';
import LoadingIcon from '../assets/Loading.svg?react'; // Add a loading icon SVG here
import userService from '../api/services/userService';
const apiHost = import.meta.env.VITE_API_HOST || 'https://docsapi.arc53.com';
let currentlyPlayingAudio: {
audio: HTMLAudioElement;
@@ -113,11 +114,12 @@ export default function SpeakButton({ text }: { text: string }) {
},
};
const response = await userService.textToSpeech(
text,
null,
abortController.signal,
);
const response = await fetch(apiHost + '/api/tts', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ text }),
signal: abortController.signal,
});
const data = await response.json();
abortControllerRef.current = null;

View File

@@ -5,54 +5,41 @@ import { useDispatch, useSelector } from 'react-redux';
import CheckCircleFilled from '../assets/check-circle-filled.svg';
import ChevronDown from '../assets/chevron-down.svg';
import WarnIcon from '../assets/warn.svg';
import {
dismissUploadTask,
selectUploadTasks,
type UploadTask,
} from '../upload/uploadSlice';
import { dismissUploadTask, selectUploadTasks } from '../upload/uploadSlice';
const PROGRESS_RADIUS = 10;
const PROGRESS_CIRCUMFERENCE = 2 * Math.PI * PROGRESS_RADIUS;
const IN_PROGRESS_STATUSES = new Set<UploadTask['status']>([
'preparing',
'uploading',
'training',
]);
/**
* Single merged upload card — Google-Drive style. Multiple in-flight
* uploads share one toast with a list of rows; the header reflects
* the *primary* task's status (the newest still-running task, or the
* newest task overall if all are terminal). Per-task progress lives
* on each row.
*
* Dismissal: the header X dismisses every visible task at once
* (mirrors the GDrive panel close — keeps the surface tidy without
* per-row controls). The chevron collapses the row list.
*/
export default function UploadToast() {
const [collapsed, setCollapsed] = useState(false);
const [collapsedTasks, setCollapsedTasks] = useState<Record<string, boolean>>(
{},
);
const toggleTaskCollapse = (taskId: string) => {
setCollapsedTasks((prev) => ({
...prev,
[taskId]: !prev[taskId],
}));
};
const { t } = useTranslation();
const dispatch = useDispatch();
const uploadTasks = useSelector(selectUploadTasks);
const visibleTasks = uploadTasks.filter((task) => !task.dismissed);
if (visibleTasks.length === 0) return null;
// Pick the task that drives the header status: prefer a still-
// running task (most-recent first since the slice unshifts), and
// fall back to whatever's most-recent if everything is terminal.
const primaryTask =
visibleTasks.find((task) => IN_PROGRESS_STATUSES.has(task.status)) ??
visibleTasks[0];
const headerLabel = getStatusHeading(primaryTask.status, t);
const dismissAll = () => {
for (const task of visibleTasks) {
dispatch(dismissUploadTask(task.id));
const getStatusHeading = (status: string) => {
switch (status) {
case 'preparing':
return t('modals.uploadDoc.progress.wait');
case 'uploading':
return t('modals.uploadDoc.progress.upload');
case 'training':
return t('modals.uploadDoc.progress.upload');
case 'completed':
return t('modals.uploadDoc.progress.completed');
case 'failed':
return t('modals.uploadDoc.progress.failed');
default:
return t('modals.uploadDoc.progress.preparing');
}
};
@@ -60,212 +47,180 @@ export default function UploadToast() {
<div
className="fixed right-4 bottom-4 z-50 flex max-w-md flex-col gap-2"
onMouseDown={(e) => e.stopPropagation()}
role="status"
aria-live="polite"
aria-atomic="false"
>
<div
className={`border-border bg-card w-[271px] overflow-hidden rounded-2xl border shadow-[0px_24px_48px_0px_#00000029] transition-all duration-300`}
>
<div
className={`flex items-center justify-between px-4 py-3 ${
primaryTask.status !== 'failed'
? 'bg-accent/50 dark:bg-muted'
: 'bg-destructive/10 dark:bg-destructive/10'
}`}
>
<h3 className="font-inter dark:text-foreground text-[14px] leading-[16.5px] font-medium text-black">
{headerLabel}
</h3>
<div className="flex items-center gap-1">
<button
type="button"
onClick={() => setCollapsed((prev) => !prev)}
aria-label={
collapsed
? t('modals.uploadDoc.progress.expandDetails')
: t('modals.uploadDoc.progress.collapseDetails')
}
className="flex h-8 items-center justify-center p-0 text-black opacity-70 transition-opacity hover:opacity-100 dark:text-white"
>
<img
src={ChevronDown}
alt=""
className={`h-4 w-4 transform transition-transform duration-200 dark:invert ${
collapsed ? 'rotate-180' : ''
}`}
/>
</button>
<button
type="button"
onClick={dismissAll}
className="flex h-8 items-center justify-center p-0 text-black opacity-70 transition-opacity hover:opacity-100 dark:text-white"
aria-label={t('modals.uploadDoc.progress.dismiss')}
>
<svg
width="16"
height="16"
viewBox="0 0 24 24"
fill="none"
xmlns="http://www.w3.org/2000/svg"
className="h-4 w-4"
>
<path
d="M18 6L6 18"
stroke="currentColor"
strokeWidth="2"
strokeLinecap="round"
strokeLinejoin="round"
/>
<path
d="M6 6L18 18"
stroke="currentColor"
strokeWidth="2"
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
</button>
</div>
</div>
{uploadTasks
.filter((task) => !task.dismissed)
.map((task) => {
const shouldShowProgress = [
'preparing',
'uploading',
'training',
].includes(task.status);
const rawProgress = Math.min(Math.max(task.progress ?? 0, 0), 100);
const formattedProgress = Math.round(rawProgress);
const progressOffset =
PROGRESS_CIRCUMFERENCE * (1 - rawProgress / 100);
const isCollapsed = collapsedTasks[task.id] ?? false;
<div
className="grid overflow-hidden transition-[grid-template-rows] duration-300 ease-out"
style={{ gridTemplateRows: collapsed ? '0fr' : '1fr' }}
>
<div
className={`min-h-0 overflow-hidden transition-opacity duration-300 ${
collapsed ? 'opacity-0' : 'opacity-100'
}`}
>
<ul className="max-h-72 overflow-y-auto">
{visibleTasks.map((task) => (
<UploadRow key={task.id} task={task} t={t} />
))}
</ul>
</div>
</div>
</div>
return (
<div
key={task.id}
className={`border-border bg-card w-[271px] overflow-hidden rounded-2xl border shadow-[0px_24px_48px_0px_#00000029] transition-all duration-300`}
>
<div className="flex flex-col">
<div
className={`flex items-center justify-between px-4 py-3 ${
task.status !== 'failed'
? 'bg-accent/50 dark:bg-muted'
: 'bg-destructive/10 dark:bg-destructive/10'
}`}
>
<h3 className="font-inter dark:text-foreground text-[14px] leading-[16.5px] font-medium text-black">
{getStatusHeading(task.status)}
</h3>
<div className="flex items-center gap-1">
<button
type="button"
onClick={() => toggleTaskCollapse(task.id)}
aria-label={
isCollapsed
? t('modals.uploadDoc.progress.expandDetails')
: t('modals.uploadDoc.progress.collapseDetails')
}
className="flex h-8 items-center justify-center p-0 text-black opacity-70 transition-opacity hover:opacity-100 dark:text-white"
>
<img
src={ChevronDown}
alt=""
className={`h-4 w-4 transform transition-transform duration-200 dark:invert ${
isCollapsed ? 'rotate-180' : ''
}`}
/>
</button>
<button
type="button"
onClick={() => dispatch(dismissUploadTask(task.id))}
className="flex h-8 items-center justify-center p-0 text-black opacity-70 transition-opacity hover:opacity-100 dark:text-white"
aria-label={t('modals.uploadDoc.progress.dismiss')}
>
<svg
width="16"
height="16"
viewBox="0 0 24 24"
fill="none"
xmlns="http://www.w3.org/2000/svg"
className="h-4 w-4"
>
<path
d="M18 6L6 18"
stroke="currentColor"
strokeWidth="2"
strokeLinecap="round"
strokeLinejoin="round"
/>
<path
d="M6 6L18 18"
stroke="currentColor"
strokeWidth="2"
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
</button>
</div>
</div>
<div
className="grid overflow-hidden transition-[grid-template-rows] duration-300 ease-out"
style={{ gridTemplateRows: isCollapsed ? '0fr' : '1fr' }}
>
<div
className={`min-h-0 overflow-hidden transition-opacity duration-300 ${
isCollapsed ? 'opacity-0' : 'opacity-100'
}`}
>
<div className="flex items-center justify-between px-5 py-3">
<p
className="font-inter dark:text-muted-foreground max-w-[200px] truncate text-[13px] leading-[16.5px] font-normal text-black"
title={task.fileName}
>
{task.fileName}
</p>
<div className="flex items-center gap-2">
{shouldShowProgress && (
<svg
width="24"
height="24"
viewBox="0 0 24 24"
className="h-6 w-6 shrink-0 text-[#7D54D1]"
role="progressbar"
aria-valuemin={0}
aria-valuemax={100}
aria-valuenow={formattedProgress}
aria-label={t(
'modals.uploadDoc.progress.uploadProgress',
{
progress: formattedProgress,
},
)}
>
<circle
className="text-muted dark:text-muted-foreground/30"
stroke="currentColor"
strokeWidth="2"
cx="12"
cy="12"
r={PROGRESS_RADIUS}
fill="none"
/>
<circle
className="text-[#7D54D1]"
stroke="currentColor"
strokeWidth="2"
strokeLinecap="round"
strokeDasharray={PROGRESS_CIRCUMFERENCE}
strokeDashoffset={progressOffset}
cx="12"
cy="12"
r={PROGRESS_RADIUS}
fill="none"
transform="rotate(-90 12 12)"
/>
</svg>
)}
{task.status === 'completed' && (
<img
src={CheckCircleFilled}
alt=""
className="h-6 w-6 shrink-0"
aria-hidden="true"
/>
)}
{task.status === 'failed' && (
<img
src={WarnIcon}
alt=""
className="h-6 w-6 shrink-0"
aria-hidden="true"
/>
)}
</div>
</div>
{task.status === 'failed' && task.errorMessage && (
<span className="block px-5 pb-3 text-xs text-red-500">
{task.errorMessage}
</span>
)}
</div>
</div>
</div>
</div>
);
})}
</div>
);
}
function UploadRow({
task,
t,
}: {
task: UploadTask;
t: ReturnType<typeof useTranslation>['t'];
}) {
const showProgress = IN_PROGRESS_STATUSES.has(task.status);
const rawProgress = Math.min(Math.max(task.progress ?? 0, 0), 100);
const formattedProgress = Math.round(rawProgress);
const progressOffset = PROGRESS_CIRCUMFERENCE * (1 - rawProgress / 100);
return (
<li className="border-border/50 border-b last:border-b-0">
<div className="flex items-center justify-between px-5 py-3">
<div className="flex min-w-0 flex-col">
<p
className="font-inter dark:text-muted-foreground max-w-[200px] truncate text-[13px] leading-[16.5px] font-normal text-black"
title={task.fileName}
>
{task.fileName}
</p>
{task.status === 'training' && task.stage && (
<span className="font-inter text-muted-foreground mt-0.5 text-[11px] leading-[14px]">
{t(`modals.uploadDoc.progress.${task.stage}`)}
</span>
)}
</div>
<div className="flex items-center gap-2">
{showProgress && (
<svg
width="24"
height="24"
viewBox="0 0 24 24"
className="h-6 w-6 shrink-0 text-[#7D54D1]"
role="progressbar"
aria-valuemin={0}
aria-valuemax={100}
aria-valuenow={formattedProgress}
aria-label={t('modals.uploadDoc.progress.uploadProgress', {
progress: formattedProgress,
})}
>
<circle
className="text-muted dark:text-muted-foreground/30"
stroke="currentColor"
strokeWidth="2"
cx="12"
cy="12"
r={PROGRESS_RADIUS}
fill="none"
/>
<circle
className="text-[#7D54D1]"
stroke="currentColor"
strokeWidth="2"
strokeLinecap="round"
strokeDasharray={PROGRESS_CIRCUMFERENCE}
strokeDashoffset={progressOffset}
cx="12"
cy="12"
r={PROGRESS_RADIUS}
fill="none"
transform="rotate(-90 12 12)"
/>
</svg>
)}
{task.status === 'completed' && (
<img
src={CheckCircleFilled}
alt=""
className="h-6 w-6 shrink-0"
aria-hidden="true"
/>
)}
{task.status === 'failed' && (
<img
src={WarnIcon}
alt=""
className="h-6 w-6 shrink-0"
aria-hidden="true"
/>
)}
</div>
</div>
{task.status === 'failed' &&
(task.tokenLimitReached || task.errorMessage) && (
<span className="block px-5 pb-3 text-xs text-red-500">
{task.tokenLimitReached
? t('modals.uploadDoc.progress.tokenLimit')
: task.errorMessage}
</span>
)}
</li>
);
}
function getStatusHeading(
status: UploadTask['status'],
t: ReturnType<typeof useTranslation>['t'],
): string {
switch (status) {
case 'preparing':
return t('modals.uploadDoc.progress.wait');
case 'uploading':
case 'training':
return t('modals.uploadDoc.progress.upload');
case 'completed':
return t('modals.uploadDoc.progress.completed');
case 'failed':
return t('modals.uploadDoc.progress.failed');
default:
return t('modals.uploadDoc.progress.preparing');
}
}

View File

@@ -322,7 +322,7 @@ export default function Conversation() {
isSplitArtifactOpen ? 'w-[60%] px-6' : 'w-full'
}`}
>
<div className="relative min-h-0 flex-1">
<div className="relative min-h-0 flex-1 ">
<ConversationMessages
handleQuestion={handleQuestion}
handleQuestionSubmission={handleQuestionSubmission}

View File

@@ -132,8 +132,6 @@ const ConversationBubble = forwardRef<
}, [message]);
const handleEditClick = () => {
if (!editInputBox.trim() || editInputBox.trim() === (message ?? '').trim())
return;
setIsEditClicked(false);
handleUpdatedQuestionSubmission?.(editInputBox, true, questionNumber);
};
@@ -154,10 +152,10 @@ const ConversationBubble = forwardRef<
<img
src={DocumentationDark}
alt="Attachment"
className="h-3.75 w-3.75 object-fill"
className="h-[15px] w-[15px] object-fill"
/>
</div>
<span className="max-w-37.5 truncate font-normal">
<span className="max-w-[150px] truncate font-normal">
{file.fileName}
</span>
</div>
@@ -244,12 +242,8 @@ const ConversationBubble = forwardRef<
{t('conversation.edit.cancel')}
</button>
<button
className="bg-primary not-disabled:hover:bg-primary/90 not-disabled:dark:hover:bg-primary/90 disabled:bg-primary/30 rounded-full px-4 py-2 text-sm font-medium text-white transition-colors disabled:cursor-not-allowed"
className="bg-primary hover:bg-primary/90 dark:hover:bg-primary/90 rounded-full px-4 py-2 text-sm font-medium text-white transition-colors"
onClick={handleEditClick}
disabled={
!editInputBox.trim() ||
editInputBox.trim() === (message ?? '').trim()
}
>
{t('conversation.edit.update')}
</button>
@@ -328,7 +322,7 @@ const ConversationBubble = forwardRef<
<div className="mb-4 flex flex-col flex-wrap items-start self-start lg:flex-nowrap">
<div className="my-2 flex flex-row items-center justify-center gap-3">
<Avatar
className="h-6.5 w-7.5 text-xl"
className="h-[26px] w-[30px] text-xl"
avatar={
<img
src={Sources}
@@ -376,7 +370,7 @@ const ConversationBubble = forwardRef<
<img
src={Document}
alt="Document"
className="h-4.25 w-4.25 object-fill"
className="h-[17px] w-[17px] object-fill"
/>
<p
className="mt-0.5 truncate text-xs"
@@ -394,11 +388,11 @@ const ConversationBubble = forwardRef<
</div>
{activeTooltip === index && (
<div
className={`dark:bg-card dark:text-foreground absolute left-1/2 z-50 max-h-48 w-40 translate-x-[-50%] translate-y-0.75 rounded-xl bg-[#FBFBFB] p-4 text-black shadow-xl sm:w-56`}
className={`dark:bg-card dark:text-foreground absolute left-1/2 z-50 max-h-48 w-40 translate-x-[-50%] translate-y-[3px] rounded-xl bg-[#FBFBFB] p-4 text-black shadow-xl sm:w-56`}
onMouseOver={() => setActiveTooltip(index)}
onMouseOut={() => setActiveTooltip(null)}
>
<p className="line-clamp-6 max-h-41 overflow-hidden rounded-md text-sm wrap-break-word text-ellipsis">
<p className="line-clamp-6 max-h-[164px] overflow-hidden rounded-md text-sm wrap-break-word text-ellipsis">
{source.text}
</p>
</div>
@@ -471,7 +465,7 @@ const ConversationBubble = forwardRef<
<div className="flex max-w-full flex-col flex-wrap items-start self-start lg:flex-nowrap">
<div className="my-2 flex flex-row items-center justify-center gap-3">
<Avatar
className="h-8.5 w-8.5 text-2xl"
className="h-[34px] w-[34px] text-2xl"
avatar={
<img
src={DocsGPT3}
@@ -1023,7 +1017,7 @@ function ToolCalls({
);
return (
<div className="mb-4 relative flex w-full flex-col flex-wrap items-start self-start lg:flex-nowrap">
<div className="mb-4 flex w-full flex-col flex-wrap items-start self-start lg:flex-nowrap">
{/* Approval bars — always visible, compact inline */}
{awaitingCalls.length > 0 && (
<div className="fade-in mt-4 ml-3 w-[90vw] md:w-[70vw] lg:w-full">
@@ -1042,7 +1036,7 @@ function ToolCalls({
<>
<div className="my-2 flex flex-row items-center justify-center gap-3">
<Avatar
className="h-6.5 w-7.5 text-xl"
className="h-[26px] w-[30px] text-xl"
avatar={
<img
src={Sources}
@@ -1089,7 +1083,7 @@ function ToolCalls({
</p>
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
<span
className="dark:text-muted-foreground leading-5.75 text-black"
className="dark:text-muted-foreground leading-[23px] text-black"
style={{ fontFamily: 'IBMPlexMono-Medium' }}
>
{JSON.stringify(toolCall.arguments, null, 2)}
@@ -1117,7 +1111,7 @@ function ToolCalls({
{toolCall.status === 'completed' && (
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
<span
className="dark:text-muted-foreground leading-5.75 text-black"
className="dark:text-muted-foreground leading-[23px] text-black"
style={{ fontFamily: 'IBMPlexMono-Medium' }}
>
{JSON.stringify(toolCall.result, null, 2)}
@@ -1127,7 +1121,7 @@ function ToolCalls({
{toolCall.status === 'error' && (
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
<span
className="text-destructive leading-5.75"
className="text-destructive leading-[23px]"
style={{ fontFamily: 'IBMPlexMono-Medium' }}
>
{toolCall.error}
@@ -1137,7 +1131,7 @@ function ToolCalls({
{toolCall.status === 'denied' && (
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
<span
className="text-muted-foreground leading-5.75"
className="text-muted-foreground leading-[23px]"
style={{ fontFamily: 'IBMPlexMono-Medium' }}
>
Denied by user
@@ -1172,7 +1166,7 @@ function Thought({
<div className="mb-4 flex w-full flex-col flex-wrap items-start self-start lg:flex-nowrap">
<div className="my-2 flex flex-row items-center justify-center gap-3">
<Avatar
className="h-6.5 w-7.5 text-xl"
className="h-[26px] w-[30px] text-xl"
avatar={
<img
src={Cloud}
@@ -1197,7 +1191,7 @@ function Thought({
</div>
{isThoughtOpen && (
<div className="fade-in mr-5 ml-2 max-w-[90vw] md:max-w-[70vw] lg:max-w-[50vw]">
<div className="bg-muted dark:bg-answer-bubble rounded-[28px] px-7 py-4.5">
<div className="bg-muted dark:bg-answer-bubble rounded-[28px] px-7 py-[18px]">
<ReactMarkdown
className="fade-in leading-normal wrap-break-word whitespace-pre-wrap"
remarkPlugins={[remarkGfm, remarkMath]}

View File

@@ -64,10 +64,7 @@ export default function ConversationTile({
}
function handleSaveConversation(changedConversation: ConversationProps) {
if (
changedConversation.name.trim().length &&
changedConversation.name.trim() !== conversation.name.trim()
) {
if (changedConversation.name.trim().length) {
onSave(changedConversation);
setIsEdit(false);
} else {

View File

@@ -1,155 +1,8 @@
import { baseURL } from '../api/client';
import conversationService from '../api/services/conversationService';
import { Doc } from '../models/misc';
import { Answer, FEEDBACK, RetrievalPayload } from './conversationModels';
import { ToolCallsType } from './types';
/**
* Mirrors the backend's ``_SEQUENCE_NO_RE`` (application/api/answer/
* routes/messages.py) — only non-negative decimal integers are valid
* cursors. Rejects empty strings (Number("") === 0), hex literals,
* exponential notation, and anything else that ``Number(...)`` would
* happily coerce.
*/
const _SEQUENCE_NO_RE = /^\d+$/;
/**
* Drain an SSE response body, forwarding each ``data:`` line to
* ``onData`` and tracking the most recent ``id:`` header. Returns
* when the body ends, the signal aborts, or ``shouldStop()`` returns
* true (e.g. a terminal ``end``/``error`` event was dispatched —
* the reconnect endpoint is a live tail that doesn't close on its
* own past terminal replay).
*/
/**
* Convert a non-SSE pre-stream HTTP failure (e.g. ``check_usage``'s
* 429 JSON response) into a synthetic typed ``error`` frame so the
* caller's slice sees the actual server message instead of the
* generic "Connection lost" synthesised when the drainer finishes
* with zero events. Returns true if a frame was dispatched and the
* caller should skip ``_drainSseBody`` entirely.
*
* SSE-shaped error bodies (``mimetype="text/event-stream"``) are
* left alone — the drainer parses the typed ``error`` frame they
* carry through the normal path.
*/
async function _handlePreStreamHttpError(
response: Response,
dispatch: (data: string) => void,
): Promise<boolean> {
if (response.ok) return false;
const contentType = (
response.headers.get('content-type') ?? ''
).toLowerCase();
if (contentType.includes('text/event-stream')) return false;
let message: string | null = null;
try {
const text = await response.text();
if (text) {
try {
const parsed = JSON.parse(text);
if (parsed && typeof parsed === 'object') {
message =
(typeof parsed.message === 'string' && parsed.message) ||
(typeof parsed.error === 'string' && parsed.error) ||
(typeof parsed.detail === 'string' && parsed.detail) ||
null;
}
} catch {
message = text.slice(0, 500);
}
}
} catch {
// Body already consumed or unreadable — fall through to the
// status-line fallback below.
}
if (!message) {
message = `HTTP ${response.status} ${response.statusText}`.trim();
}
dispatch(JSON.stringify({ type: 'error', error: message }));
return true;
}
async function _drainSseBody(
body: ReadableStream<Uint8Array>,
signal: AbortSignal,
onData: (data: string) => void,
onId: (id: number) => void,
shouldStop?: () => boolean,
): Promise<void> {
const reader = body.getReader();
const decoder = new TextDecoder('utf-8');
let buffer = '';
let stoppedEarly = false;
try {
while (true) {
if (signal.aborted) break;
if (shouldStop?.()) {
stoppedEarly = true;
break;
}
const { done, value } = await reader.read();
if (done) break;
buffer += decoder.decode(value, { stream: true });
// Normalise mixed line terminators so a stray CR can't smuggle
// a record boundary inside a JSON payload.
buffer = buffer.replace(/\r\n/g, '\n').replace(/\r/g, '\n');
let boundary = buffer.indexOf('\n\n');
while (boundary !== -1) {
const record = buffer.slice(0, boundary);
buffer = buffer.slice(boundary + 2);
boundary = buffer.indexOf('\n\n');
if (record.length === 0) continue;
const dataParts: string[] = [];
let sawDataField = false;
for (const line of record.split('\n')) {
if (line.length === 0) continue;
if (line.startsWith(':')) continue; // SSE comment / keepalive
const colonIdx = line.indexOf(':');
const field = colonIdx === -1 ? line : line.slice(0, colonIdx);
let value = colonIdx === -1 ? '' : line.slice(colonIdx + 1);
if (value.startsWith(' ')) value = value.slice(1);
if (field === 'id') {
// Strict regex match — empty value, hex, ``-1`` (the
// backend's terminal snapshot-failure synthetic), and
// exponent forms are all rejected so they can't silently
// rewrite the reconnect cursor.
if (_SEQUENCE_NO_RE.test(value)) onId(parseInt(value, 10));
} else if (field === 'data') {
sawDataField = true;
dataParts.push(value);
}
}
if (!sawDataField) continue;
const data = dataParts.join('\n').trim();
if (data.length === 0) continue;
onData(data);
if (shouldStop?.()) {
stoppedEarly = true;
break;
}
}
if (stoppedEarly) break;
}
} finally {
if (stoppedEarly) {
// Ask the runtime to tear the underlying response body down so
// the server-side WSGI thread isn't pinned waiting on
// keepalives. ``releaseLock`` alone leaves the body half-open.
try {
await reader.cancel();
} catch {
// Already errored / closed.
}
}
try {
reader.releaseLock();
} catch {
// Already released.
}
}
}
export function handleFetchAnswer(
question: string,
signal: AbortSignal,
@@ -290,153 +143,54 @@ export function handleFetchAnswerSteaming(
const headers: Record<string, string> = {};
if (idempotencyKey) headers['Idempotency-Key'] = idempotencyKey;
// Per-stream state used for reconnect-after-disconnect.
let messageId: string | null = null;
let lastEventId: number | null = null;
// The single JSON.parse below feeds both the message_id capture and
// the termination flag — cheaper and stricter than substring
// matching the wire bytes.
let endReceived = false;
const dispatch = (data: string) => {
try {
const parsed = JSON.parse(data);
if (parsed && typeof parsed === 'object') {
if (parsed.type === 'message_id' && parsed.message_id) {
messageId = parsed.message_id;
} else if (parsed.type === 'end' || parsed.type === 'error') {
endReceived = true;
}
}
} catch {
// Not JSON — pass through anyway; the caller handles raw lines.
}
onEvent(new MessageEvent('message', { data }));
};
const runInitialPost = async (): Promise<void> => {
const response = await conversationService.answerStream(
payload,
token,
signal,
headers,
);
// Pre-stream HTTP failures with non-SSE bodies (e.g. ``check_usage``
// returning a JSON 429) drain as zero events and would otherwise
// be masked by the generic "Connection lost" synthetic. Convert
// them into a typed ``error`` frame so the real message surfaces.
if (await _handlePreStreamHttpError(response, dispatch)) return;
if (!response.body) throw new Error('No response body');
await _drainSseBody(response.body, signal, dispatch, (id) => {
lastEventId = id;
});
};
// Reconnect's stop predicate: as soon as ``dispatch`` flips
// ``endReceived`` (typed ``end`` or ``error`` event seen — both
// are terminal per the backend's contract). Without this the
// live-tail endpoint would emit keepalives indefinitely and the
// await would never return.
const reconnectShouldStop = () => endReceived;
const runReconnect = async (): Promise<void> => {
if (!messageId) {
throw new Error('reconnect: no message_id captured');
}
const url = new URL(`${baseURL}/api/messages/${messageId}/events`);
if (lastEventId !== null) {
url.searchParams.set('last_event_id', String(lastEventId));
}
const reconnectHeaders: Record<string, string> = {
Accept: 'text/event-stream',
};
if (token) reconnectHeaders.Authorization = `Bearer ${token}`;
// NB: there is no slice consumer for a synthetic ``reconnecting``
// event yet — surface only the underlying network reality. The
// user-visible ``Reconnecting…`` affordance is a follow-up that
// needs ``conversationSlice`` to gain a status case.
const response = await fetch(url.toString(), {
method: 'GET',
headers: reconnectHeaders,
signal,
cache: 'no-store',
});
if (!response.ok || !response.body) {
throw new Error(
`reconnect: HTTP ${response.status} ${response.statusText}`,
);
}
await _drainSseBody(
response.body,
signal,
dispatch,
(id) => {
lastEventId = id;
},
reconnectShouldStop,
);
};
return new Promise<Answer>((resolve, reject) => {
(async () => {
try {
try {
await runInitialPost();
} catch (initialErr) {
// Mid-stream network failures (WiFi blip, worker recycle,
// body reader rejecting) surface as a thrown error — not a
// graceful EOF. If the stream had already started (we have a
// ``messageId``), fall through to the reconnect path so the
// journal-backed replay can finish what the live socket
// couldn't. Pre-stream failures (auth, DNS, server 4xx/5xx
// before any yield) lack a messageId and bubble up.
if (signal.aborted || !messageId) throw initialErr;
console.warn(
'Initial stream failed mid-flight, attempting reconnect:',
initialErr,
);
}
// The backend ends the stream cleanly with a typed ``end``
// event. Anything else (network drop, gunicorn worker recycle,
// load-balancer timeout) is a "premature close" — try one
// reconnect via the GET /api/messages/<id>/events endpoint.
if (!endReceived && !signal.aborted && messageId) {
try {
await runReconnect();
} catch (reconnectErr) {
console.warn('Stream reconnect failed:', reconnectErr);
conversationService
.answerStream(payload, token, signal, headers)
.then((response) => {
if (!response.body) throw Error('No response body');
let buffer = '';
const reader = response.body.getReader();
const decoder = new TextDecoder('utf-8');
let counterrr = 0;
const processStream = ({
done,
value,
}: ReadableStreamReadResult<Uint8Array>) => {
if (done) return;
counterrr += 1;
const chunk = decoder.decode(value);
buffer += chunk;
const events = buffer.split('\n\n');
buffer = events.pop() ?? '';
for (const event of events) {
if (event.trim().startsWith('data:')) {
const dataLine: string = event
.split('\n')
.map((line: string) => line.replace(/^data:\s?/, ''))
.join('');
const messageEvent = new MessageEvent('message', {
data: dataLine.trim(),
});
onEvent(messageEvent);
}
}
}
// If we never observed a terminal frame (reconnect 4xx/5xx,
// network drop during reconnect, or live tail still silent),
// synthesize one through the same ``dispatch`` path the wire
// events use. Without this the caller's slice never transitions
// out of ``streaming`` and the UI stays in a loading spinner
// forever — the conversationSlice handles ``data.type === 'error'``
// by setting status=failed.
if (!endReceived && !signal.aborted) {
dispatch(
JSON.stringify({
type: 'error',
error:
'Connection lost. The response could not be resumed; please try again.',
}),
);
}
// The handler historically never explicitly resolved with a
// value — callers consume the streamed events via ``onEvent``
// and read final state from Redux. Preserve that contract.
resolve(undefined as unknown as Answer);
} catch (error) {
if (signal.aborted) {
resolve(undefined as unknown as Answer);
return;
}
reader.read().then(processStream).catch(reject);
};
reader.read().then(processStream).catch(reject);
})
.catch((error) => {
console.error('Connection failed:', error);
reject(error);
}
})();
});
});
}
@@ -460,149 +214,52 @@ export function handleSubmitToolActions(
const headers: Record<string, string> = {};
if (idempotencyKey) headers['Idempotency-Key'] = idempotencyKey;
// Tool-action submissions resume against the original
// ``reserved_message_id``, so the backend's continuation path emits
// ``id:`` prefixed records that the legacy parser would silently
// drop. Use the shared SSE drainer — and the same reconnect-on-
// premature-close pattern as ``handleFetchAnswerSteaming`` so a
// dropped tool-action stream can pick up after the disconnect.
let messageId: string | null = null;
let lastEventId: number | null = null;
// Track whether the typed ``end`` event was observed. The single
// JSON.parse below feeds both the message_id capture and the
// termination flag — cheaper and stricter than substring matching
// the wire bytes.
let endReceived = false;
const dispatch = (data: string) => {
try {
const parsed = JSON.parse(data);
if (parsed && typeof parsed === 'object') {
if (parsed.type === 'message_id' && parsed.message_id) {
messageId = parsed.message_id;
} else if (parsed.type === 'end' || parsed.type === 'error') {
// Match the backend's terminal set in
// ``application/streaming/event_replay.py``: the agent's
// catch-all path emits ``error`` *without* a trailing
// ``end``, so treating only ``end`` as terminal would
// trigger a reconnect against an already-finished stream
// and hang on keepalives.
endReceived = true;
}
}
} catch {
// Not JSON — pass through anyway; the caller handles raw lines.
}
onEvent(new MessageEvent('message', { data }));
};
const runInitial = async (): Promise<void> => {
const response = await conversationService.answerStream(
payload,
token,
signal,
headers,
);
// See ``handleFetchAnswerSteaming`` for the rationale: non-SSE
// HTTP failures (e.g. ``check_usage`` 429 JSON) need to be lifted
// into a typed ``error`` frame before they reach the drainer.
if (await _handlePreStreamHttpError(response, dispatch)) return;
if (!response.body) throw new Error('No response body');
await _drainSseBody(response.body, signal, dispatch, (id) => {
lastEventId = id;
});
};
// Reconnect's stop predicate: as soon as ``dispatch`` flips
// ``endReceived`` (typed ``end`` or ``error`` event seen — both
// are terminal per the backend's contract). Without this the
// live-tail endpoint would emit keepalives indefinitely and the
// await would never return.
const reconnectShouldStop = () => endReceived;
const runReconnect = async (): Promise<void> => {
if (!messageId) {
throw new Error('reconnect: no message_id captured');
}
const url = new URL(`${baseURL}/api/messages/${messageId}/events`);
if (lastEventId !== null) {
url.searchParams.set('last_event_id', String(lastEventId));
}
const reconnectHeaders: Record<string, string> = {
Accept: 'text/event-stream',
};
if (token) reconnectHeaders.Authorization = `Bearer ${token}`;
const response = await fetch(url.toString(), {
method: 'GET',
headers: reconnectHeaders,
signal,
cache: 'no-store',
});
if (!response.ok || !response.body) {
throw new Error(
`reconnect: HTTP ${response.status} ${response.statusText}`,
);
}
await _drainSseBody(
response.body,
signal,
dispatch,
(id) => {
lastEventId = id;
},
reconnectShouldStop,
);
};
return new Promise<Answer>((resolve, reject) => {
(async () => {
try {
try {
await runInitial();
} catch (initialErr) {
// Same premature-close handling as
// ``handleFetchAnswerSteaming``: a thrown reader error after
// the message_id frame still warrants one reconnect attempt
// against the journal. Pre-stream failures lack a messageId
// and bubble up.
if (signal.aborted || !messageId) throw initialErr;
console.warn(
'Tool-actions stream failed mid-flight, attempting reconnect:',
initialErr,
);
}
if (!endReceived && !signal.aborted && messageId) {
try {
await runReconnect();
} catch (reconnectErr) {
console.warn('Tool-actions reconnect failed:', reconnectErr);
conversationService
.answerStream(payload, token, signal, headers)
.then((response) => {
if (!response.body) throw Error('No response body');
let buffer = '';
const reader = response.body.getReader();
const decoder = new TextDecoder('utf-8');
const processStream = ({
done,
value,
}: ReadableStreamReadResult<Uint8Array>) => {
if (done) return;
const chunk = decoder.decode(value);
buffer += chunk;
const events = buffer.split('\n\n');
buffer = events.pop() ?? '';
for (const event of events) {
if (event.trim().startsWith('data:')) {
const dataLine: string = event
.split('\n')
.map((line: string) => line.replace(/^data:\s?/, ''))
.join('');
const messageEvent = new MessageEvent('message', {
data: dataLine.trim(),
});
onEvent(messageEvent);
}
}
}
// Synthesize a terminal error if reconnect couldn't deliver one
// (4xx/5xx, network drop, silent live tail). Same reasoning as
// ``handleFetchAnswerSteaming``: the caller's slice only exits
// the streaming state on a terminal frame.
if (!endReceived && !signal.aborted) {
dispatch(
JSON.stringify({
type: 'error',
error:
'Connection lost. The tool response could not be resumed; please try again.',
}),
);
}
resolve(undefined as unknown as Answer);
} catch (error) {
if (signal.aborted) {
resolve(undefined as unknown as Answer);
return;
}
reader.read().then(processStream).catch(reject);
};
reader.read().then(processStream).catch(reject);
})
.catch((error) => {
console.error('Tool actions submission failed:', error);
reject(error);
}
})();
});
});
}

View File

@@ -1,153 +0,0 @@
import { describe, expect, it } from 'vitest';
import reducer, {
applyMessageTail,
setConversation,
} from './conversationSlice';
const baseQuery = {
prompt: 'tell me a poem',
messageId: 'm-1',
messageStatus: 'pending' as const,
};
const seedSlice = () => reducer(undefined, setConversation([baseQuery]));
describe('applyMessageTail — streaming partial', () => {
it('writes response to the query while status is streaming', () => {
const state = seedSlice();
const next = reducer(
state,
applyMessageTail({
index: 0,
tail: {
message_id: 'm-1',
status: 'streaming',
response: 'Hello, par',
thought: null,
sources: [],
tool_calls: [],
},
}),
);
expect(next.queries[0].messageStatus).toBe('streaming');
expect(next.queries[0].response).toBe('Hello, par');
});
it('updates response on each successive tail tick', () => {
let state = seedSlice();
state = reducer(
state,
applyMessageTail({
index: 0,
tail: {
message_id: 'm-1',
status: 'streaming',
response: 'Hello',
sources: [],
tool_calls: [],
},
}),
);
state = reducer(
state,
applyMessageTail({
index: 0,
tail: {
message_id: 'm-1',
status: 'streaming',
response: 'Hello, world',
sources: [],
tool_calls: [],
},
}),
);
expect(state.queries[0].response).toBe('Hello, world');
});
it('applies sources and tool_calls when they appear mid-stream', () => {
const state = seedSlice();
const next = reducer(
state,
applyMessageTail({
index: 0,
tail: {
message_id: 'm-1',
status: 'streaming',
response: 'partial',
sources: [{ id: 's1', title: 'doc' }],
tool_calls: [{ name: 'search' }],
},
}),
);
expect(next.queries[0].sources).toEqual([{ id: 's1', title: 'doc' }]);
expect(next.queries[0].tool_calls).toEqual([{ name: 'search' }]);
});
it('ignores empty sources / tool_calls arrays so the renderer stays clean', () => {
const state = seedSlice();
const next = reducer(
state,
applyMessageTail({
index: 0,
tail: {
message_id: 'm-1',
status: 'streaming',
response: 'partial',
sources: [],
tool_calls: [],
},
}),
);
expect(next.queries[0].sources).toBeUndefined();
expect(next.queries[0].tool_calls).toBeUndefined();
});
it('promotes to complete with the final response and clears any error', () => {
let state = seedSlice();
state = reducer(
state,
applyMessageTail({
index: 0,
tail: {
message_id: 'm-1',
status: 'streaming',
response: 'partial',
},
}),
);
state = reducer(
state,
applyMessageTail({
index: 0,
tail: {
message_id: 'm-1',
status: 'complete',
response: 'Final answer.',
},
}),
);
expect(state.queries[0].messageStatus).toBe('complete');
expect(state.queries[0].response).toBe('Final answer.');
expect(state.queries[0].error).toBeUndefined();
});
it('surfaces failed status as error and clears response', () => {
const state = seedSlice();
const next = reducer(
state,
applyMessageTail({
index: 0,
tail: {
message_id: 'm-1',
status: 'failed',
response: 'whatever',
error: 'worker died',
},
}),
);
expect(next.queries[0].messageStatus).toBe('failed');
expect(next.queries[0].error).toBe('worker died');
expect(next.queries[0].response).toBeUndefined();
});
});

View File

@@ -957,34 +957,20 @@ export const conversationSlice = createSlice({
const status = tail?.status as MessageStatus | undefined;
query.messageStatus = status;
query.lastHeartbeatAt = tail?.last_heartbeat_at ?? query.lastHeartbeatAt;
if (status === 'failed') {
if (status === 'complete') {
query.response = tail?.response ?? '';
query.thought = tail?.thought ?? query.thought;
query.sources = tail?.sources ?? query.sources;
query.tool_calls = tail?.tool_calls ?? query.tool_calls;
delete query.error;
} else if (status === 'failed') {
// Surface as error so the placeholder text never renders.
query.error =
(typeof tail?.error === 'string' && tail.error) ||
'Generation failed before completing.';
delete query.response;
return;
}
// /tail returns reconstructed partials mid-stream so a second tab
// can render the in-flight bubble; spinner is driven by status.
const incomingResponse = tail?.response;
if (typeof incomingResponse === 'string') {
query.response = incomingResponse;
} else if (status === 'complete') {
query.response = '';
}
if (typeof tail?.thought === 'string') {
query.thought = tail.thought;
}
if (Array.isArray(tail?.sources) && tail.sources.length > 0) {
query.sources = tail.sources;
}
if (Array.isArray(tail?.tool_calls) && tail.tool_calls.length > 0) {
query.tool_calls = tail.tool_calls;
}
if (status === 'complete') {
delete query.error;
}
// pending / streaming: untouched; spinner keeps showing.
},
raiseError(
state,

View File

@@ -1,18 +0,0 @@
import React from 'react';
import { useEventStream } from './useEventStream';
/**
* Mount-once provider that opens the user's SSE connection. Place
* inside ``AuthWrapper`` so it sees a populated token, and wrap the
* authenticated-app subtree so the connection lives for the user's
* whole session.
*/
export function EventStreamProvider({
children,
}: {
children: React.ReactNode;
}): React.ReactElement {
useEventStream();
return <>{children}</>;
}

View File

@@ -1,49 +0,0 @@
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
import type { AppDispatch } from '../store';
import {
sseEventReceived,
sseLastEventIdReset,
} from '../notifications/notificationsSlice';
import { dispatchSSEEvent } from './dispatchEvent';
describe('dispatchSSEEvent', () => {
let debugSpy: ReturnType<typeof vi.spyOn>;
beforeEach(() => {
debugSpy = vi.spyOn(console, 'debug').mockImplementation(() => undefined);
});
afterEach(() => {
debugSpy.mockRestore();
});
it('dispatches sseLastEventIdReset AND sseEventReceived for backlog.truncated', () => {
const dispatch = vi.fn() as unknown as AppDispatch;
const envelope = { type: 'backlog.truncated' as const };
dispatchSSEEvent(envelope, dispatch);
const calls = (dispatch as unknown as { mock: { calls: unknown[][] } }).mock
.calls;
expect(calls).toHaveLength(2);
expect(calls[0][0]).toEqual(sseLastEventIdReset());
expect(calls[1][0]).toEqual(sseEventReceived(envelope));
});
it('does not log a debug line for known envelope types', () => {
const dispatch = vi.fn() as unknown as AppDispatch;
dispatchSSEEvent({ id: 'e-1', type: 'source.ingest.progress' }, dispatch);
expect(debugSpy).not.toHaveBeenCalled();
});
it('logs a debug line for unknown envelope types', () => {
const dispatch = vi.fn() as unknown as AppDispatch;
dispatchSSEEvent({ id: 'e-2', type: 'mystery.event' }, dispatch);
expect(debugSpy).toHaveBeenCalledTimes(1);
expect(debugSpy.mock.calls[0]).toEqual([
'[dispatchSSEEvent] unknown envelope type',
'mystery.event',
]);
});
});

View File

@@ -1,58 +0,0 @@
import type { AppDispatch } from '../store';
import {
sseEventReceived,
sseLastEventIdReset,
type SSEEvent,
} from '../notifications/notificationsSlice';
// Envelope types this build knows about. Hitting an unknown type means
// the backend published something the frontend hasn't been taught yet
// — worth a single debug line so it's visible in devtools without
// drowning the console in known per-progress traffic.
const KNOWN_TYPES: ReadonlySet<string> = new Set([
'backlog.truncated',
'source.ingest.queued',
'source.ingest.progress',
'source.ingest.completed',
'source.ingest.failed',
'attachment.queued',
'attachment.progress',
'attachment.completed',
'attachment.failed',
'mcp.oauth.awaiting_redirect',
'mcp.oauth.in_progress',
'mcp.oauth.completed',
'mcp.oauth.failed',
'tool.approval.required',
]);
/**
* Single fan-out point for inbound SSE envelopes. Always dispatches
* ``sseEventReceived`` so any slice can ``extraReducers``-listen
* (uploadSlice does this for source-ingest events), then handles the
* small set of envelope-types that need centralised side effects (e.g.
* ``backlog.truncated``).
*/
export function dispatchSSEEvent(
envelope: SSEEvent,
dispatch: AppDispatch,
): void {
if (!KNOWN_TYPES.has(envelope.type)) {
console.debug('[dispatchSSEEvent] unknown envelope type', envelope.type);
}
switch (envelope.type) {
case 'backlog.truncated':
// Backlog window slid past the client's Last-Event-ID. Drop the
// cursor so the next reconnect doesn't try to resume past the
// retained window. Slices that care about full-state freshness
// can subscribe to ``sseEventReceived`` and refetch.
dispatch(sseLastEventIdReset());
break;
default:
// No central side effect; rely on slice-level extraReducers.
break;
}
dispatch(sseEventReceived(envelope));
}

View File

@@ -1,386 +0,0 @@
import { baseURL } from '../api/client';
import type { SSEEvent } from '../notifications/notificationsSlice';
/**
* Connection state surfaced to the consumer. Maps directly to the
* ``PushHealth`` machine in ``notificationsSlice``.
*/
export type EventStreamHealth = 'connecting' | 'healthy' | 'unhealthy';
export interface EventStreamOptions {
/** Bearer token; ``null`` short-circuits to ``unhealthy`` (auth required). */
token: string | null;
/**
* Lazy getter for the current ``Last-Event-ID``. Called once at the
* top of each connect attempt so token rotations / remounts read
* the freshest cursor from Redux instead of a stale mount-time
* snapshot. Return ``null`` for a fresh connect.
*/
getLastEventId: () => string | null;
onEvent: (event: SSEEvent) => void;
onHealthChange: (health: EventStreamHealth) => void;
/** Called with the most recently received id so the caller can persist it. */
onLastEventId?: (id: string) => void;
/**
* Called when the server emitted an ``id:`` line with an empty value
* (WHATWG SSE cursor reset). Distinct from ``onLastEventId('')`` so
* callers can dispatch ``sseLastEventIdReset`` without overloading
* the normal advance path.
*/
onLastEventIdReset?: () => void;
/**
* Invoked once after ``MAX_CONSECUTIVE_401`` back-to-back 401s. The
* reconnect loop then bails out, so the caller is responsible for
* refreshing the token / signalling logout. Without this, an expired
* token spins forever at the 30s backoff cap.
*/
onAuthFailure?: () => void;
/**
* Invoked once when the reconnect loop bails out after
* ``MAX_CONSECUTIVE_ERRORS`` non-401 failures. Lets the caller surface
* a warning instead of the connection silently going dark.
*/
onPermanentFailure?: () => void;
}
export interface EventStreamConnection {
close(): void;
}
/**
* Backoff schedule (ms) for reconnect attempts. Capped at 30s so a long
* outage doesn't push retries past Cloudflare's typical 100s idle-close
* envelope. The schedule resets to 0 after a stream stays healthy for
* ``HEALTHY_DEBOUNCE_MS``.
*/
const BACKOFF_SCHEDULE_MS = [0, 1_000, 2_000, 4_000, 8_000, 16_000, 30_000];
const HEALTHY_DEBOUNCE_MS = 2_000;
/**
* Reconnect ceilings. Without these, the ``while (!closed)`` loop spins
* forever on a persistently-failing endpoint — expired token (401s) or
* sustained server outage (5xx). Both counters reset on a successful
* stream open. Untested (no frontend test harness); behaviour verified
* by manual trace of the loop in ``connectEventStream``.
*/
const MAX_CONSECUTIVE_401 = 3;
const MAX_CONSECUTIVE_ERRORS = 20;
/** Up-to-±20% random jitter so N tabs reconnecting in lockstep stagger. */
function withJitter(delayMs: number): number {
if (delayMs <= 0) return 0;
const span = delayMs * 0.2;
return Math.max(0, Math.round(delayMs + (Math.random() * 2 - 1) * span));
}
/**
* Open a long-lived SSE connection to ``GET /api/events`` with
* fetch-streaming semantics that mirror ``conversationHandlers.ts``.
*
* Returns immediately with an opaque handle; the connection lives in a
* background async loop until ``close()`` is called or the underlying
* ``AbortController`` fires.
*
* The ``Last-Event-ID`` cursor rides on the URL (``?last_event_id=...``)
* rather than as a header so the request stays a CORS-simple GET — a
* custom header would force a preflight OPTIONS that the production
* cross-origin deployment isn't allowlisted for.
*/
export function connectEventStream(
opts: EventStreamOptions,
): EventStreamConnection {
const controller = new AbortController();
let closed = false;
let attempt = 0;
let consecutive401 = 0;
let consecutiveErrors = 0;
// Closure cursor. Seeded from the store on each connect attempt so
// mid-session reconnects use the freshest id, but kept here too so
// an in-flight stream's reconnect doesn't lose progress between ids
// that the store hasn't seen yet (e.g. id-only frames).
let lastEventId: string | null = opts.getLastEventId();
const notifyHealth = (h: EventStreamHealth) => {
if (closed) return;
opts.onHealthChange(h);
};
void (async () => {
while (!closed) {
const baseDelay =
BACKOFF_SCHEDULE_MS[Math.min(attempt, BACKOFF_SCHEDULE_MS.length - 1)];
const delay = withJitter(baseDelay);
if (delay > 0) {
try {
await sleep(delay, controller.signal);
} catch {
return; // aborted while waiting
}
if (closed) return;
}
notifyHealth('connecting');
// Always re-read the store cursor before reconnecting and copy
// it verbatim — including null. A null cursor isn't "leave
// alone": ``backlog.truncated`` events fire ``sseLastEventIdReset``
// to clear the slice, and the client must respect that on the
// next attempt by sending no cursor (full-backlog replay) instead
// of resending the stale one and re-tripping the same truncation.
lastEventId = opts.getLastEventId();
const url = new URL(`${baseURL}/api/events`);
if (lastEventId) url.searchParams.set('last_event_id', lastEventId);
// Auth header is omitted when token is null. Self-hosted dev
// installs run with ``AUTH_TYPE`` unset; the backend maps those
// requests to ``{"sub": "local"}`` so the SSE connection works
// tokenless. When auth IS required, a missing header surfaces
// as a 401 and the response.ok check below flips the health
// back to unhealthy.
const headers: Record<string, string> = {
Accept: 'text/event-stream',
};
if (opts.token) {
headers.Authorization = `Bearer ${opts.token}`;
}
try {
const response = await fetch(url.toString(), {
method: 'GET',
headers,
signal: controller.signal,
// SSE must not be cached.
cache: 'no-store',
});
if (!response.ok || !response.body) {
notifyHealth('unhealthy');
// 401 typically means token expired. Bail out after N
// consecutive 401s so the loop doesn't spin forever at the
// 30s backoff cap with a stale token; the caller is
// responsible for refreshing auth via ``onAuthFailure``.
if (response.status === 401) {
consecutive401 += 1;
consecutiveErrors += 1;
if (consecutive401 >= MAX_CONSECUTIVE_401) {
opts.onAuthFailure?.();
return;
}
} else {
consecutive401 = 0;
consecutiveErrors += 1;
}
if (consecutiveErrors >= MAX_CONSECUTIVE_ERRORS) {
opts.onPermanentFailure?.();
return;
}
// 429: server-side per-user concurrency cap; backoff harder.
if (response.status === 429) attempt = Math.max(attempt, 4);
else attempt = Math.min(attempt + 1, BACKOFF_SCHEDULE_MS.length - 1);
continue;
}
consecutive401 = 0;
// Connection is open. Mark healthy after either:
// - 2s of open response body (covers servers that go silent), or
// - first record received past the 2s mark.
// The setTimeout path means a backend that never emits a single
// record after sending the 200 still flips us out of `connecting`.
let healthyMarked = false;
const markHealthy = () => {
if (healthyMarked) return;
healthyMarked = true;
notifyHealth('healthy');
attempt = 0;
consecutiveErrors = 0;
};
const debounceTimer = setTimeout(markHealthy, HEALTHY_DEBOUNCE_MS);
try {
await readSSEStream(response.body, controller.signal, (record) => {
if (record.id !== undefined) {
lastEventId = record.id || null;
if (record.id) opts.onLastEventId?.(record.id);
else opts.onLastEventIdReset?.();
}
if (record.data === undefined) {
// Keepalive comment, id-only frame, or comment line.
// The cursor was already advanced via ``onLastEventId``
// above so the slice tracks ids even on frames we don't
// dispatch as events.
return;
}
// Empty data line is technically valid SSE but useless; skip.
if (record.data.trim().length === 0) return;
let envelope: SSEEvent | null = null;
try {
envelope = JSON.parse(record.data) as SSEEvent;
} catch {
// Malformed payload; skip.
return;
}
// Defensive shape validation — the cast above lies if the
// server (or a man-in-the-middle) sends garbage.
if (
!envelope ||
typeof envelope !== 'object' ||
typeof envelope.type !== 'string'
) {
return;
}
if (record.id && !envelope.id) envelope.id = record.id;
// Receiving a real envelope post-debounce-window flips
// healthy if the timer hasn't already.
markHealthy();
// Every tab dispatches every envelope it receives into its
// own Redux store. With N tabs open this means N copies of
// the same toast — accepted as a v1 limitation; cross-tab
// dedup via BroadcastChannel + navigator.locks is future
// work. Toast-level suppression can be handled per surface.
opts.onEvent(envelope);
});
} finally {
clearTimeout(debounceTimer);
}
// The reader returned without abort — server closed the stream.
// Fall through to reconnect.
notifyHealth('unhealthy');
consecutiveErrors += 1;
if (consecutiveErrors >= MAX_CONSECUTIVE_ERRORS) {
opts.onPermanentFailure?.();
return;
}
attempt = Math.min(attempt + 1, BACKOFF_SCHEDULE_MS.length - 1);
} catch (err) {
if (
closed ||
(err instanceof DOMException && err.name === 'AbortError')
) {
return;
}
notifyHealth('unhealthy');
consecutiveErrors += 1;
if (consecutiveErrors >= MAX_CONSECUTIVE_ERRORS) {
opts.onPermanentFailure?.();
return;
}
attempt = Math.min(attempt + 1, BACKOFF_SCHEDULE_MS.length - 1);
}
}
})();
return {
close() {
if (closed) return;
closed = true;
controller.abort();
},
};
}
interface ParsedSSERecord {
/**
* ``undefined`` when the record had no ``id`` field at all. An empty
* string means the server explicitly reset the cursor (an ``id:``
* line with no value, per WHATWG SSE).
*/
id?: string;
/** ``undefined`` for keepalive comments / id-only frames. */
data?: string;
}
/**
* Drain a ``ReadableStream<Uint8Array>`` of ``\n\n``-delimited SSE records,
* forwarding each parsed record to ``onRecord``. Honours the WHATWG SSE
* spec's mixed line-terminator handling and SSE comment lines.
*/
async function readSSEStream(
body: ReadableStream<Uint8Array>,
signal: AbortSignal,
onRecord: (record: ParsedSSERecord) => void,
): Promise<void> {
const reader = body.getReader();
const decoder = new TextDecoder('utf-8');
let buffer = '';
try {
while (true) {
if (signal.aborted) return;
const { done, value } = await reader.read();
if (done) return;
buffer += decoder.decode(value, { stream: true });
// SSE records are separated by a blank line. WHATWG spec accepts
// CRLF, CR, or LF — normalise so a stray CR can't smuggle a
// boundary mid-record.
buffer = buffer.replace(/\r\n/g, '\n').replace(/\r/g, '\n');
let boundary = buffer.indexOf('\n\n');
while (boundary !== -1) {
const raw = buffer.slice(0, boundary);
buffer = buffer.slice(boundary + 2);
const record = parseSSERecord(raw);
if (record) onRecord(record);
boundary = buffer.indexOf('\n\n');
}
}
} finally {
try {
reader.releaseLock();
} catch {
// Already released.
}
}
}
function parseSSERecord(raw: string): ParsedSSERecord | null {
if (raw.length === 0) return null;
const lines = raw.split('\n');
let id: string | undefined;
const dataParts: string[] = [];
let sawDataField = false;
for (const line of lines) {
if (line.length === 0) continue;
if (line.startsWith(':')) continue; // SSE comment / keepalive
const colonIdx = line.indexOf(':');
const field = colonIdx === -1 ? line : line.slice(0, colonIdx);
let value = colonIdx === -1 ? '' : line.slice(colonIdx + 1);
// SSE: value may be prefixed by exactly one optional space.
if (value.startsWith(' ')) value = value.slice(1);
if (field === 'id') {
id = value;
} else if (field === 'data') {
sawDataField = true;
dataParts.push(value);
}
// Other field names ('event', 'retry') are ignored for now.
}
if (!sawDataField && id === undefined) return null;
return {
id,
data: sawDataField ? dataParts.join('\n') : undefined,
};
}
function sleep(ms: number, signal: AbortSignal): Promise<void> {
return new Promise((resolve, reject) => {
if (signal.aborted) {
reject(new DOMException('Aborted', 'AbortError'));
return;
}
const timer = setTimeout(() => {
signal.removeEventListener('abort', onAbort);
resolve();
}, ms);
const onAbort = () => {
clearTimeout(timer);
signal.removeEventListener('abort', onAbort);
reject(new DOMException('Aborted', 'AbortError'));
};
signal.addEventListener('abort', onAbort, { once: true });
});
}

View File

@@ -1,85 +0,0 @@
import { useEffect } from 'react';
import { useDispatch, useSelector, useStore } from 'react-redux';
import {
selectLastEventId,
sseHealthChanged,
sseLastEventIdAdvanced,
sseLastEventIdReset,
} from '../notifications/notificationsSlice';
import { selectToken, setToken } from '../preferences/preferenceSlice';
import type { AppDispatch, RootState } from '../store';
import { connectEventStream } from './eventStreamClient';
import { dispatchSSEEvent } from './dispatchEvent';
/**
* Open the SSE connection for the current token and keep it alive for
* the lifetime of the host component. Recreates the connection on
* token change (login / refresh).
*
* The ``lastEventId`` cursor is read lazily from the slice on each
* connect attempt via ``store.getState()`` — capturing it at mount time
* would silently re-replay the entire 24h backlog on token rotation,
* since the slice's id advances during the previous connection's
* lifetime but a snapshot ref would still hold the value seen at
* first mount.
*/
export function useEventStream(): void {
const dispatch = useDispatch<AppDispatch>();
const token = useSelector(selectToken);
const store = useStore<RootState>();
useEffect(() => {
// Connect even when token is null. Self-hosted dev installs run
// with ``AUTH_TYPE`` unset, where ``handle_auth`` maps every
// request to ``{"sub": "local"}`` regardless of headers — gating
// the connection on a populated token would silently disable push
// notifications for the most common configuration. When auth IS
// required and token is null, the backend will 401 and the
// health state will flip to ``unhealthy`` via the response check
// inside ``connectEventStream``.
const conn = connectEventStream({
token,
getLastEventId: () => selectLastEventId(store.getState()),
onEvent: (envelope) => dispatchSSEEvent(envelope, dispatch),
// Advance the slice cursor for every id-bearing frame. Each tab
// owns an independent SSE connection and Redux store, so every
// active tab tracks its own replay cursor.
onLastEventId: (id) => dispatch(sseLastEventIdAdvanced(id)),
// Server emitted ``id:`` with an empty value — WHATWG cursor reset.
// Mirror the slice so the next reconnect doesn't resend a stale id.
onLastEventIdReset: () => dispatch(sseLastEventIdReset()),
onHealthChange: (health) => dispatch(sseHealthChanged(health)),
// SSE 401 loop bail-out. Clear the stored token AND dispatch
// ``setToken(null)`` so ``useAuth`` regenerates a fresh
// ``session_jwt`` in-session; the Redux change also flips this
// hook's ``[token]`` dep, tearing down and respawning the
// connection with the new token. Without the dispatch a
// ``session_jwt`` user is stuck until a hard reload.
onAuthFailure: () => {
console.error(
'[useEventStream] giving up after repeated 401s on /api/events',
);
try {
localStorage.removeItem('authToken');
} catch {
// localStorage unavailable (private mode, etc.) — nothing to do.
}
dispatch(setToken(null));
},
// Surface a warning when the non-401 error budget is exhausted so
// the connection going dark isn't completely silent. Doesn't block
// UI — just observable in devtools.
onPermanentFailure: () => {
console.warn(
'[useEventStream] SSE connection failed permanently after repeated errors',
);
},
});
return () => {
conn.close();
};
}, [token, dispatch, store]);
}

View File

@@ -1,4 +1,4 @@
import { useCallback, useEffect, useRef, useState, RefObject } from 'react';
import { useEffect, RefObject, useState } from 'react';
export function useOutsideAlerter<T extends HTMLElement>(
ref: RefObject<T | null>,
@@ -113,51 +113,6 @@ export function useDarkTheme() {
return [isDarkTheme, toggleTheme, componentMounted] as const;
}
export function useDebouncedValue<T>(value: T, delay = 300): T {
const [debounced, setDebounced] = useState<T>(value);
useEffect(() => {
const timer = setTimeout(() => setDebounced(value), delay);
return () => clearTimeout(timer);
}, [value, delay]);
return debounced;
}
export function useDebouncedCallback<A extends unknown[]>(
callback: (...args: A) => void,
delay = 300,
): ((...args: A) => void) & { cancel: () => void } {
const callbackRef = useRef(callback);
const timerRef = useRef<ReturnType<typeof setTimeout> | null>(null);
useEffect(() => {
callbackRef.current = callback;
}, [callback]);
const cancel = useCallback(() => {
if (timerRef.current) {
clearTimeout(timerRef.current);
timerRef.current = null;
}
}, []);
useEffect(() => cancel, [cancel]);
const debounced = useCallback(
(...args: A) => {
cancel();
timerRef.current = setTimeout(() => {
timerRef.current = null;
callbackRef.current(...args);
}, delay);
},
[delay, cancel],
);
return Object.assign(debounced, { cancel });
}
export function useLoaderState(
initialState = false,
delay = 250,

View File

@@ -15,39 +15,24 @@ export default function useAuth() {
const generateNewToken = async () => {
if (isGeneratingToken.current) return;
isGeneratingToken.current = true;
try {
const response = await userService.getNewToken();
const { token: newToken } = await response.json();
localStorage.setItem('authToken', newToken);
dispatch(setToken(newToken));
setIsAuthLoading(false);
return newToken;
} finally {
// Reset so a subsequent ``setToken(null)`` (SSE 401 recovery)
// can trigger another generation. Without this the in-flight
// guard would latch true forever after the first call.
isGeneratingToken.current = false;
}
const response = await userService.getNewToken();
const { token: newToken } = await response.json();
localStorage.setItem('authToken', newToken);
dispatch(setToken(newToken));
setIsAuthLoading(false);
return newToken;
};
useEffect(() => {
// Re-fires when ``token`` flips to null mid-session (e.g.
// ``useEventStream`` dispatches ``setToken(null)`` after repeated
// SSE 401s) so ``session_jwt`` users get a fresh token without a
// hard reload. ``authType`` short-circuits on subsequent runs.
const initializeAuth = async () => {
try {
let resolvedAuthType = authType;
if (resolvedAuthType === null) {
const configRes = await userService.getConfig();
const config = await configRes.json();
resolvedAuthType = config.auth_type;
setAuthType(resolvedAuthType);
}
const configRes = await userService.getConfig();
const config = await configRes.json();
setAuthType(config.auth_type);
if (resolvedAuthType === 'session_jwt' && !token) {
if (config.auth_type === 'session_jwt' && !token) {
await generateNewToken();
} else if (resolvedAuthType === 'simple_jwt' && !token) {
} else if (config.auth_type === 'simple_jwt' && !token) {
setShowTokenModal(true);
setIsAuthLoading(false);
} else {
@@ -59,7 +44,7 @@ export default function useAuth() {
}
};
initializeAuth();
}, [token, authType]);
}, []);
const handleTokenSubmit = (enteredToken: string) => {
localStorage.setItem('authToken', enteredToken);

View File

@@ -70,9 +70,6 @@
"sync": "Synchronisieren",
"syncNow": "Jetzt synchronisieren",
"syncing": "Synchronisiere...",
"reingest": "Erneut indexieren",
"ingestFailed": "Indexierung fehlgeschlagen",
"ingestProcessing": "Indexierung...",
"syncConfirmation": "Bist du sicher, dass du \"{{sourceName}}\" synchronisieren möchtest? Dies aktualisiert den Inhalt mit deinem Cloud-Speicher und kann Änderungen an einzelnen Chunks überschreiben.",
"syncFrequency": {
"never": "Nie",
@@ -356,8 +353,6 @@
"failed": "Upload fehlgeschlagen",
"wait": "Dies kann einige Minuten dauern",
"preparing": "Upload wird vorbereitet",
"parsing": "Dateien werden verarbeitet...",
"embedding": "Einbettung...",
"tokenLimit": "Token-Limit überschritten, bitte lade ein kleineres Dokument hoch",
"expandDetails": "Upload-Details erweitern",
"collapseDetails": "Upload-Details einklappen",

View File

@@ -70,9 +70,6 @@
"sync": "Sync",
"syncNow": "Sync now",
"syncing": "Syncing...",
"reingest": "Reingest",
"ingestFailed": "Indexing failed",
"ingestProcessing": "Indexing…",
"syncConfirmation": "Are you sure you want to sync \"{{sourceName}}\"? This will update the content with your cloud storage and may override any edits you made to individual chunks.",
"syncFrequency": {
"never": "Never",
@@ -368,8 +365,6 @@
"failed": "Upload failed",
"wait": "This may take several minutes",
"preparing": "Preparing upload",
"parsing": "Parsing files…",
"embedding": "Embedding…",
"tokenLimit": "Over the token limit, please consider uploading smaller document",
"expandDetails": "Expand upload details",
"collapseDetails": "Collapse upload details",

View File

@@ -70,9 +70,6 @@
"sync": "Sincronizar",
"syncNow": "Sincronizar ahora",
"syncing": "Sincronizando...",
"reingest": "Reindexar",
"ingestFailed": "Error de indexación",
"ingestProcessing": "Indexando...",
"syncConfirmation": "¿Estás seguro de que deseas sincronizar \"{{sourceName}}\"? Esto actualizará el contenido con tu almacenamiento en la nube y puede anular cualquier edición que hayas realizado en fragmentos individuales.",
"syncFrequency": {
"never": "Nunca",
@@ -356,8 +353,6 @@
"failed": "Error al subir",
"wait": "Esto puede tardar varios minutos",
"preparing": "Preparando subida",
"parsing": "Analizando archivos...",
"embedding": "Generando incrustaciones...",
"tokenLimit": "Excede el límite de tokens, considere cargar un documento más pequeño",
"expandDetails": "Expandir detalles de subida",
"collapseDetails": "Contraer detalles de subida",

View File

@@ -70,9 +70,6 @@
"sync": "同期",
"syncNow": "今すぐ同期",
"syncing": "同期中...",
"reingest": "再インデックス",
"ingestFailed": "インデックス作成に失敗しました",
"ingestProcessing": "インデックス作成中...",
"syncConfirmation": "\"{{sourceName}}\"を同期してもよろしいですか?これにより、コンテンツがクラウドストレージで更新され、個々のチャンクに加えた編集が上書きされる可能性があります。",
"syncFrequency": {
"never": "なし",
@@ -356,8 +353,6 @@
"failed": "アップロード失敗",
"wait": "数分かかる場合があります",
"preparing": "アップロードを準備中",
"parsing": "ファイルを解析中...",
"embedding": "埋め込み処理中...",
"tokenLimit": "トークン制限を超えています。より小さいドキュメントをアップロードしてください",
"expandDetails": "アップロードの詳細を展開",
"collapseDetails": "アップロードの詳細を折りたたむ",

View File

@@ -70,9 +70,6 @@
"sync": "Синхронизация",
"syncNow": "Синхронизировать сейчас",
"syncing": "Синхронизация...",
"reingest": "Переиндексировать",
"ingestFailed": "Ошибка индексации",
"ingestProcessing": "Индексация...",
"syncConfirmation": "Вы уверены, что хотите синхронизировать \"{{sourceName}}\"? Это обновит содержимое с вашим облачным хранилищем и может перезаписать любые изменения, внесенные вами в отдельные фрагменты.",
"syncFrequency": {
"never": "Никогда",
@@ -356,8 +353,6 @@
"failed": "Ошибка загрузки",
"wait": "Это может занять несколько минут",
"preparing": "Подготовка загрузки",
"parsing": "Обработка файлов...",
"embedding": "Создание эмбеддингов...",
"tokenLimit": "Превышен лимит токенов, рассмотрите возможность загрузки документа меньшего размера",
"expandDetails": "Развернуть детали загрузки",
"collapseDetails": "Свернуть детали загрузки",

View File

@@ -70,9 +70,6 @@
"sync": "同步",
"syncNow": "立即同步",
"syncing": "同步中...",
"reingest": "重新索引",
"ingestFailed": "索引失敗",
"ingestProcessing": "索引中...",
"syncConfirmation": "您確定要同步 \"{{sourceName}}\" 嗎?這將使用您的雲端儲存更新內容,並可能覆蓋您對個別文本塊所做的任何編輯。",
"syncFrequency": {
"never": "從不",
@@ -356,8 +353,6 @@
"failed": "上傳失敗",
"wait": "這可能需要幾分鐘",
"preparing": "準備上傳",
"parsing": "正在解析檔案...",
"embedding": "正在生成嵌入...",
"tokenLimit": "超出令牌限制,請考慮上傳較小的文檔",
"expandDetails": "展開上傳詳情",
"collapseDetails": "摺疊上傳詳情",

View File

@@ -70,9 +70,6 @@
"sync": "同步",
"syncNow": "立即同步",
"syncing": "同步中...",
"reingest": "重新索引",
"ingestFailed": "索引失败",
"ingestProcessing": "索引中...",
"syncConfirmation": "您确定要同步 \"{{sourceName}}\" 吗?这将使用您的云存储更新内容,并可能覆盖您对单个文本块所做的任何编辑。",
"syncFrequency": {
"never": "从不",
@@ -356,8 +353,6 @@
"failed": "上传失败",
"wait": "这可能需要几分钟",
"preparing": "准备上传",
"parsing": "正在解析文件...",
"embedding": "正在生成嵌入...",
"tokenLimit": "超出令牌限制,请考虑上传较小的文档",
"expandDetails": "展开上传详情",
"collapseDetails": "折叠上传详情",

View File

@@ -15,7 +15,6 @@ import {
SelectValue,
} from '../components/ui/select';
import { ActiveState } from '../models/misc';
import { selectRecentEvents } from '../notifications/notificationsSlice';
import { selectToken } from '../preferences/preferenceSlice';
import WrapperComponent from './WrapperModal';
@@ -34,7 +33,6 @@ export default function MCPServerModal({
}: MCPServerModalProps) {
const { t } = useTranslation();
const token = useSelector(selectToken);
const recentEvents = useSelector(selectRecentEvents);
const authTypes = [
{ label: t('settings.tools.mcp.authTypes.none'), value: 'none' },
@@ -73,29 +71,17 @@ export default function MCPServerModal({
>([]);
const [errors, setErrors] = useState<{ [key: string]: string }>({});
const oauthPopupRef = useRef<Window | null>(null);
// Set after ``test_mcp_connection`` returns ``task_id``. The SSE
// effect filters ``recentEvents`` to envelopes matching this id and
// drives the OAuth UI (popup open / completion / failure) from the
// push stream rather than polling the legacy status endpoint.
const [oauthTaskId, setOauthTaskId] = useState<string | null>(null);
// Highest event id we have already reacted to for this taskId. Each
// mcp.oauth.* envelope must fire its side-effect once; without this
// any later re-render that re-evaluates ``recentEvents`` would
// re-open the popup or re-fire onComplete.
const handledEventIdsRef = useRef<Set<string>>(new Set());
// Holds the ``testConnection`` ``onComplete`` for the current
// task id so the SSE effect can invoke it when the terminal event
// lands. Reset to ``null`` on cancel / new test / unmount.
const onCompleteRef = useRef<((result: any) => void) | null>(null);
const popupOpenedRef = useRef(false);
const pollingCancelledRef = useRef(false);
const pollTimerRef = useRef<ReturnType<typeof setTimeout> | null>(null);
const [oauthCompleted, setOAuthCompleted] = useState(false);
const [saveActive, setSaveActive] = useState(false);
const cleanupOAuthListener = useCallback(() => {
setOauthTaskId(null);
handledEventIdsRef.current = new Set();
onCompleteRef.current = null;
popupOpenedRef.current = false;
const cleanupPolling = useCallback(() => {
pollingCancelledRef.current = true;
if (pollTimerRef.current) {
clearTimeout(pollTimerRef.current);
pollTimerRef.current = null;
}
if (oauthPopupRef.current && !oauthPopupRef.current.closed) {
oauthPopupRef.current.close();
}
@@ -103,8 +89,8 @@ export default function MCPServerModal({
}, []);
useEffect(() => {
return cleanupOAuthListener;
}, [cleanupOAuthListener]);
return cleanupPolling;
}, [cleanupPolling]);
useEffect(() => {
if (modalState === 'ACTIVE' && server) {
@@ -133,7 +119,7 @@ export default function MCPServerModal({
}, [modalState, server]);
const resetForm = () => {
cleanupOAuthListener();
cleanupPolling();
setFormData({
name: t('settings.tools.mcp.defaultServerName'),
server_url: '',
@@ -242,123 +228,114 @@ export default function MCPServerModal({
return config;
};
/**
* Drive the OAuth handshake straight from the SSE stream:
*
* - ``mcp.oauth.awaiting_redirect`` → open the popup with the
* ``authorization_url`` carried on the envelope. Previously this URL
* came from polling ``/api/mcp_server/oauth_status/<task_id>``; the
* worker now publishes it inline so we never need to poll.
* - ``mcp.oauth.completed`` → enable Save, surface discovered tools,
* invoke ``onComplete`` (resolves ``testConnection``'s pending state).
* - ``mcp.oauth.failed`` → surface the error and reset Save.
*
* Each event is matched to the active task id via ``scope.id``. The
* publisher is best-effort: a lost ``awaiting_redirect`` envelope
* means the popup never opens, the user retries, and we accept that
* over the prior 1s × 60 polling loop.
*/
useEffect(() => {
if (!oauthTaskId) return;
// ``recentEvents`` is newest-first (the slice ``unshift``s on
// arrival). Walk it oldest-first so we observe the natural OAuth
// ordering (``awaiting_redirect`` → ``completed``) when both
// arrive between effect runs — otherwise we would short-circuit
// on ``completed`` and never open the popup for the
// ``awaiting_redirect`` envelope that was already buffered.
for (let i = recentEvents.length - 1; i >= 0; i--) {
const event = recentEvents[i];
if (event.scope?.id !== oauthTaskId) continue;
if (!event.id || handledEventIdsRef.current.has(event.id)) continue;
const pollOAuthStatus = async (
taskId: string,
onComplete: (result: any) => void,
) => {
let attempts = 0;
const maxAttempts = 60;
let popupOpened = false;
pollingCancelledRef.current = false;
const payload = (event.payload || {}) as Record<string, unknown>;
const poll = async () => {
if (pollingCancelledRef.current) return;
try {
const resp = await userService.getMCPOAuthStatus(taskId, token);
if (pollingCancelledRef.current) return;
const data = await resp.json();
if (pollingCancelledRef.current) return;
if (event.type === 'mcp.oauth.awaiting_redirect') {
handledEventIdsRef.current.add(event.id);
const authUrl = payload.authorization_url as string | undefined;
if (authUrl && !popupOpenedRef.current) {
popupOpenedRef.current = true;
if (data.authorization_url && !popupOpened) {
if (oauthPopupRef.current && !oauthPopupRef.current.closed) {
oauthPopupRef.current.close();
}
oauthPopupRef.current = window.open(
authUrl,
data.authorization_url,
'oauthPopup',
'width=600,height=700',
);
popupOpened = true;
if (!oauthPopupRef.current) {
// Popup blocked — surface the URL inline so the user can
// click through manually. Browsers gate ``window.open``
// outside of a user gesture, and the SSE event arrives
// asynchronously, so a blocked popup is expected on
// some browsers / configs.
setTestResult({
success: true,
message: t('settings.tools.mcp.oauthPopupBlocked', {
defaultValue:
'Popup blocked by browser. Click below to authorize:',
}),
authorization_url: authUrl,
authorization_url: data.authorization_url,
});
}
}
continue;
}
if (event.type === 'mcp.oauth.completed') {
handledEventIdsRef.current.add(event.id);
const tools = Array.isArray(payload.tools) ? payload.tools : [];
const toolsCount =
(payload.tools_count as number | undefined) ?? tools.length;
setOAuthCompleted(true);
setSaveActive(true);
if (oauthPopupRef.current && !oauthPopupRef.current.closed) {
oauthPopupRef.current.close();
}
const cb = onCompleteRef.current;
onCompleteRef.current = null;
setOauthTaskId(null);
if (cb) {
cb({
status: 'completed',
task_id: oauthTaskId,
tools,
tools_count: toolsCount,
const callbackReceived =
data.status === 'callback_received' || data.status === 'completed';
if (data.status === 'completed') {
setOAuthCompleted(true);
setSaveActive(true);
onComplete({
...data,
success: true,
message: t('settings.tools.mcp.oauthCompleted'),
});
}
continue;
}
if (event.type === 'mcp.oauth.failed') {
handledEventIdsRef.current.add(event.id);
const message =
(payload.error as string) ??
t('settings.tools.mcp.errors.oauthFailed');
setSaveActive(false);
if (oauthPopupRef.current && !oauthPopupRef.current.closed) {
oauthPopupRef.current.close();
}
const cb = onCompleteRef.current;
onCompleteRef.current = null;
setOauthTaskId(null);
if (cb) {
cb({
status: 'error',
task_id: oauthTaskId,
if (oauthPopupRef.current && !oauthPopupRef.current.closed) {
oauthPopupRef.current.close();
}
} else if (data.status === 'error' || data.success === false) {
setSaveActive(false);
onComplete({
...data,
success: false,
message,
message: data.message || t('settings.tools.mcp.errors.oauthFailed'),
});
if (oauthPopupRef.current && !oauthPopupRef.current.closed) {
oauthPopupRef.current.close();
}
} else {
if (++attempts < maxAttempts) {
if (
oauthPopupRef.current &&
oauthPopupRef.current.closed &&
popupOpened &&
!callbackReceived
) {
setSaveActive(false);
onComplete({
success: false,
message: t('settings.tools.mcp.errors.oauthFailed'),
});
return;
}
pollTimerRef.current = setTimeout(poll, 1000);
} else {
setSaveActive(false);
cleanupPolling();
onComplete({
success: false,
message: t('settings.tools.mcp.errors.oauthTimeout'),
});
}
}
} catch {
if (pollingCancelledRef.current) return;
if (++attempts < maxAttempts) {
pollTimerRef.current = setTimeout(poll, 1000);
} else {
cleanupPolling();
onComplete({
success: false,
message: t('settings.tools.mcp.errors.oauthTimeout'),
});
}
continue;
}
}
}, [recentEvents, oauthTaskId, t]);
};
poll();
};
const testConnection = async () => {
if (!validateForm()) return;
cleanupOAuthListener();
cleanupPolling();
setTesting(true);
setTestResult(null);
setDiscoveredTools([]);
@@ -378,7 +355,7 @@ export default function MCPServerModal({
message: t('settings.tools.mcp.oauthInProgress'),
});
setSaveActive(false);
onCompleteRef.current = (finalResult: any) => {
pollOAuthStatus(result.task_id, (finalResult) => {
setTestResult(finalResult);
if (finalResult.tools && Array.isArray(finalResult.tools)) {
setDiscoveredTools(finalResult.tools);
@@ -388,11 +365,7 @@ export default function MCPServerModal({
oauth_task_id: result.task_id || '',
}));
setTesting(false);
};
// Activate the SSE listener for this task id. The effect above
// will react when ``mcp.oauth.{awaiting_redirect,completed,failed}``
// arrives.
setOauthTaskId(result.task_id);
});
} else {
setTestResult(result);
if (result.success && result.tools && Array.isArray(result.tools)) {

View File

@@ -14,8 +14,6 @@ export type Doc = {
syncFrequency?: string;
isNested?: boolean;
provider?: string;
// Derived server-side from ingest_chunk_progress (sources API).
ingestStatus?: 'processing' | 'failed';
};
export type GetDocsResponse = {

View File

@@ -1,174 +0,0 @@
import { useDispatch, useSelector } from 'react-redux';
import { useMatch, useNavigate } from 'react-router-dom';
import WarnIcon from '../assets/warn.svg';
import type { RootState } from '../store';
import {
dismissToolApproval,
selectDismissedToolApprovals,
selectRecentEvents,
} from './notificationsSlice';
/**
* Surface ``tool.approval.required`` events as toasts that look like
* ``UploadToast`` (same fixed bottom-right rail) — but only when the
* user is NOT already on the conversation that needs the approval.
*
* - Dedup by ``conversation_id`` (the SSE ``scope.id``): keep only
* the newest pending event per conversation, so multiple paused
* tools in one conversation collapse to one toast.
* - Dismissal is per-event-id so a *new* pause of the same
* conversation will re-surface (different event id).
* - Clicking "Review" navigates to ``/c/<id>`` and dismisses.
*/
export default function ToolApprovalToast() {
const dispatch = useDispatch();
const navigate = useNavigate();
const events = useSelector(selectRecentEvents);
const dismissed = useSelector(selectDismissedToolApprovals);
// Pull the active conversation id off the route. Two route shapes
// place a conversation in view: the bare ``/c/:conversationId`` and
// the agent-scoped ``/agents/:agentId/c/:conversationId``. Hooks
// are unconditional; the toast just respects whichever matches.
//
// ``/c/new`` is the conversation route's literal-string placeholder
// for "unknown / not-yet-loaded conversation" (see the rewrite in
// the conversation route). Treat it the same as no match — the
// user isn't viewing any specific conversation yet, so an approval
// toast for any conversation should still surface.
const plainMatch = useMatch('/c/:conversationId');
const agentMatch = useMatch('/agents/:agentId/c/:conversationId');
const matchedConversationId =
plainMatch?.params.conversationId ??
agentMatch?.params.conversationId ??
null;
const currentConversationId =
matchedConversationId === 'new' ? null : matchedConversationId;
// Conversation name lookup — best-effort. The slice's
// ``preference.conversations.data`` is populated by
// ``useDataInitializer`` once auth resolves; until then we fall
// back to the conversation id.
const conversations = useSelector(
(state: RootState) => state.preference.conversations.data,
);
const dismissedSet = new Set(dismissed);
const pendingByConversation = new Map<
string,
{ eventId: string; conversationId: string }
>();
for (const event of events) {
if (event.type !== 'tool.approval.required') continue;
if (!event.id) continue; // can't dismiss without an id
if (dismissedSet.has(event.id)) continue;
const conversationId = event.scope?.id;
if (!conversationId) continue;
if (currentConversationId && conversationId === currentConversationId) {
continue;
}
if (pendingByConversation.has(conversationId)) continue;
// ``recentEvents`` is newest-first, so the first match per convId
// is the most recent unhandled approval.
pendingByConversation.set(conversationId, {
eventId: event.id,
conversationId,
});
}
if (pendingByConversation.size === 0) return null;
const conversationName = (conversationId: string): string => {
const found = conversations?.find((c) => c.id === conversationId);
return found?.name ?? 'Conversation';
};
return (
// Sit above ``UploadToast`` (which owns ``bottom-4 right-4``)
// rather than overlapping it. ``bottom-24`` ≈ 96px clears one
// standard-height upload toast; multiple in-flight uploads will
// stack into the gap, at which point the approval toast still
// floats on top via ``z-50``. Acceptable v1 layout — the two
// surfaces are rarely competing.
<div
className="fixed right-4 bottom-24 z-50 flex max-w-md flex-col gap-2"
onMouseDown={(e) => e.stopPropagation()}
role="status"
aria-live="polite"
aria-atomic="true"
>
{Array.from(pendingByConversation.values()).map(
({ eventId, conversationId }) => (
<div
key={eventId}
className="border-border bg-card w-[271px] overflow-hidden rounded-2xl border shadow-[0px_24px_48px_0px_#00000029]"
>
<div className="bg-accent/50 dark:bg-muted flex items-center justify-between px-4 py-3">
<h3 className="font-inter dark:text-foreground text-[14px] leading-[16.5px] font-medium text-black">
Tool approval needed
</h3>
<button
type="button"
onClick={() => dispatch(dismissToolApproval(eventId))}
className="flex h-8 items-center justify-center p-0 text-black opacity-70 transition-opacity hover:opacity-100 dark:text-white"
aria-label="Dismiss"
>
<svg
width="16"
height="16"
viewBox="0 0 24 24"
fill="none"
xmlns="http://www.w3.org/2000/svg"
className="h-4 w-4"
>
<path
d="M18 6L6 18"
stroke="currentColor"
strokeWidth="2"
strokeLinecap="round"
strokeLinejoin="round"
/>
<path
d="M6 6L18 18"
stroke="currentColor"
strokeWidth="2"
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
</button>
</div>
<div className="flex items-center justify-between gap-3 px-5 py-3">
<div className="flex min-w-0 items-center gap-2">
<img
src={WarnIcon}
alt=""
className="h-5 w-5 shrink-0"
aria-hidden="true"
/>
<p
className="font-inter dark:text-muted-foreground max-w-[140px] truncate text-[13px] leading-[16.5px] font-normal text-black"
title={conversationName(conversationId)}
>
{conversationName(conversationId)}
</p>
</div>
<button
type="button"
onClick={() => {
dispatch(dismissToolApproval(eventId));
navigate(`/c/${conversationId}`);
}}
className="rounded-full bg-[#7D54D1] px-3 py-1 text-[12px] font-medium text-white shadow-sm hover:bg-[#6a45b8]"
>
Review
</button>
</div>
</div>
),
)}
</div>
);
}

View File

@@ -1,71 +0,0 @@
import { beforeEach, describe, expect, it } from 'vitest';
import {
isDismissed,
loadDismissed,
saveDismissed,
} from './dismissedPersistence';
const KEY = 'test:dismissed';
const TTL = 24 * 60 * 60 * 1000;
describe('dismissedPersistence', () => {
beforeEach(() => {
localStorage.clear();
});
it('saveDismissed + loadDismissed round-trips entries', () => {
const now = Date.now();
saveDismissed(KEY, [
{ id: 'a', at: now },
{ id: 'b', at: now - 1000 },
]);
const loaded = loadDismissed(KEY, TTL);
expect(loaded).toEqual([
{ id: 'a', at: now },
{ id: 'b', at: now - 1000 },
]);
});
it('loadDismissed returns [] when key absent', () => {
expect(loadDismissed(KEY, TTL)).toEqual([]);
});
it('loadDismissed drops entries past the TTL cutoff', () => {
const now = Date.now();
saveDismissed(KEY, [
{ id: 'fresh', at: now - 1000 },
{ id: 'stale', at: now - (TTL + 1000) },
]);
const loaded = loadDismissed(KEY, TTL);
expect(loaded.map((e) => e.id)).toEqual(['fresh']);
});
it('loadDismissed returns [] on malformed JSON without throwing', () => {
localStorage.setItem(KEY, '{not json');
expect(loadDismissed(KEY, TTL)).toEqual([]);
});
it('loadDismissed filters out entries with wrong shape', () => {
const now = Date.now();
localStorage.setItem(
KEY,
JSON.stringify([
{ id: 'good', at: now },
{ id: 123, at: now },
{ id: 'bad-at', at: 'nope' },
null,
'string-entry',
]),
);
const loaded = loadDismissed(KEY, TTL);
expect(loaded.map((e) => e.id)).toEqual(['good']);
});
it('isDismissed matches by id', () => {
const list = [{ id: 'a', at: 1 }];
expect(isDismissed(list, 'a')).toBe(true);
expect(isDismissed(list, 'b')).toBe(false);
expect(isDismissed([], 'a')).toBe(false);
});
});

View File

@@ -1,42 +0,0 @@
// Persisted dismissal lists for SSE-driven toasts. Without persistence
// the next page's backlog replay re-fires the events and pops the
// toast back. TTL matches the backend's stream retention.
export interface DismissedEntry {
id: string;
at: number;
}
export function loadDismissed(key: string, ttlMs: number): DismissedEntry[] {
if (typeof localStorage === 'undefined') return [];
try {
const raw = localStorage.getItem(key);
if (!raw) return [];
const parsed = JSON.parse(raw);
if (!Array.isArray(parsed)) return [];
const cutoff = Date.now() - ttlMs;
return parsed.filter(
(e): e is DismissedEntry =>
!!e &&
typeof e === 'object' &&
typeof (e as DismissedEntry).id === 'string' &&
typeof (e as DismissedEntry).at === 'number' &&
(e as DismissedEntry).at >= cutoff,
);
} catch {
return [];
}
}
export function saveDismissed(key: string, list: DismissedEntry[]): void {
if (typeof localStorage === 'undefined') return;
try {
localStorage.setItem(key, JSON.stringify(list));
} catch {
// Best-effort: ignore quota / private-mode errors.
}
}
export function isDismissed(list: DismissedEntry[], id: string): boolean {
return list.some((e) => e.id === id);
}

View File

@@ -1,109 +0,0 @@
import { describe, expect, it, vi, afterEach } from 'vitest';
import reducer, {
dismissToolApproval,
sseEventReceived,
sseLastEventIdReset,
type SSEEvent,
} from './notificationsSlice';
const baseEvent = (overrides: Partial<SSEEvent> = {}): SSEEvent => ({
id: 'evt-1',
type: 'source.ingest.progress',
...overrides,
});
const seedState = () => reducer(undefined, { type: '@@INIT' });
afterEach(() => {
vi.useRealTimers();
});
describe('sseEventReceived', () => {
it('dedupes by id when the same envelope arrives twice', () => {
let state = seedState();
state = reducer(state, sseEventReceived(baseEvent({ id: 'a' })));
state = reducer(state, sseEventReceived(baseEvent({ id: 'a' })));
expect(state.recentEvents).toHaveLength(1);
expect(state.recentEvents[0].id).toBe('a');
});
it('does not update lastEventId for envelopes without an id (backlog.truncated)', () => {
let state = seedState();
state = reducer(state, sseEventReceived(baseEvent({ id: 'cursor-1' })));
expect(state.lastEventId).toBe('cursor-1');
state = reducer(
state,
sseEventReceived({ type: 'backlog.truncated' } as SSEEvent),
);
expect(state.lastEventId).toBe('cursor-1');
});
it('caps recentEvents at 100 entries (oldest evicted)', () => {
let state = seedState();
for (let i = 0; i < 105; i += 1) {
state = reducer(state, sseEventReceived(baseEvent({ id: `e-${i}` })));
}
expect(state.recentEvents).toHaveLength(100);
// Newest first.
expect(state.recentEvents[0].id).toBe('e-104');
expect(state.recentEvents[state.recentEvents.length - 1].id).toBe('e-5');
});
});
describe('sseLastEventIdReset', () => {
it('clears lastEventId back to null', () => {
let state = seedState();
state = reducer(state, sseEventReceived(baseEvent({ id: 'x' })));
expect(state.lastEventId).toBe('x');
state = reducer(state, sseLastEventIdReset());
expect(state.lastEventId).toBeNull();
});
});
describe('dismissToolApproval', () => {
it('dedupes by id and refreshes the timestamp', () => {
vi.useFakeTimers();
vi.setSystemTime(new Date('2026-01-01T00:00:00Z'));
let state = seedState();
state = reducer(state, dismissToolApproval('approval-1'));
const firstAt = state.dismissedToolApprovals[0].at;
vi.setSystemTime(new Date('2026-01-01T00:05:00Z'));
state = reducer(state, dismissToolApproval('approval-1'));
expect(state.dismissedToolApprovals).toHaveLength(1);
expect(state.dismissedToolApprovals[0].at).toBeGreaterThan(firstAt);
});
it('evicts entries older than the 24h TTL', () => {
vi.useFakeTimers();
vi.setSystemTime(new Date('2026-01-01T00:00:00Z'));
let state = seedState();
state = reducer(state, dismissToolApproval('old-1'));
// Move past the 24h TTL window.
vi.setSystemTime(new Date('2026-01-02T00:00:01Z'));
state = reducer(state, dismissToolApproval('fresh-1'));
const ids = state.dismissedToolApprovals.map((entry) => entry.id);
expect(ids).toEqual(['fresh-1']);
});
it('applies the 200-entry cap as a backstop after TTL filtering', () => {
vi.useFakeTimers();
vi.setSystemTime(new Date('2026-01-01T00:00:00Z'));
let state = seedState();
// Insert 205 distinct ids within the TTL window.
for (let i = 0; i < 205; i += 1) {
// Advance time slightly so the at-values are distinct but well
// inside the 24h TTL.
vi.setSystemTime(
new Date(`2026-01-01T00:00:${(i % 60).toString().padStart(2, '0')}Z`),
);
state = reducer(state, dismissToolApproval(`id-${i}`));
}
expect(state.dismissedToolApprovals).toHaveLength(200);
// The 200-cap keeps the most recently pushed ids.
expect(state.dismissedToolApprovals[0].id).toBe('id-5');
expect(state.dismissedToolApprovals[199].id).toBe('id-204');
});
});

View File

@@ -1,200 +0,0 @@
import { createSelector, createSlice, PayloadAction } from '@reduxjs/toolkit';
import { RootState } from '../store';
import { loadDismissed, saveDismissed } from './dismissedPersistence';
const DISMISSED_TOOL_APPROVALS_STORAGE_KEY = 'docsgpt:dismissedToolApprovals';
/**
* Envelope shape published by the backend SSE endpoint
* (`application/events/publisher.py`). Mirrors the wire JSON 1:1.
*/
export interface SSEEvent<P = Record<string, unknown>> {
id?: string;
type: string;
ts?: string;
user_id?: string;
topic?: string;
scope?: { kind: string; id: string };
payload?: P;
}
/**
* Connection-health state machine the rest of the app reads via
* ``selectPushChannelHealthy`` to gate polling-fallback behaviour.
*
* - ``connecting`` — initial fetch in flight, or reconnecting after drop.
* - ``healthy`` — at least one event (data or keepalive) received and
* the stream has been open >2s.
* - ``unhealthy`` — last attempt failed or has been dropped without a
* successful re-establish; fall back to polling.
*/
export type PushHealth = 'connecting' | 'healthy' | 'unhealthy';
interface NotificationsState {
health: PushHealth;
/** Most-recent server-issued id; sent back as ``Last-Event-ID`` on reconnect. */
lastEventId: string | null;
/** Bounded ring of recent events for the in-app notifications surface. */
recentEvents: SSEEvent[];
/**
* Wallclock ms of last received data-bearing event. SSE keepalives
* are comment lines (no ``id:``/``data:``) and do NOT update this —
* they're filtered out at the parser level.
*/
lastEventReceivedAt: number | null;
/**
* Event ids of ``tool.approval.required`` notifications the user
* dismissed (close button or by navigating into the conversation),
* each tagged with the wallclock ms at which it was dismissed.
* Keyed by the SSE event id so a *new* approval for the same
* conversation re-surfaces; the dismissal only suppresses the one
* specific paused-tool prompt.
*
* Entries are evicted by TTL first (anything older than
* ``DISMISSED_TOOL_APPROVALS_TTL_MS``), then by FIFO cap. The TTL
* matters because a pure FIFO with a small cap can evict a *still-
* pending* approval id before the user acts on it — re-popping the
* toast on the next dispatch. The 24h ceiling is longer than any
* plausible approval-pending window.
*/
dismissedToolApprovals: Array<{ id: string; at: number }>;
}
const RECENT_EVENTS_CAP = 100;
const DISMISSED_TOOL_APPROVALS_CAP = 200;
const DISMISSED_TOOL_APPROVALS_TTL_MS = 24 * 60 * 60 * 1000;
const initialState: NotificationsState = {
health: 'connecting',
lastEventId: null,
recentEvents: [],
lastEventReceivedAt: null,
// Hydrate from localStorage: SSE backlog replay re-delivers the
// originating ``tool.approval.required`` envelopes on reload.
dismissedToolApprovals: loadDismissed(
DISMISSED_TOOL_APPROVALS_STORAGE_KEY,
DISMISSED_TOOL_APPROVALS_TTL_MS,
),
};
export const notificationsSlice = createSlice({
name: 'notifications',
initialState,
reducers: {
sseEventReceived: (state, action: PayloadAction<SSEEvent>) => {
const e = action.payload;
// Drop immediate duplicates. Snapshot replay + live tail can
// both deliver the same id when the live pubsub frame and the
// replay XRANGE overlap, and consumers that walk
// ``recentEvents`` (FileTree, ConnectorTree, MCPServerModal,
// ToolApprovalToast) would otherwise act on the same envelope
// twice. The route's dedup floor catches the common case; this
// is a belt-and-suspenders for in-tab StrictMode double-mounts
// and any envelope that slips through with the same id.
if (e.id && state.recentEvents[0]?.id === e.id) return;
state.recentEvents.unshift(e);
if (state.recentEvents.length > RECENT_EVENTS_CAP) {
state.recentEvents.length = RECENT_EVENTS_CAP;
}
if (e.id) state.lastEventId = e.id;
state.lastEventReceivedAt = Date.now();
},
sseHealthChanged: (state, action: PayloadAction<PushHealth>) => {
state.health = action.payload;
},
/**
* Lifecycle helper used by reconnect bookkeeping — does not record
* an event, just stamps "we heard from the server" so the polling
* fallback stays disabled while keepalives arrive.
*/
sseHeartbeat: (state) => {
state.lastEventReceivedAt = Date.now();
},
sseLastEventIdReset: (state) => {
// Backlog truncated — drop the cursor so the next reconnect
// doesn't try to resume past the retained window.
state.lastEventId = null;
},
/**
* Advance the cursor without recording an event. Fired for every
* id-bearing frame including keepalives and id-only comments,
* so the slice cursor tracks the freshest id the wire has
* delivered even when no envelope was dispatched. Without this,
* ``lastEventId`` would only advance via ``sseEventReceived`` and
* a long quiet period of keepalives would leave it stale —
* eventually re-snapshotting the same backlog on each reconnect
* and exhausting the per-user replay budget.
*/
sseLastEventIdAdvanced: (state, action: PayloadAction<string>) => {
state.lastEventId = action.payload;
},
clearRecentEvents: (state) => {
state.recentEvents = [];
},
/**
* Suppress a ``tool.approval.required`` notification by the SSE
* event id. The toast surface filters dismissed ids out; a *new*
* approval event for the same conversation has a different id
* and re-surfaces, which is the desired UX (each pause is its
* own decision).
*/
dismissToolApproval: (state, action: PayloadAction<string>) => {
const id = action.payload;
const now = Date.now();
// Evict expired entries first so the TTL — not the FIFO cap —
// governs when stale ids drop, keeping still-pending approvals
// suppressed.
const cutoff = now - DISMISSED_TOOL_APPROVALS_TTL_MS;
state.dismissedToolApprovals = state.dismissedToolApprovals.filter(
(entry) => entry.at >= cutoff && entry.id !== id,
);
state.dismissedToolApprovals.push({ id, at: now });
if (state.dismissedToolApprovals.length > DISMISSED_TOOL_APPROVALS_CAP) {
state.dismissedToolApprovals = state.dismissedToolApprovals.slice(
-DISMISSED_TOOL_APPROVALS_CAP,
);
}
saveDismissed(
DISMISSED_TOOL_APPROVALS_STORAGE_KEY,
state.dismissedToolApprovals,
);
},
},
});
export const {
sseEventReceived,
sseHealthChanged,
sseHeartbeat,
sseLastEventIdReset,
sseLastEventIdAdvanced,
clearRecentEvents,
dismissToolApproval,
} = notificationsSlice.actions;
export const selectSseHealth = (state: RootState): PushHealth =>
state.notifications.health;
export const selectPushChannelHealthy = (state: RootState): boolean =>
state.notifications.health === 'healthy';
export const selectLastEventId = (state: RootState): string | null =>
state.notifications.lastEventId;
export const selectLastEventReceivedAt = (state: RootState): number | null =>
state.notifications.lastEventReceivedAt;
export const selectRecentEvents = (state: RootState): SSEEvent[] =>
state.notifications.recentEvents;
// Memoised so ``useSelector`` consumers don't re-render on every
// unrelated ``notifications`` state change — the underlying ``{id,at}``
// array only changes when ``dismissToolApproval`` runs, but the
// projected ``.map`` would otherwise return a fresh array each call.
export const selectDismissedToolApprovals = createSelector(
(state: RootState) => state.notifications.dismissedToolApprovals,
(entries) => entries.map((entry) => entry.id),
);
export default notificationsSlice.reducer;

View File

@@ -17,7 +17,7 @@ import ContextMenu, { MenuOption } from '../components/ContextMenu';
import Pagination from '../components/DocumentPagination';
import DropdownMenu from '../components/DropdownMenu';
import SkeletonLoader from '../components/SkeletonLoader';
import { useDarkTheme, useDebouncedValue, useLoaderState } from '../hooks';
import { useDarkTheme, useLoaderState } from '../hooks';
import ConfirmationModal from '../modals/ConfirmationModal';
import { ActiveState, Doc, DocumentsProps } from '../models/misc';
import { getDocs, getDocsWithPagination } from '../preferences/preferenceApi';
@@ -27,12 +27,6 @@ import {
setSourceDocs,
} from '../preferences/preferenceSlice';
import Upload from '../upload/Upload';
import {
addUploadTask,
removeUploadTask,
selectUploadTasks,
updateUploadTask,
} from '../upload/uploadSlice';
import { formatDate } from '../utils/dateTimeUtils';
import FileTree from '../components/FileTree';
import ConnectorTree from '../components/ConnectorTree';
@@ -62,10 +56,9 @@ export default function Sources({
const [isDarkTheme] = useDarkTheme();
const dispatch = useDispatch();
const token = useSelector(selectToken);
const uploadTasks = useSelector(selectUploadTasks);
const [searchTerm, setSearchTerm] = useState<string>('');
const debouncedSearchTerm = useDebouncedValue(searchTerm, 500);
const [debouncedSearchTerm, setDebouncedSearchTerm] = useState<string>('');
const [modalState, setModalState] = useState<ActiveState>('INACTIVE');
const [isOnboarding, setIsOnboarding] = useState<boolean>(false);
const [loading, setLoading] = useLoaderState(false);
@@ -124,6 +117,14 @@ export default function Sources({
document: null,
});
useEffect(() => {
const timer = setTimeout(() => {
setDebouncedSearchTerm(searchTerm);
}, 500);
return () => clearTimeout(timer);
}, [searchTerm]);
const refreshDocs = useCallback(
(
field: 'date' | 'tokens' | undefined,
@@ -256,57 +257,6 @@ export default function Sources({
}
};
const handleReingest = async (doc: Doc) => {
if (!doc.id) {
return;
}
const sourceId = doc.id;
// Drop stale toast rows for this source (a finished/dismissed task
// would swallow the reingest's SSE events), then open a fresh one.
uploadTasks
.filter((task) => task.sourceId === sourceId)
.forEach((task) => dispatch(removeUploadTask(task.id)));
const reingestTaskId = `reingest-${sourceId}-${Date.now()}`;
dispatch(
addUploadTask({
id: reingestTaskId,
fileName: doc.name || sourceId,
progress: 0,
status: 'training',
sourceId,
}),
);
try {
const response = await userService.reingestSource(
{ source_id: sourceId },
token,
);
const data = await response.json();
if (!data.success) {
console.error('Reingest failed:', data.error || data.message);
dispatch(
updateUploadTask({
id: reingestTaskId,
updates: {
status: 'failed',
errorMessage: data.error || data.message,
},
}),
);
return;
}
refreshDocs(undefined, currentPage, rowsPerPage);
} catch (error) {
console.error('Error reingesting source:', error);
dispatch(
updateUploadTask({
id: reingestTaskId,
updates: { status: 'failed' },
}),
);
}
};
const [documentToDelete, setDocumentToDelete] = useState<{
index: number;
document: Doc;
@@ -341,19 +291,6 @@ export default function Sources({
},
];
if (document.ingestStatus === 'failed') {
actions.push({
icon: SyncIcon,
label: t('settings.sources.reingest'),
onClick: () => {
handleReingest(document);
},
iconWidth: 14,
iconHeight: 14,
variant: 'primary',
});
}
if (document.syncFrequency) {
actions.push({
icon: SyncIcon,
@@ -554,16 +491,6 @@ export default function Sources({
</div>
<div className="flex flex-col items-start justify-start gap-1">
{document.ingestStatus === 'failed' && (
<span className="rounded-full bg-red-100 px-2 py-0.5 text-[11px] leading-[16px] font-medium text-red-700 dark:bg-red-900/30 dark:text-red-400">
{t('settings.sources.ingestFailed')}
</span>
)}
{document.ingestStatus === 'processing' && (
<span className="bg-muted-foreground/10 text-muted-foreground rounded-full px-2 py-0.5 text-[11px] leading-[16px] font-medium">
{t('settings.sources.ingestProcessing')}
</span>
)}
<div className="flex items-center gap-2">
<img
src={CalendarIcon}

Some files were not shown because too many files have changed in this diff Show More