mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
Refactor speaker assignment in DiartDiarization for clarity and punctuation awareness
This commit is contained in:
@@ -325,12 +325,12 @@ class AudioProcessor:
|
||||
await diarization_obj.diarize(pcm_array)
|
||||
|
||||
async with self.lock:
|
||||
new_end = diarization_obj.assign_speakers_to_tokens(
|
||||
self.end_attributed_speaker,
|
||||
self.tokens = diarization_obj.assign_speakers_to_tokens(
|
||||
self.tokens,
|
||||
use_punctuation_split=self.args.punctuation_split
|
||||
)
|
||||
self.end_attributed_speaker = new_end
|
||||
if len(self.tokens) > 0:
|
||||
self.end_attributed_speaker = max(self.tokens[-1].end, self.end_attributed_speaker)
|
||||
if buffer_diarization:
|
||||
self.buffer_diarization = buffer_diarization
|
||||
|
||||
|
||||
@@ -214,7 +214,7 @@ class DiartDiarization:
|
||||
if self.custom_source:
|
||||
self.custom_source.close()
|
||||
|
||||
def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list, use_punctuation_split: bool = False) -> float:
|
||||
def assign_speakers_to_tokens(self, tokens: list, use_punctuation_split: bool = False) -> float:
|
||||
"""
|
||||
Assign speakers to tokens based on timing overlap with speaker segments.
|
||||
Uses the segments collected by the observer.
|
||||
@@ -231,29 +231,8 @@ class DiartDiarization:
|
||||
|
||||
if not self.lag_diart and segments and tokens:
|
||||
self.lag_diart = segments[0].start - tokens[0].start
|
||||
for token in tokens:
|
||||
for segment in segments:
|
||||
if not (segment.end <= token.start + self.lag_diart or segment.start >= token.end + self.lag_diart):
|
||||
token.speaker = extract_number(segment.speaker) + 1
|
||||
end_attributed_speaker = max(token.end, end_attributed_speaker)
|
||||
|
||||
if use_punctuation_split and len(tokens) > 1:
|
||||
pass
|
||||
return end_attributed_speaker
|
||||
|
||||
|
||||
def visualize_tokens(tokens):
|
||||
conversation = [{"speaker": -1, "text": ""}]
|
||||
for token in tokens:
|
||||
speaker = conversation[-1]['speaker']
|
||||
if token.speaker != speaker:
|
||||
conversation.append({"speaker": token.speaker, "text": token.text})
|
||||
else:
|
||||
conversation[-1]['text'] += token.text
|
||||
print("Conversation:")
|
||||
for entry in conversation:
|
||||
print(f"Speaker {entry['speaker']}: {entry['text']}")
|
||||
|
||||
tokens = add_speaker_to_tokens(segments, tokens)
|
||||
return tokens
|
||||
|
||||
def concatenate_speakers(segments):
|
||||
segments_concatenated = [{"speaker": 1, "begin": 0.0, "end": 0.0}]
|
||||
@@ -270,41 +249,58 @@ def concatenate_speakers(segments):
|
||||
|
||||
|
||||
def add_speaker_to_tokens(segments, tokens):
|
||||
"""
|
||||
Assign speakers to tokens based on diarization segments, with punctuation-aware boundary adjustment.
|
||||
Refactored for clarity; behavior unchanged.
|
||||
"""
|
||||
punctuation_marks = {'.', '!', '?'}
|
||||
punctuation_tokens = [token for token in tokens if token.text.strip() in punctuation_marks]
|
||||
segments_concatenated = concatenate_speakers(segments)
|
||||
punctuation_tokens = []
|
||||
for token in tokens:
|
||||
if token.text.strip() in punctuation_marks:
|
||||
punctuation_tokens.append(token)
|
||||
|
||||
for ind, segment in enumerate(segments_concatenated):
|
||||
for i, punctuation_token in enumerate(punctuation_tokens):
|
||||
if punctuation_token.start > segment['end']:
|
||||
after_length = punctuation_token.start - segment['end']
|
||||
before_length = segment['end'] - punctuation_tokens[i - 1].end
|
||||
if before_length > after_length:
|
||||
segment['end'] = punctuation_token.start
|
||||
if i < len(punctuation_tokens) - 1:
|
||||
segments_concatenated[ind+1]['begin'] = punctuation_token.start
|
||||
else:
|
||||
segment['end'] = punctuation_tokens[i - 1].end
|
||||
if i < len(punctuation_tokens) - 1:
|
||||
segments_concatenated[ind-1]['begin'] = punctuation_tokens[i - 1].end
|
||||
break
|
||||
|
||||
for i, punctuation_token in enumerate(punctuation_tokens):
|
||||
if punctuation_token.start > segment['end']:
|
||||
after_length = punctuation_token.start - segment['end']
|
||||
before_length = segment['end'] - punctuation_tokens[i - 1].end
|
||||
if before_length > after_length:
|
||||
segment['end'] = punctuation_token.start
|
||||
if i < len(punctuation_tokens) - 1 and ind + 1 < len(segments_concatenated):
|
||||
segments_concatenated[ind + 1]['begin'] = punctuation_token.start
|
||||
else:
|
||||
segment['end'] = punctuation_tokens[i - 1].end
|
||||
if i < len(punctuation_tokens) - 1 and ind - 1 >= 0:
|
||||
segments_concatenated[ind - 1]['begin'] = punctuation_tokens[i - 1].end
|
||||
break
|
||||
|
||||
last_end = 0.0
|
||||
for token in tokens:
|
||||
start = max(last_end + 0.01, token.start)
|
||||
token.start = start
|
||||
token.end = max(start, token.end)
|
||||
last_end = token.end
|
||||
|
||||
|
||||
ind_last_speaker = 0
|
||||
for segment in segments_concatenated:
|
||||
for i, token in enumerate(tokens[ind_last_speaker:]):
|
||||
if token.end <= segment['end']:
|
||||
token.speaker = segment['speaker']
|
||||
ind_last_speaker = i + 1
|
||||
print(f"Token '{token.text}' ('begin': {token.start:.2f}, 'end': {token.end:.2f}) assigned to Speaker {segment['speaker']} ('segment': {segment['begin']:.2f}-{segment['end']:.2f})")
|
||||
print(
|
||||
f"Token '{token.text}' ('begin': {token.start:.2f}, 'end': {token.end:.2f}) "
|
||||
f"assigned to Speaker {segment['speaker']} ('segment': {segment['begin']:.2f}-{segment['end']:.2f})"
|
||||
)
|
||||
elif token.start > segment['end']:
|
||||
break
|
||||
break
|
||||
return tokens
|
||||
|
||||
|
||||
def visualize_tokens(tokens):
|
||||
conversation = [{"speaker": -1, "text": ""}]
|
||||
for token in tokens:
|
||||
speaker = conversation[-1]['speaker']
|
||||
if token.speaker != speaker:
|
||||
conversation.append({"speaker": token.speaker, "text": token.text})
|
||||
else:
|
||||
conversation[-1]['text'] += token.text
|
||||
print("Conversation:")
|
||||
for entry in conversation:
|
||||
print(f"Speaker {entry['speaker']}: {entry['text']}")
|
||||
Reference in New Issue
Block a user