mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
348 lines
13 KiB
Python
348 lines
13 KiB
Python
import sys
|
|
import numpy as np
|
|
import logging
|
|
from typing import List, Tuple, Optional
|
|
import platform
|
|
from whisperlivekit.timed_objects import ASRToken, Transcript, ChangeSpeaker
|
|
from whisperlivekit.warmup import load_file
|
|
from whisperlivekit.whisper import load_model, tokenizer
|
|
from whisperlivekit.whisper.audio import TOKENS_PER_SECOND
|
|
import os
|
|
import gc
|
|
from pathlib import Path
|
|
from whisperlivekit.model_paths import model_path_and_type, resolve_model_path
|
|
from whisperlivekit.backend_support import (
|
|
mlx_backend_available,
|
|
faster_backend_available,
|
|
)
|
|
|
|
import torch
|
|
from whisperlivekit.simul_whisper.config import AlignAttConfig
|
|
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
HAS_MLX_WHISPER = mlx_backend_available(warn_on_missing=True)
|
|
if HAS_MLX_WHISPER:
|
|
from .mlx_encoder import mlx_model_mapping, load_mlx_encoder
|
|
else:
|
|
mlx_model_mapping = {}
|
|
HAS_FASTER_WHISPER = faster_backend_available(warn_on_missing=not HAS_MLX_WHISPER)
|
|
if HAS_FASTER_WHISPER:
|
|
from faster_whisper import WhisperModel
|
|
else:
|
|
WhisperModel = None
|
|
|
|
class SimulStreamingOnlineProcessor:
|
|
SAMPLING_RATE = 16000
|
|
|
|
def __init__(
|
|
self,
|
|
asr,
|
|
logfile=sys.stderr,
|
|
):
|
|
self.asr = asr
|
|
self.logfile = logfile
|
|
self.end = 0.0
|
|
self.buffer = []
|
|
self.committed: List[ASRToken] = []
|
|
self.last_result_tokens: List[ASRToken] = []
|
|
self.load_new_backend()
|
|
|
|
#can be moved
|
|
if asr.tokenizer:
|
|
self.model.tokenizer = asr.tokenizer
|
|
|
|
def load_new_backend(self):
|
|
model = self.asr.get_new_model_instance()
|
|
self.model = PaddedAlignAttWhisper(
|
|
cfg=self.asr.cfg,
|
|
loaded_model=model,
|
|
mlx_encoder=self.asr.mlx_encoder,
|
|
fw_encoder=self.asr.fw_encoder,
|
|
)
|
|
|
|
def insert_silence(self, silence_duration, offset):
|
|
"""
|
|
If silences are > 5s, we do a complete context clear. Otherwise, we just insert a small silence and shift the last_attend_frame
|
|
"""
|
|
if silence_duration < 5:
|
|
gap_silence = torch.zeros(int(16000*silence_duration))
|
|
self.model.insert_audio(gap_silence)
|
|
# self.global_time_offset += silence_duration
|
|
else:
|
|
self.process_iter(is_last=True) #we want to totally process what remains in the buffer.
|
|
self.model.refresh_segment(complete=True)
|
|
self.model.global_time_offset = silence_duration + offset
|
|
|
|
|
|
|
|
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time):
|
|
"""Append an audio chunk to be processed by SimulStreaming."""
|
|
|
|
# Convert numpy array to torch tensor
|
|
audio_tensor = torch.from_numpy(audio).float()
|
|
self.end = audio_stream_end_time #Only to be aligned with what happens in whisperstreaming backend.
|
|
self.model.insert_audio(audio_tensor)
|
|
|
|
def new_speaker(self, change_speaker: ChangeSpeaker):
|
|
self.process_iter(is_last=True)
|
|
self.model.refresh_segment(complete=True)
|
|
self.model.speaker = change_speaker.speaker
|
|
self.global_time_offset = change_speaker.start
|
|
|
|
def get_buffer(self):
|
|
concat_buffer = Transcript.from_tokens(tokens= self.buffer, sep='')
|
|
return concat_buffer
|
|
|
|
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
|
|
"""
|
|
Process accumulated audio chunks using SimulStreaming.
|
|
|
|
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
|
|
"""
|
|
try:
|
|
timestamped_words = self.model.infer(is_last=is_last)
|
|
if self.model.cfg.language == "auto" and timestamped_words and timestamped_words[0].detected_language == None:
|
|
self.buffer.extend(timestamped_words)
|
|
return [], self.end
|
|
|
|
self.committed.extend(timestamped_words)
|
|
self.buffer = []
|
|
return timestamped_words, self.end
|
|
|
|
|
|
except Exception as e:
|
|
logger.exception(f"SimulStreaming processing error: {e}")
|
|
return [], self.end
|
|
|
|
def warmup(self, audio, init_prompt=""):
|
|
"""Warmup the SimulStreaming model."""
|
|
try:
|
|
self.model.insert_audio(audio)
|
|
self.model.infer(True)
|
|
self.model.refresh_segment(complete=True)
|
|
logger.info("SimulStreaming model warmed up successfully")
|
|
except Exception as e:
|
|
logger.exception(f"SimulStreaming warmup failed: {e}")
|
|
|
|
def __del__(self):
|
|
# free the model and add a new model to stack.
|
|
# del self.model
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
# self.asr.new_model_to_stack()
|
|
self.model.remove_hooks()
|
|
|
|
class SimulStreamingASR():
|
|
"""SimulStreaming backend with AlignAtt policy."""
|
|
sep = ""
|
|
|
|
def __init__(self, logfile=sys.stderr, **kwargs):
|
|
self.logfile = logfile
|
|
self.transcribe_kargs = {}
|
|
|
|
for key, value in kwargs.items():
|
|
setattr(self, key, value)
|
|
|
|
if self.decoder_type is None:
|
|
self.decoder_type = 'greedy' if self.beams == 1 else 'beam'
|
|
|
|
self.fast_encoder = False
|
|
self._resolved_model_path = None
|
|
self.encoder_backend = "whisper"
|
|
preferred_backend = getattr(self, "backend", "auto")
|
|
self.pytorch_path, compatible_whisper_mlx, compatible_faster_whisper = None, True, True
|
|
if self.model_path:
|
|
resolved_model_path = resolve_model_path(self.model_path)
|
|
self._resolved_model_path = resolved_model_path
|
|
self.model_path = str(resolved_model_path)
|
|
self.pytorch_path, compatible_whisper_mlx, compatible_faster_whisper = model_path_and_type(resolved_model_path)
|
|
if self.pytorch_path:
|
|
self.model_name = self.pytorch_path.stem
|
|
else:
|
|
self.model_name = Path(self.model_path).stem
|
|
raise FileNotFoundError(
|
|
f"No PyTorch checkpoint (.pt/.bin/.safetensors) found under {self.model_path}"
|
|
)
|
|
elif self.model_size is not None:
|
|
model_mapping = {
|
|
'tiny': './tiny.pt',
|
|
'base': './base.pt',
|
|
'small': './small.pt',
|
|
'medium': './medium.pt',
|
|
'medium.en': './medium.en.pt',
|
|
'large-v1': './large-v1.pt',
|
|
'base.en': './base.en.pt',
|
|
'small.en': './small.en.pt',
|
|
'tiny.en': './tiny.en.pt',
|
|
'large-v2': './large-v2.pt',
|
|
'large-v3': './large-v3.pt',
|
|
'large': './large-v3.pt'
|
|
}
|
|
self.model_name = self.model_size
|
|
else:
|
|
raise ValueError("Either model_size or model_path must be specified for SimulStreaming.")
|
|
|
|
is_multilingual = not self.model_name.endswith(".en")
|
|
|
|
self.encoder_backend = self._resolve_encoder_backend(
|
|
preferred_backend,
|
|
compatible_whisper_mlx,
|
|
compatible_faster_whisper,
|
|
)
|
|
self.fast_encoder = self.encoder_backend in ("mlx-whisper", "faster-whisper")
|
|
if self.encoder_backend == "whisper":
|
|
self.disable_fast_encoder = True
|
|
|
|
self.cfg = AlignAttConfig(
|
|
tokenizer_is_multilingual= is_multilingual,
|
|
segment_length=self.min_chunk_size,
|
|
frame_threshold=self.frame_threshold,
|
|
language=self.lan,
|
|
audio_max_len=self.audio_max_len,
|
|
audio_min_len=self.audio_min_len,
|
|
cif_ckpt_path=self.cif_ckpt_path,
|
|
decoder_type="beam",
|
|
beam_size=self.beams,
|
|
task=self.direct_english_translation,
|
|
never_fire=self.never_fire,
|
|
init_prompt=self.init_prompt,
|
|
max_context_tokens=self.max_context_tokens,
|
|
static_init_prompt=self.static_init_prompt,
|
|
)
|
|
|
|
# Set up tokenizer for translation if needed
|
|
if self.direct_english_translation:
|
|
self.tokenizer = self.set_translate_task()
|
|
else:
|
|
self.tokenizer = None
|
|
|
|
|
|
|
|
|
|
self.mlx_encoder, self.fw_encoder = None, None
|
|
if self.encoder_backend == "mlx-whisper":
|
|
print('Simulstreaming will use MLX whisper to increase encoding speed.')
|
|
if self._resolved_model_path is not None:
|
|
mlx_model = str(self._resolved_model_path)
|
|
else:
|
|
mlx_model = mlx_model_mapping.get(self.model_name)
|
|
if not mlx_model:
|
|
raise FileNotFoundError(
|
|
f"MLX Whisper backend requested but no compatible weights found for model '{self.model_name}'."
|
|
)
|
|
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model)
|
|
elif self.encoder_backend == "faster-whisper":
|
|
print('Simulstreaming will use Faster Whisper for the encoder.')
|
|
if self._resolved_model_path is not None:
|
|
fw_model = str(self._resolved_model_path)
|
|
else:
|
|
fw_model = self.model_name
|
|
self.fw_encoder = WhisperModel(
|
|
fw_model,
|
|
device='auto',
|
|
compute_type='auto',
|
|
)
|
|
|
|
self.models = [self.load_model() for i in range(self.preload_model_count)]
|
|
|
|
|
|
def _resolve_encoder_backend(self, preferred_backend, compatible_whisper_mlx, compatible_faster_whisper):
|
|
choice = preferred_backend or "auto"
|
|
if self.disable_fast_encoder:
|
|
return "whisper"
|
|
if choice == "whisper":
|
|
return "whisper"
|
|
if choice == "mlx-whisper":
|
|
if not self._can_use_mlx(compatible_whisper_mlx):
|
|
raise RuntimeError("mlx-whisper backend requested but MLX Whisper is unavailable or incompatible with the provided model.")
|
|
return "mlx-whisper"
|
|
if choice == "faster-whisper":
|
|
if not self._can_use_faster(compatible_faster_whisper):
|
|
raise RuntimeError("faster-whisper backend requested but Faster-Whisper is unavailable or incompatible with the provided model.")
|
|
return "faster-whisper"
|
|
if choice == "openai-api":
|
|
raise ValueError("openai-api backend is only supported with the LocalAgreement policy.")
|
|
# auto mode
|
|
if platform.system() == "Darwin" and self._can_use_mlx(compatible_whisper_mlx):
|
|
return "mlx-whisper"
|
|
if self._can_use_faster(compatible_faster_whisper):
|
|
return "faster-whisper"
|
|
return "whisper"
|
|
|
|
def _has_custom_model_path(self):
|
|
return self._resolved_model_path is not None
|
|
|
|
def _can_use_mlx(self, compatible_whisper_mlx):
|
|
if not HAS_MLX_WHISPER:
|
|
return False
|
|
if self._has_custom_model_path():
|
|
return compatible_whisper_mlx
|
|
return self.model_name in mlx_model_mapping
|
|
|
|
def _can_use_faster(self, compatible_faster_whisper):
|
|
if not HAS_FASTER_WHISPER:
|
|
return False
|
|
if self._has_custom_model_path():
|
|
return compatible_faster_whisper
|
|
return True
|
|
|
|
def load_model(self):
|
|
whisper_model = load_model(
|
|
name=self.pytorch_path if self.pytorch_path else self.model_name,
|
|
download_root=self.model_path,
|
|
decoder_only=self.fast_encoder,
|
|
custom_alignment_heads=self.custom_alignment_heads
|
|
)
|
|
warmup_audio = load_file(self.warmup_file)
|
|
if warmup_audio is not None:
|
|
warmup_audio = torch.from_numpy(warmup_audio).float()
|
|
if self.fast_encoder:
|
|
temp_model = PaddedAlignAttWhisper(
|
|
cfg=self.cfg,
|
|
loaded_model=whisper_model,
|
|
mlx_encoder=self.mlx_encoder,
|
|
fw_encoder=self.fw_encoder,
|
|
)
|
|
temp_model.warmup(warmup_audio)
|
|
temp_model.remove_hooks()
|
|
else:
|
|
# For standard encoder, use the original transcribe warmup
|
|
warmup_audio = load_file(self.warmup_file)
|
|
whisper_model.transcribe(warmup_audio, language=self.lan if self.lan != 'auto' else None)
|
|
return whisper_model
|
|
|
|
def get_new_model_instance(self):
|
|
"""
|
|
SimulStreaming cannot share the same backend because it uses global forward hooks on the attention layers.
|
|
Therefore, each user requires a separate model instance, which can be memory-intensive. To maintain speed, we preload the models into memory.
|
|
"""
|
|
if len(self.models) == 0:
|
|
self.models.append(self.load_model())
|
|
new_model = self.models.pop()
|
|
return new_model
|
|
# self.models[0]
|
|
|
|
def new_model_to_stack(self):
|
|
self.models.append(self.load_model())
|
|
|
|
|
|
def set_translate_task(self):
|
|
"""Set up translation task."""
|
|
if self.cfg.language == 'auto':
|
|
raise Exception('Translation cannot be done with language = auto')
|
|
return tokenizer.get_tokenizer(
|
|
multilingual=True,
|
|
language=self.cfg.language,
|
|
num_languages=99,
|
|
task="translate"
|
|
)
|
|
|
|
def transcribe(self, audio):
|
|
"""
|
|
Warmup is done directly in load_model
|
|
"""
|
|
pass
|