mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-22 13:25:08 +00:00
Compare commits
4 Commits
fix-stuck-
...
convsearch
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
36bb1a5b38 | ||
|
|
fec4f5b336 | ||
|
|
8b9eb5cffe | ||
|
|
1a764c6ee8 |
@@ -8,7 +8,7 @@ RUN apt-get update && \
|
||||
add-apt-repository ppa:deadsnakes/ppa && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends gcc g++ wget unzip libc6-dev python3.12 python3.12-venv python3.12-dev && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Verify Python installation and setup symlink
|
||||
RUN if [ -f /usr/bin/python3.12 ]; then \
|
||||
@@ -73,7 +73,7 @@ COPY --from=builder /models /app/models
|
||||
COPY . /app/application
|
||||
|
||||
# Change the ownership of the /app directory to the appuser
|
||||
|
||||
|
||||
RUN mkdir -p /app/application/inputs/local
|
||||
RUN chown -R appuser:appuser /app
|
||||
|
||||
@@ -82,11 +82,6 @@ ENV FLASK_APP=app.py \
|
||||
FLASK_DEBUG=true \
|
||||
PATH="/venv/bin:$PATH"
|
||||
|
||||
ENV MALLOC_ARENA_MAX=2 \
|
||||
OMP_NUM_THREADS=4 \
|
||||
MKL_NUM_THREADS=4 \
|
||||
OPENBLAS_NUM_THREADS=4
|
||||
|
||||
# Expose the port the app runs on
|
||||
EXPOSE 7091
|
||||
|
||||
|
||||
@@ -114,8 +114,6 @@ class BaseAgent(ABC):
|
||||
self.compressed_summary = compressed_summary
|
||||
self.current_token_count = 0
|
||||
self.context_limit_reached = False
|
||||
self.conversation_id: Optional[str] = None
|
||||
self.initial_user_id: Optional[str] = None
|
||||
|
||||
@log_activity()
|
||||
def gen(
|
||||
|
||||
@@ -20,11 +20,10 @@ from pydantic import AnyHttpUrl, ValidationError
|
||||
from redis import Redis
|
||||
|
||||
from application.agents.tools.base import Tool
|
||||
from application.api.user.tasks import mcp_oauth_task
|
||||
from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task
|
||||
from application.cache import get_redis_instance
|
||||
from application.core.settings import settings
|
||||
from application.core.url_validation import SSRFError, validate_url
|
||||
from application.events.keys import stream_key
|
||||
from application.security.encryption import decrypt_credentials
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -77,12 +76,6 @@ class MCPTool(Tool):
|
||||
self.oauth_task_id = config.get("oauth_task_id", None)
|
||||
self.oauth_client_name = config.get("oauth_client_name", "DocsGPT-MCP")
|
||||
self.redirect_uri = self._resolve_redirect_uri(config.get("redirect_uri"))
|
||||
# Pulled out of ``config`` (rather than left in ``self.config``)
|
||||
# because it is a callable supplied by the OAuth worker — not
|
||||
# something the rest of the tool plumbing should marshal or
|
||||
# serialize. ``DocsGPTOAuth`` invokes it from ``redirect_handler``
|
||||
# so the SSE envelope can carry ``authorization_url``.
|
||||
self.oauth_redirect_publish = config.pop("oauth_redirect_publish", None)
|
||||
|
||||
self.available_tools = []
|
||||
self._cache_key = self._generate_cache_key()
|
||||
@@ -174,7 +167,6 @@ class MCPTool(Tool):
|
||||
redirect_uri=self.redirect_uri,
|
||||
task_id=self.oauth_task_id,
|
||||
user_id=self.user_id,
|
||||
redirect_publish=self.oauth_redirect_publish,
|
||||
)
|
||||
elif self.auth_type == "bearer":
|
||||
token = self.auth_credentials.get(
|
||||
@@ -687,17 +679,12 @@ class DocsGPTOAuth(OAuthClientProvider):
|
||||
user_id=None,
|
||||
additional_client_metadata: dict[str, Any] | None = None,
|
||||
skip_redirect_validation: bool = False,
|
||||
redirect_publish=None,
|
||||
):
|
||||
self.redirect_uri = redirect_uri
|
||||
self.redis_client = redis_client
|
||||
self.redis_prefix = redis_prefix
|
||||
self.task_id = task_id
|
||||
self.user_id = user_id
|
||||
# Worker-supplied callback. Invoked from ``redirect_handler``
|
||||
# once the authorization URL is known so the SSE envelope can
|
||||
# carry it. ``None`` for any non-worker entrypoint.
|
||||
self.redirect_publish = redirect_publish
|
||||
|
||||
parsed_url = urlparse(mcp_url)
|
||||
self.server_base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||||
@@ -757,19 +744,17 @@ class DocsGPTOAuth(OAuthClientProvider):
|
||||
self.redis_client.setex(key, 600, auth_url)
|
||||
logger.info("Stored auth_url in Redis: %s", key)
|
||||
|
||||
if self.redirect_publish is not None:
|
||||
# Best-effort: a publish failure must not abort the OAuth
|
||||
# handshake — the user can still authorize via the popup
|
||||
# opened from the legacy polling fallback if the SSE
|
||||
# envelope is lost.
|
||||
try:
|
||||
self.redirect_publish(auth_url)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"redirect_publish callback raised for task_id=%s",
|
||||
self.task_id,
|
||||
exc_info=True,
|
||||
)
|
||||
if self.task_id:
|
||||
status_key = f"mcp_oauth_status:{self.task_id}"
|
||||
status_data = {
|
||||
"status": "requires_redirect",
|
||||
"message": "Authorization required",
|
||||
"authorization_url": self.auth_url,
|
||||
"state": self.extracted_state,
|
||||
"requires_oauth": True,
|
||||
"task_id": self.task_id,
|
||||
}
|
||||
self.redis_client.setex(status_key, 600, json.dumps(status_data))
|
||||
|
||||
async def callback_handler(self) -> tuple[str, str | None]:
|
||||
"""Wait for auth code from Redis using the state value."""
|
||||
@@ -779,6 +764,17 @@ class DocsGPTOAuth(OAuthClientProvider):
|
||||
max_wait_time = 300
|
||||
code_key = f"{self.redis_prefix}code:{self.extracted_state}"
|
||||
|
||||
if self.task_id:
|
||||
status_key = f"mcp_oauth_status:{self.task_id}"
|
||||
status_data = {
|
||||
"status": "awaiting_callback",
|
||||
"message": "Waiting for authorization...",
|
||||
"authorization_url": self.auth_url,
|
||||
"state": self.extracted_state,
|
||||
"requires_oauth": True,
|
||||
"task_id": self.task_id,
|
||||
}
|
||||
self.redis_client.setex(status_key, 600, json.dumps(status_data))
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < max_wait_time:
|
||||
code_data = self.redis_client.get(code_key)
|
||||
@@ -793,6 +789,14 @@ class DocsGPTOAuth(OAuthClientProvider):
|
||||
self.redis_client.delete(
|
||||
f"{self.redis_prefix}state:{self.extracted_state}"
|
||||
)
|
||||
|
||||
if self.task_id:
|
||||
status_data = {
|
||||
"status": "callback_received",
|
||||
"message": "Completing authentication...",
|
||||
"task_id": self.task_id,
|
||||
}
|
||||
self.redis_client.setex(status_key, 600, json.dumps(status_data))
|
||||
return code, returned_state
|
||||
error_key = f"{self.redis_prefix}error:{self.extracted_state}"
|
||||
error_data = self.redis_client.get(error_key)
|
||||
@@ -1034,73 +1038,8 @@ class MCPOAuthManager:
|
||||
logger.error("Error handling OAuth callback: %s", e)
|
||||
return False
|
||||
|
||||
def get_oauth_status(self, task_id: str, user_id: str) -> Dict[str, Any]:
|
||||
"""Return the latest OAuth status for ``task_id`` from the user's SSE journal.
|
||||
|
||||
Mirrors the legacy polling contract: ``status`` derived from the
|
||||
``mcp.oauth.*`` event-type suffix, with payload fields surfaced
|
||||
(e.g. ``tools``/``tools_count`` on ``completed``).
|
||||
"""
|
||||
def get_oauth_status(self, task_id: str) -> Dict[str, Any]:
|
||||
"""Get current status of OAuth flow using provided task_id."""
|
||||
if not task_id:
|
||||
return {"status": "not_started", "message": "OAuth flow not started"}
|
||||
if not user_id:
|
||||
return {"status": "not_found", "message": "User not provided"}
|
||||
if self.redis_client is None:
|
||||
return {"status": "not_found", "message": "Redis unavailable"}
|
||||
|
||||
try:
|
||||
# OAuth flows are short-lived but a concurrent source
|
||||
# ingest can flood the user channel between the OAuth
|
||||
# popup completing and the user clicking Save, pushing the
|
||||
# completion envelope outside the read window. Bound the
|
||||
# scan by the configured stream cap so we cover the full
|
||||
# journal — XADD MAXLEN keeps that bounded too.
|
||||
scan_count = max(settings.EVENTS_STREAM_MAXLEN, 200)
|
||||
entries = self.redis_client.xrevrange(
|
||||
stream_key(user_id), count=scan_count
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"xrevrange failed for oauth status: user_id=%s task_id=%s",
|
||||
user_id,
|
||||
task_id,
|
||||
)
|
||||
return {"status": "not_found", "message": "Status unavailable"}
|
||||
|
||||
for _entry_id, fields in entries:
|
||||
if not isinstance(fields, dict):
|
||||
continue
|
||||
# decode_responses=False ⇒ bytes keys; the string-key fallback
|
||||
# covers a future flip of that default without a forced refactor.
|
||||
event_raw = fields.get(b"event")
|
||||
if event_raw is None:
|
||||
event_raw = fields.get("event")
|
||||
if event_raw is None:
|
||||
continue
|
||||
if isinstance(event_raw, bytes):
|
||||
try:
|
||||
event_raw = event_raw.decode("utf-8")
|
||||
except Exception:
|
||||
continue
|
||||
try:
|
||||
envelope = json.loads(event_raw)
|
||||
except Exception:
|
||||
continue
|
||||
if not isinstance(envelope, dict):
|
||||
continue
|
||||
event_type = envelope.get("type", "")
|
||||
if not isinstance(event_type, str) or not event_type.startswith(
|
||||
"mcp.oauth."
|
||||
):
|
||||
continue
|
||||
scope = envelope.get("scope") or {}
|
||||
if scope.get("kind") != "mcp_oauth" or scope.get("id") != task_id:
|
||||
continue
|
||||
payload = envelope.get("payload") or {}
|
||||
return {
|
||||
"status": event_type[len("mcp.oauth."):],
|
||||
"task_id": task_id,
|
||||
**payload,
|
||||
}
|
||||
|
||||
return {"status": "not_found", "message": "Status not found"}
|
||||
return mcp_oauth_status_task(task_id)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""0001 initial schema — consolidated baseline for user-data tables.
|
||||
"""0001 initial schema — consolidated Phase-1..3 baseline.
|
||||
|
||||
Revision ID: 0001_initial
|
||||
Revises:
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
"""0007 message_events — durable journal of chat-stream events.
|
||||
|
||||
Snapshot half of the chat-stream snapshot+tail pattern. Composite PK
|
||||
``(message_id, sequence_no)``, ``created_at`` indexed for retention
|
||||
sweeps, ``ON DELETE CASCADE`` from ``conversation_messages``.
|
||||
|
||||
Revision ID: 0007_message_events
|
||||
Revises: 0006_idempotency_lease
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
revision: str = "0007_message_events"
|
||||
down_revision: Union[str, None] = "0006_idempotency_lease"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE message_events (
|
||||
message_id UUID NOT NULL REFERENCES conversation_messages(id) ON DELETE CASCADE,
|
||||
sequence_no INTEGER NOT NULL,
|
||||
event_type TEXT NOT NULL,
|
||||
payload JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
PRIMARY KEY (message_id, sequence_no)
|
||||
);
|
||||
CREATE INDEX message_events_created_at_idx ON message_events(created_at);
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP INDEX IF EXISTS message_events_created_at_idx;")
|
||||
op.execute("DROP TABLE IF EXISTS message_events;")
|
||||
@@ -1,44 +0,0 @@
|
||||
"""0008 ingest_chunk_progress.status — terminal flag for stalled ingests.
|
||||
|
||||
The reconciler's stalled-ingest sweep had no terminal write, so a dead
|
||||
ingest re-alerted every ~30 min forever. ``status`` lets it escalate a
|
||||
stalled checkpoint to ``'stalled'`` once and stop re-selecting it;
|
||||
``init_progress`` resets it to ``'active'`` on reingest.
|
||||
|
||||
Revision ID: 0008_ingest_progress_status
|
||||
Revises: 0007_message_events
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
revision: str = "0008_ingest_progress_status"
|
||||
down_revision: Union[str, None] = "0007_message_events"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Constant DEFAULT — metadata-only ADD COLUMN, no table rewrite.
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE ingest_chunk_progress
|
||||
ADD COLUMN status TEXT NOT NULL DEFAULT 'active'
|
||||
CHECK (status IN ('active', 'stalled'));
|
||||
"""
|
||||
)
|
||||
# Partial index for the reconciler's stalled-ingest sweep.
|
||||
op.execute(
|
||||
"CREATE INDEX ingest_chunk_progress_active_idx "
|
||||
"ON ingest_chunk_progress (last_updated) "
|
||||
"WHERE status = 'active';"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP INDEX IF EXISTS ingest_chunk_progress_active_idx;")
|
||||
op.execute(
|
||||
"ALTER TABLE ingest_chunk_progress DROP COLUMN IF EXISTS status;"
|
||||
)
|
||||
@@ -23,16 +23,9 @@ from application.core.settings import settings
|
||||
from application.error import sanitize_api_error
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.conversations import MessageUpdateOutcome
|
||||
from application.storage.db.repositories.token_usage import TokenUsageRepository
|
||||
from application.storage.db.repositories.user_logs import UserLogsRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.events.publisher import publish_user_event
|
||||
from application.streaming.event_replay import format_sse_event
|
||||
from application.streaming.message_journal import (
|
||||
BatchedJournalWriter,
|
||||
record_event,
|
||||
)
|
||||
from application.utils import check_required_fields
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -284,17 +277,6 @@ class BaseAnswerResource:
|
||||
"update_message_status streaming failed for %s",
|
||||
reserved_message_id,
|
||||
)
|
||||
# Seed last_heartbeat_at so watchdog doesn't fall back to `timestamp`
|
||||
# (creation time) before the first STREAM_HEARTBEAT_INTERVAL tick.
|
||||
try:
|
||||
self.conversation_service.heartbeat_message(
|
||||
reserved_message_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"initial heartbeat seed failed for %s",
|
||||
reserved_message_id,
|
||||
)
|
||||
streaming_marked = True
|
||||
last_heartbeat_at = time.monotonic()
|
||||
|
||||
@@ -321,73 +303,13 @@ class BaseAnswerResource:
|
||||
try:
|
||||
agent.tool_executor.message_id = reserved_message_id
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Could not set tool_executor.message_id; tool-call correlation will be missing for message_id=%s",
|
||||
reserved_message_id,
|
||||
)
|
||||
|
||||
# Per-stream monotonic SSE event id. Allocated by ``_emit`` and
|
||||
# threaded through both the wire format (``id: <seq>\\n``) and
|
||||
# the journal write so a reconnecting client can ``Last-Event-
|
||||
# ID`` past anything they already saw. Continuations resume
|
||||
# against the original ``reserved_message_id`` — seed the
|
||||
# allocator from the journal's high-water mark so we don't
|
||||
# collide on the duplicate-PK and silently lose every emit
|
||||
# past the resume point.
|
||||
sequence_no = -1
|
||||
if _continuation and reserved_message_id:
|
||||
try:
|
||||
from application.storage.db.repositories.message_events import (
|
||||
MessageEventsRepository,
|
||||
)
|
||||
|
||||
with db_readonly() as conn:
|
||||
latest = MessageEventsRepository(conn).latest_sequence_no(
|
||||
reserved_message_id
|
||||
)
|
||||
if latest is not None:
|
||||
sequence_no = latest
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Continuation seq seed lookup failed for message_id=%s; "
|
||||
"falling back to seq=-1 (duplicate-PK collisions will "
|
||||
"be swallowed)",
|
||||
reserved_message_id,
|
||||
)
|
||||
|
||||
# One batched journal writer per stream.
|
||||
journal_writer: Optional[BatchedJournalWriter] = (
|
||||
BatchedJournalWriter(reserved_message_id)
|
||||
if reserved_message_id
|
||||
else None
|
||||
)
|
||||
|
||||
def _emit(payload: dict) -> str:
|
||||
"""Format-and-journal one SSE event.
|
||||
|
||||
With a reserved ``message_id``, buffers into the journal and
|
||||
emits ``id: <seq>``-tagged SSE frames; otherwise falls back to
|
||||
legacy ``data: ...\\n\\n`` framing.
|
||||
"""
|
||||
nonlocal sequence_no
|
||||
if not reserved_message_id or journal_writer is None:
|
||||
return f"data: {json.dumps(payload)}\n\n"
|
||||
sequence_no += 1
|
||||
seq = sequence_no
|
||||
event_type = (
|
||||
payload.get("type", "data")
|
||||
if isinstance(payload, dict)
|
||||
else "data"
|
||||
)
|
||||
normalised = payload if isinstance(payload, dict) else {"value": payload}
|
||||
journal_writer.record(seq, event_type, normalised)
|
||||
return format_sse_event(normalised, seq)
|
||||
pass
|
||||
|
||||
try:
|
||||
# Surface the placeholder id before any LLM tokens so a
|
||||
# mid-handshake disconnect still has a row to tail-poll.
|
||||
if reserved_message_id:
|
||||
yield _emit(
|
||||
early_event = json.dumps(
|
||||
{
|
||||
"type": "message_id",
|
||||
"message_id": reserved_message_id,
|
||||
@@ -397,6 +319,7 @@ class BaseAnswerResource:
|
||||
"request_id": request_id,
|
||||
}
|
||||
)
|
||||
yield f"data: {early_event}\n\n"
|
||||
|
||||
if _continuation:
|
||||
gen_iter = agent.gen_continuation(
|
||||
@@ -422,9 +345,8 @@ class BaseAnswerResource:
|
||||
schema_info = line.get("schema")
|
||||
structured_chunks.append(line["answer"])
|
||||
else:
|
||||
yield _emit(
|
||||
{"type": "answer", "answer": line["answer"]}
|
||||
)
|
||||
data = json.dumps({"type": "answer", "answer": line["answer"]})
|
||||
yield f"data: {data}\n\n"
|
||||
elif "sources" in line:
|
||||
_mark_streaming_once()
|
||||
truncated_sources = []
|
||||
@@ -437,40 +359,43 @@ class BaseAnswerResource:
|
||||
)
|
||||
truncated_sources.append(truncated_source)
|
||||
if truncated_sources:
|
||||
yield _emit(
|
||||
data = json.dumps(
|
||||
{"type": "source", "source": truncated_sources}
|
||||
)
|
||||
yield f"data: {data}\n\n"
|
||||
elif "tool_calls" in line:
|
||||
tool_calls = line["tool_calls"]
|
||||
yield _emit({"type": "tool_calls", "tool_calls": tool_calls})
|
||||
data = json.dumps({"type": "tool_calls", "tool_calls": tool_calls})
|
||||
yield f"data: {data}\n\n"
|
||||
elif "thought" in line:
|
||||
thought += line["thought"]
|
||||
yield _emit({"type": "thought", "thought": line["thought"]})
|
||||
data = json.dumps({"type": "thought", "thought": line["thought"]})
|
||||
yield f"data: {data}\n\n"
|
||||
elif "type" in line:
|
||||
if line.get("type") == "tool_calls_pending":
|
||||
# Save continuation state and end the stream
|
||||
paused = True
|
||||
yield _emit(line)
|
||||
data = json.dumps(line)
|
||||
yield f"data: {data}\n\n"
|
||||
elif line.get("type") == "error":
|
||||
yield _emit(
|
||||
{
|
||||
"type": "error",
|
||||
"error": sanitize_api_error(
|
||||
line.get("error", "An error occurred")
|
||||
),
|
||||
}
|
||||
)
|
||||
sanitized_error = {
|
||||
"type": "error",
|
||||
"error": sanitize_api_error(line.get("error", "An error occurred"))
|
||||
}
|
||||
data = json.dumps(sanitized_error)
|
||||
yield f"data: {data}\n\n"
|
||||
else:
|
||||
yield _emit(line)
|
||||
data = json.dumps(line)
|
||||
yield f"data: {data}\n\n"
|
||||
if is_structured and structured_chunks:
|
||||
yield _emit(
|
||||
{
|
||||
"type": "structured_answer",
|
||||
"answer": response_full,
|
||||
"structured": True,
|
||||
"schema": schema_info,
|
||||
}
|
||||
)
|
||||
structured_data = {
|
||||
"type": "structured_answer",
|
||||
"answer": response_full,
|
||||
"structured": True,
|
||||
"schema": schema_info,
|
||||
}
|
||||
data = json.dumps(structured_data)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
# ---- Paused: save continuation state and end stream early ----
|
||||
if paused:
|
||||
@@ -527,7 +452,6 @@ class BaseAnswerResource:
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
state_saved = False
|
||||
if conversation_id:
|
||||
try:
|
||||
cont_service = ContinuationService()
|
||||
@@ -561,65 +485,18 @@ class BaseAnswerResource:
|
||||
agent.tool_executor, "client_tools", None
|
||||
),
|
||||
)
|
||||
state_saved = True
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to save continuation state: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Notify the user out-of-band so they can navigate
|
||||
# back to the conversation and decide on the
|
||||
# pending tool calls. Gated on ``state_saved``: a
|
||||
# missing pending_tool_state row would 404 the
|
||||
# resume endpoint, so an unfulfillable notification
|
||||
# is worse than no notification.
|
||||
user_id_for_event = (
|
||||
decoded_token.get("sub") if decoded_token else None
|
||||
)
|
||||
if state_saved and user_id_for_event and conversation_id:
|
||||
pending_calls = continuation.get(
|
||||
"pending_tool_calls", []
|
||||
) if continuation else []
|
||||
# Trim each pending tool call to its identifying
|
||||
# metadata so a tool with a multi-MB argument
|
||||
# doesn't blow out the per-event payload size
|
||||
# cap. The resume page fetches full args from
|
||||
# ``pending_tool_state`` regardless.
|
||||
pending_summaries = [
|
||||
{
|
||||
k: tc.get(k)
|
||||
for k in (
|
||||
"call_id",
|
||||
"tool_name",
|
||||
"action_name",
|
||||
"name",
|
||||
)
|
||||
if isinstance(tc, dict) and tc.get(k) is not None
|
||||
}
|
||||
for tc in (pending_calls or [])
|
||||
if isinstance(tc, dict)
|
||||
]
|
||||
publish_user_event(
|
||||
user_id_for_event,
|
||||
"tool.approval.required",
|
||||
{
|
||||
"conversation_id": str(conversation_id),
|
||||
"message_id": reserved_message_id,
|
||||
"pending_tool_calls": pending_summaries,
|
||||
},
|
||||
scope={
|
||||
"kind": "conversation",
|
||||
"id": str(conversation_id),
|
||||
},
|
||||
)
|
||||
id_data = {"type": "id", "id": str(conversation_id)}
|
||||
data = json.dumps(id_data)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
yield _emit({"type": "id", "id": str(conversation_id)})
|
||||
yield _emit({"type": "end"})
|
||||
# Drain the terminal ``end`` so a reconnecting client
|
||||
# sees it on snapshot — same reason as the main exit.
|
||||
if journal_writer is not None:
|
||||
journal_writer.close()
|
||||
data = json.dumps({"type": "end"})
|
||||
yield f"data: {data}\n\n"
|
||||
return
|
||||
|
||||
if isNoneDoc:
|
||||
@@ -726,7 +603,9 @@ class BaseAnswerResource:
|
||||
f"completion: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
yield _emit({"type": "id", "id": str(conversation_id)})
|
||||
id_data = {"type": "id", "id": str(conversation_id)}
|
||||
data = json.dumps(id_data)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
tool_calls_for_logging = self._prepare_tool_calls_for_logging(
|
||||
getattr(agent, "tool_calls", tool_calls) or tool_calls
|
||||
@@ -767,33 +646,12 @@ class BaseAnswerResource:
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
yield _emit({"type": "end"})
|
||||
# Drain the journal buffer so the terminal ``end`` event is
|
||||
# visible to any reconnecting client. Without this the
|
||||
# client could snapshot up to the last flush boundary and
|
||||
# then live-tail waiting for an ``end`` that's still
|
||||
# sitting in memory.
|
||||
if journal_writer is not None:
|
||||
journal_writer.close()
|
||||
data = json.dumps({"type": "end"})
|
||||
yield f"data: {data}\n\n"
|
||||
except GeneratorExit:
|
||||
logger.info(f"Stream aborted by client for question: {question[:50]}... ")
|
||||
# Drain any buffered events before the terminal one-shot
|
||||
# ``record_event`` below — keeps the journal's seq order
|
||||
# contiguous (buffered events ... terminal event). ``close``
|
||||
# is idempotent; pairing it with ``flush`` matches the
|
||||
# normal-exit and error branches so any future ``record()``
|
||||
# past this point would log instead of silently buffering.
|
||||
if journal_writer is not None:
|
||||
journal_writer.flush()
|
||||
journal_writer.close()
|
||||
# Save partial response
|
||||
|
||||
# Whether the DB row was flipped to ``complete`` during this
|
||||
# abort handler. Drives the choice of terminal journal event
|
||||
# below: journal ``end`` only when the row actually matches,
|
||||
# else journal ``error`` so a reconnecting client sees a
|
||||
# failed terminal state instead of a blank "success".
|
||||
finalized_complete = False
|
||||
if should_save_conversation and response_full:
|
||||
try:
|
||||
if isNoneDoc:
|
||||
@@ -828,7 +686,7 @@ class BaseAnswerResource:
|
||||
)
|
||||
llm._token_usage_source = "title"
|
||||
if reserved_message_id is not None:
|
||||
outcome = self.conversation_service.finalize_message(
|
||||
self.conversation_service.finalize_message(
|
||||
reserved_message_id,
|
||||
response_full,
|
||||
thought=thought,
|
||||
@@ -847,15 +705,6 @@ class BaseAnswerResource:
|
||||
),
|
||||
},
|
||||
)
|
||||
# ``ALREADY_COMPLETE`` means the normal-path
|
||||
# finalize at line 632 won the race: the DB row
|
||||
# is already at ``complete`` and the reconnect
|
||||
# journal should reflect that with ``end``,
|
||||
# not a spurious ``error``.
|
||||
finalized_complete = outcome in (
|
||||
MessageUpdateOutcome.UPDATED,
|
||||
MessageUpdateOutcome.ALREADY_COMPLETE,
|
||||
)
|
||||
else:
|
||||
self.conversation_service.save_conversation(
|
||||
conversation_id,
|
||||
@@ -875,9 +724,6 @@ class BaseAnswerResource:
|
||||
attachment_ids=attachment_ids,
|
||||
metadata=query_metadata if query_metadata else None,
|
||||
)
|
||||
# No journal row to gate, but flag the save as
|
||||
# successful for symmetry with the WAL path.
|
||||
finalized_complete = True
|
||||
compression_meta = getattr(agent, "compression_metadata", None)
|
||||
compression_saved = getattr(agent, "compression_saved", False)
|
||||
if conversation_id and compression_meta and not compression_saved:
|
||||
@@ -901,63 +747,6 @@ class BaseAnswerResource:
|
||||
logger.error(
|
||||
f"Error saving partial response: {str(e)}", exc_info=True
|
||||
)
|
||||
# Journal a terminal event so reconnecting clients stop tailing;
|
||||
# ``end`` only when the row is ``complete``, else ``error``.
|
||||
if reserved_message_id is not None:
|
||||
try:
|
||||
sequence_no += 1
|
||||
if finalized_complete:
|
||||
# Match the wire shape ``_emit({"type": "end"})``
|
||||
# uses on the normal path — the replay terminal
|
||||
# check at ``event_replay._payload_is_terminal``
|
||||
# reads ``payload.type``, and the frontend parses
|
||||
# the same key off ``data:``.
|
||||
record_event(
|
||||
reserved_message_id,
|
||||
sequence_no,
|
||||
"end",
|
||||
{"type": "end"},
|
||||
)
|
||||
else:
|
||||
# Nothing was persisted under the complete status
|
||||
# — mark the row failed so the reconciler doesn't
|
||||
# need to sweep it, and journal an ``error`` so a
|
||||
# reconnecting client surfaces the same failure
|
||||
# the UI would show on a live error.
|
||||
try:
|
||||
self.conversation_service.finalize_message(
|
||||
reserved_message_id,
|
||||
response_full or TERMINATED_RESPONSE_PLACEHOLDER,
|
||||
thought=thought,
|
||||
sources=source_log_docs,
|
||||
tool_calls=tool_calls,
|
||||
model_id=model_id or self.default_model_id,
|
||||
metadata=query_metadata if query_metadata else None,
|
||||
status="failed",
|
||||
error=ConnectionError(
|
||||
"client disconnected before response was persisted"
|
||||
),
|
||||
)
|
||||
except Exception as fin_err:
|
||||
logger.error(
|
||||
f"Failed to mark aborted message failed: {fin_err}",
|
||||
exc_info=True,
|
||||
)
|
||||
record_event(
|
||||
reserved_message_id,
|
||||
sequence_no,
|
||||
"error",
|
||||
{
|
||||
"type": "error",
|
||||
"error": "Stream aborted before any response was produced.",
|
||||
"code": "client_disconnect",
|
||||
},
|
||||
)
|
||||
except Exception as journal_err:
|
||||
logger.error(
|
||||
f"Failed to journal terminal event on abort: {journal_err}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in stream: {str(e)}", exc_info=True)
|
||||
@@ -979,16 +768,13 @@ class BaseAnswerResource:
|
||||
f"Failed to finalize errored message: {fin_err}",
|
||||
exc_info=True,
|
||||
)
|
||||
yield _emit(
|
||||
data = json.dumps(
|
||||
{
|
||||
"type": "error",
|
||||
"error": "Please try again later. We apologize for any inconvenience.",
|
||||
}
|
||||
)
|
||||
# Drain the terminal ``error`` event we just yielded so a
|
||||
# reconnecting client sees it on snapshot.
|
||||
if journal_writer is not None:
|
||||
journal_writer.close()
|
||||
yield f"data: {data}\n\n"
|
||||
return
|
||||
|
||||
def process_response_stream(self, stream) -> Dict[str, Any]:
|
||||
@@ -1010,22 +796,8 @@ class BaseAnswerResource:
|
||||
|
||||
for line in stream:
|
||||
try:
|
||||
# Each chunk may carry an ``id: <seq>`` header before
|
||||
# the ``data:`` line. Pull just the ``data:`` body so
|
||||
# the JSON decode doesn't choke on the SSE framing.
|
||||
event_data = ""
|
||||
for raw in line.split("\n"):
|
||||
if raw.startswith("data:"):
|
||||
event_data = raw[len("data:") :].lstrip()
|
||||
break
|
||||
if not event_data:
|
||||
continue
|
||||
event_data = line.replace("data: ", "").strip()
|
||||
event = json.loads(event_data)
|
||||
# The ``message_id`` event is informational for the
|
||||
# streaming consumer and has no synchronous-API field;
|
||||
# skip it so the type-switch below doesn't KeyError.
|
||||
if event.get("type") == "message_id":
|
||||
continue
|
||||
|
||||
if event["type"] == "id":
|
||||
conversation_id = event["id"]
|
||||
|
||||
@@ -1,135 +0,0 @@
|
||||
"""GET /api/messages/<message_id>/events — chat-stream reconnect endpoint.
|
||||
|
||||
Authenticates the caller, verifies ``message_id`` belongs to the user,
|
||||
then hands off to ``build_message_event_stream`` for snapshot+tail.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Iterator, Optional
|
||||
|
||||
from flask import Blueprint, Response, jsonify, make_response, request, stream_with_context
|
||||
from sqlalchemy import text
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.session import db_readonly
|
||||
from application.streaming.event_replay import (
|
||||
DEFAULT_KEEPALIVE_SECONDS,
|
||||
DEFAULT_POLL_TIMEOUT_SECONDS,
|
||||
build_message_event_stream,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
messages_bp = Blueprint("message_stream", __name__)
|
||||
|
||||
# A message_id is the canonical UUID hex format. Reject anything else
|
||||
# before the SQL layer so a malformed cookie can't surface as a 500.
|
||||
_MESSAGE_ID_RE = re.compile(
|
||||
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-"
|
||||
r"[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$"
|
||||
)
|
||||
# ``sequence_no`` is a non-negative decimal integer. Anything else is
|
||||
# corrupt client state — fall through to a fresh-replay cursor and let
|
||||
# the snapshot reader catch the client up.
|
||||
_SEQUENCE_NO_RE = re.compile(r"^\d+$")
|
||||
|
||||
|
||||
def _normalise_last_event_id(raw: Optional[str]) -> Optional[int]:
|
||||
if raw is None:
|
||||
return None
|
||||
raw = raw.strip()
|
||||
if not raw or not _SEQUENCE_NO_RE.match(raw):
|
||||
return None
|
||||
return int(raw)
|
||||
|
||||
|
||||
def _user_owns_message(message_id: str, user_id: str) -> bool:
|
||||
"""Return True iff ``message_id`` belongs to ``user_id``."""
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
row = conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT 1 FROM conversation_messages
|
||||
WHERE id = CAST(:id AS uuid)
|
||||
AND user_id = :u
|
||||
LIMIT 1
|
||||
"""
|
||||
),
|
||||
{"id": message_id, "u": user_id},
|
||||
).first()
|
||||
return row is not None
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Ownership lookup failed for message_id=%s user_id=%s",
|
||||
message_id,
|
||||
user_id,
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
@messages_bp.route("/api/messages/<message_id>/events", methods=["GET"])
|
||||
def stream_message_events(message_id: str) -> Response:
|
||||
decoded = getattr(request, "decoded_token", None)
|
||||
user_id = decoded.get("sub") if isinstance(decoded, dict) else None
|
||||
if not user_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Authentication required"}),
|
||||
401,
|
||||
)
|
||||
|
||||
if not _MESSAGE_ID_RE.match(message_id):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid message id"}),
|
||||
400,
|
||||
)
|
||||
|
||||
if not _user_owns_message(message_id, user_id):
|
||||
# Don't disclose whether the row exists — a malicious caller
|
||||
# gets the same 404 whether the id is bogus, taken by another
|
||||
# user, or simply gone.
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Not found"}),
|
||||
404,
|
||||
)
|
||||
|
||||
raw_cursor = request.headers.get("Last-Event-ID") or request.args.get(
|
||||
"last_event_id"
|
||||
)
|
||||
last_event_id = _normalise_last_event_id(raw_cursor)
|
||||
keepalive_seconds = float(
|
||||
getattr(settings, "SSE_KEEPALIVE_SECONDS", DEFAULT_KEEPALIVE_SECONDS)
|
||||
)
|
||||
|
||||
@stream_with_context
|
||||
def generate() -> Iterator[str]:
|
||||
try:
|
||||
yield from build_message_event_stream(
|
||||
message_id,
|
||||
last_event_id=last_event_id,
|
||||
keepalive_seconds=keepalive_seconds,
|
||||
poll_timeout_seconds=DEFAULT_POLL_TIMEOUT_SECONDS,
|
||||
)
|
||||
except GeneratorExit:
|
||||
return
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Reconnect stream crashed for message_id=%s user_id=%s",
|
||||
message_id,
|
||||
user_id,
|
||||
)
|
||||
|
||||
response = Response(generate(), mimetype="text/event-stream")
|
||||
response.headers["Cache-Control"] = "no-store"
|
||||
response.headers["X-Accel-Buffering"] = "no"
|
||||
response.headers["Connection"] = "keep-alive"
|
||||
logger.info(
|
||||
"message.event.connect message_id=%s user_id=%s last_event_id=%s",
|
||||
message_id,
|
||||
user_id,
|
||||
last_event_id if last_event_id is not None else "-",
|
||||
)
|
||||
return response
|
||||
@@ -15,10 +15,7 @@ from sqlalchemy import text as sql_text
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.conversations import (
|
||||
ConversationsRepository,
|
||||
MessageUpdateOutcome,
|
||||
)
|
||||
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
|
||||
@@ -308,17 +305,10 @@ class ConversationService:
|
||||
status: str = "complete",
|
||||
error: Optional[BaseException] = None,
|
||||
title_inputs: Optional[Dict[str, Any]] = None,
|
||||
) -> MessageUpdateOutcome:
|
||||
"""Commit the response and tool_call confirms in one transaction.
|
||||
|
||||
The outcome propagates directly from ``update_message_by_id`` so
|
||||
callers (notably the SSE abort handler) can tell a fresh
|
||||
finalize from "the row was already terminal" — the latter must
|
||||
still be treated as success when the prior state was
|
||||
``complete``.
|
||||
"""
|
||||
) -> bool:
|
||||
"""Commit the response and tool_call confirms in one transaction."""
|
||||
if not message_id:
|
||||
return MessageUpdateOutcome.INVALID
|
||||
return False
|
||||
sources = sources or []
|
||||
for source in sources:
|
||||
if "text" in source and isinstance(source["text"], str):
|
||||
@@ -346,16 +336,16 @@ class ConversationService:
|
||||
# retracting a row the reconciler already escalated.
|
||||
with db_session() as conn:
|
||||
repo = ConversationsRepository(conn)
|
||||
outcome = repo.update_message_by_id(
|
||||
ok = repo.update_message_by_id(
|
||||
message_id, update_fields,
|
||||
only_if_non_terminal=True,
|
||||
)
|
||||
if outcome is not MessageUpdateOutcome.UPDATED:
|
||||
if not ok:
|
||||
logger.warning(
|
||||
f"finalize_message: no row updated for message_id={message_id} "
|
||||
f"(outcome={outcome.value} — possibly already terminal)"
|
||||
f"(possibly already terminal — reconciler may have escalated)"
|
||||
)
|
||||
return outcome
|
||||
return False
|
||||
repo.confirm_executed_tool_calls(message_id)
|
||||
|
||||
# Outside the txn — title-gen is a multi-second LLM round trip.
|
||||
@@ -368,7 +358,7 @@ class ConversationService:
|
||||
f"finalize_message title generation failed: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
return MessageUpdateOutcome.UPDATED
|
||||
return True
|
||||
|
||||
def _maybe_generate_title(
|
||||
self,
|
||||
|
||||
@@ -1,504 +0,0 @@
|
||||
"""GET /api/events — user-scoped Server-Sent Events endpoint.
|
||||
|
||||
Subscribe-then-snapshot pattern: subscribe to ``user:{user_id}``
|
||||
pub/sub, snapshot the Redis Streams backlog past ``Last-Event-ID``
|
||||
inside the SUBSCRIBE-ack callback, flush snapshot, then tail live
|
||||
events (dedup'd by stream id). See ``docs/runbooks/sse-notifications.md``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from typing import Iterator, Optional
|
||||
|
||||
from flask import Blueprint, Response, jsonify, make_response, request, stream_with_context
|
||||
|
||||
from application.cache import get_redis_instance
|
||||
from application.core.settings import settings
|
||||
from application.events.keys import (
|
||||
connection_counter_key,
|
||||
replay_budget_key,
|
||||
stream_id_compare,
|
||||
stream_key,
|
||||
topic_name,
|
||||
)
|
||||
from application.streaming.broadcast_channel import Topic
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
events = Blueprint("event_stream", __name__)
|
||||
|
||||
SUBSCRIBE_POLL_INTERVAL_SECONDS = 1.0
|
||||
|
||||
# WHATWG SSE treats CRLF, CR, and LF equivalently as line terminators.
|
||||
_SSE_LINE_SPLIT = re.compile(r"\r\n|\r|\n")
|
||||
|
||||
# Redis Streams ids are ``ms`` or ``ms-seq`` where both halves are decimal.
|
||||
# Anything else is a corrupted client cookie / IndexedDB residue and must
|
||||
# not be passed to XRANGE — Redis would reject it and our truncation gate
|
||||
# would silently fail.
|
||||
_STREAM_ID_RE = re.compile(r"^\d+(-\d+)?$")
|
||||
|
||||
# Only emitted at most once per process so a misconfigured deployment
|
||||
# doesn't drown the logs.
|
||||
_local_user_warned = False
|
||||
|
||||
|
||||
def _format_sse(data: str, *, event_id: Optional[str] = None) -> str:
|
||||
"""Encode a payload as one SSE message terminated by a blank line.
|
||||
|
||||
Splits on any line-terminator variant (``\\r\\n``, ``\\r``, ``\\n``)
|
||||
so a stray CR in upstream content can't smuggle a premature line
|
||||
boundary into the wire format.
|
||||
"""
|
||||
lines: list[str] = []
|
||||
if event_id:
|
||||
lines.append(f"id: {event_id}")
|
||||
for line in _SSE_LINE_SPLIT.split(data):
|
||||
lines.append(f"data: {line}")
|
||||
return "\n".join(lines) + "\n\n"
|
||||
|
||||
|
||||
def _decode(value) -> Optional[str]:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, (bytes, bytearray)):
|
||||
try:
|
||||
return value.decode("utf-8")
|
||||
except Exception:
|
||||
return None
|
||||
return str(value)
|
||||
|
||||
|
||||
def _oldest_retained_id(redis_client, user_id: str) -> Optional[str]:
|
||||
"""Return the id of the oldest entry still in the stream, or ``None``.
|
||||
|
||||
Used to detect ``Last-Event-ID`` having slid off the back of the
|
||||
MAXLEN'd window.
|
||||
"""
|
||||
try:
|
||||
info = redis_client.xinfo_stream(stream_key(user_id))
|
||||
except Exception:
|
||||
return None
|
||||
if not isinstance(info, dict):
|
||||
return None
|
||||
# redis-py 7.4 returns str-keyed dicts here; the bytes-key probe is
|
||||
# defence in depth in case ``decode_responses`` is ever flipped.
|
||||
first_entry = info.get("first-entry") or info.get(b"first-entry")
|
||||
if not first_entry:
|
||||
return None
|
||||
# XINFO STREAM returns first-entry as [id, [field, value, ...]]
|
||||
try:
|
||||
return _decode(first_entry[0])
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _allow_replay(
|
||||
redis_client, user_id: str, last_event_id: Optional[str]
|
||||
) -> bool:
|
||||
"""Per-user sliding-window snapshot-replay budget.
|
||||
|
||||
Fails open on Redis errors or when the budget is disabled. Empty-backlog
|
||||
no-cursor connects skip INCR so dev double-mounts don't trip 429.
|
||||
"""
|
||||
budget = int(settings.EVENTS_REPLAY_BUDGET_REQUESTS_PER_WINDOW)
|
||||
if budget <= 0:
|
||||
return True
|
||||
if redis_client is None:
|
||||
return True
|
||||
|
||||
# Cheap pre-check: only INCR when we might actually replay. XLEN
|
||||
# is one Redis op; the alternative (INCR every connect) is two
|
||||
# ops AND wrongly counts no-op probes. The check is conservative:
|
||||
# if ``last_event_id`` is set we always INCR, even if the cursor
|
||||
# has already overtaken the latest entry — that case is rare and
|
||||
# short-lived, and probing further would mean a redundant XRANGE.
|
||||
if last_event_id is None:
|
||||
try:
|
||||
if int(redis_client.xlen(stream_key(user_id))) == 0:
|
||||
return True
|
||||
except Exception:
|
||||
# XLEN probe failed; fall through to the INCR path so a
|
||||
# transient Redis hiccup can't bypass the budget.
|
||||
logger.debug(
|
||||
"XLEN probe failed for replay budget check user=%s; "
|
||||
"proceeding to INCR",
|
||||
user_id,
|
||||
)
|
||||
|
||||
window = max(1, int(settings.EVENTS_REPLAY_BUDGET_WINDOW_SECONDS))
|
||||
key = replay_budget_key(user_id)
|
||||
try:
|
||||
used = int(redis_client.incr(key))
|
||||
# Always (re)seed the TTL. Gating on ``used == 1`` would wedge
|
||||
# the counter forever if INCR succeeds but EXPIRE raises on
|
||||
# the seeding call. EXPIRE on an existing key resets the TTL
|
||||
# to ``window`` — within ±1s of the per-window budget semantic.
|
||||
redis_client.expire(key, window)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"replay budget probe failed for user=%s; failing open",
|
||||
user_id,
|
||||
)
|
||||
return True
|
||||
return used <= budget
|
||||
|
||||
|
||||
def _normalize_last_event_id(raw: Optional[str]) -> Optional[str]:
|
||||
"""Validate the ``Last-Event-ID`` header / query param.
|
||||
|
||||
Returns the value unchanged when it parses as a Redis Streams id,
|
||||
otherwise ``None`` — callers treat ``None`` as "client has nothing"
|
||||
and replay from the start of the retained window. Invalid ids would
|
||||
otherwise pass straight to XRANGE and surface as a quiet replay
|
||||
failure plus broken truncation detection.
|
||||
"""
|
||||
if raw is None:
|
||||
return None
|
||||
raw = raw.strip()
|
||||
if not raw or not _STREAM_ID_RE.match(raw):
|
||||
return None
|
||||
return raw
|
||||
|
||||
|
||||
def _replay_backlog(
|
||||
redis_client, user_id: str, last_event_id: Optional[str], max_count: int
|
||||
) -> Iterator[tuple[str, str]]:
|
||||
"""Yield ``(entry_id, sse_line)`` for backlog entries past ``last_event_id``.
|
||||
|
||||
Capped at ``max_count`` rows; clients catch up across reconnects.
|
||||
Parse failures are skipped; the Streams id is injected into the
|
||||
envelope so replay matches live-tail shape.
|
||||
"""
|
||||
# Exclusive start: '(<id>' skips the already-delivered entry.
|
||||
start = f"({last_event_id}" if last_event_id else "-"
|
||||
try:
|
||||
entries = redis_client.xrange(
|
||||
stream_key(user_id), min=start, max="+", count=max_count
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"xrange replay failed for user=%s last_id=%s err=%s",
|
||||
user_id,
|
||||
last_event_id or "-",
|
||||
exc,
|
||||
)
|
||||
return
|
||||
|
||||
for entry_id, fields in entries:
|
||||
entry_id_str = _decode(entry_id)
|
||||
if not entry_id_str:
|
||||
continue
|
||||
# decode_responses=False on the cache client ⇒ field keys/values
|
||||
# are bytes. The string-key fallback covers a future flip of that
|
||||
# default without a forced refactor here.
|
||||
raw_event = None
|
||||
if isinstance(fields, dict):
|
||||
raw_event = fields.get(b"event")
|
||||
if raw_event is None:
|
||||
raw_event = fields.get("event")
|
||||
event_str = _decode(raw_event)
|
||||
if not event_str:
|
||||
continue
|
||||
try:
|
||||
envelope = json.loads(event_str)
|
||||
if isinstance(envelope, dict):
|
||||
envelope["id"] = entry_id_str
|
||||
event_str = json.dumps(envelope)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Replay envelope parse failed for entry %s; passing through raw",
|
||||
entry_id_str,
|
||||
)
|
||||
yield entry_id_str, _format_sse(event_str, event_id=entry_id_str)
|
||||
|
||||
|
||||
def _truncation_notice_line(oldest_id: str) -> str:
|
||||
"""SSE event the frontend can react to with a full-state refetch."""
|
||||
return _format_sse(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "backlog.truncated",
|
||||
"payload": {"oldest_retained_id": oldest_id},
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@events.route("/api/events", methods=["GET"])
|
||||
def stream_events() -> Response:
|
||||
decoded = getattr(request, "decoded_token", None)
|
||||
user_id = decoded.get("sub") if isinstance(decoded, dict) else None
|
||||
if not user_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Authentication required"}),
|
||||
401,
|
||||
)
|
||||
|
||||
# In dev deployments without AUTH_TYPE configured, every request
|
||||
# resolves to user_id="local" and shares one stream. Surface this so
|
||||
# an accidentally-multi-user dev box doesn't silently cross-stream.
|
||||
global _local_user_warned
|
||||
if user_id == "local" and not _local_user_warned:
|
||||
logger.warning(
|
||||
"SSE serving user_id='local' (AUTH_TYPE not set). "
|
||||
"All clients on this deployment will share one event stream."
|
||||
)
|
||||
_local_user_warned = True
|
||||
|
||||
raw_last_event_id = request.headers.get("Last-Event-ID") or request.args.get(
|
||||
"last_event_id"
|
||||
)
|
||||
last_event_id = _normalize_last_event_id(raw_last_event_id)
|
||||
last_event_id_invalid = raw_last_event_id is not None and last_event_id is None
|
||||
|
||||
keepalive_seconds = float(settings.SSE_KEEPALIVE_SECONDS)
|
||||
push_enabled = settings.ENABLE_SSE_PUSH
|
||||
cap = int(settings.SSE_MAX_CONCURRENT_PER_USER)
|
||||
|
||||
redis_client = get_redis_instance()
|
||||
counter_key = connection_counter_key(user_id)
|
||||
counted = False
|
||||
|
||||
if push_enabled and redis_client is not None and cap > 0:
|
||||
try:
|
||||
current = int(redis_client.incr(counter_key))
|
||||
counted = True
|
||||
except Exception:
|
||||
current = 0
|
||||
logger.debug(
|
||||
"SSE connection counter INCR failed for user=%s", user_id
|
||||
)
|
||||
if counted:
|
||||
# 1h safety TTL — orphaned counts from hard crashes self-heal.
|
||||
# EXPIRE failure must NOT clobber ``current`` and bypass the cap.
|
||||
try:
|
||||
redis_client.expire(counter_key, 3600)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"SSE connection counter EXPIRE failed for user=%s", user_id
|
||||
)
|
||||
if current > cap:
|
||||
try:
|
||||
redis_client.decr(counter_key)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"SSE connection counter DECR failed for user=%s",
|
||||
user_id,
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Too many concurrent SSE connections",
|
||||
}
|
||||
),
|
||||
429,
|
||||
)
|
||||
|
||||
# Replay budget is checked here, before the generator opens the
|
||||
# stream, so a denial can surface as HTTP 429 instead of a silent
|
||||
# snapshot skip. The earlier in-generator skip lost events between
|
||||
# the client's cursor and the first live-tailed entry: the live
|
||||
# tail still carried ``id:`` headers, the frontend advanced
|
||||
# ``lastEventId`` to one of those ids, and the events in between
|
||||
# were never reachable on the next reconnect. 429 keeps the
|
||||
# cursor pinned and lets the frontend back off until the window
|
||||
# slides (eventStreamClient.ts treats 429 as escalated backoff).
|
||||
if push_enabled and redis_client is not None and not _allow_replay(
|
||||
redis_client, user_id, last_event_id
|
||||
):
|
||||
if counted:
|
||||
try:
|
||||
redis_client.decr(counter_key)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"SSE connection counter DECR failed for user=%s",
|
||||
user_id,
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Replay budget exhausted",
|
||||
}
|
||||
),
|
||||
429,
|
||||
)
|
||||
|
||||
@stream_with_context
|
||||
def generate() -> Iterator[str]:
|
||||
connect_ts = time.monotonic()
|
||||
replayed_count = 0
|
||||
try:
|
||||
# First frame primes intermediaries (Cloudflare, nginx) so they
|
||||
# don't sit on a buffer waiting for body bytes.
|
||||
yield ": connected\n\n"
|
||||
|
||||
if not push_enabled:
|
||||
yield ": push_disabled\n\n"
|
||||
return
|
||||
|
||||
replay_lines: list[str] = []
|
||||
max_replayed_id: Optional[str] = None
|
||||
replay_done = False
|
||||
|
||||
# If the client sent a malformed Last-Event-ID, surface the
|
||||
# truncation notice synchronously *before* the subscribe
|
||||
# loop. Buffering it into ``replay_lines`` would lose it
|
||||
# when ``Topic.subscribe`` returns immediately (Redis down)
|
||||
# — the loop body never runs, and the flush at line ~335
|
||||
# never fires.
|
||||
if last_event_id_invalid:
|
||||
yield _truncation_notice_line("")
|
||||
replayed_count += 1
|
||||
|
||||
def _on_subscribe_callback() -> None:
|
||||
# Runs synchronously inside Topic.subscribe after the
|
||||
# SUBSCRIBE is acked. By doing XRANGE here, any publisher
|
||||
# firing between SUBSCRIBE-send and SUBSCRIBE-ack has its
|
||||
# XADD captured by XRANGE *and* its PUBLISH buffered at
|
||||
# the connection layer until we read it — closing the
|
||||
# replay/subscribe race the design doc warns about.
|
||||
#
|
||||
# Truncation contract: ``backlog.truncated`` is emitted
|
||||
# ONLY when the client's ``Last-Event-ID`` has slid off
|
||||
# the MAXLEN'd window — that's the case where the
|
||||
# journal is genuinely gone past the cursor and the
|
||||
# frontend should clear its slice cursor and refetch
|
||||
# state. Cap-hit skips the snapshot silently: the
|
||||
# cursor advances via the per-entry ``id:`` headers
|
||||
# and the frontend's slice keeps the latest id so the
|
||||
# next reconnect resumes from there. Budget-exhausted
|
||||
# never reaches this callback — the route 429s before
|
||||
# opening the stream, keeping the cursor pinned.
|
||||
# Conflating these with stale-cursor truncation would
|
||||
# tell the client to clear its cursor and re-receive
|
||||
# the same oldest-N entries on every reconnect —
|
||||
# locking the user out of entries past N.
|
||||
nonlocal max_replayed_id, replay_done
|
||||
try:
|
||||
if redis_client is None:
|
||||
return
|
||||
oldest = _oldest_retained_id(redis_client, user_id)
|
||||
if (
|
||||
last_event_id
|
||||
and oldest
|
||||
and stream_id_compare(last_event_id, oldest) < 0
|
||||
):
|
||||
# The Last-Event-ID has slid off the MAXLEN window.
|
||||
# Tell the client so it can fetch full state.
|
||||
replay_lines.append(_truncation_notice_line(oldest))
|
||||
replay_cap = int(settings.EVENTS_REPLAY_MAX_PER_REQUEST)
|
||||
for entry_id, sse_line in _replay_backlog(
|
||||
redis_client, user_id, last_event_id, replay_cap
|
||||
):
|
||||
replay_lines.append(sse_line)
|
||||
max_replayed_id = entry_id
|
||||
finally:
|
||||
# Always flip the flag — even on partial-replay failure
|
||||
# the outer loop must reach the flush step so we don't
|
||||
# silently strand whatever entries did land.
|
||||
replay_done = True
|
||||
|
||||
topic = Topic(topic_name(user_id))
|
||||
last_keepalive = time.monotonic()
|
||||
for payload in topic.subscribe(
|
||||
on_subscribe=_on_subscribe_callback,
|
||||
poll_timeout=SUBSCRIBE_POLL_INTERVAL_SECONDS,
|
||||
):
|
||||
# Flush snapshot on the first iteration after the SUBSCRIBE
|
||||
# callback ran. This runs at most once per connection.
|
||||
if replay_done and replay_lines:
|
||||
for line in replay_lines:
|
||||
yield line
|
||||
replayed_count += 1
|
||||
replay_lines.clear()
|
||||
|
||||
now = time.monotonic()
|
||||
if payload is None:
|
||||
if now - last_keepalive >= keepalive_seconds:
|
||||
yield ": keepalive\n\n"
|
||||
last_keepalive = now
|
||||
continue
|
||||
|
||||
event_str = _decode(payload) or ""
|
||||
event_id: Optional[str] = None
|
||||
try:
|
||||
envelope = json.loads(event_str)
|
||||
if isinstance(envelope, dict):
|
||||
candidate = envelope.get("id")
|
||||
# Only trust ids that look like real Redis Streams
|
||||
# ids (``ms`` or ``ms-seq``). A malformed or
|
||||
# adversarial publisher could otherwise pin
|
||||
# dedupe forever — a lex-greater bogus id would
|
||||
# make every legitimate later id compare ``<=``
|
||||
# and get dropped silently.
|
||||
if isinstance(candidate, str) and _STREAM_ID_RE.match(
|
||||
candidate
|
||||
):
|
||||
event_id = candidate
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Dedupe: if this id was already covered by replay, drop it.
|
||||
if (
|
||||
event_id is not None
|
||||
and max_replayed_id is not None
|
||||
and stream_id_compare(event_id, max_replayed_id) <= 0
|
||||
):
|
||||
continue
|
||||
|
||||
yield _format_sse(event_str, event_id=event_id)
|
||||
last_keepalive = now
|
||||
|
||||
# Topic.subscribe exited before the first yield (transient
|
||||
# Redis hiccup between SUBSCRIBE-ack and the first poll, or
|
||||
# an immediate Redis-down return). The callback may already
|
||||
# have populated the snapshot — flush it so the client gets
|
||||
# the backlog instead of a silent drop. Safe no-op when the
|
||||
# in-loop flush ran (it clear()'d the buffer) and when the
|
||||
# callback never fired (replay_done stays False).
|
||||
if replay_done and replay_lines:
|
||||
for line in replay_lines:
|
||||
yield line
|
||||
replayed_count += 1
|
||||
replay_lines.clear()
|
||||
except GeneratorExit:
|
||||
return
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"SSE event-stream generator crashed for user=%s", user_id
|
||||
)
|
||||
finally:
|
||||
duration_s = time.monotonic() - connect_ts
|
||||
logger.info(
|
||||
"event.disconnect user=%s duration_s=%.1f replayed=%d",
|
||||
user_id,
|
||||
duration_s,
|
||||
replayed_count,
|
||||
)
|
||||
if counted and redis_client is not None:
|
||||
try:
|
||||
redis_client.decr(counter_key)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"SSE connection counter DECR failed for user=%s on disconnect",
|
||||
user_id,
|
||||
)
|
||||
|
||||
response = Response(generate(), mimetype="text/event-stream")
|
||||
response.headers["Cache-Control"] = "no-store"
|
||||
response.headers["X-Accel-Buffering"] = "no"
|
||||
response.headers["Connection"] = "keep-alive"
|
||||
logger.info(
|
||||
"event.connect user=%s last_event_id=%s%s",
|
||||
user_id,
|
||||
last_event_id or "-",
|
||||
" (rejected_invalid)" if last_event_id_invalid else "",
|
||||
)
|
||||
return response
|
||||
@@ -214,10 +214,6 @@ class StoreAttachment(Resource):
|
||||
{
|
||||
"success": True,
|
||||
"task_id": tasks[0]["task_id"],
|
||||
# Surface the attachment_id so the frontend
|
||||
# can correlate ``attachment.*`` SSE events
|
||||
# to this row and skip the polling fallback.
|
||||
"attachment_id": tasks[0]["attachment_id"],
|
||||
"message": "File uploaded successfully. Processing started.",
|
||||
}
|
||||
),
|
||||
|
||||
@@ -7,13 +7,9 @@ from flask_restx import fields, Namespace, Resource
|
||||
from sqlalchemy import text as sql_text
|
||||
|
||||
from application.api import api
|
||||
from application.api.answer.services.conversation_service import (
|
||||
TERMINATED_RESPONSE_PLACEHOLDER,
|
||||
)
|
||||
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
|
||||
from application.storage.db.repositories.attachments import AttachmentsRepository
|
||||
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||
from application.storage.db.repositories.message_events import MessageEventsRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.utils import check_required_fields
|
||||
|
||||
@@ -110,6 +106,85 @@ class GetConversations(Resource):
|
||||
return make_response(jsonify(list_conversations), 200)
|
||||
|
||||
|
||||
@conversations_ns.route("/search_conversations")
|
||||
class SearchConversations(Resource):
|
||||
@staticmethod
|
||||
def _build_match_snippet(text_value: str, query: str, radius: int = 60) -> str:
|
||||
if not text_value:
|
||||
return ""
|
||||
idx = text_value.lower().find(query.lower())
|
||||
if idx == -1:
|
||||
snippet = text_value[: radius * 2]
|
||||
return snippet + ("…" if len(text_value) > len(snippet) else "")
|
||||
start = max(0, idx - radius)
|
||||
end = min(len(text_value), idx + len(query) + radius)
|
||||
snippet = text_value[start:end]
|
||||
if start > 0:
|
||||
snippet = "…" + snippet
|
||||
if end < len(text_value):
|
||||
snippet = snippet + "…"
|
||||
return snippet
|
||||
|
||||
@api.doc(
|
||||
description=(
|
||||
"Search the authenticated user's conversations by name or "
|
||||
"message content (case-insensitive substring match). Mirrors "
|
||||
"the visibility filter and response shape of /get_conversations, "
|
||||
"and additionally returns ``match_field`` (``name``, ``prompt`` "
|
||||
"or ``response``) and ``match_snippet`` (a short excerpt of the "
|
||||
"matched text centered on the query) for each result."
|
||||
),
|
||||
params={
|
||||
"q": "Search term (required)",
|
||||
"limit": "Maximum number of results to return (default 30, max 100)",
|
||||
},
|
||||
)
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
query = (request.args.get("q") or "").strip()
|
||||
if not query:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "q is required"}), 400
|
||||
)
|
||||
try:
|
||||
limit = int(request.args.get("limit", 30))
|
||||
except (TypeError, ValueError):
|
||||
limit = 30
|
||||
limit = max(1, min(limit, 100))
|
||||
user_id = decoded_token.get("sub")
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
conversations = ConversationsRepository(conn).search_for_user(
|
||||
user_id, query, limit=limit
|
||||
)
|
||||
list_conversations = [
|
||||
{
|
||||
"id": str(conversation["id"]),
|
||||
"name": conversation["name"],
|
||||
"agent_id": (
|
||||
str(conversation["agent_id"])
|
||||
if conversation.get("agent_id")
|
||||
else None
|
||||
),
|
||||
"is_shared_usage": conversation.get("is_shared_usage", False),
|
||||
"shared_token": conversation.get("shared_token", None),
|
||||
"match_field": conversation.get("match_field"),
|
||||
"match_snippet": self._build_match_snippet(
|
||||
conversation.get("match_text") or "", query
|
||||
),
|
||||
}
|
||||
for conversation in conversations
|
||||
]
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error searching conversations: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify(list_conversations), 200)
|
||||
|
||||
|
||||
@conversations_ns.route("/get_single_conversation")
|
||||
class GetSingleConversation(Resource):
|
||||
@api.doc(
|
||||
@@ -350,25 +425,6 @@ class GetMessageTail(Resource):
|
||||
if row is None:
|
||||
return make_response(jsonify({"status": "not found"}), 404)
|
||||
msg = row_to_dict(row)
|
||||
# Mid-stream the row's response is the placeholder; rebuild
|
||||
# the live partial from the journal so /tail mirrors SSE.
|
||||
status = msg.get("status")
|
||||
response = msg.get("response")
|
||||
thought = msg.get("thought")
|
||||
sources = msg.get("sources") or []
|
||||
tool_calls = msg.get("tool_calls") or []
|
||||
if status in ("pending", "streaming") and (
|
||||
response == TERMINATED_RESPONSE_PLACEHOLDER
|
||||
):
|
||||
partial = MessageEventsRepository(conn).reconstruct_partial(
|
||||
message_id
|
||||
)
|
||||
response = partial["response"]
|
||||
thought = partial["thought"] or thought
|
||||
if partial["sources"]:
|
||||
sources = partial["sources"]
|
||||
if partial["tool_calls"]:
|
||||
tool_calls = partial["tool_calls"]
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error tailing message {message_id}: {err}", exc_info=True
|
||||
@@ -379,11 +435,11 @@ class GetMessageTail(Resource):
|
||||
jsonify(
|
||||
{
|
||||
"message_id": str(msg["id"]),
|
||||
"status": status,
|
||||
"response": response,
|
||||
"thought": thought,
|
||||
"sources": sources,
|
||||
"tool_calls": tool_calls,
|
||||
"status": msg.get("status"),
|
||||
"response": msg.get("response"),
|
||||
"thought": msg.get("thought"),
|
||||
"sources": msg.get("sources") or [],
|
||||
"tool_calls": msg.get("tool_calls") or [],
|
||||
"request_id": msg.get("request_id"),
|
||||
"last_heartbeat_at": metadata.get("last_heartbeat_at"),
|
||||
"error": metadata.get("error"),
|
||||
|
||||
@@ -114,11 +114,11 @@ def run_reconciliation() -> Dict[str, Any]:
|
||||
},
|
||||
)
|
||||
|
||||
# Q4: ingest checkpoints whose heartbeat has gone silent. Each is
|
||||
# escalated to terminal ``status='stalled'`` and alerted once — no
|
||||
# worker kill, no rollback of the partial embed. The 'stalled' flag
|
||||
# ends the re-alert loop and drives the "indexing failed" badge the
|
||||
# sources list derives from this row.
|
||||
# Q4: ingest checkpoints whose heartbeat has gone silent. The
|
||||
# reconciler only escalates (alerts) — it doesn't kill the worker
|
||||
# or roll back the partial embed. The next dispatch resumes from
|
||||
# ``last_index`` thanks to the per-chunk checkpoint, so this is an
|
||||
# observability sweep, not a recovery action.
|
||||
with engine.begin() as conn:
|
||||
repo = ReconciliationRepository(conn)
|
||||
for row in repo.find_and_lock_stalled_ingests():
|
||||
@@ -134,7 +134,8 @@ def run_reconciliation() -> Dict[str, Any]:
|
||||
"last_updated": str(row.get("last_updated")),
|
||||
},
|
||||
)
|
||||
repo.mark_ingest_stalled(str(row["source_id"]))
|
||||
# Bump the heartbeat so we don't re-alert every tick.
|
||||
repo.touch_ingest_progress(str(row["source_id"]))
|
||||
|
||||
# Q5: idempotency rows whose lease expired with attempts exhausted.
|
||||
# The wrapper's poison-loop guard normally finalises these, but if
|
||||
|
||||
@@ -7,12 +7,8 @@ from flask import current_app, jsonify, make_response, redirect, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.tasks import reingest_source_task, sync_source
|
||||
from application.api.user.tasks import sync_source
|
||||
from application.core.settings import settings
|
||||
from application.parser.remote.remote_creator import normalize_remote_data
|
||||
from application.storage.db.repositories.ingest_chunk_progress import (
|
||||
IngestChunkProgressRepository,
|
||||
)
|
||||
from application.storage.db.repositories.sources import SourcesRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
@@ -143,8 +139,6 @@ class PaginatedSources(Resource):
|
||||
"provider": provider,
|
||||
"isNested": bool(doc.get("directory_structure")),
|
||||
"type": doc.get("type", "file"),
|
||||
# Derived in SourcesRepository.list_for_user.
|
||||
"ingestStatus": doc.get("ingest_status"),
|
||||
}
|
||||
)
|
||||
response = {
|
||||
@@ -328,7 +322,7 @@ class SyncSource(Resource):
|
||||
),
|
||||
400,
|
||||
)
|
||||
source_data = normalize_remote_data(source_type, doc.get("remote_data"))
|
||||
source_data = doc.get("remote_data")
|
||||
if not source_data:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Source is not syncable"}), 400
|
||||
@@ -352,70 +346,6 @@ class SyncSource(Resource):
|
||||
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
|
||||
|
||||
|
||||
@sources_ns.route("/sources/reingest")
|
||||
class ReingestSource(Resource):
|
||||
reingest_source_model = api.model(
|
||||
"ReingestSourceModel",
|
||||
{"source_id": fields.String(required=True, description="Source ID")},
|
||||
)
|
||||
|
||||
@api.expect(reingest_source_model)
|
||||
@api.doc(
|
||||
description="Re-run ingestion for a source — e.g. to recover a "
|
||||
"stalled embed flagged by the reconciler."
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json() or {}
|
||||
missing_fields = check_required_fields(data, ["source_id"])
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
source_id = data["source_id"]
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
doc = SourcesRepository(conn).get_any(source_id, user)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error looking up source: {err}", exc_info=True
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid source ID"}), 400
|
||||
)
|
||||
if not doc:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Source not found"}), 404
|
||||
)
|
||||
resolved_source_id = str(doc["id"])
|
||||
# Drop the stale chunk-progress row so the sources list stops
|
||||
# deriving a 'failed' status; reingest never rewrites it itself.
|
||||
try:
|
||||
with db_session() as conn:
|
||||
IngestChunkProgressRepository(conn).delete(resolved_source_id)
|
||||
except Exception as err:
|
||||
current_app.logger.warning(
|
||||
f"Could not clear ingest progress for {resolved_source_id}: "
|
||||
f"{err}",
|
||||
exc_info=True,
|
||||
)
|
||||
try:
|
||||
# Scoped key so repeated clicks collapse onto one reingest.
|
||||
task = reingest_source_task.delay(
|
||||
source_id=resolved_source_id,
|
||||
user=user,
|
||||
idempotency_key=f"reingest-source:{user}:{resolved_source_id}",
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error starting reingest for source {source_id}: {err}",
|
||||
exc_info=True,
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
|
||||
|
||||
|
||||
@sources_ns.route("/directory_structure")
|
||||
class DirectoryStructure(Resource):
|
||||
@api.doc(
|
||||
|
||||
@@ -13,7 +13,6 @@ from sqlalchemy import text as sql_text
|
||||
from application.api import api
|
||||
from application.api.user.tasks import ingest, ingest_connector_task, ingest_remote
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.source_ids import derive_source_id as _derive_source_id
|
||||
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||
from application.parser.file.constants import SUPPORTED_SOURCE_EXTENSIONS
|
||||
from application.storage.db.repositories.idempotency import IdempotencyRepository
|
||||
@@ -70,13 +69,7 @@ def _claim_task_or_get_cached(key, task_name):
|
||||
|
||||
Pre-generates the celery task_id so a losing writer sees the same
|
||||
id immediately. Returns ``(task_id, cached_response)``; non-None
|
||||
cached means the caller should return without enqueuing. The
|
||||
cached payload mirrors the fresh-request response shape (including
|
||||
``source_id``) so the frontend can correlate SSE ingest events to
|
||||
the cached upload task without an extra round-trip — but only when
|
||||
the cached row actually exists; the "deduplicated" sentinel
|
||||
deliberately omits ``source_id`` so the frontend doesn't bind to a
|
||||
phantom source.
|
||||
cached means the caller should return without enqueuing.
|
||||
"""
|
||||
predetermined_id = str(uuid.uuid4())
|
||||
with db_session() as conn:
|
||||
@@ -88,16 +81,10 @@ def _claim_task_or_get_cached(key, task_name):
|
||||
with db_readonly() as conn:
|
||||
existing = IdempotencyRepository(conn).get_task(key)
|
||||
cached_id = existing.get("task_id") if existing else None
|
||||
payload: dict = {
|
||||
return None, {
|
||||
"success": True,
|
||||
"task_id": cached_id or "deduplicated",
|
||||
}
|
||||
# Only surface ``source_id`` when there's a real winner whose worker
|
||||
# is publishing SSE events tagged with that id. The "deduplicated"
|
||||
# branch means the lock row vanished — we have nothing to correlate.
|
||||
if cached_id is not None:
|
||||
payload["source_id"] = str(_derive_source_id(key))
|
||||
return None, payload
|
||||
|
||||
|
||||
def _release_claim(key):
|
||||
@@ -249,15 +236,6 @@ class UploadFile(Resource):
|
||||
file_path = f"{base_path}/{safe_file}"
|
||||
with open(temp_file_path, "rb") as f:
|
||||
storage.save_file(f, file_path)
|
||||
# Mint the source UUID up here so the HTTP response and the
|
||||
# worker's SSE envelopes share one id. With an idempotency
|
||||
# key we reuse the deterministic uuid5 (retried task lands on
|
||||
# the same source row); without a key we fall back to uuid4.
|
||||
# The worker is told to use this id verbatim — see
|
||||
# ``ingest_worker(source_id=...)``.
|
||||
source_uuid = (
|
||||
_derive_source_id(scoped_key) if scoped_key else uuid.uuid4()
|
||||
)
|
||||
ingest_kwargs = dict(
|
||||
args=(
|
||||
settings.UPLOAD_FOLDER,
|
||||
@@ -271,7 +249,6 @@ class UploadFile(Resource):
|
||||
"file_name_map": file_name_map,
|
||||
# Scoped so the worker dedup row matches the HTTP claim.
|
||||
"idempotency_key": scoped_key or idempotency_key,
|
||||
"source_id": str(source_uuid),
|
||||
},
|
||||
)
|
||||
if predetermined_task_id is not None:
|
||||
@@ -296,15 +273,7 @@ class UploadFile(Resource):
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
# Predetermined id matches the dedup-claim row; loser GET sees same.
|
||||
response_task_id = predetermined_task_id or task.id
|
||||
# ``source_uuid`` was minted above and passed to the worker as
|
||||
# ``source_id``; the worker uses it verbatim for every SSE event,
|
||||
# so the frontend can correlate inbound ``source.ingest.*`` to
|
||||
# this upload regardless of whether an idempotency key was set.
|
||||
response_payload: dict = {
|
||||
"success": True,
|
||||
"task_id": response_task_id,
|
||||
"source_id": str(source_uuid),
|
||||
}
|
||||
response_payload = {"success": True, "task_id": response_task_id}
|
||||
return make_response(jsonify(response_payload), 200)
|
||||
|
||||
|
||||
@@ -357,18 +326,6 @@ class UploadRemote(Resource):
|
||||
)
|
||||
if cached is not None:
|
||||
return make_response(jsonify(cached), 200)
|
||||
# Mint the source UUID up here so the HTTP response and the
|
||||
# worker's SSE envelopes share one id. Same pattern as
|
||||
# ``UploadFile.post``: with an idempotency key we reuse the
|
||||
# deterministic uuid5 (retried task lands on the same source
|
||||
# row); without a key we fall back to uuid4. The worker is told
|
||||
# to use this id verbatim — see ``remote_worker`` and
|
||||
# ``ingest_connector``. Without this the no-key path would mint
|
||||
# a random uuid4 inside the worker that the frontend has no way
|
||||
# to correlate SSE events to.
|
||||
source_uuid = (
|
||||
_derive_source_id(scoped_key) if scoped_key else uuid.uuid4()
|
||||
)
|
||||
try:
|
||||
config = json.loads(data["data"])
|
||||
source_data = None
|
||||
@@ -425,23 +382,13 @@ class UploadRemote(Resource):
|
||||
"recursive": config.get("recursive", False),
|
||||
"retriever": config.get("retriever", "classic"),
|
||||
"idempotency_key": scoped_key or idempotency_key,
|
||||
"source_id": str(source_uuid),
|
||||
},
|
||||
}
|
||||
if predetermined_task_id is not None:
|
||||
connector_kwargs["task_id"] = predetermined_task_id
|
||||
task = ingest_connector_task.apply_async(**connector_kwargs)
|
||||
response_task_id = predetermined_task_id or task.id
|
||||
# ``source_uuid`` was minted above and passed to the
|
||||
# worker as ``source_id``; the worker uses it verbatim
|
||||
# for every SSE event, so the frontend can correlate
|
||||
# inbound ``source.ingest.*`` regardless of whether an
|
||||
# idempotency key was set.
|
||||
response_payload = {
|
||||
"success": True,
|
||||
"task_id": response_task_id,
|
||||
"source_id": str(source_uuid),
|
||||
}
|
||||
response_payload = {"success": True, "task_id": response_task_id}
|
||||
return make_response(jsonify(response_payload), 200)
|
||||
remote_kwargs = {
|
||||
"kwargs": {
|
||||
@@ -450,7 +397,6 @@ class UploadRemote(Resource):
|
||||
"user": user,
|
||||
"loader": data["source"],
|
||||
"idempotency_key": scoped_key or idempotency_key,
|
||||
"source_id": str(source_uuid),
|
||||
},
|
||||
}
|
||||
if predetermined_task_id is not None:
|
||||
@@ -464,11 +410,7 @@ class UploadRemote(Resource):
|
||||
_release_claim(scoped_key)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
response_task_id = predetermined_task_id or task.id
|
||||
response_payload = {
|
||||
"success": True,
|
||||
"task_id": response_task_id,
|
||||
"source_id": str(source_uuid),
|
||||
}
|
||||
response_payload = {"success": True, "task_id": response_task_id}
|
||||
return make_response(jsonify(response_payload), 200)
|
||||
|
||||
|
||||
@@ -611,19 +553,6 @@ class ManageSourceFiles(Resource):
|
||||
scoped_key, "reingest_source_task",
|
||||
)
|
||||
if cached is not None:
|
||||
# Frontend keys reingest polling on
|
||||
# ``reingest_task_id``; the shared cache helper
|
||||
# writes ``task_id``. Alias here so a dedup
|
||||
# response doesn't silently break FileTree's
|
||||
# poller. Override ``source_id`` too — the
|
||||
# helper derives it from the scoped key, which
|
||||
# is correct for upload but wrong for reingest
|
||||
# (the worker publishes events scoped to the
|
||||
# actual source row id).
|
||||
cached_task_id = cached.pop("task_id", None)
|
||||
if cached_task_id is not None:
|
||||
cached["reingest_task_id"] = cached_task_id
|
||||
cached["source_id"] = resolved_source_id
|
||||
return make_response(jsonify(cached), 200)
|
||||
|
||||
added_files = []
|
||||
@@ -679,12 +608,6 @@ class ManageSourceFiles(Resource):
|
||||
"added_files": added_files,
|
||||
"parent_dir": parent_dir,
|
||||
"reingest_task_id": task.id,
|
||||
# ``source_id`` lets the frontend correlate
|
||||
# inbound ``source.ingest.*`` SSE events
|
||||
# (emitted by ``reingest_source_worker``)
|
||||
# back to the reingest task — matches the
|
||||
# upload route's source-id contract.
|
||||
"source_id": resolved_source_id,
|
||||
}
|
||||
),
|
||||
200,
|
||||
@@ -736,15 +659,6 @@ class ManageSourceFiles(Resource):
|
||||
scoped_key, "reingest_source_task",
|
||||
)
|
||||
if cached is not None:
|
||||
cached_task_id = cached.pop("task_id", None)
|
||||
if cached_task_id is not None:
|
||||
cached["reingest_task_id"] = cached_task_id
|
||||
# Override the helper's synthetic source_id (uuid5
|
||||
# of the scoped key) with the real source row id
|
||||
# — the reingest worker publishes SSE events
|
||||
# scoped to ``resolved_source_id`` and FileTree
|
||||
# correlates on it.
|
||||
cached["source_id"] = resolved_source_id
|
||||
return make_response(jsonify(cached), 200)
|
||||
|
||||
# Remove files from storage and directory structure
|
||||
@@ -790,7 +704,6 @@ class ManageSourceFiles(Resource):
|
||||
"message": f"Removed {len(removed_files)} files",
|
||||
"removed_files": removed_files,
|
||||
"reingest_task_id": task.id,
|
||||
"source_id": resolved_source_id,
|
||||
}
|
||||
),
|
||||
200,
|
||||
@@ -849,14 +762,6 @@ class ManageSourceFiles(Resource):
|
||||
scoped_key, "reingest_source_task",
|
||||
)
|
||||
if cached is not None:
|
||||
cached_task_id = cached.pop("task_id", None)
|
||||
if cached_task_id is not None:
|
||||
cached["reingest_task_id"] = cached_task_id
|
||||
# Same source_id override as the ``remove`` /
|
||||
# ``add`` cached branches — the helper's synthetic
|
||||
# id doesn't match what reingest_source_worker
|
||||
# tags its SSE events with.
|
||||
cached["source_id"] = resolved_source_id
|
||||
return make_response(jsonify(cached), 200)
|
||||
|
||||
success = storage.remove_directory(full_directory_path)
|
||||
@@ -920,7 +825,6 @@ class ManageSourceFiles(Resource):
|
||||
"message": f"Successfully removed directory: {directory_path}",
|
||||
"removed_directory": directory_path,
|
||||
"reingest_task_id": task.id,
|
||||
"source_id": resolved_source_id,
|
||||
}
|
||||
),
|
||||
200,
|
||||
|
||||
@@ -7,6 +7,7 @@ from application.worker import (
|
||||
attachment_worker,
|
||||
ingest_worker,
|
||||
mcp_oauth,
|
||||
mcp_oauth_status,
|
||||
remote_worker,
|
||||
sync,
|
||||
sync_worker,
|
||||
@@ -39,7 +40,6 @@ def ingest(
|
||||
filename,
|
||||
file_name_map=None,
|
||||
idempotency_key=None,
|
||||
source_id=None,
|
||||
):
|
||||
resp = ingest_worker(
|
||||
self,
|
||||
@@ -51,21 +51,16 @@ def ingest(
|
||||
user,
|
||||
file_name_map=file_name_map,
|
||||
idempotency_key=idempotency_key,
|
||||
source_id=source_id,
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(**DURABLE_TASK)
|
||||
@with_idempotency(task_name="ingest_remote")
|
||||
def ingest_remote(
|
||||
self, source_data, job_name, user, loader,
|
||||
idempotency_key=None, source_id=None,
|
||||
):
|
||||
def ingest_remote(self, source_data, job_name, user, loader, idempotency_key=None):
|
||||
resp = remote_worker(
|
||||
self, source_data, job_name, user, loader,
|
||||
idempotency_key=idempotency_key,
|
||||
source_id=source_id,
|
||||
)
|
||||
return resp
|
||||
|
||||
@@ -143,7 +138,6 @@ def ingest_connector_task(
|
||||
doc_id=None,
|
||||
sync_frequency="never",
|
||||
idempotency_key=None,
|
||||
source_id=None,
|
||||
):
|
||||
from application.worker import ingest_connector
|
||||
|
||||
@@ -161,7 +155,6 @@ def ingest_connector_task(
|
||||
doc_id=doc_id,
|
||||
sync_frequency=sync_frequency,
|
||||
idempotency_key=idempotency_key,
|
||||
source_id=source_id,
|
||||
)
|
||||
return resp
|
||||
|
||||
@@ -204,15 +197,6 @@ def setup_periodic_tasks(sender, **kwargs):
|
||||
version_check_task.s(),
|
||||
name="version-check",
|
||||
)
|
||||
# Bound ``message_events`` growth — every streamed SSE chunk writes
|
||||
# one row, so retained chats accumulate hundreds of rows per
|
||||
# message. Reconnect-replay is only meaningful for streams the user
|
||||
# could plausibly still be waiting on, so 14 days is generous.
|
||||
sender.add_periodic_task(
|
||||
timedelta(hours=24),
|
||||
cleanup_message_events.s(),
|
||||
name="cleanup-message-events",
|
||||
)
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
@@ -221,6 +205,12 @@ def mcp_oauth_task(self, config, user):
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def mcp_oauth_status_task(self, task_id):
|
||||
resp = mcp_oauth_status(self, task_id)
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(bind=True, acks_late=False)
|
||||
def cleanup_pending_tool_state(self):
|
||||
"""Revert stale ``resuming`` rows, then delete TTL-expired rows."""
|
||||
@@ -275,32 +265,6 @@ def reconciliation_task(self):
|
||||
return run_reconciliation()
|
||||
|
||||
|
||||
@celery.task(bind=True, acks_late=False)
|
||||
def cleanup_message_events(self):
|
||||
"""Delete ``message_events`` rows older than the retention window.
|
||||
|
||||
Streamed answer responses write one journal row per SSE yield,
|
||||
so unbounded growth would dominate Postgres for any retained-
|
||||
conversations deployment. The reconnect-replay path only needs
|
||||
rows for in-flight streams; 14 days covers paused/tool-action
|
||||
flows comfortably.
|
||||
"""
|
||||
from application.core.settings import settings
|
||||
if not settings.POSTGRES_URI:
|
||||
return {"deleted": 0, "skipped": "POSTGRES_URI not set"}
|
||||
|
||||
from application.storage.db.engine import get_engine
|
||||
from application.storage.db.repositories.message_events import (
|
||||
MessageEventsRepository,
|
||||
)
|
||||
|
||||
ttl_days = settings.MESSAGE_EVENTS_RETENTION_DAYS
|
||||
engine = get_engine()
|
||||
with engine.begin() as conn:
|
||||
deleted = MessageEventsRepository(conn).cleanup_older_than(ttl_days)
|
||||
return {"deleted": deleted, "ttl_days": ttl_days}
|
||||
|
||||
|
||||
@celery.task(bind=True, acks_late=False)
|
||||
def version_check_task(self):
|
||||
"""Periodic anonymous version check.
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Tool management MCP server integration."""
|
||||
|
||||
import json
|
||||
from urllib.parse import urlencode, urlparse
|
||||
|
||||
from flask import current_app, jsonify, make_response, redirect, request
|
||||
@@ -225,9 +226,7 @@ class MCPServerSave(Resource):
|
||||
)
|
||||
redis_client = get_redis_instance()
|
||||
manager = MCPOAuthManager(redis_client)
|
||||
result = manager.get_oauth_status(
|
||||
config["oauth_task_id"], user
|
||||
)
|
||||
result = manager.get_oauth_status(config["oauth_task_id"])
|
||||
if not result.get("status") == "completed":
|
||||
return make_response(
|
||||
jsonify(
|
||||
@@ -439,6 +438,56 @@ class MCPOAuthCallback(Resource):
|
||||
)
|
||||
|
||||
|
||||
@tools_mcp_ns.route("/mcp_server/oauth_status/<string:task_id>")
|
||||
class MCPOAuthStatus(Resource):
|
||||
def get(self, task_id):
|
||||
try:
|
||||
redis_client = get_redis_instance()
|
||||
status_key = f"mcp_oauth_status:{task_id}"
|
||||
status_data = redis_client.get(status_key)
|
||||
|
||||
if status_data:
|
||||
status = json.loads(status_data)
|
||||
if "tools" in status and isinstance(status["tools"], list):
|
||||
status["tools"] = [
|
||||
{
|
||||
"name": t.get("name", "unknown"),
|
||||
"description": t.get("description", ""),
|
||||
}
|
||||
for t in status["tools"]
|
||||
]
|
||||
return make_response(
|
||||
jsonify({"success": True, "task_id": task_id, **status})
|
||||
)
|
||||
else:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"task_id": task_id,
|
||||
"status": "pending",
|
||||
"message": "Waiting for OAuth to start...",
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error getting OAuth status for task {task_id}: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Failed to get OAuth status",
|
||||
"task_id": task_id,
|
||||
}
|
||||
),
|
||||
500,
|
||||
)
|
||||
|
||||
|
||||
@tools_mcp_ns.route("/mcp_server/auth_status")
|
||||
class MCPAuthStatus(Resource):
|
||||
@api.doc(
|
||||
|
||||
@@ -222,26 +222,13 @@ def _stream_response(
|
||||
for line in internal_stream:
|
||||
if not line.strip():
|
||||
continue
|
||||
# ``complete_stream`` prefixes each frame with ``id: <seq>\n``
|
||||
# before the ``data:`` line. Extract just the data line so JSON
|
||||
# decode doesn't choke on the SSE framing.
|
||||
event_str = ""
|
||||
for raw in line.split("\n"):
|
||||
if raw.startswith("data:"):
|
||||
event_str = raw[len("data:") :].lstrip()
|
||||
break
|
||||
if not event_str:
|
||||
continue
|
||||
# Parse the internal SSE event
|
||||
event_str = line.replace("data: ", "").strip()
|
||||
try:
|
||||
event_data = json.loads(event_str)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
continue
|
||||
|
||||
# Skip the informational ``message_id`` event — it has no v1 /
|
||||
# OpenAI-compatible analog.
|
||||
if event_data.get("type") == "message_id":
|
||||
continue
|
||||
|
||||
# Update completion_id when we get the conversation id
|
||||
if event_data.get("type") == "id":
|
||||
conv_id = event_data.get("id", "")
|
||||
|
||||
@@ -16,8 +16,6 @@ setup_logging()
|
||||
|
||||
from application.api import api # noqa: E402
|
||||
from application.api.answer import answer # noqa: E402
|
||||
from application.api.answer.routes.messages import messages_bp # noqa: E402
|
||||
from application.api.events.routes import events # noqa: E402
|
||||
from application.api.internal.routes import internal # noqa: E402
|
||||
from application.api.user.routes import user # noqa: E402
|
||||
from application.api.connector.routes import connector # noqa: E402
|
||||
@@ -51,8 +49,6 @@ ensure_database_ready(
|
||||
app = Flask(__name__)
|
||||
app.register_blueprint(user)
|
||||
app.register_blueprint(answer)
|
||||
app.register_blueprint(events)
|
||||
app.register_blueprint(messages_bp)
|
||||
app.register_blueprint(internal)
|
||||
app.register_blueprint(connector)
|
||||
app.register_blueprint(v1_bp)
|
||||
|
||||
@@ -29,17 +29,8 @@ def get_redis_instance():
|
||||
with _instance_lock:
|
||||
if _redis_instance is None and not _redis_creation_failed:
|
||||
try:
|
||||
# ``health_check_interval`` makes redis-py ping the
|
||||
# connection every N seconds when otherwise idle.
|
||||
# Without it, a half-open TCP (NAT silently dropped
|
||||
# state, ELB idle-close) can hang the SSE generator
|
||||
# in ``pubsub.get_message`` past its keepalive
|
||||
# cadence — the kernel never surfaces the dead
|
||||
# socket because no payload is in flight.
|
||||
_redis_instance = redis.Redis.from_url(
|
||||
settings.CACHE_REDIS_URL,
|
||||
socket_connect_timeout=2,
|
||||
health_check_interval=10,
|
||||
settings.CACHE_REDIS_URL, socket_connect_timeout=2
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"Invalid Redis URL: {e}")
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
import ctypes
|
||||
import gc
|
||||
import inspect
|
||||
import logging
|
||||
import sys
|
||||
import threading
|
||||
|
||||
from celery import Celery
|
||||
@@ -101,34 +98,6 @@ def _unbind_task_log_context(task_id, **_):
|
||||
)
|
||||
|
||||
|
||||
def _trim_native_heap() -> None:
|
||||
"""Return freed glibc heap pages to the OS (Linux only; no-op elsewhere)."""
|
||||
# docling/torch parsing makes large transient allocations; glibc keeps the
|
||||
# freed pages in per-thread malloc arenas rather than returning them, so a
|
||||
# long-lived worker child's RSS only ever climbs. malloc_trim hands them
|
||||
# back. The symbol is glibc-only — absent in macOS libc.
|
||||
if not sys.platform.startswith("linux"):
|
||||
return
|
||||
try:
|
||||
ctypes.CDLL("libc.so.6").malloc_trim(0)
|
||||
except (OSError, AttributeError):
|
||||
pass
|
||||
|
||||
|
||||
@task_postrun.connect
|
||||
def _reclaim_memory_after_task(*args, **kwargs):
|
||||
"""Drop per-task allocations so the prefork child's RSS doesn't ratchet."""
|
||||
gc.collect()
|
||||
torch = sys.modules.get("torch")
|
||||
if torch is not None:
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
except Exception:
|
||||
pass
|
||||
_trim_native_heap()
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def _run_version_check(*args, **kwargs):
|
||||
"""Kick off the anonymous version check on worker startup.
|
||||
|
||||
@@ -31,10 +31,3 @@ worker_prefetch_multiplier = settings.CELERY_WORKER_PREFETCH_MULTIPLIER
|
||||
broker_transport_options = {"visibility_timeout": settings.CELERY_VISIBILITY_TIMEOUT}
|
||||
result_expires = 86400 * 7
|
||||
task_track_started = True
|
||||
|
||||
# Recycle the prefork worker child to bound native-heap growth from
|
||||
# docling/torch parsing. Left unset (Celery's unlimited default) when 0.
|
||||
if settings.CELERY_WORKER_MAX_MEMORY_PER_CHILD > 0:
|
||||
worker_max_memory_per_child = settings.CELERY_WORKER_MAX_MEMORY_PER_CHILD
|
||||
if settings.CELERY_WORKER_MAX_TASKS_PER_CHILD > 0:
|
||||
worker_max_tasks_per_child = settings.CELERY_WORKER_MAX_TASKS_PER_CHILD
|
||||
|
||||
@@ -36,11 +36,6 @@ class Settings(BaseSettings):
|
||||
# and Dify defaults; long ingests can override via env.
|
||||
CELERY_WORKER_PREFETCH_MULTIPLIER: int = 1
|
||||
CELERY_VISIBILITY_TIMEOUT: int = 3600
|
||||
# Recycle the prefork worker child once its resident size crosses this many
|
||||
# kilobytes — backstops native-heap growth from docling/torch parsing. 0 disables.
|
||||
CELERY_WORKER_MAX_MEMORY_PER_CHILD: int = 4194304
|
||||
# Recycle the child after this many tasks; 0 disables (memory cap is the primary knob).
|
||||
CELERY_WORKER_MAX_TASKS_PER_CHILD: int = 0
|
||||
# Only consulted when VECTOR_STORE=mongodb or when running scripts/db/backfill.py; user data lives in Postgres.
|
||||
MONGO_URI: Optional[str] = None
|
||||
# User-data Postgres DB.
|
||||
@@ -193,42 +188,6 @@ class Settings(BaseSettings):
|
||||
COMPRESSION_PROMPT_VERSION: str = "v1.0" # Track prompt iterations
|
||||
COMPRESSION_MAX_HISTORY_POINTS: int = 3 # Keep only last N compression points to prevent DB bloat
|
||||
|
||||
# Internal SSE push channel (notifications + durable replay journal)
|
||||
# Master switch — when False, /api/events emits a "push_disabled" comment
|
||||
# and returns; clients fall back to polling. Publisher becomes a no-op.
|
||||
ENABLE_SSE_PUSH: bool = True
|
||||
# Per-user durable backlog cap (~entries). At typical event rates this
|
||||
# gives ~24h of replay; tune up for verbose feeds, down for memory.
|
||||
EVENTS_STREAM_MAXLEN: int = 1000
|
||||
# SSE keepalive comment cadence. Must sit under Cloudflare's 100s idle
|
||||
# close and iOS Safari's ~60s — 15s gives generous headroom.
|
||||
SSE_KEEPALIVE_SECONDS: int = 15
|
||||
# Cap on simultaneous SSE connections per user. Each connection holds
|
||||
# one WSGI thread (32 per gunicorn worker) and one Redis pub/sub
|
||||
# connection. 8 covers normal multi-tab use without letting one user
|
||||
# starve the pool. Set to 0 to disable the cap.
|
||||
SSE_MAX_CONCURRENT_PER_USER: int = 8
|
||||
# Per-request cap on the number of backlog entries XRANGE returns
|
||||
# for ``/api/events`` snapshots. Bounds the bytes a single replay
|
||||
# can move from Redis to the wire — a malicious client looping
|
||||
# ``Last-Event-ID=<oldest>`` reconnects can only enumerate this
|
||||
# many entries per round-trip. Combined with the per-user
|
||||
# connection cap above and the windowed budget below, total
|
||||
# enumeration throughput is bounded.
|
||||
EVENTS_REPLAY_MAX_PER_REQUEST: int = 200
|
||||
# Sliding-window cap on snapshot replays per user. Once the budget
|
||||
# is exhausted the route returns HTTP 429 with the cursor pinned;
|
||||
# the client backs off and retries after the window rolls over.
|
||||
EVENTS_REPLAY_BUDGET_REQUESTS_PER_WINDOW: int = 30
|
||||
EVENTS_REPLAY_BUDGET_WINDOW_SECONDS: int = 60
|
||||
|
||||
# Retention for the ``message_events`` journal. The ``cleanup_message_events``
|
||||
# beat task deletes rows older than this. Reconnect-replay only
|
||||
# needs the journal for streams a client could still be tailing,
|
||||
# so 14 days is a generous default that covers paused/tool-action
|
||||
# flows without unbounded table growth.
|
||||
MESSAGE_EVENTS_RETENTION_DAYS: int = 14
|
||||
|
||||
@field_validator("POSTGRES_URI", mode="before")
|
||||
@classmethod
|
||||
def _normalize_postgres_uri_validator(cls, v):
|
||||
|
||||
@@ -1,52 +0,0 @@
|
||||
"""Stream/topic key derivations shared by publisher and SSE consumer.
|
||||
|
||||
Single source of truth for the per-user Redis Streams key and pub/sub
|
||||
topic name. Both must agree exactly — a typo here splits the
|
||||
publisher's writes from the consumer's reads.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def stream_key(user_id: str) -> str:
|
||||
"""Redis Streams key holding the durable backlog for ``user_id``."""
|
||||
return f"user:{user_id}:stream"
|
||||
|
||||
|
||||
def topic_name(user_id: str) -> str:
|
||||
"""Redis pub/sub channel used for live fan-out to ``user_id``."""
|
||||
return f"user:{user_id}"
|
||||
|
||||
|
||||
def connection_counter_key(user_id: str) -> str:
|
||||
"""Redis counter tracking active SSE connections for ``user_id``."""
|
||||
return f"user:{user_id}:sse_count"
|
||||
|
||||
|
||||
def replay_budget_key(user_id: str) -> str:
|
||||
"""Redis counter tracking snapshot replays for ``user_id`` in the
|
||||
rolling rate-limit window."""
|
||||
return f"user:{user_id}:replay_count"
|
||||
|
||||
|
||||
def stream_id_compare(a: str, b: str) -> int:
|
||||
"""Compare two Redis Streams ids. Returns -1, 0, 1 like ``cmp``.
|
||||
|
||||
Stream ids are ``ms-seq`` strings; comparing as strings would be wrong
|
||||
once ``ms`` straddles digit-count boundaries. We parse and compare
|
||||
as ``(int, int)`` tuples.
|
||||
|
||||
Raises ``ValueError`` on malformed input. Callers must pre-validate
|
||||
against ``_STREAM_ID_RE`` (or equivalent) — a lex fallback here let
|
||||
a malformed id compare lex-greater than a real one and silently pin
|
||||
dedup forever.
|
||||
"""
|
||||
a_ms, _, a_seq = a.partition("-")
|
||||
b_ms, _, b_seq = b.partition("-")
|
||||
a_tuple = (int(a_ms), int(a_seq) if a_seq else 0)
|
||||
b_tuple = (int(b_ms), int(b_seq) if b_seq else 0)
|
||||
if a_tuple < b_tuple:
|
||||
return -1
|
||||
if a_tuple > b_tuple:
|
||||
return 1
|
||||
return 0
|
||||
@@ -1,144 +0,0 @@
|
||||
"""User-scoped event publisher: durable backlog + live fan-out.
|
||||
|
||||
Each ``publish_user_event`` call writes twice:
|
||||
|
||||
1. ``XADD user:{user_id}:stream MAXLEN ~ <cap> * event <json>`` — the
|
||||
durable backlog used by SSE reconnect (``Last-Event-ID``) and stream
|
||||
replay. Bounded by ``EVENTS_STREAM_MAXLEN`` (~24h at typical event
|
||||
rates) so the per-user footprint stays predictable.
|
||||
2. ``PUBLISH user:{user_id} <json-with-id>`` — live fan-out to every
|
||||
currently connected SSE generator for the user, across instances.
|
||||
|
||||
Together they give a snapshot-plus-tail story: a reconnecting client
|
||||
reads ``XRANGE`` from its last seen id and then transitions onto the
|
||||
live pub/sub. The Redis Streams entry id (e.g. ``1735682400000-0``) is
|
||||
the canonical, monotonically increasing event id and is what
|
||||
``Last-Event-ID`` carries.
|
||||
|
||||
Failures are logged and swallowed: the caller is typically a Celery
|
||||
task whose primary work has already succeeded, and a notification
|
||||
delivery miss should not surface as a task failure.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
|
||||
from application.cache import get_redis_instance
|
||||
from application.core.settings import settings
|
||||
from application.events.keys import stream_key, topic_name
|
||||
from application.streaming.broadcast_channel import Topic
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _iso_now() -> str:
|
||||
"""ISO 8601 UTC with millisecond precision and Z suffix."""
|
||||
return (
|
||||
datetime.now(timezone.utc)
|
||||
.isoformat(timespec="milliseconds")
|
||||
.replace("+00:00", "Z")
|
||||
)
|
||||
|
||||
|
||||
def publish_user_event(
|
||||
user_id: str,
|
||||
event_type: str,
|
||||
payload: dict[str, Any],
|
||||
*,
|
||||
scope: Optional[dict[str, Any]] = None,
|
||||
) -> Optional[str]:
|
||||
"""Publish a user-scoped event; return the Redis Streams id or ``None``.
|
||||
|
||||
Fire-and-forget: never raises. ``None`` means the event reached
|
||||
neither the journal nor live subscribers (see runbook for causes).
|
||||
"""
|
||||
if not user_id or not event_type:
|
||||
logger.warning(
|
||||
"publish_user_event called without user_id or event_type "
|
||||
"(user_id=%r, event_type=%r)",
|
||||
user_id,
|
||||
event_type,
|
||||
)
|
||||
return None
|
||||
if not settings.ENABLE_SSE_PUSH:
|
||||
return None
|
||||
|
||||
envelope_partial: dict[str, Any] = {
|
||||
"type": event_type,
|
||||
"ts": _iso_now(),
|
||||
"user_id": user_id,
|
||||
"topic": topic_name(user_id),
|
||||
"scope": scope or {},
|
||||
"payload": payload,
|
||||
}
|
||||
|
||||
try:
|
||||
envelope_partial_json = json.dumps(envelope_partial)
|
||||
except (TypeError, ValueError) as exc:
|
||||
logger.warning(
|
||||
"publish_user_event payload not JSON-serializable: "
|
||||
"user=%s type=%s err=%s",
|
||||
user_id,
|
||||
event_type,
|
||||
exc,
|
||||
)
|
||||
return None
|
||||
|
||||
redis = get_redis_instance()
|
||||
if redis is None:
|
||||
logger.debug("Redis unavailable; skipping publish_user_event")
|
||||
return None
|
||||
|
||||
maxlen = settings.EVENTS_STREAM_MAXLEN
|
||||
stream_id: Optional[str] = None
|
||||
try:
|
||||
# Auto-id ('*') gives a monotonic ms-seq id that doubles as the
|
||||
# SSE event id. ``approximate=True`` lets Redis trim in chunks
|
||||
# for performance; the cap is treated as ~MAXLEN, never <.
|
||||
result = redis.xadd(
|
||||
stream_key(user_id),
|
||||
{"event": envelope_partial_json},
|
||||
maxlen=maxlen,
|
||||
approximate=True,
|
||||
)
|
||||
stream_id = (
|
||||
result.decode("utf-8")
|
||||
if isinstance(result, (bytes, bytearray))
|
||||
else str(result)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"xadd failed for user=%s event_type=%s", user_id, event_type
|
||||
)
|
||||
|
||||
# If the durable journal write failed there is no canonical id to
|
||||
# ship — publishing the envelope live would put an id-less record
|
||||
# on the wire that bypasses the SSE route's dedup floor and breaks
|
||||
# ``Last-Event-ID`` semantics for any reconnect. Best-effort
|
||||
# delivery means dropping consistently, not delivering inconsistent
|
||||
# state.
|
||||
if stream_id is None:
|
||||
return None
|
||||
|
||||
envelope = dict(envelope_partial)
|
||||
envelope["id"] = stream_id
|
||||
|
||||
try:
|
||||
Topic(topic_name(user_id)).publish(json.dumps(envelope))
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"publish failed for user=%s event_type=%s", user_id, event_type
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"event.published topic=%s type=%s id=%s",
|
||||
topic_name(user_id),
|
||||
event_type,
|
||||
stream_id,
|
||||
)
|
||||
|
||||
return stream_id
|
||||
@@ -4,7 +4,6 @@ from typing import Any, List, Optional
|
||||
from retry import retry
|
||||
from tqdm import tqdm
|
||||
from application.core.settings import settings
|
||||
from application.events.publisher import publish_user_event
|
||||
from application.storage.db.repositories.ingest_chunk_progress import (
|
||||
IngestChunkProgressRepository,
|
||||
)
|
||||
@@ -153,9 +152,6 @@ def embed_and_store_documents(
|
||||
task_status: Any,
|
||||
*,
|
||||
attempt_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
progress_start: int = 0,
|
||||
progress_end: int = 100,
|
||||
) -> None:
|
||||
"""Embeds documents and stores them in a vector store.
|
||||
|
||||
@@ -174,15 +170,6 @@ def embed_and_store_documents(
|
||||
attempt_id: Stable id of the current task invocation,
|
||||
typically ``self.request.id`` from the Celery task body.
|
||||
``None`` is treated as a fresh attempt every time.
|
||||
user_id: When provided, per-percent SSE progress events are
|
||||
published to ``user:{user_id}`` for the in-app upload toast.
|
||||
``None`` is the safe default — workers without a user
|
||||
context (e.g. background syncs) skip the publish.
|
||||
progress_start: Percent the reported progress maps to at chunk 0.
|
||||
Lets a caller reserve the lower band for an earlier stage
|
||||
(e.g. parsing). Defaults to ``0`` (embed owns the whole bar).
|
||||
progress_end: Percent the reported progress maps to at the final
|
||||
chunk. Defaults to ``100``.
|
||||
|
||||
Returns:
|
||||
None
|
||||
@@ -262,9 +249,6 @@ def embed_and_store_documents(
|
||||
# Process and embed documents
|
||||
chunk_error: Exception | None = None
|
||||
failed_idx: int | None = None
|
||||
last_published_pct = -1
|
||||
source_id_str = str(source_id)
|
||||
progress_span = progress_end - progress_start
|
||||
for idx in tqdm(
|
||||
range(loop_start, total_docs),
|
||||
desc="Embedding 🦖",
|
||||
@@ -274,30 +258,10 @@ def embed_and_store_documents(
|
||||
):
|
||||
doc = docs[idx]
|
||||
try:
|
||||
# Map the embed loop into [progress_start, progress_end].
|
||||
progress = progress_start + int(
|
||||
((idx + 1) / total_docs) * progress_span
|
||||
)
|
||||
# Update task status for progress tracking
|
||||
progress = int(((idx + 1) / total_docs) * 100)
|
||||
task_status.update_state(state="PROGRESS", meta={"current": progress})
|
||||
|
||||
# SSE push for sub-second upload-toast updates. Throttled to one
|
||||
# event per percent so a 10k-chunk ingest emits ~100 events,
|
||||
# not 10k. The Celery update_state above stays the source of
|
||||
# truth for the polling-fallback path.
|
||||
if user_id and progress > last_published_pct:
|
||||
publish_user_event(
|
||||
user_id,
|
||||
"source.ingest.progress",
|
||||
{
|
||||
"current": progress,
|
||||
"total": total_docs,
|
||||
"embedded_chunks": idx + 1,
|
||||
"stage": "embedding",
|
||||
},
|
||||
scope={"kind": "source", "id": source_id_str},
|
||||
)
|
||||
last_published_pct = progress
|
||||
|
||||
# Add document to vector store
|
||||
add_text_to_store_with_retry(store, doc, source_id)
|
||||
_record_progress(source_id, last_index=idx, embedded_chunks=idx + 1)
|
||||
|
||||
@@ -211,22 +211,13 @@ class SimpleDirectoryReader(BaseReader):
|
||||
|
||||
return new_input_files
|
||||
|
||||
def load_data(
|
||||
self,
|
||||
concatenate: bool = False,
|
||||
progress_callback: Optional[Callable[[int, int], None]] = None,
|
||||
) -> List[Document]:
|
||||
def load_data(self, concatenate: bool = False) -> List[Document]:
|
||||
"""Load data from the input directory.
|
||||
|
||||
Args:
|
||||
concatenate (bool): whether to concatenate all files into one document.
|
||||
If set to True, file metadata is ignored.
|
||||
False by default.
|
||||
progress_callback (Optional[Callable[[int, int], None]]): Called
|
||||
after each file is parsed with ``(files_done, total_files)``.
|
||||
Lets callers surface parse/OCR progress before embedding
|
||||
begins. Exceptions raised by the callback are swallowed so
|
||||
progress reporting can never fail ingestion.
|
||||
|
||||
Returns:
|
||||
List[Document]: A list of documents.
|
||||
@@ -235,9 +226,8 @@ class SimpleDirectoryReader(BaseReader):
|
||||
data_list: List[str] = []
|
||||
metadata_list = []
|
||||
self.file_token_counts = {}
|
||||
|
||||
total_files = len(self.input_files)
|
||||
for file_index, input_file in enumerate(self.input_files):
|
||||
|
||||
for input_file in self.input_files:
|
||||
suffix_lower = input_file.suffix.lower()
|
||||
parser_metadata = {}
|
||||
if suffix_lower in self.file_extractor:
|
||||
@@ -287,15 +277,7 @@ class SimpleDirectoryReader(BaseReader):
|
||||
else:
|
||||
data_list.append(str(data))
|
||||
metadata_list.append(base_metadata)
|
||||
|
||||
if progress_callback is not None:
|
||||
try:
|
||||
progress_callback(file_index + 1, total_files)
|
||||
except Exception:
|
||||
logging.warning(
|
||||
"load_data progress callback failed", exc_info=True
|
||||
)
|
||||
|
||||
|
||||
# Build directory structure if input_dir is provided
|
||||
if hasattr(self, 'input_dir'):
|
||||
self.directory_structure = self.build_directory_structure(self.input_dir)
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import json
|
||||
|
||||
from application.parser.remote.sitemap_loader import SitemapLoader
|
||||
from application.parser.remote.crawler_loader import CrawlerLoader
|
||||
from application.parser.remote.web_loader import WebLoader
|
||||
@@ -34,59 +32,3 @@ class RemoteCreator:
|
||||
if not loader_class:
|
||||
raise ValueError(f"No loader class found for type {type}")
|
||||
return loader_class(*args, **kwargs)
|
||||
|
||||
|
||||
# Loader types whose load_data expects a URL string, not a config dict.
|
||||
_URL_LOADER_TYPES = {"url", "crawler", "sitemap", "github"}
|
||||
|
||||
# Keys a remote_data dict may hold the URL under (``raw`` is the legacy shape).
|
||||
_URL_DATA_KEYS = ("url", "urls", "repo_url", "raw")
|
||||
|
||||
|
||||
def normalize_remote_data(source_type, remote_data):
|
||||
"""Convert a stored ``sources.remote_data`` JSONB value into the
|
||||
``source_data`` shape the matching loader expects.
|
||||
|
||||
Args:
|
||||
source_type: The ``sources.type`` value (the loader name).
|
||||
remote_data: The stored ``remote_data`` (dict, list, str, or None).
|
||||
|
||||
Returns:
|
||||
Loader input: a URL string or list for url/crawler/sitemap/github,
|
||||
a JSON string for reddit, a dict for s3; ``None`` when the row has
|
||||
nothing syncable.
|
||||
"""
|
||||
if remote_data is None:
|
||||
return None
|
||||
|
||||
# Some legacy rows stored the JSON itself as a string.
|
||||
if isinstance(remote_data, str):
|
||||
stripped = remote_data.strip()
|
||||
if stripped[:1] in ("{", "["):
|
||||
try:
|
||||
remote_data = json.loads(stripped)
|
||||
except json.JSONDecodeError:
|
||||
# Not actually JSON — leave remote_data as the original
|
||||
# string; the per-loader branches below handle a string.
|
||||
pass
|
||||
|
||||
loader = (source_type or "").lower()
|
||||
|
||||
if loader in _URL_LOADER_TYPES:
|
||||
if isinstance(remote_data, dict):
|
||||
for key in _URL_DATA_KEYS:
|
||||
value = remote_data.get(key)
|
||||
if value:
|
||||
return value
|
||||
# No URL key — None keeps the loader off the dict-crash path.
|
||||
return None
|
||||
return remote_data
|
||||
|
||||
if loader == "reddit":
|
||||
# reddit's loader runs json.loads() on its input — needs a string.
|
||||
if isinstance(remote_data, (dict, list)):
|
||||
return json.dumps(remote_data)
|
||||
return remote_data
|
||||
|
||||
# s3's loader accepts a dict or JSON string; pass it through unchanged.
|
||||
return remote_data
|
||||
|
||||
@@ -34,7 +34,7 @@ from sqlalchemy.dialects.postgresql import ARRAY, CITEXT, JSONB, UUID
|
||||
metadata = MetaData()
|
||||
|
||||
|
||||
# --- Users, prompts, tools, logs --------------------------------------------
|
||||
# --- Phase 1, Tier 1 --------------------------------------------------------
|
||||
|
||||
users_table = Table(
|
||||
"users",
|
||||
@@ -138,7 +138,7 @@ app_metadata_table = Table(
|
||||
)
|
||||
|
||||
|
||||
# --- Agents, sources, attachments, artifacts --------------------------------
|
||||
# --- Phase 2, Tier 2 --------------------------------------------------------
|
||||
|
||||
agent_folders_table = Table(
|
||||
"agent_folders",
|
||||
@@ -307,7 +307,7 @@ connector_sessions_table = Table(
|
||||
)
|
||||
|
||||
|
||||
# --- Conversations, messages, workflows -------------------------------------
|
||||
# --- Phase 3, Tier 3 --------------------------------------------------------
|
||||
|
||||
conversations_table = Table(
|
||||
"conversations",
|
||||
@@ -363,36 +363,6 @@ conversation_messages_table = Table(
|
||||
UniqueConstraint("conversation_id", "position", name="conversation_messages_conv_pos_uidx"),
|
||||
)
|
||||
|
||||
# Per-yield journal of chat-stream events, used by the snapshot+tail
|
||||
# reconnect: the route's GET reconnect endpoint reads
|
||||
# ``WHERE message_id = ? AND sequence_no > ?`` from this table before
|
||||
# tailing the live ``channel:{message_id}`` pub/sub. See
|
||||
# ``application/streaming/event_replay.py`` and migration 0007.
|
||||
message_events_table = Table(
|
||||
"message_events",
|
||||
metadata,
|
||||
# PK is the composite ``(message_id, sequence_no)`` — it doubles as
|
||||
# the snapshot read index (covering range scan on
|
||||
# ``WHERE message_id = ? AND sequence_no > ?``).
|
||||
Column(
|
||||
"message_id",
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("conversation_messages.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
nullable=False,
|
||||
),
|
||||
# Strictly monotonic per ``message_id``. Allocated by the route as it
|
||||
# yields, so the writer is single-threaded for the lifetime of one
|
||||
# stream — no contention, no SERIAL needed.
|
||||
Column("sequence_no", Integer, primary_key=True, nullable=False),
|
||||
Column("event_type", Text, nullable=False),
|
||||
Column("payload", JSONB, nullable=False, server_default="{}"),
|
||||
Column(
|
||||
"created_at", DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
shared_conversations_table = Table(
|
||||
"shared_conversations",
|
||||
metadata,
|
||||
@@ -433,7 +403,7 @@ pending_tool_state_table = Table(
|
||||
)
|
||||
|
||||
|
||||
# --- Durability foundation (idempotency / journals, migration 0004) ---------
|
||||
# --- Tier 1 durability foundation (migration 0004) --------------------------
|
||||
# CHECK constraints (status enums) and partial indexes are intentionally
|
||||
# omitted from these declarations — the DB is the authority. Repositories
|
||||
# use raw ``text(...)`` SQL against these tables, not the Core objects.
|
||||
@@ -514,9 +484,6 @@ ingest_chunk_progress_table = Table(
|
||||
# same task resumes from the checkpoint, but a separate invocation
|
||||
# (manual reingest, scheduled sync) resets to a clean re-index.
|
||||
Column("attempt_id", Text),
|
||||
# Added in ``0008_ingest_progress_status``. The reconciler flips
|
||||
# this to 'stalled'; ``init_progress`` resets it to 'active'.
|
||||
Column("status", Text, nullable=False, server_default="active"),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@ Covers every operation the legacy Mongo code performs on
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
@@ -26,22 +25,6 @@ from application.storage.db.models import conversations_table, conversation_mess
|
||||
from application.storage.db.serialization import PGNativeJSONEncoder
|
||||
|
||||
|
||||
class MessageUpdateOutcome(str, Enum):
|
||||
"""Discriminated result of ``update_message_by_id``.
|
||||
|
||||
Distinguishes the row-actually-updated case from the row-already-at-
|
||||
the-requested-terminal-state case so an abort handler can journal
|
||||
``end`` instead of ``error`` when the normal-path finalize already
|
||||
flipped the row to ``complete``.
|
||||
"""
|
||||
|
||||
UPDATED = "updated"
|
||||
ALREADY_COMPLETE = "already_complete"
|
||||
ALREADY_FAILED = "already_failed"
|
||||
NOT_FOUND = "not_found"
|
||||
INVALID = "invalid"
|
||||
|
||||
|
||||
def _message_row_to_dict(row) -> dict:
|
||||
"""Like ``row_to_dict`` but renames the DB column ``message_metadata``
|
||||
back to the public API key ``metadata`` so callers keep the Mongo-era
|
||||
@@ -75,8 +58,8 @@ class ConversationsRepository:
|
||||
- Already-UUID-shaped → returned as-is.
|
||||
- Otherwise treated as a Mongo ObjectId and looked up via
|
||||
``agents.legacy_mongo_id``. Returns ``None`` if no PG row
|
||||
exists yet (e.g. the agent was created before the backfill
|
||||
ran).
|
||||
exists yet (e.g. the agent was created before Phase 1
|
||||
backfill).
|
||||
"""
|
||||
if not agent_id_raw:
|
||||
return None
|
||||
@@ -262,6 +245,57 @@ class ConversationsRepository:
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def search_for_user(
|
||||
self, user_id: str, query: str, limit: int = 30,
|
||||
) -> list[dict]:
|
||||
"""Search a user's conversations by name or message content.
|
||||
|
||||
Same visibility filter as :meth:`list_for_user`. Matches against
|
||||
``conversations.name`` or any of the conversation's messages'
|
||||
``prompt`` / ``response`` columns (case-insensitive substring).
|
||||
|
||||
Each returned row includes ``match_field`` (one of ``name``,
|
||||
``prompt``, ``response``) and ``match_text`` (the full text of the
|
||||
first matching field, ``name`` taking precedence over messages,
|
||||
``prompt`` over ``response``) so callers can render a snippet.
|
||||
"""
|
||||
if not query:
|
||||
return []
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"SELECT c.*, mt.match_field, mt.match_text "
|
||||
"FROM conversations c "
|
||||
"JOIN LATERAL ( "
|
||||
" SELECT field AS match_field, txt AS match_text "
|
||||
" FROM ( "
|
||||
" SELECT 'name'::text AS field, c.name AS txt, 0 AS prio "
|
||||
" WHERE c.name ILIKE :pattern "
|
||||
" UNION ALL "
|
||||
" SELECT 'prompt'::text, m.prompt, 1 "
|
||||
" FROM conversation_messages m "
|
||||
" WHERE m.conversation_id = c.id "
|
||||
" AND m.prompt ILIKE :pattern "
|
||||
" UNION ALL "
|
||||
" SELECT 'response'::text, m.response, 2 "
|
||||
" FROM conversation_messages m "
|
||||
" WHERE m.conversation_id = c.id "
|
||||
" AND m.response ILIKE :pattern "
|
||||
" ) s "
|
||||
" ORDER BY prio "
|
||||
" LIMIT 1 "
|
||||
") mt ON TRUE "
|
||||
"WHERE c.user_id = :user_id "
|
||||
"AND (c.api_key IS NULL OR c.agent_id IS NOT NULL) "
|
||||
"ORDER BY c.date DESC LIMIT :limit"
|
||||
),
|
||||
{
|
||||
"user_id": user_id,
|
||||
"pattern": f"%{query}%",
|
||||
"limit": limit,
|
||||
},
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def rename(self, conversation_id: str, user_id: str, name: str) -> bool:
|
||||
# Shape-gate so a non-UUID id (legacy Mongo ObjectId still floating
|
||||
# around in client-side state during the cutover) never reaches the
|
||||
@@ -714,7 +748,7 @@ class ConversationsRepository:
|
||||
def update_message_by_id(
|
||||
self, message_id: str, fields: dict,
|
||||
*, only_if_non_terminal: bool = False,
|
||||
) -> MessageUpdateOutcome:
|
||||
) -> bool:
|
||||
"""Update specific fields on a message identified by its UUID.
|
||||
|
||||
``metadata`` is merged into the existing JSONB rather than
|
||||
@@ -722,13 +756,9 @@ class ConversationsRepository:
|
||||
a successful late finalize. When ``only_if_non_terminal`` is
|
||||
True, the update is gated so a late finalize cannot retract a
|
||||
reconciler-set ``failed`` (or a prior ``complete``).
|
||||
|
||||
The return value discriminates "I updated the row" from "the
|
||||
row was already at a terminal state" so the abort handler can
|
||||
journal ``end`` when the normal-path finalize already ran.
|
||||
"""
|
||||
if not looks_like_uuid(message_id):
|
||||
return MessageUpdateOutcome.INVALID
|
||||
return False
|
||||
allowed = {
|
||||
"prompt", "response", "thought", "sources", "tool_calls",
|
||||
"attachments", "model_id", "metadata", "timestamp", "status",
|
||||
@@ -736,7 +766,7 @@ class ConversationsRepository:
|
||||
}
|
||||
filtered = {k: v for k, v in fields.items() if k in allowed}
|
||||
if not filtered:
|
||||
return MessageUpdateOutcome.INVALID
|
||||
return False
|
||||
|
||||
api_to_col = {"metadata": "message_metadata"}
|
||||
|
||||
@@ -773,44 +803,15 @@ class ConversationsRepository:
|
||||
params[col] = val
|
||||
|
||||
set_parts.append("updated_at = now()")
|
||||
update_where = ["id = CAST(:id AS uuid)"]
|
||||
where_clauses = ["id = CAST(:id AS uuid)"]
|
||||
if only_if_non_terminal:
|
||||
update_where.append("status NOT IN ('complete', 'failed')")
|
||||
# Single-statement attempt + prior-status probe. Both CTEs see
|
||||
# the same MVCC snapshot, so ``prior.status`` reflects the row
|
||||
# state before the UPDATE — exactly what we need to tell
|
||||
# ``ALREADY_COMPLETE`` apart from ``ALREADY_FAILED`` apart from
|
||||
# ``NOT_FOUND`` without a follow-up SELECT.
|
||||
where_clauses.append("status NOT IN ('complete', 'failed')")
|
||||
sql = (
|
||||
"WITH attempted AS ("
|
||||
f" UPDATE conversation_messages SET {', '.join(set_parts)} "
|
||||
f" WHERE {' AND '.join(update_where)} "
|
||||
" RETURNING 1 AS updated"
|
||||
"), "
|
||||
"prior AS ("
|
||||
" SELECT status FROM conversation_messages "
|
||||
" WHERE id = CAST(:id AS uuid)"
|
||||
") "
|
||||
"SELECT (SELECT updated FROM attempted) AS updated, "
|
||||
" (SELECT status FROM prior) AS prior_status"
|
||||
f"UPDATE conversation_messages SET {', '.join(set_parts)} "
|
||||
f"WHERE {' AND '.join(where_clauses)}"
|
||||
)
|
||||
row = self._conn.execute(text(sql), params).fetchone()
|
||||
if row is None:
|
||||
return MessageUpdateOutcome.NOT_FOUND
|
||||
updated, prior_status = row[0], row[1]
|
||||
if updated:
|
||||
return MessageUpdateOutcome.UPDATED
|
||||
if prior_status is None:
|
||||
return MessageUpdateOutcome.NOT_FOUND
|
||||
if prior_status == "complete":
|
||||
return MessageUpdateOutcome.ALREADY_COMPLETE
|
||||
if prior_status == "failed":
|
||||
return MessageUpdateOutcome.ALREADY_FAILED
|
||||
# ``only_if_non_terminal=False`` always updates an existing row,
|
||||
# so reaching here means the gate excluded it for some status
|
||||
# the terminal set doesn't cover — treat as "not found" rather
|
||||
# than inventing a new variant.
|
||||
return MessageUpdateOutcome.NOT_FOUND
|
||||
result = self._conn.execute(text(sql), params)
|
||||
return result.rowcount > 0
|
||||
|
||||
def update_message_status(
|
||||
self, message_id: str, status: str,
|
||||
|
||||
@@ -41,9 +41,6 @@ class IngestChunkProgressRepository:
|
||||
rows with NULL ``attempt_id`` resume against another NULL
|
||||
caller (e.g. test fixtures), but get reset the moment a real
|
||||
``attempt_id`` arrives.
|
||||
|
||||
Both branches also reset ``status`` to ``'active'``, clearing a
|
||||
prior reconciler ``'stalled'`` escalation.
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
@@ -71,8 +68,7 @@ class IngestChunkProgressRepository:
|
||||
THEN ingest_chunk_progress.embedded_chunks
|
||||
ELSE 0
|
||||
END,
|
||||
attempt_id = EXCLUDED.attempt_id,
|
||||
status = 'active'
|
||||
attempt_id = EXCLUDED.attempt_id
|
||||
RETURNING *
|
||||
"""
|
||||
),
|
||||
@@ -117,23 +113,6 @@ class IngestChunkProgressRepository:
|
||||
row = result.fetchone()
|
||||
return row_to_dict(row) if row is not None else None
|
||||
|
||||
def delete(self, source_id: str) -> bool:
|
||||
"""Delete the progress row for ``source_id``.
|
||||
|
||||
A manual reingest supersedes any prior ingest state — including a
|
||||
reconciler ``'stalled'`` escalation — so dropping the row clears
|
||||
the derived ``failed`` ingest status the sources list shows.
|
||||
Returns ``True`` when a row was removed.
|
||||
"""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"DELETE FROM ingest_chunk_progress "
|
||||
"WHERE source_id = CAST(:source_id AS uuid)"
|
||||
),
|
||||
{"source_id": str(source_id)},
|
||||
)
|
||||
return result.rowcount > 0
|
||||
|
||||
def bump_heartbeat(self, source_id: str) -> None:
|
||||
"""Refresh ``last_updated`` so the row looks alive to the reconciler."""
|
||||
self._conn.execute(
|
||||
|
||||
@@ -1,248 +0,0 @@
|
||||
"""Repository for ``message_events`` — the chat-stream snapshot journal.
|
||||
|
||||
``record`` / ``bulk_record`` write per-yield events; ``read_after``
|
||||
replays rows past a cursor for reconnect snapshots. Composite PK
|
||||
``(message_id, sequence_no)`` raises ``IntegrityError`` on duplicates.
|
||||
Callers must use short-lived per-call transactions — long-lived
|
||||
transactions hide writes from reconnecting clients on a separate
|
||||
connection and turn one bad row into ``InFailedSqlTransaction``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import Connection, text
|
||||
|
||||
from application.storage.db.base_repository import row_to_dict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageEventsRepository:
|
||||
"""Read/write helpers for ``message_events``."""
|
||||
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def record(
|
||||
self,
|
||||
message_id: str,
|
||||
sequence_no: int,
|
||||
event_type: str,
|
||||
payload: Optional[Any] = None,
|
||||
) -> None:
|
||||
"""Append a single event to the journal.
|
||||
|
||||
At this raw repo layer ``payload`` is preserved as-is when not
|
||||
``None`` (lists, scalars, and dicts all round-trip via JSONB);
|
||||
``None`` substitutes an empty object so the column's NOT NULL
|
||||
invariant holds. The streaming-route wrapper
|
||||
``application/streaming/message_journal.py::record_event``
|
||||
tightens this contract to dicts only — the live and replay
|
||||
paths reconstruct non-dict payloads differently, so the wrapper
|
||||
rejects them at the gate. Direct callers of this repo method
|
||||
(cleanup tasks, tests, future ad-hoc consumers) keep the wider
|
||||
JSONB-compatible surface.
|
||||
|
||||
Raises ``sqlalchemy.exc.IntegrityError`` on duplicate
|
||||
``(message_id, sequence_no)`` and ``DataError`` on a malformed
|
||||
``message_id`` UUID. Both abort the surrounding transaction —
|
||||
callers must run inside a short-lived per-event session
|
||||
(see module docstring).
|
||||
"""
|
||||
if not event_type:
|
||||
raise ValueError("event_type must be a non-empty string")
|
||||
materialised_payload = payload if payload is not None else {}
|
||||
self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO message_events (
|
||||
message_id, sequence_no, event_type, payload
|
||||
) VALUES (
|
||||
CAST(:message_id AS uuid), :sequence_no, :event_type,
|
||||
CAST(:payload AS jsonb)
|
||||
)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"message_id": str(message_id),
|
||||
"sequence_no": int(sequence_no),
|
||||
"event_type": event_type,
|
||||
"payload": json.dumps(materialised_payload),
|
||||
},
|
||||
)
|
||||
|
||||
def bulk_record(
|
||||
self,
|
||||
message_id: str,
|
||||
events: list[tuple[int, str, dict]],
|
||||
) -> None:
|
||||
"""Append multiple events for ``message_id`` in one INSERT.
|
||||
|
||||
``events`` is a list of ``(sequence_no, event_type, payload)``
|
||||
tuples. SQLAlchemy ``executemany`` issues one bulk INSERT;
|
||||
Postgres treats the whole batch as one statement, so an
|
||||
IntegrityError on any row aborts the entire batch.
|
||||
|
||||
Caller contract: on IntegrityError, do NOT retry this method
|
||||
with the same batch — fall back to per-row ``record()`` calls
|
||||
(each in its own short-lived session) so a single colliding
|
||||
seq doesn't drop the rest of the batch. ``BatchedJournalWriter``
|
||||
in ``application/streaming/message_journal.py`` is the canonical
|
||||
consumer.
|
||||
"""
|
||||
if not events:
|
||||
return
|
||||
params = [
|
||||
{
|
||||
"message_id": str(message_id),
|
||||
"sequence_no": int(seq),
|
||||
"event_type": event_type,
|
||||
"payload": json.dumps(payload if payload is not None else {}),
|
||||
}
|
||||
for seq, event_type, payload in events
|
||||
]
|
||||
self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO message_events (
|
||||
message_id, sequence_no, event_type, payload
|
||||
) VALUES (
|
||||
CAST(:message_id AS uuid), :sequence_no, :event_type,
|
||||
CAST(:payload AS jsonb)
|
||||
)
|
||||
"""
|
||||
),
|
||||
params,
|
||||
)
|
||||
|
||||
def read_after(
|
||||
self,
|
||||
message_id: str,
|
||||
last_sequence_no: Optional[int] = None,
|
||||
) -> list[dict]:
|
||||
"""Return events with ``sequence_no > last_sequence_no``.
|
||||
|
||||
``last_sequence_no=None`` returns the full backlog. Rows are
|
||||
returned in ascending ``sequence_no`` order. The composite PK
|
||||
is the snapshot read index for this scan — Postgres typically
|
||||
picks an in-order index range scan, though for highly mixed
|
||||
data the planner may pick a bitmap+sort. Either way the result
|
||||
is sorted on ``sequence_no``.
|
||||
|
||||
Returns a ``list`` (not a generator) so the underlying
|
||||
``Result`` is fully drained before the caller can issue
|
||||
another query on the same connection.
|
||||
"""
|
||||
cursor = -1 if last_sequence_no is None else int(last_sequence_no)
|
||||
rows = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT message_id, sequence_no, event_type, payload, created_at
|
||||
FROM message_events
|
||||
WHERE message_id = CAST(:message_id AS uuid)
|
||||
AND sequence_no > :cursor
|
||||
ORDER BY sequence_no ASC
|
||||
"""
|
||||
),
|
||||
{"message_id": str(message_id), "cursor": cursor},
|
||||
).fetchall()
|
||||
return [row_to_dict(row) for row in rows]
|
||||
|
||||
def cleanup_older_than(self, ttl_days: int) -> int:
|
||||
"""Delete journal rows older than ``ttl_days``. Returns row count.
|
||||
|
||||
Reconnect-replay is meaningful only for streams the client
|
||||
could plausibly still be waiting on, so old rows are dead
|
||||
weight. The ``message_events_created_at_idx`` btree makes the
|
||||
range delete a cheap index scan even on large tables.
|
||||
"""
|
||||
if ttl_days <= 0:
|
||||
raise ValueError("ttl_days must be positive")
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
DELETE FROM message_events
|
||||
WHERE created_at < now() - make_interval(days => :ttl_days)
|
||||
"""
|
||||
),
|
||||
{"ttl_days": int(ttl_days)},
|
||||
)
|
||||
return int(result.rowcount or 0)
|
||||
|
||||
def reconstruct_partial(self, message_id: str) -> dict:
|
||||
"""Rebuild partial response/thought/sources/tool_calls from journal events.
|
||||
|
||||
``answer``/``thought`` chunks concat in seq order; ``source``/
|
||||
``tool_calls`` carry the full list at emit time (last-wins).
|
||||
"""
|
||||
rows = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT sequence_no, event_type, payload
|
||||
FROM message_events
|
||||
WHERE message_id = CAST(:message_id AS uuid)
|
||||
ORDER BY sequence_no ASC
|
||||
"""
|
||||
),
|
||||
{"message_id": str(message_id)},
|
||||
).fetchall()
|
||||
|
||||
response_parts: list[str] = []
|
||||
thought_parts: list[str] = []
|
||||
sources: list = []
|
||||
tool_calls: list = []
|
||||
|
||||
for row in rows:
|
||||
payload = row.payload
|
||||
if not isinstance(payload, dict):
|
||||
continue
|
||||
etype = row.event_type
|
||||
if etype == "answer":
|
||||
chunk = payload.get("answer")
|
||||
if isinstance(chunk, str):
|
||||
response_parts.append(chunk)
|
||||
elif etype == "thought":
|
||||
chunk = payload.get("thought")
|
||||
if isinstance(chunk, str):
|
||||
thought_parts.append(chunk)
|
||||
elif etype == "source":
|
||||
src = payload.get("source")
|
||||
if isinstance(src, list):
|
||||
sources = src
|
||||
elif etype == "tool_calls":
|
||||
tcs = payload.get("tool_calls")
|
||||
if isinstance(tcs, list):
|
||||
tool_calls = tcs
|
||||
|
||||
return {
|
||||
"response": "".join(response_parts),
|
||||
"thought": "".join(thought_parts),
|
||||
"sources": sources,
|
||||
"tool_calls": tool_calls,
|
||||
}
|
||||
|
||||
def latest_sequence_no(self, message_id: str) -> Optional[int]:
|
||||
"""Largest ``sequence_no`` recorded for ``message_id``, or ``None``.
|
||||
|
||||
Used by the route to seed the per-stream allocator on retry /
|
||||
process restart so a re-run continues numbering instead of
|
||||
trampling earlier entries with duplicate sequence_no.
|
||||
"""
|
||||
# ``MAX`` always returns one row — NULL when the journal is
|
||||
# empty — so we test the value, not the row presence.
|
||||
row = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT MAX(sequence_no) AS s
|
||||
FROM message_events
|
||||
WHERE message_id = CAST(:message_id AS uuid)
|
||||
"""
|
||||
),
|
||||
{"message_id": str(message_id)},
|
||||
).first()
|
||||
value = row[0] if row is not None else None
|
||||
return int(value) if value is not None else None
|
||||
@@ -107,11 +107,7 @@ class ReconciliationRepository:
|
||||
def find_and_lock_stalled_ingests(
|
||||
self, *, age_minutes: int = 30, limit: int = 100,
|
||||
) -> list[dict]:
|
||||
"""Lock still-active ingest checkpoints with a silent heartbeat.
|
||||
|
||||
The ``status = 'active'`` filter skips rows already escalated to
|
||||
``'stalled'``, so a dead ingest is alerted once, not every tick.
|
||||
"""
|
||||
"""Lock ingest checkpoints whose heartbeat hasn't ticked recently."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
@@ -120,7 +116,6 @@ class ReconciliationRepository:
|
||||
FROM ingest_chunk_progress
|
||||
WHERE last_updated < now() - make_interval(mins => :age)
|
||||
AND embedded_chunks < total_chunks
|
||||
AND status = 'active'
|
||||
ORDER BY last_updated ASC
|
||||
LIMIT :limit
|
||||
FOR UPDATE SKIP LOCKED
|
||||
@@ -130,15 +125,11 @@ class ReconciliationRepository:
|
||||
)
|
||||
return [row_to_dict(r) for r in result.fetchall()]
|
||||
|
||||
def mark_ingest_stalled(self, source_id: str) -> bool:
|
||||
"""Escalate a stalled checkpoint to terminal ``status='stalled'``.
|
||||
|
||||
Drops the row out of the sweep so the reconciler alerts once;
|
||||
``init_progress`` flips it back to ``'active'`` on reingest.
|
||||
"""
|
||||
def touch_ingest_progress(self, source_id: str) -> bool:
|
||||
"""Bump ``last_updated`` so a once-stalled ingest re-enters the watch window."""
|
||||
result = self._conn.execute(
|
||||
text(
|
||||
"UPDATE ingest_chunk_progress SET status = 'stalled' "
|
||||
"UPDATE ingest_chunk_progress SET last_updated = now() "
|
||||
"WHERE source_id = CAST(:sid AS uuid)"
|
||||
),
|
||||
{"sid": str(source_id)},
|
||||
|
||||
@@ -5,10 +5,10 @@ from __future__ import annotations
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import case, Connection, func, select, text
|
||||
from sqlalchemy import Connection, func, select, text
|
||||
|
||||
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
|
||||
from application.storage.db.models import ingest_chunk_progress_table, sources_table
|
||||
from application.storage.db.models import sources_table
|
||||
|
||||
|
||||
_SCALAR_COLUMNS = {
|
||||
@@ -61,21 +61,6 @@ def _coerce_jsonb(value: Any) -> Any:
|
||||
return value
|
||||
|
||||
|
||||
def _ingest_status_case():
|
||||
"""Derive a user-facing ingest status from the joined progress row.
|
||||
|
||||
``failed`` — reconciler-escalated stall. ``processing`` — embed in
|
||||
flight. ``None`` — no progress row, or the embed completed.
|
||||
"""
|
||||
icp = ingest_chunk_progress_table
|
||||
return case(
|
||||
(icp.c.source_id.is_(None), None),
|
||||
(icp.c.status == "stalled", "failed"),
|
||||
(icp.c.embedded_chunks < icp.c.total_chunks, "processing"),
|
||||
else_=None,
|
||||
).label("ingest_status")
|
||||
|
||||
|
||||
class SourcesRepository:
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
@@ -207,25 +192,13 @@ class SourcesRepository:
|
||||
as ``"desc"``.
|
||||
|
||||
Returns:
|
||||
A list of source rows as plain dicts (via ``row_to_dict``),
|
||||
each carrying a derived ``ingest_status`` (``failed`` /
|
||||
``processing`` / ``None``) from the joined progress row.
|
||||
A list of source rows as plain dicts (via ``row_to_dict``).
|
||||
"""
|
||||
column_name = sort_field if sort_field in _SORTABLE_COLUMNS else "date"
|
||||
sort_column = sources_table.c[column_name]
|
||||
ascending = sort_order.lower() == "asc"
|
||||
|
||||
stmt = (
|
||||
select(sources_table, _ingest_status_case())
|
||||
.select_from(
|
||||
sources_table.outerjoin(
|
||||
ingest_chunk_progress_table,
|
||||
ingest_chunk_progress_table.c.source_id
|
||||
== sources_table.c.id,
|
||||
)
|
||||
)
|
||||
.where(sources_table.c.user_id == user_id)
|
||||
)
|
||||
stmt = select(sources_table).where(sources_table.c.user_id == user_id)
|
||||
if search_term:
|
||||
stmt = stmt.where(
|
||||
sources_table.c.name.ilike(
|
||||
|
||||
@@ -63,8 +63,7 @@ class ToolCallAttemptsRepository:
|
||||
message_id: Optional[str] = None,
|
||||
artifact_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Insert OR upgrade a row to ``executed`` — or ``confirmed`` when
|
||||
there is no ``message_id``, as in ``mark_executed``.
|
||||
"""Insert OR upgrade a row to ``executed``.
|
||||
|
||||
Used as a fallback when ``record_proposed`` failed (DB outage)
|
||||
and the tool ran anyway — preserves the journal so the
|
||||
@@ -73,7 +72,6 @@ class ToolCallAttemptsRepository:
|
||||
result_payload: dict = {"result": result}
|
||||
if artifact_id:
|
||||
result_payload["artifact_id"] = artifact_id
|
||||
status = "executed" if message_id is not None else "confirmed"
|
||||
self._conn.execute(
|
||||
text(
|
||||
"""
|
||||
@@ -84,9 +82,9 @@ class ToolCallAttemptsRepository:
|
||||
(:call_id, CAST(:tool_id AS uuid), :tool_name,
|
||||
:action_name, CAST(:arguments AS jsonb),
|
||||
CAST(:result AS jsonb), CAST(:message_id AS uuid),
|
||||
:status)
|
||||
'executed')
|
||||
ON CONFLICT (call_id) DO UPDATE
|
||||
SET status = :status,
|
||||
SET status = 'executed',
|
||||
result = EXCLUDED.result,
|
||||
message_id = COALESCE(EXCLUDED.message_id, tool_call_attempts.message_id)
|
||||
"""
|
||||
@@ -99,7 +97,6 @@ class ToolCallAttemptsRepository:
|
||||
"arguments": json.dumps(arguments if arguments is not None else {}, cls=PGNativeJSONEncoder),
|
||||
"result": json.dumps(result_payload, cls=PGNativeJSONEncoder),
|
||||
"message_id": message_id,
|
||||
"status": status,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -111,9 +108,7 @@ class ToolCallAttemptsRepository:
|
||||
message_id: Optional[str] = None,
|
||||
artifact_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Flip ``proposed`` → ``executed``, or straight to ``confirmed``
|
||||
when there is no ``message_id`` (a ``save_conversation=False``
|
||||
request reserves no message, so no finalize will confirm it).
|
||||
"""Flip ``proposed`` → ``executed`` with the tool result.
|
||||
|
||||
``artifact_id`` (when present) is stored alongside ``result`` in
|
||||
the JSONB as audit data — the reconciler reads it for diagnostic
|
||||
@@ -122,14 +117,12 @@ class ToolCallAttemptsRepository:
|
||||
result_payload: dict = {"result": result}
|
||||
if artifact_id:
|
||||
result_payload["artifact_id"] = artifact_id
|
||||
status = "executed" if message_id is not None else "confirmed"
|
||||
sql = (
|
||||
"UPDATE tool_call_attempts SET "
|
||||
"status = :status, result = CAST(:result AS jsonb)"
|
||||
"status = 'executed', result = CAST(:result AS jsonb)"
|
||||
)
|
||||
params: dict[str, Any] = {
|
||||
"call_id": call_id,
|
||||
"status": status,
|
||||
"result": json.dumps(result_payload, cls=PGNativeJSONEncoder),
|
||||
}
|
||||
if message_id is not None:
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
"""Deterministic source-id derivation for idempotent ingest.
|
||||
|
||||
DO NOT CHANGE the pinned UUID namespace — it backs cross-deploy
|
||||
idempotency keys.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
# DO NOT CHANGE. See module docstring.
|
||||
DOCSGPT_INGEST_NAMESPACE = uuid.UUID("fa25d5d1-398b-46df-ac89-8d1c360b9bea")
|
||||
|
||||
|
||||
def derive_source_id(idempotency_key) -> uuid.UUID:
|
||||
"""``uuid5(NS, key)`` when a key is supplied; ``uuid4()`` otherwise.
|
||||
|
||||
A non-string / empty key falls back to ``uuid4()`` so the caller
|
||||
always gets a fresh id rather than a TypeError mid-route.
|
||||
"""
|
||||
if isinstance(idempotency_key, str) and idempotency_key:
|
||||
return uuid.uuid5(DOCSGPT_INGEST_NAMESPACE, idempotency_key)
|
||||
return uuid.uuid4()
|
||||
@@ -1,126 +0,0 @@
|
||||
"""Redis pub/sub Topic abstraction for SSE fan-out.
|
||||
|
||||
A Topic is a named channel for one-shot live event delivery. Canonical uses:
|
||||
|
||||
- ``user:{user_id}`` for per-user notifications
|
||||
- ``channel:{message_id}`` for per-chat-message streams
|
||||
|
||||
Subscription is race-free via ``on_subscribe``: the callback fires only
|
||||
after Redis acknowledges ``SUBSCRIBE``, so a publisher dispatched inside
|
||||
the callback cannot lose its first event to a not-yet-registered
|
||||
subscriber.
|
||||
|
||||
The subscribe iterator yields ``None`` on poll timeout so the caller can
|
||||
emit SSE keepalive comments without spawning a separate timer thread.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Callable, Iterator, Optional
|
||||
|
||||
from application.cache import get_redis_instance
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Topic:
|
||||
"""A pub/sub channel identified by a string name."""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
self.name = name
|
||||
|
||||
def publish(self, payload: str | bytes) -> int:
|
||||
"""Fan out a payload to currently subscribed clients.
|
||||
|
||||
Returns the number Redis reports as receiving the message (limited
|
||||
to subscribers connected to *this* Redis instance), or 0 if Redis
|
||||
is unavailable. Never raises.
|
||||
"""
|
||||
redis = get_redis_instance()
|
||||
if redis is None:
|
||||
logger.debug("Redis unavailable; dropping publish to %s", self.name)
|
||||
return 0
|
||||
try:
|
||||
return int(redis.publish(self.name, payload))
|
||||
except Exception:
|
||||
logger.exception("Topic.publish failed for %s", self.name)
|
||||
return 0
|
||||
|
||||
def subscribe(
|
||||
self,
|
||||
on_subscribe: Optional[Callable[[], None]] = None,
|
||||
poll_timeout: float = 1.0,
|
||||
) -> Iterator[Optional[bytes]]:
|
||||
"""Subscribe to the topic; yield raw payloads or ``None`` on tick.
|
||||
|
||||
Yields ``None`` every ``poll_timeout`` seconds while idle so the
|
||||
caller can emit keepalive frames or check cancellation. Yields
|
||||
``bytes`` for each delivered message.
|
||||
|
||||
``on_subscribe`` runs synchronously after Redis acknowledges the
|
||||
SUBSCRIBE — use it to seed any state (e.g. read backlog) that
|
||||
must be ordered after the subscriber is live but before the
|
||||
first pub/sub message is processed.
|
||||
|
||||
If Redis is unavailable, returns immediately without yielding.
|
||||
Cleanly unsubscribes on ``GeneratorExit`` (client disconnect).
|
||||
"""
|
||||
redis = get_redis_instance()
|
||||
if redis is None:
|
||||
logger.debug("Redis unavailable; subscribe to %s yielded nothing", self.name)
|
||||
return
|
||||
pubsub = None
|
||||
on_subscribe_fired = False
|
||||
try:
|
||||
pubsub = redis.pubsub()
|
||||
try:
|
||||
pubsub.subscribe(self.name)
|
||||
except Exception:
|
||||
# Subscribe failure (transient Redis hiccup, conn reset, etc.)
|
||||
# is treated like "Redis unavailable": yield nothing, let the
|
||||
# caller fall back to its own resilience strategy. The finally
|
||||
# block will still tear down the pubsub object cleanly.
|
||||
logger.exception("pubsub.subscribe failed for %s", self.name)
|
||||
return
|
||||
while True:
|
||||
try:
|
||||
msg = pubsub.get_message(timeout=poll_timeout)
|
||||
except Exception:
|
||||
logger.exception("pubsub.get_message failed for %s", self.name)
|
||||
return
|
||||
if msg is None:
|
||||
yield None
|
||||
continue
|
||||
msg_type = msg.get("type")
|
||||
if msg_type == "subscribe":
|
||||
if not on_subscribe_fired and on_subscribe is not None:
|
||||
try:
|
||||
on_subscribe()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"on_subscribe callback failed for %s", self.name
|
||||
)
|
||||
on_subscribe_fired = True
|
||||
continue
|
||||
if msg_type != "message":
|
||||
continue
|
||||
data = msg.get("data")
|
||||
if data is None:
|
||||
continue
|
||||
yield data if isinstance(data, bytes) else str(data).encode("utf-8")
|
||||
finally:
|
||||
if pubsub is not None:
|
||||
if on_subscribe_fired:
|
||||
try:
|
||||
pubsub.unsubscribe(self.name)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"pubsub unsubscribe error for %s",
|
||||
self.name,
|
||||
exc_info=True,
|
||||
)
|
||||
try:
|
||||
pubsub.close()
|
||||
except Exception:
|
||||
logger.debug("pubsub close error for %s", self.name, exc_info=True)
|
||||
@@ -1,434 +0,0 @@
|
||||
"""Snapshot+tail iterator for chat-stream reconnect-after-disconnect.
|
||||
|
||||
Subscribe to ``channel:{message_id}``, snapshot ``message_events``
|
||||
rows past ``last_event_id`` inside the SUBSCRIBE-ack callback, flush
|
||||
snapshot, then tail live pub/sub (dedup'd by ``sequence_no``). See
|
||||
``docs/runbooks/sse-notifications.md``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from typing import Iterator, Optional
|
||||
|
||||
from sqlalchemy import text as sql_text
|
||||
|
||||
from application.storage.db.repositories.message_events import (
|
||||
MessageEventsRepository,
|
||||
)
|
||||
from application.storage.db.session import db_readonly
|
||||
from application.streaming.broadcast_channel import Topic
|
||||
from application.streaming.keys import message_topic_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_KEEPALIVE_SECONDS = 15.0
|
||||
DEFAULT_POLL_TIMEOUT_SECONDS = 1.0
|
||||
# When the live tail has no events and no terminal in snapshot, fall
|
||||
# back to checking ``conversation_messages`` directly. If the row has
|
||||
# already gone terminal (worker journaled ``end``/``error`` to the DB
|
||||
# but the matching pub/sub publish was lost, or the row was finalized
|
||||
# without a journal write at all) we surface a terminal event so the
|
||||
# client doesn't hang on keepalives. If the row is still non-terminal
|
||||
# but the producer heartbeat is older than ``PRODUCER_IDLE_SECONDS``
|
||||
# the producer is presumed dead (worker crash / recycle between chunks
|
||||
# and finalize) and we emit a terminal ``error`` so the UI can recover.
|
||||
DEFAULT_WATCHDOG_INTERVAL_SECONDS = 5.0
|
||||
# 1.5× the route's 60s heartbeat interval — long enough that a normal
|
||||
# heartbeat skew doesn't false-positive, short enough that a stuck
|
||||
# stream surfaces before the 5-minute reconciler sweep escalates.
|
||||
DEFAULT_PRODUCER_IDLE_SECONDS = 90.0
|
||||
|
||||
# WHATWG SSE accepts CRLF, CR, LF — split on any of them so a stray CR
|
||||
# can't smuggle a record boundary into the wire format.
|
||||
_SSE_LINE_SPLIT_PATTERN = re.compile(r"\r\n|\r|\n")
|
||||
|
||||
# Event types that mark the end of a chat answer. After delivering one
|
||||
# we close the reconnect stream — keeping the connection open past a
|
||||
# terminal event would leak both the client's reconnect promise and
|
||||
# the server's WSGI thread waiting on keepalives that the user no
|
||||
# longer cares about. The agent loop emits ``end`` for normal /
|
||||
# tool-paused completion and ``error`` for the catch-all failure path
|
||||
# (which doesn't get a trailing ``end``).
|
||||
_TERMINAL_EVENT_TYPES = frozenset({"end", "error"})
|
||||
|
||||
|
||||
def _payload_is_terminal(
|
||||
payload: object, event_type: Optional[str] = None
|
||||
) -> bool:
|
||||
"""True if ``payload['type']`` or ``event_type`` is a terminal sentinel."""
|
||||
if isinstance(payload, dict) and payload.get("type") in _TERMINAL_EVENT_TYPES:
|
||||
return True
|
||||
return event_type in _TERMINAL_EVENT_TYPES
|
||||
|
||||
|
||||
def format_sse_event(payload: dict, sequence_no: int) -> str:
|
||||
"""Encode a journal event as one ``id:``/``data:`` SSE record.
|
||||
|
||||
The body is the payload's JSON serialisation. ``complete_stream``
|
||||
payloads are flat JSON dicts with no embedded newlines, so a
|
||||
single ``data:`` line is sufficient — but we still split on any
|
||||
line terminator in case a future caller passes a multi-line string.
|
||||
"""
|
||||
body = json.dumps(payload)
|
||||
lines = [f"id: {sequence_no}"]
|
||||
for line in _SSE_LINE_SPLIT_PATTERN.split(body):
|
||||
lines.append(f"data: {line}")
|
||||
return "\n".join(lines) + "\n\n"
|
||||
|
||||
|
||||
def _check_producer_liveness(
|
||||
message_id: str, idle_seconds: float
|
||||
) -> Optional[dict]:
|
||||
"""Inspect ``conversation_messages`` and return a terminal SSE
|
||||
payload when the producer is no longer alive, else ``None``.
|
||||
|
||||
Three terminal cases collapse into a single DB round-trip:
|
||||
|
||||
- ``status='complete'`` — the live finalize ran but its journal
|
||||
terminal write didn't reach us (or never happened). Synthesise
|
||||
``end`` so the client closes cleanly on the row's user-visible
|
||||
state.
|
||||
- ``status='failed'`` — same, but for the failure path. Carry the
|
||||
stashed ``error`` from ``message_metadata`` so the UI shows the
|
||||
real reason.
|
||||
- non-terminal status and ``last_heartbeat_at`` (or ``timestamp``)
|
||||
older than ``idle_seconds`` — the producing worker is gone.
|
||||
Synthesise ``error`` so the client doesn't hang on keepalives
|
||||
until the proxy idle-timeout kicks in.
|
||||
"""
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
row = conn.execute(
|
||||
sql_text(
|
||||
"""
|
||||
SELECT
|
||||
status,
|
||||
message_metadata->>'error' AS err,
|
||||
GREATEST(
|
||||
timestamp,
|
||||
COALESCE(
|
||||
(message_metadata->>'last_heartbeat_at')
|
||||
::timestamptz,
|
||||
timestamp
|
||||
)
|
||||
) < now() - make_interval(secs => :idle_secs)
|
||||
AS is_stale
|
||||
FROM conversation_messages
|
||||
WHERE id = CAST(:id AS uuid)
|
||||
"""
|
||||
),
|
||||
{"id": message_id, "idle_secs": float(idle_seconds)},
|
||||
).first()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Watchdog liveness check failed for message_id=%s", message_id
|
||||
)
|
||||
return None
|
||||
|
||||
if row is None:
|
||||
# Row deleted out from under us — treat as terminal so the
|
||||
# client doesn't keep tailing a message that no longer exists.
|
||||
return {
|
||||
"type": "error",
|
||||
"error": "Message no longer exists; please refresh.",
|
||||
"code": "message_missing",
|
||||
"message_id": message_id,
|
||||
}
|
||||
|
||||
status, err, is_stale = row[0], row[1], bool(row[2])
|
||||
if status == "complete":
|
||||
return {"type": "end"}
|
||||
if status == "failed":
|
||||
return {
|
||||
"type": "error",
|
||||
"error": err or "Stream failed; please try again.",
|
||||
"code": "producer_failed",
|
||||
"message_id": message_id,
|
||||
}
|
||||
if is_stale:
|
||||
return {
|
||||
"type": "error",
|
||||
"error": (
|
||||
"Stream producer is no longer responding; please try again."
|
||||
),
|
||||
"code": "producer_stale",
|
||||
"message_id": message_id,
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
def build_message_event_stream(
|
||||
message_id: str,
|
||||
last_event_id: Optional[int] = None,
|
||||
*,
|
||||
keepalive_seconds: float = DEFAULT_KEEPALIVE_SECONDS,
|
||||
poll_timeout_seconds: float = DEFAULT_POLL_TIMEOUT_SECONDS,
|
||||
watchdog_interval_seconds: float = DEFAULT_WATCHDOG_INTERVAL_SECONDS,
|
||||
producer_idle_seconds: float = DEFAULT_PRODUCER_IDLE_SECONDS,
|
||||
) -> Iterator[str]:
|
||||
"""Yield SSE-formatted lines for one ``message_id`` reconnect stream.
|
||||
|
||||
First frame is ``: connected``; subsequent frames are snapshot rows,
|
||||
live-tail events, or ``: keepalive`` comments. Runs until the client
|
||||
disconnects.
|
||||
"""
|
||||
yield ": connected\n\n"
|
||||
|
||||
# Replay buffer — populated inside ``_on_subscribe`` (or the
|
||||
# Redis-unavailable fallback below), drained on the first iteration
|
||||
# of the subscribe loop after the callback runs.
|
||||
replay_buffer: list[str] = []
|
||||
# Dedup floor: seeded with the client's cursor so an empty snapshot
|
||||
# still rejects re-published live events with seq <= last_event_id.
|
||||
# Advanced by snapshot rows AND by yielded live events, so any
|
||||
# republish past the snapshot ceiling is also dropped.
|
||||
max_replayed_seq: Optional[int] = last_event_id
|
||||
replay_done = False
|
||||
replay_failed = False
|
||||
# Set when a snapshot row carries a terminal ``end`` / ``error``
|
||||
# event. After flushing the buffer the generator returns; if we
|
||||
# kept tailing we'd loop on keepalives forever for a stream that
|
||||
# already finished.
|
||||
terminal_in_snapshot = False
|
||||
|
||||
def _read_snapshot_into_buffer() -> None:
|
||||
nonlocal max_replayed_seq, replay_failed, terminal_in_snapshot
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
rows = MessageEventsRepository(conn).read_after(
|
||||
message_id, last_sequence_no=last_event_id
|
||||
)
|
||||
for row in rows:
|
||||
seq = int(row["sequence_no"])
|
||||
payload = row.get("payload")
|
||||
if not isinstance(payload, dict):
|
||||
# ``record_event`` rejects non-dict payloads at the
|
||||
# write gate, so this can only be a legacy row from
|
||||
# before that contract or a direct SQL insert. The
|
||||
# original synthetic fallback (``{"type": event_type}``)
|
||||
# used to ship a malformed envelope here — drop the
|
||||
# row instead so a corrupt journal entry doesn't
|
||||
# poison a reconnect.
|
||||
logger.warning(
|
||||
"Skipping non-dict payload from message_events: "
|
||||
"message_id=%s seq=%s type=%s",
|
||||
message_id,
|
||||
seq,
|
||||
row.get("event_type"),
|
||||
)
|
||||
continue
|
||||
replay_buffer.append(format_sse_event(payload, seq))
|
||||
if max_replayed_seq is None or seq > max_replayed_seq:
|
||||
max_replayed_seq = seq
|
||||
if _payload_is_terminal(payload, row.get("event_type")):
|
||||
terminal_in_snapshot = True
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Snapshot read failed for message_id=%s last_event_id=%s",
|
||||
message_id,
|
||||
last_event_id,
|
||||
)
|
||||
replay_failed = True
|
||||
|
||||
def _on_subscribe() -> None:
|
||||
# SUBSCRIBE has been acked — Postgres reads from this point
|
||||
# capture every row that's been committed. Pub/sub messages
|
||||
# published after this point are queued at the connection level
|
||||
# until the outer loop calls ``get_message`` again.
|
||||
nonlocal replay_done
|
||||
try:
|
||||
_read_snapshot_into_buffer()
|
||||
finally:
|
||||
# Flip even on failure so the outer loop continues to live
|
||||
# tail and the client doesn't hang waiting for a snapshot
|
||||
# flush that will never come.
|
||||
replay_done = True
|
||||
|
||||
topic = Topic(message_topic_name(message_id))
|
||||
last_keepalive = time.monotonic()
|
||||
# Rate-limit the watchdog's DB hit. ``-inf`` makes the first idle
|
||||
# tick after replay_done fire immediately so a snapshot-already-
|
||||
# terminal-in-DB case is surfaced before any keepalive cadence.
|
||||
# Subsequent checks are gated by ``watchdog_interval_seconds``.
|
||||
last_watchdog_check = float("-inf")
|
||||
# Synthetic terminal events emitted by the watchdog use the same
|
||||
# ``sequence_no=-1`` convention as the snapshot-failure path so the
|
||||
# frontend's strict ``\d+`` cursor regex rejects them as a
|
||||
# ``Last-Event-ID`` for any future reconnect. The chosen
|
||||
# discriminator ensures a manual page refresh after a watchdog-fired
|
||||
# error doesn't loop on the same synthetic id.
|
||||
watchdog_synthetic_seq = -1
|
||||
|
||||
try:
|
||||
for payload in topic.subscribe(
|
||||
on_subscribe=_on_subscribe,
|
||||
poll_timeout=poll_timeout_seconds,
|
||||
):
|
||||
# Flush snapshot exactly once after the SUBSCRIBE callback
|
||||
# has run and produced a buffer.
|
||||
if replay_done and replay_buffer:
|
||||
for line in replay_buffer:
|
||||
yield line
|
||||
replay_buffer.clear()
|
||||
if terminal_in_snapshot:
|
||||
# The original stream already finished; tailing
|
||||
# would just emit keepalives forever and pin both a
|
||||
# client reconnect promise and a server WSGI thread.
|
||||
return
|
||||
|
||||
if replay_failed:
|
||||
# Snapshot read failed (DB blip / transient timeout). Emit a
|
||||
# terminal ``error`` event and return — the client only
|
||||
# reconnects after the original stream has already moved on,
|
||||
# so without a snapshot there's nothing live left to tail and
|
||||
# holding the connection open would just emit keepalives
|
||||
# until the proxy idle-timeout fires. ``code`` preserves the
|
||||
# snapshot-vs-agent-loop distinction so a future client can
|
||||
# opt into a refetch instead of a hard failure.
|
||||
yield format_sse_event(
|
||||
{
|
||||
"type": "error",
|
||||
"error": "Stream replay failed; please refresh to load the latest state.",
|
||||
"code": "snapshot_failed",
|
||||
"message_id": message_id,
|
||||
},
|
||||
sequence_no=-1,
|
||||
)
|
||||
return
|
||||
|
||||
now = time.monotonic()
|
||||
if payload is None:
|
||||
# Idle tick — check both keepalive and watchdog. The
|
||||
# watchdog only kicks in once the snapshot half has been
|
||||
# flushed (``replay_done``) so we don't race the
|
||||
# snapshot read on the first iteration.
|
||||
if (
|
||||
replay_done
|
||||
and watchdog_interval_seconds >= 0
|
||||
and now - last_watchdog_check >= watchdog_interval_seconds
|
||||
):
|
||||
last_watchdog_check = now
|
||||
terminal_payload = _check_producer_liveness(
|
||||
message_id, producer_idle_seconds
|
||||
)
|
||||
if terminal_payload is not None:
|
||||
yield format_sse_event(
|
||||
terminal_payload,
|
||||
sequence_no=watchdog_synthetic_seq,
|
||||
)
|
||||
return
|
||||
if now - last_keepalive >= keepalive_seconds:
|
||||
yield ": keepalive\n\n"
|
||||
last_keepalive = now
|
||||
continue
|
||||
|
||||
envelope = _decode_pubsub_message(payload)
|
||||
if envelope is None:
|
||||
continue
|
||||
seq = envelope.get("sequence_no")
|
||||
inner = envelope.get("payload")
|
||||
if (
|
||||
not isinstance(seq, int)
|
||||
or isinstance(seq, bool)
|
||||
or not isinstance(inner, dict)
|
||||
):
|
||||
continue
|
||||
if max_replayed_seq is not None and seq <= max_replayed_seq:
|
||||
# Snapshot already covered this id — drop the duplicate.
|
||||
continue
|
||||
yield format_sse_event(inner, seq)
|
||||
# Advance the dedup floor on the live path too, so a stale
|
||||
# republish of an already-yielded seq (process restart, retry
|
||||
# tool, etc.) is dropped on a later iteration.
|
||||
max_replayed_seq = seq
|
||||
last_keepalive = now
|
||||
if _payload_is_terminal(inner, envelope.get("event_type")):
|
||||
# Live tail just delivered the terminal event — close
|
||||
# out the reconnect stream so the client's drain
|
||||
# promise resolves and the WSGI thread is freed.
|
||||
return
|
||||
|
||||
# Subscribe exited without ever yielding (Redis unavailable,
|
||||
# ``pubsub.subscribe`` raised, or the inner loop died between
|
||||
# SUBSCRIBE-ack and the first poll). The snapshot half is in
|
||||
# Postgres and is still serviceable — read it directly so a
|
||||
# Redis-only outage doesn't cost the client their reconnect
|
||||
# backlog. Gate the read on ``replay_done`` rather than
|
||||
# ``subscribe_started``: if ``_on_subscribe`` already populated
|
||||
# the buffer, re-reading would append the same rows twice and
|
||||
# double the answer chunks on the client (the per-message
|
||||
# reconnect dispatcher does not dedup by ``id``).
|
||||
if not replay_done:
|
||||
_read_snapshot_into_buffer()
|
||||
replay_done = True
|
||||
for line in replay_buffer:
|
||||
yield line
|
||||
replay_buffer.clear()
|
||||
if replay_failed:
|
||||
# Mirror the live-tail branch: emit a terminal ``error`` so
|
||||
# the frontend's existing end/error handling drives the UI
|
||||
# to a failed state instead of relying on the proxy timeout.
|
||||
yield format_sse_event(
|
||||
{
|
||||
"type": "error",
|
||||
"error": "Stream replay failed; please refresh to load the latest state.",
|
||||
"code": "snapshot_failed",
|
||||
"message_id": message_id,
|
||||
},
|
||||
sequence_no=-1,
|
||||
)
|
||||
return
|
||||
# Same close-on-terminal contract as the live-tail branch.
|
||||
# Without it a Redis-down + already-completed-stream client
|
||||
# would also hang on a never-ending generator.
|
||||
if terminal_in_snapshot:
|
||||
return
|
||||
except GeneratorExit:
|
||||
# Client disconnect — let the underlying ``Topic.subscribe``
|
||||
# ``finally`` block tear down its pubsub cleanly.
|
||||
return
|
||||
|
||||
|
||||
def _decode_pubsub_message(raw) -> Optional[dict]:
|
||||
"""Parse a ``Topic.publish`` payload to ``{sequence_no, payload, ...}``.
|
||||
|
||||
Returns ``None`` for malformed messages (drop silently — the
|
||||
journal is still authoritative on reconnect).
|
||||
"""
|
||||
try:
|
||||
if isinstance(raw, (bytes, bytearray)):
|
||||
text_value = raw.decode("utf-8")
|
||||
else:
|
||||
text_value = str(raw)
|
||||
envelope = json.loads(text_value)
|
||||
except Exception:
|
||||
return None
|
||||
if not isinstance(envelope, dict):
|
||||
return None
|
||||
return envelope
|
||||
|
||||
|
||||
def encode_pubsub_message(
|
||||
message_id: str,
|
||||
sequence_no: int,
|
||||
event_type: str,
|
||||
payload: dict,
|
||||
) -> str:
|
||||
"""Build the JSON envelope used for ``channel:{message_id}`` publishes.
|
||||
|
||||
Kept here (not in ``message_journal.py``) so the encode/decode pair
|
||||
stays in one file — replay's ``_decode_pubsub_message`` and the
|
||||
journal's publish must agree on the shape exactly.
|
||||
"""
|
||||
return json.dumps(
|
||||
{
|
||||
"message_id": str(message_id),
|
||||
"sequence_no": int(sequence_no),
|
||||
"event_type": event_type,
|
||||
"payload": payload,
|
||||
}
|
||||
)
|
||||
@@ -1,19 +0,0 @@
|
||||
"""Per-chat-message stream key derivations.
|
||||
|
||||
Single source of truth for the Redis pub/sub topic name and any
|
||||
auxiliary keys that the chat-stream snapshot+tail reconnect path
|
||||
shares between the writer (``complete_stream`` + journal) and the
|
||||
reader (``/api/messages/<id>/events`` reconnect endpoint).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def message_topic_name(message_id: str) -> str:
|
||||
"""Redis pub/sub channel for live fan-out of one chat message.
|
||||
|
||||
Subscribers tail this topic for every event that ``complete_stream``
|
||||
yielded after the SUBSCRIBE-ack arrived; older events are recovered
|
||||
from the ``message_events`` snapshot half of the pattern.
|
||||
"""
|
||||
return f"channel:{message_id}"
|
||||
@@ -1,400 +0,0 @@
|
||||
"""Per-yield journal write for the chat-stream snapshot+tail pattern.
|
||||
|
||||
``record_event`` inserts into ``message_events`` and publishes to
|
||||
``channel:{message_id}``. Both are best-effort; the INSERT commits
|
||||
before the publish so a fast reconnect sees the row. See
|
||||
``docs/runbooks/sse-notifications.md``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from application.storage.db.repositories.message_events import (
|
||||
MessageEventsRepository,
|
||||
)
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.streaming.broadcast_channel import Topic
|
||||
from application.streaming.event_replay import encode_pubsub_message
|
||||
from application.streaming.keys import message_topic_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Tunables for ``BatchedJournalWriter``. A streaming answer emits ~100s
|
||||
# of ``answer`` chunks per response; without batching, that's one PG
|
||||
# transaction per yield in the WSGI thread. With these defaults, ~10x
|
||||
# fewer commits at the cost of a ≤100ms reconnect-visibility lag for
|
||||
# any event still sitting in the buffer.
|
||||
DEFAULT_BATCH_SIZE = 16
|
||||
DEFAULT_BATCH_INTERVAL_MS = 100
|
||||
|
||||
|
||||
def _strip_null_bytes(value: Any) -> Any:
|
||||
"""Recursively strip ``\\x00`` from string keys/values in ``value``.
|
||||
|
||||
Postgres JSONB rejects the NUL escape; an LLM emitting a stray NUL
|
||||
in a chunk would otherwise raise ``DataError`` at INSERT and the row
|
||||
would be lost from the journal (live stream proceeds, reconnect
|
||||
snapshot misses the chunk). Mirrors the strip already done in
|
||||
``parser/embedding_pipeline.py`` and
|
||||
``api/user/attachments/routes.py``.
|
||||
"""
|
||||
if isinstance(value, str):
|
||||
return value.replace("\x00", "") if "\x00" in value else value
|
||||
if isinstance(value, dict):
|
||||
return {
|
||||
(k.replace("\x00", "") if isinstance(k, str) and "\x00" in k else k):
|
||||
_strip_null_bytes(v)
|
||||
for k, v in value.items()
|
||||
}
|
||||
if isinstance(value, list):
|
||||
return [_strip_null_bytes(item) for item in value]
|
||||
if isinstance(value, tuple):
|
||||
return tuple(_strip_null_bytes(item) for item in value)
|
||||
return value
|
||||
|
||||
|
||||
def record_event(
|
||||
message_id: str,
|
||||
sequence_no: int,
|
||||
event_type: str,
|
||||
payload: Optional[dict[str, Any]] = None,
|
||||
) -> bool:
|
||||
"""Journal one SSE event and publish it live. Best-effort.
|
||||
|
||||
``payload`` must be a ``dict`` or ``None`` (non-dicts are dropped so
|
||||
live and replay envelopes stay byte-identical). Returns ``True`` when
|
||||
the journal INSERT committed. Never raises.
|
||||
"""
|
||||
if not message_id or not event_type:
|
||||
logger.warning(
|
||||
"record_event called without message_id/event_type "
|
||||
"(message_id=%r, event_type=%r)",
|
||||
message_id,
|
||||
event_type,
|
||||
)
|
||||
return False
|
||||
|
||||
if payload is None:
|
||||
materialised_payload: dict[str, Any] = {}
|
||||
elif isinstance(payload, dict):
|
||||
materialised_payload = _strip_null_bytes(payload)
|
||||
else:
|
||||
logger.warning(
|
||||
"record_event called with non-dict payload "
|
||||
"(message_id=%s seq=%s type=%s payload_type=%s) — dropping",
|
||||
message_id,
|
||||
sequence_no,
|
||||
event_type,
|
||||
type(payload).__name__,
|
||||
)
|
||||
return False
|
||||
|
||||
journal_committed = False
|
||||
# The seq we actually managed to write. Diverges from
|
||||
# ``sequence_no`` only on the IntegrityError-retry path below.
|
||||
materialised_seq = sequence_no
|
||||
try:
|
||||
# Short-lived per-event transaction. Critical for visibility:
|
||||
# the reconnect endpoint reads the journal from a separate
|
||||
# connection and only sees committed rows.
|
||||
with db_session() as conn:
|
||||
MessageEventsRepository(conn).record(
|
||||
message_id, sequence_no, event_type, materialised_payload
|
||||
)
|
||||
journal_committed = True
|
||||
except IntegrityError:
|
||||
# Composite-PK collision on (message_id, sequence_no). Most
|
||||
# likely cause is a stale ``latest_sequence_no`` seed on a
|
||||
# continuation retry — the route read MAX(seq) from a separate
|
||||
# connection before another writer committed past it. Look up
|
||||
# the live latest and retry once with latest+1 so the event is
|
||||
# not silently lost. Bounded to a single retry — if two
|
||||
# writers keep racing in lockstep the route-level retry will
|
||||
# converge them across attempts.
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
latest = MessageEventsRepository(conn).latest_sequence_no(
|
||||
message_id
|
||||
)
|
||||
materialised_seq = (latest if latest is not None else -1) + 1
|
||||
with db_session() as conn:
|
||||
MessageEventsRepository(conn).record(
|
||||
message_id,
|
||||
materialised_seq,
|
||||
event_type,
|
||||
materialised_payload,
|
||||
)
|
||||
journal_committed = True
|
||||
logger.info(
|
||||
"record_event: collision at seq=%s recovered → wrote at "
|
||||
"seq=%s message_id=%s type=%s",
|
||||
sequence_no,
|
||||
materialised_seq,
|
||||
message_id,
|
||||
event_type,
|
||||
)
|
||||
except IntegrityError:
|
||||
# Second collision under the same retry — give up and log.
|
||||
# The route's nonlocal counter will continue at
|
||||
# ``sequence_no+1`` on the next emit; the next call may
|
||||
# land cleanly past the contended window.
|
||||
logger.warning(
|
||||
"record_event: IntegrityError persists after seq+1 retry; "
|
||||
"dropping. message_id=%s original_seq=%s retry_seq=%s "
|
||||
"type=%s",
|
||||
message_id,
|
||||
sequence_no,
|
||||
materialised_seq,
|
||||
event_type,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"record_event: retry path failed unexpectedly "
|
||||
"(message_id=%s seq=%s type=%s)",
|
||||
message_id,
|
||||
sequence_no,
|
||||
event_type,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"message_events INSERT failed: message_id=%s seq=%s type=%s",
|
||||
message_id,
|
||||
sequence_no,
|
||||
event_type,
|
||||
)
|
||||
|
||||
try:
|
||||
# Publish using ``materialised_seq`` so the live pubsub frame
|
||||
# matches the journal row that other clients will snapshot on
|
||||
# reconnect. The original POST stream's SSE ``id:`` still
|
||||
# carries the caller's ``sequence_no`` — a reconnect from that
|
||||
# client will receive the same event at ``materialised_seq``
|
||||
# on the snapshot, which is a benign duplicate (the slice's
|
||||
# ``max_replayed_seq`` advances past it). No-collision case:
|
||||
# ``materialised_seq == sequence_no`` and this is identical to
|
||||
# the prior behaviour.
|
||||
wire = encode_pubsub_message(
|
||||
message_id, materialised_seq, event_type, materialised_payload
|
||||
)
|
||||
Topic(message_topic_name(message_id)).publish(wire)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"channel:%s publish failed: seq=%s type=%s",
|
||||
message_id,
|
||||
materialised_seq,
|
||||
event_type,
|
||||
)
|
||||
|
||||
return journal_committed
|
||||
|
||||
|
||||
class BatchedJournalWriter:
|
||||
"""Per-stream journal writer that batches PG INSERTs.
|
||||
|
||||
One writer per ``message_id``; ``record()`` buffers events and flushes
|
||||
on size/time/``close()`` triggers. Pubsub publishes fire only after the
|
||||
INSERT commits. On ``IntegrityError`` falls back to per-row writes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message_id: str,
|
||||
*,
|
||||
batch_size: int = DEFAULT_BATCH_SIZE,
|
||||
batch_interval_ms: int = DEFAULT_BATCH_INTERVAL_MS,
|
||||
) -> None:
|
||||
self._message_id = message_id
|
||||
self._batch_size = batch_size
|
||||
self._batch_interval_ms = batch_interval_ms
|
||||
self._buffer: list[tuple[int, str, dict[str, Any]]] = []
|
||||
self._last_flush_mono_ms = time.monotonic() * 1000.0
|
||||
self._closed = False
|
||||
|
||||
def record(
|
||||
self,
|
||||
sequence_no: int,
|
||||
event_type: str,
|
||||
payload: Optional[dict[str, Any]] = None,
|
||||
) -> bool:
|
||||
"""Buffer one event; maybe flush. Publish happens after journal commit."""
|
||||
if self._closed:
|
||||
logger.warning(
|
||||
"BatchedJournalWriter.record after close: "
|
||||
"message_id=%s seq=%s type=%s",
|
||||
self._message_id,
|
||||
sequence_no,
|
||||
event_type,
|
||||
)
|
||||
return False
|
||||
if not event_type:
|
||||
logger.warning(
|
||||
"BatchedJournalWriter.record without event_type: "
|
||||
"message_id=%s seq=%s",
|
||||
self._message_id,
|
||||
sequence_no,
|
||||
)
|
||||
return False
|
||||
if payload is None:
|
||||
materialised: dict[str, Any] = {}
|
||||
elif isinstance(payload, dict):
|
||||
materialised = _strip_null_bytes(payload)
|
||||
else:
|
||||
# Same contract as ``record_event`` — non-dict payloads
|
||||
# are rejected so the live and replay paths can't diverge
|
||||
# on envelope reconstruction.
|
||||
logger.warning(
|
||||
"BatchedJournalWriter.record with non-dict payload: "
|
||||
"message_id=%s seq=%s type=%s payload_type=%s — dropping",
|
||||
self._message_id,
|
||||
sequence_no,
|
||||
event_type,
|
||||
type(payload).__name__,
|
||||
)
|
||||
return False
|
||||
|
||||
self._buffer.append((sequence_no, event_type, materialised))
|
||||
|
||||
if self._should_flush():
|
||||
self.flush()
|
||||
return True
|
||||
|
||||
def _should_flush(self) -> bool:
|
||||
if len(self._buffer) >= self._batch_size:
|
||||
return True
|
||||
elapsed_ms = (time.monotonic() * 1000.0) - self._last_flush_mono_ms
|
||||
return elapsed_ms >= self._batch_interval_ms and len(self._buffer) > 0
|
||||
|
||||
def flush(self) -> None:
|
||||
"""Commit buffered events to PG. Best-effort.
|
||||
|
||||
Tries one bulk INSERT first; on ``IntegrityError`` (composite
|
||||
PK collision — typically a stale continuation seed) falls back
|
||||
to per-row ``record_event`` so one bad seq doesn't drop the
|
||||
rest of the batch. Always clears the buffer to bound memory,
|
||||
even on failure — a journaled event missing from a snapshot
|
||||
is degraded UX, but a runaway buffer is corruption.
|
||||
"""
|
||||
if not self._buffer:
|
||||
self._last_flush_mono_ms = time.monotonic() * 1000.0
|
||||
return
|
||||
|
||||
# Snapshot and clear before the I/O so a concurrent record()
|
||||
# call would land in a fresh buffer rather than racing the
|
||||
# flush. ``complete_stream`` is single-threaded per stream, so
|
||||
# this is belt-and-suspenders for any future change.
|
||||
pending = self._buffer
|
||||
self._buffer = []
|
||||
self._last_flush_mono_ms = time.monotonic() * 1000.0
|
||||
|
||||
try:
|
||||
with db_session() as conn:
|
||||
MessageEventsRepository(conn).bulk_record(
|
||||
self._message_id, pending
|
||||
)
|
||||
except IntegrityError:
|
||||
logger.info(
|
||||
"BatchedJournalWriter: bulk INSERT collided for "
|
||||
"message_id=%s n=%d; falling back to per-row writes",
|
||||
self._message_id,
|
||||
len(pending),
|
||||
)
|
||||
self._flush_per_row(pending)
|
||||
return
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"BatchedJournalWriter: bulk INSERT failed for "
|
||||
"message_id=%s n=%d; events dropped from journal",
|
||||
self._message_id,
|
||||
len(pending),
|
||||
)
|
||||
return
|
||||
|
||||
# Bulk INSERT committed — publish each frame in order. Best-effort:
|
||||
# one failed publish must not poison the rest of the batch.
|
||||
for seq, event_type, payload in pending:
|
||||
self._publish(seq, event_type, payload)
|
||||
|
||||
def _flush_per_row(
|
||||
self, pending: list[tuple[int, str, dict[str, Any]]]
|
||||
) -> None:
|
||||
"""Per-row fallback after a bulk collision. Publishes after each commit."""
|
||||
for seq, event_type, payload in pending:
|
||||
committed_seq: Optional[int] = None
|
||||
try:
|
||||
with db_session() as conn:
|
||||
MessageEventsRepository(conn).record(
|
||||
self._message_id, seq, event_type, payload
|
||||
)
|
||||
committed_seq = seq
|
||||
except IntegrityError:
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
latest = MessageEventsRepository(
|
||||
conn
|
||||
).latest_sequence_no(self._message_id)
|
||||
retry_seq = (latest if latest is not None else -1) + 1
|
||||
with db_session() as conn:
|
||||
MessageEventsRepository(conn).record(
|
||||
self._message_id, retry_seq, event_type, payload
|
||||
)
|
||||
committed_seq = retry_seq
|
||||
except IntegrityError:
|
||||
logger.warning(
|
||||
"BatchedJournalWriter: IntegrityError persists "
|
||||
"after seq+1 retry; dropping. message_id=%s "
|
||||
"original_seq=%s type=%s",
|
||||
self._message_id,
|
||||
seq,
|
||||
event_type,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"BatchedJournalWriter: per-row retry failed "
|
||||
"(message_id=%s seq=%s type=%s)",
|
||||
self._message_id,
|
||||
seq,
|
||||
event_type,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"BatchedJournalWriter: per-row INSERT failed "
|
||||
"(message_id=%s seq=%s type=%s)",
|
||||
self._message_id,
|
||||
seq,
|
||||
event_type,
|
||||
)
|
||||
|
||||
if committed_seq is not None:
|
||||
self._publish(committed_seq, event_type, payload)
|
||||
|
||||
def _publish(
|
||||
self, sequence_no: int, event_type: str, payload: dict[str, Any]
|
||||
) -> None:
|
||||
"""Publish one frame to the per-message pubsub channel. Best-effort."""
|
||||
try:
|
||||
wire = encode_pubsub_message(
|
||||
self._message_id, sequence_no, event_type, payload
|
||||
)
|
||||
Topic(message_topic_name(self._message_id)).publish(wire)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"channel:%s publish failed: seq=%s type=%s",
|
||||
self._message_id,
|
||||
sequence_no,
|
||||
event_type,
|
||||
)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Final flush. Idempotent — safe to call from multiple
|
||||
finally clauses.
|
||||
"""
|
||||
if self._closed:
|
||||
return
|
||||
self.flush()
|
||||
self._closed = True
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,385 +0,0 @@
|
||||
# SSE Notifications Runbook
|
||||
|
||||
> Operations guide for "user says they didn't get a notification" — and
|
||||
> the related "the bell never lights up" / "my upload toast hangs" /
|
||||
> "the chat answer doesn't reconnect" symptoms.
|
||||
|
||||
The user-facing notifications channel is the SSE pipe at
|
||||
`/api/events` plus per-message reconnects at
|
||||
`/api/messages/<id>/events`. This document maps a user complaint to
|
||||
the diagnostic that surfaces the cause.
|
||||
|
||||
---
|
||||
|
||||
## TL;DR — first 60 seconds
|
||||
|
||||
Run these three commands in parallel before anything else:
|
||||
|
||||
```bash
|
||||
# 1) Is Redis up and serving the pipe? Should print PONG instantly.
|
||||
redis-cli -n 2 PING
|
||||
|
||||
# 2) Anyone subscribed to the channel right now? Numbers per channel.
|
||||
redis-cli -n 2 PUBSUB NUMSUB user:<user_id>
|
||||
|
||||
# 3) Is the user's backlog populated? Returns the count of journaled events.
|
||||
redis-cli -n 2 XLEN user:<user_id>:stream
|
||||
```
|
||||
|
||||
- `PING` failing → Redis is the problem. Skip to "Redis-down".
|
||||
- `NUMSUB user:<user_id>` returns 0 → no client connected. Skip to "Client never connects".
|
||||
- `XLEN user:<user_id>:stream` returns 0 or low → publisher isn't writing. Skip to "Publisher silent".
|
||||
- All three look healthy → the events are flowing on the wire; the issue is downstream of the slice (UI rendering, toast suppression, etc.). Skip to "Events flowing but UI silent".
|
||||
|
||||
---
|
||||
|
||||
## Architecture cheat-sheet
|
||||
|
||||
```
|
||||
Worker (publish_user_event) Frontend tab
|
||||
│ ▲
|
||||
▼ │ GET /api/events SSE
|
||||
Redis Streams: XADD Flask route
|
||||
user:<id>:stream ──────────────► replay_backlog (snapshot)
|
||||
│ +
|
||||
▼ Topic.subscribe (live tail)
|
||||
Redis pub/sub: PUBLISH │
|
||||
user:<id> ────────────────────────────────┘
|
||||
```
|
||||
|
||||
**Source of truth:**
|
||||
- Persistent journal: Redis Stream `user:<user_id>:stream`, capped at
|
||||
`EVENTS_STREAM_MAXLEN` (default 1000) entries via `MAXLEN ~`. ~24h
|
||||
at typical event rates.
|
||||
- Live fan-out: Redis pub/sub channel `user:<user_id>`. No durability;
|
||||
subscribers must be attached at publish time.
|
||||
|
||||
The chat-stream pipe is separate, parallel infrastructure:
|
||||
- Journal: Postgres `message_events` table.
|
||||
- Live fan-out: Redis pub/sub `channel:<message_id>`.
|
||||
|
||||
Same patterns, different durability layer. This doc covers both;
|
||||
they share most diagnostic commands.
|
||||
|
||||
---
|
||||
|
||||
## Symptom → diagnostic map
|
||||
|
||||
### A. "I uploaded a source and the toast never appeared"
|
||||
|
||||
User flow: chat → upload → expect toast.
|
||||
|
||||
| Step | Command | Expect |
|
||||
| ------------------------------------------------- | ------------------------------------------------------------- | ----------------------------------------------- |
|
||||
| Worker received the task | `tail -f celery.log` filtered by user | `ingest_worker` start log line |
|
||||
| Worker published the queued event | `redis-cli -n 2 XREVRANGE user:<id>:stream + - COUNT 5` | A `source.ingest.queued` entry within seconds |
|
||||
| Frontend got it | DevTools → Network → `/api/events` → EventStream tab | `data: {"type":"source.ingest.queued",...}` |
|
||||
| Slice updated | Redux DevTools → state.upload.tasks | Task with matching `sourceId`, `status:'training'` |
|
||||
|
||||
If the worker's queued log line is there but the XADD didn't land →
|
||||
look for a `publish_user_event payload not JSON-serializable` warning
|
||||
in the worker log (the publisher swallows `TypeError`).
|
||||
|
||||
If the XADD landed but the frontend never received it → check
|
||||
`PUBSUB NUMSUB user:<id>` while the user is on the page. If 0, the
|
||||
SSE connection isn't subscribed; skip to "Client never connects".
|
||||
|
||||
If the frontend received it but the toast didn't render → the
|
||||
`uploadSlice` extraReducer requires `task.sourceId` to match the
|
||||
event's `scope.id`. Check the upload route returned `source_id` in
|
||||
its POST response (the upload, connector, and reingest paths all
|
||||
include it). Idempotent / cached responses must also include
|
||||
`source_id` (`_claim_task_or_get_cached`).
|
||||
|
||||
### B. "The bell badge never goes up"
|
||||
|
||||
There is no bell — the global notifications surface is per-event
|
||||
toasts, not an aggregated counter. If the user is on an old build,
|
||||
`Cmd-Shift-R` to bypass cache. The surfaces they're looking for are
|
||||
`UploadToast` for source uploads and `ToolApprovalToast` for
|
||||
tool-approval events.
|
||||
|
||||
### C. "My chat answer froze mid-stream and never recovered"
|
||||
|
||||
User flow: ask question → answer streaming → network blip → answer
|
||||
stops; should reconnect.
|
||||
|
||||
```bash
|
||||
# Was the original message reserved in PG?
|
||||
psql -c "SELECT id, status, prompt FROM conversation_messages \
|
||||
WHERE user_id = '<user>' ORDER BY timestamp DESC LIMIT 5;"
|
||||
|
||||
# Did the journal capture events past the user's last-seen seq?
|
||||
psql -c "SELECT sequence_no, event_type FROM message_events \
|
||||
WHERE message_id = '<id>' ORDER BY sequence_no;"
|
||||
|
||||
# Is the live tail still producing? (subscribe and watch)
|
||||
redis-cli -n 2 SUBSCRIBE channel:<message_id>
|
||||
```
|
||||
|
||||
The frontend should reconnect via `GET /api/messages/<id>/events`
|
||||
when the original POST stream closes without a typed `end` or
|
||||
`error` event. If it's not reconnecting, `console.warn('Stream
|
||||
reconnect failed', ...)` will be in the browser console — the
|
||||
reconnect HTTP errored. Common cases:
|
||||
|
||||
- The user's JWT rotated mid-stream → 401 on the GET. Frontend
|
||||
doesn't auto-refresh; the user reloads.
|
||||
- The user is on a different host than the API and CORS is rejecting
|
||||
the GET → check `application/asgi.py` allow-headers.
|
||||
|
||||
### D. "The dev install never delivers any notifications at all"
|
||||
|
||||
Default `AUTH_TYPE` unset means `decoded_token = {"sub": "local"}`
|
||||
for every request. The SSE client connects without the
|
||||
`Authorization` header in this case, and `user:local:stream` is
|
||||
the shared channel everything goes to. If the user has multiple dev
|
||||
machines pointing at the same Redis, they will see each other's
|
||||
events. Confirm with:
|
||||
|
||||
```bash
|
||||
redis-cli -n 2 KEYS 'user:local:*'
|
||||
```
|
||||
|
||||
If multiple deployments share the Redis, document that as a known
|
||||
multi-user-on-local-channel limitation. Set `AUTH_TYPE=simple_jwt`
|
||||
to scope per-user.
|
||||
|
||||
### E. "The notifications channel was working, then suddenly stopped after the user reloaded the page"
|
||||
|
||||
Likely path: `backlog.truncated` event fired, the slice cleared
|
||||
`lastEventId` to null, the closure was carrying the same stale id and
|
||||
re-tripped the same truncation on every reconnect. **Verify the user
|
||||
is on a current build — `eventStreamClient.ts` must re-read
|
||||
`lastEventId = opts.getLastEventId();` without a truthy guard so the
|
||||
null clear propagates into the next reconnect.**
|
||||
|
||||
### F. "I keep getting 429 on /api/events"
|
||||
|
||||
The per-user concurrent-connection cap (`SSE_MAX_CONCURRENT_PER_USER`,
|
||||
default 8) refused the connection. User has too many tabs open or a
|
||||
runaway reconnect loop. `redis-cli -n 2 GET user:<id>:sse_count`
|
||||
shows the live counter; the TTL is 1h from the last connection
|
||||
attempt (rolling — every INCR re-seeds it), so the key only ages
|
||||
out after the user stops reconnecting for a full hour.
|
||||
|
||||
If the count is wedged high without explanation, the
|
||||
counter-DECR-in-finally path didn't run (worker SIGKILL, OOM). Wait
|
||||
for the TTL or `redis-cli -n 2 DEL user:<id>:sse_count` to reset.
|
||||
|
||||
### G. "Replay snapshot stops at 200 events"
|
||||
|
||||
The route caps each replay at `EVENTS_REPLAY_MAX_PER_REQUEST`
|
||||
(default 200). The cap is intentionally **silent** — the route does NOT
|
||||
emit a `backlog.truncated` notice for cap-hit. The 200 entries each
|
||||
carry their own `id:` header, so the frontend's slice cursor
|
||||
advances to the most-recent delivered id. Next reconnect sends
|
||||
`last_event_id=<max_replayed>` and the snapshot resumes from there.
|
||||
A user that was 1000 entries behind catches up over ~5 reconnects.
|
||||
|
||||
If the user reports getting HTTP 429 on `/api/events` despite being
|
||||
well under `SSE_MAX_CONCURRENT_PER_USER`, they hit the windowed
|
||||
replay budget (`EVENTS_REPLAY_BUDGET_REQUESTS_PER_WINDOW`, default
|
||||
30 / `EVENTS_REPLAY_BUDGET_WINDOW_SECONDS` 60s). The route refuses
|
||||
the connection so the slice cursor stays pinned at whatever value
|
||||
it had; the frontend backs off and the next reconnect (after the
|
||||
window rolls) gets the proper snapshot. Serving the live tail
|
||||
without a snapshot used to be the behavior here, but that let the
|
||||
client advance `lastEventId` past entries it never received,
|
||||
permanently stranding the un-replayed window — so the route now
|
||||
429s instead. `redis-cli -n 2 GET user:<id>:replay_count` shows the
|
||||
current counter; TTL is the window size.
|
||||
|
||||
`backlog.truncated` is emitted ONLY when the client's
|
||||
`Last-Event-ID` has slid off the MAXLEN'd window — i.e. the journal
|
||||
is genuinely gone past the cursor and the frontend should clear the
|
||||
slice cursor and refetch state. Treating cap-hit or
|
||||
budget-exhaustion the same way would lock the user into re-receiving
|
||||
the oldest 200 entries on every reconnect (the cursor would clear,
|
||||
the snapshot would re-serve from the start, the cap would re-trip).
|
||||
|
||||
### H. "User says push notifications stopped after a deploy"
|
||||
|
||||
- Pull `event.published topic=user:<id> type=...` from the worker
|
||||
logs to confirm the publisher is still firing.
|
||||
- Pull `event.connect user=<id>` from the API logs to confirm the
|
||||
client is reconnecting.
|
||||
- Check the gunicorn worker count and `WSGIMiddleware(workers=32)` —
|
||||
if the deploy reduced worker count, the per-user cap is still 8
|
||||
but total concurrent SSE connections are bounded by `gunicorn
|
||||
workers × 32`. A capacity miss looks like users randomly getting
|
||||
429'd.
|
||||
|
||||
---
|
||||
|
||||
## Common failure modes
|
||||
|
||||
### Redis-down
|
||||
|
||||
Symptoms: `/api/events` returns 200 but emits only `: connected`
|
||||
then the body closes. `XLEN` and `PUBLISH` both fail. The publisher's
|
||||
`record_event` swallows the failure and returns False; the live tail
|
||||
publish also drops on the floor. Frontend retries forever with
|
||||
exponential backoff.
|
||||
|
||||
Resolution: bring Redis back. The journal is gone (was in-memory
|
||||
only — Streams persist within a single Redis instance, no replication
|
||||
configured). New events flow as soon as Redis comes back.
|
||||
|
||||
### `AUTH_TYPE` misconfigured = sub:"local" cross-stream
|
||||
|
||||
Symptoms: every user shares `user:local:stream`. Any user sees
|
||||
everyone else's notifications.
|
||||
|
||||
Resolution: set `AUTH_TYPE=simple_jwt` (or `session_jwt`) in `.env`.
|
||||
The events route logs a one-time WARNING per process when
|
||||
`sub == "local"` is observed. A repeat WARNING after a restart
|
||||
confirms the misconfiguration.
|
||||
|
||||
### MAXLEN trimmed past Last-Event-ID
|
||||
|
||||
Symptoms: client reconnects with `last_event_id=X`, snapshot returns
|
||||
the entire MAXLEN'd backlog (because X is older than the oldest
|
||||
retained entry). Old events appear duplicated.
|
||||
|
||||
Detection: the route's `_oldest_retained_id` check emits
|
||||
`backlog.truncated` when this case fires. Frontend's
|
||||
`dispatchSSEEvent` clears `lastEventId` so the next reconnect starts
|
||||
fresh.
|
||||
|
||||
If the WARNING isn't firing but symptoms match: the user's client
|
||||
may have a corrupt cached `lastEventId`. `localStorage` doesn't
|
||||
store this state; check Redux state via DevTools.
|
||||
|
||||
### Stale event-stream client
|
||||
|
||||
Symptoms: events visible in `XRANGE` but the frontend slice doesn't
|
||||
update.
|
||||
|
||||
```bash
|
||||
# Is the client subscribed?
|
||||
redis-cli -n 2 PUBSUB NUMSUB user:<id>
|
||||
|
||||
# When did its connection start?
|
||||
grep "event.connect user=<id>" /var/log/docsgpt.log | tail -3
|
||||
```
|
||||
|
||||
If `NUMSUB` is 0 and no recent `event.connect`, the user's tab is
|
||||
closed or the connection died and never reconnected. Push them to
|
||||
reload.
|
||||
|
||||
### Publisher silent
|
||||
|
||||
Symptoms: worker is processing the task (Celery says SUCCESS), but
|
||||
no XADD and no PUBLISH. User sees no events.
|
||||
|
||||
```bash
|
||||
# Was the publisher import error suppressed?
|
||||
grep "publish_user_event" /var/log/celery.log | grep -i "warn\|error" | tail -20
|
||||
|
||||
# Is push disabled?
|
||||
grep "ENABLE_SSE_PUSH" /var/log/docsgpt.log | tail -5
|
||||
```
|
||||
|
||||
`ENABLE_SSE_PUSH=False` in `.env` would silence the publisher
|
||||
globally. Useful for incident response if a runaway publisher is
|
||||
DoS'ing Redis; toggle off, fix root cause, toggle on.
|
||||
|
||||
---
|
||||
|
||||
## Useful one-liners
|
||||
|
||||
```bash
|
||||
# Watch a user's live event stream in real time (all events, all types)
|
||||
redis-cli -n 2 PSUBSCRIBE 'user:*' | grep "user:<id>"
|
||||
|
||||
# Last 10 events the user would see on reconnect
|
||||
redis-cli -n 2 XREVRANGE user:<id>:stream + - COUNT 10
|
||||
|
||||
# Live count of subscribed clients per user
|
||||
redis-cli -n 2 PUBSUB NUMSUB $(redis-cli -n 2 PUBSUB CHANNELS 'user:*')
|
||||
|
||||
# Trim a runaway stream (CAREFUL — destroys backlog for all current
|
||||
# subscribers; OK after explaining to the user)
|
||||
redis-cli -n 2 XTRIM user:<id>:stream MAXLEN 0
|
||||
|
||||
# Clear a wedged concurrent-connection counter
|
||||
redis-cli -n 2 DEL user:<id>:sse_count
|
||||
|
||||
# Force-flip every client to re-snapshot (drop the stream key entirely
|
||||
# — destroys the backlog; clients reconnect with their last id and
|
||||
# get a backlog.truncated)
|
||||
redis-cli -n 2 DEL user:<id>:stream
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Settings reference
|
||||
|
||||
Everything in `application/core/settings.py`:
|
||||
|
||||
| Setting | Default | Purpose |
|
||||
| --------------------------------------------- | ------- | --------------------------------------------- |
|
||||
| `ENABLE_SSE_PUSH` | `True` | Master switch. False = publisher no-ops, route serves "push_disabled" comment. |
|
||||
| `EVENTS_STREAM_MAXLEN` | `1000` | Per-user backlog cap. Approximate via `XADD MAXLEN ~`. |
|
||||
| `SSE_KEEPALIVE_SECONDS` | `15` | Comment-frame cadence. Must sit under reverse-proxy idle close. |
|
||||
| `SSE_MAX_CONCURRENT_PER_USER` | `8` | Cap on simultaneous SSE connections per user. 0 = disabled. |
|
||||
| `EVENTS_REPLAY_MAX_PER_REQUEST` | `200` | Hard cap on snapshot rows per request. |
|
||||
| `EVENTS_REPLAY_BUDGET_REQUESTS_PER_WINDOW` | `30` | Per-user replays per window. 0 = disabled. |
|
||||
| `EVENTS_REPLAY_BUDGET_WINDOW_SECONDS` | `60` | Window length. |
|
||||
| `MESSAGE_EVENTS_RETENTION_DAYS` | `14` | Retention for the `message_events` journal; `cleanup_message_events` beat task deletes older rows. |
|
||||
|
||||
---
|
||||
|
||||
## Known limitations
|
||||
|
||||
### Each tab runs its own SSE connection
|
||||
|
||||
There is no cross-tab dedup. Every tab open to the app holds its
|
||||
own SSE connection and dispatches every received event into its
|
||||
own Redux store, so a user with N tabs open will see N copies of
|
||||
each toast. With `SSE_MAX_CONCURRENT_PER_USER=8` (the default) a
|
||||
heavy multi-tab user can also hit the connection cap and start
|
||||
seeing 429s. Cross-tab dedup via a `BroadcastChannel` ring +
|
||||
`navigator.locks`-based leader election is tracked as future work.
|
||||
|
||||
### `/c/<unknown-id>` normalises to `/c/new`
|
||||
|
||||
If a user navigates to a conversation id that isn't in their
|
||||
loaded list, the conversation route rewrites the URL to `/c/new`.
|
||||
`ToolApprovalToast`'s gate uses `useMatch('/c/:conversationId')`,
|
||||
so for the brief window after the rewrite the toast may surface
|
||||
for a conversation the user *thought* they were already viewing.
|
||||
Pre-existing route behaviour; not a notifications regression.
|
||||
|
||||
### Terminal events un-dismiss running uploads
|
||||
|
||||
`frontend/src/upload/uploadSlice.ts` sets `dismissed: false` when
|
||||
an upload reaches `completed` or `failed`. If the user dismissed a
|
||||
running task and the terminal SSE arrives later, the toast pops
|
||||
back. Intentional ("notify the user it's done"); revisit if the
|
||||
re-surface UX is too aggressive for v2.
|
||||
|
||||
### Werkzeug doesn't auto-reload route files
|
||||
|
||||
The dev server (`flask run`) doesn't watch
|
||||
`application/api/events/routes.py` for changes by default.
|
||||
After editing the route, restart Flask manually — `--reload`
|
||||
isn't on. (Production gunicorn reloads via deploy.)
|
||||
|
||||
### MCP OAuth completion can fall outside the user stream's MAXLEN window
|
||||
|
||||
`get_oauth_status` scans up to `EVENTS_STREAM_MAXLEN` (~1000) entries via `XREVRANGE`. If the user has a high-rate ingest running concurrent with the OAuth handshake, the `mcp.oauth.completed` envelope can be trimmed off the back before they click Save. Symptom: backend returns "OAuth failed or not completed" even though the popup completed successfully.
|
||||
|
||||
Mitigation today: bump `EVENTS_STREAM_MAXLEN` per-deployment if your users routinely flood the channel during OAuth flows. A dedicated short-TTL Redis key for OAuth task results is tracked as a follow-up.
|
||||
|
||||
### React StrictMode double-mounts SSE
|
||||
|
||||
In dev, React 18 StrictMode mounts → unmounts → remounts every
|
||||
component, briefly opening two SSE connections per tab before the
|
||||
first is aborted. With `SSE_MAX_CONCURRENT_PER_USER=8` and 4–5
|
||||
tabs open concurrently you can transiently hit the cap and see
|
||||
HTTP 429 on cold-load. The first connection's counter increment
|
||||
fires before the AbortController-induced disconnect can decrement
|
||||
it. Production (single mount, no StrictMode) is unaffected; raise
|
||||
the cap in dev or accept transient 429s.
|
||||
1531
frontend/package-lock.json
generated
1531
frontend/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -7,8 +7,6 @@
|
||||
"dev": "vite",
|
||||
"build": "tsc && vite build",
|
||||
"preview": "vite preview",
|
||||
"test": "vitest run",
|
||||
"test:watch": "vitest",
|
||||
"lint": "eslint ./src --ext .jsx,.js,.ts,.tsx",
|
||||
"lint-fix": "eslint ./src --ext .jsx,.js,.ts,.tsx --fix",
|
||||
"format": "prettier ./src --write",
|
||||
@@ -71,7 +69,6 @@
|
||||
"eslint-plugin-promise": "^6.6.0",
|
||||
"eslint-plugin-react": "^7.37.5",
|
||||
"eslint-plugin-unused-imports": "^4.1.4",
|
||||
"happy-dom": "^17.6.3",
|
||||
"husky": "^9.1.7",
|
||||
"lint-staged": "^16.4.0",
|
||||
"postcss": "^8.5.12",
|
||||
@@ -81,7 +78,6 @@
|
||||
"tw-animate-css": "^1.4.0",
|
||||
"typescript": "^6.0.3",
|
||||
"vite": "^8.0.10",
|
||||
"vite-plugin-svgr": "^4.3.0",
|
||||
"vitest": "^3.2.4"
|
||||
"vite-plugin-svgr": "^4.3.0"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,7 +10,6 @@ import Spinner from './components/Spinner';
|
||||
import UploadToast from './components/UploadToast';
|
||||
import Conversation from './conversation/Conversation';
|
||||
import { SharedConversation } from './conversation/SharedConversation';
|
||||
import { EventStreamProvider } from './events/EventStreamProvider';
|
||||
import { useDarkTheme, useMediaQuery } from './hooks';
|
||||
import useDataInitializer from './hooks/useDataInitializer';
|
||||
import useTokenAuth from './hooks/useTokenAuth';
|
||||
@@ -18,7 +17,6 @@ import Navigation from './Navigation';
|
||||
import PageNotFound from './PageNotFound';
|
||||
import Setting from './settings';
|
||||
import Notification from './components/Notification';
|
||||
import ToolApprovalToast from './notifications/ToolApprovalToast';
|
||||
|
||||
function AuthWrapper({ children }: { children: React.ReactNode }) {
|
||||
const { isAuthLoading } = useTokenAuth();
|
||||
@@ -31,7 +29,7 @@ function AuthWrapper({ children }: { children: React.ReactNode }) {
|
||||
</div>
|
||||
);
|
||||
}
|
||||
return <EventStreamProvider>{children}</EventStreamProvider>;
|
||||
return <>{children}</>;
|
||||
}
|
||||
|
||||
function MainLayout() {
|
||||
@@ -52,7 +50,6 @@ function MainLayout() {
|
||||
<Outlet />
|
||||
</div>
|
||||
<UploadToast />
|
||||
<ToolApprovalToast />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import Github from './assets/git_nav.svg';
|
||||
import Hamburger from './assets/hamburger.svg';
|
||||
import openNewChat from './assets/openNewChat.svg';
|
||||
import Pin from './assets/pin.svg';
|
||||
import SearchIcon from './assets/search.svg';
|
||||
import AgentImage from './components/AgentImage';
|
||||
import SettingGear from './assets/settingGear.svg';
|
||||
import Spark from './assets/spark.svg';
|
||||
@@ -35,6 +36,7 @@ import { useDarkTheme, useMediaQuery } from './hooks';
|
||||
import useTokenAuth from './hooks/useTokenAuth';
|
||||
import DeleteConvModal from './modals/DeleteConvModal';
|
||||
import JWTModal from './modals/JWTModal';
|
||||
import SearchConversationsModal from './modals/SearchConversationsModal';
|
||||
import { ActiveState } from './models/misc';
|
||||
import { getConversations } from './preferences/preferenceApi';
|
||||
import {
|
||||
@@ -82,6 +84,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
const [uploadModalState, setUploadModalState] =
|
||||
useState<ActiveState>('INACTIVE');
|
||||
const [recentAgents, setRecentAgents] = useState<Agent[]>([]);
|
||||
const [searchOpen, setSearchOpen] = useState(false);
|
||||
|
||||
const navRef = useRef<HTMLDivElement>(null);
|
||||
useEffect(() => {
|
||||
@@ -104,6 +107,19 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
};
|
||||
}
|
||||
}, [navOpen, isMobile, isTablet, setNavOpen]);
|
||||
|
||||
useEffect(() => {
|
||||
function handleSearchShortcut(event: KeyboardEvent) {
|
||||
if ((event.ctrlKey || event.metaKey) && event.key.toLowerCase() === 'k') {
|
||||
event.preventDefault();
|
||||
setSearchOpen(true);
|
||||
}
|
||||
}
|
||||
|
||||
document.addEventListener('keydown', handleSearchShortcut);
|
||||
return () => document.removeEventListener('keydown', handleSearchShortcut);
|
||||
}, []);
|
||||
|
||||
async function fetchRecentAgents() {
|
||||
try {
|
||||
const response = await userService.getPinnedAgents(token);
|
||||
@@ -506,11 +522,23 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
)}
|
||||
{conversations?.data && conversations.data.length > 0 ? (
|
||||
<div className="mt-7">
|
||||
<div className="mx-4 my-auto mt-2 flex h-6 items-center justify-between gap-4 rounded-3xl">
|
||||
<div className="mx-4 my-auto mt-2 flex h-8 items-center justify-between gap-4 rounded-3xl">
|
||||
<p className="mt-1 ml-4 text-sm font-semibold">{t('chats')}</p>
|
||||
<button
|
||||
onClick={() => setSearchOpen(true)}
|
||||
className="hover:bg-sidebar-accent mr-2 flex h-7 w-7 items-center justify-center rounded-full"
|
||||
aria-label={t('modals.searchConversations.searchPlaceholder')}
|
||||
title={t('modals.searchConversations.searchPlaceholder')}
|
||||
>
|
||||
<img
|
||||
src={SearchIcon}
|
||||
alt="search"
|
||||
className="h-4 w-4 opacity-70"
|
||||
/>
|
||||
</button>
|
||||
</div>
|
||||
<div className="conversations-container">
|
||||
{conversations.data?.map((conversation) => (
|
||||
{(conversations.data ?? []).map((conversation) => (
|
||||
<ConversationTile
|
||||
key={conversation.id}
|
||||
conversation={conversation}
|
||||
@@ -644,6 +672,17 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
||||
modalState={showTokenModal ? 'ACTIVE' : 'INACTIVE'}
|
||||
handleTokenSubmit={handleTokenSubmit}
|
||||
/>
|
||||
{searchOpen && (
|
||||
<SearchConversationsModal
|
||||
close={() => setSearchOpen(false)}
|
||||
conversations={conversations?.data ?? []}
|
||||
token={token}
|
||||
onSelectConversation={(id) => {
|
||||
handleConversationClick(id);
|
||||
if (isMobile || isTablet) setNavOpen(false);
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -28,13 +28,13 @@ const endpoints = {
|
||||
UPDATE_PROMPT: '/api/update_prompt',
|
||||
SINGLE_PROMPT: (id: string) => `/api/get_single_prompt?id=${id}`,
|
||||
DELETE_PATH: (docPath: string) => `/api/delete_old?source_id=${docPath}`,
|
||||
TASK_STATUS: (task_id: string) => `/api/task_status?task_id=${task_id}`,
|
||||
MESSAGE_ANALYTICS: '/api/get_message_analytics',
|
||||
TOKEN_ANALYTICS: '/api/get_token_analytics',
|
||||
FEEDBACK_ANALYTICS: '/api/get_feedback_analytics',
|
||||
LOGS: `/api/get_user_logs`,
|
||||
MANAGE_SYNC: '/api/manage_sync',
|
||||
SYNC_SOURCE: '/api/sync_source',
|
||||
REINGEST_SOURCE: '/api/sources/reingest',
|
||||
GET_AVAILABLE_TOOLS: '/api/available_tools',
|
||||
GET_USER_TOOLS: '/api/get_tools',
|
||||
CREATE_TOOL: '/api/create_tool',
|
||||
@@ -73,6 +73,8 @@ const endpoints = {
|
||||
MANAGE_SOURCE_FILES: '/api/manage_source_files',
|
||||
MCP_TEST_CONNECTION: '/api/mcp_server/test',
|
||||
MCP_SAVE_SERVER: '/api/mcp_server/save',
|
||||
MCP_OAUTH_STATUS: (task_id: string) =>
|
||||
`/api/mcp_server/oauth_status/${task_id}`,
|
||||
MCP_AUTH_STATUS: '/api/mcp_server/auth_status',
|
||||
AGENT_FOLDERS: '/api/agents/folders/',
|
||||
AGENT_FOLDER: (id: string) => `/api/agents/folders/${id}`,
|
||||
@@ -96,6 +98,8 @@ const endpoints = {
|
||||
FEEDBACK: '/api/feedback',
|
||||
CONVERSATION: (id: string) => `/api/get_single_conversation?id=${id}`,
|
||||
CONVERSATIONS: '/api/get_conversations',
|
||||
SEARCH_CONVERSATIONS: (q: string, limit = 30) =>
|
||||
`/api/search_conversations?q=${encodeURIComponent(q)}&limit=${limit}`,
|
||||
MESSAGE_TAIL: (messageId: string) => `/api/messages/${messageId}/tail`,
|
||||
SHARE_CONVERSATION: (isPromptable: boolean) =>
|
||||
`/api/share?isPromptable=${isPromptable}`,
|
||||
|
||||
@@ -32,6 +32,16 @@ const conversationService = {
|
||||
apiClient.get(endpoints.CONVERSATION.MESSAGE_TAIL(messageId), token, {}),
|
||||
getConversations: (token: string | null): Promise<any> =>
|
||||
apiClient.get(endpoints.CONVERSATION.CONVERSATIONS, token, {}),
|
||||
searchConversations: (
|
||||
query: string,
|
||||
token: string | null,
|
||||
limit = 30,
|
||||
): Promise<any> =>
|
||||
apiClient.get(
|
||||
endpoints.CONVERSATION.SEARCH_CONVERSATIONS(query, limit),
|
||||
token,
|
||||
{},
|
||||
),
|
||||
shareConversation: (
|
||||
isPromptable: boolean,
|
||||
data: any,
|
||||
|
||||
@@ -61,6 +61,8 @@ const userService = {
|
||||
apiClient.get(endpoints.USER.SINGLE_PROMPT(id), token),
|
||||
deletePath: (docPath: string, token: string | null): Promise<any> =>
|
||||
apiClient.get(endpoints.USER.DELETE_PATH(docPath), token),
|
||||
getTaskStatus: (task_id: string, token: string | null): Promise<any> =>
|
||||
throttledApiClient.get(endpoints.USER.TASK_STATUS(task_id), token),
|
||||
getMessageAnalytics: (data: any, token: string | null): Promise<any> =>
|
||||
apiClient.post(endpoints.USER.MESSAGE_ANALYTICS, data, token),
|
||||
getTokenAnalytics: (data: any, token: string | null): Promise<any> =>
|
||||
@@ -73,8 +75,6 @@ const userService = {
|
||||
apiClient.post(endpoints.USER.MANAGE_SYNC, data, token),
|
||||
syncSource: (data: any, token: string | null): Promise<any> =>
|
||||
apiClient.post(endpoints.USER.SYNC_SOURCE, data, token),
|
||||
reingestSource: (data: any, token: string | null): Promise<any> =>
|
||||
apiClient.post(endpoints.USER.REINGEST_SOURCE, data, token),
|
||||
getAvailableTools: (token: string | null): Promise<any> =>
|
||||
apiClient.get(endpoints.USER.GET_AVAILABLE_TOOLS, token),
|
||||
getUserTools: (token: string | null): Promise<any> =>
|
||||
@@ -172,6 +172,8 @@ const userService = {
|
||||
apiClient.post(endpoints.USER.MCP_TEST_CONNECTION, data, token),
|
||||
saveMCPServer: (data: any, token: string | null): Promise<any> =>
|
||||
apiClient.post(endpoints.USER.MCP_SAVE_SERVER, data, token),
|
||||
getMCPOAuthStatus: (task_id: string, token: string | null): Promise<any> =>
|
||||
throttledApiClient.get(endpoints.USER.MCP_OAUTH_STATUS(task_id), token),
|
||||
getMCPAuthStatus: (token: string | null): Promise<any> =>
|
||||
throttledApiClient.get(endpoints.USER.MCP_AUTH_STATUS, token),
|
||||
syncConnector: (
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import React, { useEffect, useRef } from 'react';
|
||||
import React, { useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useSelector } from 'react-redux';
|
||||
|
||||
@@ -32,24 +32,13 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
||||
const [isDarkTheme] = useDarkTheme();
|
||||
const completedRef = useRef(false);
|
||||
const intervalRef = useRef<number | null>(null);
|
||||
const authWindowRef = useRef<Window | null>(null);
|
||||
// Hold the exact listener identity so unmount cleanup removes the same fn.
|
||||
const messageHandlerRef = useRef<((event: MessageEvent) => void) | null>(
|
||||
null,
|
||||
);
|
||||
// Tracks mount status so async ``fetch`` resolves after unmount don't
|
||||
// call ``onSuccess`` / ``onError`` on a vanished parent.
|
||||
const mountedRef = useRef(true);
|
||||
|
||||
const cleanup = () => {
|
||||
if (intervalRef.current) {
|
||||
clearInterval(intervalRef.current);
|
||||
intervalRef.current = null;
|
||||
}
|
||||
if (messageHandlerRef.current) {
|
||||
window.removeEventListener('message', messageHandlerRef.current as any);
|
||||
messageHandlerRef.current = null;
|
||||
}
|
||||
window.removeEventListener('message', handleAuthMessage as any);
|
||||
};
|
||||
|
||||
const handleAuthMessage = (event: MessageEvent) => {
|
||||
@@ -60,7 +49,6 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
||||
if (successGeneric || successProvider) {
|
||||
completedRef.current = true;
|
||||
cleanup();
|
||||
authWindowRef.current = null;
|
||||
onSuccess({
|
||||
session_token: event.data.session_token,
|
||||
user_email:
|
||||
@@ -70,7 +58,6 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
||||
} else if (errorProvider) {
|
||||
completedRef.current = true;
|
||||
cleanup();
|
||||
authWindowRef.current = null;
|
||||
onError(
|
||||
event.data.error || t('modals.uploadDoc.connectors.auth.authFailed'),
|
||||
);
|
||||
@@ -80,20 +67,12 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
||||
const handleAuth = async () => {
|
||||
try {
|
||||
completedRef.current = false;
|
||||
// Close any popup left over from a previous click before wiping
|
||||
// the ref — otherwise the old window keeps living with no
|
||||
// interval watching it and no listener handling its messages.
|
||||
if (authWindowRef.current && !authWindowRef.current.closed) {
|
||||
authWindowRef.current.close();
|
||||
}
|
||||
authWindowRef.current = null;
|
||||
cleanup();
|
||||
|
||||
const authResponse = await userService.getConnectorAuthUrl(
|
||||
provider,
|
||||
token,
|
||||
);
|
||||
if (!mountedRef.current) return;
|
||||
|
||||
if (!authResponse.ok) {
|
||||
throw new Error(
|
||||
@@ -102,7 +81,6 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
||||
}
|
||||
|
||||
const authData = await authResponse.json();
|
||||
if (!mountedRef.current) return;
|
||||
if (!authData.success || !authData.authorization_url) {
|
||||
throw new Error(
|
||||
authData.error || t('modals.uploadDoc.connectors.auth.authUrlFailed'),
|
||||
@@ -117,23 +95,13 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
||||
if (!authWindow) {
|
||||
throw new Error(t('modals.uploadDoc.connectors.auth.popupBlocked'));
|
||||
}
|
||||
authWindowRef.current = authWindow;
|
||||
|
||||
messageHandlerRef.current = handleAuthMessage;
|
||||
window.addEventListener('message', handleAuthMessage as any);
|
||||
|
||||
const checkClosed = window.setInterval(() => {
|
||||
if (authWindow.closed) {
|
||||
clearInterval(checkClosed);
|
||||
intervalRef.current = null;
|
||||
if (messageHandlerRef.current) {
|
||||
window.removeEventListener(
|
||||
'message',
|
||||
messageHandlerRef.current as any,
|
||||
);
|
||||
messageHandlerRef.current = null;
|
||||
}
|
||||
authWindowRef.current = null;
|
||||
window.removeEventListener('message', handleAuthMessage as any);
|
||||
if (!completedRef.current) {
|
||||
onError(t('modals.uploadDoc.connectors.auth.authCancelled'));
|
||||
}
|
||||
@@ -141,7 +109,6 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
||||
}, 1000);
|
||||
intervalRef.current = checkClosed;
|
||||
} catch (error) {
|
||||
if (!mountedRef.current) return;
|
||||
onError(
|
||||
error instanceof Error
|
||||
? error.message
|
||||
@@ -150,18 +117,6 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
||||
}
|
||||
};
|
||||
|
||||
// Release interval, message listener, and popup on unmount only.
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
mountedRef.current = false;
|
||||
cleanup();
|
||||
if (authWindowRef.current && !authWindowRef.current.closed) {
|
||||
authWindowRef.current.close();
|
||||
}
|
||||
authWindowRef.current = null;
|
||||
};
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<>
|
||||
{errorMessage && (
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import React, { useEffect, useRef, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useSelector, useStore } from 'react-redux';
|
||||
import { useSelector } from 'react-redux';
|
||||
|
||||
import userService from '../api/services/userService';
|
||||
import ArrowLeft from '../assets/arrow-left.svg';
|
||||
@@ -14,7 +14,6 @@ import { useLoaderState, useOutsideAlerter } from '../hooks';
|
||||
import ConfirmationModal from '../modals/ConfirmationModal';
|
||||
import { ActiveState } from '../models/misc';
|
||||
import { selectToken } from '../preferences/preferenceSlice';
|
||||
import type { RootState } from '../store';
|
||||
import { formatBytes } from '../utils/stringUtils';
|
||||
import Chunks from './Chunks';
|
||||
import ContextMenu, { MenuOption } from './ContextMenu';
|
||||
@@ -65,7 +64,6 @@ const ConnectorTree: React.FC<ConnectorTreeProps> = ({
|
||||
useState<DirectoryStructure | null>(null);
|
||||
const [currentPath, setCurrentPath] = useState<string[]>([]);
|
||||
const token = useSelector(selectToken);
|
||||
const store = useStore<RootState>();
|
||||
const [activeMenuId, setActiveMenuId] = useState<string | null>(null);
|
||||
const menuRefs = useRef<{
|
||||
[key: string]: React.RefObject<HTMLDivElement | null>;
|
||||
@@ -83,25 +81,6 @@ const ConnectorTree: React.FC<ConnectorTreeProps> = ({
|
||||
const [syncDone, setSyncDone] = useState<boolean>(false);
|
||||
const [syncConfirmationModal, setSyncConfirmationModal] =
|
||||
useState<ActiveState>('INACTIVE');
|
||||
const mountedRef = useRef(true);
|
||||
const syncUnsubscribeRef = useRef<(() => void) | null>(null);
|
||||
// Holds the 5-minute SSE-wait timer so the unmount cleanup can clear
|
||||
// it — otherwise the timer fires up to 5 min after unmount and
|
||||
// resolves an abandoned Promise.
|
||||
const syncTimerRef = useRef<number | null>(null);
|
||||
|
||||
useEffect(
|
||||
() => () => {
|
||||
mountedRef.current = false;
|
||||
syncUnsubscribeRef.current?.();
|
||||
syncUnsubscribeRef.current = null;
|
||||
if (syncTimerRef.current !== null) {
|
||||
window.clearTimeout(syncTimerRef.current);
|
||||
syncTimerRef.current = null;
|
||||
}
|
||||
},
|
||||
[],
|
||||
);
|
||||
|
||||
useOutsideAlerter(
|
||||
searchDropdownRef,
|
||||
@@ -137,108 +116,67 @@ const ConnectorTree: React.FC<ConnectorTreeProps> = ({
|
||||
console.log('Sync started successfully:', data.task_id);
|
||||
setSyncProgress(10);
|
||||
|
||||
// The connector worker (``ingest_connector`` in
|
||||
// ``application/worker.py``) publishes
|
||||
// ``source.ingest.{queued,completed,failed}`` envelopes keyed on
|
||||
// ``scope.id == docId`` (sync mode reuses the source uuid). Wait
|
||||
// on the bounded ``notifications.recentEvents`` ring for a
|
||||
// terminal envelope rather than polling ``/task_status``.
|
||||
// Mirrors FileTree's slice-walking pattern, including the
|
||||
// ``opStartedAt`` guard so a stale terminal event from a prior
|
||||
// sync of this same source can't short-circuit the current op.
|
||||
const opStartedAt = Date.now();
|
||||
|
||||
const terminalFromSse = (): 'completed' | 'failed' | null => {
|
||||
const events = store.getState().notifications.recentEvents;
|
||||
for (const event of events) {
|
||||
if (event.scope?.id !== docId) continue;
|
||||
const ts = event.ts ? Date.parse(event.ts) : NaN;
|
||||
if (!Number.isFinite(ts) || ts < opStartedAt) continue;
|
||||
if (event.type === 'source.ingest.completed') return 'completed';
|
||||
if (event.type === 'source.ingest.failed') return 'failed';
|
||||
}
|
||||
return null;
|
||||
};
|
||||
|
||||
const MAX_WAIT_MS = 5 * 60_000;
|
||||
const terminal = await new Promise<
|
||||
'completed' | 'failed' | 'timeout' | 'unmounted'
|
||||
>((resolve) => {
|
||||
// Cover the race where the event landed between the POST
|
||||
// returning and the subscribe call.
|
||||
const initial = terminalFromSse();
|
||||
if (initial) {
|
||||
resolve(initial);
|
||||
return;
|
||||
}
|
||||
if (!mountedRef.current) {
|
||||
resolve('unmounted');
|
||||
return;
|
||||
}
|
||||
let settled = false;
|
||||
const finish = (
|
||||
value: 'completed' | 'failed' | 'timeout' | 'unmounted',
|
||||
) => {
|
||||
if (settled) return;
|
||||
settled = true;
|
||||
if (syncTimerRef.current !== null) {
|
||||
window.clearTimeout(syncTimerRef.current);
|
||||
syncTimerRef.current = null;
|
||||
}
|
||||
if (syncUnsubscribeRef.current) {
|
||||
syncUnsubscribeRef.current();
|
||||
syncUnsubscribeRef.current = null;
|
||||
}
|
||||
resolve(value);
|
||||
};
|
||||
syncTimerRef.current = window.setTimeout(
|
||||
() => finish('timeout'),
|
||||
MAX_WAIT_MS,
|
||||
);
|
||||
syncUnsubscribeRef.current = store.subscribe(() => {
|
||||
if (!mountedRef.current) {
|
||||
finish('unmounted');
|
||||
return;
|
||||
}
|
||||
const next = terminalFromSse();
|
||||
if (next) finish(next);
|
||||
});
|
||||
});
|
||||
|
||||
if (terminal === 'timeout') {
|
||||
console.error('Sync timed out waiting for SSE terminal');
|
||||
} else if (terminal === 'unmounted') {
|
||||
return;
|
||||
}
|
||||
|
||||
if (terminal === 'completed') {
|
||||
// The "no files downloaded" early-return path publishes
|
||||
// ``completed`` with ``no_changes: true`` — treated as success
|
||||
// here; refreshing the directory is cheap and idempotent.
|
||||
setSyncProgress(100);
|
||||
console.log('Sync completed successfully');
|
||||
// Poll task status using userService
|
||||
const maxAttempts = 30;
|
||||
const pollInterval = 2000;
|
||||
|
||||
for (let attempt = 0; attempt < maxAttempts; attempt++) {
|
||||
try {
|
||||
const refreshResponse = await userService.getDirectoryStructure(
|
||||
docId,
|
||||
const statusResponse = await userService.getTaskStatus(
|
||||
data.task_id,
|
||||
token,
|
||||
);
|
||||
const refreshData = await refreshResponse.json();
|
||||
if (refreshData && refreshData.directory_structure) {
|
||||
setDirectoryStructure(refreshData.directory_structure);
|
||||
setCurrentPath([]);
|
||||
}
|
||||
if (refreshData && refreshData.provider) {
|
||||
setSourceProvider(refreshData.provider);
|
||||
const statusData = await statusResponse.json();
|
||||
|
||||
console.log(
|
||||
`Task status (attempt ${attempt + 1}):`,
|
||||
statusData.status,
|
||||
);
|
||||
|
||||
if (statusData.status === 'SUCCESS') {
|
||||
setSyncProgress(100);
|
||||
console.log('Sync completed successfully');
|
||||
|
||||
// Refresh directory structure
|
||||
try {
|
||||
const refreshResponse = await userService.getDirectoryStructure(
|
||||
docId,
|
||||
token,
|
||||
);
|
||||
const refreshData = await refreshResponse.json();
|
||||
if (refreshData && refreshData.directory_structure) {
|
||||
setDirectoryStructure(refreshData.directory_structure);
|
||||
setCurrentPath([]);
|
||||
}
|
||||
if (refreshData && refreshData.provider) {
|
||||
setSourceProvider(refreshData.provider);
|
||||
}
|
||||
|
||||
setSyncDone(true);
|
||||
setTimeout(() => setSyncDone(false), 5000);
|
||||
} catch (err) {
|
||||
console.error('Error refreshing directory structure:', err);
|
||||
}
|
||||
break;
|
||||
} else if (statusData.status === 'FAILURE') {
|
||||
console.error('Sync task failed:', statusData.result);
|
||||
break;
|
||||
} else if (statusData.status === 'PROGRESS') {
|
||||
const progress = Number(
|
||||
statusData.result && statusData.result.current != null
|
||||
? statusData.result.current
|
||||
: statusData.meta && statusData.meta.current != null
|
||||
? statusData.meta.current
|
||||
: 0,
|
||||
);
|
||||
setSyncProgress(Math.max(10, progress));
|
||||
}
|
||||
|
||||
setSyncDone(true);
|
||||
setTimeout(() => setSyncDone(false), 5000);
|
||||
} catch (err) {
|
||||
console.error('Error refreshing directory structure:', err);
|
||||
await new Promise((resolve) => setTimeout(resolve, pollInterval));
|
||||
} catch (error) {
|
||||
console.error('Error polling task status:', error);
|
||||
break;
|
||||
}
|
||||
} else if (terminal === 'failed') {
|
||||
console.error('Sync task failed (per SSE)');
|
||||
}
|
||||
} else {
|
||||
console.error('Sync failed:', data.error);
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import React, { useState, useRef, useEffect } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useSelector, useStore } from 'react-redux';
|
||||
import { useSelector } from 'react-redux';
|
||||
import { selectToken } from '../preferences/preferenceSlice';
|
||||
import type { RootState } from '../store';
|
||||
import { formatBytes } from '../utils/stringUtils';
|
||||
import Chunks from './Chunks';
|
||||
import ContextMenu, { MenuOption } from './ContextMenu';
|
||||
@@ -57,7 +56,6 @@ const FileTree: React.FC<FileTreeProps> = ({
|
||||
onBackToDocuments,
|
||||
}) => {
|
||||
const { t } = useTranslation();
|
||||
const store = useStore<RootState>();
|
||||
const [loading, setLoading] = useLoaderState(true, 500);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [directoryStructure, setDirectoryStructure] =
|
||||
@@ -97,25 +95,6 @@ const FileTree: React.FC<FileTreeProps> = ({
|
||||
const opQueueRef = useRef<QueuedOperation[]>([]);
|
||||
const processingRef = useRef(false);
|
||||
const [queueLength, setQueueLength] = useState(0);
|
||||
const mountedRef = useRef(true);
|
||||
const waitUnsubscribeRef = useRef<(() => void) | null>(null);
|
||||
// Holds the 5-minute SSE-wait timer so the unmount cleanup can clear
|
||||
// it — otherwise the timer fires up to 5 min after unmount and
|
||||
// resolves an abandoned Promise.
|
||||
const waitTimerRef = useRef<number | null>(null);
|
||||
|
||||
useEffect(
|
||||
() => () => {
|
||||
mountedRef.current = false;
|
||||
waitUnsubscribeRef.current?.();
|
||||
waitUnsubscribeRef.current = null;
|
||||
if (waitTimerRef.current !== null) {
|
||||
window.clearTimeout(waitTimerRef.current);
|
||||
waitTimerRef.current = null;
|
||||
}
|
||||
},
|
||||
[],
|
||||
);
|
||||
|
||||
useOutsideAlerter(
|
||||
searchDropdownRef,
|
||||
@@ -334,103 +313,47 @@ const FileTree: React.FC<FileTreeProps> = ({
|
||||
}
|
||||
console.log('Reingest task started:', result.reingest_task_id);
|
||||
|
||||
// SSE is the sole driver here. The backend's
|
||||
// ``reingest_source_worker`` publishes ``source.ingest.*``
|
||||
// keyed on the resolved ``source_id`` (the
|
||||
// ``manage_source_files`` route returns it explicitly so we
|
||||
// can match without consulting any slice). Subscribe to the
|
||||
// store and resolve when a terminal event tagged with our
|
||||
// source lands in ``notifications.recentEvents``. Re-checking
|
||||
// on every dispatch (rather than polling on a timer) avoids
|
||||
// races where a terminal could roll off the bounded ring
|
||||
// before the next tick observes it in chatty sessions.
|
||||
const reingestSourceId: string | undefined = result.source_id;
|
||||
// Cutoff so we don't pick up terminal events from a *previous*
|
||||
// reingest of the same source — the backend's
|
||||
// ``source.ingest.*`` payload doesn't carry a Celery task id,
|
||||
// so source_id alone is ambiguous when ops repeat.
|
||||
const opStartedAt = Date.now();
|
||||
const MAX_WAIT_MS = 5 * 60_000;
|
||||
const maxAttempts = 30;
|
||||
const pollInterval = 2000;
|
||||
|
||||
const terminalFromSse = (): 'completed' | 'failed' | null => {
|
||||
if (!reingestSourceId) return null;
|
||||
const events = store.getState().notifications.recentEvents;
|
||||
for (const event of events) {
|
||||
if (event.scope?.id !== reingestSourceId) continue;
|
||||
const ts = event.ts ? Date.parse(event.ts) : NaN;
|
||||
if (!Number.isFinite(ts) || ts < opStartedAt) continue;
|
||||
if (event.type === 'source.ingest.completed') return 'completed';
|
||||
if (event.type === 'source.ingest.failed') return 'failed';
|
||||
}
|
||||
return null;
|
||||
};
|
||||
for (let attempt = 0; attempt < maxAttempts; attempt++) {
|
||||
try {
|
||||
const statusResponse = await userService.getTaskStatus(
|
||||
result.reingest_task_id,
|
||||
token,
|
||||
);
|
||||
const statusData = await statusResponse.json();
|
||||
|
||||
const refreshStructure = async (): Promise<boolean> => {
|
||||
const structureResponse = await userService.getDirectoryStructure(
|
||||
docId,
|
||||
token,
|
||||
);
|
||||
const structureData = await structureResponse.json();
|
||||
if (!mountedRef.current) return false;
|
||||
if (structureData && structureData.directory_structure) {
|
||||
setDirectoryStructure(structureData.directory_structure);
|
||||
currentOpRef.current = null;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
console.log(
|
||||
`Task status (attempt ${attempt + 1}):`,
|
||||
statusData.status,
|
||||
);
|
||||
|
||||
const terminal = await new Promise<
|
||||
'completed' | 'failed' | 'timeout' | 'unmounted'
|
||||
>((resolve) => {
|
||||
if (!mountedRef.current) {
|
||||
resolve('unmounted');
|
||||
return;
|
||||
}
|
||||
// Cover the race where the terminal event landed between
|
||||
// the POST returning and the subscribe call.
|
||||
const initial = terminalFromSse();
|
||||
if (initial) {
|
||||
resolve(initial);
|
||||
return;
|
||||
}
|
||||
const timer = window.setTimeout(() => {
|
||||
waitUnsubscribeRef.current?.();
|
||||
waitUnsubscribeRef.current = null;
|
||||
waitTimerRef.current = null;
|
||||
resolve('timeout');
|
||||
}, MAX_WAIT_MS);
|
||||
waitTimerRef.current = timer;
|
||||
waitUnsubscribeRef.current = store.subscribe(() => {
|
||||
if (!mountedRef.current) {
|
||||
window.clearTimeout(timer);
|
||||
waitTimerRef.current = null;
|
||||
waitUnsubscribeRef.current?.();
|
||||
waitUnsubscribeRef.current = null;
|
||||
resolve('unmounted');
|
||||
return;
|
||||
if (statusData.status === 'SUCCESS') {
|
||||
console.log('Task completed successfully');
|
||||
|
||||
const structureResponse = await userService.getDirectoryStructure(
|
||||
docId,
|
||||
token,
|
||||
);
|
||||
const structureData = await structureResponse.json();
|
||||
|
||||
if (structureData && structureData.directory_structure) {
|
||||
setDirectoryStructure(structureData.directory_structure);
|
||||
currentOpRef.current = null;
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
} else if (statusData.status === 'FAILURE') {
|
||||
console.error('Task failed');
|
||||
break;
|
||||
}
|
||||
const next = terminalFromSse();
|
||||
if (next) {
|
||||
window.clearTimeout(timer);
|
||||
waitTimerRef.current = null;
|
||||
waitUnsubscribeRef.current?.();
|
||||
waitUnsubscribeRef.current = null;
|
||||
resolve(next);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
if (!mountedRef.current) return false;
|
||||
|
||||
if (terminal === 'completed') {
|
||||
if (await refreshStructure()) return true;
|
||||
} else if (terminal === 'failed') {
|
||||
console.error('Reingest task failed (per SSE)');
|
||||
} else if (terminal === 'unmounted') {
|
||||
return false;
|
||||
} else {
|
||||
console.error('Reingest timed out waiting for SSE terminal');
|
||||
await new Promise((resolve) => setTimeout(resolve, pollInterval));
|
||||
} catch (error) {
|
||||
console.error('Error polling task status:', error);
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
throw new Error(
|
||||
@@ -451,7 +374,7 @@ const FileTree: React.FC<FileTreeProps> = ({
|
||||
? 'delete directory'
|
||||
: 'delete file(s)';
|
||||
console.error(`Error ${actionText}:`, error);
|
||||
if (mountedRef.current) setError(`Failed to ${errorText}`);
|
||||
setError(`Failed to ${errorText}`);
|
||||
} finally {
|
||||
currentOpRef.current = null;
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ import { createPortal } from 'react-dom';
|
||||
import { LoaderCircle, Mic, Square } from 'lucide-react';
|
||||
import { useDropzone } from 'react-dropzone';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useDispatch, useSelector, useStore } from 'react-redux';
|
||||
import { useDispatch, useSelector } from 'react-redux';
|
||||
|
||||
import endpoints from '../api/endpoints';
|
||||
import userService from '../api/services/userService';
|
||||
@@ -28,9 +28,8 @@ import {
|
||||
selectSelectedDocs,
|
||||
selectToken,
|
||||
} from '../preferences/preferenceSlice';
|
||||
import type { RootState } from '../store';
|
||||
import Upload from '../upload/Upload';
|
||||
import { getOS, isTouchDevice } from '../utils/browserUtils';
|
||||
import { isTouchDevice } from '../utils/browserUtils';
|
||||
import SourcesPopup from './SourcesPopup';
|
||||
import ToolsPopup from './ToolsPopup';
|
||||
import { handleAbort } from '../conversation/conversationSlice';
|
||||
@@ -317,7 +316,6 @@ export default function MessageInput({
|
||||
const attachments = useSelector(selectAttachments);
|
||||
|
||||
const dispatch = useDispatch();
|
||||
const store = useStore<RootState>();
|
||||
const mediaStreamRef = useRef<MediaStream | null>(null);
|
||||
const audioContextRef = useRef<AudioContext | null>(null);
|
||||
const audioSourceNodeRef = useRef<MediaStreamAudioSourceNode | null>(null);
|
||||
@@ -337,7 +335,6 @@ export default function MessageInput({
|
||||
const voiceBaseValueRef = useRef('');
|
||||
const liveTranscriptRef = useRef('');
|
||||
|
||||
const browserOS = getOS();
|
||||
const isTouch = isTouchDevice();
|
||||
|
||||
const stopMediaStream = () => {
|
||||
@@ -386,25 +383,6 @@ export default function MessageInput({
|
||||
liveTranscriptRef.current = '';
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
const handleKeyDown = (event: KeyboardEvent) => {
|
||||
if (
|
||||
((browserOS === 'win' || browserOS === 'linux') &&
|
||||
event.ctrlKey &&
|
||||
event.key === 'k') ||
|
||||
(browserOS === 'mac' && event.metaKey && event.key === 'k')
|
||||
) {
|
||||
event.preventDefault();
|
||||
setIsSourcesPopupOpen((s) => !s);
|
||||
}
|
||||
};
|
||||
|
||||
document.addEventListener('keydown', handleKeyDown);
|
||||
return () => {
|
||||
document.removeEventListener('keydown', handleKeyDown);
|
||||
};
|
||||
}, [browserOS]);
|
||||
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
stopAudioProcessing();
|
||||
@@ -412,86 +390,6 @@ export default function MessageInput({
|
||||
};
|
||||
}, []);
|
||||
|
||||
// Recover the race where attachment.* SSE arrives before the upload
|
||||
// XHR's onload sets ``attachmentId``: walk recentEvents and watchdog
|
||||
// the row so it can't stay stuck on 'processing'. Mirrors
|
||||
// Upload.tsx's ``trackTraining``.
|
||||
const trackAttachment = useCallback(
|
||||
(clientId: string, attachmentId: string) => {
|
||||
let handled = false;
|
||||
|
||||
const check = () => {
|
||||
const state = store.getState();
|
||||
const row = state.upload.attachments.find((a) => a.id === clientId);
|
||||
if (!row) return true; // removed by user; stop tracking
|
||||
if (row.status === 'completed' || row.status === 'failed') {
|
||||
handled = true;
|
||||
return true;
|
||||
}
|
||||
for (const event of state.notifications.recentEvents) {
|
||||
if (event.scope?.id !== attachmentId) continue;
|
||||
if (event.type === 'attachment.completed') {
|
||||
const payload = (event.payload || {}) as Record<string, unknown>;
|
||||
const tokenCount = Number(payload.token_count);
|
||||
handled = true;
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: clientId,
|
||||
updates: {
|
||||
status: 'completed',
|
||||
progress: 100,
|
||||
...(Number.isFinite(tokenCount)
|
||||
? { token_count: tokenCount }
|
||||
: {}),
|
||||
},
|
||||
}),
|
||||
);
|
||||
return true;
|
||||
}
|
||||
if (event.type === 'attachment.failed') {
|
||||
handled = true;
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: clientId,
|
||||
updates: { status: 'failed' },
|
||||
}),
|
||||
);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
if (check()) return;
|
||||
const MAX_WAIT_MS = 5 * 60_000;
|
||||
let unsubscribe: (() => void) | null = null;
|
||||
const timer = window.setTimeout(() => {
|
||||
unsubscribe?.();
|
||||
if (!handled) {
|
||||
handled = true;
|
||||
console.warn(
|
||||
'trackAttachment: timed out waiting for terminal SSE',
|
||||
clientId,
|
||||
attachmentId,
|
||||
);
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: clientId,
|
||||
updates: { status: 'failed' },
|
||||
}),
|
||||
);
|
||||
}
|
||||
}, MAX_WAIT_MS);
|
||||
unsubscribe = store.subscribe(() => {
|
||||
if (check()) {
|
||||
window.clearTimeout(timer);
|
||||
unsubscribe?.();
|
||||
}
|
||||
});
|
||||
},
|
||||
[dispatch, store],
|
||||
);
|
||||
|
||||
const uploadFiles = useCallback(
|
||||
(files: File[]) => {
|
||||
if (!files || files.length === 0) return;
|
||||
@@ -592,19 +490,11 @@ export default function MessageInput({
|
||||
id: uiId,
|
||||
updates: {
|
||||
taskId: task.task_id,
|
||||
// Stash the server's attachment id so SSE
|
||||
// ``attachment.*`` events can match this
|
||||
// row by ``scope.id`` and drive the
|
||||
// per-attachment push-fresh poll gate.
|
||||
attachmentId: task.attachment_id,
|
||||
status: 'processing',
|
||||
progress: 10,
|
||||
},
|
||||
}),
|
||||
);
|
||||
if (task.attachment_id) {
|
||||
trackAttachment(uiId, task.attachment_id);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -635,15 +525,11 @@ export default function MessageInput({
|
||||
id: uiId,
|
||||
updates: {
|
||||
taskId: t.task_id,
|
||||
attachmentId: t.attachment_id,
|
||||
status: 'processing',
|
||||
progress: 10,
|
||||
},
|
||||
}),
|
||||
);
|
||||
if (t.attachment_id) {
|
||||
trackAttachment(uiId, t.attachment_id);
|
||||
}
|
||||
} else {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
@@ -677,15 +563,11 @@ export default function MessageInput({
|
||||
id: uiId,
|
||||
updates: {
|
||||
taskId: response.task_id,
|
||||
attachmentId: response.attachment_id,
|
||||
status: 'processing',
|
||||
progress: 10,
|
||||
},
|
||||
}),
|
||||
);
|
||||
if (response.attachment_id) {
|
||||
trackAttachment(uiId, response.attachment_id);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
console.warn(
|
||||
@@ -812,15 +694,11 @@ export default function MessageInput({
|
||||
id: uniqueId,
|
||||
updates: {
|
||||
taskId: response.task_id,
|
||||
attachmentId: response.attachment_id,
|
||||
status: 'processing',
|
||||
progress: 10,
|
||||
},
|
||||
}),
|
||||
);
|
||||
if (response.attachment_id) {
|
||||
trackAttachment(uniqueId, response.attachment_id);
|
||||
}
|
||||
} else {
|
||||
// If backend returned tasks[] for single-file, handle gracefully:
|
||||
if (
|
||||
@@ -832,15 +710,11 @@ export default function MessageInput({
|
||||
id: uniqueId,
|
||||
updates: {
|
||||
taskId: response.tasks[0].task_id,
|
||||
attachmentId: response.tasks[0].attachment_id,
|
||||
status: 'processing',
|
||||
progress: 10,
|
||||
},
|
||||
}),
|
||||
);
|
||||
if (response.tasks[0].attachment_id) {
|
||||
trackAttachment(uniqueId, response.tasks[0].attachment_id);
|
||||
}
|
||||
} else {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
@@ -887,7 +761,7 @@ export default function MessageInput({
|
||||
xhr.send(formData);
|
||||
});
|
||||
},
|
||||
[dispatch, token, trackAttachment],
|
||||
[dispatch, token],
|
||||
);
|
||||
|
||||
const handleFileAttachment = (e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
@@ -922,6 +796,65 @@ export default function MessageInput({
|
||||
accept: FILE_UPLOAD_ACCEPT,
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
const checkTaskStatus = () => {
|
||||
const processingAttachments = attachments.filter(
|
||||
(att) => att.status === 'processing' && att.taskId,
|
||||
);
|
||||
|
||||
processingAttachments.forEach((attachment) => {
|
||||
userService
|
||||
.getTaskStatus(attachment.taskId!, null)
|
||||
.then((data) => data.json())
|
||||
.then((data) => {
|
||||
if (data.status === 'SUCCESS') {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: attachment.id,
|
||||
updates: {
|
||||
status: 'completed',
|
||||
progress: 100,
|
||||
id: data.result?.attachment_id,
|
||||
token_count: data.result?.token_count,
|
||||
},
|
||||
}),
|
||||
);
|
||||
} else if (data.status === 'FAILURE') {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: attachment.id,
|
||||
updates: { status: 'failed' },
|
||||
}),
|
||||
);
|
||||
} else if (data.status === 'PROGRESS' && data.result?.current) {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: attachment.id,
|
||||
updates: { progress: data.result.current },
|
||||
}),
|
||||
);
|
||||
}
|
||||
})
|
||||
.catch(() => {
|
||||
dispatch(
|
||||
updateAttachment({
|
||||
id: attachment.id,
|
||||
updates: { status: 'failed' },
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
const interval = setInterval(() => {
|
||||
if (attachments.some((att) => att.status === 'processing')) {
|
||||
checkTaskStatus();
|
||||
}
|
||||
}, 2000);
|
||||
|
||||
return () => clearInterval(interval);
|
||||
}, [attachments, dispatch]);
|
||||
|
||||
const handleInput = useCallback(() => {
|
||||
if (inputRef.current) {
|
||||
if (window.innerWidth < 350) inputRef.current.style.height = 'auto';
|
||||
@@ -1632,11 +1565,6 @@ export default function MessageInput({
|
||||
: `${selectedDocs.length} sources selected`
|
||||
: t('conversation.sources.title')}
|
||||
</span>
|
||||
{!isTouch && (
|
||||
<span className="ml-1 hidden text-[10px] text-gray-500 sm:inline-block dark:text-gray-400">
|
||||
{browserOS === 'mac' ? '(⌘K)' : '(ctrl+K)'}
|
||||
</span>
|
||||
)}
|
||||
</button>
|
||||
)}
|
||||
|
||||
|
||||
@@ -5,54 +5,41 @@ import { useDispatch, useSelector } from 'react-redux';
|
||||
import CheckCircleFilled from '../assets/check-circle-filled.svg';
|
||||
import ChevronDown from '../assets/chevron-down.svg';
|
||||
import WarnIcon from '../assets/warn.svg';
|
||||
import {
|
||||
dismissUploadTask,
|
||||
selectUploadTasks,
|
||||
type UploadTask,
|
||||
} from '../upload/uploadSlice';
|
||||
import { dismissUploadTask, selectUploadTasks } from '../upload/uploadSlice';
|
||||
|
||||
const PROGRESS_RADIUS = 10;
|
||||
const PROGRESS_CIRCUMFERENCE = 2 * Math.PI * PROGRESS_RADIUS;
|
||||
|
||||
const IN_PROGRESS_STATUSES = new Set<UploadTask['status']>([
|
||||
'preparing',
|
||||
'uploading',
|
||||
'training',
|
||||
]);
|
||||
|
||||
/**
|
||||
* Single merged upload card — Google-Drive style. Multiple in-flight
|
||||
* uploads share one toast with a list of rows; the header reflects
|
||||
* the *primary* task's status (the newest still-running task, or the
|
||||
* newest task overall if all are terminal). Per-task progress lives
|
||||
* on each row.
|
||||
*
|
||||
* Dismissal: the header X dismisses every visible task at once
|
||||
* (mirrors the GDrive panel close — keeps the surface tidy without
|
||||
* per-row controls). The chevron collapses the row list.
|
||||
*/
|
||||
export default function UploadToast() {
|
||||
const [collapsed, setCollapsed] = useState(false);
|
||||
const [collapsedTasks, setCollapsedTasks] = useState<Record<string, boolean>>(
|
||||
{},
|
||||
);
|
||||
|
||||
const toggleTaskCollapse = (taskId: string) => {
|
||||
setCollapsedTasks((prev) => ({
|
||||
...prev,
|
||||
[taskId]: !prev[taskId],
|
||||
}));
|
||||
};
|
||||
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useDispatch();
|
||||
const uploadTasks = useSelector(selectUploadTasks);
|
||||
|
||||
const visibleTasks = uploadTasks.filter((task) => !task.dismissed);
|
||||
if (visibleTasks.length === 0) return null;
|
||||
|
||||
// Pick the task that drives the header status: prefer a still-
|
||||
// running task (most-recent first since the slice unshifts), and
|
||||
// fall back to whatever's most-recent if everything is terminal.
|
||||
const primaryTask =
|
||||
visibleTasks.find((task) => IN_PROGRESS_STATUSES.has(task.status)) ??
|
||||
visibleTasks[0];
|
||||
|
||||
const headerLabel = getStatusHeading(primaryTask.status, t);
|
||||
|
||||
const dismissAll = () => {
|
||||
for (const task of visibleTasks) {
|
||||
dispatch(dismissUploadTask(task.id));
|
||||
const getStatusHeading = (status: string) => {
|
||||
switch (status) {
|
||||
case 'preparing':
|
||||
return t('modals.uploadDoc.progress.wait');
|
||||
case 'uploading':
|
||||
return t('modals.uploadDoc.progress.upload');
|
||||
case 'training':
|
||||
return t('modals.uploadDoc.progress.upload');
|
||||
case 'completed':
|
||||
return t('modals.uploadDoc.progress.completed');
|
||||
case 'failed':
|
||||
return t('modals.uploadDoc.progress.failed');
|
||||
default:
|
||||
return t('modals.uploadDoc.progress.preparing');
|
||||
}
|
||||
};
|
||||
|
||||
@@ -60,212 +47,180 @@ export default function UploadToast() {
|
||||
<div
|
||||
className="fixed right-4 bottom-4 z-50 flex max-w-md flex-col gap-2"
|
||||
onMouseDown={(e) => e.stopPropagation()}
|
||||
role="status"
|
||||
aria-live="polite"
|
||||
aria-atomic="false"
|
||||
>
|
||||
<div
|
||||
className={`border-border bg-card w-[271px] overflow-hidden rounded-2xl border shadow-[0px_24px_48px_0px_#00000029] transition-all duration-300`}
|
||||
>
|
||||
<div
|
||||
className={`flex items-center justify-between px-4 py-3 ${
|
||||
primaryTask.status !== 'failed'
|
||||
? 'bg-accent/50 dark:bg-muted'
|
||||
: 'bg-destructive/10 dark:bg-destructive/10'
|
||||
}`}
|
||||
>
|
||||
<h3 className="font-inter dark:text-foreground text-[14px] leading-[16.5px] font-medium text-black">
|
||||
{headerLabel}
|
||||
</h3>
|
||||
<div className="flex items-center gap-1">
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => setCollapsed((prev) => !prev)}
|
||||
aria-label={
|
||||
collapsed
|
||||
? t('modals.uploadDoc.progress.expandDetails')
|
||||
: t('modals.uploadDoc.progress.collapseDetails')
|
||||
}
|
||||
className="flex h-8 items-center justify-center p-0 text-black opacity-70 transition-opacity hover:opacity-100 dark:text-white"
|
||||
>
|
||||
<img
|
||||
src={ChevronDown}
|
||||
alt=""
|
||||
className={`h-4 w-4 transform transition-transform duration-200 dark:invert ${
|
||||
collapsed ? 'rotate-180' : ''
|
||||
}`}
|
||||
/>
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
onClick={dismissAll}
|
||||
className="flex h-8 items-center justify-center p-0 text-black opacity-70 transition-opacity hover:opacity-100 dark:text-white"
|
||||
aria-label={t('modals.uploadDoc.progress.dismiss')}
|
||||
>
|
||||
<svg
|
||||
width="16"
|
||||
height="16"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
className="h-4 w-4"
|
||||
>
|
||||
<path
|
||||
d="M18 6L6 18"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
<path
|
||||
d="M6 6L18 18"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
{uploadTasks
|
||||
.filter((task) => !task.dismissed)
|
||||
.map((task) => {
|
||||
const shouldShowProgress = [
|
||||
'preparing',
|
||||
'uploading',
|
||||
'training',
|
||||
].includes(task.status);
|
||||
const rawProgress = Math.min(Math.max(task.progress ?? 0, 0), 100);
|
||||
const formattedProgress = Math.round(rawProgress);
|
||||
const progressOffset =
|
||||
PROGRESS_CIRCUMFERENCE * (1 - rawProgress / 100);
|
||||
const isCollapsed = collapsedTasks[task.id] ?? false;
|
||||
|
||||
<div
|
||||
className="grid overflow-hidden transition-[grid-template-rows] duration-300 ease-out"
|
||||
style={{ gridTemplateRows: collapsed ? '0fr' : '1fr' }}
|
||||
>
|
||||
<div
|
||||
className={`min-h-0 overflow-hidden transition-opacity duration-300 ${
|
||||
collapsed ? 'opacity-0' : 'opacity-100'
|
||||
}`}
|
||||
>
|
||||
<ul className="max-h-72 overflow-y-auto">
|
||||
{visibleTasks.map((task) => (
|
||||
<UploadRow key={task.id} task={task} t={t} />
|
||||
))}
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
return (
|
||||
<div
|
||||
key={task.id}
|
||||
className={`border-border bg-card w-[271px] overflow-hidden rounded-2xl border shadow-[0px_24px_48px_0px_#00000029] transition-all duration-300`}
|
||||
>
|
||||
<div className="flex flex-col">
|
||||
<div
|
||||
className={`flex items-center justify-between px-4 py-3 ${
|
||||
task.status !== 'failed'
|
||||
? 'bg-accent/50 dark:bg-muted'
|
||||
: 'bg-destructive/10 dark:bg-destructive/10'
|
||||
}`}
|
||||
>
|
||||
<h3 className="font-inter dark:text-foreground text-[14px] leading-[16.5px] font-medium text-black">
|
||||
{getStatusHeading(task.status)}
|
||||
</h3>
|
||||
<div className="flex items-center gap-1">
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => toggleTaskCollapse(task.id)}
|
||||
aria-label={
|
||||
isCollapsed
|
||||
? t('modals.uploadDoc.progress.expandDetails')
|
||||
: t('modals.uploadDoc.progress.collapseDetails')
|
||||
}
|
||||
className="flex h-8 items-center justify-center p-0 text-black opacity-70 transition-opacity hover:opacity-100 dark:text-white"
|
||||
>
|
||||
<img
|
||||
src={ChevronDown}
|
||||
alt=""
|
||||
className={`h-4 w-4 transform transition-transform duration-200 dark:invert ${
|
||||
isCollapsed ? 'rotate-180' : ''
|
||||
}`}
|
||||
/>
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => dispatch(dismissUploadTask(task.id))}
|
||||
className="flex h-8 items-center justify-center p-0 text-black opacity-70 transition-opacity hover:opacity-100 dark:text-white"
|
||||
aria-label={t('modals.uploadDoc.progress.dismiss')}
|
||||
>
|
||||
<svg
|
||||
width="16"
|
||||
height="16"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
className="h-4 w-4"
|
||||
>
|
||||
<path
|
||||
d="M18 6L6 18"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
<path
|
||||
d="M6 6L18 18"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div
|
||||
className="grid overflow-hidden transition-[grid-template-rows] duration-300 ease-out"
|
||||
style={{ gridTemplateRows: isCollapsed ? '0fr' : '1fr' }}
|
||||
>
|
||||
<div
|
||||
className={`min-h-0 overflow-hidden transition-opacity duration-300 ${
|
||||
isCollapsed ? 'opacity-0' : 'opacity-100'
|
||||
}`}
|
||||
>
|
||||
<div className="flex items-center justify-between px-5 py-3">
|
||||
<p
|
||||
className="font-inter dark:text-muted-foreground max-w-[200px] truncate text-[13px] leading-[16.5px] font-normal text-black"
|
||||
title={task.fileName}
|
||||
>
|
||||
{task.fileName}
|
||||
</p>
|
||||
|
||||
<div className="flex items-center gap-2">
|
||||
{shouldShowProgress && (
|
||||
<svg
|
||||
width="24"
|
||||
height="24"
|
||||
viewBox="0 0 24 24"
|
||||
className="h-6 w-6 shrink-0 text-[#7D54D1]"
|
||||
role="progressbar"
|
||||
aria-valuemin={0}
|
||||
aria-valuemax={100}
|
||||
aria-valuenow={formattedProgress}
|
||||
aria-label={t(
|
||||
'modals.uploadDoc.progress.uploadProgress',
|
||||
{
|
||||
progress: formattedProgress,
|
||||
},
|
||||
)}
|
||||
>
|
||||
<circle
|
||||
className="text-muted dark:text-muted-foreground/30"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
cx="12"
|
||||
cy="12"
|
||||
r={PROGRESS_RADIUS}
|
||||
fill="none"
|
||||
/>
|
||||
<circle
|
||||
className="text-[#7D54D1]"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeDasharray={PROGRESS_CIRCUMFERENCE}
|
||||
strokeDashoffset={progressOffset}
|
||||
cx="12"
|
||||
cy="12"
|
||||
r={PROGRESS_RADIUS}
|
||||
fill="none"
|
||||
transform="rotate(-90 12 12)"
|
||||
/>
|
||||
</svg>
|
||||
)}
|
||||
|
||||
{task.status === 'completed' && (
|
||||
<img
|
||||
src={CheckCircleFilled}
|
||||
alt=""
|
||||
className="h-6 w-6 shrink-0"
|
||||
aria-hidden="true"
|
||||
/>
|
||||
)}
|
||||
|
||||
{task.status === 'failed' && (
|
||||
<img
|
||||
src={WarnIcon}
|
||||
alt=""
|
||||
className="h-6 w-6 shrink-0"
|
||||
aria-hidden="true"
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{task.status === 'failed' && task.errorMessage && (
|
||||
<span className="block px-5 pb-3 text-xs text-red-500">
|
||||
{task.errorMessage}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function UploadRow({
|
||||
task,
|
||||
t,
|
||||
}: {
|
||||
task: UploadTask;
|
||||
t: ReturnType<typeof useTranslation>['t'];
|
||||
}) {
|
||||
const showProgress = IN_PROGRESS_STATUSES.has(task.status);
|
||||
const rawProgress = Math.min(Math.max(task.progress ?? 0, 0), 100);
|
||||
const formattedProgress = Math.round(rawProgress);
|
||||
const progressOffset = PROGRESS_CIRCUMFERENCE * (1 - rawProgress / 100);
|
||||
|
||||
return (
|
||||
<li className="border-border/50 border-b last:border-b-0">
|
||||
<div className="flex items-center justify-between px-5 py-3">
|
||||
<div className="flex min-w-0 flex-col">
|
||||
<p
|
||||
className="font-inter dark:text-muted-foreground max-w-[200px] truncate text-[13px] leading-[16.5px] font-normal text-black"
|
||||
title={task.fileName}
|
||||
>
|
||||
{task.fileName}
|
||||
</p>
|
||||
{task.status === 'training' && task.stage && (
|
||||
<span className="font-inter text-muted-foreground mt-0.5 text-[11px] leading-[14px]">
|
||||
{t(`modals.uploadDoc.progress.${task.stage}`)}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="flex items-center gap-2">
|
||||
{showProgress && (
|
||||
<svg
|
||||
width="24"
|
||||
height="24"
|
||||
viewBox="0 0 24 24"
|
||||
className="h-6 w-6 shrink-0 text-[#7D54D1]"
|
||||
role="progressbar"
|
||||
aria-valuemin={0}
|
||||
aria-valuemax={100}
|
||||
aria-valuenow={formattedProgress}
|
||||
aria-label={t('modals.uploadDoc.progress.uploadProgress', {
|
||||
progress: formattedProgress,
|
||||
})}
|
||||
>
|
||||
<circle
|
||||
className="text-muted dark:text-muted-foreground/30"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
cx="12"
|
||||
cy="12"
|
||||
r={PROGRESS_RADIUS}
|
||||
fill="none"
|
||||
/>
|
||||
<circle
|
||||
className="text-[#7D54D1]"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeDasharray={PROGRESS_CIRCUMFERENCE}
|
||||
strokeDashoffset={progressOffset}
|
||||
cx="12"
|
||||
cy="12"
|
||||
r={PROGRESS_RADIUS}
|
||||
fill="none"
|
||||
transform="rotate(-90 12 12)"
|
||||
/>
|
||||
</svg>
|
||||
)}
|
||||
|
||||
{task.status === 'completed' && (
|
||||
<img
|
||||
src={CheckCircleFilled}
|
||||
alt=""
|
||||
className="h-6 w-6 shrink-0"
|
||||
aria-hidden="true"
|
||||
/>
|
||||
)}
|
||||
|
||||
{task.status === 'failed' && (
|
||||
<img
|
||||
src={WarnIcon}
|
||||
alt=""
|
||||
className="h-6 w-6 shrink-0"
|
||||
aria-hidden="true"
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{task.status === 'failed' &&
|
||||
(task.tokenLimitReached || task.errorMessage) && (
|
||||
<span className="block px-5 pb-3 text-xs text-red-500">
|
||||
{task.tokenLimitReached
|
||||
? t('modals.uploadDoc.progress.tokenLimit')
|
||||
: task.errorMessage}
|
||||
</span>
|
||||
)}
|
||||
</li>
|
||||
);
|
||||
}
|
||||
|
||||
function getStatusHeading(
|
||||
status: UploadTask['status'],
|
||||
t: ReturnType<typeof useTranslation>['t'],
|
||||
): string {
|
||||
switch (status) {
|
||||
case 'preparing':
|
||||
return t('modals.uploadDoc.progress.wait');
|
||||
case 'uploading':
|
||||
case 'training':
|
||||
return t('modals.uploadDoc.progress.upload');
|
||||
case 'completed':
|
||||
return t('modals.uploadDoc.progress.completed');
|
||||
case 'failed':
|
||||
return t('modals.uploadDoc.progress.failed');
|
||||
default:
|
||||
return t('modals.uploadDoc.progress.preparing');
|
||||
}
|
||||
}
|
||||
|
||||
@@ -154,10 +154,10 @@ const ConversationBubble = forwardRef<
|
||||
<img
|
||||
src={DocumentationDark}
|
||||
alt="Attachment"
|
||||
className="h-3.75 w-3.75 object-fill"
|
||||
className="h-[15px] w-[15px] object-fill"
|
||||
/>
|
||||
</div>
|
||||
<span className="max-w-37.5 truncate font-normal">
|
||||
<span className="max-w-[150px] truncate font-normal">
|
||||
{file.fileName}
|
||||
</span>
|
||||
</div>
|
||||
@@ -328,7 +328,7 @@ const ConversationBubble = forwardRef<
|
||||
<div className="mb-4 flex flex-col flex-wrap items-start self-start lg:flex-nowrap">
|
||||
<div className="my-2 flex flex-row items-center justify-center gap-3">
|
||||
<Avatar
|
||||
className="h-6.5 w-7.5 text-xl"
|
||||
className="h-[26px] w-[30px] text-xl"
|
||||
avatar={
|
||||
<img
|
||||
src={Sources}
|
||||
@@ -376,7 +376,7 @@ const ConversationBubble = forwardRef<
|
||||
<img
|
||||
src={Document}
|
||||
alt="Document"
|
||||
className="h-4.25 w-4.25 object-fill"
|
||||
className="h-[17px] w-[17px] object-fill"
|
||||
/>
|
||||
<p
|
||||
className="mt-0.5 truncate text-xs"
|
||||
@@ -394,11 +394,11 @@ const ConversationBubble = forwardRef<
|
||||
</div>
|
||||
{activeTooltip === index && (
|
||||
<div
|
||||
className={`dark:bg-card dark:text-foreground absolute left-1/2 z-50 max-h-48 w-40 translate-x-[-50%] translate-y-0.75 rounded-xl bg-[#FBFBFB] p-4 text-black shadow-xl sm:w-56`}
|
||||
className={`dark:bg-card dark:text-foreground absolute left-1/2 z-50 max-h-48 w-40 translate-x-[-50%] translate-y-[3px] rounded-xl bg-[#FBFBFB] p-4 text-black shadow-xl sm:w-56`}
|
||||
onMouseOver={() => setActiveTooltip(index)}
|
||||
onMouseOut={() => setActiveTooltip(null)}
|
||||
>
|
||||
<p className="line-clamp-6 max-h-41 overflow-hidden rounded-md text-sm wrap-break-word text-ellipsis">
|
||||
<p className="line-clamp-6 max-h-[164px] overflow-hidden rounded-md text-sm wrap-break-word text-ellipsis">
|
||||
{source.text}
|
||||
</p>
|
||||
</div>
|
||||
@@ -471,7 +471,7 @@ const ConversationBubble = forwardRef<
|
||||
<div className="flex max-w-full flex-col flex-wrap items-start self-start lg:flex-nowrap">
|
||||
<div className="my-2 flex flex-row items-center justify-center gap-3">
|
||||
<Avatar
|
||||
className="h-8.5 w-8.5 text-2xl"
|
||||
className="h-[34px] w-[34px] text-2xl"
|
||||
avatar={
|
||||
<img
|
||||
src={DocsGPT3}
|
||||
@@ -1023,7 +1023,7 @@ function ToolCalls({
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="mb-4 relative flex w-full flex-col flex-wrap items-start self-start lg:flex-nowrap">
|
||||
<div className="mb-4 flex w-full flex-col flex-wrap items-start self-start lg:flex-nowrap">
|
||||
{/* Approval bars — always visible, compact inline */}
|
||||
{awaitingCalls.length > 0 && (
|
||||
<div className="fade-in mt-4 ml-3 w-[90vw] md:w-[70vw] lg:w-full">
|
||||
@@ -1042,7 +1042,7 @@ function ToolCalls({
|
||||
<>
|
||||
<div className="my-2 flex flex-row items-center justify-center gap-3">
|
||||
<Avatar
|
||||
className="h-6.5 w-7.5 text-xl"
|
||||
className="h-[26px] w-[30px] text-xl"
|
||||
avatar={
|
||||
<img
|
||||
src={Sources}
|
||||
@@ -1089,7 +1089,7 @@ function ToolCalls({
|
||||
</p>
|
||||
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
|
||||
<span
|
||||
className="dark:text-muted-foreground leading-5.75 text-black"
|
||||
className="dark:text-muted-foreground leading-[23px] text-black"
|
||||
style={{ fontFamily: 'IBMPlexMono-Medium' }}
|
||||
>
|
||||
{JSON.stringify(toolCall.arguments, null, 2)}
|
||||
@@ -1117,7 +1117,7 @@ function ToolCalls({
|
||||
{toolCall.status === 'completed' && (
|
||||
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
|
||||
<span
|
||||
className="dark:text-muted-foreground leading-5.75 text-black"
|
||||
className="dark:text-muted-foreground leading-[23px] text-black"
|
||||
style={{ fontFamily: 'IBMPlexMono-Medium' }}
|
||||
>
|
||||
{JSON.stringify(toolCall.result, null, 2)}
|
||||
@@ -1127,7 +1127,7 @@ function ToolCalls({
|
||||
{toolCall.status === 'error' && (
|
||||
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
|
||||
<span
|
||||
className="text-destructive leading-5.75"
|
||||
className="text-destructive leading-[23px]"
|
||||
style={{ fontFamily: 'IBMPlexMono-Medium' }}
|
||||
>
|
||||
{toolCall.error}
|
||||
@@ -1137,7 +1137,7 @@ function ToolCalls({
|
||||
{toolCall.status === 'denied' && (
|
||||
<p className="dark:bg-card rounded-b-2xl p-2 font-mono text-sm wrap-break-word">
|
||||
<span
|
||||
className="text-muted-foreground leading-5.75"
|
||||
className="text-muted-foreground leading-[23px]"
|
||||
style={{ fontFamily: 'IBMPlexMono-Medium' }}
|
||||
>
|
||||
Denied by user
|
||||
@@ -1172,7 +1172,7 @@ function Thought({
|
||||
<div className="mb-4 flex w-full flex-col flex-wrap items-start self-start lg:flex-nowrap">
|
||||
<div className="my-2 flex flex-row items-center justify-center gap-3">
|
||||
<Avatar
|
||||
className="h-6.5 w-7.5 text-xl"
|
||||
className="h-[26px] w-[30px] text-xl"
|
||||
avatar={
|
||||
<img
|
||||
src={Cloud}
|
||||
@@ -1197,7 +1197,7 @@ function Thought({
|
||||
</div>
|
||||
{isThoughtOpen && (
|
||||
<div className="fade-in mr-5 ml-2 max-w-[90vw] md:max-w-[70vw] lg:max-w-[50vw]">
|
||||
<div className="bg-muted dark:bg-answer-bubble rounded-[28px] px-7 py-4.5">
|
||||
<div className="bg-muted dark:bg-answer-bubble rounded-[28px] px-7 py-[18px]">
|
||||
<ReactMarkdown
|
||||
className="fade-in leading-normal wrap-break-word whitespace-pre-wrap"
|
||||
remarkPlugins={[remarkGfm, remarkMath]}
|
||||
|
||||
@@ -1,155 +1,8 @@
|
||||
import { baseURL } from '../api/client';
|
||||
import conversationService from '../api/services/conversationService';
|
||||
import { Doc } from '../models/misc';
|
||||
import { Answer, FEEDBACK, RetrievalPayload } from './conversationModels';
|
||||
import { ToolCallsType } from './types';
|
||||
|
||||
/**
|
||||
* Mirrors the backend's ``_SEQUENCE_NO_RE`` (application/api/answer/
|
||||
* routes/messages.py) — only non-negative decimal integers are valid
|
||||
* cursors. Rejects empty strings (Number("") === 0), hex literals,
|
||||
* exponential notation, and anything else that ``Number(...)`` would
|
||||
* happily coerce.
|
||||
*/
|
||||
const _SEQUENCE_NO_RE = /^\d+$/;
|
||||
|
||||
/**
|
||||
* Drain an SSE response body, forwarding each ``data:`` line to
|
||||
* ``onData`` and tracking the most recent ``id:`` header. Returns
|
||||
* when the body ends, the signal aborts, or ``shouldStop()`` returns
|
||||
* true (e.g. a terminal ``end``/``error`` event was dispatched —
|
||||
* the reconnect endpoint is a live tail that doesn't close on its
|
||||
* own past terminal replay).
|
||||
*/
|
||||
/**
|
||||
* Convert a non-SSE pre-stream HTTP failure (e.g. ``check_usage``'s
|
||||
* 429 JSON response) into a synthetic typed ``error`` frame so the
|
||||
* caller's slice sees the actual server message instead of the
|
||||
* generic "Connection lost" synthesised when the drainer finishes
|
||||
* with zero events. Returns true if a frame was dispatched and the
|
||||
* caller should skip ``_drainSseBody`` entirely.
|
||||
*
|
||||
* SSE-shaped error bodies (``mimetype="text/event-stream"``) are
|
||||
* left alone — the drainer parses the typed ``error`` frame they
|
||||
* carry through the normal path.
|
||||
*/
|
||||
async function _handlePreStreamHttpError(
|
||||
response: Response,
|
||||
dispatch: (data: string) => void,
|
||||
): Promise<boolean> {
|
||||
if (response.ok) return false;
|
||||
const contentType = (
|
||||
response.headers.get('content-type') ?? ''
|
||||
).toLowerCase();
|
||||
if (contentType.includes('text/event-stream')) return false;
|
||||
let message: string | null = null;
|
||||
try {
|
||||
const text = await response.text();
|
||||
if (text) {
|
||||
try {
|
||||
const parsed = JSON.parse(text);
|
||||
if (parsed && typeof parsed === 'object') {
|
||||
message =
|
||||
(typeof parsed.message === 'string' && parsed.message) ||
|
||||
(typeof parsed.error === 'string' && parsed.error) ||
|
||||
(typeof parsed.detail === 'string' && parsed.detail) ||
|
||||
null;
|
||||
}
|
||||
} catch {
|
||||
message = text.slice(0, 500);
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// Body already consumed or unreadable — fall through to the
|
||||
// status-line fallback below.
|
||||
}
|
||||
if (!message) {
|
||||
message = `HTTP ${response.status} ${response.statusText}`.trim();
|
||||
}
|
||||
dispatch(JSON.stringify({ type: 'error', error: message }));
|
||||
return true;
|
||||
}
|
||||
|
||||
async function _drainSseBody(
|
||||
body: ReadableStream<Uint8Array>,
|
||||
signal: AbortSignal,
|
||||
onData: (data: string) => void,
|
||||
onId: (id: number) => void,
|
||||
shouldStop?: () => boolean,
|
||||
): Promise<void> {
|
||||
const reader = body.getReader();
|
||||
const decoder = new TextDecoder('utf-8');
|
||||
let buffer = '';
|
||||
let stoppedEarly = false;
|
||||
try {
|
||||
while (true) {
|
||||
if (signal.aborted) break;
|
||||
if (shouldStop?.()) {
|
||||
stoppedEarly = true;
|
||||
break;
|
||||
}
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
// Normalise mixed line terminators so a stray CR can't smuggle
|
||||
// a record boundary inside a JSON payload.
|
||||
buffer = buffer.replace(/\r\n/g, '\n').replace(/\r/g, '\n');
|
||||
let boundary = buffer.indexOf('\n\n');
|
||||
while (boundary !== -1) {
|
||||
const record = buffer.slice(0, boundary);
|
||||
buffer = buffer.slice(boundary + 2);
|
||||
boundary = buffer.indexOf('\n\n');
|
||||
if (record.length === 0) continue;
|
||||
const dataParts: string[] = [];
|
||||
let sawDataField = false;
|
||||
for (const line of record.split('\n')) {
|
||||
if (line.length === 0) continue;
|
||||
if (line.startsWith(':')) continue; // SSE comment / keepalive
|
||||
const colonIdx = line.indexOf(':');
|
||||
const field = colonIdx === -1 ? line : line.slice(0, colonIdx);
|
||||
let value = colonIdx === -1 ? '' : line.slice(colonIdx + 1);
|
||||
if (value.startsWith(' ')) value = value.slice(1);
|
||||
if (field === 'id') {
|
||||
// Strict regex match — empty value, hex, ``-1`` (the
|
||||
// backend's terminal snapshot-failure synthetic), and
|
||||
// exponent forms are all rejected so they can't silently
|
||||
// rewrite the reconnect cursor.
|
||||
if (_SEQUENCE_NO_RE.test(value)) onId(parseInt(value, 10));
|
||||
} else if (field === 'data') {
|
||||
sawDataField = true;
|
||||
dataParts.push(value);
|
||||
}
|
||||
}
|
||||
if (!sawDataField) continue;
|
||||
const data = dataParts.join('\n').trim();
|
||||
if (data.length === 0) continue;
|
||||
onData(data);
|
||||
if (shouldStop?.()) {
|
||||
stoppedEarly = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (stoppedEarly) break;
|
||||
}
|
||||
} finally {
|
||||
if (stoppedEarly) {
|
||||
// Ask the runtime to tear the underlying response body down so
|
||||
// the server-side WSGI thread isn't pinned waiting on
|
||||
// keepalives. ``releaseLock`` alone leaves the body half-open.
|
||||
try {
|
||||
await reader.cancel();
|
||||
} catch {
|
||||
// Already errored / closed.
|
||||
}
|
||||
}
|
||||
try {
|
||||
reader.releaseLock();
|
||||
} catch {
|
||||
// Already released.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export function handleFetchAnswer(
|
||||
question: string,
|
||||
signal: AbortSignal,
|
||||
@@ -290,153 +143,54 @@ export function handleFetchAnswerSteaming(
|
||||
|
||||
const headers: Record<string, string> = {};
|
||||
if (idempotencyKey) headers['Idempotency-Key'] = idempotencyKey;
|
||||
|
||||
// Per-stream state used for reconnect-after-disconnect.
|
||||
let messageId: string | null = null;
|
||||
let lastEventId: number | null = null;
|
||||
// The single JSON.parse below feeds both the message_id capture and
|
||||
// the termination flag — cheaper and stricter than substring
|
||||
// matching the wire bytes.
|
||||
let endReceived = false;
|
||||
|
||||
const dispatch = (data: string) => {
|
||||
try {
|
||||
const parsed = JSON.parse(data);
|
||||
if (parsed && typeof parsed === 'object') {
|
||||
if (parsed.type === 'message_id' && parsed.message_id) {
|
||||
messageId = parsed.message_id;
|
||||
} else if (parsed.type === 'end' || parsed.type === 'error') {
|
||||
endReceived = true;
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// Not JSON — pass through anyway; the caller handles raw lines.
|
||||
}
|
||||
onEvent(new MessageEvent('message', { data }));
|
||||
};
|
||||
|
||||
const runInitialPost = async (): Promise<void> => {
|
||||
const response = await conversationService.answerStream(
|
||||
payload,
|
||||
token,
|
||||
signal,
|
||||
headers,
|
||||
);
|
||||
// Pre-stream HTTP failures with non-SSE bodies (e.g. ``check_usage``
|
||||
// returning a JSON 429) drain as zero events and would otherwise
|
||||
// be masked by the generic "Connection lost" synthetic. Convert
|
||||
// them into a typed ``error`` frame so the real message surfaces.
|
||||
if (await _handlePreStreamHttpError(response, dispatch)) return;
|
||||
if (!response.body) throw new Error('No response body');
|
||||
await _drainSseBody(response.body, signal, dispatch, (id) => {
|
||||
lastEventId = id;
|
||||
});
|
||||
};
|
||||
|
||||
// Reconnect's stop predicate: as soon as ``dispatch`` flips
|
||||
// ``endReceived`` (typed ``end`` or ``error`` event seen — both
|
||||
// are terminal per the backend's contract). Without this the
|
||||
// live-tail endpoint would emit keepalives indefinitely and the
|
||||
// await would never return.
|
||||
const reconnectShouldStop = () => endReceived;
|
||||
|
||||
const runReconnect = async (): Promise<void> => {
|
||||
if (!messageId) {
|
||||
throw new Error('reconnect: no message_id captured');
|
||||
}
|
||||
const url = new URL(`${baseURL}/api/messages/${messageId}/events`);
|
||||
if (lastEventId !== null) {
|
||||
url.searchParams.set('last_event_id', String(lastEventId));
|
||||
}
|
||||
const reconnectHeaders: Record<string, string> = {
|
||||
Accept: 'text/event-stream',
|
||||
};
|
||||
if (token) reconnectHeaders.Authorization = `Bearer ${token}`;
|
||||
// NB: there is no slice consumer for a synthetic ``reconnecting``
|
||||
// event yet — surface only the underlying network reality. The
|
||||
// user-visible ``Reconnecting…`` affordance is a follow-up that
|
||||
// needs ``conversationSlice`` to gain a status case.
|
||||
const response = await fetch(url.toString(), {
|
||||
method: 'GET',
|
||||
headers: reconnectHeaders,
|
||||
signal,
|
||||
cache: 'no-store',
|
||||
});
|
||||
if (!response.ok || !response.body) {
|
||||
throw new Error(
|
||||
`reconnect: HTTP ${response.status} ${response.statusText}`,
|
||||
);
|
||||
}
|
||||
await _drainSseBody(
|
||||
response.body,
|
||||
signal,
|
||||
dispatch,
|
||||
(id) => {
|
||||
lastEventId = id;
|
||||
},
|
||||
reconnectShouldStop,
|
||||
);
|
||||
};
|
||||
|
||||
return new Promise<Answer>((resolve, reject) => {
|
||||
(async () => {
|
||||
try {
|
||||
try {
|
||||
await runInitialPost();
|
||||
} catch (initialErr) {
|
||||
// Mid-stream network failures (WiFi blip, worker recycle,
|
||||
// body reader rejecting) surface as a thrown error — not a
|
||||
// graceful EOF. If the stream had already started (we have a
|
||||
// ``messageId``), fall through to the reconnect path so the
|
||||
// journal-backed replay can finish what the live socket
|
||||
// couldn't. Pre-stream failures (auth, DNS, server 4xx/5xx
|
||||
// before any yield) lack a messageId and bubble up.
|
||||
if (signal.aborted || !messageId) throw initialErr;
|
||||
console.warn(
|
||||
'Initial stream failed mid-flight, attempting reconnect:',
|
||||
initialErr,
|
||||
);
|
||||
}
|
||||
// The backend ends the stream cleanly with a typed ``end``
|
||||
// event. Anything else (network drop, gunicorn worker recycle,
|
||||
// load-balancer timeout) is a "premature close" — try one
|
||||
// reconnect via the GET /api/messages/<id>/events endpoint.
|
||||
if (!endReceived && !signal.aborted && messageId) {
|
||||
try {
|
||||
await runReconnect();
|
||||
} catch (reconnectErr) {
|
||||
console.warn('Stream reconnect failed:', reconnectErr);
|
||||
conversationService
|
||||
.answerStream(payload, token, signal, headers)
|
||||
.then((response) => {
|
||||
if (!response.body) throw Error('No response body');
|
||||
|
||||
let buffer = '';
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder('utf-8');
|
||||
let counterrr = 0;
|
||||
const processStream = ({
|
||||
done,
|
||||
value,
|
||||
}: ReadableStreamReadResult<Uint8Array>) => {
|
||||
if (done) return;
|
||||
|
||||
counterrr += 1;
|
||||
|
||||
const chunk = decoder.decode(value);
|
||||
buffer += chunk;
|
||||
|
||||
const events = buffer.split('\n\n');
|
||||
buffer = events.pop() ?? '';
|
||||
|
||||
for (const event of events) {
|
||||
if (event.trim().startsWith('data:')) {
|
||||
const dataLine: string = event
|
||||
.split('\n')
|
||||
.map((line: string) => line.replace(/^data:\s?/, ''))
|
||||
.join('');
|
||||
|
||||
const messageEvent = new MessageEvent('message', {
|
||||
data: dataLine.trim(),
|
||||
});
|
||||
|
||||
onEvent(messageEvent);
|
||||
}
|
||||
}
|
||||
}
|
||||
// If we never observed a terminal frame (reconnect 4xx/5xx,
|
||||
// network drop during reconnect, or live tail still silent),
|
||||
// synthesize one through the same ``dispatch`` path the wire
|
||||
// events use. Without this the caller's slice never transitions
|
||||
// out of ``streaming`` and the UI stays in a loading spinner
|
||||
// forever — the conversationSlice handles ``data.type === 'error'``
|
||||
// by setting status=failed.
|
||||
if (!endReceived && !signal.aborted) {
|
||||
dispatch(
|
||||
JSON.stringify({
|
||||
type: 'error',
|
||||
error:
|
||||
'Connection lost. The response could not be resumed; please try again.',
|
||||
}),
|
||||
);
|
||||
}
|
||||
// The handler historically never explicitly resolved with a
|
||||
// value — callers consume the streamed events via ``onEvent``
|
||||
// and read final state from Redux. Preserve that contract.
|
||||
resolve(undefined as unknown as Answer);
|
||||
} catch (error) {
|
||||
if (signal.aborted) {
|
||||
resolve(undefined as unknown as Answer);
|
||||
return;
|
||||
}
|
||||
|
||||
reader.read().then(processStream).catch(reject);
|
||||
};
|
||||
|
||||
reader.read().then(processStream).catch(reject);
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error('Connection failed:', error);
|
||||
reject(error);
|
||||
}
|
||||
})();
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -460,149 +214,52 @@ export function handleSubmitToolActions(
|
||||
|
||||
const headers: Record<string, string> = {};
|
||||
if (idempotencyKey) headers['Idempotency-Key'] = idempotencyKey;
|
||||
|
||||
// Tool-action submissions resume against the original
|
||||
// ``reserved_message_id``, so the backend's continuation path emits
|
||||
// ``id:`` prefixed records that the legacy parser would silently
|
||||
// drop. Use the shared SSE drainer — and the same reconnect-on-
|
||||
// premature-close pattern as ``handleFetchAnswerSteaming`` so a
|
||||
// dropped tool-action stream can pick up after the disconnect.
|
||||
let messageId: string | null = null;
|
||||
let lastEventId: number | null = null;
|
||||
|
||||
// Track whether the typed ``end`` event was observed. The single
|
||||
// JSON.parse below feeds both the message_id capture and the
|
||||
// termination flag — cheaper and stricter than substring matching
|
||||
// the wire bytes.
|
||||
let endReceived = false;
|
||||
|
||||
const dispatch = (data: string) => {
|
||||
try {
|
||||
const parsed = JSON.parse(data);
|
||||
if (parsed && typeof parsed === 'object') {
|
||||
if (parsed.type === 'message_id' && parsed.message_id) {
|
||||
messageId = parsed.message_id;
|
||||
} else if (parsed.type === 'end' || parsed.type === 'error') {
|
||||
// Match the backend's terminal set in
|
||||
// ``application/streaming/event_replay.py``: the agent's
|
||||
// catch-all path emits ``error`` *without* a trailing
|
||||
// ``end``, so treating only ``end`` as terminal would
|
||||
// trigger a reconnect against an already-finished stream
|
||||
// and hang on keepalives.
|
||||
endReceived = true;
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// Not JSON — pass through anyway; the caller handles raw lines.
|
||||
}
|
||||
onEvent(new MessageEvent('message', { data }));
|
||||
};
|
||||
|
||||
const runInitial = async (): Promise<void> => {
|
||||
const response = await conversationService.answerStream(
|
||||
payload,
|
||||
token,
|
||||
signal,
|
||||
headers,
|
||||
);
|
||||
// See ``handleFetchAnswerSteaming`` for the rationale: non-SSE
|
||||
// HTTP failures (e.g. ``check_usage`` 429 JSON) need to be lifted
|
||||
// into a typed ``error`` frame before they reach the drainer.
|
||||
if (await _handlePreStreamHttpError(response, dispatch)) return;
|
||||
if (!response.body) throw new Error('No response body');
|
||||
await _drainSseBody(response.body, signal, dispatch, (id) => {
|
||||
lastEventId = id;
|
||||
});
|
||||
};
|
||||
|
||||
// Reconnect's stop predicate: as soon as ``dispatch`` flips
|
||||
// ``endReceived`` (typed ``end`` or ``error`` event seen — both
|
||||
// are terminal per the backend's contract). Without this the
|
||||
// live-tail endpoint would emit keepalives indefinitely and the
|
||||
// await would never return.
|
||||
const reconnectShouldStop = () => endReceived;
|
||||
|
||||
const runReconnect = async (): Promise<void> => {
|
||||
if (!messageId) {
|
||||
throw new Error('reconnect: no message_id captured');
|
||||
}
|
||||
const url = new URL(`${baseURL}/api/messages/${messageId}/events`);
|
||||
if (lastEventId !== null) {
|
||||
url.searchParams.set('last_event_id', String(lastEventId));
|
||||
}
|
||||
const reconnectHeaders: Record<string, string> = {
|
||||
Accept: 'text/event-stream',
|
||||
};
|
||||
if (token) reconnectHeaders.Authorization = `Bearer ${token}`;
|
||||
const response = await fetch(url.toString(), {
|
||||
method: 'GET',
|
||||
headers: reconnectHeaders,
|
||||
signal,
|
||||
cache: 'no-store',
|
||||
});
|
||||
if (!response.ok || !response.body) {
|
||||
throw new Error(
|
||||
`reconnect: HTTP ${response.status} ${response.statusText}`,
|
||||
);
|
||||
}
|
||||
await _drainSseBody(
|
||||
response.body,
|
||||
signal,
|
||||
dispatch,
|
||||
(id) => {
|
||||
lastEventId = id;
|
||||
},
|
||||
reconnectShouldStop,
|
||||
);
|
||||
};
|
||||
|
||||
return new Promise<Answer>((resolve, reject) => {
|
||||
(async () => {
|
||||
try {
|
||||
try {
|
||||
await runInitial();
|
||||
} catch (initialErr) {
|
||||
// Same premature-close handling as
|
||||
// ``handleFetchAnswerSteaming``: a thrown reader error after
|
||||
// the message_id frame still warrants one reconnect attempt
|
||||
// against the journal. Pre-stream failures lack a messageId
|
||||
// and bubble up.
|
||||
if (signal.aborted || !messageId) throw initialErr;
|
||||
console.warn(
|
||||
'Tool-actions stream failed mid-flight, attempting reconnect:',
|
||||
initialErr,
|
||||
);
|
||||
}
|
||||
if (!endReceived && !signal.aborted && messageId) {
|
||||
try {
|
||||
await runReconnect();
|
||||
} catch (reconnectErr) {
|
||||
console.warn('Tool-actions reconnect failed:', reconnectErr);
|
||||
conversationService
|
||||
.answerStream(payload, token, signal, headers)
|
||||
.then((response) => {
|
||||
if (!response.body) throw Error('No response body');
|
||||
|
||||
let buffer = '';
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder('utf-8');
|
||||
|
||||
const processStream = ({
|
||||
done,
|
||||
value,
|
||||
}: ReadableStreamReadResult<Uint8Array>) => {
|
||||
if (done) return;
|
||||
|
||||
const chunk = decoder.decode(value);
|
||||
buffer += chunk;
|
||||
|
||||
const events = buffer.split('\n\n');
|
||||
buffer = events.pop() ?? '';
|
||||
|
||||
for (const event of events) {
|
||||
if (event.trim().startsWith('data:')) {
|
||||
const dataLine: string = event
|
||||
.split('\n')
|
||||
.map((line: string) => line.replace(/^data:\s?/, ''))
|
||||
.join('');
|
||||
|
||||
const messageEvent = new MessageEvent('message', {
|
||||
data: dataLine.trim(),
|
||||
});
|
||||
|
||||
onEvent(messageEvent);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Synthesize a terminal error if reconnect couldn't deliver one
|
||||
// (4xx/5xx, network drop, silent live tail). Same reasoning as
|
||||
// ``handleFetchAnswerSteaming``: the caller's slice only exits
|
||||
// the streaming state on a terminal frame.
|
||||
if (!endReceived && !signal.aborted) {
|
||||
dispatch(
|
||||
JSON.stringify({
|
||||
type: 'error',
|
||||
error:
|
||||
'Connection lost. The tool response could not be resumed; please try again.',
|
||||
}),
|
||||
);
|
||||
}
|
||||
resolve(undefined as unknown as Answer);
|
||||
} catch (error) {
|
||||
if (signal.aborted) {
|
||||
resolve(undefined as unknown as Answer);
|
||||
return;
|
||||
}
|
||||
|
||||
reader.read().then(processStream).catch(reject);
|
||||
};
|
||||
|
||||
reader.read().then(processStream).catch(reject);
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error('Tool actions submission failed:', error);
|
||||
reject(error);
|
||||
}
|
||||
})();
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -1,153 +0,0 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import reducer, {
|
||||
applyMessageTail,
|
||||
setConversation,
|
||||
} from './conversationSlice';
|
||||
|
||||
const baseQuery = {
|
||||
prompt: 'tell me a poem',
|
||||
messageId: 'm-1',
|
||||
messageStatus: 'pending' as const,
|
||||
};
|
||||
|
||||
const seedSlice = () => reducer(undefined, setConversation([baseQuery]));
|
||||
|
||||
describe('applyMessageTail — streaming partial', () => {
|
||||
it('writes response to the query while status is streaming', () => {
|
||||
const state = seedSlice();
|
||||
const next = reducer(
|
||||
state,
|
||||
applyMessageTail({
|
||||
index: 0,
|
||||
tail: {
|
||||
message_id: 'm-1',
|
||||
status: 'streaming',
|
||||
response: 'Hello, par',
|
||||
thought: null,
|
||||
sources: [],
|
||||
tool_calls: [],
|
||||
},
|
||||
}),
|
||||
);
|
||||
expect(next.queries[0].messageStatus).toBe('streaming');
|
||||
expect(next.queries[0].response).toBe('Hello, par');
|
||||
});
|
||||
|
||||
it('updates response on each successive tail tick', () => {
|
||||
let state = seedSlice();
|
||||
state = reducer(
|
||||
state,
|
||||
applyMessageTail({
|
||||
index: 0,
|
||||
tail: {
|
||||
message_id: 'm-1',
|
||||
status: 'streaming',
|
||||
response: 'Hello',
|
||||
sources: [],
|
||||
tool_calls: [],
|
||||
},
|
||||
}),
|
||||
);
|
||||
state = reducer(
|
||||
state,
|
||||
applyMessageTail({
|
||||
index: 0,
|
||||
tail: {
|
||||
message_id: 'm-1',
|
||||
status: 'streaming',
|
||||
response: 'Hello, world',
|
||||
sources: [],
|
||||
tool_calls: [],
|
||||
},
|
||||
}),
|
||||
);
|
||||
expect(state.queries[0].response).toBe('Hello, world');
|
||||
});
|
||||
|
||||
it('applies sources and tool_calls when they appear mid-stream', () => {
|
||||
const state = seedSlice();
|
||||
const next = reducer(
|
||||
state,
|
||||
applyMessageTail({
|
||||
index: 0,
|
||||
tail: {
|
||||
message_id: 'm-1',
|
||||
status: 'streaming',
|
||||
response: 'partial',
|
||||
sources: [{ id: 's1', title: 'doc' }],
|
||||
tool_calls: [{ name: 'search' }],
|
||||
},
|
||||
}),
|
||||
);
|
||||
expect(next.queries[0].sources).toEqual([{ id: 's1', title: 'doc' }]);
|
||||
expect(next.queries[0].tool_calls).toEqual([{ name: 'search' }]);
|
||||
});
|
||||
|
||||
it('ignores empty sources / tool_calls arrays so the renderer stays clean', () => {
|
||||
const state = seedSlice();
|
||||
const next = reducer(
|
||||
state,
|
||||
applyMessageTail({
|
||||
index: 0,
|
||||
tail: {
|
||||
message_id: 'm-1',
|
||||
status: 'streaming',
|
||||
response: 'partial',
|
||||
sources: [],
|
||||
tool_calls: [],
|
||||
},
|
||||
}),
|
||||
);
|
||||
expect(next.queries[0].sources).toBeUndefined();
|
||||
expect(next.queries[0].tool_calls).toBeUndefined();
|
||||
});
|
||||
|
||||
it('promotes to complete with the final response and clears any error', () => {
|
||||
let state = seedSlice();
|
||||
state = reducer(
|
||||
state,
|
||||
applyMessageTail({
|
||||
index: 0,
|
||||
tail: {
|
||||
message_id: 'm-1',
|
||||
status: 'streaming',
|
||||
response: 'partial',
|
||||
},
|
||||
}),
|
||||
);
|
||||
state = reducer(
|
||||
state,
|
||||
applyMessageTail({
|
||||
index: 0,
|
||||
tail: {
|
||||
message_id: 'm-1',
|
||||
status: 'complete',
|
||||
response: 'Final answer.',
|
||||
},
|
||||
}),
|
||||
);
|
||||
expect(state.queries[0].messageStatus).toBe('complete');
|
||||
expect(state.queries[0].response).toBe('Final answer.');
|
||||
expect(state.queries[0].error).toBeUndefined();
|
||||
});
|
||||
|
||||
it('surfaces failed status as error and clears response', () => {
|
||||
const state = seedSlice();
|
||||
const next = reducer(
|
||||
state,
|
||||
applyMessageTail({
|
||||
index: 0,
|
||||
tail: {
|
||||
message_id: 'm-1',
|
||||
status: 'failed',
|
||||
response: 'whatever',
|
||||
error: 'worker died',
|
||||
},
|
||||
}),
|
||||
);
|
||||
expect(next.queries[0].messageStatus).toBe('failed');
|
||||
expect(next.queries[0].error).toBe('worker died');
|
||||
expect(next.queries[0].response).toBeUndefined();
|
||||
});
|
||||
});
|
||||
@@ -957,34 +957,20 @@ export const conversationSlice = createSlice({
|
||||
const status = tail?.status as MessageStatus | undefined;
|
||||
query.messageStatus = status;
|
||||
query.lastHeartbeatAt = tail?.last_heartbeat_at ?? query.lastHeartbeatAt;
|
||||
if (status === 'failed') {
|
||||
if (status === 'complete') {
|
||||
query.response = tail?.response ?? '';
|
||||
query.thought = tail?.thought ?? query.thought;
|
||||
query.sources = tail?.sources ?? query.sources;
|
||||
query.tool_calls = tail?.tool_calls ?? query.tool_calls;
|
||||
delete query.error;
|
||||
} else if (status === 'failed') {
|
||||
// Surface as error so the placeholder text never renders.
|
||||
query.error =
|
||||
(typeof tail?.error === 'string' && tail.error) ||
|
||||
'Generation failed before completing.';
|
||||
delete query.response;
|
||||
return;
|
||||
}
|
||||
// /tail returns reconstructed partials mid-stream so a second tab
|
||||
// can render the in-flight bubble; spinner is driven by status.
|
||||
const incomingResponse = tail?.response;
|
||||
if (typeof incomingResponse === 'string') {
|
||||
query.response = incomingResponse;
|
||||
} else if (status === 'complete') {
|
||||
query.response = '';
|
||||
}
|
||||
if (typeof tail?.thought === 'string') {
|
||||
query.thought = tail.thought;
|
||||
}
|
||||
if (Array.isArray(tail?.sources) && tail.sources.length > 0) {
|
||||
query.sources = tail.sources;
|
||||
}
|
||||
if (Array.isArray(tail?.tool_calls) && tail.tool_calls.length > 0) {
|
||||
query.tool_calls = tail.tool_calls;
|
||||
}
|
||||
if (status === 'complete') {
|
||||
delete query.error;
|
||||
}
|
||||
// pending / streaming: untouched; spinner keeps showing.
|
||||
},
|
||||
raiseError(
|
||||
state,
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
import React from 'react';
|
||||
|
||||
import { useEventStream } from './useEventStream';
|
||||
|
||||
/**
|
||||
* Mount-once provider that opens the user's SSE connection. Place
|
||||
* inside ``AuthWrapper`` so it sees a populated token, and wrap the
|
||||
* authenticated-app subtree so the connection lives for the user's
|
||||
* whole session.
|
||||
*/
|
||||
export function EventStreamProvider({
|
||||
children,
|
||||
}: {
|
||||
children: React.ReactNode;
|
||||
}): React.ReactElement {
|
||||
useEventStream();
|
||||
return <>{children}</>;
|
||||
}
|
||||
@@ -1,49 +0,0 @@
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import type { AppDispatch } from '../store';
|
||||
import {
|
||||
sseEventReceived,
|
||||
sseLastEventIdReset,
|
||||
} from '../notifications/notificationsSlice';
|
||||
import { dispatchSSEEvent } from './dispatchEvent';
|
||||
|
||||
describe('dispatchSSEEvent', () => {
|
||||
let debugSpy: ReturnType<typeof vi.spyOn>;
|
||||
|
||||
beforeEach(() => {
|
||||
debugSpy = vi.spyOn(console, 'debug').mockImplementation(() => undefined);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
debugSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('dispatches sseLastEventIdReset AND sseEventReceived for backlog.truncated', () => {
|
||||
const dispatch = vi.fn() as unknown as AppDispatch;
|
||||
const envelope = { type: 'backlog.truncated' as const };
|
||||
|
||||
dispatchSSEEvent(envelope, dispatch);
|
||||
|
||||
const calls = (dispatch as unknown as { mock: { calls: unknown[][] } }).mock
|
||||
.calls;
|
||||
expect(calls).toHaveLength(2);
|
||||
expect(calls[0][0]).toEqual(sseLastEventIdReset());
|
||||
expect(calls[1][0]).toEqual(sseEventReceived(envelope));
|
||||
});
|
||||
|
||||
it('does not log a debug line for known envelope types', () => {
|
||||
const dispatch = vi.fn() as unknown as AppDispatch;
|
||||
dispatchSSEEvent({ id: 'e-1', type: 'source.ingest.progress' }, dispatch);
|
||||
expect(debugSpy).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('logs a debug line for unknown envelope types', () => {
|
||||
const dispatch = vi.fn() as unknown as AppDispatch;
|
||||
dispatchSSEEvent({ id: 'e-2', type: 'mystery.event' }, dispatch);
|
||||
expect(debugSpy).toHaveBeenCalledTimes(1);
|
||||
expect(debugSpy.mock.calls[0]).toEqual([
|
||||
'[dispatchSSEEvent] unknown envelope type',
|
||||
'mystery.event',
|
||||
]);
|
||||
});
|
||||
});
|
||||
@@ -1,58 +0,0 @@
|
||||
import type { AppDispatch } from '../store';
|
||||
import {
|
||||
sseEventReceived,
|
||||
sseLastEventIdReset,
|
||||
type SSEEvent,
|
||||
} from '../notifications/notificationsSlice';
|
||||
|
||||
// Envelope types this build knows about. Hitting an unknown type means
|
||||
// the backend published something the frontend hasn't been taught yet
|
||||
// — worth a single debug line so it's visible in devtools without
|
||||
// drowning the console in known per-progress traffic.
|
||||
const KNOWN_TYPES: ReadonlySet<string> = new Set([
|
||||
'backlog.truncated',
|
||||
'source.ingest.queued',
|
||||
'source.ingest.progress',
|
||||
'source.ingest.completed',
|
||||
'source.ingest.failed',
|
||||
'attachment.queued',
|
||||
'attachment.progress',
|
||||
'attachment.completed',
|
||||
'attachment.failed',
|
||||
'mcp.oauth.awaiting_redirect',
|
||||
'mcp.oauth.in_progress',
|
||||
'mcp.oauth.completed',
|
||||
'mcp.oauth.failed',
|
||||
'tool.approval.required',
|
||||
]);
|
||||
|
||||
/**
|
||||
* Single fan-out point for inbound SSE envelopes. Always dispatches
|
||||
* ``sseEventReceived`` so any slice can ``extraReducers``-listen
|
||||
* (uploadSlice does this for source-ingest events), then handles the
|
||||
* small set of envelope-types that need centralised side effects (e.g.
|
||||
* ``backlog.truncated``).
|
||||
*/
|
||||
export function dispatchSSEEvent(
|
||||
envelope: SSEEvent,
|
||||
dispatch: AppDispatch,
|
||||
): void {
|
||||
if (!KNOWN_TYPES.has(envelope.type)) {
|
||||
console.debug('[dispatchSSEEvent] unknown envelope type', envelope.type);
|
||||
}
|
||||
|
||||
switch (envelope.type) {
|
||||
case 'backlog.truncated':
|
||||
// Backlog window slid past the client's Last-Event-ID. Drop the
|
||||
// cursor so the next reconnect doesn't try to resume past the
|
||||
// retained window. Slices that care about full-state freshness
|
||||
// can subscribe to ``sseEventReceived`` and refetch.
|
||||
dispatch(sseLastEventIdReset());
|
||||
break;
|
||||
default:
|
||||
// No central side effect; rely on slice-level extraReducers.
|
||||
break;
|
||||
}
|
||||
|
||||
dispatch(sseEventReceived(envelope));
|
||||
}
|
||||
@@ -1,386 +0,0 @@
|
||||
import { baseURL } from '../api/client';
|
||||
import type { SSEEvent } from '../notifications/notificationsSlice';
|
||||
|
||||
/**
|
||||
* Connection state surfaced to the consumer. Maps directly to the
|
||||
* ``PushHealth`` machine in ``notificationsSlice``.
|
||||
*/
|
||||
export type EventStreamHealth = 'connecting' | 'healthy' | 'unhealthy';
|
||||
|
||||
export interface EventStreamOptions {
|
||||
/** Bearer token; ``null`` short-circuits to ``unhealthy`` (auth required). */
|
||||
token: string | null;
|
||||
/**
|
||||
* Lazy getter for the current ``Last-Event-ID``. Called once at the
|
||||
* top of each connect attempt so token rotations / remounts read
|
||||
* the freshest cursor from Redux instead of a stale mount-time
|
||||
* snapshot. Return ``null`` for a fresh connect.
|
||||
*/
|
||||
getLastEventId: () => string | null;
|
||||
onEvent: (event: SSEEvent) => void;
|
||||
onHealthChange: (health: EventStreamHealth) => void;
|
||||
/** Called with the most recently received id so the caller can persist it. */
|
||||
onLastEventId?: (id: string) => void;
|
||||
/**
|
||||
* Called when the server emitted an ``id:`` line with an empty value
|
||||
* (WHATWG SSE cursor reset). Distinct from ``onLastEventId('')`` so
|
||||
* callers can dispatch ``sseLastEventIdReset`` without overloading
|
||||
* the normal advance path.
|
||||
*/
|
||||
onLastEventIdReset?: () => void;
|
||||
/**
|
||||
* Invoked once after ``MAX_CONSECUTIVE_401`` back-to-back 401s. The
|
||||
* reconnect loop then bails out, so the caller is responsible for
|
||||
* refreshing the token / signalling logout. Without this, an expired
|
||||
* token spins forever at the 30s backoff cap.
|
||||
*/
|
||||
onAuthFailure?: () => void;
|
||||
/**
|
||||
* Invoked once when the reconnect loop bails out after
|
||||
* ``MAX_CONSECUTIVE_ERRORS`` non-401 failures. Lets the caller surface
|
||||
* a warning instead of the connection silently going dark.
|
||||
*/
|
||||
onPermanentFailure?: () => void;
|
||||
}
|
||||
|
||||
export interface EventStreamConnection {
|
||||
close(): void;
|
||||
}
|
||||
|
||||
/**
|
||||
* Backoff schedule (ms) for reconnect attempts. Capped at 30s so a long
|
||||
* outage doesn't push retries past Cloudflare's typical 100s idle-close
|
||||
* envelope. The schedule resets to 0 after a stream stays healthy for
|
||||
* ``HEALTHY_DEBOUNCE_MS``.
|
||||
*/
|
||||
const BACKOFF_SCHEDULE_MS = [0, 1_000, 2_000, 4_000, 8_000, 16_000, 30_000];
|
||||
const HEALTHY_DEBOUNCE_MS = 2_000;
|
||||
/**
|
||||
* Reconnect ceilings. Without these, the ``while (!closed)`` loop spins
|
||||
* forever on a persistently-failing endpoint — expired token (401s) or
|
||||
* sustained server outage (5xx). Both counters reset on a successful
|
||||
* stream open. Untested (no frontend test harness); behaviour verified
|
||||
* by manual trace of the loop in ``connectEventStream``.
|
||||
*/
|
||||
const MAX_CONSECUTIVE_401 = 3;
|
||||
const MAX_CONSECUTIVE_ERRORS = 20;
|
||||
|
||||
/** Up-to-±20% random jitter so N tabs reconnecting in lockstep stagger. */
|
||||
function withJitter(delayMs: number): number {
|
||||
if (delayMs <= 0) return 0;
|
||||
const span = delayMs * 0.2;
|
||||
return Math.max(0, Math.round(delayMs + (Math.random() * 2 - 1) * span));
|
||||
}
|
||||
|
||||
/**
|
||||
* Open a long-lived SSE connection to ``GET /api/events`` with
|
||||
* fetch-streaming semantics that mirror ``conversationHandlers.ts``.
|
||||
*
|
||||
* Returns immediately with an opaque handle; the connection lives in a
|
||||
* background async loop until ``close()`` is called or the underlying
|
||||
* ``AbortController`` fires.
|
||||
*
|
||||
* The ``Last-Event-ID`` cursor rides on the URL (``?last_event_id=...``)
|
||||
* rather than as a header so the request stays a CORS-simple GET — a
|
||||
* custom header would force a preflight OPTIONS that the production
|
||||
* cross-origin deployment isn't allowlisted for.
|
||||
*/
|
||||
export function connectEventStream(
|
||||
opts: EventStreamOptions,
|
||||
): EventStreamConnection {
|
||||
const controller = new AbortController();
|
||||
let closed = false;
|
||||
let attempt = 0;
|
||||
let consecutive401 = 0;
|
||||
let consecutiveErrors = 0;
|
||||
// Closure cursor. Seeded from the store on each connect attempt so
|
||||
// mid-session reconnects use the freshest id, but kept here too so
|
||||
// an in-flight stream's reconnect doesn't lose progress between ids
|
||||
// that the store hasn't seen yet (e.g. id-only frames).
|
||||
let lastEventId: string | null = opts.getLastEventId();
|
||||
|
||||
const notifyHealth = (h: EventStreamHealth) => {
|
||||
if (closed) return;
|
||||
opts.onHealthChange(h);
|
||||
};
|
||||
|
||||
void (async () => {
|
||||
while (!closed) {
|
||||
const baseDelay =
|
||||
BACKOFF_SCHEDULE_MS[Math.min(attempt, BACKOFF_SCHEDULE_MS.length - 1)];
|
||||
const delay = withJitter(baseDelay);
|
||||
if (delay > 0) {
|
||||
try {
|
||||
await sleep(delay, controller.signal);
|
||||
} catch {
|
||||
return; // aborted while waiting
|
||||
}
|
||||
if (closed) return;
|
||||
}
|
||||
|
||||
notifyHealth('connecting');
|
||||
|
||||
// Always re-read the store cursor before reconnecting and copy
|
||||
// it verbatim — including null. A null cursor isn't "leave
|
||||
// alone": ``backlog.truncated`` events fire ``sseLastEventIdReset``
|
||||
// to clear the slice, and the client must respect that on the
|
||||
// next attempt by sending no cursor (full-backlog replay) instead
|
||||
// of resending the stale one and re-tripping the same truncation.
|
||||
lastEventId = opts.getLastEventId();
|
||||
|
||||
const url = new URL(`${baseURL}/api/events`);
|
||||
if (lastEventId) url.searchParams.set('last_event_id', lastEventId);
|
||||
|
||||
// Auth header is omitted when token is null. Self-hosted dev
|
||||
// installs run with ``AUTH_TYPE`` unset; the backend maps those
|
||||
// requests to ``{"sub": "local"}`` so the SSE connection works
|
||||
// tokenless. When auth IS required, a missing header surfaces
|
||||
// as a 401 and the response.ok check below flips the health
|
||||
// back to unhealthy.
|
||||
const headers: Record<string, string> = {
|
||||
Accept: 'text/event-stream',
|
||||
};
|
||||
if (opts.token) {
|
||||
headers.Authorization = `Bearer ${opts.token}`;
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetch(url.toString(), {
|
||||
method: 'GET',
|
||||
headers,
|
||||
signal: controller.signal,
|
||||
// SSE must not be cached.
|
||||
cache: 'no-store',
|
||||
});
|
||||
|
||||
if (!response.ok || !response.body) {
|
||||
notifyHealth('unhealthy');
|
||||
// 401 typically means token expired. Bail out after N
|
||||
// consecutive 401s so the loop doesn't spin forever at the
|
||||
// 30s backoff cap with a stale token; the caller is
|
||||
// responsible for refreshing auth via ``onAuthFailure``.
|
||||
if (response.status === 401) {
|
||||
consecutive401 += 1;
|
||||
consecutiveErrors += 1;
|
||||
if (consecutive401 >= MAX_CONSECUTIVE_401) {
|
||||
opts.onAuthFailure?.();
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
consecutive401 = 0;
|
||||
consecutiveErrors += 1;
|
||||
}
|
||||
if (consecutiveErrors >= MAX_CONSECUTIVE_ERRORS) {
|
||||
opts.onPermanentFailure?.();
|
||||
return;
|
||||
}
|
||||
// 429: server-side per-user concurrency cap; backoff harder.
|
||||
if (response.status === 429) attempt = Math.max(attempt, 4);
|
||||
else attempt = Math.min(attempt + 1, BACKOFF_SCHEDULE_MS.length - 1);
|
||||
continue;
|
||||
}
|
||||
consecutive401 = 0;
|
||||
|
||||
// Connection is open. Mark healthy after either:
|
||||
// - 2s of open response body (covers servers that go silent), or
|
||||
// - first record received past the 2s mark.
|
||||
// The setTimeout path means a backend that never emits a single
|
||||
// record after sending the 200 still flips us out of `connecting`.
|
||||
let healthyMarked = false;
|
||||
const markHealthy = () => {
|
||||
if (healthyMarked) return;
|
||||
healthyMarked = true;
|
||||
notifyHealth('healthy');
|
||||
attempt = 0;
|
||||
consecutiveErrors = 0;
|
||||
};
|
||||
const debounceTimer = setTimeout(markHealthy, HEALTHY_DEBOUNCE_MS);
|
||||
|
||||
try {
|
||||
await readSSEStream(response.body, controller.signal, (record) => {
|
||||
if (record.id !== undefined) {
|
||||
lastEventId = record.id || null;
|
||||
if (record.id) opts.onLastEventId?.(record.id);
|
||||
else opts.onLastEventIdReset?.();
|
||||
}
|
||||
if (record.data === undefined) {
|
||||
// Keepalive comment, id-only frame, or comment line.
|
||||
// The cursor was already advanced via ``onLastEventId``
|
||||
// above so the slice tracks ids even on frames we don't
|
||||
// dispatch as events.
|
||||
return;
|
||||
}
|
||||
// Empty data line is technically valid SSE but useless; skip.
|
||||
if (record.data.trim().length === 0) return;
|
||||
let envelope: SSEEvent | null = null;
|
||||
try {
|
||||
envelope = JSON.parse(record.data) as SSEEvent;
|
||||
} catch {
|
||||
// Malformed payload; skip.
|
||||
return;
|
||||
}
|
||||
// Defensive shape validation — the cast above lies if the
|
||||
// server (or a man-in-the-middle) sends garbage.
|
||||
if (
|
||||
!envelope ||
|
||||
typeof envelope !== 'object' ||
|
||||
typeof envelope.type !== 'string'
|
||||
) {
|
||||
return;
|
||||
}
|
||||
if (record.id && !envelope.id) envelope.id = record.id;
|
||||
// Receiving a real envelope post-debounce-window flips
|
||||
// healthy if the timer hasn't already.
|
||||
markHealthy();
|
||||
// Every tab dispatches every envelope it receives into its
|
||||
// own Redux store. With N tabs open this means N copies of
|
||||
// the same toast — accepted as a v1 limitation; cross-tab
|
||||
// dedup via BroadcastChannel + navigator.locks is future
|
||||
// work. Toast-level suppression can be handled per surface.
|
||||
opts.onEvent(envelope);
|
||||
});
|
||||
} finally {
|
||||
clearTimeout(debounceTimer);
|
||||
}
|
||||
|
||||
// The reader returned without abort — server closed the stream.
|
||||
// Fall through to reconnect.
|
||||
notifyHealth('unhealthy');
|
||||
consecutiveErrors += 1;
|
||||
if (consecutiveErrors >= MAX_CONSECUTIVE_ERRORS) {
|
||||
opts.onPermanentFailure?.();
|
||||
return;
|
||||
}
|
||||
attempt = Math.min(attempt + 1, BACKOFF_SCHEDULE_MS.length - 1);
|
||||
} catch (err) {
|
||||
if (
|
||||
closed ||
|
||||
(err instanceof DOMException && err.name === 'AbortError')
|
||||
) {
|
||||
return;
|
||||
}
|
||||
notifyHealth('unhealthy');
|
||||
consecutiveErrors += 1;
|
||||
if (consecutiveErrors >= MAX_CONSECUTIVE_ERRORS) {
|
||||
opts.onPermanentFailure?.();
|
||||
return;
|
||||
}
|
||||
attempt = Math.min(attempt + 1, BACKOFF_SCHEDULE_MS.length - 1);
|
||||
}
|
||||
}
|
||||
})();
|
||||
|
||||
return {
|
||||
close() {
|
||||
if (closed) return;
|
||||
closed = true;
|
||||
controller.abort();
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
interface ParsedSSERecord {
|
||||
/**
|
||||
* ``undefined`` when the record had no ``id`` field at all. An empty
|
||||
* string means the server explicitly reset the cursor (an ``id:``
|
||||
* line with no value, per WHATWG SSE).
|
||||
*/
|
||||
id?: string;
|
||||
/** ``undefined`` for keepalive comments / id-only frames. */
|
||||
data?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Drain a ``ReadableStream<Uint8Array>`` of ``\n\n``-delimited SSE records,
|
||||
* forwarding each parsed record to ``onRecord``. Honours the WHATWG SSE
|
||||
* spec's mixed line-terminator handling and SSE comment lines.
|
||||
*/
|
||||
async function readSSEStream(
|
||||
body: ReadableStream<Uint8Array>,
|
||||
signal: AbortSignal,
|
||||
onRecord: (record: ParsedSSERecord) => void,
|
||||
): Promise<void> {
|
||||
const reader = body.getReader();
|
||||
const decoder = new TextDecoder('utf-8');
|
||||
let buffer = '';
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
if (signal.aborted) return;
|
||||
const { done, value } = await reader.read();
|
||||
if (done) return;
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
|
||||
// SSE records are separated by a blank line. WHATWG spec accepts
|
||||
// CRLF, CR, or LF — normalise so a stray CR can't smuggle a
|
||||
// boundary mid-record.
|
||||
buffer = buffer.replace(/\r\n/g, '\n').replace(/\r/g, '\n');
|
||||
|
||||
let boundary = buffer.indexOf('\n\n');
|
||||
while (boundary !== -1) {
|
||||
const raw = buffer.slice(0, boundary);
|
||||
buffer = buffer.slice(boundary + 2);
|
||||
const record = parseSSERecord(raw);
|
||||
if (record) onRecord(record);
|
||||
boundary = buffer.indexOf('\n\n');
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
try {
|
||||
reader.releaseLock();
|
||||
} catch {
|
||||
// Already released.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function parseSSERecord(raw: string): ParsedSSERecord | null {
|
||||
if (raw.length === 0) return null;
|
||||
const lines = raw.split('\n');
|
||||
let id: string | undefined;
|
||||
const dataParts: string[] = [];
|
||||
let sawDataField = false;
|
||||
|
||||
for (const line of lines) {
|
||||
if (line.length === 0) continue;
|
||||
if (line.startsWith(':')) continue; // SSE comment / keepalive
|
||||
const colonIdx = line.indexOf(':');
|
||||
const field = colonIdx === -1 ? line : line.slice(0, colonIdx);
|
||||
let value = colonIdx === -1 ? '' : line.slice(colonIdx + 1);
|
||||
// SSE: value may be prefixed by exactly one optional space.
|
||||
if (value.startsWith(' ')) value = value.slice(1);
|
||||
|
||||
if (field === 'id') {
|
||||
id = value;
|
||||
} else if (field === 'data') {
|
||||
sawDataField = true;
|
||||
dataParts.push(value);
|
||||
}
|
||||
// Other field names ('event', 'retry') are ignored for now.
|
||||
}
|
||||
|
||||
if (!sawDataField && id === undefined) return null;
|
||||
return {
|
||||
id,
|
||||
data: sawDataField ? dataParts.join('\n') : undefined,
|
||||
};
|
||||
}
|
||||
|
||||
function sleep(ms: number, signal: AbortSignal): Promise<void> {
|
||||
return new Promise((resolve, reject) => {
|
||||
if (signal.aborted) {
|
||||
reject(new DOMException('Aborted', 'AbortError'));
|
||||
return;
|
||||
}
|
||||
const timer = setTimeout(() => {
|
||||
signal.removeEventListener('abort', onAbort);
|
||||
resolve();
|
||||
}, ms);
|
||||
const onAbort = () => {
|
||||
clearTimeout(timer);
|
||||
signal.removeEventListener('abort', onAbort);
|
||||
reject(new DOMException('Aborted', 'AbortError'));
|
||||
};
|
||||
signal.addEventListener('abort', onAbort, { once: true });
|
||||
});
|
||||
}
|
||||
@@ -1,85 +0,0 @@
|
||||
import { useEffect } from 'react';
|
||||
import { useDispatch, useSelector, useStore } from 'react-redux';
|
||||
|
||||
import {
|
||||
selectLastEventId,
|
||||
sseHealthChanged,
|
||||
sseLastEventIdAdvanced,
|
||||
sseLastEventIdReset,
|
||||
} from '../notifications/notificationsSlice';
|
||||
import { selectToken, setToken } from '../preferences/preferenceSlice';
|
||||
import type { AppDispatch, RootState } from '../store';
|
||||
|
||||
import { connectEventStream } from './eventStreamClient';
|
||||
import { dispatchSSEEvent } from './dispatchEvent';
|
||||
|
||||
/**
|
||||
* Open the SSE connection for the current token and keep it alive for
|
||||
* the lifetime of the host component. Recreates the connection on
|
||||
* token change (login / refresh).
|
||||
*
|
||||
* The ``lastEventId`` cursor is read lazily from the slice on each
|
||||
* connect attempt via ``store.getState()`` — capturing it at mount time
|
||||
* would silently re-replay the entire 24h backlog on token rotation,
|
||||
* since the slice's id advances during the previous connection's
|
||||
* lifetime but a snapshot ref would still hold the value seen at
|
||||
* first mount.
|
||||
*/
|
||||
export function useEventStream(): void {
|
||||
const dispatch = useDispatch<AppDispatch>();
|
||||
const token = useSelector(selectToken);
|
||||
const store = useStore<RootState>();
|
||||
|
||||
useEffect(() => {
|
||||
// Connect even when token is null. Self-hosted dev installs run
|
||||
// with ``AUTH_TYPE`` unset, where ``handle_auth`` maps every
|
||||
// request to ``{"sub": "local"}`` regardless of headers — gating
|
||||
// the connection on a populated token would silently disable push
|
||||
// notifications for the most common configuration. When auth IS
|
||||
// required and token is null, the backend will 401 and the
|
||||
// health state will flip to ``unhealthy`` via the response check
|
||||
// inside ``connectEventStream``.
|
||||
const conn = connectEventStream({
|
||||
token,
|
||||
getLastEventId: () => selectLastEventId(store.getState()),
|
||||
onEvent: (envelope) => dispatchSSEEvent(envelope, dispatch),
|
||||
// Advance the slice cursor for every id-bearing frame. Each tab
|
||||
// owns an independent SSE connection and Redux store, so every
|
||||
// active tab tracks its own replay cursor.
|
||||
onLastEventId: (id) => dispatch(sseLastEventIdAdvanced(id)),
|
||||
// Server emitted ``id:`` with an empty value — WHATWG cursor reset.
|
||||
// Mirror the slice so the next reconnect doesn't resend a stale id.
|
||||
onLastEventIdReset: () => dispatch(sseLastEventIdReset()),
|
||||
onHealthChange: (health) => dispatch(sseHealthChanged(health)),
|
||||
// SSE 401 loop bail-out. Clear the stored token AND dispatch
|
||||
// ``setToken(null)`` so ``useAuth`` regenerates a fresh
|
||||
// ``session_jwt`` in-session; the Redux change also flips this
|
||||
// hook's ``[token]`` dep, tearing down and respawning the
|
||||
// connection with the new token. Without the dispatch a
|
||||
// ``session_jwt`` user is stuck until a hard reload.
|
||||
onAuthFailure: () => {
|
||||
console.error(
|
||||
'[useEventStream] giving up after repeated 401s on /api/events',
|
||||
);
|
||||
try {
|
||||
localStorage.removeItem('authToken');
|
||||
} catch {
|
||||
// localStorage unavailable (private mode, etc.) — nothing to do.
|
||||
}
|
||||
dispatch(setToken(null));
|
||||
},
|
||||
// Surface a warning when the non-401 error budget is exhausted so
|
||||
// the connection going dark isn't completely silent. Doesn't block
|
||||
// UI — just observable in devtools.
|
||||
onPermanentFailure: () => {
|
||||
console.warn(
|
||||
'[useEventStream] SSE connection failed permanently after repeated errors',
|
||||
);
|
||||
},
|
||||
});
|
||||
|
||||
return () => {
|
||||
conn.close();
|
||||
};
|
||||
}, [token, dispatch, store]);
|
||||
}
|
||||
@@ -1,4 +1,10 @@
|
||||
import { useCallback, useEffect, useRef, useState, RefObject } from 'react';
|
||||
import {
|
||||
useCallback,
|
||||
useEffect,
|
||||
useRef,
|
||||
useState,
|
||||
RefObject,
|
||||
} from 'react';
|
||||
|
||||
export function useOutsideAlerter<T extends HTMLElement>(
|
||||
ref: RefObject<T | null>,
|
||||
|
||||
@@ -15,39 +15,24 @@ export default function useAuth() {
|
||||
const generateNewToken = async () => {
|
||||
if (isGeneratingToken.current) return;
|
||||
isGeneratingToken.current = true;
|
||||
try {
|
||||
const response = await userService.getNewToken();
|
||||
const { token: newToken } = await response.json();
|
||||
localStorage.setItem('authToken', newToken);
|
||||
dispatch(setToken(newToken));
|
||||
setIsAuthLoading(false);
|
||||
return newToken;
|
||||
} finally {
|
||||
// Reset so a subsequent ``setToken(null)`` (SSE 401 recovery)
|
||||
// can trigger another generation. Without this the in-flight
|
||||
// guard would latch true forever after the first call.
|
||||
isGeneratingToken.current = false;
|
||||
}
|
||||
const response = await userService.getNewToken();
|
||||
const { token: newToken } = await response.json();
|
||||
localStorage.setItem('authToken', newToken);
|
||||
dispatch(setToken(newToken));
|
||||
setIsAuthLoading(false);
|
||||
return newToken;
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
// Re-fires when ``token`` flips to null mid-session (e.g.
|
||||
// ``useEventStream`` dispatches ``setToken(null)`` after repeated
|
||||
// SSE 401s) so ``session_jwt`` users get a fresh token without a
|
||||
// hard reload. ``authType`` short-circuits on subsequent runs.
|
||||
const initializeAuth = async () => {
|
||||
try {
|
||||
let resolvedAuthType = authType;
|
||||
if (resolvedAuthType === null) {
|
||||
const configRes = await userService.getConfig();
|
||||
const config = await configRes.json();
|
||||
resolvedAuthType = config.auth_type;
|
||||
setAuthType(resolvedAuthType);
|
||||
}
|
||||
const configRes = await userService.getConfig();
|
||||
const config = await configRes.json();
|
||||
setAuthType(config.auth_type);
|
||||
|
||||
if (resolvedAuthType === 'session_jwt' && !token) {
|
||||
if (config.auth_type === 'session_jwt' && !token) {
|
||||
await generateNewToken();
|
||||
} else if (resolvedAuthType === 'simple_jwt' && !token) {
|
||||
} else if (config.auth_type === 'simple_jwt' && !token) {
|
||||
setShowTokenModal(true);
|
||||
setIsAuthLoading(false);
|
||||
} else {
|
||||
@@ -59,7 +44,7 @@ export default function useAuth() {
|
||||
}
|
||||
};
|
||||
initializeAuth();
|
||||
}, [token, authType]);
|
||||
}, []);
|
||||
|
||||
const handleTokenSubmit = (enteredToken: string) => {
|
||||
localStorage.setItem('authToken', enteredToken);
|
||||
|
||||
@@ -70,9 +70,6 @@
|
||||
"sync": "Synchronisieren",
|
||||
"syncNow": "Jetzt synchronisieren",
|
||||
"syncing": "Synchronisiere...",
|
||||
"reingest": "Erneut indexieren",
|
||||
"ingestFailed": "Indexierung fehlgeschlagen",
|
||||
"ingestProcessing": "Indexierung...",
|
||||
"syncConfirmation": "Bist du sicher, dass du \"{{sourceName}}\" synchronisieren möchtest? Dies aktualisiert den Inhalt mit deinem Cloud-Speicher und kann Änderungen an einzelnen Chunks überschreiben.",
|
||||
"syncFrequency": {
|
||||
"never": "Nie",
|
||||
@@ -356,8 +353,6 @@
|
||||
"failed": "Upload fehlgeschlagen",
|
||||
"wait": "Dies kann einige Minuten dauern",
|
||||
"preparing": "Upload wird vorbereitet",
|
||||
"parsing": "Dateien werden verarbeitet...",
|
||||
"embedding": "Einbettung...",
|
||||
"tokenLimit": "Token-Limit überschritten, bitte lade ein kleineres Dokument hoch",
|
||||
"expandDetails": "Upload-Details erweitern",
|
||||
"collapseDetails": "Upload-Details einklappen",
|
||||
@@ -461,6 +456,11 @@
|
||||
"create": "Erstellen",
|
||||
"option": "Benutzern weitere Eingaben erlauben"
|
||||
},
|
||||
"searchConversations": {
|
||||
"searchPlaceholder": "Konversationen durchsuchen",
|
||||
"noResults": "Keine Ergebnisse gefunden",
|
||||
"loading": "Laden..."
|
||||
},
|
||||
"configTool": {
|
||||
"title": "Werkzeug-Konfiguration",
|
||||
"type": "Typ",
|
||||
|
||||
@@ -70,9 +70,6 @@
|
||||
"sync": "Sync",
|
||||
"syncNow": "Sync now",
|
||||
"syncing": "Syncing...",
|
||||
"reingest": "Reingest",
|
||||
"ingestFailed": "Indexing failed",
|
||||
"ingestProcessing": "Indexing…",
|
||||
"syncConfirmation": "Are you sure you want to sync \"{{sourceName}}\"? This will update the content with your cloud storage and may override any edits you made to individual chunks.",
|
||||
"syncFrequency": {
|
||||
"never": "Never",
|
||||
@@ -368,8 +365,6 @@
|
||||
"failed": "Upload failed",
|
||||
"wait": "This may take several minutes",
|
||||
"preparing": "Preparing upload",
|
||||
"parsing": "Parsing files…",
|
||||
"embedding": "Embedding…",
|
||||
"tokenLimit": "Over the token limit, please consider uploading smaller document",
|
||||
"expandDetails": "Expand upload details",
|
||||
"collapseDetails": "Collapse upload details",
|
||||
@@ -491,6 +486,11 @@
|
||||
"create": "Create",
|
||||
"option": "Allow users to prompt further"
|
||||
},
|
||||
"searchConversations": {
|
||||
"searchPlaceholder": "Search conversations",
|
||||
"noResults": "No results found",
|
||||
"loading": "Loading..."
|
||||
},
|
||||
"configTool": {
|
||||
"title": "Tool Config",
|
||||
"type": "Type",
|
||||
|
||||
@@ -70,9 +70,6 @@
|
||||
"sync": "Sincronizar",
|
||||
"syncNow": "Sincronizar ahora",
|
||||
"syncing": "Sincronizando...",
|
||||
"reingest": "Reindexar",
|
||||
"ingestFailed": "Error de indexación",
|
||||
"ingestProcessing": "Indexando...",
|
||||
"syncConfirmation": "¿Estás seguro de que deseas sincronizar \"{{sourceName}}\"? Esto actualizará el contenido con tu almacenamiento en la nube y puede anular cualquier edición que hayas realizado en fragmentos individuales.",
|
||||
"syncFrequency": {
|
||||
"never": "Nunca",
|
||||
@@ -356,8 +353,6 @@
|
||||
"failed": "Error al subir",
|
||||
"wait": "Esto puede tardar varios minutos",
|
||||
"preparing": "Preparando subida",
|
||||
"parsing": "Analizando archivos...",
|
||||
"embedding": "Generando incrustaciones...",
|
||||
"tokenLimit": "Excede el límite de tokens, considere cargar un documento más pequeño",
|
||||
"expandDetails": "Expandir detalles de subida",
|
||||
"collapseDetails": "Contraer detalles de subida",
|
||||
@@ -479,6 +474,11 @@
|
||||
"create": "Crear",
|
||||
"option": "Permitir a los usuarios realizar más consultas"
|
||||
},
|
||||
"searchConversations": {
|
||||
"searchPlaceholder": "Buscar conversaciones",
|
||||
"noResults": "No se encontraron resultados",
|
||||
"loading": "Cargando..."
|
||||
},
|
||||
"configTool": {
|
||||
"title": "Configuración de la Herramienta",
|
||||
"type": "Tipo",
|
||||
|
||||
@@ -70,9 +70,6 @@
|
||||
"sync": "同期",
|
||||
"syncNow": "今すぐ同期",
|
||||
"syncing": "同期中...",
|
||||
"reingest": "再インデックス",
|
||||
"ingestFailed": "インデックス作成に失敗しました",
|
||||
"ingestProcessing": "インデックス作成中...",
|
||||
"syncConfirmation": "\"{{sourceName}}\"を同期してもよろしいですか?これにより、コンテンツがクラウドストレージで更新され、個々のチャンクに加えた編集が上書きされる可能性があります。",
|
||||
"syncFrequency": {
|
||||
"never": "なし",
|
||||
@@ -356,8 +353,6 @@
|
||||
"failed": "アップロード失敗",
|
||||
"wait": "数分かかる場合があります",
|
||||
"preparing": "アップロードを準備中",
|
||||
"parsing": "ファイルを解析中...",
|
||||
"embedding": "埋め込み処理中...",
|
||||
"tokenLimit": "トークン制限を超えています。より小さいドキュメントをアップロードしてください",
|
||||
"expandDetails": "アップロードの詳細を展開",
|
||||
"collapseDetails": "アップロードの詳細を折りたたむ",
|
||||
@@ -479,6 +474,11 @@
|
||||
"create": "作成",
|
||||
"option": "ユーザーがより多くのクエリを実行できるようにします。"
|
||||
},
|
||||
"searchConversations": {
|
||||
"searchPlaceholder": "会話を検索",
|
||||
"noResults": "結果が見つかりません",
|
||||
"loading": "読み込み中..."
|
||||
},
|
||||
"configTool": {
|
||||
"title": "ツール設定",
|
||||
"type": "タイプ",
|
||||
|
||||
@@ -70,9 +70,6 @@
|
||||
"sync": "Синхронизация",
|
||||
"syncNow": "Синхронизировать сейчас",
|
||||
"syncing": "Синхронизация...",
|
||||
"reingest": "Переиндексировать",
|
||||
"ingestFailed": "Ошибка индексации",
|
||||
"ingestProcessing": "Индексация...",
|
||||
"syncConfirmation": "Вы уверены, что хотите синхронизировать \"{{sourceName}}\"? Это обновит содержимое с вашим облачным хранилищем и может перезаписать любые изменения, внесенные вами в отдельные фрагменты.",
|
||||
"syncFrequency": {
|
||||
"never": "Никогда",
|
||||
@@ -356,8 +353,6 @@
|
||||
"failed": "Ошибка загрузки",
|
||||
"wait": "Это может занять несколько минут",
|
||||
"preparing": "Подготовка загрузки",
|
||||
"parsing": "Обработка файлов...",
|
||||
"embedding": "Создание эмбеддингов...",
|
||||
"tokenLimit": "Превышен лимит токенов, рассмотрите возможность загрузки документа меньшего размера",
|
||||
"expandDetails": "Развернуть детали загрузки",
|
||||
"collapseDetails": "Свернуть детали загрузки",
|
||||
@@ -479,6 +474,11 @@
|
||||
"create": "Создать",
|
||||
"option": "Позволить пользователям делать дополнительные запросы."
|
||||
},
|
||||
"searchConversations": {
|
||||
"searchPlaceholder": "Поиск разговоров",
|
||||
"noResults": "Результаты не найдены",
|
||||
"loading": "Загрузка..."
|
||||
},
|
||||
"configTool": {
|
||||
"title": "Настройка инструмента",
|
||||
"type": "Тип",
|
||||
|
||||
@@ -70,9 +70,6 @@
|
||||
"sync": "同步",
|
||||
"syncNow": "立即同步",
|
||||
"syncing": "同步中...",
|
||||
"reingest": "重新索引",
|
||||
"ingestFailed": "索引失敗",
|
||||
"ingestProcessing": "索引中...",
|
||||
"syncConfirmation": "您確定要同步 \"{{sourceName}}\" 嗎?這將使用您的雲端儲存更新內容,並可能覆蓋您對個別文本塊所做的任何編輯。",
|
||||
"syncFrequency": {
|
||||
"never": "從不",
|
||||
@@ -356,8 +353,6 @@
|
||||
"failed": "上傳失敗",
|
||||
"wait": "這可能需要幾分鐘",
|
||||
"preparing": "準備上傳",
|
||||
"parsing": "正在解析檔案...",
|
||||
"embedding": "正在生成嵌入...",
|
||||
"tokenLimit": "超出令牌限制,請考慮上傳較小的文檔",
|
||||
"expandDetails": "展開上傳詳情",
|
||||
"collapseDetails": "摺疊上傳詳情",
|
||||
@@ -479,6 +474,11 @@
|
||||
"create": "建立",
|
||||
"option": "允許使用者進行更多查詢"
|
||||
},
|
||||
"searchConversations": {
|
||||
"searchPlaceholder": "搜尋對話",
|
||||
"noResults": "未找到結果",
|
||||
"loading": "載入中..."
|
||||
},
|
||||
"configTool": {
|
||||
"title": "工具設定",
|
||||
"type": "類型",
|
||||
|
||||
@@ -70,9 +70,6 @@
|
||||
"sync": "同步",
|
||||
"syncNow": "立即同步",
|
||||
"syncing": "同步中...",
|
||||
"reingest": "重新索引",
|
||||
"ingestFailed": "索引失败",
|
||||
"ingestProcessing": "索引中...",
|
||||
"syncConfirmation": "您确定要同步 \"{{sourceName}}\" 吗?这将使用您的云存储更新内容,并可能覆盖您对单个文本块所做的任何编辑。",
|
||||
"syncFrequency": {
|
||||
"never": "从不",
|
||||
@@ -356,8 +353,6 @@
|
||||
"failed": "上传失败",
|
||||
"wait": "这可能需要几分钟",
|
||||
"preparing": "准备上传",
|
||||
"parsing": "正在解析文件...",
|
||||
"embedding": "正在生成嵌入...",
|
||||
"tokenLimit": "超出令牌限制,请考虑上传较小的文档",
|
||||
"expandDetails": "展开上传详情",
|
||||
"collapseDetails": "折叠上传详情",
|
||||
@@ -479,6 +474,11 @@
|
||||
"create": "创建",
|
||||
"option": "允许用户进行更多查询。"
|
||||
},
|
||||
"searchConversations": {
|
||||
"searchPlaceholder": "搜索对话",
|
||||
"noResults": "未找到结果",
|
||||
"loading": "加载中..."
|
||||
},
|
||||
"configTool": {
|
||||
"title": "工具配置",
|
||||
"type": "类型",
|
||||
|
||||
@@ -15,7 +15,6 @@ import {
|
||||
SelectValue,
|
||||
} from '../components/ui/select';
|
||||
import { ActiveState } from '../models/misc';
|
||||
import { selectRecentEvents } from '../notifications/notificationsSlice';
|
||||
import { selectToken } from '../preferences/preferenceSlice';
|
||||
import WrapperComponent from './WrapperModal';
|
||||
|
||||
@@ -34,7 +33,6 @@ export default function MCPServerModal({
|
||||
}: MCPServerModalProps) {
|
||||
const { t } = useTranslation();
|
||||
const token = useSelector(selectToken);
|
||||
const recentEvents = useSelector(selectRecentEvents);
|
||||
|
||||
const authTypes = [
|
||||
{ label: t('settings.tools.mcp.authTypes.none'), value: 'none' },
|
||||
@@ -73,29 +71,17 @@ export default function MCPServerModal({
|
||||
>([]);
|
||||
const [errors, setErrors] = useState<{ [key: string]: string }>({});
|
||||
const oauthPopupRef = useRef<Window | null>(null);
|
||||
// Set after ``test_mcp_connection`` returns ``task_id``. The SSE
|
||||
// effect filters ``recentEvents`` to envelopes matching this id and
|
||||
// drives the OAuth UI (popup open / completion / failure) from the
|
||||
// push stream rather than polling the legacy status endpoint.
|
||||
const [oauthTaskId, setOauthTaskId] = useState<string | null>(null);
|
||||
// Highest event id we have already reacted to for this taskId. Each
|
||||
// mcp.oauth.* envelope must fire its side-effect once; without this
|
||||
// any later re-render that re-evaluates ``recentEvents`` would
|
||||
// re-open the popup or re-fire onComplete.
|
||||
const handledEventIdsRef = useRef<Set<string>>(new Set());
|
||||
// Holds the ``testConnection`` ``onComplete`` for the current
|
||||
// task id so the SSE effect can invoke it when the terminal event
|
||||
// lands. Reset to ``null`` on cancel / new test / unmount.
|
||||
const onCompleteRef = useRef<((result: any) => void) | null>(null);
|
||||
const popupOpenedRef = useRef(false);
|
||||
const pollingCancelledRef = useRef(false);
|
||||
const pollTimerRef = useRef<ReturnType<typeof setTimeout> | null>(null);
|
||||
const [oauthCompleted, setOAuthCompleted] = useState(false);
|
||||
const [saveActive, setSaveActive] = useState(false);
|
||||
|
||||
const cleanupOAuthListener = useCallback(() => {
|
||||
setOauthTaskId(null);
|
||||
handledEventIdsRef.current = new Set();
|
||||
onCompleteRef.current = null;
|
||||
popupOpenedRef.current = false;
|
||||
const cleanupPolling = useCallback(() => {
|
||||
pollingCancelledRef.current = true;
|
||||
if (pollTimerRef.current) {
|
||||
clearTimeout(pollTimerRef.current);
|
||||
pollTimerRef.current = null;
|
||||
}
|
||||
if (oauthPopupRef.current && !oauthPopupRef.current.closed) {
|
||||
oauthPopupRef.current.close();
|
||||
}
|
||||
@@ -103,8 +89,8 @@ export default function MCPServerModal({
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
return cleanupOAuthListener;
|
||||
}, [cleanupOAuthListener]);
|
||||
return cleanupPolling;
|
||||
}, [cleanupPolling]);
|
||||
|
||||
useEffect(() => {
|
||||
if (modalState === 'ACTIVE' && server) {
|
||||
@@ -133,7 +119,7 @@ export default function MCPServerModal({
|
||||
}, [modalState, server]);
|
||||
|
||||
const resetForm = () => {
|
||||
cleanupOAuthListener();
|
||||
cleanupPolling();
|
||||
setFormData({
|
||||
name: t('settings.tools.mcp.defaultServerName'),
|
||||
server_url: '',
|
||||
@@ -242,123 +228,114 @@ export default function MCPServerModal({
|
||||
return config;
|
||||
};
|
||||
|
||||
/**
|
||||
* Drive the OAuth handshake straight from the SSE stream:
|
||||
*
|
||||
* - ``mcp.oauth.awaiting_redirect`` → open the popup with the
|
||||
* ``authorization_url`` carried on the envelope. Previously this URL
|
||||
* came from polling ``/api/mcp_server/oauth_status/<task_id>``; the
|
||||
* worker now publishes it inline so we never need to poll.
|
||||
* - ``mcp.oauth.completed`` → enable Save, surface discovered tools,
|
||||
* invoke ``onComplete`` (resolves ``testConnection``'s pending state).
|
||||
* - ``mcp.oauth.failed`` → surface the error and reset Save.
|
||||
*
|
||||
* Each event is matched to the active task id via ``scope.id``. The
|
||||
* publisher is best-effort: a lost ``awaiting_redirect`` envelope
|
||||
* means the popup never opens, the user retries, and we accept that
|
||||
* over the prior 1s × 60 polling loop.
|
||||
*/
|
||||
useEffect(() => {
|
||||
if (!oauthTaskId) return;
|
||||
// ``recentEvents`` is newest-first (the slice ``unshift``s on
|
||||
// arrival). Walk it oldest-first so we observe the natural OAuth
|
||||
// ordering (``awaiting_redirect`` → ``completed``) when both
|
||||
// arrive between effect runs — otherwise we would short-circuit
|
||||
// on ``completed`` and never open the popup for the
|
||||
// ``awaiting_redirect`` envelope that was already buffered.
|
||||
for (let i = recentEvents.length - 1; i >= 0; i--) {
|
||||
const event = recentEvents[i];
|
||||
if (event.scope?.id !== oauthTaskId) continue;
|
||||
if (!event.id || handledEventIdsRef.current.has(event.id)) continue;
|
||||
const pollOAuthStatus = async (
|
||||
taskId: string,
|
||||
onComplete: (result: any) => void,
|
||||
) => {
|
||||
let attempts = 0;
|
||||
const maxAttempts = 60;
|
||||
let popupOpened = false;
|
||||
pollingCancelledRef.current = false;
|
||||
|
||||
const payload = (event.payload || {}) as Record<string, unknown>;
|
||||
const poll = async () => {
|
||||
if (pollingCancelledRef.current) return;
|
||||
try {
|
||||
const resp = await userService.getMCPOAuthStatus(taskId, token);
|
||||
if (pollingCancelledRef.current) return;
|
||||
const data = await resp.json();
|
||||
if (pollingCancelledRef.current) return;
|
||||
|
||||
if (event.type === 'mcp.oauth.awaiting_redirect') {
|
||||
handledEventIdsRef.current.add(event.id);
|
||||
const authUrl = payload.authorization_url as string | undefined;
|
||||
if (authUrl && !popupOpenedRef.current) {
|
||||
popupOpenedRef.current = true;
|
||||
if (data.authorization_url && !popupOpened) {
|
||||
if (oauthPopupRef.current && !oauthPopupRef.current.closed) {
|
||||
oauthPopupRef.current.close();
|
||||
}
|
||||
oauthPopupRef.current = window.open(
|
||||
authUrl,
|
||||
data.authorization_url,
|
||||
'oauthPopup',
|
||||
'width=600,height=700',
|
||||
);
|
||||
popupOpened = true;
|
||||
|
||||
if (!oauthPopupRef.current) {
|
||||
// Popup blocked — surface the URL inline so the user can
|
||||
// click through manually. Browsers gate ``window.open``
|
||||
// outside of a user gesture, and the SSE event arrives
|
||||
// asynchronously, so a blocked popup is expected on
|
||||
// some browsers / configs.
|
||||
setTestResult({
|
||||
success: true,
|
||||
message: t('settings.tools.mcp.oauthPopupBlocked', {
|
||||
defaultValue:
|
||||
'Popup blocked by browser. Click below to authorize:',
|
||||
}),
|
||||
authorization_url: authUrl,
|
||||
authorization_url: data.authorization_url,
|
||||
});
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (event.type === 'mcp.oauth.completed') {
|
||||
handledEventIdsRef.current.add(event.id);
|
||||
const tools = Array.isArray(payload.tools) ? payload.tools : [];
|
||||
const toolsCount =
|
||||
(payload.tools_count as number | undefined) ?? tools.length;
|
||||
setOAuthCompleted(true);
|
||||
setSaveActive(true);
|
||||
if (oauthPopupRef.current && !oauthPopupRef.current.closed) {
|
||||
oauthPopupRef.current.close();
|
||||
}
|
||||
const cb = onCompleteRef.current;
|
||||
onCompleteRef.current = null;
|
||||
setOauthTaskId(null);
|
||||
if (cb) {
|
||||
cb({
|
||||
status: 'completed',
|
||||
task_id: oauthTaskId,
|
||||
tools,
|
||||
tools_count: toolsCount,
|
||||
const callbackReceived =
|
||||
data.status === 'callback_received' || data.status === 'completed';
|
||||
|
||||
if (data.status === 'completed') {
|
||||
setOAuthCompleted(true);
|
||||
setSaveActive(true);
|
||||
onComplete({
|
||||
...data,
|
||||
success: true,
|
||||
message: t('settings.tools.mcp.oauthCompleted'),
|
||||
});
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (event.type === 'mcp.oauth.failed') {
|
||||
handledEventIdsRef.current.add(event.id);
|
||||
const message =
|
||||
(payload.error as string) ??
|
||||
t('settings.tools.mcp.errors.oauthFailed');
|
||||
setSaveActive(false);
|
||||
if (oauthPopupRef.current && !oauthPopupRef.current.closed) {
|
||||
oauthPopupRef.current.close();
|
||||
}
|
||||
const cb = onCompleteRef.current;
|
||||
onCompleteRef.current = null;
|
||||
setOauthTaskId(null);
|
||||
if (cb) {
|
||||
cb({
|
||||
status: 'error',
|
||||
task_id: oauthTaskId,
|
||||
if (oauthPopupRef.current && !oauthPopupRef.current.closed) {
|
||||
oauthPopupRef.current.close();
|
||||
}
|
||||
} else if (data.status === 'error' || data.success === false) {
|
||||
setSaveActive(false);
|
||||
onComplete({
|
||||
...data,
|
||||
success: false,
|
||||
message,
|
||||
message: data.message || t('settings.tools.mcp.errors.oauthFailed'),
|
||||
});
|
||||
if (oauthPopupRef.current && !oauthPopupRef.current.closed) {
|
||||
oauthPopupRef.current.close();
|
||||
}
|
||||
} else {
|
||||
if (++attempts < maxAttempts) {
|
||||
if (
|
||||
oauthPopupRef.current &&
|
||||
oauthPopupRef.current.closed &&
|
||||
popupOpened &&
|
||||
!callbackReceived
|
||||
) {
|
||||
setSaveActive(false);
|
||||
onComplete({
|
||||
success: false,
|
||||
message: t('settings.tools.mcp.errors.oauthFailed'),
|
||||
});
|
||||
return;
|
||||
}
|
||||
pollTimerRef.current = setTimeout(poll, 1000);
|
||||
} else {
|
||||
setSaveActive(false);
|
||||
cleanupPolling();
|
||||
onComplete({
|
||||
success: false,
|
||||
message: t('settings.tools.mcp.errors.oauthTimeout'),
|
||||
});
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
if (pollingCancelledRef.current) return;
|
||||
if (++attempts < maxAttempts) {
|
||||
pollTimerRef.current = setTimeout(poll, 1000);
|
||||
} else {
|
||||
cleanupPolling();
|
||||
onComplete({
|
||||
success: false,
|
||||
message: t('settings.tools.mcp.errors.oauthTimeout'),
|
||||
});
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}, [recentEvents, oauthTaskId, t]);
|
||||
};
|
||||
poll();
|
||||
};
|
||||
|
||||
const testConnection = async () => {
|
||||
if (!validateForm()) return;
|
||||
cleanupOAuthListener();
|
||||
cleanupPolling();
|
||||
setTesting(true);
|
||||
setTestResult(null);
|
||||
setDiscoveredTools([]);
|
||||
@@ -378,7 +355,7 @@ export default function MCPServerModal({
|
||||
message: t('settings.tools.mcp.oauthInProgress'),
|
||||
});
|
||||
setSaveActive(false);
|
||||
onCompleteRef.current = (finalResult: any) => {
|
||||
pollOAuthStatus(result.task_id, (finalResult) => {
|
||||
setTestResult(finalResult);
|
||||
if (finalResult.tools && Array.isArray(finalResult.tools)) {
|
||||
setDiscoveredTools(finalResult.tools);
|
||||
@@ -388,11 +365,7 @@ export default function MCPServerModal({
|
||||
oauth_task_id: result.task_id || '',
|
||||
}));
|
||||
setTesting(false);
|
||||
};
|
||||
// Activate the SSE listener for this task id. The effect above
|
||||
// will react when ``mcp.oauth.{awaiting_redirect,completed,failed}``
|
||||
// arrives.
|
||||
setOauthTaskId(result.task_id);
|
||||
});
|
||||
} else {
|
||||
setTestResult(result);
|
||||
if (result.success && result.tools && Array.isArray(result.tools)) {
|
||||
|
||||
232
frontend/src/modals/SearchConversationsModal.tsx
Normal file
232
frontend/src/modals/SearchConversationsModal.tsx
Normal file
@@ -0,0 +1,232 @@
|
||||
import {
|
||||
useEffect,
|
||||
useMemo,
|
||||
useRef,
|
||||
useState,
|
||||
type KeyboardEvent,
|
||||
} from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import SearchIcon from '../assets/search.svg';
|
||||
import { searchConversations } from '../preferences/preferenceApi';
|
||||
import WrapperModal from './WrapperModal';
|
||||
|
||||
type ConversationListItem = {
|
||||
id: string;
|
||||
name: string;
|
||||
match_field?: 'name' | 'prompt' | 'response' | null;
|
||||
match_snippet?: string | null;
|
||||
};
|
||||
|
||||
type SearchConversationsModalProps = {
|
||||
close: () => void;
|
||||
conversations: ConversationListItem[];
|
||||
token: string | null;
|
||||
onSelectConversation: (id: string) => void;
|
||||
};
|
||||
|
||||
// Escape regex metacharacters so the user query can be used in a RegExp.
|
||||
function escapeRegExp(value: string): string {
|
||||
return value.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
|
||||
}
|
||||
|
||||
function HighlightedText({ text, query }: { text: string; query: string }) {
|
||||
const trimmed = query.trim();
|
||||
if (!trimmed) return <>{text}</>;
|
||||
const parts = text.split(new RegExp(`(${escapeRegExp(trimmed)})`, 'gi'));
|
||||
return (
|
||||
<>
|
||||
{parts.map((part, idx) =>
|
||||
part.toLowerCase() === trimmed.toLowerCase() ? (
|
||||
<mark
|
||||
key={idx}
|
||||
className="text-purple-30 bg-transparent font-semibold"
|
||||
>
|
||||
{part}
|
||||
</mark>
|
||||
) : (
|
||||
<span key={idx}>{part}</span>
|
||||
),
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
export default function SearchConversationsModal({
|
||||
close,
|
||||
conversations,
|
||||
token,
|
||||
onSelectConversation,
|
||||
}: SearchConversationsModalProps) {
|
||||
const { t } = useTranslation();
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
const resultRefs = useRef<Array<HTMLButtonElement | null>>([]);
|
||||
|
||||
const [query, setQuery] = useState('');
|
||||
const [results, setResults] = useState<ConversationListItem[] | null>(null);
|
||||
const [isSearching, setIsSearching] = useState(false);
|
||||
const [activeIndex, setActiveIndex] = useState(-1);
|
||||
|
||||
useEffect(() => {
|
||||
inputRef.current?.focus();
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
const trimmed = query.trim();
|
||||
if (!trimmed) {
|
||||
setResults(null);
|
||||
setIsSearching(false);
|
||||
return;
|
||||
}
|
||||
setIsSearching(true);
|
||||
const handle = setTimeout(() => {
|
||||
searchConversations(trimmed, token).then((result) => {
|
||||
setResults(result.data ?? []);
|
||||
setIsSearching(false);
|
||||
});
|
||||
}, 300);
|
||||
return () => clearTimeout(handle);
|
||||
}, [query, token]);
|
||||
|
||||
const visibleConversations = useMemo(() => {
|
||||
if (!query.trim()) return conversations;
|
||||
return results ?? [];
|
||||
}, [query, results, conversations]);
|
||||
|
||||
useEffect(() => {
|
||||
if (isSearching || visibleConversations.length === 0) {
|
||||
setActiveIndex(-1);
|
||||
return;
|
||||
}
|
||||
|
||||
setActiveIndex((currentIndex) => {
|
||||
if (currentIndex >= 0 && currentIndex < visibleConversations.length) {
|
||||
return currentIndex;
|
||||
}
|
||||
|
||||
return 0;
|
||||
});
|
||||
}, [isSearching, visibleConversations]);
|
||||
|
||||
useEffect(() => {
|
||||
if (activeIndex < 0) return;
|
||||
|
||||
resultRefs.current[activeIndex]?.scrollIntoView({
|
||||
block: 'nearest',
|
||||
});
|
||||
}, [activeIndex]);
|
||||
|
||||
const handleSelect = (id: string) => {
|
||||
onSelectConversation(id);
|
||||
close();
|
||||
};
|
||||
|
||||
const handleInputKeyDown = (event: KeyboardEvent<HTMLInputElement>) => {
|
||||
if (visibleConversations.length === 0 || isSearching) return;
|
||||
|
||||
if (event.key === 'ArrowDown') {
|
||||
event.preventDefault();
|
||||
setActiveIndex((currentIndex) =>
|
||||
currentIndex < visibleConversations.length - 1 ? currentIndex + 1 : 0,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (event.key === 'ArrowUp') {
|
||||
event.preventDefault();
|
||||
setActiveIndex((currentIndex) =>
|
||||
currentIndex > 0 ? currentIndex - 1 : visibleConversations.length - 1,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (event.key === 'Enter' && activeIndex >= 0) {
|
||||
event.preventDefault();
|
||||
handleSelect(visibleConversations[activeIndex].id);
|
||||
}
|
||||
};
|
||||
|
||||
const showEmptyState =
|
||||
!!query.trim() && !isSearching && visibleConversations.length === 0;
|
||||
|
||||
return (
|
||||
<WrapperModal
|
||||
close={close}
|
||||
className="w-[92vw] max-w-xl p-0"
|
||||
contentClassName="max-h-[70vh]"
|
||||
>
|
||||
<div className="flex flex-col">
|
||||
<div className="border-sidebar-border flex items-center gap-2 border-b px-5 py-4">
|
||||
<img src={SearchIcon} alt="search" className="h-4 w-4 opacity-60" />
|
||||
<input
|
||||
ref={inputRef}
|
||||
type="text"
|
||||
value={query}
|
||||
onChange={(e) => setQuery(e.target.value)}
|
||||
onKeyDown={handleInputKeyDown}
|
||||
placeholder={t('modals.searchConversations.searchPlaceholder')}
|
||||
className="text-foreground placeholder:text-muted-foreground w-full bg-transparent text-sm outline-none"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="max-h-[55vh] overflow-y-auto py-2" role="listbox">
|
||||
{isSearching && (
|
||||
<div className="text-muted-foreground px-5 py-3 text-xs">
|
||||
{t('modals.searchConversations.loading')}
|
||||
</div>
|
||||
)}
|
||||
{showEmptyState && (
|
||||
<div className="text-muted-foreground px-5 py-3 text-xs">
|
||||
{t('modals.searchConversations.noResults')}
|
||||
</div>
|
||||
)}
|
||||
{!isSearching &&
|
||||
visibleConversations.map((conversation, index) => {
|
||||
const trimmedQuery = query.trim();
|
||||
const showSnippet =
|
||||
!!trimmedQuery &&
|
||||
!!conversation.match_snippet &&
|
||||
conversation.match_field !== 'name';
|
||||
const isActive = index === activeIndex;
|
||||
|
||||
return (
|
||||
<button
|
||||
key={conversation.id}
|
||||
type="button"
|
||||
ref={(element) => {
|
||||
resultRefs.current[index] = element;
|
||||
}}
|
||||
onClick={() => handleSelect(conversation.id)}
|
||||
onMouseEnter={() => setActiveIndex(index)}
|
||||
role="option"
|
||||
aria-selected={isActive}
|
||||
className={`text-foreground flex w-full flex-col items-start gap-0.5 px-5 py-2.5 text-left text-sm ${
|
||||
isActive ? 'bg-sidebar-accent' : 'hover:bg-sidebar-accent'
|
||||
}`}
|
||||
>
|
||||
<span className="w-full truncate">
|
||||
{trimmedQuery ? (
|
||||
<HighlightedText
|
||||
text={conversation.name}
|
||||
query={trimmedQuery}
|
||||
/>
|
||||
) : (
|
||||
conversation.name
|
||||
)}
|
||||
</span>
|
||||
{showSnippet && (
|
||||
<span className="text-muted-foreground line-clamp-2 w-full text-xs">
|
||||
<HighlightedText
|
||||
text={conversation.match_snippet as string}
|
||||
query={trimmedQuery}
|
||||
/>
|
||||
</span>
|
||||
)}
|
||||
</button>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
</WrapperModal>
|
||||
);
|
||||
}
|
||||
@@ -14,8 +14,6 @@ export type Doc = {
|
||||
syncFrequency?: string;
|
||||
isNested?: boolean;
|
||||
provider?: string;
|
||||
// Derived server-side from ingest_chunk_progress (sources API).
|
||||
ingestStatus?: 'processing' | 'failed';
|
||||
};
|
||||
|
||||
export type GetDocsResponse = {
|
||||
|
||||
@@ -1,174 +0,0 @@
|
||||
import { useDispatch, useSelector } from 'react-redux';
|
||||
import { useMatch, useNavigate } from 'react-router-dom';
|
||||
|
||||
import WarnIcon from '../assets/warn.svg';
|
||||
import type { RootState } from '../store';
|
||||
|
||||
import {
|
||||
dismissToolApproval,
|
||||
selectDismissedToolApprovals,
|
||||
selectRecentEvents,
|
||||
} from './notificationsSlice';
|
||||
|
||||
/**
|
||||
* Surface ``tool.approval.required`` events as toasts that look like
|
||||
* ``UploadToast`` (same fixed bottom-right rail) — but only when the
|
||||
* user is NOT already on the conversation that needs the approval.
|
||||
*
|
||||
* - Dedup by ``conversation_id`` (the SSE ``scope.id``): keep only
|
||||
* the newest pending event per conversation, so multiple paused
|
||||
* tools in one conversation collapse to one toast.
|
||||
* - Dismissal is per-event-id so a *new* pause of the same
|
||||
* conversation will re-surface (different event id).
|
||||
* - Clicking "Review" navigates to ``/c/<id>`` and dismisses.
|
||||
*/
|
||||
export default function ToolApprovalToast() {
|
||||
const dispatch = useDispatch();
|
||||
const navigate = useNavigate();
|
||||
const events = useSelector(selectRecentEvents);
|
||||
const dismissed = useSelector(selectDismissedToolApprovals);
|
||||
|
||||
// Pull the active conversation id off the route. Two route shapes
|
||||
// place a conversation in view: the bare ``/c/:conversationId`` and
|
||||
// the agent-scoped ``/agents/:agentId/c/:conversationId``. Hooks
|
||||
// are unconditional; the toast just respects whichever matches.
|
||||
//
|
||||
// ``/c/new`` is the conversation route's literal-string placeholder
|
||||
// for "unknown / not-yet-loaded conversation" (see the rewrite in
|
||||
// the conversation route). Treat it the same as no match — the
|
||||
// user isn't viewing any specific conversation yet, so an approval
|
||||
// toast for any conversation should still surface.
|
||||
const plainMatch = useMatch('/c/:conversationId');
|
||||
const agentMatch = useMatch('/agents/:agentId/c/:conversationId');
|
||||
const matchedConversationId =
|
||||
plainMatch?.params.conversationId ??
|
||||
agentMatch?.params.conversationId ??
|
||||
null;
|
||||
const currentConversationId =
|
||||
matchedConversationId === 'new' ? null : matchedConversationId;
|
||||
|
||||
// Conversation name lookup — best-effort. The slice's
|
||||
// ``preference.conversations.data`` is populated by
|
||||
// ``useDataInitializer`` once auth resolves; until then we fall
|
||||
// back to the conversation id.
|
||||
const conversations = useSelector(
|
||||
(state: RootState) => state.preference.conversations.data,
|
||||
);
|
||||
|
||||
const dismissedSet = new Set(dismissed);
|
||||
const pendingByConversation = new Map<
|
||||
string,
|
||||
{ eventId: string; conversationId: string }
|
||||
>();
|
||||
for (const event of events) {
|
||||
if (event.type !== 'tool.approval.required') continue;
|
||||
if (!event.id) continue; // can't dismiss without an id
|
||||
if (dismissedSet.has(event.id)) continue;
|
||||
const conversationId = event.scope?.id;
|
||||
if (!conversationId) continue;
|
||||
if (currentConversationId && conversationId === currentConversationId) {
|
||||
continue;
|
||||
}
|
||||
if (pendingByConversation.has(conversationId)) continue;
|
||||
// ``recentEvents`` is newest-first, so the first match per convId
|
||||
// is the most recent unhandled approval.
|
||||
pendingByConversation.set(conversationId, {
|
||||
eventId: event.id,
|
||||
conversationId,
|
||||
});
|
||||
}
|
||||
|
||||
if (pendingByConversation.size === 0) return null;
|
||||
|
||||
const conversationName = (conversationId: string): string => {
|
||||
const found = conversations?.find((c) => c.id === conversationId);
|
||||
return found?.name ?? 'Conversation';
|
||||
};
|
||||
|
||||
return (
|
||||
// Sit above ``UploadToast`` (which owns ``bottom-4 right-4``)
|
||||
// rather than overlapping it. ``bottom-24`` ≈ 96px clears one
|
||||
// standard-height upload toast; multiple in-flight uploads will
|
||||
// stack into the gap, at which point the approval toast still
|
||||
// floats on top via ``z-50``. Acceptable v1 layout — the two
|
||||
// surfaces are rarely competing.
|
||||
<div
|
||||
className="fixed right-4 bottom-24 z-50 flex max-w-md flex-col gap-2"
|
||||
onMouseDown={(e) => e.stopPropagation()}
|
||||
role="status"
|
||||
aria-live="polite"
|
||||
aria-atomic="true"
|
||||
>
|
||||
{Array.from(pendingByConversation.values()).map(
|
||||
({ eventId, conversationId }) => (
|
||||
<div
|
||||
key={eventId}
|
||||
className="border-border bg-card w-[271px] overflow-hidden rounded-2xl border shadow-[0px_24px_48px_0px_#00000029]"
|
||||
>
|
||||
<div className="bg-accent/50 dark:bg-muted flex items-center justify-between px-4 py-3">
|
||||
<h3 className="font-inter dark:text-foreground text-[14px] leading-[16.5px] font-medium text-black">
|
||||
Tool approval needed
|
||||
</h3>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => dispatch(dismissToolApproval(eventId))}
|
||||
className="flex h-8 items-center justify-center p-0 text-black opacity-70 transition-opacity hover:opacity-100 dark:text-white"
|
||||
aria-label="Dismiss"
|
||||
>
|
||||
<svg
|
||||
width="16"
|
||||
height="16"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
className="h-4 w-4"
|
||||
>
|
||||
<path
|
||||
d="M18 6L6 18"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
<path
|
||||
d="M6 6L18 18"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
<div className="flex items-center justify-between gap-3 px-5 py-3">
|
||||
<div className="flex min-w-0 items-center gap-2">
|
||||
<img
|
||||
src={WarnIcon}
|
||||
alt=""
|
||||
className="h-5 w-5 shrink-0"
|
||||
aria-hidden="true"
|
||||
/>
|
||||
<p
|
||||
className="font-inter dark:text-muted-foreground max-w-[140px] truncate text-[13px] leading-[16.5px] font-normal text-black"
|
||||
title={conversationName(conversationId)}
|
||||
>
|
||||
{conversationName(conversationId)}
|
||||
</p>
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => {
|
||||
dispatch(dismissToolApproval(eventId));
|
||||
navigate(`/c/${conversationId}`);
|
||||
}}
|
||||
className="rounded-full bg-[#7D54D1] px-3 py-1 text-[12px] font-medium text-white shadow-sm hover:bg-[#6a45b8]"
|
||||
>
|
||||
Review
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
),
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,71 +0,0 @@
|
||||
import { beforeEach, describe, expect, it } from 'vitest';
|
||||
|
||||
import {
|
||||
isDismissed,
|
||||
loadDismissed,
|
||||
saveDismissed,
|
||||
} from './dismissedPersistence';
|
||||
|
||||
const KEY = 'test:dismissed';
|
||||
const TTL = 24 * 60 * 60 * 1000;
|
||||
|
||||
describe('dismissedPersistence', () => {
|
||||
beforeEach(() => {
|
||||
localStorage.clear();
|
||||
});
|
||||
|
||||
it('saveDismissed + loadDismissed round-trips entries', () => {
|
||||
const now = Date.now();
|
||||
saveDismissed(KEY, [
|
||||
{ id: 'a', at: now },
|
||||
{ id: 'b', at: now - 1000 },
|
||||
]);
|
||||
const loaded = loadDismissed(KEY, TTL);
|
||||
expect(loaded).toEqual([
|
||||
{ id: 'a', at: now },
|
||||
{ id: 'b', at: now - 1000 },
|
||||
]);
|
||||
});
|
||||
|
||||
it('loadDismissed returns [] when key absent', () => {
|
||||
expect(loadDismissed(KEY, TTL)).toEqual([]);
|
||||
});
|
||||
|
||||
it('loadDismissed drops entries past the TTL cutoff', () => {
|
||||
const now = Date.now();
|
||||
saveDismissed(KEY, [
|
||||
{ id: 'fresh', at: now - 1000 },
|
||||
{ id: 'stale', at: now - (TTL + 1000) },
|
||||
]);
|
||||
const loaded = loadDismissed(KEY, TTL);
|
||||
expect(loaded.map((e) => e.id)).toEqual(['fresh']);
|
||||
});
|
||||
|
||||
it('loadDismissed returns [] on malformed JSON without throwing', () => {
|
||||
localStorage.setItem(KEY, '{not json');
|
||||
expect(loadDismissed(KEY, TTL)).toEqual([]);
|
||||
});
|
||||
|
||||
it('loadDismissed filters out entries with wrong shape', () => {
|
||||
const now = Date.now();
|
||||
localStorage.setItem(
|
||||
KEY,
|
||||
JSON.stringify([
|
||||
{ id: 'good', at: now },
|
||||
{ id: 123, at: now },
|
||||
{ id: 'bad-at', at: 'nope' },
|
||||
null,
|
||||
'string-entry',
|
||||
]),
|
||||
);
|
||||
const loaded = loadDismissed(KEY, TTL);
|
||||
expect(loaded.map((e) => e.id)).toEqual(['good']);
|
||||
});
|
||||
|
||||
it('isDismissed matches by id', () => {
|
||||
const list = [{ id: 'a', at: 1 }];
|
||||
expect(isDismissed(list, 'a')).toBe(true);
|
||||
expect(isDismissed(list, 'b')).toBe(false);
|
||||
expect(isDismissed([], 'a')).toBe(false);
|
||||
});
|
||||
});
|
||||
@@ -1,42 +0,0 @@
|
||||
// Persisted dismissal lists for SSE-driven toasts. Without persistence
|
||||
// the next page's backlog replay re-fires the events and pops the
|
||||
// toast back. TTL matches the backend's stream retention.
|
||||
|
||||
export interface DismissedEntry {
|
||||
id: string;
|
||||
at: number;
|
||||
}
|
||||
|
||||
export function loadDismissed(key: string, ttlMs: number): DismissedEntry[] {
|
||||
if (typeof localStorage === 'undefined') return [];
|
||||
try {
|
||||
const raw = localStorage.getItem(key);
|
||||
if (!raw) return [];
|
||||
const parsed = JSON.parse(raw);
|
||||
if (!Array.isArray(parsed)) return [];
|
||||
const cutoff = Date.now() - ttlMs;
|
||||
return parsed.filter(
|
||||
(e): e is DismissedEntry =>
|
||||
!!e &&
|
||||
typeof e === 'object' &&
|
||||
typeof (e as DismissedEntry).id === 'string' &&
|
||||
typeof (e as DismissedEntry).at === 'number' &&
|
||||
(e as DismissedEntry).at >= cutoff,
|
||||
);
|
||||
} catch {
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
export function saveDismissed(key: string, list: DismissedEntry[]): void {
|
||||
if (typeof localStorage === 'undefined') return;
|
||||
try {
|
||||
localStorage.setItem(key, JSON.stringify(list));
|
||||
} catch {
|
||||
// Best-effort: ignore quota / private-mode errors.
|
||||
}
|
||||
}
|
||||
|
||||
export function isDismissed(list: DismissedEntry[], id: string): boolean {
|
||||
return list.some((e) => e.id === id);
|
||||
}
|
||||
@@ -1,109 +0,0 @@
|
||||
import { describe, expect, it, vi, afterEach } from 'vitest';
|
||||
|
||||
import reducer, {
|
||||
dismissToolApproval,
|
||||
sseEventReceived,
|
||||
sseLastEventIdReset,
|
||||
type SSEEvent,
|
||||
} from './notificationsSlice';
|
||||
|
||||
const baseEvent = (overrides: Partial<SSEEvent> = {}): SSEEvent => ({
|
||||
id: 'evt-1',
|
||||
type: 'source.ingest.progress',
|
||||
...overrides,
|
||||
});
|
||||
|
||||
const seedState = () => reducer(undefined, { type: '@@INIT' });
|
||||
|
||||
afterEach(() => {
|
||||
vi.useRealTimers();
|
||||
});
|
||||
|
||||
describe('sseEventReceived', () => {
|
||||
it('dedupes by id when the same envelope arrives twice', () => {
|
||||
let state = seedState();
|
||||
state = reducer(state, sseEventReceived(baseEvent({ id: 'a' })));
|
||||
state = reducer(state, sseEventReceived(baseEvent({ id: 'a' })));
|
||||
expect(state.recentEvents).toHaveLength(1);
|
||||
expect(state.recentEvents[0].id).toBe('a');
|
||||
});
|
||||
|
||||
it('does not update lastEventId for envelopes without an id (backlog.truncated)', () => {
|
||||
let state = seedState();
|
||||
state = reducer(state, sseEventReceived(baseEvent({ id: 'cursor-1' })));
|
||||
expect(state.lastEventId).toBe('cursor-1');
|
||||
state = reducer(
|
||||
state,
|
||||
sseEventReceived({ type: 'backlog.truncated' } as SSEEvent),
|
||||
);
|
||||
expect(state.lastEventId).toBe('cursor-1');
|
||||
});
|
||||
|
||||
it('caps recentEvents at 100 entries (oldest evicted)', () => {
|
||||
let state = seedState();
|
||||
for (let i = 0; i < 105; i += 1) {
|
||||
state = reducer(state, sseEventReceived(baseEvent({ id: `e-${i}` })));
|
||||
}
|
||||
expect(state.recentEvents).toHaveLength(100);
|
||||
// Newest first.
|
||||
expect(state.recentEvents[0].id).toBe('e-104');
|
||||
expect(state.recentEvents[state.recentEvents.length - 1].id).toBe('e-5');
|
||||
});
|
||||
});
|
||||
|
||||
describe('sseLastEventIdReset', () => {
|
||||
it('clears lastEventId back to null', () => {
|
||||
let state = seedState();
|
||||
state = reducer(state, sseEventReceived(baseEvent({ id: 'x' })));
|
||||
expect(state.lastEventId).toBe('x');
|
||||
state = reducer(state, sseLastEventIdReset());
|
||||
expect(state.lastEventId).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('dismissToolApproval', () => {
|
||||
it('dedupes by id and refreshes the timestamp', () => {
|
||||
vi.useFakeTimers();
|
||||
vi.setSystemTime(new Date('2026-01-01T00:00:00Z'));
|
||||
let state = seedState();
|
||||
state = reducer(state, dismissToolApproval('approval-1'));
|
||||
const firstAt = state.dismissedToolApprovals[0].at;
|
||||
|
||||
vi.setSystemTime(new Date('2026-01-01T00:05:00Z'));
|
||||
state = reducer(state, dismissToolApproval('approval-1'));
|
||||
expect(state.dismissedToolApprovals).toHaveLength(1);
|
||||
expect(state.dismissedToolApprovals[0].at).toBeGreaterThan(firstAt);
|
||||
});
|
||||
|
||||
it('evicts entries older than the 24h TTL', () => {
|
||||
vi.useFakeTimers();
|
||||
vi.setSystemTime(new Date('2026-01-01T00:00:00Z'));
|
||||
let state = seedState();
|
||||
state = reducer(state, dismissToolApproval('old-1'));
|
||||
|
||||
// Move past the 24h TTL window.
|
||||
vi.setSystemTime(new Date('2026-01-02T00:00:01Z'));
|
||||
state = reducer(state, dismissToolApproval('fresh-1'));
|
||||
const ids = state.dismissedToolApprovals.map((entry) => entry.id);
|
||||
expect(ids).toEqual(['fresh-1']);
|
||||
});
|
||||
|
||||
it('applies the 200-entry cap as a backstop after TTL filtering', () => {
|
||||
vi.useFakeTimers();
|
||||
vi.setSystemTime(new Date('2026-01-01T00:00:00Z'));
|
||||
let state = seedState();
|
||||
// Insert 205 distinct ids within the TTL window.
|
||||
for (let i = 0; i < 205; i += 1) {
|
||||
// Advance time slightly so the at-values are distinct but well
|
||||
// inside the 24h TTL.
|
||||
vi.setSystemTime(
|
||||
new Date(`2026-01-01T00:00:${(i % 60).toString().padStart(2, '0')}Z`),
|
||||
);
|
||||
state = reducer(state, dismissToolApproval(`id-${i}`));
|
||||
}
|
||||
expect(state.dismissedToolApprovals).toHaveLength(200);
|
||||
// The 200-cap keeps the most recently pushed ids.
|
||||
expect(state.dismissedToolApprovals[0].id).toBe('id-5');
|
||||
expect(state.dismissedToolApprovals[199].id).toBe('id-204');
|
||||
});
|
||||
});
|
||||
@@ -1,200 +0,0 @@
|
||||
import { createSelector, createSlice, PayloadAction } from '@reduxjs/toolkit';
|
||||
|
||||
import { RootState } from '../store';
|
||||
import { loadDismissed, saveDismissed } from './dismissedPersistence';
|
||||
|
||||
const DISMISSED_TOOL_APPROVALS_STORAGE_KEY = 'docsgpt:dismissedToolApprovals';
|
||||
|
||||
/**
|
||||
* Envelope shape published by the backend SSE endpoint
|
||||
* (`application/events/publisher.py`). Mirrors the wire JSON 1:1.
|
||||
*/
|
||||
export interface SSEEvent<P = Record<string, unknown>> {
|
||||
id?: string;
|
||||
type: string;
|
||||
ts?: string;
|
||||
user_id?: string;
|
||||
topic?: string;
|
||||
scope?: { kind: string; id: string };
|
||||
payload?: P;
|
||||
}
|
||||
|
||||
/**
|
||||
* Connection-health state machine the rest of the app reads via
|
||||
* ``selectPushChannelHealthy`` to gate polling-fallback behaviour.
|
||||
*
|
||||
* - ``connecting`` — initial fetch in flight, or reconnecting after drop.
|
||||
* - ``healthy`` — at least one event (data or keepalive) received and
|
||||
* the stream has been open >2s.
|
||||
* - ``unhealthy`` — last attempt failed or has been dropped without a
|
||||
* successful re-establish; fall back to polling.
|
||||
*/
|
||||
export type PushHealth = 'connecting' | 'healthy' | 'unhealthy';
|
||||
|
||||
interface NotificationsState {
|
||||
health: PushHealth;
|
||||
/** Most-recent server-issued id; sent back as ``Last-Event-ID`` on reconnect. */
|
||||
lastEventId: string | null;
|
||||
/** Bounded ring of recent events for the in-app notifications surface. */
|
||||
recentEvents: SSEEvent[];
|
||||
/**
|
||||
* Wallclock ms of last received data-bearing event. SSE keepalives
|
||||
* are comment lines (no ``id:``/``data:``) and do NOT update this —
|
||||
* they're filtered out at the parser level.
|
||||
*/
|
||||
lastEventReceivedAt: number | null;
|
||||
/**
|
||||
* Event ids of ``tool.approval.required`` notifications the user
|
||||
* dismissed (close button or by navigating into the conversation),
|
||||
* each tagged with the wallclock ms at which it was dismissed.
|
||||
* Keyed by the SSE event id so a *new* approval for the same
|
||||
* conversation re-surfaces; the dismissal only suppresses the one
|
||||
* specific paused-tool prompt.
|
||||
*
|
||||
* Entries are evicted by TTL first (anything older than
|
||||
* ``DISMISSED_TOOL_APPROVALS_TTL_MS``), then by FIFO cap. The TTL
|
||||
* matters because a pure FIFO with a small cap can evict a *still-
|
||||
* pending* approval id before the user acts on it — re-popping the
|
||||
* toast on the next dispatch. The 24h ceiling is longer than any
|
||||
* plausible approval-pending window.
|
||||
*/
|
||||
dismissedToolApprovals: Array<{ id: string; at: number }>;
|
||||
}
|
||||
|
||||
const RECENT_EVENTS_CAP = 100;
|
||||
const DISMISSED_TOOL_APPROVALS_CAP = 200;
|
||||
const DISMISSED_TOOL_APPROVALS_TTL_MS = 24 * 60 * 60 * 1000;
|
||||
|
||||
const initialState: NotificationsState = {
|
||||
health: 'connecting',
|
||||
lastEventId: null,
|
||||
recentEvents: [],
|
||||
lastEventReceivedAt: null,
|
||||
// Hydrate from localStorage: SSE backlog replay re-delivers the
|
||||
// originating ``tool.approval.required`` envelopes on reload.
|
||||
dismissedToolApprovals: loadDismissed(
|
||||
DISMISSED_TOOL_APPROVALS_STORAGE_KEY,
|
||||
DISMISSED_TOOL_APPROVALS_TTL_MS,
|
||||
),
|
||||
};
|
||||
|
||||
export const notificationsSlice = createSlice({
|
||||
name: 'notifications',
|
||||
initialState,
|
||||
reducers: {
|
||||
sseEventReceived: (state, action: PayloadAction<SSEEvent>) => {
|
||||
const e = action.payload;
|
||||
// Drop immediate duplicates. Snapshot replay + live tail can
|
||||
// both deliver the same id when the live pubsub frame and the
|
||||
// replay XRANGE overlap, and consumers that walk
|
||||
// ``recentEvents`` (FileTree, ConnectorTree, MCPServerModal,
|
||||
// ToolApprovalToast) would otherwise act on the same envelope
|
||||
// twice. The route's dedup floor catches the common case; this
|
||||
// is a belt-and-suspenders for in-tab StrictMode double-mounts
|
||||
// and any envelope that slips through with the same id.
|
||||
if (e.id && state.recentEvents[0]?.id === e.id) return;
|
||||
state.recentEvents.unshift(e);
|
||||
if (state.recentEvents.length > RECENT_EVENTS_CAP) {
|
||||
state.recentEvents.length = RECENT_EVENTS_CAP;
|
||||
}
|
||||
if (e.id) state.lastEventId = e.id;
|
||||
state.lastEventReceivedAt = Date.now();
|
||||
},
|
||||
sseHealthChanged: (state, action: PayloadAction<PushHealth>) => {
|
||||
state.health = action.payload;
|
||||
},
|
||||
/**
|
||||
* Lifecycle helper used by reconnect bookkeeping — does not record
|
||||
* an event, just stamps "we heard from the server" so the polling
|
||||
* fallback stays disabled while keepalives arrive.
|
||||
*/
|
||||
sseHeartbeat: (state) => {
|
||||
state.lastEventReceivedAt = Date.now();
|
||||
},
|
||||
sseLastEventIdReset: (state) => {
|
||||
// Backlog truncated — drop the cursor so the next reconnect
|
||||
// doesn't try to resume past the retained window.
|
||||
state.lastEventId = null;
|
||||
},
|
||||
/**
|
||||
* Advance the cursor without recording an event. Fired for every
|
||||
* id-bearing frame including keepalives and id-only comments,
|
||||
* so the slice cursor tracks the freshest id the wire has
|
||||
* delivered even when no envelope was dispatched. Without this,
|
||||
* ``lastEventId`` would only advance via ``sseEventReceived`` and
|
||||
* a long quiet period of keepalives would leave it stale —
|
||||
* eventually re-snapshotting the same backlog on each reconnect
|
||||
* and exhausting the per-user replay budget.
|
||||
*/
|
||||
sseLastEventIdAdvanced: (state, action: PayloadAction<string>) => {
|
||||
state.lastEventId = action.payload;
|
||||
},
|
||||
clearRecentEvents: (state) => {
|
||||
state.recentEvents = [];
|
||||
},
|
||||
/**
|
||||
* Suppress a ``tool.approval.required`` notification by the SSE
|
||||
* event id. The toast surface filters dismissed ids out; a *new*
|
||||
* approval event for the same conversation has a different id
|
||||
* and re-surfaces, which is the desired UX (each pause is its
|
||||
* own decision).
|
||||
*/
|
||||
dismissToolApproval: (state, action: PayloadAction<string>) => {
|
||||
const id = action.payload;
|
||||
const now = Date.now();
|
||||
// Evict expired entries first so the TTL — not the FIFO cap —
|
||||
// governs when stale ids drop, keeping still-pending approvals
|
||||
// suppressed.
|
||||
const cutoff = now - DISMISSED_TOOL_APPROVALS_TTL_MS;
|
||||
state.dismissedToolApprovals = state.dismissedToolApprovals.filter(
|
||||
(entry) => entry.at >= cutoff && entry.id !== id,
|
||||
);
|
||||
state.dismissedToolApprovals.push({ id, at: now });
|
||||
if (state.dismissedToolApprovals.length > DISMISSED_TOOL_APPROVALS_CAP) {
|
||||
state.dismissedToolApprovals = state.dismissedToolApprovals.slice(
|
||||
-DISMISSED_TOOL_APPROVALS_CAP,
|
||||
);
|
||||
}
|
||||
saveDismissed(
|
||||
DISMISSED_TOOL_APPROVALS_STORAGE_KEY,
|
||||
state.dismissedToolApprovals,
|
||||
);
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
export const {
|
||||
sseEventReceived,
|
||||
sseHealthChanged,
|
||||
sseHeartbeat,
|
||||
sseLastEventIdReset,
|
||||
sseLastEventIdAdvanced,
|
||||
clearRecentEvents,
|
||||
dismissToolApproval,
|
||||
} = notificationsSlice.actions;
|
||||
|
||||
export const selectSseHealth = (state: RootState): PushHealth =>
|
||||
state.notifications.health;
|
||||
|
||||
export const selectPushChannelHealthy = (state: RootState): boolean =>
|
||||
state.notifications.health === 'healthy';
|
||||
|
||||
export const selectLastEventId = (state: RootState): string | null =>
|
||||
state.notifications.lastEventId;
|
||||
|
||||
export const selectLastEventReceivedAt = (state: RootState): number | null =>
|
||||
state.notifications.lastEventReceivedAt;
|
||||
|
||||
export const selectRecentEvents = (state: RootState): SSEEvent[] =>
|
||||
state.notifications.recentEvents;
|
||||
|
||||
// Memoised so ``useSelector`` consumers don't re-render on every
|
||||
// unrelated ``notifications`` state change — the underlying ``{id,at}``
|
||||
// array only changes when ``dismissToolApproval`` runs, but the
|
||||
// projected ``.map`` would otherwise return a fresh array each call.
|
||||
export const selectDismissedToolApprovals = createSelector(
|
||||
(state: RootState) => state.notifications.dismissedToolApprovals,
|
||||
(entries) => entries.map((entry) => entry.id),
|
||||
);
|
||||
|
||||
export default notificationsSlice.reducer;
|
||||
@@ -85,6 +85,49 @@ export async function getConversations(
|
||||
}
|
||||
}
|
||||
|
||||
export async function searchConversations(
|
||||
query: string,
|
||||
token: string | null,
|
||||
limit = 30,
|
||||
): Promise<GetConversationsResult> {
|
||||
try {
|
||||
const response = await conversationService.searchConversations(
|
||||
query,
|
||||
token,
|
||||
limit,
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
console.error('Error searching conversations:', response.statusText);
|
||||
return { data: null, loading: false };
|
||||
}
|
||||
|
||||
const rawData: unknown = await response.json();
|
||||
if (!Array.isArray(rawData)) {
|
||||
console.error(
|
||||
'Invalid data format received from API: Expected an array.',
|
||||
rawData,
|
||||
);
|
||||
return { data: null, loading: false };
|
||||
}
|
||||
|
||||
const conversations: ConversationSummary[] = rawData.map((item: any) => ({
|
||||
id: item.id,
|
||||
name: item.name,
|
||||
agent_id: item.agent_id ?? null,
|
||||
match_field: item.match_field ?? null,
|
||||
match_snippet: item.match_snippet ?? null,
|
||||
}));
|
||||
return { data: conversations, loading: false };
|
||||
} catch (error) {
|
||||
console.error(
|
||||
'An unexpected error occurred while searching conversations:',
|
||||
error,
|
||||
);
|
||||
return { data: null, loading: false };
|
||||
}
|
||||
}
|
||||
|
||||
export function getLocalApiKey(): string | null {
|
||||
const key = localStorage.getItem('DocsGPTApiKey');
|
||||
return key;
|
||||
|
||||
@@ -2,6 +2,8 @@ export type ConversationSummary = {
|
||||
id: string;
|
||||
name: string;
|
||||
agent_id: string | null;
|
||||
match_field?: 'name' | 'prompt' | 'response' | null;
|
||||
match_snippet?: string | null;
|
||||
};
|
||||
|
||||
export type GetConversationsResult = {
|
||||
|
||||
@@ -27,12 +27,6 @@ import {
|
||||
setSourceDocs,
|
||||
} from '../preferences/preferenceSlice';
|
||||
import Upload from '../upload/Upload';
|
||||
import {
|
||||
addUploadTask,
|
||||
removeUploadTask,
|
||||
selectUploadTasks,
|
||||
updateUploadTask,
|
||||
} from '../upload/uploadSlice';
|
||||
import { formatDate } from '../utils/dateTimeUtils';
|
||||
import FileTree from '../components/FileTree';
|
||||
import ConnectorTree from '../components/ConnectorTree';
|
||||
@@ -62,7 +56,6 @@ export default function Sources({
|
||||
const [isDarkTheme] = useDarkTheme();
|
||||
const dispatch = useDispatch();
|
||||
const token = useSelector(selectToken);
|
||||
const uploadTasks = useSelector(selectUploadTasks);
|
||||
|
||||
const [searchTerm, setSearchTerm] = useState<string>('');
|
||||
const debouncedSearchTerm = useDebouncedValue(searchTerm, 500);
|
||||
@@ -256,57 +249,6 @@ export default function Sources({
|
||||
}
|
||||
};
|
||||
|
||||
const handleReingest = async (doc: Doc) => {
|
||||
if (!doc.id) {
|
||||
return;
|
||||
}
|
||||
const sourceId = doc.id;
|
||||
// Drop stale toast rows for this source (a finished/dismissed task
|
||||
// would swallow the reingest's SSE events), then open a fresh one.
|
||||
uploadTasks
|
||||
.filter((task) => task.sourceId === sourceId)
|
||||
.forEach((task) => dispatch(removeUploadTask(task.id)));
|
||||
const reingestTaskId = `reingest-${sourceId}-${Date.now()}`;
|
||||
dispatch(
|
||||
addUploadTask({
|
||||
id: reingestTaskId,
|
||||
fileName: doc.name || sourceId,
|
||||
progress: 0,
|
||||
status: 'training',
|
||||
sourceId,
|
||||
}),
|
||||
);
|
||||
try {
|
||||
const response = await userService.reingestSource(
|
||||
{ source_id: sourceId },
|
||||
token,
|
||||
);
|
||||
const data = await response.json();
|
||||
if (!data.success) {
|
||||
console.error('Reingest failed:', data.error || data.message);
|
||||
dispatch(
|
||||
updateUploadTask({
|
||||
id: reingestTaskId,
|
||||
updates: {
|
||||
status: 'failed',
|
||||
errorMessage: data.error || data.message,
|
||||
},
|
||||
}),
|
||||
);
|
||||
return;
|
||||
}
|
||||
refreshDocs(undefined, currentPage, rowsPerPage);
|
||||
} catch (error) {
|
||||
console.error('Error reingesting source:', error);
|
||||
dispatch(
|
||||
updateUploadTask({
|
||||
id: reingestTaskId,
|
||||
updates: { status: 'failed' },
|
||||
}),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
const [documentToDelete, setDocumentToDelete] = useState<{
|
||||
index: number;
|
||||
document: Doc;
|
||||
@@ -341,19 +283,6 @@ export default function Sources({
|
||||
},
|
||||
];
|
||||
|
||||
if (document.ingestStatus === 'failed') {
|
||||
actions.push({
|
||||
icon: SyncIcon,
|
||||
label: t('settings.sources.reingest'),
|
||||
onClick: () => {
|
||||
handleReingest(document);
|
||||
},
|
||||
iconWidth: 14,
|
||||
iconHeight: 14,
|
||||
variant: 'primary',
|
||||
});
|
||||
}
|
||||
|
||||
if (document.syncFrequency) {
|
||||
actions.push({
|
||||
icon: SyncIcon,
|
||||
@@ -554,16 +483,6 @@ export default function Sources({
|
||||
</div>
|
||||
|
||||
<div className="flex flex-col items-start justify-start gap-1">
|
||||
{document.ingestStatus === 'failed' && (
|
||||
<span className="rounded-full bg-red-100 px-2 py-0.5 text-[11px] leading-[16px] font-medium text-red-700 dark:bg-red-900/30 dark:text-red-400">
|
||||
{t('settings.sources.ingestFailed')}
|
||||
</span>
|
||||
)}
|
||||
{document.ingestStatus === 'processing' && (
|
||||
<span className="bg-muted-foreground/10 text-muted-foreground rounded-full px-2 py-0.5 text-[11px] leading-[16px] font-medium">
|
||||
{t('settings.sources.ingestProcessing')}
|
||||
</span>
|
||||
)}
|
||||
<div className="flex items-center gap-2">
|
||||
<img
|
||||
src={CalendarIcon}
|
||||
|
||||
@@ -4,7 +4,6 @@ import agentPreviewReducer from './agents/agentPreviewSlice';
|
||||
import workflowPreviewReducer from './agents/workflow/workflowPreviewSlice';
|
||||
import { conversationSlice } from './conversation/conversationSlice';
|
||||
import { sharedConversationSlice } from './conversation/sharedConversationSlice';
|
||||
import notificationsReducer from './notifications/notificationsSlice';
|
||||
import { getStoredRecentDocs } from './preferences/preferenceApi';
|
||||
import {
|
||||
Preference,
|
||||
@@ -68,7 +67,6 @@ const store = configureStore({
|
||||
upload: uploadReducer,
|
||||
agentPreview: agentPreviewReducer,
|
||||
workflowPreview: workflowPreviewReducer,
|
||||
notifications: notificationsReducer,
|
||||
},
|
||||
middleware: (getDefaultMiddleware) =>
|
||||
getDefaultMiddleware().concat(prefListenerMiddleware.middleware),
|
||||
|
||||
@@ -2,9 +2,9 @@ import { useCallback, useState } from 'react';
|
||||
import { nanoid } from '@reduxjs/toolkit';
|
||||
import { useDropzone } from 'react-dropzone';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useDispatch, useSelector, useStore } from 'react-redux';
|
||||
import { useDispatch, useSelector } from 'react-redux';
|
||||
|
||||
import type { RootState } from '../store';
|
||||
import userService from '../api/services/userService';
|
||||
import { getSessionToken } from '../utils/providerUtils';
|
||||
import Dropdown from '../components/Dropdown';
|
||||
import Input from '../components/Input';
|
||||
@@ -298,7 +298,6 @@ function Upload({
|
||||
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useDispatch();
|
||||
const store = useStore<RootState>();
|
||||
|
||||
const ingestorOptions: IngestorOption[] = IngestorFormSchemas.filter(
|
||||
(schema) => (schema.validate ? schema.validate() : true),
|
||||
@@ -335,113 +334,110 @@ function Upload({
|
||||
[dispatch],
|
||||
);
|
||||
|
||||
/**
|
||||
* Wait for the source.ingest.* SSE pipeline to flip this task into a
|
||||
* terminal state, then run the post-completion side effects: refresh
|
||||
* the global source list, auto-select the new doc, and fire the
|
||||
* caller's ``onSuccessfulUpload`` hook. The slice's extraReducer
|
||||
* (uploadSlice.ts) is the sole driver of the task's status; we only
|
||||
* subscribe so the side effects can fire after the modal has closed.
|
||||
*/
|
||||
const trackTraining = useCallback(
|
||||
(clientTaskId: string) => {
|
||||
let handled = false;
|
||||
(backendTaskId: string, clientTaskId: string) => {
|
||||
let timeoutId: number | null = null;
|
||||
|
||||
const handleTerminal = (status: 'completed' | 'failed') => {
|
||||
if (handled) return;
|
||||
handled = true;
|
||||
if (status !== 'completed') return;
|
||||
getDocs(token)
|
||||
.then((docs) => {
|
||||
dispatch(setSourceDocs(docs));
|
||||
if (Array.isArray(docs)) {
|
||||
const existingDocIds = new Set(
|
||||
(Array.isArray(sourceDocs) ? sourceDocs : [])
|
||||
.map((doc: Doc) => doc?.id)
|
||||
.filter((id): id is string => Boolean(id)),
|
||||
);
|
||||
const newDoc = docs.find(
|
||||
(doc: Doc) => doc.id && !existingDocIds.has(doc.id),
|
||||
);
|
||||
if (newDoc) {
|
||||
// If only one doc is selected, replace it completely
|
||||
// If multiple docs are selected, append the new doc
|
||||
if (selectedDocs.length === 1) {
|
||||
dispatch(setSelectedDocs([newDoc]));
|
||||
} else {
|
||||
dispatch(setSelectedDocs([...selectedDocs, newDoc]));
|
||||
const poll = () => {
|
||||
userService
|
||||
.getTaskStatus(backendTaskId, null)
|
||||
.then((response) => response.json())
|
||||
.then(async (data) => {
|
||||
if (!data.success && data.message) {
|
||||
if (timeoutId !== null) {
|
||||
clearTimeout(timeoutId);
|
||||
timeoutId = null;
|
||||
}
|
||||
handleTaskFailure(clientTaskId, data.message);
|
||||
return;
|
||||
}
|
||||
|
||||
if (data.status === 'SUCCESS') {
|
||||
if (timeoutId !== null) {
|
||||
clearTimeout(timeoutId);
|
||||
timeoutId = null;
|
||||
}
|
||||
|
||||
const docs = await getDocs(token);
|
||||
dispatch(setSourceDocs(docs));
|
||||
|
||||
if (Array.isArray(docs)) {
|
||||
const existingDocIds = new Set(
|
||||
(Array.isArray(sourceDocs) ? sourceDocs : [])
|
||||
.map((doc: Doc) => doc?.id)
|
||||
.filter((id): id is string => Boolean(id)),
|
||||
);
|
||||
const newDoc = docs.find(
|
||||
(doc: Doc) => doc.id && !existingDocIds.has(doc.id),
|
||||
);
|
||||
if (newDoc) {
|
||||
// If only one doc is selected, replace it completely
|
||||
// If multiple docs are selected, append the new doc
|
||||
if (selectedDocs.length === 1) {
|
||||
dispatch(setSelectedDocs([newDoc]));
|
||||
} else {
|
||||
dispatch(setSelectedDocs([...selectedDocs, newDoc]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (data.result?.limited) {
|
||||
dispatch(
|
||||
updateUploadTask({
|
||||
id: clientTaskId,
|
||||
updates: {
|
||||
status: 'failed',
|
||||
progress: 100,
|
||||
errorMessage: t('modals.uploadDoc.progress.tokenLimit'),
|
||||
},
|
||||
}),
|
||||
);
|
||||
} else {
|
||||
dispatch(
|
||||
updateUploadTask({
|
||||
id: clientTaskId,
|
||||
updates: {
|
||||
status: 'completed',
|
||||
progress: 100,
|
||||
errorMessage: undefined,
|
||||
},
|
||||
}),
|
||||
);
|
||||
onSuccessfulUpload?.();
|
||||
}
|
||||
} else if (data.status === 'FAILURE') {
|
||||
if (timeoutId !== null) {
|
||||
clearTimeout(timeoutId);
|
||||
timeoutId = null;
|
||||
}
|
||||
handleTaskFailure(clientTaskId, data.result?.message);
|
||||
} else if (data.status === 'PROGRESS') {
|
||||
dispatch(
|
||||
updateUploadTask({
|
||||
id: clientTaskId,
|
||||
updates: {
|
||||
status: 'training',
|
||||
progress: Math.min(100, data.result?.current ?? 0),
|
||||
},
|
||||
}),
|
||||
);
|
||||
timeoutId = window.setTimeout(poll, 5000);
|
||||
} else {
|
||||
timeoutId = window.setTimeout(poll, 5000);
|
||||
}
|
||||
onSuccessfulUpload?.();
|
||||
})
|
||||
.catch((err) => {
|
||||
console.error(
|
||||
'SSE-driven post-completion source-list refresh failed:',
|
||||
err,
|
||||
);
|
||||
.catch((error) => {
|
||||
if (timeoutId !== null) {
|
||||
clearTimeout(timeoutId);
|
||||
timeoutId = null;
|
||||
}
|
||||
handleTaskFailure(clientTaskId, error?.message);
|
||||
});
|
||||
};
|
||||
|
||||
const check = () => {
|
||||
const state = store.getState();
|
||||
const task = state.upload.tasks.find((t) => t.id === clientTaskId);
|
||||
if (!task) return false;
|
||||
if (task.status === 'completed' || task.status === 'failed') {
|
||||
handleTerminal(task.status);
|
||||
return true;
|
||||
}
|
||||
// Recover from the race where the terminal SSE landed before
|
||||
// ``xhr.onload`` populated ``task.sourceId`` — the slice
|
||||
// silently drops such events (no task to match by sourceId).
|
||||
// Mirrors ConnectorTree/FileTree's ``recentEvents`` walk.
|
||||
if (task.sourceId) {
|
||||
for (const event of state.notifications.recentEvents) {
|
||||
if (event.scope?.id !== task.sourceId) continue;
|
||||
if (event.type === 'source.ingest.completed') {
|
||||
handleTerminal('completed');
|
||||
return true;
|
||||
}
|
||||
if (event.type === 'source.ingest.failed') {
|
||||
handleTerminal('failed');
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
if (check()) return;
|
||||
const MAX_WAIT_MS = 5 * 60_000;
|
||||
let unsubscribe: (() => void) | null = null;
|
||||
const timer = window.setTimeout(() => {
|
||||
unsubscribe?.();
|
||||
if (!handled) {
|
||||
handled = true;
|
||||
console.warn(
|
||||
'trackTraining: timed out waiting for terminal SSE',
|
||||
clientTaskId,
|
||||
);
|
||||
dispatch(
|
||||
updateUploadTask({
|
||||
id: clientTaskId,
|
||||
updates: {
|
||||
status: 'failed',
|
||||
errorMessage:
|
||||
'Timed out waiting for ingest completion. The ingest may still be running — please refresh to check.',
|
||||
},
|
||||
}),
|
||||
);
|
||||
}
|
||||
}, MAX_WAIT_MS);
|
||||
unsubscribe = store.subscribe(() => {
|
||||
if (check()) {
|
||||
window.clearTimeout(timer);
|
||||
unsubscribe?.();
|
||||
}
|
||||
});
|
||||
timeoutId = window.setTimeout(poll, 3000);
|
||||
},
|
||||
[dispatch, onSuccessfulUpload, selectedDocs, sourceDocs, store, token],
|
||||
[dispatch, handleTaskFailure, onSuccessfulUpload, sourceDocs, t, token],
|
||||
);
|
||||
|
||||
const onDrop = useCallback(
|
||||
@@ -503,23 +499,19 @@ function Upload({
|
||||
xhr.onload = () => {
|
||||
if (xhr.status >= 200 && xhr.status < 300) {
|
||||
try {
|
||||
const parsed = JSON.parse(xhr.responseText) as {
|
||||
task_id?: string;
|
||||
source_id?: string;
|
||||
};
|
||||
const parsed = JSON.parse(xhr.responseText) as { task_id?: string };
|
||||
if (parsed.task_id) {
|
||||
dispatch(
|
||||
updateUploadTask({
|
||||
id: clientTaskId,
|
||||
updates: {
|
||||
taskId: parsed.task_id,
|
||||
sourceId: parsed.source_id,
|
||||
status: 'training',
|
||||
progress: 0,
|
||||
},
|
||||
}),
|
||||
);
|
||||
trackTraining(clientTaskId);
|
||||
trackTraining(parsed.task_id, clientTaskId);
|
||||
} else {
|
||||
dispatch(
|
||||
updateUploadTask({
|
||||
@@ -635,23 +627,19 @@ function Upload({
|
||||
xhr.onload = () => {
|
||||
if (xhr.status >= 200 && xhr.status < 300) {
|
||||
try {
|
||||
const response = JSON.parse(xhr.responseText) as {
|
||||
task_id?: string;
|
||||
source_id?: string;
|
||||
};
|
||||
const response = JSON.parse(xhr.responseText) as { task_id?: string };
|
||||
if (response.task_id) {
|
||||
dispatch(
|
||||
updateUploadTask({
|
||||
id: clientTaskId,
|
||||
updates: {
|
||||
taskId: response.task_id,
|
||||
sourceId: response.source_id,
|
||||
status: 'training',
|
||||
progress: 0,
|
||||
},
|
||||
}),
|
||||
);
|
||||
trackTraining(clientTaskId);
|
||||
trackTraining(response.task_id, clientTaskId);
|
||||
} else {
|
||||
dispatch(
|
||||
updateUploadTask({
|
||||
|
||||
@@ -1,478 +0,0 @@
|
||||
import { beforeEach, describe, expect, it } from 'vitest';
|
||||
|
||||
import notificationsReducer, {
|
||||
sseEventReceived,
|
||||
type SSEEvent,
|
||||
} from '../notifications/notificationsSlice';
|
||||
import reducer, {
|
||||
addAttachment,
|
||||
addUploadTask,
|
||||
dismissUploadTask,
|
||||
updateAttachment,
|
||||
updateUploadTask,
|
||||
type Attachment,
|
||||
type UploadTask,
|
||||
} from './uploadSlice';
|
||||
|
||||
const SOURCE_ID = 'src-1';
|
||||
|
||||
const makeTask = (overrides: Partial<UploadTask> = {}): UploadTask => ({
|
||||
id: 't-1',
|
||||
fileName: 'doc.pdf',
|
||||
progress: 0,
|
||||
status: 'preparing',
|
||||
sourceId: SOURCE_ID,
|
||||
...overrides,
|
||||
});
|
||||
|
||||
const stateWithTask = (task: UploadTask) =>
|
||||
reducer(undefined, addUploadTask(task));
|
||||
|
||||
const ingest = (
|
||||
type: string,
|
||||
payload: Record<string, unknown> = {},
|
||||
scopeId = SOURCE_ID,
|
||||
) =>
|
||||
sseEventReceived({
|
||||
id: `id-${type}`,
|
||||
type,
|
||||
scope: { kind: 'source', id: scopeId },
|
||||
payload,
|
||||
});
|
||||
|
||||
describe('dismissal persistence across reload', () => {
|
||||
const STORAGE_KEY = 'docsgpt:dismissedUploadSourceIds';
|
||||
const SRC = 'src-persisted';
|
||||
|
||||
// Mirrors initialState as if hydrated from localStorage.
|
||||
const seedState = (entries: { id: string; at: number }[]) =>
|
||||
reducer(
|
||||
{ attachments: [], tasks: [], dismissedSourceIds: entries },
|
||||
{ type: '@@INIT' },
|
||||
);
|
||||
|
||||
beforeEach(() => {
|
||||
localStorage.clear();
|
||||
});
|
||||
|
||||
it('dismissUploadTask writes the task sourceId to localStorage', () => {
|
||||
let state = stateWithTask(
|
||||
makeTask({ id: 't-dismiss', sourceId: SRC, status: 'completed' }),
|
||||
);
|
||||
state = reducer(state, dismissUploadTask('t-dismiss'));
|
||||
expect(state.tasks[0].dismissed).toBe(true);
|
||||
expect(state.dismissedSourceIds).toHaveLength(1);
|
||||
expect(state.dismissedSourceIds[0].id).toBe(SRC);
|
||||
const persisted = JSON.parse(localStorage.getItem(STORAGE_KEY) ?? '[]');
|
||||
expect(persisted).toHaveLength(1);
|
||||
expect(persisted[0].id).toBe(SRC);
|
||||
});
|
||||
|
||||
it('skips persistence when the task has no sourceId yet', () => {
|
||||
let state = stateWithTask(
|
||||
makeTask({ id: 't-no-src', sourceId: undefined, status: 'preparing' }),
|
||||
);
|
||||
state = reducer(state, dismissUploadTask('t-no-src'));
|
||||
expect(state.tasks[0].dismissed).toBe(true);
|
||||
expect(state.dismissedSourceIds).toHaveLength(0);
|
||||
expect(localStorage.getItem(STORAGE_KEY)).toBeNull();
|
||||
});
|
||||
|
||||
it('auto-create on refresh marks the task dismissed when sourceId is in the persisted list', () => {
|
||||
const state = seedState([{ id: SRC, at: Date.now() }]);
|
||||
const next = reducer(
|
||||
state,
|
||||
ingest('source.ingest.progress', { current: 40 }, SRC),
|
||||
);
|
||||
const recovered = next.tasks.find((t) => t.sourceId === SRC);
|
||||
expect(recovered).toBeDefined();
|
||||
expect(recovered!.dismissed).toBe(true);
|
||||
expect(recovered!.progress).toBe(40);
|
||||
});
|
||||
|
||||
it('terminal events do NOT un-dismiss a task whose sourceId was previously dismissed', () => {
|
||||
const state = seedState([{ id: SRC, at: Date.now() }]);
|
||||
let next = reducer(state, ingest('source.ingest.queued', {}, SRC));
|
||||
expect(next.tasks[0].dismissed).toBe(true);
|
||||
next = reducer(next, ingest('source.ingest.completed', {}, SRC));
|
||||
// Even though `wasTerminal` is false (just transitioned), the
|
||||
// persisted dismissal keeps the toast closed.
|
||||
expect(next.tasks[0].status).toBe('completed');
|
||||
expect(next.tasks[0].dismissed).toBe(true);
|
||||
});
|
||||
|
||||
it('updateUploadTask does not un-dismiss on terminal when sourceId was previously dismissed', () => {
|
||||
const state = reducer(
|
||||
{
|
||||
attachments: [],
|
||||
tasks: [
|
||||
makeTask({
|
||||
id: 't-1',
|
||||
sourceId: SRC,
|
||||
status: 'training',
|
||||
dismissed: true,
|
||||
}),
|
||||
],
|
||||
dismissedSourceIds: [{ id: SRC, at: Date.now() }],
|
||||
},
|
||||
{ type: '@@INIT' },
|
||||
);
|
||||
const next = reducer(
|
||||
state,
|
||||
updateUploadTask({
|
||||
id: 't-1',
|
||||
updates: { status: 'completed' },
|
||||
}),
|
||||
);
|
||||
expect(next.tasks[0].status).toBe('completed');
|
||||
expect(next.tasks[0].dismissed).toBe(true);
|
||||
});
|
||||
|
||||
it('un-dismisses normally for sourceIds NOT in the persisted list', () => {
|
||||
const state = reducer(undefined, { type: '@@INIT' });
|
||||
const populated = reducer(
|
||||
state,
|
||||
addUploadTask(
|
||||
makeTask({
|
||||
id: 't-fresh',
|
||||
sourceId: 'src-fresh',
|
||||
status: 'training',
|
||||
dismissed: true,
|
||||
}),
|
||||
),
|
||||
);
|
||||
const next = reducer(
|
||||
populated,
|
||||
updateUploadTask({ id: 't-fresh', updates: { status: 'completed' } }),
|
||||
);
|
||||
expect(next.tasks[0].dismissed).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('refresh recovery — auto-create from SSE when no task matches', () => {
|
||||
it('creates a task on queued for an unknown sourceId', () => {
|
||||
let state = reducer(undefined, addUploadTask(makeTask({ id: 'other' })));
|
||||
state = reducer(
|
||||
state,
|
||||
ingest(
|
||||
'source.ingest.queued',
|
||||
{ filename: 'crawler.json', job_name: 'docs' },
|
||||
'src-recovery',
|
||||
),
|
||||
);
|
||||
const recovered = state.tasks.find((t) => t.sourceId === 'src-recovery');
|
||||
expect(recovered).toBeDefined();
|
||||
expect(recovered!.status).toBe('training');
|
||||
expect(recovered!.fileName).toBe('crawler.json');
|
||||
expect(recovered!.dismissed).toBe(false);
|
||||
});
|
||||
|
||||
it('creates a task on progress for an unknown sourceId and applies the percent', () => {
|
||||
let state: ReturnType<typeof reducer> = reducer(undefined, {
|
||||
type: '@@INIT',
|
||||
});
|
||||
state = reducer(
|
||||
state,
|
||||
ingest(
|
||||
'source.ingest.progress',
|
||||
{ current: 55, total: 10, embedded_chunks: 5 },
|
||||
'src-progress',
|
||||
),
|
||||
);
|
||||
expect(state.tasks).toHaveLength(1);
|
||||
expect(state.tasks[0].sourceId).toBe('src-progress');
|
||||
expect(state.tasks[0].status).toBe('training');
|
||||
expect(state.tasks[0].progress).toBe(55);
|
||||
});
|
||||
|
||||
it('does NOT create a task on completed for an unknown sourceId (avoids backlog toast spam)', () => {
|
||||
let state: ReturnType<typeof reducer> = reducer(undefined, {
|
||||
type: '@@INIT',
|
||||
});
|
||||
state = reducer(
|
||||
state,
|
||||
ingest('source.ingest.completed', { tokens: 16 }, 'src-stale'),
|
||||
);
|
||||
expect(state.tasks).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('creates a task on failed for an unknown sourceId so error surfaces post-refresh', () => {
|
||||
let state: ReturnType<typeof reducer> = reducer(undefined, {
|
||||
type: '@@INIT',
|
||||
});
|
||||
state = reducer(
|
||||
state,
|
||||
ingest(
|
||||
'source.ingest.failed',
|
||||
{ error: 'embed worker died' },
|
||||
'src-failed',
|
||||
),
|
||||
);
|
||||
expect(state.tasks).toHaveLength(1);
|
||||
expect(state.tasks[0].status).toBe('failed');
|
||||
expect(state.tasks[0].errorMessage).toBe('embed worker died');
|
||||
expect(state.tasks[0].dismissed).toBe(false);
|
||||
});
|
||||
|
||||
it('falls back to job_name then sourceId when filename is absent', () => {
|
||||
let state: ReturnType<typeof reducer> = reducer(undefined, {
|
||||
type: '@@INIT',
|
||||
});
|
||||
state = reducer(
|
||||
state,
|
||||
ingest('source.ingest.queued', { job_name: 'docs-docs' }, 'src-jn'),
|
||||
);
|
||||
expect(state.tasks[0].fileName).toBe('docs-docs');
|
||||
state = reducer(state, ingest('source.ingest.queued', {}, 'src-only'));
|
||||
expect(state.tasks.find((t) => t.sourceId === 'src-only')!.fileName).toBe(
|
||||
'src-only',
|
||||
);
|
||||
});
|
||||
|
||||
it('subsequent progress events update the recovered task in place', () => {
|
||||
let state: ReturnType<typeof reducer> = reducer(undefined, {
|
||||
type: '@@INIT',
|
||||
});
|
||||
state = reducer(
|
||||
state,
|
||||
ingest('source.ingest.queued', { filename: 'a.txt' }, 'src-flow'),
|
||||
);
|
||||
state = reducer(
|
||||
state,
|
||||
ingest('source.ingest.progress', { current: 30 }, 'src-flow'),
|
||||
);
|
||||
state = reducer(
|
||||
state,
|
||||
ingest('source.ingest.progress', { current: 80 }, 'src-flow'),
|
||||
);
|
||||
state = reducer(
|
||||
state,
|
||||
ingest('source.ingest.completed', { tokens: 100 }, 'src-flow'),
|
||||
);
|
||||
expect(state.tasks).toHaveLength(1);
|
||||
expect(state.tasks[0].status).toBe('completed');
|
||||
expect(state.tasks[0].progress).toBe(100);
|
||||
});
|
||||
});
|
||||
|
||||
describe('source.ingest.queued', () => {
|
||||
it('does not regress a task already in training', () => {
|
||||
let state = stateWithTask(makeTask({ status: 'training', progress: 42 }));
|
||||
state = reducer(state, ingest('source.ingest.queued'));
|
||||
expect(state.tasks[0].status).toBe('training');
|
||||
expect(state.tasks[0].progress).toBe(42);
|
||||
});
|
||||
|
||||
it('transitions preparing -> training and zeros progress', () => {
|
||||
let state = stateWithTask(makeTask({ status: 'preparing', progress: 12 }));
|
||||
state = reducer(state, ingest('source.ingest.queued'));
|
||||
expect(state.tasks[0].status).toBe('training');
|
||||
expect(state.tasks[0].progress).toBe(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('source.ingest.progress', () => {
|
||||
it('clamps to 0..100 and is monotonic', () => {
|
||||
let state = stateWithTask(makeTask({ status: 'training' }));
|
||||
state = reducer(state, ingest('source.ingest.progress', { current: 30 }));
|
||||
expect(state.tasks[0].progress).toBe(30);
|
||||
// Higher value advances.
|
||||
state = reducer(state, ingest('source.ingest.progress', { current: 150 }));
|
||||
expect(state.tasks[0].progress).toBe(100);
|
||||
// Lower value never regresses.
|
||||
state = reducer(state, ingest('source.ingest.progress', { current: 50 }));
|
||||
expect(state.tasks[0].progress).toBe(100);
|
||||
// Negative gets clamped at 0 but still cannot regress already-higher.
|
||||
state = reducer(state, ingest('source.ingest.progress', { current: -10 }));
|
||||
expect(state.tasks[0].progress).toBe(100);
|
||||
});
|
||||
|
||||
it('records the ingest stage from the payload', () => {
|
||||
let state = stateWithTask(makeTask({ status: 'training' }));
|
||||
state = reducer(
|
||||
state,
|
||||
ingest('source.ingest.progress', { current: 20, stage: 'parsing' }),
|
||||
);
|
||||
expect(state.tasks[0].stage).toBe('parsing');
|
||||
state = reducer(
|
||||
state,
|
||||
ingest('source.ingest.progress', { current: 70, stage: 'embedding' }),
|
||||
);
|
||||
expect(state.tasks[0].stage).toBe('embedding');
|
||||
// An unknown/absent stage leaves the last known value intact.
|
||||
state = reducer(
|
||||
state,
|
||||
ingest('source.ingest.progress', { current: 80, stage: 'bogus' }),
|
||||
);
|
||||
expect(state.tasks[0].stage).toBe('embedding');
|
||||
});
|
||||
});
|
||||
|
||||
describe('source.ingest.completed', () => {
|
||||
it('transitions training -> completed and sets dismissed=false', () => {
|
||||
let state = stateWithTask(
|
||||
makeTask({ status: 'training', dismissed: true }),
|
||||
);
|
||||
state = reducer(state, ingest('source.ingest.completed'));
|
||||
expect(state.tasks[0].status).toBe('completed');
|
||||
expect(state.tasks[0].progress).toBe(100);
|
||||
expect(state.tasks[0].dismissed).toBe(false);
|
||||
expect(state.tasks[0].tokenLimitReached).toBe(false);
|
||||
});
|
||||
|
||||
it('with limited=true transitions training -> failed and flags tokenLimitReached', () => {
|
||||
let state = stateWithTask(
|
||||
makeTask({ status: 'training', dismissed: true }),
|
||||
);
|
||||
state = reducer(
|
||||
state,
|
||||
ingest('source.ingest.completed', { limited: true }),
|
||||
);
|
||||
expect(state.tasks[0].status).toBe('failed');
|
||||
expect(state.tasks[0].progress).toBe(100);
|
||||
expect(state.tasks[0].tokenLimitReached).toBe(true);
|
||||
expect(state.tasks[0].dismissed).toBe(false);
|
||||
});
|
||||
|
||||
it('does not re-un-dismiss when a duplicate terminal event arrives', () => {
|
||||
// Initial terminal — wasTerminal=false, dismissed flipped to false.
|
||||
let state = stateWithTask(makeTask({ status: 'training' }));
|
||||
state = reducer(state, ingest('source.ingest.completed'));
|
||||
expect(state.tasks[0].dismissed).toBe(false);
|
||||
|
||||
// User dismisses the toast manually.
|
||||
state = {
|
||||
...state,
|
||||
tasks: state.tasks.map((t) => ({ ...t, dismissed: true })),
|
||||
};
|
||||
|
||||
// Duplicate terminal envelope (StrictMode remount, reconnect overlap).
|
||||
state = reducer(state, ingest('source.ingest.completed'));
|
||||
expect(state.tasks[0].status).toBe('completed');
|
||||
expect(state.tasks[0].dismissed).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('source.ingest.failed', () => {
|
||||
it('transitions training -> failed with the error message', () => {
|
||||
let state = stateWithTask(makeTask({ status: 'training' }));
|
||||
state = reducer(
|
||||
state,
|
||||
ingest('source.ingest.failed', { error: 'parser blew up' }),
|
||||
);
|
||||
expect(state.tasks[0].status).toBe('failed');
|
||||
expect(state.tasks[0].errorMessage).toBe('parser blew up');
|
||||
expect(state.tasks[0].dismissed).toBe(false);
|
||||
});
|
||||
|
||||
it('does not re-un-dismiss when a duplicate failed event arrives', () => {
|
||||
let state = stateWithTask(makeTask({ status: 'training' }));
|
||||
state = reducer(state, ingest('source.ingest.failed', { error: 'oops' }));
|
||||
state = {
|
||||
...state,
|
||||
tasks: state.tasks.map((t) => ({ ...t, dismissed: true })),
|
||||
};
|
||||
state = reducer(state, ingest('source.ingest.failed', { error: 'oops' }));
|
||||
expect(state.tasks[0].dismissed).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('attachment race recovery', () => {
|
||||
const ATTACHMENT_ID = 'att-1';
|
||||
const CLIENT_ID = 'ui-1';
|
||||
|
||||
const attEvent = (
|
||||
type: string,
|
||||
payload: Record<string, unknown> = {},
|
||||
): SSEEvent => ({
|
||||
id: `id-${type}`,
|
||||
type,
|
||||
scope: { kind: 'attachment', id: ATTACHMENT_ID },
|
||||
payload,
|
||||
});
|
||||
|
||||
const makeAttachment = (overrides: Partial<Attachment> = {}): Attachment => ({
|
||||
id: CLIENT_ID,
|
||||
fileName: 'small.pdf',
|
||||
progress: 10,
|
||||
status: 'processing',
|
||||
taskId: 'celery-1',
|
||||
attachmentId: ATTACHMENT_ID,
|
||||
...overrides,
|
||||
});
|
||||
|
||||
it('drops attachment.completed silently when no row matches attachmentId', () => {
|
||||
const state = reducer(
|
||||
undefined,
|
||||
sseEventReceived(attEvent('attachment.completed', { token_count: 42 })),
|
||||
);
|
||||
expect(state.attachments).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('lands the terminal envelope in notifications.recentEvents for later recovery', () => {
|
||||
const notifState = notificationsReducer(
|
||||
undefined,
|
||||
sseEventReceived(attEvent('attachment.completed', { token_count: 7 })),
|
||||
);
|
||||
expect(notifState.recentEvents).toHaveLength(1);
|
||||
expect(notifState.recentEvents[0].scope?.id).toBe(ATTACHMENT_ID);
|
||||
expect(notifState.recentEvents[0].type).toBe('attachment.completed');
|
||||
});
|
||||
|
||||
it('reconciler dispatch flips the row to completed after the late row addition', () => {
|
||||
// Full race: terminal SSE first, then xhr.onload adds the row,
|
||||
// then trackAttachment.check() walks recentEvents and dispatches.
|
||||
const terminal = attEvent('attachment.completed', { token_count: 99 });
|
||||
const notifState = notificationsReducer(
|
||||
undefined,
|
||||
sseEventReceived(terminal),
|
||||
);
|
||||
|
||||
let uploadState = reducer(undefined, addAttachment(makeAttachment()));
|
||||
expect(uploadState.attachments[0].status).toBe('processing');
|
||||
|
||||
const found = notifState.recentEvents.find(
|
||||
(e) => e.scope?.id === ATTACHMENT_ID && e.type === 'attachment.completed',
|
||||
);
|
||||
expect(found).toBeDefined();
|
||||
const tokenCount = Number(
|
||||
(found?.payload as { token_count?: unknown })?.token_count,
|
||||
);
|
||||
uploadState = reducer(
|
||||
uploadState,
|
||||
updateAttachment({
|
||||
id: CLIENT_ID,
|
||||
updates: {
|
||||
status: 'completed',
|
||||
progress: 100,
|
||||
...(Number.isFinite(tokenCount) ? { token_count: tokenCount } : {}),
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
expect(uploadState.attachments[0].status).toBe('completed');
|
||||
expect(uploadState.attachments[0].progress).toBe(100);
|
||||
expect(uploadState.attachments[0].token_count).toBe(99);
|
||||
});
|
||||
|
||||
it('attachment.failed envelope can drive a stuck row to failed via reconciler', () => {
|
||||
const failed = attEvent('attachment.failed', { error: 'docling boom' });
|
||||
const notifState = notificationsReducer(
|
||||
undefined,
|
||||
sseEventReceived(failed),
|
||||
);
|
||||
|
||||
let uploadState = reducer(undefined, addAttachment(makeAttachment()));
|
||||
const found = notifState.recentEvents.find(
|
||||
(e) => e.scope?.id === ATTACHMENT_ID && e.type === 'attachment.failed',
|
||||
);
|
||||
expect(found).toBeDefined();
|
||||
|
||||
uploadState = reducer(
|
||||
uploadState,
|
||||
updateAttachment({ id: CLIENT_ID, updates: { status: 'failed' } }),
|
||||
);
|
||||
|
||||
expect(uploadState.attachments[0].status).toBe('failed');
|
||||
});
|
||||
});
|
||||
@@ -1,46 +1,12 @@
|
||||
import { createSelector, createSlice, PayloadAction } from '@reduxjs/toolkit';
|
||||
|
||||
import {
|
||||
DismissedEntry,
|
||||
isDismissed,
|
||||
loadDismissed,
|
||||
saveDismissed,
|
||||
} from '../notifications/dismissedPersistence';
|
||||
import { sseEventReceived } from '../notifications/notificationsSlice';
|
||||
import { RootState } from '../store';
|
||||
|
||||
const DISMISSED_SOURCE_IDS_STORAGE_KEY = 'docsgpt:dismissedUploadSourceIds';
|
||||
const DISMISSED_SOURCE_IDS_TTL_MS = 24 * 60 * 60 * 1000;
|
||||
const DISMISSED_SOURCE_IDS_CAP = 200;
|
||||
|
||||
function recordDismissedSourceId(
|
||||
list: DismissedEntry[],
|
||||
sourceId: string,
|
||||
): DismissedEntry[] {
|
||||
const now = Date.now();
|
||||
const cutoff = now - DISMISSED_SOURCE_IDS_TTL_MS;
|
||||
const next = list.filter(
|
||||
(entry) => entry.at >= cutoff && entry.id !== sourceId,
|
||||
);
|
||||
next.push({ id: sourceId, at: now });
|
||||
return next.length > DISMISSED_SOURCE_IDS_CAP
|
||||
? next.slice(-DISMISSED_SOURCE_IDS_CAP)
|
||||
: next;
|
||||
}
|
||||
|
||||
export interface Attachment {
|
||||
id: string; // Client-side state-management id (uuid generated in MessageInput)
|
||||
id: string; // Unique identifier for the attachment (required for state management)
|
||||
fileName: string;
|
||||
progress: number;
|
||||
status: 'uploading' | 'processing' | 'completed' | 'failed';
|
||||
taskId: string; // Server-assigned celery task ID (used for API calls)
|
||||
/**
|
||||
* Server-assigned attachment id (stable across the lifecycle —
|
||||
* ``attachment.*`` SSE events use this in ``scope.id``). Set as
|
||||
* soon as the upload response returns. Distinct from ``id``
|
||||
* (client) and ``taskId`` (celery).
|
||||
*/
|
||||
attachmentId?: string;
|
||||
taskId: string; // Server-assigned task ID (used for API calls)
|
||||
token_count?: number;
|
||||
}
|
||||
|
||||
@@ -57,46 +23,18 @@ export interface UploadTask {
|
||||
progress: number;
|
||||
status: UploadTaskStatus;
|
||||
taskId?: string;
|
||||
/**
|
||||
* Server-derived source id (uuid5 over the idempotency key) returned by
|
||||
* the upload endpoint. Used to correlate inbound SSE ingest events
|
||||
* (``source.ingest.*``) back to this task without consulting the
|
||||
* polling endpoint.
|
||||
*/
|
||||
sourceId?: string;
|
||||
errorMessage?: string;
|
||||
dismissed?: boolean;
|
||||
/**
|
||||
* Ingest phase from the latest ``source.ingest.progress`` event:
|
||||
* ``parsing`` (parse/OCR, lower band of the bar) or ``embedding``
|
||||
* (upper band). Drives the phase label in ``UploadToast``.
|
||||
*/
|
||||
stage?: 'parsing' | 'embedding';
|
||||
/**
|
||||
* Flipped when ``source.ingest.completed`` carries
|
||||
* ``payload.limited === true`` (the worker hit a token cap during
|
||||
* ingest). The slice routes such events to a failed status and
|
||||
* sets this flag so ``UploadToast`` can surface a translated
|
||||
* token-limit message instead of a generic error. Forward-looking:
|
||||
* no worker code path sets ``limited: true`` today.
|
||||
*/
|
||||
tokenLimitReached?: boolean;
|
||||
}
|
||||
|
||||
interface UploadState {
|
||||
attachments: Attachment[];
|
||||
tasks: UploadTask[];
|
||||
/** Persisted dismissed sourceIds; keeps backlog-replay auto-creates silent. */
|
||||
dismissedSourceIds: DismissedEntry[];
|
||||
}
|
||||
|
||||
const initialState: UploadState = {
|
||||
attachments: [],
|
||||
tasks: [],
|
||||
dismissedSourceIds: loadDismissed(
|
||||
DISMISSED_SOURCE_IDS_STORAGE_KEY,
|
||||
DISMISSED_SOURCE_IDS_TTL_MS,
|
||||
),
|
||||
};
|
||||
|
||||
export const uploadSlice = createSlice({
|
||||
@@ -165,19 +103,9 @@ export const uploadSlice = createSlice({
|
||||
);
|
||||
if (index !== -1) {
|
||||
const updates = action.payload.updates;
|
||||
const existingSourceId = state.tasks[index].sourceId;
|
||||
const incomingSourceId = updates.sourceId;
|
||||
const effectiveSourceId = incomingSourceId ?? existingSourceId;
|
||||
const sourceWasDismissed =
|
||||
!!effectiveSourceId &&
|
||||
isDismissed(state.dismissedSourceIds, effectiveSourceId);
|
||||
|
||||
// Re-surface on terminal status, except when the sourceId was
|
||||
// dismissed in a prior session (don't re-pop after reload).
|
||||
if (
|
||||
(updates.status === 'completed' || updates.status === 'failed') &&
|
||||
!sourceWasDismissed
|
||||
) {
|
||||
// When task completes or fails, set dismissed to false to notify user
|
||||
if (updates.status === 'completed' || updates.status === 'failed') {
|
||||
state.tasks[index] = {
|
||||
...state.tasks[index],
|
||||
...updates,
|
||||
@@ -198,187 +126,12 @@ export const uploadSlice = createSlice({
|
||||
...state.tasks[index],
|
||||
dismissed: true,
|
||||
};
|
||||
// Persist by sourceId so a reload doesn't re-pop via the SSE
|
||||
// backlog replay's auto-create branch.
|
||||
const sourceId = state.tasks[index].sourceId;
|
||||
if (sourceId) {
|
||||
state.dismissedSourceIds = recordDismissedSourceId(
|
||||
state.dismissedSourceIds,
|
||||
sourceId,
|
||||
);
|
||||
saveDismissed(
|
||||
DISMISSED_SOURCE_IDS_STORAGE_KEY,
|
||||
state.dismissedSourceIds,
|
||||
);
|
||||
}
|
||||
}
|
||||
},
|
||||
removeUploadTask: (state, action: PayloadAction<string>) => {
|
||||
state.tasks = state.tasks.filter((task) => task.id !== action.payload);
|
||||
},
|
||||
},
|
||||
extraReducers: (builder) => {
|
||||
// Consume backend SSE ingest events for sub-second progress and
|
||||
// terminal-status updates. The match is by ``sourceId`` (set by
|
||||
// the upload endpoint's response). Polling stays as the
|
||||
// correctness-of-record fallback in Upload.tsx.
|
||||
builder.addCase(sseEventReceived, (state, action) => {
|
||||
const e = action.payload;
|
||||
const scopeId =
|
||||
typeof e.scope?.id === 'string' && e.scope.id.length > 0
|
||||
? e.scope.id
|
||||
: undefined;
|
||||
|
||||
// Attachment events flow through the same SSE pipe; route them
|
||||
// to ``state.attachments`` matched by ``attachmentId``. SSE is
|
||||
// the sole driver of attachment state transitions — polling
|
||||
// has been removed. Events for attachments uploaded in another
|
||||
// session are silently dropped.
|
||||
if (e.type.startsWith('attachment.') && scopeId) {
|
||||
const attachment = state.attachments.find(
|
||||
(a) => a.attachmentId === scopeId,
|
||||
);
|
||||
if (attachment) {
|
||||
const payload = (e.payload || {}) as Record<string, unknown>;
|
||||
switch (e.type) {
|
||||
case 'attachment.queued':
|
||||
case 'attachment.progress': {
|
||||
if (
|
||||
attachment.status === 'completed' ||
|
||||
attachment.status === 'failed'
|
||||
) {
|
||||
break;
|
||||
}
|
||||
attachment.status = 'processing';
|
||||
const current = Number(payload.current);
|
||||
if (Number.isFinite(current)) {
|
||||
const clamped = Math.max(0, Math.min(100, current));
|
||||
if (clamped > attachment.progress) {
|
||||
attachment.progress = clamped;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'attachment.completed': {
|
||||
attachment.status = 'completed';
|
||||
attachment.progress = 100;
|
||||
// Replace the client-generated uuid with the server's
|
||||
// attachment id so question submission
|
||||
// (Conversation.tsx:174) sends an id the backend can
|
||||
// resolve. Without this the backend would silently drop
|
||||
// the attachment from the message context.
|
||||
attachment.id = scopeId;
|
||||
const tokenCount = Number(payload.token_count);
|
||||
if (Number.isFinite(tokenCount)) {
|
||||
attachment.token_count = tokenCount;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'attachment.failed': {
|
||||
attachment.status = 'failed';
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (!e.type.startsWith('source.ingest.')) return;
|
||||
if (!scopeId) return;
|
||||
// ``source.ingest.*`` payloads do not carry a celery task_id
|
||||
// (see ``application/worker.py`` — only ``source_id`` /
|
||||
// ``job_name`` / ``filename`` are published), so ``sourceId``
|
||||
// is the only correlation key.
|
||||
const payload = (e.payload || {}) as Record<string, unknown>;
|
||||
let task = state.tasks.find((t) => t.sourceId === scopeId);
|
||||
if (!task) {
|
||||
// Auto-create on backlog/live SSE when the local task is gone
|
||||
// (page refresh mid-ingest). Skip ``completed`` so the historical
|
||||
// backlog doesn't pop toasts for already-finished sources.
|
||||
if (e.type === 'source.ingest.completed') return;
|
||||
const fileName =
|
||||
(typeof payload.filename === 'string' && payload.filename) ||
|
||||
(typeof payload.job_name === 'string' && payload.job_name) ||
|
||||
scopeId;
|
||||
const previouslyDismissed = isDismissed(
|
||||
state.dismissedSourceIds,
|
||||
scopeId,
|
||||
);
|
||||
task = {
|
||||
id: scopeId,
|
||||
fileName,
|
||||
progress: 0,
|
||||
status: 'training',
|
||||
sourceId: scopeId,
|
||||
dismissed: previouslyDismissed,
|
||||
};
|
||||
state.tasks.push(task);
|
||||
task = state.tasks[state.tasks.length - 1];
|
||||
}
|
||||
const wasTerminal =
|
||||
task.status === 'completed' || task.status === 'failed';
|
||||
const sourceWasDismissed =
|
||||
!!task.sourceId && isDismissed(state.dismissedSourceIds, task.sourceId);
|
||||
|
||||
switch (e.type) {
|
||||
case 'source.ingest.queued':
|
||||
// Don't regress a task already past 'training' (e.g. the
|
||||
// queued event arrives after the upload XHR finished and
|
||||
// status flipped to 'training'). Idempotent on retries.
|
||||
if (task.status === 'preparing' || task.status === 'uploading') {
|
||||
task.status = 'training';
|
||||
task.progress = 0;
|
||||
}
|
||||
break;
|
||||
case 'source.ingest.progress': {
|
||||
const current = Number(payload.current);
|
||||
if (!Number.isFinite(current)) break;
|
||||
// Clamp + monotonic — never regress an already-higher value.
|
||||
const clamped = Math.max(0, Math.min(100, current));
|
||||
if (task.status === 'completed' || task.status === 'failed') break;
|
||||
task.status = 'training';
|
||||
if (clamped > task.progress) task.progress = clamped;
|
||||
if (payload.stage === 'parsing' || payload.stage === 'embedding') {
|
||||
task.stage = payload.stage;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'source.ingest.completed':
|
||||
if (payload.limited === true) {
|
||||
// Token-cap reached during ingest — surface as a failure
|
||||
// so the toast shows the translated limit message rather
|
||||
// than a misleading success state.
|
||||
task.status = 'failed';
|
||||
task.progress = 100;
|
||||
task.tokenLimitReached = true;
|
||||
task.errorMessage = undefined;
|
||||
// Only un-dismiss on the initial terminal transition;
|
||||
// duplicate envelopes (StrictMode remount, reconnect
|
||||
// overlap) must not re-pop a user-dismissed toast.
|
||||
if (!wasTerminal && !sourceWasDismissed) task.dismissed = false;
|
||||
} else {
|
||||
task.status = 'completed';
|
||||
task.progress = 100;
|
||||
task.errorMessage = undefined;
|
||||
task.tokenLimitReached = false;
|
||||
if (!wasTerminal && !sourceWasDismissed) task.dismissed = false;
|
||||
}
|
||||
break;
|
||||
case 'source.ingest.failed':
|
||||
task.status = 'failed';
|
||||
task.errorMessage =
|
||||
typeof payload.error === 'string'
|
||||
? payload.error
|
||||
: 'Ingestion failed.';
|
||||
if (!wasTerminal && !sourceWasDismissed) task.dismissed = false;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
export const {
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
"module": "ESNext",
|
||||
"moduleResolution": "Bundler",
|
||||
"resolveJsonModule": true,
|
||||
"types": ["vite-plugin-svgr/client", "vitest/globals"],
|
||||
"types": ["vite-plugin-svgr/client"],
|
||||
"isolatedModules": true,
|
||||
"noEmit": true,
|
||||
"jsx": "react-jsx",
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
/// <reference types="vitest" />
|
||||
import { defineConfig } from 'vite';
|
||||
import react from '@vitejs/plugin-react';
|
||||
import svgr from 'vite-plugin-svgr';
|
||||
@@ -12,10 +11,4 @@ export default defineConfig({
|
||||
'@': path.resolve(__dirname, './src'),
|
||||
},
|
||||
},
|
||||
test: {
|
||||
environment: 'happy-dom',
|
||||
globals: true,
|
||||
include: ['src/**/*.test.{ts,tsx}'],
|
||||
setupFiles: ['./vitest.setup.ts'],
|
||||
},
|
||||
});
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
// happy-dom's localStorage lives on `window`; slices read the bare
|
||||
// global at module-load. Install a Map-backed shim on globalThis.
|
||||
const store = new Map<string, string>();
|
||||
const shim = {
|
||||
getItem: (k: string) => (store.has(k) ? store.get(k)! : null),
|
||||
setItem: (k: string, v: string) => {
|
||||
store.set(k, String(v));
|
||||
},
|
||||
removeItem: (k: string) => {
|
||||
store.delete(k);
|
||||
},
|
||||
clear: () => store.clear(),
|
||||
key: (i: number) => Array.from(store.keys())[i] ?? null,
|
||||
get length() {
|
||||
return store.size;
|
||||
},
|
||||
};
|
||||
// Force-override: some happy-dom versions expose a Storage stub
|
||||
// without getItem, so simple assignment isn't enough.
|
||||
Object.defineProperty(globalThis, 'localStorage', {
|
||||
value: shim,
|
||||
writable: true,
|
||||
configurable: true,
|
||||
});
|
||||
@@ -4,24 +4,19 @@ Fixed 5-second generation (100 tokens × 50 ms/token). No auth. Emits SSE
|
||||
chunks in OpenAI's chat.completions streaming format, or a single response
|
||||
when stream=false. Run on 127.0.0.1:8090 — point DocsGPT at it via
|
||||
OPENAI_BASE_URL=http://127.0.0.1:8090/v1.
|
||||
|
||||
Flags:
|
||||
--tool-calls First response returns a tool call instead of text.
|
||||
Subsequent responses (after a tool_result) return text.
|
||||
Useful for triggering the tool-execution loop.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from flask import Flask, Response, request, jsonify
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
TOKEN_COUNT = 100
|
||||
TOKEN_DELAY_S = 0.05 # 100 * 0.05 = 5.0 s
|
||||
TOOL_CALL_MODE = False
|
||||
|
||||
logger = logging.getLogger("mock_llm")
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s mock: %(message)s")
|
||||
@@ -44,7 +39,7 @@ FILLER_TOKENS = [
|
||||
".",
|
||||
]
|
||||
|
||||
app = Flask(__name__)
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
def _token_stream_id() -> str:
|
||||
@@ -68,57 +63,11 @@ def _sse_chunk(completion_id: str, model: str, delta: dict, finish_reason=None)
|
||||
return f"data: {json.dumps(payload)}\n\n"
|
||||
|
||||
|
||||
def _gen_tool_call_stream(model: str, req_id: str):
|
||||
"""Emit two tool_calls (search) in streaming format.
|
||||
|
||||
Two calls ensure the handler executes the first (which can return a
|
||||
huge result), then hits _check_context_limit before the second.
|
||||
"""
|
||||
completion_id = _token_stream_id()
|
||||
call_id_1 = f"call_{uuid.uuid4().hex[:12]}"
|
||||
call_id_2 = f"call_{uuid.uuid4().hex[:12]}"
|
||||
|
||||
yield _sse_chunk(completion_id, model, {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": 0,
|
||||
"id": call_id_1,
|
||||
"type": "function",
|
||||
"function": {"name": "search", "arguments": ""},
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"id": call_id_2,
|
||||
"type": "function",
|
||||
"function": {"name": "search", "arguments": ""},
|
||||
},
|
||||
],
|
||||
})
|
||||
args_json = json.dumps({"query": "Python programming basics"})
|
||||
for ch in args_json:
|
||||
time.sleep(TOKEN_DELAY_S)
|
||||
yield _sse_chunk(completion_id, model, {
|
||||
"tool_calls": [
|
||||
{"index": 0, "function": {"arguments": ch}},
|
||||
{"index": 1, "function": {"arguments": ch}},
|
||||
],
|
||||
})
|
||||
yield _sse_chunk(completion_id, model, {}, finish_reason="tool_calls")
|
||||
yield "data: [DONE]\n\n"
|
||||
logger.info("[%s] tool_call stream done (ids=%s, %s)", req_id, call_id_1, call_id_2)
|
||||
|
||||
|
||||
def _has_tool_result(messages: list) -> bool:
|
||||
return any(m.get("role") == "tool" for m in messages)
|
||||
|
||||
|
||||
def _gen_text_stream(model: str, req_id: str):
|
||||
async def _stream_response(model: str, req_id: str):
|
||||
completion_id = _token_stream_id()
|
||||
yield _sse_chunk(completion_id, model, {"role": "assistant", "content": ""})
|
||||
for tok in FILLER_TOKENS[:TOKEN_COUNT]:
|
||||
time.sleep(TOKEN_DELAY_S)
|
||||
for i, tok in enumerate(FILLER_TOKENS[:TOKEN_COUNT]):
|
||||
await asyncio.sleep(TOKEN_DELAY_S)
|
||||
yield _sse_chunk(completion_id, model, {"content": tok})
|
||||
yield _sse_chunk(completion_id, model, {}, finish_reason="stop")
|
||||
yield "data: [DONE]\n\n"
|
||||
@@ -126,84 +75,63 @@ def _gen_text_stream(model: str, req_id: str):
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
def chat_completions():
|
||||
body = request.get_json(force=True)
|
||||
async def chat_completions(request: Request):
|
||||
body = await request.json()
|
||||
model = body.get("model", "mock")
|
||||
stream = bool(body.get("stream", False))
|
||||
messages = body.get("messages", [])
|
||||
tools = body.get("tools")
|
||||
req_id = uuid.uuid4().hex[:8]
|
||||
logger.info(
|
||||
"[%s] /chat/completions stream=%s model=%s tools=%s msgs=%d",
|
||||
req_id, stream, model, bool(tools), len(messages),
|
||||
)
|
||||
|
||||
use_tool_call = (
|
||||
TOOL_CALL_MODE
|
||||
and tools
|
||||
and not _has_tool_result(messages)
|
||||
)
|
||||
logger.info("[%s] /chat/completions stream=%s model=%s max_tokens=%s", req_id, stream, model, body.get("max_tokens"))
|
||||
|
||||
if stream:
|
||||
gen = (
|
||||
_gen_tool_call_stream(model, req_id) if use_tool_call
|
||||
else _gen_text_stream(model, req_id)
|
||||
)
|
||||
return Response(
|
||||
gen,
|
||||
mimetype="text/event-stream",
|
||||
return StreamingResponse(
|
||||
_stream_response(model, req_id),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache, no-transform",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
time.sleep(TOKEN_COUNT * TOKEN_DELAY_S)
|
||||
await asyncio.sleep(TOKEN_COUNT * TOKEN_DELAY_S)
|
||||
logger.info("[%s] non-stream done", req_id)
|
||||
text = "".join(FILLER_TOKENS[:TOKEN_COUNT])
|
||||
completion_id = _token_stream_id()
|
||||
return jsonify({
|
||||
"id": completion_id,
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": text},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": TOKEN_COUNT,
|
||||
"total_tokens": 10 + TOKEN_COUNT,
|
||||
},
|
||||
})
|
||||
return JSONResponse(
|
||||
{
|
||||
"id": completion_id,
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": text},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": TOKEN_COUNT,
|
||||
"total_tokens": 10 + TOKEN_COUNT,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
def list_models():
|
||||
return jsonify({
|
||||
async def list_models():
|
||||
return {
|
||||
"object": "list",
|
||||
"data": [{"id": "mock", "object": "model", "owned_by": "mock"}],
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
return jsonify({"status": "ok"})
|
||||
async def health():
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--tool-calls", action="store_true",
|
||||
help="First response returns a tool_call; subsequent responses return text.",
|
||||
)
|
||||
parser.add_argument("--port", type=int, default=8090)
|
||||
args = parser.parse_args()
|
||||
TOOL_CALL_MODE = args.tool_calls
|
||||
if TOOL_CALL_MODE:
|
||||
logger.info("Tool-call mode enabled")
|
||||
app.run(host="127.0.0.1", port=args.port, debug=False, threaded=True)
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="127.0.0.1", port=8090, log_level="info")
|
||||
|
||||
@@ -492,7 +492,7 @@ class TestMCPOAuthManager:
|
||||
from application.agents.tools.mcp_tool import MCPOAuthManager
|
||||
|
||||
manager = MCPOAuthManager(MagicMock())
|
||||
result = manager.get_oauth_status("", "alice")
|
||||
result = manager.get_oauth_status("")
|
||||
assert result["status"] == "not_started"
|
||||
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
"""Tests for the journaled execute path on ToolExecutor.
|
||||
|
||||
Each tool call inserts a ``tool_call_attempts`` row and flips it
|
||||
``proposed → executed`` (or ``→ failed``). With a ``message_id`` it
|
||||
stays ``executed`` for the finalize path to confirm; without one
|
||||
(``save_conversation=False``) it goes straight to ``confirmed``.
|
||||
Each tool call inserts a row into ``tool_call_attempts`` then flips
|
||||
through ``proposed → executed`` (or ``proposed → failed``). The flip
|
||||
to ``confirmed`` is owned by the message-finalize path and is only
|
||||
asserted indirectly here (rows stay in ``executed`` so the reconciler
|
||||
can pick them up).
|
||||
"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
@@ -74,24 +75,11 @@ def _make_call(name="test_action_t1", call_id="c1"):
|
||||
return call
|
||||
|
||||
|
||||
_TOOLS_DICT = {
|
||||
"t1": {
|
||||
"id": "00000000-0000-0000-0000-000000000001",
|
||||
"name": "test_tool",
|
||||
"config": {"key": "val"},
|
||||
"actions": [
|
||||
{"name": "test_action", "description": "T", "parameters": {"properties": {}}},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestExecuteJournaling:
|
||||
def test_no_message_id_proposed_then_confirmed(
|
||||
def test_happy_path_proposed_then_executed(
|
||||
self, pg_conn, mock_tool_manager, monkeypatch
|
||||
):
|
||||
"""No reserved message (``save_conversation=False``) → row lands ``confirmed``, not ``executed``."""
|
||||
executor = ToolExecutor(user="u")
|
||||
monkeypatch.setattr(
|
||||
"application.agents.tool_executor.ToolActionParser",
|
||||
@@ -101,12 +89,23 @@ class TestExecuteJournaling:
|
||||
)
|
||||
_patch_db(monkeypatch, pg_conn)
|
||||
|
||||
events, result = _drain(executor.execute(_TOOLS_DICT, _make_call(), "MockLLM"))
|
||||
tools_dict = {
|
||||
"t1": {
|
||||
"id": "00000000-0000-0000-0000-000000000001",
|
||||
"name": "test_tool",
|
||||
"config": {"key": "val"},
|
||||
"actions": [
|
||||
{"name": "test_action", "description": "T", "parameters": {"properties": {}}},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
events, result = _drain(executor.execute(tools_dict, _make_call(), "MockLLM"))
|
||||
assert result[0] == "Tool result"
|
||||
|
||||
row = _select_attempt(pg_conn, "c1")
|
||||
assert row is not None
|
||||
assert row["status"] == "confirmed"
|
||||
assert row["status"] == "executed"
|
||||
assert row["tool_name"] == "test_tool"
|
||||
assert row["action_name"] == "test_action"
|
||||
assert row["arguments"] == {"q": "v"}
|
||||
@@ -118,7 +117,10 @@ class TestExecuteJournaling:
|
||||
def test_executor_message_id_is_persisted_on_executed_row(
|
||||
self, pg_conn, mock_tool_manager, monkeypatch
|
||||
):
|
||||
"""The executor's message_id is carried onto the journal row, which stays ``executed``."""
|
||||
"""When the route stamps a placeholder message_id on the executor,
|
||||
the journal row carries it forward so ``confirm_executed_tool_calls``
|
||||
can later flip it to ``confirmed``.
|
||||
"""
|
||||
from application.storage.db.repositories.conversations import (
|
||||
ConversationsRepository,
|
||||
)
|
||||
@@ -145,7 +147,18 @@ class TestExecuteJournaling:
|
||||
)
|
||||
_patch_db(monkeypatch, pg_conn)
|
||||
|
||||
_drain(executor.execute(_TOOLS_DICT, _make_call(call_id="cm1"), "MockLLM"))
|
||||
tools_dict = {
|
||||
"t1": {
|
||||
"id": "00000000-0000-0000-0000-000000000001",
|
||||
"name": "test_tool",
|
||||
"config": {"key": "val"},
|
||||
"actions": [
|
||||
{"name": "test_action", "description": "T", "parameters": {"properties": {}}},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
_drain(executor.execute(tools_dict, _make_call(call_id="cm1"), "MockLLM"))
|
||||
|
||||
row = _select_attempt(pg_conn, "cm1")
|
||||
assert row is not None
|
||||
@@ -167,7 +180,18 @@ class TestExecuteJournaling:
|
||||
RuntimeError("boom")
|
||||
)
|
||||
|
||||
gen = executor.execute(_TOOLS_DICT, _make_call(call_id="c2"), "MockLLM")
|
||||
tools_dict = {
|
||||
"t1": {
|
||||
"id": "00000000-0000-0000-0000-000000000001",
|
||||
"name": "test_tool",
|
||||
"config": {"key": "val"},
|
||||
"actions": [
|
||||
{"name": "test_action", "description": "T", "parameters": {"properties": {}}},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
gen = executor.execute(tools_dict, _make_call(call_id="c2"), "MockLLM")
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
_drain(gen)
|
||||
|
||||
@@ -176,10 +200,42 @@ class TestExecuteJournaling:
|
||||
assert row["status"] == "failed"
|
||||
assert row["error"] == "boom"
|
||||
|
||||
def test_executed_row_lingers_for_reconciler_when_no_confirm(
|
||||
self, pg_conn, mock_tool_manager, monkeypatch
|
||||
):
|
||||
"""No finalize_message call → row sits in ``executed``."""
|
||||
executor = ToolExecutor(user="u")
|
||||
monkeypatch.setattr(
|
||||
"application.agents.tool_executor.ToolActionParser",
|
||||
lambda _cls, **kw: Mock(
|
||||
parse_args=Mock(return_value=("t1", "test_action", {}))
|
||||
),
|
||||
)
|
||||
_patch_db(monkeypatch, pg_conn)
|
||||
|
||||
tools_dict = {
|
||||
"t1": {
|
||||
"id": "00000000-0000-0000-0000-000000000001",
|
||||
"name": "test_tool",
|
||||
"config": {"key": "val"},
|
||||
"actions": [
|
||||
{"name": "test_action", "description": "T", "parameters": {"properties": {}}},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
_drain(executor.execute(tools_dict, _make_call(call_id="c3"), "MockLLM"))
|
||||
|
||||
row = _select_attempt(pg_conn, "c3")
|
||||
assert row["status"] == "executed"
|
||||
# Partial index `tool_call_attempts_pending_ts_idx` selects rows
|
||||
# in ('proposed','executed') — the reconciler reads those.
|
||||
assert row["status"] in ("proposed", "executed")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestRepository:
|
||||
def test_proposed_then_confirmed_when_no_message(self, pg_conn):
|
||||
def test_proposed_then_executed_round_trip(self, pg_conn):
|
||||
from application.storage.db.repositories.tool_call_attempts import (
|
||||
ToolCallAttemptsRepository,
|
||||
)
|
||||
@@ -193,50 +249,7 @@ class TestRepository:
|
||||
|
||||
assert repo.mark_executed("c-x", {"out": "ok"}) is True
|
||||
row = _select_attempt(pg_conn, "c-x")
|
||||
assert row["status"] == "confirmed"
|
||||
assert row["message_id"] is None
|
||||
assert row["result"] == {"result": {"out": "ok"}}
|
||||
|
||||
def test_mark_executed_with_message_stays_executed(self, pg_conn):
|
||||
from application.storage.db.repositories.conversations import (
|
||||
ConversationsRepository,
|
||||
)
|
||||
from application.storage.db.repositories.tool_call_attempts import (
|
||||
ToolCallAttemptsRepository,
|
||||
)
|
||||
|
||||
# FK constraint: message_id must reference a real row.
|
||||
conv_repo = ConversationsRepository(pg_conn)
|
||||
conv = conv_repo.create("u-repo", "repo-msg-test")
|
||||
msg = conv_repo.reserve_message(
|
||||
str(conv["id"]),
|
||||
prompt="q?",
|
||||
placeholder_response="...",
|
||||
request_id="req-repo-1",
|
||||
status="pending",
|
||||
)
|
||||
message_uuid = str(msg["id"])
|
||||
|
||||
repo = ToolCallAttemptsRepository(pg_conn)
|
||||
repo.record_proposed("c-m", "tool", "act", {})
|
||||
assert (
|
||||
repo.mark_executed("c-m", {"out": "ok"}, message_id=message_uuid) is True
|
||||
)
|
||||
row = _select_attempt(pg_conn, "c-m")
|
||||
assert row["status"] == "executed"
|
||||
assert str(row["message_id"]) == message_uuid
|
||||
|
||||
def test_upsert_executed_without_message_confirms(self, pg_conn):
|
||||
"""``upsert_executed`` (DB-outage fallback) with no ``message_id`` lands ``confirmed``."""
|
||||
from application.storage.db.repositories.tool_call_attempts import (
|
||||
ToolCallAttemptsRepository,
|
||||
)
|
||||
|
||||
repo = ToolCallAttemptsRepository(pg_conn)
|
||||
repo.upsert_executed("c-up", "tool", "act", {"a": 1}, {"out": "ok"})
|
||||
row = _select_attempt(pg_conn, "c-up")
|
||||
assert row["status"] == "confirmed"
|
||||
assert row["message_id"] is None
|
||||
assert row["result"] == {"result": {"out": "ok"}}
|
||||
|
||||
def test_mark_failed_sets_error(self, pg_conn):
|
||||
|
||||
@@ -8,7 +8,6 @@ MCPOAuthManager.
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -870,110 +869,17 @@ class TestMCPOAuthManager:
|
||||
from application.agents.tools.mcp_tool import MCPOAuthManager
|
||||
|
||||
manager = MCPOAuthManager(MagicMock())
|
||||
result = manager.get_oauth_status("", "alice")
|
||||
result = manager.get_oauth_status("")
|
||||
assert result["status"] == "not_started"
|
||||
|
||||
def test_get_oauth_status_no_user(self):
|
||||
def test_get_oauth_status_with_task(self):
|
||||
from application.agents.tools.mcp_tool import MCPOAuthManager
|
||||
|
||||
manager = MCPOAuthManager(MagicMock())
|
||||
result = manager.get_oauth_status("task123", "")
|
||||
# Without a user id we can't address the per-user SSE journal,
|
||||
# so the manager refuses rather than scanning every user's
|
||||
# stream.
|
||||
assert result["status"] == "not_found"
|
||||
|
||||
def test_get_oauth_status_reads_completed_envelope_from_journal(self):
|
||||
"""The manager walks the user's SSE Streams journal and
|
||||
surfaces the latest ``mcp.oauth.*`` envelope for the task.
|
||||
|
||||
Verifies the full polling-contract surface: ``status`` is
|
||||
derived from the event-type suffix, and the completed
|
||||
payload's ``tools`` / ``tools_count`` fields are passed
|
||||
through unchanged so ``mcp.py``'s ``connect_mcp`` can use them
|
||||
without further plumbing.
|
||||
"""
|
||||
from application.agents.tools.mcp_tool import MCPOAuthManager
|
||||
|
||||
completed_envelope = json.dumps(
|
||||
{
|
||||
"type": "mcp.oauth.completed",
|
||||
"scope": {"kind": "mcp_oauth", "id": "task123"},
|
||||
"payload": {
|
||||
"task_id": "task123",
|
||||
"tools": [{"name": "t1", "description": "d"}],
|
||||
"tools_count": 1,
|
||||
},
|
||||
}
|
||||
).encode("utf-8")
|
||||
# xrevrange yields newest-first; an unrelated older entry
|
||||
# should be skipped on the way to the matching one.
|
||||
unrelated_envelope = json.dumps(
|
||||
{
|
||||
"type": "source.ingest.completed",
|
||||
"scope": {"kind": "source", "id": "abc"},
|
||||
"payload": {},
|
||||
}
|
||||
).encode("utf-8")
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.xrevrange.return_value = [
|
||||
(b"1735682400000-0", {b"event": completed_envelope}),
|
||||
(b"1735682300000-0", {b"event": unrelated_envelope}),
|
||||
]
|
||||
|
||||
manager = MCPOAuthManager(mock_redis)
|
||||
result = manager.get_oauth_status("task123", "alice")
|
||||
|
||||
assert result["status"] == "completed"
|
||||
assert result["tools"] == [{"name": "t1", "description": "d"}]
|
||||
assert result["tools_count"] == 1
|
||||
# The stream key is scoped per-user — never global.
|
||||
mock_redis.xrevrange.assert_called_once()
|
||||
call_args = mock_redis.xrevrange.call_args
|
||||
assert "user:alice:stream" in call_args.args
|
||||
# Scan window must cover the full bounded stream so a flood of
|
||||
# concurrent source-ingest events between popup-completed and
|
||||
# Save can't push the OAuth envelope out of view.
|
||||
from application.core.settings import settings
|
||||
|
||||
assert call_args.kwargs.get("count") >= settings.EVENTS_STREAM_MAXLEN
|
||||
|
||||
def test_get_oauth_status_scan_covers_events_stream_maxlen(self):
|
||||
"""Regression: count must scale with ``EVENTS_STREAM_MAXLEN``
|
||||
so the OAuth completion envelope is reachable even after
|
||||
concurrent source-ingest events flood the user stream."""
|
||||
from application.agents.tools.mcp_tool import MCPOAuthManager
|
||||
from application.core.settings import settings
|
||||
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.xrevrange.return_value = []
|
||||
|
||||
manager = MCPOAuthManager(mock_redis)
|
||||
manager.get_oauth_status("task123", "alice")
|
||||
|
||||
call_args = mock_redis.xrevrange.call_args
|
||||
assert call_args.kwargs.get("count") >= settings.EVENTS_STREAM_MAXLEN
|
||||
|
||||
def test_get_oauth_status_returns_not_found_when_no_match(self):
|
||||
from application.agents.tools.mcp_tool import MCPOAuthManager
|
||||
|
||||
mock_redis = MagicMock()
|
||||
# Stream is non-empty but has nothing matching this task.
|
||||
unrelated_envelope = json.dumps(
|
||||
{
|
||||
"type": "mcp.oauth.completed",
|
||||
"scope": {"kind": "mcp_oauth", "id": "other-task"},
|
||||
"payload": {"task_id": "other-task"},
|
||||
}
|
||||
).encode("utf-8")
|
||||
mock_redis.xrevrange.return_value = [
|
||||
(b"1735682400000-0", {b"event": unrelated_envelope}),
|
||||
]
|
||||
|
||||
manager = MCPOAuthManager(mock_redis)
|
||||
result = manager.get_oauth_status("task123", "alice")
|
||||
|
||||
assert result["status"] == "not_found"
|
||||
with patch("application.agents.tools.mcp_tool.mcp_oauth_status_task",
|
||||
return_value={"status": "complete"}):
|
||||
manager = MCPOAuthManager(MagicMock())
|
||||
result = manager.get_oauth_status("task123")
|
||||
assert result["status"] == "complete"
|
||||
|
||||
|
||||
# =====================================================================
|
||||
@@ -1419,7 +1325,6 @@ class TestSetupClientExtended:
|
||||
tool._client = None
|
||||
tool.available_tools = []
|
||||
tool.user_id = "user1"
|
||||
tool.oauth_redirect_publish = None
|
||||
|
||||
mock_client = MagicMock()
|
||||
with patch.object(MCPTool, "_create_transport", return_value=MagicMock()), \
|
||||
@@ -2352,21 +2257,16 @@ class TestMCPOAuthManagerExtended:
|
||||
# Should store error in redis
|
||||
mock_redis.setex.assert_called()
|
||||
|
||||
def test_get_oauth_status_returns_not_found_on_redis_error(self):
|
||||
"""A failure inside ``xrevrange`` (Redis down, network blip) is
|
||||
swallowed — the manager returns ``not_found`` so the caller
|
||||
can present a clean "OAuth failed, try again" message rather
|
||||
than a 500.
|
||||
"""
|
||||
def test_get_oauth_status_task_error(self):
|
||||
from application.agents.tools.mcp_tool import MCPOAuthManager
|
||||
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.xrevrange.side_effect = Exception("Redis went away")
|
||||
|
||||
manager = MCPOAuthManager(mock_redis)
|
||||
result = manager.get_oauth_status("task123", "alice")
|
||||
|
||||
assert result["status"] == "not_found"
|
||||
with patch(
|
||||
"application.agents.tools.mcp_tool.mcp_oauth_status_task",
|
||||
side_effect=Exception("task failed"),
|
||||
):
|
||||
manager = MCPOAuthManager(MagicMock())
|
||||
with pytest.raises(Exception, match="task failed"):
|
||||
manager.get_oauth_status("task123")
|
||||
|
||||
|
||||
# =====================================================================
|
||||
|
||||
@@ -449,9 +449,6 @@ class TestFinalizeMessage:
|
||||
question="q",
|
||||
decoded_token={"sub": user},
|
||||
)
|
||||
from application.storage.db.repositories.conversations import (
|
||||
MessageUpdateOutcome,
|
||||
)
|
||||
assert svc.finalize_message(
|
||||
res["message_id"],
|
||||
"real answer",
|
||||
@@ -461,7 +458,7 @@ class TestFinalizeMessage:
|
||||
model_id="gpt-4",
|
||||
metadata={"foo": "bar"},
|
||||
status="complete",
|
||||
) is MessageUpdateOutcome.UPDATED
|
||||
) is True
|
||||
|
||||
msgs = ConversationsRepository(pg_conn).get_messages(
|
||||
res["conversation_id"],
|
||||
@@ -491,15 +488,12 @@ class TestFinalizeMessage:
|
||||
decoded_token={"sub": user},
|
||||
)
|
||||
err = RuntimeError("provider down")
|
||||
from application.storage.db.repositories.conversations import (
|
||||
MessageUpdateOutcome,
|
||||
)
|
||||
assert svc.finalize_message(
|
||||
res["message_id"],
|
||||
"fallback text",
|
||||
status="failed",
|
||||
error=err,
|
||||
) is MessageUpdateOutcome.UPDATED
|
||||
) is True
|
||||
|
||||
msgs = ConversationsRepository(pg_conn).get_messages(
|
||||
res["conversation_id"],
|
||||
@@ -532,12 +526,9 @@ class TestFinalizeMessage:
|
||||
),
|
||||
{"cid": "c1", "mid": res["message_id"]},
|
||||
)
|
||||
from application.storage.db.repositories.conversations import (
|
||||
MessageUpdateOutcome,
|
||||
)
|
||||
assert svc.finalize_message(
|
||||
res["message_id"], "ans", status="complete",
|
||||
) is MessageUpdateOutcome.UPDATED
|
||||
) is True
|
||||
|
||||
status = pg_conn.execute(
|
||||
sql_text("SELECT status FROM tool_call_attempts WHERE call_id = :cid"),
|
||||
@@ -545,19 +536,16 @@ class TestFinalizeMessage:
|
||||
).scalar()
|
||||
assert status == "confirmed"
|
||||
|
||||
def test_finalize_returns_not_found_for_unknown_message(self, pg_conn):
|
||||
def test_finalize_returns_false_for_unknown_message(self, pg_conn):
|
||||
from application.api.answer.services.conversation_service import (
|
||||
ConversationService,
|
||||
)
|
||||
from application.storage.db.repositories.conversations import (
|
||||
MessageUpdateOutcome,
|
||||
)
|
||||
with _patch_db(pg_conn):
|
||||
assert ConversationService().finalize_message(
|
||||
"00000000-0000-0000-0000-000000000000",
|
||||
"x",
|
||||
status="complete",
|
||||
) is MessageUpdateOutcome.NOT_FOUND
|
||||
) is False
|
||||
|
||||
def test_finalize_rolls_back_tool_call_confirm_on_message_update_failure(
|
||||
self, pg_conn
|
||||
@@ -663,9 +651,6 @@ class TestFinalizeMessage:
|
||||
question="long question that becomes the fallback name",
|
||||
decoded_token={"sub": user},
|
||||
)
|
||||
from application.storage.db.repositories.conversations import (
|
||||
MessageUpdateOutcome,
|
||||
)
|
||||
assert svc.finalize_message(
|
||||
res["message_id"],
|
||||
"answer",
|
||||
@@ -679,7 +664,7 @@ class TestFinalizeMessage:
|
||||
"long question that becomes the fallback name"[:50]
|
||||
),
|
||||
},
|
||||
) is MessageUpdateOutcome.UPDATED
|
||||
) is True
|
||||
|
||||
repo = ConversationsRepository(pg_conn)
|
||||
conv = repo.get_any(res["conversation_id"], user)
|
||||
|
||||
@@ -343,50 +343,6 @@ class TestProcessResponseStreamExtended:
|
||||
result = resource.process_response_stream(iter(stream))
|
||||
assert result["thought"] == "thinking..."
|
||||
|
||||
def test_handles_id_prefixed_chunks(self, mock_mongo_db, flask_app):
|
||||
"""``complete_stream`` emits ``id: <seq>\\n`` before each
|
||||
``data:`` line so reconnects can resume. The non-streaming
|
||||
``/api/answer`` consumer must skip the ``id:`` header (and the
|
||||
informational ``message_id`` event) without breaking JSON
|
||||
decoding.
|
||||
"""
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
stream = [
|
||||
'id: 0\n'
|
||||
f'data: {json.dumps({"type": "message_id", "message_id": "abc"})}\n\n',
|
||||
'id: 1\n'
|
||||
f'data: {json.dumps({"type": "answer", "answer": "Hello "})}\n\n',
|
||||
'id: 2\n'
|
||||
f'data: {json.dumps({"type": "answer", "answer": "world"})}\n\n',
|
||||
'id: 3\n'
|
||||
f'data: {json.dumps({"type": "id", "id": "conv-1"})}\n\n',
|
||||
'id: 4\n'
|
||||
f'data: {json.dumps({"type": "end"})}\n\n',
|
||||
]
|
||||
result = resource.process_response_stream(iter(stream))
|
||||
assert result["error"] is None
|
||||
assert result["answer"] == "Hello world"
|
||||
assert result["conversation_id"] == "conv-1"
|
||||
|
||||
def test_skips_keepalive_comment_lines(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
stream = [
|
||||
': keepalive\n\n',
|
||||
'id: 0\n'
|
||||
f'data: {json.dumps({"type": "answer", "answer": "ok"})}\n\n',
|
||||
'id: 1\n'
|
||||
f'data: {json.dumps({"type": "end"})}\n\n',
|
||||
]
|
||||
result = resource.process_response_stream(iter(stream))
|
||||
assert result["answer"] == "ok"
|
||||
assert result["error"] is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCheckUsageStringBooleans:
|
||||
@@ -637,251 +593,6 @@ class TestCompleteStreamGeneratorExit:
|
||||
next(gen)
|
||||
gen.close() # Should not crash even with save error
|
||||
|
||||
def test_generator_exit_before_any_response_journals_error_not_end(
|
||||
self, mock_mongo_db, flask_app,
|
||||
):
|
||||
"""A client disconnect right after the early ``message_id`` frame
|
||||
leaves ``response_full`` empty, so finalize never runs. The abort
|
||||
handler must journal ``error`` (not ``end``) and flip the row to
|
||||
``failed`` — otherwise a reconnecting client sees a terminal
|
||||
``end`` for a row whose DB status is still non-terminal and the
|
||||
UI parks on a blank successful answer.
|
||||
"""
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
mock_agent = MagicMock()
|
||||
|
||||
# An agent that's primed but never yields anything before the
|
||||
# client disconnects — keeps ``response_full`` empty.
|
||||
def gen_never_yields():
|
||||
if False:
|
||||
yield # pragma: no cover
|
||||
return
|
||||
|
||||
mock_agent.gen.return_value = gen_never_yields()
|
||||
mock_agent.compression_metadata = None
|
||||
mock_agent.compression_saved = False
|
||||
mock_agent.tool_calls = []
|
||||
|
||||
resource.conversation_service = MagicMock()
|
||||
resource.conversation_service.save_user_question.return_value = {
|
||||
"conversation_id": "conv1",
|
||||
"message_id": "msg1",
|
||||
"request_id": "req1",
|
||||
}
|
||||
|
||||
journaled: list[tuple] = []
|
||||
|
||||
def _capture_record(message_id, sequence_no, event_type, payload):
|
||||
journaled.append((message_id, sequence_no, event_type, payload))
|
||||
return True
|
||||
|
||||
with patch(
|
||||
"application.api.answer.routes.base.record_event",
|
||||
side_effect=_capture_record,
|
||||
):
|
||||
gen = resource.complete_stream(
|
||||
question="Q",
|
||||
agent=mock_agent,
|
||||
conversation_id="conv1",
|
||||
user_api_key=None,
|
||||
decoded_token={"sub": "u"},
|
||||
should_save_conversation=True,
|
||||
model_id="gpt-4",
|
||||
)
|
||||
# Drain the early ``message_id`` event, then close before
|
||||
# the agent yields anything.
|
||||
next(gen)
|
||||
gen.close()
|
||||
|
||||
# The early message_id frame got journaled (seq 0); the abort
|
||||
# handler must follow with an ``error`` event (NOT ``end``).
|
||||
terminal_events = [
|
||||
(et, pl) for (_, _, et, pl) in journaled if et in ("end", "error")
|
||||
]
|
||||
assert len(terminal_events) == 1, (
|
||||
f"expected exactly one terminal journal write, got {terminal_events}"
|
||||
)
|
||||
assert terminal_events[0][0] == "error", (
|
||||
f"expected ``error`` terminal but got ``end``: {terminal_events}"
|
||||
)
|
||||
payload = terminal_events[0][1]
|
||||
assert payload.get("type") == "error"
|
||||
assert payload.get("code") == "client_disconnect"
|
||||
|
||||
# And the DB row should have been flipped to ``failed`` via
|
||||
# finalize_message. The mocked service records the call.
|
||||
finalize_calls = (
|
||||
resource.conversation_service.finalize_message.call_args_list
|
||||
)
|
||||
assert len(finalize_calls) == 1
|
||||
assert finalize_calls[0].kwargs.get("status") == "failed"
|
||||
|
||||
def test_generator_exit_after_response_still_journals_end(
|
||||
self, mock_mongo_db, flask_app,
|
||||
):
|
||||
"""Regression guard: a disconnect AFTER partial response was
|
||||
produced and ``finalize_message`` succeeded must still journal
|
||||
``end`` (the row matches ``complete``). Only the empty-response
|
||||
branch flips to ``error``.
|
||||
"""
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
from application.storage.db.repositories.conversations import (
|
||||
MessageUpdateOutcome,
|
||||
)
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
mock_agent = MagicMock()
|
||||
|
||||
def gen_answers():
|
||||
yield {"answer": "partial"}
|
||||
|
||||
mock_agent.gen.return_value = gen_answers()
|
||||
mock_agent.compression_metadata = None
|
||||
mock_agent.compression_saved = False
|
||||
mock_agent.tool_calls = []
|
||||
|
||||
resource.conversation_service = MagicMock()
|
||||
resource.conversation_service.finalize_message.return_value = (
|
||||
MessageUpdateOutcome.UPDATED
|
||||
)
|
||||
resource.conversation_service.save_user_question.return_value = {
|
||||
"conversation_id": "conv1",
|
||||
"message_id": "msg1",
|
||||
"request_id": "req1",
|
||||
}
|
||||
|
||||
journaled: list[tuple] = []
|
||||
|
||||
def _capture_record(message_id, sequence_no, event_type, payload):
|
||||
journaled.append((message_id, sequence_no, event_type, payload))
|
||||
return True
|
||||
|
||||
with patch(
|
||||
"application.api.answer.routes.base.record_event",
|
||||
side_effect=_capture_record,
|
||||
):
|
||||
gen = resource.complete_stream(
|
||||
question="Q",
|
||||
agent=mock_agent,
|
||||
conversation_id="conv1",
|
||||
user_api_key=None,
|
||||
decoded_token={"sub": "u"},
|
||||
should_save_conversation=True,
|
||||
model_id="gpt-4",
|
||||
)
|
||||
next(gen) # message_id frame
|
||||
next(gen) # answer frame (consumes ``partial``)
|
||||
gen.close()
|
||||
|
||||
terminal_events = [
|
||||
(et, pl) for (_, _, et, pl) in journaled if et in ("end", "error")
|
||||
]
|
||||
assert terminal_events and terminal_events[0][0] == "end", (
|
||||
f"finalize succeeded but abort journaled {terminal_events}"
|
||||
)
|
||||
|
||||
def test_generator_exit_after_normal_finalize_already_complete_journals_end(
|
||||
self, mock_mongo_db, flask_app,
|
||||
):
|
||||
"""Regression for the race where the normal-path finalize wins
|
||||
against a client disconnect.
|
||||
|
||||
Trace: agent finishes, ``complete_stream`` runs the normal-path
|
||||
``finalize_message`` at base.py:632 and flips the row to
|
||||
``complete``. The client TCP-resets before the ``end`` frame can
|
||||
be journaled. The GeneratorExit handler calls ``finalize_message``
|
||||
again — and the repository, gated by ``only_if_non_terminal``,
|
||||
reports ``ALREADY_COMPLETE`` because the row is already at the
|
||||
target state. The abort handler must journal ``end``, not
|
||||
``error``: the DB says ``complete``, the reconnecting client
|
||||
must see the same.
|
||||
|
||||
Without the fix the abort handler treated ``ALREADY_COMPLETE``
|
||||
as a failure and journaled ``error`` — a reconnect would then
|
||||
replay a successful completion as a failed answer.
|
||||
"""
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
from application.storage.db.repositories.conversations import (
|
||||
MessageUpdateOutcome,
|
||||
)
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
mock_agent = MagicMock()
|
||||
|
||||
def gen_answers():
|
||||
yield {"answer": "partial"}
|
||||
|
||||
mock_agent.gen.return_value = gen_answers()
|
||||
mock_agent.compression_metadata = None
|
||||
mock_agent.compression_saved = False
|
||||
mock_agent.tool_calls = []
|
||||
|
||||
resource.conversation_service = MagicMock()
|
||||
# Normal-path call returns UPDATED; abort-handler call sees
|
||||
# the row already terminal and returns ALREADY_COMPLETE.
|
||||
resource.conversation_service.finalize_message.side_effect = [
|
||||
MessageUpdateOutcome.UPDATED,
|
||||
MessageUpdateOutcome.ALREADY_COMPLETE,
|
||||
]
|
||||
resource.conversation_service.save_user_question.return_value = {
|
||||
"conversation_id": "conv1",
|
||||
"message_id": "msg1",
|
||||
"request_id": "req1",
|
||||
}
|
||||
|
||||
journaled: list[tuple] = []
|
||||
|
||||
def _capture_record(message_id, sequence_no, event_type, payload):
|
||||
journaled.append((message_id, sequence_no, event_type, payload))
|
||||
return True
|
||||
|
||||
with patch(
|
||||
"application.api.answer.routes.base.record_event",
|
||||
side_effect=_capture_record,
|
||||
):
|
||||
gen = resource.complete_stream(
|
||||
question="Q",
|
||||
agent=mock_agent,
|
||||
conversation_id="conv1",
|
||||
user_api_key=None,
|
||||
decoded_token={"sub": "u"},
|
||||
should_save_conversation=True,
|
||||
model_id="gpt-4",
|
||||
)
|
||||
next(gen) # message_id frame
|
||||
next(gen) # answer frame
|
||||
# Pull the ``id`` frame so the normal-path finalize at
|
||||
# line 632 actually runs (it sits between the answer
|
||||
# frame yield and the id frame yield).
|
||||
next(gen) # id frame — runs normal-path finalize first
|
||||
gen.close() # GeneratorExit at the id frame yield
|
||||
|
||||
terminal_events = [
|
||||
(et, pl) for (_, _, et, pl) in journaled if et in ("end", "error")
|
||||
]
|
||||
assert len(terminal_events) == 1, (
|
||||
f"expected exactly one terminal journal write, got "
|
||||
f"{terminal_events}"
|
||||
)
|
||||
assert terminal_events[0][0] == "end", (
|
||||
f"row was already ``complete`` but abort journaled "
|
||||
f"{terminal_events[0]} — reconnect would surface a "
|
||||
f"successful answer as failed"
|
||||
)
|
||||
|
||||
# Both finalize_message calls were made: the normal-path
|
||||
# one and the abort-handler one. Asserting on side_effect
|
||||
# consumption ensures the test really exercised both
|
||||
# branches.
|
||||
assert (
|
||||
resource.conversation_service.finalize_message.call_count == 2
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _patch_db_session(conn):
|
||||
@@ -895,34 +606,10 @@ def _patch_db_session(conn):
|
||||
), patch(
|
||||
"application.api.answer.services.conversation_service.db_readonly",
|
||||
_yield,
|
||||
), patch(
|
||||
# ``record_event`` opens its own short-lived ``db_session`` for
|
||||
# cross-connection visibility. In tests we route it back to the
|
||||
# same ``pg_conn`` so the journal write can see the message row
|
||||
# the conversation_service just wrote in this transaction.
|
||||
"application.streaming.message_journal.db_session",
|
||||
_yield,
|
||||
), patch(
|
||||
# ``complete_stream`` reads ``latest_sequence_no`` via
|
||||
# ``db_readonly`` to seed continuation runs. Same patch reason
|
||||
# as the journal — keep the read on the same pg_conn so it sees
|
||||
# uncommitted writes from this transaction.
|
||||
"application.api.answer.routes.base.db_readonly",
|
||||
_yield,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
def _extract_sse_data(chunk: str) -> str:
|
||||
"""Pull the ``data:`` payload from an SSE record, ignoring any
|
||||
``id:`` header introduced by the journal wiring.
|
||||
"""
|
||||
for line in chunk.split("\n"):
|
||||
if line.startswith("data:"):
|
||||
return line[len("data:") :].lstrip()
|
||||
return ""
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCompleteStreamWalAcceptance:
|
||||
"""Acceptance for the WAL pre-persist behaviour: when the LLM raises
|
||||
@@ -972,158 +659,6 @@ class TestCompleteStreamWalAcceptance:
|
||||
assert "RuntimeError" in msgs[0]["metadata"]["error"]
|
||||
assert "LLM upstream failed" in msgs[0]["metadata"]["error"]
|
||||
|
||||
def test_tool_approval_event_only_fires_when_state_saved(
|
||||
self, pg_conn, flask_app,
|
||||
):
|
||||
"""A `tool.approval.required` notification with no resumable
|
||||
``pending_tool_state`` row would deep-link the user to a 404.
|
||||
Gate the publish on save_state actually committing.
|
||||
"""
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
mock_agent = MagicMock()
|
||||
# Drive the agent into the paused branch with a single
|
||||
# ``tool_calls_pending`` event.
|
||||
mock_agent.gen.return_value = iter(
|
||||
[
|
||||
{
|
||||
"type": "tool_calls_pending",
|
||||
"data": {"pending_tool_calls": [{"call_id": "c1"}]},
|
||||
}
|
||||
]
|
||||
)
|
||||
mock_agent._pending_continuation = {
|
||||
"messages": [],
|
||||
"tools_dict": {},
|
||||
"pending_tool_calls": [{"call_id": "c1"}],
|
||||
}
|
||||
mock_agent.tool_calls = []
|
||||
mock_agent.compression_metadata = None
|
||||
mock_agent.compression_saved = False
|
||||
|
||||
published: list = []
|
||||
|
||||
def _capture(*args, **kwargs):
|
||||
published.append(args)
|
||||
|
||||
with _patch_db_session(pg_conn), patch(
|
||||
"application.api.answer.routes.base.publish_user_event",
|
||||
side_effect=_capture,
|
||||
), patch(
|
||||
"application.api.answer.services.continuation_service."
|
||||
"ContinuationService.save_state",
|
||||
side_effect=RuntimeError("PG outage"),
|
||||
):
|
||||
list(
|
||||
resource.complete_stream(
|
||||
question="run my tool",
|
||||
agent=mock_agent,
|
||||
conversation_id=None,
|
||||
user_api_key=None,
|
||||
decoded_token={"sub": "u-tap"},
|
||||
should_save_conversation=True,
|
||||
model_id="gpt-4",
|
||||
)
|
||||
)
|
||||
|
||||
# No tool.approval.required publish when save_state failed.
|
||||
event_types = [a[1] for a in published if len(a) >= 2]
|
||||
assert "tool.approval.required" not in event_types
|
||||
|
||||
def test_continuation_seeds_sequence_no_from_journal_high_water_mark(
|
||||
self, pg_conn, flask_app,
|
||||
):
|
||||
"""A resumed (tool-actions) stream must continue numbering past
|
||||
the original run's max ``sequence_no``. Otherwise the second
|
||||
invocation collides on the duplicate-PK and silently drops
|
||||
every journal write past the resume point.
|
||||
|
||||
We pre-seed the journal with synthetic rows simulating an
|
||||
original run, then invoke ``complete_stream`` with
|
||||
``_continuation`` set and assert seq numbering picks up past
|
||||
the high-water mark.
|
||||
"""
|
||||
import uuid as _uuid
|
||||
|
||||
from sqlalchemy import text as sql_text
|
||||
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
from application.storage.db.repositories.message_events import (
|
||||
MessageEventsRepository,
|
||||
)
|
||||
|
||||
# Pre-seed: insert a parent conversation + message + a few
|
||||
# journal rows so latest_sequence_no returns 7.
|
||||
user_id = "u-resume"
|
||||
conv_id = _uuid.uuid4()
|
||||
message_id = _uuid.uuid4()
|
||||
pg_conn.execute(
|
||||
sql_text("INSERT INTO users (user_id) VALUES (:u)"),
|
||||
{"u": user_id},
|
||||
)
|
||||
pg_conn.execute(
|
||||
sql_text(
|
||||
"INSERT INTO conversations (id, user_id, name) "
|
||||
"VALUES (:id, :u, 'pre-seed')"
|
||||
),
|
||||
{"id": conv_id, "u": user_id},
|
||||
)
|
||||
pg_conn.execute(
|
||||
sql_text(
|
||||
"INSERT INTO conversation_messages "
|
||||
"(id, conversation_id, user_id, position, prompt) "
|
||||
"VALUES (:id, :c, :u, 0, 'q')"
|
||||
),
|
||||
{"id": message_id, "c": conv_id, "u": user_id},
|
||||
)
|
||||
repo = MessageEventsRepository(pg_conn)
|
||||
for seq in range(8): # rows 0..7
|
||||
repo.record(str(message_id), seq, "answer", {"chunk": str(seq)})
|
||||
original_max = repo.latest_sequence_no(str(message_id))
|
||||
assert original_max == 7
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
cont_agent = MagicMock()
|
||||
cont_agent.gen_continuation.return_value = iter(
|
||||
[{"answer": "d"}, {"answer": "e"}]
|
||||
)
|
||||
cont_agent.tool_calls = []
|
||||
cont_agent.compression_metadata = None
|
||||
cont_agent.compression_saved = False
|
||||
cont_agent.tool_executor = None
|
||||
|
||||
with _patch_db_session(pg_conn):
|
||||
list(
|
||||
resource.complete_stream(
|
||||
question="",
|
||||
agent=cont_agent,
|
||||
conversation_id=str(conv_id),
|
||||
user_api_key=None,
|
||||
decoded_token={"sub": user_id},
|
||||
should_save_conversation=True,
|
||||
model_id="gpt-4",
|
||||
_continuation={
|
||||
"messages": [],
|
||||
"tools_dict": {},
|
||||
"pending_tool_calls": [],
|
||||
"tool_actions": [],
|
||||
"reserved_message_id": str(message_id),
|
||||
"request_id": "req-resume",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
new_max = MessageEventsRepository(pg_conn).latest_sequence_no(
|
||||
str(message_id)
|
||||
)
|
||||
# Continuation extended the journal — the new high-water mark
|
||||
# is strictly greater than the seeded ``original_max=7``,
|
||||
# confirming the allocator picked up past the resume point.
|
||||
assert new_max is not None and new_max > original_max
|
||||
|
||||
def test_request_id_consistent_across_sse_event_and_wal_row(
|
||||
self, pg_conn, flask_app,
|
||||
):
|
||||
@@ -1159,9 +694,9 @@ class TestCompleteStreamWalAcceptance:
|
||||
)
|
||||
|
||||
sse_events = [
|
||||
json.loads(_extract_sse_data(s))
|
||||
json.loads(s.replace("data: ", "").strip())
|
||||
for s in stream
|
||||
if "data:" in s
|
||||
if s.startswith("data: ")
|
||||
]
|
||||
early_events = [e for e in sse_events if e.get("type") == "message_id"]
|
||||
assert len(early_events) == 1
|
||||
@@ -1177,102 +712,3 @@ class TestCompleteStreamWalAcceptance:
|
||||
msgs = ConversationsRepository(pg_conn).get_messages(str(convs[0][0]))
|
||||
assert len(msgs) == 1
|
||||
assert msgs[0]["request_id"] == sse_request_id
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestStreamingHeartbeatSeed:
|
||||
"""Regression guard: when the row flips to ``streaming`` we must seed
|
||||
``last_heartbeat_at`` so the watchdog doesn't fall back to
|
||||
``timestamp`` (creation time) on slow LLM cold-starts (>idle_secs).
|
||||
"""
|
||||
|
||||
def test_heartbeat_seeded_on_first_chunk(self, mock_mongo_db, flask_app):
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.gen.return_value = iter(
|
||||
[{"answer": "a"}, {"answer": "b"}]
|
||||
)
|
||||
mock_agent.compression_metadata = None
|
||||
mock_agent.compression_saved = False
|
||||
mock_agent.tool_calls = []
|
||||
mock_agent.tool_executor = None
|
||||
|
||||
resource.conversation_service = MagicMock()
|
||||
resource.conversation_service.save_conversation.return_value = "conv1"
|
||||
resource.conversation_service.save_user_question.return_value = {
|
||||
"conversation_id": "conv1",
|
||||
"message_id": "msg1",
|
||||
"request_id": "req1",
|
||||
}
|
||||
|
||||
list(
|
||||
resource.complete_stream(
|
||||
question="Q",
|
||||
agent=mock_agent,
|
||||
conversation_id="conv1",
|
||||
user_api_key=None,
|
||||
decoded_token={"sub": "u"},
|
||||
should_save_conversation=True,
|
||||
model_id="gpt-4",
|
||||
)
|
||||
)
|
||||
|
||||
# update_message_status flips the row to ``streaming`` exactly
|
||||
# once (idempotent via streaming_marked).
|
||||
status_calls = [
|
||||
c for c in resource.conversation_service
|
||||
.update_message_status.call_args_list
|
||||
if c.args[1] == "streaming"
|
||||
]
|
||||
assert len(status_calls) == 1
|
||||
assert status_calls[0].args[0] == "msg1"
|
||||
|
||||
# heartbeat seed runs exactly once at the same flip — multiple
|
||||
# chunks don't re-stamp inside this window (the throttled
|
||||
# _heartbeat_streaming path is gated by STREAM_HEARTBEAT_INTERVAL).
|
||||
hb_calls = (
|
||||
resource.conversation_service.heartbeat_message.call_args_list
|
||||
)
|
||||
assert len(hb_calls) == 1
|
||||
assert hb_calls[0].args[0] == "msg1"
|
||||
|
||||
def test_heartbeat_seed_skipped_without_reserved_message_id(
|
||||
self, mock_mongo_db, flask_app,
|
||||
):
|
||||
"""No DB-backed message row → no heartbeat call (and no error)."""
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
|
||||
with flask_app.app_context():
|
||||
resource = BaseAnswerResource()
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.gen.return_value = iter([{"answer": "a"}])
|
||||
mock_agent.compression_metadata = None
|
||||
mock_agent.compression_saved = False
|
||||
mock_agent.tool_calls = []
|
||||
mock_agent.tool_executor = None
|
||||
|
||||
resource.conversation_service = MagicMock()
|
||||
resource.conversation_service.save_conversation.return_value = "conv1"
|
||||
# save_user_question returns no message_id → reservation absent
|
||||
resource.conversation_service.save_user_question.return_value = {
|
||||
"conversation_id": "conv1",
|
||||
"message_id": None,
|
||||
"request_id": "req1",
|
||||
}
|
||||
|
||||
list(
|
||||
resource.complete_stream(
|
||||
question="Q",
|
||||
agent=mock_agent,
|
||||
conversation_id="conv1",
|
||||
user_api_key=None,
|
||||
decoded_token={"sub": "u"},
|
||||
should_save_conversation=True,
|
||||
model_id="gpt-4",
|
||||
)
|
||||
)
|
||||
|
||||
resource.conversation_service.heartbeat_message.assert_not_called()
|
||||
|
||||
@@ -1,233 +0,0 @@
|
||||
"""Integration tests for the end-to-end snapshot+tail handoff.
|
||||
|
||||
Exercises the publisher → journal → reconnect endpoint round-trip
|
||||
without mocking the journal layer, so a regression in any of:
|
||||
- complete_stream's _emit closure
|
||||
- record_event's commit-per-call contract
|
||||
- build_message_event_stream's snapshot-from-DB path
|
||||
- the reconnect route's auth + ownership gates
|
||||
- message_events repo SQL
|
||||
|
||||
would surface here as a failed integration assertion.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid as _uuid
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import text as sql_text
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _patch_journal_session(conn):
|
||||
@contextmanager
|
||||
def _yield():
|
||||
yield conn
|
||||
|
||||
with patch(
|
||||
"application.streaming.message_journal.db_session", _yield
|
||||
), patch(
|
||||
"application.streaming.event_replay.db_readonly", _yield
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
def _seed_message(conn, user_id: str | None = None):
|
||||
user_id = user_id or f"u-{_uuid.uuid4().hex[:6]}"
|
||||
conv_id = _uuid.uuid4()
|
||||
msg_id = _uuid.uuid4()
|
||||
conn.execute(sql_text("INSERT INTO users (user_id) VALUES (:u)"), {"u": user_id})
|
||||
conn.execute(
|
||||
sql_text(
|
||||
"INSERT INTO conversations (id, user_id, name) VALUES (:id, :u, 't')"
|
||||
),
|
||||
{"id": conv_id, "u": user_id},
|
||||
)
|
||||
conn.execute(
|
||||
sql_text(
|
||||
"INSERT INTO conversation_messages (id, conversation_id, user_id, position) "
|
||||
"VALUES (:id, :c, :u, 0)"
|
||||
),
|
||||
{"id": msg_id, "c": conv_id, "u": user_id},
|
||||
)
|
||||
return user_id, str(msg_id)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestSnapshotPlusTailRoundTrip:
|
||||
def test_record_event_then_snapshot_returns_what_was_journaled(
|
||||
self, pg_conn,
|
||||
):
|
||||
"""End-to-end of the journal half: ``record_event`` writes through
|
||||
a real ``MessageEventsRepository``, ``build_message_event_stream``
|
||||
reads the snapshot back via the same repo + a stub Topic.
|
||||
"""
|
||||
from application.streaming import event_replay
|
||||
from application.streaming.message_journal import record_event
|
||||
|
||||
_, message_id = _seed_message(pg_conn)
|
||||
|
||||
with _patch_journal_session(pg_conn):
|
||||
# Three events stamp the journal.
|
||||
record_event(message_id, 0, "answer", {"type": "answer", "answer": "A"})
|
||||
record_event(message_id, 1, "answer", {"type": "answer", "answer": "B"})
|
||||
record_event(message_id, 2, "end", {"type": "end"})
|
||||
|
||||
# Reconnect path: subscribe yields nothing (Redis-down
|
||||
# branch); the post-loop fallback runs the snapshot read
|
||||
# synchronously and yields the journal contents.
|
||||
def _empty_subscribe(self, on_subscribe=None, poll_timeout=1.0):
|
||||
return
|
||||
yield # pragma: no cover
|
||||
|
||||
with patch.object(
|
||||
event_replay.Topic,
|
||||
"subscribe",
|
||||
_empty_subscribe,
|
||||
create=False,
|
||||
):
|
||||
gen = event_replay.build_message_event_stream(
|
||||
message_id,
|
||||
last_event_id=None,
|
||||
keepalive_seconds=0.05,
|
||||
poll_timeout_seconds=0.01,
|
||||
)
|
||||
out = list(gen)
|
||||
|
||||
# Prelude + 3 snapshot frames.
|
||||
assert out[0] == ": connected\n\n"
|
||||
assert "id: 0" in out[1] and '"answer": "A"' in out[1]
|
||||
assert "id: 1" in out[2] and '"answer": "B"' in out[2]
|
||||
assert "id: 2" in out[3] and '"type": "end"' in out[3]
|
||||
|
||||
def test_snapshot_resumes_past_last_event_id(self, pg_conn):
|
||||
from application.streaming import event_replay
|
||||
from application.streaming.message_journal import record_event
|
||||
|
||||
_, message_id = _seed_message(pg_conn)
|
||||
|
||||
with _patch_journal_session(pg_conn):
|
||||
for seq in range(5):
|
||||
record_event(
|
||||
message_id, seq, "answer", {"type": "answer", "answer": str(seq)}
|
||||
)
|
||||
|
||||
def _empty_subscribe(self, on_subscribe=None, poll_timeout=1.0):
|
||||
return
|
||||
yield # pragma: no cover
|
||||
|
||||
with patch.object(
|
||||
event_replay.Topic,
|
||||
"subscribe",
|
||||
_empty_subscribe,
|
||||
create=False,
|
||||
):
|
||||
# Client says it has seen up through seq=2; expect 3 + 4.
|
||||
out = list(
|
||||
event_replay.build_message_event_stream(
|
||||
message_id,
|
||||
last_event_id=2,
|
||||
keepalive_seconds=0.05,
|
||||
poll_timeout_seconds=0.01,
|
||||
)
|
||||
)
|
||||
|
||||
ids_seen = [line for line in out if line.startswith("id: ")]
|
||||
# Multi-line records: extract the id integers we delivered.
|
||||
emitted = sorted(
|
||||
int(line.split(": ", 1)[1].split("\n")[0])
|
||||
for line in out
|
||||
if line.startswith("id: ")
|
||||
)
|
||||
# Filter for non-negative (the snapshot-failure synthetic uses -1).
|
||||
emitted = [e for e in emitted if e >= 0]
|
||||
assert emitted == [3, 4]
|
||||
assert ids_seen # sanity
|
||||
|
||||
def test_reconnect_route_round_trip(self, pg_conn, flask_app):
|
||||
"""``/api/messages/<id>/events`` returns the journaled events
|
||||
for an authenticated owner.
|
||||
"""
|
||||
from flask import Flask, request
|
||||
|
||||
from application.api.answer.routes.messages import messages_bp
|
||||
from application.streaming.message_journal import record_event
|
||||
|
||||
# Build a fresh Flask app routing to the reconnect blueprint
|
||||
# plus a tiny auth shim that injects the test user.
|
||||
user_id, message_id = _seed_message(pg_conn)
|
||||
app = Flask(__name__)
|
||||
app.register_blueprint(messages_bp)
|
||||
app.config["TESTING"] = True
|
||||
|
||||
@app.before_request
|
||||
def _shim_auth():
|
||||
request.decoded_token = {"sub": user_id}
|
||||
|
||||
with _patch_journal_session(pg_conn):
|
||||
record_event(message_id, 0, "answer", {"type": "answer", "answer": "x"})
|
||||
record_event(message_id, 1, "end", {"type": "end"})
|
||||
|
||||
from application.streaming import event_replay
|
||||
|
||||
def _empty_subscribe(self, on_subscribe=None, poll_timeout=1.0):
|
||||
return
|
||||
yield # pragma: no cover
|
||||
|
||||
with patch.object(
|
||||
event_replay.Topic,
|
||||
"subscribe",
|
||||
_empty_subscribe,
|
||||
create=False,
|
||||
), patch(
|
||||
"application.api.answer.routes.messages.db_readonly"
|
||||
) as ro:
|
||||
ro.return_value.__enter__.return_value = pg_conn
|
||||
|
||||
with app.test_client() as c:
|
||||
r = c.get(f"/api/messages/{message_id}/events")
|
||||
assert r.status_code == 200
|
||||
body = b""
|
||||
for chunk in r.iter_encoded():
|
||||
body += chunk
|
||||
if body.count(b"\n\n") >= 4:
|
||||
break
|
||||
r.close()
|
||||
# Both journaled events present in the response.
|
||||
text = body.decode("utf-8")
|
||||
assert ": connected" in text
|
||||
assert '"answer": "x"' in text
|
||||
assert '"type": "end"' in text
|
||||
# The seq lines are correct.
|
||||
assert "id: 0" in text and "id: 1" in text
|
||||
|
||||
def test_reconnect_rejects_non_owner(self, pg_conn, flask_app):
|
||||
from flask import Flask, request
|
||||
|
||||
from application.api.answer.routes.messages import messages_bp
|
||||
|
||||
user_id, message_id = _seed_message(pg_conn)
|
||||
app = Flask(__name__)
|
||||
app.register_blueprint(messages_bp)
|
||||
|
||||
@app.before_request
|
||||
def _shim_auth():
|
||||
request.decoded_token = {"sub": "different-user"}
|
||||
|
||||
# Make the ownership check use the test connection.
|
||||
with patch(
|
||||
"application.api.answer.routes.messages.db_readonly"
|
||||
) as ro:
|
||||
from contextlib import contextmanager as _cm
|
||||
|
||||
@_cm
|
||||
def _yield():
|
||||
yield pg_conn
|
||||
|
||||
ro.side_effect = lambda: _yield()
|
||||
with app.test_client() as c:
|
||||
r = c.get(f"/api/messages/{message_id}/events")
|
||||
assert r.status_code == 404
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user