diff --git a/whisper_online.py b/whisper_online.py index d79396a..c90babb 100644 --- a/whisper_online.py +++ b/whisper_online.py @@ -548,6 +548,37 @@ def add_shared_args(parser): parser.add_argument('--buffer_trimming', type=str, default="segment", choices=["sentence", "segment"],help='Buffer trimming strategy -- trim completed sentences marked with punctuation mark and detected by sentence segmenter, or the completed segments returned by Whisper. Sentence segmenter must be installed for "sentence" option.') parser.add_argument('--buffer_trimming_sec', type=float, default=15, help='Buffer trimming length threshold in seconds. If buffer length is longer, trimming sentence/segment is triggered.') +def asr_factory(args, logfile=sys.stderr): + """ + Creates and configures an ASR instance based on the specified backend and arguments. + """ + backend = args.backend + if backend == "openai-api": + print("Using OpenAI API.", file=logfile) + asr = OpenaiApiASR(lan=args.lan) + else: + if backend == "faster-whisper": + from faster_whisper import FasterWhisperASR + asr_cls = FasterWhisperASR + else: + from whisper_timestamped import WhisperTimestampedASR + asr_cls = WhisperTimestampedASR + + # Only for FasterWhisperASR and WhisperTimestampedASR + size = args.model + t = time.time() + print(f"Loading Whisper {size} model for {args.lan}...", file=logfile, end=" ", flush=True) + asr = asr_cls(modelsize=size, lan=args.lan, cache_dir=args.model_cache_dir, model_dir=args.model_dir) + e = time.time() + print(f"done. It took {round(e-t,2)} seconds.", file=logfile) + + # Apply common configurations + if getattr(args, 'vad', False): # Checks if VAD argument is present and True + print("Setting VAD filter", file=logfile) + asr.use_vad() + + return asr + ## main: if __name__ == "__main__": @@ -575,28 +606,8 @@ if __name__ == "__main__": duration = len(load_audio(audio_path))/SAMPLING_RATE print("Audio duration is: %2.2f seconds" % duration, file=logfile) + asr = asr_factory(args, logfile=logfile) language = args.lan - - if args.backend == "openai-api": - print("Using OpenAI API.",file=logfile) - asr = OpenaiApiASR(lan=language) - else: - if args.backend == "faster-whisper": - asr_cls = FasterWhisperASR - else: - asr_cls = WhisperTimestampedASR - - size = args.model - t = time.time() - print(f"Loading Whisper {size} model for {language}...",file=logfile,end=" ",flush=True) - asr = asr_cls(modelsize=size, lan=language, cache_dir=args.model_cache_dir, model_dir=args.model_dir) - e = time.time() - print(f"done. It took {round(e-t,2)} seconds.",file=logfile) - - if args.vad: - print("setting VAD filter",file=logfile) - asr.use_vad() - if args.task == "translate": asr.set_translate_task() tgt_language = "en" # Whisper translates into English diff --git a/whisper_online_server.py b/whisper_online_server.py index 0cdc97d..7f81caa 100644 --- a/whisper_online_server.py +++ b/whisper_online_server.py @@ -24,36 +24,13 @@ SAMPLING_RATE = 16000 size = args.model language = args.lan -t = time.time() -print(f"Loading Whisper {size} model for {language}...",file=sys.stderr,end=" ",flush=True) - -if args.backend == "faster-whisper": - from faster_whisper import WhisperModel - asr_cls = FasterWhisperASR -elif args.backend == "openai-api": - asr_cls = OpenaiApiASR -else: - import whisper - import whisper_timestamped -# from whisper_timestamped_model import WhisperTimestampedASR - asr_cls = WhisperTimestampedASR - -asr = asr_cls(modelsize=size, lan=language, cache_dir=args.model_cache_dir, model_dir=args.model_dir) - +asr = asr_factory(args) if args.task == "translate": asr.set_translate_task() tgt_language = "en" else: tgt_language = language -e = time.time() -print(f"done. It took {round(e-t,2)} seconds.",file=sys.stderr) - -if args.vad: - print("setting VAD filter",file=sys.stderr) - asr.use_vad() - - min_chunk = args.min_chunk_size if args.buffer_trimming == "sentence":