mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-17 02:25:38 +00:00
complete_stream previously opened a fresh db_session() per yielded event, doing one Postgres INSERT + commit per chunk on the WSGI thread. Streaming answers emit ~100s of answer chunks per response, so the route was paying ~100 PG roundtrips per stream serialized on commit latency. New BatchedJournalWriter in application/streaming/message_journal.py accumulates rows per stream and flushes on three triggers: - size: buffer reaches 16 entries - time: 100ms elapsed since the last flush - lifecycle: close() at end-of-stream Live pubsub publishes still fire synchronously per record(), so subscribers see events in real time — only the durable journal write is amortized. On bulk INSERT IntegrityError the writer falls back to per-row record() with the existing seq+1 retry so a single colliding seq doesn't drop the rest of the batch. complete_stream wires journal_writer.close() into every exit path (happy end, tool-approval-paused end, GeneratorExit, error handler) so the terminal event is committed before the generator returns — otherwise a reconnecting client could snapshot up to the last flush boundary and live-tail waiting for an end that's still in memory. Repository gets bulk_record() — one SQLAlchemy executemany INSERT for the bulk path. All-or-nothing on collision (Postgres aborts the whole batch); the writer's per-row fallback handles recovery.
427 lines
17 KiB
Python
427 lines
17 KiB
Python
"""Per-yield journal write for the chat-stream snapshot+tail pattern.
|
|
|
|
``complete_stream`` calls ``record_event`` once per SSE event it
|
|
yields. The hook does two things:
|
|
|
|
1. Insert a row into ``message_events`` (the durable snapshot used by
|
|
reconnecting clients reading from a *different* connection).
|
|
2. Publish a JSON envelope to ``channel:{message_id}`` so any client
|
|
currently subscribed receives the event live.
|
|
|
|
Both are best-effort: failures are logged and swallowed, never raised
|
|
back into the streaming loop. A missed journal write means a client
|
|
that reconnects between this event and the next won't see this one in
|
|
their snapshot — degraded UX, not corrupted state. A missed publish
|
|
means currently-subscribed reconnect viewers miss the live tick;
|
|
they'll catch up via the snapshot on their next reconnect (or after
|
|
their poll-timeout cycle if they're already attached).
|
|
|
|
Each ``record_event`` opens its own short-lived ``db_session()`` so
|
|
the INSERT commits before the matching publish — without that ordering
|
|
a fast-reconnecting client could hit the snapshot read on a separate
|
|
connection and miss the row that's still uncommitted on the streaming
|
|
connection.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import time
|
|
from typing import Any, Optional
|
|
|
|
from sqlalchemy.exc import IntegrityError
|
|
|
|
from application.storage.db.repositories.message_events import (
|
|
MessageEventsRepository,
|
|
)
|
|
from application.storage.db.session import db_readonly, db_session
|
|
from application.streaming.broadcast_channel import Topic
|
|
from application.streaming.event_replay import encode_pubsub_message
|
|
from application.streaming.keys import message_topic_name
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# Tunables for ``BatchedJournalWriter``. A streaming answer emits ~100s
|
|
# of ``answer`` chunks per response; without batching, that's one PG
|
|
# transaction per yield in the WSGI thread. With these defaults, ~10x
|
|
# fewer commits at the cost of a ≤100ms reconnect-visibility lag for
|
|
# any event still sitting in the buffer.
|
|
DEFAULT_BATCH_SIZE = 16
|
|
DEFAULT_BATCH_INTERVAL_MS = 100
|
|
|
|
|
|
def record_event(
|
|
message_id: str,
|
|
sequence_no: int,
|
|
event_type: str,
|
|
payload: Optional[dict[str, Any]] = None,
|
|
) -> bool:
|
|
"""Journal one SSE event and publish it live. Best-effort.
|
|
|
|
``payload`` must be a ``dict`` (or ``None``). Passing a list,
|
|
scalar, or any other shape is a contract violation: the live path
|
|
in ``base.py::_emit`` and the replay path in
|
|
``event_replay`` previously reconstructed non-dicts differently
|
|
(``{"value": payload}`` vs. ``{"type": event_type}``), so a
|
|
reconnecting client would receive a different envelope than the
|
|
one originally streamed. Rejecting non-dicts at this gate keeps
|
|
the two paths byte-identical.
|
|
|
|
Returns ``True`` when the journal INSERT committed (the publish is
|
|
attempted regardless of insert outcome and isn't reflected in the
|
|
return value). Never raises — every failure path logs and swallows.
|
|
"""
|
|
if not message_id or not event_type:
|
|
logger.warning(
|
|
"record_event called without message_id/event_type "
|
|
"(message_id=%r, event_type=%r)",
|
|
message_id,
|
|
event_type,
|
|
)
|
|
return False
|
|
|
|
if payload is None:
|
|
materialised_payload: dict[str, Any] = {}
|
|
elif isinstance(payload, dict):
|
|
materialised_payload = payload
|
|
else:
|
|
logger.warning(
|
|
"record_event called with non-dict payload "
|
|
"(message_id=%s seq=%s type=%s payload_type=%s) — dropping",
|
|
message_id,
|
|
sequence_no,
|
|
event_type,
|
|
type(payload).__name__,
|
|
)
|
|
return False
|
|
|
|
journal_committed = False
|
|
# The seq we actually managed to write. Diverges from
|
|
# ``sequence_no`` only on the IntegrityError-retry path below.
|
|
materialised_seq = sequence_no
|
|
try:
|
|
# Short-lived per-event transaction. Critical for visibility:
|
|
# the reconnect endpoint reads the journal from a separate
|
|
# connection and only sees committed rows.
|
|
with db_session() as conn:
|
|
MessageEventsRepository(conn).record(
|
|
message_id, sequence_no, event_type, materialised_payload
|
|
)
|
|
journal_committed = True
|
|
except IntegrityError:
|
|
# Composite-PK collision on (message_id, sequence_no). Most
|
|
# likely cause is a stale ``latest_sequence_no`` seed on a
|
|
# continuation retry — the route read MAX(seq) from a separate
|
|
# connection before another writer committed past it. Look up
|
|
# the live latest and retry once with latest+1 so the event is
|
|
# not silently lost. Bounded to a single retry — if two
|
|
# writers keep racing in lockstep the route-level retry will
|
|
# converge them across attempts.
|
|
try:
|
|
with db_readonly() as conn:
|
|
latest = MessageEventsRepository(conn).latest_sequence_no(
|
|
message_id
|
|
)
|
|
materialised_seq = (latest if latest is not None else -1) + 1
|
|
with db_session() as conn:
|
|
MessageEventsRepository(conn).record(
|
|
message_id,
|
|
materialised_seq,
|
|
event_type,
|
|
materialised_payload,
|
|
)
|
|
journal_committed = True
|
|
logger.info(
|
|
"record_event: collision at seq=%s recovered → wrote at "
|
|
"seq=%s message_id=%s type=%s",
|
|
sequence_no,
|
|
materialised_seq,
|
|
message_id,
|
|
event_type,
|
|
)
|
|
except IntegrityError:
|
|
# Second collision under the same retry — give up and log.
|
|
# The route's nonlocal counter will continue at
|
|
# ``sequence_no+1`` on the next emit; the next call may
|
|
# land cleanly past the contended window.
|
|
logger.warning(
|
|
"record_event: IntegrityError persists after seq+1 retry; "
|
|
"dropping. message_id=%s original_seq=%s retry_seq=%s "
|
|
"type=%s",
|
|
message_id,
|
|
sequence_no,
|
|
materialised_seq,
|
|
event_type,
|
|
)
|
|
except Exception:
|
|
logger.exception(
|
|
"record_event: retry path failed unexpectedly "
|
|
"(message_id=%s seq=%s type=%s)",
|
|
message_id,
|
|
sequence_no,
|
|
event_type,
|
|
)
|
|
except Exception:
|
|
logger.exception(
|
|
"message_events INSERT failed: message_id=%s seq=%s type=%s",
|
|
message_id,
|
|
sequence_no,
|
|
event_type,
|
|
)
|
|
|
|
try:
|
|
# Publish using ``materialised_seq`` so the live pubsub frame
|
|
# matches the journal row that other clients will snapshot on
|
|
# reconnect. The original POST stream's SSE ``id:`` still
|
|
# carries the caller's ``sequence_no`` — a reconnect from that
|
|
# client will receive the same event at ``materialised_seq``
|
|
# on the snapshot, which is a benign duplicate (the slice's
|
|
# ``max_replayed_seq`` advances past it). No-collision case:
|
|
# ``materialised_seq == sequence_no`` and this is identical to
|
|
# the prior behaviour.
|
|
wire = encode_pubsub_message(
|
|
message_id, materialised_seq, event_type, materialised_payload
|
|
)
|
|
Topic(message_topic_name(message_id)).publish(wire)
|
|
except Exception:
|
|
logger.exception(
|
|
"channel:%s publish failed: seq=%s type=%s",
|
|
message_id,
|
|
materialised_seq,
|
|
event_type,
|
|
)
|
|
|
|
return journal_committed
|
|
|
|
|
|
class BatchedJournalWriter:
|
|
"""Per-stream journal writer that batches PG INSERTs.
|
|
|
|
Replaces the per-emit synchronous ``record_event`` call inside the
|
|
streaming hot path (``complete_stream``). One writer is created per
|
|
``message_id``; each yield calls ``record()``, and the writer flushes
|
|
on three independent triggers:
|
|
|
|
1. **Size** — buffer reaches ``batch_size`` entries.
|
|
2. **Time** — ``batch_interval_ms`` elapsed since the last flush.
|
|
3. **Lifecycle** — caller invokes ``close()`` at end of stream.
|
|
|
|
Live pubsub publishes still fire synchronously per ``record()`` call,
|
|
so subscribers see events in real time — only the durable journal
|
|
write is amortized. The cost is a small reconnect-visibility lag
|
|
(≤ ``batch_interval_ms``) for events still sitting in the buffer.
|
|
|
|
Flush is best-effort: a failed batch logs and continues. On
|
|
``IntegrityError`` (typically a stale-seq seed on a continuation
|
|
retry) the writer falls back to per-row ``record_event`` so a
|
|
single colliding seq doesn't drop the rest of the batch.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
message_id: str,
|
|
*,
|
|
batch_size: int = DEFAULT_BATCH_SIZE,
|
|
batch_interval_ms: int = DEFAULT_BATCH_INTERVAL_MS,
|
|
) -> None:
|
|
self._message_id = message_id
|
|
self._batch_size = batch_size
|
|
self._batch_interval_ms = batch_interval_ms
|
|
self._buffer: list[tuple[int, str, dict[str, Any]]] = []
|
|
self._last_flush_mono_ms = time.monotonic() * 1000.0
|
|
self._closed = False
|
|
|
|
def record(
|
|
self,
|
|
sequence_no: int,
|
|
event_type: str,
|
|
payload: Optional[dict[str, Any]] = None,
|
|
) -> bool:
|
|
"""Buffer one event; publish live; maybe flush.
|
|
|
|
Returns ``True`` if the event was accepted into the buffer.
|
|
``False`` rejects come from contract violations (empty
|
|
``event_type``, non-dict payload) — same gates as
|
|
``record_event``. The event reaches the journal asynchronously
|
|
on the next flush; callers that need synchronous visibility
|
|
(e.g. terminal events written from an abort handler outside
|
|
the streaming loop) should call ``flush()`` then use
|
|
``record_event`` directly.
|
|
"""
|
|
if self._closed:
|
|
logger.warning(
|
|
"BatchedJournalWriter.record after close: "
|
|
"message_id=%s seq=%s type=%s",
|
|
self._message_id,
|
|
sequence_no,
|
|
event_type,
|
|
)
|
|
return False
|
|
if not event_type:
|
|
logger.warning(
|
|
"BatchedJournalWriter.record without event_type: "
|
|
"message_id=%s seq=%s",
|
|
self._message_id,
|
|
sequence_no,
|
|
)
|
|
return False
|
|
if payload is None:
|
|
materialised: dict[str, Any] = {}
|
|
elif isinstance(payload, dict):
|
|
materialised = payload
|
|
else:
|
|
# Same contract as ``record_event`` — non-dict payloads
|
|
# are rejected so the live and replay paths can't diverge
|
|
# on envelope reconstruction.
|
|
logger.warning(
|
|
"BatchedJournalWriter.record with non-dict payload: "
|
|
"message_id=%s seq=%s type=%s payload_type=%s — dropping",
|
|
self._message_id,
|
|
sequence_no,
|
|
event_type,
|
|
type(payload).__name__,
|
|
)
|
|
return False
|
|
|
|
self._buffer.append((sequence_no, event_type, materialised))
|
|
|
|
# Publish live synchronously so subscribers see events in
|
|
# real time. Failures are logged and don't block the buffer.
|
|
try:
|
|
wire = encode_pubsub_message(
|
|
self._message_id, sequence_no, event_type, materialised
|
|
)
|
|
Topic(message_topic_name(self._message_id)).publish(wire)
|
|
except Exception:
|
|
logger.exception(
|
|
"channel:%s publish failed: seq=%s type=%s",
|
|
self._message_id,
|
|
sequence_no,
|
|
event_type,
|
|
)
|
|
|
|
if self._should_flush():
|
|
self.flush()
|
|
return True
|
|
|
|
def _should_flush(self) -> bool:
|
|
if len(self._buffer) >= self._batch_size:
|
|
return True
|
|
elapsed_ms = (time.monotonic() * 1000.0) - self._last_flush_mono_ms
|
|
return elapsed_ms >= self._batch_interval_ms and len(self._buffer) > 0
|
|
|
|
def flush(self) -> None:
|
|
"""Commit buffered events to PG. Best-effort.
|
|
|
|
Tries one bulk INSERT first; on ``IntegrityError`` (composite
|
|
PK collision — typically a stale continuation seed) falls back
|
|
to per-row ``record_event`` so one bad seq doesn't drop the
|
|
rest of the batch. Always clears the buffer to bound memory,
|
|
even on failure — a journaled event missing from a snapshot
|
|
is degraded UX, but a runaway buffer is corruption.
|
|
"""
|
|
if not self._buffer:
|
|
self._last_flush_mono_ms = time.monotonic() * 1000.0
|
|
return
|
|
|
|
# Snapshot and clear before the I/O so a concurrent record()
|
|
# call would land in a fresh buffer rather than racing the
|
|
# flush. ``complete_stream`` is single-threaded per stream, so
|
|
# this is belt-and-suspenders for any future change.
|
|
pending = self._buffer
|
|
self._buffer = []
|
|
self._last_flush_mono_ms = time.monotonic() * 1000.0
|
|
|
|
try:
|
|
with db_session() as conn:
|
|
MessageEventsRepository(conn).bulk_record(
|
|
self._message_id, pending
|
|
)
|
|
except IntegrityError:
|
|
logger.info(
|
|
"BatchedJournalWriter: bulk INSERT collided for "
|
|
"message_id=%s n=%d; falling back to per-row writes",
|
|
self._message_id,
|
|
len(pending),
|
|
)
|
|
self._flush_per_row(pending)
|
|
except Exception:
|
|
logger.exception(
|
|
"BatchedJournalWriter: bulk INSERT failed for "
|
|
"message_id=%s n=%d; events dropped from journal",
|
|
self._message_id,
|
|
len(pending),
|
|
)
|
|
|
|
def _flush_per_row(
|
|
self, pending: list[tuple[int, str, dict[str, Any]]]
|
|
) -> None:
|
|
"""Per-row fallback after a bulk collision.
|
|
|
|
Each row goes through ``record_event`` which already handles
|
|
``IntegrityError`` with a single seq+1 retry. The live publish
|
|
is skipped here — ``record()`` already fired it when the row
|
|
first entered the buffer, so re-publishing on the fallback
|
|
path would double-deliver.
|
|
"""
|
|
for seq, event_type, payload in pending:
|
|
journal_committed = False
|
|
try:
|
|
with db_session() as conn:
|
|
MessageEventsRepository(conn).record(
|
|
self._message_id, seq, event_type, payload
|
|
)
|
|
journal_committed = True
|
|
except IntegrityError:
|
|
try:
|
|
with db_readonly() as conn:
|
|
latest = MessageEventsRepository(
|
|
conn
|
|
).latest_sequence_no(self._message_id)
|
|
retry_seq = (latest if latest is not None else -1) + 1
|
|
with db_session() as conn:
|
|
MessageEventsRepository(conn).record(
|
|
self._message_id, retry_seq, event_type, payload
|
|
)
|
|
journal_committed = True
|
|
except IntegrityError:
|
|
logger.warning(
|
|
"BatchedJournalWriter: IntegrityError persists "
|
|
"after seq+1 retry; dropping. message_id=%s "
|
|
"original_seq=%s type=%s",
|
|
self._message_id,
|
|
seq,
|
|
event_type,
|
|
)
|
|
except Exception:
|
|
logger.exception(
|
|
"BatchedJournalWriter: per-row retry failed "
|
|
"(message_id=%s seq=%s type=%s)",
|
|
self._message_id,
|
|
seq,
|
|
event_type,
|
|
)
|
|
except Exception:
|
|
logger.exception(
|
|
"BatchedJournalWriter: per-row INSERT failed "
|
|
"(message_id=%s seq=%s type=%s)",
|
|
self._message_id,
|
|
seq,
|
|
event_type,
|
|
)
|
|
if not journal_committed:
|
|
# Stay silent in the loop — the per-row logger above
|
|
# already captured the failure. This branch is here
|
|
# to make the control flow explicit for readers.
|
|
pass
|
|
|
|
def close(self) -> None:
|
|
"""Final flush. Idempotent — safe to call from multiple
|
|
finally clauses.
|
|
"""
|
|
if self._closed:
|
|
return
|
|
self.flush()
|
|
self._closed = True
|