Make --vad work with --backend openai-api

This commit is contained in:
Tijs Zwinkels
2024-02-10 15:29:18 +01:00
parent 3696fef2b1
commit f0a24cd5e1

View File

@@ -162,7 +162,7 @@ class OpenaiApiASR(ASRBase):
self.load_model()
self.use_vad = False
self.use_vad_opt = False
# reset the task in set_translate_task
self.task = "transcribe"
@@ -175,21 +175,27 @@ class OpenaiApiASR(ASRBase):
def ts_words(self, segments):
o = []
# If VAD on, skip segments containing no speech.
# TODO: threshold can be set from outside
# TODO: Make VAD work again with word-level timestamps
#if self.use_vad and segment["no_speech_prob"] > 0.8:
# continue
no_speech_segments = []
if self.use_vad_opt:
for segment in segments.segments:
# TODO: threshold can be set from outside
if segment["no_speech_prob"] > 0.8:
no_speech_segments.append((segment.get("start"), segment.get("end")))
for word in segments:
o.append((word.get("start"), word.get("end"), word.get("word")))
o = []
for word in segments.words:
start = word.get("start")
end = word.get("end")
if any(s[0] <= start <= s[1] for s in no_speech_segments):
# print("Skipping word", word.get("word"), "because it's in a no-speech segment")
continue
o.append((start, end, word.get("word")))
return o
def segments_end_ts(self, res):
return [s["end"] for s in res]
return [s["end"] for s in res.words]
def transcribe(self, audio_data, prompt=None, *args, **kwargs):
# Write the audio data to a buffer
@@ -205,7 +211,7 @@ class OpenaiApiASR(ASRBase):
"file": buffer,
"response_format": self.response_format,
"temperature": self.temperature,
"timestamp_granularities": ["word"]
"timestamp_granularities": ["word", "segment"]
}
if self.task != "translate" and self.language:
params["language"] = self.language
@@ -221,10 +227,10 @@ class OpenaiApiASR(ASRBase):
transcript = proc.create(**params)
print(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds",file=self.logfile)
return transcript.words
return transcript
def use_vad(self):
self.use_vad = True
self.use_vad_opt = True
def set_translate_task(self):
self.task = "translate"
@@ -592,9 +598,9 @@ if __name__ == "__main__":
e = time.time()
print(f"done. It took {round(e-t,2)} seconds.",file=logfile)
if args.vad:
print("setting VAD filter",file=logfile)
asr.use_vad()
if args.vad:
print("setting VAD filter",file=logfile)
asr.use_vad()
if args.task == "translate":
asr.set_translate_task()