mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-05-06 15:16:27 +00:00
update benchmark with qwen3 which reuses kv cache
This commit is contained in:
@@ -126,13 +126,13 @@ uv sync --extra cu129 --extra voxtral-hf --extra translation
|
||||
See **Parameters & Configuration** below on how to use them.
|
||||
|
||||
<p align="center">
|
||||
<img src="benchmark_scatter_en_unaware.png" alt="Speed vs Accuracy — English, compute-unaware" width="700">
|
||||
<img src="benchmark_scatter_en_aware.png" alt="Speed vs Accuracy — English" width="700">
|
||||
</p>
|
||||
<p align="center">
|
||||
<img src="benchmark_scatter_en_aware.png" alt="Speed vs Accuracy — English, compute-aware" width="700">
|
||||
<img src="benchmark_scatter_fr_aware.png" alt="Speed vs Accuracy — French" width="700">
|
||||
</p>
|
||||
|
||||
Benchmarks use public audio from [LibriSpeech](https://huggingface.co/datasets/openslr/librispeech_asr) and [Multilingual LibriSpeech](https://huggingface.co/datasets/facebook/multilingual_librispeech) — fully reproducible with `python scripts/run_scatter_benchmark.py`.
|
||||
Benchmarks use 6 minutes of public [LibriVox](https://librivox.org/) audiobook recordings per language (30s + 60s + 120s + 180s), with ground truth from [Project Gutenberg](https://www.gutenberg.org/). Fully reproducible with `python scripts/run_scatter_benchmark.py`.
|
||||
We are actively looking for benchmark results on other hardware (NVIDIA GPUs, different Apple Silicon chips, cloud instances). If you run the benchmarks on your machine, please share your results via an issue or PR!
|
||||
|
||||
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 89 KiB After Width: | Height: | Size: 95 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 94 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 92 KiB After Width: | Height: | Size: 95 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 91 KiB |
@@ -227,28 +227,33 @@ def generate_scatter(results, system_info, output_path, n_samples, lang="en",
|
||||
fig, ax = plt.subplots(figsize=(12, 7), facecolor="white")
|
||||
ax.set_facecolor("#fafafa")
|
||||
|
||||
# Separate main cluster from outliers (RTF > 1.0)
|
||||
main = [r for r in results if r["rtf"] <= 1.0]
|
||||
slow = [r for r in results if r["rtf"] > 1.0]
|
||||
# Show ALL points on chart (no outlier exclusion)
|
||||
main = results
|
||||
slow = []
|
||||
|
||||
# Axis limits: tight around main data
|
||||
# Axis limits: fit all data
|
||||
if main:
|
||||
xmax = max(r["rtf"] for r in main) * 1.6
|
||||
ymax = max(r["wer_pct"] for r in main) * 1.5 + 1
|
||||
xmax = max(r["rtf"] for r in main) * 1.15
|
||||
ymax = max(r["wer_pct"] for r in main) * 1.15 + 1
|
||||
else:
|
||||
xmax, ymax = 0.5, 10
|
||||
xmax = max(xmax, 0.45)
|
||||
xmax = max(xmax, 1.15) # always show the real-time line
|
||||
ymax = max(ymax, 8)
|
||||
|
||||
# Sweet spot zone
|
||||
sweet_x = xmax * 0.85
|
||||
sweet_y = ymax * 0.55
|
||||
# Sweet spot zone: RTF < 1.0 (real-time) and WER < 12%
|
||||
sweet_x = min(1.0, xmax * 0.85)
|
||||
sweet_y = min(12, ymax * 0.45)
|
||||
rect = plt.Rectangle((0, 0), sweet_x, sweet_y, alpha=0.07, color="#4ecca3",
|
||||
zorder=0, linewidth=0)
|
||||
ax.add_patch(rect)
|
||||
ax.text(sweet_x - 0.005, sweet_y - 0.15, "sweet spot", ha="right", va="top",
|
||||
fontsize=10, color="#2ecc71", fontstyle="italic", fontweight="bold", alpha=0.5)
|
||||
|
||||
# Real-time limit line
|
||||
ax.axvline(x=1.0, color="#e94560", linestyle="--", linewidth=1.5, alpha=0.4, zorder=1)
|
||||
ax.text(1.02, ymax * 0.97, "real-time\nlimit", fontsize=8, color="#e94560",
|
||||
va="top", alpha=0.6)
|
||||
|
||||
# Manual label offsets keyed by label name — hand-tuned
|
||||
OFFSETS = {
|
||||
"fw LA base": (8, 8),
|
||||
|
||||
@@ -34,6 +34,11 @@ logger = logging.getLogger(__name__)
|
||||
# Decoder sliding-window size (matches the model's training configuration).
|
||||
_DECODER_WINDOW = 8192
|
||||
|
||||
# Maximum continuous decoding positions before forcing a reset.
|
||||
# Beyond ~20s of continuous audio the autoregressive context drifts and
|
||||
# produces hallucination. 20s / 80ms per token = 250 tokens.
|
||||
_MAX_CONTINUOUS_POSITIONS = 250
|
||||
|
||||
|
||||
def _prompt_tokens(tokenizer, n_left_pad=LEFT_PAD_TOKENS, n_delay=6):
|
||||
"""Build the prompt token sequence and return ``(token_ids, n_delay)``."""
|
||||
@@ -152,6 +157,7 @@ class VoxtralMLXOnlineProcessor:
|
||||
self._last_token: mx.array | None = None
|
||||
# Bookkeeping
|
||||
self._samples_encoded = 0
|
||||
self._real_samples_encoded = 0 # only real audio, excludes silence padding
|
||||
self._positions_decoded = 0
|
||||
self._prefilled = False
|
||||
self._first_chunk = True
|
||||
@@ -191,6 +197,7 @@ class VoxtralMLXOnlineProcessor:
|
||||
self.end = audio_stream_end_time
|
||||
self._pending_chunks.append(audio)
|
||||
self._pending_len += len(audio)
|
||||
self._real_samples_encoded += len(audio)
|
||||
self.audio_buffer = audio # diagnostic only
|
||||
|
||||
# -- core processing --
|
||||
@@ -203,14 +210,28 @@ class VoxtralMLXOnlineProcessor:
|
||||
return [], self.end
|
||||
|
||||
def _step(self, is_last: bool) -> Tuple[List[ASRToken], float]:
|
||||
# 0. Safety cap: if continuous decoding exceeds the limit, force a
|
||||
# flush+reset to prevent hallucination even without VAD silence.
|
||||
if self._prefilled and self._positions_decoded >= _MAX_CONTINUOUS_POSITIONS + self._prefix_len:
|
||||
logger.info(
|
||||
"[voxtral-mlx] continuous decoding cap hit at %d positions — "
|
||||
"forcing flush+reset",
|
||||
self._positions_decoded,
|
||||
)
|
||||
words = self._flush_and_reset()
|
||||
return words, self.end
|
||||
|
||||
# 1. Encode any new audio
|
||||
self._encode_pending()
|
||||
|
||||
if self._audio_embeds is None:
|
||||
return [], self.end
|
||||
|
||||
# 2. Compute how many positions we can safely decode
|
||||
total_safe = LEFT_PAD_TOKENS + self._samples_encoded // SAMPLES_PER_TOKEN
|
||||
# 2. Compute how many positions we can safely decode.
|
||||
# The safe boundary prevents the decoder from running ahead of the
|
||||
# audio encoder. _samples_encoded tracks only real audio (not
|
||||
# silence padding), so positions beyond this produce hallucination.
|
||||
total_safe = LEFT_PAD_TOKENS + self._real_samples_encoded // SAMPLES_PER_TOKEN
|
||||
n_available = self._audio_embeds.shape[0]
|
||||
n_decodable = min(n_available, total_safe - self._positions_decoded)
|
||||
|
||||
@@ -229,11 +250,19 @@ class VoxtralMLXOnlineProcessor:
|
||||
if n_decodable <= 0 or self._audio_embeds is None:
|
||||
return [], self.end
|
||||
|
||||
# Clamp to the continuous decoding cap so we don't overshoot
|
||||
max_left = _MAX_CONTINUOUS_POSITIONS + self._prefix_len - self._positions_decoded
|
||||
if max_left > 0:
|
||||
n_decodable = min(n_decodable, max_left)
|
||||
else:
|
||||
# Will be caught by the cap check on the next call
|
||||
return self._extract_committed_words(), self.end
|
||||
|
||||
# 4. Decode available positions
|
||||
hit_eos = self._decode_positions(n_decodable)
|
||||
|
||||
if hit_eos:
|
||||
# Flush words, reset for next utterance
|
||||
# Flush words, then full reset for next utterance
|
||||
words = self._flush_all_words()
|
||||
logger.debug(
|
||||
"[voxtral-mlx] EOS hit during stream: flushed %d words, "
|
||||
@@ -242,9 +271,12 @@ class VoxtralMLXOnlineProcessor:
|
||||
self._samples_encoded / self.SAMPLING_RATE,
|
||||
self._full_text[-60:] if self._full_text else "",
|
||||
)
|
||||
saved_offset = self._time_offset
|
||||
new_offset = self._time_offset + self._real_samples_encoded / self.SAMPLING_RATE
|
||||
saved_end = self.end
|
||||
self._reset_state()
|
||||
self._time_offset = saved_offset
|
||||
self._time_offset = new_offset
|
||||
self.end = saved_end
|
||||
mx.clear_cache()
|
||||
return words, self.end
|
||||
|
||||
# 5. Extract committed words (all but the last, which may still grow)
|
||||
@@ -451,12 +483,66 @@ class VoxtralMLXOnlineProcessor:
|
||||
return Transcript(start=self.end, end=self.end, text=" ".join(remaining))
|
||||
return Transcript(start=None, end=None, text="")
|
||||
|
||||
def start_silence(self) -> Tuple[List[ASRToken], float]:
|
||||
"""Flush all pending words when silence starts.
|
||||
def _safe_decode_remaining(self):
|
||||
"""Decode remaining audio embeddings, respecting the safe boundary.
|
||||
|
||||
Adds right-padding silence and forces a full decode pass so the
|
||||
decoder emits tokens for the last words of speech. Without this,
|
||||
the model holds back the final tokens waiting for future context.
|
||||
Uses the same guard as ``_step`` to avoid decoding positions that
|
||||
are beyond the real audio frontier, which causes hallucination.
|
||||
"""
|
||||
if self._audio_embeds is None or not self._prefilled:
|
||||
return
|
||||
# Use the same formula as _step() — this excludes padding positions
|
||||
total_safe = LEFT_PAD_TOKENS + self._samples_encoded // SAMPLES_PER_TOKEN
|
||||
n_available = self._audio_embeds.shape[0]
|
||||
n_decodable = min(n_available, max(0, total_safe - self._positions_decoded))
|
||||
# Cap at RIGHT_PAD_TOKENS to only decode the padding needed for
|
||||
# the model to emit final tokens, not all accumulated padding
|
||||
n_decodable = min(n_decodable, RIGHT_PAD_TOKENS)
|
||||
if n_decodable > 0:
|
||||
self._decode_positions(n_decodable)
|
||||
|
||||
def _flush_last_token_text(self):
|
||||
"""Add the last pending token's text (if not EOS) to _full_text."""
|
||||
if self._last_token is None:
|
||||
return
|
||||
tid = self._last_token.item()
|
||||
if tid == self._eos_id:
|
||||
return
|
||||
text = self._tokenizer.decode(
|
||||
[tid], special_token_policy=SpecialTokenPolicy.IGNORE
|
||||
)
|
||||
if not text:
|
||||
return
|
||||
last_pos = self._positions_decoded - self._prefix_len
|
||||
if text.lstrip() != text or not self._full_text:
|
||||
if self._current_word_pos is not None:
|
||||
self._word_audio_ends.append(last_pos)
|
||||
self._word_audio_starts.append(last_pos)
|
||||
self._current_word_pos = last_pos
|
||||
elif self._current_word_pos is None:
|
||||
self._word_audio_starts.append(last_pos)
|
||||
self._current_word_pos = last_pos
|
||||
self._full_text += text
|
||||
self._n_text_tokens += 1
|
||||
|
||||
def _close_current_word(self):
|
||||
"""Close the last word if one is being built."""
|
||||
if self._current_word_pos is not None:
|
||||
last_pos = self._positions_decoded - self._prefix_len
|
||||
self._word_audio_ends.append(last_pos)
|
||||
self._current_word_pos = None
|
||||
|
||||
def _flush_and_reset(self) -> List[ASRToken]:
|
||||
"""Flush pending audio, decode remaining, extract all words, then
|
||||
fully reset both encoder and decoder state.
|
||||
|
||||
Used at silence boundaries and when the continuous decoding cap is
|
||||
hit. A full reset (encoder + decoder) is necessary because the
|
||||
encoder's incremental state (conv tails, KV caches) contains history
|
||||
that would produce embeddings incompatible with a freshly-initialised
|
||||
decoder. After reset ``_first_chunk=True``, so the next audio chunk
|
||||
receives proper left-padding and both encoder and decoder start in
|
||||
sync.
|
||||
"""
|
||||
# Align pending audio to SAMPLES_PER_TOKEN boundary
|
||||
remainder = self._pending_len % SAMPLES_PER_TOKEN
|
||||
@@ -471,37 +557,40 @@ class VoxtralMLXOnlineProcessor:
|
||||
# Encode remaining audio (including right-padding)
|
||||
self._encode_pending()
|
||||
|
||||
# Decode everything that's left
|
||||
if self._audio_embeds is not None and self._prefilled:
|
||||
self._decode_positions(self._audio_embeds.shape[0])
|
||||
# Decode only positions backed by real audio
|
||||
self._safe_decode_remaining()
|
||||
|
||||
# Flush last token if it wasn't EOS
|
||||
if self._last_token is not None:
|
||||
tid = self._last_token.item()
|
||||
if tid != self._eos_id:
|
||||
text = self._tokenizer.decode(
|
||||
[tid], special_token_policy=SpecialTokenPolicy.IGNORE
|
||||
)
|
||||
if text:
|
||||
last_pos = self._positions_decoded - self._prefix_len
|
||||
if text.lstrip() != text or not self._full_text:
|
||||
if self._current_word_pos is not None:
|
||||
self._word_audio_ends.append(last_pos)
|
||||
self._word_audio_starts.append(last_pos)
|
||||
self._current_word_pos = last_pos
|
||||
elif self._current_word_pos is None:
|
||||
self._word_audio_starts.append(last_pos)
|
||||
self._current_word_pos = last_pos
|
||||
self._full_text += text
|
||||
self._n_text_tokens += 1
|
||||
|
||||
# Close the last word if still open
|
||||
if self._current_word_pos is not None:
|
||||
last_pos = self._positions_decoded - self._prefix_len
|
||||
self._word_audio_ends.append(last_pos)
|
||||
self._current_word_pos = None
|
||||
self._flush_last_token_text()
|
||||
self._close_current_word()
|
||||
|
||||
words = self._flush_all_words()
|
||||
|
||||
# Compute time offset: the decoded audio covers up to this point
|
||||
new_offset = self._time_offset + self._real_samples_encoded / self.SAMPLING_RATE
|
||||
saved_end = self.end
|
||||
|
||||
# Full reset — encoder AND decoder. The encoder's incremental
|
||||
# state (conv tails, transformer KV caches) carries history from
|
||||
# the previous segment; keeping it would make the next set of
|
||||
# embeddings incompatible with a fresh decoder prefill.
|
||||
self._reset_state()
|
||||
self._time_offset = new_offset
|
||||
self.end = saved_end
|
||||
|
||||
# Free MLX caches eagerly
|
||||
mx.clear_cache()
|
||||
|
||||
return words
|
||||
|
||||
def start_silence(self) -> Tuple[List[ASRToken], float]:
|
||||
"""Flush all pending words when silence starts, then fully reset.
|
||||
|
||||
Adds right-padding silence and forces a decode pass so the
|
||||
decoder emits tokens for the last words of speech. After flushing,
|
||||
resets both encoder and decoder state to prevent hallucination from
|
||||
accumulated autoregressive context drift on long audio.
|
||||
"""
|
||||
words = self._flush_and_reset()
|
||||
logger.info("[voxtral-mlx] start_silence: flushed %d words", len(words))
|
||||
return words, self.end
|
||||
|
||||
@@ -529,10 +618,7 @@ class VoxtralMLXOnlineProcessor:
|
||||
|
||||
# Align pending audio to SAMPLES_PER_TOKEN boundary so nothing is lost
|
||||
remainder = self._pending_len % SAMPLES_PER_TOKEN
|
||||
if remainder > 0:
|
||||
align_pad = SAMPLES_PER_TOKEN - remainder
|
||||
else:
|
||||
align_pad = 0
|
||||
align_pad = (SAMPLES_PER_TOKEN - remainder) if remainder > 0 else 0
|
||||
|
||||
# Add alignment + right-padding silence
|
||||
total_pad = align_pad + RIGHT_PAD_TOKENS * SAMPLES_PER_TOKEN
|
||||
@@ -543,48 +629,11 @@ class VoxtralMLXOnlineProcessor:
|
||||
# Encode remaining audio (including right-padding)
|
||||
self._encode_pending()
|
||||
|
||||
logger.debug(
|
||||
"[voxtral-mlx] finish after encode: audio_embeds=%s, pending=%d",
|
||||
self._audio_embeds.shape if self._audio_embeds is not None else None,
|
||||
self._pending_len,
|
||||
)
|
||||
# Decode only positions backed by real audio
|
||||
self._safe_decode_remaining()
|
||||
|
||||
hit_eos = False
|
||||
|
||||
# Decode everything that's left from right-padding
|
||||
if self._audio_embeds is not None and self._prefilled:
|
||||
hit_eos = self._decode_positions(self._audio_embeds.shape[0])
|
||||
logger.debug(
|
||||
"[voxtral-mlx] finish decode: hit_eos=%s, text='%s'",
|
||||
hit_eos, self._full_text[-80:] if self._full_text else "",
|
||||
)
|
||||
|
||||
# Flush last token if it wasn't EOS
|
||||
if self._last_token is not None:
|
||||
tid = self._last_token.item()
|
||||
if tid != self._eos_id:
|
||||
text = self._tokenizer.decode(
|
||||
[tid], special_token_policy=SpecialTokenPolicy.IGNORE
|
||||
)
|
||||
if text:
|
||||
last_pos = self._positions_decoded - self._prefix_len
|
||||
# Check if this starts a new word
|
||||
if text.lstrip() != text or not self._full_text:
|
||||
if self._current_word_pos is not None:
|
||||
self._word_audio_ends.append(last_pos)
|
||||
self._word_audio_starts.append(last_pos)
|
||||
self._current_word_pos = last_pos
|
||||
elif self._current_word_pos is None:
|
||||
self._word_audio_starts.append(last_pos)
|
||||
self._current_word_pos = last_pos
|
||||
self._full_text += text
|
||||
self._n_text_tokens += 1
|
||||
|
||||
# Close the last word if still open
|
||||
if self._current_word_pos is not None:
|
||||
last_pos = self._positions_decoded - self._prefix_len
|
||||
self._word_audio_ends.append(last_pos)
|
||||
self._current_word_pos = None
|
||||
self._flush_last_token_text()
|
||||
self._close_current_word()
|
||||
|
||||
words = self._flush_all_words()
|
||||
logger.info("[voxtral-mlx] finish: flushed %d words", len(words))
|
||||
|
||||
Reference in New Issue
Block a user