diff --git a/whisperlivekit/backend_support.py b/whisperlivekit/backend_support.py index a64770a..e565d31 100644 --- a/whisperlivekit/backend_support.py +++ b/whisperlivekit/backend_support.py @@ -29,6 +29,13 @@ def mlx_backend_available(warn_on_missing = False): return available +def voxmlx_backend_available(): + """Return True if voxmlx (Voxtral MLX backend) is available.""" + is_macos = platform.system() == "Darwin" + is_arm = platform.machine() == "arm64" + return is_macos and is_arm and module_available("voxmlx") + + def faster_backend_available(warn_on_missing = False): available = module_available("faster_whisper") if not available and warn_on_missing and platform.system() != "Darwin": diff --git a/whisperlivekit/core.py b/whisperlivekit/core.py index 133ccbf..ff1bbf6 100644 --- a/whisperlivekit/core.py +++ b/whisperlivekit/core.py @@ -104,7 +104,12 @@ class TranscriptionEngine: ) backend_policy = self.args.backend_policy if self.args.transcription: - if backend_policy == "simulstreaming": + if self.args.backend == "voxtral-mlx": + from whisperlivekit.voxtral_streaming import VoxtralStreamingASR + self.tokenizer = None + self.asr = VoxtralStreamingASR(**transcription_common_params) + logger.info("Using Voxtral MLX streaming backend") + elif backend_policy == "simulstreaming": simulstreaming_params = { "disable_fast_encoder": False, "custom_alignment_heads": None, @@ -186,6 +191,9 @@ class TranscriptionEngine: def online_factory(args, asr): + if getattr(args, 'backend', None) == "voxtral-mlx": + from whisperlivekit.voxtral_streaming import VoxtralStreamingOnlineProcessor + return VoxtralStreamingOnlineProcessor(asr) if args.backend_policy == "simulstreaming": from whisperlivekit.simul_whisper import SimulStreamingOnlineProcessor return SimulStreamingOnlineProcessor(asr) diff --git a/whisperlivekit/parse_args.py b/whisperlivekit/parse_args.py index 9b5da4d..54beda2 100644 --- a/whisperlivekit/parse_args.py +++ b/whisperlivekit/parse_args.py @@ -147,8 +147,8 @@ def parse_args(): "--backend", type=str, default="auto", - choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api"], - help="Select the Whisper backend implementation (auto: prefer MLX on macOS, otherwise Faster-Whisper, else Whisper). Use 'openai-api' with --backend-policy localagreement to call OpenAI's API.", + choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api", "voxtral-mlx"], + help="Select the Whisper backend implementation (auto: prefer MLX on macOS, otherwise Faster-Whisper, else Whisper). Use 'openai-api' with --backend-policy localagreement to call OpenAI's API. Use 'voxtral-mlx' for Voxtral streaming on Apple Silicon.", ) parser.add_argument( "--no-vac", diff --git a/whisperlivekit/voxtral_streaming.py b/whisperlivekit/voxtral_streaming.py new file mode 100644 index 0000000..5c6619b --- /dev/null +++ b/whisperlivekit/voxtral_streaming.py @@ -0,0 +1,484 @@ +""" +Voxtral Mini Realtime streaming backend using voxmlx's incremental encode/decode. + +Uses model.encode_step() for incremental audio encoding and token-by-token +autoregressive decoding, matching voxmlx's native streaming pipeline. +""" + +import logging +import sys +import time +from typing import List, Optional, Tuple + +import numpy as np + +from whisperlivekit.timed_objects import ASRToken, Transcript + +logger = logging.getLogger(__name__) + +N_LEFT_PAD_TOKENS = 32 +N_RIGHT_PAD_TOKENS = 17 + + +class VoxtralStreamingASR: + """Voxtral model holder for the streaming pipeline.""" + + sep = " " + + def __init__(self, logfile=sys.stderr, **kwargs): + from voxmlx import _build_prompt_tokens + from voxmlx import load_model as vox_load_model + + self.logfile = logfile + self.transcribe_kargs = {} + + lan = kwargs.get("lan", "auto") + self.original_language = None if lan == "auto" else lan + + DEFAULT_MODEL = "mlx-community/Voxtral-Mini-4B-Realtime-6bit" + model_path = kwargs.get("model_dir") or kwargs.get("model_path") + if not model_path: + model_size = kwargs.get("model_size", "") + # Only use model_size if it looks like a HF repo or a path, not a Whisper size name + if model_size and ("/" in model_size or model_size.startswith(".")): + model_path = model_size + else: + model_path = DEFAULT_MODEL + + t = time.time() + logger.info(f"Loading Voxtral model '{model_path}' via voxmlx...") + self.model, self._tokenizer, self._config = vox_load_model(model_path) + self._prompt_tokens, self._n_delay_tokens = _build_prompt_tokens( + self._tokenizer + ) + logger.info(f"Voxtral model loaded in {time.time() - t:.2f}s") + + self.backend_choice = "voxtral-mlx" + self.tokenizer = None # sentence tokenizer — not needed for streaming + + def transcribe(self, audio): + pass + + +class VoxtralStreamingOnlineProcessor: + """ + Online processor for Voxtral streaming ASR. + + Uses voxmlx's incremental encoding (encode_step) and token-by-token + autoregressive decoding. Each decode step corresponds to 80ms of audio. + """ + + SAMPLING_RATE = 16000 + + def __init__(self, asr: VoxtralStreamingASR, logfile=sys.stderr): + from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy + + self.asr = asr + self.logfile = logfile + self.end = 0.0 + self.buffer = [] + self.audio_buffer = np.array([], dtype=np.float32) # for logging compat + self._special_token_policy = SpecialTokenPolicy.IGNORE + self._reset_state() + logger.info( + f"[voxtral] Initialized. eos_id={asr._tokenizer.eos_id}, " + f"prefix_len={len(asr._prompt_tokens)}, " + f"n_delay={asr._n_delay_tokens}" + ) + + def _reset_state(self): + from voxmlx.audio import SAMPLES_PER_TOKEN + + self._samples_per_token = SAMPLES_PER_TOKEN + + # Incremental encoder state + self._audio_tail = None + self._conv1_tail = None + self._conv2_tail = None + self._encoder_cache = None + self._ds_buf = None + + # Decoder state + self._decoder_cache = None + self._y = None # last sampled token (mx.array scalar) + self._t_cond = None + self._text_embeds = None + + # Audio / decode tracking + self._pending_audio = np.zeros(0, dtype=np.float32) + self._audio_embeds = None + self._n_audio_samples_fed = 0 + self._n_total_decoded = 0 + self._first_cycle = True + self._prefilled = False + + # Word extraction: accumulate token IDs, full-sequence decode for correct spacing + self._output_token_ids: List[int] = [] + self._token_positions: List[int] = [] # decode position for each token + self._n_committed_words = 0 + self._global_time_offset = 0.0 + self._y_flushed_to_output = False # True after start_silence flushes pending _y + + # ── Interface methods (same as SimulStreamingOnlineProcessor) ── + + def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float): + self.end = audio_stream_end_time + self._pending_audio = np.append(self._pending_audio, audio) + self.audio_buffer = self._pending_audio # for logging compat + + def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]: + try: + return self._process_iter_inner(is_last) + except Exception as e: + logger.warning(f"[voxtral] process_iter exception: {e}", exc_info=True) + return [], self.end + + def _get_full_text(self) -> str: + """Decode all accumulated token IDs at once for correct spacing.""" + if not self._output_token_ids: + return "" + sp = self.asr._tokenizer + return sp.decode(self._output_token_ids, special_token_policy=self._special_token_policy) + + def get_buffer(self) -> Transcript: + """Return all uncommitted text as buffer, including pending _y token.""" + # Temporarily include pending _y for buffer display + ids = list(self._output_token_ids) + if self._y is not None and not self._y_flushed_to_output: + sp = self.asr._tokenizer + token_id = self._y.item() + if token_id != sp.eos_id: + ids.append(token_id) + if not ids: + return Transcript(start=None, end=None, text="") + sp = self.asr._tokenizer + full_text = sp.decode(ids, special_token_policy=self._special_token_policy) + words = full_text.split() + uncommitted = words[self._n_committed_words:] + if uncommitted: + text = " ".join(uncommitted) + return Transcript(start=self.end, end=self.end, text=text) + return Transcript(start=None, end=None, text="") + + def start_silence(self) -> Tuple[List[ASRToken], float]: + """Flush all uncommitted words when silence starts.""" + self._flush_last_y() # Include the pending _y token before flushing + words = self._flush_all_pending_words() + logger.info(f"[voxtral] start_silence: flushed {len(words)} words") + return words, self.end + + def end_silence(self, silence_duration: float, offset: float): + self._global_time_offset += silence_duration + self.end += silence_duration + + def new_speaker(self, change_speaker): + self.start_silence() + + def warmup(self, audio, init_prompt=""): + pass + + def finish(self) -> Tuple[List[ASRToken], float]: + """Flush remaining audio with right-padding to let the model finish decoding.""" + right_pad = np.zeros( + N_RIGHT_PAD_TOKENS * self._samples_per_token, dtype=np.float32 + ) + self._pending_audio = np.append(self._pending_audio, right_pad) + self._n_audio_samples_fed += len(right_pad) + + final_words, _ = self._process_iter_inner(is_last=True) + # Flush the last pending self._y token (like voxmlx's finally block) + self._flush_last_y() + final_words.extend(self._flush_all_pending_words()) + return final_words, self.end + + # ── Word extraction ── + + def _pos_to_time(self, pos: int) -> float: + """Convert a decode position to seconds relative to audio start.""" + SPT = self._samples_per_token + return max(0.0, (pos - N_LEFT_PAD_TOKENS) * SPT / self.SAMPLING_RATE) + + def _flush_last_y(self): + """Flush the last pending self._y token that hasn't been processed yet.""" + if self._y is None or self._y_flushed_to_output: + return + sp = self.asr._tokenizer + token_id = self._y.item() + if token_id != sp.eos_id: + self._output_token_ids.append(token_id) + self._token_positions.append(self._n_total_decoded) + self._y_flushed_to_output = True + + def _extract_new_words(self) -> List[ASRToken]: + """ + Split accumulated text into words and return new complete words + (all but the last, which may still be growing). + """ + if not self._output_token_ids: + return [] + + full_text = self._get_full_text() + words = full_text.split() + + new_words: List[ASRToken] = [] + n_tokens = len(self._output_token_ids) + # All words except the last are guaranteed complete + while len(words) > self._n_committed_words + 1: + word = words[self._n_committed_words] + word_idx = self._n_committed_words + n_words_total = len(words) + # Approximate: assign token range proportionally + tok_start = int(word_idx / n_words_total * n_tokens) + tok_end = int((word_idx + 1) / n_words_total * n_tokens) + tok_start = min(tok_start, len(self._token_positions) - 1) + tok_end = min(tok_end, len(self._token_positions) - 1) + + start_time = self._pos_to_time(self._token_positions[tok_start]) + self._global_time_offset + end_time = self._pos_to_time(self._token_positions[tok_end]) + self._global_time_offset + + # Prepend space to match Whisper convention (Segment.from_tokens joins with '') + text = word if self._n_committed_words == 0 else " " + word + new_words.append(ASRToken(start=start_time, end=end_time, text=text)) + self._n_committed_words += 1 + + return new_words + + def _flush_all_pending_words(self) -> List[ASRToken]: + """Flush ALL words including the last partial one.""" + if not self._output_token_ids: + return [] + + full_text = self._get_full_text() + words = full_text.split() + + new_words: List[ASRToken] = [] + n_tokens = len(self._output_token_ids) + n_words_total = max(len(words), 1) + + while self._n_committed_words < len(words): + word = words[self._n_committed_words] + word_idx = self._n_committed_words + + tok_start = int(word_idx / n_words_total * n_tokens) + tok_end = int((word_idx + 1) / n_words_total * n_tokens) + tok_start = min(tok_start, max(len(self._token_positions) - 1, 0)) + tok_end = min(tok_end, max(len(self._token_positions) - 1, 0)) + + if self._token_positions: + start_time = self._pos_to_time(self._token_positions[tok_start]) + self._global_time_offset + end_time = self._pos_to_time(self._token_positions[tok_end]) + self._global_time_offset + else: + start_time = self._global_time_offset + end_time = self._global_time_offset + + # Prepend space to match Whisper convention (Segment.from_tokens joins with '') + text = word if self._n_committed_words == 0 else " " + word + new_words.append(ASRToken(start=start_time, end=end_time, text=text)) + self._n_committed_words += 1 + + return new_words + + # ── Core streaming logic ── + + def _process_iter_inner(self, is_last: bool) -> Tuple[List[ASRToken], float]: + import mlx.core as mx + + from voxmlx.audio import log_mel_spectrogram_step + from voxmlx.cache import RotatingKVCache + + model = self.asr.model + sp = self.asr._tokenizer + prompt_tokens = self.asr._prompt_tokens + prefix_len = len(prompt_tokens) + SPT = self._samples_per_token + + # ── Phase 1: Encode new audio ── + if self._first_cycle and len(self._pending_audio) >= SPT: + left_pad = np.zeros(N_LEFT_PAD_TOKENS * SPT, dtype=np.float32) + n_feed = (len(self._pending_audio) // SPT) * SPT + chunk = np.concatenate([left_pad, self._pending_audio[:n_feed]]) + self._pending_audio = self._pending_audio[n_feed:] + self._n_audio_samples_fed += n_feed + + mel, self._audio_tail = log_mel_spectrogram_step( + chunk, self._audio_tail + ) + ( + new_embeds, + self._conv1_tail, + self._conv2_tail, + self._encoder_cache, + self._ds_buf, + ) = model.encode_step( + mel, + self._conv1_tail, + self._conv2_tail, + self._encoder_cache, + self._ds_buf, + ) + if new_embeds is not None: + mx.eval(new_embeds) + self._audio_embeds = new_embeds + logger.info(f"[voxtral] first encode: {new_embeds.shape[0]} embeds from {n_feed} samples") + else: + logger.info(f"[voxtral] first encode: no embeds from {n_feed} samples") + self._first_cycle = False + + elif not self._first_cycle and len(self._pending_audio) >= SPT: + n_feed = (len(self._pending_audio) // SPT) * SPT + chunk = self._pending_audio[:n_feed] + self._pending_audio = self._pending_audio[n_feed:] + self._n_audio_samples_fed += n_feed + + mel, self._audio_tail = log_mel_spectrogram_step( + chunk, self._audio_tail + ) + ( + new_embeds, + self._conv1_tail, + self._conv2_tail, + self._encoder_cache, + self._ds_buf, + ) = model.encode_step( + mel, + self._conv1_tail, + self._conv2_tail, + self._encoder_cache, + self._ds_buf, + ) + if new_embeds is not None: + mx.eval(new_embeds) + if self._audio_embeds is not None: + self._audio_embeds = mx.concatenate( + [self._audio_embeds, new_embeds] + ) + else: + self._audio_embeds = new_embeds + + self.audio_buffer = self._pending_audio # for logging compat + + if self._audio_embeds is None: + return [], self.end + + # Safety: don't decode ahead of encoded audio + safe_total = ( + N_LEFT_PAD_TOKENS + self._n_audio_samples_fed // SPT + ) + n_decodable = min( + self._audio_embeds.shape[0], safe_total - self._n_total_decoded + ) + + if n_decodable <= 0: + return [], self.end + + # ── Phase 2: Prefill (once per utterance) ── + if not self._prefilled: + if self._n_total_decoded + self._audio_embeds.shape[0] < prefix_len: + logger.info( + f"[voxtral] waiting for prefill: have {self._audio_embeds.shape[0]} embeds, need {prefix_len}" + ) + return [], self.end + + n_layers = len(model.language_model.layers) + self._decoder_cache = [RotatingKVCache(8192) for _ in range(n_layers)] + + self._t_cond = model.time_embedding( + mx.array([self.asr._n_delay_tokens], dtype=mx.float32) + ) + + prompt_ids = mx.array([prompt_tokens]) + self._text_embeds = model.language_model.embed(prompt_ids)[0] + + prefix_embeds = ( + self._text_embeds + self._audio_embeds[:prefix_len] + )[None, :, :] + + logits = model.decode( + prefix_embeds, self._t_cond, "causal", self._decoder_cache + ) + mx.eval( + logits, + *[x for c in self._decoder_cache for x in (c.keys, c.values)], + ) + + self._y = mx.argmax(logits[0, -1:], axis=-1).squeeze() + mx.async_eval(self._y) + + self._audio_embeds = self._audio_embeds[prefix_len:] + self._n_total_decoded = prefix_len + self._prefilled = True + logger.info(f"[voxtral] prefill done, first token y={self._y.item()}") + + n_decodable = min( + self._audio_embeds.shape[0], safe_total - self._n_total_decoded + ) + + if n_decodable <= 0: + return [], self.end + + # ── Phase 3: Decode new positions ── + eos_id = sp.eos_id + hit_eos = False + n_consumed = 0 + + for i in range(n_decodable): + token_embed = model.language_model.embed(self._y.reshape(1, 1))[0, 0] + step_embed = (self._audio_embeds[i] + token_embed)[None, None, :] + logits = model.decode( + step_embed, self._t_cond, mask=None, cache=self._decoder_cache + ) + next_y = mx.argmax(logits[0, -1:], axis=-1).squeeze() + mx.async_eval(next_y) + + token_id = self._y.item() + n_consumed = i + 1 + + if token_id == eos_id: + hit_eos = True + logger.info("[voxtral] hit EOS") + break + + # Accumulate token ID — full-sequence decode produces correct spacing + # Skip if this _y was already flushed by start_silence() + if self._y_flushed_to_output: + self._y_flushed_to_output = False + else: + self._output_token_ids.append(token_id) + # Track position for timestamp estimation + pos = self._n_total_decoded + i + self._token_positions.append(pos) + + if i > 0 and i % 256 == 0: + mx.clear_cache() + + self._y = next_y + + self._n_total_decoded += n_consumed + + # Trim consumed embeddings + if self._audio_embeds.shape[0] > n_consumed: + self._audio_embeds = self._audio_embeds[n_consumed:] + else: + self._audio_embeds = None + + # Log decode results + full_text = self._get_full_text() + logger.info( + f"[voxtral] decoded {n_consumed} tokens | " + f"total_decoded={self._n_total_decoded} | " + f"text='{full_text[-80:]}' | " + f"n_words={len(full_text.split())} committed={self._n_committed_words}" + ) + + # Extract complete words from the decoded token sequence + new_words = self._extract_new_words() + + if hit_eos: + new_words.extend(self._flush_all_pending_words()) + self._reset_state() + + if new_words: + logger.info(f"[voxtral] returning {len(new_words)} words: {[w.text for w in new_words]}") + + self.buffer = [] + return new_words, self.end