Lines to Segments. Merging dataclasses

This commit is contained in:
Quentin Fuxa
2025-11-27 21:54:58 +01:00
parent 34ddd2ac02
commit c0965c6c31
3 changed files with 95 additions and 74 deletions

View File

@@ -12,7 +12,7 @@ from whisperlivekit.core import (TranscriptionEngine,
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
from whisperlivekit.silero_vad_iterator import FixedVADIterator
from whisperlivekit.timed_objects import (ASRToken, ChangeSpeaker, FrontData,
Line, Silence, State, Transcript)
Segment, Silence, State, Transcript)
from whisperlivekit.tokens_alignment import TokensAlignment
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")

View File

@@ -114,6 +114,9 @@ class Segment(TimedText):
end: Optional[float]
text: Optional[str]
speaker: Optional[str]
tokens: Optional[ASRToken] = None
translation: Optional[Translation] = None
@classmethod
def from_tokens(
cls,
@@ -141,17 +144,13 @@ class Segment(TimedText):
speaker=-1,
detected_language=start_token.detected_language
)
def is_silence(self) -> bool:
"""True when this segment represents a silence gap."""
return self.speaker == -2
@dataclass
class Line(TimedText):
translation: str = ''
def to_dict(self) -> Dict[str, Any]:
"""Serialize the line for frontend consumption."""
"""Serialize the segment for frontend consumption."""
_dict: Dict[str, Any] = {
'speaker': int(self.speaker) if self.speaker != -1 else 1,
'text': self.text,
@@ -163,29 +162,13 @@ class Line(TimedText):
if self.detected_language:
_dict['detected_language'] = self.detected_language
return _dict
def build_from_tokens(self, tokens: List[ASRToken]) -> "Line":
"""Populate line attributes from a contiguous token list."""
self.text = ''.join([token.text for token in tokens])
self.start = tokens[0].start
self.end = tokens[-1].end
self.speaker = 1
self.detected_language = tokens[0].detected_language
return self
def build_from_segment(self, segment: Segment) -> "Line":
"""Populate the line fields from a pre-built segment."""
self.text = segment.text
self.start = segment.start
self.end = segment.end
self.speaker = segment.speaker
self.detected_language = segment.detected_language
return self
def is_silent(self) -> bool:
return self.speaker == -2
@dataclass
class PuncSegment(Segment):
pass
class SilentLine(Line):
class SilentSegment(Segment):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.speaker = -2
@@ -196,7 +179,7 @@ class SilentLine(Line):
class FrontData():
status: str = ''
error: str = ''
lines: list[Line] = field(default_factory=list)
lines: list[Segment] = field(default_factory=list)
buffer_transcription: str = ''
buffer_diarization: str = ''
buffer_translation: str = ''

View File

@@ -1,8 +1,8 @@
from time import time
from typing import Any, List, Optional, Tuple, Union
from whisperlivekit.timed_objects import (ASRToken, Line, Segment, Silence,
SilentLine, SpeakerSegment,
from whisperlivekit.timed_objects import (ASRToken, Segment, PuncSegment, Silence,
SilentSegment, SpeakerSegment,
TimedText)
@@ -27,6 +27,14 @@ class TokensAlignment:
self.sep: str = sep if sep is not None else ' '
self.beg_loop: Optional[float] = None
self.validated_segments: List[Segment] = []
self.current_line_tokens: List[ASRToken] = []
self.diarization_buffer: List[ASRToken] = []
self.last_punctuation = None
self.last_uncompleted_punc_segment: PuncSegment = None
self.unvalidated_tokens: PuncSegment = []
def update(self) -> None:
"""Drain state buffers into the running alignment context."""
self.new_tokens, self.state.new_tokens = self.state.new_tokens, []
@@ -39,27 +47,27 @@ class TokensAlignment:
self.all_translation_segments.extend(self.new_translation)
self.new_translation_buffer = self.state.new_translation_buffer
def add_translation(self, line: Line) -> None:
"""Append translated text segments that overlap with a line."""
def add_translation(self, segment: Segment) -> None:
"""Append translated text segments that overlap with a segment."""
for ts in self.all_translation_segments:
if ts.is_within(line):
line.translation += ts.text + (self.sep if ts.text else '')
elif line.translation:
if ts.is_within(segment):
segment.translation += ts.text + (self.sep if ts.text else '')
elif segment.translation:
break
def compute_punctuations_segments(self, tokens: Optional[List[ASRToken]] = None) -> List[Segment]:
def compute_punctuations_segments(self, tokens: Optional[List[ASRToken]] = None) -> List[PuncSegment]:
"""Group tokens into segments split by punctuation and explicit silence."""
segments = []
segment_start_idx = 0
for i, token in enumerate(self.all_tokens):
if token.is_silence():
previous_segment = Segment.from_tokens(
previous_segment = PuncSegment.from_tokens(
tokens=self.all_tokens[segment_start_idx: i],
)
if previous_segment:
segments.append(previous_segment)
segment = Segment.from_tokens(
segment = PuncSegment.from_tokens(
tokens=[token],
is_silence=True
)
@@ -67,19 +75,47 @@ class TokensAlignment:
segment_start_idx = i+1
else:
if token.has_punctuation():
segment = Segment.from_tokens(
segment = PuncSegment.from_tokens(
tokens=self.all_tokens[segment_start_idx: i+1],
)
segments.append(segment)
segment_start_idx = i+1
final_segment = Segment.from_tokens(
final_segment = PuncSegment.from_tokens(
tokens=self.all_tokens[segment_start_idx:],
)
if final_segment:
segments.append(final_segment)
return segments
def compute_new_punctuations_segments(self) -> List[PuncSegment]:
new_punc_segments = []
segment_start_idx = 0
self.unvalidated_tokens += self.new_tokens
for i, token in enumerate(self.unvalidated_tokens):
if token.is_silence():
previous_segment = PuncSegment.from_tokens(
tokens=self.unvalidated_tokens[segment_start_idx: i],
)
if previous_segment:
new_punc_segments.append(previous_segment)
segment = PuncSegment.from_tokens(
tokens=[token],
is_silence=True
)
new_punc_segments.append(segment)
segment_start_idx = i+1
else:
if token.has_punctuation():
segment = PuncSegment.from_tokens(
tokens=self.unvalidated_tokens[segment_start_idx: i+1],
)
new_punc_segments.append(segment)
segment_start_idx = i+1
self.unvalidated_tokens = self.unvalidated_tokens[segment_start_idx:]
return new_punc_segments
def concatenate_diar_segments(self) -> List[SpeakerSegment]:
"""Merge consecutive diarization slices that share the same speaker."""
@@ -102,8 +138,8 @@ class TokensAlignment:
return max(0, end - start)
def get_lines_diarization(self) -> Tuple[List[Line], str]:
"""Build lines when diarization is enabled and track overflow buffer."""
def get_lines_diarization(self) -> Tuple[List[Segment], str]:
"""Build segments when diarization is enabled and track overflow buffer."""
diarization_buffer = ''
punctuation_segments = self.compute_punctuations_segments()
diarization_segments = self.concatenate_diar_segments()
@@ -121,18 +157,18 @@ class TokensAlignment:
max_overlap_speaker = diarization_segment.speaker + 1
punctuation_segment.speaker = max_overlap_speaker
lines = []
segments = []
if punctuation_segments:
lines = [Line().build_from_segment(punctuation_segments[0])]
segments = [punctuation_segments[0]]
for segment in punctuation_segments[1:]:
if segment.speaker == lines[-1].speaker:
if lines[-1].text:
lines[-1].text += segment.text
lines[-1].end = segment.end
if segment.speaker == segments[-1].speaker:
if segments[-1].text:
segments[-1].text += segment.text
segments[-1].end = segment.end
else:
lines.append(Line().build_from_segment(segment))
segments.append(segment)
return lines, diarization_buffer
return segments, diarization_buffer
def get_lines(
@@ -140,40 +176,42 @@ class TokensAlignment:
diarization: bool = False,
translation: bool = False,
current_silence: Optional[Silence] = None
) -> Tuple[List[Line], str, Union[str, TimedText]]:
"""Return the formatted lines plus buffers, optionally with diarization/translation."""
) -> Tuple[List[Segment], str, Union[str, TimedText]]:
"""Return the formatted segments plus buffers, optionally with diarization/translation."""
if diarization:
lines, diarization_buffer = self.get_lines_diarization()
segments, diarization_buffer = self.get_lines_diarization()
else:
diarization_buffer = ''
lines = []
current_line_tokens = []
for token in self.all_tokens:
for token in self.new_tokens:
if token.is_silence():
if current_line_tokens:
lines.append(Line().build_from_tokens(current_line_tokens))
current_line_tokens = []
if self.current_line_tokens:
self.validated_segments.append(Segment().from_tokens(self.current_line_tokens))
self.current_line_tokens = []
end_silence = token.end if token.has_ended else time() - self.beg_loop
if lines and lines[-1].is_silent():
lines[-1].end = end_silence
if self.validated_segments and self.validated_segments[-1].is_silence():
self.validated_segments[-1].end = end_silence
else:
lines.append(SilentLine(
start = token.start,
end = end_silence
self.validated_segments.append(SilentSegment(
start=token.start,
end=end_silence
))
else:
current_line_tokens.append(token)
if current_line_tokens:
lines.append(Line().build_from_tokens(current_line_tokens))
self.current_line_tokens.append(token)
segments = list(self.validated_segments)
if self.current_line_tokens:
segments.append(Segment().from_tokens(self.current_line_tokens))
if current_silence:
end_silence = current_silence.end if current_silence.has_ended else time() - self.beg_loop
if lines and lines[-1].is_silent():
lines[-1].end = end_silence
if segments and segments[-1].is_silence():
segments[-1] = SilentSegment(start=segments[-1].start, end=end_silence)
else:
lines.append(SilentLine(
start = current_silence.start,
end = end_silence
segments.append(SilentSegment(
start=current_silence.start,
end=end_silence
))
if translation:
[self.add_translation(line) for line in lines if not type(line) == Silence]
return lines, diarization_buffer, self.new_translation_buffer.text
[self.add_translation(segment) for segment in segments if not segment.is_silence()]
return segments, diarization_buffer, self.new_translation_buffer.text