import io import logging import math import sys from typing import List import numpy as np import soundfile as sf from whisperlivekit.model_paths import detect_model_format, resolve_model_path from whisperlivekit.timed_objects import ASRToken from whisperlivekit.whisper.transcribe import transcribe as whisper_transcribe logger = logging.getLogger(__name__) class ASRBase: sep = " " # join transcribe words with this character (" " for whisper_timestamped, # "" for faster-whisper because it emits the spaces when needed) def __init__(self, lan, model_size=None, cache_dir=None, model_dir=None, lora_path=None, logfile=sys.stderr): self.logfile = logfile self.transcribe_kargs = {} self.lora_path = lora_path if lan == "auto": self.original_language = None else: self.original_language = lan self.model = self.load_model(model_size, cache_dir, model_dir) def load_model(self, model_size, cache_dir, model_dir): raise NotImplementedError("must be implemented in the child class") def transcribe(self, audio, init_prompt=""): raise NotImplementedError("must be implemented in the child class") def use_vad(self): raise NotImplementedError("must be implemented in the child class") class WhisperASR(ASRBase): """Uses WhisperLiveKit's built-in Whisper implementation.""" sep = " " def load_model(self, model_size=None, cache_dir=None, model_dir=None): from whisperlivekit.whisper import load_model as load_whisper_model if model_dir is not None: resolved_path = resolve_model_path(model_dir) if resolved_path.is_dir(): model_info = detect_model_format(resolved_path) if not model_info.has_pytorch: raise FileNotFoundError( f"No supported PyTorch checkpoint found under {resolved_path}" ) logger.debug(f"Loading Whisper model from custom path {resolved_path}") return load_whisper_model(str(resolved_path), lora_path=self.lora_path) if model_size is None: raise ValueError("Either model_size or model_dir must be set for WhisperASR") return load_whisper_model(model_size, download_root=cache_dir, lora_path=self.lora_path) def transcribe(self, audio, init_prompt=""): options = dict(self.transcribe_kargs) options.pop("vad", None) options.pop("vad_filter", None) language = self.original_language if self.original_language else None result = whisper_transcribe( self.model, audio, language=language, initial_prompt=init_prompt, condition_on_previous_text=True, word_timestamps=True, **options, ) return result def ts_words(self, r) -> List[ASRToken]: """ Converts the Whisper result to a list of ASRToken objects. """ tokens = [] for segment in r["segments"]: for word in segment["words"]: token = ASRToken( word["start"], word["end"], word["word"], probability=word.get("probability"), ) tokens.append(token) return tokens def segments_end_ts(self, res) -> List[float]: return [segment["end"] for segment in res["segments"]] def use_vad(self): logger.warning("VAD is not currently supported for WhisperASR backend and will be ignored.") class FasterWhisperASR(ASRBase): """Uses faster-whisper as the backend.""" sep = "" def load_model(self, model_size=None, cache_dir=None, model_dir=None): from faster_whisper import WhisperModel if model_dir is not None: resolved_path = resolve_model_path(model_dir) logger.debug(f"Loading faster-whisper model from {resolved_path}. " f"model_size and cache_dir parameters are not used.") model_size_or_path = str(resolved_path) elif model_size is not None: model_size_or_path = model_size else: raise ValueError("Either model_size or model_dir must be set") device = "auto" # Allow CTranslate2 to decide available device compute_type = "auto" # Allow CTranslate2 to decide faster compute type model = WhisperModel( model_size_or_path, device=device, compute_type=compute_type, download_root=cache_dir, ) return model def transcribe(self, audio: np.ndarray, init_prompt: str = "") -> list: segments, info = self.model.transcribe( audio, language=self.original_language, initial_prompt=init_prompt, beam_size=5, word_timestamps=True, condition_on_previous_text=True, **self.transcribe_kargs, ) return list(segments) def ts_words(self, segments) -> List[ASRToken]: tokens = [] for segment in segments: if segment.no_speech_prob > 0.9: continue for word in segment.words: token = ASRToken(word.start, word.end, word.word, probability=word.probability) tokens.append(token) return tokens def segments_end_ts(self, segments) -> List[float]: return [segment.end for segment in segments] def use_vad(self): self.transcribe_kargs["vad_filter"] = True class MLXWhisper(ASRBase): """ Uses MLX Whisper optimized for Apple Silicon. """ sep = "" def load_model(self, model_size=None, cache_dir=None, model_dir=None): import mlx.core as mx from mlx_whisper.transcribe import ModelHolder, transcribe if model_dir is not None: resolved_path = resolve_model_path(model_dir) logger.debug(f"Loading MLX Whisper model from {resolved_path}. model_size parameter is not used.") model_size_or_path = str(resolved_path) elif model_size is not None: model_size_or_path = self.translate_model_name(model_size) logger.debug(f"Loading whisper model {model_size}. You use mlx whisper, so {model_size_or_path} will be used.") else: raise ValueError("Either model_size or model_dir must be set") self.model_size_or_path = model_size_or_path dtype = mx.float16 ModelHolder.get_model(model_size_or_path, dtype) return transcribe def translate_model_name(self, model_name): from whisperlivekit.model_mapping import MLX_MODEL_MAPPING mlx_model_path = MLX_MODEL_MAPPING.get(model_name) if mlx_model_path: return mlx_model_path else: raise ValueError(f"Model name '{model_name}' is not recognized or not supported.") def transcribe(self, audio, init_prompt=""): if self.transcribe_kargs: logger.warning("Transcribe kwargs (vad, task) are not compatible with MLX Whisper and will be ignored.") segments = self.model( audio, language=self.original_language, initial_prompt=init_prompt, word_timestamps=True, condition_on_previous_text=True, path_or_hf_repo=self.model_size_or_path, ) return segments.get("segments", []) def ts_words(self, segments) -> List[ASRToken]: tokens = [] for segment in segments: if segment.get("no_speech_prob", 0) > 0.9: continue for word in segment.get("words", []): token = ASRToken(word["start"], word["end"], word["word"]) tokens.append(token) return tokens def segments_end_ts(self, res) -> List[float]: return [s["end"] for s in res] def use_vad(self): self.transcribe_kargs["vad_filter"] = True class OpenaiApiASR(ASRBase): """Uses OpenAI's Whisper API for transcription.""" def __init__(self, lan=None, temperature=0, logfile=sys.stderr): self.logfile = logfile self.modelname = "whisper-1" self.original_language = None if lan == "auto" else lan self.response_format = "verbose_json" self.temperature = temperature self.load_model() self.use_vad_opt = False self.direct_english_translation = False self.task = "transcribe" def load_model(self, *args, **kwargs): from openai import OpenAI self.client = OpenAI() self.transcribed_seconds = 0 def ts_words(self, segments) -> List[ASRToken]: """ Converts OpenAI API response words into ASRToken objects while optionally skipping words that fall into no-speech segments. """ no_speech_segments = [] if self.use_vad_opt: for segment in segments.segments: if segment.no_speech_prob > 0.8: no_speech_segments.append((segment.start, segment.end)) tokens = [] for word in segments.words: start = word.start end = word.end if any(s[0] <= start <= s[1] for s in no_speech_segments): continue tokens.append(ASRToken(start, end, word.word)) return tokens def segments_end_ts(self, res) -> List[float]: return [s.end for s in res.words] def transcribe(self, audio_data, prompt=None, *args, **kwargs): buffer = io.BytesIO() buffer.name = "temp.wav" sf.write(buffer, audio_data, samplerate=16000, format="WAV", subtype="PCM_16") buffer.seek(0) self.transcribed_seconds += math.ceil(len(audio_data) / 16000) params = { "model": self.modelname, "file": buffer, "response_format": self.response_format, "temperature": self.temperature, "timestamp_granularities": ["word", "segment"], } if not self.direct_english_translation and self.original_language: params["language"] = self.original_language if prompt: params["prompt"] = prompt task = self.transcribe_kargs.get("task", self.task) proc = self.client.audio.translations if task == "translate" else self.client.audio.transcriptions transcript = proc.create(**params) logger.debug(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds") return transcript def use_vad(self): self.use_vad_opt = True