split whisper_online.py into smaller files

This commit is contained in:
Quentin Fuxa
2025-01-14 20:52:53 +01:00
parent 9cbac96c44
commit ce56264241
7 changed files with 778 additions and 773 deletions

View File

@@ -12,12 +12,12 @@ This project extends the [Whisper Streaming](https://github.com/ufal/whisper_str
5. **MLX Whisper backend**: Integrates the alternative backend option MLX Whisper, optimized for efficient speech recognition on Apple silicon.
![Demo Screenshot](src/demo.png)
![Demo Screenshot](src/web/demo.png)
## Code Origins
This project reuses and extends code from the original Whisper Streaming repository:
- whisper_online.py: Contains code from whisper_streaming
- whisper_online.py, backends.py and online_asr.py: Contains code from whisper_streaming
- silero_vad_iterator.py: Originally from the Silero VAD repository, included in the whisper_streaming project.
## Installation

View File

Before

Width:  |  Height:  |  Size: 81 KiB

After

Width:  |  Height:  |  Size: 81 KiB

View File

@@ -0,0 +1,368 @@
import sys
import logging
import io
import soundfile as sf
import math
logger = logging.getLogger(__name__)
class ASRBase:
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
# "" for faster-whisper because it emits the spaces when neeeded)
def __init__(
self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr
):
self.logfile = logfile
self.transcribe_kargs = {}
if lan == "auto":
self.original_language = None
else:
self.original_language = lan
self.model = self.load_model(modelsize, cache_dir, model_dir)
def load_model(self, modelsize, cache_dir):
raise NotImplemented("must be implemented in the child class")
def transcribe(self, audio, init_prompt=""):
raise NotImplemented("must be implemented in the child class")
def use_vad(self):
raise NotImplemented("must be implemented in the child class")
class WhisperTimestampedASR(ASRBase):
"""Uses whisper_timestamped library as the backend. Initially, we tested the code on this backend. It worked, but slower than faster-whisper.
On the other hand, the installation for GPU could be easier.
"""
sep = " "
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
import whisper
import whisper_timestamped
from whisper_timestamped import transcribe_timestamped
self.transcribe_timestamped = transcribe_timestamped
if model_dir is not None:
logger.debug("ignoring model_dir, not implemented")
return whisper.load_model(modelsize, download_root=cache_dir)
def transcribe(self, audio, init_prompt=""):
result = self.transcribe_timestamped(
self.model,
audio,
language=self.original_language,
initial_prompt=init_prompt,
verbose=None,
condition_on_previous_text=True,
**self.transcribe_kargs,
)
return result
def ts_words(self, r):
# return: transcribe result object to [(beg,end,"word1"), ...]
o = []
for s in r["segments"]:
for w in s["words"]:
t = (w["start"], w["end"], w["text"])
o.append(t)
return o
def segments_end_ts(self, res):
return [s["end"] for s in res["segments"]]
def use_vad(self):
self.transcribe_kargs["vad"] = True
def set_translate_task(self):
self.transcribe_kargs["task"] = "translate"
class FasterWhisperASR(ASRBase):
"""Uses faster-whisper library as the backend. Works much faster, appx 4-times (in offline mode). For GPU, it requires installation with a specific CUDNN version."""
sep = ""
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
from faster_whisper import WhisperModel
# logging.getLogger("faster_whisper").setLevel(logger.level)
if model_dir is not None:
logger.debug(
f"Loading whisper model from model_dir {model_dir}. modelsize and cache_dir parameters are not used."
)
model_size_or_path = model_dir
elif modelsize is not None:
model_size_or_path = modelsize
else:
raise ValueError("modelsize or model_dir parameter must be set")
# this worked fast and reliably on NVIDIA L40
model = WhisperModel(
model_size_or_path,
device="cuda",
compute_type="float16",
download_root=cache_dir,
)
# or run on GPU with INT8
# tested: the transcripts were different, probably worse than with FP16, and it was slightly (appx 20%) slower
# model = WhisperModel(model_size, device="cuda", compute_type="int8_float16")
# or run on CPU with INT8
# tested: works, but slow, appx 10-times than cuda FP16
# model = WhisperModel(modelsize, device="cpu", compute_type="int8") #, download_root="faster-disk-cache-dir/")
return model
def transcribe(self, audio, init_prompt=""):
# tested: beam_size=5 is faster and better than 1 (on one 200 second document from En ESIC, min chunk 0.01)
segments, info = self.model.transcribe(
audio,
language=self.original_language,
initial_prompt=init_prompt,
beam_size=5,
word_timestamps=True,
condition_on_previous_text=True,
**self.transcribe_kargs,
)
# print(info) # info contains language detection result
return list(segments)
def ts_words(self, segments):
o = []
for segment in segments:
for word in segment.words:
if segment.no_speech_prob > 0.9:
continue
# not stripping the spaces -- should not be merged with them!
w = word.word
t = (word.start, word.end, w)
o.append(t)
return o
def segments_end_ts(self, res):
return [s.end for s in res]
def use_vad(self):
self.transcribe_kargs["vad_filter"] = True
def set_translate_task(self):
self.transcribe_kargs["task"] = "translate"
class MLXWhisper(ASRBase):
"""
Uses MPX Whisper library as the backend, optimized for Apple Silicon.
Models available: https://huggingface.co/collections/mlx-community/whisper-663256f9964fbb1177db93dc
Significantly faster than faster-whisper (without CUDA) on Apple M1.
"""
sep = " "
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
"""
Loads the MLX-compatible Whisper model.
Args:
modelsize (str, optional): The size or name of the Whisper model to load.
If provided, it will be translated to an MLX-compatible model path using the `translate_model_name` method.
Example: "large-v3-turbo" -> "mlx-community/whisper-large-v3-turbo".
cache_dir (str, optional): Path to the directory for caching models.
**Note**: This is not supported by MLX Whisper and will be ignored.
model_dir (str, optional): Direct path to a custom model directory.
If specified, it overrides the `modelsize` parameter.
"""
from mlx_whisper.transcribe import ModelHolder, transcribe
import mlx.core as mx
if model_dir is not None:
logger.debug(
f"Loading whisper model from model_dir {model_dir}. modelsize parameter is not used."
)
model_size_or_path = model_dir
elif modelsize is not None:
model_size_or_path = self.translate_model_name(modelsize)
logger.debug(
f"Loading whisper model {modelsize}. You use mlx whisper, so {model_size_or_path} will be used."
)
self.model_size_or_path = model_size_or_path
# In mlx_whisper.transcribe, dtype is defined as:
# dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32
# Since we do not use decode_options in self.transcribe, we will set dtype to mx.float16
dtype = mx.float16
ModelHolder.get_model(model_size_or_path, dtype)
return transcribe
def translate_model_name(self, model_name):
"""
Translates a given model name to its corresponding MLX-compatible model path.
Args:
model_name (str): The name of the model to translate.
Returns:
str: The MLX-compatible model path.
"""
# Dictionary mapping model names to MLX-compatible paths
model_mapping = {
"tiny.en": "mlx-community/whisper-tiny.en-mlx",
"tiny": "mlx-community/whisper-tiny-mlx",
"base.en": "mlx-community/whisper-base.en-mlx",
"base": "mlx-community/whisper-base-mlx",
"small.en": "mlx-community/whisper-small.en-mlx",
"small": "mlx-community/whisper-small-mlx",
"medium.en": "mlx-community/whisper-medium.en-mlx",
"medium": "mlx-community/whisper-medium-mlx",
"large-v1": "mlx-community/whisper-large-v1-mlx",
"large-v2": "mlx-community/whisper-large-v2-mlx",
"large-v3": "mlx-community/whisper-large-v3-mlx",
"large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
"large": "mlx-community/whisper-large-mlx",
}
# Retrieve the corresponding MLX model path
mlx_model_path = model_mapping.get(model_name)
if mlx_model_path:
return mlx_model_path
else:
raise ValueError(
f"Model name '{model_name}' is not recognized or not supported."
)
def transcribe(self, audio, init_prompt=""):
if self.transcribe_kargs:
logger.warning("Transcribe kwargs (vad, task) are not compatible with MLX Whisper and will be ignored.")
segments = self.model(
audio,
language=self.original_language,
initial_prompt=init_prompt,
word_timestamps=True,
condition_on_previous_text=True,
path_or_hf_repo=self.model_size_or_path,
)
return segments.get("segments", [])
def ts_words(self, segments):
"""
Extract timestamped words from transcription segments and skips words with high no-speech probability.
"""
return [
(word["start"], word["end"], word["word"])
for segment in segments
for word in segment.get("words", [])
if segment.get("no_speech_prob", 0) <= 0.9
]
def segments_end_ts(self, res):
return [s["end"] for s in res]
def use_vad(self):
self.transcribe_kargs["vad_filter"] = True
def set_translate_task(self):
self.transcribe_kargs["task"] = "translate"
class OpenaiApiASR(ASRBase):
"""Uses OpenAI's Whisper API for audio transcription."""
def __init__(self, lan=None, temperature=0, logfile=sys.stderr):
self.logfile = logfile
self.modelname = "whisper-1"
self.original_language = (
None if lan == "auto" else lan
) # ISO-639-1 language code
self.response_format = "verbose_json"
self.temperature = temperature
self.load_model()
self.use_vad_opt = False
# reset the task in set_translate_task
self.task = "transcribe"
def load_model(self, *args, **kwargs):
from openai import OpenAI
self.client = OpenAI()
self.transcribed_seconds = (
0 # for logging how many seconds were processed by API, to know the cost
)
def ts_words(self, segments):
no_speech_segments = []
if self.use_vad_opt:
for segment in segments.segments:
# TODO: threshold can be set from outside
if segment["no_speech_prob"] > 0.8:
no_speech_segments.append(
(segment.get("start"), segment.get("end"))
)
o = []
for word in segments.words:
start = word.start
end = word.end
if any(s[0] <= start <= s[1] for s in no_speech_segments):
# print("Skipping word", word.get("word"), "because it's in a no-speech segment")
continue
o.append((start, end, word.word))
return o
def segments_end_ts(self, res):
return [s.end for s in res.words]
def transcribe(self, audio_data, prompt=None, *args, **kwargs):
# Write the audio data to a buffer
buffer = io.BytesIO()
buffer.name = "temp.wav"
sf.write(buffer, audio_data, samplerate=16000, format="WAV", subtype="PCM_16")
buffer.seek(0) # Reset buffer's position to the beginning
self.transcribed_seconds += math.ceil(
len(audio_data) / 16000
) # it rounds up to the whole seconds
params = {
"model": self.modelname,
"file": buffer,
"response_format": self.response_format,
"temperature": self.temperature,
"timestamp_granularities": ["word", "segment"],
}
if self.task != "translate" and self.original_language:
params["language"] = self.original_language
if prompt:
params["prompt"] = prompt
if self.task == "translate":
proc = self.client.audio.translations
else:
proc = self.client.audio.transcriptions
# Process transcription/translation
transcript = proc.create(**params)
logger.debug(
f"OpenAI API processed accumulated {self.transcribed_seconds} seconds"
)
return transcript
def use_vad(self):
self.use_vad_opt = True
def set_translate_task(self):
self.task = "translate"

View File

@@ -0,0 +1,401 @@
import sys
import numpy as np
import logging
logger = logging.getLogger(__name__)
class HypothesisBuffer:
def __init__(self, logfile=sys.stderr):
self.commited_in_buffer = []
self.buffer = []
self.new = []
self.last_commited_time = 0
self.last_commited_word = None
self.logfile = logfile
def insert(self, new, offset):
# compare self.commited_in_buffer and new. It inserts only the words in new that extend the commited_in_buffer, it means they are roughly behind last_commited_time and new in content
# the new tail is added to self.new
new = [(a + offset, b + offset, t) for a, b, t in new]
self.new = [(a, b, t) for a, b, t in new if a > self.last_commited_time - 0.1]
if len(self.new) >= 1:
a, b, t = self.new[0]
if abs(a - self.last_commited_time) < 1:
if self.commited_in_buffer:
# it's going to search for 1, 2, ..., 5 consecutive words (n-grams) that are identical in commited and new. If they are, they're dropped.
cn = len(self.commited_in_buffer)
nn = len(self.new)
for i in range(1, min(min(cn, nn), 5) + 1): # 5 is the maximum
c = " ".join(
[self.commited_in_buffer[-j][2] for j in range(1, i + 1)][
::-1
]
)
tail = " ".join(self.new[j - 1][2] for j in range(1, i + 1))
if c == tail:
words = []
for j in range(i):
words.append(repr(self.new.pop(0)))
words_msg = " ".join(words)
logger.debug(f"removing last {i} words: {words_msg}")
break
def flush(self):
# returns commited chunk = the longest common prefix of 2 last inserts.
commit = []
while self.new:
na, nb, nt = self.new[0]
if len(self.buffer) == 0:
break
if nt == self.buffer[0][2]:
commit.append((na, nb, nt))
self.last_commited_word = nt
self.last_commited_time = nb
self.buffer.pop(0)
self.new.pop(0)
else:
break
self.buffer = self.new
self.new = []
self.commited_in_buffer.extend(commit)
return commit
def pop_commited(self, time):
while self.commited_in_buffer and self.commited_in_buffer[0][1] <= time:
self.commited_in_buffer.pop(0)
def complete(self):
return self.buffer
class OnlineASRProcessor:
SAMPLING_RATE = 16000
def __init__(
self,
asr,
tokenize_method=None,
buffer_trimming=("segment", 15),
logfile=sys.stderr,
):
"""asr: WhisperASR object
tokenize_method: sentence tokenizer function for the target language. Must be a callable and behaves like the one of MosesTokenizer. It can be None, if "segment" buffer trimming option is used, then tokenizer is not used at all.
("segment", 15)
buffer_trimming: a pair of (option, seconds), where option is either "sentence" or "segment", and seconds is a number. Buffer is trimmed if it is longer than "seconds" threshold. Default is the most recommended option.
logfile: where to store the log.
"""
self.asr = asr
self.tokenize = tokenize_method
self.logfile = logfile
self.init()
self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming
def init(self, offset=None):
"""run this when starting or restarting processing"""
self.audio_buffer = np.array([], dtype=np.float32)
self.transcript_buffer = HypothesisBuffer(logfile=self.logfile)
self.buffer_time_offset = 0
if offset is not None:
self.buffer_time_offset = offset
self.transcript_buffer.last_commited_time = self.buffer_time_offset
self.commited = []
def insert_audio_chunk(self, audio):
self.audio_buffer = np.append(self.audio_buffer, audio)
def prompt(self):
"""Returns a tuple: (prompt, context), where "prompt" is a 200-character suffix of commited text that is inside of the scrolled away part of audio buffer.
"context" is the commited text that is inside the audio buffer. It is transcribed again and skipped. It is returned only for debugging and logging reasons.
"""
k = max(0, len(self.commited) - 1)
while k > 0 and self.commited[k - 1][1] > self.buffer_time_offset:
k -= 1
p = self.commited[:k]
p = [t for _, _, t in p]
prompt = []
l = 0
while p and l < 200: # 200 characters prompt size
x = p.pop(-1)
l += len(x) + 1
prompt.append(x)
non_prompt = self.commited[k:]
return self.asr.sep.join(prompt[::-1]), self.asr.sep.join(
t for _, _, t in non_prompt
)
def process_iter(self):
"""Runs on the current audio buffer.
Returns: a tuple (beg_timestamp, end_timestamp, "text"), or (None, None, "").
The non-emty text is confirmed (committed) partial transcript.
"""
prompt, non_prompt = self.prompt()
logger.debug(f"PROMPT: {prompt}")
logger.debug(f"CONTEXT: {non_prompt}")
logger.debug(
f"transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f} seconds from {self.buffer_time_offset:2.2f}"
)
res = self.asr.transcribe(self.audio_buffer, init_prompt=prompt)
# transform to [(beg,end,"word1"), ...]
tsw = self.asr.ts_words(res)
self.transcript_buffer.insert(tsw, self.buffer_time_offset)
o = self.transcript_buffer.flush()
self.commited.extend(o)
completed = self.to_flush(o)
logger.debug(f">>>>COMPLETE NOW: {completed[2]}")
the_rest = self.to_flush(self.transcript_buffer.complete())
logger.debug(f"INCOMPLETE: {the_rest[2]}")
# there is a newly confirmed text
if o and self.buffer_trimming_way == "sentence": # trim the completed sentences
if (
len(self.audio_buffer) / self.SAMPLING_RATE > self.buffer_trimming_sec
): # longer than this
self.chunk_completed_sentence()
if self.buffer_trimming_way == "segment":
s = self.buffer_trimming_sec # trim the completed segments longer than s,
else:
s = 30 # if the audio buffer is longer than 30s, trim it
if len(self.audio_buffer) / self.SAMPLING_RATE > s:
self.chunk_completed_segment(res)
# alternative: on any word
# l = self.buffer_time_offset + len(self.audio_buffer)/self.SAMPLING_RATE - 10
# let's find commited word that is less
# k = len(self.commited)-1
# while k>0 and self.commited[k][1] > l:
# k -= 1
# t = self.commited[k][1]
logger.debug("chunking segment")
# self.chunk_at(t)
logger.debug(
f"len of buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f}"
)
return self.to_flush(o)
def chunk_completed_sentence(self):
if self.commited == []:
return
logger.debug("COMPLETED SENTENCE: ", [s[2] for s in self.commited])
sents = self.words_to_sentences(self.commited)
for s in sents:
logger.debug(f"\t\tSENT: {s}")
if len(sents) < 2:
return
while len(sents) > 2:
sents.pop(0)
# we will continue with audio processing at this timestamp
chunk_at = sents[-2][1]
logger.debug(f"--- sentence chunked at {chunk_at:2.2f}")
self.chunk_at(chunk_at)
def chunk_completed_segment(self, res):
if self.commited == []:
return
ends = self.asr.segments_end_ts(res)
t = self.commited[-1][1]
if len(ends) > 1:
e = ends[-2] + self.buffer_time_offset
while len(ends) > 2 and e > t:
ends.pop(-1)
e = ends[-2] + self.buffer_time_offset
if e <= t:
logger.debug(f"--- segment chunked at {e:2.2f}")
self.chunk_at(e)
else:
logger.debug(f"--- last segment not within commited area")
else:
logger.debug(f"--- not enough segments to chunk")
def chunk_at(self, time):
"""trims the hypothesis and audio buffer at "time" """
self.transcript_buffer.pop_commited(time)
cut_seconds = time - self.buffer_time_offset
self.audio_buffer = self.audio_buffer[int(cut_seconds * self.SAMPLING_RATE) :]
self.buffer_time_offset = time
def words_to_sentences(self, words):
"""Uses self.tokenize for sentence segmentation of words.
Returns: [(beg,end,"sentence 1"),...]
"""
cwords = [w for w in words]
t = " ".join(o[2] for o in cwords)
s = self.tokenize(t)
out = []
while s:
beg = None
end = None
sent = s.pop(0).strip()
fsent = sent
while cwords:
b, e, w = cwords.pop(0)
w = w.strip()
if beg is None and sent.startswith(w):
beg = b
elif end is None and sent == w:
end = e
out.append((beg, end, fsent))
break
sent = sent[len(w) :].strip()
return out
def finish(self):
"""Flush the incomplete text when the whole processing ends.
Returns: the same format as self.process_iter()
"""
o = self.transcript_buffer.complete()
f = self.to_flush(o)
logger.debug(f"last, noncommited: {f}")
self.buffer_time_offset += len(self.audio_buffer) / 16000
return f
def to_flush(
self,
sents,
sep=None,
offset=0,
):
# concatenates the timestamped words or sentences into one sequence that is flushed in one line
# sents: [(beg1, end1, "sentence1"), ...] or [] if empty
# return: (beg1,end-of-last-sentence,"concatenation of sentences") or (None, None, "") if empty
if sep is None:
sep = self.asr.sep
t = sep.join(s[2] for s in sents)
if len(sents) == 0:
b = None
e = None
else:
b = offset + sents[0][0]
e = offset + sents[-1][1]
return (b, e, t)
class VACOnlineASRProcessor(OnlineASRProcessor):
"""Wraps OnlineASRProcessor with VAC (Voice Activity Controller).
It works the same way as OnlineASRProcessor: it receives chunks of audio (e.g. 0.04 seconds),
it runs VAD and continuously detects whether there is speech or not.
When it detects end of speech (non-voice for 500ms), it makes OnlineASRProcessor to end the utterance immediately.
"""
def __init__(self, online_chunk_size, *a, **kw):
self.online_chunk_size = online_chunk_size
self.online = OnlineASRProcessor(*a, **kw)
# VAC:
import torch
model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
from silero_vad_iterator import FixedVADIterator
self.vac = FixedVADIterator(
model
) # we use the default options there: 500ms silence, 100ms padding, etc.
self.logfile = self.online.logfile
self.init()
def init(self):
self.online.init()
self.vac.reset_states()
self.current_online_chunk_buffer_size = 0
self.is_currently_final = False
self.status = None # or "voice" or "nonvoice"
self.audio_buffer = np.array([], dtype=np.float32)
self.buffer_offset = 0 # in frames
def clear_buffer(self):
self.buffer_offset += len(self.audio_buffer)
self.audio_buffer = np.array([], dtype=np.float32)
def insert_audio_chunk(self, audio):
res = self.vac(audio)
self.audio_buffer = np.append(self.audio_buffer, audio)
if res is not None:
frame = list(res.values())[0] - self.buffer_offset
if "start" in res and "end" not in res:
self.status = "voice"
send_audio = self.audio_buffer[frame:]
self.online.init(
offset=(frame + self.buffer_offset) / self.SAMPLING_RATE
)
self.online.insert_audio_chunk(send_audio)
self.current_online_chunk_buffer_size += len(send_audio)
self.clear_buffer()
elif "end" in res and "start" not in res:
self.status = "nonvoice"
send_audio = self.audio_buffer[:frame]
self.online.insert_audio_chunk(send_audio)
self.current_online_chunk_buffer_size += len(send_audio)
self.is_currently_final = True
self.clear_buffer()
else:
beg = res["start"] - self.buffer_offset
end = res["end"] - self.buffer_offset
self.status = "nonvoice"
send_audio = self.audio_buffer[beg:end]
self.online.init(offset=(beg + self.buffer_offset) / self.SAMPLING_RATE)
self.online.insert_audio_chunk(send_audio)
self.current_online_chunk_buffer_size += len(send_audio)
self.is_currently_final = True
self.clear_buffer()
else:
if self.status == "voice":
self.online.insert_audio_chunk(self.audio_buffer)
self.current_online_chunk_buffer_size += len(self.audio_buffer)
self.clear_buffer()
else:
# We keep 1 second because VAD may later find start of voice in it.
# But we trim it to prevent OOM.
self.buffer_offset += max(
0, len(self.audio_buffer) - self.SAMPLING_RATE
)
self.audio_buffer = self.audio_buffer[-self.SAMPLING_RATE :]
def process_iter(self):
if self.is_currently_final:
return self.finish()
elif (
self.current_online_chunk_buffer_size
> self.SAMPLING_RATE * self.online_chunk_size
):
self.current_online_chunk_buffer_size = 0
ret = self.online.process_iter()
return ret
else:
print("no online update, only VAD", self.status, file=self.logfile)
return (None, None, "")
def finish(self):
ret = self.online.finish()
self.current_online_chunk_buffer_size = 0
self.is_currently_final = False
return ret

View File

@@ -43,7 +43,7 @@ args = parser.parse_args()
asr, tokenizer = backend_factory(args)
# Load demo HTML for the root endpoint
with open("src/live_transcription.html", "r", encoding="utf-8") as f:
with open("src/web/live_transcription.html", "r", encoding="utf-8") as f:
html = f.read()

View File

@@ -5,10 +5,8 @@ import librosa
from functools import lru_cache
import time
import logging
import io
import soundfile as sf
import math
from src.whisper_streaming.backends import FasterWhisperASR, MLXWhisper, WhisperTimestampedASR, OpenaiApiASR
from src.whisper_streaming.online_asr import OnlineASRProcessor, VACOnlineASRProcessor
logger = logging.getLogger(__name__)
@@ -25,768 +23,6 @@ def load_audio_chunk(fname, beg, end):
end_s = int(end * 16000)
return audio[beg_s:end_s]
# Whisper backend
class ASRBase:
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
# "" for faster-whisper because it emits the spaces when neeeded)
def __init__(
self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr
):
self.logfile = logfile
self.transcribe_kargs = {}
if lan == "auto":
self.original_language = None
else:
self.original_language = lan
self.model = self.load_model(modelsize, cache_dir, model_dir)
def load_model(self, modelsize, cache_dir):
raise NotImplemented("must be implemented in the child class")
def transcribe(self, audio, init_prompt=""):
raise NotImplemented("must be implemented in the child class")
def use_vad(self):
raise NotImplemented("must be implemented in the child class")
class WhisperTimestampedASR(ASRBase):
"""Uses whisper_timestamped library as the backend. Initially, we tested the code on this backend. It worked, but slower than faster-whisper.
On the other hand, the installation for GPU could be easier.
"""
sep = " "
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
import whisper
import whisper_timestamped
from whisper_timestamped import transcribe_timestamped
self.transcribe_timestamped = transcribe_timestamped
if model_dir is not None:
logger.debug("ignoring model_dir, not implemented")
return whisper.load_model(modelsize, download_root=cache_dir)
def transcribe(self, audio, init_prompt=""):
result = self.transcribe_timestamped(
self.model,
audio,
language=self.original_language,
initial_prompt=init_prompt,
verbose=None,
condition_on_previous_text=True,
**self.transcribe_kargs,
)
return result
def ts_words(self, r):
# return: transcribe result object to [(beg,end,"word1"), ...]
o = []
for s in r["segments"]:
for w in s["words"]:
t = (w["start"], w["end"], w["text"])
o.append(t)
return o
def segments_end_ts(self, res):
return [s["end"] for s in res["segments"]]
def use_vad(self):
self.transcribe_kargs["vad"] = True
def set_translate_task(self):
self.transcribe_kargs["task"] = "translate"
class FasterWhisperASR(ASRBase):
"""Uses faster-whisper library as the backend. Works much faster, appx 4-times (in offline mode). For GPU, it requires installation with a specific CUDNN version."""
sep = ""
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
from faster_whisper import WhisperModel
# logging.getLogger("faster_whisper").setLevel(logger.level)
if model_dir is not None:
logger.debug(
f"Loading whisper model from model_dir {model_dir}. modelsize and cache_dir parameters are not used."
)
model_size_or_path = model_dir
elif modelsize is not None:
model_size_or_path = modelsize
else:
raise ValueError("modelsize or model_dir parameter must be set")
# this worked fast and reliably on NVIDIA L40
model = WhisperModel(
model_size_or_path,
device="cuda",
compute_type="float16",
download_root=cache_dir,
)
# or run on GPU with INT8
# tested: the transcripts were different, probably worse than with FP16, and it was slightly (appx 20%) slower
# model = WhisperModel(model_size, device="cuda", compute_type="int8_float16")
# or run on CPU with INT8
# tested: works, but slow, appx 10-times than cuda FP16
# model = WhisperModel(modelsize, device="cpu", compute_type="int8") #, download_root="faster-disk-cache-dir/")
return model
def transcribe(self, audio, init_prompt=""):
# tested: beam_size=5 is faster and better than 1 (on one 200 second document from En ESIC, min chunk 0.01)
segments, info = self.model.transcribe(
audio,
language=self.original_language,
initial_prompt=init_prompt,
beam_size=5,
word_timestamps=True,
condition_on_previous_text=True,
**self.transcribe_kargs,
)
# print(info) # info contains language detection result
return list(segments)
def ts_words(self, segments):
o = []
for segment in segments:
for word in segment.words:
if segment.no_speech_prob > 0.9:
continue
# not stripping the spaces -- should not be merged with them!
w = word.word
t = (word.start, word.end, w)
o.append(t)
return o
def segments_end_ts(self, res):
return [s.end for s in res]
def use_vad(self):
self.transcribe_kargs["vad_filter"] = True
def set_translate_task(self):
self.transcribe_kargs["task"] = "translate"
class MLXWhisper(ASRBase):
"""
Uses MPX Whisper library as the backend, optimized for Apple Silicon.
Models available: https://huggingface.co/collections/mlx-community/whisper-663256f9964fbb1177db93dc
Significantly faster than faster-whisper (without CUDA) on Apple M1.
"""
sep = " "
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
"""
Loads the MLX-compatible Whisper model.
Args:
modelsize (str, optional): The size or name of the Whisper model to load.
If provided, it will be translated to an MLX-compatible model path using the `translate_model_name` method.
Example: "large-v3-turbo" -> "mlx-community/whisper-large-v3-turbo".
cache_dir (str, optional): Path to the directory for caching models.
**Note**: This is not supported by MLX Whisper and will be ignored.
model_dir (str, optional): Direct path to a custom model directory.
If specified, it overrides the `modelsize` parameter.
"""
from mlx_whisper.transcribe import ModelHolder, transcribe
import mlx.core as mx
if model_dir is not None:
logger.debug(
f"Loading whisper model from model_dir {model_dir}. modelsize parameter is not used."
)
model_size_or_path = model_dir
elif modelsize is not None:
model_size_or_path = self.translate_model_name(modelsize)
logger.debug(
f"Loading whisper model {modelsize}. You use mlx whisper, so {model_size_or_path} will be used."
)
self.model_size_or_path = model_size_or_path
# In mlx_whisper.transcribe, dtype is defined as:
# dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32
# Since we do not use decode_options in self.transcribe, we will set dtype to mx.float16
dtype = mx.float16
ModelHolder.get_model(model_size_or_path, dtype)
return transcribe
def translate_model_name(self, model_name):
"""
Translates a given model name to its corresponding MLX-compatible model path.
Args:
model_name (str): The name of the model to translate.
Returns:
str: The MLX-compatible model path.
"""
# Dictionary mapping model names to MLX-compatible paths
model_mapping = {
"tiny.en": "mlx-community/whisper-tiny.en-mlx",
"tiny": "mlx-community/whisper-tiny-mlx",
"base.en": "mlx-community/whisper-base.en-mlx",
"base": "mlx-community/whisper-base-mlx",
"small.en": "mlx-community/whisper-small.en-mlx",
"small": "mlx-community/whisper-small-mlx",
"medium.en": "mlx-community/whisper-medium.en-mlx",
"medium": "mlx-community/whisper-medium-mlx",
"large-v1": "mlx-community/whisper-large-v1-mlx",
"large-v2": "mlx-community/whisper-large-v2-mlx",
"large-v3": "mlx-community/whisper-large-v3-mlx",
"large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
"large": "mlx-community/whisper-large-mlx",
}
# Retrieve the corresponding MLX model path
mlx_model_path = model_mapping.get(model_name)
if mlx_model_path:
return mlx_model_path
else:
raise ValueError(
f"Model name '{model_name}' is not recognized or not supported."
)
def transcribe(self, audio, init_prompt=""):
if self.transcribe_kargs:
logger.warning("Transcribe kwargs (vad, task) are not compatible with MLX Whisper and will be ignored.")
segments = self.model(
audio,
language=self.original_language,
initial_prompt=init_prompt,
word_timestamps=True,
condition_on_previous_text=True,
path_or_hf_repo=self.model_size_or_path,
)
return segments.get("segments", [])
def ts_words(self, segments):
"""
Extract timestamped words from transcription segments and skips words with high no-speech probability.
"""
return [
(word["start"], word["end"], word["word"])
for segment in segments
for word in segment.get("words", [])
if segment.get("no_speech_prob", 0) <= 0.9
]
def segments_end_ts(self, res):
return [s["end"] for s in res]
def use_vad(self):
self.transcribe_kargs["vad_filter"] = True
def set_translate_task(self):
self.transcribe_kargs["task"] = "translate"
class OpenaiApiASR(ASRBase):
"""Uses OpenAI's Whisper API for audio transcription."""
def __init__(self, lan=None, temperature=0, logfile=sys.stderr):
self.logfile = logfile
self.modelname = "whisper-1"
self.original_language = (
None if lan == "auto" else lan
) # ISO-639-1 language code
self.response_format = "verbose_json"
self.temperature = temperature
self.load_model()
self.use_vad_opt = False
# reset the task in set_translate_task
self.task = "transcribe"
def load_model(self, *args, **kwargs):
from openai import OpenAI
self.client = OpenAI()
self.transcribed_seconds = (
0 # for logging how many seconds were processed by API, to know the cost
)
def ts_words(self, segments):
no_speech_segments = []
if self.use_vad_opt:
for segment in segments.segments:
# TODO: threshold can be set from outside
if segment["no_speech_prob"] > 0.8:
no_speech_segments.append(
(segment.get("start"), segment.get("end"))
)
o = []
for word in segments.words:
start = word.start
end = word.end
if any(s[0] <= start <= s[1] for s in no_speech_segments):
# print("Skipping word", word.get("word"), "because it's in a no-speech segment")
continue
o.append((start, end, word.word))
return o
def segments_end_ts(self, res):
return [s.end for s in res.words]
def transcribe(self, audio_data, prompt=None, *args, **kwargs):
# Write the audio data to a buffer
buffer = io.BytesIO()
buffer.name = "temp.wav"
sf.write(buffer, audio_data, samplerate=16000, format="WAV", subtype="PCM_16")
buffer.seek(0) # Reset buffer's position to the beginning
self.transcribed_seconds += math.ceil(
len(audio_data) / 16000
) # it rounds up to the whole seconds
params = {
"model": self.modelname,
"file": buffer,
"response_format": self.response_format,
"temperature": self.temperature,
"timestamp_granularities": ["word", "segment"],
}
if self.task != "translate" and self.original_language:
params["language"] = self.original_language
if prompt:
params["prompt"] = prompt
if self.task == "translate":
proc = self.client.audio.translations
else:
proc = self.client.audio.transcriptions
# Process transcription/translation
transcript = proc.create(**params)
logger.debug(
f"OpenAI API processed accumulated {self.transcribed_seconds} seconds"
)
return transcript
def use_vad(self):
self.use_vad_opt = True
def set_translate_task(self):
self.task = "translate"
class HypothesisBuffer:
def __init__(self, logfile=sys.stderr):
self.commited_in_buffer = []
self.buffer = []
self.new = []
self.last_commited_time = 0
self.last_commited_word = None
self.logfile = logfile
def insert(self, new, offset):
# compare self.commited_in_buffer and new. It inserts only the words in new that extend the commited_in_buffer, it means they are roughly behind last_commited_time and new in content
# the new tail is added to self.new
new = [(a + offset, b + offset, t) for a, b, t in new]
self.new = [(a, b, t) for a, b, t in new if a > self.last_commited_time - 0.1]
if len(self.new) >= 1:
a, b, t = self.new[0]
if abs(a - self.last_commited_time) < 1:
if self.commited_in_buffer:
# it's going to search for 1, 2, ..., 5 consecutive words (n-grams) that are identical in commited and new. If they are, they're dropped.
cn = len(self.commited_in_buffer)
nn = len(self.new)
for i in range(1, min(min(cn, nn), 5) + 1): # 5 is the maximum
c = " ".join(
[self.commited_in_buffer[-j][2] for j in range(1, i + 1)][
::-1
]
)
tail = " ".join(self.new[j - 1][2] for j in range(1, i + 1))
if c == tail:
words = []
for j in range(i):
words.append(repr(self.new.pop(0)))
words_msg = " ".join(words)
logger.debug(f"removing last {i} words: {words_msg}")
break
def flush(self):
# returns commited chunk = the longest common prefix of 2 last inserts.
commit = []
while self.new:
na, nb, nt = self.new[0]
if len(self.buffer) == 0:
break
if nt == self.buffer[0][2]:
commit.append((na, nb, nt))
self.last_commited_word = nt
self.last_commited_time = nb
self.buffer.pop(0)
self.new.pop(0)
else:
break
self.buffer = self.new
self.new = []
self.commited_in_buffer.extend(commit)
return commit
def pop_commited(self, time):
while self.commited_in_buffer and self.commited_in_buffer[0][1] <= time:
self.commited_in_buffer.pop(0)
def complete(self):
return self.buffer
class OnlineASRProcessor:
SAMPLING_RATE = 16000
def __init__(
self,
asr,
tokenize_method=None,
buffer_trimming=("segment", 15),
logfile=sys.stderr,
):
"""asr: WhisperASR object
tokenize_method: sentence tokenizer function for the target language. Must be a callable and behaves like the one of MosesTokenizer. It can be None, if "segment" buffer trimming option is used, then tokenizer is not used at all.
("segment", 15)
buffer_trimming: a pair of (option, seconds), where option is either "sentence" or "segment", and seconds is a number. Buffer is trimmed if it is longer than "seconds" threshold. Default is the most recommended option.
logfile: where to store the log.
"""
self.asr = asr
self.tokenize = tokenize_method
self.logfile = logfile
self.init()
self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming
def init(self, offset=None):
"""run this when starting or restarting processing"""
self.audio_buffer = np.array([], dtype=np.float32)
self.transcript_buffer = HypothesisBuffer(logfile=self.logfile)
self.buffer_time_offset = 0
if offset is not None:
self.buffer_time_offset = offset
self.transcript_buffer.last_commited_time = self.buffer_time_offset
self.commited = []
def insert_audio_chunk(self, audio):
self.audio_buffer = np.append(self.audio_buffer, audio)
def prompt(self):
"""Returns a tuple: (prompt, context), where "prompt" is a 200-character suffix of commited text that is inside of the scrolled away part of audio buffer.
"context" is the commited text that is inside the audio buffer. It is transcribed again and skipped. It is returned only for debugging and logging reasons.
"""
k = max(0, len(self.commited) - 1)
while k > 0 and self.commited[k - 1][1] > self.buffer_time_offset:
k -= 1
p = self.commited[:k]
p = [t for _, _, t in p]
prompt = []
l = 0
while p and l < 200: # 200 characters prompt size
x = p.pop(-1)
l += len(x) + 1
prompt.append(x)
non_prompt = self.commited[k:]
return self.asr.sep.join(prompt[::-1]), self.asr.sep.join(
t for _, _, t in non_prompt
)
def process_iter(self):
"""Runs on the current audio buffer.
Returns: a tuple (beg_timestamp, end_timestamp, "text"), or (None, None, "").
The non-emty text is confirmed (committed) partial transcript.
"""
prompt, non_prompt = self.prompt()
logger.debug(f"PROMPT: {prompt}")
logger.debug(f"CONTEXT: {non_prompt}")
logger.debug(
f"transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f} seconds from {self.buffer_time_offset:2.2f}"
)
res = self.asr.transcribe(self.audio_buffer, init_prompt=prompt)
# transform to [(beg,end,"word1"), ...]
tsw = self.asr.ts_words(res)
self.transcript_buffer.insert(tsw, self.buffer_time_offset)
o = self.transcript_buffer.flush()
self.commited.extend(o)
completed = self.to_flush(o)
logger.debug(f">>>>COMPLETE NOW: {completed[2]}")
the_rest = self.to_flush(self.transcript_buffer.complete())
logger.debug(f"INCOMPLETE: {the_rest[2]}")
# there is a newly confirmed text
if o and self.buffer_trimming_way == "sentence": # trim the completed sentences
if (
len(self.audio_buffer) / self.SAMPLING_RATE > self.buffer_trimming_sec
): # longer than this
self.chunk_completed_sentence()
if self.buffer_trimming_way == "segment":
s = self.buffer_trimming_sec # trim the completed segments longer than s,
else:
s = 30 # if the audio buffer is longer than 30s, trim it
if len(self.audio_buffer) / self.SAMPLING_RATE > s:
self.chunk_completed_segment(res)
# alternative: on any word
# l = self.buffer_time_offset + len(self.audio_buffer)/self.SAMPLING_RATE - 10
# let's find commited word that is less
# k = len(self.commited)-1
# while k>0 and self.commited[k][1] > l:
# k -= 1
# t = self.commited[k][1]
logger.debug("chunking segment")
# self.chunk_at(t)
logger.debug(
f"len of buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f}"
)
return self.to_flush(o)
def chunk_completed_sentence(self):
if self.commited == []:
return
logger.debug("COMPLETED SENTENCE: ", [s[2] for s in self.commited])
sents = self.words_to_sentences(self.commited)
for s in sents:
logger.debug(f"\t\tSENT: {s}")
if len(sents) < 2:
return
while len(sents) > 2:
sents.pop(0)
# we will continue with audio processing at this timestamp
chunk_at = sents[-2][1]
logger.debug(f"--- sentence chunked at {chunk_at:2.2f}")
self.chunk_at(chunk_at)
def chunk_completed_segment(self, res):
if self.commited == []:
return
ends = self.asr.segments_end_ts(res)
t = self.commited[-1][1]
if len(ends) > 1:
e = ends[-2] + self.buffer_time_offset
while len(ends) > 2 and e > t:
ends.pop(-1)
e = ends[-2] + self.buffer_time_offset
if e <= t:
logger.debug(f"--- segment chunked at {e:2.2f}")
self.chunk_at(e)
else:
logger.debug(f"--- last segment not within commited area")
else:
logger.debug(f"--- not enough segments to chunk")
def chunk_at(self, time):
"""trims the hypothesis and audio buffer at "time" """
self.transcript_buffer.pop_commited(time)
cut_seconds = time - self.buffer_time_offset
self.audio_buffer = self.audio_buffer[int(cut_seconds * self.SAMPLING_RATE) :]
self.buffer_time_offset = time
def words_to_sentences(self, words):
"""Uses self.tokenize for sentence segmentation of words.
Returns: [(beg,end,"sentence 1"),...]
"""
cwords = [w for w in words]
t = " ".join(o[2] for o in cwords)
s = self.tokenize(t)
out = []
while s:
beg = None
end = None
sent = s.pop(0).strip()
fsent = sent
while cwords:
b, e, w = cwords.pop(0)
w = w.strip()
if beg is None and sent.startswith(w):
beg = b
elif end is None and sent == w:
end = e
out.append((beg, end, fsent))
break
sent = sent[len(w) :].strip()
return out
def finish(self):
"""Flush the incomplete text when the whole processing ends.
Returns: the same format as self.process_iter()
"""
o = self.transcript_buffer.complete()
f = self.to_flush(o)
logger.debug(f"last, noncommited: {f}")
self.buffer_time_offset += len(self.audio_buffer) / 16000
return f
def to_flush(
self,
sents,
sep=None,
offset=0,
):
# concatenates the timestamped words or sentences into one sequence that is flushed in one line
# sents: [(beg1, end1, "sentence1"), ...] or [] if empty
# return: (beg1,end-of-last-sentence,"concatenation of sentences") or (None, None, "") if empty
if sep is None:
sep = self.asr.sep
t = sep.join(s[2] for s in sents)
if len(sents) == 0:
b = None
e = None
else:
b = offset + sents[0][0]
e = offset + sents[-1][1]
return (b, e, t)
class VACOnlineASRProcessor(OnlineASRProcessor):
"""Wraps OnlineASRProcessor with VAC (Voice Activity Controller).
It works the same way as OnlineASRProcessor: it receives chunks of audio (e.g. 0.04 seconds),
it runs VAD and continuously detects whether there is speech or not.
When it detects end of speech (non-voice for 500ms), it makes OnlineASRProcessor to end the utterance immediately.
"""
def __init__(self, online_chunk_size, *a, **kw):
self.online_chunk_size = online_chunk_size
self.online = OnlineASRProcessor(*a, **kw)
# VAC:
import torch
model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
from silero_vad_iterator import FixedVADIterator
self.vac = FixedVADIterator(
model
) # we use the default options there: 500ms silence, 100ms padding, etc.
self.logfile = self.online.logfile
self.init()
def init(self):
self.online.init()
self.vac.reset_states()
self.current_online_chunk_buffer_size = 0
self.is_currently_final = False
self.status = None # or "voice" or "nonvoice"
self.audio_buffer = np.array([], dtype=np.float32)
self.buffer_offset = 0 # in frames
def clear_buffer(self):
self.buffer_offset += len(self.audio_buffer)
self.audio_buffer = np.array([], dtype=np.float32)
def insert_audio_chunk(self, audio):
res = self.vac(audio)
self.audio_buffer = np.append(self.audio_buffer, audio)
if res is not None:
frame = list(res.values())[0] - self.buffer_offset
if "start" in res and "end" not in res:
self.status = "voice"
send_audio = self.audio_buffer[frame:]
self.online.init(
offset=(frame + self.buffer_offset) / self.SAMPLING_RATE
)
self.online.insert_audio_chunk(send_audio)
self.current_online_chunk_buffer_size += len(send_audio)
self.clear_buffer()
elif "end" in res and "start" not in res:
self.status = "nonvoice"
send_audio = self.audio_buffer[:frame]
self.online.insert_audio_chunk(send_audio)
self.current_online_chunk_buffer_size += len(send_audio)
self.is_currently_final = True
self.clear_buffer()
else:
beg = res["start"] - self.buffer_offset
end = res["end"] - self.buffer_offset
self.status = "nonvoice"
send_audio = self.audio_buffer[beg:end]
self.online.init(offset=(beg + self.buffer_offset) / self.SAMPLING_RATE)
self.online.insert_audio_chunk(send_audio)
self.current_online_chunk_buffer_size += len(send_audio)
self.is_currently_final = True
self.clear_buffer()
else:
if self.status == "voice":
self.online.insert_audio_chunk(self.audio_buffer)
self.current_online_chunk_buffer_size += len(self.audio_buffer)
self.clear_buffer()
else:
# We keep 1 second because VAD may later find start of voice in it.
# But we trim it to prevent OOM.
self.buffer_offset += max(
0, len(self.audio_buffer) - self.SAMPLING_RATE
)
self.audio_buffer = self.audio_buffer[-self.SAMPLING_RATE :]
def process_iter(self):
if self.is_currently_final:
return self.finish()
elif (
self.current_online_chunk_buffer_size
> self.SAMPLING_RATE * self.online_chunk_size
):
self.current_online_chunk_buffer_size = 0
ret = self.online.process_iter()
return ret
else:
print("no online update, only VAD", self.status, file=self.logfile)
return (None, None, "")
def finish(self):
ret = self.online.finish()
self.current_online_chunk_buffer_size = 0
self.is_currently_final = False
return ret
WHISPER_LANG_CODES = "af,am,ar,as,az,ba,be,bg,bn,bo,br,bs,ca,cs,cy,da,de,el,en,es,et,eu,fa,fi,fo,fr,gl,gu,ha,haw,he,hi,hr,ht,hu,hy,id,is,it,ja,jw,ka,kk,km,kn,ko,la,lb,ln,lo,lt,lv,mg,mi,mk,ml,mn,mr,ms,mt,my,ne,nl,nn,no,oc,pa,pl,ps,pt,ro,ru,sa,sd,si,sk,sl,sn,so,sq,sr,su,sv,sw,ta,te,tg,th,tk,tl,tr,tt,uk,ur,uz,vi,yi,yo,zh".split(
","
)
@@ -852,7 +88,7 @@ def add_shared_args(parser):
parser.add_argument(
"--model",
type=str,
default="tiny",
default="large-v3-turbo",
choices="tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo".split(
","
),
@@ -887,14 +123,14 @@ def add_shared_args(parser):
parser.add_argument(
"--backend",
type=str,
default="mlx-whisper",
default="faster-whisper",
choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api"],
help="Load only this backend for Whisper processing.",
)
parser.add_argument(
"--vac",
action="store_true",
default=True,
default=False,
help="Use VAC = voice activity controller. Recommended. Requires torch.",
)
parser.add_argument(
@@ -903,7 +139,7 @@ def add_shared_args(parser):
parser.add_argument(
"--vad",
action="store_true",
default=True,
default=False,
help="Use VAD = voice activity detection, with the default parameters.",
)
parser.add_argument(