sortformer diar implementation v0.2

This commit is contained in:
Quentin Fuxa
2025-08-24 18:32:01 +02:00
parent 5b2ddeccdb
commit 3393a08f7e
3 changed files with 206 additions and 392 deletions

View File

@@ -8,138 +8,4 @@ logger = logging.getLogger(__name__)
try:
from nemo.collections.asr.models import SortformerEncLabelModel
except ImportError:
raise SystemExit("""Please use `pip install "git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]"` to use the Sortformer diarization""")
class SortformerDiarization:
def __init__(self, model_name="nvidia/diar_streaming_sortformer_4spk-v2"):
self.diar_model = SortformerEncLabelModel.from_pretrained(model_name)
self.diar_model.eval()
if torch.cuda.is_available():
self.diar_model.to(torch.device("cuda"))
# Streaming parameters for speed
self.diar_model.sortformer_modules.chunk_len = 12
self.diar_model.sortformer_modules.chunk_right_context = 1
self.diar_model.sortformer_modules.spkcache_len = 188
self.diar_model.sortformer_modules.fifo_len = 188
self.diar_model.sortformer_modules.spkcache_update_period = 144
self.diar_model.sortformer_modules.log = False
self.diar_model.sortformer_modules._check_streaming_parameters()
self.batch_size = 1
self.processed_signal_offset = torch.zeros((self.batch_size,), dtype=torch.long, device=self.diar_model.device)
self.audio_buffer = np.array([], dtype=np.float32)
self.sample_rate = 16000
self.speaker_segments = []
self.streaming_state = self.diar_model.sortformer_modules.init_streaming_state(
batch_size=self.batch_size,
async_streaming=True,
device=self.diar_model.device
)
self.total_preds = torch.zeros((self.batch_size, 0, self.diar_model.sortformer_modules.n_spk), device=self.diar_model.device)
def _prepare_audio_signal(self, signal):
audio_signal = torch.tensor(signal).unsqueeze(0).to(self.diar_model.device)
audio_signal_length = torch.tensor([audio_signal.shape[1]]).to(self.diar_model.device)
processed_signal, processed_signal_length = self.diar_model.preprocessor(input_signal=audio_signal, length=audio_signal_length)
return processed_signal, processed_signal_length
def _create_streaming_loader(self, processed_signal, processed_signal_length):
streaming_loader = self.diar_model.sortformer_modules.streaming_feat_loader(
feat_seq=processed_signal,
feat_seq_length=processed_signal_length,
feat_seq_offset=self.processed_signal_offset,
)
return streaming_loader
async def diarize(self, pcm_array: np.ndarray):
"""
Process an incoming audio chunk for diarization.
"""
self.audio_buffer = np.concatenate([self.audio_buffer, pcm_array])
# Process in fixed-size chunks (e.g., 1 second)
chunk_size = self.sample_rate # 1 second of audio
while len(self.audio_buffer) >= chunk_size:
chunk_to_process = self.audio_buffer[:chunk_size]
self.audio_buffer = self.audio_buffer[chunk_size:]
processed_signal, processed_signal_length = self._prepare_audio_signal(chunk_to_process)
current_offset_seconds = self.processed_signal_offset.item() * self.diar_model.preprocessor._cfg.window_stride
streaming_loader = self._create_streaming_loader(processed_signal, processed_signal_length)
frame_duration_s = self.diar_model.sortformer_modules.subsampling_factor * self.diar_model.preprocessor._cfg.window_stride
chunk_duration_seconds = self.diar_model.sortformer_modules.chunk_len * frame_duration_s
for i, chunk_feat_seq_t, feat_lengths, left_offset, right_offset in streaming_loader:
with torch.inference_mode():
self.streaming_state, self.total_preds = self.diar_model.forward_streaming_step(
processed_signal=chunk_feat_seq_t,
processed_signal_length=feat_lengths,
streaming_state=self.streaming_state,
total_preds=self.total_preds,
left_offset=left_offset,
right_offset=right_offset,
)
num_new_frames = feat_lengths[0].item()
# Get predictions for the current chunk from the end of total_preds
preds_np = self.total_preds[0, -num_new_frames:].cpu().numpy()
active_speakers = np.argmax(preds_np, axis=1)
for idx, spk in enumerate(active_speakers):
start_time = current_offset_seconds + (i * chunk_duration_seconds) + (idx * frame_duration_s)
end_time = start_time + frame_duration_s
if self.speaker_segments and self.speaker_segments[-1].speaker == spk + 1:
self.speaker_segments[-1].end = end_time
else:
self.speaker_segments.append(SpeakerSegment(
speaker=int(spk + 1),
start=start_time,
end=end_time
))
self.processed_signal_offset += processed_signal_length
def assign_speakers_to_tokens(self, tokens: list, **kwargs) -> list:
"""
Assign speakers to tokens based on timing overlap with speaker segments.
"""
for token in tokens:
for segment in self.speaker_segments:
if not (segment.end <= token.start or segment.start >= token.end):
token.speaker = segment.speaker
return tokens
def close(self):
"""
Cleanup resources.
"""
logger.info("Closing SortformerDiarization.")
if __name__ == '__main__':
import librosa
an4_audio = 'new_audio_test.mp3'
signal, sr = librosa.load(an4_audio, sr=16000)
diarization_pipeline = SortformerDiarization()
# Simulate streaming
chunk_size = 16000 # 1 second
for i in range(0, len(signal), chunk_size):
chunk = signal[i:i+chunk_size]
import asyncio
asyncio.run(diarization_pipeline.diarize(chunk))
for segment in diarization_pipeline.speaker_segments:
print(f"Speaker {segment.speaker}: {segment.start:.2f}s - {segment.end:.2f}s")
raise SystemExit("""Please use `pip install "git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]"` to use the Sortformer diarization""")

View File

@@ -1,257 +0,0 @@
import numpy as np
import torch
import logging
import math
logger = logging.getLogger(__name__)
try:
from nemo.collections.asr.models import SortformerEncLabelModel
except ImportError:
raise SystemExit("""Please use `pip install "git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]"` to use the Sortformer diarization""")
diar_model = SortformerEncLabelModel.from_pretrained("nvidia/diar_streaming_sortformer_4spk-v2")
diar_model.eval()
if torch.cuda.is_available():
diar_model.to(torch.device("cuda"))
# Set the streaming parameters corresponding to 1.04s latency setup. This will affect the streaming feat loader.
# diar_model.sortformer_modules.chunk_len = 6
# diar_model.sortformer_modules.spkcache_len = 188
# diar_model.sortformer_modules.chunk_right_context = 7
# diar_model.sortformer_modules.fifo_len = 188
# diar_model.sortformer_modules.spkcache_update_period = 144
# diar_model.sortformer_modules.log = False
# here we change the settings for our goal: speed!
# we want batches of around 1 second. one frame is 0.08s, so 1s is 12.5 frames. we take 12.
diar_model.sortformer_modules.chunk_len = 12
# for more speed, we reduce the 'right context'. it's like looking less into the future.
diar_model.sortformer_modules.chunk_right_context = 1
# we keep the rest same for now
diar_model.sortformer_modules.spkcache_len = 188
diar_model.sortformer_modules.fifo_len = 188
diar_model.sortformer_modules.spkcache_update_period = 144
diar_model.sortformer_modules.log = False
diar_model.sortformer_modules._check_streaming_parameters()
batch_size = 1
processed_signal_offset = torch.zeros((batch_size,), dtype=torch.long, device=diar_model.device)
# from nemo.collections.asr.parts.preprocessing.features import FilterbankFeatures
# from nemo.collections.asr.modules.audio_preprocessing import get_features
from nemo.collections.asr.modules.audio_preprocessing import AudioToMelSpectrogramPreprocessor
def prepare_audio_signal(signal):
audio_signal = torch.tensor(signal).unsqueeze(0).to(diar_model.device)
audio_signal_length = torch.tensor([audio_signal.shape[1]]).to(diar_model.device)
processed_signal, processed_signal_length = AudioToMelSpectrogramPreprocessor(
window_size= 0.025,
normalize="NA",
n_fft=512,
features=128).get_features(audio_signal, audio_signal_length)
return processed_signal, processed_signal_length
def streaming_feat_loader(
feat_seq, feat_seq_length, feat_seq_offset
):
"""
Load a chunk of feature sequence for streaming inference.
Args:
feat_seq (torch.Tensor): Tensor containing feature sequence
Shape: (batch_size, feat_dim, feat frame count)
feat_seq_length (torch.Tensor): Tensor containing feature sequence lengths
Shape: (batch_size,)
feat_seq_offset (torch.Tensor): Tensor containing feature sequence offsets
Shape: (batch_size,)
Returns:
chunk_idx (int): Index of the current chunk
chunk_feat_seq (torch.Tensor): Tensor containing the chunk of feature sequence
Shape: (batch_size, diar frame count, feat_dim)
feat_lengths (torch.Tensor): Tensor containing lengths of the chunk of feature sequence
Shape: (batch_size,)
"""
feat_len = feat_seq.shape[2]
num_chunks = math.ceil(feat_len / (diar_model.sortformer_modules.chunk_len * diar_model.sortformer_modules.subsampling_factor))
if False:
logging.info(
f"feat_len={feat_len}, num_chunks={num_chunks}, "
f"feat_seq_length={feat_seq_length}, feat_seq_offset={feat_seq_offset}"
)
stt_feat, end_feat, chunk_idx = 0, 0, 0
while end_feat < feat_len:
left_offset = min(diar_model.sortformer_modules.chunk_left_context * diar_model.sortformer_modules.subsampling_factor, stt_feat)
end_feat = min(stt_feat + diar_model.sortformer_modules.chunk_len * diar_model.sortformer_modules.subsampling_factor, feat_len)
right_offset = min(diar_model.sortformer_modules.chunk_right_context * diar_model.sortformer_modules.subsampling_factor, feat_len - end_feat)
chunk_feat_seq = feat_seq[:, :, stt_feat - left_offset : end_feat + right_offset]
feat_lengths = (feat_seq_length + feat_seq_offset - stt_feat + left_offset).clamp(
0, chunk_feat_seq.shape[2]
)
feat_lengths = feat_lengths * (feat_seq_offset < end_feat)
stt_feat = end_feat
chunk_feat_seq_t = torch.transpose(chunk_feat_seq, 1, 2)
if False:
logging.info(
f"chunk_idx: {chunk_idx}, "
f"chunk_feat_seq_t shape: {chunk_feat_seq_t.shape}, "
f"chunk_feat_lengths: {feat_lengths}"
)
yield chunk_idx, chunk_feat_seq_t, feat_lengths, left_offset, right_offset
chunk_idx += 1
class StreamingSortformerState:
"""
This class creates a class instance that will be used to store the state of the
streaming Sortformer model.
Attributes:
spkcache (torch.Tensor): Speaker cache to store embeddings from start
spkcache_lengths (torch.Tensor): Lengths of the speaker cache
spkcache_preds (torch.Tensor): The speaker predictions for the speaker cache parts
fifo (torch.Tensor): FIFO queue to save the embedding from the latest chunks
fifo_lengths (torch.Tensor): Lengths of the FIFO queue
fifo_preds (torch.Tensor): The speaker predictions for the FIFO queue parts
spk_perm (torch.Tensor): Speaker permutation information for the speaker cache
mean_sil_emb (torch.Tensor): Mean silence embedding
n_sil_frames (torch.Tensor): Number of silence frames
"""
spkcache = None # Speaker cache to store embeddings from start
spkcache_lengths = None #
spkcache_preds = None # speaker cache predictions
fifo = None # to save the embedding from the latest chunks
fifo_lengths = None
fifo_preds = None
spk_perm = None
mean_sil_emb = None
n_sil_frames = None
def init_streaming_state(self, batch_size: int = 1, async_streaming: bool = False, device: torch.device = None):
"""
Initializes StreamingSortformerState with empty tensors or zero-valued tensors.
Args:
batch_size (int): Batch size for tensors in streaming state
async_streaming (bool): True for asynchronous update, False for synchronous update
device (torch.device): Device for tensors in streaming state
Returns:
streaming_state (SortformerStreamingState): initialized streaming state
"""
streaming_state = StreamingSortformerState()
if async_streaming:
streaming_state.spkcache = torch.zeros((batch_size, self.spkcache_len, self.fc_d_model), device=device)
streaming_state.spkcache_preds = torch.zeros((batch_size, self.spkcache_len, self.n_spk), device=device)
streaming_state.spkcache_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
streaming_state.fifo = torch.zeros((batch_size, self.fifo_len, self.fc_d_model), device=device)
streaming_state.fifo_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
else:
streaming_state.spkcache = torch.zeros((batch_size, 0, self.fc_d_model), device=device)
streaming_state.fifo = torch.zeros((batch_size, 0, self.fc_d_model), device=device)
streaming_state.mean_sil_emb = torch.zeros((batch_size, self.fc_d_model), device=device)
streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device)
return streaming_state
def process_diarization(signal, chunks):
audio_signal = torch.tensor(signal).unsqueeze(0).to(diar_model.device)
audio_signal_length = torch.tensor([audio_signal.shape[1]]).to(diar_model.device)
processed_signal, processed_signal_length = AudioToMelSpectrogramPreprocessor(
window_size= 0.025,
normalize="NA",
n_fft=512,
features=128).get_features(audio_signal, audio_signal_length)
streaming_loader = streaming_feat_loader(processed_signal, processed_signal_length, processed_signal_offset)
streaming_state = init_streaming_state(diar_model.sortformer_modules,
batch_size = batch_size,
async_streaming = True,
device = diar_model.device
)
total_preds = torch.zeros((batch_size, 0, diar_model.sortformer_modules.n_spk), device=diar_model.device)
chunk_duration_seconds = diar_model.sortformer_modules.chunk_len * diar_model.sortformer_modules.subsampling_factor * diar_model.preprocessor._cfg.window_stride
print(f"Chunk duration: {chunk_duration_seconds} seconds")
l_speakers = [
{'start_time': 0,
'end_time': 0,
'speaker': 0
}
]
len_prediction = None
left_offset = 0
right_offset = 8
for i, chunk_feat_seq_t, _, _, _ in streaming_loader:
with torch.inference_mode():
streaming_state, total_preds = diar_model.forward_streaming_step(
processed_signal=chunk_feat_seq_t,
processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]),
streaming_state=streaming_state,
total_preds=total_preds,
left_offset=left_offset,
right_offset=right_offset,
)
left_offset = 8
preds_np = total_preds[0].cpu().numpy()
active_speakers = np.argmax(preds_np, axis=1)
if len_prediction is None:
len_prediction = len(active_speakers) # we want to get the len of 1 prediction
frame_duration = chunk_duration_seconds / len_prediction
active_speakers = active_speakers[-len_prediction:]
print(chunk_feat_seq_t.shape, total_preds.shape)
for idx, spk in enumerate(active_speakers):
if spk != l_speakers[-1]['speaker']:
l_speakers.append(
{'start_time': i * chunk_duration_seconds + idx * frame_duration,
'end_time': i * chunk_duration_seconds + (idx + 1) * frame_duration,
'speaker': spk
})
else:
l_speakers[-1]['end_time'] = i * chunk_duration_seconds + (idx + 1) * frame_duration
print(l_speakers)
"""
Should print
[{'start_time': 0, 'end_time': 8.72, 'speaker': 0},
{'start_time': 8.72, 'end_time': 18.88, 'speaker': 1},
{'start_time': 18.88, 'end_time': 24.96, 'speaker': 2},
{'start_time': 24.96, 'end_time': 31.68, 'speaker': 0}]
"""
if __name__ == '__main__':
import librosa
an4_audio = 'new_audio_test.mp3'
signal, sr = librosa.load(an4_audio,sr=16000)
"""
ground truth:
speaker 0 : 0:00 - 0:09
speaker 1 : 0:09 - 0:19
speaker 2 : 0:19 - 0:25
speaker 0 : 0:25 - end
"""
# Simulate streaming
chunk_size = 16000 # 1 second
chunks = []
for i in range(0, len(signal), chunk_size):
chunk = signal[i:i+chunk_size]
chunks.append(chunk)
process_diarization(signal, chunks)

View File

@@ -0,0 +1,205 @@
import numpy as np
import torch
import logging
from nemo.collections.asr.models import SortformerEncLabelModel
from nemo.collections.asr.modules import AudioToMelSpectrogramPreprocessor
import librosa
logger = logging.getLogger(__name__)
def load_model():
diar_model = SortformerEncLabelModel.from_pretrained("nvidia/diar_streaming_sortformer_4spk-v2")
diar_model.eval()
if torch.cuda.is_available():
diar_model.to(torch.device("cuda"))
#we target 1 second lag for the moment. chunk_len could be reduced.
diar_model.sortformer_modules.chunk_len = 10
diar_model.sortformer_modules.subsampling_factor = 10 #8 would be better ideally
diar_model.sortformer_modules.chunk_right_context = 0 #no.
diar_model.sortformer_modules.chunk_left_context = 10 #big so it compensiate the problem with no padding later.
diar_model.sortformer_modules.spkcache_len = 188
diar_model.sortformer_modules.fifo_len = 188
diar_model.sortformer_modules.spkcache_update_period = 144
diar_model.sortformer_modules.log = False
diar_model.sortformer_modules._check_streaming_parameters()
audio2mel = AudioToMelSpectrogramPreprocessor(
window_size= 0.025,
normalize="NA",
n_fft=512,
features=128,
pad_to=0) #pad_to 16 works better than 0. On test audio, we detect a third speaker for 1 second with pad_to=0. To solve that : increase left context to 10.
return diar_model, audio2mel
diar_model, audio2mel = load_model()
class StreamingSortformerState:
"""
This class creates a class instance that will be used to store the state of the
streaming Sortformer model.
Attributes:
spkcache (torch.Tensor): Speaker cache to store embeddings from start
spkcache_lengths (torch.Tensor): Lengths of the speaker cache
spkcache_preds (torch.Tensor): The speaker predictions for the speaker cache parts
fifo (torch.Tensor): FIFO queue to save the embedding from the latest chunks
fifo_lengths (torch.Tensor): Lengths of the FIFO queue
fifo_preds (torch.Tensor): The speaker predictions for the FIFO queue parts
spk_perm (torch.Tensor): Speaker permutation information for the speaker cache
mean_sil_emb (torch.Tensor): Mean silence embedding
n_sil_frames (torch.Tensor): Number of silence frames
"""
spkcache = None # Speaker cache to store embeddings from start
spkcache_lengths = None #
spkcache_preds = None # speaker cache predictions
fifo = None # to save the embedding from the latest chunks
fifo_lengths = None
fifo_preds = None
spk_perm = None
mean_sil_emb = None
n_sil_frames = None
def init_streaming_state(self, batch_size: int = 1, async_streaming: bool = False, device: torch.device = None):
"""
Initializes StreamingSortformerState with empty tensors or zero-valued tensors.
Args:
batch_size (int): Batch size for tensors in streaming state
async_streaming (bool): True for asynchronous update, False for synchronous update
device (torch.device): Device for tensors in streaming state
Returns:
streaming_state (SortformerStreamingState): initialized streaming state
"""
streaming_state = StreamingSortformerState()
if async_streaming:
streaming_state.spkcache = torch.zeros((batch_size, self.spkcache_len, self.fc_d_model), device=device)
streaming_state.spkcache_preds = torch.zeros((batch_size, self.spkcache_len, self.n_spk), device=device)
streaming_state.spkcache_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
streaming_state.fifo = torch.zeros((batch_size, self.fifo_len, self.fc_d_model), device=device)
streaming_state.fifo_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
else:
streaming_state.spkcache = torch.zeros((batch_size, 0, self.fc_d_model), device=device)
streaming_state.fifo = torch.zeros((batch_size, 0, self.fc_d_model), device=device)
streaming_state.mean_sil_emb = torch.zeros((batch_size, self.fc_d_model), device=device)
streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device)
return streaming_state
def process_diarization(chunks):
"""
what it does:
1. Preprocessing: Applies dithering and pre-emphasis (high-pass filter) if enabled
2. STFT: Computes the Short-Time Fourier Transform using:
- the window of window_size=0.025 --> size of a window : 400 samples
- the hop parameter : n_window_stride = 0.01 -> every 160 samples, a new window
3. Magnitude Calculation: Converts complex STFT output to magnitude spectrogram
4. Mel Conversion: Applies Mel filterbanks (128 filters in this case) to get Mel spectrogram
5. Logarithm: Takes the log of the Mel spectrogram (if `log=True`)
6. Normalization: Skips normalization since `normalize="NA"`
7. Padding: Pads the time dimension to a multiple of `pad_to` (default 16)
"""
previous_chunk = None
l_chunk_feat_seq_t = []
for chunk in chunks:
audio_signal_chunk = torch.tensor(chunk).unsqueeze(0).to(diar_model.device)
audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]]).to(diar_model.device)
processed_signal_chunk, processed_signal_length_chunk = audio2mel.get_features(audio_signal_chunk, audio_signal_length_chunk)
if previous_chunk is not None:
to_add = previous_chunk[:, :, -99:]
total = torch.concat([to_add, processed_signal_chunk], dim=2)
else:
total = processed_signal_chunk
previous_chunk = processed_signal_chunk
l_chunk_feat_seq_t.append(torch.transpose(total, 1, 2))
batch_size = 1
streaming_state = init_streaming_state(diar_model.sortformer_modules,
batch_size = batch_size,
async_streaming = True,
device = diar_model.device
)
total_preds = torch.zeros((batch_size, 0, diar_model.sortformer_modules.n_spk), device=diar_model.device)
chunk_duration_seconds = diar_model.sortformer_modules.chunk_len * diar_model.sortformer_modules.subsampling_factor * diar_model.preprocessor._cfg.window_stride
l_speakers = [
{'start_time': 0,
'end_time': 0,
'speaker': 0
}
]
len_prediction = None
left_offset = 0
right_offset = 8
for i, chunk_feat_seq_t in enumerate(l_chunk_feat_seq_t):
with torch.inference_mode():
streaming_state, total_preds = diar_model.forward_streaming_step(
processed_signal=chunk_feat_seq_t,
processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]),
streaming_state=streaming_state,
total_preds=total_preds,
left_offset=left_offset,
right_offset=right_offset,
)
left_offset = 8
preds_np = total_preds[0].cpu().numpy()
active_speakers = np.argmax(preds_np, axis=1)
if len_prediction is None:
len_prediction = len(active_speakers) # we want to get the len of 1 prediction
frame_duration = chunk_duration_seconds / len_prediction
active_speakers = active_speakers[-len_prediction:]
for idx, spk in enumerate(active_speakers):
if spk != l_speakers[-1]['speaker']:
l_speakers.append(
{'start_time': (i * chunk_duration_seconds + idx * frame_duration),
'end_time': (i * chunk_duration_seconds + (idx + 1) * frame_duration),
'speaker': spk
})
else:
l_speakers[-1]['end_time'] = i * chunk_duration_seconds + (idx + 1) * frame_duration
"""
Should print
[{'start_time': 0, 'end_time': 8.72, 'speaker': 0},
{'start_time': 8.72, 'end_time': 18.88, 'speaker': 1},
{'start_time': 18.88, 'end_time': 24.96, 'speaker': 2},
{'start_time': 24.96, 'end_time': 31.68, 'speaker': 0}]
"""
for speaker in l_speakers:
print(f"Speaker {speaker['speaker']}: {speaker['start_time']:.2f}s - {speaker['end_time']:.2f}s")
if __name__ == '__main__':
an4_audio = 'audio_test.mp3'
signal, sr = librosa.load(an4_audio, sr=16000)
signal = signal[:16000*30]
# signal = signal[:-(len(signal)%16000)]
print("\n" + "=" * 50)
print("Expected ground truth:")
print("Speaker 0: 0:00 - 0:09")
print("Speaker 1: 0:09 - 0:19")
print("Speaker 2: 0:19 - 0:25")
print("Speaker 0: 0:25 - 0:30")
print("=" * 50)
chunk_size = 16000 # 1 second
chunks = []
for i in range(0, len(signal), chunk_size):
chunk = signal[i:i+chunk_size]
chunks.append(chunk)
process_diarization(chunks)