diff --git a/whisper_online.py b/whisper_online.py index c872b23..651ceb4 100644 --- a/whisper_online.py +++ b/whisper_online.py @@ -551,7 +551,7 @@ def add_shared_args(parser): def asr_factory(args, logfile=sys.stderr): """ - Creates and configures an ASR instance based on the specified backend and arguments. + Creates and configures an ASR and ASR Online instance based on the specified backend and arguments. """ backend = args.backend if backend == "openai-api": @@ -576,8 +576,23 @@ def asr_factory(args, logfile=sys.stderr): print("Setting VAD filter", file=logfile) asr.use_vad() - return asr + language = args.lan + if args.task == "translate": + asr.set_translate_task() + tgt_language = "en" # Whisper translates into English + else: + tgt_language = language # Whisper transcribes in this language + # Create the tokenizer + if args.buffer_trimming == "sentence": + tokenizer = create_tokenizer(tgt_language) + else: + tokenizer = None + + # Create the OnlineASRProcessor + online = OnlineASRProcessor(asr,tokenizer,logfile=logfile,buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec)) + + return asr, online ## main: if __name__ == "__main__": @@ -605,22 +620,8 @@ if __name__ == "__main__": duration = len(load_audio(audio_path))/SAMPLING_RATE print("Audio duration is: %2.2f seconds" % duration, file=logfile) - asr = asr_factory(args, logfile=logfile) - language = args.lan - if args.task == "translate": - asr.set_translate_task() - tgt_language = "en" # Whisper translates into English - else: - tgt_language = language # Whisper transcribes in this language - - + asr, online = asr_factory(args, logfile=logfile) min_chunk = args.min_chunk_size - if args.buffer_trimming == "sentence": - tokenizer = create_tokenizer(tgt_language) - else: - tokenizer = None - online = OnlineASRProcessor(asr,tokenizer,logfile=logfile,buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec)) - # load the audio into the LRU cache before we start the timer a = load_audio_chunk(audio_path,0,1) diff --git a/whisper_online_server.py b/whisper_online_server.py index 263ab75..f652a75 100644 --- a/whisper_online_server.py +++ b/whisper_online_server.py @@ -25,16 +25,10 @@ SAMPLING_RATE = 16000 size = args.model language = args.lan - -asr = asr_factory(args) -if args.task == "translate": - asr.set_translate_task() - tgt_language = "en" -else: - tgt_language = language - +asr, online = asr_factory(args) min_chunk = args.min_chunk_size + if args.buffer_trimming == "sentence": tokenizer = create_tokenizer(tgt_language) else: