mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-04-26 08:06:15 +00:00
Lines to Segments. Merging dataclasses
This commit is contained in:
@@ -12,7 +12,7 @@ from whisperlivekit.core import (TranscriptionEngine,
|
||||
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
|
||||
from whisperlivekit.silero_vad_iterator import FixedVADIterator
|
||||
from whisperlivekit.timed_objects import (ASRToken, ChangeSpeaker, FrontData,
|
||||
Line, Silence, State, Transcript)
|
||||
Segment, Silence, State, Transcript)
|
||||
from whisperlivekit.tokens_alignment import TokensAlignment
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
|
||||
@@ -114,6 +114,9 @@ class Segment(TimedText):
|
||||
end: Optional[float]
|
||||
text: Optional[str]
|
||||
speaker: Optional[str]
|
||||
tokens: Optional[ASRToken] = None
|
||||
translation: Optional[Translation] = None
|
||||
|
||||
@classmethod
|
||||
def from_tokens(
|
||||
cls,
|
||||
@@ -141,17 +144,13 @@ class Segment(TimedText):
|
||||
speaker=-1,
|
||||
detected_language=start_token.detected_language
|
||||
)
|
||||
|
||||
def is_silence(self) -> bool:
|
||||
"""True when this segment represents a silence gap."""
|
||||
return self.speaker == -2
|
||||
|
||||
|
||||
@dataclass
|
||||
class Line(TimedText):
|
||||
translation: str = ''
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Serialize the line for frontend consumption."""
|
||||
"""Serialize the segment for frontend consumption."""
|
||||
_dict: Dict[str, Any] = {
|
||||
'speaker': int(self.speaker) if self.speaker != -1 else 1,
|
||||
'text': self.text,
|
||||
@@ -163,29 +162,13 @@ class Line(TimedText):
|
||||
if self.detected_language:
|
||||
_dict['detected_language'] = self.detected_language
|
||||
return _dict
|
||||
|
||||
def build_from_tokens(self, tokens: List[ASRToken]) -> "Line":
|
||||
"""Populate line attributes from a contiguous token list."""
|
||||
self.text = ''.join([token.text for token in tokens])
|
||||
self.start = tokens[0].start
|
||||
self.end = tokens[-1].end
|
||||
self.speaker = 1
|
||||
self.detected_language = tokens[0].detected_language
|
||||
return self
|
||||
|
||||
def build_from_segment(self, segment: Segment) -> "Line":
|
||||
"""Populate the line fields from a pre-built segment."""
|
||||
self.text = segment.text
|
||||
self.start = segment.start
|
||||
self.end = segment.end
|
||||
self.speaker = segment.speaker
|
||||
self.detected_language = segment.detected_language
|
||||
return self
|
||||
|
||||
def is_silent(self) -> bool:
|
||||
return self.speaker == -2
|
||||
@dataclass
|
||||
class PuncSegment(Segment):
|
||||
pass
|
||||
|
||||
class SilentLine(Line):
|
||||
class SilentSegment(Segment):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.speaker = -2
|
||||
@@ -196,7 +179,7 @@ class SilentLine(Line):
|
||||
class FrontData():
|
||||
status: str = ''
|
||||
error: str = ''
|
||||
lines: list[Line] = field(default_factory=list)
|
||||
lines: list[Segment] = field(default_factory=list)
|
||||
buffer_transcription: str = ''
|
||||
buffer_diarization: str = ''
|
||||
buffer_translation: str = ''
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from time import time
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
from whisperlivekit.timed_objects import (ASRToken, Line, Segment, Silence,
|
||||
SilentLine, SpeakerSegment,
|
||||
from whisperlivekit.timed_objects import (ASRToken, Segment, PuncSegment, Silence,
|
||||
SilentSegment, SpeakerSegment,
|
||||
TimedText)
|
||||
|
||||
|
||||
@@ -27,6 +27,14 @@ class TokensAlignment:
|
||||
self.sep: str = sep if sep is not None else ' '
|
||||
self.beg_loop: Optional[float] = None
|
||||
|
||||
self.validated_segments: List[Segment] = []
|
||||
self.current_line_tokens: List[ASRToken] = []
|
||||
self.diarization_buffer: List[ASRToken] = []
|
||||
|
||||
self.last_punctuation = None
|
||||
self.last_uncompleted_punc_segment: PuncSegment = None
|
||||
self.unvalidated_tokens: PuncSegment = []
|
||||
|
||||
def update(self) -> None:
|
||||
"""Drain state buffers into the running alignment context."""
|
||||
self.new_tokens, self.state.new_tokens = self.state.new_tokens, []
|
||||
@@ -39,27 +47,27 @@ class TokensAlignment:
|
||||
self.all_translation_segments.extend(self.new_translation)
|
||||
self.new_translation_buffer = self.state.new_translation_buffer
|
||||
|
||||
def add_translation(self, line: Line) -> None:
|
||||
"""Append translated text segments that overlap with a line."""
|
||||
def add_translation(self, segment: Segment) -> None:
|
||||
"""Append translated text segments that overlap with a segment."""
|
||||
for ts in self.all_translation_segments:
|
||||
if ts.is_within(line):
|
||||
line.translation += ts.text + (self.sep if ts.text else '')
|
||||
elif line.translation:
|
||||
if ts.is_within(segment):
|
||||
segment.translation += ts.text + (self.sep if ts.text else '')
|
||||
elif segment.translation:
|
||||
break
|
||||
|
||||
|
||||
def compute_punctuations_segments(self, tokens: Optional[List[ASRToken]] = None) -> List[Segment]:
|
||||
def compute_punctuations_segments(self, tokens: Optional[List[ASRToken]] = None) -> List[PuncSegment]:
|
||||
"""Group tokens into segments split by punctuation and explicit silence."""
|
||||
segments = []
|
||||
segment_start_idx = 0
|
||||
for i, token in enumerate(self.all_tokens):
|
||||
if token.is_silence():
|
||||
previous_segment = Segment.from_tokens(
|
||||
previous_segment = PuncSegment.from_tokens(
|
||||
tokens=self.all_tokens[segment_start_idx: i],
|
||||
)
|
||||
if previous_segment:
|
||||
segments.append(previous_segment)
|
||||
segment = Segment.from_tokens(
|
||||
segment = PuncSegment.from_tokens(
|
||||
tokens=[token],
|
||||
is_silence=True
|
||||
)
|
||||
@@ -67,19 +75,47 @@ class TokensAlignment:
|
||||
segment_start_idx = i+1
|
||||
else:
|
||||
if token.has_punctuation():
|
||||
segment = Segment.from_tokens(
|
||||
segment = PuncSegment.from_tokens(
|
||||
tokens=self.all_tokens[segment_start_idx: i+1],
|
||||
)
|
||||
segments.append(segment)
|
||||
segment_start_idx = i+1
|
||||
|
||||
final_segment = Segment.from_tokens(
|
||||
final_segment = PuncSegment.from_tokens(
|
||||
tokens=self.all_tokens[segment_start_idx:],
|
||||
)
|
||||
if final_segment:
|
||||
segments.append(final_segment)
|
||||
return segments
|
||||
|
||||
def compute_new_punctuations_segments(self) -> List[PuncSegment]:
|
||||
new_punc_segments = []
|
||||
segment_start_idx = 0
|
||||
self.unvalidated_tokens += self.new_tokens
|
||||
for i, token in enumerate(self.unvalidated_tokens):
|
||||
if token.is_silence():
|
||||
previous_segment = PuncSegment.from_tokens(
|
||||
tokens=self.unvalidated_tokens[segment_start_idx: i],
|
||||
)
|
||||
if previous_segment:
|
||||
new_punc_segments.append(previous_segment)
|
||||
segment = PuncSegment.from_tokens(
|
||||
tokens=[token],
|
||||
is_silence=True
|
||||
)
|
||||
new_punc_segments.append(segment)
|
||||
segment_start_idx = i+1
|
||||
else:
|
||||
if token.has_punctuation():
|
||||
segment = PuncSegment.from_tokens(
|
||||
tokens=self.unvalidated_tokens[segment_start_idx: i+1],
|
||||
)
|
||||
new_punc_segments.append(segment)
|
||||
segment_start_idx = i+1
|
||||
|
||||
self.unvalidated_tokens = self.unvalidated_tokens[segment_start_idx:]
|
||||
return new_punc_segments
|
||||
|
||||
|
||||
def concatenate_diar_segments(self) -> List[SpeakerSegment]:
|
||||
"""Merge consecutive diarization slices that share the same speaker."""
|
||||
@@ -102,8 +138,8 @@ class TokensAlignment:
|
||||
|
||||
return max(0, end - start)
|
||||
|
||||
def get_lines_diarization(self) -> Tuple[List[Line], str]:
|
||||
"""Build lines when diarization is enabled and track overflow buffer."""
|
||||
def get_lines_diarization(self) -> Tuple[List[Segment], str]:
|
||||
"""Build segments when diarization is enabled and track overflow buffer."""
|
||||
diarization_buffer = ''
|
||||
punctuation_segments = self.compute_punctuations_segments()
|
||||
diarization_segments = self.concatenate_diar_segments()
|
||||
@@ -121,18 +157,18 @@ class TokensAlignment:
|
||||
max_overlap_speaker = diarization_segment.speaker + 1
|
||||
punctuation_segment.speaker = max_overlap_speaker
|
||||
|
||||
lines = []
|
||||
segments = []
|
||||
if punctuation_segments:
|
||||
lines = [Line().build_from_segment(punctuation_segments[0])]
|
||||
segments = [punctuation_segments[0]]
|
||||
for segment in punctuation_segments[1:]:
|
||||
if segment.speaker == lines[-1].speaker:
|
||||
if lines[-1].text:
|
||||
lines[-1].text += segment.text
|
||||
lines[-1].end = segment.end
|
||||
if segment.speaker == segments[-1].speaker:
|
||||
if segments[-1].text:
|
||||
segments[-1].text += segment.text
|
||||
segments[-1].end = segment.end
|
||||
else:
|
||||
lines.append(Line().build_from_segment(segment))
|
||||
segments.append(segment)
|
||||
|
||||
return lines, diarization_buffer
|
||||
return segments, diarization_buffer
|
||||
|
||||
|
||||
def get_lines(
|
||||
@@ -140,40 +176,42 @@ class TokensAlignment:
|
||||
diarization: bool = False,
|
||||
translation: bool = False,
|
||||
current_silence: Optional[Silence] = None
|
||||
) -> Tuple[List[Line], str, Union[str, TimedText]]:
|
||||
"""Return the formatted lines plus buffers, optionally with diarization/translation."""
|
||||
) -> Tuple[List[Segment], str, Union[str, TimedText]]:
|
||||
"""Return the formatted segments plus buffers, optionally with diarization/translation."""
|
||||
if diarization:
|
||||
lines, diarization_buffer = self.get_lines_diarization()
|
||||
segments, diarization_buffer = self.get_lines_diarization()
|
||||
else:
|
||||
diarization_buffer = ''
|
||||
lines = []
|
||||
current_line_tokens = []
|
||||
for token in self.all_tokens:
|
||||
for token in self.new_tokens:
|
||||
if token.is_silence():
|
||||
if current_line_tokens:
|
||||
lines.append(Line().build_from_tokens(current_line_tokens))
|
||||
current_line_tokens = []
|
||||
if self.current_line_tokens:
|
||||
self.validated_segments.append(Segment().from_tokens(self.current_line_tokens))
|
||||
self.current_line_tokens = []
|
||||
|
||||
end_silence = token.end if token.has_ended else time() - self.beg_loop
|
||||
if lines and lines[-1].is_silent():
|
||||
lines[-1].end = end_silence
|
||||
if self.validated_segments and self.validated_segments[-1].is_silence():
|
||||
self.validated_segments[-1].end = end_silence
|
||||
else:
|
||||
lines.append(SilentLine(
|
||||
start = token.start,
|
||||
end = end_silence
|
||||
self.validated_segments.append(SilentSegment(
|
||||
start=token.start,
|
||||
end=end_silence
|
||||
))
|
||||
else:
|
||||
current_line_tokens.append(token)
|
||||
if current_line_tokens:
|
||||
lines.append(Line().build_from_tokens(current_line_tokens))
|
||||
self.current_line_tokens.append(token)
|
||||
|
||||
segments = list(self.validated_segments)
|
||||
if self.current_line_tokens:
|
||||
segments.append(Segment().from_tokens(self.current_line_tokens))
|
||||
|
||||
if current_silence:
|
||||
end_silence = current_silence.end if current_silence.has_ended else time() - self.beg_loop
|
||||
if lines and lines[-1].is_silent():
|
||||
lines[-1].end = end_silence
|
||||
if segments and segments[-1].is_silence():
|
||||
segments[-1] = SilentSegment(start=segments[-1].start, end=end_silence)
|
||||
else:
|
||||
lines.append(SilentLine(
|
||||
start = current_silence.start,
|
||||
end = end_silence
|
||||
segments.append(SilentSegment(
|
||||
start=current_silence.start,
|
||||
end=end_silence
|
||||
))
|
||||
if translation:
|
||||
[self.add_translation(line) for line in lines if not type(line) == Silence]
|
||||
return lines, diarization_buffer, self.new_translation_buffer.text
|
||||
[self.add_translation(segment) for segment in segments if not segment.is_silence()]
|
||||
return segments, diarization_buffer, self.new_translation_buffer.text
|
||||
|
||||
Reference in New Issue
Block a user