From 46770efd6c445c1cbf00aba62fc4f2bffbc4b0bf Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Wed, 28 May 2025 11:43:18 +0200 Subject: [PATCH] correct error when using VAC --- .../whisper_streaming_custom/online_asr.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/whisperlivekit/whisper_streaming_custom/online_asr.py b/whisperlivekit/whisper_streaming_custom/online_asr.py index 1af06de..8432662 100644 --- a/whisperlivekit/whisper_streaming_custom/online_asr.py +++ b/whisperlivekit/whisper_streaming_custom/online_asr.py @@ -343,15 +343,15 @@ class OnlineASRProcessor: ) sentences.append(sentence) return sentences - def finish(self) -> Transcript: + + def finish(self) -> List[ASRToken]: """ Flush the remaining transcript when processing ends. """ remaining_tokens = self.transcript_buffer.buffer - final_transcript = self.concatenate_tokens(remaining_tokens) - logger.debug(f"Final non-committed transcript: {final_transcript}") + logger.debug(f"Final non-committed tokens: {remaining_tokens}") self.buffer_time_offset += len(self.audio_buffer) / self.SAMPLING_RATE - return final_transcript + return remaining_tokens def concatenate_tokens( self, @@ -384,7 +384,8 @@ class VACOnlineASRProcessor: def __init__(self, online_chunk_size: float, *args, **kwargs): self.online_chunk_size = online_chunk_size self.online = OnlineASRProcessor(*args, **kwargs) - + self.asr = self.online.asr + # Load a VAD model (e.g. Silero VAD) import torch model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad") @@ -455,7 +456,7 @@ class VACOnlineASRProcessor: self.buffer_offset += max(0, len(self.audio_buffer) - self.SAMPLING_RATE) self.audio_buffer = self.audio_buffer[-self.SAMPLING_RATE:] - def process_iter(self) -> Transcript: + def process_iter(self) -> List[ASRToken]: """ Depending on the VAD status and the amount of accumulated audio, process the current audio chunk. @@ -467,9 +468,9 @@ class VACOnlineASRProcessor: return self.online.process_iter() else: logger.debug("No online update, only VAD") - return Transcript(None, None, "") + return [] - def finish(self) -> Transcript: + def finish(self) -> List[ASRToken]: """Finish processing by flushing any remaining text.""" result = self.online.finish() self.current_online_chunk_buffer_size = 0 @@ -480,4 +481,4 @@ class VACOnlineASRProcessor: """ Get the unvalidated buffer in string format. """ - return self.online.concatenate_tokens(self.online.transcript_buffer.buffer).text + return self.online.concatenate_tokens(self.online.transcript_buffer.buffer)