correct error when using VAC

This commit is contained in:
Quentin Fuxa
2025-05-28 11:43:18 +02:00
parent b23ef3ec3e
commit 46770efd6c

View File

@@ -343,15 +343,15 @@ class OnlineASRProcessor:
)
sentences.append(sentence)
return sentences
def finish(self) -> Transcript:
def finish(self) -> List[ASRToken]:
"""
Flush the remaining transcript when processing ends.
"""
remaining_tokens = self.transcript_buffer.buffer
final_transcript = self.concatenate_tokens(remaining_tokens)
logger.debug(f"Final non-committed transcript: {final_transcript}")
logger.debug(f"Final non-committed tokens: {remaining_tokens}")
self.buffer_time_offset += len(self.audio_buffer) / self.SAMPLING_RATE
return final_transcript
return remaining_tokens
def concatenate_tokens(
self,
@@ -384,7 +384,8 @@ class VACOnlineASRProcessor:
def __init__(self, online_chunk_size: float, *args, **kwargs):
self.online_chunk_size = online_chunk_size
self.online = OnlineASRProcessor(*args, **kwargs)
self.asr = self.online.asr
# Load a VAD model (e.g. Silero VAD)
import torch
model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
@@ -455,7 +456,7 @@ class VACOnlineASRProcessor:
self.buffer_offset += max(0, len(self.audio_buffer) - self.SAMPLING_RATE)
self.audio_buffer = self.audio_buffer[-self.SAMPLING_RATE:]
def process_iter(self) -> Transcript:
def process_iter(self) -> List[ASRToken]:
"""
Depending on the VAD status and the amount of accumulated audio,
process the current audio chunk.
@@ -467,9 +468,9 @@ class VACOnlineASRProcessor:
return self.online.process_iter()
else:
logger.debug("No online update, only VAD")
return Transcript(None, None, "")
return []
def finish(self) -> Transcript:
def finish(self) -> List[ASRToken]:
"""Finish processing by flushing any remaining text."""
result = self.online.finish()
self.current_online_chunk_buffer_size = 0
@@ -480,4 +481,4 @@ class VACOnlineASRProcessor:
"""
Get the unvalidated buffer in string format.
"""
return self.online.concatenate_tokens(self.online.transcript_buffer.buffer).text
return self.online.concatenate_tokens(self.online.transcript_buffer.buffer)