segment attribution in result formatter

This commit is contained in:
Quentin Fuxa
2025-11-19 21:10:28 +01:00
parent c1bb9c2bde
commit b7d20a0ff0
6 changed files with 174 additions and 149 deletions

View File

@@ -417,14 +417,12 @@ class AudioProcessor:
to_transcript, self.cumulative_pcm = cut_at(self.cumulative_pcm, cut_sec)
await self.transcription_queue.put(to_transcript)
self.last_end = last_segment.end
elif not self.diarization_before_transcription:
elif not self.diarization_before_transcription:
segments = diarization_obj.get_segments()
async with self.lock:
self.state.tokens = diarization_obj.assign_speakers_to_tokens(
self.state.tokens,
use_punctuation_split=self.args.punctuation_split
)
if len(self.state.tokens) > 0:
self.state.end_attributed_speaker = max(self.state.tokens[-1].end, self.state.end_attributed_speaker)
self.state.speaker_segments = segments.copy()
if segments:
self.state.end_attributed_speaker = max(seg.end for seg in segments)
except Exception as e:
logger.warning(f"Exception in diarization_processor: {e}")

View File

@@ -178,7 +178,6 @@ class DiartDiarization:
self.pipeline = SpeakerDiarization(config=config)
self.observer = DiarizationObserver()
self.lag_diart = None
if use_microphone:
self.source = MicrophoneAudioSource(block_duration=block_duration)
@@ -217,32 +216,6 @@ class DiartDiarization:
if self.custom_source:
self.custom_source.close()
def assign_speakers_to_tokens(self, tokens: list, use_punctuation_split: bool = False) -> float:
"""
Assign speakers to tokens based on timing overlap with speaker segments.
Uses the segments collected by the observer.
If use_punctuation_split is True, uses punctuation marks to refine speaker boundaries.
"""
segments = self.observer.get_segments()
# Debug logging
logger.debug(f"assign_speakers_to_tokens called with {len(tokens)} tokens")
logger.debug(f"Available segments: {len(segments)}")
for i, seg in enumerate(segments[:5]): # Show first 5 segments
logger.debug(f" Segment {i}: {seg.speaker} [{seg.start:.2f}-{seg.end:.2f}]")
if not self.lag_diart and segments and tokens:
self.lag_diart = segments[0].start - tokens[0].start
if not use_punctuation_split:
for token in tokens:
for segment in segments:
if not (segment.end <= token.start + self.lag_diart or segment.start >= token.end + self.lag_diart):
token.speaker = extract_number(segment.speaker) + 1
else:
tokens = add_speaker_to_tokens(segments, tokens)
return tokens
def concatenate_speakers(segments):
segments_concatenated = [{"speaker": 1, "begin": 0.0, "end": 0.0}]

View File

@@ -279,119 +279,6 @@ class SortformerDiarizationOnline:
except Exception as e:
logger.error(f"Error processing predictions: {e}")
def assign_speakers_to_tokens(self, tokens: list, use_punctuation_split: bool = False) -> list:
"""
Assign speakers to tokens based on timing overlap with speaker segments.
Args:
tokens: List of tokens with timing information
use_punctuation_split: Whether to use punctuation for boundary refinement
Returns:
List of tokens with speaker assignments
Last speaker_segment
"""
with self.segment_lock:
segments = self.speaker_segments.copy()
if not segments or not tokens:
logger.debug("No segments or tokens available for speaker assignment")
return tokens
logger.debug(f"Assigning speakers to {len(tokens)} tokens using {len(segments)} segments")
use_punctuation_split = False
if not use_punctuation_split:
# Simple overlap-based assignment
for token in tokens:
token.speaker = -1 # Default to no speaker
for segment in segments:
# Check for timing overlap
if not (segment.end <= token.start or segment.start >= token.end):
token.speaker = segment.speaker + 1 # Convert to 1-based indexing
break
else:
# Use punctuation-aware assignment (similar to diart_backend)
tokens = self._add_speaker_to_tokens_with_punctuation(segments, tokens)
return tokens
def _add_speaker_to_tokens_with_punctuation(self, segments: List[SpeakerSegment], tokens: list) -> list:
"""
Assign speakers to tokens with punctuation-aware boundary adjustment.
Args:
segments: List of speaker segments
tokens: List of tokens to assign speakers to
Returns:
List of tokens with speaker assignments
"""
punctuation_marks = {'.', '!', '?'}
punctuation_tokens = [token for token in tokens if token.text.strip() in punctuation_marks]
# Convert segments to concatenated format
segments_concatenated = self._concatenate_speakers(segments)
# Adjust segment boundaries based on punctuation
for ind, segment in enumerate(segments_concatenated):
for i, punctuation_token in enumerate(punctuation_tokens):
if punctuation_token.start > segment['end']:
after_length = punctuation_token.start - segment['end']
before_length = segment['end'] - punctuation_tokens[i - 1].end if i > 0 else float('inf')
if before_length > after_length:
segment['end'] = punctuation_token.start
if i < len(punctuation_tokens) - 1 and ind + 1 < len(segments_concatenated):
segments_concatenated[ind + 1]['begin'] = punctuation_token.start
else:
segment['end'] = punctuation_tokens[i - 1].end if i > 0 else segment['end']
if i < len(punctuation_tokens) - 1 and ind - 1 >= 0:
segments_concatenated[ind - 1]['begin'] = punctuation_tokens[i - 1].end
break
# Ensure non-overlapping tokens
last_end = 0.0
for token in tokens:
start = max(last_end + 0.01, token.start)
token.start = start
token.end = max(start, token.end)
last_end = token.end
# Assign speakers based on adjusted segments
ind_last_speaker = 0
for segment in segments_concatenated:
for i, token in enumerate(tokens[ind_last_speaker:]):
if token.end <= segment['end']:
token.speaker = segment['speaker']
ind_last_speaker = i + 1
elif token.start > segment['end']:
break
return tokens
def _concatenate_speakers(self, segments: List[SpeakerSegment]) -> List[dict]:
"""
Concatenate consecutive segments from the same speaker.
Args:
segments: List of speaker segments
Returns:
List of concatenated speaker segments
"""
if not segments:
return []
segments_concatenated = [{"speaker": segments[0].speaker + 1, "begin": segments[0].start, "end": segments[0].end}]
for segment in segments[1:]:
speaker = segment.speaker + 1
if segments_concatenated[-1]['speaker'] != speaker:
segments_concatenated.append({"speaker": speaker, "begin": segment.start, "end": segment.end})
else:
segments_concatenated[-1]['end'] = segment.end
return segments_concatenated
def get_segments(self) -> List[SpeakerSegment]:
"""Get a copy of the current speaker segments."""

View File

@@ -0,0 +1,60 @@
########### WHAT IS PRODUCED: ###########
SPEAKER 1 0:00:04 - 0:00:06
Transcription technology has improved so much in the past
SPEAKER 1 0:00:07 - 0:00:12
years. Have you noticed how accurate real-time speech detects is now?
SPEAKER 2 0:00:12 - 0:00:12
Absolutely
SPEAKER 1 0:00:13 - 0:00:13
.
SPEAKER 2 0:00:14 - 0:00:14
I
SPEAKER 1 0:00:14 - 0:00:17
use it all the time for taking notes during meetings.
SPEAKER 2 0:00:17 - 0:00:17
It
SPEAKER 1 0:00:17 - 0:00:22
's amazing how it can recognize different speakers, and even add punctuation.
SPEAKER 2 0:00:22 - 0:00:22
Yeah
SPEAKER 1 0:00:23 - 0:00:26
, but sometimes noise can still cause mistakes.
SPEAKER 3 0:00:26 - 0:00:27
Does
SPEAKER 1 0:00:27 - 0:00:28
this system handle that
SPEAKER 1 0:00:29 - 0:00:29
?
SPEAKER 3 0:00:29 - 0:00:29
It
SPEAKER 1 0:00:29 - 0:00:33
does a pretty good job filtering noise, especially with models that use voice activity
########### WHAT SHOULD BE PRODUCED: ###########
SPEAKER 1 0:00:04 - 0:00:12
Transcription technology has improved so much in the past years. Have you noticed how accurate real-time speech detects is now?
SPEAKER 2 0:00:12 - 0:00:22
Absolutely. I use it all the time for taking notes during meetings. It's amazing how it can recognize different speakers, and even add punctuation.
SPEAKER 3 0:00:22 - 0:00:28
Yeah, but sometimes noise can still cause mistakes. Does this system handle that well?
SPEAKER 1 0:00:29 - 0:00:29
It does a pretty good job filtering noise, especially with models that use voice activity

View File

@@ -1,7 +1,9 @@
import logging
import re
from whisperlivekit.remove_silences import handle_silences
from whisperlivekit.timed_objects import Line, format_time
from whisperlivekit.timed_objects import Line, format_time, SpeakerSegment
from typing import List
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@@ -44,7 +46,106 @@ def append_token_to_last_line(lines, sep, token):
lines[-1].end = token.end
if not lines[-1].detected_language and token.detected_language:
lines[-1].detected_language = token.detected_language
def extract_number(s) -> int:
"""Extract number from speaker string (for diart compatibility)."""
if isinstance(s, int):
return s
m = re.search(r'\d+', str(s))
return int(m.group()) if m else 0
def concatenate_speakers(segments: List[SpeakerSegment]) -> List[dict]:
"""Concatenate consecutive segments from the same speaker."""
if not segments:
return []
# Get speaker number from first segment
first_speaker = extract_number(segments[0].speaker)
segments_concatenated = [{"speaker": first_speaker + 1, "begin": segments[0].start, "end": segments[0].end}]
for segment in segments[1:]:
speaker = extract_number(segment.speaker) + 1
if segments_concatenated[-1]['speaker'] != speaker:
segments_concatenated.append({"speaker": speaker, "begin": segment.start, "end": segment.end})
else:
segments_concatenated[-1]['end'] = segment.end
return segments_concatenated
def add_speaker_to_tokens_with_punctuation(segments: List[SpeakerSegment], tokens: list) -> list:
"""Assign speakers to tokens with punctuation-aware boundary adjustment."""
punctuation_marks = {'.', '!', '?'}
punctuation_tokens = [token for token in tokens if token.text.strip() in punctuation_marks]
segments_concatenated = concatenate_speakers(segments)
for ind, segment in enumerate(segments_concatenated):
for i, punctuation_token in enumerate(punctuation_tokens):
if punctuation_token.start > segment['end']:
after_length = punctuation_token.start - segment['end']
before_length = segment['end'] - punctuation_tokens[i - 1].end if i > 0 else float('inf')
if before_length > after_length:
segment['end'] = punctuation_token.start
if i < len(punctuation_tokens) - 1 and ind + 1 < len(segments_concatenated):
segments_concatenated[ind + 1]['begin'] = punctuation_token.start
else:
segment['end'] = punctuation_tokens[i - 1].end if i > 0 else segment['end']
if i < len(punctuation_tokens) - 1 and ind - 1 >= 0:
segments_concatenated[ind - 1]['begin'] = punctuation_tokens[i - 1].end
break
# Ensure non-overlapping tokens
last_end = 0.0
for token in tokens:
start = max(last_end + 0.01, token.start)
token.start = start
token.end = max(start, token.end)
last_end = token.end
# Assign speakers based on adjusted segments
ind_last_speaker = 0
for segment in segments_concatenated:
for i, token in enumerate(tokens[ind_last_speaker:]):
if token.end <= segment['end']:
token.speaker = segment['speaker']
ind_last_speaker = i + 1
elif token.start > segment['end']:
break
return tokens
def assign_speakers_to_tokens(tokens: list, segments: List[SpeakerSegment], use_punctuation_split: bool = False) -> list:
"""
Assign speakers to tokens based on timing overlap with speaker segments.
Args:
tokens: List of tokens with timing information
segments: List of speaker segments
use_punctuation_split: Whether to use punctuation for boundary refinement
Returns:
List of tokens with speaker assignments
"""
if not segments or not tokens:
logger.debug("No segments or tokens available for speaker assignment")
return tokens
logger.debug(f"Assigning speakers to {len(tokens)} tokens using {len(segments)} segments")
if not use_punctuation_split:
# Simple overlap-based assignment
for token in tokens:
token.speaker = -1 # Default to no speaker
for segment in segments:
# Check for timing overlap
if not (segment.end <= token.start or segment.start >= token.end):
speaker_num = extract_number(segment.speaker)
token.speaker = speaker_num + 1 # Convert to 1-based indexing
break
else:
# Use punctuation-aware assignment
tokens = add_speaker_to_tokens_with_punctuation(segments, tokens)
return tokens
def format_output(state, silence, args, sep):
diarization = args.diarization
@@ -56,6 +157,11 @@ def format_output(state, silence, args, sep):
last_speaker = abs(state.last_speaker)
undiarized_text = []
tokens = handle_silences(tokens, state.beg_loop, silence)
# Assign speakers to tokens based on segments stored in state
if diarization and state.speaker_segments:
use_punctuation_split = args.punctuation_split if hasattr(args, 'punctuation_split') else False
tokens = assign_speakers_to_tokens(tokens, state.speaker_segments, use_punctuation_split=use_punctuation_split)
for i in range(last_validated_token, len(tokens)):
token = tokens[i]
speaker = int(token.speaker)

View File

@@ -185,6 +185,7 @@ class State():
translation_validated_segments: list = field(default_factory=list)
buffer_translation: str = field(default_factory=Transcript)
buffer_transcription: str = field(default_factory=Transcript)
speaker_segments: list = field(default_factory=list)
end_buffer: float = 0.0
end_attributed_speaker: float = 0.0
remaining_time_transcription: float = 0.0