mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
Compare commits
10 Commits
0.2.17.pos
...
feature/vo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7ea507ed8e | ||
|
|
e7e82f7c19 | ||
|
|
8c799fa4d1 | ||
|
|
8923337380 | ||
|
|
aded1649ae | ||
|
|
3b535e857a | ||
|
|
d649250b9a | ||
|
|
7735478286 | ||
|
|
b9e72d2b9a | ||
|
|
e5b01033af |
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "whisperlivekit"
|
||||
version = "0.2.17.post1"
|
||||
version = "0.2.18"
|
||||
description = "Real-time speech-to-text with speaker diarization using Whisper"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
@@ -69,4 +69,5 @@ packages = [
|
||||
[tool.setuptools.package-data]
|
||||
whisperlivekit = ["web/*.html", "web/*.css", "web/*.js", "web/src/*.svg"]
|
||||
"whisperlivekit.whisper.assets" = ["*.tiktoken", "*.npz"]
|
||||
"whisperlivekit.whisper.normalizers" = ["*.json"]
|
||||
"whisperlivekit.silero_vad_models" = ["*.jit", "*.onnx"]
|
||||
|
||||
@@ -29,6 +29,13 @@ def mlx_backend_available(warn_on_missing = False):
|
||||
return available
|
||||
|
||||
|
||||
def voxmlx_backend_available():
|
||||
"""Return True if voxmlx (Voxtral MLX backend) is available."""
|
||||
is_macos = platform.system() == "Darwin"
|
||||
is_arm = platform.machine() == "arm64"
|
||||
return is_macos and is_arm and module_available("voxmlx")
|
||||
|
||||
|
||||
def faster_backend_available(warn_on_missing = False):
|
||||
available = module_available("faster_whisper")
|
||||
if not available and warn_on_missing and platform.system() != "Darwin":
|
||||
|
||||
@@ -104,7 +104,12 @@ class TranscriptionEngine:
|
||||
)
|
||||
backend_policy = self.args.backend_policy
|
||||
if self.args.transcription:
|
||||
if backend_policy == "simulstreaming":
|
||||
if self.args.backend == "voxtral-mlx":
|
||||
from whisperlivekit.voxtral_streaming import VoxtralStreamingASR
|
||||
self.tokenizer = None
|
||||
self.asr = VoxtralStreamingASR(**transcription_common_params)
|
||||
logger.info("Using Voxtral MLX streaming backend")
|
||||
elif backend_policy == "simulstreaming":
|
||||
simulstreaming_params = {
|
||||
"disable_fast_encoder": False,
|
||||
"custom_alignment_heads": None,
|
||||
@@ -186,6 +191,9 @@ class TranscriptionEngine:
|
||||
|
||||
|
||||
def online_factory(args, asr):
|
||||
if getattr(args, 'backend', None) == "voxtral-mlx":
|
||||
from whisperlivekit.voxtral_streaming import VoxtralStreamingOnlineProcessor
|
||||
return VoxtralStreamingOnlineProcessor(asr)
|
||||
if args.backend_policy == "simulstreaming":
|
||||
from whisperlivekit.simul_whisper import SimulStreamingOnlineProcessor
|
||||
return SimulStreamingOnlineProcessor(asr)
|
||||
|
||||
@@ -202,14 +202,14 @@ class DiartDiarization:
|
||||
def insert_silence(self, silence_duration):
|
||||
self.observer.global_time_offset += silence_duration
|
||||
|
||||
async def diarize(self, pcm_array: np.ndarray):
|
||||
"""
|
||||
Process audio data for diarization.
|
||||
Only used when working with WebSocketAudioSource.
|
||||
"""
|
||||
def insert_audio_chunk(self, pcm_array: np.ndarray):
|
||||
"""Buffer audio for the next diarization step."""
|
||||
if self.custom_source:
|
||||
self.custom_source.push_audio(pcm_array)
|
||||
# self.observer.clear_old_segments()
|
||||
self.custom_source.push_audio(pcm_array)
|
||||
|
||||
async def diarize(self):
|
||||
"""Return the current speaker segments from the diarization pipeline."""
|
||||
return self.observer.get_segments()
|
||||
|
||||
def close(self):
|
||||
"""Close the audio source."""
|
||||
|
||||
@@ -151,7 +151,7 @@ class FasterWhisperASR(ASRBase):
|
||||
if segment.no_speech_prob > 0.9:
|
||||
continue
|
||||
for word in segment.words:
|
||||
token = ASRToken(word.start, word.end, word.word)
|
||||
token = ASRToken(word.start, word.end, word.word, probability=word.probability)
|
||||
tokens.append(token)
|
||||
return tokens
|
||||
|
||||
@@ -249,6 +249,7 @@ class OpenaiApiASR(ASRBase):
|
||||
self.load_model()
|
||||
self.use_vad_opt = False
|
||||
self.direct_english_translation = False
|
||||
self.task = "transcribe"
|
||||
|
||||
def load_model(self, *args, **kwargs):
|
||||
from openai import OpenAI
|
||||
@@ -294,7 +295,8 @@ class OpenaiApiASR(ASRBase):
|
||||
params["language"] = self.original_language
|
||||
if prompt:
|
||||
params["prompt"] = prompt
|
||||
proc = self.client.audio.translations if self.task == "translate" else self.client.audio.transcriptions
|
||||
task = self.transcribe_kargs.get("task", self.task)
|
||||
proc = self.client.audio.translations if task == "translate" else self.client.audio.transcriptions
|
||||
transcript = proc.create(**params)
|
||||
logger.debug(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds")
|
||||
return transcript
|
||||
|
||||
@@ -146,6 +146,7 @@ def backend_factory(
|
||||
|
||||
if direct_english_translation:
|
||||
tgt_language = "en" # Whisper translates into English
|
||||
asr.transcribe_kargs["task"] = "translate"
|
||||
else:
|
||||
tgt_language = lan # Whisper transcribes in this language
|
||||
|
||||
@@ -154,9 +155,9 @@ def backend_factory(
|
||||
tokenizer = create_tokenizer(tgt_language)
|
||||
else:
|
||||
tokenizer = None
|
||||
|
||||
|
||||
warmup_asr(asr, warmup_file)
|
||||
|
||||
|
||||
asr.confidence_validation = confidence_validation
|
||||
asr.tokenizer = tokenizer
|
||||
asr.buffer_trimming = buffer_trimming
|
||||
|
||||
@@ -147,8 +147,8 @@ def parse_args():
|
||||
"--backend",
|
||||
type=str,
|
||||
default="auto",
|
||||
choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api"],
|
||||
help="Select the Whisper backend implementation (auto: prefer MLX on macOS, otherwise Faster-Whisper, else Whisper). Use 'openai-api' with --backend-policy localagreement to call OpenAI's API.",
|
||||
choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api", "voxtral-mlx"],
|
||||
help="Select the Whisper backend implementation (auto: prefer MLX on macOS, otherwise Faster-Whisper, else Whisper). Use 'openai-api' with --backend-policy localagreement to call OpenAI's API. Use 'voxtral-mlx' for Voxtral streaming on Apple Silicon.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-vac",
|
||||
|
||||
@@ -46,8 +46,6 @@ class SimulStreamingOnlineProcessor:
|
||||
self.logfile = logfile
|
||||
self.end = 0.0
|
||||
self.buffer = []
|
||||
self.committed: List[ASRToken] = []
|
||||
self.last_result_tokens: List[ASRToken] = []
|
||||
self.model = self._create_alignatt()
|
||||
|
||||
if asr.tokenizer:
|
||||
@@ -122,7 +120,6 @@ class SimulStreamingOnlineProcessor:
|
||||
self.buffer.extend(timestamped_words)
|
||||
return [], self.end
|
||||
|
||||
self.committed.extend(timestamped_words)
|
||||
self.buffer = []
|
||||
return timestamped_words, self.end
|
||||
except Exception as e:
|
||||
@@ -217,7 +214,7 @@ class SimulStreamingASR:
|
||||
cif_ckpt_path=self.cif_ckpt_path,
|
||||
decoder_type="beam",
|
||||
beam_size=self.beams,
|
||||
task=self.direct_english_translation,
|
||||
task="translate" if self.direct_english_translation else "transcribe",
|
||||
never_fire=self.never_fire,
|
||||
init_prompt=self.init_prompt,
|
||||
max_context_tokens=self.max_context_tokens,
|
||||
@@ -330,7 +327,7 @@ class SimulStreamingASR:
|
||||
lora_path = getattr(self, 'lora_path', None)
|
||||
whisper_model = load_model(
|
||||
name=model_ref,
|
||||
download_root=None,
|
||||
download_root=getattr(self, 'model_cache_dir', None),
|
||||
decoder_only=self.fast_encoder,
|
||||
custom_alignment_heads=self.custom_alignment_heads,
|
||||
lora_path=lora_path,
|
||||
|
||||
@@ -532,7 +532,9 @@ class MLXAlignAtt:
|
||||
accumulated_cross_attns = []
|
||||
|
||||
audio_duration_s = self.segments_len()
|
||||
max_tokens_per_chunk = max(50, int(audio_duration_s * TOKENS_PER_SECOND * 2.0))
|
||||
# ~15 text tokens/s is a generous upper bound for speech; TOKENS_PER_SECOND (50)
|
||||
# is the mel-frame rate and was causing 10-40x over-allocation on repetition loops.
|
||||
max_tokens_per_chunk = max(50, int(audio_duration_s * 15 * 1.5))
|
||||
tokens_produced_this_chunk = 0
|
||||
|
||||
while not completed and current_tokens.shape[1] < self.max_text_len:
|
||||
@@ -558,6 +560,8 @@ class MLXAlignAtt:
|
||||
mx.eval(logits)
|
||||
|
||||
accumulated_cross_attns.append(cross_qk)
|
||||
if len(accumulated_cross_attns) > 16:
|
||||
accumulated_cross_attns = accumulated_cross_attns[-16:]
|
||||
|
||||
if new_segment and self.tokenizer.no_speech is not None:
|
||||
probs_at_sot = mx.softmax(logits[:, self.state.sot_index, :], axis=-1)
|
||||
|
||||
@@ -390,7 +390,6 @@ class AlignAtt:
|
||||
return []
|
||||
if not self._apply_minseglen():
|
||||
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
|
||||
input_segments = torch.cat(self.state.segments, dim=0)
|
||||
return []
|
||||
|
||||
# input_segments is concatenation of audio, it's one array
|
||||
@@ -485,7 +484,9 @@ class AlignAtt:
|
||||
accumulated_cross_attns = []
|
||||
|
||||
audio_duration_s = self.segments_len()
|
||||
max_tokens_per_chunk = max(50, int(audio_duration_s * TOKENS_PER_SECOND * 2.0)) # 2x margin, min 50
|
||||
# ~15 text tokens/s is a generous upper bound for speech; TOKENS_PER_SECOND (50)
|
||||
# is the mel-frame rate and was causing 10-40x over-allocation on repetition loops.
|
||||
max_tokens_per_chunk = max(50, int(audio_duration_s * 15 * 1.5))
|
||||
tokens_produced_this_chunk = 0
|
||||
|
||||
while not completed and current_tokens.shape[1] < self.max_text_len: # bos is 3 tokens
|
||||
@@ -506,8 +507,12 @@ class AlignAtt:
|
||||
result = self.logits(tokens_for_logits, encoder_feature, return_cross_attn=True)
|
||||
logits, cross_attns = result
|
||||
|
||||
# Accumulate cross-attention from this forward pass
|
||||
# Accumulate cross-attention from this forward pass (rolling window to
|
||||
# bound VRAM — only the last entry matters for alignment, and the
|
||||
# median_filter kernel is 7, so 16 entries is more than enough).
|
||||
accumulated_cross_attns.append(cross_attns)
|
||||
if len(accumulated_cross_attns) > 16:
|
||||
accumulated_cross_attns = accumulated_cross_attns[-16:]
|
||||
|
||||
if new_segment and self.tokenizer.no_speech is not None:
|
||||
probs_at_sot = logits[:, self.state.sot_index, :].float().softmax(dim=-1)
|
||||
|
||||
@@ -39,10 +39,11 @@ class TimedText(Timed):
|
||||
|
||||
@dataclass()
|
||||
class ASRToken(TimedText):
|
||||
|
||||
probability: Optional[float] = None
|
||||
|
||||
def with_offset(self, offset: float) -> "ASRToken":
|
||||
"""Return a new token with the time offset added."""
|
||||
return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, detected_language=self.detected_language)
|
||||
return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, detected_language=self.detected_language, probability=self.probability)
|
||||
|
||||
def is_silence(self) -> bool:
|
||||
return False
|
||||
|
||||
@@ -53,7 +53,8 @@ class TokensAlignment:
|
||||
segment.translation = ''
|
||||
for ts in self.all_translation_segments:
|
||||
if ts.is_within(segment):
|
||||
segment.translation += ts.text + (self.sep if ts.text else '')
|
||||
if ts.text:
|
||||
segment.translation += ts.text + self.sep
|
||||
elif segment.translation:
|
||||
break
|
||||
|
||||
@@ -185,11 +186,11 @@ class TokensAlignment:
|
||||
else:
|
||||
diarization_buffer = ''
|
||||
for token in self.new_tokens:
|
||||
if token.is_silence():
|
||||
if isinstance(token, Silence):
|
||||
if self.current_line_tokens:
|
||||
self.validated_segments.append(Segment().from_tokens(self.current_line_tokens))
|
||||
self.validated_segments.append(Segment.from_tokens(self.current_line_tokens))
|
||||
self.current_line_tokens = []
|
||||
|
||||
|
||||
end_silence = token.end if token.has_ended else time() - self.beg_loop
|
||||
if self.validated_segments and self.validated_segments[-1].is_silence():
|
||||
self.validated_segments[-1].end = end_silence
|
||||
@@ -203,7 +204,7 @@ class TokensAlignment:
|
||||
|
||||
segments = list(self.validated_segments)
|
||||
if self.current_line_tokens:
|
||||
segments.append(Segment().from_tokens(self.current_line_tokens))
|
||||
segments.append(Segment.from_tokens(self.current_line_tokens))
|
||||
|
||||
if current_silence:
|
||||
end_silence = current_silence.end if current_silence.has_ended else time() - self.beg_loop
|
||||
|
||||
484
whisperlivekit/voxtral_streaming.py
Normal file
484
whisperlivekit/voxtral_streaming.py
Normal file
@@ -0,0 +1,484 @@
|
||||
"""
|
||||
Voxtral Mini Realtime streaming backend using voxmlx's incremental encode/decode.
|
||||
|
||||
Uses model.encode_step() for incremental audio encoding and token-by-token
|
||||
autoregressive decoding, matching voxmlx's native streaming pipeline.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from whisperlivekit.timed_objects import ASRToken, Transcript
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
N_LEFT_PAD_TOKENS = 32
|
||||
N_RIGHT_PAD_TOKENS = 17
|
||||
|
||||
|
||||
class VoxtralStreamingASR:
|
||||
"""Voxtral model holder for the streaming pipeline."""
|
||||
|
||||
sep = " "
|
||||
|
||||
def __init__(self, logfile=sys.stderr, **kwargs):
|
||||
from voxmlx import _build_prompt_tokens
|
||||
from voxmlx import load_model as vox_load_model
|
||||
|
||||
self.logfile = logfile
|
||||
self.transcribe_kargs = {}
|
||||
|
||||
lan = kwargs.get("lan", "auto")
|
||||
self.original_language = None if lan == "auto" else lan
|
||||
|
||||
DEFAULT_MODEL = "mlx-community/Voxtral-Mini-4B-Realtime-6bit"
|
||||
model_path = kwargs.get("model_dir") or kwargs.get("model_path")
|
||||
if not model_path:
|
||||
model_size = kwargs.get("model_size", "")
|
||||
# Only use model_size if it looks like a HF repo or a path, not a Whisper size name
|
||||
if model_size and ("/" in model_size or model_size.startswith(".")):
|
||||
model_path = model_size
|
||||
else:
|
||||
model_path = DEFAULT_MODEL
|
||||
|
||||
t = time.time()
|
||||
logger.info(f"Loading Voxtral model '{model_path}' via voxmlx...")
|
||||
self.model, self._tokenizer, self._config = vox_load_model(model_path)
|
||||
self._prompt_tokens, self._n_delay_tokens = _build_prompt_tokens(
|
||||
self._tokenizer
|
||||
)
|
||||
logger.info(f"Voxtral model loaded in {time.time() - t:.2f}s")
|
||||
|
||||
self.backend_choice = "voxtral-mlx"
|
||||
self.tokenizer = None # sentence tokenizer — not needed for streaming
|
||||
|
||||
def transcribe(self, audio):
|
||||
pass
|
||||
|
||||
|
||||
class VoxtralStreamingOnlineProcessor:
|
||||
"""
|
||||
Online processor for Voxtral streaming ASR.
|
||||
|
||||
Uses voxmlx's incremental encoding (encode_step) and token-by-token
|
||||
autoregressive decoding. Each decode step corresponds to 80ms of audio.
|
||||
"""
|
||||
|
||||
SAMPLING_RATE = 16000
|
||||
|
||||
def __init__(self, asr: VoxtralStreamingASR, logfile=sys.stderr):
|
||||
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
|
||||
|
||||
self.asr = asr
|
||||
self.logfile = logfile
|
||||
self.end = 0.0
|
||||
self.buffer = []
|
||||
self.audio_buffer = np.array([], dtype=np.float32) # for logging compat
|
||||
self._special_token_policy = SpecialTokenPolicy.IGNORE
|
||||
self._reset_state()
|
||||
logger.info(
|
||||
f"[voxtral] Initialized. eos_id={asr._tokenizer.eos_id}, "
|
||||
f"prefix_len={len(asr._prompt_tokens)}, "
|
||||
f"n_delay={asr._n_delay_tokens}"
|
||||
)
|
||||
|
||||
def _reset_state(self):
|
||||
from voxmlx.audio import SAMPLES_PER_TOKEN
|
||||
|
||||
self._samples_per_token = SAMPLES_PER_TOKEN
|
||||
|
||||
# Incremental encoder state
|
||||
self._audio_tail = None
|
||||
self._conv1_tail = None
|
||||
self._conv2_tail = None
|
||||
self._encoder_cache = None
|
||||
self._ds_buf = None
|
||||
|
||||
# Decoder state
|
||||
self._decoder_cache = None
|
||||
self._y = None # last sampled token (mx.array scalar)
|
||||
self._t_cond = None
|
||||
self._text_embeds = None
|
||||
|
||||
# Audio / decode tracking
|
||||
self._pending_audio = np.zeros(0, dtype=np.float32)
|
||||
self._audio_embeds = None
|
||||
self._n_audio_samples_fed = 0
|
||||
self._n_total_decoded = 0
|
||||
self._first_cycle = True
|
||||
self._prefilled = False
|
||||
|
||||
# Word extraction: accumulate token IDs, full-sequence decode for correct spacing
|
||||
self._output_token_ids: List[int] = []
|
||||
self._token_positions: List[int] = [] # decode position for each token
|
||||
self._n_committed_words = 0
|
||||
self._global_time_offset = 0.0
|
||||
self._y_flushed_to_output = False # True after start_silence flushes pending _y
|
||||
|
||||
# ── Interface methods (same as SimulStreamingOnlineProcessor) ──
|
||||
|
||||
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float):
|
||||
self.end = audio_stream_end_time
|
||||
self._pending_audio = np.append(self._pending_audio, audio)
|
||||
self.audio_buffer = self._pending_audio # for logging compat
|
||||
|
||||
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
|
||||
try:
|
||||
return self._process_iter_inner(is_last)
|
||||
except Exception as e:
|
||||
logger.warning(f"[voxtral] process_iter exception: {e}", exc_info=True)
|
||||
return [], self.end
|
||||
|
||||
def _get_full_text(self) -> str:
|
||||
"""Decode all accumulated token IDs at once for correct spacing."""
|
||||
if not self._output_token_ids:
|
||||
return ""
|
||||
sp = self.asr._tokenizer
|
||||
return sp.decode(self._output_token_ids, special_token_policy=self._special_token_policy)
|
||||
|
||||
def get_buffer(self) -> Transcript:
|
||||
"""Return all uncommitted text as buffer, including pending _y token."""
|
||||
# Temporarily include pending _y for buffer display
|
||||
ids = list(self._output_token_ids)
|
||||
if self._y is not None and not self._y_flushed_to_output:
|
||||
sp = self.asr._tokenizer
|
||||
token_id = self._y.item()
|
||||
if token_id != sp.eos_id:
|
||||
ids.append(token_id)
|
||||
if not ids:
|
||||
return Transcript(start=None, end=None, text="")
|
||||
sp = self.asr._tokenizer
|
||||
full_text = sp.decode(ids, special_token_policy=self._special_token_policy)
|
||||
words = full_text.split()
|
||||
uncommitted = words[self._n_committed_words:]
|
||||
if uncommitted:
|
||||
text = " ".join(uncommitted)
|
||||
return Transcript(start=self.end, end=self.end, text=text)
|
||||
return Transcript(start=None, end=None, text="")
|
||||
|
||||
def start_silence(self) -> Tuple[List[ASRToken], float]:
|
||||
"""Flush all uncommitted words when silence starts."""
|
||||
self._flush_last_y() # Include the pending _y token before flushing
|
||||
words = self._flush_all_pending_words()
|
||||
logger.info(f"[voxtral] start_silence: flushed {len(words)} words")
|
||||
return words, self.end
|
||||
|
||||
def end_silence(self, silence_duration: float, offset: float):
|
||||
self._global_time_offset += silence_duration
|
||||
self.end += silence_duration
|
||||
|
||||
def new_speaker(self, change_speaker):
|
||||
self.start_silence()
|
||||
|
||||
def warmup(self, audio, init_prompt=""):
|
||||
pass
|
||||
|
||||
def finish(self) -> Tuple[List[ASRToken], float]:
|
||||
"""Flush remaining audio with right-padding to let the model finish decoding."""
|
||||
right_pad = np.zeros(
|
||||
N_RIGHT_PAD_TOKENS * self._samples_per_token, dtype=np.float32
|
||||
)
|
||||
self._pending_audio = np.append(self._pending_audio, right_pad)
|
||||
self._n_audio_samples_fed += len(right_pad)
|
||||
|
||||
final_words, _ = self._process_iter_inner(is_last=True)
|
||||
# Flush the last pending self._y token (like voxmlx's finally block)
|
||||
self._flush_last_y()
|
||||
final_words.extend(self._flush_all_pending_words())
|
||||
return final_words, self.end
|
||||
|
||||
# ── Word extraction ──
|
||||
|
||||
def _pos_to_time(self, pos: int) -> float:
|
||||
"""Convert a decode position to seconds relative to audio start."""
|
||||
SPT = self._samples_per_token
|
||||
return max(0.0, (pos - N_LEFT_PAD_TOKENS) * SPT / self.SAMPLING_RATE)
|
||||
|
||||
def _flush_last_y(self):
|
||||
"""Flush the last pending self._y token that hasn't been processed yet."""
|
||||
if self._y is None or self._y_flushed_to_output:
|
||||
return
|
||||
sp = self.asr._tokenizer
|
||||
token_id = self._y.item()
|
||||
if token_id != sp.eos_id:
|
||||
self._output_token_ids.append(token_id)
|
||||
self._token_positions.append(self._n_total_decoded)
|
||||
self._y_flushed_to_output = True
|
||||
|
||||
def _extract_new_words(self) -> List[ASRToken]:
|
||||
"""
|
||||
Split accumulated text into words and return new complete words
|
||||
(all but the last, which may still be growing).
|
||||
"""
|
||||
if not self._output_token_ids:
|
||||
return []
|
||||
|
||||
full_text = self._get_full_text()
|
||||
words = full_text.split()
|
||||
|
||||
new_words: List[ASRToken] = []
|
||||
n_tokens = len(self._output_token_ids)
|
||||
# All words except the last are guaranteed complete
|
||||
while len(words) > self._n_committed_words + 1:
|
||||
word = words[self._n_committed_words]
|
||||
word_idx = self._n_committed_words
|
||||
n_words_total = len(words)
|
||||
# Approximate: assign token range proportionally
|
||||
tok_start = int(word_idx / n_words_total * n_tokens)
|
||||
tok_end = int((word_idx + 1) / n_words_total * n_tokens)
|
||||
tok_start = min(tok_start, len(self._token_positions) - 1)
|
||||
tok_end = min(tok_end, len(self._token_positions) - 1)
|
||||
|
||||
start_time = self._pos_to_time(self._token_positions[tok_start]) + self._global_time_offset
|
||||
end_time = self._pos_to_time(self._token_positions[tok_end]) + self._global_time_offset
|
||||
|
||||
# Prepend space to match Whisper convention (Segment.from_tokens joins with '')
|
||||
text = word if self._n_committed_words == 0 else " " + word
|
||||
new_words.append(ASRToken(start=start_time, end=end_time, text=text))
|
||||
self._n_committed_words += 1
|
||||
|
||||
return new_words
|
||||
|
||||
def _flush_all_pending_words(self) -> List[ASRToken]:
|
||||
"""Flush ALL words including the last partial one."""
|
||||
if not self._output_token_ids:
|
||||
return []
|
||||
|
||||
full_text = self._get_full_text()
|
||||
words = full_text.split()
|
||||
|
||||
new_words: List[ASRToken] = []
|
||||
n_tokens = len(self._output_token_ids)
|
||||
n_words_total = max(len(words), 1)
|
||||
|
||||
while self._n_committed_words < len(words):
|
||||
word = words[self._n_committed_words]
|
||||
word_idx = self._n_committed_words
|
||||
|
||||
tok_start = int(word_idx / n_words_total * n_tokens)
|
||||
tok_end = int((word_idx + 1) / n_words_total * n_tokens)
|
||||
tok_start = min(tok_start, max(len(self._token_positions) - 1, 0))
|
||||
tok_end = min(tok_end, max(len(self._token_positions) - 1, 0))
|
||||
|
||||
if self._token_positions:
|
||||
start_time = self._pos_to_time(self._token_positions[tok_start]) + self._global_time_offset
|
||||
end_time = self._pos_to_time(self._token_positions[tok_end]) + self._global_time_offset
|
||||
else:
|
||||
start_time = self._global_time_offset
|
||||
end_time = self._global_time_offset
|
||||
|
||||
# Prepend space to match Whisper convention (Segment.from_tokens joins with '')
|
||||
text = word if self._n_committed_words == 0 else " " + word
|
||||
new_words.append(ASRToken(start=start_time, end=end_time, text=text))
|
||||
self._n_committed_words += 1
|
||||
|
||||
return new_words
|
||||
|
||||
# ── Core streaming logic ──
|
||||
|
||||
def _process_iter_inner(self, is_last: bool) -> Tuple[List[ASRToken], float]:
|
||||
import mlx.core as mx
|
||||
|
||||
from voxmlx.audio import log_mel_spectrogram_step
|
||||
from voxmlx.cache import RotatingKVCache
|
||||
|
||||
model = self.asr.model
|
||||
sp = self.asr._tokenizer
|
||||
prompt_tokens = self.asr._prompt_tokens
|
||||
prefix_len = len(prompt_tokens)
|
||||
SPT = self._samples_per_token
|
||||
|
||||
# ── Phase 1: Encode new audio ──
|
||||
if self._first_cycle and len(self._pending_audio) >= SPT:
|
||||
left_pad = np.zeros(N_LEFT_PAD_TOKENS * SPT, dtype=np.float32)
|
||||
n_feed = (len(self._pending_audio) // SPT) * SPT
|
||||
chunk = np.concatenate([left_pad, self._pending_audio[:n_feed]])
|
||||
self._pending_audio = self._pending_audio[n_feed:]
|
||||
self._n_audio_samples_fed += n_feed
|
||||
|
||||
mel, self._audio_tail = log_mel_spectrogram_step(
|
||||
chunk, self._audio_tail
|
||||
)
|
||||
(
|
||||
new_embeds,
|
||||
self._conv1_tail,
|
||||
self._conv2_tail,
|
||||
self._encoder_cache,
|
||||
self._ds_buf,
|
||||
) = model.encode_step(
|
||||
mel,
|
||||
self._conv1_tail,
|
||||
self._conv2_tail,
|
||||
self._encoder_cache,
|
||||
self._ds_buf,
|
||||
)
|
||||
if new_embeds is not None:
|
||||
mx.eval(new_embeds)
|
||||
self._audio_embeds = new_embeds
|
||||
logger.info(f"[voxtral] first encode: {new_embeds.shape[0]} embeds from {n_feed} samples")
|
||||
else:
|
||||
logger.info(f"[voxtral] first encode: no embeds from {n_feed} samples")
|
||||
self._first_cycle = False
|
||||
|
||||
elif not self._first_cycle and len(self._pending_audio) >= SPT:
|
||||
n_feed = (len(self._pending_audio) // SPT) * SPT
|
||||
chunk = self._pending_audio[:n_feed]
|
||||
self._pending_audio = self._pending_audio[n_feed:]
|
||||
self._n_audio_samples_fed += n_feed
|
||||
|
||||
mel, self._audio_tail = log_mel_spectrogram_step(
|
||||
chunk, self._audio_tail
|
||||
)
|
||||
(
|
||||
new_embeds,
|
||||
self._conv1_tail,
|
||||
self._conv2_tail,
|
||||
self._encoder_cache,
|
||||
self._ds_buf,
|
||||
) = model.encode_step(
|
||||
mel,
|
||||
self._conv1_tail,
|
||||
self._conv2_tail,
|
||||
self._encoder_cache,
|
||||
self._ds_buf,
|
||||
)
|
||||
if new_embeds is not None:
|
||||
mx.eval(new_embeds)
|
||||
if self._audio_embeds is not None:
|
||||
self._audio_embeds = mx.concatenate(
|
||||
[self._audio_embeds, new_embeds]
|
||||
)
|
||||
else:
|
||||
self._audio_embeds = new_embeds
|
||||
|
||||
self.audio_buffer = self._pending_audio # for logging compat
|
||||
|
||||
if self._audio_embeds is None:
|
||||
return [], self.end
|
||||
|
||||
# Safety: don't decode ahead of encoded audio
|
||||
safe_total = (
|
||||
N_LEFT_PAD_TOKENS + self._n_audio_samples_fed // SPT
|
||||
)
|
||||
n_decodable = min(
|
||||
self._audio_embeds.shape[0], safe_total - self._n_total_decoded
|
||||
)
|
||||
|
||||
if n_decodable <= 0:
|
||||
return [], self.end
|
||||
|
||||
# ── Phase 2: Prefill (once per utterance) ──
|
||||
if not self._prefilled:
|
||||
if self._n_total_decoded + self._audio_embeds.shape[0] < prefix_len:
|
||||
logger.info(
|
||||
f"[voxtral] waiting for prefill: have {self._audio_embeds.shape[0]} embeds, need {prefix_len}"
|
||||
)
|
||||
return [], self.end
|
||||
|
||||
n_layers = len(model.language_model.layers)
|
||||
self._decoder_cache = [RotatingKVCache(8192) for _ in range(n_layers)]
|
||||
|
||||
self._t_cond = model.time_embedding(
|
||||
mx.array([self.asr._n_delay_tokens], dtype=mx.float32)
|
||||
)
|
||||
|
||||
prompt_ids = mx.array([prompt_tokens])
|
||||
self._text_embeds = model.language_model.embed(prompt_ids)[0]
|
||||
|
||||
prefix_embeds = (
|
||||
self._text_embeds + self._audio_embeds[:prefix_len]
|
||||
)[None, :, :]
|
||||
|
||||
logits = model.decode(
|
||||
prefix_embeds, self._t_cond, "causal", self._decoder_cache
|
||||
)
|
||||
mx.eval(
|
||||
logits,
|
||||
*[x for c in self._decoder_cache for x in (c.keys, c.values)],
|
||||
)
|
||||
|
||||
self._y = mx.argmax(logits[0, -1:], axis=-1).squeeze()
|
||||
mx.async_eval(self._y)
|
||||
|
||||
self._audio_embeds = self._audio_embeds[prefix_len:]
|
||||
self._n_total_decoded = prefix_len
|
||||
self._prefilled = True
|
||||
logger.info(f"[voxtral] prefill done, first token y={self._y.item()}")
|
||||
|
||||
n_decodable = min(
|
||||
self._audio_embeds.shape[0], safe_total - self._n_total_decoded
|
||||
)
|
||||
|
||||
if n_decodable <= 0:
|
||||
return [], self.end
|
||||
|
||||
# ── Phase 3: Decode new positions ──
|
||||
eos_id = sp.eos_id
|
||||
hit_eos = False
|
||||
n_consumed = 0
|
||||
|
||||
for i in range(n_decodable):
|
||||
token_embed = model.language_model.embed(self._y.reshape(1, 1))[0, 0]
|
||||
step_embed = (self._audio_embeds[i] + token_embed)[None, None, :]
|
||||
logits = model.decode(
|
||||
step_embed, self._t_cond, mask=None, cache=self._decoder_cache
|
||||
)
|
||||
next_y = mx.argmax(logits[0, -1:], axis=-1).squeeze()
|
||||
mx.async_eval(next_y)
|
||||
|
||||
token_id = self._y.item()
|
||||
n_consumed = i + 1
|
||||
|
||||
if token_id == eos_id:
|
||||
hit_eos = True
|
||||
logger.info("[voxtral] hit EOS")
|
||||
break
|
||||
|
||||
# Accumulate token ID — full-sequence decode produces correct spacing
|
||||
# Skip if this _y was already flushed by start_silence()
|
||||
if self._y_flushed_to_output:
|
||||
self._y_flushed_to_output = False
|
||||
else:
|
||||
self._output_token_ids.append(token_id)
|
||||
# Track position for timestamp estimation
|
||||
pos = self._n_total_decoded + i
|
||||
self._token_positions.append(pos)
|
||||
|
||||
if i > 0 and i % 256 == 0:
|
||||
mx.clear_cache()
|
||||
|
||||
self._y = next_y
|
||||
|
||||
self._n_total_decoded += n_consumed
|
||||
|
||||
# Trim consumed embeddings
|
||||
if self._audio_embeds.shape[0] > n_consumed:
|
||||
self._audio_embeds = self._audio_embeds[n_consumed:]
|
||||
else:
|
||||
self._audio_embeds = None
|
||||
|
||||
# Log decode results
|
||||
full_text = self._get_full_text()
|
||||
logger.info(
|
||||
f"[voxtral] decoded {n_consumed} tokens | "
|
||||
f"total_decoded={self._n_total_decoded} | "
|
||||
f"text='{full_text[-80:]}' | "
|
||||
f"n_words={len(full_text.split())} committed={self._n_committed_words}"
|
||||
)
|
||||
|
||||
# Extract complete words from the decoded token sequence
|
||||
new_words = self._extract_new_words()
|
||||
|
||||
if hit_eos:
|
||||
new_words.extend(self._flush_all_pending_words())
|
||||
self._reset_state()
|
||||
|
||||
if new_words:
|
||||
logger.info(f"[voxtral] returning {len(new_words)} words: {[w.text for w in new_words]}")
|
||||
|
||||
self.buffer = []
|
||||
return new_words, self.end
|
||||
Reference in New Issue
Block a user