Files
WhisperLiveKit/whisperlivekit/simul_whisper/backend.py

365 lines
14 KiB
Python

import sys
import numpy as np
import logging
from typing import List, Tuple, Optional
import logging
import platform
from whisperlivekit.timed_objects import ASRToken, Transcript
from whisperlivekit.warmup import load_file
from whisperlivekit.simul_whisper.license_simulstreaming import SIMULSTREAMING_LICENSE
from .whisper import load_model, tokenizer
from .whisper.audio import TOKENS_PER_SECOND
import os
import gc
logger = logging.getLogger(__name__)
import torch
from whisperlivekit.simul_whisper.config import AlignAttConfig
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper
from whisperlivekit.simul_whisper.whisper import tokenizer
try:
from .mlx_encoder import mlx_model_mapping, load_mlx_encoder
HAS_MLX_WHISPER = True
except ImportError:
if platform.system() == "Darwin" and platform.machine() == "arm64":
print('MLX Whisper not found but you are on Apple Silicon. Consider installing mlx-whisper for better performance: pip install mlx-whisper')
HAS_MLX_WHISPER = False
if HAS_MLX_WHISPER:
HAS_FASTER_WHISPER = False
else:
try:
from faster_whisper import WhisperModel
HAS_FASTER_WHISPER = True
except ImportError:
HAS_FASTER_WHISPER = False
# TOO_MANY_REPETITIONS = 3
class SimulStreamingOnlineProcessor:
SAMPLING_RATE = 16000
def __init__(
self,
asr,
logfile=sys.stderr,
warmup_file=None
):
self.asr = asr
self.logfile = logfile
self.end = 0.0
self.global_time_offset = 0.0
self.committed: List[ASRToken] = []
self.last_result_tokens: List[ASRToken] = []
self.load_new_backend()
#can be moved
if asr.tokenizer:
self.model.tokenizer = asr.tokenizer
def load_new_backend(self):
model = self.asr.get_new_model_instance()
self.model = PaddedAlignAttWhisper(
cfg=self.asr.cfg,
loaded_model=model,
mlx_encoder=self.asr.mlx_encoder,
fw_encoder=self.asr.fw_encoder,
)
def insert_silence(self, silence_duration, offset):
"""
If silences are > 5s, we do a complete context clear. Otherwise, we just insert a small silence and shift the last_attend_frame
"""
if silence_duration < 5:
gap_silence = torch.zeros(int(16000*silence_duration))
self.model.insert_audio(gap_silence)
# self.global_time_offset += silence_duration
else:
self.process_iter(is_last=True) #we want to totally process what remains in the buffer.
self.model.refresh_segment(complete=True)
self.global_time_offset = silence_duration + offset
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time):
"""Append an audio chunk to be processed by SimulStreaming."""
# Convert numpy array to torch tensor
audio_tensor = torch.from_numpy(audio).float()
self.end = audio_stream_end_time #Only to be aligned with what happens in whisperstreaming backend.
self.model.insert_audio(audio_tensor)
def get_buffer(self):
return Transcript(
start=None,
end=None,
text='',
probability=None
)
def timestamped_text(self, tokens, generation):
"""
generate timestamped text from tokens and generation data.
args:
tokens: List of tokens to process
generation: Dictionary containing generation progress and optionally results
returns:
List of tuples containing (start_time, end_time, word) for each word
"""
FRAME_DURATION = 0.02
if "result" in generation:
split_words = generation["result"]["split_words"]
split_tokens = generation["result"]["split_tokens"]
else:
split_words, split_tokens = self.model.tokenizer.split_to_word_tokens(tokens)
progress = generation["progress"]
frames = [p["most_attended_frames"][0] for p in progress]
absolute_timestamps = [p["absolute_timestamps"][0] for p in progress]
tokens_queue = tokens.copy()
timestamped_words = []
for word, word_tokens in zip(split_words, split_tokens):
# start_frame = None
# end_frame = None
for expected_token in word_tokens:
if not tokens_queue or not frames:
raise ValueError(f"Insufficient tokens or frames for word '{word}'")
actual_token = tokens_queue.pop(0)
current_frame = frames.pop(0)
current_timestamp = absolute_timestamps.pop(0)
if actual_token != expected_token:
raise ValueError(
f"Token mismatch: expected '{expected_token}', "
f"got '{actual_token}' at frame {current_frame}"
)
# if start_frame is None:
# start_frame = current_frame
# end_frame = current_frame
# start_time = start_frame * FRAME_DURATION
# end_time = end_frame * FRAME_DURATION
start_time = current_timestamp
end_time = current_timestamp + 0.1
timestamp_entry = (start_time, end_time, word)
timestamped_words.append(timestamp_entry)
logger.debug(f"TS-WORD:\t{start_time:.2f}\t{end_time:.2f}\t{word}")
return timestamped_words
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
"""
Process accumulated audio chunks using SimulStreaming.
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
"""
try:
tokens, generation_progress = self.model.infer(is_last=is_last)
ts_words = self.timestamped_text(tokens, generation_progress)
new_tokens = []
for ts_word in ts_words:
start, end, word = ts_word
token = ASRToken(
start=start,
end=end,
text=word,
probability=0.95 # fake prob. Maybe we can extract it from the model?
).with_offset(
self.global_time_offset
)
new_tokens.append(token)
# identical_tokens = 0
# n_new_tokens = len(new_tokens)
# if n_new_tokens:
self.committed.extend(new_tokens)
# if token in self.committed:
# pos = len(self.committed) - 1 - self.committed[::-1].index(token)
# if pos:
# for i in range(len(self.committed) - n_new_tokens, -1, -n_new_tokens):
# commited_segment = self.committed[i:i+n_new_tokens]
# if commited_segment == new_tokens:
# identical_segments +=1
# if identical_tokens >= TOO_MANY_REPETITIONS:
# logger.warning('Too many repetition, model is stuck. Load a new one')
# self.committed = self.committed[:i]
# self.load_new_backend()
# return [], self.end
# pos = self.committed.rindex(token)
return new_tokens, self.end
except Exception as e:
logger.exception(f"SimulStreaming processing error: {e}")
return [], self.end
def warmup(self, audio, init_prompt=""):
"""Warmup the SimulStreaming model."""
try:
self.model.insert_audio(audio)
self.model.infer(True)
self.model.refresh_segment(complete=True)
logger.info("SimulStreaming model warmed up successfully")
except Exception as e:
logger.exception(f"SimulStreaming warmup failed: {e}")
def __del__(self):
# free the model and add a new model to stack.
# del self.model
gc.collect()
torch.cuda.empty_cache()
# self.asr.new_model_to_stack()
self.model.remove_hooks()
class SimulStreamingASR():
"""SimulStreaming backend with AlignAtt policy."""
sep = ""
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr, **kwargs):
logger.warning(SIMULSTREAMING_LICENSE)
self.logfile = logfile
self.transcribe_kargs = {}
self.original_language = lan
self.model_path = kwargs.get('model_path', './large-v3.pt')
self.frame_threshold = kwargs.get('frame_threshold', 25)
self.audio_max_len = kwargs.get('audio_max_len', 20.0)
self.audio_min_len = kwargs.get('audio_min_len', 0.0)
self.segment_length = kwargs.get('segment_length', 0.5)
self.beams = kwargs.get('beams', 1)
self.decoder_type = kwargs.get('decoder_type', 'greedy' if self.beams == 1 else 'beam')
self.task = kwargs.get('task', 'transcribe')
self.cif_ckpt_path = kwargs.get('cif_ckpt_path', None)
self.never_fire = kwargs.get('never_fire', False)
self.init_prompt = kwargs.get('init_prompt', None)
self.static_init_prompt = kwargs.get('static_init_prompt', None)
self.max_context_tokens = kwargs.get('max_context_tokens', None)
self.warmup_file = kwargs.get('warmup_file', None)
self.preload_model_count = kwargs.get('preload_model_count', 1)
self.disable_fast_encoder = kwargs.get('disable_fast_encoder', False)
self.fast_encoder = False
if model_dir is not None:
self.model_path = model_dir
elif modelsize is not None:
model_mapping = {
'tiny': './tiny.pt',
'base': './base.pt',
'small': './small.pt',
'medium': './medium.pt',
'medium.en': './medium.en.pt',
'large-v1': './large-v1.pt',
'base.en': './base.en.pt',
'small.en': './small.en.pt',
'tiny.en': './tiny.en.pt',
'large-v2': './large-v2.pt',
'large-v3': './large-v3.pt',
'large': './large-v3.pt'
}
self.model_path = model_mapping.get(modelsize, f'./{modelsize}.pt')
self.cfg = AlignAttConfig(
model_path=self.model_path,
segment_length=self.segment_length,
frame_threshold=self.frame_threshold,
language=self.original_language,
audio_max_len=self.audio_max_len,
audio_min_len=self.audio_min_len,
cif_ckpt_path=self.cif_ckpt_path,
decoder_type="beam",
beam_size=self.beams,
task=self.task,
never_fire=self.never_fire,
init_prompt=self.init_prompt,
max_context_tokens=self.max_context_tokens,
static_init_prompt=self.static_init_prompt,
)
# Set up tokenizer for translation if needed
if self.task == "translate":
self.tokenizer = self.set_translate_task()
else:
self.tokenizer = None
self.model_name = os.path.basename(self.cfg.model_path).replace(".pt", "")
self.model_path = os.path.dirname(os.path.abspath(self.cfg.model_path))
self.mlx_encoder, self.fw_encoder = None, None
if not self.disable_fast_encoder:
if HAS_MLX_WHISPER:
print('Simulstreaming will use MLX whisper for a faster encoder.')
mlx_model_name = mlx_model_mapping[self.model_name]
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model_name)
self.fast_encoder = True
elif HAS_FASTER_WHISPER:
print('Simulstreaming will use Faster Whisper for the encoder.')
self.fw_encoder = WhisperModel(
self.model_name,
device='auto',
compute_type='auto',
)
self.fast_encoder = True
self.models = [self.load_model() for i in range(self.preload_model_count)]
def load_model(self):
whisper_model = load_model(name=self.model_name, download_root=self.model_path, decoder_only=self.fast_encoder)
warmup_audio = load_file(self.warmup_file)
if warmup_audio is not None:
warmup_audio = torch.from_numpy(warmup_audio).float()
if self.fast_encoder:
temp_model = PaddedAlignAttWhisper(
cfg=self.cfg,
loaded_model=whisper_model,
mlx_encoder=self.mlx_encoder,
fw_encoder=self.fw_encoder,
)
temp_model.warmup(warmup_audio)
temp_model.remove_hooks()
else:
# For standard encoder, use the original transcribe warmup
warmup_audio = load_file(self.warmup_file)
whisper_model.transcribe(warmup_audio, language=self.original_language if self.original_language != 'auto' else None)
return whisper_model
def get_new_model_instance(self):
"""
SimulStreaming cannot share the same backend because it uses global forward hooks on the attention layers.
Therefore, each user requires a separate model instance, which can be memory-intensive. To maintain speed, we preload the models into memory.
"""
if len(self.models) == 0:
self.models.append(self.load_model())
new_model = self.models.pop()
return new_model
# self.models[0]
def new_model_to_stack(self):
self.models.append(self.load_model())
def set_translate_task(self):
"""Set up translation task."""
if self.cfg.language == 'auto':
raise Exception('Translation cannot be done with language = auto')
return tokenizer.get_tokenizer(
multilingual=True,
language=self.cfg.language,
num_languages=99,
task="translate"
)
def transcribe(self, audio):
"""
Warmup is done directly in load_model
"""
pass