mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 14:23:18 +00:00
missing features in openai-api, PR #52
This commit is contained in:
@@ -6,8 +6,7 @@ from functools import lru_cache
|
||||
import time
|
||||
import io
|
||||
import soundfile as sf
|
||||
|
||||
|
||||
import math
|
||||
|
||||
@lru_cache
|
||||
def load_audio(fname):
|
||||
@@ -147,24 +146,34 @@ class FasterWhisperASR(ASRBase):
|
||||
class OpenaiApiASR(ASRBase):
|
||||
"""Uses OpenAI's Whisper API for audio transcription."""
|
||||
|
||||
def __init__(self, modelsize=None, lan=None, cache_dir=None, model_dir=None, response_format="verbose_json", temperature=0):
|
||||
self.modelname = "whisper-1" # modelsize is not used but kept for interface consistency
|
||||
def __init__(self, lan=None, response_format="verbose_json", temperature=0, logfile=sys.stderr):
|
||||
self.logfile = logfile
|
||||
|
||||
self.modelname = "whisper-1"
|
||||
self.language = lan # ISO-639-1 language code
|
||||
self.response_format = response_format
|
||||
self.temperature = temperature
|
||||
self.model = self.load_model(modelsize, cache_dir, model_dir)
|
||||
|
||||
self.load_model()
|
||||
|
||||
self.use_vad = False
|
||||
|
||||
# reset the task in set_translate_task
|
||||
self.task = "transcribe"
|
||||
|
||||
def load_model(self, *args, **kwargs):
|
||||
from openai import OpenAI
|
||||
self.client = OpenAI()
|
||||
# Since we're using the OpenAI API, there's no model to load locally.
|
||||
print("Model configuration is set to use the OpenAI Whisper API.")
|
||||
|
||||
self.transcribed_seconds = 0 # for logging how many seconds were processed by API, to know the cost
|
||||
|
||||
|
||||
def ts_words(self, segments):
|
||||
o = []
|
||||
for segment in segments:
|
||||
# Skip segments containing no speech
|
||||
if segment["no_speech_prob"] > 0.8:
|
||||
# If VAD on, skip segments containing no speech.
|
||||
# TODO: threshold can be set from outside
|
||||
if self.use_vad and segment["no_speech_prob"] > 0.8:
|
||||
continue
|
||||
|
||||
# Splitting the text into words and filtering out empty strings
|
||||
@@ -197,23 +206,39 @@ class OpenaiApiASR(ASRBase):
|
||||
sf.write(buffer, audio_data, samplerate=16000, format='WAV', subtype='PCM_16')
|
||||
buffer.seek(0) # Reset buffer's position to the beginning
|
||||
|
||||
# Prepare transcription parameters
|
||||
transcription_params = {
|
||||
self.transcribed_seconds += math.ceil(len(audio_data)/16000) # it rounds up to the whole seconds
|
||||
|
||||
params = {
|
||||
"model": self.modelname,
|
||||
"file": buffer,
|
||||
"response_format": self.response_format,
|
||||
"temperature": self.temperature
|
||||
}
|
||||
if self.language:
|
||||
if self.task != "translate" and self.language:
|
||||
transcription_params["language"] = self.language
|
||||
if prompt:
|
||||
transcription_params["prompt"] = prompt
|
||||
|
||||
# Perform the transcription
|
||||
transcript = self.client.audio.transcriptions.create(**transcription_params)
|
||||
if self.task == "translate":
|
||||
proc = self.client.audio.translations
|
||||
else:
|
||||
proc = self.client.audio.transcriptions
|
||||
|
||||
# Process transcription/translation
|
||||
|
||||
transcript = proc.create(**params)
|
||||
print(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds ",file=self.logfile)
|
||||
|
||||
return transcript.segments
|
||||
|
||||
def use_vad(self):
|
||||
self.use_vad = True
|
||||
|
||||
def set_translate_task(self):
|
||||
self.task = "translate"
|
||||
|
||||
|
||||
|
||||
|
||||
class HypothesisBuffer:
|
||||
|
||||
@@ -557,20 +582,27 @@ if __name__ == "__main__":
|
||||
duration = len(load_audio(audio_path))/SAMPLING_RATE
|
||||
print("Audio duration is: %2.2f seconds" % duration, file=logfile)
|
||||
|
||||
size = args.model
|
||||
language = args.lan
|
||||
|
||||
t = time.time()
|
||||
print(f"Loading Whisper {size} model for {language}...",file=logfile,end=" ",flush=True)
|
||||
|
||||
if args.backend == "faster-whisper":
|
||||
asr_cls = FasterWhisperASR
|
||||
elif args.backend == "openai-api":
|
||||
asr_cls = OpenaiApiASR
|
||||
if args.backend == "openai-api":
|
||||
print("Using OpenAI API.",file=logfile)
|
||||
asr = OpenaiApiASR(lan=language)
|
||||
else:
|
||||
asr_cls = WhisperTimestampedASR
|
||||
if args.backend == "faster-whisper":
|
||||
asr_cls = FasterWhisperASR
|
||||
else:
|
||||
asr_cls = WhisperTimestampedASR
|
||||
|
||||
asr = asr_cls(modelsize=size, lan=language, cache_dir=args.model_cache_dir, model_dir=args.model_dir)
|
||||
size = args.model
|
||||
t = time.time()
|
||||
print(f"Loading Whisper {size} model for {language}...",file=logfile,end=" ",flush=True)
|
||||
asr = asr_cls(modelsize=size, lan=language, cache_dir=args.model_cache_dir, model_dir=args.model_dir)
|
||||
e = time.time()
|
||||
print(f"done. It took {round(e-t,2)} seconds.",file=logfile)
|
||||
|
||||
if args.vad:
|
||||
print("setting VAD filter",file=logfile)
|
||||
asr.use_vad()
|
||||
|
||||
if args.task == "translate":
|
||||
asr.set_translate_task()
|
||||
@@ -578,14 +610,6 @@ if __name__ == "__main__":
|
||||
else:
|
||||
tgt_language = language # Whisper transcribes in this language
|
||||
|
||||
|
||||
e = time.time()
|
||||
print(f"done. It took {round(e-t,2)} seconds.",file=logfile)
|
||||
|
||||
if args.vad:
|
||||
print("setting VAD filter",file=logfile)
|
||||
asr.use_vad()
|
||||
|
||||
|
||||
min_chunk = args.min_chunk_size
|
||||
if args.buffer_trimming == "sentence":
|
||||
|
||||
Reference in New Issue
Block a user