mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-30 01:07:10 +00:00
106 lines
3.6 KiB
Python
106 lines
3.6 KiB
Python
"""Diff-based WebSocket output protocol for WhisperLiveKit.
|
|
|
|
Instead of sending the full FrontData state on every update, the DiffTracker
|
|
computes incremental diffs — only sending new/changed lines and volatile fields.
|
|
|
|
Protocol
|
|
--------
|
|
Opt-in via query parameter: ``ws://host:port/asr?mode=diff``
|
|
|
|
First message from server:
|
|
``{"type": "snapshot", "seq": 1, ...full state...}``
|
|
|
|
Subsequent messages:
|
|
``{"type": "diff", "seq": N, "new_lines": [...], ...}``
|
|
|
|
The client reconstructs state by:
|
|
1. On ``"snapshot"``: replace all state.
|
|
2. On ``"diff"``:
|
|
- If ``lines_pruned`` > 0: drop that many lines from the front.
|
|
- Append ``new_lines`` to the end.
|
|
- Replace ``buffer_*`` and ``remaining_time_*`` fields.
|
|
- Use ``n_lines`` to verify sync (total expected line count).
|
|
"""
|
|
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Dict, List
|
|
|
|
from whisperlivekit.timed_objects import FrontData
|
|
|
|
|
|
@dataclass
|
|
class DiffTracker:
|
|
"""Tracks FrontData state and computes incremental diffs."""
|
|
|
|
seq: int = 0
|
|
_prev_lines: List[Dict[str, Any]] = field(default_factory=list)
|
|
_sent_snapshot: bool = False
|
|
|
|
def to_message(self, front_data: FrontData) -> Dict[str, Any]:
|
|
"""Convert a FrontData into a diff or snapshot message.
|
|
|
|
First call returns a full snapshot. Subsequent calls return diffs
|
|
containing only changed/new data.
|
|
"""
|
|
self.seq += 1
|
|
full = front_data.to_dict()
|
|
current_lines = full["lines"]
|
|
|
|
if not self._sent_snapshot:
|
|
self._sent_snapshot = True
|
|
self._prev_lines = current_lines[:]
|
|
return {"type": "snapshot", "seq": self.seq, **full}
|
|
|
|
# Compute diff
|
|
msg: Dict[str, Any] = {
|
|
"type": "diff",
|
|
"seq": self.seq,
|
|
"status": full["status"],
|
|
"n_lines": len(current_lines),
|
|
"buffer_transcription": full["buffer_transcription"],
|
|
"buffer_diarization": full["buffer_diarization"],
|
|
"buffer_translation": full["buffer_translation"],
|
|
"remaining_time_transcription": full["remaining_time_transcription"],
|
|
"remaining_time_diarization": full["remaining_time_diarization"],
|
|
}
|
|
if full.get("error"):
|
|
msg["error"] = full["error"]
|
|
|
|
# Detect front-pruning: find where current[0] appears in prev
|
|
prune_offset = 0
|
|
if current_lines and self._prev_lines:
|
|
first_current = current_lines[0]
|
|
for i, prev_line in enumerate(self._prev_lines):
|
|
if prev_line == first_current:
|
|
prune_offset = i
|
|
break
|
|
else:
|
|
# current[0] not found in prev — treat all prev as pruned
|
|
prune_offset = len(self._prev_lines)
|
|
elif not current_lines:
|
|
prune_offset = len(self._prev_lines)
|
|
|
|
if prune_offset > 0:
|
|
msg["lines_pruned"] = prune_offset
|
|
|
|
# Find common prefix starting after pruned lines
|
|
common = 0
|
|
remaining_prev = len(self._prev_lines) - prune_offset
|
|
min_len = min(remaining_prev, len(current_lines))
|
|
while common < min_len and self._prev_lines[prune_offset + common] == current_lines[common]:
|
|
common += 1
|
|
|
|
# New or changed lines after the common prefix
|
|
new_lines = current_lines[common:]
|
|
if new_lines:
|
|
msg["new_lines"] = new_lines
|
|
|
|
self._prev_lines = current_lines[:]
|
|
return msg
|
|
|
|
def reset(self) -> None:
|
|
"""Reset state so the next call produces a fresh snapshot."""
|
|
self.seq = 0
|
|
self._prev_lines = []
|
|
self._sent_snapshot = False
|