diff --git a/whisperlivekit/simul_whisper/backend.py b/whisperlivekit/simul_whisper/backend.py index fa550da..de306bc 100644 --- a/whisperlivekit/simul_whisper/backend.py +++ b/whisperlivekit/simul_whisper/backend.py @@ -38,7 +38,7 @@ class SimulStreamingOnlineProcessor: self.logfile = logfile self.is_last = False self.end = 0.0 - self.cumulative_audio_duration = 0.0 + self.global_time_offset = 0.0 self.committed: List[ASRToken] = [] self.last_result_tokens: List[ASRToken] = [] @@ -59,10 +59,9 @@ class SimulStreamingOnlineProcessor: if silence_duration < 5: gap_silence = torch.zeros(int(16000*min(silence_duration, 1.0))) self.model.insert_audio(gap_silence) - self.model.last_attend_frame += int(TOKENS_PER_SECOND * (min(silence_duration, 1.0) - 1.0)) else: self.model.refresh_segment(complete=True) - self.model.last_attend_frame += int(TOKENS_PER_SECOND * silence_duration) + self.global_time_offset += silence_duration @@ -83,29 +82,51 @@ class SimulStreamingOnlineProcessor: ) def timestamped_text(self, tokens, generation): - # From the simulstreaming repo. self.model to self.asr.model - pr = generation["progress"] - if "result" not in generation: - split_words, split_tokens = self.model.tokenizer.split_to_word_tokens(tokens) + """ + generate timestamped text from tokens and generation data. + + args: + tokens: List of tokens to process + generation: Dictionary containing generation progress and optionally results + + returns: + List of tuples containing (start_time, end_time, word) for each word + """ + FRAME_DURATION = 0.02 + if "result" in generation: + split_words = generation["result"]["split_words"] + split_tokens = generation["result"]["split_tokens"] else: - split_words, split_tokens = generation["result"]["split_words"], generation["result"]["split_tokens"] - - frames = [p["most_attended_frames"][0] for p in pr] - tokens = tokens.copy() - ret = [] - for sw,st in zip(split_words, split_tokens): - b = None - for stt in st: - t,f = tokens.pop(0), frames.pop(0) - if t != stt: - raise ValueError(f"Token mismatch: {t} != {stt} at frame {f}.") - if b is None: - b = f - e = f - out = (b*0.02, e*0.02, sw) - ret.append(out) - logger.debug(f"TS-WORD:\t{' '.join(map(str, out))}") - return ret + split_words, split_tokens = self.model.tokenizer.split_to_word_tokens(tokens) + progress = generation["progress"] + frames = [p["most_attended_frames"][0] for p in progress] + tokens_queue = tokens.copy() + timestamped_words = [] + + for word, word_tokens in zip(split_words, split_tokens): + start_frame = None + end_frame = None + for expected_token in word_tokens: + if not tokens_queue or not frames: + raise ValueError(f"Insufficient tokens or frames for word '{word}'") + + actual_token = tokens_queue.pop(0) + current_frame = frames.pop(0) + if actual_token != expected_token: + raise ValueError( + f"Token mismatch: expected '{expected_token}', " + f"got '{actual_token}' at frame {current_frame}" + ) + if start_frame is None: + start_frame = current_frame + end_frame = current_frame + start_time = start_frame * FRAME_DURATION + end_time = end_frame * FRAME_DURATION + + timestamp_entry = (start_time, end_time, word) + timestamped_words.append(timestamp_entry) + logger.debug(f"TS-WORD:\t{start_time:.2f}\t{end_time:.2f}\t{word}") + return timestamped_words def process_iter(self) -> Tuple[List[ASRToken], float]: """ @@ -126,6 +147,8 @@ class SimulStreamingOnlineProcessor: end=end, text=word, probability=0.95 # fake prob. Maybe we can extract it from the model? + ).with_offset( + self.global_time_offset ) new_tokens.append(token) diff --git a/whisperlivekit/whisper_streaming_custom/online_asr.py b/whisperlivekit/whisper_streaming_custom/online_asr.py index d575035..1f8a2a0 100644 --- a/whisperlivekit/whisper_streaming_custom/online_asr.py +++ b/whisperlivekit/whisper_streaming_custom/online_asr.py @@ -122,6 +122,7 @@ class OnlineASRProcessor: self.tokenize = tokenize_method self.logfile = logfile self.confidence_validation = confidence_validation + self.global_time_offset = 0.0 self.init() self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming @@ -152,6 +153,17 @@ class OnlineASRProcessor: """Append an audio chunk (a numpy array) to the current audio buffer.""" self.audio_buffer = np.append(self.audio_buffer, audio) + def insert_silence(self, silence_duration): + """ + If silences are > 5s, we do a complete context clear. Otherwise, we just insert a small silence and shift the last_attend_frame + """ + if silence_duration < 3: + gap_silence = np.zeros(int(16000 * silence_duration), dtype=np.int16) + self.insert_audio_chunk(gap_silence) + else: + self.init(offset=(silence_duration + self.buffer_time_offset) / self.SAMPLING_RATE) + self.global_time_offset += silence_duration + def prompt(self) -> Tuple[str, str]: """ Returns a tuple: (prompt, context), where: @@ -230,6 +242,9 @@ class OnlineASRProcessor: logger.debug( f"Length of audio buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds" ) + if self.global_time_offset: + for token in committed_tokens: + token.with_offset(self.global_time_offset) return committed_tokens, current_audio_processed_upto def chunk_completed_sentence(self):