From b7d20a0ff0feb06ba7171a2055c2d1830572768c Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Wed, 19 Nov 2025 21:10:28 +0100 Subject: [PATCH] segment attribution in result formatter --- whisperlivekit/audio_processor.py | 12 +- whisperlivekit/diarization/diart_backend.py | 27 ----- .../diarization/sortformer_backend.py | 113 ------------------ whisperlivekit/result_diarization.md | 60 ++++++++++ whisperlivekit/results_formater.py | 110 ++++++++++++++++- whisperlivekit/timed_objects.py | 1 + 6 files changed, 174 insertions(+), 149 deletions(-) create mode 100644 whisperlivekit/result_diarization.md diff --git a/whisperlivekit/audio_processor.py b/whisperlivekit/audio_processor.py index 3150c7e..5bbba26 100644 --- a/whisperlivekit/audio_processor.py +++ b/whisperlivekit/audio_processor.py @@ -417,14 +417,12 @@ class AudioProcessor: to_transcript, self.cumulative_pcm = cut_at(self.cumulative_pcm, cut_sec) await self.transcription_queue.put(to_transcript) self.last_end = last_segment.end - elif not self.diarization_before_transcription: + elif not self.diarization_before_transcription: + segments = diarization_obj.get_segments() async with self.lock: - self.state.tokens = diarization_obj.assign_speakers_to_tokens( - self.state.tokens, - use_punctuation_split=self.args.punctuation_split - ) - if len(self.state.tokens) > 0: - self.state.end_attributed_speaker = max(self.state.tokens[-1].end, self.state.end_attributed_speaker) + self.state.speaker_segments = segments.copy() + if segments: + self.state.end_attributed_speaker = max(seg.end for seg in segments) except Exception as e: logger.warning(f"Exception in diarization_processor: {e}") diff --git a/whisperlivekit/diarization/diart_backend.py b/whisperlivekit/diarization/diart_backend.py index 6c578cb..55a7d3c 100644 --- a/whisperlivekit/diarization/diart_backend.py +++ b/whisperlivekit/diarization/diart_backend.py @@ -178,7 +178,6 @@ class DiartDiarization: self.pipeline = SpeakerDiarization(config=config) self.observer = DiarizationObserver() - self.lag_diart = None if use_microphone: self.source = MicrophoneAudioSource(block_duration=block_duration) @@ -217,32 +216,6 @@ class DiartDiarization: if self.custom_source: self.custom_source.close() - def assign_speakers_to_tokens(self, tokens: list, use_punctuation_split: bool = False) -> float: - """ - Assign speakers to tokens based on timing overlap with speaker segments. - Uses the segments collected by the observer. - - If use_punctuation_split is True, uses punctuation marks to refine speaker boundaries. - """ - segments = self.observer.get_segments() - - # Debug logging - logger.debug(f"assign_speakers_to_tokens called with {len(tokens)} tokens") - logger.debug(f"Available segments: {len(segments)}") - for i, seg in enumerate(segments[:5]): # Show first 5 segments - logger.debug(f" Segment {i}: {seg.speaker} [{seg.start:.2f}-{seg.end:.2f}]") - - if not self.lag_diart and segments and tokens: - self.lag_diart = segments[0].start - tokens[0].start - - if not use_punctuation_split: - for token in tokens: - for segment in segments: - if not (segment.end <= token.start + self.lag_diart or segment.start >= token.end + self.lag_diart): - token.speaker = extract_number(segment.speaker) + 1 - else: - tokens = add_speaker_to_tokens(segments, tokens) - return tokens def concatenate_speakers(segments): segments_concatenated = [{"speaker": 1, "begin": 0.0, "end": 0.0}] diff --git a/whisperlivekit/diarization/sortformer_backend.py b/whisperlivekit/diarization/sortformer_backend.py index d835df0..4d79156 100644 --- a/whisperlivekit/diarization/sortformer_backend.py +++ b/whisperlivekit/diarization/sortformer_backend.py @@ -279,119 +279,6 @@ class SortformerDiarizationOnline: except Exception as e: logger.error(f"Error processing predictions: {e}") - def assign_speakers_to_tokens(self, tokens: list, use_punctuation_split: bool = False) -> list: - """ - Assign speakers to tokens based on timing overlap with speaker segments. - - Args: - tokens: List of tokens with timing information - use_punctuation_split: Whether to use punctuation for boundary refinement - - Returns: - List of tokens with speaker assignments - Last speaker_segment - """ - with self.segment_lock: - segments = self.speaker_segments.copy() - - if not segments or not tokens: - logger.debug("No segments or tokens available for speaker assignment") - return tokens - - logger.debug(f"Assigning speakers to {len(tokens)} tokens using {len(segments)} segments") - use_punctuation_split = False - if not use_punctuation_split: - # Simple overlap-based assignment - for token in tokens: - token.speaker = -1 # Default to no speaker - for segment in segments: - # Check for timing overlap - if not (segment.end <= token.start or segment.start >= token.end): - token.speaker = segment.speaker + 1 # Convert to 1-based indexing - break - else: - # Use punctuation-aware assignment (similar to diart_backend) - tokens = self._add_speaker_to_tokens_with_punctuation(segments, tokens) - - return tokens - - def _add_speaker_to_tokens_with_punctuation(self, segments: List[SpeakerSegment], tokens: list) -> list: - """ - Assign speakers to tokens with punctuation-aware boundary adjustment. - - Args: - segments: List of speaker segments - tokens: List of tokens to assign speakers to - - Returns: - List of tokens with speaker assignments - """ - punctuation_marks = {'.', '!', '?'} - punctuation_tokens = [token for token in tokens if token.text.strip() in punctuation_marks] - - # Convert segments to concatenated format - segments_concatenated = self._concatenate_speakers(segments) - - # Adjust segment boundaries based on punctuation - for ind, segment in enumerate(segments_concatenated): - for i, punctuation_token in enumerate(punctuation_tokens): - if punctuation_token.start > segment['end']: - after_length = punctuation_token.start - segment['end'] - before_length = segment['end'] - punctuation_tokens[i - 1].end if i > 0 else float('inf') - - if before_length > after_length: - segment['end'] = punctuation_token.start - if i < len(punctuation_tokens) - 1 and ind + 1 < len(segments_concatenated): - segments_concatenated[ind + 1]['begin'] = punctuation_token.start - else: - segment['end'] = punctuation_tokens[i - 1].end if i > 0 else segment['end'] - if i < len(punctuation_tokens) - 1 and ind - 1 >= 0: - segments_concatenated[ind - 1]['begin'] = punctuation_tokens[i - 1].end - break - - # Ensure non-overlapping tokens - last_end = 0.0 - for token in tokens: - start = max(last_end + 0.01, token.start) - token.start = start - token.end = max(start, token.end) - last_end = token.end - - # Assign speakers based on adjusted segments - ind_last_speaker = 0 - for segment in segments_concatenated: - for i, token in enumerate(tokens[ind_last_speaker:]): - if token.end <= segment['end']: - token.speaker = segment['speaker'] - ind_last_speaker = i + 1 - elif token.start > segment['end']: - break - - return tokens - - def _concatenate_speakers(self, segments: List[SpeakerSegment]) -> List[dict]: - """ - Concatenate consecutive segments from the same speaker. - - Args: - segments: List of speaker segments - - Returns: - List of concatenated speaker segments - """ - if not segments: - return [] - - segments_concatenated = [{"speaker": segments[0].speaker + 1, "begin": segments[0].start, "end": segments[0].end}] - - for segment in segments[1:]: - speaker = segment.speaker + 1 - if segments_concatenated[-1]['speaker'] != speaker: - segments_concatenated.append({"speaker": speaker, "begin": segment.start, "end": segment.end}) - else: - segments_concatenated[-1]['end'] = segment.end - - return segments_concatenated def get_segments(self) -> List[SpeakerSegment]: """Get a copy of the current speaker segments.""" diff --git a/whisperlivekit/result_diarization.md b/whisperlivekit/result_diarization.md new file mode 100644 index 0000000..78607b3 --- /dev/null +++ b/whisperlivekit/result_diarization.md @@ -0,0 +1,60 @@ +########### WHAT IS PRODUCED: ########### + +SPEAKER 1 0:00:04 - 0:00:06 +Transcription technology has improved so much in the past + +SPEAKER 1 0:00:07 - 0:00:12 +years. Have you noticed how accurate real-time speech detects is now? + +SPEAKER 2 0:00:12 - 0:00:12 +Absolutely + +SPEAKER 1 0:00:13 - 0:00:13 +. + +SPEAKER 2 0:00:14 - 0:00:14 +I + +SPEAKER 1 0:00:14 - 0:00:17 +use it all the time for taking notes during meetings. + +SPEAKER 2 0:00:17 - 0:00:17 +It + +SPEAKER 1 0:00:17 - 0:00:22 +'s amazing how it can recognize different speakers, and even add punctuation. + +SPEAKER 2 0:00:22 - 0:00:22 +Yeah + +SPEAKER 1 0:00:23 - 0:00:26 +, but sometimes noise can still cause mistakes. + +SPEAKER 3 0:00:26 - 0:00:27 +Does + +SPEAKER 1 0:00:27 - 0:00:28 +this system handle that + +SPEAKER 1 0:00:29 - 0:00:29 +? + +SPEAKER 3 0:00:29 - 0:00:29 +It + +SPEAKER 1 0:00:29 - 0:00:33 +does a pretty good job filtering noise, especially with models that use voice activity + +########### WHAT SHOULD BE PRODUCED: ########### + +SPEAKER 1 0:00:04 - 0:00:12 +Transcription technology has improved so much in the past years. Have you noticed how accurate real-time speech detects is now? + +SPEAKER 2 0:00:12 - 0:00:22 +Absolutely. I use it all the time for taking notes during meetings. It's amazing how it can recognize different speakers, and even add punctuation. + +SPEAKER 3 0:00:22 - 0:00:28 +Yeah, but sometimes noise can still cause mistakes. Does this system handle that well? + +SPEAKER 1 0:00:29 - 0:00:29 +It does a pretty good job filtering noise, especially with models that use voice activity \ No newline at end of file diff --git a/whisperlivekit/results_formater.py b/whisperlivekit/results_formater.py index 86622fb..df83601 100644 --- a/whisperlivekit/results_formater.py +++ b/whisperlivekit/results_formater.py @@ -1,7 +1,9 @@ import logging +import re from whisperlivekit.remove_silences import handle_silences -from whisperlivekit.timed_objects import Line, format_time +from whisperlivekit.timed_objects import Line, format_time, SpeakerSegment +from typing import List logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -44,7 +46,106 @@ def append_token_to_last_line(lines, sep, token): lines[-1].end = token.end if not lines[-1].detected_language and token.detected_language: lines[-1].detected_language = token.detected_language - + +def extract_number(s) -> int: + """Extract number from speaker string (for diart compatibility).""" + if isinstance(s, int): + return s + m = re.search(r'\d+', str(s)) + return int(m.group()) if m else 0 + +def concatenate_speakers(segments: List[SpeakerSegment]) -> List[dict]: + """Concatenate consecutive segments from the same speaker.""" + if not segments: + return [] + + # Get speaker number from first segment + first_speaker = extract_number(segments[0].speaker) + segments_concatenated = [{"speaker": first_speaker + 1, "begin": segments[0].start, "end": segments[0].end}] + + for segment in segments[1:]: + speaker = extract_number(segment.speaker) + 1 + if segments_concatenated[-1]['speaker'] != speaker: + segments_concatenated.append({"speaker": speaker, "begin": segment.start, "end": segment.end}) + else: + segments_concatenated[-1]['end'] = segment.end + + return segments_concatenated + +def add_speaker_to_tokens_with_punctuation(segments: List[SpeakerSegment], tokens: list) -> list: + """Assign speakers to tokens with punctuation-aware boundary adjustment.""" + punctuation_marks = {'.', '!', '?'} + punctuation_tokens = [token for token in tokens if token.text.strip() in punctuation_marks] + segments_concatenated = concatenate_speakers(segments) + + for ind, segment in enumerate(segments_concatenated): + for i, punctuation_token in enumerate(punctuation_tokens): + if punctuation_token.start > segment['end']: + after_length = punctuation_token.start - segment['end'] + before_length = segment['end'] - punctuation_tokens[i - 1].end if i > 0 else float('inf') + if before_length > after_length: + segment['end'] = punctuation_token.start + if i < len(punctuation_tokens) - 1 and ind + 1 < len(segments_concatenated): + segments_concatenated[ind + 1]['begin'] = punctuation_token.start + else: + segment['end'] = punctuation_tokens[i - 1].end if i > 0 else segment['end'] + if i < len(punctuation_tokens) - 1 and ind - 1 >= 0: + segments_concatenated[ind - 1]['begin'] = punctuation_tokens[i - 1].end + break + + # Ensure non-overlapping tokens + last_end = 0.0 + for token in tokens: + start = max(last_end + 0.01, token.start) + token.start = start + token.end = max(start, token.end) + last_end = token.end + + # Assign speakers based on adjusted segments + ind_last_speaker = 0 + for segment in segments_concatenated: + for i, token in enumerate(tokens[ind_last_speaker:]): + if token.end <= segment['end']: + token.speaker = segment['speaker'] + ind_last_speaker = i + 1 + elif token.start > segment['end']: + break + + return tokens + +def assign_speakers_to_tokens(tokens: list, segments: List[SpeakerSegment], use_punctuation_split: bool = False) -> list: + """ + Assign speakers to tokens based on timing overlap with speaker segments. + + Args: + tokens: List of tokens with timing information + segments: List of speaker segments + use_punctuation_split: Whether to use punctuation for boundary refinement + + Returns: + List of tokens with speaker assignments + """ + if not segments or not tokens: + logger.debug("No segments or tokens available for speaker assignment") + return tokens + + logger.debug(f"Assigning speakers to {len(tokens)} tokens using {len(segments)} segments") + + if not use_punctuation_split: + # Simple overlap-based assignment + for token in tokens: + token.speaker = -1 # Default to no speaker + for segment in segments: + # Check for timing overlap + if not (segment.end <= token.start or segment.start >= token.end): + speaker_num = extract_number(segment.speaker) + token.speaker = speaker_num + 1 # Convert to 1-based indexing + break + else: + # Use punctuation-aware assignment + tokens = add_speaker_to_tokens_with_punctuation(segments, tokens) + + return tokens def format_output(state, silence, args, sep): diarization = args.diarization @@ -56,6 +157,11 @@ def format_output(state, silence, args, sep): last_speaker = abs(state.last_speaker) undiarized_text = [] tokens = handle_silences(tokens, state.beg_loop, silence) + + # Assign speakers to tokens based on segments stored in state + if diarization and state.speaker_segments: + use_punctuation_split = args.punctuation_split if hasattr(args, 'punctuation_split') else False + tokens = assign_speakers_to_tokens(tokens, state.speaker_segments, use_punctuation_split=use_punctuation_split) for i in range(last_validated_token, len(tokens)): token = tokens[i] speaker = int(token.speaker) diff --git a/whisperlivekit/timed_objects.py b/whisperlivekit/timed_objects.py index 2a73e8b..66e1b14 100644 --- a/whisperlivekit/timed_objects.py +++ b/whisperlivekit/timed_objects.py @@ -185,6 +185,7 @@ class State(): translation_validated_segments: list = field(default_factory=list) buffer_translation: str = field(default_factory=Transcript) buffer_transcription: str = field(default_factory=Transcript) + speaker_segments: list = field(default_factory=list) end_buffer: float = 0.0 end_attributed_speaker: float = 0.0 remaining_time_transcription: float = 0.0