mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 14:23:18 +00:00
session parameter required in OnnxWrapper
This commit is contained in:
@@ -54,7 +54,7 @@ class OnnxWrapper():
|
||||
ONNX Runtime wrapper for Silero VAD model with per-instance state.
|
||||
"""
|
||||
|
||||
def __init__(self, session: OnnxSession = None, force_onnx_cpu=False):
|
||||
def __init__(self, session: OnnxSession, force_onnx_cpu=False):
|
||||
self._shared_session = session
|
||||
self.sample_rates = session.sample_rates
|
||||
self.reset_states()
|
||||
@@ -313,11 +313,9 @@ class FixedVADIterator(VADIterator):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test JIT model
|
||||
print("Testing JIT model...")
|
||||
model = load_jit_vad()
|
||||
vad = FixedVADIterator(model)
|
||||
|
||||
# vad = FixedVADIterator(load_jit_vad())
|
||||
vad = FixedVADIterator(OnnxWrapper(session=load_onnx_session()))
|
||||
|
||||
audio_buffer = np.array([0] * 512, dtype=np.float32)
|
||||
result = vad(audio_buffer)
|
||||
print(f" 512 samples: {result}")
|
||||
@@ -325,24 +323,4 @@ if __name__ == "__main__":
|
||||
# test with 511 samples
|
||||
audio_buffer = np.array([0] * 511, dtype=np.float32)
|
||||
result = vad(audio_buffer)
|
||||
print(f" 511 samples: {result}")
|
||||
|
||||
# Test ONNX with shared session
|
||||
print("\nTesting ONNX with shared session...")
|
||||
shared_session = load_onnx_session()
|
||||
|
||||
# Create two independent VAD iterators sharing the same session
|
||||
vad1 = FixedVADIterator(OnnxWrapper(session=shared_session))
|
||||
vad2 = FixedVADIterator(OnnxWrapper(session=shared_session))
|
||||
|
||||
# Both should work independently
|
||||
audio_buffer = np.array([0] * 512, dtype=np.float32)
|
||||
result1 = vad1(audio_buffer)
|
||||
result2 = vad2(audio_buffer)
|
||||
print(f" VAD1 result: {result1}")
|
||||
print(f" VAD2 result: {result2}")
|
||||
|
||||
# Verify they have separate states
|
||||
print(f" VAD1 and VAD2 share session: {vad1.model._shared_session is vad2.model._shared_session}")
|
||||
print(f" VAD1 and VAD2 have separate state: {vad1.model._state is not vad2.model._state}")
|
||||
print("\nAll tests passed!")
|
||||
print(f" 511 samples: {result}")
|
||||
Reference in New Issue
Block a user