diff --git a/whisperlivekit/audio_processor.py b/whisperlivekit/audio_processor.py index 5e74819..213cf52 100644 --- a/whisperlivekit/audio_processor.py +++ b/whisperlivekit/audio_processor.py @@ -67,6 +67,8 @@ class AudioProcessor: self.is_stopping = False self.silence = False self.silence_duration = 0.0 + self.start_silence = None + self.last_silence_dispatch_time = None self.state = State() self.lock = asyncio.Lock() self.sep = " " # Default separator @@ -128,6 +130,34 @@ class AudioProcessor: if models.translation_model: self.translation = online_translation_factory(self.args, models.translation_model) + async def _push_silence_event(self, silence_buffer: Silence): + if not self.diarization_before_transcription and self.transcription_queue: + await self.transcription_queue.put(silence_buffer) + if self.args.diarization and self.diarization_queue: + await self.diarization_queue.put(silence_buffer) + if self.translation_queue: + await self.translation_queue.put(silence_buffer) + + async def _begin_silence(self): + if self.silence: + return + self.silence = True + now = time() + self.start_silence = now + self.last_silence_dispatch_time = now + await self._push_silence_event(Silence(is_starting=True)) + + async def _end_silence(self): + if not self.silence: + return + now = time() + duration = now - self.last_silence_dispatch_time + await self._push_silence_event(Silence(duration=duration, has_ended=True)) + self.last_silence_dispatch_time = now + self.silence = False + self.start_silence = None + self.last_silence_dispatch_time = None + def convert_pcm_to_float(self, pcm_buffer): """Convert PCM buffer in s16le format to normalized NumPy array.""" return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0 @@ -225,28 +255,42 @@ class AudioProcessor: asr_internal_buffer_duration_s = len(getattr(self.transcription, 'audio_buffer', [])) / self.transcription.SAMPLING_RATE transcription_lag_s = max(0.0, time() - self.state.beg_loop - self.state.end_buffer) asr_processing_logs = f"internal_buffer={asr_internal_buffer_duration_s:.2f}s | lag={transcription_lag_s:.2f}s |" - if type(item) is Silence: - asr_processing_logs += f" + Silence of = {item.duration:.2f}s" + stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time + new_tokens = [] + current_audio_processed_upto = self.state.end_buffer + + if isinstance(item, Silence): + if item.is_starting: + new_tokens, current_audio_processed_upto = await asyncio.to_thread( + self.transcription.start_silence + ) + asr_processing_logs += f" + Silence starting" + if item.has_ended: + asr_processing_logs += f" + Silence of = {item.duration:.2f}s" + cumulative_pcm_duration_stream_time += item.duration + current_audio_processed_upto = cumulative_pcm_duration_stream_time + self.transcription.end_silence(item.duration, self.state.tokens[-1].end if self.state.tokens else 0) if self.state.tokens: asr_processing_logs += f" | last_end = {self.state.tokens[-1].end} |" logger.info(asr_processing_logs) - cumulative_pcm_duration_stream_time += item.duration - self.transcription.insert_silence(item.duration, self.state.tokens[-1].end if self.state.tokens else 0) - continue + new_tokens = new_tokens or [] + current_audio_processed_upto = max(current_audio_processed_upto, stream_time_end_of_current_pcm) elif isinstance(item, ChangeSpeaker): self.transcription.new_speaker(item) + self.transcription_queue.task_done() + continue elif isinstance(item, np.ndarray): pcm_array = item - - logger.info(asr_processing_logs) - - duration_this_chunk = len(pcm_array) / self.sample_rate - cumulative_pcm_duration_stream_time += duration_this_chunk - stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time + logger.info(asr_processing_logs) + cumulative_pcm_duration_stream_time += len(pcm_array) / self.sample_rate + stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time + self.transcription.insert_audio_chunk(pcm_array, stream_time_end_of_current_pcm) + new_tokens, current_audio_processed_upto = await asyncio.to_thread(self.transcription.process_iter) + new_tokens = new_tokens or [] + else: + self.transcription_queue.task_done() + continue - self.transcription.insert_audio_chunk(pcm_array, stream_time_end_of_current_pcm) - new_tokens, current_audio_processed_upto = await asyncio.to_thread(self.transcription.process_iter) - _buffer_transcript = self.transcription.get_buffer() buffer_text = _buffer_transcript.text @@ -304,7 +348,7 @@ class AudioProcessor: logger.debug("Diarization processor received sentinel. Finishing.") self.diarization_queue.task_done() break - elif type(item) is Silence: + elif type(item) is Silence and item.has_ended: diarization_obj.insert_silence(item.duration) continue elif isinstance(item, np.ndarray): @@ -380,7 +424,7 @@ class AudioProcessor: if additional_token is SENTINEL: sentinel_found = True break - elif type(additional_token) is Silence: + elif type(additional_token) is Silence and additional_token.has_ended: self.translation.insert_silence(additional_token.duration) continue else: @@ -640,26 +684,15 @@ class AudioProcessor: self.pcm_buffer = self.pcm_buffer[aligned_chunk_size:] res = None - end_of_audio = False - silence_buffer = None - if self.args.vac: res = self.vac(pcm_array) if res is not None: - if res.get("end", 0) > res.get("start", 0): - end_of_audio = True - elif self.silence: #end of silence - self.silence = False - silence_buffer = Silence(duration=time() - self.start_silence) + if res.get("end", 0) > res.get("start", 0) and not self.silence: + await self._begin_silence() + elif self.silence: + await self._end_silence() - if silence_buffer: - if not self.diarization_before_transcription and self.transcription_queue: - await self.transcription_queue.put(silence_buffer) - if self.args.diarization and self.diarization_queue: - await self.diarization_queue.put(silence_buffer) - if self.translation_queue: - await self.translation_queue.put(silence_buffer) if not self.silence: if not self.diarization_before_transcription and self.transcription_queue: @@ -670,9 +703,5 @@ class AudioProcessor: self.silence_duration = 0.0 - if end_of_audio: - self.silence = True - self.start_silence = time() - if not self.args.transcription and not self.args.diarization: await asyncio.sleep(0.1) diff --git a/whisperlivekit/diarization/sortformer_backend.py b/whisperlivekit/diarization/sortformer_backend.py index 09651f1..d835df0 100644 --- a/whisperlivekit/diarization/sortformer_backend.py +++ b/whisperlivekit/diarization/sortformer_backend.py @@ -160,7 +160,7 @@ class SortformerDiarizationOnline: # Initialize total predictions tensor self.total_preds = torch.zeros((batch_size, 0, self.diar_model.sortformer_modules.n_spk), device=device) - def insert_silence(self, silence_duration: float): + def insert_silence(self, silence_duration: Optional[float]): """ Insert silence period by adjusting the global time offset. diff --git a/whisperlivekit/local_agreement/online_asr.py b/whisperlivekit/local_agreement/online_asr.py index deca515..40a4551 100644 --- a/whisperlivekit/local_agreement/online_asr.py +++ b/whisperlivekit/local_agreement/online_asr.py @@ -151,21 +151,32 @@ class OnlineASRProcessor: """Append an audio chunk (a numpy array) to the current audio buffer.""" self.audio_buffer = np.append(self.audio_buffer, audio) - def insert_silence(self, silence_duration, offset): - """ - If silences are > 5s, we do a complete context clear. Otherwise, we just insert a small silence and shift the last_attend_frame - """ - # if self.transcript_buffer.buffer: - # self.committed.extend(self.transcript_buffer.buffer) - # self.transcript_buffer.buffer = [] - - if True: #silence_duration < 3: #we want the last audio to be treated to not have a gap. could also be handled in the future in ends_with_silence. - gap_silence = np.zeros(int(16000 * silence_duration), dtype=np.int16) - self.insert_audio_chunk(gap_silence) + def start_silence(self): + if self.audio_buffer.size == 0: + return [], self.get_audio_buffer_end_time() + return self.process_iter() + + def end_silence(self, silence_duration: Optional[float], offset: float): + if not silence_duration or silence_duration <= 0: + return + + long_silence = silence_duration >= 5 + if not long_silence: + gap_samples = int(self.SAMPLING_RATE * silence_duration) + if gap_samples > 0: + gap_silence = np.zeros(gap_samples, dtype=np.float32) + self.insert_audio_chunk(gap_silence) else: self.init(offset=silence_duration + offset) + self.global_time_offset += silence_duration + def insert_silence(self, silence_duration, offset): + """ + Backwards compatibility shim for legacy callers that still use insert_silence. + """ + self.end_silence(silence_duration, offset) + def prompt(self) -> Tuple[str, str]: """ Returns a tuple: (prompt, context), where: diff --git a/whisperlivekit/simul_whisper/backend.py b/whisperlivekit/simul_whisper/backend.py index 3416898..2c18823 100644 --- a/whisperlivekit/simul_whisper/backend.py +++ b/whisperlivekit/simul_whisper/backend.py @@ -63,16 +63,22 @@ class SimulStreamingOnlineProcessor: fw_encoder=self.asr.fw_encoder, ) - def insert_silence(self, silence_duration, offset): + def start_silence(self): + tokens, processed_upto = self.process_iter(is_last=True) + return tokens, processed_upto + + def end_silence(self, silence_duration, offset): """ If silences are > 5s, we do a complete context clear. Otherwise, we just insert a small silence and shift the last_attend_frame """ - if silence_duration < 5: - gap_silence = torch.zeros(int(16000*silence_duration)) - self.model.insert_audio(gap_silence) - # self.global_time_offset += silence_duration - else: - self.process_iter(is_last=True) #we want to totally process what remains in the buffer. + self.end += silence_duration + long_silence = silence_duration >= 5 + if not long_silence: + gap_len = int(16000 * silence_duration) + if gap_len > 0: + gap_silence = torch.zeros(gap_len) + self.model.insert_audio(gap_silence) + if long_silence: self.model.refresh_segment(complete=True) self.model.global_time_offset = silence_duration + offset diff --git a/whisperlivekit/timed_objects.py b/whisperlivekit/timed_objects.py index 6d9aff9..03222b0 100644 --- a/whisperlivekit/timed_objects.py +++ b/whisperlivekit/timed_objects.py @@ -123,7 +123,9 @@ class Translation(TimedText): @dataclass class Silence(): - duration: float + duration: Optional[float] = None + is_starting: bool = False + has_ended: bool = False @dataclass