From 15089c80fdd31e0d1c3f192a002d0c3bfa4b1038 Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Wed, 12 Feb 2025 23:36:52 +0000 Subject: [PATCH] Add get_buffer method to retrieve unvalidated buffer in string format --- src/whisper_streaming/online_asr.py | 16 +++++++++++++++- whisper_fastapi_online_server.py | 8 ++------ 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/src/whisper_streaming/online_asr.py b/src/whisper_streaming/online_asr.py index 605f9a7..0ad0bed 100644 --- a/src/whisper_streaming/online_asr.py +++ b/src/whisper_streaming/online_asr.py @@ -85,6 +85,7 @@ class HypothesisBuffer: self.committed_in_buffer.pop(0) + class OnlineASRProcessor: """ Processes incoming audio in a streaming fashion, calling the ASR system @@ -163,6 +164,13 @@ class OnlineASRProcessor: context_text = self.asr.sep.join(token.text for token in non_prompt_tokens) return self.asr.sep.join(prompt_list[::-1]), context_text + def get_buffer(self): + """ + Get the unvalidated buffer in string format. + """ + return self.concatenate_tokens(self.transcript_buffer.buffer).text + + def process_iter(self) -> Transcript: """ Processes the current audio buffer. @@ -413,4 +421,10 @@ class VACOnlineASRProcessor: result = self.online.finish() self.current_online_chunk_buffer_size = 0 self.is_currently_final = False - return result \ No newline at end of file + return result + + def get_buffer(self): + """ + Get the unvalidated buffer in string format. + """ + return self.online.concatenate_tokens(self.online.transcript_buffer.buffer).text diff --git a/whisper_fastapi_online_server.py b/whisper_fastapi_online_server.py index fbdbe12..1ea8731 100644 --- a/whisper_fastapi_online_server.py +++ b/whisper_fastapi_online_server.py @@ -158,12 +158,8 @@ async def websocket_endpoint(websocket: WebSocket): }) full_transcription += transcription.text - if args.vac: - transcript = online.online.concatenate_tokens(online.online.transcript_buffer.buffer) - else: - transcript = online.concatenate_tokens(online.transcript_buffer.buffer) - - buffer = transcript.text + buffer = online.get_buffer() + if buffer in full_transcription: # With VAC, the buffer is not updated until the next chunk is processed buffer = ""