mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 14:23:18 +00:00
285 lines
10 KiB
Python
285 lines
10 KiB
Python
import io
|
|
import logging
|
|
import math
|
|
import sys
|
|
from typing import List
|
|
|
|
import numpy as np
|
|
import soundfile as sf
|
|
|
|
from whisperlivekit.model_paths import detect_model_format, resolve_model_path
|
|
from whisperlivekit.timed_objects import ASRToken
|
|
from whisperlivekit.whisper.transcribe import transcribe as whisper_transcribe
|
|
|
|
logger = logging.getLogger(__name__)
|
|
class ASRBase:
|
|
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
|
|
# "" for faster-whisper because it emits the spaces when needed)
|
|
|
|
def __init__(self, lan, model_size=None, cache_dir=None, model_dir=None, lora_path=None, logfile=sys.stderr):
|
|
self.logfile = logfile
|
|
self.transcribe_kargs = {}
|
|
self.lora_path = lora_path
|
|
if lan == "auto":
|
|
self.original_language = None
|
|
else:
|
|
self.original_language = lan
|
|
self.model = self.load_model(model_size, cache_dir, model_dir)
|
|
|
|
def load_model(self, model_size, cache_dir, model_dir):
|
|
raise NotImplementedError("must be implemented in the child class")
|
|
|
|
def transcribe(self, audio, init_prompt=""):
|
|
raise NotImplementedError("must be implemented in the child class")
|
|
|
|
def use_vad(self):
|
|
raise NotImplementedError("must be implemented in the child class")
|
|
|
|
|
|
class WhisperASR(ASRBase):
|
|
"""Uses WhisperLiveKit's built-in Whisper implementation."""
|
|
sep = " "
|
|
|
|
def load_model(self, model_size=None, cache_dir=None, model_dir=None):
|
|
from whisperlivekit.whisper import load_model as load_whisper_model
|
|
|
|
if model_dir is not None:
|
|
resolved_path = resolve_model_path(model_dir)
|
|
if resolved_path.is_dir():
|
|
model_info = detect_model_format(resolved_path)
|
|
if not model_info.has_pytorch:
|
|
raise FileNotFoundError(
|
|
f"No supported PyTorch checkpoint found under {resolved_path}"
|
|
)
|
|
logger.debug(f"Loading Whisper model from custom path {resolved_path}")
|
|
return load_whisper_model(str(resolved_path), lora_path=self.lora_path)
|
|
|
|
if model_size is None:
|
|
raise ValueError("Either model_size or model_dir must be set for WhisperASR")
|
|
|
|
return load_whisper_model(model_size, download_root=cache_dir, lora_path=self.lora_path)
|
|
|
|
def transcribe(self, audio, init_prompt=""):
|
|
options = dict(self.transcribe_kargs)
|
|
options.pop("vad", None)
|
|
options.pop("vad_filter", None)
|
|
language = self.original_language if self.original_language else None
|
|
|
|
result = whisper_transcribe(
|
|
self.model,
|
|
audio,
|
|
language=language,
|
|
initial_prompt=init_prompt,
|
|
condition_on_previous_text=True,
|
|
word_timestamps=True,
|
|
**options,
|
|
)
|
|
return result
|
|
|
|
def ts_words(self, r) -> List[ASRToken]:
|
|
"""
|
|
Converts the Whisper result to a list of ASRToken objects.
|
|
"""
|
|
tokens = []
|
|
for segment in r["segments"]:
|
|
for word in segment["words"]:
|
|
token = ASRToken(
|
|
word["start"],
|
|
word["end"],
|
|
word["word"],
|
|
probability=word.get("probability"),
|
|
)
|
|
tokens.append(token)
|
|
return tokens
|
|
|
|
def segments_end_ts(self, res) -> List[float]:
|
|
return [segment["end"] for segment in res["segments"]]
|
|
|
|
def use_vad(self):
|
|
logger.warning("VAD is not currently supported for WhisperASR backend and will be ignored.")
|
|
|
|
class FasterWhisperASR(ASRBase):
|
|
"""Uses faster-whisper as the backend."""
|
|
sep = ""
|
|
|
|
def load_model(self, model_size=None, cache_dir=None, model_dir=None):
|
|
from faster_whisper import WhisperModel
|
|
|
|
if model_dir is not None:
|
|
resolved_path = resolve_model_path(model_dir)
|
|
logger.debug(f"Loading faster-whisper model from {resolved_path}. "
|
|
f"model_size and cache_dir parameters are not used.")
|
|
model_size_or_path = str(resolved_path)
|
|
elif model_size is not None:
|
|
model_size_or_path = model_size
|
|
else:
|
|
raise ValueError("Either model_size or model_dir must be set")
|
|
device = "auto" # Allow CTranslate2 to decide available device
|
|
compute_type = "auto" # Allow CTranslate2 to decide faster compute type
|
|
|
|
|
|
model = WhisperModel(
|
|
model_size_or_path,
|
|
device=device,
|
|
compute_type=compute_type,
|
|
download_root=cache_dir,
|
|
)
|
|
return model
|
|
|
|
def transcribe(self, audio: np.ndarray, init_prompt: str = "") -> list:
|
|
segments, info = self.model.transcribe(
|
|
audio,
|
|
language=self.original_language,
|
|
initial_prompt=init_prompt,
|
|
beam_size=5,
|
|
word_timestamps=True,
|
|
condition_on_previous_text=True,
|
|
**self.transcribe_kargs,
|
|
)
|
|
return list(segments)
|
|
|
|
def ts_words(self, segments) -> List[ASRToken]:
|
|
tokens = []
|
|
for segment in segments:
|
|
if segment.no_speech_prob > 0.9:
|
|
continue
|
|
for word in segment.words:
|
|
token = ASRToken(word.start, word.end, word.word, probability=word.probability)
|
|
tokens.append(token)
|
|
return tokens
|
|
|
|
def segments_end_ts(self, segments) -> List[float]:
|
|
return [segment.end for segment in segments]
|
|
|
|
def use_vad(self):
|
|
self.transcribe_kargs["vad_filter"] = True
|
|
|
|
class MLXWhisper(ASRBase):
|
|
"""
|
|
Uses MLX Whisper optimized for Apple Silicon.
|
|
"""
|
|
sep = ""
|
|
|
|
def load_model(self, model_size=None, cache_dir=None, model_dir=None):
|
|
import mlx.core as mx
|
|
from mlx_whisper.transcribe import ModelHolder, transcribe
|
|
|
|
if model_dir is not None:
|
|
resolved_path = resolve_model_path(model_dir)
|
|
logger.debug(f"Loading MLX Whisper model from {resolved_path}. model_size parameter is not used.")
|
|
model_size_or_path = str(resolved_path)
|
|
elif model_size is not None:
|
|
model_size_or_path = self.translate_model_name(model_size)
|
|
logger.debug(f"Loading whisper model {model_size}. You use mlx whisper, so {model_size_or_path} will be used.")
|
|
else:
|
|
raise ValueError("Either model_size or model_dir must be set")
|
|
|
|
self.model_size_or_path = model_size_or_path
|
|
dtype = mx.float16
|
|
ModelHolder.get_model(model_size_or_path, dtype)
|
|
return transcribe
|
|
|
|
def translate_model_name(self, model_name):
|
|
from whisperlivekit.model_mapping import MLX_MODEL_MAPPING
|
|
mlx_model_path = MLX_MODEL_MAPPING.get(model_name)
|
|
if mlx_model_path:
|
|
return mlx_model_path
|
|
else:
|
|
raise ValueError(f"Model name '{model_name}' is not recognized or not supported.")
|
|
|
|
def transcribe(self, audio, init_prompt=""):
|
|
if self.transcribe_kargs:
|
|
logger.warning("Transcribe kwargs (vad, task) are not compatible with MLX Whisper and will be ignored.")
|
|
segments = self.model(
|
|
audio,
|
|
language=self.original_language,
|
|
initial_prompt=init_prompt,
|
|
word_timestamps=True,
|
|
condition_on_previous_text=True,
|
|
path_or_hf_repo=self.model_size_or_path,
|
|
)
|
|
return segments.get("segments", [])
|
|
|
|
def ts_words(self, segments) -> List[ASRToken]:
|
|
tokens = []
|
|
for segment in segments:
|
|
if segment.get("no_speech_prob", 0) > 0.9:
|
|
continue
|
|
for word in segment.get("words", []):
|
|
token = ASRToken(word["start"], word["end"], word["word"])
|
|
tokens.append(token)
|
|
return tokens
|
|
|
|
def segments_end_ts(self, res) -> List[float]:
|
|
return [s["end"] for s in res]
|
|
|
|
def use_vad(self):
|
|
self.transcribe_kargs["vad_filter"] = True
|
|
|
|
|
|
class OpenaiApiASR(ASRBase):
|
|
"""Uses OpenAI's Whisper API for transcription."""
|
|
def __init__(self, lan=None, temperature=0, logfile=sys.stderr):
|
|
self.logfile = logfile
|
|
self.modelname = "whisper-1"
|
|
self.original_language = None if lan == "auto" else lan
|
|
self.response_format = "verbose_json"
|
|
self.temperature = temperature
|
|
self.load_model()
|
|
self.use_vad_opt = False
|
|
self.direct_english_translation = False
|
|
self.task = "transcribe"
|
|
|
|
def load_model(self, *args, **kwargs):
|
|
from openai import OpenAI
|
|
self.client = OpenAI()
|
|
self.transcribed_seconds = 0
|
|
|
|
def ts_words(self, segments) -> List[ASRToken]:
|
|
"""
|
|
Converts OpenAI API response words into ASRToken objects while
|
|
optionally skipping words that fall into no-speech segments.
|
|
"""
|
|
no_speech_segments = []
|
|
if self.use_vad_opt:
|
|
for segment in segments.segments:
|
|
if segment.no_speech_prob > 0.8:
|
|
no_speech_segments.append((segment.start, segment.end))
|
|
tokens = []
|
|
for word in segments.words:
|
|
start = word.start
|
|
end = word.end
|
|
if any(s[0] <= start <= s[1] for s in no_speech_segments):
|
|
continue
|
|
tokens.append(ASRToken(start, end, word.word))
|
|
return tokens
|
|
|
|
def segments_end_ts(self, res) -> List[float]:
|
|
return [s.end for s in res.words]
|
|
|
|
def transcribe(self, audio_data, prompt=None, *args, **kwargs):
|
|
buffer = io.BytesIO()
|
|
buffer.name = "temp.wav"
|
|
sf.write(buffer, audio_data, samplerate=16000, format="WAV", subtype="PCM_16")
|
|
buffer.seek(0)
|
|
self.transcribed_seconds += math.ceil(len(audio_data) / 16000)
|
|
params = {
|
|
"model": self.modelname,
|
|
"file": buffer,
|
|
"response_format": self.response_format,
|
|
"temperature": self.temperature,
|
|
"timestamp_granularities": ["word", "segment"],
|
|
}
|
|
if not self.direct_english_translation and self.original_language:
|
|
params["language"] = self.original_language
|
|
if prompt:
|
|
params["prompt"] = prompt
|
|
task = self.transcribe_kargs.get("task", self.task)
|
|
proc = self.client.audio.translations if task == "translate" else self.client.audio.transcriptions
|
|
transcript = proc.create(**params)
|
|
logger.debug(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds")
|
|
return transcript
|
|
|
|
def use_vad(self):
|
|
self.use_vad_opt = True
|