From 63870987c0e0505ab0c594d56d1a6cce439030bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dominik=20Mach=C3=A1=C4=8Dek?= Date: Fri, 4 Oct 2024 17:14:55 +0200 Subject: [PATCH] FixedSileroVADIterator to support other than 512-sized chunks with v5 isssue #116 --- silero_vad.py | 37 +++++++++++++++++++++++++++++++++++++ whisper_online.py | 2 +- 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/silero_vad.py b/silero_vad.py index 7735215..9e79e9d 100644 --- a/silero_vad.py +++ b/silero_vad.py @@ -94,4 +94,41 @@ class VADIterator: return None +####################### +# this is our workaround for Silero v5 requiring at least 512-sized audio chunks +# (see https://github.com/ufal/whisper_streaming/issues/116 ) +import numpy as np +class FixedVADIterator(VADIterator): + + def reset_states(self): + super().reset_states() + self.buffer = np.array([],dtype=np.float32) + + def __call__(self, x, return_seconds=False): + self.buffer = np.append(self.buffer, x) + if len(self.buffer) >= 512: + ret = super().__call__(self.buffer, return_seconds=return_seconds) + self.buffer = np.array([],dtype=np.float32) + return ret + return None + +if __name__ == "__main__": + # test/demonstrate the need for FixedVADIterator: + + import torch + model, _ = torch.hub.load( + repo_or_dir='snakers4/silero-vad', + model='silero_vad' + ) + vac = FixedVADIterator(model) +# vac = VADIterator(model) # the second case crashes with this + + # this works: for both + audio_buffer = np.array([0]*(512),dtype=np.float32) + vac(audio_buffer) + + # this crashes on the non FixedVADIterator with + # ops.prim.RaiseException("Input audio chunk is too short", "builtins.ValueError") + audio_buffer = np.array([0]*(512-1),dtype=np.float32) + vac(audio_buffer) diff --git a/whisper_online.py b/whisper_online.py index 4ed7804..d3e4a5c 100644 --- a/whisper_online.py +++ b/whisper_online.py @@ -531,7 +531,7 @@ class VACOnlineASRProcessor(OnlineASRProcessor): # VAC: import torch model, _ = torch.hub.load( - repo_or_dir='snakers4/silero-vad:v4.0', + repo_or_dir='snakers4/silero-vad', model='silero_vad' ) from silero_vad import VADIterator