0.2.8 : only the decoder of whisper is loaded in memory when a different encoder is used

This commit is contained in:
Quentin Fuxa
2025-09-02 21:12:25 +02:00
parent 50b0527858
commit 3bd2122eb4
11 changed files with 112 additions and 47 deletions

View File

@@ -9,7 +9,7 @@
<p align="center">
<a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a>
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=installations"></a>
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.13-dark_green"></a>
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.15-dark_green"></a>
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-MIT/Dual Licensed-dark_green"></a>
</p>
@@ -67,10 +67,10 @@ pip install whisperlivekit
| Optional | `pip install` |
|-----------|-------------|
| **Speaker diarization with Sortformer** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` |
| Speaker diarization with Diart | `diart` |
| Original Whisper backend | `whisper` |
| Improved timestamps backend | `whisper-timestamped` |
| Apple Silicon optimization backend | `mlx-whisper` |
| **Apple Silicon optimized backend** | `mlx-whisper` |
| *[Not recommanded]* Speaker diarization with Diart | `diart` |
| *[Not recommanded]* Original Whisper backend | `whisper` |
| *[Not recommanded]* Improved timestamps backend | `whisper-timestamped` |
| OpenAI API backend | `openai` |
See **Parameters & Configuration** below on how to use them.
@@ -138,6 +138,7 @@ An important list of parameters can be changed. But what *should* you change?
- the `--language`. List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English.
- the `--backend` ? you can switch to `--backend faster-whisper` if `simulstreaming` does not work correctly or if you prefer to avoid the dual-license requirements.
- `--warmup-file`, if you have one
- `--task translate`, to translate in english
- `--host`, `--port`, `--ssl-certfile`, `--ssl-keyfile`, if you set up a server
- `--diarization`, if you want to use it.
@@ -159,14 +160,9 @@ The rest I don't recommend. But below are your options.
| `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` |
| WhisperStreaming backend options | Description | Default |
|-----------|-------------|---------|
| `--confidence-validation` | Use confidence scores for faster validation | `False` |
| `--buffer_trimming` | Buffer trimming strategy (`sentence` or `segment`) | `segment` |
| SimulStreaming backend options | Description | Default |
|-----------|-------------|---------|
| `--disable-fast-encoder` | Disable Faster Whisper or MLX Whisper backends for the encoder (if installed). Inference can be slower but helpful when GPU memory is limited | `False` |
| `--frame-threshold` | AlignAtt frame threshold (lower = faster, higher = more accurate) | `25` |
| `--beams` | Number of beams for beam search (1 = greedy decoding) | `1` |
| `--decoder` | Force decoder type (`beam` or `greedy`) | `auto` |
@@ -180,6 +176,12 @@ The rest I don't recommend. But below are your options.
| `--model-path` | Direct path to .pt model file. Download it if not found | `./base.pt` |
| `--preloaded-model-count` | Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent users) | `1` |
| WhisperStreaming backend options | Description | Default |
|-----------|-------------|---------|
| `--confidence-validation` | Use confidence scores for faster validation | `False` |
| `--buffer_trimming` | Buffer trimming strategy (`sentence` or `segment`) | `segment` |
| Diarization options | Description | Default |
|-----------|-------------|---------|
| `--diarization` | Enable speaker identification | `False` |

Binary file not shown.

Before

Width:  |  Height:  |  Size: 367 KiB

After

Width:  |  Height:  |  Size: 368 KiB

View File

@@ -4,8 +4,8 @@ build-backend = "setuptools.build_meta"
[project]
name = "whisperlivekit"
version = "0.2.7"
description = "Real-time, Fully Local Whisper's Speech-to-Text and Speaker Diarization"
version = "0.2.8"
description = "Real-time speech-to-text with speaker diarization using Whisper"
readme = "README.md"
authors = [
{ name = "Quentin Fuxa" }
@@ -18,6 +18,11 @@ classifiers = [
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.14",
"Programming Language :: Python :: 3.15",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Multimedia :: Sound/Audio :: Speech"
]
@@ -28,7 +33,8 @@ dependencies = [
"faster-whisper",
"uvicorn",
"websockets",
"torch",
"torchaudio>=2.0.0",
"torch>=2.0.0",
"tqdm",
"tiktoken",
'triton>=2.0.0; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")'

View File

@@ -19,6 +19,15 @@ transcription_engine = None
@asynccontextmanager
async def lifespan(app: FastAPI):
#to remove after 0.2.8
if args.backend == "simulstreaming" and not args.disable_fast_encoder:
logger.warning(f"""
{'='*50}
WhisperLiveKit 0.2.8 has introduced a new fast encoder feature using MLX Whisper or Faster Whisper for improved speed. Use --disable-fast-encoder to disable if you encounter issues.
{'='*50}
""")
global transcription_engine
transcription_engine = TranscriptionEngine(
**vars(args),

View File

@@ -46,6 +46,7 @@ class TranscriptionEngine:
"confidence_validation": False,
"buffer_trimming_sec": 15,
# simulstreaming params:
"disable_fast_encoder": False,
"frame_threshold": 25,
"beams": 1,
"decoder_type": None,
@@ -60,7 +61,7 @@ class TranscriptionEngine:
"diarization_backend": "sortformer",
# diart params:
"segmentation_model": "pyannote/segmentation-3.0",
"embedding_model": "pyannote/embedding",
"embedding_model": "pyannote/embedding",
}
config_dict = {**defaults, **kwargs}
@@ -97,7 +98,7 @@ class TranscriptionEngine:
simulstreaming_kwargs = {}
for attr in ['frame_threshold', 'beams', 'decoder_type', 'audio_max_len', 'audio_min_len',
'cif_ckpt_path', 'never_fire', 'init_prompt', 'static_init_prompt',
'max_context_tokens', 'model_path', 'warmup_file', 'preload_model_count']:
'max_context_tokens', 'model_path', 'warmup_file', 'preload_model_count', 'disable_fast_encoder']:
if hasattr(self.args, attr):
simulstreaming_kwargs[attr] = getattr(self.args, attr)

View File

@@ -161,6 +161,14 @@ def parse_args():
# SimulStreaming-specific arguments
simulstreaming_group = parser.add_argument_group('SimulStreaming arguments (only used with --backend simulstreaming)')
simulstreaming_group.add_argument(
"--disable-fast-encoder",
action="store_true",
default=False,
dest="disable_fast_encoder",
help="Disable Faster Whisper or MLX Whisper backends for encoding (if installed). Slower but helpful when GPU memory is limited",
)
simulstreaming_group.add_argument(
"--frame-threshold",

View File

@@ -244,7 +244,8 @@ class SimulStreamingASR():
self.max_context_tokens = kwargs.get('max_context_tokens', None)
self.warmup_file = kwargs.get('warmup_file', None)
self.preload_model_count = kwargs.get('preload_model_count', 1)
self.disable_fast_encoder = kwargs.get('disable_fast_encoder', False)
self.fast_encoder = False
if model_dir is not None:
self.model_path = model_dir
elif modelsize is not None:
@@ -289,25 +290,44 @@ class SimulStreamingASR():
self.model_name = os.path.basename(self.cfg.model_path).replace(".pt", "")
self.model_path = os.path.dirname(os.path.abspath(self.cfg.model_path))
self.models = [self.load_model() for i in range(self.preload_model_count)]
self.mlx_encoder, self.fw_encoder = None, None
if HAS_MLX_WHISPER:
print('Simulstreaming will use MLX whisper for a faster encoder.')
mlx_model_name = mlx_model_mapping[self.model_name]
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model_name)
elif HAS_FASTER_WHISPER:
print('Simulstreaming will use Faster Whisper for the encoder.')
self.fw_encoder = WhisperModel(
self.model_name,
device='auto',
compute_type='auto',
)
if not self.disable_fast_encoder:
if HAS_MLX_WHISPER:
print('Simulstreaming will use MLX whisper for a faster encoder.')
mlx_model_name = mlx_model_mapping[self.model_name]
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model_name)
self.fast_encoder = True
elif HAS_FASTER_WHISPER:
print('Simulstreaming will use Faster Whisper for the encoder.')
self.fw_encoder = WhisperModel(
self.model_name,
device='auto',
compute_type='auto',
)
self.fast_encoder = True
self.models = [self.load_model() for i in range(self.preload_model_count)]
def load_model(self):
whisper_model = load_model(name=self.model_name, download_root=self.model_path)
whisper_model = load_model(name=self.model_name, download_root=self.model_path, decoder_only=self.fast_encoder)
warmup_audio = load_file(self.warmup_file)
whisper_model.transcribe(warmup_audio, language=self.original_language if self.original_language != 'auto' else None)
if warmup_audio is not None:
warmup_audio = torch.from_numpy(warmup_audio).float()
if self.fast_encoder:
temp_model = PaddedAlignAttWhisper(
cfg=self.cfg,
loaded_model=whisper_model,
mlx_encoder=self.mlx_encoder,
fw_encoder=self.fw_encoder,
)
temp_model.warmup(warmup_audio)
temp_model.remove_hooks()
else:
# For standard encoder, use the original transcribe warmup
warmup_audio = load_file(self.warmup_file)
whisper_model.transcribe(warmup_audio, language=self.original_language if self.original_language != 'auto' else None)
return whisper_model
def get_new_model_instance(self):

View File

@@ -64,7 +64,7 @@ class PaddedAlignAttWhisper:
self.mlx_encoder = mlx_encoder
self.fw_encoder = fw_encoder
if HAS_FASTER_WHISPER:
if fw_encoder:
self.fw_feature_extractor = FeatureExtractor(feature_size=self.model.dims.n_mels)
logger.info(f"Model dimensions: {self.model.dims}")
@@ -176,6 +176,15 @@ class PaddedAlignAttWhisper:
for hook in self.l_hooks:
hook.remove()
def warmup(self, audio):
try:
self.insert_audio(audio)
self.infer(is_last=True)
self.refresh_segment(complete=True)
logger.info("Model warmed up successfully")
except Exception as e:
logger.exception(f"Model warmup failed: {e}")
def create_tokenizer(self, language=None):
self.tokenizer = tokenizer.get_tokenizer(
multilingual=self.tokenizer_is_multilingual,
@@ -386,14 +395,14 @@ class PaddedAlignAttWhisper:
# NEW : we can use a different encoder, before using standart whisper for cross attention with the hooks on the decoder
beg_encode = time()
if HAS_MLX_WHISPER:
if self.mlx_encoder:
mlx_mel_padded = mlx_log_mel_spectrogram(audio=input_segments.detach(), n_mels=self.model.dims.n_mels, padding=N_SAMPLES)
mlx_mel = mlx_pad_or_trim(mlx_mel_padded, N_FRAMES, axis=-2)
mlx_encoder_feature = self.mlx_encoder.encoder(mlx_mel[None])
encoder_feature = torch.tensor(np.array(mlx_encoder_feature))
content_mel_len = int((mlx_mel_padded.shape[0] - mlx_mel.shape[0])/2)
device = 'cpu'
elif HAS_FASTER_WHISPER:
elif self.fw_encoder:
audio_length_seconds = len(input_segments) / 16000
content_mel_len = int(audio_length_seconds * 100)//2
mel_padded_2 = self.fw_feature_extractor(waveform=input_segments.numpy(), padding=N_SAMPLES)[None, :]

View File

@@ -105,6 +105,7 @@ def load_model(
device: Optional[Union[str, torch.device]] = None,
download_root: str = None,
in_memory: bool = False,
decoder_only=False
) -> Whisper:
"""
Load a Whisper ASR model
@@ -151,7 +152,14 @@ def load_model(
del checkpoint_file
dims = ModelDimensions(**checkpoint["dims"])
model = Whisper(dims)
model = Whisper(dims, decoder_only=decoder_only)
if decoder_only:
checkpoint["model_state_dict"] = {
k: v for k, v in checkpoint["model_state_dict"].items()
if 'encoder' not in k
}
model.load_state_dict(checkpoint["model_state_dict"])
if alignment_heads is not None:

View File

@@ -253,16 +253,18 @@ class TextDecoder(nn.Module):
class Whisper(nn.Module):
def __init__(self, dims: ModelDimensions):
def __init__(self, dims: ModelDimensions, decoder_only: bool = False):
super().__init__()
self.dims = dims
self.encoder = AudioEncoder(
self.dims.n_mels,
self.dims.n_audio_ctx,
self.dims.n_audio_state,
self.dims.n_audio_head,
self.dims.n_audio_layer,
)
if not decoder_only:
self.encoder = AudioEncoder(
self.dims.n_mels,
self.dims.n_audio_ctx,
self.dims.n_audio_state,
self.dims.n_audio_head,
self.dims.n_audio_layer,
)
self.decoder = TextDecoder(
self.dims.n_vocab,
self.dims.n_text_ctx,

View File

@@ -31,21 +31,21 @@ def load_file(warmup_file=None, timeout=5):
logger.debug(f"Download successful in {time.time() - start_time:.2f}s")
except (urllib.error.URLError, socket.timeout) as e:
logger.warning(f"Download failed: {e}. Proceeding without warmup.")
return False
return None
finally:
socket.setdefaulttimeout(original_timeout)
elif not warmup_file:
return False
return None
if not warmup_file or not os.path.exists(warmup_file) or os.path.getsize(warmup_file) == 0:
logger.warning(f"Warmup file {warmup_file} invalid or missing.")
return False
return None
try:
audio, sr = librosa.load(warmup_file, sr=16000)
except Exception as e:
logger.warning(f"Failed to load audio file: {e}")
return False
return None
return audio
def warmup_asr(asr, warmup_file=None, timeout=5):