33 Commits

Author SHA1 Message Date
Quentin Fuxa
7ea507ed8e Add Voxtral MLX streaming backend
Integrates the voxmlx-based Voxtral Mini Realtime streaming pipeline:
- VoxtralStreamingASR and VoxtralStreamingOnlineProcessor
- Incremental audio encoding and token-by-token autoregressive decoding
- Selectable via --backend voxtral-mlx

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 09:20:28 +01:00
Quentin Fuxa
e7e82f7c19 bump to 0.2.18 2026-02-11 22:10:00 +01:00
Quentin Fuxa
8c799fa4d1 fix simulstreaming vram leak: cap cross-attn accumulation + token budget
fixes #283, fixes #275

- accumulated_cross_attns was growing unboundedly during decoding loop,
  using up to ~5GB for repetition loops. now capped to rolling window of 16
- max_tokens_per_chunk was using TOKENS_PER_SECOND (mel frame rate = 50)
  instead of actual text token rate (~15/s), allowing 10-40x too many
  decoding steps
- removed unused torch.cat on early return path
- removed dead self.committed/last_result_tokens lists (never read)
- same fixes applied to mlx variant
2026-02-11 22:10:00 +01:00
Quentin Fuxa
8923337380 fix --direct-english-translation not setting task=translate for localagreement backends
the flag was only used for tokenizer language selection but never
actually passed to whisper/faster-whisper transcribe calls. also init
OpenaiApiASR.task and read from transcribe_kargs.

fixes #306
2026-02-11 22:10:00 +01:00
Quentin Fuxa
aded1649ae fix model_cache_dir + direct_english_translation task in simulstreaming
pass actual cache dir instead of None, and use proper task string
instead of boolean for AlignAttConfig

fixes #310
2026-02-11 22:10:00 +01:00
Quentin Fuxa
3b535e857a fix NoneType concatenation in add_translation
fixes #296
2026-02-11 22:10:00 +01:00
Quentin Fuxa
d649250b9a fix Segment classmethod call + isinstance type narrowing
fixes #331, fixes #329
2026-02-11 22:10:00 +01:00
Quentin Fuxa
7735478286 add insert_audio_chunk to DiartDiarization
fixes #332
2026-02-11 22:10:00 +01:00
Quentin Fuxa
b9e72d2b9a add probability field to ASRToken
fixes #330, fixes #313
2026-02-11 22:10:00 +01:00
Quentin Fuxa
e5b01033af add json normalizers for english language in build 2026-01-16 10:47:46 +01:00
Quentin Fuxa
6ae545bcb1 bump to 0.2.17.post1 2026-01-16 10:43:52 +01:00
Quentin Fuxa
04980d3f5e Merge branch 'main' of https://github.com/QuentinFuxa/WhisperLiveKit 2026-01-16 10:38:29 +01:00
Quentin Fuxa
79a705c969 fixes #323 2026-01-16 10:38:07 +01:00
Quentin Fuxa
34e4abd455 Merge pull request #322 from eschmidbauer/fix/thread-safety-issues
Fix kv cache not being properly cleaned between sessions
2026-01-09 19:23:35 +01:00
Emmanuel Schmidbauer
d59ddbaeae Fix critical thread safety issues 2026-01-09 11:23:19 -05:00
Quentin Fuxa
4dd66e7766 Merge pull request #317 from jantonj/fix-bug-diarization-lag
update diarization lag after stream analysed
2025-12-19 17:43:07 +01:00
Anton Jacobson
3db5d81a20 update diarization lag after stream analysed 2025-12-18 14:13:28 +01:00
Quentin Fuxa
b67ddea494 bump to 0.2.17 2025-12-08 23:52:00 +01:00
Quentin Fuxa
3192553e20 fixes #307 2025-12-09 10:27:49 +01:00
Quentin Fuxa
f379a243fe Merge pull request #274 from blakkd/patch-1
minor path change
2025-12-09 10:10:32 +01:00
Quentin Fuxa
ec09898a9f fixes #301 2025-12-06 10:19:50 +01:00
blakkd
befbae56c7 minor path change
prevents

```
FileNotFoundError: [Errno 2] No such file or directory: 'whisperlivekit/web/live_transcription.html'
```
2025-11-16 23:47:58 +01:00
Quentin Fuxa
719e8b1a20 adapt online for mlx detection 2024-11-25 23:52:00 +01:00
Quentin Fuxa
f1b47178d8 adapt online for mlx detection 2024-11-25 23:52:00 +01:00
Quentin Fuxa
59db08e961 loader for full mlx 2024-11-25 23:52:00 +01:00
Quentin Fuxa
6fc20b9562 new dec class 2024-11-21 23:52:00 +01:00
Quentin Fuxa
fac8659161 uses native mlx function for attention 2024-11-21 23:52:00 +01:00
Quentin Fuxa
4d9332ce7d fixes #299 2025-12-05 17:54:14 +01:00
Quentin Fuxa
62444ce746 session parameter required in OnnxWrapper 2025-12-05 15:37:18 +01:00
Quentin Fuxa
2431a6bf91 isolated VAD states per user: .onnx: share a stateless model. .jit: require duplicating the model.
Co-authored-by: eschmidbauer <eschmidbauer@gmail.com>
2025-12-05 15:27:14 +01:00
Quentin Fuxa
d1263e7228 Merge pull request #308 from gzz2000/main
Fix local agreement backend, removing excess parameter, #295
2025-12-05 11:34:05 +01:00
Zizheng Guo
30ddd522a4 Fix local agreement backend, removing excess parameter, fixes https://github.com/QuentinFuxa/WhisperLiveKit/issues/295 2025-12-04 16:45:23 +08:00
Quentin Fuxa
635bace09e update archi 2025-11-30 18:39:10 +01:00
26 changed files with 2087 additions and 184 deletions

View File

@@ -37,9 +37,10 @@ RUN pip3 install --upgrade pip setuptools wheel && \
COPY . . COPY . .
# Install WhisperLiveKit directly, allowing for optional dependencies # Install WhisperLiveKit directly, allowing for optional dependencies
# Example: --build-arg EXTRAS="translation"
RUN if [ -n "$EXTRAS" ]; then \ RUN if [ -n "$EXTRAS" ]; then \
echo "Installing with extras: [$EXTRAS]"; \ echo "Installing with extras: [$EXTRAS]"; \
pip install --no-cache-dir whisperlivekit[$EXTRAS]; \ pip install --no-cache-dir "whisperlivekit[$EXTRAS]"; \
else \ else \
echo "Installing base package only"; \ echo "Installing base package only"; \
pip install --no-cache-dir whisperlivekit; \ pip install --no-cache-dir whisperlivekit; \

View File

@@ -147,8 +147,8 @@ async def websocket_endpoint(websocket: WebSocket):
|-----------|-------------|---------| |-----------|-------------|---------|
| `--model` | Whisper model size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/default_and_custom_models.md) | `small` | | `--model` | Whisper model size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/default_and_custom_models.md) | `small` |
| `--model-path` | Local .pt file/directory **or** Hugging Face repo ID containing the Whisper model. Overrides `--model`. Recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/default_and_custom_models.md) | `None` | | `--model-path` | Local .pt file/directory **or** Hugging Face repo ID containing the Whisper model. Overrides `--model`. Recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/default_and_custom_models.md) | `None` |
| `--language` | List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/whisper/tokenizer.py). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English. | `auto` | | `--language` | List [here](docs/supported_languages.md). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English. | `auto` |
| `--target-language` | If sets, translates using [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting). [200 languages available](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/supported_languages.md). If you want to translate to english, you can also use `--direct-english-translation`. The STT model will try to directly output the translation. | `None` | | `--target-language` | If sets, translates using [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting). [200 languages available](docs/supported_languages.md). If you want to translate to english, you can also use `--direct-english-translation`. The STT model will try to directly output the translation. | `None` |
| `--diarization` | Enable speaker identification | `False` | | `--diarization` | Enable speaker identification | `False` |
| `--backend-policy` | Streaming strategy: `1`/`simulstreaming` uses AlignAtt SimulStreaming, `2`/`localagreement` uses the LocalAgreement policy | `simulstreaming` | | `--backend-policy` | Streaming strategy: `1`/`simulstreaming` uses AlignAtt SimulStreaming, `2`/`localagreement` uses the LocalAgreement policy | `simulstreaming` |
| `--backend` | Whisper implementation selector. `auto` picks MLX on macOS (if installed), otherwise Faster-Whisper, otherwise vanilla Whisper. You can also force `mlx-whisper`, `faster-whisper`, `whisper`, or `openai-api` (LocalAgreement only) | `auto` | | `--backend` | Whisper implementation selector. `auto` picks MLX on macOS (if installed), otherwise Faster-Whisper, otherwise vanilla Whisper. You can also force `mlx-whisper`, `faster-whisper`, `whisper`, or `openai-api` (LocalAgreement only) | `auto` |
@@ -267,7 +267,7 @@ docker run --gpus all -p 8000:8000 --name wlk wlk --model large-v3 --language fr
#### Customization #### Customization
- `--build-arg` Options: - `--build-arg` Options:
- `EXTRAS="whisper-timestamped"` - Add extras to the image's installation (no spaces). Remember to set necessary container options! - `EXTRAS="translation"` - Add extras to the image's installation (no spaces). Remember to set necessary container options!
- `HF_PRECACHE_DIR="./.cache/"` - Pre-load a model cache for faster first-time start - `HF_PRECACHE_DIR="./.cache/"` - Pre-load a model cache for faster first-time start
- `HF_TKN_FILE="./token"` - Add your Hugging Face Hub access token to download gated models - `HF_TKN_FILE="./token"` - Add your Hugging Face Hub access token to download gated models

Binary file not shown.

Before

Width:  |  Height:  |  Size: 422 KiB

After

Width:  |  Height:  |  Size: 422 KiB

View File

@@ -6,7 +6,7 @@ Capture the audio of your current tab, transcribe diarize and translate it using
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/chrome-extension/demo-extension.png" alt="WhisperLiveKit Demo" width="730"> <img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/chrome-extension/demo-extension.png" alt="WhisperLiveKit Demo" width="730">
## Running this extension ## Running this extension
1. Run `python sync_extension.py` to copy frontend files to the `chrome-extension` directory. 1. Run `python scripts/sync_extension.py` to copy frontend files to the `chrome-extension` directory.
2. Load the `chrome-extension` directory in Chrome as an unpacked extension. 2. Load the `chrome-extension` directory in Chrome as an unpacked extension.

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "whisperlivekit" name = "whisperlivekit"
version = "0.2.16.dev0" version = "0.2.18"
description = "Real-time speech-to-text with speaker diarization using Whisper" description = "Real-time speech-to-text with speaker diarization using Whisper"
readme = "README.md" readme = "README.md"
authors = [ authors = [
@@ -35,6 +35,7 @@ dependencies = [
"torchaudio>=2.0.0", "torchaudio>=2.0.0",
"torch>=2.0.0", "torch>=2.0.0",
"huggingface-hub>=0.25.0", "huggingface-hub>=0.25.0",
"faster-whisper>=1.2.0",
"tqdm", "tqdm",
"tiktoken", "tiktoken",
'triton>=2.0.0; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")' 'triton>=2.0.0; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")'
@@ -56,6 +57,7 @@ packages = [
"whisperlivekit", "whisperlivekit",
"whisperlivekit.diarization", "whisperlivekit.diarization",
"whisperlivekit.simul_whisper", "whisperlivekit.simul_whisper",
"whisperlivekit.simul_whisper.mlx",
"whisperlivekit.whisper", "whisperlivekit.whisper",
"whisperlivekit.whisper.assets", "whisperlivekit.whisper.assets",
"whisperlivekit.whisper.normalizers", "whisperlivekit.whisper.normalizers",
@@ -67,4 +69,5 @@ packages = [
[tool.setuptools.package-data] [tool.setuptools.package-data]
whisperlivekit = ["web/*.html", "web/*.css", "web/*.js", "web/src/*.svg"] whisperlivekit = ["web/*.html", "web/*.css", "web/*.js", "web/src/*.svg"]
"whisperlivekit.whisper.assets" = ["*.tiktoken", "*.npz"] "whisperlivekit.whisper.assets" = ["*.tiktoken", "*.npz"]
"whisperlivekit.whisper.normalizers" = ["*.json"]
"whisperlivekit.silero_vad_models" = ["*.jit", "*.onnx"] "whisperlivekit.silero_vad_models" = ["*.jit", "*.onnx"]

View File

@@ -10,7 +10,7 @@ from whisperlivekit.core import (TranscriptionEngine,
online_diarization_factory, online_factory, online_diarization_factory, online_factory,
online_translation_factory) online_translation_factory)
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
from whisperlivekit.silero_vad_iterator import FixedVADIterator from whisperlivekit.silero_vad_iterator import FixedVADIterator, OnnxWrapper, load_jit_vad
from whisperlivekit.timed_objects import (ASRToken, ChangeSpeaker, FrontData, from whisperlivekit.timed_objects import (ASRToken, ChangeSpeaker, FrontData,
Segment, Silence, State, Transcript) Segment, Silence, State, Transcript)
from whisperlivekit.tokens_alignment import TokensAlignment from whisperlivekit.tokens_alignment import TokensAlignment
@@ -85,12 +85,14 @@ class AudioProcessor:
# Models and processing # Models and processing
self.asr: Any = models.asr self.asr: Any = models.asr
self.vac_model: Any = models.vac_model self.vac: Optional[FixedVADIterator] = None
if self.args.vac:
self.vac: Optional[FixedVADIterator] = FixedVADIterator(models.vac_model)
else:
self.vac: Optional[FixedVADIterator] = None
if self.args.vac:
if models.vac_session is not None:
vac_model = OnnxWrapper(session=models.vac_session)
self.vac = FixedVADIterator(vac_model)
else:
self.vac = FixedVADIterator(load_jit_vad())
self.ffmpeg_manager: Optional[FFmpegManager] = None self.ffmpeg_manager: Optional[FFmpegManager] = None
self.ffmpeg_reader_task: Optional[asyncio.Task] = None self.ffmpeg_reader_task: Optional[asyncio.Task] = None
self._ffmpeg_error: Optional[str] = None self._ffmpeg_error: Optional[str] = None
@@ -351,11 +353,14 @@ class AudioProcessor:
if item.has_ended: if item.has_ended:
self.diarization.insert_silence(item.duration) self.diarization.insert_silence(item.duration)
continue continue
self.diarization.insert_audio_chunk(item) self.diarization.insert_audio_chunk(item)
diarization_segments = await self.diarization.diarize() diarization_segments = await self.diarization.diarize()
self.state.new_diarization = diarization_segments diar_end = 0.0
if diarization_segments:
diar_end = max(getattr(s, "end", 0.0) for s in diarization_segments)
async with self.lock:
self.state.new_diarization = diarization_segments
self.state.end_attributed_speaker = max(self.state.end_attributed_speaker, diar_end)
except Exception as e: except Exception as e:
logger.warning(f"Exception in diarization_processor: {e}") logger.warning(f"Exception in diarization_processor: {e}")
logger.warning(f"Traceback: {traceback.format_exc()}") logger.warning(f"Traceback: {traceback.format_exc()}")

View File

@@ -29,6 +29,13 @@ def mlx_backend_available(warn_on_missing = False):
return available return available
def voxmlx_backend_available():
"""Return True if voxmlx (Voxtral MLX backend) is available."""
is_macos = platform.system() == "Darwin"
is_arm = platform.machine() == "arm64"
return is_macos and is_arm and module_available("voxmlx")
def faster_backend_available(warn_on_missing = False): def faster_backend_available(warn_on_missing = False):
available = module_available("faster_whisper") available = module_available("faster_whisper")
if not available and warn_on_missing and platform.system() != "Darwin": if not available and warn_on_missing and platform.system() != "Darwin":

View File

@@ -1,5 +1,6 @@
import logging import logging
import sys import sys
import threading
from argparse import Namespace from argparse import Namespace
from whisperlivekit.local_agreement.online_asr import OnlineASRProcessor from whisperlivekit.local_agreement.online_asr import OnlineASRProcessor
@@ -19,16 +20,26 @@ logger = logging.getLogger(__name__)
class TranscriptionEngine: class TranscriptionEngine:
_instance = None _instance = None
_initialized = False _initialized = False
_lock = threading.Lock() # Thread-safe singleton lock
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
# Double-checked locking pattern for thread-safe singleton
if cls._instance is None: if cls._instance is None:
cls._instance = super().__new__(cls) with cls._lock:
# Check again inside lock to prevent race condition
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance return cls._instance
def __init__(self, **kwargs): def __init__(self, **kwargs):
if TranscriptionEngine._initialized: # Thread-safe initialization check
return with TranscriptionEngine._lock:
if TranscriptionEngine._initialized:
return
# Set flag immediately to prevent re-initialization
TranscriptionEngine._initialized = True
# Perform initialization outside lock to avoid holding lock during slow operations
global_params = { global_params = {
"host": "localhost", "host": "localhost",
"port": 8000, "port": 8000,
@@ -36,7 +47,6 @@ class TranscriptionEngine:
"punctuation_split": False, "punctuation_split": False,
"target_language": "", "target_language": "",
"vac": True, "vac": True,
"vac_onnx": False,
"vac_chunk_size": 0.04, "vac_chunk_size": 0.04,
"log_level": "DEBUG", "log_level": "DEBUG",
"ssl_certfile": None, "ssl_certfile": None,
@@ -79,18 +89,27 @@ class TranscriptionEngine:
self.asr = None self.asr = None
self.tokenizer = None self.tokenizer = None
self.diarization = None self.diarization = None
self.vac_model = None self.vac_session = None
if self.args.vac: if self.args.vac:
from whisperlivekit.silero_vad_iterator import load_silero_vad from whisperlivekit.silero_vad_iterator import is_onnx_available
# Use ONNX if specified, otherwise use JIT (default)
use_onnx = kwargs.get('vac_onnx', False)
self.vac_model = load_silero_vad(onnx=use_onnx)
if is_onnx_available():
from whisperlivekit.silero_vad_iterator import load_onnx_session
self.vac_session = load_onnx_session()
else:
logger.warning(
"onnxruntime not installed. VAC will use JIT model which is loaded per-session. "
"For multi-user scenarios, install onnxruntime: pip install onnxruntime"
)
backend_policy = self.args.backend_policy backend_policy = self.args.backend_policy
if self.args.transcription: if self.args.transcription:
if backend_policy == "simulstreaming": if self.args.backend == "voxtral-mlx":
from whisperlivekit.voxtral_streaming import VoxtralStreamingASR
self.tokenizer = None
self.asr = VoxtralStreamingASR(**transcription_common_params)
logger.info("Using Voxtral MLX streaming backend")
elif backend_policy == "simulstreaming":
simulstreaming_params = { simulstreaming_params = {
"disable_fast_encoder": False, "disable_fast_encoder": False,
"custom_alignment_heads": None, "custom_alignment_heads": None,
@@ -169,16 +188,16 @@ class TranscriptionEngine:
} }
translation_params = update_with_kwargs(translation_params, kwargs) translation_params = update_with_kwargs(translation_params, kwargs)
self.translation_model = load_model([self.args.lan], **translation_params) #in the future we want to handle different languages for different speakers self.translation_model = load_model([self.args.lan], **translation_params) #in the future we want to handle different languages for different speakers
TranscriptionEngine._initialized = True
def online_factory(args, asr): def online_factory(args, asr):
if getattr(args, 'backend', None) == "voxtral-mlx":
from whisperlivekit.voxtral_streaming import VoxtralStreamingOnlineProcessor
return VoxtralStreamingOnlineProcessor(asr)
if args.backend_policy == "simulstreaming": if args.backend_policy == "simulstreaming":
from whisperlivekit.simul_whisper import SimulStreamingOnlineProcessor from whisperlivekit.simul_whisper import SimulStreamingOnlineProcessor
online = SimulStreamingOnlineProcessor(asr) return SimulStreamingOnlineProcessor(asr)
else: return OnlineASRProcessor(asr)
online = OnlineASRProcessor(asr)
return online
def online_diarization_factory(args, diarization_backend): def online_diarization_factory(args, diarization_backend):

View File

@@ -202,14 +202,14 @@ class DiartDiarization:
def insert_silence(self, silence_duration): def insert_silence(self, silence_duration):
self.observer.global_time_offset += silence_duration self.observer.global_time_offset += silence_duration
async def diarize(self, pcm_array: np.ndarray): def insert_audio_chunk(self, pcm_array: np.ndarray):
""" """Buffer audio for the next diarization step."""
Process audio data for diarization.
Only used when working with WebSocketAudioSource.
"""
if self.custom_source: if self.custom_source:
self.custom_source.push_audio(pcm_array) self.custom_source.push_audio(pcm_array)
# self.observer.clear_old_segments()
async def diarize(self):
"""Return the current speaker segments from the diarization pipeline."""
return self.observer.get_segments()
def close(self): def close(self):
"""Close the audio source.""" """Close the audio source."""

View File

@@ -249,6 +249,7 @@ class OpenaiApiASR(ASRBase):
self.load_model() self.load_model()
self.use_vad_opt = False self.use_vad_opt = False
self.direct_english_translation = False self.direct_english_translation = False
self.task = "transcribe"
def load_model(self, *args, **kwargs): def load_model(self, *args, **kwargs):
from openai import OpenAI from openai import OpenAI
@@ -294,7 +295,8 @@ class OpenaiApiASR(ASRBase):
params["language"] = self.original_language params["language"] = self.original_language
if prompt: if prompt:
params["prompt"] = prompt params["prompt"] = prompt
proc = self.client.audio.translations if self.task == "translate" else self.client.audio.transcriptions task = self.transcribe_kargs.get("task", self.task)
proc = self.client.audio.translations if task == "translate" else self.client.audio.transcriptions
transcript = proc.create(**params) transcript = proc.create(**params)
logger.debug(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds") logger.debug(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds")
return transcript return transcript

View File

@@ -146,6 +146,7 @@ def backend_factory(
if direct_english_translation: if direct_english_translation:
tgt_language = "en" # Whisper translates into English tgt_language = "en" # Whisper translates into English
asr.transcribe_kargs["task"] = "translate"
else: else:
tgt_language = lan # Whisper transcribes in this language tgt_language = lan # Whisper transcribes in this language

View File

@@ -147,8 +147,8 @@ def parse_args():
"--backend", "--backend",
type=str, type=str,
default="auto", default="auto",
choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api"], choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api", "voxtral-mlx"],
help="Select the Whisper backend implementation (auto: prefer MLX on macOS, otherwise Faster-Whisper, else Whisper). Use 'openai-api' with --backend-policy localagreement to call OpenAI's API.", help="Select the Whisper backend implementation (auto: prefer MLX on macOS, otherwise Faster-Whisper, else Whisper). Use 'openai-api' with --backend-policy localagreement to call OpenAI's API. Use 'voxtral-mlx' for Voxtral streaming on Apple Silicon.",
) )
parser.add_argument( parser.add_argument(
"--no-vac", "--no-vac",

View File

@@ -8,6 +8,15 @@ import torch
Code is adapted from silero-vad v6: https://github.com/snakers4/silero-vad Code is adapted from silero-vad v6: https://github.com/snakers4/silero-vad
""" """
def is_onnx_available() -> bool:
"""Check if onnxruntime is installed."""
try:
import onnxruntime
return True
except ImportError:
return False
def init_jit_model(model_path: str, device=torch.device('cpu')): def init_jit_model(model_path: str, device=torch.device('cpu')):
"""Load a JIT model from file.""" """Load a JIT model from file."""
model = torch.jit.load(model_path, map_location=device) model = torch.jit.load(model_path, map_location=device)
@@ -15,12 +24,12 @@ def init_jit_model(model_path: str, device=torch.device('cpu')):
return model return model
class OnnxWrapper(): class OnnxSession():
"""ONNX Runtime wrapper for Silero VAD model.""" """
Shared ONNX session for Silero VAD model (stateless).
"""
def __init__(self, path, force_onnx_cpu=False): def __init__(self, path, force_onnx_cpu=False):
global np
import numpy as np
import onnxruntime import onnxruntime
opts = onnxruntime.SessionOptions() opts = onnxruntime.SessionOptions()
@@ -32,13 +41,28 @@ class OnnxWrapper():
else: else:
self.session = onnxruntime.InferenceSession(path, sess_options=opts) self.session = onnxruntime.InferenceSession(path, sess_options=opts)
self.reset_states() self.path = path
if '16k' in path: if '16k' in path:
warnings.warn('This model support only 16000 sampling rate!') warnings.warn('This model support only 16000 sampling rate!')
self.sample_rates = [16000] self.sample_rates = [16000]
else: else:
self.sample_rates = [8000, 16000] self.sample_rates = [8000, 16000]
class OnnxWrapper():
"""
ONNX Runtime wrapper for Silero VAD model with per-instance state.
"""
def __init__(self, session: OnnxSession, force_onnx_cpu=False):
self._shared_session = session
self.sample_rates = session.sample_rates
self.reset_states()
@property
def session(self):
return self._shared_session.session
def _validate_input(self, x, sr: int): def _validate_input(self, x, sr: int):
if x.dim() == 1: if x.dim() == 1:
x = x.unsqueeze(0) x = x.unsqueeze(0)
@@ -101,38 +125,20 @@ class OnnxWrapper():
return out return out
def load_silero_vad(model_path: str = None, onnx: bool = False, opset_version: int = 16): def _get_onnx_model_path(model_path: str = None, opset_version: int = 16) -> Path:
""" """Get the path to the ONNX model file."""
Load Silero VAD model (JIT or ONNX).
Parameters
----------
model_path : str, optional
Path to model file. If None, uses default bundled model.
onnx : bool, default False
Whether to use ONNX runtime (requires onnxruntime package).
opset_version : int, default 16
ONNX opset version (15 or 16). Only used if onnx=True.
Returns
-------
model
Loaded VAD model (JIT or ONNX wrapper)
"""
available_ops = [15, 16] available_ops = [15, 16]
if onnx and opset_version not in available_ops: if opset_version not in available_ops:
raise Exception(f'Available ONNX opset_version: {available_ops}') raise Exception(f'Available ONNX opset_version: {available_ops}')
if model_path is None: if model_path is None:
current_dir = Path(__file__).parent current_dir = Path(__file__).parent
data_dir = current_dir / 'silero_vad_models' data_dir = current_dir / 'silero_vad_models'
if onnx: if opset_version == 16:
if opset_version == 16: model_name = 'silero_vad.onnx'
model_name = 'silero_vad.onnx'
else:
model_name = f'silero_vad_16k_op{opset_version}.onnx'
else: else:
model_name = 'silero_vad.jit' model_name = f'silero_vad_16k_op{opset_version}.onnx'
model_path = data_dir / model_name model_path = data_dir / model_name
@@ -143,16 +149,38 @@ def load_silero_vad(model_path: str = None, onnx: bool = False, opset_version: i
) )
else: else:
model_path = Path(model_path) model_path = Path(model_path)
if onnx:
try: return model_path
model = OnnxWrapper(str(model_path), force_onnx_cpu=True)
except ImportError:
raise ImportError( def load_onnx_session(model_path: str = None, opset_version: int = 16, force_onnx_cpu: bool = True) -> OnnxSession:
"ONNX runtime not available. Install with: pip install onnxruntime\n" """
"Or use JIT model by setting onnx=False" Load a shared ONNX session for Silero VAD.
"""
path = _get_onnx_model_path(model_path, opset_version)
return OnnxSession(str(path), force_onnx_cpu=force_onnx_cpu)
def load_jit_vad(model_path: str = None):
"""
Load Silero VAD model in JIT format.
"""
if model_path is None:
current_dir = Path(__file__).parent
data_dir = current_dir / 'silero_vad_models'
model_name = 'silero_vad.jit'
model_path = data_dir / model_name
if not model_path.exists():
raise FileNotFoundError(
f"Model file not found: {model_path}\n"
f"Please ensure the whisperlivekit/silero_vad_models/ directory contains the model files."
) )
else: else:
model = init_jit_model(str(model_path)) model_path = Path(model_path)
model = init_jit_model(str(model_path))
return model return model
@@ -285,8 +313,8 @@ class FixedVADIterator(VADIterator):
if __name__ == "__main__": if __name__ == "__main__":
model = load_silero_vad(onnx=False) # vad = FixedVADIterator(load_jit_vad())
vad = FixedVADIterator(model) vad = FixedVADIterator(OnnxWrapper(session=load_onnx_session()))
audio_buffer = np.array([0] * 512, dtype=np.float32) audio_buffer = np.array([0] * 512, dtype=np.float32)
result = vad(audio_buffer) result = vad(audio_buffer)
@@ -295,3 +323,4 @@ if __name__ == "__main__":
# test with 511 samples # test with 511 samples
audio_buffer = np.array([0] * 511, dtype=np.float32) audio_buffer = np.array([0] * 511, dtype=np.float32)
result = vad(audio_buffer) result = vad(audio_buffer)
print(f" 511 samples: {result}")

View File

@@ -24,9 +24,11 @@ logger = logging.getLogger(__name__)
HAS_MLX_WHISPER = mlx_backend_available(warn_on_missing=True) HAS_MLX_WHISPER = mlx_backend_available(warn_on_missing=True)
if HAS_MLX_WHISPER: if HAS_MLX_WHISPER:
from .mlx_encoder import load_mlx_encoder, mlx_model_mapping from .mlx_encoder import load_mlx_encoder, load_mlx_model, mlx_model_mapping
from .mlx import MLXAlignAtt
else: else:
mlx_model_mapping = {} mlx_model_mapping = {}
MLXAlignAtt = None
HAS_FASTER_WHISPER = faster_backend_available(warn_on_missing=not HAS_MLX_WHISPER) HAS_FASTER_WHISPER = faster_backend_available(warn_on_missing=not HAS_MLX_WHISPER)
if HAS_FASTER_WHISPER: if HAS_FASTER_WHISPER:
from faster_whisper import WhisperModel from faster_whisper import WhisperModel
@@ -36,50 +38,47 @@ else:
MIN_DURATION_REAL_SILENCE = 5 MIN_DURATION_REAL_SILENCE = 5
class SimulStreamingOnlineProcessor: class SimulStreamingOnlineProcessor:
"""Online processor for SimulStreaming ASR."""
SAMPLING_RATE = 16000 SAMPLING_RATE = 16000
def __init__( def __init__(self, asr, logfile=sys.stderr):
self,
asr,
logfile=sys.stderr,
):
self.asr = asr self.asr = asr
self.logfile = logfile self.logfile = logfile
self.end = 0.0 self.end = 0.0
self.buffer = [] self.buffer = []
self.committed: List[ASRToken] = [] self.model = self._create_alignatt()
self.last_result_tokens: List[ASRToken] = []
self.load_new_alignatt_instance()
if asr.tokenizer: if asr.tokenizer:
self.model.tokenizer = asr.tokenizer self.model.tokenizer = asr.tokenizer
self.model.state.tokenizer = asr.tokenizer
def load_new_alignatt_instance(self): def _create_alignatt(self):
"""Initialize AlignAtt decoder using the shared model.""" """Create the AlignAtt decoder instance based on ASR mode."""
self.model = AlignAtt( if self.asr.use_full_mlx and HAS_MLX_WHISPER:
cfg=self.asr.cfg, return MLXAlignAtt(cfg=self.asr.cfg, mlx_model=self.asr.mlx_model)
loaded_model=self.asr.shared_model, else:
mlx_encoder=self.asr.mlx_encoder, return AlignAtt(
fw_encoder=self.asr.fw_encoder, cfg=self.asr.cfg,
) loaded_model=self.asr.shared_model,
mlx_encoder=self.asr.mlx_encoder,
fw_encoder=self.asr.fw_encoder,
)
def start_silence(self): def start_silence(self):
tokens, processed_upto = self.process_iter(is_last=True) tokens, processed_upto = self.process_iter(is_last=True)
return tokens, processed_upto return tokens, processed_upto
def end_silence(self, silence_duration, offset): def end_silence(self, silence_duration, offset):
""" """Handle silence period."""
Handle silence period.
If silence > MIN_DURATION_REAL_SILENCE, do a complete context clear.
Otherwise, insert a small silence and shift the last_attend_frame.
"""
self.end += silence_duration self.end += silence_duration
long_silence = silence_duration >= MIN_DURATION_REAL_SILENCE long_silence = silence_duration >= MIN_DURATION_REAL_SILENCE
if not long_silence: if not long_silence:
gap_len = int(16000 * silence_duration) gap_len = int(16000 * silence_duration)
if gap_len > 0: if gap_len > 0:
gap_silence = torch.zeros(gap_len) if self.asr.use_full_mlx:
gap_silence = np.zeros(gap_len, dtype=np.float32)
else:
gap_silence = torch.zeros(gap_len)
self.model.insert_audio(gap_silence) self.model.insert_audio(gap_silence)
if long_silence: if long_silence:
self.model.refresh_segment(complete=True) self.model.refresh_segment(complete=True)
@@ -87,11 +86,12 @@ class SimulStreamingOnlineProcessor:
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time): def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time):
"""Append an audio chunk to be processed by SimulStreaming.""" """Append an audio chunk to be processed by SimulStreaming."""
self.end = audio_stream_end_time
# Convert numpy array to torch tensor if self.asr.use_full_mlx:
audio_tensor = torch.from_numpy(audio).float() self.model.insert_audio(audio)
self.end = audio_stream_end_time # Aligned with whisperstreaming backend behavior else:
self.model.insert_audio(audio_tensor) audio_tensor = torch.from_numpy(audio).float()
self.model.insert_audio(audio_tensor)
def new_speaker(self, change_speaker: ChangeSpeaker): def new_speaker(self, change_speaker: ChangeSpeaker):
"""Handle speaker change event.""" """Handle speaker change event."""
@@ -120,7 +120,6 @@ class SimulStreamingOnlineProcessor:
self.buffer.extend(timestamped_words) self.buffer.extend(timestamped_words)
return [], self.end return [], self.end
self.committed.extend(timestamped_words)
self.buffer = [] self.buffer = []
return timestamped_words, self.end return timestamped_words, self.end
except Exception as e: except Exception as e:
@@ -130,6 +129,10 @@ class SimulStreamingOnlineProcessor:
def warmup(self, audio, init_prompt=""): def warmup(self, audio, init_prompt=""):
"""Warmup the SimulStreaming model.""" """Warmup the SimulStreaming model."""
try: try:
if self.asr.use_full_mlx:
# MLX mode: ensure numpy array
if hasattr(audio, 'numpy'):
audio = audio.numpy()
self.model.insert_audio(audio) self.model.insert_audio(audio)
self.model.infer(True) self.model.infer(True)
self.model.refresh_segment(complete=True) self.model.refresh_segment(complete=True)
@@ -139,9 +142,14 @@ class SimulStreamingOnlineProcessor:
def __del__(self): def __del__(self):
gc.collect() gc.collect()
torch.cuda.empty_cache() if not getattr(self.asr, 'use_full_mlx', True) and torch is not None:
try:
torch.cuda.empty_cache()
except Exception:
pass
class SimulStreamingASR():
class SimulStreamingASR:
"""SimulStreaming backend with AlignAtt policy.""" """SimulStreaming backend with AlignAtt policy."""
sep = "" sep = ""
@@ -158,6 +166,7 @@ class SimulStreamingASR():
self.fast_encoder = False self.fast_encoder = False
self._resolved_model_path = None self._resolved_model_path = None
self.encoder_backend = "whisper" self.encoder_backend = "whisper"
self.use_full_mlx = getattr(self, "use_full_mlx", False)
preferred_backend = getattr(self, "backend", "auto") preferred_backend = getattr(self, "backend", "auto")
compatible_whisper_mlx, compatible_faster_whisper = True, True compatible_whisper_mlx, compatible_faster_whisper = True, True
@@ -170,7 +179,7 @@ class SimulStreamingASR():
compatible_whisper_mlx = model_info.compatible_whisper_mlx compatible_whisper_mlx = model_info.compatible_whisper_mlx
compatible_faster_whisper = model_info.compatible_faster_whisper compatible_faster_whisper = model_info.compatible_faster_whisper
if not model_info.has_pytorch: if not self.use_full_mlx and not model_info.has_pytorch:
raise FileNotFoundError( raise FileNotFoundError(
f"No PyTorch checkpoint (.pt/.bin/.safetensors) found under {self.model_path}" f"No PyTorch checkpoint (.pt/.bin/.safetensors) found under {self.model_path}"
) )
@@ -191,6 +200,10 @@ class SimulStreamingASR():
if self.encoder_backend == "whisper": if self.encoder_backend == "whisper":
self.disable_fast_encoder = True self.disable_fast_encoder = True
if self.encoder_backend == "mlx-whisper" and platform.system() == "Darwin":
if not hasattr(self, '_full_mlx_disabled'):
self.use_full_mlx = True
self.cfg = AlignAttConfig( self.cfg = AlignAttConfig(
tokenizer_is_multilingual= is_multilingual, tokenizer_is_multilingual= is_multilingual,
segment_length=self.min_chunk_size, segment_length=self.min_chunk_size,
@@ -201,7 +214,7 @@ class SimulStreamingASR():
cif_ckpt_path=self.cif_ckpt_path, cif_ckpt_path=self.cif_ckpt_path,
decoder_type="beam", decoder_type="beam",
beam_size=self.beams, beam_size=self.beams,
task=self.direct_english_translation, task="translate" if self.direct_english_translation else "transcribe",
never_fire=self.never_fire, never_fire=self.never_fire,
init_prompt=self.init_prompt, init_prompt=self.init_prompt,
max_context_tokens=self.max_context_tokens, max_context_tokens=self.max_context_tokens,
@@ -214,20 +227,36 @@ class SimulStreamingASR():
else: else:
self.tokenizer = None self.tokenizer = None
self.mlx_encoder, self.fw_encoder = None, None self.mlx_encoder, self.fw_encoder, self.mlx_model = None, None, None
if self.encoder_backend == "mlx-whisper": self.shared_model = None
print('Simulstreaming will use MLX whisper to increase encoding speed.')
if self.use_full_mlx and HAS_MLX_WHISPER:
logger.info('MLX Whisper backend used.')
if self._resolved_model_path is not None: if self._resolved_model_path is not None:
mlx_model = str(self._resolved_model_path) mlx_model_path = str(self._resolved_model_path)
else: else:
mlx_model = mlx_model_mapping.get(self.model_name) mlx_model_path = mlx_model_mapping.get(self.model_name)
if not mlx_model: if not mlx_model_path:
raise FileNotFoundError( raise FileNotFoundError(
f"MLX Whisper backend requested but no compatible weights found for model '{self.model_name}'." f"MLX Whisper backend requested but no compatible weights found for model '{self.model_name}'."
) )
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model) self.mlx_model = load_mlx_model(path_or_hf_repo=mlx_model_path)
self._warmup_mlx_model()
elif self.encoder_backend == "mlx-whisper":
# hybrid mode: mlx encoder + pytorch decoder
logger.info('SimulStreaming will use MLX Whisper encoder with PyTorch decoder.')
if self._resolved_model_path is not None:
mlx_model_path = str(self._resolved_model_path)
else:
mlx_model_path = mlx_model_mapping.get(self.model_name)
if not mlx_model_path:
raise FileNotFoundError(
f"MLX Whisper backend requested but no compatible weights found for model '{self.model_name}'."
)
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model_path)
self.shared_model = self.load_model()
elif self.encoder_backend == "faster-whisper": elif self.encoder_backend == "faster-whisper":
print('Simulstreaming will use Faster Whisper for the encoder.') print('SimulStreaming will use Faster Whisper for the encoder.')
if self._resolved_model_path is not None: if self._resolved_model_path is not None:
fw_model = str(self._resolved_model_path) fw_model = str(self._resolved_model_path)
else: else:
@@ -237,7 +266,20 @@ class SimulStreamingASR():
device='auto', device='auto',
compute_type='auto', compute_type='auto',
) )
self.shared_model = self.load_model() self.shared_model = self.load_model()
else:
self.shared_model = self.load_model()
def _warmup_mlx_model(self):
"""Warmup the full MLX model."""
warmup_audio = load_file(self.warmup_file)
if warmup_audio is not None:
temp_model = MLXAlignAtt(
cfg=self.cfg,
mlx_model=self.mlx_model,
)
temp_model.warmup(warmup_audio)
logger.info("Full MLX model warmed up successfully")
def _resolve_encoder_backend(self, preferred_backend, compatible_whisper_mlx, compatible_faster_whisper): def _resolve_encoder_backend(self, preferred_backend, compatible_whisper_mlx, compatible_faster_whisper):
@@ -285,7 +327,7 @@ class SimulStreamingASR():
lora_path = getattr(self, 'lora_path', None) lora_path = getattr(self, 'lora_path', None)
whisper_model = load_model( whisper_model = load_model(
name=model_ref, name=model_ref,
download_root=None, download_root=getattr(self, 'model_cache_dir', None),
decoder_only=self.fast_encoder, decoder_only=self.fast_encoder,
custom_alignment_heads=self.custom_alignment_heads, custom_alignment_heads=self.custom_alignment_heads,
lora_path=lora_path, lora_path=lora_path,

View File

@@ -47,9 +47,24 @@ class DecoderState:
def clean_cache(self): def clean_cache(self):
"""Clean the kv_cache after each inference step.""" """Clean the kv_cache after each inference step."""
self.kv_cache = {} # Explicitly delete tensor references to free GPU memory
if self.kv_cache:
for key in list(self.kv_cache.keys()):
tensor = self.kv_cache.pop(key, None)
if tensor is not None:
del tensor
# Clear the dict
self.kv_cache.clear()
# Force GPU cache cleanup (only if CUDA is available)
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
if self.decoder_type == "beam" and self.inference is not None: if self.decoder_type == "beam" and self.inference is not None:
self.inference.kv_cache = self.kv_cache # Create NEW dict instead of sharing reference
self.inference.kv_cache = {}
if self.token_decoder is not None: if self.token_decoder is not None:
self.token_decoder.reset() self.token_decoder.reset()

View File

@@ -0,0 +1,11 @@
from .decoder_state import MLXDecoderState
from .decoders import MLXBeamSearchDecoder, MLXGreedyDecoder, MLXInference
from .simul_whisper import MLXAlignAtt
__all__ = [
"MLXAlignAtt",
"MLXBeamSearchDecoder",
"MLXDecoderState",
"MLXGreedyDecoder",
"MLXInference",
]

View File

@@ -0,0 +1,76 @@
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
import mlx.core as mx
import numpy as np
@dataclass
class MLXDecoderState:
"""
mlx kv cache format: List of ((k, v), (cross_k, cross_v)) tuples per layer,
where each element is a tuple of mx.arrays.
"""
kv_cache: Optional[List[Tuple[Tuple[mx.array, mx.array], Tuple[mx.array, mx.array]]]] = None
tokenizer: Any = None
detected_language: Optional[str] = None
reset_tokenizer_to_auto_next_call: bool = False
tokens: List[mx.array] = field(default_factory=list)
initial_tokens: Optional[mx.array] = None
initial_token_length: int = 0
sot_index: int = 0
align_source: Dict[int, List[Tuple[int, int]]] = field(default_factory=dict)
num_align_heads: int = 0
segments: List[np.ndarray] = field(default_factory=list)
context: Any = None
pending_incomplete_tokens: List[int] = field(default_factory=list)
global_time_offset: float = 0.0
cumulative_time_offset: float = 0.0
first_timestamp: Optional[float] = None
last_attend_frame: int = 0
speaker: int = -1
log_segments: int = 0
cif_weights: Optional[mx.array] = None
always_fire: bool = False
never_fire: bool = False
suppress_tokens: Optional[Tuple[int, ...]] = None
token_decoder: Any = None
decoder_type: str = "greedy"
inference: Any = None
def clean_cache(self):
self.kv_cache = None
if self.decoder_type == "beam" and self.inference is not None:
self.inference.kv_cache = None
if self.token_decoder is not None:
self.token_decoder.reset()
def reset(self, rewind_threshold: int = 200):
self.last_attend_frame = -rewind_threshold
self.cumulative_time_offset = 0.0
self.pending_incomplete_tokens = []
self.log_segments += 1
def full_reset(self, rewind_threshold: int = 200):
"""
Full reset including audio segments and tokens.
Args:
rewind_threshold: Value for resetting last_attend_frame
"""
self.reset(rewind_threshold)
self.segments = []
self.tokens = []
self.kv_cache = None
self.first_timestamp = None

View File

@@ -0,0 +1,219 @@
"""
MLX-native token decoders for streaming ASR.
"""
from typing import Any, Dict, List, Optional, Tuple
import mlx.core as mx
import numpy as np
class MLXGreedyDecoder:
"""Greedy decoder using MLX operations."""
def __init__(self, temperature: float, eot: int):
self.temperature = temperature
self.eot = eot
def update(
self, tokens: mx.array, logits: mx.array, sum_logprobs: mx.array
) -> Tuple[mx.array, bool]:
"""
Update tokens with next predicted token.
Args:
tokens: Current token sequence, shape (batch, seq_len)
logits: Logits for next token, shape (batch, vocab_size)
sum_logprobs: Cumulative log probabilities, shape (batch,)
Returns:
Updated tokens and completion flag
"""
if self.temperature == 0:
next_tokens = mx.argmax(logits, axis=-1)
else:
probs = mx.softmax(logits / self.temperature, axis=-1)
next_tokens = mx.random.categorical(mx.log(probs + 1e-10))
logprobs = mx.softmax(logits, axis=-1)
logprobs = mx.log(logprobs + 1e-10)
batch_size = logprobs.shape[0]
current_logprobs = logprobs[mx.arange(batch_size), next_tokens]
mask = (tokens[:, -1] != self.eot).astype(mx.float32)
sum_logprobs = sum_logprobs + current_logprobs * mask
eot_mask = (tokens[:, -1] == self.eot)
next_tokens = mx.where(eot_mask, mx.array(self.eot), next_tokens)
tokens = mx.concatenate([tokens, next_tokens[:, None]], axis=1)
completed = bool(mx.all(tokens[:, -1] == self.eot))
return tokens, completed
def finalize(self, tokens: mx.array, sum_logprobs: mx.array):
"""Finalize decoding by ensuring EOT at end."""
eot_column = mx.full((tokens.shape[0], 1), self.eot, dtype=tokens.dtype)
tokens = mx.concatenate([tokens, eot_column], axis=1)
return tokens, sum_logprobs.tolist()
class MLXBeamSearchDecoder:
"""Beam search decoder using MLX operations."""
def __init__(
self,
beam_size: int,
eot: int,
inference: Any,
patience: Optional[float] = None,
):
self.beam_size = beam_size
self.eot = eot
self.inference = inference
self.patience = patience or 1.0
self.max_candidates: int = round(beam_size * self.patience)
self.finished_sequences: Optional[List[Dict]] = None
assert (
self.max_candidates > 0
), f"Invalid beam size ({beam_size}) or patience ({patience})"
def reset(self):
"""Reset finished sequences for new segment."""
self.finished_sequences = None
def update(
self, tokens: mx.array, logits: mx.array, sum_logprobs: mx.array
) -> Tuple[mx.array, bool]:
"""
Update tokens using beam search.
Args:
tokens: Current token sequences, shape (batch * beam_size, seq_len)
logits: Logits for next token, shape (batch * beam_size, vocab_size)
sum_logprobs: Cumulative log probabilities, shape (batch * beam_size,)
Returns:
Updated tokens and completion flag
"""
if tokens.shape[0] % self.beam_size != 0:
raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
n_audio = tokens.shape[0] // self.beam_size
if self.finished_sequences is None:
self.finished_sequences = [{} for _ in range(n_audio)]
logprobs = mx.softmax(logits, axis=-1)
logprobs = mx.log(logprobs + 1e-10)
logprobs_np = np.array(logprobs)
tokens_np = np.array(tokens)
sum_logprobs_np = np.array(sum_logprobs)
next_tokens, source_indices, finished_sequences = [], [], []
new_sum_logprobs = []
for i in range(n_audio):
scores, sources, finished = {}, {}, {}
for j in range(self.beam_size):
idx = i * self.beam_size + j
prefix = tokens_np[idx].tolist()
top_k_indices = np.argsort(logprobs_np[idx])[-self.beam_size - 1:][::-1]
for token_idx in top_k_indices:
logprob = logprobs_np[idx, token_idx]
new_logprob = sum_logprobs_np[idx] + logprob
sequence = tuple(prefix + [int(token_idx)])
scores[sequence] = new_logprob
sources[sequence] = idx
saved = 0
for sequence in sorted(scores, key=scores.get, reverse=True):
if sequence[-1] == self.eot:
finished[sequence] = scores[sequence]
else:
new_sum_logprobs.append(scores[sequence])
next_tokens.append(sequence)
source_indices.append(sources[sequence])
saved += 1
if saved == self.beam_size:
break
finished_sequences.append(finished)
tokens = mx.array(np.array(next_tokens, dtype=np.int32))
sum_logprobs = mx.array(np.array(new_sum_logprobs, dtype=np.float32))
self.inference.rearrange_kv_cache(source_indices)
assert len(self.finished_sequences) == len(finished_sequences)
for previously_finished, newly_finished in zip(
self.finished_sequences, finished_sequences
):
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
if len(previously_finished) >= self.max_candidates:
break
previously_finished[seq] = newly_finished[seq]
completed = all(
len(sequences) >= self.max_candidates
for sequences in self.finished_sequences
)
return tokens, completed
def finalize(self, preceding_tokens: mx.array, sum_logprobs: mx.array):
"""Finalize beam search by selecting best sequences."""
preceding_tokens_np = np.array(preceding_tokens)
sum_logprobs_np = np.array(sum_logprobs)
n_audio = preceding_tokens_np.shape[0] // self.beam_size
tokens_list: List[List[int]] = [[] for _ in range(n_audio)]
sum_logprobs_list: List[float] = [0.0] * n_audio
for i, sequences in enumerate(self.finished_sequences):
if sequences:
best_seq = max(sequences, key=sequences.get)
tokens_list[i] = list(best_seq)
sum_logprobs_list[i] = sequences[best_seq]
else:
idx = i * self.beam_size
tokens_list[i] = preceding_tokens_np[idx].tolist() + [self.eot]
sum_logprobs_list[i] = float(sum_logprobs_np[idx])
max_len = max(len(t) for t in tokens_list)
for i, t in enumerate(tokens_list):
tokens_list[i] = t + [self.eot] * (max_len - len(t))
tokens = mx.array(np.array(tokens_list, dtype=np.int32))
return tokens, sum_logprobs_list
class MLXInference:
"""MLX inference wrapper for beam search KV cache management."""
def __init__(self, model, initial_token_length: int):
self.model = model
self.initial_token_length = initial_token_length
self.kv_cache = None
def rearrange_kv_cache(self, source_indices: List[int]):
"""Rearrange KV cache based on beam search source indices."""
if self.kv_cache is None:
return
if source_indices == list(range(len(source_indices))):
return
source_indices_mx = mx.array(source_indices, dtype=mx.int32)
new_cache = []
for layer_cache in self.kv_cache:
(k, v), (cross_k, cross_v) = layer_cache
new_k = k[source_indices_mx]
new_v = v[source_indices_mx]
new_cache.append(((new_k, new_v), (cross_k, cross_v)))
self.kv_cache = new_cache
def logits(
self,
tokens: mx.array,
audio_features: mx.array,
) -> Tuple[mx.array, List]:
"""Get logits from decoder with KV cache."""
logits, self.kv_cache, cross_qk = self.model.decoder(
tokens, audio_features, kv_cache=self.kv_cache
)
return logits, cross_qk

View File

@@ -0,0 +1,756 @@
"""
MLX whisper AlignAtt streaming decoder
"""
import logging
from time import time
from typing import Any, List, Optional, Tuple
import mlx.core as mx
import numpy as np
from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram
from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim
from whisperlivekit.timed_objects import ASRToken
from whisperlivekit.whisper import DecodingOptions, tokenizer
from whisperlivekit.whisper.audio import N_FRAMES, N_SAMPLES, TOKENS_PER_SECOND
from ..config import AlignAttConfig
from .decoder_state import MLXDecoderState
from .decoders import MLXBeamSearchDecoder, MLXGreedyDecoder, MLXInference
DEC_PAD = 50257
logger = logging.getLogger(__name__)
class MLXTokenBuffer: #should try to make it heritate from classic simul whisper class
"""Token buffer for MLX-based decoding."""
def __init__(self, text="", tokenizer=None, prefix_token_ids=None):
self.text = text
self.prefix_token_ids = prefix_token_ids or []
self.tokenizer = tokenizer
self.pending_token_ids = []
def as_token_ids(self, tokenizer=None):
if tokenizer is None:
tokenizer = self.tokenizer
if tokenizer is None:
raise ValueError("Tokenizer is not set.")
return self.prefix_token_ids + tokenizer.encode(self.text)
def as_mlx_array(self) -> mx.array:
"""Return tokens as MLX array."""
tok_ids = self.as_token_ids()
return mx.array([tok_ids], dtype=mx.int32)
def as_mlx_array_beam(self, beam: int) -> mx.array:
"""Return tokens as MLX array repeated for beam search."""
t = self.as_mlx_array()
return mx.repeat(t, beam, axis=0)
def as_text(self):
return self.text
@staticmethod
def empty(*a, **kw):
return MLXTokenBuffer(*a, **kw)
@staticmethod
def from_text(text, *a, **kw):
return MLXTokenBuffer(*a, text=text, **kw)
def is_empty(self):
return self.text is None or self.text == ""
def trim_words(self, num=1, after=0):
"""Trim words from the beginning of the context."""
tokenizer = self.tokenizer
assert tokenizer is not None, "Tokenizer is not set."
ids = tokenizer.encode(self.text[after:])
words, wids = self.tokenizer.split_to_word_tokens(ids)
if not words:
return 0
self.text = self.text[:after] + "".join(words[num:])
return sum(len(wi) for wi in wids[:num])
def append_token_ids(self, token_ids):
"""Append token IDs to the buffer, handling incomplete UTF-8."""
tokenizer = self.tokenizer
assert tokenizer is not None, "Tokenizer is not set."
all_tokens = self.pending_token_ids + token_ids
decoded = tokenizer.decode(all_tokens)
replacement_char = "\ufffd"
if replacement_char in decoded:
if len(all_tokens) > 1:
decoded_partial = tokenizer.decode(all_tokens[:-1])
if replacement_char not in decoded_partial:
self.text += decoded_partial
self.pending_token_ids = [all_tokens[-1]]
else:
self.pending_token_ids = all_tokens
else:
self.pending_token_ids = all_tokens
else:
self.text += decoded
self.pending_token_ids = []
def mlx_median_filter(x: mx.array, filter_width: int) -> mx.array:
"""
Apply median filter along the last axis.
Args:
x: Input array of shape (..., T)
filter_width: Width of the median filter (should be odd)
Returns:
Filtered array of same shape
"""
if filter_width <= 1:
return x
pad_width = filter_width // 2
shape = x.shape
left_pad = mx.repeat(x[..., :1], pad_width, axis=-1)
right_pad = mx.repeat(x[..., -1:], pad_width, axis=-1)
x_padded = mx.concatenate([left_pad, x, right_pad], axis=-1)
result_shape = list(shape)
result = []
for i in range(shape[-1]):
window = x_padded[..., i:i + filter_width]
sorted_window = mx.sort(window, axis=-1)
median_val = sorted_window[..., filter_width // 2:filter_width // 2 + 1]
result.append(median_val)
return mx.concatenate(result, axis=-1)
class MLXAlignAtt:
"""
MLX-native Alignment-based Attention decoder for SimulStreaming.
This class runs entirely on MLX, with no PyTorch dependencies for inference.
"""
@property
def speaker(self):
return self.state.speaker
@speaker.setter
def speaker(self, value):
self.state.speaker = value
@property
def global_time_offset(self):
return self.state.global_time_offset
@global_time_offset.setter
def global_time_offset(self, value):
self.state.global_time_offset = value
def __init__(
self,
cfg: AlignAttConfig,
mlx_model: Any,
) -> None:
"""
Initialize MLX AlignAtt decoder.
Args:
cfg: AlignAtt configuration
mlx_model: MLX Whisper model (full model, not just encoder)
"""
self.model = mlx_model
self.cfg = cfg
logger.info(f"MLX Model dimensions: {self.model.dims}")
self.decode_options = DecodingOptions(
language=cfg.language,
without_timestamps=True,
task=cfg.task
)
self.tokenizer_is_multilingual = cfg.tokenizer_is_multilingual
self.max_text_len = self.model.dims.n_text_ctx
self.num_decoder_layers = len(self.model.decoder.blocks)
if self.cfg.max_context_tokens is None:
self.max_context_tokens = self.max_text_len
else:
self.max_context_tokens = self.cfg.max_context_tokens
# Initialize per-session state
self.state = MLXDecoderState()
self._init_state(cfg)
def _init_state(self, cfg: AlignAttConfig):
"""Initialize the per-session decoder state."""
self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
self.state.tokenizer = self.tokenizer
self.state.detected_language = cfg.language if cfg.language != "auto" else None
self.state.global_time_offset = 0.0
self.state.last_attend_frame = -cfg.rewind_threshold
self.state.speaker = -1
if cfg.cif_ckpt_path is None or not cfg.cif_ckpt_path:
if cfg.never_fire:
self.state.never_fire = True
self.state.always_fire = False
else:
self.state.always_fire = True
self.state.never_fire = False
else:
logger.warning("CIF checkpoint provided but MLX CIF not implemented. Using always_fire=True")
self.state.always_fire = True
self.state.never_fire = cfg.never_fire
self._build_alignment_source()
suppress_tokens = [
self.tokenizer.transcribe,
self.tokenizer.translate,
self.tokenizer.sot,
self.tokenizer.sot_prev,
self.tokenizer.sot_lm,
self.tokenizer.no_timestamps,
] + list(self.tokenizer.all_language_tokens)
if self.tokenizer.no_speech is not None:
suppress_tokens.append(self.tokenizer.no_speech)
self.state.suppress_tokens = tuple(sorted(set(suppress_tokens)))
logger.debug(f"Suppress tokens: {self.state.suppress_tokens}")
self.init_tokens()
self.init_context()
self.state.decoder_type = cfg.decoder_type
if cfg.decoder_type == "greedy":
logger.info("Using MLX greedy decoder")
self.state.token_decoder = MLXGreedyDecoder(0.0, self.tokenizer.eot)
elif cfg.decoder_type == "beam":
logger.info("Using MLX beam decoder")
self.state.inference = MLXInference(self.model, self.state.initial_token_length)
self.state.token_decoder = MLXBeamSearchDecoder(
inference=self.state.inference,
eot=self.tokenizer.eot,
beam_size=cfg.beam_size
)
def _build_alignment_source(self):
"""Build alignment source mapping from model's alignment_heads."""
self.state.align_source = {}
self.state.num_align_heads = 0
alignment_heads = self.model.alignment_heads
if alignment_heads is None:
logger.warning("No alignment heads found in model")
return
if hasattr(alignment_heads, 'tolist'):
heads_list = alignment_heads.tolist()
else:
heads_list = np.array(alignment_heads).tolist()
for layer_rank, head_id in heads_list:
layer_rank = int(layer_rank)
head_id = int(head_id)
heads = self.state.align_source.get(layer_rank, [])
heads.append((self.state.num_align_heads, head_id))
self.state.align_source[layer_rank] = heads
self.state.num_align_heads += 1
def warmup(self, audio: np.ndarray):
"""Warmup the model with sample audio."""
try:
self.insert_audio(audio)
self.infer(is_last=True)
self.refresh_segment(complete=True)
logger.info("MLX model warmed up successfully")
except Exception as e:
logger.exception(f"MLX model warmup failed: {e}")
def create_tokenizer(self, language=None):
"""Create tokenizer for the given language."""
self.tokenizer = tokenizer.get_tokenizer(
multilingual=self.tokenizer_is_multilingual,
language=language,
num_languages=self.model.num_languages,
task=self.decode_options.task
)
self.state.tokenizer = self.tokenizer
def init_context(self):
"""Initialize context buffer."""
kw = {
'tokenizer': self.tokenizer,
'prefix_token_ids': [self.tokenizer.sot_prev]
}
self.state.context = MLXTokenBuffer.empty(**kw)
if self.cfg.static_init_prompt is not None:
self.state.context = MLXTokenBuffer.from_text(self.cfg.static_init_prompt, **kw)
if self.cfg.init_prompt is not None:
self.state.context.text += self.cfg.init_prompt
def init_tokens(self):
"""Initialize token sequence."""
logger.debug(f"init tokens, {len(self.state.segments)}")
self.state.initial_tokens = mx.array(
[self.tokenizer.sot_sequence_including_notimestamps],
dtype=mx.int32
)
self.state.initial_token_length = self.state.initial_tokens.shape[1]
self.state.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot)
logger.debug(f"init tokens after, {len(self.state.segments)}")
self.state.tokens = [self.state.initial_tokens]
def trim_context(self):
"""Trim context if too long."""
logger.info("Trimming context")
c = len(self.state.context.as_token_ids()) - len(self.state.context.prefix_token_ids)
logger.info(f"Context text: {self.state.context.as_text()}")
l = sum(t.shape[1] for t in self.state.tokens) + c
if self.cfg.static_init_prompt is None:
after = 0
else:
after = len(self.cfg.static_init_prompt)
while c > self.max_context_tokens or l > self.max_text_len - 20:
t = self.state.context.trim_words(after=after)
l -= t
c -= t
logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
if t == 0:
break
logger.info(f"Context after trim: {self.state.context.text} (len: {l})")
def refresh_segment(self, complete=False):
"""Refresh segment state."""
logger.debug("Refreshing segment:")
self.init_tokens()
self.state.last_attend_frame = -self.cfg.rewind_threshold
self.state.cumulative_time_offset = 0.0
self.init_context()
logger.debug(f"Context: {self.state.context}")
if not complete and len(self.state.segments) > 2:
self.state.segments = self.state.segments[-2:]
else:
logger.debug("removing all segments.")
self.state.segments = []
self.state.log_segments += 1
self.state.pending_incomplete_tokens = []
def fire_at_boundary(self, chunked_encoder_feature: mx.array) -> bool:
"""Check if we should fire at word boundary (CIF-based)."""
if self.state.always_fire:
return True
if self.state.never_fire:
return False
return True
def _current_tokens(self) -> mx.array:
"""Get current token sequence for decoding."""
toks = self.state.tokens
if toks[0].shape[0] == 1:
toks[0] = mx.repeat(toks[0], self.cfg.beam_size, axis=0)
if not self.state.context.is_empty():
context_toks = self.state.context.as_mlx_array_beam(self.cfg.beam_size)
toks = [context_toks] + toks
# Concatenate all tokens
if len(toks) > 1:
current_tokens = mx.concatenate(toks, axis=1)
else:
current_tokens = toks[0]
logger.debug("debug print current_tokens:")
self.debug_print_tokens(current_tokens)
return current_tokens
def debug_print_tokens(self, tokens: mx.array):
"""Debug print token sequences."""
tokens_np = np.array(tokens)
for i in range(min(self.cfg.beam_size, tokens_np.shape[0])):
logger.debug(self.tokenizer.decode_with_timestamps(tokens_np[i].tolist()))
def segments_len(self) -> float:
"""Get total length of audio segments in seconds."""
return sum(s.shape[0] for s in self.state.segments) / 16000
def _apply_minseglen(self) -> bool:
"""Check if we have enough audio to process."""
segments_len = self.segments_len()
if segments_len < self.cfg.audio_min_len:
logger.debug("waiting for next segment")
return False
return True
def insert_audio(self, segment: np.ndarray = None):
"""Insert audio segment into buffer."""
if segment is not None:
if hasattr(segment, 'numpy'):
segment = segment.numpy()
self.state.segments.append(segment)
removed_len = 0
segments_len = self.segments_len()
while len(self.state.segments) > 1 and segments_len > self.cfg.audio_max_len:
removed_len = self.state.segments[0].shape[0] / 16000
segments_len -= removed_len
self.state.last_attend_frame -= int(TOKENS_PER_SECOND * removed_len)
self.state.cumulative_time_offset += removed_len
self.state.segments = self.state.segments[1:]
logger.debug(f"remove segments: {len(self.state.segments)} {len(self.state.tokens)}, cumulative offset: {self.state.cumulative_time_offset:.2f}s")
if len(self.state.tokens) > 1:
# Convert MLX array to list for context
token_list = np.array(self.state.tokens[1][0, :]).tolist()
self.state.context.append_token_ids(token_list)
self.state.tokens = [self.state.initial_tokens] + self.state.tokens[2:]
return removed_len
def _clean_cache(self):
"""Clean the kv_cache after each inference step."""
self.state.clean_cache()
def _suppress_tokens(self, logits: mx.array) -> mx.array:
"""Apply token suppression to logits."""
if self.state.suppress_tokens:
suppress_indices = mx.array(list(self.state.suppress_tokens), dtype=mx.int32)
logits = logits.at[:, suppress_indices].add(-float('inf'))
return logits
def lang_id(self, encoder_features: mx.array) -> Tuple[mx.array, List[dict]]:
"""Language detection from encoder features."""
n_audio = encoder_features.shape[0]
x = mx.array([[self.tokenizer.sot]] * n_audio, dtype=mx.int32)
logits, _, _ = self.model.decoder(x, encoder_features, kv_cache=None)
logits = logits[:, 0]
mask = mx.ones(logits.shape[-1], dtype=mx.bool_)
language_token_indices = mx.array(list(self.tokenizer.all_language_tokens), dtype=mx.int32)
mask = mask.at[language_token_indices].add(False)
logits = mx.where(mask, mx.array(-float('inf')), logits)
language_tokens = mx.argmax(logits, axis=-1)
language_token_probs = mx.softmax(logits, axis=-1)
probs_np = np.array(language_token_probs)
language_probs = [
{
c: float(probs_np[i, j])
for j, c in zip(self.tokenizer.all_language_tokens, self.tokenizer.all_language_codes)
}
for i in range(n_audio)
]
self._clean_cache()
return language_tokens, language_probs
def infer(self, is_last: bool = False) -> List[ASRToken]:
"""
Main inference method.
Args:
is_last: Whether this is the final chunk
Returns:
List of timestamped ASR tokens
"""
new_segment = True
if len(self.state.segments) == 0:
logger.debug("No segments, nothing to do")
return []
if not self._apply_minseglen():
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
return []
if len(self.state.segments) > 1:
input_segments = np.concatenate(self.state.segments, axis=0)
else:
input_segments = self.state.segments[0]
beg_encode = time()
mlx_mel_padded = mlx_log_mel_spectrogram(
audio=input_segments,
n_mels=self.model.dims.n_mels,
padding=N_SAMPLES
)
mlx_mel = mlx_pad_or_trim(mlx_mel_padded, N_FRAMES, axis=-2)
encoder_feature = self.model.encoder(mlx_mel[None])
content_mel_len = int((mlx_mel_padded.shape[0] - mlx_mel.shape[0]) / 2)
mx.eval(encoder_feature)
end_encode = time()
logger.debug(f'MLX Encoder duration: {end_encode - beg_encode:.3f}s')
if self.cfg.language == "auto" and self.state.detected_language is None and self.state.first_timestamp:
seconds_since_start = self.segments_len() - self.state.first_timestamp
if seconds_since_start >= 2.0:
language_tokens, language_probs = self.lang_id(encoder_feature)
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
print(f"Detected language: {top_lan} with p={p:.4f}")
self.create_tokenizer(top_lan)
self.state.last_attend_frame = -self.cfg.rewind_threshold
self.state.cumulative_time_offset = 0.0
self.init_tokens()
self.init_context()
self.state.detected_language = top_lan
logger.info(f"Tokenizer language: {self.tokenizer.language}")
self.trim_context()
current_tokens = self._current_tokens()
fire_detected = self.fire_at_boundary(encoder_feature[:, :content_mel_len, :])
sum_logprobs = mx.zeros((self.cfg.beam_size,), dtype=mx.float32)
completed = False
attn_of_alignment_heads = None
most_attended_frame = None
token_len_before_decoding = current_tokens.shape[1]
l_absolute_timestamps = []
accumulated_cross_attns = []
audio_duration_s = self.segments_len()
# ~15 text tokens/s is a generous upper bound for speech; TOKENS_PER_SECOND (50)
# is the mel-frame rate and was causing 10-40x over-allocation on repetition loops.
max_tokens_per_chunk = max(50, int(audio_duration_s * 15 * 1.5))
tokens_produced_this_chunk = 0
while not completed and current_tokens.shape[1] < self.max_text_len:
tokens_produced_this_chunk += 1
if tokens_produced_this_chunk > max_tokens_per_chunk:
logger.warning(f"[Loop Detection] Too many tokens ({tokens_produced_this_chunk}) for {audio_duration_s:.2f}s audio. Breaking.")
current_tokens = current_tokens[:, :token_len_before_decoding]
break
if new_segment:
tokens_for_logits = current_tokens
else:
tokens_for_logits = current_tokens[:, -1:]
if self.state.decoder_type == "greedy":
logits, self.state.kv_cache, cross_qk = self.model.decoder(
tokens_for_logits, encoder_feature, kv_cache=self.state.kv_cache
)
else:
logits, cross_qk = self.state.inference.logits(tokens_for_logits, encoder_feature)
mx.eval(logits)
accumulated_cross_attns.append(cross_qk)
if len(accumulated_cross_attns) > 16:
accumulated_cross_attns = accumulated_cross_attns[-16:]
if new_segment and self.tokenizer.no_speech is not None:
probs_at_sot = mx.softmax(logits[:, self.state.sot_index, :], axis=-1)
no_speech_probs = np.array(probs_at_sot[:, self.tokenizer.no_speech]).tolist()
if no_speech_probs[0] > self.cfg.nonspeech_prob:
logger.info("no speech, stop")
break
logits = logits[:, -1, :] # Last token logits
# Suppress tokens at segment start
if new_segment:
blank_tokens = self.tokenizer.encode(" ") + [self.tokenizer.eot]
logits = logits.at[:, blank_tokens].add(-float('inf'))
new_segment = False
logits = self._suppress_tokens(logits)
current_tokens, completed = self.state.token_decoder.update(
current_tokens, logits, sum_logprobs
)
mx.eval(current_tokens)
logger.debug(f"Decoding completed: {completed}")
self.debug_print_tokens(current_tokens)
attn_of_alignment_heads = self._process_cross_attention(
accumulated_cross_attns, content_mel_len
)
most_attended_frames = mx.argmax(attn_of_alignment_heads[:, -1, :], axis=-1)
most_attended_frames_np = np.array(most_attended_frames)
absolute_timestamps = [
(frame * 0.02 + self.state.cumulative_time_offset)
for frame in most_attended_frames_np.tolist()
]
logger.debug(str(most_attended_frames_np.tolist()) + " most att frames")
logger.debug(f"Absolute timestamps: {absolute_timestamps}")
most_attended_frame = int(most_attended_frames_np[0])
l_absolute_timestamps.append(absolute_timestamps[0])
if completed:
current_tokens = current_tokens[:, :-1]
break
if not is_last and self.state.last_attend_frame - most_attended_frame > self.cfg.rewind_threshold:
current_tokens_np = np.array(current_tokens)
if current_tokens.shape[1] > 1 and current_tokens_np[0, -2] >= DEC_PAD:
logger.debug("omit rewinding from special tokens")
self.state.last_attend_frame = most_attended_frame
else:
logger.debug(f"[rewind detected] current: {most_attended_frame}, last: {self.state.last_attend_frame}")
self.state.last_attend_frame = -self.cfg.rewind_threshold
current_tokens = mx.concatenate(self.state.tokens, axis=1) if len(self.state.tokens) > 0 else self.state.tokens[0]
break
else:
self.state.last_attend_frame = most_attended_frame
if content_mel_len - most_attended_frame <= (4 if is_last else self.cfg.frame_threshold):
logger.debug(f"attention reaches the end: {most_attended_frame}/{content_mel_len}")
current_tokens = current_tokens[:, :-1]
break
tokens_to_split = np.array(current_tokens[0, token_len_before_decoding:]).tolist()
if self.state.pending_incomplete_tokens:
logger.debug(f"[UTF-8 Fix] Prepending pending tokens: {self.state.pending_incomplete_tokens}")
tokens_to_split = self.state.pending_incomplete_tokens + tokens_to_split
if fire_detected or is_last:
new_hypothesis = tokens_to_split
split_words, split_tokens = self.tokenizer.split_to_word_tokens(new_hypothesis)
else:
split_words, split_tokens = self.tokenizer.split_to_word_tokens(tokens_to_split)
if len(split_words) > 1:
new_hypothesis = [i for sublist in split_tokens[:-1] for i in sublist]
else:
new_hypothesis = []
logger.debug(f"new_hypothesis: {new_hypothesis}")
new_tokens = mx.array([new_hypothesis], dtype=mx.int32)
new_tokens = mx.repeat(new_tokens, self.cfg.beam_size, axis=0)
self.state.tokens.append(new_tokens)
logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}")
self._clean_cache()
if len(l_absolute_timestamps) >= 2 and self.state.first_timestamp is None:
self.state.first_timestamp = l_absolute_timestamps[0]
timestamped_words = []
timestamp_idx = 0
replacement_char = "\ufffd"
for word, word_tokens in zip(split_words, split_tokens):
if replacement_char in word:
logger.warning(f"[UTF-8 Filter] Skipping: {repr(word)}")
timestamp_idx += len(word_tokens)
continue
try:
current_timestamp = l_absolute_timestamps[timestamp_idx]
except IndexError:
pass
timestamp_idx += len(word_tokens)
timestamp_entry = ASRToken(
start=round(current_timestamp, 2),
end=round(current_timestamp + 0.1, 2),
text=word,
speaker=self.state.speaker,
detected_language=self.state.detected_language
).with_offset(self.state.global_time_offset)
timestamped_words.append(timestamp_entry)
self.state.pending_incomplete_tokens = []
MAX_PENDING_TOKENS = 10
if split_words and replacement_char in split_words[-1]:
if len(split_tokens[-1]) <= MAX_PENDING_TOKENS:
self.state.pending_incomplete_tokens = split_tokens[-1]
logger.debug(f"[UTF-8 Fix] Holding incomplete tokens")
else:
logger.warning(f"[UTF-8 Fix] Skipping too many tokens")
return timestamped_words
def _process_cross_attention(
self,
cross_attns: List[List[mx.array]],
content_mel_len: int
) -> mx.array:
"""
Process cross-attention weights for alignment.
Args:
cross_attns: List of cross-attention from each forward pass
Each element is a list of mx.arrays per layer
content_mel_len: Length of actual audio content
Returns:
Processed attention tensor, shape (batch, seq_len, content_mel_len)
"""
attn_of_alignment_heads = [[] for _ in range(self.state.num_align_heads)]
num_decoder_layers = self.num_decoder_layers
if cross_attns and isinstance(cross_attns[0], list):
flattened_attns = [attn for layer_list in cross_attns for attn in layer_list]
else:
flattened_attns = cross_attns
for idx, attn_mat in enumerate(flattened_attns):
if attn_mat is None:
continue
layer_rank = idx % num_decoder_layers
align_heads_in_layer = self.state.align_source.get(layer_rank, [])
if len(align_heads_in_layer) == 0:
continue
attn_mat = mx.softmax(attn_mat, axis=-1)
for align_head_rank, head_id in align_heads_in_layer:
if self.cfg.beam_size == 1:
if attn_mat.ndim == 4:
a = attn_mat[0, head_id, :, :]
else:
a = attn_mat[head_id, :, :]
a = a[None, :, :]
else:
a = attn_mat[:, head_id, :, :]
attn_of_alignment_heads[align_head_rank].append(a)
tmp = []
for mat in attn_of_alignment_heads:
if mat:
t = mx.concatenate(mat, axis=1)
tmp.append(t)
if not tmp:
return mx.zeros((self.cfg.beam_size, 1, content_mel_len))
attn_of_alignment_heads = mx.stack(tmp, axis=1)
std = mx.std(attn_of_alignment_heads, axis=-2, keepdims=True)
mean = mx.mean(attn_of_alignment_heads, axis=-2, keepdims=True)
attn_of_alignment_heads = (attn_of_alignment_heads - mean) / (std + 1e-8)
attn_of_alignment_heads = mlx_median_filter(attn_of_alignment_heads, 7)
attn_of_alignment_heads = mx.mean(attn_of_alignment_heads, axis=1)
attn_of_alignment_heads = attn_of_alignment_heads[:, :, :content_mel_len]
mx.eval(attn_of_alignment_heads)
return attn_of_alignment_heads

View File

@@ -69,3 +69,39 @@ def load_mlx_encoder(
model.update(encoder_weights) model.update(encoder_weights)
mx.eval(model.parameters()) mx.eval(model.parameters())
return model return model
def load_mlx_model(
path_or_hf_repo: str,
dtype: mx.Dtype = mx.float32,
) -> whisper.Whisper:
model_path = Path(path_or_hf_repo)
if not model_path.exists():
model_path = Path(snapshot_download(repo_id=path_or_hf_repo))
with open(str(model_path / "config.json"), "r") as f:
config = json.loads(f.read())
config.pop("model_type", None)
quantization = config.pop("quantization", None)
model_args = whisper.ModelDimensions(**config)
wf = model_path / "weights.safetensors"
if not wf.exists():
wf = model_path / "weights.npz"
weights = mx.load(str(wf))
model = whisper.Whisper(model_args, dtype)
if quantization is not None:
class_predicate = (
lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
and f"{p}.scales" in weights
)
nn.quantize(model, **quantization, class_predicate=class_predicate)
weights = tree_unflatten(list(weights.items()))
model.update(weights)
mx.eval(model.parameters())
return model

View File

@@ -390,7 +390,6 @@ class AlignAtt:
return [] return []
if not self._apply_minseglen(): if not self._apply_minseglen():
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.") logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
input_segments = torch.cat(self.state.segments, dim=0)
return [] return []
# input_segments is concatenation of audio, it's one array # input_segments is concatenation of audio, it's one array
@@ -485,7 +484,9 @@ class AlignAtt:
accumulated_cross_attns = [] accumulated_cross_attns = []
audio_duration_s = self.segments_len() audio_duration_s = self.segments_len()
max_tokens_per_chunk = max(50, int(audio_duration_s * TOKENS_PER_SECOND * 2.0)) # 2x margin, min 50 # ~15 text tokens/s is a generous upper bound for speech; TOKENS_PER_SECOND (50)
# is the mel-frame rate and was causing 10-40x over-allocation on repetition loops.
max_tokens_per_chunk = max(50, int(audio_duration_s * 15 * 1.5))
tokens_produced_this_chunk = 0 tokens_produced_this_chunk = 0
while not completed and current_tokens.shape[1] < self.max_text_len: # bos is 3 tokens while not completed and current_tokens.shape[1] < self.max_text_len: # bos is 3 tokens
@@ -506,8 +507,12 @@ class AlignAtt:
result = self.logits(tokens_for_logits, encoder_feature, return_cross_attn=True) result = self.logits(tokens_for_logits, encoder_feature, return_cross_attn=True)
logits, cross_attns = result logits, cross_attns = result
# Accumulate cross-attention from this forward pass # Accumulate cross-attention from this forward pass (rolling window to
# bound VRAM — only the last entry matters for alignment, and the
# median_filter kernel is 7, so 16 entries is more than enough).
accumulated_cross_attns.append(cross_attns) accumulated_cross_attns.append(cross_attns)
if len(accumulated_cross_attns) > 16:
accumulated_cross_attns = accumulated_cross_attns[-16:]
if new_segment and self.tokenizer.no_speech is not None: if new_segment and self.tokenizer.no_speech is not None:
probs_at_sot = logits[:, self.state.sot_index, :].float().softmax(dim=-1) probs_at_sot = logits[:, self.state.sot_index, :].float().softmax(dim=-1)
@@ -626,8 +631,10 @@ class AlignAtt:
try: try:
current_timestamp = l_absolute_timestamps[timestamp_idx] current_timestamp = l_absolute_timestamps[timestamp_idx]
except: except IndexError:
pass # Use last timestamp if index out of range
logger.warning(f"Timestamp index {timestamp_idx} out of range, using last timestamp")
current_timestamp = l_absolute_timestamps[-1] if l_absolute_timestamps else 0.0
timestamp_idx += len(word_tokens) timestamp_idx += len(word_tokens)
timestamp_entry = ASRToken( timestamp_entry = ASRToken(

View File

@@ -0,0 +1,139 @@
"""
Thread Safety Configuration for WhisperLiveKit
This module provides thread safety configuration and utilities.
Environment Variables:
WHISPERLIVEKIT_MODEL_LOCK: Enable/disable model locking (default: 1)
Set to "0" to disable for single-connection deployments
WHISPERLIVEKIT_LOCK_TIMEOUT: Lock acquisition timeout in seconds (default: 30)
Usage:
# Enable model locking (default)
export WHISPERLIVEKIT_MODEL_LOCK=1
# Disable for single-connection deployment
export WHISPERLIVEKIT_MODEL_LOCK=0
# Custom timeout
export WHISPERLIVEKIT_LOCK_TIMEOUT=60
"""
import os
import logging
import threading
logger = logging.getLogger(__name__)
# Configuration
USE_MODEL_LOCK = os.environ.get("WHISPERLIVEKIT_MODEL_LOCK", "1") == "1"
LOCK_TIMEOUT = float(os.environ.get("WHISPERLIVEKIT_LOCK_TIMEOUT", "30.0"))
# Global model lock
_model_lock = threading.Lock()
# Log configuration on import
if USE_MODEL_LOCK:
logger.info(f"Model locking ENABLED (timeout: {LOCK_TIMEOUT}s)")
logger.info("For single-connection deployments, set WHISPERLIVEKIT_MODEL_LOCK=0")
else:
logger.warning("Model locking DISABLED - only safe for single-connection deployments")
def get_model_lock():
"""Get the global model lock instance"""
return _model_lock
def acquire_model_lock(timeout=None):
"""
Acquire model lock with timeout.
Args:
timeout: Lock acquisition timeout (default: use LOCK_TIMEOUT)
Returns:
bool: True if lock acquired, False on timeout
"""
if not USE_MODEL_LOCK:
return True
timeout = timeout or LOCK_TIMEOUT
acquired = _model_lock.acquire(timeout=timeout)
if not acquired:
logger.error(f"Failed to acquire model lock within {timeout}s")
return acquired
def release_model_lock():
"""Release model lock"""
if not USE_MODEL_LOCK:
return
try:
_model_lock.release()
except RuntimeError:
# Lock not held - this is fine
pass
class ModelLockContext:
"""Context manager for model lock"""
def __init__(self, timeout=None):
self.timeout = timeout
self.acquired = False
def __enter__(self):
self.acquired = acquire_model_lock(self.timeout)
return self.acquired
def __exit__(self, exc_type, exc_val, exc_tb):
if self.acquired:
release_model_lock()
return False
# Concurrency recommendations
RECOMMENDED_CONNECTIONS_PER_WORKER = 1 if USE_MODEL_LOCK else 1
RECOMMENDED_WORKERS = 4
def print_deployment_recommendations():
"""Print recommended deployment configuration"""
print("\n" + "="*60)
print("WhisperLiveKit Deployment Recommendations")
print("="*60)
if USE_MODEL_LOCK:
print("⚠️ Model locking is ENABLED")
print(" This serializes inference across connections.")
print()
print("Recommended deployment:")
print(f" gunicorn -w {RECOMMENDED_WORKERS} \\")
print(" -k uvicorn.workers.UvicornWorker \\")
print(" --worker-connections 1 \\")
print(" whisperlivekit.basic_server:app")
print()
print("Expected capacity:")
print(f" - {RECOMMENDED_WORKERS} concurrent users (1 per worker)")
print(f" - Memory: ~{RECOMMENDED_WORKERS}x model size")
else:
print("✅ Model locking is DISABLED")
print(" ⚠️ ONLY safe for single-connection deployments")
print()
print("Recommended deployment:")
print(" uvicorn whisperlivekit.basic_server:app \\")
print(" --host 0.0.0.0 --port 8000 \\")
print(" --workers 1")
print()
print("Expected capacity:")
print(" - 1 concurrent user only")
print("="*60 + "\n")
if __name__ == "__main__":
print_deployment_recommendations()

View File

@@ -39,10 +39,11 @@ class TimedText(Timed):
@dataclass() @dataclass()
class ASRToken(TimedText): class ASRToken(TimedText):
probability: Optional[float] = None
def with_offset(self, offset: float) -> "ASRToken": def with_offset(self, offset: float) -> "ASRToken":
"""Return a new token with the time offset added.""" """Return a new token with the time offset added."""
return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, detected_language=self.detected_language) return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, detected_language=self.detected_language, probability=self.probability)
def is_silence(self) -> bool: def is_silence(self) -> bool:
return False return False

View File

@@ -49,9 +49,12 @@ class TokensAlignment:
def add_translation(self, segment: Segment) -> None: def add_translation(self, segment: Segment) -> None:
"""Append translated text segments that overlap with a segment.""" """Append translated text segments that overlap with a segment."""
if segment.translation is None:
segment.translation = ''
for ts in self.all_translation_segments: for ts in self.all_translation_segments:
if ts.is_within(segment): if ts.is_within(segment):
segment.translation += ts.text + (self.sep if ts.text else '') if ts.text:
segment.translation += ts.text + self.sep
elif segment.translation: elif segment.translation:
break break
@@ -183,9 +186,9 @@ class TokensAlignment:
else: else:
diarization_buffer = '' diarization_buffer = ''
for token in self.new_tokens: for token in self.new_tokens:
if token.is_silence(): if isinstance(token, Silence):
if self.current_line_tokens: if self.current_line_tokens:
self.validated_segments.append(Segment().from_tokens(self.current_line_tokens)) self.validated_segments.append(Segment.from_tokens(self.current_line_tokens))
self.current_line_tokens = [] self.current_line_tokens = []
end_silence = token.end if token.has_ended else time() - self.beg_loop end_silence = token.end if token.has_ended else time() - self.beg_loop
@@ -201,7 +204,7 @@ class TokensAlignment:
segments = list(self.validated_segments) segments = list(self.validated_segments)
if self.current_line_tokens: if self.current_line_tokens:
segments.append(Segment().from_tokens(self.current_line_tokens)) segments.append(Segment.from_tokens(self.current_line_tokens))
if current_silence: if current_silence:
end_silence = current_silence.end if current_silence.has_ended else time() - self.beg_loop end_silence = current_silence.end if current_silence.has_ended else time() - self.beg_loop

View File

@@ -0,0 +1,484 @@
"""
Voxtral Mini Realtime streaming backend using voxmlx's incremental encode/decode.
Uses model.encode_step() for incremental audio encoding and token-by-token
autoregressive decoding, matching voxmlx's native streaming pipeline.
"""
import logging
import sys
import time
from typing import List, Optional, Tuple
import numpy as np
from whisperlivekit.timed_objects import ASRToken, Transcript
logger = logging.getLogger(__name__)
N_LEFT_PAD_TOKENS = 32
N_RIGHT_PAD_TOKENS = 17
class VoxtralStreamingASR:
"""Voxtral model holder for the streaming pipeline."""
sep = " "
def __init__(self, logfile=sys.stderr, **kwargs):
from voxmlx import _build_prompt_tokens
from voxmlx import load_model as vox_load_model
self.logfile = logfile
self.transcribe_kargs = {}
lan = kwargs.get("lan", "auto")
self.original_language = None if lan == "auto" else lan
DEFAULT_MODEL = "mlx-community/Voxtral-Mini-4B-Realtime-6bit"
model_path = kwargs.get("model_dir") or kwargs.get("model_path")
if not model_path:
model_size = kwargs.get("model_size", "")
# Only use model_size if it looks like a HF repo or a path, not a Whisper size name
if model_size and ("/" in model_size or model_size.startswith(".")):
model_path = model_size
else:
model_path = DEFAULT_MODEL
t = time.time()
logger.info(f"Loading Voxtral model '{model_path}' via voxmlx...")
self.model, self._tokenizer, self._config = vox_load_model(model_path)
self._prompt_tokens, self._n_delay_tokens = _build_prompt_tokens(
self._tokenizer
)
logger.info(f"Voxtral model loaded in {time.time() - t:.2f}s")
self.backend_choice = "voxtral-mlx"
self.tokenizer = None # sentence tokenizer — not needed for streaming
def transcribe(self, audio):
pass
class VoxtralStreamingOnlineProcessor:
"""
Online processor for Voxtral streaming ASR.
Uses voxmlx's incremental encoding (encode_step) and token-by-token
autoregressive decoding. Each decode step corresponds to 80ms of audio.
"""
SAMPLING_RATE = 16000
def __init__(self, asr: VoxtralStreamingASR, logfile=sys.stderr):
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
self.asr = asr
self.logfile = logfile
self.end = 0.0
self.buffer = []
self.audio_buffer = np.array([], dtype=np.float32) # for logging compat
self._special_token_policy = SpecialTokenPolicy.IGNORE
self._reset_state()
logger.info(
f"[voxtral] Initialized. eos_id={asr._tokenizer.eos_id}, "
f"prefix_len={len(asr._prompt_tokens)}, "
f"n_delay={asr._n_delay_tokens}"
)
def _reset_state(self):
from voxmlx.audio import SAMPLES_PER_TOKEN
self._samples_per_token = SAMPLES_PER_TOKEN
# Incremental encoder state
self._audio_tail = None
self._conv1_tail = None
self._conv2_tail = None
self._encoder_cache = None
self._ds_buf = None
# Decoder state
self._decoder_cache = None
self._y = None # last sampled token (mx.array scalar)
self._t_cond = None
self._text_embeds = None
# Audio / decode tracking
self._pending_audio = np.zeros(0, dtype=np.float32)
self._audio_embeds = None
self._n_audio_samples_fed = 0
self._n_total_decoded = 0
self._first_cycle = True
self._prefilled = False
# Word extraction: accumulate token IDs, full-sequence decode for correct spacing
self._output_token_ids: List[int] = []
self._token_positions: List[int] = [] # decode position for each token
self._n_committed_words = 0
self._global_time_offset = 0.0
self._y_flushed_to_output = False # True after start_silence flushes pending _y
# ── Interface methods (same as SimulStreamingOnlineProcessor) ──
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float):
self.end = audio_stream_end_time
self._pending_audio = np.append(self._pending_audio, audio)
self.audio_buffer = self._pending_audio # for logging compat
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
try:
return self._process_iter_inner(is_last)
except Exception as e:
logger.warning(f"[voxtral] process_iter exception: {e}", exc_info=True)
return [], self.end
def _get_full_text(self) -> str:
"""Decode all accumulated token IDs at once for correct spacing."""
if not self._output_token_ids:
return ""
sp = self.asr._tokenizer
return sp.decode(self._output_token_ids, special_token_policy=self._special_token_policy)
def get_buffer(self) -> Transcript:
"""Return all uncommitted text as buffer, including pending _y token."""
# Temporarily include pending _y for buffer display
ids = list(self._output_token_ids)
if self._y is not None and not self._y_flushed_to_output:
sp = self.asr._tokenizer
token_id = self._y.item()
if token_id != sp.eos_id:
ids.append(token_id)
if not ids:
return Transcript(start=None, end=None, text="")
sp = self.asr._tokenizer
full_text = sp.decode(ids, special_token_policy=self._special_token_policy)
words = full_text.split()
uncommitted = words[self._n_committed_words:]
if uncommitted:
text = " ".join(uncommitted)
return Transcript(start=self.end, end=self.end, text=text)
return Transcript(start=None, end=None, text="")
def start_silence(self) -> Tuple[List[ASRToken], float]:
"""Flush all uncommitted words when silence starts."""
self._flush_last_y() # Include the pending _y token before flushing
words = self._flush_all_pending_words()
logger.info(f"[voxtral] start_silence: flushed {len(words)} words")
return words, self.end
def end_silence(self, silence_duration: float, offset: float):
self._global_time_offset += silence_duration
self.end += silence_duration
def new_speaker(self, change_speaker):
self.start_silence()
def warmup(self, audio, init_prompt=""):
pass
def finish(self) -> Tuple[List[ASRToken], float]:
"""Flush remaining audio with right-padding to let the model finish decoding."""
right_pad = np.zeros(
N_RIGHT_PAD_TOKENS * self._samples_per_token, dtype=np.float32
)
self._pending_audio = np.append(self._pending_audio, right_pad)
self._n_audio_samples_fed += len(right_pad)
final_words, _ = self._process_iter_inner(is_last=True)
# Flush the last pending self._y token (like voxmlx's finally block)
self._flush_last_y()
final_words.extend(self._flush_all_pending_words())
return final_words, self.end
# ── Word extraction ──
def _pos_to_time(self, pos: int) -> float:
"""Convert a decode position to seconds relative to audio start."""
SPT = self._samples_per_token
return max(0.0, (pos - N_LEFT_PAD_TOKENS) * SPT / self.SAMPLING_RATE)
def _flush_last_y(self):
"""Flush the last pending self._y token that hasn't been processed yet."""
if self._y is None or self._y_flushed_to_output:
return
sp = self.asr._tokenizer
token_id = self._y.item()
if token_id != sp.eos_id:
self._output_token_ids.append(token_id)
self._token_positions.append(self._n_total_decoded)
self._y_flushed_to_output = True
def _extract_new_words(self) -> List[ASRToken]:
"""
Split accumulated text into words and return new complete words
(all but the last, which may still be growing).
"""
if not self._output_token_ids:
return []
full_text = self._get_full_text()
words = full_text.split()
new_words: List[ASRToken] = []
n_tokens = len(self._output_token_ids)
# All words except the last are guaranteed complete
while len(words) > self._n_committed_words + 1:
word = words[self._n_committed_words]
word_idx = self._n_committed_words
n_words_total = len(words)
# Approximate: assign token range proportionally
tok_start = int(word_idx / n_words_total * n_tokens)
tok_end = int((word_idx + 1) / n_words_total * n_tokens)
tok_start = min(tok_start, len(self._token_positions) - 1)
tok_end = min(tok_end, len(self._token_positions) - 1)
start_time = self._pos_to_time(self._token_positions[tok_start]) + self._global_time_offset
end_time = self._pos_to_time(self._token_positions[tok_end]) + self._global_time_offset
# Prepend space to match Whisper convention (Segment.from_tokens joins with '')
text = word if self._n_committed_words == 0 else " " + word
new_words.append(ASRToken(start=start_time, end=end_time, text=text))
self._n_committed_words += 1
return new_words
def _flush_all_pending_words(self) -> List[ASRToken]:
"""Flush ALL words including the last partial one."""
if not self._output_token_ids:
return []
full_text = self._get_full_text()
words = full_text.split()
new_words: List[ASRToken] = []
n_tokens = len(self._output_token_ids)
n_words_total = max(len(words), 1)
while self._n_committed_words < len(words):
word = words[self._n_committed_words]
word_idx = self._n_committed_words
tok_start = int(word_idx / n_words_total * n_tokens)
tok_end = int((word_idx + 1) / n_words_total * n_tokens)
tok_start = min(tok_start, max(len(self._token_positions) - 1, 0))
tok_end = min(tok_end, max(len(self._token_positions) - 1, 0))
if self._token_positions:
start_time = self._pos_to_time(self._token_positions[tok_start]) + self._global_time_offset
end_time = self._pos_to_time(self._token_positions[tok_end]) + self._global_time_offset
else:
start_time = self._global_time_offset
end_time = self._global_time_offset
# Prepend space to match Whisper convention (Segment.from_tokens joins with '')
text = word if self._n_committed_words == 0 else " " + word
new_words.append(ASRToken(start=start_time, end=end_time, text=text))
self._n_committed_words += 1
return new_words
# ── Core streaming logic ──
def _process_iter_inner(self, is_last: bool) -> Tuple[List[ASRToken], float]:
import mlx.core as mx
from voxmlx.audio import log_mel_spectrogram_step
from voxmlx.cache import RotatingKVCache
model = self.asr.model
sp = self.asr._tokenizer
prompt_tokens = self.asr._prompt_tokens
prefix_len = len(prompt_tokens)
SPT = self._samples_per_token
# ── Phase 1: Encode new audio ──
if self._first_cycle and len(self._pending_audio) >= SPT:
left_pad = np.zeros(N_LEFT_PAD_TOKENS * SPT, dtype=np.float32)
n_feed = (len(self._pending_audio) // SPT) * SPT
chunk = np.concatenate([left_pad, self._pending_audio[:n_feed]])
self._pending_audio = self._pending_audio[n_feed:]
self._n_audio_samples_fed += n_feed
mel, self._audio_tail = log_mel_spectrogram_step(
chunk, self._audio_tail
)
(
new_embeds,
self._conv1_tail,
self._conv2_tail,
self._encoder_cache,
self._ds_buf,
) = model.encode_step(
mel,
self._conv1_tail,
self._conv2_tail,
self._encoder_cache,
self._ds_buf,
)
if new_embeds is not None:
mx.eval(new_embeds)
self._audio_embeds = new_embeds
logger.info(f"[voxtral] first encode: {new_embeds.shape[0]} embeds from {n_feed} samples")
else:
logger.info(f"[voxtral] first encode: no embeds from {n_feed} samples")
self._first_cycle = False
elif not self._first_cycle and len(self._pending_audio) >= SPT:
n_feed = (len(self._pending_audio) // SPT) * SPT
chunk = self._pending_audio[:n_feed]
self._pending_audio = self._pending_audio[n_feed:]
self._n_audio_samples_fed += n_feed
mel, self._audio_tail = log_mel_spectrogram_step(
chunk, self._audio_tail
)
(
new_embeds,
self._conv1_tail,
self._conv2_tail,
self._encoder_cache,
self._ds_buf,
) = model.encode_step(
mel,
self._conv1_tail,
self._conv2_tail,
self._encoder_cache,
self._ds_buf,
)
if new_embeds is not None:
mx.eval(new_embeds)
if self._audio_embeds is not None:
self._audio_embeds = mx.concatenate(
[self._audio_embeds, new_embeds]
)
else:
self._audio_embeds = new_embeds
self.audio_buffer = self._pending_audio # for logging compat
if self._audio_embeds is None:
return [], self.end
# Safety: don't decode ahead of encoded audio
safe_total = (
N_LEFT_PAD_TOKENS + self._n_audio_samples_fed // SPT
)
n_decodable = min(
self._audio_embeds.shape[0], safe_total - self._n_total_decoded
)
if n_decodable <= 0:
return [], self.end
# ── Phase 2: Prefill (once per utterance) ──
if not self._prefilled:
if self._n_total_decoded + self._audio_embeds.shape[0] < prefix_len:
logger.info(
f"[voxtral] waiting for prefill: have {self._audio_embeds.shape[0]} embeds, need {prefix_len}"
)
return [], self.end
n_layers = len(model.language_model.layers)
self._decoder_cache = [RotatingKVCache(8192) for _ in range(n_layers)]
self._t_cond = model.time_embedding(
mx.array([self.asr._n_delay_tokens], dtype=mx.float32)
)
prompt_ids = mx.array([prompt_tokens])
self._text_embeds = model.language_model.embed(prompt_ids)[0]
prefix_embeds = (
self._text_embeds + self._audio_embeds[:prefix_len]
)[None, :, :]
logits = model.decode(
prefix_embeds, self._t_cond, "causal", self._decoder_cache
)
mx.eval(
logits,
*[x for c in self._decoder_cache for x in (c.keys, c.values)],
)
self._y = mx.argmax(logits[0, -1:], axis=-1).squeeze()
mx.async_eval(self._y)
self._audio_embeds = self._audio_embeds[prefix_len:]
self._n_total_decoded = prefix_len
self._prefilled = True
logger.info(f"[voxtral] prefill done, first token y={self._y.item()}")
n_decodable = min(
self._audio_embeds.shape[0], safe_total - self._n_total_decoded
)
if n_decodable <= 0:
return [], self.end
# ── Phase 3: Decode new positions ──
eos_id = sp.eos_id
hit_eos = False
n_consumed = 0
for i in range(n_decodable):
token_embed = model.language_model.embed(self._y.reshape(1, 1))[0, 0]
step_embed = (self._audio_embeds[i] + token_embed)[None, None, :]
logits = model.decode(
step_embed, self._t_cond, mask=None, cache=self._decoder_cache
)
next_y = mx.argmax(logits[0, -1:], axis=-1).squeeze()
mx.async_eval(next_y)
token_id = self._y.item()
n_consumed = i + 1
if token_id == eos_id:
hit_eos = True
logger.info("[voxtral] hit EOS")
break
# Accumulate token ID — full-sequence decode produces correct spacing
# Skip if this _y was already flushed by start_silence()
if self._y_flushed_to_output:
self._y_flushed_to_output = False
else:
self._output_token_ids.append(token_id)
# Track position for timestamp estimation
pos = self._n_total_decoded + i
self._token_positions.append(pos)
if i > 0 and i % 256 == 0:
mx.clear_cache()
self._y = next_y
self._n_total_decoded += n_consumed
# Trim consumed embeddings
if self._audio_embeds.shape[0] > n_consumed:
self._audio_embeds = self._audio_embeds[n_consumed:]
else:
self._audio_embeds = None
# Log decode results
full_text = self._get_full_text()
logger.info(
f"[voxtral] decoded {n_consumed} tokens | "
f"total_decoded={self._n_total_decoded} | "
f"text='{full_text[-80:]}' | "
f"n_words={len(full_text.split())} committed={self._n_committed_words}"
)
# Extract complete words from the decoded token sequence
new_words = self._extract_new_words()
if hit_eos:
new_words.extend(self._flush_all_pending_words())
self._reset_state()
if new_words:
logger.info(f"[voxtral] returning {len(new_words)} words: {[w.text for w in new_words]}")
self.buffer = []
return new_words, self.end

View File

@@ -108,7 +108,7 @@ def available_models() -> List[str]:
def _infer_dims_from_config(path: str) -> Optional[ModelDimensions]: def _infer_dims_from_config(path: str) -> Optional[ModelDimensions]:
""" """
attempt to infer ModelDimensions from a HF style config.json located attempt to infer ModelDimensions from a HF style config.json located
next to the given checkpoint, usefull for distilled models next to the given checkpoint, usefull for distilled models/MLX models.
""" """
candidates = [] candidates = []
if os.path.isdir(path): if os.path.isdir(path):
@@ -122,6 +122,25 @@ def _infer_dims_from_config(path: str) -> Optional[ModelDimensions]:
with open(candidate, "r", encoding="utf-8") as f: with open(candidate, "r", encoding="utf-8") as f:
config = json.load(f) config = json.load(f)
# native Whisper format
native_keys = ["n_mels", "n_audio_ctx", "n_audio_state", "n_audio_head",
"n_audio_layer", "n_vocab", "n_text_ctx", "n_text_state",
"n_text_head", "n_text_layer"]
if all(k in config for k in native_keys):
return ModelDimensions(
n_mels=config["n_mels"],
n_audio_ctx=config["n_audio_ctx"],
n_audio_state=config["n_audio_state"],
n_audio_head=config["n_audio_head"],
n_audio_layer=config["n_audio_layer"],
n_vocab=config["n_vocab"],
n_text_ctx=config["n_text_ctx"],
n_text_state=config["n_text_state"],
n_text_head=config["n_text_head"],
n_text_layer=config["n_text_layer"],
)
# HuggingFace format
try: try:
return ModelDimensions( return ModelDimensions(
n_mels=config["num_mel_bins"], n_mels=config["num_mel_bins"],
@@ -236,6 +255,24 @@ def _convert_hf_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, tor
return converted if converted else state_dict return converted if converted else state_dict
def _convert_mlx_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Converts an mlx whisper checkpoint to a default openai whisper one
"""
if not any("mlp1" in k or "mlp2" in k for k in state_dict):
return state_dict
converted = {}
for key, value in state_dict.items():
if key == "alignment_heads":
continue
new_key = key.replace(".mlp1.", ".mlp.0.").replace(".mlp2.", ".mlp.2.")
converted[new_key] = value
return converted
def _load_lora_state(lora_path: str): def _load_lora_state(lora_path: str):
safe_path = os.path.join(lora_path, "adapter_model.safetensors") safe_path = os.path.join(lora_path, "adapter_model.safetensors")
bin_path = os.path.join(lora_path, "adapter_model.bin") bin_path = os.path.join(lora_path, "adapter_model.bin")
@@ -520,7 +557,12 @@ def load_model(
state_dict = checkpoint["model_state_dict"] state_dict = checkpoint["model_state_dict"]
else: else:
state_dict = checkpoint state_dict = checkpoint
if alignment_heads is None and "alignment_heads" in state_dict:
alignment_heads = state_dict["alignment_heads"]
state_dict = _convert_hf_state_dict(state_dict) state_dict = _convert_hf_state_dict(state_dict)
state_dict = _convert_mlx_state_dict(state_dict)
_apply_lora_adapter(state_dict, lora_path) _apply_lora_adapter(state_dict, lora_path)
if dims_cfg is not None: if dims_cfg is not None:
@@ -546,8 +588,13 @@ def load_model(
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
if alignment_heads is not None: if alignment_heads is not None:
model.set_alignment_heads(alignment_heads) if isinstance(alignment_heads, bytes):
model.set_alignment_heads(alignment_heads)
elif isinstance(alignment_heads, torch.Tensor): #for mlx whisper
mask = torch.zeros(dims.n_text_layer, dims.n_text_head, dtype=torch.bool)
for layer, head in alignment_heads.tolist():
mask[layer, head] = True
model.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
return model.to(device) return model.to(device)