diff --git a/.gitignore b/.gitignore index 6bb0fc4..a015198 100644 --- a/.gitignore +++ b/.gitignore @@ -54,21 +54,6 @@ coverage.xml # Translations *.mo *.pot -# Django stuff: -*.log -local_settings.py -db.sqlite3 -db.sqlite3-journal - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ # PyBuilder target/ @@ -138,4 +123,5 @@ test_*.py launch.json .DS_Store test/* -nllb-200-distilled-600M-ctranslate2/* \ No newline at end of file +nllb-200-distilled-600M-ctranslate2/* +*.mp3 \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index b023646..4e8ddb9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,8 +50,20 @@ Homepage = "https://github.com/QuentinFuxa/WhisperLiveKit" whisperlivekit-server = "whisperlivekit.basic_server:main" [tool.setuptools] -packages = ["whisperlivekit", "whisperlivekit.diarization", "whisperlivekit.simul_whisper", "whisperlivekit.simul_whisper.whisper", "whisperlivekit.simul_whisper.whisper.assets", "whisperlivekit.simul_whisper.whisper.normalizers", "whisperlivekit.web", "whisperlivekit.whisper_streaming_custom", "whisperlivekit.translation"] +packages = [ + "whisperlivekit", + "whisperlivekit.diarization", + "whisperlivekit.simul_whisper", + "whisperlivekit.simul_whisper.whisper", + "whisperlivekit.simul_whisper.whisper.assets", + "whisperlivekit.simul_whisper.whisper.normalizers", + "whisperlivekit.web", + "whisperlivekit.whisper_streaming_custom", + "whisperlivekit.translation", + "whisperlivekit.vad_models" +] [tool.setuptools.package-data] whisperlivekit = ["web/*.html", "web/*.css", "web/*.js", "web/src/*.svg"] "whisperlivekit.simul_whisper.whisper.assets" = ["*.tiktoken", "*.npz"] +"whisperlivekit.vad_models" = ["*.jit", "*.onnx"] diff --git a/whisperlivekit/core.py b/whisperlivekit/core.py index 4af585c..955a3aa 100644 --- a/whisperlivekit/core.py +++ b/whisperlivekit/core.py @@ -33,6 +33,7 @@ class TranscriptionEngine: "punctuation_split": False, "target_language": "", "vac": True, + "vac_onnx": False, "vac_chunk_size": 0.04, "log_level": "DEBUG", "ssl_certfile": None, @@ -75,8 +76,10 @@ class TranscriptionEngine: self.vac_model = None if self.args.vac: - import torch - self.vac_model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad") + from whisperlivekit.silero_vad_iterator import load_silero_vad + # Use ONNX if specified, otherwise use JIT (default) + use_onnx = kwargs.get('vac_onnx', False) + self.vac_model = load_silero_vad(onnx=use_onnx) if self.args.transcription: if self.args.backend == "simulstreaming": @@ -173,4 +176,4 @@ def online_translation_factory(args, translation_model): #one shared nllb model for all speaker #one tokenizer per speaker/language from whisperlivekit.translation.translation import OnlineTranslation - return OnlineTranslation(translation_model, [args.lan], [args.target_language]) \ No newline at end of file + return OnlineTranslation(translation_model, [args.lan], [args.target_language]) diff --git a/whisperlivekit/silero_vad_iterator.py b/whisperlivekit/silero_vad_iterator.py index 6f633da..ad2d2ba 100644 --- a/whisperlivekit/silero_vad_iterator.py +++ b/whisperlivekit/silero_vad_iterator.py @@ -1,27 +1,182 @@ import torch +import numpy as np +import warnings +from pathlib import Path -# This is copied from silero-vad's vad_utils.py: -# https://github.com/snakers4/silero-vad/blob/f6b1294cb27590fb2452899df98fb234dfef1134/utils_vad.py#L340 -# (except changed defaults) +""" +Code is adapted from silero-vad v6: https://github.com/snakers4/silero-vad +""" -# Their licence is MIT, same as ours: https://github.com/snakers4/silero-vad/blob/f6b1294cb27590fb2452899df98fb234dfef1134/LICENSE +def init_jit_model(model_path: str, device=torch.device('cpu')): + """Load a JIT model from file.""" + model = torch.jit.load(model_path, map_location=device) + model.eval() + return model + + +class OnnxWrapper(): + """ONNX Runtime wrapper for Silero VAD model.""" + + def __init__(self, path, force_onnx_cpu=False): + global np + import numpy as np + import onnxruntime + + opts = onnxruntime.SessionOptions() + opts.inter_op_num_threads = 1 + opts.intra_op_num_threads = 1 + + if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers(): + self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'], sess_options=opts) + else: + self.session = onnxruntime.InferenceSession(path, sess_options=opts) + + self.reset_states() + if '16k' in path: + warnings.warn('This model support only 16000 sampling rate!') + self.sample_rates = [16000] + else: + self.sample_rates = [8000, 16000] + + def _validate_input(self, x, sr: int): + if x.dim() == 1: + x = x.unsqueeze(0) + if x.dim() > 2: + raise ValueError(f"Too many dimensions for input audio chunk {x.dim()}") + + if sr != 16000 and (sr % 16000 == 0): + step = sr // 16000 + x = x[:,::step] + sr = 16000 + + if sr not in self.sample_rates: + raise ValueError(f"Supported sampling rates: {self.sample_rates} (or multiply of 16000)") + if sr / x.shape[1] > 31.25: + raise ValueError("Input audio chunk is too short") + + return x, sr + + def reset_states(self, batch_size=1): + self._state = torch.zeros((2, batch_size, 128)).float() + self._context = torch.zeros(0) + self._last_sr = 0 + self._last_batch_size = 0 + + def __call__(self, x, sr: int): + + x, sr = self._validate_input(x, sr) + num_samples = 512 if sr == 16000 else 256 + + if x.shape[-1] != num_samples: + raise ValueError(f"Provided number of samples is {x.shape[-1]} (Supported values: 256 for 8000 sample rate, 512 for 16000)") + + batch_size = x.shape[0] + context_size = 64 if sr == 16000 else 32 + + if not self._last_batch_size: + self.reset_states(batch_size) + if (self._last_sr) and (self._last_sr != sr): + self.reset_states(batch_size) + if (self._last_batch_size) and (self._last_batch_size != batch_size): + self.reset_states(batch_size) + + if not len(self._context): + self._context = torch.zeros(batch_size, context_size) + + x = torch.cat([self._context, x], dim=1) + if sr in [8000, 16000]: + ort_inputs = {'input': x.numpy(), 'state': self._state.numpy(), 'sr': np.array(sr, dtype='int64')} + ort_outs = self.session.run(None, ort_inputs) + out, state = ort_outs + self._state = torch.from_numpy(state) + else: + raise ValueError() + + self._context = x[..., -context_size:] + self._last_sr = sr + self._last_batch_size = batch_size + + out = torch.from_numpy(out) + return out + + +def load_silero_vad(model_path: str = None, onnx: bool = False, opset_version: int = 16): + """ + Load Silero VAD model (JIT or ONNX). + + Parameters + ---------- + model_path : str, optional + Path to model file. If None, uses default bundled model. + onnx : bool, default False + Whether to use ONNX runtime (requires onnxruntime package). + opset_version : int, default 16 + ONNX opset version (15 or 16). Only used if onnx=True. + + Returns + ------- + model + Loaded VAD model (JIT or ONNX wrapper) + """ + available_ops = [15, 16] + if onnx and opset_version not in available_ops: + raise Exception(f'Available ONNX opset_version: {available_ops}') + if model_path is None: + current_dir = Path(__file__).parent + data_dir = current_dir / 'vad_models' + + if onnx: + if opset_version == 16: + model_name = 'silero_vad.onnx' + else: + model_name = f'silero_vad_16k_op{opset_version}.onnx' + else: + model_name = 'silero_vad.jit' + + model_path = data_dir / model_name + + if not model_path.exists(): + raise FileNotFoundError( + f"Model file not found: {model_path}\n" + f"Please ensure the whisperlivekit/vad_models/ directory contains the model files." + ) + else: + model_path = Path(model_path) + if onnx: + try: + model = OnnxWrapper(str(model_path), force_onnx_cpu=True) + except ImportError: + raise ImportError( + "ONNX runtime not available. Install with: pip install onnxruntime\n" + "Or use JIT model by setting onnx=False" + ) + else: + model = init_jit_model(str(model_path)) + + return model class VADIterator: - def __init__( - self, - model, - threshold: float = 0.5, - sampling_rate: int = 16000, - min_silence_duration_ms: int = 500, # makes sense on one recording that I checked - speech_pad_ms: int = 100, # same - ): + """ + Voice Activity Detection iterator for streaming audio. + + This is the Silero VAD v6 implementation. + """ + + def __init__(self, + model, + threshold: float = 0.5, + sampling_rate: int = 16000, + min_silence_duration_ms: int = 100, + speech_pad_ms: int = 30 + ): + """ Class for stream imitation Parameters ---------- - model: preloaded .jit silero VAD model + model: preloaded .jit/.onnx silero VAD model threshold: float (default - 0.5) Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH. @@ -42,9 +197,7 @@ class VADIterator: self.sampling_rate = sampling_rate if sampling_rate not in [8000, 16000]: - raise ValueError( - "VADIterator does not support sampling rates other than [8000, 16000]" - ) + raise ValueError('VADIterator does not support sampling rates other than [8000, 16000]') self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000 self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000 @@ -57,13 +210,17 @@ class VADIterator: self.temp_end = 0 self.current_sample = 0 - def __call__(self, x, return_seconds=False): + @torch.no_grad() + def __call__(self, x, return_seconds=False, time_resolution: int = 1): """ x: torch.Tensor audio chunk (see examples in repo) return_seconds: bool (default - False) whether return timestamps in seconds (default - samples) + + time_resolution: int (default - 1) + time resolution of speech coordinates when requested as seconds """ if not torch.is_tensor(x): @@ -82,14 +239,8 @@ class VADIterator: if (speech_prob >= self.threshold) and not self.triggered: self.triggered = True - speech_start = self.current_sample - self.speech_pad_samples - return { - "start": ( - int(speech_start) - if not return_seconds - else round(speech_start / self.sampling_rate, 1) - ) - } + speech_start = max(0, self.current_sample - self.speech_pad_samples - window_size_samples) + return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, time_resolution)} if (speech_prob < self.threshold - 0.15) and self.triggered: if not self.temp_end: @@ -97,30 +248,17 @@ class VADIterator: if self.current_sample - self.temp_end < self.min_silence_samples: return None else: - speech_end = self.temp_end + self.speech_pad_samples + speech_end = self.temp_end + self.speech_pad_samples - window_size_samples self.temp_end = 0 self.triggered = False - return { - "end": ( - int(speech_end) - if not return_seconds - else round(speech_end / self.sampling_rate, 1) - ) - } + return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, time_resolution)} return None -####################### -# because Silero now requires exactly 512-sized audio chunks - -import numpy as np - - class FixedVADIterator(VADIterator): - """It fixes VADIterator by allowing to process any audio length, not only exactly 512 frames at once. - If audio to be processed at once is long and multiple voiced segments detected, - then __call__ returns the start of the first segment, and end (or middle, which means no end) of the last segment. + """ + Fixed VAD Iterator that handles variable-length audio chunks, not only exactly 512 frames at once. """ def reset_states(self): @@ -137,27 +275,20 @@ class FixedVADIterator(VADIterator): ret = r elif r is not None: if "end" in r: - ret["end"] = r["end"] # the latter end - if "start" in r and "end" in ret: # there is an earlier start. - # Remove end, merging this segment with the previous one. + ret["end"] = r["end"] + if "start" in r and "end" in ret: del ret["end"] return ret if ret != {} else None if __name__ == "__main__": - # test/demonstrate the need for FixedVADIterator: - - import torch - - model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad") - vac = FixedVADIterator(model) - # vac = VADIterator(model) # the second case crashes with this - - # this works: for both - audio_buffer = np.array([0] * (512), dtype=np.float32) - vac(audio_buffer) - - # this crashes on the non FixedVADIterator with - # ops.prim.RaiseException("Input audio chunk is too short", "builtins.ValueError") - audio_buffer = np.array([0] * (512 - 1), dtype=np.float32) - vac(audio_buffer) + model = load_silero_vad(onnx=False) + vad = FixedVADIterator(model) + + audio_buffer = np.array([0] * 512, dtype=np.float32) + result = vad(audio_buffer) + print(f" 512 samples: {result}") + + # test with 511 samples + audio_buffer = np.array([0] * 511, dtype=np.float32) + result = vad(audio_buffer) \ No newline at end of file diff --git a/whisperlivekit/vad_models/__init__.py b/whisperlivekit/vad_models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/whisperlivekit/vad_models/silero_vad.jit b/whisperlivekit/vad_models/silero_vad.jit new file mode 100644 index 0000000..30868c9 Binary files /dev/null and b/whisperlivekit/vad_models/silero_vad.jit differ diff --git a/whisperlivekit/vad_models/silero_vad.onnx b/whisperlivekit/vad_models/silero_vad.onnx new file mode 100644 index 0000000..cb60519 Binary files /dev/null and b/whisperlivekit/vad_models/silero_vad.onnx differ diff --git a/whisperlivekit/vad_models/silero_vad_16k_op15.onnx b/whisperlivekit/vad_models/silero_vad_16k_op15.onnx new file mode 100644 index 0000000..3036c9e Binary files /dev/null and b/whisperlivekit/vad_models/silero_vad_16k_op15.onnx differ diff --git a/whisperlivekit/vad_models/silero_vad_half.onnx b/whisperlivekit/vad_models/silero_vad_half.onnx new file mode 100644 index 0000000..97e39fb Binary files /dev/null and b/whisperlivekit/vad_models/silero_vad_half.onnx differ