From 3c15246fc0566d10064307b3cf41fe42c1003e89 Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Fri, 20 Feb 2026 20:46:37 +0100 Subject: [PATCH] mixstral hf v0 --- pyproject.toml | 1 + whisperlivekit/backend_support.py | 6 + whisperlivekit/core.py | 10 +- whisperlivekit/parse_args.py | 4 +- whisperlivekit/voxtral_hf_streaming.py | 386 +++++++++++++++++++++++++ 5 files changed, 404 insertions(+), 3 deletions(-) create mode 100644 whisperlivekit/voxtral_hf_streaming.py diff --git a/pyproject.toml b/pyproject.toml index f8a4527..74ade12 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ dependencies = [ [project.optional-dependencies] translation = ["nllw"] sentence_tokenizer = ["mosestokenizer", "wtpsplit"] +voxtral-hf = ["transformers>=5.2.0", "mistral-common[audio]"] [project.urls] Homepage = "https://github.com/QuentinFuxa/WhisperLiveKit" diff --git a/whisperlivekit/backend_support.py b/whisperlivekit/backend_support.py index a64770a..32733cf 100644 --- a/whisperlivekit/backend_support.py +++ b/whisperlivekit/backend_support.py @@ -29,6 +29,12 @@ def mlx_backend_available(warn_on_missing = False): return available +def voxtral_hf_backend_available(): + """Return True if HF Transformers Voxtral backend is available.""" + return module_available("transformers") + + + def faster_backend_available(warn_on_missing = False): available = module_available("faster_whisper") if not available and warn_on_missing and platform.system() != "Darwin": diff --git a/whisperlivekit/core.py b/whisperlivekit/core.py index f96fa3c..7cf1041 100644 --- a/whisperlivekit/core.py +++ b/whisperlivekit/core.py @@ -92,7 +92,12 @@ class TranscriptionEngine: } if config.transcription: - if config.backend_policy == "simulstreaming": + if config.backend == "voxtral": + from whisperlivekit.voxtral_hf_streaming import VoxtralHFStreamingASR + self.tokenizer = None + self.asr = VoxtralHFStreamingASR(**transcription_common_params) + logger.info("Using Voxtral HF Transformers streaming backend") + elif config.backend_policy == "simulstreaming": simulstreaming_params = { "disable_fast_encoder": config.disable_fast_encoder, "custom_alignment_heads": config.custom_alignment_heads, @@ -164,6 +169,9 @@ class TranscriptionEngine: def online_factory(args, asr): + if getattr(args, 'backend', None) == "voxtral": + from whisperlivekit.voxtral_hf_streaming import VoxtralHFStreamingOnlineProcessor + return VoxtralHFStreamingOnlineProcessor(asr) if args.backend_policy == "simulstreaming": from whisperlivekit.simul_whisper import SimulStreamingOnlineProcessor return SimulStreamingOnlineProcessor(asr) diff --git a/whisperlivekit/parse_args.py b/whisperlivekit/parse_args.py index 0f5f394..d89aaca 100644 --- a/whisperlivekit/parse_args.py +++ b/whisperlivekit/parse_args.py @@ -147,8 +147,8 @@ def parse_args(): "--backend", type=str, default="auto", - choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api"], - help="Select the ASR backend implementation (auto: prefer MLX on macOS, otherwise Faster-Whisper, else Whisper).", + choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api", "voxtral"], + help="Select the ASR backend implementation (auto: prefer MLX on macOS, otherwise Faster-Whisper, else Whisper). Use 'voxtral' for Voxtral streaming via HuggingFace Transformers (CUDA/CPU/MPS).", ) parser.add_argument( "--no-vac", diff --git a/whisperlivekit/voxtral_hf_streaming.py b/whisperlivekit/voxtral_hf_streaming.py new file mode 100644 index 0000000..2fee95f --- /dev/null +++ b/whisperlivekit/voxtral_hf_streaming.py @@ -0,0 +1,386 @@ +""" +Voxtral Mini Realtime streaming backend using HuggingFace Transformers. + +Uses VoxtralRealtimeForConditionalGeneration with a background generate thread +and queue-based audio feeding for real-time streaming transcription. +Supports CUDA, CPU, and MPS devices. +""" + +import logging +import queue +import sys +import threading +import time +from typing import List, Optional, Tuple + +import numpy as np + +from whisperlivekit.timed_objects import ASRToken, Transcript + +logger = logging.getLogger(__name__) + + +class VoxtralHFStreamingASR: + """Voxtral model holder using HuggingFace Transformers.""" + + sep = " " + + def __init__(self, logfile=sys.stderr, **kwargs): + import torch + from transformers import ( + AutoProcessor, + VoxtralRealtimeForConditionalGeneration, + ) + + self.logfile = logfile + self.transcribe_kargs = {} + + lan = kwargs.get("lan", "auto") + self.original_language = None if lan == "auto" else lan + + DEFAULT_MODEL = "mistralai/Voxtral-Mini-4B-Realtime-2602" + model_path = kwargs.get("model_dir") or kwargs.get("model_path") + if not model_path: + model_size = kwargs.get("model_size", "") + if model_size and ("/" in model_size or model_size.startswith(".")): + model_path = model_size + else: + model_path = DEFAULT_MODEL + + t = time.time() + logger.info(f"Loading Voxtral model '{model_path}' via HF Transformers...") + self.processor = AutoProcessor.from_pretrained(model_path) + self.model = VoxtralRealtimeForConditionalGeneration.from_pretrained( + model_path, + torch_dtype=torch.bfloat16, + device_map="auto", + ) + logger.info(f"Voxtral HF model loaded in {time.time() - t:.2f}s on {self.model.device}") + + self.backend_choice = "voxtral" + self.tokenizer = None # sentence tokenizer — not needed for streaming + + def transcribe(self, audio): + pass + + +class VoxtralHFStreamingOnlineProcessor: + """ + Online processor for Voxtral streaming ASR via HuggingFace Transformers. + + Uses a background thread running model.generate() with a queue-based + input_features_generator and TextIteratorStreamer for real-time output. + Each decoded token corresponds to ~80ms of audio. + """ + + SAMPLING_RATE = 16000 + + def __init__(self, asr: VoxtralHFStreamingASR, logfile=sys.stderr): + self.asr = asr + self.logfile = logfile + self.end = 0.0 + self.buffer = [] + self.audio_buffer = np.array([], dtype=np.float32) + + processor = asr.processor + self._first_chunk_samples = processor.num_samples_first_audio_chunk + self._chunk_samples = processor.num_samples_per_audio_chunk + self._chunk_step = processor.num_samples_per_audio_chunk_step + self._right_pad_samples = int( + processor.num_right_pad_tokens * processor.raw_audio_length_per_tok + ) + self._seconds_per_token = processor.raw_audio_length_per_tok / self.SAMPLING_RATE + + self._reset_state() + + logger.info( + f"[voxtral-hf] Initialized. first_chunk={self._first_chunk_samples} samples, " + f"chunk={self._chunk_samples}, step={self._chunk_step}, " + f"right_pad={self._right_pad_samples}" + ) + + def _reset_state(self): + self._pending_audio = np.zeros(0, dtype=np.float32) + self._audio_queue: queue.Queue = queue.Queue() + self._streamer_texts: List[str] = [] + self._generate_thread: Optional[threading.Thread] = None + self._generate_started = False + self._generate_finished = False + self._generate_error: Optional[Exception] = None + + # Text accumulation and word extraction + self._accumulated_text = "" + self._n_text_tokens_received = 0 + self._n_committed_words = 0 + self._global_time_offset = 0.0 + + # Lock for text state accessed from both generate thread and main thread + self._text_lock = threading.Lock() + + # ── Interface methods ── + + def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float): + self.end = audio_stream_end_time + self._pending_audio = np.append(self._pending_audio, audio) + self.audio_buffer = self._pending_audio + + def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]: + try: + return self._process_iter_inner(is_last) + except Exception as e: + logger.warning(f"[voxtral-hf] process_iter exception: {e}", exc_info=True) + return [], self.end + + def get_buffer(self) -> Transcript: + """Return all uncommitted text as buffer.""" + with self._text_lock: + text = self._accumulated_text + if not text: + return Transcript(start=None, end=None, text="") + + words = text.split() + uncommitted = words[self._n_committed_words:] + if uncommitted: + return Transcript(start=self.end, end=self.end, text=" ".join(uncommitted)) + return Transcript(start=None, end=None, text="") + + def start_silence(self) -> Tuple[List[ASRToken], float]: + """Flush all uncommitted words when silence starts.""" + self._drain_streamer() + words = self._flush_all_pending_words() + logger.info(f"[voxtral-hf] start_silence: flushed {len(words)} words") + return words, self.end + + def end_silence(self, silence_duration: float, offset: float): + self._global_time_offset += silence_duration + self.end += silence_duration + + def new_speaker(self, change_speaker): + self.start_silence() + + def warmup(self, audio, init_prompt=""): + pass + + def finish(self) -> Tuple[List[ASRToken], float]: + """Flush remaining audio with right-padding and stop the generate thread.""" + # Add right-padding so the model can finish decoding + if self._right_pad_samples > 0: + right_pad = np.zeros(self._right_pad_samples, dtype=np.float32) + self._pending_audio = np.append(self._pending_audio, right_pad) + + # Feed remaining audio + if self._generate_started and not self._generate_finished: + self._feed_pending_audio() + # Signal end of audio + self._audio_queue.put(None) + # Wait for generate to finish + if self._generate_thread is not None: + self._generate_thread.join(timeout=30.0) + elif not self._generate_started and len(self._pending_audio) >= self._first_chunk_samples: + # Never started but have enough audio — start and immediately finish + self._start_generate_thread() + self._feed_pending_audio() + self._audio_queue.put(None) + if self._generate_thread is not None: + self._generate_thread.join(timeout=30.0) + + self._drain_streamer() + words = self._flush_all_pending_words() + logger.info(f"[voxtral-hf] finish: flushed {len(words)} words") + return words, self.end + + # ── Generate thread management ── + + def _start_generate_thread(self): + """Start model.generate() in a background thread with streaming.""" + import torch + from transformers import TextIteratorStreamer + + processor = self.asr.processor + model = self.asr.model + + # Extract first chunk + first_chunk_audio = self._pending_audio[:self._first_chunk_samples] + self._pending_audio = self._pending_audio[self._first_chunk_samples:] + + first_inputs = processor( + first_chunk_audio, + is_streaming=True, + is_first_audio_chunk=True, + return_tensors="pt", + ) + first_inputs = first_inputs.to(model.device, dtype=model.dtype) + + streamer = TextIteratorStreamer( + processor.tokenizer, + skip_prompt=True, + skip_special_tokens=True, + ) + self._streamer = streamer + + audio_queue = self._audio_queue + + def input_features_gen(): + yield first_inputs.input_features + while True: + chunk_audio = audio_queue.get() + if chunk_audio is None: + break + inputs = processor( + chunk_audio, + is_streaming=True, + is_first_audio_chunk=False, + return_tensors="pt", + ) + inputs = inputs.to(model.device, dtype=model.dtype) + yield inputs.input_features + + def run_generate(): + try: + with torch.no_grad(): + model.generate( + input_features_generator=input_features_gen(), + streamer=streamer, + **first_inputs, + ) + except Exception as e: + logger.error(f"[voxtral-hf] generate error: {e}", exc_info=True) + self._generate_error = e + finally: + self._generate_finished = True + + self._generate_thread = threading.Thread(target=run_generate, daemon=True) + self._generate_thread.start() + self._generate_started = True + logger.info("[voxtral-hf] generate thread started") + + def _feed_pending_audio(self): + """Convert pending audio into properly-sized chunks for the generator.""" + chunk_size = self._chunk_samples + step_size = self._chunk_step + + while len(self._pending_audio) >= chunk_size: + chunk = self._pending_audio[:chunk_size] + self._audio_queue.put(chunk) + self._pending_audio = self._pending_audio[step_size:] + + self.audio_buffer = self._pending_audio + + def _drain_streamer(self): + """Non-blocking drain of all available text from the streamer.""" + if not self._generate_started: + return + + streamer = self._streamer + try: + for text_fragment in streamer: + if text_fragment: + with self._text_lock: + self._accumulated_text += text_fragment + self._n_text_tokens_received += 1 + # Check if more is immediately available (non-blocking) + if streamer.text_queue.empty(): + break + except StopIteration: + pass + + # ── Word extraction ── + + def _pos_to_time(self, token_position: int) -> float: + """Convert token position to seconds.""" + return token_position * self._seconds_per_token + self._global_time_offset + + def _extract_new_words(self) -> List[ASRToken]: + """Extract complete words (all but the last, which may still be growing).""" + with self._text_lock: + text = self._accumulated_text + if not text: + return [] + + words = text.split() + new_words: List[ASRToken] = [] + n_tokens = self._n_text_tokens_received + n_words_total = len(words) + + while len(words) > self._n_committed_words + 1: + word = words[self._n_committed_words] + word_idx = self._n_committed_words + + tok_start = int(word_idx / n_words_total * n_tokens) if n_words_total > 0 else 0 + tok_end = int((word_idx + 1) / n_words_total * n_tokens) if n_words_total > 0 else 0 + + start_time = self._pos_to_time(tok_start) + end_time = self._pos_to_time(tok_end) + + text_out = word if self._n_committed_words == 0 else " " + word + new_words.append(ASRToken(start=start_time, end=end_time, text=text_out)) + self._n_committed_words += 1 + + return new_words + + def _flush_all_pending_words(self) -> List[ASRToken]: + """Flush ALL words including the last partial one.""" + with self._text_lock: + text = self._accumulated_text + if not text: + return [] + + words = text.split() + new_words: List[ASRToken] = [] + n_tokens = max(self._n_text_tokens_received, 1) + n_words_total = max(len(words), 1) + + while self._n_committed_words < len(words): + word = words[self._n_committed_words] + word_idx = self._n_committed_words + + tok_start = int(word_idx / n_words_total * n_tokens) + tok_end = int((word_idx + 1) / n_words_total * n_tokens) + + start_time = self._pos_to_time(tok_start) + end_time = self._pos_to_time(tok_end) + + text_out = word if self._n_committed_words == 0 else " " + word + new_words.append(ASRToken(start=start_time, end=end_time, text=text_out)) + self._n_committed_words += 1 + + return new_words + + # ── Core processing ── + + def _process_iter_inner(self, is_last: bool) -> Tuple[List[ASRToken], float]: + # Start generate thread when enough audio is buffered + if not self._generate_started: + if len(self._pending_audio) >= self._first_chunk_samples: + self._start_generate_thread() + self._feed_pending_audio() + else: + return [], self.end + + # Feed any new pending audio + if self._generate_started and not self._generate_finished: + self._feed_pending_audio() + + # If generate finished unexpectedly (EOS) but new audio arrived, restart + if self._generate_finished and len(self._pending_audio) >= self._first_chunk_samples: + self._drain_streamer() + flush_words = self._flush_all_pending_words() + # Reset for new utterance + old_offset = self._global_time_offset + self._reset_state() + self._global_time_offset = old_offset + self._start_generate_thread() + self._feed_pending_audio() + return flush_words, self.end + + # Drain available text from streamer + self._drain_streamer() + + # Extract complete words + new_words = self._extract_new_words() + + if new_words: + logger.info(f"[voxtral-hf] returning {len(new_words)} words: {[w.text for w in new_words]}") + + self.buffer = [] + return new_words, self.end