From 199e21b3ef4215738497505e4fc6fcd470089ed6 Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Sat, 30 Aug 2025 23:50:16 +0200 Subject: [PATCH] faster-whisper as an optional encoder alternative for simulstreaming --- DEV_NOTES.md | 64 +++++++++++++++++++ whisperlivekit/simul_whisper/simul_whisper.py | 34 +++++++++- 2 files changed, 95 insertions(+), 3 deletions(-) create mode 100644 DEV_NOTES.md diff --git a/DEV_NOTES.md b/DEV_NOTES.md new file mode 100644 index 0000000..9bf286d --- /dev/null +++ b/DEV_NOTES.md @@ -0,0 +1,64 @@ +# 1. Simulstreaming: Decouple the encoder for faster inference + +Simulstreaming encoder time (whisperlivekit/simul_whisper/simul_whisper.py l. 397) experimentations : + +On macOS Apple Silicon M4 : + +| Encoder | base.en | small | +|--------|---------|-------| +| WHISPER (no modification) | 0.35s | 1.09s | +| FASTER_WHISPER | 0.4s | 1.20s | +| MLX_WHISPER | 0.07s | 0.20s | + + + + +# 2. SortFormer Diarization: 4-to-2 Speaker Constraint Algorithm + +Transform a diarization model that predicts up to 4 speakers into one that predicts up to 2 speakers by mapping the output predictions. + +## Problem Statement +- Input: `self.total_preds` with shape `(x, x, 4)` - predictions for 4 speakers +- Output: Constrained predictions with shape `(x, x, 2)` - predictions for 2 speakers + +# +### Initial Setup +For each time step `i`, we have a ranking of 4 speaker predictions (1-4). When only 2 speakers are present, the model will have close predictions for the 2 active speaker positions. + +Instead of `np.argmax(preds_np, axis=1)`, we take the top 2 predictions and build a dynamic 4→2 mapping that can evolve over time. + +### Algorithm + +```python +top_2_speakers = np.argsort(preds_np, axis=1)[:, -2:] +``` + +- `DS_a_{i}`: Top detected speaker for prediction i +- `DS_b_{i}`: Second detected speaker for prediction i +- `AS_{i}`: Attributed speaker for prediction i +- `GTS_A`: Ground truth speaker A +- `GTS_B`: Ground truth speaker B +- `DIST(a, b)`: Distance between detected speakers a and b + +3. **Attribution Logic** + +``` +AS_0 ← A + +AS_1 ← B + +IF DIST(DS_a_0, DS_a_1) < DIST(DS_a_0, DS_a_2) AND + DIST(DS_a_0, DS_a_1) < DIST(DS_a_1, DS_a_2): + # Likely that DS_a_0 = DS_a_1 (same speaker) + AS_1 ← A + AS_2 ← B + +ELIF DIST(DS_a_0, DS_a_2) < DIST(DS_a_0, DS_a_1) AND + DIST(DS_a_0, DS_a_2) < DIST(DS_a_1, DS_a_2): + AS_2 ← A + +ELSE: + AS_2 ← B + +to finish +``` \ No newline at end of file diff --git a/whisperlivekit/simul_whisper/simul_whisper.py b/whisperlivekit/simul_whisper/simul_whisper.py index 104f379..50719ae 100644 --- a/whisperlivekit/simul_whisper/simul_whisper.py +++ b/whisperlivekit/simul_whisper/simul_whisper.py @@ -35,6 +35,17 @@ try: except ImportError: HAS_MLX_WHISPER = False +try: + from faster_whisper import WhisperModel + from faster_whisper.audio import pad_or_trim as fw_pad_or_trim + from faster_whisper.feature_extractor import FeatureExtractor + HAS_FASTER_WHISPER = True +except ImportError: + HAS_FASTER_WHISPER = False + +# HAS_MLX_WHISPER = False +HAS_FASTER_WHISPER = False #Time to determine if that's really faster + # New features added to the original version of Simul-Whisper: # - large-v3 model support @@ -56,6 +67,14 @@ class PaddedAlignAttWhisper: print('Simulstreaming will use MLX whisper for a faster encoder.') mlx_model_name = model_mapping[model_name] self.mlx_model = load_models.load_model(path_or_hf_repo=mlx_model_name) + elif HAS_FASTER_WHISPER: + print('Simulstreaming will use Faster Whisper for the encoder.') + self.fw_model = WhisperModel( + model_name, + device='auto', + compute_type='auto', + ) + self.feature_extractor = FeatureExtractor(feature_size=self.model.dims.n_mels) logger.info(f"Model dimensions: {self.model.dims}") @@ -375,7 +394,7 @@ class PaddedAlignAttWhisper: input_segments = self.segments[0] # NEW : we can use a different encoder, before using standart whisper for cross attention with the hooks on the decoder - # beg_encode = time() + beg_encode = time() if HAS_MLX_WHISPER: mlx_mel_padded = mlx_log_mel_spectrogram(audio=input_segments.detach(), n_mels=self.model.dims.n_mels, padding=N_SAMPLES) mlx_mel = mlx_pad_or_trim(mlx_mel_padded, N_FRAMES, axis=-2) @@ -383,6 +402,15 @@ class PaddedAlignAttWhisper: encoder_feature = torch.tensor(np.array(mlx_encoder_feature)) content_mel_len = int((mlx_mel_padded.shape[0] - mlx_mel.shape[0])/2) device = 'cpu' + elif HAS_FASTER_WHISPER: + audio_length_seconds = len(input_segments) / 16000 + content_mel_len = int(audio_length_seconds * 100)//2 + # padded_audio = pad_or_trim(input_segments.detach(), N_SAMPLES) + mel_padded_2 = self.feature_extractor(waveform=input_segments.numpy(), padding=N_SAMPLES)[None, :] + mel = fw_pad_or_trim(mel_padded_2, N_FRAMES, axis=-1) + encoder_feature_ctranslate = self.fw_model.encode(mel) + encoder_feature = torch.Tensor(np.array(encoder_feature_ctranslate)) + device = 'cpu' else: # mel + padding to 30s mel_padded = log_mel_spectrogram(input_segments, n_mels=self.model.dims.n_mels, padding=N_SAMPLES, @@ -393,8 +421,8 @@ class PaddedAlignAttWhisper: content_mel_len = int((mel_padded.shape[2] - mel.shape[2])/2) encoder_feature = self.model.encoder(mel) device = mel.device - # end_encode = time() - # print('Encode time whisper', HAS_MLX_WHISPER, end_encode-beg_encode) + end_encode = time() + # print('Encoder duration:', end_encode-beg_encode)