Files
DocsGPT/application/streaming/message_journal.py
Alex 249f9f9fe0 perf(streaming): batch message_events INSERTs per stream
complete_stream previously opened a fresh db_session() per yielded
event, doing one Postgres INSERT + commit per chunk on the WSGI
thread. Streaming answers emit ~100s of answer chunks per response,
so the route was paying ~100 PG roundtrips per stream serialized on
commit latency.

New BatchedJournalWriter in application/streaming/message_journal.py
accumulates rows per stream and flushes on three triggers:
- size: buffer reaches 16 entries
- time: 100ms elapsed since the last flush
- lifecycle: close() at end-of-stream

Live pubsub publishes still fire synchronously per record(), so
subscribers see events in real time — only the durable journal write
is amortized. On bulk INSERT IntegrityError the writer falls back to
per-row record() with the existing seq+1 retry so a single colliding
seq doesn't drop the rest of the batch.

complete_stream wires journal_writer.close() into every exit path
(happy end, tool-approval-paused end, GeneratorExit, error handler)
so the terminal event is committed before the generator returns —
otherwise a reconnecting client could snapshot up to the last flush
boundary and live-tail waiting for an end that's still in memory.

Repository gets bulk_record() — one SQLAlchemy executemany INSERT
for the bulk path. All-or-nothing on collision (Postgres aborts the
whole batch); the writer's per-row fallback handles recovery.
2026-05-12 18:20:19 +01:00

427 lines
17 KiB
Python

"""Per-yield journal write for the chat-stream snapshot+tail pattern.
``complete_stream`` calls ``record_event`` once per SSE event it
yields. The hook does two things:
1. Insert a row into ``message_events`` (the durable snapshot used by
reconnecting clients reading from a *different* connection).
2. Publish a JSON envelope to ``channel:{message_id}`` so any client
currently subscribed receives the event live.
Both are best-effort: failures are logged and swallowed, never raised
back into the streaming loop. A missed journal write means a client
that reconnects between this event and the next won't see this one in
their snapshot — degraded UX, not corrupted state. A missed publish
means currently-subscribed reconnect viewers miss the live tick;
they'll catch up via the snapshot on their next reconnect (or after
their poll-timeout cycle if they're already attached).
Each ``record_event`` opens its own short-lived ``db_session()`` so
the INSERT commits before the matching publish — without that ordering
a fast-reconnecting client could hit the snapshot read on a separate
connection and miss the row that's still uncommitted on the streaming
connection.
"""
from __future__ import annotations
import logging
import time
from typing import Any, Optional
from sqlalchemy.exc import IntegrityError
from application.storage.db.repositories.message_events import (
MessageEventsRepository,
)
from application.storage.db.session import db_readonly, db_session
from application.streaming.broadcast_channel import Topic
from application.streaming.event_replay import encode_pubsub_message
from application.streaming.keys import message_topic_name
logger = logging.getLogger(__name__)
# Tunables for ``BatchedJournalWriter``. A streaming answer emits ~100s
# of ``answer`` chunks per response; without batching, that's one PG
# transaction per yield in the WSGI thread. With these defaults, ~10x
# fewer commits at the cost of a ≤100ms reconnect-visibility lag for
# any event still sitting in the buffer.
DEFAULT_BATCH_SIZE = 16
DEFAULT_BATCH_INTERVAL_MS = 100
def record_event(
message_id: str,
sequence_no: int,
event_type: str,
payload: Optional[dict[str, Any]] = None,
) -> bool:
"""Journal one SSE event and publish it live. Best-effort.
``payload`` must be a ``dict`` (or ``None``). Passing a list,
scalar, or any other shape is a contract violation: the live path
in ``base.py::_emit`` and the replay path in
``event_replay`` previously reconstructed non-dicts differently
(``{"value": payload}`` vs. ``{"type": event_type}``), so a
reconnecting client would receive a different envelope than the
one originally streamed. Rejecting non-dicts at this gate keeps
the two paths byte-identical.
Returns ``True`` when the journal INSERT committed (the publish is
attempted regardless of insert outcome and isn't reflected in the
return value). Never raises — every failure path logs and swallows.
"""
if not message_id or not event_type:
logger.warning(
"record_event called without message_id/event_type "
"(message_id=%r, event_type=%r)",
message_id,
event_type,
)
return False
if payload is None:
materialised_payload: dict[str, Any] = {}
elif isinstance(payload, dict):
materialised_payload = payload
else:
logger.warning(
"record_event called with non-dict payload "
"(message_id=%s seq=%s type=%s payload_type=%s) — dropping",
message_id,
sequence_no,
event_type,
type(payload).__name__,
)
return False
journal_committed = False
# The seq we actually managed to write. Diverges from
# ``sequence_no`` only on the IntegrityError-retry path below.
materialised_seq = sequence_no
try:
# Short-lived per-event transaction. Critical for visibility:
# the reconnect endpoint reads the journal from a separate
# connection and only sees committed rows.
with db_session() as conn:
MessageEventsRepository(conn).record(
message_id, sequence_no, event_type, materialised_payload
)
journal_committed = True
except IntegrityError:
# Composite-PK collision on (message_id, sequence_no). Most
# likely cause is a stale ``latest_sequence_no`` seed on a
# continuation retry — the route read MAX(seq) from a separate
# connection before another writer committed past it. Look up
# the live latest and retry once with latest+1 so the event is
# not silently lost. Bounded to a single retry — if two
# writers keep racing in lockstep the route-level retry will
# converge them across attempts.
try:
with db_readonly() as conn:
latest = MessageEventsRepository(conn).latest_sequence_no(
message_id
)
materialised_seq = (latest if latest is not None else -1) + 1
with db_session() as conn:
MessageEventsRepository(conn).record(
message_id,
materialised_seq,
event_type,
materialised_payload,
)
journal_committed = True
logger.info(
"record_event: collision at seq=%s recovered → wrote at "
"seq=%s message_id=%s type=%s",
sequence_no,
materialised_seq,
message_id,
event_type,
)
except IntegrityError:
# Second collision under the same retry — give up and log.
# The route's nonlocal counter will continue at
# ``sequence_no+1`` on the next emit; the next call may
# land cleanly past the contended window.
logger.warning(
"record_event: IntegrityError persists after seq+1 retry; "
"dropping. message_id=%s original_seq=%s retry_seq=%s "
"type=%s",
message_id,
sequence_no,
materialised_seq,
event_type,
)
except Exception:
logger.exception(
"record_event: retry path failed unexpectedly "
"(message_id=%s seq=%s type=%s)",
message_id,
sequence_no,
event_type,
)
except Exception:
logger.exception(
"message_events INSERT failed: message_id=%s seq=%s type=%s",
message_id,
sequence_no,
event_type,
)
try:
# Publish using ``materialised_seq`` so the live pubsub frame
# matches the journal row that other clients will snapshot on
# reconnect. The original POST stream's SSE ``id:`` still
# carries the caller's ``sequence_no`` — a reconnect from that
# client will receive the same event at ``materialised_seq``
# on the snapshot, which is a benign duplicate (the slice's
# ``max_replayed_seq`` advances past it). No-collision case:
# ``materialised_seq == sequence_no`` and this is identical to
# the prior behaviour.
wire = encode_pubsub_message(
message_id, materialised_seq, event_type, materialised_payload
)
Topic(message_topic_name(message_id)).publish(wire)
except Exception:
logger.exception(
"channel:%s publish failed: seq=%s type=%s",
message_id,
materialised_seq,
event_type,
)
return journal_committed
class BatchedJournalWriter:
"""Per-stream journal writer that batches PG INSERTs.
Replaces the per-emit synchronous ``record_event`` call inside the
streaming hot path (``complete_stream``). One writer is created per
``message_id``; each yield calls ``record()``, and the writer flushes
on three independent triggers:
1. **Size** — buffer reaches ``batch_size`` entries.
2. **Time** — ``batch_interval_ms`` elapsed since the last flush.
3. **Lifecycle** — caller invokes ``close()`` at end of stream.
Live pubsub publishes still fire synchronously per ``record()`` call,
so subscribers see events in real time — only the durable journal
write is amortized. The cost is a small reconnect-visibility lag
(≤ ``batch_interval_ms``) for events still sitting in the buffer.
Flush is best-effort: a failed batch logs and continues. On
``IntegrityError`` (typically a stale-seq seed on a continuation
retry) the writer falls back to per-row ``record_event`` so a
single colliding seq doesn't drop the rest of the batch.
"""
def __init__(
self,
message_id: str,
*,
batch_size: int = DEFAULT_BATCH_SIZE,
batch_interval_ms: int = DEFAULT_BATCH_INTERVAL_MS,
) -> None:
self._message_id = message_id
self._batch_size = batch_size
self._batch_interval_ms = batch_interval_ms
self._buffer: list[tuple[int, str, dict[str, Any]]] = []
self._last_flush_mono_ms = time.monotonic() * 1000.0
self._closed = False
def record(
self,
sequence_no: int,
event_type: str,
payload: Optional[dict[str, Any]] = None,
) -> bool:
"""Buffer one event; publish live; maybe flush.
Returns ``True`` if the event was accepted into the buffer.
``False`` rejects come from contract violations (empty
``event_type``, non-dict payload) — same gates as
``record_event``. The event reaches the journal asynchronously
on the next flush; callers that need synchronous visibility
(e.g. terminal events written from an abort handler outside
the streaming loop) should call ``flush()`` then use
``record_event`` directly.
"""
if self._closed:
logger.warning(
"BatchedJournalWriter.record after close: "
"message_id=%s seq=%s type=%s",
self._message_id,
sequence_no,
event_type,
)
return False
if not event_type:
logger.warning(
"BatchedJournalWriter.record without event_type: "
"message_id=%s seq=%s",
self._message_id,
sequence_no,
)
return False
if payload is None:
materialised: dict[str, Any] = {}
elif isinstance(payload, dict):
materialised = payload
else:
# Same contract as ``record_event`` — non-dict payloads
# are rejected so the live and replay paths can't diverge
# on envelope reconstruction.
logger.warning(
"BatchedJournalWriter.record with non-dict payload: "
"message_id=%s seq=%s type=%s payload_type=%s — dropping",
self._message_id,
sequence_no,
event_type,
type(payload).__name__,
)
return False
self._buffer.append((sequence_no, event_type, materialised))
# Publish live synchronously so subscribers see events in
# real time. Failures are logged and don't block the buffer.
try:
wire = encode_pubsub_message(
self._message_id, sequence_no, event_type, materialised
)
Topic(message_topic_name(self._message_id)).publish(wire)
except Exception:
logger.exception(
"channel:%s publish failed: seq=%s type=%s",
self._message_id,
sequence_no,
event_type,
)
if self._should_flush():
self.flush()
return True
def _should_flush(self) -> bool:
if len(self._buffer) >= self._batch_size:
return True
elapsed_ms = (time.monotonic() * 1000.0) - self._last_flush_mono_ms
return elapsed_ms >= self._batch_interval_ms and len(self._buffer) > 0
def flush(self) -> None:
"""Commit buffered events to PG. Best-effort.
Tries one bulk INSERT first; on ``IntegrityError`` (composite
PK collision — typically a stale continuation seed) falls back
to per-row ``record_event`` so one bad seq doesn't drop the
rest of the batch. Always clears the buffer to bound memory,
even on failure — a journaled event missing from a snapshot
is degraded UX, but a runaway buffer is corruption.
"""
if not self._buffer:
self._last_flush_mono_ms = time.monotonic() * 1000.0
return
# Snapshot and clear before the I/O so a concurrent record()
# call would land in a fresh buffer rather than racing the
# flush. ``complete_stream`` is single-threaded per stream, so
# this is belt-and-suspenders for any future change.
pending = self._buffer
self._buffer = []
self._last_flush_mono_ms = time.monotonic() * 1000.0
try:
with db_session() as conn:
MessageEventsRepository(conn).bulk_record(
self._message_id, pending
)
except IntegrityError:
logger.info(
"BatchedJournalWriter: bulk INSERT collided for "
"message_id=%s n=%d; falling back to per-row writes",
self._message_id,
len(pending),
)
self._flush_per_row(pending)
except Exception:
logger.exception(
"BatchedJournalWriter: bulk INSERT failed for "
"message_id=%s n=%d; events dropped from journal",
self._message_id,
len(pending),
)
def _flush_per_row(
self, pending: list[tuple[int, str, dict[str, Any]]]
) -> None:
"""Per-row fallback after a bulk collision.
Each row goes through ``record_event`` which already handles
``IntegrityError`` with a single seq+1 retry. The live publish
is skipped here — ``record()`` already fired it when the row
first entered the buffer, so re-publishing on the fallback
path would double-deliver.
"""
for seq, event_type, payload in pending:
journal_committed = False
try:
with db_session() as conn:
MessageEventsRepository(conn).record(
self._message_id, seq, event_type, payload
)
journal_committed = True
except IntegrityError:
try:
with db_readonly() as conn:
latest = MessageEventsRepository(
conn
).latest_sequence_no(self._message_id)
retry_seq = (latest if latest is not None else -1) + 1
with db_session() as conn:
MessageEventsRepository(conn).record(
self._message_id, retry_seq, event_type, payload
)
journal_committed = True
except IntegrityError:
logger.warning(
"BatchedJournalWriter: IntegrityError persists "
"after seq+1 retry; dropping. message_id=%s "
"original_seq=%s type=%s",
self._message_id,
seq,
event_type,
)
except Exception:
logger.exception(
"BatchedJournalWriter: per-row retry failed "
"(message_id=%s seq=%s type=%s)",
self._message_id,
seq,
event_type,
)
except Exception:
logger.exception(
"BatchedJournalWriter: per-row INSERT failed "
"(message_id=%s seq=%s type=%s)",
self._message_id,
seq,
event_type,
)
if not journal_committed:
# Stay silent in the loop — the per-row logger above
# already captured the failure. This branch is here
# to make the control flow explicit for readers.
pass
def close(self) -> None:
"""Final flush. Idempotent — safe to call from multiple
finally clauses.
"""
if self._closed:
return
self.flush()
self._closed = True