internal rework 4

This commit is contained in:
Quentin Fuxa
2025-11-20 23:45:20 +01:00
parent 270faf2069
commit 8e7aea4fcf
3 changed files with 169 additions and 61 deletions

View File

@@ -1,6 +1,10 @@
from whisperlivekit.timed_objects import Line, SilentLine, format_time, SpeakerSegment, Silence
from whisperlivekit.timed_objects import PunctuationSegment
from time import time
from typing import Optional
from whisperlivekit.timed_objects import Line, SilentLine, ASRToken, SpeakerSegment, Silence
from whisperlivekit.timed_objects import PunctuationSegment
ALIGNMENT_TIME_TOLERANCE = 0.2 # seconds
class TokensAlignment:
@@ -12,15 +16,16 @@ class TokensAlignment:
self._diarization_index = 0
self._translation_index = 0
self.all_tokens = []
self.all_diarization_segments = []
self.all_tokens : list[ASRToken] = []
self.all_diarization_segments: list[SpeakerSegment] = []
self.all_translation_segments = []
self.new_tokens = []
self.new_tokens : list[ASRToken] = []
self.new_diarization: list[SpeakerSegment] = []
self.new_translation = []
self.new_diarization = []
self.new_tokens_buffer = []
self.sep = ' '
self.sep = sep if sep is not None else ' '
self.beg_loop = None
def update(self):
self.new_tokens, self.state.new_tokens = self.state.new_tokens, []
@@ -32,7 +37,10 @@ class TokensAlignment:
self.all_diarization_segments.extend(self.new_diarization)
self.all_translation_segments.extend(self.new_translation)
def create_lines_from_tokens(self, current_silence, beg_loop):
def get_lines(self, current_silence):
"""
In the case without diarization
"""
lines = []
current_line_tokens = []
for token in self.all_tokens:
@@ -40,7 +48,7 @@ class TokensAlignment:
if current_line_tokens:
lines.append(Line().build_from_tokens(current_line_tokens))
current_line_tokens = []
end_silence = token.end if token.has_ended else time() - beg_loop
end_silence = token.end if token.has_ended else time() - self.beg_loop
if lines and lines[-1].is_silent():
lines[-1].end = end_silence
else:
@@ -53,7 +61,7 @@ class TokensAlignment:
if current_line_tokens:
lines.append(Line().build_from_tokens(current_line_tokens))
if current_silence:
end_silence = current_silence.end if current_silence.has_ended else time() - beg_loop
end_silence = current_silence.end if current_silence.has_ended else time() - self.beg_loop
if lines and lines[-1].is_silent():
lines[-1].end = end_silence
else:
@@ -64,22 +72,104 @@ class TokensAlignment:
return lines
def align_tokens(self):
if not self.diarization:
pass
# return self.all_tokens
def compute_punctuations_segments(self):
def _get_asr_tokens(self) -> list[ASRToken]:
return [token for token in self.all_tokens if isinstance(token, ASRToken)]
def _tokens_to_text(self, tokens: list[ASRToken]) -> str:
return ''.join(token.text for token in tokens)
def _extract_detected_language(self, tokens: list[ASRToken]):
for token in tokens:
if getattr(token, 'detected_language', None):
return token.detected_language
return None
def _speaker_display_id(self, raw_speaker) -> int:
if isinstance(raw_speaker, int):
speaker_index = raw_speaker
else:
digits = ''.join(ch for ch in str(raw_speaker) if ch.isdigit())
speaker_index = int(digits) if digits else 0
return speaker_index + 1 if speaker_index >= 0 else 0
def _line_from_tokens(self, tokens: list[ASRToken], speaker: int) -> Line:
line = Line().build_from_tokens(tokens)
line.speaker = speaker
detected_language = self._extract_detected_language(tokens)
if detected_language:
line.detected_language = detected_language
return line
def _find_initial_diar_index(self, diar_segments: list[SpeakerSegment], start_time: float) -> int:
for idx, segment in enumerate(diar_segments):
if segment.end + ALIGNMENT_TIME_TOLERANCE >= start_time:
return idx
return len(diar_segments)
def _find_speaker_for_token(self, token: ASRToken, diar_segments: list[SpeakerSegment], diar_idx: int):
if not diar_segments:
return None, diar_idx
idx = min(diar_idx, len(diar_segments) - 1)
midpoint = (token.start + token.end) / 2 if token.end is not None else token.start
while idx < len(diar_segments) and diar_segments[idx].end + ALIGNMENT_TIME_TOLERANCE < midpoint:
idx += 1
candidate_indices = []
if idx < len(diar_segments):
candidate_indices.append(idx)
if idx > 0:
candidate_indices.append(idx - 1)
for candidate_idx in candidate_indices:
segment = diar_segments[candidate_idx]
seg_start = (segment.start or 0) - ALIGNMENT_TIME_TOLERANCE
seg_end = (segment.end or 0) + ALIGNMENT_TIME_TOLERANCE
if seg_start <= midpoint <= seg_end:
return segment.speaker, candidate_idx
return None, idx
def _build_lines_for_tokens(self, tokens: list[ASRToken], diar_segments: list[SpeakerSegment], diar_idx: int):
if not tokens:
return [], diar_idx
segment_lines: list[Line] = []
current_tokens: list[ASRToken] = []
current_speaker = None
pointer = diar_idx
for token in tokens:
speaker_raw, pointer = self._find_speaker_for_token(token, diar_segments, pointer)
if speaker_raw is None:
return [], diar_idx
speaker = self._speaker_display_id(speaker_raw)
if current_speaker is None or current_speaker != speaker:
if current_tokens:
segment_lines.append(self._line_from_tokens(current_tokens, current_speaker))
current_tokens = [token]
current_speaker = speaker
else:
current_tokens.append(token)
if current_tokens:
segment_lines.append(self._line_from_tokens(current_tokens, current_speaker))
return segment_lines, pointer
def compute_punctuations_segments(self, tokens: Optional[list[ASRToken]] = None):
"""Compute segments of text between punctuation marks.
Returns a list of PunctuationSegment objects, each representing
the text from the start (or previous punctuation) to the current punctuation mark.
"""
if not self.all_tokens:
tokens = tokens if tokens is not None else self._get_asr_tokens()
if not tokens:
return []
punctuation_indices = [
i for i, token in enumerate(self.all_tokens)
i for i, token in enumerate[ASRToken](tokens)
if token.is_punctuation()
]
if not punctuation_indices:
@@ -91,7 +181,7 @@ class TokensAlignment:
end_idx = punct_idx
if start_idx <= end_idx:
segment = PunctuationSegment.from_token_range(
tokens=self.all_tokens,
tokens=tokens,
token_index_start=start_idx,
token_index_end=end_idx,
punctuation_token_index=punct_idx
@@ -109,4 +199,42 @@ class TokensAlignment:
merged[-1].end = segment.end
else:
merged.append(segment)
return merged
return merged
def get_lines(self, diarization=False, translation=False):
"""
Align diarization speaker segments with punctuation-delimited transcription
segments (see docs/alignement_principles.md).
"""
tokens = self._get_asr_tokens()
if not tokens:
return [], ''
punctuation_segments = self.compute_punctuations_segments(tokens=tokens)
diar_segments = self.concatenate_diar_segments()
if not punctuation_segments or not diar_segments:
return [], self._tokens_to_text(tokens)
max_diar_end = diar_segments[-1].end
if max_diar_end is None:
return [], self._tokens_to_text(tokens)
lines: list[Line] = []
last_consumed_index = -1
diar_idx = self._find_initial_diar_index(diar_segments, tokens[0].start or 0)
for segment in punctuation_segments:
if segment.end is None or segment.end > max_diar_end:
break
slice_tokens = tokens[segment.token_index_start:segment.token_index_end + 1]
segment_lines, diar_idx = self._build_lines_for_tokens(slice_tokens, diar_segments, diar_idx)
if not segment_lines:
break
lines.extend(segment_lines)
last_consumed_index = segment.token_index_end
buffer_tokens = tokens[last_consumed_index + 1:] if last_consumed_index + 1 < len(tokens) else []
buffer_diarization = self._tokens_to_text(buffer_tokens)
return lines, buffer_diarization

View File

@@ -15,7 +15,7 @@ logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
SENTINEL = object() # unique sentinel object for end of stream marker
MILENCE_DURATION = 3
MIN_DURATION_REAL_SILENCE = 5
def cut_at(cumulative_pcm, cut_sec):
cumulative_len = 0
@@ -165,7 +165,7 @@ class AudioProcessor:
self.current_silence.is_starting=False
self.current_silence.has_ended=True
self.current_silence.compute_duration()
if self.current_silence.duration > MILENCE_DURATION:
if self.current_silence.duration > MIN_DURATION_REAL_SILENCE:
self.state_light.new_tokens.append(self.current_silence)
await self._push_silence_event()
self.current_silence = None
@@ -410,57 +410,32 @@ class AudioProcessor:
continue
self.tokens_alignment.update()
lines = self.tokens_alignment.create_lines_from_tokens(self.current_silence, self.beg_loop)
undiarized_text = ''
lines, buffer_diarization_text, buffer_translation_text = self.tokens_alignment.get_lines(
diarization=self.args.diarization,
translation=self.args.translation
)
state = await self.get_current_state()
# self.tokens_alignment.compute_punctuations_segments()
# lines, undiarized_text = format_output(
# state,
# self.current_silence,
# args = self.args,
# sep=self.sep
# )
if lines and lines[-1].speaker == -2:
buffer_transcription = Transcript()
else:
buffer_transcription = state.buffer_transcription
buffer_diarization = ''
if undiarized_text:
buffer_diarization = self.sep.join(undiarized_text)
async with self.lock:
self.state.end_attributed_speaker = state.end_attributed_speaker
buffer_translation_text = ''
if state.buffer_translation:
raw_buffer_translation = getattr(state.buffer_translation, 'text', state.buffer_translation)
if raw_buffer_translation:
buffer_translation_text = raw_buffer_translation.strip()
buffer_transcription_text = ''
buffer_diarization_text = ''
response_status = "active_transcription"
if not state.tokens and not buffer_transcription and not buffer_diarization:
if not lines and not buffer_transcription_text and not buffer_diarization_text:
response_status = "no_audio_detected"
lines = []
elif not lines:
lines = [Line(
speaker=1,
start=state.end_buffer,
end=state.end_buffer
)]
response = FrontData(
status=response_status,
lines=lines,
buffer_transcription=buffer_transcription.text.strip(),
buffer_diarization=buffer_diarization,
buffer_transcription=buffer_transcription_text,
buffer_diarization=buffer_diarization_text,
buffer_translation=buffer_translation_text,
remaining_time_transcription=state.remaining_time_transcription,
remaining_time_diarization=state.remaining_time_diarization if self.args.diarization else 0
)
should_push = (response != self.last_response_content)
if should_push and (lines or buffer_transcription or buffer_diarization or response_status == "no_audio_detected"):
if should_push:
yield response
self.last_response_content = response
@@ -582,6 +557,7 @@ class AudioProcessor:
if not self.beg_loop:
self.beg_loop = time()
self.current_silence = Silence(start=0.0, is_starting=True)
self.tokens_alignment.beg_loop = self.beg_loop
if not message:
logger.info("Empty audio message received, initiating stop sequence.")

View File

@@ -162,8 +162,10 @@ class Line(TimedText):
return self.speaker == -2
class SilentLine(Line):
speaker = -2
text = ''
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.speaker = -2
self.text = ''
@dataclass
@@ -192,8 +194,10 @@ class FrontData():
return _dict
@dataclass
class PunctuationSegment(TimedText):
class PunctuationSegment():
"""Represents a segment of text between punctuation marks."""
start: Optional[float]
end: Optional[float]
token_index_start: int
token_index_end: int
punctuation_token_index: int