mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
adapt backend for the new classes
This commit is contained in:
@@ -1,45 +1,47 @@
|
||||
import sys
|
||||
import logging
|
||||
|
||||
import io
|
||||
import soundfile as sf
|
||||
import math
|
||||
import torch
|
||||
from typing import List
|
||||
import numpy as np
|
||||
from src.whisper_streaming.asr_token import ASRToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ASRBase:
|
||||
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
|
||||
# "" for faster-whisper because it emits the spaces when neeeded)
|
||||
# "" for faster-whisper because it emits the spaces when needed)
|
||||
|
||||
def __init__(
|
||||
self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr
|
||||
):
|
||||
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr):
|
||||
self.logfile = logfile
|
||||
|
||||
self.transcribe_kargs = {}
|
||||
if lan == "auto":
|
||||
self.original_language = None
|
||||
else:
|
||||
self.original_language = lan
|
||||
|
||||
self.model = self.load_model(modelsize, cache_dir, model_dir)
|
||||
|
||||
def load_model(self, modelsize, cache_dir):
|
||||
raise NotImplemented("must be implemented in the child class")
|
||||
def with_offset(self, offset: float) -> ASRToken:
|
||||
# This method is kept for compatibility (typically you will use ASRToken.with_offset)
|
||||
return ASRToken(self.start + offset, self.end + offset, self.text)
|
||||
|
||||
def __repr__(self):
|
||||
return f"ASRToken(start={self.start:.2f}, end={self.end:.2f}, text={self.text!r})"
|
||||
|
||||
def load_model(self, modelsize, cache_dir, model_dir):
|
||||
raise NotImplementedError("must be implemented in the child class")
|
||||
|
||||
def transcribe(self, audio, init_prompt=""):
|
||||
raise NotImplemented("must be implemented in the child class")
|
||||
raise NotImplementedError("must be implemented in the child class")
|
||||
|
||||
def use_vad(self):
|
||||
raise NotImplemented("must be implemented in the child class")
|
||||
raise NotImplementedError("must be implemented in the child class")
|
||||
|
||||
|
||||
class WhisperTimestampedASR(ASRBase):
|
||||
"""Uses whisper_timestamped library as the backend. Initially, we tested the code on this backend. It worked, but slower than faster-whisper.
|
||||
On the other hand, the installation for GPU could be easier.
|
||||
"""
|
||||
|
||||
"""Uses whisper_timestamped as the backend."""
|
||||
sep = " "
|
||||
|
||||
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
|
||||
@@ -64,17 +66,19 @@ class WhisperTimestampedASR(ASRBase):
|
||||
)
|
||||
return result
|
||||
|
||||
def ts_words(self, r):
|
||||
# return: transcribe result object to [(beg,end,"word1"), ...]
|
||||
o = []
|
||||
for s in r["segments"]:
|
||||
for w in s["words"]:
|
||||
t = (w["start"], w["end"], w["text"])
|
||||
o.append(t)
|
||||
return o
|
||||
def ts_words(self, r) -> List[ASRToken]:
|
||||
"""
|
||||
Converts the whisper_timestamped result to a list of ASRToken objects.
|
||||
"""
|
||||
tokens = []
|
||||
for segment in r["segments"]:
|
||||
for word in segment["words"]:
|
||||
token = ASRToken(word["start"], word["end"], word["text"])
|
||||
tokens.append(token)
|
||||
return tokens
|
||||
|
||||
def segments_end_ts(self, res):
|
||||
return [s["end"] for s in res["segments"]]
|
||||
def segments_end_ts(self, res) -> List[float]:
|
||||
return [segment["end"] for segment in res["segments"]]
|
||||
|
||||
def use_vad(self):
|
||||
self.transcribe_kargs["vad"] = True
|
||||
@@ -84,24 +88,20 @@ class WhisperTimestampedASR(ASRBase):
|
||||
|
||||
|
||||
class FasterWhisperASR(ASRBase):
|
||||
"""Uses faster-whisper library as the backend. Works much faster, appx 4-times (in offline mode). For GPU, it requires installation with a specific CUDNN version."""
|
||||
|
||||
"""Uses faster-whisper as the backend."""
|
||||
sep = ""
|
||||
|
||||
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
# logging.getLogger("faster_whisper").setLevel(logger.level)
|
||||
if model_dir is not None:
|
||||
logger.debug(
|
||||
f"Loading whisper model from model_dir {model_dir}. modelsize and cache_dir parameters are not used."
|
||||
)
|
||||
logger.debug(f"Loading whisper model from model_dir {model_dir}. "
|
||||
f"modelsize and cache_dir parameters are not used.")
|
||||
model_size_or_path = model_dir
|
||||
elif modelsize is not None:
|
||||
model_size_or_path = modelsize
|
||||
else:
|
||||
raise ValueError("modelsize or model_dir parameter must be set")
|
||||
|
||||
raise ValueError("Either modelsize or model_dir must be set")
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
compute_type = "float16" if device == "cuda" else "float32"
|
||||
|
||||
@@ -111,19 +111,9 @@ class FasterWhisperASR(ASRBase):
|
||||
compute_type=compute_type,
|
||||
download_root=cache_dir,
|
||||
)
|
||||
|
||||
# or run on GPU with INT8
|
||||
# tested: the transcripts were different, probably worse than with FP16, and it was slightly (appx 20%) slower
|
||||
# model = WhisperModel(model_size, device="cuda", compute_type="int8_float16")
|
||||
|
||||
# or run on CPU with INT8
|
||||
# tested: works, but slow, appx 10-times than cuda FP16
|
||||
# model = WhisperModel(modelsize, device="cpu", compute_type="int8") #, download_root="faster-disk-cache-dir/")
|
||||
return model
|
||||
|
||||
def transcribe(self, audio, init_prompt=""):
|
||||
|
||||
# tested: beam_size=5 is faster and better than 1 (on one 200 second document from En ESIC, min chunk 0.01)
|
||||
def transcribe(self, audio: np.ndarray, init_prompt: str = "") -> list:
|
||||
segments, info = self.model.transcribe(
|
||||
audio,
|
||||
language=self.original_language,
|
||||
@@ -133,24 +123,20 @@ class FasterWhisperASR(ASRBase):
|
||||
condition_on_previous_text=True,
|
||||
**self.transcribe_kargs,
|
||||
)
|
||||
# print(info) # info contains language detection result
|
||||
|
||||
return list(segments)
|
||||
|
||||
def ts_words(self, segments):
|
||||
o = []
|
||||
def ts_words(self, segments) -> List[ASRToken]:
|
||||
tokens = []
|
||||
for segment in segments:
|
||||
if segment.no_speech_prob > 0.9:
|
||||
continue
|
||||
for word in segment.words:
|
||||
if segment.no_speech_prob > 0.9:
|
||||
continue
|
||||
# not stripping the spaces -- should not be merged with them!
|
||||
w = word.word
|
||||
t = (word.start, word.end, w)
|
||||
o.append(t)
|
||||
return o
|
||||
token = ASRToken(word.start, word.end, word.word)
|
||||
tokens.append(token)
|
||||
return tokens
|
||||
|
||||
def segments_end_ts(self, res):
|
||||
return [s.end for s in res]
|
||||
def segments_end_ts(self, segments) -> List[float]:
|
||||
return [segment.end for segment in segments]
|
||||
|
||||
def use_vad(self):
|
||||
self.transcribe_kargs["vad_filter"] = True
|
||||
@@ -161,60 +147,29 @@ class FasterWhisperASR(ASRBase):
|
||||
|
||||
class MLXWhisper(ASRBase):
|
||||
"""
|
||||
Uses MPX Whisper library as the backend, optimized for Apple Silicon.
|
||||
Models available: https://huggingface.co/collections/mlx-community/whisper-663256f9964fbb1177db93dc
|
||||
Significantly faster than faster-whisper (without CUDA) on Apple M1.
|
||||
Uses MLX Whisper optimized for Apple Silicon.
|
||||
"""
|
||||
|
||||
sep = "" # In my experience in french it should also be no space.
|
||||
sep = ""
|
||||
|
||||
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
|
||||
"""
|
||||
Loads the MLX-compatible Whisper model.
|
||||
|
||||
Args:
|
||||
modelsize (str, optional): The size or name of the Whisper model to load.
|
||||
If provided, it will be translated to an MLX-compatible model path using the `translate_model_name` method.
|
||||
Example: "large-v3-turbo" -> "mlx-community/whisper-large-v3-turbo".
|
||||
cache_dir (str, optional): Path to the directory for caching models.
|
||||
**Note**: This is not supported by MLX Whisper and will be ignored.
|
||||
model_dir (str, optional): Direct path to a custom model directory.
|
||||
If specified, it overrides the `modelsize` parameter.
|
||||
"""
|
||||
from mlx_whisper.transcribe import ModelHolder, transcribe
|
||||
import mlx.core as mx
|
||||
|
||||
if model_dir is not None:
|
||||
logger.debug(
|
||||
f"Loading whisper model from model_dir {model_dir}. modelsize parameter is not used."
|
||||
)
|
||||
logger.debug(f"Loading whisper model from model_dir {model_dir}. modelsize parameter is not used.")
|
||||
model_size_or_path = model_dir
|
||||
elif modelsize is not None:
|
||||
model_size_or_path = self.translate_model_name(modelsize)
|
||||
logger.debug(
|
||||
f"Loading whisper model {modelsize}. You use mlx whisper, so {model_size_or_path} will be used."
|
||||
)
|
||||
logger.debug(f"Loading whisper model {modelsize}. You use mlx whisper, so {model_size_or_path} will be used.")
|
||||
else:
|
||||
raise ValueError("Either modelsize or model_dir must be set")
|
||||
|
||||
self.model_size_or_path = model_size_or_path
|
||||
|
||||
# In mlx_whisper.transcribe, dtype is defined as:
|
||||
# dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32
|
||||
# Since we do not use decode_options in self.transcribe, we will set dtype to mx.float16
|
||||
dtype = mx.float16
|
||||
dtype = mx.float16
|
||||
ModelHolder.get_model(model_size_or_path, dtype)
|
||||
return transcribe
|
||||
|
||||
def translate_model_name(self, model_name):
|
||||
"""
|
||||
Translates a given model name to its corresponding MLX-compatible model path.
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the model to translate.
|
||||
|
||||
Returns:
|
||||
str: The MLX-compatible model path.
|
||||
"""
|
||||
# Dictionary mapping model names to MLX-compatible paths
|
||||
model_mapping = {
|
||||
"tiny.en": "mlx-community/whisper-tiny.en-mlx",
|
||||
"tiny": "mlx-community/whisper-tiny-mlx",
|
||||
@@ -230,16 +185,11 @@ class MLXWhisper(ASRBase):
|
||||
"large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
|
||||
"large": "mlx-community/whisper-large-mlx",
|
||||
}
|
||||
|
||||
# Retrieve the corresponding MLX model path
|
||||
mlx_model_path = model_mapping.get(model_name)
|
||||
|
||||
if mlx_model_path:
|
||||
return mlx_model_path
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Model name '{model_name}' is not recognized or not supported."
|
||||
)
|
||||
raise ValueError(f"Model name '{model_name}' is not recognized or not supported.")
|
||||
|
||||
def transcribe(self, audio, init_prompt=""):
|
||||
if self.transcribe_kargs:
|
||||
@@ -254,18 +204,17 @@ class MLXWhisper(ASRBase):
|
||||
)
|
||||
return segments.get("segments", [])
|
||||
|
||||
def ts_words(self, segments):
|
||||
"""
|
||||
Extract timestamped words from transcription segments and skips words with high no-speech probability.
|
||||
"""
|
||||
return [
|
||||
(word["start"], word["end"], word["word"])
|
||||
for segment in segments
|
||||
for word in segment.get("words", [])
|
||||
if segment.get("no_speech_prob", 0) <= 0.9
|
||||
]
|
||||
def ts_words(self, segments) -> List[ASRToken]:
|
||||
tokens = []
|
||||
for segment in segments:
|
||||
if segment.get("no_speech_prob", 0) > 0.9:
|
||||
continue
|
||||
for word in segment.get("words", []):
|
||||
token = ASRToken(word["start"], word["end"], word["word"])
|
||||
tokens.append(token)
|
||||
return tokens
|
||||
|
||||
def segments_end_ts(self, res):
|
||||
def segments_end_ts(self, res) -> List[float]:
|
||||
return [s["end"] for s in res]
|
||||
|
||||
def use_vad(self):
|
||||
@@ -276,68 +225,50 @@ class MLXWhisper(ASRBase):
|
||||
|
||||
|
||||
class OpenaiApiASR(ASRBase):
|
||||
"""Uses OpenAI's Whisper API for audio transcription."""
|
||||
|
||||
"""Uses OpenAI's Whisper API for transcription."""
|
||||
def __init__(self, lan=None, temperature=0, logfile=sys.stderr):
|
||||
self.logfile = logfile
|
||||
|
||||
self.modelname = "whisper-1"
|
||||
self.original_language = (
|
||||
None if lan == "auto" else lan
|
||||
) # ISO-639-1 language code
|
||||
self.original_language = None if lan == "auto" else lan
|
||||
self.response_format = "verbose_json"
|
||||
self.temperature = temperature
|
||||
|
||||
self.load_model()
|
||||
|
||||
self.use_vad_opt = False
|
||||
|
||||
# reset the task in set_translate_task
|
||||
self.task = "transcribe"
|
||||
|
||||
def load_model(self, *args, **kwargs):
|
||||
from openai import OpenAI
|
||||
|
||||
self.client = OpenAI()
|
||||
self.transcribed_seconds = 0
|
||||
|
||||
self.transcribed_seconds = (
|
||||
0 # for logging how many seconds were processed by API, to know the cost
|
||||
)
|
||||
|
||||
def ts_words(self, segments):
|
||||
def ts_words(self, segments) -> List[ASRToken]:
|
||||
"""
|
||||
Converts OpenAI API response words into ASRToken objects while
|
||||
optionally skipping words that fall into no-speech segments.
|
||||
"""
|
||||
no_speech_segments = []
|
||||
if self.use_vad_opt:
|
||||
for segment in segments.segments:
|
||||
# TODO: threshold can be set from outside
|
||||
if segment["no_speech_prob"] > 0.8:
|
||||
no_speech_segments.append(
|
||||
(segment.get("start"), segment.get("end"))
|
||||
)
|
||||
|
||||
o = []
|
||||
no_speech_segments.append((segment.get("start"), segment.get("end")))
|
||||
tokens = []
|
||||
for word in segments.words:
|
||||
start = word.start
|
||||
end = word.end
|
||||
if any(s[0] <= start <= s[1] for s in no_speech_segments):
|
||||
# print("Skipping word", word.get("word"), "because it's in a no-speech segment")
|
||||
continue
|
||||
o.append((start, end, word.word))
|
||||
return o
|
||||
tokens.append(ASRToken(start, end, word.word))
|
||||
return tokens
|
||||
|
||||
def segments_end_ts(self, res):
|
||||
def segments_end_ts(self, res) -> List[float]:
|
||||
return [s.end for s in res.words]
|
||||
|
||||
def transcribe(self, audio_data, prompt=None, *args, **kwargs):
|
||||
# Write the audio data to a buffer
|
||||
buffer = io.BytesIO()
|
||||
buffer.name = "temp.wav"
|
||||
sf.write(buffer, audio_data, samplerate=16000, format="WAV", subtype="PCM_16")
|
||||
buffer.seek(0) # Reset buffer's position to the beginning
|
||||
|
||||
self.transcribed_seconds += math.ceil(
|
||||
len(audio_data) / 16000
|
||||
) # it rounds up to the whole seconds
|
||||
|
||||
buffer.seek(0)
|
||||
self.transcribed_seconds += math.ceil(len(audio_data) / 16000)
|
||||
params = {
|
||||
"model": self.modelname,
|
||||
"file": buffer,
|
||||
@@ -349,22 +280,13 @@ class OpenaiApiASR(ASRBase):
|
||||
params["language"] = self.original_language
|
||||
if prompt:
|
||||
params["prompt"] = prompt
|
||||
|
||||
if self.task == "translate":
|
||||
proc = self.client.audio.translations
|
||||
else:
|
||||
proc = self.client.audio.transcriptions
|
||||
|
||||
# Process transcription/translation
|
||||
proc = self.client.audio.translations if self.task == "translate" else self.client.audio.transcriptions
|
||||
transcript = proc.create(**params)
|
||||
logger.debug(
|
||||
f"OpenAI API processed accumulated {self.transcribed_seconds} seconds"
|
||||
)
|
||||
|
||||
logger.debug(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds")
|
||||
return transcript
|
||||
|
||||
def use_vad(self):
|
||||
self.use_vad_opt = True
|
||||
|
||||
def set_translate_task(self):
|
||||
self.task = "translate"
|
||||
self.task = "translate"
|
||||
Reference in New Issue
Block a user