mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
Compare commits
33 Commits
api_live
...
feature/vo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7ea507ed8e | ||
|
|
e7e82f7c19 | ||
|
|
8c799fa4d1 | ||
|
|
8923337380 | ||
|
|
aded1649ae | ||
|
|
3b535e857a | ||
|
|
d649250b9a | ||
|
|
7735478286 | ||
|
|
b9e72d2b9a | ||
|
|
e5b01033af | ||
|
|
6ae545bcb1 | ||
|
|
04980d3f5e | ||
|
|
79a705c969 | ||
|
|
34e4abd455 | ||
|
|
d59ddbaeae | ||
|
|
4dd66e7766 | ||
|
|
3db5d81a20 | ||
|
|
b67ddea494 | ||
|
|
3192553e20 | ||
|
|
f379a243fe | ||
|
|
ec09898a9f | ||
|
|
befbae56c7 | ||
|
|
719e8b1a20 | ||
|
|
f1b47178d8 | ||
|
|
59db08e961 | ||
|
|
6fc20b9562 | ||
|
|
fac8659161 | ||
|
|
4d9332ce7d | ||
|
|
62444ce746 | ||
|
|
2431a6bf91 | ||
|
|
d1263e7228 | ||
|
|
30ddd522a4 | ||
|
|
635bace09e |
@@ -37,9 +37,10 @@ RUN pip3 install --upgrade pip setuptools wheel && \
|
|||||||
COPY . .
|
COPY . .
|
||||||
|
|
||||||
# Install WhisperLiveKit directly, allowing for optional dependencies
|
# Install WhisperLiveKit directly, allowing for optional dependencies
|
||||||
|
# Example: --build-arg EXTRAS="translation"
|
||||||
RUN if [ -n "$EXTRAS" ]; then \
|
RUN if [ -n "$EXTRAS" ]; then \
|
||||||
echo "Installing with extras: [$EXTRAS]"; \
|
echo "Installing with extras: [$EXTRAS]"; \
|
||||||
pip install --no-cache-dir whisperlivekit[$EXTRAS]; \
|
pip install --no-cache-dir "whisperlivekit[$EXTRAS]"; \
|
||||||
else \
|
else \
|
||||||
echo "Installing base package only"; \
|
echo "Installing base package only"; \
|
||||||
pip install --no-cache-dir whisperlivekit; \
|
pip install --no-cache-dir whisperlivekit; \
|
||||||
|
|||||||
@@ -147,8 +147,8 @@ async def websocket_endpoint(websocket: WebSocket):
|
|||||||
|-----------|-------------|---------|
|
|-----------|-------------|---------|
|
||||||
| `--model` | Whisper model size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/default_and_custom_models.md) | `small` |
|
| `--model` | Whisper model size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/default_and_custom_models.md) | `small` |
|
||||||
| `--model-path` | Local .pt file/directory **or** Hugging Face repo ID containing the Whisper model. Overrides `--model`. Recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/default_and_custom_models.md) | `None` |
|
| `--model-path` | Local .pt file/directory **or** Hugging Face repo ID containing the Whisper model. Overrides `--model`. Recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/default_and_custom_models.md) | `None` |
|
||||||
| `--language` | List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/whisper/tokenizer.py). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English. | `auto` |
|
| `--language` | List [here](docs/supported_languages.md). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English. | `auto` |
|
||||||
| `--target-language` | If sets, translates using [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting). [200 languages available](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/supported_languages.md). If you want to translate to english, you can also use `--direct-english-translation`. The STT model will try to directly output the translation. | `None` |
|
| `--target-language` | If sets, translates using [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting). [200 languages available](docs/supported_languages.md). If you want to translate to english, you can also use `--direct-english-translation`. The STT model will try to directly output the translation. | `None` |
|
||||||
| `--diarization` | Enable speaker identification | `False` |
|
| `--diarization` | Enable speaker identification | `False` |
|
||||||
| `--backend-policy` | Streaming strategy: `1`/`simulstreaming` uses AlignAtt SimulStreaming, `2`/`localagreement` uses the LocalAgreement policy | `simulstreaming` |
|
| `--backend-policy` | Streaming strategy: `1`/`simulstreaming` uses AlignAtt SimulStreaming, `2`/`localagreement` uses the LocalAgreement policy | `simulstreaming` |
|
||||||
| `--backend` | Whisper implementation selector. `auto` picks MLX on macOS (if installed), otherwise Faster-Whisper, otherwise vanilla Whisper. You can also force `mlx-whisper`, `faster-whisper`, `whisper`, or `openai-api` (LocalAgreement only) | `auto` |
|
| `--backend` | Whisper implementation selector. `auto` picks MLX on macOS (if installed), otherwise Faster-Whisper, otherwise vanilla Whisper. You can also force `mlx-whisper`, `faster-whisper`, `whisper`, or `openai-api` (LocalAgreement only) | `auto` |
|
||||||
@@ -267,7 +267,7 @@ docker run --gpus all -p 8000:8000 --name wlk wlk --model large-v3 --language fr
|
|||||||
#### Customization
|
#### Customization
|
||||||
|
|
||||||
- `--build-arg` Options:
|
- `--build-arg` Options:
|
||||||
- `EXTRAS="whisper-timestamped"` - Add extras to the image's installation (no spaces). Remember to set necessary container options!
|
- `EXTRAS="translation"` - Add extras to the image's installation (no spaces). Remember to set necessary container options!
|
||||||
- `HF_PRECACHE_DIR="./.cache/"` - Pre-load a model cache for faster first-time start
|
- `HF_PRECACHE_DIR="./.cache/"` - Pre-load a model cache for faster first-time start
|
||||||
- `HF_TKN_FILE="./token"` - Add your Hugging Face Hub access token to download gated models
|
- `HF_TKN_FILE="./token"` - Add your Hugging Face Hub access token to download gated models
|
||||||
|
|
||||||
|
|||||||
BIN
architecture.png
BIN
architecture.png
Binary file not shown.
|
Before Width: | Height: | Size: 422 KiB After Width: | Height: | Size: 422 KiB |
@@ -6,7 +6,7 @@ Capture the audio of your current tab, transcribe diarize and translate it using
|
|||||||
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/chrome-extension/demo-extension.png" alt="WhisperLiveKit Demo" width="730">
|
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/chrome-extension/demo-extension.png" alt="WhisperLiveKit Demo" width="730">
|
||||||
|
|
||||||
## Running this extension
|
## Running this extension
|
||||||
1. Run `python sync_extension.py` to copy frontend files to the `chrome-extension` directory.
|
1. Run `python scripts/sync_extension.py` to copy frontend files to the `chrome-extension` directory.
|
||||||
2. Load the `chrome-extension` directory in Chrome as an unpacked extension.
|
2. Load the `chrome-extension` directory in Chrome as an unpacked extension.
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "whisperlivekit"
|
name = "whisperlivekit"
|
||||||
version = "0.2.16.dev0"
|
version = "0.2.18"
|
||||||
description = "Real-time speech-to-text with speaker diarization using Whisper"
|
description = "Real-time speech-to-text with speaker diarization using Whisper"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
authors = [
|
authors = [
|
||||||
@@ -35,6 +35,7 @@ dependencies = [
|
|||||||
"torchaudio>=2.0.0",
|
"torchaudio>=2.0.0",
|
||||||
"torch>=2.0.0",
|
"torch>=2.0.0",
|
||||||
"huggingface-hub>=0.25.0",
|
"huggingface-hub>=0.25.0",
|
||||||
|
"faster-whisper>=1.2.0",
|
||||||
"tqdm",
|
"tqdm",
|
||||||
"tiktoken",
|
"tiktoken",
|
||||||
'triton>=2.0.0; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")'
|
'triton>=2.0.0; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")'
|
||||||
@@ -56,6 +57,7 @@ packages = [
|
|||||||
"whisperlivekit",
|
"whisperlivekit",
|
||||||
"whisperlivekit.diarization",
|
"whisperlivekit.diarization",
|
||||||
"whisperlivekit.simul_whisper",
|
"whisperlivekit.simul_whisper",
|
||||||
|
"whisperlivekit.simul_whisper.mlx",
|
||||||
"whisperlivekit.whisper",
|
"whisperlivekit.whisper",
|
||||||
"whisperlivekit.whisper.assets",
|
"whisperlivekit.whisper.assets",
|
||||||
"whisperlivekit.whisper.normalizers",
|
"whisperlivekit.whisper.normalizers",
|
||||||
@@ -67,4 +69,5 @@ packages = [
|
|||||||
[tool.setuptools.package-data]
|
[tool.setuptools.package-data]
|
||||||
whisperlivekit = ["web/*.html", "web/*.css", "web/*.js", "web/src/*.svg"]
|
whisperlivekit = ["web/*.html", "web/*.css", "web/*.js", "web/src/*.svg"]
|
||||||
"whisperlivekit.whisper.assets" = ["*.tiktoken", "*.npz"]
|
"whisperlivekit.whisper.assets" = ["*.tiktoken", "*.npz"]
|
||||||
|
"whisperlivekit.whisper.normalizers" = ["*.json"]
|
||||||
"whisperlivekit.silero_vad_models" = ["*.jit", "*.onnx"]
|
"whisperlivekit.silero_vad_models" = ["*.jit", "*.onnx"]
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from whisperlivekit.core import (TranscriptionEngine,
|
|||||||
online_diarization_factory, online_factory,
|
online_diarization_factory, online_factory,
|
||||||
online_translation_factory)
|
online_translation_factory)
|
||||||
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
|
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
|
||||||
from whisperlivekit.silero_vad_iterator import FixedVADIterator
|
from whisperlivekit.silero_vad_iterator import FixedVADIterator, OnnxWrapper, load_jit_vad
|
||||||
from whisperlivekit.timed_objects import (ASRToken, ChangeSpeaker, FrontData,
|
from whisperlivekit.timed_objects import (ASRToken, ChangeSpeaker, FrontData,
|
||||||
Segment, Silence, State, Transcript)
|
Segment, Silence, State, Transcript)
|
||||||
from whisperlivekit.tokens_alignment import TokensAlignment
|
from whisperlivekit.tokens_alignment import TokensAlignment
|
||||||
@@ -85,12 +85,14 @@ class AudioProcessor:
|
|||||||
|
|
||||||
# Models and processing
|
# Models and processing
|
||||||
self.asr: Any = models.asr
|
self.asr: Any = models.asr
|
||||||
self.vac_model: Any = models.vac_model
|
self.vac: Optional[FixedVADIterator] = None
|
||||||
if self.args.vac:
|
|
||||||
self.vac: Optional[FixedVADIterator] = FixedVADIterator(models.vac_model)
|
|
||||||
else:
|
|
||||||
self.vac: Optional[FixedVADIterator] = None
|
|
||||||
|
|
||||||
|
if self.args.vac:
|
||||||
|
if models.vac_session is not None:
|
||||||
|
vac_model = OnnxWrapper(session=models.vac_session)
|
||||||
|
self.vac = FixedVADIterator(vac_model)
|
||||||
|
else:
|
||||||
|
self.vac = FixedVADIterator(load_jit_vad())
|
||||||
self.ffmpeg_manager: Optional[FFmpegManager] = None
|
self.ffmpeg_manager: Optional[FFmpegManager] = None
|
||||||
self.ffmpeg_reader_task: Optional[asyncio.Task] = None
|
self.ffmpeg_reader_task: Optional[asyncio.Task] = None
|
||||||
self._ffmpeg_error: Optional[str] = None
|
self._ffmpeg_error: Optional[str] = None
|
||||||
@@ -351,11 +353,14 @@ class AudioProcessor:
|
|||||||
if item.has_ended:
|
if item.has_ended:
|
||||||
self.diarization.insert_silence(item.duration)
|
self.diarization.insert_silence(item.duration)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
self.diarization.insert_audio_chunk(item)
|
self.diarization.insert_audio_chunk(item)
|
||||||
diarization_segments = await self.diarization.diarize()
|
diarization_segments = await self.diarization.diarize()
|
||||||
self.state.new_diarization = diarization_segments
|
diar_end = 0.0
|
||||||
|
if diarization_segments:
|
||||||
|
diar_end = max(getattr(s, "end", 0.0) for s in diarization_segments)
|
||||||
|
async with self.lock:
|
||||||
|
self.state.new_diarization = diarization_segments
|
||||||
|
self.state.end_attributed_speaker = max(self.state.end_attributed_speaker, diar_end)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Exception in diarization_processor: {e}")
|
logger.warning(f"Exception in diarization_processor: {e}")
|
||||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||||
|
|||||||
@@ -29,6 +29,13 @@ def mlx_backend_available(warn_on_missing = False):
|
|||||||
return available
|
return available
|
||||||
|
|
||||||
|
|
||||||
|
def voxmlx_backend_available():
|
||||||
|
"""Return True if voxmlx (Voxtral MLX backend) is available."""
|
||||||
|
is_macos = platform.system() == "Darwin"
|
||||||
|
is_arm = platform.machine() == "arm64"
|
||||||
|
return is_macos and is_arm and module_available("voxmlx")
|
||||||
|
|
||||||
|
|
||||||
def faster_backend_available(warn_on_missing = False):
|
def faster_backend_available(warn_on_missing = False):
|
||||||
available = module_available("faster_whisper")
|
available = module_available("faster_whisper")
|
||||||
if not available and warn_on_missing and platform.system() != "Darwin":
|
if not available and warn_on_missing and platform.system() != "Darwin":
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
|
import threading
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
|
||||||
from whisperlivekit.local_agreement.online_asr import OnlineASRProcessor
|
from whisperlivekit.local_agreement.online_asr import OnlineASRProcessor
|
||||||
@@ -19,16 +20,26 @@ logger = logging.getLogger(__name__)
|
|||||||
class TranscriptionEngine:
|
class TranscriptionEngine:
|
||||||
_instance = None
|
_instance = None
|
||||||
_initialized = False
|
_initialized = False
|
||||||
|
_lock = threading.Lock() # Thread-safe singleton lock
|
||||||
|
|
||||||
def __new__(cls, *args, **kwargs):
|
def __new__(cls, *args, **kwargs):
|
||||||
|
# Double-checked locking pattern for thread-safe singleton
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
cls._instance = super().__new__(cls)
|
with cls._lock:
|
||||||
|
# Check again inside lock to prevent race condition
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
if TranscriptionEngine._initialized:
|
# Thread-safe initialization check
|
||||||
return
|
with TranscriptionEngine._lock:
|
||||||
|
if TranscriptionEngine._initialized:
|
||||||
|
return
|
||||||
|
# Set flag immediately to prevent re-initialization
|
||||||
|
TranscriptionEngine._initialized = True
|
||||||
|
|
||||||
|
# Perform initialization outside lock to avoid holding lock during slow operations
|
||||||
global_params = {
|
global_params = {
|
||||||
"host": "localhost",
|
"host": "localhost",
|
||||||
"port": 8000,
|
"port": 8000,
|
||||||
@@ -36,7 +47,6 @@ class TranscriptionEngine:
|
|||||||
"punctuation_split": False,
|
"punctuation_split": False,
|
||||||
"target_language": "",
|
"target_language": "",
|
||||||
"vac": True,
|
"vac": True,
|
||||||
"vac_onnx": False,
|
|
||||||
"vac_chunk_size": 0.04,
|
"vac_chunk_size": 0.04,
|
||||||
"log_level": "DEBUG",
|
"log_level": "DEBUG",
|
||||||
"ssl_certfile": None,
|
"ssl_certfile": None,
|
||||||
@@ -79,18 +89,27 @@ class TranscriptionEngine:
|
|||||||
self.asr = None
|
self.asr = None
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
self.diarization = None
|
self.diarization = None
|
||||||
self.vac_model = None
|
self.vac_session = None
|
||||||
|
|
||||||
if self.args.vac:
|
if self.args.vac:
|
||||||
from whisperlivekit.silero_vad_iterator import load_silero_vad
|
from whisperlivekit.silero_vad_iterator import is_onnx_available
|
||||||
|
|
||||||
# Use ONNX if specified, otherwise use JIT (default)
|
|
||||||
use_onnx = kwargs.get('vac_onnx', False)
|
|
||||||
self.vac_model = load_silero_vad(onnx=use_onnx)
|
|
||||||
|
|
||||||
|
if is_onnx_available():
|
||||||
|
from whisperlivekit.silero_vad_iterator import load_onnx_session
|
||||||
|
self.vac_session = load_onnx_session()
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"onnxruntime not installed. VAC will use JIT model which is loaded per-session. "
|
||||||
|
"For multi-user scenarios, install onnxruntime: pip install onnxruntime"
|
||||||
|
)
|
||||||
backend_policy = self.args.backend_policy
|
backend_policy = self.args.backend_policy
|
||||||
if self.args.transcription:
|
if self.args.transcription:
|
||||||
if backend_policy == "simulstreaming":
|
if self.args.backend == "voxtral-mlx":
|
||||||
|
from whisperlivekit.voxtral_streaming import VoxtralStreamingASR
|
||||||
|
self.tokenizer = None
|
||||||
|
self.asr = VoxtralStreamingASR(**transcription_common_params)
|
||||||
|
logger.info("Using Voxtral MLX streaming backend")
|
||||||
|
elif backend_policy == "simulstreaming":
|
||||||
simulstreaming_params = {
|
simulstreaming_params = {
|
||||||
"disable_fast_encoder": False,
|
"disable_fast_encoder": False,
|
||||||
"custom_alignment_heads": None,
|
"custom_alignment_heads": None,
|
||||||
@@ -169,16 +188,16 @@ class TranscriptionEngine:
|
|||||||
}
|
}
|
||||||
translation_params = update_with_kwargs(translation_params, kwargs)
|
translation_params = update_with_kwargs(translation_params, kwargs)
|
||||||
self.translation_model = load_model([self.args.lan], **translation_params) #in the future we want to handle different languages for different speakers
|
self.translation_model = load_model([self.args.lan], **translation_params) #in the future we want to handle different languages for different speakers
|
||||||
TranscriptionEngine._initialized = True
|
|
||||||
|
|
||||||
|
|
||||||
def online_factory(args, asr):
|
def online_factory(args, asr):
|
||||||
|
if getattr(args, 'backend', None) == "voxtral-mlx":
|
||||||
|
from whisperlivekit.voxtral_streaming import VoxtralStreamingOnlineProcessor
|
||||||
|
return VoxtralStreamingOnlineProcessor(asr)
|
||||||
if args.backend_policy == "simulstreaming":
|
if args.backend_policy == "simulstreaming":
|
||||||
from whisperlivekit.simul_whisper import SimulStreamingOnlineProcessor
|
from whisperlivekit.simul_whisper import SimulStreamingOnlineProcessor
|
||||||
online = SimulStreamingOnlineProcessor(asr)
|
return SimulStreamingOnlineProcessor(asr)
|
||||||
else:
|
return OnlineASRProcessor(asr)
|
||||||
online = OnlineASRProcessor(asr)
|
|
||||||
return online
|
|
||||||
|
|
||||||
|
|
||||||
def online_diarization_factory(args, diarization_backend):
|
def online_diarization_factory(args, diarization_backend):
|
||||||
|
|||||||
@@ -202,14 +202,14 @@ class DiartDiarization:
|
|||||||
def insert_silence(self, silence_duration):
|
def insert_silence(self, silence_duration):
|
||||||
self.observer.global_time_offset += silence_duration
|
self.observer.global_time_offset += silence_duration
|
||||||
|
|
||||||
async def diarize(self, pcm_array: np.ndarray):
|
def insert_audio_chunk(self, pcm_array: np.ndarray):
|
||||||
"""
|
"""Buffer audio for the next diarization step."""
|
||||||
Process audio data for diarization.
|
|
||||||
Only used when working with WebSocketAudioSource.
|
|
||||||
"""
|
|
||||||
if self.custom_source:
|
if self.custom_source:
|
||||||
self.custom_source.push_audio(pcm_array)
|
self.custom_source.push_audio(pcm_array)
|
||||||
# self.observer.clear_old_segments()
|
|
||||||
|
async def diarize(self):
|
||||||
|
"""Return the current speaker segments from the diarization pipeline."""
|
||||||
|
return self.observer.get_segments()
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
"""Close the audio source."""
|
"""Close the audio source."""
|
||||||
|
|||||||
@@ -249,6 +249,7 @@ class OpenaiApiASR(ASRBase):
|
|||||||
self.load_model()
|
self.load_model()
|
||||||
self.use_vad_opt = False
|
self.use_vad_opt = False
|
||||||
self.direct_english_translation = False
|
self.direct_english_translation = False
|
||||||
|
self.task = "transcribe"
|
||||||
|
|
||||||
def load_model(self, *args, **kwargs):
|
def load_model(self, *args, **kwargs):
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
@@ -294,7 +295,8 @@ class OpenaiApiASR(ASRBase):
|
|||||||
params["language"] = self.original_language
|
params["language"] = self.original_language
|
||||||
if prompt:
|
if prompt:
|
||||||
params["prompt"] = prompt
|
params["prompt"] = prompt
|
||||||
proc = self.client.audio.translations if self.task == "translate" else self.client.audio.transcriptions
|
task = self.transcribe_kargs.get("task", self.task)
|
||||||
|
proc = self.client.audio.translations if task == "translate" else self.client.audio.transcriptions
|
||||||
transcript = proc.create(**params)
|
transcript = proc.create(**params)
|
||||||
logger.debug(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds")
|
logger.debug(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds")
|
||||||
return transcript
|
return transcript
|
||||||
|
|||||||
@@ -146,6 +146,7 @@ def backend_factory(
|
|||||||
|
|
||||||
if direct_english_translation:
|
if direct_english_translation:
|
||||||
tgt_language = "en" # Whisper translates into English
|
tgt_language = "en" # Whisper translates into English
|
||||||
|
asr.transcribe_kargs["task"] = "translate"
|
||||||
else:
|
else:
|
||||||
tgt_language = lan # Whisper transcribes in this language
|
tgt_language = lan # Whisper transcribes in this language
|
||||||
|
|
||||||
|
|||||||
@@ -147,8 +147,8 @@ def parse_args():
|
|||||||
"--backend",
|
"--backend",
|
||||||
type=str,
|
type=str,
|
||||||
default="auto",
|
default="auto",
|
||||||
choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api"],
|
choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api", "voxtral-mlx"],
|
||||||
help="Select the Whisper backend implementation (auto: prefer MLX on macOS, otherwise Faster-Whisper, else Whisper). Use 'openai-api' with --backend-policy localagreement to call OpenAI's API.",
|
help="Select the Whisper backend implementation (auto: prefer MLX on macOS, otherwise Faster-Whisper, else Whisper). Use 'openai-api' with --backend-policy localagreement to call OpenAI's API. Use 'voxtral-mlx' for Voxtral streaming on Apple Silicon.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--no-vac",
|
"--no-vac",
|
||||||
|
|||||||
@@ -8,6 +8,15 @@ import torch
|
|||||||
Code is adapted from silero-vad v6: https://github.com/snakers4/silero-vad
|
Code is adapted from silero-vad v6: https://github.com/snakers4/silero-vad
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def is_onnx_available() -> bool:
|
||||||
|
"""Check if onnxruntime is installed."""
|
||||||
|
try:
|
||||||
|
import onnxruntime
|
||||||
|
return True
|
||||||
|
except ImportError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def init_jit_model(model_path: str, device=torch.device('cpu')):
|
def init_jit_model(model_path: str, device=torch.device('cpu')):
|
||||||
"""Load a JIT model from file."""
|
"""Load a JIT model from file."""
|
||||||
model = torch.jit.load(model_path, map_location=device)
|
model = torch.jit.load(model_path, map_location=device)
|
||||||
@@ -15,12 +24,12 @@ def init_jit_model(model_path: str, device=torch.device('cpu')):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
class OnnxWrapper():
|
class OnnxSession():
|
||||||
"""ONNX Runtime wrapper for Silero VAD model."""
|
"""
|
||||||
|
Shared ONNX session for Silero VAD model (stateless).
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, path, force_onnx_cpu=False):
|
def __init__(self, path, force_onnx_cpu=False):
|
||||||
global np
|
|
||||||
import numpy as np
|
|
||||||
import onnxruntime
|
import onnxruntime
|
||||||
|
|
||||||
opts = onnxruntime.SessionOptions()
|
opts = onnxruntime.SessionOptions()
|
||||||
@@ -32,13 +41,28 @@ class OnnxWrapper():
|
|||||||
else:
|
else:
|
||||||
self.session = onnxruntime.InferenceSession(path, sess_options=opts)
|
self.session = onnxruntime.InferenceSession(path, sess_options=opts)
|
||||||
|
|
||||||
self.reset_states()
|
self.path = path
|
||||||
if '16k' in path:
|
if '16k' in path:
|
||||||
warnings.warn('This model support only 16000 sampling rate!')
|
warnings.warn('This model support only 16000 sampling rate!')
|
||||||
self.sample_rates = [16000]
|
self.sample_rates = [16000]
|
||||||
else:
|
else:
|
||||||
self.sample_rates = [8000, 16000]
|
self.sample_rates = [8000, 16000]
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxWrapper():
|
||||||
|
"""
|
||||||
|
ONNX Runtime wrapper for Silero VAD model with per-instance state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, session: OnnxSession, force_onnx_cpu=False):
|
||||||
|
self._shared_session = session
|
||||||
|
self.sample_rates = session.sample_rates
|
||||||
|
self.reset_states()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def session(self):
|
||||||
|
return self._shared_session.session
|
||||||
|
|
||||||
def _validate_input(self, x, sr: int):
|
def _validate_input(self, x, sr: int):
|
||||||
if x.dim() == 1:
|
if x.dim() == 1:
|
||||||
x = x.unsqueeze(0)
|
x = x.unsqueeze(0)
|
||||||
@@ -101,38 +125,20 @@ class OnnxWrapper():
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def load_silero_vad(model_path: str = None, onnx: bool = False, opset_version: int = 16):
|
def _get_onnx_model_path(model_path: str = None, opset_version: int = 16) -> Path:
|
||||||
"""
|
"""Get the path to the ONNX model file."""
|
||||||
Load Silero VAD model (JIT or ONNX).
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
model_path : str, optional
|
|
||||||
Path to model file. If None, uses default bundled model.
|
|
||||||
onnx : bool, default False
|
|
||||||
Whether to use ONNX runtime (requires onnxruntime package).
|
|
||||||
opset_version : int, default 16
|
|
||||||
ONNX opset version (15 or 16). Only used if onnx=True.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
model
|
|
||||||
Loaded VAD model (JIT or ONNX wrapper)
|
|
||||||
"""
|
|
||||||
available_ops = [15, 16]
|
available_ops = [15, 16]
|
||||||
if onnx and opset_version not in available_ops:
|
if opset_version not in available_ops:
|
||||||
raise Exception(f'Available ONNX opset_version: {available_ops}')
|
raise Exception(f'Available ONNX opset_version: {available_ops}')
|
||||||
|
|
||||||
if model_path is None:
|
if model_path is None:
|
||||||
current_dir = Path(__file__).parent
|
current_dir = Path(__file__).parent
|
||||||
data_dir = current_dir / 'silero_vad_models'
|
data_dir = current_dir / 'silero_vad_models'
|
||||||
|
|
||||||
if onnx:
|
if opset_version == 16:
|
||||||
if opset_version == 16:
|
model_name = 'silero_vad.onnx'
|
||||||
model_name = 'silero_vad.onnx'
|
|
||||||
else:
|
|
||||||
model_name = f'silero_vad_16k_op{opset_version}.onnx'
|
|
||||||
else:
|
else:
|
||||||
model_name = 'silero_vad.jit'
|
model_name = f'silero_vad_16k_op{opset_version}.onnx'
|
||||||
|
|
||||||
model_path = data_dir / model_name
|
model_path = data_dir / model_name
|
||||||
|
|
||||||
@@ -143,16 +149,38 @@ def load_silero_vad(model_path: str = None, onnx: bool = False, opset_version: i
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model_path = Path(model_path)
|
model_path = Path(model_path)
|
||||||
if onnx:
|
|
||||||
try:
|
return model_path
|
||||||
model = OnnxWrapper(str(model_path), force_onnx_cpu=True)
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError(
|
def load_onnx_session(model_path: str = None, opset_version: int = 16, force_onnx_cpu: bool = True) -> OnnxSession:
|
||||||
"ONNX runtime not available. Install with: pip install onnxruntime\n"
|
"""
|
||||||
"Or use JIT model by setting onnx=False"
|
Load a shared ONNX session for Silero VAD.
|
||||||
|
"""
|
||||||
|
path = _get_onnx_model_path(model_path, opset_version)
|
||||||
|
return OnnxSession(str(path), force_onnx_cpu=force_onnx_cpu)
|
||||||
|
|
||||||
|
|
||||||
|
def load_jit_vad(model_path: str = None):
|
||||||
|
"""
|
||||||
|
Load Silero VAD model in JIT format.
|
||||||
|
"""
|
||||||
|
if model_path is None:
|
||||||
|
current_dir = Path(__file__).parent
|
||||||
|
data_dir = current_dir / 'silero_vad_models'
|
||||||
|
model_name = 'silero_vad.jit'
|
||||||
|
|
||||||
|
model_path = data_dir / model_name
|
||||||
|
|
||||||
|
if not model_path.exists():
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Model file not found: {model_path}\n"
|
||||||
|
f"Please ensure the whisperlivekit/silero_vad_models/ directory contains the model files."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model = init_jit_model(str(model_path))
|
model_path = Path(model_path)
|
||||||
|
|
||||||
|
model = init_jit_model(str(model_path))
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@@ -285,8 +313,8 @@ class FixedVADIterator(VADIterator):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
model = load_silero_vad(onnx=False)
|
# vad = FixedVADIterator(load_jit_vad())
|
||||||
vad = FixedVADIterator(model)
|
vad = FixedVADIterator(OnnxWrapper(session=load_onnx_session()))
|
||||||
|
|
||||||
audio_buffer = np.array([0] * 512, dtype=np.float32)
|
audio_buffer = np.array([0] * 512, dtype=np.float32)
|
||||||
result = vad(audio_buffer)
|
result = vad(audio_buffer)
|
||||||
@@ -295,3 +323,4 @@ if __name__ == "__main__":
|
|||||||
# test with 511 samples
|
# test with 511 samples
|
||||||
audio_buffer = np.array([0] * 511, dtype=np.float32)
|
audio_buffer = np.array([0] * 511, dtype=np.float32)
|
||||||
result = vad(audio_buffer)
|
result = vad(audio_buffer)
|
||||||
|
print(f" 511 samples: {result}")
|
||||||
@@ -24,9 +24,11 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
HAS_MLX_WHISPER = mlx_backend_available(warn_on_missing=True)
|
HAS_MLX_WHISPER = mlx_backend_available(warn_on_missing=True)
|
||||||
if HAS_MLX_WHISPER:
|
if HAS_MLX_WHISPER:
|
||||||
from .mlx_encoder import load_mlx_encoder, mlx_model_mapping
|
from .mlx_encoder import load_mlx_encoder, load_mlx_model, mlx_model_mapping
|
||||||
|
from .mlx import MLXAlignAtt
|
||||||
else:
|
else:
|
||||||
mlx_model_mapping = {}
|
mlx_model_mapping = {}
|
||||||
|
MLXAlignAtt = None
|
||||||
HAS_FASTER_WHISPER = faster_backend_available(warn_on_missing=not HAS_MLX_WHISPER)
|
HAS_FASTER_WHISPER = faster_backend_available(warn_on_missing=not HAS_MLX_WHISPER)
|
||||||
if HAS_FASTER_WHISPER:
|
if HAS_FASTER_WHISPER:
|
||||||
from faster_whisper import WhisperModel
|
from faster_whisper import WhisperModel
|
||||||
@@ -36,50 +38,47 @@ else:
|
|||||||
MIN_DURATION_REAL_SILENCE = 5
|
MIN_DURATION_REAL_SILENCE = 5
|
||||||
|
|
||||||
class SimulStreamingOnlineProcessor:
|
class SimulStreamingOnlineProcessor:
|
||||||
|
"""Online processor for SimulStreaming ASR."""
|
||||||
SAMPLING_RATE = 16000
|
SAMPLING_RATE = 16000
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, asr, logfile=sys.stderr):
|
||||||
self,
|
|
||||||
asr,
|
|
||||||
logfile=sys.stderr,
|
|
||||||
):
|
|
||||||
self.asr = asr
|
self.asr = asr
|
||||||
self.logfile = logfile
|
self.logfile = logfile
|
||||||
self.end = 0.0
|
self.end = 0.0
|
||||||
self.buffer = []
|
self.buffer = []
|
||||||
self.committed: List[ASRToken] = []
|
self.model = self._create_alignatt()
|
||||||
self.last_result_tokens: List[ASRToken] = []
|
|
||||||
self.load_new_alignatt_instance()
|
|
||||||
|
|
||||||
if asr.tokenizer:
|
if asr.tokenizer:
|
||||||
self.model.tokenizer = asr.tokenizer
|
self.model.tokenizer = asr.tokenizer
|
||||||
|
self.model.state.tokenizer = asr.tokenizer
|
||||||
|
|
||||||
def load_new_alignatt_instance(self):
|
def _create_alignatt(self):
|
||||||
"""Initialize AlignAtt decoder using the shared model."""
|
"""Create the AlignAtt decoder instance based on ASR mode."""
|
||||||
self.model = AlignAtt(
|
if self.asr.use_full_mlx and HAS_MLX_WHISPER:
|
||||||
cfg=self.asr.cfg,
|
return MLXAlignAtt(cfg=self.asr.cfg, mlx_model=self.asr.mlx_model)
|
||||||
loaded_model=self.asr.shared_model,
|
else:
|
||||||
mlx_encoder=self.asr.mlx_encoder,
|
return AlignAtt(
|
||||||
fw_encoder=self.asr.fw_encoder,
|
cfg=self.asr.cfg,
|
||||||
)
|
loaded_model=self.asr.shared_model,
|
||||||
|
mlx_encoder=self.asr.mlx_encoder,
|
||||||
|
fw_encoder=self.asr.fw_encoder,
|
||||||
|
)
|
||||||
|
|
||||||
def start_silence(self):
|
def start_silence(self):
|
||||||
tokens, processed_upto = self.process_iter(is_last=True)
|
tokens, processed_upto = self.process_iter(is_last=True)
|
||||||
return tokens, processed_upto
|
return tokens, processed_upto
|
||||||
|
|
||||||
def end_silence(self, silence_duration, offset):
|
def end_silence(self, silence_duration, offset):
|
||||||
"""
|
"""Handle silence period."""
|
||||||
Handle silence period.
|
|
||||||
|
|
||||||
If silence > MIN_DURATION_REAL_SILENCE, do a complete context clear.
|
|
||||||
Otherwise, insert a small silence and shift the last_attend_frame.
|
|
||||||
"""
|
|
||||||
self.end += silence_duration
|
self.end += silence_duration
|
||||||
long_silence = silence_duration >= MIN_DURATION_REAL_SILENCE
|
long_silence = silence_duration >= MIN_DURATION_REAL_SILENCE
|
||||||
if not long_silence:
|
if not long_silence:
|
||||||
gap_len = int(16000 * silence_duration)
|
gap_len = int(16000 * silence_duration)
|
||||||
if gap_len > 0:
|
if gap_len > 0:
|
||||||
gap_silence = torch.zeros(gap_len)
|
if self.asr.use_full_mlx:
|
||||||
|
gap_silence = np.zeros(gap_len, dtype=np.float32)
|
||||||
|
else:
|
||||||
|
gap_silence = torch.zeros(gap_len)
|
||||||
self.model.insert_audio(gap_silence)
|
self.model.insert_audio(gap_silence)
|
||||||
if long_silence:
|
if long_silence:
|
||||||
self.model.refresh_segment(complete=True)
|
self.model.refresh_segment(complete=True)
|
||||||
@@ -87,11 +86,12 @@ class SimulStreamingOnlineProcessor:
|
|||||||
|
|
||||||
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time):
|
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time):
|
||||||
"""Append an audio chunk to be processed by SimulStreaming."""
|
"""Append an audio chunk to be processed by SimulStreaming."""
|
||||||
|
self.end = audio_stream_end_time
|
||||||
# Convert numpy array to torch tensor
|
if self.asr.use_full_mlx:
|
||||||
audio_tensor = torch.from_numpy(audio).float()
|
self.model.insert_audio(audio)
|
||||||
self.end = audio_stream_end_time # Aligned with whisperstreaming backend behavior
|
else:
|
||||||
self.model.insert_audio(audio_tensor)
|
audio_tensor = torch.from_numpy(audio).float()
|
||||||
|
self.model.insert_audio(audio_tensor)
|
||||||
|
|
||||||
def new_speaker(self, change_speaker: ChangeSpeaker):
|
def new_speaker(self, change_speaker: ChangeSpeaker):
|
||||||
"""Handle speaker change event."""
|
"""Handle speaker change event."""
|
||||||
@@ -120,7 +120,6 @@ class SimulStreamingOnlineProcessor:
|
|||||||
self.buffer.extend(timestamped_words)
|
self.buffer.extend(timestamped_words)
|
||||||
return [], self.end
|
return [], self.end
|
||||||
|
|
||||||
self.committed.extend(timestamped_words)
|
|
||||||
self.buffer = []
|
self.buffer = []
|
||||||
return timestamped_words, self.end
|
return timestamped_words, self.end
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -130,6 +129,10 @@ class SimulStreamingOnlineProcessor:
|
|||||||
def warmup(self, audio, init_prompt=""):
|
def warmup(self, audio, init_prompt=""):
|
||||||
"""Warmup the SimulStreaming model."""
|
"""Warmup the SimulStreaming model."""
|
||||||
try:
|
try:
|
||||||
|
if self.asr.use_full_mlx:
|
||||||
|
# MLX mode: ensure numpy array
|
||||||
|
if hasattr(audio, 'numpy'):
|
||||||
|
audio = audio.numpy()
|
||||||
self.model.insert_audio(audio)
|
self.model.insert_audio(audio)
|
||||||
self.model.infer(True)
|
self.model.infer(True)
|
||||||
self.model.refresh_segment(complete=True)
|
self.model.refresh_segment(complete=True)
|
||||||
@@ -139,9 +142,14 @@ class SimulStreamingOnlineProcessor:
|
|||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
if not getattr(self.asr, 'use_full_mlx', True) and torch is not None:
|
||||||
|
try:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
class SimulStreamingASR():
|
|
||||||
|
class SimulStreamingASR:
|
||||||
"""SimulStreaming backend with AlignAtt policy."""
|
"""SimulStreaming backend with AlignAtt policy."""
|
||||||
sep = ""
|
sep = ""
|
||||||
|
|
||||||
@@ -158,6 +166,7 @@ class SimulStreamingASR():
|
|||||||
self.fast_encoder = False
|
self.fast_encoder = False
|
||||||
self._resolved_model_path = None
|
self._resolved_model_path = None
|
||||||
self.encoder_backend = "whisper"
|
self.encoder_backend = "whisper"
|
||||||
|
self.use_full_mlx = getattr(self, "use_full_mlx", False)
|
||||||
preferred_backend = getattr(self, "backend", "auto")
|
preferred_backend = getattr(self, "backend", "auto")
|
||||||
compatible_whisper_mlx, compatible_faster_whisper = True, True
|
compatible_whisper_mlx, compatible_faster_whisper = True, True
|
||||||
|
|
||||||
@@ -170,7 +179,7 @@ class SimulStreamingASR():
|
|||||||
compatible_whisper_mlx = model_info.compatible_whisper_mlx
|
compatible_whisper_mlx = model_info.compatible_whisper_mlx
|
||||||
compatible_faster_whisper = model_info.compatible_faster_whisper
|
compatible_faster_whisper = model_info.compatible_faster_whisper
|
||||||
|
|
||||||
if not model_info.has_pytorch:
|
if not self.use_full_mlx and not model_info.has_pytorch:
|
||||||
raise FileNotFoundError(
|
raise FileNotFoundError(
|
||||||
f"No PyTorch checkpoint (.pt/.bin/.safetensors) found under {self.model_path}"
|
f"No PyTorch checkpoint (.pt/.bin/.safetensors) found under {self.model_path}"
|
||||||
)
|
)
|
||||||
@@ -191,6 +200,10 @@ class SimulStreamingASR():
|
|||||||
if self.encoder_backend == "whisper":
|
if self.encoder_backend == "whisper":
|
||||||
self.disable_fast_encoder = True
|
self.disable_fast_encoder = True
|
||||||
|
|
||||||
|
if self.encoder_backend == "mlx-whisper" and platform.system() == "Darwin":
|
||||||
|
if not hasattr(self, '_full_mlx_disabled'):
|
||||||
|
self.use_full_mlx = True
|
||||||
|
|
||||||
self.cfg = AlignAttConfig(
|
self.cfg = AlignAttConfig(
|
||||||
tokenizer_is_multilingual= is_multilingual,
|
tokenizer_is_multilingual= is_multilingual,
|
||||||
segment_length=self.min_chunk_size,
|
segment_length=self.min_chunk_size,
|
||||||
@@ -201,7 +214,7 @@ class SimulStreamingASR():
|
|||||||
cif_ckpt_path=self.cif_ckpt_path,
|
cif_ckpt_path=self.cif_ckpt_path,
|
||||||
decoder_type="beam",
|
decoder_type="beam",
|
||||||
beam_size=self.beams,
|
beam_size=self.beams,
|
||||||
task=self.direct_english_translation,
|
task="translate" if self.direct_english_translation else "transcribe",
|
||||||
never_fire=self.never_fire,
|
never_fire=self.never_fire,
|
||||||
init_prompt=self.init_prompt,
|
init_prompt=self.init_prompt,
|
||||||
max_context_tokens=self.max_context_tokens,
|
max_context_tokens=self.max_context_tokens,
|
||||||
@@ -214,20 +227,36 @@ class SimulStreamingASR():
|
|||||||
else:
|
else:
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
|
|
||||||
self.mlx_encoder, self.fw_encoder = None, None
|
self.mlx_encoder, self.fw_encoder, self.mlx_model = None, None, None
|
||||||
if self.encoder_backend == "mlx-whisper":
|
self.shared_model = None
|
||||||
print('Simulstreaming will use MLX whisper to increase encoding speed.')
|
|
||||||
|
if self.use_full_mlx and HAS_MLX_WHISPER:
|
||||||
|
logger.info('MLX Whisper backend used.')
|
||||||
if self._resolved_model_path is not None:
|
if self._resolved_model_path is not None:
|
||||||
mlx_model = str(self._resolved_model_path)
|
mlx_model_path = str(self._resolved_model_path)
|
||||||
else:
|
else:
|
||||||
mlx_model = mlx_model_mapping.get(self.model_name)
|
mlx_model_path = mlx_model_mapping.get(self.model_name)
|
||||||
if not mlx_model:
|
if not mlx_model_path:
|
||||||
raise FileNotFoundError(
|
raise FileNotFoundError(
|
||||||
f"MLX Whisper backend requested but no compatible weights found for model '{self.model_name}'."
|
f"MLX Whisper backend requested but no compatible weights found for model '{self.model_name}'."
|
||||||
)
|
)
|
||||||
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model)
|
self.mlx_model = load_mlx_model(path_or_hf_repo=mlx_model_path)
|
||||||
|
self._warmup_mlx_model()
|
||||||
|
elif self.encoder_backend == "mlx-whisper":
|
||||||
|
# hybrid mode: mlx encoder + pytorch decoder
|
||||||
|
logger.info('SimulStreaming will use MLX Whisper encoder with PyTorch decoder.')
|
||||||
|
if self._resolved_model_path is not None:
|
||||||
|
mlx_model_path = str(self._resolved_model_path)
|
||||||
|
else:
|
||||||
|
mlx_model_path = mlx_model_mapping.get(self.model_name)
|
||||||
|
if not mlx_model_path:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"MLX Whisper backend requested but no compatible weights found for model '{self.model_name}'."
|
||||||
|
)
|
||||||
|
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model_path)
|
||||||
|
self.shared_model = self.load_model()
|
||||||
elif self.encoder_backend == "faster-whisper":
|
elif self.encoder_backend == "faster-whisper":
|
||||||
print('Simulstreaming will use Faster Whisper for the encoder.')
|
print('SimulStreaming will use Faster Whisper for the encoder.')
|
||||||
if self._resolved_model_path is not None:
|
if self._resolved_model_path is not None:
|
||||||
fw_model = str(self._resolved_model_path)
|
fw_model = str(self._resolved_model_path)
|
||||||
else:
|
else:
|
||||||
@@ -237,7 +266,20 @@ class SimulStreamingASR():
|
|||||||
device='auto',
|
device='auto',
|
||||||
compute_type='auto',
|
compute_type='auto',
|
||||||
)
|
)
|
||||||
self.shared_model = self.load_model()
|
self.shared_model = self.load_model()
|
||||||
|
else:
|
||||||
|
self.shared_model = self.load_model()
|
||||||
|
|
||||||
|
def _warmup_mlx_model(self):
|
||||||
|
"""Warmup the full MLX model."""
|
||||||
|
warmup_audio = load_file(self.warmup_file)
|
||||||
|
if warmup_audio is not None:
|
||||||
|
temp_model = MLXAlignAtt(
|
||||||
|
cfg=self.cfg,
|
||||||
|
mlx_model=self.mlx_model,
|
||||||
|
)
|
||||||
|
temp_model.warmup(warmup_audio)
|
||||||
|
logger.info("Full MLX model warmed up successfully")
|
||||||
|
|
||||||
|
|
||||||
def _resolve_encoder_backend(self, preferred_backend, compatible_whisper_mlx, compatible_faster_whisper):
|
def _resolve_encoder_backend(self, preferred_backend, compatible_whisper_mlx, compatible_faster_whisper):
|
||||||
@@ -285,7 +327,7 @@ class SimulStreamingASR():
|
|||||||
lora_path = getattr(self, 'lora_path', None)
|
lora_path = getattr(self, 'lora_path', None)
|
||||||
whisper_model = load_model(
|
whisper_model = load_model(
|
||||||
name=model_ref,
|
name=model_ref,
|
||||||
download_root=None,
|
download_root=getattr(self, 'model_cache_dir', None),
|
||||||
decoder_only=self.fast_encoder,
|
decoder_only=self.fast_encoder,
|
||||||
custom_alignment_heads=self.custom_alignment_heads,
|
custom_alignment_heads=self.custom_alignment_heads,
|
||||||
lora_path=lora_path,
|
lora_path=lora_path,
|
||||||
|
|||||||
@@ -47,9 +47,24 @@ class DecoderState:
|
|||||||
|
|
||||||
def clean_cache(self):
|
def clean_cache(self):
|
||||||
"""Clean the kv_cache after each inference step."""
|
"""Clean the kv_cache after each inference step."""
|
||||||
self.kv_cache = {}
|
# Explicitly delete tensor references to free GPU memory
|
||||||
|
if self.kv_cache:
|
||||||
|
for key in list(self.kv_cache.keys()):
|
||||||
|
tensor = self.kv_cache.pop(key, None)
|
||||||
|
if tensor is not None:
|
||||||
|
del tensor
|
||||||
|
|
||||||
|
# Clear the dict
|
||||||
|
self.kv_cache.clear()
|
||||||
|
|
||||||
|
# Force GPU cache cleanup (only if CUDA is available)
|
||||||
|
import torch
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
if self.decoder_type == "beam" and self.inference is not None:
|
if self.decoder_type == "beam" and self.inference is not None:
|
||||||
self.inference.kv_cache = self.kv_cache
|
# Create NEW dict instead of sharing reference
|
||||||
|
self.inference.kv_cache = {}
|
||||||
if self.token_decoder is not None:
|
if self.token_decoder is not None:
|
||||||
self.token_decoder.reset()
|
self.token_decoder.reset()
|
||||||
|
|
||||||
|
|||||||
11
whisperlivekit/simul_whisper/mlx/__init__.py
Normal file
11
whisperlivekit/simul_whisper/mlx/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
from .decoder_state import MLXDecoderState
|
||||||
|
from .decoders import MLXBeamSearchDecoder, MLXGreedyDecoder, MLXInference
|
||||||
|
from .simul_whisper import MLXAlignAtt
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"MLXAlignAtt",
|
||||||
|
"MLXBeamSearchDecoder",
|
||||||
|
"MLXDecoderState",
|
||||||
|
"MLXGreedyDecoder",
|
||||||
|
"MLXInference",
|
||||||
|
]
|
||||||
76
whisperlivekit/simul_whisper/mlx/decoder_state.py
Normal file
76
whisperlivekit/simul_whisper/mlx/decoder_state.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MLXDecoderState:
|
||||||
|
"""
|
||||||
|
mlx kv cache format: List of ((k, v), (cross_k, cross_v)) tuples per layer,
|
||||||
|
where each element is a tuple of mx.arrays.
|
||||||
|
"""
|
||||||
|
|
||||||
|
kv_cache: Optional[List[Tuple[Tuple[mx.array, mx.array], Tuple[mx.array, mx.array]]]] = None
|
||||||
|
|
||||||
|
tokenizer: Any = None
|
||||||
|
detected_language: Optional[str] = None
|
||||||
|
reset_tokenizer_to_auto_next_call: bool = False
|
||||||
|
|
||||||
|
tokens: List[mx.array] = field(default_factory=list)
|
||||||
|
initial_tokens: Optional[mx.array] = None
|
||||||
|
initial_token_length: int = 0
|
||||||
|
sot_index: int = 0
|
||||||
|
align_source: Dict[int, List[Tuple[int, int]]] = field(default_factory=dict)
|
||||||
|
num_align_heads: int = 0
|
||||||
|
segments: List[np.ndarray] = field(default_factory=list)
|
||||||
|
|
||||||
|
context: Any = None
|
||||||
|
|
||||||
|
pending_incomplete_tokens: List[int] = field(default_factory=list)
|
||||||
|
|
||||||
|
global_time_offset: float = 0.0
|
||||||
|
cumulative_time_offset: float = 0.0
|
||||||
|
first_timestamp: Optional[float] = None
|
||||||
|
last_attend_frame: int = 0
|
||||||
|
|
||||||
|
speaker: int = -1
|
||||||
|
log_segments: int = 0
|
||||||
|
cif_weights: Optional[mx.array] = None
|
||||||
|
always_fire: bool = False
|
||||||
|
never_fire: bool = False
|
||||||
|
|
||||||
|
suppress_tokens: Optional[Tuple[int, ...]] = None
|
||||||
|
|
||||||
|
token_decoder: Any = None
|
||||||
|
decoder_type: str = "greedy"
|
||||||
|
|
||||||
|
inference: Any = None
|
||||||
|
|
||||||
|
def clean_cache(self):
|
||||||
|
self.kv_cache = None
|
||||||
|
if self.decoder_type == "beam" and self.inference is not None:
|
||||||
|
self.inference.kv_cache = None
|
||||||
|
if self.token_decoder is not None:
|
||||||
|
self.token_decoder.reset()
|
||||||
|
|
||||||
|
def reset(self, rewind_threshold: int = 200):
|
||||||
|
self.last_attend_frame = -rewind_threshold
|
||||||
|
self.cumulative_time_offset = 0.0
|
||||||
|
self.pending_incomplete_tokens = []
|
||||||
|
self.log_segments += 1
|
||||||
|
|
||||||
|
def full_reset(self, rewind_threshold: int = 200):
|
||||||
|
"""
|
||||||
|
Full reset including audio segments and tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rewind_threshold: Value for resetting last_attend_frame
|
||||||
|
"""
|
||||||
|
self.reset(rewind_threshold)
|
||||||
|
self.segments = []
|
||||||
|
self.tokens = []
|
||||||
|
self.kv_cache = None
|
||||||
|
self.first_timestamp = None
|
||||||
|
|
||||||
219
whisperlivekit/simul_whisper/mlx/decoders.py
Normal file
219
whisperlivekit/simul_whisper/mlx/decoders.py
Normal file
@@ -0,0 +1,219 @@
|
|||||||
|
"""
|
||||||
|
MLX-native token decoders for streaming ASR.
|
||||||
|
"""
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class MLXGreedyDecoder:
|
||||||
|
"""Greedy decoder using MLX operations."""
|
||||||
|
|
||||||
|
def __init__(self, temperature: float, eot: int):
|
||||||
|
self.temperature = temperature
|
||||||
|
self.eot = eot
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self, tokens: mx.array, logits: mx.array, sum_logprobs: mx.array
|
||||||
|
) -> Tuple[mx.array, bool]:
|
||||||
|
"""
|
||||||
|
Update tokens with next predicted token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens: Current token sequence, shape (batch, seq_len)
|
||||||
|
logits: Logits for next token, shape (batch, vocab_size)
|
||||||
|
sum_logprobs: Cumulative log probabilities, shape (batch,)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated tokens and completion flag
|
||||||
|
"""
|
||||||
|
if self.temperature == 0:
|
||||||
|
next_tokens = mx.argmax(logits, axis=-1)
|
||||||
|
else:
|
||||||
|
probs = mx.softmax(logits / self.temperature, axis=-1)
|
||||||
|
next_tokens = mx.random.categorical(mx.log(probs + 1e-10))
|
||||||
|
|
||||||
|
logprobs = mx.softmax(logits, axis=-1)
|
||||||
|
logprobs = mx.log(logprobs + 1e-10)
|
||||||
|
batch_size = logprobs.shape[0]
|
||||||
|
current_logprobs = logprobs[mx.arange(batch_size), next_tokens]
|
||||||
|
mask = (tokens[:, -1] != self.eot).astype(mx.float32)
|
||||||
|
sum_logprobs = sum_logprobs + current_logprobs * mask
|
||||||
|
eot_mask = (tokens[:, -1] == self.eot)
|
||||||
|
next_tokens = mx.where(eot_mask, mx.array(self.eot), next_tokens)
|
||||||
|
tokens = mx.concatenate([tokens, next_tokens[:, None]], axis=1)
|
||||||
|
completed = bool(mx.all(tokens[:, -1] == self.eot))
|
||||||
|
|
||||||
|
return tokens, completed
|
||||||
|
|
||||||
|
def finalize(self, tokens: mx.array, sum_logprobs: mx.array):
|
||||||
|
"""Finalize decoding by ensuring EOT at end."""
|
||||||
|
eot_column = mx.full((tokens.shape[0], 1), self.eot, dtype=tokens.dtype)
|
||||||
|
tokens = mx.concatenate([tokens, eot_column], axis=1)
|
||||||
|
return tokens, sum_logprobs.tolist()
|
||||||
|
|
||||||
|
|
||||||
|
class MLXBeamSearchDecoder:
|
||||||
|
"""Beam search decoder using MLX operations."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
beam_size: int,
|
||||||
|
eot: int,
|
||||||
|
inference: Any,
|
||||||
|
patience: Optional[float] = None,
|
||||||
|
):
|
||||||
|
self.beam_size = beam_size
|
||||||
|
self.eot = eot
|
||||||
|
self.inference = inference
|
||||||
|
self.patience = patience or 1.0
|
||||||
|
self.max_candidates: int = round(beam_size * self.patience)
|
||||||
|
self.finished_sequences: Optional[List[Dict]] = None
|
||||||
|
|
||||||
|
assert (
|
||||||
|
self.max_candidates > 0
|
||||||
|
), f"Invalid beam size ({beam_size}) or patience ({patience})"
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Reset finished sequences for new segment."""
|
||||||
|
self.finished_sequences = None
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self, tokens: mx.array, logits: mx.array, sum_logprobs: mx.array
|
||||||
|
) -> Tuple[mx.array, bool]:
|
||||||
|
"""
|
||||||
|
Update tokens using beam search.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens: Current token sequences, shape (batch * beam_size, seq_len)
|
||||||
|
logits: Logits for next token, shape (batch * beam_size, vocab_size)
|
||||||
|
sum_logprobs: Cumulative log probabilities, shape (batch * beam_size,)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated tokens and completion flag
|
||||||
|
"""
|
||||||
|
if tokens.shape[0] % self.beam_size != 0:
|
||||||
|
raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
|
||||||
|
|
||||||
|
n_audio = tokens.shape[0] // self.beam_size
|
||||||
|
if self.finished_sequences is None:
|
||||||
|
self.finished_sequences = [{} for _ in range(n_audio)]
|
||||||
|
logprobs = mx.softmax(logits, axis=-1)
|
||||||
|
logprobs = mx.log(logprobs + 1e-10)
|
||||||
|
logprobs_np = np.array(logprobs)
|
||||||
|
tokens_np = np.array(tokens)
|
||||||
|
sum_logprobs_np = np.array(sum_logprobs)
|
||||||
|
|
||||||
|
next_tokens, source_indices, finished_sequences = [], [], []
|
||||||
|
new_sum_logprobs = []
|
||||||
|
|
||||||
|
for i in range(n_audio):
|
||||||
|
scores, sources, finished = {}, {}, {}
|
||||||
|
for j in range(self.beam_size):
|
||||||
|
idx = i * self.beam_size + j
|
||||||
|
prefix = tokens_np[idx].tolist()
|
||||||
|
top_k_indices = np.argsort(logprobs_np[idx])[-self.beam_size - 1:][::-1]
|
||||||
|
|
||||||
|
for token_idx in top_k_indices:
|
||||||
|
logprob = logprobs_np[idx, token_idx]
|
||||||
|
new_logprob = sum_logprobs_np[idx] + logprob
|
||||||
|
sequence = tuple(prefix + [int(token_idx)])
|
||||||
|
scores[sequence] = new_logprob
|
||||||
|
sources[sequence] = idx
|
||||||
|
saved = 0
|
||||||
|
for sequence in sorted(scores, key=scores.get, reverse=True):
|
||||||
|
if sequence[-1] == self.eot:
|
||||||
|
finished[sequence] = scores[sequence]
|
||||||
|
else:
|
||||||
|
new_sum_logprobs.append(scores[sequence])
|
||||||
|
next_tokens.append(sequence)
|
||||||
|
source_indices.append(sources[sequence])
|
||||||
|
|
||||||
|
saved += 1
|
||||||
|
if saved == self.beam_size:
|
||||||
|
break
|
||||||
|
|
||||||
|
finished_sequences.append(finished)
|
||||||
|
tokens = mx.array(np.array(next_tokens, dtype=np.int32))
|
||||||
|
sum_logprobs = mx.array(np.array(new_sum_logprobs, dtype=np.float32))
|
||||||
|
self.inference.rearrange_kv_cache(source_indices)
|
||||||
|
assert len(self.finished_sequences) == len(finished_sequences)
|
||||||
|
for previously_finished, newly_finished in zip(
|
||||||
|
self.finished_sequences, finished_sequences
|
||||||
|
):
|
||||||
|
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
|
||||||
|
if len(previously_finished) >= self.max_candidates:
|
||||||
|
break
|
||||||
|
previously_finished[seq] = newly_finished[seq]
|
||||||
|
completed = all(
|
||||||
|
len(sequences) >= self.max_candidates
|
||||||
|
for sequences in self.finished_sequences
|
||||||
|
)
|
||||||
|
|
||||||
|
return tokens, completed
|
||||||
|
|
||||||
|
def finalize(self, preceding_tokens: mx.array, sum_logprobs: mx.array):
|
||||||
|
"""Finalize beam search by selecting best sequences."""
|
||||||
|
preceding_tokens_np = np.array(preceding_tokens)
|
||||||
|
sum_logprobs_np = np.array(sum_logprobs)
|
||||||
|
|
||||||
|
n_audio = preceding_tokens_np.shape[0] // self.beam_size
|
||||||
|
tokens_list: List[List[int]] = [[] for _ in range(n_audio)]
|
||||||
|
sum_logprobs_list: List[float] = [0.0] * n_audio
|
||||||
|
|
||||||
|
for i, sequences in enumerate(self.finished_sequences):
|
||||||
|
if sequences:
|
||||||
|
best_seq = max(sequences, key=sequences.get)
|
||||||
|
tokens_list[i] = list(best_seq)
|
||||||
|
sum_logprobs_list[i] = sequences[best_seq]
|
||||||
|
else:
|
||||||
|
idx = i * self.beam_size
|
||||||
|
tokens_list[i] = preceding_tokens_np[idx].tolist() + [self.eot]
|
||||||
|
sum_logprobs_list[i] = float(sum_logprobs_np[idx])
|
||||||
|
max_len = max(len(t) for t in tokens_list)
|
||||||
|
for i, t in enumerate(tokens_list):
|
||||||
|
tokens_list[i] = t + [self.eot] * (max_len - len(t))
|
||||||
|
|
||||||
|
tokens = mx.array(np.array(tokens_list, dtype=np.int32))
|
||||||
|
return tokens, sum_logprobs_list
|
||||||
|
|
||||||
|
|
||||||
|
class MLXInference:
|
||||||
|
"""MLX inference wrapper for beam search KV cache management."""
|
||||||
|
|
||||||
|
def __init__(self, model, initial_token_length: int):
|
||||||
|
self.model = model
|
||||||
|
self.initial_token_length = initial_token_length
|
||||||
|
self.kv_cache = None
|
||||||
|
|
||||||
|
def rearrange_kv_cache(self, source_indices: List[int]):
|
||||||
|
"""Rearrange KV cache based on beam search source indices."""
|
||||||
|
if self.kv_cache is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if source_indices == list(range(len(source_indices))):
|
||||||
|
return
|
||||||
|
|
||||||
|
source_indices_mx = mx.array(source_indices, dtype=mx.int32)
|
||||||
|
|
||||||
|
new_cache = []
|
||||||
|
for layer_cache in self.kv_cache:
|
||||||
|
(k, v), (cross_k, cross_v) = layer_cache
|
||||||
|
new_k = k[source_indices_mx]
|
||||||
|
new_v = v[source_indices_mx]
|
||||||
|
new_cache.append(((new_k, new_v), (cross_k, cross_v)))
|
||||||
|
|
||||||
|
self.kv_cache = new_cache
|
||||||
|
|
||||||
|
def logits(
|
||||||
|
self,
|
||||||
|
tokens: mx.array,
|
||||||
|
audio_features: mx.array,
|
||||||
|
) -> Tuple[mx.array, List]:
|
||||||
|
"""Get logits from decoder with KV cache."""
|
||||||
|
logits, self.kv_cache, cross_qk = self.model.decoder(
|
||||||
|
tokens, audio_features, kv_cache=self.kv_cache
|
||||||
|
)
|
||||||
|
return logits, cross_qk
|
||||||
|
|
||||||
756
whisperlivekit/simul_whisper/mlx/simul_whisper.py
Normal file
756
whisperlivekit/simul_whisper/mlx/simul_whisper.py
Normal file
@@ -0,0 +1,756 @@
|
|||||||
|
"""
|
||||||
|
MLX whisper AlignAtt streaming decoder
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
from time import time
|
||||||
|
from typing import Any, List, Optional, Tuple
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram
|
||||||
|
from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim
|
||||||
|
|
||||||
|
from whisperlivekit.timed_objects import ASRToken
|
||||||
|
from whisperlivekit.whisper import DecodingOptions, tokenizer
|
||||||
|
from whisperlivekit.whisper.audio import N_FRAMES, N_SAMPLES, TOKENS_PER_SECOND
|
||||||
|
|
||||||
|
from ..config import AlignAttConfig
|
||||||
|
from .decoder_state import MLXDecoderState
|
||||||
|
from .decoders import MLXBeamSearchDecoder, MLXGreedyDecoder, MLXInference
|
||||||
|
|
||||||
|
DEC_PAD = 50257
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MLXTokenBuffer: #should try to make it heritate from classic simul whisper class
|
||||||
|
"""Token buffer for MLX-based decoding."""
|
||||||
|
|
||||||
|
def __init__(self, text="", tokenizer=None, prefix_token_ids=None):
|
||||||
|
self.text = text
|
||||||
|
self.prefix_token_ids = prefix_token_ids or []
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.pending_token_ids = []
|
||||||
|
|
||||||
|
def as_token_ids(self, tokenizer=None):
|
||||||
|
if tokenizer is None:
|
||||||
|
tokenizer = self.tokenizer
|
||||||
|
if tokenizer is None:
|
||||||
|
raise ValueError("Tokenizer is not set.")
|
||||||
|
return self.prefix_token_ids + tokenizer.encode(self.text)
|
||||||
|
|
||||||
|
def as_mlx_array(self) -> mx.array:
|
||||||
|
"""Return tokens as MLX array."""
|
||||||
|
tok_ids = self.as_token_ids()
|
||||||
|
return mx.array([tok_ids], dtype=mx.int32)
|
||||||
|
|
||||||
|
def as_mlx_array_beam(self, beam: int) -> mx.array:
|
||||||
|
"""Return tokens as MLX array repeated for beam search."""
|
||||||
|
t = self.as_mlx_array()
|
||||||
|
return mx.repeat(t, beam, axis=0)
|
||||||
|
|
||||||
|
def as_text(self):
|
||||||
|
return self.text
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def empty(*a, **kw):
|
||||||
|
return MLXTokenBuffer(*a, **kw)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_text(text, *a, **kw):
|
||||||
|
return MLXTokenBuffer(*a, text=text, **kw)
|
||||||
|
|
||||||
|
def is_empty(self):
|
||||||
|
return self.text is None or self.text == ""
|
||||||
|
|
||||||
|
def trim_words(self, num=1, after=0):
|
||||||
|
"""Trim words from the beginning of the context."""
|
||||||
|
tokenizer = self.tokenizer
|
||||||
|
assert tokenizer is not None, "Tokenizer is not set."
|
||||||
|
|
||||||
|
ids = tokenizer.encode(self.text[after:])
|
||||||
|
words, wids = self.tokenizer.split_to_word_tokens(ids)
|
||||||
|
if not words:
|
||||||
|
return 0
|
||||||
|
self.text = self.text[:after] + "".join(words[num:])
|
||||||
|
return sum(len(wi) for wi in wids[:num])
|
||||||
|
|
||||||
|
def append_token_ids(self, token_ids):
|
||||||
|
"""Append token IDs to the buffer, handling incomplete UTF-8."""
|
||||||
|
tokenizer = self.tokenizer
|
||||||
|
assert tokenizer is not None, "Tokenizer is not set."
|
||||||
|
|
||||||
|
all_tokens = self.pending_token_ids + token_ids
|
||||||
|
decoded = tokenizer.decode(all_tokens)
|
||||||
|
replacement_char = "\ufffd"
|
||||||
|
|
||||||
|
if replacement_char in decoded:
|
||||||
|
if len(all_tokens) > 1:
|
||||||
|
decoded_partial = tokenizer.decode(all_tokens[:-1])
|
||||||
|
if replacement_char not in decoded_partial:
|
||||||
|
self.text += decoded_partial
|
||||||
|
self.pending_token_ids = [all_tokens[-1]]
|
||||||
|
else:
|
||||||
|
self.pending_token_ids = all_tokens
|
||||||
|
else:
|
||||||
|
self.pending_token_ids = all_tokens
|
||||||
|
else:
|
||||||
|
self.text += decoded
|
||||||
|
self.pending_token_ids = []
|
||||||
|
|
||||||
|
|
||||||
|
def mlx_median_filter(x: mx.array, filter_width: int) -> mx.array:
|
||||||
|
"""
|
||||||
|
Apply median filter along the last axis.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input array of shape (..., T)
|
||||||
|
filter_width: Width of the median filter (should be odd)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Filtered array of same shape
|
||||||
|
"""
|
||||||
|
if filter_width <= 1:
|
||||||
|
return x
|
||||||
|
|
||||||
|
pad_width = filter_width // 2
|
||||||
|
shape = x.shape
|
||||||
|
|
||||||
|
left_pad = mx.repeat(x[..., :1], pad_width, axis=-1)
|
||||||
|
right_pad = mx.repeat(x[..., -1:], pad_width, axis=-1)
|
||||||
|
x_padded = mx.concatenate([left_pad, x, right_pad], axis=-1)
|
||||||
|
|
||||||
|
result_shape = list(shape)
|
||||||
|
result = []
|
||||||
|
|
||||||
|
for i in range(shape[-1]):
|
||||||
|
window = x_padded[..., i:i + filter_width]
|
||||||
|
sorted_window = mx.sort(window, axis=-1)
|
||||||
|
median_val = sorted_window[..., filter_width // 2:filter_width // 2 + 1]
|
||||||
|
result.append(median_val)
|
||||||
|
|
||||||
|
return mx.concatenate(result, axis=-1)
|
||||||
|
|
||||||
|
|
||||||
|
class MLXAlignAtt:
|
||||||
|
"""
|
||||||
|
MLX-native Alignment-based Attention decoder for SimulStreaming.
|
||||||
|
|
||||||
|
This class runs entirely on MLX, with no PyTorch dependencies for inference.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def speaker(self):
|
||||||
|
return self.state.speaker
|
||||||
|
|
||||||
|
@speaker.setter
|
||||||
|
def speaker(self, value):
|
||||||
|
self.state.speaker = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def global_time_offset(self):
|
||||||
|
return self.state.global_time_offset
|
||||||
|
|
||||||
|
@global_time_offset.setter
|
||||||
|
def global_time_offset(self, value):
|
||||||
|
self.state.global_time_offset = value
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cfg: AlignAttConfig,
|
||||||
|
mlx_model: Any,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize MLX AlignAtt decoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg: AlignAtt configuration
|
||||||
|
mlx_model: MLX Whisper model (full model, not just encoder)
|
||||||
|
"""
|
||||||
|
self.model = mlx_model
|
||||||
|
self.cfg = cfg
|
||||||
|
|
||||||
|
logger.info(f"MLX Model dimensions: {self.model.dims}")
|
||||||
|
|
||||||
|
self.decode_options = DecodingOptions(
|
||||||
|
language=cfg.language,
|
||||||
|
without_timestamps=True,
|
||||||
|
task=cfg.task
|
||||||
|
)
|
||||||
|
self.tokenizer_is_multilingual = cfg.tokenizer_is_multilingual
|
||||||
|
|
||||||
|
self.max_text_len = self.model.dims.n_text_ctx
|
||||||
|
self.num_decoder_layers = len(self.model.decoder.blocks)
|
||||||
|
|
||||||
|
if self.cfg.max_context_tokens is None:
|
||||||
|
self.max_context_tokens = self.max_text_len
|
||||||
|
else:
|
||||||
|
self.max_context_tokens = self.cfg.max_context_tokens
|
||||||
|
|
||||||
|
# Initialize per-session state
|
||||||
|
self.state = MLXDecoderState()
|
||||||
|
self._init_state(cfg)
|
||||||
|
|
||||||
|
def _init_state(self, cfg: AlignAttConfig):
|
||||||
|
"""Initialize the per-session decoder state."""
|
||||||
|
self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
|
||||||
|
self.state.tokenizer = self.tokenizer
|
||||||
|
self.state.detected_language = cfg.language if cfg.language != "auto" else None
|
||||||
|
self.state.global_time_offset = 0.0
|
||||||
|
self.state.last_attend_frame = -cfg.rewind_threshold
|
||||||
|
self.state.speaker = -1
|
||||||
|
|
||||||
|
if cfg.cif_ckpt_path is None or not cfg.cif_ckpt_path:
|
||||||
|
if cfg.never_fire:
|
||||||
|
self.state.never_fire = True
|
||||||
|
self.state.always_fire = False
|
||||||
|
else:
|
||||||
|
self.state.always_fire = True
|
||||||
|
self.state.never_fire = False
|
||||||
|
else:
|
||||||
|
logger.warning("CIF checkpoint provided but MLX CIF not implemented. Using always_fire=True")
|
||||||
|
self.state.always_fire = True
|
||||||
|
self.state.never_fire = cfg.never_fire
|
||||||
|
|
||||||
|
self._build_alignment_source()
|
||||||
|
|
||||||
|
suppress_tokens = [
|
||||||
|
self.tokenizer.transcribe,
|
||||||
|
self.tokenizer.translate,
|
||||||
|
self.tokenizer.sot,
|
||||||
|
self.tokenizer.sot_prev,
|
||||||
|
self.tokenizer.sot_lm,
|
||||||
|
self.tokenizer.no_timestamps,
|
||||||
|
] + list(self.tokenizer.all_language_tokens)
|
||||||
|
if self.tokenizer.no_speech is not None:
|
||||||
|
suppress_tokens.append(self.tokenizer.no_speech)
|
||||||
|
self.state.suppress_tokens = tuple(sorted(set(suppress_tokens)))
|
||||||
|
logger.debug(f"Suppress tokens: {self.state.suppress_tokens}")
|
||||||
|
|
||||||
|
self.init_tokens()
|
||||||
|
self.init_context()
|
||||||
|
|
||||||
|
self.state.decoder_type = cfg.decoder_type
|
||||||
|
if cfg.decoder_type == "greedy":
|
||||||
|
logger.info("Using MLX greedy decoder")
|
||||||
|
self.state.token_decoder = MLXGreedyDecoder(0.0, self.tokenizer.eot)
|
||||||
|
elif cfg.decoder_type == "beam":
|
||||||
|
logger.info("Using MLX beam decoder")
|
||||||
|
self.state.inference = MLXInference(self.model, self.state.initial_token_length)
|
||||||
|
self.state.token_decoder = MLXBeamSearchDecoder(
|
||||||
|
inference=self.state.inference,
|
||||||
|
eot=self.tokenizer.eot,
|
||||||
|
beam_size=cfg.beam_size
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_alignment_source(self):
|
||||||
|
"""Build alignment source mapping from model's alignment_heads."""
|
||||||
|
self.state.align_source = {}
|
||||||
|
self.state.num_align_heads = 0
|
||||||
|
|
||||||
|
alignment_heads = self.model.alignment_heads
|
||||||
|
|
||||||
|
if alignment_heads is None:
|
||||||
|
logger.warning("No alignment heads found in model")
|
||||||
|
return
|
||||||
|
|
||||||
|
if hasattr(alignment_heads, 'tolist'):
|
||||||
|
heads_list = alignment_heads.tolist()
|
||||||
|
else:
|
||||||
|
heads_list = np.array(alignment_heads).tolist()
|
||||||
|
|
||||||
|
for layer_rank, head_id in heads_list:
|
||||||
|
layer_rank = int(layer_rank)
|
||||||
|
head_id = int(head_id)
|
||||||
|
heads = self.state.align_source.get(layer_rank, [])
|
||||||
|
heads.append((self.state.num_align_heads, head_id))
|
||||||
|
self.state.align_source[layer_rank] = heads
|
||||||
|
self.state.num_align_heads += 1
|
||||||
|
|
||||||
|
def warmup(self, audio: np.ndarray):
|
||||||
|
"""Warmup the model with sample audio."""
|
||||||
|
try:
|
||||||
|
self.insert_audio(audio)
|
||||||
|
self.infer(is_last=True)
|
||||||
|
self.refresh_segment(complete=True)
|
||||||
|
logger.info("MLX model warmed up successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"MLX model warmup failed: {e}")
|
||||||
|
|
||||||
|
def create_tokenizer(self, language=None):
|
||||||
|
"""Create tokenizer for the given language."""
|
||||||
|
self.tokenizer = tokenizer.get_tokenizer(
|
||||||
|
multilingual=self.tokenizer_is_multilingual,
|
||||||
|
language=language,
|
||||||
|
num_languages=self.model.num_languages,
|
||||||
|
task=self.decode_options.task
|
||||||
|
)
|
||||||
|
self.state.tokenizer = self.tokenizer
|
||||||
|
|
||||||
|
def init_context(self):
|
||||||
|
"""Initialize context buffer."""
|
||||||
|
kw = {
|
||||||
|
'tokenizer': self.tokenizer,
|
||||||
|
'prefix_token_ids': [self.tokenizer.sot_prev]
|
||||||
|
}
|
||||||
|
self.state.context = MLXTokenBuffer.empty(**kw)
|
||||||
|
if self.cfg.static_init_prompt is not None:
|
||||||
|
self.state.context = MLXTokenBuffer.from_text(self.cfg.static_init_prompt, **kw)
|
||||||
|
if self.cfg.init_prompt is not None:
|
||||||
|
self.state.context.text += self.cfg.init_prompt
|
||||||
|
|
||||||
|
def init_tokens(self):
|
||||||
|
"""Initialize token sequence."""
|
||||||
|
logger.debug(f"init tokens, {len(self.state.segments)}")
|
||||||
|
self.state.initial_tokens = mx.array(
|
||||||
|
[self.tokenizer.sot_sequence_including_notimestamps],
|
||||||
|
dtype=mx.int32
|
||||||
|
)
|
||||||
|
self.state.initial_token_length = self.state.initial_tokens.shape[1]
|
||||||
|
self.state.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot)
|
||||||
|
logger.debug(f"init tokens after, {len(self.state.segments)}")
|
||||||
|
self.state.tokens = [self.state.initial_tokens]
|
||||||
|
|
||||||
|
def trim_context(self):
|
||||||
|
"""Trim context if too long."""
|
||||||
|
logger.info("Trimming context")
|
||||||
|
c = len(self.state.context.as_token_ids()) - len(self.state.context.prefix_token_ids)
|
||||||
|
logger.info(f"Context text: {self.state.context.as_text()}")
|
||||||
|
l = sum(t.shape[1] for t in self.state.tokens) + c
|
||||||
|
if self.cfg.static_init_prompt is None:
|
||||||
|
after = 0
|
||||||
|
else:
|
||||||
|
after = len(self.cfg.static_init_prompt)
|
||||||
|
while c > self.max_context_tokens or l > self.max_text_len - 20:
|
||||||
|
t = self.state.context.trim_words(after=after)
|
||||||
|
l -= t
|
||||||
|
c -= t
|
||||||
|
logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
|
||||||
|
if t == 0:
|
||||||
|
break
|
||||||
|
logger.info(f"Context after trim: {self.state.context.text} (len: {l})")
|
||||||
|
|
||||||
|
def refresh_segment(self, complete=False):
|
||||||
|
"""Refresh segment state."""
|
||||||
|
logger.debug("Refreshing segment:")
|
||||||
|
self.init_tokens()
|
||||||
|
self.state.last_attend_frame = -self.cfg.rewind_threshold
|
||||||
|
self.state.cumulative_time_offset = 0.0
|
||||||
|
self.init_context()
|
||||||
|
logger.debug(f"Context: {self.state.context}")
|
||||||
|
if not complete and len(self.state.segments) > 2:
|
||||||
|
self.state.segments = self.state.segments[-2:]
|
||||||
|
else:
|
||||||
|
logger.debug("removing all segments.")
|
||||||
|
self.state.segments = []
|
||||||
|
self.state.log_segments += 1
|
||||||
|
self.state.pending_incomplete_tokens = []
|
||||||
|
|
||||||
|
def fire_at_boundary(self, chunked_encoder_feature: mx.array) -> bool:
|
||||||
|
"""Check if we should fire at word boundary (CIF-based)."""
|
||||||
|
if self.state.always_fire:
|
||||||
|
return True
|
||||||
|
if self.state.never_fire:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _current_tokens(self) -> mx.array:
|
||||||
|
"""Get current token sequence for decoding."""
|
||||||
|
toks = self.state.tokens
|
||||||
|
|
||||||
|
if toks[0].shape[0] == 1:
|
||||||
|
toks[0] = mx.repeat(toks[0], self.cfg.beam_size, axis=0)
|
||||||
|
|
||||||
|
if not self.state.context.is_empty():
|
||||||
|
context_toks = self.state.context.as_mlx_array_beam(self.cfg.beam_size)
|
||||||
|
toks = [context_toks] + toks
|
||||||
|
|
||||||
|
# Concatenate all tokens
|
||||||
|
if len(toks) > 1:
|
||||||
|
current_tokens = mx.concatenate(toks, axis=1)
|
||||||
|
else:
|
||||||
|
current_tokens = toks[0]
|
||||||
|
|
||||||
|
logger.debug("debug print current_tokens:")
|
||||||
|
self.debug_print_tokens(current_tokens)
|
||||||
|
return current_tokens
|
||||||
|
|
||||||
|
def debug_print_tokens(self, tokens: mx.array):
|
||||||
|
"""Debug print token sequences."""
|
||||||
|
tokens_np = np.array(tokens)
|
||||||
|
for i in range(min(self.cfg.beam_size, tokens_np.shape[0])):
|
||||||
|
logger.debug(self.tokenizer.decode_with_timestamps(tokens_np[i].tolist()))
|
||||||
|
|
||||||
|
def segments_len(self) -> float:
|
||||||
|
"""Get total length of audio segments in seconds."""
|
||||||
|
return sum(s.shape[0] for s in self.state.segments) / 16000
|
||||||
|
|
||||||
|
def _apply_minseglen(self) -> bool:
|
||||||
|
"""Check if we have enough audio to process."""
|
||||||
|
segments_len = self.segments_len()
|
||||||
|
if segments_len < self.cfg.audio_min_len:
|
||||||
|
logger.debug("waiting for next segment")
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def insert_audio(self, segment: np.ndarray = None):
|
||||||
|
"""Insert audio segment into buffer."""
|
||||||
|
if segment is not None:
|
||||||
|
if hasattr(segment, 'numpy'):
|
||||||
|
segment = segment.numpy()
|
||||||
|
self.state.segments.append(segment)
|
||||||
|
|
||||||
|
removed_len = 0
|
||||||
|
segments_len = self.segments_len()
|
||||||
|
|
||||||
|
while len(self.state.segments) > 1 and segments_len > self.cfg.audio_max_len:
|
||||||
|
removed_len = self.state.segments[0].shape[0] / 16000
|
||||||
|
segments_len -= removed_len
|
||||||
|
self.state.last_attend_frame -= int(TOKENS_PER_SECOND * removed_len)
|
||||||
|
self.state.cumulative_time_offset += removed_len
|
||||||
|
self.state.segments = self.state.segments[1:]
|
||||||
|
logger.debug(f"remove segments: {len(self.state.segments)} {len(self.state.tokens)}, cumulative offset: {self.state.cumulative_time_offset:.2f}s")
|
||||||
|
|
||||||
|
if len(self.state.tokens) > 1:
|
||||||
|
# Convert MLX array to list for context
|
||||||
|
token_list = np.array(self.state.tokens[1][0, :]).tolist()
|
||||||
|
self.state.context.append_token_ids(token_list)
|
||||||
|
self.state.tokens = [self.state.initial_tokens] + self.state.tokens[2:]
|
||||||
|
|
||||||
|
return removed_len
|
||||||
|
|
||||||
|
def _clean_cache(self):
|
||||||
|
"""Clean the kv_cache after each inference step."""
|
||||||
|
self.state.clean_cache()
|
||||||
|
|
||||||
|
def _suppress_tokens(self, logits: mx.array) -> mx.array:
|
||||||
|
"""Apply token suppression to logits."""
|
||||||
|
if self.state.suppress_tokens:
|
||||||
|
suppress_indices = mx.array(list(self.state.suppress_tokens), dtype=mx.int32)
|
||||||
|
logits = logits.at[:, suppress_indices].add(-float('inf'))
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def lang_id(self, encoder_features: mx.array) -> Tuple[mx.array, List[dict]]:
|
||||||
|
"""Language detection from encoder features."""
|
||||||
|
n_audio = encoder_features.shape[0]
|
||||||
|
x = mx.array([[self.tokenizer.sot]] * n_audio, dtype=mx.int32)
|
||||||
|
|
||||||
|
logits, _, _ = self.model.decoder(x, encoder_features, kv_cache=None)
|
||||||
|
logits = logits[:, 0]
|
||||||
|
|
||||||
|
mask = mx.ones(logits.shape[-1], dtype=mx.bool_)
|
||||||
|
language_token_indices = mx.array(list(self.tokenizer.all_language_tokens), dtype=mx.int32)
|
||||||
|
mask = mask.at[language_token_indices].add(False)
|
||||||
|
|
||||||
|
logits = mx.where(mask, mx.array(-float('inf')), logits)
|
||||||
|
|
||||||
|
language_tokens = mx.argmax(logits, axis=-1)
|
||||||
|
language_token_probs = mx.softmax(logits, axis=-1)
|
||||||
|
|
||||||
|
probs_np = np.array(language_token_probs)
|
||||||
|
|
||||||
|
language_probs = [
|
||||||
|
{
|
||||||
|
c: float(probs_np[i, j])
|
||||||
|
for j, c in zip(self.tokenizer.all_language_tokens, self.tokenizer.all_language_codes)
|
||||||
|
}
|
||||||
|
for i in range(n_audio)
|
||||||
|
]
|
||||||
|
|
||||||
|
self._clean_cache()
|
||||||
|
return language_tokens, language_probs
|
||||||
|
|
||||||
|
def infer(self, is_last: bool = False) -> List[ASRToken]:
|
||||||
|
"""
|
||||||
|
Main inference method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
is_last: Whether this is the final chunk
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of timestamped ASR tokens
|
||||||
|
"""
|
||||||
|
new_segment = True
|
||||||
|
|
||||||
|
if len(self.state.segments) == 0:
|
||||||
|
logger.debug("No segments, nothing to do")
|
||||||
|
return []
|
||||||
|
|
||||||
|
if not self._apply_minseglen():
|
||||||
|
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
if len(self.state.segments) > 1:
|
||||||
|
input_segments = np.concatenate(self.state.segments, axis=0)
|
||||||
|
else:
|
||||||
|
input_segments = self.state.segments[0]
|
||||||
|
|
||||||
|
beg_encode = time()
|
||||||
|
|
||||||
|
mlx_mel_padded = mlx_log_mel_spectrogram(
|
||||||
|
audio=input_segments,
|
||||||
|
n_mels=self.model.dims.n_mels,
|
||||||
|
padding=N_SAMPLES
|
||||||
|
)
|
||||||
|
mlx_mel = mlx_pad_or_trim(mlx_mel_padded, N_FRAMES, axis=-2)
|
||||||
|
encoder_feature = self.model.encoder(mlx_mel[None])
|
||||||
|
content_mel_len = int((mlx_mel_padded.shape[0] - mlx_mel.shape[0]) / 2)
|
||||||
|
|
||||||
|
mx.eval(encoder_feature)
|
||||||
|
|
||||||
|
end_encode = time()
|
||||||
|
logger.debug(f'MLX Encoder duration: {end_encode - beg_encode:.3f}s')
|
||||||
|
|
||||||
|
if self.cfg.language == "auto" and self.state.detected_language is None and self.state.first_timestamp:
|
||||||
|
seconds_since_start = self.segments_len() - self.state.first_timestamp
|
||||||
|
if seconds_since_start >= 2.0:
|
||||||
|
language_tokens, language_probs = self.lang_id(encoder_feature)
|
||||||
|
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
|
||||||
|
print(f"Detected language: {top_lan} with p={p:.4f}")
|
||||||
|
self.create_tokenizer(top_lan)
|
||||||
|
self.state.last_attend_frame = -self.cfg.rewind_threshold
|
||||||
|
self.state.cumulative_time_offset = 0.0
|
||||||
|
self.init_tokens()
|
||||||
|
self.init_context()
|
||||||
|
self.state.detected_language = top_lan
|
||||||
|
logger.info(f"Tokenizer language: {self.tokenizer.language}")
|
||||||
|
|
||||||
|
self.trim_context()
|
||||||
|
current_tokens = self._current_tokens()
|
||||||
|
|
||||||
|
fire_detected = self.fire_at_boundary(encoder_feature[:, :content_mel_len, :])
|
||||||
|
|
||||||
|
sum_logprobs = mx.zeros((self.cfg.beam_size,), dtype=mx.float32)
|
||||||
|
completed = False
|
||||||
|
|
||||||
|
attn_of_alignment_heads = None
|
||||||
|
most_attended_frame = None
|
||||||
|
|
||||||
|
token_len_before_decoding = current_tokens.shape[1]
|
||||||
|
|
||||||
|
l_absolute_timestamps = []
|
||||||
|
accumulated_cross_attns = []
|
||||||
|
|
||||||
|
audio_duration_s = self.segments_len()
|
||||||
|
# ~15 text tokens/s is a generous upper bound for speech; TOKENS_PER_SECOND (50)
|
||||||
|
# is the mel-frame rate and was causing 10-40x over-allocation on repetition loops.
|
||||||
|
max_tokens_per_chunk = max(50, int(audio_duration_s * 15 * 1.5))
|
||||||
|
tokens_produced_this_chunk = 0
|
||||||
|
|
||||||
|
while not completed and current_tokens.shape[1] < self.max_text_len:
|
||||||
|
tokens_produced_this_chunk += 1
|
||||||
|
|
||||||
|
if tokens_produced_this_chunk > max_tokens_per_chunk:
|
||||||
|
logger.warning(f"[Loop Detection] Too many tokens ({tokens_produced_this_chunk}) for {audio_duration_s:.2f}s audio. Breaking.")
|
||||||
|
current_tokens = current_tokens[:, :token_len_before_decoding]
|
||||||
|
break
|
||||||
|
|
||||||
|
if new_segment:
|
||||||
|
tokens_for_logits = current_tokens
|
||||||
|
else:
|
||||||
|
tokens_for_logits = current_tokens[:, -1:]
|
||||||
|
|
||||||
|
if self.state.decoder_type == "greedy":
|
||||||
|
logits, self.state.kv_cache, cross_qk = self.model.decoder(
|
||||||
|
tokens_for_logits, encoder_feature, kv_cache=self.state.kv_cache
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logits, cross_qk = self.state.inference.logits(tokens_for_logits, encoder_feature)
|
||||||
|
|
||||||
|
mx.eval(logits)
|
||||||
|
|
||||||
|
accumulated_cross_attns.append(cross_qk)
|
||||||
|
if len(accumulated_cross_attns) > 16:
|
||||||
|
accumulated_cross_attns = accumulated_cross_attns[-16:]
|
||||||
|
|
||||||
|
if new_segment and self.tokenizer.no_speech is not None:
|
||||||
|
probs_at_sot = mx.softmax(logits[:, self.state.sot_index, :], axis=-1)
|
||||||
|
no_speech_probs = np.array(probs_at_sot[:, self.tokenizer.no_speech]).tolist()
|
||||||
|
if no_speech_probs[0] > self.cfg.nonspeech_prob:
|
||||||
|
logger.info("no speech, stop")
|
||||||
|
break
|
||||||
|
|
||||||
|
logits = logits[:, -1, :] # Last token logits
|
||||||
|
|
||||||
|
# Suppress tokens at segment start
|
||||||
|
if new_segment:
|
||||||
|
blank_tokens = self.tokenizer.encode(" ") + [self.tokenizer.eot]
|
||||||
|
logits = logits.at[:, blank_tokens].add(-float('inf'))
|
||||||
|
new_segment = False
|
||||||
|
|
||||||
|
logits = self._suppress_tokens(logits)
|
||||||
|
|
||||||
|
current_tokens, completed = self.state.token_decoder.update(
|
||||||
|
current_tokens, logits, sum_logprobs
|
||||||
|
)
|
||||||
|
mx.eval(current_tokens)
|
||||||
|
|
||||||
|
logger.debug(f"Decoding completed: {completed}")
|
||||||
|
self.debug_print_tokens(current_tokens)
|
||||||
|
|
||||||
|
attn_of_alignment_heads = self._process_cross_attention(
|
||||||
|
accumulated_cross_attns, content_mel_len
|
||||||
|
)
|
||||||
|
|
||||||
|
most_attended_frames = mx.argmax(attn_of_alignment_heads[:, -1, :], axis=-1)
|
||||||
|
most_attended_frames_np = np.array(most_attended_frames)
|
||||||
|
|
||||||
|
absolute_timestamps = [
|
||||||
|
(frame * 0.02 + self.state.cumulative_time_offset)
|
||||||
|
for frame in most_attended_frames_np.tolist()
|
||||||
|
]
|
||||||
|
|
||||||
|
logger.debug(str(most_attended_frames_np.tolist()) + " most att frames")
|
||||||
|
logger.debug(f"Absolute timestamps: {absolute_timestamps}")
|
||||||
|
|
||||||
|
most_attended_frame = int(most_attended_frames_np[0])
|
||||||
|
l_absolute_timestamps.append(absolute_timestamps[0])
|
||||||
|
|
||||||
|
if completed:
|
||||||
|
current_tokens = current_tokens[:, :-1]
|
||||||
|
break
|
||||||
|
if not is_last and self.state.last_attend_frame - most_attended_frame > self.cfg.rewind_threshold:
|
||||||
|
current_tokens_np = np.array(current_tokens)
|
||||||
|
if current_tokens.shape[1] > 1 and current_tokens_np[0, -2] >= DEC_PAD:
|
||||||
|
logger.debug("omit rewinding from special tokens")
|
||||||
|
self.state.last_attend_frame = most_attended_frame
|
||||||
|
else:
|
||||||
|
logger.debug(f"[rewind detected] current: {most_attended_frame}, last: {self.state.last_attend_frame}")
|
||||||
|
self.state.last_attend_frame = -self.cfg.rewind_threshold
|
||||||
|
current_tokens = mx.concatenate(self.state.tokens, axis=1) if len(self.state.tokens) > 0 else self.state.tokens[0]
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
self.state.last_attend_frame = most_attended_frame
|
||||||
|
if content_mel_len - most_attended_frame <= (4 if is_last else self.cfg.frame_threshold):
|
||||||
|
logger.debug(f"attention reaches the end: {most_attended_frame}/{content_mel_len}")
|
||||||
|
current_tokens = current_tokens[:, :-1]
|
||||||
|
break
|
||||||
|
tokens_to_split = np.array(current_tokens[0, token_len_before_decoding:]).tolist()
|
||||||
|
if self.state.pending_incomplete_tokens:
|
||||||
|
logger.debug(f"[UTF-8 Fix] Prepending pending tokens: {self.state.pending_incomplete_tokens}")
|
||||||
|
tokens_to_split = self.state.pending_incomplete_tokens + tokens_to_split
|
||||||
|
|
||||||
|
if fire_detected or is_last:
|
||||||
|
new_hypothesis = tokens_to_split
|
||||||
|
split_words, split_tokens = self.tokenizer.split_to_word_tokens(new_hypothesis)
|
||||||
|
else:
|
||||||
|
split_words, split_tokens = self.tokenizer.split_to_word_tokens(tokens_to_split)
|
||||||
|
if len(split_words) > 1:
|
||||||
|
new_hypothesis = [i for sublist in split_tokens[:-1] for i in sublist]
|
||||||
|
else:
|
||||||
|
new_hypothesis = []
|
||||||
|
|
||||||
|
logger.debug(f"new_hypothesis: {new_hypothesis}")
|
||||||
|
new_tokens = mx.array([new_hypothesis], dtype=mx.int32)
|
||||||
|
new_tokens = mx.repeat(new_tokens, self.cfg.beam_size, axis=0)
|
||||||
|
self.state.tokens.append(new_tokens)
|
||||||
|
|
||||||
|
logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}")
|
||||||
|
|
||||||
|
self._clean_cache()
|
||||||
|
|
||||||
|
if len(l_absolute_timestamps) >= 2 and self.state.first_timestamp is None:
|
||||||
|
self.state.first_timestamp = l_absolute_timestamps[0]
|
||||||
|
timestamped_words = []
|
||||||
|
timestamp_idx = 0
|
||||||
|
replacement_char = "\ufffd"
|
||||||
|
|
||||||
|
for word, word_tokens in zip(split_words, split_tokens):
|
||||||
|
if replacement_char in word:
|
||||||
|
logger.warning(f"[UTF-8 Filter] Skipping: {repr(word)}")
|
||||||
|
timestamp_idx += len(word_tokens)
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
current_timestamp = l_absolute_timestamps[timestamp_idx]
|
||||||
|
except IndexError:
|
||||||
|
pass
|
||||||
|
timestamp_idx += len(word_tokens)
|
||||||
|
|
||||||
|
timestamp_entry = ASRToken(
|
||||||
|
start=round(current_timestamp, 2),
|
||||||
|
end=round(current_timestamp + 0.1, 2),
|
||||||
|
text=word,
|
||||||
|
speaker=self.state.speaker,
|
||||||
|
detected_language=self.state.detected_language
|
||||||
|
).with_offset(self.state.global_time_offset)
|
||||||
|
timestamped_words.append(timestamp_entry)
|
||||||
|
self.state.pending_incomplete_tokens = []
|
||||||
|
MAX_PENDING_TOKENS = 10
|
||||||
|
if split_words and replacement_char in split_words[-1]:
|
||||||
|
if len(split_tokens[-1]) <= MAX_PENDING_TOKENS:
|
||||||
|
self.state.pending_incomplete_tokens = split_tokens[-1]
|
||||||
|
logger.debug(f"[UTF-8 Fix] Holding incomplete tokens")
|
||||||
|
else:
|
||||||
|
logger.warning(f"[UTF-8 Fix] Skipping too many tokens")
|
||||||
|
|
||||||
|
return timestamped_words
|
||||||
|
|
||||||
|
def _process_cross_attention(
|
||||||
|
self,
|
||||||
|
cross_attns: List[List[mx.array]],
|
||||||
|
content_mel_len: int
|
||||||
|
) -> mx.array:
|
||||||
|
"""
|
||||||
|
Process cross-attention weights for alignment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cross_attns: List of cross-attention from each forward pass
|
||||||
|
Each element is a list of mx.arrays per layer
|
||||||
|
content_mel_len: Length of actual audio content
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Processed attention tensor, shape (batch, seq_len, content_mel_len)
|
||||||
|
"""
|
||||||
|
attn_of_alignment_heads = [[] for _ in range(self.state.num_align_heads)]
|
||||||
|
num_decoder_layers = self.num_decoder_layers
|
||||||
|
|
||||||
|
if cross_attns and isinstance(cross_attns[0], list):
|
||||||
|
flattened_attns = [attn for layer_list in cross_attns for attn in layer_list]
|
||||||
|
else:
|
||||||
|
flattened_attns = cross_attns
|
||||||
|
|
||||||
|
for idx, attn_mat in enumerate(flattened_attns):
|
||||||
|
if attn_mat is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
layer_rank = idx % num_decoder_layers
|
||||||
|
align_heads_in_layer = self.state.align_source.get(layer_rank, [])
|
||||||
|
|
||||||
|
if len(align_heads_in_layer) == 0:
|
||||||
|
continue
|
||||||
|
attn_mat = mx.softmax(attn_mat, axis=-1)
|
||||||
|
|
||||||
|
for align_head_rank, head_id in align_heads_in_layer:
|
||||||
|
if self.cfg.beam_size == 1:
|
||||||
|
if attn_mat.ndim == 4:
|
||||||
|
a = attn_mat[0, head_id, :, :]
|
||||||
|
else:
|
||||||
|
a = attn_mat[head_id, :, :]
|
||||||
|
a = a[None, :, :]
|
||||||
|
else:
|
||||||
|
a = attn_mat[:, head_id, :, :]
|
||||||
|
attn_of_alignment_heads[align_head_rank].append(a)
|
||||||
|
tmp = []
|
||||||
|
for mat in attn_of_alignment_heads:
|
||||||
|
if mat:
|
||||||
|
t = mx.concatenate(mat, axis=1)
|
||||||
|
tmp.append(t)
|
||||||
|
|
||||||
|
if not tmp:
|
||||||
|
return mx.zeros((self.cfg.beam_size, 1, content_mel_len))
|
||||||
|
attn_of_alignment_heads = mx.stack(tmp, axis=1)
|
||||||
|
|
||||||
|
std = mx.std(attn_of_alignment_heads, axis=-2, keepdims=True)
|
||||||
|
mean = mx.mean(attn_of_alignment_heads, axis=-2, keepdims=True)
|
||||||
|
attn_of_alignment_heads = (attn_of_alignment_heads - mean) / (std + 1e-8)
|
||||||
|
|
||||||
|
attn_of_alignment_heads = mlx_median_filter(attn_of_alignment_heads, 7)
|
||||||
|
|
||||||
|
attn_of_alignment_heads = mx.mean(attn_of_alignment_heads, axis=1)
|
||||||
|
|
||||||
|
attn_of_alignment_heads = attn_of_alignment_heads[:, :, :content_mel_len]
|
||||||
|
|
||||||
|
mx.eval(attn_of_alignment_heads)
|
||||||
|
return attn_of_alignment_heads
|
||||||
|
|
||||||
@@ -69,3 +69,39 @@ def load_mlx_encoder(
|
|||||||
model.update(encoder_weights)
|
model.update(encoder_weights)
|
||||||
mx.eval(model.parameters())
|
mx.eval(model.parameters())
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def load_mlx_model(
|
||||||
|
path_or_hf_repo: str,
|
||||||
|
dtype: mx.Dtype = mx.float32,
|
||||||
|
) -> whisper.Whisper:
|
||||||
|
model_path = Path(path_or_hf_repo)
|
||||||
|
if not model_path.exists():
|
||||||
|
model_path = Path(snapshot_download(repo_id=path_or_hf_repo))
|
||||||
|
|
||||||
|
with open(str(model_path / "config.json"), "r") as f:
|
||||||
|
config = json.loads(f.read())
|
||||||
|
config.pop("model_type", None)
|
||||||
|
quantization = config.pop("quantization", None)
|
||||||
|
|
||||||
|
model_args = whisper.ModelDimensions(**config)
|
||||||
|
|
||||||
|
wf = model_path / "weights.safetensors"
|
||||||
|
if not wf.exists():
|
||||||
|
wf = model_path / "weights.npz"
|
||||||
|
weights = mx.load(str(wf))
|
||||||
|
|
||||||
|
model = whisper.Whisper(model_args, dtype)
|
||||||
|
|
||||||
|
if quantization is not None:
|
||||||
|
class_predicate = (
|
||||||
|
lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
|
||||||
|
and f"{p}.scales" in weights
|
||||||
|
)
|
||||||
|
nn.quantize(model, **quantization, class_predicate=class_predicate)
|
||||||
|
|
||||||
|
weights = tree_unflatten(list(weights.items()))
|
||||||
|
|
||||||
|
model.update(weights)
|
||||||
|
mx.eval(model.parameters())
|
||||||
|
return model
|
||||||
@@ -390,7 +390,6 @@ class AlignAtt:
|
|||||||
return []
|
return []
|
||||||
if not self._apply_minseglen():
|
if not self._apply_minseglen():
|
||||||
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
|
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
|
||||||
input_segments = torch.cat(self.state.segments, dim=0)
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# input_segments is concatenation of audio, it's one array
|
# input_segments is concatenation of audio, it's one array
|
||||||
@@ -485,7 +484,9 @@ class AlignAtt:
|
|||||||
accumulated_cross_attns = []
|
accumulated_cross_attns = []
|
||||||
|
|
||||||
audio_duration_s = self.segments_len()
|
audio_duration_s = self.segments_len()
|
||||||
max_tokens_per_chunk = max(50, int(audio_duration_s * TOKENS_PER_SECOND * 2.0)) # 2x margin, min 50
|
# ~15 text tokens/s is a generous upper bound for speech; TOKENS_PER_SECOND (50)
|
||||||
|
# is the mel-frame rate and was causing 10-40x over-allocation on repetition loops.
|
||||||
|
max_tokens_per_chunk = max(50, int(audio_duration_s * 15 * 1.5))
|
||||||
tokens_produced_this_chunk = 0
|
tokens_produced_this_chunk = 0
|
||||||
|
|
||||||
while not completed and current_tokens.shape[1] < self.max_text_len: # bos is 3 tokens
|
while not completed and current_tokens.shape[1] < self.max_text_len: # bos is 3 tokens
|
||||||
@@ -506,8 +507,12 @@ class AlignAtt:
|
|||||||
result = self.logits(tokens_for_logits, encoder_feature, return_cross_attn=True)
|
result = self.logits(tokens_for_logits, encoder_feature, return_cross_attn=True)
|
||||||
logits, cross_attns = result
|
logits, cross_attns = result
|
||||||
|
|
||||||
# Accumulate cross-attention from this forward pass
|
# Accumulate cross-attention from this forward pass (rolling window to
|
||||||
|
# bound VRAM — only the last entry matters for alignment, and the
|
||||||
|
# median_filter kernel is 7, so 16 entries is more than enough).
|
||||||
accumulated_cross_attns.append(cross_attns)
|
accumulated_cross_attns.append(cross_attns)
|
||||||
|
if len(accumulated_cross_attns) > 16:
|
||||||
|
accumulated_cross_attns = accumulated_cross_attns[-16:]
|
||||||
|
|
||||||
if new_segment and self.tokenizer.no_speech is not None:
|
if new_segment and self.tokenizer.no_speech is not None:
|
||||||
probs_at_sot = logits[:, self.state.sot_index, :].float().softmax(dim=-1)
|
probs_at_sot = logits[:, self.state.sot_index, :].float().softmax(dim=-1)
|
||||||
@@ -626,8 +631,10 @@ class AlignAtt:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
current_timestamp = l_absolute_timestamps[timestamp_idx]
|
current_timestamp = l_absolute_timestamps[timestamp_idx]
|
||||||
except:
|
except IndexError:
|
||||||
pass
|
# Use last timestamp if index out of range
|
||||||
|
logger.warning(f"Timestamp index {timestamp_idx} out of range, using last timestamp")
|
||||||
|
current_timestamp = l_absolute_timestamps[-1] if l_absolute_timestamps else 0.0
|
||||||
timestamp_idx += len(word_tokens)
|
timestamp_idx += len(word_tokens)
|
||||||
|
|
||||||
timestamp_entry = ASRToken(
|
timestamp_entry = ASRToken(
|
||||||
|
|||||||
139
whisperlivekit/thread_safety.py
Normal file
139
whisperlivekit/thread_safety.py
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
"""
|
||||||
|
Thread Safety Configuration for WhisperLiveKit
|
||||||
|
|
||||||
|
This module provides thread safety configuration and utilities.
|
||||||
|
|
||||||
|
Environment Variables:
|
||||||
|
WHISPERLIVEKIT_MODEL_LOCK: Enable/disable model locking (default: 1)
|
||||||
|
Set to "0" to disable for single-connection deployments
|
||||||
|
|
||||||
|
WHISPERLIVEKIT_LOCK_TIMEOUT: Lock acquisition timeout in seconds (default: 30)
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# Enable model locking (default)
|
||||||
|
export WHISPERLIVEKIT_MODEL_LOCK=1
|
||||||
|
|
||||||
|
# Disable for single-connection deployment
|
||||||
|
export WHISPERLIVEKIT_MODEL_LOCK=0
|
||||||
|
|
||||||
|
# Custom timeout
|
||||||
|
export WHISPERLIVEKIT_LOCK_TIMEOUT=60
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
USE_MODEL_LOCK = os.environ.get("WHISPERLIVEKIT_MODEL_LOCK", "1") == "1"
|
||||||
|
LOCK_TIMEOUT = float(os.environ.get("WHISPERLIVEKIT_LOCK_TIMEOUT", "30.0"))
|
||||||
|
|
||||||
|
# Global model lock
|
||||||
|
_model_lock = threading.Lock()
|
||||||
|
|
||||||
|
# Log configuration on import
|
||||||
|
if USE_MODEL_LOCK:
|
||||||
|
logger.info(f"Model locking ENABLED (timeout: {LOCK_TIMEOUT}s)")
|
||||||
|
logger.info("For single-connection deployments, set WHISPERLIVEKIT_MODEL_LOCK=0")
|
||||||
|
else:
|
||||||
|
logger.warning("Model locking DISABLED - only safe for single-connection deployments")
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_lock():
|
||||||
|
"""Get the global model lock instance"""
|
||||||
|
return _model_lock
|
||||||
|
|
||||||
|
|
||||||
|
def acquire_model_lock(timeout=None):
|
||||||
|
"""
|
||||||
|
Acquire model lock with timeout.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeout: Lock acquisition timeout (default: use LOCK_TIMEOUT)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if lock acquired, False on timeout
|
||||||
|
"""
|
||||||
|
if not USE_MODEL_LOCK:
|
||||||
|
return True
|
||||||
|
|
||||||
|
timeout = timeout or LOCK_TIMEOUT
|
||||||
|
acquired = _model_lock.acquire(timeout=timeout)
|
||||||
|
|
||||||
|
if not acquired:
|
||||||
|
logger.error(f"Failed to acquire model lock within {timeout}s")
|
||||||
|
|
||||||
|
return acquired
|
||||||
|
|
||||||
|
|
||||||
|
def release_model_lock():
|
||||||
|
"""Release model lock"""
|
||||||
|
if not USE_MODEL_LOCK:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
_model_lock.release()
|
||||||
|
except RuntimeError:
|
||||||
|
# Lock not held - this is fine
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ModelLockContext:
|
||||||
|
"""Context manager for model lock"""
|
||||||
|
|
||||||
|
def __init__(self, timeout=None):
|
||||||
|
self.timeout = timeout
|
||||||
|
self.acquired = False
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.acquired = acquire_model_lock(self.timeout)
|
||||||
|
return self.acquired
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
if self.acquired:
|
||||||
|
release_model_lock()
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# Concurrency recommendations
|
||||||
|
RECOMMENDED_CONNECTIONS_PER_WORKER = 1 if USE_MODEL_LOCK else 1
|
||||||
|
RECOMMENDED_WORKERS = 4
|
||||||
|
|
||||||
|
def print_deployment_recommendations():
|
||||||
|
"""Print recommended deployment configuration"""
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("WhisperLiveKit Deployment Recommendations")
|
||||||
|
print("="*60)
|
||||||
|
|
||||||
|
if USE_MODEL_LOCK:
|
||||||
|
print("⚠️ Model locking is ENABLED")
|
||||||
|
print(" This serializes inference across connections.")
|
||||||
|
print()
|
||||||
|
print("Recommended deployment:")
|
||||||
|
print(f" gunicorn -w {RECOMMENDED_WORKERS} \\")
|
||||||
|
print(" -k uvicorn.workers.UvicornWorker \\")
|
||||||
|
print(" --worker-connections 1 \\")
|
||||||
|
print(" whisperlivekit.basic_server:app")
|
||||||
|
print()
|
||||||
|
print("Expected capacity:")
|
||||||
|
print(f" - {RECOMMENDED_WORKERS} concurrent users (1 per worker)")
|
||||||
|
print(f" - Memory: ~{RECOMMENDED_WORKERS}x model size")
|
||||||
|
else:
|
||||||
|
print("✅ Model locking is DISABLED")
|
||||||
|
print(" ⚠️ ONLY safe for single-connection deployments")
|
||||||
|
print()
|
||||||
|
print("Recommended deployment:")
|
||||||
|
print(" uvicorn whisperlivekit.basic_server:app \\")
|
||||||
|
print(" --host 0.0.0.0 --port 8000 \\")
|
||||||
|
print(" --workers 1")
|
||||||
|
print()
|
||||||
|
print("Expected capacity:")
|
||||||
|
print(" - 1 concurrent user only")
|
||||||
|
|
||||||
|
print("="*60 + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print_deployment_recommendations()
|
||||||
@@ -39,10 +39,11 @@ class TimedText(Timed):
|
|||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class ASRToken(TimedText):
|
class ASRToken(TimedText):
|
||||||
|
probability: Optional[float] = None
|
||||||
|
|
||||||
def with_offset(self, offset: float) -> "ASRToken":
|
def with_offset(self, offset: float) -> "ASRToken":
|
||||||
"""Return a new token with the time offset added."""
|
"""Return a new token with the time offset added."""
|
||||||
return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, detected_language=self.detected_language)
|
return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, detected_language=self.detected_language, probability=self.probability)
|
||||||
|
|
||||||
def is_silence(self) -> bool:
|
def is_silence(self) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -49,9 +49,12 @@ class TokensAlignment:
|
|||||||
|
|
||||||
def add_translation(self, segment: Segment) -> None:
|
def add_translation(self, segment: Segment) -> None:
|
||||||
"""Append translated text segments that overlap with a segment."""
|
"""Append translated text segments that overlap with a segment."""
|
||||||
|
if segment.translation is None:
|
||||||
|
segment.translation = ''
|
||||||
for ts in self.all_translation_segments:
|
for ts in self.all_translation_segments:
|
||||||
if ts.is_within(segment):
|
if ts.is_within(segment):
|
||||||
segment.translation += ts.text + (self.sep if ts.text else '')
|
if ts.text:
|
||||||
|
segment.translation += ts.text + self.sep
|
||||||
elif segment.translation:
|
elif segment.translation:
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -183,9 +186,9 @@ class TokensAlignment:
|
|||||||
else:
|
else:
|
||||||
diarization_buffer = ''
|
diarization_buffer = ''
|
||||||
for token in self.new_tokens:
|
for token in self.new_tokens:
|
||||||
if token.is_silence():
|
if isinstance(token, Silence):
|
||||||
if self.current_line_tokens:
|
if self.current_line_tokens:
|
||||||
self.validated_segments.append(Segment().from_tokens(self.current_line_tokens))
|
self.validated_segments.append(Segment.from_tokens(self.current_line_tokens))
|
||||||
self.current_line_tokens = []
|
self.current_line_tokens = []
|
||||||
|
|
||||||
end_silence = token.end if token.has_ended else time() - self.beg_loop
|
end_silence = token.end if token.has_ended else time() - self.beg_loop
|
||||||
@@ -201,7 +204,7 @@ class TokensAlignment:
|
|||||||
|
|
||||||
segments = list(self.validated_segments)
|
segments = list(self.validated_segments)
|
||||||
if self.current_line_tokens:
|
if self.current_line_tokens:
|
||||||
segments.append(Segment().from_tokens(self.current_line_tokens))
|
segments.append(Segment.from_tokens(self.current_line_tokens))
|
||||||
|
|
||||||
if current_silence:
|
if current_silence:
|
||||||
end_silence = current_silence.end if current_silence.has_ended else time() - self.beg_loop
|
end_silence = current_silence.end if current_silence.has_ended else time() - self.beg_loop
|
||||||
|
|||||||
484
whisperlivekit/voxtral_streaming.py
Normal file
484
whisperlivekit/voxtral_streaming.py
Normal file
@@ -0,0 +1,484 @@
|
|||||||
|
"""
|
||||||
|
Voxtral Mini Realtime streaming backend using voxmlx's incremental encode/decode.
|
||||||
|
|
||||||
|
Uses model.encode_step() for incremental audio encoding and token-by-token
|
||||||
|
autoregressive decoding, matching voxmlx's native streaming pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from whisperlivekit.timed_objects import ASRToken, Transcript
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
N_LEFT_PAD_TOKENS = 32
|
||||||
|
N_RIGHT_PAD_TOKENS = 17
|
||||||
|
|
||||||
|
|
||||||
|
class VoxtralStreamingASR:
|
||||||
|
"""Voxtral model holder for the streaming pipeline."""
|
||||||
|
|
||||||
|
sep = " "
|
||||||
|
|
||||||
|
def __init__(self, logfile=sys.stderr, **kwargs):
|
||||||
|
from voxmlx import _build_prompt_tokens
|
||||||
|
from voxmlx import load_model as vox_load_model
|
||||||
|
|
||||||
|
self.logfile = logfile
|
||||||
|
self.transcribe_kargs = {}
|
||||||
|
|
||||||
|
lan = kwargs.get("lan", "auto")
|
||||||
|
self.original_language = None if lan == "auto" else lan
|
||||||
|
|
||||||
|
DEFAULT_MODEL = "mlx-community/Voxtral-Mini-4B-Realtime-6bit"
|
||||||
|
model_path = kwargs.get("model_dir") or kwargs.get("model_path")
|
||||||
|
if not model_path:
|
||||||
|
model_size = kwargs.get("model_size", "")
|
||||||
|
# Only use model_size if it looks like a HF repo or a path, not a Whisper size name
|
||||||
|
if model_size and ("/" in model_size or model_size.startswith(".")):
|
||||||
|
model_path = model_size
|
||||||
|
else:
|
||||||
|
model_path = DEFAULT_MODEL
|
||||||
|
|
||||||
|
t = time.time()
|
||||||
|
logger.info(f"Loading Voxtral model '{model_path}' via voxmlx...")
|
||||||
|
self.model, self._tokenizer, self._config = vox_load_model(model_path)
|
||||||
|
self._prompt_tokens, self._n_delay_tokens = _build_prompt_tokens(
|
||||||
|
self._tokenizer
|
||||||
|
)
|
||||||
|
logger.info(f"Voxtral model loaded in {time.time() - t:.2f}s")
|
||||||
|
|
||||||
|
self.backend_choice = "voxtral-mlx"
|
||||||
|
self.tokenizer = None # sentence tokenizer — not needed for streaming
|
||||||
|
|
||||||
|
def transcribe(self, audio):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class VoxtralStreamingOnlineProcessor:
|
||||||
|
"""
|
||||||
|
Online processor for Voxtral streaming ASR.
|
||||||
|
|
||||||
|
Uses voxmlx's incremental encoding (encode_step) and token-by-token
|
||||||
|
autoregressive decoding. Each decode step corresponds to 80ms of audio.
|
||||||
|
"""
|
||||||
|
|
||||||
|
SAMPLING_RATE = 16000
|
||||||
|
|
||||||
|
def __init__(self, asr: VoxtralStreamingASR, logfile=sys.stderr):
|
||||||
|
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
|
||||||
|
|
||||||
|
self.asr = asr
|
||||||
|
self.logfile = logfile
|
||||||
|
self.end = 0.0
|
||||||
|
self.buffer = []
|
||||||
|
self.audio_buffer = np.array([], dtype=np.float32) # for logging compat
|
||||||
|
self._special_token_policy = SpecialTokenPolicy.IGNORE
|
||||||
|
self._reset_state()
|
||||||
|
logger.info(
|
||||||
|
f"[voxtral] Initialized. eos_id={asr._tokenizer.eos_id}, "
|
||||||
|
f"prefix_len={len(asr._prompt_tokens)}, "
|
||||||
|
f"n_delay={asr._n_delay_tokens}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _reset_state(self):
|
||||||
|
from voxmlx.audio import SAMPLES_PER_TOKEN
|
||||||
|
|
||||||
|
self._samples_per_token = SAMPLES_PER_TOKEN
|
||||||
|
|
||||||
|
# Incremental encoder state
|
||||||
|
self._audio_tail = None
|
||||||
|
self._conv1_tail = None
|
||||||
|
self._conv2_tail = None
|
||||||
|
self._encoder_cache = None
|
||||||
|
self._ds_buf = None
|
||||||
|
|
||||||
|
# Decoder state
|
||||||
|
self._decoder_cache = None
|
||||||
|
self._y = None # last sampled token (mx.array scalar)
|
||||||
|
self._t_cond = None
|
||||||
|
self._text_embeds = None
|
||||||
|
|
||||||
|
# Audio / decode tracking
|
||||||
|
self._pending_audio = np.zeros(0, dtype=np.float32)
|
||||||
|
self._audio_embeds = None
|
||||||
|
self._n_audio_samples_fed = 0
|
||||||
|
self._n_total_decoded = 0
|
||||||
|
self._first_cycle = True
|
||||||
|
self._prefilled = False
|
||||||
|
|
||||||
|
# Word extraction: accumulate token IDs, full-sequence decode for correct spacing
|
||||||
|
self._output_token_ids: List[int] = []
|
||||||
|
self._token_positions: List[int] = [] # decode position for each token
|
||||||
|
self._n_committed_words = 0
|
||||||
|
self._global_time_offset = 0.0
|
||||||
|
self._y_flushed_to_output = False # True after start_silence flushes pending _y
|
||||||
|
|
||||||
|
# ── Interface methods (same as SimulStreamingOnlineProcessor) ──
|
||||||
|
|
||||||
|
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float):
|
||||||
|
self.end = audio_stream_end_time
|
||||||
|
self._pending_audio = np.append(self._pending_audio, audio)
|
||||||
|
self.audio_buffer = self._pending_audio # for logging compat
|
||||||
|
|
||||||
|
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
|
||||||
|
try:
|
||||||
|
return self._process_iter_inner(is_last)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[voxtral] process_iter exception: {e}", exc_info=True)
|
||||||
|
return [], self.end
|
||||||
|
|
||||||
|
def _get_full_text(self) -> str:
|
||||||
|
"""Decode all accumulated token IDs at once for correct spacing."""
|
||||||
|
if not self._output_token_ids:
|
||||||
|
return ""
|
||||||
|
sp = self.asr._tokenizer
|
||||||
|
return sp.decode(self._output_token_ids, special_token_policy=self._special_token_policy)
|
||||||
|
|
||||||
|
def get_buffer(self) -> Transcript:
|
||||||
|
"""Return all uncommitted text as buffer, including pending _y token."""
|
||||||
|
# Temporarily include pending _y for buffer display
|
||||||
|
ids = list(self._output_token_ids)
|
||||||
|
if self._y is not None and not self._y_flushed_to_output:
|
||||||
|
sp = self.asr._tokenizer
|
||||||
|
token_id = self._y.item()
|
||||||
|
if token_id != sp.eos_id:
|
||||||
|
ids.append(token_id)
|
||||||
|
if not ids:
|
||||||
|
return Transcript(start=None, end=None, text="")
|
||||||
|
sp = self.asr._tokenizer
|
||||||
|
full_text = sp.decode(ids, special_token_policy=self._special_token_policy)
|
||||||
|
words = full_text.split()
|
||||||
|
uncommitted = words[self._n_committed_words:]
|
||||||
|
if uncommitted:
|
||||||
|
text = " ".join(uncommitted)
|
||||||
|
return Transcript(start=self.end, end=self.end, text=text)
|
||||||
|
return Transcript(start=None, end=None, text="")
|
||||||
|
|
||||||
|
def start_silence(self) -> Tuple[List[ASRToken], float]:
|
||||||
|
"""Flush all uncommitted words when silence starts."""
|
||||||
|
self._flush_last_y() # Include the pending _y token before flushing
|
||||||
|
words = self._flush_all_pending_words()
|
||||||
|
logger.info(f"[voxtral] start_silence: flushed {len(words)} words")
|
||||||
|
return words, self.end
|
||||||
|
|
||||||
|
def end_silence(self, silence_duration: float, offset: float):
|
||||||
|
self._global_time_offset += silence_duration
|
||||||
|
self.end += silence_duration
|
||||||
|
|
||||||
|
def new_speaker(self, change_speaker):
|
||||||
|
self.start_silence()
|
||||||
|
|
||||||
|
def warmup(self, audio, init_prompt=""):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def finish(self) -> Tuple[List[ASRToken], float]:
|
||||||
|
"""Flush remaining audio with right-padding to let the model finish decoding."""
|
||||||
|
right_pad = np.zeros(
|
||||||
|
N_RIGHT_PAD_TOKENS * self._samples_per_token, dtype=np.float32
|
||||||
|
)
|
||||||
|
self._pending_audio = np.append(self._pending_audio, right_pad)
|
||||||
|
self._n_audio_samples_fed += len(right_pad)
|
||||||
|
|
||||||
|
final_words, _ = self._process_iter_inner(is_last=True)
|
||||||
|
# Flush the last pending self._y token (like voxmlx's finally block)
|
||||||
|
self._flush_last_y()
|
||||||
|
final_words.extend(self._flush_all_pending_words())
|
||||||
|
return final_words, self.end
|
||||||
|
|
||||||
|
# ── Word extraction ──
|
||||||
|
|
||||||
|
def _pos_to_time(self, pos: int) -> float:
|
||||||
|
"""Convert a decode position to seconds relative to audio start."""
|
||||||
|
SPT = self._samples_per_token
|
||||||
|
return max(0.0, (pos - N_LEFT_PAD_TOKENS) * SPT / self.SAMPLING_RATE)
|
||||||
|
|
||||||
|
def _flush_last_y(self):
|
||||||
|
"""Flush the last pending self._y token that hasn't been processed yet."""
|
||||||
|
if self._y is None or self._y_flushed_to_output:
|
||||||
|
return
|
||||||
|
sp = self.asr._tokenizer
|
||||||
|
token_id = self._y.item()
|
||||||
|
if token_id != sp.eos_id:
|
||||||
|
self._output_token_ids.append(token_id)
|
||||||
|
self._token_positions.append(self._n_total_decoded)
|
||||||
|
self._y_flushed_to_output = True
|
||||||
|
|
||||||
|
def _extract_new_words(self) -> List[ASRToken]:
|
||||||
|
"""
|
||||||
|
Split accumulated text into words and return new complete words
|
||||||
|
(all but the last, which may still be growing).
|
||||||
|
"""
|
||||||
|
if not self._output_token_ids:
|
||||||
|
return []
|
||||||
|
|
||||||
|
full_text = self._get_full_text()
|
||||||
|
words = full_text.split()
|
||||||
|
|
||||||
|
new_words: List[ASRToken] = []
|
||||||
|
n_tokens = len(self._output_token_ids)
|
||||||
|
# All words except the last are guaranteed complete
|
||||||
|
while len(words) > self._n_committed_words + 1:
|
||||||
|
word = words[self._n_committed_words]
|
||||||
|
word_idx = self._n_committed_words
|
||||||
|
n_words_total = len(words)
|
||||||
|
# Approximate: assign token range proportionally
|
||||||
|
tok_start = int(word_idx / n_words_total * n_tokens)
|
||||||
|
tok_end = int((word_idx + 1) / n_words_total * n_tokens)
|
||||||
|
tok_start = min(tok_start, len(self._token_positions) - 1)
|
||||||
|
tok_end = min(tok_end, len(self._token_positions) - 1)
|
||||||
|
|
||||||
|
start_time = self._pos_to_time(self._token_positions[tok_start]) + self._global_time_offset
|
||||||
|
end_time = self._pos_to_time(self._token_positions[tok_end]) + self._global_time_offset
|
||||||
|
|
||||||
|
# Prepend space to match Whisper convention (Segment.from_tokens joins with '')
|
||||||
|
text = word if self._n_committed_words == 0 else " " + word
|
||||||
|
new_words.append(ASRToken(start=start_time, end=end_time, text=text))
|
||||||
|
self._n_committed_words += 1
|
||||||
|
|
||||||
|
return new_words
|
||||||
|
|
||||||
|
def _flush_all_pending_words(self) -> List[ASRToken]:
|
||||||
|
"""Flush ALL words including the last partial one."""
|
||||||
|
if not self._output_token_ids:
|
||||||
|
return []
|
||||||
|
|
||||||
|
full_text = self._get_full_text()
|
||||||
|
words = full_text.split()
|
||||||
|
|
||||||
|
new_words: List[ASRToken] = []
|
||||||
|
n_tokens = len(self._output_token_ids)
|
||||||
|
n_words_total = max(len(words), 1)
|
||||||
|
|
||||||
|
while self._n_committed_words < len(words):
|
||||||
|
word = words[self._n_committed_words]
|
||||||
|
word_idx = self._n_committed_words
|
||||||
|
|
||||||
|
tok_start = int(word_idx / n_words_total * n_tokens)
|
||||||
|
tok_end = int((word_idx + 1) / n_words_total * n_tokens)
|
||||||
|
tok_start = min(tok_start, max(len(self._token_positions) - 1, 0))
|
||||||
|
tok_end = min(tok_end, max(len(self._token_positions) - 1, 0))
|
||||||
|
|
||||||
|
if self._token_positions:
|
||||||
|
start_time = self._pos_to_time(self._token_positions[tok_start]) + self._global_time_offset
|
||||||
|
end_time = self._pos_to_time(self._token_positions[tok_end]) + self._global_time_offset
|
||||||
|
else:
|
||||||
|
start_time = self._global_time_offset
|
||||||
|
end_time = self._global_time_offset
|
||||||
|
|
||||||
|
# Prepend space to match Whisper convention (Segment.from_tokens joins with '')
|
||||||
|
text = word if self._n_committed_words == 0 else " " + word
|
||||||
|
new_words.append(ASRToken(start=start_time, end=end_time, text=text))
|
||||||
|
self._n_committed_words += 1
|
||||||
|
|
||||||
|
return new_words
|
||||||
|
|
||||||
|
# ── Core streaming logic ──
|
||||||
|
|
||||||
|
def _process_iter_inner(self, is_last: bool) -> Tuple[List[ASRToken], float]:
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
from voxmlx.audio import log_mel_spectrogram_step
|
||||||
|
from voxmlx.cache import RotatingKVCache
|
||||||
|
|
||||||
|
model = self.asr.model
|
||||||
|
sp = self.asr._tokenizer
|
||||||
|
prompt_tokens = self.asr._prompt_tokens
|
||||||
|
prefix_len = len(prompt_tokens)
|
||||||
|
SPT = self._samples_per_token
|
||||||
|
|
||||||
|
# ── Phase 1: Encode new audio ──
|
||||||
|
if self._first_cycle and len(self._pending_audio) >= SPT:
|
||||||
|
left_pad = np.zeros(N_LEFT_PAD_TOKENS * SPT, dtype=np.float32)
|
||||||
|
n_feed = (len(self._pending_audio) // SPT) * SPT
|
||||||
|
chunk = np.concatenate([left_pad, self._pending_audio[:n_feed]])
|
||||||
|
self._pending_audio = self._pending_audio[n_feed:]
|
||||||
|
self._n_audio_samples_fed += n_feed
|
||||||
|
|
||||||
|
mel, self._audio_tail = log_mel_spectrogram_step(
|
||||||
|
chunk, self._audio_tail
|
||||||
|
)
|
||||||
|
(
|
||||||
|
new_embeds,
|
||||||
|
self._conv1_tail,
|
||||||
|
self._conv2_tail,
|
||||||
|
self._encoder_cache,
|
||||||
|
self._ds_buf,
|
||||||
|
) = model.encode_step(
|
||||||
|
mel,
|
||||||
|
self._conv1_tail,
|
||||||
|
self._conv2_tail,
|
||||||
|
self._encoder_cache,
|
||||||
|
self._ds_buf,
|
||||||
|
)
|
||||||
|
if new_embeds is not None:
|
||||||
|
mx.eval(new_embeds)
|
||||||
|
self._audio_embeds = new_embeds
|
||||||
|
logger.info(f"[voxtral] first encode: {new_embeds.shape[0]} embeds from {n_feed} samples")
|
||||||
|
else:
|
||||||
|
logger.info(f"[voxtral] first encode: no embeds from {n_feed} samples")
|
||||||
|
self._first_cycle = False
|
||||||
|
|
||||||
|
elif not self._first_cycle and len(self._pending_audio) >= SPT:
|
||||||
|
n_feed = (len(self._pending_audio) // SPT) * SPT
|
||||||
|
chunk = self._pending_audio[:n_feed]
|
||||||
|
self._pending_audio = self._pending_audio[n_feed:]
|
||||||
|
self._n_audio_samples_fed += n_feed
|
||||||
|
|
||||||
|
mel, self._audio_tail = log_mel_spectrogram_step(
|
||||||
|
chunk, self._audio_tail
|
||||||
|
)
|
||||||
|
(
|
||||||
|
new_embeds,
|
||||||
|
self._conv1_tail,
|
||||||
|
self._conv2_tail,
|
||||||
|
self._encoder_cache,
|
||||||
|
self._ds_buf,
|
||||||
|
) = model.encode_step(
|
||||||
|
mel,
|
||||||
|
self._conv1_tail,
|
||||||
|
self._conv2_tail,
|
||||||
|
self._encoder_cache,
|
||||||
|
self._ds_buf,
|
||||||
|
)
|
||||||
|
if new_embeds is not None:
|
||||||
|
mx.eval(new_embeds)
|
||||||
|
if self._audio_embeds is not None:
|
||||||
|
self._audio_embeds = mx.concatenate(
|
||||||
|
[self._audio_embeds, new_embeds]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._audio_embeds = new_embeds
|
||||||
|
|
||||||
|
self.audio_buffer = self._pending_audio # for logging compat
|
||||||
|
|
||||||
|
if self._audio_embeds is None:
|
||||||
|
return [], self.end
|
||||||
|
|
||||||
|
# Safety: don't decode ahead of encoded audio
|
||||||
|
safe_total = (
|
||||||
|
N_LEFT_PAD_TOKENS + self._n_audio_samples_fed // SPT
|
||||||
|
)
|
||||||
|
n_decodable = min(
|
||||||
|
self._audio_embeds.shape[0], safe_total - self._n_total_decoded
|
||||||
|
)
|
||||||
|
|
||||||
|
if n_decodable <= 0:
|
||||||
|
return [], self.end
|
||||||
|
|
||||||
|
# ── Phase 2: Prefill (once per utterance) ──
|
||||||
|
if not self._prefilled:
|
||||||
|
if self._n_total_decoded + self._audio_embeds.shape[0] < prefix_len:
|
||||||
|
logger.info(
|
||||||
|
f"[voxtral] waiting for prefill: have {self._audio_embeds.shape[0]} embeds, need {prefix_len}"
|
||||||
|
)
|
||||||
|
return [], self.end
|
||||||
|
|
||||||
|
n_layers = len(model.language_model.layers)
|
||||||
|
self._decoder_cache = [RotatingKVCache(8192) for _ in range(n_layers)]
|
||||||
|
|
||||||
|
self._t_cond = model.time_embedding(
|
||||||
|
mx.array([self.asr._n_delay_tokens], dtype=mx.float32)
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_ids = mx.array([prompt_tokens])
|
||||||
|
self._text_embeds = model.language_model.embed(prompt_ids)[0]
|
||||||
|
|
||||||
|
prefix_embeds = (
|
||||||
|
self._text_embeds + self._audio_embeds[:prefix_len]
|
||||||
|
)[None, :, :]
|
||||||
|
|
||||||
|
logits = model.decode(
|
||||||
|
prefix_embeds, self._t_cond, "causal", self._decoder_cache
|
||||||
|
)
|
||||||
|
mx.eval(
|
||||||
|
logits,
|
||||||
|
*[x for c in self._decoder_cache for x in (c.keys, c.values)],
|
||||||
|
)
|
||||||
|
|
||||||
|
self._y = mx.argmax(logits[0, -1:], axis=-1).squeeze()
|
||||||
|
mx.async_eval(self._y)
|
||||||
|
|
||||||
|
self._audio_embeds = self._audio_embeds[prefix_len:]
|
||||||
|
self._n_total_decoded = prefix_len
|
||||||
|
self._prefilled = True
|
||||||
|
logger.info(f"[voxtral] prefill done, first token y={self._y.item()}")
|
||||||
|
|
||||||
|
n_decodable = min(
|
||||||
|
self._audio_embeds.shape[0], safe_total - self._n_total_decoded
|
||||||
|
)
|
||||||
|
|
||||||
|
if n_decodable <= 0:
|
||||||
|
return [], self.end
|
||||||
|
|
||||||
|
# ── Phase 3: Decode new positions ──
|
||||||
|
eos_id = sp.eos_id
|
||||||
|
hit_eos = False
|
||||||
|
n_consumed = 0
|
||||||
|
|
||||||
|
for i in range(n_decodable):
|
||||||
|
token_embed = model.language_model.embed(self._y.reshape(1, 1))[0, 0]
|
||||||
|
step_embed = (self._audio_embeds[i] + token_embed)[None, None, :]
|
||||||
|
logits = model.decode(
|
||||||
|
step_embed, self._t_cond, mask=None, cache=self._decoder_cache
|
||||||
|
)
|
||||||
|
next_y = mx.argmax(logits[0, -1:], axis=-1).squeeze()
|
||||||
|
mx.async_eval(next_y)
|
||||||
|
|
||||||
|
token_id = self._y.item()
|
||||||
|
n_consumed = i + 1
|
||||||
|
|
||||||
|
if token_id == eos_id:
|
||||||
|
hit_eos = True
|
||||||
|
logger.info("[voxtral] hit EOS")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Accumulate token ID — full-sequence decode produces correct spacing
|
||||||
|
# Skip if this _y was already flushed by start_silence()
|
||||||
|
if self._y_flushed_to_output:
|
||||||
|
self._y_flushed_to_output = False
|
||||||
|
else:
|
||||||
|
self._output_token_ids.append(token_id)
|
||||||
|
# Track position for timestamp estimation
|
||||||
|
pos = self._n_total_decoded + i
|
||||||
|
self._token_positions.append(pos)
|
||||||
|
|
||||||
|
if i > 0 and i % 256 == 0:
|
||||||
|
mx.clear_cache()
|
||||||
|
|
||||||
|
self._y = next_y
|
||||||
|
|
||||||
|
self._n_total_decoded += n_consumed
|
||||||
|
|
||||||
|
# Trim consumed embeddings
|
||||||
|
if self._audio_embeds.shape[0] > n_consumed:
|
||||||
|
self._audio_embeds = self._audio_embeds[n_consumed:]
|
||||||
|
else:
|
||||||
|
self._audio_embeds = None
|
||||||
|
|
||||||
|
# Log decode results
|
||||||
|
full_text = self._get_full_text()
|
||||||
|
logger.info(
|
||||||
|
f"[voxtral] decoded {n_consumed} tokens | "
|
||||||
|
f"total_decoded={self._n_total_decoded} | "
|
||||||
|
f"text='{full_text[-80:]}' | "
|
||||||
|
f"n_words={len(full_text.split())} committed={self._n_committed_words}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract complete words from the decoded token sequence
|
||||||
|
new_words = self._extract_new_words()
|
||||||
|
|
||||||
|
if hit_eos:
|
||||||
|
new_words.extend(self._flush_all_pending_words())
|
||||||
|
self._reset_state()
|
||||||
|
|
||||||
|
if new_words:
|
||||||
|
logger.info(f"[voxtral] returning {len(new_words)} words: {[w.text for w in new_words]}")
|
||||||
|
|
||||||
|
self.buffer = []
|
||||||
|
return new_words, self.end
|
||||||
@@ -108,7 +108,7 @@ def available_models() -> List[str]:
|
|||||||
def _infer_dims_from_config(path: str) -> Optional[ModelDimensions]:
|
def _infer_dims_from_config(path: str) -> Optional[ModelDimensions]:
|
||||||
"""
|
"""
|
||||||
attempt to infer ModelDimensions from a HF style config.json located
|
attempt to infer ModelDimensions from a HF style config.json located
|
||||||
next to the given checkpoint, usefull for distilled models
|
next to the given checkpoint, usefull for distilled models/MLX models.
|
||||||
"""
|
"""
|
||||||
candidates = []
|
candidates = []
|
||||||
if os.path.isdir(path):
|
if os.path.isdir(path):
|
||||||
@@ -122,6 +122,25 @@ def _infer_dims_from_config(path: str) -> Optional[ModelDimensions]:
|
|||||||
with open(candidate, "r", encoding="utf-8") as f:
|
with open(candidate, "r", encoding="utf-8") as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
|
|
||||||
|
# native Whisper format
|
||||||
|
native_keys = ["n_mels", "n_audio_ctx", "n_audio_state", "n_audio_head",
|
||||||
|
"n_audio_layer", "n_vocab", "n_text_ctx", "n_text_state",
|
||||||
|
"n_text_head", "n_text_layer"]
|
||||||
|
if all(k in config for k in native_keys):
|
||||||
|
return ModelDimensions(
|
||||||
|
n_mels=config["n_mels"],
|
||||||
|
n_audio_ctx=config["n_audio_ctx"],
|
||||||
|
n_audio_state=config["n_audio_state"],
|
||||||
|
n_audio_head=config["n_audio_head"],
|
||||||
|
n_audio_layer=config["n_audio_layer"],
|
||||||
|
n_vocab=config["n_vocab"],
|
||||||
|
n_text_ctx=config["n_text_ctx"],
|
||||||
|
n_text_state=config["n_text_state"],
|
||||||
|
n_text_head=config["n_text_head"],
|
||||||
|
n_text_layer=config["n_text_layer"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# HuggingFace format
|
||||||
try:
|
try:
|
||||||
return ModelDimensions(
|
return ModelDimensions(
|
||||||
n_mels=config["num_mel_bins"],
|
n_mels=config["num_mel_bins"],
|
||||||
@@ -236,6 +255,24 @@ def _convert_hf_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, tor
|
|||||||
return converted if converted else state_dict
|
return converted if converted else state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_mlx_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Converts an mlx whisper checkpoint to a default openai whisper one
|
||||||
|
"""
|
||||||
|
if not any("mlp1" in k or "mlp2" in k for k in state_dict):
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
converted = {}
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
if key == "alignment_heads":
|
||||||
|
continue
|
||||||
|
|
||||||
|
new_key = key.replace(".mlp1.", ".mlp.0.").replace(".mlp2.", ".mlp.2.")
|
||||||
|
converted[new_key] = value
|
||||||
|
|
||||||
|
return converted
|
||||||
|
|
||||||
|
|
||||||
def _load_lora_state(lora_path: str):
|
def _load_lora_state(lora_path: str):
|
||||||
safe_path = os.path.join(lora_path, "adapter_model.safetensors")
|
safe_path = os.path.join(lora_path, "adapter_model.safetensors")
|
||||||
bin_path = os.path.join(lora_path, "adapter_model.bin")
|
bin_path = os.path.join(lora_path, "adapter_model.bin")
|
||||||
@@ -520,7 +557,12 @@ def load_model(
|
|||||||
state_dict = checkpoint["model_state_dict"]
|
state_dict = checkpoint["model_state_dict"]
|
||||||
else:
|
else:
|
||||||
state_dict = checkpoint
|
state_dict = checkpoint
|
||||||
|
|
||||||
|
if alignment_heads is None and "alignment_heads" in state_dict:
|
||||||
|
alignment_heads = state_dict["alignment_heads"]
|
||||||
|
|
||||||
state_dict = _convert_hf_state_dict(state_dict)
|
state_dict = _convert_hf_state_dict(state_dict)
|
||||||
|
state_dict = _convert_mlx_state_dict(state_dict)
|
||||||
_apply_lora_adapter(state_dict, lora_path)
|
_apply_lora_adapter(state_dict, lora_path)
|
||||||
|
|
||||||
if dims_cfg is not None:
|
if dims_cfg is not None:
|
||||||
@@ -546,8 +588,13 @@ def load_model(
|
|||||||
model.load_state_dict(state_dict)
|
model.load_state_dict(state_dict)
|
||||||
|
|
||||||
if alignment_heads is not None:
|
if alignment_heads is not None:
|
||||||
model.set_alignment_heads(alignment_heads)
|
if isinstance(alignment_heads, bytes):
|
||||||
|
model.set_alignment_heads(alignment_heads)
|
||||||
|
elif isinstance(alignment_heads, torch.Tensor): #for mlx whisper
|
||||||
|
mask = torch.zeros(dims.n_text_layer, dims.n_text_head, dtype=torch.bool)
|
||||||
|
for layer, head in alignment_heads.tolist():
|
||||||
|
mask[layer, head] = True
|
||||||
|
model.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
|
||||||
return model.to(device)
|
return model.to(device)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user