diff --git a/README.md b/README.md index 90c23e5..4b2a633 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@

PyPI Version PyPI Downloads -Python Versions +Python Versions License

@@ -67,10 +67,10 @@ pip install whisperlivekit | Optional | `pip install` | |-----------|-------------| | **Speaker diarization with Sortformer** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` | -| Speaker diarization with Diart | `diart` | -| Original Whisper backend | `whisper` | -| Improved timestamps backend | `whisper-timestamped` | -| Apple Silicon optimization backend | `mlx-whisper` | +| **Apple Silicon optimized backend** | `mlx-whisper` | +| *[Not recommanded]* Speaker diarization with Diart | `diart` | +| *[Not recommanded]* Original Whisper backend | `whisper` | +| *[Not recommanded]* Improved timestamps backend | `whisper-timestamped` | | OpenAI API backend | `openai` | See **Parameters & Configuration** below on how to use them. @@ -138,6 +138,7 @@ An important list of parameters can be changed. But what *should* you change? - the `--language`. List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English. - the `--backend` ? you can switch to `--backend faster-whisper` if `simulstreaming` does not work correctly or if you prefer to avoid the dual-license requirements. - `--warmup-file`, if you have one +- `--task translate`, to translate in english - `--host`, `--port`, `--ssl-certfile`, `--ssl-keyfile`, if you set up a server - `--diarization`, if you want to use it. @@ -159,14 +160,9 @@ The rest I don't recommend. But below are your options. | `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` | -| WhisperStreaming backend options | Description | Default | -|-----------|-------------|---------| -| `--confidence-validation` | Use confidence scores for faster validation | `False` | -| `--buffer_trimming` | Buffer trimming strategy (`sentence` or `segment`) | `segment` | - - | SimulStreaming backend options | Description | Default | |-----------|-------------|---------| +| `--disable-fast-encoder` | Disable Faster Whisper or MLX Whisper backends for the encoder (if installed). Inference can be slower but helpful when GPU memory is limited | `False` | | `--frame-threshold` | AlignAtt frame threshold (lower = faster, higher = more accurate) | `25` | | `--beams` | Number of beams for beam search (1 = greedy decoding) | `1` | | `--decoder` | Force decoder type (`beam` or `greedy`) | `auto` | @@ -180,6 +176,12 @@ The rest I don't recommend. But below are your options. | `--model-path` | Direct path to .pt model file. Download it if not found | `./base.pt` | | `--preloaded-model-count` | Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent users) | `1` | + +| WhisperStreaming backend options | Description | Default | +|-----------|-------------|---------| +| `--confidence-validation` | Use confidence scores for faster validation | `False` | +| `--buffer_trimming` | Buffer trimming strategy (`sentence` or `segment`) | `segment` | + | Diarization options | Description | Default | |-----------|-------------|---------| | `--diarization` | Enable speaker identification | `False` | diff --git a/architecture.png b/architecture.png index 213df2a..b9aa73f 100644 Binary files a/architecture.png and b/architecture.png differ diff --git a/pyproject.toml b/pyproject.toml index cf09dc4..e943cf9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,8 +4,8 @@ build-backend = "setuptools.build_meta" [project] name = "whisperlivekit" -version = "0.2.7" -description = "Real-time, Fully Local Whisper's Speech-to-Text and Speaker Diarization" +version = "0.2.8" +description = "Real-time speech-to-text with speaker diarization using Whisper" readme = "README.md" authors = [ { name = "Quentin Fuxa" } @@ -18,6 +18,11 @@ classifiers = [ "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", + "Programming Language :: Python :: 3.15", "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Multimedia :: Sound/Audio :: Speech" ] @@ -28,7 +33,8 @@ dependencies = [ "faster-whisper", "uvicorn", "websockets", - "torch", + "torchaudio>=2.0.0", + "torch>=2.0.0", "tqdm", "tiktoken", 'triton>=2.0.0; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")' diff --git a/whisperlivekit/basic_server.py b/whisperlivekit/basic_server.py index b66eefd..fa4b9da 100644 --- a/whisperlivekit/basic_server.py +++ b/whisperlivekit/basic_server.py @@ -19,6 +19,15 @@ transcription_engine = None @asynccontextmanager async def lifespan(app: FastAPI): + + #to remove after 0.2.8 + if args.backend == "simulstreaming" and not args.disable_fast_encoder: + logger.warning(f""" +{'='*50} +WhisperLiveKit 0.2.8 has introduced a new fast encoder feature using MLX Whisper or Faster Whisper for improved speed. Use --disable-fast-encoder to disable if you encounter issues. +{'='*50} + """) + global transcription_engine transcription_engine = TranscriptionEngine( **vars(args), diff --git a/whisperlivekit/core.py b/whisperlivekit/core.py index bcab83f..8ce714b 100644 --- a/whisperlivekit/core.py +++ b/whisperlivekit/core.py @@ -46,6 +46,7 @@ class TranscriptionEngine: "confidence_validation": False, "buffer_trimming_sec": 15, # simulstreaming params: + "disable_fast_encoder": False, "frame_threshold": 25, "beams": 1, "decoder_type": None, @@ -60,7 +61,7 @@ class TranscriptionEngine: "diarization_backend": "sortformer", # diart params: "segmentation_model": "pyannote/segmentation-3.0", - "embedding_model": "pyannote/embedding", + "embedding_model": "pyannote/embedding", } config_dict = {**defaults, **kwargs} @@ -97,7 +98,7 @@ class TranscriptionEngine: simulstreaming_kwargs = {} for attr in ['frame_threshold', 'beams', 'decoder_type', 'audio_max_len', 'audio_min_len', 'cif_ckpt_path', 'never_fire', 'init_prompt', 'static_init_prompt', - 'max_context_tokens', 'model_path', 'warmup_file', 'preload_model_count']: + 'max_context_tokens', 'model_path', 'warmup_file', 'preload_model_count', 'disable_fast_encoder']: if hasattr(self.args, attr): simulstreaming_kwargs[attr] = getattr(self.args, attr) diff --git a/whisperlivekit/parse_args.py b/whisperlivekit/parse_args.py index c8d0ce5..023f951 100644 --- a/whisperlivekit/parse_args.py +++ b/whisperlivekit/parse_args.py @@ -161,6 +161,14 @@ def parse_args(): # SimulStreaming-specific arguments simulstreaming_group = parser.add_argument_group('SimulStreaming arguments (only used with --backend simulstreaming)') + + simulstreaming_group.add_argument( + "--disable-fast-encoder", + action="store_true", + default=False, + dest="disable_fast_encoder", + help="Disable Faster Whisper or MLX Whisper backends for encoding (if installed). Slower but helpful when GPU memory is limited", + ) simulstreaming_group.add_argument( "--frame-threshold", diff --git a/whisperlivekit/simul_whisper/backend.py b/whisperlivekit/simul_whisper/backend.py index 7e1482e..f6439b5 100644 --- a/whisperlivekit/simul_whisper/backend.py +++ b/whisperlivekit/simul_whisper/backend.py @@ -244,7 +244,8 @@ class SimulStreamingASR(): self.max_context_tokens = kwargs.get('max_context_tokens', None) self.warmup_file = kwargs.get('warmup_file', None) self.preload_model_count = kwargs.get('preload_model_count', 1) - + self.disable_fast_encoder = kwargs.get('disable_fast_encoder', False) + self.fast_encoder = False if model_dir is not None: self.model_path = model_dir elif modelsize is not None: @@ -289,25 +290,44 @@ class SimulStreamingASR(): self.model_name = os.path.basename(self.cfg.model_path).replace(".pt", "") self.model_path = os.path.dirname(os.path.abspath(self.cfg.model_path)) - self.models = [self.load_model() for i in range(self.preload_model_count)] self.mlx_encoder, self.fw_encoder = None, None - if HAS_MLX_WHISPER: - print('Simulstreaming will use MLX whisper for a faster encoder.') - mlx_model_name = mlx_model_mapping[self.model_name] - self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model_name) - elif HAS_FASTER_WHISPER: - print('Simulstreaming will use Faster Whisper for the encoder.') - self.fw_encoder = WhisperModel( - self.model_name, - device='auto', - compute_type='auto', - ) + if not self.disable_fast_encoder: + if HAS_MLX_WHISPER: + print('Simulstreaming will use MLX whisper for a faster encoder.') + mlx_model_name = mlx_model_mapping[self.model_name] + self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model_name) + self.fast_encoder = True + elif HAS_FASTER_WHISPER: + print('Simulstreaming will use Faster Whisper for the encoder.') + self.fw_encoder = WhisperModel( + self.model_name, + device='auto', + compute_type='auto', + ) + self.fast_encoder = True + + self.models = [self.load_model() for i in range(self.preload_model_count)] + def load_model(self): - whisper_model = load_model(name=self.model_name, download_root=self.model_path) + whisper_model = load_model(name=self.model_name, download_root=self.model_path, decoder_only=self.fast_encoder) warmup_audio = load_file(self.warmup_file) - whisper_model.transcribe(warmup_audio, language=self.original_language if self.original_language != 'auto' else None) + if warmup_audio is not None: + warmup_audio = torch.from_numpy(warmup_audio).float() + if self.fast_encoder: + temp_model = PaddedAlignAttWhisper( + cfg=self.cfg, + loaded_model=whisper_model, + mlx_encoder=self.mlx_encoder, + fw_encoder=self.fw_encoder, + ) + temp_model.warmup(warmup_audio) + temp_model.remove_hooks() + else: + # For standard encoder, use the original transcribe warmup + warmup_audio = load_file(self.warmup_file) + whisper_model.transcribe(warmup_audio, language=self.original_language if self.original_language != 'auto' else None) return whisper_model def get_new_model_instance(self): diff --git a/whisperlivekit/simul_whisper/simul_whisper.py b/whisperlivekit/simul_whisper/simul_whisper.py index 39101bf..c1f8c2e 100644 --- a/whisperlivekit/simul_whisper/simul_whisper.py +++ b/whisperlivekit/simul_whisper/simul_whisper.py @@ -64,7 +64,7 @@ class PaddedAlignAttWhisper: self.mlx_encoder = mlx_encoder self.fw_encoder = fw_encoder - if HAS_FASTER_WHISPER: + if fw_encoder: self.fw_feature_extractor = FeatureExtractor(feature_size=self.model.dims.n_mels) logger.info(f"Model dimensions: {self.model.dims}") @@ -176,6 +176,15 @@ class PaddedAlignAttWhisper: for hook in self.l_hooks: hook.remove() + def warmup(self, audio): + try: + self.insert_audio(audio) + self.infer(is_last=True) + self.refresh_segment(complete=True) + logger.info("Model warmed up successfully") + except Exception as e: + logger.exception(f"Model warmup failed: {e}") + def create_tokenizer(self, language=None): self.tokenizer = tokenizer.get_tokenizer( multilingual=self.tokenizer_is_multilingual, @@ -386,14 +395,14 @@ class PaddedAlignAttWhisper: # NEW : we can use a different encoder, before using standart whisper for cross attention with the hooks on the decoder beg_encode = time() - if HAS_MLX_WHISPER: + if self.mlx_encoder: mlx_mel_padded = mlx_log_mel_spectrogram(audio=input_segments.detach(), n_mels=self.model.dims.n_mels, padding=N_SAMPLES) mlx_mel = mlx_pad_or_trim(mlx_mel_padded, N_FRAMES, axis=-2) mlx_encoder_feature = self.mlx_encoder.encoder(mlx_mel[None]) encoder_feature = torch.tensor(np.array(mlx_encoder_feature)) content_mel_len = int((mlx_mel_padded.shape[0] - mlx_mel.shape[0])/2) device = 'cpu' - elif HAS_FASTER_WHISPER: + elif self.fw_encoder: audio_length_seconds = len(input_segments) / 16000 content_mel_len = int(audio_length_seconds * 100)//2 mel_padded_2 = self.fw_feature_extractor(waveform=input_segments.numpy(), padding=N_SAMPLES)[None, :] diff --git a/whisperlivekit/simul_whisper/whisper/__init__.py b/whisperlivekit/simul_whisper/whisper/__init__.py index e210718..069ddbb 100644 --- a/whisperlivekit/simul_whisper/whisper/__init__.py +++ b/whisperlivekit/simul_whisper/whisper/__init__.py @@ -105,6 +105,7 @@ def load_model( device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False, + decoder_only=False ) -> Whisper: """ Load a Whisper ASR model @@ -151,7 +152,14 @@ def load_model( del checkpoint_file dims = ModelDimensions(**checkpoint["dims"]) - model = Whisper(dims) + model = Whisper(dims, decoder_only=decoder_only) + + if decoder_only: + checkpoint["model_state_dict"] = { + k: v for k, v in checkpoint["model_state_dict"].items() + if 'encoder' not in k + } + model.load_state_dict(checkpoint["model_state_dict"]) if alignment_heads is not None: diff --git a/whisperlivekit/simul_whisper/whisper/model.py b/whisperlivekit/simul_whisper/whisper/model.py index 7fb887e..b6482a6 100644 --- a/whisperlivekit/simul_whisper/whisper/model.py +++ b/whisperlivekit/simul_whisper/whisper/model.py @@ -253,16 +253,18 @@ class TextDecoder(nn.Module): class Whisper(nn.Module): - def __init__(self, dims: ModelDimensions): + def __init__(self, dims: ModelDimensions, decoder_only: bool = False): super().__init__() self.dims = dims - self.encoder = AudioEncoder( - self.dims.n_mels, - self.dims.n_audio_ctx, - self.dims.n_audio_state, - self.dims.n_audio_head, - self.dims.n_audio_layer, - ) + + if not decoder_only: + self.encoder = AudioEncoder( + self.dims.n_mels, + self.dims.n_audio_ctx, + self.dims.n_audio_state, + self.dims.n_audio_head, + self.dims.n_audio_layer, + ) self.decoder = TextDecoder( self.dims.n_vocab, self.dims.n_text_ctx, diff --git a/whisperlivekit/warmup.py b/whisperlivekit/warmup.py index 5003fe2..4c14586 100644 --- a/whisperlivekit/warmup.py +++ b/whisperlivekit/warmup.py @@ -31,21 +31,21 @@ def load_file(warmup_file=None, timeout=5): logger.debug(f"Download successful in {time.time() - start_time:.2f}s") except (urllib.error.URLError, socket.timeout) as e: logger.warning(f"Download failed: {e}. Proceeding without warmup.") - return False + return None finally: socket.setdefaulttimeout(original_timeout) elif not warmup_file: - return False + return None if not warmup_file or not os.path.exists(warmup_file) or os.path.getsize(warmup_file) == 0: logger.warning(f"Warmup file {warmup_file} invalid or missing.") - return False + return None try: audio, sr = librosa.load(warmup_file, sr=16000) except Exception as e: logger.warning(f"Failed to load audio file: {e}") - return False + return None return audio def warmup_asr(asr, warmup_file=None, timeout=5):