diff --git a/whisperlivekit/diarization/diarization_online.py b/whisperlivekit/diarization/diarization_online.py index 04ce0ac..9391104 100644 --- a/whisperlivekit/diarization/diarization_online.py +++ b/whisperlivekit/diarization/diarization_online.py @@ -238,78 +238,73 @@ class DiartDiarization: end_attributed_speaker = max(token.end, end_attributed_speaker) if use_punctuation_split and len(tokens) > 1: - punctuation_marks = {'.', '!', '?'} - - print("Here are the tokens:", - [(t.text, t.start, t.end, t.speaker) for t in tokens[:10]]) - - segment_map = [] - for segment in segments: - speaker_num = extract_number(segment.speaker) + 1 - segment_map.append((segment.start, segment.end, speaker_num)) - segment_map.sort(key=lambda x: x[0]) - - i = 0 - while i < len(tokens): - current_token = tokens[i] - - is_sentence_end = False - if current_token.text and current_token.text.strip(): - text = current_token.text.strip() - if text[-1] in punctuation_marks: - is_sentence_end = True - logger.debug(f"Token {i} ends sentence: '{current_token.text}' at {current_token.end:.2f}s") - - if is_sentence_end and current_token.speaker != -1: - punctuation_time = current_token.end - current_speaker = current_token.speaker - - j = i + 1 - next_sentence_tokens = [] - while j < len(tokens): - next_token = tokens[j] - next_sentence_tokens.append(j) - - # Check if this token ends the next sentence - if next_token.text and next_token.text.strip(): - if next_token.text.strip()[-1] in punctuation_marks: - break - j += 1 - - if next_sentence_tokens: - speaker_times = {} - - for idx in next_sentence_tokens: - token = tokens[idx] - # Find which segments overlap with this token - for seg_start, seg_end, seg_speaker in segment_map: - if not (seg_end <= token.start or seg_start >= token.end): - # Calculate overlap duration - overlap_start = max(seg_start, token.start) - overlap_end = min(seg_end, token.end) - overlap_duration = overlap_end - overlap_start - - if seg_speaker not in speaker_times: - speaker_times[seg_speaker] = 0 - speaker_times[seg_speaker] += overlap_duration - - if speaker_times: - dominant_speaker = max(speaker_times.items(), key=lambda x: x[1])[0] - - if dominant_speaker != current_speaker: - logger.debug(f" Speaker change after punctuation: {current_speaker} → {dominant_speaker}") - - for idx in next_sentence_tokens: - if tokens[idx].speaker != dominant_speaker: - logger.debug(f" Reassigning token {idx} ('{tokens[idx].text}') to Speaker {dominant_speaker}") - tokens[idx].speaker = dominant_speaker - end_attributed_speaker = max(tokens[idx].end, end_attributed_speaker) - else: - for idx in next_sentence_tokens: - if tokens[idx].speaker == -1: - tokens[idx].speaker = current_speaker - end_attributed_speaker = max(tokens[idx].end, end_attributed_speaker) - - i += 1 - + pass return end_attributed_speaker + + +def visualize_tokens(tokens): + conversation = [{"speaker": -1, "text": ""}] + for token in tokens: + speaker = conversation[-1]['speaker'] + if token.speaker != speaker: + conversation.append({"speaker": token.speaker, "text": token.text}) + else: + conversation[-1]['text'] += token.text + print("Conversation:") + for entry in conversation: + print(f"Speaker {entry['speaker']}: {entry['text']}") + + +def concatenate_speakers(segments): + segments_concatenated = [{"speaker": 1, "begin": 0.0, "end": 0.0}] + for segment in segments: + speaker = extract_number(segment.speaker) + 1 + if segments_concatenated[-1]['speaker'] != speaker: + segments_concatenated.append({"speaker": speaker, "begin": segment.start, "end": segment.end}) + else: + segments_concatenated[-1]['end'] = segment.end + print("Segments concatenated:") + for entry in segments_concatenated: + print(f"Speaker {entry['speaker']}: {entry['begin']:.2f}s - {entry['end']:.2f}s") + return segments_concatenated + + +def add_speaker_to_tokens(segments, tokens): + punctuation_marks = {'.', '!', '?'} + segments_concatenated = concatenate_speakers(segments) + punctuation_tokens = [] + for token in tokens: + if token.text.strip() in punctuation_marks: + punctuation_tokens.append(token) + + for ind, segment in enumerate(segments_concatenated): + for i, punctuation_token in enumerate(punctuation_tokens): + if punctuation_token.start > segment['end']: + after_length = punctuation_token.start - segment['end'] + before_length = segment['end'] - punctuation_tokens[i - 1].end + if before_length > after_length: + segment['end'] = punctuation_token.start + if i < len(punctuation_tokens) - 1: + segments_concatenated[ind+1]['begin'] = punctuation_token.start + else: + segment['end'] = punctuation_tokens[i - 1].end + if i < len(punctuation_tokens) - 1: + segments_concatenated[ind-1]['begin'] = punctuation_tokens[i - 1].end + break + + last_end = 0.0 + for token in tokens: + start = max(last_end + 0.01, token.start) + token.start = start + token.end = max(start, token.end) + last_end = token.end + + ind_last_speaker = 0 + for segment in segments_concatenated: + for i, token in enumerate(tokens[ind_last_speaker:]): + if token.end <= segment['end']: + token.speaker = segment['speaker'] + ind_last_speaker = i + 1 + print(f"Token '{token.text}' ('begin': {token.start:.2f}, 'end': {token.end:.2f}) assigned to Speaker {segment['speaker']} ('segment': {segment['begin']:.2f}-{segment['end']:.2f})") + elif token.start > segment['end']: + break \ No newline at end of file