mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-08 06:44:09 +00:00
migration to silero vad v6: supports onnx
This commit is contained in:
18
.gitignore
vendored
18
.gitignore
vendored
@@ -54,21 +54,6 @@ coverage.xml
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
@@ -138,4 +123,5 @@ test_*.py
|
||||
launch.json
|
||||
.DS_Store
|
||||
test/*
|
||||
nllb-200-distilled-600M-ctranslate2/*
|
||||
nllb-200-distilled-600M-ctranslate2/*
|
||||
*.mp3
|
||||
@@ -50,8 +50,20 @@ Homepage = "https://github.com/QuentinFuxa/WhisperLiveKit"
|
||||
whisperlivekit-server = "whisperlivekit.basic_server:main"
|
||||
|
||||
[tool.setuptools]
|
||||
packages = ["whisperlivekit", "whisperlivekit.diarization", "whisperlivekit.simul_whisper", "whisperlivekit.simul_whisper.whisper", "whisperlivekit.simul_whisper.whisper.assets", "whisperlivekit.simul_whisper.whisper.normalizers", "whisperlivekit.web", "whisperlivekit.whisper_streaming_custom", "whisperlivekit.translation"]
|
||||
packages = [
|
||||
"whisperlivekit",
|
||||
"whisperlivekit.diarization",
|
||||
"whisperlivekit.simul_whisper",
|
||||
"whisperlivekit.simul_whisper.whisper",
|
||||
"whisperlivekit.simul_whisper.whisper.assets",
|
||||
"whisperlivekit.simul_whisper.whisper.normalizers",
|
||||
"whisperlivekit.web",
|
||||
"whisperlivekit.whisper_streaming_custom",
|
||||
"whisperlivekit.translation",
|
||||
"whisperlivekit.vad_models"
|
||||
]
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
whisperlivekit = ["web/*.html", "web/*.css", "web/*.js", "web/src/*.svg"]
|
||||
"whisperlivekit.simul_whisper.whisper.assets" = ["*.tiktoken", "*.npz"]
|
||||
"whisperlivekit.vad_models" = ["*.jit", "*.onnx"]
|
||||
|
||||
@@ -33,6 +33,7 @@ class TranscriptionEngine:
|
||||
"punctuation_split": False,
|
||||
"target_language": "",
|
||||
"vac": True,
|
||||
"vac_onnx": False,
|
||||
"vac_chunk_size": 0.04,
|
||||
"log_level": "DEBUG",
|
||||
"ssl_certfile": None,
|
||||
@@ -75,8 +76,10 @@ class TranscriptionEngine:
|
||||
self.vac_model = None
|
||||
|
||||
if self.args.vac:
|
||||
import torch
|
||||
self.vac_model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
|
||||
from whisperlivekit.silero_vad_iterator import load_silero_vad
|
||||
# Use ONNX if specified, otherwise use JIT (default)
|
||||
use_onnx = kwargs.get('vac_onnx', False)
|
||||
self.vac_model = load_silero_vad(onnx=use_onnx)
|
||||
|
||||
if self.args.transcription:
|
||||
if self.args.backend == "simulstreaming":
|
||||
@@ -173,4 +176,4 @@ def online_translation_factory(args, translation_model):
|
||||
#one shared nllb model for all speaker
|
||||
#one tokenizer per speaker/language
|
||||
from whisperlivekit.translation.translation import OnlineTranslation
|
||||
return OnlineTranslation(translation_model, [args.lan], [args.target_language])
|
||||
return OnlineTranslation(translation_model, [args.lan], [args.target_language])
|
||||
|
||||
@@ -1,27 +1,182 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
# This is copied from silero-vad's vad_utils.py:
|
||||
# https://github.com/snakers4/silero-vad/blob/f6b1294cb27590fb2452899df98fb234dfef1134/utils_vad.py#L340
|
||||
# (except changed defaults)
|
||||
"""
|
||||
Code is adapted from silero-vad v6: https://github.com/snakers4/silero-vad
|
||||
"""
|
||||
|
||||
# Their licence is MIT, same as ours: https://github.com/snakers4/silero-vad/blob/f6b1294cb27590fb2452899df98fb234dfef1134/LICENSE
|
||||
def init_jit_model(model_path: str, device=torch.device('cpu')):
|
||||
"""Load a JIT model from file."""
|
||||
model = torch.jit.load(model_path, map_location=device)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
class OnnxWrapper():
|
||||
"""ONNX Runtime wrapper for Silero VAD model."""
|
||||
|
||||
def __init__(self, path, force_onnx_cpu=False):
|
||||
global np
|
||||
import numpy as np
|
||||
import onnxruntime
|
||||
|
||||
opts = onnxruntime.SessionOptions()
|
||||
opts.inter_op_num_threads = 1
|
||||
opts.intra_op_num_threads = 1
|
||||
|
||||
if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
|
||||
self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'], sess_options=opts)
|
||||
else:
|
||||
self.session = onnxruntime.InferenceSession(path, sess_options=opts)
|
||||
|
||||
self.reset_states()
|
||||
if '16k' in path:
|
||||
warnings.warn('This model support only 16000 sampling rate!')
|
||||
self.sample_rates = [16000]
|
||||
else:
|
||||
self.sample_rates = [8000, 16000]
|
||||
|
||||
def _validate_input(self, x, sr: int):
|
||||
if x.dim() == 1:
|
||||
x = x.unsqueeze(0)
|
||||
if x.dim() > 2:
|
||||
raise ValueError(f"Too many dimensions for input audio chunk {x.dim()}")
|
||||
|
||||
if sr != 16000 and (sr % 16000 == 0):
|
||||
step = sr // 16000
|
||||
x = x[:,::step]
|
||||
sr = 16000
|
||||
|
||||
if sr not in self.sample_rates:
|
||||
raise ValueError(f"Supported sampling rates: {self.sample_rates} (or multiply of 16000)")
|
||||
if sr / x.shape[1] > 31.25:
|
||||
raise ValueError("Input audio chunk is too short")
|
||||
|
||||
return x, sr
|
||||
|
||||
def reset_states(self, batch_size=1):
|
||||
self._state = torch.zeros((2, batch_size, 128)).float()
|
||||
self._context = torch.zeros(0)
|
||||
self._last_sr = 0
|
||||
self._last_batch_size = 0
|
||||
|
||||
def __call__(self, x, sr: int):
|
||||
|
||||
x, sr = self._validate_input(x, sr)
|
||||
num_samples = 512 if sr == 16000 else 256
|
||||
|
||||
if x.shape[-1] != num_samples:
|
||||
raise ValueError(f"Provided number of samples is {x.shape[-1]} (Supported values: 256 for 8000 sample rate, 512 for 16000)")
|
||||
|
||||
batch_size = x.shape[0]
|
||||
context_size = 64 if sr == 16000 else 32
|
||||
|
||||
if not self._last_batch_size:
|
||||
self.reset_states(batch_size)
|
||||
if (self._last_sr) and (self._last_sr != sr):
|
||||
self.reset_states(batch_size)
|
||||
if (self._last_batch_size) and (self._last_batch_size != batch_size):
|
||||
self.reset_states(batch_size)
|
||||
|
||||
if not len(self._context):
|
||||
self._context = torch.zeros(batch_size, context_size)
|
||||
|
||||
x = torch.cat([self._context, x], dim=1)
|
||||
if sr in [8000, 16000]:
|
||||
ort_inputs = {'input': x.numpy(), 'state': self._state.numpy(), 'sr': np.array(sr, dtype='int64')}
|
||||
ort_outs = self.session.run(None, ort_inputs)
|
||||
out, state = ort_outs
|
||||
self._state = torch.from_numpy(state)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
self._context = x[..., -context_size:]
|
||||
self._last_sr = sr
|
||||
self._last_batch_size = batch_size
|
||||
|
||||
out = torch.from_numpy(out)
|
||||
return out
|
||||
|
||||
|
||||
def load_silero_vad(model_path: str = None, onnx: bool = False, opset_version: int = 16):
|
||||
"""
|
||||
Load Silero VAD model (JIT or ONNX).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_path : str, optional
|
||||
Path to model file. If None, uses default bundled model.
|
||||
onnx : bool, default False
|
||||
Whether to use ONNX runtime (requires onnxruntime package).
|
||||
opset_version : int, default 16
|
||||
ONNX opset version (15 or 16). Only used if onnx=True.
|
||||
|
||||
Returns
|
||||
-------
|
||||
model
|
||||
Loaded VAD model (JIT or ONNX wrapper)
|
||||
"""
|
||||
available_ops = [15, 16]
|
||||
if onnx and opset_version not in available_ops:
|
||||
raise Exception(f'Available ONNX opset_version: {available_ops}')
|
||||
if model_path is None:
|
||||
current_dir = Path(__file__).parent
|
||||
data_dir = current_dir / 'vad_models'
|
||||
|
||||
if onnx:
|
||||
if opset_version == 16:
|
||||
model_name = 'silero_vad.onnx'
|
||||
else:
|
||||
model_name = f'silero_vad_16k_op{opset_version}.onnx'
|
||||
else:
|
||||
model_name = 'silero_vad.jit'
|
||||
|
||||
model_path = data_dir / model_name
|
||||
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Model file not found: {model_path}\n"
|
||||
f"Please ensure the whisperlivekit/vad_models/ directory contains the model files."
|
||||
)
|
||||
else:
|
||||
model_path = Path(model_path)
|
||||
if onnx:
|
||||
try:
|
||||
model = OnnxWrapper(str(model_path), force_onnx_cpu=True)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"ONNX runtime not available. Install with: pip install onnxruntime\n"
|
||||
"Or use JIT model by setting onnx=False"
|
||||
)
|
||||
else:
|
||||
model = init_jit_model(str(model_path))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class VADIterator:
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
threshold: float = 0.5,
|
||||
sampling_rate: int = 16000,
|
||||
min_silence_duration_ms: int = 500, # makes sense on one recording that I checked
|
||||
speech_pad_ms: int = 100, # same
|
||||
):
|
||||
"""
|
||||
Voice Activity Detection iterator for streaming audio.
|
||||
|
||||
This is the Silero VAD v6 implementation.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model,
|
||||
threshold: float = 0.5,
|
||||
sampling_rate: int = 16000,
|
||||
min_silence_duration_ms: int = 100,
|
||||
speech_pad_ms: int = 30
|
||||
):
|
||||
|
||||
"""
|
||||
Class for stream imitation
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model: preloaded .jit silero VAD model
|
||||
model: preloaded .jit/.onnx silero VAD model
|
||||
|
||||
threshold: float (default - 0.5)
|
||||
Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
|
||||
@@ -42,9 +197,7 @@ class VADIterator:
|
||||
self.sampling_rate = sampling_rate
|
||||
|
||||
if sampling_rate not in [8000, 16000]:
|
||||
raise ValueError(
|
||||
"VADIterator does not support sampling rates other than [8000, 16000]"
|
||||
)
|
||||
raise ValueError('VADIterator does not support sampling rates other than [8000, 16000]')
|
||||
|
||||
self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
|
||||
self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000
|
||||
@@ -57,13 +210,17 @@ class VADIterator:
|
||||
self.temp_end = 0
|
||||
self.current_sample = 0
|
||||
|
||||
def __call__(self, x, return_seconds=False):
|
||||
@torch.no_grad()
|
||||
def __call__(self, x, return_seconds=False, time_resolution: int = 1):
|
||||
"""
|
||||
x: torch.Tensor
|
||||
audio chunk (see examples in repo)
|
||||
|
||||
return_seconds: bool (default - False)
|
||||
whether return timestamps in seconds (default - samples)
|
||||
|
||||
time_resolution: int (default - 1)
|
||||
time resolution of speech coordinates when requested as seconds
|
||||
"""
|
||||
|
||||
if not torch.is_tensor(x):
|
||||
@@ -82,14 +239,8 @@ class VADIterator:
|
||||
|
||||
if (speech_prob >= self.threshold) and not self.triggered:
|
||||
self.triggered = True
|
||||
speech_start = self.current_sample - self.speech_pad_samples
|
||||
return {
|
||||
"start": (
|
||||
int(speech_start)
|
||||
if not return_seconds
|
||||
else round(speech_start / self.sampling_rate, 1)
|
||||
)
|
||||
}
|
||||
speech_start = max(0, self.current_sample - self.speech_pad_samples - window_size_samples)
|
||||
return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, time_resolution)}
|
||||
|
||||
if (speech_prob < self.threshold - 0.15) and self.triggered:
|
||||
if not self.temp_end:
|
||||
@@ -97,30 +248,17 @@ class VADIterator:
|
||||
if self.current_sample - self.temp_end < self.min_silence_samples:
|
||||
return None
|
||||
else:
|
||||
speech_end = self.temp_end + self.speech_pad_samples
|
||||
speech_end = self.temp_end + self.speech_pad_samples - window_size_samples
|
||||
self.temp_end = 0
|
||||
self.triggered = False
|
||||
return {
|
||||
"end": (
|
||||
int(speech_end)
|
||||
if not return_seconds
|
||||
else round(speech_end / self.sampling_rate, 1)
|
||||
)
|
||||
}
|
||||
return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, time_resolution)}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
#######################
|
||||
# because Silero now requires exactly 512-sized audio chunks
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class FixedVADIterator(VADIterator):
|
||||
"""It fixes VADIterator by allowing to process any audio length, not only exactly 512 frames at once.
|
||||
If audio to be processed at once is long and multiple voiced segments detected,
|
||||
then __call__ returns the start of the first segment, and end (or middle, which means no end) of the last segment.
|
||||
"""
|
||||
Fixed VAD Iterator that handles variable-length audio chunks, not only exactly 512 frames at once.
|
||||
"""
|
||||
|
||||
def reset_states(self):
|
||||
@@ -137,27 +275,20 @@ class FixedVADIterator(VADIterator):
|
||||
ret = r
|
||||
elif r is not None:
|
||||
if "end" in r:
|
||||
ret["end"] = r["end"] # the latter end
|
||||
if "start" in r and "end" in ret: # there is an earlier start.
|
||||
# Remove end, merging this segment with the previous one.
|
||||
ret["end"] = r["end"]
|
||||
if "start" in r and "end" in ret:
|
||||
del ret["end"]
|
||||
return ret if ret != {} else None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# test/demonstrate the need for FixedVADIterator:
|
||||
|
||||
import torch
|
||||
|
||||
model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
|
||||
vac = FixedVADIterator(model)
|
||||
# vac = VADIterator(model) # the second case crashes with this
|
||||
|
||||
# this works: for both
|
||||
audio_buffer = np.array([0] * (512), dtype=np.float32)
|
||||
vac(audio_buffer)
|
||||
|
||||
# this crashes on the non FixedVADIterator with
|
||||
# ops.prim.RaiseException("Input audio chunk is too short", "builtins.ValueError")
|
||||
audio_buffer = np.array([0] * (512 - 1), dtype=np.float32)
|
||||
vac(audio_buffer)
|
||||
model = load_silero_vad(onnx=False)
|
||||
vad = FixedVADIterator(model)
|
||||
|
||||
audio_buffer = np.array([0] * 512, dtype=np.float32)
|
||||
result = vad(audio_buffer)
|
||||
print(f" 512 samples: {result}")
|
||||
|
||||
# test with 511 samples
|
||||
audio_buffer = np.array([0] * 511, dtype=np.float32)
|
||||
result = vad(audio_buffer)
|
||||
0
whisperlivekit/vad_models/__init__.py
Normal file
0
whisperlivekit/vad_models/__init__.py
Normal file
BIN
whisperlivekit/vad_models/silero_vad.jit
Normal file
BIN
whisperlivekit/vad_models/silero_vad.jit
Normal file
Binary file not shown.
BIN
whisperlivekit/vad_models/silero_vad.onnx
Normal file
BIN
whisperlivekit/vad_models/silero_vad.onnx
Normal file
Binary file not shown.
BIN
whisperlivekit/vad_models/silero_vad_16k_op15.onnx
Normal file
BIN
whisperlivekit/vad_models/silero_vad_16k_op15.onnx
Normal file
Binary file not shown.
BIN
whisperlivekit/vad_models/silero_vad_half.onnx
Normal file
BIN
whisperlivekit/vad_models/silero_vad_half.onnx
Normal file
Binary file not shown.
Reference in New Issue
Block a user