faster-whisper as an optional encoder alternative for simulstreaming

This commit is contained in:
Quentin Fuxa
2025-08-30 23:50:16 +02:00
parent 1d926f2e67
commit 199e21b3ef
2 changed files with 95 additions and 3 deletions

64
DEV_NOTES.md Normal file
View File

@@ -0,0 +1,64 @@
# 1. Simulstreaming: Decouple the encoder for faster inference
Simulstreaming encoder time (whisperlivekit/simul_whisper/simul_whisper.py l. 397) experimentations :
On macOS Apple Silicon M4 :
| Encoder | base.en | small |
|--------|---------|-------|
| WHISPER (no modification) | 0.35s | 1.09s |
| FASTER_WHISPER | 0.4s | 1.20s |
| MLX_WHISPER | 0.07s | 0.20s |
# 2. SortFormer Diarization: 4-to-2 Speaker Constraint Algorithm
Transform a diarization model that predicts up to 4 speakers into one that predicts up to 2 speakers by mapping the output predictions.
## Problem Statement
- Input: `self.total_preds` with shape `(x, x, 4)` - predictions for 4 speakers
- Output: Constrained predictions with shape `(x, x, 2)` - predictions for 2 speakers
#
### Initial Setup
For each time step `i`, we have a ranking of 4 speaker predictions (1-4). When only 2 speakers are present, the model will have close predictions for the 2 active speaker positions.
Instead of `np.argmax(preds_np, axis=1)`, we take the top 2 predictions and build a dynamic 4→2 mapping that can evolve over time.
### Algorithm
```python
top_2_speakers = np.argsort(preds_np, axis=1)[:, -2:]
```
- `DS_a_{i}`: Top detected speaker for prediction i
- `DS_b_{i}`: Second detected speaker for prediction i
- `AS_{i}`: Attributed speaker for prediction i
- `GTS_A`: Ground truth speaker A
- `GTS_B`: Ground truth speaker B
- `DIST(a, b)`: Distance between detected speakers a and b
3. **Attribution Logic**
```
AS_0 ← A
AS_1 ← B
IF DIST(DS_a_0, DS_a_1) < DIST(DS_a_0, DS_a_2) AND
DIST(DS_a_0, DS_a_1) < DIST(DS_a_1, DS_a_2):
# Likely that DS_a_0 = DS_a_1 (same speaker)
AS_1 ← A
AS_2 ← B
ELIF DIST(DS_a_0, DS_a_2) < DIST(DS_a_0, DS_a_1) AND
DIST(DS_a_0, DS_a_2) < DIST(DS_a_1, DS_a_2):
AS_2 ← A
ELSE:
AS_2 ← B
to finish
```

View File

@@ -35,6 +35,17 @@ try:
except ImportError:
HAS_MLX_WHISPER = False
try:
from faster_whisper import WhisperModel
from faster_whisper.audio import pad_or_trim as fw_pad_or_trim
from faster_whisper.feature_extractor import FeatureExtractor
HAS_FASTER_WHISPER = True
except ImportError:
HAS_FASTER_WHISPER = False
# HAS_MLX_WHISPER = False
HAS_FASTER_WHISPER = False #Time to determine if that's really faster
# New features added to the original version of Simul-Whisper:
# - large-v3 model support
@@ -56,6 +67,14 @@ class PaddedAlignAttWhisper:
print('Simulstreaming will use MLX whisper for a faster encoder.')
mlx_model_name = model_mapping[model_name]
self.mlx_model = load_models.load_model(path_or_hf_repo=mlx_model_name)
elif HAS_FASTER_WHISPER:
print('Simulstreaming will use Faster Whisper for the encoder.')
self.fw_model = WhisperModel(
model_name,
device='auto',
compute_type='auto',
)
self.feature_extractor = FeatureExtractor(feature_size=self.model.dims.n_mels)
logger.info(f"Model dimensions: {self.model.dims}")
@@ -375,7 +394,7 @@ class PaddedAlignAttWhisper:
input_segments = self.segments[0]
# NEW : we can use a different encoder, before using standart whisper for cross attention with the hooks on the decoder
# beg_encode = time()
beg_encode = time()
if HAS_MLX_WHISPER:
mlx_mel_padded = mlx_log_mel_spectrogram(audio=input_segments.detach(), n_mels=self.model.dims.n_mels, padding=N_SAMPLES)
mlx_mel = mlx_pad_or_trim(mlx_mel_padded, N_FRAMES, axis=-2)
@@ -383,6 +402,15 @@ class PaddedAlignAttWhisper:
encoder_feature = torch.tensor(np.array(mlx_encoder_feature))
content_mel_len = int((mlx_mel_padded.shape[0] - mlx_mel.shape[0])/2)
device = 'cpu'
elif HAS_FASTER_WHISPER:
audio_length_seconds = len(input_segments) / 16000
content_mel_len = int(audio_length_seconds * 100)//2
# padded_audio = pad_or_trim(input_segments.detach(), N_SAMPLES)
mel_padded_2 = self.feature_extractor(waveform=input_segments.numpy(), padding=N_SAMPLES)[None, :]
mel = fw_pad_or_trim(mel_padded_2, N_FRAMES, axis=-1)
encoder_feature_ctranslate = self.fw_model.encode(mel)
encoder_feature = torch.Tensor(np.array(encoder_feature_ctranslate))
device = 'cpu'
else:
# mel + padding to 30s
mel_padded = log_mel_spectrogram(input_segments, n_mels=self.model.dims.n_mels, padding=N_SAMPLES,
@@ -393,8 +421,8 @@ class PaddedAlignAttWhisper:
content_mel_len = int((mel_padded.shape[2] - mel.shape[2])/2)
encoder_feature = self.model.encoder(mel)
device = mel.device
# end_encode = time()
# print('Encode time whisper', HAS_MLX_WHISPER, end_encode-beg_encode)
end_encode = time()
# print('Encoder duration:', end_encode-beg_encode)