From bccbb15177137f5ea7246b2211791f2481c41279 Mon Sep 17 00:00:00 2001 From: Tijs Zwinkels Date: Wed, 20 Mar 2024 16:29:01 +0100 Subject: [PATCH] Move creation of OnlineASRProcessor inside the factory method Preventing more code duplication between whisper_online.py and whisper_online_server.py --- whisper_online.py | 35 ++++++++++++++++++----------------- whisper_online_server.py | 17 +---------------- 2 files changed, 19 insertions(+), 33 deletions(-) diff --git a/whisper_online.py b/whisper_online.py index a00547e..c4a90e3 100644 --- a/whisper_online.py +++ b/whisper_online.py @@ -551,7 +551,7 @@ def add_shared_args(parser): def asr_factory(args, logfile=sys.stderr): """ - Creates and configures an ASR instance based on the specified backend and arguments. + Creates and configures an ASR and ASR Online instance based on the specified backend and arguments. """ backend = args.backend if backend == "openai-api": @@ -576,8 +576,23 @@ def asr_factory(args, logfile=sys.stderr): print("Setting VAD filter", file=logfile) asr.use_vad() - return asr + language = args.lan + if args.task == "translate": + asr.set_translate_task() + tgt_language = "en" # Whisper translates into English + else: + tgt_language = language # Whisper transcribes in this language + # Create the tokenizer + if args.buffer_trimming == "sentence": + tokenizer = create_tokenizer(tgt_language) + else: + tokenizer = None + + # Create the OnlineASRProcessor + online = OnlineASRProcessor(asr,tokenizer,logfile=logfile,buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec)) + + return asr, online ## main: if __name__ == "__main__": @@ -605,22 +620,8 @@ if __name__ == "__main__": duration = len(load_audio(audio_path))/SAMPLING_RATE print("Audio duration is: %2.2f seconds" % duration, file=logfile) - asr = asr_factory(args, logfile=logfile) - language = args.lan - if args.task == "translate": - asr.set_translate_task() - tgt_language = "en" # Whisper translates into English - else: - tgt_language = language # Whisper transcribes in this language - - + asr, online = asr_factory(args, logfile=logfile) min_chunk = args.min_chunk_size - if args.buffer_trimming == "sentence": - tokenizer = create_tokenizer(tgt_language) - else: - tokenizer = None - online = OnlineASRProcessor(asr,tokenizer,logfile=logfile,buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec)) - # load the audio into the LRU cache before we start the timer a = load_audio_chunk(audio_path,0,1) diff --git a/whisper_online_server.py b/whisper_online_server.py index 7f81caa..188038a 100644 --- a/whisper_online_server.py +++ b/whisper_online_server.py @@ -23,24 +23,9 @@ SAMPLING_RATE = 16000 size = args.model language = args.lan - -asr = asr_factory(args) -if args.task == "translate": - asr.set_translate_task() - tgt_language = "en" -else: - tgt_language = language - +asr, online = asr_factory(args) min_chunk = args.min_chunk_size -if args.buffer_trimming == "sentence": - tokenizer = create_tokenizer(tgt_language) -else: - tokenizer = None -online = OnlineASRProcessor(asr,tokenizer,buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec)) - - - demo_audio_path = "cs-maji-2.16k.wav" if os.path.exists(demo_audio_path): # load the audio into the LRU cache before we start the timer