diff --git a/whisperlivekit/audio_processor.py b/whisperlivekit/audio_processor.py index 1962bab..e5ca35c 100644 --- a/whisperlivekit/audio_processor.py +++ b/whisperlivekit/audio_processor.py @@ -10,7 +10,7 @@ from whisperlivekit.core import (TranscriptionEngine, online_diarization_factory, online_factory, online_translation_factory) from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState -from whisperlivekit.silero_vad_iterator import FixedVADIterator +from whisperlivekit.silero_vad_iterator import FixedVADIterator, OnnxWrapper, load_jit_vad from whisperlivekit.timed_objects import (ASRToken, ChangeSpeaker, FrontData, Segment, Silence, State, Transcript) from whisperlivekit.tokens_alignment import TokensAlignment @@ -85,12 +85,14 @@ class AudioProcessor: # Models and processing self.asr: Any = models.asr - self.vac_model: Any = models.vac_model + self.vac: Optional[FixedVADIterator] = None + if self.args.vac: - self.vac: Optional[FixedVADIterator] = FixedVADIterator(models.vac_model) - else: - self.vac: Optional[FixedVADIterator] = None - + if models.vac_session is not None: + vac_model = OnnxWrapper(session=models.vac_session) + self.vac = FixedVADIterator(vac_model) + else: + self.vac = FixedVADIterator(load_jit_vad()) self.ffmpeg_manager: Optional[FFmpegManager] = None self.ffmpeg_reader_task: Optional[asyncio.Task] = None self._ffmpeg_error: Optional[str] = None diff --git a/whisperlivekit/core.py b/whisperlivekit/core.py index d510f63..71a667e 100644 --- a/whisperlivekit/core.py +++ b/whisperlivekit/core.py @@ -36,7 +36,6 @@ class TranscriptionEngine: "punctuation_split": False, "target_language": "", "vac": True, - "vac_onnx": False, "vac_chunk_size": 0.04, "log_level": "DEBUG", "ssl_certfile": None, @@ -79,15 +78,19 @@ class TranscriptionEngine: self.asr = None self.tokenizer = None self.diarization = None - self.vac_model = None + self.vac_session = None if self.args.vac: - from whisperlivekit.silero_vad_iterator import load_silero_vad - - # Use ONNX if specified, otherwise use JIT (default) - use_onnx = kwargs.get('vac_onnx', False) - self.vac_model = load_silero_vad(onnx=use_onnx) - + from whisperlivekit.silero_vad_iterator import is_onnx_available + + if is_onnx_available(): + from whisperlivekit.silero_vad_iterator import load_onnx_session + self.vac_session = load_onnx_session() + else: + logger.warning( + "onnxruntime not installed. VAC will use JIT model which is loaded per-session. " + "For multi-user scenarios, install onnxruntime: pip install onnxruntime" + ) backend_policy = self.args.backend_policy if self.args.transcription: if backend_policy == "simulstreaming": diff --git a/whisperlivekit/silero_vad_iterator.py b/whisperlivekit/silero_vad_iterator.py index 45ca63c..9ece6a2 100644 --- a/whisperlivekit/silero_vad_iterator.py +++ b/whisperlivekit/silero_vad_iterator.py @@ -8,6 +8,15 @@ import torch Code is adapted from silero-vad v6: https://github.com/snakers4/silero-vad """ +def is_onnx_available() -> bool: + """Check if onnxruntime is installed.""" + try: + import onnxruntime + return True + except ImportError: + return False + + def init_jit_model(model_path: str, device=torch.device('cpu')): """Load a JIT model from file.""" model = torch.jit.load(model_path, map_location=device) @@ -15,12 +24,12 @@ def init_jit_model(model_path: str, device=torch.device('cpu')): return model -class OnnxWrapper(): - """ONNX Runtime wrapper for Silero VAD model.""" +class OnnxSession(): + """ + Shared ONNX session for Silero VAD model (stateless). + """ def __init__(self, path, force_onnx_cpu=False): - global np - import numpy as np import onnxruntime opts = onnxruntime.SessionOptions() @@ -32,13 +41,28 @@ class OnnxWrapper(): else: self.session = onnxruntime.InferenceSession(path, sess_options=opts) - self.reset_states() + self.path = path if '16k' in path: warnings.warn('This model support only 16000 sampling rate!') self.sample_rates = [16000] else: self.sample_rates = [8000, 16000] + +class OnnxWrapper(): + """ + ONNX Runtime wrapper for Silero VAD model with per-instance state. + """ + + def __init__(self, session: OnnxSession = None, force_onnx_cpu=False): + self._shared_session = session + self.sample_rates = session.sample_rates + self.reset_states() + + @property + def session(self): + return self._shared_session.session + def _validate_input(self, x, sr: int): if x.dim() == 1: x = x.unsqueeze(0) @@ -101,38 +125,20 @@ class OnnxWrapper(): return out -def load_silero_vad(model_path: str = None, onnx: bool = False, opset_version: int = 16): - """ - Load Silero VAD model (JIT or ONNX). - - Parameters - ---------- - model_path : str, optional - Path to model file. If None, uses default bundled model. - onnx : bool, default False - Whether to use ONNX runtime (requires onnxruntime package). - opset_version : int, default 16 - ONNX opset version (15 or 16). Only used if onnx=True. - - Returns - ------- - model - Loaded VAD model (JIT or ONNX wrapper) - """ +def _get_onnx_model_path(model_path: str = None, opset_version: int = 16) -> Path: + """Get the path to the ONNX model file.""" available_ops = [15, 16] - if onnx and opset_version not in available_ops: + if opset_version not in available_ops: raise Exception(f'Available ONNX opset_version: {available_ops}') + if model_path is None: current_dir = Path(__file__).parent data_dir = current_dir / 'silero_vad_models' - if onnx: - if opset_version == 16: - model_name = 'silero_vad.onnx' - else: - model_name = f'silero_vad_16k_op{opset_version}.onnx' + if opset_version == 16: + model_name = 'silero_vad.onnx' else: - model_name = 'silero_vad.jit' + model_name = f'silero_vad_16k_op{opset_version}.onnx' model_path = data_dir / model_name @@ -143,17 +149,39 @@ def load_silero_vad(model_path: str = None, onnx: bool = False, opset_version: i ) else: model_path = Path(model_path) - if onnx: - try: - model = OnnxWrapper(str(model_path), force_onnx_cpu=True) - except ImportError: - raise ImportError( - "ONNX runtime not available. Install with: pip install onnxruntime\n" - "Or use JIT model by setting onnx=False" + + return model_path + + +def load_onnx_session(model_path: str = None, opset_version: int = 16, force_onnx_cpu: bool = True) -> OnnxSession: + """ + Load a shared ONNX session for Silero VAD. + """ + path = _get_onnx_model_path(model_path, opset_version) + return OnnxSession(str(path), force_onnx_cpu=force_onnx_cpu) + + +def load_jit_vad(model_path: str = None): + """ + Load Silero VAD model in JIT format. + """ + if model_path is None: + current_dir = Path(__file__).parent + data_dir = current_dir / 'silero_vad_models' + model_name = 'silero_vad.jit' + + model_path = data_dir / model_name + + if not model_path.exists(): + raise FileNotFoundError( + f"Model file not found: {model_path}\n" + f"Please ensure the whisperlivekit/silero_vad_models/ directory contains the model files." ) else: - model = init_jit_model(str(model_path)) + model_path = Path(model_path) + model = init_jit_model(str(model_path)) + return model @@ -285,7 +313,9 @@ class FixedVADIterator(VADIterator): if __name__ == "__main__": - model = load_silero_vad(onnx=False) + # Test JIT model + print("Testing JIT model...") + model = load_jit_vad() vad = FixedVADIterator(model) audio_buffer = np.array([0] * 512, dtype=np.float32) @@ -294,4 +324,25 @@ if __name__ == "__main__": # test with 511 samples audio_buffer = np.array([0] * 511, dtype=np.float32) - result = vad(audio_buffer) \ No newline at end of file + result = vad(audio_buffer) + print(f" 511 samples: {result}") + + # Test ONNX with shared session + print("\nTesting ONNX with shared session...") + shared_session = load_onnx_session() + + # Create two independent VAD iterators sharing the same session + vad1 = FixedVADIterator(OnnxWrapper(session=shared_session)) + vad2 = FixedVADIterator(OnnxWrapper(session=shared_session)) + + # Both should work independently + audio_buffer = np.array([0] * 512, dtype=np.float32) + result1 = vad1(audio_buffer) + result2 = vad2(audio_buffer) + print(f" VAD1 result: {result1}") + print(f" VAD2 result: {result2}") + + # Verify they have separate states + print(f" VAD1 and VAD2 share session: {vad1.model._shared_session is vad2.model._shared_session}") + print(f" VAD1 and VAD2 have separate state: {vad1.model._state is not vad2.model._state}") + print("\nAll tests passed!") \ No newline at end of file