mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 14:23:18 +00:00
adapt online for mlx detection
This commit is contained in:
@@ -176,12 +176,10 @@ class TranscriptionEngine:
|
||||
|
||||
|
||||
def online_factory(args, asr):
|
||||
if args.backend_policy == "simulstreaming":
|
||||
if args.backend_policy == "simulstreaming":
|
||||
from whisperlivekit.simul_whisper import SimulStreamingOnlineProcessor
|
||||
online = SimulStreamingOnlineProcessor(asr)
|
||||
else:
|
||||
online = OnlineASRProcessor(asr)
|
||||
return online
|
||||
return SimulStreamingOnlineProcessor(asr)
|
||||
return OnlineASRProcessor(asr)
|
||||
|
||||
|
||||
def online_diarization_factory(args, diarization_backend):
|
||||
|
||||
@@ -24,9 +24,11 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
HAS_MLX_WHISPER = mlx_backend_available(warn_on_missing=True)
|
||||
if HAS_MLX_WHISPER:
|
||||
from .mlx_encoder import load_mlx_encoder, mlx_model_mapping
|
||||
from .mlx_encoder import load_mlx_encoder, load_mlx_model, mlx_model_mapping
|
||||
from .mlx import MLXAlignAtt
|
||||
else:
|
||||
mlx_model_mapping = {}
|
||||
MLXAlignAtt = None
|
||||
HAS_FASTER_WHISPER = faster_backend_available(warn_on_missing=not HAS_MLX_WHISPER)
|
||||
if HAS_FASTER_WHISPER:
|
||||
from faster_whisper import WhisperModel
|
||||
@@ -36,50 +38,49 @@ else:
|
||||
MIN_DURATION_REAL_SILENCE = 5
|
||||
|
||||
class SimulStreamingOnlineProcessor:
|
||||
"""Online processor for SimulStreaming ASR."""
|
||||
SAMPLING_RATE = 16000
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
asr,
|
||||
logfile=sys.stderr,
|
||||
):
|
||||
def __init__(self, asr, logfile=sys.stderr):
|
||||
self.asr = asr
|
||||
self.logfile = logfile
|
||||
self.end = 0.0
|
||||
self.buffer = []
|
||||
self.committed: List[ASRToken] = []
|
||||
self.last_result_tokens: List[ASRToken] = []
|
||||
self.load_new_alignatt_instance()
|
||||
self.last_result_tokens: List[ASRToken] = []
|
||||
self.model = self._create_alignatt()
|
||||
|
||||
if asr.tokenizer:
|
||||
self.model.tokenizer = asr.tokenizer
|
||||
self.model.state.tokenizer = asr.tokenizer
|
||||
|
||||
def load_new_alignatt_instance(self):
|
||||
"""Initialize AlignAtt decoder using the shared model."""
|
||||
self.model = AlignAtt(
|
||||
cfg=self.asr.cfg,
|
||||
loaded_model=self.asr.shared_model,
|
||||
mlx_encoder=self.asr.mlx_encoder,
|
||||
fw_encoder=self.asr.fw_encoder,
|
||||
)
|
||||
def _create_alignatt(self):
|
||||
"""Create the AlignAtt decoder instance based on ASR mode."""
|
||||
if self.asr.use_full_mlx and HAS_MLX_WHISPER:
|
||||
return MLXAlignAtt(cfg=self.asr.cfg, mlx_model=self.asr.mlx_model)
|
||||
else:
|
||||
return AlignAtt(
|
||||
cfg=self.asr.cfg,
|
||||
loaded_model=self.asr.shared_model,
|
||||
mlx_encoder=self.asr.mlx_encoder,
|
||||
fw_encoder=self.asr.fw_encoder,
|
||||
)
|
||||
|
||||
def start_silence(self):
|
||||
tokens, processed_upto = self.process_iter(is_last=True)
|
||||
return tokens, processed_upto
|
||||
|
||||
def end_silence(self, silence_duration, offset):
|
||||
"""
|
||||
Handle silence period.
|
||||
|
||||
If silence > MIN_DURATION_REAL_SILENCE, do a complete context clear.
|
||||
Otherwise, insert a small silence and shift the last_attend_frame.
|
||||
"""
|
||||
"""Handle silence period."""
|
||||
self.end += silence_duration
|
||||
long_silence = silence_duration >= MIN_DURATION_REAL_SILENCE
|
||||
if not long_silence:
|
||||
gap_len = int(16000 * silence_duration)
|
||||
if gap_len > 0:
|
||||
gap_silence = torch.zeros(gap_len)
|
||||
if self.asr.use_full_mlx:
|
||||
gap_silence = np.zeros(gap_len, dtype=np.float32)
|
||||
else:
|
||||
gap_silence = torch.zeros(gap_len)
|
||||
self.model.insert_audio(gap_silence)
|
||||
if long_silence:
|
||||
self.model.refresh_segment(complete=True)
|
||||
@@ -87,11 +88,12 @@ class SimulStreamingOnlineProcessor:
|
||||
|
||||
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time):
|
||||
"""Append an audio chunk to be processed by SimulStreaming."""
|
||||
|
||||
# Convert numpy array to torch tensor
|
||||
audio_tensor = torch.from_numpy(audio).float()
|
||||
self.end = audio_stream_end_time # Aligned with whisperstreaming backend behavior
|
||||
self.model.insert_audio(audio_tensor)
|
||||
self.end = audio_stream_end_time
|
||||
if self.asr.use_full_mlx:
|
||||
self.model.insert_audio(audio)
|
||||
else:
|
||||
audio_tensor = torch.from_numpy(audio).float()
|
||||
self.model.insert_audio(audio_tensor)
|
||||
|
||||
def new_speaker(self, change_speaker: ChangeSpeaker):
|
||||
"""Handle speaker change event."""
|
||||
@@ -130,6 +132,10 @@ class SimulStreamingOnlineProcessor:
|
||||
def warmup(self, audio, init_prompt=""):
|
||||
"""Warmup the SimulStreaming model."""
|
||||
try:
|
||||
if self.asr.use_full_mlx:
|
||||
# MLX mode: ensure numpy array
|
||||
if hasattr(audio, 'numpy'):
|
||||
audio = audio.numpy()
|
||||
self.model.insert_audio(audio)
|
||||
self.model.infer(True)
|
||||
self.model.refresh_segment(complete=True)
|
||||
@@ -139,9 +145,14 @@ class SimulStreamingOnlineProcessor:
|
||||
|
||||
def __del__(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
if not getattr(self.asr, 'use_full_mlx', True) and torch is not None:
|
||||
try:
|
||||
torch.cuda.empty_cache()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
class SimulStreamingASR():
|
||||
|
||||
class SimulStreamingASR:
|
||||
"""SimulStreaming backend with AlignAtt policy."""
|
||||
sep = ""
|
||||
|
||||
@@ -158,6 +169,7 @@ class SimulStreamingASR():
|
||||
self.fast_encoder = False
|
||||
self._resolved_model_path = None
|
||||
self.encoder_backend = "whisper"
|
||||
self.use_full_mlx = getattr(self, "use_full_mlx", False)
|
||||
preferred_backend = getattr(self, "backend", "auto")
|
||||
compatible_whisper_mlx, compatible_faster_whisper = True, True
|
||||
|
||||
@@ -170,7 +182,7 @@ class SimulStreamingASR():
|
||||
compatible_whisper_mlx = model_info.compatible_whisper_mlx
|
||||
compatible_faster_whisper = model_info.compatible_faster_whisper
|
||||
|
||||
if not model_info.has_pytorch:
|
||||
if not self.use_full_mlx and not model_info.has_pytorch:
|
||||
raise FileNotFoundError(
|
||||
f"No PyTorch checkpoint (.pt/.bin/.safetensors) found under {self.model_path}"
|
||||
)
|
||||
@@ -190,6 +202,10 @@ class SimulStreamingASR():
|
||||
self.fast_encoder = self.encoder_backend in ("mlx-whisper", "faster-whisper")
|
||||
if self.encoder_backend == "whisper":
|
||||
self.disable_fast_encoder = True
|
||||
|
||||
if self.encoder_backend == "mlx-whisper" and platform.system() == "Darwin":
|
||||
if not hasattr(self, '_full_mlx_disabled'):
|
||||
self.use_full_mlx = True
|
||||
|
||||
self.cfg = AlignAttConfig(
|
||||
tokenizer_is_multilingual= is_multilingual,
|
||||
@@ -214,20 +230,36 @@ class SimulStreamingASR():
|
||||
else:
|
||||
self.tokenizer = None
|
||||
|
||||
self.mlx_encoder, self.fw_encoder = None, None
|
||||
if self.encoder_backend == "mlx-whisper":
|
||||
print('Simulstreaming will use MLX whisper to increase encoding speed.')
|
||||
self.mlx_encoder, self.fw_encoder, self.mlx_model = None, None, None
|
||||
self.shared_model = None
|
||||
|
||||
if self.use_full_mlx and HAS_MLX_WHISPER:
|
||||
logger.info('MLX Whisper backend used.')
|
||||
if self._resolved_model_path is not None:
|
||||
mlx_model = str(self._resolved_model_path)
|
||||
mlx_model_path = str(self._resolved_model_path)
|
||||
else:
|
||||
mlx_model = mlx_model_mapping.get(self.model_name)
|
||||
if not mlx_model:
|
||||
mlx_model_path = mlx_model_mapping.get(self.model_name)
|
||||
if not mlx_model_path:
|
||||
raise FileNotFoundError(
|
||||
f"MLX Whisper backend requested but no compatible weights found for model '{self.model_name}'."
|
||||
)
|
||||
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model)
|
||||
self.mlx_model = load_mlx_model(path_or_hf_repo=mlx_model_path)
|
||||
self._warmup_mlx_model()
|
||||
elif self.encoder_backend == "mlx-whisper":
|
||||
# hybrid mode: mlx encoder + pytorch decoder
|
||||
logger.info('SimulStreaming will use MLX Whisper encoder with PyTorch decoder.')
|
||||
if self._resolved_model_path is not None:
|
||||
mlx_model_path = str(self._resolved_model_path)
|
||||
else:
|
||||
mlx_model_path = mlx_model_mapping.get(self.model_name)
|
||||
if not mlx_model_path:
|
||||
raise FileNotFoundError(
|
||||
f"MLX Whisper backend requested but no compatible weights found for model '{self.model_name}'."
|
||||
)
|
||||
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model_path)
|
||||
self.shared_model = self.load_model()
|
||||
elif self.encoder_backend == "faster-whisper":
|
||||
print('Simulstreaming will use Faster Whisper for the encoder.')
|
||||
print('SimulStreaming will use Faster Whisper for the encoder.')
|
||||
if self._resolved_model_path is not None:
|
||||
fw_model = str(self._resolved_model_path)
|
||||
else:
|
||||
@@ -237,7 +269,20 @@ class SimulStreamingASR():
|
||||
device='auto',
|
||||
compute_type='auto',
|
||||
)
|
||||
self.shared_model = self.load_model()
|
||||
self.shared_model = self.load_model()
|
||||
else:
|
||||
self.shared_model = self.load_model()
|
||||
|
||||
def _warmup_mlx_model(self):
|
||||
"""Warmup the full MLX model."""
|
||||
warmup_audio = load_file(self.warmup_file)
|
||||
if warmup_audio is not None:
|
||||
temp_model = MLXAlignAtt(
|
||||
cfg=self.cfg,
|
||||
mlx_model=self.mlx_model,
|
||||
)
|
||||
temp_model.warmup(warmup_audio)
|
||||
logger.info("Full MLX model warmed up successfully")
|
||||
|
||||
|
||||
def _resolve_encoder_backend(self, preferred_backend, compatible_whisper_mlx, compatible_faster_whisper):
|
||||
|
||||
Reference in New Issue
Block a user