mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
375 lines
16 KiB
Python
375 lines
16 KiB
Python
from time import time
|
|
from typing import Any, List, Optional, Tuple, Union
|
|
|
|
from whisperlivekit.timed_objects import (ASRToken, Segment, SegmentBuffer, PuncSegment, Silence,
|
|
SilentSegment, SpeakerSegment,
|
|
TimedText)
|
|
|
|
|
|
class TokensAlignment:
|
|
# Minimum duration (seconds) for a silence to be displayed
|
|
MIN_SILENCE_DISPLAY_DURATION = 2.0
|
|
|
|
def __init__(self, state: Any, args: Any, sep: Optional[str]) -> None:
|
|
self.state = state
|
|
self.diarization = args.diarization
|
|
self._tokens_index: int = 0
|
|
self._diarization_index: int = 0
|
|
self._translation_index: int = 0
|
|
|
|
self.all_tokens: List[ASRToken] = []
|
|
self.all_diarization_segments: List[SpeakerSegment] = []
|
|
self.all_translation_segments: List[Any] = []
|
|
|
|
self.new_tokens: List[ASRToken] = []
|
|
self.new_diarization: List[SpeakerSegment] = []
|
|
self.new_translation: List[Any] = []
|
|
self.new_translation_buffer: Union[TimedText, str] = TimedText()
|
|
self.new_tokens_buffer: List[Any] = []
|
|
self.sep: str = sep if sep is not None else ' '
|
|
self.beg_loop: Optional[float] = None
|
|
|
|
self.validated_segments: List[Segment] = []
|
|
self.current_line_tokens: List[ASRToken] = []
|
|
self.diarization_buffer: List[ASRToken] = []
|
|
|
|
self.last_punctuation = None
|
|
self.last_uncompleted_punc_segment: PuncSegment = None
|
|
self.tokens_after_last_punctuation: PuncSegment = []
|
|
self.all_validated_segments: List[Segment] = []
|
|
|
|
# For token-by-token validation with diarization
|
|
self.pending_tokens: List[ASRToken] = []
|
|
self.last_validated_token_end: float = 0.0
|
|
|
|
# Segment ID counter for the new API
|
|
self._next_segment_id: int = 1
|
|
|
|
def update(self) -> None:
|
|
"""Drain state buffers into the running alignment context."""
|
|
self.new_tokens, self.state.new_tokens = self.state.new_tokens, []
|
|
self.new_diarization, self.state.new_diarization = self.state.new_diarization, []
|
|
self.new_translation, self.state.new_translation = self.state.new_translation, []
|
|
self.new_tokens_buffer, self.state.new_tokens_buffer = self.state.new_tokens_buffer, []
|
|
|
|
self.all_tokens.extend(self.new_tokens)
|
|
self.all_diarization_segments.extend(self.new_diarization)
|
|
self.all_translation_segments.extend(self.new_translation)
|
|
self.new_translation_buffer = self.state.new_translation_buffer
|
|
|
|
def add_translation(self, segment: Segment) -> None:
|
|
"""Append translated text segments that overlap with a segment."""
|
|
for ts in self.all_translation_segments:
|
|
if ts.is_within(segment):
|
|
segment.translation += ts.text + (self.sep if ts.text else '')
|
|
elif segment.translation:
|
|
break
|
|
|
|
|
|
def compute_punctuations_segments(self, tokens: Optional[List[ASRToken]] = None) -> List[PuncSegment]:
|
|
"""Group tokens into segments split by punctuation and explicit silence."""
|
|
segments = []
|
|
segment_start_idx = 0
|
|
for i, token in enumerate(self.all_tokens):
|
|
if token.is_silence():
|
|
previous_segment = PuncSegment.from_tokens(
|
|
tokens=self.all_tokens[segment_start_idx: i],
|
|
)
|
|
if previous_segment:
|
|
segments.append(previous_segment)
|
|
segment = PuncSegment.from_tokens(
|
|
tokens=[token],
|
|
is_silence=True
|
|
)
|
|
segments.append(segment)
|
|
segment_start_idx = i+1
|
|
else:
|
|
if token.has_punctuation():
|
|
segment = PuncSegment.from_tokens(
|
|
tokens=self.all_tokens[segment_start_idx: i+1],
|
|
)
|
|
segments.append(segment)
|
|
segment_start_idx = i+1
|
|
|
|
final_segment = PuncSegment.from_tokens(
|
|
tokens=self.all_tokens[segment_start_idx:],
|
|
)
|
|
if final_segment:
|
|
segments.append(final_segment)
|
|
return segments
|
|
|
|
def compute_new_punctuations_segments(self) -> List[PuncSegment]:
|
|
new_punc_segments = []
|
|
segment_start_idx = 0
|
|
self.tokens_after_last_punctuation += self.new_tokens
|
|
for i, token in enumerate(self.tokens_after_last_punctuation):
|
|
if token.is_silence():
|
|
previous_segment = PuncSegment.from_tokens(
|
|
tokens=self.tokens_after_last_punctuation[segment_start_idx: i],
|
|
)
|
|
if previous_segment:
|
|
new_punc_segments.append(previous_segment)
|
|
segment = PuncSegment.from_tokens(
|
|
tokens=[token],
|
|
is_silence=True
|
|
)
|
|
new_punc_segments.append(segment)
|
|
segment_start_idx = i+1
|
|
else:
|
|
if token.has_punctuation():
|
|
segment = PuncSegment.from_tokens(
|
|
tokens=self.tokens_after_last_punctuation[segment_start_idx: i+1],
|
|
)
|
|
new_punc_segments.append(segment)
|
|
segment_start_idx = i+1
|
|
|
|
self.tokens_after_last_punctuation = self.tokens_after_last_punctuation[segment_start_idx:]
|
|
return new_punc_segments
|
|
|
|
|
|
def concatenate_diar_segments(self) -> List[SpeakerSegment]:
|
|
"""Merge consecutive diarization slices that share the same speaker."""
|
|
if not self.all_diarization_segments:
|
|
return []
|
|
merged = [self.all_diarization_segments[0]]
|
|
for segment in self.all_diarization_segments[1:]:
|
|
if segment.speaker == merged[-1].speaker:
|
|
merged[-1].end = segment.end
|
|
else:
|
|
merged.append(segment)
|
|
return merged
|
|
|
|
|
|
@staticmethod
|
|
def intersection_duration(seg1: TimedText, seg2: TimedText) -> float:
|
|
"""Return the overlap duration between two timed segments."""
|
|
start = max(seg1.start, seg2.start)
|
|
end = min(seg1.end, seg2.end)
|
|
|
|
return max(0, end - start)
|
|
|
|
def _get_speaker_for_token(self, token: ASRToken, diarization_segments: List[SpeakerSegment]) -> Optional[int]:
|
|
"""Get speaker ID for a token based on diarization overlap. Returns None if not covered."""
|
|
if not diarization_segments:
|
|
return None
|
|
|
|
# Check if token is beyond diarization coverage
|
|
if token.start >= diarization_segments[-1].end:
|
|
return None
|
|
|
|
# Find speaker with max overlap
|
|
max_overlap = 0.0
|
|
best_speaker = None
|
|
for diar_seg in diarization_segments:
|
|
overlap = self.intersection_duration(token, diar_seg)
|
|
if overlap > max_overlap:
|
|
max_overlap = overlap
|
|
best_speaker = diar_seg.speaker + 1 # 1-indexed
|
|
|
|
return best_speaker if max_overlap > 0 else None
|
|
|
|
def get_lines_diarization(self) -> Tuple[List[Segment], str]:
|
|
"""Build segments with token-by-token validation when diarization covers them."""
|
|
diarization_segments = self.concatenate_diar_segments()
|
|
|
|
# Add new tokens to pending
|
|
self.pending_tokens.extend(self.new_tokens)
|
|
|
|
# Process pending tokens - validate those covered by diarization
|
|
still_pending = []
|
|
for token in self.pending_tokens:
|
|
if token.is_silence():
|
|
# Handle silence tokens
|
|
silence_duration = (token.end or 0) - (token.start or 0)
|
|
if silence_duration >= self.MIN_SILENCE_DISPLAY_DURATION:
|
|
# Significant silence - add as separate segment
|
|
if self.all_validated_segments and not self.all_validated_segments[-1].is_silence():
|
|
self.all_validated_segments.append(SilentSegment(
|
|
start=token.start,
|
|
end=token.end
|
|
))
|
|
elif self.all_validated_segments and self.all_validated_segments[-1].is_silence():
|
|
# Extend existing silence
|
|
self.all_validated_segments[-1].end = token.end
|
|
else:
|
|
self.all_validated_segments.append(SilentSegment(
|
|
start=token.start,
|
|
end=token.end
|
|
))
|
|
# Short silences are ignored (don't go to pending either)
|
|
continue
|
|
|
|
speaker = self._get_speaker_for_token(token, diarization_segments)
|
|
|
|
if speaker is not None:
|
|
# Token is covered by diarization - validate it
|
|
if self.all_validated_segments:
|
|
last_seg = self.all_validated_segments[-1]
|
|
if not last_seg.is_silence() and last_seg.speaker == speaker:
|
|
# Same speaker - append to existing segment
|
|
last_seg.text += token.text
|
|
last_seg.end = token.end
|
|
else:
|
|
# Different speaker or after silence - new segment
|
|
new_seg = Segment(
|
|
start=token.start,
|
|
end=token.end,
|
|
text=token.text,
|
|
speaker=speaker,
|
|
start_speaker=token.start,
|
|
detected_language=token.detected_language
|
|
)
|
|
self.all_validated_segments.append(new_seg)
|
|
else:
|
|
# First segment
|
|
new_seg = Segment(
|
|
start=token.start,
|
|
end=token.end,
|
|
text=token.text,
|
|
speaker=speaker,
|
|
start_speaker=token.start,
|
|
detected_language=token.detected_language
|
|
)
|
|
self.all_validated_segments.append(new_seg)
|
|
|
|
self.last_validated_token_end = token.end
|
|
else:
|
|
# Token not yet covered by diarization - keep pending
|
|
still_pending.append(token)
|
|
|
|
self.pending_tokens = still_pending
|
|
|
|
# Build diarization buffer from pending tokens
|
|
diarization_buffer = ''.join(t.text for t in self.pending_tokens if not t.is_silence())
|
|
|
|
return self.all_validated_segments, diarization_buffer
|
|
|
|
|
|
def _assign_segment_ids(self, segments: List[Segment]) -> None:
|
|
"""Assign unique IDs to segments that don't have one yet."""
|
|
for segment in segments:
|
|
if segment.id is None:
|
|
segment.id = self._next_segment_id
|
|
self._next_segment_id += 1
|
|
|
|
def _assign_buffers_to_last_segment(
|
|
self,
|
|
segments: List[Segment],
|
|
buffer_transcription: str,
|
|
buffer_diarization: str,
|
|
buffer_translation: str
|
|
) -> None:
|
|
"""Assign buffer content to the last non-silent segment."""
|
|
# First, clear ALL buffers (they're ephemeral and shouldn't persist)
|
|
for segment in segments:
|
|
segment.buffer = SegmentBuffer()
|
|
|
|
# Find the last non-silent segment and assign buffers to it
|
|
for segment in reversed(segments):
|
|
if not segment.is_silence():
|
|
segment.buffer = SegmentBuffer(
|
|
transcription=buffer_transcription,
|
|
diarization=buffer_diarization,
|
|
translation=buffer_translation
|
|
)
|
|
break
|
|
|
|
def _filter_and_merge_segments(self, segments: List[Segment]) -> List[Segment]:
|
|
"""Filter parasitic silences and merge consecutive same-speaker segments."""
|
|
if not segments:
|
|
return segments
|
|
|
|
result = []
|
|
for seg in segments:
|
|
if seg.is_silence():
|
|
# Filter short silences
|
|
duration = (seg.end or 0) - (seg.start or 0)
|
|
if duration < self.MIN_SILENCE_DISPLAY_DURATION:
|
|
continue
|
|
# Merge consecutive silences
|
|
if result and result[-1].is_silence():
|
|
result[-1].end = seg.end
|
|
continue
|
|
else:
|
|
# Merge same speaker segments (across filtered silences)
|
|
if result and not result[-1].is_silence() and result[-1].speaker == seg.speaker:
|
|
result[-1].text += seg.text
|
|
result[-1].end = seg.end
|
|
continue
|
|
|
|
result.append(seg)
|
|
|
|
return result
|
|
|
|
def get_lines(
|
|
self,
|
|
diarization: bool = False,
|
|
translation: bool = False,
|
|
current_silence: Optional[Silence] = None,
|
|
buffer_transcription: str = ''
|
|
) -> List[Segment]:
|
|
"""Return the formatted segments with per-segment buffers, optionally with diarization/translation."""
|
|
diarization_buffer = ''
|
|
|
|
if diarization:
|
|
segments, diarization_buffer = self.get_lines_diarization()
|
|
else:
|
|
for token in self.new_tokens:
|
|
if token.is_silence():
|
|
# Check silence duration before adding
|
|
silence_duration = (token.end or 0) - (token.start or 0)
|
|
if silence_duration >= self.MIN_SILENCE_DISPLAY_DURATION:
|
|
if self.current_line_tokens:
|
|
self.validated_segments.append(Segment().from_tokens(self.current_line_tokens))
|
|
self.current_line_tokens = []
|
|
|
|
end_silence = token.end if token.has_ended else time() - self.beg_loop
|
|
if self.validated_segments and self.validated_segments[-1].is_silence():
|
|
self.validated_segments[-1].end = end_silence
|
|
else:
|
|
self.validated_segments.append(SilentSegment(
|
|
start=token.start,
|
|
end=end_silence
|
|
))
|
|
else:
|
|
self.current_line_tokens.append(token)
|
|
|
|
segments = list(self.validated_segments)
|
|
if self.current_line_tokens:
|
|
segments.append(Segment().from_tokens(self.current_line_tokens))
|
|
|
|
# Handle current ongoing silence
|
|
if current_silence:
|
|
silence_duration = (current_silence.end or time() - self.beg_loop) - (current_silence.start or 0)
|
|
if silence_duration >= self.MIN_SILENCE_DISPLAY_DURATION:
|
|
end_silence = current_silence.end if current_silence.has_ended else time() - self.beg_loop
|
|
if segments and segments[-1].is_silence():
|
|
segments[-1] = SilentSegment(start=segments[-1].start, end=end_silence)
|
|
else:
|
|
segments.append(SilentSegment(
|
|
start=current_silence.start,
|
|
end=end_silence
|
|
))
|
|
|
|
if translation:
|
|
[self.add_translation(segment) for segment in segments if not segment.is_silence()]
|
|
|
|
# Get translation buffer text
|
|
translation_buffer = self.new_translation_buffer.text if self.new_translation_buffer else ''
|
|
|
|
# Filter parasitic silences and merge same-speaker segments
|
|
segments = self._filter_and_merge_segments(segments)
|
|
|
|
# Assign unique IDs to all segments
|
|
self._assign_segment_ids(segments)
|
|
|
|
# Assign buffers to the last active segment
|
|
self._assign_buffers_to_last_segment(
|
|
segments,
|
|
buffer_transcription=buffer_transcription,
|
|
buffer_diarization=diarization_buffer,
|
|
buffer_translation=translation_buffer
|
|
)
|
|
|
|
return segments
|