Use distinct backend models for simulstreaming and add --preloaded_model_count to preload them

This commit is contained in:
Quentin Fuxa
2025-08-15 23:03:55 +02:00
parent 349c7dcb9e
commit 1652db9a2d
4 changed files with 52 additions and 14 deletions

View File

@@ -219,6 +219,7 @@ WhisperLiveKit offers extensive configuration options:
| `--static-init-prompt` | Static prompt that doesn't scroll | `None` |
| `--max-context-tokens` | Maximum context tokens | `None` |
| `--model-path` | Direct path to .pt model file. Download it if not found | `./base.pt` |
| `--preloaded-model-count` | Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent users) | `1` |
## 🔧 How It Works

View File

@@ -90,7 +90,7 @@ class TranscriptionEngine:
simulstreaming_kwargs = {}
for attr in ['frame_threshold', 'beams', 'decoder_type', 'audio_max_len', 'audio_min_len',
'cif_ckpt_path', 'never_fire', 'init_prompt', 'static_init_prompt',
'max_context_tokens', 'model_path']:
'max_context_tokens', 'model_path', 'warmup_file', 'preload_model_count']:
if hasattr(self.args, attr):
simulstreaming_kwargs[attr] = getattr(self.args, attr)

View File

@@ -242,6 +242,14 @@ def parse_args():
dest="model_path",
help="Direct path to the SimulStreaming Whisper .pt model file. Overrides --model for SimulStreaming backend.",
)
simulstreaming_group.add_argument(
"--preloaded_model_count",
type=int,
default=1,
dest="preloaded_model_count",
help="Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent instances).",
)
args = parser.parse_args()

View File

@@ -4,9 +4,11 @@ import logging
from typing import List, Tuple, Optional
import logging
from whisperlivekit.timed_objects import ASRToken, Transcript
from whisperlivekit.warmup import load_file
from whisperlivekit.simul_whisper.license_simulstreaming import SIMULSTREAMING_LICENSE
from .whisper import load_model, tokenizer
import os
import gc
logger = logging.getLogger(__name__)
try:
@@ -36,10 +38,11 @@ class SimulStreamingOnlineProcessor:
self.cumulative_audio_duration = 0.0
self.committed: List[ASRToken] = []
self.last_result_tokens: List[ASRToken] = []
self.last_result_tokens: List[ASRToken] = []
model = asr.get_new_model_instance()
self.model = PaddedAlignAttWhisper(
cfg=asr.cfg,
loaded_model=asr.whisper_model)
loaded_model=model)
if asr.tokenizer:
self.model.tokenizer = asr.tokenizer
@@ -132,6 +135,12 @@ class SimulStreamingOnlineProcessor:
except Exception as e:
logger.exception(f"SimulStreaming warmup failed: {e}")
def __del__(self):
# free the model and add a new model to stack.
del self.model
gc.collect()
torch.cuda.empty_cache()
self.asr.new_model_to_stack()
class SimulStreamingASR():
"""SimulStreaming backend with AlignAtt policy."""
@@ -156,6 +165,8 @@ class SimulStreamingASR():
self.init_prompt = kwargs.get('init_prompt', None)
self.static_init_prompt = kwargs.get('static_init_prompt', None)
self.max_context_tokens = kwargs.get('max_context_tokens', None)
self.warmup_file = kwargs.get('warmup_file', None)
self.preload_model_count = kwargs.get('preload_model_count', 1)
if model_dir is not None:
self.model_path = model_dir
@@ -176,16 +187,11 @@ class SimulStreamingASR():
}
self.model_path = model_mapping.get(modelsize, f'./{modelsize}.pt')
self.model = self.load_model(modelsize)
# Set up tokenizer for translation if needed
if self.task == "translate":
self.tokenizer = self.set_translate_task()
else:
self.tokenizer = None
def load_model(self, modelsize):
self.cfg = AlignAttConfig(
model_path=self.model_path,
segment_length=self.segment_length,
@@ -201,10 +207,33 @@ class SimulStreamingASR():
init_prompt=self.init_prompt,
max_context_tokens=self.max_context_tokens,
static_init_prompt=self.static_init_prompt,
)
model_name = os.path.basename(self.cfg.model_path).replace(".pt", "")
model_path = os.path.dirname(os.path.abspath(self.cfg.model_path))
self.whisper_model = load_model(name=model_name, download_root=model_path)
)
self.model_name = os.path.basename(self.cfg.model_path).replace(".pt", "")
self.model_path = os.path.dirname(os.path.abspath(self.cfg.model_path))
self.models = [self.load_model() for i in range(self.preload_model_count)]
def load_model(self):
whisper_model = load_model(name=self.model_name, download_root=self.model_path)
warmup_audio = load_file(self.warmup_file)
whisper_model.transcribe(warmup_audio, language=self.original_language)
return whisper_model
def get_new_model_instance(self):
"""
SimulStreaming cannot share the same backend because it uses global forward hooks on the attention layers.
Therefore, each user requires a separate model instance, which can be memory-intensive. To maintain speed, we preload the models into memory.
"""
if len(self.models) == 0:
self.models.append(self.load_model())
new_model = self.models.pop()
return new_model
def new_model_to_stack(self):
self.models.append(self.load_model())
def set_translate_task(self):
@@ -218,6 +247,6 @@ class SimulStreamingASR():
def transcribe(self, audio):
"""
Only used for warmup. It's a direct whisper call, not a simulstreaming call
Warmup is done directly in load_model
"""
self.whisper_model.transcribe(audio, language=self.original_language)
pass