From 50f1b94856cd916d7e5bca1650fc8fc8ff3104f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dominik=20Mach=C3=A1=C4=8Dek?= Date: Thu, 25 Jan 2024 16:49:25 +0100 Subject: [PATCH] missing features in openai-api, PR #52 --- whisper_online.py | 88 ++++++++++++++++++++++++++++++----------------- 1 file changed, 56 insertions(+), 32 deletions(-) diff --git a/whisper_online.py b/whisper_online.py index 860c82d..e4bfff8 100644 --- a/whisper_online.py +++ b/whisper_online.py @@ -6,8 +6,7 @@ from functools import lru_cache import time import io import soundfile as sf - - +import math @lru_cache def load_audio(fname): @@ -147,24 +146,34 @@ class FasterWhisperASR(ASRBase): class OpenaiApiASR(ASRBase): """Uses OpenAI's Whisper API for audio transcription.""" - def __init__(self, modelsize=None, lan=None, cache_dir=None, model_dir=None, response_format="verbose_json", temperature=0): - self.modelname = "whisper-1" # modelsize is not used but kept for interface consistency + def __init__(self, lan=None, response_format="verbose_json", temperature=0, logfile=sys.stderr): + self.logfile = logfile + + self.modelname = "whisper-1" self.language = lan # ISO-639-1 language code self.response_format = response_format self.temperature = temperature - self.model = self.load_model(modelsize, cache_dir, model_dir) + + self.load_model() + + self.use_vad = False + + # reset the task in set_translate_task + self.task = "transcribe" def load_model(self, *args, **kwargs): from openai import OpenAI self.client = OpenAI() - # Since we're using the OpenAI API, there's no model to load locally. - print("Model configuration is set to use the OpenAI Whisper API.") + + self.transcribed_seconds = 0 # for logging how many seconds were processed by API, to know the cost + def ts_words(self, segments): o = [] for segment in segments: - # Skip segments containing no speech - if segment["no_speech_prob"] > 0.8: + # If VAD on, skip segments containing no speech. + # TODO: threshold can be set from outside + if self.use_vad and segment["no_speech_prob"] > 0.8: continue # Splitting the text into words and filtering out empty strings @@ -197,23 +206,39 @@ class OpenaiApiASR(ASRBase): sf.write(buffer, audio_data, samplerate=16000, format='WAV', subtype='PCM_16') buffer.seek(0) # Reset buffer's position to the beginning - # Prepare transcription parameters - transcription_params = { + self.transcribed_seconds += math.ceil(len(audio_data)/16000) # it rounds up to the whole seconds + + params = { "model": self.modelname, "file": buffer, "response_format": self.response_format, "temperature": self.temperature } - if self.language: + if self.task != "translate" and self.language: transcription_params["language"] = self.language if prompt: transcription_params["prompt"] = prompt - # Perform the transcription - transcript = self.client.audio.transcriptions.create(**transcription_params) + if self.task == "translate": + proc = self.client.audio.translations + else: + proc = self.client.audio.transcriptions + + # Process transcription/translation + + transcript = proc.create(**params) + print(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds ",file=self.logfile) return transcript.segments + def use_vad(self): + self.use_vad = True + + def set_translate_task(self): + self.task = "translate" + + + class HypothesisBuffer: @@ -557,20 +582,27 @@ if __name__ == "__main__": duration = len(load_audio(audio_path))/SAMPLING_RATE print("Audio duration is: %2.2f seconds" % duration, file=logfile) - size = args.model language = args.lan - t = time.time() - print(f"Loading Whisper {size} model for {language}...",file=logfile,end=" ",flush=True) - - if args.backend == "faster-whisper": - asr_cls = FasterWhisperASR - elif args.backend == "openai-api": - asr_cls = OpenaiApiASR + if args.backend == "openai-api": + print("Using OpenAI API.",file=logfile) + asr = OpenaiApiASR(lan=language) else: - asr_cls = WhisperTimestampedASR + if args.backend == "faster-whisper": + asr_cls = FasterWhisperASR + else: + asr_cls = WhisperTimestampedASR - asr = asr_cls(modelsize=size, lan=language, cache_dir=args.model_cache_dir, model_dir=args.model_dir) + size = args.model + t = time.time() + print(f"Loading Whisper {size} model for {language}...",file=logfile,end=" ",flush=True) + asr = asr_cls(modelsize=size, lan=language, cache_dir=args.model_cache_dir, model_dir=args.model_dir) + e = time.time() + print(f"done. It took {round(e-t,2)} seconds.",file=logfile) + + if args.vad: + print("setting VAD filter",file=logfile) + asr.use_vad() if args.task == "translate": asr.set_translate_task() @@ -578,14 +610,6 @@ if __name__ == "__main__": else: tgt_language = language # Whisper transcribes in this language - - e = time.time() - print(f"done. It took {round(e-t,2)} seconds.",file=logfile) - - if args.vad: - print("setting VAD filter",file=logfile) - asr.use_vad() - min_chunk = args.min_chunk_size if args.buffer_trimming == "sentence":