diff --git a/whisperlivekit/local_agreement/backends.py b/whisperlivekit/local_agreement/backends.py index 001e13e..95ab0e8 100644 --- a/whisperlivekit/local_agreement/backends.py +++ b/whisperlivekit/local_agreement/backends.py @@ -249,6 +249,7 @@ class OpenaiApiASR(ASRBase): self.load_model() self.use_vad_opt = False self.direct_english_translation = False + self.task = "transcribe" def load_model(self, *args, **kwargs): from openai import OpenAI @@ -294,7 +295,8 @@ class OpenaiApiASR(ASRBase): params["language"] = self.original_language if prompt: params["prompt"] = prompt - proc = self.client.audio.translations if self.task == "translate" else self.client.audio.transcriptions + task = self.transcribe_kargs.get("task", self.task) + proc = self.client.audio.translations if task == "translate" else self.client.audio.transcriptions transcript = proc.create(**params) logger.debug(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds") return transcript diff --git a/whisperlivekit/local_agreement/whisper_online.py b/whisperlivekit/local_agreement/whisper_online.py index d74ac54..5093196 100644 --- a/whisperlivekit/local_agreement/whisper_online.py +++ b/whisperlivekit/local_agreement/whisper_online.py @@ -146,6 +146,7 @@ def backend_factory( if direct_english_translation: tgt_language = "en" # Whisper translates into English + asr.transcribe_kargs["task"] = "translate" else: tgt_language = lan # Whisper transcribes in this language @@ -154,9 +155,9 @@ def backend_factory( tokenizer = create_tokenizer(tgt_language) else: tokenizer = None - + warmup_asr(asr, warmup_file) - + asr.confidence_validation = confidence_validation asr.tokenizer = tokenizer asr.buffer_trimming = buffer_trimming