diff --git a/whisperlivekit/diarization/sortformer_backend.py b/whisperlivekit/diarization/sortformer_backend.py index c3298aa..5115c91 100644 --- a/whisperlivekit/diarization/sortformer_backend.py +++ b/whisperlivekit/diarization/sortformer_backend.py @@ -8,138 +8,4 @@ logger = logging.getLogger(__name__) try: from nemo.collections.asr.models import SortformerEncLabelModel except ImportError: - raise SystemExit("""Please use `pip install "git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]"` to use the Sortformer diarization""") - -class SortformerDiarization: - def __init__(self, model_name="nvidia/diar_streaming_sortformer_4spk-v2"): - self.diar_model = SortformerEncLabelModel.from_pretrained(model_name) - self.diar_model.eval() - - if torch.cuda.is_available(): - self.diar_model.to(torch.device("cuda")) - - # Streaming parameters for speed - self.diar_model.sortformer_modules.chunk_len = 12 - self.diar_model.sortformer_modules.chunk_right_context = 1 - self.diar_model.sortformer_modules.spkcache_len = 188 - self.diar_model.sortformer_modules.fifo_len = 188 - self.diar_model.sortformer_modules.spkcache_update_period = 144 - self.diar_model.sortformer_modules.log = False - self.diar_model.sortformer_modules._check_streaming_parameters() - - self.batch_size = 1 - self.processed_signal_offset = torch.zeros((self.batch_size,), dtype=torch.long, device=self.diar_model.device) - - self.audio_buffer = np.array([], dtype=np.float32) - self.sample_rate = 16000 - self.speaker_segments = [] - - self.streaming_state = self.diar_model.sortformer_modules.init_streaming_state( - batch_size=self.batch_size, - async_streaming=True, - device=self.diar_model.device - ) - self.total_preds = torch.zeros((self.batch_size, 0, self.diar_model.sortformer_modules.n_spk), device=self.diar_model.device) - - - def _prepare_audio_signal(self, signal): - audio_signal = torch.tensor(signal).unsqueeze(0).to(self.diar_model.device) - audio_signal_length = torch.tensor([audio_signal.shape[1]]).to(self.diar_model.device) - processed_signal, processed_signal_length = self.diar_model.preprocessor(input_signal=audio_signal, length=audio_signal_length) - return processed_signal, processed_signal_length - - def _create_streaming_loader(self, processed_signal, processed_signal_length): - streaming_loader = self.diar_model.sortformer_modules.streaming_feat_loader( - feat_seq=processed_signal, - feat_seq_length=processed_signal_length, - feat_seq_offset=self.processed_signal_offset, - ) - return streaming_loader - - async def diarize(self, pcm_array: np.ndarray): - """ - Process an incoming audio chunk for diarization. - """ - self.audio_buffer = np.concatenate([self.audio_buffer, pcm_array]) - - # Process in fixed-size chunks (e.g., 1 second) - chunk_size = self.sample_rate # 1 second of audio - - while len(self.audio_buffer) >= chunk_size: - chunk_to_process = self.audio_buffer[:chunk_size] - self.audio_buffer = self.audio_buffer[chunk_size:] - - processed_signal, processed_signal_length = self._prepare_audio_signal(chunk_to_process) - - current_offset_seconds = self.processed_signal_offset.item() * self.diar_model.preprocessor._cfg.window_stride - - streaming_loader = self._create_streaming_loader(processed_signal, processed_signal_length) - - frame_duration_s = self.diar_model.sortformer_modules.subsampling_factor * self.diar_model.preprocessor._cfg.window_stride - chunk_duration_seconds = self.diar_model.sortformer_modules.chunk_len * frame_duration_s - - for i, chunk_feat_seq_t, feat_lengths, left_offset, right_offset in streaming_loader: - with torch.inference_mode(): - self.streaming_state, self.total_preds = self.diar_model.forward_streaming_step( - processed_signal=chunk_feat_seq_t, - processed_signal_length=feat_lengths, - streaming_state=self.streaming_state, - total_preds=self.total_preds, - left_offset=left_offset, - right_offset=right_offset, - ) - - num_new_frames = feat_lengths[0].item() - - # Get predictions for the current chunk from the end of total_preds - preds_np = self.total_preds[0, -num_new_frames:].cpu().numpy() - active_speakers = np.argmax(preds_np, axis=1) - - for idx, spk in enumerate(active_speakers): - start_time = current_offset_seconds + (i * chunk_duration_seconds) + (idx * frame_duration_s) - end_time = start_time + frame_duration_s - - if self.speaker_segments and self.speaker_segments[-1].speaker == spk + 1: - self.speaker_segments[-1].end = end_time - else: - self.speaker_segments.append(SpeakerSegment( - speaker=int(spk + 1), - start=start_time, - end=end_time - )) - - self.processed_signal_offset += processed_signal_length - - - def assign_speakers_to_tokens(self, tokens: list, **kwargs) -> list: - """ - Assign speakers to tokens based on timing overlap with speaker segments. - """ - for token in tokens: - for segment in self.speaker_segments: - if not (segment.end <= token.start or segment.start >= token.end): - token.speaker = segment.speaker - return tokens - - def close(self): - """ - Cleanup resources. - """ - logger.info("Closing SortformerDiarization.") - -if __name__ == '__main__': - import librosa - an4_audio = 'new_audio_test.mp3' - signal, sr = librosa.load(an4_audio, sr=16000) - - diarization_pipeline = SortformerDiarization() - - # Simulate streaming - chunk_size = 16000 # 1 second - for i in range(0, len(signal), chunk_size): - chunk = signal[i:i+chunk_size] - import asyncio - asyncio.run(diarization_pipeline.diarize(chunk)) - - for segment in diarization_pipeline.speaker_segments: - print(f"Speaker {segment.speaker}: {segment.start:.2f}s - {segment.end:.2f}s") \ No newline at end of file + raise SystemExit("""Please use `pip install "git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]"` to use the Sortformer diarization""") \ No newline at end of file diff --git a/whisperlivekit/diarization/sortformer_backend_2.py b/whisperlivekit/diarization/sortformer_backend_2.py deleted file mode 100644 index ede3bfa..0000000 --- a/whisperlivekit/diarization/sortformer_backend_2.py +++ /dev/null @@ -1,257 +0,0 @@ -import numpy as np -import torch -import logging -import math -logger = logging.getLogger(__name__) - -try: - from nemo.collections.asr.models import SortformerEncLabelModel -except ImportError: - raise SystemExit("""Please use `pip install "git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]"` to use the Sortformer diarization""") - - -diar_model = SortformerEncLabelModel.from_pretrained("nvidia/diar_streaming_sortformer_4spk-v2") -diar_model.eval() - -if torch.cuda.is_available(): - diar_model.to(torch.device("cuda")) - -# Set the streaming parameters corresponding to 1.04s latency setup. This will affect the streaming feat loader. -# diar_model.sortformer_modules.chunk_len = 6 -# diar_model.sortformer_modules.spkcache_len = 188 -# diar_model.sortformer_modules.chunk_right_context = 7 -# diar_model.sortformer_modules.fifo_len = 188 -# diar_model.sortformer_modules.spkcache_update_period = 144 -# diar_model.sortformer_modules.log = False - - -# here we change the settings for our goal: speed! -# we want batches of around 1 second. one frame is 0.08s, so 1s is 12.5 frames. we take 12. -diar_model.sortformer_modules.chunk_len = 12 - -# for more speed, we reduce the 'right context'. it's like looking less into the future. -diar_model.sortformer_modules.chunk_right_context = 1 - -# we keep the rest same for now -diar_model.sortformer_modules.spkcache_len = 188 -diar_model.sortformer_modules.fifo_len = 188 -diar_model.sortformer_modules.spkcache_update_period = 144 -diar_model.sortformer_modules.log = False -diar_model.sortformer_modules._check_streaming_parameters() - -batch_size = 1 -processed_signal_offset = torch.zeros((batch_size,), dtype=torch.long, device=diar_model.device) - -# from nemo.collections.asr.parts.preprocessing.features import FilterbankFeatures -# from nemo.collections.asr.modules.audio_preprocessing import get_features -from nemo.collections.asr.modules.audio_preprocessing import AudioToMelSpectrogramPreprocessor - - -def prepare_audio_signal(signal): - audio_signal = torch.tensor(signal).unsqueeze(0).to(diar_model.device) - audio_signal_length = torch.tensor([audio_signal.shape[1]]).to(diar_model.device) - processed_signal, processed_signal_length = AudioToMelSpectrogramPreprocessor( - window_size= 0.025, - normalize="NA", - n_fft=512, - features=128).get_features(audio_signal, audio_signal_length) - return processed_signal, processed_signal_length - - -def streaming_feat_loader( - feat_seq, feat_seq_length, feat_seq_offset -): - """ - Load a chunk of feature sequence for streaming inference. - - Args: - feat_seq (torch.Tensor): Tensor containing feature sequence - Shape: (batch_size, feat_dim, feat frame count) - feat_seq_length (torch.Tensor): Tensor containing feature sequence lengths - Shape: (batch_size,) - feat_seq_offset (torch.Tensor): Tensor containing feature sequence offsets - Shape: (batch_size,) - - Returns: - chunk_idx (int): Index of the current chunk - chunk_feat_seq (torch.Tensor): Tensor containing the chunk of feature sequence - Shape: (batch_size, diar frame count, feat_dim) - feat_lengths (torch.Tensor): Tensor containing lengths of the chunk of feature sequence - Shape: (batch_size,) - """ - feat_len = feat_seq.shape[2] - num_chunks = math.ceil(feat_len / (diar_model.sortformer_modules.chunk_len * diar_model.sortformer_modules.subsampling_factor)) - if False: - logging.info( - f"feat_len={feat_len}, num_chunks={num_chunks}, " - f"feat_seq_length={feat_seq_length}, feat_seq_offset={feat_seq_offset}" - ) - - stt_feat, end_feat, chunk_idx = 0, 0, 0 - while end_feat < feat_len: - left_offset = min(diar_model.sortformer_modules.chunk_left_context * diar_model.sortformer_modules.subsampling_factor, stt_feat) - end_feat = min(stt_feat + diar_model.sortformer_modules.chunk_len * diar_model.sortformer_modules.subsampling_factor, feat_len) - right_offset = min(diar_model.sortformer_modules.chunk_right_context * diar_model.sortformer_modules.subsampling_factor, feat_len - end_feat) - chunk_feat_seq = feat_seq[:, :, stt_feat - left_offset : end_feat + right_offset] - feat_lengths = (feat_seq_length + feat_seq_offset - stt_feat + left_offset).clamp( - 0, chunk_feat_seq.shape[2] - ) - feat_lengths = feat_lengths * (feat_seq_offset < end_feat) - stt_feat = end_feat - chunk_feat_seq_t = torch.transpose(chunk_feat_seq, 1, 2) - if False: - logging.info( - f"chunk_idx: {chunk_idx}, " - f"chunk_feat_seq_t shape: {chunk_feat_seq_t.shape}, " - f"chunk_feat_lengths: {feat_lengths}" - ) - yield chunk_idx, chunk_feat_seq_t, feat_lengths, left_offset, right_offset - chunk_idx += 1 - - -class StreamingSortformerState: - """ - This class creates a class instance that will be used to store the state of the - streaming Sortformer model. - - Attributes: - spkcache (torch.Tensor): Speaker cache to store embeddings from start - spkcache_lengths (torch.Tensor): Lengths of the speaker cache - spkcache_preds (torch.Tensor): The speaker predictions for the speaker cache parts - fifo (torch.Tensor): FIFO queue to save the embedding from the latest chunks - fifo_lengths (torch.Tensor): Lengths of the FIFO queue - fifo_preds (torch.Tensor): The speaker predictions for the FIFO queue parts - spk_perm (torch.Tensor): Speaker permutation information for the speaker cache - mean_sil_emb (torch.Tensor): Mean silence embedding - n_sil_frames (torch.Tensor): Number of silence frames - """ - - spkcache = None # Speaker cache to store embeddings from start - spkcache_lengths = None # - spkcache_preds = None # speaker cache predictions - fifo = None # to save the embedding from the latest chunks - fifo_lengths = None - fifo_preds = None - spk_perm = None - mean_sil_emb = None - n_sil_frames = None - - -def init_streaming_state(self, batch_size: int = 1, async_streaming: bool = False, device: torch.device = None): - """ - Initializes StreamingSortformerState with empty tensors or zero-valued tensors. - - Args: - batch_size (int): Batch size for tensors in streaming state - async_streaming (bool): True for asynchronous update, False for synchronous update - device (torch.device): Device for tensors in streaming state - - Returns: - streaming_state (SortformerStreamingState): initialized streaming state - """ - streaming_state = StreamingSortformerState() - if async_streaming: - streaming_state.spkcache = torch.zeros((batch_size, self.spkcache_len, self.fc_d_model), device=device) - streaming_state.spkcache_preds = torch.zeros((batch_size, self.spkcache_len, self.n_spk), device=device) - streaming_state.spkcache_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device) - streaming_state.fifo = torch.zeros((batch_size, self.fifo_len, self.fc_d_model), device=device) - streaming_state.fifo_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device) - else: - streaming_state.spkcache = torch.zeros((batch_size, 0, self.fc_d_model), device=device) - streaming_state.fifo = torch.zeros((batch_size, 0, self.fc_d_model), device=device) - streaming_state.mean_sil_emb = torch.zeros((batch_size, self.fc_d_model), device=device) - streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device) - return streaming_state - -def process_diarization(signal, chunks): - - audio_signal = torch.tensor(signal).unsqueeze(0).to(diar_model.device) - audio_signal_length = torch.tensor([audio_signal.shape[1]]).to(diar_model.device) - processed_signal, processed_signal_length = AudioToMelSpectrogramPreprocessor( - window_size= 0.025, - normalize="NA", - n_fft=512, - features=128).get_features(audio_signal, audio_signal_length) - - - streaming_loader = streaming_feat_loader(processed_signal, processed_signal_length, processed_signal_offset) - - - streaming_state = init_streaming_state(diar_model.sortformer_modules, - batch_size = batch_size, - async_streaming = True, - device = diar_model.device - ) - total_preds = torch.zeros((batch_size, 0, diar_model.sortformer_modules.n_spk), device=diar_model.device) - - - chunk_duration_seconds = diar_model.sortformer_modules.chunk_len * diar_model.sortformer_modules.subsampling_factor * diar_model.preprocessor._cfg.window_stride - print(f"Chunk duration: {chunk_duration_seconds} seconds") - - l_speakers = [ - {'start_time': 0, - 'end_time': 0, - 'speaker': 0 - } - ] - len_prediction = None - left_offset = 0 - right_offset = 8 - for i, chunk_feat_seq_t, _, _, _ in streaming_loader: - with torch.inference_mode(): - streaming_state, total_preds = diar_model.forward_streaming_step( - processed_signal=chunk_feat_seq_t, - processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]), - streaming_state=streaming_state, - total_preds=total_preds, - left_offset=left_offset, - right_offset=right_offset, - ) - left_offset = 8 - preds_np = total_preds[0].cpu().numpy() - active_speakers = np.argmax(preds_np, axis=1) - if len_prediction is None: - len_prediction = len(active_speakers) # we want to get the len of 1 prediction - frame_duration = chunk_duration_seconds / len_prediction - active_speakers = active_speakers[-len_prediction:] - print(chunk_feat_seq_t.shape, total_preds.shape) - for idx, spk in enumerate(active_speakers): - if spk != l_speakers[-1]['speaker']: - l_speakers.append( - {'start_time': i * chunk_duration_seconds + idx * frame_duration, - 'end_time': i * chunk_duration_seconds + (idx + 1) * frame_duration, - 'speaker': spk - }) - else: - l_speakers[-1]['end_time'] = i * chunk_duration_seconds + (idx + 1) * frame_duration - - print(l_speakers) - """ - Should print - [{'start_time': 0, 'end_time': 8.72, 'speaker': 0}, - {'start_time': 8.72, 'end_time': 18.88, 'speaker': 1}, - {'start_time': 18.88, 'end_time': 24.96, 'speaker': 2}, - {'start_time': 24.96, 'end_time': 31.68, 'speaker': 0}] - """ - -if __name__ == '__main__': - import librosa - an4_audio = 'new_audio_test.mp3' - signal, sr = librosa.load(an4_audio,sr=16000) - - """ - ground truth: - speaker 0 : 0:00 - 0:09 - speaker 1 : 0:09 - 0:19 - speaker 2 : 0:19 - 0:25 - speaker 0 : 0:25 - end - """ - - # Simulate streaming - chunk_size = 16000 # 1 second - chunks = [] - for i in range(0, len(signal), chunk_size): - chunk = signal[i:i+chunk_size] - chunks.append(chunk) - - process_diarization(signal, chunks) \ No newline at end of file diff --git a/whisperlivekit/diarization/sortformer_backend_offline.py b/whisperlivekit/diarization/sortformer_backend_offline.py new file mode 100644 index 0000000..2619154 --- /dev/null +++ b/whisperlivekit/diarization/sortformer_backend_offline.py @@ -0,0 +1,205 @@ +import numpy as np +import torch +import logging + +from nemo.collections.asr.models import SortformerEncLabelModel +from nemo.collections.asr.modules import AudioToMelSpectrogramPreprocessor +import librosa + +logger = logging.getLogger(__name__) + +def load_model(): + + diar_model = SortformerEncLabelModel.from_pretrained("nvidia/diar_streaming_sortformer_4spk-v2") + diar_model.eval() + + if torch.cuda.is_available(): + diar_model.to(torch.device("cuda")) + + #we target 1 second lag for the moment. chunk_len could be reduced. + diar_model.sortformer_modules.chunk_len = 10 + diar_model.sortformer_modules.subsampling_factor = 10 #8 would be better ideally + + diar_model.sortformer_modules.chunk_right_context = 0 #no. + diar_model.sortformer_modules.chunk_left_context = 10 #big so it compensiate the problem with no padding later. + + diar_model.sortformer_modules.spkcache_len = 188 + diar_model.sortformer_modules.fifo_len = 188 + diar_model.sortformer_modules.spkcache_update_period = 144 + diar_model.sortformer_modules.log = False + diar_model.sortformer_modules._check_streaming_parameters() + + + audio2mel = AudioToMelSpectrogramPreprocessor( + window_size= 0.025, + normalize="NA", + n_fft=512, + features=128, + pad_to=0) #pad_to 16 works better than 0. On test audio, we detect a third speaker for 1 second with pad_to=0. To solve that : increase left context to 10. + + return diar_model, audio2mel + +diar_model, audio2mel = load_model() + +class StreamingSortformerState: + """ + This class creates a class instance that will be used to store the state of the + streaming Sortformer model. + + Attributes: + spkcache (torch.Tensor): Speaker cache to store embeddings from start + spkcache_lengths (torch.Tensor): Lengths of the speaker cache + spkcache_preds (torch.Tensor): The speaker predictions for the speaker cache parts + fifo (torch.Tensor): FIFO queue to save the embedding from the latest chunks + fifo_lengths (torch.Tensor): Lengths of the FIFO queue + fifo_preds (torch.Tensor): The speaker predictions for the FIFO queue parts + spk_perm (torch.Tensor): Speaker permutation information for the speaker cache + mean_sil_emb (torch.Tensor): Mean silence embedding + n_sil_frames (torch.Tensor): Number of silence frames + """ + + spkcache = None # Speaker cache to store embeddings from start + spkcache_lengths = None # + spkcache_preds = None # speaker cache predictions + fifo = None # to save the embedding from the latest chunks + fifo_lengths = None + fifo_preds = None + spk_perm = None + mean_sil_emb = None + n_sil_frames = None + + +def init_streaming_state(self, batch_size: int = 1, async_streaming: bool = False, device: torch.device = None): + """ + Initializes StreamingSortformerState with empty tensors or zero-valued tensors. + + Args: + batch_size (int): Batch size for tensors in streaming state + async_streaming (bool): True for asynchronous update, False for synchronous update + device (torch.device): Device for tensors in streaming state + + Returns: + streaming_state (SortformerStreamingState): initialized streaming state + """ + streaming_state = StreamingSortformerState() + if async_streaming: + streaming_state.spkcache = torch.zeros((batch_size, self.spkcache_len, self.fc_d_model), device=device) + streaming_state.spkcache_preds = torch.zeros((batch_size, self.spkcache_len, self.n_spk), device=device) + streaming_state.spkcache_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device) + streaming_state.fifo = torch.zeros((batch_size, self.fifo_len, self.fc_d_model), device=device) + streaming_state.fifo_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device) + else: + streaming_state.spkcache = torch.zeros((batch_size, 0, self.fc_d_model), device=device) + streaming_state.fifo = torch.zeros((batch_size, 0, self.fc_d_model), device=device) + streaming_state.mean_sil_emb = torch.zeros((batch_size, self.fc_d_model), device=device) + streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device) + return streaming_state + + +def process_diarization(chunks): + """ + what it does: + 1. Preprocessing: Applies dithering and pre-emphasis (high-pass filter) if enabled + 2. STFT: Computes the Short-Time Fourier Transform using: + - the window of window_size=0.025 --> size of a window : 400 samples + - the hop parameter : n_window_stride = 0.01 -> every 160 samples, a new window + 3. Magnitude Calculation: Converts complex STFT output to magnitude spectrogram + 4. Mel Conversion: Applies Mel filterbanks (128 filters in this case) to get Mel spectrogram + 5. Logarithm: Takes the log of the Mel spectrogram (if `log=True`) + 6. Normalization: Skips normalization since `normalize="NA"` + 7. Padding: Pads the time dimension to a multiple of `pad_to` (default 16) + """ + previous_chunk = None + l_chunk_feat_seq_t = [] + for chunk in chunks: + audio_signal_chunk = torch.tensor(chunk).unsqueeze(0).to(diar_model.device) + audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]]).to(diar_model.device) + processed_signal_chunk, processed_signal_length_chunk = audio2mel.get_features(audio_signal_chunk, audio_signal_length_chunk) + if previous_chunk is not None: + to_add = previous_chunk[:, :, -99:] + total = torch.concat([to_add, processed_signal_chunk], dim=2) + else: + total = processed_signal_chunk + previous_chunk = processed_signal_chunk + l_chunk_feat_seq_t.append(torch.transpose(total, 1, 2)) + + batch_size = 1 + streaming_state = init_streaming_state(diar_model.sortformer_modules, + batch_size = batch_size, + async_streaming = True, + device = diar_model.device + ) + total_preds = torch.zeros((batch_size, 0, diar_model.sortformer_modules.n_spk), device=diar_model.device) + + chunk_duration_seconds = diar_model.sortformer_modules.chunk_len * diar_model.sortformer_modules.subsampling_factor * diar_model.preprocessor._cfg.window_stride + + l_speakers = [ + {'start_time': 0, + 'end_time': 0, + 'speaker': 0 + } + ] + len_prediction = None + left_offset = 0 + right_offset = 8 + for i, chunk_feat_seq_t in enumerate(l_chunk_feat_seq_t): + with torch.inference_mode(): + streaming_state, total_preds = diar_model.forward_streaming_step( + processed_signal=chunk_feat_seq_t, + processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]), + streaming_state=streaming_state, + total_preds=total_preds, + left_offset=left_offset, + right_offset=right_offset, + ) + left_offset = 8 + preds_np = total_preds[0].cpu().numpy() + active_speakers = np.argmax(preds_np, axis=1) + if len_prediction is None: + len_prediction = len(active_speakers) # we want to get the len of 1 prediction + frame_duration = chunk_duration_seconds / len_prediction + active_speakers = active_speakers[-len_prediction:] + for idx, spk in enumerate(active_speakers): + if spk != l_speakers[-1]['speaker']: + l_speakers.append( + {'start_time': (i * chunk_duration_seconds + idx * frame_duration), + 'end_time': (i * chunk_duration_seconds + (idx + 1) * frame_duration), + 'speaker': spk + }) + else: + l_speakers[-1]['end_time'] = i * chunk_duration_seconds + (idx + 1) * frame_duration + + + """ + Should print + [{'start_time': 0, 'end_time': 8.72, 'speaker': 0}, + {'start_time': 8.72, 'end_time': 18.88, 'speaker': 1}, + {'start_time': 18.88, 'end_time': 24.96, 'speaker': 2}, + {'start_time': 24.96, 'end_time': 31.68, 'speaker': 0}] + """ + for speaker in l_speakers: + print(f"Speaker {speaker['speaker']}: {speaker['start_time']:.2f}s - {speaker['end_time']:.2f}s") + + +if __name__ == '__main__': + + an4_audio = 'audio_test.mp3' + signal, sr = librosa.load(an4_audio, sr=16000) + signal = signal[:16000*30] + # signal = signal[:-(len(signal)%16000)] + + print("\n" + "=" * 50) + print("Expected ground truth:") + print("Speaker 0: 0:00 - 0:09") + print("Speaker 1: 0:09 - 0:19") + print("Speaker 2: 0:19 - 0:25") + print("Speaker 0: 0:25 - 0:30") + print("=" * 50) + + chunk_size = 16000 # 1 second + chunks = [] + for i in range(0, len(signal), chunk_size): + chunk = signal[i:i+chunk_size] + chunks.append(chunk) + + process_diarization(chunks) \ No newline at end of file