Refactor speaker assignment in DiartDiarization for clarity and punctuation awareness

This commit is contained in:
Quentin Fuxa
2025-07-15 14:38:53 +02:00
parent 1599bd87a0
commit 3ad3683ca7
2 changed files with 45 additions and 49 deletions

View File

@@ -325,12 +325,12 @@ class AudioProcessor:
await diarization_obj.diarize(pcm_array)
async with self.lock:
new_end = diarization_obj.assign_speakers_to_tokens(
self.end_attributed_speaker,
self.tokens = diarization_obj.assign_speakers_to_tokens(
self.tokens,
use_punctuation_split=self.args.punctuation_split
)
self.end_attributed_speaker = new_end
if len(self.tokens) > 0:
self.end_attributed_speaker = max(self.tokens[-1].end, self.end_attributed_speaker)
if buffer_diarization:
self.buffer_diarization = buffer_diarization

View File

@@ -214,7 +214,7 @@ class DiartDiarization:
if self.custom_source:
self.custom_source.close()
def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list, use_punctuation_split: bool = False) -> float:
def assign_speakers_to_tokens(self, tokens: list, use_punctuation_split: bool = False) -> float:
"""
Assign speakers to tokens based on timing overlap with speaker segments.
Uses the segments collected by the observer.
@@ -231,29 +231,8 @@ class DiartDiarization:
if not self.lag_diart and segments and tokens:
self.lag_diart = segments[0].start - tokens[0].start
for token in tokens:
for segment in segments:
if not (segment.end <= token.start + self.lag_diart or segment.start >= token.end + self.lag_diart):
token.speaker = extract_number(segment.speaker) + 1
end_attributed_speaker = max(token.end, end_attributed_speaker)
if use_punctuation_split and len(tokens) > 1:
pass
return end_attributed_speaker
def visualize_tokens(tokens):
conversation = [{"speaker": -1, "text": ""}]
for token in tokens:
speaker = conversation[-1]['speaker']
if token.speaker != speaker:
conversation.append({"speaker": token.speaker, "text": token.text})
else:
conversation[-1]['text'] += token.text
print("Conversation:")
for entry in conversation:
print(f"Speaker {entry['speaker']}: {entry['text']}")
tokens = add_speaker_to_tokens(segments, tokens)
return tokens
def concatenate_speakers(segments):
segments_concatenated = [{"speaker": 1, "begin": 0.0, "end": 0.0}]
@@ -270,41 +249,58 @@ def concatenate_speakers(segments):
def add_speaker_to_tokens(segments, tokens):
"""
Assign speakers to tokens based on diarization segments, with punctuation-aware boundary adjustment.
Refactored for clarity; behavior unchanged.
"""
punctuation_marks = {'.', '!', '?'}
punctuation_tokens = [token for token in tokens if token.text.strip() in punctuation_marks]
segments_concatenated = concatenate_speakers(segments)
punctuation_tokens = []
for token in tokens:
if token.text.strip() in punctuation_marks:
punctuation_tokens.append(token)
for ind, segment in enumerate(segments_concatenated):
for i, punctuation_token in enumerate(punctuation_tokens):
if punctuation_token.start > segment['end']:
after_length = punctuation_token.start - segment['end']
before_length = segment['end'] - punctuation_tokens[i - 1].end
if before_length > after_length:
segment['end'] = punctuation_token.start
if i < len(punctuation_tokens) - 1:
segments_concatenated[ind+1]['begin'] = punctuation_token.start
else:
segment['end'] = punctuation_tokens[i - 1].end
if i < len(punctuation_tokens) - 1:
segments_concatenated[ind-1]['begin'] = punctuation_tokens[i - 1].end
break
for i, punctuation_token in enumerate(punctuation_tokens):
if punctuation_token.start > segment['end']:
after_length = punctuation_token.start - segment['end']
before_length = segment['end'] - punctuation_tokens[i - 1].end
if before_length > after_length:
segment['end'] = punctuation_token.start
if i < len(punctuation_tokens) - 1 and ind + 1 < len(segments_concatenated):
segments_concatenated[ind + 1]['begin'] = punctuation_token.start
else:
segment['end'] = punctuation_tokens[i - 1].end
if i < len(punctuation_tokens) - 1 and ind - 1 >= 0:
segments_concatenated[ind - 1]['begin'] = punctuation_tokens[i - 1].end
break
last_end = 0.0
for token in tokens:
start = max(last_end + 0.01, token.start)
token.start = start
token.end = max(start, token.end)
last_end = token.end
ind_last_speaker = 0
for segment in segments_concatenated:
for i, token in enumerate(tokens[ind_last_speaker:]):
if token.end <= segment['end']:
token.speaker = segment['speaker']
ind_last_speaker = i + 1
print(f"Token '{token.text}' ('begin': {token.start:.2f}, 'end': {token.end:.2f}) assigned to Speaker {segment['speaker']} ('segment': {segment['begin']:.2f}-{segment['end']:.2f})")
print(
f"Token '{token.text}' ('begin': {token.start:.2f}, 'end': {token.end:.2f}) "
f"assigned to Speaker {segment['speaker']} ('segment': {segment['begin']:.2f}-{segment['end']:.2f})"
)
elif token.start > segment['end']:
break
break
return tokens
def visualize_tokens(tokens):
conversation = [{"speaker": -1, "text": ""}]
for token in tokens:
speaker = conversation[-1]['speaker']
if token.speaker != speaker:
conversation.append({"speaker": token.speaker, "text": token.text})
else:
conversation[-1]['text'] += token.text
print("Conversation:")
for entry in conversation:
print(f"Speaker {entry['speaker']}: {entry['text']}")