diff --git a/README.md b/README.md index 1d49cad..7986247 100644 --- a/README.md +++ b/README.md @@ -219,6 +219,7 @@ WhisperLiveKit offers extensive configuration options: | `--static-init-prompt` | Static prompt that doesn't scroll | `None` | | `--max-context-tokens` | Maximum context tokens | `None` | | `--model-path` | Direct path to .pt model file. Download it if not found | `./base.pt` | +| `--preloaded-model-count` | Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent users) | `1` | ## 🔧 How It Works diff --git a/whisperlivekit/core.py b/whisperlivekit/core.py index 0081e39..da7fdab 100644 --- a/whisperlivekit/core.py +++ b/whisperlivekit/core.py @@ -90,7 +90,7 @@ class TranscriptionEngine: simulstreaming_kwargs = {} for attr in ['frame_threshold', 'beams', 'decoder_type', 'audio_max_len', 'audio_min_len', 'cif_ckpt_path', 'never_fire', 'init_prompt', 'static_init_prompt', - 'max_context_tokens', 'model_path']: + 'max_context_tokens', 'model_path', 'warmup_file', 'preload_model_count']: if hasattr(self.args, attr): simulstreaming_kwargs[attr] = getattr(self.args, attr) diff --git a/whisperlivekit/parse_args.py b/whisperlivekit/parse_args.py index f30a6df..0806ff5 100644 --- a/whisperlivekit/parse_args.py +++ b/whisperlivekit/parse_args.py @@ -242,6 +242,14 @@ def parse_args(): dest="model_path", help="Direct path to the SimulStreaming Whisper .pt model file. Overrides --model for SimulStreaming backend.", ) + + simulstreaming_group.add_argument( + "--preloaded_model_count", + type=int, + default=1, + dest="preloaded_model_count", + help="Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent instances).", + ) args = parser.parse_args() diff --git a/whisperlivekit/simul_whisper/backend.py b/whisperlivekit/simul_whisper/backend.py index 01a5ad6..90fb900 100644 --- a/whisperlivekit/simul_whisper/backend.py +++ b/whisperlivekit/simul_whisper/backend.py @@ -4,9 +4,11 @@ import logging from typing import List, Tuple, Optional import logging from whisperlivekit.timed_objects import ASRToken, Transcript +from whisperlivekit.warmup import load_file from whisperlivekit.simul_whisper.license_simulstreaming import SIMULSTREAMING_LICENSE from .whisper import load_model, tokenizer import os +import gc logger = logging.getLogger(__name__) try: @@ -36,10 +38,11 @@ class SimulStreamingOnlineProcessor: self.cumulative_audio_duration = 0.0 self.committed: List[ASRToken] = [] - self.last_result_tokens: List[ASRToken] = [] + self.last_result_tokens: List[ASRToken] = [] + model = asr.get_new_model_instance() self.model = PaddedAlignAttWhisper( cfg=asr.cfg, - loaded_model=asr.whisper_model) + loaded_model=model) if asr.tokenizer: self.model.tokenizer = asr.tokenizer @@ -132,6 +135,12 @@ class SimulStreamingOnlineProcessor: except Exception as e: logger.exception(f"SimulStreaming warmup failed: {e}") + def __del__(self): + # free the model and add a new model to stack. + del self.model + gc.collect() + torch.cuda.empty_cache() + self.asr.new_model_to_stack() class SimulStreamingASR(): """SimulStreaming backend with AlignAtt policy.""" @@ -156,6 +165,8 @@ class SimulStreamingASR(): self.init_prompt = kwargs.get('init_prompt', None) self.static_init_prompt = kwargs.get('static_init_prompt', None) self.max_context_tokens = kwargs.get('max_context_tokens', None) + self.warmup_file = kwargs.get('warmup_file', None) + self.preload_model_count = kwargs.get('preload_model_count', 1) if model_dir is not None: self.model_path = model_dir @@ -176,16 +187,11 @@ class SimulStreamingASR(): } self.model_path = model_mapping.get(modelsize, f'./{modelsize}.pt') - self.model = self.load_model(modelsize) - # Set up tokenizer for translation if needed if self.task == "translate": self.tokenizer = self.set_translate_task() else: self.tokenizer = None - - - def load_model(self, modelsize): self.cfg = AlignAttConfig( model_path=self.model_path, segment_length=self.segment_length, @@ -201,10 +207,33 @@ class SimulStreamingASR(): init_prompt=self.init_prompt, max_context_tokens=self.max_context_tokens, static_init_prompt=self.static_init_prompt, - ) - model_name = os.path.basename(self.cfg.model_path).replace(".pt", "") - model_path = os.path.dirname(os.path.abspath(self.cfg.model_path)) - self.whisper_model = load_model(name=model_name, download_root=model_path) + ) + + self.model_name = os.path.basename(self.cfg.model_path).replace(".pt", "") + self.model_path = os.path.dirname(os.path.abspath(self.cfg.model_path)) + self.models = [self.load_model() for i in range(self.preload_model_count)] + + + + + def load_model(self): + whisper_model = load_model(name=self.model_name, download_root=self.model_path) + warmup_audio = load_file(self.warmup_file) + whisper_model.transcribe(warmup_audio, language=self.original_language) + return whisper_model + + def get_new_model_instance(self): + """ + SimulStreaming cannot share the same backend because it uses global forward hooks on the attention layers. + Therefore, each user requires a separate model instance, which can be memory-intensive. To maintain speed, we preload the models into memory. + """ + if len(self.models) == 0: + self.models.append(self.load_model()) + new_model = self.models.pop() + return new_model + + def new_model_to_stack(self): + self.models.append(self.load_model()) def set_translate_task(self): @@ -218,6 +247,6 @@ class SimulStreamingASR(): def transcribe(self, audio): """ - Only used for warmup. It's a direct whisper call, not a simulstreaming call + Warmup is done directly in load_model """ - self.whisper_model.transcribe(audio, language=self.original_language) \ No newline at end of file + pass \ No newline at end of file