diff --git a/whisperlivekit/whisper_streaming_custom/backends.py b/whisperlivekit/whisper_streaming_custom/backends.py index 91c11a7..32b5044 100644 --- a/whisperlivekit/whisper_streaming_custom/backends.py +++ b/whisperlivekit/whisper_streaming_custom/backends.py @@ -425,8 +425,10 @@ class SimulStreamingASR(ASRBase): # We dont have word-level timestamps here. 1rst approach, should be improved later. words = text.strip().split() if not words: - return tokens - duration_per_word = 0.5 # rough estimate of 0.5 seconds per word... Not so great + return tokens + + duration_per_word = 0.1 # this will be modified based on actual audio duration + #with the SimulStreamingOnlineProcessor for i, word in enumerate(words): start_time = i * duration_per_word diff --git a/whisperlivekit/whisper_streaming_custom/online_asr.py b/whisperlivekit/whisper_streaming_custom/online_asr.py index bc7c69f..a17e94b 100644 --- a/whisperlivekit/whisper_streaming_custom/online_asr.py +++ b/whisperlivekit/whisper_streaming_custom/online_asr.py @@ -544,10 +544,12 @@ class SimulStreamingOnlineProcessor: self.beg = self.offset self.end = self.offset self.cumulative_audio_duration = 0.0 + self.last_audio_stream_end_time = self.offset self.committed: List[ASRToken] = [] self.last_result_tokens: List[ASRToken] = [] self.buffer_content = "" + self.processed_audio_duration = 0.0 def get_audio_buffer_end_time(self) -> float: """Returns the absolute end time of the current audio buffer.""" @@ -565,7 +567,12 @@ class SimulStreamingOnlineProcessor: # Update timing chunk_duration = len(audio) / self.SAMPLING_RATE self.cumulative_audio_duration += chunk_duration - self.end = self.offset + self.cumulative_audio_duration + + if audio_stream_end_time is not None: + self.last_audio_stream_end_time = audio_stream_end_time + self.end = audio_stream_end_time + else: + self.end = self.offset + self.cumulative_audio_duration def prompt(self) -> Tuple[str, str]: """ @@ -578,9 +585,10 @@ class SimulStreamingOnlineProcessor: """ Get the unvalidated buffer content. """ + buffer_end = self.end if hasattr(self, 'end') else None return Transcript( start=None, - end=None, + end=buffer_end, text=self.buffer_content, probability=None ) @@ -601,12 +609,13 @@ class SimulStreamingOnlineProcessor: else: audio = torch.cat(self.audio_chunks, dim=0) - if audio.shape[0] > 0: - self.end = self.offset + (audio.shape[0] / self.SAMPLING_RATE) + audio_duration = audio.shape[0] / self.SAMPLING_RATE if audio.shape[0] > 0 else 0 + self.processed_audio_duration += audio_duration self.audio_chunks = [] - logger.debug(f"SimulStreaming processing audio shape: {audio.shape}") + logger.debug(f"SimulStreaming processing audio shape: {audio.shape}, duration: {audio_duration:.2f}s") + logger.debug(f"Current end time: {self.end:.2f}s, last stream time: {self.last_audio_stream_end_time:.2f}s") result = self.asr.model.infer(audio, is_last=self.is_last) @@ -624,27 +633,33 @@ class SimulStreamingOnlineProcessor: words = decoded_text.strip().split() new_tokens = [] - current_time = self.beg - word_duration = 0.3 # Not great should be improved. - - for word in words: - token_start = current_time - token_end = current_time + word_duration - token = ASRToken( - start=token_start, - end=token_end, - text=word, - probability=0.95 # fake prob. Maybe we can extract it from the model? - ) - new_tokens.append(token) - current_time = token_end + num_words = len(words) + if num_words > 0: + # distribute words evenly across the processed audio duration + # we NEED that for when we use diarization. Even if that s not perfect + start_time = self.end - audio_duration + time_per_word = audio_duration / num_words if num_words > 1 else audio_duration + + for i, word in enumerate(words): + token_start = start_time + (i * time_per_word) + token_end = start_time + ((i + 1) * time_per_word) + + token_end = min(token_end, self.end) + + token = ASRToken( + start=token_start, + end=token_end, + text=word, + probability=0.95 # fake prob. Maybe we can extract it from the model? + ) + new_tokens.append(token) self.beg = self.end self.committed.extend(new_tokens) self.last_result_tokens = new_tokens - logger.debug(f"SimulStreaming generated {len(new_tokens)} tokens") + logger.debug(f"SimulStreaming generated {len(new_tokens)} tokens with end time: {self.end:.2f}s") return new_tokens, self.end return [], self.end