Merge branch 'tijszwinkels-online-from-factory'

PR #71
This commit is contained in:
Dominik Macháček
2024-04-17 15:22:11 +02:00
2 changed files with 20 additions and 25 deletions

View File

@@ -551,7 +551,7 @@ def add_shared_args(parser):
def asr_factory(args, logfile=sys.stderr):
"""
Creates and configures an ASR instance based on the specified backend and arguments.
Creates and configures an ASR and ASR Online instance based on the specified backend and arguments.
"""
backend = args.backend
if backend == "openai-api":
@@ -576,8 +576,23 @@ def asr_factory(args, logfile=sys.stderr):
print("Setting VAD filter", file=logfile)
asr.use_vad()
return asr
language = args.lan
if args.task == "translate":
asr.set_translate_task()
tgt_language = "en" # Whisper translates into English
else:
tgt_language = language # Whisper transcribes in this language
# Create the tokenizer
if args.buffer_trimming == "sentence":
tokenizer = create_tokenizer(tgt_language)
else:
tokenizer = None
# Create the OnlineASRProcessor
online = OnlineASRProcessor(asr,tokenizer,logfile=logfile,buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec))
return asr, online
## main:
if __name__ == "__main__":
@@ -605,22 +620,8 @@ if __name__ == "__main__":
duration = len(load_audio(audio_path))/SAMPLING_RATE
print("Audio duration is: %2.2f seconds" % duration, file=logfile)
asr = asr_factory(args, logfile=logfile)
language = args.lan
if args.task == "translate":
asr.set_translate_task()
tgt_language = "en" # Whisper translates into English
else:
tgt_language = language # Whisper transcribes in this language
asr, online = asr_factory(args, logfile=logfile)
min_chunk = args.min_chunk_size
if args.buffer_trimming == "sentence":
tokenizer = create_tokenizer(tgt_language)
else:
tokenizer = None
online = OnlineASRProcessor(asr,tokenizer,logfile=logfile,buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec))
# load the audio into the LRU cache before we start the timer
a = load_audio_chunk(audio_path,0,1)

View File

@@ -25,16 +25,10 @@ SAMPLING_RATE = 16000
size = args.model
language = args.lan
asr = asr_factory(args)
if args.task == "translate":
asr.set_translate_task()
tgt_language = "en"
else:
tgt_language = language
asr, online = asr_factory(args)
min_chunk = args.min_chunk_size
if args.buffer_trimming == "sentence":
tokenizer = create_tokenizer(tgt_language)
else: