From 3ad3683ca7d5ab038f4be7f31f49d6ccfa955c26 Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Tue, 15 Jul 2025 14:38:53 +0200 Subject: [PATCH] Refactor speaker assignment in DiartDiarization for clarity and punctuation awareness --- whisperlivekit/audio_processor.py | 6 +- .../diarization/diarization_online.py | 88 +++++++++---------- 2 files changed, 45 insertions(+), 49 deletions(-) diff --git a/whisperlivekit/audio_processor.py b/whisperlivekit/audio_processor.py index bc65da6..ddb1ca5 100644 --- a/whisperlivekit/audio_processor.py +++ b/whisperlivekit/audio_processor.py @@ -325,12 +325,12 @@ class AudioProcessor: await diarization_obj.diarize(pcm_array) async with self.lock: - new_end = diarization_obj.assign_speakers_to_tokens( - self.end_attributed_speaker, + self.tokens = diarization_obj.assign_speakers_to_tokens( self.tokens, use_punctuation_split=self.args.punctuation_split ) - self.end_attributed_speaker = new_end + if len(self.tokens) > 0: + self.end_attributed_speaker = max(self.tokens[-1].end, self.end_attributed_speaker) if buffer_diarization: self.buffer_diarization = buffer_diarization diff --git a/whisperlivekit/diarization/diarization_online.py b/whisperlivekit/diarization/diarization_online.py index 9391104..cf995fb 100644 --- a/whisperlivekit/diarization/diarization_online.py +++ b/whisperlivekit/diarization/diarization_online.py @@ -214,7 +214,7 @@ class DiartDiarization: if self.custom_source: self.custom_source.close() - def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list, use_punctuation_split: bool = False) -> float: + def assign_speakers_to_tokens(self, tokens: list, use_punctuation_split: bool = False) -> float: """ Assign speakers to tokens based on timing overlap with speaker segments. Uses the segments collected by the observer. @@ -231,29 +231,8 @@ class DiartDiarization: if not self.lag_diart and segments and tokens: self.lag_diart = segments[0].start - tokens[0].start - for token in tokens: - for segment in segments: - if not (segment.end <= token.start + self.lag_diart or segment.start >= token.end + self.lag_diart): - token.speaker = extract_number(segment.speaker) + 1 - end_attributed_speaker = max(token.end, end_attributed_speaker) - - if use_punctuation_split and len(tokens) > 1: - pass - return end_attributed_speaker - - -def visualize_tokens(tokens): - conversation = [{"speaker": -1, "text": ""}] - for token in tokens: - speaker = conversation[-1]['speaker'] - if token.speaker != speaker: - conversation.append({"speaker": token.speaker, "text": token.text}) - else: - conversation[-1]['text'] += token.text - print("Conversation:") - for entry in conversation: - print(f"Speaker {entry['speaker']}: {entry['text']}") - + tokens = add_speaker_to_tokens(segments, tokens) + return tokens def concatenate_speakers(segments): segments_concatenated = [{"speaker": 1, "begin": 0.0, "end": 0.0}] @@ -270,41 +249,58 @@ def concatenate_speakers(segments): def add_speaker_to_tokens(segments, tokens): + """ + Assign speakers to tokens based on diarization segments, with punctuation-aware boundary adjustment. + Refactored for clarity; behavior unchanged. + """ punctuation_marks = {'.', '!', '?'} + punctuation_tokens = [token for token in tokens if token.text.strip() in punctuation_marks] segments_concatenated = concatenate_speakers(segments) - punctuation_tokens = [] - for token in tokens: - if token.text.strip() in punctuation_marks: - punctuation_tokens.append(token) - for ind, segment in enumerate(segments_concatenated): - for i, punctuation_token in enumerate(punctuation_tokens): - if punctuation_token.start > segment['end']: - after_length = punctuation_token.start - segment['end'] - before_length = segment['end'] - punctuation_tokens[i - 1].end - if before_length > after_length: - segment['end'] = punctuation_token.start - if i < len(punctuation_tokens) - 1: - segments_concatenated[ind+1]['begin'] = punctuation_token.start - else: - segment['end'] = punctuation_tokens[i - 1].end - if i < len(punctuation_tokens) - 1: - segments_concatenated[ind-1]['begin'] = punctuation_tokens[i - 1].end - break - + for i, punctuation_token in enumerate(punctuation_tokens): + if punctuation_token.start > segment['end']: + after_length = punctuation_token.start - segment['end'] + before_length = segment['end'] - punctuation_tokens[i - 1].end + if before_length > after_length: + segment['end'] = punctuation_token.start + if i < len(punctuation_tokens) - 1 and ind + 1 < len(segments_concatenated): + segments_concatenated[ind + 1]['begin'] = punctuation_token.start + else: + segment['end'] = punctuation_tokens[i - 1].end + if i < len(punctuation_tokens) - 1 and ind - 1 >= 0: + segments_concatenated[ind - 1]['begin'] = punctuation_tokens[i - 1].end + break + last_end = 0.0 for token in tokens: start = max(last_end + 0.01, token.start) token.start = start token.end = max(start, token.end) last_end = token.end - + ind_last_speaker = 0 for segment in segments_concatenated: for i, token in enumerate(tokens[ind_last_speaker:]): if token.end <= segment['end']: token.speaker = segment['speaker'] ind_last_speaker = i + 1 - print(f"Token '{token.text}' ('begin': {token.start:.2f}, 'end': {token.end:.2f}) assigned to Speaker {segment['speaker']} ('segment': {segment['begin']:.2f}-{segment['end']:.2f})") + print( + f"Token '{token.text}' ('begin': {token.start:.2f}, 'end': {token.end:.2f}) " + f"assigned to Speaker {segment['speaker']} ('segment': {segment['begin']:.2f}-{segment['end']:.2f})" + ) elif token.start > segment['end']: - break \ No newline at end of file + break + return tokens + + +def visualize_tokens(tokens): + conversation = [{"speaker": -1, "text": ""}] + for token in tokens: + speaker = conversation[-1]['speaker'] + if token.speaker != speaker: + conversation.append({"speaker": token.speaker, "text": token.text}) + else: + conversation[-1]['text'] += token.text + print("Conversation:") + for entry in conversation: + print(f"Speaker {entry['speaker']}: {entry['text']}") \ No newline at end of file