diff --git a/docs/alignement_principles.md b/docs/alignement_principles.md new file mode 100644 index 0000000..4ce8fae --- /dev/null +++ b/docs/alignement_principles.md @@ -0,0 +1,71 @@ +### Alignment between STT Tokens and Diarization Segments + +- Example 1: The punctuation from STT and the speaker change from Diariation come in the prediction `t` +- Example 2: The punctuation from STT comes from prediction `t`, but the speaker change from Diariation come in the prediction `t-1` +- Example 3: The punctuation from STT comes from prediction `t-1`, but the speaker change from Diariation come in the prediction `t` + +> `#` Is the split between the `t-1` prediction and t prediction. + + +## Example 1: +```text +punctuations_segments : __#_______.__________________!____ +diarization_segments: +SPK1 __#____________ +SPK2 # ___________________ +--> +ALIGNED SPK1 __#_______. +ALIGNED SPK2 # __________________!____ + +t-1 output: +SPK1: __# +SPK2: NO +DIARIZATION BUFFER: NO + +t output: +SPK1: __#__. +SPK2: __________________!____ +DIARIZATION BUFFER: No +``` + +## Example 2: +```text +punctuations_segments : _____#__.___________ +diarization_segments: +SPK1 ___ # +SPK2 __#______________ +--> +ALIGNED SPK1 _____#__. +ALIGNED SPK2 # ___________ + +t-1 output: +SPK1: ___ # +SPK2: +DIARIZATION BUFFER: __# + +t output: +SPK1: __#__. +SPK2: ___________ +DIARIZATION BUFFER: No +``` + +## Example 3: +```text +punctuations_segments : ___.__#__________ +diarization_segments: +SPK1 ______#__ +SPK2 # ________ +--> +ALIGNED SPK1 ___. # +ALIGNED SPK2 __#__________ + +t-1 output: +SPK1: ___. # +SPK2: +DIARIZATION BUFFER: __# + +t output: +SPK1: # +SPK2: __#___________ +DIARIZATION BUFFER: NO +``` diff --git a/whisperlivekit/TokensAlignment.py b/whisperlivekit/TokensAlignment.py new file mode 100644 index 0000000..46b2c59 --- /dev/null +++ b/whisperlivekit/TokensAlignment.py @@ -0,0 +1,168 @@ +class TokensAlignment: + + def __init__(self, state_light, silence=None, args=None): + self.state_light = state_light + self.silence = silence + self.args = args + + self._tokens_index = 0 + self._diarization_index = 0 + self._translation_index = 0 + + def update(self): + pass + + + def compute_punctuations_segments(self): + punctuations_breaks = [] + new_tokens = self.state.tokens[self.state.last_validated_token:] + for i in range(len(new_tokens)): + token = new_tokens[i] + if token.is_punctuation(): + punctuations_breaks.append({ + 'token_index': i, + 'token': token, + 'start': token.start, + 'end': token.end, + }) + punctuations_segments = [] + for i, break_info in enumerate(punctuations_breaks): + start = punctuations_breaks[i - 1]['end'] if i > 0 else 0.0 + end = break_info['end'] + punctuations_segments.append({ + 'start': start, + 'end': end, + 'token_index': break_info['token_index'], + 'token': break_info['token'] + }) + return punctuations_segments + + def concatenate_diar_segments(self): + diarization_segments = self.state.diarization_segments + +if __name__ == "__main__": + from whisperlivekit.timed_objects import State, ASRToken, SpeakerSegment, Transcript, Silence + + # Reconstruct the state from the backup data + tokens = [ + ASRToken(start=1.38, end=1.48, text=' The'), + ASRToken(start=1.42, end=1.52, text=' description'), + ASRToken(start=1.82, end=1.92, text=' technology'), + ASRToken(start=2.54, end=2.64, text=' has'), + ASRToken(start=2.7, end=2.8, text=' improved'), + ASRToken(start=3.24, end=3.34, text=' so'), + ASRToken(start=3.66, end=3.76, text=' much'), + ASRToken(start=4.02, end=4.12, text=' in'), + ASRToken(start=4.08, end=4.18, text=' the'), + ASRToken(start=4.26, end=4.36, text=' past'), + ASRToken(start=4.48, end=4.58, text=' few'), + ASRToken(start=4.76, end=4.86, text=' years'), + ASRToken(start=5.76, end=5.86, text='.'), + ASRToken(start=5.72, end=5.82, text=' Have'), + ASRToken(start=5.92, end=6.02, text=' you'), + ASRToken(start=6.08, end=6.18, text=' noticed'), + ASRToken(start=6.52, end=6.62, text=' how'), + ASRToken(start=6.8, end=6.9, text=' accurate'), + ASRToken(start=7.46, end=7.56, text=' real'), + ASRToken(start=7.72, end=7.82, text='-time'), + ASRToken(start=8.06, end=8.16, text=' speech'), + ASRToken(start=8.48, end=8.58, text=' to'), + ASRToken(start=8.68, end=8.78, text=' text'), + ASRToken(start=9.0, end=9.1, text=' is'), + ASRToken(start=9.24, end=9.34, text=' now'), + ASRToken(start=9.82, end=9.92, text='?'), + ASRToken(start=9.86, end=9.96, text=' Absolutely'), + ASRToken(start=11.26, end=11.36, text='.'), + ASRToken(start=11.36, end=11.46, text=' I'), + ASRToken(start=11.58, end=11.68, text=' use'), + ASRToken(start=11.78, end=11.88, text=' it'), + ASRToken(start=11.94, end=12.04, text=' all'), + ASRToken(start=12.08, end=12.18, text=' the'), + ASRToken(start=12.32, end=12.42, text=' time'), + ASRToken(start=12.58, end=12.68, text=' for'), + ASRToken(start=12.78, end=12.88, text=' taking'), + ASRToken(start=13.14, end=13.24, text=' notes'), + ASRToken(start=13.4, end=13.5, text=' during'), + ASRToken(start=13.78, end=13.88, text=' meetings'), + ASRToken(start=14.6, end=14.7, text='.'), + ASRToken(start=14.82, end=14.92, text=' It'), + ASRToken(start=14.92, end=15.02, text="'s"), + ASRToken(start=15.04, end=15.14, text=' amazing'), + ASRToken(start=15.5, end=15.6, text=' how'), + ASRToken(start=15.66, end=15.76, text=' it'), + ASRToken(start=15.8, end=15.9, text=' can'), + ASRToken(start=15.96, end=16.06, text=' recognize'), + ASRToken(start=16.58, end=16.68, text=' different'), + ASRToken(start=16.94, end=17.04, text=' speakers'), + ASRToken(start=17.82, end=17.92, text=' and'), + ASRToken(start=18.0, end=18.1, text=' even'), + ASRToken(start=18.42, end=18.52, text=' add'), + ASRToken(start=18.74, end=18.84, text=' punct'), + ASRToken(start=19.02, end=19.12, text='uation'), + ASRToken(start=19.68, end=19.78, text='.'), + ASRToken(start=20.04, end=20.14, text=' Yeah'), + ASRToken(start=20.5, end=20.6, text=','), + ASRToken(start=20.6, end=20.7, text=' but'), + ASRToken(start=20.76, end=20.86, text=' sometimes'), + ASRToken(start=21.42, end=21.52, text=' noise'), + ASRToken(start=21.82, end=21.92, text=' can'), + ASRToken(start=22.08, end=22.18, text=' still'), + ASRToken(start=22.38, end=22.48, text=' cause'), + ASRToken(start=22.72, end=22.82, text=' mistakes'), + ASRToken(start=23.74, end=23.84, text='.'), + ASRToken(start=23.96, end=24.06, text=' Does'), + ASRToken(start=24.16, end=24.26, text=' this'), + ASRToken(start=24.4, end=24.5, text=' system'), + ASRToken(start=24.76, end=24.86, text=' handle'), + ASRToken(start=25.12, end=25.22, text=' that'), + ASRToken(start=25.38, end=25.48, text=' well'), + ASRToken(start=25.68, end=25.78, text='?'), + ASRToken(start=26.4, end=26.5, text=' It'), + ASRToken(start=26.5, end=26.6, text=' does'), + ASRToken(start=26.7, end=26.8, text=' a'), + ASRToken(start=27.08, end=27.18, text=' pretty'), + ASRToken(start=27.12, end=27.22, text=' good'), + ASRToken(start=27.34, end=27.44, text=' job'), + ASRToken(start=27.64, end=27.74, text=' filtering'), + ASRToken(start=28.1, end=28.2, text=' noise'), + ASRToken(start=28.64, end=28.74, text=','), + ASRToken(start=28.78, end=28.88, text=' especially'), + ASRToken(start=29.3, end=29.4, text=' with'), + ASRToken(start=29.51, end=29.61, text=' models'), + ASRToken(start=29.99, end=30.09, text=' that'), + ASRToken(start=30.21, end=30.31, text=' use'), + ASRToken(start=30.51, end=30.61, text=' voice'), + ASRToken(start=30.83, end=30.93, text=' activity'), + ] + + diarization_segments = [ + SpeakerSegment(start=1.3255040645599365, end=4.3255040645599365, speaker=0), + SpeakerSegment(start=4.806154012680054, end=9.806154012680054, speaker=0), + SpeakerSegment(start=9.806154012680054, end=10.806154012680054, speaker=1), + SpeakerSegment(start=11.168735027313232, end=14.168735027313232, speaker=1), + SpeakerSegment(start=14.41029405593872, end=17.41029405593872, speaker=1), + SpeakerSegment(start=17.52983808517456, end=19.52983808517456, speaker=1), + SpeakerSegment(start=19.64953374862671, end=20.066200415293377, speaker=1), + SpeakerSegment(start=20.066200415293377, end=22.64953374862671, speaker=2), + SpeakerSegment(start=23.012792587280273, end=25.012792587280273, speaker=2), + SpeakerSegment(start=25.495875597000122, end=26.41254226366679, speaker=2), + SpeakerSegment(start=26.41254226366679, end=30.495875597000122, speaker=0), + ] + + state = State( + tokens=tokens, + last_validated_token=72, + last_speaker=-1, + last_punctuation_index=71, + translation_validated_segments=[], + buffer_translation=Transcript(start=0, end=0, speaker=-1), + buffer_transcription=Transcript(start=None, end=None, speaker=-1), + diarization_segments=diarization_segments, + end_buffer=31.21587559700018, + end_attributed_speaker=30.495875597000122, + remaining_time_transcription=0.4, + remaining_time_diarization=0.7, + beg_loop=1763627603.968919 + ) + + alignment = TokensAlignment(state) \ No newline at end of file diff --git a/whisperlivekit/audio_processor.py b/whisperlivekit/audio_processor.py index 5bbba26..353924d 100644 --- a/whisperlivekit/audio_processor.py +++ b/whisperlivekit/audio_processor.py @@ -4,12 +4,12 @@ from time import time, sleep import math import logging import traceback -from whisperlivekit.timed_objects import ASRToken, Silence, Line, FrontData, State, Transcript, ChangeSpeaker +from whisperlivekit.timed_objects import ASRToken, Silence, Line, FrontData, State, StateLight, Transcript, ChangeSpeaker from whisperlivekit.core import TranscriptionEngine, online_factory, online_diarization_factory, online_translation_factory from whisperlivekit.silero_vad_iterator import FixedVADIterator from whisperlivekit.results_formater import format_output from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState - +from whisperlivekit.TokensAlignment import TokensAlignment logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -86,21 +86,16 @@ class AudioProcessor: self.start_silence = None self.last_silence_dispatch_time = None self.state = State() + self.state_light = StateLight() self.lock = asyncio.Lock() self.sep = " " # Default separator self.last_response_content = FrontData() self.last_detected_speaker = None self.speaker_languages = {} - self.diarization_before_transcription = False - self.segments = [] - + self.tokens_alignment = TokensAlignment(self.state_light, self.args, self.sep) + self.beg_loop = None - if self.diarization_before_transcription: - self.cumulative_pcm = [] - self.last_start = 0.0 - self.last_end = 0.0 - # Models and processing self.asr = models.asr self.vac_model = models.vac_model @@ -128,7 +123,7 @@ class AudioProcessor: self.translation_queue = asyncio.Queue() if self.args.target_language else None self.pcm_buffer = bytearray() self.total_pcm_samples = 0 - + self.end_buffer = 0.0 self.transcription_task = None self.diarization_task = None self.translation_task = None @@ -148,7 +143,7 @@ class AudioProcessor: self.translation = online_translation_factory(self.args, models.translation_model) async def _push_silence_event(self, silence_buffer: Silence): - if not self.diarization_before_transcription and self.transcription_queue: + if self.transcription_queue: await self.transcription_queue.put(silence_buffer) if self.args.diarization and self.diarization_queue: await self.diarization_queue.put(silence_buffer) @@ -168,7 +163,7 @@ class AudioProcessor: if not self.silence: return now = time() - duration = now - (self.last_silence_dispatch_time if self.last_silence_dispatch_time else self.state.beg_loop) + duration = now - (self.last_silence_dispatch_time if self.last_silence_dispatch_time else self.beg_loop) await self._push_silence_event(Silence(duration=duration, has_ended=True)) self.last_silence_dispatch_time = now self.silence = False @@ -178,7 +173,7 @@ class AudioProcessor: async def _enqueue_active_audio(self, pcm_chunk: np.ndarray): if pcm_chunk is None or pcm_chunk.size == 0: return - if not self.diarization_before_transcription and self.transcription_queue: + if self.transcription_queue: await self.transcription_queue.put(pcm_chunk.copy()) if self.args.diarization and self.diarization_queue: await self.diarization_queue.put(pcm_chunk.copy()) @@ -198,15 +193,6 @@ class AudioProcessor: def convert_pcm_to_float(self, pcm_buffer): """Convert PCM buffer in s16le format to normalized NumPy array.""" return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0 - - async def add_dummy_token(self): - """Placeholder token when no transcription is available.""" - async with self.lock: - current_time = time() - self.state.beg_loop - self.state.tokens.append(ASRToken( - start=current_time, end=current_time + 1, - text=".", speaker=-1, is_dummy=True - )) async def get_current_state(self): """Get current state.""" @@ -214,12 +200,12 @@ class AudioProcessor: current_time = time() remaining_transcription = 0 - if self.state.end_buffer > 0: - remaining_transcription = max(0, round(current_time - self.state.beg_loop - self.state.end_buffer, 1)) + if self.end_buffer > 0: + remaining_transcription = max(0, round(current_time - self.beg_loop - self.end_buffer, 1)) remaining_diarization = 0 if self.state.tokens: - latest_end = max(self.state.end_buffer, self.state.tokens[-1].end if self.state.tokens else 0) + latest_end = max(self.end_buffer, self.state.tokens[-1].end if self.state.tokens else 0) remaining_diarization = max(0, round(latest_end - self.state.end_attributed_speaker, 1)) self.state.remaining_time_transcription = remaining_transcription @@ -270,7 +256,7 @@ class AudioProcessor: await asyncio.sleep(0.2) logger.info("FFmpeg stdout processing finished. Signaling downstream processors if needed.") - if not self.diarization_before_transcription and self.transcription_queue: + if self.transcription_queue: await self.transcription_queue.put(SENTINEL) if self.diarization: await self.diarization_queue.put(SENTINEL) @@ -290,11 +276,11 @@ class AudioProcessor: break asr_internal_buffer_duration_s = len(getattr(self.transcription, 'audio_buffer', [])) / self.transcription.SAMPLING_RATE - transcription_lag_s = max(0.0, time() - self.state.beg_loop - self.state.end_buffer) + transcription_lag_s = max(0.0, time() - self.beg_loop - self.end_buffer) asr_processing_logs = f"internal_buffer={asr_internal_buffer_duration_s:.2f}s | lag={transcription_lag_s:.2f}s |" stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time new_tokens = [] - current_audio_processed_upto = self.state.end_buffer + current_audio_processed_upto = self.end_buffer if isinstance(item, Silence): if item.is_starting: @@ -314,7 +300,6 @@ class AudioProcessor: current_audio_processed_upto = max(current_audio_processed_upto, stream_time_end_of_current_pcm) elif isinstance(item, ChangeSpeaker): self.transcription.new_speaker(item) - # self.transcription_queue.task_done() continue elif isinstance(item, np.ndarray): pcm_array = item @@ -324,9 +309,6 @@ class AudioProcessor: self.transcription.insert_audio_chunk(pcm_array, stream_time_end_of_current_pcm) new_tokens, current_audio_processed_upto = await asyncio.to_thread(self.transcription.process_iter) new_tokens = new_tokens or [] - else: - continue - return _buffer_transcript = self.transcription.get_buffer() buffer_text = _buffer_transcript.text @@ -336,7 +318,7 @@ class AudioProcessor: if buffer_text.startswith(validated_text): _buffer_transcript.text = buffer_text[len(validated_text):].lstrip() - candidate_end_times = [self.state.end_buffer] + candidate_end_times = [self.end_buffer] if new_tokens: candidate_end_times.append(new_tokens[-1].end) @@ -349,8 +331,11 @@ class AudioProcessor: async with self.lock: self.state.tokens.extend(new_tokens) self.state.buffer_transcription = _buffer_transcript - self.state.end_buffer = max(candidate_end_times) - + self.end_buffer = max(candidate_end_times) + self.state_light.new_tokens = new_tokens + self.state_light.new_tokens += 1 + self.state_light.new_tokens_buffer = _buffer_transcript + if self.translation_queue: for token in new_tokens: await self.translation_queue.put(token) @@ -370,59 +355,21 @@ class AudioProcessor: logger.info("Transcription processor task finished.") - async def diarization_processor(self, diarization_obj): - """Process audio chunks for speaker diarization.""" - if self.diarization_before_transcription: - self.current_speaker = 0 - await self.transcription_queue.put(ChangeSpeaker(speaker=self.current_speaker, start=0.0)) + async def diarization_processor(self): while True: try: item = await get_all_from_queue(self.diarization_queue) if item is SENTINEL: - logger.debug("Diarization processor received sentinel. Finishing.") break elif type(item) is Silence: if item.has_ended: - diarization_obj.insert_silence(item.duration) + self.diarization.insert_silence(item.duration) continue - elif isinstance(item, np.ndarray): - pcm_array = item - else: - raise Exception('item should be pcm_array') - - - - # Process diarization - await diarization_obj.diarize(pcm_array) - if self.diarization_before_transcription: - segments = diarization_obj.get_segments() - self.cumulative_pcm.append(pcm_array) - if segments: - last_segment = segments[-1] - if last_segment.speaker != self.current_speaker: - cut_sec = last_segment.start - self.last_end - to_transcript, self.cumulative_pcm = cut_at(self.cumulative_pcm, cut_sec) - await self.transcription_queue.put(to_transcript) - - self.current_speaker = last_segment.speaker - await self.transcription_queue.put(ChangeSpeaker(speaker=self.current_speaker, start=last_segment.start)) - - cut_sec = last_segment.end - last_segment.start - to_transcript, self.cumulative_pcm = cut_at(self.cumulative_pcm, cut_sec) - await self.transcription_queue.put(to_transcript) - self.last_start = last_segment.start - self.last_end = last_segment.end - else: - cut_sec = last_segment.end - self.last_end - to_transcript, self.cumulative_pcm = cut_at(self.cumulative_pcm, cut_sec) - await self.transcription_queue.put(to_transcript) - self.last_end = last_segment.end - elif not self.diarization_before_transcription: - segments = diarization_obj.get_segments() - async with self.lock: - self.state.speaker_segments = segments.copy() - if segments: - self.state.end_attributed_speaker = max(seg.end for seg in segments) + + self.diarization.insert_audio_chunk(item) + diarization_segments = await self.diarization.diarize() + self.state_light.new_diarization = diarization_segments + self.state_light.new_diarization_index += 1 except Exception as e: logger.warning(f"Exception in diarization_processor: {e}") @@ -466,7 +413,7 @@ class AudioProcessor: continue state = await self.get_current_state() - + self.tokens_alignment.compute_punctuations_segments() lines, undiarized_text = format_output( state, self.silence, @@ -553,7 +500,7 @@ class AudioProcessor: processing_tasks_for_watchdog.append(self.transcription_task) if self.diarization: - self.diarization_task = asyncio.create_task(self.diarization_processor(self.diarization)) + self.diarization_task = asyncio.create_task(self.diarization_processor()) self.all_tasks_for_cleanup.append(self.diarization_task) processing_tasks_for_watchdog.append(self.diarization_task) @@ -632,8 +579,8 @@ class AudioProcessor: async def process_audio(self, message): """Process incoming audio data.""" - if not self.state.beg_loop: - self.state.beg_loop = time() + if not self.beg_loop: + self.beg_loop = time() if not message: logger.info("Empty audio message received, initiating stop sequence.") diff --git a/whisperlivekit/diarization/diart_backend.py b/whisperlivekit/diarization/diart_backend.py index 55a7d3c..0525973 100644 --- a/whisperlivekit/diarization/diart_backend.py +++ b/whisperlivekit/diarization/diart_backend.py @@ -26,7 +26,7 @@ class DiarizationObserver(Observer): """Observer that logs all data emitted by the diarization pipeline and stores speaker segments.""" def __init__(self): - self.speaker_segments = [] + self.diarization_segments = [] self.processed_time = 0 self.segment_lock = threading.Lock() self.global_time_offset = 0.0 @@ -48,7 +48,7 @@ class DiarizationObserver(Observer): for speaker, label in annotation._labels.items(): for start, end in zip(label.segments_boundaries_[:-1], label.segments_boundaries_[1:]): print(f" {speaker}: {start:.2f}s-{end:.2f}s") - self.speaker_segments.append(SpeakerSegment( + self.diarization_segments.append(SpeakerSegment( speaker=speaker, start=start + self.global_time_offset, end=end + self.global_time_offset @@ -59,14 +59,14 @@ class DiarizationObserver(Observer): def get_segments(self) -> List[SpeakerSegment]: """Get a copy of the current speaker segments.""" with self.segment_lock: - return self.speaker_segments.copy() + return self.diarization_segments.copy() def clear_old_segments(self, older_than: float = 30.0): """Clear segments older than the specified time.""" with self.segment_lock: current_time = self.processed_time - self.speaker_segments = [ - segment for segment in self.speaker_segments + self.diarization_segments = [ + segment for segment in self.diarization_segments if current_time - segment.end < older_than ] diff --git a/whisperlivekit/diarization/sortformer_backend.py b/whisperlivekit/diarization/sortformer_backend.py index 4d79156..4fb4627 100644 --- a/whisperlivekit/diarization/sortformer_backend.py +++ b/whisperlivekit/diarization/sortformer_backend.py @@ -94,11 +94,11 @@ class SortformerDiarizationOnline: model_name: Pre-trained model name (default: "nvidia/diar_streaming_sortformer_4spk-v2") """ self.sample_rate = sample_rate - self.speaker_segments = [] + self.diarization_segments = [] + self.diar_segments = [] self.buffer_audio = np.array([], dtype=np.float32) self.segment_lock = threading.Lock() self.global_time_offset = 0.0 - self.processed_time = 0.0 self.debug = False self.diar_model = shared_model.diar_model @@ -155,9 +155,7 @@ class SortformerDiarizationOnline: ) self.streaming_state.fifo_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device) self.streaming_state.mean_sil_emb = torch.zeros((batch_size, self.diar_model.sortformer_modules.fc_d_model), device=device) - self.streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device) - - # Initialize total predictions tensor + self.streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device) self.total_preds = torch.zeros((batch_size, 0, self.diar_model.sortformer_modules.n_spk), device=device) def insert_silence(self, silence_duration: Optional[float]): @@ -171,135 +169,111 @@ class SortformerDiarizationOnline: self.global_time_offset += silence_duration logger.debug(f"Inserted silence of {silence_duration:.2f}s, new offset: {self.global_time_offset:.2f}s") - async def diarize(self, pcm_array: np.ndarray): + def insert_audio_chunk(self, pcm_array: np.ndarray): + if self.debug: + self.audio_buffer.append(pcm_array.copy()) + self.buffer_audio = np.concatenate([self.buffer_audio, pcm_array.copy()]) + + + async def diarize(self): """ Process audio data for diarization in streaming fashion. Args: pcm_array: Audio data as numpy array """ - try: - if self.debug: - self.audio_buffer.append(pcm_array.copy()) - threshold = int(self.chunk_duration_seconds * self.sample_rate) + threshold = int(self.chunk_duration_seconds * self.sample_rate) + + if not len(self.buffer_audio) >= threshold: + return [] + + audio = self.buffer_audio[:threshold] + self.buffer_audio = self.buffer_audio[threshold:] + + device = self.diar_model.device + audio_signal_chunk = torch.tensor(audio, device=device).unsqueeze(0) + audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]], device=device) + + processed_signal_chunk, processed_signal_length_chunk = self.audio2mel.get_features( + audio_signal_chunk, audio_signal_length_chunk + ) + processed_signal_chunk = processed_signal_chunk.to(device) + processed_signal_length_chunk = processed_signal_length_chunk.to(device) + + if self._previous_chunk_features is not None: + to_add = self._previous_chunk_features[:, :, -99:].to(device) + total_features = torch.concat([to_add, processed_signal_chunk], dim=2).to(device) + else: + total_features = processed_signal_chunk.to(device) + + self._previous_chunk_features = processed_signal_chunk.to(device) + + chunk_feat_seq_t = torch.transpose(total_features, 1, 2).to(device) + + with torch.inference_mode(): + left_offset = 8 if self._chunk_index > 0 else 0 + right_offset = 8 - self.buffer_audio = np.concatenate([self.buffer_audio, pcm_array.copy()]) - if not len(self.buffer_audio) >= threshold: - return - - audio = self.buffer_audio[:threshold] - self.buffer_audio = self.buffer_audio[threshold:] - - device = self.diar_model.device - audio_signal_chunk = torch.tensor(audio, device=device).unsqueeze(0) - audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]], device=device) - - processed_signal_chunk, processed_signal_length_chunk = self.audio2mel.get_features( - audio_signal_chunk, audio_signal_length_chunk - ) - processed_signal_chunk = processed_signal_chunk.to(device) - processed_signal_length_chunk = processed_signal_length_chunk.to(device) - - if self._previous_chunk_features is not None: - to_add = self._previous_chunk_features[:, :, -99:].to(device) - total_features = torch.concat([to_add, processed_signal_chunk], dim=2).to(device) - else: - total_features = processed_signal_chunk.to(device) - - self._previous_chunk_features = processed_signal_chunk.to(device) - - chunk_feat_seq_t = torch.transpose(total_features, 1, 2).to(device) - - with torch.inference_mode(): - left_offset = 8 if self._chunk_index > 0 else 0 - right_offset = 8 - - self.streaming_state, self.total_preds = self.diar_model.forward_streaming_step( - processed_signal=chunk_feat_seq_t, - processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]).to(device), - streaming_state=self.streaming_state, - total_preds=self.total_preds, - left_offset=left_offset, - right_offset=right_offset, - ) - - # Convert predictions to speaker segments - self._process_predictions() - - self._chunk_index += 1 - - except Exception as e: - logger.error(f"Error in diarize: {e}") - raise - - # TODO: Handle case when stream ends with partial buffer (accumulated_duration > 0 but < chunk_duration_seconds) + self.streaming_state, self.total_preds = self.diar_model.forward_streaming_step( + processed_signal=chunk_feat_seq_t, + processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]).to(device), + streaming_state=self.streaming_state, + total_preds=self.total_preds, + left_offset=left_offset, + right_offset=right_offset, + ) + new_segments = self._process_predictions() + + self._chunk_index += 1 + return new_segments def _process_predictions(self): """Process model predictions and convert to speaker segments.""" - try: - preds_np = self.total_preds[0].cpu().numpy() - active_speakers = np.argmax(preds_np, axis=1) - - if self._len_prediction is None: - self._len_prediction = len(active_speakers) - - # Get predictions for current chunk - frame_duration = self.chunk_duration_seconds / self._len_prediction - current_chunk_preds = active_speakers[-self._len_prediction:] - - with self.segment_lock: - # Process predictions into segments - base_time = self._chunk_index * self.chunk_duration_seconds + self.global_time_offset - - for idx, spk in enumerate(current_chunk_preds): - start_time = base_time + idx * frame_duration - end_time = base_time + (idx + 1) * frame_duration - - # Check if this continues the last segment or starts a new one - if (self.speaker_segments and - self.speaker_segments[-1].speaker == spk and - abs(self.speaker_segments[-1].end - start_time) < frame_duration * 0.5): - # Continue existing segment - self.speaker_segments[-1].end = end_time - else: - - # Create new segment - self.speaker_segments.append(SpeakerSegment( - speaker=spk, - start=start_time, - end=end_time - )) - - # Update processed time - self.processed_time = max(self.processed_time, base_time + self.chunk_duration_seconds) - - logger.debug(f"Processed chunk {self._chunk_index}, total segments: {len(self.speaker_segments)}") - - except Exception as e: - logger.error(f"Error processing predictions: {e}") - + preds_np = self.total_preds[0].cpu().numpy() + active_speakers = np.argmax(preds_np, axis=1) + + if self._len_prediction is None: + self._len_prediction = len(active_speakers) #12 + + frame_duration = self.chunk_duration_seconds / self._len_prediction + current_chunk_preds = active_speakers[-self._len_prediction:] + + new_segments = [] + with self.segment_lock: + base_time = self._chunk_index * self.chunk_duration_seconds + self.global_time_offset + current_spk = current_chunk_preds[0] + start_time = round(base_time, 2) + for idx, spk in enumerate(current_chunk_preds): + current_time = round(base_time + idx * frame_duration, 2) + if spk != current_spk: + new_segments.append(SpeakerSegment( + speaker=current_spk, + start=start_time, + end=current_time + )) + start_time = current_time + current_spk = spk + new_segments.append( + SpeakerSegment( + speaker=current_spk, + start=start_time, + end=current_time + ) + ) + return new_segments + def get_segments(self) -> List[SpeakerSegment]: """Get a copy of the current speaker segments.""" with self.segment_lock: - return self.speaker_segments.copy() - - def clear_old_segments(self, older_than: float = 30.0): - """Clear segments older than the specified time.""" - with self.segment_lock: - current_time = self.processed_time - self.speaker_segments = [ - segment for segment in self.speaker_segments - if current_time - segment.end < older_than - ] - logger.debug(f"Cleared old segments, remaining: {len(self.speaker_segments)}") + return self.diarization_segments.copy() def close(self): """Close the diarization system and clean up resources.""" logger.info("Closing SortformerDiarization") with self.segment_lock: - self.speaker_segments.clear() + self.diarization_segments.clear() if self.debug: concatenated_audio = np.concatenate(self.audio_buffer) @@ -325,7 +299,7 @@ if __name__ == '__main__': async def main(): """TEST ONLY.""" - an4_audio = 'audio_test.mp3' + an4_audio = 'diarization_audio.wav' signal, sr = librosa.load(an4_audio, sr=16000) signal = signal[:16000*30] @@ -337,13 +311,15 @@ if __name__ == '__main__': print("Speaker 0: 0:25 - 0:30") print("=" * 50) - diarization = SortformerDiarization(sample_rate=16000) + diarization_backend = SortformerDiarization() + diarization = SortformerDiarizationOnline(shared_model = diarization_backend) chunk_size = 1600 for i in range(0, len(signal), chunk_size): chunk = signal[i:i+chunk_size] - await diarization.diarize(chunk) + new_segments = await diarization.diarize(chunk) print(f"Processed chunk {i // chunk_size + 1}") + print(new_segments) segments = diarization.get_segments() print("\nDiarization results:") diff --git a/whisperlivekit/diarization/sortformer_backend_offline.py b/whisperlivekit/diarization/sortformer_backend_offline.py deleted file mode 100644 index 2619154..0000000 --- a/whisperlivekit/diarization/sortformer_backend_offline.py +++ /dev/null @@ -1,205 +0,0 @@ -import numpy as np -import torch -import logging - -from nemo.collections.asr.models import SortformerEncLabelModel -from nemo.collections.asr.modules import AudioToMelSpectrogramPreprocessor -import librosa - -logger = logging.getLogger(__name__) - -def load_model(): - - diar_model = SortformerEncLabelModel.from_pretrained("nvidia/diar_streaming_sortformer_4spk-v2") - diar_model.eval() - - if torch.cuda.is_available(): - diar_model.to(torch.device("cuda")) - - #we target 1 second lag for the moment. chunk_len could be reduced. - diar_model.sortformer_modules.chunk_len = 10 - diar_model.sortformer_modules.subsampling_factor = 10 #8 would be better ideally - - diar_model.sortformer_modules.chunk_right_context = 0 #no. - diar_model.sortformer_modules.chunk_left_context = 10 #big so it compensiate the problem with no padding later. - - diar_model.sortformer_modules.spkcache_len = 188 - diar_model.sortformer_modules.fifo_len = 188 - diar_model.sortformer_modules.spkcache_update_period = 144 - diar_model.sortformer_modules.log = False - diar_model.sortformer_modules._check_streaming_parameters() - - - audio2mel = AudioToMelSpectrogramPreprocessor( - window_size= 0.025, - normalize="NA", - n_fft=512, - features=128, - pad_to=0) #pad_to 16 works better than 0. On test audio, we detect a third speaker for 1 second with pad_to=0. To solve that : increase left context to 10. - - return diar_model, audio2mel - -diar_model, audio2mel = load_model() - -class StreamingSortformerState: - """ - This class creates a class instance that will be used to store the state of the - streaming Sortformer model. - - Attributes: - spkcache (torch.Tensor): Speaker cache to store embeddings from start - spkcache_lengths (torch.Tensor): Lengths of the speaker cache - spkcache_preds (torch.Tensor): The speaker predictions for the speaker cache parts - fifo (torch.Tensor): FIFO queue to save the embedding from the latest chunks - fifo_lengths (torch.Tensor): Lengths of the FIFO queue - fifo_preds (torch.Tensor): The speaker predictions for the FIFO queue parts - spk_perm (torch.Tensor): Speaker permutation information for the speaker cache - mean_sil_emb (torch.Tensor): Mean silence embedding - n_sil_frames (torch.Tensor): Number of silence frames - """ - - spkcache = None # Speaker cache to store embeddings from start - spkcache_lengths = None # - spkcache_preds = None # speaker cache predictions - fifo = None # to save the embedding from the latest chunks - fifo_lengths = None - fifo_preds = None - spk_perm = None - mean_sil_emb = None - n_sil_frames = None - - -def init_streaming_state(self, batch_size: int = 1, async_streaming: bool = False, device: torch.device = None): - """ - Initializes StreamingSortformerState with empty tensors or zero-valued tensors. - - Args: - batch_size (int): Batch size for tensors in streaming state - async_streaming (bool): True for asynchronous update, False for synchronous update - device (torch.device): Device for tensors in streaming state - - Returns: - streaming_state (SortformerStreamingState): initialized streaming state - """ - streaming_state = StreamingSortformerState() - if async_streaming: - streaming_state.spkcache = torch.zeros((batch_size, self.spkcache_len, self.fc_d_model), device=device) - streaming_state.spkcache_preds = torch.zeros((batch_size, self.spkcache_len, self.n_spk), device=device) - streaming_state.spkcache_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device) - streaming_state.fifo = torch.zeros((batch_size, self.fifo_len, self.fc_d_model), device=device) - streaming_state.fifo_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device) - else: - streaming_state.spkcache = torch.zeros((batch_size, 0, self.fc_d_model), device=device) - streaming_state.fifo = torch.zeros((batch_size, 0, self.fc_d_model), device=device) - streaming_state.mean_sil_emb = torch.zeros((batch_size, self.fc_d_model), device=device) - streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device) - return streaming_state - - -def process_diarization(chunks): - """ - what it does: - 1. Preprocessing: Applies dithering and pre-emphasis (high-pass filter) if enabled - 2. STFT: Computes the Short-Time Fourier Transform using: - - the window of window_size=0.025 --> size of a window : 400 samples - - the hop parameter : n_window_stride = 0.01 -> every 160 samples, a new window - 3. Magnitude Calculation: Converts complex STFT output to magnitude spectrogram - 4. Mel Conversion: Applies Mel filterbanks (128 filters in this case) to get Mel spectrogram - 5. Logarithm: Takes the log of the Mel spectrogram (if `log=True`) - 6. Normalization: Skips normalization since `normalize="NA"` - 7. Padding: Pads the time dimension to a multiple of `pad_to` (default 16) - """ - previous_chunk = None - l_chunk_feat_seq_t = [] - for chunk in chunks: - audio_signal_chunk = torch.tensor(chunk).unsqueeze(0).to(diar_model.device) - audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]]).to(diar_model.device) - processed_signal_chunk, processed_signal_length_chunk = audio2mel.get_features(audio_signal_chunk, audio_signal_length_chunk) - if previous_chunk is not None: - to_add = previous_chunk[:, :, -99:] - total = torch.concat([to_add, processed_signal_chunk], dim=2) - else: - total = processed_signal_chunk - previous_chunk = processed_signal_chunk - l_chunk_feat_seq_t.append(torch.transpose(total, 1, 2)) - - batch_size = 1 - streaming_state = init_streaming_state(diar_model.sortformer_modules, - batch_size = batch_size, - async_streaming = True, - device = diar_model.device - ) - total_preds = torch.zeros((batch_size, 0, diar_model.sortformer_modules.n_spk), device=diar_model.device) - - chunk_duration_seconds = diar_model.sortformer_modules.chunk_len * diar_model.sortformer_modules.subsampling_factor * diar_model.preprocessor._cfg.window_stride - - l_speakers = [ - {'start_time': 0, - 'end_time': 0, - 'speaker': 0 - } - ] - len_prediction = None - left_offset = 0 - right_offset = 8 - for i, chunk_feat_seq_t in enumerate(l_chunk_feat_seq_t): - with torch.inference_mode(): - streaming_state, total_preds = diar_model.forward_streaming_step( - processed_signal=chunk_feat_seq_t, - processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]), - streaming_state=streaming_state, - total_preds=total_preds, - left_offset=left_offset, - right_offset=right_offset, - ) - left_offset = 8 - preds_np = total_preds[0].cpu().numpy() - active_speakers = np.argmax(preds_np, axis=1) - if len_prediction is None: - len_prediction = len(active_speakers) # we want to get the len of 1 prediction - frame_duration = chunk_duration_seconds / len_prediction - active_speakers = active_speakers[-len_prediction:] - for idx, spk in enumerate(active_speakers): - if spk != l_speakers[-1]['speaker']: - l_speakers.append( - {'start_time': (i * chunk_duration_seconds + idx * frame_duration), - 'end_time': (i * chunk_duration_seconds + (idx + 1) * frame_duration), - 'speaker': spk - }) - else: - l_speakers[-1]['end_time'] = i * chunk_duration_seconds + (idx + 1) * frame_duration - - - """ - Should print - [{'start_time': 0, 'end_time': 8.72, 'speaker': 0}, - {'start_time': 8.72, 'end_time': 18.88, 'speaker': 1}, - {'start_time': 18.88, 'end_time': 24.96, 'speaker': 2}, - {'start_time': 24.96, 'end_time': 31.68, 'speaker': 0}] - """ - for speaker in l_speakers: - print(f"Speaker {speaker['speaker']}: {speaker['start_time']:.2f}s - {speaker['end_time']:.2f}s") - - -if __name__ == '__main__': - - an4_audio = 'audio_test.mp3' - signal, sr = librosa.load(an4_audio, sr=16000) - signal = signal[:16000*30] - # signal = signal[:-(len(signal)%16000)] - - print("\n" + "=" * 50) - print("Expected ground truth:") - print("Speaker 0: 0:00 - 0:09") - print("Speaker 1: 0:09 - 0:19") - print("Speaker 2: 0:19 - 0:25") - print("Speaker 0: 0:25 - 0:30") - print("=" * 50) - - chunk_size = 16000 # 1 second - chunks = [] - for i in range(0, len(signal), chunk_size): - chunk = signal[i:i+chunk_size] - chunks.append(chunk) - - process_diarization(chunks) \ No newline at end of file diff --git a/whisperlivekit/remove_silences.py b/whisperlivekit/remove_silences.py index 785c793..368d34c 100644 --- a/whisperlivekit/remove_silences.py +++ b/whisperlivekit/remove_silences.py @@ -36,7 +36,6 @@ def blank_to_silence(tokens): start=token.start, end=token.end, speaker=-2, - probability=0.95 ) else: if silence_token: #there was silence but no more @@ -70,7 +69,6 @@ def no_token_to_silence(tokens): start=last_end, end=token.start, speaker=-2, - probability=0.95 ) new_tokens.append(silence_token) @@ -90,7 +88,6 @@ def ends_with_silence(tokens, beg_loop, vac_detected_silence): start=tokens[-1].end, end=current_time, speaker=-2, - probability=0.95 ) ) return tokens diff --git a/whisperlivekit/results_formater.py b/whisperlivekit/results_formater.py index df83601..07cbbb5 100644 --- a/whisperlivekit/results_formater.py +++ b/whisperlivekit/results_formater.py @@ -159,14 +159,14 @@ def format_output(state, silence, args, sep): tokens = handle_silences(tokens, state.beg_loop, silence) # Assign speakers to tokens based on segments stored in state - if diarization and state.speaker_segments: + if False and diarization and state.diarization_segments: use_punctuation_split = args.punctuation_split if hasattr(args, 'punctuation_split') else False - tokens = assign_speakers_to_tokens(tokens, state.speaker_segments, use_punctuation_split=use_punctuation_split) + tokens = assign_speakers_to_tokens(tokens, state.diarization_segments, use_punctuation_split=use_punctuation_split) for i in range(last_validated_token, len(tokens)): token = tokens[i] speaker = int(token.speaker) token.corrected_speaker = speaker - if not diarization: + if True or not diarization: if speaker == -1: #Speaker -1 means no attributed by diarization. In the frontend, it should appear under 'Speaker 1' token.corrected_speaker = 1 token.validated_speaker = True diff --git a/whisperlivekit/simul_whisper/simul_whisper.py b/whisperlivekit/simul_whisper/simul_whisper.py index 692639d..0a7d5e3 100644 --- a/whisperlivekit/simul_whisper/simul_whisper.py +++ b/whisperlivekit/simul_whisper/simul_whisper.py @@ -641,10 +641,9 @@ class AlignAtt: timestamp_idx += len(word_tokens) timestamp_entry = ASRToken( - start=current_timestamp, - end=current_timestamp + 0.1, + start=round(current_timestamp, 2), + end=round(current_timestamp + 0.1, 2), text= word, - probability=0.95, speaker=self.speaker, detected_language=self.detected_language ).with_offset( diff --git a/whisperlivekit/timed_objects.py b/whisperlivekit/timed_objects.py index 66e1b14..69f1542 100644 --- a/whisperlivekit/timed_objects.py +++ b/whisperlivekit/timed_objects.py @@ -8,15 +8,15 @@ def format_time(seconds: float) -> str: """Format seconds as HH:MM:SS.""" return str(timedelta(seconds=int(seconds))) - @dataclass -class TimedText: +class Timed: start: Optional[float] = 0 end: Optional[float] = 0 + +@dataclass +class TimedText(Timed): text: Optional[str] = '' speaker: Optional[int] = -1 - probability: Optional[float] = None - is_dummy: Optional[bool] = False detected_language: Optional[str] = None def is_punctuation(self): @@ -51,7 +51,7 @@ class ASRToken(TimedText): def with_offset(self, offset: float) -> "ASRToken": """Return a new token with the time offset added.""" - return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, self.probability, detected_language=self.detected_language) + return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, detected_language=self.detected_language) @dataclass class Sentence(TimedText): @@ -72,21 +72,21 @@ class Transcript(TimedText): ) -> "Transcript": sep = sep if sep is not None else ' ' text = sep.join(token.text for token in tokens) - probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None if tokens: start = offset + tokens[0].start end = offset + tokens[-1].end else: start = None end = None - return cls(start, end, text, probability=probability) + return cls(start, end, text) @dataclass -class SpeakerSegment(TimedText): +class SpeakerSegment(Timed): """Represents a segment of audio attributed to a specific speaker. No text nor probability is associated with this segment. """ + speaker: Optional[int] = -1 pass @dataclass @@ -185,9 +185,19 @@ class State(): translation_validated_segments: list = field(default_factory=list) buffer_translation: str = field(default_factory=Transcript) buffer_transcription: str = field(default_factory=Transcript) - speaker_segments: list = field(default_factory=list) + diarization_segments: list = field(default_factory=list) end_buffer: float = 0.0 end_attributed_speaker: float = 0.0 remaining_time_transcription: float = 0.0 remaining_time_diarization: float = 0.0 - beg_loop: Optional[int] = None + + +@dataclass +class StateLight(): + new_tokens: list = field(default_factory=list) + new_translation: list = field(default_factory=list) + new_diarization: list = field(default_factory=list) + new_tokens_buffer: list = field(default_factory=list) #only when local agreement + new_tokens_index = 0 + new_translation_index = 0 + new_diarization_index = 0