missing features in openai-api, PR #52

This commit is contained in:
Dominik Macháček
2024-01-25 16:49:25 +01:00
parent ab27bfb361
commit 50f1b94856

View File

@@ -6,8 +6,7 @@ from functools import lru_cache
import time
import io
import soundfile as sf
import math
@lru_cache
def load_audio(fname):
@@ -147,24 +146,34 @@ class FasterWhisperASR(ASRBase):
class OpenaiApiASR(ASRBase):
"""Uses OpenAI's Whisper API for audio transcription."""
def __init__(self, modelsize=None, lan=None, cache_dir=None, model_dir=None, response_format="verbose_json", temperature=0):
self.modelname = "whisper-1" # modelsize is not used but kept for interface consistency
def __init__(self, lan=None, response_format="verbose_json", temperature=0, logfile=sys.stderr):
self.logfile = logfile
self.modelname = "whisper-1"
self.language = lan # ISO-639-1 language code
self.response_format = response_format
self.temperature = temperature
self.model = self.load_model(modelsize, cache_dir, model_dir)
self.load_model()
self.use_vad = False
# reset the task in set_translate_task
self.task = "transcribe"
def load_model(self, *args, **kwargs):
from openai import OpenAI
self.client = OpenAI()
# Since we're using the OpenAI API, there's no model to load locally.
print("Model configuration is set to use the OpenAI Whisper API.")
self.transcribed_seconds = 0 # for logging how many seconds were processed by API, to know the cost
def ts_words(self, segments):
o = []
for segment in segments:
# Skip segments containing no speech
if segment["no_speech_prob"] > 0.8:
# If VAD on, skip segments containing no speech.
# TODO: threshold can be set from outside
if self.use_vad and segment["no_speech_prob"] > 0.8:
continue
# Splitting the text into words and filtering out empty strings
@@ -197,23 +206,39 @@ class OpenaiApiASR(ASRBase):
sf.write(buffer, audio_data, samplerate=16000, format='WAV', subtype='PCM_16')
buffer.seek(0) # Reset buffer's position to the beginning
# Prepare transcription parameters
transcription_params = {
self.transcribed_seconds += math.ceil(len(audio_data)/16000) # it rounds up to the whole seconds
params = {
"model": self.modelname,
"file": buffer,
"response_format": self.response_format,
"temperature": self.temperature
}
if self.language:
if self.task != "translate" and self.language:
transcription_params["language"] = self.language
if prompt:
transcription_params["prompt"] = prompt
# Perform the transcription
transcript = self.client.audio.transcriptions.create(**transcription_params)
if self.task == "translate":
proc = self.client.audio.translations
else:
proc = self.client.audio.transcriptions
# Process transcription/translation
transcript = proc.create(**params)
print(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds ",file=self.logfile)
return transcript.segments
def use_vad(self):
self.use_vad = True
def set_translate_task(self):
self.task = "translate"
class HypothesisBuffer:
@@ -557,20 +582,27 @@ if __name__ == "__main__":
duration = len(load_audio(audio_path))/SAMPLING_RATE
print("Audio duration is: %2.2f seconds" % duration, file=logfile)
size = args.model
language = args.lan
t = time.time()
print(f"Loading Whisper {size} model for {language}...",file=logfile,end=" ",flush=True)
if args.backend == "faster-whisper":
asr_cls = FasterWhisperASR
elif args.backend == "openai-api":
asr_cls = OpenaiApiASR
if args.backend == "openai-api":
print("Using OpenAI API.",file=logfile)
asr = OpenaiApiASR(lan=language)
else:
asr_cls = WhisperTimestampedASR
if args.backend == "faster-whisper":
asr_cls = FasterWhisperASR
else:
asr_cls = WhisperTimestampedASR
asr = asr_cls(modelsize=size, lan=language, cache_dir=args.model_cache_dir, model_dir=args.model_dir)
size = args.model
t = time.time()
print(f"Loading Whisper {size} model for {language}...",file=logfile,end=" ",flush=True)
asr = asr_cls(modelsize=size, lan=language, cache_dir=args.model_cache_dir, model_dir=args.model_dir)
e = time.time()
print(f"done. It took {round(e-t,2)} seconds.",file=logfile)
if args.vad:
print("setting VAD filter",file=logfile)
asr.use_vad()
if args.task == "translate":
asr.set_translate_task()
@@ -578,14 +610,6 @@ if __name__ == "__main__":
else:
tgt_language = language # Whisper transcribes in this language
e = time.time()
print(f"done. It took {round(e-t,2)} seconds.",file=logfile)
if args.vad:
print("setting VAD filter",file=logfile)
asr.use_vad()
min_chunk = args.min_chunk_size
if args.buffer_trimming == "sentence":