0.2.8 : only the decoder of whisper is loaded in memory when a different encoder is used

This commit is contained in:
Quentin Fuxa
2025-09-02 21:12:25 +02:00
parent 50b0527858
commit 3bd2122eb4
11 changed files with 112 additions and 47 deletions

View File

@@ -9,7 +9,7 @@
<p align="center"> <p align="center">
<a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a> <a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a>
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=installations"></a> <a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=installations"></a>
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.13-dark_green"></a> <a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.15-dark_green"></a>
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-MIT/Dual Licensed-dark_green"></a> <a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-MIT/Dual Licensed-dark_green"></a>
</p> </p>
@@ -67,10 +67,10 @@ pip install whisperlivekit
| Optional | `pip install` | | Optional | `pip install` |
|-----------|-------------| |-----------|-------------|
| **Speaker diarization with Sortformer** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` | | **Speaker diarization with Sortformer** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` |
| Speaker diarization with Diart | `diart` | | **Apple Silicon optimized backend** | `mlx-whisper` |
| Original Whisper backend | `whisper` | | *[Not recommanded]* Speaker diarization with Diart | `diart` |
| Improved timestamps backend | `whisper-timestamped` | | *[Not recommanded]* Original Whisper backend | `whisper` |
| Apple Silicon optimization backend | `mlx-whisper` | | *[Not recommanded]* Improved timestamps backend | `whisper-timestamped` |
| OpenAI API backend | `openai` | | OpenAI API backend | `openai` |
See **Parameters & Configuration** below on how to use them. See **Parameters & Configuration** below on how to use them.
@@ -138,6 +138,7 @@ An important list of parameters can be changed. But what *should* you change?
- the `--language`. List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English. - the `--language`. List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English.
- the `--backend` ? you can switch to `--backend faster-whisper` if `simulstreaming` does not work correctly or if you prefer to avoid the dual-license requirements. - the `--backend` ? you can switch to `--backend faster-whisper` if `simulstreaming` does not work correctly or if you prefer to avoid the dual-license requirements.
- `--warmup-file`, if you have one - `--warmup-file`, if you have one
- `--task translate`, to translate in english
- `--host`, `--port`, `--ssl-certfile`, `--ssl-keyfile`, if you set up a server - `--host`, `--port`, `--ssl-certfile`, `--ssl-keyfile`, if you set up a server
- `--diarization`, if you want to use it. - `--diarization`, if you want to use it.
@@ -159,14 +160,9 @@ The rest I don't recommend. But below are your options.
| `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` | | `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` |
| WhisperStreaming backend options | Description | Default |
|-----------|-------------|---------|
| `--confidence-validation` | Use confidence scores for faster validation | `False` |
| `--buffer_trimming` | Buffer trimming strategy (`sentence` or `segment`) | `segment` |
| SimulStreaming backend options | Description | Default | | SimulStreaming backend options | Description | Default |
|-----------|-------------|---------| |-----------|-------------|---------|
| `--disable-fast-encoder` | Disable Faster Whisper or MLX Whisper backends for the encoder (if installed). Inference can be slower but helpful when GPU memory is limited | `False` |
| `--frame-threshold` | AlignAtt frame threshold (lower = faster, higher = more accurate) | `25` | | `--frame-threshold` | AlignAtt frame threshold (lower = faster, higher = more accurate) | `25` |
| `--beams` | Number of beams for beam search (1 = greedy decoding) | `1` | | `--beams` | Number of beams for beam search (1 = greedy decoding) | `1` |
| `--decoder` | Force decoder type (`beam` or `greedy`) | `auto` | | `--decoder` | Force decoder type (`beam` or `greedy`) | `auto` |
@@ -180,6 +176,12 @@ The rest I don't recommend. But below are your options.
| `--model-path` | Direct path to .pt model file. Download it if not found | `./base.pt` | | `--model-path` | Direct path to .pt model file. Download it if not found | `./base.pt` |
| `--preloaded-model-count` | Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent users) | `1` | | `--preloaded-model-count` | Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent users) | `1` |
| WhisperStreaming backend options | Description | Default |
|-----------|-------------|---------|
| `--confidence-validation` | Use confidence scores for faster validation | `False` |
| `--buffer_trimming` | Buffer trimming strategy (`sentence` or `segment`) | `segment` |
| Diarization options | Description | Default | | Diarization options | Description | Default |
|-----------|-------------|---------| |-----------|-------------|---------|
| `--diarization` | Enable speaker identification | `False` | | `--diarization` | Enable speaker identification | `False` |

Binary file not shown.

Before

Width:  |  Height:  |  Size: 367 KiB

After

Width:  |  Height:  |  Size: 368 KiB

View File

@@ -4,8 +4,8 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "whisperlivekit" name = "whisperlivekit"
version = "0.2.7" version = "0.2.8"
description = "Real-time, Fully Local Whisper's Speech-to-Text and Speaker Diarization" description = "Real-time speech-to-text with speaker diarization using Whisper"
readme = "README.md" readme = "README.md"
authors = [ authors = [
{ name = "Quentin Fuxa" } { name = "Quentin Fuxa" }
@@ -18,6 +18,11 @@ classifiers = [
"License :: OSI Approved :: MIT License", "License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.14",
"Programming Language :: Python :: 3.15",
"Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Multimedia :: Sound/Audio :: Speech" "Topic :: Multimedia :: Sound/Audio :: Speech"
] ]
@@ -28,7 +33,8 @@ dependencies = [
"faster-whisper", "faster-whisper",
"uvicorn", "uvicorn",
"websockets", "websockets",
"torch", "torchaudio>=2.0.0",
"torch>=2.0.0",
"tqdm", "tqdm",
"tiktoken", "tiktoken",
'triton>=2.0.0; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")' 'triton>=2.0.0; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")'

View File

@@ -19,6 +19,15 @@ transcription_engine = None
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
#to remove after 0.2.8
if args.backend == "simulstreaming" and not args.disable_fast_encoder:
logger.warning(f"""
{'='*50}
WhisperLiveKit 0.2.8 has introduced a new fast encoder feature using MLX Whisper or Faster Whisper for improved speed. Use --disable-fast-encoder to disable if you encounter issues.
{'='*50}
""")
global transcription_engine global transcription_engine
transcription_engine = TranscriptionEngine( transcription_engine = TranscriptionEngine(
**vars(args), **vars(args),

View File

@@ -46,6 +46,7 @@ class TranscriptionEngine:
"confidence_validation": False, "confidence_validation": False,
"buffer_trimming_sec": 15, "buffer_trimming_sec": 15,
# simulstreaming params: # simulstreaming params:
"disable_fast_encoder": False,
"frame_threshold": 25, "frame_threshold": 25,
"beams": 1, "beams": 1,
"decoder_type": None, "decoder_type": None,
@@ -60,7 +61,7 @@ class TranscriptionEngine:
"diarization_backend": "sortformer", "diarization_backend": "sortformer",
# diart params: # diart params:
"segmentation_model": "pyannote/segmentation-3.0", "segmentation_model": "pyannote/segmentation-3.0",
"embedding_model": "pyannote/embedding", "embedding_model": "pyannote/embedding",
} }
config_dict = {**defaults, **kwargs} config_dict = {**defaults, **kwargs}
@@ -97,7 +98,7 @@ class TranscriptionEngine:
simulstreaming_kwargs = {} simulstreaming_kwargs = {}
for attr in ['frame_threshold', 'beams', 'decoder_type', 'audio_max_len', 'audio_min_len', for attr in ['frame_threshold', 'beams', 'decoder_type', 'audio_max_len', 'audio_min_len',
'cif_ckpt_path', 'never_fire', 'init_prompt', 'static_init_prompt', 'cif_ckpt_path', 'never_fire', 'init_prompt', 'static_init_prompt',
'max_context_tokens', 'model_path', 'warmup_file', 'preload_model_count']: 'max_context_tokens', 'model_path', 'warmup_file', 'preload_model_count', 'disable_fast_encoder']:
if hasattr(self.args, attr): if hasattr(self.args, attr):
simulstreaming_kwargs[attr] = getattr(self.args, attr) simulstreaming_kwargs[attr] = getattr(self.args, attr)

View File

@@ -161,6 +161,14 @@ def parse_args():
# SimulStreaming-specific arguments # SimulStreaming-specific arguments
simulstreaming_group = parser.add_argument_group('SimulStreaming arguments (only used with --backend simulstreaming)') simulstreaming_group = parser.add_argument_group('SimulStreaming arguments (only used with --backend simulstreaming)')
simulstreaming_group.add_argument(
"--disable-fast-encoder",
action="store_true",
default=False,
dest="disable_fast_encoder",
help="Disable Faster Whisper or MLX Whisper backends for encoding (if installed). Slower but helpful when GPU memory is limited",
)
simulstreaming_group.add_argument( simulstreaming_group.add_argument(
"--frame-threshold", "--frame-threshold",

View File

@@ -244,7 +244,8 @@ class SimulStreamingASR():
self.max_context_tokens = kwargs.get('max_context_tokens', None) self.max_context_tokens = kwargs.get('max_context_tokens', None)
self.warmup_file = kwargs.get('warmup_file', None) self.warmup_file = kwargs.get('warmup_file', None)
self.preload_model_count = kwargs.get('preload_model_count', 1) self.preload_model_count = kwargs.get('preload_model_count', 1)
self.disable_fast_encoder = kwargs.get('disable_fast_encoder', False)
self.fast_encoder = False
if model_dir is not None: if model_dir is not None:
self.model_path = model_dir self.model_path = model_dir
elif modelsize is not None: elif modelsize is not None:
@@ -289,25 +290,44 @@ class SimulStreamingASR():
self.model_name = os.path.basename(self.cfg.model_path).replace(".pt", "") self.model_name = os.path.basename(self.cfg.model_path).replace(".pt", "")
self.model_path = os.path.dirname(os.path.abspath(self.cfg.model_path)) self.model_path = os.path.dirname(os.path.abspath(self.cfg.model_path))
self.models = [self.load_model() for i in range(self.preload_model_count)]
self.mlx_encoder, self.fw_encoder = None, None self.mlx_encoder, self.fw_encoder = None, None
if HAS_MLX_WHISPER: if not self.disable_fast_encoder:
print('Simulstreaming will use MLX whisper for a faster encoder.') if HAS_MLX_WHISPER:
mlx_model_name = mlx_model_mapping[self.model_name] print('Simulstreaming will use MLX whisper for a faster encoder.')
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model_name) mlx_model_name = mlx_model_mapping[self.model_name]
elif HAS_FASTER_WHISPER: self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model_name)
print('Simulstreaming will use Faster Whisper for the encoder.') self.fast_encoder = True
self.fw_encoder = WhisperModel( elif HAS_FASTER_WHISPER:
self.model_name, print('Simulstreaming will use Faster Whisper for the encoder.')
device='auto', self.fw_encoder = WhisperModel(
compute_type='auto', self.model_name,
) device='auto',
compute_type='auto',
)
self.fast_encoder = True
self.models = [self.load_model() for i in range(self.preload_model_count)]
def load_model(self): def load_model(self):
whisper_model = load_model(name=self.model_name, download_root=self.model_path) whisper_model = load_model(name=self.model_name, download_root=self.model_path, decoder_only=self.fast_encoder)
warmup_audio = load_file(self.warmup_file) warmup_audio = load_file(self.warmup_file)
whisper_model.transcribe(warmup_audio, language=self.original_language if self.original_language != 'auto' else None) if warmup_audio is not None:
warmup_audio = torch.from_numpy(warmup_audio).float()
if self.fast_encoder:
temp_model = PaddedAlignAttWhisper(
cfg=self.cfg,
loaded_model=whisper_model,
mlx_encoder=self.mlx_encoder,
fw_encoder=self.fw_encoder,
)
temp_model.warmup(warmup_audio)
temp_model.remove_hooks()
else:
# For standard encoder, use the original transcribe warmup
warmup_audio = load_file(self.warmup_file)
whisper_model.transcribe(warmup_audio, language=self.original_language if self.original_language != 'auto' else None)
return whisper_model return whisper_model
def get_new_model_instance(self): def get_new_model_instance(self):

View File

@@ -64,7 +64,7 @@ class PaddedAlignAttWhisper:
self.mlx_encoder = mlx_encoder self.mlx_encoder = mlx_encoder
self.fw_encoder = fw_encoder self.fw_encoder = fw_encoder
if HAS_FASTER_WHISPER: if fw_encoder:
self.fw_feature_extractor = FeatureExtractor(feature_size=self.model.dims.n_mels) self.fw_feature_extractor = FeatureExtractor(feature_size=self.model.dims.n_mels)
logger.info(f"Model dimensions: {self.model.dims}") logger.info(f"Model dimensions: {self.model.dims}")
@@ -176,6 +176,15 @@ class PaddedAlignAttWhisper:
for hook in self.l_hooks: for hook in self.l_hooks:
hook.remove() hook.remove()
def warmup(self, audio):
try:
self.insert_audio(audio)
self.infer(is_last=True)
self.refresh_segment(complete=True)
logger.info("Model warmed up successfully")
except Exception as e:
logger.exception(f"Model warmup failed: {e}")
def create_tokenizer(self, language=None): def create_tokenizer(self, language=None):
self.tokenizer = tokenizer.get_tokenizer( self.tokenizer = tokenizer.get_tokenizer(
multilingual=self.tokenizer_is_multilingual, multilingual=self.tokenizer_is_multilingual,
@@ -386,14 +395,14 @@ class PaddedAlignAttWhisper:
# NEW : we can use a different encoder, before using standart whisper for cross attention with the hooks on the decoder # NEW : we can use a different encoder, before using standart whisper for cross attention with the hooks on the decoder
beg_encode = time() beg_encode = time()
if HAS_MLX_WHISPER: if self.mlx_encoder:
mlx_mel_padded = mlx_log_mel_spectrogram(audio=input_segments.detach(), n_mels=self.model.dims.n_mels, padding=N_SAMPLES) mlx_mel_padded = mlx_log_mel_spectrogram(audio=input_segments.detach(), n_mels=self.model.dims.n_mels, padding=N_SAMPLES)
mlx_mel = mlx_pad_or_trim(mlx_mel_padded, N_FRAMES, axis=-2) mlx_mel = mlx_pad_or_trim(mlx_mel_padded, N_FRAMES, axis=-2)
mlx_encoder_feature = self.mlx_encoder.encoder(mlx_mel[None]) mlx_encoder_feature = self.mlx_encoder.encoder(mlx_mel[None])
encoder_feature = torch.tensor(np.array(mlx_encoder_feature)) encoder_feature = torch.tensor(np.array(mlx_encoder_feature))
content_mel_len = int((mlx_mel_padded.shape[0] - mlx_mel.shape[0])/2) content_mel_len = int((mlx_mel_padded.shape[0] - mlx_mel.shape[0])/2)
device = 'cpu' device = 'cpu'
elif HAS_FASTER_WHISPER: elif self.fw_encoder:
audio_length_seconds = len(input_segments) / 16000 audio_length_seconds = len(input_segments) / 16000
content_mel_len = int(audio_length_seconds * 100)//2 content_mel_len = int(audio_length_seconds * 100)//2
mel_padded_2 = self.fw_feature_extractor(waveform=input_segments.numpy(), padding=N_SAMPLES)[None, :] mel_padded_2 = self.fw_feature_extractor(waveform=input_segments.numpy(), padding=N_SAMPLES)[None, :]

View File

@@ -105,6 +105,7 @@ def load_model(
device: Optional[Union[str, torch.device]] = None, device: Optional[Union[str, torch.device]] = None,
download_root: str = None, download_root: str = None,
in_memory: bool = False, in_memory: bool = False,
decoder_only=False
) -> Whisper: ) -> Whisper:
""" """
Load a Whisper ASR model Load a Whisper ASR model
@@ -151,7 +152,14 @@ def load_model(
del checkpoint_file del checkpoint_file
dims = ModelDimensions(**checkpoint["dims"]) dims = ModelDimensions(**checkpoint["dims"])
model = Whisper(dims) model = Whisper(dims, decoder_only=decoder_only)
if decoder_only:
checkpoint["model_state_dict"] = {
k: v for k, v in checkpoint["model_state_dict"].items()
if 'encoder' not in k
}
model.load_state_dict(checkpoint["model_state_dict"]) model.load_state_dict(checkpoint["model_state_dict"])
if alignment_heads is not None: if alignment_heads is not None:

View File

@@ -253,16 +253,18 @@ class TextDecoder(nn.Module):
class Whisper(nn.Module): class Whisper(nn.Module):
def __init__(self, dims: ModelDimensions): def __init__(self, dims: ModelDimensions, decoder_only: bool = False):
super().__init__() super().__init__()
self.dims = dims self.dims = dims
self.encoder = AudioEncoder(
self.dims.n_mels, if not decoder_only:
self.dims.n_audio_ctx, self.encoder = AudioEncoder(
self.dims.n_audio_state, self.dims.n_mels,
self.dims.n_audio_head, self.dims.n_audio_ctx,
self.dims.n_audio_layer, self.dims.n_audio_state,
) self.dims.n_audio_head,
self.dims.n_audio_layer,
)
self.decoder = TextDecoder( self.decoder = TextDecoder(
self.dims.n_vocab, self.dims.n_vocab,
self.dims.n_text_ctx, self.dims.n_text_ctx,

View File

@@ -31,21 +31,21 @@ def load_file(warmup_file=None, timeout=5):
logger.debug(f"Download successful in {time.time() - start_time:.2f}s") logger.debug(f"Download successful in {time.time() - start_time:.2f}s")
except (urllib.error.URLError, socket.timeout) as e: except (urllib.error.URLError, socket.timeout) as e:
logger.warning(f"Download failed: {e}. Proceeding without warmup.") logger.warning(f"Download failed: {e}. Proceeding without warmup.")
return False return None
finally: finally:
socket.setdefaulttimeout(original_timeout) socket.setdefaulttimeout(original_timeout)
elif not warmup_file: elif not warmup_file:
return False return None
if not warmup_file or not os.path.exists(warmup_file) or os.path.getsize(warmup_file) == 0: if not warmup_file or not os.path.exists(warmup_file) or os.path.getsize(warmup_file) == 0:
logger.warning(f"Warmup file {warmup_file} invalid or missing.") logger.warning(f"Warmup file {warmup_file} invalid or missing.")
return False return None
try: try:
audio, sr = librosa.load(warmup_file, sr=16000) audio, sr = librosa.load(warmup_file, sr=16000)
except Exception as e: except Exception as e:
logger.warning(f"Failed to load audio file: {e}") logger.warning(f"Failed to load audio file: {e}")
return False return None
return audio return audio
def warmup_asr(asr, warmup_file=None, timeout=5): def warmup_asr(asr, warmup_file=None, timeout=5):