mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
Fix crash when using openai-api with whisper_online_server
+ refactored creation of the ASR into a factory method
This commit is contained in:
@@ -548,6 +548,37 @@ def add_shared_args(parser):
|
||||
parser.add_argument('--buffer_trimming', type=str, default="segment", choices=["sentence", "segment"],help='Buffer trimming strategy -- trim completed sentences marked with punctuation mark and detected by sentence segmenter, or the completed segments returned by Whisper. Sentence segmenter must be installed for "sentence" option.')
|
||||
parser.add_argument('--buffer_trimming_sec', type=float, default=15, help='Buffer trimming length threshold in seconds. If buffer length is longer, trimming sentence/segment is triggered.')
|
||||
|
||||
def asr_factory(args, logfile=sys.stderr):
|
||||
"""
|
||||
Creates and configures an ASR instance based on the specified backend and arguments.
|
||||
"""
|
||||
backend = args.backend
|
||||
if backend == "openai-api":
|
||||
print("Using OpenAI API.", file=logfile)
|
||||
asr = OpenaiApiASR(lan=args.lan)
|
||||
else:
|
||||
if backend == "faster-whisper":
|
||||
from faster_whisper import FasterWhisperASR
|
||||
asr_cls = FasterWhisperASR
|
||||
else:
|
||||
from whisper_timestamped import WhisperTimestampedASR
|
||||
asr_cls = WhisperTimestampedASR
|
||||
|
||||
# Only for FasterWhisperASR and WhisperTimestampedASR
|
||||
size = args.model
|
||||
t = time.time()
|
||||
print(f"Loading Whisper {size} model for {args.lan}...", file=logfile, end=" ", flush=True)
|
||||
asr = asr_cls(modelsize=size, lan=args.lan, cache_dir=args.model_cache_dir, model_dir=args.model_dir)
|
||||
e = time.time()
|
||||
print(f"done. It took {round(e-t,2)} seconds.", file=logfile)
|
||||
|
||||
# Apply common configurations
|
||||
if getattr(args, 'vad', False): # Checks if VAD argument is present and True
|
||||
print("Setting VAD filter", file=logfile)
|
||||
asr.use_vad()
|
||||
|
||||
return asr
|
||||
|
||||
## main:
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -575,28 +606,8 @@ if __name__ == "__main__":
|
||||
duration = len(load_audio(audio_path))/SAMPLING_RATE
|
||||
print("Audio duration is: %2.2f seconds" % duration, file=logfile)
|
||||
|
||||
asr = asr_factory(args, logfile=logfile)
|
||||
language = args.lan
|
||||
|
||||
if args.backend == "openai-api":
|
||||
print("Using OpenAI API.",file=logfile)
|
||||
asr = OpenaiApiASR(lan=language)
|
||||
else:
|
||||
if args.backend == "faster-whisper":
|
||||
asr_cls = FasterWhisperASR
|
||||
else:
|
||||
asr_cls = WhisperTimestampedASR
|
||||
|
||||
size = args.model
|
||||
t = time.time()
|
||||
print(f"Loading Whisper {size} model for {language}...",file=logfile,end=" ",flush=True)
|
||||
asr = asr_cls(modelsize=size, lan=language, cache_dir=args.model_cache_dir, model_dir=args.model_dir)
|
||||
e = time.time()
|
||||
print(f"done. It took {round(e-t,2)} seconds.",file=logfile)
|
||||
|
||||
if args.vad:
|
||||
print("setting VAD filter",file=logfile)
|
||||
asr.use_vad()
|
||||
|
||||
if args.task == "translate":
|
||||
asr.set_translate_task()
|
||||
tgt_language = "en" # Whisper translates into English
|
||||
|
||||
@@ -24,36 +24,13 @@ SAMPLING_RATE = 16000
|
||||
size = args.model
|
||||
language = args.lan
|
||||
|
||||
t = time.time()
|
||||
print(f"Loading Whisper {size} model for {language}...",file=sys.stderr,end=" ",flush=True)
|
||||
|
||||
if args.backend == "faster-whisper":
|
||||
from faster_whisper import WhisperModel
|
||||
asr_cls = FasterWhisperASR
|
||||
elif args.backend == "openai-api":
|
||||
asr_cls = OpenaiApiASR
|
||||
else:
|
||||
import whisper
|
||||
import whisper_timestamped
|
||||
# from whisper_timestamped_model import WhisperTimestampedASR
|
||||
asr_cls = WhisperTimestampedASR
|
||||
|
||||
asr = asr_cls(modelsize=size, lan=language, cache_dir=args.model_cache_dir, model_dir=args.model_dir)
|
||||
|
||||
asr = asr_factory(args)
|
||||
if args.task == "translate":
|
||||
asr.set_translate_task()
|
||||
tgt_language = "en"
|
||||
else:
|
||||
tgt_language = language
|
||||
|
||||
e = time.time()
|
||||
print(f"done. It took {round(e-t,2)} seconds.",file=sys.stderr)
|
||||
|
||||
if args.vad:
|
||||
print("setting VAD filter",file=sys.stderr)
|
||||
asr.use_vad()
|
||||
|
||||
|
||||
min_chunk = args.min_chunk_size
|
||||
|
||||
if args.buffer_trimming == "sentence":
|
||||
|
||||
Reference in New Issue
Block a user