session parameter required in OnnxWrapper

This commit is contained in:
Quentin Fuxa
2025-12-05 15:37:18 +01:00
parent 2431a6bf91
commit 62444ce746

View File

@@ -54,7 +54,7 @@ class OnnxWrapper():
ONNX Runtime wrapper for Silero VAD model with per-instance state.
"""
def __init__(self, session: OnnxSession = None, force_onnx_cpu=False):
def __init__(self, session: OnnxSession, force_onnx_cpu=False):
self._shared_session = session
self.sample_rates = session.sample_rates
self.reset_states()
@@ -313,11 +313,9 @@ class FixedVADIterator(VADIterator):
if __name__ == "__main__":
# Test JIT model
print("Testing JIT model...")
model = load_jit_vad()
vad = FixedVADIterator(model)
# vad = FixedVADIterator(load_jit_vad())
vad = FixedVADIterator(OnnxWrapper(session=load_onnx_session()))
audio_buffer = np.array([0] * 512, dtype=np.float32)
result = vad(audio_buffer)
print(f" 512 samples: {result}")
@@ -325,24 +323,4 @@ if __name__ == "__main__":
# test with 511 samples
audio_buffer = np.array([0] * 511, dtype=np.float32)
result = vad(audio_buffer)
print(f" 511 samples: {result}")
# Test ONNX with shared session
print("\nTesting ONNX with shared session...")
shared_session = load_onnx_session()
# Create two independent VAD iterators sharing the same session
vad1 = FixedVADIterator(OnnxWrapper(session=shared_session))
vad2 = FixedVADIterator(OnnxWrapper(session=shared_session))
# Both should work independently
audio_buffer = np.array([0] * 512, dtype=np.float32)
result1 = vad1(audio_buffer)
result2 = vad2(audio_buffer)
print(f" VAD1 result: {result1}")
print(f" VAD2 result: {result2}")
# Verify they have separate states
print(f" VAD1 and VAD2 share session: {vad1.model._shared_session is vad2.model._shared_session}")
print(f" VAD1 and VAD2 have separate state: {vad1.model._state is not vad2.model._state}")
print("\nAll tests passed!")
print(f" 511 samples: {result}")