update benchmark with qwen3 which reuses kv cache

This commit is contained in:
Quentin Fuxa
2026-03-15 22:32:01 +01:00
parent dd48997674
commit a6a85431f6
7 changed files with 150 additions and 96 deletions

View File

@@ -126,13 +126,13 @@ uv sync --extra cu129 --extra voxtral-hf --extra translation
See **Parameters & Configuration** below on how to use them.
<p align="center">
<img src="benchmark_scatter_en_unaware.png" alt="Speed vs Accuracy — English, compute-unaware" width="700">
<img src="benchmark_scatter_en_aware.png" alt="Speed vs Accuracy — English" width="700">
</p>
<p align="center">
<img src="benchmark_scatter_en_aware.png" alt="Speed vs Accuracy — English, compute-aware" width="700">
<img src="benchmark_scatter_fr_aware.png" alt="Speed vs Accuracy — French" width="700">
</p>
Benchmarks use public audio from [LibriSpeech](https://huggingface.co/datasets/openslr/librispeech_asr) and [Multilingual LibriSpeech](https://huggingface.co/datasets/facebook/multilingual_librispeech) — fully reproducible with `python scripts/run_scatter_benchmark.py`.
Benchmarks use 6 minutes of public [LibriVox](https://librivox.org/) audiobook recordings per language (30s + 60s + 120s + 180s), with ground truth from [Project Gutenberg](https://www.gutenberg.org/). Fully reproducible with `python scripts/run_scatter_benchmark.py`.
We are actively looking for benchmark results on other hardware (NVIDIA GPUs, different Apple Silicon chips, cloud instances). If you run the benchmarks on your machine, please share your results via an issue or PR!

Binary file not shown.

Before

Width:  |  Height:  |  Size: 89 KiB

After

Width:  |  Height:  |  Size: 95 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 94 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 92 KiB

After

Width:  |  Height:  |  Size: 95 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 91 KiB

View File

@@ -227,28 +227,33 @@ def generate_scatter(results, system_info, output_path, n_samples, lang="en",
fig, ax = plt.subplots(figsize=(12, 7), facecolor="white")
ax.set_facecolor("#fafafa")
# Separate main cluster from outliers (RTF > 1.0)
main = [r for r in results if r["rtf"] <= 1.0]
slow = [r for r in results if r["rtf"] > 1.0]
# Show ALL points on chart (no outlier exclusion)
main = results
slow = []
# Axis limits: tight around main data
# Axis limits: fit all data
if main:
xmax = max(r["rtf"] for r in main) * 1.6
ymax = max(r["wer_pct"] for r in main) * 1.5 + 1
xmax = max(r["rtf"] for r in main) * 1.15
ymax = max(r["wer_pct"] for r in main) * 1.15 + 1
else:
xmax, ymax = 0.5, 10
xmax = max(xmax, 0.45)
xmax = max(xmax, 1.15) # always show the real-time line
ymax = max(ymax, 8)
# Sweet spot zone
sweet_x = xmax * 0.85
sweet_y = ymax * 0.55
# Sweet spot zone: RTF < 1.0 (real-time) and WER < 12%
sweet_x = min(1.0, xmax * 0.85)
sweet_y = min(12, ymax * 0.45)
rect = plt.Rectangle((0, 0), sweet_x, sweet_y, alpha=0.07, color="#4ecca3",
zorder=0, linewidth=0)
ax.add_patch(rect)
ax.text(sweet_x - 0.005, sweet_y - 0.15, "sweet spot", ha="right", va="top",
fontsize=10, color="#2ecc71", fontstyle="italic", fontweight="bold", alpha=0.5)
# Real-time limit line
ax.axvline(x=1.0, color="#e94560", linestyle="--", linewidth=1.5, alpha=0.4, zorder=1)
ax.text(1.02, ymax * 0.97, "real-time\nlimit", fontsize=8, color="#e94560",
va="top", alpha=0.6)
# Manual label offsets keyed by label name — hand-tuned
OFFSETS = {
"fw LA base": (8, 8),

View File

@@ -34,6 +34,11 @@ logger = logging.getLogger(__name__)
# Decoder sliding-window size (matches the model's training configuration).
_DECODER_WINDOW = 8192
# Maximum continuous decoding positions before forcing a reset.
# Beyond ~20s of continuous audio the autoregressive context drifts and
# produces hallucination. 20s / 80ms per token = 250 tokens.
_MAX_CONTINUOUS_POSITIONS = 250
def _prompt_tokens(tokenizer, n_left_pad=LEFT_PAD_TOKENS, n_delay=6):
"""Build the prompt token sequence and return ``(token_ids, n_delay)``."""
@@ -152,6 +157,7 @@ class VoxtralMLXOnlineProcessor:
self._last_token: mx.array | None = None
# Bookkeeping
self._samples_encoded = 0
self._real_samples_encoded = 0 # only real audio, excludes silence padding
self._positions_decoded = 0
self._prefilled = False
self._first_chunk = True
@@ -191,6 +197,7 @@ class VoxtralMLXOnlineProcessor:
self.end = audio_stream_end_time
self._pending_chunks.append(audio)
self._pending_len += len(audio)
self._real_samples_encoded += len(audio)
self.audio_buffer = audio # diagnostic only
# -- core processing --
@@ -203,14 +210,28 @@ class VoxtralMLXOnlineProcessor:
return [], self.end
def _step(self, is_last: bool) -> Tuple[List[ASRToken], float]:
# 0. Safety cap: if continuous decoding exceeds the limit, force a
# flush+reset to prevent hallucination even without VAD silence.
if self._prefilled and self._positions_decoded >= _MAX_CONTINUOUS_POSITIONS + self._prefix_len:
logger.info(
"[voxtral-mlx] continuous decoding cap hit at %d positions — "
"forcing flush+reset",
self._positions_decoded,
)
words = self._flush_and_reset()
return words, self.end
# 1. Encode any new audio
self._encode_pending()
if self._audio_embeds is None:
return [], self.end
# 2. Compute how many positions we can safely decode
total_safe = LEFT_PAD_TOKENS + self._samples_encoded // SAMPLES_PER_TOKEN
# 2. Compute how many positions we can safely decode.
# The safe boundary prevents the decoder from running ahead of the
# audio encoder. _samples_encoded tracks only real audio (not
# silence padding), so positions beyond this produce hallucination.
total_safe = LEFT_PAD_TOKENS + self._real_samples_encoded // SAMPLES_PER_TOKEN
n_available = self._audio_embeds.shape[0]
n_decodable = min(n_available, total_safe - self._positions_decoded)
@@ -229,11 +250,19 @@ class VoxtralMLXOnlineProcessor:
if n_decodable <= 0 or self._audio_embeds is None:
return [], self.end
# Clamp to the continuous decoding cap so we don't overshoot
max_left = _MAX_CONTINUOUS_POSITIONS + self._prefix_len - self._positions_decoded
if max_left > 0:
n_decodable = min(n_decodable, max_left)
else:
# Will be caught by the cap check on the next call
return self._extract_committed_words(), self.end
# 4. Decode available positions
hit_eos = self._decode_positions(n_decodable)
if hit_eos:
# Flush words, reset for next utterance
# Flush words, then full reset for next utterance
words = self._flush_all_words()
logger.debug(
"[voxtral-mlx] EOS hit during stream: flushed %d words, "
@@ -242,9 +271,12 @@ class VoxtralMLXOnlineProcessor:
self._samples_encoded / self.SAMPLING_RATE,
self._full_text[-60:] if self._full_text else "",
)
saved_offset = self._time_offset
new_offset = self._time_offset + self._real_samples_encoded / self.SAMPLING_RATE
saved_end = self.end
self._reset_state()
self._time_offset = saved_offset
self._time_offset = new_offset
self.end = saved_end
mx.clear_cache()
return words, self.end
# 5. Extract committed words (all but the last, which may still grow)
@@ -451,12 +483,66 @@ class VoxtralMLXOnlineProcessor:
return Transcript(start=self.end, end=self.end, text=" ".join(remaining))
return Transcript(start=None, end=None, text="")
def start_silence(self) -> Tuple[List[ASRToken], float]:
"""Flush all pending words when silence starts.
def _safe_decode_remaining(self):
"""Decode remaining audio embeddings, respecting the safe boundary.
Adds right-padding silence and forces a full decode pass so the
decoder emits tokens for the last words of speech. Without this,
the model holds back the final tokens waiting for future context.
Uses the same guard as ``_step`` to avoid decoding positions that
are beyond the real audio frontier, which causes hallucination.
"""
if self._audio_embeds is None or not self._prefilled:
return
# Use the same formula as _step() — this excludes padding positions
total_safe = LEFT_PAD_TOKENS + self._samples_encoded // SAMPLES_PER_TOKEN
n_available = self._audio_embeds.shape[0]
n_decodable = min(n_available, max(0, total_safe - self._positions_decoded))
# Cap at RIGHT_PAD_TOKENS to only decode the padding needed for
# the model to emit final tokens, not all accumulated padding
n_decodable = min(n_decodable, RIGHT_PAD_TOKENS)
if n_decodable > 0:
self._decode_positions(n_decodable)
def _flush_last_token_text(self):
"""Add the last pending token's text (if not EOS) to _full_text."""
if self._last_token is None:
return
tid = self._last_token.item()
if tid == self._eos_id:
return
text = self._tokenizer.decode(
[tid], special_token_policy=SpecialTokenPolicy.IGNORE
)
if not text:
return
last_pos = self._positions_decoded - self._prefix_len
if text.lstrip() != text or not self._full_text:
if self._current_word_pos is not None:
self._word_audio_ends.append(last_pos)
self._word_audio_starts.append(last_pos)
self._current_word_pos = last_pos
elif self._current_word_pos is None:
self._word_audio_starts.append(last_pos)
self._current_word_pos = last_pos
self._full_text += text
self._n_text_tokens += 1
def _close_current_word(self):
"""Close the last word if one is being built."""
if self._current_word_pos is not None:
last_pos = self._positions_decoded - self._prefix_len
self._word_audio_ends.append(last_pos)
self._current_word_pos = None
def _flush_and_reset(self) -> List[ASRToken]:
"""Flush pending audio, decode remaining, extract all words, then
fully reset both encoder and decoder state.
Used at silence boundaries and when the continuous decoding cap is
hit. A full reset (encoder + decoder) is necessary because the
encoder's incremental state (conv tails, KV caches) contains history
that would produce embeddings incompatible with a freshly-initialised
decoder. After reset ``_first_chunk=True``, so the next audio chunk
receives proper left-padding and both encoder and decoder start in
sync.
"""
# Align pending audio to SAMPLES_PER_TOKEN boundary
remainder = self._pending_len % SAMPLES_PER_TOKEN
@@ -471,37 +557,40 @@ class VoxtralMLXOnlineProcessor:
# Encode remaining audio (including right-padding)
self._encode_pending()
# Decode everything that's left
if self._audio_embeds is not None and self._prefilled:
self._decode_positions(self._audio_embeds.shape[0])
# Decode only positions backed by real audio
self._safe_decode_remaining()
# Flush last token if it wasn't EOS
if self._last_token is not None:
tid = self._last_token.item()
if tid != self._eos_id:
text = self._tokenizer.decode(
[tid], special_token_policy=SpecialTokenPolicy.IGNORE
)
if text:
last_pos = self._positions_decoded - self._prefix_len
if text.lstrip() != text or not self._full_text:
if self._current_word_pos is not None:
self._word_audio_ends.append(last_pos)
self._word_audio_starts.append(last_pos)
self._current_word_pos = last_pos
elif self._current_word_pos is None:
self._word_audio_starts.append(last_pos)
self._current_word_pos = last_pos
self._full_text += text
self._n_text_tokens += 1
# Close the last word if still open
if self._current_word_pos is not None:
last_pos = self._positions_decoded - self._prefix_len
self._word_audio_ends.append(last_pos)
self._current_word_pos = None
self._flush_last_token_text()
self._close_current_word()
words = self._flush_all_words()
# Compute time offset: the decoded audio covers up to this point
new_offset = self._time_offset + self._real_samples_encoded / self.SAMPLING_RATE
saved_end = self.end
# Full reset — encoder AND decoder. The encoder's incremental
# state (conv tails, transformer KV caches) carries history from
# the previous segment; keeping it would make the next set of
# embeddings incompatible with a fresh decoder prefill.
self._reset_state()
self._time_offset = new_offset
self.end = saved_end
# Free MLX caches eagerly
mx.clear_cache()
return words
def start_silence(self) -> Tuple[List[ASRToken], float]:
"""Flush all pending words when silence starts, then fully reset.
Adds right-padding silence and forces a decode pass so the
decoder emits tokens for the last words of speech. After flushing,
resets both encoder and decoder state to prevent hallucination from
accumulated autoregressive context drift on long audio.
"""
words = self._flush_and_reset()
logger.info("[voxtral-mlx] start_silence: flushed %d words", len(words))
return words, self.end
@@ -529,10 +618,7 @@ class VoxtralMLXOnlineProcessor:
# Align pending audio to SAMPLES_PER_TOKEN boundary so nothing is lost
remainder = self._pending_len % SAMPLES_PER_TOKEN
if remainder > 0:
align_pad = SAMPLES_PER_TOKEN - remainder
else:
align_pad = 0
align_pad = (SAMPLES_PER_TOKEN - remainder) if remainder > 0 else 0
# Add alignment + right-padding silence
total_pad = align_pad + RIGHT_PAD_TOKENS * SAMPLES_PER_TOKEN
@@ -543,48 +629,11 @@ class VoxtralMLXOnlineProcessor:
# Encode remaining audio (including right-padding)
self._encode_pending()
logger.debug(
"[voxtral-mlx] finish after encode: audio_embeds=%s, pending=%d",
self._audio_embeds.shape if self._audio_embeds is not None else None,
self._pending_len,
)
# Decode only positions backed by real audio
self._safe_decode_remaining()
hit_eos = False
# Decode everything that's left from right-padding
if self._audio_embeds is not None and self._prefilled:
hit_eos = self._decode_positions(self._audio_embeds.shape[0])
logger.debug(
"[voxtral-mlx] finish decode: hit_eos=%s, text='%s'",
hit_eos, self._full_text[-80:] if self._full_text else "",
)
# Flush last token if it wasn't EOS
if self._last_token is not None:
tid = self._last_token.item()
if tid != self._eos_id:
text = self._tokenizer.decode(
[tid], special_token_policy=SpecialTokenPolicy.IGNORE
)
if text:
last_pos = self._positions_decoded - self._prefix_len
# Check if this starts a new word
if text.lstrip() != text or not self._full_text:
if self._current_word_pos is not None:
self._word_audio_ends.append(last_pos)
self._word_audio_starts.append(last_pos)
self._current_word_pos = last_pos
elif self._current_word_pos is None:
self._word_audio_starts.append(last_pos)
self._current_word_pos = last_pos
self._full_text += text
self._n_text_tokens += 1
# Close the last word if still open
if self._current_word_pos is not None:
last_pos = self._positions_decoded - self._prefix_len
self._word_audio_ends.append(last_pos)
self._current_word_pos = None
self._flush_last_token_text()
self._close_current_word()
words = self._flush_all_words()
logger.info("[voxtral-mlx] finish: flushed %d words", len(words))