From 62bf28949edc75d8b8ac4c723ce1fee20a8fc35f Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Tue, 1 Jul 2025 20:10:45 +0200 Subject: [PATCH] compatible with the latest version of simulstreaming --- .../whisper_streaming_custom/backends.py | 3 +- .../whisper_streaming_custom/online_asr.py | 85 +++++++++---------- 2 files changed, 44 insertions(+), 44 deletions(-) diff --git a/whisperlivekit/whisper_streaming_custom/backends.py b/whisperlivekit/whisper_streaming_custom/backends.py index 851e39e..38eb79b 100644 --- a/whisperlivekit/whisper_streaming_custom/backends.py +++ b/whisperlivekit/whisper_streaming_custom/backends.py @@ -484,7 +484,8 @@ class SimulStreamingASR(ASRBase): try: if isinstance(audio, np.ndarray): audio = torch.from_numpy(audio).float() - self.model.infer(audio, True) + self.model.insert_audio(audio) + self.model.infer(True) self.model.refresh_segment(complete=True) logger.info("SimulStreaming model warmed up successfully") except Exception as e: diff --git a/whisperlivekit/whisper_streaming_custom/online_asr.py b/whisperlivekit/whisper_streaming_custom/online_asr.py index b689e0f..734a310 100644 --- a/whisperlivekit/whisper_streaming_custom/online_asr.py +++ b/whisperlivekit/whisper_streaming_custom/online_asr.py @@ -609,6 +609,31 @@ class SimulStreamingOnlineProcessor: probability=None ) + def timestamped_text(self, tokens, generation): + # From the simulstreaming repo. self.model to self.asr.model + pr = generation["progress"] + if "result" not in generation: + split_words, split_tokens = self.asr.model.tokenizer.split_to_word_tokens(tokens) + else: + split_words, split_tokens = generation["result"]["split_words"], generation["result"]["split_tokens"] + + frames = [p["most_attended_frames"][0] for p in pr] + tokens = tokens.copy() + ret = [] + for sw,st in zip(split_words,split_tokens): + b = None + for stt in st: + t,f = tokens.pop(0), frames.pop(0) + if t != stt: + raise ValueError(f"Token mismatch: {t} != {stt} at frame {f}.") + if b is None: + b = f + e = f + out = (b*0.02, e*0.02, sw) + ret.append(out) + logger.debug(f"TS-WORD:\t{' '.join(map(str, out))}") + return ret + def process_iter(self) -> Tuple[List[ASRToken], float]: """ Process accumulated audio chunks using SimulStreaming. @@ -633,52 +658,26 @@ class SimulStreamingOnlineProcessor: logger.debug(f"SimulStreaming processing audio shape: {audio.shape}, duration: {audio_duration:.2f}s") logger.debug(f"Current end time: {self.end:.2f}s, last stream time: {self.last_audio_stream_end_time:.2f}s") - result = self.asr.model.infer(audio, is_last=self.is_last) + self.asr.model.insert_audio(audio) + tokens, generation_progress = self.asr.model.infer(is_last=self.is_last) + ts_words = self.timestamped_text(tokens, generation_progress) + text = self.asr.model.tokenizer.decode(tokens) - if torch.is_tensor(result): - # we filter out padding tokens as it s done in simul whisper - from simul_whisper.simul_whisper import DEC_PAD - result = result[result < DEC_PAD] + new_tokens = [] + for ts_word in ts_words: - # C/P from simul_whisper.simul_whisper.py - if len(result) > 0: - decoded_text = self.asr.model.tokenizer.decode(result.cpu().numpy()) - logger.debug(f"SimulStreaming decoded: {decoded_text}") - - if decoded_text.strip(): - words = decoded_text.strip().split() - new_tokens = [] - - num_words = len(words) - if num_words > 0: - # distribute words evenly across the processed audio duration - # we NEED that for when we use diarization. Even if that s not perfect - start_time = self.end - audio_duration - time_per_word = audio_duration / num_words if num_words > 1 else audio_duration - - for i, word in enumerate(words): - token_start = start_time + (i * time_per_word) - token_end = start_time + ((i + 1) * time_per_word) - - token_end = min(token_end, self.end) - - token = ASRToken( - start=token_start, - end=token_end, - text=word, - probability=0.95 # fake prob. Maybe we can extract it from the model? - ) - new_tokens.append(token) - - self.beg = self.end - - self.committed.extend(new_tokens) - self.last_result_tokens = new_tokens - - logger.debug(f"SimulStreaming generated {len(new_tokens)} tokens with end time: {self.end:.2f}s") - return new_tokens, self.end + start, end, word = ts_word + token = ASRToken( + start=start, + end=end, + text=word, + probability=0.95 # fake prob. Maybe we can extract it from the model? + ) + new_tokens.append(token) + self.committed.extend(new_tokens) - return [], self.end + return new_tokens, self.end + except Exception as e: logger.error(f"SimulStreaming processing error: {e}")