mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 14:23:18 +00:00
segment attribution in result formatter
This commit is contained in:
@@ -417,14 +417,12 @@ class AudioProcessor:
|
||||
to_transcript, self.cumulative_pcm = cut_at(self.cumulative_pcm, cut_sec)
|
||||
await self.transcription_queue.put(to_transcript)
|
||||
self.last_end = last_segment.end
|
||||
elif not self.diarization_before_transcription:
|
||||
elif not self.diarization_before_transcription:
|
||||
segments = diarization_obj.get_segments()
|
||||
async with self.lock:
|
||||
self.state.tokens = diarization_obj.assign_speakers_to_tokens(
|
||||
self.state.tokens,
|
||||
use_punctuation_split=self.args.punctuation_split
|
||||
)
|
||||
if len(self.state.tokens) > 0:
|
||||
self.state.end_attributed_speaker = max(self.state.tokens[-1].end, self.state.end_attributed_speaker)
|
||||
self.state.speaker_segments = segments.copy()
|
||||
if segments:
|
||||
self.state.end_attributed_speaker = max(seg.end for seg in segments)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception in diarization_processor: {e}")
|
||||
|
||||
@@ -178,7 +178,6 @@ class DiartDiarization:
|
||||
|
||||
self.pipeline = SpeakerDiarization(config=config)
|
||||
self.observer = DiarizationObserver()
|
||||
self.lag_diart = None
|
||||
|
||||
if use_microphone:
|
||||
self.source = MicrophoneAudioSource(block_duration=block_duration)
|
||||
@@ -217,32 +216,6 @@ class DiartDiarization:
|
||||
if self.custom_source:
|
||||
self.custom_source.close()
|
||||
|
||||
def assign_speakers_to_tokens(self, tokens: list, use_punctuation_split: bool = False) -> float:
|
||||
"""
|
||||
Assign speakers to tokens based on timing overlap with speaker segments.
|
||||
Uses the segments collected by the observer.
|
||||
|
||||
If use_punctuation_split is True, uses punctuation marks to refine speaker boundaries.
|
||||
"""
|
||||
segments = self.observer.get_segments()
|
||||
|
||||
# Debug logging
|
||||
logger.debug(f"assign_speakers_to_tokens called with {len(tokens)} tokens")
|
||||
logger.debug(f"Available segments: {len(segments)}")
|
||||
for i, seg in enumerate(segments[:5]): # Show first 5 segments
|
||||
logger.debug(f" Segment {i}: {seg.speaker} [{seg.start:.2f}-{seg.end:.2f}]")
|
||||
|
||||
if not self.lag_diart and segments and tokens:
|
||||
self.lag_diart = segments[0].start - tokens[0].start
|
||||
|
||||
if not use_punctuation_split:
|
||||
for token in tokens:
|
||||
for segment in segments:
|
||||
if not (segment.end <= token.start + self.lag_diart or segment.start >= token.end + self.lag_diart):
|
||||
token.speaker = extract_number(segment.speaker) + 1
|
||||
else:
|
||||
tokens = add_speaker_to_tokens(segments, tokens)
|
||||
return tokens
|
||||
|
||||
def concatenate_speakers(segments):
|
||||
segments_concatenated = [{"speaker": 1, "begin": 0.0, "end": 0.0}]
|
||||
|
||||
@@ -279,119 +279,6 @@ class SortformerDiarizationOnline:
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing predictions: {e}")
|
||||
|
||||
def assign_speakers_to_tokens(self, tokens: list, use_punctuation_split: bool = False) -> list:
|
||||
"""
|
||||
Assign speakers to tokens based on timing overlap with speaker segments.
|
||||
|
||||
Args:
|
||||
tokens: List of tokens with timing information
|
||||
use_punctuation_split: Whether to use punctuation for boundary refinement
|
||||
|
||||
Returns:
|
||||
List of tokens with speaker assignments
|
||||
Last speaker_segment
|
||||
"""
|
||||
with self.segment_lock:
|
||||
segments = self.speaker_segments.copy()
|
||||
|
||||
if not segments or not tokens:
|
||||
logger.debug("No segments or tokens available for speaker assignment")
|
||||
return tokens
|
||||
|
||||
logger.debug(f"Assigning speakers to {len(tokens)} tokens using {len(segments)} segments")
|
||||
use_punctuation_split = False
|
||||
if not use_punctuation_split:
|
||||
# Simple overlap-based assignment
|
||||
for token in tokens:
|
||||
token.speaker = -1 # Default to no speaker
|
||||
for segment in segments:
|
||||
# Check for timing overlap
|
||||
if not (segment.end <= token.start or segment.start >= token.end):
|
||||
token.speaker = segment.speaker + 1 # Convert to 1-based indexing
|
||||
break
|
||||
else:
|
||||
# Use punctuation-aware assignment (similar to diart_backend)
|
||||
tokens = self._add_speaker_to_tokens_with_punctuation(segments, tokens)
|
||||
|
||||
return tokens
|
||||
|
||||
def _add_speaker_to_tokens_with_punctuation(self, segments: List[SpeakerSegment], tokens: list) -> list:
|
||||
"""
|
||||
Assign speakers to tokens with punctuation-aware boundary adjustment.
|
||||
|
||||
Args:
|
||||
segments: List of speaker segments
|
||||
tokens: List of tokens to assign speakers to
|
||||
|
||||
Returns:
|
||||
List of tokens with speaker assignments
|
||||
"""
|
||||
punctuation_marks = {'.', '!', '?'}
|
||||
punctuation_tokens = [token for token in tokens if token.text.strip() in punctuation_marks]
|
||||
|
||||
# Convert segments to concatenated format
|
||||
segments_concatenated = self._concatenate_speakers(segments)
|
||||
|
||||
# Adjust segment boundaries based on punctuation
|
||||
for ind, segment in enumerate(segments_concatenated):
|
||||
for i, punctuation_token in enumerate(punctuation_tokens):
|
||||
if punctuation_token.start > segment['end']:
|
||||
after_length = punctuation_token.start - segment['end']
|
||||
before_length = segment['end'] - punctuation_tokens[i - 1].end if i > 0 else float('inf')
|
||||
|
||||
if before_length > after_length:
|
||||
segment['end'] = punctuation_token.start
|
||||
if i < len(punctuation_tokens) - 1 and ind + 1 < len(segments_concatenated):
|
||||
segments_concatenated[ind + 1]['begin'] = punctuation_token.start
|
||||
else:
|
||||
segment['end'] = punctuation_tokens[i - 1].end if i > 0 else segment['end']
|
||||
if i < len(punctuation_tokens) - 1 and ind - 1 >= 0:
|
||||
segments_concatenated[ind - 1]['begin'] = punctuation_tokens[i - 1].end
|
||||
break
|
||||
|
||||
# Ensure non-overlapping tokens
|
||||
last_end = 0.0
|
||||
for token in tokens:
|
||||
start = max(last_end + 0.01, token.start)
|
||||
token.start = start
|
||||
token.end = max(start, token.end)
|
||||
last_end = token.end
|
||||
|
||||
# Assign speakers based on adjusted segments
|
||||
ind_last_speaker = 0
|
||||
for segment in segments_concatenated:
|
||||
for i, token in enumerate(tokens[ind_last_speaker:]):
|
||||
if token.end <= segment['end']:
|
||||
token.speaker = segment['speaker']
|
||||
ind_last_speaker = i + 1
|
||||
elif token.start > segment['end']:
|
||||
break
|
||||
|
||||
return tokens
|
||||
|
||||
def _concatenate_speakers(self, segments: List[SpeakerSegment]) -> List[dict]:
|
||||
"""
|
||||
Concatenate consecutive segments from the same speaker.
|
||||
|
||||
Args:
|
||||
segments: List of speaker segments
|
||||
|
||||
Returns:
|
||||
List of concatenated speaker segments
|
||||
"""
|
||||
if not segments:
|
||||
return []
|
||||
|
||||
segments_concatenated = [{"speaker": segments[0].speaker + 1, "begin": segments[0].start, "end": segments[0].end}]
|
||||
|
||||
for segment in segments[1:]:
|
||||
speaker = segment.speaker + 1
|
||||
if segments_concatenated[-1]['speaker'] != speaker:
|
||||
segments_concatenated.append({"speaker": speaker, "begin": segment.start, "end": segment.end})
|
||||
else:
|
||||
segments_concatenated[-1]['end'] = segment.end
|
||||
|
||||
return segments_concatenated
|
||||
|
||||
def get_segments(self) -> List[SpeakerSegment]:
|
||||
"""Get a copy of the current speaker segments."""
|
||||
|
||||
60
whisperlivekit/result_diarization.md
Normal file
60
whisperlivekit/result_diarization.md
Normal file
@@ -0,0 +1,60 @@
|
||||
########### WHAT IS PRODUCED: ###########
|
||||
|
||||
SPEAKER 1 0:00:04 - 0:00:06
|
||||
Transcription technology has improved so much in the past
|
||||
|
||||
SPEAKER 1 0:00:07 - 0:00:12
|
||||
years. Have you noticed how accurate real-time speech detects is now?
|
||||
|
||||
SPEAKER 2 0:00:12 - 0:00:12
|
||||
Absolutely
|
||||
|
||||
SPEAKER 1 0:00:13 - 0:00:13
|
||||
.
|
||||
|
||||
SPEAKER 2 0:00:14 - 0:00:14
|
||||
I
|
||||
|
||||
SPEAKER 1 0:00:14 - 0:00:17
|
||||
use it all the time for taking notes during meetings.
|
||||
|
||||
SPEAKER 2 0:00:17 - 0:00:17
|
||||
It
|
||||
|
||||
SPEAKER 1 0:00:17 - 0:00:22
|
||||
's amazing how it can recognize different speakers, and even add punctuation.
|
||||
|
||||
SPEAKER 2 0:00:22 - 0:00:22
|
||||
Yeah
|
||||
|
||||
SPEAKER 1 0:00:23 - 0:00:26
|
||||
, but sometimes noise can still cause mistakes.
|
||||
|
||||
SPEAKER 3 0:00:26 - 0:00:27
|
||||
Does
|
||||
|
||||
SPEAKER 1 0:00:27 - 0:00:28
|
||||
this system handle that
|
||||
|
||||
SPEAKER 1 0:00:29 - 0:00:29
|
||||
?
|
||||
|
||||
SPEAKER 3 0:00:29 - 0:00:29
|
||||
It
|
||||
|
||||
SPEAKER 1 0:00:29 - 0:00:33
|
||||
does a pretty good job filtering noise, especially with models that use voice activity
|
||||
|
||||
########### WHAT SHOULD BE PRODUCED: ###########
|
||||
|
||||
SPEAKER 1 0:00:04 - 0:00:12
|
||||
Transcription technology has improved so much in the past years. Have you noticed how accurate real-time speech detects is now?
|
||||
|
||||
SPEAKER 2 0:00:12 - 0:00:22
|
||||
Absolutely. I use it all the time for taking notes during meetings. It's amazing how it can recognize different speakers, and even add punctuation.
|
||||
|
||||
SPEAKER 3 0:00:22 - 0:00:28
|
||||
Yeah, but sometimes noise can still cause mistakes. Does this system handle that well?
|
||||
|
||||
SPEAKER 1 0:00:29 - 0:00:29
|
||||
It does a pretty good job filtering noise, especially with models that use voice activity
|
||||
@@ -1,7 +1,9 @@
|
||||
|
||||
import logging
|
||||
import re
|
||||
from whisperlivekit.remove_silences import handle_silences
|
||||
from whisperlivekit.timed_objects import Line, format_time
|
||||
from whisperlivekit.timed_objects import Line, format_time, SpeakerSegment
|
||||
from typing import List
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
@@ -44,7 +46,106 @@ def append_token_to_last_line(lines, sep, token):
|
||||
lines[-1].end = token.end
|
||||
if not lines[-1].detected_language and token.detected_language:
|
||||
lines[-1].detected_language = token.detected_language
|
||||
|
||||
|
||||
def extract_number(s) -> int:
|
||||
"""Extract number from speaker string (for diart compatibility)."""
|
||||
if isinstance(s, int):
|
||||
return s
|
||||
m = re.search(r'\d+', str(s))
|
||||
return int(m.group()) if m else 0
|
||||
|
||||
def concatenate_speakers(segments: List[SpeakerSegment]) -> List[dict]:
|
||||
"""Concatenate consecutive segments from the same speaker."""
|
||||
if not segments:
|
||||
return []
|
||||
|
||||
# Get speaker number from first segment
|
||||
first_speaker = extract_number(segments[0].speaker)
|
||||
segments_concatenated = [{"speaker": first_speaker + 1, "begin": segments[0].start, "end": segments[0].end}]
|
||||
|
||||
for segment in segments[1:]:
|
||||
speaker = extract_number(segment.speaker) + 1
|
||||
if segments_concatenated[-1]['speaker'] != speaker:
|
||||
segments_concatenated.append({"speaker": speaker, "begin": segment.start, "end": segment.end})
|
||||
else:
|
||||
segments_concatenated[-1]['end'] = segment.end
|
||||
|
||||
return segments_concatenated
|
||||
|
||||
def add_speaker_to_tokens_with_punctuation(segments: List[SpeakerSegment], tokens: list) -> list:
|
||||
"""Assign speakers to tokens with punctuation-aware boundary adjustment."""
|
||||
punctuation_marks = {'.', '!', '?'}
|
||||
punctuation_tokens = [token for token in tokens if token.text.strip() in punctuation_marks]
|
||||
segments_concatenated = concatenate_speakers(segments)
|
||||
|
||||
for ind, segment in enumerate(segments_concatenated):
|
||||
for i, punctuation_token in enumerate(punctuation_tokens):
|
||||
if punctuation_token.start > segment['end']:
|
||||
after_length = punctuation_token.start - segment['end']
|
||||
before_length = segment['end'] - punctuation_tokens[i - 1].end if i > 0 else float('inf')
|
||||
if before_length > after_length:
|
||||
segment['end'] = punctuation_token.start
|
||||
if i < len(punctuation_tokens) - 1 and ind + 1 < len(segments_concatenated):
|
||||
segments_concatenated[ind + 1]['begin'] = punctuation_token.start
|
||||
else:
|
||||
segment['end'] = punctuation_tokens[i - 1].end if i > 0 else segment['end']
|
||||
if i < len(punctuation_tokens) - 1 and ind - 1 >= 0:
|
||||
segments_concatenated[ind - 1]['begin'] = punctuation_tokens[i - 1].end
|
||||
break
|
||||
|
||||
# Ensure non-overlapping tokens
|
||||
last_end = 0.0
|
||||
for token in tokens:
|
||||
start = max(last_end + 0.01, token.start)
|
||||
token.start = start
|
||||
token.end = max(start, token.end)
|
||||
last_end = token.end
|
||||
|
||||
# Assign speakers based on adjusted segments
|
||||
ind_last_speaker = 0
|
||||
for segment in segments_concatenated:
|
||||
for i, token in enumerate(tokens[ind_last_speaker:]):
|
||||
if token.end <= segment['end']:
|
||||
token.speaker = segment['speaker']
|
||||
ind_last_speaker = i + 1
|
||||
elif token.start > segment['end']:
|
||||
break
|
||||
|
||||
return tokens
|
||||
|
||||
def assign_speakers_to_tokens(tokens: list, segments: List[SpeakerSegment], use_punctuation_split: bool = False) -> list:
|
||||
"""
|
||||
Assign speakers to tokens based on timing overlap with speaker segments.
|
||||
|
||||
Args:
|
||||
tokens: List of tokens with timing information
|
||||
segments: List of speaker segments
|
||||
use_punctuation_split: Whether to use punctuation for boundary refinement
|
||||
|
||||
Returns:
|
||||
List of tokens with speaker assignments
|
||||
"""
|
||||
if not segments or not tokens:
|
||||
logger.debug("No segments or tokens available for speaker assignment")
|
||||
return tokens
|
||||
|
||||
logger.debug(f"Assigning speakers to {len(tokens)} tokens using {len(segments)} segments")
|
||||
|
||||
if not use_punctuation_split:
|
||||
# Simple overlap-based assignment
|
||||
for token in tokens:
|
||||
token.speaker = -1 # Default to no speaker
|
||||
for segment in segments:
|
||||
# Check for timing overlap
|
||||
if not (segment.end <= token.start or segment.start >= token.end):
|
||||
speaker_num = extract_number(segment.speaker)
|
||||
token.speaker = speaker_num + 1 # Convert to 1-based indexing
|
||||
break
|
||||
else:
|
||||
# Use punctuation-aware assignment
|
||||
tokens = add_speaker_to_tokens_with_punctuation(segments, tokens)
|
||||
|
||||
return tokens
|
||||
|
||||
def format_output(state, silence, args, sep):
|
||||
diarization = args.diarization
|
||||
@@ -56,6 +157,11 @@ def format_output(state, silence, args, sep):
|
||||
last_speaker = abs(state.last_speaker)
|
||||
undiarized_text = []
|
||||
tokens = handle_silences(tokens, state.beg_loop, silence)
|
||||
|
||||
# Assign speakers to tokens based on segments stored in state
|
||||
if diarization and state.speaker_segments:
|
||||
use_punctuation_split = args.punctuation_split if hasattr(args, 'punctuation_split') else False
|
||||
tokens = assign_speakers_to_tokens(tokens, state.speaker_segments, use_punctuation_split=use_punctuation_split)
|
||||
for i in range(last_validated_token, len(tokens)):
|
||||
token = tokens[i]
|
||||
speaker = int(token.speaker)
|
||||
|
||||
@@ -185,6 +185,7 @@ class State():
|
||||
translation_validated_segments: list = field(default_factory=list)
|
||||
buffer_translation: str = field(default_factory=Transcript)
|
||||
buffer_transcription: str = field(default_factory=Transcript)
|
||||
speaker_segments: list = field(default_factory=list)
|
||||
end_buffer: float = 0.0
|
||||
end_attributed_speaker: float = 0.0
|
||||
remaining_time_transcription: float = 0.0
|
||||
|
||||
Reference in New Issue
Block a user