10 Commits

Author SHA1 Message Date
Quentin Fuxa
7ea507ed8e Add Voxtral MLX streaming backend
Integrates the voxmlx-based Voxtral Mini Realtime streaming pipeline:
- VoxtralStreamingASR and VoxtralStreamingOnlineProcessor
- Incremental audio encoding and token-by-token autoregressive decoding
- Selectable via --backend voxtral-mlx

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 09:20:28 +01:00
Quentin Fuxa
e7e82f7c19 bump to 0.2.18 2026-02-11 22:10:00 +01:00
Quentin Fuxa
8c799fa4d1 fix simulstreaming vram leak: cap cross-attn accumulation + token budget
fixes #283, fixes #275

- accumulated_cross_attns was growing unboundedly during decoding loop,
  using up to ~5GB for repetition loops. now capped to rolling window of 16
- max_tokens_per_chunk was using TOKENS_PER_SECOND (mel frame rate = 50)
  instead of actual text token rate (~15/s), allowing 10-40x too many
  decoding steps
- removed unused torch.cat on early return path
- removed dead self.committed/last_result_tokens lists (never read)
- same fixes applied to mlx variant
2026-02-11 22:10:00 +01:00
Quentin Fuxa
8923337380 fix --direct-english-translation not setting task=translate for localagreement backends
the flag was only used for tokenizer language selection but never
actually passed to whisper/faster-whisper transcribe calls. also init
OpenaiApiASR.task and read from transcribe_kargs.

fixes #306
2026-02-11 22:10:00 +01:00
Quentin Fuxa
aded1649ae fix model_cache_dir + direct_english_translation task in simulstreaming
pass actual cache dir instead of None, and use proper task string
instead of boolean for AlignAttConfig

fixes #310
2026-02-11 22:10:00 +01:00
Quentin Fuxa
3b535e857a fix NoneType concatenation in add_translation
fixes #296
2026-02-11 22:10:00 +01:00
Quentin Fuxa
d649250b9a fix Segment classmethod call + isinstance type narrowing
fixes #331, fixes #329
2026-02-11 22:10:00 +01:00
Quentin Fuxa
7735478286 add insert_audio_chunk to DiartDiarization
fixes #332
2026-02-11 22:10:00 +01:00
Quentin Fuxa
b9e72d2b9a add probability field to ASRToken
fixes #330, fixes #313
2026-02-11 22:10:00 +01:00
Quentin Fuxa
e5b01033af add json normalizers for english language in build 2026-01-16 10:47:46 +01:00
13 changed files with 542 additions and 31 deletions

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "whisperlivekit"
version = "0.2.17.post1"
version = "0.2.18"
description = "Real-time speech-to-text with speaker diarization using Whisper"
readme = "README.md"
authors = [
@@ -69,4 +69,5 @@ packages = [
[tool.setuptools.package-data]
whisperlivekit = ["web/*.html", "web/*.css", "web/*.js", "web/src/*.svg"]
"whisperlivekit.whisper.assets" = ["*.tiktoken", "*.npz"]
"whisperlivekit.whisper.normalizers" = ["*.json"]
"whisperlivekit.silero_vad_models" = ["*.jit", "*.onnx"]

View File

@@ -29,6 +29,13 @@ def mlx_backend_available(warn_on_missing = False):
return available
def voxmlx_backend_available():
"""Return True if voxmlx (Voxtral MLX backend) is available."""
is_macos = platform.system() == "Darwin"
is_arm = platform.machine() == "arm64"
return is_macos and is_arm and module_available("voxmlx")
def faster_backend_available(warn_on_missing = False):
available = module_available("faster_whisper")
if not available and warn_on_missing and platform.system() != "Darwin":

View File

@@ -104,7 +104,12 @@ class TranscriptionEngine:
)
backend_policy = self.args.backend_policy
if self.args.transcription:
if backend_policy == "simulstreaming":
if self.args.backend == "voxtral-mlx":
from whisperlivekit.voxtral_streaming import VoxtralStreamingASR
self.tokenizer = None
self.asr = VoxtralStreamingASR(**transcription_common_params)
logger.info("Using Voxtral MLX streaming backend")
elif backend_policy == "simulstreaming":
simulstreaming_params = {
"disable_fast_encoder": False,
"custom_alignment_heads": None,
@@ -186,6 +191,9 @@ class TranscriptionEngine:
def online_factory(args, asr):
if getattr(args, 'backend', None) == "voxtral-mlx":
from whisperlivekit.voxtral_streaming import VoxtralStreamingOnlineProcessor
return VoxtralStreamingOnlineProcessor(asr)
if args.backend_policy == "simulstreaming":
from whisperlivekit.simul_whisper import SimulStreamingOnlineProcessor
return SimulStreamingOnlineProcessor(asr)

View File

@@ -202,14 +202,14 @@ class DiartDiarization:
def insert_silence(self, silence_duration):
self.observer.global_time_offset += silence_duration
async def diarize(self, pcm_array: np.ndarray):
"""
Process audio data for diarization.
Only used when working with WebSocketAudioSource.
"""
def insert_audio_chunk(self, pcm_array: np.ndarray):
"""Buffer audio for the next diarization step."""
if self.custom_source:
self.custom_source.push_audio(pcm_array)
# self.observer.clear_old_segments()
self.custom_source.push_audio(pcm_array)
async def diarize(self):
"""Return the current speaker segments from the diarization pipeline."""
return self.observer.get_segments()
def close(self):
"""Close the audio source."""

View File

@@ -151,7 +151,7 @@ class FasterWhisperASR(ASRBase):
if segment.no_speech_prob > 0.9:
continue
for word in segment.words:
token = ASRToken(word.start, word.end, word.word)
token = ASRToken(word.start, word.end, word.word, probability=word.probability)
tokens.append(token)
return tokens
@@ -249,6 +249,7 @@ class OpenaiApiASR(ASRBase):
self.load_model()
self.use_vad_opt = False
self.direct_english_translation = False
self.task = "transcribe"
def load_model(self, *args, **kwargs):
from openai import OpenAI
@@ -294,7 +295,8 @@ class OpenaiApiASR(ASRBase):
params["language"] = self.original_language
if prompt:
params["prompt"] = prompt
proc = self.client.audio.translations if self.task == "translate" else self.client.audio.transcriptions
task = self.transcribe_kargs.get("task", self.task)
proc = self.client.audio.translations if task == "translate" else self.client.audio.transcriptions
transcript = proc.create(**params)
logger.debug(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds")
return transcript

View File

@@ -146,6 +146,7 @@ def backend_factory(
if direct_english_translation:
tgt_language = "en" # Whisper translates into English
asr.transcribe_kargs["task"] = "translate"
else:
tgt_language = lan # Whisper transcribes in this language
@@ -154,9 +155,9 @@ def backend_factory(
tokenizer = create_tokenizer(tgt_language)
else:
tokenizer = None
warmup_asr(asr, warmup_file)
asr.confidence_validation = confidence_validation
asr.tokenizer = tokenizer
asr.buffer_trimming = buffer_trimming

View File

@@ -147,8 +147,8 @@ def parse_args():
"--backend",
type=str,
default="auto",
choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api"],
help="Select the Whisper backend implementation (auto: prefer MLX on macOS, otherwise Faster-Whisper, else Whisper). Use 'openai-api' with --backend-policy localagreement to call OpenAI's API.",
choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api", "voxtral-mlx"],
help="Select the Whisper backend implementation (auto: prefer MLX on macOS, otherwise Faster-Whisper, else Whisper). Use 'openai-api' with --backend-policy localagreement to call OpenAI's API. Use 'voxtral-mlx' for Voxtral streaming on Apple Silicon.",
)
parser.add_argument(
"--no-vac",

View File

@@ -46,8 +46,6 @@ class SimulStreamingOnlineProcessor:
self.logfile = logfile
self.end = 0.0
self.buffer = []
self.committed: List[ASRToken] = []
self.last_result_tokens: List[ASRToken] = []
self.model = self._create_alignatt()
if asr.tokenizer:
@@ -122,7 +120,6 @@ class SimulStreamingOnlineProcessor:
self.buffer.extend(timestamped_words)
return [], self.end
self.committed.extend(timestamped_words)
self.buffer = []
return timestamped_words, self.end
except Exception as e:
@@ -217,7 +214,7 @@ class SimulStreamingASR:
cif_ckpt_path=self.cif_ckpt_path,
decoder_type="beam",
beam_size=self.beams,
task=self.direct_english_translation,
task="translate" if self.direct_english_translation else "transcribe",
never_fire=self.never_fire,
init_prompt=self.init_prompt,
max_context_tokens=self.max_context_tokens,
@@ -330,7 +327,7 @@ class SimulStreamingASR:
lora_path = getattr(self, 'lora_path', None)
whisper_model = load_model(
name=model_ref,
download_root=None,
download_root=getattr(self, 'model_cache_dir', None),
decoder_only=self.fast_encoder,
custom_alignment_heads=self.custom_alignment_heads,
lora_path=lora_path,

View File

@@ -532,7 +532,9 @@ class MLXAlignAtt:
accumulated_cross_attns = []
audio_duration_s = self.segments_len()
max_tokens_per_chunk = max(50, int(audio_duration_s * TOKENS_PER_SECOND * 2.0))
# ~15 text tokens/s is a generous upper bound for speech; TOKENS_PER_SECOND (50)
# is the mel-frame rate and was causing 10-40x over-allocation on repetition loops.
max_tokens_per_chunk = max(50, int(audio_duration_s * 15 * 1.5))
tokens_produced_this_chunk = 0
while not completed and current_tokens.shape[1] < self.max_text_len:
@@ -558,6 +560,8 @@ class MLXAlignAtt:
mx.eval(logits)
accumulated_cross_attns.append(cross_qk)
if len(accumulated_cross_attns) > 16:
accumulated_cross_attns = accumulated_cross_attns[-16:]
if new_segment and self.tokenizer.no_speech is not None:
probs_at_sot = mx.softmax(logits[:, self.state.sot_index, :], axis=-1)

View File

@@ -390,7 +390,6 @@ class AlignAtt:
return []
if not self._apply_minseglen():
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
input_segments = torch.cat(self.state.segments, dim=0)
return []
# input_segments is concatenation of audio, it's one array
@@ -485,7 +484,9 @@ class AlignAtt:
accumulated_cross_attns = []
audio_duration_s = self.segments_len()
max_tokens_per_chunk = max(50, int(audio_duration_s * TOKENS_PER_SECOND * 2.0)) # 2x margin, min 50
# ~15 text tokens/s is a generous upper bound for speech; TOKENS_PER_SECOND (50)
# is the mel-frame rate and was causing 10-40x over-allocation on repetition loops.
max_tokens_per_chunk = max(50, int(audio_duration_s * 15 * 1.5))
tokens_produced_this_chunk = 0
while not completed and current_tokens.shape[1] < self.max_text_len: # bos is 3 tokens
@@ -506,8 +507,12 @@ class AlignAtt:
result = self.logits(tokens_for_logits, encoder_feature, return_cross_attn=True)
logits, cross_attns = result
# Accumulate cross-attention from this forward pass
# Accumulate cross-attention from this forward pass (rolling window to
# bound VRAM — only the last entry matters for alignment, and the
# median_filter kernel is 7, so 16 entries is more than enough).
accumulated_cross_attns.append(cross_attns)
if len(accumulated_cross_attns) > 16:
accumulated_cross_attns = accumulated_cross_attns[-16:]
if new_segment and self.tokenizer.no_speech is not None:
probs_at_sot = logits[:, self.state.sot_index, :].float().softmax(dim=-1)

View File

@@ -39,10 +39,11 @@ class TimedText(Timed):
@dataclass()
class ASRToken(TimedText):
probability: Optional[float] = None
def with_offset(self, offset: float) -> "ASRToken":
"""Return a new token with the time offset added."""
return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, detected_language=self.detected_language)
return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, detected_language=self.detected_language, probability=self.probability)
def is_silence(self) -> bool:
return False

View File

@@ -53,7 +53,8 @@ class TokensAlignment:
segment.translation = ''
for ts in self.all_translation_segments:
if ts.is_within(segment):
segment.translation += ts.text + (self.sep if ts.text else '')
if ts.text:
segment.translation += ts.text + self.sep
elif segment.translation:
break
@@ -185,11 +186,11 @@ class TokensAlignment:
else:
diarization_buffer = ''
for token in self.new_tokens:
if token.is_silence():
if isinstance(token, Silence):
if self.current_line_tokens:
self.validated_segments.append(Segment().from_tokens(self.current_line_tokens))
self.validated_segments.append(Segment.from_tokens(self.current_line_tokens))
self.current_line_tokens = []
end_silence = token.end if token.has_ended else time() - self.beg_loop
if self.validated_segments and self.validated_segments[-1].is_silence():
self.validated_segments[-1].end = end_silence
@@ -203,7 +204,7 @@ class TokensAlignment:
segments = list(self.validated_segments)
if self.current_line_tokens:
segments.append(Segment().from_tokens(self.current_line_tokens))
segments.append(Segment.from_tokens(self.current_line_tokens))
if current_silence:
end_silence = current_silence.end if current_silence.has_ended else time() - self.beg_loop

View File

@@ -0,0 +1,484 @@
"""
Voxtral Mini Realtime streaming backend using voxmlx's incremental encode/decode.
Uses model.encode_step() for incremental audio encoding and token-by-token
autoregressive decoding, matching voxmlx's native streaming pipeline.
"""
import logging
import sys
import time
from typing import List, Optional, Tuple
import numpy as np
from whisperlivekit.timed_objects import ASRToken, Transcript
logger = logging.getLogger(__name__)
N_LEFT_PAD_TOKENS = 32
N_RIGHT_PAD_TOKENS = 17
class VoxtralStreamingASR:
"""Voxtral model holder for the streaming pipeline."""
sep = " "
def __init__(self, logfile=sys.stderr, **kwargs):
from voxmlx import _build_prompt_tokens
from voxmlx import load_model as vox_load_model
self.logfile = logfile
self.transcribe_kargs = {}
lan = kwargs.get("lan", "auto")
self.original_language = None if lan == "auto" else lan
DEFAULT_MODEL = "mlx-community/Voxtral-Mini-4B-Realtime-6bit"
model_path = kwargs.get("model_dir") or kwargs.get("model_path")
if not model_path:
model_size = kwargs.get("model_size", "")
# Only use model_size if it looks like a HF repo or a path, not a Whisper size name
if model_size and ("/" in model_size or model_size.startswith(".")):
model_path = model_size
else:
model_path = DEFAULT_MODEL
t = time.time()
logger.info(f"Loading Voxtral model '{model_path}' via voxmlx...")
self.model, self._tokenizer, self._config = vox_load_model(model_path)
self._prompt_tokens, self._n_delay_tokens = _build_prompt_tokens(
self._tokenizer
)
logger.info(f"Voxtral model loaded in {time.time() - t:.2f}s")
self.backend_choice = "voxtral-mlx"
self.tokenizer = None # sentence tokenizer — not needed for streaming
def transcribe(self, audio):
pass
class VoxtralStreamingOnlineProcessor:
"""
Online processor for Voxtral streaming ASR.
Uses voxmlx's incremental encoding (encode_step) and token-by-token
autoregressive decoding. Each decode step corresponds to 80ms of audio.
"""
SAMPLING_RATE = 16000
def __init__(self, asr: VoxtralStreamingASR, logfile=sys.stderr):
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
self.asr = asr
self.logfile = logfile
self.end = 0.0
self.buffer = []
self.audio_buffer = np.array([], dtype=np.float32) # for logging compat
self._special_token_policy = SpecialTokenPolicy.IGNORE
self._reset_state()
logger.info(
f"[voxtral] Initialized. eos_id={asr._tokenizer.eos_id}, "
f"prefix_len={len(asr._prompt_tokens)}, "
f"n_delay={asr._n_delay_tokens}"
)
def _reset_state(self):
from voxmlx.audio import SAMPLES_PER_TOKEN
self._samples_per_token = SAMPLES_PER_TOKEN
# Incremental encoder state
self._audio_tail = None
self._conv1_tail = None
self._conv2_tail = None
self._encoder_cache = None
self._ds_buf = None
# Decoder state
self._decoder_cache = None
self._y = None # last sampled token (mx.array scalar)
self._t_cond = None
self._text_embeds = None
# Audio / decode tracking
self._pending_audio = np.zeros(0, dtype=np.float32)
self._audio_embeds = None
self._n_audio_samples_fed = 0
self._n_total_decoded = 0
self._first_cycle = True
self._prefilled = False
# Word extraction: accumulate token IDs, full-sequence decode for correct spacing
self._output_token_ids: List[int] = []
self._token_positions: List[int] = [] # decode position for each token
self._n_committed_words = 0
self._global_time_offset = 0.0
self._y_flushed_to_output = False # True after start_silence flushes pending _y
# ── Interface methods (same as SimulStreamingOnlineProcessor) ──
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float):
self.end = audio_stream_end_time
self._pending_audio = np.append(self._pending_audio, audio)
self.audio_buffer = self._pending_audio # for logging compat
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
try:
return self._process_iter_inner(is_last)
except Exception as e:
logger.warning(f"[voxtral] process_iter exception: {e}", exc_info=True)
return [], self.end
def _get_full_text(self) -> str:
"""Decode all accumulated token IDs at once for correct spacing."""
if not self._output_token_ids:
return ""
sp = self.asr._tokenizer
return sp.decode(self._output_token_ids, special_token_policy=self._special_token_policy)
def get_buffer(self) -> Transcript:
"""Return all uncommitted text as buffer, including pending _y token."""
# Temporarily include pending _y for buffer display
ids = list(self._output_token_ids)
if self._y is not None and not self._y_flushed_to_output:
sp = self.asr._tokenizer
token_id = self._y.item()
if token_id != sp.eos_id:
ids.append(token_id)
if not ids:
return Transcript(start=None, end=None, text="")
sp = self.asr._tokenizer
full_text = sp.decode(ids, special_token_policy=self._special_token_policy)
words = full_text.split()
uncommitted = words[self._n_committed_words:]
if uncommitted:
text = " ".join(uncommitted)
return Transcript(start=self.end, end=self.end, text=text)
return Transcript(start=None, end=None, text="")
def start_silence(self) -> Tuple[List[ASRToken], float]:
"""Flush all uncommitted words when silence starts."""
self._flush_last_y() # Include the pending _y token before flushing
words = self._flush_all_pending_words()
logger.info(f"[voxtral] start_silence: flushed {len(words)} words")
return words, self.end
def end_silence(self, silence_duration: float, offset: float):
self._global_time_offset += silence_duration
self.end += silence_duration
def new_speaker(self, change_speaker):
self.start_silence()
def warmup(self, audio, init_prompt=""):
pass
def finish(self) -> Tuple[List[ASRToken], float]:
"""Flush remaining audio with right-padding to let the model finish decoding."""
right_pad = np.zeros(
N_RIGHT_PAD_TOKENS * self._samples_per_token, dtype=np.float32
)
self._pending_audio = np.append(self._pending_audio, right_pad)
self._n_audio_samples_fed += len(right_pad)
final_words, _ = self._process_iter_inner(is_last=True)
# Flush the last pending self._y token (like voxmlx's finally block)
self._flush_last_y()
final_words.extend(self._flush_all_pending_words())
return final_words, self.end
# ── Word extraction ──
def _pos_to_time(self, pos: int) -> float:
"""Convert a decode position to seconds relative to audio start."""
SPT = self._samples_per_token
return max(0.0, (pos - N_LEFT_PAD_TOKENS) * SPT / self.SAMPLING_RATE)
def _flush_last_y(self):
"""Flush the last pending self._y token that hasn't been processed yet."""
if self._y is None or self._y_flushed_to_output:
return
sp = self.asr._tokenizer
token_id = self._y.item()
if token_id != sp.eos_id:
self._output_token_ids.append(token_id)
self._token_positions.append(self._n_total_decoded)
self._y_flushed_to_output = True
def _extract_new_words(self) -> List[ASRToken]:
"""
Split accumulated text into words and return new complete words
(all but the last, which may still be growing).
"""
if not self._output_token_ids:
return []
full_text = self._get_full_text()
words = full_text.split()
new_words: List[ASRToken] = []
n_tokens = len(self._output_token_ids)
# All words except the last are guaranteed complete
while len(words) > self._n_committed_words + 1:
word = words[self._n_committed_words]
word_idx = self._n_committed_words
n_words_total = len(words)
# Approximate: assign token range proportionally
tok_start = int(word_idx / n_words_total * n_tokens)
tok_end = int((word_idx + 1) / n_words_total * n_tokens)
tok_start = min(tok_start, len(self._token_positions) - 1)
tok_end = min(tok_end, len(self._token_positions) - 1)
start_time = self._pos_to_time(self._token_positions[tok_start]) + self._global_time_offset
end_time = self._pos_to_time(self._token_positions[tok_end]) + self._global_time_offset
# Prepend space to match Whisper convention (Segment.from_tokens joins with '')
text = word if self._n_committed_words == 0 else " " + word
new_words.append(ASRToken(start=start_time, end=end_time, text=text))
self._n_committed_words += 1
return new_words
def _flush_all_pending_words(self) -> List[ASRToken]:
"""Flush ALL words including the last partial one."""
if not self._output_token_ids:
return []
full_text = self._get_full_text()
words = full_text.split()
new_words: List[ASRToken] = []
n_tokens = len(self._output_token_ids)
n_words_total = max(len(words), 1)
while self._n_committed_words < len(words):
word = words[self._n_committed_words]
word_idx = self._n_committed_words
tok_start = int(word_idx / n_words_total * n_tokens)
tok_end = int((word_idx + 1) / n_words_total * n_tokens)
tok_start = min(tok_start, max(len(self._token_positions) - 1, 0))
tok_end = min(tok_end, max(len(self._token_positions) - 1, 0))
if self._token_positions:
start_time = self._pos_to_time(self._token_positions[tok_start]) + self._global_time_offset
end_time = self._pos_to_time(self._token_positions[tok_end]) + self._global_time_offset
else:
start_time = self._global_time_offset
end_time = self._global_time_offset
# Prepend space to match Whisper convention (Segment.from_tokens joins with '')
text = word if self._n_committed_words == 0 else " " + word
new_words.append(ASRToken(start=start_time, end=end_time, text=text))
self._n_committed_words += 1
return new_words
# ── Core streaming logic ──
def _process_iter_inner(self, is_last: bool) -> Tuple[List[ASRToken], float]:
import mlx.core as mx
from voxmlx.audio import log_mel_spectrogram_step
from voxmlx.cache import RotatingKVCache
model = self.asr.model
sp = self.asr._tokenizer
prompt_tokens = self.asr._prompt_tokens
prefix_len = len(prompt_tokens)
SPT = self._samples_per_token
# ── Phase 1: Encode new audio ──
if self._first_cycle and len(self._pending_audio) >= SPT:
left_pad = np.zeros(N_LEFT_PAD_TOKENS * SPT, dtype=np.float32)
n_feed = (len(self._pending_audio) // SPT) * SPT
chunk = np.concatenate([left_pad, self._pending_audio[:n_feed]])
self._pending_audio = self._pending_audio[n_feed:]
self._n_audio_samples_fed += n_feed
mel, self._audio_tail = log_mel_spectrogram_step(
chunk, self._audio_tail
)
(
new_embeds,
self._conv1_tail,
self._conv2_tail,
self._encoder_cache,
self._ds_buf,
) = model.encode_step(
mel,
self._conv1_tail,
self._conv2_tail,
self._encoder_cache,
self._ds_buf,
)
if new_embeds is not None:
mx.eval(new_embeds)
self._audio_embeds = new_embeds
logger.info(f"[voxtral] first encode: {new_embeds.shape[0]} embeds from {n_feed} samples")
else:
logger.info(f"[voxtral] first encode: no embeds from {n_feed} samples")
self._first_cycle = False
elif not self._first_cycle and len(self._pending_audio) >= SPT:
n_feed = (len(self._pending_audio) // SPT) * SPT
chunk = self._pending_audio[:n_feed]
self._pending_audio = self._pending_audio[n_feed:]
self._n_audio_samples_fed += n_feed
mel, self._audio_tail = log_mel_spectrogram_step(
chunk, self._audio_tail
)
(
new_embeds,
self._conv1_tail,
self._conv2_tail,
self._encoder_cache,
self._ds_buf,
) = model.encode_step(
mel,
self._conv1_tail,
self._conv2_tail,
self._encoder_cache,
self._ds_buf,
)
if new_embeds is not None:
mx.eval(new_embeds)
if self._audio_embeds is not None:
self._audio_embeds = mx.concatenate(
[self._audio_embeds, new_embeds]
)
else:
self._audio_embeds = new_embeds
self.audio_buffer = self._pending_audio # for logging compat
if self._audio_embeds is None:
return [], self.end
# Safety: don't decode ahead of encoded audio
safe_total = (
N_LEFT_PAD_TOKENS + self._n_audio_samples_fed // SPT
)
n_decodable = min(
self._audio_embeds.shape[0], safe_total - self._n_total_decoded
)
if n_decodable <= 0:
return [], self.end
# ── Phase 2: Prefill (once per utterance) ──
if not self._prefilled:
if self._n_total_decoded + self._audio_embeds.shape[0] < prefix_len:
logger.info(
f"[voxtral] waiting for prefill: have {self._audio_embeds.shape[0]} embeds, need {prefix_len}"
)
return [], self.end
n_layers = len(model.language_model.layers)
self._decoder_cache = [RotatingKVCache(8192) for _ in range(n_layers)]
self._t_cond = model.time_embedding(
mx.array([self.asr._n_delay_tokens], dtype=mx.float32)
)
prompt_ids = mx.array([prompt_tokens])
self._text_embeds = model.language_model.embed(prompt_ids)[0]
prefix_embeds = (
self._text_embeds + self._audio_embeds[:prefix_len]
)[None, :, :]
logits = model.decode(
prefix_embeds, self._t_cond, "causal", self._decoder_cache
)
mx.eval(
logits,
*[x for c in self._decoder_cache for x in (c.keys, c.values)],
)
self._y = mx.argmax(logits[0, -1:], axis=-1).squeeze()
mx.async_eval(self._y)
self._audio_embeds = self._audio_embeds[prefix_len:]
self._n_total_decoded = prefix_len
self._prefilled = True
logger.info(f"[voxtral] prefill done, first token y={self._y.item()}")
n_decodable = min(
self._audio_embeds.shape[0], safe_total - self._n_total_decoded
)
if n_decodable <= 0:
return [], self.end
# ── Phase 3: Decode new positions ──
eos_id = sp.eos_id
hit_eos = False
n_consumed = 0
for i in range(n_decodable):
token_embed = model.language_model.embed(self._y.reshape(1, 1))[0, 0]
step_embed = (self._audio_embeds[i] + token_embed)[None, None, :]
logits = model.decode(
step_embed, self._t_cond, mask=None, cache=self._decoder_cache
)
next_y = mx.argmax(logits[0, -1:], axis=-1).squeeze()
mx.async_eval(next_y)
token_id = self._y.item()
n_consumed = i + 1
if token_id == eos_id:
hit_eos = True
logger.info("[voxtral] hit EOS")
break
# Accumulate token ID — full-sequence decode produces correct spacing
# Skip if this _y was already flushed by start_silence()
if self._y_flushed_to_output:
self._y_flushed_to_output = False
else:
self._output_token_ids.append(token_id)
# Track position for timestamp estimation
pos = self._n_total_decoded + i
self._token_positions.append(pos)
if i > 0 and i % 256 == 0:
mx.clear_cache()
self._y = next_y
self._n_total_decoded += n_consumed
# Trim consumed embeddings
if self._audio_embeds.shape[0] > n_consumed:
self._audio_embeds = self._audio_embeds[n_consumed:]
else:
self._audio_embeds = None
# Log decode results
full_text = self._get_full_text()
logger.info(
f"[voxtral] decoded {n_consumed} tokens | "
f"total_decoded={self._n_total_decoded} | "
f"text='{full_text[-80:]}' | "
f"n_words={len(full_text.split())} committed={self._n_committed_words}"
)
# Extract complete words from the decoded token sequence
new_words = self._extract_new_words()
if hit_eos:
new_words.extend(self._flush_all_pending_words())
self._reset_state()
if new_words:
logger.info(f"[voxtral] returning {len(new_words)} words: {[w.text for w in new_words]}")
self.buffer = []
return new_words, self.end