mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
Make --vad work with --backend openai-api
This commit is contained in:
@@ -162,7 +162,7 @@ class OpenaiApiASR(ASRBase):
|
||||
|
||||
self.load_model()
|
||||
|
||||
self.use_vad = False
|
||||
self.use_vad_opt = False
|
||||
|
||||
# reset the task in set_translate_task
|
||||
self.task = "transcribe"
|
||||
@@ -175,21 +175,27 @@ class OpenaiApiASR(ASRBase):
|
||||
|
||||
|
||||
def ts_words(self, segments):
|
||||
o = []
|
||||
# If VAD on, skip segments containing no speech.
|
||||
# TODO: threshold can be set from outside
|
||||
# TODO: Make VAD work again with word-level timestamps
|
||||
#if self.use_vad and segment["no_speech_prob"] > 0.8:
|
||||
# continue
|
||||
no_speech_segments = []
|
||||
if self.use_vad_opt:
|
||||
for segment in segments.segments:
|
||||
# TODO: threshold can be set from outside
|
||||
if segment["no_speech_prob"] > 0.8:
|
||||
no_speech_segments.append((segment.get("start"), segment.get("end")))
|
||||
|
||||
for word in segments:
|
||||
o.append((word.get("start"), word.get("end"), word.get("word")))
|
||||
o = []
|
||||
for word in segments.words:
|
||||
start = word.get("start")
|
||||
end = word.get("end")
|
||||
if any(s[0] <= start <= s[1] for s in no_speech_segments):
|
||||
# print("Skipping word", word.get("word"), "because it's in a no-speech segment")
|
||||
continue
|
||||
o.append((start, end, word.get("word")))
|
||||
|
||||
return o
|
||||
|
||||
|
||||
def segments_end_ts(self, res):
|
||||
return [s["end"] for s in res]
|
||||
return [s["end"] for s in res.words]
|
||||
|
||||
def transcribe(self, audio_data, prompt=None, *args, **kwargs):
|
||||
# Write the audio data to a buffer
|
||||
@@ -205,7 +211,7 @@ class OpenaiApiASR(ASRBase):
|
||||
"file": buffer,
|
||||
"response_format": self.response_format,
|
||||
"temperature": self.temperature,
|
||||
"timestamp_granularities": ["word"]
|
||||
"timestamp_granularities": ["word", "segment"]
|
||||
}
|
||||
if self.task != "translate" and self.language:
|
||||
params["language"] = self.language
|
||||
@@ -221,10 +227,10 @@ class OpenaiApiASR(ASRBase):
|
||||
transcript = proc.create(**params)
|
||||
print(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds",file=self.logfile)
|
||||
|
||||
return transcript.words
|
||||
return transcript
|
||||
|
||||
def use_vad(self):
|
||||
self.use_vad = True
|
||||
self.use_vad_opt = True
|
||||
|
||||
def set_translate_task(self):
|
||||
self.task = "translate"
|
||||
@@ -592,9 +598,9 @@ if __name__ == "__main__":
|
||||
e = time.time()
|
||||
print(f"done. It took {round(e-t,2)} seconds.",file=logfile)
|
||||
|
||||
if args.vad:
|
||||
print("setting VAD filter",file=logfile)
|
||||
asr.use_vad()
|
||||
if args.vad:
|
||||
print("setting VAD filter",file=logfile)
|
||||
asr.use_vad()
|
||||
|
||||
if args.task == "translate":
|
||||
asr.set_translate_task()
|
||||
|
||||
Reference in New Issue
Block a user