From b82cc3b61396a64d6bad9f25efaa6fcb0f2f6ea1 Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Fri, 7 Feb 2025 12:24:37 +0100 Subject: [PATCH] adapt backend for the new classes --- src/whisper_streaming/backends.py | 236 ++++++++++-------------------- 1 file changed, 79 insertions(+), 157 deletions(-) diff --git a/src/whisper_streaming/backends.py b/src/whisper_streaming/backends.py index 20522ed..514dded 100644 --- a/src/whisper_streaming/backends.py +++ b/src/whisper_streaming/backends.py @@ -1,45 +1,47 @@ import sys import logging - import io import soundfile as sf import math import torch +from typing import List +import numpy as np +from src.whisper_streaming.asr_token import ASRToken logger = logging.getLogger(__name__) class ASRBase: sep = " " # join transcribe words with this character (" " for whisper_timestamped, - # "" for faster-whisper because it emits the spaces when neeeded) + # "" for faster-whisper because it emits the spaces when needed) - def __init__( - self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr - ): + def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr): self.logfile = logfile - self.transcribe_kargs = {} if lan == "auto": self.original_language = None else: self.original_language = lan - self.model = self.load_model(modelsize, cache_dir, model_dir) - def load_model(self, modelsize, cache_dir): - raise NotImplemented("must be implemented in the child class") + def with_offset(self, offset: float) -> ASRToken: + # This method is kept for compatibility (typically you will use ASRToken.with_offset) + return ASRToken(self.start + offset, self.end + offset, self.text) + + def __repr__(self): + return f"ASRToken(start={self.start:.2f}, end={self.end:.2f}, text={self.text!r})" + + def load_model(self, modelsize, cache_dir, model_dir): + raise NotImplementedError("must be implemented in the child class") def transcribe(self, audio, init_prompt=""): - raise NotImplemented("must be implemented in the child class") + raise NotImplementedError("must be implemented in the child class") def use_vad(self): - raise NotImplemented("must be implemented in the child class") + raise NotImplementedError("must be implemented in the child class") class WhisperTimestampedASR(ASRBase): - """Uses whisper_timestamped library as the backend. Initially, we tested the code on this backend. It worked, but slower than faster-whisper. - On the other hand, the installation for GPU could be easier. - """ - + """Uses whisper_timestamped as the backend.""" sep = " " def load_model(self, modelsize=None, cache_dir=None, model_dir=None): @@ -64,17 +66,19 @@ class WhisperTimestampedASR(ASRBase): ) return result - def ts_words(self, r): - # return: transcribe result object to [(beg,end,"word1"), ...] - o = [] - for s in r["segments"]: - for w in s["words"]: - t = (w["start"], w["end"], w["text"]) - o.append(t) - return o + def ts_words(self, r) -> List[ASRToken]: + """ + Converts the whisper_timestamped result to a list of ASRToken objects. + """ + tokens = [] + for segment in r["segments"]: + for word in segment["words"]: + token = ASRToken(word["start"], word["end"], word["text"]) + tokens.append(token) + return tokens - def segments_end_ts(self, res): - return [s["end"] for s in res["segments"]] + def segments_end_ts(self, res) -> List[float]: + return [segment["end"] for segment in res["segments"]] def use_vad(self): self.transcribe_kargs["vad"] = True @@ -84,24 +88,20 @@ class WhisperTimestampedASR(ASRBase): class FasterWhisperASR(ASRBase): - """Uses faster-whisper library as the backend. Works much faster, appx 4-times (in offline mode). For GPU, it requires installation with a specific CUDNN version.""" - + """Uses faster-whisper as the backend.""" sep = "" def load_model(self, modelsize=None, cache_dir=None, model_dir=None): from faster_whisper import WhisperModel - # logging.getLogger("faster_whisper").setLevel(logger.level) if model_dir is not None: - logger.debug( - f"Loading whisper model from model_dir {model_dir}. modelsize and cache_dir parameters are not used." - ) + logger.debug(f"Loading whisper model from model_dir {model_dir}. " + f"modelsize and cache_dir parameters are not used.") model_size_or_path = model_dir elif modelsize is not None: model_size_or_path = modelsize else: - raise ValueError("modelsize or model_dir parameter must be set") - + raise ValueError("Either modelsize or model_dir must be set") device = "cuda" if torch.cuda.is_available() else "cpu" compute_type = "float16" if device == "cuda" else "float32" @@ -111,19 +111,9 @@ class FasterWhisperASR(ASRBase): compute_type=compute_type, download_root=cache_dir, ) - - # or run on GPU with INT8 - # tested: the transcripts were different, probably worse than with FP16, and it was slightly (appx 20%) slower - # model = WhisperModel(model_size, device="cuda", compute_type="int8_float16") - - # or run on CPU with INT8 - # tested: works, but slow, appx 10-times than cuda FP16 - # model = WhisperModel(modelsize, device="cpu", compute_type="int8") #, download_root="faster-disk-cache-dir/") return model - def transcribe(self, audio, init_prompt=""): - - # tested: beam_size=5 is faster and better than 1 (on one 200 second document from En ESIC, min chunk 0.01) + def transcribe(self, audio: np.ndarray, init_prompt: str = "") -> list: segments, info = self.model.transcribe( audio, language=self.original_language, @@ -133,24 +123,20 @@ class FasterWhisperASR(ASRBase): condition_on_previous_text=True, **self.transcribe_kargs, ) - # print(info) # info contains language detection result - return list(segments) - def ts_words(self, segments): - o = [] + def ts_words(self, segments) -> List[ASRToken]: + tokens = [] for segment in segments: + if segment.no_speech_prob > 0.9: + continue for word in segment.words: - if segment.no_speech_prob > 0.9: - continue - # not stripping the spaces -- should not be merged with them! - w = word.word - t = (word.start, word.end, w) - o.append(t) - return o + token = ASRToken(word.start, word.end, word.word) + tokens.append(token) + return tokens - def segments_end_ts(self, res): - return [s.end for s in res] + def segments_end_ts(self, segments) -> List[float]: + return [segment.end for segment in segments] def use_vad(self): self.transcribe_kargs["vad_filter"] = True @@ -161,60 +147,29 @@ class FasterWhisperASR(ASRBase): class MLXWhisper(ASRBase): """ - Uses MPX Whisper library as the backend, optimized for Apple Silicon. - Models available: https://huggingface.co/collections/mlx-community/whisper-663256f9964fbb1177db93dc - Significantly faster than faster-whisper (without CUDA) on Apple M1. + Uses MLX Whisper optimized for Apple Silicon. """ - - sep = "" # In my experience in french it should also be no space. + sep = "" def load_model(self, modelsize=None, cache_dir=None, model_dir=None): - """ - Loads the MLX-compatible Whisper model. - - Args: - modelsize (str, optional): The size or name of the Whisper model to load. - If provided, it will be translated to an MLX-compatible model path using the `translate_model_name` method. - Example: "large-v3-turbo" -> "mlx-community/whisper-large-v3-turbo". - cache_dir (str, optional): Path to the directory for caching models. - **Note**: This is not supported by MLX Whisper and will be ignored. - model_dir (str, optional): Direct path to a custom model directory. - If specified, it overrides the `modelsize` parameter. - """ from mlx_whisper.transcribe import ModelHolder, transcribe import mlx.core as mx if model_dir is not None: - logger.debug( - f"Loading whisper model from model_dir {model_dir}. modelsize parameter is not used." - ) + logger.debug(f"Loading whisper model from model_dir {model_dir}. modelsize parameter is not used.") model_size_or_path = model_dir elif modelsize is not None: model_size_or_path = self.translate_model_name(modelsize) - logger.debug( - f"Loading whisper model {modelsize}. You use mlx whisper, so {model_size_or_path} will be used." - ) + logger.debug(f"Loading whisper model {modelsize}. You use mlx whisper, so {model_size_or_path} will be used.") + else: + raise ValueError("Either modelsize or model_dir must be set") self.model_size_or_path = model_size_or_path - - # In mlx_whisper.transcribe, dtype is defined as: - # dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32 - # Since we do not use decode_options in self.transcribe, we will set dtype to mx.float16 - dtype = mx.float16 + dtype = mx.float16 ModelHolder.get_model(model_size_or_path, dtype) return transcribe def translate_model_name(self, model_name): - """ - Translates a given model name to its corresponding MLX-compatible model path. - - Args: - model_name (str): The name of the model to translate. - - Returns: - str: The MLX-compatible model path. - """ - # Dictionary mapping model names to MLX-compatible paths model_mapping = { "tiny.en": "mlx-community/whisper-tiny.en-mlx", "tiny": "mlx-community/whisper-tiny-mlx", @@ -230,16 +185,11 @@ class MLXWhisper(ASRBase): "large-v3-turbo": "mlx-community/whisper-large-v3-turbo", "large": "mlx-community/whisper-large-mlx", } - - # Retrieve the corresponding MLX model path mlx_model_path = model_mapping.get(model_name) - if mlx_model_path: return mlx_model_path else: - raise ValueError( - f"Model name '{model_name}' is not recognized or not supported." - ) + raise ValueError(f"Model name '{model_name}' is not recognized or not supported.") def transcribe(self, audio, init_prompt=""): if self.transcribe_kargs: @@ -254,18 +204,17 @@ class MLXWhisper(ASRBase): ) return segments.get("segments", []) - def ts_words(self, segments): - """ - Extract timestamped words from transcription segments and skips words with high no-speech probability. - """ - return [ - (word["start"], word["end"], word["word"]) - for segment in segments - for word in segment.get("words", []) - if segment.get("no_speech_prob", 0) <= 0.9 - ] + def ts_words(self, segments) -> List[ASRToken]: + tokens = [] + for segment in segments: + if segment.get("no_speech_prob", 0) > 0.9: + continue + for word in segment.get("words", []): + token = ASRToken(word["start"], word["end"], word["word"]) + tokens.append(token) + return tokens - def segments_end_ts(self, res): + def segments_end_ts(self, res) -> List[float]: return [s["end"] for s in res] def use_vad(self): @@ -276,68 +225,50 @@ class MLXWhisper(ASRBase): class OpenaiApiASR(ASRBase): - """Uses OpenAI's Whisper API for audio transcription.""" - + """Uses OpenAI's Whisper API for transcription.""" def __init__(self, lan=None, temperature=0, logfile=sys.stderr): self.logfile = logfile - self.modelname = "whisper-1" - self.original_language = ( - None if lan == "auto" else lan - ) # ISO-639-1 language code + self.original_language = None if lan == "auto" else lan self.response_format = "verbose_json" self.temperature = temperature - self.load_model() - self.use_vad_opt = False - - # reset the task in set_translate_task self.task = "transcribe" def load_model(self, *args, **kwargs): from openai import OpenAI - self.client = OpenAI() + self.transcribed_seconds = 0 - self.transcribed_seconds = ( - 0 # for logging how many seconds were processed by API, to know the cost - ) - - def ts_words(self, segments): + def ts_words(self, segments) -> List[ASRToken]: + """ + Converts OpenAI API response words into ASRToken objects while + optionally skipping words that fall into no-speech segments. + """ no_speech_segments = [] if self.use_vad_opt: for segment in segments.segments: - # TODO: threshold can be set from outside if segment["no_speech_prob"] > 0.8: - no_speech_segments.append( - (segment.get("start"), segment.get("end")) - ) - - o = [] + no_speech_segments.append((segment.get("start"), segment.get("end"))) + tokens = [] for word in segments.words: start = word.start end = word.end if any(s[0] <= start <= s[1] for s in no_speech_segments): - # print("Skipping word", word.get("word"), "because it's in a no-speech segment") continue - o.append((start, end, word.word)) - return o + tokens.append(ASRToken(start, end, word.word)) + return tokens - def segments_end_ts(self, res): + def segments_end_ts(self, res) -> List[float]: return [s.end for s in res.words] def transcribe(self, audio_data, prompt=None, *args, **kwargs): - # Write the audio data to a buffer buffer = io.BytesIO() buffer.name = "temp.wav" sf.write(buffer, audio_data, samplerate=16000, format="WAV", subtype="PCM_16") - buffer.seek(0) # Reset buffer's position to the beginning - - self.transcribed_seconds += math.ceil( - len(audio_data) / 16000 - ) # it rounds up to the whole seconds - + buffer.seek(0) + self.transcribed_seconds += math.ceil(len(audio_data) / 16000) params = { "model": self.modelname, "file": buffer, @@ -349,22 +280,13 @@ class OpenaiApiASR(ASRBase): params["language"] = self.original_language if prompt: params["prompt"] = prompt - - if self.task == "translate": - proc = self.client.audio.translations - else: - proc = self.client.audio.transcriptions - - # Process transcription/translation + proc = self.client.audio.translations if self.task == "translate" else self.client.audio.transcriptions transcript = proc.create(**params) - logger.debug( - f"OpenAI API processed accumulated {self.transcribed_seconds} seconds" - ) - + logger.debug(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds") return transcript def use_vad(self): self.use_vad_opt = True def set_translate_task(self): - self.task = "translate" + self.task = "translate" \ No newline at end of file