33 Commits

Author SHA1 Message Date
Quentin Fuxa
7ea507ed8e Add Voxtral MLX streaming backend
Integrates the voxmlx-based Voxtral Mini Realtime streaming pipeline:
- VoxtralStreamingASR and VoxtralStreamingOnlineProcessor
- Incremental audio encoding and token-by-token autoregressive decoding
- Selectable via --backend voxtral-mlx

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 09:20:28 +01:00
Quentin Fuxa
e7e82f7c19 bump to 0.2.18 2026-02-11 22:10:00 +01:00
Quentin Fuxa
8c799fa4d1 fix simulstreaming vram leak: cap cross-attn accumulation + token budget
fixes #283, fixes #275

- accumulated_cross_attns was growing unboundedly during decoding loop,
  using up to ~5GB for repetition loops. now capped to rolling window of 16
- max_tokens_per_chunk was using TOKENS_PER_SECOND (mel frame rate = 50)
  instead of actual text token rate (~15/s), allowing 10-40x too many
  decoding steps
- removed unused torch.cat on early return path
- removed dead self.committed/last_result_tokens lists (never read)
- same fixes applied to mlx variant
2026-02-11 22:10:00 +01:00
Quentin Fuxa
8923337380 fix --direct-english-translation not setting task=translate for localagreement backends
the flag was only used for tokenizer language selection but never
actually passed to whisper/faster-whisper transcribe calls. also init
OpenaiApiASR.task and read from transcribe_kargs.

fixes #306
2026-02-11 22:10:00 +01:00
Quentin Fuxa
aded1649ae fix model_cache_dir + direct_english_translation task in simulstreaming
pass actual cache dir instead of None, and use proper task string
instead of boolean for AlignAttConfig

fixes #310
2026-02-11 22:10:00 +01:00
Quentin Fuxa
3b535e857a fix NoneType concatenation in add_translation
fixes #296
2026-02-11 22:10:00 +01:00
Quentin Fuxa
d649250b9a fix Segment classmethod call + isinstance type narrowing
fixes #331, fixes #329
2026-02-11 22:10:00 +01:00
Quentin Fuxa
7735478286 add insert_audio_chunk to DiartDiarization
fixes #332
2026-02-11 22:10:00 +01:00
Quentin Fuxa
b9e72d2b9a add probability field to ASRToken
fixes #330, fixes #313
2026-02-11 22:10:00 +01:00
Quentin Fuxa
e5b01033af add json normalizers for english language in build 2026-01-16 10:47:46 +01:00
Quentin Fuxa
6ae545bcb1 bump to 0.2.17.post1 2026-01-16 10:43:52 +01:00
Quentin Fuxa
04980d3f5e Merge branch 'main' of https://github.com/QuentinFuxa/WhisperLiveKit 2026-01-16 10:38:29 +01:00
Quentin Fuxa
79a705c969 fixes #323 2026-01-16 10:38:07 +01:00
Quentin Fuxa
34e4abd455 Merge pull request #322 from eschmidbauer/fix/thread-safety-issues
Fix kv cache not being properly cleaned between sessions
2026-01-09 19:23:35 +01:00
Emmanuel Schmidbauer
d59ddbaeae Fix critical thread safety issues 2026-01-09 11:23:19 -05:00
Quentin Fuxa
4dd66e7766 Merge pull request #317 from jantonj/fix-bug-diarization-lag
update diarization lag after stream analysed
2025-12-19 17:43:07 +01:00
Anton Jacobson
3db5d81a20 update diarization lag after stream analysed 2025-12-18 14:13:28 +01:00
Quentin Fuxa
b67ddea494 bump to 0.2.17 2025-12-08 23:52:00 +01:00
Quentin Fuxa
3192553e20 fixes #307 2025-12-09 10:27:49 +01:00
Quentin Fuxa
f379a243fe Merge pull request #274 from blakkd/patch-1
minor path change
2025-12-09 10:10:32 +01:00
Quentin Fuxa
ec09898a9f fixes #301 2025-12-06 10:19:50 +01:00
blakkd
befbae56c7 minor path change
prevents

```
FileNotFoundError: [Errno 2] No such file or directory: 'whisperlivekit/web/live_transcription.html'
```
2025-11-16 23:47:58 +01:00
Quentin Fuxa
719e8b1a20 adapt online for mlx detection 2024-11-25 23:52:00 +01:00
Quentin Fuxa
f1b47178d8 adapt online for mlx detection 2024-11-25 23:52:00 +01:00
Quentin Fuxa
59db08e961 loader for full mlx 2024-11-25 23:52:00 +01:00
Quentin Fuxa
6fc20b9562 new dec class 2024-11-21 23:52:00 +01:00
Quentin Fuxa
fac8659161 uses native mlx function for attention 2024-11-21 23:52:00 +01:00
Quentin Fuxa
4d9332ce7d fixes #299 2025-12-05 17:54:14 +01:00
Quentin Fuxa
62444ce746 session parameter required in OnnxWrapper 2025-12-05 15:37:18 +01:00
Quentin Fuxa
2431a6bf91 isolated VAD states per user: .onnx: share a stateless model. .jit: require duplicating the model.
Co-authored-by: eschmidbauer <eschmidbauer@gmail.com>
2025-12-05 15:27:14 +01:00
Quentin Fuxa
d1263e7228 Merge pull request #308 from gzz2000/main
Fix local agreement backend, removing excess parameter, #295
2025-12-05 11:34:05 +01:00
Zizheng Guo
30ddd522a4 Fix local agreement backend, removing excess parameter, fixes https://github.com/QuentinFuxa/WhisperLiveKit/issues/295 2025-12-04 16:45:23 +08:00
Quentin Fuxa
635bace09e update archi 2025-11-30 18:39:10 +01:00
26 changed files with 2087 additions and 184 deletions

View File

@@ -37,9 +37,10 @@ RUN pip3 install --upgrade pip setuptools wheel && \
COPY . . COPY . .
# Install WhisperLiveKit directly, allowing for optional dependencies # Install WhisperLiveKit directly, allowing for optional dependencies
# Example: --build-arg EXTRAS="translation"
RUN if [ -n "$EXTRAS" ]; then \ RUN if [ -n "$EXTRAS" ]; then \
echo "Installing with extras: [$EXTRAS]"; \ echo "Installing with extras: [$EXTRAS]"; \
pip install --no-cache-dir whisperlivekit[$EXTRAS]; \ pip install --no-cache-dir "whisperlivekit[$EXTRAS]"; \
else \ else \
echo "Installing base package only"; \ echo "Installing base package only"; \
pip install --no-cache-dir whisperlivekit; \ pip install --no-cache-dir whisperlivekit; \

View File

@@ -147,8 +147,8 @@ async def websocket_endpoint(websocket: WebSocket):
|-----------|-------------|---------| |-----------|-------------|---------|
| `--model` | Whisper model size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/default_and_custom_models.md) | `small` | | `--model` | Whisper model size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/default_and_custom_models.md) | `small` |
| `--model-path` | Local .pt file/directory **or** Hugging Face repo ID containing the Whisper model. Overrides `--model`. Recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/default_and_custom_models.md) | `None` | | `--model-path` | Local .pt file/directory **or** Hugging Face repo ID containing the Whisper model. Overrides `--model`. Recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/default_and_custom_models.md) | `None` |
| `--language` | List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/whisper/tokenizer.py). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English. | `auto` | | `--language` | List [here](docs/supported_languages.md). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English. | `auto` |
| `--target-language` | If sets, translates using [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting). [200 languages available](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/supported_languages.md). If you want to translate to english, you can also use `--direct-english-translation`. The STT model will try to directly output the translation. | `None` | | `--target-language` | If sets, translates using [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting). [200 languages available](docs/supported_languages.md). If you want to translate to english, you can also use `--direct-english-translation`. The STT model will try to directly output the translation. | `None` |
| `--diarization` | Enable speaker identification | `False` | | `--diarization` | Enable speaker identification | `False` |
| `--backend-policy` | Streaming strategy: `1`/`simulstreaming` uses AlignAtt SimulStreaming, `2`/`localagreement` uses the LocalAgreement policy | `simulstreaming` | | `--backend-policy` | Streaming strategy: `1`/`simulstreaming` uses AlignAtt SimulStreaming, `2`/`localagreement` uses the LocalAgreement policy | `simulstreaming` |
| `--backend` | Whisper implementation selector. `auto` picks MLX on macOS (if installed), otherwise Faster-Whisper, otherwise vanilla Whisper. You can also force `mlx-whisper`, `faster-whisper`, `whisper`, or `openai-api` (LocalAgreement only) | `auto` | | `--backend` | Whisper implementation selector. `auto` picks MLX on macOS (if installed), otherwise Faster-Whisper, otherwise vanilla Whisper. You can also force `mlx-whisper`, `faster-whisper`, `whisper`, or `openai-api` (LocalAgreement only) | `auto` |
@@ -267,7 +267,7 @@ docker run --gpus all -p 8000:8000 --name wlk wlk --model large-v3 --language fr
#### Customization #### Customization
- `--build-arg` Options: - `--build-arg` Options:
- `EXTRAS="whisper-timestamped"` - Add extras to the image's installation (no spaces). Remember to set necessary container options! - `EXTRAS="translation"` - Add extras to the image's installation (no spaces). Remember to set necessary container options!
- `HF_PRECACHE_DIR="./.cache/"` - Pre-load a model cache for faster first-time start - `HF_PRECACHE_DIR="./.cache/"` - Pre-load a model cache for faster first-time start
- `HF_TKN_FILE="./token"` - Add your Hugging Face Hub access token to download gated models - `HF_TKN_FILE="./token"` - Add your Hugging Face Hub access token to download gated models

Binary file not shown.

Before

Width:  |  Height:  |  Size: 422 KiB

After

Width:  |  Height:  |  Size: 422 KiB

View File

@@ -6,7 +6,7 @@ Capture the audio of your current tab, transcribe diarize and translate it using
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/chrome-extension/demo-extension.png" alt="WhisperLiveKit Demo" width="730"> <img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/chrome-extension/demo-extension.png" alt="WhisperLiveKit Demo" width="730">
## Running this extension ## Running this extension
1. Run `python sync_extension.py` to copy frontend files to the `chrome-extension` directory. 1. Run `python scripts/sync_extension.py` to copy frontend files to the `chrome-extension` directory.
2. Load the `chrome-extension` directory in Chrome as an unpacked extension. 2. Load the `chrome-extension` directory in Chrome as an unpacked extension.

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "whisperlivekit" name = "whisperlivekit"
version = "0.2.16.dev0" version = "0.2.18"
description = "Real-time speech-to-text with speaker diarization using Whisper" description = "Real-time speech-to-text with speaker diarization using Whisper"
readme = "README.md" readme = "README.md"
authors = [ authors = [
@@ -35,6 +35,7 @@ dependencies = [
"torchaudio>=2.0.0", "torchaudio>=2.0.0",
"torch>=2.0.0", "torch>=2.0.0",
"huggingface-hub>=0.25.0", "huggingface-hub>=0.25.0",
"faster-whisper>=1.2.0",
"tqdm", "tqdm",
"tiktoken", "tiktoken",
'triton>=2.0.0; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")' 'triton>=2.0.0; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")'
@@ -56,6 +57,7 @@ packages = [
"whisperlivekit", "whisperlivekit",
"whisperlivekit.diarization", "whisperlivekit.diarization",
"whisperlivekit.simul_whisper", "whisperlivekit.simul_whisper",
"whisperlivekit.simul_whisper.mlx",
"whisperlivekit.whisper", "whisperlivekit.whisper",
"whisperlivekit.whisper.assets", "whisperlivekit.whisper.assets",
"whisperlivekit.whisper.normalizers", "whisperlivekit.whisper.normalizers",
@@ -67,4 +69,5 @@ packages = [
[tool.setuptools.package-data] [tool.setuptools.package-data]
whisperlivekit = ["web/*.html", "web/*.css", "web/*.js", "web/src/*.svg"] whisperlivekit = ["web/*.html", "web/*.css", "web/*.js", "web/src/*.svg"]
"whisperlivekit.whisper.assets" = ["*.tiktoken", "*.npz"] "whisperlivekit.whisper.assets" = ["*.tiktoken", "*.npz"]
"whisperlivekit.whisper.normalizers" = ["*.json"]
"whisperlivekit.silero_vad_models" = ["*.jit", "*.onnx"] "whisperlivekit.silero_vad_models" = ["*.jit", "*.onnx"]

View File

@@ -10,7 +10,7 @@ from whisperlivekit.core import (TranscriptionEngine,
online_diarization_factory, online_factory, online_diarization_factory, online_factory,
online_translation_factory) online_translation_factory)
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
from whisperlivekit.silero_vad_iterator import FixedVADIterator from whisperlivekit.silero_vad_iterator import FixedVADIterator, OnnxWrapper, load_jit_vad
from whisperlivekit.timed_objects import (ASRToken, ChangeSpeaker, FrontData, from whisperlivekit.timed_objects import (ASRToken, ChangeSpeaker, FrontData,
Segment, Silence, State, Transcript) Segment, Silence, State, Transcript)
from whisperlivekit.tokens_alignment import TokensAlignment from whisperlivekit.tokens_alignment import TokensAlignment
@@ -32,7 +32,7 @@ async def get_all_from_queue(queue: asyncio.Queue) -> Union[object, Silence, np.
if isinstance(first_item, Silence): if isinstance(first_item, Silence):
return first_item return first_item
items.append(first_item) items.append(first_item)
while True: while True:
if not queue._queue: if not queue._queue:
break break
@@ -53,15 +53,15 @@ class AudioProcessor:
Processes audio streams for transcription and diarization. Processes audio streams for transcription and diarization.
Handles audio processing, state management, and result formatting. Handles audio processing, state management, and result formatting.
""" """
def __init__(self, **kwargs: Any) -> None: def __init__(self, **kwargs: Any) -> None:
"""Initialize the audio processor with configuration, models, and state.""" """Initialize the audio processor with configuration, models, and state."""
if 'transcription_engine' in kwargs and isinstance(kwargs['transcription_engine'], TranscriptionEngine): if 'transcription_engine' in kwargs and isinstance(kwargs['transcription_engine'], TranscriptionEngine):
models = kwargs['transcription_engine'] models = kwargs['transcription_engine']
else: else:
models = TranscriptionEngine(**kwargs) models = TranscriptionEngine(**kwargs)
# Audio processing settings # Audio processing settings
self.args = models.args self.args = models.args
self.sample_rate = 16000 self.sample_rate = 16000
@@ -85,12 +85,14 @@ class AudioProcessor:
# Models and processing # Models and processing
self.asr: Any = models.asr self.asr: Any = models.asr
self.vac_model: Any = models.vac_model self.vac: Optional[FixedVADIterator] = None
if self.args.vac: if self.args.vac:
self.vac: Optional[FixedVADIterator] = FixedVADIterator(models.vac_model) if models.vac_session is not None:
else: vac_model = OnnxWrapper(session=models.vac_session)
self.vac: Optional[FixedVADIterator] = None self.vac = FixedVADIterator(vac_model)
else:
self.vac = FixedVADIterator(load_jit_vad())
self.ffmpeg_manager: Optional[FFmpegManager] = None self.ffmpeg_manager: Optional[FFmpegManager] = None
self.ffmpeg_reader_task: Optional[asyncio.Task] = None self.ffmpeg_reader_task: Optional[asyncio.Task] = None
self._ffmpeg_error: Optional[str] = None self._ffmpeg_error: Optional[str] = None
@@ -104,7 +106,7 @@ class AudioProcessor:
logger.error(f"FFmpeg error: {error_type}") logger.error(f"FFmpeg error: {error_type}")
self._ffmpeg_error = error_type self._ffmpeg_error = error_type
self.ffmpeg_manager.on_error_callback = handle_ffmpeg_error self.ffmpeg_manager.on_error_callback = handle_ffmpeg_error
self.transcription_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.transcription else None self.transcription_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.transcription else None
self.diarization_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.diarization else None self.diarization_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.diarization else None
self.translation_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.target_language else None self.translation_queue: Optional[asyncio.Queue] = asyncio.Queue() if self.args.target_language else None
@@ -115,14 +117,14 @@ class AudioProcessor:
self.translation_task: Optional[asyncio.Task] = None self.translation_task: Optional[asyncio.Task] = None
self.watchdog_task: Optional[asyncio.Task] = None self.watchdog_task: Optional[asyncio.Task] = None
self.all_tasks_for_cleanup: List[asyncio.Task] = [] self.all_tasks_for_cleanup: List[asyncio.Task] = []
self.transcription: Optional[Any] = None self.transcription: Optional[Any] = None
self.translation: Optional[Any] = None self.translation: Optional[Any] = None
self.diarization: Optional[Any] = None self.diarization: Optional[Any] = None
if self.args.transcription: if self.args.transcription:
self.transcription = online_factory(self.args, models.asr) self.transcription = online_factory(self.args, models.asr)
self.sep = self.transcription.asr.sep self.sep = self.transcription.asr.sep
if self.args.diarization: if self.args.diarization:
self.diarization = online_diarization_factory(self.args, models.diarization_model) self.diarization = online_diarization_factory(self.args, models.diarization_model)
if models.translation_model: if models.translation_model:
@@ -180,24 +182,24 @@ class AudioProcessor:
def convert_pcm_to_float(self, pcm_buffer: Union[bytes, bytearray]) -> np.ndarray: def convert_pcm_to_float(self, pcm_buffer: Union[bytes, bytearray]) -> np.ndarray:
"""Convert PCM buffer in s16le format to normalized NumPy array.""" """Convert PCM buffer in s16le format to normalized NumPy array."""
return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0 return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0
async def get_current_state(self) -> State: async def get_current_state(self) -> State:
"""Get current state.""" """Get current state."""
async with self.lock: async with self.lock:
current_time = time() current_time = time()
remaining_transcription = 0 remaining_transcription = 0
if self.state.end_buffer > 0: if self.state.end_buffer > 0:
remaining_transcription = max(0, round(current_time - self.beg_loop - self.state.end_buffer, 1)) remaining_transcription = max(0, round(current_time - self.beg_loop - self.state.end_buffer, 1))
remaining_diarization = 0 remaining_diarization = 0
if self.state.tokens: if self.state.tokens:
latest_end = max(self.state.end_buffer, self.state.tokens[-1].end if self.state.tokens else 0) latest_end = max(self.state.end_buffer, self.state.tokens[-1].end if self.state.tokens else 0)
remaining_diarization = max(0, round(latest_end - self.state.end_attributed_speaker, 1)) remaining_diarization = max(0, round(latest_end - self.state.end_attributed_speaker, 1))
self.state.remaining_time_transcription = remaining_transcription self.state.remaining_time_transcription = remaining_transcription
self.state.remaining_time_diarization = remaining_diarization self.state.remaining_time_diarization = remaining_diarization
return self.state return self.state
async def ffmpeg_stdout_reader(self) -> None: async def ffmpeg_stdout_reader(self) -> None:
@@ -253,7 +255,7 @@ class AudioProcessor:
async def transcription_processor(self) -> None: async def transcription_processor(self) -> None:
"""Process audio chunks for transcription.""" """Process audio chunks for transcription."""
cumulative_pcm_duration_stream_time = 0.0 cumulative_pcm_duration_stream_time = 0.0
while True: while True:
try: try:
# item = await self.transcription_queue.get() # item = await self.transcription_queue.get()
@@ -309,12 +311,12 @@ class AudioProcessor:
if new_tokens: if new_tokens:
candidate_end_times.append(new_tokens[-1].end) candidate_end_times.append(new_tokens[-1].end)
if _buffer_transcript.end is not None: if _buffer_transcript.end is not None:
candidate_end_times.append(_buffer_transcript.end) candidate_end_times.append(_buffer_transcript.end)
candidate_end_times.append(current_audio_processed_upto) candidate_end_times.append(current_audio_processed_upto)
async with self.lock: async with self.lock:
self.state.tokens.extend(new_tokens) self.state.tokens.extend(new_tokens)
self.state.buffer_transcription = _buffer_transcript self.state.buffer_transcription = _buffer_transcript
@@ -324,13 +326,13 @@ class AudioProcessor:
if self.translation_queue: if self.translation_queue:
for token in new_tokens: for token in new_tokens:
await self.translation_queue.put(token) await self.translation_queue.put(token)
except Exception as e: except Exception as e:
logger.warning(f"Exception in transcription_processor: {e}") logger.warning(f"Exception in transcription_processor: {e}")
logger.warning(f"Traceback: {traceback.format_exc()}") logger.warning(f"Traceback: {traceback.format_exc()}")
if 'pcm_array' in locals() and pcm_array is not SENTINEL : # Check if pcm_array was assigned from queue if 'pcm_array' in locals() and pcm_array is not SENTINEL : # Check if pcm_array was assigned from queue
self.transcription_queue.task_done() self.transcription_queue.task_done()
if self.is_stopping: if self.is_stopping:
logger.info("Transcription processor finishing due to stopping flag.") logger.info("Transcription processor finishing due to stopping flag.")
if self.diarization_queue: if self.diarization_queue:
@@ -351,18 +353,21 @@ class AudioProcessor:
if item.has_ended: if item.has_ended:
self.diarization.insert_silence(item.duration) self.diarization.insert_silence(item.duration)
continue continue
self.diarization.insert_audio_chunk(item) self.diarization.insert_audio_chunk(item)
diarization_segments = await self.diarization.diarize() diarization_segments = await self.diarization.diarize()
self.state.new_diarization = diarization_segments diar_end = 0.0
if diarization_segments:
diar_end = max(getattr(s, "end", 0.0) for s in diarization_segments)
async with self.lock:
self.state.new_diarization = diarization_segments
self.state.end_attributed_speaker = max(self.state.end_attributed_speaker, diar_end)
except Exception as e: except Exception as e:
logger.warning(f"Exception in diarization_processor: {e}") logger.warning(f"Exception in diarization_processor: {e}")
logger.warning(f"Traceback: {traceback.format_exc()}") logger.warning(f"Traceback: {traceback.format_exc()}")
logger.info("Diarization processor task finished.") logger.info("Diarization processor task finished.")
async def translation_processor(self) -> None: async def translation_processor(self) -> None:
# the idea is to ignore diarization for the moment. We use only transcription tokens. # the idea is to ignore diarization for the moment. We use only transcription tokens.
# And the speaker is attributed given the segments used for the translation # And the speaker is attributed given the segments used for the translation
# in the future we want to have different languages for each speaker etc, so it will be more complex. # in the future we want to have different languages for each speaker etc, so it will be more complex.
while True: while True:
@@ -424,22 +429,22 @@ class AudioProcessor:
remaining_time_transcription=state.remaining_time_transcription, remaining_time_transcription=state.remaining_time_transcription,
remaining_time_diarization=state.remaining_time_diarization if self.args.diarization else 0 remaining_time_diarization=state.remaining_time_diarization if self.args.diarization else 0
) )
should_push = (response != self.last_response_content) should_push = (response != self.last_response_content)
if should_push: if should_push:
yield response yield response
self.last_response_content = response self.last_response_content = response
if self.is_stopping and self._processing_tasks_done(): if self.is_stopping and self._processing_tasks_done():
logger.info("Results formatter: All upstream processors are done and in stopping state. Terminating.") logger.info("Results formatter: All upstream processors are done and in stopping state. Terminating.")
return return
await asyncio.sleep(0.05) await asyncio.sleep(0.05)
except Exception as e: except Exception as e:
logger.warning(f"Exception in results_formatter. Traceback: {traceback.format_exc()}") logger.warning(f"Exception in results_formatter. Traceback: {traceback.format_exc()}")
await asyncio.sleep(0.5) await asyncio.sleep(0.5)
async def create_tasks(self) -> AsyncGenerator[FrontData, None]: async def create_tasks(self) -> AsyncGenerator[FrontData, None]:
"""Create and start processing tasks.""" """Create and start processing tasks."""
self.all_tasks_for_cleanup = [] self.all_tasks_for_cleanup = []
@@ -464,21 +469,21 @@ class AudioProcessor:
self.transcription_task = asyncio.create_task(self.transcription_processor()) self.transcription_task = asyncio.create_task(self.transcription_processor())
self.all_tasks_for_cleanup.append(self.transcription_task) self.all_tasks_for_cleanup.append(self.transcription_task)
processing_tasks_for_watchdog.append(self.transcription_task) processing_tasks_for_watchdog.append(self.transcription_task)
if self.diarization: if self.diarization:
self.diarization_task = asyncio.create_task(self.diarization_processor()) self.diarization_task = asyncio.create_task(self.diarization_processor())
self.all_tasks_for_cleanup.append(self.diarization_task) self.all_tasks_for_cleanup.append(self.diarization_task)
processing_tasks_for_watchdog.append(self.diarization_task) processing_tasks_for_watchdog.append(self.diarization_task)
if self.translation: if self.translation:
self.translation_task = asyncio.create_task(self.translation_processor()) self.translation_task = asyncio.create_task(self.translation_processor())
self.all_tasks_for_cleanup.append(self.translation_task) self.all_tasks_for_cleanup.append(self.translation_task)
processing_tasks_for_watchdog.append(self.translation_task) processing_tasks_for_watchdog.append(self.translation_task)
# Monitor overall system health # Monitor overall system health
self.watchdog_task = asyncio.create_task(self.watchdog(processing_tasks_for_watchdog)) self.watchdog_task = asyncio.create_task(self.watchdog(processing_tasks_for_watchdog))
self.all_tasks_for_cleanup.append(self.watchdog_task) self.all_tasks_for_cleanup.append(self.watchdog_task)
return self.results_formatter() return self.results_formatter()
async def watchdog(self, tasks_to_monitor: List[asyncio.Task]) -> None: async def watchdog(self, tasks_to_monitor: List[asyncio.Task]) -> None:
@@ -491,7 +496,7 @@ class AudioProcessor:
return return
await asyncio.sleep(10) await asyncio.sleep(10)
for i, task in enumerate(list(tasks_remaining)): for i, task in enumerate(list(tasks_remaining)):
if task.done(): if task.done():
exc = task.exception() exc = task.exception()
@@ -501,13 +506,13 @@ class AudioProcessor:
else: else:
logger.info(f"{task_name} completed normally.") logger.info(f"{task_name} completed normally.")
tasks_remaining.remove(task) tasks_remaining.remove(task)
except asyncio.CancelledError: except asyncio.CancelledError:
logger.info("Watchdog task cancelled.") logger.info("Watchdog task cancelled.")
break break
except Exception as e: except Exception as e:
logger.error(f"Error in watchdog task: {e}", exc_info=True) logger.error(f"Error in watchdog task: {e}", exc_info=True)
async def cleanup(self) -> None: async def cleanup(self) -> None:
"""Clean up resources when processing is complete.""" """Clean up resources when processing is complete."""
logger.info("Starting cleanup of AudioProcessor resources.") logger.info("Starting cleanup of AudioProcessor resources.")
@@ -515,7 +520,7 @@ class AudioProcessor:
for task in self.all_tasks_for_cleanup: for task in self.all_tasks_for_cleanup:
if task and not task.done(): if task and not task.done():
task.cancel() task.cancel()
created_tasks = [t for t in self.all_tasks_for_cleanup if t] created_tasks = [t for t in self.all_tasks_for_cleanup if t]
if created_tasks: if created_tasks:
await asyncio.gather(*created_tasks, return_exceptions=True) await asyncio.gather(*created_tasks, return_exceptions=True)
@@ -553,7 +558,7 @@ class AudioProcessor:
if not message: if not message:
logger.info("Empty audio message received, initiating stop sequence.") logger.info("Empty audio message received, initiating stop sequence.")
self.is_stopping = True self.is_stopping = True
if self.transcription_queue: if self.transcription_queue:
await self.transcription_queue.put(SENTINEL) await self.transcription_queue.put(SENTINEL)
@@ -594,7 +599,7 @@ class AudioProcessor:
chunk_size = min(len(self.pcm_buffer), self.max_bytes_per_sec) chunk_size = min(len(self.pcm_buffer), self.max_bytes_per_sec)
aligned_chunk_size = (chunk_size // self.bytes_per_sample) * self.bytes_per_sample aligned_chunk_size = (chunk_size // self.bytes_per_sample) * self.bytes_per_sample
if aligned_chunk_size == 0: if aligned_chunk_size == 0:
return return
pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:aligned_chunk_size]) pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:aligned_chunk_size])
@@ -611,7 +616,7 @@ class AudioProcessor:
if res is not None: if res is not None:
if "start" in res and self.current_silence: if "start" in res and self.current_silence:
await self._end_silence() await self._end_silence()
if "end" in res and not self.current_silence: if "end" in res and not self.current_silence:
pre_silence_chunk = self._slice_before_silence( pre_silence_chunk = self._slice_before_silence(
pcm_array, chunk_sample_start, res.get("end") pcm_array, chunk_sample_start, res.get("end")

View File

@@ -29,6 +29,13 @@ def mlx_backend_available(warn_on_missing = False):
return available return available
def voxmlx_backend_available():
"""Return True if voxmlx (Voxtral MLX backend) is available."""
is_macos = platform.system() == "Darwin"
is_arm = platform.machine() == "arm64"
return is_macos and is_arm and module_available("voxmlx")
def faster_backend_available(warn_on_missing = False): def faster_backend_available(warn_on_missing = False):
available = module_available("faster_whisper") available = module_available("faster_whisper")
if not available and warn_on_missing and platform.system() != "Darwin": if not available and warn_on_missing and platform.system() != "Darwin":

View File

@@ -1,5 +1,6 @@
import logging import logging
import sys import sys
import threading
from argparse import Namespace from argparse import Namespace
from whisperlivekit.local_agreement.online_asr import OnlineASRProcessor from whisperlivekit.local_agreement.online_asr import OnlineASRProcessor
@@ -19,16 +20,26 @@ logger = logging.getLogger(__name__)
class TranscriptionEngine: class TranscriptionEngine:
_instance = None _instance = None
_initialized = False _initialized = False
_lock = threading.Lock() # Thread-safe singleton lock
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
# Double-checked locking pattern for thread-safe singleton
if cls._instance is None: if cls._instance is None:
cls._instance = super().__new__(cls) with cls._lock:
# Check again inside lock to prevent race condition
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance return cls._instance
def __init__(self, **kwargs): def __init__(self, **kwargs):
if TranscriptionEngine._initialized: # Thread-safe initialization check
return with TranscriptionEngine._lock:
if TranscriptionEngine._initialized:
return
# Set flag immediately to prevent re-initialization
TranscriptionEngine._initialized = True
# Perform initialization outside lock to avoid holding lock during slow operations
global_params = { global_params = {
"host": "localhost", "host": "localhost",
"port": 8000, "port": 8000,
@@ -36,7 +47,6 @@ class TranscriptionEngine:
"punctuation_split": False, "punctuation_split": False,
"target_language": "", "target_language": "",
"vac": True, "vac": True,
"vac_onnx": False,
"vac_chunk_size": 0.04, "vac_chunk_size": 0.04,
"log_level": "DEBUG", "log_level": "DEBUG",
"ssl_certfile": None, "ssl_certfile": None,
@@ -79,18 +89,27 @@ class TranscriptionEngine:
self.asr = None self.asr = None
self.tokenizer = None self.tokenizer = None
self.diarization = None self.diarization = None
self.vac_model = None self.vac_session = None
if self.args.vac: if self.args.vac:
from whisperlivekit.silero_vad_iterator import load_silero_vad from whisperlivekit.silero_vad_iterator import is_onnx_available
# Use ONNX if specified, otherwise use JIT (default) if is_onnx_available():
use_onnx = kwargs.get('vac_onnx', False) from whisperlivekit.silero_vad_iterator import load_onnx_session
self.vac_model = load_silero_vad(onnx=use_onnx) self.vac_session = load_onnx_session()
else:
logger.warning(
"onnxruntime not installed. VAC will use JIT model which is loaded per-session. "
"For multi-user scenarios, install onnxruntime: pip install onnxruntime"
)
backend_policy = self.args.backend_policy backend_policy = self.args.backend_policy
if self.args.transcription: if self.args.transcription:
if backend_policy == "simulstreaming": if self.args.backend == "voxtral-mlx":
from whisperlivekit.voxtral_streaming import VoxtralStreamingASR
self.tokenizer = None
self.asr = VoxtralStreamingASR(**transcription_common_params)
logger.info("Using Voxtral MLX streaming backend")
elif backend_policy == "simulstreaming":
simulstreaming_params = { simulstreaming_params = {
"disable_fast_encoder": False, "disable_fast_encoder": False,
"custom_alignment_heads": None, "custom_alignment_heads": None,
@@ -169,16 +188,16 @@ class TranscriptionEngine:
} }
translation_params = update_with_kwargs(translation_params, kwargs) translation_params = update_with_kwargs(translation_params, kwargs)
self.translation_model = load_model([self.args.lan], **translation_params) #in the future we want to handle different languages for different speakers self.translation_model = load_model([self.args.lan], **translation_params) #in the future we want to handle different languages for different speakers
TranscriptionEngine._initialized = True
def online_factory(args, asr): def online_factory(args, asr):
if args.backend_policy == "simulstreaming": if getattr(args, 'backend', None) == "voxtral-mlx":
from whisperlivekit.voxtral_streaming import VoxtralStreamingOnlineProcessor
return VoxtralStreamingOnlineProcessor(asr)
if args.backend_policy == "simulstreaming":
from whisperlivekit.simul_whisper import SimulStreamingOnlineProcessor from whisperlivekit.simul_whisper import SimulStreamingOnlineProcessor
online = SimulStreamingOnlineProcessor(asr) return SimulStreamingOnlineProcessor(asr)
else: return OnlineASRProcessor(asr)
online = OnlineASRProcessor(asr)
return online
def online_diarization_factory(args, diarization_backend): def online_diarization_factory(args, diarization_backend):

View File

@@ -202,14 +202,14 @@ class DiartDiarization:
def insert_silence(self, silence_duration): def insert_silence(self, silence_duration):
self.observer.global_time_offset += silence_duration self.observer.global_time_offset += silence_duration
async def diarize(self, pcm_array: np.ndarray): def insert_audio_chunk(self, pcm_array: np.ndarray):
""" """Buffer audio for the next diarization step."""
Process audio data for diarization.
Only used when working with WebSocketAudioSource.
"""
if self.custom_source: if self.custom_source:
self.custom_source.push_audio(pcm_array) self.custom_source.push_audio(pcm_array)
# self.observer.clear_old_segments()
async def diarize(self):
"""Return the current speaker segments from the diarization pipeline."""
return self.observer.get_segments()
def close(self): def close(self):
"""Close the audio source.""" """Close the audio source."""

View File

@@ -249,6 +249,7 @@ class OpenaiApiASR(ASRBase):
self.load_model() self.load_model()
self.use_vad_opt = False self.use_vad_opt = False
self.direct_english_translation = False self.direct_english_translation = False
self.task = "transcribe"
def load_model(self, *args, **kwargs): def load_model(self, *args, **kwargs):
from openai import OpenAI from openai import OpenAI
@@ -294,7 +295,8 @@ class OpenaiApiASR(ASRBase):
params["language"] = self.original_language params["language"] = self.original_language
if prompt: if prompt:
params["prompt"] = prompt params["prompt"] = prompt
proc = self.client.audio.translations if self.task == "translate" else self.client.audio.transcriptions task = self.transcribe_kargs.get("task", self.task)
proc = self.client.audio.translations if task == "translate" else self.client.audio.transcriptions
transcript = proc.create(**params) transcript = proc.create(**params)
logger.debug(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds") logger.debug(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds")
return transcript return transcript

View File

@@ -146,6 +146,7 @@ def backend_factory(
if direct_english_translation: if direct_english_translation:
tgt_language = "en" # Whisper translates into English tgt_language = "en" # Whisper translates into English
asr.transcribe_kargs["task"] = "translate"
else: else:
tgt_language = lan # Whisper transcribes in this language tgt_language = lan # Whisper transcribes in this language
@@ -154,9 +155,9 @@ def backend_factory(
tokenizer = create_tokenizer(tgt_language) tokenizer = create_tokenizer(tgt_language)
else: else:
tokenizer = None tokenizer = None
warmup_asr(asr, warmup_file) warmup_asr(asr, warmup_file)
asr.confidence_validation = confidence_validation asr.confidence_validation = confidence_validation
asr.tokenizer = tokenizer asr.tokenizer = tokenizer
asr.buffer_trimming = buffer_trimming asr.buffer_trimming = buffer_trimming

View File

@@ -147,8 +147,8 @@ def parse_args():
"--backend", "--backend",
type=str, type=str,
default="auto", default="auto",
choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api"], choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api", "voxtral-mlx"],
help="Select the Whisper backend implementation (auto: prefer MLX on macOS, otherwise Faster-Whisper, else Whisper). Use 'openai-api' with --backend-policy localagreement to call OpenAI's API.", help="Select the Whisper backend implementation (auto: prefer MLX on macOS, otherwise Faster-Whisper, else Whisper). Use 'openai-api' with --backend-policy localagreement to call OpenAI's API. Use 'voxtral-mlx' for Voxtral streaming on Apple Silicon.",
) )
parser.add_argument( parser.add_argument(
"--no-vac", "--no-vac",

View File

@@ -8,6 +8,15 @@ import torch
Code is adapted from silero-vad v6: https://github.com/snakers4/silero-vad Code is adapted from silero-vad v6: https://github.com/snakers4/silero-vad
""" """
def is_onnx_available() -> bool:
"""Check if onnxruntime is installed."""
try:
import onnxruntime
return True
except ImportError:
return False
def init_jit_model(model_path: str, device=torch.device('cpu')): def init_jit_model(model_path: str, device=torch.device('cpu')):
"""Load a JIT model from file.""" """Load a JIT model from file."""
model = torch.jit.load(model_path, map_location=device) model = torch.jit.load(model_path, map_location=device)
@@ -15,12 +24,12 @@ def init_jit_model(model_path: str, device=torch.device('cpu')):
return model return model
class OnnxWrapper(): class OnnxSession():
"""ONNX Runtime wrapper for Silero VAD model.""" """
Shared ONNX session for Silero VAD model (stateless).
"""
def __init__(self, path, force_onnx_cpu=False): def __init__(self, path, force_onnx_cpu=False):
global np
import numpy as np
import onnxruntime import onnxruntime
opts = onnxruntime.SessionOptions() opts = onnxruntime.SessionOptions()
@@ -32,13 +41,28 @@ class OnnxWrapper():
else: else:
self.session = onnxruntime.InferenceSession(path, sess_options=opts) self.session = onnxruntime.InferenceSession(path, sess_options=opts)
self.reset_states() self.path = path
if '16k' in path: if '16k' in path:
warnings.warn('This model support only 16000 sampling rate!') warnings.warn('This model support only 16000 sampling rate!')
self.sample_rates = [16000] self.sample_rates = [16000]
else: else:
self.sample_rates = [8000, 16000] self.sample_rates = [8000, 16000]
class OnnxWrapper():
"""
ONNX Runtime wrapper for Silero VAD model with per-instance state.
"""
def __init__(self, session: OnnxSession, force_onnx_cpu=False):
self._shared_session = session
self.sample_rates = session.sample_rates
self.reset_states()
@property
def session(self):
return self._shared_session.session
def _validate_input(self, x, sr: int): def _validate_input(self, x, sr: int):
if x.dim() == 1: if x.dim() == 1:
x = x.unsqueeze(0) x = x.unsqueeze(0)
@@ -101,38 +125,20 @@ class OnnxWrapper():
return out return out
def load_silero_vad(model_path: str = None, onnx: bool = False, opset_version: int = 16): def _get_onnx_model_path(model_path: str = None, opset_version: int = 16) -> Path:
""" """Get the path to the ONNX model file."""
Load Silero VAD model (JIT or ONNX).
Parameters
----------
model_path : str, optional
Path to model file. If None, uses default bundled model.
onnx : bool, default False
Whether to use ONNX runtime (requires onnxruntime package).
opset_version : int, default 16
ONNX opset version (15 or 16). Only used if onnx=True.
Returns
-------
model
Loaded VAD model (JIT or ONNX wrapper)
"""
available_ops = [15, 16] available_ops = [15, 16]
if onnx and opset_version not in available_ops: if opset_version not in available_ops:
raise Exception(f'Available ONNX opset_version: {available_ops}') raise Exception(f'Available ONNX opset_version: {available_ops}')
if model_path is None: if model_path is None:
current_dir = Path(__file__).parent current_dir = Path(__file__).parent
data_dir = current_dir / 'silero_vad_models' data_dir = current_dir / 'silero_vad_models'
if onnx: if opset_version == 16:
if opset_version == 16: model_name = 'silero_vad.onnx'
model_name = 'silero_vad.onnx'
else:
model_name = f'silero_vad_16k_op{opset_version}.onnx'
else: else:
model_name = 'silero_vad.jit' model_name = f'silero_vad_16k_op{opset_version}.onnx'
model_path = data_dir / model_name model_path = data_dir / model_name
@@ -143,17 +149,39 @@ def load_silero_vad(model_path: str = None, onnx: bool = False, opset_version: i
) )
else: else:
model_path = Path(model_path) model_path = Path(model_path)
if onnx:
try: return model_path
model = OnnxWrapper(str(model_path), force_onnx_cpu=True)
except ImportError:
raise ImportError( def load_onnx_session(model_path: str = None, opset_version: int = 16, force_onnx_cpu: bool = True) -> OnnxSession:
"ONNX runtime not available. Install with: pip install onnxruntime\n" """
"Or use JIT model by setting onnx=False" Load a shared ONNX session for Silero VAD.
"""
path = _get_onnx_model_path(model_path, opset_version)
return OnnxSession(str(path), force_onnx_cpu=force_onnx_cpu)
def load_jit_vad(model_path: str = None):
"""
Load Silero VAD model in JIT format.
"""
if model_path is None:
current_dir = Path(__file__).parent
data_dir = current_dir / 'silero_vad_models'
model_name = 'silero_vad.jit'
model_path = data_dir / model_name
if not model_path.exists():
raise FileNotFoundError(
f"Model file not found: {model_path}\n"
f"Please ensure the whisperlivekit/silero_vad_models/ directory contains the model files."
) )
else: else:
model = init_jit_model(str(model_path)) model_path = Path(model_path)
model = init_jit_model(str(model_path))
return model return model
@@ -285,13 +313,14 @@ class FixedVADIterator(VADIterator):
if __name__ == "__main__": if __name__ == "__main__":
model = load_silero_vad(onnx=False) # vad = FixedVADIterator(load_jit_vad())
vad = FixedVADIterator(model) vad = FixedVADIterator(OnnxWrapper(session=load_onnx_session()))
audio_buffer = np.array([0] * 512, dtype=np.float32) audio_buffer = np.array([0] * 512, dtype=np.float32)
result = vad(audio_buffer) result = vad(audio_buffer)
print(f" 512 samples: {result}") print(f" 512 samples: {result}")
# test with 511 samples # test with 511 samples
audio_buffer = np.array([0] * 511, dtype=np.float32) audio_buffer = np.array([0] * 511, dtype=np.float32)
result = vad(audio_buffer) result = vad(audio_buffer)
print(f" 511 samples: {result}")

View File

@@ -24,9 +24,11 @@ logger = logging.getLogger(__name__)
HAS_MLX_WHISPER = mlx_backend_available(warn_on_missing=True) HAS_MLX_WHISPER = mlx_backend_available(warn_on_missing=True)
if HAS_MLX_WHISPER: if HAS_MLX_WHISPER:
from .mlx_encoder import load_mlx_encoder, mlx_model_mapping from .mlx_encoder import load_mlx_encoder, load_mlx_model, mlx_model_mapping
from .mlx import MLXAlignAtt
else: else:
mlx_model_mapping = {} mlx_model_mapping = {}
MLXAlignAtt = None
HAS_FASTER_WHISPER = faster_backend_available(warn_on_missing=not HAS_MLX_WHISPER) HAS_FASTER_WHISPER = faster_backend_available(warn_on_missing=not HAS_MLX_WHISPER)
if HAS_FASTER_WHISPER: if HAS_FASTER_WHISPER:
from faster_whisper import WhisperModel from faster_whisper import WhisperModel
@@ -36,50 +38,47 @@ else:
MIN_DURATION_REAL_SILENCE = 5 MIN_DURATION_REAL_SILENCE = 5
class SimulStreamingOnlineProcessor: class SimulStreamingOnlineProcessor:
"""Online processor for SimulStreaming ASR."""
SAMPLING_RATE = 16000 SAMPLING_RATE = 16000
def __init__( def __init__(self, asr, logfile=sys.stderr):
self,
asr,
logfile=sys.stderr,
):
self.asr = asr self.asr = asr
self.logfile = logfile self.logfile = logfile
self.end = 0.0 self.end = 0.0
self.buffer = [] self.buffer = []
self.committed: List[ASRToken] = [] self.model = self._create_alignatt()
self.last_result_tokens: List[ASRToken] = []
self.load_new_alignatt_instance()
if asr.tokenizer: if asr.tokenizer:
self.model.tokenizer = asr.tokenizer self.model.tokenizer = asr.tokenizer
self.model.state.tokenizer = asr.tokenizer
def load_new_alignatt_instance(self): def _create_alignatt(self):
"""Initialize AlignAtt decoder using the shared model.""" """Create the AlignAtt decoder instance based on ASR mode."""
self.model = AlignAtt( if self.asr.use_full_mlx and HAS_MLX_WHISPER:
cfg=self.asr.cfg, return MLXAlignAtt(cfg=self.asr.cfg, mlx_model=self.asr.mlx_model)
loaded_model=self.asr.shared_model, else:
mlx_encoder=self.asr.mlx_encoder, return AlignAtt(
fw_encoder=self.asr.fw_encoder, cfg=self.asr.cfg,
) loaded_model=self.asr.shared_model,
mlx_encoder=self.asr.mlx_encoder,
fw_encoder=self.asr.fw_encoder,
)
def start_silence(self): def start_silence(self):
tokens, processed_upto = self.process_iter(is_last=True) tokens, processed_upto = self.process_iter(is_last=True)
return tokens, processed_upto return tokens, processed_upto
def end_silence(self, silence_duration, offset): def end_silence(self, silence_duration, offset):
""" """Handle silence period."""
Handle silence period.
If silence > MIN_DURATION_REAL_SILENCE, do a complete context clear.
Otherwise, insert a small silence and shift the last_attend_frame.
"""
self.end += silence_duration self.end += silence_duration
long_silence = silence_duration >= MIN_DURATION_REAL_SILENCE long_silence = silence_duration >= MIN_DURATION_REAL_SILENCE
if not long_silence: if not long_silence:
gap_len = int(16000 * silence_duration) gap_len = int(16000 * silence_duration)
if gap_len > 0: if gap_len > 0:
gap_silence = torch.zeros(gap_len) if self.asr.use_full_mlx:
gap_silence = np.zeros(gap_len, dtype=np.float32)
else:
gap_silence = torch.zeros(gap_len)
self.model.insert_audio(gap_silence) self.model.insert_audio(gap_silence)
if long_silence: if long_silence:
self.model.refresh_segment(complete=True) self.model.refresh_segment(complete=True)
@@ -87,11 +86,12 @@ class SimulStreamingOnlineProcessor:
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time): def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time):
"""Append an audio chunk to be processed by SimulStreaming.""" """Append an audio chunk to be processed by SimulStreaming."""
self.end = audio_stream_end_time
# Convert numpy array to torch tensor if self.asr.use_full_mlx:
audio_tensor = torch.from_numpy(audio).float() self.model.insert_audio(audio)
self.end = audio_stream_end_time # Aligned with whisperstreaming backend behavior else:
self.model.insert_audio(audio_tensor) audio_tensor = torch.from_numpy(audio).float()
self.model.insert_audio(audio_tensor)
def new_speaker(self, change_speaker: ChangeSpeaker): def new_speaker(self, change_speaker: ChangeSpeaker):
"""Handle speaker change event.""" """Handle speaker change event."""
@@ -120,7 +120,6 @@ class SimulStreamingOnlineProcessor:
self.buffer.extend(timestamped_words) self.buffer.extend(timestamped_words)
return [], self.end return [], self.end
self.committed.extend(timestamped_words)
self.buffer = [] self.buffer = []
return timestamped_words, self.end return timestamped_words, self.end
except Exception as e: except Exception as e:
@@ -130,6 +129,10 @@ class SimulStreamingOnlineProcessor:
def warmup(self, audio, init_prompt=""): def warmup(self, audio, init_prompt=""):
"""Warmup the SimulStreaming model.""" """Warmup the SimulStreaming model."""
try: try:
if self.asr.use_full_mlx:
# MLX mode: ensure numpy array
if hasattr(audio, 'numpy'):
audio = audio.numpy()
self.model.insert_audio(audio) self.model.insert_audio(audio)
self.model.infer(True) self.model.infer(True)
self.model.refresh_segment(complete=True) self.model.refresh_segment(complete=True)
@@ -139,9 +142,14 @@ class SimulStreamingOnlineProcessor:
def __del__(self): def __del__(self):
gc.collect() gc.collect()
torch.cuda.empty_cache() if not getattr(self.asr, 'use_full_mlx', True) and torch is not None:
try:
torch.cuda.empty_cache()
except Exception:
pass
class SimulStreamingASR():
class SimulStreamingASR:
"""SimulStreaming backend with AlignAtt policy.""" """SimulStreaming backend with AlignAtt policy."""
sep = "" sep = ""
@@ -158,6 +166,7 @@ class SimulStreamingASR():
self.fast_encoder = False self.fast_encoder = False
self._resolved_model_path = None self._resolved_model_path = None
self.encoder_backend = "whisper" self.encoder_backend = "whisper"
self.use_full_mlx = getattr(self, "use_full_mlx", False)
preferred_backend = getattr(self, "backend", "auto") preferred_backend = getattr(self, "backend", "auto")
compatible_whisper_mlx, compatible_faster_whisper = True, True compatible_whisper_mlx, compatible_faster_whisper = True, True
@@ -170,7 +179,7 @@ class SimulStreamingASR():
compatible_whisper_mlx = model_info.compatible_whisper_mlx compatible_whisper_mlx = model_info.compatible_whisper_mlx
compatible_faster_whisper = model_info.compatible_faster_whisper compatible_faster_whisper = model_info.compatible_faster_whisper
if not model_info.has_pytorch: if not self.use_full_mlx and not model_info.has_pytorch:
raise FileNotFoundError( raise FileNotFoundError(
f"No PyTorch checkpoint (.pt/.bin/.safetensors) found under {self.model_path}" f"No PyTorch checkpoint (.pt/.bin/.safetensors) found under {self.model_path}"
) )
@@ -190,6 +199,10 @@ class SimulStreamingASR():
self.fast_encoder = self.encoder_backend in ("mlx-whisper", "faster-whisper") self.fast_encoder = self.encoder_backend in ("mlx-whisper", "faster-whisper")
if self.encoder_backend == "whisper": if self.encoder_backend == "whisper":
self.disable_fast_encoder = True self.disable_fast_encoder = True
if self.encoder_backend == "mlx-whisper" and platform.system() == "Darwin":
if not hasattr(self, '_full_mlx_disabled'):
self.use_full_mlx = True
self.cfg = AlignAttConfig( self.cfg = AlignAttConfig(
tokenizer_is_multilingual= is_multilingual, tokenizer_is_multilingual= is_multilingual,
@@ -201,7 +214,7 @@ class SimulStreamingASR():
cif_ckpt_path=self.cif_ckpt_path, cif_ckpt_path=self.cif_ckpt_path,
decoder_type="beam", decoder_type="beam",
beam_size=self.beams, beam_size=self.beams,
task=self.direct_english_translation, task="translate" if self.direct_english_translation else "transcribe",
never_fire=self.never_fire, never_fire=self.never_fire,
init_prompt=self.init_prompt, init_prompt=self.init_prompt,
max_context_tokens=self.max_context_tokens, max_context_tokens=self.max_context_tokens,
@@ -214,20 +227,36 @@ class SimulStreamingASR():
else: else:
self.tokenizer = None self.tokenizer = None
self.mlx_encoder, self.fw_encoder = None, None self.mlx_encoder, self.fw_encoder, self.mlx_model = None, None, None
if self.encoder_backend == "mlx-whisper": self.shared_model = None
print('Simulstreaming will use MLX whisper to increase encoding speed.')
if self.use_full_mlx and HAS_MLX_WHISPER:
logger.info('MLX Whisper backend used.')
if self._resolved_model_path is not None: if self._resolved_model_path is not None:
mlx_model = str(self._resolved_model_path) mlx_model_path = str(self._resolved_model_path)
else: else:
mlx_model = mlx_model_mapping.get(self.model_name) mlx_model_path = mlx_model_mapping.get(self.model_name)
if not mlx_model: if not mlx_model_path:
raise FileNotFoundError( raise FileNotFoundError(
f"MLX Whisper backend requested but no compatible weights found for model '{self.model_name}'." f"MLX Whisper backend requested but no compatible weights found for model '{self.model_name}'."
) )
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model) self.mlx_model = load_mlx_model(path_or_hf_repo=mlx_model_path)
self._warmup_mlx_model()
elif self.encoder_backend == "mlx-whisper":
# hybrid mode: mlx encoder + pytorch decoder
logger.info('SimulStreaming will use MLX Whisper encoder with PyTorch decoder.')
if self._resolved_model_path is not None:
mlx_model_path = str(self._resolved_model_path)
else:
mlx_model_path = mlx_model_mapping.get(self.model_name)
if not mlx_model_path:
raise FileNotFoundError(
f"MLX Whisper backend requested but no compatible weights found for model '{self.model_name}'."
)
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model_path)
self.shared_model = self.load_model()
elif self.encoder_backend == "faster-whisper": elif self.encoder_backend == "faster-whisper":
print('Simulstreaming will use Faster Whisper for the encoder.') print('SimulStreaming will use Faster Whisper for the encoder.')
if self._resolved_model_path is not None: if self._resolved_model_path is not None:
fw_model = str(self._resolved_model_path) fw_model = str(self._resolved_model_path)
else: else:
@@ -237,7 +266,20 @@ class SimulStreamingASR():
device='auto', device='auto',
compute_type='auto', compute_type='auto',
) )
self.shared_model = self.load_model() self.shared_model = self.load_model()
else:
self.shared_model = self.load_model()
def _warmup_mlx_model(self):
"""Warmup the full MLX model."""
warmup_audio = load_file(self.warmup_file)
if warmup_audio is not None:
temp_model = MLXAlignAtt(
cfg=self.cfg,
mlx_model=self.mlx_model,
)
temp_model.warmup(warmup_audio)
logger.info("Full MLX model warmed up successfully")
def _resolve_encoder_backend(self, preferred_backend, compatible_whisper_mlx, compatible_faster_whisper): def _resolve_encoder_backend(self, preferred_backend, compatible_whisper_mlx, compatible_faster_whisper):
@@ -285,7 +327,7 @@ class SimulStreamingASR():
lora_path = getattr(self, 'lora_path', None) lora_path = getattr(self, 'lora_path', None)
whisper_model = load_model( whisper_model = load_model(
name=model_ref, name=model_ref,
download_root=None, download_root=getattr(self, 'model_cache_dir', None),
decoder_only=self.fast_encoder, decoder_only=self.fast_encoder,
custom_alignment_heads=self.custom_alignment_heads, custom_alignment_heads=self.custom_alignment_heads,
lora_path=lora_path, lora_path=lora_path,

View File

@@ -47,9 +47,24 @@ class DecoderState:
def clean_cache(self): def clean_cache(self):
"""Clean the kv_cache after each inference step.""" """Clean the kv_cache after each inference step."""
self.kv_cache = {} # Explicitly delete tensor references to free GPU memory
if self.kv_cache:
for key in list(self.kv_cache.keys()):
tensor = self.kv_cache.pop(key, None)
if tensor is not None:
del tensor
# Clear the dict
self.kv_cache.clear()
# Force GPU cache cleanup (only if CUDA is available)
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
if self.decoder_type == "beam" and self.inference is not None: if self.decoder_type == "beam" and self.inference is not None:
self.inference.kv_cache = self.kv_cache # Create NEW dict instead of sharing reference
self.inference.kv_cache = {}
if self.token_decoder is not None: if self.token_decoder is not None:
self.token_decoder.reset() self.token_decoder.reset()

View File

@@ -0,0 +1,11 @@
from .decoder_state import MLXDecoderState
from .decoders import MLXBeamSearchDecoder, MLXGreedyDecoder, MLXInference
from .simul_whisper import MLXAlignAtt
__all__ = [
"MLXAlignAtt",
"MLXBeamSearchDecoder",
"MLXDecoderState",
"MLXGreedyDecoder",
"MLXInference",
]

View File

@@ -0,0 +1,76 @@
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
import mlx.core as mx
import numpy as np
@dataclass
class MLXDecoderState:
"""
mlx kv cache format: List of ((k, v), (cross_k, cross_v)) tuples per layer,
where each element is a tuple of mx.arrays.
"""
kv_cache: Optional[List[Tuple[Tuple[mx.array, mx.array], Tuple[mx.array, mx.array]]]] = None
tokenizer: Any = None
detected_language: Optional[str] = None
reset_tokenizer_to_auto_next_call: bool = False
tokens: List[mx.array] = field(default_factory=list)
initial_tokens: Optional[mx.array] = None
initial_token_length: int = 0
sot_index: int = 0
align_source: Dict[int, List[Tuple[int, int]]] = field(default_factory=dict)
num_align_heads: int = 0
segments: List[np.ndarray] = field(default_factory=list)
context: Any = None
pending_incomplete_tokens: List[int] = field(default_factory=list)
global_time_offset: float = 0.0
cumulative_time_offset: float = 0.0
first_timestamp: Optional[float] = None
last_attend_frame: int = 0
speaker: int = -1
log_segments: int = 0
cif_weights: Optional[mx.array] = None
always_fire: bool = False
never_fire: bool = False
suppress_tokens: Optional[Tuple[int, ...]] = None
token_decoder: Any = None
decoder_type: str = "greedy"
inference: Any = None
def clean_cache(self):
self.kv_cache = None
if self.decoder_type == "beam" and self.inference is not None:
self.inference.kv_cache = None
if self.token_decoder is not None:
self.token_decoder.reset()
def reset(self, rewind_threshold: int = 200):
self.last_attend_frame = -rewind_threshold
self.cumulative_time_offset = 0.0
self.pending_incomplete_tokens = []
self.log_segments += 1
def full_reset(self, rewind_threshold: int = 200):
"""
Full reset including audio segments and tokens.
Args:
rewind_threshold: Value for resetting last_attend_frame
"""
self.reset(rewind_threshold)
self.segments = []
self.tokens = []
self.kv_cache = None
self.first_timestamp = None

View File

@@ -0,0 +1,219 @@
"""
MLX-native token decoders for streaming ASR.
"""
from typing import Any, Dict, List, Optional, Tuple
import mlx.core as mx
import numpy as np
class MLXGreedyDecoder:
"""Greedy decoder using MLX operations."""
def __init__(self, temperature: float, eot: int):
self.temperature = temperature
self.eot = eot
def update(
self, tokens: mx.array, logits: mx.array, sum_logprobs: mx.array
) -> Tuple[mx.array, bool]:
"""
Update tokens with next predicted token.
Args:
tokens: Current token sequence, shape (batch, seq_len)
logits: Logits for next token, shape (batch, vocab_size)
sum_logprobs: Cumulative log probabilities, shape (batch,)
Returns:
Updated tokens and completion flag
"""
if self.temperature == 0:
next_tokens = mx.argmax(logits, axis=-1)
else:
probs = mx.softmax(logits / self.temperature, axis=-1)
next_tokens = mx.random.categorical(mx.log(probs + 1e-10))
logprobs = mx.softmax(logits, axis=-1)
logprobs = mx.log(logprobs + 1e-10)
batch_size = logprobs.shape[0]
current_logprobs = logprobs[mx.arange(batch_size), next_tokens]
mask = (tokens[:, -1] != self.eot).astype(mx.float32)
sum_logprobs = sum_logprobs + current_logprobs * mask
eot_mask = (tokens[:, -1] == self.eot)
next_tokens = mx.where(eot_mask, mx.array(self.eot), next_tokens)
tokens = mx.concatenate([tokens, next_tokens[:, None]], axis=1)
completed = bool(mx.all(tokens[:, -1] == self.eot))
return tokens, completed
def finalize(self, tokens: mx.array, sum_logprobs: mx.array):
"""Finalize decoding by ensuring EOT at end."""
eot_column = mx.full((tokens.shape[0], 1), self.eot, dtype=tokens.dtype)
tokens = mx.concatenate([tokens, eot_column], axis=1)
return tokens, sum_logprobs.tolist()
class MLXBeamSearchDecoder:
"""Beam search decoder using MLX operations."""
def __init__(
self,
beam_size: int,
eot: int,
inference: Any,
patience: Optional[float] = None,
):
self.beam_size = beam_size
self.eot = eot
self.inference = inference
self.patience = patience or 1.0
self.max_candidates: int = round(beam_size * self.patience)
self.finished_sequences: Optional[List[Dict]] = None
assert (
self.max_candidates > 0
), f"Invalid beam size ({beam_size}) or patience ({patience})"
def reset(self):
"""Reset finished sequences for new segment."""
self.finished_sequences = None
def update(
self, tokens: mx.array, logits: mx.array, sum_logprobs: mx.array
) -> Tuple[mx.array, bool]:
"""
Update tokens using beam search.
Args:
tokens: Current token sequences, shape (batch * beam_size, seq_len)
logits: Logits for next token, shape (batch * beam_size, vocab_size)
sum_logprobs: Cumulative log probabilities, shape (batch * beam_size,)
Returns:
Updated tokens and completion flag
"""
if tokens.shape[0] % self.beam_size != 0:
raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
n_audio = tokens.shape[0] // self.beam_size
if self.finished_sequences is None:
self.finished_sequences = [{} for _ in range(n_audio)]
logprobs = mx.softmax(logits, axis=-1)
logprobs = mx.log(logprobs + 1e-10)
logprobs_np = np.array(logprobs)
tokens_np = np.array(tokens)
sum_logprobs_np = np.array(sum_logprobs)
next_tokens, source_indices, finished_sequences = [], [], []
new_sum_logprobs = []
for i in range(n_audio):
scores, sources, finished = {}, {}, {}
for j in range(self.beam_size):
idx = i * self.beam_size + j
prefix = tokens_np[idx].tolist()
top_k_indices = np.argsort(logprobs_np[idx])[-self.beam_size - 1:][::-1]
for token_idx in top_k_indices:
logprob = logprobs_np[idx, token_idx]
new_logprob = sum_logprobs_np[idx] + logprob
sequence = tuple(prefix + [int(token_idx)])
scores[sequence] = new_logprob
sources[sequence] = idx
saved = 0
for sequence in sorted(scores, key=scores.get, reverse=True):
if sequence[-1] == self.eot:
finished[sequence] = scores[sequence]
else:
new_sum_logprobs.append(scores[sequence])
next_tokens.append(sequence)
source_indices.append(sources[sequence])
saved += 1
if saved == self.beam_size:
break
finished_sequences.append(finished)
tokens = mx.array(np.array(next_tokens, dtype=np.int32))
sum_logprobs = mx.array(np.array(new_sum_logprobs, dtype=np.float32))
self.inference.rearrange_kv_cache(source_indices)
assert len(self.finished_sequences) == len(finished_sequences)
for previously_finished, newly_finished in zip(
self.finished_sequences, finished_sequences
):
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
if len(previously_finished) >= self.max_candidates:
break
previously_finished[seq] = newly_finished[seq]
completed = all(
len(sequences) >= self.max_candidates
for sequences in self.finished_sequences
)
return tokens, completed
def finalize(self, preceding_tokens: mx.array, sum_logprobs: mx.array):
"""Finalize beam search by selecting best sequences."""
preceding_tokens_np = np.array(preceding_tokens)
sum_logprobs_np = np.array(sum_logprobs)
n_audio = preceding_tokens_np.shape[0] // self.beam_size
tokens_list: List[List[int]] = [[] for _ in range(n_audio)]
sum_logprobs_list: List[float] = [0.0] * n_audio
for i, sequences in enumerate(self.finished_sequences):
if sequences:
best_seq = max(sequences, key=sequences.get)
tokens_list[i] = list(best_seq)
sum_logprobs_list[i] = sequences[best_seq]
else:
idx = i * self.beam_size
tokens_list[i] = preceding_tokens_np[idx].tolist() + [self.eot]
sum_logprobs_list[i] = float(sum_logprobs_np[idx])
max_len = max(len(t) for t in tokens_list)
for i, t in enumerate(tokens_list):
tokens_list[i] = t + [self.eot] * (max_len - len(t))
tokens = mx.array(np.array(tokens_list, dtype=np.int32))
return tokens, sum_logprobs_list
class MLXInference:
"""MLX inference wrapper for beam search KV cache management."""
def __init__(self, model, initial_token_length: int):
self.model = model
self.initial_token_length = initial_token_length
self.kv_cache = None
def rearrange_kv_cache(self, source_indices: List[int]):
"""Rearrange KV cache based on beam search source indices."""
if self.kv_cache is None:
return
if source_indices == list(range(len(source_indices))):
return
source_indices_mx = mx.array(source_indices, dtype=mx.int32)
new_cache = []
for layer_cache in self.kv_cache:
(k, v), (cross_k, cross_v) = layer_cache
new_k = k[source_indices_mx]
new_v = v[source_indices_mx]
new_cache.append(((new_k, new_v), (cross_k, cross_v)))
self.kv_cache = new_cache
def logits(
self,
tokens: mx.array,
audio_features: mx.array,
) -> Tuple[mx.array, List]:
"""Get logits from decoder with KV cache."""
logits, self.kv_cache, cross_qk = self.model.decoder(
tokens, audio_features, kv_cache=self.kv_cache
)
return logits, cross_qk

View File

@@ -0,0 +1,756 @@
"""
MLX whisper AlignAtt streaming decoder
"""
import logging
from time import time
from typing import Any, List, Optional, Tuple
import mlx.core as mx
import numpy as np
from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram
from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim
from whisperlivekit.timed_objects import ASRToken
from whisperlivekit.whisper import DecodingOptions, tokenizer
from whisperlivekit.whisper.audio import N_FRAMES, N_SAMPLES, TOKENS_PER_SECOND
from ..config import AlignAttConfig
from .decoder_state import MLXDecoderState
from .decoders import MLXBeamSearchDecoder, MLXGreedyDecoder, MLXInference
DEC_PAD = 50257
logger = logging.getLogger(__name__)
class MLXTokenBuffer: #should try to make it heritate from classic simul whisper class
"""Token buffer for MLX-based decoding."""
def __init__(self, text="", tokenizer=None, prefix_token_ids=None):
self.text = text
self.prefix_token_ids = prefix_token_ids or []
self.tokenizer = tokenizer
self.pending_token_ids = []
def as_token_ids(self, tokenizer=None):
if tokenizer is None:
tokenizer = self.tokenizer
if tokenizer is None:
raise ValueError("Tokenizer is not set.")
return self.prefix_token_ids + tokenizer.encode(self.text)
def as_mlx_array(self) -> mx.array:
"""Return tokens as MLX array."""
tok_ids = self.as_token_ids()
return mx.array([tok_ids], dtype=mx.int32)
def as_mlx_array_beam(self, beam: int) -> mx.array:
"""Return tokens as MLX array repeated for beam search."""
t = self.as_mlx_array()
return mx.repeat(t, beam, axis=0)
def as_text(self):
return self.text
@staticmethod
def empty(*a, **kw):
return MLXTokenBuffer(*a, **kw)
@staticmethod
def from_text(text, *a, **kw):
return MLXTokenBuffer(*a, text=text, **kw)
def is_empty(self):
return self.text is None or self.text == ""
def trim_words(self, num=1, after=0):
"""Trim words from the beginning of the context."""
tokenizer = self.tokenizer
assert tokenizer is not None, "Tokenizer is not set."
ids = tokenizer.encode(self.text[after:])
words, wids = self.tokenizer.split_to_word_tokens(ids)
if not words:
return 0
self.text = self.text[:after] + "".join(words[num:])
return sum(len(wi) for wi in wids[:num])
def append_token_ids(self, token_ids):
"""Append token IDs to the buffer, handling incomplete UTF-8."""
tokenizer = self.tokenizer
assert tokenizer is not None, "Tokenizer is not set."
all_tokens = self.pending_token_ids + token_ids
decoded = tokenizer.decode(all_tokens)
replacement_char = "\ufffd"
if replacement_char in decoded:
if len(all_tokens) > 1:
decoded_partial = tokenizer.decode(all_tokens[:-1])
if replacement_char not in decoded_partial:
self.text += decoded_partial
self.pending_token_ids = [all_tokens[-1]]
else:
self.pending_token_ids = all_tokens
else:
self.pending_token_ids = all_tokens
else:
self.text += decoded
self.pending_token_ids = []
def mlx_median_filter(x: mx.array, filter_width: int) -> mx.array:
"""
Apply median filter along the last axis.
Args:
x: Input array of shape (..., T)
filter_width: Width of the median filter (should be odd)
Returns:
Filtered array of same shape
"""
if filter_width <= 1:
return x
pad_width = filter_width // 2
shape = x.shape
left_pad = mx.repeat(x[..., :1], pad_width, axis=-1)
right_pad = mx.repeat(x[..., -1:], pad_width, axis=-1)
x_padded = mx.concatenate([left_pad, x, right_pad], axis=-1)
result_shape = list(shape)
result = []
for i in range(shape[-1]):
window = x_padded[..., i:i + filter_width]
sorted_window = mx.sort(window, axis=-1)
median_val = sorted_window[..., filter_width // 2:filter_width // 2 + 1]
result.append(median_val)
return mx.concatenate(result, axis=-1)
class MLXAlignAtt:
"""
MLX-native Alignment-based Attention decoder for SimulStreaming.
This class runs entirely on MLX, with no PyTorch dependencies for inference.
"""
@property
def speaker(self):
return self.state.speaker
@speaker.setter
def speaker(self, value):
self.state.speaker = value
@property
def global_time_offset(self):
return self.state.global_time_offset
@global_time_offset.setter
def global_time_offset(self, value):
self.state.global_time_offset = value
def __init__(
self,
cfg: AlignAttConfig,
mlx_model: Any,
) -> None:
"""
Initialize MLX AlignAtt decoder.
Args:
cfg: AlignAtt configuration
mlx_model: MLX Whisper model (full model, not just encoder)
"""
self.model = mlx_model
self.cfg = cfg
logger.info(f"MLX Model dimensions: {self.model.dims}")
self.decode_options = DecodingOptions(
language=cfg.language,
without_timestamps=True,
task=cfg.task
)
self.tokenizer_is_multilingual = cfg.tokenizer_is_multilingual
self.max_text_len = self.model.dims.n_text_ctx
self.num_decoder_layers = len(self.model.decoder.blocks)
if self.cfg.max_context_tokens is None:
self.max_context_tokens = self.max_text_len
else:
self.max_context_tokens = self.cfg.max_context_tokens
# Initialize per-session state
self.state = MLXDecoderState()
self._init_state(cfg)
def _init_state(self, cfg: AlignAttConfig):
"""Initialize the per-session decoder state."""
self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
self.state.tokenizer = self.tokenizer
self.state.detected_language = cfg.language if cfg.language != "auto" else None
self.state.global_time_offset = 0.0
self.state.last_attend_frame = -cfg.rewind_threshold
self.state.speaker = -1
if cfg.cif_ckpt_path is None or not cfg.cif_ckpt_path:
if cfg.never_fire:
self.state.never_fire = True
self.state.always_fire = False
else:
self.state.always_fire = True
self.state.never_fire = False
else:
logger.warning("CIF checkpoint provided but MLX CIF not implemented. Using always_fire=True")
self.state.always_fire = True
self.state.never_fire = cfg.never_fire
self._build_alignment_source()
suppress_tokens = [
self.tokenizer.transcribe,
self.tokenizer.translate,
self.tokenizer.sot,
self.tokenizer.sot_prev,
self.tokenizer.sot_lm,
self.tokenizer.no_timestamps,
] + list(self.tokenizer.all_language_tokens)
if self.tokenizer.no_speech is not None:
suppress_tokens.append(self.tokenizer.no_speech)
self.state.suppress_tokens = tuple(sorted(set(suppress_tokens)))
logger.debug(f"Suppress tokens: {self.state.suppress_tokens}")
self.init_tokens()
self.init_context()
self.state.decoder_type = cfg.decoder_type
if cfg.decoder_type == "greedy":
logger.info("Using MLX greedy decoder")
self.state.token_decoder = MLXGreedyDecoder(0.0, self.tokenizer.eot)
elif cfg.decoder_type == "beam":
logger.info("Using MLX beam decoder")
self.state.inference = MLXInference(self.model, self.state.initial_token_length)
self.state.token_decoder = MLXBeamSearchDecoder(
inference=self.state.inference,
eot=self.tokenizer.eot,
beam_size=cfg.beam_size
)
def _build_alignment_source(self):
"""Build alignment source mapping from model's alignment_heads."""
self.state.align_source = {}
self.state.num_align_heads = 0
alignment_heads = self.model.alignment_heads
if alignment_heads is None:
logger.warning("No alignment heads found in model")
return
if hasattr(alignment_heads, 'tolist'):
heads_list = alignment_heads.tolist()
else:
heads_list = np.array(alignment_heads).tolist()
for layer_rank, head_id in heads_list:
layer_rank = int(layer_rank)
head_id = int(head_id)
heads = self.state.align_source.get(layer_rank, [])
heads.append((self.state.num_align_heads, head_id))
self.state.align_source[layer_rank] = heads
self.state.num_align_heads += 1
def warmup(self, audio: np.ndarray):
"""Warmup the model with sample audio."""
try:
self.insert_audio(audio)
self.infer(is_last=True)
self.refresh_segment(complete=True)
logger.info("MLX model warmed up successfully")
except Exception as e:
logger.exception(f"MLX model warmup failed: {e}")
def create_tokenizer(self, language=None):
"""Create tokenizer for the given language."""
self.tokenizer = tokenizer.get_tokenizer(
multilingual=self.tokenizer_is_multilingual,
language=language,
num_languages=self.model.num_languages,
task=self.decode_options.task
)
self.state.tokenizer = self.tokenizer
def init_context(self):
"""Initialize context buffer."""
kw = {
'tokenizer': self.tokenizer,
'prefix_token_ids': [self.tokenizer.sot_prev]
}
self.state.context = MLXTokenBuffer.empty(**kw)
if self.cfg.static_init_prompt is not None:
self.state.context = MLXTokenBuffer.from_text(self.cfg.static_init_prompt, **kw)
if self.cfg.init_prompt is not None:
self.state.context.text += self.cfg.init_prompt
def init_tokens(self):
"""Initialize token sequence."""
logger.debug(f"init tokens, {len(self.state.segments)}")
self.state.initial_tokens = mx.array(
[self.tokenizer.sot_sequence_including_notimestamps],
dtype=mx.int32
)
self.state.initial_token_length = self.state.initial_tokens.shape[1]
self.state.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot)
logger.debug(f"init tokens after, {len(self.state.segments)}")
self.state.tokens = [self.state.initial_tokens]
def trim_context(self):
"""Trim context if too long."""
logger.info("Trimming context")
c = len(self.state.context.as_token_ids()) - len(self.state.context.prefix_token_ids)
logger.info(f"Context text: {self.state.context.as_text()}")
l = sum(t.shape[1] for t in self.state.tokens) + c
if self.cfg.static_init_prompt is None:
after = 0
else:
after = len(self.cfg.static_init_prompt)
while c > self.max_context_tokens or l > self.max_text_len - 20:
t = self.state.context.trim_words(after=after)
l -= t
c -= t
logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
if t == 0:
break
logger.info(f"Context after trim: {self.state.context.text} (len: {l})")
def refresh_segment(self, complete=False):
"""Refresh segment state."""
logger.debug("Refreshing segment:")
self.init_tokens()
self.state.last_attend_frame = -self.cfg.rewind_threshold
self.state.cumulative_time_offset = 0.0
self.init_context()
logger.debug(f"Context: {self.state.context}")
if not complete and len(self.state.segments) > 2:
self.state.segments = self.state.segments[-2:]
else:
logger.debug("removing all segments.")
self.state.segments = []
self.state.log_segments += 1
self.state.pending_incomplete_tokens = []
def fire_at_boundary(self, chunked_encoder_feature: mx.array) -> bool:
"""Check if we should fire at word boundary (CIF-based)."""
if self.state.always_fire:
return True
if self.state.never_fire:
return False
return True
def _current_tokens(self) -> mx.array:
"""Get current token sequence for decoding."""
toks = self.state.tokens
if toks[0].shape[0] == 1:
toks[0] = mx.repeat(toks[0], self.cfg.beam_size, axis=0)
if not self.state.context.is_empty():
context_toks = self.state.context.as_mlx_array_beam(self.cfg.beam_size)
toks = [context_toks] + toks
# Concatenate all tokens
if len(toks) > 1:
current_tokens = mx.concatenate(toks, axis=1)
else:
current_tokens = toks[0]
logger.debug("debug print current_tokens:")
self.debug_print_tokens(current_tokens)
return current_tokens
def debug_print_tokens(self, tokens: mx.array):
"""Debug print token sequences."""
tokens_np = np.array(tokens)
for i in range(min(self.cfg.beam_size, tokens_np.shape[0])):
logger.debug(self.tokenizer.decode_with_timestamps(tokens_np[i].tolist()))
def segments_len(self) -> float:
"""Get total length of audio segments in seconds."""
return sum(s.shape[0] for s in self.state.segments) / 16000
def _apply_minseglen(self) -> bool:
"""Check if we have enough audio to process."""
segments_len = self.segments_len()
if segments_len < self.cfg.audio_min_len:
logger.debug("waiting for next segment")
return False
return True
def insert_audio(self, segment: np.ndarray = None):
"""Insert audio segment into buffer."""
if segment is not None:
if hasattr(segment, 'numpy'):
segment = segment.numpy()
self.state.segments.append(segment)
removed_len = 0
segments_len = self.segments_len()
while len(self.state.segments) > 1 and segments_len > self.cfg.audio_max_len:
removed_len = self.state.segments[0].shape[0] / 16000
segments_len -= removed_len
self.state.last_attend_frame -= int(TOKENS_PER_SECOND * removed_len)
self.state.cumulative_time_offset += removed_len
self.state.segments = self.state.segments[1:]
logger.debug(f"remove segments: {len(self.state.segments)} {len(self.state.tokens)}, cumulative offset: {self.state.cumulative_time_offset:.2f}s")
if len(self.state.tokens) > 1:
# Convert MLX array to list for context
token_list = np.array(self.state.tokens[1][0, :]).tolist()
self.state.context.append_token_ids(token_list)
self.state.tokens = [self.state.initial_tokens] + self.state.tokens[2:]
return removed_len
def _clean_cache(self):
"""Clean the kv_cache after each inference step."""
self.state.clean_cache()
def _suppress_tokens(self, logits: mx.array) -> mx.array:
"""Apply token suppression to logits."""
if self.state.suppress_tokens:
suppress_indices = mx.array(list(self.state.suppress_tokens), dtype=mx.int32)
logits = logits.at[:, suppress_indices].add(-float('inf'))
return logits
def lang_id(self, encoder_features: mx.array) -> Tuple[mx.array, List[dict]]:
"""Language detection from encoder features."""
n_audio = encoder_features.shape[0]
x = mx.array([[self.tokenizer.sot]] * n_audio, dtype=mx.int32)
logits, _, _ = self.model.decoder(x, encoder_features, kv_cache=None)
logits = logits[:, 0]
mask = mx.ones(logits.shape[-1], dtype=mx.bool_)
language_token_indices = mx.array(list(self.tokenizer.all_language_tokens), dtype=mx.int32)
mask = mask.at[language_token_indices].add(False)
logits = mx.where(mask, mx.array(-float('inf')), logits)
language_tokens = mx.argmax(logits, axis=-1)
language_token_probs = mx.softmax(logits, axis=-1)
probs_np = np.array(language_token_probs)
language_probs = [
{
c: float(probs_np[i, j])
for j, c in zip(self.tokenizer.all_language_tokens, self.tokenizer.all_language_codes)
}
for i in range(n_audio)
]
self._clean_cache()
return language_tokens, language_probs
def infer(self, is_last: bool = False) -> List[ASRToken]:
"""
Main inference method.
Args:
is_last: Whether this is the final chunk
Returns:
List of timestamped ASR tokens
"""
new_segment = True
if len(self.state.segments) == 0:
logger.debug("No segments, nothing to do")
return []
if not self._apply_minseglen():
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
return []
if len(self.state.segments) > 1:
input_segments = np.concatenate(self.state.segments, axis=0)
else:
input_segments = self.state.segments[0]
beg_encode = time()
mlx_mel_padded = mlx_log_mel_spectrogram(
audio=input_segments,
n_mels=self.model.dims.n_mels,
padding=N_SAMPLES
)
mlx_mel = mlx_pad_or_trim(mlx_mel_padded, N_FRAMES, axis=-2)
encoder_feature = self.model.encoder(mlx_mel[None])
content_mel_len = int((mlx_mel_padded.shape[0] - mlx_mel.shape[0]) / 2)
mx.eval(encoder_feature)
end_encode = time()
logger.debug(f'MLX Encoder duration: {end_encode - beg_encode:.3f}s')
if self.cfg.language == "auto" and self.state.detected_language is None and self.state.first_timestamp:
seconds_since_start = self.segments_len() - self.state.first_timestamp
if seconds_since_start >= 2.0:
language_tokens, language_probs = self.lang_id(encoder_feature)
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
print(f"Detected language: {top_lan} with p={p:.4f}")
self.create_tokenizer(top_lan)
self.state.last_attend_frame = -self.cfg.rewind_threshold
self.state.cumulative_time_offset = 0.0
self.init_tokens()
self.init_context()
self.state.detected_language = top_lan
logger.info(f"Tokenizer language: {self.tokenizer.language}")
self.trim_context()
current_tokens = self._current_tokens()
fire_detected = self.fire_at_boundary(encoder_feature[:, :content_mel_len, :])
sum_logprobs = mx.zeros((self.cfg.beam_size,), dtype=mx.float32)
completed = False
attn_of_alignment_heads = None
most_attended_frame = None
token_len_before_decoding = current_tokens.shape[1]
l_absolute_timestamps = []
accumulated_cross_attns = []
audio_duration_s = self.segments_len()
# ~15 text tokens/s is a generous upper bound for speech; TOKENS_PER_SECOND (50)
# is the mel-frame rate and was causing 10-40x over-allocation on repetition loops.
max_tokens_per_chunk = max(50, int(audio_duration_s * 15 * 1.5))
tokens_produced_this_chunk = 0
while not completed and current_tokens.shape[1] < self.max_text_len:
tokens_produced_this_chunk += 1
if tokens_produced_this_chunk > max_tokens_per_chunk:
logger.warning(f"[Loop Detection] Too many tokens ({tokens_produced_this_chunk}) for {audio_duration_s:.2f}s audio. Breaking.")
current_tokens = current_tokens[:, :token_len_before_decoding]
break
if new_segment:
tokens_for_logits = current_tokens
else:
tokens_for_logits = current_tokens[:, -1:]
if self.state.decoder_type == "greedy":
logits, self.state.kv_cache, cross_qk = self.model.decoder(
tokens_for_logits, encoder_feature, kv_cache=self.state.kv_cache
)
else:
logits, cross_qk = self.state.inference.logits(tokens_for_logits, encoder_feature)
mx.eval(logits)
accumulated_cross_attns.append(cross_qk)
if len(accumulated_cross_attns) > 16:
accumulated_cross_attns = accumulated_cross_attns[-16:]
if new_segment and self.tokenizer.no_speech is not None:
probs_at_sot = mx.softmax(logits[:, self.state.sot_index, :], axis=-1)
no_speech_probs = np.array(probs_at_sot[:, self.tokenizer.no_speech]).tolist()
if no_speech_probs[0] > self.cfg.nonspeech_prob:
logger.info("no speech, stop")
break
logits = logits[:, -1, :] # Last token logits
# Suppress tokens at segment start
if new_segment:
blank_tokens = self.tokenizer.encode(" ") + [self.tokenizer.eot]
logits = logits.at[:, blank_tokens].add(-float('inf'))
new_segment = False
logits = self._suppress_tokens(logits)
current_tokens, completed = self.state.token_decoder.update(
current_tokens, logits, sum_logprobs
)
mx.eval(current_tokens)
logger.debug(f"Decoding completed: {completed}")
self.debug_print_tokens(current_tokens)
attn_of_alignment_heads = self._process_cross_attention(
accumulated_cross_attns, content_mel_len
)
most_attended_frames = mx.argmax(attn_of_alignment_heads[:, -1, :], axis=-1)
most_attended_frames_np = np.array(most_attended_frames)
absolute_timestamps = [
(frame * 0.02 + self.state.cumulative_time_offset)
for frame in most_attended_frames_np.tolist()
]
logger.debug(str(most_attended_frames_np.tolist()) + " most att frames")
logger.debug(f"Absolute timestamps: {absolute_timestamps}")
most_attended_frame = int(most_attended_frames_np[0])
l_absolute_timestamps.append(absolute_timestamps[0])
if completed:
current_tokens = current_tokens[:, :-1]
break
if not is_last and self.state.last_attend_frame - most_attended_frame > self.cfg.rewind_threshold:
current_tokens_np = np.array(current_tokens)
if current_tokens.shape[1] > 1 and current_tokens_np[0, -2] >= DEC_PAD:
logger.debug("omit rewinding from special tokens")
self.state.last_attend_frame = most_attended_frame
else:
logger.debug(f"[rewind detected] current: {most_attended_frame}, last: {self.state.last_attend_frame}")
self.state.last_attend_frame = -self.cfg.rewind_threshold
current_tokens = mx.concatenate(self.state.tokens, axis=1) if len(self.state.tokens) > 0 else self.state.tokens[0]
break
else:
self.state.last_attend_frame = most_attended_frame
if content_mel_len - most_attended_frame <= (4 if is_last else self.cfg.frame_threshold):
logger.debug(f"attention reaches the end: {most_attended_frame}/{content_mel_len}")
current_tokens = current_tokens[:, :-1]
break
tokens_to_split = np.array(current_tokens[0, token_len_before_decoding:]).tolist()
if self.state.pending_incomplete_tokens:
logger.debug(f"[UTF-8 Fix] Prepending pending tokens: {self.state.pending_incomplete_tokens}")
tokens_to_split = self.state.pending_incomplete_tokens + tokens_to_split
if fire_detected or is_last:
new_hypothesis = tokens_to_split
split_words, split_tokens = self.tokenizer.split_to_word_tokens(new_hypothesis)
else:
split_words, split_tokens = self.tokenizer.split_to_word_tokens(tokens_to_split)
if len(split_words) > 1:
new_hypothesis = [i for sublist in split_tokens[:-1] for i in sublist]
else:
new_hypothesis = []
logger.debug(f"new_hypothesis: {new_hypothesis}")
new_tokens = mx.array([new_hypothesis], dtype=mx.int32)
new_tokens = mx.repeat(new_tokens, self.cfg.beam_size, axis=0)
self.state.tokens.append(new_tokens)
logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}")
self._clean_cache()
if len(l_absolute_timestamps) >= 2 and self.state.first_timestamp is None:
self.state.first_timestamp = l_absolute_timestamps[0]
timestamped_words = []
timestamp_idx = 0
replacement_char = "\ufffd"
for word, word_tokens in zip(split_words, split_tokens):
if replacement_char in word:
logger.warning(f"[UTF-8 Filter] Skipping: {repr(word)}")
timestamp_idx += len(word_tokens)
continue
try:
current_timestamp = l_absolute_timestamps[timestamp_idx]
except IndexError:
pass
timestamp_idx += len(word_tokens)
timestamp_entry = ASRToken(
start=round(current_timestamp, 2),
end=round(current_timestamp + 0.1, 2),
text=word,
speaker=self.state.speaker,
detected_language=self.state.detected_language
).with_offset(self.state.global_time_offset)
timestamped_words.append(timestamp_entry)
self.state.pending_incomplete_tokens = []
MAX_PENDING_TOKENS = 10
if split_words and replacement_char in split_words[-1]:
if len(split_tokens[-1]) <= MAX_PENDING_TOKENS:
self.state.pending_incomplete_tokens = split_tokens[-1]
logger.debug(f"[UTF-8 Fix] Holding incomplete tokens")
else:
logger.warning(f"[UTF-8 Fix] Skipping too many tokens")
return timestamped_words
def _process_cross_attention(
self,
cross_attns: List[List[mx.array]],
content_mel_len: int
) -> mx.array:
"""
Process cross-attention weights for alignment.
Args:
cross_attns: List of cross-attention from each forward pass
Each element is a list of mx.arrays per layer
content_mel_len: Length of actual audio content
Returns:
Processed attention tensor, shape (batch, seq_len, content_mel_len)
"""
attn_of_alignment_heads = [[] for _ in range(self.state.num_align_heads)]
num_decoder_layers = self.num_decoder_layers
if cross_attns and isinstance(cross_attns[0], list):
flattened_attns = [attn for layer_list in cross_attns for attn in layer_list]
else:
flattened_attns = cross_attns
for idx, attn_mat in enumerate(flattened_attns):
if attn_mat is None:
continue
layer_rank = idx % num_decoder_layers
align_heads_in_layer = self.state.align_source.get(layer_rank, [])
if len(align_heads_in_layer) == 0:
continue
attn_mat = mx.softmax(attn_mat, axis=-1)
for align_head_rank, head_id in align_heads_in_layer:
if self.cfg.beam_size == 1:
if attn_mat.ndim == 4:
a = attn_mat[0, head_id, :, :]
else:
a = attn_mat[head_id, :, :]
a = a[None, :, :]
else:
a = attn_mat[:, head_id, :, :]
attn_of_alignment_heads[align_head_rank].append(a)
tmp = []
for mat in attn_of_alignment_heads:
if mat:
t = mx.concatenate(mat, axis=1)
tmp.append(t)
if not tmp:
return mx.zeros((self.cfg.beam_size, 1, content_mel_len))
attn_of_alignment_heads = mx.stack(tmp, axis=1)
std = mx.std(attn_of_alignment_heads, axis=-2, keepdims=True)
mean = mx.mean(attn_of_alignment_heads, axis=-2, keepdims=True)
attn_of_alignment_heads = (attn_of_alignment_heads - mean) / (std + 1e-8)
attn_of_alignment_heads = mlx_median_filter(attn_of_alignment_heads, 7)
attn_of_alignment_heads = mx.mean(attn_of_alignment_heads, axis=1)
attn_of_alignment_heads = attn_of_alignment_heads[:, :, :content_mel_len]
mx.eval(attn_of_alignment_heads)
return attn_of_alignment_heads

View File

@@ -68,4 +68,40 @@ def load_mlx_encoder(
model.update(encoder_weights) model.update(encoder_weights)
mx.eval(model.parameters()) mx.eval(model.parameters())
return model
def load_mlx_model(
path_or_hf_repo: str,
dtype: mx.Dtype = mx.float32,
) -> whisper.Whisper:
model_path = Path(path_or_hf_repo)
if not model_path.exists():
model_path = Path(snapshot_download(repo_id=path_or_hf_repo))
with open(str(model_path / "config.json"), "r") as f:
config = json.loads(f.read())
config.pop("model_type", None)
quantization = config.pop("quantization", None)
model_args = whisper.ModelDimensions(**config)
wf = model_path / "weights.safetensors"
if not wf.exists():
wf = model_path / "weights.npz"
weights = mx.load(str(wf))
model = whisper.Whisper(model_args, dtype)
if quantization is not None:
class_predicate = (
lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
and f"{p}.scales" in weights
)
nn.quantize(model, **quantization, class_predicate=class_predicate)
weights = tree_unflatten(list(weights.items()))
model.update(weights)
mx.eval(model.parameters())
return model return model

View File

@@ -390,7 +390,6 @@ class AlignAtt:
return [] return []
if not self._apply_minseglen(): if not self._apply_minseglen():
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.") logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
input_segments = torch.cat(self.state.segments, dim=0)
return [] return []
# input_segments is concatenation of audio, it's one array # input_segments is concatenation of audio, it's one array
@@ -485,7 +484,9 @@ class AlignAtt:
accumulated_cross_attns = [] accumulated_cross_attns = []
audio_duration_s = self.segments_len() audio_duration_s = self.segments_len()
max_tokens_per_chunk = max(50, int(audio_duration_s * TOKENS_PER_SECOND * 2.0)) # 2x margin, min 50 # ~15 text tokens/s is a generous upper bound for speech; TOKENS_PER_SECOND (50)
# is the mel-frame rate and was causing 10-40x over-allocation on repetition loops.
max_tokens_per_chunk = max(50, int(audio_duration_s * 15 * 1.5))
tokens_produced_this_chunk = 0 tokens_produced_this_chunk = 0
while not completed and current_tokens.shape[1] < self.max_text_len: # bos is 3 tokens while not completed and current_tokens.shape[1] < self.max_text_len: # bos is 3 tokens
@@ -506,8 +507,12 @@ class AlignAtt:
result = self.logits(tokens_for_logits, encoder_feature, return_cross_attn=True) result = self.logits(tokens_for_logits, encoder_feature, return_cross_attn=True)
logits, cross_attns = result logits, cross_attns = result
# Accumulate cross-attention from this forward pass # Accumulate cross-attention from this forward pass (rolling window to
# bound VRAM — only the last entry matters for alignment, and the
# median_filter kernel is 7, so 16 entries is more than enough).
accumulated_cross_attns.append(cross_attns) accumulated_cross_attns.append(cross_attns)
if len(accumulated_cross_attns) > 16:
accumulated_cross_attns = accumulated_cross_attns[-16:]
if new_segment and self.tokenizer.no_speech is not None: if new_segment and self.tokenizer.no_speech is not None:
probs_at_sot = logits[:, self.state.sot_index, :].float().softmax(dim=-1) probs_at_sot = logits[:, self.state.sot_index, :].float().softmax(dim=-1)
@@ -626,8 +631,10 @@ class AlignAtt:
try: try:
current_timestamp = l_absolute_timestamps[timestamp_idx] current_timestamp = l_absolute_timestamps[timestamp_idx]
except: except IndexError:
pass # Use last timestamp if index out of range
logger.warning(f"Timestamp index {timestamp_idx} out of range, using last timestamp")
current_timestamp = l_absolute_timestamps[-1] if l_absolute_timestamps else 0.0
timestamp_idx += len(word_tokens) timestamp_idx += len(word_tokens)
timestamp_entry = ASRToken( timestamp_entry = ASRToken(

View File

@@ -0,0 +1,139 @@
"""
Thread Safety Configuration for WhisperLiveKit
This module provides thread safety configuration and utilities.
Environment Variables:
WHISPERLIVEKIT_MODEL_LOCK: Enable/disable model locking (default: 1)
Set to "0" to disable for single-connection deployments
WHISPERLIVEKIT_LOCK_TIMEOUT: Lock acquisition timeout in seconds (default: 30)
Usage:
# Enable model locking (default)
export WHISPERLIVEKIT_MODEL_LOCK=1
# Disable for single-connection deployment
export WHISPERLIVEKIT_MODEL_LOCK=0
# Custom timeout
export WHISPERLIVEKIT_LOCK_TIMEOUT=60
"""
import os
import logging
import threading
logger = logging.getLogger(__name__)
# Configuration
USE_MODEL_LOCK = os.environ.get("WHISPERLIVEKIT_MODEL_LOCK", "1") == "1"
LOCK_TIMEOUT = float(os.environ.get("WHISPERLIVEKIT_LOCK_TIMEOUT", "30.0"))
# Global model lock
_model_lock = threading.Lock()
# Log configuration on import
if USE_MODEL_LOCK:
logger.info(f"Model locking ENABLED (timeout: {LOCK_TIMEOUT}s)")
logger.info("For single-connection deployments, set WHISPERLIVEKIT_MODEL_LOCK=0")
else:
logger.warning("Model locking DISABLED - only safe for single-connection deployments")
def get_model_lock():
"""Get the global model lock instance"""
return _model_lock
def acquire_model_lock(timeout=None):
"""
Acquire model lock with timeout.
Args:
timeout: Lock acquisition timeout (default: use LOCK_TIMEOUT)
Returns:
bool: True if lock acquired, False on timeout
"""
if not USE_MODEL_LOCK:
return True
timeout = timeout or LOCK_TIMEOUT
acquired = _model_lock.acquire(timeout=timeout)
if not acquired:
logger.error(f"Failed to acquire model lock within {timeout}s")
return acquired
def release_model_lock():
"""Release model lock"""
if not USE_MODEL_LOCK:
return
try:
_model_lock.release()
except RuntimeError:
# Lock not held - this is fine
pass
class ModelLockContext:
"""Context manager for model lock"""
def __init__(self, timeout=None):
self.timeout = timeout
self.acquired = False
def __enter__(self):
self.acquired = acquire_model_lock(self.timeout)
return self.acquired
def __exit__(self, exc_type, exc_val, exc_tb):
if self.acquired:
release_model_lock()
return False
# Concurrency recommendations
RECOMMENDED_CONNECTIONS_PER_WORKER = 1 if USE_MODEL_LOCK else 1
RECOMMENDED_WORKERS = 4
def print_deployment_recommendations():
"""Print recommended deployment configuration"""
print("\n" + "="*60)
print("WhisperLiveKit Deployment Recommendations")
print("="*60)
if USE_MODEL_LOCK:
print("⚠️ Model locking is ENABLED")
print(" This serializes inference across connections.")
print()
print("Recommended deployment:")
print(f" gunicorn -w {RECOMMENDED_WORKERS} \\")
print(" -k uvicorn.workers.UvicornWorker \\")
print(" --worker-connections 1 \\")
print(" whisperlivekit.basic_server:app")
print()
print("Expected capacity:")
print(f" - {RECOMMENDED_WORKERS} concurrent users (1 per worker)")
print(f" - Memory: ~{RECOMMENDED_WORKERS}x model size")
else:
print("✅ Model locking is DISABLED")
print(" ⚠️ ONLY safe for single-connection deployments")
print()
print("Recommended deployment:")
print(" uvicorn whisperlivekit.basic_server:app \\")
print(" --host 0.0.0.0 --port 8000 \\")
print(" --workers 1")
print()
print("Expected capacity:")
print(" - 1 concurrent user only")
print("="*60 + "\n")
if __name__ == "__main__":
print_deployment_recommendations()

View File

@@ -39,10 +39,11 @@ class TimedText(Timed):
@dataclass() @dataclass()
class ASRToken(TimedText): class ASRToken(TimedText):
probability: Optional[float] = None
def with_offset(self, offset: float) -> "ASRToken": def with_offset(self, offset: float) -> "ASRToken":
"""Return a new token with the time offset added.""" """Return a new token with the time offset added."""
return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, detected_language=self.detected_language) return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, detected_language=self.detected_language, probability=self.probability)
def is_silence(self) -> bool: def is_silence(self) -> bool:
return False return False

View File

@@ -49,9 +49,12 @@ class TokensAlignment:
def add_translation(self, segment: Segment) -> None: def add_translation(self, segment: Segment) -> None:
"""Append translated text segments that overlap with a segment.""" """Append translated text segments that overlap with a segment."""
if segment.translation is None:
segment.translation = ''
for ts in self.all_translation_segments: for ts in self.all_translation_segments:
if ts.is_within(segment): if ts.is_within(segment):
segment.translation += ts.text + (self.sep if ts.text else '') if ts.text:
segment.translation += ts.text + self.sep
elif segment.translation: elif segment.translation:
break break
@@ -183,11 +186,11 @@ class TokensAlignment:
else: else:
diarization_buffer = '' diarization_buffer = ''
for token in self.new_tokens: for token in self.new_tokens:
if token.is_silence(): if isinstance(token, Silence):
if self.current_line_tokens: if self.current_line_tokens:
self.validated_segments.append(Segment().from_tokens(self.current_line_tokens)) self.validated_segments.append(Segment.from_tokens(self.current_line_tokens))
self.current_line_tokens = [] self.current_line_tokens = []
end_silence = token.end if token.has_ended else time() - self.beg_loop end_silence = token.end if token.has_ended else time() - self.beg_loop
if self.validated_segments and self.validated_segments[-1].is_silence(): if self.validated_segments and self.validated_segments[-1].is_silence():
self.validated_segments[-1].end = end_silence self.validated_segments[-1].end = end_silence
@@ -201,7 +204,7 @@ class TokensAlignment:
segments = list(self.validated_segments) segments = list(self.validated_segments)
if self.current_line_tokens: if self.current_line_tokens:
segments.append(Segment().from_tokens(self.current_line_tokens)) segments.append(Segment.from_tokens(self.current_line_tokens))
if current_silence: if current_silence:
end_silence = current_silence.end if current_silence.has_ended else time() - self.beg_loop end_silence = current_silence.end if current_silence.has_ended else time() - self.beg_loop

View File

@@ -0,0 +1,484 @@
"""
Voxtral Mini Realtime streaming backend using voxmlx's incremental encode/decode.
Uses model.encode_step() for incremental audio encoding and token-by-token
autoregressive decoding, matching voxmlx's native streaming pipeline.
"""
import logging
import sys
import time
from typing import List, Optional, Tuple
import numpy as np
from whisperlivekit.timed_objects import ASRToken, Transcript
logger = logging.getLogger(__name__)
N_LEFT_PAD_TOKENS = 32
N_RIGHT_PAD_TOKENS = 17
class VoxtralStreamingASR:
"""Voxtral model holder for the streaming pipeline."""
sep = " "
def __init__(self, logfile=sys.stderr, **kwargs):
from voxmlx import _build_prompt_tokens
from voxmlx import load_model as vox_load_model
self.logfile = logfile
self.transcribe_kargs = {}
lan = kwargs.get("lan", "auto")
self.original_language = None if lan == "auto" else lan
DEFAULT_MODEL = "mlx-community/Voxtral-Mini-4B-Realtime-6bit"
model_path = kwargs.get("model_dir") or kwargs.get("model_path")
if not model_path:
model_size = kwargs.get("model_size", "")
# Only use model_size if it looks like a HF repo or a path, not a Whisper size name
if model_size and ("/" in model_size or model_size.startswith(".")):
model_path = model_size
else:
model_path = DEFAULT_MODEL
t = time.time()
logger.info(f"Loading Voxtral model '{model_path}' via voxmlx...")
self.model, self._tokenizer, self._config = vox_load_model(model_path)
self._prompt_tokens, self._n_delay_tokens = _build_prompt_tokens(
self._tokenizer
)
logger.info(f"Voxtral model loaded in {time.time() - t:.2f}s")
self.backend_choice = "voxtral-mlx"
self.tokenizer = None # sentence tokenizer — not needed for streaming
def transcribe(self, audio):
pass
class VoxtralStreamingOnlineProcessor:
"""
Online processor for Voxtral streaming ASR.
Uses voxmlx's incremental encoding (encode_step) and token-by-token
autoregressive decoding. Each decode step corresponds to 80ms of audio.
"""
SAMPLING_RATE = 16000
def __init__(self, asr: VoxtralStreamingASR, logfile=sys.stderr):
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
self.asr = asr
self.logfile = logfile
self.end = 0.0
self.buffer = []
self.audio_buffer = np.array([], dtype=np.float32) # for logging compat
self._special_token_policy = SpecialTokenPolicy.IGNORE
self._reset_state()
logger.info(
f"[voxtral] Initialized. eos_id={asr._tokenizer.eos_id}, "
f"prefix_len={len(asr._prompt_tokens)}, "
f"n_delay={asr._n_delay_tokens}"
)
def _reset_state(self):
from voxmlx.audio import SAMPLES_PER_TOKEN
self._samples_per_token = SAMPLES_PER_TOKEN
# Incremental encoder state
self._audio_tail = None
self._conv1_tail = None
self._conv2_tail = None
self._encoder_cache = None
self._ds_buf = None
# Decoder state
self._decoder_cache = None
self._y = None # last sampled token (mx.array scalar)
self._t_cond = None
self._text_embeds = None
# Audio / decode tracking
self._pending_audio = np.zeros(0, dtype=np.float32)
self._audio_embeds = None
self._n_audio_samples_fed = 0
self._n_total_decoded = 0
self._first_cycle = True
self._prefilled = False
# Word extraction: accumulate token IDs, full-sequence decode for correct spacing
self._output_token_ids: List[int] = []
self._token_positions: List[int] = [] # decode position for each token
self._n_committed_words = 0
self._global_time_offset = 0.0
self._y_flushed_to_output = False # True after start_silence flushes pending _y
# ── Interface methods (same as SimulStreamingOnlineProcessor) ──
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float):
self.end = audio_stream_end_time
self._pending_audio = np.append(self._pending_audio, audio)
self.audio_buffer = self._pending_audio # for logging compat
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
try:
return self._process_iter_inner(is_last)
except Exception as e:
logger.warning(f"[voxtral] process_iter exception: {e}", exc_info=True)
return [], self.end
def _get_full_text(self) -> str:
"""Decode all accumulated token IDs at once for correct spacing."""
if not self._output_token_ids:
return ""
sp = self.asr._tokenizer
return sp.decode(self._output_token_ids, special_token_policy=self._special_token_policy)
def get_buffer(self) -> Transcript:
"""Return all uncommitted text as buffer, including pending _y token."""
# Temporarily include pending _y for buffer display
ids = list(self._output_token_ids)
if self._y is not None and not self._y_flushed_to_output:
sp = self.asr._tokenizer
token_id = self._y.item()
if token_id != sp.eos_id:
ids.append(token_id)
if not ids:
return Transcript(start=None, end=None, text="")
sp = self.asr._tokenizer
full_text = sp.decode(ids, special_token_policy=self._special_token_policy)
words = full_text.split()
uncommitted = words[self._n_committed_words:]
if uncommitted:
text = " ".join(uncommitted)
return Transcript(start=self.end, end=self.end, text=text)
return Transcript(start=None, end=None, text="")
def start_silence(self) -> Tuple[List[ASRToken], float]:
"""Flush all uncommitted words when silence starts."""
self._flush_last_y() # Include the pending _y token before flushing
words = self._flush_all_pending_words()
logger.info(f"[voxtral] start_silence: flushed {len(words)} words")
return words, self.end
def end_silence(self, silence_duration: float, offset: float):
self._global_time_offset += silence_duration
self.end += silence_duration
def new_speaker(self, change_speaker):
self.start_silence()
def warmup(self, audio, init_prompt=""):
pass
def finish(self) -> Tuple[List[ASRToken], float]:
"""Flush remaining audio with right-padding to let the model finish decoding."""
right_pad = np.zeros(
N_RIGHT_PAD_TOKENS * self._samples_per_token, dtype=np.float32
)
self._pending_audio = np.append(self._pending_audio, right_pad)
self._n_audio_samples_fed += len(right_pad)
final_words, _ = self._process_iter_inner(is_last=True)
# Flush the last pending self._y token (like voxmlx's finally block)
self._flush_last_y()
final_words.extend(self._flush_all_pending_words())
return final_words, self.end
# ── Word extraction ──
def _pos_to_time(self, pos: int) -> float:
"""Convert a decode position to seconds relative to audio start."""
SPT = self._samples_per_token
return max(0.0, (pos - N_LEFT_PAD_TOKENS) * SPT / self.SAMPLING_RATE)
def _flush_last_y(self):
"""Flush the last pending self._y token that hasn't been processed yet."""
if self._y is None or self._y_flushed_to_output:
return
sp = self.asr._tokenizer
token_id = self._y.item()
if token_id != sp.eos_id:
self._output_token_ids.append(token_id)
self._token_positions.append(self._n_total_decoded)
self._y_flushed_to_output = True
def _extract_new_words(self) -> List[ASRToken]:
"""
Split accumulated text into words and return new complete words
(all but the last, which may still be growing).
"""
if not self._output_token_ids:
return []
full_text = self._get_full_text()
words = full_text.split()
new_words: List[ASRToken] = []
n_tokens = len(self._output_token_ids)
# All words except the last are guaranteed complete
while len(words) > self._n_committed_words + 1:
word = words[self._n_committed_words]
word_idx = self._n_committed_words
n_words_total = len(words)
# Approximate: assign token range proportionally
tok_start = int(word_idx / n_words_total * n_tokens)
tok_end = int((word_idx + 1) / n_words_total * n_tokens)
tok_start = min(tok_start, len(self._token_positions) - 1)
tok_end = min(tok_end, len(self._token_positions) - 1)
start_time = self._pos_to_time(self._token_positions[tok_start]) + self._global_time_offset
end_time = self._pos_to_time(self._token_positions[tok_end]) + self._global_time_offset
# Prepend space to match Whisper convention (Segment.from_tokens joins with '')
text = word if self._n_committed_words == 0 else " " + word
new_words.append(ASRToken(start=start_time, end=end_time, text=text))
self._n_committed_words += 1
return new_words
def _flush_all_pending_words(self) -> List[ASRToken]:
"""Flush ALL words including the last partial one."""
if not self._output_token_ids:
return []
full_text = self._get_full_text()
words = full_text.split()
new_words: List[ASRToken] = []
n_tokens = len(self._output_token_ids)
n_words_total = max(len(words), 1)
while self._n_committed_words < len(words):
word = words[self._n_committed_words]
word_idx = self._n_committed_words
tok_start = int(word_idx / n_words_total * n_tokens)
tok_end = int((word_idx + 1) / n_words_total * n_tokens)
tok_start = min(tok_start, max(len(self._token_positions) - 1, 0))
tok_end = min(tok_end, max(len(self._token_positions) - 1, 0))
if self._token_positions:
start_time = self._pos_to_time(self._token_positions[tok_start]) + self._global_time_offset
end_time = self._pos_to_time(self._token_positions[tok_end]) + self._global_time_offset
else:
start_time = self._global_time_offset
end_time = self._global_time_offset
# Prepend space to match Whisper convention (Segment.from_tokens joins with '')
text = word if self._n_committed_words == 0 else " " + word
new_words.append(ASRToken(start=start_time, end=end_time, text=text))
self._n_committed_words += 1
return new_words
# ── Core streaming logic ──
def _process_iter_inner(self, is_last: bool) -> Tuple[List[ASRToken], float]:
import mlx.core as mx
from voxmlx.audio import log_mel_spectrogram_step
from voxmlx.cache import RotatingKVCache
model = self.asr.model
sp = self.asr._tokenizer
prompt_tokens = self.asr._prompt_tokens
prefix_len = len(prompt_tokens)
SPT = self._samples_per_token
# ── Phase 1: Encode new audio ──
if self._first_cycle and len(self._pending_audio) >= SPT:
left_pad = np.zeros(N_LEFT_PAD_TOKENS * SPT, dtype=np.float32)
n_feed = (len(self._pending_audio) // SPT) * SPT
chunk = np.concatenate([left_pad, self._pending_audio[:n_feed]])
self._pending_audio = self._pending_audio[n_feed:]
self._n_audio_samples_fed += n_feed
mel, self._audio_tail = log_mel_spectrogram_step(
chunk, self._audio_tail
)
(
new_embeds,
self._conv1_tail,
self._conv2_tail,
self._encoder_cache,
self._ds_buf,
) = model.encode_step(
mel,
self._conv1_tail,
self._conv2_tail,
self._encoder_cache,
self._ds_buf,
)
if new_embeds is not None:
mx.eval(new_embeds)
self._audio_embeds = new_embeds
logger.info(f"[voxtral] first encode: {new_embeds.shape[0]} embeds from {n_feed} samples")
else:
logger.info(f"[voxtral] first encode: no embeds from {n_feed} samples")
self._first_cycle = False
elif not self._first_cycle and len(self._pending_audio) >= SPT:
n_feed = (len(self._pending_audio) // SPT) * SPT
chunk = self._pending_audio[:n_feed]
self._pending_audio = self._pending_audio[n_feed:]
self._n_audio_samples_fed += n_feed
mel, self._audio_tail = log_mel_spectrogram_step(
chunk, self._audio_tail
)
(
new_embeds,
self._conv1_tail,
self._conv2_tail,
self._encoder_cache,
self._ds_buf,
) = model.encode_step(
mel,
self._conv1_tail,
self._conv2_tail,
self._encoder_cache,
self._ds_buf,
)
if new_embeds is not None:
mx.eval(new_embeds)
if self._audio_embeds is not None:
self._audio_embeds = mx.concatenate(
[self._audio_embeds, new_embeds]
)
else:
self._audio_embeds = new_embeds
self.audio_buffer = self._pending_audio # for logging compat
if self._audio_embeds is None:
return [], self.end
# Safety: don't decode ahead of encoded audio
safe_total = (
N_LEFT_PAD_TOKENS + self._n_audio_samples_fed // SPT
)
n_decodable = min(
self._audio_embeds.shape[0], safe_total - self._n_total_decoded
)
if n_decodable <= 0:
return [], self.end
# ── Phase 2: Prefill (once per utterance) ──
if not self._prefilled:
if self._n_total_decoded + self._audio_embeds.shape[0] < prefix_len:
logger.info(
f"[voxtral] waiting for prefill: have {self._audio_embeds.shape[0]} embeds, need {prefix_len}"
)
return [], self.end
n_layers = len(model.language_model.layers)
self._decoder_cache = [RotatingKVCache(8192) for _ in range(n_layers)]
self._t_cond = model.time_embedding(
mx.array([self.asr._n_delay_tokens], dtype=mx.float32)
)
prompt_ids = mx.array([prompt_tokens])
self._text_embeds = model.language_model.embed(prompt_ids)[0]
prefix_embeds = (
self._text_embeds + self._audio_embeds[:prefix_len]
)[None, :, :]
logits = model.decode(
prefix_embeds, self._t_cond, "causal", self._decoder_cache
)
mx.eval(
logits,
*[x for c in self._decoder_cache for x in (c.keys, c.values)],
)
self._y = mx.argmax(logits[0, -1:], axis=-1).squeeze()
mx.async_eval(self._y)
self._audio_embeds = self._audio_embeds[prefix_len:]
self._n_total_decoded = prefix_len
self._prefilled = True
logger.info(f"[voxtral] prefill done, first token y={self._y.item()}")
n_decodable = min(
self._audio_embeds.shape[0], safe_total - self._n_total_decoded
)
if n_decodable <= 0:
return [], self.end
# ── Phase 3: Decode new positions ──
eos_id = sp.eos_id
hit_eos = False
n_consumed = 0
for i in range(n_decodable):
token_embed = model.language_model.embed(self._y.reshape(1, 1))[0, 0]
step_embed = (self._audio_embeds[i] + token_embed)[None, None, :]
logits = model.decode(
step_embed, self._t_cond, mask=None, cache=self._decoder_cache
)
next_y = mx.argmax(logits[0, -1:], axis=-1).squeeze()
mx.async_eval(next_y)
token_id = self._y.item()
n_consumed = i + 1
if token_id == eos_id:
hit_eos = True
logger.info("[voxtral] hit EOS")
break
# Accumulate token ID — full-sequence decode produces correct spacing
# Skip if this _y was already flushed by start_silence()
if self._y_flushed_to_output:
self._y_flushed_to_output = False
else:
self._output_token_ids.append(token_id)
# Track position for timestamp estimation
pos = self._n_total_decoded + i
self._token_positions.append(pos)
if i > 0 and i % 256 == 0:
mx.clear_cache()
self._y = next_y
self._n_total_decoded += n_consumed
# Trim consumed embeddings
if self._audio_embeds.shape[0] > n_consumed:
self._audio_embeds = self._audio_embeds[n_consumed:]
else:
self._audio_embeds = None
# Log decode results
full_text = self._get_full_text()
logger.info(
f"[voxtral] decoded {n_consumed} tokens | "
f"total_decoded={self._n_total_decoded} | "
f"text='{full_text[-80:]}' | "
f"n_words={len(full_text.split())} committed={self._n_committed_words}"
)
# Extract complete words from the decoded token sequence
new_words = self._extract_new_words()
if hit_eos:
new_words.extend(self._flush_all_pending_words())
self._reset_state()
if new_words:
logger.info(f"[voxtral] returning {len(new_words)} words: {[w.text for w in new_words]}")
self.buffer = []
return new_words, self.end

View File

@@ -108,7 +108,7 @@ def available_models() -> List[str]:
def _infer_dims_from_config(path: str) -> Optional[ModelDimensions]: def _infer_dims_from_config(path: str) -> Optional[ModelDimensions]:
""" """
attempt to infer ModelDimensions from a HF style config.json located attempt to infer ModelDimensions from a HF style config.json located
next to the given checkpoint, usefull for distilled models next to the given checkpoint, usefull for distilled models/MLX models.
""" """
candidates = [] candidates = []
if os.path.isdir(path): if os.path.isdir(path):
@@ -122,6 +122,25 @@ def _infer_dims_from_config(path: str) -> Optional[ModelDimensions]:
with open(candidate, "r", encoding="utf-8") as f: with open(candidate, "r", encoding="utf-8") as f:
config = json.load(f) config = json.load(f)
# native Whisper format
native_keys = ["n_mels", "n_audio_ctx", "n_audio_state", "n_audio_head",
"n_audio_layer", "n_vocab", "n_text_ctx", "n_text_state",
"n_text_head", "n_text_layer"]
if all(k in config for k in native_keys):
return ModelDimensions(
n_mels=config["n_mels"],
n_audio_ctx=config["n_audio_ctx"],
n_audio_state=config["n_audio_state"],
n_audio_head=config["n_audio_head"],
n_audio_layer=config["n_audio_layer"],
n_vocab=config["n_vocab"],
n_text_ctx=config["n_text_ctx"],
n_text_state=config["n_text_state"],
n_text_head=config["n_text_head"],
n_text_layer=config["n_text_layer"],
)
# HuggingFace format
try: try:
return ModelDimensions( return ModelDimensions(
n_mels=config["num_mel_bins"], n_mels=config["num_mel_bins"],
@@ -236,6 +255,24 @@ def _convert_hf_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, tor
return converted if converted else state_dict return converted if converted else state_dict
def _convert_mlx_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Converts an mlx whisper checkpoint to a default openai whisper one
"""
if not any("mlp1" in k or "mlp2" in k for k in state_dict):
return state_dict
converted = {}
for key, value in state_dict.items():
if key == "alignment_heads":
continue
new_key = key.replace(".mlp1.", ".mlp.0.").replace(".mlp2.", ".mlp.2.")
converted[new_key] = value
return converted
def _load_lora_state(lora_path: str): def _load_lora_state(lora_path: str):
safe_path = os.path.join(lora_path, "adapter_model.safetensors") safe_path = os.path.join(lora_path, "adapter_model.safetensors")
bin_path = os.path.join(lora_path, "adapter_model.bin") bin_path = os.path.join(lora_path, "adapter_model.bin")
@@ -520,7 +557,12 @@ def load_model(
state_dict = checkpoint["model_state_dict"] state_dict = checkpoint["model_state_dict"]
else: else:
state_dict = checkpoint state_dict = checkpoint
if alignment_heads is None and "alignment_heads" in state_dict:
alignment_heads = state_dict["alignment_heads"]
state_dict = _convert_hf_state_dict(state_dict) state_dict = _convert_hf_state_dict(state_dict)
state_dict = _convert_mlx_state_dict(state_dict)
_apply_lora_adapter(state_dict, lora_path) _apply_lora_adapter(state_dict, lora_path)
if dims_cfg is not None: if dims_cfg is not None:
@@ -546,8 +588,13 @@ def load_model(
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
if alignment_heads is not None: if alignment_heads is not None:
model.set_alignment_heads(alignment_heads) if isinstance(alignment_heads, bytes):
model.set_alignment_heads(alignment_heads)
elif isinstance(alignment_heads, torch.Tensor): #for mlx whisper
mask = torch.zeros(dims.n_text_layer, dims.n_text_head, dtype=torch.bool)
for layer, head in alignment_heads.tolist():
mask[layer, head] = True
model.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
return model.to(device) return model.to(device)