clean SimulStreamingOnlineProcessor initialization + audio processing

This commit is contained in:
Quentin Fuxa
2025-08-09 20:16:27 +02:00
parent b05297a96d
commit 5491964e81
2 changed files with 6 additions and 40 deletions

View File

@@ -125,10 +125,7 @@ def online_factory(args, asr, tokenizer, logfile=sys.stderr):
from simul_whisper import SimulStreamingOnlineProcessor
online = SimulStreamingOnlineProcessor(
asr,
tokenizer,
logfile=logfile,
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
confidence_validation=args.confidence_validation
)
elif args.vac:
online = VACOnlineASRProcessor(

View File

@@ -23,45 +23,34 @@ class SimulStreamingOnlineProcessor:
def __init__(
self,
asr,
tokenize_method: Optional[callable] = None,
buffer_trimming: Tuple[str, float] = ("segment", 15),
confidence_validation = False,
logfile=sys.stderr,
):
self.asr = asr
self.logfile = logfile
self.confidence_validation = confidence_validation
# buffer does not work yet
self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming
self.audio_chunks = []
self.offset = 0.0
self.is_last = False
self.beg = self.offset
self.end = self.offset
self.beg = 0.0
self.end = 0.0
self.cumulative_audio_duration = 0.0
self.last_audio_stream_end_time = self.offset
self.committed: List[ASRToken] = []
self.last_result_tokens: List[ASRToken] = []
self.buffer_content = ""
self.processed_audio_duration = 0.0
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: Optional[float] = None):
"""Append an audio chunk to be processed by SimulStreaming."""
# Convert numpy array to torch tensor
audio_tensor = torch.from_numpy(audio).float()
self.audio_chunks.append(audio_tensor)
# Update timing
chunk_duration = len(audio) / self.SAMPLING_RATE
self.cumulative_audio_duration += chunk_duration
if audio_stream_end_time is not None:
self.last_audio_stream_end_time = audio_stream_end_time
self.end = audio_stream_end_time
else:
self.end = self.offset + self.cumulative_audio_duration
self.end = self.cumulative_audio_duration
self.asr.model.insert_audio(audio_tensor)
def get_buffer(self):
"""
@@ -106,25 +95,7 @@ class SimulStreamingOnlineProcessor:
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
"""
if not self.audio_chunks:
return [], self.end
try:
# concatenate all audio chunks
if len(self.audio_chunks) == 1:
audio = self.audio_chunks[0]
else:
audio = torch.cat(self.audio_chunks, dim=0)
audio_duration = audio.shape[0] / self.SAMPLING_RATE if audio.shape[0] > 0 else 0
self.processed_audio_duration += audio_duration
self.audio_chunks = []
logger.debug(f"SimulStreaming processing audio shape: {audio.shape}, duration: {audio_duration:.2f}s")
logger.debug(f"Current end time: {self.end:.2f}s, last stream time: {self.last_audio_stream_end_time:.2f}s")
self.asr.model.insert_audio(audio)
try:
tokens, generation_progress = self.asr.model.infer(is_last=self.is_last)
ts_words = self.timestamped_text(tokens, generation_progress)
@@ -214,9 +185,7 @@ class SimulStreamingASR():
init_prompt=self.init_prompt,
max_context_tokens=self.max_context_tokens,
static_init_prompt=self.static_init_prompt,
)
logger.info(f"Loading SimulStreaming model with language: {self.original_language}")
)
model = PaddedAlignAttWhisper(cfg)
return model